This commit is contained in:
Iliyan Angelov
2025-12-01 06:50:10 +02:00
parent 91f51bc6fe
commit 62c1fe5951
4682 changed files with 544807 additions and 31208 deletions

View File

@@ -34,7 +34,7 @@ from ._exceptions import (
WriteError,
WriteTimeout,
)
from ._models import URL, Origin, Request, Response
from ._models import URL, Origin, Proxy, Request, Response
from ._ssl import default_ssl_context
from ._sync import (
ConnectionInterface,
@@ -79,6 +79,7 @@ __all__ = [
"URL",
"Request",
"Response",
"Proxy",
# async
"AsyncHTTPConnection",
"AsyncConnectionPool",
@@ -130,10 +131,11 @@ __all__ = [
"WriteError",
]
__version__ = "0.17.3"
__version__ = "1.0.9"
__locals = locals()
for __name in __all__:
if not __name.startswith("__"):
# Exclude SOCKET_OPTION, it causes AttributeError on Python 3.14
if not __name.startswith(("__", "SOCKET_OPTION")):
setattr(__locals[__name], "__module__", "httpcore") # noqa

View File

@@ -1,17 +1,19 @@
from contextlib import contextmanager
from typing import Iterator, Optional, Union
from __future__ import annotations
import contextlib
import typing
from ._models import URL, Extensions, HeaderTypes, Response
from ._sync.connection_pool import ConnectionPool
def request(
method: Union[bytes, str],
url: Union[URL, bytes, str],
method: bytes | str,
url: URL | bytes | str,
*,
headers: HeaderTypes = None,
content: Union[bytes, Iterator[bytes], None] = None,
extensions: Optional[Extensions] = None,
content: bytes | typing.Iterator[bytes] | None = None,
extensions: Extensions | None = None,
) -> Response:
"""
Sends an HTTP request, returning the response.
@@ -45,15 +47,15 @@ def request(
)
@contextmanager
@contextlib.contextmanager
def stream(
method: Union[bytes, str],
url: Union[URL, bytes, str],
method: bytes | str,
url: URL | bytes | str,
*,
headers: HeaderTypes = None,
content: Union[bytes, Iterator[bytes], None] = None,
extensions: Optional[Extensions] = None,
) -> Iterator[Response]:
content: bytes | typing.Iterator[bytes] | None = None,
extensions: Extensions | None = None,
) -> typing.Iterator[Response]:
"""
Sends an HTTP request, returning the response within a content manager.

View File

@@ -1,12 +1,14 @@
from __future__ import annotations
import itertools
import logging
import ssl
from types import TracebackType
from typing import Iterable, Iterator, Optional, Type
import types
import typing
from .._backends.auto import AutoBackend
from .._backends.base import SOCKET_OPTION, AsyncNetworkBackend, AsyncNetworkStream
from .._exceptions import ConnectError, ConnectionNotAvailable, ConnectTimeout
from .._exceptions import ConnectError, ConnectTimeout
from .._models import Origin, Request, Response
from .._ssl import default_ssl_context
from .._synchronization import AsyncLock
@@ -20,25 +22,32 @@ RETRIES_BACKOFF_FACTOR = 0.5 # 0s, 0.5s, 1s, 2s, 4s, etc.
logger = logging.getLogger("httpcore.connection")
def exponential_backoff(factor: float) -> Iterator[float]:
def exponential_backoff(factor: float) -> typing.Iterator[float]:
"""
Generate a geometric sequence that has a ratio of 2 and starts with 0.
For example:
- `factor = 2`: `0, 2, 4, 8, 16, 32, 64, ...`
- `factor = 3`: `0, 3, 6, 12, 24, 48, 96, ...`
"""
yield 0
for n in itertools.count(2):
yield factor * (2 ** (n - 2))
for n in itertools.count():
yield factor * 2**n
class AsyncHTTPConnection(AsyncConnectionInterface):
def __init__(
self,
origin: Origin,
ssl_context: Optional[ssl.SSLContext] = None,
keepalive_expiry: Optional[float] = None,
ssl_context: ssl.SSLContext | None = None,
keepalive_expiry: float | None = None,
http1: bool = True,
http2: bool = False,
retries: int = 0,
local_address: Optional[str] = None,
uds: Optional[str] = None,
network_backend: Optional[AsyncNetworkBackend] = None,
socket_options: Optional[Iterable[SOCKET_OPTION]] = None,
local_address: str | None = None,
uds: str | None = None,
network_backend: AsyncNetworkBackend | None = None,
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
) -> None:
self._origin = origin
self._ssl_context = ssl_context
@@ -52,7 +61,7 @@ class AsyncHTTPConnection(AsyncConnectionInterface):
self._network_backend: AsyncNetworkBackend = (
AutoBackend() if network_backend is None else network_backend
)
self._connection: Optional[AsyncConnectionInterface] = None
self._connection: AsyncConnectionInterface | None = None
self._connect_failed: bool = False
self._request_lock = AsyncLock()
self._socket_options = socket_options
@@ -63,9 +72,9 @@ class AsyncHTTPConnection(AsyncConnectionInterface):
f"Attempted to send request to {request.url.origin} on connection to {self._origin}"
)
async with self._request_lock:
if self._connection is None:
try:
try:
async with self._request_lock:
if self._connection is None:
stream = await self._connect(request)
ssl_object = stream.get_extra_info("ssl_object")
@@ -87,11 +96,9 @@ class AsyncHTTPConnection(AsyncConnectionInterface):
stream=stream,
keepalive_expiry=self._keepalive_expiry,
)
except Exception as exc:
self._connect_failed = True
raise exc
elif not self._connection.is_available():
raise ConnectionNotAvailable()
except BaseException as exc:
self._connect_failed = True
raise exc
return await self._connection.handle_async_request(request)
@@ -130,7 +137,7 @@ class AsyncHTTPConnection(AsyncConnectionInterface):
)
trace.return_value = stream
if self._origin.scheme == b"https":
if self._origin.scheme in (b"https", b"wss"):
ssl_context = (
default_ssl_context()
if self._ssl_context is None
@@ -203,13 +210,13 @@ class AsyncHTTPConnection(AsyncConnectionInterface):
# These context managers are not used in the standard flow, but are
# useful for testing or working with connection instances directly.
async def __aenter__(self) -> "AsyncHTTPConnection":
async def __aenter__(self) -> AsyncHTTPConnection:
return self
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]] = None,
exc_value: Optional[BaseException] = None,
traceback: Optional[TracebackType] = None,
exc_type: type[BaseException] | None = None,
exc_value: BaseException | None = None,
traceback: types.TracebackType | None = None,
) -> None:
await self.aclose()

View File

@@ -1,41 +1,44 @@
from __future__ import annotations
import ssl
import sys
from types import TracebackType
from typing import AsyncIterable, AsyncIterator, Iterable, List, Optional, Type
import types
import typing
from .._backends.auto import AutoBackend
from .._backends.base import SOCKET_OPTION, AsyncNetworkBackend
from .._exceptions import ConnectionNotAvailable, UnsupportedProtocol
from .._models import Origin, Request, Response
from .._synchronization import AsyncEvent, AsyncLock, AsyncShieldCancellation
from .._models import Origin, Proxy, Request, Response
from .._synchronization import AsyncEvent, AsyncShieldCancellation, AsyncThreadLock
from .connection import AsyncHTTPConnection
from .interfaces import AsyncConnectionInterface, AsyncRequestInterface
class RequestStatus:
def __init__(self, request: Request):
class AsyncPoolRequest:
def __init__(self, request: Request) -> None:
self.request = request
self.connection: Optional[AsyncConnectionInterface] = None
self.connection: AsyncConnectionInterface | None = None
self._connection_acquired = AsyncEvent()
def set_connection(self, connection: AsyncConnectionInterface) -> None:
assert self.connection is None
def assign_to_connection(self, connection: AsyncConnectionInterface | None) -> None:
self.connection = connection
self._connection_acquired.set()
def unset_connection(self) -> None:
assert self.connection is not None
def clear_connection(self) -> None:
self.connection = None
self._connection_acquired = AsyncEvent()
async def wait_for_connection(
self, timeout: Optional[float] = None
self, timeout: float | None = None
) -> AsyncConnectionInterface:
if self.connection is None:
await self._connection_acquired.wait(timeout=timeout)
assert self.connection is not None
return self.connection
def is_queued(self) -> bool:
return self.connection is None
class AsyncConnectionPool(AsyncRequestInterface):
"""
@@ -44,17 +47,18 @@ class AsyncConnectionPool(AsyncRequestInterface):
def __init__(
self,
ssl_context: Optional[ssl.SSLContext] = None,
max_connections: Optional[int] = 10,
max_keepalive_connections: Optional[int] = None,
keepalive_expiry: Optional[float] = None,
ssl_context: ssl.SSLContext | None = None,
proxy: Proxy | None = None,
max_connections: int | None = 10,
max_keepalive_connections: int | None = None,
keepalive_expiry: float | None = None,
http1: bool = True,
http2: bool = False,
retries: int = 0,
local_address: Optional[str] = None,
uds: Optional[str] = None,
network_backend: Optional[AsyncNetworkBackend] = None,
socket_options: Optional[Iterable[SOCKET_OPTION]] = None,
local_address: str | None = None,
uds: str | None = None,
network_backend: AsyncNetworkBackend | None = None,
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
) -> None:
"""
A connection pool for making HTTP requests.
@@ -86,7 +90,7 @@ class AsyncConnectionPool(AsyncRequestInterface):
in the TCP socket when the connection was established.
"""
self._ssl_context = ssl_context
self._proxy = proxy
self._max_connections = (
sys.maxsize if max_connections is None else max_connections
)
@@ -106,15 +110,61 @@ class AsyncConnectionPool(AsyncRequestInterface):
self._local_address = local_address
self._uds = uds
self._pool: List[AsyncConnectionInterface] = []
self._requests: List[RequestStatus] = []
self._pool_lock = AsyncLock()
self._network_backend = (
AutoBackend() if network_backend is None else network_backend
)
self._socket_options = socket_options
# The mutable state on a connection pool is the queue of incoming requests,
# and the set of connections that are servicing those requests.
self._connections: list[AsyncConnectionInterface] = []
self._requests: list[AsyncPoolRequest] = []
# We only mutate the state of the connection pool within an 'optional_thread_lock'
# context. This holds a threading lock unless we're running in async mode,
# in which case it is a no-op.
self._optional_thread_lock = AsyncThreadLock()
def create_connection(self, origin: Origin) -> AsyncConnectionInterface:
if self._proxy is not None:
if self._proxy.url.scheme in (b"socks5", b"socks5h"):
from .socks_proxy import AsyncSocks5Connection
return AsyncSocks5Connection(
proxy_origin=self._proxy.url.origin,
proxy_auth=self._proxy.auth,
remote_origin=origin,
ssl_context=self._ssl_context,
keepalive_expiry=self._keepalive_expiry,
http1=self._http1,
http2=self._http2,
network_backend=self._network_backend,
)
elif origin.scheme == b"http":
from .http_proxy import AsyncForwardHTTPConnection
return AsyncForwardHTTPConnection(
proxy_origin=self._proxy.url.origin,
proxy_headers=self._proxy.headers,
proxy_ssl_context=self._proxy.ssl_context,
remote_origin=origin,
keepalive_expiry=self._keepalive_expiry,
network_backend=self._network_backend,
)
from .http_proxy import AsyncTunnelHTTPConnection
return AsyncTunnelHTTPConnection(
proxy_origin=self._proxy.url.origin,
proxy_headers=self._proxy.headers,
proxy_ssl_context=self._proxy.ssl_context,
remote_origin=origin,
ssl_context=self._ssl_context,
keepalive_expiry=self._keepalive_expiry,
http1=self._http1,
http2=self._http2,
network_backend=self._network_backend,
)
return AsyncHTTPConnection(
origin=origin,
ssl_context=self._ssl_context,
@@ -129,7 +179,7 @@ class AsyncConnectionPool(AsyncRequestInterface):
)
@property
def connections(self) -> List[AsyncConnectionInterface]:
def connections(self) -> list[AsyncConnectionInterface]:
"""
Return a list of the connections currently in the pool.
@@ -144,64 +194,7 @@ class AsyncConnectionPool(AsyncRequestInterface):
]
```
"""
return list(self._pool)
async def _attempt_to_acquire_connection(self, status: RequestStatus) -> bool:
"""
Attempt to provide a connection that can handle the given origin.
"""
origin = status.request.url.origin
# If there are queued requests in front of us, then don't acquire a
# connection. We handle requests strictly in order.
waiting = [s for s in self._requests if s.connection is None]
if waiting and waiting[0] is not status:
return False
# Reuse an existing connection if one is currently available.
for idx, connection in enumerate(self._pool):
if connection.can_handle_request(origin) and connection.is_available():
self._pool.pop(idx)
self._pool.insert(0, connection)
status.set_connection(connection)
return True
# If the pool is currently full, attempt to close one idle connection.
if len(self._pool) >= self._max_connections:
for idx, connection in reversed(list(enumerate(self._pool))):
if connection.is_idle():
await connection.aclose()
self._pool.pop(idx)
break
# If the pool is still full, then we cannot acquire a connection.
if len(self._pool) >= self._max_connections:
return False
# Otherwise create a new connection.
connection = self.create_connection(origin)
self._pool.insert(0, connection)
status.set_connection(connection)
return True
async def _close_expired_connections(self) -> None:
"""
Clean up the connection pool by closing off any connections that have expired.
"""
# Close any connections that have expired their keep-alive time.
for idx, connection in reversed(list(enumerate(self._pool))):
if connection.has_expired():
await connection.aclose()
self._pool.pop(idx)
# If the pool size exceeds the maximum number of allowed keep-alive connections,
# then close off idle connections as required.
pool_size = len(self._pool)
for idx, connection in reversed(list(enumerate(self._pool))):
if connection.is_idle() and pool_size > self._max_keepalive_connections:
await connection.aclose()
self._pool.pop(idx)
pool_size -= 1
return list(self._connections)
async def handle_async_request(self, request: Request) -> Response:
"""
@@ -219,138 +212,209 @@ class AsyncConnectionPool(AsyncRequestInterface):
f"Request URL has an unsupported protocol '{scheme}://'."
)
status = RequestStatus(request)
timeouts = request.extensions.get("timeout", {})
timeout = timeouts.get("pool", None)
async with self._pool_lock:
self._requests.append(status)
await self._close_expired_connections()
await self._attempt_to_acquire_connection(status)
with self._optional_thread_lock:
# Add the incoming request to our request queue.
pool_request = AsyncPoolRequest(request)
self._requests.append(pool_request)
while True:
timeouts = request.extensions.get("timeout", {})
timeout = timeouts.get("pool", None)
try:
connection = await status.wait_for_connection(timeout=timeout)
except BaseException as exc:
# If we timeout here, or if the task is cancelled, then make
# sure to remove the request from the queue before bubbling
# up the exception.
async with self._pool_lock:
# Ensure only remove when task exists.
if status in self._requests:
self._requests.remove(status)
raise exc
try:
while True:
with self._optional_thread_lock:
# Assign incoming requests to available connections,
# closing or creating new connections as required.
closing = self._assign_requests_to_connections()
await self._close_connections(closing)
try:
response = await connection.handle_async_request(request)
except ConnectionNotAvailable:
# The ConnectionNotAvailable exception is a special case, that
# indicates we need to retry the request on a new connection.
#
# The most common case where this can occur is when multiple
# requests are queued waiting for a single connection, which
# might end up as an HTTP/2 connection, but which actually ends
# up as HTTP/1.1.
async with self._pool_lock:
# Maintain our position in the request queue, but reset the
# status so that the request becomes queued again.
status.unset_connection()
await self._attempt_to_acquire_connection(status)
except BaseException as exc:
with AsyncShieldCancellation():
await self.response_closed(status)
raise exc
else:
break
# Wait until this request has an assigned connection.
connection = await pool_request.wait_for_connection(timeout=timeout)
# When we return the response, we wrap the stream in a special class
# that handles notifying the connection pool once the response
# has been released.
assert isinstance(response.stream, AsyncIterable)
try:
# Send the request on the assigned connection.
response = await connection.handle_async_request(
pool_request.request
)
except ConnectionNotAvailable:
# In some cases a connection may initially be available to
# handle a request, but then become unavailable.
#
# In this case we clear the connection and try again.
pool_request.clear_connection()
else:
break # pragma: nocover
except BaseException as exc:
with self._optional_thread_lock:
# For any exception or cancellation we remove the request from
# the queue, and then re-assign requests to connections.
self._requests.remove(pool_request)
closing = self._assign_requests_to_connections()
await self._close_connections(closing)
raise exc from None
# Return the response. Note that in this case we still have to manage
# the point at which the response is closed.
assert isinstance(response.stream, typing.AsyncIterable)
return Response(
status=response.status,
headers=response.headers,
content=ConnectionPoolByteStream(response.stream, self, status),
content=PoolByteStream(
stream=response.stream, pool_request=pool_request, pool=self
),
extensions=response.extensions,
)
async def response_closed(self, status: RequestStatus) -> None:
def _assign_requests_to_connections(self) -> list[AsyncConnectionInterface]:
"""
This method acts as a callback once the request/response cycle is complete.
Manage the state of the connection pool, assigning incoming
requests to connections as available.
It is called into from the `ConnectionPoolByteStream.aclose()` method.
Called whenever a new request is added or removed from the pool.
Any closing connections are returned, allowing the I/O for closing
those connections to be handled seperately.
"""
assert status.connection is not None
connection = status.connection
closing_connections = []
async with self._pool_lock:
# Update the state of the connection pool.
if status in self._requests:
self._requests.remove(status)
# First we handle cleaning up any connections that are closed,
# have expired their keep-alive, or surplus idle connections.
for connection in list(self._connections):
if connection.is_closed():
# log: "removing closed connection"
self._connections.remove(connection)
elif connection.has_expired():
# log: "closing expired connection"
self._connections.remove(connection)
closing_connections.append(connection)
elif (
connection.is_idle()
and len([connection.is_idle() for connection in self._connections])
> self._max_keepalive_connections
):
# log: "closing idle connection"
self._connections.remove(connection)
closing_connections.append(connection)
if connection.is_closed() and connection in self._pool:
self._pool.remove(connection)
# Assign queued requests to connections.
queued_requests = [request for request in self._requests if request.is_queued()]
for pool_request in queued_requests:
origin = pool_request.request.url.origin
available_connections = [
connection
for connection in self._connections
if connection.can_handle_request(origin) and connection.is_available()
]
idle_connections = [
connection for connection in self._connections if connection.is_idle()
]
# Since we've had a response closed, it's possible we'll now be able
# to service one or more requests that are currently pending.
for status in self._requests:
if status.connection is None:
acquired = await self._attempt_to_acquire_connection(status)
# If we could not acquire a connection for a queued request
# then we don't need to check anymore requests that are
# queued later behind it.
if not acquired:
break
# There are three cases for how we may be able to handle the request:
#
# 1. There is an existing connection that can handle the request.
# 2. We can create a new connection to handle the request.
# 3. We can close an idle connection and then create a new connection
# to handle the request.
if available_connections:
# log: "reusing existing connection"
connection = available_connections[0]
pool_request.assign_to_connection(connection)
elif len(self._connections) < self._max_connections:
# log: "creating new connection"
connection = self.create_connection(origin)
self._connections.append(connection)
pool_request.assign_to_connection(connection)
elif idle_connections:
# log: "closing idle connection"
connection = idle_connections[0]
self._connections.remove(connection)
closing_connections.append(connection)
# log: "creating new connection"
connection = self.create_connection(origin)
self._connections.append(connection)
pool_request.assign_to_connection(connection)
# Housekeeping.
await self._close_expired_connections()
return closing_connections
async def _close_connections(self, closing: list[AsyncConnectionInterface]) -> None:
# Close connections which have been removed from the pool.
with AsyncShieldCancellation():
for connection in closing:
await connection.aclose()
async def aclose(self) -> None:
"""
Close any connections in the pool.
"""
async with self._pool_lock:
for connection in self._pool:
await connection.aclose()
self._pool = []
self._requests = []
# Explicitly close the connection pool.
# Clears all existing requests and connections.
with self._optional_thread_lock:
closing_connections = list(self._connections)
self._connections = []
await self._close_connections(closing_connections)
async def __aenter__(self) -> "AsyncConnectionPool":
async def __aenter__(self) -> AsyncConnectionPool:
return self
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]] = None,
exc_value: Optional[BaseException] = None,
traceback: Optional[TracebackType] = None,
exc_type: type[BaseException] | None = None,
exc_value: BaseException | None = None,
traceback: types.TracebackType | None = None,
) -> None:
await self.aclose()
def __repr__(self) -> str:
class_name = self.__class__.__name__
with self._optional_thread_lock:
request_is_queued = [request.is_queued() for request in self._requests]
connection_is_idle = [
connection.is_idle() for connection in self._connections
]
class ConnectionPoolByteStream:
"""
A wrapper around the response byte stream, that additionally handles
notifying the connection pool when the response has been closed.
"""
num_active_requests = request_is_queued.count(False)
num_queued_requests = request_is_queued.count(True)
num_active_connections = connection_is_idle.count(False)
num_idle_connections = connection_is_idle.count(True)
requests_info = (
f"Requests: {num_active_requests} active, {num_queued_requests} queued"
)
connection_info = (
f"Connections: {num_active_connections} active, {num_idle_connections} idle"
)
return f"<{class_name} [{requests_info} | {connection_info}]>"
class PoolByteStream:
def __init__(
self,
stream: AsyncIterable[bytes],
stream: typing.AsyncIterable[bytes],
pool_request: AsyncPoolRequest,
pool: AsyncConnectionPool,
status: RequestStatus,
) -> None:
self._stream = stream
self._pool_request = pool_request
self._pool = pool
self._status = status
self._closed = False
async def __aiter__(self) -> AsyncIterator[bytes]:
async for part in self._stream:
yield part
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
try:
async for part in self._stream:
yield part
except BaseException as exc:
await self.aclose()
raise exc from None
async def aclose(self) -> None:
try:
if hasattr(self._stream, "aclose"):
await self._stream.aclose()
finally:
if not self._closed:
self._closed = True
with AsyncShieldCancellation():
await self._pool.response_closed(self._status)
if hasattr(self._stream, "aclose"):
await self._stream.aclose()
with self._pool._optional_thread_lock:
self._pool._requests.remove(self._pool_request)
closing = self._pool._assign_requests_to_connections()
await self._pool._close_connections(closing)

View File

@@ -1,17 +1,11 @@
from __future__ import annotations
import enum
import logging
import ssl
import time
from types import TracebackType
from typing import (
AsyncIterable,
AsyncIterator,
List,
Optional,
Tuple,
Type,
Union,
cast,
)
import types
import typing
import h11
@@ -20,6 +14,7 @@ from .._exceptions import (
ConnectionNotAvailable,
LocalProtocolError,
RemoteProtocolError,
WriteError,
map_exceptions,
)
from .._models import Origin, Request, Response
@@ -31,7 +26,7 @@ logger = logging.getLogger("httpcore.http11")
# A subset of `h11.Event` types supported by `_send_event`
H11SendEvent = Union[
H11SendEvent = typing.Union[
h11.Request,
h11.Data,
h11.EndOfMessage,
@@ -53,12 +48,12 @@ class AsyncHTTP11Connection(AsyncConnectionInterface):
self,
origin: Origin,
stream: AsyncNetworkStream,
keepalive_expiry: Optional[float] = None,
keepalive_expiry: float | None = None,
) -> None:
self._origin = origin
self._network_stream = stream
self._keepalive_expiry: Optional[float] = keepalive_expiry
self._expire_at: Optional[float] = None
self._keepalive_expiry: float | None = keepalive_expiry
self._expire_at: float | None = None
self._state = HTTPConnectionState.NEW
self._state_lock = AsyncLock()
self._request_count = 0
@@ -84,10 +79,21 @@ class AsyncHTTP11Connection(AsyncConnectionInterface):
try:
kwargs = {"request": request}
async with Trace("send_request_headers", logger, request, kwargs) as trace:
await self._send_request_headers(**kwargs)
async with Trace("send_request_body", logger, request, kwargs) as trace:
await self._send_request_body(**kwargs)
try:
async with Trace(
"send_request_headers", logger, request, kwargs
) as trace:
await self._send_request_headers(**kwargs)
async with Trace("send_request_body", logger, request, kwargs) as trace:
await self._send_request_body(**kwargs)
except WriteError:
# If we get a write error while we're writing the request,
# then we supress this error and move on to attempting to
# read the response. Servers can sometimes close the request
# pre-emptively and then respond with a well formed HTTP
# error response.
pass
async with Trace(
"receive_response_headers", logger, request, kwargs
) as trace:
@@ -96,6 +102,7 @@ class AsyncHTTP11Connection(AsyncConnectionInterface):
status,
reason_phrase,
headers,
trailing_data,
) = await self._receive_response_headers(**kwargs)
trace.return_value = (
http_version,
@@ -104,6 +111,14 @@ class AsyncHTTP11Connection(AsyncConnectionInterface):
headers,
)
network_stream = self._network_stream
# CONNECT or Upgrade request
if (status == 101) or (
(request.method == b"CONNECT") and (200 <= status < 300)
):
network_stream = AsyncHTTP11UpgradeStream(network_stream, trailing_data)
return Response(
status=status,
headers=headers,
@@ -111,7 +126,7 @@ class AsyncHTTP11Connection(AsyncConnectionInterface):
extensions={
"http_version": http_version,
"reason_phrase": reason_phrase,
"network_stream": self._network_stream,
"network_stream": network_stream,
},
)
except BaseException as exc:
@@ -138,16 +153,14 @@ class AsyncHTTP11Connection(AsyncConnectionInterface):
timeouts = request.extensions.get("timeout", {})
timeout = timeouts.get("write", None)
assert isinstance(request.stream, AsyncIterable)
assert isinstance(request.stream, typing.AsyncIterable)
async for chunk in request.stream:
event = h11.Data(data=chunk)
await self._send_event(event, timeout=timeout)
await self._send_event(h11.EndOfMessage(), timeout=timeout)
async def _send_event(
self, event: h11.Event, timeout: Optional[float] = None
) -> None:
async def _send_event(self, event: h11.Event, timeout: float | None = None) -> None:
bytes_to_send = self._h11_state.send(event)
if bytes_to_send is not None:
await self._network_stream.write(bytes_to_send, timeout=timeout)
@@ -156,7 +169,7 @@ class AsyncHTTP11Connection(AsyncConnectionInterface):
async def _receive_response_headers(
self, request: Request
) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]]]:
) -> tuple[bytes, int, bytes, list[tuple[bytes, bytes]], bytes]:
timeouts = request.extensions.get("timeout", {})
timeout = timeouts.get("read", None)
@@ -176,9 +189,13 @@ class AsyncHTTP11Connection(AsyncConnectionInterface):
# raw header casing, rather than the enforced lowercase headers.
headers = event.headers.raw_items()
return http_version, event.status_code, event.reason, headers
trailing_data, _ = self._h11_state.trailing_data
async def _receive_response_body(self, request: Request) -> AsyncIterator[bytes]:
return http_version, event.status_code, event.reason, headers, trailing_data
async def _receive_response_body(
self, request: Request
) -> typing.AsyncIterator[bytes]:
timeouts = request.extensions.get("timeout", {})
timeout = timeouts.get("read", None)
@@ -190,8 +207,8 @@ class AsyncHTTP11Connection(AsyncConnectionInterface):
break
async def _receive_event(
self, timeout: Optional[float] = None
) -> Union[h11.Event, Type[h11.PAUSED]]:
self, timeout: float | None = None
) -> h11.Event | type[h11.PAUSED]:
while True:
with map_exceptions({h11.RemoteProtocolError: RemoteProtocolError}):
event = self._h11_state.next_event()
@@ -216,7 +233,7 @@ class AsyncHTTP11Connection(AsyncConnectionInterface):
self._h11_state.receive_data(data)
else:
# mypy fails to narrow the type in the above if statement above
return cast(Union[h11.Event, Type[h11.PAUSED]], event)
return event # type: ignore[return-value]
async def _response_closed(self) -> None:
async with self._state_lock:
@@ -292,14 +309,14 @@ class AsyncHTTP11Connection(AsyncConnectionInterface):
# These context managers are not used in the standard flow, but are
# useful for testing or working with connection instances directly.
async def __aenter__(self) -> "AsyncHTTP11Connection":
async def __aenter__(self) -> AsyncHTTP11Connection:
return self
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]] = None,
exc_value: Optional[BaseException] = None,
traceback: Optional[TracebackType] = None,
exc_type: type[BaseException] | None = None,
exc_value: BaseException | None = None,
traceback: types.TracebackType | None = None,
) -> None:
await self.aclose()
@@ -310,7 +327,7 @@ class HTTP11ConnectionByteStream:
self._request = request
self._closed = False
async def __aiter__(self) -> AsyncIterator[bytes]:
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
kwargs = {"request": self._request}
try:
async with Trace("receive_response_body", logger, self._request, kwargs):
@@ -329,3 +346,34 @@ class HTTP11ConnectionByteStream:
self._closed = True
async with Trace("response_closed", logger, self._request):
await self._connection._response_closed()
class AsyncHTTP11UpgradeStream(AsyncNetworkStream):
def __init__(self, stream: AsyncNetworkStream, leading_data: bytes) -> None:
self._stream = stream
self._leading_data = leading_data
async def read(self, max_bytes: int, timeout: float | None = None) -> bytes:
if self._leading_data:
buffer = self._leading_data[:max_bytes]
self._leading_data = self._leading_data[max_bytes:]
return buffer
else:
return await self._stream.read(max_bytes, timeout)
async def write(self, buffer: bytes, timeout: float | None = None) -> None:
await self._stream.write(buffer, timeout)
async def aclose(self) -> None:
await self._stream.aclose()
async def start_tls(
self,
ssl_context: ssl.SSLContext,
server_hostname: str | None = None,
timeout: float | None = None,
) -> AsyncNetworkStream:
return await self._stream.start_tls(ssl_context, server_hostname, timeout)
def get_extra_info(self, info: str) -> typing.Any:
return self._stream.get_extra_info(info)

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
import enum
import logging
import time
@@ -45,14 +47,14 @@ class AsyncHTTP2Connection(AsyncConnectionInterface):
self,
origin: Origin,
stream: AsyncNetworkStream,
keepalive_expiry: typing.Optional[float] = None,
keepalive_expiry: float | None = None,
):
self._origin = origin
self._network_stream = stream
self._keepalive_expiry: typing.Optional[float] = keepalive_expiry
self._keepalive_expiry: float | None = keepalive_expiry
self._h2_state = h2.connection.H2Connection(config=self.CONFIG)
self._state = HTTPConnectionState.IDLE
self._expire_at: typing.Optional[float] = None
self._expire_at: float | None = None
self._request_count = 0
self._init_lock = AsyncLock()
self._state_lock = AsyncLock()
@@ -63,24 +65,22 @@ class AsyncHTTP2Connection(AsyncConnectionInterface):
self._connection_error = False
# Mapping from stream ID to response stream events.
self._events: typing.Dict[
self._events: dict[
int,
typing.Union[
h2.events.ResponseReceived,
h2.events.DataReceived,
h2.events.StreamEnded,
h2.events.StreamReset,
list[
h2.events.ResponseReceived
| h2.events.DataReceived
| h2.events.StreamEnded
| h2.events.StreamReset,
],
] = {}
# Connection terminated events are stored as state since
# we need to handle them for all streams.
self._connection_terminated: typing.Optional[
h2.events.ConnectionTerminated
] = None
self._connection_terminated: h2.events.ConnectionTerminated | None = None
self._read_exception: typing.Optional[Exception] = None
self._write_exception: typing.Optional[Exception] = None
self._read_exception: Exception | None = None
self._write_exception: Exception | None = None
async def handle_async_request(self, request: Request) -> Response:
if not self.can_handle_request(request.url.origin):
@@ -104,9 +104,11 @@ class AsyncHTTP2Connection(AsyncConnectionInterface):
async with self._init_lock:
if not self._sent_connection_init:
try:
kwargs = {"request": request}
async with Trace("send_connection_init", logger, request, kwargs):
await self._send_connection_init(**kwargs)
sci_kwargs = {"request": request}
async with Trace(
"send_connection_init", logger, request, sci_kwargs
):
await self._send_connection_init(**sci_kwargs)
except BaseException as exc:
with AsyncShieldCancellation():
await self.aclose()
@@ -284,7 +286,7 @@ class AsyncHTTP2Connection(AsyncConnectionInterface):
async def _receive_response(
self, request: Request, stream_id: int
) -> typing.Tuple[int, typing.List[typing.Tuple[bytes, bytes]]]:
) -> tuple[int, list[tuple[bytes, bytes]]]:
"""
Return the response status code and headers for a given stream ID.
"""
@@ -295,6 +297,7 @@ class AsyncHTTP2Connection(AsyncConnectionInterface):
status_code = 200
headers = []
assert event.headers is not None
for k, v in event.headers:
if k == b":status":
status_code = int(v.decode("ascii", errors="ignore"))
@@ -312,6 +315,8 @@ class AsyncHTTP2Connection(AsyncConnectionInterface):
while True:
event = await self._receive_stream_event(request, stream_id)
if isinstance(event, h2.events.DataReceived):
assert event.flow_controlled_length is not None
assert event.data is not None
amount = event.flow_controlled_length
self._h2_state.acknowledge_received_data(amount, stream_id)
await self._write_outgoing_data(request)
@@ -321,9 +326,7 @@ class AsyncHTTP2Connection(AsyncConnectionInterface):
async def _receive_stream_event(
self, request: Request, stream_id: int
) -> typing.Union[
h2.events.ResponseReceived, h2.events.DataReceived, h2.events.StreamEnded
]:
) -> h2.events.ResponseReceived | h2.events.DataReceived | h2.events.StreamEnded:
"""
Return the next available event for a given stream ID.
@@ -337,7 +340,7 @@ class AsyncHTTP2Connection(AsyncConnectionInterface):
return event
async def _receive_events(
self, request: Request, stream_id: typing.Optional[int] = None
self, request: Request, stream_id: int | None = None
) -> None:
"""
Read some data from the network until we see one or more events
@@ -384,7 +387,9 @@ class AsyncHTTP2Connection(AsyncConnectionInterface):
await self._write_outgoing_data(request)
async def _receive_remote_settings_change(self, event: h2.events.Event) -> None:
async def _receive_remote_settings_change(
self, event: h2.events.RemoteSettingsChanged
) -> None:
max_concurrent_streams = event.changed_settings.get(
h2.settings.SettingCodes.MAX_CONCURRENT_STREAMS
)
@@ -425,9 +430,7 @@ class AsyncHTTP2Connection(AsyncConnectionInterface):
# Wrappers around network read/write operations...
async def _read_incoming_data(
self, request: Request
) -> typing.List[h2.events.Event]:
async def _read_incoming_data(self, request: Request) -> list[h2.events.Event]:
timeouts = request.extensions.get("timeout", {})
timeout = timeouts.get("read", None)
@@ -451,7 +454,7 @@ class AsyncHTTP2Connection(AsyncConnectionInterface):
self._connection_error = True
raise exc
events: typing.List[h2.events.Event] = self._h2_state.receive_data(data)
events: list[h2.events.Event] = self._h2_state.receive_data(data)
return events
@@ -544,14 +547,14 @@ class AsyncHTTP2Connection(AsyncConnectionInterface):
# These context managers are not used in the standard flow, but are
# useful for testing or working with connection instances directly.
async def __aenter__(self) -> "AsyncHTTP2Connection":
async def __aenter__(self) -> AsyncHTTP2Connection:
return self
async def __aexit__(
self,
exc_type: typing.Optional[typing.Type[BaseException]] = None,
exc_value: typing.Optional[BaseException] = None,
traceback: typing.Optional[types.TracebackType] = None,
exc_type: type[BaseException] | None = None,
exc_value: BaseException | None = None,
traceback: types.TracebackType | None = None,
) -> None:
await self.aclose()

View File

@@ -1,7 +1,9 @@
from __future__ import annotations
import base64
import logging
import ssl
from base64 import b64encode
from typing import Iterable, List, Mapping, Optional, Sequence, Tuple, Union
import typing
from .._backends.base import SOCKET_OPTION, AsyncNetworkBackend
from .._exceptions import ProxyError
@@ -22,17 +24,18 @@ from .connection_pool import AsyncConnectionPool
from .http11 import AsyncHTTP11Connection
from .interfaces import AsyncConnectionInterface
HeadersAsSequence = Sequence[Tuple[Union[bytes, str], Union[bytes, str]]]
HeadersAsMapping = Mapping[Union[bytes, str], Union[bytes, str]]
ByteOrStr = typing.Union[bytes, str]
HeadersAsSequence = typing.Sequence[typing.Tuple[ByteOrStr, ByteOrStr]]
HeadersAsMapping = typing.Mapping[ByteOrStr, ByteOrStr]
logger = logging.getLogger("httpcore.proxy")
def merge_headers(
default_headers: Optional[Sequence[Tuple[bytes, bytes]]] = None,
override_headers: Optional[Sequence[Tuple[bytes, bytes]]] = None,
) -> List[Tuple[bytes, bytes]]:
default_headers: typing.Sequence[tuple[bytes, bytes]] | None = None,
override_headers: typing.Sequence[tuple[bytes, bytes]] | None = None,
) -> list[tuple[bytes, bytes]]:
"""
Append default_headers and override_headers, de-duplicating if a key exists
in both cases.
@@ -48,32 +51,28 @@ def merge_headers(
return default_headers + override_headers
def build_auth_header(username: bytes, password: bytes) -> bytes:
userpass = username + b":" + password
return b"Basic " + b64encode(userpass)
class AsyncHTTPProxy(AsyncConnectionPool):
class AsyncHTTPProxy(AsyncConnectionPool): # pragma: nocover
"""
A connection pool that sends requests via an HTTP proxy.
"""
def __init__(
self,
proxy_url: Union[URL, bytes, str],
proxy_auth: Optional[Tuple[Union[bytes, str], Union[bytes, str]]] = None,
proxy_headers: Union[HeadersAsMapping, HeadersAsSequence, None] = None,
ssl_context: Optional[ssl.SSLContext] = None,
max_connections: Optional[int] = 10,
max_keepalive_connections: Optional[int] = None,
keepalive_expiry: Optional[float] = None,
proxy_url: URL | bytes | str,
proxy_auth: tuple[bytes | str, bytes | str] | None = None,
proxy_headers: HeadersAsMapping | HeadersAsSequence | None = None,
ssl_context: ssl.SSLContext | None = None,
proxy_ssl_context: ssl.SSLContext | None = None,
max_connections: int | None = 10,
max_keepalive_connections: int | None = None,
keepalive_expiry: float | None = None,
http1: bool = True,
http2: bool = False,
retries: int = 0,
local_address: Optional[str] = None,
uds: Optional[str] = None,
network_backend: Optional[AsyncNetworkBackend] = None,
socket_options: Optional[Iterable[SOCKET_OPTION]] = None,
local_address: str | None = None,
uds: str | None = None,
network_backend: AsyncNetworkBackend | None = None,
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
) -> None:
"""
A connection pool for making HTTP requests.
@@ -88,6 +87,7 @@ class AsyncHTTPProxy(AsyncConnectionPool):
ssl_context: An SSL context to use for verifying connections.
If not specified, the default `httpcore.default_ssl_context()`
will be used.
proxy_ssl_context: The same as `ssl_context`, but for a proxy server rather than a remote origin.
max_connections: The maximum number of concurrent HTTP connections that
the pool should allow. Any attempt to send a request on a pool that
would exceed this amount will block until a connection is available.
@@ -122,13 +122,23 @@ class AsyncHTTPProxy(AsyncConnectionPool):
uds=uds,
socket_options=socket_options,
)
self._ssl_context = ssl_context
self._proxy_url = enforce_url(proxy_url, name="proxy_url")
if (
self._proxy_url.scheme == b"http" and proxy_ssl_context is not None
): # pragma: no cover
raise RuntimeError(
"The `proxy_ssl_context` argument is not allowed for the http scheme"
)
self._ssl_context = ssl_context
self._proxy_ssl_context = proxy_ssl_context
self._proxy_headers = enforce_headers(proxy_headers, name="proxy_headers")
if proxy_auth is not None:
username = enforce_bytes(proxy_auth[0], name="proxy_auth")
password = enforce_bytes(proxy_auth[1], name="proxy_auth")
authorization = build_auth_header(username, password)
userpass = username + b":" + password
authorization = b"Basic " + base64.b64encode(userpass)
self._proxy_headers = [
(b"Proxy-Authorization", authorization)
] + self._proxy_headers
@@ -141,12 +151,14 @@ class AsyncHTTPProxy(AsyncConnectionPool):
remote_origin=origin,
keepalive_expiry=self._keepalive_expiry,
network_backend=self._network_backend,
proxy_ssl_context=self._proxy_ssl_context,
)
return AsyncTunnelHTTPConnection(
proxy_origin=self._proxy_url.origin,
proxy_headers=self._proxy_headers,
remote_origin=origin,
ssl_context=self._ssl_context,
proxy_ssl_context=self._proxy_ssl_context,
keepalive_expiry=self._keepalive_expiry,
http1=self._http1,
http2=self._http2,
@@ -159,16 +171,18 @@ class AsyncForwardHTTPConnection(AsyncConnectionInterface):
self,
proxy_origin: Origin,
remote_origin: Origin,
proxy_headers: Union[HeadersAsMapping, HeadersAsSequence, None] = None,
keepalive_expiry: Optional[float] = None,
network_backend: Optional[AsyncNetworkBackend] = None,
socket_options: Optional[Iterable[SOCKET_OPTION]] = None,
proxy_headers: HeadersAsMapping | HeadersAsSequence | None = None,
keepalive_expiry: float | None = None,
network_backend: AsyncNetworkBackend | None = None,
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
proxy_ssl_context: ssl.SSLContext | None = None,
) -> None:
self._connection = AsyncHTTPConnection(
origin=proxy_origin,
keepalive_expiry=keepalive_expiry,
network_backend=network_backend,
socket_options=socket_options,
ssl_context=proxy_ssl_context,
)
self._proxy_origin = proxy_origin
self._proxy_headers = enforce_headers(proxy_headers, name="proxy_headers")
@@ -221,23 +235,26 @@ class AsyncTunnelHTTPConnection(AsyncConnectionInterface):
self,
proxy_origin: Origin,
remote_origin: Origin,
ssl_context: Optional[ssl.SSLContext] = None,
proxy_headers: Optional[Sequence[Tuple[bytes, bytes]]] = None,
keepalive_expiry: Optional[float] = None,
ssl_context: ssl.SSLContext | None = None,
proxy_ssl_context: ssl.SSLContext | None = None,
proxy_headers: typing.Sequence[tuple[bytes, bytes]] | None = None,
keepalive_expiry: float | None = None,
http1: bool = True,
http2: bool = False,
network_backend: Optional[AsyncNetworkBackend] = None,
socket_options: Optional[Iterable[SOCKET_OPTION]] = None,
network_backend: AsyncNetworkBackend | None = None,
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
) -> None:
self._connection: AsyncConnectionInterface = AsyncHTTPConnection(
origin=proxy_origin,
keepalive_expiry=keepalive_expiry,
network_backend=network_backend,
socket_options=socket_options,
ssl_context=proxy_ssl_context,
)
self._proxy_origin = proxy_origin
self._remote_origin = remote_origin
self._ssl_context = ssl_context
self._proxy_ssl_context = proxy_ssl_context
self._proxy_headers = enforce_headers(proxy_headers, name="proxy_headers")
self._keepalive_expiry = keepalive_expiry
self._http1 = http1

View File

@@ -1,5 +1,7 @@
from contextlib import asynccontextmanager
from typing import AsyncIterator, Optional, Union
from __future__ import annotations
import contextlib
import typing
from .._models import (
URL,
@@ -18,12 +20,12 @@ from .._models import (
class AsyncRequestInterface:
async def request(
self,
method: Union[bytes, str],
url: Union[URL, bytes, str],
method: bytes | str,
url: URL | bytes | str,
*,
headers: HeaderTypes = None,
content: Union[bytes, AsyncIterator[bytes], None] = None,
extensions: Optional[Extensions] = None,
content: bytes | typing.AsyncIterator[bytes] | None = None,
extensions: Extensions | None = None,
) -> Response:
# Strict type checking on our parameters.
method = enforce_bytes(method, name="method")
@@ -47,16 +49,16 @@ class AsyncRequestInterface:
await response.aclose()
return response
@asynccontextmanager
@contextlib.asynccontextmanager
async def stream(
self,
method: Union[bytes, str],
url: Union[URL, bytes, str],
method: bytes | str,
url: URL | bytes | str,
*,
headers: HeaderTypes = None,
content: Union[bytes, AsyncIterator[bytes], None] = None,
extensions: Optional[Extensions] = None,
) -> AsyncIterator[Response]:
content: bytes | typing.AsyncIterator[bytes] | None = None,
extensions: Extensions | None = None,
) -> typing.AsyncIterator[Response]:
# Strict type checking on our parameters.
method = enforce_bytes(method, name="method")
url = enforce_url(url, name="url")

View File

@@ -1,8 +1,9 @@
from __future__ import annotations
import logging
import ssl
import typing
from socksio import socks5
import socksio
from .._backends.auto import AutoBackend
from .._backends.base import AsyncNetworkBackend, AsyncNetworkStream
@@ -43,24 +44,24 @@ async def _init_socks5_connection(
*,
host: bytes,
port: int,
auth: typing.Optional[typing.Tuple[bytes, bytes]] = None,
auth: tuple[bytes, bytes] | None = None,
) -> None:
conn = socks5.SOCKS5Connection()
conn = socksio.socks5.SOCKS5Connection()
# Auth method request
auth_method = (
socks5.SOCKS5AuthMethod.NO_AUTH_REQUIRED
socksio.socks5.SOCKS5AuthMethod.NO_AUTH_REQUIRED
if auth is None
else socks5.SOCKS5AuthMethod.USERNAME_PASSWORD
else socksio.socks5.SOCKS5AuthMethod.USERNAME_PASSWORD
)
conn.send(socks5.SOCKS5AuthMethodsRequest([auth_method]))
conn.send(socksio.socks5.SOCKS5AuthMethodsRequest([auth_method]))
outgoing_bytes = conn.data_to_send()
await stream.write(outgoing_bytes)
# Auth method response
incoming_bytes = await stream.read(max_bytes=4096)
response = conn.receive_data(incoming_bytes)
assert isinstance(response, socks5.SOCKS5AuthReply)
assert isinstance(response, socksio.socks5.SOCKS5AuthReply)
if response.method != auth_method:
requested = AUTH_METHODS.get(auth_method, "UNKNOWN")
responded = AUTH_METHODS.get(response.method, "UNKNOWN")
@@ -68,25 +69,25 @@ async def _init_socks5_connection(
f"Requested {requested} from proxy server, but got {responded}."
)
if response.method == socks5.SOCKS5AuthMethod.USERNAME_PASSWORD:
if response.method == socksio.socks5.SOCKS5AuthMethod.USERNAME_PASSWORD:
# Username/password request
assert auth is not None
username, password = auth
conn.send(socks5.SOCKS5UsernamePasswordRequest(username, password))
conn.send(socksio.socks5.SOCKS5UsernamePasswordRequest(username, password))
outgoing_bytes = conn.data_to_send()
await stream.write(outgoing_bytes)
# Username/password response
incoming_bytes = await stream.read(max_bytes=4096)
response = conn.receive_data(incoming_bytes)
assert isinstance(response, socks5.SOCKS5UsernamePasswordReply)
assert isinstance(response, socksio.socks5.SOCKS5UsernamePasswordReply)
if not response.success:
raise ProxyError("Invalid username/password")
# Connect request
conn.send(
socks5.SOCKS5CommandRequest.from_address(
socks5.SOCKS5Command.CONNECT, (host, port)
socksio.socks5.SOCKS5CommandRequest.from_address(
socksio.socks5.SOCKS5Command.CONNECT, (host, port)
)
)
outgoing_bytes = conn.data_to_send()
@@ -95,31 +96,29 @@ async def _init_socks5_connection(
# Connect response
incoming_bytes = await stream.read(max_bytes=4096)
response = conn.receive_data(incoming_bytes)
assert isinstance(response, socks5.SOCKS5Reply)
if response.reply_code != socks5.SOCKS5ReplyCode.SUCCEEDED:
assert isinstance(response, socksio.socks5.SOCKS5Reply)
if response.reply_code != socksio.socks5.SOCKS5ReplyCode.SUCCEEDED:
reply_code = REPLY_CODES.get(response.reply_code, "UNKOWN")
raise ProxyError(f"Proxy Server could not connect: {reply_code}.")
class AsyncSOCKSProxy(AsyncConnectionPool):
class AsyncSOCKSProxy(AsyncConnectionPool): # pragma: nocover
"""
A connection pool that sends requests via an HTTP proxy.
"""
def __init__(
self,
proxy_url: typing.Union[URL, bytes, str],
proxy_auth: typing.Optional[
typing.Tuple[typing.Union[bytes, str], typing.Union[bytes, str]]
] = None,
ssl_context: typing.Optional[ssl.SSLContext] = None,
max_connections: typing.Optional[int] = 10,
max_keepalive_connections: typing.Optional[int] = None,
keepalive_expiry: typing.Optional[float] = None,
proxy_url: URL | bytes | str,
proxy_auth: tuple[bytes | str, bytes | str] | None = None,
ssl_context: ssl.SSLContext | None = None,
max_connections: int | None = 10,
max_keepalive_connections: int | None = None,
keepalive_expiry: float | None = None,
http1: bool = True,
http2: bool = False,
retries: int = 0,
network_backend: typing.Optional[AsyncNetworkBackend] = None,
network_backend: AsyncNetworkBackend | None = None,
) -> None:
"""
A connection pool for making HTTP requests.
@@ -167,7 +166,7 @@ class AsyncSOCKSProxy(AsyncConnectionPool):
username, password = proxy_auth
username_bytes = enforce_bytes(username, name="proxy_auth")
password_bytes = enforce_bytes(password, name="proxy_auth")
self._proxy_auth: typing.Optional[typing.Tuple[bytes, bytes]] = (
self._proxy_auth: tuple[bytes, bytes] | None = (
username_bytes,
password_bytes,
)
@@ -192,12 +191,12 @@ class AsyncSocks5Connection(AsyncConnectionInterface):
self,
proxy_origin: Origin,
remote_origin: Origin,
proxy_auth: typing.Optional[typing.Tuple[bytes, bytes]] = None,
ssl_context: typing.Optional[ssl.SSLContext] = None,
keepalive_expiry: typing.Optional[float] = None,
proxy_auth: tuple[bytes, bytes] | None = None,
ssl_context: ssl.SSLContext | None = None,
keepalive_expiry: float | None = None,
http1: bool = True,
http2: bool = False,
network_backend: typing.Optional[AsyncNetworkBackend] = None,
network_backend: AsyncNetworkBackend | None = None,
) -> None:
self._proxy_origin = proxy_origin
self._remote_origin = remote_origin
@@ -211,11 +210,12 @@ class AsyncSocks5Connection(AsyncConnectionInterface):
AutoBackend() if network_backend is None else network_backend
)
self._connect_lock = AsyncLock()
self._connection: typing.Optional[AsyncConnectionInterface] = None
self._connection: AsyncConnectionInterface | None = None
self._connect_failed = False
async def handle_async_request(self, request: Request) -> Response:
timeouts = request.extensions.get("timeout", {})
sni_hostname = request.extensions.get("sni_hostname", None)
timeout = timeouts.get("connect", None)
async with self._connect_lock:
@@ -227,7 +227,7 @@ class AsyncSocks5Connection(AsyncConnectionInterface):
"port": self._proxy_origin.port,
"timeout": timeout,
}
with Trace("connect_tcp", logger, request, kwargs) as trace:
async with Trace("connect_tcp", logger, request, kwargs) as trace:
stream = await self._network_backend.connect_tcp(**kwargs)
trace.return_value = stream
@@ -238,7 +238,7 @@ class AsyncSocks5Connection(AsyncConnectionInterface):
"port": self._remote_origin.port,
"auth": self._proxy_auth,
}
with Trace(
async with Trace(
"setup_socks5_connection", logger, request, kwargs
) as trace:
await _init_socks5_connection(**kwargs)
@@ -258,7 +258,8 @@ class AsyncSocks5Connection(AsyncConnectionInterface):
kwargs = {
"ssl_context": ssl_context,
"server_hostname": self._remote_origin.host.decode("ascii"),
"server_hostname": sni_hostname
or self._remote_origin.host.decode("ascii"),
"timeout": timeout,
}
async with Trace("start_tls", logger, request, kwargs) as trace:

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
import ssl
import typing
@@ -20,13 +22,12 @@ class AnyIOStream(AsyncNetworkStream):
def __init__(self, stream: anyio.abc.ByteStream) -> None:
self._stream = stream
async def read(
self, max_bytes: int, timeout: typing.Optional[float] = None
) -> bytes:
async def read(self, max_bytes: int, timeout: float | None = None) -> bytes:
exc_map = {
TimeoutError: ReadTimeout,
anyio.BrokenResourceError: ReadError,
anyio.ClosedResourceError: ReadError,
anyio.EndOfStream: ReadError,
}
with map_exceptions(exc_map):
with anyio.fail_after(timeout):
@@ -35,9 +36,7 @@ class AnyIOStream(AsyncNetworkStream):
except anyio.EndOfStream: # pragma: nocover
return b""
async def write(
self, buffer: bytes, timeout: typing.Optional[float] = None
) -> None:
async def write(self, buffer: bytes, timeout: float | None = None) -> None:
if not buffer:
return
@@ -56,12 +55,14 @@ class AnyIOStream(AsyncNetworkStream):
async def start_tls(
self,
ssl_context: ssl.SSLContext,
server_hostname: typing.Optional[str] = None,
timeout: typing.Optional[float] = None,
server_hostname: str | None = None,
timeout: float | None = None,
) -> AsyncNetworkStream:
exc_map = {
TimeoutError: ConnectTimeout,
anyio.BrokenResourceError: ConnectError,
anyio.EndOfStream: ConnectError,
ssl.SSLError: ConnectError,
}
with map_exceptions(exc_map):
try:
@@ -98,12 +99,12 @@ class AnyIOBackend(AsyncNetworkBackend):
self,
host: str,
port: int,
timeout: typing.Optional[float] = None,
local_address: typing.Optional[str] = None,
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
) -> AsyncNetworkStream:
timeout: float | None = None,
local_address: str | None = None,
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
) -> AsyncNetworkStream: # pragma: nocover
if socket_options is None:
socket_options = [] # pragma: no cover
socket_options = []
exc_map = {
TimeoutError: ConnectTimeout,
OSError: ConnectError,
@@ -124,8 +125,8 @@ class AnyIOBackend(AsyncNetworkBackend):
async def connect_unix_socket(
self,
path: str,
timeout: typing.Optional[float] = None,
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
timeout: float | None = None,
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
) -> AsyncNetworkStream: # pragma: nocover
if socket_options is None:
socket_options = []

View File

@@ -1,15 +1,15 @@
from __future__ import annotations
import typing
from typing import Optional
import sniffio
from .._synchronization import current_async_library
from .base import SOCKET_OPTION, AsyncNetworkBackend, AsyncNetworkStream
class AutoBackend(AsyncNetworkBackend):
async def _init_backend(self) -> None:
if not (hasattr(self, "_backend")):
backend = sniffio.current_async_library()
backend = current_async_library()
if backend == "trio":
from .trio import TrioBackend
@@ -23,9 +23,9 @@ class AutoBackend(AsyncNetworkBackend):
self,
host: str,
port: int,
timeout: Optional[float] = None,
local_address: Optional[str] = None,
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
timeout: float | None = None,
local_address: str | None = None,
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
) -> AsyncNetworkStream:
await self._init_backend()
return await self._backend.connect_tcp(
@@ -39,8 +39,8 @@ class AutoBackend(AsyncNetworkBackend):
async def connect_unix_socket(
self,
path: str,
timeout: Optional[float] = None,
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
timeout: float | None = None,
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
) -> AsyncNetworkStream: # pragma: nocover
await self._init_backend()
return await self._backend.connect_unix_socket(

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
import ssl
import time
import typing
@@ -10,10 +12,10 @@ SOCKET_OPTION = typing.Union[
class NetworkStream:
def read(self, max_bytes: int, timeout: typing.Optional[float] = None) -> bytes:
def read(self, max_bytes: int, timeout: float | None = None) -> bytes:
raise NotImplementedError() # pragma: nocover
def write(self, buffer: bytes, timeout: typing.Optional[float] = None) -> None:
def write(self, buffer: bytes, timeout: float | None = None) -> None:
raise NotImplementedError() # pragma: nocover
def close(self) -> None:
@@ -22,9 +24,9 @@ class NetworkStream:
def start_tls(
self,
ssl_context: ssl.SSLContext,
server_hostname: typing.Optional[str] = None,
timeout: typing.Optional[float] = None,
) -> "NetworkStream":
server_hostname: str | None = None,
timeout: float | None = None,
) -> NetworkStream:
raise NotImplementedError() # pragma: nocover
def get_extra_info(self, info: str) -> typing.Any:
@@ -36,17 +38,17 @@ class NetworkBackend:
self,
host: str,
port: int,
timeout: typing.Optional[float] = None,
local_address: typing.Optional[str] = None,
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
timeout: float | None = None,
local_address: str | None = None,
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
) -> NetworkStream:
raise NotImplementedError() # pragma: nocover
def connect_unix_socket(
self,
path: str,
timeout: typing.Optional[float] = None,
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
timeout: float | None = None,
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
) -> NetworkStream:
raise NotImplementedError() # pragma: nocover
@@ -55,14 +57,10 @@ class NetworkBackend:
class AsyncNetworkStream:
async def read(
self, max_bytes: int, timeout: typing.Optional[float] = None
) -> bytes:
async def read(self, max_bytes: int, timeout: float | None = None) -> bytes:
raise NotImplementedError() # pragma: nocover
async def write(
self, buffer: bytes, timeout: typing.Optional[float] = None
) -> None:
async def write(self, buffer: bytes, timeout: float | None = None) -> None:
raise NotImplementedError() # pragma: nocover
async def aclose(self) -> None:
@@ -71,9 +69,9 @@ class AsyncNetworkStream:
async def start_tls(
self,
ssl_context: ssl.SSLContext,
server_hostname: typing.Optional[str] = None,
timeout: typing.Optional[float] = None,
) -> "AsyncNetworkStream":
server_hostname: str | None = None,
timeout: float | None = None,
) -> AsyncNetworkStream:
raise NotImplementedError() # pragma: nocover
def get_extra_info(self, info: str) -> typing.Any:
@@ -85,17 +83,17 @@ class AsyncNetworkBackend:
self,
host: str,
port: int,
timeout: typing.Optional[float] = None,
local_address: typing.Optional[str] = None,
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
timeout: float | None = None,
local_address: str | None = None,
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
) -> AsyncNetworkStream:
raise NotImplementedError() # pragma: nocover
async def connect_unix_socket(
self,
path: str,
timeout: typing.Optional[float] = None,
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
timeout: float | None = None,
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
) -> AsyncNetworkStream:
raise NotImplementedError() # pragma: nocover

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
import ssl
import typing
from typing import Optional
from .._exceptions import ReadError
from .base import (
@@ -21,19 +22,19 @@ class MockSSLObject:
class MockStream(NetworkStream):
def __init__(self, buffer: typing.List[bytes], http2: bool = False) -> None:
def __init__(self, buffer: list[bytes], http2: bool = False) -> None:
self._buffer = buffer
self._http2 = http2
self._closed = False
def read(self, max_bytes: int, timeout: Optional[float] = None) -> bytes:
def read(self, max_bytes: int, timeout: float | None = None) -> bytes:
if self._closed:
raise ReadError("Connection closed")
if not self._buffer:
return b""
return self._buffer.pop(0)
def write(self, buffer: bytes, timeout: Optional[float] = None) -> None:
def write(self, buffer: bytes, timeout: float | None = None) -> None:
pass
def close(self) -> None:
@@ -42,8 +43,8 @@ class MockStream(NetworkStream):
def start_tls(
self,
ssl_context: ssl.SSLContext,
server_hostname: Optional[str] = None,
timeout: Optional[float] = None,
server_hostname: str | None = None,
timeout: float | None = None,
) -> NetworkStream:
return self
@@ -55,7 +56,7 @@ class MockStream(NetworkStream):
class MockBackend(NetworkBackend):
def __init__(self, buffer: typing.List[bytes], http2: bool = False) -> None:
def __init__(self, buffer: list[bytes], http2: bool = False) -> None:
self._buffer = buffer
self._http2 = http2
@@ -63,17 +64,17 @@ class MockBackend(NetworkBackend):
self,
host: str,
port: int,
timeout: Optional[float] = None,
local_address: Optional[str] = None,
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
timeout: float | None = None,
local_address: str | None = None,
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
) -> NetworkStream:
return MockStream(list(self._buffer), http2=self._http2)
def connect_unix_socket(
self,
path: str,
timeout: Optional[float] = None,
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
timeout: float | None = None,
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
) -> NetworkStream:
return MockStream(list(self._buffer), http2=self._http2)
@@ -82,19 +83,19 @@ class MockBackend(NetworkBackend):
class AsyncMockStream(AsyncNetworkStream):
def __init__(self, buffer: typing.List[bytes], http2: bool = False) -> None:
def __init__(self, buffer: list[bytes], http2: bool = False) -> None:
self._buffer = buffer
self._http2 = http2
self._closed = False
async def read(self, max_bytes: int, timeout: Optional[float] = None) -> bytes:
async def read(self, max_bytes: int, timeout: float | None = None) -> bytes:
if self._closed:
raise ReadError("Connection closed")
if not self._buffer:
return b""
return self._buffer.pop(0)
async def write(self, buffer: bytes, timeout: Optional[float] = None) -> None:
async def write(self, buffer: bytes, timeout: float | None = None) -> None:
pass
async def aclose(self) -> None:
@@ -103,8 +104,8 @@ class AsyncMockStream(AsyncNetworkStream):
async def start_tls(
self,
ssl_context: ssl.SSLContext,
server_hostname: Optional[str] = None,
timeout: Optional[float] = None,
server_hostname: str | None = None,
timeout: float | None = None,
) -> AsyncNetworkStream:
return self
@@ -116,7 +117,7 @@ class AsyncMockStream(AsyncNetworkStream):
class AsyncMockBackend(AsyncNetworkBackend):
def __init__(self, buffer: typing.List[bytes], http2: bool = False) -> None:
def __init__(self, buffer: list[bytes], http2: bool = False) -> None:
self._buffer = buffer
self._http2 = http2
@@ -124,17 +125,17 @@ class AsyncMockBackend(AsyncNetworkBackend):
self,
host: str,
port: int,
timeout: Optional[float] = None,
local_address: Optional[str] = None,
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
timeout: float | None = None,
local_address: str | None = None,
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
) -> AsyncNetworkStream:
return AsyncMockStream(list(self._buffer), http2=self._http2)
async def connect_unix_socket(
self,
path: str,
timeout: Optional[float] = None,
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
timeout: float | None = None,
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
) -> AsyncNetworkStream:
return AsyncMockStream(list(self._buffer), http2=self._http2)

View File

@@ -1,3 +1,6 @@
from __future__ import annotations
import functools
import socket
import ssl
import sys
@@ -17,17 +20,114 @@ from .._utils import is_socket_readable
from .base import SOCKET_OPTION, NetworkBackend, NetworkStream
class TLSinTLSStream(NetworkStream): # pragma: no cover
"""
Because the standard `SSLContext.wrap_socket` method does
not work for `SSLSocket` objects, we need this class
to implement TLS stream using an underlying `SSLObject`
instance in order to support TLS on top of TLS.
"""
# Defined in RFC 8449
TLS_RECORD_SIZE = 16384
def __init__(
self,
sock: socket.socket,
ssl_context: ssl.SSLContext,
server_hostname: str | None = None,
timeout: float | None = None,
):
self._sock = sock
self._incoming = ssl.MemoryBIO()
self._outgoing = ssl.MemoryBIO()
self.ssl_obj = ssl_context.wrap_bio(
incoming=self._incoming,
outgoing=self._outgoing,
server_hostname=server_hostname,
)
self._sock.settimeout(timeout)
self._perform_io(self.ssl_obj.do_handshake)
def _perform_io(
self,
func: typing.Callable[..., typing.Any],
) -> typing.Any:
ret = None
while True:
errno = None
try:
ret = func()
except (ssl.SSLWantReadError, ssl.SSLWantWriteError) as e:
errno = e.errno
self._sock.sendall(self._outgoing.read())
if errno == ssl.SSL_ERROR_WANT_READ:
buf = self._sock.recv(self.TLS_RECORD_SIZE)
if buf:
self._incoming.write(buf)
else:
self._incoming.write_eof()
if errno is None:
return ret
def read(self, max_bytes: int, timeout: float | None = None) -> bytes:
exc_map: ExceptionMapping = {socket.timeout: ReadTimeout, OSError: ReadError}
with map_exceptions(exc_map):
self._sock.settimeout(timeout)
return typing.cast(
bytes, self._perform_io(functools.partial(self.ssl_obj.read, max_bytes))
)
def write(self, buffer: bytes, timeout: float | None = None) -> None:
exc_map: ExceptionMapping = {socket.timeout: WriteTimeout, OSError: WriteError}
with map_exceptions(exc_map):
self._sock.settimeout(timeout)
while buffer:
nsent = self._perform_io(functools.partial(self.ssl_obj.write, buffer))
buffer = buffer[nsent:]
def close(self) -> None:
self._sock.close()
def start_tls(
self,
ssl_context: ssl.SSLContext,
server_hostname: str | None = None,
timeout: float | None = None,
) -> NetworkStream:
raise NotImplementedError()
def get_extra_info(self, info: str) -> typing.Any:
if info == "ssl_object":
return self.ssl_obj
if info == "client_addr":
return self._sock.getsockname()
if info == "server_addr":
return self._sock.getpeername()
if info == "socket":
return self._sock
if info == "is_readable":
return is_socket_readable(self._sock)
return None
class SyncStream(NetworkStream):
def __init__(self, sock: socket.socket) -> None:
self._sock = sock
def read(self, max_bytes: int, timeout: typing.Optional[float] = None) -> bytes:
def read(self, max_bytes: int, timeout: float | None = None) -> bytes:
exc_map: ExceptionMapping = {socket.timeout: ReadTimeout, OSError: ReadError}
with map_exceptions(exc_map):
self._sock.settimeout(timeout)
return self._sock.recv(max_bytes)
def write(self, buffer: bytes, timeout: typing.Optional[float] = None) -> None:
def write(self, buffer: bytes, timeout: float | None = None) -> None:
if not buffer:
return
@@ -44,8 +144,8 @@ class SyncStream(NetworkStream):
def start_tls(
self,
ssl_context: ssl.SSLContext,
server_hostname: typing.Optional[str] = None,
timeout: typing.Optional[float] = None,
server_hostname: str | None = None,
timeout: float | None = None,
) -> NetworkStream:
exc_map: ExceptionMapping = {
socket.timeout: ConnectTimeout,
@@ -53,10 +153,18 @@ class SyncStream(NetworkStream):
}
with map_exceptions(exc_map):
try:
self._sock.settimeout(timeout)
sock = ssl_context.wrap_socket(
self._sock, server_hostname=server_hostname
)
if isinstance(self._sock, ssl.SSLSocket): # pragma: no cover
# If the underlying socket has already been upgraded
# to the TLS layer (i.e. is an instance of SSLSocket),
# we need some additional smarts to support TLS-in-TLS.
return TLSinTLSStream(
self._sock, ssl_context, server_hostname, timeout
)
else:
self._sock.settimeout(timeout)
sock = ssl_context.wrap_socket(
self._sock, server_hostname=server_hostname
)
except Exception as exc: # pragma: nocover
self.close()
raise exc
@@ -81,9 +189,9 @@ class SyncBackend(NetworkBackend):
self,
host: str,
port: int,
timeout: typing.Optional[float] = None,
local_address: typing.Optional[str] = None,
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
timeout: float | None = None,
local_address: str | None = None,
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
) -> NetworkStream:
# Note that we automatically include `TCP_NODELAY`
# in addition to any other custom socket options.
@@ -110,8 +218,8 @@ class SyncBackend(NetworkBackend):
def connect_unix_socket(
self,
path: str,
timeout: typing.Optional[float] = None,
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
timeout: float | None = None,
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
) -> NetworkStream: # pragma: nocover
if sys.platform == "win32":
raise RuntimeError(

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
import ssl
import typing
@@ -20,9 +22,7 @@ class TrioStream(AsyncNetworkStream):
def __init__(self, stream: trio.abc.Stream) -> None:
self._stream = stream
async def read(
self, max_bytes: int, timeout: typing.Optional[float] = None
) -> bytes:
async def read(self, max_bytes: int, timeout: float | None = None) -> bytes:
timeout_or_inf = float("inf") if timeout is None else timeout
exc_map: ExceptionMapping = {
trio.TooSlowError: ReadTimeout,
@@ -34,9 +34,7 @@ class TrioStream(AsyncNetworkStream):
data: bytes = await self._stream.receive_some(max_bytes=max_bytes)
return data
async def write(
self, buffer: bytes, timeout: typing.Optional[float] = None
) -> None:
async def write(self, buffer: bytes, timeout: float | None = None) -> None:
if not buffer:
return
@@ -56,8 +54,8 @@ class TrioStream(AsyncNetworkStream):
async def start_tls(
self,
ssl_context: ssl.SSLContext,
server_hostname: typing.Optional[str] = None,
timeout: typing.Optional[float] = None,
server_hostname: str | None = None,
timeout: float | None = None,
) -> AsyncNetworkStream:
timeout_or_inf = float("inf") if timeout is None else timeout
exc_map: ExceptionMapping = {
@@ -113,9 +111,9 @@ class TrioBackend(AsyncNetworkBackend):
self,
host: str,
port: int,
timeout: typing.Optional[float] = None,
local_address: typing.Optional[str] = None,
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
timeout: float | None = None,
local_address: str | None = None,
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
) -> AsyncNetworkStream:
# By default for TCP sockets, trio enables TCP_NODELAY.
# https://trio.readthedocs.io/en/stable/reference-io.html#trio.SocketStream
@@ -139,8 +137,8 @@ class TrioBackend(AsyncNetworkBackend):
async def connect_unix_socket(
self,
path: str,
timeout: typing.Optional[float] = None,
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
timeout: float | None = None,
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
) -> AsyncNetworkStream: # pragma: nocover
if socket_options is None:
socket_options = []

View File

@@ -1,11 +1,11 @@
import contextlib
from typing import Iterator, Mapping, Type
import typing
ExceptionMapping = Mapping[Type[Exception], Type[Exception]]
ExceptionMapping = typing.Mapping[typing.Type[Exception], typing.Type[Exception]]
@contextlib.contextmanager
def map_exceptions(map: ExceptionMapping) -> Iterator[None]:
def map_exceptions(map: ExceptionMapping) -> typing.Iterator[None]:
try:
yield
except Exception as exc: # noqa: PIE786

View File

@@ -1,29 +1,22 @@
from typing import (
Any,
AsyncIterable,
AsyncIterator,
Iterable,
Iterator,
List,
Mapping,
Optional,
Sequence,
Tuple,
Union,
)
from urllib.parse import urlparse
from __future__ import annotations
import base64
import ssl
import typing
import urllib.parse
# Functions for typechecking...
HeadersAsSequence = Sequence[Tuple[Union[bytes, str], Union[bytes, str]]]
HeadersAsMapping = Mapping[Union[bytes, str], Union[bytes, str]]
HeaderTypes = Union[HeadersAsSequence, HeadersAsMapping, None]
ByteOrStr = typing.Union[bytes, str]
HeadersAsSequence = typing.Sequence[typing.Tuple[ByteOrStr, ByteOrStr]]
HeadersAsMapping = typing.Mapping[ByteOrStr, ByteOrStr]
HeaderTypes = typing.Union[HeadersAsSequence, HeadersAsMapping, None]
Extensions = Mapping[str, Any]
Extensions = typing.MutableMapping[str, typing.Any]
def enforce_bytes(value: Union[bytes, str], *, name: str) -> bytes:
def enforce_bytes(value: bytes | str, *, name: str) -> bytes:
"""
Any arguments that are ultimately represented as bytes can be specified
either as bytes or as strings.
@@ -44,7 +37,7 @@ def enforce_bytes(value: Union[bytes, str], *, name: str) -> bytes:
raise TypeError(f"{name} must be bytes or str, but got {seen_type}.")
def enforce_url(value: Union["URL", bytes, str], *, name: str) -> "URL":
def enforce_url(value: URL | bytes | str, *, name: str) -> URL:
"""
Type check for URL parameters.
"""
@@ -58,15 +51,15 @@ def enforce_url(value: Union["URL", bytes, str], *, name: str) -> "URL":
def enforce_headers(
value: Union[HeadersAsMapping, HeadersAsSequence, None] = None, *, name: str
) -> List[Tuple[bytes, bytes]]:
value: HeadersAsMapping | HeadersAsSequence | None = None, *, name: str
) -> list[tuple[bytes, bytes]]:
"""
Convienence function that ensure all items in request or response headers
are either bytes or strings in the plain ASCII range.
"""
if value is None:
return []
elif isinstance(value, Mapping):
elif isinstance(value, typing.Mapping):
return [
(
enforce_bytes(k, name="header name"),
@@ -74,7 +67,7 @@ def enforce_headers(
)
for k, v in value.items()
]
elif isinstance(value, Sequence):
elif isinstance(value, typing.Sequence):
return [
(
enforce_bytes(k, name="header name"),
@@ -90,8 +83,10 @@ def enforce_headers(
def enforce_stream(
value: Union[bytes, Iterable[bytes], AsyncIterable[bytes], None], *, name: str
) -> Union[Iterable[bytes], AsyncIterable[bytes]]:
value: bytes | typing.Iterable[bytes] | typing.AsyncIterable[bytes] | None,
*,
name: str,
) -> typing.Iterable[bytes] | typing.AsyncIterable[bytes]:
if value is None:
return ByteStream(b"")
elif isinstance(value, bytes):
@@ -112,11 +107,11 @@ DEFAULT_PORTS = {
def include_request_headers(
headers: List[Tuple[bytes, bytes]],
headers: list[tuple[bytes, bytes]],
*,
url: "URL",
content: Union[None, bytes, Iterable[bytes], AsyncIterable[bytes]],
) -> List[Tuple[bytes, bytes]]:
content: None | bytes | typing.Iterable[bytes] | typing.AsyncIterable[bytes],
) -> list[tuple[bytes, bytes]]:
headers_set = set(k.lower() for k, v in headers)
if b"host" not in headers_set:
@@ -153,10 +148,10 @@ class ByteStream:
def __init__(self, content: bytes) -> None:
self._content = content
def __iter__(self) -> Iterator[bytes]:
def __iter__(self) -> typing.Iterator[bytes]:
yield self._content
async def __aiter__(self) -> AsyncIterator[bytes]:
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
yield self._content
def __repr__(self) -> str:
@@ -169,7 +164,7 @@ class Origin:
self.host = host
self.port = port
def __eq__(self, other: Any) -> bool:
def __eq__(self, other: typing.Any) -> bool:
return (
isinstance(other, Origin)
and self.scheme == other.scheme
@@ -253,12 +248,12 @@ class URL:
def __init__(
self,
url: Union[bytes, str] = "",
url: bytes | str = "",
*,
scheme: Union[bytes, str] = b"",
host: Union[bytes, str] = b"",
port: Optional[int] = None,
target: Union[bytes, str] = b"",
scheme: bytes | str = b"",
host: bytes | str = b"",
port: int | None = None,
target: bytes | str = b"",
) -> None:
"""
Parameters:
@@ -270,7 +265,7 @@ class URL:
target: The target of the HTTP request. Such as `"/items?search=red"`.
"""
if url:
parsed = urlparse(enforce_bytes(url, name="url"))
parsed = urllib.parse.urlparse(enforce_bytes(url, name="url"))
self.scheme = parsed.scheme
self.host = parsed.hostname or b""
self.port = parsed.port
@@ -291,12 +286,13 @@ class URL:
b"ws": 80,
b"wss": 443,
b"socks5": 1080,
b"socks5h": 1080,
}[self.scheme]
return Origin(
scheme=self.scheme, host=self.host, port=self.port or default_port
)
def __eq__(self, other: Any) -> bool:
def __eq__(self, other: typing.Any) -> bool:
return (
isinstance(other, URL)
and other.scheme == self.scheme
@@ -324,12 +320,15 @@ class Request:
def __init__(
self,
method: Union[bytes, str],
url: Union[URL, bytes, str],
method: bytes | str,
url: URL | bytes | str,
*,
headers: HeaderTypes = None,
content: Union[bytes, Iterable[bytes], AsyncIterable[bytes], None] = None,
extensions: Optional[Extensions] = None,
content: bytes
| typing.Iterable[bytes]
| typing.AsyncIterable[bytes]
| None = None,
extensions: Extensions | None = None,
) -> None:
"""
Parameters:
@@ -338,20 +337,28 @@ class Request:
url: The request URL, either as a `URL` instance, or as a string or bytes.
For example: `"https://www.example.com".`
headers: The HTTP request headers.
content: The content of the response body.
content: The content of the request body.
extensions: A dictionary of optional extra information included on
the request. Possible keys include `"timeout"`, and `"trace"`.
"""
self.method: bytes = enforce_bytes(method, name="method")
self.url: URL = enforce_url(url, name="url")
self.headers: List[Tuple[bytes, bytes]] = enforce_headers(
self.headers: list[tuple[bytes, bytes]] = enforce_headers(
headers, name="headers"
)
self.stream: Union[Iterable[bytes], AsyncIterable[bytes]] = enforce_stream(
content, name="content"
self.stream: typing.Iterable[bytes] | typing.AsyncIterable[bytes] = (
enforce_stream(content, name="content")
)
self.extensions = {} if extensions is None else extensions
if "target" in self.extensions:
self.url = URL(
scheme=self.url.scheme,
host=self.url.host,
port=self.url.port,
target=self.extensions["target"],
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__} [{self.method!r}]>"
@@ -366,8 +373,11 @@ class Response:
status: int,
*,
headers: HeaderTypes = None,
content: Union[bytes, Iterable[bytes], AsyncIterable[bytes], None] = None,
extensions: Optional[Extensions] = None,
content: bytes
| typing.Iterable[bytes]
| typing.AsyncIterable[bytes]
| None = None,
extensions: Extensions | None = None,
) -> None:
"""
Parameters:
@@ -379,11 +389,11 @@ class Response:
`"reason_phrase"`, and `"network_stream"`.
"""
self.status: int = status
self.headers: List[Tuple[bytes, bytes]] = enforce_headers(
self.headers: list[tuple[bytes, bytes]] = enforce_headers(
headers, name="headers"
)
self.stream: Union[Iterable[bytes], AsyncIterable[bytes]] = enforce_stream(
content, name="content"
self.stream: typing.Iterable[bytes] | typing.AsyncIterable[bytes] = (
enforce_stream(content, name="content")
)
self.extensions = {} if extensions is None else extensions
@@ -392,7 +402,7 @@ class Response:
@property
def content(self) -> bytes:
if not hasattr(self, "_content"):
if isinstance(self.stream, Iterable):
if isinstance(self.stream, typing.Iterable):
raise RuntimeError(
"Attempted to access 'response.content' on a streaming response. "
"Call 'response.read()' first."
@@ -410,7 +420,7 @@ class Response:
# Sync interface...
def read(self) -> bytes:
if not isinstance(self.stream, Iterable): # pragma: nocover
if not isinstance(self.stream, typing.Iterable): # pragma: nocover
raise RuntimeError(
"Attempted to read an asynchronous response using 'response.read()'. "
"You should use 'await response.aread()' instead."
@@ -419,8 +429,8 @@ class Response:
self._content = b"".join([part for part in self.iter_stream()])
return self._content
def iter_stream(self) -> Iterator[bytes]:
if not isinstance(self.stream, Iterable): # pragma: nocover
def iter_stream(self) -> typing.Iterator[bytes]:
if not isinstance(self.stream, typing.Iterable): # pragma: nocover
raise RuntimeError(
"Attempted to stream an asynchronous response using 'for ... in "
"response.iter_stream()'. "
@@ -435,7 +445,7 @@ class Response:
yield chunk
def close(self) -> None:
if not isinstance(self.stream, Iterable): # pragma: nocover
if not isinstance(self.stream, typing.Iterable): # pragma: nocover
raise RuntimeError(
"Attempted to close an asynchronous response using 'response.close()'. "
"You should use 'await response.aclose()' instead."
@@ -446,7 +456,7 @@ class Response:
# Async interface...
async def aread(self) -> bytes:
if not isinstance(self.stream, AsyncIterable): # pragma: nocover
if not isinstance(self.stream, typing.AsyncIterable): # pragma: nocover
raise RuntimeError(
"Attempted to read an synchronous response using "
"'await response.aread()'. "
@@ -456,8 +466,8 @@ class Response:
self._content = b"".join([part async for part in self.aiter_stream()])
return self._content
async def aiter_stream(self) -> AsyncIterator[bytes]:
if not isinstance(self.stream, AsyncIterable): # pragma: nocover
async def aiter_stream(self) -> typing.AsyncIterator[bytes]:
if not isinstance(self.stream, typing.AsyncIterable): # pragma: nocover
raise RuntimeError(
"Attempted to stream an synchronous response using 'async for ... in "
"response.aiter_stream()'. "
@@ -473,7 +483,7 @@ class Response:
yield chunk
async def aclose(self) -> None:
if not isinstance(self.stream, AsyncIterable): # pragma: nocover
if not isinstance(self.stream, typing.AsyncIterable): # pragma: nocover
raise RuntimeError(
"Attempted to close a synchronous response using "
"'await response.aclose()'. "
@@ -481,3 +491,26 @@ class Response:
)
if hasattr(self.stream, "aclose"):
await self.stream.aclose()
class Proxy:
def __init__(
self,
url: URL | bytes | str,
auth: tuple[bytes | str, bytes | str] | None = None,
headers: HeadersAsMapping | HeadersAsSequence | None = None,
ssl_context: ssl.SSLContext | None = None,
):
self.url = enforce_url(url, name="url")
self.headers = enforce_headers(headers, name="headers")
self.ssl_context = ssl_context
if auth is not None:
username = enforce_bytes(auth[0], name="auth")
password = enforce_bytes(auth[1], name="auth")
userpass = username + b":" + password
authorization = b"Basic " + base64.b64encode(userpass)
self.auth: tuple[bytes, bytes] | None = (username, password)
self.headers = [(b"Proxy-Authorization", authorization)] + self.headers
else:
self.auth = None

View File

@@ -1,12 +1,14 @@
from __future__ import annotations
import itertools
import logging
import ssl
from types import TracebackType
from typing import Iterable, Iterator, Optional, Type
import types
import typing
from .._backends.sync import SyncBackend
from .._backends.base import SOCKET_OPTION, NetworkBackend, NetworkStream
from .._exceptions import ConnectError, ConnectionNotAvailable, ConnectTimeout
from .._exceptions import ConnectError, ConnectTimeout
from .._models import Origin, Request, Response
from .._ssl import default_ssl_context
from .._synchronization import Lock
@@ -20,25 +22,32 @@ RETRIES_BACKOFF_FACTOR = 0.5 # 0s, 0.5s, 1s, 2s, 4s, etc.
logger = logging.getLogger("httpcore.connection")
def exponential_backoff(factor: float) -> Iterator[float]:
def exponential_backoff(factor: float) -> typing.Iterator[float]:
"""
Generate a geometric sequence that has a ratio of 2 and starts with 0.
For example:
- `factor = 2`: `0, 2, 4, 8, 16, 32, 64, ...`
- `factor = 3`: `0, 3, 6, 12, 24, 48, 96, ...`
"""
yield 0
for n in itertools.count(2):
yield factor * (2 ** (n - 2))
for n in itertools.count():
yield factor * 2**n
class HTTPConnection(ConnectionInterface):
def __init__(
self,
origin: Origin,
ssl_context: Optional[ssl.SSLContext] = None,
keepalive_expiry: Optional[float] = None,
ssl_context: ssl.SSLContext | None = None,
keepalive_expiry: float | None = None,
http1: bool = True,
http2: bool = False,
retries: int = 0,
local_address: Optional[str] = None,
uds: Optional[str] = None,
network_backend: Optional[NetworkBackend] = None,
socket_options: Optional[Iterable[SOCKET_OPTION]] = None,
local_address: str | None = None,
uds: str | None = None,
network_backend: NetworkBackend | None = None,
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
) -> None:
self._origin = origin
self._ssl_context = ssl_context
@@ -52,7 +61,7 @@ class HTTPConnection(ConnectionInterface):
self._network_backend: NetworkBackend = (
SyncBackend() if network_backend is None else network_backend
)
self._connection: Optional[ConnectionInterface] = None
self._connection: ConnectionInterface | None = None
self._connect_failed: bool = False
self._request_lock = Lock()
self._socket_options = socket_options
@@ -63,9 +72,9 @@ class HTTPConnection(ConnectionInterface):
f"Attempted to send request to {request.url.origin} on connection to {self._origin}"
)
with self._request_lock:
if self._connection is None:
try:
try:
with self._request_lock:
if self._connection is None:
stream = self._connect(request)
ssl_object = stream.get_extra_info("ssl_object")
@@ -87,11 +96,9 @@ class HTTPConnection(ConnectionInterface):
stream=stream,
keepalive_expiry=self._keepalive_expiry,
)
except Exception as exc:
self._connect_failed = True
raise exc
elif not self._connection.is_available():
raise ConnectionNotAvailable()
except BaseException as exc:
self._connect_failed = True
raise exc
return self._connection.handle_request(request)
@@ -130,7 +137,7 @@ class HTTPConnection(ConnectionInterface):
)
trace.return_value = stream
if self._origin.scheme == b"https":
if self._origin.scheme in (b"https", b"wss"):
ssl_context = (
default_ssl_context()
if self._ssl_context is None
@@ -203,13 +210,13 @@ class HTTPConnection(ConnectionInterface):
# These context managers are not used in the standard flow, but are
# useful for testing or working with connection instances directly.
def __enter__(self) -> "HTTPConnection":
def __enter__(self) -> HTTPConnection:
return self
def __exit__(
self,
exc_type: Optional[Type[BaseException]] = None,
exc_value: Optional[BaseException] = None,
traceback: Optional[TracebackType] = None,
exc_type: type[BaseException] | None = None,
exc_value: BaseException | None = None,
traceback: types.TracebackType | None = None,
) -> None:
self.close()

View File

@@ -1,41 +1,44 @@
from __future__ import annotations
import ssl
import sys
from types import TracebackType
from typing import Iterable, Iterator, Iterable, List, Optional, Type
import types
import typing
from .._backends.sync import SyncBackend
from .._backends.base import SOCKET_OPTION, NetworkBackend
from .._exceptions import ConnectionNotAvailable, UnsupportedProtocol
from .._models import Origin, Request, Response
from .._synchronization import Event, Lock, ShieldCancellation
from .._models import Origin, Proxy, Request, Response
from .._synchronization import Event, ShieldCancellation, ThreadLock
from .connection import HTTPConnection
from .interfaces import ConnectionInterface, RequestInterface
class RequestStatus:
def __init__(self, request: Request):
class PoolRequest:
def __init__(self, request: Request) -> None:
self.request = request
self.connection: Optional[ConnectionInterface] = None
self.connection: ConnectionInterface | None = None
self._connection_acquired = Event()
def set_connection(self, connection: ConnectionInterface) -> None:
assert self.connection is None
def assign_to_connection(self, connection: ConnectionInterface | None) -> None:
self.connection = connection
self._connection_acquired.set()
def unset_connection(self) -> None:
assert self.connection is not None
def clear_connection(self) -> None:
self.connection = None
self._connection_acquired = Event()
def wait_for_connection(
self, timeout: Optional[float] = None
self, timeout: float | None = None
) -> ConnectionInterface:
if self.connection is None:
self._connection_acquired.wait(timeout=timeout)
assert self.connection is not None
return self.connection
def is_queued(self) -> bool:
return self.connection is None
class ConnectionPool(RequestInterface):
"""
@@ -44,17 +47,18 @@ class ConnectionPool(RequestInterface):
def __init__(
self,
ssl_context: Optional[ssl.SSLContext] = None,
max_connections: Optional[int] = 10,
max_keepalive_connections: Optional[int] = None,
keepalive_expiry: Optional[float] = None,
ssl_context: ssl.SSLContext | None = None,
proxy: Proxy | None = None,
max_connections: int | None = 10,
max_keepalive_connections: int | None = None,
keepalive_expiry: float | None = None,
http1: bool = True,
http2: bool = False,
retries: int = 0,
local_address: Optional[str] = None,
uds: Optional[str] = None,
network_backend: Optional[NetworkBackend] = None,
socket_options: Optional[Iterable[SOCKET_OPTION]] = None,
local_address: str | None = None,
uds: str | None = None,
network_backend: NetworkBackend | None = None,
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
) -> None:
"""
A connection pool for making HTTP requests.
@@ -86,7 +90,7 @@ class ConnectionPool(RequestInterface):
in the TCP socket when the connection was established.
"""
self._ssl_context = ssl_context
self._proxy = proxy
self._max_connections = (
sys.maxsize if max_connections is None else max_connections
)
@@ -106,15 +110,61 @@ class ConnectionPool(RequestInterface):
self._local_address = local_address
self._uds = uds
self._pool: List[ConnectionInterface] = []
self._requests: List[RequestStatus] = []
self._pool_lock = Lock()
self._network_backend = (
SyncBackend() if network_backend is None else network_backend
)
self._socket_options = socket_options
# The mutable state on a connection pool is the queue of incoming requests,
# and the set of connections that are servicing those requests.
self._connections: list[ConnectionInterface] = []
self._requests: list[PoolRequest] = []
# We only mutate the state of the connection pool within an 'optional_thread_lock'
# context. This holds a threading lock unless we're running in async mode,
# in which case it is a no-op.
self._optional_thread_lock = ThreadLock()
def create_connection(self, origin: Origin) -> ConnectionInterface:
if self._proxy is not None:
if self._proxy.url.scheme in (b"socks5", b"socks5h"):
from .socks_proxy import Socks5Connection
return Socks5Connection(
proxy_origin=self._proxy.url.origin,
proxy_auth=self._proxy.auth,
remote_origin=origin,
ssl_context=self._ssl_context,
keepalive_expiry=self._keepalive_expiry,
http1=self._http1,
http2=self._http2,
network_backend=self._network_backend,
)
elif origin.scheme == b"http":
from .http_proxy import ForwardHTTPConnection
return ForwardHTTPConnection(
proxy_origin=self._proxy.url.origin,
proxy_headers=self._proxy.headers,
proxy_ssl_context=self._proxy.ssl_context,
remote_origin=origin,
keepalive_expiry=self._keepalive_expiry,
network_backend=self._network_backend,
)
from .http_proxy import TunnelHTTPConnection
return TunnelHTTPConnection(
proxy_origin=self._proxy.url.origin,
proxy_headers=self._proxy.headers,
proxy_ssl_context=self._proxy.ssl_context,
remote_origin=origin,
ssl_context=self._ssl_context,
keepalive_expiry=self._keepalive_expiry,
http1=self._http1,
http2=self._http2,
network_backend=self._network_backend,
)
return HTTPConnection(
origin=origin,
ssl_context=self._ssl_context,
@@ -129,7 +179,7 @@ class ConnectionPool(RequestInterface):
)
@property
def connections(self) -> List[ConnectionInterface]:
def connections(self) -> list[ConnectionInterface]:
"""
Return a list of the connections currently in the pool.
@@ -144,64 +194,7 @@ class ConnectionPool(RequestInterface):
]
```
"""
return list(self._pool)
def _attempt_to_acquire_connection(self, status: RequestStatus) -> bool:
"""
Attempt to provide a connection that can handle the given origin.
"""
origin = status.request.url.origin
# If there are queued requests in front of us, then don't acquire a
# connection. We handle requests strictly in order.
waiting = [s for s in self._requests if s.connection is None]
if waiting and waiting[0] is not status:
return False
# Reuse an existing connection if one is currently available.
for idx, connection in enumerate(self._pool):
if connection.can_handle_request(origin) and connection.is_available():
self._pool.pop(idx)
self._pool.insert(0, connection)
status.set_connection(connection)
return True
# If the pool is currently full, attempt to close one idle connection.
if len(self._pool) >= self._max_connections:
for idx, connection in reversed(list(enumerate(self._pool))):
if connection.is_idle():
connection.close()
self._pool.pop(idx)
break
# If the pool is still full, then we cannot acquire a connection.
if len(self._pool) >= self._max_connections:
return False
# Otherwise create a new connection.
connection = self.create_connection(origin)
self._pool.insert(0, connection)
status.set_connection(connection)
return True
def _close_expired_connections(self) -> None:
"""
Clean up the connection pool by closing off any connections that have expired.
"""
# Close any connections that have expired their keep-alive time.
for idx, connection in reversed(list(enumerate(self._pool))):
if connection.has_expired():
connection.close()
self._pool.pop(idx)
# If the pool size exceeds the maximum number of allowed keep-alive connections,
# then close off idle connections as required.
pool_size = len(self._pool)
for idx, connection in reversed(list(enumerate(self._pool))):
if connection.is_idle() and pool_size > self._max_keepalive_connections:
connection.close()
self._pool.pop(idx)
pool_size -= 1
return list(self._connections)
def handle_request(self, request: Request) -> Response:
"""
@@ -219,138 +212,209 @@ class ConnectionPool(RequestInterface):
f"Request URL has an unsupported protocol '{scheme}://'."
)
status = RequestStatus(request)
timeouts = request.extensions.get("timeout", {})
timeout = timeouts.get("pool", None)
with self._pool_lock:
self._requests.append(status)
self._close_expired_connections()
self._attempt_to_acquire_connection(status)
with self._optional_thread_lock:
# Add the incoming request to our request queue.
pool_request = PoolRequest(request)
self._requests.append(pool_request)
while True:
timeouts = request.extensions.get("timeout", {})
timeout = timeouts.get("pool", None)
try:
connection = status.wait_for_connection(timeout=timeout)
except BaseException as exc:
# If we timeout here, or if the task is cancelled, then make
# sure to remove the request from the queue before bubbling
# up the exception.
with self._pool_lock:
# Ensure only remove when task exists.
if status in self._requests:
self._requests.remove(status)
raise exc
try:
while True:
with self._optional_thread_lock:
# Assign incoming requests to available connections,
# closing or creating new connections as required.
closing = self._assign_requests_to_connections()
self._close_connections(closing)
try:
response = connection.handle_request(request)
except ConnectionNotAvailable:
# The ConnectionNotAvailable exception is a special case, that
# indicates we need to retry the request on a new connection.
#
# The most common case where this can occur is when multiple
# requests are queued waiting for a single connection, which
# might end up as an HTTP/2 connection, but which actually ends
# up as HTTP/1.1.
with self._pool_lock:
# Maintain our position in the request queue, but reset the
# status so that the request becomes queued again.
status.unset_connection()
self._attempt_to_acquire_connection(status)
except BaseException as exc:
with ShieldCancellation():
self.response_closed(status)
raise exc
else:
break
# Wait until this request has an assigned connection.
connection = pool_request.wait_for_connection(timeout=timeout)
# When we return the response, we wrap the stream in a special class
# that handles notifying the connection pool once the response
# has been released.
assert isinstance(response.stream, Iterable)
try:
# Send the request on the assigned connection.
response = connection.handle_request(
pool_request.request
)
except ConnectionNotAvailable:
# In some cases a connection may initially be available to
# handle a request, but then become unavailable.
#
# In this case we clear the connection and try again.
pool_request.clear_connection()
else:
break # pragma: nocover
except BaseException as exc:
with self._optional_thread_lock:
# For any exception or cancellation we remove the request from
# the queue, and then re-assign requests to connections.
self._requests.remove(pool_request)
closing = self._assign_requests_to_connections()
self._close_connections(closing)
raise exc from None
# Return the response. Note that in this case we still have to manage
# the point at which the response is closed.
assert isinstance(response.stream, typing.Iterable)
return Response(
status=response.status,
headers=response.headers,
content=ConnectionPoolByteStream(response.stream, self, status),
content=PoolByteStream(
stream=response.stream, pool_request=pool_request, pool=self
),
extensions=response.extensions,
)
def response_closed(self, status: RequestStatus) -> None:
def _assign_requests_to_connections(self) -> list[ConnectionInterface]:
"""
This method acts as a callback once the request/response cycle is complete.
Manage the state of the connection pool, assigning incoming
requests to connections as available.
It is called into from the `ConnectionPoolByteStream.close()` method.
Called whenever a new request is added or removed from the pool.
Any closing connections are returned, allowing the I/O for closing
those connections to be handled seperately.
"""
assert status.connection is not None
connection = status.connection
closing_connections = []
with self._pool_lock:
# Update the state of the connection pool.
if status in self._requests:
self._requests.remove(status)
# First we handle cleaning up any connections that are closed,
# have expired their keep-alive, or surplus idle connections.
for connection in list(self._connections):
if connection.is_closed():
# log: "removing closed connection"
self._connections.remove(connection)
elif connection.has_expired():
# log: "closing expired connection"
self._connections.remove(connection)
closing_connections.append(connection)
elif (
connection.is_idle()
and len([connection.is_idle() for connection in self._connections])
> self._max_keepalive_connections
):
# log: "closing idle connection"
self._connections.remove(connection)
closing_connections.append(connection)
if connection.is_closed() and connection in self._pool:
self._pool.remove(connection)
# Assign queued requests to connections.
queued_requests = [request for request in self._requests if request.is_queued()]
for pool_request in queued_requests:
origin = pool_request.request.url.origin
available_connections = [
connection
for connection in self._connections
if connection.can_handle_request(origin) and connection.is_available()
]
idle_connections = [
connection for connection in self._connections if connection.is_idle()
]
# Since we've had a response closed, it's possible we'll now be able
# to service one or more requests that are currently pending.
for status in self._requests:
if status.connection is None:
acquired = self._attempt_to_acquire_connection(status)
# If we could not acquire a connection for a queued request
# then we don't need to check anymore requests that are
# queued later behind it.
if not acquired:
break
# There are three cases for how we may be able to handle the request:
#
# 1. There is an existing connection that can handle the request.
# 2. We can create a new connection to handle the request.
# 3. We can close an idle connection and then create a new connection
# to handle the request.
if available_connections:
# log: "reusing existing connection"
connection = available_connections[0]
pool_request.assign_to_connection(connection)
elif len(self._connections) < self._max_connections:
# log: "creating new connection"
connection = self.create_connection(origin)
self._connections.append(connection)
pool_request.assign_to_connection(connection)
elif idle_connections:
# log: "closing idle connection"
connection = idle_connections[0]
self._connections.remove(connection)
closing_connections.append(connection)
# log: "creating new connection"
connection = self.create_connection(origin)
self._connections.append(connection)
pool_request.assign_to_connection(connection)
# Housekeeping.
self._close_expired_connections()
return closing_connections
def _close_connections(self, closing: list[ConnectionInterface]) -> None:
# Close connections which have been removed from the pool.
with ShieldCancellation():
for connection in closing:
connection.close()
def close(self) -> None:
"""
Close any connections in the pool.
"""
with self._pool_lock:
for connection in self._pool:
connection.close()
self._pool = []
self._requests = []
# Explicitly close the connection pool.
# Clears all existing requests and connections.
with self._optional_thread_lock:
closing_connections = list(self._connections)
self._connections = []
self._close_connections(closing_connections)
def __enter__(self) -> "ConnectionPool":
def __enter__(self) -> ConnectionPool:
return self
def __exit__(
self,
exc_type: Optional[Type[BaseException]] = None,
exc_value: Optional[BaseException] = None,
traceback: Optional[TracebackType] = None,
exc_type: type[BaseException] | None = None,
exc_value: BaseException | None = None,
traceback: types.TracebackType | None = None,
) -> None:
self.close()
def __repr__(self) -> str:
class_name = self.__class__.__name__
with self._optional_thread_lock:
request_is_queued = [request.is_queued() for request in self._requests]
connection_is_idle = [
connection.is_idle() for connection in self._connections
]
class ConnectionPoolByteStream:
"""
A wrapper around the response byte stream, that additionally handles
notifying the connection pool when the response has been closed.
"""
num_active_requests = request_is_queued.count(False)
num_queued_requests = request_is_queued.count(True)
num_active_connections = connection_is_idle.count(False)
num_idle_connections = connection_is_idle.count(True)
requests_info = (
f"Requests: {num_active_requests} active, {num_queued_requests} queued"
)
connection_info = (
f"Connections: {num_active_connections} active, {num_idle_connections} idle"
)
return f"<{class_name} [{requests_info} | {connection_info}]>"
class PoolByteStream:
def __init__(
self,
stream: Iterable[bytes],
stream: typing.Iterable[bytes],
pool_request: PoolRequest,
pool: ConnectionPool,
status: RequestStatus,
) -> None:
self._stream = stream
self._pool_request = pool_request
self._pool = pool
self._status = status
self._closed = False
def __iter__(self) -> Iterator[bytes]:
for part in self._stream:
yield part
def __iter__(self) -> typing.Iterator[bytes]:
try:
for part in self._stream:
yield part
except BaseException as exc:
self.close()
raise exc from None
def close(self) -> None:
try:
if hasattr(self._stream, "close"):
self._stream.close()
finally:
if not self._closed:
self._closed = True
with ShieldCancellation():
self._pool.response_closed(self._status)
if hasattr(self._stream, "close"):
self._stream.close()
with self._pool._optional_thread_lock:
self._pool._requests.remove(self._pool_request)
closing = self._pool._assign_requests_to_connections()
self._pool._close_connections(closing)

View File

@@ -1,17 +1,11 @@
from __future__ import annotations
import enum
import logging
import ssl
import time
from types import TracebackType
from typing import (
Iterable,
Iterator,
List,
Optional,
Tuple,
Type,
Union,
cast,
)
import types
import typing
import h11
@@ -20,6 +14,7 @@ from .._exceptions import (
ConnectionNotAvailable,
LocalProtocolError,
RemoteProtocolError,
WriteError,
map_exceptions,
)
from .._models import Origin, Request, Response
@@ -31,7 +26,7 @@ logger = logging.getLogger("httpcore.http11")
# A subset of `h11.Event` types supported by `_send_event`
H11SendEvent = Union[
H11SendEvent = typing.Union[
h11.Request,
h11.Data,
h11.EndOfMessage,
@@ -53,12 +48,12 @@ class HTTP11Connection(ConnectionInterface):
self,
origin: Origin,
stream: NetworkStream,
keepalive_expiry: Optional[float] = None,
keepalive_expiry: float | None = None,
) -> None:
self._origin = origin
self._network_stream = stream
self._keepalive_expiry: Optional[float] = keepalive_expiry
self._expire_at: Optional[float] = None
self._keepalive_expiry: float | None = keepalive_expiry
self._expire_at: float | None = None
self._state = HTTPConnectionState.NEW
self._state_lock = Lock()
self._request_count = 0
@@ -84,10 +79,21 @@ class HTTP11Connection(ConnectionInterface):
try:
kwargs = {"request": request}
with Trace("send_request_headers", logger, request, kwargs) as trace:
self._send_request_headers(**kwargs)
with Trace("send_request_body", logger, request, kwargs) as trace:
self._send_request_body(**kwargs)
try:
with Trace(
"send_request_headers", logger, request, kwargs
) as trace:
self._send_request_headers(**kwargs)
with Trace("send_request_body", logger, request, kwargs) as trace:
self._send_request_body(**kwargs)
except WriteError:
# If we get a write error while we're writing the request,
# then we supress this error and move on to attempting to
# read the response. Servers can sometimes close the request
# pre-emptively and then respond with a well formed HTTP
# error response.
pass
with Trace(
"receive_response_headers", logger, request, kwargs
) as trace:
@@ -96,6 +102,7 @@ class HTTP11Connection(ConnectionInterface):
status,
reason_phrase,
headers,
trailing_data,
) = self._receive_response_headers(**kwargs)
trace.return_value = (
http_version,
@@ -104,6 +111,14 @@ class HTTP11Connection(ConnectionInterface):
headers,
)
network_stream = self._network_stream
# CONNECT or Upgrade request
if (status == 101) or (
(request.method == b"CONNECT") and (200 <= status < 300)
):
network_stream = HTTP11UpgradeStream(network_stream, trailing_data)
return Response(
status=status,
headers=headers,
@@ -111,7 +126,7 @@ class HTTP11Connection(ConnectionInterface):
extensions={
"http_version": http_version,
"reason_phrase": reason_phrase,
"network_stream": self._network_stream,
"network_stream": network_stream,
},
)
except BaseException as exc:
@@ -138,16 +153,14 @@ class HTTP11Connection(ConnectionInterface):
timeouts = request.extensions.get("timeout", {})
timeout = timeouts.get("write", None)
assert isinstance(request.stream, Iterable)
assert isinstance(request.stream, typing.Iterable)
for chunk in request.stream:
event = h11.Data(data=chunk)
self._send_event(event, timeout=timeout)
self._send_event(h11.EndOfMessage(), timeout=timeout)
def _send_event(
self, event: h11.Event, timeout: Optional[float] = None
) -> None:
def _send_event(self, event: h11.Event, timeout: float | None = None) -> None:
bytes_to_send = self._h11_state.send(event)
if bytes_to_send is not None:
self._network_stream.write(bytes_to_send, timeout=timeout)
@@ -156,7 +169,7 @@ class HTTP11Connection(ConnectionInterface):
def _receive_response_headers(
self, request: Request
) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]]]:
) -> tuple[bytes, int, bytes, list[tuple[bytes, bytes]], bytes]:
timeouts = request.extensions.get("timeout", {})
timeout = timeouts.get("read", None)
@@ -176,9 +189,13 @@ class HTTP11Connection(ConnectionInterface):
# raw header casing, rather than the enforced lowercase headers.
headers = event.headers.raw_items()
return http_version, event.status_code, event.reason, headers
trailing_data, _ = self._h11_state.trailing_data
def _receive_response_body(self, request: Request) -> Iterator[bytes]:
return http_version, event.status_code, event.reason, headers, trailing_data
def _receive_response_body(
self, request: Request
) -> typing.Iterator[bytes]:
timeouts = request.extensions.get("timeout", {})
timeout = timeouts.get("read", None)
@@ -190,8 +207,8 @@ class HTTP11Connection(ConnectionInterface):
break
def _receive_event(
self, timeout: Optional[float] = None
) -> Union[h11.Event, Type[h11.PAUSED]]:
self, timeout: float | None = None
) -> h11.Event | type[h11.PAUSED]:
while True:
with map_exceptions({h11.RemoteProtocolError: RemoteProtocolError}):
event = self._h11_state.next_event()
@@ -216,7 +233,7 @@ class HTTP11Connection(ConnectionInterface):
self._h11_state.receive_data(data)
else:
# mypy fails to narrow the type in the above if statement above
return cast(Union[h11.Event, Type[h11.PAUSED]], event)
return event # type: ignore[return-value]
def _response_closed(self) -> None:
with self._state_lock:
@@ -292,14 +309,14 @@ class HTTP11Connection(ConnectionInterface):
# These context managers are not used in the standard flow, but are
# useful for testing or working with connection instances directly.
def __enter__(self) -> "HTTP11Connection":
def __enter__(self) -> HTTP11Connection:
return self
def __exit__(
self,
exc_type: Optional[Type[BaseException]] = None,
exc_value: Optional[BaseException] = None,
traceback: Optional[TracebackType] = None,
exc_type: type[BaseException] | None = None,
exc_value: BaseException | None = None,
traceback: types.TracebackType | None = None,
) -> None:
self.close()
@@ -310,7 +327,7 @@ class HTTP11ConnectionByteStream:
self._request = request
self._closed = False
def __iter__(self) -> Iterator[bytes]:
def __iter__(self) -> typing.Iterator[bytes]:
kwargs = {"request": self._request}
try:
with Trace("receive_response_body", logger, self._request, kwargs):
@@ -329,3 +346,34 @@ class HTTP11ConnectionByteStream:
self._closed = True
with Trace("response_closed", logger, self._request):
self._connection._response_closed()
class HTTP11UpgradeStream(NetworkStream):
def __init__(self, stream: NetworkStream, leading_data: bytes) -> None:
self._stream = stream
self._leading_data = leading_data
def read(self, max_bytes: int, timeout: float | None = None) -> bytes:
if self._leading_data:
buffer = self._leading_data[:max_bytes]
self._leading_data = self._leading_data[max_bytes:]
return buffer
else:
return self._stream.read(max_bytes, timeout)
def write(self, buffer: bytes, timeout: float | None = None) -> None:
self._stream.write(buffer, timeout)
def close(self) -> None:
self._stream.close()
def start_tls(
self,
ssl_context: ssl.SSLContext,
server_hostname: str | None = None,
timeout: float | None = None,
) -> NetworkStream:
return self._stream.start_tls(ssl_context, server_hostname, timeout)
def get_extra_info(self, info: str) -> typing.Any:
return self._stream.get_extra_info(info)

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
import enum
import logging
import time
@@ -45,14 +47,14 @@ class HTTP2Connection(ConnectionInterface):
self,
origin: Origin,
stream: NetworkStream,
keepalive_expiry: typing.Optional[float] = None,
keepalive_expiry: float | None = None,
):
self._origin = origin
self._network_stream = stream
self._keepalive_expiry: typing.Optional[float] = keepalive_expiry
self._keepalive_expiry: float | None = keepalive_expiry
self._h2_state = h2.connection.H2Connection(config=self.CONFIG)
self._state = HTTPConnectionState.IDLE
self._expire_at: typing.Optional[float] = None
self._expire_at: float | None = None
self._request_count = 0
self._init_lock = Lock()
self._state_lock = Lock()
@@ -63,24 +65,22 @@ class HTTP2Connection(ConnectionInterface):
self._connection_error = False
# Mapping from stream ID to response stream events.
self._events: typing.Dict[
self._events: dict[
int,
typing.Union[
h2.events.ResponseReceived,
h2.events.DataReceived,
h2.events.StreamEnded,
h2.events.StreamReset,
list[
h2.events.ResponseReceived
| h2.events.DataReceived
| h2.events.StreamEnded
| h2.events.StreamReset,
],
] = {}
# Connection terminated events are stored as state since
# we need to handle them for all streams.
self._connection_terminated: typing.Optional[
h2.events.ConnectionTerminated
] = None
self._connection_terminated: h2.events.ConnectionTerminated | None = None
self._read_exception: typing.Optional[Exception] = None
self._write_exception: typing.Optional[Exception] = None
self._read_exception: Exception | None = None
self._write_exception: Exception | None = None
def handle_request(self, request: Request) -> Response:
if not self.can_handle_request(request.url.origin):
@@ -104,9 +104,11 @@ class HTTP2Connection(ConnectionInterface):
with self._init_lock:
if not self._sent_connection_init:
try:
kwargs = {"request": request}
with Trace("send_connection_init", logger, request, kwargs):
self._send_connection_init(**kwargs)
sci_kwargs = {"request": request}
with Trace(
"send_connection_init", logger, request, sci_kwargs
):
self._send_connection_init(**sci_kwargs)
except BaseException as exc:
with ShieldCancellation():
self.close()
@@ -284,7 +286,7 @@ class HTTP2Connection(ConnectionInterface):
def _receive_response(
self, request: Request, stream_id: int
) -> typing.Tuple[int, typing.List[typing.Tuple[bytes, bytes]]]:
) -> tuple[int, list[tuple[bytes, bytes]]]:
"""
Return the response status code and headers for a given stream ID.
"""
@@ -295,6 +297,7 @@ class HTTP2Connection(ConnectionInterface):
status_code = 200
headers = []
assert event.headers is not None
for k, v in event.headers:
if k == b":status":
status_code = int(v.decode("ascii", errors="ignore"))
@@ -312,6 +315,8 @@ class HTTP2Connection(ConnectionInterface):
while True:
event = self._receive_stream_event(request, stream_id)
if isinstance(event, h2.events.DataReceived):
assert event.flow_controlled_length is not None
assert event.data is not None
amount = event.flow_controlled_length
self._h2_state.acknowledge_received_data(amount, stream_id)
self._write_outgoing_data(request)
@@ -321,9 +326,7 @@ class HTTP2Connection(ConnectionInterface):
def _receive_stream_event(
self, request: Request, stream_id: int
) -> typing.Union[
h2.events.ResponseReceived, h2.events.DataReceived, h2.events.StreamEnded
]:
) -> h2.events.ResponseReceived | h2.events.DataReceived | h2.events.StreamEnded:
"""
Return the next available event for a given stream ID.
@@ -337,7 +340,7 @@ class HTTP2Connection(ConnectionInterface):
return event
def _receive_events(
self, request: Request, stream_id: typing.Optional[int] = None
self, request: Request, stream_id: int | None = None
) -> None:
"""
Read some data from the network until we see one or more events
@@ -384,7 +387,9 @@ class HTTP2Connection(ConnectionInterface):
self._write_outgoing_data(request)
def _receive_remote_settings_change(self, event: h2.events.Event) -> None:
def _receive_remote_settings_change(
self, event: h2.events.RemoteSettingsChanged
) -> None:
max_concurrent_streams = event.changed_settings.get(
h2.settings.SettingCodes.MAX_CONCURRENT_STREAMS
)
@@ -425,9 +430,7 @@ class HTTP2Connection(ConnectionInterface):
# Wrappers around network read/write operations...
def _read_incoming_data(
self, request: Request
) -> typing.List[h2.events.Event]:
def _read_incoming_data(self, request: Request) -> list[h2.events.Event]:
timeouts = request.extensions.get("timeout", {})
timeout = timeouts.get("read", None)
@@ -451,7 +454,7 @@ class HTTP2Connection(ConnectionInterface):
self._connection_error = True
raise exc
events: typing.List[h2.events.Event] = self._h2_state.receive_data(data)
events: list[h2.events.Event] = self._h2_state.receive_data(data)
return events
@@ -544,14 +547,14 @@ class HTTP2Connection(ConnectionInterface):
# These context managers are not used in the standard flow, but are
# useful for testing or working with connection instances directly.
def __enter__(self) -> "HTTP2Connection":
def __enter__(self) -> HTTP2Connection:
return self
def __exit__(
self,
exc_type: typing.Optional[typing.Type[BaseException]] = None,
exc_value: typing.Optional[BaseException] = None,
traceback: typing.Optional[types.TracebackType] = None,
exc_type: type[BaseException] | None = None,
exc_value: BaseException | None = None,
traceback: types.TracebackType | None = None,
) -> None:
self.close()

View File

@@ -1,7 +1,9 @@
from __future__ import annotations
import base64
import logging
import ssl
from base64 import b64encode
from typing import Iterable, List, Mapping, Optional, Sequence, Tuple, Union
import typing
from .._backends.base import SOCKET_OPTION, NetworkBackend
from .._exceptions import ProxyError
@@ -22,17 +24,18 @@ from .connection_pool import ConnectionPool
from .http11 import HTTP11Connection
from .interfaces import ConnectionInterface
HeadersAsSequence = Sequence[Tuple[Union[bytes, str], Union[bytes, str]]]
HeadersAsMapping = Mapping[Union[bytes, str], Union[bytes, str]]
ByteOrStr = typing.Union[bytes, str]
HeadersAsSequence = typing.Sequence[typing.Tuple[ByteOrStr, ByteOrStr]]
HeadersAsMapping = typing.Mapping[ByteOrStr, ByteOrStr]
logger = logging.getLogger("httpcore.proxy")
def merge_headers(
default_headers: Optional[Sequence[Tuple[bytes, bytes]]] = None,
override_headers: Optional[Sequence[Tuple[bytes, bytes]]] = None,
) -> List[Tuple[bytes, bytes]]:
default_headers: typing.Sequence[tuple[bytes, bytes]] | None = None,
override_headers: typing.Sequence[tuple[bytes, bytes]] | None = None,
) -> list[tuple[bytes, bytes]]:
"""
Append default_headers and override_headers, de-duplicating if a key exists
in both cases.
@@ -48,32 +51,28 @@ def merge_headers(
return default_headers + override_headers
def build_auth_header(username: bytes, password: bytes) -> bytes:
userpass = username + b":" + password
return b"Basic " + b64encode(userpass)
class HTTPProxy(ConnectionPool):
class HTTPProxy(ConnectionPool): # pragma: nocover
"""
A connection pool that sends requests via an HTTP proxy.
"""
def __init__(
self,
proxy_url: Union[URL, bytes, str],
proxy_auth: Optional[Tuple[Union[bytes, str], Union[bytes, str]]] = None,
proxy_headers: Union[HeadersAsMapping, HeadersAsSequence, None] = None,
ssl_context: Optional[ssl.SSLContext] = None,
max_connections: Optional[int] = 10,
max_keepalive_connections: Optional[int] = None,
keepalive_expiry: Optional[float] = None,
proxy_url: URL | bytes | str,
proxy_auth: tuple[bytes | str, bytes | str] | None = None,
proxy_headers: HeadersAsMapping | HeadersAsSequence | None = None,
ssl_context: ssl.SSLContext | None = None,
proxy_ssl_context: ssl.SSLContext | None = None,
max_connections: int | None = 10,
max_keepalive_connections: int | None = None,
keepalive_expiry: float | None = None,
http1: bool = True,
http2: bool = False,
retries: int = 0,
local_address: Optional[str] = None,
uds: Optional[str] = None,
network_backend: Optional[NetworkBackend] = None,
socket_options: Optional[Iterable[SOCKET_OPTION]] = None,
local_address: str | None = None,
uds: str | None = None,
network_backend: NetworkBackend | None = None,
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
) -> None:
"""
A connection pool for making HTTP requests.
@@ -88,6 +87,7 @@ class HTTPProxy(ConnectionPool):
ssl_context: An SSL context to use for verifying connections.
If not specified, the default `httpcore.default_ssl_context()`
will be used.
proxy_ssl_context: The same as `ssl_context`, but for a proxy server rather than a remote origin.
max_connections: The maximum number of concurrent HTTP connections that
the pool should allow. Any attempt to send a request on a pool that
would exceed this amount will block until a connection is available.
@@ -122,13 +122,23 @@ class HTTPProxy(ConnectionPool):
uds=uds,
socket_options=socket_options,
)
self._ssl_context = ssl_context
self._proxy_url = enforce_url(proxy_url, name="proxy_url")
if (
self._proxy_url.scheme == b"http" and proxy_ssl_context is not None
): # pragma: no cover
raise RuntimeError(
"The `proxy_ssl_context` argument is not allowed for the http scheme"
)
self._ssl_context = ssl_context
self._proxy_ssl_context = proxy_ssl_context
self._proxy_headers = enforce_headers(proxy_headers, name="proxy_headers")
if proxy_auth is not None:
username = enforce_bytes(proxy_auth[0], name="proxy_auth")
password = enforce_bytes(proxy_auth[1], name="proxy_auth")
authorization = build_auth_header(username, password)
userpass = username + b":" + password
authorization = b"Basic " + base64.b64encode(userpass)
self._proxy_headers = [
(b"Proxy-Authorization", authorization)
] + self._proxy_headers
@@ -141,12 +151,14 @@ class HTTPProxy(ConnectionPool):
remote_origin=origin,
keepalive_expiry=self._keepalive_expiry,
network_backend=self._network_backend,
proxy_ssl_context=self._proxy_ssl_context,
)
return TunnelHTTPConnection(
proxy_origin=self._proxy_url.origin,
proxy_headers=self._proxy_headers,
remote_origin=origin,
ssl_context=self._ssl_context,
proxy_ssl_context=self._proxy_ssl_context,
keepalive_expiry=self._keepalive_expiry,
http1=self._http1,
http2=self._http2,
@@ -159,16 +171,18 @@ class ForwardHTTPConnection(ConnectionInterface):
self,
proxy_origin: Origin,
remote_origin: Origin,
proxy_headers: Union[HeadersAsMapping, HeadersAsSequence, None] = None,
keepalive_expiry: Optional[float] = None,
network_backend: Optional[NetworkBackend] = None,
socket_options: Optional[Iterable[SOCKET_OPTION]] = None,
proxy_headers: HeadersAsMapping | HeadersAsSequence | None = None,
keepalive_expiry: float | None = None,
network_backend: NetworkBackend | None = None,
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
proxy_ssl_context: ssl.SSLContext | None = None,
) -> None:
self._connection = HTTPConnection(
origin=proxy_origin,
keepalive_expiry=keepalive_expiry,
network_backend=network_backend,
socket_options=socket_options,
ssl_context=proxy_ssl_context,
)
self._proxy_origin = proxy_origin
self._proxy_headers = enforce_headers(proxy_headers, name="proxy_headers")
@@ -221,23 +235,26 @@ class TunnelHTTPConnection(ConnectionInterface):
self,
proxy_origin: Origin,
remote_origin: Origin,
ssl_context: Optional[ssl.SSLContext] = None,
proxy_headers: Optional[Sequence[Tuple[bytes, bytes]]] = None,
keepalive_expiry: Optional[float] = None,
ssl_context: ssl.SSLContext | None = None,
proxy_ssl_context: ssl.SSLContext | None = None,
proxy_headers: typing.Sequence[tuple[bytes, bytes]] | None = None,
keepalive_expiry: float | None = None,
http1: bool = True,
http2: bool = False,
network_backend: Optional[NetworkBackend] = None,
socket_options: Optional[Iterable[SOCKET_OPTION]] = None,
network_backend: NetworkBackend | None = None,
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
) -> None:
self._connection: ConnectionInterface = HTTPConnection(
origin=proxy_origin,
keepalive_expiry=keepalive_expiry,
network_backend=network_backend,
socket_options=socket_options,
ssl_context=proxy_ssl_context,
)
self._proxy_origin = proxy_origin
self._remote_origin = remote_origin
self._ssl_context = ssl_context
self._proxy_ssl_context = proxy_ssl_context
self._proxy_headers = enforce_headers(proxy_headers, name="proxy_headers")
self._keepalive_expiry = keepalive_expiry
self._http1 = http1

View File

@@ -1,5 +1,7 @@
from contextlib import contextmanager
from typing import Iterator, Optional, Union
from __future__ import annotations
import contextlib
import typing
from .._models import (
URL,
@@ -18,12 +20,12 @@ from .._models import (
class RequestInterface:
def request(
self,
method: Union[bytes, str],
url: Union[URL, bytes, str],
method: bytes | str,
url: URL | bytes | str,
*,
headers: HeaderTypes = None,
content: Union[bytes, Iterator[bytes], None] = None,
extensions: Optional[Extensions] = None,
content: bytes | typing.Iterator[bytes] | None = None,
extensions: Extensions | None = None,
) -> Response:
# Strict type checking on our parameters.
method = enforce_bytes(method, name="method")
@@ -47,16 +49,16 @@ class RequestInterface:
response.close()
return response
@contextmanager
@contextlib.contextmanager
def stream(
self,
method: Union[bytes, str],
url: Union[URL, bytes, str],
method: bytes | str,
url: URL | bytes | str,
*,
headers: HeaderTypes = None,
content: Union[bytes, Iterator[bytes], None] = None,
extensions: Optional[Extensions] = None,
) -> Iterator[Response]:
content: bytes | typing.Iterator[bytes] | None = None,
extensions: Extensions | None = None,
) -> typing.Iterator[Response]:
# Strict type checking on our parameters.
method = enforce_bytes(method, name="method")
url = enforce_url(url, name="url")

View File

@@ -1,8 +1,9 @@
from __future__ import annotations
import logging
import ssl
import typing
from socksio import socks5
import socksio
from .._backends.sync import SyncBackend
from .._backends.base import NetworkBackend, NetworkStream
@@ -43,24 +44,24 @@ def _init_socks5_connection(
*,
host: bytes,
port: int,
auth: typing.Optional[typing.Tuple[bytes, bytes]] = None,
auth: tuple[bytes, bytes] | None = None,
) -> None:
conn = socks5.SOCKS5Connection()
conn = socksio.socks5.SOCKS5Connection()
# Auth method request
auth_method = (
socks5.SOCKS5AuthMethod.NO_AUTH_REQUIRED
socksio.socks5.SOCKS5AuthMethod.NO_AUTH_REQUIRED
if auth is None
else socks5.SOCKS5AuthMethod.USERNAME_PASSWORD
else socksio.socks5.SOCKS5AuthMethod.USERNAME_PASSWORD
)
conn.send(socks5.SOCKS5AuthMethodsRequest([auth_method]))
conn.send(socksio.socks5.SOCKS5AuthMethodsRequest([auth_method]))
outgoing_bytes = conn.data_to_send()
stream.write(outgoing_bytes)
# Auth method response
incoming_bytes = stream.read(max_bytes=4096)
response = conn.receive_data(incoming_bytes)
assert isinstance(response, socks5.SOCKS5AuthReply)
assert isinstance(response, socksio.socks5.SOCKS5AuthReply)
if response.method != auth_method:
requested = AUTH_METHODS.get(auth_method, "UNKNOWN")
responded = AUTH_METHODS.get(response.method, "UNKNOWN")
@@ -68,25 +69,25 @@ def _init_socks5_connection(
f"Requested {requested} from proxy server, but got {responded}."
)
if response.method == socks5.SOCKS5AuthMethod.USERNAME_PASSWORD:
if response.method == socksio.socks5.SOCKS5AuthMethod.USERNAME_PASSWORD:
# Username/password request
assert auth is not None
username, password = auth
conn.send(socks5.SOCKS5UsernamePasswordRequest(username, password))
conn.send(socksio.socks5.SOCKS5UsernamePasswordRequest(username, password))
outgoing_bytes = conn.data_to_send()
stream.write(outgoing_bytes)
# Username/password response
incoming_bytes = stream.read(max_bytes=4096)
response = conn.receive_data(incoming_bytes)
assert isinstance(response, socks5.SOCKS5UsernamePasswordReply)
assert isinstance(response, socksio.socks5.SOCKS5UsernamePasswordReply)
if not response.success:
raise ProxyError("Invalid username/password")
# Connect request
conn.send(
socks5.SOCKS5CommandRequest.from_address(
socks5.SOCKS5Command.CONNECT, (host, port)
socksio.socks5.SOCKS5CommandRequest.from_address(
socksio.socks5.SOCKS5Command.CONNECT, (host, port)
)
)
outgoing_bytes = conn.data_to_send()
@@ -95,31 +96,29 @@ def _init_socks5_connection(
# Connect response
incoming_bytes = stream.read(max_bytes=4096)
response = conn.receive_data(incoming_bytes)
assert isinstance(response, socks5.SOCKS5Reply)
if response.reply_code != socks5.SOCKS5ReplyCode.SUCCEEDED:
assert isinstance(response, socksio.socks5.SOCKS5Reply)
if response.reply_code != socksio.socks5.SOCKS5ReplyCode.SUCCEEDED:
reply_code = REPLY_CODES.get(response.reply_code, "UNKOWN")
raise ProxyError(f"Proxy Server could not connect: {reply_code}.")
class SOCKSProxy(ConnectionPool):
class SOCKSProxy(ConnectionPool): # pragma: nocover
"""
A connection pool that sends requests via an HTTP proxy.
"""
def __init__(
self,
proxy_url: typing.Union[URL, bytes, str],
proxy_auth: typing.Optional[
typing.Tuple[typing.Union[bytes, str], typing.Union[bytes, str]]
] = None,
ssl_context: typing.Optional[ssl.SSLContext] = None,
max_connections: typing.Optional[int] = 10,
max_keepalive_connections: typing.Optional[int] = None,
keepalive_expiry: typing.Optional[float] = None,
proxy_url: URL | bytes | str,
proxy_auth: tuple[bytes | str, bytes | str] | None = None,
ssl_context: ssl.SSLContext | None = None,
max_connections: int | None = 10,
max_keepalive_connections: int | None = None,
keepalive_expiry: float | None = None,
http1: bool = True,
http2: bool = False,
retries: int = 0,
network_backend: typing.Optional[NetworkBackend] = None,
network_backend: NetworkBackend | None = None,
) -> None:
"""
A connection pool for making HTTP requests.
@@ -167,7 +166,7 @@ class SOCKSProxy(ConnectionPool):
username, password = proxy_auth
username_bytes = enforce_bytes(username, name="proxy_auth")
password_bytes = enforce_bytes(password, name="proxy_auth")
self._proxy_auth: typing.Optional[typing.Tuple[bytes, bytes]] = (
self._proxy_auth: tuple[bytes, bytes] | None = (
username_bytes,
password_bytes,
)
@@ -192,12 +191,12 @@ class Socks5Connection(ConnectionInterface):
self,
proxy_origin: Origin,
remote_origin: Origin,
proxy_auth: typing.Optional[typing.Tuple[bytes, bytes]] = None,
ssl_context: typing.Optional[ssl.SSLContext] = None,
keepalive_expiry: typing.Optional[float] = None,
proxy_auth: tuple[bytes, bytes] | None = None,
ssl_context: ssl.SSLContext | None = None,
keepalive_expiry: float | None = None,
http1: bool = True,
http2: bool = False,
network_backend: typing.Optional[NetworkBackend] = None,
network_backend: NetworkBackend | None = None,
) -> None:
self._proxy_origin = proxy_origin
self._remote_origin = remote_origin
@@ -211,11 +210,12 @@ class Socks5Connection(ConnectionInterface):
SyncBackend() if network_backend is None else network_backend
)
self._connect_lock = Lock()
self._connection: typing.Optional[ConnectionInterface] = None
self._connection: ConnectionInterface | None = None
self._connect_failed = False
def handle_request(self, request: Request) -> Response:
timeouts = request.extensions.get("timeout", {})
sni_hostname = request.extensions.get("sni_hostname", None)
timeout = timeouts.get("connect", None)
with self._connect_lock:
@@ -258,7 +258,8 @@ class Socks5Connection(ConnectionInterface):
kwargs = {
"ssl_context": ssl_context,
"server_hostname": self._remote_origin.host.decode("ascii"),
"server_hostname": sni_hostname
or self._remote_origin.host.decode("ascii"),
"timeout": timeout,
}
with Trace("start_tls", logger, request, kwargs) as trace:

View File

@@ -1,8 +1,7 @@
import threading
from types import TracebackType
from typing import Optional, Type
from __future__ import annotations
import sniffio
import threading
import types
from ._exceptions import ExceptionMapping, PoolTimeout, map_exceptions
@@ -11,7 +10,7 @@ from ._exceptions import ExceptionMapping, PoolTimeout, map_exceptions
try:
import trio
except ImportError: # pragma: nocover
except (ImportError, NotImplementedError): # pragma: nocover
trio = None # type: ignore
try:
@@ -20,7 +19,40 @@ except ImportError: # pragma: nocover
anyio = None # type: ignore
def current_async_library() -> str:
# Determine if we're running under trio or asyncio.
# See https://sniffio.readthedocs.io/en/latest/
try:
import sniffio
except ImportError: # pragma: nocover
environment = "asyncio"
else:
environment = sniffio.current_async_library()
if environment not in ("asyncio", "trio"): # pragma: nocover
raise RuntimeError("Running under an unsupported async environment.")
if environment == "asyncio" and anyio is None: # pragma: nocover
raise RuntimeError(
"Running with asyncio requires installation of 'httpcore[asyncio]'."
)
if environment == "trio" and trio is None: # pragma: nocover
raise RuntimeError(
"Running with trio requires installation of 'httpcore[trio]'."
)
return environment
class AsyncLock:
"""
This is a standard lock.
In the sync case `Lock` provides thread locking.
In the async case `AsyncLock` provides async locking.
"""
def __init__(self) -> None:
self._backend = ""
@@ -29,43 +61,55 @@ class AsyncLock:
Detect if we're running under 'asyncio' or 'trio' and create
a lock with the correct implementation.
"""
self._backend = sniffio.current_async_library()
self._backend = current_async_library()
if self._backend == "trio":
if trio is None: # pragma: nocover
raise RuntimeError(
"Running under trio, requires the 'trio' package to be installed."
)
self._trio_lock = trio.Lock()
else:
if anyio is None: # pragma: nocover
raise RuntimeError(
"Running under asyncio requires the 'anyio' package to be installed."
)
elif self._backend == "asyncio":
self._anyio_lock = anyio.Lock()
async def __aenter__(self) -> "AsyncLock":
async def __aenter__(self) -> AsyncLock:
if not self._backend:
self.setup()
if self._backend == "trio":
await self._trio_lock.acquire()
else:
elif self._backend == "asyncio":
await self._anyio_lock.acquire()
return self
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]] = None,
exc_value: Optional[BaseException] = None,
traceback: Optional[TracebackType] = None,
exc_type: type[BaseException] | None = None,
exc_value: BaseException | None = None,
traceback: types.TracebackType | None = None,
) -> None:
if self._backend == "trio":
self._trio_lock.release()
else:
elif self._backend == "asyncio":
self._anyio_lock.release()
class AsyncThreadLock:
"""
This is a threading-only lock for no-I/O contexts.
In the sync case `ThreadLock` provides thread locking.
In the async case `AsyncThreadLock` is a no-op.
"""
def __enter__(self) -> AsyncThreadLock:
return self
def __exit__(
self,
exc_type: type[BaseException] | None = None,
exc_value: BaseException | None = None,
traceback: types.TracebackType | None = None,
) -> None:
pass
class AsyncEvent:
def __init__(self) -> None:
self._backend = ""
@@ -75,18 +119,10 @@ class AsyncEvent:
Detect if we're running under 'asyncio' or 'trio' and create
a lock with the correct implementation.
"""
self._backend = sniffio.current_async_library()
self._backend = current_async_library()
if self._backend == "trio":
if trio is None: # pragma: nocover
raise RuntimeError(
"Running under trio requires the 'trio' package to be installed."
)
self._trio_event = trio.Event()
else:
if anyio is None: # pragma: nocover
raise RuntimeError(
"Running under asyncio requires the 'anyio' package to be installed."
)
elif self._backend == "asyncio":
self._anyio_event = anyio.Event()
def set(self) -> None:
@@ -95,30 +131,20 @@ class AsyncEvent:
if self._backend == "trio":
self._trio_event.set()
else:
elif self._backend == "asyncio":
self._anyio_event.set()
async def wait(self, timeout: Optional[float] = None) -> None:
async def wait(self, timeout: float | None = None) -> None:
if not self._backend:
self.setup()
if self._backend == "trio":
if trio is None: # pragma: nocover
raise RuntimeError(
"Running under trio requires the 'trio' package to be installed."
)
trio_exc_map: ExceptionMapping = {trio.TooSlowError: PoolTimeout}
timeout_or_inf = float("inf") if timeout is None else timeout
with map_exceptions(trio_exc_map):
with trio.fail_after(timeout_or_inf):
await self._trio_event.wait()
else:
if anyio is None: # pragma: nocover
raise RuntimeError(
"Running under asyncio requires the 'anyio' package to be installed."
)
elif self._backend == "asyncio":
anyio_exc_map: ExceptionMapping = {TimeoutError: PoolTimeout}
with map_exceptions(anyio_exc_map):
with anyio.fail_after(timeout):
@@ -135,22 +161,12 @@ class AsyncSemaphore:
Detect if we're running under 'asyncio' or 'trio' and create
a semaphore with the correct implementation.
"""
self._backend = sniffio.current_async_library()
self._backend = current_async_library()
if self._backend == "trio":
if trio is None: # pragma: nocover
raise RuntimeError(
"Running under trio requires the 'trio' package to be installed."
)
self._trio_semaphore = trio.Semaphore(
initial_value=self._bound, max_value=self._bound
)
else:
if anyio is None: # pragma: nocover
raise RuntimeError(
"Running under asyncio requires the 'anyio' package to be installed."
)
elif self._backend == "asyncio":
self._anyio_semaphore = anyio.Semaphore(
initial_value=self._bound, max_value=self._bound
)
@@ -161,13 +177,13 @@ class AsyncSemaphore:
if self._backend == "trio":
await self._trio_semaphore.acquire()
else:
elif self._backend == "asyncio":
await self._anyio_semaphore.acquire()
async def release(self) -> None:
if self._backend == "trio":
self._trio_semaphore.release()
else:
elif self._backend == "asyncio":
self._anyio_semaphore.release()
@@ -184,39 +200,29 @@ class AsyncShieldCancellation:
Detect if we're running under 'asyncio' or 'trio' and create
a shielded scope with the correct implementation.
"""
self._backend = sniffio.current_async_library()
self._backend = current_async_library()
if self._backend == "trio":
if trio is None: # pragma: nocover
raise RuntimeError(
"Running under trio requires the 'trio' package to be installed."
)
self._trio_shield = trio.CancelScope(shield=True)
else:
if anyio is None: # pragma: nocover
raise RuntimeError(
"Running under asyncio requires the 'anyio' package to be installed."
)
elif self._backend == "asyncio":
self._anyio_shield = anyio.CancelScope(shield=True)
def __enter__(self) -> "AsyncShieldCancellation":
def __enter__(self) -> AsyncShieldCancellation:
if self._backend == "trio":
self._trio_shield.__enter__()
else:
elif self._backend == "asyncio":
self._anyio_shield.__enter__()
return self
def __exit__(
self,
exc_type: Optional[Type[BaseException]] = None,
exc_value: Optional[BaseException] = None,
traceback: Optional[TracebackType] = None,
exc_type: type[BaseException] | None = None,
exc_value: BaseException | None = None,
traceback: types.TracebackType | None = None,
) -> None:
if self._backend == "trio":
self._trio_shield.__exit__(exc_type, exc_value, traceback)
else:
elif self._backend == "asyncio":
self._anyio_shield.__exit__(exc_type, exc_value, traceback)
@@ -224,18 +230,49 @@ class AsyncShieldCancellation:
class Lock:
"""
This is a standard lock.
In the sync case `Lock` provides thread locking.
In the async case `AsyncLock` provides async locking.
"""
def __init__(self) -> None:
self._lock = threading.Lock()
def __enter__(self) -> "Lock":
def __enter__(self) -> Lock:
self._lock.acquire()
return self
def __exit__(
self,
exc_type: Optional[Type[BaseException]] = None,
exc_value: Optional[BaseException] = None,
traceback: Optional[TracebackType] = None,
exc_type: type[BaseException] | None = None,
exc_value: BaseException | None = None,
traceback: types.TracebackType | None = None,
) -> None:
self._lock.release()
class ThreadLock:
"""
This is a threading-only lock for no-I/O contexts.
In the sync case `ThreadLock` provides thread locking.
In the async case `AsyncThreadLock` is a no-op.
"""
def __init__(self) -> None:
self._lock = threading.Lock()
def __enter__(self) -> ThreadLock:
self._lock.acquire()
return self
def __exit__(
self,
exc_type: type[BaseException] | None = None,
exc_value: BaseException | None = None,
traceback: types.TracebackType | None = None,
) -> None:
self._lock.release()
@@ -247,7 +284,9 @@ class Event:
def set(self) -> None:
self._event.set()
def wait(self, timeout: Optional[float] = None) -> None:
def wait(self, timeout: float | None = None) -> None:
if timeout == float("inf"): # pragma: no cover
timeout = None
if not self._event.wait(timeout=timeout):
raise PoolTimeout() # pragma: nocover
@@ -267,13 +306,13 @@ class ShieldCancellation:
# Thread-synchronous codebases don't support cancellation semantics.
# We have this class because we need to mirror the async and sync
# cases within our package, but it's just a no-op.
def __enter__(self) -> "ShieldCancellation":
def __enter__(self) -> ShieldCancellation:
return self
def __exit__(
self,
exc_type: Optional[Type[BaseException]] = None,
exc_value: Optional[BaseException] = None,
traceback: Optional[TracebackType] = None,
exc_type: type[BaseException] | None = None,
exc_value: BaseException | None = None,
traceback: types.TracebackType | None = None,
) -> None:
pass

View File

@@ -1,7 +1,9 @@
from __future__ import annotations
import inspect
import logging
from types import TracebackType
from typing import Any, Dict, Optional, Type
import types
import typing
from ._models import Request
@@ -11,8 +13,8 @@ class Trace:
self,
name: str,
logger: logging.Logger,
request: Optional[Request] = None,
kwargs: Optional[Dict[str, Any]] = None,
request: Request | None = None,
kwargs: dict[str, typing.Any] | None = None,
) -> None:
self.name = name
self.logger = logger
@@ -21,11 +23,11 @@ class Trace:
)
self.debug = self.logger.isEnabledFor(logging.DEBUG)
self.kwargs = kwargs or {}
self.return_value: Any = None
self.return_value: typing.Any = None
self.should_trace = self.debug or self.trace_extension is not None
self.prefix = self.logger.name.split(".")[-1]
def trace(self, name: str, info: Dict[str, Any]) -> None:
def trace(self, name: str, info: dict[str, typing.Any]) -> None:
if self.trace_extension is not None:
prefix_and_name = f"{self.prefix}.{name}"
ret = self.trace_extension(prefix_and_name, info)
@@ -44,7 +46,7 @@ class Trace:
message = f"{name} {args}"
self.logger.debug(message)
def __enter__(self) -> "Trace":
def __enter__(self) -> Trace:
if self.should_trace:
info = self.kwargs
self.trace(f"{self.name}.started", info)
@@ -52,9 +54,9 @@ class Trace:
def __exit__(
self,
exc_type: Optional[Type[BaseException]] = None,
exc_value: Optional[BaseException] = None,
traceback: Optional[TracebackType] = None,
exc_type: type[BaseException] | None = None,
exc_value: BaseException | None = None,
traceback: types.TracebackType | None = None,
) -> None:
if self.should_trace:
if exc_value is None:
@@ -64,7 +66,7 @@ class Trace:
info = {"exception": exc_value}
self.trace(f"{self.name}.failed", info)
async def atrace(self, name: str, info: Dict[str, Any]) -> None:
async def atrace(self, name: str, info: dict[str, typing.Any]) -> None:
if self.trace_extension is not None:
prefix_and_name = f"{self.prefix}.{name}"
coro = self.trace_extension(prefix_and_name, info)
@@ -84,7 +86,7 @@ class Trace:
message = f"{name} {args}"
self.logger.debug(message)
async def __aenter__(self) -> "Trace":
async def __aenter__(self) -> Trace:
if self.should_trace:
info = self.kwargs
await self.atrace(f"{self.name}.started", info)
@@ -92,9 +94,9 @@ class Trace:
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]] = None,
exc_value: Optional[BaseException] = None,
traceback: Optional[TracebackType] = None,
exc_type: type[BaseException] | None = None,
exc_value: BaseException | None = None,
traceback: types.TracebackType | None = None,
) -> None:
if self.should_trace:
if exc_value is None:

View File

@@ -1,10 +1,11 @@
from __future__ import annotations
import select
import socket
import sys
import typing
def is_socket_readable(sock: typing.Optional[socket.socket]) -> bool:
def is_socket_readable(sock: socket.socket | None) -> bool:
"""
Return whether a socket, as identifed by its file descriptor, is readable.
"A socket is readable" means that the read buffer isn't empty, i.e. that calling