updates
This commit is contained in:
@@ -1,48 +1,15 @@
|
||||
from .__version__ import __description__, __title__, __version__
|
||||
from ._api import delete, get, head, options, patch, post, put, request, stream
|
||||
from ._auth import Auth, BasicAuth, DigestAuth, NetRCAuth
|
||||
from ._client import USE_CLIENT_DEFAULT, AsyncClient, Client
|
||||
from ._config import Limits, Proxy, Timeout, create_ssl_context
|
||||
from ._content import ByteStream
|
||||
from ._exceptions import (
|
||||
CloseError,
|
||||
ConnectError,
|
||||
ConnectTimeout,
|
||||
CookieConflict,
|
||||
DecodingError,
|
||||
HTTPError,
|
||||
HTTPStatusError,
|
||||
InvalidURL,
|
||||
LocalProtocolError,
|
||||
NetworkError,
|
||||
PoolTimeout,
|
||||
ProtocolError,
|
||||
ProxyError,
|
||||
ReadError,
|
||||
ReadTimeout,
|
||||
RemoteProtocolError,
|
||||
RequestError,
|
||||
RequestNotRead,
|
||||
ResponseNotRead,
|
||||
StreamClosed,
|
||||
StreamConsumed,
|
||||
StreamError,
|
||||
TimeoutException,
|
||||
TooManyRedirects,
|
||||
TransportError,
|
||||
UnsupportedProtocol,
|
||||
WriteError,
|
||||
WriteTimeout,
|
||||
)
|
||||
from ._models import Cookies, Headers, Request, Response
|
||||
from ._status_codes import codes
|
||||
from ._transports.asgi import ASGITransport
|
||||
from ._transports.base import AsyncBaseTransport, BaseTransport
|
||||
from ._transports.default import AsyncHTTPTransport, HTTPTransport
|
||||
from ._transports.mock import MockTransport
|
||||
from ._transports.wsgi import WSGITransport
|
||||
from ._types import AsyncByteStream, SyncByteStream
|
||||
from ._urls import URL, QueryParams
|
||||
from ._api import *
|
||||
from ._auth import *
|
||||
from ._client import *
|
||||
from ._config import *
|
||||
from ._content import *
|
||||
from ._exceptions import *
|
||||
from ._models import *
|
||||
from ._status_codes import *
|
||||
from ._transports import *
|
||||
from ._types import *
|
||||
from ._urls import *
|
||||
|
||||
try:
|
||||
from ._main import main
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,3 +1,3 @@
|
||||
__title__ = "httpx"
|
||||
__description__ = "A next generation HTTP client, for Python 3."
|
||||
__version__ = "0.24.1"
|
||||
__version__ = "0.28.1"
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
from contextlib import contextmanager
|
||||
|
||||
@@ -6,37 +8,50 @@ from ._config import DEFAULT_TIMEOUT_CONFIG
|
||||
from ._models import Response
|
||||
from ._types import (
|
||||
AuthTypes,
|
||||
CertTypes,
|
||||
CookieTypes,
|
||||
HeaderTypes,
|
||||
ProxiesTypes,
|
||||
ProxyTypes,
|
||||
QueryParamTypes,
|
||||
RequestContent,
|
||||
RequestData,
|
||||
RequestFiles,
|
||||
TimeoutTypes,
|
||||
URLTypes,
|
||||
VerifyTypes,
|
||||
)
|
||||
from ._urls import URL
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
import ssl # pragma: no cover
|
||||
|
||||
|
||||
__all__ = [
|
||||
"delete",
|
||||
"get",
|
||||
"head",
|
||||
"options",
|
||||
"patch",
|
||||
"post",
|
||||
"put",
|
||||
"request",
|
||||
"stream",
|
||||
]
|
||||
|
||||
|
||||
def request(
|
||||
method: str,
|
||||
url: URLTypes,
|
||||
url: URL | str,
|
||||
*,
|
||||
params: typing.Optional[QueryParamTypes] = None,
|
||||
content: typing.Optional[RequestContent] = None,
|
||||
data: typing.Optional[RequestData] = None,
|
||||
files: typing.Optional[RequestFiles] = None,
|
||||
json: typing.Optional[typing.Any] = None,
|
||||
headers: typing.Optional[HeaderTypes] = None,
|
||||
cookies: typing.Optional[CookieTypes] = None,
|
||||
auth: typing.Optional[AuthTypes] = None,
|
||||
proxies: typing.Optional[ProxiesTypes] = None,
|
||||
params: QueryParamTypes | None = None,
|
||||
content: RequestContent | None = None,
|
||||
data: RequestData | None = None,
|
||||
files: RequestFiles | None = None,
|
||||
json: typing.Any | None = None,
|
||||
headers: HeaderTypes | None = None,
|
||||
cookies: CookieTypes | None = None,
|
||||
auth: AuthTypes | None = None,
|
||||
proxy: ProxyTypes | None = None,
|
||||
timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
|
||||
follow_redirects: bool = False,
|
||||
verify: VerifyTypes = True,
|
||||
cert: typing.Optional[CertTypes] = None,
|
||||
verify: ssl.SSLContext | str | bool = True,
|
||||
trust_env: bool = True,
|
||||
) -> Response:
|
||||
"""
|
||||
@@ -63,18 +78,13 @@ def request(
|
||||
request.
|
||||
* **auth** - *(optional)* An authentication class to use when sending the
|
||||
request.
|
||||
* **proxies** - *(optional)* A dictionary mapping proxy keys to proxy URLs.
|
||||
* **proxy** - *(optional)* A proxy URL where all the traffic should be routed.
|
||||
* **timeout** - *(optional)* The timeout configuration to use when sending
|
||||
the request.
|
||||
* **follow_redirects** - *(optional)* Enables or disables HTTP redirects.
|
||||
* **verify** - *(optional)* SSL certificates (a.k.a CA bundle) used to
|
||||
verify the identity of requested hosts. Either `True` (default CA bundle),
|
||||
a path to an SSL certificate file, an `ssl.SSLContext`, or `False`
|
||||
(which will disable verification).
|
||||
* **cert** - *(optional)* An SSL certificate used by the requested host
|
||||
to authenticate the client. Either a path to an SSL certificate file, or
|
||||
two-tuple of (certificate file, key file), or a three-tuple of (certificate
|
||||
file, key file, password).
|
||||
* **verify** - *(optional)* Either `True` to use an SSL context with the
|
||||
default CA bundle, `False` to disable verification, or an instance of
|
||||
`ssl.SSLContext` to use a custom context.
|
||||
* **trust_env** - *(optional)* Enables or disables usage of environment
|
||||
variables for configuration.
|
||||
|
||||
@@ -91,8 +101,7 @@ def request(
|
||||
"""
|
||||
with Client(
|
||||
cookies=cookies,
|
||||
proxies=proxies,
|
||||
cert=cert,
|
||||
proxy=proxy,
|
||||
verify=verify,
|
||||
timeout=timeout,
|
||||
trust_env=trust_env,
|
||||
@@ -114,21 +123,20 @@ def request(
|
||||
@contextmanager
|
||||
def stream(
|
||||
method: str,
|
||||
url: URLTypes,
|
||||
url: URL | str,
|
||||
*,
|
||||
params: typing.Optional[QueryParamTypes] = None,
|
||||
content: typing.Optional[RequestContent] = None,
|
||||
data: typing.Optional[RequestData] = None,
|
||||
files: typing.Optional[RequestFiles] = None,
|
||||
json: typing.Optional[typing.Any] = None,
|
||||
headers: typing.Optional[HeaderTypes] = None,
|
||||
cookies: typing.Optional[CookieTypes] = None,
|
||||
auth: typing.Optional[AuthTypes] = None,
|
||||
proxies: typing.Optional[ProxiesTypes] = None,
|
||||
params: QueryParamTypes | None = None,
|
||||
content: RequestContent | None = None,
|
||||
data: RequestData | None = None,
|
||||
files: RequestFiles | None = None,
|
||||
json: typing.Any | None = None,
|
||||
headers: HeaderTypes | None = None,
|
||||
cookies: CookieTypes | None = None,
|
||||
auth: AuthTypes | None = None,
|
||||
proxy: ProxyTypes | None = None,
|
||||
timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
|
||||
follow_redirects: bool = False,
|
||||
verify: VerifyTypes = True,
|
||||
cert: typing.Optional[CertTypes] = None,
|
||||
verify: ssl.SSLContext | str | bool = True,
|
||||
trust_env: bool = True,
|
||||
) -> typing.Iterator[Response]:
|
||||
"""
|
||||
@@ -143,8 +151,7 @@ def stream(
|
||||
"""
|
||||
with Client(
|
||||
cookies=cookies,
|
||||
proxies=proxies,
|
||||
cert=cert,
|
||||
proxy=proxy,
|
||||
verify=verify,
|
||||
timeout=timeout,
|
||||
trust_env=trust_env,
|
||||
@@ -165,16 +172,15 @@ def stream(
|
||||
|
||||
|
||||
def get(
|
||||
url: URLTypes,
|
||||
url: URL | str,
|
||||
*,
|
||||
params: typing.Optional[QueryParamTypes] = None,
|
||||
headers: typing.Optional[HeaderTypes] = None,
|
||||
cookies: typing.Optional[CookieTypes] = None,
|
||||
auth: typing.Optional[AuthTypes] = None,
|
||||
proxies: typing.Optional[ProxiesTypes] = None,
|
||||
params: QueryParamTypes | None = None,
|
||||
headers: HeaderTypes | None = None,
|
||||
cookies: CookieTypes | None = None,
|
||||
auth: AuthTypes | None = None,
|
||||
proxy: ProxyTypes | None = None,
|
||||
follow_redirects: bool = False,
|
||||
cert: typing.Optional[CertTypes] = None,
|
||||
verify: VerifyTypes = True,
|
||||
verify: ssl.SSLContext | str | bool = True,
|
||||
timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
|
||||
trust_env: bool = True,
|
||||
) -> Response:
|
||||
@@ -193,9 +199,8 @@ def get(
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
auth=auth,
|
||||
proxies=proxies,
|
||||
proxy=proxy,
|
||||
follow_redirects=follow_redirects,
|
||||
cert=cert,
|
||||
verify=verify,
|
||||
timeout=timeout,
|
||||
trust_env=trust_env,
|
||||
@@ -203,16 +208,15 @@ def get(
|
||||
|
||||
|
||||
def options(
|
||||
url: URLTypes,
|
||||
url: URL | str,
|
||||
*,
|
||||
params: typing.Optional[QueryParamTypes] = None,
|
||||
headers: typing.Optional[HeaderTypes] = None,
|
||||
cookies: typing.Optional[CookieTypes] = None,
|
||||
auth: typing.Optional[AuthTypes] = None,
|
||||
proxies: typing.Optional[ProxiesTypes] = None,
|
||||
params: QueryParamTypes | None = None,
|
||||
headers: HeaderTypes | None = None,
|
||||
cookies: CookieTypes | None = None,
|
||||
auth: AuthTypes | None = None,
|
||||
proxy: ProxyTypes | None = None,
|
||||
follow_redirects: bool = False,
|
||||
cert: typing.Optional[CertTypes] = None,
|
||||
verify: VerifyTypes = True,
|
||||
verify: ssl.SSLContext | str | bool = True,
|
||||
timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
|
||||
trust_env: bool = True,
|
||||
) -> Response:
|
||||
@@ -231,9 +235,8 @@ def options(
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
auth=auth,
|
||||
proxies=proxies,
|
||||
proxy=proxy,
|
||||
follow_redirects=follow_redirects,
|
||||
cert=cert,
|
||||
verify=verify,
|
||||
timeout=timeout,
|
||||
trust_env=trust_env,
|
||||
@@ -241,16 +244,15 @@ def options(
|
||||
|
||||
|
||||
def head(
|
||||
url: URLTypes,
|
||||
url: URL | str,
|
||||
*,
|
||||
params: typing.Optional[QueryParamTypes] = None,
|
||||
headers: typing.Optional[HeaderTypes] = None,
|
||||
cookies: typing.Optional[CookieTypes] = None,
|
||||
auth: typing.Optional[AuthTypes] = None,
|
||||
proxies: typing.Optional[ProxiesTypes] = None,
|
||||
params: QueryParamTypes | None = None,
|
||||
headers: HeaderTypes | None = None,
|
||||
cookies: CookieTypes | None = None,
|
||||
auth: AuthTypes | None = None,
|
||||
proxy: ProxyTypes | None = None,
|
||||
follow_redirects: bool = False,
|
||||
cert: typing.Optional[CertTypes] = None,
|
||||
verify: VerifyTypes = True,
|
||||
verify: ssl.SSLContext | str | bool = True,
|
||||
timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
|
||||
trust_env: bool = True,
|
||||
) -> Response:
|
||||
@@ -269,9 +271,8 @@ def head(
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
auth=auth,
|
||||
proxies=proxies,
|
||||
proxy=proxy,
|
||||
follow_redirects=follow_redirects,
|
||||
cert=cert,
|
||||
verify=verify,
|
||||
timeout=timeout,
|
||||
trust_env=trust_env,
|
||||
@@ -279,20 +280,19 @@ def head(
|
||||
|
||||
|
||||
def post(
|
||||
url: URLTypes,
|
||||
url: URL | str,
|
||||
*,
|
||||
content: typing.Optional[RequestContent] = None,
|
||||
data: typing.Optional[RequestData] = None,
|
||||
files: typing.Optional[RequestFiles] = None,
|
||||
json: typing.Optional[typing.Any] = None,
|
||||
params: typing.Optional[QueryParamTypes] = None,
|
||||
headers: typing.Optional[HeaderTypes] = None,
|
||||
cookies: typing.Optional[CookieTypes] = None,
|
||||
auth: typing.Optional[AuthTypes] = None,
|
||||
proxies: typing.Optional[ProxiesTypes] = None,
|
||||
content: RequestContent | None = None,
|
||||
data: RequestData | None = None,
|
||||
files: RequestFiles | None = None,
|
||||
json: typing.Any | None = None,
|
||||
params: QueryParamTypes | None = None,
|
||||
headers: HeaderTypes | None = None,
|
||||
cookies: CookieTypes | None = None,
|
||||
auth: AuthTypes | None = None,
|
||||
proxy: ProxyTypes | None = None,
|
||||
follow_redirects: bool = False,
|
||||
cert: typing.Optional[CertTypes] = None,
|
||||
verify: VerifyTypes = True,
|
||||
verify: ssl.SSLContext | str | bool = True,
|
||||
timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
|
||||
trust_env: bool = True,
|
||||
) -> Response:
|
||||
@@ -312,9 +312,8 @@ def post(
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
auth=auth,
|
||||
proxies=proxies,
|
||||
proxy=proxy,
|
||||
follow_redirects=follow_redirects,
|
||||
cert=cert,
|
||||
verify=verify,
|
||||
timeout=timeout,
|
||||
trust_env=trust_env,
|
||||
@@ -322,20 +321,19 @@ def post(
|
||||
|
||||
|
||||
def put(
|
||||
url: URLTypes,
|
||||
url: URL | str,
|
||||
*,
|
||||
content: typing.Optional[RequestContent] = None,
|
||||
data: typing.Optional[RequestData] = None,
|
||||
files: typing.Optional[RequestFiles] = None,
|
||||
json: typing.Optional[typing.Any] = None,
|
||||
params: typing.Optional[QueryParamTypes] = None,
|
||||
headers: typing.Optional[HeaderTypes] = None,
|
||||
cookies: typing.Optional[CookieTypes] = None,
|
||||
auth: typing.Optional[AuthTypes] = None,
|
||||
proxies: typing.Optional[ProxiesTypes] = None,
|
||||
content: RequestContent | None = None,
|
||||
data: RequestData | None = None,
|
||||
files: RequestFiles | None = None,
|
||||
json: typing.Any | None = None,
|
||||
params: QueryParamTypes | None = None,
|
||||
headers: HeaderTypes | None = None,
|
||||
cookies: CookieTypes | None = None,
|
||||
auth: AuthTypes | None = None,
|
||||
proxy: ProxyTypes | None = None,
|
||||
follow_redirects: bool = False,
|
||||
cert: typing.Optional[CertTypes] = None,
|
||||
verify: VerifyTypes = True,
|
||||
verify: ssl.SSLContext | str | bool = True,
|
||||
timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
|
||||
trust_env: bool = True,
|
||||
) -> Response:
|
||||
@@ -355,9 +353,8 @@ def put(
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
auth=auth,
|
||||
proxies=proxies,
|
||||
proxy=proxy,
|
||||
follow_redirects=follow_redirects,
|
||||
cert=cert,
|
||||
verify=verify,
|
||||
timeout=timeout,
|
||||
trust_env=trust_env,
|
||||
@@ -365,20 +362,19 @@ def put(
|
||||
|
||||
|
||||
def patch(
|
||||
url: URLTypes,
|
||||
url: URL | str,
|
||||
*,
|
||||
content: typing.Optional[RequestContent] = None,
|
||||
data: typing.Optional[RequestData] = None,
|
||||
files: typing.Optional[RequestFiles] = None,
|
||||
json: typing.Optional[typing.Any] = None,
|
||||
params: typing.Optional[QueryParamTypes] = None,
|
||||
headers: typing.Optional[HeaderTypes] = None,
|
||||
cookies: typing.Optional[CookieTypes] = None,
|
||||
auth: typing.Optional[AuthTypes] = None,
|
||||
proxies: typing.Optional[ProxiesTypes] = None,
|
||||
content: RequestContent | None = None,
|
||||
data: RequestData | None = None,
|
||||
files: RequestFiles | None = None,
|
||||
json: typing.Any | None = None,
|
||||
params: QueryParamTypes | None = None,
|
||||
headers: HeaderTypes | None = None,
|
||||
cookies: CookieTypes | None = None,
|
||||
auth: AuthTypes | None = None,
|
||||
proxy: ProxyTypes | None = None,
|
||||
follow_redirects: bool = False,
|
||||
cert: typing.Optional[CertTypes] = None,
|
||||
verify: VerifyTypes = True,
|
||||
verify: ssl.SSLContext | str | bool = True,
|
||||
timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
|
||||
trust_env: bool = True,
|
||||
) -> Response:
|
||||
@@ -398,9 +394,8 @@ def patch(
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
auth=auth,
|
||||
proxies=proxies,
|
||||
proxy=proxy,
|
||||
follow_redirects=follow_redirects,
|
||||
cert=cert,
|
||||
verify=verify,
|
||||
timeout=timeout,
|
||||
trust_env=trust_env,
|
||||
@@ -408,17 +403,16 @@ def patch(
|
||||
|
||||
|
||||
def delete(
|
||||
url: URLTypes,
|
||||
url: URL | str,
|
||||
*,
|
||||
params: typing.Optional[QueryParamTypes] = None,
|
||||
headers: typing.Optional[HeaderTypes] = None,
|
||||
cookies: typing.Optional[CookieTypes] = None,
|
||||
auth: typing.Optional[AuthTypes] = None,
|
||||
proxies: typing.Optional[ProxiesTypes] = None,
|
||||
params: QueryParamTypes | None = None,
|
||||
headers: HeaderTypes | None = None,
|
||||
cookies: CookieTypes | None = None,
|
||||
auth: AuthTypes | None = None,
|
||||
proxy: ProxyTypes | None = None,
|
||||
follow_redirects: bool = False,
|
||||
cert: typing.Optional[CertTypes] = None,
|
||||
verify: VerifyTypes = True,
|
||||
timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
|
||||
verify: ssl.SSLContext | str | bool = True,
|
||||
trust_env: bool = True,
|
||||
) -> Response:
|
||||
"""
|
||||
@@ -436,9 +430,8 @@ def delete(
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
auth=auth,
|
||||
proxies=proxies,
|
||||
proxy=proxy,
|
||||
follow_redirects=follow_redirects,
|
||||
cert=cert,
|
||||
verify=verify,
|
||||
timeout=timeout,
|
||||
trust_env=trust_env,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import netrc
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
@@ -8,13 +9,16 @@ from base64 import b64encode
|
||||
from urllib.request import parse_http_list
|
||||
|
||||
from ._exceptions import ProtocolError
|
||||
from ._models import Request, Response
|
||||
from ._models import Cookies, Request, Response
|
||||
from ._utils import to_bytes, to_str, unquote
|
||||
|
||||
if typing.TYPE_CHECKING: # pragma: no cover
|
||||
from hashlib import _Hash
|
||||
|
||||
|
||||
__all__ = ["Auth", "BasicAuth", "DigestAuth", "NetRCAuth"]
|
||||
|
||||
|
||||
class Auth:
|
||||
"""
|
||||
Base class for all authentication schemes.
|
||||
@@ -125,18 +129,14 @@ class BasicAuth(Auth):
|
||||
and uses HTTP Basic authentication.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, username: typing.Union[str, bytes], password: typing.Union[str, bytes]
|
||||
):
|
||||
def __init__(self, username: str | bytes, password: str | bytes) -> None:
|
||||
self._auth_header = self._build_auth_header(username, password)
|
||||
|
||||
def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]:
|
||||
request.headers["Authorization"] = self._auth_header
|
||||
yield request
|
||||
|
||||
def _build_auth_header(
|
||||
self, username: typing.Union[str, bytes], password: typing.Union[str, bytes]
|
||||
) -> str:
|
||||
def _build_auth_header(self, username: str | bytes, password: str | bytes) -> str:
|
||||
userpass = b":".join((to_bytes(username), to_bytes(password)))
|
||||
token = b64encode(userpass).decode()
|
||||
return f"Basic {token}"
|
||||
@@ -147,7 +147,11 @@ class NetRCAuth(Auth):
|
||||
Use a 'netrc' file to lookup basic auth credentials based on the url host.
|
||||
"""
|
||||
|
||||
def __init__(self, file: typing.Optional[str] = None):
|
||||
def __init__(self, file: str | None = None) -> None:
|
||||
# Lazily import 'netrc'.
|
||||
# There's no need for us to load this module unless 'NetRCAuth' is being used.
|
||||
import netrc
|
||||
|
||||
self._netrc_info = netrc.netrc(file)
|
||||
|
||||
def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]:
|
||||
@@ -162,16 +166,14 @@ class NetRCAuth(Auth):
|
||||
)
|
||||
yield request
|
||||
|
||||
def _build_auth_header(
|
||||
self, username: typing.Union[str, bytes], password: typing.Union[str, bytes]
|
||||
) -> str:
|
||||
def _build_auth_header(self, username: str | bytes, password: str | bytes) -> str:
|
||||
userpass = b":".join((to_bytes(username), to_bytes(password)))
|
||||
token = b64encode(userpass).decode()
|
||||
return f"Basic {token}"
|
||||
|
||||
|
||||
class DigestAuth(Auth):
|
||||
_ALGORITHM_TO_HASH_FUNCTION: typing.Dict[str, typing.Callable[[bytes], "_Hash"]] = {
|
||||
_ALGORITHM_TO_HASH_FUNCTION: dict[str, typing.Callable[[bytes], _Hash]] = {
|
||||
"MD5": hashlib.md5,
|
||||
"MD5-SESS": hashlib.md5,
|
||||
"SHA": hashlib.sha1,
|
||||
@@ -182,12 +184,10 @@ class DigestAuth(Auth):
|
||||
"SHA-512-SESS": hashlib.sha512,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self, username: typing.Union[str, bytes], password: typing.Union[str, bytes]
|
||||
) -> None:
|
||||
def __init__(self, username: str | bytes, password: str | bytes) -> None:
|
||||
self._username = to_bytes(username)
|
||||
self._password = to_bytes(password)
|
||||
self._last_challenge: typing.Optional[_DigestAuthChallenge] = None
|
||||
self._last_challenge: _DigestAuthChallenge | None = None
|
||||
self._nonce_count = 1
|
||||
|
||||
def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]:
|
||||
@@ -217,11 +217,13 @@ class DigestAuth(Auth):
|
||||
request.headers["Authorization"] = self._build_auth_header(
|
||||
request, self._last_challenge
|
||||
)
|
||||
if response.cookies:
|
||||
Cookies(response.cookies).set_cookie_header(request=request)
|
||||
yield request
|
||||
|
||||
def _parse_challenge(
|
||||
self, request: Request, response: Response, auth_header: str
|
||||
) -> "_DigestAuthChallenge":
|
||||
) -> _DigestAuthChallenge:
|
||||
"""
|
||||
Returns a challenge from a Digest WWW-Authenticate header.
|
||||
These take the form of:
|
||||
@@ -232,7 +234,7 @@ class DigestAuth(Auth):
|
||||
# This method should only ever have been called with a Digest auth header.
|
||||
assert scheme.lower() == "digest"
|
||||
|
||||
header_dict: typing.Dict[str, str] = {}
|
||||
header_dict: dict[str, str] = {}
|
||||
for field in parse_http_list(fields):
|
||||
key, value = field.strip().split("=", 1)
|
||||
header_dict[key] = unquote(value)
|
||||
@@ -251,7 +253,7 @@ class DigestAuth(Auth):
|
||||
raise ProtocolError(message, request=request) from exc
|
||||
|
||||
def _build_auth_header(
|
||||
self, request: Request, challenge: "_DigestAuthChallenge"
|
||||
self, request: Request, challenge: _DigestAuthChallenge
|
||||
) -> str:
|
||||
hash_func = self._ALGORITHM_TO_HASH_FUNCTION[challenge.algorithm.upper()]
|
||||
|
||||
@@ -275,17 +277,18 @@ class DigestAuth(Auth):
|
||||
|
||||
qop = self._resolve_qop(challenge.qop, request=request)
|
||||
if qop is None:
|
||||
# Following RFC 2069
|
||||
digest_data = [HA1, challenge.nonce, HA2]
|
||||
else:
|
||||
digest_data = [challenge.nonce, nc_value, cnonce, qop, HA2]
|
||||
key_digest = b":".join(digest_data)
|
||||
# Following RFC 2617/7616
|
||||
digest_data = [HA1, challenge.nonce, nc_value, cnonce, qop, HA2]
|
||||
|
||||
format_args = {
|
||||
"username": self._username,
|
||||
"realm": challenge.realm,
|
||||
"nonce": challenge.nonce,
|
||||
"uri": path,
|
||||
"response": digest(b":".join((HA1, key_digest))),
|
||||
"response": digest(b":".join(digest_data)),
|
||||
"algorithm": challenge.algorithm.encode(),
|
||||
}
|
||||
if challenge.opaque:
|
||||
@@ -305,7 +308,7 @@ class DigestAuth(Auth):
|
||||
|
||||
return hashlib.sha1(s).hexdigest()[:16].encode()
|
||||
|
||||
def _get_header_value(self, header_fields: typing.Dict[str, bytes]) -> str:
|
||||
def _get_header_value(self, header_fields: dict[str, bytes]) -> str:
|
||||
NON_QUOTED_FIELDS = ("algorithm", "qop", "nc")
|
||||
QUOTED_TEMPLATE = '{}="{}"'
|
||||
NON_QUOTED_TEMPLATE = "{}={}"
|
||||
@@ -323,9 +326,7 @@ class DigestAuth(Auth):
|
||||
|
||||
return header_value
|
||||
|
||||
def _resolve_qop(
|
||||
self, qop: typing.Optional[bytes], request: Request
|
||||
) -> typing.Optional[bytes]:
|
||||
def _resolve_qop(self, qop: bytes | None, request: Request) -> bytes | None:
|
||||
if qop is None:
|
||||
return None
|
||||
qops = re.split(b", ?", qop)
|
||||
@@ -343,5 +344,5 @@ class _DigestAuthChallenge(typing.NamedTuple):
|
||||
realm: bytes
|
||||
nonce: bytes
|
||||
algorithm: str
|
||||
opaque: typing.Optional[bytes]
|
||||
qop: typing.Optional[bytes]
|
||||
opaque: bytes | None
|
||||
qop: bytes | None
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,43 +0,0 @@
|
||||
"""
|
||||
The _compat module is used for code which requires branching between different
|
||||
Python environments. It is excluded from the code coverage checks.
|
||||
"""
|
||||
import ssl
|
||||
import sys
|
||||
|
||||
# Brotli support is optional
|
||||
# The C bindings in `brotli` are recommended for CPython.
|
||||
# The CFFI bindings in `brotlicffi` are recommended for PyPy and everything else.
|
||||
try:
|
||||
import brotlicffi as brotli
|
||||
except ImportError: # pragma: no cover
|
||||
try:
|
||||
import brotli
|
||||
except ImportError:
|
||||
brotli = None
|
||||
|
||||
if sys.version_info >= (3, 10) or (
|
||||
sys.version_info >= (3, 7) and ssl.OPENSSL_VERSION_INFO >= (1, 1, 0, 7)
|
||||
):
|
||||
|
||||
def set_minimum_tls_version_1_2(context: ssl.SSLContext) -> None:
|
||||
# The OP_NO_SSL* and OP_NO_TLS* become deprecated in favor of
|
||||
# 'SSLContext.minimum_version' from Python 3.7 onwards, however
|
||||
# this attribute is not available unless the ssl module is compiled
|
||||
# with OpenSSL 1.1.0g or newer.
|
||||
# https://docs.python.org/3.10/library/ssl.html#ssl.SSLContext.minimum_version
|
||||
# https://docs.python.org/3.7/library/ssl.html#ssl.SSLContext.minimum_version
|
||||
context.minimum_version = ssl.TLSVersion.TLSv1_2
|
||||
|
||||
else:
|
||||
|
||||
def set_minimum_tls_version_1_2(context: ssl.SSLContext) -> None:
|
||||
# If 'minimum_version' isn't available, we configure these options with
|
||||
# the older deprecated variants.
|
||||
context.options |= ssl.OP_NO_SSLv2
|
||||
context.options |= ssl.OP_NO_SSLv3
|
||||
context.options |= ssl.OP_NO_TLSv1
|
||||
context.options |= ssl.OP_NO_TLSv1_1
|
||||
|
||||
|
||||
__all__ = ["brotli", "set_minimum_tls_version_1_2"]
|
||||
@@ -1,39 +1,16 @@
|
||||
import logging
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import ssl
|
||||
import sys
|
||||
import typing
|
||||
from pathlib import Path
|
||||
|
||||
import certifi
|
||||
|
||||
from ._compat import set_minimum_tls_version_1_2
|
||||
from ._models import Headers
|
||||
from ._types import CertTypes, HeaderTypes, TimeoutTypes, URLTypes, VerifyTypes
|
||||
from ._types import CertTypes, HeaderTypes, TimeoutTypes
|
||||
from ._urls import URL
|
||||
from ._utils import get_ca_bundle_from_env
|
||||
|
||||
DEFAULT_CIPHERS = ":".join(
|
||||
[
|
||||
"ECDHE+AESGCM",
|
||||
"ECDHE+CHACHA20",
|
||||
"DHE+AESGCM",
|
||||
"DHE+CHACHA20",
|
||||
"ECDH+AESGCM",
|
||||
"DH+AESGCM",
|
||||
"ECDH+AES",
|
||||
"DH+AES",
|
||||
"RSA+AESGCM",
|
||||
"RSA+AES",
|
||||
"!aNULL",
|
||||
"!eNULL",
|
||||
"!MD5",
|
||||
"!DSS",
|
||||
]
|
||||
)
|
||||
if typing.TYPE_CHECKING:
|
||||
import ssl # pragma: no cover
|
||||
|
||||
|
||||
logger = logging.getLogger("httpx")
|
||||
__all__ = ["Limits", "Proxy", "Timeout", "create_ssl_context"]
|
||||
|
||||
|
||||
class UnsetType:
|
||||
@@ -44,152 +21,52 @@ UNSET = UnsetType()
|
||||
|
||||
|
||||
def create_ssl_context(
|
||||
cert: typing.Optional[CertTypes] = None,
|
||||
verify: VerifyTypes = True,
|
||||
verify: ssl.SSLContext | str | bool = True,
|
||||
cert: CertTypes | None = None,
|
||||
trust_env: bool = True,
|
||||
http2: bool = False,
|
||||
) -> ssl.SSLContext:
|
||||
return SSLConfig(
|
||||
cert=cert, verify=verify, trust_env=trust_env, http2=http2
|
||||
).ssl_context
|
||||
import ssl
|
||||
import warnings
|
||||
|
||||
import certifi
|
||||
|
||||
class SSLConfig:
|
||||
"""
|
||||
SSL Configuration.
|
||||
"""
|
||||
|
||||
DEFAULT_CA_BUNDLE_PATH = Path(certifi.where())
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
cert: typing.Optional[CertTypes] = None,
|
||||
verify: VerifyTypes = True,
|
||||
trust_env: bool = True,
|
||||
http2: bool = False,
|
||||
):
|
||||
self.cert = cert
|
||||
self.verify = verify
|
||||
self.trust_env = trust_env
|
||||
self.http2 = http2
|
||||
self.ssl_context = self.load_ssl_context()
|
||||
|
||||
def load_ssl_context(self) -> ssl.SSLContext:
|
||||
logger.debug(
|
||||
"load_ssl_context verify=%r cert=%r trust_env=%r http2=%r",
|
||||
self.verify,
|
||||
self.cert,
|
||||
self.trust_env,
|
||||
self.http2,
|
||||
)
|
||||
|
||||
if self.verify:
|
||||
return self.load_ssl_context_verify()
|
||||
return self.load_ssl_context_no_verify()
|
||||
|
||||
def load_ssl_context_no_verify(self) -> ssl.SSLContext:
|
||||
"""
|
||||
Return an SSL context for unverified connections.
|
||||
"""
|
||||
context = self._create_default_ssl_context()
|
||||
context.check_hostname = False
|
||||
context.verify_mode = ssl.CERT_NONE
|
||||
self._load_client_certs(context)
|
||||
return context
|
||||
|
||||
def load_ssl_context_verify(self) -> ssl.SSLContext:
|
||||
"""
|
||||
Return an SSL context for verified connections.
|
||||
"""
|
||||
if self.trust_env and self.verify is True:
|
||||
ca_bundle = get_ca_bundle_from_env()
|
||||
if ca_bundle is not None:
|
||||
self.verify = ca_bundle
|
||||
|
||||
if isinstance(self.verify, ssl.SSLContext):
|
||||
# Allow passing in our own SSLContext object that's pre-configured.
|
||||
context = self.verify
|
||||
self._load_client_certs(context)
|
||||
return context
|
||||
elif isinstance(self.verify, bool):
|
||||
ca_bundle_path = self.DEFAULT_CA_BUNDLE_PATH
|
||||
elif Path(self.verify).exists():
|
||||
ca_bundle_path = Path(self.verify)
|
||||
if verify is True:
|
||||
if trust_env and os.environ.get("SSL_CERT_FILE"): # pragma: nocover
|
||||
ctx = ssl.create_default_context(cafile=os.environ["SSL_CERT_FILE"])
|
||||
elif trust_env and os.environ.get("SSL_CERT_DIR"): # pragma: nocover
|
||||
ctx = ssl.create_default_context(capath=os.environ["SSL_CERT_DIR"])
|
||||
else:
|
||||
raise IOError(
|
||||
"Could not find a suitable TLS CA certificate bundle, "
|
||||
"invalid path: {}".format(self.verify)
|
||||
)
|
||||
# Default case...
|
||||
ctx = ssl.create_default_context(cafile=certifi.where())
|
||||
elif verify is False:
|
||||
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
||||
ctx.check_hostname = False
|
||||
ctx.verify_mode = ssl.CERT_NONE
|
||||
elif isinstance(verify, str): # pragma: nocover
|
||||
message = (
|
||||
"`verify=<str>` is deprecated. "
|
||||
"Use `verify=ssl.create_default_context(cafile=...)` "
|
||||
"or `verify=ssl.create_default_context(capath=...)` instead."
|
||||
)
|
||||
warnings.warn(message, DeprecationWarning)
|
||||
if os.path.isdir(verify):
|
||||
return ssl.create_default_context(capath=verify)
|
||||
return ssl.create_default_context(cafile=verify)
|
||||
else:
|
||||
ctx = verify
|
||||
|
||||
context = self._create_default_ssl_context()
|
||||
context.verify_mode = ssl.CERT_REQUIRED
|
||||
context.check_hostname = True
|
||||
if cert: # pragma: nocover
|
||||
message = (
|
||||
"`cert=...` is deprecated. Use `verify=<ssl_context>` instead,"
|
||||
"with `.load_cert_chain()` to configure the certificate chain."
|
||||
)
|
||||
warnings.warn(message, DeprecationWarning)
|
||||
if isinstance(cert, str):
|
||||
ctx.load_cert_chain(cert)
|
||||
else:
|
||||
ctx.load_cert_chain(*cert)
|
||||
|
||||
# Signal to server support for PHA in TLS 1.3. Raises an
|
||||
# AttributeError if only read-only access is implemented.
|
||||
if sys.version_info >= (3, 8): # pragma: no cover
|
||||
try:
|
||||
context.post_handshake_auth = True
|
||||
except AttributeError: # pragma: no cover
|
||||
pass
|
||||
|
||||
# Disable using 'commonName' for SSLContext.check_hostname
|
||||
# when the 'subjectAltName' extension isn't available.
|
||||
try:
|
||||
context.hostname_checks_common_name = False
|
||||
except AttributeError: # pragma: no cover
|
||||
pass
|
||||
|
||||
if ca_bundle_path.is_file():
|
||||
cafile = str(ca_bundle_path)
|
||||
logger.debug("load_verify_locations cafile=%r", cafile)
|
||||
context.load_verify_locations(cafile=cafile)
|
||||
elif ca_bundle_path.is_dir():
|
||||
capath = str(ca_bundle_path)
|
||||
logger.debug("load_verify_locations capath=%r", capath)
|
||||
context.load_verify_locations(capath=capath)
|
||||
|
||||
self._load_client_certs(context)
|
||||
|
||||
return context
|
||||
|
||||
def _create_default_ssl_context(self) -> ssl.SSLContext:
|
||||
"""
|
||||
Creates the default SSLContext object that's used for both verified
|
||||
and unverified connections.
|
||||
"""
|
||||
context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
||||
set_minimum_tls_version_1_2(context)
|
||||
context.options |= ssl.OP_NO_COMPRESSION
|
||||
context.set_ciphers(DEFAULT_CIPHERS)
|
||||
|
||||
if ssl.HAS_ALPN:
|
||||
alpn_idents = ["http/1.1", "h2"] if self.http2 else ["http/1.1"]
|
||||
context.set_alpn_protocols(alpn_idents)
|
||||
|
||||
if sys.version_info >= (3, 8): # pragma: no cover
|
||||
keylogfile = os.environ.get("SSLKEYLOGFILE")
|
||||
if keylogfile and self.trust_env:
|
||||
context.keylog_filename = keylogfile
|
||||
|
||||
return context
|
||||
|
||||
def _load_client_certs(self, ssl_context: ssl.SSLContext) -> None:
|
||||
"""
|
||||
Loads client certificates into our SSLContext object
|
||||
"""
|
||||
if self.cert is not None:
|
||||
if isinstance(self.cert, str):
|
||||
ssl_context.load_cert_chain(certfile=self.cert)
|
||||
elif isinstance(self.cert, tuple) and len(self.cert) == 2:
|
||||
ssl_context.load_cert_chain(certfile=self.cert[0], keyfile=self.cert[1])
|
||||
elif isinstance(self.cert, tuple) and len(self.cert) == 3:
|
||||
ssl_context.load_cert_chain(
|
||||
certfile=self.cert[0],
|
||||
keyfile=self.cert[1],
|
||||
password=self.cert[2], # type: ignore
|
||||
)
|
||||
return ctx
|
||||
|
||||
|
||||
class Timeout:
|
||||
@@ -208,13 +85,13 @@ class Timeout:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
timeout: typing.Union[TimeoutTypes, UnsetType] = UNSET,
|
||||
timeout: TimeoutTypes | UnsetType = UNSET,
|
||||
*,
|
||||
connect: typing.Union[None, float, UnsetType] = UNSET,
|
||||
read: typing.Union[None, float, UnsetType] = UNSET,
|
||||
write: typing.Union[None, float, UnsetType] = UNSET,
|
||||
pool: typing.Union[None, float, UnsetType] = UNSET,
|
||||
):
|
||||
connect: None | float | UnsetType = UNSET,
|
||||
read: None | float | UnsetType = UNSET,
|
||||
write: None | float | UnsetType = UNSET,
|
||||
pool: None | float | UnsetType = UNSET,
|
||||
) -> None:
|
||||
if isinstance(timeout, Timeout):
|
||||
# Passed as a single explicit Timeout.
|
||||
assert connect is UNSET
|
||||
@@ -252,7 +129,7 @@ class Timeout:
|
||||
self.write = timeout if isinstance(write, UnsetType) else write
|
||||
self.pool = timeout if isinstance(pool, UnsetType) else pool
|
||||
|
||||
def as_dict(self) -> typing.Dict[str, typing.Optional[float]]:
|
||||
def as_dict(self) -> dict[str, float | None]:
|
||||
return {
|
||||
"connect": self.connect,
|
||||
"read": self.read,
|
||||
@@ -296,10 +173,10 @@ class Limits:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
max_connections: typing.Optional[int] = None,
|
||||
max_keepalive_connections: typing.Optional[int] = None,
|
||||
keepalive_expiry: typing.Optional[float] = 5.0,
|
||||
):
|
||||
max_connections: int | None = None,
|
||||
max_keepalive_connections: int | None = None,
|
||||
keepalive_expiry: float | None = 5.0,
|
||||
) -> None:
|
||||
self.max_connections = max_connections
|
||||
self.max_keepalive_connections = max_keepalive_connections
|
||||
self.keepalive_expiry = keepalive_expiry
|
||||
@@ -324,15 +201,16 @@ class Limits:
|
||||
class Proxy:
|
||||
def __init__(
|
||||
self,
|
||||
url: URLTypes,
|
||||
url: URL | str,
|
||||
*,
|
||||
auth: typing.Optional[typing.Tuple[str, str]] = None,
|
||||
headers: typing.Optional[HeaderTypes] = None,
|
||||
):
|
||||
ssl_context: ssl.SSLContext | None = None,
|
||||
auth: tuple[str, str] | None = None,
|
||||
headers: HeaderTypes | None = None,
|
||||
) -> None:
|
||||
url = URL(url)
|
||||
headers = Headers(headers)
|
||||
|
||||
if url.scheme not in ("http", "https", "socks5"):
|
||||
if url.scheme not in ("http", "https", "socks5", "socks5h"):
|
||||
raise ValueError(f"Unknown scheme for proxy URL {url!r}")
|
||||
|
||||
if url.username or url.password:
|
||||
@@ -343,9 +221,10 @@ class Proxy:
|
||||
self.url = url
|
||||
self.auth = auth
|
||||
self.headers = headers
|
||||
self.ssl_context = ssl_context
|
||||
|
||||
@property
|
||||
def raw_auth(self) -> typing.Optional[typing.Tuple[bytes, bytes]]:
|
||||
def raw_auth(self) -> tuple[bytes, bytes] | None:
|
||||
# The proxy authentication as raw bytes.
|
||||
return (
|
||||
None
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import warnings
|
||||
from json import dumps as json_dumps
|
||||
@@ -5,13 +7,9 @@ from typing import (
|
||||
Any,
|
||||
AsyncIterable,
|
||||
AsyncIterator,
|
||||
Dict,
|
||||
Iterable,
|
||||
Iterator,
|
||||
Mapping,
|
||||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
from urllib.parse import urlencode
|
||||
|
||||
@@ -27,6 +25,8 @@ from ._types import (
|
||||
)
|
||||
from ._utils import peek_filelike_length, primitive_value_to_str
|
||||
|
||||
__all__ = ["ByteStream"]
|
||||
|
||||
|
||||
class ByteStream(AsyncByteStream, SyncByteStream):
|
||||
def __init__(self, stream: bytes) -> None:
|
||||
@@ -42,7 +42,7 @@ class ByteStream(AsyncByteStream, SyncByteStream):
|
||||
class IteratorByteStream(SyncByteStream):
|
||||
CHUNK_SIZE = 65_536
|
||||
|
||||
def __init__(self, stream: Iterable[bytes]):
|
||||
def __init__(self, stream: Iterable[bytes]) -> None:
|
||||
self._stream = stream
|
||||
self._is_stream_consumed = False
|
||||
self._is_generator = inspect.isgenerator(stream)
|
||||
@@ -67,7 +67,7 @@ class IteratorByteStream(SyncByteStream):
|
||||
class AsyncIteratorByteStream(AsyncByteStream):
|
||||
CHUNK_SIZE = 65_536
|
||||
|
||||
def __init__(self, stream: AsyncIterable[bytes]):
|
||||
def __init__(self, stream: AsyncIterable[bytes]) -> None:
|
||||
self._stream = stream
|
||||
self._is_stream_consumed = False
|
||||
self._is_generator = inspect.isasyncgen(stream)
|
||||
@@ -105,8 +105,8 @@ class UnattachedStream(AsyncByteStream, SyncByteStream):
|
||||
|
||||
|
||||
def encode_content(
|
||||
content: Union[str, bytes, Iterable[bytes], AsyncIterable[bytes]]
|
||||
) -> Tuple[Dict[str, str], Union[SyncByteStream, AsyncByteStream]]:
|
||||
content: str | bytes | Iterable[bytes] | AsyncIterable[bytes],
|
||||
) -> tuple[dict[str, str], SyncByteStream | AsyncByteStream]:
|
||||
if isinstance(content, (bytes, str)):
|
||||
body = content.encode("utf-8") if isinstance(content, str) else content
|
||||
content_length = len(body)
|
||||
@@ -135,7 +135,7 @@ def encode_content(
|
||||
|
||||
def encode_urlencoded_data(
|
||||
data: RequestData,
|
||||
) -> Tuple[Dict[str, str], ByteStream]:
|
||||
) -> tuple[dict[str, str], ByteStream]:
|
||||
plain_data = []
|
||||
for key, value in data.items():
|
||||
if isinstance(value, (list, tuple)):
|
||||
@@ -150,14 +150,14 @@ def encode_urlencoded_data(
|
||||
|
||||
|
||||
def encode_multipart_data(
|
||||
data: RequestData, files: RequestFiles, boundary: Optional[bytes]
|
||||
) -> Tuple[Dict[str, str], MultipartStream]:
|
||||
data: RequestData, files: RequestFiles, boundary: bytes | None
|
||||
) -> tuple[dict[str, str], MultipartStream]:
|
||||
multipart = MultipartStream(data=data, files=files, boundary=boundary)
|
||||
headers = multipart.get_headers()
|
||||
return headers, multipart
|
||||
|
||||
|
||||
def encode_text(text: str) -> Tuple[Dict[str, str], ByteStream]:
|
||||
def encode_text(text: str) -> tuple[dict[str, str], ByteStream]:
|
||||
body = text.encode("utf-8")
|
||||
content_length = str(len(body))
|
||||
content_type = "text/plain; charset=utf-8"
|
||||
@@ -165,7 +165,7 @@ def encode_text(text: str) -> Tuple[Dict[str, str], ByteStream]:
|
||||
return headers, ByteStream(body)
|
||||
|
||||
|
||||
def encode_html(html: str) -> Tuple[Dict[str, str], ByteStream]:
|
||||
def encode_html(html: str) -> tuple[dict[str, str], ByteStream]:
|
||||
body = html.encode("utf-8")
|
||||
content_length = str(len(body))
|
||||
content_type = "text/html; charset=utf-8"
|
||||
@@ -173,8 +173,10 @@ def encode_html(html: str) -> Tuple[Dict[str, str], ByteStream]:
|
||||
return headers, ByteStream(body)
|
||||
|
||||
|
||||
def encode_json(json: Any) -> Tuple[Dict[str, str], ByteStream]:
|
||||
body = json_dumps(json).encode("utf-8")
|
||||
def encode_json(json: Any) -> tuple[dict[str, str], ByteStream]:
|
||||
body = json_dumps(
|
||||
json, ensure_ascii=False, separators=(",", ":"), allow_nan=False
|
||||
).encode("utf-8")
|
||||
content_length = str(len(body))
|
||||
content_type = "application/json"
|
||||
headers = {"Content-Length": content_length, "Content-Type": content_type}
|
||||
@@ -182,12 +184,12 @@ def encode_json(json: Any) -> Tuple[Dict[str, str], ByteStream]:
|
||||
|
||||
|
||||
def encode_request(
|
||||
content: Optional[RequestContent] = None,
|
||||
data: Optional[RequestData] = None,
|
||||
files: Optional[RequestFiles] = None,
|
||||
json: Optional[Any] = None,
|
||||
boundary: Optional[bytes] = None,
|
||||
) -> Tuple[Dict[str, str], Union[SyncByteStream, AsyncByteStream]]:
|
||||
content: RequestContent | None = None,
|
||||
data: RequestData | None = None,
|
||||
files: RequestFiles | None = None,
|
||||
json: Any | None = None,
|
||||
boundary: bytes | None = None,
|
||||
) -> tuple[dict[str, str], SyncByteStream | AsyncByteStream]:
|
||||
"""
|
||||
Handles encoding the given `content`, `data`, `files`, and `json`,
|
||||
returning a two-tuple of (<headers>, <stream>).
|
||||
@@ -201,7 +203,7 @@ def encode_request(
|
||||
# `data=<bytes...>` usages. We deal with that case here, treating it
|
||||
# as if `content=<...>` had been supplied instead.
|
||||
message = "Use 'content=<...>' to upload raw bytes/text content."
|
||||
warnings.warn(message, DeprecationWarning)
|
||||
warnings.warn(message, DeprecationWarning, stacklevel=2)
|
||||
return encode_content(data)
|
||||
|
||||
if content is not None:
|
||||
@@ -217,11 +219,11 @@ def encode_request(
|
||||
|
||||
|
||||
def encode_response(
|
||||
content: Optional[ResponseContent] = None,
|
||||
text: Optional[str] = None,
|
||||
html: Optional[str] = None,
|
||||
json: Optional[Any] = None,
|
||||
) -> Tuple[Dict[str, str], Union[SyncByteStream, AsyncByteStream]]:
|
||||
content: ResponseContent | None = None,
|
||||
text: str | None = None,
|
||||
html: str | None = None,
|
||||
json: Any | None = None,
|
||||
) -> tuple[dict[str, str], SyncByteStream | AsyncByteStream]:
|
||||
"""
|
||||
Handles encoding the given `content`, returning a two-tuple of
|
||||
(<headers>, <stream>).
|
||||
|
||||
@@ -3,14 +3,35 @@ Handlers for Content-Encoding.
|
||||
|
||||
See: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Encoding
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import codecs
|
||||
import io
|
||||
import typing
|
||||
import zlib
|
||||
|
||||
from ._compat import brotli
|
||||
from ._exceptions import DecodingError
|
||||
|
||||
# Brotli support is optional
|
||||
try:
|
||||
# The C bindings in `brotli` are recommended for CPython.
|
||||
import brotli
|
||||
except ImportError: # pragma: no cover
|
||||
try:
|
||||
# The CFFI bindings in `brotlicffi` are recommended for PyPy
|
||||
# and other environments.
|
||||
import brotlicffi as brotli
|
||||
except ImportError:
|
||||
brotli = None
|
||||
|
||||
|
||||
# Zstandard support is optional
|
||||
try:
|
||||
import zstandard
|
||||
except ImportError: # pragma: no cover
|
||||
zstandard = None # type: ignore
|
||||
|
||||
|
||||
class ContentDecoder:
|
||||
def decode(self, data: bytes) -> bytes:
|
||||
@@ -137,6 +158,48 @@ class BrotliDecoder(ContentDecoder):
|
||||
raise DecodingError(str(exc)) from exc
|
||||
|
||||
|
||||
class ZStandardDecoder(ContentDecoder):
|
||||
"""
|
||||
Handle 'zstd' RFC 8878 decoding.
|
||||
|
||||
Requires `pip install zstandard`.
|
||||
Can be installed as a dependency of httpx using `pip install httpx[zstd]`.
|
||||
"""
|
||||
|
||||
# inspired by the ZstdDecoder implementation in urllib3
|
||||
def __init__(self) -> None:
|
||||
if zstandard is None: # pragma: no cover
|
||||
raise ImportError(
|
||||
"Using 'ZStandardDecoder', ..."
|
||||
"Make sure to install httpx using `pip install httpx[zstd]`."
|
||||
) from None
|
||||
|
||||
self.decompressor = zstandard.ZstdDecompressor().decompressobj()
|
||||
self.seen_data = False
|
||||
|
||||
def decode(self, data: bytes) -> bytes:
|
||||
assert zstandard is not None
|
||||
self.seen_data = True
|
||||
output = io.BytesIO()
|
||||
try:
|
||||
output.write(self.decompressor.decompress(data))
|
||||
while self.decompressor.eof and self.decompressor.unused_data:
|
||||
unused_data = self.decompressor.unused_data
|
||||
self.decompressor = zstandard.ZstdDecompressor().decompressobj()
|
||||
output.write(self.decompressor.decompress(unused_data))
|
||||
except zstandard.ZstdError as exc:
|
||||
raise DecodingError(str(exc)) from exc
|
||||
return output.getvalue()
|
||||
|
||||
def flush(self) -> bytes:
|
||||
if not self.seen_data:
|
||||
return b""
|
||||
ret = self.decompressor.flush() # note: this is a no-op
|
||||
if not self.decompressor.eof:
|
||||
raise DecodingError("Zstandard data is incomplete") # pragma: no cover
|
||||
return bytes(ret)
|
||||
|
||||
|
||||
class MultiDecoder(ContentDecoder):
|
||||
"""
|
||||
Handle the case where multiple encodings have been applied.
|
||||
@@ -167,11 +230,11 @@ class ByteChunker:
|
||||
Handles returning byte content in fixed-size chunks.
|
||||
"""
|
||||
|
||||
def __init__(self, chunk_size: typing.Optional[int] = None) -> None:
|
||||
def __init__(self, chunk_size: int | None = None) -> None:
|
||||
self._buffer = io.BytesIO()
|
||||
self._chunk_size = chunk_size
|
||||
|
||||
def decode(self, content: bytes) -> typing.List[bytes]:
|
||||
def decode(self, content: bytes) -> list[bytes]:
|
||||
if self._chunk_size is None:
|
||||
return [content] if content else []
|
||||
|
||||
@@ -194,7 +257,7 @@ class ByteChunker:
|
||||
else:
|
||||
return []
|
||||
|
||||
def flush(self) -> typing.List[bytes]:
|
||||
def flush(self) -> list[bytes]:
|
||||
value = self._buffer.getvalue()
|
||||
self._buffer.seek(0)
|
||||
self._buffer.truncate()
|
||||
@@ -206,13 +269,13 @@ class TextChunker:
|
||||
Handles returning text content in fixed-size chunks.
|
||||
"""
|
||||
|
||||
def __init__(self, chunk_size: typing.Optional[int] = None) -> None:
|
||||
def __init__(self, chunk_size: int | None = None) -> None:
|
||||
self._buffer = io.StringIO()
|
||||
self._chunk_size = chunk_size
|
||||
|
||||
def decode(self, content: str) -> typing.List[str]:
|
||||
def decode(self, content: str) -> list[str]:
|
||||
if self._chunk_size is None:
|
||||
return [content]
|
||||
return [content] if content else []
|
||||
|
||||
self._buffer.write(content)
|
||||
if self._buffer.tell() >= self._chunk_size:
|
||||
@@ -233,7 +296,7 @@ class TextChunker:
|
||||
else:
|
||||
return []
|
||||
|
||||
def flush(self) -> typing.List[str]:
|
||||
def flush(self) -> list[str]:
|
||||
value = self._buffer.getvalue()
|
||||
self._buffer.seek(0)
|
||||
self._buffer.truncate()
|
||||
@@ -245,7 +308,7 @@ class TextDecoder:
|
||||
Handles incrementally decoding bytes into text
|
||||
"""
|
||||
|
||||
def __init__(self, encoding: str = "utf-8"):
|
||||
def __init__(self, encoding: str = "utf-8") -> None:
|
||||
self.decoder = codecs.getincrementaldecoder(encoding)(errors="replace")
|
||||
|
||||
def decode(self, data: bytes) -> str:
|
||||
@@ -259,14 +322,15 @@ class LineDecoder:
|
||||
"""
|
||||
Handles incrementally reading lines from text.
|
||||
|
||||
Has the same behaviour as the stdllib splitlines, but handling the input iteratively.
|
||||
Has the same behaviour as the stdllib splitlines,
|
||||
but handling the input iteratively.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.buffer: typing.List[str] = []
|
||||
self.buffer: list[str] = []
|
||||
self.trailing_cr: bool = False
|
||||
|
||||
def decode(self, text: str) -> typing.List[str]:
|
||||
def decode(self, text: str) -> list[str]:
|
||||
# See https://docs.python.org/3/library/stdtypes.html#str.splitlines
|
||||
NEWLINE_CHARS = "\n\r\x0b\x0c\x1c\x1d\x1e\x85\u2028\u2029"
|
||||
|
||||
@@ -279,7 +343,9 @@ class LineDecoder:
|
||||
text = text[:-1]
|
||||
|
||||
if not text:
|
||||
return []
|
||||
# NOTE: the edge case input of empty text doesn't occur in practice,
|
||||
# because other httpx internals filter out this value
|
||||
return [] # pragma: no cover
|
||||
|
||||
trailing_newline = text[-1] in NEWLINE_CHARS
|
||||
lines = text.splitlines()
|
||||
@@ -302,7 +368,7 @@ class LineDecoder:
|
||||
|
||||
return lines
|
||||
|
||||
def flush(self) -> typing.List[str]:
|
||||
def flush(self) -> list[str]:
|
||||
if not self.buffer and not self.trailing_cr:
|
||||
return []
|
||||
|
||||
@@ -317,8 +383,11 @@ SUPPORTED_DECODERS = {
|
||||
"gzip": GZipDecoder,
|
||||
"deflate": DeflateDecoder,
|
||||
"br": BrotliDecoder,
|
||||
"zstd": ZStandardDecoder,
|
||||
}
|
||||
|
||||
|
||||
if brotli is None:
|
||||
SUPPORTED_DECODERS.pop("br") # pragma: no cover
|
||||
if zstandard is None:
|
||||
SUPPORTED_DECODERS.pop("zstd") # pragma: no cover
|
||||
|
||||
@@ -30,12 +30,46 @@ Our exception hierarchy:
|
||||
x ResponseNotRead
|
||||
x RequestNotRead
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import typing
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from ._models import Request, Response # pragma: no cover
|
||||
|
||||
__all__ = [
|
||||
"CloseError",
|
||||
"ConnectError",
|
||||
"ConnectTimeout",
|
||||
"CookieConflict",
|
||||
"DecodingError",
|
||||
"HTTPError",
|
||||
"HTTPStatusError",
|
||||
"InvalidURL",
|
||||
"LocalProtocolError",
|
||||
"NetworkError",
|
||||
"PoolTimeout",
|
||||
"ProtocolError",
|
||||
"ProxyError",
|
||||
"ReadError",
|
||||
"ReadTimeout",
|
||||
"RemoteProtocolError",
|
||||
"RequestError",
|
||||
"RequestNotRead",
|
||||
"ResponseNotRead",
|
||||
"StreamClosed",
|
||||
"StreamConsumed",
|
||||
"StreamError",
|
||||
"TimeoutException",
|
||||
"TooManyRedirects",
|
||||
"TransportError",
|
||||
"UnsupportedProtocol",
|
||||
"WriteError",
|
||||
"WriteTimeout",
|
||||
]
|
||||
|
||||
|
||||
class HTTPError(Exception):
|
||||
"""
|
||||
@@ -57,16 +91,16 @@ class HTTPError(Exception):
|
||||
|
||||
def __init__(self, message: str) -> None:
|
||||
super().__init__(message)
|
||||
self._request: typing.Optional["Request"] = None
|
||||
self._request: Request | None = None
|
||||
|
||||
@property
|
||||
def request(self) -> "Request":
|
||||
def request(self) -> Request:
|
||||
if self._request is None:
|
||||
raise RuntimeError("The .request property has not been set.")
|
||||
return self._request
|
||||
|
||||
@request.setter
|
||||
def request(self, request: "Request") -> None:
|
||||
def request(self, request: Request) -> None:
|
||||
self._request = request
|
||||
|
||||
|
||||
@@ -75,9 +109,7 @@ class RequestError(HTTPError):
|
||||
Base class for all exceptions that may occur when issuing a `.request()`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, message: str, *, request: typing.Optional["Request"] = None
|
||||
) -> None:
|
||||
def __init__(self, message: str, *, request: Request | None = None) -> None:
|
||||
super().__init__(message)
|
||||
# At the point an exception is raised we won't typically have a request
|
||||
# instance to associate it with.
|
||||
@@ -230,9 +262,7 @@ class HTTPStatusError(HTTPError):
|
||||
May be raised when calling `response.raise_for_status()`
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, message: str, *, request: "Request", response: "Response"
|
||||
) -> None:
|
||||
def __init__(self, message: str, *, request: Request, response: Response) -> None:
|
||||
super().__init__(message)
|
||||
self.request = request
|
||||
self.response = response
|
||||
@@ -313,7 +343,10 @@ class ResponseNotRead(StreamError):
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
message = "Attempted to access streaming response content, without having called `read()`."
|
||||
message = (
|
||||
"Attempted to access streaming response content,"
|
||||
" without having called `read()`."
|
||||
)
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
@@ -323,13 +356,16 @@ class RequestNotRead(StreamError):
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
message = "Attempted to access streaming request content, without having called `read()`."
|
||||
message = (
|
||||
"Attempted to access streaming request content,"
|
||||
" without having called `read()`."
|
||||
)
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def request_context(
|
||||
request: typing.Optional["Request"] = None,
|
||||
request: Request | None = None,
|
||||
) -> typing.Iterator[None]:
|
||||
"""
|
||||
A context manager that can be used to attach the given request context
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import json
|
||||
import sys
|
||||
import typing
|
||||
|
||||
import click
|
||||
import httpcore
|
||||
import pygments.lexers
|
||||
import pygments.util
|
||||
import rich.console
|
||||
@@ -18,6 +19,9 @@ from ._exceptions import RequestError
|
||||
from ._models import Response
|
||||
from ._status_codes import codes
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
import httpcore # pragma: no cover
|
||||
|
||||
|
||||
def print_help() -> None:
|
||||
console = rich.console.Console()
|
||||
@@ -63,20 +67,21 @@ def print_help() -> None:
|
||||
)
|
||||
table.add_row(
|
||||
"--auth [cyan]<USER PASS>",
|
||||
"Username and password to include in the request. Specify '-' for the password to use "
|
||||
"a password prompt. Note that using --verbose/-v will expose the Authorization "
|
||||
"header, including the password encoding in a trivially reversible format.",
|
||||
"Username and password to include in the request. Specify '-' for the password"
|
||||
" to use a password prompt. Note that using --verbose/-v will expose"
|
||||
" the Authorization header, including the password encoding"
|
||||
" in a trivially reversible format.",
|
||||
)
|
||||
|
||||
table.add_row(
|
||||
"--proxies [cyan]URL",
|
||||
"--proxy [cyan]URL",
|
||||
"Send the request via a proxy. Should be the URL giving the proxy address.",
|
||||
)
|
||||
|
||||
table.add_row(
|
||||
"--timeout [cyan]FLOAT",
|
||||
"Timeout value to use for network operations, such as establishing the connection, "
|
||||
"reading some data, etc... [Default: 5.0]",
|
||||
"Timeout value to use for network operations, such as establishing the"
|
||||
" connection, reading some data, etc... [Default: 5.0]",
|
||||
)
|
||||
|
||||
table.add_row("--follow-redirects", "Automatically follow redirects.")
|
||||
@@ -124,8 +129,8 @@ def format_request_headers(request: httpcore.Request, http2: bool = False) -> st
|
||||
def format_response_headers(
|
||||
http_version: bytes,
|
||||
status: int,
|
||||
reason_phrase: typing.Optional[bytes],
|
||||
headers: typing.List[typing.Tuple[bytes, bytes]],
|
||||
reason_phrase: bytes | None,
|
||||
headers: list[tuple[bytes, bytes]],
|
||||
) -> str:
|
||||
version = http_version.decode("ascii")
|
||||
reason = (
|
||||
@@ -151,8 +156,8 @@ def print_request_headers(request: httpcore.Request, http2: bool = False) -> Non
|
||||
def print_response_headers(
|
||||
http_version: bytes,
|
||||
status: int,
|
||||
reason_phrase: typing.Optional[bytes],
|
||||
headers: typing.List[typing.Tuple[bytes, bytes]],
|
||||
reason_phrase: bytes | None,
|
||||
headers: list[tuple[bytes, bytes]],
|
||||
) -> None:
|
||||
console = rich.console.Console()
|
||||
http_text = format_response_headers(http_version, status, reason_phrase, headers)
|
||||
@@ -267,7 +272,7 @@ def download_response(response: Response, download: typing.BinaryIO) -> None:
|
||||
|
||||
def validate_json(
|
||||
ctx: click.Context,
|
||||
param: typing.Union[click.Option, click.Parameter],
|
||||
param: click.Option | click.Parameter,
|
||||
value: typing.Any,
|
||||
) -> typing.Any:
|
||||
if value is None:
|
||||
@@ -281,7 +286,7 @@ def validate_json(
|
||||
|
||||
def validate_auth(
|
||||
ctx: click.Context,
|
||||
param: typing.Union[click.Option, click.Parameter],
|
||||
param: click.Option | click.Parameter,
|
||||
value: typing.Any,
|
||||
) -> typing.Any:
|
||||
if value == (None, None):
|
||||
@@ -295,7 +300,7 @@ def validate_auth(
|
||||
|
||||
def handle_help(
|
||||
ctx: click.Context,
|
||||
param: typing.Union[click.Option, click.Parameter],
|
||||
param: click.Option | click.Parameter,
|
||||
value: typing.Any,
|
||||
) -> None:
|
||||
if not value or ctx.resilient_parsing:
|
||||
@@ -385,8 +390,8 @@ def handle_help(
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--proxies",
|
||||
"proxies",
|
||||
"--proxy",
|
||||
"proxy",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Send the request via a proxy. Should be the URL giving the proxy address.",
|
||||
@@ -447,20 +452,20 @@ def handle_help(
|
||||
def main(
|
||||
url: str,
|
||||
method: str,
|
||||
params: typing.List[typing.Tuple[str, str]],
|
||||
params: list[tuple[str, str]],
|
||||
content: str,
|
||||
data: typing.List[typing.Tuple[str, str]],
|
||||
files: typing.List[typing.Tuple[str, click.File]],
|
||||
data: list[tuple[str, str]],
|
||||
files: list[tuple[str, click.File]],
|
||||
json: str,
|
||||
headers: typing.List[typing.Tuple[str, str]],
|
||||
cookies: typing.List[typing.Tuple[str, str]],
|
||||
auth: typing.Optional[typing.Tuple[str, str]],
|
||||
proxies: str,
|
||||
headers: list[tuple[str, str]],
|
||||
cookies: list[tuple[str, str]],
|
||||
auth: tuple[str, str] | None,
|
||||
proxy: str,
|
||||
timeout: float,
|
||||
follow_redirects: bool,
|
||||
verify: bool,
|
||||
http2: bool,
|
||||
download: typing.Optional[typing.BinaryIO],
|
||||
download: typing.BinaryIO | None,
|
||||
verbose: bool,
|
||||
) -> None:
|
||||
"""
|
||||
@@ -471,12 +476,7 @@ def main(
|
||||
method = "POST" if content or data or files or json else "GET"
|
||||
|
||||
try:
|
||||
with Client(
|
||||
proxies=proxies,
|
||||
timeout=timeout,
|
||||
verify=verify,
|
||||
http2=http2,
|
||||
) as client:
|
||||
with Client(proxy=proxy, timeout=timeout, http2=http2, verify=verify) as client:
|
||||
with client.stream(
|
||||
method,
|
||||
url,
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import codecs
|
||||
import datetime
|
||||
import email.message
|
||||
import json as jsonlib
|
||||
import re
|
||||
import typing
|
||||
import urllib.request
|
||||
from collections.abc import Mapping
|
||||
@@ -42,15 +46,94 @@ from ._types import (
|
||||
SyncByteStream,
|
||||
)
|
||||
from ._urls import URL
|
||||
from ._utils import (
|
||||
guess_json_utf,
|
||||
is_known_encoding,
|
||||
normalize_header_key,
|
||||
normalize_header_value,
|
||||
obfuscate_sensitive_headers,
|
||||
parse_content_type_charset,
|
||||
parse_header_links,
|
||||
)
|
||||
from ._utils import to_bytes_or_str, to_str
|
||||
|
||||
__all__ = ["Cookies", "Headers", "Request", "Response"]
|
||||
|
||||
SENSITIVE_HEADERS = {"authorization", "proxy-authorization"}
|
||||
|
||||
|
||||
def _is_known_encoding(encoding: str) -> bool:
|
||||
"""
|
||||
Return `True` if `encoding` is a known codec.
|
||||
"""
|
||||
try:
|
||||
codecs.lookup(encoding)
|
||||
except LookupError:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _normalize_header_key(key: str | bytes, encoding: str | None = None) -> bytes:
|
||||
"""
|
||||
Coerce str/bytes into a strictly byte-wise HTTP header key.
|
||||
"""
|
||||
return key if isinstance(key, bytes) else key.encode(encoding or "ascii")
|
||||
|
||||
|
||||
def _normalize_header_value(value: str | bytes, encoding: str | None = None) -> bytes:
|
||||
"""
|
||||
Coerce str/bytes into a strictly byte-wise HTTP header value.
|
||||
"""
|
||||
if isinstance(value, bytes):
|
||||
return value
|
||||
if not isinstance(value, str):
|
||||
raise TypeError(f"Header value must be str or bytes, not {type(value)}")
|
||||
return value.encode(encoding or "ascii")
|
||||
|
||||
|
||||
def _parse_content_type_charset(content_type: str) -> str | None:
|
||||
# We used to use `cgi.parse_header()` here, but `cgi` became a dead battery.
|
||||
# See: https://peps.python.org/pep-0594/#cgi
|
||||
msg = email.message.Message()
|
||||
msg["content-type"] = content_type
|
||||
return msg.get_content_charset(failobj=None)
|
||||
|
||||
|
||||
def _parse_header_links(value: str) -> list[dict[str, str]]:
|
||||
"""
|
||||
Returns a list of parsed link headers, for more info see:
|
||||
https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Link
|
||||
The generic syntax of those is:
|
||||
Link: < uri-reference >; param1=value1; param2="value2"
|
||||
So for instance:
|
||||
Link; '<http:/.../front.jpeg>; type="image/jpeg",<http://.../back.jpeg>;'
|
||||
would return
|
||||
[
|
||||
{"url": "http:/.../front.jpeg", "type": "image/jpeg"},
|
||||
{"url": "http://.../back.jpeg"},
|
||||
]
|
||||
:param value: HTTP Link entity-header field
|
||||
:return: list of parsed link headers
|
||||
"""
|
||||
links: list[dict[str, str]] = []
|
||||
replace_chars = " '\""
|
||||
value = value.strip(replace_chars)
|
||||
if not value:
|
||||
return links
|
||||
for val in re.split(", *<", value):
|
||||
try:
|
||||
url, params = val.split(";", 1)
|
||||
except ValueError:
|
||||
url, params = val, ""
|
||||
link = {"url": url.strip("<> '\"")}
|
||||
for param in params.split(";"):
|
||||
try:
|
||||
key, value = param.split("=")
|
||||
except ValueError:
|
||||
break
|
||||
link[key.strip(replace_chars)] = value.strip(replace_chars)
|
||||
links.append(link)
|
||||
return links
|
||||
|
||||
|
||||
def _obfuscate_sensitive_headers(
|
||||
items: typing.Iterable[tuple[typing.AnyStr, typing.AnyStr]],
|
||||
) -> typing.Iterator[tuple[typing.AnyStr, typing.AnyStr]]:
|
||||
for k, v in items:
|
||||
if to_str(k.lower()) in SENSITIVE_HEADERS:
|
||||
v = to_bytes_or_str("[secure]", match_type_of=v)
|
||||
yield k, v
|
||||
|
||||
|
||||
class Headers(typing.MutableMapping[str, str]):
|
||||
@@ -60,31 +143,23 @@ class Headers(typing.MutableMapping[str, str]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
headers: typing.Optional[HeaderTypes] = None,
|
||||
encoding: typing.Optional[str] = None,
|
||||
headers: HeaderTypes | None = None,
|
||||
encoding: str | None = None,
|
||||
) -> None:
|
||||
if headers is None:
|
||||
self._list = [] # type: typing.List[typing.Tuple[bytes, bytes, bytes]]
|
||||
elif isinstance(headers, Headers):
|
||||
self._list = [] # type: typing.List[typing.Tuple[bytes, bytes, bytes]]
|
||||
|
||||
if isinstance(headers, Headers):
|
||||
self._list = list(headers._list)
|
||||
elif isinstance(headers, Mapping):
|
||||
self._list = [
|
||||
(
|
||||
normalize_header_key(k, lower=False, encoding=encoding),
|
||||
normalize_header_key(k, lower=True, encoding=encoding),
|
||||
normalize_header_value(v, encoding),
|
||||
)
|
||||
for k, v in headers.items()
|
||||
]
|
||||
else:
|
||||
self._list = [
|
||||
(
|
||||
normalize_header_key(k, lower=False, encoding=encoding),
|
||||
normalize_header_key(k, lower=True, encoding=encoding),
|
||||
normalize_header_value(v, encoding),
|
||||
)
|
||||
for k, v in headers
|
||||
]
|
||||
for k, v in headers.items():
|
||||
bytes_key = _normalize_header_key(k, encoding)
|
||||
bytes_value = _normalize_header_value(v, encoding)
|
||||
self._list.append((bytes_key, bytes_key.lower(), bytes_value))
|
||||
elif headers is not None:
|
||||
for k, v in headers:
|
||||
bytes_key = _normalize_header_key(k, encoding)
|
||||
bytes_value = _normalize_header_value(v, encoding)
|
||||
self._list.append((bytes_key, bytes_key.lower(), bytes_value))
|
||||
|
||||
self._encoding = encoding
|
||||
|
||||
@@ -118,7 +193,7 @@ class Headers(typing.MutableMapping[str, str]):
|
||||
self._encoding = value
|
||||
|
||||
@property
|
||||
def raw(self) -> typing.List[typing.Tuple[bytes, bytes]]:
|
||||
def raw(self) -> list[tuple[bytes, bytes]]:
|
||||
"""
|
||||
Returns a list of the raw header items, as byte pairs.
|
||||
"""
|
||||
@@ -128,7 +203,7 @@ class Headers(typing.MutableMapping[str, str]):
|
||||
return {key.decode(self.encoding): None for _, key, value in self._list}.keys()
|
||||
|
||||
def values(self) -> typing.ValuesView[str]:
|
||||
values_dict: typing.Dict[str, str] = {}
|
||||
values_dict: dict[str, str] = {}
|
||||
for _, key, value in self._list:
|
||||
str_key = key.decode(self.encoding)
|
||||
str_value = value.decode(self.encoding)
|
||||
@@ -143,7 +218,7 @@ class Headers(typing.MutableMapping[str, str]):
|
||||
Return `(key, value)` items of headers. Concatenate headers
|
||||
into a single comma separated value when a key occurs multiple times.
|
||||
"""
|
||||
values_dict: typing.Dict[str, str] = {}
|
||||
values_dict: dict[str, str] = {}
|
||||
for _, key, value in self._list:
|
||||
str_key = key.decode(self.encoding)
|
||||
str_value = value.decode(self.encoding)
|
||||
@@ -153,7 +228,7 @@ class Headers(typing.MutableMapping[str, str]):
|
||||
values_dict[str_key] = str_value
|
||||
return values_dict.items()
|
||||
|
||||
def multi_items(self) -> typing.List[typing.Tuple[str, str]]:
|
||||
def multi_items(self) -> list[tuple[str, str]]:
|
||||
"""
|
||||
Return a list of `(key, value)` pairs of headers. Allow multiple
|
||||
occurrences of the same key without concatenating into a single
|
||||
@@ -174,7 +249,7 @@ class Headers(typing.MutableMapping[str, str]):
|
||||
except KeyError:
|
||||
return default
|
||||
|
||||
def get_list(self, key: str, split_commas: bool = False) -> typing.List[str]:
|
||||
def get_list(self, key: str, split_commas: bool = False) -> list[str]:
|
||||
"""
|
||||
Return a list of all header values for a given key.
|
||||
If `split_commas=True` is passed, then any comma separated header
|
||||
@@ -196,14 +271,14 @@ class Headers(typing.MutableMapping[str, str]):
|
||||
split_values.extend([item.strip() for item in value.split(",")])
|
||||
return split_values
|
||||
|
||||
def update(self, headers: typing.Optional[HeaderTypes] = None) -> None: # type: ignore
|
||||
def update(self, headers: HeaderTypes | None = None) -> None: # type: ignore
|
||||
headers = Headers(headers)
|
||||
for key in headers.keys():
|
||||
if key in self:
|
||||
self.pop(key)
|
||||
self._list.extend(headers._list)
|
||||
|
||||
def copy(self) -> "Headers":
|
||||
def copy(self) -> Headers:
|
||||
return Headers(self, encoding=self.encoding)
|
||||
|
||||
def __getitem__(self, key: str) -> str:
|
||||
@@ -295,7 +370,7 @@ class Headers(typing.MutableMapping[str, str]):
|
||||
if self.encoding != "ascii":
|
||||
encoding_str = f", encoding={self.encoding!r}"
|
||||
|
||||
as_list = list(obfuscate_sensitive_headers(self.multi_items()))
|
||||
as_list = list(_obfuscate_sensitive_headers(self.multi_items()))
|
||||
as_dict = dict(as_list)
|
||||
|
||||
no_duplicate_keys = len(as_dict) == len(as_list)
|
||||
@@ -307,35 +382,29 @@ class Headers(typing.MutableMapping[str, str]):
|
||||
class Request:
|
||||
def __init__(
|
||||
self,
|
||||
method: typing.Union[str, bytes],
|
||||
url: typing.Union["URL", str],
|
||||
method: str,
|
||||
url: URL | str,
|
||||
*,
|
||||
params: typing.Optional[QueryParamTypes] = None,
|
||||
headers: typing.Optional[HeaderTypes] = None,
|
||||
cookies: typing.Optional[CookieTypes] = None,
|
||||
content: typing.Optional[RequestContent] = None,
|
||||
data: typing.Optional[RequestData] = None,
|
||||
files: typing.Optional[RequestFiles] = None,
|
||||
json: typing.Optional[typing.Any] = None,
|
||||
stream: typing.Union[SyncByteStream, AsyncByteStream, None] = None,
|
||||
extensions: typing.Optional[RequestExtensions] = None,
|
||||
):
|
||||
self.method = (
|
||||
method.decode("ascii").upper()
|
||||
if isinstance(method, bytes)
|
||||
else method.upper()
|
||||
)
|
||||
self.url = URL(url)
|
||||
if params is not None:
|
||||
self.url = self.url.copy_merge_params(params=params)
|
||||
params: QueryParamTypes | None = None,
|
||||
headers: HeaderTypes | None = None,
|
||||
cookies: CookieTypes | None = None,
|
||||
content: RequestContent | None = None,
|
||||
data: RequestData | None = None,
|
||||
files: RequestFiles | None = None,
|
||||
json: typing.Any | None = None,
|
||||
stream: SyncByteStream | AsyncByteStream | None = None,
|
||||
extensions: RequestExtensions | None = None,
|
||||
) -> None:
|
||||
self.method = method.upper()
|
||||
self.url = URL(url) if params is None else URL(url, params=params)
|
||||
self.headers = Headers(headers)
|
||||
self.extensions = {} if extensions is None else extensions
|
||||
self.extensions = {} if extensions is None else dict(extensions)
|
||||
|
||||
if cookies:
|
||||
Cookies(cookies).set_cookie_header(self)
|
||||
|
||||
if stream is None:
|
||||
content_type: typing.Optional[str] = self.headers.get("content-type")
|
||||
content_type: str | None = self.headers.get("content-type")
|
||||
headers, stream = encode_request(
|
||||
content=content,
|
||||
data=data,
|
||||
@@ -359,7 +428,8 @@ class Request:
|
||||
# Using `content=...` implies automatically populated `Host` and content
|
||||
# headers, of either `Content-Length: ...` or `Transfer-Encoding: chunked`.
|
||||
#
|
||||
# Using `stream=...` will not automatically include *any* auto-populated headers.
|
||||
# Using `stream=...` will not automatically include *any*
|
||||
# auto-populated headers.
|
||||
#
|
||||
# As an end-user you don't really need `stream=...`. It's only
|
||||
# useful when:
|
||||
@@ -368,14 +438,14 @@ class Request:
|
||||
# * Creating request instances on the *server-side* of the transport API.
|
||||
self.stream = stream
|
||||
|
||||
def _prepare(self, default_headers: typing.Dict[str, str]) -> None:
|
||||
def _prepare(self, default_headers: dict[str, str]) -> None:
|
||||
for key, value in default_headers.items():
|
||||
# Ignore Transfer-Encoding if the Content-Length has been set explicitly.
|
||||
if key.lower() == "transfer-encoding" and "Content-Length" in self.headers:
|
||||
continue
|
||||
self.headers.setdefault(key, value)
|
||||
|
||||
auto_headers: typing.List[typing.Tuple[bytes, bytes]] = []
|
||||
auto_headers: list[tuple[bytes, bytes]] = []
|
||||
|
||||
has_host = "Host" in self.headers
|
||||
has_content_length = (
|
||||
@@ -428,14 +498,14 @@ class Request:
|
||||
url = str(self.url)
|
||||
return f"<{class_name}({self.method!r}, {url!r})>"
|
||||
|
||||
def __getstate__(self) -> typing.Dict[str, typing.Any]:
|
||||
def __getstate__(self) -> dict[str, typing.Any]:
|
||||
return {
|
||||
name: value
|
||||
for name, value in self.__dict__.items()
|
||||
if name not in ["extensions", "stream"]
|
||||
}
|
||||
|
||||
def __setstate__(self, state: typing.Dict[str, typing.Any]) -> None:
|
||||
def __setstate__(self, state: dict[str, typing.Any]) -> None:
|
||||
for name, value in state.items():
|
||||
setattr(self, name, value)
|
||||
self.extensions = {}
|
||||
@@ -447,27 +517,27 @@ class Response:
|
||||
self,
|
||||
status_code: int,
|
||||
*,
|
||||
headers: typing.Optional[HeaderTypes] = None,
|
||||
content: typing.Optional[ResponseContent] = None,
|
||||
text: typing.Optional[str] = None,
|
||||
html: typing.Optional[str] = None,
|
||||
headers: HeaderTypes | None = None,
|
||||
content: ResponseContent | None = None,
|
||||
text: str | None = None,
|
||||
html: str | None = None,
|
||||
json: typing.Any = None,
|
||||
stream: typing.Union[SyncByteStream, AsyncByteStream, None] = None,
|
||||
request: typing.Optional[Request] = None,
|
||||
extensions: typing.Optional[ResponseExtensions] = None,
|
||||
history: typing.Optional[typing.List["Response"]] = None,
|
||||
default_encoding: typing.Union[str, typing.Callable[[bytes], str]] = "utf-8",
|
||||
):
|
||||
stream: SyncByteStream | AsyncByteStream | None = None,
|
||||
request: Request | None = None,
|
||||
extensions: ResponseExtensions | None = None,
|
||||
history: list[Response] | None = None,
|
||||
default_encoding: str | typing.Callable[[bytes], str] = "utf-8",
|
||||
) -> None:
|
||||
self.status_code = status_code
|
||||
self.headers = Headers(headers)
|
||||
|
||||
self._request: typing.Optional[Request] = request
|
||||
self._request: Request | None = request
|
||||
|
||||
# When follow_redirects=False and a redirect is received,
|
||||
# the client will set `response.next_request`.
|
||||
self.next_request: typing.Optional[Request] = None
|
||||
self.next_request: Request | None = None
|
||||
|
||||
self.extensions = {} if extensions is None else extensions
|
||||
self.extensions = {} if extensions is None else dict(extensions)
|
||||
self.history = [] if history is None else list(history)
|
||||
|
||||
self.is_closed = False
|
||||
@@ -498,7 +568,7 @@ class Response:
|
||||
|
||||
self._num_bytes_downloaded = 0
|
||||
|
||||
def _prepare(self, default_headers: typing.Dict[str, str]) -> None:
|
||||
def _prepare(self, default_headers: dict[str, str]) -> None:
|
||||
for key, value in default_headers.items():
|
||||
# Ignore Transfer-Encoding if the Content-Length has been set explicitly.
|
||||
if key.lower() == "transfer-encoding" and "content-length" in self.headers:
|
||||
@@ -580,7 +650,7 @@ class Response:
|
||||
return self._text
|
||||
|
||||
@property
|
||||
def encoding(self) -> typing.Optional[str]:
|
||||
def encoding(self) -> str | None:
|
||||
"""
|
||||
Return an encoding to use for decoding the byte content into text.
|
||||
The priority for determining this is given by...
|
||||
@@ -593,7 +663,7 @@ class Response:
|
||||
"""
|
||||
if not hasattr(self, "_encoding"):
|
||||
encoding = self.charset_encoding
|
||||
if encoding is None or not is_known_encoding(encoding):
|
||||
if encoding is None or not _is_known_encoding(encoding):
|
||||
if isinstance(self.default_encoding, str):
|
||||
encoding = self.default_encoding
|
||||
elif hasattr(self, "_content"):
|
||||
@@ -603,10 +673,20 @@ class Response:
|
||||
|
||||
@encoding.setter
|
||||
def encoding(self, value: str) -> None:
|
||||
"""
|
||||
Set the encoding to use for decoding the byte content into text.
|
||||
|
||||
If the `text` attribute has been accessed, attempting to set the
|
||||
encoding will throw a ValueError.
|
||||
"""
|
||||
if hasattr(self, "_text"):
|
||||
raise ValueError(
|
||||
"Setting encoding after `text` has been accessed is not allowed."
|
||||
)
|
||||
self._encoding = value
|
||||
|
||||
@property
|
||||
def charset_encoding(self) -> typing.Optional[str]:
|
||||
def charset_encoding(self) -> str | None:
|
||||
"""
|
||||
Return the encoding, as specified by the Content-Type header.
|
||||
"""
|
||||
@@ -614,7 +694,7 @@ class Response:
|
||||
if content_type is None:
|
||||
return None
|
||||
|
||||
return parse_content_type_charset(content_type)
|
||||
return _parse_content_type_charset(content_type)
|
||||
|
||||
def _get_content_decoder(self) -> ContentDecoder:
|
||||
"""
|
||||
@@ -622,7 +702,7 @@ class Response:
|
||||
content, depending on the Content-Encoding used in the response.
|
||||
"""
|
||||
if not hasattr(self, "_decoder"):
|
||||
decoders: typing.List[ContentDecoder] = []
|
||||
decoders: list[ContentDecoder] = []
|
||||
values = self.headers.get_list("content-encoding", split_commas=True)
|
||||
for value in values:
|
||||
value = value.strip().lower()
|
||||
@@ -711,7 +791,7 @@ class Response:
|
||||
and "Location" in self.headers
|
||||
)
|
||||
|
||||
def raise_for_status(self) -> None:
|
||||
def raise_for_status(self) -> Response:
|
||||
"""
|
||||
Raise the `HTTPStatusError` if one occurred.
|
||||
"""
|
||||
@@ -723,18 +803,18 @@ class Response:
|
||||
)
|
||||
|
||||
if self.is_success:
|
||||
return
|
||||
return self
|
||||
|
||||
if self.has_redirect_location:
|
||||
message = (
|
||||
"{error_type} '{0.status_code} {0.reason_phrase}' for url '{0.url}'\n"
|
||||
"Redirect location: '{0.headers[location]}'\n"
|
||||
"For more information check: https://httpstatuses.com/{0.status_code}"
|
||||
"For more information check: https://developer.mozilla.org/en-US/docs/Web/HTTP/Status/{0.status_code}"
|
||||
)
|
||||
else:
|
||||
message = (
|
||||
"{error_type} '{0.status_code} {0.reason_phrase}' for url '{0.url}'\n"
|
||||
"For more information check: https://httpstatuses.com/{0.status_code}"
|
||||
"For more information check: https://developer.mozilla.org/en-US/docs/Web/HTTP/Status/{0.status_code}"
|
||||
)
|
||||
|
||||
status_class = self.status_code // 100
|
||||
@@ -749,32 +829,28 @@ class Response:
|
||||
raise HTTPStatusError(message, request=request, response=self)
|
||||
|
||||
def json(self, **kwargs: typing.Any) -> typing.Any:
|
||||
if self.charset_encoding is None and self.content and len(self.content) > 3:
|
||||
encoding = guess_json_utf(self.content)
|
||||
if encoding is not None:
|
||||
return jsonlib.loads(self.content.decode(encoding), **kwargs)
|
||||
return jsonlib.loads(self.text, **kwargs)
|
||||
return jsonlib.loads(self.content, **kwargs)
|
||||
|
||||
@property
|
||||
def cookies(self) -> "Cookies":
|
||||
def cookies(self) -> Cookies:
|
||||
if not hasattr(self, "_cookies"):
|
||||
self._cookies = Cookies()
|
||||
self._cookies.extract_cookies(self)
|
||||
return self._cookies
|
||||
|
||||
@property
|
||||
def links(self) -> typing.Dict[typing.Optional[str], typing.Dict[str, str]]:
|
||||
def links(self) -> dict[str | None, dict[str, str]]:
|
||||
"""
|
||||
Returns the parsed header links of the response, if any
|
||||
"""
|
||||
header = self.headers.get("link")
|
||||
ldict = {}
|
||||
if header:
|
||||
links = parse_header_links(header)
|
||||
for link in links:
|
||||
key = link.get("rel") or link.get("url")
|
||||
ldict[key] = link
|
||||
return ldict
|
||||
if header is None:
|
||||
return {}
|
||||
|
||||
return {
|
||||
(link.get("rel") or link.get("url")): link
|
||||
for link in _parse_header_links(header)
|
||||
}
|
||||
|
||||
@property
|
||||
def num_bytes_downloaded(self) -> int:
|
||||
@@ -783,14 +859,14 @@ class Response:
|
||||
def __repr__(self) -> str:
|
||||
return f"<Response [{self.status_code} {self.reason_phrase}]>"
|
||||
|
||||
def __getstate__(self) -> typing.Dict[str, typing.Any]:
|
||||
def __getstate__(self) -> dict[str, typing.Any]:
|
||||
return {
|
||||
name: value
|
||||
for name, value in self.__dict__.items()
|
||||
if name not in ["extensions", "stream", "is_closed", "_decoder"]
|
||||
}
|
||||
|
||||
def __setstate__(self, state: typing.Dict[str, typing.Any]) -> None:
|
||||
def __setstate__(self, state: dict[str, typing.Any]) -> None:
|
||||
for name, value in state.items():
|
||||
setattr(self, name, value)
|
||||
self.is_closed = True
|
||||
@@ -805,12 +881,10 @@ class Response:
|
||||
self._content = b"".join(self.iter_bytes())
|
||||
return self._content
|
||||
|
||||
def iter_bytes(
|
||||
self, chunk_size: typing.Optional[int] = None
|
||||
) -> typing.Iterator[bytes]:
|
||||
def iter_bytes(self, chunk_size: int | None = None) -> typing.Iterator[bytes]:
|
||||
"""
|
||||
A byte-iterator over the decoded response content.
|
||||
This allows us to handle gzip, deflate, and brotli encoded responses.
|
||||
This allows us to handle gzip, deflate, brotli, and zstd encoded responses.
|
||||
"""
|
||||
if hasattr(self, "_content"):
|
||||
chunk_size = len(self._content) if chunk_size is None else chunk_size
|
||||
@@ -830,9 +904,7 @@ class Response:
|
||||
for chunk in chunker.flush():
|
||||
yield chunk
|
||||
|
||||
def iter_text(
|
||||
self, chunk_size: typing.Optional[int] = None
|
||||
) -> typing.Iterator[str]:
|
||||
def iter_text(self, chunk_size: int | None = None) -> typing.Iterator[str]:
|
||||
"""
|
||||
A str-iterator over the decoded response content
|
||||
that handles both gzip, deflate, etc but also detects the content's
|
||||
@@ -847,7 +919,7 @@ class Response:
|
||||
yield chunk
|
||||
text_content = decoder.flush()
|
||||
for chunk in chunker.decode(text_content):
|
||||
yield chunk
|
||||
yield chunk # pragma: no cover
|
||||
for chunk in chunker.flush():
|
||||
yield chunk
|
||||
|
||||
@@ -860,9 +932,7 @@ class Response:
|
||||
for line in decoder.flush():
|
||||
yield line
|
||||
|
||||
def iter_raw(
|
||||
self, chunk_size: typing.Optional[int] = None
|
||||
) -> typing.Iterator[bytes]:
|
||||
def iter_raw(self, chunk_size: int | None = None) -> typing.Iterator[bytes]:
|
||||
"""
|
||||
A byte-iterator over the raw response content.
|
||||
"""
|
||||
@@ -910,11 +980,11 @@ class Response:
|
||||
return self._content
|
||||
|
||||
async def aiter_bytes(
|
||||
self, chunk_size: typing.Optional[int] = None
|
||||
self, chunk_size: int | None = None
|
||||
) -> typing.AsyncIterator[bytes]:
|
||||
"""
|
||||
A byte-iterator over the decoded response content.
|
||||
This allows us to handle gzip, deflate, and brotli encoded responses.
|
||||
This allows us to handle gzip, deflate, brotli, and zstd encoded responses.
|
||||
"""
|
||||
if hasattr(self, "_content"):
|
||||
chunk_size = len(self._content) if chunk_size is None else chunk_size
|
||||
@@ -935,7 +1005,7 @@ class Response:
|
||||
yield chunk
|
||||
|
||||
async def aiter_text(
|
||||
self, chunk_size: typing.Optional[int] = None
|
||||
self, chunk_size: int | None = None
|
||||
) -> typing.AsyncIterator[str]:
|
||||
"""
|
||||
A str-iterator over the decoded response content
|
||||
@@ -951,7 +1021,7 @@ class Response:
|
||||
yield chunk
|
||||
text_content = decoder.flush()
|
||||
for chunk in chunker.decode(text_content):
|
||||
yield chunk
|
||||
yield chunk # pragma: no cover
|
||||
for chunk in chunker.flush():
|
||||
yield chunk
|
||||
|
||||
@@ -965,7 +1035,7 @@ class Response:
|
||||
yield line
|
||||
|
||||
async def aiter_raw(
|
||||
self, chunk_size: typing.Optional[int] = None
|
||||
self, chunk_size: int | None = None
|
||||
) -> typing.AsyncIterator[bytes]:
|
||||
"""
|
||||
A byte-iterator over the raw response content.
|
||||
@@ -1011,7 +1081,7 @@ class Cookies(typing.MutableMapping[str, str]):
|
||||
HTTP Cookies, as a mutable mapping.
|
||||
"""
|
||||
|
||||
def __init__(self, cookies: typing.Optional[CookieTypes] = None) -> None:
|
||||
def __init__(self, cookies: CookieTypes | None = None) -> None:
|
||||
if cookies is None or isinstance(cookies, dict):
|
||||
self.jar = CookieJar()
|
||||
if isinstance(cookies, dict):
|
||||
@@ -1073,10 +1143,10 @@ class Cookies(typing.MutableMapping[str, str]):
|
||||
def get( # type: ignore
|
||||
self,
|
||||
name: str,
|
||||
default: typing.Optional[str] = None,
|
||||
domain: typing.Optional[str] = None,
|
||||
path: typing.Optional[str] = None,
|
||||
) -> typing.Optional[str]:
|
||||
default: str | None = None,
|
||||
domain: str | None = None,
|
||||
path: str | None = None,
|
||||
) -> str | None:
|
||||
"""
|
||||
Get a cookie by name. May optionally include domain and path
|
||||
in order to specify exactly which cookie to retrieve.
|
||||
@@ -1098,8 +1168,8 @@ class Cookies(typing.MutableMapping[str, str]):
|
||||
def delete(
|
||||
self,
|
||||
name: str,
|
||||
domain: typing.Optional[str] = None,
|
||||
path: typing.Optional[str] = None,
|
||||
domain: str | None = None,
|
||||
path: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Delete a cookie by name. May optionally include domain and path
|
||||
@@ -1119,9 +1189,7 @@ class Cookies(typing.MutableMapping[str, str]):
|
||||
for cookie in remove:
|
||||
self.jar.clear(cookie.domain, cookie.path, cookie.name)
|
||||
|
||||
def clear(
|
||||
self, domain: typing.Optional[str] = None, path: typing.Optional[str] = None
|
||||
) -> None:
|
||||
def clear(self, domain: str | None = None, path: str | None = None) -> None:
|
||||
"""
|
||||
Delete all cookies. Optionally include a domain and path in
|
||||
order to only delete a subset of all the cookies.
|
||||
@@ -1134,7 +1202,7 @@ class Cookies(typing.MutableMapping[str, str]):
|
||||
args.append(path)
|
||||
self.jar.clear(*args)
|
||||
|
||||
def update(self, cookies: typing.Optional[CookieTypes] = None) -> None: # type: ignore
|
||||
def update(self, cookies: CookieTypes | None = None) -> None: # type: ignore
|
||||
cookies = Cookies(cookies)
|
||||
for cookie in cookies.jar:
|
||||
self.jar.set_cookie(cookie)
|
||||
@@ -1196,7 +1264,7 @@ class Cookies(typing.MutableMapping[str, str]):
|
||||
for use with `CookieJar` operations.
|
||||
"""
|
||||
|
||||
def __init__(self, response: Response):
|
||||
def __init__(self, response: Response) -> None:
|
||||
self.response = response
|
||||
|
||||
def info(self) -> email.message.Message:
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
import binascii
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import mimetypes
|
||||
import os
|
||||
import re
|
||||
import typing
|
||||
from pathlib import Path
|
||||
|
||||
@@ -13,17 +16,46 @@ from ._types import (
|
||||
SyncByteStream,
|
||||
)
|
||||
from ._utils import (
|
||||
format_form_param,
|
||||
guess_content_type,
|
||||
peek_filelike_length,
|
||||
primitive_value_to_str,
|
||||
to_bytes,
|
||||
)
|
||||
|
||||
_HTML5_FORM_ENCODING_REPLACEMENTS = {'"': "%22", "\\": "\\\\"}
|
||||
_HTML5_FORM_ENCODING_REPLACEMENTS.update(
|
||||
{chr(c): "%{:02X}".format(c) for c in range(0x1F + 1) if c != 0x1B}
|
||||
)
|
||||
_HTML5_FORM_ENCODING_RE = re.compile(
|
||||
r"|".join([re.escape(c) for c in _HTML5_FORM_ENCODING_REPLACEMENTS.keys()])
|
||||
)
|
||||
|
||||
|
||||
def _format_form_param(name: str, value: str) -> bytes:
|
||||
"""
|
||||
Encode a name/value pair within a multipart form.
|
||||
"""
|
||||
|
||||
def replacer(match: typing.Match[str]) -> str:
|
||||
return _HTML5_FORM_ENCODING_REPLACEMENTS[match.group(0)]
|
||||
|
||||
value = _HTML5_FORM_ENCODING_RE.sub(replacer, value)
|
||||
return f'{name}="{value}"'.encode()
|
||||
|
||||
|
||||
def _guess_content_type(filename: str | None) -> str | None:
|
||||
"""
|
||||
Guesses the mimetype based on a filename. Defaults to `application/octet-stream`.
|
||||
|
||||
Returns `None` if `filename` is `None` or empty.
|
||||
"""
|
||||
if filename:
|
||||
return mimetypes.guess_type(filename)[0] or "application/octet-stream"
|
||||
return None
|
||||
|
||||
|
||||
def get_multipart_boundary_from_content_type(
|
||||
content_type: typing.Optional[bytes],
|
||||
) -> typing.Optional[bytes]:
|
||||
content_type: bytes | None,
|
||||
) -> bytes | None:
|
||||
if not content_type or not content_type.startswith(b"multipart/form-data"):
|
||||
return None
|
||||
# parse boundary according to
|
||||
@@ -40,25 +72,24 @@ class DataField:
|
||||
A single form field item, within a multipart form field.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, name: str, value: typing.Union[str, bytes, int, float, None]
|
||||
) -> None:
|
||||
def __init__(self, name: str, value: str | bytes | int | float | None) -> None:
|
||||
if not isinstance(name, str):
|
||||
raise TypeError(
|
||||
f"Invalid type for name. Expected str, got {type(name)}: {name!r}"
|
||||
)
|
||||
if value is not None and not isinstance(value, (str, bytes, int, float)):
|
||||
raise TypeError(
|
||||
f"Invalid type for value. Expected primitive type, got {type(value)}: {value!r}"
|
||||
"Invalid type for value. Expected primitive type,"
|
||||
f" got {type(value)}: {value!r}"
|
||||
)
|
||||
self.name = name
|
||||
self.value: typing.Union[str, bytes] = (
|
||||
self.value: str | bytes = (
|
||||
value if isinstance(value, bytes) else primitive_value_to_str(value)
|
||||
)
|
||||
|
||||
def render_headers(self) -> bytes:
|
||||
if not hasattr(self, "_headers"):
|
||||
name = format_form_param("name", self.name)
|
||||
name = _format_form_param("name", self.name)
|
||||
self._headers = b"".join(
|
||||
[b"Content-Disposition: form-data; ", name, b"\r\n\r\n"]
|
||||
)
|
||||
@@ -93,18 +124,20 @@ class FileField:
|
||||
|
||||
fileobj: FileContent
|
||||
|
||||
headers: typing.Dict[str, str] = {}
|
||||
content_type: typing.Optional[str] = None
|
||||
headers: dict[str, str] = {}
|
||||
content_type: str | None = None
|
||||
|
||||
# This large tuple based API largely mirror's requests' API
|
||||
# It would be good to think of better APIs for this that we could include in httpx 2.0
|
||||
# since variable length tuples (especially of 4 elements) are quite unwieldly
|
||||
# It would be good to think of better APIs for this that we could
|
||||
# include in httpx 2.0 since variable length tuples(especially of 4 elements)
|
||||
# are quite unwieldly
|
||||
if isinstance(value, tuple):
|
||||
if len(value) == 2:
|
||||
# neither the 3rd parameter (content_type) nor the 4th (headers) was included
|
||||
filename, fileobj = value # type: ignore
|
||||
# neither the 3rd parameter (content_type) nor the 4th (headers)
|
||||
# was included
|
||||
filename, fileobj = value
|
||||
elif len(value) == 3:
|
||||
filename, fileobj, content_type = value # type: ignore
|
||||
filename, fileobj, content_type = value
|
||||
else:
|
||||
# all 4 parameters included
|
||||
filename, fileobj, content_type, headers = value # type: ignore
|
||||
@@ -113,13 +146,13 @@ class FileField:
|
||||
fileobj = value
|
||||
|
||||
if content_type is None:
|
||||
content_type = guess_content_type(filename)
|
||||
content_type = _guess_content_type(filename)
|
||||
|
||||
has_content_type_header = any("content-type" in key.lower() for key in headers)
|
||||
if content_type is not None and not has_content_type_header:
|
||||
# note that unlike requests, we ignore the content_type
|
||||
# provided in the 3rd tuple element if it is also included in the headers
|
||||
# requests does the opposite (it overwrites the header with the 3rd tuple element)
|
||||
# note that unlike requests, we ignore the content_type provided in the 3rd
|
||||
# tuple element if it is also included in the headers requests does
|
||||
# the opposite (it overwrites the headerwith the 3rd tuple element)
|
||||
headers["Content-Type"] = content_type
|
||||
|
||||
if isinstance(fileobj, io.StringIO):
|
||||
@@ -135,7 +168,7 @@ class FileField:
|
||||
self.file = fileobj
|
||||
self.headers = headers
|
||||
|
||||
def get_length(self) -> typing.Optional[int]:
|
||||
def get_length(self) -> int | None:
|
||||
headers = self.render_headers()
|
||||
|
||||
if isinstance(self.file, (str, bytes)):
|
||||
@@ -154,10 +187,10 @@ class FileField:
|
||||
if not hasattr(self, "_headers"):
|
||||
parts = [
|
||||
b"Content-Disposition: form-data; ",
|
||||
format_form_param("name", self.name),
|
||||
_format_form_param("name", self.name),
|
||||
]
|
||||
if self.filename:
|
||||
filename = format_form_param("filename", self.filename)
|
||||
filename = _format_form_param("filename", self.filename)
|
||||
parts.extend([b"; ", filename])
|
||||
for header_name, header_value in self.headers.items():
|
||||
key, val = f"\r\n{header_name}: ".encode(), header_value.encode()
|
||||
@@ -197,10 +230,10 @@ class MultipartStream(SyncByteStream, AsyncByteStream):
|
||||
self,
|
||||
data: RequestData,
|
||||
files: RequestFiles,
|
||||
boundary: typing.Optional[bytes] = None,
|
||||
boundary: bytes | None = None,
|
||||
) -> None:
|
||||
if boundary is None:
|
||||
boundary = binascii.hexlify(os.urandom(16))
|
||||
boundary = os.urandom(16).hex().encode("ascii")
|
||||
|
||||
self.boundary = boundary
|
||||
self.content_type = "multipart/form-data; boundary=%s" % boundary.decode(
|
||||
@@ -210,7 +243,7 @@ class MultipartStream(SyncByteStream, AsyncByteStream):
|
||||
|
||||
def _iter_fields(
|
||||
self, data: RequestData, files: RequestFiles
|
||||
) -> typing.Iterator[typing.Union[FileField, DataField]]:
|
||||
) -> typing.Iterator[FileField | DataField]:
|
||||
for name, value in data.items():
|
||||
if isinstance(value, (tuple, list)):
|
||||
for item in value:
|
||||
@@ -229,7 +262,7 @@ class MultipartStream(SyncByteStream, AsyncByteStream):
|
||||
yield b"\r\n"
|
||||
yield b"--%s--\r\n" % self.boundary
|
||||
|
||||
def get_content_length(self) -> typing.Optional[int]:
|
||||
def get_content_length(self) -> int | None:
|
||||
"""
|
||||
Return the length of the multipart encoded content, or `None` if
|
||||
any of the files have a length that cannot be determined upfront.
|
||||
@@ -251,7 +284,7 @@ class MultipartStream(SyncByteStream, AsyncByteStream):
|
||||
|
||||
# Content stream interface.
|
||||
|
||||
def get_headers(self) -> typing.Dict[str, str]:
|
||||
def get_headers(self) -> dict[str, str]:
|
||||
content_length = self.get_content_length()
|
||||
content_type = self.content_type
|
||||
if content_length is None:
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import IntEnum
|
||||
|
||||
__all__ = ["codes"]
|
||||
|
||||
|
||||
class codes(IntEnum):
|
||||
"""HTTP status codes and reason phrases
|
||||
@@ -21,7 +25,7 @@ class codes(IntEnum):
|
||||
* RFC 8470: Using Early Data in HTTP
|
||||
"""
|
||||
|
||||
def __new__(cls, value: int, phrase: str = "") -> "codes":
|
||||
def __new__(cls, value: int, phrase: str = "") -> codes:
|
||||
obj = int.__new__(cls, value)
|
||||
obj._value_ = value
|
||||
|
||||
|
||||
@@ -0,0 +1,15 @@
|
||||
from .asgi import *
|
||||
from .base import *
|
||||
from .default import *
|
||||
from .mock import *
|
||||
from .wsgi import *
|
||||
|
||||
__all__ = [
|
||||
"ASGITransport",
|
||||
"AsyncBaseTransport",
|
||||
"BaseTransport",
|
||||
"AsyncHTTPTransport",
|
||||
"HTTPTransport",
|
||||
"MockTransport",
|
||||
"WSGITransport",
|
||||
]
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,6 +1,6 @@
|
||||
import typing
|
||||
from __future__ import annotations
|
||||
|
||||
import sniffio
|
||||
import typing
|
||||
|
||||
from .._models import Request, Response
|
||||
from .._types import AsyncByteStream
|
||||
@@ -14,29 +14,46 @@ if typing.TYPE_CHECKING: # pragma: no cover
|
||||
Event = typing.Union[asyncio.Event, trio.Event]
|
||||
|
||||
|
||||
_Message = typing.Dict[str, typing.Any]
|
||||
_Message = typing.MutableMapping[str, typing.Any]
|
||||
_Receive = typing.Callable[[], typing.Awaitable[_Message]]
|
||||
_Send = typing.Callable[
|
||||
[typing.Dict[str, typing.Any]], typing.Coroutine[None, None, None]
|
||||
[typing.MutableMapping[str, typing.Any]], typing.Awaitable[None]
|
||||
]
|
||||
_ASGIApp = typing.Callable[
|
||||
[typing.Dict[str, typing.Any], _Receive, _Send], typing.Coroutine[None, None, None]
|
||||
[typing.MutableMapping[str, typing.Any], _Receive, _Send], typing.Awaitable[None]
|
||||
]
|
||||
|
||||
__all__ = ["ASGITransport"]
|
||||
|
||||
def create_event() -> "Event":
|
||||
if sniffio.current_async_library() == "trio":
|
||||
|
||||
def is_running_trio() -> bool:
|
||||
try:
|
||||
# sniffio is a dependency of trio.
|
||||
|
||||
# See https://github.com/python-trio/trio/issues/2802
|
||||
import sniffio
|
||||
|
||||
if sniffio.current_async_library() == "trio":
|
||||
return True
|
||||
except ImportError: # pragma: nocover
|
||||
pass
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def create_event() -> Event:
|
||||
if is_running_trio():
|
||||
import trio
|
||||
|
||||
return trio.Event()
|
||||
else:
|
||||
import asyncio
|
||||
|
||||
return asyncio.Event()
|
||||
import asyncio
|
||||
|
||||
return asyncio.Event()
|
||||
|
||||
|
||||
class ASGIResponseStream(AsyncByteStream):
|
||||
def __init__(self, body: typing.List[bytes]) -> None:
|
||||
def __init__(self, body: list[bytes]) -> None:
|
||||
self._body = body
|
||||
|
||||
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
|
||||
@@ -46,17 +63,8 @@ class ASGIResponseStream(AsyncByteStream):
|
||||
class ASGITransport(AsyncBaseTransport):
|
||||
"""
|
||||
A custom AsyncTransport that handles sending requests directly to an ASGI app.
|
||||
The simplest way to use this functionality is to use the `app` argument.
|
||||
|
||||
```
|
||||
client = httpx.AsyncClient(app=app)
|
||||
```
|
||||
|
||||
Alternatively, you can setup the transport instance explicitly.
|
||||
This allows you to include any additional configuration arguments specific
|
||||
to the ASGITransport class:
|
||||
|
||||
```
|
||||
```python
|
||||
transport = httpx.ASGITransport(
|
||||
app=app,
|
||||
root_path="/submount",
|
||||
@@ -81,7 +89,7 @@ class ASGITransport(AsyncBaseTransport):
|
||||
app: _ASGIApp,
|
||||
raise_app_exceptions: bool = True,
|
||||
root_path: str = "",
|
||||
client: typing.Tuple[str, int] = ("127.0.0.1", 123),
|
||||
client: tuple[str, int] = ("127.0.0.1", 123),
|
||||
) -> None:
|
||||
self.app = app
|
||||
self.raise_app_exceptions = raise_app_exceptions
|
||||
@@ -103,7 +111,7 @@ class ASGITransport(AsyncBaseTransport):
|
||||
"headers": [(k.lower(), v) for (k, v) in request.headers.raw],
|
||||
"scheme": request.url.scheme,
|
||||
"path": request.url.path,
|
||||
"raw_path": request.url.raw_path,
|
||||
"raw_path": request.url.raw_path.split(b"?")[0],
|
||||
"query_string": request.url.query,
|
||||
"server": (request.url.host, request.url.port),
|
||||
"client": self.client,
|
||||
@@ -123,7 +131,7 @@ class ASGITransport(AsyncBaseTransport):
|
||||
|
||||
# ASGI callables.
|
||||
|
||||
async def receive() -> typing.Dict[str, typing.Any]:
|
||||
async def receive() -> dict[str, typing.Any]:
|
||||
nonlocal request_complete
|
||||
|
||||
if request_complete:
|
||||
@@ -137,7 +145,7 @@ class ASGITransport(AsyncBaseTransport):
|
||||
return {"type": "http.request", "body": b"", "more_body": False}
|
||||
return {"type": "http.request", "body": body, "more_body": True}
|
||||
|
||||
async def send(message: typing.Dict[str, typing.Any]) -> None:
|
||||
async def send(message: typing.MutableMapping[str, typing.Any]) -> None:
|
||||
nonlocal status_code, response_headers, response_started
|
||||
|
||||
if message["type"] == "http.response.start":
|
||||
@@ -161,9 +169,15 @@ class ASGITransport(AsyncBaseTransport):
|
||||
try:
|
||||
await self.app(scope, receive, send)
|
||||
except Exception: # noqa: PIE-786
|
||||
if self.raise_app_exceptions or not response_complete.is_set():
|
||||
if self.raise_app_exceptions:
|
||||
raise
|
||||
|
||||
response_complete.set()
|
||||
if status_code is None:
|
||||
status_code = 500
|
||||
if response_headers is None:
|
||||
response_headers = {}
|
||||
|
||||
assert response_complete.is_set()
|
||||
assert status_code is not None
|
||||
assert response_headers is not None
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
from types import TracebackType
|
||||
|
||||
@@ -6,6 +8,8 @@ from .._models import Request, Response
|
||||
T = typing.TypeVar("T", bound="BaseTransport")
|
||||
A = typing.TypeVar("A", bound="AsyncBaseTransport")
|
||||
|
||||
__all__ = ["AsyncBaseTransport", "BaseTransport"]
|
||||
|
||||
|
||||
class BaseTransport:
|
||||
def __enter__(self: T) -> T:
|
||||
@@ -13,9 +17,9 @@ class BaseTransport:
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: typing.Optional[typing.Type[BaseException]] = None,
|
||||
exc_value: typing.Optional[BaseException] = None,
|
||||
traceback: typing.Optional[TracebackType] = None,
|
||||
exc_type: type[BaseException] | None = None,
|
||||
exc_value: BaseException | None = None,
|
||||
traceback: TracebackType | None = None,
|
||||
) -> None:
|
||||
self.close()
|
||||
|
||||
@@ -64,9 +68,9 @@ class AsyncBaseTransport:
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: typing.Optional[typing.Type[BaseException]] = None,
|
||||
exc_value: typing.Optional[BaseException] = None,
|
||||
traceback: typing.Optional[TracebackType] = None,
|
||||
exc_type: type[BaseException] | None = None,
|
||||
exc_value: BaseException | None = None,
|
||||
traceback: TracebackType | None = None,
|
||||
) -> None:
|
||||
await self.aclose()
|
||||
|
||||
|
||||
@@ -23,11 +23,17 @@ client = httpx.Client(transport=transport)
|
||||
transport = httpx.HTTPTransport(uds="socket.uds")
|
||||
client = httpx.Client(transport=transport)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import typing
|
||||
from types import TracebackType
|
||||
|
||||
import httpcore
|
||||
if typing.TYPE_CHECKING:
|
||||
import ssl # pragma: no cover
|
||||
|
||||
import httpx # pragma: no cover
|
||||
|
||||
from .._config import DEFAULT_LIMITS, Limits, Proxy, create_ssl_context
|
||||
from .._exceptions import (
|
||||
@@ -47,18 +53,53 @@ from .._exceptions import (
|
||||
WriteTimeout,
|
||||
)
|
||||
from .._models import Request, Response
|
||||
from .._types import AsyncByteStream, CertTypes, SyncByteStream, VerifyTypes
|
||||
from .._types import AsyncByteStream, CertTypes, ProxyTypes, SyncByteStream
|
||||
from .._urls import URL
|
||||
from .base import AsyncBaseTransport, BaseTransport
|
||||
|
||||
T = typing.TypeVar("T", bound="HTTPTransport")
|
||||
A = typing.TypeVar("A", bound="AsyncHTTPTransport")
|
||||
|
||||
SOCKET_OPTION = typing.Union[
|
||||
typing.Tuple[int, int, int],
|
||||
typing.Tuple[int, int, typing.Union[bytes, bytearray]],
|
||||
typing.Tuple[int, int, None, int],
|
||||
]
|
||||
|
||||
__all__ = ["AsyncHTTPTransport", "HTTPTransport"]
|
||||
|
||||
HTTPCORE_EXC_MAP: dict[type[Exception], type[httpx.HTTPError]] = {}
|
||||
|
||||
|
||||
def _load_httpcore_exceptions() -> dict[type[Exception], type[httpx.HTTPError]]:
|
||||
import httpcore
|
||||
|
||||
return {
|
||||
httpcore.TimeoutException: TimeoutException,
|
||||
httpcore.ConnectTimeout: ConnectTimeout,
|
||||
httpcore.ReadTimeout: ReadTimeout,
|
||||
httpcore.WriteTimeout: WriteTimeout,
|
||||
httpcore.PoolTimeout: PoolTimeout,
|
||||
httpcore.NetworkError: NetworkError,
|
||||
httpcore.ConnectError: ConnectError,
|
||||
httpcore.ReadError: ReadError,
|
||||
httpcore.WriteError: WriteError,
|
||||
httpcore.ProxyError: ProxyError,
|
||||
httpcore.UnsupportedProtocol: UnsupportedProtocol,
|
||||
httpcore.ProtocolError: ProtocolError,
|
||||
httpcore.LocalProtocolError: LocalProtocolError,
|
||||
httpcore.RemoteProtocolError: RemoteProtocolError,
|
||||
}
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def map_httpcore_exceptions() -> typing.Iterator[None]:
|
||||
global HTTPCORE_EXC_MAP
|
||||
if len(HTTPCORE_EXC_MAP) == 0:
|
||||
HTTPCORE_EXC_MAP = _load_httpcore_exceptions()
|
||||
try:
|
||||
yield
|
||||
except Exception as exc: # noqa: PIE-786
|
||||
except Exception as exc:
|
||||
mapped_exc = None
|
||||
|
||||
for from_exc, to_exc in HTTPCORE_EXC_MAP.items():
|
||||
@@ -77,26 +118,8 @@ def map_httpcore_exceptions() -> typing.Iterator[None]:
|
||||
raise mapped_exc(message) from exc
|
||||
|
||||
|
||||
HTTPCORE_EXC_MAP = {
|
||||
httpcore.TimeoutException: TimeoutException,
|
||||
httpcore.ConnectTimeout: ConnectTimeout,
|
||||
httpcore.ReadTimeout: ReadTimeout,
|
||||
httpcore.WriteTimeout: WriteTimeout,
|
||||
httpcore.PoolTimeout: PoolTimeout,
|
||||
httpcore.NetworkError: NetworkError,
|
||||
httpcore.ConnectError: ConnectError,
|
||||
httpcore.ReadError: ReadError,
|
||||
httpcore.WriteError: WriteError,
|
||||
httpcore.ProxyError: ProxyError,
|
||||
httpcore.UnsupportedProtocol: UnsupportedProtocol,
|
||||
httpcore.ProtocolError: ProtocolError,
|
||||
httpcore.LocalProtocolError: LocalProtocolError,
|
||||
httpcore.RemoteProtocolError: RemoteProtocolError,
|
||||
}
|
||||
|
||||
|
||||
class ResponseStream(SyncByteStream):
|
||||
def __init__(self, httpcore_stream: typing.Iterable[bytes]):
|
||||
def __init__(self, httpcore_stream: typing.Iterable[bytes]) -> None:
|
||||
self._httpcore_stream = httpcore_stream
|
||||
|
||||
def __iter__(self) -> typing.Iterator[bytes]:
|
||||
@@ -112,17 +135,21 @@ class ResponseStream(SyncByteStream):
|
||||
class HTTPTransport(BaseTransport):
|
||||
def __init__(
|
||||
self,
|
||||
verify: VerifyTypes = True,
|
||||
cert: typing.Optional[CertTypes] = None,
|
||||
verify: ssl.SSLContext | str | bool = True,
|
||||
cert: CertTypes | None = None,
|
||||
trust_env: bool = True,
|
||||
http1: bool = True,
|
||||
http2: bool = False,
|
||||
limits: Limits = DEFAULT_LIMITS,
|
||||
trust_env: bool = True,
|
||||
proxy: typing.Optional[Proxy] = None,
|
||||
uds: typing.Optional[str] = None,
|
||||
local_address: typing.Optional[str] = None,
|
||||
proxy: ProxyTypes | None = None,
|
||||
uds: str | None = None,
|
||||
local_address: str | None = None,
|
||||
retries: int = 0,
|
||||
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
|
||||
) -> None:
|
||||
import httpcore
|
||||
|
||||
proxy = Proxy(url=proxy) if isinstance(proxy, (str, URL)) else proxy
|
||||
ssl_context = create_ssl_context(verify=verify, cert=cert, trust_env=trust_env)
|
||||
|
||||
if proxy is None:
|
||||
@@ -136,6 +163,7 @@ class HTTPTransport(BaseTransport):
|
||||
uds=uds,
|
||||
local_address=local_address,
|
||||
retries=retries,
|
||||
socket_options=socket_options,
|
||||
)
|
||||
elif proxy.url.scheme in ("http", "https"):
|
||||
self._pool = httpcore.HTTPProxy(
|
||||
@@ -148,13 +176,15 @@ class HTTPTransport(BaseTransport):
|
||||
proxy_auth=proxy.raw_auth,
|
||||
proxy_headers=proxy.headers.raw,
|
||||
ssl_context=ssl_context,
|
||||
proxy_ssl_context=proxy.ssl_context,
|
||||
max_connections=limits.max_connections,
|
||||
max_keepalive_connections=limits.max_keepalive_connections,
|
||||
keepalive_expiry=limits.keepalive_expiry,
|
||||
http1=http1,
|
||||
http2=http2,
|
||||
socket_options=socket_options,
|
||||
)
|
||||
elif proxy.url.scheme == "socks5":
|
||||
elif proxy.url.scheme in ("socks5", "socks5h"):
|
||||
try:
|
||||
import socksio # noqa
|
||||
except ImportError: # pragma: no cover
|
||||
@@ -180,7 +210,8 @@ class HTTPTransport(BaseTransport):
|
||||
)
|
||||
else: # pragma: no cover
|
||||
raise ValueError(
|
||||
f"Proxy protocol must be either 'http', 'https', or 'socks5', but got {proxy.url.scheme!r}."
|
||||
"Proxy protocol must be either 'http', 'https', 'socks5', or 'socks5h',"
|
||||
f" but got {proxy.url.scheme!r}."
|
||||
)
|
||||
|
||||
def __enter__(self: T) -> T: # Use generics for subclass support.
|
||||
@@ -189,9 +220,9 @@ class HTTPTransport(BaseTransport):
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: typing.Optional[typing.Type[BaseException]] = None,
|
||||
exc_value: typing.Optional[BaseException] = None,
|
||||
traceback: typing.Optional[TracebackType] = None,
|
||||
exc_type: type[BaseException] | None = None,
|
||||
exc_value: BaseException | None = None,
|
||||
traceback: TracebackType | None = None,
|
||||
) -> None:
|
||||
with map_httpcore_exceptions():
|
||||
self._pool.__exit__(exc_type, exc_value, traceback)
|
||||
@@ -201,6 +232,7 @@ class HTTPTransport(BaseTransport):
|
||||
request: Request,
|
||||
) -> Response:
|
||||
assert isinstance(request.stream, SyncByteStream)
|
||||
import httpcore
|
||||
|
||||
req = httpcore.Request(
|
||||
method=request.method,
|
||||
@@ -231,7 +263,7 @@ class HTTPTransport(BaseTransport):
|
||||
|
||||
|
||||
class AsyncResponseStream(AsyncByteStream):
|
||||
def __init__(self, httpcore_stream: typing.AsyncIterable[bytes]):
|
||||
def __init__(self, httpcore_stream: typing.AsyncIterable[bytes]) -> None:
|
||||
self._httpcore_stream = httpcore_stream
|
||||
|
||||
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
|
||||
@@ -247,17 +279,21 @@ class AsyncResponseStream(AsyncByteStream):
|
||||
class AsyncHTTPTransport(AsyncBaseTransport):
|
||||
def __init__(
|
||||
self,
|
||||
verify: VerifyTypes = True,
|
||||
cert: typing.Optional[CertTypes] = None,
|
||||
verify: ssl.SSLContext | str | bool = True,
|
||||
cert: CertTypes | None = None,
|
||||
trust_env: bool = True,
|
||||
http1: bool = True,
|
||||
http2: bool = False,
|
||||
limits: Limits = DEFAULT_LIMITS,
|
||||
trust_env: bool = True,
|
||||
proxy: typing.Optional[Proxy] = None,
|
||||
uds: typing.Optional[str] = None,
|
||||
local_address: typing.Optional[str] = None,
|
||||
proxy: ProxyTypes | None = None,
|
||||
uds: str | None = None,
|
||||
local_address: str | None = None,
|
||||
retries: int = 0,
|
||||
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
|
||||
) -> None:
|
||||
import httpcore
|
||||
|
||||
proxy = Proxy(url=proxy) if isinstance(proxy, (str, URL)) else proxy
|
||||
ssl_context = create_ssl_context(verify=verify, cert=cert, trust_env=trust_env)
|
||||
|
||||
if proxy is None:
|
||||
@@ -271,6 +307,7 @@ class AsyncHTTPTransport(AsyncBaseTransport):
|
||||
uds=uds,
|
||||
local_address=local_address,
|
||||
retries=retries,
|
||||
socket_options=socket_options,
|
||||
)
|
||||
elif proxy.url.scheme in ("http", "https"):
|
||||
self._pool = httpcore.AsyncHTTPProxy(
|
||||
@@ -282,14 +319,16 @@ class AsyncHTTPTransport(AsyncBaseTransport):
|
||||
),
|
||||
proxy_auth=proxy.raw_auth,
|
||||
proxy_headers=proxy.headers.raw,
|
||||
proxy_ssl_context=proxy.ssl_context,
|
||||
ssl_context=ssl_context,
|
||||
max_connections=limits.max_connections,
|
||||
max_keepalive_connections=limits.max_keepalive_connections,
|
||||
keepalive_expiry=limits.keepalive_expiry,
|
||||
http1=http1,
|
||||
http2=http2,
|
||||
socket_options=socket_options,
|
||||
)
|
||||
elif proxy.url.scheme == "socks5":
|
||||
elif proxy.url.scheme in ("socks5", "socks5h"):
|
||||
try:
|
||||
import socksio # noqa
|
||||
except ImportError: # pragma: no cover
|
||||
@@ -315,7 +354,8 @@ class AsyncHTTPTransport(AsyncBaseTransport):
|
||||
)
|
||||
else: # pragma: no cover
|
||||
raise ValueError(
|
||||
f"Proxy protocol must be either 'http', 'https', or 'socks5', but got {proxy.url.scheme!r}."
|
||||
"Proxy protocol must be either 'http', 'https', 'socks5', or 'socks5h',"
|
||||
" but got {proxy.url.scheme!r}."
|
||||
)
|
||||
|
||||
async def __aenter__(self: A) -> A: # Use generics for subclass support.
|
||||
@@ -324,9 +364,9 @@ class AsyncHTTPTransport(AsyncBaseTransport):
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: typing.Optional[typing.Type[BaseException]] = None,
|
||||
exc_value: typing.Optional[BaseException] = None,
|
||||
traceback: typing.Optional[TracebackType] = None,
|
||||
exc_type: type[BaseException] | None = None,
|
||||
exc_value: BaseException | None = None,
|
||||
traceback: TracebackType | None = None,
|
||||
) -> None:
|
||||
with map_httpcore_exceptions():
|
||||
await self._pool.__aexit__(exc_type, exc_value, traceback)
|
||||
@@ -336,6 +376,7 @@ class AsyncHTTPTransport(AsyncBaseTransport):
|
||||
request: Request,
|
||||
) -> Response:
|
||||
assert isinstance(request.stream, AsyncByteStream)
|
||||
import httpcore
|
||||
|
||||
req = httpcore.Request(
|
||||
method=request.method,
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
from .._models import Request, Response
|
||||
@@ -7,8 +9,11 @@ SyncHandler = typing.Callable[[Request], Response]
|
||||
AsyncHandler = typing.Callable[[Request], typing.Coroutine[None, None, Response]]
|
||||
|
||||
|
||||
__all__ = ["MockTransport"]
|
||||
|
||||
|
||||
class MockTransport(AsyncBaseTransport, BaseTransport):
|
||||
def __init__(self, handler: typing.Union[SyncHandler, AsyncHandler]) -> None:
|
||||
def __init__(self, handler: SyncHandler | AsyncHandler) -> None:
|
||||
self.handler = handler
|
||||
|
||||
def handle_request(
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import itertools
|
||||
import sys
|
||||
@@ -14,6 +16,9 @@ if typing.TYPE_CHECKING:
|
||||
_T = typing.TypeVar("_T")
|
||||
|
||||
|
||||
__all__ = ["WSGITransport"]
|
||||
|
||||
|
||||
def _skip_leading_empty_chunks(body: typing.Iterable[_T]) -> typing.Iterable[_T]:
|
||||
body = iter(body)
|
||||
for chunk in body:
|
||||
@@ -71,11 +76,11 @@ class WSGITransport(BaseTransport):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app: "WSGIApplication",
|
||||
app: WSGIApplication,
|
||||
raise_app_exceptions: bool = True,
|
||||
script_name: str = "",
|
||||
remote_addr: str = "127.0.0.1",
|
||||
wsgi_errors: typing.Optional[typing.TextIO] = None,
|
||||
wsgi_errors: typing.TextIO | None = None,
|
||||
) -> None:
|
||||
self.app = app
|
||||
self.raise_app_exceptions = raise_app_exceptions
|
||||
@@ -102,6 +107,7 @@ class WSGITransport(BaseTransport):
|
||||
"QUERY_STRING": request.url.query.decode("ascii"),
|
||||
"SERVER_NAME": request.url.host,
|
||||
"SERVER_PORT": str(port),
|
||||
"SERVER_PROTOCOL": "HTTP/1.1",
|
||||
"REMOTE_ADDR": self.remote_addr,
|
||||
}
|
||||
for header_key, header_value in request.headers.raw:
|
||||
@@ -116,8 +122,8 @@ class WSGITransport(BaseTransport):
|
||||
|
||||
def start_response(
|
||||
status: str,
|
||||
response_headers: typing.List[typing.Tuple[str, str]],
|
||||
exc_info: typing.Optional["OptExcInfo"] = None,
|
||||
response_headers: list[tuple[str, str]],
|
||||
exc_info: OptExcInfo | None = None,
|
||||
) -> typing.Callable[[bytes], typing.Any]:
|
||||
nonlocal seen_status, seen_response_headers, seen_exc_info
|
||||
seen_status = status
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
Type definitions for type checking purposes.
|
||||
"""
|
||||
|
||||
import ssl
|
||||
from http.cookiejar import CookieJar
|
||||
from typing import (
|
||||
IO,
|
||||
@@ -16,7 +15,6 @@ from typing import (
|
||||
Iterator,
|
||||
List,
|
||||
Mapping,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
@@ -32,16 +30,6 @@ if TYPE_CHECKING: # pragma: no cover
|
||||
|
||||
PrimitiveData = Optional[Union[str, int, float, bool]]
|
||||
|
||||
RawURL = NamedTuple(
|
||||
"RawURL",
|
||||
[
|
||||
("raw_scheme", bytes),
|
||||
("raw_host", bytes),
|
||||
("port", Optional[int]),
|
||||
("raw_path", bytes),
|
||||
],
|
||||
)
|
||||
|
||||
URLTypes = Union["URL", str]
|
||||
|
||||
QueryParamTypes = Union[
|
||||
@@ -63,21 +51,13 @@ HeaderTypes = Union[
|
||||
|
||||
CookieTypes = Union["Cookies", CookieJar, Dict[str, str], List[Tuple[str, str]]]
|
||||
|
||||
CertTypes = Union[
|
||||
# certfile
|
||||
str,
|
||||
# (certfile, keyfile)
|
||||
Tuple[str, Optional[str]],
|
||||
# (certfile, keyfile, password)
|
||||
Tuple[str, Optional[str], Optional[str]],
|
||||
]
|
||||
VerifyTypes = Union[str, bool, ssl.SSLContext]
|
||||
TimeoutTypes = Union[
|
||||
Optional[float],
|
||||
Tuple[Optional[float], Optional[float], Optional[float], Optional[float]],
|
||||
"Timeout",
|
||||
]
|
||||
ProxiesTypes = Union[URLTypes, "Proxy", Dict[URLTypes, Union[None, URLTypes, "Proxy"]]]
|
||||
ProxyTypes = Union["URL", str, "Proxy"]
|
||||
CertTypes = Union[str, Tuple[str, str], Tuple[str, str, str]]
|
||||
|
||||
AuthTypes = Union[
|
||||
Tuple[Union[str, bytes], Union[str, bytes]],
|
||||
@@ -106,6 +86,8 @@ RequestFiles = Union[Mapping[str, FileTypes], Sequence[Tuple[str, FileTypes]]]
|
||||
|
||||
RequestExtensions = Mapping[str, Any]
|
||||
|
||||
__all__ = ["AsyncByteStream", "SyncByteStream"]
|
||||
|
||||
|
||||
class SyncByteStream:
|
||||
def __iter__(self) -> Iterator[bytes]:
|
||||
|
||||
@@ -15,6 +15,9 @@ Previously we relied on the excellent `rfc3986` package to handle URL parsing an
|
||||
validation, but this module provides a simpler alternative, with less indirection
|
||||
required.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ipaddress
|
||||
import re
|
||||
import typing
|
||||
@@ -33,6 +36,67 @@ SUB_DELIMS = "!$&'()*+,;="
|
||||
|
||||
PERCENT_ENCODED_REGEX = re.compile("%[A-Fa-f0-9]{2}")
|
||||
|
||||
# https://url.spec.whatwg.org/#percent-encoded-bytes
|
||||
|
||||
# The fragment percent-encode set is the C0 control percent-encode set
|
||||
# and U+0020 SPACE, U+0022 ("), U+003C (<), U+003E (>), and U+0060 (`).
|
||||
FRAG_SAFE = "".join(
|
||||
[chr(i) for i in range(0x20, 0x7F) if i not in (0x20, 0x22, 0x3C, 0x3E, 0x60)]
|
||||
)
|
||||
|
||||
# The query percent-encode set is the C0 control percent-encode set
|
||||
# and U+0020 SPACE, U+0022 ("), U+0023 (#), U+003C (<), and U+003E (>).
|
||||
QUERY_SAFE = "".join(
|
||||
[chr(i) for i in range(0x20, 0x7F) if i not in (0x20, 0x22, 0x23, 0x3C, 0x3E)]
|
||||
)
|
||||
|
||||
# The path percent-encode set is the query percent-encode set
|
||||
# and U+003F (?), U+0060 (`), U+007B ({), and U+007D (}).
|
||||
PATH_SAFE = "".join(
|
||||
[
|
||||
chr(i)
|
||||
for i in range(0x20, 0x7F)
|
||||
if i not in (0x20, 0x22, 0x23, 0x3C, 0x3E) + (0x3F, 0x60, 0x7B, 0x7D)
|
||||
]
|
||||
)
|
||||
|
||||
# The userinfo percent-encode set is the path percent-encode set
|
||||
# and U+002F (/), U+003A (:), U+003B (;), U+003D (=), U+0040 (@),
|
||||
# U+005B ([) to U+005E (^), inclusive, and U+007C (|).
|
||||
USERNAME_SAFE = "".join(
|
||||
[
|
||||
chr(i)
|
||||
for i in range(0x20, 0x7F)
|
||||
if i
|
||||
not in (0x20, 0x22, 0x23, 0x3C, 0x3E)
|
||||
+ (0x3F, 0x60, 0x7B, 0x7D)
|
||||
+ (0x2F, 0x3A, 0x3B, 0x3D, 0x40, 0x5B, 0x5C, 0x5D, 0x5E, 0x7C)
|
||||
]
|
||||
)
|
||||
PASSWORD_SAFE = "".join(
|
||||
[
|
||||
chr(i)
|
||||
for i in range(0x20, 0x7F)
|
||||
if i
|
||||
not in (0x20, 0x22, 0x23, 0x3C, 0x3E)
|
||||
+ (0x3F, 0x60, 0x7B, 0x7D)
|
||||
+ (0x2F, 0x3A, 0x3B, 0x3D, 0x40, 0x5B, 0x5C, 0x5D, 0x5E, 0x7C)
|
||||
]
|
||||
)
|
||||
# Note... The terminology 'userinfo' percent-encode set in the WHATWG document
|
||||
# is used for the username and password quoting. For the joint userinfo component
|
||||
# we remove U+003A (:) from the safe set.
|
||||
USERINFO_SAFE = "".join(
|
||||
[
|
||||
chr(i)
|
||||
for i in range(0x20, 0x7F)
|
||||
if i
|
||||
not in (0x20, 0x22, 0x23, 0x3C, 0x3E)
|
||||
+ (0x3F, 0x60, 0x7B, 0x7D)
|
||||
+ (0x2F, 0x3B, 0x3D, 0x40, 0x5B, 0x5C, 0x5D, 0x5E, 0x7C)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# {scheme}: (optional)
|
||||
# //{authority} (optional)
|
||||
@@ -62,8 +126,8 @@ AUTHORITY_REGEX = re.compile(
|
||||
(
|
||||
r"(?:(?P<userinfo>{userinfo})@)?" r"(?P<host>{host})" r":?(?P<port>{port})?"
|
||||
).format(
|
||||
userinfo="[^@]*", # Any character sequence not including '@'.
|
||||
host="(\\[.*\\]|[^:]*)", # Either any character sequence not including ':',
|
||||
userinfo=".*", # Any character sequence.
|
||||
host="(\\[.*\\]|[^:@]*)", # Either any character sequence excluding ':' or '@',
|
||||
# or an IPv6 address enclosed within square brackets.
|
||||
port=".*", # Any character sequence.
|
||||
)
|
||||
@@ -87,7 +151,7 @@ COMPONENT_REGEX = {
|
||||
|
||||
# We use these simple regexs as a first pass before handing off to
|
||||
# the stdlib 'ipaddress' module for IP address validation.
|
||||
IPv4_STYLE_HOSTNAME = re.compile(r"^[0-9]+.[0-9]+.[0-9]+.[0-9]+$")
|
||||
IPv4_STYLE_HOSTNAME = re.compile(r"^[0-9]+\.[0-9]+\.[0-9]+\.[0-9]+$")
|
||||
IPv6_STYLE_HOSTNAME = re.compile(r"^\[.*\]$")
|
||||
|
||||
|
||||
@@ -95,10 +159,10 @@ class ParseResult(typing.NamedTuple):
|
||||
scheme: str
|
||||
userinfo: str
|
||||
host: str
|
||||
port: typing.Optional[int]
|
||||
port: int | None
|
||||
path: str
|
||||
query: typing.Optional[str]
|
||||
fragment: typing.Optional[str]
|
||||
query: str | None
|
||||
fragment: str | None
|
||||
|
||||
@property
|
||||
def authority(self) -> str:
|
||||
@@ -119,7 +183,7 @@ class ParseResult(typing.NamedTuple):
|
||||
]
|
||||
)
|
||||
|
||||
def copy_with(self, **kwargs: typing.Optional[str]) -> "ParseResult":
|
||||
def copy_with(self, **kwargs: str | None) -> ParseResult:
|
||||
if not kwargs:
|
||||
return self
|
||||
|
||||
@@ -146,7 +210,7 @@ class ParseResult(typing.NamedTuple):
|
||||
)
|
||||
|
||||
|
||||
def urlparse(url: str = "", **kwargs: typing.Optional[str]) -> ParseResult:
|
||||
def urlparse(url: str = "", **kwargs: str | None) -> ParseResult:
|
||||
# Initial basic checks on allowable URLs.
|
||||
# ---------------------------------------
|
||||
|
||||
@@ -157,7 +221,12 @@ def urlparse(url: str = "", **kwargs: typing.Optional[str]) -> ParseResult:
|
||||
# If a URL includes any ASCII control characters including \t, \r, \n,
|
||||
# then treat it as invalid.
|
||||
if any(char.isascii() and not char.isprintable() for char in url):
|
||||
raise InvalidURL("Invalid non-printable ASCII character in URL")
|
||||
char = next(char for char in url if char.isascii() and not char.isprintable())
|
||||
idx = url.find(char)
|
||||
error = (
|
||||
f"Invalid non-printable ASCII character in URL, {char!r} at position {idx}."
|
||||
)
|
||||
raise InvalidURL(error)
|
||||
|
||||
# Some keyword arguments require special handling.
|
||||
# ------------------------------------------------
|
||||
@@ -174,8 +243,8 @@ def urlparse(url: str = "", **kwargs: typing.Optional[str]) -> ParseResult:
|
||||
|
||||
# Replace "username" and/or "password" with "userinfo".
|
||||
if "username" in kwargs or "password" in kwargs:
|
||||
username = quote(kwargs.pop("username", "") or "")
|
||||
password = quote(kwargs.pop("password", "") or "")
|
||||
username = quote(kwargs.pop("username", "") or "", safe=USERNAME_SAFE)
|
||||
password = quote(kwargs.pop("password", "") or "", safe=PASSWORD_SAFE)
|
||||
kwargs["userinfo"] = f"{username}:{password}" if password else username
|
||||
|
||||
# Replace "raw_path" with "path" and "query".
|
||||
@@ -202,9 +271,15 @@ def urlparse(url: str = "", **kwargs: typing.Optional[str]) -> ParseResult:
|
||||
# If a component includes any ASCII control characters including \t, \r, \n,
|
||||
# then treat it as invalid.
|
||||
if any(char.isascii() and not char.isprintable() for char in value):
|
||||
raise InvalidURL(
|
||||
f"Invalid non-printable ASCII character in URL component '{key}'"
|
||||
char = next(
|
||||
char for char in value if char.isascii() and not char.isprintable()
|
||||
)
|
||||
idx = value.find(char)
|
||||
error = (
|
||||
f"Invalid non-printable ASCII character in URL {key} component, "
|
||||
f"{char!r} at position {idx}."
|
||||
)
|
||||
raise InvalidURL(error)
|
||||
|
||||
# Ensure that keyword arguments match as a valid regex.
|
||||
if not COMPONENT_REGEX[key].fullmatch(value):
|
||||
@@ -224,7 +299,7 @@ def urlparse(url: str = "", **kwargs: typing.Optional[str]) -> ParseResult:
|
||||
authority = kwargs.get("authority", url_dict["authority"]) or ""
|
||||
path = kwargs.get("path", url_dict["path"]) or ""
|
||||
query = kwargs.get("query", url_dict["query"])
|
||||
fragment = kwargs.get("fragment", url_dict["fragment"])
|
||||
frag = kwargs.get("fragment", url_dict["fragment"])
|
||||
|
||||
# The AUTHORITY_REGEX will always match, but may have empty components.
|
||||
authority_match = AUTHORITY_REGEX.match(authority)
|
||||
@@ -241,32 +316,21 @@ def urlparse(url: str = "", **kwargs: typing.Optional[str]) -> ParseResult:
|
||||
# We end up with a parsed representation of the URL,
|
||||
# with components that are plain ASCII bytestrings.
|
||||
parsed_scheme: str = scheme.lower()
|
||||
parsed_userinfo: str = quote(userinfo, safe=SUB_DELIMS + ":")
|
||||
parsed_userinfo: str = quote(userinfo, safe=USERINFO_SAFE)
|
||||
parsed_host: str = encode_host(host)
|
||||
parsed_port: typing.Optional[int] = normalize_port(port, scheme)
|
||||
parsed_port: int | None = normalize_port(port, scheme)
|
||||
|
||||
has_scheme = parsed_scheme != ""
|
||||
has_authority = (
|
||||
parsed_userinfo != "" or parsed_host != "" or parsed_port is not None
|
||||
)
|
||||
validate_path(path, has_scheme=has_scheme, has_authority=has_authority)
|
||||
if has_authority:
|
||||
if has_scheme or has_authority:
|
||||
path = normalize_path(path)
|
||||
|
||||
# The GEN_DELIMS set is... : / ? # [ ] @
|
||||
# These do not need to be percent-quoted unless they serve as delimiters for the
|
||||
# specific component.
|
||||
|
||||
# For 'path' we need to drop ? and # from the GEN_DELIMS set.
|
||||
parsed_path: str = quote(path, safe=SUB_DELIMS + ":/[]@")
|
||||
# For 'query' we need to drop '#' from the GEN_DELIMS set.
|
||||
parsed_query: typing.Optional[str] = (
|
||||
None if query is None else quote(query, safe=SUB_DELIMS + ":/?[]@")
|
||||
)
|
||||
# For 'fragment' we can include all of the GEN_DELIMS set.
|
||||
parsed_fragment: typing.Optional[str] = (
|
||||
None if fragment is None else quote(fragment, safe=SUB_DELIMS + ":/?#[]@")
|
||||
)
|
||||
parsed_path: str = quote(path, safe=PATH_SAFE)
|
||||
parsed_query: str | None = None if query is None else quote(query, safe=QUERY_SAFE)
|
||||
parsed_frag: str | None = None if frag is None else quote(frag, safe=FRAG_SAFE)
|
||||
|
||||
# The parsed ASCII bytestrings are our canonical form.
|
||||
# All properties of the URL are derived from these.
|
||||
@@ -277,7 +341,7 @@ def urlparse(url: str = "", **kwargs: typing.Optional[str]) -> ParseResult:
|
||||
parsed_port,
|
||||
parsed_path,
|
||||
parsed_query,
|
||||
parsed_fragment,
|
||||
parsed_frag,
|
||||
)
|
||||
|
||||
|
||||
@@ -318,7 +382,8 @@ def encode_host(host: str) -> str:
|
||||
# From https://datatracker.ietf.org/doc/html/rfc3986/#section-3.2.2
|
||||
#
|
||||
# reg-name = *( unreserved / pct-encoded / sub-delims )
|
||||
return quote(host.lower(), safe=SUB_DELIMS)
|
||||
WHATWG_SAFE = '"`{}%|\\'
|
||||
return quote(host.lower(), safe=SUB_DELIMS + WHATWG_SAFE)
|
||||
|
||||
# IDNA hostnames
|
||||
try:
|
||||
@@ -327,9 +392,7 @@ def encode_host(host: str) -> str:
|
||||
raise InvalidURL(f"Invalid IDNA hostname: {host!r}")
|
||||
|
||||
|
||||
def normalize_port(
|
||||
port: typing.Optional[typing.Union[str, int]], scheme: str
|
||||
) -> typing.Optional[int]:
|
||||
def normalize_port(port: str | int | None, scheme: str) -> int | None:
|
||||
# From https://tools.ietf.org/html/rfc3986#section-3.2.3
|
||||
#
|
||||
# "A scheme may define a default port. For example, the "http" scheme
|
||||
@@ -358,28 +421,27 @@ def normalize_port(
|
||||
|
||||
def validate_path(path: str, has_scheme: bool, has_authority: bool) -> None:
|
||||
"""
|
||||
Path validation rules that depend on if the URL contains a scheme or authority component.
|
||||
Path validation rules that depend on if the URL contains
|
||||
a scheme or authority component.
|
||||
|
||||
See https://datatracker.ietf.org/doc/html/rfc3986.html#section-3.3
|
||||
"""
|
||||
if has_authority:
|
||||
# > If a URI contains an authority component, then the path component
|
||||
# > must either be empty or begin with a slash ("/") character."
|
||||
# If a URI contains an authority component, then the path component
|
||||
# must either be empty or begin with a slash ("/") character."
|
||||
if path and not path.startswith("/"):
|
||||
raise InvalidURL("For absolute URLs, path must be empty or begin with '/'")
|
||||
else:
|
||||
# > If a URI does not contain an authority component, then the path cannot begin
|
||||
# > with two slash characters ("//").
|
||||
|
||||
if not has_scheme and not has_authority:
|
||||
# If a URI does not contain an authority component, then the path cannot begin
|
||||
# with two slash characters ("//").
|
||||
if path.startswith("//"):
|
||||
raise InvalidURL(
|
||||
"URLs with no authority component cannot have a path starting with '//'"
|
||||
)
|
||||
# > In addition, a URI reference (Section 4.1) may be a relative-path reference, in which
|
||||
# > case the first path segment cannot contain a colon (":") character.
|
||||
if path.startswith(":") and not has_scheme:
|
||||
raise InvalidURL(
|
||||
"URLs with no scheme component cannot have a path starting with ':'"
|
||||
)
|
||||
raise InvalidURL("Relative URLs cannot have a path starting with '//'")
|
||||
|
||||
# In addition, a URI reference (Section 4.1) may be a relative-path reference,
|
||||
# in which case the first path segment cannot contain a colon (":") character.
|
||||
if path.startswith(":"):
|
||||
raise InvalidURL("Relative URLs cannot have a path starting with ':'")
|
||||
|
||||
|
||||
def normalize_path(path: str) -> str:
|
||||
@@ -390,9 +452,18 @@ def normalize_path(path: str) -> str:
|
||||
|
||||
normalize_path("/path/./to/somewhere/..") == "/path/to"
|
||||
"""
|
||||
# https://datatracker.ietf.org/doc/html/rfc3986#section-5.2.4
|
||||
# Fast return when no '.' characters in the path.
|
||||
if "." not in path:
|
||||
return path
|
||||
|
||||
components = path.split("/")
|
||||
output: typing.List[str] = []
|
||||
|
||||
# Fast return when no '.' or '..' components in the path.
|
||||
if "." not in components and ".." not in components:
|
||||
return path
|
||||
|
||||
# https://datatracker.ietf.org/doc/html/rfc3986#section-5.2.4
|
||||
output: list[str] = []
|
||||
for component in components:
|
||||
if component == ".":
|
||||
pass
|
||||
@@ -404,59 +475,53 @@ def normalize_path(path: str) -> str:
|
||||
return "/".join(output)
|
||||
|
||||
|
||||
def percent_encode(char: str) -> str:
|
||||
def PERCENT(string: str) -> str:
|
||||
return "".join([f"%{byte:02X}" for byte in string.encode("utf-8")])
|
||||
|
||||
|
||||
def percent_encoded(string: str, safe: str) -> str:
|
||||
"""
|
||||
Replace a single character with the percent-encoded representation.
|
||||
|
||||
Characters outside the ASCII range are represented with their a percent-encoded
|
||||
representation of their UTF-8 byte sequence.
|
||||
|
||||
For example:
|
||||
|
||||
percent_encode(" ") == "%20"
|
||||
Use percent-encoding to quote a string.
|
||||
"""
|
||||
return "".join([f"%{byte:02x}" for byte in char.encode("utf-8")]).upper()
|
||||
NON_ESCAPED_CHARS = UNRESERVED_CHARACTERS + safe
|
||||
|
||||
|
||||
def is_safe(string: str, safe: str = "/") -> bool:
|
||||
"""
|
||||
Determine if a given string is already quote-safe.
|
||||
"""
|
||||
NON_ESCAPED_CHARS = UNRESERVED_CHARACTERS + safe + "%"
|
||||
|
||||
# All characters must already be non-escaping or '%'
|
||||
for char in string:
|
||||
if char not in NON_ESCAPED_CHARS:
|
||||
return False
|
||||
|
||||
# Any '%' characters must be valid '%xx' escape sequences.
|
||||
return string.count("%") == len(PERCENT_ENCODED_REGEX.findall(string))
|
||||
|
||||
|
||||
def quote(string: str, safe: str = "/") -> str:
|
||||
"""
|
||||
Use percent-encoding to quote a string if required.
|
||||
"""
|
||||
if is_safe(string, safe=safe):
|
||||
# Fast path for strings that don't need escaping.
|
||||
if not string.rstrip(NON_ESCAPED_CHARS):
|
||||
return string
|
||||
|
||||
NON_ESCAPED_CHARS = UNRESERVED_CHARACTERS + safe
|
||||
return "".join(
|
||||
[char if char in NON_ESCAPED_CHARS else percent_encode(char) for char in string]
|
||||
[char if char in NON_ESCAPED_CHARS else PERCENT(char) for char in string]
|
||||
)
|
||||
|
||||
|
||||
def urlencode(items: typing.List[typing.Tuple[str, str]]) -> str:
|
||||
# We can use a much simpler version of the stdlib urlencode here because
|
||||
# we don't need to handle a bunch of different typing cases, such as bytes vs str.
|
||||
#
|
||||
# https://github.com/python/cpython/blob/b2f7b2ef0b5421e01efb8c7bee2ef95d3bab77eb/Lib/urllib/parse.py#L926
|
||||
#
|
||||
# Note that we use '%20' encoding for spaces, and treat '/' as a safe
|
||||
# character. This means our query params have the same escaping as other
|
||||
# characters in the URL path. This is slightly different to `requests`,
|
||||
# but is the behaviour that browsers use.
|
||||
#
|
||||
# See https://github.com/encode/httpx/issues/2536 and
|
||||
# https://docs.python.org/3/library/urllib.parse.html#urllib.parse.urlencode
|
||||
return "&".join([quote(k) + "=" + quote(v) for k, v in items])
|
||||
def quote(string: str, safe: str) -> str:
|
||||
"""
|
||||
Use percent-encoding to quote a string, omitting existing '%xx' escape sequences.
|
||||
|
||||
See: https://www.rfc-editor.org/rfc/rfc3986#section-2.1
|
||||
|
||||
* `string`: The string to be percent-escaped.
|
||||
* `safe`: A string containing characters that may be treated as safe, and do not
|
||||
need to be escaped. Unreserved characters are always treated as safe.
|
||||
See: https://www.rfc-editor.org/rfc/rfc3986#section-2.3
|
||||
"""
|
||||
parts = []
|
||||
current_position = 0
|
||||
for match in re.finditer(PERCENT_ENCODED_REGEX, string):
|
||||
start_position, end_position = match.start(), match.end()
|
||||
matched_text = match.group(0)
|
||||
# Add any text up to the '%xx' escape sequence.
|
||||
if start_position != current_position:
|
||||
leading_text = string[current_position:start_position]
|
||||
parts.append(percent_encoded(leading_text, safe=safe))
|
||||
|
||||
# Add the '%xx' escape sequence.
|
||||
parts.append(matched_text)
|
||||
current_position = end_position
|
||||
|
||||
# Add any text after the final '%xx' escape sequence.
|
||||
if current_position != len(string):
|
||||
trailing_text = string[current_position:]
|
||||
parts.append(percent_encoded(trailing_text, safe=safe))
|
||||
|
||||
return "".join(parts)
|
||||
|
||||
@@ -1,12 +1,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
from urllib.parse import parse_qs, unquote
|
||||
from urllib.parse import parse_qs, unquote, urlencode
|
||||
|
||||
import idna
|
||||
|
||||
from ._types import QueryParamTypes, RawURL, URLTypes
|
||||
from ._urlparse import urlencode, urlparse
|
||||
from ._types import QueryParamTypes
|
||||
from ._urlparse import urlparse
|
||||
from ._utils import primitive_value_to_str
|
||||
|
||||
__all__ = ["URL", "QueryParams"]
|
||||
|
||||
|
||||
class URL:
|
||||
"""
|
||||
@@ -51,26 +55,26 @@ class URL:
|
||||
assert url.raw_host == b"xn--fiqs8s.icom.museum"
|
||||
|
||||
* `url.port` is either None or an integer. URLs that include the default port for
|
||||
"http", "https", "ws", "wss", and "ftp" schemes have their port normalized to `None`.
|
||||
"http", "https", "ws", "wss", and "ftp" schemes have their port
|
||||
normalized to `None`.
|
||||
|
||||
assert httpx.URL("http://example.com") == httpx.URL("http://example.com:80")
|
||||
assert httpx.URL("http://example.com").port is None
|
||||
assert httpx.URL("http://example.com:80").port is None
|
||||
|
||||
* `url.userinfo` is raw bytes, without URL escaping. Usually you'll want to work with
|
||||
`url.username` and `url.password` instead, which handle the URL escaping.
|
||||
* `url.userinfo` is raw bytes, without URL escaping. Usually you'll want to work
|
||||
with `url.username` and `url.password` instead, which handle the URL escaping.
|
||||
|
||||
* `url.raw_path` is raw bytes of both the path and query, without URL escaping.
|
||||
This portion is used as the target when constructing HTTP requests. Usually you'll
|
||||
want to work with `url.path` instead.
|
||||
|
||||
* `url.query` is raw bytes, without URL escaping. A URL query string portion can only
|
||||
be properly URL escaped when decoding the parameter names and values themselves.
|
||||
* `url.query` is raw bytes, without URL escaping. A URL query string portion can
|
||||
only be properly URL escaped when decoding the parameter names and values
|
||||
themselves.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, url: typing.Union["URL", str] = "", **kwargs: typing.Any
|
||||
) -> None:
|
||||
def __init__(self, url: URL | str = "", **kwargs: typing.Any) -> None:
|
||||
if kwargs:
|
||||
allowed = {
|
||||
"scheme": str,
|
||||
@@ -115,7 +119,8 @@ class URL:
|
||||
self._uri_reference = url._uri_reference.copy_with(**kwargs)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Invalid type for url. Expected str or httpx.URL, got {type(url)}: {url!r}"
|
||||
"Invalid type for url. Expected str or httpx.URL,"
|
||||
f" got {type(url)}: {url!r}"
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -210,7 +215,7 @@ class URL:
|
||||
return self._uri_reference.host.encode("ascii")
|
||||
|
||||
@property
|
||||
def port(self) -> typing.Optional[int]:
|
||||
def port(self) -> int | None:
|
||||
"""
|
||||
The URL port as an integer.
|
||||
|
||||
@@ -267,7 +272,7 @@ class URL:
|
||||
return query.encode("ascii")
|
||||
|
||||
@property
|
||||
def params(self) -> "QueryParams":
|
||||
def params(self) -> QueryParams:
|
||||
"""
|
||||
The URL query parameters, neatly parsed and packaged into an immutable
|
||||
multidict representation.
|
||||
@@ -299,21 +304,6 @@ class URL:
|
||||
"""
|
||||
return unquote(self._uri_reference.fragment or "")
|
||||
|
||||
@property
|
||||
def raw(self) -> RawURL:
|
||||
"""
|
||||
Provides the (scheme, host, port, target) for the outgoing request.
|
||||
|
||||
In older versions of `httpx` this was used in the low-level transport API.
|
||||
We no longer use `RawURL`, and this property will be deprecated in a future release.
|
||||
"""
|
||||
return RawURL(
|
||||
self.raw_scheme,
|
||||
self.raw_host,
|
||||
self.port,
|
||||
self.raw_path,
|
||||
)
|
||||
|
||||
@property
|
||||
def is_absolute_url(self) -> bool:
|
||||
"""
|
||||
@@ -334,7 +324,7 @@ class URL:
|
||||
"""
|
||||
return not self.is_absolute_url
|
||||
|
||||
def copy_with(self, **kwargs: typing.Any) -> "URL":
|
||||
def copy_with(self, **kwargs: typing.Any) -> URL:
|
||||
"""
|
||||
Copy this URL, returning a new URL with some components altered.
|
||||
Accepts the same set of parameters as the components that are made
|
||||
@@ -342,24 +332,26 @@ class URL:
|
||||
|
||||
For example:
|
||||
|
||||
url = httpx.URL("https://www.example.com").copy_with(username="jo@gmail.com", password="a secret")
|
||||
url = httpx.URL("https://www.example.com").copy_with(
|
||||
username="jo@gmail.com", password="a secret"
|
||||
)
|
||||
assert url == "https://jo%40email.com:a%20secret@www.example.com"
|
||||
"""
|
||||
return URL(self, **kwargs)
|
||||
|
||||
def copy_set_param(self, key: str, value: typing.Any = None) -> "URL":
|
||||
def copy_set_param(self, key: str, value: typing.Any = None) -> URL:
|
||||
return self.copy_with(params=self.params.set(key, value))
|
||||
|
||||
def copy_add_param(self, key: str, value: typing.Any = None) -> "URL":
|
||||
def copy_add_param(self, key: str, value: typing.Any = None) -> URL:
|
||||
return self.copy_with(params=self.params.add(key, value))
|
||||
|
||||
def copy_remove_param(self, key: str) -> "URL":
|
||||
def copy_remove_param(self, key: str) -> URL:
|
||||
return self.copy_with(params=self.params.remove(key))
|
||||
|
||||
def copy_merge_params(self, params: QueryParamTypes) -> "URL":
|
||||
def copy_merge_params(self, params: QueryParamTypes) -> URL:
|
||||
return self.copy_with(params=self.params.merge(params))
|
||||
|
||||
def join(self, url: URLTypes) -> "URL":
|
||||
def join(self, url: URL | str) -> URL:
|
||||
"""
|
||||
Return an absolute URL, using this URL as the base.
|
||||
|
||||
@@ -408,15 +400,29 @@ class URL:
|
||||
|
||||
return f"{self.__class__.__name__}({url!r})"
|
||||
|
||||
@property
|
||||
def raw(self) -> tuple[bytes, bytes, int, bytes]: # pragma: nocover
|
||||
import collections
|
||||
import warnings
|
||||
|
||||
warnings.warn("URL.raw is deprecated.")
|
||||
RawURL = collections.namedtuple(
|
||||
"RawURL", ["raw_scheme", "raw_host", "port", "raw_path"]
|
||||
)
|
||||
return RawURL(
|
||||
raw_scheme=self.raw_scheme,
|
||||
raw_host=self.raw_host,
|
||||
port=self.port,
|
||||
raw_path=self.raw_path,
|
||||
)
|
||||
|
||||
|
||||
class QueryParams(typing.Mapping[str, str]):
|
||||
"""
|
||||
URL query parameters, as a multi-dict.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, *args: typing.Optional[QueryParamTypes], **kwargs: typing.Any
|
||||
) -> None:
|
||||
def __init__(self, *args: QueryParamTypes | None, **kwargs: typing.Any) -> None:
|
||||
assert len(args) < 2, "Too many arguments."
|
||||
assert not (args and kwargs), "Cannot mix named and unnamed arguments."
|
||||
|
||||
@@ -428,7 +434,7 @@ class QueryParams(typing.Mapping[str, str]):
|
||||
elif isinstance(value, QueryParams):
|
||||
self._dict = {k: list(v) for k, v in value._dict.items()}
|
||||
else:
|
||||
dict_value: typing.Dict[typing.Any, typing.List[typing.Any]] = {}
|
||||
dict_value: dict[typing.Any, list[typing.Any]] = {}
|
||||
if isinstance(value, (list, tuple)):
|
||||
# Convert list inputs like:
|
||||
# [("a", "123"), ("a", "456"), ("b", "789")]
|
||||
@@ -489,7 +495,7 @@ class QueryParams(typing.Mapping[str, str]):
|
||||
"""
|
||||
return {k: v[0] for k, v in self._dict.items()}.items()
|
||||
|
||||
def multi_items(self) -> typing.List[typing.Tuple[str, str]]:
|
||||
def multi_items(self) -> list[tuple[str, str]]:
|
||||
"""
|
||||
Return all items in the query params. Allow duplicate keys to occur.
|
||||
|
||||
@@ -498,7 +504,7 @@ class QueryParams(typing.Mapping[str, str]):
|
||||
q = httpx.QueryParams("a=123&a=456&b=789")
|
||||
assert list(q.multi_items()) == [("a", "123"), ("a", "456"), ("b", "789")]
|
||||
"""
|
||||
multi_items: typing.List[typing.Tuple[str, str]] = []
|
||||
multi_items: list[tuple[str, str]] = []
|
||||
for k, v in self._dict.items():
|
||||
multi_items.extend([(k, i) for i in v])
|
||||
return multi_items
|
||||
@@ -517,7 +523,7 @@ class QueryParams(typing.Mapping[str, str]):
|
||||
return self._dict[str(key)][0]
|
||||
return default
|
||||
|
||||
def get_list(self, key: str) -> typing.List[str]:
|
||||
def get_list(self, key: str) -> list[str]:
|
||||
"""
|
||||
Get all values from the query param for a given key.
|
||||
|
||||
@@ -528,7 +534,7 @@ class QueryParams(typing.Mapping[str, str]):
|
||||
"""
|
||||
return list(self._dict.get(str(key), []))
|
||||
|
||||
def set(self, key: str, value: typing.Any = None) -> "QueryParams":
|
||||
def set(self, key: str, value: typing.Any = None) -> QueryParams:
|
||||
"""
|
||||
Return a new QueryParams instance, setting the value of a key.
|
||||
|
||||
@@ -543,7 +549,7 @@ class QueryParams(typing.Mapping[str, str]):
|
||||
q._dict[str(key)] = [primitive_value_to_str(value)]
|
||||
return q
|
||||
|
||||
def add(self, key: str, value: typing.Any = None) -> "QueryParams":
|
||||
def add(self, key: str, value: typing.Any = None) -> QueryParams:
|
||||
"""
|
||||
Return a new QueryParams instance, setting or appending the value of a key.
|
||||
|
||||
@@ -558,7 +564,7 @@ class QueryParams(typing.Mapping[str, str]):
|
||||
q._dict[str(key)] = q.get_list(key) + [primitive_value_to_str(value)]
|
||||
return q
|
||||
|
||||
def remove(self, key: str) -> "QueryParams":
|
||||
def remove(self, key: str) -> QueryParams:
|
||||
"""
|
||||
Return a new QueryParams instance, removing the value of a key.
|
||||
|
||||
@@ -573,7 +579,7 @@ class QueryParams(typing.Mapping[str, str]):
|
||||
q._dict.pop(str(key), None)
|
||||
return q
|
||||
|
||||
def merge(self, params: typing.Optional[QueryParamTypes] = None) -> "QueryParams":
|
||||
def merge(self, params: QueryParamTypes | None = None) -> QueryParams:
|
||||
"""
|
||||
Return a new QueryParams instance, updated with.
|
||||
|
||||
@@ -615,13 +621,6 @@ class QueryParams(typing.Mapping[str, str]):
|
||||
return sorted(self.multi_items()) == sorted(other.multi_items())
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""
|
||||
Note that we use '%20' encoding for spaces, and treat '/' as a safe
|
||||
character.
|
||||
|
||||
See https://github.com/encode/httpx/issues/2536 and
|
||||
https://docs.python.org/3/library/urllib.parse.html#urllib.parse.urlencode
|
||||
"""
|
||||
return urlencode(self.multi_items())
|
||||
|
||||
def __repr__(self) -> str:
|
||||
@@ -629,7 +628,7 @@ class QueryParams(typing.Mapping[str, str]):
|
||||
query_string = str(self)
|
||||
return f"{class_name}({query_string!r})"
|
||||
|
||||
def update(self, params: typing.Optional[QueryParamTypes] = None) -> None:
|
||||
def update(self, params: QueryParamTypes | None = None) -> None:
|
||||
raise RuntimeError(
|
||||
"QueryParams are immutable since 0.18.0. "
|
||||
"Use `q = q.merge(...)` to create an updated copy."
|
||||
|
||||
@@ -1,59 +1,18 @@
|
||||
import codecs
|
||||
import email.message
|
||||
from __future__ import annotations
|
||||
|
||||
import ipaddress
|
||||
import mimetypes
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
import typing
|
||||
from pathlib import Path
|
||||
from urllib.request import getproxies
|
||||
|
||||
import sniffio
|
||||
|
||||
from ._types import PrimitiveData
|
||||
|
||||
if typing.TYPE_CHECKING: # pragma: no cover
|
||||
from ._urls import URL
|
||||
|
||||
|
||||
_HTML5_FORM_ENCODING_REPLACEMENTS = {'"': "%22", "\\": "\\\\"}
|
||||
_HTML5_FORM_ENCODING_REPLACEMENTS.update(
|
||||
{chr(c): "%{:02X}".format(c) for c in range(0x1F + 1) if c != 0x1B}
|
||||
)
|
||||
_HTML5_FORM_ENCODING_RE = re.compile(
|
||||
r"|".join([re.escape(c) for c in _HTML5_FORM_ENCODING_REPLACEMENTS.keys()])
|
||||
)
|
||||
|
||||
|
||||
def normalize_header_key(
|
||||
value: typing.Union[str, bytes],
|
||||
lower: bool,
|
||||
encoding: typing.Optional[str] = None,
|
||||
) -> bytes:
|
||||
"""
|
||||
Coerce str/bytes into a strictly byte-wise HTTP header key.
|
||||
"""
|
||||
if isinstance(value, bytes):
|
||||
bytes_value = value
|
||||
else:
|
||||
bytes_value = value.encode(encoding or "ascii")
|
||||
|
||||
return bytes_value.lower() if lower else bytes_value
|
||||
|
||||
|
||||
def normalize_header_value(
|
||||
value: typing.Union[str, bytes], encoding: typing.Optional[str] = None
|
||||
) -> bytes:
|
||||
"""
|
||||
Coerce str/bytes into a strictly byte-wise HTTP header value.
|
||||
"""
|
||||
if isinstance(value, bytes):
|
||||
return value
|
||||
return value.encode(encoding or "ascii")
|
||||
|
||||
|
||||
def primitive_value_to_str(value: "PrimitiveData") -> str:
|
||||
def primitive_value_to_str(value: PrimitiveData) -> str:
|
||||
"""
|
||||
Coerce a primitive data type into a string value.
|
||||
|
||||
@@ -68,166 +27,7 @@ def primitive_value_to_str(value: "PrimitiveData") -> str:
|
||||
return str(value)
|
||||
|
||||
|
||||
def is_known_encoding(encoding: str) -> bool:
|
||||
"""
|
||||
Return `True` if `encoding` is a known codec.
|
||||
"""
|
||||
try:
|
||||
codecs.lookup(encoding)
|
||||
except LookupError:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def format_form_param(name: str, value: str) -> bytes:
|
||||
"""
|
||||
Encode a name/value pair within a multipart form.
|
||||
"""
|
||||
|
||||
def replacer(match: typing.Match[str]) -> str:
|
||||
return _HTML5_FORM_ENCODING_REPLACEMENTS[match.group(0)]
|
||||
|
||||
value = _HTML5_FORM_ENCODING_RE.sub(replacer, value)
|
||||
return f'{name}="{value}"'.encode()
|
||||
|
||||
|
||||
# Null bytes; no need to recreate these on each call to guess_json_utf
|
||||
_null = b"\x00"
|
||||
_null2 = _null * 2
|
||||
_null3 = _null * 3
|
||||
|
||||
|
||||
def guess_json_utf(data: bytes) -> typing.Optional[str]:
|
||||
# JSON always starts with two ASCII characters, so detection is as
|
||||
# easy as counting the nulls and from their location and count
|
||||
# determine the encoding. Also detect a BOM, if present.
|
||||
sample = data[:4]
|
||||
if sample in (codecs.BOM_UTF32_LE, codecs.BOM_UTF32_BE):
|
||||
return "utf-32" # BOM included
|
||||
if sample[:3] == codecs.BOM_UTF8:
|
||||
return "utf-8-sig" # BOM included, MS style (discouraged)
|
||||
if sample[:2] in (codecs.BOM_UTF16_LE, codecs.BOM_UTF16_BE):
|
||||
return "utf-16" # BOM included
|
||||
nullcount = sample.count(_null)
|
||||
if nullcount == 0:
|
||||
return "utf-8"
|
||||
if nullcount == 2:
|
||||
if sample[::2] == _null2: # 1st and 3rd are null
|
||||
return "utf-16-be"
|
||||
if sample[1::2] == _null2: # 2nd and 4th are null
|
||||
return "utf-16-le"
|
||||
# Did not detect 2 valid UTF-16 ascii-range characters
|
||||
if nullcount == 3:
|
||||
if sample[:3] == _null3:
|
||||
return "utf-32-be"
|
||||
if sample[1:] == _null3:
|
||||
return "utf-32-le"
|
||||
# Did not detect a valid UTF-32 ascii-range character
|
||||
return None
|
||||
|
||||
|
||||
def get_ca_bundle_from_env() -> typing.Optional[str]:
|
||||
if "SSL_CERT_FILE" in os.environ:
|
||||
ssl_file = Path(os.environ["SSL_CERT_FILE"])
|
||||
if ssl_file.is_file():
|
||||
return str(ssl_file)
|
||||
if "SSL_CERT_DIR" in os.environ:
|
||||
ssl_path = Path(os.environ["SSL_CERT_DIR"])
|
||||
if ssl_path.is_dir():
|
||||
return str(ssl_path)
|
||||
return None
|
||||
|
||||
|
||||
def parse_header_links(value: str) -> typing.List[typing.Dict[str, str]]:
|
||||
"""
|
||||
Returns a list of parsed link headers, for more info see:
|
||||
https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Link
|
||||
The generic syntax of those is:
|
||||
Link: < uri-reference >; param1=value1; param2="value2"
|
||||
So for instance:
|
||||
Link; '<http:/.../front.jpeg>; type="image/jpeg",<http://.../back.jpeg>;'
|
||||
would return
|
||||
[
|
||||
{"url": "http:/.../front.jpeg", "type": "image/jpeg"},
|
||||
{"url": "http://.../back.jpeg"},
|
||||
]
|
||||
:param value: HTTP Link entity-header field
|
||||
:return: list of parsed link headers
|
||||
"""
|
||||
links: typing.List[typing.Dict[str, str]] = []
|
||||
replace_chars = " '\""
|
||||
value = value.strip(replace_chars)
|
||||
if not value:
|
||||
return links
|
||||
for val in re.split(", *<", value):
|
||||
try:
|
||||
url, params = val.split(";", 1)
|
||||
except ValueError:
|
||||
url, params = val, ""
|
||||
link = {"url": url.strip("<> '\"")}
|
||||
for param in params.split(";"):
|
||||
try:
|
||||
key, value = param.split("=")
|
||||
except ValueError:
|
||||
break
|
||||
link[key.strip(replace_chars)] = value.strip(replace_chars)
|
||||
links.append(link)
|
||||
return links
|
||||
|
||||
|
||||
def parse_content_type_charset(content_type: str) -> typing.Optional[str]:
|
||||
# We used to use `cgi.parse_header()` here, but `cgi` became a dead battery.
|
||||
# See: https://peps.python.org/pep-0594/#cgi
|
||||
msg = email.message.Message()
|
||||
msg["content-type"] = content_type
|
||||
return msg.get_content_charset(failobj=None)
|
||||
|
||||
|
||||
SENSITIVE_HEADERS = {"authorization", "proxy-authorization"}
|
||||
|
||||
|
||||
def obfuscate_sensitive_headers(
|
||||
items: typing.Iterable[typing.Tuple[typing.AnyStr, typing.AnyStr]]
|
||||
) -> typing.Iterator[typing.Tuple[typing.AnyStr, typing.AnyStr]]:
|
||||
for k, v in items:
|
||||
if to_str(k.lower()) in SENSITIVE_HEADERS:
|
||||
v = to_bytes_or_str("[secure]", match_type_of=v)
|
||||
yield k, v
|
||||
|
||||
|
||||
def port_or_default(url: "URL") -> typing.Optional[int]:
|
||||
if url.port is not None:
|
||||
return url.port
|
||||
return {"http": 80, "https": 443}.get(url.scheme)
|
||||
|
||||
|
||||
def same_origin(url: "URL", other: "URL") -> bool:
|
||||
"""
|
||||
Return 'True' if the given URLs share the same origin.
|
||||
"""
|
||||
return (
|
||||
url.scheme == other.scheme
|
||||
and url.host == other.host
|
||||
and port_or_default(url) == port_or_default(other)
|
||||
)
|
||||
|
||||
|
||||
def is_https_redirect(url: "URL", location: "URL") -> bool:
|
||||
"""
|
||||
Return 'True' if 'location' is a HTTPS upgrade of 'url'
|
||||
"""
|
||||
if url.host != location.host:
|
||||
return False
|
||||
|
||||
return (
|
||||
url.scheme == "http"
|
||||
and port_or_default(url) == 80
|
||||
and location.scheme == "https"
|
||||
and port_or_default(location) == 443
|
||||
)
|
||||
|
||||
|
||||
def get_environment_proxies() -> typing.Dict[str, typing.Optional[str]]:
|
||||
def get_environment_proxies() -> dict[str, str | None]:
|
||||
"""Gets proxy information from the environment"""
|
||||
|
||||
# urllib.request.getproxies() falls back on System
|
||||
@@ -235,7 +35,7 @@ def get_environment_proxies() -> typing.Dict[str, typing.Optional[str]]:
|
||||
# We don't want to propagate non-HTTP proxies into
|
||||
# our configuration such as 'TRAVIS_APT_PROXY'.
|
||||
proxy_info = getproxies()
|
||||
mounts: typing.Dict[str, typing.Optional[str]] = {}
|
||||
mounts: dict[str, str | None] = {}
|
||||
|
||||
for scheme in ("http", "https", "all"):
|
||||
if proxy_info.get(scheme):
|
||||
@@ -262,7 +62,9 @@ def get_environment_proxies() -> typing.Dict[str, typing.Optional[str]]:
|
||||
# (But not "wwwgoogle.com")
|
||||
# NO_PROXY can include domains, IPv6, IPv4 addresses and "localhost"
|
||||
# NO_PROXY=example.com,::1,localhost,192.168.0.0/16
|
||||
if is_ipv4_hostname(hostname):
|
||||
if "://" in hostname:
|
||||
mounts[hostname] = None
|
||||
elif is_ipv4_hostname(hostname):
|
||||
mounts[f"all://{hostname}"] = None
|
||||
elif is_ipv6_hostname(hostname):
|
||||
mounts[f"all://[{hostname}]"] = None
|
||||
@@ -274,11 +76,11 @@ def get_environment_proxies() -> typing.Dict[str, typing.Optional[str]]:
|
||||
return mounts
|
||||
|
||||
|
||||
def to_bytes(value: typing.Union[str, bytes], encoding: str = "utf-8") -> bytes:
|
||||
def to_bytes(value: str | bytes, encoding: str = "utf-8") -> bytes:
|
||||
return value.encode(encoding) if isinstance(value, str) else value
|
||||
|
||||
|
||||
def to_str(value: typing.Union[str, bytes], encoding: str = "utf-8") -> str:
|
||||
def to_str(value: str | bytes, encoding: str = "utf-8") -> str:
|
||||
return value if isinstance(value, str) else value.decode(encoding)
|
||||
|
||||
|
||||
@@ -290,13 +92,7 @@ def unquote(value: str) -> str:
|
||||
return value[1:-1] if value[0] == value[-1] == '"' else value
|
||||
|
||||
|
||||
def guess_content_type(filename: typing.Optional[str]) -> typing.Optional[str]:
|
||||
if filename:
|
||||
return mimetypes.guess_type(filename)[0] or "application/octet-stream"
|
||||
return None
|
||||
|
||||
|
||||
def peek_filelike_length(stream: typing.Any) -> typing.Optional[int]:
|
||||
def peek_filelike_length(stream: typing.Any) -> int | None:
|
||||
"""
|
||||
Given a file-like stream object, return its length in number of bytes
|
||||
without reading it into memory.
|
||||
@@ -321,48 +117,17 @@ def peek_filelike_length(stream: typing.Any) -> typing.Optional[int]:
|
||||
return length
|
||||
|
||||
|
||||
class Timer:
|
||||
async def _get_time(self) -> float:
|
||||
library = sniffio.current_async_library()
|
||||
if library == "trio":
|
||||
import trio
|
||||
|
||||
return trio.current_time()
|
||||
elif library == "curio": # pragma: no cover
|
||||
import curio
|
||||
|
||||
return typing.cast(float, await curio.clock())
|
||||
|
||||
import asyncio
|
||||
|
||||
return asyncio.get_event_loop().time()
|
||||
|
||||
def sync_start(self) -> None:
|
||||
self.started = time.perf_counter()
|
||||
|
||||
async def async_start(self) -> None:
|
||||
self.started = await self._get_time()
|
||||
|
||||
def sync_elapsed(self) -> float:
|
||||
now = time.perf_counter()
|
||||
return now - self.started
|
||||
|
||||
async def async_elapsed(self) -> float:
|
||||
now = await self._get_time()
|
||||
return now - self.started
|
||||
|
||||
|
||||
class URLPattern:
|
||||
"""
|
||||
A utility class currently used for making lookups against proxy keys...
|
||||
|
||||
# Wildcard matching...
|
||||
>>> pattern = URLPattern("all")
|
||||
>>> pattern = URLPattern("all://")
|
||||
>>> pattern.matches(httpx.URL("http://example.com"))
|
||||
True
|
||||
|
||||
# Witch scheme matching...
|
||||
>>> pattern = URLPattern("https")
|
||||
>>> pattern = URLPattern("https://")
|
||||
>>> pattern.matches(httpx.URL("https://example.com"))
|
||||
True
|
||||
>>> pattern.matches(httpx.URL("http://example.com"))
|
||||
@@ -410,7 +175,7 @@ class URLPattern:
|
||||
self.host = "" if url.host == "*" else url.host
|
||||
self.port = url.port
|
||||
if not url.host or url.host == "*":
|
||||
self.host_regex: typing.Optional[typing.Pattern[str]] = None
|
||||
self.host_regex: typing.Pattern[str] | None = None
|
||||
elif url.host.startswith("*."):
|
||||
# *.example.com should match "www.example.com", but not "example.com"
|
||||
domain = re.escape(url.host[2:])
|
||||
@@ -424,7 +189,7 @@ class URLPattern:
|
||||
domain = re.escape(url.host)
|
||||
self.host_regex = re.compile(f"^{domain}$")
|
||||
|
||||
def matches(self, other: "URL") -> bool:
|
||||
def matches(self, other: URL) -> bool:
|
||||
if self.scheme and self.scheme != other.scheme:
|
||||
return False
|
||||
if (
|
||||
@@ -438,7 +203,7 @@ class URLPattern:
|
||||
return True
|
||||
|
||||
@property
|
||||
def priority(self) -> typing.Tuple[int, int, int]:
|
||||
def priority(self) -> tuple[int, int, int]:
|
||||
"""
|
||||
The priority allows URLPattern instances to be sortable, so that
|
||||
we can match from most specific to least specific.
|
||||
@@ -454,7 +219,7 @@ class URLPattern:
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.pattern)
|
||||
|
||||
def __lt__(self, other: "URLPattern") -> bool:
|
||||
def __lt__(self, other: URLPattern) -> bool:
|
||||
return self.priority < other.priority
|
||||
|
||||
def __eq__(self, other: typing.Any) -> bool:
|
||||
|
||||
Reference in New Issue
Block a user