This commit is contained in:
Iliyan Angelov
2025-12-01 06:50:10 +02:00
parent 91f51bc6fe
commit 62c1fe5951
4682 changed files with 544807 additions and 31208 deletions

View File

@@ -1,48 +1,15 @@
from .__version__ import __description__, __title__, __version__
from ._api import delete, get, head, options, patch, post, put, request, stream
from ._auth import Auth, BasicAuth, DigestAuth, NetRCAuth
from ._client import USE_CLIENT_DEFAULT, AsyncClient, Client
from ._config import Limits, Proxy, Timeout, create_ssl_context
from ._content import ByteStream
from ._exceptions import (
CloseError,
ConnectError,
ConnectTimeout,
CookieConflict,
DecodingError,
HTTPError,
HTTPStatusError,
InvalidURL,
LocalProtocolError,
NetworkError,
PoolTimeout,
ProtocolError,
ProxyError,
ReadError,
ReadTimeout,
RemoteProtocolError,
RequestError,
RequestNotRead,
ResponseNotRead,
StreamClosed,
StreamConsumed,
StreamError,
TimeoutException,
TooManyRedirects,
TransportError,
UnsupportedProtocol,
WriteError,
WriteTimeout,
)
from ._models import Cookies, Headers, Request, Response
from ._status_codes import codes
from ._transports.asgi import ASGITransport
from ._transports.base import AsyncBaseTransport, BaseTransport
from ._transports.default import AsyncHTTPTransport, HTTPTransport
from ._transports.mock import MockTransport
from ._transports.wsgi import WSGITransport
from ._types import AsyncByteStream, SyncByteStream
from ._urls import URL, QueryParams
from ._api import *
from ._auth import *
from ._client import *
from ._config import *
from ._content import *
from ._exceptions import *
from ._models import *
from ._status_codes import *
from ._transports import *
from ._types import *
from ._urls import *
try:
from ._main import main

View File

@@ -1,3 +1,3 @@
__title__ = "httpx"
__description__ = "A next generation HTTP client, for Python 3."
__version__ = "0.24.1"
__version__ = "0.28.1"

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
import typing
from contextlib import contextmanager
@@ -6,37 +8,50 @@ from ._config import DEFAULT_TIMEOUT_CONFIG
from ._models import Response
from ._types import (
AuthTypes,
CertTypes,
CookieTypes,
HeaderTypes,
ProxiesTypes,
ProxyTypes,
QueryParamTypes,
RequestContent,
RequestData,
RequestFiles,
TimeoutTypes,
URLTypes,
VerifyTypes,
)
from ._urls import URL
if typing.TYPE_CHECKING:
import ssl # pragma: no cover
__all__ = [
"delete",
"get",
"head",
"options",
"patch",
"post",
"put",
"request",
"stream",
]
def request(
method: str,
url: URLTypes,
url: URL | str,
*,
params: typing.Optional[QueryParamTypes] = None,
content: typing.Optional[RequestContent] = None,
data: typing.Optional[RequestData] = None,
files: typing.Optional[RequestFiles] = None,
json: typing.Optional[typing.Any] = None,
headers: typing.Optional[HeaderTypes] = None,
cookies: typing.Optional[CookieTypes] = None,
auth: typing.Optional[AuthTypes] = None,
proxies: typing.Optional[ProxiesTypes] = None,
params: QueryParamTypes | None = None,
content: RequestContent | None = None,
data: RequestData | None = None,
files: RequestFiles | None = None,
json: typing.Any | None = None,
headers: HeaderTypes | None = None,
cookies: CookieTypes | None = None,
auth: AuthTypes | None = None,
proxy: ProxyTypes | None = None,
timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
follow_redirects: bool = False,
verify: VerifyTypes = True,
cert: typing.Optional[CertTypes] = None,
verify: ssl.SSLContext | str | bool = True,
trust_env: bool = True,
) -> Response:
"""
@@ -63,18 +78,13 @@ def request(
request.
* **auth** - *(optional)* An authentication class to use when sending the
request.
* **proxies** - *(optional)* A dictionary mapping proxy keys to proxy URLs.
* **proxy** - *(optional)* A proxy URL where all the traffic should be routed.
* **timeout** - *(optional)* The timeout configuration to use when sending
the request.
* **follow_redirects** - *(optional)* Enables or disables HTTP redirects.
* **verify** - *(optional)* SSL certificates (a.k.a CA bundle) used to
verify the identity of requested hosts. Either `True` (default CA bundle),
a path to an SSL certificate file, an `ssl.SSLContext`, or `False`
(which will disable verification).
* **cert** - *(optional)* An SSL certificate used by the requested host
to authenticate the client. Either a path to an SSL certificate file, or
two-tuple of (certificate file, key file), or a three-tuple of (certificate
file, key file, password).
* **verify** - *(optional)* Either `True` to use an SSL context with the
default CA bundle, `False` to disable verification, or an instance of
`ssl.SSLContext` to use a custom context.
* **trust_env** - *(optional)* Enables or disables usage of environment
variables for configuration.
@@ -91,8 +101,7 @@ def request(
"""
with Client(
cookies=cookies,
proxies=proxies,
cert=cert,
proxy=proxy,
verify=verify,
timeout=timeout,
trust_env=trust_env,
@@ -114,21 +123,20 @@ def request(
@contextmanager
def stream(
method: str,
url: URLTypes,
url: URL | str,
*,
params: typing.Optional[QueryParamTypes] = None,
content: typing.Optional[RequestContent] = None,
data: typing.Optional[RequestData] = None,
files: typing.Optional[RequestFiles] = None,
json: typing.Optional[typing.Any] = None,
headers: typing.Optional[HeaderTypes] = None,
cookies: typing.Optional[CookieTypes] = None,
auth: typing.Optional[AuthTypes] = None,
proxies: typing.Optional[ProxiesTypes] = None,
params: QueryParamTypes | None = None,
content: RequestContent | None = None,
data: RequestData | None = None,
files: RequestFiles | None = None,
json: typing.Any | None = None,
headers: HeaderTypes | None = None,
cookies: CookieTypes | None = None,
auth: AuthTypes | None = None,
proxy: ProxyTypes | None = None,
timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
follow_redirects: bool = False,
verify: VerifyTypes = True,
cert: typing.Optional[CertTypes] = None,
verify: ssl.SSLContext | str | bool = True,
trust_env: bool = True,
) -> typing.Iterator[Response]:
"""
@@ -143,8 +151,7 @@ def stream(
"""
with Client(
cookies=cookies,
proxies=proxies,
cert=cert,
proxy=proxy,
verify=verify,
timeout=timeout,
trust_env=trust_env,
@@ -165,16 +172,15 @@ def stream(
def get(
url: URLTypes,
url: URL | str,
*,
params: typing.Optional[QueryParamTypes] = None,
headers: typing.Optional[HeaderTypes] = None,
cookies: typing.Optional[CookieTypes] = None,
auth: typing.Optional[AuthTypes] = None,
proxies: typing.Optional[ProxiesTypes] = None,
params: QueryParamTypes | None = None,
headers: HeaderTypes | None = None,
cookies: CookieTypes | None = None,
auth: AuthTypes | None = None,
proxy: ProxyTypes | None = None,
follow_redirects: bool = False,
cert: typing.Optional[CertTypes] = None,
verify: VerifyTypes = True,
verify: ssl.SSLContext | str | bool = True,
timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
trust_env: bool = True,
) -> Response:
@@ -193,9 +199,8 @@ def get(
headers=headers,
cookies=cookies,
auth=auth,
proxies=proxies,
proxy=proxy,
follow_redirects=follow_redirects,
cert=cert,
verify=verify,
timeout=timeout,
trust_env=trust_env,
@@ -203,16 +208,15 @@ def get(
def options(
url: URLTypes,
url: URL | str,
*,
params: typing.Optional[QueryParamTypes] = None,
headers: typing.Optional[HeaderTypes] = None,
cookies: typing.Optional[CookieTypes] = None,
auth: typing.Optional[AuthTypes] = None,
proxies: typing.Optional[ProxiesTypes] = None,
params: QueryParamTypes | None = None,
headers: HeaderTypes | None = None,
cookies: CookieTypes | None = None,
auth: AuthTypes | None = None,
proxy: ProxyTypes | None = None,
follow_redirects: bool = False,
cert: typing.Optional[CertTypes] = None,
verify: VerifyTypes = True,
verify: ssl.SSLContext | str | bool = True,
timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
trust_env: bool = True,
) -> Response:
@@ -231,9 +235,8 @@ def options(
headers=headers,
cookies=cookies,
auth=auth,
proxies=proxies,
proxy=proxy,
follow_redirects=follow_redirects,
cert=cert,
verify=verify,
timeout=timeout,
trust_env=trust_env,
@@ -241,16 +244,15 @@ def options(
def head(
url: URLTypes,
url: URL | str,
*,
params: typing.Optional[QueryParamTypes] = None,
headers: typing.Optional[HeaderTypes] = None,
cookies: typing.Optional[CookieTypes] = None,
auth: typing.Optional[AuthTypes] = None,
proxies: typing.Optional[ProxiesTypes] = None,
params: QueryParamTypes | None = None,
headers: HeaderTypes | None = None,
cookies: CookieTypes | None = None,
auth: AuthTypes | None = None,
proxy: ProxyTypes | None = None,
follow_redirects: bool = False,
cert: typing.Optional[CertTypes] = None,
verify: VerifyTypes = True,
verify: ssl.SSLContext | str | bool = True,
timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
trust_env: bool = True,
) -> Response:
@@ -269,9 +271,8 @@ def head(
headers=headers,
cookies=cookies,
auth=auth,
proxies=proxies,
proxy=proxy,
follow_redirects=follow_redirects,
cert=cert,
verify=verify,
timeout=timeout,
trust_env=trust_env,
@@ -279,20 +280,19 @@ def head(
def post(
url: URLTypes,
url: URL | str,
*,
content: typing.Optional[RequestContent] = None,
data: typing.Optional[RequestData] = None,
files: typing.Optional[RequestFiles] = None,
json: typing.Optional[typing.Any] = None,
params: typing.Optional[QueryParamTypes] = None,
headers: typing.Optional[HeaderTypes] = None,
cookies: typing.Optional[CookieTypes] = None,
auth: typing.Optional[AuthTypes] = None,
proxies: typing.Optional[ProxiesTypes] = None,
content: RequestContent | None = None,
data: RequestData | None = None,
files: RequestFiles | None = None,
json: typing.Any | None = None,
params: QueryParamTypes | None = None,
headers: HeaderTypes | None = None,
cookies: CookieTypes | None = None,
auth: AuthTypes | None = None,
proxy: ProxyTypes | None = None,
follow_redirects: bool = False,
cert: typing.Optional[CertTypes] = None,
verify: VerifyTypes = True,
verify: ssl.SSLContext | str | bool = True,
timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
trust_env: bool = True,
) -> Response:
@@ -312,9 +312,8 @@ def post(
headers=headers,
cookies=cookies,
auth=auth,
proxies=proxies,
proxy=proxy,
follow_redirects=follow_redirects,
cert=cert,
verify=verify,
timeout=timeout,
trust_env=trust_env,
@@ -322,20 +321,19 @@ def post(
def put(
url: URLTypes,
url: URL | str,
*,
content: typing.Optional[RequestContent] = None,
data: typing.Optional[RequestData] = None,
files: typing.Optional[RequestFiles] = None,
json: typing.Optional[typing.Any] = None,
params: typing.Optional[QueryParamTypes] = None,
headers: typing.Optional[HeaderTypes] = None,
cookies: typing.Optional[CookieTypes] = None,
auth: typing.Optional[AuthTypes] = None,
proxies: typing.Optional[ProxiesTypes] = None,
content: RequestContent | None = None,
data: RequestData | None = None,
files: RequestFiles | None = None,
json: typing.Any | None = None,
params: QueryParamTypes | None = None,
headers: HeaderTypes | None = None,
cookies: CookieTypes | None = None,
auth: AuthTypes | None = None,
proxy: ProxyTypes | None = None,
follow_redirects: bool = False,
cert: typing.Optional[CertTypes] = None,
verify: VerifyTypes = True,
verify: ssl.SSLContext | str | bool = True,
timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
trust_env: bool = True,
) -> Response:
@@ -355,9 +353,8 @@ def put(
headers=headers,
cookies=cookies,
auth=auth,
proxies=proxies,
proxy=proxy,
follow_redirects=follow_redirects,
cert=cert,
verify=verify,
timeout=timeout,
trust_env=trust_env,
@@ -365,20 +362,19 @@ def put(
def patch(
url: URLTypes,
url: URL | str,
*,
content: typing.Optional[RequestContent] = None,
data: typing.Optional[RequestData] = None,
files: typing.Optional[RequestFiles] = None,
json: typing.Optional[typing.Any] = None,
params: typing.Optional[QueryParamTypes] = None,
headers: typing.Optional[HeaderTypes] = None,
cookies: typing.Optional[CookieTypes] = None,
auth: typing.Optional[AuthTypes] = None,
proxies: typing.Optional[ProxiesTypes] = None,
content: RequestContent | None = None,
data: RequestData | None = None,
files: RequestFiles | None = None,
json: typing.Any | None = None,
params: QueryParamTypes | None = None,
headers: HeaderTypes | None = None,
cookies: CookieTypes | None = None,
auth: AuthTypes | None = None,
proxy: ProxyTypes | None = None,
follow_redirects: bool = False,
cert: typing.Optional[CertTypes] = None,
verify: VerifyTypes = True,
verify: ssl.SSLContext | str | bool = True,
timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
trust_env: bool = True,
) -> Response:
@@ -398,9 +394,8 @@ def patch(
headers=headers,
cookies=cookies,
auth=auth,
proxies=proxies,
proxy=proxy,
follow_redirects=follow_redirects,
cert=cert,
verify=verify,
timeout=timeout,
trust_env=trust_env,
@@ -408,17 +403,16 @@ def patch(
def delete(
url: URLTypes,
url: URL | str,
*,
params: typing.Optional[QueryParamTypes] = None,
headers: typing.Optional[HeaderTypes] = None,
cookies: typing.Optional[CookieTypes] = None,
auth: typing.Optional[AuthTypes] = None,
proxies: typing.Optional[ProxiesTypes] = None,
params: QueryParamTypes | None = None,
headers: HeaderTypes | None = None,
cookies: CookieTypes | None = None,
auth: AuthTypes | None = None,
proxy: ProxyTypes | None = None,
follow_redirects: bool = False,
cert: typing.Optional[CertTypes] = None,
verify: VerifyTypes = True,
timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
verify: ssl.SSLContext | str | bool = True,
trust_env: bool = True,
) -> Response:
"""
@@ -436,9 +430,8 @@ def delete(
headers=headers,
cookies=cookies,
auth=auth,
proxies=proxies,
proxy=proxy,
follow_redirects=follow_redirects,
cert=cert,
verify=verify,
timeout=timeout,
trust_env=trust_env,

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
import hashlib
import netrc
import os
import re
import time
@@ -8,13 +9,16 @@ from base64 import b64encode
from urllib.request import parse_http_list
from ._exceptions import ProtocolError
from ._models import Request, Response
from ._models import Cookies, Request, Response
from ._utils import to_bytes, to_str, unquote
if typing.TYPE_CHECKING: # pragma: no cover
from hashlib import _Hash
__all__ = ["Auth", "BasicAuth", "DigestAuth", "NetRCAuth"]
class Auth:
"""
Base class for all authentication schemes.
@@ -125,18 +129,14 @@ class BasicAuth(Auth):
and uses HTTP Basic authentication.
"""
def __init__(
self, username: typing.Union[str, bytes], password: typing.Union[str, bytes]
):
def __init__(self, username: str | bytes, password: str | bytes) -> None:
self._auth_header = self._build_auth_header(username, password)
def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]:
request.headers["Authorization"] = self._auth_header
yield request
def _build_auth_header(
self, username: typing.Union[str, bytes], password: typing.Union[str, bytes]
) -> str:
def _build_auth_header(self, username: str | bytes, password: str | bytes) -> str:
userpass = b":".join((to_bytes(username), to_bytes(password)))
token = b64encode(userpass).decode()
return f"Basic {token}"
@@ -147,7 +147,11 @@ class NetRCAuth(Auth):
Use a 'netrc' file to lookup basic auth credentials based on the url host.
"""
def __init__(self, file: typing.Optional[str] = None):
def __init__(self, file: str | None = None) -> None:
# Lazily import 'netrc'.
# There's no need for us to load this module unless 'NetRCAuth' is being used.
import netrc
self._netrc_info = netrc.netrc(file)
def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]:
@@ -162,16 +166,14 @@ class NetRCAuth(Auth):
)
yield request
def _build_auth_header(
self, username: typing.Union[str, bytes], password: typing.Union[str, bytes]
) -> str:
def _build_auth_header(self, username: str | bytes, password: str | bytes) -> str:
userpass = b":".join((to_bytes(username), to_bytes(password)))
token = b64encode(userpass).decode()
return f"Basic {token}"
class DigestAuth(Auth):
_ALGORITHM_TO_HASH_FUNCTION: typing.Dict[str, typing.Callable[[bytes], "_Hash"]] = {
_ALGORITHM_TO_HASH_FUNCTION: dict[str, typing.Callable[[bytes], _Hash]] = {
"MD5": hashlib.md5,
"MD5-SESS": hashlib.md5,
"SHA": hashlib.sha1,
@@ -182,12 +184,10 @@ class DigestAuth(Auth):
"SHA-512-SESS": hashlib.sha512,
}
def __init__(
self, username: typing.Union[str, bytes], password: typing.Union[str, bytes]
) -> None:
def __init__(self, username: str | bytes, password: str | bytes) -> None:
self._username = to_bytes(username)
self._password = to_bytes(password)
self._last_challenge: typing.Optional[_DigestAuthChallenge] = None
self._last_challenge: _DigestAuthChallenge | None = None
self._nonce_count = 1
def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]:
@@ -217,11 +217,13 @@ class DigestAuth(Auth):
request.headers["Authorization"] = self._build_auth_header(
request, self._last_challenge
)
if response.cookies:
Cookies(response.cookies).set_cookie_header(request=request)
yield request
def _parse_challenge(
self, request: Request, response: Response, auth_header: str
) -> "_DigestAuthChallenge":
) -> _DigestAuthChallenge:
"""
Returns a challenge from a Digest WWW-Authenticate header.
These take the form of:
@@ -232,7 +234,7 @@ class DigestAuth(Auth):
# This method should only ever have been called with a Digest auth header.
assert scheme.lower() == "digest"
header_dict: typing.Dict[str, str] = {}
header_dict: dict[str, str] = {}
for field in parse_http_list(fields):
key, value = field.strip().split("=", 1)
header_dict[key] = unquote(value)
@@ -251,7 +253,7 @@ class DigestAuth(Auth):
raise ProtocolError(message, request=request) from exc
def _build_auth_header(
self, request: Request, challenge: "_DigestAuthChallenge"
self, request: Request, challenge: _DigestAuthChallenge
) -> str:
hash_func = self._ALGORITHM_TO_HASH_FUNCTION[challenge.algorithm.upper()]
@@ -275,17 +277,18 @@ class DigestAuth(Auth):
qop = self._resolve_qop(challenge.qop, request=request)
if qop is None:
# Following RFC 2069
digest_data = [HA1, challenge.nonce, HA2]
else:
digest_data = [challenge.nonce, nc_value, cnonce, qop, HA2]
key_digest = b":".join(digest_data)
# Following RFC 2617/7616
digest_data = [HA1, challenge.nonce, nc_value, cnonce, qop, HA2]
format_args = {
"username": self._username,
"realm": challenge.realm,
"nonce": challenge.nonce,
"uri": path,
"response": digest(b":".join((HA1, key_digest))),
"response": digest(b":".join(digest_data)),
"algorithm": challenge.algorithm.encode(),
}
if challenge.opaque:
@@ -305,7 +308,7 @@ class DigestAuth(Auth):
return hashlib.sha1(s).hexdigest()[:16].encode()
def _get_header_value(self, header_fields: typing.Dict[str, bytes]) -> str:
def _get_header_value(self, header_fields: dict[str, bytes]) -> str:
NON_QUOTED_FIELDS = ("algorithm", "qop", "nc")
QUOTED_TEMPLATE = '{}="{}"'
NON_QUOTED_TEMPLATE = "{}={}"
@@ -323,9 +326,7 @@ class DigestAuth(Auth):
return header_value
def _resolve_qop(
self, qop: typing.Optional[bytes], request: Request
) -> typing.Optional[bytes]:
def _resolve_qop(self, qop: bytes | None, request: Request) -> bytes | None:
if qop is None:
return None
qops = re.split(b", ?", qop)
@@ -343,5 +344,5 @@ class _DigestAuthChallenge(typing.NamedTuple):
realm: bytes
nonce: bytes
algorithm: str
opaque: typing.Optional[bytes]
qop: typing.Optional[bytes]
opaque: bytes | None
qop: bytes | None

File diff suppressed because it is too large Load Diff

View File

@@ -1,43 +0,0 @@
"""
The _compat module is used for code which requires branching between different
Python environments. It is excluded from the code coverage checks.
"""
import ssl
import sys
# Brotli support is optional
# The C bindings in `brotli` are recommended for CPython.
# The CFFI bindings in `brotlicffi` are recommended for PyPy and everything else.
try:
import brotlicffi as brotli
except ImportError: # pragma: no cover
try:
import brotli
except ImportError:
brotli = None
if sys.version_info >= (3, 10) or (
sys.version_info >= (3, 7) and ssl.OPENSSL_VERSION_INFO >= (1, 1, 0, 7)
):
def set_minimum_tls_version_1_2(context: ssl.SSLContext) -> None:
# The OP_NO_SSL* and OP_NO_TLS* become deprecated in favor of
# 'SSLContext.minimum_version' from Python 3.7 onwards, however
# this attribute is not available unless the ssl module is compiled
# with OpenSSL 1.1.0g or newer.
# https://docs.python.org/3.10/library/ssl.html#ssl.SSLContext.minimum_version
# https://docs.python.org/3.7/library/ssl.html#ssl.SSLContext.minimum_version
context.minimum_version = ssl.TLSVersion.TLSv1_2
else:
def set_minimum_tls_version_1_2(context: ssl.SSLContext) -> None:
# If 'minimum_version' isn't available, we configure these options with
# the older deprecated variants.
context.options |= ssl.OP_NO_SSLv2
context.options |= ssl.OP_NO_SSLv3
context.options |= ssl.OP_NO_TLSv1
context.options |= ssl.OP_NO_TLSv1_1
__all__ = ["brotli", "set_minimum_tls_version_1_2"]

View File

@@ -1,39 +1,16 @@
import logging
from __future__ import annotations
import os
import ssl
import sys
import typing
from pathlib import Path
import certifi
from ._compat import set_minimum_tls_version_1_2
from ._models import Headers
from ._types import CertTypes, HeaderTypes, TimeoutTypes, URLTypes, VerifyTypes
from ._types import CertTypes, HeaderTypes, TimeoutTypes
from ._urls import URL
from ._utils import get_ca_bundle_from_env
DEFAULT_CIPHERS = ":".join(
[
"ECDHE+AESGCM",
"ECDHE+CHACHA20",
"DHE+AESGCM",
"DHE+CHACHA20",
"ECDH+AESGCM",
"DH+AESGCM",
"ECDH+AES",
"DH+AES",
"RSA+AESGCM",
"RSA+AES",
"!aNULL",
"!eNULL",
"!MD5",
"!DSS",
]
)
if typing.TYPE_CHECKING:
import ssl # pragma: no cover
logger = logging.getLogger("httpx")
__all__ = ["Limits", "Proxy", "Timeout", "create_ssl_context"]
class UnsetType:
@@ -44,152 +21,52 @@ UNSET = UnsetType()
def create_ssl_context(
cert: typing.Optional[CertTypes] = None,
verify: VerifyTypes = True,
verify: ssl.SSLContext | str | bool = True,
cert: CertTypes | None = None,
trust_env: bool = True,
http2: bool = False,
) -> ssl.SSLContext:
return SSLConfig(
cert=cert, verify=verify, trust_env=trust_env, http2=http2
).ssl_context
import ssl
import warnings
import certifi
class SSLConfig:
"""
SSL Configuration.
"""
DEFAULT_CA_BUNDLE_PATH = Path(certifi.where())
def __init__(
self,
*,
cert: typing.Optional[CertTypes] = None,
verify: VerifyTypes = True,
trust_env: bool = True,
http2: bool = False,
):
self.cert = cert
self.verify = verify
self.trust_env = trust_env
self.http2 = http2
self.ssl_context = self.load_ssl_context()
def load_ssl_context(self) -> ssl.SSLContext:
logger.debug(
"load_ssl_context verify=%r cert=%r trust_env=%r http2=%r",
self.verify,
self.cert,
self.trust_env,
self.http2,
)
if self.verify:
return self.load_ssl_context_verify()
return self.load_ssl_context_no_verify()
def load_ssl_context_no_verify(self) -> ssl.SSLContext:
"""
Return an SSL context for unverified connections.
"""
context = self._create_default_ssl_context()
context.check_hostname = False
context.verify_mode = ssl.CERT_NONE
self._load_client_certs(context)
return context
def load_ssl_context_verify(self) -> ssl.SSLContext:
"""
Return an SSL context for verified connections.
"""
if self.trust_env and self.verify is True:
ca_bundle = get_ca_bundle_from_env()
if ca_bundle is not None:
self.verify = ca_bundle
if isinstance(self.verify, ssl.SSLContext):
# Allow passing in our own SSLContext object that's pre-configured.
context = self.verify
self._load_client_certs(context)
return context
elif isinstance(self.verify, bool):
ca_bundle_path = self.DEFAULT_CA_BUNDLE_PATH
elif Path(self.verify).exists():
ca_bundle_path = Path(self.verify)
if verify is True:
if trust_env and os.environ.get("SSL_CERT_FILE"): # pragma: nocover
ctx = ssl.create_default_context(cafile=os.environ["SSL_CERT_FILE"])
elif trust_env and os.environ.get("SSL_CERT_DIR"): # pragma: nocover
ctx = ssl.create_default_context(capath=os.environ["SSL_CERT_DIR"])
else:
raise IOError(
"Could not find a suitable TLS CA certificate bundle, "
"invalid path: {}".format(self.verify)
)
# Default case...
ctx = ssl.create_default_context(cafile=certifi.where())
elif verify is False:
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
ctx.check_hostname = False
ctx.verify_mode = ssl.CERT_NONE
elif isinstance(verify, str): # pragma: nocover
message = (
"`verify=<str>` is deprecated. "
"Use `verify=ssl.create_default_context(cafile=...)` "
"or `verify=ssl.create_default_context(capath=...)` instead."
)
warnings.warn(message, DeprecationWarning)
if os.path.isdir(verify):
return ssl.create_default_context(capath=verify)
return ssl.create_default_context(cafile=verify)
else:
ctx = verify
context = self._create_default_ssl_context()
context.verify_mode = ssl.CERT_REQUIRED
context.check_hostname = True
if cert: # pragma: nocover
message = (
"`cert=...` is deprecated. Use `verify=<ssl_context>` instead,"
"with `.load_cert_chain()` to configure the certificate chain."
)
warnings.warn(message, DeprecationWarning)
if isinstance(cert, str):
ctx.load_cert_chain(cert)
else:
ctx.load_cert_chain(*cert)
# Signal to server support for PHA in TLS 1.3. Raises an
# AttributeError if only read-only access is implemented.
if sys.version_info >= (3, 8): # pragma: no cover
try:
context.post_handshake_auth = True
except AttributeError: # pragma: no cover
pass
# Disable using 'commonName' for SSLContext.check_hostname
# when the 'subjectAltName' extension isn't available.
try:
context.hostname_checks_common_name = False
except AttributeError: # pragma: no cover
pass
if ca_bundle_path.is_file():
cafile = str(ca_bundle_path)
logger.debug("load_verify_locations cafile=%r", cafile)
context.load_verify_locations(cafile=cafile)
elif ca_bundle_path.is_dir():
capath = str(ca_bundle_path)
logger.debug("load_verify_locations capath=%r", capath)
context.load_verify_locations(capath=capath)
self._load_client_certs(context)
return context
def _create_default_ssl_context(self) -> ssl.SSLContext:
"""
Creates the default SSLContext object that's used for both verified
and unverified connections.
"""
context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
set_minimum_tls_version_1_2(context)
context.options |= ssl.OP_NO_COMPRESSION
context.set_ciphers(DEFAULT_CIPHERS)
if ssl.HAS_ALPN:
alpn_idents = ["http/1.1", "h2"] if self.http2 else ["http/1.1"]
context.set_alpn_protocols(alpn_idents)
if sys.version_info >= (3, 8): # pragma: no cover
keylogfile = os.environ.get("SSLKEYLOGFILE")
if keylogfile and self.trust_env:
context.keylog_filename = keylogfile
return context
def _load_client_certs(self, ssl_context: ssl.SSLContext) -> None:
"""
Loads client certificates into our SSLContext object
"""
if self.cert is not None:
if isinstance(self.cert, str):
ssl_context.load_cert_chain(certfile=self.cert)
elif isinstance(self.cert, tuple) and len(self.cert) == 2:
ssl_context.load_cert_chain(certfile=self.cert[0], keyfile=self.cert[1])
elif isinstance(self.cert, tuple) and len(self.cert) == 3:
ssl_context.load_cert_chain(
certfile=self.cert[0],
keyfile=self.cert[1],
password=self.cert[2], # type: ignore
)
return ctx
class Timeout:
@@ -208,13 +85,13 @@ class Timeout:
def __init__(
self,
timeout: typing.Union[TimeoutTypes, UnsetType] = UNSET,
timeout: TimeoutTypes | UnsetType = UNSET,
*,
connect: typing.Union[None, float, UnsetType] = UNSET,
read: typing.Union[None, float, UnsetType] = UNSET,
write: typing.Union[None, float, UnsetType] = UNSET,
pool: typing.Union[None, float, UnsetType] = UNSET,
):
connect: None | float | UnsetType = UNSET,
read: None | float | UnsetType = UNSET,
write: None | float | UnsetType = UNSET,
pool: None | float | UnsetType = UNSET,
) -> None:
if isinstance(timeout, Timeout):
# Passed as a single explicit Timeout.
assert connect is UNSET
@@ -252,7 +129,7 @@ class Timeout:
self.write = timeout if isinstance(write, UnsetType) else write
self.pool = timeout if isinstance(pool, UnsetType) else pool
def as_dict(self) -> typing.Dict[str, typing.Optional[float]]:
def as_dict(self) -> dict[str, float | None]:
return {
"connect": self.connect,
"read": self.read,
@@ -296,10 +173,10 @@ class Limits:
def __init__(
self,
*,
max_connections: typing.Optional[int] = None,
max_keepalive_connections: typing.Optional[int] = None,
keepalive_expiry: typing.Optional[float] = 5.0,
):
max_connections: int | None = None,
max_keepalive_connections: int | None = None,
keepalive_expiry: float | None = 5.0,
) -> None:
self.max_connections = max_connections
self.max_keepalive_connections = max_keepalive_connections
self.keepalive_expiry = keepalive_expiry
@@ -324,15 +201,16 @@ class Limits:
class Proxy:
def __init__(
self,
url: URLTypes,
url: URL | str,
*,
auth: typing.Optional[typing.Tuple[str, str]] = None,
headers: typing.Optional[HeaderTypes] = None,
):
ssl_context: ssl.SSLContext | None = None,
auth: tuple[str, str] | None = None,
headers: HeaderTypes | None = None,
) -> None:
url = URL(url)
headers = Headers(headers)
if url.scheme not in ("http", "https", "socks5"):
if url.scheme not in ("http", "https", "socks5", "socks5h"):
raise ValueError(f"Unknown scheme for proxy URL {url!r}")
if url.username or url.password:
@@ -343,9 +221,10 @@ class Proxy:
self.url = url
self.auth = auth
self.headers = headers
self.ssl_context = ssl_context
@property
def raw_auth(self) -> typing.Optional[typing.Tuple[bytes, bytes]]:
def raw_auth(self) -> tuple[bytes, bytes] | None:
# The proxy authentication as raw bytes.
return (
None

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
import inspect
import warnings
from json import dumps as json_dumps
@@ -5,13 +7,9 @@ from typing import (
Any,
AsyncIterable,
AsyncIterator,
Dict,
Iterable,
Iterator,
Mapping,
Optional,
Tuple,
Union,
)
from urllib.parse import urlencode
@@ -27,6 +25,8 @@ from ._types import (
)
from ._utils import peek_filelike_length, primitive_value_to_str
__all__ = ["ByteStream"]
class ByteStream(AsyncByteStream, SyncByteStream):
def __init__(self, stream: bytes) -> None:
@@ -42,7 +42,7 @@ class ByteStream(AsyncByteStream, SyncByteStream):
class IteratorByteStream(SyncByteStream):
CHUNK_SIZE = 65_536
def __init__(self, stream: Iterable[bytes]):
def __init__(self, stream: Iterable[bytes]) -> None:
self._stream = stream
self._is_stream_consumed = False
self._is_generator = inspect.isgenerator(stream)
@@ -67,7 +67,7 @@ class IteratorByteStream(SyncByteStream):
class AsyncIteratorByteStream(AsyncByteStream):
CHUNK_SIZE = 65_536
def __init__(self, stream: AsyncIterable[bytes]):
def __init__(self, stream: AsyncIterable[bytes]) -> None:
self._stream = stream
self._is_stream_consumed = False
self._is_generator = inspect.isasyncgen(stream)
@@ -105,8 +105,8 @@ class UnattachedStream(AsyncByteStream, SyncByteStream):
def encode_content(
content: Union[str, bytes, Iterable[bytes], AsyncIterable[bytes]]
) -> Tuple[Dict[str, str], Union[SyncByteStream, AsyncByteStream]]:
content: str | bytes | Iterable[bytes] | AsyncIterable[bytes],
) -> tuple[dict[str, str], SyncByteStream | AsyncByteStream]:
if isinstance(content, (bytes, str)):
body = content.encode("utf-8") if isinstance(content, str) else content
content_length = len(body)
@@ -135,7 +135,7 @@ def encode_content(
def encode_urlencoded_data(
data: RequestData,
) -> Tuple[Dict[str, str], ByteStream]:
) -> tuple[dict[str, str], ByteStream]:
plain_data = []
for key, value in data.items():
if isinstance(value, (list, tuple)):
@@ -150,14 +150,14 @@ def encode_urlencoded_data(
def encode_multipart_data(
data: RequestData, files: RequestFiles, boundary: Optional[bytes]
) -> Tuple[Dict[str, str], MultipartStream]:
data: RequestData, files: RequestFiles, boundary: bytes | None
) -> tuple[dict[str, str], MultipartStream]:
multipart = MultipartStream(data=data, files=files, boundary=boundary)
headers = multipart.get_headers()
return headers, multipart
def encode_text(text: str) -> Tuple[Dict[str, str], ByteStream]:
def encode_text(text: str) -> tuple[dict[str, str], ByteStream]:
body = text.encode("utf-8")
content_length = str(len(body))
content_type = "text/plain; charset=utf-8"
@@ -165,7 +165,7 @@ def encode_text(text: str) -> Tuple[Dict[str, str], ByteStream]:
return headers, ByteStream(body)
def encode_html(html: str) -> Tuple[Dict[str, str], ByteStream]:
def encode_html(html: str) -> tuple[dict[str, str], ByteStream]:
body = html.encode("utf-8")
content_length = str(len(body))
content_type = "text/html; charset=utf-8"
@@ -173,8 +173,10 @@ def encode_html(html: str) -> Tuple[Dict[str, str], ByteStream]:
return headers, ByteStream(body)
def encode_json(json: Any) -> Tuple[Dict[str, str], ByteStream]:
body = json_dumps(json).encode("utf-8")
def encode_json(json: Any) -> tuple[dict[str, str], ByteStream]:
body = json_dumps(
json, ensure_ascii=False, separators=(",", ":"), allow_nan=False
).encode("utf-8")
content_length = str(len(body))
content_type = "application/json"
headers = {"Content-Length": content_length, "Content-Type": content_type}
@@ -182,12 +184,12 @@ def encode_json(json: Any) -> Tuple[Dict[str, str], ByteStream]:
def encode_request(
content: Optional[RequestContent] = None,
data: Optional[RequestData] = None,
files: Optional[RequestFiles] = None,
json: Optional[Any] = None,
boundary: Optional[bytes] = None,
) -> Tuple[Dict[str, str], Union[SyncByteStream, AsyncByteStream]]:
content: RequestContent | None = None,
data: RequestData | None = None,
files: RequestFiles | None = None,
json: Any | None = None,
boundary: bytes | None = None,
) -> tuple[dict[str, str], SyncByteStream | AsyncByteStream]:
"""
Handles encoding the given `content`, `data`, `files`, and `json`,
returning a two-tuple of (<headers>, <stream>).
@@ -201,7 +203,7 @@ def encode_request(
# `data=<bytes...>` usages. We deal with that case here, treating it
# as if `content=<...>` had been supplied instead.
message = "Use 'content=<...>' to upload raw bytes/text content."
warnings.warn(message, DeprecationWarning)
warnings.warn(message, DeprecationWarning, stacklevel=2)
return encode_content(data)
if content is not None:
@@ -217,11 +219,11 @@ def encode_request(
def encode_response(
content: Optional[ResponseContent] = None,
text: Optional[str] = None,
html: Optional[str] = None,
json: Optional[Any] = None,
) -> Tuple[Dict[str, str], Union[SyncByteStream, AsyncByteStream]]:
content: ResponseContent | None = None,
text: str | None = None,
html: str | None = None,
json: Any | None = None,
) -> tuple[dict[str, str], SyncByteStream | AsyncByteStream]:
"""
Handles encoding the given `content`, returning a two-tuple of
(<headers>, <stream>).

View File

@@ -3,14 +3,35 @@ Handlers for Content-Encoding.
See: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Encoding
"""
from __future__ import annotations
import codecs
import io
import typing
import zlib
from ._compat import brotli
from ._exceptions import DecodingError
# Brotli support is optional
try:
# The C bindings in `brotli` are recommended for CPython.
import brotli
except ImportError: # pragma: no cover
try:
# The CFFI bindings in `brotlicffi` are recommended for PyPy
# and other environments.
import brotlicffi as brotli
except ImportError:
brotli = None
# Zstandard support is optional
try:
import zstandard
except ImportError: # pragma: no cover
zstandard = None # type: ignore
class ContentDecoder:
def decode(self, data: bytes) -> bytes:
@@ -137,6 +158,48 @@ class BrotliDecoder(ContentDecoder):
raise DecodingError(str(exc)) from exc
class ZStandardDecoder(ContentDecoder):
"""
Handle 'zstd' RFC 8878 decoding.
Requires `pip install zstandard`.
Can be installed as a dependency of httpx using `pip install httpx[zstd]`.
"""
# inspired by the ZstdDecoder implementation in urllib3
def __init__(self) -> None:
if zstandard is None: # pragma: no cover
raise ImportError(
"Using 'ZStandardDecoder', ..."
"Make sure to install httpx using `pip install httpx[zstd]`."
) from None
self.decompressor = zstandard.ZstdDecompressor().decompressobj()
self.seen_data = False
def decode(self, data: bytes) -> bytes:
assert zstandard is not None
self.seen_data = True
output = io.BytesIO()
try:
output.write(self.decompressor.decompress(data))
while self.decompressor.eof and self.decompressor.unused_data:
unused_data = self.decompressor.unused_data
self.decompressor = zstandard.ZstdDecompressor().decompressobj()
output.write(self.decompressor.decompress(unused_data))
except zstandard.ZstdError as exc:
raise DecodingError(str(exc)) from exc
return output.getvalue()
def flush(self) -> bytes:
if not self.seen_data:
return b""
ret = self.decompressor.flush() # note: this is a no-op
if not self.decompressor.eof:
raise DecodingError("Zstandard data is incomplete") # pragma: no cover
return bytes(ret)
class MultiDecoder(ContentDecoder):
"""
Handle the case where multiple encodings have been applied.
@@ -167,11 +230,11 @@ class ByteChunker:
Handles returning byte content in fixed-size chunks.
"""
def __init__(self, chunk_size: typing.Optional[int] = None) -> None:
def __init__(self, chunk_size: int | None = None) -> None:
self._buffer = io.BytesIO()
self._chunk_size = chunk_size
def decode(self, content: bytes) -> typing.List[bytes]:
def decode(self, content: bytes) -> list[bytes]:
if self._chunk_size is None:
return [content] if content else []
@@ -194,7 +257,7 @@ class ByteChunker:
else:
return []
def flush(self) -> typing.List[bytes]:
def flush(self) -> list[bytes]:
value = self._buffer.getvalue()
self._buffer.seek(0)
self._buffer.truncate()
@@ -206,13 +269,13 @@ class TextChunker:
Handles returning text content in fixed-size chunks.
"""
def __init__(self, chunk_size: typing.Optional[int] = None) -> None:
def __init__(self, chunk_size: int | None = None) -> None:
self._buffer = io.StringIO()
self._chunk_size = chunk_size
def decode(self, content: str) -> typing.List[str]:
def decode(self, content: str) -> list[str]:
if self._chunk_size is None:
return [content]
return [content] if content else []
self._buffer.write(content)
if self._buffer.tell() >= self._chunk_size:
@@ -233,7 +296,7 @@ class TextChunker:
else:
return []
def flush(self) -> typing.List[str]:
def flush(self) -> list[str]:
value = self._buffer.getvalue()
self._buffer.seek(0)
self._buffer.truncate()
@@ -245,7 +308,7 @@ class TextDecoder:
Handles incrementally decoding bytes into text
"""
def __init__(self, encoding: str = "utf-8"):
def __init__(self, encoding: str = "utf-8") -> None:
self.decoder = codecs.getincrementaldecoder(encoding)(errors="replace")
def decode(self, data: bytes) -> str:
@@ -259,14 +322,15 @@ class LineDecoder:
"""
Handles incrementally reading lines from text.
Has the same behaviour as the stdllib splitlines, but handling the input iteratively.
Has the same behaviour as the stdllib splitlines,
but handling the input iteratively.
"""
def __init__(self) -> None:
self.buffer: typing.List[str] = []
self.buffer: list[str] = []
self.trailing_cr: bool = False
def decode(self, text: str) -> typing.List[str]:
def decode(self, text: str) -> list[str]:
# See https://docs.python.org/3/library/stdtypes.html#str.splitlines
NEWLINE_CHARS = "\n\r\x0b\x0c\x1c\x1d\x1e\x85\u2028\u2029"
@@ -279,7 +343,9 @@ class LineDecoder:
text = text[:-1]
if not text:
return []
# NOTE: the edge case input of empty text doesn't occur in practice,
# because other httpx internals filter out this value
return [] # pragma: no cover
trailing_newline = text[-1] in NEWLINE_CHARS
lines = text.splitlines()
@@ -302,7 +368,7 @@ class LineDecoder:
return lines
def flush(self) -> typing.List[str]:
def flush(self) -> list[str]:
if not self.buffer and not self.trailing_cr:
return []
@@ -317,8 +383,11 @@ SUPPORTED_DECODERS = {
"gzip": GZipDecoder,
"deflate": DeflateDecoder,
"br": BrotliDecoder,
"zstd": ZStandardDecoder,
}
if brotli is None:
SUPPORTED_DECODERS.pop("br") # pragma: no cover
if zstandard is None:
SUPPORTED_DECODERS.pop("zstd") # pragma: no cover

View File

@@ -30,12 +30,46 @@ Our exception hierarchy:
x ResponseNotRead
x RequestNotRead
"""
from __future__ import annotations
import contextlib
import typing
if typing.TYPE_CHECKING:
from ._models import Request, Response # pragma: no cover
__all__ = [
"CloseError",
"ConnectError",
"ConnectTimeout",
"CookieConflict",
"DecodingError",
"HTTPError",
"HTTPStatusError",
"InvalidURL",
"LocalProtocolError",
"NetworkError",
"PoolTimeout",
"ProtocolError",
"ProxyError",
"ReadError",
"ReadTimeout",
"RemoteProtocolError",
"RequestError",
"RequestNotRead",
"ResponseNotRead",
"StreamClosed",
"StreamConsumed",
"StreamError",
"TimeoutException",
"TooManyRedirects",
"TransportError",
"UnsupportedProtocol",
"WriteError",
"WriteTimeout",
]
class HTTPError(Exception):
"""
@@ -57,16 +91,16 @@ class HTTPError(Exception):
def __init__(self, message: str) -> None:
super().__init__(message)
self._request: typing.Optional["Request"] = None
self._request: Request | None = None
@property
def request(self) -> "Request":
def request(self) -> Request:
if self._request is None:
raise RuntimeError("The .request property has not been set.")
return self._request
@request.setter
def request(self, request: "Request") -> None:
def request(self, request: Request) -> None:
self._request = request
@@ -75,9 +109,7 @@ class RequestError(HTTPError):
Base class for all exceptions that may occur when issuing a `.request()`.
"""
def __init__(
self, message: str, *, request: typing.Optional["Request"] = None
) -> None:
def __init__(self, message: str, *, request: Request | None = None) -> None:
super().__init__(message)
# At the point an exception is raised we won't typically have a request
# instance to associate it with.
@@ -230,9 +262,7 @@ class HTTPStatusError(HTTPError):
May be raised when calling `response.raise_for_status()`
"""
def __init__(
self, message: str, *, request: "Request", response: "Response"
) -> None:
def __init__(self, message: str, *, request: Request, response: Response) -> None:
super().__init__(message)
self.request = request
self.response = response
@@ -313,7 +343,10 @@ class ResponseNotRead(StreamError):
"""
def __init__(self) -> None:
message = "Attempted to access streaming response content, without having called `read()`."
message = (
"Attempted to access streaming response content,"
" without having called `read()`."
)
super().__init__(message)
@@ -323,13 +356,16 @@ class RequestNotRead(StreamError):
"""
def __init__(self) -> None:
message = "Attempted to access streaming request content, without having called `read()`."
message = (
"Attempted to access streaming request content,"
" without having called `read()`."
)
super().__init__(message)
@contextlib.contextmanager
def request_context(
request: typing.Optional["Request"] = None,
request: Request | None = None,
) -> typing.Iterator[None]:
"""
A context manager that can be used to attach the given request context

View File

@@ -1,10 +1,11 @@
from __future__ import annotations
import functools
import json
import sys
import typing
import click
import httpcore
import pygments.lexers
import pygments.util
import rich.console
@@ -18,6 +19,9 @@ from ._exceptions import RequestError
from ._models import Response
from ._status_codes import codes
if typing.TYPE_CHECKING:
import httpcore # pragma: no cover
def print_help() -> None:
console = rich.console.Console()
@@ -63,20 +67,21 @@ def print_help() -> None:
)
table.add_row(
"--auth [cyan]<USER PASS>",
"Username and password to include in the request. Specify '-' for the password to use "
"a password prompt. Note that using --verbose/-v will expose the Authorization "
"header, including the password encoding in a trivially reversible format.",
"Username and password to include in the request. Specify '-' for the password"
" to use a password prompt. Note that using --verbose/-v will expose"
" the Authorization header, including the password encoding"
" in a trivially reversible format.",
)
table.add_row(
"--proxies [cyan]URL",
"--proxy [cyan]URL",
"Send the request via a proxy. Should be the URL giving the proxy address.",
)
table.add_row(
"--timeout [cyan]FLOAT",
"Timeout value to use for network operations, such as establishing the connection, "
"reading some data, etc... [Default: 5.0]",
"Timeout value to use for network operations, such as establishing the"
" connection, reading some data, etc... [Default: 5.0]",
)
table.add_row("--follow-redirects", "Automatically follow redirects.")
@@ -124,8 +129,8 @@ def format_request_headers(request: httpcore.Request, http2: bool = False) -> st
def format_response_headers(
http_version: bytes,
status: int,
reason_phrase: typing.Optional[bytes],
headers: typing.List[typing.Tuple[bytes, bytes]],
reason_phrase: bytes | None,
headers: list[tuple[bytes, bytes]],
) -> str:
version = http_version.decode("ascii")
reason = (
@@ -151,8 +156,8 @@ def print_request_headers(request: httpcore.Request, http2: bool = False) -> Non
def print_response_headers(
http_version: bytes,
status: int,
reason_phrase: typing.Optional[bytes],
headers: typing.List[typing.Tuple[bytes, bytes]],
reason_phrase: bytes | None,
headers: list[tuple[bytes, bytes]],
) -> None:
console = rich.console.Console()
http_text = format_response_headers(http_version, status, reason_phrase, headers)
@@ -267,7 +272,7 @@ def download_response(response: Response, download: typing.BinaryIO) -> None:
def validate_json(
ctx: click.Context,
param: typing.Union[click.Option, click.Parameter],
param: click.Option | click.Parameter,
value: typing.Any,
) -> typing.Any:
if value is None:
@@ -281,7 +286,7 @@ def validate_json(
def validate_auth(
ctx: click.Context,
param: typing.Union[click.Option, click.Parameter],
param: click.Option | click.Parameter,
value: typing.Any,
) -> typing.Any:
if value == (None, None):
@@ -295,7 +300,7 @@ def validate_auth(
def handle_help(
ctx: click.Context,
param: typing.Union[click.Option, click.Parameter],
param: click.Option | click.Parameter,
value: typing.Any,
) -> None:
if not value or ctx.resilient_parsing:
@@ -385,8 +390,8 @@ def handle_help(
),
)
@click.option(
"--proxies",
"proxies",
"--proxy",
"proxy",
type=str,
default=None,
help="Send the request via a proxy. Should be the URL giving the proxy address.",
@@ -447,20 +452,20 @@ def handle_help(
def main(
url: str,
method: str,
params: typing.List[typing.Tuple[str, str]],
params: list[tuple[str, str]],
content: str,
data: typing.List[typing.Tuple[str, str]],
files: typing.List[typing.Tuple[str, click.File]],
data: list[tuple[str, str]],
files: list[tuple[str, click.File]],
json: str,
headers: typing.List[typing.Tuple[str, str]],
cookies: typing.List[typing.Tuple[str, str]],
auth: typing.Optional[typing.Tuple[str, str]],
proxies: str,
headers: list[tuple[str, str]],
cookies: list[tuple[str, str]],
auth: tuple[str, str] | None,
proxy: str,
timeout: float,
follow_redirects: bool,
verify: bool,
http2: bool,
download: typing.Optional[typing.BinaryIO],
download: typing.BinaryIO | None,
verbose: bool,
) -> None:
"""
@@ -471,12 +476,7 @@ def main(
method = "POST" if content or data or files or json else "GET"
try:
with Client(
proxies=proxies,
timeout=timeout,
verify=verify,
http2=http2,
) as client:
with Client(proxy=proxy, timeout=timeout, http2=http2, verify=verify) as client:
with client.stream(
method,
url,

View File

@@ -1,6 +1,10 @@
from __future__ import annotations
import codecs
import datetime
import email.message
import json as jsonlib
import re
import typing
import urllib.request
from collections.abc import Mapping
@@ -42,15 +46,94 @@ from ._types import (
SyncByteStream,
)
from ._urls import URL
from ._utils import (
guess_json_utf,
is_known_encoding,
normalize_header_key,
normalize_header_value,
obfuscate_sensitive_headers,
parse_content_type_charset,
parse_header_links,
)
from ._utils import to_bytes_or_str, to_str
__all__ = ["Cookies", "Headers", "Request", "Response"]
SENSITIVE_HEADERS = {"authorization", "proxy-authorization"}
def _is_known_encoding(encoding: str) -> bool:
"""
Return `True` if `encoding` is a known codec.
"""
try:
codecs.lookup(encoding)
except LookupError:
return False
return True
def _normalize_header_key(key: str | bytes, encoding: str | None = None) -> bytes:
"""
Coerce str/bytes into a strictly byte-wise HTTP header key.
"""
return key if isinstance(key, bytes) else key.encode(encoding or "ascii")
def _normalize_header_value(value: str | bytes, encoding: str | None = None) -> bytes:
"""
Coerce str/bytes into a strictly byte-wise HTTP header value.
"""
if isinstance(value, bytes):
return value
if not isinstance(value, str):
raise TypeError(f"Header value must be str or bytes, not {type(value)}")
return value.encode(encoding or "ascii")
def _parse_content_type_charset(content_type: str) -> str | None:
# We used to use `cgi.parse_header()` here, but `cgi` became a dead battery.
# See: https://peps.python.org/pep-0594/#cgi
msg = email.message.Message()
msg["content-type"] = content_type
return msg.get_content_charset(failobj=None)
def _parse_header_links(value: str) -> list[dict[str, str]]:
"""
Returns a list of parsed link headers, for more info see:
https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Link
The generic syntax of those is:
Link: < uri-reference >; param1=value1; param2="value2"
So for instance:
Link; '<http:/.../front.jpeg>; type="image/jpeg",<http://.../back.jpeg>;'
would return
[
{"url": "http:/.../front.jpeg", "type": "image/jpeg"},
{"url": "http://.../back.jpeg"},
]
:param value: HTTP Link entity-header field
:return: list of parsed link headers
"""
links: list[dict[str, str]] = []
replace_chars = " '\""
value = value.strip(replace_chars)
if not value:
return links
for val in re.split(", *<", value):
try:
url, params = val.split(";", 1)
except ValueError:
url, params = val, ""
link = {"url": url.strip("<> '\"")}
for param in params.split(";"):
try:
key, value = param.split("=")
except ValueError:
break
link[key.strip(replace_chars)] = value.strip(replace_chars)
links.append(link)
return links
def _obfuscate_sensitive_headers(
items: typing.Iterable[tuple[typing.AnyStr, typing.AnyStr]],
) -> typing.Iterator[tuple[typing.AnyStr, typing.AnyStr]]:
for k, v in items:
if to_str(k.lower()) in SENSITIVE_HEADERS:
v = to_bytes_or_str("[secure]", match_type_of=v)
yield k, v
class Headers(typing.MutableMapping[str, str]):
@@ -60,31 +143,23 @@ class Headers(typing.MutableMapping[str, str]):
def __init__(
self,
headers: typing.Optional[HeaderTypes] = None,
encoding: typing.Optional[str] = None,
headers: HeaderTypes | None = None,
encoding: str | None = None,
) -> None:
if headers is None:
self._list = [] # type: typing.List[typing.Tuple[bytes, bytes, bytes]]
elif isinstance(headers, Headers):
self._list = [] # type: typing.List[typing.Tuple[bytes, bytes, bytes]]
if isinstance(headers, Headers):
self._list = list(headers._list)
elif isinstance(headers, Mapping):
self._list = [
(
normalize_header_key(k, lower=False, encoding=encoding),
normalize_header_key(k, lower=True, encoding=encoding),
normalize_header_value(v, encoding),
)
for k, v in headers.items()
]
else:
self._list = [
(
normalize_header_key(k, lower=False, encoding=encoding),
normalize_header_key(k, lower=True, encoding=encoding),
normalize_header_value(v, encoding),
)
for k, v in headers
]
for k, v in headers.items():
bytes_key = _normalize_header_key(k, encoding)
bytes_value = _normalize_header_value(v, encoding)
self._list.append((bytes_key, bytes_key.lower(), bytes_value))
elif headers is not None:
for k, v in headers:
bytes_key = _normalize_header_key(k, encoding)
bytes_value = _normalize_header_value(v, encoding)
self._list.append((bytes_key, bytes_key.lower(), bytes_value))
self._encoding = encoding
@@ -118,7 +193,7 @@ class Headers(typing.MutableMapping[str, str]):
self._encoding = value
@property
def raw(self) -> typing.List[typing.Tuple[bytes, bytes]]:
def raw(self) -> list[tuple[bytes, bytes]]:
"""
Returns a list of the raw header items, as byte pairs.
"""
@@ -128,7 +203,7 @@ class Headers(typing.MutableMapping[str, str]):
return {key.decode(self.encoding): None for _, key, value in self._list}.keys()
def values(self) -> typing.ValuesView[str]:
values_dict: typing.Dict[str, str] = {}
values_dict: dict[str, str] = {}
for _, key, value in self._list:
str_key = key.decode(self.encoding)
str_value = value.decode(self.encoding)
@@ -143,7 +218,7 @@ class Headers(typing.MutableMapping[str, str]):
Return `(key, value)` items of headers. Concatenate headers
into a single comma separated value when a key occurs multiple times.
"""
values_dict: typing.Dict[str, str] = {}
values_dict: dict[str, str] = {}
for _, key, value in self._list:
str_key = key.decode(self.encoding)
str_value = value.decode(self.encoding)
@@ -153,7 +228,7 @@ class Headers(typing.MutableMapping[str, str]):
values_dict[str_key] = str_value
return values_dict.items()
def multi_items(self) -> typing.List[typing.Tuple[str, str]]:
def multi_items(self) -> list[tuple[str, str]]:
"""
Return a list of `(key, value)` pairs of headers. Allow multiple
occurrences of the same key without concatenating into a single
@@ -174,7 +249,7 @@ class Headers(typing.MutableMapping[str, str]):
except KeyError:
return default
def get_list(self, key: str, split_commas: bool = False) -> typing.List[str]:
def get_list(self, key: str, split_commas: bool = False) -> list[str]:
"""
Return a list of all header values for a given key.
If `split_commas=True` is passed, then any comma separated header
@@ -196,14 +271,14 @@ class Headers(typing.MutableMapping[str, str]):
split_values.extend([item.strip() for item in value.split(",")])
return split_values
def update(self, headers: typing.Optional[HeaderTypes] = None) -> None: # type: ignore
def update(self, headers: HeaderTypes | None = None) -> None: # type: ignore
headers = Headers(headers)
for key in headers.keys():
if key in self:
self.pop(key)
self._list.extend(headers._list)
def copy(self) -> "Headers":
def copy(self) -> Headers:
return Headers(self, encoding=self.encoding)
def __getitem__(self, key: str) -> str:
@@ -295,7 +370,7 @@ class Headers(typing.MutableMapping[str, str]):
if self.encoding != "ascii":
encoding_str = f", encoding={self.encoding!r}"
as_list = list(obfuscate_sensitive_headers(self.multi_items()))
as_list = list(_obfuscate_sensitive_headers(self.multi_items()))
as_dict = dict(as_list)
no_duplicate_keys = len(as_dict) == len(as_list)
@@ -307,35 +382,29 @@ class Headers(typing.MutableMapping[str, str]):
class Request:
def __init__(
self,
method: typing.Union[str, bytes],
url: typing.Union["URL", str],
method: str,
url: URL | str,
*,
params: typing.Optional[QueryParamTypes] = None,
headers: typing.Optional[HeaderTypes] = None,
cookies: typing.Optional[CookieTypes] = None,
content: typing.Optional[RequestContent] = None,
data: typing.Optional[RequestData] = None,
files: typing.Optional[RequestFiles] = None,
json: typing.Optional[typing.Any] = None,
stream: typing.Union[SyncByteStream, AsyncByteStream, None] = None,
extensions: typing.Optional[RequestExtensions] = None,
):
self.method = (
method.decode("ascii").upper()
if isinstance(method, bytes)
else method.upper()
)
self.url = URL(url)
if params is not None:
self.url = self.url.copy_merge_params(params=params)
params: QueryParamTypes | None = None,
headers: HeaderTypes | None = None,
cookies: CookieTypes | None = None,
content: RequestContent | None = None,
data: RequestData | None = None,
files: RequestFiles | None = None,
json: typing.Any | None = None,
stream: SyncByteStream | AsyncByteStream | None = None,
extensions: RequestExtensions | None = None,
) -> None:
self.method = method.upper()
self.url = URL(url) if params is None else URL(url, params=params)
self.headers = Headers(headers)
self.extensions = {} if extensions is None else extensions
self.extensions = {} if extensions is None else dict(extensions)
if cookies:
Cookies(cookies).set_cookie_header(self)
if stream is None:
content_type: typing.Optional[str] = self.headers.get("content-type")
content_type: str | None = self.headers.get("content-type")
headers, stream = encode_request(
content=content,
data=data,
@@ -359,7 +428,8 @@ class Request:
# Using `content=...` implies automatically populated `Host` and content
# headers, of either `Content-Length: ...` or `Transfer-Encoding: chunked`.
#
# Using `stream=...` will not automatically include *any* auto-populated headers.
# Using `stream=...` will not automatically include *any*
# auto-populated headers.
#
# As an end-user you don't really need `stream=...`. It's only
# useful when:
@@ -368,14 +438,14 @@ class Request:
# * Creating request instances on the *server-side* of the transport API.
self.stream = stream
def _prepare(self, default_headers: typing.Dict[str, str]) -> None:
def _prepare(self, default_headers: dict[str, str]) -> None:
for key, value in default_headers.items():
# Ignore Transfer-Encoding if the Content-Length has been set explicitly.
if key.lower() == "transfer-encoding" and "Content-Length" in self.headers:
continue
self.headers.setdefault(key, value)
auto_headers: typing.List[typing.Tuple[bytes, bytes]] = []
auto_headers: list[tuple[bytes, bytes]] = []
has_host = "Host" in self.headers
has_content_length = (
@@ -428,14 +498,14 @@ class Request:
url = str(self.url)
return f"<{class_name}({self.method!r}, {url!r})>"
def __getstate__(self) -> typing.Dict[str, typing.Any]:
def __getstate__(self) -> dict[str, typing.Any]:
return {
name: value
for name, value in self.__dict__.items()
if name not in ["extensions", "stream"]
}
def __setstate__(self, state: typing.Dict[str, typing.Any]) -> None:
def __setstate__(self, state: dict[str, typing.Any]) -> None:
for name, value in state.items():
setattr(self, name, value)
self.extensions = {}
@@ -447,27 +517,27 @@ class Response:
self,
status_code: int,
*,
headers: typing.Optional[HeaderTypes] = None,
content: typing.Optional[ResponseContent] = None,
text: typing.Optional[str] = None,
html: typing.Optional[str] = None,
headers: HeaderTypes | None = None,
content: ResponseContent | None = None,
text: str | None = None,
html: str | None = None,
json: typing.Any = None,
stream: typing.Union[SyncByteStream, AsyncByteStream, None] = None,
request: typing.Optional[Request] = None,
extensions: typing.Optional[ResponseExtensions] = None,
history: typing.Optional[typing.List["Response"]] = None,
default_encoding: typing.Union[str, typing.Callable[[bytes], str]] = "utf-8",
):
stream: SyncByteStream | AsyncByteStream | None = None,
request: Request | None = None,
extensions: ResponseExtensions | None = None,
history: list[Response] | None = None,
default_encoding: str | typing.Callable[[bytes], str] = "utf-8",
) -> None:
self.status_code = status_code
self.headers = Headers(headers)
self._request: typing.Optional[Request] = request
self._request: Request | None = request
# When follow_redirects=False and a redirect is received,
# the client will set `response.next_request`.
self.next_request: typing.Optional[Request] = None
self.next_request: Request | None = None
self.extensions = {} if extensions is None else extensions
self.extensions = {} if extensions is None else dict(extensions)
self.history = [] if history is None else list(history)
self.is_closed = False
@@ -498,7 +568,7 @@ class Response:
self._num_bytes_downloaded = 0
def _prepare(self, default_headers: typing.Dict[str, str]) -> None:
def _prepare(self, default_headers: dict[str, str]) -> None:
for key, value in default_headers.items():
# Ignore Transfer-Encoding if the Content-Length has been set explicitly.
if key.lower() == "transfer-encoding" and "content-length" in self.headers:
@@ -580,7 +650,7 @@ class Response:
return self._text
@property
def encoding(self) -> typing.Optional[str]:
def encoding(self) -> str | None:
"""
Return an encoding to use for decoding the byte content into text.
The priority for determining this is given by...
@@ -593,7 +663,7 @@ class Response:
"""
if not hasattr(self, "_encoding"):
encoding = self.charset_encoding
if encoding is None or not is_known_encoding(encoding):
if encoding is None or not _is_known_encoding(encoding):
if isinstance(self.default_encoding, str):
encoding = self.default_encoding
elif hasattr(self, "_content"):
@@ -603,10 +673,20 @@ class Response:
@encoding.setter
def encoding(self, value: str) -> None:
"""
Set the encoding to use for decoding the byte content into text.
If the `text` attribute has been accessed, attempting to set the
encoding will throw a ValueError.
"""
if hasattr(self, "_text"):
raise ValueError(
"Setting encoding after `text` has been accessed is not allowed."
)
self._encoding = value
@property
def charset_encoding(self) -> typing.Optional[str]:
def charset_encoding(self) -> str | None:
"""
Return the encoding, as specified by the Content-Type header.
"""
@@ -614,7 +694,7 @@ class Response:
if content_type is None:
return None
return parse_content_type_charset(content_type)
return _parse_content_type_charset(content_type)
def _get_content_decoder(self) -> ContentDecoder:
"""
@@ -622,7 +702,7 @@ class Response:
content, depending on the Content-Encoding used in the response.
"""
if not hasattr(self, "_decoder"):
decoders: typing.List[ContentDecoder] = []
decoders: list[ContentDecoder] = []
values = self.headers.get_list("content-encoding", split_commas=True)
for value in values:
value = value.strip().lower()
@@ -711,7 +791,7 @@ class Response:
and "Location" in self.headers
)
def raise_for_status(self) -> None:
def raise_for_status(self) -> Response:
"""
Raise the `HTTPStatusError` if one occurred.
"""
@@ -723,18 +803,18 @@ class Response:
)
if self.is_success:
return
return self
if self.has_redirect_location:
message = (
"{error_type} '{0.status_code} {0.reason_phrase}' for url '{0.url}'\n"
"Redirect location: '{0.headers[location]}'\n"
"For more information check: https://httpstatuses.com/{0.status_code}"
"For more information check: https://developer.mozilla.org/en-US/docs/Web/HTTP/Status/{0.status_code}"
)
else:
message = (
"{error_type} '{0.status_code} {0.reason_phrase}' for url '{0.url}'\n"
"For more information check: https://httpstatuses.com/{0.status_code}"
"For more information check: https://developer.mozilla.org/en-US/docs/Web/HTTP/Status/{0.status_code}"
)
status_class = self.status_code // 100
@@ -749,32 +829,28 @@ class Response:
raise HTTPStatusError(message, request=request, response=self)
def json(self, **kwargs: typing.Any) -> typing.Any:
if self.charset_encoding is None and self.content and len(self.content) > 3:
encoding = guess_json_utf(self.content)
if encoding is not None:
return jsonlib.loads(self.content.decode(encoding), **kwargs)
return jsonlib.loads(self.text, **kwargs)
return jsonlib.loads(self.content, **kwargs)
@property
def cookies(self) -> "Cookies":
def cookies(self) -> Cookies:
if not hasattr(self, "_cookies"):
self._cookies = Cookies()
self._cookies.extract_cookies(self)
return self._cookies
@property
def links(self) -> typing.Dict[typing.Optional[str], typing.Dict[str, str]]:
def links(self) -> dict[str | None, dict[str, str]]:
"""
Returns the parsed header links of the response, if any
"""
header = self.headers.get("link")
ldict = {}
if header:
links = parse_header_links(header)
for link in links:
key = link.get("rel") or link.get("url")
ldict[key] = link
return ldict
if header is None:
return {}
return {
(link.get("rel") or link.get("url")): link
for link in _parse_header_links(header)
}
@property
def num_bytes_downloaded(self) -> int:
@@ -783,14 +859,14 @@ class Response:
def __repr__(self) -> str:
return f"<Response [{self.status_code} {self.reason_phrase}]>"
def __getstate__(self) -> typing.Dict[str, typing.Any]:
def __getstate__(self) -> dict[str, typing.Any]:
return {
name: value
for name, value in self.__dict__.items()
if name not in ["extensions", "stream", "is_closed", "_decoder"]
}
def __setstate__(self, state: typing.Dict[str, typing.Any]) -> None:
def __setstate__(self, state: dict[str, typing.Any]) -> None:
for name, value in state.items():
setattr(self, name, value)
self.is_closed = True
@@ -805,12 +881,10 @@ class Response:
self._content = b"".join(self.iter_bytes())
return self._content
def iter_bytes(
self, chunk_size: typing.Optional[int] = None
) -> typing.Iterator[bytes]:
def iter_bytes(self, chunk_size: int | None = None) -> typing.Iterator[bytes]:
"""
A byte-iterator over the decoded response content.
This allows us to handle gzip, deflate, and brotli encoded responses.
This allows us to handle gzip, deflate, brotli, and zstd encoded responses.
"""
if hasattr(self, "_content"):
chunk_size = len(self._content) if chunk_size is None else chunk_size
@@ -830,9 +904,7 @@ class Response:
for chunk in chunker.flush():
yield chunk
def iter_text(
self, chunk_size: typing.Optional[int] = None
) -> typing.Iterator[str]:
def iter_text(self, chunk_size: int | None = None) -> typing.Iterator[str]:
"""
A str-iterator over the decoded response content
that handles both gzip, deflate, etc but also detects the content's
@@ -847,7 +919,7 @@ class Response:
yield chunk
text_content = decoder.flush()
for chunk in chunker.decode(text_content):
yield chunk
yield chunk # pragma: no cover
for chunk in chunker.flush():
yield chunk
@@ -860,9 +932,7 @@ class Response:
for line in decoder.flush():
yield line
def iter_raw(
self, chunk_size: typing.Optional[int] = None
) -> typing.Iterator[bytes]:
def iter_raw(self, chunk_size: int | None = None) -> typing.Iterator[bytes]:
"""
A byte-iterator over the raw response content.
"""
@@ -910,11 +980,11 @@ class Response:
return self._content
async def aiter_bytes(
self, chunk_size: typing.Optional[int] = None
self, chunk_size: int | None = None
) -> typing.AsyncIterator[bytes]:
"""
A byte-iterator over the decoded response content.
This allows us to handle gzip, deflate, and brotli encoded responses.
This allows us to handle gzip, deflate, brotli, and zstd encoded responses.
"""
if hasattr(self, "_content"):
chunk_size = len(self._content) if chunk_size is None else chunk_size
@@ -935,7 +1005,7 @@ class Response:
yield chunk
async def aiter_text(
self, chunk_size: typing.Optional[int] = None
self, chunk_size: int | None = None
) -> typing.AsyncIterator[str]:
"""
A str-iterator over the decoded response content
@@ -951,7 +1021,7 @@ class Response:
yield chunk
text_content = decoder.flush()
for chunk in chunker.decode(text_content):
yield chunk
yield chunk # pragma: no cover
for chunk in chunker.flush():
yield chunk
@@ -965,7 +1035,7 @@ class Response:
yield line
async def aiter_raw(
self, chunk_size: typing.Optional[int] = None
self, chunk_size: int | None = None
) -> typing.AsyncIterator[bytes]:
"""
A byte-iterator over the raw response content.
@@ -1011,7 +1081,7 @@ class Cookies(typing.MutableMapping[str, str]):
HTTP Cookies, as a mutable mapping.
"""
def __init__(self, cookies: typing.Optional[CookieTypes] = None) -> None:
def __init__(self, cookies: CookieTypes | None = None) -> None:
if cookies is None or isinstance(cookies, dict):
self.jar = CookieJar()
if isinstance(cookies, dict):
@@ -1073,10 +1143,10 @@ class Cookies(typing.MutableMapping[str, str]):
def get( # type: ignore
self,
name: str,
default: typing.Optional[str] = None,
domain: typing.Optional[str] = None,
path: typing.Optional[str] = None,
) -> typing.Optional[str]:
default: str | None = None,
domain: str | None = None,
path: str | None = None,
) -> str | None:
"""
Get a cookie by name. May optionally include domain and path
in order to specify exactly which cookie to retrieve.
@@ -1098,8 +1168,8 @@ class Cookies(typing.MutableMapping[str, str]):
def delete(
self,
name: str,
domain: typing.Optional[str] = None,
path: typing.Optional[str] = None,
domain: str | None = None,
path: str | None = None,
) -> None:
"""
Delete a cookie by name. May optionally include domain and path
@@ -1119,9 +1189,7 @@ class Cookies(typing.MutableMapping[str, str]):
for cookie in remove:
self.jar.clear(cookie.domain, cookie.path, cookie.name)
def clear(
self, domain: typing.Optional[str] = None, path: typing.Optional[str] = None
) -> None:
def clear(self, domain: str | None = None, path: str | None = None) -> None:
"""
Delete all cookies. Optionally include a domain and path in
order to only delete a subset of all the cookies.
@@ -1134,7 +1202,7 @@ class Cookies(typing.MutableMapping[str, str]):
args.append(path)
self.jar.clear(*args)
def update(self, cookies: typing.Optional[CookieTypes] = None) -> None: # type: ignore
def update(self, cookies: CookieTypes | None = None) -> None: # type: ignore
cookies = Cookies(cookies)
for cookie in cookies.jar:
self.jar.set_cookie(cookie)
@@ -1196,7 +1264,7 @@ class Cookies(typing.MutableMapping[str, str]):
for use with `CookieJar` operations.
"""
def __init__(self, response: Response):
def __init__(self, response: Response) -> None:
self.response = response
def info(self) -> email.message.Message:

View File

@@ -1,6 +1,9 @@
import binascii
from __future__ import annotations
import io
import mimetypes
import os
import re
import typing
from pathlib import Path
@@ -13,17 +16,46 @@ from ._types import (
SyncByteStream,
)
from ._utils import (
format_form_param,
guess_content_type,
peek_filelike_length,
primitive_value_to_str,
to_bytes,
)
_HTML5_FORM_ENCODING_REPLACEMENTS = {'"': "%22", "\\": "\\\\"}
_HTML5_FORM_ENCODING_REPLACEMENTS.update(
{chr(c): "%{:02X}".format(c) for c in range(0x1F + 1) if c != 0x1B}
)
_HTML5_FORM_ENCODING_RE = re.compile(
r"|".join([re.escape(c) for c in _HTML5_FORM_ENCODING_REPLACEMENTS.keys()])
)
def _format_form_param(name: str, value: str) -> bytes:
"""
Encode a name/value pair within a multipart form.
"""
def replacer(match: typing.Match[str]) -> str:
return _HTML5_FORM_ENCODING_REPLACEMENTS[match.group(0)]
value = _HTML5_FORM_ENCODING_RE.sub(replacer, value)
return f'{name}="{value}"'.encode()
def _guess_content_type(filename: str | None) -> str | None:
"""
Guesses the mimetype based on a filename. Defaults to `application/octet-stream`.
Returns `None` if `filename` is `None` or empty.
"""
if filename:
return mimetypes.guess_type(filename)[0] or "application/octet-stream"
return None
def get_multipart_boundary_from_content_type(
content_type: typing.Optional[bytes],
) -> typing.Optional[bytes]:
content_type: bytes | None,
) -> bytes | None:
if not content_type or not content_type.startswith(b"multipart/form-data"):
return None
# parse boundary according to
@@ -40,25 +72,24 @@ class DataField:
A single form field item, within a multipart form field.
"""
def __init__(
self, name: str, value: typing.Union[str, bytes, int, float, None]
) -> None:
def __init__(self, name: str, value: str | bytes | int | float | None) -> None:
if not isinstance(name, str):
raise TypeError(
f"Invalid type for name. Expected str, got {type(name)}: {name!r}"
)
if value is not None and not isinstance(value, (str, bytes, int, float)):
raise TypeError(
f"Invalid type for value. Expected primitive type, got {type(value)}: {value!r}"
"Invalid type for value. Expected primitive type,"
f" got {type(value)}: {value!r}"
)
self.name = name
self.value: typing.Union[str, bytes] = (
self.value: str | bytes = (
value if isinstance(value, bytes) else primitive_value_to_str(value)
)
def render_headers(self) -> bytes:
if not hasattr(self, "_headers"):
name = format_form_param("name", self.name)
name = _format_form_param("name", self.name)
self._headers = b"".join(
[b"Content-Disposition: form-data; ", name, b"\r\n\r\n"]
)
@@ -93,18 +124,20 @@ class FileField:
fileobj: FileContent
headers: typing.Dict[str, str] = {}
content_type: typing.Optional[str] = None
headers: dict[str, str] = {}
content_type: str | None = None
# This large tuple based API largely mirror's requests' API
# It would be good to think of better APIs for this that we could include in httpx 2.0
# since variable length tuples (especially of 4 elements) are quite unwieldly
# It would be good to think of better APIs for this that we could
# include in httpx 2.0 since variable length tuples(especially of 4 elements)
# are quite unwieldly
if isinstance(value, tuple):
if len(value) == 2:
# neither the 3rd parameter (content_type) nor the 4th (headers) was included
filename, fileobj = value # type: ignore
# neither the 3rd parameter (content_type) nor the 4th (headers)
# was included
filename, fileobj = value
elif len(value) == 3:
filename, fileobj, content_type = value # type: ignore
filename, fileobj, content_type = value
else:
# all 4 parameters included
filename, fileobj, content_type, headers = value # type: ignore
@@ -113,13 +146,13 @@ class FileField:
fileobj = value
if content_type is None:
content_type = guess_content_type(filename)
content_type = _guess_content_type(filename)
has_content_type_header = any("content-type" in key.lower() for key in headers)
if content_type is not None and not has_content_type_header:
# note that unlike requests, we ignore the content_type
# provided in the 3rd tuple element if it is also included in the headers
# requests does the opposite (it overwrites the header with the 3rd tuple element)
# note that unlike requests, we ignore the content_type provided in the 3rd
# tuple element if it is also included in the headers requests does
# the opposite (it overwrites the headerwith the 3rd tuple element)
headers["Content-Type"] = content_type
if isinstance(fileobj, io.StringIO):
@@ -135,7 +168,7 @@ class FileField:
self.file = fileobj
self.headers = headers
def get_length(self) -> typing.Optional[int]:
def get_length(self) -> int | None:
headers = self.render_headers()
if isinstance(self.file, (str, bytes)):
@@ -154,10 +187,10 @@ class FileField:
if not hasattr(self, "_headers"):
parts = [
b"Content-Disposition: form-data; ",
format_form_param("name", self.name),
_format_form_param("name", self.name),
]
if self.filename:
filename = format_form_param("filename", self.filename)
filename = _format_form_param("filename", self.filename)
parts.extend([b"; ", filename])
for header_name, header_value in self.headers.items():
key, val = f"\r\n{header_name}: ".encode(), header_value.encode()
@@ -197,10 +230,10 @@ class MultipartStream(SyncByteStream, AsyncByteStream):
self,
data: RequestData,
files: RequestFiles,
boundary: typing.Optional[bytes] = None,
boundary: bytes | None = None,
) -> None:
if boundary is None:
boundary = binascii.hexlify(os.urandom(16))
boundary = os.urandom(16).hex().encode("ascii")
self.boundary = boundary
self.content_type = "multipart/form-data; boundary=%s" % boundary.decode(
@@ -210,7 +243,7 @@ class MultipartStream(SyncByteStream, AsyncByteStream):
def _iter_fields(
self, data: RequestData, files: RequestFiles
) -> typing.Iterator[typing.Union[FileField, DataField]]:
) -> typing.Iterator[FileField | DataField]:
for name, value in data.items():
if isinstance(value, (tuple, list)):
for item in value:
@@ -229,7 +262,7 @@ class MultipartStream(SyncByteStream, AsyncByteStream):
yield b"\r\n"
yield b"--%s--\r\n" % self.boundary
def get_content_length(self) -> typing.Optional[int]:
def get_content_length(self) -> int | None:
"""
Return the length of the multipart encoded content, or `None` if
any of the files have a length that cannot be determined upfront.
@@ -251,7 +284,7 @@ class MultipartStream(SyncByteStream, AsyncByteStream):
# Content stream interface.
def get_headers(self) -> typing.Dict[str, str]:
def get_headers(self) -> dict[str, str]:
content_length = self.get_content_length()
content_type = self.content_type
if content_length is None:

View File

@@ -1,5 +1,9 @@
from __future__ import annotations
from enum import IntEnum
__all__ = ["codes"]
class codes(IntEnum):
"""HTTP status codes and reason phrases
@@ -21,7 +25,7 @@ class codes(IntEnum):
* RFC 8470: Using Early Data in HTTP
"""
def __new__(cls, value: int, phrase: str = "") -> "codes":
def __new__(cls, value: int, phrase: str = "") -> codes:
obj = int.__new__(cls, value)
obj._value_ = value

View File

@@ -0,0 +1,15 @@
from .asgi import *
from .base import *
from .default import *
from .mock import *
from .wsgi import *
__all__ = [
"ASGITransport",
"AsyncBaseTransport",
"BaseTransport",
"AsyncHTTPTransport",
"HTTPTransport",
"MockTransport",
"WSGITransport",
]

View File

@@ -1,6 +1,6 @@
import typing
from __future__ import annotations
import sniffio
import typing
from .._models import Request, Response
from .._types import AsyncByteStream
@@ -14,29 +14,46 @@ if typing.TYPE_CHECKING: # pragma: no cover
Event = typing.Union[asyncio.Event, trio.Event]
_Message = typing.Dict[str, typing.Any]
_Message = typing.MutableMapping[str, typing.Any]
_Receive = typing.Callable[[], typing.Awaitable[_Message]]
_Send = typing.Callable[
[typing.Dict[str, typing.Any]], typing.Coroutine[None, None, None]
[typing.MutableMapping[str, typing.Any]], typing.Awaitable[None]
]
_ASGIApp = typing.Callable[
[typing.Dict[str, typing.Any], _Receive, _Send], typing.Coroutine[None, None, None]
[typing.MutableMapping[str, typing.Any], _Receive, _Send], typing.Awaitable[None]
]
__all__ = ["ASGITransport"]
def create_event() -> "Event":
if sniffio.current_async_library() == "trio":
def is_running_trio() -> bool:
try:
# sniffio is a dependency of trio.
# See https://github.com/python-trio/trio/issues/2802
import sniffio
if sniffio.current_async_library() == "trio":
return True
except ImportError: # pragma: nocover
pass
return False
def create_event() -> Event:
if is_running_trio():
import trio
return trio.Event()
else:
import asyncio
return asyncio.Event()
import asyncio
return asyncio.Event()
class ASGIResponseStream(AsyncByteStream):
def __init__(self, body: typing.List[bytes]) -> None:
def __init__(self, body: list[bytes]) -> None:
self._body = body
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
@@ -46,17 +63,8 @@ class ASGIResponseStream(AsyncByteStream):
class ASGITransport(AsyncBaseTransport):
"""
A custom AsyncTransport that handles sending requests directly to an ASGI app.
The simplest way to use this functionality is to use the `app` argument.
```
client = httpx.AsyncClient(app=app)
```
Alternatively, you can setup the transport instance explicitly.
This allows you to include any additional configuration arguments specific
to the ASGITransport class:
```
```python
transport = httpx.ASGITransport(
app=app,
root_path="/submount",
@@ -81,7 +89,7 @@ class ASGITransport(AsyncBaseTransport):
app: _ASGIApp,
raise_app_exceptions: bool = True,
root_path: str = "",
client: typing.Tuple[str, int] = ("127.0.0.1", 123),
client: tuple[str, int] = ("127.0.0.1", 123),
) -> None:
self.app = app
self.raise_app_exceptions = raise_app_exceptions
@@ -103,7 +111,7 @@ class ASGITransport(AsyncBaseTransport):
"headers": [(k.lower(), v) for (k, v) in request.headers.raw],
"scheme": request.url.scheme,
"path": request.url.path,
"raw_path": request.url.raw_path,
"raw_path": request.url.raw_path.split(b"?")[0],
"query_string": request.url.query,
"server": (request.url.host, request.url.port),
"client": self.client,
@@ -123,7 +131,7 @@ class ASGITransport(AsyncBaseTransport):
# ASGI callables.
async def receive() -> typing.Dict[str, typing.Any]:
async def receive() -> dict[str, typing.Any]:
nonlocal request_complete
if request_complete:
@@ -137,7 +145,7 @@ class ASGITransport(AsyncBaseTransport):
return {"type": "http.request", "body": b"", "more_body": False}
return {"type": "http.request", "body": body, "more_body": True}
async def send(message: typing.Dict[str, typing.Any]) -> None:
async def send(message: typing.MutableMapping[str, typing.Any]) -> None:
nonlocal status_code, response_headers, response_started
if message["type"] == "http.response.start":
@@ -161,9 +169,15 @@ class ASGITransport(AsyncBaseTransport):
try:
await self.app(scope, receive, send)
except Exception: # noqa: PIE-786
if self.raise_app_exceptions or not response_complete.is_set():
if self.raise_app_exceptions:
raise
response_complete.set()
if status_code is None:
status_code = 500
if response_headers is None:
response_headers = {}
assert response_complete.is_set()
assert status_code is not None
assert response_headers is not None

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
import typing
from types import TracebackType
@@ -6,6 +8,8 @@ from .._models import Request, Response
T = typing.TypeVar("T", bound="BaseTransport")
A = typing.TypeVar("A", bound="AsyncBaseTransport")
__all__ = ["AsyncBaseTransport", "BaseTransport"]
class BaseTransport:
def __enter__(self: T) -> T:
@@ -13,9 +17,9 @@ class BaseTransport:
def __exit__(
self,
exc_type: typing.Optional[typing.Type[BaseException]] = None,
exc_value: typing.Optional[BaseException] = None,
traceback: typing.Optional[TracebackType] = None,
exc_type: type[BaseException] | None = None,
exc_value: BaseException | None = None,
traceback: TracebackType | None = None,
) -> None:
self.close()
@@ -64,9 +68,9 @@ class AsyncBaseTransport:
async def __aexit__(
self,
exc_type: typing.Optional[typing.Type[BaseException]] = None,
exc_value: typing.Optional[BaseException] = None,
traceback: typing.Optional[TracebackType] = None,
exc_type: type[BaseException] | None = None,
exc_value: BaseException | None = None,
traceback: TracebackType | None = None,
) -> None:
await self.aclose()

View File

@@ -23,11 +23,17 @@ client = httpx.Client(transport=transport)
transport = httpx.HTTPTransport(uds="socket.uds")
client = httpx.Client(transport=transport)
"""
from __future__ import annotations
import contextlib
import typing
from types import TracebackType
import httpcore
if typing.TYPE_CHECKING:
import ssl # pragma: no cover
import httpx # pragma: no cover
from .._config import DEFAULT_LIMITS, Limits, Proxy, create_ssl_context
from .._exceptions import (
@@ -47,18 +53,53 @@ from .._exceptions import (
WriteTimeout,
)
from .._models import Request, Response
from .._types import AsyncByteStream, CertTypes, SyncByteStream, VerifyTypes
from .._types import AsyncByteStream, CertTypes, ProxyTypes, SyncByteStream
from .._urls import URL
from .base import AsyncBaseTransport, BaseTransport
T = typing.TypeVar("T", bound="HTTPTransport")
A = typing.TypeVar("A", bound="AsyncHTTPTransport")
SOCKET_OPTION = typing.Union[
typing.Tuple[int, int, int],
typing.Tuple[int, int, typing.Union[bytes, bytearray]],
typing.Tuple[int, int, None, int],
]
__all__ = ["AsyncHTTPTransport", "HTTPTransport"]
HTTPCORE_EXC_MAP: dict[type[Exception], type[httpx.HTTPError]] = {}
def _load_httpcore_exceptions() -> dict[type[Exception], type[httpx.HTTPError]]:
import httpcore
return {
httpcore.TimeoutException: TimeoutException,
httpcore.ConnectTimeout: ConnectTimeout,
httpcore.ReadTimeout: ReadTimeout,
httpcore.WriteTimeout: WriteTimeout,
httpcore.PoolTimeout: PoolTimeout,
httpcore.NetworkError: NetworkError,
httpcore.ConnectError: ConnectError,
httpcore.ReadError: ReadError,
httpcore.WriteError: WriteError,
httpcore.ProxyError: ProxyError,
httpcore.UnsupportedProtocol: UnsupportedProtocol,
httpcore.ProtocolError: ProtocolError,
httpcore.LocalProtocolError: LocalProtocolError,
httpcore.RemoteProtocolError: RemoteProtocolError,
}
@contextlib.contextmanager
def map_httpcore_exceptions() -> typing.Iterator[None]:
global HTTPCORE_EXC_MAP
if len(HTTPCORE_EXC_MAP) == 0:
HTTPCORE_EXC_MAP = _load_httpcore_exceptions()
try:
yield
except Exception as exc: # noqa: PIE-786
except Exception as exc:
mapped_exc = None
for from_exc, to_exc in HTTPCORE_EXC_MAP.items():
@@ -77,26 +118,8 @@ def map_httpcore_exceptions() -> typing.Iterator[None]:
raise mapped_exc(message) from exc
HTTPCORE_EXC_MAP = {
httpcore.TimeoutException: TimeoutException,
httpcore.ConnectTimeout: ConnectTimeout,
httpcore.ReadTimeout: ReadTimeout,
httpcore.WriteTimeout: WriteTimeout,
httpcore.PoolTimeout: PoolTimeout,
httpcore.NetworkError: NetworkError,
httpcore.ConnectError: ConnectError,
httpcore.ReadError: ReadError,
httpcore.WriteError: WriteError,
httpcore.ProxyError: ProxyError,
httpcore.UnsupportedProtocol: UnsupportedProtocol,
httpcore.ProtocolError: ProtocolError,
httpcore.LocalProtocolError: LocalProtocolError,
httpcore.RemoteProtocolError: RemoteProtocolError,
}
class ResponseStream(SyncByteStream):
def __init__(self, httpcore_stream: typing.Iterable[bytes]):
def __init__(self, httpcore_stream: typing.Iterable[bytes]) -> None:
self._httpcore_stream = httpcore_stream
def __iter__(self) -> typing.Iterator[bytes]:
@@ -112,17 +135,21 @@ class ResponseStream(SyncByteStream):
class HTTPTransport(BaseTransport):
def __init__(
self,
verify: VerifyTypes = True,
cert: typing.Optional[CertTypes] = None,
verify: ssl.SSLContext | str | bool = True,
cert: CertTypes | None = None,
trust_env: bool = True,
http1: bool = True,
http2: bool = False,
limits: Limits = DEFAULT_LIMITS,
trust_env: bool = True,
proxy: typing.Optional[Proxy] = None,
uds: typing.Optional[str] = None,
local_address: typing.Optional[str] = None,
proxy: ProxyTypes | None = None,
uds: str | None = None,
local_address: str | None = None,
retries: int = 0,
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
) -> None:
import httpcore
proxy = Proxy(url=proxy) if isinstance(proxy, (str, URL)) else proxy
ssl_context = create_ssl_context(verify=verify, cert=cert, trust_env=trust_env)
if proxy is None:
@@ -136,6 +163,7 @@ class HTTPTransport(BaseTransport):
uds=uds,
local_address=local_address,
retries=retries,
socket_options=socket_options,
)
elif proxy.url.scheme in ("http", "https"):
self._pool = httpcore.HTTPProxy(
@@ -148,13 +176,15 @@ class HTTPTransport(BaseTransport):
proxy_auth=proxy.raw_auth,
proxy_headers=proxy.headers.raw,
ssl_context=ssl_context,
proxy_ssl_context=proxy.ssl_context,
max_connections=limits.max_connections,
max_keepalive_connections=limits.max_keepalive_connections,
keepalive_expiry=limits.keepalive_expiry,
http1=http1,
http2=http2,
socket_options=socket_options,
)
elif proxy.url.scheme == "socks5":
elif proxy.url.scheme in ("socks5", "socks5h"):
try:
import socksio # noqa
except ImportError: # pragma: no cover
@@ -180,7 +210,8 @@ class HTTPTransport(BaseTransport):
)
else: # pragma: no cover
raise ValueError(
f"Proxy protocol must be either 'http', 'https', or 'socks5', but got {proxy.url.scheme!r}."
"Proxy protocol must be either 'http', 'https', 'socks5', or 'socks5h',"
f" but got {proxy.url.scheme!r}."
)
def __enter__(self: T) -> T: # Use generics for subclass support.
@@ -189,9 +220,9 @@ class HTTPTransport(BaseTransport):
def __exit__(
self,
exc_type: typing.Optional[typing.Type[BaseException]] = None,
exc_value: typing.Optional[BaseException] = None,
traceback: typing.Optional[TracebackType] = None,
exc_type: type[BaseException] | None = None,
exc_value: BaseException | None = None,
traceback: TracebackType | None = None,
) -> None:
with map_httpcore_exceptions():
self._pool.__exit__(exc_type, exc_value, traceback)
@@ -201,6 +232,7 @@ class HTTPTransport(BaseTransport):
request: Request,
) -> Response:
assert isinstance(request.stream, SyncByteStream)
import httpcore
req = httpcore.Request(
method=request.method,
@@ -231,7 +263,7 @@ class HTTPTransport(BaseTransport):
class AsyncResponseStream(AsyncByteStream):
def __init__(self, httpcore_stream: typing.AsyncIterable[bytes]):
def __init__(self, httpcore_stream: typing.AsyncIterable[bytes]) -> None:
self._httpcore_stream = httpcore_stream
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
@@ -247,17 +279,21 @@ class AsyncResponseStream(AsyncByteStream):
class AsyncHTTPTransport(AsyncBaseTransport):
def __init__(
self,
verify: VerifyTypes = True,
cert: typing.Optional[CertTypes] = None,
verify: ssl.SSLContext | str | bool = True,
cert: CertTypes | None = None,
trust_env: bool = True,
http1: bool = True,
http2: bool = False,
limits: Limits = DEFAULT_LIMITS,
trust_env: bool = True,
proxy: typing.Optional[Proxy] = None,
uds: typing.Optional[str] = None,
local_address: typing.Optional[str] = None,
proxy: ProxyTypes | None = None,
uds: str | None = None,
local_address: str | None = None,
retries: int = 0,
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
) -> None:
import httpcore
proxy = Proxy(url=proxy) if isinstance(proxy, (str, URL)) else proxy
ssl_context = create_ssl_context(verify=verify, cert=cert, trust_env=trust_env)
if proxy is None:
@@ -271,6 +307,7 @@ class AsyncHTTPTransport(AsyncBaseTransport):
uds=uds,
local_address=local_address,
retries=retries,
socket_options=socket_options,
)
elif proxy.url.scheme in ("http", "https"):
self._pool = httpcore.AsyncHTTPProxy(
@@ -282,14 +319,16 @@ class AsyncHTTPTransport(AsyncBaseTransport):
),
proxy_auth=proxy.raw_auth,
proxy_headers=proxy.headers.raw,
proxy_ssl_context=proxy.ssl_context,
ssl_context=ssl_context,
max_connections=limits.max_connections,
max_keepalive_connections=limits.max_keepalive_connections,
keepalive_expiry=limits.keepalive_expiry,
http1=http1,
http2=http2,
socket_options=socket_options,
)
elif proxy.url.scheme == "socks5":
elif proxy.url.scheme in ("socks5", "socks5h"):
try:
import socksio # noqa
except ImportError: # pragma: no cover
@@ -315,7 +354,8 @@ class AsyncHTTPTransport(AsyncBaseTransport):
)
else: # pragma: no cover
raise ValueError(
f"Proxy protocol must be either 'http', 'https', or 'socks5', but got {proxy.url.scheme!r}."
"Proxy protocol must be either 'http', 'https', 'socks5', or 'socks5h',"
" but got {proxy.url.scheme!r}."
)
async def __aenter__(self: A) -> A: # Use generics for subclass support.
@@ -324,9 +364,9 @@ class AsyncHTTPTransport(AsyncBaseTransport):
async def __aexit__(
self,
exc_type: typing.Optional[typing.Type[BaseException]] = None,
exc_value: typing.Optional[BaseException] = None,
traceback: typing.Optional[TracebackType] = None,
exc_type: type[BaseException] | None = None,
exc_value: BaseException | None = None,
traceback: TracebackType | None = None,
) -> None:
with map_httpcore_exceptions():
await self._pool.__aexit__(exc_type, exc_value, traceback)
@@ -336,6 +376,7 @@ class AsyncHTTPTransport(AsyncBaseTransport):
request: Request,
) -> Response:
assert isinstance(request.stream, AsyncByteStream)
import httpcore
req = httpcore.Request(
method=request.method,

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
import typing
from .._models import Request, Response
@@ -7,8 +9,11 @@ SyncHandler = typing.Callable[[Request], Response]
AsyncHandler = typing.Callable[[Request], typing.Coroutine[None, None, Response]]
__all__ = ["MockTransport"]
class MockTransport(AsyncBaseTransport, BaseTransport):
def __init__(self, handler: typing.Union[SyncHandler, AsyncHandler]) -> None:
def __init__(self, handler: SyncHandler | AsyncHandler) -> None:
self.handler = handler
def handle_request(

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
import io
import itertools
import sys
@@ -14,6 +16,9 @@ if typing.TYPE_CHECKING:
_T = typing.TypeVar("_T")
__all__ = ["WSGITransport"]
def _skip_leading_empty_chunks(body: typing.Iterable[_T]) -> typing.Iterable[_T]:
body = iter(body)
for chunk in body:
@@ -71,11 +76,11 @@ class WSGITransport(BaseTransport):
def __init__(
self,
app: "WSGIApplication",
app: WSGIApplication,
raise_app_exceptions: bool = True,
script_name: str = "",
remote_addr: str = "127.0.0.1",
wsgi_errors: typing.Optional[typing.TextIO] = None,
wsgi_errors: typing.TextIO | None = None,
) -> None:
self.app = app
self.raise_app_exceptions = raise_app_exceptions
@@ -102,6 +107,7 @@ class WSGITransport(BaseTransport):
"QUERY_STRING": request.url.query.decode("ascii"),
"SERVER_NAME": request.url.host,
"SERVER_PORT": str(port),
"SERVER_PROTOCOL": "HTTP/1.1",
"REMOTE_ADDR": self.remote_addr,
}
for header_key, header_value in request.headers.raw:
@@ -116,8 +122,8 @@ class WSGITransport(BaseTransport):
def start_response(
status: str,
response_headers: typing.List[typing.Tuple[str, str]],
exc_info: typing.Optional["OptExcInfo"] = None,
response_headers: list[tuple[str, str]],
exc_info: OptExcInfo | None = None,
) -> typing.Callable[[bytes], typing.Any]:
nonlocal seen_status, seen_response_headers, seen_exc_info
seen_status = status

View File

@@ -2,7 +2,6 @@
Type definitions for type checking purposes.
"""
import ssl
from http.cookiejar import CookieJar
from typing import (
IO,
@@ -16,7 +15,6 @@ from typing import (
Iterator,
List,
Mapping,
NamedTuple,
Optional,
Sequence,
Tuple,
@@ -32,16 +30,6 @@ if TYPE_CHECKING: # pragma: no cover
PrimitiveData = Optional[Union[str, int, float, bool]]
RawURL = NamedTuple(
"RawURL",
[
("raw_scheme", bytes),
("raw_host", bytes),
("port", Optional[int]),
("raw_path", bytes),
],
)
URLTypes = Union["URL", str]
QueryParamTypes = Union[
@@ -63,21 +51,13 @@ HeaderTypes = Union[
CookieTypes = Union["Cookies", CookieJar, Dict[str, str], List[Tuple[str, str]]]
CertTypes = Union[
# certfile
str,
# (certfile, keyfile)
Tuple[str, Optional[str]],
# (certfile, keyfile, password)
Tuple[str, Optional[str], Optional[str]],
]
VerifyTypes = Union[str, bool, ssl.SSLContext]
TimeoutTypes = Union[
Optional[float],
Tuple[Optional[float], Optional[float], Optional[float], Optional[float]],
"Timeout",
]
ProxiesTypes = Union[URLTypes, "Proxy", Dict[URLTypes, Union[None, URLTypes, "Proxy"]]]
ProxyTypes = Union["URL", str, "Proxy"]
CertTypes = Union[str, Tuple[str, str], Tuple[str, str, str]]
AuthTypes = Union[
Tuple[Union[str, bytes], Union[str, bytes]],
@@ -106,6 +86,8 @@ RequestFiles = Union[Mapping[str, FileTypes], Sequence[Tuple[str, FileTypes]]]
RequestExtensions = Mapping[str, Any]
__all__ = ["AsyncByteStream", "SyncByteStream"]
class SyncByteStream:
def __iter__(self) -> Iterator[bytes]:

View File

@@ -15,6 +15,9 @@ Previously we relied on the excellent `rfc3986` package to handle URL parsing an
validation, but this module provides a simpler alternative, with less indirection
required.
"""
from __future__ import annotations
import ipaddress
import re
import typing
@@ -33,6 +36,67 @@ SUB_DELIMS = "!$&'()*+,;="
PERCENT_ENCODED_REGEX = re.compile("%[A-Fa-f0-9]{2}")
# https://url.spec.whatwg.org/#percent-encoded-bytes
# The fragment percent-encode set is the C0 control percent-encode set
# and U+0020 SPACE, U+0022 ("), U+003C (<), U+003E (>), and U+0060 (`).
FRAG_SAFE = "".join(
[chr(i) for i in range(0x20, 0x7F) if i not in (0x20, 0x22, 0x3C, 0x3E, 0x60)]
)
# The query percent-encode set is the C0 control percent-encode set
# and U+0020 SPACE, U+0022 ("), U+0023 (#), U+003C (<), and U+003E (>).
QUERY_SAFE = "".join(
[chr(i) for i in range(0x20, 0x7F) if i not in (0x20, 0x22, 0x23, 0x3C, 0x3E)]
)
# The path percent-encode set is the query percent-encode set
# and U+003F (?), U+0060 (`), U+007B ({), and U+007D (}).
PATH_SAFE = "".join(
[
chr(i)
for i in range(0x20, 0x7F)
if i not in (0x20, 0x22, 0x23, 0x3C, 0x3E) + (0x3F, 0x60, 0x7B, 0x7D)
]
)
# The userinfo percent-encode set is the path percent-encode set
# and U+002F (/), U+003A (:), U+003B (;), U+003D (=), U+0040 (@),
# U+005B ([) to U+005E (^), inclusive, and U+007C (|).
USERNAME_SAFE = "".join(
[
chr(i)
for i in range(0x20, 0x7F)
if i
not in (0x20, 0x22, 0x23, 0x3C, 0x3E)
+ (0x3F, 0x60, 0x7B, 0x7D)
+ (0x2F, 0x3A, 0x3B, 0x3D, 0x40, 0x5B, 0x5C, 0x5D, 0x5E, 0x7C)
]
)
PASSWORD_SAFE = "".join(
[
chr(i)
for i in range(0x20, 0x7F)
if i
not in (0x20, 0x22, 0x23, 0x3C, 0x3E)
+ (0x3F, 0x60, 0x7B, 0x7D)
+ (0x2F, 0x3A, 0x3B, 0x3D, 0x40, 0x5B, 0x5C, 0x5D, 0x5E, 0x7C)
]
)
# Note... The terminology 'userinfo' percent-encode set in the WHATWG document
# is used for the username and password quoting. For the joint userinfo component
# we remove U+003A (:) from the safe set.
USERINFO_SAFE = "".join(
[
chr(i)
for i in range(0x20, 0x7F)
if i
not in (0x20, 0x22, 0x23, 0x3C, 0x3E)
+ (0x3F, 0x60, 0x7B, 0x7D)
+ (0x2F, 0x3B, 0x3D, 0x40, 0x5B, 0x5C, 0x5D, 0x5E, 0x7C)
]
)
# {scheme}: (optional)
# //{authority} (optional)
@@ -62,8 +126,8 @@ AUTHORITY_REGEX = re.compile(
(
r"(?:(?P<userinfo>{userinfo})@)?" r"(?P<host>{host})" r":?(?P<port>{port})?"
).format(
userinfo="[^@]*", # Any character sequence not including '@'.
host="(\\[.*\\]|[^:]*)", # Either any character sequence not including ':',
userinfo=".*", # Any character sequence.
host="(\\[.*\\]|[^:@]*)", # Either any character sequence excluding ':' or '@',
# or an IPv6 address enclosed within square brackets.
port=".*", # Any character sequence.
)
@@ -87,7 +151,7 @@ COMPONENT_REGEX = {
# We use these simple regexs as a first pass before handing off to
# the stdlib 'ipaddress' module for IP address validation.
IPv4_STYLE_HOSTNAME = re.compile(r"^[0-9]+.[0-9]+.[0-9]+.[0-9]+$")
IPv4_STYLE_HOSTNAME = re.compile(r"^[0-9]+\.[0-9]+\.[0-9]+\.[0-9]+$")
IPv6_STYLE_HOSTNAME = re.compile(r"^\[.*\]$")
@@ -95,10 +159,10 @@ class ParseResult(typing.NamedTuple):
scheme: str
userinfo: str
host: str
port: typing.Optional[int]
port: int | None
path: str
query: typing.Optional[str]
fragment: typing.Optional[str]
query: str | None
fragment: str | None
@property
def authority(self) -> str:
@@ -119,7 +183,7 @@ class ParseResult(typing.NamedTuple):
]
)
def copy_with(self, **kwargs: typing.Optional[str]) -> "ParseResult":
def copy_with(self, **kwargs: str | None) -> ParseResult:
if not kwargs:
return self
@@ -146,7 +210,7 @@ class ParseResult(typing.NamedTuple):
)
def urlparse(url: str = "", **kwargs: typing.Optional[str]) -> ParseResult:
def urlparse(url: str = "", **kwargs: str | None) -> ParseResult:
# Initial basic checks on allowable URLs.
# ---------------------------------------
@@ -157,7 +221,12 @@ def urlparse(url: str = "", **kwargs: typing.Optional[str]) -> ParseResult:
# If a URL includes any ASCII control characters including \t, \r, \n,
# then treat it as invalid.
if any(char.isascii() and not char.isprintable() for char in url):
raise InvalidURL("Invalid non-printable ASCII character in URL")
char = next(char for char in url if char.isascii() and not char.isprintable())
idx = url.find(char)
error = (
f"Invalid non-printable ASCII character in URL, {char!r} at position {idx}."
)
raise InvalidURL(error)
# Some keyword arguments require special handling.
# ------------------------------------------------
@@ -174,8 +243,8 @@ def urlparse(url: str = "", **kwargs: typing.Optional[str]) -> ParseResult:
# Replace "username" and/or "password" with "userinfo".
if "username" in kwargs or "password" in kwargs:
username = quote(kwargs.pop("username", "") or "")
password = quote(kwargs.pop("password", "") or "")
username = quote(kwargs.pop("username", "") or "", safe=USERNAME_SAFE)
password = quote(kwargs.pop("password", "") or "", safe=PASSWORD_SAFE)
kwargs["userinfo"] = f"{username}:{password}" if password else username
# Replace "raw_path" with "path" and "query".
@@ -202,9 +271,15 @@ def urlparse(url: str = "", **kwargs: typing.Optional[str]) -> ParseResult:
# If a component includes any ASCII control characters including \t, \r, \n,
# then treat it as invalid.
if any(char.isascii() and not char.isprintable() for char in value):
raise InvalidURL(
f"Invalid non-printable ASCII character in URL component '{key}'"
char = next(
char for char in value if char.isascii() and not char.isprintable()
)
idx = value.find(char)
error = (
f"Invalid non-printable ASCII character in URL {key} component, "
f"{char!r} at position {idx}."
)
raise InvalidURL(error)
# Ensure that keyword arguments match as a valid regex.
if not COMPONENT_REGEX[key].fullmatch(value):
@@ -224,7 +299,7 @@ def urlparse(url: str = "", **kwargs: typing.Optional[str]) -> ParseResult:
authority = kwargs.get("authority", url_dict["authority"]) or ""
path = kwargs.get("path", url_dict["path"]) or ""
query = kwargs.get("query", url_dict["query"])
fragment = kwargs.get("fragment", url_dict["fragment"])
frag = kwargs.get("fragment", url_dict["fragment"])
# The AUTHORITY_REGEX will always match, but may have empty components.
authority_match = AUTHORITY_REGEX.match(authority)
@@ -241,32 +316,21 @@ def urlparse(url: str = "", **kwargs: typing.Optional[str]) -> ParseResult:
# We end up with a parsed representation of the URL,
# with components that are plain ASCII bytestrings.
parsed_scheme: str = scheme.lower()
parsed_userinfo: str = quote(userinfo, safe=SUB_DELIMS + ":")
parsed_userinfo: str = quote(userinfo, safe=USERINFO_SAFE)
parsed_host: str = encode_host(host)
parsed_port: typing.Optional[int] = normalize_port(port, scheme)
parsed_port: int | None = normalize_port(port, scheme)
has_scheme = parsed_scheme != ""
has_authority = (
parsed_userinfo != "" or parsed_host != "" or parsed_port is not None
)
validate_path(path, has_scheme=has_scheme, has_authority=has_authority)
if has_authority:
if has_scheme or has_authority:
path = normalize_path(path)
# The GEN_DELIMS set is... : / ? # [ ] @
# These do not need to be percent-quoted unless they serve as delimiters for the
# specific component.
# For 'path' we need to drop ? and # from the GEN_DELIMS set.
parsed_path: str = quote(path, safe=SUB_DELIMS + ":/[]@")
# For 'query' we need to drop '#' from the GEN_DELIMS set.
parsed_query: typing.Optional[str] = (
None if query is None else quote(query, safe=SUB_DELIMS + ":/?[]@")
)
# For 'fragment' we can include all of the GEN_DELIMS set.
parsed_fragment: typing.Optional[str] = (
None if fragment is None else quote(fragment, safe=SUB_DELIMS + ":/?#[]@")
)
parsed_path: str = quote(path, safe=PATH_SAFE)
parsed_query: str | None = None if query is None else quote(query, safe=QUERY_SAFE)
parsed_frag: str | None = None if frag is None else quote(frag, safe=FRAG_SAFE)
# The parsed ASCII bytestrings are our canonical form.
# All properties of the URL are derived from these.
@@ -277,7 +341,7 @@ def urlparse(url: str = "", **kwargs: typing.Optional[str]) -> ParseResult:
parsed_port,
parsed_path,
parsed_query,
parsed_fragment,
parsed_frag,
)
@@ -318,7 +382,8 @@ def encode_host(host: str) -> str:
# From https://datatracker.ietf.org/doc/html/rfc3986/#section-3.2.2
#
# reg-name = *( unreserved / pct-encoded / sub-delims )
return quote(host.lower(), safe=SUB_DELIMS)
WHATWG_SAFE = '"`{}%|\\'
return quote(host.lower(), safe=SUB_DELIMS + WHATWG_SAFE)
# IDNA hostnames
try:
@@ -327,9 +392,7 @@ def encode_host(host: str) -> str:
raise InvalidURL(f"Invalid IDNA hostname: {host!r}")
def normalize_port(
port: typing.Optional[typing.Union[str, int]], scheme: str
) -> typing.Optional[int]:
def normalize_port(port: str | int | None, scheme: str) -> int | None:
# From https://tools.ietf.org/html/rfc3986#section-3.2.3
#
# "A scheme may define a default port. For example, the "http" scheme
@@ -358,28 +421,27 @@ def normalize_port(
def validate_path(path: str, has_scheme: bool, has_authority: bool) -> None:
"""
Path validation rules that depend on if the URL contains a scheme or authority component.
Path validation rules that depend on if the URL contains
a scheme or authority component.
See https://datatracker.ietf.org/doc/html/rfc3986.html#section-3.3
"""
if has_authority:
# > If a URI contains an authority component, then the path component
# > must either be empty or begin with a slash ("/") character."
# If a URI contains an authority component, then the path component
# must either be empty or begin with a slash ("/") character."
if path and not path.startswith("/"):
raise InvalidURL("For absolute URLs, path must be empty or begin with '/'")
else:
# > If a URI does not contain an authority component, then the path cannot begin
# > with two slash characters ("//").
if not has_scheme and not has_authority:
# If a URI does not contain an authority component, then the path cannot begin
# with two slash characters ("//").
if path.startswith("//"):
raise InvalidURL(
"URLs with no authority component cannot have a path starting with '//'"
)
# > In addition, a URI reference (Section 4.1) may be a relative-path reference, in which
# > case the first path segment cannot contain a colon (":") character.
if path.startswith(":") and not has_scheme:
raise InvalidURL(
"URLs with no scheme component cannot have a path starting with ':'"
)
raise InvalidURL("Relative URLs cannot have a path starting with '//'")
# In addition, a URI reference (Section 4.1) may be a relative-path reference,
# in which case the first path segment cannot contain a colon (":") character.
if path.startswith(":"):
raise InvalidURL("Relative URLs cannot have a path starting with ':'")
def normalize_path(path: str) -> str:
@@ -390,9 +452,18 @@ def normalize_path(path: str) -> str:
normalize_path("/path/./to/somewhere/..") == "/path/to"
"""
# https://datatracker.ietf.org/doc/html/rfc3986#section-5.2.4
# Fast return when no '.' characters in the path.
if "." not in path:
return path
components = path.split("/")
output: typing.List[str] = []
# Fast return when no '.' or '..' components in the path.
if "." not in components and ".." not in components:
return path
# https://datatracker.ietf.org/doc/html/rfc3986#section-5.2.4
output: list[str] = []
for component in components:
if component == ".":
pass
@@ -404,59 +475,53 @@ def normalize_path(path: str) -> str:
return "/".join(output)
def percent_encode(char: str) -> str:
def PERCENT(string: str) -> str:
return "".join([f"%{byte:02X}" for byte in string.encode("utf-8")])
def percent_encoded(string: str, safe: str) -> str:
"""
Replace a single character with the percent-encoded representation.
Characters outside the ASCII range are represented with their a percent-encoded
representation of their UTF-8 byte sequence.
For example:
percent_encode(" ") == "%20"
Use percent-encoding to quote a string.
"""
return "".join([f"%{byte:02x}" for byte in char.encode("utf-8")]).upper()
NON_ESCAPED_CHARS = UNRESERVED_CHARACTERS + safe
def is_safe(string: str, safe: str = "/") -> bool:
"""
Determine if a given string is already quote-safe.
"""
NON_ESCAPED_CHARS = UNRESERVED_CHARACTERS + safe + "%"
# All characters must already be non-escaping or '%'
for char in string:
if char not in NON_ESCAPED_CHARS:
return False
# Any '%' characters must be valid '%xx' escape sequences.
return string.count("%") == len(PERCENT_ENCODED_REGEX.findall(string))
def quote(string: str, safe: str = "/") -> str:
"""
Use percent-encoding to quote a string if required.
"""
if is_safe(string, safe=safe):
# Fast path for strings that don't need escaping.
if not string.rstrip(NON_ESCAPED_CHARS):
return string
NON_ESCAPED_CHARS = UNRESERVED_CHARACTERS + safe
return "".join(
[char if char in NON_ESCAPED_CHARS else percent_encode(char) for char in string]
[char if char in NON_ESCAPED_CHARS else PERCENT(char) for char in string]
)
def urlencode(items: typing.List[typing.Tuple[str, str]]) -> str:
# We can use a much simpler version of the stdlib urlencode here because
# we don't need to handle a bunch of different typing cases, such as bytes vs str.
#
# https://github.com/python/cpython/blob/b2f7b2ef0b5421e01efb8c7bee2ef95d3bab77eb/Lib/urllib/parse.py#L926
#
# Note that we use '%20' encoding for spaces, and treat '/' as a safe
# character. This means our query params have the same escaping as other
# characters in the URL path. This is slightly different to `requests`,
# but is the behaviour that browsers use.
#
# See https://github.com/encode/httpx/issues/2536 and
# https://docs.python.org/3/library/urllib.parse.html#urllib.parse.urlencode
return "&".join([quote(k) + "=" + quote(v) for k, v in items])
def quote(string: str, safe: str) -> str:
"""
Use percent-encoding to quote a string, omitting existing '%xx' escape sequences.
See: https://www.rfc-editor.org/rfc/rfc3986#section-2.1
* `string`: The string to be percent-escaped.
* `safe`: A string containing characters that may be treated as safe, and do not
need to be escaped. Unreserved characters are always treated as safe.
See: https://www.rfc-editor.org/rfc/rfc3986#section-2.3
"""
parts = []
current_position = 0
for match in re.finditer(PERCENT_ENCODED_REGEX, string):
start_position, end_position = match.start(), match.end()
matched_text = match.group(0)
# Add any text up to the '%xx' escape sequence.
if start_position != current_position:
leading_text = string[current_position:start_position]
parts.append(percent_encoded(leading_text, safe=safe))
# Add the '%xx' escape sequence.
parts.append(matched_text)
current_position = end_position
# Add any text after the final '%xx' escape sequence.
if current_position != len(string):
trailing_text = string[current_position:]
parts.append(percent_encoded(trailing_text, safe=safe))
return "".join(parts)

View File

@@ -1,12 +1,16 @@
from __future__ import annotations
import typing
from urllib.parse import parse_qs, unquote
from urllib.parse import parse_qs, unquote, urlencode
import idna
from ._types import QueryParamTypes, RawURL, URLTypes
from ._urlparse import urlencode, urlparse
from ._types import QueryParamTypes
from ._urlparse import urlparse
from ._utils import primitive_value_to_str
__all__ = ["URL", "QueryParams"]
class URL:
"""
@@ -51,26 +55,26 @@ class URL:
assert url.raw_host == b"xn--fiqs8s.icom.museum"
* `url.port` is either None or an integer. URLs that include the default port for
"http", "https", "ws", "wss", and "ftp" schemes have their port normalized to `None`.
"http", "https", "ws", "wss", and "ftp" schemes have their port
normalized to `None`.
assert httpx.URL("http://example.com") == httpx.URL("http://example.com:80")
assert httpx.URL("http://example.com").port is None
assert httpx.URL("http://example.com:80").port is None
* `url.userinfo` is raw bytes, without URL escaping. Usually you'll want to work with
`url.username` and `url.password` instead, which handle the URL escaping.
* `url.userinfo` is raw bytes, without URL escaping. Usually you'll want to work
with `url.username` and `url.password` instead, which handle the URL escaping.
* `url.raw_path` is raw bytes of both the path and query, without URL escaping.
This portion is used as the target when constructing HTTP requests. Usually you'll
want to work with `url.path` instead.
* `url.query` is raw bytes, without URL escaping. A URL query string portion can only
be properly URL escaped when decoding the parameter names and values themselves.
* `url.query` is raw bytes, without URL escaping. A URL query string portion can
only be properly URL escaped when decoding the parameter names and values
themselves.
"""
def __init__(
self, url: typing.Union["URL", str] = "", **kwargs: typing.Any
) -> None:
def __init__(self, url: URL | str = "", **kwargs: typing.Any) -> None:
if kwargs:
allowed = {
"scheme": str,
@@ -115,7 +119,8 @@ class URL:
self._uri_reference = url._uri_reference.copy_with(**kwargs)
else:
raise TypeError(
f"Invalid type for url. Expected str or httpx.URL, got {type(url)}: {url!r}"
"Invalid type for url. Expected str or httpx.URL,"
f" got {type(url)}: {url!r}"
)
@property
@@ -210,7 +215,7 @@ class URL:
return self._uri_reference.host.encode("ascii")
@property
def port(self) -> typing.Optional[int]:
def port(self) -> int | None:
"""
The URL port as an integer.
@@ -267,7 +272,7 @@ class URL:
return query.encode("ascii")
@property
def params(self) -> "QueryParams":
def params(self) -> QueryParams:
"""
The URL query parameters, neatly parsed and packaged into an immutable
multidict representation.
@@ -299,21 +304,6 @@ class URL:
"""
return unquote(self._uri_reference.fragment or "")
@property
def raw(self) -> RawURL:
"""
Provides the (scheme, host, port, target) for the outgoing request.
In older versions of `httpx` this was used in the low-level transport API.
We no longer use `RawURL`, and this property will be deprecated in a future release.
"""
return RawURL(
self.raw_scheme,
self.raw_host,
self.port,
self.raw_path,
)
@property
def is_absolute_url(self) -> bool:
"""
@@ -334,7 +324,7 @@ class URL:
"""
return not self.is_absolute_url
def copy_with(self, **kwargs: typing.Any) -> "URL":
def copy_with(self, **kwargs: typing.Any) -> URL:
"""
Copy this URL, returning a new URL with some components altered.
Accepts the same set of parameters as the components that are made
@@ -342,24 +332,26 @@ class URL:
For example:
url = httpx.URL("https://www.example.com").copy_with(username="jo@gmail.com", password="a secret")
url = httpx.URL("https://www.example.com").copy_with(
username="jo@gmail.com", password="a secret"
)
assert url == "https://jo%40email.com:a%20secret@www.example.com"
"""
return URL(self, **kwargs)
def copy_set_param(self, key: str, value: typing.Any = None) -> "URL":
def copy_set_param(self, key: str, value: typing.Any = None) -> URL:
return self.copy_with(params=self.params.set(key, value))
def copy_add_param(self, key: str, value: typing.Any = None) -> "URL":
def copy_add_param(self, key: str, value: typing.Any = None) -> URL:
return self.copy_with(params=self.params.add(key, value))
def copy_remove_param(self, key: str) -> "URL":
def copy_remove_param(self, key: str) -> URL:
return self.copy_with(params=self.params.remove(key))
def copy_merge_params(self, params: QueryParamTypes) -> "URL":
def copy_merge_params(self, params: QueryParamTypes) -> URL:
return self.copy_with(params=self.params.merge(params))
def join(self, url: URLTypes) -> "URL":
def join(self, url: URL | str) -> URL:
"""
Return an absolute URL, using this URL as the base.
@@ -408,15 +400,29 @@ class URL:
return f"{self.__class__.__name__}({url!r})"
@property
def raw(self) -> tuple[bytes, bytes, int, bytes]: # pragma: nocover
import collections
import warnings
warnings.warn("URL.raw is deprecated.")
RawURL = collections.namedtuple(
"RawURL", ["raw_scheme", "raw_host", "port", "raw_path"]
)
return RawURL(
raw_scheme=self.raw_scheme,
raw_host=self.raw_host,
port=self.port,
raw_path=self.raw_path,
)
class QueryParams(typing.Mapping[str, str]):
"""
URL query parameters, as a multi-dict.
"""
def __init__(
self, *args: typing.Optional[QueryParamTypes], **kwargs: typing.Any
) -> None:
def __init__(self, *args: QueryParamTypes | None, **kwargs: typing.Any) -> None:
assert len(args) < 2, "Too many arguments."
assert not (args and kwargs), "Cannot mix named and unnamed arguments."
@@ -428,7 +434,7 @@ class QueryParams(typing.Mapping[str, str]):
elif isinstance(value, QueryParams):
self._dict = {k: list(v) for k, v in value._dict.items()}
else:
dict_value: typing.Dict[typing.Any, typing.List[typing.Any]] = {}
dict_value: dict[typing.Any, list[typing.Any]] = {}
if isinstance(value, (list, tuple)):
# Convert list inputs like:
# [("a", "123"), ("a", "456"), ("b", "789")]
@@ -489,7 +495,7 @@ class QueryParams(typing.Mapping[str, str]):
"""
return {k: v[0] for k, v in self._dict.items()}.items()
def multi_items(self) -> typing.List[typing.Tuple[str, str]]:
def multi_items(self) -> list[tuple[str, str]]:
"""
Return all items in the query params. Allow duplicate keys to occur.
@@ -498,7 +504,7 @@ class QueryParams(typing.Mapping[str, str]):
q = httpx.QueryParams("a=123&a=456&b=789")
assert list(q.multi_items()) == [("a", "123"), ("a", "456"), ("b", "789")]
"""
multi_items: typing.List[typing.Tuple[str, str]] = []
multi_items: list[tuple[str, str]] = []
for k, v in self._dict.items():
multi_items.extend([(k, i) for i in v])
return multi_items
@@ -517,7 +523,7 @@ class QueryParams(typing.Mapping[str, str]):
return self._dict[str(key)][0]
return default
def get_list(self, key: str) -> typing.List[str]:
def get_list(self, key: str) -> list[str]:
"""
Get all values from the query param for a given key.
@@ -528,7 +534,7 @@ class QueryParams(typing.Mapping[str, str]):
"""
return list(self._dict.get(str(key), []))
def set(self, key: str, value: typing.Any = None) -> "QueryParams":
def set(self, key: str, value: typing.Any = None) -> QueryParams:
"""
Return a new QueryParams instance, setting the value of a key.
@@ -543,7 +549,7 @@ class QueryParams(typing.Mapping[str, str]):
q._dict[str(key)] = [primitive_value_to_str(value)]
return q
def add(self, key: str, value: typing.Any = None) -> "QueryParams":
def add(self, key: str, value: typing.Any = None) -> QueryParams:
"""
Return a new QueryParams instance, setting or appending the value of a key.
@@ -558,7 +564,7 @@ class QueryParams(typing.Mapping[str, str]):
q._dict[str(key)] = q.get_list(key) + [primitive_value_to_str(value)]
return q
def remove(self, key: str) -> "QueryParams":
def remove(self, key: str) -> QueryParams:
"""
Return a new QueryParams instance, removing the value of a key.
@@ -573,7 +579,7 @@ class QueryParams(typing.Mapping[str, str]):
q._dict.pop(str(key), None)
return q
def merge(self, params: typing.Optional[QueryParamTypes] = None) -> "QueryParams":
def merge(self, params: QueryParamTypes | None = None) -> QueryParams:
"""
Return a new QueryParams instance, updated with.
@@ -615,13 +621,6 @@ class QueryParams(typing.Mapping[str, str]):
return sorted(self.multi_items()) == sorted(other.multi_items())
def __str__(self) -> str:
"""
Note that we use '%20' encoding for spaces, and treat '/' as a safe
character.
See https://github.com/encode/httpx/issues/2536 and
https://docs.python.org/3/library/urllib.parse.html#urllib.parse.urlencode
"""
return urlencode(self.multi_items())
def __repr__(self) -> str:
@@ -629,7 +628,7 @@ class QueryParams(typing.Mapping[str, str]):
query_string = str(self)
return f"{class_name}({query_string!r})"
def update(self, params: typing.Optional[QueryParamTypes] = None) -> None:
def update(self, params: QueryParamTypes | None = None) -> None:
raise RuntimeError(
"QueryParams are immutable since 0.18.0. "
"Use `q = q.merge(...)` to create an updated copy."

View File

@@ -1,59 +1,18 @@
import codecs
import email.message
from __future__ import annotations
import ipaddress
import mimetypes
import os
import re
import time
import typing
from pathlib import Path
from urllib.request import getproxies
import sniffio
from ._types import PrimitiveData
if typing.TYPE_CHECKING: # pragma: no cover
from ._urls import URL
_HTML5_FORM_ENCODING_REPLACEMENTS = {'"': "%22", "\\": "\\\\"}
_HTML5_FORM_ENCODING_REPLACEMENTS.update(
{chr(c): "%{:02X}".format(c) for c in range(0x1F + 1) if c != 0x1B}
)
_HTML5_FORM_ENCODING_RE = re.compile(
r"|".join([re.escape(c) for c in _HTML5_FORM_ENCODING_REPLACEMENTS.keys()])
)
def normalize_header_key(
value: typing.Union[str, bytes],
lower: bool,
encoding: typing.Optional[str] = None,
) -> bytes:
"""
Coerce str/bytes into a strictly byte-wise HTTP header key.
"""
if isinstance(value, bytes):
bytes_value = value
else:
bytes_value = value.encode(encoding or "ascii")
return bytes_value.lower() if lower else bytes_value
def normalize_header_value(
value: typing.Union[str, bytes], encoding: typing.Optional[str] = None
) -> bytes:
"""
Coerce str/bytes into a strictly byte-wise HTTP header value.
"""
if isinstance(value, bytes):
return value
return value.encode(encoding or "ascii")
def primitive_value_to_str(value: "PrimitiveData") -> str:
def primitive_value_to_str(value: PrimitiveData) -> str:
"""
Coerce a primitive data type into a string value.
@@ -68,166 +27,7 @@ def primitive_value_to_str(value: "PrimitiveData") -> str:
return str(value)
def is_known_encoding(encoding: str) -> bool:
"""
Return `True` if `encoding` is a known codec.
"""
try:
codecs.lookup(encoding)
except LookupError:
return False
return True
def format_form_param(name: str, value: str) -> bytes:
"""
Encode a name/value pair within a multipart form.
"""
def replacer(match: typing.Match[str]) -> str:
return _HTML5_FORM_ENCODING_REPLACEMENTS[match.group(0)]
value = _HTML5_FORM_ENCODING_RE.sub(replacer, value)
return f'{name}="{value}"'.encode()
# Null bytes; no need to recreate these on each call to guess_json_utf
_null = b"\x00"
_null2 = _null * 2
_null3 = _null * 3
def guess_json_utf(data: bytes) -> typing.Optional[str]:
# JSON always starts with two ASCII characters, so detection is as
# easy as counting the nulls and from their location and count
# determine the encoding. Also detect a BOM, if present.
sample = data[:4]
if sample in (codecs.BOM_UTF32_LE, codecs.BOM_UTF32_BE):
return "utf-32" # BOM included
if sample[:3] == codecs.BOM_UTF8:
return "utf-8-sig" # BOM included, MS style (discouraged)
if sample[:2] in (codecs.BOM_UTF16_LE, codecs.BOM_UTF16_BE):
return "utf-16" # BOM included
nullcount = sample.count(_null)
if nullcount == 0:
return "utf-8"
if nullcount == 2:
if sample[::2] == _null2: # 1st and 3rd are null
return "utf-16-be"
if sample[1::2] == _null2: # 2nd and 4th are null
return "utf-16-le"
# Did not detect 2 valid UTF-16 ascii-range characters
if nullcount == 3:
if sample[:3] == _null3:
return "utf-32-be"
if sample[1:] == _null3:
return "utf-32-le"
# Did not detect a valid UTF-32 ascii-range character
return None
def get_ca_bundle_from_env() -> typing.Optional[str]:
if "SSL_CERT_FILE" in os.environ:
ssl_file = Path(os.environ["SSL_CERT_FILE"])
if ssl_file.is_file():
return str(ssl_file)
if "SSL_CERT_DIR" in os.environ:
ssl_path = Path(os.environ["SSL_CERT_DIR"])
if ssl_path.is_dir():
return str(ssl_path)
return None
def parse_header_links(value: str) -> typing.List[typing.Dict[str, str]]:
"""
Returns a list of parsed link headers, for more info see:
https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Link
The generic syntax of those is:
Link: < uri-reference >; param1=value1; param2="value2"
So for instance:
Link; '<http:/.../front.jpeg>; type="image/jpeg",<http://.../back.jpeg>;'
would return
[
{"url": "http:/.../front.jpeg", "type": "image/jpeg"},
{"url": "http://.../back.jpeg"},
]
:param value: HTTP Link entity-header field
:return: list of parsed link headers
"""
links: typing.List[typing.Dict[str, str]] = []
replace_chars = " '\""
value = value.strip(replace_chars)
if not value:
return links
for val in re.split(", *<", value):
try:
url, params = val.split(";", 1)
except ValueError:
url, params = val, ""
link = {"url": url.strip("<> '\"")}
for param in params.split(";"):
try:
key, value = param.split("=")
except ValueError:
break
link[key.strip(replace_chars)] = value.strip(replace_chars)
links.append(link)
return links
def parse_content_type_charset(content_type: str) -> typing.Optional[str]:
# We used to use `cgi.parse_header()` here, but `cgi` became a dead battery.
# See: https://peps.python.org/pep-0594/#cgi
msg = email.message.Message()
msg["content-type"] = content_type
return msg.get_content_charset(failobj=None)
SENSITIVE_HEADERS = {"authorization", "proxy-authorization"}
def obfuscate_sensitive_headers(
items: typing.Iterable[typing.Tuple[typing.AnyStr, typing.AnyStr]]
) -> typing.Iterator[typing.Tuple[typing.AnyStr, typing.AnyStr]]:
for k, v in items:
if to_str(k.lower()) in SENSITIVE_HEADERS:
v = to_bytes_or_str("[secure]", match_type_of=v)
yield k, v
def port_or_default(url: "URL") -> typing.Optional[int]:
if url.port is not None:
return url.port
return {"http": 80, "https": 443}.get(url.scheme)
def same_origin(url: "URL", other: "URL") -> bool:
"""
Return 'True' if the given URLs share the same origin.
"""
return (
url.scheme == other.scheme
and url.host == other.host
and port_or_default(url) == port_or_default(other)
)
def is_https_redirect(url: "URL", location: "URL") -> bool:
"""
Return 'True' if 'location' is a HTTPS upgrade of 'url'
"""
if url.host != location.host:
return False
return (
url.scheme == "http"
and port_or_default(url) == 80
and location.scheme == "https"
and port_or_default(location) == 443
)
def get_environment_proxies() -> typing.Dict[str, typing.Optional[str]]:
def get_environment_proxies() -> dict[str, str | None]:
"""Gets proxy information from the environment"""
# urllib.request.getproxies() falls back on System
@@ -235,7 +35,7 @@ def get_environment_proxies() -> typing.Dict[str, typing.Optional[str]]:
# We don't want to propagate non-HTTP proxies into
# our configuration such as 'TRAVIS_APT_PROXY'.
proxy_info = getproxies()
mounts: typing.Dict[str, typing.Optional[str]] = {}
mounts: dict[str, str | None] = {}
for scheme in ("http", "https", "all"):
if proxy_info.get(scheme):
@@ -262,7 +62,9 @@ def get_environment_proxies() -> typing.Dict[str, typing.Optional[str]]:
# (But not "wwwgoogle.com")
# NO_PROXY can include domains, IPv6, IPv4 addresses and "localhost"
# NO_PROXY=example.com,::1,localhost,192.168.0.0/16
if is_ipv4_hostname(hostname):
if "://" in hostname:
mounts[hostname] = None
elif is_ipv4_hostname(hostname):
mounts[f"all://{hostname}"] = None
elif is_ipv6_hostname(hostname):
mounts[f"all://[{hostname}]"] = None
@@ -274,11 +76,11 @@ def get_environment_proxies() -> typing.Dict[str, typing.Optional[str]]:
return mounts
def to_bytes(value: typing.Union[str, bytes], encoding: str = "utf-8") -> bytes:
def to_bytes(value: str | bytes, encoding: str = "utf-8") -> bytes:
return value.encode(encoding) if isinstance(value, str) else value
def to_str(value: typing.Union[str, bytes], encoding: str = "utf-8") -> str:
def to_str(value: str | bytes, encoding: str = "utf-8") -> str:
return value if isinstance(value, str) else value.decode(encoding)
@@ -290,13 +92,7 @@ def unquote(value: str) -> str:
return value[1:-1] if value[0] == value[-1] == '"' else value
def guess_content_type(filename: typing.Optional[str]) -> typing.Optional[str]:
if filename:
return mimetypes.guess_type(filename)[0] or "application/octet-stream"
return None
def peek_filelike_length(stream: typing.Any) -> typing.Optional[int]:
def peek_filelike_length(stream: typing.Any) -> int | None:
"""
Given a file-like stream object, return its length in number of bytes
without reading it into memory.
@@ -321,48 +117,17 @@ def peek_filelike_length(stream: typing.Any) -> typing.Optional[int]:
return length
class Timer:
async def _get_time(self) -> float:
library = sniffio.current_async_library()
if library == "trio":
import trio
return trio.current_time()
elif library == "curio": # pragma: no cover
import curio
return typing.cast(float, await curio.clock())
import asyncio
return asyncio.get_event_loop().time()
def sync_start(self) -> None:
self.started = time.perf_counter()
async def async_start(self) -> None:
self.started = await self._get_time()
def sync_elapsed(self) -> float:
now = time.perf_counter()
return now - self.started
async def async_elapsed(self) -> float:
now = await self._get_time()
return now - self.started
class URLPattern:
"""
A utility class currently used for making lookups against proxy keys...
# Wildcard matching...
>>> pattern = URLPattern("all")
>>> pattern = URLPattern("all://")
>>> pattern.matches(httpx.URL("http://example.com"))
True
# Witch scheme matching...
>>> pattern = URLPattern("https")
>>> pattern = URLPattern("https://")
>>> pattern.matches(httpx.URL("https://example.com"))
True
>>> pattern.matches(httpx.URL("http://example.com"))
@@ -410,7 +175,7 @@ class URLPattern:
self.host = "" if url.host == "*" else url.host
self.port = url.port
if not url.host or url.host == "*":
self.host_regex: typing.Optional[typing.Pattern[str]] = None
self.host_regex: typing.Pattern[str] | None = None
elif url.host.startswith("*."):
# *.example.com should match "www.example.com", but not "example.com"
domain = re.escape(url.host[2:])
@@ -424,7 +189,7 @@ class URLPattern:
domain = re.escape(url.host)
self.host_regex = re.compile(f"^{domain}$")
def matches(self, other: "URL") -> bool:
def matches(self, other: URL) -> bool:
if self.scheme and self.scheme != other.scheme:
return False
if (
@@ -438,7 +203,7 @@ class URLPattern:
return True
@property
def priority(self) -> typing.Tuple[int, int, int]:
def priority(self) -> tuple[int, int, int]:
"""
The priority allows URLPattern instances to be sortable, so that
we can match from most specific to least specific.
@@ -454,7 +219,7 @@ class URLPattern:
def __hash__(self) -> int:
return hash(self.pattern)
def __lt__(self, other: "URLPattern") -> bool:
def __lt__(self, other: URLPattern) -> bool:
return self.priority < other.priority
def __eq__(self, other: typing.Any) -> bool: