updates
This commit is contained in:
@@ -1,6 +1,9 @@
|
||||
import binascii
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import mimetypes
|
||||
import os
|
||||
import re
|
||||
import typing
|
||||
from pathlib import Path
|
||||
|
||||
@@ -13,17 +16,46 @@ from ._types import (
|
||||
SyncByteStream,
|
||||
)
|
||||
from ._utils import (
|
||||
format_form_param,
|
||||
guess_content_type,
|
||||
peek_filelike_length,
|
||||
primitive_value_to_str,
|
||||
to_bytes,
|
||||
)
|
||||
|
||||
_HTML5_FORM_ENCODING_REPLACEMENTS = {'"': "%22", "\\": "\\\\"}
|
||||
_HTML5_FORM_ENCODING_REPLACEMENTS.update(
|
||||
{chr(c): "%{:02X}".format(c) for c in range(0x1F + 1) if c != 0x1B}
|
||||
)
|
||||
_HTML5_FORM_ENCODING_RE = re.compile(
|
||||
r"|".join([re.escape(c) for c in _HTML5_FORM_ENCODING_REPLACEMENTS.keys()])
|
||||
)
|
||||
|
||||
|
||||
def _format_form_param(name: str, value: str) -> bytes:
|
||||
"""
|
||||
Encode a name/value pair within a multipart form.
|
||||
"""
|
||||
|
||||
def replacer(match: typing.Match[str]) -> str:
|
||||
return _HTML5_FORM_ENCODING_REPLACEMENTS[match.group(0)]
|
||||
|
||||
value = _HTML5_FORM_ENCODING_RE.sub(replacer, value)
|
||||
return f'{name}="{value}"'.encode()
|
||||
|
||||
|
||||
def _guess_content_type(filename: str | None) -> str | None:
|
||||
"""
|
||||
Guesses the mimetype based on a filename. Defaults to `application/octet-stream`.
|
||||
|
||||
Returns `None` if `filename` is `None` or empty.
|
||||
"""
|
||||
if filename:
|
||||
return mimetypes.guess_type(filename)[0] or "application/octet-stream"
|
||||
return None
|
||||
|
||||
|
||||
def get_multipart_boundary_from_content_type(
|
||||
content_type: typing.Optional[bytes],
|
||||
) -> typing.Optional[bytes]:
|
||||
content_type: bytes | None,
|
||||
) -> bytes | None:
|
||||
if not content_type or not content_type.startswith(b"multipart/form-data"):
|
||||
return None
|
||||
# parse boundary according to
|
||||
@@ -40,25 +72,24 @@ class DataField:
|
||||
A single form field item, within a multipart form field.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, name: str, value: typing.Union[str, bytes, int, float, None]
|
||||
) -> None:
|
||||
def __init__(self, name: str, value: str | bytes | int | float | None) -> None:
|
||||
if not isinstance(name, str):
|
||||
raise TypeError(
|
||||
f"Invalid type for name. Expected str, got {type(name)}: {name!r}"
|
||||
)
|
||||
if value is not None and not isinstance(value, (str, bytes, int, float)):
|
||||
raise TypeError(
|
||||
f"Invalid type for value. Expected primitive type, got {type(value)}: {value!r}"
|
||||
"Invalid type for value. Expected primitive type,"
|
||||
f" got {type(value)}: {value!r}"
|
||||
)
|
||||
self.name = name
|
||||
self.value: typing.Union[str, bytes] = (
|
||||
self.value: str | bytes = (
|
||||
value if isinstance(value, bytes) else primitive_value_to_str(value)
|
||||
)
|
||||
|
||||
def render_headers(self) -> bytes:
|
||||
if not hasattr(self, "_headers"):
|
||||
name = format_form_param("name", self.name)
|
||||
name = _format_form_param("name", self.name)
|
||||
self._headers = b"".join(
|
||||
[b"Content-Disposition: form-data; ", name, b"\r\n\r\n"]
|
||||
)
|
||||
@@ -93,18 +124,20 @@ class FileField:
|
||||
|
||||
fileobj: FileContent
|
||||
|
||||
headers: typing.Dict[str, str] = {}
|
||||
content_type: typing.Optional[str] = None
|
||||
headers: dict[str, str] = {}
|
||||
content_type: str | None = None
|
||||
|
||||
# This large tuple based API largely mirror's requests' API
|
||||
# It would be good to think of better APIs for this that we could include in httpx 2.0
|
||||
# since variable length tuples (especially of 4 elements) are quite unwieldly
|
||||
# It would be good to think of better APIs for this that we could
|
||||
# include in httpx 2.0 since variable length tuples(especially of 4 elements)
|
||||
# are quite unwieldly
|
||||
if isinstance(value, tuple):
|
||||
if len(value) == 2:
|
||||
# neither the 3rd parameter (content_type) nor the 4th (headers) was included
|
||||
filename, fileobj = value # type: ignore
|
||||
# neither the 3rd parameter (content_type) nor the 4th (headers)
|
||||
# was included
|
||||
filename, fileobj = value
|
||||
elif len(value) == 3:
|
||||
filename, fileobj, content_type = value # type: ignore
|
||||
filename, fileobj, content_type = value
|
||||
else:
|
||||
# all 4 parameters included
|
||||
filename, fileobj, content_type, headers = value # type: ignore
|
||||
@@ -113,13 +146,13 @@ class FileField:
|
||||
fileobj = value
|
||||
|
||||
if content_type is None:
|
||||
content_type = guess_content_type(filename)
|
||||
content_type = _guess_content_type(filename)
|
||||
|
||||
has_content_type_header = any("content-type" in key.lower() for key in headers)
|
||||
if content_type is not None and not has_content_type_header:
|
||||
# note that unlike requests, we ignore the content_type
|
||||
# provided in the 3rd tuple element if it is also included in the headers
|
||||
# requests does the opposite (it overwrites the header with the 3rd tuple element)
|
||||
# note that unlike requests, we ignore the content_type provided in the 3rd
|
||||
# tuple element if it is also included in the headers requests does
|
||||
# the opposite (it overwrites the headerwith the 3rd tuple element)
|
||||
headers["Content-Type"] = content_type
|
||||
|
||||
if isinstance(fileobj, io.StringIO):
|
||||
@@ -135,7 +168,7 @@ class FileField:
|
||||
self.file = fileobj
|
||||
self.headers = headers
|
||||
|
||||
def get_length(self) -> typing.Optional[int]:
|
||||
def get_length(self) -> int | None:
|
||||
headers = self.render_headers()
|
||||
|
||||
if isinstance(self.file, (str, bytes)):
|
||||
@@ -154,10 +187,10 @@ class FileField:
|
||||
if not hasattr(self, "_headers"):
|
||||
parts = [
|
||||
b"Content-Disposition: form-data; ",
|
||||
format_form_param("name", self.name),
|
||||
_format_form_param("name", self.name),
|
||||
]
|
||||
if self.filename:
|
||||
filename = format_form_param("filename", self.filename)
|
||||
filename = _format_form_param("filename", self.filename)
|
||||
parts.extend([b"; ", filename])
|
||||
for header_name, header_value in self.headers.items():
|
||||
key, val = f"\r\n{header_name}: ".encode(), header_value.encode()
|
||||
@@ -197,10 +230,10 @@ class MultipartStream(SyncByteStream, AsyncByteStream):
|
||||
self,
|
||||
data: RequestData,
|
||||
files: RequestFiles,
|
||||
boundary: typing.Optional[bytes] = None,
|
||||
boundary: bytes | None = None,
|
||||
) -> None:
|
||||
if boundary is None:
|
||||
boundary = binascii.hexlify(os.urandom(16))
|
||||
boundary = os.urandom(16).hex().encode("ascii")
|
||||
|
||||
self.boundary = boundary
|
||||
self.content_type = "multipart/form-data; boundary=%s" % boundary.decode(
|
||||
@@ -210,7 +243,7 @@ class MultipartStream(SyncByteStream, AsyncByteStream):
|
||||
|
||||
def _iter_fields(
|
||||
self, data: RequestData, files: RequestFiles
|
||||
) -> typing.Iterator[typing.Union[FileField, DataField]]:
|
||||
) -> typing.Iterator[FileField | DataField]:
|
||||
for name, value in data.items():
|
||||
if isinstance(value, (tuple, list)):
|
||||
for item in value:
|
||||
@@ -229,7 +262,7 @@ class MultipartStream(SyncByteStream, AsyncByteStream):
|
||||
yield b"\r\n"
|
||||
yield b"--%s--\r\n" % self.boundary
|
||||
|
||||
def get_content_length(self) -> typing.Optional[int]:
|
||||
def get_content_length(self) -> int | None:
|
||||
"""
|
||||
Return the length of the multipart encoded content, or `None` if
|
||||
any of the files have a length that cannot be determined upfront.
|
||||
@@ -251,7 +284,7 @@ class MultipartStream(SyncByteStream, AsyncByteStream):
|
||||
|
||||
# Content stream interface.
|
||||
|
||||
def get_headers(self) -> typing.Dict[str, str]:
|
||||
def get_headers(self) -> dict[str, str]:
|
||||
content_length = self.get_content_length()
|
||||
content_type = self.content_type
|
||||
if content_length is None:
|
||||
|
||||
Reference in New Issue
Block a user