updates
This commit is contained in:
@@ -1,43 +1,58 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import inspect
|
||||
import io
|
||||
import json
|
||||
import math
|
||||
import queue
|
||||
import sys
|
||||
import typing
|
||||
import warnings
|
||||
from collections.abc import Awaitable, Callable, Generator, Iterable, Mapping, MutableMapping, Sequence
|
||||
from concurrent.futures import Future
|
||||
from contextlib import AbstractContextManager
|
||||
from types import GeneratorType
|
||||
from typing import (
|
||||
Any,
|
||||
Literal,
|
||||
TypedDict,
|
||||
TypeGuard,
|
||||
cast,
|
||||
)
|
||||
from urllib.parse import unquote, urljoin
|
||||
|
||||
import anyio
|
||||
import anyio.abc
|
||||
import anyio.from_thread
|
||||
import httpx
|
||||
from anyio.streams.stapled import StapledObjectStream
|
||||
|
||||
from starlette._utils import is_async_callable
|
||||
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
||||
from starlette.websockets import WebSocketDisconnect
|
||||
|
||||
if sys.version_info >= (3, 8): # pragma: no cover
|
||||
from typing import TypedDict
|
||||
if sys.version_info >= (3, 11): # pragma: no cover
|
||||
from typing import Self
|
||||
else: # pragma: no cover
|
||||
from typing_extensions import TypedDict
|
||||
from typing_extensions import Self
|
||||
|
||||
_PortalFactoryType = typing.Callable[
|
||||
[], typing.ContextManager[anyio.abc.BlockingPortal]
|
||||
]
|
||||
try:
|
||||
import httpx
|
||||
except ModuleNotFoundError: # pragma: no cover
|
||||
raise RuntimeError(
|
||||
"The starlette.testclient module requires the httpx package to be installed.\n"
|
||||
"You can install this with:\n"
|
||||
" $ pip install httpx\n"
|
||||
)
|
||||
_PortalFactoryType = Callable[[], AbstractContextManager[anyio.abc.BlockingPortal]]
|
||||
|
||||
ASGIInstance = typing.Callable[[Receive, Send], typing.Awaitable[None]]
|
||||
ASGI2App = typing.Callable[[Scope], ASGIInstance]
|
||||
ASGI3App = typing.Callable[[Scope, Receive, Send], typing.Awaitable[None]]
|
||||
ASGIInstance = Callable[[Receive, Send], Awaitable[None]]
|
||||
ASGI2App = Callable[[Scope], ASGIInstance]
|
||||
ASGI3App = Callable[[Scope, Receive, Send], Awaitable[None]]
|
||||
|
||||
|
||||
_RequestData = typing.Mapping[str, typing.Union[str, typing.Iterable[str]]]
|
||||
_RequestData = Mapping[str, str | Iterable[str] | bytes]
|
||||
|
||||
|
||||
def _is_asgi3(app: typing.Union[ASGI2App, ASGI3App]) -> bool:
|
||||
def _is_asgi3(app: ASGI2App | ASGI3App) -> TypeGuard[ASGI3App]:
|
||||
if inspect.isclass(app):
|
||||
return hasattr(app, "__await__")
|
||||
return is_async_callable(app)
|
||||
@@ -58,14 +73,24 @@ class _WrapASGI2:
|
||||
|
||||
class _AsyncBackend(TypedDict):
|
||||
backend: str
|
||||
backend_options: typing.Dict[str, typing.Any]
|
||||
backend_options: dict[str, Any]
|
||||
|
||||
|
||||
class _Upgrade(Exception):
|
||||
def __init__(self, session: "WebSocketTestSession") -> None:
|
||||
def __init__(self, session: WebSocketTestSession) -> None:
|
||||
self.session = session
|
||||
|
||||
|
||||
class WebSocketDenialResponse( # type: ignore[misc]
|
||||
httpx.Response,
|
||||
WebSocketDisconnect,
|
||||
):
|
||||
"""
|
||||
A special case of `WebSocketDisconnect`, raised in the `TestClient` if the
|
||||
`WebSocket` is closed before being accepted with a `send_denial_response()`.
|
||||
"""
|
||||
|
||||
|
||||
class WebSocketTestSession:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -77,65 +102,60 @@ class WebSocketTestSession:
|
||||
self.scope = scope
|
||||
self.accepted_subprotocol = None
|
||||
self.portal_factory = portal_factory
|
||||
self._receive_queue: "queue.Queue[typing.Any]" = queue.Queue()
|
||||
self._send_queue: "queue.Queue[typing.Any]" = queue.Queue()
|
||||
self.extra_headers = None
|
||||
|
||||
def __enter__(self) -> "WebSocketTestSession":
|
||||
self.exit_stack = contextlib.ExitStack()
|
||||
self.portal = self.exit_stack.enter_context(self.portal_factory())
|
||||
|
||||
try:
|
||||
_: "Future[None]" = self.portal.start_task_soon(self._run)
|
||||
def __enter__(self) -> WebSocketTestSession:
|
||||
with contextlib.ExitStack() as stack:
|
||||
self.portal = portal = stack.enter_context(self.portal_factory())
|
||||
fut, cs = portal.start_task(self._run)
|
||||
stack.callback(fut.result)
|
||||
stack.callback(portal.call, cs.cancel)
|
||||
self.send({"type": "websocket.connect"})
|
||||
message = self.receive()
|
||||
self._raise_on_close(message)
|
||||
except Exception:
|
||||
self.exit_stack.close()
|
||||
raise
|
||||
self.accepted_subprotocol = message.get("subprotocol", None)
|
||||
self.extra_headers = message.get("headers", None)
|
||||
return self
|
||||
self.accepted_subprotocol = message.get("subprotocol", None)
|
||||
self.extra_headers = message.get("headers", None)
|
||||
stack.callback(self.close, 1000)
|
||||
self.exit_stack = stack.pop_all()
|
||||
return self
|
||||
|
||||
def __exit__(self, *args: typing.Any) -> None:
|
||||
try:
|
||||
self.close(1000)
|
||||
finally:
|
||||
self.exit_stack.close()
|
||||
while not self._send_queue.empty():
|
||||
message = self._send_queue.get()
|
||||
if isinstance(message, BaseException):
|
||||
raise message
|
||||
def __exit__(self, *args: Any) -> bool | None:
|
||||
return self.exit_stack.__exit__(*args)
|
||||
|
||||
async def _run(self) -> None:
|
||||
async def _run(self, *, task_status: anyio.abc.TaskStatus[anyio.CancelScope]) -> None:
|
||||
"""
|
||||
The sub-thread in which the websocket session runs.
|
||||
"""
|
||||
scope = self.scope
|
||||
receive = self._asgi_receive
|
||||
send = self._asgi_send
|
||||
try:
|
||||
await self.app(scope, receive, send)
|
||||
except BaseException as exc:
|
||||
self._send_queue.put(exc)
|
||||
raise
|
||||
send: anyio.create_memory_object_stream[Message] = anyio.create_memory_object_stream(math.inf)
|
||||
send_tx, send_rx = send
|
||||
receive: anyio.create_memory_object_stream[Message] = anyio.create_memory_object_stream(math.inf)
|
||||
receive_tx, receive_rx = receive
|
||||
with send_tx, send_rx, receive_tx, receive_rx, anyio.CancelScope() as cs:
|
||||
self._receive_tx = receive_tx
|
||||
self._send_rx = send_rx
|
||||
task_status.started(cs)
|
||||
await self.app(self.scope, receive_rx.receive, send_tx.send)
|
||||
|
||||
async def _asgi_receive(self) -> Message:
|
||||
while self._receive_queue.empty():
|
||||
await anyio.sleep(0)
|
||||
return self._receive_queue.get()
|
||||
|
||||
async def _asgi_send(self, message: Message) -> None:
|
||||
self._send_queue.put(message)
|
||||
# wait for cs.cancel to be called before closing streams
|
||||
await anyio.sleep_forever()
|
||||
|
||||
def _raise_on_close(self, message: Message) -> None:
|
||||
if message["type"] == "websocket.close":
|
||||
raise WebSocketDisconnect(
|
||||
message.get("code", 1000), message.get("reason", "")
|
||||
)
|
||||
raise WebSocketDisconnect(code=message.get("code", 1000), reason=message.get("reason", ""))
|
||||
elif message["type"] == "websocket.http.response.start":
|
||||
status_code: int = message["status"]
|
||||
headers: list[tuple[bytes, bytes]] = message["headers"]
|
||||
body: list[bytes] = []
|
||||
while True:
|
||||
message = self.receive()
|
||||
assert message["type"] == "websocket.http.response.body"
|
||||
body.append(message["body"])
|
||||
if not message.get("more_body", False):
|
||||
break
|
||||
raise WebSocketDenialResponse(status_code=status_code, headers=headers, content=b"".join(body))
|
||||
|
||||
def send(self, message: Message) -> None:
|
||||
self._receive_queue.put(message)
|
||||
self.portal.call(self._receive_tx.send, message)
|
||||
|
||||
def send_text(self, data: str) -> None:
|
||||
self.send({"type": "websocket.receive", "text": data})
|
||||
@@ -143,35 +163,30 @@ class WebSocketTestSession:
|
||||
def send_bytes(self, data: bytes) -> None:
|
||||
self.send({"type": "websocket.receive", "bytes": data})
|
||||
|
||||
def send_json(self, data: typing.Any, mode: str = "text") -> None:
|
||||
assert mode in ["text", "binary"]
|
||||
text = json.dumps(data, separators=(",", ":"))
|
||||
def send_json(self, data: Any, mode: Literal["text", "binary"] = "text") -> None:
|
||||
text = json.dumps(data, separators=(",", ":"), ensure_ascii=False)
|
||||
if mode == "text":
|
||||
self.send({"type": "websocket.receive", "text": text})
|
||||
else:
|
||||
self.send({"type": "websocket.receive", "bytes": text.encode("utf-8")})
|
||||
|
||||
def close(self, code: int = 1000) -> None:
|
||||
self.send({"type": "websocket.disconnect", "code": code})
|
||||
def close(self, code: int = 1000, reason: str | None = None) -> None:
|
||||
self.send({"type": "websocket.disconnect", "code": code, "reason": reason})
|
||||
|
||||
def receive(self) -> Message:
|
||||
message = self._send_queue.get()
|
||||
if isinstance(message, BaseException):
|
||||
raise message
|
||||
return message
|
||||
return self.portal.call(self._send_rx.receive)
|
||||
|
||||
def receive_text(self) -> str:
|
||||
message = self.receive()
|
||||
self._raise_on_close(message)
|
||||
return message["text"]
|
||||
return cast(str, message["text"])
|
||||
|
||||
def receive_bytes(self) -> bytes:
|
||||
message = self.receive()
|
||||
self._raise_on_close(message)
|
||||
return message["bytes"]
|
||||
return cast(bytes, message["bytes"])
|
||||
|
||||
def receive_json(self, mode: str = "text") -> typing.Any:
|
||||
assert mode in ["text", "binary"]
|
||||
def receive_json(self, mode: Literal["text", "binary"] = "text") -> Any:
|
||||
message = self.receive()
|
||||
self._raise_on_close(message)
|
||||
if mode == "text":
|
||||
@@ -189,13 +204,15 @@ class _TestClientTransport(httpx.BaseTransport):
|
||||
raise_server_exceptions: bool = True,
|
||||
root_path: str = "",
|
||||
*,
|
||||
app_state: typing.Dict[str, typing.Any],
|
||||
client: tuple[str, int],
|
||||
app_state: dict[str, Any],
|
||||
) -> None:
|
||||
self.app = app
|
||||
self.raise_server_exceptions = raise_server_exceptions
|
||||
self.root_path = root_path
|
||||
self.portal_factory = portal_factory
|
||||
self.app_state = app_state
|
||||
self.client = client
|
||||
|
||||
def handle_request(self, request: httpx.Request) -> httpx.Response:
|
||||
scheme = request.url.scheme
|
||||
@@ -215,38 +232,36 @@ class _TestClientTransport(httpx.BaseTransport):
|
||||
|
||||
# Include the 'host' header.
|
||||
if "host" in request.headers:
|
||||
headers: typing.List[typing.Tuple[bytes, bytes]] = []
|
||||
headers: list[tuple[bytes, bytes]] = []
|
||||
elif port == default_port: # pragma: no cover
|
||||
headers = [(b"host", host.encode())]
|
||||
else: # pragma: no cover
|
||||
headers = [(b"host", (f"{host}:{port}").encode())]
|
||||
|
||||
# Include other request headers.
|
||||
headers += [
|
||||
(key.lower().encode(), value.encode())
|
||||
for key, value in request.headers.items()
|
||||
]
|
||||
headers += [(key.lower().encode(), value.encode()) for key, value in request.headers.multi_items()]
|
||||
|
||||
scope: typing.Dict[str, typing.Any]
|
||||
scope: dict[str, Any]
|
||||
|
||||
if scheme in {"ws", "wss"}:
|
||||
subprotocol = request.headers.get("sec-websocket-protocol", None)
|
||||
if subprotocol is None:
|
||||
subprotocols: typing.Sequence[str] = []
|
||||
subprotocols: Sequence[str] = []
|
||||
else:
|
||||
subprotocols = [value.strip() for value in subprotocol.split(",")]
|
||||
scope = {
|
||||
"type": "websocket",
|
||||
"path": unquote(path),
|
||||
"raw_path": raw_path,
|
||||
"raw_path": raw_path.split(b"?", 1)[0],
|
||||
"root_path": self.root_path,
|
||||
"scheme": scheme,
|
||||
"query_string": query.encode(),
|
||||
"headers": headers,
|
||||
"client": ["testclient", 50000],
|
||||
"client": self.client,
|
||||
"server": [host, port],
|
||||
"subprotocols": subprotocols,
|
||||
"state": self.app_state.copy(),
|
||||
"extensions": {"websocket.http.response": {}},
|
||||
}
|
||||
session = WebSocketTestSession(self.app, scope, self.portal_factory)
|
||||
raise _Upgrade(session)
|
||||
@@ -256,12 +271,12 @@ class _TestClientTransport(httpx.BaseTransport):
|
||||
"http_version": "1.1",
|
||||
"method": request.method,
|
||||
"path": unquote(path),
|
||||
"raw_path": raw_path,
|
||||
"raw_path": raw_path.split(b"?", 1)[0],
|
||||
"root_path": self.root_path,
|
||||
"scheme": scheme,
|
||||
"query_string": query.encode(),
|
||||
"headers": headers,
|
||||
"client": ["testclient", 50000],
|
||||
"client": self.client,
|
||||
"server": [host, port],
|
||||
"extensions": {"http.response.debug": {}},
|
||||
"state": self.app_state.copy(),
|
||||
@@ -270,7 +285,7 @@ class _TestClientTransport(httpx.BaseTransport):
|
||||
request_complete = False
|
||||
response_started = False
|
||||
response_complete: anyio.Event
|
||||
raw_kwargs: typing.Dict[str, typing.Any] = {"stream": io.BytesIO()}
|
||||
raw_kwargs: dict[str, Any] = {"stream": io.BytesIO()}
|
||||
template = None
|
||||
context = None
|
||||
|
||||
@@ -306,22 +321,13 @@ class _TestClientTransport(httpx.BaseTransport):
|
||||
nonlocal raw_kwargs, response_started, template, context
|
||||
|
||||
if message["type"] == "http.response.start":
|
||||
assert (
|
||||
not response_started
|
||||
), 'Received multiple "http.response.start" messages.'
|
||||
assert not response_started, 'Received multiple "http.response.start" messages.'
|
||||
raw_kwargs["status_code"] = message["status"]
|
||||
raw_kwargs["headers"] = [
|
||||
(key.decode(), value.decode())
|
||||
for key, value in message.get("headers", [])
|
||||
]
|
||||
raw_kwargs["headers"] = [(key.decode(), value.decode()) for key, value in message.get("headers", [])]
|
||||
response_started = True
|
||||
elif message["type"] == "http.response.body":
|
||||
assert (
|
||||
response_started
|
||||
), 'Received "http.response.body" without "http.response.start".'
|
||||
assert (
|
||||
not response_complete.is_set()
|
||||
), 'Received "http.response.body" after response completed.'
|
||||
assert response_started, 'Received "http.response.body" without "http.response.start".'
|
||||
assert not response_complete.is_set(), 'Received "http.response.body" after response completed.'
|
||||
body = message.get("body", b"")
|
||||
more_body = message.get("more_body", False)
|
||||
if request.method != "HEAD":
|
||||
@@ -361,8 +367,8 @@ class _TestClientTransport(httpx.BaseTransport):
|
||||
|
||||
class TestClient(httpx.Client):
|
||||
__test__ = False
|
||||
task: "Future[None]"
|
||||
portal: typing.Optional[anyio.abc.BlockingPortal] = None
|
||||
task: Future[None]
|
||||
portal: anyio.abc.BlockingPortal | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -370,110 +376,84 @@ class TestClient(httpx.Client):
|
||||
base_url: str = "http://testserver",
|
||||
raise_server_exceptions: bool = True,
|
||||
root_path: str = "",
|
||||
backend: str = "asyncio",
|
||||
backend_options: typing.Optional[typing.Dict[str, typing.Any]] = None,
|
||||
cookies: httpx._client.CookieTypes = None,
|
||||
headers: typing.Dict[str, str] = None,
|
||||
backend: Literal["asyncio", "trio"] = "asyncio",
|
||||
backend_options: dict[str, Any] | None = None,
|
||||
cookies: httpx._types.CookieTypes | None = None,
|
||||
headers: dict[str, str] | None = None,
|
||||
follow_redirects: bool = True,
|
||||
client: tuple[str, int] = ("testclient", 50000),
|
||||
) -> None:
|
||||
self.async_backend = _AsyncBackend(
|
||||
backend=backend, backend_options=backend_options or {}
|
||||
)
|
||||
self.async_backend = _AsyncBackend(backend=backend, backend_options=backend_options or {})
|
||||
if _is_asgi3(app):
|
||||
app = typing.cast(ASGI3App, app)
|
||||
asgi_app = app
|
||||
else:
|
||||
app = typing.cast(ASGI2App, app) # type: ignore[assignment]
|
||||
app = cast(ASGI2App, app) # type: ignore[assignment]
|
||||
asgi_app = _WrapASGI2(app) # type: ignore[arg-type]
|
||||
self.app = asgi_app
|
||||
self.app_state: typing.Dict[str, typing.Any] = {}
|
||||
self.app_state: dict[str, Any] = {}
|
||||
transport = _TestClientTransport(
|
||||
self.app,
|
||||
portal_factory=self._portal_factory,
|
||||
raise_server_exceptions=raise_server_exceptions,
|
||||
root_path=root_path,
|
||||
app_state=self.app_state,
|
||||
client=client,
|
||||
)
|
||||
if headers is None:
|
||||
headers = {}
|
||||
headers.setdefault("user-agent", "testclient")
|
||||
super().__init__(
|
||||
app=self.app,
|
||||
base_url=base_url,
|
||||
headers=headers,
|
||||
transport=transport,
|
||||
follow_redirects=True,
|
||||
follow_redirects=follow_redirects,
|
||||
cookies=cookies,
|
||||
)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _portal_factory(self) -> typing.Generator[anyio.abc.BlockingPortal, None, None]:
|
||||
def _portal_factory(self) -> Generator[anyio.abc.BlockingPortal, None, None]:
|
||||
if self.portal is not None:
|
||||
yield self.portal
|
||||
else:
|
||||
with anyio.from_thread.start_blocking_portal(
|
||||
**self.async_backend
|
||||
) as portal:
|
||||
with anyio.from_thread.start_blocking_portal(**self.async_backend) as portal:
|
||||
yield portal
|
||||
|
||||
def _choose_redirect_arg(
|
||||
self,
|
||||
follow_redirects: typing.Optional[bool],
|
||||
allow_redirects: typing.Optional[bool],
|
||||
) -> typing.Union[bool, httpx._client.UseClientDefault]:
|
||||
redirect: typing.Union[
|
||||
bool, httpx._client.UseClientDefault
|
||||
] = httpx._client.USE_CLIENT_DEFAULT
|
||||
if allow_redirects is not None:
|
||||
message = (
|
||||
"The `allow_redirects` argument is deprecated. "
|
||||
"Use `follow_redirects` instead."
|
||||
)
|
||||
warnings.warn(message, DeprecationWarning)
|
||||
redirect = allow_redirects
|
||||
if follow_redirects is not None:
|
||||
redirect = follow_redirects
|
||||
elif allow_redirects is not None and follow_redirects is not None:
|
||||
raise RuntimeError( # pragma: no cover
|
||||
"Cannot use both `allow_redirects` and `follow_redirects`."
|
||||
)
|
||||
return redirect
|
||||
|
||||
def request( # type: ignore[override]
|
||||
self,
|
||||
method: str,
|
||||
url: httpx._types.URLTypes,
|
||||
*,
|
||||
content: typing.Optional[httpx._types.RequestContent] = None,
|
||||
data: typing.Optional[_RequestData] = None,
|
||||
files: typing.Optional[httpx._types.RequestFiles] = None,
|
||||
json: typing.Any = None,
|
||||
params: typing.Optional[httpx._types.QueryParamTypes] = None,
|
||||
headers: typing.Optional[httpx._types.HeaderTypes] = None,
|
||||
cookies: typing.Optional[httpx._types.CookieTypes] = None,
|
||||
auth: typing.Union[
|
||||
httpx._types.AuthTypes, httpx._client.UseClientDefault
|
||||
] = httpx._client.USE_CLIENT_DEFAULT,
|
||||
follow_redirects: typing.Optional[bool] = None,
|
||||
allow_redirects: typing.Optional[bool] = None,
|
||||
timeout: typing.Union[
|
||||
httpx._client.TimeoutTypes, httpx._client.UseClientDefault
|
||||
] = httpx._client.USE_CLIENT_DEFAULT,
|
||||
extensions: typing.Optional[typing.Dict[str, typing.Any]] = None,
|
||||
content: httpx._types.RequestContent | None = None,
|
||||
data: _RequestData | None = None,
|
||||
files: httpx._types.RequestFiles | None = None,
|
||||
json: Any = None,
|
||||
params: httpx._types.QueryParamTypes | None = None,
|
||||
headers: httpx._types.HeaderTypes | None = None,
|
||||
cookies: httpx._types.CookieTypes | None = None,
|
||||
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
extensions: dict[str, Any] | None = None,
|
||||
) -> httpx.Response:
|
||||
url = self.base_url.join(url)
|
||||
redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
|
||||
if timeout is not httpx.USE_CLIENT_DEFAULT:
|
||||
warnings.warn(
|
||||
"You should not use the 'timeout' argument with the TestClient. "
|
||||
"See https://github.com/Kludex/starlette/issues/1108 for more information.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
url = self._merge_url(url)
|
||||
return super().request(
|
||||
method,
|
||||
url,
|
||||
content=content,
|
||||
data=data, # type: ignore[arg-type]
|
||||
data=data,
|
||||
files=files,
|
||||
json=json,
|
||||
params=params,
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
auth=auth,
|
||||
follow_redirects=redirect,
|
||||
follow_redirects=follow_redirects,
|
||||
timeout=timeout,
|
||||
extensions=extensions,
|
||||
)
|
||||
@@ -482,27 +462,21 @@ class TestClient(httpx.Client):
|
||||
self,
|
||||
url: httpx._types.URLTypes,
|
||||
*,
|
||||
params: typing.Optional[httpx._types.QueryParamTypes] = None,
|
||||
headers: typing.Optional[httpx._types.HeaderTypes] = None,
|
||||
cookies: typing.Optional[httpx._types.CookieTypes] = None,
|
||||
auth: typing.Union[
|
||||
httpx._types.AuthTypes, httpx._client.UseClientDefault
|
||||
] = httpx._client.USE_CLIENT_DEFAULT,
|
||||
follow_redirects: typing.Optional[bool] = None,
|
||||
allow_redirects: typing.Optional[bool] = None,
|
||||
timeout: typing.Union[
|
||||
httpx._client.TimeoutTypes, httpx._client.UseClientDefault
|
||||
] = httpx._client.USE_CLIENT_DEFAULT,
|
||||
extensions: typing.Optional[typing.Dict[str, typing.Any]] = None,
|
||||
params: httpx._types.QueryParamTypes | None = None,
|
||||
headers: httpx._types.HeaderTypes | None = None,
|
||||
cookies: httpx._types.CookieTypes | None = None,
|
||||
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
extensions: dict[str, Any] | None = None,
|
||||
) -> httpx.Response:
|
||||
redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
|
||||
return super().get(
|
||||
url,
|
||||
params=params,
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
auth=auth,
|
||||
follow_redirects=redirect,
|
||||
follow_redirects=follow_redirects,
|
||||
timeout=timeout,
|
||||
extensions=extensions,
|
||||
)
|
||||
@@ -511,27 +485,21 @@ class TestClient(httpx.Client):
|
||||
self,
|
||||
url: httpx._types.URLTypes,
|
||||
*,
|
||||
params: typing.Optional[httpx._types.QueryParamTypes] = None,
|
||||
headers: typing.Optional[httpx._types.HeaderTypes] = None,
|
||||
cookies: typing.Optional[httpx._types.CookieTypes] = None,
|
||||
auth: typing.Union[
|
||||
httpx._types.AuthTypes, httpx._client.UseClientDefault
|
||||
] = httpx._client.USE_CLIENT_DEFAULT,
|
||||
follow_redirects: typing.Optional[bool] = None,
|
||||
allow_redirects: typing.Optional[bool] = None,
|
||||
timeout: typing.Union[
|
||||
httpx._client.TimeoutTypes, httpx._client.UseClientDefault
|
||||
] = httpx._client.USE_CLIENT_DEFAULT,
|
||||
extensions: typing.Optional[typing.Dict[str, typing.Any]] = None,
|
||||
params: httpx._types.QueryParamTypes | None = None,
|
||||
headers: httpx._types.HeaderTypes | None = None,
|
||||
cookies: httpx._types.CookieTypes | None = None,
|
||||
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
extensions: dict[str, Any] | None = None,
|
||||
) -> httpx.Response:
|
||||
redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
|
||||
return super().options(
|
||||
url,
|
||||
params=params,
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
auth=auth,
|
||||
follow_redirects=redirect,
|
||||
follow_redirects=follow_redirects,
|
||||
timeout=timeout,
|
||||
extensions=extensions,
|
||||
)
|
||||
@@ -540,27 +508,21 @@ class TestClient(httpx.Client):
|
||||
self,
|
||||
url: httpx._types.URLTypes,
|
||||
*,
|
||||
params: typing.Optional[httpx._types.QueryParamTypes] = None,
|
||||
headers: typing.Optional[httpx._types.HeaderTypes] = None,
|
||||
cookies: typing.Optional[httpx._types.CookieTypes] = None,
|
||||
auth: typing.Union[
|
||||
httpx._types.AuthTypes, httpx._client.UseClientDefault
|
||||
] = httpx._client.USE_CLIENT_DEFAULT,
|
||||
follow_redirects: typing.Optional[bool] = None,
|
||||
allow_redirects: typing.Optional[bool] = None,
|
||||
timeout: typing.Union[
|
||||
httpx._client.TimeoutTypes, httpx._client.UseClientDefault
|
||||
] = httpx._client.USE_CLIENT_DEFAULT,
|
||||
extensions: typing.Optional[typing.Dict[str, typing.Any]] = None,
|
||||
params: httpx._types.QueryParamTypes | None = None,
|
||||
headers: httpx._types.HeaderTypes | None = None,
|
||||
cookies: httpx._types.CookieTypes | None = None,
|
||||
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
extensions: dict[str, Any] | None = None,
|
||||
) -> httpx.Response:
|
||||
redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
|
||||
return super().head(
|
||||
url,
|
||||
params=params,
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
auth=auth,
|
||||
follow_redirects=redirect,
|
||||
follow_redirects=follow_redirects,
|
||||
timeout=timeout,
|
||||
extensions=extensions,
|
||||
)
|
||||
@@ -569,35 +531,29 @@ class TestClient(httpx.Client):
|
||||
self,
|
||||
url: httpx._types.URLTypes,
|
||||
*,
|
||||
content: typing.Optional[httpx._types.RequestContent] = None,
|
||||
data: typing.Optional[_RequestData] = None,
|
||||
files: typing.Optional[httpx._types.RequestFiles] = None,
|
||||
json: typing.Any = None,
|
||||
params: typing.Optional[httpx._types.QueryParamTypes] = None,
|
||||
headers: typing.Optional[httpx._types.HeaderTypes] = None,
|
||||
cookies: typing.Optional[httpx._types.CookieTypes] = None,
|
||||
auth: typing.Union[
|
||||
httpx._types.AuthTypes, httpx._client.UseClientDefault
|
||||
] = httpx._client.USE_CLIENT_DEFAULT,
|
||||
follow_redirects: typing.Optional[bool] = None,
|
||||
allow_redirects: typing.Optional[bool] = None,
|
||||
timeout: typing.Union[
|
||||
httpx._client.TimeoutTypes, httpx._client.UseClientDefault
|
||||
] = httpx._client.USE_CLIENT_DEFAULT,
|
||||
extensions: typing.Optional[typing.Dict[str, typing.Any]] = None,
|
||||
content: httpx._types.RequestContent | None = None,
|
||||
data: _RequestData | None = None,
|
||||
files: httpx._types.RequestFiles | None = None,
|
||||
json: Any = None,
|
||||
params: httpx._types.QueryParamTypes | None = None,
|
||||
headers: httpx._types.HeaderTypes | None = None,
|
||||
cookies: httpx._types.CookieTypes | None = None,
|
||||
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
extensions: dict[str, Any] | None = None,
|
||||
) -> httpx.Response:
|
||||
redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
|
||||
return super().post(
|
||||
url,
|
||||
content=content,
|
||||
data=data, # type: ignore[arg-type]
|
||||
data=data,
|
||||
files=files,
|
||||
json=json,
|
||||
params=params,
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
auth=auth,
|
||||
follow_redirects=redirect,
|
||||
follow_redirects=follow_redirects,
|
||||
timeout=timeout,
|
||||
extensions=extensions,
|
||||
)
|
||||
@@ -606,35 +562,29 @@ class TestClient(httpx.Client):
|
||||
self,
|
||||
url: httpx._types.URLTypes,
|
||||
*,
|
||||
content: typing.Optional[httpx._types.RequestContent] = None,
|
||||
data: typing.Optional[_RequestData] = None,
|
||||
files: typing.Optional[httpx._types.RequestFiles] = None,
|
||||
json: typing.Any = None,
|
||||
params: typing.Optional[httpx._types.QueryParamTypes] = None,
|
||||
headers: typing.Optional[httpx._types.HeaderTypes] = None,
|
||||
cookies: typing.Optional[httpx._types.CookieTypes] = None,
|
||||
auth: typing.Union[
|
||||
httpx._types.AuthTypes, httpx._client.UseClientDefault
|
||||
] = httpx._client.USE_CLIENT_DEFAULT,
|
||||
follow_redirects: typing.Optional[bool] = None,
|
||||
allow_redirects: typing.Optional[bool] = None,
|
||||
timeout: typing.Union[
|
||||
httpx._client.TimeoutTypes, httpx._client.UseClientDefault
|
||||
] = httpx._client.USE_CLIENT_DEFAULT,
|
||||
extensions: typing.Optional[typing.Dict[str, typing.Any]] = None,
|
||||
content: httpx._types.RequestContent | None = None,
|
||||
data: _RequestData | None = None,
|
||||
files: httpx._types.RequestFiles | None = None,
|
||||
json: Any = None,
|
||||
params: httpx._types.QueryParamTypes | None = None,
|
||||
headers: httpx._types.HeaderTypes | None = None,
|
||||
cookies: httpx._types.CookieTypes | None = None,
|
||||
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
extensions: dict[str, Any] | None = None,
|
||||
) -> httpx.Response:
|
||||
redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
|
||||
return super().put(
|
||||
url,
|
||||
content=content,
|
||||
data=data, # type: ignore[arg-type]
|
||||
data=data,
|
||||
files=files,
|
||||
json=json,
|
||||
params=params,
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
auth=auth,
|
||||
follow_redirects=redirect,
|
||||
follow_redirects=follow_redirects,
|
||||
timeout=timeout,
|
||||
extensions=extensions,
|
||||
)
|
||||
@@ -643,35 +593,29 @@ class TestClient(httpx.Client):
|
||||
self,
|
||||
url: httpx._types.URLTypes,
|
||||
*,
|
||||
content: typing.Optional[httpx._types.RequestContent] = None,
|
||||
data: typing.Optional[_RequestData] = None,
|
||||
files: typing.Optional[httpx._types.RequestFiles] = None,
|
||||
json: typing.Any = None,
|
||||
params: typing.Optional[httpx._types.QueryParamTypes] = None,
|
||||
headers: typing.Optional[httpx._types.HeaderTypes] = None,
|
||||
cookies: typing.Optional[httpx._types.CookieTypes] = None,
|
||||
auth: typing.Union[
|
||||
httpx._types.AuthTypes, httpx._client.UseClientDefault
|
||||
] = httpx._client.USE_CLIENT_DEFAULT,
|
||||
follow_redirects: typing.Optional[bool] = None,
|
||||
allow_redirects: typing.Optional[bool] = None,
|
||||
timeout: typing.Union[
|
||||
httpx._client.TimeoutTypes, httpx._client.UseClientDefault
|
||||
] = httpx._client.USE_CLIENT_DEFAULT,
|
||||
extensions: typing.Optional[typing.Dict[str, typing.Any]] = None,
|
||||
content: httpx._types.RequestContent | None = None,
|
||||
data: _RequestData | None = None,
|
||||
files: httpx._types.RequestFiles | None = None,
|
||||
json: Any = None,
|
||||
params: httpx._types.QueryParamTypes | None = None,
|
||||
headers: httpx._types.HeaderTypes | None = None,
|
||||
cookies: httpx._types.CookieTypes | None = None,
|
||||
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
extensions: dict[str, Any] | None = None,
|
||||
) -> httpx.Response:
|
||||
redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
|
||||
return super().patch(
|
||||
url,
|
||||
content=content,
|
||||
data=data, # type: ignore[arg-type]
|
||||
data=data,
|
||||
files=files,
|
||||
json=json,
|
||||
params=params,
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
auth=auth,
|
||||
follow_redirects=redirect,
|
||||
follow_redirects=follow_redirects,
|
||||
timeout=timeout,
|
||||
extensions=extensions,
|
||||
)
|
||||
@@ -680,34 +624,31 @@ class TestClient(httpx.Client):
|
||||
self,
|
||||
url: httpx._types.URLTypes,
|
||||
*,
|
||||
params: typing.Optional[httpx._types.QueryParamTypes] = None,
|
||||
headers: typing.Optional[httpx._types.HeaderTypes] = None,
|
||||
cookies: typing.Optional[httpx._types.CookieTypes] = None,
|
||||
auth: typing.Union[
|
||||
httpx._types.AuthTypes, httpx._client.UseClientDefault
|
||||
] = httpx._client.USE_CLIENT_DEFAULT,
|
||||
follow_redirects: typing.Optional[bool] = None,
|
||||
allow_redirects: typing.Optional[bool] = None,
|
||||
timeout: typing.Union[
|
||||
httpx._client.TimeoutTypes, httpx._client.UseClientDefault
|
||||
] = httpx._client.USE_CLIENT_DEFAULT,
|
||||
extensions: typing.Optional[typing.Dict[str, typing.Any]] = None,
|
||||
params: httpx._types.QueryParamTypes | None = None,
|
||||
headers: httpx._types.HeaderTypes | None = None,
|
||||
cookies: httpx._types.CookieTypes | None = None,
|
||||
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
extensions: dict[str, Any] | None = None,
|
||||
) -> httpx.Response:
|
||||
redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
|
||||
return super().delete(
|
||||
url,
|
||||
params=params,
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
auth=auth,
|
||||
follow_redirects=redirect,
|
||||
follow_redirects=follow_redirects,
|
||||
timeout=timeout,
|
||||
extensions=extensions,
|
||||
)
|
||||
|
||||
def websocket_connect(
|
||||
self, url: str, subprotocols: typing.Sequence[str] = None, **kwargs: typing.Any
|
||||
) -> typing.Any:
|
||||
self,
|
||||
url: str,
|
||||
subprotocols: Sequence[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> WebSocketTestSession:
|
||||
url = urljoin("ws://testserver", url)
|
||||
headers = kwargs.get("headers", {})
|
||||
headers.setdefault("connection", "upgrade")
|
||||
@@ -725,22 +666,24 @@ class TestClient(httpx.Client):
|
||||
|
||||
return session
|
||||
|
||||
def __enter__(self) -> "TestClient":
|
||||
def __enter__(self) -> Self:
|
||||
with contextlib.ExitStack() as stack:
|
||||
self.portal = portal = stack.enter_context(
|
||||
anyio.from_thread.start_blocking_portal(**self.async_backend)
|
||||
)
|
||||
self.portal = portal = stack.enter_context(anyio.from_thread.start_blocking_portal(**self.async_backend))
|
||||
|
||||
@stack.callback
|
||||
def reset_portal() -> None:
|
||||
self.portal = None
|
||||
|
||||
self.stream_send = StapledObjectStream(
|
||||
*anyio.create_memory_object_stream(math.inf)
|
||||
send: anyio.create_memory_object_stream[MutableMapping[str, Any] | None] = (
|
||||
anyio.create_memory_object_stream(math.inf)
|
||||
)
|
||||
self.stream_receive = StapledObjectStream(
|
||||
*anyio.create_memory_object_stream(math.inf)
|
||||
receive: anyio.create_memory_object_stream[MutableMapping[str, Any]] = anyio.create_memory_object_stream(
|
||||
math.inf
|
||||
)
|
||||
for channel in (*send, *receive):
|
||||
stack.callback(channel.close)
|
||||
self.stream_send = StapledObjectStream(*send)
|
||||
self.stream_receive = StapledObjectStream(*receive)
|
||||
self.task = portal.start_task_soon(self.lifespan)
|
||||
portal.call(self.wait_startup)
|
||||
|
||||
@@ -752,7 +695,7 @@ class TestClient(httpx.Client):
|
||||
|
||||
return self
|
||||
|
||||
def __exit__(self, *args: typing.Any) -> None:
|
||||
def __exit__(self, *args: Any) -> None:
|
||||
self.exit_stack.close()
|
||||
|
||||
async def lifespan(self) -> None:
|
||||
@@ -765,7 +708,7 @@ class TestClient(httpx.Client):
|
||||
async def wait_startup(self) -> None:
|
||||
await self.stream_receive.send({"type": "lifespan.startup"})
|
||||
|
||||
async def receive() -> typing.Any:
|
||||
async def receive() -> Any:
|
||||
message = await self.stream_send.receive()
|
||||
if message is None:
|
||||
self.task.result()
|
||||
@@ -780,18 +723,17 @@ class TestClient(httpx.Client):
|
||||
await receive()
|
||||
|
||||
async def wait_shutdown(self) -> None:
|
||||
async def receive() -> typing.Any:
|
||||
async def receive() -> Any:
|
||||
message = await self.stream_send.receive()
|
||||
if message is None:
|
||||
self.task.result()
|
||||
return message
|
||||
|
||||
async with self.stream_send:
|
||||
await self.stream_receive.send({"type": "lifespan.shutdown"})
|
||||
message = await receive()
|
||||
assert message["type"] in (
|
||||
"lifespan.shutdown.complete",
|
||||
"lifespan.shutdown.failed",
|
||||
)
|
||||
if message["type"] == "lifespan.shutdown.failed":
|
||||
await receive()
|
||||
await self.stream_receive.send({"type": "lifespan.shutdown"})
|
||||
message = await receive()
|
||||
assert message["type"] in (
|
||||
"lifespan.shutdown.complete",
|
||||
"lifespan.shutdown.failed",
|
||||
)
|
||||
if message["type"] == "lifespan.shutdown.failed":
|
||||
await receive()
|
||||
|
||||
Reference in New Issue
Block a user