This commit is contained in:
Iliyan Angelov
2025-12-01 06:50:10 +02:00
parent 91f51bc6fe
commit 62c1fe5951
4682 changed files with 544807 additions and 31208 deletions

View File

@@ -1 +1 @@
__version__ = "0.27.0"
__version__ = "0.50.0"

View File

@@ -1,28 +0,0 @@
import hashlib
# Compat wrapper to always include the `usedforsecurity=...` parameter,
# which is only added from Python 3.9 onwards.
# We use this flag to indicate that we use `md5` hashes only for non-security
# cases (our ETag checksums).
# If we don't indicate that we're using MD5 for non-security related reasons,
# then attempting to use this function will raise an error when used
# environments which enable a strict "FIPs mode".
#
# See issue: https://github.com/encode/starlette/issues/1365
try:
# check if the Python version supports the parameter
# using usedforsecurity=False to avoid an exception on FIPS systems
# that reject usedforsecurity=True
hashlib.md5(b"data", usedforsecurity=False) # type: ignore[call-arg]
def md5_hexdigest(
data: bytes, *, usedforsecurity: bool = True
) -> str: # pragma: no cover
return hashlib.md5( # type: ignore[call-arg]
data, usedforsecurity=usedforsecurity
).hexdigest()
except TypeError: # pragma: no cover
def md5_hexdigest(data: bytes, *, usedforsecurity: bool = True) -> str:
return hashlib.md5(data).hexdigest()

View File

@@ -0,0 +1,65 @@
from __future__ import annotations
from typing import Any
from starlette._utils import is_async_callable
from starlette.concurrency import run_in_threadpool
from starlette.exceptions import HTTPException
from starlette.requests import Request
from starlette.types import ASGIApp, ExceptionHandler, Message, Receive, Scope, Send
from starlette.websockets import WebSocket
ExceptionHandlers = dict[Any, ExceptionHandler]
StatusHandlers = dict[int, ExceptionHandler]
def _lookup_exception_handler(exc_handlers: ExceptionHandlers, exc: Exception) -> ExceptionHandler | None:
for cls in type(exc).__mro__:
if cls in exc_handlers:
return exc_handlers[cls]
return None
def wrap_app_handling_exceptions(app: ASGIApp, conn: Request | WebSocket) -> ASGIApp:
exception_handlers: ExceptionHandlers
status_handlers: StatusHandlers
try:
exception_handlers, status_handlers = conn.scope["starlette.exception_handlers"]
except KeyError:
exception_handlers, status_handlers = {}, {}
async def wrapped_app(scope: Scope, receive: Receive, send: Send) -> None:
response_started = False
async def sender(message: Message) -> None:
nonlocal response_started
if message["type"] == "http.response.start":
response_started = True
await send(message)
try:
await app(scope, receive, sender)
except Exception as exc:
handler = None
if isinstance(exc, HTTPException):
handler = status_handlers.get(exc.status_code)
if handler is None:
handler = _lookup_exception_handler(exception_handlers, exc)
if handler is None:
raise exc
if response_started:
raise RuntimeError("Caught handled exception, but response already started.") from exc
if is_async_callable(handler):
response = await handler(conn, exc)
else:
response = await run_in_threadpool(handler, conn, exc)
if response is not None:
await response(scope, receive, sender)
return wrapped_app

View File

@@ -1,74 +1,103 @@
import asyncio
from __future__ import annotations
import functools
import sys
import typing
from types import TracebackType
from collections.abc import Awaitable, Callable, Generator
from contextlib import AbstractAsyncContextManager, contextmanager
from typing import Any, Generic, Protocol, TypeVar, overload
if sys.version_info < (3, 8): # pragma: no cover
from typing_extensions import Protocol
from starlette.types import Scope
if sys.version_info >= (3, 13): # pragma: no cover
from inspect import iscoroutinefunction
from typing import TypeIs
else: # pragma: no cover
from typing import Protocol
from asyncio import iscoroutinefunction
from typing_extensions import TypeIs
has_exceptiongroups = True
if sys.version_info < (3, 11): # pragma: no cover
try:
from exceptiongroup import BaseExceptionGroup # type: ignore[unused-ignore,import-not-found]
except ImportError:
has_exceptiongroups = False
T = TypeVar("T")
AwaitableCallable = Callable[..., Awaitable[T]]
def is_async_callable(obj: typing.Any) -> bool:
@overload
def is_async_callable(obj: AwaitableCallable[T]) -> TypeIs[AwaitableCallable[T]]: ...
@overload
def is_async_callable(obj: Any) -> TypeIs[AwaitableCallable[Any]]: ...
def is_async_callable(obj: Any) -> Any:
while isinstance(obj, functools.partial):
obj = obj.func
return asyncio.iscoroutinefunction(obj) or (
callable(obj) and asyncio.iscoroutinefunction(obj.__call__)
)
return iscoroutinefunction(obj) or (callable(obj) and iscoroutinefunction(obj.__call__))
T_co = typing.TypeVar("T_co", covariant=True)
T_co = TypeVar("T_co", covariant=True)
# TODO: once 3.8 is the minimum supported version (27 Jun 2023)
# this can just become
# class AwaitableOrContextManager(
# typing.Awaitable[T_co],
# typing.AsyncContextManager[T_co],
# typing.Protocol[T_co],
# ):
# pass
class AwaitableOrContextManager(Protocol[T_co]):
def __await__(self) -> typing.Generator[typing.Any, None, T_co]:
... # pragma: no cover
async def __aenter__(self) -> T_co:
... # pragma: no cover
async def __aexit__(
self,
__exc_type: typing.Optional[typing.Type[BaseException]],
__exc_value: typing.Optional[BaseException],
__traceback: typing.Optional[TracebackType],
) -> typing.Union[bool, None]:
... # pragma: no cover
class AwaitableOrContextManager(Awaitable[T_co], AbstractAsyncContextManager[T_co], Protocol[T_co]): ...
class SupportsAsyncClose(Protocol):
async def close(self) -> None:
... # pragma: no cover
async def close(self) -> None: ... # pragma: no cover
SupportsAsyncCloseType = typing.TypeVar(
"SupportsAsyncCloseType", bound=SupportsAsyncClose, covariant=False
)
SupportsAsyncCloseType = TypeVar("SupportsAsyncCloseType", bound=SupportsAsyncClose, covariant=False)
class AwaitableOrContextManagerWrapper(typing.Generic[SupportsAsyncCloseType]):
class AwaitableOrContextManagerWrapper(Generic[SupportsAsyncCloseType]):
__slots__ = ("aw", "entered")
def __init__(self, aw: typing.Awaitable[SupportsAsyncCloseType]) -> None:
def __init__(self, aw: Awaitable[SupportsAsyncCloseType]) -> None:
self.aw = aw
def __await__(self) -> typing.Generator[typing.Any, None, SupportsAsyncCloseType]:
def __await__(self) -> Generator[Any, None, SupportsAsyncCloseType]:
return self.aw.__await__()
async def __aenter__(self) -> SupportsAsyncCloseType:
self.entered = await self.aw
return self.entered
async def __aexit__(self, *args: typing.Any) -> typing.Union[None, bool]:
async def __aexit__(self, *args: Any) -> None | bool:
await self.entered.close()
return None
@contextmanager
def collapse_excgroups() -> Generator[None, None, None]:
try:
yield
except BaseException as exc:
if has_exceptiongroups: # pragma: no cover
while isinstance(exc, BaseExceptionGroup) and len(exc.exceptions) == 1:
exc = exc.exceptions[0]
raise exc
def get_route_path(scope: Scope) -> str:
path: str = scope["path"]
root_path = scope.get("root_path", "")
if not root_path:
return path
if not path.startswith(root_path):
return path
if path == root_path:
return ""
if path[len(root_path)] == "/":
return path[len(root_path) :]
return path

View File

@@ -1,90 +1,80 @@
import typing
from __future__ import annotations
import warnings
from collections.abc import Awaitable, Callable, Mapping, Sequence
from typing import Any, ParamSpec, TypeVar
from starlette.datastructures import State, URLPath
from starlette.middleware import Middleware
from starlette.middleware import Middleware, _MiddlewareFactory
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.errors import ServerErrorMiddleware
from starlette.middleware.exceptions import ExceptionMiddleware
from starlette.requests import Request
from starlette.responses import Response
from starlette.routing import BaseRoute, Router
from starlette.types import ASGIApp, Lifespan, Receive, Scope, Send
from starlette.types import ASGIApp, ExceptionHandler, Lifespan, Receive, Scope, Send
from starlette.websockets import WebSocket
AppType = typing.TypeVar("AppType", bound="Starlette")
AppType = TypeVar("AppType", bound="Starlette")
P = ParamSpec("P")
class Starlette:
"""
Creates an application instance.
**Parameters:**
* **debug** - Boolean indicating if debug tracebacks should be returned on errors.
* **routes** - A list of routes to serve incoming HTTP and WebSocket requests.
* **middleware** - A list of middleware to run for every request. A starlette
application will always automatically include two middleware classes.
`ServerErrorMiddleware` is added as the very outermost middleware, to handle
any uncaught errors occurring anywhere in the entire stack.
`ExceptionMiddleware` is added as the very innermost middleware, to deal
with handled exception cases occurring in the routing or endpoints.
* **exception_handlers** - A mapping of either integer status codes,
or exception class types onto callables which handle the exceptions.
Exception handler callables should be of the form
`handler(request, exc) -> response` and may be be either standard functions, or
async functions.
* **on_startup** - A list of callables to run on application startup.
Startup handler callables do not take any arguments, and may be be either
standard functions, or async functions.
* **on_shutdown** - A list of callables to run on application shutdown.
Shutdown handler callables do not take any arguments, and may be be either
standard functions, or async functions.
* **lifespan** - A lifespan context function, which can be used to perform
startup and shutdown tasks. This is a newer style that replaces the
`on_startup` and `on_shutdown` handlers. Use one or the other, not both.
"""
"""Creates an Starlette application."""
def __init__(
self: "AppType",
self: AppType,
debug: bool = False,
routes: typing.Optional[typing.Sequence[BaseRoute]] = None,
middleware: typing.Optional[typing.Sequence[Middleware]] = None,
exception_handlers: typing.Optional[
typing.Mapping[
typing.Any,
typing.Callable[
[Request, Exception],
typing.Union[Response, typing.Awaitable[Response]],
],
]
] = None,
on_startup: typing.Optional[typing.Sequence[typing.Callable]] = None,
on_shutdown: typing.Optional[typing.Sequence[typing.Callable]] = None,
lifespan: typing.Optional[Lifespan["AppType"]] = None,
routes: Sequence[BaseRoute] | None = None,
middleware: Sequence[Middleware] | None = None,
exception_handlers: Mapping[Any, ExceptionHandler] | None = None,
on_startup: Sequence[Callable[[], Any]] | None = None,
on_shutdown: Sequence[Callable[[], Any]] | None = None,
lifespan: Lifespan[AppType] | None = None,
) -> None:
"""Initializes the application.
Parameters:
debug: Boolean indicating if debug tracebacks should be returned on errors.
routes: A list of routes to serve incoming HTTP and WebSocket requests.
middleware: A list of middleware to run for every request. A starlette
application will always automatically include two middleware classes.
`ServerErrorMiddleware` is added as the very outermost middleware, to handle
any uncaught errors occurring anywhere in the entire stack.
`ExceptionMiddleware` is added as the very innermost middleware, to deal
with handled exception cases occurring in the routing or endpoints.
exception_handlers: A mapping of either integer status codes,
or exception class types onto callables which handle the exceptions.
Exception handler callables should be of the form
`handler(request, exc) -> response` and may be either standard functions, or
async functions.
on_startup: A list of callables to run on application startup.
Startup handler callables do not take any arguments, and may be either
standard functions, or async functions.
on_shutdown: A list of callables to run on application shutdown.
Shutdown handler callables do not take any arguments, and may be either
standard functions, or async functions.
lifespan: A lifespan context function, which can be used to perform
startup and shutdown tasks. This is a newer style that replaces the
`on_startup` and `on_shutdown` handlers. Use one or the other, not both.
"""
# The lifespan context function is a newer style that replaces
# on_startup / on_shutdown handlers. Use one or the other, not both.
assert lifespan is None or (
on_startup is None and on_shutdown is None
), "Use either 'lifespan' or 'on_startup'/'on_shutdown', not both."
assert lifespan is None or (on_startup is None and on_shutdown is None), (
"Use either 'lifespan' or 'on_startup'/'on_shutdown', not both."
)
self.debug = debug
self.state = State()
self.router = Router(
routes, on_startup=on_startup, on_shutdown=on_shutdown, lifespan=lifespan
)
self.exception_handlers = (
{} if exception_handlers is None else dict(exception_handlers)
)
self.router = Router(routes, on_startup=on_startup, on_shutdown=on_shutdown, lifespan=lifespan)
self.exception_handlers = {} if exception_handlers is None else dict(exception_handlers)
self.user_middleware = [] if middleware is None else list(middleware)
self.middleware_stack: typing.Optional[ASGIApp] = None
self.middleware_stack: ASGIApp | None = None
def build_middleware_stack(self) -> ASGIApp:
debug = self.debug
error_handler = None
exception_handlers: typing.Dict[
typing.Any, typing.Callable[[Request, Exception], Response]
] = {}
exception_handlers: dict[Any, ExceptionHandler] = {}
for key, value in self.exception_handlers.items():
if key in (500, Exception):
@@ -95,25 +85,20 @@ class Starlette:
middleware = (
[Middleware(ServerErrorMiddleware, handler=error_handler, debug=debug)]
+ self.user_middleware
+ [
Middleware(
ExceptionMiddleware, handlers=exception_handlers, debug=debug
)
]
+ [Middleware(ExceptionMiddleware, handlers=exception_handlers, debug=debug)]
)
app = self.router
for cls, options in reversed(middleware):
app = cls(app=app, **options)
for cls, args, kwargs in reversed(middleware):
app = cls(app, *args, **kwargs)
return app
@property
def routes(self) -> typing.List[BaseRoute]:
def routes(self) -> list[BaseRoute]:
return self.router.routes
# TODO: Make `__name` a positional-only argument when we drop Python 3.7 support.
def url_path_for(self, __name: str, **path_params: typing.Any) -> URLPath:
return self.router.url_path_for(__name, **path_params)
def url_path_for(self, name: str, /, **path_params: Any) -> URLPath:
return self.router.url_path_for(name, **path_params)
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
scope["app"] = self
@@ -121,63 +106,65 @@ class Starlette:
self.middleware_stack = self.build_middleware_stack()
await self.middleware_stack(scope, receive, send)
def on_event(self, event_type: str) -> typing.Callable: # pragma: nocover
return self.router.on_event(event_type)
def on_event(self, event_type: str) -> Callable: # type: ignore[type-arg]
return self.router.on_event(event_type) # pragma: no cover
def mount(
self, path: str, app: ASGIApp, name: typing.Optional[str] = None
) -> None: # pragma: nocover
self.router.mount(path, app=app, name=name)
def mount(self, path: str, app: ASGIApp, name: str | None = None) -> None:
self.router.mount(path, app=app, name=name) # pragma: no cover
def host(
self, host: str, app: ASGIApp, name: typing.Optional[str] = None
) -> None: # pragma: no cover
self.router.host(host, app=app, name=name)
def host(self, host: str, app: ASGIApp, name: str | None = None) -> None:
self.router.host(host, app=app, name=name) # pragma: no cover
def add_middleware(self, middleware_class: type, **options: typing.Any) -> None:
def add_middleware(
self,
middleware_class: _MiddlewareFactory[P],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
if self.middleware_stack is not None: # pragma: no cover
raise RuntimeError("Cannot add middleware after an application has started")
self.user_middleware.insert(0, Middleware(middleware_class, **options))
self.user_middleware.insert(0, Middleware(middleware_class, *args, **kwargs))
def add_exception_handler(
self,
exc_class_or_status_code: typing.Union[int, typing.Type[Exception]],
handler: typing.Callable,
exc_class_or_status_code: int | type[Exception],
handler: ExceptionHandler,
) -> None: # pragma: no cover
self.exception_handlers[exc_class_or_status_code] = handler
def add_event_handler(
self, event_type: str, func: typing.Callable
self,
event_type: str,
func: Callable, # type: ignore[type-arg]
) -> None: # pragma: no cover
self.router.add_event_handler(event_type, func)
def add_route(
self,
path: str,
route: typing.Callable,
methods: typing.Optional[typing.List[str]] = None,
name: typing.Optional[str] = None,
route: Callable[[Request], Awaitable[Response] | Response],
methods: list[str] | None = None,
name: str | None = None,
include_in_schema: bool = True,
) -> None: # pragma: no cover
self.router.add_route(
path, route, methods=methods, name=name, include_in_schema=include_in_schema
)
self.router.add_route(path, route, methods=methods, name=name, include_in_schema=include_in_schema)
def add_websocket_route(
self, path: str, route: typing.Callable, name: typing.Optional[str] = None
self,
path: str,
route: Callable[[WebSocket], Awaitable[None]],
name: str | None = None,
) -> None: # pragma: no cover
self.router.add_websocket_route(path, route, name=name)
def exception_handler(
self, exc_class_or_status_code: typing.Union[int, typing.Type[Exception]]
) -> typing.Callable:
def exception_handler(self, exc_class_or_status_code: int | type[Exception]) -> Callable: # type: ignore[type-arg]
warnings.warn(
"The `exception_handler` decorator is deprecated, and will be removed in version 1.0.0. " # noqa: E501
"Refer to https://www.starlette.io/exceptions/ for the recommended approach.", # noqa: E501
"The `exception_handler` decorator is deprecated, and will be removed in version 1.0.0. "
"Refer to https://starlette.dev/exceptions/ for the recommended approach.",
DeprecationWarning,
)
def decorator(func: typing.Callable) -> typing.Callable:
def decorator(func: Callable) -> Callable: # type: ignore[type-arg]
self.add_exception_handler(exc_class_or_status_code, func)
return func
@@ -186,10 +173,10 @@ class Starlette:
def route(
self,
path: str,
methods: typing.Optional[typing.List[str]] = None,
name: typing.Optional[str] = None,
methods: list[str] | None = None,
name: str | None = None,
include_in_schema: bool = True,
) -> typing.Callable:
) -> Callable: # type: ignore[type-arg]
"""
We no longer document this decorator style API, and its usage is discouraged.
Instead you should use the following approach:
@@ -198,12 +185,12 @@ class Starlette:
>>> app = Starlette(routes=routes)
"""
warnings.warn(
"The `route` decorator is deprecated, and will be removed in version 1.0.0. " # noqa: E501
"Refer to https://www.starlette.io/routing/ for the recommended approach.", # noqa: E501
"The `route` decorator is deprecated, and will be removed in version 1.0.0. "
"Refer to https://starlette.dev/routing/ for the recommended approach.",
DeprecationWarning,
)
def decorator(func: typing.Callable) -> typing.Callable:
def decorator(func: Callable) -> Callable: # type: ignore[type-arg]
self.router.add_route(
path,
func,
@@ -215,9 +202,7 @@ class Starlette:
return decorator
def websocket_route(
self, path: str, name: typing.Optional[str] = None
) -> typing.Callable:
def websocket_route(self, path: str, name: str | None = None) -> Callable: # type: ignore[type-arg]
"""
We no longer document this decorator style API, and its usage is discouraged.
Instead you should use the following approach:
@@ -226,18 +211,18 @@ class Starlette:
>>> app = Starlette(routes=routes)
"""
warnings.warn(
"The `websocket_route` decorator is deprecated, and will be removed in version 1.0.0. " # noqa: E501
"Refer to https://www.starlette.io/routing/#websocket-routing for the recommended approach.", # noqa: E501
"The `websocket_route` decorator is deprecated, and will be removed in version 1.0.0. "
"Refer to https://starlette.dev/routing/#websocket-routing for the recommended approach.",
DeprecationWarning,
)
def decorator(func: typing.Callable) -> typing.Callable:
def decorator(func: Callable) -> Callable: # type: ignore[type-arg]
self.router.add_websocket_route(path, func, name=name)
return func
return decorator
def middleware(self, middleware_type: str) -> typing.Callable:
def middleware(self, middleware_type: str) -> Callable: # type: ignore[type-arg]
"""
We no longer document this decorator style API, and its usage is discouraged.
Instead you should use the following approach:
@@ -246,15 +231,13 @@ class Starlette:
>>> app = Starlette(middleware=middleware)
"""
warnings.warn(
"The `middleware` decorator is deprecated, and will be removed in version 1.0.0. " # noqa: E501
"Refer to https://www.starlette.io/middleware/#using-middleware for recommended approach.", # noqa: E501
"The `middleware` decorator is deprecated, and will be removed in version 1.0.0. "
"Refer to https://starlette.dev/middleware/#using-middleware for recommended approach.",
DeprecationWarning,
)
assert (
middleware_type == "http"
), 'Currently only middleware("http") is supported.'
assert middleware_type == "http", 'Currently only middleware("http") is supported.'
def decorator(func: typing.Callable) -> typing.Callable:
def decorator(func: Callable) -> Callable: # type: ignore[type-arg]
self.add_middleware(BaseHTTPMiddleware, dispatch=func)
return func

View File

@@ -1,18 +1,21 @@
from __future__ import annotations
import functools
import inspect
import typing
from collections.abc import Callable, Sequence
from typing import Any, ParamSpec
from urllib.parse import urlencode
from starlette._utils import is_async_callable
from starlette.exceptions import HTTPException
from starlette.requests import HTTPConnection, Request
from starlette.responses import RedirectResponse, Response
from starlette.responses import RedirectResponse
from starlette.websockets import WebSocket
_CallableType = typing.TypeVar("_CallableType", bound=typing.Callable)
_P = ParamSpec("_P")
def has_required_scope(conn: HTTPConnection, scopes: typing.Sequence[str]) -> bool:
def has_required_scope(conn: HTTPConnection, scopes: Sequence[str]) -> bool:
for scope in scopes:
if scope not in conn.auth.scopes:
return False
@@ -20,32 +23,28 @@ def has_required_scope(conn: HTTPConnection, scopes: typing.Sequence[str]) -> bo
def requires(
scopes: typing.Union[str, typing.Sequence[str]],
scopes: str | Sequence[str],
status_code: int = 403,
redirect: typing.Optional[str] = None,
) -> typing.Callable[[_CallableType], _CallableType]:
redirect: str | None = None,
) -> Callable[[Callable[_P, Any]], Callable[_P, Any]]:
scopes_list = [scopes] if isinstance(scopes, str) else list(scopes)
def decorator(func: typing.Callable) -> typing.Callable:
def decorator(
func: Callable[_P, Any],
) -> Callable[_P, Any]:
sig = inspect.signature(func)
for idx, parameter in enumerate(sig.parameters.values()):
if parameter.name == "request" or parameter.name == "websocket":
type_ = parameter.name
break
else:
raise Exception(
f'No "request" or "websocket" argument on function "{func}"'
)
raise Exception(f'No "request" or "websocket" argument on function "{func}"')
if type_ == "websocket":
# Handle websocket functions. (Always async)
@functools.wraps(func)
async def websocket_wrapper(
*args: typing.Any, **kwargs: typing.Any
) -> None:
websocket = kwargs.get(
"websocket", args[idx] if idx < len(args) else None
)
async def websocket_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None:
websocket = kwargs.get("websocket", args[idx] if idx < len(args) else None)
assert isinstance(websocket, WebSocket)
if not has_required_scope(websocket, scopes_list):
@@ -58,19 +57,14 @@ def requires(
elif is_async_callable(func):
# Handle async request/response functions.
@functools.wraps(func)
async def async_wrapper(
*args: typing.Any, **kwargs: typing.Any
) -> Response:
async def async_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> Any:
request = kwargs.get("request", args[idx] if idx < len(args) else None)
assert isinstance(request, Request)
if not has_required_scope(request, scopes_list):
if redirect is not None:
orig_request_qparam = urlencode({"next": str(request.url)})
next_url = "{redirect_path}?{orig_request}".format(
redirect_path=request.url_for(redirect),
orig_request=orig_request_qparam,
)
next_url = f"{request.url_for(redirect)}?{orig_request_qparam}"
return RedirectResponse(url=next_url, status_code=303)
raise HTTPException(status_code=status_code)
return await func(*args, **kwargs)
@@ -80,24 +74,21 @@ def requires(
else:
# Handle sync request/response functions.
@functools.wraps(func)
def sync_wrapper(*args: typing.Any, **kwargs: typing.Any) -> Response:
def sync_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> Any:
request = kwargs.get("request", args[idx] if idx < len(args) else None)
assert isinstance(request, Request)
if not has_required_scope(request, scopes_list):
if redirect is not None:
orig_request_qparam = urlencode({"next": str(request.url)})
next_url = "{redirect_path}?{orig_request}".format(
redirect_path=request.url_for(redirect),
orig_request=orig_request_qparam,
)
next_url = f"{request.url_for(redirect)}?{orig_request_qparam}"
return RedirectResponse(url=next_url, status_code=303)
raise HTTPException(status_code=status_code)
return func(*args, **kwargs)
return sync_wrapper
return decorator # type: ignore[return-value]
return decorator
class AuthenticationError(Exception):
@@ -105,14 +96,12 @@ class AuthenticationError(Exception):
class AuthenticationBackend:
async def authenticate(
self, conn: HTTPConnection
) -> typing.Optional[typing.Tuple["AuthCredentials", "BaseUser"]]:
async def authenticate(self, conn: HTTPConnection) -> tuple[AuthCredentials, BaseUser] | None:
raise NotImplementedError() # pragma: no cover
class AuthCredentials:
def __init__(self, scopes: typing.Optional[typing.Sequence[str]] = None):
def __init__(self, scopes: Sequence[str] | None = None):
self.scopes = [] if scopes is None else list(scopes)

View File

@@ -1,10 +1,7 @@
import sys
import typing
from __future__ import annotations
if sys.version_info >= (3, 10): # pragma: no cover
from typing import ParamSpec
else: # pragma: no cover
from typing_extensions import ParamSpec
from collections.abc import Callable, Sequence
from typing import Any, ParamSpec
from starlette._utils import is_async_callable
from starlette.concurrency import run_in_threadpool
@@ -13,9 +10,7 @@ P = ParamSpec("P")
class BackgroundTask:
def __init__(
self, func: typing.Callable[P, typing.Any], *args: P.args, **kwargs: P.kwargs
) -> None:
def __init__(self, func: Callable[P, Any], *args: P.args, **kwargs: P.kwargs) -> None:
self.func = func
self.args = args
self.kwargs = kwargs
@@ -29,12 +24,10 @@ class BackgroundTask:
class BackgroundTasks(BackgroundTask):
def __init__(self, tasks: typing.Optional[typing.Sequence[BackgroundTask]] = None):
def __init__(self, tasks: Sequence[BackgroundTask] | None = None):
self.tasks = list(tasks) if tasks else []
def add_task(
self, func: typing.Callable[P, typing.Any], *args: P.args, **kwargs: P.kwargs
) -> None:
def add_task(self, func: Callable[P, Any], *args: P.args, **kwargs: P.kwargs) -> None:
task = BackgroundTask(func, *args, **kwargs)
self.tasks.append(task)

View File

@@ -1,30 +1,25 @@
from __future__ import annotations
import functools
import sys
import typing
import warnings
from collections.abc import AsyncIterator, Callable, Coroutine, Iterable, Iterator
from typing import ParamSpec, TypeVar
import anyio
import anyio.to_thread
if sys.version_info >= (3, 10): # pragma: no cover
from typing import ParamSpec
else: # pragma: no cover
from typing_extensions import ParamSpec
T = typing.TypeVar("T")
P = ParamSpec("P")
T = TypeVar("T")
async def run_until_first_complete(*args: typing.Tuple[typing.Callable, dict]) -> None:
async def run_until_first_complete(*args: tuple[Callable, dict]) -> None: # type: ignore[type-arg]
warnings.warn(
"run_until_first_complete is deprecated "
"and will be removed in a future version.",
"run_until_first_complete is deprecated and will be removed in a future version.",
DeprecationWarning,
)
async with anyio.create_task_group() as task_group:
async def run(func: typing.Callable[[], typing.Coroutine]) -> None:
async def run(func: Callable[[], Coroutine]) -> None: # type: ignore[type-arg]
await func()
task_group.cancel_scope.cancel()
@@ -32,20 +27,16 @@ async def run_until_first_complete(*args: typing.Tuple[typing.Callable, dict]) -
task_group.start_soon(run, functools.partial(func, **kwargs))
async def run_in_threadpool(
func: typing.Callable[P, T], *args: P.args, **kwargs: P.kwargs
) -> T:
if kwargs: # pragma: no cover
# run_sync doesn't accept 'kwargs', so bind them in here
func = functools.partial(func, **kwargs)
return await anyio.to_thread.run_sync(func, *args)
async def run_in_threadpool(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
func = functools.partial(func, *args, **kwargs)
return await anyio.to_thread.run_sync(func)
class _StopIteration(Exception):
pass
def _next(iterator: typing.Iterator[T]) -> T:
def _next(iterator: Iterator[T]) -> T:
# We can't raise `StopIteration` from within the threadpool iterator
# and catch it outside that context, so we coerce them into a different
# exception type.
@@ -56,10 +47,11 @@ def _next(iterator: typing.Iterator[T]) -> T:
async def iterate_in_threadpool(
iterator: typing.Iterator[T],
) -> typing.AsyncIterator[T]:
iterator: Iterable[T],
) -> AsyncIterator[T]:
as_iterator = iter(iterator)
while True:
try:
yield await anyio.to_thread.run_sync(_next, iterator)
yield await anyio.to_thread.run_sync(_next, as_iterator)
except _StopIteration:
break

View File

@@ -1,7 +1,10 @@
from __future__ import annotations
import os
import typing
from collections.abc import MutableMapping
import warnings
from collections.abc import Callable, Iterator, Mapping, MutableMapping
from pathlib import Path
from typing import Any, TypeVar, overload
class undefined:
@@ -12,32 +15,26 @@ class EnvironError(Exception):
pass
class Environ(MutableMapping):
def __init__(self, environ: typing.MutableMapping = os.environ):
class Environ(MutableMapping[str, str]):
def __init__(self, environ: MutableMapping[str, str] = os.environ):
self._environ = environ
self._has_been_read: typing.Set[typing.Any] = set()
self._has_been_read: set[str] = set()
def __getitem__(self, key: typing.Any) -> typing.Any:
def __getitem__(self, key: str) -> str:
self._has_been_read.add(key)
return self._environ.__getitem__(key)
def __setitem__(self, key: typing.Any, value: typing.Any) -> None:
def __setitem__(self, key: str, value: str) -> None:
if key in self._has_been_read:
raise EnvironError(
f"Attempting to set environ['{key}'], but the value has already been "
"read."
)
raise EnvironError(f"Attempting to set environ['{key}'], but the value has already been read.")
self._environ.__setitem__(key, value)
def __delitem__(self, key: typing.Any) -> None:
def __delitem__(self, key: str) -> None:
if key in self._has_been_read:
raise EnvironError(
f"Attempting to delete environ['{key}'], but the value has already "
"been read."
)
raise EnvironError(f"Attempting to delete environ['{key}'], but the value has already been read.")
self._environ.__delitem__(key)
def __iter__(self) -> typing.Iterator:
def __iter__(self) -> Iterator[str]:
return iter(self._environ)
def __len__(self) -> int:
@@ -46,65 +43,60 @@ class Environ(MutableMapping):
environ = Environ()
T = typing.TypeVar("T")
T = TypeVar("T")
class Config:
def __init__(
self,
env_file: typing.Optional[typing.Union[str, Path]] = None,
environ: typing.Mapping[str, str] = environ,
env_file: str | Path | None = None,
environ: Mapping[str, str] = environ,
env_prefix: str = "",
encoding: str = "utf-8",
) -> None:
self.environ = environ
self.env_prefix = env_prefix
self.file_values: typing.Dict[str, str] = {}
if env_file is not None and os.path.isfile(env_file):
self.file_values = self._read_file(env_file)
self.file_values: dict[str, str] = {}
if env_file is not None:
if not os.path.isfile(env_file):
warnings.warn(f"Config file '{env_file}' not found.")
else:
self.file_values = self._read_file(env_file, encoding)
@typing.overload
def __call__(self, key: str, *, default: None) -> typing.Optional[str]:
...
@overload
def __call__(self, key: str, *, default: None) -> str | None: ...
@typing.overload
def __call__(self, key: str, cast: typing.Type[T], default: T = ...) -> T:
...
@overload
def __call__(self, key: str, cast: type[T], default: T = ...) -> T: ...
@typing.overload
def __call__(
self, key: str, cast: typing.Type[str] = ..., default: str = ...
) -> str:
...
@overload
def __call__(self, key: str, cast: type[str] = ..., default: str = ...) -> str: ...
@typing.overload
@overload
def __call__(
self,
key: str,
cast: typing.Callable[[typing.Any], T] = ...,
default: typing.Any = ...,
) -> T:
...
cast: Callable[[Any], T] = ...,
default: Any = ...,
) -> T: ...
@typing.overload
def __call__(
self, key: str, cast: typing.Type[str] = ..., default: T = ...
) -> typing.Union[T, str]:
...
@overload
def __call__(self, key: str, cast: type[str] = ..., default: T = ...) -> T | str: ...
def __call__(
self,
key: str,
cast: typing.Optional[typing.Callable] = None,
default: typing.Any = undefined,
) -> typing.Any:
cast: Callable[[Any], Any] | None = None,
default: Any = undefined,
) -> Any:
return self.get(key, cast, default)
def get(
self,
key: str,
cast: typing.Optional[typing.Callable] = None,
default: typing.Any = undefined,
) -> typing.Any:
cast: Callable[[Any], Any] | None = None,
default: Any = undefined,
) -> Any:
key = self.env_prefix + key
if key in self.environ:
value = self.environ[key]
@@ -116,9 +108,9 @@ class Config:
return self._perform_cast(key, default, cast)
raise KeyError(f"Config '{key}' is missing, and has no default.")
def _read_file(self, file_name: typing.Union[str, Path]) -> typing.Dict[str, str]:
file_values: typing.Dict[str, str] = {}
with open(file_name) as input_file:
def _read_file(self, file_name: str | Path, encoding: str) -> dict[str, str]:
file_values: dict[str, str] = {}
with open(file_name, encoding=encoding) as input_file:
for line in input_file.readlines():
line = line.strip()
if "=" in line and not line.startswith("#"):
@@ -129,21 +121,20 @@ class Config:
return file_values
def _perform_cast(
self, key: str, value: typing.Any, cast: typing.Optional[typing.Callable] = None
) -> typing.Any:
self,
key: str,
value: Any,
cast: Callable[[Any], Any] | None = None,
) -> Any:
if cast is None or value is None:
return value
elif cast is bool and isinstance(value, str):
mapping = {"true": True, "1": True, "false": False, "0": False}
value = value.lower()
if value not in mapping:
raise ValueError(
f"Config '{key}' has value '{value}'. Not a valid bool."
)
raise ValueError(f"Config '{key}' has value '{value}'. Not a valid bool.")
return mapping[value]
try:
return cast(value)
except (TypeError, ValueError):
raise ValueError(
f"Config '{key}' has value '{value}'. Not a valid {cast.__name__}."
)
raise ValueError(f"Config '{key}' has value '{value}'. Not a valid {cast.__name__}.")

View File

@@ -1,12 +1,14 @@
from __future__ import annotations
import math
import typing
import uuid
from typing import Any, ClassVar, Generic, TypeVar
T = typing.TypeVar("T")
T = TypeVar("T")
class Convertor(typing.Generic[T]):
regex: typing.ClassVar[str] = ""
class Convertor(Generic[T]):
regex: ClassVar[str] = ""
def convert(self, value: str) -> T:
raise NotImplementedError() # pragma: no cover
@@ -15,7 +17,7 @@ class Convertor(typing.Generic[T]):
raise NotImplementedError() # pragma: no cover
class StringConvertor(Convertor):
class StringConvertor(Convertor[str]):
regex = "[^/]+"
def convert(self, value: str) -> str:
@@ -28,7 +30,7 @@ class StringConvertor(Convertor):
return value
class PathConvertor(Convertor):
class PathConvertor(Convertor[str]):
regex = ".*"
def convert(self, value: str) -> str:
@@ -38,7 +40,7 @@ class PathConvertor(Convertor):
return str(value)
class IntegerConvertor(Convertor):
class IntegerConvertor(Convertor[int]):
regex = "[0-9]+"
def convert(self, value: str) -> int:
@@ -50,7 +52,7 @@ class IntegerConvertor(Convertor):
return str(value)
class FloatConvertor(Convertor):
class FloatConvertor(Convertor[float]):
regex = r"[0-9]+(\.[0-9]+)?"
def convert(self, value: str) -> float:
@@ -64,8 +66,8 @@ class FloatConvertor(Convertor):
return ("%0.20f" % value).rstrip("0").rstrip(".")
class UUIDConvertor(Convertor):
regex = "[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}"
class UUIDConvertor(Convertor[uuid.UUID]):
regex = "[0-9a-fA-F]{8}-?[0-9a-fA-F]{4}-?[0-9a-fA-F]{4}-?[0-9a-fA-F]{4}-?[0-9a-fA-F]{12}"
def convert(self, value: str) -> uuid.UUID:
return uuid.UUID(value)
@@ -74,7 +76,7 @@ class UUIDConvertor(Convertor):
return str(value)
CONVERTOR_TYPES = {
CONVERTOR_TYPES: dict[str, Convertor[Any]] = {
"str": StringConvertor(),
"path": PathConvertor(),
"int": IntegerConvertor(),
@@ -83,5 +85,5 @@ CONVERTOR_TYPES = {
}
def register_url_convertor(key: str, convertor: Convertor) -> None:
def register_url_convertor(key: str, convertor: Convertor[Any]) -> None:
CONVERTOR_TYPES[key] = convertor

View File

@@ -1,37 +1,45 @@
import typing
from collections.abc import Sequence
from __future__ import annotations
from collections.abc import ItemsView, Iterable, Iterator, KeysView, Mapping, MutableMapping, Sequence, ValuesView
from shlex import shlex
from typing import (
Any,
BinaryIO,
NamedTuple,
TypeVar,
cast,
)
from urllib.parse import SplitResult, parse_qsl, urlencode, urlsplit
from starlette.concurrency import run_in_threadpool
from starlette.types import Scope
class Address(typing.NamedTuple):
class Address(NamedTuple):
host: str
port: int
_KeyType = typing.TypeVar("_KeyType")
_KeyType = TypeVar("_KeyType")
# Mapping keys are invariant but their values are covariant since
# you can only read them
# that is, you can't do `Mapping[str, Animal]()["fido"] = Dog()`
_CovariantValueType = typing.TypeVar("_CovariantValueType", covariant=True)
_CovariantValueType = TypeVar("_CovariantValueType", covariant=True)
class URL:
def __init__(
self,
url: str = "",
scope: typing.Optional[Scope] = None,
**components: typing.Any,
scope: Scope | None = None,
**components: Any,
) -> None:
if scope is not None:
assert not url, 'Cannot set both "url" and "scope".'
assert not components, 'Cannot set both "scope" and "**components".'
scheme = scope.get("scheme", "http")
server = scope.get("server", None)
path = scope.get("root_path", "") + scope["path"]
path = scope["path"]
query_string = scope.get("query_string", b"")
host_header = None
@@ -87,32 +95,27 @@ class URL:
return self.components.fragment
@property
def username(self) -> typing.Union[None, str]:
def username(self) -> None | str:
return self.components.username
@property
def password(self) -> typing.Union[None, str]:
def password(self) -> None | str:
return self.components.password
@property
def hostname(self) -> typing.Union[None, str]:
def hostname(self) -> None | str:
return self.components.hostname
@property
def port(self) -> typing.Optional[int]:
def port(self) -> int | None:
return self.components.port
@property
def is_secure(self) -> bool:
return self.scheme in ("https", "wss")
def replace(self, **kwargs: typing.Any) -> "URL":
if (
"username" in kwargs
or "password" in kwargs
or "hostname" in kwargs
or "port" in kwargs
):
def replace(self, **kwargs: Any) -> URL:
if "username" in kwargs or "password" in kwargs or "hostname" in kwargs or "port" in kwargs:
hostname = kwargs.pop("hostname", None)
port = kwargs.pop("port", self.port)
username = kwargs.pop("username", self.username)
@@ -139,19 +142,17 @@ class URL:
components = self.components._replace(**kwargs)
return self.__class__(components.geturl())
def include_query_params(self, **kwargs: typing.Any) -> "URL":
def include_query_params(self, **kwargs: Any) -> URL:
params = MultiDict(parse_qsl(self.query, keep_blank_values=True))
params.update({str(key): str(value) for key, value in kwargs.items()})
query = urlencode(params.multi_items())
return self.replace(query=query)
def replace_query_params(self, **kwargs: typing.Any) -> "URL":
def replace_query_params(self, **kwargs: Any) -> URL:
query = urlencode([(str(key), str(value)) for key, value in kwargs.items()])
return self.replace(query=query)
def remove_query_params(
self, keys: typing.Union[str, typing.Sequence[str]]
) -> "URL":
def remove_query_params(self, keys: str | Sequence[str]) -> URL:
if isinstance(keys, str):
keys = [keys]
params = MultiDict(parse_qsl(self.query, keep_blank_values=True))
@@ -160,7 +161,7 @@ class URL:
query = urlencode(params.multi_items())
return self.replace(query=query)
def __eq__(self, other: typing.Any) -> bool:
def __eq__(self, other: Any) -> bool:
return str(self) == str(other)
def __str__(self) -> str:
@@ -179,7 +180,7 @@ class URLPath(str):
Used by the routing to return `url_path_for` matches.
"""
def __new__(cls, path: str, protocol: str = "", host: str = "") -> "URLPath":
def __new__(cls, path: str, protocol: str = "", host: str = "") -> URLPath:
assert protocol in ("http", "websocket", "")
return str.__new__(cls, path)
@@ -187,7 +188,7 @@ class URLPath(str):
self.protocol = protocol
self.host = host
def make_absolute_url(self, base_url: typing.Union[str, URL]) -> URL:
def make_absolute_url(self, base_url: str | URL) -> URL:
if isinstance(base_url, str):
base_url = URL(base_url)
if self.protocol:
@@ -223,8 +224,8 @@ class Secret:
return bool(self._value)
class CommaSeparatedStrings(Sequence):
def __init__(self, value: typing.Union[str, typing.Sequence[str]]):
class CommaSeparatedStrings(Sequence[str]):
def __init__(self, value: str | Sequence[str]):
if isinstance(value, str):
splitter = shlex(value, posix=True)
splitter.whitespace = ","
@@ -236,10 +237,10 @@ class CommaSeparatedStrings(Sequence):
def __len__(self) -> int:
return len(self._items)
def __getitem__(self, index: typing.Union[int, slice]) -> typing.Any:
def __getitem__(self, index: int | slice) -> Any:
return self._items[index]
def __iter__(self) -> typing.Iterator[str]:
def __iter__(self) -> Iterator[str]:
return iter(self._items)
def __repr__(self) -> str:
@@ -251,74 +252,65 @@ class CommaSeparatedStrings(Sequence):
return ", ".join(repr(item) for item in self)
class ImmutableMultiDict(typing.Mapping[_KeyType, _CovariantValueType]):
_dict: typing.Dict[_KeyType, _CovariantValueType]
class ImmutableMultiDict(Mapping[_KeyType, _CovariantValueType]):
_dict: dict[_KeyType, _CovariantValueType]
def __init__(
self,
*args: typing.Union[
"ImmutableMultiDict[_KeyType, _CovariantValueType]",
typing.Mapping[_KeyType, _CovariantValueType],
typing.Iterable[typing.Tuple[_KeyType, _CovariantValueType]],
],
**kwargs: typing.Any,
*args: ImmutableMultiDict[_KeyType, _CovariantValueType]
| Mapping[_KeyType, _CovariantValueType]
| Iterable[tuple[_KeyType, _CovariantValueType]],
**kwargs: Any,
) -> None:
assert len(args) < 2, "Too many arguments."
value: typing.Any = args[0] if args else []
value: Any = args[0] if args else []
if kwargs:
value = (
ImmutableMultiDict(value).multi_items()
+ ImmutableMultiDict(kwargs).multi_items() # type: ignore[operator]
)
value = ImmutableMultiDict(value).multi_items() + ImmutableMultiDict(kwargs).multi_items()
if not value:
_items: typing.List[typing.Tuple[typing.Any, typing.Any]] = []
_items: list[tuple[Any, Any]] = []
elif hasattr(value, "multi_items"):
value = typing.cast(
ImmutableMultiDict[_KeyType, _CovariantValueType], value
)
value = cast(ImmutableMultiDict[_KeyType, _CovariantValueType], value)
_items = list(value.multi_items())
elif hasattr(value, "items"):
value = typing.cast(typing.Mapping[_KeyType, _CovariantValueType], value)
value = cast(Mapping[_KeyType, _CovariantValueType], value)
_items = list(value.items())
else:
value = typing.cast(
typing.List[typing.Tuple[typing.Any, typing.Any]], value
)
value = cast("list[tuple[Any, Any]]", value)
_items = list(value)
self._dict = {k: v for k, v in _items}
self._list = _items
def getlist(self, key: typing.Any) -> typing.List[_CovariantValueType]:
def getlist(self, key: Any) -> list[_CovariantValueType]:
return [item_value for item_key, item_value in self._list if item_key == key]
def keys(self) -> typing.KeysView[_KeyType]:
def keys(self) -> KeysView[_KeyType]:
return self._dict.keys()
def values(self) -> typing.ValuesView[_CovariantValueType]:
def values(self) -> ValuesView[_CovariantValueType]:
return self._dict.values()
def items(self) -> typing.ItemsView[_KeyType, _CovariantValueType]:
def items(self) -> ItemsView[_KeyType, _CovariantValueType]:
return self._dict.items()
def multi_items(self) -> typing.List[typing.Tuple[_KeyType, _CovariantValueType]]:
def multi_items(self) -> list[tuple[_KeyType, _CovariantValueType]]:
return list(self._list)
def __getitem__(self, key: _KeyType) -> _CovariantValueType:
return self._dict[key]
def __contains__(self, key: typing.Any) -> bool:
def __contains__(self, key: Any) -> bool:
return key in self._dict
def __iter__(self) -> typing.Iterator[_KeyType]:
def __iter__(self) -> Iterator[_KeyType]:
return iter(self.keys())
def __len__(self) -> int:
return len(self._dict)
def __eq__(self, other: typing.Any) -> bool:
def __eq__(self, other: Any) -> bool:
if not isinstance(other, self.__class__):
return False
return sorted(self._list) == sorted(other._list)
@@ -329,24 +321,24 @@ class ImmutableMultiDict(typing.Mapping[_KeyType, _CovariantValueType]):
return f"{class_name}({items!r})"
class MultiDict(ImmutableMultiDict[typing.Any, typing.Any]):
def __setitem__(self, key: typing.Any, value: typing.Any) -> None:
class MultiDict(ImmutableMultiDict[Any, Any]):
def __setitem__(self, key: Any, value: Any) -> None:
self.setlist(key, [value])
def __delitem__(self, key: typing.Any) -> None:
def __delitem__(self, key: Any) -> None:
self._list = [(k, v) for k, v in self._list if k != key]
del self._dict[key]
def pop(self, key: typing.Any, default: typing.Any = None) -> typing.Any:
def pop(self, key: Any, default: Any = None) -> Any:
self._list = [(k, v) for k, v in self._list if k != key]
return self._dict.pop(key, default)
def popitem(self) -> typing.Tuple:
def popitem(self) -> tuple[Any, Any]:
key, value = self._dict.popitem()
self._list = [(k, v) for k, v in self._list if k != key]
return key, value
def poplist(self, key: typing.Any) -> typing.List:
def poplist(self, key: Any) -> list[Any]:
values = [v for k, v in self._list if k == key]
self.pop(key)
return values
@@ -355,14 +347,14 @@ class MultiDict(ImmutableMultiDict[typing.Any, typing.Any]):
self._dict.clear()
self._list.clear()
def setdefault(self, key: typing.Any, default: typing.Any = None) -> typing.Any:
def setdefault(self, key: Any, default: Any = None) -> Any:
if key not in self:
self._dict[key] = default
self._list.append((key, default))
return self[key]
def setlist(self, key: typing.Any, values: typing.List) -> None:
def setlist(self, key: Any, values: list[Any]) -> None:
if not values:
self.pop(key, None)
else:
@@ -370,18 +362,14 @@ class MultiDict(ImmutableMultiDict[typing.Any, typing.Any]):
self._list = existing_items + [(key, value) for value in values]
self._dict[key] = values[-1]
def append(self, key: typing.Any, value: typing.Any) -> None:
def append(self, key: Any, value: Any) -> None:
self._list.append((key, value))
self._dict[key] = value
def update(
self,
*args: typing.Union[
"MultiDict",
typing.Mapping,
typing.List[typing.Tuple[typing.Any, typing.Any]],
],
**kwargs: typing.Any,
*args: MultiDict | Mapping[Any, Any] | list[tuple[Any, Any]],
**kwargs: Any,
) -> None:
value = MultiDict(*args, **kwargs)
existing_items = [(k, v) for (k, v) in self._list if k not in value.keys()]
@@ -396,14 +384,8 @@ class QueryParams(ImmutableMultiDict[str, str]):
def __init__(
self,
*args: typing.Union[
"ImmutableMultiDict",
typing.Mapping,
typing.List[typing.Tuple[typing.Any, typing.Any]],
str,
bytes,
],
**kwargs: typing.Any,
*args: ImmutableMultiDict[Any, Any] | Mapping[Any, Any] | list[tuple[Any, Any]] | str | bytes,
**kwargs: Any,
) -> None:
assert len(args) < 2, "Too many arguments."
@@ -412,9 +394,7 @@ class QueryParams(ImmutableMultiDict[str, str]):
if isinstance(value, str):
super().__init__(parse_qsl(value, keep_blank_values=True), **kwargs)
elif isinstance(value, bytes):
super().__init__(
parse_qsl(value.decode("latin-1"), keep_blank_values=True), **kwargs
)
super().__init__(parse_qsl(value.decode("latin-1"), keep_blank_values=True), **kwargs)
else:
super().__init__(*args, **kwargs) # type: ignore[arg-type]
self._list = [(str(k), str(v)) for k, v in self._list]
@@ -436,19 +416,23 @@ class UploadFile:
def __init__(
self,
file: typing.BinaryIO,
file: BinaryIO,
*,
size: typing.Optional[int] = None,
filename: typing.Optional[str] = None,
headers: "typing.Optional[Headers]" = None,
size: int | None = None,
filename: str | None = None,
headers: Headers | None = None,
) -> None:
self.filename = filename
self.file = file
self.size = size
self.headers = headers or Headers()
# Capture max size from SpooledTemporaryFile if one is provided. This slightly speeds up future checks.
# Note 0 means unlimited mirroring SpooledTemporaryFile's __init__
self._max_mem_size = getattr(self.file, "_max_size", 0)
@property
def content_type(self) -> typing.Optional[str]:
def content_type(self) -> str | None:
return self.headers.get("content-type", None)
@property
@@ -457,14 +441,24 @@ class UploadFile:
rolled_to_disk = getattr(self.file, "_rolled", True)
return not rolled_to_disk
async def write(self, data: bytes) -> None:
if self.size is not None:
self.size += len(data)
def _will_roll(self, size_to_add: int) -> bool:
# If we're not in_memory then we will always roll
if not self._in_memory:
return True
if self._in_memory:
self.file.write(data)
else:
# Check for SpooledTemporaryFile._max_size
future_size = self.file.tell() + size_to_add
return bool(future_size > self._max_mem_size) if self._max_mem_size else False
async def write(self, data: bytes) -> None:
new_data_len = len(data)
if self.size is not None:
self.size += new_data_len
if self._will_roll(new_data_len):
await run_in_threadpool(self.file.write, data)
else:
self.file.write(data)
async def read(self, size: int = -1) -> bytes:
if self._in_memory:
@@ -483,20 +477,19 @@ class UploadFile:
else:
await run_in_threadpool(self.file.close)
def __repr__(self) -> str:
return f"{self.__class__.__name__}(filename={self.filename!r}, size={self.size!r}, headers={self.headers!r})"
class FormData(ImmutableMultiDict[str, typing.Union[UploadFile, str]]):
class FormData(ImmutableMultiDict[str, UploadFile | str]):
"""
An immutable multidict, containing both file uploads and text input.
"""
def __init__(
self,
*args: typing.Union[
"FormData",
typing.Mapping[str, typing.Union[str, UploadFile]],
typing.List[typing.Tuple[str, typing.Union[str, UploadFile]]],
],
**kwargs: typing.Union[str, UploadFile],
*args: FormData | Mapping[str, str | UploadFile] | list[tuple[str, str | UploadFile]],
**kwargs: str | UploadFile,
) -> None:
super().__init__(*args, **kwargs)
@@ -506,25 +499,22 @@ class FormData(ImmutableMultiDict[str, typing.Union[UploadFile, str]]):
await value.close()
class Headers(typing.Mapping[str, str]):
class Headers(Mapping[str, str]):
"""
An immutable, case-insensitive multidict.
"""
def __init__(
self,
headers: typing.Optional[typing.Mapping[str, str]] = None,
raw: typing.Optional[typing.List[typing.Tuple[bytes, bytes]]] = None,
scope: typing.Optional[typing.MutableMapping[str, typing.Any]] = None,
headers: Mapping[str, str] | None = None,
raw: list[tuple[bytes, bytes]] | None = None,
scope: MutableMapping[str, Any] | None = None,
) -> None:
self._list: typing.List[typing.Tuple[bytes, bytes]] = []
self._list: list[tuple[bytes, bytes]] = []
if headers is not None:
assert raw is None, 'Cannot set both "headers" and "raw".'
assert scope is None, 'Cannot set both "headers" and "scope".'
self._list = [
(key.lower().encode("latin-1"), value.encode("latin-1"))
for key, value in headers.items()
]
self._list = [(key.lower().encode("latin-1"), value.encode("latin-1")) for key, value in headers.items()]
elif raw is not None:
assert scope is None, 'Cannot set both "raw" and "scope".'
self._list = raw
@@ -534,30 +524,23 @@ class Headers(typing.Mapping[str, str]):
self._list = scope["headers"] = list(scope["headers"])
@property
def raw(self) -> typing.List[typing.Tuple[bytes, bytes]]:
def raw(self) -> list[tuple[bytes, bytes]]:
return list(self._list)
def keys(self) -> typing.List[str]: # type: ignore[override]
def keys(self) -> list[str]: # type: ignore[override]
return [key.decode("latin-1") for key, value in self._list]
def values(self) -> typing.List[str]: # type: ignore[override]
def values(self) -> list[str]: # type: ignore[override]
return [value.decode("latin-1") for key, value in self._list]
def items(self) -> typing.List[typing.Tuple[str, str]]: # type: ignore[override]
return [
(key.decode("latin-1"), value.decode("latin-1"))
for key, value in self._list
]
def items(self) -> list[tuple[str, str]]: # type: ignore[override]
return [(key.decode("latin-1"), value.decode("latin-1")) for key, value in self._list]
def getlist(self, key: str) -> typing.List[str]:
def getlist(self, key: str) -> list[str]:
get_header_key = key.lower().encode("latin-1")
return [
item_value.decode("latin-1")
for item_key, item_value in self._list
if item_key == get_header_key
]
return [item_value.decode("latin-1") for item_key, item_value in self._list if item_key == get_header_key]
def mutablecopy(self) -> "MutableHeaders":
def mutablecopy(self) -> MutableHeaders:
return MutableHeaders(raw=self._list[:])
def __getitem__(self, key: str) -> str:
@@ -567,20 +550,20 @@ class Headers(typing.Mapping[str, str]):
return header_value.decode("latin-1")
raise KeyError(key)
def __contains__(self, key: typing.Any) -> bool:
def __contains__(self, key: Any) -> bool:
get_header_key = key.lower().encode("latin-1")
for header_key, header_value in self._list:
if header_key == get_header_key:
return True
return False
def __iter__(self) -> typing.Iterator[typing.Any]:
def __iter__(self) -> Iterator[Any]:
return iter(self.keys())
def __len__(self) -> int:
return len(self._list)
def __eq__(self, other: typing.Any) -> bool:
def __eq__(self, other: Any) -> bool:
if not isinstance(other, Headers):
return False
return sorted(self._list) == sorted(other._list)
@@ -602,7 +585,7 @@ class MutableHeaders(Headers):
set_key = key.lower().encode("latin-1")
set_value = value.encode("latin-1")
found_indexes: "typing.List[int]" = []
found_indexes: list[int] = []
for idx, (item_key, item_value) in enumerate(self._list):
if item_key == set_key:
found_indexes.append(idx)
@@ -622,7 +605,7 @@ class MutableHeaders(Headers):
"""
del_key = key.lower().encode("latin-1")
pop_indexes: "typing.List[int]" = []
pop_indexes: list[int] = []
for idx, (item_key, item_value) in enumerate(self._list):
if item_key == del_key:
pop_indexes.append(idx)
@@ -630,21 +613,21 @@ class MutableHeaders(Headers):
for idx in reversed(pop_indexes):
del self._list[idx]
def __ior__(self, other: typing.Mapping[str, str]) -> "MutableHeaders":
if not isinstance(other, typing.Mapping):
def __ior__(self, other: Mapping[str, str]) -> MutableHeaders:
if not isinstance(other, Mapping):
raise TypeError(f"Expected a mapping but got {other.__class__.__name__}")
self.update(other)
return self
def __or__(self, other: typing.Mapping[str, str]) -> "MutableHeaders":
if not isinstance(other, typing.Mapping):
def __or__(self, other: Mapping[str, str]) -> MutableHeaders:
if not isinstance(other, Mapping):
raise TypeError(f"Expected a mapping but got {other.__class__.__name__}")
new = self.mutablecopy()
new.update(other)
return new
@property
def raw(self) -> typing.List[typing.Tuple[bytes, bytes]]:
def raw(self) -> list[tuple[bytes, bytes]]:
return self._list
def setdefault(self, key: str, value: str) -> str:
@@ -661,7 +644,7 @@ class MutableHeaders(Headers):
self._list.append((set_key, set_value))
return value
def update(self, other: typing.Mapping[str, str]) -> None:
def update(self, other: Mapping[str, str]) -> None:
for key, val in other.items():
self[key] = val
@@ -687,22 +670,22 @@ class State:
Used for `request.state` and `app.state`.
"""
_state: typing.Dict[str, typing.Any]
_state: dict[str, Any]
def __init__(self, state: typing.Optional[typing.Dict[str, typing.Any]] = None):
def __init__(self, state: dict[str, Any] | None = None):
if state is None:
state = {}
super().__setattr__("_state", state)
def __setattr__(self, key: typing.Any, value: typing.Any) -> None:
def __setattr__(self, key: Any, value: Any) -> None:
self._state[key] = value
def __getattr__(self, key: typing.Any) -> typing.Any:
def __getattr__(self, key: Any) -> Any:
try:
return self._state[key]
except KeyError:
message = "'{}' object has no attribute '{}'"
raise AttributeError(message.format(self.__class__.__name__, key))
def __delattr__(self, key: typing.Any) -> None:
def __delattr__(self, key: Any) -> None:
del self._state[key]

View File

@@ -1,5 +1,8 @@
from __future__ import annotations
import json
import typing
from collections.abc import Callable, Generator
from typing import Any, Literal
from starlette import status
from starlette._utils import is_async_callable
@@ -23,20 +26,14 @@ class HTTPEndpoint:
if getattr(self, method.lower(), None) is not None
]
def __await__(self) -> typing.Generator:
def __await__(self) -> Generator[Any, None, None]:
return self.dispatch().__await__()
async def dispatch(self) -> None:
request = Request(self.scope, receive=self.receive)
handler_name = (
"get"
if request.method == "HEAD" and not hasattr(self, "head")
else request.method.lower()
)
handler_name = "get" if request.method == "HEAD" and not hasattr(self, "head") else request.method.lower()
handler: typing.Callable[[Request], typing.Any] = getattr(
self, handler_name, self.method_not_allowed
)
handler: Callable[[Request], Any] = getattr(self, handler_name, self.method_not_allowed)
is_async = is_async_callable(handler)
if is_async:
response = await handler(request)
@@ -55,7 +52,7 @@ class HTTPEndpoint:
class WebSocketEndpoint:
encoding: typing.Optional[str] = None # May be "text", "bytes", or "json".
encoding: Literal["text", "bytes", "json"] | None = None
def __init__(self, scope: Scope, receive: Receive, send: Send) -> None:
assert scope["type"] == "websocket"
@@ -63,7 +60,7 @@ class WebSocketEndpoint:
self.receive = receive
self.send = send
def __await__(self) -> typing.Generator:
def __await__(self) -> Generator[Any, None, None]:
return self.dispatch().__await__()
async def dispatch(self) -> None:
@@ -78,10 +75,8 @@ class WebSocketEndpoint:
if message["type"] == "websocket.receive":
data = await self.decode(websocket, message)
await self.on_receive(websocket, data)
elif message["type"] == "websocket.disconnect":
close_code = int(
message.get("code") or status.WS_1000_NORMAL_CLOSURE
)
elif message["type"] == "websocket.disconnect": # pragma: no branch
close_code = int(message.get("code") or status.WS_1000_NORMAL_CLOSURE)
break
except Exception as exc:
close_code = status.WS_1011_INTERNAL_ERROR
@@ -89,7 +84,7 @@ class WebSocketEndpoint:
finally:
await self.on_disconnect(websocket, close_code)
async def decode(self, websocket: WebSocket, message: Message) -> typing.Any:
async def decode(self, websocket: WebSocket, message: Message) -> Any:
if self.encoding == "text":
if "text" not in message:
await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA)
@@ -114,16 +109,14 @@ class WebSocketEndpoint:
await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA)
raise RuntimeError("Malformed JSON data received.")
assert (
self.encoding is None
), f"Unsupported 'encoding' attribute {self.encoding}"
assert self.encoding is None, f"Unsupported 'encoding' attribute {self.encoding}"
return message["text"] if message.get("text") else message["bytes"]
async def on_connect(self, websocket: WebSocket) -> None:
"""Override to handle an incoming websocket connection"""
await websocket.accept()
async def on_receive(self, websocket: WebSocket, data: typing.Any) -> None:
async def on_receive(self, websocket: WebSocket, data: Any) -> None:
"""Override to handle an incoming websocket message"""
async def on_disconnect(self, websocket: WebSocket, close_code: int) -> None:

View File

@@ -1,54 +1,33 @@
import http
import typing
import warnings
from __future__ import annotations
__all__ = ("HTTPException", "WebSocketException")
import http
from collections.abc import Mapping
class HTTPException(Exception):
def __init__(
self,
status_code: int,
detail: typing.Optional[str] = None,
headers: typing.Optional[dict] = None,
) -> None:
def __init__(self, status_code: int, detail: str | None = None, headers: Mapping[str, str] | None = None) -> None:
if detail is None:
detail = http.HTTPStatus(status_code).phrase
self.status_code = status_code
self.detail = detail
self.headers = headers
def __str__(self) -> str:
return f"{self.status_code}: {self.detail}"
def __repr__(self) -> str:
class_name = self.__class__.__name__
return f"{class_name}(status_code={self.status_code!r}, detail={self.detail!r})"
class WebSocketException(Exception):
def __init__(self, code: int, reason: typing.Optional[str] = None) -> None:
def __init__(self, code: int, reason: str | None = None) -> None:
self.code = code
self.reason = reason or ""
def __str__(self) -> str:
return f"{self.code}: {self.reason}"
def __repr__(self) -> str:
class_name = self.__class__.__name__
return f"{class_name}(code={self.code!r}, reason={self.reason!r})"
__deprecated__ = "ExceptionMiddleware"
def __getattr__(name: str) -> typing.Any: # pragma: no cover
if name == __deprecated__:
from starlette.middleware.exceptions import ExceptionMiddleware
warnings.warn(
f"{__deprecated__} is deprecated on `starlette.exceptions`. "
f"Import it from `starlette.middleware.exceptions` instead.",
category=DeprecationWarning,
stacklevel=3,
)
return ExceptionMiddleware
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
def __dir__() -> typing.List[str]:
return sorted(list(__all__) + [__deprecated__]) # pragma: no cover

View File

@@ -1,17 +1,28 @@
import typing
from __future__ import annotations
from collections.abc import AsyncGenerator
from dataclasses import dataclass, field
from enum import Enum
from tempfile import SpooledTemporaryFile
from typing import TYPE_CHECKING
from urllib.parse import unquote_plus
from starlette.datastructures import FormData, Headers, UploadFile
try:
import multipart
from multipart.multipart import parse_options_header
except ModuleNotFoundError: # pragma: nocover
parse_options_header = None
multipart = None
if TYPE_CHECKING:
import python_multipart as multipart
from python_multipart.multipart import MultipartCallbacks, QuerystringCallbacks, parse_options_header
else:
try:
try:
import python_multipart as multipart
from python_multipart.multipart import parse_options_header
except ModuleNotFoundError: # pragma: no cover
import multipart
from multipart.multipart import parse_options_header
except ModuleNotFoundError: # pragma: no cover
multipart = None
parse_options_header = None
class FormMessage(Enum):
@@ -24,14 +35,14 @@ class FormMessage(Enum):
@dataclass
class MultipartPart:
content_disposition: typing.Optional[bytes] = None
content_disposition: bytes | None = None
field_name: str = ""
data: bytes = b""
file: typing.Optional[UploadFile] = None
item_headers: typing.List[typing.Tuple[bytes, bytes]] = field(default_factory=list)
data: bytearray = field(default_factory=bytearray)
file: UploadFile | None = None
item_headers: list[tuple[bytes, bytes]] = field(default_factory=list)
def _user_safe_decode(src: bytes, codec: str) -> str:
def _user_safe_decode(src: bytes | bytearray, codec: str) -> str:
try:
return src.decode(codec)
except (UnicodeDecodeError, LookupError):
@@ -44,15 +55,11 @@ class MultiPartException(Exception):
class FormParser:
def __init__(
self, headers: Headers, stream: typing.AsyncGenerator[bytes, None]
) -> None:
assert (
multipart is not None
), "The `python-multipart` library must be installed to use form parsing."
def __init__(self, headers: Headers, stream: AsyncGenerator[bytes, None]) -> None:
assert multipart is not None, "The `python-multipart` library must be installed to use form parsing."
self.headers = headers
self.stream = stream
self.messages: typing.List[typing.Tuple[FormMessage, bytes]] = []
self.messages: list[tuple[FormMessage, bytes]] = []
def on_field_start(self) -> None:
message = (FormMessage.FIELD_START, b"")
@@ -76,7 +83,7 @@ class FormParser:
async def parse(self) -> FormData:
# Callbacks dictionary.
callbacks = {
callbacks: QuerystringCallbacks = {
"on_field_start": self.on_field_start,
"on_field_name": self.on_field_name,
"on_field_data": self.on_field_data,
@@ -89,7 +96,7 @@ class FormParser:
field_name = b""
field_value = b""
items: typing.List[typing.Tuple[str, typing.Union[str, UploadFile]]] = []
items: list[tuple[str, str | UploadFile]] = []
# Feed the parser with data from the request.
async for chunk in self.stream:
@@ -116,33 +123,36 @@ class FormParser:
class MultiPartParser:
max_file_size = 1024 * 1024
spool_max_size = 1024 * 1024 # 1MB
"""The maximum size of the spooled temporary file used to store file data."""
max_part_size = 1024 * 1024 # 1MB
"""The maximum size of a part in the multipart request."""
def __init__(
self,
headers: Headers,
stream: typing.AsyncGenerator[bytes, None],
stream: AsyncGenerator[bytes, None],
*,
max_files: typing.Union[int, float] = 1000,
max_fields: typing.Union[int, float] = 1000,
max_files: int | float = 1000,
max_fields: int | float = 1000,
max_part_size: int = 1024 * 1024, # 1MB
) -> None:
assert (
multipart is not None
), "The `python-multipart` library must be installed to use form parsing."
assert multipart is not None, "The `python-multipart` library must be installed to use form parsing."
self.headers = headers
self.stream = stream
self.max_files = max_files
self.max_fields = max_fields
self.items: typing.List[typing.Tuple[str, typing.Union[str, UploadFile]]] = []
self.items: list[tuple[str, str | UploadFile]] = []
self._current_files = 0
self._current_fields = 0
self._current_partial_header_name: bytes = b""
self._current_partial_header_value: bytes = b""
self._current_part = MultipartPart()
self._charset = ""
self._file_parts_to_write: typing.List[typing.Tuple[MultipartPart, bytes]] = []
self._file_parts_to_finish: typing.List[MultipartPart] = []
self._files_to_close_on_error: typing.List[SpooledTemporaryFile] = []
self._file_parts_to_write: list[tuple[MultipartPart, bytes]] = []
self._file_parts_to_finish: list[MultipartPart] = []
self._files_to_close_on_error: list[SpooledTemporaryFile[bytes]] = []
self.max_part_size = max_part_size
def on_part_begin(self) -> None:
self._current_part = MultipartPart()
@@ -150,7 +160,9 @@ class MultiPartParser:
def on_part_data(self, data: bytes, start: int, end: int) -> None:
message_bytes = data[start:end]
if self._current_part.file is None:
self._current_part.data += message_bytes
if len(self._current_part.data) + len(message_bytes) > self.max_part_size:
raise MultiPartException(f"Part exceeded maximum size of {int(self.max_part_size / 1024)}KB.")
self._current_part.data.extend(message_bytes)
else:
self._file_parts_to_write.append((self._current_part, message_bytes))
@@ -179,32 +191,22 @@ class MultiPartParser:
field = self._current_partial_header_name.lower()
if field == b"content-disposition":
self._current_part.content_disposition = self._current_partial_header_value
self._current_part.item_headers.append(
(field, self._current_partial_header_value)
)
self._current_part.item_headers.append((field, self._current_partial_header_value))
self._current_partial_header_name = b""
self._current_partial_header_value = b""
def on_headers_finished(self) -> None:
disposition, options = parse_options_header(
self._current_part.content_disposition
)
disposition, options = parse_options_header(self._current_part.content_disposition)
try:
self._current_part.field_name = _user_safe_decode(
options[b"name"], self._charset
)
self._current_part.field_name = _user_safe_decode(options[b"name"], self._charset)
except KeyError:
raise MultiPartException(
'The Content-Disposition header field "name" must be ' "provided."
)
raise MultiPartException('The Content-Disposition header field "name" must be provided.')
if b"filename" in options:
self._current_files += 1
if self._current_files > self.max_files:
raise MultiPartException(
f"Too many files. Maximum number of files is {self.max_files}."
)
raise MultiPartException(f"Too many files. Maximum number of files is {self.max_files}.")
filename = _user_safe_decode(options[b"filename"], self._charset)
tempfile = SpooledTemporaryFile(max_size=self.max_file_size)
tempfile = SpooledTemporaryFile(max_size=self.spool_max_size)
self._files_to_close_on_error.append(tempfile)
self._current_part.file = UploadFile(
file=tempfile, # type: ignore[arg-type]
@@ -215,9 +217,7 @@ class MultiPartParser:
else:
self._current_fields += 1
if self._current_fields > self.max_fields:
raise MultiPartException(
f"Too many fields. Maximum number of fields is {self.max_fields}."
)
raise MultiPartException(f"Too many fields. Maximum number of fields is {self.max_fields}.")
self._current_part.file = None
def on_end(self) -> None:
@@ -227,7 +227,7 @@ class MultiPartParser:
# Parse the Content-Type header to get the multipart boundary.
_, params = parse_options_header(self.headers["Content-Type"])
charset = params.get(b"charset", "utf-8")
if type(charset) == bytes:
if isinstance(charset, bytes):
charset = charset.decode("latin-1")
self._charset = charset
try:
@@ -236,7 +236,7 @@ class MultiPartParser:
raise MultiPartException("Missing boundary in multipart.")
# Callbacks dictionary.
callbacks = {
callbacks: MultipartCallbacks = {
"on_part_begin": self.on_part_begin,
"on_part_data": self.on_part_data,
"on_part_end": self.on_part_end,

View File

@@ -1,17 +1,37 @@
import typing
from __future__ import annotations
from collections.abc import Awaitable, Callable, Iterator
from typing import Any, ParamSpec, Protocol
P = ParamSpec("P")
_Scope = Any
_Receive = Callable[[], Awaitable[Any]]
_Send = Callable[[Any], Awaitable[None]]
# Since `starlette.types.ASGIApp` type differs from `ASGIApplication` from `asgiref`
# we need to define a more permissive version of ASGIApp that doesn't cause type errors.
_ASGIApp = Callable[[_Scope, _Receive, _Send], Awaitable[None]]
class _MiddlewareFactory(Protocol[P]):
def __call__(self, app: _ASGIApp, /, *args: P.args, **kwargs: P.kwargs) -> _ASGIApp: ... # pragma: no cover
class Middleware:
def __init__(self, cls: type, **options: typing.Any) -> None:
def __init__(self, cls: _MiddlewareFactory[P], *args: P.args, **kwargs: P.kwargs) -> None:
self.cls = cls
self.options = options
self.args = args
self.kwargs = kwargs
def __iter__(self) -> typing.Iterator:
as_tuple = (self.cls, self.options)
def __iter__(self) -> Iterator[Any]:
as_tuple = (self.cls, self.args, self.kwargs)
return iter(as_tuple)
def __repr__(self) -> str:
class_name = self.__class__.__name__
option_strings = [f"{key}={value!r}" for key, value in self.options.items()]
args_repr = ", ".join([self.cls.__name__] + option_strings)
args_strings = [f"{value!r}" for value in self.args]
option_strings = [f"{key}={value!r}" for key, value in self.kwargs.items()]
name = getattr(self.cls, "__name__", "")
args_repr = ", ".join([name] + args_strings + option_strings)
return f"{class_name}({args_repr})"

View File

@@ -1,4 +1,6 @@
import typing
from __future__ import annotations
from collections.abc import Callable
from starlette.authentication import (
AuthCredentials,
@@ -16,15 +18,13 @@ class AuthenticationMiddleware:
self,
app: ASGIApp,
backend: AuthenticationBackend,
on_error: typing.Optional[
typing.Callable[[HTTPConnection, AuthenticationError], Response]
] = None,
on_error: Callable[[HTTPConnection, AuthenticationError], Response] | None = None,
) -> None:
self.app = app
self.backend = backend
self.on_error: typing.Callable[
[HTTPConnection, AuthenticationError], Response
] = (on_error if on_error is not None else self.default_on_error)
self.on_error: Callable[[HTTPConnection, AuthenticationError], Response] = (
on_error if on_error is not None else self.default_on_error
)
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] not in ["http", "websocket"]:

View File

@@ -1,23 +1,100 @@
import typing
from __future__ import annotations
from collections.abc import AsyncGenerator, AsyncIterable, Awaitable, Callable, Mapping, MutableMapping
from typing import Any, TypeVar
import anyio
from starlette.background import BackgroundTask
from starlette.requests import Request
from starlette.responses import ContentStream, Response, StreamingResponse
from starlette._utils import collapse_excgroups
from starlette.requests import ClientDisconnect, Request
from starlette.responses import Response
from starlette.types import ASGIApp, Message, Receive, Scope, Send
RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]]
DispatchFunction = typing.Callable[
[Request, RequestResponseEndpoint], typing.Awaitable[Response]
]
T = typing.TypeVar("T")
RequestResponseEndpoint = Callable[[Request], Awaitable[Response]]
DispatchFunction = Callable[[Request, RequestResponseEndpoint], Awaitable[Response]]
BodyStreamGenerator = AsyncGenerator[bytes | MutableMapping[str, Any], None]
AsyncContentStream = AsyncIterable[str | bytes | memoryview | MutableMapping[str, Any]]
T = TypeVar("T")
class _CachedRequest(Request):
"""
If the user calls Request.body() from their dispatch function
we cache the entire request body in memory and pass that to downstream middlewares,
but if they call Request.stream() then all we do is send an
empty body so that downstream things don't hang forever.
"""
def __init__(self, scope: Scope, receive: Receive):
super().__init__(scope, receive)
self._wrapped_rcv_disconnected = False
self._wrapped_rcv_consumed = False
self._wrapped_rc_stream = self.stream()
async def wrapped_receive(self) -> Message:
# wrapped_rcv state 1: disconnected
if self._wrapped_rcv_disconnected:
# we've already sent a disconnect to the downstream app
# we don't need to wait to get another one
# (although most ASGI servers will just keep sending it)
return {"type": "http.disconnect"}
# wrapped_rcv state 1: consumed but not yet disconnected
if self._wrapped_rcv_consumed:
# since the downstream app has consumed us all that is left
# is to send it a disconnect
if self._is_disconnected:
# the middleware has already seen the disconnect
# since we know the client is disconnected no need to wait
# for the message
self._wrapped_rcv_disconnected = True
return {"type": "http.disconnect"}
# we don't know yet if the client is disconnected or not
# so we'll wait until we get that message
msg = await self.receive()
if msg["type"] != "http.disconnect": # pragma: no cover
# at this point a disconnect is all that we should be receiving
# if we get something else, things went wrong somewhere
raise RuntimeError(f"Unexpected message received: {msg['type']}")
self._wrapped_rcv_disconnected = True
return msg
# wrapped_rcv state 3: not yet consumed
if getattr(self, "_body", None) is not None:
# body() was called, we return it even if the client disconnected
self._wrapped_rcv_consumed = True
return {
"type": "http.request",
"body": self._body,
"more_body": False,
}
elif self._stream_consumed:
# stream() was called to completion
# return an empty body so that downstream apps don't hang
# waiting for a disconnect
self._wrapped_rcv_consumed = True
return {
"type": "http.request",
"body": b"",
"more_body": False,
}
else:
# body() was never called and stream() wasn't consumed
try:
stream = self.stream()
chunk = await stream.__anext__()
self._wrapped_rcv_consumed = self._stream_consumed
return {
"type": "http.request",
"body": chunk,
"more_body": not self._stream_consumed,
}
except ClientDisconnect:
self._wrapped_rcv_disconnected = True
return {"type": "http.disconnect"}
class BaseHTTPMiddleware:
def __init__(
self, app: ASGIApp, dispatch: typing.Optional[DispatchFunction] = None
) -> None:
def __init__(self, app: ASGIApp, dispatch: DispatchFunction | None = None) -> None:
self.app = app
self.dispatch_func = self.dispatch if dispatch is None else dispatch
@@ -26,35 +103,32 @@ class BaseHTTPMiddleware:
await self.app(scope, receive, send)
return
request = _CachedRequest(scope, receive)
wrapped_receive = request.wrapped_receive
response_sent = anyio.Event()
app_exc: Exception | None = None
exception_already_raised = False
async def call_next(request: Request) -> Response:
app_exc: typing.Optional[Exception] = None
send_stream, recv_stream = anyio.create_memory_object_stream()
async def receive_or_disconnect() -> Message:
if response_sent.is_set():
return {"type": "http.disconnect"}
async with anyio.create_task_group() as task_group:
async def wrap(func: typing.Callable[[], typing.Awaitable[T]]) -> T:
async def wrap(func: Callable[[], Awaitable[T]]) -> T:
result = await func()
task_group.cancel_scope.cancel()
return result
task_group.start_soon(wrap, response_sent.wait)
message = await wrap(request.receive)
message = await wrap(wrapped_receive)
if response_sent.is_set():
return {"type": "http.disconnect"}
return message
async def close_recv_stream_on_response_sent() -> None:
await response_sent.wait()
recv_stream.close()
async def send_no_error(message: Message) -> None:
try:
await send_stream.send(message)
@@ -65,13 +139,12 @@ class BaseHTTPMiddleware:
async def coro() -> None:
nonlocal app_exc
async with send_stream:
with send_stream:
try:
await self.app(scope, receive_or_disconnect, send_no_error)
except Exception as exc:
app_exc = exc
task_group.start_soon(close_recv_stream_on_response_sent)
task_group.start_soon(coro)
try:
@@ -81,54 +154,91 @@ class BaseHTTPMiddleware:
message = await recv_stream.receive()
except anyio.EndOfStream:
if app_exc is not None:
raise app_exc
nonlocal exception_already_raised
exception_already_raised = True
# Prevent `anyio.EndOfStream` from polluting app exception context.
# If both cause and context are None then the context is suppressed
# and `anyio.EndOfStream` is not present in the exception traceback.
# If exception cause is not None then it is propagated with
# reraising here.
# If exception has no cause but has context set then the context is
# propagated as a cause with the reraise. This is necessary in order
# to prevent `anyio.EndOfStream` from polluting the exception
# context.
raise app_exc from app_exc.__cause__ or app_exc.__context__
raise RuntimeError("No response returned.")
assert message["type"] == "http.response.start"
async def body_stream() -> typing.AsyncGenerator[bytes, None]:
async with recv_stream:
async for message in recv_stream:
assert message["type"] == "http.response.body"
body = message.get("body", b"")
if body:
yield body
async def body_stream() -> BodyStreamGenerator:
async for message in recv_stream:
if message["type"] == "http.response.pathsend":
yield message
break
assert message["type"] == "http.response.body", f"Unexpected message: {message}"
body = message.get("body", b"")
if body:
yield body
if not message.get("more_body", False):
break
if app_exc is not None:
raise app_exc
response = _StreamingResponse(
status_code=message["status"], content=body_stream(), info=info
)
response = _StreamingResponse(status_code=message["status"], content=body_stream(), info=info)
response.raw_headers = message["headers"]
return response
async with anyio.create_task_group() as task_group:
request = Request(scope, receive=receive)
response = await self.dispatch_func(request, call_next)
await response(scope, receive, send)
response_sent.set()
streams: anyio.create_memory_object_stream[Message] = anyio.create_memory_object_stream()
send_stream, recv_stream = streams
with recv_stream, send_stream, collapse_excgroups():
async with anyio.create_task_group() as task_group:
response = await self.dispatch_func(request, call_next)
await response(scope, wrapped_receive, send)
response_sent.set()
recv_stream.close()
if app_exc is not None and not exception_already_raised:
raise app_exc
async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
) -> Response:
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
raise NotImplementedError() # pragma: no cover
class _StreamingResponse(StreamingResponse):
class _StreamingResponse(Response):
def __init__(
self,
content: ContentStream,
content: AsyncContentStream,
status_code: int = 200,
headers: typing.Optional[typing.Mapping[str, str]] = None,
media_type: typing.Optional[str] = None,
background: typing.Optional[BackgroundTask] = None,
info: typing.Optional[typing.Mapping[str, typing.Any]] = None,
headers: Mapping[str, str] | None = None,
media_type: str | None = None,
info: Mapping[str, Any] | None = None,
) -> None:
self._info = info
super().__init__(content, status_code, headers, media_type, background)
self.info = info
self.body_iterator = content
self.status_code = status_code
self.media_type = media_type
self.init_headers(headers)
self.background = None
async def stream_response(self, send: Send) -> None:
if self._info:
await send({"type": "http.response.debug", "info": self._info})
return await super().stream_response(send)
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if self.info is not None:
await send({"type": "http.response.debug", "info": self.info})
await send(
{
"type": "http.response.start",
"status": self.status_code,
"headers": self.raw_headers,
}
)
should_close_body = True
async for chunk in self.body_iterator:
if isinstance(chunk, dict):
# We got an ASGI message which is not response body (eg: pathsend)
should_close_body = False
await send(chunk)
continue
await send({"type": "http.response.body", "body": chunk, "more_body": True})
if should_close_body:
await send({"type": "http.response.body", "body": b"", "more_body": False})
if self.background:
await self.background()

View File

@@ -1,6 +1,8 @@
from __future__ import annotations
import functools
import re
import typing
from collections.abc import Sequence
from starlette.datastructures import Headers, MutableHeaders
from starlette.responses import PlainTextResponse, Response
@@ -14,12 +16,12 @@ class CORSMiddleware:
def __init__(
self,
app: ASGIApp,
allow_origins: typing.Sequence[str] = (),
allow_methods: typing.Sequence[str] = ("GET",),
allow_headers: typing.Sequence[str] = (),
allow_origins: Sequence[str] = (),
allow_methods: Sequence[str] = ("GET",),
allow_headers: Sequence[str] = (),
allow_credentials: bool = False,
allow_origin_regex: typing.Optional[str] = None,
expose_headers: typing.Sequence[str] = (),
allow_origin_regex: str | None = None,
expose_headers: Sequence[str] = (),
max_age: int = 600,
) -> None:
if "*" in allow_methods:
@@ -94,9 +96,7 @@ class CORSMiddleware:
if self.allow_all_origins:
return True
if self.allow_origin_regex is not None and self.allow_origin_regex.fullmatch(
origin
):
if self.allow_origin_regex is not None and self.allow_origin_regex.fullmatch(origin):
return True
return origin in self.allow_origins
@@ -139,15 +139,11 @@ class CORSMiddleware:
return PlainTextResponse("OK", status_code=200, headers=headers)
async def simple_response(
self, scope: Scope, receive: Receive, send: Send, request_headers: Headers
) -> None:
async def simple_response(self, scope: Scope, receive: Receive, send: Send, request_headers: Headers) -> None:
send = functools.partial(self.send, send=send, request_headers=request_headers)
await self.app(scope, receive, send)
async def send(
self, message: Message, send: Send, request_headers: Headers
) -> None:
async def send(self, message: Message, send: Send, request_headers: Headers) -> None:
if message["type"] != "http.response.start":
await send(message)
return

View File

@@ -1,13 +1,15 @@
from __future__ import annotations
import html
import inspect
import sys
import traceback
import typing
from starlette._utils import is_async_callable
from starlette.concurrency import run_in_threadpool
from starlette.requests import Request
from starlette.responses import HTMLResponse, PlainTextResponse, Response
from starlette.types import ASGIApp, Message, Receive, Scope, Send
from starlette.types import ASGIApp, ExceptionHandler, Message, Receive, Scope, Send
STYLES = """
p {
@@ -137,7 +139,7 @@ class ServerErrorMiddleware:
def __init__(
self,
app: ASGIApp,
handler: typing.Optional[typing.Callable] = None,
handler: ExceptionHandler | None = None,
debug: bool = False,
) -> None:
self.app = app
@@ -183,9 +185,7 @@ class ServerErrorMiddleware:
# to optionally raise the error within the test case.
raise exc
def format_line(
self, index: int, line: str, frame_lineno: int, frame_index: int
) -> str:
def format_line(self, index: int, line: str, frame_lineno: int, frame_index: int) -> str:
values = {
# HTML escape - line could contain < or >
"line": html.escape(line).replace(" ", "&nbsp"),
@@ -199,7 +199,10 @@ class ServerErrorMiddleware:
def generate_frame_html(self, frame: inspect.FrameInfo, is_collapsed: bool) -> str:
code_context = "".join(
self.format_line(
index, line, frame.lineno, frame.index # type: ignore[arg-type]
index,
line,
frame.lineno,
frame.index, # type: ignore[arg-type]
)
for index, line in enumerate(frame.code_context or [])
)
@@ -219,9 +222,7 @@ class ServerErrorMiddleware:
return FRAME_TEMPLATE.format(**values)
def generate_html(self, exc: Exception, limit: int = 7) -> str:
traceback_obj = traceback.TracebackException.from_exception(
exc, capture_locals=True
)
traceback_obj = traceback.TracebackException.from_exception(exc, capture_locals=True)
exc_html = ""
is_collapsed = False
@@ -232,11 +233,13 @@ class ServerErrorMiddleware:
exc_html += self.generate_frame_html(frame, is_collapsed)
is_collapsed = True
if sys.version_info >= (3, 13): # pragma: no cover
exc_type_str = traceback_obj.exc_type_str
else: # pragma: no cover
exc_type_str = traceback_obj.exc_type.__name__
# escape error class and text
error = (
f"{html.escape(traceback_obj.exc_type.__name__)}: "
f"{html.escape(str(traceback_obj))}"
)
error = f"{html.escape(exc_type_str)}: {html.escape(str(traceback_obj))}"
return TEMPLATE.format(styles=STYLES, js=JS, error=error, exc_html=exc_html)

View File

@@ -1,11 +1,17 @@
import typing
from __future__ import annotations
from starlette._utils import is_async_callable
from starlette.concurrency import run_in_threadpool
from collections.abc import Mapping
from typing import Any
from starlette._exception_handler import (
ExceptionHandlers,
StatusHandlers,
wrap_app_handling_exceptions,
)
from starlette.exceptions import HTTPException, WebSocketException
from starlette.requests import Request
from starlette.responses import PlainTextResponse, Response
from starlette.types import ASGIApp, Message, Receive, Scope, Send
from starlette.types import ASGIApp, ExceptionHandler, Receive, Scope, Send
from starlette.websockets import WebSocket
@@ -13,28 +19,24 @@ class ExceptionMiddleware:
def __init__(
self,
app: ASGIApp,
handlers: typing.Optional[
typing.Mapping[typing.Any, typing.Callable[[Request, Exception], Response]]
] = None,
handlers: Mapping[Any, ExceptionHandler] | None = None,
debug: bool = False,
) -> None:
self.app = app
self.debug = debug # TODO: We ought to handle 404 cases if debug is set.
self._status_handlers: typing.Dict[int, typing.Callable] = {}
self._exception_handlers: typing.Dict[
typing.Type[Exception], typing.Callable
] = {
self._status_handlers: StatusHandlers = {}
self._exception_handlers: ExceptionHandlers = {
HTTPException: self.http_exception,
WebSocketException: self.websocket_exception,
}
if handlers is not None:
if handlers is not None: # pragma: no branch
for key, value in handlers.items():
self.add_exception_handler(key, value)
def add_exception_handler(
self,
exc_class_or_status_code: typing.Union[int, typing.Type[Exception]],
handler: typing.Callable[[Request, Exception], Response],
exc_class_or_status_code: int | type[Exception],
handler: ExceptionHandler,
) -> None:
if isinstance(exc_class_or_status_code, int):
self._status_handlers[exc_class_or_status_code] = handler
@@ -42,68 +44,30 @@ class ExceptionMiddleware:
assert issubclass(exc_class_or_status_code, Exception)
self._exception_handlers[exc_class_or_status_code] = handler
def _lookup_exception_handler(
self, exc: Exception
) -> typing.Optional[typing.Callable]:
for cls in type(exc).__mro__:
if cls in self._exception_handlers:
return self._exception_handlers[cls]
return None
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] not in ("http", "websocket"):
await self.app(scope, receive, send)
return
response_started = False
async def sender(message: Message) -> None:
nonlocal response_started
if message["type"] == "http.response.start":
response_started = True
await send(message)
try:
await self.app(scope, receive, sender)
except Exception as exc:
handler = None
if isinstance(exc, HTTPException):
handler = self._status_handlers.get(exc.status_code)
if handler is None:
handler = self._lookup_exception_handler(exc)
if handler is None:
raise exc
if response_started:
msg = "Caught handled exception, but response already started."
raise RuntimeError(msg) from exc
if scope["type"] == "http":
request = Request(scope, receive=receive)
if is_async_callable(handler):
response = await handler(request, exc)
else:
response = await run_in_threadpool(handler, request, exc)
await response(scope, receive, sender)
elif scope["type"] == "websocket":
websocket = WebSocket(scope, receive=receive, send=send)
if is_async_callable(handler):
await handler(websocket, exc)
else:
await run_in_threadpool(handler, websocket, exc)
def http_exception(self, request: Request, exc: HTTPException) -> Response:
if exc.status_code in {204, 304}:
return Response(status_code=exc.status_code, headers=exc.headers)
return PlainTextResponse(
exc.detail, status_code=exc.status_code, headers=exc.headers
scope["starlette.exception_handlers"] = (
self._exception_handlers,
self._status_handlers,
)
async def websocket_exception(
self, websocket: WebSocket, exc: WebSocketException
) -> None:
await websocket.close(code=exc.code, reason=exc.reason)
conn: Request | WebSocket
if scope["type"] == "http":
conn = Request(scope, receive, send)
else:
conn = WebSocket(scope, receive, send)
await wrap_app_handling_exceptions(self.app, conn)(scope, receive, send)
async def http_exception(self, request: Request, exc: Exception) -> Response:
assert isinstance(exc, HTTPException)
if exc.status_code in {204, 304}:
return Response(status_code=exc.status_code, headers=exc.headers)
return PlainTextResponse(exc.detail, status_code=exc.status_code, headers=exc.headers)
async def websocket_exception(self, websocket: WebSocket, exc: Exception) -> None:
assert isinstance(exc, WebSocketException)
await websocket.close(code=exc.code, reason=exc.reason) # pragma: no cover

View File

@@ -1,49 +1,51 @@
import gzip
import io
import typing
from typing import NoReturn
from starlette.datastructures import Headers, MutableHeaders
from starlette.types import ASGIApp, Message, Receive, Scope, Send
DEFAULT_EXCLUDED_CONTENT_TYPES = ("text/event-stream",)
class GZipMiddleware:
def __init__(
self, app: ASGIApp, minimum_size: int = 500, compresslevel: int = 9
) -> None:
def __init__(self, app: ASGIApp, minimum_size: int = 500, compresslevel: int = 9) -> None:
self.app = app
self.minimum_size = minimum_size
self.compresslevel = compresslevel
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] == "http":
headers = Headers(scope=scope)
if "gzip" in headers.get("Accept-Encoding", ""):
responder = GZipResponder(
self.app, self.minimum_size, compresslevel=self.compresslevel
)
await responder(scope, receive, send)
return
await self.app(scope, receive, send)
if scope["type"] != "http": # pragma: no cover
await self.app(scope, receive, send)
return
headers = Headers(scope=scope)
responder: ASGIApp
if "gzip" in headers.get("Accept-Encoding", ""):
responder = GZipResponder(self.app, self.minimum_size, compresslevel=self.compresslevel)
else:
responder = IdentityResponder(self.app, self.minimum_size)
await responder(scope, receive, send)
class GZipResponder:
def __init__(self, app: ASGIApp, minimum_size: int, compresslevel: int = 9) -> None:
class IdentityResponder:
content_encoding: str
def __init__(self, app: ASGIApp, minimum_size: int) -> None:
self.app = app
self.minimum_size = minimum_size
self.send: Send = unattached_send
self.initial_message: Message = {}
self.started = False
self.content_encoding_set = False
self.gzip_buffer = io.BytesIO()
self.gzip_file = gzip.GzipFile(
mode="wb", fileobj=self.gzip_buffer, compresslevel=compresslevel
)
self.content_type_is_excluded = False
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
self.send = send
await self.app(scope, receive, self.send_with_gzip)
await self.app(scope, receive, self.send_with_compression)
async def send_with_gzip(self, message: Message) -> None:
async def send_with_compression(self, message: Message) -> None:
message_type = message["type"]
if message_type == "http.response.start":
# Don't send the initial message until we've determined how to
@@ -51,7 +53,8 @@ class GZipResponder:
self.initial_message = message
headers = Headers(raw=self.initial_message["headers"])
self.content_encoding_set = "content-encoding" in headers
elif message_type == "http.response.body" and self.content_encoding_set:
self.content_type_is_excluded = headers.get("content-type", "").startswith(DEFAULT_EXCLUDED_CONTENT_TYPES)
elif message_type == "http.response.body" and (self.content_encoding_set or self.content_type_is_excluded):
if not self.started:
self.started = True
await self.send(self.initial_message)
@@ -61,53 +64,82 @@ class GZipResponder:
body = message.get("body", b"")
more_body = message.get("more_body", False)
if len(body) < self.minimum_size and not more_body:
# Don't apply GZip to small outgoing responses.
# Don't apply compression to small outgoing responses.
await self.send(self.initial_message)
await self.send(message)
elif not more_body:
# Standard GZip response.
self.gzip_file.write(body)
self.gzip_file.close()
body = self.gzip_buffer.getvalue()
# Standard response.
body = self.apply_compression(body, more_body=False)
headers = MutableHeaders(raw=self.initial_message["headers"])
headers["Content-Encoding"] = "gzip"
headers["Content-Length"] = str(len(body))
headers.add_vary_header("Accept-Encoding")
message["body"] = body
if body != message["body"]:
headers["Content-Encoding"] = self.content_encoding
headers["Content-Length"] = str(len(body))
message["body"] = body
await self.send(self.initial_message)
await self.send(message)
else:
# Initial body in streaming GZip response.
headers = MutableHeaders(raw=self.initial_message["headers"])
headers["Content-Encoding"] = "gzip"
headers.add_vary_header("Accept-Encoding")
del headers["Content-Length"]
# Initial body in streaming response.
body = self.apply_compression(body, more_body=True)
self.gzip_file.write(body)
message["body"] = self.gzip_buffer.getvalue()
self.gzip_buffer.seek(0)
self.gzip_buffer.truncate()
headers = MutableHeaders(raw=self.initial_message["headers"])
headers.add_vary_header("Accept-Encoding")
if body != message["body"]:
headers["Content-Encoding"] = self.content_encoding
del headers["Content-Length"]
message["body"] = body
await self.send(self.initial_message)
await self.send(message)
elif message_type == "http.response.body":
# Remaining body in streaming GZip response.
# Remaining body in streaming response.
body = message.get("body", b"")
more_body = message.get("more_body", False)
self.gzip_file.write(body)
if not more_body:
self.gzip_file.close()
message["body"] = self.gzip_buffer.getvalue()
self.gzip_buffer.seek(0)
self.gzip_buffer.truncate()
message["body"] = self.apply_compression(body, more_body=more_body)
await self.send(message)
elif message_type == "http.response.pathsend": # pragma: no branch
# Don't apply GZip to pathsend responses
await self.send(self.initial_message)
await self.send(message)
def apply_compression(self, body: bytes, *, more_body: bool) -> bytes:
"""Apply compression on the response body.
If more_body is False, any compression file should be closed. If it
isn't, it won't be closed automatically until all background tasks
complete.
"""
return body
async def unattached_send(message: Message) -> typing.NoReturn:
class GZipResponder(IdentityResponder):
content_encoding = "gzip"
def __init__(self, app: ASGIApp, minimum_size: int, compresslevel: int = 9) -> None:
super().__init__(app, minimum_size)
self.gzip_buffer = io.BytesIO()
self.gzip_file = gzip.GzipFile(mode="wb", fileobj=self.gzip_buffer, compresslevel=compresslevel)
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
with self.gzip_buffer, self.gzip_file:
await super().__call__(scope, receive, send)
def apply_compression(self, body: bytes, *, more_body: bool) -> bytes:
self.gzip_file.write(body)
if not more_body:
self.gzip_file.close()
body = self.gzip_buffer.getvalue()
self.gzip_buffer.seek(0)
self.gzip_buffer.truncate()
return body
async def unattached_send(message: Message) -> NoReturn:
raise RuntimeError("send awaitable not set") # pragma: no cover

View File

@@ -1,7 +1,8 @@
from __future__ import annotations
import json
import sys
import typing
from base64 import b64decode, b64encode
from typing import Literal
import itsdangerous
from itsdangerous.exc import BadSignature
@@ -10,22 +11,18 @@ from starlette.datastructures import MutableHeaders, Secret
from starlette.requests import HTTPConnection
from starlette.types import ASGIApp, Message, Receive, Scope, Send
if sys.version_info >= (3, 8): # pragma: no cover
from typing import Literal
else: # pragma: no cover
from typing_extensions import Literal
class SessionMiddleware:
def __init__(
self,
app: ASGIApp,
secret_key: typing.Union[str, Secret],
secret_key: str | Secret,
session_cookie: str = "session",
max_age: typing.Optional[int] = 14 * 24 * 60 * 60, # 14 days, in seconds
max_age: int | None = 14 * 24 * 60 * 60, # 14 days, in seconds
path: str = "/",
same_site: Literal["lax", "strict", "none"] = "lax",
https_only: bool = False,
domain: str | None = None,
) -> None:
self.app = app
self.signer = itsdangerous.TimestampSigner(str(secret_key))
@@ -35,6 +32,8 @@ class SessionMiddleware:
self.security_flags = "httponly; samesite=" + same_site
if https_only: # Secure flag can be used with HTTPS only
self.security_flags += "; secure"
if domain is not None:
self.security_flags += f"; domain={domain}"
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] not in ("http", "websocket"): # pragma: no cover
@@ -62,7 +61,7 @@ class SessionMiddleware:
data = b64encode(json.dumps(scope["session"]).encode("utf-8"))
data = self.signer.sign(data)
headers = MutableHeaders(scope=message)
header_value = "{session_cookie}={data}; path={path}; {max_age}{security_flags}".format( # noqa E501
header_value = "{session_cookie}={data}; path={path}; {max_age}{security_flags}".format(
session_cookie=self.session_cookie,
data=data.decode("utf-8"),
path=self.path,
@@ -73,7 +72,7 @@ class SessionMiddleware:
elif not initial_session_was_empty:
# The session has been cleared.
headers = MutableHeaders(scope=message)
header_value = "{session_cookie}={data}; path={path}; {expires}{security_flags}".format( # noqa E501
header_value = "{session_cookie}={data}; path={path}; {expires}{security_flags}".format(
session_cookie=self.session_cookie,
data="null",
path=self.path,

View File

@@ -1,4 +1,6 @@
import typing
from __future__ import annotations
from collections.abc import Sequence
from starlette.datastructures import URL, Headers
from starlette.responses import PlainTextResponse, RedirectResponse, Response
@@ -11,7 +13,7 @@ class TrustedHostMiddleware:
def __init__(
self,
app: ASGIApp,
allowed_hosts: typing.Optional[typing.Sequence[str]] = None,
allowed_hosts: Sequence[str] | None = None,
www_redirect: bool = True,
) -> None:
if allowed_hosts is None:
@@ -39,9 +41,7 @@ class TrustedHostMiddleware:
is_valid_host = False
found_www_redirect = False
for pattern in self.allowed_hosts:
if host == pattern or (
pattern.startswith("*") and host.endswith(pattern[1:])
):
if host == pattern or (pattern.startswith("*") and host.endswith(pattern[1:])):
is_valid_host = True
break
elif "www." + host == pattern:

View File

@@ -1,10 +1,14 @@
from __future__ import annotations
import io
import math
import sys
import typing
import warnings
from collections.abc import Callable, MutableMapping
from typing import Any
import anyio
from anyio.abc import ObjectReceiveStream, ObjectSendStream
from starlette.types import Receive, Scope, Send
@@ -15,14 +19,20 @@ warnings.warn(
)
def build_environ(scope: Scope, body: bytes) -> dict:
def build_environ(scope: Scope, body: bytes) -> dict[str, Any]:
"""
Builds a scope and request body into a WSGI environ object.
"""
script_name = scope.get("root_path", "").encode("utf8").decode("latin1")
path_info = scope["path"].encode("utf8").decode("latin1")
if path_info.startswith(script_name):
path_info = path_info[len(script_name) :]
environ = {
"REQUEST_METHOD": scope["method"],
"SCRIPT_NAME": scope.get("root_path", "").encode("utf8").decode("latin1"),
"PATH_INFO": scope["path"].encode("utf8").decode("latin1"),
"SCRIPT_NAME": script_name,
"PATH_INFO": path_info,
"QUERY_STRING": scope["query_string"].decode("ascii"),
"SERVER_PROTOCOL": f"HTTP/{scope['http_version']}",
"wsgi.version": (1, 0),
@@ -62,7 +72,7 @@ def build_environ(scope: Scope, body: bytes) -> dict:
class WSGIMiddleware:
def __init__(self, app: typing.Callable) -> None:
def __init__(self, app: Callable[..., Any]) -> None:
self.app = app
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
@@ -72,16 +82,17 @@ class WSGIMiddleware:
class WSGIResponder:
def __init__(self, app: typing.Callable, scope: Scope) -> None:
stream_send: ObjectSendStream[MutableMapping[str, Any]]
stream_receive: ObjectReceiveStream[MutableMapping[str, Any]]
def __init__(self, app: Callable[..., Any], scope: Scope) -> None:
self.app = app
self.scope = scope
self.status = None
self.response_headers = None
self.stream_send, self.stream_receive = anyio.create_memory_object_stream(
math.inf
)
self.stream_send, self.stream_receive = anyio.create_memory_object_stream(math.inf)
self.response_started = False
self.exc_info: typing.Any = None
self.exc_info: Any = None
async def __call__(self, receive: Receive, send: Send) -> None:
body = b""
@@ -107,11 +118,11 @@ class WSGIResponder:
def start_response(
self,
status: str,
response_headers: typing.List[typing.Tuple[str, str]],
exc_info: typing.Any = None,
response_headers: list[tuple[str, str]],
exc_info: Any = None,
) -> None:
self.exc_info = exc_info
if not self.response_started:
if not self.response_started: # pragma: no branch
self.response_started = True
status_code_string, _ = status.split(" ", 1)
status_code = int(status_code_string)
@@ -128,13 +139,15 @@ class WSGIResponder:
},
)
def wsgi(self, environ: dict, start_response: typing.Callable) -> None:
def wsgi(
self,
environ: dict[str, Any],
start_response: Callable[..., Any],
) -> None:
for chunk in self.app(environ, start_response):
anyio.from_thread.run(
self.stream_send.send,
{"type": "http.response.body", "body": chunk, "more_body": True},
)
anyio.from_thread.run(
self.stream_send.send, {"type": "http.response.body", "body": b""}
)
anyio.from_thread.run(self.stream_send.send, {"type": "http.response.body", "body": b""})

View File

@@ -1,6 +1,9 @@
from __future__ import annotations
import json
import typing
from collections.abc import AsyncGenerator, Iterator, Mapping
from http import cookies as http_cookies
from typing import TYPE_CHECKING, Any, NoReturn, cast
import anyio
@@ -10,14 +13,19 @@ from starlette.exceptions import HTTPException
from starlette.formparsers import FormParser, MultiPartException, MultiPartParser
from starlette.types import Message, Receive, Scope, Send
try:
from multipart.multipart import parse_options_header
except ModuleNotFoundError: # pragma: nocover
parse_options_header = None
if TYPE_CHECKING:
from python_multipart.multipart import parse_options_header
if typing.TYPE_CHECKING:
from starlette.applications import Starlette
from starlette.routing import Router
else:
try:
try:
from python_multipart.multipart import parse_options_header
except ModuleNotFoundError: # pragma: no cover
from multipart.multipart import parse_options_header
except ModuleNotFoundError: # pragma: no cover
parse_options_header = None
SERVER_PUSH_HEADERS_TO_COPY = {
@@ -29,7 +37,7 @@ SERVER_PUSH_HEADERS_TO_COPY = {
}
def cookie_parser(cookie_string: str) -> typing.Dict[str, str]:
def cookie_parser(cookie_string: str) -> dict[str, str]:
"""
This function parses a ``Cookie`` HTTP header into a dict of key/value pairs.
@@ -41,7 +49,7 @@ def cookie_parser(cookie_string: str) -> typing.Dict[str, str]:
Note: we are explicitly _NOT_ using `SimpleCookie.load` because it is based
on an outdated spec and will fail on lots of input we want to support
"""
cookie_dict: typing.Dict[str, str] = {}
cookie_dict: dict[str, str] = {}
for chunk in cookie_string.split(";"):
if "=" in chunk:
key, val = chunk.split("=", 1)
@@ -60,20 +68,20 @@ class ClientDisconnect(Exception):
pass
class HTTPConnection(typing.Mapping[str, typing.Any]):
class HTTPConnection(Mapping[str, Any]):
"""
A base class for incoming HTTP connections, that is used to provide
any functionality that is common to both `Request` and `WebSocket`.
"""
def __init__(self, scope: Scope, receive: typing.Optional[Receive] = None) -> None:
def __init__(self, scope: Scope, receive: Receive | None = None) -> None:
assert scope["type"] in ("http", "websocket")
self.scope = scope
def __getitem__(self, key: str) -> typing.Any:
def __getitem__(self, key: str) -> Any:
return self.scope[key]
def __iter__(self) -> typing.Iterator[str]:
def __iter__(self) -> Iterator[str]:
return iter(self.scope)
def __len__(self) -> int:
@@ -86,12 +94,12 @@ class HTTPConnection(typing.Mapping[str, typing.Any]):
__hash__ = object.__hash__
@property
def app(self) -> typing.Any:
def app(self) -> Any:
return self.scope["app"]
@property
def url(self) -> URL:
if not hasattr(self, "_url"):
if not hasattr(self, "_url"): # pragma: no branch
self._url = URL(scope=self.scope)
return self._url
@@ -99,11 +107,16 @@ class HTTPConnection(typing.Mapping[str, typing.Any]):
def base_url(self) -> URL:
if not hasattr(self, "_base_url"):
base_url_scope = dict(self.scope)
base_url_scope["path"] = "/"
# This is used by request.url_for, it might be used inside a Mount which
# would have its own child scope with its own root_path, but the base URL
# for url_for should still be the top level app root path.
app_root_path = base_url_scope.get("app_root_path", base_url_scope.get("root_path", ""))
path = app_root_path
if not path.endswith("/"):
path += "/"
base_url_scope["path"] = path
base_url_scope["query_string"] = b""
base_url_scope["root_path"] = base_url_scope.get(
"app_root_path", base_url_scope.get("root_path", "")
)
base_url_scope["root_path"] = app_root_path
self._base_url = URL(scope=base_url_scope)
return self._base_url
@@ -115,52 +128,47 @@ class HTTPConnection(typing.Mapping[str, typing.Any]):
@property
def query_params(self) -> QueryParams:
if not hasattr(self, "_query_params"):
if not hasattr(self, "_query_params"): # pragma: no branch
self._query_params = QueryParams(self.scope["query_string"])
return self._query_params
@property
def path_params(self) -> typing.Dict[str, typing.Any]:
def path_params(self) -> dict[str, Any]:
return self.scope.get("path_params", {})
@property
def cookies(self) -> typing.Dict[str, str]:
def cookies(self) -> dict[str, str]:
if not hasattr(self, "_cookies"):
cookies: typing.Dict[str, str] = {}
cookie_header = self.headers.get("cookie")
cookies: dict[str, str] = {}
cookie_headers = self.headers.getlist("cookie")
for header in cookie_headers:
cookies.update(cookie_parser(header))
if cookie_header:
cookies = cookie_parser(cookie_header)
self._cookies = cookies
return self._cookies
@property
def client(self) -> typing.Optional[Address]:
# client is a 2 item tuple of (host, port), None or missing
def client(self) -> Address | None:
# client is a 2 item tuple of (host, port), None if missing
host_port = self.scope.get("client")
if host_port is not None:
return Address(*host_port)
return None
@property
def session(self) -> typing.Dict[str, typing.Any]:
assert (
"session" in self.scope
), "SessionMiddleware must be installed to access request.session"
return self.scope["session"]
def session(self) -> dict[str, Any]:
assert "session" in self.scope, "SessionMiddleware must be installed to access request.session"
return self.scope["session"] # type: ignore[no-any-return]
@property
def auth(self) -> typing.Any:
assert (
"auth" in self.scope
), "AuthenticationMiddleware must be installed to access request.auth"
def auth(self) -> Any:
assert "auth" in self.scope, "AuthenticationMiddleware must be installed to access request.auth"
return self.scope["auth"]
@property
def user(self) -> typing.Any:
assert (
"user" in self.scope
), "AuthenticationMiddleware must be installed to access request.user"
def user(self) -> Any:
assert "user" in self.scope, "AuthenticationMiddleware must be installed to access request.user"
return self.scope["user"]
@property
@@ -173,26 +181,26 @@ class HTTPConnection(typing.Mapping[str, typing.Any]):
self._state = State(self.scope["state"])
return self._state
def url_for(self, __name: str, **path_params: typing.Any) -> URL:
router: Router = self.scope["router"]
url_path = router.url_path_for(__name, **path_params)
def url_for(self, name: str, /, **path_params: Any) -> URL:
url_path_provider: Router | Starlette | None = self.scope.get("router") or self.scope.get("app")
if url_path_provider is None:
raise RuntimeError("The `url_for` method can only be used inside a Starlette application or with a router.")
url_path = url_path_provider.url_path_for(name, **path_params)
return url_path.make_absolute_url(base_url=self.base_url)
async def empty_receive() -> typing.NoReturn:
async def empty_receive() -> NoReturn:
raise RuntimeError("Receive channel has not been made available")
async def empty_send(message: Message) -> typing.NoReturn:
async def empty_send(message: Message) -> NoReturn:
raise RuntimeError("Send channel has not been made available")
class Request(HTTPConnection):
_form: typing.Optional[FormData]
_form: FormData | None
def __init__(
self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send
):
def __init__(self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send):
super().__init__(scope)
assert scope["type"] == "http"
self._receive = receive
@@ -203,43 +211,42 @@ class Request(HTTPConnection):
@property
def method(self) -> str:
return self.scope["method"]
return cast(str, self.scope["method"])
@property
def receive(self) -> Receive:
return self._receive
async def stream(self) -> typing.AsyncGenerator[bytes, None]:
async def stream(self) -> AsyncGenerator[bytes, None]:
if hasattr(self, "_body"):
yield self._body
yield b""
return
if self._stream_consumed:
raise RuntimeError("Stream consumed")
self._stream_consumed = True
while True:
while not self._stream_consumed:
message = await self._receive()
if message["type"] == "http.request":
body = message.get("body", b"")
if not message.get("more_body", False):
self._stream_consumed = True
if body:
yield body
if not message.get("more_body", False):
break
elif message["type"] == "http.disconnect":
elif message["type"] == "http.disconnect": # pragma: no branch
self._is_disconnected = True
raise ClientDisconnect()
yield b""
async def body(self) -> bytes:
if not hasattr(self, "_body"):
chunks: "typing.List[bytes]" = []
chunks: list[bytes] = []
async for chunk in self.stream():
chunks.append(chunk)
self._body = b"".join(chunks)
return self._body
async def json(self) -> typing.Any:
if not hasattr(self, "_json"):
async def json(self) -> Any:
if not hasattr(self, "_json"): # pragma: no branch
body = await self.body()
self._json = json.loads(body)
return self._json
@@ -247,13 +254,14 @@ class Request(HTTPConnection):
async def _get_form(
self,
*,
max_files: typing.Union[int, float] = 1000,
max_fields: typing.Union[int, float] = 1000,
max_files: int | float = 1000,
max_fields: int | float = 1000,
max_part_size: int = 1024 * 1024,
) -> FormData:
if self._form is None:
assert (
parse_options_header is not None
), "The `python-multipart` library must be installed to use form parsing."
if self._form is None: # pragma: no branch
assert parse_options_header is not None, (
"The `python-multipart` library must be installed to use form parsing."
)
content_type_header = self.headers.get("Content-Type")
content_type: bytes
content_type, _ = parse_options_header(content_type_header)
@@ -264,6 +272,7 @@ class Request(HTTPConnection):
self.stream(),
max_files=max_files,
max_fields=max_fields,
max_part_size=max_part_size,
)
self._form = await multipart_parser.parse()
except MultiPartException as exc:
@@ -280,15 +289,16 @@ class Request(HTTPConnection):
def form(
self,
*,
max_files: typing.Union[int, float] = 1000,
max_fields: typing.Union[int, float] = 1000,
max_files: int | float = 1000,
max_fields: int | float = 1000,
max_part_size: int = 1024 * 1024,
) -> AwaitableOrContextManager[FormData]:
return AwaitableOrContextManagerWrapper(
self._get_form(max_files=max_files, max_fields=max_fields)
self._get_form(max_files=max_files, max_fields=max_fields, max_part_size=max_part_size)
)
async def close(self) -> None:
if self._form is not None:
if self._form is not None: # pragma: no branch
await self._form.close()
async def is_disconnected(self) -> bool:
@@ -307,12 +317,8 @@ class Request(HTTPConnection):
async def send_push_promise(self, path: str) -> None:
if "http.response.push" in self.scope.get("extensions", {}):
raw_headers: "typing.List[typing.Tuple[bytes, bytes]]" = []
raw_headers: list[tuple[bytes, bytes]] = []
for name in SERVER_PUSH_HEADERS_TO_COPY:
for value in self.headers.getlist(name):
raw_headers.append(
(name.encode("latin-1"), value.encode("latin-1"))
)
await self._send(
{"type": "http.response.push", "path": path, "headers": raw_headers}
)
raw_headers.append((name.encode("latin-1"), value.encode("latin-1")))
await self._send({"type": "http.response.push", "path": path, "headers": raw_headers})

View File

@@ -1,40 +1,31 @@
from __future__ import annotations
import hashlib
import http.cookies
import json
import os
import stat
import sys
import typing
import warnings
from collections.abc import AsyncIterable, Awaitable, Callable, Iterable, Mapping, Sequence
from datetime import datetime
from email.utils import format_datetime, formatdate
from functools import partial
from mimetypes import guess_type as mimetypes_guess_type
from mimetypes import guess_type
from secrets import token_hex
from typing import Any, Literal
from urllib.parse import quote
import anyio
import anyio.to_thread
from starlette._compat import md5_hexdigest
from starlette._utils import collapse_excgroups
from starlette.background import BackgroundTask
from starlette.concurrency import iterate_in_threadpool
from starlette.datastructures import URL, MutableHeaders
from starlette.datastructures import URL, Headers, MutableHeaders
from starlette.requests import ClientDisconnect
from starlette.types import Receive, Scope, Send
if sys.version_info >= (3, 8): # pragma: no cover
from typing import Literal
else: # pragma: no cover
from typing_extensions import Literal
# Workaround for adding samesite support to pre 3.8 python
http.cookies.Morsel._reserved["samesite"] = "SameSite" # type: ignore[attr-defined]
# Compatibility wrapper for `mimetypes.guess_type` to support `os.PathLike` on <py3.8
def guess_type(
url: typing.Union[str, "os.PathLike[str]"], strict: bool = True
) -> typing.Tuple[typing.Optional[str], typing.Optional[str]]:
if sys.version_info < (3, 8): # pragma: no cover
url = os.fspath(url)
return mimetypes_guess_type(url, strict)
class Response:
media_type = None
@@ -42,11 +33,11 @@ class Response:
def __init__(
self,
content: typing.Any = None,
content: Any = None,
status_code: int = 200,
headers: typing.Optional[typing.Mapping[str, str]] = None,
media_type: typing.Optional[str] = None,
background: typing.Optional[BackgroundTask] = None,
headers: Mapping[str, str] | None = None,
media_type: str | None = None,
background: BackgroundTask | None = None,
) -> None:
self.status_code = status_code
if media_type is not None:
@@ -55,25 +46,20 @@ class Response:
self.body = self.render(content)
self.init_headers(headers)
def render(self, content: typing.Any) -> bytes:
def render(self, content: Any) -> bytes | memoryview:
if content is None:
return b""
if isinstance(content, bytes):
if isinstance(content, bytes | memoryview):
return content
return content.encode(self.charset)
return content.encode(self.charset) # type: ignore
def init_headers(
self, headers: typing.Optional[typing.Mapping[str, str]] = None
) -> None:
def init_headers(self, headers: Mapping[str, str] | None = None) -> None:
if headers is None:
raw_headers: typing.List[typing.Tuple[bytes, bytes]] = []
raw_headers: list[tuple[bytes, bytes]] = []
populate_content_length = True
populate_content_type = True
else:
raw_headers = [
(k.lower().encode("latin-1"), v.encode("latin-1"))
for k, v in headers.items()
]
raw_headers = [(k.lower().encode("latin-1"), v.encode("latin-1")) for k, v in headers.items()]
keys = [h[0] for h in raw_headers]
populate_content_length = b"content-length" not in keys
populate_content_type = b"content-type" not in keys
@@ -89,7 +75,7 @@ class Response:
content_type = self.media_type
if content_type is not None and populate_content_type:
if content_type.startswith("text/"):
if content_type.startswith("text/") and "charset=" not in content_type.lower():
content_type += "; charset=" + self.charset
raw_headers.append((b"content-type", content_type.encode("latin-1")))
@@ -105,15 +91,16 @@ class Response:
self,
key: str,
value: str = "",
max_age: typing.Optional[int] = None,
expires: typing.Optional[typing.Union[datetime, str, int]] = None,
path: str = "/",
domain: typing.Optional[str] = None,
max_age: int | None = None,
expires: datetime | str | int | None = None,
path: str | None = "/",
domain: str | None = None,
secure: bool = False,
httponly: bool = False,
samesite: typing.Optional[Literal["lax", "strict", "none"]] = "lax",
samesite: Literal["lax", "strict", "none"] | None = "lax",
partitioned: bool = False,
) -> None:
cookie: "http.cookies.BaseCookie[str]" = http.cookies.SimpleCookie()
cookie: http.cookies.BaseCookie[str] = http.cookies.SimpleCookie()
cookie[key] = value
if max_age is not None:
cookie[key]["max-age"] = max_age
@@ -137,6 +124,11 @@ class Response:
"none",
], "samesite must be either 'strict', 'lax' or 'none'"
cookie[key]["samesite"] = samesite
if partitioned:
if sys.version_info < (3, 14):
raise ValueError("Partitioned cookies are only supported in Python 3.14 and above.") # pragma: no cover
cookie[key]["partitioned"] = True # pragma: no cover
cookie_val = cookie.output(header="").strip()
self.raw_headers.append((b"set-cookie", cookie_val.encode("latin-1")))
@@ -144,10 +136,10 @@ class Response:
self,
key: str,
path: str = "/",
domain: typing.Optional[str] = None,
domain: str | None = None,
secure: bool = False,
httponly: bool = False,
samesite: typing.Optional[Literal["lax", "strict", "none"]] = "lax",
samesite: Literal["lax", "strict", "none"] | None = "lax",
) -> None:
self.set_cookie(
key,
@@ -161,14 +153,15 @@ class Response:
)
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
prefix = "websocket." if scope["type"] == "websocket" else ""
await send(
{
"type": "http.response.start",
"type": prefix + "http.response.start",
"status": self.status_code,
"headers": self.raw_headers,
}
)
await send({"type": "http.response.body", "body": self.body})
await send({"type": prefix + "http.response.body", "body": self.body})
if self.background is not None:
await self.background()
@@ -187,15 +180,15 @@ class JSONResponse(Response):
def __init__(
self,
content: typing.Any,
content: Any,
status_code: int = 200,
headers: typing.Optional[typing.Dict[str, str]] = None,
media_type: typing.Optional[str] = None,
background: typing.Optional[BackgroundTask] = None,
headers: Mapping[str, str] | None = None,
media_type: str | None = None,
background: BackgroundTask | None = None,
) -> None:
super().__init__(content, status_code, headers, media_type, background)
def render(self, content: typing.Any) -> bytes:
def render(self, content: Any) -> bytes:
return json.dumps(
content,
ensure_ascii=False,
@@ -208,21 +201,19 @@ class JSONResponse(Response):
class RedirectResponse(Response):
def __init__(
self,
url: typing.Union[str, URL],
url: str | URL,
status_code: int = 307,
headers: typing.Optional[typing.Mapping[str, str]] = None,
background: typing.Optional[BackgroundTask] = None,
headers: Mapping[str, str] | None = None,
background: BackgroundTask | None = None,
) -> None:
super().__init__(
content=b"", status_code=status_code, headers=headers, background=background
)
super().__init__(content=b"", status_code=status_code, headers=headers, background=background)
self.headers["location"] = quote(str(url), safe=":/%#?=@[]!$&'()*+,;")
Content = typing.Union[str, bytes]
SyncContentStream = typing.Iterator[Content]
AsyncContentStream = typing.AsyncIterable[Content]
ContentStream = typing.Union[AsyncContentStream, SyncContentStream]
Content = str | bytes | memoryview
SyncContentStream = Iterable[Content]
AsyncContentStream = AsyncIterable[Content]
ContentStream = AsyncContentStream | SyncContentStream
class StreamingResponse(Response):
@@ -232,11 +223,11 @@ class StreamingResponse(Response):
self,
content: ContentStream,
status_code: int = 200,
headers: typing.Optional[typing.Mapping[str, str]] = None,
media_type: typing.Optional[str] = None,
background: typing.Optional[BackgroundTask] = None,
headers: Mapping[str, str] | None = None,
media_type: str | None = None,
background: BackgroundTask | None = None,
) -> None:
if isinstance(content, typing.AsyncIterable):
if isinstance(content, AsyncIterable):
self.body_iterator = content
else:
self.body_iterator = iterate_in_threadpool(content)
@@ -260,60 +251,80 @@ class StreamingResponse(Response):
}
)
async for chunk in self.body_iterator:
if not isinstance(chunk, bytes):
if not isinstance(chunk, bytes | memoryview):
chunk = chunk.encode(self.charset)
await send({"type": "http.response.body", "body": chunk, "more_body": True})
await send({"type": "http.response.body", "body": b"", "more_body": False})
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
async with anyio.create_task_group() as task_group:
spec_version = tuple(map(int, scope.get("asgi", {}).get("spec_version", "2.0").split(".")))
async def wrap(func: "typing.Callable[[], typing.Awaitable[None]]") -> None:
await func()
task_group.cancel_scope.cancel()
if spec_version >= (2, 4):
try:
await self.stream_response(send)
except OSError:
raise ClientDisconnect()
else:
with collapse_excgroups():
async with anyio.create_task_group() as task_group:
task_group.start_soon(wrap, partial(self.stream_response, send))
await wrap(partial(self.listen_for_disconnect, receive))
async def wrap(func: Callable[[], Awaitable[None]]) -> None:
await func()
task_group.cancel_scope.cancel()
task_group.start_soon(wrap, partial(self.stream_response, send))
await wrap(partial(self.listen_for_disconnect, receive))
if self.background is not None:
await self.background()
class MalformedRangeHeader(Exception):
def __init__(self, content: str = "Malformed range header.") -> None:
self.content = content
class RangeNotSatisfiable(Exception):
def __init__(self, max_size: int) -> None:
self.max_size = max_size
class FileResponse(Response):
chunk_size = 64 * 1024
def __init__(
self,
path: typing.Union[str, "os.PathLike[str]"],
path: str | os.PathLike[str],
status_code: int = 200,
headers: typing.Optional[typing.Mapping[str, str]] = None,
media_type: typing.Optional[str] = None,
background: typing.Optional[BackgroundTask] = None,
filename: typing.Optional[str] = None,
stat_result: typing.Optional[os.stat_result] = None,
method: typing.Optional[str] = None,
headers: Mapping[str, str] | None = None,
media_type: str | None = None,
background: BackgroundTask | None = None,
filename: str | None = None,
stat_result: os.stat_result | None = None,
method: str | None = None,
content_disposition_type: str = "attachment",
) -> None:
self.path = path
self.status_code = status_code
self.filename = filename
self.send_header_only = method is not None and method.upper() == "HEAD"
if method is not None:
warnings.warn(
"The 'method' parameter is not used, and it will be removed.",
DeprecationWarning,
)
if media_type is None:
media_type = guess_type(filename or path)[0] or "text/plain"
self.media_type = media_type
self.background = background
self.init_headers(headers)
self.headers.setdefault("accept-ranges", "bytes")
if self.filename is not None:
content_disposition_filename = quote(self.filename)
if content_disposition_filename != self.filename:
content_disposition = "{}; filename*=utf-8''{}".format(
content_disposition_type, content_disposition_filename
)
content_disposition = f"{content_disposition_type}; filename*=utf-8''{content_disposition_filename}"
else:
content_disposition = '{}; filename="{}"'.format(
content_disposition_type, self.filename
)
content_disposition = f'{content_disposition_type}; filename="{self.filename}"'
self.headers.setdefault("content-disposition", content_disposition)
self.stat_result = stat_result
if stat_result is not None:
@@ -323,13 +334,16 @@ class FileResponse(Response):
content_length = str(stat_result.st_size)
last_modified = formatdate(stat_result.st_mtime, usegmt=True)
etag_base = str(stat_result.st_mtime) + "-" + str(stat_result.st_size)
etag = md5_hexdigest(etag_base.encode(), usedforsecurity=False)
etag = f'"{hashlib.md5(etag_base.encode(), usedforsecurity=False).hexdigest()}"'
self.headers.setdefault("content-length", content_length)
self.headers.setdefault("last-modified", last_modified)
self.headers.setdefault("etag", etag)
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
send_header_only: bool = scope["method"].upper() == "HEAD"
send_pathsend: bool = "http.response.pathsend" in scope.get("extensions", {})
if self.stat_result is None:
try:
stat_result = await anyio.to_thread.run_sync(os.stat, self.path)
@@ -340,27 +354,213 @@ class FileResponse(Response):
mode = stat_result.st_mode
if not stat.S_ISREG(mode):
raise RuntimeError(f"File at path {self.path} is not a file.")
await send(
{
"type": "http.response.start",
"status": self.status_code,
"headers": self.raw_headers,
}
)
if self.send_header_only:
else:
stat_result = self.stat_result
headers = Headers(scope=scope)
http_range = headers.get("range")
http_if_range = headers.get("if-range")
if http_range is None or (http_if_range is not None and not self._should_use_range(http_if_range)):
await self._handle_simple(send, send_header_only, send_pathsend)
else:
try:
ranges = self._parse_range_header(http_range, stat_result.st_size)
except MalformedRangeHeader as exc:
return await PlainTextResponse(exc.content, status_code=400)(scope, receive, send)
except RangeNotSatisfiable as exc:
response = PlainTextResponse(status_code=416, headers={"Content-Range": f"*/{exc.max_size}"})
return await response(scope, receive, send)
if len(ranges) == 1:
start, end = ranges[0]
await self._handle_single_range(send, start, end, stat_result.st_size, send_header_only)
else:
await self._handle_multiple_ranges(send, ranges, stat_result.st_size, send_header_only)
if self.background is not None:
await self.background()
async def _handle_simple(self, send: Send, send_header_only: bool, send_pathsend: bool) -> None:
await send({"type": "http.response.start", "status": self.status_code, "headers": self.raw_headers})
if send_header_only:
await send({"type": "http.response.body", "body": b"", "more_body": False})
elif send_pathsend:
await send({"type": "http.response.pathsend", "path": str(self.path)})
else:
async with await anyio.open_file(self.path, mode="rb") as file:
more_body = True
while more_body:
chunk = await file.read(self.chunk_size)
more_body = len(chunk) == self.chunk_size
await send(
{
"type": "http.response.body",
"body": chunk,
"more_body": more_body,
}
)
if self.background is not None:
await self.background()
await send({"type": "http.response.body", "body": chunk, "more_body": more_body})
async def _handle_single_range(
self, send: Send, start: int, end: int, file_size: int, send_header_only: bool
) -> None:
self.headers["content-range"] = f"bytes {start}-{end - 1}/{file_size}"
self.headers["content-length"] = str(end - start)
await send({"type": "http.response.start", "status": 206, "headers": self.raw_headers})
if send_header_only:
await send({"type": "http.response.body", "body": b"", "more_body": False})
else:
async with await anyio.open_file(self.path, mode="rb") as file:
await file.seek(start)
more_body = True
while more_body:
chunk = await file.read(min(self.chunk_size, end - start))
start += len(chunk)
more_body = len(chunk) == self.chunk_size and start < end
await send({"type": "http.response.body", "body": chunk, "more_body": more_body})
async def _handle_multiple_ranges(
self,
send: Send,
ranges: list[tuple[int, int]],
file_size: int,
send_header_only: bool,
) -> None:
# In firefox and chrome, they use boundary with 95-96 bits entropy (that's roughly 13 bytes).
boundary = token_hex(13)
content_length, header_generator = self.generate_multipart(
ranges, boundary, file_size, self.headers["content-type"]
)
self.headers["content-range"] = f"multipart/byteranges; boundary={boundary}"
self.headers["content-length"] = str(content_length)
await send({"type": "http.response.start", "status": 206, "headers": self.raw_headers})
if send_header_only:
await send({"type": "http.response.body", "body": b"", "more_body": False})
else:
async with await anyio.open_file(self.path, mode="rb") as file:
for start, end in ranges:
await send({"type": "http.response.body", "body": header_generator(start, end), "more_body": True})
await file.seek(start)
while start < end:
chunk = await file.read(min(self.chunk_size, end - start))
start += len(chunk)
await send({"type": "http.response.body", "body": chunk, "more_body": True})
await send({"type": "http.response.body", "body": b"\n", "more_body": True})
await send(
{
"type": "http.response.body",
"body": f"\n--{boundary}--\n".encode("latin-1"),
"more_body": False,
}
)
def _should_use_range(self, http_if_range: str) -> bool:
return http_if_range == self.headers["last-modified"] or http_if_range == self.headers["etag"]
@classmethod
def _parse_range_header(cls, http_range: str, file_size: int) -> list[tuple[int, int]]:
ranges: list[tuple[int, int]] = []
try:
units, range_ = http_range.split("=", 1)
except ValueError:
raise MalformedRangeHeader()
units = units.strip().lower()
if units != "bytes":
raise MalformedRangeHeader("Only support bytes range")
ranges = cls._parse_ranges(range_, file_size)
if len(ranges) == 0:
raise MalformedRangeHeader("Range header: range must be requested")
if any(not (0 <= start < file_size) for start, _ in ranges):
raise RangeNotSatisfiable(file_size)
if any(start > end for start, end in ranges):
raise MalformedRangeHeader("Range header: start must be less than end")
if len(ranges) == 1:
return ranges
# Merge ranges
result: list[tuple[int, int]] = []
for start, end in ranges:
for p in range(len(result)):
p_start, p_end = result[p]
if start > p_end:
continue
elif end < p_start:
result.insert(p, (start, end)) # THIS IS NOT REACHED!
break
else:
result[p] = (min(start, p_start), max(end, p_end))
break
else:
result.append((start, end))
return result
@classmethod
def _parse_ranges(cls, range_: str, file_size: int) -> list[tuple[int, int]]:
ranges: list[tuple[int, int]] = []
for part in range_.split(","):
part = part.strip()
# If the range is empty or a single dash, we ignore it.
if not part or part == "-":
continue
# If the range is not in the format "start-end", we ignore it.
if "-" not in part:
continue
start_str, end_str = part.split("-", 1)
start_str = start_str.strip()
end_str = end_str.strip()
try:
start = int(start_str) if start_str else file_size - int(end_str)
end = int(end_str) + 1 if start_str and end_str and int(end_str) < file_size else file_size
ranges.append((start, end))
except ValueError:
# If the range is not numeric, we ignore it.
continue
return ranges
def generate_multipart(
self,
ranges: Sequence[tuple[int, int]],
boundary: str,
max_size: int,
content_type: str,
) -> tuple[int, Callable[[int, int], bytes]]:
r"""
Multipart response headers generator.
```
--{boundary}\n
Content-Type: {content_type}\n
Content-Range: bytes {start}-{end-1}/{max_size}\n
\n
..........content...........\n
--{boundary}\n
Content-Type: {content_type}\n
Content-Range: bytes {start}-{end-1}/{max_size}\n
\n
..........content...........\n
--{boundary}--\n
```
"""
boundary_len = len(boundary)
static_header_part_len = 44 + boundary_len + len(content_type) + len(str(max_size))
content_length = sum(
(len(str(start)) + len(str(end - 1)) + static_header_part_len) # Headers
+ (end - start) # Content
for start, end in ranges
) + (
5 + boundary_len # --boundary--\n
)
return (
content_length,
lambda start, end: (
f"--{boundary}\nContent-Type: {content_type}\nContent-Range: bytes {start}-{end - 1}/{max_size}\n\n"
).encode("latin-1"),
)

View File

@@ -1,22 +1,27 @@
from __future__ import annotations
import contextlib
import functools
import inspect
import re
import traceback
import types
import typing
import warnings
from contextlib import asynccontextmanager
from collections.abc import Awaitable, Callable, Collection, Generator, Sequence
from contextlib import AbstractAsyncContextManager, AbstractContextManager, asynccontextmanager
from enum import Enum
from re import Pattern
from typing import Any, TypeVar
from starlette._utils import is_async_callable
from starlette._exception_handler import wrap_app_handling_exceptions
from starlette._utils import get_route_path, is_async_callable
from starlette.concurrency import run_in_threadpool
from starlette.convertors import CONVERTOR_TYPES, Convertor
from starlette.datastructures import URL, Headers, URLPath
from starlette.exceptions import HTTPException
from starlette.middleware import Middleware
from starlette.requests import Request
from starlette.responses import PlainTextResponse, RedirectResponse
from starlette.responses import PlainTextResponse, RedirectResponse, Response
from starlette.types import ASGIApp, Lifespan, Receive, Scope, Send
from starlette.websockets import WebSocket, WebSocketClose
@@ -27,7 +32,7 @@ class NoMatchFound(Exception):
if no matching route exists.
"""
def __init__(self, name: str, path_params: typing.Dict[str, typing.Any]) -> None:
def __init__(self, name: str, path_params: dict[str, Any]) -> None:
params = ", ".join(list(path_params.keys()))
super().__init__(f'No route exists for name "{name}" and params "{params}".')
@@ -38,14 +43,13 @@ class Match(Enum):
FULL = 2
def iscoroutinefunction_or_partial(obj: typing.Any) -> bool: # pragma: no cover
def iscoroutinefunction_or_partial(obj: Any) -> bool: # pragma: no cover
"""
Correctly determines if an object is a coroutine function,
including those wrapped in functools.partial objects.
"""
warnings.warn(
"iscoroutinefunction_or_partial is deprecated, "
"and will be removed in a future release.",
"iscoroutinefunction_or_partial is deprecated, and will be removed in a future release.",
DeprecationWarning,
)
while isinstance(obj, functools.partial):
@@ -53,25 +57,32 @@ def iscoroutinefunction_or_partial(obj: typing.Any) -> bool: # pragma: no cover
return inspect.iscoroutinefunction(obj)
def request_response(func: typing.Callable) -> ASGIApp:
def request_response(
func: Callable[[Request], Awaitable[Response] | Response],
) -> ASGIApp:
"""
Takes a function or coroutine `func(request) -> response`,
and returns an ASGI application.
"""
is_coroutine = is_async_callable(func)
f: Callable[[Request], Awaitable[Response]] = (
func if is_async_callable(func) else functools.partial(run_in_threadpool, func)
)
async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive=receive, send=send)
if is_coroutine:
response = await func(request)
else:
response = await run_in_threadpool(func, request)
await response(scope, receive, send)
request = Request(scope, receive, send)
async def app(scope: Scope, receive: Receive, send: Send) -> None:
response = await f(request)
await response(scope, receive, send)
await wrap_app_handling_exceptions(app, request)(scope, receive, send)
return app
def websocket_session(func: typing.Callable) -> ASGIApp:
def websocket_session(
func: Callable[[WebSocket], Awaitable[None]],
) -> ASGIApp:
"""
Takes a coroutine `func(session)`, and returns an ASGI application.
"""
@@ -79,22 +90,24 @@ def websocket_session(func: typing.Callable) -> ASGIApp:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
session = WebSocket(scope, receive=receive, send=send)
await func(session)
async def app(scope: Scope, receive: Receive, send: Send) -> None:
await func(session)
await wrap_app_handling_exceptions(app, session)(scope, receive, send)
return app
def get_name(endpoint: typing.Callable) -> str:
if inspect.isroutine(endpoint) or inspect.isclass(endpoint):
return endpoint.__name__
return endpoint.__class__.__name__
def get_name(endpoint: Callable[..., Any]) -> str:
return getattr(endpoint, "__name__", endpoint.__class__.__name__)
def replace_params(
path: str,
param_convertors: typing.Dict[str, Convertor],
path_params: typing.Dict[str, str],
) -> typing.Tuple[str, dict]:
param_convertors: dict[str, Convertor[Any]],
path_params: dict[str, str],
) -> tuple[str, dict[str, str]]:
for key, value in list(path_params.items()):
if "{" + key + "}" in path:
convertor = param_convertors[key]
@@ -110,7 +123,7 @@ PARAM_REGEX = re.compile("{([a-zA-Z_][a-zA-Z0-9_]*)(:[a-zA-Z_][a-zA-Z0-9_]*)?}")
def compile_path(
path: str,
) -> typing.Tuple[typing.Pattern, str, typing.Dict[str, Convertor]]:
) -> tuple[Pattern[str], str, dict[str, Convertor[Any]]]:
"""
Given a path string, like: "/{username:str}",
or a host string, like: "{subdomain}.mydomain.org", return a three-tuple
@@ -131,9 +144,7 @@ def compile_path(
for match in PARAM_REGEX.finditer(path):
param_name, convertor_type = match.groups("str")
convertor_type = convertor_type.lstrip(":")
assert (
convertor_type in CONVERTOR_TYPES
), f"Unknown path convertor '{convertor_type}'"
assert convertor_type in CONVERTOR_TYPES, f"Unknown path convertor '{convertor_type}'"
convertor = CONVERTOR_TYPES[convertor_type]
path_regex += re.escape(path[idx : match.start()])
@@ -167,10 +178,10 @@ def compile_path(
class BaseRoute:
def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
def matches(self, scope: Scope) -> tuple[Match, Scope]:
raise NotImplementedError() # pragma: no cover
def url_path_for(self, __name: str, **path_params: typing.Any) -> URLPath:
def url_path_for(self, name: str, /, **path_params: Any) -> URLPath:
raise NotImplementedError() # pragma: no cover
async def handle(self, scope: Scope, receive: Receive, send: Send) -> None:
@@ -187,7 +198,7 @@ class BaseRoute:
if scope["type"] == "http":
response = PlainTextResponse("Not Found", status_code=404)
await response(scope, receive, send)
elif scope["type"] == "websocket":
elif scope["type"] == "websocket": # pragma: no branch
websocket_close = WebSocketClose()
await websocket_close(scope, receive, send)
return
@@ -200,11 +211,12 @@ class Route(BaseRoute):
def __init__(
self,
path: str,
endpoint: typing.Callable,
endpoint: Callable[..., Any],
*,
methods: typing.Optional[typing.List[str]] = None,
name: typing.Optional[str] = None,
methods: Collection[str] | None = None,
name: str | None = None,
include_in_schema: bool = True,
middleware: Sequence[Middleware] | None = None,
) -> None:
assert path.startswith("/"), "Routed paths must start with '/'"
self.path = path
@@ -224,6 +236,10 @@ class Route(BaseRoute):
# Endpoint is a class. Treat it as ASGI.
self.app = endpoint
if middleware is not None:
for cls, args, kwargs in reversed(middleware):
self.app = cls(self.app, *args, **kwargs)
if methods is None:
self.methods = None
else:
@@ -233,9 +249,11 @@ class Route(BaseRoute):
self.path_regex, self.path_format, self.param_convertors = compile_path(path)
def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
def matches(self, scope: Scope) -> tuple[Match, Scope]:
path_params: dict[str, Any]
if scope["type"] == "http":
match = self.path_regex.match(scope["path"])
route_path = get_route_path(scope)
match = self.path_regex.match(route_path)
if match:
matched_params = match.groupdict()
for key, value in matched_params.items():
@@ -249,16 +267,14 @@ class Route(BaseRoute):
return Match.FULL, child_scope
return Match.NONE, {}
def url_path_for(self, __name: str, **path_params: typing.Any) -> URLPath:
def url_path_for(self, name: str, /, **path_params: Any) -> URLPath:
seen_params = set(path_params.keys())
expected_params = set(self.param_convertors.keys())
if __name != self.name or seen_params != expected_params:
raise NoMatchFound(__name, path_params)
if name != self.name or seen_params != expected_params:
raise NoMatchFound(name, path_params)
path, remaining_params = replace_params(
self.path_format, self.param_convertors, path_params
)
path, remaining_params = replace_params(self.path_format, self.param_convertors, path_params)
assert not remaining_params
return URLPath(path=path, protocol="http")
@@ -268,14 +284,12 @@ class Route(BaseRoute):
if "app" in scope:
raise HTTPException(status_code=405, headers=headers)
else:
response = PlainTextResponse(
"Method Not Allowed", status_code=405, headers=headers
)
response = PlainTextResponse("Method Not Allowed", status_code=405, headers=headers)
await response(scope, receive, send)
else:
await self.app(scope, receive, send)
def __eq__(self, other: typing.Any) -> bool:
def __eq__(self, other: Any) -> bool:
return (
isinstance(other, Route)
and self.path == other.path
@@ -292,7 +306,12 @@ class Route(BaseRoute):
class WebSocketRoute(BaseRoute):
def __init__(
self, path: str, endpoint: typing.Callable, *, name: typing.Optional[str] = None
self,
path: str,
endpoint: Callable[..., Any],
*,
name: str | None = None,
middleware: Sequence[Middleware] | None = None,
) -> None:
assert path.startswith("/"), "Routed paths must start with '/'"
self.path = path
@@ -309,11 +328,17 @@ class WebSocketRoute(BaseRoute):
# Endpoint is a class. Treat it as ASGI.
self.app = endpoint
if middleware is not None:
for cls, args, kwargs in reversed(middleware):
self.app = cls(self.app, *args, **kwargs)
self.path_regex, self.path_format, self.param_convertors = compile_path(path)
def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
def matches(self, scope: Scope) -> tuple[Match, Scope]:
path_params: dict[str, Any]
if scope["type"] == "websocket":
match = self.path_regex.match(scope["path"])
route_path = get_route_path(scope)
match = self.path_regex.match(route_path)
if match:
matched_params = match.groupdict()
for key, value in matched_params.items():
@@ -324,28 +349,22 @@ class WebSocketRoute(BaseRoute):
return Match.FULL, child_scope
return Match.NONE, {}
def url_path_for(self, __name: str, **path_params: typing.Any) -> URLPath:
def url_path_for(self, name: str, /, **path_params: Any) -> URLPath:
seen_params = set(path_params.keys())
expected_params = set(self.param_convertors.keys())
if __name != self.name or seen_params != expected_params:
raise NoMatchFound(__name, path_params)
if name != self.name or seen_params != expected_params:
raise NoMatchFound(name, path_params)
path, remaining_params = replace_params(
self.path_format, self.param_convertors, path_params
)
path, remaining_params = replace_params(self.path_format, self.param_convertors, path_params)
assert not remaining_params
return URLPath(path=path, protocol="websocket")
async def handle(self, scope: Scope, receive: Receive, send: Send) -> None:
await self.app(scope, receive, send)
def __eq__(self, other: typing.Any) -> bool:
return (
isinstance(other, WebSocketRoute)
and self.path == other.path
and self.endpoint == other.endpoint
)
def __eq__(self, other: Any) -> bool:
return isinstance(other, WebSocketRoute) and self.path == other.path and self.endpoint == other.endpoint
def __repr__(self) -> str:
return f"{self.__class__.__name__}(path={self.path!r}, name={self.name!r})"
@@ -355,16 +374,14 @@ class Mount(BaseRoute):
def __init__(
self,
path: str,
app: typing.Optional[ASGIApp] = None,
routes: typing.Optional[typing.Sequence[BaseRoute]] = None,
name: typing.Optional[str] = None,
app: ASGIApp | None = None,
routes: Sequence[BaseRoute] | None = None,
name: str | None = None,
*,
middleware: typing.Optional[typing.Sequence[Middleware]] = None,
middleware: Sequence[Middleware] | None = None,
) -> None:
assert path == "" or path.startswith("/"), "Routed paths must start with '/'"
assert (
app is not None or routes is not None
), "Either 'app=...', or 'routes=' must be specified"
assert app is not None or routes is not None, "Either 'app=...', or 'routes=' must be specified"
self.path = path.rstrip("/")
if app is not None:
self._base_app: ASGIApp = app
@@ -372,82 +389,80 @@ class Mount(BaseRoute):
self._base_app = Router(routes=routes)
self.app = self._base_app
if middleware is not None:
for cls, options in reversed(middleware):
self.app = cls(app=self.app, **options)
for cls, args, kwargs in reversed(middleware):
self.app = cls(self.app, *args, **kwargs)
self.name = name
self.path_regex, self.path_format, self.param_convertors = compile_path(
self.path + "/{path:path}"
)
self.path_regex, self.path_format, self.param_convertors = compile_path(self.path + "/{path:path}")
@property
def routes(self) -> typing.List[BaseRoute]:
def routes(self) -> list[BaseRoute]:
return getattr(self._base_app, "routes", [])
def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
if scope["type"] in ("http", "websocket"):
path = scope["path"]
match = self.path_regex.match(path)
def matches(self, scope: Scope) -> tuple[Match, Scope]:
path_params: dict[str, Any]
if scope["type"] in ("http", "websocket"): # pragma: no branch
root_path = scope.get("root_path", "")
route_path = get_route_path(scope)
match = self.path_regex.match(route_path)
if match:
matched_params = match.groupdict()
for key, value in matched_params.items():
matched_params[key] = self.param_convertors[key].convert(value)
remaining_path = "/" + matched_params.pop("path")
matched_path = path[: -len(remaining_path)]
matched_path = route_path[: -len(remaining_path)]
path_params = dict(scope.get("path_params", {}))
path_params.update(matched_params)
root_path = scope.get("root_path", "")
child_scope = {
"path_params": path_params,
# app_root_path will only be set at the top level scope,
# initialized with the (optional) value of a root_path
# set above/before Starlette. And even though any
# mount will have its own child scope with its own respective
# root_path, the app_root_path will always be available in all
# the child scopes with the same top level value because it's
# set only once here with a default, any other child scope will
# just inherit that app_root_path default value stored in the
# scope. All this is needed to support Request.url_for(), as it
# uses the app_root_path to build the URL path.
"app_root_path": scope.get("app_root_path", root_path),
"root_path": root_path + matched_path,
"path": remaining_path,
"endpoint": self.app,
}
return Match.FULL, child_scope
return Match.NONE, {}
def url_path_for(self, __name: str, **path_params: typing.Any) -> URLPath:
if self.name is not None and __name == self.name and "path" in path_params:
def url_path_for(self, name: str, /, **path_params: Any) -> URLPath:
if self.name is not None and name == self.name and "path" in path_params:
# 'name' matches "<mount_name>".
path_params["path"] = path_params["path"].lstrip("/")
path, remaining_params = replace_params(
self.path_format, self.param_convertors, path_params
)
path, remaining_params = replace_params(self.path_format, self.param_convertors, path_params)
if not remaining_params:
return URLPath(path=path)
elif self.name is None or __name.startswith(self.name + ":"):
elif self.name is None or name.startswith(self.name + ":"):
if self.name is None:
# No mount name.
remaining_name = __name
remaining_name = name
else:
# 'name' matches "<mount_name>:<child_name>".
remaining_name = __name[len(self.name) + 1 :]
remaining_name = name[len(self.name) + 1 :]
path_kwarg = path_params.get("path")
path_params["path"] = ""
path_prefix, remaining_params = replace_params(
self.path_format, self.param_convertors, path_params
)
path_prefix, remaining_params = replace_params(self.path_format, self.param_convertors, path_params)
if path_kwarg is not None:
remaining_params["path"] = path_kwarg
for route in self.routes or []:
try:
url = route.url_path_for(remaining_name, **remaining_params)
return URLPath(
path=path_prefix.rstrip("/") + str(url), protocol=url.protocol
)
return URLPath(path=path_prefix.rstrip("/") + str(url), protocol=url.protocol)
except NoMatchFound:
pass
raise NoMatchFound(__name, path_params)
raise NoMatchFound(name, path_params)
async def handle(self, scope: Scope, receive: Receive, send: Send) -> None:
await self.app(scope, receive, send)
def __eq__(self, other: typing.Any) -> bool:
return (
isinstance(other, Mount)
and self.path == other.path
and self.app == other.app
)
def __eq__(self, other: Any) -> bool:
return isinstance(other, Mount) and self.path == other.path and self.app == other.app
def __repr__(self) -> str:
class_name = self.__class__.__name__
@@ -456,9 +471,7 @@ class Mount(BaseRoute):
class Host(BaseRoute):
def __init__(
self, host: str, app: ASGIApp, name: typing.Optional[str] = None
) -> None:
def __init__(self, host: str, app: ASGIApp, name: str | None = None) -> None:
assert not host.startswith("/"), "Host must not start with '/'"
self.host = host
self.app = app
@@ -466,11 +479,11 @@ class Host(BaseRoute):
self.host_regex, self.host_format, self.param_convertors = compile_path(host)
@property
def routes(self) -> typing.List[BaseRoute]:
def routes(self) -> list[BaseRoute]:
return getattr(self.app, "routes", [])
def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
if scope["type"] in ("http", "websocket"):
def matches(self, scope: Scope) -> tuple[Match, Scope]:
if scope["type"] in ("http", "websocket"): # pragma:no branch
headers = Headers(scope=scope)
host = headers.get("host", "").split(":")[0]
match = self.host_regex.match(host)
@@ -484,42 +497,34 @@ class Host(BaseRoute):
return Match.FULL, child_scope
return Match.NONE, {}
def url_path_for(self, __name: str, **path_params: typing.Any) -> URLPath:
if self.name is not None and __name == self.name and "path" in path_params:
def url_path_for(self, name: str, /, **path_params: Any) -> URLPath:
if self.name is not None and name == self.name and "path" in path_params:
# 'name' matches "<mount_name>".
path = path_params.pop("path")
host, remaining_params = replace_params(
self.host_format, self.param_convertors, path_params
)
host, remaining_params = replace_params(self.host_format, self.param_convertors, path_params)
if not remaining_params:
return URLPath(path=path, host=host)
elif self.name is None or __name.startswith(self.name + ":"):
elif self.name is None or name.startswith(self.name + ":"):
if self.name is None:
# No mount name.
remaining_name = __name
remaining_name = name
else:
# 'name' matches "<mount_name>:<child_name>".
remaining_name = __name[len(self.name) + 1 :]
host, remaining_params = replace_params(
self.host_format, self.param_convertors, path_params
)
remaining_name = name[len(self.name) + 1 :]
host, remaining_params = replace_params(self.host_format, self.param_convertors, path_params)
for route in self.routes or []:
try:
url = route.url_path_for(remaining_name, **remaining_params)
return URLPath(path=str(url), protocol=url.protocol, host=host)
except NoMatchFound:
pass
raise NoMatchFound(__name, path_params)
raise NoMatchFound(name, path_params)
async def handle(self, scope: Scope, receive: Receive, send: Send) -> None:
await self.app(scope, receive, send)
def __eq__(self, other: typing.Any) -> bool:
return (
isinstance(other, Host)
and self.host == other.host
and self.app == other.app
)
def __eq__(self, other: Any) -> bool:
return isinstance(other, Host) and self.host == other.host and self.app == other.app
def __repr__(self) -> str:
class_name = self.__class__.__name__
@@ -527,11 +532,11 @@ class Host(BaseRoute):
return f"{class_name}(host={self.host!r}, name={name!r}, app={self.app!r})"
_T = typing.TypeVar("_T")
_T = TypeVar("_T")
class _AsyncLiftContextManager(typing.AsyncContextManager[_T]):
def __init__(self, cm: typing.ContextManager[_T]):
class _AsyncLiftContextManager(AbstractAsyncContextManager[_T]):
def __init__(self, cm: AbstractContextManager[_T]):
self._cm = cm
async def __aenter__(self) -> _T:
@@ -539,27 +544,27 @@ class _AsyncLiftContextManager(typing.AsyncContextManager[_T]):
async def __aexit__(
self,
exc_type: typing.Optional[typing.Type[BaseException]],
exc_value: typing.Optional[BaseException],
traceback: typing.Optional[types.TracebackType],
) -> typing.Optional[bool]:
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: types.TracebackType | None,
) -> bool | None:
return self._cm.__exit__(exc_type, exc_value, traceback)
def _wrap_gen_lifespan_context(
lifespan_context: typing.Callable[[typing.Any], typing.Generator]
) -> typing.Callable[[typing.Any], typing.AsyncContextManager]:
lifespan_context: Callable[[Any], Generator[Any, Any, Any]],
) -> Callable[[Any], AbstractAsyncContextManager[Any]]:
cmgr = contextlib.contextmanager(lifespan_context)
@functools.wraps(cmgr)
def wrapper(app: typing.Any) -> _AsyncLiftContextManager:
def wrapper(app: Any) -> _AsyncLiftContextManager[Any]:
return _AsyncLiftContextManager(cmgr(app))
return wrapper
class _DefaultLifespan:
def __init__(self, router: "Router"):
def __init__(self, router: Router):
self._router = router
async def __aenter__(self) -> None:
@@ -575,14 +580,16 @@ class _DefaultLifespan:
class Router:
def __init__(
self,
routes: typing.Optional[typing.Sequence[BaseRoute]] = None,
routes: Sequence[BaseRoute] | None = None,
redirect_slashes: bool = True,
default: typing.Optional[ASGIApp] = None,
on_startup: typing.Optional[typing.Sequence[typing.Callable]] = None,
on_shutdown: typing.Optional[typing.Sequence[typing.Callable]] = None,
default: ASGIApp | None = None,
on_startup: Sequence[Callable[[], Any]] | None = None,
on_shutdown: Sequence[Callable[[], Any]] | None = None,
# the generic to Lifespan[AppType] is the type of the top level application
# which the router cannot know statically, so we use typing.Any
lifespan: typing.Optional[Lifespan[typing.Any]] = None,
# which the router cannot know statically, so we use Any
lifespan: Lifespan[Any] | None = None,
*,
middleware: Sequence[Middleware] | None = None,
) -> None:
self.routes = [] if routes is None else list(routes)
self.redirect_slashes = redirect_slashes
@@ -594,12 +601,18 @@ class Router:
warnings.warn(
"The on_startup and on_shutdown parameters are deprecated, and they "
"will be removed on version 1.0. Use the lifespan parameter instead. "
"See more about it on https://www.starlette.io/lifespan/.",
"See more about it on https://starlette.dev/lifespan/.",
DeprecationWarning,
)
if lifespan:
warnings.warn(
"The `lifespan` parameter cannot be used with `on_startup` or "
"`on_shutdown`. Both `on_startup` and `on_shutdown` will be "
"ignored."
)
if lifespan is None:
self.lifespan_context: Lifespan = _DefaultLifespan(self)
self.lifespan_context: Lifespan[Any] = _DefaultLifespan(self)
elif inspect.isasyncgenfunction(lifespan):
warnings.warn(
@@ -608,20 +621,24 @@ class Router:
DeprecationWarning,
)
self.lifespan_context = asynccontextmanager(
lifespan, # type: ignore[arg-type]
lifespan,
)
elif inspect.isgeneratorfunction(lifespan):
warnings.warn(
"generator function lifespans are deprecated, "
"use an @contextlib.asynccontextmanager function instead",
"generator function lifespans are deprecated, use an @contextlib.asynccontextmanager function instead",
DeprecationWarning,
)
self.lifespan_context = _wrap_gen_lifespan_context(
lifespan, # type: ignore[arg-type]
lifespan,
)
else:
self.lifespan_context = lifespan
self.middleware_stack = self.app
if middleware:
for cls, args, kwargs in reversed(middleware):
self.middleware_stack = cls(self.middleware_stack, *args, **kwargs)
async def not_found(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] == "websocket":
websocket_close = WebSocketClose()
@@ -637,13 +654,13 @@ class Router:
response = PlainTextResponse("Not Found", status_code=404)
await response(scope, receive, send)
def url_path_for(self, __name: str, **path_params: typing.Any) -> URLPath:
def url_path_for(self, name: str, /, **path_params: Any) -> URLPath:
for route in self.routes:
try:
return route.url_path_for(__name, **path_params)
return route.url_path_for(name, **path_params)
except NoMatchFound:
pass
raise NoMatchFound(__name, path_params)
raise NoMatchFound(name, path_params)
async def startup(self) -> None:
"""
@@ -671,15 +688,13 @@ class Router:
startup and shutdown events.
"""
started = False
app: typing.Any = scope.get("app")
app: Any = scope.get("app")
await receive()
try:
async with self.lifespan_context(app) as maybe_state:
if maybe_state is not None:
if "state" not in scope:
raise RuntimeError(
'The server does not support "state" in the lifespan scope.'
)
raise RuntimeError('The server does not support "state" in the lifespan scope.')
scope["state"].update(maybe_state)
await send({"type": "lifespan.startup.complete"})
started = True
@@ -698,6 +713,9 @@ class Router:
"""
The main entry point to the Router class.
"""
await self.middleware_stack(scope, receive, send)
async def app(self, scope: Scope, receive: Receive, send: Send) -> None:
assert scope["type"] in ("http", "websocket", "lifespan")
if "router" not in scope:
@@ -729,9 +747,10 @@ class Router:
await partial.handle(scope, receive, send)
return
if scope["type"] == "http" and self.redirect_slashes and scope["path"] != "/":
route_path = get_route_path(scope)
if scope["type"] == "http" and self.redirect_slashes and route_path != "/":
redirect_scope = dict(scope)
if scope["path"].endswith("/"):
if route_path.endswith("/"):
redirect_scope["path"] = redirect_scope["path"].rstrip("/")
else:
redirect_scope["path"] = redirect_scope["path"] + "/"
@@ -746,29 +765,25 @@ class Router:
await self.default(scope, receive, send)
def __eq__(self, other: typing.Any) -> bool:
def __eq__(self, other: Any) -> bool:
return isinstance(other, Router) and self.routes == other.routes
def mount(
self, path: str, app: ASGIApp, name: typing.Optional[str] = None
) -> None: # pragma: nocover
def mount(self, path: str, app: ASGIApp, name: str | None = None) -> None: # pragma: no cover
route = Mount(path, app=app, name=name)
self.routes.append(route)
def host(
self, host: str, app: ASGIApp, name: typing.Optional[str] = None
) -> None: # pragma: no cover
def host(self, host: str, app: ASGIApp, name: str | None = None) -> None: # pragma: no cover
route = Host(host, app=app, name=name)
self.routes.append(route)
def add_route(
self,
path: str,
endpoint: typing.Callable,
methods: typing.Optional[typing.List[str]] = None,
name: typing.Optional[str] = None,
endpoint: Callable[[Request], Awaitable[Response] | Response],
methods: Collection[str] | None = None,
name: str | None = None,
include_in_schema: bool = True,
) -> None: # pragma: nocover
) -> None: # pragma: no cover
route = Route(
path,
endpoint=endpoint,
@@ -779,7 +794,10 @@ class Router:
self.routes.append(route)
def add_websocket_route(
self, path: str, endpoint: typing.Callable, name: typing.Optional[str] = None
self,
path: str,
endpoint: Callable[[WebSocket], Awaitable[None]],
name: str | None = None,
) -> None: # pragma: no cover
route = WebSocketRoute(path, endpoint=endpoint, name=name)
self.routes.append(route)
@@ -787,10 +805,10 @@ class Router:
def route(
self,
path: str,
methods: typing.Optional[typing.List[str]] = None,
name: typing.Optional[str] = None,
methods: Collection[str] | None = None,
name: str | None = None,
include_in_schema: bool = True,
) -> typing.Callable:
) -> Callable: # type: ignore[type-arg]
"""
We no longer document this decorator style API, and its usage is discouraged.
Instead you should use the following approach:
@@ -800,11 +818,11 @@ class Router:
"""
warnings.warn(
"The `route` decorator is deprecated, and will be removed in version 1.0.0."
"Refer to https://www.starlette.io/routing/#http-routing for the recommended approach.", # noqa: E501
"Refer to https://starlette.dev/routing/#http-routing for the recommended approach.",
DeprecationWarning,
)
def decorator(func: typing.Callable) -> typing.Callable:
def decorator(func: Callable) -> Callable: # type: ignore[type-arg]
self.add_route(
path,
func,
@@ -816,9 +834,7 @@ class Router:
return decorator
def websocket_route(
self, path: str, name: typing.Optional[str] = None
) -> typing.Callable:
def websocket_route(self, path: str, name: str | None = None) -> Callable: # type: ignore[type-arg]
"""
We no longer document this decorator style API, and its usage is discouraged.
Instead you should use the following approach:
@@ -827,20 +843,18 @@ class Router:
>>> app = Starlette(routes=routes)
"""
warnings.warn(
"The `websocket_route` decorator is deprecated, and will be removed in version 1.0.0. Refer to " # noqa: E501
"https://www.starlette.io/routing/#websocket-routing for the recommended approach.", # noqa: E501
"The `websocket_route` decorator is deprecated, and will be removed in version 1.0.0. Refer to "
"https://starlette.dev/routing/#websocket-routing for the recommended approach.",
DeprecationWarning,
)
def decorator(func: typing.Callable) -> typing.Callable:
def decorator(func: Callable) -> Callable: # type: ignore[type-arg]
self.add_websocket_route(path, func, name=name)
return func
return decorator
def add_event_handler(
self, event_type: str, func: typing.Callable
) -> None: # pragma: no cover
def add_event_handler(self, event_type: str, func: Callable[[], Any]) -> None: # pragma: no cover
assert event_type in ("startup", "shutdown")
if event_type == "startup":
@@ -848,14 +862,14 @@ class Router:
else:
self.on_shutdown.append(func)
def on_event(self, event_type: str) -> typing.Callable:
def on_event(self, event_type: str) -> Callable: # type: ignore[type-arg]
warnings.warn(
"The `on_event` decorator is deprecated, and will be removed in version 1.0.0. " # noqa: E501
"Refer to https://www.starlette.io/lifespan/ for recommended approach.",
"The `on_event` decorator is deprecated, and will be removed in version 1.0.0. "
"Refer to https://starlette.dev/lifespan/ for recommended approach.",
DeprecationWarning,
)
def decorator(func: typing.Callable) -> typing.Callable:
def decorator(func: Callable) -> Callable: # type: ignore[type-arg]
self.add_event_handler(event_type, func)
return func

View File

@@ -1,41 +1,43 @@
from __future__ import annotations
import inspect
import re
import typing
from collections.abc import Callable
from typing import Any, NamedTuple
from starlette.requests import Request
from starlette.responses import Response
from starlette.routing import BaseRoute, Mount, Route
from starlette.routing import BaseRoute, Host, Mount, Route
try:
import yaml
except ModuleNotFoundError: # pragma: nocover
except ModuleNotFoundError: # pragma: no cover
yaml = None # type: ignore[assignment]
class OpenAPIResponse(Response):
media_type = "application/vnd.oai.openapi"
def render(self, content: typing.Any) -> bytes:
def render(self, content: Any) -> bytes:
assert yaml is not None, "`pyyaml` must be installed to use OpenAPIResponse."
assert isinstance(
content, dict
), "The schema passed to OpenAPIResponse should be a dictionary."
assert isinstance(content, dict), "The schema passed to OpenAPIResponse should be a dictionary."
return yaml.dump(content, default_flow_style=False).encode("utf-8")
class EndpointInfo(typing.NamedTuple):
class EndpointInfo(NamedTuple):
path: str
http_method: str
func: typing.Callable
func: Callable[..., Any]
_remove_converter_pattern = re.compile(r":\w+}")
class BaseSchemaGenerator:
def get_schema(self, routes: typing.List[BaseRoute]) -> dict:
def get_schema(self, routes: list[BaseRoute]) -> dict[str, Any]:
raise NotImplementedError() # pragma: no cover
def get_endpoints(
self, routes: typing.List[BaseRoute]
) -> typing.List[EndpointInfo]:
def get_endpoints(self, routes: list[BaseRoute]) -> list[EndpointInfo]:
"""
Given the routes, yields the following information:
@@ -46,12 +48,15 @@ class BaseSchemaGenerator:
- func
method ready to extract the docstring
"""
endpoints_info: list = []
endpoints_info: list[EndpointInfo] = []
for route in routes:
if isinstance(route, Mount):
path = self._remove_converter(route.path)
if isinstance(route, Mount | Host):
routes = route.routes or []
if isinstance(route, Mount):
path = self._remove_converter(route.path)
else:
path = ""
sub_endpoints = [
EndpointInfo(
path="".join((path, sub_endpoint.path)),
@@ -70,9 +75,7 @@ class BaseSchemaGenerator:
for method in route.methods or ["GET"]:
if method == "HEAD":
continue
endpoints_info.append(
EndpointInfo(path, method.lower(), route.endpoint)
)
endpoints_info.append(EndpointInfo(path, method.lower(), route.endpoint))
else:
path = self._remove_converter(route.path)
for method in ["get", "post", "put", "patch", "delete", "options"]:
@@ -90,9 +93,9 @@ class BaseSchemaGenerator:
Route("/users/{id:int}", endpoint=get_user, methods=["GET"])
Should be represented as `/users/{id}` in the OpenAPI schema.
"""
return re.sub(r":\w+}", "}", path)
return _remove_converter_pattern.sub("}", path)
def parse_docstring(self, func_or_method: typing.Callable) -> dict:
def parse_docstring(self, func_or_method: Callable[..., Any]) -> dict[str, Any]:
"""
Given a function, parse the docstring as YAML and return a dictionary of info.
"""
@@ -123,10 +126,10 @@ class BaseSchemaGenerator:
class SchemaGenerator(BaseSchemaGenerator):
def __init__(self, base_schema: dict) -> None:
def __init__(self, base_schema: dict[str, Any]) -> None:
self.base_schema = base_schema
def get_schema(self, routes: typing.List[BaseRoute]) -> dict:
def get_schema(self, routes: list[BaseRoute]) -> dict[str, Any]:
schema = dict(self.base_schema)
schema.setdefault("paths", {})
endpoints_info = self.get_endpoints(routes)

View File

@@ -1,17 +1,22 @@
from __future__ import annotations
import errno
import importlib.util
import os
import stat
import typing
from email.utils import parsedate
from typing import Union
import anyio
import anyio.to_thread
from starlette._utils import get_route_path
from starlette.datastructures import URL, Headers
from starlette.exceptions import HTTPException
from starlette.responses import FileResponse, RedirectResponse, Response
from starlette.types import Receive, Scope, Send
PathLike = typing.Union[str, "os.PathLike[str]"]
PathLike = Union[str, "os.PathLike[str]"]
class NotModifiedResponse(Response):
@@ -27,11 +32,7 @@ class NotModifiedResponse(Response):
def __init__(self, headers: Headers):
super().__init__(
status_code=304,
headers={
name: value
for name, value in headers.items()
if name in self.NOT_MODIFIED_HEADERS
},
headers={name: value for name, value in headers.items() if name in self.NOT_MODIFIED_HEADERS},
)
@@ -39,10 +40,8 @@ class StaticFiles:
def __init__(
self,
*,
directory: typing.Optional[PathLike] = None,
packages: typing.Optional[
typing.List[typing.Union[str, typing.Tuple[str, str]]]
] = None,
directory: PathLike | None = None,
packages: list[str | tuple[str, str]] | None = None,
html: bool = False,
check_dir: bool = True,
follow_symlink: bool = False,
@@ -58,11 +57,9 @@ class StaticFiles:
def get_directories(
self,
directory: typing.Optional[PathLike] = None,
packages: typing.Optional[
typing.List[typing.Union[str, typing.Tuple[str, str]]]
] = None,
) -> typing.List[PathLike]:
directory: PathLike | None = None,
packages: list[str | tuple[str, str]] | None = None,
) -> list[PathLike]:
"""
Given `directory` and `packages` arguments, return a list of all the
directories that should be used for serving static files from.
@@ -79,12 +76,10 @@ class StaticFiles:
spec = importlib.util.find_spec(package)
assert spec is not None, f"Package {package!r} could not be found."
assert spec.origin is not None, f"Package {package!r} could not be found."
package_directory = os.path.normpath(
os.path.join(spec.origin, "..", statics_dir)
package_directory = os.path.normpath(os.path.join(spec.origin, "..", statics_dir))
assert os.path.isdir(package_directory), (
f"Directory '{statics_dir!r}' in package {package!r} could not be found."
)
assert os.path.isdir(
package_directory
), f"Directory '{statics_dir!r}' in package {package!r} could not be found."
directories.append(package_directory)
return directories
@@ -108,7 +103,8 @@ class StaticFiles:
Given the ASGI scope, return the `path` string to serve up,
with OS specific path separators, and any '..', '.' components removed.
"""
return os.path.normpath(os.path.join(*scope["path"].split("/")))
route_path = get_route_path(scope)
return os.path.normpath(os.path.join(*route_path.split("/")))
async def get_response(self, path: str, scope: Scope) -> Response:
"""
@@ -118,13 +114,15 @@ class StaticFiles:
raise HTTPException(status_code=405)
try:
full_path, stat_result = await anyio.to_thread.run_sync(
self.lookup_path, path
)
full_path, stat_result = await anyio.to_thread.run_sync(self.lookup_path, path)
except PermissionError:
raise HTTPException(status_code=401)
except OSError:
raise
except OSError as exc:
# Filename is too long, so it can't be a valid static file.
if exc.errno == errno.ENAMETOOLONG:
raise HTTPException(status_code=404)
raise exc
if stat_result and stat.S_ISREG(stat_result.st_mode):
# We have a static file to serve.
@@ -134,9 +132,7 @@ class StaticFiles:
# We're in HTML mode, and have got a directory URL.
# Check if we have 'index.html' file to serve.
index_path = os.path.join(path, "index.html")
full_path, stat_result = await anyio.to_thread.run_sync(
self.lookup_path, index_path
)
full_path, stat_result = await anyio.to_thread.run_sync(self.lookup_path, index_path)
if stat_result is not None and stat.S_ISREG(stat_result.st_mode):
if not scope["path"].endswith("/"):
# Directory URLs should redirect to always end in "/".
@@ -147,31 +143,22 @@ class StaticFiles:
if self.html:
# Check for '404.html' if we're in HTML mode.
full_path, stat_result = await anyio.to_thread.run_sync(
self.lookup_path, "404.html"
)
full_path, stat_result = await anyio.to_thread.run_sync(self.lookup_path, "404.html")
if stat_result and stat.S_ISREG(stat_result.st_mode):
return FileResponse(
full_path,
stat_result=stat_result,
method=scope["method"],
status_code=404,
)
return FileResponse(full_path, stat_result=stat_result, status_code=404)
raise HTTPException(status_code=404)
def lookup_path(
self, path: str
) -> typing.Tuple[str, typing.Optional[os.stat_result]]:
def lookup_path(self, path: str) -> tuple[str, os.stat_result | None]:
for directory in self.all_directories:
joined_path = os.path.join(directory, path)
if self.follow_symlink:
full_path = os.path.abspath(joined_path)
directory = os.path.abspath(directory)
else:
full_path = os.path.realpath(joined_path)
directory = os.path.realpath(directory)
if os.path.commonpath([full_path, directory]) != directory:
# Don't allow misbehaving clients to break out of the static files
# directory.
directory = os.path.realpath(directory)
if os.path.commonpath([full_path, directory]) != str(directory):
# Don't allow misbehaving clients to break out of the static files directory.
continue
try:
return full_path, os.stat(full_path)
@@ -186,12 +173,9 @@ class StaticFiles:
scope: Scope,
status_code: int = 200,
) -> Response:
method = scope["method"]
request_headers = Headers(scope=scope)
response = FileResponse(
full_path, status_code=status_code, stat_result=stat_result, method=method
)
response = FileResponse(full_path, status_code=status_code, stat_result=stat_result)
if self.is_not_modified(response.headers, request_headers):
return NotModifiedResponse(response.headers)
return response
@@ -208,37 +192,24 @@ class StaticFiles:
try:
stat_result = await anyio.to_thread.run_sync(os.stat, self.directory)
except FileNotFoundError:
raise RuntimeError(
f"StaticFiles directory '{self.directory}' does not exist."
)
raise RuntimeError(f"StaticFiles directory '{self.directory}' does not exist.")
if not (stat.S_ISDIR(stat_result.st_mode) or stat.S_ISLNK(stat_result.st_mode)):
raise RuntimeError(
f"StaticFiles path '{self.directory}' is not a directory."
)
raise RuntimeError(f"StaticFiles path '{self.directory}' is not a directory.")
def is_not_modified(
self, response_headers: Headers, request_headers: Headers
) -> bool:
def is_not_modified(self, response_headers: Headers, request_headers: Headers) -> bool:
"""
Given the request and response headers, return `True` if an HTTP
"Not Modified" response could be returned instead.
"""
try:
if_none_match = request_headers["if-none-match"]
if if_none_match := request_headers.get("if-none-match"):
# The "etag" header is added by FileResponse, so it's always present.
etag = response_headers["etag"]
if if_none_match == etag:
return True
except KeyError:
pass
return etag in [tag.strip(" W/") for tag in if_none_match.split(",")]
try:
if_modified_since = parsedate(request_headers["if-modified-since"])
last_modified = parsedate(response_headers["last-modified"])
if (
if_modified_since is not None
and last_modified is not None
and if_modified_since >= last_modified
):
if if_modified_since is not None and last_modified is not None and if_modified_since >= last_modified:
return True
except KeyError:
pass

View File

@@ -3,12 +3,14 @@ HTTP codes
See HTTP Status Code Registry:
https://www.iana.org/assignments/http-status-codes/http-status-codes.xhtml
And RFC 2324 - https://tools.ietf.org/html/rfc2324
And RFC 9110 - https://www.rfc-editor.org/rfc/rfc9110
"""
import warnings
from typing import List
__all__ = (
from __future__ import annotations
import warnings
__all__ = [
"HTTP_100_CONTINUE",
"HTTP_101_SWITCHING_PROTOCOLS",
"HTTP_102_PROCESSING",
@@ -45,14 +47,14 @@ __all__ = (
"HTTP_410_GONE",
"HTTP_411_LENGTH_REQUIRED",
"HTTP_412_PRECONDITION_FAILED",
"HTTP_413_REQUEST_ENTITY_TOO_LARGE",
"HTTP_414_REQUEST_URI_TOO_LONG",
"HTTP_413_CONTENT_TOO_LARGE",
"HTTP_414_URI_TOO_LONG",
"HTTP_415_UNSUPPORTED_MEDIA_TYPE",
"HTTP_416_REQUESTED_RANGE_NOT_SATISFIABLE",
"HTTP_416_RANGE_NOT_SATISFIABLE",
"HTTP_417_EXPECTATION_FAILED",
"HTTP_418_IM_A_TEAPOT",
"HTTP_421_MISDIRECTED_REQUEST",
"HTTP_422_UNPROCESSABLE_ENTITY",
"HTTP_422_UNPROCESSABLE_CONTENT",
"HTTP_423_LOCKED",
"HTTP_424_FAILED_DEPENDENCY",
"HTTP_425_TOO_EARLY",
@@ -87,7 +89,7 @@ __all__ = (
"WS_1013_TRY_AGAIN_LATER",
"WS_1014_BAD_GATEWAY",
"WS_1015_TLS_HANDSHAKE",
)
]
HTTP_100_CONTINUE = 100
HTTP_101_SWITCHING_PROTOCOLS = 101
@@ -125,14 +127,14 @@ HTTP_409_CONFLICT = 409
HTTP_410_GONE = 410
HTTP_411_LENGTH_REQUIRED = 411
HTTP_412_PRECONDITION_FAILED = 412
HTTP_413_REQUEST_ENTITY_TOO_LARGE = 413
HTTP_414_REQUEST_URI_TOO_LONG = 414
HTTP_413_CONTENT_TOO_LARGE = 413
HTTP_414_URI_TOO_LONG = 414
HTTP_415_UNSUPPORTED_MEDIA_TYPE = 415
HTTP_416_REQUESTED_RANGE_NOT_SATISFIABLE = 416
HTTP_416_RANGE_NOT_SATISFIABLE = 416
HTTP_417_EXPECTATION_FAILED = 417
HTTP_418_IM_A_TEAPOT = 418
HTTP_421_MISDIRECTED_REQUEST = 421
HTTP_422_UNPROCESSABLE_ENTITY = 422
HTTP_422_UNPROCESSABLE_CONTENT = 422
HTTP_423_LOCKED = 423
HTTP_424_FAILED_DEPENDENCY = 424
HTTP_425_TOO_EARLY = 425
@@ -175,15 +177,22 @@ WS_1013_TRY_AGAIN_LATER = 1013
WS_1014_BAD_GATEWAY = 1014
WS_1015_TLS_HANDSHAKE = 1015
__deprecated__ = {"WS_1004_NO_STATUS_RCVD": 1004, "WS_1005_ABNORMAL_CLOSURE": 1005}
__deprecated__ = {
"HTTP_413_REQUEST_ENTITY_TOO_LARGE": 413,
"HTTP_414_REQUEST_URI_TOO_LONG": 414,
"HTTP_416_REQUESTED_RANGE_NOT_SATISFIABLE": 416,
"HTTP_422_UNPROCESSABLE_ENTITY": 422,
}
def __getattr__(name: str) -> int:
deprecation_changes = {
"WS_1004_NO_STATUS_RCVD": "WS_1005_NO_STATUS_RCVD",
"WS_1005_ABNORMAL_CLOSURE": "WS_1006_ABNORMAL_CLOSURE",
"HTTP_413_REQUEST_ENTITY_TOO_LARGE": "HTTP_413_CONTENT_TOO_LARGE",
"HTTP_414_REQUEST_URI_TOO_LONG": "HTTP_414_URI_TOO_LONG",
"HTTP_416_REQUESTED_RANGE_NOT_SATISFIABLE": "HTTP_416_RANGE_NOT_SATISFIABLE",
"HTTP_422_UNPROCESSABLE_ENTITY": "HTTP_422_UNPROCESSABLE_CONTENT",
}
deprecated = __deprecated__.get(name)
if deprecated:
warnings.warn(
@@ -192,8 +201,9 @@ def __getattr__(name: str) -> int:
stacklevel=3,
)
return deprecated
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
raise AttributeError(f"module 'starlette.status' has no attribute '{name}'")
def __dir__() -> List[str]:
def __dir__() -> list[str]:
return sorted(list(__all__) + list(__deprecated__.keys())) # pragma: no cover

View File

@@ -1,10 +1,14 @@
import typing
from __future__ import annotations
import warnings
from collections.abc import Callable, Mapping, Sequence
from os import PathLike
from typing import Any, cast, overload
from starlette.background import BackgroundTask
from starlette.datastructures import URL
from starlette.requests import Request
from starlette.responses import Response
from starlette.responses import HTMLResponse
from starlette.types import Receive, Scope, Send
try:
@@ -16,23 +20,21 @@ try:
# adding a type ignore for mypy to let us access an attribute that may not exist
if hasattr(jinja2, "pass_context"):
pass_context = jinja2.pass_context
else: # pragma: nocover
else: # pragma: no cover
pass_context = jinja2.contextfunction # type: ignore[attr-defined]
except ModuleNotFoundError: # pragma: nocover
except ModuleNotFoundError: # pragma: no cover
jinja2 = None # type: ignore[assignment]
class _TemplateResponse(Response):
media_type = "text/html"
class _TemplateResponse(HTMLResponse):
def __init__(
self,
template: typing.Any,
context: dict,
template: Any,
context: dict[str, Any],
status_code: int = 200,
headers: typing.Optional[typing.Mapping[str, str]] = None,
media_type: typing.Optional[str] = None,
background: typing.Optional[BackgroundTask] = None,
headers: Mapping[str, str] | None = None,
media_type: str | None = None,
background: BackgroundTask | None = None,
):
self.template = template
self.context = context
@@ -42,7 +44,7 @@ class _TemplateResponse(Response):
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
request = self.context.get("request", {})
extensions = request.get("extensions", {})
if "http.response.debug" in extensions:
if "http.response.debug" in extensions: # pragma: no branch
await send(
{
"type": "http.response.debug",
@@ -62,50 +64,145 @@ class Jinja2Templates:
return templates.TemplateResponse("index.html", {"request": request})
"""
@overload
def __init__(
self,
directory: typing.Union[str, PathLike],
context_processors: typing.Optional[
typing.List[typing.Callable[[Request], typing.Dict[str, typing.Any]]]
] = None,
**env_options: typing.Any,
directory: str | PathLike[str] | Sequence[str | PathLike[str]],
*,
context_processors: list[Callable[[Request], dict[str, Any]]] | None = None,
**env_options: Any,
) -> None: ...
@overload
def __init__(
self,
*,
env: jinja2.Environment,
context_processors: list[Callable[[Request], dict[str, Any]]] | None = None,
) -> None: ...
def __init__(
self,
directory: str | PathLike[str] | Sequence[str | PathLike[str]] | None = None,
*,
context_processors: list[Callable[[Request], dict[str, Any]]] | None = None,
env: jinja2.Environment | None = None,
**env_options: Any,
) -> None:
if env_options:
warnings.warn(
"Extra environment options are deprecated. Use a preconfigured jinja2.Environment instead.",
DeprecationWarning,
)
assert jinja2 is not None, "jinja2 must be installed to use Jinja2Templates"
self.env = self._create_env(directory, **env_options)
assert bool(directory) ^ bool(env), "either 'directory' or 'env' arguments must be passed"
self.context_processors = context_processors or []
if directory is not None:
self.env = self._create_env(directory, **env_options)
elif env is not None: # pragma: no branch
self.env = env
self._setup_env_defaults(self.env)
def _create_env(
self, directory: typing.Union[str, PathLike], **env_options: typing.Any
) -> "jinja2.Environment":
@pass_context
def url_for(context: dict, name: str, **path_params: typing.Any) -> URL:
request = context["request"]
return request.url_for(name, **path_params)
self,
directory: str | PathLike[str] | Sequence[str | PathLike[str]],
**env_options: Any,
) -> jinja2.Environment:
loader = jinja2.FileSystemLoader(directory)
env_options.setdefault("loader", loader)
env_options.setdefault("autoescape", True)
env = jinja2.Environment(**env_options)
env.globals["url_for"] = url_for
return env
return jinja2.Environment(**env_options)
def get_template(self, name: str) -> "jinja2.Template":
def _setup_env_defaults(self, env: jinja2.Environment) -> None:
@pass_context
def url_for(
context: dict[str, Any],
name: str,
/,
**path_params: Any,
) -> URL:
request: Request = context["request"]
return request.url_for(name, **path_params)
env.globals.setdefault("url_for", url_for)
def get_template(self, name: str) -> jinja2.Template:
return self.env.get_template(name)
@overload
def TemplateResponse(
self,
request: Request,
name: str,
context: dict[str, Any] | None = None,
status_code: int = 200,
headers: Mapping[str, str] | None = None,
media_type: str | None = None,
background: BackgroundTask | None = None,
) -> _TemplateResponse: ...
@overload
def TemplateResponse(
self,
name: str,
context: dict,
context: dict[str, Any] | None = None,
status_code: int = 200,
headers: typing.Optional[typing.Mapping[str, str]] = None,
media_type: typing.Optional[str] = None,
background: typing.Optional[BackgroundTask] = None,
headers: Mapping[str, str] | None = None,
media_type: str | None = None,
background: BackgroundTask | None = None,
) -> _TemplateResponse:
if "request" not in context:
raise ValueError('context must include a "request" key')
# Deprecated usage
...
request = typing.cast(Request, context["request"])
def TemplateResponse(self, *args: Any, **kwargs: Any) -> _TemplateResponse:
if args:
if isinstance(args[0], str): # the first argument is template name (old style)
warnings.warn(
"The `name` is not the first parameter anymore. "
"The first parameter should be the `Request` instance.\n"
'Replace `TemplateResponse(name, {"request": request})` by `TemplateResponse(request, name)`.',
DeprecationWarning,
)
name = args[0]
context = args[1] if len(args) > 1 else kwargs.get("context", {})
status_code = args[2] if len(args) > 2 else kwargs.get("status_code", 200)
headers = args[3] if len(args) > 3 else kwargs.get("headers")
media_type = args[4] if len(args) > 4 else kwargs.get("media_type")
background = args[5] if len(args) > 5 else kwargs.get("background")
if "request" not in context:
raise ValueError('context must include a "request" key')
request = context["request"]
else: # the first argument is a request instance (new style)
request = args[0]
name = args[1] if len(args) > 1 else kwargs["name"]
context = args[2] if len(args) > 2 else kwargs.get("context", {})
status_code = args[3] if len(args) > 3 else kwargs.get("status_code", 200)
headers = args[4] if len(args) > 4 else kwargs.get("headers")
media_type = args[5] if len(args) > 5 else kwargs.get("media_type")
background = args[6] if len(args) > 6 else kwargs.get("background")
else: # all arguments are kwargs
if "request" not in kwargs:
warnings.warn(
"The `TemplateResponse` now requires the `request` argument.\n"
'Replace `TemplateResponse(name, {"context": context})` by `TemplateResponse(request, name)`.',
DeprecationWarning,
)
if "request" not in kwargs.get("context", {}):
raise ValueError('context must include a "request" key')
context = kwargs.get("context", {})
request = kwargs.get("request", context.get("request"))
name = cast(str, kwargs["name"])
status_code = kwargs.get("status_code", 200)
headers = kwargs.get("headers")
media_type = kwargs.get("media_type")
background = kwargs.get("background")
context.setdefault("request", request)
for context_processor in self.context_processors:
context.update(context_processor(request))

View File

@@ -1,43 +1,58 @@
from __future__ import annotations
import contextlib
import inspect
import io
import json
import math
import queue
import sys
import typing
import warnings
from collections.abc import Awaitable, Callable, Generator, Iterable, Mapping, MutableMapping, Sequence
from concurrent.futures import Future
from contextlib import AbstractContextManager
from types import GeneratorType
from typing import (
Any,
Literal,
TypedDict,
TypeGuard,
cast,
)
from urllib.parse import unquote, urljoin
import anyio
import anyio.abc
import anyio.from_thread
import httpx
from anyio.streams.stapled import StapledObjectStream
from starlette._utils import is_async_callable
from starlette.types import ASGIApp, Message, Receive, Scope, Send
from starlette.websockets import WebSocketDisconnect
if sys.version_info >= (3, 8): # pragma: no cover
from typing import TypedDict
if sys.version_info >= (3, 11): # pragma: no cover
from typing import Self
else: # pragma: no cover
from typing_extensions import TypedDict
from typing_extensions import Self
_PortalFactoryType = typing.Callable[
[], typing.ContextManager[anyio.abc.BlockingPortal]
]
try:
import httpx
except ModuleNotFoundError: # pragma: no cover
raise RuntimeError(
"The starlette.testclient module requires the httpx package to be installed.\n"
"You can install this with:\n"
" $ pip install httpx\n"
)
_PortalFactoryType = Callable[[], AbstractContextManager[anyio.abc.BlockingPortal]]
ASGIInstance = typing.Callable[[Receive, Send], typing.Awaitable[None]]
ASGI2App = typing.Callable[[Scope], ASGIInstance]
ASGI3App = typing.Callable[[Scope, Receive, Send], typing.Awaitable[None]]
ASGIInstance = Callable[[Receive, Send], Awaitable[None]]
ASGI2App = Callable[[Scope], ASGIInstance]
ASGI3App = Callable[[Scope, Receive, Send], Awaitable[None]]
_RequestData = typing.Mapping[str, typing.Union[str, typing.Iterable[str]]]
_RequestData = Mapping[str, str | Iterable[str] | bytes]
def _is_asgi3(app: typing.Union[ASGI2App, ASGI3App]) -> bool:
def _is_asgi3(app: ASGI2App | ASGI3App) -> TypeGuard[ASGI3App]:
if inspect.isclass(app):
return hasattr(app, "__await__")
return is_async_callable(app)
@@ -58,14 +73,24 @@ class _WrapASGI2:
class _AsyncBackend(TypedDict):
backend: str
backend_options: typing.Dict[str, typing.Any]
backend_options: dict[str, Any]
class _Upgrade(Exception):
def __init__(self, session: "WebSocketTestSession") -> None:
def __init__(self, session: WebSocketTestSession) -> None:
self.session = session
class WebSocketDenialResponse( # type: ignore[misc]
httpx.Response,
WebSocketDisconnect,
):
"""
A special case of `WebSocketDisconnect`, raised in the `TestClient` if the
`WebSocket` is closed before being accepted with a `send_denial_response()`.
"""
class WebSocketTestSession:
def __init__(
self,
@@ -77,65 +102,60 @@ class WebSocketTestSession:
self.scope = scope
self.accepted_subprotocol = None
self.portal_factory = portal_factory
self._receive_queue: "queue.Queue[typing.Any]" = queue.Queue()
self._send_queue: "queue.Queue[typing.Any]" = queue.Queue()
self.extra_headers = None
def __enter__(self) -> "WebSocketTestSession":
self.exit_stack = contextlib.ExitStack()
self.portal = self.exit_stack.enter_context(self.portal_factory())
try:
_: "Future[None]" = self.portal.start_task_soon(self._run)
def __enter__(self) -> WebSocketTestSession:
with contextlib.ExitStack() as stack:
self.portal = portal = stack.enter_context(self.portal_factory())
fut, cs = portal.start_task(self._run)
stack.callback(fut.result)
stack.callback(portal.call, cs.cancel)
self.send({"type": "websocket.connect"})
message = self.receive()
self._raise_on_close(message)
except Exception:
self.exit_stack.close()
raise
self.accepted_subprotocol = message.get("subprotocol", None)
self.extra_headers = message.get("headers", None)
return self
self.accepted_subprotocol = message.get("subprotocol", None)
self.extra_headers = message.get("headers", None)
stack.callback(self.close, 1000)
self.exit_stack = stack.pop_all()
return self
def __exit__(self, *args: typing.Any) -> None:
try:
self.close(1000)
finally:
self.exit_stack.close()
while not self._send_queue.empty():
message = self._send_queue.get()
if isinstance(message, BaseException):
raise message
def __exit__(self, *args: Any) -> bool | None:
return self.exit_stack.__exit__(*args)
async def _run(self) -> None:
async def _run(self, *, task_status: anyio.abc.TaskStatus[anyio.CancelScope]) -> None:
"""
The sub-thread in which the websocket session runs.
"""
scope = self.scope
receive = self._asgi_receive
send = self._asgi_send
try:
await self.app(scope, receive, send)
except BaseException as exc:
self._send_queue.put(exc)
raise
send: anyio.create_memory_object_stream[Message] = anyio.create_memory_object_stream(math.inf)
send_tx, send_rx = send
receive: anyio.create_memory_object_stream[Message] = anyio.create_memory_object_stream(math.inf)
receive_tx, receive_rx = receive
with send_tx, send_rx, receive_tx, receive_rx, anyio.CancelScope() as cs:
self._receive_tx = receive_tx
self._send_rx = send_rx
task_status.started(cs)
await self.app(self.scope, receive_rx.receive, send_tx.send)
async def _asgi_receive(self) -> Message:
while self._receive_queue.empty():
await anyio.sleep(0)
return self._receive_queue.get()
async def _asgi_send(self, message: Message) -> None:
self._send_queue.put(message)
# wait for cs.cancel to be called before closing streams
await anyio.sleep_forever()
def _raise_on_close(self, message: Message) -> None:
if message["type"] == "websocket.close":
raise WebSocketDisconnect(
message.get("code", 1000), message.get("reason", "")
)
raise WebSocketDisconnect(code=message.get("code", 1000), reason=message.get("reason", ""))
elif message["type"] == "websocket.http.response.start":
status_code: int = message["status"]
headers: list[tuple[bytes, bytes]] = message["headers"]
body: list[bytes] = []
while True:
message = self.receive()
assert message["type"] == "websocket.http.response.body"
body.append(message["body"])
if not message.get("more_body", False):
break
raise WebSocketDenialResponse(status_code=status_code, headers=headers, content=b"".join(body))
def send(self, message: Message) -> None:
self._receive_queue.put(message)
self.portal.call(self._receive_tx.send, message)
def send_text(self, data: str) -> None:
self.send({"type": "websocket.receive", "text": data})
@@ -143,35 +163,30 @@ class WebSocketTestSession:
def send_bytes(self, data: bytes) -> None:
self.send({"type": "websocket.receive", "bytes": data})
def send_json(self, data: typing.Any, mode: str = "text") -> None:
assert mode in ["text", "binary"]
text = json.dumps(data, separators=(",", ":"))
def send_json(self, data: Any, mode: Literal["text", "binary"] = "text") -> None:
text = json.dumps(data, separators=(",", ":"), ensure_ascii=False)
if mode == "text":
self.send({"type": "websocket.receive", "text": text})
else:
self.send({"type": "websocket.receive", "bytes": text.encode("utf-8")})
def close(self, code: int = 1000) -> None:
self.send({"type": "websocket.disconnect", "code": code})
def close(self, code: int = 1000, reason: str | None = None) -> None:
self.send({"type": "websocket.disconnect", "code": code, "reason": reason})
def receive(self) -> Message:
message = self._send_queue.get()
if isinstance(message, BaseException):
raise message
return message
return self.portal.call(self._send_rx.receive)
def receive_text(self) -> str:
message = self.receive()
self._raise_on_close(message)
return message["text"]
return cast(str, message["text"])
def receive_bytes(self) -> bytes:
message = self.receive()
self._raise_on_close(message)
return message["bytes"]
return cast(bytes, message["bytes"])
def receive_json(self, mode: str = "text") -> typing.Any:
assert mode in ["text", "binary"]
def receive_json(self, mode: Literal["text", "binary"] = "text") -> Any:
message = self.receive()
self._raise_on_close(message)
if mode == "text":
@@ -189,13 +204,15 @@ class _TestClientTransport(httpx.BaseTransport):
raise_server_exceptions: bool = True,
root_path: str = "",
*,
app_state: typing.Dict[str, typing.Any],
client: tuple[str, int],
app_state: dict[str, Any],
) -> None:
self.app = app
self.raise_server_exceptions = raise_server_exceptions
self.root_path = root_path
self.portal_factory = portal_factory
self.app_state = app_state
self.client = client
def handle_request(self, request: httpx.Request) -> httpx.Response:
scheme = request.url.scheme
@@ -215,38 +232,36 @@ class _TestClientTransport(httpx.BaseTransport):
# Include the 'host' header.
if "host" in request.headers:
headers: typing.List[typing.Tuple[bytes, bytes]] = []
headers: list[tuple[bytes, bytes]] = []
elif port == default_port: # pragma: no cover
headers = [(b"host", host.encode())]
else: # pragma: no cover
headers = [(b"host", (f"{host}:{port}").encode())]
# Include other request headers.
headers += [
(key.lower().encode(), value.encode())
for key, value in request.headers.items()
]
headers += [(key.lower().encode(), value.encode()) for key, value in request.headers.multi_items()]
scope: typing.Dict[str, typing.Any]
scope: dict[str, Any]
if scheme in {"ws", "wss"}:
subprotocol = request.headers.get("sec-websocket-protocol", None)
if subprotocol is None:
subprotocols: typing.Sequence[str] = []
subprotocols: Sequence[str] = []
else:
subprotocols = [value.strip() for value in subprotocol.split(",")]
scope = {
"type": "websocket",
"path": unquote(path),
"raw_path": raw_path,
"raw_path": raw_path.split(b"?", 1)[0],
"root_path": self.root_path,
"scheme": scheme,
"query_string": query.encode(),
"headers": headers,
"client": ["testclient", 50000],
"client": self.client,
"server": [host, port],
"subprotocols": subprotocols,
"state": self.app_state.copy(),
"extensions": {"websocket.http.response": {}},
}
session = WebSocketTestSession(self.app, scope, self.portal_factory)
raise _Upgrade(session)
@@ -256,12 +271,12 @@ class _TestClientTransport(httpx.BaseTransport):
"http_version": "1.1",
"method": request.method,
"path": unquote(path),
"raw_path": raw_path,
"raw_path": raw_path.split(b"?", 1)[0],
"root_path": self.root_path,
"scheme": scheme,
"query_string": query.encode(),
"headers": headers,
"client": ["testclient", 50000],
"client": self.client,
"server": [host, port],
"extensions": {"http.response.debug": {}},
"state": self.app_state.copy(),
@@ -270,7 +285,7 @@ class _TestClientTransport(httpx.BaseTransport):
request_complete = False
response_started = False
response_complete: anyio.Event
raw_kwargs: typing.Dict[str, typing.Any] = {"stream": io.BytesIO()}
raw_kwargs: dict[str, Any] = {"stream": io.BytesIO()}
template = None
context = None
@@ -306,22 +321,13 @@ class _TestClientTransport(httpx.BaseTransport):
nonlocal raw_kwargs, response_started, template, context
if message["type"] == "http.response.start":
assert (
not response_started
), 'Received multiple "http.response.start" messages.'
assert not response_started, 'Received multiple "http.response.start" messages.'
raw_kwargs["status_code"] = message["status"]
raw_kwargs["headers"] = [
(key.decode(), value.decode())
for key, value in message.get("headers", [])
]
raw_kwargs["headers"] = [(key.decode(), value.decode()) for key, value in message.get("headers", [])]
response_started = True
elif message["type"] == "http.response.body":
assert (
response_started
), 'Received "http.response.body" without "http.response.start".'
assert (
not response_complete.is_set()
), 'Received "http.response.body" after response completed.'
assert response_started, 'Received "http.response.body" without "http.response.start".'
assert not response_complete.is_set(), 'Received "http.response.body" after response completed.'
body = message.get("body", b"")
more_body = message.get("more_body", False)
if request.method != "HEAD":
@@ -361,8 +367,8 @@ class _TestClientTransport(httpx.BaseTransport):
class TestClient(httpx.Client):
__test__ = False
task: "Future[None]"
portal: typing.Optional[anyio.abc.BlockingPortal] = None
task: Future[None]
portal: anyio.abc.BlockingPortal | None = None
def __init__(
self,
@@ -370,110 +376,84 @@ class TestClient(httpx.Client):
base_url: str = "http://testserver",
raise_server_exceptions: bool = True,
root_path: str = "",
backend: str = "asyncio",
backend_options: typing.Optional[typing.Dict[str, typing.Any]] = None,
cookies: httpx._client.CookieTypes = None,
headers: typing.Dict[str, str] = None,
backend: Literal["asyncio", "trio"] = "asyncio",
backend_options: dict[str, Any] | None = None,
cookies: httpx._types.CookieTypes | None = None,
headers: dict[str, str] | None = None,
follow_redirects: bool = True,
client: tuple[str, int] = ("testclient", 50000),
) -> None:
self.async_backend = _AsyncBackend(
backend=backend, backend_options=backend_options or {}
)
self.async_backend = _AsyncBackend(backend=backend, backend_options=backend_options or {})
if _is_asgi3(app):
app = typing.cast(ASGI3App, app)
asgi_app = app
else:
app = typing.cast(ASGI2App, app) # type: ignore[assignment]
app = cast(ASGI2App, app) # type: ignore[assignment]
asgi_app = _WrapASGI2(app) # type: ignore[arg-type]
self.app = asgi_app
self.app_state: typing.Dict[str, typing.Any] = {}
self.app_state: dict[str, Any] = {}
transport = _TestClientTransport(
self.app,
portal_factory=self._portal_factory,
raise_server_exceptions=raise_server_exceptions,
root_path=root_path,
app_state=self.app_state,
client=client,
)
if headers is None:
headers = {}
headers.setdefault("user-agent", "testclient")
super().__init__(
app=self.app,
base_url=base_url,
headers=headers,
transport=transport,
follow_redirects=True,
follow_redirects=follow_redirects,
cookies=cookies,
)
@contextlib.contextmanager
def _portal_factory(self) -> typing.Generator[anyio.abc.BlockingPortal, None, None]:
def _portal_factory(self) -> Generator[anyio.abc.BlockingPortal, None, None]:
if self.portal is not None:
yield self.portal
else:
with anyio.from_thread.start_blocking_portal(
**self.async_backend
) as portal:
with anyio.from_thread.start_blocking_portal(**self.async_backend) as portal:
yield portal
def _choose_redirect_arg(
self,
follow_redirects: typing.Optional[bool],
allow_redirects: typing.Optional[bool],
) -> typing.Union[bool, httpx._client.UseClientDefault]:
redirect: typing.Union[
bool, httpx._client.UseClientDefault
] = httpx._client.USE_CLIENT_DEFAULT
if allow_redirects is not None:
message = (
"The `allow_redirects` argument is deprecated. "
"Use `follow_redirects` instead."
)
warnings.warn(message, DeprecationWarning)
redirect = allow_redirects
if follow_redirects is not None:
redirect = follow_redirects
elif allow_redirects is not None and follow_redirects is not None:
raise RuntimeError( # pragma: no cover
"Cannot use both `allow_redirects` and `follow_redirects`."
)
return redirect
def request( # type: ignore[override]
self,
method: str,
url: httpx._types.URLTypes,
*,
content: typing.Optional[httpx._types.RequestContent] = None,
data: typing.Optional[_RequestData] = None,
files: typing.Optional[httpx._types.RequestFiles] = None,
json: typing.Any = None,
params: typing.Optional[httpx._types.QueryParamTypes] = None,
headers: typing.Optional[httpx._types.HeaderTypes] = None,
cookies: typing.Optional[httpx._types.CookieTypes] = None,
auth: typing.Union[
httpx._types.AuthTypes, httpx._client.UseClientDefault
] = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: typing.Optional[bool] = None,
allow_redirects: typing.Optional[bool] = None,
timeout: typing.Union[
httpx._client.TimeoutTypes, httpx._client.UseClientDefault
] = httpx._client.USE_CLIENT_DEFAULT,
extensions: typing.Optional[typing.Dict[str, typing.Any]] = None,
content: httpx._types.RequestContent | None = None,
data: _RequestData | None = None,
files: httpx._types.RequestFiles | None = None,
json: Any = None,
params: httpx._types.QueryParamTypes | None = None,
headers: httpx._types.HeaderTypes | None = None,
cookies: httpx._types.CookieTypes | None = None,
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
extensions: dict[str, Any] | None = None,
) -> httpx.Response:
url = self.base_url.join(url)
redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
if timeout is not httpx.USE_CLIENT_DEFAULT:
warnings.warn(
"You should not use the 'timeout' argument with the TestClient. "
"See https://github.com/Kludex/starlette/issues/1108 for more information.",
DeprecationWarning,
)
url = self._merge_url(url)
return super().request(
method,
url,
content=content,
data=data, # type: ignore[arg-type]
data=data,
files=files,
json=json,
params=params,
headers=headers,
cookies=cookies,
auth=auth,
follow_redirects=redirect,
follow_redirects=follow_redirects,
timeout=timeout,
extensions=extensions,
)
@@ -482,27 +462,21 @@ class TestClient(httpx.Client):
self,
url: httpx._types.URLTypes,
*,
params: typing.Optional[httpx._types.QueryParamTypes] = None,
headers: typing.Optional[httpx._types.HeaderTypes] = None,
cookies: typing.Optional[httpx._types.CookieTypes] = None,
auth: typing.Union[
httpx._types.AuthTypes, httpx._client.UseClientDefault
] = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: typing.Optional[bool] = None,
allow_redirects: typing.Optional[bool] = None,
timeout: typing.Union[
httpx._client.TimeoutTypes, httpx._client.UseClientDefault
] = httpx._client.USE_CLIENT_DEFAULT,
extensions: typing.Optional[typing.Dict[str, typing.Any]] = None,
params: httpx._types.QueryParamTypes | None = None,
headers: httpx._types.HeaderTypes | None = None,
cookies: httpx._types.CookieTypes | None = None,
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
extensions: dict[str, Any] | None = None,
) -> httpx.Response:
redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
return super().get(
url,
params=params,
headers=headers,
cookies=cookies,
auth=auth,
follow_redirects=redirect,
follow_redirects=follow_redirects,
timeout=timeout,
extensions=extensions,
)
@@ -511,27 +485,21 @@ class TestClient(httpx.Client):
self,
url: httpx._types.URLTypes,
*,
params: typing.Optional[httpx._types.QueryParamTypes] = None,
headers: typing.Optional[httpx._types.HeaderTypes] = None,
cookies: typing.Optional[httpx._types.CookieTypes] = None,
auth: typing.Union[
httpx._types.AuthTypes, httpx._client.UseClientDefault
] = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: typing.Optional[bool] = None,
allow_redirects: typing.Optional[bool] = None,
timeout: typing.Union[
httpx._client.TimeoutTypes, httpx._client.UseClientDefault
] = httpx._client.USE_CLIENT_DEFAULT,
extensions: typing.Optional[typing.Dict[str, typing.Any]] = None,
params: httpx._types.QueryParamTypes | None = None,
headers: httpx._types.HeaderTypes | None = None,
cookies: httpx._types.CookieTypes | None = None,
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
extensions: dict[str, Any] | None = None,
) -> httpx.Response:
redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
return super().options(
url,
params=params,
headers=headers,
cookies=cookies,
auth=auth,
follow_redirects=redirect,
follow_redirects=follow_redirects,
timeout=timeout,
extensions=extensions,
)
@@ -540,27 +508,21 @@ class TestClient(httpx.Client):
self,
url: httpx._types.URLTypes,
*,
params: typing.Optional[httpx._types.QueryParamTypes] = None,
headers: typing.Optional[httpx._types.HeaderTypes] = None,
cookies: typing.Optional[httpx._types.CookieTypes] = None,
auth: typing.Union[
httpx._types.AuthTypes, httpx._client.UseClientDefault
] = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: typing.Optional[bool] = None,
allow_redirects: typing.Optional[bool] = None,
timeout: typing.Union[
httpx._client.TimeoutTypes, httpx._client.UseClientDefault
] = httpx._client.USE_CLIENT_DEFAULT,
extensions: typing.Optional[typing.Dict[str, typing.Any]] = None,
params: httpx._types.QueryParamTypes | None = None,
headers: httpx._types.HeaderTypes | None = None,
cookies: httpx._types.CookieTypes | None = None,
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
extensions: dict[str, Any] | None = None,
) -> httpx.Response:
redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
return super().head(
url,
params=params,
headers=headers,
cookies=cookies,
auth=auth,
follow_redirects=redirect,
follow_redirects=follow_redirects,
timeout=timeout,
extensions=extensions,
)
@@ -569,35 +531,29 @@ class TestClient(httpx.Client):
self,
url: httpx._types.URLTypes,
*,
content: typing.Optional[httpx._types.RequestContent] = None,
data: typing.Optional[_RequestData] = None,
files: typing.Optional[httpx._types.RequestFiles] = None,
json: typing.Any = None,
params: typing.Optional[httpx._types.QueryParamTypes] = None,
headers: typing.Optional[httpx._types.HeaderTypes] = None,
cookies: typing.Optional[httpx._types.CookieTypes] = None,
auth: typing.Union[
httpx._types.AuthTypes, httpx._client.UseClientDefault
] = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: typing.Optional[bool] = None,
allow_redirects: typing.Optional[bool] = None,
timeout: typing.Union[
httpx._client.TimeoutTypes, httpx._client.UseClientDefault
] = httpx._client.USE_CLIENT_DEFAULT,
extensions: typing.Optional[typing.Dict[str, typing.Any]] = None,
content: httpx._types.RequestContent | None = None,
data: _RequestData | None = None,
files: httpx._types.RequestFiles | None = None,
json: Any = None,
params: httpx._types.QueryParamTypes | None = None,
headers: httpx._types.HeaderTypes | None = None,
cookies: httpx._types.CookieTypes | None = None,
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
extensions: dict[str, Any] | None = None,
) -> httpx.Response:
redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
return super().post(
url,
content=content,
data=data, # type: ignore[arg-type]
data=data,
files=files,
json=json,
params=params,
headers=headers,
cookies=cookies,
auth=auth,
follow_redirects=redirect,
follow_redirects=follow_redirects,
timeout=timeout,
extensions=extensions,
)
@@ -606,35 +562,29 @@ class TestClient(httpx.Client):
self,
url: httpx._types.URLTypes,
*,
content: typing.Optional[httpx._types.RequestContent] = None,
data: typing.Optional[_RequestData] = None,
files: typing.Optional[httpx._types.RequestFiles] = None,
json: typing.Any = None,
params: typing.Optional[httpx._types.QueryParamTypes] = None,
headers: typing.Optional[httpx._types.HeaderTypes] = None,
cookies: typing.Optional[httpx._types.CookieTypes] = None,
auth: typing.Union[
httpx._types.AuthTypes, httpx._client.UseClientDefault
] = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: typing.Optional[bool] = None,
allow_redirects: typing.Optional[bool] = None,
timeout: typing.Union[
httpx._client.TimeoutTypes, httpx._client.UseClientDefault
] = httpx._client.USE_CLIENT_DEFAULT,
extensions: typing.Optional[typing.Dict[str, typing.Any]] = None,
content: httpx._types.RequestContent | None = None,
data: _RequestData | None = None,
files: httpx._types.RequestFiles | None = None,
json: Any = None,
params: httpx._types.QueryParamTypes | None = None,
headers: httpx._types.HeaderTypes | None = None,
cookies: httpx._types.CookieTypes | None = None,
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
extensions: dict[str, Any] | None = None,
) -> httpx.Response:
redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
return super().put(
url,
content=content,
data=data, # type: ignore[arg-type]
data=data,
files=files,
json=json,
params=params,
headers=headers,
cookies=cookies,
auth=auth,
follow_redirects=redirect,
follow_redirects=follow_redirects,
timeout=timeout,
extensions=extensions,
)
@@ -643,35 +593,29 @@ class TestClient(httpx.Client):
self,
url: httpx._types.URLTypes,
*,
content: typing.Optional[httpx._types.RequestContent] = None,
data: typing.Optional[_RequestData] = None,
files: typing.Optional[httpx._types.RequestFiles] = None,
json: typing.Any = None,
params: typing.Optional[httpx._types.QueryParamTypes] = None,
headers: typing.Optional[httpx._types.HeaderTypes] = None,
cookies: typing.Optional[httpx._types.CookieTypes] = None,
auth: typing.Union[
httpx._types.AuthTypes, httpx._client.UseClientDefault
] = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: typing.Optional[bool] = None,
allow_redirects: typing.Optional[bool] = None,
timeout: typing.Union[
httpx._client.TimeoutTypes, httpx._client.UseClientDefault
] = httpx._client.USE_CLIENT_DEFAULT,
extensions: typing.Optional[typing.Dict[str, typing.Any]] = None,
content: httpx._types.RequestContent | None = None,
data: _RequestData | None = None,
files: httpx._types.RequestFiles | None = None,
json: Any = None,
params: httpx._types.QueryParamTypes | None = None,
headers: httpx._types.HeaderTypes | None = None,
cookies: httpx._types.CookieTypes | None = None,
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
extensions: dict[str, Any] | None = None,
) -> httpx.Response:
redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
return super().patch(
url,
content=content,
data=data, # type: ignore[arg-type]
data=data,
files=files,
json=json,
params=params,
headers=headers,
cookies=cookies,
auth=auth,
follow_redirects=redirect,
follow_redirects=follow_redirects,
timeout=timeout,
extensions=extensions,
)
@@ -680,34 +624,31 @@ class TestClient(httpx.Client):
self,
url: httpx._types.URLTypes,
*,
params: typing.Optional[httpx._types.QueryParamTypes] = None,
headers: typing.Optional[httpx._types.HeaderTypes] = None,
cookies: typing.Optional[httpx._types.CookieTypes] = None,
auth: typing.Union[
httpx._types.AuthTypes, httpx._client.UseClientDefault
] = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: typing.Optional[bool] = None,
allow_redirects: typing.Optional[bool] = None,
timeout: typing.Union[
httpx._client.TimeoutTypes, httpx._client.UseClientDefault
] = httpx._client.USE_CLIENT_DEFAULT,
extensions: typing.Optional[typing.Dict[str, typing.Any]] = None,
params: httpx._types.QueryParamTypes | None = None,
headers: httpx._types.HeaderTypes | None = None,
cookies: httpx._types.CookieTypes | None = None,
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
extensions: dict[str, Any] | None = None,
) -> httpx.Response:
redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
return super().delete(
url,
params=params,
headers=headers,
cookies=cookies,
auth=auth,
follow_redirects=redirect,
follow_redirects=follow_redirects,
timeout=timeout,
extensions=extensions,
)
def websocket_connect(
self, url: str, subprotocols: typing.Sequence[str] = None, **kwargs: typing.Any
) -> typing.Any:
self,
url: str,
subprotocols: Sequence[str] | None = None,
**kwargs: Any,
) -> WebSocketTestSession:
url = urljoin("ws://testserver", url)
headers = kwargs.get("headers", {})
headers.setdefault("connection", "upgrade")
@@ -725,22 +666,24 @@ class TestClient(httpx.Client):
return session
def __enter__(self) -> "TestClient":
def __enter__(self) -> Self:
with contextlib.ExitStack() as stack:
self.portal = portal = stack.enter_context(
anyio.from_thread.start_blocking_portal(**self.async_backend)
)
self.portal = portal = stack.enter_context(anyio.from_thread.start_blocking_portal(**self.async_backend))
@stack.callback
def reset_portal() -> None:
self.portal = None
self.stream_send = StapledObjectStream(
*anyio.create_memory_object_stream(math.inf)
send: anyio.create_memory_object_stream[MutableMapping[str, Any] | None] = (
anyio.create_memory_object_stream(math.inf)
)
self.stream_receive = StapledObjectStream(
*anyio.create_memory_object_stream(math.inf)
receive: anyio.create_memory_object_stream[MutableMapping[str, Any]] = anyio.create_memory_object_stream(
math.inf
)
for channel in (*send, *receive):
stack.callback(channel.close)
self.stream_send = StapledObjectStream(*send)
self.stream_receive = StapledObjectStream(*receive)
self.task = portal.start_task_soon(self.lifespan)
portal.call(self.wait_startup)
@@ -752,7 +695,7 @@ class TestClient(httpx.Client):
return self
def __exit__(self, *args: typing.Any) -> None:
def __exit__(self, *args: Any) -> None:
self.exit_stack.close()
async def lifespan(self) -> None:
@@ -765,7 +708,7 @@ class TestClient(httpx.Client):
async def wait_startup(self) -> None:
await self.stream_receive.send({"type": "lifespan.startup"})
async def receive() -> typing.Any:
async def receive() -> Any:
message = await self.stream_send.receive()
if message is None:
self.task.result()
@@ -780,18 +723,17 @@ class TestClient(httpx.Client):
await receive()
async def wait_shutdown(self) -> None:
async def receive() -> typing.Any:
async def receive() -> Any:
message = await self.stream_send.receive()
if message is None:
self.task.result()
return message
async with self.stream_send:
await self.stream_receive.send({"type": "lifespan.shutdown"})
message = await receive()
assert message["type"] in (
"lifespan.shutdown.complete",
"lifespan.shutdown.failed",
)
if message["type"] == "lifespan.shutdown.failed":
await receive()
await self.stream_receive.send({"type": "lifespan.shutdown"})
message = await receive()
assert message["type"] in (
"lifespan.shutdown.complete",
"lifespan.shutdown.failed",
)
if message["type"] == "lifespan.shutdown.failed":
await receive()

View File

@@ -1,17 +1,26 @@
import typing
from collections.abc import Awaitable, Callable, Mapping, MutableMapping
from contextlib import AbstractAsyncContextManager
from typing import TYPE_CHECKING, Any, TypeVar
AppType = typing.TypeVar("AppType")
if TYPE_CHECKING:
from starlette.requests import Request
from starlette.responses import Response
from starlette.websockets import WebSocket
Scope = typing.MutableMapping[str, typing.Any]
Message = typing.MutableMapping[str, typing.Any]
AppType = TypeVar("AppType")
Receive = typing.Callable[[], typing.Awaitable[Message]]
Send = typing.Callable[[Message], typing.Awaitable[None]]
Scope = MutableMapping[str, Any]
Message = MutableMapping[str, Any]
ASGIApp = typing.Callable[[Scope, Receive, Send], typing.Awaitable[None]]
Receive = Callable[[], Awaitable[Message]]
Send = Callable[[Message], Awaitable[None]]
StatelessLifespan = typing.Callable[[AppType], typing.AsyncContextManager[None]]
StatefulLifespan = typing.Callable[
[AppType], typing.AsyncContextManager[typing.Mapping[str, typing.Any]]
]
Lifespan = typing.Union[StatelessLifespan[AppType], StatefulLifespan[AppType]]
ASGIApp = Callable[[Scope, Receive, Send], Awaitable[None]]
StatelessLifespan = Callable[[AppType], AbstractAsyncContextManager[None]]
StatefulLifespan = Callable[[AppType], AbstractAsyncContextManager[Mapping[str, Any]]]
Lifespan = StatelessLifespan[AppType] | StatefulLifespan[AppType]
HTTPExceptionHandler = Callable[["Request", Exception], "Response | Awaitable[Response]"]
WebSocketExceptionHandler = Callable[["WebSocket", Exception], Awaitable[None]]
ExceptionHandler = HTTPExceptionHandler | WebSocketExceptionHandler

View File

@@ -1,8 +1,12 @@
from __future__ import annotations
import enum
import json
import typing
from collections.abc import AsyncIterator, Iterable
from typing import Any, cast
from starlette.requests import HTTPConnection
from starlette.responses import Response
from starlette.types import Message, Receive, Scope, Send
@@ -10,10 +14,11 @@ class WebSocketState(enum.Enum):
CONNECTING = 0
CONNECTED = 1
DISCONNECTED = 2
RESPONSE = 3
class WebSocketDisconnect(Exception):
def __init__(self, code: int = 1000, reason: typing.Optional[str] = None) -> None:
def __init__(self, code: int = 1000, reason: str | None = None) -> None:
self.code = code
self.reason = reason or ""
@@ -35,10 +40,7 @@ class WebSocket(HTTPConnection):
message = await self._receive()
message_type = message["type"]
if message_type != "websocket.connect":
raise RuntimeError(
'Expected ASGI message "websocket.connect", '
f"but got {message_type!r}"
)
raise RuntimeError(f'Expected ASGI message "websocket.connect", but got {message_type!r}')
self.client_state = WebSocketState.CONNECTED
return message
elif self.client_state == WebSocketState.CONNECTED:
@@ -46,16 +48,13 @@ class WebSocket(HTTPConnection):
message_type = message["type"]
if message_type not in {"websocket.receive", "websocket.disconnect"}:
raise RuntimeError(
'Expected ASGI message "websocket.receive" or '
f'"websocket.disconnect", but got {message_type!r}'
f'Expected ASGI message "websocket.receive" or "websocket.disconnect", but got {message_type!r}'
)
if message_type == "websocket.disconnect":
self.client_state = WebSocketState.DISCONNECTED
return message
else:
raise RuntimeError(
'Cannot call "receive" once a disconnect message has been received.'
)
raise RuntimeError('Cannot call "receive" once a disconnect message has been received.')
async def send(self, message: Message) -> None:
"""
@@ -63,13 +62,15 @@ class WebSocket(HTTPConnection):
"""
if self.application_state == WebSocketState.CONNECTING:
message_type = message["type"]
if message_type not in {"websocket.accept", "websocket.close"}:
if message_type not in {"websocket.accept", "websocket.close", "websocket.http.response.start"}:
raise RuntimeError(
'Expected ASGI message "websocket.accept" or '
f'"websocket.close", but got {message_type!r}'
'Expected ASGI message "websocket.accept", "websocket.close" or "websocket.http.response.start", '
f"but got {message_type!r}"
)
if message_type == "websocket.close":
self.application_state = WebSocketState.DISCONNECTED
elif message_type == "websocket.http.response.start":
self.application_state = WebSocketState.RESPONSE
else:
self.application_state = WebSocketState.CONNECTED
await self._send(message)
@@ -77,58 +78,60 @@ class WebSocket(HTTPConnection):
message_type = message["type"]
if message_type not in {"websocket.send", "websocket.close"}:
raise RuntimeError(
'Expected ASGI message "websocket.send" or "websocket.close", '
f"but got {message_type!r}"
f'Expected ASGI message "websocket.send" or "websocket.close", but got {message_type!r}'
)
if message_type == "websocket.close":
self.application_state = WebSocketState.DISCONNECTED
try:
await self._send(message)
except OSError:
self.application_state = WebSocketState.DISCONNECTED
raise WebSocketDisconnect(code=1006)
elif self.application_state == WebSocketState.RESPONSE:
message_type = message["type"]
if message_type != "websocket.http.response.body":
raise RuntimeError(f'Expected ASGI message "websocket.http.response.body", but got {message_type!r}')
if not message.get("more_body", False):
self.application_state = WebSocketState.DISCONNECTED
await self._send(message)
else:
raise RuntimeError('Cannot call "send" once a close message has been sent.')
async def accept(
self,
subprotocol: typing.Optional[str] = None,
headers: typing.Optional[typing.Iterable[typing.Tuple[bytes, bytes]]] = None,
subprotocol: str | None = None,
headers: Iterable[tuple[bytes, bytes]] | None = None,
) -> None:
headers = headers or []
if self.client_state == WebSocketState.CONNECTING:
if self.client_state == WebSocketState.CONNECTING: # pragma: no branch
# If we haven't yet seen the 'connect' message, then wait for it first.
await self.receive()
await self.send(
{"type": "websocket.accept", "subprotocol": subprotocol, "headers": headers}
)
await self.send({"type": "websocket.accept", "subprotocol": subprotocol, "headers": headers})
def _raise_on_disconnect(self, message: Message) -> None:
if message["type"] == "websocket.disconnect":
raise WebSocketDisconnect(message["code"])
raise WebSocketDisconnect(message["code"], message.get("reason"))
async def receive_text(self) -> str:
if self.application_state != WebSocketState.CONNECTED:
raise RuntimeError(
'WebSocket is not connected. Need to call "accept" first.'
)
raise RuntimeError('WebSocket is not connected. Need to call "accept" first.')
message = await self.receive()
self._raise_on_disconnect(message)
return message["text"]
return cast(str, message["text"])
async def receive_bytes(self) -> bytes:
if self.application_state != WebSocketState.CONNECTED:
raise RuntimeError(
'WebSocket is not connected. Need to call "accept" first.'
)
raise RuntimeError('WebSocket is not connected. Need to call "accept" first.')
message = await self.receive()
self._raise_on_disconnect(message)
return message["bytes"]
return cast(bytes, message["bytes"])
async def receive_json(self, mode: str = "text") -> typing.Any:
async def receive_json(self, mode: str = "text") -> Any:
if mode not in {"text", "binary"}:
raise RuntimeError('The "mode" argument should be "text" or "binary".')
if self.application_state != WebSocketState.CONNECTED:
raise RuntimeError(
'WebSocket is not connected. Need to call "accept" first.'
)
raise RuntimeError('WebSocket is not connected. Need to call "accept" first.')
message = await self.receive()
self._raise_on_disconnect(message)
@@ -138,21 +141,21 @@ class WebSocket(HTTPConnection):
text = message["bytes"].decode("utf-8")
return json.loads(text)
async def iter_text(self) -> typing.AsyncIterator[str]:
async def iter_text(self) -> AsyncIterator[str]:
try:
while True:
yield await self.receive_text()
except WebSocketDisconnect:
pass
async def iter_bytes(self) -> typing.AsyncIterator[bytes]:
async def iter_bytes(self) -> AsyncIterator[bytes]:
try:
while True:
yield await self.receive_bytes()
except WebSocketDisconnect:
pass
async def iter_json(self) -> typing.AsyncIterator[typing.Any]:
async def iter_json(self) -> AsyncIterator[Any]:
try:
while True:
yield await self.receive_json()
@@ -165,29 +168,29 @@ class WebSocket(HTTPConnection):
async def send_bytes(self, data: bytes) -> None:
await self.send({"type": "websocket.send", "bytes": data})
async def send_json(self, data: typing.Any, mode: str = "text") -> None:
async def send_json(self, data: Any, mode: str = "text") -> None:
if mode not in {"text", "binary"}:
raise RuntimeError('The "mode" argument should be "text" or "binary".')
text = json.dumps(data, separators=(",", ":"))
text = json.dumps(data, separators=(",", ":"), ensure_ascii=False)
if mode == "text":
await self.send({"type": "websocket.send", "text": text})
else:
await self.send({"type": "websocket.send", "bytes": text.encode("utf-8")})
async def close(
self, code: int = 1000, reason: typing.Optional[str] = None
) -> None:
await self.send(
{"type": "websocket.close", "code": code, "reason": reason or ""}
)
async def close(self, code: int = 1000, reason: str | None = None) -> None:
await self.send({"type": "websocket.close", "code": code, "reason": reason or ""})
async def send_denial_response(self, response: Response) -> None:
if "websocket.http.response" in self.scope.get("extensions", {}):
await response(self.scope, self.receive, self.send)
else:
raise RuntimeError("The server doesn't support the Websocket Denial Response extension.")
class WebSocketClose:
def __init__(self, code: int = 1000, reason: typing.Optional[str] = None) -> None:
def __init__(self, code: int = 1000, reason: str | None = None) -> None:
self.code = code
self.reason = reason or ""
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
await send(
{"type": "websocket.close", "code": self.code, "reason": self.reason}
)
await send({"type": "websocket.close", "code": self.code, "reason": self.reason})