updates
This commit is contained in:
@@ -34,7 +34,7 @@ from ._exceptions import (
|
||||
WriteError,
|
||||
WriteTimeout,
|
||||
)
|
||||
from ._models import URL, Origin, Request, Response
|
||||
from ._models import URL, Origin, Proxy, Request, Response
|
||||
from ._ssl import default_ssl_context
|
||||
from ._sync import (
|
||||
ConnectionInterface,
|
||||
@@ -79,6 +79,7 @@ __all__ = [
|
||||
"URL",
|
||||
"Request",
|
||||
"Response",
|
||||
"Proxy",
|
||||
# async
|
||||
"AsyncHTTPConnection",
|
||||
"AsyncConnectionPool",
|
||||
@@ -130,10 +131,11 @@ __all__ = [
|
||||
"WriteError",
|
||||
]
|
||||
|
||||
__version__ = "0.17.3"
|
||||
__version__ = "1.0.9"
|
||||
|
||||
|
||||
__locals = locals()
|
||||
for __name in __all__:
|
||||
if not __name.startswith("__"):
|
||||
# Exclude SOCKET_OPTION, it causes AttributeError on Python 3.14
|
||||
if not __name.startswith(("__", "SOCKET_OPTION")):
|
||||
setattr(__locals[__name], "__module__", "httpcore") # noqa
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,17 +1,19 @@
|
||||
from contextlib import contextmanager
|
||||
from typing import Iterator, Optional, Union
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import typing
|
||||
|
||||
from ._models import URL, Extensions, HeaderTypes, Response
|
||||
from ._sync.connection_pool import ConnectionPool
|
||||
|
||||
|
||||
def request(
|
||||
method: Union[bytes, str],
|
||||
url: Union[URL, bytes, str],
|
||||
method: bytes | str,
|
||||
url: URL | bytes | str,
|
||||
*,
|
||||
headers: HeaderTypes = None,
|
||||
content: Union[bytes, Iterator[bytes], None] = None,
|
||||
extensions: Optional[Extensions] = None,
|
||||
content: bytes | typing.Iterator[bytes] | None = None,
|
||||
extensions: Extensions | None = None,
|
||||
) -> Response:
|
||||
"""
|
||||
Sends an HTTP request, returning the response.
|
||||
@@ -45,15 +47,15 @@ def request(
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
@contextlib.contextmanager
|
||||
def stream(
|
||||
method: Union[bytes, str],
|
||||
url: Union[URL, bytes, str],
|
||||
method: bytes | str,
|
||||
url: URL | bytes | str,
|
||||
*,
|
||||
headers: HeaderTypes = None,
|
||||
content: Union[bytes, Iterator[bytes], None] = None,
|
||||
extensions: Optional[Extensions] = None,
|
||||
) -> Iterator[Response]:
|
||||
content: bytes | typing.Iterator[bytes] | None = None,
|
||||
extensions: Extensions | None = None,
|
||||
) -> typing.Iterator[Response]:
|
||||
"""
|
||||
Sends an HTTP request, returning the response within a content manager.
|
||||
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,12 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
import logging
|
||||
import ssl
|
||||
from types import TracebackType
|
||||
from typing import Iterable, Iterator, Optional, Type
|
||||
import types
|
||||
import typing
|
||||
|
||||
from .._backends.auto import AutoBackend
|
||||
from .._backends.base import SOCKET_OPTION, AsyncNetworkBackend, AsyncNetworkStream
|
||||
from .._exceptions import ConnectError, ConnectionNotAvailable, ConnectTimeout
|
||||
from .._exceptions import ConnectError, ConnectTimeout
|
||||
from .._models import Origin, Request, Response
|
||||
from .._ssl import default_ssl_context
|
||||
from .._synchronization import AsyncLock
|
||||
@@ -20,25 +22,32 @@ RETRIES_BACKOFF_FACTOR = 0.5 # 0s, 0.5s, 1s, 2s, 4s, etc.
|
||||
logger = logging.getLogger("httpcore.connection")
|
||||
|
||||
|
||||
def exponential_backoff(factor: float) -> Iterator[float]:
|
||||
def exponential_backoff(factor: float) -> typing.Iterator[float]:
|
||||
"""
|
||||
Generate a geometric sequence that has a ratio of 2 and starts with 0.
|
||||
|
||||
For example:
|
||||
- `factor = 2`: `0, 2, 4, 8, 16, 32, 64, ...`
|
||||
- `factor = 3`: `0, 3, 6, 12, 24, 48, 96, ...`
|
||||
"""
|
||||
yield 0
|
||||
for n in itertools.count(2):
|
||||
yield factor * (2 ** (n - 2))
|
||||
for n in itertools.count():
|
||||
yield factor * 2**n
|
||||
|
||||
|
||||
class AsyncHTTPConnection(AsyncConnectionInterface):
|
||||
def __init__(
|
||||
self,
|
||||
origin: Origin,
|
||||
ssl_context: Optional[ssl.SSLContext] = None,
|
||||
keepalive_expiry: Optional[float] = None,
|
||||
ssl_context: ssl.SSLContext | None = None,
|
||||
keepalive_expiry: float | None = None,
|
||||
http1: bool = True,
|
||||
http2: bool = False,
|
||||
retries: int = 0,
|
||||
local_address: Optional[str] = None,
|
||||
uds: Optional[str] = None,
|
||||
network_backend: Optional[AsyncNetworkBackend] = None,
|
||||
socket_options: Optional[Iterable[SOCKET_OPTION]] = None,
|
||||
local_address: str | None = None,
|
||||
uds: str | None = None,
|
||||
network_backend: AsyncNetworkBackend | None = None,
|
||||
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
|
||||
) -> None:
|
||||
self._origin = origin
|
||||
self._ssl_context = ssl_context
|
||||
@@ -52,7 +61,7 @@ class AsyncHTTPConnection(AsyncConnectionInterface):
|
||||
self._network_backend: AsyncNetworkBackend = (
|
||||
AutoBackend() if network_backend is None else network_backend
|
||||
)
|
||||
self._connection: Optional[AsyncConnectionInterface] = None
|
||||
self._connection: AsyncConnectionInterface | None = None
|
||||
self._connect_failed: bool = False
|
||||
self._request_lock = AsyncLock()
|
||||
self._socket_options = socket_options
|
||||
@@ -63,9 +72,9 @@ class AsyncHTTPConnection(AsyncConnectionInterface):
|
||||
f"Attempted to send request to {request.url.origin} on connection to {self._origin}"
|
||||
)
|
||||
|
||||
async with self._request_lock:
|
||||
if self._connection is None:
|
||||
try:
|
||||
try:
|
||||
async with self._request_lock:
|
||||
if self._connection is None:
|
||||
stream = await self._connect(request)
|
||||
|
||||
ssl_object = stream.get_extra_info("ssl_object")
|
||||
@@ -87,11 +96,9 @@ class AsyncHTTPConnection(AsyncConnectionInterface):
|
||||
stream=stream,
|
||||
keepalive_expiry=self._keepalive_expiry,
|
||||
)
|
||||
except Exception as exc:
|
||||
self._connect_failed = True
|
||||
raise exc
|
||||
elif not self._connection.is_available():
|
||||
raise ConnectionNotAvailable()
|
||||
except BaseException as exc:
|
||||
self._connect_failed = True
|
||||
raise exc
|
||||
|
||||
return await self._connection.handle_async_request(request)
|
||||
|
||||
@@ -130,7 +137,7 @@ class AsyncHTTPConnection(AsyncConnectionInterface):
|
||||
)
|
||||
trace.return_value = stream
|
||||
|
||||
if self._origin.scheme == b"https":
|
||||
if self._origin.scheme in (b"https", b"wss"):
|
||||
ssl_context = (
|
||||
default_ssl_context()
|
||||
if self._ssl_context is None
|
||||
@@ -203,13 +210,13 @@ class AsyncHTTPConnection(AsyncConnectionInterface):
|
||||
# These context managers are not used in the standard flow, but are
|
||||
# useful for testing or working with connection instances directly.
|
||||
|
||||
async def __aenter__(self) -> "AsyncHTTPConnection":
|
||||
async def __aenter__(self) -> AsyncHTTPConnection:
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]] = None,
|
||||
exc_value: Optional[BaseException] = None,
|
||||
traceback: Optional[TracebackType] = None,
|
||||
exc_type: type[BaseException] | None = None,
|
||||
exc_value: BaseException | None = None,
|
||||
traceback: types.TracebackType | None = None,
|
||||
) -> None:
|
||||
await self.aclose()
|
||||
|
||||
@@ -1,41 +1,44 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import ssl
|
||||
import sys
|
||||
from types import TracebackType
|
||||
from typing import AsyncIterable, AsyncIterator, Iterable, List, Optional, Type
|
||||
import types
|
||||
import typing
|
||||
|
||||
from .._backends.auto import AutoBackend
|
||||
from .._backends.base import SOCKET_OPTION, AsyncNetworkBackend
|
||||
from .._exceptions import ConnectionNotAvailable, UnsupportedProtocol
|
||||
from .._models import Origin, Request, Response
|
||||
from .._synchronization import AsyncEvent, AsyncLock, AsyncShieldCancellation
|
||||
from .._models import Origin, Proxy, Request, Response
|
||||
from .._synchronization import AsyncEvent, AsyncShieldCancellation, AsyncThreadLock
|
||||
from .connection import AsyncHTTPConnection
|
||||
from .interfaces import AsyncConnectionInterface, AsyncRequestInterface
|
||||
|
||||
|
||||
class RequestStatus:
|
||||
def __init__(self, request: Request):
|
||||
class AsyncPoolRequest:
|
||||
def __init__(self, request: Request) -> None:
|
||||
self.request = request
|
||||
self.connection: Optional[AsyncConnectionInterface] = None
|
||||
self.connection: AsyncConnectionInterface | None = None
|
||||
self._connection_acquired = AsyncEvent()
|
||||
|
||||
def set_connection(self, connection: AsyncConnectionInterface) -> None:
|
||||
assert self.connection is None
|
||||
def assign_to_connection(self, connection: AsyncConnectionInterface | None) -> None:
|
||||
self.connection = connection
|
||||
self._connection_acquired.set()
|
||||
|
||||
def unset_connection(self) -> None:
|
||||
assert self.connection is not None
|
||||
def clear_connection(self) -> None:
|
||||
self.connection = None
|
||||
self._connection_acquired = AsyncEvent()
|
||||
|
||||
async def wait_for_connection(
|
||||
self, timeout: Optional[float] = None
|
||||
self, timeout: float | None = None
|
||||
) -> AsyncConnectionInterface:
|
||||
if self.connection is None:
|
||||
await self._connection_acquired.wait(timeout=timeout)
|
||||
assert self.connection is not None
|
||||
return self.connection
|
||||
|
||||
def is_queued(self) -> bool:
|
||||
return self.connection is None
|
||||
|
||||
|
||||
class AsyncConnectionPool(AsyncRequestInterface):
|
||||
"""
|
||||
@@ -44,17 +47,18 @@ class AsyncConnectionPool(AsyncRequestInterface):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ssl_context: Optional[ssl.SSLContext] = None,
|
||||
max_connections: Optional[int] = 10,
|
||||
max_keepalive_connections: Optional[int] = None,
|
||||
keepalive_expiry: Optional[float] = None,
|
||||
ssl_context: ssl.SSLContext | None = None,
|
||||
proxy: Proxy | None = None,
|
||||
max_connections: int | None = 10,
|
||||
max_keepalive_connections: int | None = None,
|
||||
keepalive_expiry: float | None = None,
|
||||
http1: bool = True,
|
||||
http2: bool = False,
|
||||
retries: int = 0,
|
||||
local_address: Optional[str] = None,
|
||||
uds: Optional[str] = None,
|
||||
network_backend: Optional[AsyncNetworkBackend] = None,
|
||||
socket_options: Optional[Iterable[SOCKET_OPTION]] = None,
|
||||
local_address: str | None = None,
|
||||
uds: str | None = None,
|
||||
network_backend: AsyncNetworkBackend | None = None,
|
||||
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
A connection pool for making HTTP requests.
|
||||
@@ -86,7 +90,7 @@ class AsyncConnectionPool(AsyncRequestInterface):
|
||||
in the TCP socket when the connection was established.
|
||||
"""
|
||||
self._ssl_context = ssl_context
|
||||
|
||||
self._proxy = proxy
|
||||
self._max_connections = (
|
||||
sys.maxsize if max_connections is None else max_connections
|
||||
)
|
||||
@@ -106,15 +110,61 @@ class AsyncConnectionPool(AsyncRequestInterface):
|
||||
self._local_address = local_address
|
||||
self._uds = uds
|
||||
|
||||
self._pool: List[AsyncConnectionInterface] = []
|
||||
self._requests: List[RequestStatus] = []
|
||||
self._pool_lock = AsyncLock()
|
||||
self._network_backend = (
|
||||
AutoBackend() if network_backend is None else network_backend
|
||||
)
|
||||
self._socket_options = socket_options
|
||||
|
||||
# The mutable state on a connection pool is the queue of incoming requests,
|
||||
# and the set of connections that are servicing those requests.
|
||||
self._connections: list[AsyncConnectionInterface] = []
|
||||
self._requests: list[AsyncPoolRequest] = []
|
||||
|
||||
# We only mutate the state of the connection pool within an 'optional_thread_lock'
|
||||
# context. This holds a threading lock unless we're running in async mode,
|
||||
# in which case it is a no-op.
|
||||
self._optional_thread_lock = AsyncThreadLock()
|
||||
|
||||
def create_connection(self, origin: Origin) -> AsyncConnectionInterface:
|
||||
if self._proxy is not None:
|
||||
if self._proxy.url.scheme in (b"socks5", b"socks5h"):
|
||||
from .socks_proxy import AsyncSocks5Connection
|
||||
|
||||
return AsyncSocks5Connection(
|
||||
proxy_origin=self._proxy.url.origin,
|
||||
proxy_auth=self._proxy.auth,
|
||||
remote_origin=origin,
|
||||
ssl_context=self._ssl_context,
|
||||
keepalive_expiry=self._keepalive_expiry,
|
||||
http1=self._http1,
|
||||
http2=self._http2,
|
||||
network_backend=self._network_backend,
|
||||
)
|
||||
elif origin.scheme == b"http":
|
||||
from .http_proxy import AsyncForwardHTTPConnection
|
||||
|
||||
return AsyncForwardHTTPConnection(
|
||||
proxy_origin=self._proxy.url.origin,
|
||||
proxy_headers=self._proxy.headers,
|
||||
proxy_ssl_context=self._proxy.ssl_context,
|
||||
remote_origin=origin,
|
||||
keepalive_expiry=self._keepalive_expiry,
|
||||
network_backend=self._network_backend,
|
||||
)
|
||||
from .http_proxy import AsyncTunnelHTTPConnection
|
||||
|
||||
return AsyncTunnelHTTPConnection(
|
||||
proxy_origin=self._proxy.url.origin,
|
||||
proxy_headers=self._proxy.headers,
|
||||
proxy_ssl_context=self._proxy.ssl_context,
|
||||
remote_origin=origin,
|
||||
ssl_context=self._ssl_context,
|
||||
keepalive_expiry=self._keepalive_expiry,
|
||||
http1=self._http1,
|
||||
http2=self._http2,
|
||||
network_backend=self._network_backend,
|
||||
)
|
||||
|
||||
return AsyncHTTPConnection(
|
||||
origin=origin,
|
||||
ssl_context=self._ssl_context,
|
||||
@@ -129,7 +179,7 @@ class AsyncConnectionPool(AsyncRequestInterface):
|
||||
)
|
||||
|
||||
@property
|
||||
def connections(self) -> List[AsyncConnectionInterface]:
|
||||
def connections(self) -> list[AsyncConnectionInterface]:
|
||||
"""
|
||||
Return a list of the connections currently in the pool.
|
||||
|
||||
@@ -144,64 +194,7 @@ class AsyncConnectionPool(AsyncRequestInterface):
|
||||
]
|
||||
```
|
||||
"""
|
||||
return list(self._pool)
|
||||
|
||||
async def _attempt_to_acquire_connection(self, status: RequestStatus) -> bool:
|
||||
"""
|
||||
Attempt to provide a connection that can handle the given origin.
|
||||
"""
|
||||
origin = status.request.url.origin
|
||||
|
||||
# If there are queued requests in front of us, then don't acquire a
|
||||
# connection. We handle requests strictly in order.
|
||||
waiting = [s for s in self._requests if s.connection is None]
|
||||
if waiting and waiting[0] is not status:
|
||||
return False
|
||||
|
||||
# Reuse an existing connection if one is currently available.
|
||||
for idx, connection in enumerate(self._pool):
|
||||
if connection.can_handle_request(origin) and connection.is_available():
|
||||
self._pool.pop(idx)
|
||||
self._pool.insert(0, connection)
|
||||
status.set_connection(connection)
|
||||
return True
|
||||
|
||||
# If the pool is currently full, attempt to close one idle connection.
|
||||
if len(self._pool) >= self._max_connections:
|
||||
for idx, connection in reversed(list(enumerate(self._pool))):
|
||||
if connection.is_idle():
|
||||
await connection.aclose()
|
||||
self._pool.pop(idx)
|
||||
break
|
||||
|
||||
# If the pool is still full, then we cannot acquire a connection.
|
||||
if len(self._pool) >= self._max_connections:
|
||||
return False
|
||||
|
||||
# Otherwise create a new connection.
|
||||
connection = self.create_connection(origin)
|
||||
self._pool.insert(0, connection)
|
||||
status.set_connection(connection)
|
||||
return True
|
||||
|
||||
async def _close_expired_connections(self) -> None:
|
||||
"""
|
||||
Clean up the connection pool by closing off any connections that have expired.
|
||||
"""
|
||||
# Close any connections that have expired their keep-alive time.
|
||||
for idx, connection in reversed(list(enumerate(self._pool))):
|
||||
if connection.has_expired():
|
||||
await connection.aclose()
|
||||
self._pool.pop(idx)
|
||||
|
||||
# If the pool size exceeds the maximum number of allowed keep-alive connections,
|
||||
# then close off idle connections as required.
|
||||
pool_size = len(self._pool)
|
||||
for idx, connection in reversed(list(enumerate(self._pool))):
|
||||
if connection.is_idle() and pool_size > self._max_keepalive_connections:
|
||||
await connection.aclose()
|
||||
self._pool.pop(idx)
|
||||
pool_size -= 1
|
||||
return list(self._connections)
|
||||
|
||||
async def handle_async_request(self, request: Request) -> Response:
|
||||
"""
|
||||
@@ -219,138 +212,209 @@ class AsyncConnectionPool(AsyncRequestInterface):
|
||||
f"Request URL has an unsupported protocol '{scheme}://'."
|
||||
)
|
||||
|
||||
status = RequestStatus(request)
|
||||
timeouts = request.extensions.get("timeout", {})
|
||||
timeout = timeouts.get("pool", None)
|
||||
|
||||
async with self._pool_lock:
|
||||
self._requests.append(status)
|
||||
await self._close_expired_connections()
|
||||
await self._attempt_to_acquire_connection(status)
|
||||
with self._optional_thread_lock:
|
||||
# Add the incoming request to our request queue.
|
||||
pool_request = AsyncPoolRequest(request)
|
||||
self._requests.append(pool_request)
|
||||
|
||||
while True:
|
||||
timeouts = request.extensions.get("timeout", {})
|
||||
timeout = timeouts.get("pool", None)
|
||||
try:
|
||||
connection = await status.wait_for_connection(timeout=timeout)
|
||||
except BaseException as exc:
|
||||
# If we timeout here, or if the task is cancelled, then make
|
||||
# sure to remove the request from the queue before bubbling
|
||||
# up the exception.
|
||||
async with self._pool_lock:
|
||||
# Ensure only remove when task exists.
|
||||
if status in self._requests:
|
||||
self._requests.remove(status)
|
||||
raise exc
|
||||
try:
|
||||
while True:
|
||||
with self._optional_thread_lock:
|
||||
# Assign incoming requests to available connections,
|
||||
# closing or creating new connections as required.
|
||||
closing = self._assign_requests_to_connections()
|
||||
await self._close_connections(closing)
|
||||
|
||||
try:
|
||||
response = await connection.handle_async_request(request)
|
||||
except ConnectionNotAvailable:
|
||||
# The ConnectionNotAvailable exception is a special case, that
|
||||
# indicates we need to retry the request on a new connection.
|
||||
#
|
||||
# The most common case where this can occur is when multiple
|
||||
# requests are queued waiting for a single connection, which
|
||||
# might end up as an HTTP/2 connection, but which actually ends
|
||||
# up as HTTP/1.1.
|
||||
async with self._pool_lock:
|
||||
# Maintain our position in the request queue, but reset the
|
||||
# status so that the request becomes queued again.
|
||||
status.unset_connection()
|
||||
await self._attempt_to_acquire_connection(status)
|
||||
except BaseException as exc:
|
||||
with AsyncShieldCancellation():
|
||||
await self.response_closed(status)
|
||||
raise exc
|
||||
else:
|
||||
break
|
||||
# Wait until this request has an assigned connection.
|
||||
connection = await pool_request.wait_for_connection(timeout=timeout)
|
||||
|
||||
# When we return the response, we wrap the stream in a special class
|
||||
# that handles notifying the connection pool once the response
|
||||
# has been released.
|
||||
assert isinstance(response.stream, AsyncIterable)
|
||||
try:
|
||||
# Send the request on the assigned connection.
|
||||
response = await connection.handle_async_request(
|
||||
pool_request.request
|
||||
)
|
||||
except ConnectionNotAvailable:
|
||||
# In some cases a connection may initially be available to
|
||||
# handle a request, but then become unavailable.
|
||||
#
|
||||
# In this case we clear the connection and try again.
|
||||
pool_request.clear_connection()
|
||||
else:
|
||||
break # pragma: nocover
|
||||
|
||||
except BaseException as exc:
|
||||
with self._optional_thread_lock:
|
||||
# For any exception or cancellation we remove the request from
|
||||
# the queue, and then re-assign requests to connections.
|
||||
self._requests.remove(pool_request)
|
||||
closing = self._assign_requests_to_connections()
|
||||
|
||||
await self._close_connections(closing)
|
||||
raise exc from None
|
||||
|
||||
# Return the response. Note that in this case we still have to manage
|
||||
# the point at which the response is closed.
|
||||
assert isinstance(response.stream, typing.AsyncIterable)
|
||||
return Response(
|
||||
status=response.status,
|
||||
headers=response.headers,
|
||||
content=ConnectionPoolByteStream(response.stream, self, status),
|
||||
content=PoolByteStream(
|
||||
stream=response.stream, pool_request=pool_request, pool=self
|
||||
),
|
||||
extensions=response.extensions,
|
||||
)
|
||||
|
||||
async def response_closed(self, status: RequestStatus) -> None:
|
||||
def _assign_requests_to_connections(self) -> list[AsyncConnectionInterface]:
|
||||
"""
|
||||
This method acts as a callback once the request/response cycle is complete.
|
||||
Manage the state of the connection pool, assigning incoming
|
||||
requests to connections as available.
|
||||
|
||||
It is called into from the `ConnectionPoolByteStream.aclose()` method.
|
||||
Called whenever a new request is added or removed from the pool.
|
||||
|
||||
Any closing connections are returned, allowing the I/O for closing
|
||||
those connections to be handled seperately.
|
||||
"""
|
||||
assert status.connection is not None
|
||||
connection = status.connection
|
||||
closing_connections = []
|
||||
|
||||
async with self._pool_lock:
|
||||
# Update the state of the connection pool.
|
||||
if status in self._requests:
|
||||
self._requests.remove(status)
|
||||
# First we handle cleaning up any connections that are closed,
|
||||
# have expired their keep-alive, or surplus idle connections.
|
||||
for connection in list(self._connections):
|
||||
if connection.is_closed():
|
||||
# log: "removing closed connection"
|
||||
self._connections.remove(connection)
|
||||
elif connection.has_expired():
|
||||
# log: "closing expired connection"
|
||||
self._connections.remove(connection)
|
||||
closing_connections.append(connection)
|
||||
elif (
|
||||
connection.is_idle()
|
||||
and len([connection.is_idle() for connection in self._connections])
|
||||
> self._max_keepalive_connections
|
||||
):
|
||||
# log: "closing idle connection"
|
||||
self._connections.remove(connection)
|
||||
closing_connections.append(connection)
|
||||
|
||||
if connection.is_closed() and connection in self._pool:
|
||||
self._pool.remove(connection)
|
||||
# Assign queued requests to connections.
|
||||
queued_requests = [request for request in self._requests if request.is_queued()]
|
||||
for pool_request in queued_requests:
|
||||
origin = pool_request.request.url.origin
|
||||
available_connections = [
|
||||
connection
|
||||
for connection in self._connections
|
||||
if connection.can_handle_request(origin) and connection.is_available()
|
||||
]
|
||||
idle_connections = [
|
||||
connection for connection in self._connections if connection.is_idle()
|
||||
]
|
||||
|
||||
# Since we've had a response closed, it's possible we'll now be able
|
||||
# to service one or more requests that are currently pending.
|
||||
for status in self._requests:
|
||||
if status.connection is None:
|
||||
acquired = await self._attempt_to_acquire_connection(status)
|
||||
# If we could not acquire a connection for a queued request
|
||||
# then we don't need to check anymore requests that are
|
||||
# queued later behind it.
|
||||
if not acquired:
|
||||
break
|
||||
# There are three cases for how we may be able to handle the request:
|
||||
#
|
||||
# 1. There is an existing connection that can handle the request.
|
||||
# 2. We can create a new connection to handle the request.
|
||||
# 3. We can close an idle connection and then create a new connection
|
||||
# to handle the request.
|
||||
if available_connections:
|
||||
# log: "reusing existing connection"
|
||||
connection = available_connections[0]
|
||||
pool_request.assign_to_connection(connection)
|
||||
elif len(self._connections) < self._max_connections:
|
||||
# log: "creating new connection"
|
||||
connection = self.create_connection(origin)
|
||||
self._connections.append(connection)
|
||||
pool_request.assign_to_connection(connection)
|
||||
elif idle_connections:
|
||||
# log: "closing idle connection"
|
||||
connection = idle_connections[0]
|
||||
self._connections.remove(connection)
|
||||
closing_connections.append(connection)
|
||||
# log: "creating new connection"
|
||||
connection = self.create_connection(origin)
|
||||
self._connections.append(connection)
|
||||
pool_request.assign_to_connection(connection)
|
||||
|
||||
# Housekeeping.
|
||||
await self._close_expired_connections()
|
||||
return closing_connections
|
||||
|
||||
async def _close_connections(self, closing: list[AsyncConnectionInterface]) -> None:
|
||||
# Close connections which have been removed from the pool.
|
||||
with AsyncShieldCancellation():
|
||||
for connection in closing:
|
||||
await connection.aclose()
|
||||
|
||||
async def aclose(self) -> None:
|
||||
"""
|
||||
Close any connections in the pool.
|
||||
"""
|
||||
async with self._pool_lock:
|
||||
for connection in self._pool:
|
||||
await connection.aclose()
|
||||
self._pool = []
|
||||
self._requests = []
|
||||
# Explicitly close the connection pool.
|
||||
# Clears all existing requests and connections.
|
||||
with self._optional_thread_lock:
|
||||
closing_connections = list(self._connections)
|
||||
self._connections = []
|
||||
await self._close_connections(closing_connections)
|
||||
|
||||
async def __aenter__(self) -> "AsyncConnectionPool":
|
||||
async def __aenter__(self) -> AsyncConnectionPool:
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]] = None,
|
||||
exc_value: Optional[BaseException] = None,
|
||||
traceback: Optional[TracebackType] = None,
|
||||
exc_type: type[BaseException] | None = None,
|
||||
exc_value: BaseException | None = None,
|
||||
traceback: types.TracebackType | None = None,
|
||||
) -> None:
|
||||
await self.aclose()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
class_name = self.__class__.__name__
|
||||
with self._optional_thread_lock:
|
||||
request_is_queued = [request.is_queued() for request in self._requests]
|
||||
connection_is_idle = [
|
||||
connection.is_idle() for connection in self._connections
|
||||
]
|
||||
|
||||
class ConnectionPoolByteStream:
|
||||
"""
|
||||
A wrapper around the response byte stream, that additionally handles
|
||||
notifying the connection pool when the response has been closed.
|
||||
"""
|
||||
num_active_requests = request_is_queued.count(False)
|
||||
num_queued_requests = request_is_queued.count(True)
|
||||
num_active_connections = connection_is_idle.count(False)
|
||||
num_idle_connections = connection_is_idle.count(True)
|
||||
|
||||
requests_info = (
|
||||
f"Requests: {num_active_requests} active, {num_queued_requests} queued"
|
||||
)
|
||||
connection_info = (
|
||||
f"Connections: {num_active_connections} active, {num_idle_connections} idle"
|
||||
)
|
||||
|
||||
return f"<{class_name} [{requests_info} | {connection_info}]>"
|
||||
|
||||
|
||||
class PoolByteStream:
|
||||
def __init__(
|
||||
self,
|
||||
stream: AsyncIterable[bytes],
|
||||
stream: typing.AsyncIterable[bytes],
|
||||
pool_request: AsyncPoolRequest,
|
||||
pool: AsyncConnectionPool,
|
||||
status: RequestStatus,
|
||||
) -> None:
|
||||
self._stream = stream
|
||||
self._pool_request = pool_request
|
||||
self._pool = pool
|
||||
self._status = status
|
||||
self._closed = False
|
||||
|
||||
async def __aiter__(self) -> AsyncIterator[bytes]:
|
||||
async for part in self._stream:
|
||||
yield part
|
||||
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
|
||||
try:
|
||||
async for part in self._stream:
|
||||
yield part
|
||||
except BaseException as exc:
|
||||
await self.aclose()
|
||||
raise exc from None
|
||||
|
||||
async def aclose(self) -> None:
|
||||
try:
|
||||
if hasattr(self._stream, "aclose"):
|
||||
await self._stream.aclose()
|
||||
finally:
|
||||
if not self._closed:
|
||||
self._closed = True
|
||||
with AsyncShieldCancellation():
|
||||
await self._pool.response_closed(self._status)
|
||||
if hasattr(self._stream, "aclose"):
|
||||
await self._stream.aclose()
|
||||
|
||||
with self._pool._optional_thread_lock:
|
||||
self._pool._requests.remove(self._pool_request)
|
||||
closing = self._pool._assign_requests_to_connections()
|
||||
|
||||
await self._pool._close_connections(closing)
|
||||
|
||||
@@ -1,17 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
import logging
|
||||
import ssl
|
||||
import time
|
||||
from types import TracebackType
|
||||
from typing import (
|
||||
AsyncIterable,
|
||||
AsyncIterator,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
import types
|
||||
import typing
|
||||
|
||||
import h11
|
||||
|
||||
@@ -20,6 +14,7 @@ from .._exceptions import (
|
||||
ConnectionNotAvailable,
|
||||
LocalProtocolError,
|
||||
RemoteProtocolError,
|
||||
WriteError,
|
||||
map_exceptions,
|
||||
)
|
||||
from .._models import Origin, Request, Response
|
||||
@@ -31,7 +26,7 @@ logger = logging.getLogger("httpcore.http11")
|
||||
|
||||
|
||||
# A subset of `h11.Event` types supported by `_send_event`
|
||||
H11SendEvent = Union[
|
||||
H11SendEvent = typing.Union[
|
||||
h11.Request,
|
||||
h11.Data,
|
||||
h11.EndOfMessage,
|
||||
@@ -53,12 +48,12 @@ class AsyncHTTP11Connection(AsyncConnectionInterface):
|
||||
self,
|
||||
origin: Origin,
|
||||
stream: AsyncNetworkStream,
|
||||
keepalive_expiry: Optional[float] = None,
|
||||
keepalive_expiry: float | None = None,
|
||||
) -> None:
|
||||
self._origin = origin
|
||||
self._network_stream = stream
|
||||
self._keepalive_expiry: Optional[float] = keepalive_expiry
|
||||
self._expire_at: Optional[float] = None
|
||||
self._keepalive_expiry: float | None = keepalive_expiry
|
||||
self._expire_at: float | None = None
|
||||
self._state = HTTPConnectionState.NEW
|
||||
self._state_lock = AsyncLock()
|
||||
self._request_count = 0
|
||||
@@ -84,10 +79,21 @@ class AsyncHTTP11Connection(AsyncConnectionInterface):
|
||||
|
||||
try:
|
||||
kwargs = {"request": request}
|
||||
async with Trace("send_request_headers", logger, request, kwargs) as trace:
|
||||
await self._send_request_headers(**kwargs)
|
||||
async with Trace("send_request_body", logger, request, kwargs) as trace:
|
||||
await self._send_request_body(**kwargs)
|
||||
try:
|
||||
async with Trace(
|
||||
"send_request_headers", logger, request, kwargs
|
||||
) as trace:
|
||||
await self._send_request_headers(**kwargs)
|
||||
async with Trace("send_request_body", logger, request, kwargs) as trace:
|
||||
await self._send_request_body(**kwargs)
|
||||
except WriteError:
|
||||
# If we get a write error while we're writing the request,
|
||||
# then we supress this error and move on to attempting to
|
||||
# read the response. Servers can sometimes close the request
|
||||
# pre-emptively and then respond with a well formed HTTP
|
||||
# error response.
|
||||
pass
|
||||
|
||||
async with Trace(
|
||||
"receive_response_headers", logger, request, kwargs
|
||||
) as trace:
|
||||
@@ -96,6 +102,7 @@ class AsyncHTTP11Connection(AsyncConnectionInterface):
|
||||
status,
|
||||
reason_phrase,
|
||||
headers,
|
||||
trailing_data,
|
||||
) = await self._receive_response_headers(**kwargs)
|
||||
trace.return_value = (
|
||||
http_version,
|
||||
@@ -104,6 +111,14 @@ class AsyncHTTP11Connection(AsyncConnectionInterface):
|
||||
headers,
|
||||
)
|
||||
|
||||
network_stream = self._network_stream
|
||||
|
||||
# CONNECT or Upgrade request
|
||||
if (status == 101) or (
|
||||
(request.method == b"CONNECT") and (200 <= status < 300)
|
||||
):
|
||||
network_stream = AsyncHTTP11UpgradeStream(network_stream, trailing_data)
|
||||
|
||||
return Response(
|
||||
status=status,
|
||||
headers=headers,
|
||||
@@ -111,7 +126,7 @@ class AsyncHTTP11Connection(AsyncConnectionInterface):
|
||||
extensions={
|
||||
"http_version": http_version,
|
||||
"reason_phrase": reason_phrase,
|
||||
"network_stream": self._network_stream,
|
||||
"network_stream": network_stream,
|
||||
},
|
||||
)
|
||||
except BaseException as exc:
|
||||
@@ -138,16 +153,14 @@ class AsyncHTTP11Connection(AsyncConnectionInterface):
|
||||
timeouts = request.extensions.get("timeout", {})
|
||||
timeout = timeouts.get("write", None)
|
||||
|
||||
assert isinstance(request.stream, AsyncIterable)
|
||||
assert isinstance(request.stream, typing.AsyncIterable)
|
||||
async for chunk in request.stream:
|
||||
event = h11.Data(data=chunk)
|
||||
await self._send_event(event, timeout=timeout)
|
||||
|
||||
await self._send_event(h11.EndOfMessage(), timeout=timeout)
|
||||
|
||||
async def _send_event(
|
||||
self, event: h11.Event, timeout: Optional[float] = None
|
||||
) -> None:
|
||||
async def _send_event(self, event: h11.Event, timeout: float | None = None) -> None:
|
||||
bytes_to_send = self._h11_state.send(event)
|
||||
if bytes_to_send is not None:
|
||||
await self._network_stream.write(bytes_to_send, timeout=timeout)
|
||||
@@ -156,7 +169,7 @@ class AsyncHTTP11Connection(AsyncConnectionInterface):
|
||||
|
||||
async def _receive_response_headers(
|
||||
self, request: Request
|
||||
) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]]]:
|
||||
) -> tuple[bytes, int, bytes, list[tuple[bytes, bytes]], bytes]:
|
||||
timeouts = request.extensions.get("timeout", {})
|
||||
timeout = timeouts.get("read", None)
|
||||
|
||||
@@ -176,9 +189,13 @@ class AsyncHTTP11Connection(AsyncConnectionInterface):
|
||||
# raw header casing, rather than the enforced lowercase headers.
|
||||
headers = event.headers.raw_items()
|
||||
|
||||
return http_version, event.status_code, event.reason, headers
|
||||
trailing_data, _ = self._h11_state.trailing_data
|
||||
|
||||
async def _receive_response_body(self, request: Request) -> AsyncIterator[bytes]:
|
||||
return http_version, event.status_code, event.reason, headers, trailing_data
|
||||
|
||||
async def _receive_response_body(
|
||||
self, request: Request
|
||||
) -> typing.AsyncIterator[bytes]:
|
||||
timeouts = request.extensions.get("timeout", {})
|
||||
timeout = timeouts.get("read", None)
|
||||
|
||||
@@ -190,8 +207,8 @@ class AsyncHTTP11Connection(AsyncConnectionInterface):
|
||||
break
|
||||
|
||||
async def _receive_event(
|
||||
self, timeout: Optional[float] = None
|
||||
) -> Union[h11.Event, Type[h11.PAUSED]]:
|
||||
self, timeout: float | None = None
|
||||
) -> h11.Event | type[h11.PAUSED]:
|
||||
while True:
|
||||
with map_exceptions({h11.RemoteProtocolError: RemoteProtocolError}):
|
||||
event = self._h11_state.next_event()
|
||||
@@ -216,7 +233,7 @@ class AsyncHTTP11Connection(AsyncConnectionInterface):
|
||||
self._h11_state.receive_data(data)
|
||||
else:
|
||||
# mypy fails to narrow the type in the above if statement above
|
||||
return cast(Union[h11.Event, Type[h11.PAUSED]], event)
|
||||
return event # type: ignore[return-value]
|
||||
|
||||
async def _response_closed(self) -> None:
|
||||
async with self._state_lock:
|
||||
@@ -292,14 +309,14 @@ class AsyncHTTP11Connection(AsyncConnectionInterface):
|
||||
# These context managers are not used in the standard flow, but are
|
||||
# useful for testing or working with connection instances directly.
|
||||
|
||||
async def __aenter__(self) -> "AsyncHTTP11Connection":
|
||||
async def __aenter__(self) -> AsyncHTTP11Connection:
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]] = None,
|
||||
exc_value: Optional[BaseException] = None,
|
||||
traceback: Optional[TracebackType] = None,
|
||||
exc_type: type[BaseException] | None = None,
|
||||
exc_value: BaseException | None = None,
|
||||
traceback: types.TracebackType | None = None,
|
||||
) -> None:
|
||||
await self.aclose()
|
||||
|
||||
@@ -310,7 +327,7 @@ class HTTP11ConnectionByteStream:
|
||||
self._request = request
|
||||
self._closed = False
|
||||
|
||||
async def __aiter__(self) -> AsyncIterator[bytes]:
|
||||
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
|
||||
kwargs = {"request": self._request}
|
||||
try:
|
||||
async with Trace("receive_response_body", logger, self._request, kwargs):
|
||||
@@ -329,3 +346,34 @@ class HTTP11ConnectionByteStream:
|
||||
self._closed = True
|
||||
async with Trace("response_closed", logger, self._request):
|
||||
await self._connection._response_closed()
|
||||
|
||||
|
||||
class AsyncHTTP11UpgradeStream(AsyncNetworkStream):
|
||||
def __init__(self, stream: AsyncNetworkStream, leading_data: bytes) -> None:
|
||||
self._stream = stream
|
||||
self._leading_data = leading_data
|
||||
|
||||
async def read(self, max_bytes: int, timeout: float | None = None) -> bytes:
|
||||
if self._leading_data:
|
||||
buffer = self._leading_data[:max_bytes]
|
||||
self._leading_data = self._leading_data[max_bytes:]
|
||||
return buffer
|
||||
else:
|
||||
return await self._stream.read(max_bytes, timeout)
|
||||
|
||||
async def write(self, buffer: bytes, timeout: float | None = None) -> None:
|
||||
await self._stream.write(buffer, timeout)
|
||||
|
||||
async def aclose(self) -> None:
|
||||
await self._stream.aclose()
|
||||
|
||||
async def start_tls(
|
||||
self,
|
||||
ssl_context: ssl.SSLContext,
|
||||
server_hostname: str | None = None,
|
||||
timeout: float | None = None,
|
||||
) -> AsyncNetworkStream:
|
||||
return await self._stream.start_tls(ssl_context, server_hostname, timeout)
|
||||
|
||||
def get_extra_info(self, info: str) -> typing.Any:
|
||||
return self._stream.get_extra_info(info)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
import logging
|
||||
import time
|
||||
@@ -45,14 +47,14 @@ class AsyncHTTP2Connection(AsyncConnectionInterface):
|
||||
self,
|
||||
origin: Origin,
|
||||
stream: AsyncNetworkStream,
|
||||
keepalive_expiry: typing.Optional[float] = None,
|
||||
keepalive_expiry: float | None = None,
|
||||
):
|
||||
self._origin = origin
|
||||
self._network_stream = stream
|
||||
self._keepalive_expiry: typing.Optional[float] = keepalive_expiry
|
||||
self._keepalive_expiry: float | None = keepalive_expiry
|
||||
self._h2_state = h2.connection.H2Connection(config=self.CONFIG)
|
||||
self._state = HTTPConnectionState.IDLE
|
||||
self._expire_at: typing.Optional[float] = None
|
||||
self._expire_at: float | None = None
|
||||
self._request_count = 0
|
||||
self._init_lock = AsyncLock()
|
||||
self._state_lock = AsyncLock()
|
||||
@@ -63,24 +65,22 @@ class AsyncHTTP2Connection(AsyncConnectionInterface):
|
||||
self._connection_error = False
|
||||
|
||||
# Mapping from stream ID to response stream events.
|
||||
self._events: typing.Dict[
|
||||
self._events: dict[
|
||||
int,
|
||||
typing.Union[
|
||||
h2.events.ResponseReceived,
|
||||
h2.events.DataReceived,
|
||||
h2.events.StreamEnded,
|
||||
h2.events.StreamReset,
|
||||
list[
|
||||
h2.events.ResponseReceived
|
||||
| h2.events.DataReceived
|
||||
| h2.events.StreamEnded
|
||||
| h2.events.StreamReset,
|
||||
],
|
||||
] = {}
|
||||
|
||||
# Connection terminated events are stored as state since
|
||||
# we need to handle them for all streams.
|
||||
self._connection_terminated: typing.Optional[
|
||||
h2.events.ConnectionTerminated
|
||||
] = None
|
||||
self._connection_terminated: h2.events.ConnectionTerminated | None = None
|
||||
|
||||
self._read_exception: typing.Optional[Exception] = None
|
||||
self._write_exception: typing.Optional[Exception] = None
|
||||
self._read_exception: Exception | None = None
|
||||
self._write_exception: Exception | None = None
|
||||
|
||||
async def handle_async_request(self, request: Request) -> Response:
|
||||
if not self.can_handle_request(request.url.origin):
|
||||
@@ -104,9 +104,11 @@ class AsyncHTTP2Connection(AsyncConnectionInterface):
|
||||
async with self._init_lock:
|
||||
if not self._sent_connection_init:
|
||||
try:
|
||||
kwargs = {"request": request}
|
||||
async with Trace("send_connection_init", logger, request, kwargs):
|
||||
await self._send_connection_init(**kwargs)
|
||||
sci_kwargs = {"request": request}
|
||||
async with Trace(
|
||||
"send_connection_init", logger, request, sci_kwargs
|
||||
):
|
||||
await self._send_connection_init(**sci_kwargs)
|
||||
except BaseException as exc:
|
||||
with AsyncShieldCancellation():
|
||||
await self.aclose()
|
||||
@@ -284,7 +286,7 @@ class AsyncHTTP2Connection(AsyncConnectionInterface):
|
||||
|
||||
async def _receive_response(
|
||||
self, request: Request, stream_id: int
|
||||
) -> typing.Tuple[int, typing.List[typing.Tuple[bytes, bytes]]]:
|
||||
) -> tuple[int, list[tuple[bytes, bytes]]]:
|
||||
"""
|
||||
Return the response status code and headers for a given stream ID.
|
||||
"""
|
||||
@@ -295,6 +297,7 @@ class AsyncHTTP2Connection(AsyncConnectionInterface):
|
||||
|
||||
status_code = 200
|
||||
headers = []
|
||||
assert event.headers is not None
|
||||
for k, v in event.headers:
|
||||
if k == b":status":
|
||||
status_code = int(v.decode("ascii", errors="ignore"))
|
||||
@@ -312,6 +315,8 @@ class AsyncHTTP2Connection(AsyncConnectionInterface):
|
||||
while True:
|
||||
event = await self._receive_stream_event(request, stream_id)
|
||||
if isinstance(event, h2.events.DataReceived):
|
||||
assert event.flow_controlled_length is not None
|
||||
assert event.data is not None
|
||||
amount = event.flow_controlled_length
|
||||
self._h2_state.acknowledge_received_data(amount, stream_id)
|
||||
await self._write_outgoing_data(request)
|
||||
@@ -321,9 +326,7 @@ class AsyncHTTP2Connection(AsyncConnectionInterface):
|
||||
|
||||
async def _receive_stream_event(
|
||||
self, request: Request, stream_id: int
|
||||
) -> typing.Union[
|
||||
h2.events.ResponseReceived, h2.events.DataReceived, h2.events.StreamEnded
|
||||
]:
|
||||
) -> h2.events.ResponseReceived | h2.events.DataReceived | h2.events.StreamEnded:
|
||||
"""
|
||||
Return the next available event for a given stream ID.
|
||||
|
||||
@@ -337,7 +340,7 @@ class AsyncHTTP2Connection(AsyncConnectionInterface):
|
||||
return event
|
||||
|
||||
async def _receive_events(
|
||||
self, request: Request, stream_id: typing.Optional[int] = None
|
||||
self, request: Request, stream_id: int | None = None
|
||||
) -> None:
|
||||
"""
|
||||
Read some data from the network until we see one or more events
|
||||
@@ -384,7 +387,9 @@ class AsyncHTTP2Connection(AsyncConnectionInterface):
|
||||
|
||||
await self._write_outgoing_data(request)
|
||||
|
||||
async def _receive_remote_settings_change(self, event: h2.events.Event) -> None:
|
||||
async def _receive_remote_settings_change(
|
||||
self, event: h2.events.RemoteSettingsChanged
|
||||
) -> None:
|
||||
max_concurrent_streams = event.changed_settings.get(
|
||||
h2.settings.SettingCodes.MAX_CONCURRENT_STREAMS
|
||||
)
|
||||
@@ -425,9 +430,7 @@ class AsyncHTTP2Connection(AsyncConnectionInterface):
|
||||
|
||||
# Wrappers around network read/write operations...
|
||||
|
||||
async def _read_incoming_data(
|
||||
self, request: Request
|
||||
) -> typing.List[h2.events.Event]:
|
||||
async def _read_incoming_data(self, request: Request) -> list[h2.events.Event]:
|
||||
timeouts = request.extensions.get("timeout", {})
|
||||
timeout = timeouts.get("read", None)
|
||||
|
||||
@@ -451,7 +454,7 @@ class AsyncHTTP2Connection(AsyncConnectionInterface):
|
||||
self._connection_error = True
|
||||
raise exc
|
||||
|
||||
events: typing.List[h2.events.Event] = self._h2_state.receive_data(data)
|
||||
events: list[h2.events.Event] = self._h2_state.receive_data(data)
|
||||
|
||||
return events
|
||||
|
||||
@@ -544,14 +547,14 @@ class AsyncHTTP2Connection(AsyncConnectionInterface):
|
||||
# These context managers are not used in the standard flow, but are
|
||||
# useful for testing or working with connection instances directly.
|
||||
|
||||
async def __aenter__(self) -> "AsyncHTTP2Connection":
|
||||
async def __aenter__(self) -> AsyncHTTP2Connection:
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: typing.Optional[typing.Type[BaseException]] = None,
|
||||
exc_value: typing.Optional[BaseException] = None,
|
||||
traceback: typing.Optional[types.TracebackType] = None,
|
||||
exc_type: type[BaseException] | None = None,
|
||||
exc_value: BaseException | None = None,
|
||||
traceback: types.TracebackType | None = None,
|
||||
) -> None:
|
||||
await self.aclose()
|
||||
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import logging
|
||||
import ssl
|
||||
from base64 import b64encode
|
||||
from typing import Iterable, List, Mapping, Optional, Sequence, Tuple, Union
|
||||
import typing
|
||||
|
||||
from .._backends.base import SOCKET_OPTION, AsyncNetworkBackend
|
||||
from .._exceptions import ProxyError
|
||||
@@ -22,17 +24,18 @@ from .connection_pool import AsyncConnectionPool
|
||||
from .http11 import AsyncHTTP11Connection
|
||||
from .interfaces import AsyncConnectionInterface
|
||||
|
||||
HeadersAsSequence = Sequence[Tuple[Union[bytes, str], Union[bytes, str]]]
|
||||
HeadersAsMapping = Mapping[Union[bytes, str], Union[bytes, str]]
|
||||
ByteOrStr = typing.Union[bytes, str]
|
||||
HeadersAsSequence = typing.Sequence[typing.Tuple[ByteOrStr, ByteOrStr]]
|
||||
HeadersAsMapping = typing.Mapping[ByteOrStr, ByteOrStr]
|
||||
|
||||
|
||||
logger = logging.getLogger("httpcore.proxy")
|
||||
|
||||
|
||||
def merge_headers(
|
||||
default_headers: Optional[Sequence[Tuple[bytes, bytes]]] = None,
|
||||
override_headers: Optional[Sequence[Tuple[bytes, bytes]]] = None,
|
||||
) -> List[Tuple[bytes, bytes]]:
|
||||
default_headers: typing.Sequence[tuple[bytes, bytes]] | None = None,
|
||||
override_headers: typing.Sequence[tuple[bytes, bytes]] | None = None,
|
||||
) -> list[tuple[bytes, bytes]]:
|
||||
"""
|
||||
Append default_headers and override_headers, de-duplicating if a key exists
|
||||
in both cases.
|
||||
@@ -48,32 +51,28 @@ def merge_headers(
|
||||
return default_headers + override_headers
|
||||
|
||||
|
||||
def build_auth_header(username: bytes, password: bytes) -> bytes:
|
||||
userpass = username + b":" + password
|
||||
return b"Basic " + b64encode(userpass)
|
||||
|
||||
|
||||
class AsyncHTTPProxy(AsyncConnectionPool):
|
||||
class AsyncHTTPProxy(AsyncConnectionPool): # pragma: nocover
|
||||
"""
|
||||
A connection pool that sends requests via an HTTP proxy.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
proxy_url: Union[URL, bytes, str],
|
||||
proxy_auth: Optional[Tuple[Union[bytes, str], Union[bytes, str]]] = None,
|
||||
proxy_headers: Union[HeadersAsMapping, HeadersAsSequence, None] = None,
|
||||
ssl_context: Optional[ssl.SSLContext] = None,
|
||||
max_connections: Optional[int] = 10,
|
||||
max_keepalive_connections: Optional[int] = None,
|
||||
keepalive_expiry: Optional[float] = None,
|
||||
proxy_url: URL | bytes | str,
|
||||
proxy_auth: tuple[bytes | str, bytes | str] | None = None,
|
||||
proxy_headers: HeadersAsMapping | HeadersAsSequence | None = None,
|
||||
ssl_context: ssl.SSLContext | None = None,
|
||||
proxy_ssl_context: ssl.SSLContext | None = None,
|
||||
max_connections: int | None = 10,
|
||||
max_keepalive_connections: int | None = None,
|
||||
keepalive_expiry: float | None = None,
|
||||
http1: bool = True,
|
||||
http2: bool = False,
|
||||
retries: int = 0,
|
||||
local_address: Optional[str] = None,
|
||||
uds: Optional[str] = None,
|
||||
network_backend: Optional[AsyncNetworkBackend] = None,
|
||||
socket_options: Optional[Iterable[SOCKET_OPTION]] = None,
|
||||
local_address: str | None = None,
|
||||
uds: str | None = None,
|
||||
network_backend: AsyncNetworkBackend | None = None,
|
||||
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
A connection pool for making HTTP requests.
|
||||
@@ -88,6 +87,7 @@ class AsyncHTTPProxy(AsyncConnectionPool):
|
||||
ssl_context: An SSL context to use for verifying connections.
|
||||
If not specified, the default `httpcore.default_ssl_context()`
|
||||
will be used.
|
||||
proxy_ssl_context: The same as `ssl_context`, but for a proxy server rather than a remote origin.
|
||||
max_connections: The maximum number of concurrent HTTP connections that
|
||||
the pool should allow. Any attempt to send a request on a pool that
|
||||
would exceed this amount will block until a connection is available.
|
||||
@@ -122,13 +122,23 @@ class AsyncHTTPProxy(AsyncConnectionPool):
|
||||
uds=uds,
|
||||
socket_options=socket_options,
|
||||
)
|
||||
self._ssl_context = ssl_context
|
||||
|
||||
self._proxy_url = enforce_url(proxy_url, name="proxy_url")
|
||||
if (
|
||||
self._proxy_url.scheme == b"http" and proxy_ssl_context is not None
|
||||
): # pragma: no cover
|
||||
raise RuntimeError(
|
||||
"The `proxy_ssl_context` argument is not allowed for the http scheme"
|
||||
)
|
||||
|
||||
self._ssl_context = ssl_context
|
||||
self._proxy_ssl_context = proxy_ssl_context
|
||||
self._proxy_headers = enforce_headers(proxy_headers, name="proxy_headers")
|
||||
if proxy_auth is not None:
|
||||
username = enforce_bytes(proxy_auth[0], name="proxy_auth")
|
||||
password = enforce_bytes(proxy_auth[1], name="proxy_auth")
|
||||
authorization = build_auth_header(username, password)
|
||||
userpass = username + b":" + password
|
||||
authorization = b"Basic " + base64.b64encode(userpass)
|
||||
self._proxy_headers = [
|
||||
(b"Proxy-Authorization", authorization)
|
||||
] + self._proxy_headers
|
||||
@@ -141,12 +151,14 @@ class AsyncHTTPProxy(AsyncConnectionPool):
|
||||
remote_origin=origin,
|
||||
keepalive_expiry=self._keepalive_expiry,
|
||||
network_backend=self._network_backend,
|
||||
proxy_ssl_context=self._proxy_ssl_context,
|
||||
)
|
||||
return AsyncTunnelHTTPConnection(
|
||||
proxy_origin=self._proxy_url.origin,
|
||||
proxy_headers=self._proxy_headers,
|
||||
remote_origin=origin,
|
||||
ssl_context=self._ssl_context,
|
||||
proxy_ssl_context=self._proxy_ssl_context,
|
||||
keepalive_expiry=self._keepalive_expiry,
|
||||
http1=self._http1,
|
||||
http2=self._http2,
|
||||
@@ -159,16 +171,18 @@ class AsyncForwardHTTPConnection(AsyncConnectionInterface):
|
||||
self,
|
||||
proxy_origin: Origin,
|
||||
remote_origin: Origin,
|
||||
proxy_headers: Union[HeadersAsMapping, HeadersAsSequence, None] = None,
|
||||
keepalive_expiry: Optional[float] = None,
|
||||
network_backend: Optional[AsyncNetworkBackend] = None,
|
||||
socket_options: Optional[Iterable[SOCKET_OPTION]] = None,
|
||||
proxy_headers: HeadersAsMapping | HeadersAsSequence | None = None,
|
||||
keepalive_expiry: float | None = None,
|
||||
network_backend: AsyncNetworkBackend | None = None,
|
||||
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
|
||||
proxy_ssl_context: ssl.SSLContext | None = None,
|
||||
) -> None:
|
||||
self._connection = AsyncHTTPConnection(
|
||||
origin=proxy_origin,
|
||||
keepalive_expiry=keepalive_expiry,
|
||||
network_backend=network_backend,
|
||||
socket_options=socket_options,
|
||||
ssl_context=proxy_ssl_context,
|
||||
)
|
||||
self._proxy_origin = proxy_origin
|
||||
self._proxy_headers = enforce_headers(proxy_headers, name="proxy_headers")
|
||||
@@ -221,23 +235,26 @@ class AsyncTunnelHTTPConnection(AsyncConnectionInterface):
|
||||
self,
|
||||
proxy_origin: Origin,
|
||||
remote_origin: Origin,
|
||||
ssl_context: Optional[ssl.SSLContext] = None,
|
||||
proxy_headers: Optional[Sequence[Tuple[bytes, bytes]]] = None,
|
||||
keepalive_expiry: Optional[float] = None,
|
||||
ssl_context: ssl.SSLContext | None = None,
|
||||
proxy_ssl_context: ssl.SSLContext | None = None,
|
||||
proxy_headers: typing.Sequence[tuple[bytes, bytes]] | None = None,
|
||||
keepalive_expiry: float | None = None,
|
||||
http1: bool = True,
|
||||
http2: bool = False,
|
||||
network_backend: Optional[AsyncNetworkBackend] = None,
|
||||
socket_options: Optional[Iterable[SOCKET_OPTION]] = None,
|
||||
network_backend: AsyncNetworkBackend | None = None,
|
||||
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
|
||||
) -> None:
|
||||
self._connection: AsyncConnectionInterface = AsyncHTTPConnection(
|
||||
origin=proxy_origin,
|
||||
keepalive_expiry=keepalive_expiry,
|
||||
network_backend=network_backend,
|
||||
socket_options=socket_options,
|
||||
ssl_context=proxy_ssl_context,
|
||||
)
|
||||
self._proxy_origin = proxy_origin
|
||||
self._remote_origin = remote_origin
|
||||
self._ssl_context = ssl_context
|
||||
self._proxy_ssl_context = proxy_ssl_context
|
||||
self._proxy_headers = enforce_headers(proxy_headers, name="proxy_headers")
|
||||
self._keepalive_expiry = keepalive_expiry
|
||||
self._http1 = http1
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncIterator, Optional, Union
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import typing
|
||||
|
||||
from .._models import (
|
||||
URL,
|
||||
@@ -18,12 +20,12 @@ from .._models import (
|
||||
class AsyncRequestInterface:
|
||||
async def request(
|
||||
self,
|
||||
method: Union[bytes, str],
|
||||
url: Union[URL, bytes, str],
|
||||
method: bytes | str,
|
||||
url: URL | bytes | str,
|
||||
*,
|
||||
headers: HeaderTypes = None,
|
||||
content: Union[bytes, AsyncIterator[bytes], None] = None,
|
||||
extensions: Optional[Extensions] = None,
|
||||
content: bytes | typing.AsyncIterator[bytes] | None = None,
|
||||
extensions: Extensions | None = None,
|
||||
) -> Response:
|
||||
# Strict type checking on our parameters.
|
||||
method = enforce_bytes(method, name="method")
|
||||
@@ -47,16 +49,16 @@ class AsyncRequestInterface:
|
||||
await response.aclose()
|
||||
return response
|
||||
|
||||
@asynccontextmanager
|
||||
@contextlib.asynccontextmanager
|
||||
async def stream(
|
||||
self,
|
||||
method: Union[bytes, str],
|
||||
url: Union[URL, bytes, str],
|
||||
method: bytes | str,
|
||||
url: URL | bytes | str,
|
||||
*,
|
||||
headers: HeaderTypes = None,
|
||||
content: Union[bytes, AsyncIterator[bytes], None] = None,
|
||||
extensions: Optional[Extensions] = None,
|
||||
) -> AsyncIterator[Response]:
|
||||
content: bytes | typing.AsyncIterator[bytes] | None = None,
|
||||
extensions: Extensions | None = None,
|
||||
) -> typing.AsyncIterator[Response]:
|
||||
# Strict type checking on our parameters.
|
||||
method = enforce_bytes(method, name="method")
|
||||
url = enforce_url(url, name="url")
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import ssl
|
||||
import typing
|
||||
|
||||
from socksio import socks5
|
||||
import socksio
|
||||
|
||||
from .._backends.auto import AutoBackend
|
||||
from .._backends.base import AsyncNetworkBackend, AsyncNetworkStream
|
||||
@@ -43,24 +44,24 @@ async def _init_socks5_connection(
|
||||
*,
|
||||
host: bytes,
|
||||
port: int,
|
||||
auth: typing.Optional[typing.Tuple[bytes, bytes]] = None,
|
||||
auth: tuple[bytes, bytes] | None = None,
|
||||
) -> None:
|
||||
conn = socks5.SOCKS5Connection()
|
||||
conn = socksio.socks5.SOCKS5Connection()
|
||||
|
||||
# Auth method request
|
||||
auth_method = (
|
||||
socks5.SOCKS5AuthMethod.NO_AUTH_REQUIRED
|
||||
socksio.socks5.SOCKS5AuthMethod.NO_AUTH_REQUIRED
|
||||
if auth is None
|
||||
else socks5.SOCKS5AuthMethod.USERNAME_PASSWORD
|
||||
else socksio.socks5.SOCKS5AuthMethod.USERNAME_PASSWORD
|
||||
)
|
||||
conn.send(socks5.SOCKS5AuthMethodsRequest([auth_method]))
|
||||
conn.send(socksio.socks5.SOCKS5AuthMethodsRequest([auth_method]))
|
||||
outgoing_bytes = conn.data_to_send()
|
||||
await stream.write(outgoing_bytes)
|
||||
|
||||
# Auth method response
|
||||
incoming_bytes = await stream.read(max_bytes=4096)
|
||||
response = conn.receive_data(incoming_bytes)
|
||||
assert isinstance(response, socks5.SOCKS5AuthReply)
|
||||
assert isinstance(response, socksio.socks5.SOCKS5AuthReply)
|
||||
if response.method != auth_method:
|
||||
requested = AUTH_METHODS.get(auth_method, "UNKNOWN")
|
||||
responded = AUTH_METHODS.get(response.method, "UNKNOWN")
|
||||
@@ -68,25 +69,25 @@ async def _init_socks5_connection(
|
||||
f"Requested {requested} from proxy server, but got {responded}."
|
||||
)
|
||||
|
||||
if response.method == socks5.SOCKS5AuthMethod.USERNAME_PASSWORD:
|
||||
if response.method == socksio.socks5.SOCKS5AuthMethod.USERNAME_PASSWORD:
|
||||
# Username/password request
|
||||
assert auth is not None
|
||||
username, password = auth
|
||||
conn.send(socks5.SOCKS5UsernamePasswordRequest(username, password))
|
||||
conn.send(socksio.socks5.SOCKS5UsernamePasswordRequest(username, password))
|
||||
outgoing_bytes = conn.data_to_send()
|
||||
await stream.write(outgoing_bytes)
|
||||
|
||||
# Username/password response
|
||||
incoming_bytes = await stream.read(max_bytes=4096)
|
||||
response = conn.receive_data(incoming_bytes)
|
||||
assert isinstance(response, socks5.SOCKS5UsernamePasswordReply)
|
||||
assert isinstance(response, socksio.socks5.SOCKS5UsernamePasswordReply)
|
||||
if not response.success:
|
||||
raise ProxyError("Invalid username/password")
|
||||
|
||||
# Connect request
|
||||
conn.send(
|
||||
socks5.SOCKS5CommandRequest.from_address(
|
||||
socks5.SOCKS5Command.CONNECT, (host, port)
|
||||
socksio.socks5.SOCKS5CommandRequest.from_address(
|
||||
socksio.socks5.SOCKS5Command.CONNECT, (host, port)
|
||||
)
|
||||
)
|
||||
outgoing_bytes = conn.data_to_send()
|
||||
@@ -95,31 +96,29 @@ async def _init_socks5_connection(
|
||||
# Connect response
|
||||
incoming_bytes = await stream.read(max_bytes=4096)
|
||||
response = conn.receive_data(incoming_bytes)
|
||||
assert isinstance(response, socks5.SOCKS5Reply)
|
||||
if response.reply_code != socks5.SOCKS5ReplyCode.SUCCEEDED:
|
||||
assert isinstance(response, socksio.socks5.SOCKS5Reply)
|
||||
if response.reply_code != socksio.socks5.SOCKS5ReplyCode.SUCCEEDED:
|
||||
reply_code = REPLY_CODES.get(response.reply_code, "UNKOWN")
|
||||
raise ProxyError(f"Proxy Server could not connect: {reply_code}.")
|
||||
|
||||
|
||||
class AsyncSOCKSProxy(AsyncConnectionPool):
|
||||
class AsyncSOCKSProxy(AsyncConnectionPool): # pragma: nocover
|
||||
"""
|
||||
A connection pool that sends requests via an HTTP proxy.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
proxy_url: typing.Union[URL, bytes, str],
|
||||
proxy_auth: typing.Optional[
|
||||
typing.Tuple[typing.Union[bytes, str], typing.Union[bytes, str]]
|
||||
] = None,
|
||||
ssl_context: typing.Optional[ssl.SSLContext] = None,
|
||||
max_connections: typing.Optional[int] = 10,
|
||||
max_keepalive_connections: typing.Optional[int] = None,
|
||||
keepalive_expiry: typing.Optional[float] = None,
|
||||
proxy_url: URL | bytes | str,
|
||||
proxy_auth: tuple[bytes | str, bytes | str] | None = None,
|
||||
ssl_context: ssl.SSLContext | None = None,
|
||||
max_connections: int | None = 10,
|
||||
max_keepalive_connections: int | None = None,
|
||||
keepalive_expiry: float | None = None,
|
||||
http1: bool = True,
|
||||
http2: bool = False,
|
||||
retries: int = 0,
|
||||
network_backend: typing.Optional[AsyncNetworkBackend] = None,
|
||||
network_backend: AsyncNetworkBackend | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
A connection pool for making HTTP requests.
|
||||
@@ -167,7 +166,7 @@ class AsyncSOCKSProxy(AsyncConnectionPool):
|
||||
username, password = proxy_auth
|
||||
username_bytes = enforce_bytes(username, name="proxy_auth")
|
||||
password_bytes = enforce_bytes(password, name="proxy_auth")
|
||||
self._proxy_auth: typing.Optional[typing.Tuple[bytes, bytes]] = (
|
||||
self._proxy_auth: tuple[bytes, bytes] | None = (
|
||||
username_bytes,
|
||||
password_bytes,
|
||||
)
|
||||
@@ -192,12 +191,12 @@ class AsyncSocks5Connection(AsyncConnectionInterface):
|
||||
self,
|
||||
proxy_origin: Origin,
|
||||
remote_origin: Origin,
|
||||
proxy_auth: typing.Optional[typing.Tuple[bytes, bytes]] = None,
|
||||
ssl_context: typing.Optional[ssl.SSLContext] = None,
|
||||
keepalive_expiry: typing.Optional[float] = None,
|
||||
proxy_auth: tuple[bytes, bytes] | None = None,
|
||||
ssl_context: ssl.SSLContext | None = None,
|
||||
keepalive_expiry: float | None = None,
|
||||
http1: bool = True,
|
||||
http2: bool = False,
|
||||
network_backend: typing.Optional[AsyncNetworkBackend] = None,
|
||||
network_backend: AsyncNetworkBackend | None = None,
|
||||
) -> None:
|
||||
self._proxy_origin = proxy_origin
|
||||
self._remote_origin = remote_origin
|
||||
@@ -211,11 +210,12 @@ class AsyncSocks5Connection(AsyncConnectionInterface):
|
||||
AutoBackend() if network_backend is None else network_backend
|
||||
)
|
||||
self._connect_lock = AsyncLock()
|
||||
self._connection: typing.Optional[AsyncConnectionInterface] = None
|
||||
self._connection: AsyncConnectionInterface | None = None
|
||||
self._connect_failed = False
|
||||
|
||||
async def handle_async_request(self, request: Request) -> Response:
|
||||
timeouts = request.extensions.get("timeout", {})
|
||||
sni_hostname = request.extensions.get("sni_hostname", None)
|
||||
timeout = timeouts.get("connect", None)
|
||||
|
||||
async with self._connect_lock:
|
||||
@@ -227,7 +227,7 @@ class AsyncSocks5Connection(AsyncConnectionInterface):
|
||||
"port": self._proxy_origin.port,
|
||||
"timeout": timeout,
|
||||
}
|
||||
with Trace("connect_tcp", logger, request, kwargs) as trace:
|
||||
async with Trace("connect_tcp", logger, request, kwargs) as trace:
|
||||
stream = await self._network_backend.connect_tcp(**kwargs)
|
||||
trace.return_value = stream
|
||||
|
||||
@@ -238,7 +238,7 @@ class AsyncSocks5Connection(AsyncConnectionInterface):
|
||||
"port": self._remote_origin.port,
|
||||
"auth": self._proxy_auth,
|
||||
}
|
||||
with Trace(
|
||||
async with Trace(
|
||||
"setup_socks5_connection", logger, request, kwargs
|
||||
) as trace:
|
||||
await _init_socks5_connection(**kwargs)
|
||||
@@ -258,7 +258,8 @@ class AsyncSocks5Connection(AsyncConnectionInterface):
|
||||
|
||||
kwargs = {
|
||||
"ssl_context": ssl_context,
|
||||
"server_hostname": self._remote_origin.host.decode("ascii"),
|
||||
"server_hostname": sni_hostname
|
||||
or self._remote_origin.host.decode("ascii"),
|
||||
"timeout": timeout,
|
||||
}
|
||||
async with Trace("start_tls", logger, request, kwargs) as trace:
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import ssl
|
||||
import typing
|
||||
|
||||
@@ -20,13 +22,12 @@ class AnyIOStream(AsyncNetworkStream):
|
||||
def __init__(self, stream: anyio.abc.ByteStream) -> None:
|
||||
self._stream = stream
|
||||
|
||||
async def read(
|
||||
self, max_bytes: int, timeout: typing.Optional[float] = None
|
||||
) -> bytes:
|
||||
async def read(self, max_bytes: int, timeout: float | None = None) -> bytes:
|
||||
exc_map = {
|
||||
TimeoutError: ReadTimeout,
|
||||
anyio.BrokenResourceError: ReadError,
|
||||
anyio.ClosedResourceError: ReadError,
|
||||
anyio.EndOfStream: ReadError,
|
||||
}
|
||||
with map_exceptions(exc_map):
|
||||
with anyio.fail_after(timeout):
|
||||
@@ -35,9 +36,7 @@ class AnyIOStream(AsyncNetworkStream):
|
||||
except anyio.EndOfStream: # pragma: nocover
|
||||
return b""
|
||||
|
||||
async def write(
|
||||
self, buffer: bytes, timeout: typing.Optional[float] = None
|
||||
) -> None:
|
||||
async def write(self, buffer: bytes, timeout: float | None = None) -> None:
|
||||
if not buffer:
|
||||
return
|
||||
|
||||
@@ -56,12 +55,14 @@ class AnyIOStream(AsyncNetworkStream):
|
||||
async def start_tls(
|
||||
self,
|
||||
ssl_context: ssl.SSLContext,
|
||||
server_hostname: typing.Optional[str] = None,
|
||||
timeout: typing.Optional[float] = None,
|
||||
server_hostname: str | None = None,
|
||||
timeout: float | None = None,
|
||||
) -> AsyncNetworkStream:
|
||||
exc_map = {
|
||||
TimeoutError: ConnectTimeout,
|
||||
anyio.BrokenResourceError: ConnectError,
|
||||
anyio.EndOfStream: ConnectError,
|
||||
ssl.SSLError: ConnectError,
|
||||
}
|
||||
with map_exceptions(exc_map):
|
||||
try:
|
||||
@@ -98,12 +99,12 @@ class AnyIOBackend(AsyncNetworkBackend):
|
||||
self,
|
||||
host: str,
|
||||
port: int,
|
||||
timeout: typing.Optional[float] = None,
|
||||
local_address: typing.Optional[str] = None,
|
||||
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
|
||||
) -> AsyncNetworkStream:
|
||||
timeout: float | None = None,
|
||||
local_address: str | None = None,
|
||||
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
|
||||
) -> AsyncNetworkStream: # pragma: nocover
|
||||
if socket_options is None:
|
||||
socket_options = [] # pragma: no cover
|
||||
socket_options = []
|
||||
exc_map = {
|
||||
TimeoutError: ConnectTimeout,
|
||||
OSError: ConnectError,
|
||||
@@ -124,8 +125,8 @@ class AnyIOBackend(AsyncNetworkBackend):
|
||||
async def connect_unix_socket(
|
||||
self,
|
||||
path: str,
|
||||
timeout: typing.Optional[float] = None,
|
||||
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
|
||||
timeout: float | None = None,
|
||||
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
|
||||
) -> AsyncNetworkStream: # pragma: nocover
|
||||
if socket_options is None:
|
||||
socket_options = []
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
from typing import Optional
|
||||
|
||||
import sniffio
|
||||
|
||||
from .._synchronization import current_async_library
|
||||
from .base import SOCKET_OPTION, AsyncNetworkBackend, AsyncNetworkStream
|
||||
|
||||
|
||||
class AutoBackend(AsyncNetworkBackend):
|
||||
async def _init_backend(self) -> None:
|
||||
if not (hasattr(self, "_backend")):
|
||||
backend = sniffio.current_async_library()
|
||||
backend = current_async_library()
|
||||
if backend == "trio":
|
||||
from .trio import TrioBackend
|
||||
|
||||
@@ -23,9 +23,9 @@ class AutoBackend(AsyncNetworkBackend):
|
||||
self,
|
||||
host: str,
|
||||
port: int,
|
||||
timeout: Optional[float] = None,
|
||||
local_address: Optional[str] = None,
|
||||
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
|
||||
timeout: float | None = None,
|
||||
local_address: str | None = None,
|
||||
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
|
||||
) -> AsyncNetworkStream:
|
||||
await self._init_backend()
|
||||
return await self._backend.connect_tcp(
|
||||
@@ -39,8 +39,8 @@ class AutoBackend(AsyncNetworkBackend):
|
||||
async def connect_unix_socket(
|
||||
self,
|
||||
path: str,
|
||||
timeout: Optional[float] = None,
|
||||
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
|
||||
timeout: float | None = None,
|
||||
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
|
||||
) -> AsyncNetworkStream: # pragma: nocover
|
||||
await self._init_backend()
|
||||
return await self._backend.connect_unix_socket(
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import ssl
|
||||
import time
|
||||
import typing
|
||||
@@ -10,10 +12,10 @@ SOCKET_OPTION = typing.Union[
|
||||
|
||||
|
||||
class NetworkStream:
|
||||
def read(self, max_bytes: int, timeout: typing.Optional[float] = None) -> bytes:
|
||||
def read(self, max_bytes: int, timeout: float | None = None) -> bytes:
|
||||
raise NotImplementedError() # pragma: nocover
|
||||
|
||||
def write(self, buffer: bytes, timeout: typing.Optional[float] = None) -> None:
|
||||
def write(self, buffer: bytes, timeout: float | None = None) -> None:
|
||||
raise NotImplementedError() # pragma: nocover
|
||||
|
||||
def close(self) -> None:
|
||||
@@ -22,9 +24,9 @@ class NetworkStream:
|
||||
def start_tls(
|
||||
self,
|
||||
ssl_context: ssl.SSLContext,
|
||||
server_hostname: typing.Optional[str] = None,
|
||||
timeout: typing.Optional[float] = None,
|
||||
) -> "NetworkStream":
|
||||
server_hostname: str | None = None,
|
||||
timeout: float | None = None,
|
||||
) -> NetworkStream:
|
||||
raise NotImplementedError() # pragma: nocover
|
||||
|
||||
def get_extra_info(self, info: str) -> typing.Any:
|
||||
@@ -36,17 +38,17 @@ class NetworkBackend:
|
||||
self,
|
||||
host: str,
|
||||
port: int,
|
||||
timeout: typing.Optional[float] = None,
|
||||
local_address: typing.Optional[str] = None,
|
||||
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
|
||||
timeout: float | None = None,
|
||||
local_address: str | None = None,
|
||||
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
|
||||
) -> NetworkStream:
|
||||
raise NotImplementedError() # pragma: nocover
|
||||
|
||||
def connect_unix_socket(
|
||||
self,
|
||||
path: str,
|
||||
timeout: typing.Optional[float] = None,
|
||||
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
|
||||
timeout: float | None = None,
|
||||
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
|
||||
) -> NetworkStream:
|
||||
raise NotImplementedError() # pragma: nocover
|
||||
|
||||
@@ -55,14 +57,10 @@ class NetworkBackend:
|
||||
|
||||
|
||||
class AsyncNetworkStream:
|
||||
async def read(
|
||||
self, max_bytes: int, timeout: typing.Optional[float] = None
|
||||
) -> bytes:
|
||||
async def read(self, max_bytes: int, timeout: float | None = None) -> bytes:
|
||||
raise NotImplementedError() # pragma: nocover
|
||||
|
||||
async def write(
|
||||
self, buffer: bytes, timeout: typing.Optional[float] = None
|
||||
) -> None:
|
||||
async def write(self, buffer: bytes, timeout: float | None = None) -> None:
|
||||
raise NotImplementedError() # pragma: nocover
|
||||
|
||||
async def aclose(self) -> None:
|
||||
@@ -71,9 +69,9 @@ class AsyncNetworkStream:
|
||||
async def start_tls(
|
||||
self,
|
||||
ssl_context: ssl.SSLContext,
|
||||
server_hostname: typing.Optional[str] = None,
|
||||
timeout: typing.Optional[float] = None,
|
||||
) -> "AsyncNetworkStream":
|
||||
server_hostname: str | None = None,
|
||||
timeout: float | None = None,
|
||||
) -> AsyncNetworkStream:
|
||||
raise NotImplementedError() # pragma: nocover
|
||||
|
||||
def get_extra_info(self, info: str) -> typing.Any:
|
||||
@@ -85,17 +83,17 @@ class AsyncNetworkBackend:
|
||||
self,
|
||||
host: str,
|
||||
port: int,
|
||||
timeout: typing.Optional[float] = None,
|
||||
local_address: typing.Optional[str] = None,
|
||||
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
|
||||
timeout: float | None = None,
|
||||
local_address: str | None = None,
|
||||
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
|
||||
) -> AsyncNetworkStream:
|
||||
raise NotImplementedError() # pragma: nocover
|
||||
|
||||
async def connect_unix_socket(
|
||||
self,
|
||||
path: str,
|
||||
timeout: typing.Optional[float] = None,
|
||||
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
|
||||
timeout: float | None = None,
|
||||
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
|
||||
) -> AsyncNetworkStream:
|
||||
raise NotImplementedError() # pragma: nocover
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import ssl
|
||||
import typing
|
||||
from typing import Optional
|
||||
|
||||
from .._exceptions import ReadError
|
||||
from .base import (
|
||||
@@ -21,19 +22,19 @@ class MockSSLObject:
|
||||
|
||||
|
||||
class MockStream(NetworkStream):
|
||||
def __init__(self, buffer: typing.List[bytes], http2: bool = False) -> None:
|
||||
def __init__(self, buffer: list[bytes], http2: bool = False) -> None:
|
||||
self._buffer = buffer
|
||||
self._http2 = http2
|
||||
self._closed = False
|
||||
|
||||
def read(self, max_bytes: int, timeout: Optional[float] = None) -> bytes:
|
||||
def read(self, max_bytes: int, timeout: float | None = None) -> bytes:
|
||||
if self._closed:
|
||||
raise ReadError("Connection closed")
|
||||
if not self._buffer:
|
||||
return b""
|
||||
return self._buffer.pop(0)
|
||||
|
||||
def write(self, buffer: bytes, timeout: Optional[float] = None) -> None:
|
||||
def write(self, buffer: bytes, timeout: float | None = None) -> None:
|
||||
pass
|
||||
|
||||
def close(self) -> None:
|
||||
@@ -42,8 +43,8 @@ class MockStream(NetworkStream):
|
||||
def start_tls(
|
||||
self,
|
||||
ssl_context: ssl.SSLContext,
|
||||
server_hostname: Optional[str] = None,
|
||||
timeout: Optional[float] = None,
|
||||
server_hostname: str | None = None,
|
||||
timeout: float | None = None,
|
||||
) -> NetworkStream:
|
||||
return self
|
||||
|
||||
@@ -55,7 +56,7 @@ class MockStream(NetworkStream):
|
||||
|
||||
|
||||
class MockBackend(NetworkBackend):
|
||||
def __init__(self, buffer: typing.List[bytes], http2: bool = False) -> None:
|
||||
def __init__(self, buffer: list[bytes], http2: bool = False) -> None:
|
||||
self._buffer = buffer
|
||||
self._http2 = http2
|
||||
|
||||
@@ -63,17 +64,17 @@ class MockBackend(NetworkBackend):
|
||||
self,
|
||||
host: str,
|
||||
port: int,
|
||||
timeout: Optional[float] = None,
|
||||
local_address: Optional[str] = None,
|
||||
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
|
||||
timeout: float | None = None,
|
||||
local_address: str | None = None,
|
||||
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
|
||||
) -> NetworkStream:
|
||||
return MockStream(list(self._buffer), http2=self._http2)
|
||||
|
||||
def connect_unix_socket(
|
||||
self,
|
||||
path: str,
|
||||
timeout: Optional[float] = None,
|
||||
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
|
||||
timeout: float | None = None,
|
||||
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
|
||||
) -> NetworkStream:
|
||||
return MockStream(list(self._buffer), http2=self._http2)
|
||||
|
||||
@@ -82,19 +83,19 @@ class MockBackend(NetworkBackend):
|
||||
|
||||
|
||||
class AsyncMockStream(AsyncNetworkStream):
|
||||
def __init__(self, buffer: typing.List[bytes], http2: bool = False) -> None:
|
||||
def __init__(self, buffer: list[bytes], http2: bool = False) -> None:
|
||||
self._buffer = buffer
|
||||
self._http2 = http2
|
||||
self._closed = False
|
||||
|
||||
async def read(self, max_bytes: int, timeout: Optional[float] = None) -> bytes:
|
||||
async def read(self, max_bytes: int, timeout: float | None = None) -> bytes:
|
||||
if self._closed:
|
||||
raise ReadError("Connection closed")
|
||||
if not self._buffer:
|
||||
return b""
|
||||
return self._buffer.pop(0)
|
||||
|
||||
async def write(self, buffer: bytes, timeout: Optional[float] = None) -> None:
|
||||
async def write(self, buffer: bytes, timeout: float | None = None) -> None:
|
||||
pass
|
||||
|
||||
async def aclose(self) -> None:
|
||||
@@ -103,8 +104,8 @@ class AsyncMockStream(AsyncNetworkStream):
|
||||
async def start_tls(
|
||||
self,
|
||||
ssl_context: ssl.SSLContext,
|
||||
server_hostname: Optional[str] = None,
|
||||
timeout: Optional[float] = None,
|
||||
server_hostname: str | None = None,
|
||||
timeout: float | None = None,
|
||||
) -> AsyncNetworkStream:
|
||||
return self
|
||||
|
||||
@@ -116,7 +117,7 @@ class AsyncMockStream(AsyncNetworkStream):
|
||||
|
||||
|
||||
class AsyncMockBackend(AsyncNetworkBackend):
|
||||
def __init__(self, buffer: typing.List[bytes], http2: bool = False) -> None:
|
||||
def __init__(self, buffer: list[bytes], http2: bool = False) -> None:
|
||||
self._buffer = buffer
|
||||
self._http2 = http2
|
||||
|
||||
@@ -124,17 +125,17 @@ class AsyncMockBackend(AsyncNetworkBackend):
|
||||
self,
|
||||
host: str,
|
||||
port: int,
|
||||
timeout: Optional[float] = None,
|
||||
local_address: Optional[str] = None,
|
||||
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
|
||||
timeout: float | None = None,
|
||||
local_address: str | None = None,
|
||||
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
|
||||
) -> AsyncNetworkStream:
|
||||
return AsyncMockStream(list(self._buffer), http2=self._http2)
|
||||
|
||||
async def connect_unix_socket(
|
||||
self,
|
||||
path: str,
|
||||
timeout: Optional[float] = None,
|
||||
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
|
||||
timeout: float | None = None,
|
||||
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
|
||||
) -> AsyncNetworkStream:
|
||||
return AsyncMockStream(list(self._buffer), http2=self._http2)
|
||||
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import socket
|
||||
import ssl
|
||||
import sys
|
||||
@@ -17,17 +20,114 @@ from .._utils import is_socket_readable
|
||||
from .base import SOCKET_OPTION, NetworkBackend, NetworkStream
|
||||
|
||||
|
||||
class TLSinTLSStream(NetworkStream): # pragma: no cover
|
||||
"""
|
||||
Because the standard `SSLContext.wrap_socket` method does
|
||||
not work for `SSLSocket` objects, we need this class
|
||||
to implement TLS stream using an underlying `SSLObject`
|
||||
instance in order to support TLS on top of TLS.
|
||||
"""
|
||||
|
||||
# Defined in RFC 8449
|
||||
TLS_RECORD_SIZE = 16384
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sock: socket.socket,
|
||||
ssl_context: ssl.SSLContext,
|
||||
server_hostname: str | None = None,
|
||||
timeout: float | None = None,
|
||||
):
|
||||
self._sock = sock
|
||||
self._incoming = ssl.MemoryBIO()
|
||||
self._outgoing = ssl.MemoryBIO()
|
||||
|
||||
self.ssl_obj = ssl_context.wrap_bio(
|
||||
incoming=self._incoming,
|
||||
outgoing=self._outgoing,
|
||||
server_hostname=server_hostname,
|
||||
)
|
||||
|
||||
self._sock.settimeout(timeout)
|
||||
self._perform_io(self.ssl_obj.do_handshake)
|
||||
|
||||
def _perform_io(
|
||||
self,
|
||||
func: typing.Callable[..., typing.Any],
|
||||
) -> typing.Any:
|
||||
ret = None
|
||||
|
||||
while True:
|
||||
errno = None
|
||||
try:
|
||||
ret = func()
|
||||
except (ssl.SSLWantReadError, ssl.SSLWantWriteError) as e:
|
||||
errno = e.errno
|
||||
|
||||
self._sock.sendall(self._outgoing.read())
|
||||
|
||||
if errno == ssl.SSL_ERROR_WANT_READ:
|
||||
buf = self._sock.recv(self.TLS_RECORD_SIZE)
|
||||
|
||||
if buf:
|
||||
self._incoming.write(buf)
|
||||
else:
|
||||
self._incoming.write_eof()
|
||||
if errno is None:
|
||||
return ret
|
||||
|
||||
def read(self, max_bytes: int, timeout: float | None = None) -> bytes:
|
||||
exc_map: ExceptionMapping = {socket.timeout: ReadTimeout, OSError: ReadError}
|
||||
with map_exceptions(exc_map):
|
||||
self._sock.settimeout(timeout)
|
||||
return typing.cast(
|
||||
bytes, self._perform_io(functools.partial(self.ssl_obj.read, max_bytes))
|
||||
)
|
||||
|
||||
def write(self, buffer: bytes, timeout: float | None = None) -> None:
|
||||
exc_map: ExceptionMapping = {socket.timeout: WriteTimeout, OSError: WriteError}
|
||||
with map_exceptions(exc_map):
|
||||
self._sock.settimeout(timeout)
|
||||
while buffer:
|
||||
nsent = self._perform_io(functools.partial(self.ssl_obj.write, buffer))
|
||||
buffer = buffer[nsent:]
|
||||
|
||||
def close(self) -> None:
|
||||
self._sock.close()
|
||||
|
||||
def start_tls(
|
||||
self,
|
||||
ssl_context: ssl.SSLContext,
|
||||
server_hostname: str | None = None,
|
||||
timeout: float | None = None,
|
||||
) -> NetworkStream:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_extra_info(self, info: str) -> typing.Any:
|
||||
if info == "ssl_object":
|
||||
return self.ssl_obj
|
||||
if info == "client_addr":
|
||||
return self._sock.getsockname()
|
||||
if info == "server_addr":
|
||||
return self._sock.getpeername()
|
||||
if info == "socket":
|
||||
return self._sock
|
||||
if info == "is_readable":
|
||||
return is_socket_readable(self._sock)
|
||||
return None
|
||||
|
||||
|
||||
class SyncStream(NetworkStream):
|
||||
def __init__(self, sock: socket.socket) -> None:
|
||||
self._sock = sock
|
||||
|
||||
def read(self, max_bytes: int, timeout: typing.Optional[float] = None) -> bytes:
|
||||
def read(self, max_bytes: int, timeout: float | None = None) -> bytes:
|
||||
exc_map: ExceptionMapping = {socket.timeout: ReadTimeout, OSError: ReadError}
|
||||
with map_exceptions(exc_map):
|
||||
self._sock.settimeout(timeout)
|
||||
return self._sock.recv(max_bytes)
|
||||
|
||||
def write(self, buffer: bytes, timeout: typing.Optional[float] = None) -> None:
|
||||
def write(self, buffer: bytes, timeout: float | None = None) -> None:
|
||||
if not buffer:
|
||||
return
|
||||
|
||||
@@ -44,8 +144,8 @@ class SyncStream(NetworkStream):
|
||||
def start_tls(
|
||||
self,
|
||||
ssl_context: ssl.SSLContext,
|
||||
server_hostname: typing.Optional[str] = None,
|
||||
timeout: typing.Optional[float] = None,
|
||||
server_hostname: str | None = None,
|
||||
timeout: float | None = None,
|
||||
) -> NetworkStream:
|
||||
exc_map: ExceptionMapping = {
|
||||
socket.timeout: ConnectTimeout,
|
||||
@@ -53,10 +153,18 @@ class SyncStream(NetworkStream):
|
||||
}
|
||||
with map_exceptions(exc_map):
|
||||
try:
|
||||
self._sock.settimeout(timeout)
|
||||
sock = ssl_context.wrap_socket(
|
||||
self._sock, server_hostname=server_hostname
|
||||
)
|
||||
if isinstance(self._sock, ssl.SSLSocket): # pragma: no cover
|
||||
# If the underlying socket has already been upgraded
|
||||
# to the TLS layer (i.e. is an instance of SSLSocket),
|
||||
# we need some additional smarts to support TLS-in-TLS.
|
||||
return TLSinTLSStream(
|
||||
self._sock, ssl_context, server_hostname, timeout
|
||||
)
|
||||
else:
|
||||
self._sock.settimeout(timeout)
|
||||
sock = ssl_context.wrap_socket(
|
||||
self._sock, server_hostname=server_hostname
|
||||
)
|
||||
except Exception as exc: # pragma: nocover
|
||||
self.close()
|
||||
raise exc
|
||||
@@ -81,9 +189,9 @@ class SyncBackend(NetworkBackend):
|
||||
self,
|
||||
host: str,
|
||||
port: int,
|
||||
timeout: typing.Optional[float] = None,
|
||||
local_address: typing.Optional[str] = None,
|
||||
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
|
||||
timeout: float | None = None,
|
||||
local_address: str | None = None,
|
||||
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
|
||||
) -> NetworkStream:
|
||||
# Note that we automatically include `TCP_NODELAY`
|
||||
# in addition to any other custom socket options.
|
||||
@@ -110,8 +218,8 @@ class SyncBackend(NetworkBackend):
|
||||
def connect_unix_socket(
|
||||
self,
|
||||
path: str,
|
||||
timeout: typing.Optional[float] = None,
|
||||
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
|
||||
timeout: float | None = None,
|
||||
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
|
||||
) -> NetworkStream: # pragma: nocover
|
||||
if sys.platform == "win32":
|
||||
raise RuntimeError(
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import ssl
|
||||
import typing
|
||||
|
||||
@@ -20,9 +22,7 @@ class TrioStream(AsyncNetworkStream):
|
||||
def __init__(self, stream: trio.abc.Stream) -> None:
|
||||
self._stream = stream
|
||||
|
||||
async def read(
|
||||
self, max_bytes: int, timeout: typing.Optional[float] = None
|
||||
) -> bytes:
|
||||
async def read(self, max_bytes: int, timeout: float | None = None) -> bytes:
|
||||
timeout_or_inf = float("inf") if timeout is None else timeout
|
||||
exc_map: ExceptionMapping = {
|
||||
trio.TooSlowError: ReadTimeout,
|
||||
@@ -34,9 +34,7 @@ class TrioStream(AsyncNetworkStream):
|
||||
data: bytes = await self._stream.receive_some(max_bytes=max_bytes)
|
||||
return data
|
||||
|
||||
async def write(
|
||||
self, buffer: bytes, timeout: typing.Optional[float] = None
|
||||
) -> None:
|
||||
async def write(self, buffer: bytes, timeout: float | None = None) -> None:
|
||||
if not buffer:
|
||||
return
|
||||
|
||||
@@ -56,8 +54,8 @@ class TrioStream(AsyncNetworkStream):
|
||||
async def start_tls(
|
||||
self,
|
||||
ssl_context: ssl.SSLContext,
|
||||
server_hostname: typing.Optional[str] = None,
|
||||
timeout: typing.Optional[float] = None,
|
||||
server_hostname: str | None = None,
|
||||
timeout: float | None = None,
|
||||
) -> AsyncNetworkStream:
|
||||
timeout_or_inf = float("inf") if timeout is None else timeout
|
||||
exc_map: ExceptionMapping = {
|
||||
@@ -113,9 +111,9 @@ class TrioBackend(AsyncNetworkBackend):
|
||||
self,
|
||||
host: str,
|
||||
port: int,
|
||||
timeout: typing.Optional[float] = None,
|
||||
local_address: typing.Optional[str] = None,
|
||||
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
|
||||
timeout: float | None = None,
|
||||
local_address: str | None = None,
|
||||
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
|
||||
) -> AsyncNetworkStream:
|
||||
# By default for TCP sockets, trio enables TCP_NODELAY.
|
||||
# https://trio.readthedocs.io/en/stable/reference-io.html#trio.SocketStream
|
||||
@@ -139,8 +137,8 @@ class TrioBackend(AsyncNetworkBackend):
|
||||
async def connect_unix_socket(
|
||||
self,
|
||||
path: str,
|
||||
timeout: typing.Optional[float] = None,
|
||||
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
|
||||
timeout: float | None = None,
|
||||
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
|
||||
) -> AsyncNetworkStream: # pragma: nocover
|
||||
if socket_options is None:
|
||||
socket_options = []
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import contextlib
|
||||
from typing import Iterator, Mapping, Type
|
||||
import typing
|
||||
|
||||
ExceptionMapping = Mapping[Type[Exception], Type[Exception]]
|
||||
ExceptionMapping = typing.Mapping[typing.Type[Exception], typing.Type[Exception]]
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def map_exceptions(map: ExceptionMapping) -> Iterator[None]:
|
||||
def map_exceptions(map: ExceptionMapping) -> typing.Iterator[None]:
|
||||
try:
|
||||
yield
|
||||
except Exception as exc: # noqa: PIE786
|
||||
|
||||
@@ -1,29 +1,22 @@
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterable,
|
||||
AsyncIterator,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
from urllib.parse import urlparse
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import ssl
|
||||
import typing
|
||||
import urllib.parse
|
||||
|
||||
# Functions for typechecking...
|
||||
|
||||
|
||||
HeadersAsSequence = Sequence[Tuple[Union[bytes, str], Union[bytes, str]]]
|
||||
HeadersAsMapping = Mapping[Union[bytes, str], Union[bytes, str]]
|
||||
HeaderTypes = Union[HeadersAsSequence, HeadersAsMapping, None]
|
||||
ByteOrStr = typing.Union[bytes, str]
|
||||
HeadersAsSequence = typing.Sequence[typing.Tuple[ByteOrStr, ByteOrStr]]
|
||||
HeadersAsMapping = typing.Mapping[ByteOrStr, ByteOrStr]
|
||||
HeaderTypes = typing.Union[HeadersAsSequence, HeadersAsMapping, None]
|
||||
|
||||
Extensions = Mapping[str, Any]
|
||||
Extensions = typing.MutableMapping[str, typing.Any]
|
||||
|
||||
|
||||
def enforce_bytes(value: Union[bytes, str], *, name: str) -> bytes:
|
||||
def enforce_bytes(value: bytes | str, *, name: str) -> bytes:
|
||||
"""
|
||||
Any arguments that are ultimately represented as bytes can be specified
|
||||
either as bytes or as strings.
|
||||
@@ -44,7 +37,7 @@ def enforce_bytes(value: Union[bytes, str], *, name: str) -> bytes:
|
||||
raise TypeError(f"{name} must be bytes or str, but got {seen_type}.")
|
||||
|
||||
|
||||
def enforce_url(value: Union["URL", bytes, str], *, name: str) -> "URL":
|
||||
def enforce_url(value: URL | bytes | str, *, name: str) -> URL:
|
||||
"""
|
||||
Type check for URL parameters.
|
||||
"""
|
||||
@@ -58,15 +51,15 @@ def enforce_url(value: Union["URL", bytes, str], *, name: str) -> "URL":
|
||||
|
||||
|
||||
def enforce_headers(
|
||||
value: Union[HeadersAsMapping, HeadersAsSequence, None] = None, *, name: str
|
||||
) -> List[Tuple[bytes, bytes]]:
|
||||
value: HeadersAsMapping | HeadersAsSequence | None = None, *, name: str
|
||||
) -> list[tuple[bytes, bytes]]:
|
||||
"""
|
||||
Convienence function that ensure all items in request or response headers
|
||||
are either bytes or strings in the plain ASCII range.
|
||||
"""
|
||||
if value is None:
|
||||
return []
|
||||
elif isinstance(value, Mapping):
|
||||
elif isinstance(value, typing.Mapping):
|
||||
return [
|
||||
(
|
||||
enforce_bytes(k, name="header name"),
|
||||
@@ -74,7 +67,7 @@ def enforce_headers(
|
||||
)
|
||||
for k, v in value.items()
|
||||
]
|
||||
elif isinstance(value, Sequence):
|
||||
elif isinstance(value, typing.Sequence):
|
||||
return [
|
||||
(
|
||||
enforce_bytes(k, name="header name"),
|
||||
@@ -90,8 +83,10 @@ def enforce_headers(
|
||||
|
||||
|
||||
def enforce_stream(
|
||||
value: Union[bytes, Iterable[bytes], AsyncIterable[bytes], None], *, name: str
|
||||
) -> Union[Iterable[bytes], AsyncIterable[bytes]]:
|
||||
value: bytes | typing.Iterable[bytes] | typing.AsyncIterable[bytes] | None,
|
||||
*,
|
||||
name: str,
|
||||
) -> typing.Iterable[bytes] | typing.AsyncIterable[bytes]:
|
||||
if value is None:
|
||||
return ByteStream(b"")
|
||||
elif isinstance(value, bytes):
|
||||
@@ -112,11 +107,11 @@ DEFAULT_PORTS = {
|
||||
|
||||
|
||||
def include_request_headers(
|
||||
headers: List[Tuple[bytes, bytes]],
|
||||
headers: list[tuple[bytes, bytes]],
|
||||
*,
|
||||
url: "URL",
|
||||
content: Union[None, bytes, Iterable[bytes], AsyncIterable[bytes]],
|
||||
) -> List[Tuple[bytes, bytes]]:
|
||||
content: None | bytes | typing.Iterable[bytes] | typing.AsyncIterable[bytes],
|
||||
) -> list[tuple[bytes, bytes]]:
|
||||
headers_set = set(k.lower() for k, v in headers)
|
||||
|
||||
if b"host" not in headers_set:
|
||||
@@ -153,10 +148,10 @@ class ByteStream:
|
||||
def __init__(self, content: bytes) -> None:
|
||||
self._content = content
|
||||
|
||||
def __iter__(self) -> Iterator[bytes]:
|
||||
def __iter__(self) -> typing.Iterator[bytes]:
|
||||
yield self._content
|
||||
|
||||
async def __aiter__(self) -> AsyncIterator[bytes]:
|
||||
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
|
||||
yield self._content
|
||||
|
||||
def __repr__(self) -> str:
|
||||
@@ -169,7 +164,7 @@ class Origin:
|
||||
self.host = host
|
||||
self.port = port
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
def __eq__(self, other: typing.Any) -> bool:
|
||||
return (
|
||||
isinstance(other, Origin)
|
||||
and self.scheme == other.scheme
|
||||
@@ -253,12 +248,12 @@ class URL:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
url: Union[bytes, str] = "",
|
||||
url: bytes | str = "",
|
||||
*,
|
||||
scheme: Union[bytes, str] = b"",
|
||||
host: Union[bytes, str] = b"",
|
||||
port: Optional[int] = None,
|
||||
target: Union[bytes, str] = b"",
|
||||
scheme: bytes | str = b"",
|
||||
host: bytes | str = b"",
|
||||
port: int | None = None,
|
||||
target: bytes | str = b"",
|
||||
) -> None:
|
||||
"""
|
||||
Parameters:
|
||||
@@ -270,7 +265,7 @@ class URL:
|
||||
target: The target of the HTTP request. Such as `"/items?search=red"`.
|
||||
"""
|
||||
if url:
|
||||
parsed = urlparse(enforce_bytes(url, name="url"))
|
||||
parsed = urllib.parse.urlparse(enforce_bytes(url, name="url"))
|
||||
self.scheme = parsed.scheme
|
||||
self.host = parsed.hostname or b""
|
||||
self.port = parsed.port
|
||||
@@ -291,12 +286,13 @@ class URL:
|
||||
b"ws": 80,
|
||||
b"wss": 443,
|
||||
b"socks5": 1080,
|
||||
b"socks5h": 1080,
|
||||
}[self.scheme]
|
||||
return Origin(
|
||||
scheme=self.scheme, host=self.host, port=self.port or default_port
|
||||
)
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
def __eq__(self, other: typing.Any) -> bool:
|
||||
return (
|
||||
isinstance(other, URL)
|
||||
and other.scheme == self.scheme
|
||||
@@ -324,12 +320,15 @@ class Request:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
method: Union[bytes, str],
|
||||
url: Union[URL, bytes, str],
|
||||
method: bytes | str,
|
||||
url: URL | bytes | str,
|
||||
*,
|
||||
headers: HeaderTypes = None,
|
||||
content: Union[bytes, Iterable[bytes], AsyncIterable[bytes], None] = None,
|
||||
extensions: Optional[Extensions] = None,
|
||||
content: bytes
|
||||
| typing.Iterable[bytes]
|
||||
| typing.AsyncIterable[bytes]
|
||||
| None = None,
|
||||
extensions: Extensions | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Parameters:
|
||||
@@ -338,20 +337,28 @@ class Request:
|
||||
url: The request URL, either as a `URL` instance, or as a string or bytes.
|
||||
For example: `"https://www.example.com".`
|
||||
headers: The HTTP request headers.
|
||||
content: The content of the response body.
|
||||
content: The content of the request body.
|
||||
extensions: A dictionary of optional extra information included on
|
||||
the request. Possible keys include `"timeout"`, and `"trace"`.
|
||||
"""
|
||||
self.method: bytes = enforce_bytes(method, name="method")
|
||||
self.url: URL = enforce_url(url, name="url")
|
||||
self.headers: List[Tuple[bytes, bytes]] = enforce_headers(
|
||||
self.headers: list[tuple[bytes, bytes]] = enforce_headers(
|
||||
headers, name="headers"
|
||||
)
|
||||
self.stream: Union[Iterable[bytes], AsyncIterable[bytes]] = enforce_stream(
|
||||
content, name="content"
|
||||
self.stream: typing.Iterable[bytes] | typing.AsyncIterable[bytes] = (
|
||||
enforce_stream(content, name="content")
|
||||
)
|
||||
self.extensions = {} if extensions is None else extensions
|
||||
|
||||
if "target" in self.extensions:
|
||||
self.url = URL(
|
||||
scheme=self.url.scheme,
|
||||
host=self.url.host,
|
||||
port=self.url.port,
|
||||
target=self.extensions["target"],
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self.__class__.__name__} [{self.method!r}]>"
|
||||
|
||||
@@ -366,8 +373,11 @@ class Response:
|
||||
status: int,
|
||||
*,
|
||||
headers: HeaderTypes = None,
|
||||
content: Union[bytes, Iterable[bytes], AsyncIterable[bytes], None] = None,
|
||||
extensions: Optional[Extensions] = None,
|
||||
content: bytes
|
||||
| typing.Iterable[bytes]
|
||||
| typing.AsyncIterable[bytes]
|
||||
| None = None,
|
||||
extensions: Extensions | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Parameters:
|
||||
@@ -379,11 +389,11 @@ class Response:
|
||||
`"reason_phrase"`, and `"network_stream"`.
|
||||
"""
|
||||
self.status: int = status
|
||||
self.headers: List[Tuple[bytes, bytes]] = enforce_headers(
|
||||
self.headers: list[tuple[bytes, bytes]] = enforce_headers(
|
||||
headers, name="headers"
|
||||
)
|
||||
self.stream: Union[Iterable[bytes], AsyncIterable[bytes]] = enforce_stream(
|
||||
content, name="content"
|
||||
self.stream: typing.Iterable[bytes] | typing.AsyncIterable[bytes] = (
|
||||
enforce_stream(content, name="content")
|
||||
)
|
||||
self.extensions = {} if extensions is None else extensions
|
||||
|
||||
@@ -392,7 +402,7 @@ class Response:
|
||||
@property
|
||||
def content(self) -> bytes:
|
||||
if not hasattr(self, "_content"):
|
||||
if isinstance(self.stream, Iterable):
|
||||
if isinstance(self.stream, typing.Iterable):
|
||||
raise RuntimeError(
|
||||
"Attempted to access 'response.content' on a streaming response. "
|
||||
"Call 'response.read()' first."
|
||||
@@ -410,7 +420,7 @@ class Response:
|
||||
# Sync interface...
|
||||
|
||||
def read(self) -> bytes:
|
||||
if not isinstance(self.stream, Iterable): # pragma: nocover
|
||||
if not isinstance(self.stream, typing.Iterable): # pragma: nocover
|
||||
raise RuntimeError(
|
||||
"Attempted to read an asynchronous response using 'response.read()'. "
|
||||
"You should use 'await response.aread()' instead."
|
||||
@@ -419,8 +429,8 @@ class Response:
|
||||
self._content = b"".join([part for part in self.iter_stream()])
|
||||
return self._content
|
||||
|
||||
def iter_stream(self) -> Iterator[bytes]:
|
||||
if not isinstance(self.stream, Iterable): # pragma: nocover
|
||||
def iter_stream(self) -> typing.Iterator[bytes]:
|
||||
if not isinstance(self.stream, typing.Iterable): # pragma: nocover
|
||||
raise RuntimeError(
|
||||
"Attempted to stream an asynchronous response using 'for ... in "
|
||||
"response.iter_stream()'. "
|
||||
@@ -435,7 +445,7 @@ class Response:
|
||||
yield chunk
|
||||
|
||||
def close(self) -> None:
|
||||
if not isinstance(self.stream, Iterable): # pragma: nocover
|
||||
if not isinstance(self.stream, typing.Iterable): # pragma: nocover
|
||||
raise RuntimeError(
|
||||
"Attempted to close an asynchronous response using 'response.close()'. "
|
||||
"You should use 'await response.aclose()' instead."
|
||||
@@ -446,7 +456,7 @@ class Response:
|
||||
# Async interface...
|
||||
|
||||
async def aread(self) -> bytes:
|
||||
if not isinstance(self.stream, AsyncIterable): # pragma: nocover
|
||||
if not isinstance(self.stream, typing.AsyncIterable): # pragma: nocover
|
||||
raise RuntimeError(
|
||||
"Attempted to read an synchronous response using "
|
||||
"'await response.aread()'. "
|
||||
@@ -456,8 +466,8 @@ class Response:
|
||||
self._content = b"".join([part async for part in self.aiter_stream()])
|
||||
return self._content
|
||||
|
||||
async def aiter_stream(self) -> AsyncIterator[bytes]:
|
||||
if not isinstance(self.stream, AsyncIterable): # pragma: nocover
|
||||
async def aiter_stream(self) -> typing.AsyncIterator[bytes]:
|
||||
if not isinstance(self.stream, typing.AsyncIterable): # pragma: nocover
|
||||
raise RuntimeError(
|
||||
"Attempted to stream an synchronous response using 'async for ... in "
|
||||
"response.aiter_stream()'. "
|
||||
@@ -473,7 +483,7 @@ class Response:
|
||||
yield chunk
|
||||
|
||||
async def aclose(self) -> None:
|
||||
if not isinstance(self.stream, AsyncIterable): # pragma: nocover
|
||||
if not isinstance(self.stream, typing.AsyncIterable): # pragma: nocover
|
||||
raise RuntimeError(
|
||||
"Attempted to close a synchronous response using "
|
||||
"'await response.aclose()'. "
|
||||
@@ -481,3 +491,26 @@ class Response:
|
||||
)
|
||||
if hasattr(self.stream, "aclose"):
|
||||
await self.stream.aclose()
|
||||
|
||||
|
||||
class Proxy:
|
||||
def __init__(
|
||||
self,
|
||||
url: URL | bytes | str,
|
||||
auth: tuple[bytes | str, bytes | str] | None = None,
|
||||
headers: HeadersAsMapping | HeadersAsSequence | None = None,
|
||||
ssl_context: ssl.SSLContext | None = None,
|
||||
):
|
||||
self.url = enforce_url(url, name="url")
|
||||
self.headers = enforce_headers(headers, name="headers")
|
||||
self.ssl_context = ssl_context
|
||||
|
||||
if auth is not None:
|
||||
username = enforce_bytes(auth[0], name="auth")
|
||||
password = enforce_bytes(auth[1], name="auth")
|
||||
userpass = username + b":" + password
|
||||
authorization = b"Basic " + base64.b64encode(userpass)
|
||||
self.auth: tuple[bytes, bytes] | None = (username, password)
|
||||
self.headers = [(b"Proxy-Authorization", authorization)] + self.headers
|
||||
else:
|
||||
self.auth = None
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,12 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
import logging
|
||||
import ssl
|
||||
from types import TracebackType
|
||||
from typing import Iterable, Iterator, Optional, Type
|
||||
import types
|
||||
import typing
|
||||
|
||||
from .._backends.sync import SyncBackend
|
||||
from .._backends.base import SOCKET_OPTION, NetworkBackend, NetworkStream
|
||||
from .._exceptions import ConnectError, ConnectionNotAvailable, ConnectTimeout
|
||||
from .._exceptions import ConnectError, ConnectTimeout
|
||||
from .._models import Origin, Request, Response
|
||||
from .._ssl import default_ssl_context
|
||||
from .._synchronization import Lock
|
||||
@@ -20,25 +22,32 @@ RETRIES_BACKOFF_FACTOR = 0.5 # 0s, 0.5s, 1s, 2s, 4s, etc.
|
||||
logger = logging.getLogger("httpcore.connection")
|
||||
|
||||
|
||||
def exponential_backoff(factor: float) -> Iterator[float]:
|
||||
def exponential_backoff(factor: float) -> typing.Iterator[float]:
|
||||
"""
|
||||
Generate a geometric sequence that has a ratio of 2 and starts with 0.
|
||||
|
||||
For example:
|
||||
- `factor = 2`: `0, 2, 4, 8, 16, 32, 64, ...`
|
||||
- `factor = 3`: `0, 3, 6, 12, 24, 48, 96, ...`
|
||||
"""
|
||||
yield 0
|
||||
for n in itertools.count(2):
|
||||
yield factor * (2 ** (n - 2))
|
||||
for n in itertools.count():
|
||||
yield factor * 2**n
|
||||
|
||||
|
||||
class HTTPConnection(ConnectionInterface):
|
||||
def __init__(
|
||||
self,
|
||||
origin: Origin,
|
||||
ssl_context: Optional[ssl.SSLContext] = None,
|
||||
keepalive_expiry: Optional[float] = None,
|
||||
ssl_context: ssl.SSLContext | None = None,
|
||||
keepalive_expiry: float | None = None,
|
||||
http1: bool = True,
|
||||
http2: bool = False,
|
||||
retries: int = 0,
|
||||
local_address: Optional[str] = None,
|
||||
uds: Optional[str] = None,
|
||||
network_backend: Optional[NetworkBackend] = None,
|
||||
socket_options: Optional[Iterable[SOCKET_OPTION]] = None,
|
||||
local_address: str | None = None,
|
||||
uds: str | None = None,
|
||||
network_backend: NetworkBackend | None = None,
|
||||
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
|
||||
) -> None:
|
||||
self._origin = origin
|
||||
self._ssl_context = ssl_context
|
||||
@@ -52,7 +61,7 @@ class HTTPConnection(ConnectionInterface):
|
||||
self._network_backend: NetworkBackend = (
|
||||
SyncBackend() if network_backend is None else network_backend
|
||||
)
|
||||
self._connection: Optional[ConnectionInterface] = None
|
||||
self._connection: ConnectionInterface | None = None
|
||||
self._connect_failed: bool = False
|
||||
self._request_lock = Lock()
|
||||
self._socket_options = socket_options
|
||||
@@ -63,9 +72,9 @@ class HTTPConnection(ConnectionInterface):
|
||||
f"Attempted to send request to {request.url.origin} on connection to {self._origin}"
|
||||
)
|
||||
|
||||
with self._request_lock:
|
||||
if self._connection is None:
|
||||
try:
|
||||
try:
|
||||
with self._request_lock:
|
||||
if self._connection is None:
|
||||
stream = self._connect(request)
|
||||
|
||||
ssl_object = stream.get_extra_info("ssl_object")
|
||||
@@ -87,11 +96,9 @@ class HTTPConnection(ConnectionInterface):
|
||||
stream=stream,
|
||||
keepalive_expiry=self._keepalive_expiry,
|
||||
)
|
||||
except Exception as exc:
|
||||
self._connect_failed = True
|
||||
raise exc
|
||||
elif not self._connection.is_available():
|
||||
raise ConnectionNotAvailable()
|
||||
except BaseException as exc:
|
||||
self._connect_failed = True
|
||||
raise exc
|
||||
|
||||
return self._connection.handle_request(request)
|
||||
|
||||
@@ -130,7 +137,7 @@ class HTTPConnection(ConnectionInterface):
|
||||
)
|
||||
trace.return_value = stream
|
||||
|
||||
if self._origin.scheme == b"https":
|
||||
if self._origin.scheme in (b"https", b"wss"):
|
||||
ssl_context = (
|
||||
default_ssl_context()
|
||||
if self._ssl_context is None
|
||||
@@ -203,13 +210,13 @@ class HTTPConnection(ConnectionInterface):
|
||||
# These context managers are not used in the standard flow, but are
|
||||
# useful for testing or working with connection instances directly.
|
||||
|
||||
def __enter__(self) -> "HTTPConnection":
|
||||
def __enter__(self) -> HTTPConnection:
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]] = None,
|
||||
exc_value: Optional[BaseException] = None,
|
||||
traceback: Optional[TracebackType] = None,
|
||||
exc_type: type[BaseException] | None = None,
|
||||
exc_value: BaseException | None = None,
|
||||
traceback: types.TracebackType | None = None,
|
||||
) -> None:
|
||||
self.close()
|
||||
|
||||
@@ -1,41 +1,44 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import ssl
|
||||
import sys
|
||||
from types import TracebackType
|
||||
from typing import Iterable, Iterator, Iterable, List, Optional, Type
|
||||
import types
|
||||
import typing
|
||||
|
||||
from .._backends.sync import SyncBackend
|
||||
from .._backends.base import SOCKET_OPTION, NetworkBackend
|
||||
from .._exceptions import ConnectionNotAvailable, UnsupportedProtocol
|
||||
from .._models import Origin, Request, Response
|
||||
from .._synchronization import Event, Lock, ShieldCancellation
|
||||
from .._models import Origin, Proxy, Request, Response
|
||||
from .._synchronization import Event, ShieldCancellation, ThreadLock
|
||||
from .connection import HTTPConnection
|
||||
from .interfaces import ConnectionInterface, RequestInterface
|
||||
|
||||
|
||||
class RequestStatus:
|
||||
def __init__(self, request: Request):
|
||||
class PoolRequest:
|
||||
def __init__(self, request: Request) -> None:
|
||||
self.request = request
|
||||
self.connection: Optional[ConnectionInterface] = None
|
||||
self.connection: ConnectionInterface | None = None
|
||||
self._connection_acquired = Event()
|
||||
|
||||
def set_connection(self, connection: ConnectionInterface) -> None:
|
||||
assert self.connection is None
|
||||
def assign_to_connection(self, connection: ConnectionInterface | None) -> None:
|
||||
self.connection = connection
|
||||
self._connection_acquired.set()
|
||||
|
||||
def unset_connection(self) -> None:
|
||||
assert self.connection is not None
|
||||
def clear_connection(self) -> None:
|
||||
self.connection = None
|
||||
self._connection_acquired = Event()
|
||||
|
||||
def wait_for_connection(
|
||||
self, timeout: Optional[float] = None
|
||||
self, timeout: float | None = None
|
||||
) -> ConnectionInterface:
|
||||
if self.connection is None:
|
||||
self._connection_acquired.wait(timeout=timeout)
|
||||
assert self.connection is not None
|
||||
return self.connection
|
||||
|
||||
def is_queued(self) -> bool:
|
||||
return self.connection is None
|
||||
|
||||
|
||||
class ConnectionPool(RequestInterface):
|
||||
"""
|
||||
@@ -44,17 +47,18 @@ class ConnectionPool(RequestInterface):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ssl_context: Optional[ssl.SSLContext] = None,
|
||||
max_connections: Optional[int] = 10,
|
||||
max_keepalive_connections: Optional[int] = None,
|
||||
keepalive_expiry: Optional[float] = None,
|
||||
ssl_context: ssl.SSLContext | None = None,
|
||||
proxy: Proxy | None = None,
|
||||
max_connections: int | None = 10,
|
||||
max_keepalive_connections: int | None = None,
|
||||
keepalive_expiry: float | None = None,
|
||||
http1: bool = True,
|
||||
http2: bool = False,
|
||||
retries: int = 0,
|
||||
local_address: Optional[str] = None,
|
||||
uds: Optional[str] = None,
|
||||
network_backend: Optional[NetworkBackend] = None,
|
||||
socket_options: Optional[Iterable[SOCKET_OPTION]] = None,
|
||||
local_address: str | None = None,
|
||||
uds: str | None = None,
|
||||
network_backend: NetworkBackend | None = None,
|
||||
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
A connection pool for making HTTP requests.
|
||||
@@ -86,7 +90,7 @@ class ConnectionPool(RequestInterface):
|
||||
in the TCP socket when the connection was established.
|
||||
"""
|
||||
self._ssl_context = ssl_context
|
||||
|
||||
self._proxy = proxy
|
||||
self._max_connections = (
|
||||
sys.maxsize if max_connections is None else max_connections
|
||||
)
|
||||
@@ -106,15 +110,61 @@ class ConnectionPool(RequestInterface):
|
||||
self._local_address = local_address
|
||||
self._uds = uds
|
||||
|
||||
self._pool: List[ConnectionInterface] = []
|
||||
self._requests: List[RequestStatus] = []
|
||||
self._pool_lock = Lock()
|
||||
self._network_backend = (
|
||||
SyncBackend() if network_backend is None else network_backend
|
||||
)
|
||||
self._socket_options = socket_options
|
||||
|
||||
# The mutable state on a connection pool is the queue of incoming requests,
|
||||
# and the set of connections that are servicing those requests.
|
||||
self._connections: list[ConnectionInterface] = []
|
||||
self._requests: list[PoolRequest] = []
|
||||
|
||||
# We only mutate the state of the connection pool within an 'optional_thread_lock'
|
||||
# context. This holds a threading lock unless we're running in async mode,
|
||||
# in which case it is a no-op.
|
||||
self._optional_thread_lock = ThreadLock()
|
||||
|
||||
def create_connection(self, origin: Origin) -> ConnectionInterface:
|
||||
if self._proxy is not None:
|
||||
if self._proxy.url.scheme in (b"socks5", b"socks5h"):
|
||||
from .socks_proxy import Socks5Connection
|
||||
|
||||
return Socks5Connection(
|
||||
proxy_origin=self._proxy.url.origin,
|
||||
proxy_auth=self._proxy.auth,
|
||||
remote_origin=origin,
|
||||
ssl_context=self._ssl_context,
|
||||
keepalive_expiry=self._keepalive_expiry,
|
||||
http1=self._http1,
|
||||
http2=self._http2,
|
||||
network_backend=self._network_backend,
|
||||
)
|
||||
elif origin.scheme == b"http":
|
||||
from .http_proxy import ForwardHTTPConnection
|
||||
|
||||
return ForwardHTTPConnection(
|
||||
proxy_origin=self._proxy.url.origin,
|
||||
proxy_headers=self._proxy.headers,
|
||||
proxy_ssl_context=self._proxy.ssl_context,
|
||||
remote_origin=origin,
|
||||
keepalive_expiry=self._keepalive_expiry,
|
||||
network_backend=self._network_backend,
|
||||
)
|
||||
from .http_proxy import TunnelHTTPConnection
|
||||
|
||||
return TunnelHTTPConnection(
|
||||
proxy_origin=self._proxy.url.origin,
|
||||
proxy_headers=self._proxy.headers,
|
||||
proxy_ssl_context=self._proxy.ssl_context,
|
||||
remote_origin=origin,
|
||||
ssl_context=self._ssl_context,
|
||||
keepalive_expiry=self._keepalive_expiry,
|
||||
http1=self._http1,
|
||||
http2=self._http2,
|
||||
network_backend=self._network_backend,
|
||||
)
|
||||
|
||||
return HTTPConnection(
|
||||
origin=origin,
|
||||
ssl_context=self._ssl_context,
|
||||
@@ -129,7 +179,7 @@ class ConnectionPool(RequestInterface):
|
||||
)
|
||||
|
||||
@property
|
||||
def connections(self) -> List[ConnectionInterface]:
|
||||
def connections(self) -> list[ConnectionInterface]:
|
||||
"""
|
||||
Return a list of the connections currently in the pool.
|
||||
|
||||
@@ -144,64 +194,7 @@ class ConnectionPool(RequestInterface):
|
||||
]
|
||||
```
|
||||
"""
|
||||
return list(self._pool)
|
||||
|
||||
def _attempt_to_acquire_connection(self, status: RequestStatus) -> bool:
|
||||
"""
|
||||
Attempt to provide a connection that can handle the given origin.
|
||||
"""
|
||||
origin = status.request.url.origin
|
||||
|
||||
# If there are queued requests in front of us, then don't acquire a
|
||||
# connection. We handle requests strictly in order.
|
||||
waiting = [s for s in self._requests if s.connection is None]
|
||||
if waiting and waiting[0] is not status:
|
||||
return False
|
||||
|
||||
# Reuse an existing connection if one is currently available.
|
||||
for idx, connection in enumerate(self._pool):
|
||||
if connection.can_handle_request(origin) and connection.is_available():
|
||||
self._pool.pop(idx)
|
||||
self._pool.insert(0, connection)
|
||||
status.set_connection(connection)
|
||||
return True
|
||||
|
||||
# If the pool is currently full, attempt to close one idle connection.
|
||||
if len(self._pool) >= self._max_connections:
|
||||
for idx, connection in reversed(list(enumerate(self._pool))):
|
||||
if connection.is_idle():
|
||||
connection.close()
|
||||
self._pool.pop(idx)
|
||||
break
|
||||
|
||||
# If the pool is still full, then we cannot acquire a connection.
|
||||
if len(self._pool) >= self._max_connections:
|
||||
return False
|
||||
|
||||
# Otherwise create a new connection.
|
||||
connection = self.create_connection(origin)
|
||||
self._pool.insert(0, connection)
|
||||
status.set_connection(connection)
|
||||
return True
|
||||
|
||||
def _close_expired_connections(self) -> None:
|
||||
"""
|
||||
Clean up the connection pool by closing off any connections that have expired.
|
||||
"""
|
||||
# Close any connections that have expired their keep-alive time.
|
||||
for idx, connection in reversed(list(enumerate(self._pool))):
|
||||
if connection.has_expired():
|
||||
connection.close()
|
||||
self._pool.pop(idx)
|
||||
|
||||
# If the pool size exceeds the maximum number of allowed keep-alive connections,
|
||||
# then close off idle connections as required.
|
||||
pool_size = len(self._pool)
|
||||
for idx, connection in reversed(list(enumerate(self._pool))):
|
||||
if connection.is_idle() and pool_size > self._max_keepalive_connections:
|
||||
connection.close()
|
||||
self._pool.pop(idx)
|
||||
pool_size -= 1
|
||||
return list(self._connections)
|
||||
|
||||
def handle_request(self, request: Request) -> Response:
|
||||
"""
|
||||
@@ -219,138 +212,209 @@ class ConnectionPool(RequestInterface):
|
||||
f"Request URL has an unsupported protocol '{scheme}://'."
|
||||
)
|
||||
|
||||
status = RequestStatus(request)
|
||||
timeouts = request.extensions.get("timeout", {})
|
||||
timeout = timeouts.get("pool", None)
|
||||
|
||||
with self._pool_lock:
|
||||
self._requests.append(status)
|
||||
self._close_expired_connections()
|
||||
self._attempt_to_acquire_connection(status)
|
||||
with self._optional_thread_lock:
|
||||
# Add the incoming request to our request queue.
|
||||
pool_request = PoolRequest(request)
|
||||
self._requests.append(pool_request)
|
||||
|
||||
while True:
|
||||
timeouts = request.extensions.get("timeout", {})
|
||||
timeout = timeouts.get("pool", None)
|
||||
try:
|
||||
connection = status.wait_for_connection(timeout=timeout)
|
||||
except BaseException as exc:
|
||||
# If we timeout here, or if the task is cancelled, then make
|
||||
# sure to remove the request from the queue before bubbling
|
||||
# up the exception.
|
||||
with self._pool_lock:
|
||||
# Ensure only remove when task exists.
|
||||
if status in self._requests:
|
||||
self._requests.remove(status)
|
||||
raise exc
|
||||
try:
|
||||
while True:
|
||||
with self._optional_thread_lock:
|
||||
# Assign incoming requests to available connections,
|
||||
# closing or creating new connections as required.
|
||||
closing = self._assign_requests_to_connections()
|
||||
self._close_connections(closing)
|
||||
|
||||
try:
|
||||
response = connection.handle_request(request)
|
||||
except ConnectionNotAvailable:
|
||||
# The ConnectionNotAvailable exception is a special case, that
|
||||
# indicates we need to retry the request on a new connection.
|
||||
#
|
||||
# The most common case where this can occur is when multiple
|
||||
# requests are queued waiting for a single connection, which
|
||||
# might end up as an HTTP/2 connection, but which actually ends
|
||||
# up as HTTP/1.1.
|
||||
with self._pool_lock:
|
||||
# Maintain our position in the request queue, but reset the
|
||||
# status so that the request becomes queued again.
|
||||
status.unset_connection()
|
||||
self._attempt_to_acquire_connection(status)
|
||||
except BaseException as exc:
|
||||
with ShieldCancellation():
|
||||
self.response_closed(status)
|
||||
raise exc
|
||||
else:
|
||||
break
|
||||
# Wait until this request has an assigned connection.
|
||||
connection = pool_request.wait_for_connection(timeout=timeout)
|
||||
|
||||
# When we return the response, we wrap the stream in a special class
|
||||
# that handles notifying the connection pool once the response
|
||||
# has been released.
|
||||
assert isinstance(response.stream, Iterable)
|
||||
try:
|
||||
# Send the request on the assigned connection.
|
||||
response = connection.handle_request(
|
||||
pool_request.request
|
||||
)
|
||||
except ConnectionNotAvailable:
|
||||
# In some cases a connection may initially be available to
|
||||
# handle a request, but then become unavailable.
|
||||
#
|
||||
# In this case we clear the connection and try again.
|
||||
pool_request.clear_connection()
|
||||
else:
|
||||
break # pragma: nocover
|
||||
|
||||
except BaseException as exc:
|
||||
with self._optional_thread_lock:
|
||||
# For any exception or cancellation we remove the request from
|
||||
# the queue, and then re-assign requests to connections.
|
||||
self._requests.remove(pool_request)
|
||||
closing = self._assign_requests_to_connections()
|
||||
|
||||
self._close_connections(closing)
|
||||
raise exc from None
|
||||
|
||||
# Return the response. Note that in this case we still have to manage
|
||||
# the point at which the response is closed.
|
||||
assert isinstance(response.stream, typing.Iterable)
|
||||
return Response(
|
||||
status=response.status,
|
||||
headers=response.headers,
|
||||
content=ConnectionPoolByteStream(response.stream, self, status),
|
||||
content=PoolByteStream(
|
||||
stream=response.stream, pool_request=pool_request, pool=self
|
||||
),
|
||||
extensions=response.extensions,
|
||||
)
|
||||
|
||||
def response_closed(self, status: RequestStatus) -> None:
|
||||
def _assign_requests_to_connections(self) -> list[ConnectionInterface]:
|
||||
"""
|
||||
This method acts as a callback once the request/response cycle is complete.
|
||||
Manage the state of the connection pool, assigning incoming
|
||||
requests to connections as available.
|
||||
|
||||
It is called into from the `ConnectionPoolByteStream.close()` method.
|
||||
Called whenever a new request is added or removed from the pool.
|
||||
|
||||
Any closing connections are returned, allowing the I/O for closing
|
||||
those connections to be handled seperately.
|
||||
"""
|
||||
assert status.connection is not None
|
||||
connection = status.connection
|
||||
closing_connections = []
|
||||
|
||||
with self._pool_lock:
|
||||
# Update the state of the connection pool.
|
||||
if status in self._requests:
|
||||
self._requests.remove(status)
|
||||
# First we handle cleaning up any connections that are closed,
|
||||
# have expired their keep-alive, or surplus idle connections.
|
||||
for connection in list(self._connections):
|
||||
if connection.is_closed():
|
||||
# log: "removing closed connection"
|
||||
self._connections.remove(connection)
|
||||
elif connection.has_expired():
|
||||
# log: "closing expired connection"
|
||||
self._connections.remove(connection)
|
||||
closing_connections.append(connection)
|
||||
elif (
|
||||
connection.is_idle()
|
||||
and len([connection.is_idle() for connection in self._connections])
|
||||
> self._max_keepalive_connections
|
||||
):
|
||||
# log: "closing idle connection"
|
||||
self._connections.remove(connection)
|
||||
closing_connections.append(connection)
|
||||
|
||||
if connection.is_closed() and connection in self._pool:
|
||||
self._pool.remove(connection)
|
||||
# Assign queued requests to connections.
|
||||
queued_requests = [request for request in self._requests if request.is_queued()]
|
||||
for pool_request in queued_requests:
|
||||
origin = pool_request.request.url.origin
|
||||
available_connections = [
|
||||
connection
|
||||
for connection in self._connections
|
||||
if connection.can_handle_request(origin) and connection.is_available()
|
||||
]
|
||||
idle_connections = [
|
||||
connection for connection in self._connections if connection.is_idle()
|
||||
]
|
||||
|
||||
# Since we've had a response closed, it's possible we'll now be able
|
||||
# to service one or more requests that are currently pending.
|
||||
for status in self._requests:
|
||||
if status.connection is None:
|
||||
acquired = self._attempt_to_acquire_connection(status)
|
||||
# If we could not acquire a connection for a queued request
|
||||
# then we don't need to check anymore requests that are
|
||||
# queued later behind it.
|
||||
if not acquired:
|
||||
break
|
||||
# There are three cases for how we may be able to handle the request:
|
||||
#
|
||||
# 1. There is an existing connection that can handle the request.
|
||||
# 2. We can create a new connection to handle the request.
|
||||
# 3. We can close an idle connection and then create a new connection
|
||||
# to handle the request.
|
||||
if available_connections:
|
||||
# log: "reusing existing connection"
|
||||
connection = available_connections[0]
|
||||
pool_request.assign_to_connection(connection)
|
||||
elif len(self._connections) < self._max_connections:
|
||||
# log: "creating new connection"
|
||||
connection = self.create_connection(origin)
|
||||
self._connections.append(connection)
|
||||
pool_request.assign_to_connection(connection)
|
||||
elif idle_connections:
|
||||
# log: "closing idle connection"
|
||||
connection = idle_connections[0]
|
||||
self._connections.remove(connection)
|
||||
closing_connections.append(connection)
|
||||
# log: "creating new connection"
|
||||
connection = self.create_connection(origin)
|
||||
self._connections.append(connection)
|
||||
pool_request.assign_to_connection(connection)
|
||||
|
||||
# Housekeeping.
|
||||
self._close_expired_connections()
|
||||
return closing_connections
|
||||
|
||||
def _close_connections(self, closing: list[ConnectionInterface]) -> None:
|
||||
# Close connections which have been removed from the pool.
|
||||
with ShieldCancellation():
|
||||
for connection in closing:
|
||||
connection.close()
|
||||
|
||||
def close(self) -> None:
|
||||
"""
|
||||
Close any connections in the pool.
|
||||
"""
|
||||
with self._pool_lock:
|
||||
for connection in self._pool:
|
||||
connection.close()
|
||||
self._pool = []
|
||||
self._requests = []
|
||||
# Explicitly close the connection pool.
|
||||
# Clears all existing requests and connections.
|
||||
with self._optional_thread_lock:
|
||||
closing_connections = list(self._connections)
|
||||
self._connections = []
|
||||
self._close_connections(closing_connections)
|
||||
|
||||
def __enter__(self) -> "ConnectionPool":
|
||||
def __enter__(self) -> ConnectionPool:
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]] = None,
|
||||
exc_value: Optional[BaseException] = None,
|
||||
traceback: Optional[TracebackType] = None,
|
||||
exc_type: type[BaseException] | None = None,
|
||||
exc_value: BaseException | None = None,
|
||||
traceback: types.TracebackType | None = None,
|
||||
) -> None:
|
||||
self.close()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
class_name = self.__class__.__name__
|
||||
with self._optional_thread_lock:
|
||||
request_is_queued = [request.is_queued() for request in self._requests]
|
||||
connection_is_idle = [
|
||||
connection.is_idle() for connection in self._connections
|
||||
]
|
||||
|
||||
class ConnectionPoolByteStream:
|
||||
"""
|
||||
A wrapper around the response byte stream, that additionally handles
|
||||
notifying the connection pool when the response has been closed.
|
||||
"""
|
||||
num_active_requests = request_is_queued.count(False)
|
||||
num_queued_requests = request_is_queued.count(True)
|
||||
num_active_connections = connection_is_idle.count(False)
|
||||
num_idle_connections = connection_is_idle.count(True)
|
||||
|
||||
requests_info = (
|
||||
f"Requests: {num_active_requests} active, {num_queued_requests} queued"
|
||||
)
|
||||
connection_info = (
|
||||
f"Connections: {num_active_connections} active, {num_idle_connections} idle"
|
||||
)
|
||||
|
||||
return f"<{class_name} [{requests_info} | {connection_info}]>"
|
||||
|
||||
|
||||
class PoolByteStream:
|
||||
def __init__(
|
||||
self,
|
||||
stream: Iterable[bytes],
|
||||
stream: typing.Iterable[bytes],
|
||||
pool_request: PoolRequest,
|
||||
pool: ConnectionPool,
|
||||
status: RequestStatus,
|
||||
) -> None:
|
||||
self._stream = stream
|
||||
self._pool_request = pool_request
|
||||
self._pool = pool
|
||||
self._status = status
|
||||
self._closed = False
|
||||
|
||||
def __iter__(self) -> Iterator[bytes]:
|
||||
for part in self._stream:
|
||||
yield part
|
||||
def __iter__(self) -> typing.Iterator[bytes]:
|
||||
try:
|
||||
for part in self._stream:
|
||||
yield part
|
||||
except BaseException as exc:
|
||||
self.close()
|
||||
raise exc from None
|
||||
|
||||
def close(self) -> None:
|
||||
try:
|
||||
if hasattr(self._stream, "close"):
|
||||
self._stream.close()
|
||||
finally:
|
||||
if not self._closed:
|
||||
self._closed = True
|
||||
with ShieldCancellation():
|
||||
self._pool.response_closed(self._status)
|
||||
if hasattr(self._stream, "close"):
|
||||
self._stream.close()
|
||||
|
||||
with self._pool._optional_thread_lock:
|
||||
self._pool._requests.remove(self._pool_request)
|
||||
closing = self._pool._assign_requests_to_connections()
|
||||
|
||||
self._pool._close_connections(closing)
|
||||
|
||||
@@ -1,17 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
import logging
|
||||
import ssl
|
||||
import time
|
||||
from types import TracebackType
|
||||
from typing import (
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
import types
|
||||
import typing
|
||||
|
||||
import h11
|
||||
|
||||
@@ -20,6 +14,7 @@ from .._exceptions import (
|
||||
ConnectionNotAvailable,
|
||||
LocalProtocolError,
|
||||
RemoteProtocolError,
|
||||
WriteError,
|
||||
map_exceptions,
|
||||
)
|
||||
from .._models import Origin, Request, Response
|
||||
@@ -31,7 +26,7 @@ logger = logging.getLogger("httpcore.http11")
|
||||
|
||||
|
||||
# A subset of `h11.Event` types supported by `_send_event`
|
||||
H11SendEvent = Union[
|
||||
H11SendEvent = typing.Union[
|
||||
h11.Request,
|
||||
h11.Data,
|
||||
h11.EndOfMessage,
|
||||
@@ -53,12 +48,12 @@ class HTTP11Connection(ConnectionInterface):
|
||||
self,
|
||||
origin: Origin,
|
||||
stream: NetworkStream,
|
||||
keepalive_expiry: Optional[float] = None,
|
||||
keepalive_expiry: float | None = None,
|
||||
) -> None:
|
||||
self._origin = origin
|
||||
self._network_stream = stream
|
||||
self._keepalive_expiry: Optional[float] = keepalive_expiry
|
||||
self._expire_at: Optional[float] = None
|
||||
self._keepalive_expiry: float | None = keepalive_expiry
|
||||
self._expire_at: float | None = None
|
||||
self._state = HTTPConnectionState.NEW
|
||||
self._state_lock = Lock()
|
||||
self._request_count = 0
|
||||
@@ -84,10 +79,21 @@ class HTTP11Connection(ConnectionInterface):
|
||||
|
||||
try:
|
||||
kwargs = {"request": request}
|
||||
with Trace("send_request_headers", logger, request, kwargs) as trace:
|
||||
self._send_request_headers(**kwargs)
|
||||
with Trace("send_request_body", logger, request, kwargs) as trace:
|
||||
self._send_request_body(**kwargs)
|
||||
try:
|
||||
with Trace(
|
||||
"send_request_headers", logger, request, kwargs
|
||||
) as trace:
|
||||
self._send_request_headers(**kwargs)
|
||||
with Trace("send_request_body", logger, request, kwargs) as trace:
|
||||
self._send_request_body(**kwargs)
|
||||
except WriteError:
|
||||
# If we get a write error while we're writing the request,
|
||||
# then we supress this error and move on to attempting to
|
||||
# read the response. Servers can sometimes close the request
|
||||
# pre-emptively and then respond with a well formed HTTP
|
||||
# error response.
|
||||
pass
|
||||
|
||||
with Trace(
|
||||
"receive_response_headers", logger, request, kwargs
|
||||
) as trace:
|
||||
@@ -96,6 +102,7 @@ class HTTP11Connection(ConnectionInterface):
|
||||
status,
|
||||
reason_phrase,
|
||||
headers,
|
||||
trailing_data,
|
||||
) = self._receive_response_headers(**kwargs)
|
||||
trace.return_value = (
|
||||
http_version,
|
||||
@@ -104,6 +111,14 @@ class HTTP11Connection(ConnectionInterface):
|
||||
headers,
|
||||
)
|
||||
|
||||
network_stream = self._network_stream
|
||||
|
||||
# CONNECT or Upgrade request
|
||||
if (status == 101) or (
|
||||
(request.method == b"CONNECT") and (200 <= status < 300)
|
||||
):
|
||||
network_stream = HTTP11UpgradeStream(network_stream, trailing_data)
|
||||
|
||||
return Response(
|
||||
status=status,
|
||||
headers=headers,
|
||||
@@ -111,7 +126,7 @@ class HTTP11Connection(ConnectionInterface):
|
||||
extensions={
|
||||
"http_version": http_version,
|
||||
"reason_phrase": reason_phrase,
|
||||
"network_stream": self._network_stream,
|
||||
"network_stream": network_stream,
|
||||
},
|
||||
)
|
||||
except BaseException as exc:
|
||||
@@ -138,16 +153,14 @@ class HTTP11Connection(ConnectionInterface):
|
||||
timeouts = request.extensions.get("timeout", {})
|
||||
timeout = timeouts.get("write", None)
|
||||
|
||||
assert isinstance(request.stream, Iterable)
|
||||
assert isinstance(request.stream, typing.Iterable)
|
||||
for chunk in request.stream:
|
||||
event = h11.Data(data=chunk)
|
||||
self._send_event(event, timeout=timeout)
|
||||
|
||||
self._send_event(h11.EndOfMessage(), timeout=timeout)
|
||||
|
||||
def _send_event(
|
||||
self, event: h11.Event, timeout: Optional[float] = None
|
||||
) -> None:
|
||||
def _send_event(self, event: h11.Event, timeout: float | None = None) -> None:
|
||||
bytes_to_send = self._h11_state.send(event)
|
||||
if bytes_to_send is not None:
|
||||
self._network_stream.write(bytes_to_send, timeout=timeout)
|
||||
@@ -156,7 +169,7 @@ class HTTP11Connection(ConnectionInterface):
|
||||
|
||||
def _receive_response_headers(
|
||||
self, request: Request
|
||||
) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]]]:
|
||||
) -> tuple[bytes, int, bytes, list[tuple[bytes, bytes]], bytes]:
|
||||
timeouts = request.extensions.get("timeout", {})
|
||||
timeout = timeouts.get("read", None)
|
||||
|
||||
@@ -176,9 +189,13 @@ class HTTP11Connection(ConnectionInterface):
|
||||
# raw header casing, rather than the enforced lowercase headers.
|
||||
headers = event.headers.raw_items()
|
||||
|
||||
return http_version, event.status_code, event.reason, headers
|
||||
trailing_data, _ = self._h11_state.trailing_data
|
||||
|
||||
def _receive_response_body(self, request: Request) -> Iterator[bytes]:
|
||||
return http_version, event.status_code, event.reason, headers, trailing_data
|
||||
|
||||
def _receive_response_body(
|
||||
self, request: Request
|
||||
) -> typing.Iterator[bytes]:
|
||||
timeouts = request.extensions.get("timeout", {})
|
||||
timeout = timeouts.get("read", None)
|
||||
|
||||
@@ -190,8 +207,8 @@ class HTTP11Connection(ConnectionInterface):
|
||||
break
|
||||
|
||||
def _receive_event(
|
||||
self, timeout: Optional[float] = None
|
||||
) -> Union[h11.Event, Type[h11.PAUSED]]:
|
||||
self, timeout: float | None = None
|
||||
) -> h11.Event | type[h11.PAUSED]:
|
||||
while True:
|
||||
with map_exceptions({h11.RemoteProtocolError: RemoteProtocolError}):
|
||||
event = self._h11_state.next_event()
|
||||
@@ -216,7 +233,7 @@ class HTTP11Connection(ConnectionInterface):
|
||||
self._h11_state.receive_data(data)
|
||||
else:
|
||||
# mypy fails to narrow the type in the above if statement above
|
||||
return cast(Union[h11.Event, Type[h11.PAUSED]], event)
|
||||
return event # type: ignore[return-value]
|
||||
|
||||
def _response_closed(self) -> None:
|
||||
with self._state_lock:
|
||||
@@ -292,14 +309,14 @@ class HTTP11Connection(ConnectionInterface):
|
||||
# These context managers are not used in the standard flow, but are
|
||||
# useful for testing or working with connection instances directly.
|
||||
|
||||
def __enter__(self) -> "HTTP11Connection":
|
||||
def __enter__(self) -> HTTP11Connection:
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]] = None,
|
||||
exc_value: Optional[BaseException] = None,
|
||||
traceback: Optional[TracebackType] = None,
|
||||
exc_type: type[BaseException] | None = None,
|
||||
exc_value: BaseException | None = None,
|
||||
traceback: types.TracebackType | None = None,
|
||||
) -> None:
|
||||
self.close()
|
||||
|
||||
@@ -310,7 +327,7 @@ class HTTP11ConnectionByteStream:
|
||||
self._request = request
|
||||
self._closed = False
|
||||
|
||||
def __iter__(self) -> Iterator[bytes]:
|
||||
def __iter__(self) -> typing.Iterator[bytes]:
|
||||
kwargs = {"request": self._request}
|
||||
try:
|
||||
with Trace("receive_response_body", logger, self._request, kwargs):
|
||||
@@ -329,3 +346,34 @@ class HTTP11ConnectionByteStream:
|
||||
self._closed = True
|
||||
with Trace("response_closed", logger, self._request):
|
||||
self._connection._response_closed()
|
||||
|
||||
|
||||
class HTTP11UpgradeStream(NetworkStream):
|
||||
def __init__(self, stream: NetworkStream, leading_data: bytes) -> None:
|
||||
self._stream = stream
|
||||
self._leading_data = leading_data
|
||||
|
||||
def read(self, max_bytes: int, timeout: float | None = None) -> bytes:
|
||||
if self._leading_data:
|
||||
buffer = self._leading_data[:max_bytes]
|
||||
self._leading_data = self._leading_data[max_bytes:]
|
||||
return buffer
|
||||
else:
|
||||
return self._stream.read(max_bytes, timeout)
|
||||
|
||||
def write(self, buffer: bytes, timeout: float | None = None) -> None:
|
||||
self._stream.write(buffer, timeout)
|
||||
|
||||
def close(self) -> None:
|
||||
self._stream.close()
|
||||
|
||||
def start_tls(
|
||||
self,
|
||||
ssl_context: ssl.SSLContext,
|
||||
server_hostname: str | None = None,
|
||||
timeout: float | None = None,
|
||||
) -> NetworkStream:
|
||||
return self._stream.start_tls(ssl_context, server_hostname, timeout)
|
||||
|
||||
def get_extra_info(self, info: str) -> typing.Any:
|
||||
return self._stream.get_extra_info(info)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
import logging
|
||||
import time
|
||||
@@ -45,14 +47,14 @@ class HTTP2Connection(ConnectionInterface):
|
||||
self,
|
||||
origin: Origin,
|
||||
stream: NetworkStream,
|
||||
keepalive_expiry: typing.Optional[float] = None,
|
||||
keepalive_expiry: float | None = None,
|
||||
):
|
||||
self._origin = origin
|
||||
self._network_stream = stream
|
||||
self._keepalive_expiry: typing.Optional[float] = keepalive_expiry
|
||||
self._keepalive_expiry: float | None = keepalive_expiry
|
||||
self._h2_state = h2.connection.H2Connection(config=self.CONFIG)
|
||||
self._state = HTTPConnectionState.IDLE
|
||||
self._expire_at: typing.Optional[float] = None
|
||||
self._expire_at: float | None = None
|
||||
self._request_count = 0
|
||||
self._init_lock = Lock()
|
||||
self._state_lock = Lock()
|
||||
@@ -63,24 +65,22 @@ class HTTP2Connection(ConnectionInterface):
|
||||
self._connection_error = False
|
||||
|
||||
# Mapping from stream ID to response stream events.
|
||||
self._events: typing.Dict[
|
||||
self._events: dict[
|
||||
int,
|
||||
typing.Union[
|
||||
h2.events.ResponseReceived,
|
||||
h2.events.DataReceived,
|
||||
h2.events.StreamEnded,
|
||||
h2.events.StreamReset,
|
||||
list[
|
||||
h2.events.ResponseReceived
|
||||
| h2.events.DataReceived
|
||||
| h2.events.StreamEnded
|
||||
| h2.events.StreamReset,
|
||||
],
|
||||
] = {}
|
||||
|
||||
# Connection terminated events are stored as state since
|
||||
# we need to handle them for all streams.
|
||||
self._connection_terminated: typing.Optional[
|
||||
h2.events.ConnectionTerminated
|
||||
] = None
|
||||
self._connection_terminated: h2.events.ConnectionTerminated | None = None
|
||||
|
||||
self._read_exception: typing.Optional[Exception] = None
|
||||
self._write_exception: typing.Optional[Exception] = None
|
||||
self._read_exception: Exception | None = None
|
||||
self._write_exception: Exception | None = None
|
||||
|
||||
def handle_request(self, request: Request) -> Response:
|
||||
if not self.can_handle_request(request.url.origin):
|
||||
@@ -104,9 +104,11 @@ class HTTP2Connection(ConnectionInterface):
|
||||
with self._init_lock:
|
||||
if not self._sent_connection_init:
|
||||
try:
|
||||
kwargs = {"request": request}
|
||||
with Trace("send_connection_init", logger, request, kwargs):
|
||||
self._send_connection_init(**kwargs)
|
||||
sci_kwargs = {"request": request}
|
||||
with Trace(
|
||||
"send_connection_init", logger, request, sci_kwargs
|
||||
):
|
||||
self._send_connection_init(**sci_kwargs)
|
||||
except BaseException as exc:
|
||||
with ShieldCancellation():
|
||||
self.close()
|
||||
@@ -284,7 +286,7 @@ class HTTP2Connection(ConnectionInterface):
|
||||
|
||||
def _receive_response(
|
||||
self, request: Request, stream_id: int
|
||||
) -> typing.Tuple[int, typing.List[typing.Tuple[bytes, bytes]]]:
|
||||
) -> tuple[int, list[tuple[bytes, bytes]]]:
|
||||
"""
|
||||
Return the response status code and headers for a given stream ID.
|
||||
"""
|
||||
@@ -295,6 +297,7 @@ class HTTP2Connection(ConnectionInterface):
|
||||
|
||||
status_code = 200
|
||||
headers = []
|
||||
assert event.headers is not None
|
||||
for k, v in event.headers:
|
||||
if k == b":status":
|
||||
status_code = int(v.decode("ascii", errors="ignore"))
|
||||
@@ -312,6 +315,8 @@ class HTTP2Connection(ConnectionInterface):
|
||||
while True:
|
||||
event = self._receive_stream_event(request, stream_id)
|
||||
if isinstance(event, h2.events.DataReceived):
|
||||
assert event.flow_controlled_length is not None
|
||||
assert event.data is not None
|
||||
amount = event.flow_controlled_length
|
||||
self._h2_state.acknowledge_received_data(amount, stream_id)
|
||||
self._write_outgoing_data(request)
|
||||
@@ -321,9 +326,7 @@ class HTTP2Connection(ConnectionInterface):
|
||||
|
||||
def _receive_stream_event(
|
||||
self, request: Request, stream_id: int
|
||||
) -> typing.Union[
|
||||
h2.events.ResponseReceived, h2.events.DataReceived, h2.events.StreamEnded
|
||||
]:
|
||||
) -> h2.events.ResponseReceived | h2.events.DataReceived | h2.events.StreamEnded:
|
||||
"""
|
||||
Return the next available event for a given stream ID.
|
||||
|
||||
@@ -337,7 +340,7 @@ class HTTP2Connection(ConnectionInterface):
|
||||
return event
|
||||
|
||||
def _receive_events(
|
||||
self, request: Request, stream_id: typing.Optional[int] = None
|
||||
self, request: Request, stream_id: int | None = None
|
||||
) -> None:
|
||||
"""
|
||||
Read some data from the network until we see one or more events
|
||||
@@ -384,7 +387,9 @@ class HTTP2Connection(ConnectionInterface):
|
||||
|
||||
self._write_outgoing_data(request)
|
||||
|
||||
def _receive_remote_settings_change(self, event: h2.events.Event) -> None:
|
||||
def _receive_remote_settings_change(
|
||||
self, event: h2.events.RemoteSettingsChanged
|
||||
) -> None:
|
||||
max_concurrent_streams = event.changed_settings.get(
|
||||
h2.settings.SettingCodes.MAX_CONCURRENT_STREAMS
|
||||
)
|
||||
@@ -425,9 +430,7 @@ class HTTP2Connection(ConnectionInterface):
|
||||
|
||||
# Wrappers around network read/write operations...
|
||||
|
||||
def _read_incoming_data(
|
||||
self, request: Request
|
||||
) -> typing.List[h2.events.Event]:
|
||||
def _read_incoming_data(self, request: Request) -> list[h2.events.Event]:
|
||||
timeouts = request.extensions.get("timeout", {})
|
||||
timeout = timeouts.get("read", None)
|
||||
|
||||
@@ -451,7 +454,7 @@ class HTTP2Connection(ConnectionInterface):
|
||||
self._connection_error = True
|
||||
raise exc
|
||||
|
||||
events: typing.List[h2.events.Event] = self._h2_state.receive_data(data)
|
||||
events: list[h2.events.Event] = self._h2_state.receive_data(data)
|
||||
|
||||
return events
|
||||
|
||||
@@ -544,14 +547,14 @@ class HTTP2Connection(ConnectionInterface):
|
||||
# These context managers are not used in the standard flow, but are
|
||||
# useful for testing or working with connection instances directly.
|
||||
|
||||
def __enter__(self) -> "HTTP2Connection":
|
||||
def __enter__(self) -> HTTP2Connection:
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: typing.Optional[typing.Type[BaseException]] = None,
|
||||
exc_value: typing.Optional[BaseException] = None,
|
||||
traceback: typing.Optional[types.TracebackType] = None,
|
||||
exc_type: type[BaseException] | None = None,
|
||||
exc_value: BaseException | None = None,
|
||||
traceback: types.TracebackType | None = None,
|
||||
) -> None:
|
||||
self.close()
|
||||
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import logging
|
||||
import ssl
|
||||
from base64 import b64encode
|
||||
from typing import Iterable, List, Mapping, Optional, Sequence, Tuple, Union
|
||||
import typing
|
||||
|
||||
from .._backends.base import SOCKET_OPTION, NetworkBackend
|
||||
from .._exceptions import ProxyError
|
||||
@@ -22,17 +24,18 @@ from .connection_pool import ConnectionPool
|
||||
from .http11 import HTTP11Connection
|
||||
from .interfaces import ConnectionInterface
|
||||
|
||||
HeadersAsSequence = Sequence[Tuple[Union[bytes, str], Union[bytes, str]]]
|
||||
HeadersAsMapping = Mapping[Union[bytes, str], Union[bytes, str]]
|
||||
ByteOrStr = typing.Union[bytes, str]
|
||||
HeadersAsSequence = typing.Sequence[typing.Tuple[ByteOrStr, ByteOrStr]]
|
||||
HeadersAsMapping = typing.Mapping[ByteOrStr, ByteOrStr]
|
||||
|
||||
|
||||
logger = logging.getLogger("httpcore.proxy")
|
||||
|
||||
|
||||
def merge_headers(
|
||||
default_headers: Optional[Sequence[Tuple[bytes, bytes]]] = None,
|
||||
override_headers: Optional[Sequence[Tuple[bytes, bytes]]] = None,
|
||||
) -> List[Tuple[bytes, bytes]]:
|
||||
default_headers: typing.Sequence[tuple[bytes, bytes]] | None = None,
|
||||
override_headers: typing.Sequence[tuple[bytes, bytes]] | None = None,
|
||||
) -> list[tuple[bytes, bytes]]:
|
||||
"""
|
||||
Append default_headers and override_headers, de-duplicating if a key exists
|
||||
in both cases.
|
||||
@@ -48,32 +51,28 @@ def merge_headers(
|
||||
return default_headers + override_headers
|
||||
|
||||
|
||||
def build_auth_header(username: bytes, password: bytes) -> bytes:
|
||||
userpass = username + b":" + password
|
||||
return b"Basic " + b64encode(userpass)
|
||||
|
||||
|
||||
class HTTPProxy(ConnectionPool):
|
||||
class HTTPProxy(ConnectionPool): # pragma: nocover
|
||||
"""
|
||||
A connection pool that sends requests via an HTTP proxy.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
proxy_url: Union[URL, bytes, str],
|
||||
proxy_auth: Optional[Tuple[Union[bytes, str], Union[bytes, str]]] = None,
|
||||
proxy_headers: Union[HeadersAsMapping, HeadersAsSequence, None] = None,
|
||||
ssl_context: Optional[ssl.SSLContext] = None,
|
||||
max_connections: Optional[int] = 10,
|
||||
max_keepalive_connections: Optional[int] = None,
|
||||
keepalive_expiry: Optional[float] = None,
|
||||
proxy_url: URL | bytes | str,
|
||||
proxy_auth: tuple[bytes | str, bytes | str] | None = None,
|
||||
proxy_headers: HeadersAsMapping | HeadersAsSequence | None = None,
|
||||
ssl_context: ssl.SSLContext | None = None,
|
||||
proxy_ssl_context: ssl.SSLContext | None = None,
|
||||
max_connections: int | None = 10,
|
||||
max_keepalive_connections: int | None = None,
|
||||
keepalive_expiry: float | None = None,
|
||||
http1: bool = True,
|
||||
http2: bool = False,
|
||||
retries: int = 0,
|
||||
local_address: Optional[str] = None,
|
||||
uds: Optional[str] = None,
|
||||
network_backend: Optional[NetworkBackend] = None,
|
||||
socket_options: Optional[Iterable[SOCKET_OPTION]] = None,
|
||||
local_address: str | None = None,
|
||||
uds: str | None = None,
|
||||
network_backend: NetworkBackend | None = None,
|
||||
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
A connection pool for making HTTP requests.
|
||||
@@ -88,6 +87,7 @@ class HTTPProxy(ConnectionPool):
|
||||
ssl_context: An SSL context to use for verifying connections.
|
||||
If not specified, the default `httpcore.default_ssl_context()`
|
||||
will be used.
|
||||
proxy_ssl_context: The same as `ssl_context`, but for a proxy server rather than a remote origin.
|
||||
max_connections: The maximum number of concurrent HTTP connections that
|
||||
the pool should allow. Any attempt to send a request on a pool that
|
||||
would exceed this amount will block until a connection is available.
|
||||
@@ -122,13 +122,23 @@ class HTTPProxy(ConnectionPool):
|
||||
uds=uds,
|
||||
socket_options=socket_options,
|
||||
)
|
||||
self._ssl_context = ssl_context
|
||||
|
||||
self._proxy_url = enforce_url(proxy_url, name="proxy_url")
|
||||
if (
|
||||
self._proxy_url.scheme == b"http" and proxy_ssl_context is not None
|
||||
): # pragma: no cover
|
||||
raise RuntimeError(
|
||||
"The `proxy_ssl_context` argument is not allowed for the http scheme"
|
||||
)
|
||||
|
||||
self._ssl_context = ssl_context
|
||||
self._proxy_ssl_context = proxy_ssl_context
|
||||
self._proxy_headers = enforce_headers(proxy_headers, name="proxy_headers")
|
||||
if proxy_auth is not None:
|
||||
username = enforce_bytes(proxy_auth[0], name="proxy_auth")
|
||||
password = enforce_bytes(proxy_auth[1], name="proxy_auth")
|
||||
authorization = build_auth_header(username, password)
|
||||
userpass = username + b":" + password
|
||||
authorization = b"Basic " + base64.b64encode(userpass)
|
||||
self._proxy_headers = [
|
||||
(b"Proxy-Authorization", authorization)
|
||||
] + self._proxy_headers
|
||||
@@ -141,12 +151,14 @@ class HTTPProxy(ConnectionPool):
|
||||
remote_origin=origin,
|
||||
keepalive_expiry=self._keepalive_expiry,
|
||||
network_backend=self._network_backend,
|
||||
proxy_ssl_context=self._proxy_ssl_context,
|
||||
)
|
||||
return TunnelHTTPConnection(
|
||||
proxy_origin=self._proxy_url.origin,
|
||||
proxy_headers=self._proxy_headers,
|
||||
remote_origin=origin,
|
||||
ssl_context=self._ssl_context,
|
||||
proxy_ssl_context=self._proxy_ssl_context,
|
||||
keepalive_expiry=self._keepalive_expiry,
|
||||
http1=self._http1,
|
||||
http2=self._http2,
|
||||
@@ -159,16 +171,18 @@ class ForwardHTTPConnection(ConnectionInterface):
|
||||
self,
|
||||
proxy_origin: Origin,
|
||||
remote_origin: Origin,
|
||||
proxy_headers: Union[HeadersAsMapping, HeadersAsSequence, None] = None,
|
||||
keepalive_expiry: Optional[float] = None,
|
||||
network_backend: Optional[NetworkBackend] = None,
|
||||
socket_options: Optional[Iterable[SOCKET_OPTION]] = None,
|
||||
proxy_headers: HeadersAsMapping | HeadersAsSequence | None = None,
|
||||
keepalive_expiry: float | None = None,
|
||||
network_backend: NetworkBackend | None = None,
|
||||
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
|
||||
proxy_ssl_context: ssl.SSLContext | None = None,
|
||||
) -> None:
|
||||
self._connection = HTTPConnection(
|
||||
origin=proxy_origin,
|
||||
keepalive_expiry=keepalive_expiry,
|
||||
network_backend=network_backend,
|
||||
socket_options=socket_options,
|
||||
ssl_context=proxy_ssl_context,
|
||||
)
|
||||
self._proxy_origin = proxy_origin
|
||||
self._proxy_headers = enforce_headers(proxy_headers, name="proxy_headers")
|
||||
@@ -221,23 +235,26 @@ class TunnelHTTPConnection(ConnectionInterface):
|
||||
self,
|
||||
proxy_origin: Origin,
|
||||
remote_origin: Origin,
|
||||
ssl_context: Optional[ssl.SSLContext] = None,
|
||||
proxy_headers: Optional[Sequence[Tuple[bytes, bytes]]] = None,
|
||||
keepalive_expiry: Optional[float] = None,
|
||||
ssl_context: ssl.SSLContext | None = None,
|
||||
proxy_ssl_context: ssl.SSLContext | None = None,
|
||||
proxy_headers: typing.Sequence[tuple[bytes, bytes]] | None = None,
|
||||
keepalive_expiry: float | None = None,
|
||||
http1: bool = True,
|
||||
http2: bool = False,
|
||||
network_backend: Optional[NetworkBackend] = None,
|
||||
socket_options: Optional[Iterable[SOCKET_OPTION]] = None,
|
||||
network_backend: NetworkBackend | None = None,
|
||||
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
|
||||
) -> None:
|
||||
self._connection: ConnectionInterface = HTTPConnection(
|
||||
origin=proxy_origin,
|
||||
keepalive_expiry=keepalive_expiry,
|
||||
network_backend=network_backend,
|
||||
socket_options=socket_options,
|
||||
ssl_context=proxy_ssl_context,
|
||||
)
|
||||
self._proxy_origin = proxy_origin
|
||||
self._remote_origin = remote_origin
|
||||
self._ssl_context = ssl_context
|
||||
self._proxy_ssl_context = proxy_ssl_context
|
||||
self._proxy_headers = enforce_headers(proxy_headers, name="proxy_headers")
|
||||
self._keepalive_expiry = keepalive_expiry
|
||||
self._http1 = http1
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from contextlib import contextmanager
|
||||
from typing import Iterator, Optional, Union
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import typing
|
||||
|
||||
from .._models import (
|
||||
URL,
|
||||
@@ -18,12 +20,12 @@ from .._models import (
|
||||
class RequestInterface:
|
||||
def request(
|
||||
self,
|
||||
method: Union[bytes, str],
|
||||
url: Union[URL, bytes, str],
|
||||
method: bytes | str,
|
||||
url: URL | bytes | str,
|
||||
*,
|
||||
headers: HeaderTypes = None,
|
||||
content: Union[bytes, Iterator[bytes], None] = None,
|
||||
extensions: Optional[Extensions] = None,
|
||||
content: bytes | typing.Iterator[bytes] | None = None,
|
||||
extensions: Extensions | None = None,
|
||||
) -> Response:
|
||||
# Strict type checking on our parameters.
|
||||
method = enforce_bytes(method, name="method")
|
||||
@@ -47,16 +49,16 @@ class RequestInterface:
|
||||
response.close()
|
||||
return response
|
||||
|
||||
@contextmanager
|
||||
@contextlib.contextmanager
|
||||
def stream(
|
||||
self,
|
||||
method: Union[bytes, str],
|
||||
url: Union[URL, bytes, str],
|
||||
method: bytes | str,
|
||||
url: URL | bytes | str,
|
||||
*,
|
||||
headers: HeaderTypes = None,
|
||||
content: Union[bytes, Iterator[bytes], None] = None,
|
||||
extensions: Optional[Extensions] = None,
|
||||
) -> Iterator[Response]:
|
||||
content: bytes | typing.Iterator[bytes] | None = None,
|
||||
extensions: Extensions | None = None,
|
||||
) -> typing.Iterator[Response]:
|
||||
# Strict type checking on our parameters.
|
||||
method = enforce_bytes(method, name="method")
|
||||
url = enforce_url(url, name="url")
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import ssl
|
||||
import typing
|
||||
|
||||
from socksio import socks5
|
||||
import socksio
|
||||
|
||||
from .._backends.sync import SyncBackend
|
||||
from .._backends.base import NetworkBackend, NetworkStream
|
||||
@@ -43,24 +44,24 @@ def _init_socks5_connection(
|
||||
*,
|
||||
host: bytes,
|
||||
port: int,
|
||||
auth: typing.Optional[typing.Tuple[bytes, bytes]] = None,
|
||||
auth: tuple[bytes, bytes] | None = None,
|
||||
) -> None:
|
||||
conn = socks5.SOCKS5Connection()
|
||||
conn = socksio.socks5.SOCKS5Connection()
|
||||
|
||||
# Auth method request
|
||||
auth_method = (
|
||||
socks5.SOCKS5AuthMethod.NO_AUTH_REQUIRED
|
||||
socksio.socks5.SOCKS5AuthMethod.NO_AUTH_REQUIRED
|
||||
if auth is None
|
||||
else socks5.SOCKS5AuthMethod.USERNAME_PASSWORD
|
||||
else socksio.socks5.SOCKS5AuthMethod.USERNAME_PASSWORD
|
||||
)
|
||||
conn.send(socks5.SOCKS5AuthMethodsRequest([auth_method]))
|
||||
conn.send(socksio.socks5.SOCKS5AuthMethodsRequest([auth_method]))
|
||||
outgoing_bytes = conn.data_to_send()
|
||||
stream.write(outgoing_bytes)
|
||||
|
||||
# Auth method response
|
||||
incoming_bytes = stream.read(max_bytes=4096)
|
||||
response = conn.receive_data(incoming_bytes)
|
||||
assert isinstance(response, socks5.SOCKS5AuthReply)
|
||||
assert isinstance(response, socksio.socks5.SOCKS5AuthReply)
|
||||
if response.method != auth_method:
|
||||
requested = AUTH_METHODS.get(auth_method, "UNKNOWN")
|
||||
responded = AUTH_METHODS.get(response.method, "UNKNOWN")
|
||||
@@ -68,25 +69,25 @@ def _init_socks5_connection(
|
||||
f"Requested {requested} from proxy server, but got {responded}."
|
||||
)
|
||||
|
||||
if response.method == socks5.SOCKS5AuthMethod.USERNAME_PASSWORD:
|
||||
if response.method == socksio.socks5.SOCKS5AuthMethod.USERNAME_PASSWORD:
|
||||
# Username/password request
|
||||
assert auth is not None
|
||||
username, password = auth
|
||||
conn.send(socks5.SOCKS5UsernamePasswordRequest(username, password))
|
||||
conn.send(socksio.socks5.SOCKS5UsernamePasswordRequest(username, password))
|
||||
outgoing_bytes = conn.data_to_send()
|
||||
stream.write(outgoing_bytes)
|
||||
|
||||
# Username/password response
|
||||
incoming_bytes = stream.read(max_bytes=4096)
|
||||
response = conn.receive_data(incoming_bytes)
|
||||
assert isinstance(response, socks5.SOCKS5UsernamePasswordReply)
|
||||
assert isinstance(response, socksio.socks5.SOCKS5UsernamePasswordReply)
|
||||
if not response.success:
|
||||
raise ProxyError("Invalid username/password")
|
||||
|
||||
# Connect request
|
||||
conn.send(
|
||||
socks5.SOCKS5CommandRequest.from_address(
|
||||
socks5.SOCKS5Command.CONNECT, (host, port)
|
||||
socksio.socks5.SOCKS5CommandRequest.from_address(
|
||||
socksio.socks5.SOCKS5Command.CONNECT, (host, port)
|
||||
)
|
||||
)
|
||||
outgoing_bytes = conn.data_to_send()
|
||||
@@ -95,31 +96,29 @@ def _init_socks5_connection(
|
||||
# Connect response
|
||||
incoming_bytes = stream.read(max_bytes=4096)
|
||||
response = conn.receive_data(incoming_bytes)
|
||||
assert isinstance(response, socks5.SOCKS5Reply)
|
||||
if response.reply_code != socks5.SOCKS5ReplyCode.SUCCEEDED:
|
||||
assert isinstance(response, socksio.socks5.SOCKS5Reply)
|
||||
if response.reply_code != socksio.socks5.SOCKS5ReplyCode.SUCCEEDED:
|
||||
reply_code = REPLY_CODES.get(response.reply_code, "UNKOWN")
|
||||
raise ProxyError(f"Proxy Server could not connect: {reply_code}.")
|
||||
|
||||
|
||||
class SOCKSProxy(ConnectionPool):
|
||||
class SOCKSProxy(ConnectionPool): # pragma: nocover
|
||||
"""
|
||||
A connection pool that sends requests via an HTTP proxy.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
proxy_url: typing.Union[URL, bytes, str],
|
||||
proxy_auth: typing.Optional[
|
||||
typing.Tuple[typing.Union[bytes, str], typing.Union[bytes, str]]
|
||||
] = None,
|
||||
ssl_context: typing.Optional[ssl.SSLContext] = None,
|
||||
max_connections: typing.Optional[int] = 10,
|
||||
max_keepalive_connections: typing.Optional[int] = None,
|
||||
keepalive_expiry: typing.Optional[float] = None,
|
||||
proxy_url: URL | bytes | str,
|
||||
proxy_auth: tuple[bytes | str, bytes | str] | None = None,
|
||||
ssl_context: ssl.SSLContext | None = None,
|
||||
max_connections: int | None = 10,
|
||||
max_keepalive_connections: int | None = None,
|
||||
keepalive_expiry: float | None = None,
|
||||
http1: bool = True,
|
||||
http2: bool = False,
|
||||
retries: int = 0,
|
||||
network_backend: typing.Optional[NetworkBackend] = None,
|
||||
network_backend: NetworkBackend | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
A connection pool for making HTTP requests.
|
||||
@@ -167,7 +166,7 @@ class SOCKSProxy(ConnectionPool):
|
||||
username, password = proxy_auth
|
||||
username_bytes = enforce_bytes(username, name="proxy_auth")
|
||||
password_bytes = enforce_bytes(password, name="proxy_auth")
|
||||
self._proxy_auth: typing.Optional[typing.Tuple[bytes, bytes]] = (
|
||||
self._proxy_auth: tuple[bytes, bytes] | None = (
|
||||
username_bytes,
|
||||
password_bytes,
|
||||
)
|
||||
@@ -192,12 +191,12 @@ class Socks5Connection(ConnectionInterface):
|
||||
self,
|
||||
proxy_origin: Origin,
|
||||
remote_origin: Origin,
|
||||
proxy_auth: typing.Optional[typing.Tuple[bytes, bytes]] = None,
|
||||
ssl_context: typing.Optional[ssl.SSLContext] = None,
|
||||
keepalive_expiry: typing.Optional[float] = None,
|
||||
proxy_auth: tuple[bytes, bytes] | None = None,
|
||||
ssl_context: ssl.SSLContext | None = None,
|
||||
keepalive_expiry: float | None = None,
|
||||
http1: bool = True,
|
||||
http2: bool = False,
|
||||
network_backend: typing.Optional[NetworkBackend] = None,
|
||||
network_backend: NetworkBackend | None = None,
|
||||
) -> None:
|
||||
self._proxy_origin = proxy_origin
|
||||
self._remote_origin = remote_origin
|
||||
@@ -211,11 +210,12 @@ class Socks5Connection(ConnectionInterface):
|
||||
SyncBackend() if network_backend is None else network_backend
|
||||
)
|
||||
self._connect_lock = Lock()
|
||||
self._connection: typing.Optional[ConnectionInterface] = None
|
||||
self._connection: ConnectionInterface | None = None
|
||||
self._connect_failed = False
|
||||
|
||||
def handle_request(self, request: Request) -> Response:
|
||||
timeouts = request.extensions.get("timeout", {})
|
||||
sni_hostname = request.extensions.get("sni_hostname", None)
|
||||
timeout = timeouts.get("connect", None)
|
||||
|
||||
with self._connect_lock:
|
||||
@@ -258,7 +258,8 @@ class Socks5Connection(ConnectionInterface):
|
||||
|
||||
kwargs = {
|
||||
"ssl_context": ssl_context,
|
||||
"server_hostname": self._remote_origin.host.decode("ascii"),
|
||||
"server_hostname": sni_hostname
|
||||
or self._remote_origin.host.decode("ascii"),
|
||||
"timeout": timeout,
|
||||
}
|
||||
with Trace("start_tls", logger, request, kwargs) as trace:
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
import threading
|
||||
from types import TracebackType
|
||||
from typing import Optional, Type
|
||||
from __future__ import annotations
|
||||
|
||||
import sniffio
|
||||
import threading
|
||||
import types
|
||||
|
||||
from ._exceptions import ExceptionMapping, PoolTimeout, map_exceptions
|
||||
|
||||
@@ -11,7 +10,7 @@ from ._exceptions import ExceptionMapping, PoolTimeout, map_exceptions
|
||||
|
||||
try:
|
||||
import trio
|
||||
except ImportError: # pragma: nocover
|
||||
except (ImportError, NotImplementedError): # pragma: nocover
|
||||
trio = None # type: ignore
|
||||
|
||||
try:
|
||||
@@ -20,7 +19,40 @@ except ImportError: # pragma: nocover
|
||||
anyio = None # type: ignore
|
||||
|
||||
|
||||
def current_async_library() -> str:
|
||||
# Determine if we're running under trio or asyncio.
|
||||
# See https://sniffio.readthedocs.io/en/latest/
|
||||
try:
|
||||
import sniffio
|
||||
except ImportError: # pragma: nocover
|
||||
environment = "asyncio"
|
||||
else:
|
||||
environment = sniffio.current_async_library()
|
||||
|
||||
if environment not in ("asyncio", "trio"): # pragma: nocover
|
||||
raise RuntimeError("Running under an unsupported async environment.")
|
||||
|
||||
if environment == "asyncio" and anyio is None: # pragma: nocover
|
||||
raise RuntimeError(
|
||||
"Running with asyncio requires installation of 'httpcore[asyncio]'."
|
||||
)
|
||||
|
||||
if environment == "trio" and trio is None: # pragma: nocover
|
||||
raise RuntimeError(
|
||||
"Running with trio requires installation of 'httpcore[trio]'."
|
||||
)
|
||||
|
||||
return environment
|
||||
|
||||
|
||||
class AsyncLock:
|
||||
"""
|
||||
This is a standard lock.
|
||||
|
||||
In the sync case `Lock` provides thread locking.
|
||||
In the async case `AsyncLock` provides async locking.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._backend = ""
|
||||
|
||||
@@ -29,43 +61,55 @@ class AsyncLock:
|
||||
Detect if we're running under 'asyncio' or 'trio' and create
|
||||
a lock with the correct implementation.
|
||||
"""
|
||||
self._backend = sniffio.current_async_library()
|
||||
self._backend = current_async_library()
|
||||
if self._backend == "trio":
|
||||
if trio is None: # pragma: nocover
|
||||
raise RuntimeError(
|
||||
"Running under trio, requires the 'trio' package to be installed."
|
||||
)
|
||||
self._trio_lock = trio.Lock()
|
||||
else:
|
||||
if anyio is None: # pragma: nocover
|
||||
raise RuntimeError(
|
||||
"Running under asyncio requires the 'anyio' package to be installed."
|
||||
)
|
||||
elif self._backend == "asyncio":
|
||||
self._anyio_lock = anyio.Lock()
|
||||
|
||||
async def __aenter__(self) -> "AsyncLock":
|
||||
async def __aenter__(self) -> AsyncLock:
|
||||
if not self._backend:
|
||||
self.setup()
|
||||
|
||||
if self._backend == "trio":
|
||||
await self._trio_lock.acquire()
|
||||
else:
|
||||
elif self._backend == "asyncio":
|
||||
await self._anyio_lock.acquire()
|
||||
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]] = None,
|
||||
exc_value: Optional[BaseException] = None,
|
||||
traceback: Optional[TracebackType] = None,
|
||||
exc_type: type[BaseException] | None = None,
|
||||
exc_value: BaseException | None = None,
|
||||
traceback: types.TracebackType | None = None,
|
||||
) -> None:
|
||||
if self._backend == "trio":
|
||||
self._trio_lock.release()
|
||||
else:
|
||||
elif self._backend == "asyncio":
|
||||
self._anyio_lock.release()
|
||||
|
||||
|
||||
class AsyncThreadLock:
|
||||
"""
|
||||
This is a threading-only lock for no-I/O contexts.
|
||||
|
||||
In the sync case `ThreadLock` provides thread locking.
|
||||
In the async case `AsyncThreadLock` is a no-op.
|
||||
"""
|
||||
|
||||
def __enter__(self) -> AsyncThreadLock:
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None = None,
|
||||
exc_value: BaseException | None = None,
|
||||
traceback: types.TracebackType | None = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class AsyncEvent:
|
||||
def __init__(self) -> None:
|
||||
self._backend = ""
|
||||
@@ -75,18 +119,10 @@ class AsyncEvent:
|
||||
Detect if we're running under 'asyncio' or 'trio' and create
|
||||
a lock with the correct implementation.
|
||||
"""
|
||||
self._backend = sniffio.current_async_library()
|
||||
self._backend = current_async_library()
|
||||
if self._backend == "trio":
|
||||
if trio is None: # pragma: nocover
|
||||
raise RuntimeError(
|
||||
"Running under trio requires the 'trio' package to be installed."
|
||||
)
|
||||
self._trio_event = trio.Event()
|
||||
else:
|
||||
if anyio is None: # pragma: nocover
|
||||
raise RuntimeError(
|
||||
"Running under asyncio requires the 'anyio' package to be installed."
|
||||
)
|
||||
elif self._backend == "asyncio":
|
||||
self._anyio_event = anyio.Event()
|
||||
|
||||
def set(self) -> None:
|
||||
@@ -95,30 +131,20 @@ class AsyncEvent:
|
||||
|
||||
if self._backend == "trio":
|
||||
self._trio_event.set()
|
||||
else:
|
||||
elif self._backend == "asyncio":
|
||||
self._anyio_event.set()
|
||||
|
||||
async def wait(self, timeout: Optional[float] = None) -> None:
|
||||
async def wait(self, timeout: float | None = None) -> None:
|
||||
if not self._backend:
|
||||
self.setup()
|
||||
|
||||
if self._backend == "trio":
|
||||
if trio is None: # pragma: nocover
|
||||
raise RuntimeError(
|
||||
"Running under trio requires the 'trio' package to be installed."
|
||||
)
|
||||
|
||||
trio_exc_map: ExceptionMapping = {trio.TooSlowError: PoolTimeout}
|
||||
timeout_or_inf = float("inf") if timeout is None else timeout
|
||||
with map_exceptions(trio_exc_map):
|
||||
with trio.fail_after(timeout_or_inf):
|
||||
await self._trio_event.wait()
|
||||
else:
|
||||
if anyio is None: # pragma: nocover
|
||||
raise RuntimeError(
|
||||
"Running under asyncio requires the 'anyio' package to be installed."
|
||||
)
|
||||
|
||||
elif self._backend == "asyncio":
|
||||
anyio_exc_map: ExceptionMapping = {TimeoutError: PoolTimeout}
|
||||
with map_exceptions(anyio_exc_map):
|
||||
with anyio.fail_after(timeout):
|
||||
@@ -135,22 +161,12 @@ class AsyncSemaphore:
|
||||
Detect if we're running under 'asyncio' or 'trio' and create
|
||||
a semaphore with the correct implementation.
|
||||
"""
|
||||
self._backend = sniffio.current_async_library()
|
||||
self._backend = current_async_library()
|
||||
if self._backend == "trio":
|
||||
if trio is None: # pragma: nocover
|
||||
raise RuntimeError(
|
||||
"Running under trio requires the 'trio' package to be installed."
|
||||
)
|
||||
|
||||
self._trio_semaphore = trio.Semaphore(
|
||||
initial_value=self._bound, max_value=self._bound
|
||||
)
|
||||
else:
|
||||
if anyio is None: # pragma: nocover
|
||||
raise RuntimeError(
|
||||
"Running under asyncio requires the 'anyio' package to be installed."
|
||||
)
|
||||
|
||||
elif self._backend == "asyncio":
|
||||
self._anyio_semaphore = anyio.Semaphore(
|
||||
initial_value=self._bound, max_value=self._bound
|
||||
)
|
||||
@@ -161,13 +177,13 @@ class AsyncSemaphore:
|
||||
|
||||
if self._backend == "trio":
|
||||
await self._trio_semaphore.acquire()
|
||||
else:
|
||||
elif self._backend == "asyncio":
|
||||
await self._anyio_semaphore.acquire()
|
||||
|
||||
async def release(self) -> None:
|
||||
if self._backend == "trio":
|
||||
self._trio_semaphore.release()
|
||||
else:
|
||||
elif self._backend == "asyncio":
|
||||
self._anyio_semaphore.release()
|
||||
|
||||
|
||||
@@ -184,39 +200,29 @@ class AsyncShieldCancellation:
|
||||
Detect if we're running under 'asyncio' or 'trio' and create
|
||||
a shielded scope with the correct implementation.
|
||||
"""
|
||||
self._backend = sniffio.current_async_library()
|
||||
self._backend = current_async_library()
|
||||
|
||||
if self._backend == "trio":
|
||||
if trio is None: # pragma: nocover
|
||||
raise RuntimeError(
|
||||
"Running under trio requires the 'trio' package to be installed."
|
||||
)
|
||||
|
||||
self._trio_shield = trio.CancelScope(shield=True)
|
||||
else:
|
||||
if anyio is None: # pragma: nocover
|
||||
raise RuntimeError(
|
||||
"Running under asyncio requires the 'anyio' package to be installed."
|
||||
)
|
||||
|
||||
elif self._backend == "asyncio":
|
||||
self._anyio_shield = anyio.CancelScope(shield=True)
|
||||
|
||||
def __enter__(self) -> "AsyncShieldCancellation":
|
||||
def __enter__(self) -> AsyncShieldCancellation:
|
||||
if self._backend == "trio":
|
||||
self._trio_shield.__enter__()
|
||||
else:
|
||||
elif self._backend == "asyncio":
|
||||
self._anyio_shield.__enter__()
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]] = None,
|
||||
exc_value: Optional[BaseException] = None,
|
||||
traceback: Optional[TracebackType] = None,
|
||||
exc_type: type[BaseException] | None = None,
|
||||
exc_value: BaseException | None = None,
|
||||
traceback: types.TracebackType | None = None,
|
||||
) -> None:
|
||||
if self._backend == "trio":
|
||||
self._trio_shield.__exit__(exc_type, exc_value, traceback)
|
||||
else:
|
||||
elif self._backend == "asyncio":
|
||||
self._anyio_shield.__exit__(exc_type, exc_value, traceback)
|
||||
|
||||
|
||||
@@ -224,18 +230,49 @@ class AsyncShieldCancellation:
|
||||
|
||||
|
||||
class Lock:
|
||||
"""
|
||||
This is a standard lock.
|
||||
|
||||
In the sync case `Lock` provides thread locking.
|
||||
In the async case `AsyncLock` provides async locking.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def __enter__(self) -> "Lock":
|
||||
def __enter__(self) -> Lock:
|
||||
self._lock.acquire()
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]] = None,
|
||||
exc_value: Optional[BaseException] = None,
|
||||
traceback: Optional[TracebackType] = None,
|
||||
exc_type: type[BaseException] | None = None,
|
||||
exc_value: BaseException | None = None,
|
||||
traceback: types.TracebackType | None = None,
|
||||
) -> None:
|
||||
self._lock.release()
|
||||
|
||||
|
||||
class ThreadLock:
|
||||
"""
|
||||
This is a threading-only lock for no-I/O contexts.
|
||||
|
||||
In the sync case `ThreadLock` provides thread locking.
|
||||
In the async case `AsyncThreadLock` is a no-op.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def __enter__(self) -> ThreadLock:
|
||||
self._lock.acquire()
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None = None,
|
||||
exc_value: BaseException | None = None,
|
||||
traceback: types.TracebackType | None = None,
|
||||
) -> None:
|
||||
self._lock.release()
|
||||
|
||||
@@ -247,7 +284,9 @@ class Event:
|
||||
def set(self) -> None:
|
||||
self._event.set()
|
||||
|
||||
def wait(self, timeout: Optional[float] = None) -> None:
|
||||
def wait(self, timeout: float | None = None) -> None:
|
||||
if timeout == float("inf"): # pragma: no cover
|
||||
timeout = None
|
||||
if not self._event.wait(timeout=timeout):
|
||||
raise PoolTimeout() # pragma: nocover
|
||||
|
||||
@@ -267,13 +306,13 @@ class ShieldCancellation:
|
||||
# Thread-synchronous codebases don't support cancellation semantics.
|
||||
# We have this class because we need to mirror the async and sync
|
||||
# cases within our package, but it's just a no-op.
|
||||
def __enter__(self) -> "ShieldCancellation":
|
||||
def __enter__(self) -> ShieldCancellation:
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]] = None,
|
||||
exc_value: Optional[BaseException] = None,
|
||||
traceback: Optional[TracebackType] = None,
|
||||
exc_type: type[BaseException] | None = None,
|
||||
exc_value: BaseException | None = None,
|
||||
traceback: types.TracebackType | None = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
from types import TracebackType
|
||||
from typing import Any, Dict, Optional, Type
|
||||
import types
|
||||
import typing
|
||||
|
||||
from ._models import Request
|
||||
|
||||
@@ -11,8 +13,8 @@ class Trace:
|
||||
self,
|
||||
name: str,
|
||||
logger: logging.Logger,
|
||||
request: Optional[Request] = None,
|
||||
kwargs: Optional[Dict[str, Any]] = None,
|
||||
request: Request | None = None,
|
||||
kwargs: dict[str, typing.Any] | None = None,
|
||||
) -> None:
|
||||
self.name = name
|
||||
self.logger = logger
|
||||
@@ -21,11 +23,11 @@ class Trace:
|
||||
)
|
||||
self.debug = self.logger.isEnabledFor(logging.DEBUG)
|
||||
self.kwargs = kwargs or {}
|
||||
self.return_value: Any = None
|
||||
self.return_value: typing.Any = None
|
||||
self.should_trace = self.debug or self.trace_extension is not None
|
||||
self.prefix = self.logger.name.split(".")[-1]
|
||||
|
||||
def trace(self, name: str, info: Dict[str, Any]) -> None:
|
||||
def trace(self, name: str, info: dict[str, typing.Any]) -> None:
|
||||
if self.trace_extension is not None:
|
||||
prefix_and_name = f"{self.prefix}.{name}"
|
||||
ret = self.trace_extension(prefix_and_name, info)
|
||||
@@ -44,7 +46,7 @@ class Trace:
|
||||
message = f"{name} {args}"
|
||||
self.logger.debug(message)
|
||||
|
||||
def __enter__(self) -> "Trace":
|
||||
def __enter__(self) -> Trace:
|
||||
if self.should_trace:
|
||||
info = self.kwargs
|
||||
self.trace(f"{self.name}.started", info)
|
||||
@@ -52,9 +54,9 @@ class Trace:
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]] = None,
|
||||
exc_value: Optional[BaseException] = None,
|
||||
traceback: Optional[TracebackType] = None,
|
||||
exc_type: type[BaseException] | None = None,
|
||||
exc_value: BaseException | None = None,
|
||||
traceback: types.TracebackType | None = None,
|
||||
) -> None:
|
||||
if self.should_trace:
|
||||
if exc_value is None:
|
||||
@@ -64,7 +66,7 @@ class Trace:
|
||||
info = {"exception": exc_value}
|
||||
self.trace(f"{self.name}.failed", info)
|
||||
|
||||
async def atrace(self, name: str, info: Dict[str, Any]) -> None:
|
||||
async def atrace(self, name: str, info: dict[str, typing.Any]) -> None:
|
||||
if self.trace_extension is not None:
|
||||
prefix_and_name = f"{self.prefix}.{name}"
|
||||
coro = self.trace_extension(prefix_and_name, info)
|
||||
@@ -84,7 +86,7 @@ class Trace:
|
||||
message = f"{name} {args}"
|
||||
self.logger.debug(message)
|
||||
|
||||
async def __aenter__(self) -> "Trace":
|
||||
async def __aenter__(self) -> Trace:
|
||||
if self.should_trace:
|
||||
info = self.kwargs
|
||||
await self.atrace(f"{self.name}.started", info)
|
||||
@@ -92,9 +94,9 @@ class Trace:
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]] = None,
|
||||
exc_value: Optional[BaseException] = None,
|
||||
traceback: Optional[TracebackType] = None,
|
||||
exc_type: type[BaseException] | None = None,
|
||||
exc_value: BaseException | None = None,
|
||||
traceback: types.TracebackType | None = None,
|
||||
) -> None:
|
||||
if self.should_trace:
|
||||
if exc_value is None:
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import select
|
||||
import socket
|
||||
import sys
|
||||
import typing
|
||||
|
||||
|
||||
def is_socket_readable(sock: typing.Optional[socket.socket]) -> bool:
|
||||
def is_socket_readable(sock: socket.socket | None) -> bool:
|
||||
"""
|
||||
Return whether a socket, as identifed by its file descriptor, is readable.
|
||||
"A socket is readable" means that the read buffer isn't empty, i.e. that calling
|
||||
|
||||
Reference in New Issue
Block a user