This commit is contained in:
Iliyan Angelov
2025-12-01 06:50:10 +02:00
parent 91f51bc6fe
commit 62c1fe5951
4682 changed files with 544807 additions and 31208 deletions

View File

@@ -0,0 +1,15 @@
from .asgi import *
from .base import *
from .default import *
from .mock import *
from .wsgi import *
__all__ = [
"ASGITransport",
"AsyncBaseTransport",
"BaseTransport",
"AsyncHTTPTransport",
"HTTPTransport",
"MockTransport",
"WSGITransport",
]

View File

@@ -1,6 +1,6 @@
import typing
from __future__ import annotations
import sniffio
import typing
from .._models import Request, Response
from .._types import AsyncByteStream
@@ -14,29 +14,46 @@ if typing.TYPE_CHECKING: # pragma: no cover
Event = typing.Union[asyncio.Event, trio.Event]
_Message = typing.Dict[str, typing.Any]
_Message = typing.MutableMapping[str, typing.Any]
_Receive = typing.Callable[[], typing.Awaitable[_Message]]
_Send = typing.Callable[
[typing.Dict[str, typing.Any]], typing.Coroutine[None, None, None]
[typing.MutableMapping[str, typing.Any]], typing.Awaitable[None]
]
_ASGIApp = typing.Callable[
[typing.Dict[str, typing.Any], _Receive, _Send], typing.Coroutine[None, None, None]
[typing.MutableMapping[str, typing.Any], _Receive, _Send], typing.Awaitable[None]
]
__all__ = ["ASGITransport"]
def create_event() -> "Event":
if sniffio.current_async_library() == "trio":
def is_running_trio() -> bool:
try:
# sniffio is a dependency of trio.
# See https://github.com/python-trio/trio/issues/2802
import sniffio
if sniffio.current_async_library() == "trio":
return True
except ImportError: # pragma: nocover
pass
return False
def create_event() -> Event:
if is_running_trio():
import trio
return trio.Event()
else:
import asyncio
return asyncio.Event()
import asyncio
return asyncio.Event()
class ASGIResponseStream(AsyncByteStream):
def __init__(self, body: typing.List[bytes]) -> None:
def __init__(self, body: list[bytes]) -> None:
self._body = body
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
@@ -46,17 +63,8 @@ class ASGIResponseStream(AsyncByteStream):
class ASGITransport(AsyncBaseTransport):
"""
A custom AsyncTransport that handles sending requests directly to an ASGI app.
The simplest way to use this functionality is to use the `app` argument.
```
client = httpx.AsyncClient(app=app)
```
Alternatively, you can setup the transport instance explicitly.
This allows you to include any additional configuration arguments specific
to the ASGITransport class:
```
```python
transport = httpx.ASGITransport(
app=app,
root_path="/submount",
@@ -81,7 +89,7 @@ class ASGITransport(AsyncBaseTransport):
app: _ASGIApp,
raise_app_exceptions: bool = True,
root_path: str = "",
client: typing.Tuple[str, int] = ("127.0.0.1", 123),
client: tuple[str, int] = ("127.0.0.1", 123),
) -> None:
self.app = app
self.raise_app_exceptions = raise_app_exceptions
@@ -103,7 +111,7 @@ class ASGITransport(AsyncBaseTransport):
"headers": [(k.lower(), v) for (k, v) in request.headers.raw],
"scheme": request.url.scheme,
"path": request.url.path,
"raw_path": request.url.raw_path,
"raw_path": request.url.raw_path.split(b"?")[0],
"query_string": request.url.query,
"server": (request.url.host, request.url.port),
"client": self.client,
@@ -123,7 +131,7 @@ class ASGITransport(AsyncBaseTransport):
# ASGI callables.
async def receive() -> typing.Dict[str, typing.Any]:
async def receive() -> dict[str, typing.Any]:
nonlocal request_complete
if request_complete:
@@ -137,7 +145,7 @@ class ASGITransport(AsyncBaseTransport):
return {"type": "http.request", "body": b"", "more_body": False}
return {"type": "http.request", "body": body, "more_body": True}
async def send(message: typing.Dict[str, typing.Any]) -> None:
async def send(message: typing.MutableMapping[str, typing.Any]) -> None:
nonlocal status_code, response_headers, response_started
if message["type"] == "http.response.start":
@@ -161,9 +169,15 @@ class ASGITransport(AsyncBaseTransport):
try:
await self.app(scope, receive, send)
except Exception: # noqa: PIE-786
if self.raise_app_exceptions or not response_complete.is_set():
if self.raise_app_exceptions:
raise
response_complete.set()
if status_code is None:
status_code = 500
if response_headers is None:
response_headers = {}
assert response_complete.is_set()
assert status_code is not None
assert response_headers is not None

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
import typing
from types import TracebackType
@@ -6,6 +8,8 @@ from .._models import Request, Response
T = typing.TypeVar("T", bound="BaseTransport")
A = typing.TypeVar("A", bound="AsyncBaseTransport")
__all__ = ["AsyncBaseTransport", "BaseTransport"]
class BaseTransport:
def __enter__(self: T) -> T:
@@ -13,9 +17,9 @@ class BaseTransport:
def __exit__(
self,
exc_type: typing.Optional[typing.Type[BaseException]] = None,
exc_value: typing.Optional[BaseException] = None,
traceback: typing.Optional[TracebackType] = None,
exc_type: type[BaseException] | None = None,
exc_value: BaseException | None = None,
traceback: TracebackType | None = None,
) -> None:
self.close()
@@ -64,9 +68,9 @@ class AsyncBaseTransport:
async def __aexit__(
self,
exc_type: typing.Optional[typing.Type[BaseException]] = None,
exc_value: typing.Optional[BaseException] = None,
traceback: typing.Optional[TracebackType] = None,
exc_type: type[BaseException] | None = None,
exc_value: BaseException | None = None,
traceback: TracebackType | None = None,
) -> None:
await self.aclose()

View File

@@ -23,11 +23,17 @@ client = httpx.Client(transport=transport)
transport = httpx.HTTPTransport(uds="socket.uds")
client = httpx.Client(transport=transport)
"""
from __future__ import annotations
import contextlib
import typing
from types import TracebackType
import httpcore
if typing.TYPE_CHECKING:
import ssl # pragma: no cover
import httpx # pragma: no cover
from .._config import DEFAULT_LIMITS, Limits, Proxy, create_ssl_context
from .._exceptions import (
@@ -47,18 +53,53 @@ from .._exceptions import (
WriteTimeout,
)
from .._models import Request, Response
from .._types import AsyncByteStream, CertTypes, SyncByteStream, VerifyTypes
from .._types import AsyncByteStream, CertTypes, ProxyTypes, SyncByteStream
from .._urls import URL
from .base import AsyncBaseTransport, BaseTransport
T = typing.TypeVar("T", bound="HTTPTransport")
A = typing.TypeVar("A", bound="AsyncHTTPTransport")
SOCKET_OPTION = typing.Union[
typing.Tuple[int, int, int],
typing.Tuple[int, int, typing.Union[bytes, bytearray]],
typing.Tuple[int, int, None, int],
]
__all__ = ["AsyncHTTPTransport", "HTTPTransport"]
HTTPCORE_EXC_MAP: dict[type[Exception], type[httpx.HTTPError]] = {}
def _load_httpcore_exceptions() -> dict[type[Exception], type[httpx.HTTPError]]:
import httpcore
return {
httpcore.TimeoutException: TimeoutException,
httpcore.ConnectTimeout: ConnectTimeout,
httpcore.ReadTimeout: ReadTimeout,
httpcore.WriteTimeout: WriteTimeout,
httpcore.PoolTimeout: PoolTimeout,
httpcore.NetworkError: NetworkError,
httpcore.ConnectError: ConnectError,
httpcore.ReadError: ReadError,
httpcore.WriteError: WriteError,
httpcore.ProxyError: ProxyError,
httpcore.UnsupportedProtocol: UnsupportedProtocol,
httpcore.ProtocolError: ProtocolError,
httpcore.LocalProtocolError: LocalProtocolError,
httpcore.RemoteProtocolError: RemoteProtocolError,
}
@contextlib.contextmanager
def map_httpcore_exceptions() -> typing.Iterator[None]:
global HTTPCORE_EXC_MAP
if len(HTTPCORE_EXC_MAP) == 0:
HTTPCORE_EXC_MAP = _load_httpcore_exceptions()
try:
yield
except Exception as exc: # noqa: PIE-786
except Exception as exc:
mapped_exc = None
for from_exc, to_exc in HTTPCORE_EXC_MAP.items():
@@ -77,26 +118,8 @@ def map_httpcore_exceptions() -> typing.Iterator[None]:
raise mapped_exc(message) from exc
HTTPCORE_EXC_MAP = {
httpcore.TimeoutException: TimeoutException,
httpcore.ConnectTimeout: ConnectTimeout,
httpcore.ReadTimeout: ReadTimeout,
httpcore.WriteTimeout: WriteTimeout,
httpcore.PoolTimeout: PoolTimeout,
httpcore.NetworkError: NetworkError,
httpcore.ConnectError: ConnectError,
httpcore.ReadError: ReadError,
httpcore.WriteError: WriteError,
httpcore.ProxyError: ProxyError,
httpcore.UnsupportedProtocol: UnsupportedProtocol,
httpcore.ProtocolError: ProtocolError,
httpcore.LocalProtocolError: LocalProtocolError,
httpcore.RemoteProtocolError: RemoteProtocolError,
}
class ResponseStream(SyncByteStream):
def __init__(self, httpcore_stream: typing.Iterable[bytes]):
def __init__(self, httpcore_stream: typing.Iterable[bytes]) -> None:
self._httpcore_stream = httpcore_stream
def __iter__(self) -> typing.Iterator[bytes]:
@@ -112,17 +135,21 @@ class ResponseStream(SyncByteStream):
class HTTPTransport(BaseTransport):
def __init__(
self,
verify: VerifyTypes = True,
cert: typing.Optional[CertTypes] = None,
verify: ssl.SSLContext | str | bool = True,
cert: CertTypes | None = None,
trust_env: bool = True,
http1: bool = True,
http2: bool = False,
limits: Limits = DEFAULT_LIMITS,
trust_env: bool = True,
proxy: typing.Optional[Proxy] = None,
uds: typing.Optional[str] = None,
local_address: typing.Optional[str] = None,
proxy: ProxyTypes | None = None,
uds: str | None = None,
local_address: str | None = None,
retries: int = 0,
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
) -> None:
import httpcore
proxy = Proxy(url=proxy) if isinstance(proxy, (str, URL)) else proxy
ssl_context = create_ssl_context(verify=verify, cert=cert, trust_env=trust_env)
if proxy is None:
@@ -136,6 +163,7 @@ class HTTPTransport(BaseTransport):
uds=uds,
local_address=local_address,
retries=retries,
socket_options=socket_options,
)
elif proxy.url.scheme in ("http", "https"):
self._pool = httpcore.HTTPProxy(
@@ -148,13 +176,15 @@ class HTTPTransport(BaseTransport):
proxy_auth=proxy.raw_auth,
proxy_headers=proxy.headers.raw,
ssl_context=ssl_context,
proxy_ssl_context=proxy.ssl_context,
max_connections=limits.max_connections,
max_keepalive_connections=limits.max_keepalive_connections,
keepalive_expiry=limits.keepalive_expiry,
http1=http1,
http2=http2,
socket_options=socket_options,
)
elif proxy.url.scheme == "socks5":
elif proxy.url.scheme in ("socks5", "socks5h"):
try:
import socksio # noqa
except ImportError: # pragma: no cover
@@ -180,7 +210,8 @@ class HTTPTransport(BaseTransport):
)
else: # pragma: no cover
raise ValueError(
f"Proxy protocol must be either 'http', 'https', or 'socks5', but got {proxy.url.scheme!r}."
"Proxy protocol must be either 'http', 'https', 'socks5', or 'socks5h',"
f" but got {proxy.url.scheme!r}."
)
def __enter__(self: T) -> T: # Use generics for subclass support.
@@ -189,9 +220,9 @@ class HTTPTransport(BaseTransport):
def __exit__(
self,
exc_type: typing.Optional[typing.Type[BaseException]] = None,
exc_value: typing.Optional[BaseException] = None,
traceback: typing.Optional[TracebackType] = None,
exc_type: type[BaseException] | None = None,
exc_value: BaseException | None = None,
traceback: TracebackType | None = None,
) -> None:
with map_httpcore_exceptions():
self._pool.__exit__(exc_type, exc_value, traceback)
@@ -201,6 +232,7 @@ class HTTPTransport(BaseTransport):
request: Request,
) -> Response:
assert isinstance(request.stream, SyncByteStream)
import httpcore
req = httpcore.Request(
method=request.method,
@@ -231,7 +263,7 @@ class HTTPTransport(BaseTransport):
class AsyncResponseStream(AsyncByteStream):
def __init__(self, httpcore_stream: typing.AsyncIterable[bytes]):
def __init__(self, httpcore_stream: typing.AsyncIterable[bytes]) -> None:
self._httpcore_stream = httpcore_stream
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
@@ -247,17 +279,21 @@ class AsyncResponseStream(AsyncByteStream):
class AsyncHTTPTransport(AsyncBaseTransport):
def __init__(
self,
verify: VerifyTypes = True,
cert: typing.Optional[CertTypes] = None,
verify: ssl.SSLContext | str | bool = True,
cert: CertTypes | None = None,
trust_env: bool = True,
http1: bool = True,
http2: bool = False,
limits: Limits = DEFAULT_LIMITS,
trust_env: bool = True,
proxy: typing.Optional[Proxy] = None,
uds: typing.Optional[str] = None,
local_address: typing.Optional[str] = None,
proxy: ProxyTypes | None = None,
uds: str | None = None,
local_address: str | None = None,
retries: int = 0,
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
) -> None:
import httpcore
proxy = Proxy(url=proxy) if isinstance(proxy, (str, URL)) else proxy
ssl_context = create_ssl_context(verify=verify, cert=cert, trust_env=trust_env)
if proxy is None:
@@ -271,6 +307,7 @@ class AsyncHTTPTransport(AsyncBaseTransport):
uds=uds,
local_address=local_address,
retries=retries,
socket_options=socket_options,
)
elif proxy.url.scheme in ("http", "https"):
self._pool = httpcore.AsyncHTTPProxy(
@@ -282,14 +319,16 @@ class AsyncHTTPTransport(AsyncBaseTransport):
),
proxy_auth=proxy.raw_auth,
proxy_headers=proxy.headers.raw,
proxy_ssl_context=proxy.ssl_context,
ssl_context=ssl_context,
max_connections=limits.max_connections,
max_keepalive_connections=limits.max_keepalive_connections,
keepalive_expiry=limits.keepalive_expiry,
http1=http1,
http2=http2,
socket_options=socket_options,
)
elif proxy.url.scheme == "socks5":
elif proxy.url.scheme in ("socks5", "socks5h"):
try:
import socksio # noqa
except ImportError: # pragma: no cover
@@ -315,7 +354,8 @@ class AsyncHTTPTransport(AsyncBaseTransport):
)
else: # pragma: no cover
raise ValueError(
f"Proxy protocol must be either 'http', 'https', or 'socks5', but got {proxy.url.scheme!r}."
"Proxy protocol must be either 'http', 'https', 'socks5', or 'socks5h',"
" but got {proxy.url.scheme!r}."
)
async def __aenter__(self: A) -> A: # Use generics for subclass support.
@@ -324,9 +364,9 @@ class AsyncHTTPTransport(AsyncBaseTransport):
async def __aexit__(
self,
exc_type: typing.Optional[typing.Type[BaseException]] = None,
exc_value: typing.Optional[BaseException] = None,
traceback: typing.Optional[TracebackType] = None,
exc_type: type[BaseException] | None = None,
exc_value: BaseException | None = None,
traceback: TracebackType | None = None,
) -> None:
with map_httpcore_exceptions():
await self._pool.__aexit__(exc_type, exc_value, traceback)
@@ -336,6 +376,7 @@ class AsyncHTTPTransport(AsyncBaseTransport):
request: Request,
) -> Response:
assert isinstance(request.stream, AsyncByteStream)
import httpcore
req = httpcore.Request(
method=request.method,

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
import typing
from .._models import Request, Response
@@ -7,8 +9,11 @@ SyncHandler = typing.Callable[[Request], Response]
AsyncHandler = typing.Callable[[Request], typing.Coroutine[None, None, Response]]
__all__ = ["MockTransport"]
class MockTransport(AsyncBaseTransport, BaseTransport):
def __init__(self, handler: typing.Union[SyncHandler, AsyncHandler]) -> None:
def __init__(self, handler: SyncHandler | AsyncHandler) -> None:
self.handler = handler
def handle_request(

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
import io
import itertools
import sys
@@ -14,6 +16,9 @@ if typing.TYPE_CHECKING:
_T = typing.TypeVar("_T")
__all__ = ["WSGITransport"]
def _skip_leading_empty_chunks(body: typing.Iterable[_T]) -> typing.Iterable[_T]:
body = iter(body)
for chunk in body:
@@ -71,11 +76,11 @@ class WSGITransport(BaseTransport):
def __init__(
self,
app: "WSGIApplication",
app: WSGIApplication,
raise_app_exceptions: bool = True,
script_name: str = "",
remote_addr: str = "127.0.0.1",
wsgi_errors: typing.Optional[typing.TextIO] = None,
wsgi_errors: typing.TextIO | None = None,
) -> None:
self.app = app
self.raise_app_exceptions = raise_app_exceptions
@@ -102,6 +107,7 @@ class WSGITransport(BaseTransport):
"QUERY_STRING": request.url.query.decode("ascii"),
"SERVER_NAME": request.url.host,
"SERVER_PORT": str(port),
"SERVER_PROTOCOL": "HTTP/1.1",
"REMOTE_ADDR": self.remote_addr,
}
for header_key, header_value in request.headers.raw:
@@ -116,8 +122,8 @@ class WSGITransport(BaseTransport):
def start_response(
status: str,
response_headers: typing.List[typing.Tuple[str, str]],
exc_info: typing.Optional["OptExcInfo"] = None,
response_headers: list[tuple[str, str]],
exc_info: OptExcInfo | None = None,
) -> typing.Callable[[bytes], typing.Any]:
nonlocal seen_status, seen_response_headers, seen_exc_info
seen_status = status