updates
This commit is contained in:
@@ -1 +1 @@
|
||||
__version__ = "0.27.0"
|
||||
__version__ = "0.50.0"
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,28 +0,0 @@
|
||||
import hashlib
|
||||
|
||||
# Compat wrapper to always include the `usedforsecurity=...` parameter,
|
||||
# which is only added from Python 3.9 onwards.
|
||||
# We use this flag to indicate that we use `md5` hashes only for non-security
|
||||
# cases (our ETag checksums).
|
||||
# If we don't indicate that we're using MD5 for non-security related reasons,
|
||||
# then attempting to use this function will raise an error when used
|
||||
# environments which enable a strict "FIPs mode".
|
||||
#
|
||||
# See issue: https://github.com/encode/starlette/issues/1365
|
||||
try:
|
||||
# check if the Python version supports the parameter
|
||||
# using usedforsecurity=False to avoid an exception on FIPS systems
|
||||
# that reject usedforsecurity=True
|
||||
hashlib.md5(b"data", usedforsecurity=False) # type: ignore[call-arg]
|
||||
|
||||
def md5_hexdigest(
|
||||
data: bytes, *, usedforsecurity: bool = True
|
||||
) -> str: # pragma: no cover
|
||||
return hashlib.md5( # type: ignore[call-arg]
|
||||
data, usedforsecurity=usedforsecurity
|
||||
).hexdigest()
|
||||
|
||||
except TypeError: # pragma: no cover
|
||||
|
||||
def md5_hexdigest(data: bytes, *, usedforsecurity: bool = True) -> str:
|
||||
return hashlib.md5(data).hexdigest()
|
||||
@@ -0,0 +1,65 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from starlette._utils import is_async_callable
|
||||
from starlette.concurrency import run_in_threadpool
|
||||
from starlette.exceptions import HTTPException
|
||||
from starlette.requests import Request
|
||||
from starlette.types import ASGIApp, ExceptionHandler, Message, Receive, Scope, Send
|
||||
from starlette.websockets import WebSocket
|
||||
|
||||
ExceptionHandlers = dict[Any, ExceptionHandler]
|
||||
StatusHandlers = dict[int, ExceptionHandler]
|
||||
|
||||
|
||||
def _lookup_exception_handler(exc_handlers: ExceptionHandlers, exc: Exception) -> ExceptionHandler | None:
|
||||
for cls in type(exc).__mro__:
|
||||
if cls in exc_handlers:
|
||||
return exc_handlers[cls]
|
||||
return None
|
||||
|
||||
|
||||
def wrap_app_handling_exceptions(app: ASGIApp, conn: Request | WebSocket) -> ASGIApp:
|
||||
exception_handlers: ExceptionHandlers
|
||||
status_handlers: StatusHandlers
|
||||
try:
|
||||
exception_handlers, status_handlers = conn.scope["starlette.exception_handlers"]
|
||||
except KeyError:
|
||||
exception_handlers, status_handlers = {}, {}
|
||||
|
||||
async def wrapped_app(scope: Scope, receive: Receive, send: Send) -> None:
|
||||
response_started = False
|
||||
|
||||
async def sender(message: Message) -> None:
|
||||
nonlocal response_started
|
||||
|
||||
if message["type"] == "http.response.start":
|
||||
response_started = True
|
||||
await send(message)
|
||||
|
||||
try:
|
||||
await app(scope, receive, sender)
|
||||
except Exception as exc:
|
||||
handler = None
|
||||
|
||||
if isinstance(exc, HTTPException):
|
||||
handler = status_handlers.get(exc.status_code)
|
||||
|
||||
if handler is None:
|
||||
handler = _lookup_exception_handler(exception_handlers, exc)
|
||||
|
||||
if handler is None:
|
||||
raise exc
|
||||
|
||||
if response_started:
|
||||
raise RuntimeError("Caught handled exception, but response already started.") from exc
|
||||
|
||||
if is_async_callable(handler):
|
||||
response = await handler(conn, exc)
|
||||
else:
|
||||
response = await run_in_threadpool(handler, conn, exc)
|
||||
if response is not None:
|
||||
await response(scope, receive, sender)
|
||||
|
||||
return wrapped_app
|
||||
@@ -1,74 +1,103 @@
|
||||
import asyncio
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import sys
|
||||
import typing
|
||||
from types import TracebackType
|
||||
from collections.abc import Awaitable, Callable, Generator
|
||||
from contextlib import AbstractAsyncContextManager, contextmanager
|
||||
from typing import Any, Generic, Protocol, TypeVar, overload
|
||||
|
||||
if sys.version_info < (3, 8): # pragma: no cover
|
||||
from typing_extensions import Protocol
|
||||
from starlette.types import Scope
|
||||
|
||||
if sys.version_info >= (3, 13): # pragma: no cover
|
||||
from inspect import iscoroutinefunction
|
||||
from typing import TypeIs
|
||||
else: # pragma: no cover
|
||||
from typing import Protocol
|
||||
from asyncio import iscoroutinefunction
|
||||
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
has_exceptiongroups = True
|
||||
if sys.version_info < (3, 11): # pragma: no cover
|
||||
try:
|
||||
from exceptiongroup import BaseExceptionGroup # type: ignore[unused-ignore,import-not-found]
|
||||
except ImportError:
|
||||
has_exceptiongroups = False
|
||||
|
||||
T = TypeVar("T")
|
||||
AwaitableCallable = Callable[..., Awaitable[T]]
|
||||
|
||||
|
||||
def is_async_callable(obj: typing.Any) -> bool:
|
||||
@overload
|
||||
def is_async_callable(obj: AwaitableCallable[T]) -> TypeIs[AwaitableCallable[T]]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def is_async_callable(obj: Any) -> TypeIs[AwaitableCallable[Any]]: ...
|
||||
|
||||
|
||||
def is_async_callable(obj: Any) -> Any:
|
||||
while isinstance(obj, functools.partial):
|
||||
obj = obj.func
|
||||
|
||||
return asyncio.iscoroutinefunction(obj) or (
|
||||
callable(obj) and asyncio.iscoroutinefunction(obj.__call__)
|
||||
)
|
||||
return iscoroutinefunction(obj) or (callable(obj) and iscoroutinefunction(obj.__call__))
|
||||
|
||||
|
||||
T_co = typing.TypeVar("T_co", covariant=True)
|
||||
T_co = TypeVar("T_co", covariant=True)
|
||||
|
||||
|
||||
# TODO: once 3.8 is the minimum supported version (27 Jun 2023)
|
||||
# this can just become
|
||||
# class AwaitableOrContextManager(
|
||||
# typing.Awaitable[T_co],
|
||||
# typing.AsyncContextManager[T_co],
|
||||
# typing.Protocol[T_co],
|
||||
# ):
|
||||
# pass
|
||||
class AwaitableOrContextManager(Protocol[T_co]):
|
||||
def __await__(self) -> typing.Generator[typing.Any, None, T_co]:
|
||||
... # pragma: no cover
|
||||
|
||||
async def __aenter__(self) -> T_co:
|
||||
... # pragma: no cover
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
__exc_type: typing.Optional[typing.Type[BaseException]],
|
||||
__exc_value: typing.Optional[BaseException],
|
||||
__traceback: typing.Optional[TracebackType],
|
||||
) -> typing.Union[bool, None]:
|
||||
... # pragma: no cover
|
||||
class AwaitableOrContextManager(Awaitable[T_co], AbstractAsyncContextManager[T_co], Protocol[T_co]): ...
|
||||
|
||||
|
||||
class SupportsAsyncClose(Protocol):
|
||||
async def close(self) -> None:
|
||||
... # pragma: no cover
|
||||
async def close(self) -> None: ... # pragma: no cover
|
||||
|
||||
|
||||
SupportsAsyncCloseType = typing.TypeVar(
|
||||
"SupportsAsyncCloseType", bound=SupportsAsyncClose, covariant=False
|
||||
)
|
||||
SupportsAsyncCloseType = TypeVar("SupportsAsyncCloseType", bound=SupportsAsyncClose, covariant=False)
|
||||
|
||||
|
||||
class AwaitableOrContextManagerWrapper(typing.Generic[SupportsAsyncCloseType]):
|
||||
class AwaitableOrContextManagerWrapper(Generic[SupportsAsyncCloseType]):
|
||||
__slots__ = ("aw", "entered")
|
||||
|
||||
def __init__(self, aw: typing.Awaitable[SupportsAsyncCloseType]) -> None:
|
||||
def __init__(self, aw: Awaitable[SupportsAsyncCloseType]) -> None:
|
||||
self.aw = aw
|
||||
|
||||
def __await__(self) -> typing.Generator[typing.Any, None, SupportsAsyncCloseType]:
|
||||
def __await__(self) -> Generator[Any, None, SupportsAsyncCloseType]:
|
||||
return self.aw.__await__()
|
||||
|
||||
async def __aenter__(self) -> SupportsAsyncCloseType:
|
||||
self.entered = await self.aw
|
||||
return self.entered
|
||||
|
||||
async def __aexit__(self, *args: typing.Any) -> typing.Union[None, bool]:
|
||||
async def __aexit__(self, *args: Any) -> None | bool:
|
||||
await self.entered.close()
|
||||
return None
|
||||
|
||||
|
||||
@contextmanager
|
||||
def collapse_excgroups() -> Generator[None, None, None]:
|
||||
try:
|
||||
yield
|
||||
except BaseException as exc:
|
||||
if has_exceptiongroups: # pragma: no cover
|
||||
while isinstance(exc, BaseExceptionGroup) and len(exc.exceptions) == 1:
|
||||
exc = exc.exceptions[0]
|
||||
|
||||
raise exc
|
||||
|
||||
|
||||
def get_route_path(scope: Scope) -> str:
|
||||
path: str = scope["path"]
|
||||
root_path = scope.get("root_path", "")
|
||||
if not root_path:
|
||||
return path
|
||||
|
||||
if not path.startswith(root_path):
|
||||
return path
|
||||
|
||||
if path == root_path:
|
||||
return ""
|
||||
|
||||
if path[len(root_path)] == "/":
|
||||
return path[len(root_path) :]
|
||||
|
||||
return path
|
||||
|
||||
@@ -1,90 +1,80 @@
|
||||
import typing
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from collections.abc import Awaitable, Callable, Mapping, Sequence
|
||||
from typing import Any, ParamSpec, TypeVar
|
||||
|
||||
from starlette.datastructures import State, URLPath
|
||||
from starlette.middleware import Middleware
|
||||
from starlette.middleware import Middleware, _MiddlewareFactory
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.middleware.errors import ServerErrorMiddleware
|
||||
from starlette.middleware.exceptions import ExceptionMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
from starlette.routing import BaseRoute, Router
|
||||
from starlette.types import ASGIApp, Lifespan, Receive, Scope, Send
|
||||
from starlette.types import ASGIApp, ExceptionHandler, Lifespan, Receive, Scope, Send
|
||||
from starlette.websockets import WebSocket
|
||||
|
||||
AppType = typing.TypeVar("AppType", bound="Starlette")
|
||||
AppType = TypeVar("AppType", bound="Starlette")
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
||||
class Starlette:
|
||||
"""
|
||||
Creates an application instance.
|
||||
|
||||
**Parameters:**
|
||||
|
||||
* **debug** - Boolean indicating if debug tracebacks should be returned on errors.
|
||||
* **routes** - A list of routes to serve incoming HTTP and WebSocket requests.
|
||||
* **middleware** - A list of middleware to run for every request. A starlette
|
||||
application will always automatically include two middleware classes.
|
||||
`ServerErrorMiddleware` is added as the very outermost middleware, to handle
|
||||
any uncaught errors occurring anywhere in the entire stack.
|
||||
`ExceptionMiddleware` is added as the very innermost middleware, to deal
|
||||
with handled exception cases occurring in the routing or endpoints.
|
||||
* **exception_handlers** - A mapping of either integer status codes,
|
||||
or exception class types onto callables which handle the exceptions.
|
||||
Exception handler callables should be of the form
|
||||
`handler(request, exc) -> response` and may be be either standard functions, or
|
||||
async functions.
|
||||
* **on_startup** - A list of callables to run on application startup.
|
||||
Startup handler callables do not take any arguments, and may be be either
|
||||
standard functions, or async functions.
|
||||
* **on_shutdown** - A list of callables to run on application shutdown.
|
||||
Shutdown handler callables do not take any arguments, and may be be either
|
||||
standard functions, or async functions.
|
||||
* **lifespan** - A lifespan context function, which can be used to perform
|
||||
startup and shutdown tasks. This is a newer style that replaces the
|
||||
`on_startup` and `on_shutdown` handlers. Use one or the other, not both.
|
||||
"""
|
||||
"""Creates an Starlette application."""
|
||||
|
||||
def __init__(
|
||||
self: "AppType",
|
||||
self: AppType,
|
||||
debug: bool = False,
|
||||
routes: typing.Optional[typing.Sequence[BaseRoute]] = None,
|
||||
middleware: typing.Optional[typing.Sequence[Middleware]] = None,
|
||||
exception_handlers: typing.Optional[
|
||||
typing.Mapping[
|
||||
typing.Any,
|
||||
typing.Callable[
|
||||
[Request, Exception],
|
||||
typing.Union[Response, typing.Awaitable[Response]],
|
||||
],
|
||||
]
|
||||
] = None,
|
||||
on_startup: typing.Optional[typing.Sequence[typing.Callable]] = None,
|
||||
on_shutdown: typing.Optional[typing.Sequence[typing.Callable]] = None,
|
||||
lifespan: typing.Optional[Lifespan["AppType"]] = None,
|
||||
routes: Sequence[BaseRoute] | None = None,
|
||||
middleware: Sequence[Middleware] | None = None,
|
||||
exception_handlers: Mapping[Any, ExceptionHandler] | None = None,
|
||||
on_startup: Sequence[Callable[[], Any]] | None = None,
|
||||
on_shutdown: Sequence[Callable[[], Any]] | None = None,
|
||||
lifespan: Lifespan[AppType] | None = None,
|
||||
) -> None:
|
||||
"""Initializes the application.
|
||||
|
||||
Parameters:
|
||||
debug: Boolean indicating if debug tracebacks should be returned on errors.
|
||||
routes: A list of routes to serve incoming HTTP and WebSocket requests.
|
||||
middleware: A list of middleware to run for every request. A starlette
|
||||
application will always automatically include two middleware classes.
|
||||
`ServerErrorMiddleware` is added as the very outermost middleware, to handle
|
||||
any uncaught errors occurring anywhere in the entire stack.
|
||||
`ExceptionMiddleware` is added as the very innermost middleware, to deal
|
||||
with handled exception cases occurring in the routing or endpoints.
|
||||
exception_handlers: A mapping of either integer status codes,
|
||||
or exception class types onto callables which handle the exceptions.
|
||||
Exception handler callables should be of the form
|
||||
`handler(request, exc) -> response` and may be either standard functions, or
|
||||
async functions.
|
||||
on_startup: A list of callables to run on application startup.
|
||||
Startup handler callables do not take any arguments, and may be either
|
||||
standard functions, or async functions.
|
||||
on_shutdown: A list of callables to run on application shutdown.
|
||||
Shutdown handler callables do not take any arguments, and may be either
|
||||
standard functions, or async functions.
|
||||
lifespan: A lifespan context function, which can be used to perform
|
||||
startup and shutdown tasks. This is a newer style that replaces the
|
||||
`on_startup` and `on_shutdown` handlers. Use one or the other, not both.
|
||||
"""
|
||||
# The lifespan context function is a newer style that replaces
|
||||
# on_startup / on_shutdown handlers. Use one or the other, not both.
|
||||
assert lifespan is None or (
|
||||
on_startup is None and on_shutdown is None
|
||||
), "Use either 'lifespan' or 'on_startup'/'on_shutdown', not both."
|
||||
assert lifespan is None or (on_startup is None and on_shutdown is None), (
|
||||
"Use either 'lifespan' or 'on_startup'/'on_shutdown', not both."
|
||||
)
|
||||
|
||||
self.debug = debug
|
||||
self.state = State()
|
||||
self.router = Router(
|
||||
routes, on_startup=on_startup, on_shutdown=on_shutdown, lifespan=lifespan
|
||||
)
|
||||
self.exception_handlers = (
|
||||
{} if exception_handlers is None else dict(exception_handlers)
|
||||
)
|
||||
self.router = Router(routes, on_startup=on_startup, on_shutdown=on_shutdown, lifespan=lifespan)
|
||||
self.exception_handlers = {} if exception_handlers is None else dict(exception_handlers)
|
||||
self.user_middleware = [] if middleware is None else list(middleware)
|
||||
self.middleware_stack: typing.Optional[ASGIApp] = None
|
||||
self.middleware_stack: ASGIApp | None = None
|
||||
|
||||
def build_middleware_stack(self) -> ASGIApp:
|
||||
debug = self.debug
|
||||
error_handler = None
|
||||
exception_handlers: typing.Dict[
|
||||
typing.Any, typing.Callable[[Request, Exception], Response]
|
||||
] = {}
|
||||
exception_handlers: dict[Any, ExceptionHandler] = {}
|
||||
|
||||
for key, value in self.exception_handlers.items():
|
||||
if key in (500, Exception):
|
||||
@@ -95,25 +85,20 @@ class Starlette:
|
||||
middleware = (
|
||||
[Middleware(ServerErrorMiddleware, handler=error_handler, debug=debug)]
|
||||
+ self.user_middleware
|
||||
+ [
|
||||
Middleware(
|
||||
ExceptionMiddleware, handlers=exception_handlers, debug=debug
|
||||
)
|
||||
]
|
||||
+ [Middleware(ExceptionMiddleware, handlers=exception_handlers, debug=debug)]
|
||||
)
|
||||
|
||||
app = self.router
|
||||
for cls, options in reversed(middleware):
|
||||
app = cls(app=app, **options)
|
||||
for cls, args, kwargs in reversed(middleware):
|
||||
app = cls(app, *args, **kwargs)
|
||||
return app
|
||||
|
||||
@property
|
||||
def routes(self) -> typing.List[BaseRoute]:
|
||||
def routes(self) -> list[BaseRoute]:
|
||||
return self.router.routes
|
||||
|
||||
# TODO: Make `__name` a positional-only argument when we drop Python 3.7 support.
|
||||
def url_path_for(self, __name: str, **path_params: typing.Any) -> URLPath:
|
||||
return self.router.url_path_for(__name, **path_params)
|
||||
def url_path_for(self, name: str, /, **path_params: Any) -> URLPath:
|
||||
return self.router.url_path_for(name, **path_params)
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
scope["app"] = self
|
||||
@@ -121,63 +106,65 @@ class Starlette:
|
||||
self.middleware_stack = self.build_middleware_stack()
|
||||
await self.middleware_stack(scope, receive, send)
|
||||
|
||||
def on_event(self, event_type: str) -> typing.Callable: # pragma: nocover
|
||||
return self.router.on_event(event_type)
|
||||
def on_event(self, event_type: str) -> Callable: # type: ignore[type-arg]
|
||||
return self.router.on_event(event_type) # pragma: no cover
|
||||
|
||||
def mount(
|
||||
self, path: str, app: ASGIApp, name: typing.Optional[str] = None
|
||||
) -> None: # pragma: nocover
|
||||
self.router.mount(path, app=app, name=name)
|
||||
def mount(self, path: str, app: ASGIApp, name: str | None = None) -> None:
|
||||
self.router.mount(path, app=app, name=name) # pragma: no cover
|
||||
|
||||
def host(
|
||||
self, host: str, app: ASGIApp, name: typing.Optional[str] = None
|
||||
) -> None: # pragma: no cover
|
||||
self.router.host(host, app=app, name=name)
|
||||
def host(self, host: str, app: ASGIApp, name: str | None = None) -> None:
|
||||
self.router.host(host, app=app, name=name) # pragma: no cover
|
||||
|
||||
def add_middleware(self, middleware_class: type, **options: typing.Any) -> None:
|
||||
def add_middleware(
|
||||
self,
|
||||
middleware_class: _MiddlewareFactory[P],
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> None:
|
||||
if self.middleware_stack is not None: # pragma: no cover
|
||||
raise RuntimeError("Cannot add middleware after an application has started")
|
||||
self.user_middleware.insert(0, Middleware(middleware_class, **options))
|
||||
self.user_middleware.insert(0, Middleware(middleware_class, *args, **kwargs))
|
||||
|
||||
def add_exception_handler(
|
||||
self,
|
||||
exc_class_or_status_code: typing.Union[int, typing.Type[Exception]],
|
||||
handler: typing.Callable,
|
||||
exc_class_or_status_code: int | type[Exception],
|
||||
handler: ExceptionHandler,
|
||||
) -> None: # pragma: no cover
|
||||
self.exception_handlers[exc_class_or_status_code] = handler
|
||||
|
||||
def add_event_handler(
|
||||
self, event_type: str, func: typing.Callable
|
||||
self,
|
||||
event_type: str,
|
||||
func: Callable, # type: ignore[type-arg]
|
||||
) -> None: # pragma: no cover
|
||||
self.router.add_event_handler(event_type, func)
|
||||
|
||||
def add_route(
|
||||
self,
|
||||
path: str,
|
||||
route: typing.Callable,
|
||||
methods: typing.Optional[typing.List[str]] = None,
|
||||
name: typing.Optional[str] = None,
|
||||
route: Callable[[Request], Awaitable[Response] | Response],
|
||||
methods: list[str] | None = None,
|
||||
name: str | None = None,
|
||||
include_in_schema: bool = True,
|
||||
) -> None: # pragma: no cover
|
||||
self.router.add_route(
|
||||
path, route, methods=methods, name=name, include_in_schema=include_in_schema
|
||||
)
|
||||
self.router.add_route(path, route, methods=methods, name=name, include_in_schema=include_in_schema)
|
||||
|
||||
def add_websocket_route(
|
||||
self, path: str, route: typing.Callable, name: typing.Optional[str] = None
|
||||
self,
|
||||
path: str,
|
||||
route: Callable[[WebSocket], Awaitable[None]],
|
||||
name: str | None = None,
|
||||
) -> None: # pragma: no cover
|
||||
self.router.add_websocket_route(path, route, name=name)
|
||||
|
||||
def exception_handler(
|
||||
self, exc_class_or_status_code: typing.Union[int, typing.Type[Exception]]
|
||||
) -> typing.Callable:
|
||||
def exception_handler(self, exc_class_or_status_code: int | type[Exception]) -> Callable: # type: ignore[type-arg]
|
||||
warnings.warn(
|
||||
"The `exception_handler` decorator is deprecated, and will be removed in version 1.0.0. " # noqa: E501
|
||||
"Refer to https://www.starlette.io/exceptions/ for the recommended approach.", # noqa: E501
|
||||
"The `exception_handler` decorator is deprecated, and will be removed in version 1.0.0. "
|
||||
"Refer to https://starlette.dev/exceptions/ for the recommended approach.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
def decorator(func: typing.Callable) -> typing.Callable:
|
||||
def decorator(func: Callable) -> Callable: # type: ignore[type-arg]
|
||||
self.add_exception_handler(exc_class_or_status_code, func)
|
||||
return func
|
||||
|
||||
@@ -186,10 +173,10 @@ class Starlette:
|
||||
def route(
|
||||
self,
|
||||
path: str,
|
||||
methods: typing.Optional[typing.List[str]] = None,
|
||||
name: typing.Optional[str] = None,
|
||||
methods: list[str] | None = None,
|
||||
name: str | None = None,
|
||||
include_in_schema: bool = True,
|
||||
) -> typing.Callable:
|
||||
) -> Callable: # type: ignore[type-arg]
|
||||
"""
|
||||
We no longer document this decorator style API, and its usage is discouraged.
|
||||
Instead you should use the following approach:
|
||||
@@ -198,12 +185,12 @@ class Starlette:
|
||||
>>> app = Starlette(routes=routes)
|
||||
"""
|
||||
warnings.warn(
|
||||
"The `route` decorator is deprecated, and will be removed in version 1.0.0. " # noqa: E501
|
||||
"Refer to https://www.starlette.io/routing/ for the recommended approach.", # noqa: E501
|
||||
"The `route` decorator is deprecated, and will be removed in version 1.0.0. "
|
||||
"Refer to https://starlette.dev/routing/ for the recommended approach.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
def decorator(func: typing.Callable) -> typing.Callable:
|
||||
def decorator(func: Callable) -> Callable: # type: ignore[type-arg]
|
||||
self.router.add_route(
|
||||
path,
|
||||
func,
|
||||
@@ -215,9 +202,7 @@ class Starlette:
|
||||
|
||||
return decorator
|
||||
|
||||
def websocket_route(
|
||||
self, path: str, name: typing.Optional[str] = None
|
||||
) -> typing.Callable:
|
||||
def websocket_route(self, path: str, name: str | None = None) -> Callable: # type: ignore[type-arg]
|
||||
"""
|
||||
We no longer document this decorator style API, and its usage is discouraged.
|
||||
Instead you should use the following approach:
|
||||
@@ -226,18 +211,18 @@ class Starlette:
|
||||
>>> app = Starlette(routes=routes)
|
||||
"""
|
||||
warnings.warn(
|
||||
"The `websocket_route` decorator is deprecated, and will be removed in version 1.0.0. " # noqa: E501
|
||||
"Refer to https://www.starlette.io/routing/#websocket-routing for the recommended approach.", # noqa: E501
|
||||
"The `websocket_route` decorator is deprecated, and will be removed in version 1.0.0. "
|
||||
"Refer to https://starlette.dev/routing/#websocket-routing for the recommended approach.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
def decorator(func: typing.Callable) -> typing.Callable:
|
||||
def decorator(func: Callable) -> Callable: # type: ignore[type-arg]
|
||||
self.router.add_websocket_route(path, func, name=name)
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def middleware(self, middleware_type: str) -> typing.Callable:
|
||||
def middleware(self, middleware_type: str) -> Callable: # type: ignore[type-arg]
|
||||
"""
|
||||
We no longer document this decorator style API, and its usage is discouraged.
|
||||
Instead you should use the following approach:
|
||||
@@ -246,15 +231,13 @@ class Starlette:
|
||||
>>> app = Starlette(middleware=middleware)
|
||||
"""
|
||||
warnings.warn(
|
||||
"The `middleware` decorator is deprecated, and will be removed in version 1.0.0. " # noqa: E501
|
||||
"Refer to https://www.starlette.io/middleware/#using-middleware for recommended approach.", # noqa: E501
|
||||
"The `middleware` decorator is deprecated, and will be removed in version 1.0.0. "
|
||||
"Refer to https://starlette.dev/middleware/#using-middleware for recommended approach.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
assert (
|
||||
middleware_type == "http"
|
||||
), 'Currently only middleware("http") is supported.'
|
||||
assert middleware_type == "http", 'Currently only middleware("http") is supported.'
|
||||
|
||||
def decorator(func: typing.Callable) -> typing.Callable:
|
||||
def decorator(func: Callable) -> Callable: # type: ignore[type-arg]
|
||||
self.add_middleware(BaseHTTPMiddleware, dispatch=func)
|
||||
return func
|
||||
|
||||
|
||||
@@ -1,18 +1,21 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
import typing
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Any, ParamSpec
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from starlette._utils import is_async_callable
|
||||
from starlette.exceptions import HTTPException
|
||||
from starlette.requests import HTTPConnection, Request
|
||||
from starlette.responses import RedirectResponse, Response
|
||||
from starlette.responses import RedirectResponse
|
||||
from starlette.websockets import WebSocket
|
||||
|
||||
_CallableType = typing.TypeVar("_CallableType", bound=typing.Callable)
|
||||
_P = ParamSpec("_P")
|
||||
|
||||
|
||||
def has_required_scope(conn: HTTPConnection, scopes: typing.Sequence[str]) -> bool:
|
||||
def has_required_scope(conn: HTTPConnection, scopes: Sequence[str]) -> bool:
|
||||
for scope in scopes:
|
||||
if scope not in conn.auth.scopes:
|
||||
return False
|
||||
@@ -20,32 +23,28 @@ def has_required_scope(conn: HTTPConnection, scopes: typing.Sequence[str]) -> bo
|
||||
|
||||
|
||||
def requires(
|
||||
scopes: typing.Union[str, typing.Sequence[str]],
|
||||
scopes: str | Sequence[str],
|
||||
status_code: int = 403,
|
||||
redirect: typing.Optional[str] = None,
|
||||
) -> typing.Callable[[_CallableType], _CallableType]:
|
||||
redirect: str | None = None,
|
||||
) -> Callable[[Callable[_P, Any]], Callable[_P, Any]]:
|
||||
scopes_list = [scopes] if isinstance(scopes, str) else list(scopes)
|
||||
|
||||
def decorator(func: typing.Callable) -> typing.Callable:
|
||||
def decorator(
|
||||
func: Callable[_P, Any],
|
||||
) -> Callable[_P, Any]:
|
||||
sig = inspect.signature(func)
|
||||
for idx, parameter in enumerate(sig.parameters.values()):
|
||||
if parameter.name == "request" or parameter.name == "websocket":
|
||||
type_ = parameter.name
|
||||
break
|
||||
else:
|
||||
raise Exception(
|
||||
f'No "request" or "websocket" argument on function "{func}"'
|
||||
)
|
||||
raise Exception(f'No "request" or "websocket" argument on function "{func}"')
|
||||
|
||||
if type_ == "websocket":
|
||||
# Handle websocket functions. (Always async)
|
||||
@functools.wraps(func)
|
||||
async def websocket_wrapper(
|
||||
*args: typing.Any, **kwargs: typing.Any
|
||||
) -> None:
|
||||
websocket = kwargs.get(
|
||||
"websocket", args[idx] if idx < len(args) else None
|
||||
)
|
||||
async def websocket_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None:
|
||||
websocket = kwargs.get("websocket", args[idx] if idx < len(args) else None)
|
||||
assert isinstance(websocket, WebSocket)
|
||||
|
||||
if not has_required_scope(websocket, scopes_list):
|
||||
@@ -58,19 +57,14 @@ def requires(
|
||||
elif is_async_callable(func):
|
||||
# Handle async request/response functions.
|
||||
@functools.wraps(func)
|
||||
async def async_wrapper(
|
||||
*args: typing.Any, **kwargs: typing.Any
|
||||
) -> Response:
|
||||
async def async_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> Any:
|
||||
request = kwargs.get("request", args[idx] if idx < len(args) else None)
|
||||
assert isinstance(request, Request)
|
||||
|
||||
if not has_required_scope(request, scopes_list):
|
||||
if redirect is not None:
|
||||
orig_request_qparam = urlencode({"next": str(request.url)})
|
||||
next_url = "{redirect_path}?{orig_request}".format(
|
||||
redirect_path=request.url_for(redirect),
|
||||
orig_request=orig_request_qparam,
|
||||
)
|
||||
next_url = f"{request.url_for(redirect)}?{orig_request_qparam}"
|
||||
return RedirectResponse(url=next_url, status_code=303)
|
||||
raise HTTPException(status_code=status_code)
|
||||
return await func(*args, **kwargs)
|
||||
@@ -80,24 +74,21 @@ def requires(
|
||||
else:
|
||||
# Handle sync request/response functions.
|
||||
@functools.wraps(func)
|
||||
def sync_wrapper(*args: typing.Any, **kwargs: typing.Any) -> Response:
|
||||
def sync_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> Any:
|
||||
request = kwargs.get("request", args[idx] if idx < len(args) else None)
|
||||
assert isinstance(request, Request)
|
||||
|
||||
if not has_required_scope(request, scopes_list):
|
||||
if redirect is not None:
|
||||
orig_request_qparam = urlencode({"next": str(request.url)})
|
||||
next_url = "{redirect_path}?{orig_request}".format(
|
||||
redirect_path=request.url_for(redirect),
|
||||
orig_request=orig_request_qparam,
|
||||
)
|
||||
next_url = f"{request.url_for(redirect)}?{orig_request_qparam}"
|
||||
return RedirectResponse(url=next_url, status_code=303)
|
||||
raise HTTPException(status_code=status_code)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return sync_wrapper
|
||||
|
||||
return decorator # type: ignore[return-value]
|
||||
return decorator
|
||||
|
||||
|
||||
class AuthenticationError(Exception):
|
||||
@@ -105,14 +96,12 @@ class AuthenticationError(Exception):
|
||||
|
||||
|
||||
class AuthenticationBackend:
|
||||
async def authenticate(
|
||||
self, conn: HTTPConnection
|
||||
) -> typing.Optional[typing.Tuple["AuthCredentials", "BaseUser"]]:
|
||||
async def authenticate(self, conn: HTTPConnection) -> tuple[AuthCredentials, BaseUser] | None:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
|
||||
class AuthCredentials:
|
||||
def __init__(self, scopes: typing.Optional[typing.Sequence[str]] = None):
|
||||
def __init__(self, scopes: Sequence[str] | None = None):
|
||||
self.scopes = [] if scopes is None else list(scopes)
|
||||
|
||||
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
import sys
|
||||
import typing
|
||||
from __future__ import annotations
|
||||
|
||||
if sys.version_info >= (3, 10): # pragma: no cover
|
||||
from typing import ParamSpec
|
||||
else: # pragma: no cover
|
||||
from typing_extensions import ParamSpec
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Any, ParamSpec
|
||||
|
||||
from starlette._utils import is_async_callable
|
||||
from starlette.concurrency import run_in_threadpool
|
||||
@@ -13,9 +10,7 @@ P = ParamSpec("P")
|
||||
|
||||
|
||||
class BackgroundTask:
|
||||
def __init__(
|
||||
self, func: typing.Callable[P, typing.Any], *args: P.args, **kwargs: P.kwargs
|
||||
) -> None:
|
||||
def __init__(self, func: Callable[P, Any], *args: P.args, **kwargs: P.kwargs) -> None:
|
||||
self.func = func
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
@@ -29,12 +24,10 @@ class BackgroundTask:
|
||||
|
||||
|
||||
class BackgroundTasks(BackgroundTask):
|
||||
def __init__(self, tasks: typing.Optional[typing.Sequence[BackgroundTask]] = None):
|
||||
def __init__(self, tasks: Sequence[BackgroundTask] | None = None):
|
||||
self.tasks = list(tasks) if tasks else []
|
||||
|
||||
def add_task(
|
||||
self, func: typing.Callable[P, typing.Any], *args: P.args, **kwargs: P.kwargs
|
||||
) -> None:
|
||||
def add_task(self, func: Callable[P, Any], *args: P.args, **kwargs: P.kwargs) -> None:
|
||||
task = BackgroundTask(func, *args, **kwargs)
|
||||
self.tasks.append(task)
|
||||
|
||||
|
||||
@@ -1,30 +1,25 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import sys
|
||||
import typing
|
||||
import warnings
|
||||
from collections.abc import AsyncIterator, Callable, Coroutine, Iterable, Iterator
|
||||
from typing import ParamSpec, TypeVar
|
||||
|
||||
import anyio
|
||||
import anyio.to_thread
|
||||
|
||||
if sys.version_info >= (3, 10): # pragma: no cover
|
||||
from typing import ParamSpec
|
||||
else: # pragma: no cover
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
|
||||
T = typing.TypeVar("T")
|
||||
P = ParamSpec("P")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
async def run_until_first_complete(*args: typing.Tuple[typing.Callable, dict]) -> None:
|
||||
async def run_until_first_complete(*args: tuple[Callable, dict]) -> None: # type: ignore[type-arg]
|
||||
warnings.warn(
|
||||
"run_until_first_complete is deprecated "
|
||||
"and will be removed in a future version.",
|
||||
"run_until_first_complete is deprecated and will be removed in a future version.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
async with anyio.create_task_group() as task_group:
|
||||
|
||||
async def run(func: typing.Callable[[], typing.Coroutine]) -> None:
|
||||
async def run(func: Callable[[], Coroutine]) -> None: # type: ignore[type-arg]
|
||||
await func()
|
||||
task_group.cancel_scope.cancel()
|
||||
|
||||
@@ -32,20 +27,16 @@ async def run_until_first_complete(*args: typing.Tuple[typing.Callable, dict]) -
|
||||
task_group.start_soon(run, functools.partial(func, **kwargs))
|
||||
|
||||
|
||||
async def run_in_threadpool(
|
||||
func: typing.Callable[P, T], *args: P.args, **kwargs: P.kwargs
|
||||
) -> T:
|
||||
if kwargs: # pragma: no cover
|
||||
# run_sync doesn't accept 'kwargs', so bind them in here
|
||||
func = functools.partial(func, **kwargs)
|
||||
return await anyio.to_thread.run_sync(func, *args)
|
||||
async def run_in_threadpool(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
|
||||
func = functools.partial(func, *args, **kwargs)
|
||||
return await anyio.to_thread.run_sync(func)
|
||||
|
||||
|
||||
class _StopIteration(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def _next(iterator: typing.Iterator[T]) -> T:
|
||||
def _next(iterator: Iterator[T]) -> T:
|
||||
# We can't raise `StopIteration` from within the threadpool iterator
|
||||
# and catch it outside that context, so we coerce them into a different
|
||||
# exception type.
|
||||
@@ -56,10 +47,11 @@ def _next(iterator: typing.Iterator[T]) -> T:
|
||||
|
||||
|
||||
async def iterate_in_threadpool(
|
||||
iterator: typing.Iterator[T],
|
||||
) -> typing.AsyncIterator[T]:
|
||||
iterator: Iterable[T],
|
||||
) -> AsyncIterator[T]:
|
||||
as_iterator = iter(iterator)
|
||||
while True:
|
||||
try:
|
||||
yield await anyio.to_thread.run_sync(_next, iterator)
|
||||
yield await anyio.to_thread.run_sync(_next, as_iterator)
|
||||
except _StopIteration:
|
||||
break
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import typing
|
||||
from collections.abc import MutableMapping
|
||||
import warnings
|
||||
from collections.abc import Callable, Iterator, Mapping, MutableMapping
|
||||
from pathlib import Path
|
||||
from typing import Any, TypeVar, overload
|
||||
|
||||
|
||||
class undefined:
|
||||
@@ -12,32 +15,26 @@ class EnvironError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class Environ(MutableMapping):
|
||||
def __init__(self, environ: typing.MutableMapping = os.environ):
|
||||
class Environ(MutableMapping[str, str]):
|
||||
def __init__(self, environ: MutableMapping[str, str] = os.environ):
|
||||
self._environ = environ
|
||||
self._has_been_read: typing.Set[typing.Any] = set()
|
||||
self._has_been_read: set[str] = set()
|
||||
|
||||
def __getitem__(self, key: typing.Any) -> typing.Any:
|
||||
def __getitem__(self, key: str) -> str:
|
||||
self._has_been_read.add(key)
|
||||
return self._environ.__getitem__(key)
|
||||
|
||||
def __setitem__(self, key: typing.Any, value: typing.Any) -> None:
|
||||
def __setitem__(self, key: str, value: str) -> None:
|
||||
if key in self._has_been_read:
|
||||
raise EnvironError(
|
||||
f"Attempting to set environ['{key}'], but the value has already been "
|
||||
"read."
|
||||
)
|
||||
raise EnvironError(f"Attempting to set environ['{key}'], but the value has already been read.")
|
||||
self._environ.__setitem__(key, value)
|
||||
|
||||
def __delitem__(self, key: typing.Any) -> None:
|
||||
def __delitem__(self, key: str) -> None:
|
||||
if key in self._has_been_read:
|
||||
raise EnvironError(
|
||||
f"Attempting to delete environ['{key}'], but the value has already "
|
||||
"been read."
|
||||
)
|
||||
raise EnvironError(f"Attempting to delete environ['{key}'], but the value has already been read.")
|
||||
self._environ.__delitem__(key)
|
||||
|
||||
def __iter__(self) -> typing.Iterator:
|
||||
def __iter__(self) -> Iterator[str]:
|
||||
return iter(self._environ)
|
||||
|
||||
def __len__(self) -> int:
|
||||
@@ -46,65 +43,60 @@ class Environ(MutableMapping):
|
||||
|
||||
environ = Environ()
|
||||
|
||||
T = typing.TypeVar("T")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class Config:
|
||||
def __init__(
|
||||
self,
|
||||
env_file: typing.Optional[typing.Union[str, Path]] = None,
|
||||
environ: typing.Mapping[str, str] = environ,
|
||||
env_file: str | Path | None = None,
|
||||
environ: Mapping[str, str] = environ,
|
||||
env_prefix: str = "",
|
||||
encoding: str = "utf-8",
|
||||
) -> None:
|
||||
self.environ = environ
|
||||
self.env_prefix = env_prefix
|
||||
self.file_values: typing.Dict[str, str] = {}
|
||||
if env_file is not None and os.path.isfile(env_file):
|
||||
self.file_values = self._read_file(env_file)
|
||||
self.file_values: dict[str, str] = {}
|
||||
if env_file is not None:
|
||||
if not os.path.isfile(env_file):
|
||||
warnings.warn(f"Config file '{env_file}' not found.")
|
||||
else:
|
||||
self.file_values = self._read_file(env_file, encoding)
|
||||
|
||||
@typing.overload
|
||||
def __call__(self, key: str, *, default: None) -> typing.Optional[str]:
|
||||
...
|
||||
@overload
|
||||
def __call__(self, key: str, *, default: None) -> str | None: ...
|
||||
|
||||
@typing.overload
|
||||
def __call__(self, key: str, cast: typing.Type[T], default: T = ...) -> T:
|
||||
...
|
||||
@overload
|
||||
def __call__(self, key: str, cast: type[T], default: T = ...) -> T: ...
|
||||
|
||||
@typing.overload
|
||||
def __call__(
|
||||
self, key: str, cast: typing.Type[str] = ..., default: str = ...
|
||||
) -> str:
|
||||
...
|
||||
@overload
|
||||
def __call__(self, key: str, cast: type[str] = ..., default: str = ...) -> str: ...
|
||||
|
||||
@typing.overload
|
||||
@overload
|
||||
def __call__(
|
||||
self,
|
||||
key: str,
|
||||
cast: typing.Callable[[typing.Any], T] = ...,
|
||||
default: typing.Any = ...,
|
||||
) -> T:
|
||||
...
|
||||
cast: Callable[[Any], T] = ...,
|
||||
default: Any = ...,
|
||||
) -> T: ...
|
||||
|
||||
@typing.overload
|
||||
def __call__(
|
||||
self, key: str, cast: typing.Type[str] = ..., default: T = ...
|
||||
) -> typing.Union[T, str]:
|
||||
...
|
||||
@overload
|
||||
def __call__(self, key: str, cast: type[str] = ..., default: T = ...) -> T | str: ...
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
key: str,
|
||||
cast: typing.Optional[typing.Callable] = None,
|
||||
default: typing.Any = undefined,
|
||||
) -> typing.Any:
|
||||
cast: Callable[[Any], Any] | None = None,
|
||||
default: Any = undefined,
|
||||
) -> Any:
|
||||
return self.get(key, cast, default)
|
||||
|
||||
def get(
|
||||
self,
|
||||
key: str,
|
||||
cast: typing.Optional[typing.Callable] = None,
|
||||
default: typing.Any = undefined,
|
||||
) -> typing.Any:
|
||||
cast: Callable[[Any], Any] | None = None,
|
||||
default: Any = undefined,
|
||||
) -> Any:
|
||||
key = self.env_prefix + key
|
||||
if key in self.environ:
|
||||
value = self.environ[key]
|
||||
@@ -116,9 +108,9 @@ class Config:
|
||||
return self._perform_cast(key, default, cast)
|
||||
raise KeyError(f"Config '{key}' is missing, and has no default.")
|
||||
|
||||
def _read_file(self, file_name: typing.Union[str, Path]) -> typing.Dict[str, str]:
|
||||
file_values: typing.Dict[str, str] = {}
|
||||
with open(file_name) as input_file:
|
||||
def _read_file(self, file_name: str | Path, encoding: str) -> dict[str, str]:
|
||||
file_values: dict[str, str] = {}
|
||||
with open(file_name, encoding=encoding) as input_file:
|
||||
for line in input_file.readlines():
|
||||
line = line.strip()
|
||||
if "=" in line and not line.startswith("#"):
|
||||
@@ -129,21 +121,20 @@ class Config:
|
||||
return file_values
|
||||
|
||||
def _perform_cast(
|
||||
self, key: str, value: typing.Any, cast: typing.Optional[typing.Callable] = None
|
||||
) -> typing.Any:
|
||||
self,
|
||||
key: str,
|
||||
value: Any,
|
||||
cast: Callable[[Any], Any] | None = None,
|
||||
) -> Any:
|
||||
if cast is None or value is None:
|
||||
return value
|
||||
elif cast is bool and isinstance(value, str):
|
||||
mapping = {"true": True, "1": True, "false": False, "0": False}
|
||||
value = value.lower()
|
||||
if value not in mapping:
|
||||
raise ValueError(
|
||||
f"Config '{key}' has value '{value}'. Not a valid bool."
|
||||
)
|
||||
raise ValueError(f"Config '{key}' has value '{value}'. Not a valid bool.")
|
||||
return mapping[value]
|
||||
try:
|
||||
return cast(value)
|
||||
except (TypeError, ValueError):
|
||||
raise ValueError(
|
||||
f"Config '{key}' has value '{value}'. Not a valid {cast.__name__}."
|
||||
)
|
||||
raise ValueError(f"Config '{key}' has value '{value}'. Not a valid {cast.__name__}.")
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import typing
|
||||
import uuid
|
||||
from typing import Any, ClassVar, Generic, TypeVar
|
||||
|
||||
T = typing.TypeVar("T")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class Convertor(typing.Generic[T]):
|
||||
regex: typing.ClassVar[str] = ""
|
||||
class Convertor(Generic[T]):
|
||||
regex: ClassVar[str] = ""
|
||||
|
||||
def convert(self, value: str) -> T:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
@@ -15,7 +17,7 @@ class Convertor(typing.Generic[T]):
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
|
||||
class StringConvertor(Convertor):
|
||||
class StringConvertor(Convertor[str]):
|
||||
regex = "[^/]+"
|
||||
|
||||
def convert(self, value: str) -> str:
|
||||
@@ -28,7 +30,7 @@ class StringConvertor(Convertor):
|
||||
return value
|
||||
|
||||
|
||||
class PathConvertor(Convertor):
|
||||
class PathConvertor(Convertor[str]):
|
||||
regex = ".*"
|
||||
|
||||
def convert(self, value: str) -> str:
|
||||
@@ -38,7 +40,7 @@ class PathConvertor(Convertor):
|
||||
return str(value)
|
||||
|
||||
|
||||
class IntegerConvertor(Convertor):
|
||||
class IntegerConvertor(Convertor[int]):
|
||||
regex = "[0-9]+"
|
||||
|
||||
def convert(self, value: str) -> int:
|
||||
@@ -50,7 +52,7 @@ class IntegerConvertor(Convertor):
|
||||
return str(value)
|
||||
|
||||
|
||||
class FloatConvertor(Convertor):
|
||||
class FloatConvertor(Convertor[float]):
|
||||
regex = r"[0-9]+(\.[0-9]+)?"
|
||||
|
||||
def convert(self, value: str) -> float:
|
||||
@@ -64,8 +66,8 @@ class FloatConvertor(Convertor):
|
||||
return ("%0.20f" % value).rstrip("0").rstrip(".")
|
||||
|
||||
|
||||
class UUIDConvertor(Convertor):
|
||||
regex = "[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}"
|
||||
class UUIDConvertor(Convertor[uuid.UUID]):
|
||||
regex = "[0-9a-fA-F]{8}-?[0-9a-fA-F]{4}-?[0-9a-fA-F]{4}-?[0-9a-fA-F]{4}-?[0-9a-fA-F]{12}"
|
||||
|
||||
def convert(self, value: str) -> uuid.UUID:
|
||||
return uuid.UUID(value)
|
||||
@@ -74,7 +76,7 @@ class UUIDConvertor(Convertor):
|
||||
return str(value)
|
||||
|
||||
|
||||
CONVERTOR_TYPES = {
|
||||
CONVERTOR_TYPES: dict[str, Convertor[Any]] = {
|
||||
"str": StringConvertor(),
|
||||
"path": PathConvertor(),
|
||||
"int": IntegerConvertor(),
|
||||
@@ -83,5 +85,5 @@ CONVERTOR_TYPES = {
|
||||
}
|
||||
|
||||
|
||||
def register_url_convertor(key: str, convertor: Convertor) -> None:
|
||||
def register_url_convertor(key: str, convertor: Convertor[Any]) -> None:
|
||||
CONVERTOR_TYPES[key] = convertor
|
||||
|
||||
@@ -1,37 +1,45 @@
|
||||
import typing
|
||||
from collections.abc import Sequence
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import ItemsView, Iterable, Iterator, KeysView, Mapping, MutableMapping, Sequence, ValuesView
|
||||
from shlex import shlex
|
||||
from typing import (
|
||||
Any,
|
||||
BinaryIO,
|
||||
NamedTuple,
|
||||
TypeVar,
|
||||
cast,
|
||||
)
|
||||
from urllib.parse import SplitResult, parse_qsl, urlencode, urlsplit
|
||||
|
||||
from starlette.concurrency import run_in_threadpool
|
||||
from starlette.types import Scope
|
||||
|
||||
|
||||
class Address(typing.NamedTuple):
|
||||
class Address(NamedTuple):
|
||||
host: str
|
||||
port: int
|
||||
|
||||
|
||||
_KeyType = typing.TypeVar("_KeyType")
|
||||
_KeyType = TypeVar("_KeyType")
|
||||
# Mapping keys are invariant but their values are covariant since
|
||||
# you can only read them
|
||||
# that is, you can't do `Mapping[str, Animal]()["fido"] = Dog()`
|
||||
_CovariantValueType = typing.TypeVar("_CovariantValueType", covariant=True)
|
||||
_CovariantValueType = TypeVar("_CovariantValueType", covariant=True)
|
||||
|
||||
|
||||
class URL:
|
||||
def __init__(
|
||||
self,
|
||||
url: str = "",
|
||||
scope: typing.Optional[Scope] = None,
|
||||
**components: typing.Any,
|
||||
scope: Scope | None = None,
|
||||
**components: Any,
|
||||
) -> None:
|
||||
if scope is not None:
|
||||
assert not url, 'Cannot set both "url" and "scope".'
|
||||
assert not components, 'Cannot set both "scope" and "**components".'
|
||||
scheme = scope.get("scheme", "http")
|
||||
server = scope.get("server", None)
|
||||
path = scope.get("root_path", "") + scope["path"]
|
||||
path = scope["path"]
|
||||
query_string = scope.get("query_string", b"")
|
||||
|
||||
host_header = None
|
||||
@@ -87,32 +95,27 @@ class URL:
|
||||
return self.components.fragment
|
||||
|
||||
@property
|
||||
def username(self) -> typing.Union[None, str]:
|
||||
def username(self) -> None | str:
|
||||
return self.components.username
|
||||
|
||||
@property
|
||||
def password(self) -> typing.Union[None, str]:
|
||||
def password(self) -> None | str:
|
||||
return self.components.password
|
||||
|
||||
@property
|
||||
def hostname(self) -> typing.Union[None, str]:
|
||||
def hostname(self) -> None | str:
|
||||
return self.components.hostname
|
||||
|
||||
@property
|
||||
def port(self) -> typing.Optional[int]:
|
||||
def port(self) -> int | None:
|
||||
return self.components.port
|
||||
|
||||
@property
|
||||
def is_secure(self) -> bool:
|
||||
return self.scheme in ("https", "wss")
|
||||
|
||||
def replace(self, **kwargs: typing.Any) -> "URL":
|
||||
if (
|
||||
"username" in kwargs
|
||||
or "password" in kwargs
|
||||
or "hostname" in kwargs
|
||||
or "port" in kwargs
|
||||
):
|
||||
def replace(self, **kwargs: Any) -> URL:
|
||||
if "username" in kwargs or "password" in kwargs or "hostname" in kwargs or "port" in kwargs:
|
||||
hostname = kwargs.pop("hostname", None)
|
||||
port = kwargs.pop("port", self.port)
|
||||
username = kwargs.pop("username", self.username)
|
||||
@@ -139,19 +142,17 @@ class URL:
|
||||
components = self.components._replace(**kwargs)
|
||||
return self.__class__(components.geturl())
|
||||
|
||||
def include_query_params(self, **kwargs: typing.Any) -> "URL":
|
||||
def include_query_params(self, **kwargs: Any) -> URL:
|
||||
params = MultiDict(parse_qsl(self.query, keep_blank_values=True))
|
||||
params.update({str(key): str(value) for key, value in kwargs.items()})
|
||||
query = urlencode(params.multi_items())
|
||||
return self.replace(query=query)
|
||||
|
||||
def replace_query_params(self, **kwargs: typing.Any) -> "URL":
|
||||
def replace_query_params(self, **kwargs: Any) -> URL:
|
||||
query = urlencode([(str(key), str(value)) for key, value in kwargs.items()])
|
||||
return self.replace(query=query)
|
||||
|
||||
def remove_query_params(
|
||||
self, keys: typing.Union[str, typing.Sequence[str]]
|
||||
) -> "URL":
|
||||
def remove_query_params(self, keys: str | Sequence[str]) -> URL:
|
||||
if isinstance(keys, str):
|
||||
keys = [keys]
|
||||
params = MultiDict(parse_qsl(self.query, keep_blank_values=True))
|
||||
@@ -160,7 +161,7 @@ class URL:
|
||||
query = urlencode(params.multi_items())
|
||||
return self.replace(query=query)
|
||||
|
||||
def __eq__(self, other: typing.Any) -> bool:
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
return str(self) == str(other)
|
||||
|
||||
def __str__(self) -> str:
|
||||
@@ -179,7 +180,7 @@ class URLPath(str):
|
||||
Used by the routing to return `url_path_for` matches.
|
||||
"""
|
||||
|
||||
def __new__(cls, path: str, protocol: str = "", host: str = "") -> "URLPath":
|
||||
def __new__(cls, path: str, protocol: str = "", host: str = "") -> URLPath:
|
||||
assert protocol in ("http", "websocket", "")
|
||||
return str.__new__(cls, path)
|
||||
|
||||
@@ -187,7 +188,7 @@ class URLPath(str):
|
||||
self.protocol = protocol
|
||||
self.host = host
|
||||
|
||||
def make_absolute_url(self, base_url: typing.Union[str, URL]) -> URL:
|
||||
def make_absolute_url(self, base_url: str | URL) -> URL:
|
||||
if isinstance(base_url, str):
|
||||
base_url = URL(base_url)
|
||||
if self.protocol:
|
||||
@@ -223,8 +224,8 @@ class Secret:
|
||||
return bool(self._value)
|
||||
|
||||
|
||||
class CommaSeparatedStrings(Sequence):
|
||||
def __init__(self, value: typing.Union[str, typing.Sequence[str]]):
|
||||
class CommaSeparatedStrings(Sequence[str]):
|
||||
def __init__(self, value: str | Sequence[str]):
|
||||
if isinstance(value, str):
|
||||
splitter = shlex(value, posix=True)
|
||||
splitter.whitespace = ","
|
||||
@@ -236,10 +237,10 @@ class CommaSeparatedStrings(Sequence):
|
||||
def __len__(self) -> int:
|
||||
return len(self._items)
|
||||
|
||||
def __getitem__(self, index: typing.Union[int, slice]) -> typing.Any:
|
||||
def __getitem__(self, index: int | slice) -> Any:
|
||||
return self._items[index]
|
||||
|
||||
def __iter__(self) -> typing.Iterator[str]:
|
||||
def __iter__(self) -> Iterator[str]:
|
||||
return iter(self._items)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
@@ -251,74 +252,65 @@ class CommaSeparatedStrings(Sequence):
|
||||
return ", ".join(repr(item) for item in self)
|
||||
|
||||
|
||||
class ImmutableMultiDict(typing.Mapping[_KeyType, _CovariantValueType]):
|
||||
_dict: typing.Dict[_KeyType, _CovariantValueType]
|
||||
class ImmutableMultiDict(Mapping[_KeyType, _CovariantValueType]):
|
||||
_dict: dict[_KeyType, _CovariantValueType]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args: typing.Union[
|
||||
"ImmutableMultiDict[_KeyType, _CovariantValueType]",
|
||||
typing.Mapping[_KeyType, _CovariantValueType],
|
||||
typing.Iterable[typing.Tuple[_KeyType, _CovariantValueType]],
|
||||
],
|
||||
**kwargs: typing.Any,
|
||||
*args: ImmutableMultiDict[_KeyType, _CovariantValueType]
|
||||
| Mapping[_KeyType, _CovariantValueType]
|
||||
| Iterable[tuple[_KeyType, _CovariantValueType]],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
assert len(args) < 2, "Too many arguments."
|
||||
|
||||
value: typing.Any = args[0] if args else []
|
||||
value: Any = args[0] if args else []
|
||||
if kwargs:
|
||||
value = (
|
||||
ImmutableMultiDict(value).multi_items()
|
||||
+ ImmutableMultiDict(kwargs).multi_items() # type: ignore[operator]
|
||||
)
|
||||
value = ImmutableMultiDict(value).multi_items() + ImmutableMultiDict(kwargs).multi_items()
|
||||
|
||||
if not value:
|
||||
_items: typing.List[typing.Tuple[typing.Any, typing.Any]] = []
|
||||
_items: list[tuple[Any, Any]] = []
|
||||
elif hasattr(value, "multi_items"):
|
||||
value = typing.cast(
|
||||
ImmutableMultiDict[_KeyType, _CovariantValueType], value
|
||||
)
|
||||
value = cast(ImmutableMultiDict[_KeyType, _CovariantValueType], value)
|
||||
_items = list(value.multi_items())
|
||||
elif hasattr(value, "items"):
|
||||
value = typing.cast(typing.Mapping[_KeyType, _CovariantValueType], value)
|
||||
value = cast(Mapping[_KeyType, _CovariantValueType], value)
|
||||
_items = list(value.items())
|
||||
else:
|
||||
value = typing.cast(
|
||||
typing.List[typing.Tuple[typing.Any, typing.Any]], value
|
||||
)
|
||||
value = cast("list[tuple[Any, Any]]", value)
|
||||
_items = list(value)
|
||||
|
||||
self._dict = {k: v for k, v in _items}
|
||||
self._list = _items
|
||||
|
||||
def getlist(self, key: typing.Any) -> typing.List[_CovariantValueType]:
|
||||
def getlist(self, key: Any) -> list[_CovariantValueType]:
|
||||
return [item_value for item_key, item_value in self._list if item_key == key]
|
||||
|
||||
def keys(self) -> typing.KeysView[_KeyType]:
|
||||
def keys(self) -> KeysView[_KeyType]:
|
||||
return self._dict.keys()
|
||||
|
||||
def values(self) -> typing.ValuesView[_CovariantValueType]:
|
||||
def values(self) -> ValuesView[_CovariantValueType]:
|
||||
return self._dict.values()
|
||||
|
||||
def items(self) -> typing.ItemsView[_KeyType, _CovariantValueType]:
|
||||
def items(self) -> ItemsView[_KeyType, _CovariantValueType]:
|
||||
return self._dict.items()
|
||||
|
||||
def multi_items(self) -> typing.List[typing.Tuple[_KeyType, _CovariantValueType]]:
|
||||
def multi_items(self) -> list[tuple[_KeyType, _CovariantValueType]]:
|
||||
return list(self._list)
|
||||
|
||||
def __getitem__(self, key: _KeyType) -> _CovariantValueType:
|
||||
return self._dict[key]
|
||||
|
||||
def __contains__(self, key: typing.Any) -> bool:
|
||||
def __contains__(self, key: Any) -> bool:
|
||||
return key in self._dict
|
||||
|
||||
def __iter__(self) -> typing.Iterator[_KeyType]:
|
||||
def __iter__(self) -> Iterator[_KeyType]:
|
||||
return iter(self.keys())
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._dict)
|
||||
|
||||
def __eq__(self, other: typing.Any) -> bool:
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if not isinstance(other, self.__class__):
|
||||
return False
|
||||
return sorted(self._list) == sorted(other._list)
|
||||
@@ -329,24 +321,24 @@ class ImmutableMultiDict(typing.Mapping[_KeyType, _CovariantValueType]):
|
||||
return f"{class_name}({items!r})"
|
||||
|
||||
|
||||
class MultiDict(ImmutableMultiDict[typing.Any, typing.Any]):
|
||||
def __setitem__(self, key: typing.Any, value: typing.Any) -> None:
|
||||
class MultiDict(ImmutableMultiDict[Any, Any]):
|
||||
def __setitem__(self, key: Any, value: Any) -> None:
|
||||
self.setlist(key, [value])
|
||||
|
||||
def __delitem__(self, key: typing.Any) -> None:
|
||||
def __delitem__(self, key: Any) -> None:
|
||||
self._list = [(k, v) for k, v in self._list if k != key]
|
||||
del self._dict[key]
|
||||
|
||||
def pop(self, key: typing.Any, default: typing.Any = None) -> typing.Any:
|
||||
def pop(self, key: Any, default: Any = None) -> Any:
|
||||
self._list = [(k, v) for k, v in self._list if k != key]
|
||||
return self._dict.pop(key, default)
|
||||
|
||||
def popitem(self) -> typing.Tuple:
|
||||
def popitem(self) -> tuple[Any, Any]:
|
||||
key, value = self._dict.popitem()
|
||||
self._list = [(k, v) for k, v in self._list if k != key]
|
||||
return key, value
|
||||
|
||||
def poplist(self, key: typing.Any) -> typing.List:
|
||||
def poplist(self, key: Any) -> list[Any]:
|
||||
values = [v for k, v in self._list if k == key]
|
||||
self.pop(key)
|
||||
return values
|
||||
@@ -355,14 +347,14 @@ class MultiDict(ImmutableMultiDict[typing.Any, typing.Any]):
|
||||
self._dict.clear()
|
||||
self._list.clear()
|
||||
|
||||
def setdefault(self, key: typing.Any, default: typing.Any = None) -> typing.Any:
|
||||
def setdefault(self, key: Any, default: Any = None) -> Any:
|
||||
if key not in self:
|
||||
self._dict[key] = default
|
||||
self._list.append((key, default))
|
||||
|
||||
return self[key]
|
||||
|
||||
def setlist(self, key: typing.Any, values: typing.List) -> None:
|
||||
def setlist(self, key: Any, values: list[Any]) -> None:
|
||||
if not values:
|
||||
self.pop(key, None)
|
||||
else:
|
||||
@@ -370,18 +362,14 @@ class MultiDict(ImmutableMultiDict[typing.Any, typing.Any]):
|
||||
self._list = existing_items + [(key, value) for value in values]
|
||||
self._dict[key] = values[-1]
|
||||
|
||||
def append(self, key: typing.Any, value: typing.Any) -> None:
|
||||
def append(self, key: Any, value: Any) -> None:
|
||||
self._list.append((key, value))
|
||||
self._dict[key] = value
|
||||
|
||||
def update(
|
||||
self,
|
||||
*args: typing.Union[
|
||||
"MultiDict",
|
||||
typing.Mapping,
|
||||
typing.List[typing.Tuple[typing.Any, typing.Any]],
|
||||
],
|
||||
**kwargs: typing.Any,
|
||||
*args: MultiDict | Mapping[Any, Any] | list[tuple[Any, Any]],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
value = MultiDict(*args, **kwargs)
|
||||
existing_items = [(k, v) for (k, v) in self._list if k not in value.keys()]
|
||||
@@ -396,14 +384,8 @@ class QueryParams(ImmutableMultiDict[str, str]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args: typing.Union[
|
||||
"ImmutableMultiDict",
|
||||
typing.Mapping,
|
||||
typing.List[typing.Tuple[typing.Any, typing.Any]],
|
||||
str,
|
||||
bytes,
|
||||
],
|
||||
**kwargs: typing.Any,
|
||||
*args: ImmutableMultiDict[Any, Any] | Mapping[Any, Any] | list[tuple[Any, Any]] | str | bytes,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
assert len(args) < 2, "Too many arguments."
|
||||
|
||||
@@ -412,9 +394,7 @@ class QueryParams(ImmutableMultiDict[str, str]):
|
||||
if isinstance(value, str):
|
||||
super().__init__(parse_qsl(value, keep_blank_values=True), **kwargs)
|
||||
elif isinstance(value, bytes):
|
||||
super().__init__(
|
||||
parse_qsl(value.decode("latin-1"), keep_blank_values=True), **kwargs
|
||||
)
|
||||
super().__init__(parse_qsl(value.decode("latin-1"), keep_blank_values=True), **kwargs)
|
||||
else:
|
||||
super().__init__(*args, **kwargs) # type: ignore[arg-type]
|
||||
self._list = [(str(k), str(v)) for k, v in self._list]
|
||||
@@ -436,19 +416,23 @@ class UploadFile:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file: typing.BinaryIO,
|
||||
file: BinaryIO,
|
||||
*,
|
||||
size: typing.Optional[int] = None,
|
||||
filename: typing.Optional[str] = None,
|
||||
headers: "typing.Optional[Headers]" = None,
|
||||
size: int | None = None,
|
||||
filename: str | None = None,
|
||||
headers: Headers | None = None,
|
||||
) -> None:
|
||||
self.filename = filename
|
||||
self.file = file
|
||||
self.size = size
|
||||
self.headers = headers or Headers()
|
||||
|
||||
# Capture max size from SpooledTemporaryFile if one is provided. This slightly speeds up future checks.
|
||||
# Note 0 means unlimited mirroring SpooledTemporaryFile's __init__
|
||||
self._max_mem_size = getattr(self.file, "_max_size", 0)
|
||||
|
||||
@property
|
||||
def content_type(self) -> typing.Optional[str]:
|
||||
def content_type(self) -> str | None:
|
||||
return self.headers.get("content-type", None)
|
||||
|
||||
@property
|
||||
@@ -457,14 +441,24 @@ class UploadFile:
|
||||
rolled_to_disk = getattr(self.file, "_rolled", True)
|
||||
return not rolled_to_disk
|
||||
|
||||
async def write(self, data: bytes) -> None:
|
||||
if self.size is not None:
|
||||
self.size += len(data)
|
||||
def _will_roll(self, size_to_add: int) -> bool:
|
||||
# If we're not in_memory then we will always roll
|
||||
if not self._in_memory:
|
||||
return True
|
||||
|
||||
if self._in_memory:
|
||||
self.file.write(data)
|
||||
else:
|
||||
# Check for SpooledTemporaryFile._max_size
|
||||
future_size = self.file.tell() + size_to_add
|
||||
return bool(future_size > self._max_mem_size) if self._max_mem_size else False
|
||||
|
||||
async def write(self, data: bytes) -> None:
|
||||
new_data_len = len(data)
|
||||
if self.size is not None:
|
||||
self.size += new_data_len
|
||||
|
||||
if self._will_roll(new_data_len):
|
||||
await run_in_threadpool(self.file.write, data)
|
||||
else:
|
||||
self.file.write(data)
|
||||
|
||||
async def read(self, size: int = -1) -> bytes:
|
||||
if self._in_memory:
|
||||
@@ -483,20 +477,19 @@ class UploadFile:
|
||||
else:
|
||||
await run_in_threadpool(self.file.close)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(filename={self.filename!r}, size={self.size!r}, headers={self.headers!r})"
|
||||
|
||||
class FormData(ImmutableMultiDict[str, typing.Union[UploadFile, str]]):
|
||||
|
||||
class FormData(ImmutableMultiDict[str, UploadFile | str]):
|
||||
"""
|
||||
An immutable multidict, containing both file uploads and text input.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args: typing.Union[
|
||||
"FormData",
|
||||
typing.Mapping[str, typing.Union[str, UploadFile]],
|
||||
typing.List[typing.Tuple[str, typing.Union[str, UploadFile]]],
|
||||
],
|
||||
**kwargs: typing.Union[str, UploadFile],
|
||||
*args: FormData | Mapping[str, str | UploadFile] | list[tuple[str, str | UploadFile]],
|
||||
**kwargs: str | UploadFile,
|
||||
) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@@ -506,25 +499,22 @@ class FormData(ImmutableMultiDict[str, typing.Union[UploadFile, str]]):
|
||||
await value.close()
|
||||
|
||||
|
||||
class Headers(typing.Mapping[str, str]):
|
||||
class Headers(Mapping[str, str]):
|
||||
"""
|
||||
An immutable, case-insensitive multidict.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
headers: typing.Optional[typing.Mapping[str, str]] = None,
|
||||
raw: typing.Optional[typing.List[typing.Tuple[bytes, bytes]]] = None,
|
||||
scope: typing.Optional[typing.MutableMapping[str, typing.Any]] = None,
|
||||
headers: Mapping[str, str] | None = None,
|
||||
raw: list[tuple[bytes, bytes]] | None = None,
|
||||
scope: MutableMapping[str, Any] | None = None,
|
||||
) -> None:
|
||||
self._list: typing.List[typing.Tuple[bytes, bytes]] = []
|
||||
self._list: list[tuple[bytes, bytes]] = []
|
||||
if headers is not None:
|
||||
assert raw is None, 'Cannot set both "headers" and "raw".'
|
||||
assert scope is None, 'Cannot set both "headers" and "scope".'
|
||||
self._list = [
|
||||
(key.lower().encode("latin-1"), value.encode("latin-1"))
|
||||
for key, value in headers.items()
|
||||
]
|
||||
self._list = [(key.lower().encode("latin-1"), value.encode("latin-1")) for key, value in headers.items()]
|
||||
elif raw is not None:
|
||||
assert scope is None, 'Cannot set both "raw" and "scope".'
|
||||
self._list = raw
|
||||
@@ -534,30 +524,23 @@ class Headers(typing.Mapping[str, str]):
|
||||
self._list = scope["headers"] = list(scope["headers"])
|
||||
|
||||
@property
|
||||
def raw(self) -> typing.List[typing.Tuple[bytes, bytes]]:
|
||||
def raw(self) -> list[tuple[bytes, bytes]]:
|
||||
return list(self._list)
|
||||
|
||||
def keys(self) -> typing.List[str]: # type: ignore[override]
|
||||
def keys(self) -> list[str]: # type: ignore[override]
|
||||
return [key.decode("latin-1") for key, value in self._list]
|
||||
|
||||
def values(self) -> typing.List[str]: # type: ignore[override]
|
||||
def values(self) -> list[str]: # type: ignore[override]
|
||||
return [value.decode("latin-1") for key, value in self._list]
|
||||
|
||||
def items(self) -> typing.List[typing.Tuple[str, str]]: # type: ignore[override]
|
||||
return [
|
||||
(key.decode("latin-1"), value.decode("latin-1"))
|
||||
for key, value in self._list
|
||||
]
|
||||
def items(self) -> list[tuple[str, str]]: # type: ignore[override]
|
||||
return [(key.decode("latin-1"), value.decode("latin-1")) for key, value in self._list]
|
||||
|
||||
def getlist(self, key: str) -> typing.List[str]:
|
||||
def getlist(self, key: str) -> list[str]:
|
||||
get_header_key = key.lower().encode("latin-1")
|
||||
return [
|
||||
item_value.decode("latin-1")
|
||||
for item_key, item_value in self._list
|
||||
if item_key == get_header_key
|
||||
]
|
||||
return [item_value.decode("latin-1") for item_key, item_value in self._list if item_key == get_header_key]
|
||||
|
||||
def mutablecopy(self) -> "MutableHeaders":
|
||||
def mutablecopy(self) -> MutableHeaders:
|
||||
return MutableHeaders(raw=self._list[:])
|
||||
|
||||
def __getitem__(self, key: str) -> str:
|
||||
@@ -567,20 +550,20 @@ class Headers(typing.Mapping[str, str]):
|
||||
return header_value.decode("latin-1")
|
||||
raise KeyError(key)
|
||||
|
||||
def __contains__(self, key: typing.Any) -> bool:
|
||||
def __contains__(self, key: Any) -> bool:
|
||||
get_header_key = key.lower().encode("latin-1")
|
||||
for header_key, header_value in self._list:
|
||||
if header_key == get_header_key:
|
||||
return True
|
||||
return False
|
||||
|
||||
def __iter__(self) -> typing.Iterator[typing.Any]:
|
||||
def __iter__(self) -> Iterator[Any]:
|
||||
return iter(self.keys())
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._list)
|
||||
|
||||
def __eq__(self, other: typing.Any) -> bool:
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if not isinstance(other, Headers):
|
||||
return False
|
||||
return sorted(self._list) == sorted(other._list)
|
||||
@@ -602,7 +585,7 @@ class MutableHeaders(Headers):
|
||||
set_key = key.lower().encode("latin-1")
|
||||
set_value = value.encode("latin-1")
|
||||
|
||||
found_indexes: "typing.List[int]" = []
|
||||
found_indexes: list[int] = []
|
||||
for idx, (item_key, item_value) in enumerate(self._list):
|
||||
if item_key == set_key:
|
||||
found_indexes.append(idx)
|
||||
@@ -622,7 +605,7 @@ class MutableHeaders(Headers):
|
||||
"""
|
||||
del_key = key.lower().encode("latin-1")
|
||||
|
||||
pop_indexes: "typing.List[int]" = []
|
||||
pop_indexes: list[int] = []
|
||||
for idx, (item_key, item_value) in enumerate(self._list):
|
||||
if item_key == del_key:
|
||||
pop_indexes.append(idx)
|
||||
@@ -630,21 +613,21 @@ class MutableHeaders(Headers):
|
||||
for idx in reversed(pop_indexes):
|
||||
del self._list[idx]
|
||||
|
||||
def __ior__(self, other: typing.Mapping[str, str]) -> "MutableHeaders":
|
||||
if not isinstance(other, typing.Mapping):
|
||||
def __ior__(self, other: Mapping[str, str]) -> MutableHeaders:
|
||||
if not isinstance(other, Mapping):
|
||||
raise TypeError(f"Expected a mapping but got {other.__class__.__name__}")
|
||||
self.update(other)
|
||||
return self
|
||||
|
||||
def __or__(self, other: typing.Mapping[str, str]) -> "MutableHeaders":
|
||||
if not isinstance(other, typing.Mapping):
|
||||
def __or__(self, other: Mapping[str, str]) -> MutableHeaders:
|
||||
if not isinstance(other, Mapping):
|
||||
raise TypeError(f"Expected a mapping but got {other.__class__.__name__}")
|
||||
new = self.mutablecopy()
|
||||
new.update(other)
|
||||
return new
|
||||
|
||||
@property
|
||||
def raw(self) -> typing.List[typing.Tuple[bytes, bytes]]:
|
||||
def raw(self) -> list[tuple[bytes, bytes]]:
|
||||
return self._list
|
||||
|
||||
def setdefault(self, key: str, value: str) -> str:
|
||||
@@ -661,7 +644,7 @@ class MutableHeaders(Headers):
|
||||
self._list.append((set_key, set_value))
|
||||
return value
|
||||
|
||||
def update(self, other: typing.Mapping[str, str]) -> None:
|
||||
def update(self, other: Mapping[str, str]) -> None:
|
||||
for key, val in other.items():
|
||||
self[key] = val
|
||||
|
||||
@@ -687,22 +670,22 @@ class State:
|
||||
Used for `request.state` and `app.state`.
|
||||
"""
|
||||
|
||||
_state: typing.Dict[str, typing.Any]
|
||||
_state: dict[str, Any]
|
||||
|
||||
def __init__(self, state: typing.Optional[typing.Dict[str, typing.Any]] = None):
|
||||
def __init__(self, state: dict[str, Any] | None = None):
|
||||
if state is None:
|
||||
state = {}
|
||||
super().__setattr__("_state", state)
|
||||
|
||||
def __setattr__(self, key: typing.Any, value: typing.Any) -> None:
|
||||
def __setattr__(self, key: Any, value: Any) -> None:
|
||||
self._state[key] = value
|
||||
|
||||
def __getattr__(self, key: typing.Any) -> typing.Any:
|
||||
def __getattr__(self, key: Any) -> Any:
|
||||
try:
|
||||
return self._state[key]
|
||||
except KeyError:
|
||||
message = "'{}' object has no attribute '{}'"
|
||||
raise AttributeError(message.format(self.__class__.__name__, key))
|
||||
|
||||
def __delattr__(self, key: typing.Any) -> None:
|
||||
def __delattr__(self, key: Any) -> None:
|
||||
del self._state[key]
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import typing
|
||||
from collections.abc import Callable, Generator
|
||||
from typing import Any, Literal
|
||||
|
||||
from starlette import status
|
||||
from starlette._utils import is_async_callable
|
||||
@@ -23,20 +26,14 @@ class HTTPEndpoint:
|
||||
if getattr(self, method.lower(), None) is not None
|
||||
]
|
||||
|
||||
def __await__(self) -> typing.Generator:
|
||||
def __await__(self) -> Generator[Any, None, None]:
|
||||
return self.dispatch().__await__()
|
||||
|
||||
async def dispatch(self) -> None:
|
||||
request = Request(self.scope, receive=self.receive)
|
||||
handler_name = (
|
||||
"get"
|
||||
if request.method == "HEAD" and not hasattr(self, "head")
|
||||
else request.method.lower()
|
||||
)
|
||||
handler_name = "get" if request.method == "HEAD" and not hasattr(self, "head") else request.method.lower()
|
||||
|
||||
handler: typing.Callable[[Request], typing.Any] = getattr(
|
||||
self, handler_name, self.method_not_allowed
|
||||
)
|
||||
handler: Callable[[Request], Any] = getattr(self, handler_name, self.method_not_allowed)
|
||||
is_async = is_async_callable(handler)
|
||||
if is_async:
|
||||
response = await handler(request)
|
||||
@@ -55,7 +52,7 @@ class HTTPEndpoint:
|
||||
|
||||
|
||||
class WebSocketEndpoint:
|
||||
encoding: typing.Optional[str] = None # May be "text", "bytes", or "json".
|
||||
encoding: Literal["text", "bytes", "json"] | None = None
|
||||
|
||||
def __init__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
assert scope["type"] == "websocket"
|
||||
@@ -63,7 +60,7 @@ class WebSocketEndpoint:
|
||||
self.receive = receive
|
||||
self.send = send
|
||||
|
||||
def __await__(self) -> typing.Generator:
|
||||
def __await__(self) -> Generator[Any, None, None]:
|
||||
return self.dispatch().__await__()
|
||||
|
||||
async def dispatch(self) -> None:
|
||||
@@ -78,10 +75,8 @@ class WebSocketEndpoint:
|
||||
if message["type"] == "websocket.receive":
|
||||
data = await self.decode(websocket, message)
|
||||
await self.on_receive(websocket, data)
|
||||
elif message["type"] == "websocket.disconnect":
|
||||
close_code = int(
|
||||
message.get("code") or status.WS_1000_NORMAL_CLOSURE
|
||||
)
|
||||
elif message["type"] == "websocket.disconnect": # pragma: no branch
|
||||
close_code = int(message.get("code") or status.WS_1000_NORMAL_CLOSURE)
|
||||
break
|
||||
except Exception as exc:
|
||||
close_code = status.WS_1011_INTERNAL_ERROR
|
||||
@@ -89,7 +84,7 @@ class WebSocketEndpoint:
|
||||
finally:
|
||||
await self.on_disconnect(websocket, close_code)
|
||||
|
||||
async def decode(self, websocket: WebSocket, message: Message) -> typing.Any:
|
||||
async def decode(self, websocket: WebSocket, message: Message) -> Any:
|
||||
if self.encoding == "text":
|
||||
if "text" not in message:
|
||||
await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA)
|
||||
@@ -114,16 +109,14 @@ class WebSocketEndpoint:
|
||||
await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA)
|
||||
raise RuntimeError("Malformed JSON data received.")
|
||||
|
||||
assert (
|
||||
self.encoding is None
|
||||
), f"Unsupported 'encoding' attribute {self.encoding}"
|
||||
assert self.encoding is None, f"Unsupported 'encoding' attribute {self.encoding}"
|
||||
return message["text"] if message.get("text") else message["bytes"]
|
||||
|
||||
async def on_connect(self, websocket: WebSocket) -> None:
|
||||
"""Override to handle an incoming websocket connection"""
|
||||
await websocket.accept()
|
||||
|
||||
async def on_receive(self, websocket: WebSocket, data: typing.Any) -> None:
|
||||
async def on_receive(self, websocket: WebSocket, data: Any) -> None:
|
||||
"""Override to handle an incoming websocket message"""
|
||||
|
||||
async def on_disconnect(self, websocket: WebSocket, close_code: int) -> None:
|
||||
|
||||
@@ -1,54 +1,33 @@
|
||||
import http
|
||||
import typing
|
||||
import warnings
|
||||
from __future__ import annotations
|
||||
|
||||
__all__ = ("HTTPException", "WebSocketException")
|
||||
import http
|
||||
from collections.abc import Mapping
|
||||
|
||||
|
||||
class HTTPException(Exception):
|
||||
def __init__(
|
||||
self,
|
||||
status_code: int,
|
||||
detail: typing.Optional[str] = None,
|
||||
headers: typing.Optional[dict] = None,
|
||||
) -> None:
|
||||
def __init__(self, status_code: int, detail: str | None = None, headers: Mapping[str, str] | None = None) -> None:
|
||||
if detail is None:
|
||||
detail = http.HTTPStatus(status_code).phrase
|
||||
self.status_code = status_code
|
||||
self.detail = detail
|
||||
self.headers = headers
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.status_code}: {self.detail}"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
class_name = self.__class__.__name__
|
||||
return f"{class_name}(status_code={self.status_code!r}, detail={self.detail!r})"
|
||||
|
||||
|
||||
class WebSocketException(Exception):
|
||||
def __init__(self, code: int, reason: typing.Optional[str] = None) -> None:
|
||||
def __init__(self, code: int, reason: str | None = None) -> None:
|
||||
self.code = code
|
||||
self.reason = reason or ""
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.code}: {self.reason}"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
class_name = self.__class__.__name__
|
||||
return f"{class_name}(code={self.code!r}, reason={self.reason!r})"
|
||||
|
||||
|
||||
__deprecated__ = "ExceptionMiddleware"
|
||||
|
||||
|
||||
def __getattr__(name: str) -> typing.Any: # pragma: no cover
|
||||
if name == __deprecated__:
|
||||
from starlette.middleware.exceptions import ExceptionMiddleware
|
||||
|
||||
warnings.warn(
|
||||
f"{__deprecated__} is deprecated on `starlette.exceptions`. "
|
||||
f"Import it from `starlette.middleware.exceptions` instead.",
|
||||
category=DeprecationWarning,
|
||||
stacklevel=3,
|
||||
)
|
||||
return ExceptionMiddleware
|
||||
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
||||
|
||||
|
||||
def __dir__() -> typing.List[str]:
|
||||
return sorted(list(__all__) + [__deprecated__]) # pragma: no cover
|
||||
|
||||
@@ -1,17 +1,28 @@
|
||||
import typing
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from tempfile import SpooledTemporaryFile
|
||||
from typing import TYPE_CHECKING
|
||||
from urllib.parse import unquote_plus
|
||||
|
||||
from starlette.datastructures import FormData, Headers, UploadFile
|
||||
|
||||
try:
|
||||
import multipart
|
||||
from multipart.multipart import parse_options_header
|
||||
except ModuleNotFoundError: # pragma: nocover
|
||||
parse_options_header = None
|
||||
multipart = None
|
||||
if TYPE_CHECKING:
|
||||
import python_multipart as multipart
|
||||
from python_multipart.multipart import MultipartCallbacks, QuerystringCallbacks, parse_options_header
|
||||
else:
|
||||
try:
|
||||
try:
|
||||
import python_multipart as multipart
|
||||
from python_multipart.multipart import parse_options_header
|
||||
except ModuleNotFoundError: # pragma: no cover
|
||||
import multipart
|
||||
from multipart.multipart import parse_options_header
|
||||
except ModuleNotFoundError: # pragma: no cover
|
||||
multipart = None
|
||||
parse_options_header = None
|
||||
|
||||
|
||||
class FormMessage(Enum):
|
||||
@@ -24,14 +35,14 @@ class FormMessage(Enum):
|
||||
|
||||
@dataclass
|
||||
class MultipartPart:
|
||||
content_disposition: typing.Optional[bytes] = None
|
||||
content_disposition: bytes | None = None
|
||||
field_name: str = ""
|
||||
data: bytes = b""
|
||||
file: typing.Optional[UploadFile] = None
|
||||
item_headers: typing.List[typing.Tuple[bytes, bytes]] = field(default_factory=list)
|
||||
data: bytearray = field(default_factory=bytearray)
|
||||
file: UploadFile | None = None
|
||||
item_headers: list[tuple[bytes, bytes]] = field(default_factory=list)
|
||||
|
||||
|
||||
def _user_safe_decode(src: bytes, codec: str) -> str:
|
||||
def _user_safe_decode(src: bytes | bytearray, codec: str) -> str:
|
||||
try:
|
||||
return src.decode(codec)
|
||||
except (UnicodeDecodeError, LookupError):
|
||||
@@ -44,15 +55,11 @@ class MultiPartException(Exception):
|
||||
|
||||
|
||||
class FormParser:
|
||||
def __init__(
|
||||
self, headers: Headers, stream: typing.AsyncGenerator[bytes, None]
|
||||
) -> None:
|
||||
assert (
|
||||
multipart is not None
|
||||
), "The `python-multipart` library must be installed to use form parsing."
|
||||
def __init__(self, headers: Headers, stream: AsyncGenerator[bytes, None]) -> None:
|
||||
assert multipart is not None, "The `python-multipart` library must be installed to use form parsing."
|
||||
self.headers = headers
|
||||
self.stream = stream
|
||||
self.messages: typing.List[typing.Tuple[FormMessage, bytes]] = []
|
||||
self.messages: list[tuple[FormMessage, bytes]] = []
|
||||
|
||||
def on_field_start(self) -> None:
|
||||
message = (FormMessage.FIELD_START, b"")
|
||||
@@ -76,7 +83,7 @@ class FormParser:
|
||||
|
||||
async def parse(self) -> FormData:
|
||||
# Callbacks dictionary.
|
||||
callbacks = {
|
||||
callbacks: QuerystringCallbacks = {
|
||||
"on_field_start": self.on_field_start,
|
||||
"on_field_name": self.on_field_name,
|
||||
"on_field_data": self.on_field_data,
|
||||
@@ -89,7 +96,7 @@ class FormParser:
|
||||
field_name = b""
|
||||
field_value = b""
|
||||
|
||||
items: typing.List[typing.Tuple[str, typing.Union[str, UploadFile]]] = []
|
||||
items: list[tuple[str, str | UploadFile]] = []
|
||||
|
||||
# Feed the parser with data from the request.
|
||||
async for chunk in self.stream:
|
||||
@@ -116,33 +123,36 @@ class FormParser:
|
||||
|
||||
|
||||
class MultiPartParser:
|
||||
max_file_size = 1024 * 1024
|
||||
spool_max_size = 1024 * 1024 # 1MB
|
||||
"""The maximum size of the spooled temporary file used to store file data."""
|
||||
max_part_size = 1024 * 1024 # 1MB
|
||||
"""The maximum size of a part in the multipart request."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
headers: Headers,
|
||||
stream: typing.AsyncGenerator[bytes, None],
|
||||
stream: AsyncGenerator[bytes, None],
|
||||
*,
|
||||
max_files: typing.Union[int, float] = 1000,
|
||||
max_fields: typing.Union[int, float] = 1000,
|
||||
max_files: int | float = 1000,
|
||||
max_fields: int | float = 1000,
|
||||
max_part_size: int = 1024 * 1024, # 1MB
|
||||
) -> None:
|
||||
assert (
|
||||
multipart is not None
|
||||
), "The `python-multipart` library must be installed to use form parsing."
|
||||
assert multipart is not None, "The `python-multipart` library must be installed to use form parsing."
|
||||
self.headers = headers
|
||||
self.stream = stream
|
||||
self.max_files = max_files
|
||||
self.max_fields = max_fields
|
||||
self.items: typing.List[typing.Tuple[str, typing.Union[str, UploadFile]]] = []
|
||||
self.items: list[tuple[str, str | UploadFile]] = []
|
||||
self._current_files = 0
|
||||
self._current_fields = 0
|
||||
self._current_partial_header_name: bytes = b""
|
||||
self._current_partial_header_value: bytes = b""
|
||||
self._current_part = MultipartPart()
|
||||
self._charset = ""
|
||||
self._file_parts_to_write: typing.List[typing.Tuple[MultipartPart, bytes]] = []
|
||||
self._file_parts_to_finish: typing.List[MultipartPart] = []
|
||||
self._files_to_close_on_error: typing.List[SpooledTemporaryFile] = []
|
||||
self._file_parts_to_write: list[tuple[MultipartPart, bytes]] = []
|
||||
self._file_parts_to_finish: list[MultipartPart] = []
|
||||
self._files_to_close_on_error: list[SpooledTemporaryFile[bytes]] = []
|
||||
self.max_part_size = max_part_size
|
||||
|
||||
def on_part_begin(self) -> None:
|
||||
self._current_part = MultipartPart()
|
||||
@@ -150,7 +160,9 @@ class MultiPartParser:
|
||||
def on_part_data(self, data: bytes, start: int, end: int) -> None:
|
||||
message_bytes = data[start:end]
|
||||
if self._current_part.file is None:
|
||||
self._current_part.data += message_bytes
|
||||
if len(self._current_part.data) + len(message_bytes) > self.max_part_size:
|
||||
raise MultiPartException(f"Part exceeded maximum size of {int(self.max_part_size / 1024)}KB.")
|
||||
self._current_part.data.extend(message_bytes)
|
||||
else:
|
||||
self._file_parts_to_write.append((self._current_part, message_bytes))
|
||||
|
||||
@@ -179,32 +191,22 @@ class MultiPartParser:
|
||||
field = self._current_partial_header_name.lower()
|
||||
if field == b"content-disposition":
|
||||
self._current_part.content_disposition = self._current_partial_header_value
|
||||
self._current_part.item_headers.append(
|
||||
(field, self._current_partial_header_value)
|
||||
)
|
||||
self._current_part.item_headers.append((field, self._current_partial_header_value))
|
||||
self._current_partial_header_name = b""
|
||||
self._current_partial_header_value = b""
|
||||
|
||||
def on_headers_finished(self) -> None:
|
||||
disposition, options = parse_options_header(
|
||||
self._current_part.content_disposition
|
||||
)
|
||||
disposition, options = parse_options_header(self._current_part.content_disposition)
|
||||
try:
|
||||
self._current_part.field_name = _user_safe_decode(
|
||||
options[b"name"], self._charset
|
||||
)
|
||||
self._current_part.field_name = _user_safe_decode(options[b"name"], self._charset)
|
||||
except KeyError:
|
||||
raise MultiPartException(
|
||||
'The Content-Disposition header field "name" must be ' "provided."
|
||||
)
|
||||
raise MultiPartException('The Content-Disposition header field "name" must be provided.')
|
||||
if b"filename" in options:
|
||||
self._current_files += 1
|
||||
if self._current_files > self.max_files:
|
||||
raise MultiPartException(
|
||||
f"Too many files. Maximum number of files is {self.max_files}."
|
||||
)
|
||||
raise MultiPartException(f"Too many files. Maximum number of files is {self.max_files}.")
|
||||
filename = _user_safe_decode(options[b"filename"], self._charset)
|
||||
tempfile = SpooledTemporaryFile(max_size=self.max_file_size)
|
||||
tempfile = SpooledTemporaryFile(max_size=self.spool_max_size)
|
||||
self._files_to_close_on_error.append(tempfile)
|
||||
self._current_part.file = UploadFile(
|
||||
file=tempfile, # type: ignore[arg-type]
|
||||
@@ -215,9 +217,7 @@ class MultiPartParser:
|
||||
else:
|
||||
self._current_fields += 1
|
||||
if self._current_fields > self.max_fields:
|
||||
raise MultiPartException(
|
||||
f"Too many fields. Maximum number of fields is {self.max_fields}."
|
||||
)
|
||||
raise MultiPartException(f"Too many fields. Maximum number of fields is {self.max_fields}.")
|
||||
self._current_part.file = None
|
||||
|
||||
def on_end(self) -> None:
|
||||
@@ -227,7 +227,7 @@ class MultiPartParser:
|
||||
# Parse the Content-Type header to get the multipart boundary.
|
||||
_, params = parse_options_header(self.headers["Content-Type"])
|
||||
charset = params.get(b"charset", "utf-8")
|
||||
if type(charset) == bytes:
|
||||
if isinstance(charset, bytes):
|
||||
charset = charset.decode("latin-1")
|
||||
self._charset = charset
|
||||
try:
|
||||
@@ -236,7 +236,7 @@ class MultiPartParser:
|
||||
raise MultiPartException("Missing boundary in multipart.")
|
||||
|
||||
# Callbacks dictionary.
|
||||
callbacks = {
|
||||
callbacks: MultipartCallbacks = {
|
||||
"on_part_begin": self.on_part_begin,
|
||||
"on_part_data": self.on_part_data,
|
||||
"on_part_end": self.on_part_end,
|
||||
|
||||
@@ -1,17 +1,37 @@
|
||||
import typing
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Awaitable, Callable, Iterator
|
||||
from typing import Any, ParamSpec, Protocol
|
||||
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
||||
_Scope = Any
|
||||
_Receive = Callable[[], Awaitable[Any]]
|
||||
_Send = Callable[[Any], Awaitable[None]]
|
||||
# Since `starlette.types.ASGIApp` type differs from `ASGIApplication` from `asgiref`
|
||||
# we need to define a more permissive version of ASGIApp that doesn't cause type errors.
|
||||
_ASGIApp = Callable[[_Scope, _Receive, _Send], Awaitable[None]]
|
||||
|
||||
|
||||
class _MiddlewareFactory(Protocol[P]):
|
||||
def __call__(self, app: _ASGIApp, /, *args: P.args, **kwargs: P.kwargs) -> _ASGIApp: ... # pragma: no cover
|
||||
|
||||
|
||||
class Middleware:
|
||||
def __init__(self, cls: type, **options: typing.Any) -> None:
|
||||
def __init__(self, cls: _MiddlewareFactory[P], *args: P.args, **kwargs: P.kwargs) -> None:
|
||||
self.cls = cls
|
||||
self.options = options
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
|
||||
def __iter__(self) -> typing.Iterator:
|
||||
as_tuple = (self.cls, self.options)
|
||||
def __iter__(self) -> Iterator[Any]:
|
||||
as_tuple = (self.cls, self.args, self.kwargs)
|
||||
return iter(as_tuple)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
class_name = self.__class__.__name__
|
||||
option_strings = [f"{key}={value!r}" for key, value in self.options.items()]
|
||||
args_repr = ", ".join([self.cls.__name__] + option_strings)
|
||||
args_strings = [f"{value!r}" for value in self.args]
|
||||
option_strings = [f"{key}={value!r}" for key, value in self.kwargs.items()]
|
||||
name = getattr(self.cls, "__name__", "")
|
||||
args_repr = ", ".join([name] + args_strings + option_strings)
|
||||
return f"{class_name}({args_repr})"
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,4 +1,6 @@
|
||||
import typing
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
from starlette.authentication import (
|
||||
AuthCredentials,
|
||||
@@ -16,15 +18,13 @@ class AuthenticationMiddleware:
|
||||
self,
|
||||
app: ASGIApp,
|
||||
backend: AuthenticationBackend,
|
||||
on_error: typing.Optional[
|
||||
typing.Callable[[HTTPConnection, AuthenticationError], Response]
|
||||
] = None,
|
||||
on_error: Callable[[HTTPConnection, AuthenticationError], Response] | None = None,
|
||||
) -> None:
|
||||
self.app = app
|
||||
self.backend = backend
|
||||
self.on_error: typing.Callable[
|
||||
[HTTPConnection, AuthenticationError], Response
|
||||
] = (on_error if on_error is not None else self.default_on_error)
|
||||
self.on_error: Callable[[HTTPConnection, AuthenticationError], Response] = (
|
||||
on_error if on_error is not None else self.default_on_error
|
||||
)
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if scope["type"] not in ["http", "websocket"]:
|
||||
|
||||
@@ -1,23 +1,100 @@
|
||||
import typing
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncGenerator, AsyncIterable, Awaitable, Callable, Mapping, MutableMapping
|
||||
from typing import Any, TypeVar
|
||||
|
||||
import anyio
|
||||
|
||||
from starlette.background import BackgroundTask
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import ContentStream, Response, StreamingResponse
|
||||
from starlette._utils import collapse_excgroups
|
||||
from starlette.requests import ClientDisconnect, Request
|
||||
from starlette.responses import Response
|
||||
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
||||
|
||||
RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]]
|
||||
DispatchFunction = typing.Callable[
|
||||
[Request, RequestResponseEndpoint], typing.Awaitable[Response]
|
||||
]
|
||||
T = typing.TypeVar("T")
|
||||
RequestResponseEndpoint = Callable[[Request], Awaitable[Response]]
|
||||
DispatchFunction = Callable[[Request, RequestResponseEndpoint], Awaitable[Response]]
|
||||
BodyStreamGenerator = AsyncGenerator[bytes | MutableMapping[str, Any], None]
|
||||
AsyncContentStream = AsyncIterable[str | bytes | memoryview | MutableMapping[str, Any]]
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class _CachedRequest(Request):
|
||||
"""
|
||||
If the user calls Request.body() from their dispatch function
|
||||
we cache the entire request body in memory and pass that to downstream middlewares,
|
||||
but if they call Request.stream() then all we do is send an
|
||||
empty body so that downstream things don't hang forever.
|
||||
"""
|
||||
|
||||
def __init__(self, scope: Scope, receive: Receive):
|
||||
super().__init__(scope, receive)
|
||||
self._wrapped_rcv_disconnected = False
|
||||
self._wrapped_rcv_consumed = False
|
||||
self._wrapped_rc_stream = self.stream()
|
||||
|
||||
async def wrapped_receive(self) -> Message:
|
||||
# wrapped_rcv state 1: disconnected
|
||||
if self._wrapped_rcv_disconnected:
|
||||
# we've already sent a disconnect to the downstream app
|
||||
# we don't need to wait to get another one
|
||||
# (although most ASGI servers will just keep sending it)
|
||||
return {"type": "http.disconnect"}
|
||||
# wrapped_rcv state 1: consumed but not yet disconnected
|
||||
if self._wrapped_rcv_consumed:
|
||||
# since the downstream app has consumed us all that is left
|
||||
# is to send it a disconnect
|
||||
if self._is_disconnected:
|
||||
# the middleware has already seen the disconnect
|
||||
# since we know the client is disconnected no need to wait
|
||||
# for the message
|
||||
self._wrapped_rcv_disconnected = True
|
||||
return {"type": "http.disconnect"}
|
||||
# we don't know yet if the client is disconnected or not
|
||||
# so we'll wait until we get that message
|
||||
msg = await self.receive()
|
||||
if msg["type"] != "http.disconnect": # pragma: no cover
|
||||
# at this point a disconnect is all that we should be receiving
|
||||
# if we get something else, things went wrong somewhere
|
||||
raise RuntimeError(f"Unexpected message received: {msg['type']}")
|
||||
self._wrapped_rcv_disconnected = True
|
||||
return msg
|
||||
|
||||
# wrapped_rcv state 3: not yet consumed
|
||||
if getattr(self, "_body", None) is not None:
|
||||
# body() was called, we return it even if the client disconnected
|
||||
self._wrapped_rcv_consumed = True
|
||||
return {
|
||||
"type": "http.request",
|
||||
"body": self._body,
|
||||
"more_body": False,
|
||||
}
|
||||
elif self._stream_consumed:
|
||||
# stream() was called to completion
|
||||
# return an empty body so that downstream apps don't hang
|
||||
# waiting for a disconnect
|
||||
self._wrapped_rcv_consumed = True
|
||||
return {
|
||||
"type": "http.request",
|
||||
"body": b"",
|
||||
"more_body": False,
|
||||
}
|
||||
else:
|
||||
# body() was never called and stream() wasn't consumed
|
||||
try:
|
||||
stream = self.stream()
|
||||
chunk = await stream.__anext__()
|
||||
self._wrapped_rcv_consumed = self._stream_consumed
|
||||
return {
|
||||
"type": "http.request",
|
||||
"body": chunk,
|
||||
"more_body": not self._stream_consumed,
|
||||
}
|
||||
except ClientDisconnect:
|
||||
self._wrapped_rcv_disconnected = True
|
||||
return {"type": "http.disconnect"}
|
||||
|
||||
|
||||
class BaseHTTPMiddleware:
|
||||
def __init__(
|
||||
self, app: ASGIApp, dispatch: typing.Optional[DispatchFunction] = None
|
||||
) -> None:
|
||||
def __init__(self, app: ASGIApp, dispatch: DispatchFunction | None = None) -> None:
|
||||
self.app = app
|
||||
self.dispatch_func = self.dispatch if dispatch is None else dispatch
|
||||
|
||||
@@ -26,35 +103,32 @@ class BaseHTTPMiddleware:
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
request = _CachedRequest(scope, receive)
|
||||
wrapped_receive = request.wrapped_receive
|
||||
response_sent = anyio.Event()
|
||||
app_exc: Exception | None = None
|
||||
exception_already_raised = False
|
||||
|
||||
async def call_next(request: Request) -> Response:
|
||||
app_exc: typing.Optional[Exception] = None
|
||||
send_stream, recv_stream = anyio.create_memory_object_stream()
|
||||
|
||||
async def receive_or_disconnect() -> Message:
|
||||
if response_sent.is_set():
|
||||
return {"type": "http.disconnect"}
|
||||
|
||||
async with anyio.create_task_group() as task_group:
|
||||
|
||||
async def wrap(func: typing.Callable[[], typing.Awaitable[T]]) -> T:
|
||||
async def wrap(func: Callable[[], Awaitable[T]]) -> T:
|
||||
result = await func()
|
||||
task_group.cancel_scope.cancel()
|
||||
return result
|
||||
|
||||
task_group.start_soon(wrap, response_sent.wait)
|
||||
message = await wrap(request.receive)
|
||||
message = await wrap(wrapped_receive)
|
||||
|
||||
if response_sent.is_set():
|
||||
return {"type": "http.disconnect"}
|
||||
|
||||
return message
|
||||
|
||||
async def close_recv_stream_on_response_sent() -> None:
|
||||
await response_sent.wait()
|
||||
recv_stream.close()
|
||||
|
||||
async def send_no_error(message: Message) -> None:
|
||||
try:
|
||||
await send_stream.send(message)
|
||||
@@ -65,13 +139,12 @@ class BaseHTTPMiddleware:
|
||||
async def coro() -> None:
|
||||
nonlocal app_exc
|
||||
|
||||
async with send_stream:
|
||||
with send_stream:
|
||||
try:
|
||||
await self.app(scope, receive_or_disconnect, send_no_error)
|
||||
except Exception as exc:
|
||||
app_exc = exc
|
||||
|
||||
task_group.start_soon(close_recv_stream_on_response_sent)
|
||||
task_group.start_soon(coro)
|
||||
|
||||
try:
|
||||
@@ -81,54 +154,91 @@ class BaseHTTPMiddleware:
|
||||
message = await recv_stream.receive()
|
||||
except anyio.EndOfStream:
|
||||
if app_exc is not None:
|
||||
raise app_exc
|
||||
nonlocal exception_already_raised
|
||||
exception_already_raised = True
|
||||
# Prevent `anyio.EndOfStream` from polluting app exception context.
|
||||
# If both cause and context are None then the context is suppressed
|
||||
# and `anyio.EndOfStream` is not present in the exception traceback.
|
||||
# If exception cause is not None then it is propagated with
|
||||
# reraising here.
|
||||
# If exception has no cause but has context set then the context is
|
||||
# propagated as a cause with the reraise. This is necessary in order
|
||||
# to prevent `anyio.EndOfStream` from polluting the exception
|
||||
# context.
|
||||
raise app_exc from app_exc.__cause__ or app_exc.__context__
|
||||
raise RuntimeError("No response returned.")
|
||||
|
||||
assert message["type"] == "http.response.start"
|
||||
|
||||
async def body_stream() -> typing.AsyncGenerator[bytes, None]:
|
||||
async with recv_stream:
|
||||
async for message in recv_stream:
|
||||
assert message["type"] == "http.response.body"
|
||||
body = message.get("body", b"")
|
||||
if body:
|
||||
yield body
|
||||
async def body_stream() -> BodyStreamGenerator:
|
||||
async for message in recv_stream:
|
||||
if message["type"] == "http.response.pathsend":
|
||||
yield message
|
||||
break
|
||||
assert message["type"] == "http.response.body", f"Unexpected message: {message}"
|
||||
body = message.get("body", b"")
|
||||
if body:
|
||||
yield body
|
||||
if not message.get("more_body", False):
|
||||
break
|
||||
|
||||
if app_exc is not None:
|
||||
raise app_exc
|
||||
|
||||
response = _StreamingResponse(
|
||||
status_code=message["status"], content=body_stream(), info=info
|
||||
)
|
||||
response = _StreamingResponse(status_code=message["status"], content=body_stream(), info=info)
|
||||
response.raw_headers = message["headers"]
|
||||
return response
|
||||
|
||||
async with anyio.create_task_group() as task_group:
|
||||
request = Request(scope, receive=receive)
|
||||
response = await self.dispatch_func(request, call_next)
|
||||
await response(scope, receive, send)
|
||||
response_sent.set()
|
||||
streams: anyio.create_memory_object_stream[Message] = anyio.create_memory_object_stream()
|
||||
send_stream, recv_stream = streams
|
||||
with recv_stream, send_stream, collapse_excgroups():
|
||||
async with anyio.create_task_group() as task_group:
|
||||
response = await self.dispatch_func(request, call_next)
|
||||
await response(scope, wrapped_receive, send)
|
||||
response_sent.set()
|
||||
recv_stream.close()
|
||||
if app_exc is not None and not exception_already_raised:
|
||||
raise app_exc
|
||||
|
||||
async def dispatch(
|
||||
self, request: Request, call_next: RequestResponseEndpoint
|
||||
) -> Response:
|
||||
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
|
||||
class _StreamingResponse(StreamingResponse):
|
||||
class _StreamingResponse(Response):
|
||||
def __init__(
|
||||
self,
|
||||
content: ContentStream,
|
||||
content: AsyncContentStream,
|
||||
status_code: int = 200,
|
||||
headers: typing.Optional[typing.Mapping[str, str]] = None,
|
||||
media_type: typing.Optional[str] = None,
|
||||
background: typing.Optional[BackgroundTask] = None,
|
||||
info: typing.Optional[typing.Mapping[str, typing.Any]] = None,
|
||||
headers: Mapping[str, str] | None = None,
|
||||
media_type: str | None = None,
|
||||
info: Mapping[str, Any] | None = None,
|
||||
) -> None:
|
||||
self._info = info
|
||||
super().__init__(content, status_code, headers, media_type, background)
|
||||
self.info = info
|
||||
self.body_iterator = content
|
||||
self.status_code = status_code
|
||||
self.media_type = media_type
|
||||
self.init_headers(headers)
|
||||
self.background = None
|
||||
|
||||
async def stream_response(self, send: Send) -> None:
|
||||
if self._info:
|
||||
await send({"type": "http.response.debug", "info": self._info})
|
||||
return await super().stream_response(send)
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if self.info is not None:
|
||||
await send({"type": "http.response.debug", "info": self.info})
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": self.status_code,
|
||||
"headers": self.raw_headers,
|
||||
}
|
||||
)
|
||||
|
||||
should_close_body = True
|
||||
async for chunk in self.body_iterator:
|
||||
if isinstance(chunk, dict):
|
||||
# We got an ASGI message which is not response body (eg: pathsend)
|
||||
should_close_body = False
|
||||
await send(chunk)
|
||||
continue
|
||||
await send({"type": "http.response.body", "body": chunk, "more_body": True})
|
||||
|
||||
if should_close_body:
|
||||
await send({"type": "http.response.body", "body": b"", "more_body": False})
|
||||
|
||||
if self.background:
|
||||
await self.background()
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import re
|
||||
import typing
|
||||
from collections.abc import Sequence
|
||||
|
||||
from starlette.datastructures import Headers, MutableHeaders
|
||||
from starlette.responses import PlainTextResponse, Response
|
||||
@@ -14,12 +16,12 @@ class CORSMiddleware:
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
allow_origins: typing.Sequence[str] = (),
|
||||
allow_methods: typing.Sequence[str] = ("GET",),
|
||||
allow_headers: typing.Sequence[str] = (),
|
||||
allow_origins: Sequence[str] = (),
|
||||
allow_methods: Sequence[str] = ("GET",),
|
||||
allow_headers: Sequence[str] = (),
|
||||
allow_credentials: bool = False,
|
||||
allow_origin_regex: typing.Optional[str] = None,
|
||||
expose_headers: typing.Sequence[str] = (),
|
||||
allow_origin_regex: str | None = None,
|
||||
expose_headers: Sequence[str] = (),
|
||||
max_age: int = 600,
|
||||
) -> None:
|
||||
if "*" in allow_methods:
|
||||
@@ -94,9 +96,7 @@ class CORSMiddleware:
|
||||
if self.allow_all_origins:
|
||||
return True
|
||||
|
||||
if self.allow_origin_regex is not None and self.allow_origin_regex.fullmatch(
|
||||
origin
|
||||
):
|
||||
if self.allow_origin_regex is not None and self.allow_origin_regex.fullmatch(origin):
|
||||
return True
|
||||
|
||||
return origin in self.allow_origins
|
||||
@@ -139,15 +139,11 @@ class CORSMiddleware:
|
||||
|
||||
return PlainTextResponse("OK", status_code=200, headers=headers)
|
||||
|
||||
async def simple_response(
|
||||
self, scope: Scope, receive: Receive, send: Send, request_headers: Headers
|
||||
) -> None:
|
||||
async def simple_response(self, scope: Scope, receive: Receive, send: Send, request_headers: Headers) -> None:
|
||||
send = functools.partial(self.send, send=send, request_headers=request_headers)
|
||||
await self.app(scope, receive, send)
|
||||
|
||||
async def send(
|
||||
self, message: Message, send: Send, request_headers: Headers
|
||||
) -> None:
|
||||
async def send(self, message: Message, send: Send, request_headers: Headers) -> None:
|
||||
if message["type"] != "http.response.start":
|
||||
await send(message)
|
||||
return
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import html
|
||||
import inspect
|
||||
import sys
|
||||
import traceback
|
||||
import typing
|
||||
|
||||
from starlette._utils import is_async_callable
|
||||
from starlette.concurrency import run_in_threadpool
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import HTMLResponse, PlainTextResponse, Response
|
||||
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
||||
from starlette.types import ASGIApp, ExceptionHandler, Message, Receive, Scope, Send
|
||||
|
||||
STYLES = """
|
||||
p {
|
||||
@@ -137,7 +139,7 @@ class ServerErrorMiddleware:
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
handler: typing.Optional[typing.Callable] = None,
|
||||
handler: ExceptionHandler | None = None,
|
||||
debug: bool = False,
|
||||
) -> None:
|
||||
self.app = app
|
||||
@@ -183,9 +185,7 @@ class ServerErrorMiddleware:
|
||||
# to optionally raise the error within the test case.
|
||||
raise exc
|
||||
|
||||
def format_line(
|
||||
self, index: int, line: str, frame_lineno: int, frame_index: int
|
||||
) -> str:
|
||||
def format_line(self, index: int, line: str, frame_lineno: int, frame_index: int) -> str:
|
||||
values = {
|
||||
# HTML escape - line could contain < or >
|
||||
"line": html.escape(line).replace(" ", " "),
|
||||
@@ -199,7 +199,10 @@ class ServerErrorMiddleware:
|
||||
def generate_frame_html(self, frame: inspect.FrameInfo, is_collapsed: bool) -> str:
|
||||
code_context = "".join(
|
||||
self.format_line(
|
||||
index, line, frame.lineno, frame.index # type: ignore[arg-type]
|
||||
index,
|
||||
line,
|
||||
frame.lineno,
|
||||
frame.index, # type: ignore[arg-type]
|
||||
)
|
||||
for index, line in enumerate(frame.code_context or [])
|
||||
)
|
||||
@@ -219,9 +222,7 @@ class ServerErrorMiddleware:
|
||||
return FRAME_TEMPLATE.format(**values)
|
||||
|
||||
def generate_html(self, exc: Exception, limit: int = 7) -> str:
|
||||
traceback_obj = traceback.TracebackException.from_exception(
|
||||
exc, capture_locals=True
|
||||
)
|
||||
traceback_obj = traceback.TracebackException.from_exception(exc, capture_locals=True)
|
||||
|
||||
exc_html = ""
|
||||
is_collapsed = False
|
||||
@@ -232,11 +233,13 @@ class ServerErrorMiddleware:
|
||||
exc_html += self.generate_frame_html(frame, is_collapsed)
|
||||
is_collapsed = True
|
||||
|
||||
if sys.version_info >= (3, 13): # pragma: no cover
|
||||
exc_type_str = traceback_obj.exc_type_str
|
||||
else: # pragma: no cover
|
||||
exc_type_str = traceback_obj.exc_type.__name__
|
||||
|
||||
# escape error class and text
|
||||
error = (
|
||||
f"{html.escape(traceback_obj.exc_type.__name__)}: "
|
||||
f"{html.escape(str(traceback_obj))}"
|
||||
)
|
||||
error = f"{html.escape(exc_type_str)}: {html.escape(str(traceback_obj))}"
|
||||
|
||||
return TEMPLATE.format(styles=STYLES, js=JS, error=error, exc_html=exc_html)
|
||||
|
||||
|
||||
@@ -1,11 +1,17 @@
|
||||
import typing
|
||||
from __future__ import annotations
|
||||
|
||||
from starlette._utils import is_async_callable
|
||||
from starlette.concurrency import run_in_threadpool
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from starlette._exception_handler import (
|
||||
ExceptionHandlers,
|
||||
StatusHandlers,
|
||||
wrap_app_handling_exceptions,
|
||||
)
|
||||
from starlette.exceptions import HTTPException, WebSocketException
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import PlainTextResponse, Response
|
||||
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
||||
from starlette.types import ASGIApp, ExceptionHandler, Receive, Scope, Send
|
||||
from starlette.websockets import WebSocket
|
||||
|
||||
|
||||
@@ -13,28 +19,24 @@ class ExceptionMiddleware:
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
handlers: typing.Optional[
|
||||
typing.Mapping[typing.Any, typing.Callable[[Request, Exception], Response]]
|
||||
] = None,
|
||||
handlers: Mapping[Any, ExceptionHandler] | None = None,
|
||||
debug: bool = False,
|
||||
) -> None:
|
||||
self.app = app
|
||||
self.debug = debug # TODO: We ought to handle 404 cases if debug is set.
|
||||
self._status_handlers: typing.Dict[int, typing.Callable] = {}
|
||||
self._exception_handlers: typing.Dict[
|
||||
typing.Type[Exception], typing.Callable
|
||||
] = {
|
||||
self._status_handlers: StatusHandlers = {}
|
||||
self._exception_handlers: ExceptionHandlers = {
|
||||
HTTPException: self.http_exception,
|
||||
WebSocketException: self.websocket_exception,
|
||||
}
|
||||
if handlers is not None:
|
||||
if handlers is not None: # pragma: no branch
|
||||
for key, value in handlers.items():
|
||||
self.add_exception_handler(key, value)
|
||||
|
||||
def add_exception_handler(
|
||||
self,
|
||||
exc_class_or_status_code: typing.Union[int, typing.Type[Exception]],
|
||||
handler: typing.Callable[[Request, Exception], Response],
|
||||
exc_class_or_status_code: int | type[Exception],
|
||||
handler: ExceptionHandler,
|
||||
) -> None:
|
||||
if isinstance(exc_class_or_status_code, int):
|
||||
self._status_handlers[exc_class_or_status_code] = handler
|
||||
@@ -42,68 +44,30 @@ class ExceptionMiddleware:
|
||||
assert issubclass(exc_class_or_status_code, Exception)
|
||||
self._exception_handlers[exc_class_or_status_code] = handler
|
||||
|
||||
def _lookup_exception_handler(
|
||||
self, exc: Exception
|
||||
) -> typing.Optional[typing.Callable]:
|
||||
for cls in type(exc).__mro__:
|
||||
if cls in self._exception_handlers:
|
||||
return self._exception_handlers[cls]
|
||||
return None
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if scope["type"] not in ("http", "websocket"):
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
response_started = False
|
||||
|
||||
async def sender(message: Message) -> None:
|
||||
nonlocal response_started
|
||||
|
||||
if message["type"] == "http.response.start":
|
||||
response_started = True
|
||||
await send(message)
|
||||
|
||||
try:
|
||||
await self.app(scope, receive, sender)
|
||||
except Exception as exc:
|
||||
handler = None
|
||||
|
||||
if isinstance(exc, HTTPException):
|
||||
handler = self._status_handlers.get(exc.status_code)
|
||||
|
||||
if handler is None:
|
||||
handler = self._lookup_exception_handler(exc)
|
||||
|
||||
if handler is None:
|
||||
raise exc
|
||||
|
||||
if response_started:
|
||||
msg = "Caught handled exception, but response already started."
|
||||
raise RuntimeError(msg) from exc
|
||||
|
||||
if scope["type"] == "http":
|
||||
request = Request(scope, receive=receive)
|
||||
if is_async_callable(handler):
|
||||
response = await handler(request, exc)
|
||||
else:
|
||||
response = await run_in_threadpool(handler, request, exc)
|
||||
await response(scope, receive, sender)
|
||||
elif scope["type"] == "websocket":
|
||||
websocket = WebSocket(scope, receive=receive, send=send)
|
||||
if is_async_callable(handler):
|
||||
await handler(websocket, exc)
|
||||
else:
|
||||
await run_in_threadpool(handler, websocket, exc)
|
||||
|
||||
def http_exception(self, request: Request, exc: HTTPException) -> Response:
|
||||
if exc.status_code in {204, 304}:
|
||||
return Response(status_code=exc.status_code, headers=exc.headers)
|
||||
return PlainTextResponse(
|
||||
exc.detail, status_code=exc.status_code, headers=exc.headers
|
||||
scope["starlette.exception_handlers"] = (
|
||||
self._exception_handlers,
|
||||
self._status_handlers,
|
||||
)
|
||||
|
||||
async def websocket_exception(
|
||||
self, websocket: WebSocket, exc: WebSocketException
|
||||
) -> None:
|
||||
await websocket.close(code=exc.code, reason=exc.reason)
|
||||
conn: Request | WebSocket
|
||||
if scope["type"] == "http":
|
||||
conn = Request(scope, receive, send)
|
||||
else:
|
||||
conn = WebSocket(scope, receive, send)
|
||||
|
||||
await wrap_app_handling_exceptions(self.app, conn)(scope, receive, send)
|
||||
|
||||
async def http_exception(self, request: Request, exc: Exception) -> Response:
|
||||
assert isinstance(exc, HTTPException)
|
||||
if exc.status_code in {204, 304}:
|
||||
return Response(status_code=exc.status_code, headers=exc.headers)
|
||||
return PlainTextResponse(exc.detail, status_code=exc.status_code, headers=exc.headers)
|
||||
|
||||
async def websocket_exception(self, websocket: WebSocket, exc: Exception) -> None:
|
||||
assert isinstance(exc, WebSocketException)
|
||||
await websocket.close(code=exc.code, reason=exc.reason) # pragma: no cover
|
||||
|
||||
@@ -1,49 +1,51 @@
|
||||
import gzip
|
||||
import io
|
||||
import typing
|
||||
from typing import NoReturn
|
||||
|
||||
from starlette.datastructures import Headers, MutableHeaders
|
||||
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
||||
|
||||
DEFAULT_EXCLUDED_CONTENT_TYPES = ("text/event-stream",)
|
||||
|
||||
|
||||
class GZipMiddleware:
|
||||
def __init__(
|
||||
self, app: ASGIApp, minimum_size: int = 500, compresslevel: int = 9
|
||||
) -> None:
|
||||
def __init__(self, app: ASGIApp, minimum_size: int = 500, compresslevel: int = 9) -> None:
|
||||
self.app = app
|
||||
self.minimum_size = minimum_size
|
||||
self.compresslevel = compresslevel
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if scope["type"] == "http":
|
||||
headers = Headers(scope=scope)
|
||||
if "gzip" in headers.get("Accept-Encoding", ""):
|
||||
responder = GZipResponder(
|
||||
self.app, self.minimum_size, compresslevel=self.compresslevel
|
||||
)
|
||||
await responder(scope, receive, send)
|
||||
return
|
||||
await self.app(scope, receive, send)
|
||||
if scope["type"] != "http": # pragma: no cover
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
headers = Headers(scope=scope)
|
||||
responder: ASGIApp
|
||||
if "gzip" in headers.get("Accept-Encoding", ""):
|
||||
responder = GZipResponder(self.app, self.minimum_size, compresslevel=self.compresslevel)
|
||||
else:
|
||||
responder = IdentityResponder(self.app, self.minimum_size)
|
||||
|
||||
await responder(scope, receive, send)
|
||||
|
||||
|
||||
class GZipResponder:
|
||||
def __init__(self, app: ASGIApp, minimum_size: int, compresslevel: int = 9) -> None:
|
||||
class IdentityResponder:
|
||||
content_encoding: str
|
||||
|
||||
def __init__(self, app: ASGIApp, minimum_size: int) -> None:
|
||||
self.app = app
|
||||
self.minimum_size = minimum_size
|
||||
self.send: Send = unattached_send
|
||||
self.initial_message: Message = {}
|
||||
self.started = False
|
||||
self.content_encoding_set = False
|
||||
self.gzip_buffer = io.BytesIO()
|
||||
self.gzip_file = gzip.GzipFile(
|
||||
mode="wb", fileobj=self.gzip_buffer, compresslevel=compresslevel
|
||||
)
|
||||
self.content_type_is_excluded = False
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
self.send = send
|
||||
await self.app(scope, receive, self.send_with_gzip)
|
||||
await self.app(scope, receive, self.send_with_compression)
|
||||
|
||||
async def send_with_gzip(self, message: Message) -> None:
|
||||
async def send_with_compression(self, message: Message) -> None:
|
||||
message_type = message["type"]
|
||||
if message_type == "http.response.start":
|
||||
# Don't send the initial message until we've determined how to
|
||||
@@ -51,7 +53,8 @@ class GZipResponder:
|
||||
self.initial_message = message
|
||||
headers = Headers(raw=self.initial_message["headers"])
|
||||
self.content_encoding_set = "content-encoding" in headers
|
||||
elif message_type == "http.response.body" and self.content_encoding_set:
|
||||
self.content_type_is_excluded = headers.get("content-type", "").startswith(DEFAULT_EXCLUDED_CONTENT_TYPES)
|
||||
elif message_type == "http.response.body" and (self.content_encoding_set or self.content_type_is_excluded):
|
||||
if not self.started:
|
||||
self.started = True
|
||||
await self.send(self.initial_message)
|
||||
@@ -61,53 +64,82 @@ class GZipResponder:
|
||||
body = message.get("body", b"")
|
||||
more_body = message.get("more_body", False)
|
||||
if len(body) < self.minimum_size and not more_body:
|
||||
# Don't apply GZip to small outgoing responses.
|
||||
# Don't apply compression to small outgoing responses.
|
||||
await self.send(self.initial_message)
|
||||
await self.send(message)
|
||||
elif not more_body:
|
||||
# Standard GZip response.
|
||||
self.gzip_file.write(body)
|
||||
self.gzip_file.close()
|
||||
body = self.gzip_buffer.getvalue()
|
||||
# Standard response.
|
||||
body = self.apply_compression(body, more_body=False)
|
||||
|
||||
headers = MutableHeaders(raw=self.initial_message["headers"])
|
||||
headers["Content-Encoding"] = "gzip"
|
||||
headers["Content-Length"] = str(len(body))
|
||||
headers.add_vary_header("Accept-Encoding")
|
||||
message["body"] = body
|
||||
if body != message["body"]:
|
||||
headers["Content-Encoding"] = self.content_encoding
|
||||
headers["Content-Length"] = str(len(body))
|
||||
message["body"] = body
|
||||
|
||||
await self.send(self.initial_message)
|
||||
await self.send(message)
|
||||
else:
|
||||
# Initial body in streaming GZip response.
|
||||
headers = MutableHeaders(raw=self.initial_message["headers"])
|
||||
headers["Content-Encoding"] = "gzip"
|
||||
headers.add_vary_header("Accept-Encoding")
|
||||
del headers["Content-Length"]
|
||||
# Initial body in streaming response.
|
||||
body = self.apply_compression(body, more_body=True)
|
||||
|
||||
self.gzip_file.write(body)
|
||||
message["body"] = self.gzip_buffer.getvalue()
|
||||
self.gzip_buffer.seek(0)
|
||||
self.gzip_buffer.truncate()
|
||||
headers = MutableHeaders(raw=self.initial_message["headers"])
|
||||
headers.add_vary_header("Accept-Encoding")
|
||||
if body != message["body"]:
|
||||
headers["Content-Encoding"] = self.content_encoding
|
||||
del headers["Content-Length"]
|
||||
message["body"] = body
|
||||
|
||||
await self.send(self.initial_message)
|
||||
await self.send(message)
|
||||
|
||||
elif message_type == "http.response.body":
|
||||
# Remaining body in streaming GZip response.
|
||||
# Remaining body in streaming response.
|
||||
body = message.get("body", b"")
|
||||
more_body = message.get("more_body", False)
|
||||
|
||||
self.gzip_file.write(body)
|
||||
if not more_body:
|
||||
self.gzip_file.close()
|
||||
|
||||
message["body"] = self.gzip_buffer.getvalue()
|
||||
self.gzip_buffer.seek(0)
|
||||
self.gzip_buffer.truncate()
|
||||
message["body"] = self.apply_compression(body, more_body=more_body)
|
||||
|
||||
await self.send(message)
|
||||
elif message_type == "http.response.pathsend": # pragma: no branch
|
||||
# Don't apply GZip to pathsend responses
|
||||
await self.send(self.initial_message)
|
||||
await self.send(message)
|
||||
|
||||
def apply_compression(self, body: bytes, *, more_body: bool) -> bytes:
|
||||
"""Apply compression on the response body.
|
||||
|
||||
If more_body is False, any compression file should be closed. If it
|
||||
isn't, it won't be closed automatically until all background tasks
|
||||
complete.
|
||||
"""
|
||||
return body
|
||||
|
||||
|
||||
async def unattached_send(message: Message) -> typing.NoReturn:
|
||||
class GZipResponder(IdentityResponder):
|
||||
content_encoding = "gzip"
|
||||
|
||||
def __init__(self, app: ASGIApp, minimum_size: int, compresslevel: int = 9) -> None:
|
||||
super().__init__(app, minimum_size)
|
||||
|
||||
self.gzip_buffer = io.BytesIO()
|
||||
self.gzip_file = gzip.GzipFile(mode="wb", fileobj=self.gzip_buffer, compresslevel=compresslevel)
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
with self.gzip_buffer, self.gzip_file:
|
||||
await super().__call__(scope, receive, send)
|
||||
|
||||
def apply_compression(self, body: bytes, *, more_body: bool) -> bytes:
|
||||
self.gzip_file.write(body)
|
||||
if not more_body:
|
||||
self.gzip_file.close()
|
||||
|
||||
body = self.gzip_buffer.getvalue()
|
||||
self.gzip_buffer.seek(0)
|
||||
self.gzip_buffer.truncate()
|
||||
|
||||
return body
|
||||
|
||||
|
||||
async def unattached_send(message: Message) -> NoReturn:
|
||||
raise RuntimeError("send awaitable not set") # pragma: no cover
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sys
|
||||
import typing
|
||||
from base64 import b64decode, b64encode
|
||||
from typing import Literal
|
||||
|
||||
import itsdangerous
|
||||
from itsdangerous.exc import BadSignature
|
||||
@@ -10,22 +11,18 @@ from starlette.datastructures import MutableHeaders, Secret
|
||||
from starlette.requests import HTTPConnection
|
||||
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
||||
|
||||
if sys.version_info >= (3, 8): # pragma: no cover
|
||||
from typing import Literal
|
||||
else: # pragma: no cover
|
||||
from typing_extensions import Literal
|
||||
|
||||
|
||||
class SessionMiddleware:
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
secret_key: typing.Union[str, Secret],
|
||||
secret_key: str | Secret,
|
||||
session_cookie: str = "session",
|
||||
max_age: typing.Optional[int] = 14 * 24 * 60 * 60, # 14 days, in seconds
|
||||
max_age: int | None = 14 * 24 * 60 * 60, # 14 days, in seconds
|
||||
path: str = "/",
|
||||
same_site: Literal["lax", "strict", "none"] = "lax",
|
||||
https_only: bool = False,
|
||||
domain: str | None = None,
|
||||
) -> None:
|
||||
self.app = app
|
||||
self.signer = itsdangerous.TimestampSigner(str(secret_key))
|
||||
@@ -35,6 +32,8 @@ class SessionMiddleware:
|
||||
self.security_flags = "httponly; samesite=" + same_site
|
||||
if https_only: # Secure flag can be used with HTTPS only
|
||||
self.security_flags += "; secure"
|
||||
if domain is not None:
|
||||
self.security_flags += f"; domain={domain}"
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if scope["type"] not in ("http", "websocket"): # pragma: no cover
|
||||
@@ -62,7 +61,7 @@ class SessionMiddleware:
|
||||
data = b64encode(json.dumps(scope["session"]).encode("utf-8"))
|
||||
data = self.signer.sign(data)
|
||||
headers = MutableHeaders(scope=message)
|
||||
header_value = "{session_cookie}={data}; path={path}; {max_age}{security_flags}".format( # noqa E501
|
||||
header_value = "{session_cookie}={data}; path={path}; {max_age}{security_flags}".format(
|
||||
session_cookie=self.session_cookie,
|
||||
data=data.decode("utf-8"),
|
||||
path=self.path,
|
||||
@@ -73,7 +72,7 @@ class SessionMiddleware:
|
||||
elif not initial_session_was_empty:
|
||||
# The session has been cleared.
|
||||
headers = MutableHeaders(scope=message)
|
||||
header_value = "{session_cookie}={data}; path={path}; {expires}{security_flags}".format( # noqa E501
|
||||
header_value = "{session_cookie}={data}; path={path}; {expires}{security_flags}".format(
|
||||
session_cookie=self.session_cookie,
|
||||
data="null",
|
||||
path=self.path,
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import typing
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from starlette.datastructures import URL, Headers
|
||||
from starlette.responses import PlainTextResponse, RedirectResponse, Response
|
||||
@@ -11,7 +13,7 @@ class TrustedHostMiddleware:
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
allowed_hosts: typing.Optional[typing.Sequence[str]] = None,
|
||||
allowed_hosts: Sequence[str] | None = None,
|
||||
www_redirect: bool = True,
|
||||
) -> None:
|
||||
if allowed_hosts is None:
|
||||
@@ -39,9 +41,7 @@ class TrustedHostMiddleware:
|
||||
is_valid_host = False
|
||||
found_www_redirect = False
|
||||
for pattern in self.allowed_hosts:
|
||||
if host == pattern or (
|
||||
pattern.startswith("*") and host.endswith(pattern[1:])
|
||||
):
|
||||
if host == pattern or (pattern.startswith("*") and host.endswith(pattern[1:])):
|
||||
is_valid_host = True
|
||||
break
|
||||
elif "www." + host == pattern:
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import math
|
||||
import sys
|
||||
import typing
|
||||
import warnings
|
||||
from collections.abc import Callable, MutableMapping
|
||||
from typing import Any
|
||||
|
||||
import anyio
|
||||
from anyio.abc import ObjectReceiveStream, ObjectSendStream
|
||||
|
||||
from starlette.types import Receive, Scope, Send
|
||||
|
||||
@@ -15,14 +19,20 @@ warnings.warn(
|
||||
)
|
||||
|
||||
|
||||
def build_environ(scope: Scope, body: bytes) -> dict:
|
||||
def build_environ(scope: Scope, body: bytes) -> dict[str, Any]:
|
||||
"""
|
||||
Builds a scope and request body into a WSGI environ object.
|
||||
"""
|
||||
|
||||
script_name = scope.get("root_path", "").encode("utf8").decode("latin1")
|
||||
path_info = scope["path"].encode("utf8").decode("latin1")
|
||||
if path_info.startswith(script_name):
|
||||
path_info = path_info[len(script_name) :]
|
||||
|
||||
environ = {
|
||||
"REQUEST_METHOD": scope["method"],
|
||||
"SCRIPT_NAME": scope.get("root_path", "").encode("utf8").decode("latin1"),
|
||||
"PATH_INFO": scope["path"].encode("utf8").decode("latin1"),
|
||||
"SCRIPT_NAME": script_name,
|
||||
"PATH_INFO": path_info,
|
||||
"QUERY_STRING": scope["query_string"].decode("ascii"),
|
||||
"SERVER_PROTOCOL": f"HTTP/{scope['http_version']}",
|
||||
"wsgi.version": (1, 0),
|
||||
@@ -62,7 +72,7 @@ def build_environ(scope: Scope, body: bytes) -> dict:
|
||||
|
||||
|
||||
class WSGIMiddleware:
|
||||
def __init__(self, app: typing.Callable) -> None:
|
||||
def __init__(self, app: Callable[..., Any]) -> None:
|
||||
self.app = app
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
@@ -72,16 +82,17 @@ class WSGIMiddleware:
|
||||
|
||||
|
||||
class WSGIResponder:
|
||||
def __init__(self, app: typing.Callable, scope: Scope) -> None:
|
||||
stream_send: ObjectSendStream[MutableMapping[str, Any]]
|
||||
stream_receive: ObjectReceiveStream[MutableMapping[str, Any]]
|
||||
|
||||
def __init__(self, app: Callable[..., Any], scope: Scope) -> None:
|
||||
self.app = app
|
||||
self.scope = scope
|
||||
self.status = None
|
||||
self.response_headers = None
|
||||
self.stream_send, self.stream_receive = anyio.create_memory_object_stream(
|
||||
math.inf
|
||||
)
|
||||
self.stream_send, self.stream_receive = anyio.create_memory_object_stream(math.inf)
|
||||
self.response_started = False
|
||||
self.exc_info: typing.Any = None
|
||||
self.exc_info: Any = None
|
||||
|
||||
async def __call__(self, receive: Receive, send: Send) -> None:
|
||||
body = b""
|
||||
@@ -107,11 +118,11 @@ class WSGIResponder:
|
||||
def start_response(
|
||||
self,
|
||||
status: str,
|
||||
response_headers: typing.List[typing.Tuple[str, str]],
|
||||
exc_info: typing.Any = None,
|
||||
response_headers: list[tuple[str, str]],
|
||||
exc_info: Any = None,
|
||||
) -> None:
|
||||
self.exc_info = exc_info
|
||||
if not self.response_started:
|
||||
if not self.response_started: # pragma: no branch
|
||||
self.response_started = True
|
||||
status_code_string, _ = status.split(" ", 1)
|
||||
status_code = int(status_code_string)
|
||||
@@ -128,13 +139,15 @@ class WSGIResponder:
|
||||
},
|
||||
)
|
||||
|
||||
def wsgi(self, environ: dict, start_response: typing.Callable) -> None:
|
||||
def wsgi(
|
||||
self,
|
||||
environ: dict[str, Any],
|
||||
start_response: Callable[..., Any],
|
||||
) -> None:
|
||||
for chunk in self.app(environ, start_response):
|
||||
anyio.from_thread.run(
|
||||
self.stream_send.send,
|
||||
{"type": "http.response.body", "body": chunk, "more_body": True},
|
||||
)
|
||||
|
||||
anyio.from_thread.run(
|
||||
self.stream_send.send, {"type": "http.response.body", "body": b""}
|
||||
)
|
||||
anyio.from_thread.run(self.stream_send.send, {"type": "http.response.body", "body": b""})
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import typing
|
||||
from collections.abc import AsyncGenerator, Iterator, Mapping
|
||||
from http import cookies as http_cookies
|
||||
from typing import TYPE_CHECKING, Any, NoReturn, cast
|
||||
|
||||
import anyio
|
||||
|
||||
@@ -10,14 +13,19 @@ from starlette.exceptions import HTTPException
|
||||
from starlette.formparsers import FormParser, MultiPartException, MultiPartParser
|
||||
from starlette.types import Message, Receive, Scope, Send
|
||||
|
||||
try:
|
||||
from multipart.multipart import parse_options_header
|
||||
except ModuleNotFoundError: # pragma: nocover
|
||||
parse_options_header = None
|
||||
if TYPE_CHECKING:
|
||||
from python_multipart.multipart import parse_options_header
|
||||
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from starlette.applications import Starlette
|
||||
from starlette.routing import Router
|
||||
else:
|
||||
try:
|
||||
try:
|
||||
from python_multipart.multipart import parse_options_header
|
||||
except ModuleNotFoundError: # pragma: no cover
|
||||
from multipart.multipart import parse_options_header
|
||||
except ModuleNotFoundError: # pragma: no cover
|
||||
parse_options_header = None
|
||||
|
||||
|
||||
SERVER_PUSH_HEADERS_TO_COPY = {
|
||||
@@ -29,7 +37,7 @@ SERVER_PUSH_HEADERS_TO_COPY = {
|
||||
}
|
||||
|
||||
|
||||
def cookie_parser(cookie_string: str) -> typing.Dict[str, str]:
|
||||
def cookie_parser(cookie_string: str) -> dict[str, str]:
|
||||
"""
|
||||
This function parses a ``Cookie`` HTTP header into a dict of key/value pairs.
|
||||
|
||||
@@ -41,7 +49,7 @@ def cookie_parser(cookie_string: str) -> typing.Dict[str, str]:
|
||||
Note: we are explicitly _NOT_ using `SimpleCookie.load` because it is based
|
||||
on an outdated spec and will fail on lots of input we want to support
|
||||
"""
|
||||
cookie_dict: typing.Dict[str, str] = {}
|
||||
cookie_dict: dict[str, str] = {}
|
||||
for chunk in cookie_string.split(";"):
|
||||
if "=" in chunk:
|
||||
key, val = chunk.split("=", 1)
|
||||
@@ -60,20 +68,20 @@ class ClientDisconnect(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class HTTPConnection(typing.Mapping[str, typing.Any]):
|
||||
class HTTPConnection(Mapping[str, Any]):
|
||||
"""
|
||||
A base class for incoming HTTP connections, that is used to provide
|
||||
any functionality that is common to both `Request` and `WebSocket`.
|
||||
"""
|
||||
|
||||
def __init__(self, scope: Scope, receive: typing.Optional[Receive] = None) -> None:
|
||||
def __init__(self, scope: Scope, receive: Receive | None = None) -> None:
|
||||
assert scope["type"] in ("http", "websocket")
|
||||
self.scope = scope
|
||||
|
||||
def __getitem__(self, key: str) -> typing.Any:
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
return self.scope[key]
|
||||
|
||||
def __iter__(self) -> typing.Iterator[str]:
|
||||
def __iter__(self) -> Iterator[str]:
|
||||
return iter(self.scope)
|
||||
|
||||
def __len__(self) -> int:
|
||||
@@ -86,12 +94,12 @@ class HTTPConnection(typing.Mapping[str, typing.Any]):
|
||||
__hash__ = object.__hash__
|
||||
|
||||
@property
|
||||
def app(self) -> typing.Any:
|
||||
def app(self) -> Any:
|
||||
return self.scope["app"]
|
||||
|
||||
@property
|
||||
def url(self) -> URL:
|
||||
if not hasattr(self, "_url"):
|
||||
if not hasattr(self, "_url"): # pragma: no branch
|
||||
self._url = URL(scope=self.scope)
|
||||
return self._url
|
||||
|
||||
@@ -99,11 +107,16 @@ class HTTPConnection(typing.Mapping[str, typing.Any]):
|
||||
def base_url(self) -> URL:
|
||||
if not hasattr(self, "_base_url"):
|
||||
base_url_scope = dict(self.scope)
|
||||
base_url_scope["path"] = "/"
|
||||
# This is used by request.url_for, it might be used inside a Mount which
|
||||
# would have its own child scope with its own root_path, but the base URL
|
||||
# for url_for should still be the top level app root path.
|
||||
app_root_path = base_url_scope.get("app_root_path", base_url_scope.get("root_path", ""))
|
||||
path = app_root_path
|
||||
if not path.endswith("/"):
|
||||
path += "/"
|
||||
base_url_scope["path"] = path
|
||||
base_url_scope["query_string"] = b""
|
||||
base_url_scope["root_path"] = base_url_scope.get(
|
||||
"app_root_path", base_url_scope.get("root_path", "")
|
||||
)
|
||||
base_url_scope["root_path"] = app_root_path
|
||||
self._base_url = URL(scope=base_url_scope)
|
||||
return self._base_url
|
||||
|
||||
@@ -115,52 +128,47 @@ class HTTPConnection(typing.Mapping[str, typing.Any]):
|
||||
|
||||
@property
|
||||
def query_params(self) -> QueryParams:
|
||||
if not hasattr(self, "_query_params"):
|
||||
if not hasattr(self, "_query_params"): # pragma: no branch
|
||||
self._query_params = QueryParams(self.scope["query_string"])
|
||||
return self._query_params
|
||||
|
||||
@property
|
||||
def path_params(self) -> typing.Dict[str, typing.Any]:
|
||||
def path_params(self) -> dict[str, Any]:
|
||||
return self.scope.get("path_params", {})
|
||||
|
||||
@property
|
||||
def cookies(self) -> typing.Dict[str, str]:
|
||||
def cookies(self) -> dict[str, str]:
|
||||
if not hasattr(self, "_cookies"):
|
||||
cookies: typing.Dict[str, str] = {}
|
||||
cookie_header = self.headers.get("cookie")
|
||||
cookies: dict[str, str] = {}
|
||||
cookie_headers = self.headers.getlist("cookie")
|
||||
|
||||
for header in cookie_headers:
|
||||
cookies.update(cookie_parser(header))
|
||||
|
||||
if cookie_header:
|
||||
cookies = cookie_parser(cookie_header)
|
||||
self._cookies = cookies
|
||||
return self._cookies
|
||||
|
||||
@property
|
||||
def client(self) -> typing.Optional[Address]:
|
||||
# client is a 2 item tuple of (host, port), None or missing
|
||||
def client(self) -> Address | None:
|
||||
# client is a 2 item tuple of (host, port), None if missing
|
||||
host_port = self.scope.get("client")
|
||||
if host_port is not None:
|
||||
return Address(*host_port)
|
||||
return None
|
||||
|
||||
@property
|
||||
def session(self) -> typing.Dict[str, typing.Any]:
|
||||
assert (
|
||||
"session" in self.scope
|
||||
), "SessionMiddleware must be installed to access request.session"
|
||||
return self.scope["session"]
|
||||
def session(self) -> dict[str, Any]:
|
||||
assert "session" in self.scope, "SessionMiddleware must be installed to access request.session"
|
||||
return self.scope["session"] # type: ignore[no-any-return]
|
||||
|
||||
@property
|
||||
def auth(self) -> typing.Any:
|
||||
assert (
|
||||
"auth" in self.scope
|
||||
), "AuthenticationMiddleware must be installed to access request.auth"
|
||||
def auth(self) -> Any:
|
||||
assert "auth" in self.scope, "AuthenticationMiddleware must be installed to access request.auth"
|
||||
return self.scope["auth"]
|
||||
|
||||
@property
|
||||
def user(self) -> typing.Any:
|
||||
assert (
|
||||
"user" in self.scope
|
||||
), "AuthenticationMiddleware must be installed to access request.user"
|
||||
def user(self) -> Any:
|
||||
assert "user" in self.scope, "AuthenticationMiddleware must be installed to access request.user"
|
||||
return self.scope["user"]
|
||||
|
||||
@property
|
||||
@@ -173,26 +181,26 @@ class HTTPConnection(typing.Mapping[str, typing.Any]):
|
||||
self._state = State(self.scope["state"])
|
||||
return self._state
|
||||
|
||||
def url_for(self, __name: str, **path_params: typing.Any) -> URL:
|
||||
router: Router = self.scope["router"]
|
||||
url_path = router.url_path_for(__name, **path_params)
|
||||
def url_for(self, name: str, /, **path_params: Any) -> URL:
|
||||
url_path_provider: Router | Starlette | None = self.scope.get("router") or self.scope.get("app")
|
||||
if url_path_provider is None:
|
||||
raise RuntimeError("The `url_for` method can only be used inside a Starlette application or with a router.")
|
||||
url_path = url_path_provider.url_path_for(name, **path_params)
|
||||
return url_path.make_absolute_url(base_url=self.base_url)
|
||||
|
||||
|
||||
async def empty_receive() -> typing.NoReturn:
|
||||
async def empty_receive() -> NoReturn:
|
||||
raise RuntimeError("Receive channel has not been made available")
|
||||
|
||||
|
||||
async def empty_send(message: Message) -> typing.NoReturn:
|
||||
async def empty_send(message: Message) -> NoReturn:
|
||||
raise RuntimeError("Send channel has not been made available")
|
||||
|
||||
|
||||
class Request(HTTPConnection):
|
||||
_form: typing.Optional[FormData]
|
||||
_form: FormData | None
|
||||
|
||||
def __init__(
|
||||
self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send
|
||||
):
|
||||
def __init__(self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send):
|
||||
super().__init__(scope)
|
||||
assert scope["type"] == "http"
|
||||
self._receive = receive
|
||||
@@ -203,43 +211,42 @@ class Request(HTTPConnection):
|
||||
|
||||
@property
|
||||
def method(self) -> str:
|
||||
return self.scope["method"]
|
||||
return cast(str, self.scope["method"])
|
||||
|
||||
@property
|
||||
def receive(self) -> Receive:
|
||||
return self._receive
|
||||
|
||||
async def stream(self) -> typing.AsyncGenerator[bytes, None]:
|
||||
async def stream(self) -> AsyncGenerator[bytes, None]:
|
||||
if hasattr(self, "_body"):
|
||||
yield self._body
|
||||
yield b""
|
||||
return
|
||||
if self._stream_consumed:
|
||||
raise RuntimeError("Stream consumed")
|
||||
self._stream_consumed = True
|
||||
while True:
|
||||
while not self._stream_consumed:
|
||||
message = await self._receive()
|
||||
if message["type"] == "http.request":
|
||||
body = message.get("body", b"")
|
||||
if not message.get("more_body", False):
|
||||
self._stream_consumed = True
|
||||
if body:
|
||||
yield body
|
||||
if not message.get("more_body", False):
|
||||
break
|
||||
elif message["type"] == "http.disconnect":
|
||||
elif message["type"] == "http.disconnect": # pragma: no branch
|
||||
self._is_disconnected = True
|
||||
raise ClientDisconnect()
|
||||
yield b""
|
||||
|
||||
async def body(self) -> bytes:
|
||||
if not hasattr(self, "_body"):
|
||||
chunks: "typing.List[bytes]" = []
|
||||
chunks: list[bytes] = []
|
||||
async for chunk in self.stream():
|
||||
chunks.append(chunk)
|
||||
self._body = b"".join(chunks)
|
||||
return self._body
|
||||
|
||||
async def json(self) -> typing.Any:
|
||||
if not hasattr(self, "_json"):
|
||||
async def json(self) -> Any:
|
||||
if not hasattr(self, "_json"): # pragma: no branch
|
||||
body = await self.body()
|
||||
self._json = json.loads(body)
|
||||
return self._json
|
||||
@@ -247,13 +254,14 @@ class Request(HTTPConnection):
|
||||
async def _get_form(
|
||||
self,
|
||||
*,
|
||||
max_files: typing.Union[int, float] = 1000,
|
||||
max_fields: typing.Union[int, float] = 1000,
|
||||
max_files: int | float = 1000,
|
||||
max_fields: int | float = 1000,
|
||||
max_part_size: int = 1024 * 1024,
|
||||
) -> FormData:
|
||||
if self._form is None:
|
||||
assert (
|
||||
parse_options_header is not None
|
||||
), "The `python-multipart` library must be installed to use form parsing."
|
||||
if self._form is None: # pragma: no branch
|
||||
assert parse_options_header is not None, (
|
||||
"The `python-multipart` library must be installed to use form parsing."
|
||||
)
|
||||
content_type_header = self.headers.get("Content-Type")
|
||||
content_type: bytes
|
||||
content_type, _ = parse_options_header(content_type_header)
|
||||
@@ -264,6 +272,7 @@ class Request(HTTPConnection):
|
||||
self.stream(),
|
||||
max_files=max_files,
|
||||
max_fields=max_fields,
|
||||
max_part_size=max_part_size,
|
||||
)
|
||||
self._form = await multipart_parser.parse()
|
||||
except MultiPartException as exc:
|
||||
@@ -280,15 +289,16 @@ class Request(HTTPConnection):
|
||||
def form(
|
||||
self,
|
||||
*,
|
||||
max_files: typing.Union[int, float] = 1000,
|
||||
max_fields: typing.Union[int, float] = 1000,
|
||||
max_files: int | float = 1000,
|
||||
max_fields: int | float = 1000,
|
||||
max_part_size: int = 1024 * 1024,
|
||||
) -> AwaitableOrContextManager[FormData]:
|
||||
return AwaitableOrContextManagerWrapper(
|
||||
self._get_form(max_files=max_files, max_fields=max_fields)
|
||||
self._get_form(max_files=max_files, max_fields=max_fields, max_part_size=max_part_size)
|
||||
)
|
||||
|
||||
async def close(self) -> None:
|
||||
if self._form is not None:
|
||||
if self._form is not None: # pragma: no branch
|
||||
await self._form.close()
|
||||
|
||||
async def is_disconnected(self) -> bool:
|
||||
@@ -307,12 +317,8 @@ class Request(HTTPConnection):
|
||||
|
||||
async def send_push_promise(self, path: str) -> None:
|
||||
if "http.response.push" in self.scope.get("extensions", {}):
|
||||
raw_headers: "typing.List[typing.Tuple[bytes, bytes]]" = []
|
||||
raw_headers: list[tuple[bytes, bytes]] = []
|
||||
for name in SERVER_PUSH_HEADERS_TO_COPY:
|
||||
for value in self.headers.getlist(name):
|
||||
raw_headers.append(
|
||||
(name.encode("latin-1"), value.encode("latin-1"))
|
||||
)
|
||||
await self._send(
|
||||
{"type": "http.response.push", "path": path, "headers": raw_headers}
|
||||
)
|
||||
raw_headers.append((name.encode("latin-1"), value.encode("latin-1")))
|
||||
await self._send({"type": "http.response.push", "path": path, "headers": raw_headers})
|
||||
|
||||
@@ -1,40 +1,31 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import http.cookies
|
||||
import json
|
||||
import os
|
||||
import stat
|
||||
import sys
|
||||
import typing
|
||||
import warnings
|
||||
from collections.abc import AsyncIterable, Awaitable, Callable, Iterable, Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from email.utils import format_datetime, formatdate
|
||||
from functools import partial
|
||||
from mimetypes import guess_type as mimetypes_guess_type
|
||||
from mimetypes import guess_type
|
||||
from secrets import token_hex
|
||||
from typing import Any, Literal
|
||||
from urllib.parse import quote
|
||||
|
||||
import anyio
|
||||
import anyio.to_thread
|
||||
|
||||
from starlette._compat import md5_hexdigest
|
||||
from starlette._utils import collapse_excgroups
|
||||
from starlette.background import BackgroundTask
|
||||
from starlette.concurrency import iterate_in_threadpool
|
||||
from starlette.datastructures import URL, MutableHeaders
|
||||
from starlette.datastructures import URL, Headers, MutableHeaders
|
||||
from starlette.requests import ClientDisconnect
|
||||
from starlette.types import Receive, Scope, Send
|
||||
|
||||
if sys.version_info >= (3, 8): # pragma: no cover
|
||||
from typing import Literal
|
||||
else: # pragma: no cover
|
||||
from typing_extensions import Literal
|
||||
|
||||
# Workaround for adding samesite support to pre 3.8 python
|
||||
http.cookies.Morsel._reserved["samesite"] = "SameSite" # type: ignore[attr-defined]
|
||||
|
||||
|
||||
# Compatibility wrapper for `mimetypes.guess_type` to support `os.PathLike` on <py3.8
|
||||
def guess_type(
|
||||
url: typing.Union[str, "os.PathLike[str]"], strict: bool = True
|
||||
) -> typing.Tuple[typing.Optional[str], typing.Optional[str]]:
|
||||
if sys.version_info < (3, 8): # pragma: no cover
|
||||
url = os.fspath(url)
|
||||
return mimetypes_guess_type(url, strict)
|
||||
|
||||
|
||||
class Response:
|
||||
media_type = None
|
||||
@@ -42,11 +33,11 @@ class Response:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
content: typing.Any = None,
|
||||
content: Any = None,
|
||||
status_code: int = 200,
|
||||
headers: typing.Optional[typing.Mapping[str, str]] = None,
|
||||
media_type: typing.Optional[str] = None,
|
||||
background: typing.Optional[BackgroundTask] = None,
|
||||
headers: Mapping[str, str] | None = None,
|
||||
media_type: str | None = None,
|
||||
background: BackgroundTask | None = None,
|
||||
) -> None:
|
||||
self.status_code = status_code
|
||||
if media_type is not None:
|
||||
@@ -55,25 +46,20 @@ class Response:
|
||||
self.body = self.render(content)
|
||||
self.init_headers(headers)
|
||||
|
||||
def render(self, content: typing.Any) -> bytes:
|
||||
def render(self, content: Any) -> bytes | memoryview:
|
||||
if content is None:
|
||||
return b""
|
||||
if isinstance(content, bytes):
|
||||
if isinstance(content, bytes | memoryview):
|
||||
return content
|
||||
return content.encode(self.charset)
|
||||
return content.encode(self.charset) # type: ignore
|
||||
|
||||
def init_headers(
|
||||
self, headers: typing.Optional[typing.Mapping[str, str]] = None
|
||||
) -> None:
|
||||
def init_headers(self, headers: Mapping[str, str] | None = None) -> None:
|
||||
if headers is None:
|
||||
raw_headers: typing.List[typing.Tuple[bytes, bytes]] = []
|
||||
raw_headers: list[tuple[bytes, bytes]] = []
|
||||
populate_content_length = True
|
||||
populate_content_type = True
|
||||
else:
|
||||
raw_headers = [
|
||||
(k.lower().encode("latin-1"), v.encode("latin-1"))
|
||||
for k, v in headers.items()
|
||||
]
|
||||
raw_headers = [(k.lower().encode("latin-1"), v.encode("latin-1")) for k, v in headers.items()]
|
||||
keys = [h[0] for h in raw_headers]
|
||||
populate_content_length = b"content-length" not in keys
|
||||
populate_content_type = b"content-type" not in keys
|
||||
@@ -89,7 +75,7 @@ class Response:
|
||||
|
||||
content_type = self.media_type
|
||||
if content_type is not None and populate_content_type:
|
||||
if content_type.startswith("text/"):
|
||||
if content_type.startswith("text/") and "charset=" not in content_type.lower():
|
||||
content_type += "; charset=" + self.charset
|
||||
raw_headers.append((b"content-type", content_type.encode("latin-1")))
|
||||
|
||||
@@ -105,15 +91,16 @@ class Response:
|
||||
self,
|
||||
key: str,
|
||||
value: str = "",
|
||||
max_age: typing.Optional[int] = None,
|
||||
expires: typing.Optional[typing.Union[datetime, str, int]] = None,
|
||||
path: str = "/",
|
||||
domain: typing.Optional[str] = None,
|
||||
max_age: int | None = None,
|
||||
expires: datetime | str | int | None = None,
|
||||
path: str | None = "/",
|
||||
domain: str | None = None,
|
||||
secure: bool = False,
|
||||
httponly: bool = False,
|
||||
samesite: typing.Optional[Literal["lax", "strict", "none"]] = "lax",
|
||||
samesite: Literal["lax", "strict", "none"] | None = "lax",
|
||||
partitioned: bool = False,
|
||||
) -> None:
|
||||
cookie: "http.cookies.BaseCookie[str]" = http.cookies.SimpleCookie()
|
||||
cookie: http.cookies.BaseCookie[str] = http.cookies.SimpleCookie()
|
||||
cookie[key] = value
|
||||
if max_age is not None:
|
||||
cookie[key]["max-age"] = max_age
|
||||
@@ -137,6 +124,11 @@ class Response:
|
||||
"none",
|
||||
], "samesite must be either 'strict', 'lax' or 'none'"
|
||||
cookie[key]["samesite"] = samesite
|
||||
if partitioned:
|
||||
if sys.version_info < (3, 14):
|
||||
raise ValueError("Partitioned cookies are only supported in Python 3.14 and above.") # pragma: no cover
|
||||
cookie[key]["partitioned"] = True # pragma: no cover
|
||||
|
||||
cookie_val = cookie.output(header="").strip()
|
||||
self.raw_headers.append((b"set-cookie", cookie_val.encode("latin-1")))
|
||||
|
||||
@@ -144,10 +136,10 @@ class Response:
|
||||
self,
|
||||
key: str,
|
||||
path: str = "/",
|
||||
domain: typing.Optional[str] = None,
|
||||
domain: str | None = None,
|
||||
secure: bool = False,
|
||||
httponly: bool = False,
|
||||
samesite: typing.Optional[Literal["lax", "strict", "none"]] = "lax",
|
||||
samesite: Literal["lax", "strict", "none"] | None = "lax",
|
||||
) -> None:
|
||||
self.set_cookie(
|
||||
key,
|
||||
@@ -161,14 +153,15 @@ class Response:
|
||||
)
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
prefix = "websocket." if scope["type"] == "websocket" else ""
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"type": prefix + "http.response.start",
|
||||
"status": self.status_code,
|
||||
"headers": self.raw_headers,
|
||||
}
|
||||
)
|
||||
await send({"type": "http.response.body", "body": self.body})
|
||||
await send({"type": prefix + "http.response.body", "body": self.body})
|
||||
|
||||
if self.background is not None:
|
||||
await self.background()
|
||||
@@ -187,15 +180,15 @@ class JSONResponse(Response):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
content: typing.Any,
|
||||
content: Any,
|
||||
status_code: int = 200,
|
||||
headers: typing.Optional[typing.Dict[str, str]] = None,
|
||||
media_type: typing.Optional[str] = None,
|
||||
background: typing.Optional[BackgroundTask] = None,
|
||||
headers: Mapping[str, str] | None = None,
|
||||
media_type: str | None = None,
|
||||
background: BackgroundTask | None = None,
|
||||
) -> None:
|
||||
super().__init__(content, status_code, headers, media_type, background)
|
||||
|
||||
def render(self, content: typing.Any) -> bytes:
|
||||
def render(self, content: Any) -> bytes:
|
||||
return json.dumps(
|
||||
content,
|
||||
ensure_ascii=False,
|
||||
@@ -208,21 +201,19 @@ class JSONResponse(Response):
|
||||
class RedirectResponse(Response):
|
||||
def __init__(
|
||||
self,
|
||||
url: typing.Union[str, URL],
|
||||
url: str | URL,
|
||||
status_code: int = 307,
|
||||
headers: typing.Optional[typing.Mapping[str, str]] = None,
|
||||
background: typing.Optional[BackgroundTask] = None,
|
||||
headers: Mapping[str, str] | None = None,
|
||||
background: BackgroundTask | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
content=b"", status_code=status_code, headers=headers, background=background
|
||||
)
|
||||
super().__init__(content=b"", status_code=status_code, headers=headers, background=background)
|
||||
self.headers["location"] = quote(str(url), safe=":/%#?=@[]!$&'()*+,;")
|
||||
|
||||
|
||||
Content = typing.Union[str, bytes]
|
||||
SyncContentStream = typing.Iterator[Content]
|
||||
AsyncContentStream = typing.AsyncIterable[Content]
|
||||
ContentStream = typing.Union[AsyncContentStream, SyncContentStream]
|
||||
Content = str | bytes | memoryview
|
||||
SyncContentStream = Iterable[Content]
|
||||
AsyncContentStream = AsyncIterable[Content]
|
||||
ContentStream = AsyncContentStream | SyncContentStream
|
||||
|
||||
|
||||
class StreamingResponse(Response):
|
||||
@@ -232,11 +223,11 @@ class StreamingResponse(Response):
|
||||
self,
|
||||
content: ContentStream,
|
||||
status_code: int = 200,
|
||||
headers: typing.Optional[typing.Mapping[str, str]] = None,
|
||||
media_type: typing.Optional[str] = None,
|
||||
background: typing.Optional[BackgroundTask] = None,
|
||||
headers: Mapping[str, str] | None = None,
|
||||
media_type: str | None = None,
|
||||
background: BackgroundTask | None = None,
|
||||
) -> None:
|
||||
if isinstance(content, typing.AsyncIterable):
|
||||
if isinstance(content, AsyncIterable):
|
||||
self.body_iterator = content
|
||||
else:
|
||||
self.body_iterator = iterate_in_threadpool(content)
|
||||
@@ -260,60 +251,80 @@ class StreamingResponse(Response):
|
||||
}
|
||||
)
|
||||
async for chunk in self.body_iterator:
|
||||
if not isinstance(chunk, bytes):
|
||||
if not isinstance(chunk, bytes | memoryview):
|
||||
chunk = chunk.encode(self.charset)
|
||||
await send({"type": "http.response.body", "body": chunk, "more_body": True})
|
||||
|
||||
await send({"type": "http.response.body", "body": b"", "more_body": False})
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
async with anyio.create_task_group() as task_group:
|
||||
spec_version = tuple(map(int, scope.get("asgi", {}).get("spec_version", "2.0").split(".")))
|
||||
|
||||
async def wrap(func: "typing.Callable[[], typing.Awaitable[None]]") -> None:
|
||||
await func()
|
||||
task_group.cancel_scope.cancel()
|
||||
if spec_version >= (2, 4):
|
||||
try:
|
||||
await self.stream_response(send)
|
||||
except OSError:
|
||||
raise ClientDisconnect()
|
||||
else:
|
||||
with collapse_excgroups():
|
||||
async with anyio.create_task_group() as task_group:
|
||||
|
||||
task_group.start_soon(wrap, partial(self.stream_response, send))
|
||||
await wrap(partial(self.listen_for_disconnect, receive))
|
||||
async def wrap(func: Callable[[], Awaitable[None]]) -> None:
|
||||
await func()
|
||||
task_group.cancel_scope.cancel()
|
||||
|
||||
task_group.start_soon(wrap, partial(self.stream_response, send))
|
||||
await wrap(partial(self.listen_for_disconnect, receive))
|
||||
|
||||
if self.background is not None:
|
||||
await self.background()
|
||||
|
||||
|
||||
class MalformedRangeHeader(Exception):
|
||||
def __init__(self, content: str = "Malformed range header.") -> None:
|
||||
self.content = content
|
||||
|
||||
|
||||
class RangeNotSatisfiable(Exception):
|
||||
def __init__(self, max_size: int) -> None:
|
||||
self.max_size = max_size
|
||||
|
||||
|
||||
class FileResponse(Response):
|
||||
chunk_size = 64 * 1024
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path: typing.Union[str, "os.PathLike[str]"],
|
||||
path: str | os.PathLike[str],
|
||||
status_code: int = 200,
|
||||
headers: typing.Optional[typing.Mapping[str, str]] = None,
|
||||
media_type: typing.Optional[str] = None,
|
||||
background: typing.Optional[BackgroundTask] = None,
|
||||
filename: typing.Optional[str] = None,
|
||||
stat_result: typing.Optional[os.stat_result] = None,
|
||||
method: typing.Optional[str] = None,
|
||||
headers: Mapping[str, str] | None = None,
|
||||
media_type: str | None = None,
|
||||
background: BackgroundTask | None = None,
|
||||
filename: str | None = None,
|
||||
stat_result: os.stat_result | None = None,
|
||||
method: str | None = None,
|
||||
content_disposition_type: str = "attachment",
|
||||
) -> None:
|
||||
self.path = path
|
||||
self.status_code = status_code
|
||||
self.filename = filename
|
||||
self.send_header_only = method is not None and method.upper() == "HEAD"
|
||||
if method is not None:
|
||||
warnings.warn(
|
||||
"The 'method' parameter is not used, and it will be removed.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
if media_type is None:
|
||||
media_type = guess_type(filename or path)[0] or "text/plain"
|
||||
self.media_type = media_type
|
||||
self.background = background
|
||||
self.init_headers(headers)
|
||||
self.headers.setdefault("accept-ranges", "bytes")
|
||||
if self.filename is not None:
|
||||
content_disposition_filename = quote(self.filename)
|
||||
if content_disposition_filename != self.filename:
|
||||
content_disposition = "{}; filename*=utf-8''{}".format(
|
||||
content_disposition_type, content_disposition_filename
|
||||
)
|
||||
content_disposition = f"{content_disposition_type}; filename*=utf-8''{content_disposition_filename}"
|
||||
else:
|
||||
content_disposition = '{}; filename="{}"'.format(
|
||||
content_disposition_type, self.filename
|
||||
)
|
||||
content_disposition = f'{content_disposition_type}; filename="{self.filename}"'
|
||||
self.headers.setdefault("content-disposition", content_disposition)
|
||||
self.stat_result = stat_result
|
||||
if stat_result is not None:
|
||||
@@ -323,13 +334,16 @@ class FileResponse(Response):
|
||||
content_length = str(stat_result.st_size)
|
||||
last_modified = formatdate(stat_result.st_mtime, usegmt=True)
|
||||
etag_base = str(stat_result.st_mtime) + "-" + str(stat_result.st_size)
|
||||
etag = md5_hexdigest(etag_base.encode(), usedforsecurity=False)
|
||||
etag = f'"{hashlib.md5(etag_base.encode(), usedforsecurity=False).hexdigest()}"'
|
||||
|
||||
self.headers.setdefault("content-length", content_length)
|
||||
self.headers.setdefault("last-modified", last_modified)
|
||||
self.headers.setdefault("etag", etag)
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
send_header_only: bool = scope["method"].upper() == "HEAD"
|
||||
send_pathsend: bool = "http.response.pathsend" in scope.get("extensions", {})
|
||||
|
||||
if self.stat_result is None:
|
||||
try:
|
||||
stat_result = await anyio.to_thread.run_sync(os.stat, self.path)
|
||||
@@ -340,27 +354,213 @@ class FileResponse(Response):
|
||||
mode = stat_result.st_mode
|
||||
if not stat.S_ISREG(mode):
|
||||
raise RuntimeError(f"File at path {self.path} is not a file.")
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": self.status_code,
|
||||
"headers": self.raw_headers,
|
||||
}
|
||||
)
|
||||
if self.send_header_only:
|
||||
else:
|
||||
stat_result = self.stat_result
|
||||
|
||||
headers = Headers(scope=scope)
|
||||
http_range = headers.get("range")
|
||||
http_if_range = headers.get("if-range")
|
||||
|
||||
if http_range is None or (http_if_range is not None and not self._should_use_range(http_if_range)):
|
||||
await self._handle_simple(send, send_header_only, send_pathsend)
|
||||
else:
|
||||
try:
|
||||
ranges = self._parse_range_header(http_range, stat_result.st_size)
|
||||
except MalformedRangeHeader as exc:
|
||||
return await PlainTextResponse(exc.content, status_code=400)(scope, receive, send)
|
||||
except RangeNotSatisfiable as exc:
|
||||
response = PlainTextResponse(status_code=416, headers={"Content-Range": f"*/{exc.max_size}"})
|
||||
return await response(scope, receive, send)
|
||||
|
||||
if len(ranges) == 1:
|
||||
start, end = ranges[0]
|
||||
await self._handle_single_range(send, start, end, stat_result.st_size, send_header_only)
|
||||
else:
|
||||
await self._handle_multiple_ranges(send, ranges, stat_result.st_size, send_header_only)
|
||||
|
||||
if self.background is not None:
|
||||
await self.background()
|
||||
|
||||
async def _handle_simple(self, send: Send, send_header_only: bool, send_pathsend: bool) -> None:
|
||||
await send({"type": "http.response.start", "status": self.status_code, "headers": self.raw_headers})
|
||||
if send_header_only:
|
||||
await send({"type": "http.response.body", "body": b"", "more_body": False})
|
||||
elif send_pathsend:
|
||||
await send({"type": "http.response.pathsend", "path": str(self.path)})
|
||||
else:
|
||||
async with await anyio.open_file(self.path, mode="rb") as file:
|
||||
more_body = True
|
||||
while more_body:
|
||||
chunk = await file.read(self.chunk_size)
|
||||
more_body = len(chunk) == self.chunk_size
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.body",
|
||||
"body": chunk,
|
||||
"more_body": more_body,
|
||||
}
|
||||
)
|
||||
if self.background is not None:
|
||||
await self.background()
|
||||
await send({"type": "http.response.body", "body": chunk, "more_body": more_body})
|
||||
|
||||
async def _handle_single_range(
|
||||
self, send: Send, start: int, end: int, file_size: int, send_header_only: bool
|
||||
) -> None:
|
||||
self.headers["content-range"] = f"bytes {start}-{end - 1}/{file_size}"
|
||||
self.headers["content-length"] = str(end - start)
|
||||
await send({"type": "http.response.start", "status": 206, "headers": self.raw_headers})
|
||||
if send_header_only:
|
||||
await send({"type": "http.response.body", "body": b"", "more_body": False})
|
||||
else:
|
||||
async with await anyio.open_file(self.path, mode="rb") as file:
|
||||
await file.seek(start)
|
||||
more_body = True
|
||||
while more_body:
|
||||
chunk = await file.read(min(self.chunk_size, end - start))
|
||||
start += len(chunk)
|
||||
more_body = len(chunk) == self.chunk_size and start < end
|
||||
await send({"type": "http.response.body", "body": chunk, "more_body": more_body})
|
||||
|
||||
async def _handle_multiple_ranges(
|
||||
self,
|
||||
send: Send,
|
||||
ranges: list[tuple[int, int]],
|
||||
file_size: int,
|
||||
send_header_only: bool,
|
||||
) -> None:
|
||||
# In firefox and chrome, they use boundary with 95-96 bits entropy (that's roughly 13 bytes).
|
||||
boundary = token_hex(13)
|
||||
content_length, header_generator = self.generate_multipart(
|
||||
ranges, boundary, file_size, self.headers["content-type"]
|
||||
)
|
||||
self.headers["content-range"] = f"multipart/byteranges; boundary={boundary}"
|
||||
self.headers["content-length"] = str(content_length)
|
||||
await send({"type": "http.response.start", "status": 206, "headers": self.raw_headers})
|
||||
if send_header_only:
|
||||
await send({"type": "http.response.body", "body": b"", "more_body": False})
|
||||
else:
|
||||
async with await anyio.open_file(self.path, mode="rb") as file:
|
||||
for start, end in ranges:
|
||||
await send({"type": "http.response.body", "body": header_generator(start, end), "more_body": True})
|
||||
await file.seek(start)
|
||||
while start < end:
|
||||
chunk = await file.read(min(self.chunk_size, end - start))
|
||||
start += len(chunk)
|
||||
await send({"type": "http.response.body", "body": chunk, "more_body": True})
|
||||
await send({"type": "http.response.body", "body": b"\n", "more_body": True})
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.body",
|
||||
"body": f"\n--{boundary}--\n".encode("latin-1"),
|
||||
"more_body": False,
|
||||
}
|
||||
)
|
||||
|
||||
def _should_use_range(self, http_if_range: str) -> bool:
|
||||
return http_if_range == self.headers["last-modified"] or http_if_range == self.headers["etag"]
|
||||
|
||||
@classmethod
|
||||
def _parse_range_header(cls, http_range: str, file_size: int) -> list[tuple[int, int]]:
|
||||
ranges: list[tuple[int, int]] = []
|
||||
try:
|
||||
units, range_ = http_range.split("=", 1)
|
||||
except ValueError:
|
||||
raise MalformedRangeHeader()
|
||||
|
||||
units = units.strip().lower()
|
||||
|
||||
if units != "bytes":
|
||||
raise MalformedRangeHeader("Only support bytes range")
|
||||
|
||||
ranges = cls._parse_ranges(range_, file_size)
|
||||
|
||||
if len(ranges) == 0:
|
||||
raise MalformedRangeHeader("Range header: range must be requested")
|
||||
|
||||
if any(not (0 <= start < file_size) for start, _ in ranges):
|
||||
raise RangeNotSatisfiable(file_size)
|
||||
|
||||
if any(start > end for start, end in ranges):
|
||||
raise MalformedRangeHeader("Range header: start must be less than end")
|
||||
|
||||
if len(ranges) == 1:
|
||||
return ranges
|
||||
|
||||
# Merge ranges
|
||||
result: list[tuple[int, int]] = []
|
||||
for start, end in ranges:
|
||||
for p in range(len(result)):
|
||||
p_start, p_end = result[p]
|
||||
if start > p_end:
|
||||
continue
|
||||
elif end < p_start:
|
||||
result.insert(p, (start, end)) # THIS IS NOT REACHED!
|
||||
break
|
||||
else:
|
||||
result[p] = (min(start, p_start), max(end, p_end))
|
||||
break
|
||||
else:
|
||||
result.append((start, end))
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def _parse_ranges(cls, range_: str, file_size: int) -> list[tuple[int, int]]:
|
||||
ranges: list[tuple[int, int]] = []
|
||||
|
||||
for part in range_.split(","):
|
||||
part = part.strip()
|
||||
|
||||
# If the range is empty or a single dash, we ignore it.
|
||||
if not part or part == "-":
|
||||
continue
|
||||
|
||||
# If the range is not in the format "start-end", we ignore it.
|
||||
if "-" not in part:
|
||||
continue
|
||||
|
||||
start_str, end_str = part.split("-", 1)
|
||||
start_str = start_str.strip()
|
||||
end_str = end_str.strip()
|
||||
|
||||
try:
|
||||
start = int(start_str) if start_str else file_size - int(end_str)
|
||||
end = int(end_str) + 1 if start_str and end_str and int(end_str) < file_size else file_size
|
||||
ranges.append((start, end))
|
||||
except ValueError:
|
||||
# If the range is not numeric, we ignore it.
|
||||
continue
|
||||
|
||||
return ranges
|
||||
|
||||
def generate_multipart(
|
||||
self,
|
||||
ranges: Sequence[tuple[int, int]],
|
||||
boundary: str,
|
||||
max_size: int,
|
||||
content_type: str,
|
||||
) -> tuple[int, Callable[[int, int], bytes]]:
|
||||
r"""
|
||||
Multipart response headers generator.
|
||||
|
||||
```
|
||||
--{boundary}\n
|
||||
Content-Type: {content_type}\n
|
||||
Content-Range: bytes {start}-{end-1}/{max_size}\n
|
||||
\n
|
||||
..........content...........\n
|
||||
--{boundary}\n
|
||||
Content-Type: {content_type}\n
|
||||
Content-Range: bytes {start}-{end-1}/{max_size}\n
|
||||
\n
|
||||
..........content...........\n
|
||||
--{boundary}--\n
|
||||
```
|
||||
"""
|
||||
boundary_len = len(boundary)
|
||||
static_header_part_len = 44 + boundary_len + len(content_type) + len(str(max_size))
|
||||
content_length = sum(
|
||||
(len(str(start)) + len(str(end - 1)) + static_header_part_len) # Headers
|
||||
+ (end - start) # Content
|
||||
for start, end in ranges
|
||||
) + (
|
||||
5 + boundary_len # --boundary--\n
|
||||
)
|
||||
return (
|
||||
content_length,
|
||||
lambda start, end: (
|
||||
f"--{boundary}\nContent-Type: {content_type}\nContent-Range: bytes {start}-{end - 1}/{max_size}\n\n"
|
||||
).encode("latin-1"),
|
||||
)
|
||||
|
||||
@@ -1,22 +1,27 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
import inspect
|
||||
import re
|
||||
import traceback
|
||||
import types
|
||||
import typing
|
||||
import warnings
|
||||
from contextlib import asynccontextmanager
|
||||
from collections.abc import Awaitable, Callable, Collection, Generator, Sequence
|
||||
from contextlib import AbstractAsyncContextManager, AbstractContextManager, asynccontextmanager
|
||||
from enum import Enum
|
||||
from re import Pattern
|
||||
from typing import Any, TypeVar
|
||||
|
||||
from starlette._utils import is_async_callable
|
||||
from starlette._exception_handler import wrap_app_handling_exceptions
|
||||
from starlette._utils import get_route_path, is_async_callable
|
||||
from starlette.concurrency import run_in_threadpool
|
||||
from starlette.convertors import CONVERTOR_TYPES, Convertor
|
||||
from starlette.datastructures import URL, Headers, URLPath
|
||||
from starlette.exceptions import HTTPException
|
||||
from starlette.middleware import Middleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import PlainTextResponse, RedirectResponse
|
||||
from starlette.responses import PlainTextResponse, RedirectResponse, Response
|
||||
from starlette.types import ASGIApp, Lifespan, Receive, Scope, Send
|
||||
from starlette.websockets import WebSocket, WebSocketClose
|
||||
|
||||
@@ -27,7 +32,7 @@ class NoMatchFound(Exception):
|
||||
if no matching route exists.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, path_params: typing.Dict[str, typing.Any]) -> None:
|
||||
def __init__(self, name: str, path_params: dict[str, Any]) -> None:
|
||||
params = ", ".join(list(path_params.keys()))
|
||||
super().__init__(f'No route exists for name "{name}" and params "{params}".')
|
||||
|
||||
@@ -38,14 +43,13 @@ class Match(Enum):
|
||||
FULL = 2
|
||||
|
||||
|
||||
def iscoroutinefunction_or_partial(obj: typing.Any) -> bool: # pragma: no cover
|
||||
def iscoroutinefunction_or_partial(obj: Any) -> bool: # pragma: no cover
|
||||
"""
|
||||
Correctly determines if an object is a coroutine function,
|
||||
including those wrapped in functools.partial objects.
|
||||
"""
|
||||
warnings.warn(
|
||||
"iscoroutinefunction_or_partial is deprecated, "
|
||||
"and will be removed in a future release.",
|
||||
"iscoroutinefunction_or_partial is deprecated, and will be removed in a future release.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
while isinstance(obj, functools.partial):
|
||||
@@ -53,25 +57,32 @@ def iscoroutinefunction_or_partial(obj: typing.Any) -> bool: # pragma: no cover
|
||||
return inspect.iscoroutinefunction(obj)
|
||||
|
||||
|
||||
def request_response(func: typing.Callable) -> ASGIApp:
|
||||
def request_response(
|
||||
func: Callable[[Request], Awaitable[Response] | Response],
|
||||
) -> ASGIApp:
|
||||
"""
|
||||
Takes a function or coroutine `func(request) -> response`,
|
||||
and returns an ASGI application.
|
||||
"""
|
||||
is_coroutine = is_async_callable(func)
|
||||
f: Callable[[Request], Awaitable[Response]] = (
|
||||
func if is_async_callable(func) else functools.partial(run_in_threadpool, func)
|
||||
)
|
||||
|
||||
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
||||
request = Request(scope, receive=receive, send=send)
|
||||
if is_coroutine:
|
||||
response = await func(request)
|
||||
else:
|
||||
response = await run_in_threadpool(func, request)
|
||||
await response(scope, receive, send)
|
||||
request = Request(scope, receive, send)
|
||||
|
||||
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
||||
response = await f(request)
|
||||
await response(scope, receive, send)
|
||||
|
||||
await wrap_app_handling_exceptions(app, request)(scope, receive, send)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def websocket_session(func: typing.Callable) -> ASGIApp:
|
||||
def websocket_session(
|
||||
func: Callable[[WebSocket], Awaitable[None]],
|
||||
) -> ASGIApp:
|
||||
"""
|
||||
Takes a coroutine `func(session)`, and returns an ASGI application.
|
||||
"""
|
||||
@@ -79,22 +90,24 @@ def websocket_session(func: typing.Callable) -> ASGIApp:
|
||||
|
||||
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
||||
session = WebSocket(scope, receive=receive, send=send)
|
||||
await func(session)
|
||||
|
||||
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
||||
await func(session)
|
||||
|
||||
await wrap_app_handling_exceptions(app, session)(scope, receive, send)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def get_name(endpoint: typing.Callable) -> str:
|
||||
if inspect.isroutine(endpoint) or inspect.isclass(endpoint):
|
||||
return endpoint.__name__
|
||||
return endpoint.__class__.__name__
|
||||
def get_name(endpoint: Callable[..., Any]) -> str:
|
||||
return getattr(endpoint, "__name__", endpoint.__class__.__name__)
|
||||
|
||||
|
||||
def replace_params(
|
||||
path: str,
|
||||
param_convertors: typing.Dict[str, Convertor],
|
||||
path_params: typing.Dict[str, str],
|
||||
) -> typing.Tuple[str, dict]:
|
||||
param_convertors: dict[str, Convertor[Any]],
|
||||
path_params: dict[str, str],
|
||||
) -> tuple[str, dict[str, str]]:
|
||||
for key, value in list(path_params.items()):
|
||||
if "{" + key + "}" in path:
|
||||
convertor = param_convertors[key]
|
||||
@@ -110,7 +123,7 @@ PARAM_REGEX = re.compile("{([a-zA-Z_][a-zA-Z0-9_]*)(:[a-zA-Z_][a-zA-Z0-9_]*)?}")
|
||||
|
||||
def compile_path(
|
||||
path: str,
|
||||
) -> typing.Tuple[typing.Pattern, str, typing.Dict[str, Convertor]]:
|
||||
) -> tuple[Pattern[str], str, dict[str, Convertor[Any]]]:
|
||||
"""
|
||||
Given a path string, like: "/{username:str}",
|
||||
or a host string, like: "{subdomain}.mydomain.org", return a three-tuple
|
||||
@@ -131,9 +144,7 @@ def compile_path(
|
||||
for match in PARAM_REGEX.finditer(path):
|
||||
param_name, convertor_type = match.groups("str")
|
||||
convertor_type = convertor_type.lstrip(":")
|
||||
assert (
|
||||
convertor_type in CONVERTOR_TYPES
|
||||
), f"Unknown path convertor '{convertor_type}'"
|
||||
assert convertor_type in CONVERTOR_TYPES, f"Unknown path convertor '{convertor_type}'"
|
||||
convertor = CONVERTOR_TYPES[convertor_type]
|
||||
|
||||
path_regex += re.escape(path[idx : match.start()])
|
||||
@@ -167,10 +178,10 @@ def compile_path(
|
||||
|
||||
|
||||
class BaseRoute:
|
||||
def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
|
||||
def matches(self, scope: Scope) -> tuple[Match, Scope]:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
def url_path_for(self, __name: str, **path_params: typing.Any) -> URLPath:
|
||||
def url_path_for(self, name: str, /, **path_params: Any) -> URLPath:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
async def handle(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
@@ -187,7 +198,7 @@ class BaseRoute:
|
||||
if scope["type"] == "http":
|
||||
response = PlainTextResponse("Not Found", status_code=404)
|
||||
await response(scope, receive, send)
|
||||
elif scope["type"] == "websocket":
|
||||
elif scope["type"] == "websocket": # pragma: no branch
|
||||
websocket_close = WebSocketClose()
|
||||
await websocket_close(scope, receive, send)
|
||||
return
|
||||
@@ -200,11 +211,12 @@ class Route(BaseRoute):
|
||||
def __init__(
|
||||
self,
|
||||
path: str,
|
||||
endpoint: typing.Callable,
|
||||
endpoint: Callable[..., Any],
|
||||
*,
|
||||
methods: typing.Optional[typing.List[str]] = None,
|
||||
name: typing.Optional[str] = None,
|
||||
methods: Collection[str] | None = None,
|
||||
name: str | None = None,
|
||||
include_in_schema: bool = True,
|
||||
middleware: Sequence[Middleware] | None = None,
|
||||
) -> None:
|
||||
assert path.startswith("/"), "Routed paths must start with '/'"
|
||||
self.path = path
|
||||
@@ -224,6 +236,10 @@ class Route(BaseRoute):
|
||||
# Endpoint is a class. Treat it as ASGI.
|
||||
self.app = endpoint
|
||||
|
||||
if middleware is not None:
|
||||
for cls, args, kwargs in reversed(middleware):
|
||||
self.app = cls(self.app, *args, **kwargs)
|
||||
|
||||
if methods is None:
|
||||
self.methods = None
|
||||
else:
|
||||
@@ -233,9 +249,11 @@ class Route(BaseRoute):
|
||||
|
||||
self.path_regex, self.path_format, self.param_convertors = compile_path(path)
|
||||
|
||||
def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
|
||||
def matches(self, scope: Scope) -> tuple[Match, Scope]:
|
||||
path_params: dict[str, Any]
|
||||
if scope["type"] == "http":
|
||||
match = self.path_regex.match(scope["path"])
|
||||
route_path = get_route_path(scope)
|
||||
match = self.path_regex.match(route_path)
|
||||
if match:
|
||||
matched_params = match.groupdict()
|
||||
for key, value in matched_params.items():
|
||||
@@ -249,16 +267,14 @@ class Route(BaseRoute):
|
||||
return Match.FULL, child_scope
|
||||
return Match.NONE, {}
|
||||
|
||||
def url_path_for(self, __name: str, **path_params: typing.Any) -> URLPath:
|
||||
def url_path_for(self, name: str, /, **path_params: Any) -> URLPath:
|
||||
seen_params = set(path_params.keys())
|
||||
expected_params = set(self.param_convertors.keys())
|
||||
|
||||
if __name != self.name or seen_params != expected_params:
|
||||
raise NoMatchFound(__name, path_params)
|
||||
if name != self.name or seen_params != expected_params:
|
||||
raise NoMatchFound(name, path_params)
|
||||
|
||||
path, remaining_params = replace_params(
|
||||
self.path_format, self.param_convertors, path_params
|
||||
)
|
||||
path, remaining_params = replace_params(self.path_format, self.param_convertors, path_params)
|
||||
assert not remaining_params
|
||||
return URLPath(path=path, protocol="http")
|
||||
|
||||
@@ -268,14 +284,12 @@ class Route(BaseRoute):
|
||||
if "app" in scope:
|
||||
raise HTTPException(status_code=405, headers=headers)
|
||||
else:
|
||||
response = PlainTextResponse(
|
||||
"Method Not Allowed", status_code=405, headers=headers
|
||||
)
|
||||
response = PlainTextResponse("Method Not Allowed", status_code=405, headers=headers)
|
||||
await response(scope, receive, send)
|
||||
else:
|
||||
await self.app(scope, receive, send)
|
||||
|
||||
def __eq__(self, other: typing.Any) -> bool:
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
return (
|
||||
isinstance(other, Route)
|
||||
and self.path == other.path
|
||||
@@ -292,7 +306,12 @@ class Route(BaseRoute):
|
||||
|
||||
class WebSocketRoute(BaseRoute):
|
||||
def __init__(
|
||||
self, path: str, endpoint: typing.Callable, *, name: typing.Optional[str] = None
|
||||
self,
|
||||
path: str,
|
||||
endpoint: Callable[..., Any],
|
||||
*,
|
||||
name: str | None = None,
|
||||
middleware: Sequence[Middleware] | None = None,
|
||||
) -> None:
|
||||
assert path.startswith("/"), "Routed paths must start with '/'"
|
||||
self.path = path
|
||||
@@ -309,11 +328,17 @@ class WebSocketRoute(BaseRoute):
|
||||
# Endpoint is a class. Treat it as ASGI.
|
||||
self.app = endpoint
|
||||
|
||||
if middleware is not None:
|
||||
for cls, args, kwargs in reversed(middleware):
|
||||
self.app = cls(self.app, *args, **kwargs)
|
||||
|
||||
self.path_regex, self.path_format, self.param_convertors = compile_path(path)
|
||||
|
||||
def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
|
||||
def matches(self, scope: Scope) -> tuple[Match, Scope]:
|
||||
path_params: dict[str, Any]
|
||||
if scope["type"] == "websocket":
|
||||
match = self.path_regex.match(scope["path"])
|
||||
route_path = get_route_path(scope)
|
||||
match = self.path_regex.match(route_path)
|
||||
if match:
|
||||
matched_params = match.groupdict()
|
||||
for key, value in matched_params.items():
|
||||
@@ -324,28 +349,22 @@ class WebSocketRoute(BaseRoute):
|
||||
return Match.FULL, child_scope
|
||||
return Match.NONE, {}
|
||||
|
||||
def url_path_for(self, __name: str, **path_params: typing.Any) -> URLPath:
|
||||
def url_path_for(self, name: str, /, **path_params: Any) -> URLPath:
|
||||
seen_params = set(path_params.keys())
|
||||
expected_params = set(self.param_convertors.keys())
|
||||
|
||||
if __name != self.name or seen_params != expected_params:
|
||||
raise NoMatchFound(__name, path_params)
|
||||
if name != self.name or seen_params != expected_params:
|
||||
raise NoMatchFound(name, path_params)
|
||||
|
||||
path, remaining_params = replace_params(
|
||||
self.path_format, self.param_convertors, path_params
|
||||
)
|
||||
path, remaining_params = replace_params(self.path_format, self.param_convertors, path_params)
|
||||
assert not remaining_params
|
||||
return URLPath(path=path, protocol="websocket")
|
||||
|
||||
async def handle(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
await self.app(scope, receive, send)
|
||||
|
||||
def __eq__(self, other: typing.Any) -> bool:
|
||||
return (
|
||||
isinstance(other, WebSocketRoute)
|
||||
and self.path == other.path
|
||||
and self.endpoint == other.endpoint
|
||||
)
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
return isinstance(other, WebSocketRoute) and self.path == other.path and self.endpoint == other.endpoint
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(path={self.path!r}, name={self.name!r})"
|
||||
@@ -355,16 +374,14 @@ class Mount(BaseRoute):
|
||||
def __init__(
|
||||
self,
|
||||
path: str,
|
||||
app: typing.Optional[ASGIApp] = None,
|
||||
routes: typing.Optional[typing.Sequence[BaseRoute]] = None,
|
||||
name: typing.Optional[str] = None,
|
||||
app: ASGIApp | None = None,
|
||||
routes: Sequence[BaseRoute] | None = None,
|
||||
name: str | None = None,
|
||||
*,
|
||||
middleware: typing.Optional[typing.Sequence[Middleware]] = None,
|
||||
middleware: Sequence[Middleware] | None = None,
|
||||
) -> None:
|
||||
assert path == "" or path.startswith("/"), "Routed paths must start with '/'"
|
||||
assert (
|
||||
app is not None or routes is not None
|
||||
), "Either 'app=...', or 'routes=' must be specified"
|
||||
assert app is not None or routes is not None, "Either 'app=...', or 'routes=' must be specified"
|
||||
self.path = path.rstrip("/")
|
||||
if app is not None:
|
||||
self._base_app: ASGIApp = app
|
||||
@@ -372,82 +389,80 @@ class Mount(BaseRoute):
|
||||
self._base_app = Router(routes=routes)
|
||||
self.app = self._base_app
|
||||
if middleware is not None:
|
||||
for cls, options in reversed(middleware):
|
||||
self.app = cls(app=self.app, **options)
|
||||
for cls, args, kwargs in reversed(middleware):
|
||||
self.app = cls(self.app, *args, **kwargs)
|
||||
self.name = name
|
||||
self.path_regex, self.path_format, self.param_convertors = compile_path(
|
||||
self.path + "/{path:path}"
|
||||
)
|
||||
self.path_regex, self.path_format, self.param_convertors = compile_path(self.path + "/{path:path}")
|
||||
|
||||
@property
|
||||
def routes(self) -> typing.List[BaseRoute]:
|
||||
def routes(self) -> list[BaseRoute]:
|
||||
return getattr(self._base_app, "routes", [])
|
||||
|
||||
def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
|
||||
if scope["type"] in ("http", "websocket"):
|
||||
path = scope["path"]
|
||||
match = self.path_regex.match(path)
|
||||
def matches(self, scope: Scope) -> tuple[Match, Scope]:
|
||||
path_params: dict[str, Any]
|
||||
if scope["type"] in ("http", "websocket"): # pragma: no branch
|
||||
root_path = scope.get("root_path", "")
|
||||
route_path = get_route_path(scope)
|
||||
match = self.path_regex.match(route_path)
|
||||
if match:
|
||||
matched_params = match.groupdict()
|
||||
for key, value in matched_params.items():
|
||||
matched_params[key] = self.param_convertors[key].convert(value)
|
||||
remaining_path = "/" + matched_params.pop("path")
|
||||
matched_path = path[: -len(remaining_path)]
|
||||
matched_path = route_path[: -len(remaining_path)]
|
||||
path_params = dict(scope.get("path_params", {}))
|
||||
path_params.update(matched_params)
|
||||
root_path = scope.get("root_path", "")
|
||||
child_scope = {
|
||||
"path_params": path_params,
|
||||
# app_root_path will only be set at the top level scope,
|
||||
# initialized with the (optional) value of a root_path
|
||||
# set above/before Starlette. And even though any
|
||||
# mount will have its own child scope with its own respective
|
||||
# root_path, the app_root_path will always be available in all
|
||||
# the child scopes with the same top level value because it's
|
||||
# set only once here with a default, any other child scope will
|
||||
# just inherit that app_root_path default value stored in the
|
||||
# scope. All this is needed to support Request.url_for(), as it
|
||||
# uses the app_root_path to build the URL path.
|
||||
"app_root_path": scope.get("app_root_path", root_path),
|
||||
"root_path": root_path + matched_path,
|
||||
"path": remaining_path,
|
||||
"endpoint": self.app,
|
||||
}
|
||||
return Match.FULL, child_scope
|
||||
return Match.NONE, {}
|
||||
|
||||
def url_path_for(self, __name: str, **path_params: typing.Any) -> URLPath:
|
||||
if self.name is not None and __name == self.name and "path" in path_params:
|
||||
def url_path_for(self, name: str, /, **path_params: Any) -> URLPath:
|
||||
if self.name is not None and name == self.name and "path" in path_params:
|
||||
# 'name' matches "<mount_name>".
|
||||
path_params["path"] = path_params["path"].lstrip("/")
|
||||
path, remaining_params = replace_params(
|
||||
self.path_format, self.param_convertors, path_params
|
||||
)
|
||||
path, remaining_params = replace_params(self.path_format, self.param_convertors, path_params)
|
||||
if not remaining_params:
|
||||
return URLPath(path=path)
|
||||
elif self.name is None or __name.startswith(self.name + ":"):
|
||||
elif self.name is None or name.startswith(self.name + ":"):
|
||||
if self.name is None:
|
||||
# No mount name.
|
||||
remaining_name = __name
|
||||
remaining_name = name
|
||||
else:
|
||||
# 'name' matches "<mount_name>:<child_name>".
|
||||
remaining_name = __name[len(self.name) + 1 :]
|
||||
remaining_name = name[len(self.name) + 1 :]
|
||||
path_kwarg = path_params.get("path")
|
||||
path_params["path"] = ""
|
||||
path_prefix, remaining_params = replace_params(
|
||||
self.path_format, self.param_convertors, path_params
|
||||
)
|
||||
path_prefix, remaining_params = replace_params(self.path_format, self.param_convertors, path_params)
|
||||
if path_kwarg is not None:
|
||||
remaining_params["path"] = path_kwarg
|
||||
for route in self.routes or []:
|
||||
try:
|
||||
url = route.url_path_for(remaining_name, **remaining_params)
|
||||
return URLPath(
|
||||
path=path_prefix.rstrip("/") + str(url), protocol=url.protocol
|
||||
)
|
||||
return URLPath(path=path_prefix.rstrip("/") + str(url), protocol=url.protocol)
|
||||
except NoMatchFound:
|
||||
pass
|
||||
raise NoMatchFound(__name, path_params)
|
||||
raise NoMatchFound(name, path_params)
|
||||
|
||||
async def handle(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
await self.app(scope, receive, send)
|
||||
|
||||
def __eq__(self, other: typing.Any) -> bool:
|
||||
return (
|
||||
isinstance(other, Mount)
|
||||
and self.path == other.path
|
||||
and self.app == other.app
|
||||
)
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
return isinstance(other, Mount) and self.path == other.path and self.app == other.app
|
||||
|
||||
def __repr__(self) -> str:
|
||||
class_name = self.__class__.__name__
|
||||
@@ -456,9 +471,7 @@ class Mount(BaseRoute):
|
||||
|
||||
|
||||
class Host(BaseRoute):
|
||||
def __init__(
|
||||
self, host: str, app: ASGIApp, name: typing.Optional[str] = None
|
||||
) -> None:
|
||||
def __init__(self, host: str, app: ASGIApp, name: str | None = None) -> None:
|
||||
assert not host.startswith("/"), "Host must not start with '/'"
|
||||
self.host = host
|
||||
self.app = app
|
||||
@@ -466,11 +479,11 @@ class Host(BaseRoute):
|
||||
self.host_regex, self.host_format, self.param_convertors = compile_path(host)
|
||||
|
||||
@property
|
||||
def routes(self) -> typing.List[BaseRoute]:
|
||||
def routes(self) -> list[BaseRoute]:
|
||||
return getattr(self.app, "routes", [])
|
||||
|
||||
def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
|
||||
if scope["type"] in ("http", "websocket"):
|
||||
def matches(self, scope: Scope) -> tuple[Match, Scope]:
|
||||
if scope["type"] in ("http", "websocket"): # pragma:no branch
|
||||
headers = Headers(scope=scope)
|
||||
host = headers.get("host", "").split(":")[0]
|
||||
match = self.host_regex.match(host)
|
||||
@@ -484,42 +497,34 @@ class Host(BaseRoute):
|
||||
return Match.FULL, child_scope
|
||||
return Match.NONE, {}
|
||||
|
||||
def url_path_for(self, __name: str, **path_params: typing.Any) -> URLPath:
|
||||
if self.name is not None and __name == self.name and "path" in path_params:
|
||||
def url_path_for(self, name: str, /, **path_params: Any) -> URLPath:
|
||||
if self.name is not None and name == self.name and "path" in path_params:
|
||||
# 'name' matches "<mount_name>".
|
||||
path = path_params.pop("path")
|
||||
host, remaining_params = replace_params(
|
||||
self.host_format, self.param_convertors, path_params
|
||||
)
|
||||
host, remaining_params = replace_params(self.host_format, self.param_convertors, path_params)
|
||||
if not remaining_params:
|
||||
return URLPath(path=path, host=host)
|
||||
elif self.name is None or __name.startswith(self.name + ":"):
|
||||
elif self.name is None or name.startswith(self.name + ":"):
|
||||
if self.name is None:
|
||||
# No mount name.
|
||||
remaining_name = __name
|
||||
remaining_name = name
|
||||
else:
|
||||
# 'name' matches "<mount_name>:<child_name>".
|
||||
remaining_name = __name[len(self.name) + 1 :]
|
||||
host, remaining_params = replace_params(
|
||||
self.host_format, self.param_convertors, path_params
|
||||
)
|
||||
remaining_name = name[len(self.name) + 1 :]
|
||||
host, remaining_params = replace_params(self.host_format, self.param_convertors, path_params)
|
||||
for route in self.routes or []:
|
||||
try:
|
||||
url = route.url_path_for(remaining_name, **remaining_params)
|
||||
return URLPath(path=str(url), protocol=url.protocol, host=host)
|
||||
except NoMatchFound:
|
||||
pass
|
||||
raise NoMatchFound(__name, path_params)
|
||||
raise NoMatchFound(name, path_params)
|
||||
|
||||
async def handle(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
await self.app(scope, receive, send)
|
||||
|
||||
def __eq__(self, other: typing.Any) -> bool:
|
||||
return (
|
||||
isinstance(other, Host)
|
||||
and self.host == other.host
|
||||
and self.app == other.app
|
||||
)
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
return isinstance(other, Host) and self.host == other.host and self.app == other.app
|
||||
|
||||
def __repr__(self) -> str:
|
||||
class_name = self.__class__.__name__
|
||||
@@ -527,11 +532,11 @@ class Host(BaseRoute):
|
||||
return f"{class_name}(host={self.host!r}, name={name!r}, app={self.app!r})"
|
||||
|
||||
|
||||
_T = typing.TypeVar("_T")
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
class _AsyncLiftContextManager(typing.AsyncContextManager[_T]):
|
||||
def __init__(self, cm: typing.ContextManager[_T]):
|
||||
class _AsyncLiftContextManager(AbstractAsyncContextManager[_T]):
|
||||
def __init__(self, cm: AbstractContextManager[_T]):
|
||||
self._cm = cm
|
||||
|
||||
async def __aenter__(self) -> _T:
|
||||
@@ -539,27 +544,27 @@ class _AsyncLiftContextManager(typing.AsyncContextManager[_T]):
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: typing.Optional[typing.Type[BaseException]],
|
||||
exc_value: typing.Optional[BaseException],
|
||||
traceback: typing.Optional[types.TracebackType],
|
||||
) -> typing.Optional[bool]:
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_value: BaseException | None,
|
||||
traceback: types.TracebackType | None,
|
||||
) -> bool | None:
|
||||
return self._cm.__exit__(exc_type, exc_value, traceback)
|
||||
|
||||
|
||||
def _wrap_gen_lifespan_context(
|
||||
lifespan_context: typing.Callable[[typing.Any], typing.Generator]
|
||||
) -> typing.Callable[[typing.Any], typing.AsyncContextManager]:
|
||||
lifespan_context: Callable[[Any], Generator[Any, Any, Any]],
|
||||
) -> Callable[[Any], AbstractAsyncContextManager[Any]]:
|
||||
cmgr = contextlib.contextmanager(lifespan_context)
|
||||
|
||||
@functools.wraps(cmgr)
|
||||
def wrapper(app: typing.Any) -> _AsyncLiftContextManager:
|
||||
def wrapper(app: Any) -> _AsyncLiftContextManager[Any]:
|
||||
return _AsyncLiftContextManager(cmgr(app))
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class _DefaultLifespan:
|
||||
def __init__(self, router: "Router"):
|
||||
def __init__(self, router: Router):
|
||||
self._router = router
|
||||
|
||||
async def __aenter__(self) -> None:
|
||||
@@ -575,14 +580,16 @@ class _DefaultLifespan:
|
||||
class Router:
|
||||
def __init__(
|
||||
self,
|
||||
routes: typing.Optional[typing.Sequence[BaseRoute]] = None,
|
||||
routes: Sequence[BaseRoute] | None = None,
|
||||
redirect_slashes: bool = True,
|
||||
default: typing.Optional[ASGIApp] = None,
|
||||
on_startup: typing.Optional[typing.Sequence[typing.Callable]] = None,
|
||||
on_shutdown: typing.Optional[typing.Sequence[typing.Callable]] = None,
|
||||
default: ASGIApp | None = None,
|
||||
on_startup: Sequence[Callable[[], Any]] | None = None,
|
||||
on_shutdown: Sequence[Callable[[], Any]] | None = None,
|
||||
# the generic to Lifespan[AppType] is the type of the top level application
|
||||
# which the router cannot know statically, so we use typing.Any
|
||||
lifespan: typing.Optional[Lifespan[typing.Any]] = None,
|
||||
# which the router cannot know statically, so we use Any
|
||||
lifespan: Lifespan[Any] | None = None,
|
||||
*,
|
||||
middleware: Sequence[Middleware] | None = None,
|
||||
) -> None:
|
||||
self.routes = [] if routes is None else list(routes)
|
||||
self.redirect_slashes = redirect_slashes
|
||||
@@ -594,12 +601,18 @@ class Router:
|
||||
warnings.warn(
|
||||
"The on_startup and on_shutdown parameters are deprecated, and they "
|
||||
"will be removed on version 1.0. Use the lifespan parameter instead. "
|
||||
"See more about it on https://www.starlette.io/lifespan/.",
|
||||
"See more about it on https://starlette.dev/lifespan/.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
if lifespan:
|
||||
warnings.warn(
|
||||
"The `lifespan` parameter cannot be used with `on_startup` or "
|
||||
"`on_shutdown`. Both `on_startup` and `on_shutdown` will be "
|
||||
"ignored."
|
||||
)
|
||||
|
||||
if lifespan is None:
|
||||
self.lifespan_context: Lifespan = _DefaultLifespan(self)
|
||||
self.lifespan_context: Lifespan[Any] = _DefaultLifespan(self)
|
||||
|
||||
elif inspect.isasyncgenfunction(lifespan):
|
||||
warnings.warn(
|
||||
@@ -608,20 +621,24 @@ class Router:
|
||||
DeprecationWarning,
|
||||
)
|
||||
self.lifespan_context = asynccontextmanager(
|
||||
lifespan, # type: ignore[arg-type]
|
||||
lifespan,
|
||||
)
|
||||
elif inspect.isgeneratorfunction(lifespan):
|
||||
warnings.warn(
|
||||
"generator function lifespans are deprecated, "
|
||||
"use an @contextlib.asynccontextmanager function instead",
|
||||
"generator function lifespans are deprecated, use an @contextlib.asynccontextmanager function instead",
|
||||
DeprecationWarning,
|
||||
)
|
||||
self.lifespan_context = _wrap_gen_lifespan_context(
|
||||
lifespan, # type: ignore[arg-type]
|
||||
lifespan,
|
||||
)
|
||||
else:
|
||||
self.lifespan_context = lifespan
|
||||
|
||||
self.middleware_stack = self.app
|
||||
if middleware:
|
||||
for cls, args, kwargs in reversed(middleware):
|
||||
self.middleware_stack = cls(self.middleware_stack, *args, **kwargs)
|
||||
|
||||
async def not_found(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if scope["type"] == "websocket":
|
||||
websocket_close = WebSocketClose()
|
||||
@@ -637,13 +654,13 @@ class Router:
|
||||
response = PlainTextResponse("Not Found", status_code=404)
|
||||
await response(scope, receive, send)
|
||||
|
||||
def url_path_for(self, __name: str, **path_params: typing.Any) -> URLPath:
|
||||
def url_path_for(self, name: str, /, **path_params: Any) -> URLPath:
|
||||
for route in self.routes:
|
||||
try:
|
||||
return route.url_path_for(__name, **path_params)
|
||||
return route.url_path_for(name, **path_params)
|
||||
except NoMatchFound:
|
||||
pass
|
||||
raise NoMatchFound(__name, path_params)
|
||||
raise NoMatchFound(name, path_params)
|
||||
|
||||
async def startup(self) -> None:
|
||||
"""
|
||||
@@ -671,15 +688,13 @@ class Router:
|
||||
startup and shutdown events.
|
||||
"""
|
||||
started = False
|
||||
app: typing.Any = scope.get("app")
|
||||
app: Any = scope.get("app")
|
||||
await receive()
|
||||
try:
|
||||
async with self.lifespan_context(app) as maybe_state:
|
||||
if maybe_state is not None:
|
||||
if "state" not in scope:
|
||||
raise RuntimeError(
|
||||
'The server does not support "state" in the lifespan scope.'
|
||||
)
|
||||
raise RuntimeError('The server does not support "state" in the lifespan scope.')
|
||||
scope["state"].update(maybe_state)
|
||||
await send({"type": "lifespan.startup.complete"})
|
||||
started = True
|
||||
@@ -698,6 +713,9 @@ class Router:
|
||||
"""
|
||||
The main entry point to the Router class.
|
||||
"""
|
||||
await self.middleware_stack(scope, receive, send)
|
||||
|
||||
async def app(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
assert scope["type"] in ("http", "websocket", "lifespan")
|
||||
|
||||
if "router" not in scope:
|
||||
@@ -729,9 +747,10 @@ class Router:
|
||||
await partial.handle(scope, receive, send)
|
||||
return
|
||||
|
||||
if scope["type"] == "http" and self.redirect_slashes and scope["path"] != "/":
|
||||
route_path = get_route_path(scope)
|
||||
if scope["type"] == "http" and self.redirect_slashes and route_path != "/":
|
||||
redirect_scope = dict(scope)
|
||||
if scope["path"].endswith("/"):
|
||||
if route_path.endswith("/"):
|
||||
redirect_scope["path"] = redirect_scope["path"].rstrip("/")
|
||||
else:
|
||||
redirect_scope["path"] = redirect_scope["path"] + "/"
|
||||
@@ -746,29 +765,25 @@ class Router:
|
||||
|
||||
await self.default(scope, receive, send)
|
||||
|
||||
def __eq__(self, other: typing.Any) -> bool:
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
return isinstance(other, Router) and self.routes == other.routes
|
||||
|
||||
def mount(
|
||||
self, path: str, app: ASGIApp, name: typing.Optional[str] = None
|
||||
) -> None: # pragma: nocover
|
||||
def mount(self, path: str, app: ASGIApp, name: str | None = None) -> None: # pragma: no cover
|
||||
route = Mount(path, app=app, name=name)
|
||||
self.routes.append(route)
|
||||
|
||||
def host(
|
||||
self, host: str, app: ASGIApp, name: typing.Optional[str] = None
|
||||
) -> None: # pragma: no cover
|
||||
def host(self, host: str, app: ASGIApp, name: str | None = None) -> None: # pragma: no cover
|
||||
route = Host(host, app=app, name=name)
|
||||
self.routes.append(route)
|
||||
|
||||
def add_route(
|
||||
self,
|
||||
path: str,
|
||||
endpoint: typing.Callable,
|
||||
methods: typing.Optional[typing.List[str]] = None,
|
||||
name: typing.Optional[str] = None,
|
||||
endpoint: Callable[[Request], Awaitable[Response] | Response],
|
||||
methods: Collection[str] | None = None,
|
||||
name: str | None = None,
|
||||
include_in_schema: bool = True,
|
||||
) -> None: # pragma: nocover
|
||||
) -> None: # pragma: no cover
|
||||
route = Route(
|
||||
path,
|
||||
endpoint=endpoint,
|
||||
@@ -779,7 +794,10 @@ class Router:
|
||||
self.routes.append(route)
|
||||
|
||||
def add_websocket_route(
|
||||
self, path: str, endpoint: typing.Callable, name: typing.Optional[str] = None
|
||||
self,
|
||||
path: str,
|
||||
endpoint: Callable[[WebSocket], Awaitable[None]],
|
||||
name: str | None = None,
|
||||
) -> None: # pragma: no cover
|
||||
route = WebSocketRoute(path, endpoint=endpoint, name=name)
|
||||
self.routes.append(route)
|
||||
@@ -787,10 +805,10 @@ class Router:
|
||||
def route(
|
||||
self,
|
||||
path: str,
|
||||
methods: typing.Optional[typing.List[str]] = None,
|
||||
name: typing.Optional[str] = None,
|
||||
methods: Collection[str] | None = None,
|
||||
name: str | None = None,
|
||||
include_in_schema: bool = True,
|
||||
) -> typing.Callable:
|
||||
) -> Callable: # type: ignore[type-arg]
|
||||
"""
|
||||
We no longer document this decorator style API, and its usage is discouraged.
|
||||
Instead you should use the following approach:
|
||||
@@ -800,11 +818,11 @@ class Router:
|
||||
"""
|
||||
warnings.warn(
|
||||
"The `route` decorator is deprecated, and will be removed in version 1.0.0."
|
||||
"Refer to https://www.starlette.io/routing/#http-routing for the recommended approach.", # noqa: E501
|
||||
"Refer to https://starlette.dev/routing/#http-routing for the recommended approach.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
def decorator(func: typing.Callable) -> typing.Callable:
|
||||
def decorator(func: Callable) -> Callable: # type: ignore[type-arg]
|
||||
self.add_route(
|
||||
path,
|
||||
func,
|
||||
@@ -816,9 +834,7 @@ class Router:
|
||||
|
||||
return decorator
|
||||
|
||||
def websocket_route(
|
||||
self, path: str, name: typing.Optional[str] = None
|
||||
) -> typing.Callable:
|
||||
def websocket_route(self, path: str, name: str | None = None) -> Callable: # type: ignore[type-arg]
|
||||
"""
|
||||
We no longer document this decorator style API, and its usage is discouraged.
|
||||
Instead you should use the following approach:
|
||||
@@ -827,20 +843,18 @@ class Router:
|
||||
>>> app = Starlette(routes=routes)
|
||||
"""
|
||||
warnings.warn(
|
||||
"The `websocket_route` decorator is deprecated, and will be removed in version 1.0.0. Refer to " # noqa: E501
|
||||
"https://www.starlette.io/routing/#websocket-routing for the recommended approach.", # noqa: E501
|
||||
"The `websocket_route` decorator is deprecated, and will be removed in version 1.0.0. Refer to "
|
||||
"https://starlette.dev/routing/#websocket-routing for the recommended approach.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
def decorator(func: typing.Callable) -> typing.Callable:
|
||||
def decorator(func: Callable) -> Callable: # type: ignore[type-arg]
|
||||
self.add_websocket_route(path, func, name=name)
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def add_event_handler(
|
||||
self, event_type: str, func: typing.Callable
|
||||
) -> None: # pragma: no cover
|
||||
def add_event_handler(self, event_type: str, func: Callable[[], Any]) -> None: # pragma: no cover
|
||||
assert event_type in ("startup", "shutdown")
|
||||
|
||||
if event_type == "startup":
|
||||
@@ -848,14 +862,14 @@ class Router:
|
||||
else:
|
||||
self.on_shutdown.append(func)
|
||||
|
||||
def on_event(self, event_type: str) -> typing.Callable:
|
||||
def on_event(self, event_type: str) -> Callable: # type: ignore[type-arg]
|
||||
warnings.warn(
|
||||
"The `on_event` decorator is deprecated, and will be removed in version 1.0.0. " # noqa: E501
|
||||
"Refer to https://www.starlette.io/lifespan/ for recommended approach.",
|
||||
"The `on_event` decorator is deprecated, and will be removed in version 1.0.0. "
|
||||
"Refer to https://starlette.dev/lifespan/ for recommended approach.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
def decorator(func: typing.Callable) -> typing.Callable:
|
||||
def decorator(func: Callable) -> Callable: # type: ignore[type-arg]
|
||||
self.add_event_handler(event_type, func)
|
||||
return func
|
||||
|
||||
|
||||
@@ -1,41 +1,43 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import re
|
||||
import typing
|
||||
from collections.abc import Callable
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
from starlette.routing import BaseRoute, Mount, Route
|
||||
from starlette.routing import BaseRoute, Host, Mount, Route
|
||||
|
||||
try:
|
||||
import yaml
|
||||
except ModuleNotFoundError: # pragma: nocover
|
||||
except ModuleNotFoundError: # pragma: no cover
|
||||
yaml = None # type: ignore[assignment]
|
||||
|
||||
|
||||
class OpenAPIResponse(Response):
|
||||
media_type = "application/vnd.oai.openapi"
|
||||
|
||||
def render(self, content: typing.Any) -> bytes:
|
||||
def render(self, content: Any) -> bytes:
|
||||
assert yaml is not None, "`pyyaml` must be installed to use OpenAPIResponse."
|
||||
assert isinstance(
|
||||
content, dict
|
||||
), "The schema passed to OpenAPIResponse should be a dictionary."
|
||||
assert isinstance(content, dict), "The schema passed to OpenAPIResponse should be a dictionary."
|
||||
return yaml.dump(content, default_flow_style=False).encode("utf-8")
|
||||
|
||||
|
||||
class EndpointInfo(typing.NamedTuple):
|
||||
class EndpointInfo(NamedTuple):
|
||||
path: str
|
||||
http_method: str
|
||||
func: typing.Callable
|
||||
func: Callable[..., Any]
|
||||
|
||||
|
||||
_remove_converter_pattern = re.compile(r":\w+}")
|
||||
|
||||
|
||||
class BaseSchemaGenerator:
|
||||
def get_schema(self, routes: typing.List[BaseRoute]) -> dict:
|
||||
def get_schema(self, routes: list[BaseRoute]) -> dict[str, Any]:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
def get_endpoints(
|
||||
self, routes: typing.List[BaseRoute]
|
||||
) -> typing.List[EndpointInfo]:
|
||||
def get_endpoints(self, routes: list[BaseRoute]) -> list[EndpointInfo]:
|
||||
"""
|
||||
Given the routes, yields the following information:
|
||||
|
||||
@@ -46,12 +48,15 @@ class BaseSchemaGenerator:
|
||||
- func
|
||||
method ready to extract the docstring
|
||||
"""
|
||||
endpoints_info: list = []
|
||||
endpoints_info: list[EndpointInfo] = []
|
||||
|
||||
for route in routes:
|
||||
if isinstance(route, Mount):
|
||||
path = self._remove_converter(route.path)
|
||||
if isinstance(route, Mount | Host):
|
||||
routes = route.routes or []
|
||||
if isinstance(route, Mount):
|
||||
path = self._remove_converter(route.path)
|
||||
else:
|
||||
path = ""
|
||||
sub_endpoints = [
|
||||
EndpointInfo(
|
||||
path="".join((path, sub_endpoint.path)),
|
||||
@@ -70,9 +75,7 @@ class BaseSchemaGenerator:
|
||||
for method in route.methods or ["GET"]:
|
||||
if method == "HEAD":
|
||||
continue
|
||||
endpoints_info.append(
|
||||
EndpointInfo(path, method.lower(), route.endpoint)
|
||||
)
|
||||
endpoints_info.append(EndpointInfo(path, method.lower(), route.endpoint))
|
||||
else:
|
||||
path = self._remove_converter(route.path)
|
||||
for method in ["get", "post", "put", "patch", "delete", "options"]:
|
||||
@@ -90,9 +93,9 @@ class BaseSchemaGenerator:
|
||||
Route("/users/{id:int}", endpoint=get_user, methods=["GET"])
|
||||
Should be represented as `/users/{id}` in the OpenAPI schema.
|
||||
"""
|
||||
return re.sub(r":\w+}", "}", path)
|
||||
return _remove_converter_pattern.sub("}", path)
|
||||
|
||||
def parse_docstring(self, func_or_method: typing.Callable) -> dict:
|
||||
def parse_docstring(self, func_or_method: Callable[..., Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Given a function, parse the docstring as YAML and return a dictionary of info.
|
||||
"""
|
||||
@@ -123,10 +126,10 @@ class BaseSchemaGenerator:
|
||||
|
||||
|
||||
class SchemaGenerator(BaseSchemaGenerator):
|
||||
def __init__(self, base_schema: dict) -> None:
|
||||
def __init__(self, base_schema: dict[str, Any]) -> None:
|
||||
self.base_schema = base_schema
|
||||
|
||||
def get_schema(self, routes: typing.List[BaseRoute]) -> dict:
|
||||
def get_schema(self, routes: list[BaseRoute]) -> dict[str, Any]:
|
||||
schema = dict(self.base_schema)
|
||||
schema.setdefault("paths", {})
|
||||
endpoints_info = self.get_endpoints(routes)
|
||||
|
||||
@@ -1,17 +1,22 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import errno
|
||||
import importlib.util
|
||||
import os
|
||||
import stat
|
||||
import typing
|
||||
from email.utils import parsedate
|
||||
from typing import Union
|
||||
|
||||
import anyio
|
||||
import anyio.to_thread
|
||||
|
||||
from starlette._utils import get_route_path
|
||||
from starlette.datastructures import URL, Headers
|
||||
from starlette.exceptions import HTTPException
|
||||
from starlette.responses import FileResponse, RedirectResponse, Response
|
||||
from starlette.types import Receive, Scope, Send
|
||||
|
||||
PathLike = typing.Union[str, "os.PathLike[str]"]
|
||||
PathLike = Union[str, "os.PathLike[str]"]
|
||||
|
||||
|
||||
class NotModifiedResponse(Response):
|
||||
@@ -27,11 +32,7 @@ class NotModifiedResponse(Response):
|
||||
def __init__(self, headers: Headers):
|
||||
super().__init__(
|
||||
status_code=304,
|
||||
headers={
|
||||
name: value
|
||||
for name, value in headers.items()
|
||||
if name in self.NOT_MODIFIED_HEADERS
|
||||
},
|
||||
headers={name: value for name, value in headers.items() if name in self.NOT_MODIFIED_HEADERS},
|
||||
)
|
||||
|
||||
|
||||
@@ -39,10 +40,8 @@ class StaticFiles:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
directory: typing.Optional[PathLike] = None,
|
||||
packages: typing.Optional[
|
||||
typing.List[typing.Union[str, typing.Tuple[str, str]]]
|
||||
] = None,
|
||||
directory: PathLike | None = None,
|
||||
packages: list[str | tuple[str, str]] | None = None,
|
||||
html: bool = False,
|
||||
check_dir: bool = True,
|
||||
follow_symlink: bool = False,
|
||||
@@ -58,11 +57,9 @@ class StaticFiles:
|
||||
|
||||
def get_directories(
|
||||
self,
|
||||
directory: typing.Optional[PathLike] = None,
|
||||
packages: typing.Optional[
|
||||
typing.List[typing.Union[str, typing.Tuple[str, str]]]
|
||||
] = None,
|
||||
) -> typing.List[PathLike]:
|
||||
directory: PathLike | None = None,
|
||||
packages: list[str | tuple[str, str]] | None = None,
|
||||
) -> list[PathLike]:
|
||||
"""
|
||||
Given `directory` and `packages` arguments, return a list of all the
|
||||
directories that should be used for serving static files from.
|
||||
@@ -79,12 +76,10 @@ class StaticFiles:
|
||||
spec = importlib.util.find_spec(package)
|
||||
assert spec is not None, f"Package {package!r} could not be found."
|
||||
assert spec.origin is not None, f"Package {package!r} could not be found."
|
||||
package_directory = os.path.normpath(
|
||||
os.path.join(spec.origin, "..", statics_dir)
|
||||
package_directory = os.path.normpath(os.path.join(spec.origin, "..", statics_dir))
|
||||
assert os.path.isdir(package_directory), (
|
||||
f"Directory '{statics_dir!r}' in package {package!r} could not be found."
|
||||
)
|
||||
assert os.path.isdir(
|
||||
package_directory
|
||||
), f"Directory '{statics_dir!r}' in package {package!r} could not be found."
|
||||
directories.append(package_directory)
|
||||
|
||||
return directories
|
||||
@@ -108,7 +103,8 @@ class StaticFiles:
|
||||
Given the ASGI scope, return the `path` string to serve up,
|
||||
with OS specific path separators, and any '..', '.' components removed.
|
||||
"""
|
||||
return os.path.normpath(os.path.join(*scope["path"].split("/")))
|
||||
route_path = get_route_path(scope)
|
||||
return os.path.normpath(os.path.join(*route_path.split("/")))
|
||||
|
||||
async def get_response(self, path: str, scope: Scope) -> Response:
|
||||
"""
|
||||
@@ -118,13 +114,15 @@ class StaticFiles:
|
||||
raise HTTPException(status_code=405)
|
||||
|
||||
try:
|
||||
full_path, stat_result = await anyio.to_thread.run_sync(
|
||||
self.lookup_path, path
|
||||
)
|
||||
full_path, stat_result = await anyio.to_thread.run_sync(self.lookup_path, path)
|
||||
except PermissionError:
|
||||
raise HTTPException(status_code=401)
|
||||
except OSError:
|
||||
raise
|
||||
except OSError as exc:
|
||||
# Filename is too long, so it can't be a valid static file.
|
||||
if exc.errno == errno.ENAMETOOLONG:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
raise exc
|
||||
|
||||
if stat_result and stat.S_ISREG(stat_result.st_mode):
|
||||
# We have a static file to serve.
|
||||
@@ -134,9 +132,7 @@ class StaticFiles:
|
||||
# We're in HTML mode, and have got a directory URL.
|
||||
# Check if we have 'index.html' file to serve.
|
||||
index_path = os.path.join(path, "index.html")
|
||||
full_path, stat_result = await anyio.to_thread.run_sync(
|
||||
self.lookup_path, index_path
|
||||
)
|
||||
full_path, stat_result = await anyio.to_thread.run_sync(self.lookup_path, index_path)
|
||||
if stat_result is not None and stat.S_ISREG(stat_result.st_mode):
|
||||
if not scope["path"].endswith("/"):
|
||||
# Directory URLs should redirect to always end in "/".
|
||||
@@ -147,31 +143,22 @@ class StaticFiles:
|
||||
|
||||
if self.html:
|
||||
# Check for '404.html' if we're in HTML mode.
|
||||
full_path, stat_result = await anyio.to_thread.run_sync(
|
||||
self.lookup_path, "404.html"
|
||||
)
|
||||
full_path, stat_result = await anyio.to_thread.run_sync(self.lookup_path, "404.html")
|
||||
if stat_result and stat.S_ISREG(stat_result.st_mode):
|
||||
return FileResponse(
|
||||
full_path,
|
||||
stat_result=stat_result,
|
||||
method=scope["method"],
|
||||
status_code=404,
|
||||
)
|
||||
return FileResponse(full_path, stat_result=stat_result, status_code=404)
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
def lookup_path(
|
||||
self, path: str
|
||||
) -> typing.Tuple[str, typing.Optional[os.stat_result]]:
|
||||
def lookup_path(self, path: str) -> tuple[str, os.stat_result | None]:
|
||||
for directory in self.all_directories:
|
||||
joined_path = os.path.join(directory, path)
|
||||
if self.follow_symlink:
|
||||
full_path = os.path.abspath(joined_path)
|
||||
directory = os.path.abspath(directory)
|
||||
else:
|
||||
full_path = os.path.realpath(joined_path)
|
||||
directory = os.path.realpath(directory)
|
||||
if os.path.commonpath([full_path, directory]) != directory:
|
||||
# Don't allow misbehaving clients to break out of the static files
|
||||
# directory.
|
||||
directory = os.path.realpath(directory)
|
||||
if os.path.commonpath([full_path, directory]) != str(directory):
|
||||
# Don't allow misbehaving clients to break out of the static files directory.
|
||||
continue
|
||||
try:
|
||||
return full_path, os.stat(full_path)
|
||||
@@ -186,12 +173,9 @@ class StaticFiles:
|
||||
scope: Scope,
|
||||
status_code: int = 200,
|
||||
) -> Response:
|
||||
method = scope["method"]
|
||||
request_headers = Headers(scope=scope)
|
||||
|
||||
response = FileResponse(
|
||||
full_path, status_code=status_code, stat_result=stat_result, method=method
|
||||
)
|
||||
response = FileResponse(full_path, status_code=status_code, stat_result=stat_result)
|
||||
if self.is_not_modified(response.headers, request_headers):
|
||||
return NotModifiedResponse(response.headers)
|
||||
return response
|
||||
@@ -208,37 +192,24 @@ class StaticFiles:
|
||||
try:
|
||||
stat_result = await anyio.to_thread.run_sync(os.stat, self.directory)
|
||||
except FileNotFoundError:
|
||||
raise RuntimeError(
|
||||
f"StaticFiles directory '{self.directory}' does not exist."
|
||||
)
|
||||
raise RuntimeError(f"StaticFiles directory '{self.directory}' does not exist.")
|
||||
if not (stat.S_ISDIR(stat_result.st_mode) or stat.S_ISLNK(stat_result.st_mode)):
|
||||
raise RuntimeError(
|
||||
f"StaticFiles path '{self.directory}' is not a directory."
|
||||
)
|
||||
raise RuntimeError(f"StaticFiles path '{self.directory}' is not a directory.")
|
||||
|
||||
def is_not_modified(
|
||||
self, response_headers: Headers, request_headers: Headers
|
||||
) -> bool:
|
||||
def is_not_modified(self, response_headers: Headers, request_headers: Headers) -> bool:
|
||||
"""
|
||||
Given the request and response headers, return `True` if an HTTP
|
||||
"Not Modified" response could be returned instead.
|
||||
"""
|
||||
try:
|
||||
if_none_match = request_headers["if-none-match"]
|
||||
if if_none_match := request_headers.get("if-none-match"):
|
||||
# The "etag" header is added by FileResponse, so it's always present.
|
||||
etag = response_headers["etag"]
|
||||
if if_none_match == etag:
|
||||
return True
|
||||
except KeyError:
|
||||
pass
|
||||
return etag in [tag.strip(" W/") for tag in if_none_match.split(",")]
|
||||
|
||||
try:
|
||||
if_modified_since = parsedate(request_headers["if-modified-since"])
|
||||
last_modified = parsedate(response_headers["last-modified"])
|
||||
if (
|
||||
if_modified_since is not None
|
||||
and last_modified is not None
|
||||
and if_modified_since >= last_modified
|
||||
):
|
||||
if if_modified_since is not None and last_modified is not None and if_modified_since >= last_modified:
|
||||
return True
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
@@ -3,12 +3,14 @@ HTTP codes
|
||||
See HTTP Status Code Registry:
|
||||
https://www.iana.org/assignments/http-status-codes/http-status-codes.xhtml
|
||||
|
||||
And RFC 2324 - https://tools.ietf.org/html/rfc2324
|
||||
And RFC 9110 - https://www.rfc-editor.org/rfc/rfc9110
|
||||
"""
|
||||
import warnings
|
||||
from typing import List
|
||||
|
||||
__all__ = (
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
|
||||
__all__ = [
|
||||
"HTTP_100_CONTINUE",
|
||||
"HTTP_101_SWITCHING_PROTOCOLS",
|
||||
"HTTP_102_PROCESSING",
|
||||
@@ -45,14 +47,14 @@ __all__ = (
|
||||
"HTTP_410_GONE",
|
||||
"HTTP_411_LENGTH_REQUIRED",
|
||||
"HTTP_412_PRECONDITION_FAILED",
|
||||
"HTTP_413_REQUEST_ENTITY_TOO_LARGE",
|
||||
"HTTP_414_REQUEST_URI_TOO_LONG",
|
||||
"HTTP_413_CONTENT_TOO_LARGE",
|
||||
"HTTP_414_URI_TOO_LONG",
|
||||
"HTTP_415_UNSUPPORTED_MEDIA_TYPE",
|
||||
"HTTP_416_REQUESTED_RANGE_NOT_SATISFIABLE",
|
||||
"HTTP_416_RANGE_NOT_SATISFIABLE",
|
||||
"HTTP_417_EXPECTATION_FAILED",
|
||||
"HTTP_418_IM_A_TEAPOT",
|
||||
"HTTP_421_MISDIRECTED_REQUEST",
|
||||
"HTTP_422_UNPROCESSABLE_ENTITY",
|
||||
"HTTP_422_UNPROCESSABLE_CONTENT",
|
||||
"HTTP_423_LOCKED",
|
||||
"HTTP_424_FAILED_DEPENDENCY",
|
||||
"HTTP_425_TOO_EARLY",
|
||||
@@ -87,7 +89,7 @@ __all__ = (
|
||||
"WS_1013_TRY_AGAIN_LATER",
|
||||
"WS_1014_BAD_GATEWAY",
|
||||
"WS_1015_TLS_HANDSHAKE",
|
||||
)
|
||||
]
|
||||
|
||||
HTTP_100_CONTINUE = 100
|
||||
HTTP_101_SWITCHING_PROTOCOLS = 101
|
||||
@@ -125,14 +127,14 @@ HTTP_409_CONFLICT = 409
|
||||
HTTP_410_GONE = 410
|
||||
HTTP_411_LENGTH_REQUIRED = 411
|
||||
HTTP_412_PRECONDITION_FAILED = 412
|
||||
HTTP_413_REQUEST_ENTITY_TOO_LARGE = 413
|
||||
HTTP_414_REQUEST_URI_TOO_LONG = 414
|
||||
HTTP_413_CONTENT_TOO_LARGE = 413
|
||||
HTTP_414_URI_TOO_LONG = 414
|
||||
HTTP_415_UNSUPPORTED_MEDIA_TYPE = 415
|
||||
HTTP_416_REQUESTED_RANGE_NOT_SATISFIABLE = 416
|
||||
HTTP_416_RANGE_NOT_SATISFIABLE = 416
|
||||
HTTP_417_EXPECTATION_FAILED = 417
|
||||
HTTP_418_IM_A_TEAPOT = 418
|
||||
HTTP_421_MISDIRECTED_REQUEST = 421
|
||||
HTTP_422_UNPROCESSABLE_ENTITY = 422
|
||||
HTTP_422_UNPROCESSABLE_CONTENT = 422
|
||||
HTTP_423_LOCKED = 423
|
||||
HTTP_424_FAILED_DEPENDENCY = 424
|
||||
HTTP_425_TOO_EARLY = 425
|
||||
@@ -175,15 +177,22 @@ WS_1013_TRY_AGAIN_LATER = 1013
|
||||
WS_1014_BAD_GATEWAY = 1014
|
||||
WS_1015_TLS_HANDSHAKE = 1015
|
||||
|
||||
|
||||
__deprecated__ = {"WS_1004_NO_STATUS_RCVD": 1004, "WS_1005_ABNORMAL_CLOSURE": 1005}
|
||||
__deprecated__ = {
|
||||
"HTTP_413_REQUEST_ENTITY_TOO_LARGE": 413,
|
||||
"HTTP_414_REQUEST_URI_TOO_LONG": 414,
|
||||
"HTTP_416_REQUESTED_RANGE_NOT_SATISFIABLE": 416,
|
||||
"HTTP_422_UNPROCESSABLE_ENTITY": 422,
|
||||
}
|
||||
|
||||
|
||||
def __getattr__(name: str) -> int:
|
||||
deprecation_changes = {
|
||||
"WS_1004_NO_STATUS_RCVD": "WS_1005_NO_STATUS_RCVD",
|
||||
"WS_1005_ABNORMAL_CLOSURE": "WS_1006_ABNORMAL_CLOSURE",
|
||||
"HTTP_413_REQUEST_ENTITY_TOO_LARGE": "HTTP_413_CONTENT_TOO_LARGE",
|
||||
"HTTP_414_REQUEST_URI_TOO_LONG": "HTTP_414_URI_TOO_LONG",
|
||||
"HTTP_416_REQUESTED_RANGE_NOT_SATISFIABLE": "HTTP_416_RANGE_NOT_SATISFIABLE",
|
||||
"HTTP_422_UNPROCESSABLE_ENTITY": "HTTP_422_UNPROCESSABLE_CONTENT",
|
||||
}
|
||||
|
||||
deprecated = __deprecated__.get(name)
|
||||
if deprecated:
|
||||
warnings.warn(
|
||||
@@ -192,8 +201,9 @@ def __getattr__(name: str) -> int:
|
||||
stacklevel=3,
|
||||
)
|
||||
return deprecated
|
||||
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
||||
|
||||
raise AttributeError(f"module 'starlette.status' has no attribute '{name}'")
|
||||
|
||||
|
||||
def __dir__() -> List[str]:
|
||||
def __dir__() -> list[str]:
|
||||
return sorted(list(__all__) + list(__deprecated__.keys())) # pragma: no cover
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
import typing
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from collections.abc import Callable, Mapping, Sequence
|
||||
from os import PathLike
|
||||
from typing import Any, cast, overload
|
||||
|
||||
from starlette.background import BackgroundTask
|
||||
from starlette.datastructures import URL
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
from starlette.responses import HTMLResponse
|
||||
from starlette.types import Receive, Scope, Send
|
||||
|
||||
try:
|
||||
@@ -16,23 +20,21 @@ try:
|
||||
# adding a type ignore for mypy to let us access an attribute that may not exist
|
||||
if hasattr(jinja2, "pass_context"):
|
||||
pass_context = jinja2.pass_context
|
||||
else: # pragma: nocover
|
||||
else: # pragma: no cover
|
||||
pass_context = jinja2.contextfunction # type: ignore[attr-defined]
|
||||
except ModuleNotFoundError: # pragma: nocover
|
||||
except ModuleNotFoundError: # pragma: no cover
|
||||
jinja2 = None # type: ignore[assignment]
|
||||
|
||||
|
||||
class _TemplateResponse(Response):
|
||||
media_type = "text/html"
|
||||
|
||||
class _TemplateResponse(HTMLResponse):
|
||||
def __init__(
|
||||
self,
|
||||
template: typing.Any,
|
||||
context: dict,
|
||||
template: Any,
|
||||
context: dict[str, Any],
|
||||
status_code: int = 200,
|
||||
headers: typing.Optional[typing.Mapping[str, str]] = None,
|
||||
media_type: typing.Optional[str] = None,
|
||||
background: typing.Optional[BackgroundTask] = None,
|
||||
headers: Mapping[str, str] | None = None,
|
||||
media_type: str | None = None,
|
||||
background: BackgroundTask | None = None,
|
||||
):
|
||||
self.template = template
|
||||
self.context = context
|
||||
@@ -42,7 +44,7 @@ class _TemplateResponse(Response):
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
request = self.context.get("request", {})
|
||||
extensions = request.get("extensions", {})
|
||||
if "http.response.debug" in extensions:
|
||||
if "http.response.debug" in extensions: # pragma: no branch
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.debug",
|
||||
@@ -62,50 +64,145 @@ class Jinja2Templates:
|
||||
return templates.TemplateResponse("index.html", {"request": request})
|
||||
"""
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
directory: typing.Union[str, PathLike],
|
||||
context_processors: typing.Optional[
|
||||
typing.List[typing.Callable[[Request], typing.Dict[str, typing.Any]]]
|
||||
] = None,
|
||||
**env_options: typing.Any,
|
||||
directory: str | PathLike[str] | Sequence[str | PathLike[str]],
|
||||
*,
|
||||
context_processors: list[Callable[[Request], dict[str, Any]]] | None = None,
|
||||
**env_options: Any,
|
||||
) -> None: ...
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
env: jinja2.Environment,
|
||||
context_processors: list[Callable[[Request], dict[str, Any]]] | None = None,
|
||||
) -> None: ...
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
directory: str | PathLike[str] | Sequence[str | PathLike[str]] | None = None,
|
||||
*,
|
||||
context_processors: list[Callable[[Request], dict[str, Any]]] | None = None,
|
||||
env: jinja2.Environment | None = None,
|
||||
**env_options: Any,
|
||||
) -> None:
|
||||
if env_options:
|
||||
warnings.warn(
|
||||
"Extra environment options are deprecated. Use a preconfigured jinja2.Environment instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
assert jinja2 is not None, "jinja2 must be installed to use Jinja2Templates"
|
||||
self.env = self._create_env(directory, **env_options)
|
||||
assert bool(directory) ^ bool(env), "either 'directory' or 'env' arguments must be passed"
|
||||
self.context_processors = context_processors or []
|
||||
if directory is not None:
|
||||
self.env = self._create_env(directory, **env_options)
|
||||
elif env is not None: # pragma: no branch
|
||||
self.env = env
|
||||
|
||||
self._setup_env_defaults(self.env)
|
||||
|
||||
def _create_env(
|
||||
self, directory: typing.Union[str, PathLike], **env_options: typing.Any
|
||||
) -> "jinja2.Environment":
|
||||
@pass_context
|
||||
def url_for(context: dict, name: str, **path_params: typing.Any) -> URL:
|
||||
request = context["request"]
|
||||
return request.url_for(name, **path_params)
|
||||
|
||||
self,
|
||||
directory: str | PathLike[str] | Sequence[str | PathLike[str]],
|
||||
**env_options: Any,
|
||||
) -> jinja2.Environment:
|
||||
loader = jinja2.FileSystemLoader(directory)
|
||||
env_options.setdefault("loader", loader)
|
||||
env_options.setdefault("autoescape", True)
|
||||
|
||||
env = jinja2.Environment(**env_options)
|
||||
env.globals["url_for"] = url_for
|
||||
return env
|
||||
return jinja2.Environment(**env_options)
|
||||
|
||||
def get_template(self, name: str) -> "jinja2.Template":
|
||||
def _setup_env_defaults(self, env: jinja2.Environment) -> None:
|
||||
@pass_context
|
||||
def url_for(
|
||||
context: dict[str, Any],
|
||||
name: str,
|
||||
/,
|
||||
**path_params: Any,
|
||||
) -> URL:
|
||||
request: Request = context["request"]
|
||||
return request.url_for(name, **path_params)
|
||||
|
||||
env.globals.setdefault("url_for", url_for)
|
||||
|
||||
def get_template(self, name: str) -> jinja2.Template:
|
||||
return self.env.get_template(name)
|
||||
|
||||
@overload
|
||||
def TemplateResponse(
|
||||
self,
|
||||
request: Request,
|
||||
name: str,
|
||||
context: dict[str, Any] | None = None,
|
||||
status_code: int = 200,
|
||||
headers: Mapping[str, str] | None = None,
|
||||
media_type: str | None = None,
|
||||
background: BackgroundTask | None = None,
|
||||
) -> _TemplateResponse: ...
|
||||
|
||||
@overload
|
||||
def TemplateResponse(
|
||||
self,
|
||||
name: str,
|
||||
context: dict,
|
||||
context: dict[str, Any] | None = None,
|
||||
status_code: int = 200,
|
||||
headers: typing.Optional[typing.Mapping[str, str]] = None,
|
||||
media_type: typing.Optional[str] = None,
|
||||
background: typing.Optional[BackgroundTask] = None,
|
||||
headers: Mapping[str, str] | None = None,
|
||||
media_type: str | None = None,
|
||||
background: BackgroundTask | None = None,
|
||||
) -> _TemplateResponse:
|
||||
if "request" not in context:
|
||||
raise ValueError('context must include a "request" key')
|
||||
# Deprecated usage
|
||||
...
|
||||
|
||||
request = typing.cast(Request, context["request"])
|
||||
def TemplateResponse(self, *args: Any, **kwargs: Any) -> _TemplateResponse:
|
||||
if args:
|
||||
if isinstance(args[0], str): # the first argument is template name (old style)
|
||||
warnings.warn(
|
||||
"The `name` is not the first parameter anymore. "
|
||||
"The first parameter should be the `Request` instance.\n"
|
||||
'Replace `TemplateResponse(name, {"request": request})` by `TemplateResponse(request, name)`.',
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
name = args[0]
|
||||
context = args[1] if len(args) > 1 else kwargs.get("context", {})
|
||||
status_code = args[2] if len(args) > 2 else kwargs.get("status_code", 200)
|
||||
headers = args[3] if len(args) > 3 else kwargs.get("headers")
|
||||
media_type = args[4] if len(args) > 4 else kwargs.get("media_type")
|
||||
background = args[5] if len(args) > 5 else kwargs.get("background")
|
||||
|
||||
if "request" not in context:
|
||||
raise ValueError('context must include a "request" key')
|
||||
request = context["request"]
|
||||
else: # the first argument is a request instance (new style)
|
||||
request = args[0]
|
||||
name = args[1] if len(args) > 1 else kwargs["name"]
|
||||
context = args[2] if len(args) > 2 else kwargs.get("context", {})
|
||||
status_code = args[3] if len(args) > 3 else kwargs.get("status_code", 200)
|
||||
headers = args[4] if len(args) > 4 else kwargs.get("headers")
|
||||
media_type = args[5] if len(args) > 5 else kwargs.get("media_type")
|
||||
background = args[6] if len(args) > 6 else kwargs.get("background")
|
||||
else: # all arguments are kwargs
|
||||
if "request" not in kwargs:
|
||||
warnings.warn(
|
||||
"The `TemplateResponse` now requires the `request` argument.\n"
|
||||
'Replace `TemplateResponse(name, {"context": context})` by `TemplateResponse(request, name)`.',
|
||||
DeprecationWarning,
|
||||
)
|
||||
if "request" not in kwargs.get("context", {}):
|
||||
raise ValueError('context must include a "request" key')
|
||||
|
||||
context = kwargs.get("context", {})
|
||||
request = kwargs.get("request", context.get("request"))
|
||||
name = cast(str, kwargs["name"])
|
||||
status_code = kwargs.get("status_code", 200)
|
||||
headers = kwargs.get("headers")
|
||||
media_type = kwargs.get("media_type")
|
||||
background = kwargs.get("background")
|
||||
|
||||
context.setdefault("request", request)
|
||||
for context_processor in self.context_processors:
|
||||
context.update(context_processor(request))
|
||||
|
||||
|
||||
@@ -1,43 +1,58 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import inspect
|
||||
import io
|
||||
import json
|
||||
import math
|
||||
import queue
|
||||
import sys
|
||||
import typing
|
||||
import warnings
|
||||
from collections.abc import Awaitable, Callable, Generator, Iterable, Mapping, MutableMapping, Sequence
|
||||
from concurrent.futures import Future
|
||||
from contextlib import AbstractContextManager
|
||||
from types import GeneratorType
|
||||
from typing import (
|
||||
Any,
|
||||
Literal,
|
||||
TypedDict,
|
||||
TypeGuard,
|
||||
cast,
|
||||
)
|
||||
from urllib.parse import unquote, urljoin
|
||||
|
||||
import anyio
|
||||
import anyio.abc
|
||||
import anyio.from_thread
|
||||
import httpx
|
||||
from anyio.streams.stapled import StapledObjectStream
|
||||
|
||||
from starlette._utils import is_async_callable
|
||||
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
||||
from starlette.websockets import WebSocketDisconnect
|
||||
|
||||
if sys.version_info >= (3, 8): # pragma: no cover
|
||||
from typing import TypedDict
|
||||
if sys.version_info >= (3, 11): # pragma: no cover
|
||||
from typing import Self
|
||||
else: # pragma: no cover
|
||||
from typing_extensions import TypedDict
|
||||
from typing_extensions import Self
|
||||
|
||||
_PortalFactoryType = typing.Callable[
|
||||
[], typing.ContextManager[anyio.abc.BlockingPortal]
|
||||
]
|
||||
try:
|
||||
import httpx
|
||||
except ModuleNotFoundError: # pragma: no cover
|
||||
raise RuntimeError(
|
||||
"The starlette.testclient module requires the httpx package to be installed.\n"
|
||||
"You can install this with:\n"
|
||||
" $ pip install httpx\n"
|
||||
)
|
||||
_PortalFactoryType = Callable[[], AbstractContextManager[anyio.abc.BlockingPortal]]
|
||||
|
||||
ASGIInstance = typing.Callable[[Receive, Send], typing.Awaitable[None]]
|
||||
ASGI2App = typing.Callable[[Scope], ASGIInstance]
|
||||
ASGI3App = typing.Callable[[Scope, Receive, Send], typing.Awaitable[None]]
|
||||
ASGIInstance = Callable[[Receive, Send], Awaitable[None]]
|
||||
ASGI2App = Callable[[Scope], ASGIInstance]
|
||||
ASGI3App = Callable[[Scope, Receive, Send], Awaitable[None]]
|
||||
|
||||
|
||||
_RequestData = typing.Mapping[str, typing.Union[str, typing.Iterable[str]]]
|
||||
_RequestData = Mapping[str, str | Iterable[str] | bytes]
|
||||
|
||||
|
||||
def _is_asgi3(app: typing.Union[ASGI2App, ASGI3App]) -> bool:
|
||||
def _is_asgi3(app: ASGI2App | ASGI3App) -> TypeGuard[ASGI3App]:
|
||||
if inspect.isclass(app):
|
||||
return hasattr(app, "__await__")
|
||||
return is_async_callable(app)
|
||||
@@ -58,14 +73,24 @@ class _WrapASGI2:
|
||||
|
||||
class _AsyncBackend(TypedDict):
|
||||
backend: str
|
||||
backend_options: typing.Dict[str, typing.Any]
|
||||
backend_options: dict[str, Any]
|
||||
|
||||
|
||||
class _Upgrade(Exception):
|
||||
def __init__(self, session: "WebSocketTestSession") -> None:
|
||||
def __init__(self, session: WebSocketTestSession) -> None:
|
||||
self.session = session
|
||||
|
||||
|
||||
class WebSocketDenialResponse( # type: ignore[misc]
|
||||
httpx.Response,
|
||||
WebSocketDisconnect,
|
||||
):
|
||||
"""
|
||||
A special case of `WebSocketDisconnect`, raised in the `TestClient` if the
|
||||
`WebSocket` is closed before being accepted with a `send_denial_response()`.
|
||||
"""
|
||||
|
||||
|
||||
class WebSocketTestSession:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -77,65 +102,60 @@ class WebSocketTestSession:
|
||||
self.scope = scope
|
||||
self.accepted_subprotocol = None
|
||||
self.portal_factory = portal_factory
|
||||
self._receive_queue: "queue.Queue[typing.Any]" = queue.Queue()
|
||||
self._send_queue: "queue.Queue[typing.Any]" = queue.Queue()
|
||||
self.extra_headers = None
|
||||
|
||||
def __enter__(self) -> "WebSocketTestSession":
|
||||
self.exit_stack = contextlib.ExitStack()
|
||||
self.portal = self.exit_stack.enter_context(self.portal_factory())
|
||||
|
||||
try:
|
||||
_: "Future[None]" = self.portal.start_task_soon(self._run)
|
||||
def __enter__(self) -> WebSocketTestSession:
|
||||
with contextlib.ExitStack() as stack:
|
||||
self.portal = portal = stack.enter_context(self.portal_factory())
|
||||
fut, cs = portal.start_task(self._run)
|
||||
stack.callback(fut.result)
|
||||
stack.callback(portal.call, cs.cancel)
|
||||
self.send({"type": "websocket.connect"})
|
||||
message = self.receive()
|
||||
self._raise_on_close(message)
|
||||
except Exception:
|
||||
self.exit_stack.close()
|
||||
raise
|
||||
self.accepted_subprotocol = message.get("subprotocol", None)
|
||||
self.extra_headers = message.get("headers", None)
|
||||
return self
|
||||
self.accepted_subprotocol = message.get("subprotocol", None)
|
||||
self.extra_headers = message.get("headers", None)
|
||||
stack.callback(self.close, 1000)
|
||||
self.exit_stack = stack.pop_all()
|
||||
return self
|
||||
|
||||
def __exit__(self, *args: typing.Any) -> None:
|
||||
try:
|
||||
self.close(1000)
|
||||
finally:
|
||||
self.exit_stack.close()
|
||||
while not self._send_queue.empty():
|
||||
message = self._send_queue.get()
|
||||
if isinstance(message, BaseException):
|
||||
raise message
|
||||
def __exit__(self, *args: Any) -> bool | None:
|
||||
return self.exit_stack.__exit__(*args)
|
||||
|
||||
async def _run(self) -> None:
|
||||
async def _run(self, *, task_status: anyio.abc.TaskStatus[anyio.CancelScope]) -> None:
|
||||
"""
|
||||
The sub-thread in which the websocket session runs.
|
||||
"""
|
||||
scope = self.scope
|
||||
receive = self._asgi_receive
|
||||
send = self._asgi_send
|
||||
try:
|
||||
await self.app(scope, receive, send)
|
||||
except BaseException as exc:
|
||||
self._send_queue.put(exc)
|
||||
raise
|
||||
send: anyio.create_memory_object_stream[Message] = anyio.create_memory_object_stream(math.inf)
|
||||
send_tx, send_rx = send
|
||||
receive: anyio.create_memory_object_stream[Message] = anyio.create_memory_object_stream(math.inf)
|
||||
receive_tx, receive_rx = receive
|
||||
with send_tx, send_rx, receive_tx, receive_rx, anyio.CancelScope() as cs:
|
||||
self._receive_tx = receive_tx
|
||||
self._send_rx = send_rx
|
||||
task_status.started(cs)
|
||||
await self.app(self.scope, receive_rx.receive, send_tx.send)
|
||||
|
||||
async def _asgi_receive(self) -> Message:
|
||||
while self._receive_queue.empty():
|
||||
await anyio.sleep(0)
|
||||
return self._receive_queue.get()
|
||||
|
||||
async def _asgi_send(self, message: Message) -> None:
|
||||
self._send_queue.put(message)
|
||||
# wait for cs.cancel to be called before closing streams
|
||||
await anyio.sleep_forever()
|
||||
|
||||
def _raise_on_close(self, message: Message) -> None:
|
||||
if message["type"] == "websocket.close":
|
||||
raise WebSocketDisconnect(
|
||||
message.get("code", 1000), message.get("reason", "")
|
||||
)
|
||||
raise WebSocketDisconnect(code=message.get("code", 1000), reason=message.get("reason", ""))
|
||||
elif message["type"] == "websocket.http.response.start":
|
||||
status_code: int = message["status"]
|
||||
headers: list[tuple[bytes, bytes]] = message["headers"]
|
||||
body: list[bytes] = []
|
||||
while True:
|
||||
message = self.receive()
|
||||
assert message["type"] == "websocket.http.response.body"
|
||||
body.append(message["body"])
|
||||
if not message.get("more_body", False):
|
||||
break
|
||||
raise WebSocketDenialResponse(status_code=status_code, headers=headers, content=b"".join(body))
|
||||
|
||||
def send(self, message: Message) -> None:
|
||||
self._receive_queue.put(message)
|
||||
self.portal.call(self._receive_tx.send, message)
|
||||
|
||||
def send_text(self, data: str) -> None:
|
||||
self.send({"type": "websocket.receive", "text": data})
|
||||
@@ -143,35 +163,30 @@ class WebSocketTestSession:
|
||||
def send_bytes(self, data: bytes) -> None:
|
||||
self.send({"type": "websocket.receive", "bytes": data})
|
||||
|
||||
def send_json(self, data: typing.Any, mode: str = "text") -> None:
|
||||
assert mode in ["text", "binary"]
|
||||
text = json.dumps(data, separators=(",", ":"))
|
||||
def send_json(self, data: Any, mode: Literal["text", "binary"] = "text") -> None:
|
||||
text = json.dumps(data, separators=(",", ":"), ensure_ascii=False)
|
||||
if mode == "text":
|
||||
self.send({"type": "websocket.receive", "text": text})
|
||||
else:
|
||||
self.send({"type": "websocket.receive", "bytes": text.encode("utf-8")})
|
||||
|
||||
def close(self, code: int = 1000) -> None:
|
||||
self.send({"type": "websocket.disconnect", "code": code})
|
||||
def close(self, code: int = 1000, reason: str | None = None) -> None:
|
||||
self.send({"type": "websocket.disconnect", "code": code, "reason": reason})
|
||||
|
||||
def receive(self) -> Message:
|
||||
message = self._send_queue.get()
|
||||
if isinstance(message, BaseException):
|
||||
raise message
|
||||
return message
|
||||
return self.portal.call(self._send_rx.receive)
|
||||
|
||||
def receive_text(self) -> str:
|
||||
message = self.receive()
|
||||
self._raise_on_close(message)
|
||||
return message["text"]
|
||||
return cast(str, message["text"])
|
||||
|
||||
def receive_bytes(self) -> bytes:
|
||||
message = self.receive()
|
||||
self._raise_on_close(message)
|
||||
return message["bytes"]
|
||||
return cast(bytes, message["bytes"])
|
||||
|
||||
def receive_json(self, mode: str = "text") -> typing.Any:
|
||||
assert mode in ["text", "binary"]
|
||||
def receive_json(self, mode: Literal["text", "binary"] = "text") -> Any:
|
||||
message = self.receive()
|
||||
self._raise_on_close(message)
|
||||
if mode == "text":
|
||||
@@ -189,13 +204,15 @@ class _TestClientTransport(httpx.BaseTransport):
|
||||
raise_server_exceptions: bool = True,
|
||||
root_path: str = "",
|
||||
*,
|
||||
app_state: typing.Dict[str, typing.Any],
|
||||
client: tuple[str, int],
|
||||
app_state: dict[str, Any],
|
||||
) -> None:
|
||||
self.app = app
|
||||
self.raise_server_exceptions = raise_server_exceptions
|
||||
self.root_path = root_path
|
||||
self.portal_factory = portal_factory
|
||||
self.app_state = app_state
|
||||
self.client = client
|
||||
|
||||
def handle_request(self, request: httpx.Request) -> httpx.Response:
|
||||
scheme = request.url.scheme
|
||||
@@ -215,38 +232,36 @@ class _TestClientTransport(httpx.BaseTransport):
|
||||
|
||||
# Include the 'host' header.
|
||||
if "host" in request.headers:
|
||||
headers: typing.List[typing.Tuple[bytes, bytes]] = []
|
||||
headers: list[tuple[bytes, bytes]] = []
|
||||
elif port == default_port: # pragma: no cover
|
||||
headers = [(b"host", host.encode())]
|
||||
else: # pragma: no cover
|
||||
headers = [(b"host", (f"{host}:{port}").encode())]
|
||||
|
||||
# Include other request headers.
|
||||
headers += [
|
||||
(key.lower().encode(), value.encode())
|
||||
for key, value in request.headers.items()
|
||||
]
|
||||
headers += [(key.lower().encode(), value.encode()) for key, value in request.headers.multi_items()]
|
||||
|
||||
scope: typing.Dict[str, typing.Any]
|
||||
scope: dict[str, Any]
|
||||
|
||||
if scheme in {"ws", "wss"}:
|
||||
subprotocol = request.headers.get("sec-websocket-protocol", None)
|
||||
if subprotocol is None:
|
||||
subprotocols: typing.Sequence[str] = []
|
||||
subprotocols: Sequence[str] = []
|
||||
else:
|
||||
subprotocols = [value.strip() for value in subprotocol.split(",")]
|
||||
scope = {
|
||||
"type": "websocket",
|
||||
"path": unquote(path),
|
||||
"raw_path": raw_path,
|
||||
"raw_path": raw_path.split(b"?", 1)[0],
|
||||
"root_path": self.root_path,
|
||||
"scheme": scheme,
|
||||
"query_string": query.encode(),
|
||||
"headers": headers,
|
||||
"client": ["testclient", 50000],
|
||||
"client": self.client,
|
||||
"server": [host, port],
|
||||
"subprotocols": subprotocols,
|
||||
"state": self.app_state.copy(),
|
||||
"extensions": {"websocket.http.response": {}},
|
||||
}
|
||||
session = WebSocketTestSession(self.app, scope, self.portal_factory)
|
||||
raise _Upgrade(session)
|
||||
@@ -256,12 +271,12 @@ class _TestClientTransport(httpx.BaseTransport):
|
||||
"http_version": "1.1",
|
||||
"method": request.method,
|
||||
"path": unquote(path),
|
||||
"raw_path": raw_path,
|
||||
"raw_path": raw_path.split(b"?", 1)[0],
|
||||
"root_path": self.root_path,
|
||||
"scheme": scheme,
|
||||
"query_string": query.encode(),
|
||||
"headers": headers,
|
||||
"client": ["testclient", 50000],
|
||||
"client": self.client,
|
||||
"server": [host, port],
|
||||
"extensions": {"http.response.debug": {}},
|
||||
"state": self.app_state.copy(),
|
||||
@@ -270,7 +285,7 @@ class _TestClientTransport(httpx.BaseTransport):
|
||||
request_complete = False
|
||||
response_started = False
|
||||
response_complete: anyio.Event
|
||||
raw_kwargs: typing.Dict[str, typing.Any] = {"stream": io.BytesIO()}
|
||||
raw_kwargs: dict[str, Any] = {"stream": io.BytesIO()}
|
||||
template = None
|
||||
context = None
|
||||
|
||||
@@ -306,22 +321,13 @@ class _TestClientTransport(httpx.BaseTransport):
|
||||
nonlocal raw_kwargs, response_started, template, context
|
||||
|
||||
if message["type"] == "http.response.start":
|
||||
assert (
|
||||
not response_started
|
||||
), 'Received multiple "http.response.start" messages.'
|
||||
assert not response_started, 'Received multiple "http.response.start" messages.'
|
||||
raw_kwargs["status_code"] = message["status"]
|
||||
raw_kwargs["headers"] = [
|
||||
(key.decode(), value.decode())
|
||||
for key, value in message.get("headers", [])
|
||||
]
|
||||
raw_kwargs["headers"] = [(key.decode(), value.decode()) for key, value in message.get("headers", [])]
|
||||
response_started = True
|
||||
elif message["type"] == "http.response.body":
|
||||
assert (
|
||||
response_started
|
||||
), 'Received "http.response.body" without "http.response.start".'
|
||||
assert (
|
||||
not response_complete.is_set()
|
||||
), 'Received "http.response.body" after response completed.'
|
||||
assert response_started, 'Received "http.response.body" without "http.response.start".'
|
||||
assert not response_complete.is_set(), 'Received "http.response.body" after response completed.'
|
||||
body = message.get("body", b"")
|
||||
more_body = message.get("more_body", False)
|
||||
if request.method != "HEAD":
|
||||
@@ -361,8 +367,8 @@ class _TestClientTransport(httpx.BaseTransport):
|
||||
|
||||
class TestClient(httpx.Client):
|
||||
__test__ = False
|
||||
task: "Future[None]"
|
||||
portal: typing.Optional[anyio.abc.BlockingPortal] = None
|
||||
task: Future[None]
|
||||
portal: anyio.abc.BlockingPortal | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -370,110 +376,84 @@ class TestClient(httpx.Client):
|
||||
base_url: str = "http://testserver",
|
||||
raise_server_exceptions: bool = True,
|
||||
root_path: str = "",
|
||||
backend: str = "asyncio",
|
||||
backend_options: typing.Optional[typing.Dict[str, typing.Any]] = None,
|
||||
cookies: httpx._client.CookieTypes = None,
|
||||
headers: typing.Dict[str, str] = None,
|
||||
backend: Literal["asyncio", "trio"] = "asyncio",
|
||||
backend_options: dict[str, Any] | None = None,
|
||||
cookies: httpx._types.CookieTypes | None = None,
|
||||
headers: dict[str, str] | None = None,
|
||||
follow_redirects: bool = True,
|
||||
client: tuple[str, int] = ("testclient", 50000),
|
||||
) -> None:
|
||||
self.async_backend = _AsyncBackend(
|
||||
backend=backend, backend_options=backend_options or {}
|
||||
)
|
||||
self.async_backend = _AsyncBackend(backend=backend, backend_options=backend_options or {})
|
||||
if _is_asgi3(app):
|
||||
app = typing.cast(ASGI3App, app)
|
||||
asgi_app = app
|
||||
else:
|
||||
app = typing.cast(ASGI2App, app) # type: ignore[assignment]
|
||||
app = cast(ASGI2App, app) # type: ignore[assignment]
|
||||
asgi_app = _WrapASGI2(app) # type: ignore[arg-type]
|
||||
self.app = asgi_app
|
||||
self.app_state: typing.Dict[str, typing.Any] = {}
|
||||
self.app_state: dict[str, Any] = {}
|
||||
transport = _TestClientTransport(
|
||||
self.app,
|
||||
portal_factory=self._portal_factory,
|
||||
raise_server_exceptions=raise_server_exceptions,
|
||||
root_path=root_path,
|
||||
app_state=self.app_state,
|
||||
client=client,
|
||||
)
|
||||
if headers is None:
|
||||
headers = {}
|
||||
headers.setdefault("user-agent", "testclient")
|
||||
super().__init__(
|
||||
app=self.app,
|
||||
base_url=base_url,
|
||||
headers=headers,
|
||||
transport=transport,
|
||||
follow_redirects=True,
|
||||
follow_redirects=follow_redirects,
|
||||
cookies=cookies,
|
||||
)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _portal_factory(self) -> typing.Generator[anyio.abc.BlockingPortal, None, None]:
|
||||
def _portal_factory(self) -> Generator[anyio.abc.BlockingPortal, None, None]:
|
||||
if self.portal is not None:
|
||||
yield self.portal
|
||||
else:
|
||||
with anyio.from_thread.start_blocking_portal(
|
||||
**self.async_backend
|
||||
) as portal:
|
||||
with anyio.from_thread.start_blocking_portal(**self.async_backend) as portal:
|
||||
yield portal
|
||||
|
||||
def _choose_redirect_arg(
|
||||
self,
|
||||
follow_redirects: typing.Optional[bool],
|
||||
allow_redirects: typing.Optional[bool],
|
||||
) -> typing.Union[bool, httpx._client.UseClientDefault]:
|
||||
redirect: typing.Union[
|
||||
bool, httpx._client.UseClientDefault
|
||||
] = httpx._client.USE_CLIENT_DEFAULT
|
||||
if allow_redirects is not None:
|
||||
message = (
|
||||
"The `allow_redirects` argument is deprecated. "
|
||||
"Use `follow_redirects` instead."
|
||||
)
|
||||
warnings.warn(message, DeprecationWarning)
|
||||
redirect = allow_redirects
|
||||
if follow_redirects is not None:
|
||||
redirect = follow_redirects
|
||||
elif allow_redirects is not None and follow_redirects is not None:
|
||||
raise RuntimeError( # pragma: no cover
|
||||
"Cannot use both `allow_redirects` and `follow_redirects`."
|
||||
)
|
||||
return redirect
|
||||
|
||||
def request( # type: ignore[override]
|
||||
self,
|
||||
method: str,
|
||||
url: httpx._types.URLTypes,
|
||||
*,
|
||||
content: typing.Optional[httpx._types.RequestContent] = None,
|
||||
data: typing.Optional[_RequestData] = None,
|
||||
files: typing.Optional[httpx._types.RequestFiles] = None,
|
||||
json: typing.Any = None,
|
||||
params: typing.Optional[httpx._types.QueryParamTypes] = None,
|
||||
headers: typing.Optional[httpx._types.HeaderTypes] = None,
|
||||
cookies: typing.Optional[httpx._types.CookieTypes] = None,
|
||||
auth: typing.Union[
|
||||
httpx._types.AuthTypes, httpx._client.UseClientDefault
|
||||
] = httpx._client.USE_CLIENT_DEFAULT,
|
||||
follow_redirects: typing.Optional[bool] = None,
|
||||
allow_redirects: typing.Optional[bool] = None,
|
||||
timeout: typing.Union[
|
||||
httpx._client.TimeoutTypes, httpx._client.UseClientDefault
|
||||
] = httpx._client.USE_CLIENT_DEFAULT,
|
||||
extensions: typing.Optional[typing.Dict[str, typing.Any]] = None,
|
||||
content: httpx._types.RequestContent | None = None,
|
||||
data: _RequestData | None = None,
|
||||
files: httpx._types.RequestFiles | None = None,
|
||||
json: Any = None,
|
||||
params: httpx._types.QueryParamTypes | None = None,
|
||||
headers: httpx._types.HeaderTypes | None = None,
|
||||
cookies: httpx._types.CookieTypes | None = None,
|
||||
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
extensions: dict[str, Any] | None = None,
|
||||
) -> httpx.Response:
|
||||
url = self.base_url.join(url)
|
||||
redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
|
||||
if timeout is not httpx.USE_CLIENT_DEFAULT:
|
||||
warnings.warn(
|
||||
"You should not use the 'timeout' argument with the TestClient. "
|
||||
"See https://github.com/Kludex/starlette/issues/1108 for more information.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
url = self._merge_url(url)
|
||||
return super().request(
|
||||
method,
|
||||
url,
|
||||
content=content,
|
||||
data=data, # type: ignore[arg-type]
|
||||
data=data,
|
||||
files=files,
|
||||
json=json,
|
||||
params=params,
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
auth=auth,
|
||||
follow_redirects=redirect,
|
||||
follow_redirects=follow_redirects,
|
||||
timeout=timeout,
|
||||
extensions=extensions,
|
||||
)
|
||||
@@ -482,27 +462,21 @@ class TestClient(httpx.Client):
|
||||
self,
|
||||
url: httpx._types.URLTypes,
|
||||
*,
|
||||
params: typing.Optional[httpx._types.QueryParamTypes] = None,
|
||||
headers: typing.Optional[httpx._types.HeaderTypes] = None,
|
||||
cookies: typing.Optional[httpx._types.CookieTypes] = None,
|
||||
auth: typing.Union[
|
||||
httpx._types.AuthTypes, httpx._client.UseClientDefault
|
||||
] = httpx._client.USE_CLIENT_DEFAULT,
|
||||
follow_redirects: typing.Optional[bool] = None,
|
||||
allow_redirects: typing.Optional[bool] = None,
|
||||
timeout: typing.Union[
|
||||
httpx._client.TimeoutTypes, httpx._client.UseClientDefault
|
||||
] = httpx._client.USE_CLIENT_DEFAULT,
|
||||
extensions: typing.Optional[typing.Dict[str, typing.Any]] = None,
|
||||
params: httpx._types.QueryParamTypes | None = None,
|
||||
headers: httpx._types.HeaderTypes | None = None,
|
||||
cookies: httpx._types.CookieTypes | None = None,
|
||||
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
extensions: dict[str, Any] | None = None,
|
||||
) -> httpx.Response:
|
||||
redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
|
||||
return super().get(
|
||||
url,
|
||||
params=params,
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
auth=auth,
|
||||
follow_redirects=redirect,
|
||||
follow_redirects=follow_redirects,
|
||||
timeout=timeout,
|
||||
extensions=extensions,
|
||||
)
|
||||
@@ -511,27 +485,21 @@ class TestClient(httpx.Client):
|
||||
self,
|
||||
url: httpx._types.URLTypes,
|
||||
*,
|
||||
params: typing.Optional[httpx._types.QueryParamTypes] = None,
|
||||
headers: typing.Optional[httpx._types.HeaderTypes] = None,
|
||||
cookies: typing.Optional[httpx._types.CookieTypes] = None,
|
||||
auth: typing.Union[
|
||||
httpx._types.AuthTypes, httpx._client.UseClientDefault
|
||||
] = httpx._client.USE_CLIENT_DEFAULT,
|
||||
follow_redirects: typing.Optional[bool] = None,
|
||||
allow_redirects: typing.Optional[bool] = None,
|
||||
timeout: typing.Union[
|
||||
httpx._client.TimeoutTypes, httpx._client.UseClientDefault
|
||||
] = httpx._client.USE_CLIENT_DEFAULT,
|
||||
extensions: typing.Optional[typing.Dict[str, typing.Any]] = None,
|
||||
params: httpx._types.QueryParamTypes | None = None,
|
||||
headers: httpx._types.HeaderTypes | None = None,
|
||||
cookies: httpx._types.CookieTypes | None = None,
|
||||
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
extensions: dict[str, Any] | None = None,
|
||||
) -> httpx.Response:
|
||||
redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
|
||||
return super().options(
|
||||
url,
|
||||
params=params,
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
auth=auth,
|
||||
follow_redirects=redirect,
|
||||
follow_redirects=follow_redirects,
|
||||
timeout=timeout,
|
||||
extensions=extensions,
|
||||
)
|
||||
@@ -540,27 +508,21 @@ class TestClient(httpx.Client):
|
||||
self,
|
||||
url: httpx._types.URLTypes,
|
||||
*,
|
||||
params: typing.Optional[httpx._types.QueryParamTypes] = None,
|
||||
headers: typing.Optional[httpx._types.HeaderTypes] = None,
|
||||
cookies: typing.Optional[httpx._types.CookieTypes] = None,
|
||||
auth: typing.Union[
|
||||
httpx._types.AuthTypes, httpx._client.UseClientDefault
|
||||
] = httpx._client.USE_CLIENT_DEFAULT,
|
||||
follow_redirects: typing.Optional[bool] = None,
|
||||
allow_redirects: typing.Optional[bool] = None,
|
||||
timeout: typing.Union[
|
||||
httpx._client.TimeoutTypes, httpx._client.UseClientDefault
|
||||
] = httpx._client.USE_CLIENT_DEFAULT,
|
||||
extensions: typing.Optional[typing.Dict[str, typing.Any]] = None,
|
||||
params: httpx._types.QueryParamTypes | None = None,
|
||||
headers: httpx._types.HeaderTypes | None = None,
|
||||
cookies: httpx._types.CookieTypes | None = None,
|
||||
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
extensions: dict[str, Any] | None = None,
|
||||
) -> httpx.Response:
|
||||
redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
|
||||
return super().head(
|
||||
url,
|
||||
params=params,
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
auth=auth,
|
||||
follow_redirects=redirect,
|
||||
follow_redirects=follow_redirects,
|
||||
timeout=timeout,
|
||||
extensions=extensions,
|
||||
)
|
||||
@@ -569,35 +531,29 @@ class TestClient(httpx.Client):
|
||||
self,
|
||||
url: httpx._types.URLTypes,
|
||||
*,
|
||||
content: typing.Optional[httpx._types.RequestContent] = None,
|
||||
data: typing.Optional[_RequestData] = None,
|
||||
files: typing.Optional[httpx._types.RequestFiles] = None,
|
||||
json: typing.Any = None,
|
||||
params: typing.Optional[httpx._types.QueryParamTypes] = None,
|
||||
headers: typing.Optional[httpx._types.HeaderTypes] = None,
|
||||
cookies: typing.Optional[httpx._types.CookieTypes] = None,
|
||||
auth: typing.Union[
|
||||
httpx._types.AuthTypes, httpx._client.UseClientDefault
|
||||
] = httpx._client.USE_CLIENT_DEFAULT,
|
||||
follow_redirects: typing.Optional[bool] = None,
|
||||
allow_redirects: typing.Optional[bool] = None,
|
||||
timeout: typing.Union[
|
||||
httpx._client.TimeoutTypes, httpx._client.UseClientDefault
|
||||
] = httpx._client.USE_CLIENT_DEFAULT,
|
||||
extensions: typing.Optional[typing.Dict[str, typing.Any]] = None,
|
||||
content: httpx._types.RequestContent | None = None,
|
||||
data: _RequestData | None = None,
|
||||
files: httpx._types.RequestFiles | None = None,
|
||||
json: Any = None,
|
||||
params: httpx._types.QueryParamTypes | None = None,
|
||||
headers: httpx._types.HeaderTypes | None = None,
|
||||
cookies: httpx._types.CookieTypes | None = None,
|
||||
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
extensions: dict[str, Any] | None = None,
|
||||
) -> httpx.Response:
|
||||
redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
|
||||
return super().post(
|
||||
url,
|
||||
content=content,
|
||||
data=data, # type: ignore[arg-type]
|
||||
data=data,
|
||||
files=files,
|
||||
json=json,
|
||||
params=params,
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
auth=auth,
|
||||
follow_redirects=redirect,
|
||||
follow_redirects=follow_redirects,
|
||||
timeout=timeout,
|
||||
extensions=extensions,
|
||||
)
|
||||
@@ -606,35 +562,29 @@ class TestClient(httpx.Client):
|
||||
self,
|
||||
url: httpx._types.URLTypes,
|
||||
*,
|
||||
content: typing.Optional[httpx._types.RequestContent] = None,
|
||||
data: typing.Optional[_RequestData] = None,
|
||||
files: typing.Optional[httpx._types.RequestFiles] = None,
|
||||
json: typing.Any = None,
|
||||
params: typing.Optional[httpx._types.QueryParamTypes] = None,
|
||||
headers: typing.Optional[httpx._types.HeaderTypes] = None,
|
||||
cookies: typing.Optional[httpx._types.CookieTypes] = None,
|
||||
auth: typing.Union[
|
||||
httpx._types.AuthTypes, httpx._client.UseClientDefault
|
||||
] = httpx._client.USE_CLIENT_DEFAULT,
|
||||
follow_redirects: typing.Optional[bool] = None,
|
||||
allow_redirects: typing.Optional[bool] = None,
|
||||
timeout: typing.Union[
|
||||
httpx._client.TimeoutTypes, httpx._client.UseClientDefault
|
||||
] = httpx._client.USE_CLIENT_DEFAULT,
|
||||
extensions: typing.Optional[typing.Dict[str, typing.Any]] = None,
|
||||
content: httpx._types.RequestContent | None = None,
|
||||
data: _RequestData | None = None,
|
||||
files: httpx._types.RequestFiles | None = None,
|
||||
json: Any = None,
|
||||
params: httpx._types.QueryParamTypes | None = None,
|
||||
headers: httpx._types.HeaderTypes | None = None,
|
||||
cookies: httpx._types.CookieTypes | None = None,
|
||||
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
extensions: dict[str, Any] | None = None,
|
||||
) -> httpx.Response:
|
||||
redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
|
||||
return super().put(
|
||||
url,
|
||||
content=content,
|
||||
data=data, # type: ignore[arg-type]
|
||||
data=data,
|
||||
files=files,
|
||||
json=json,
|
||||
params=params,
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
auth=auth,
|
||||
follow_redirects=redirect,
|
||||
follow_redirects=follow_redirects,
|
||||
timeout=timeout,
|
||||
extensions=extensions,
|
||||
)
|
||||
@@ -643,35 +593,29 @@ class TestClient(httpx.Client):
|
||||
self,
|
||||
url: httpx._types.URLTypes,
|
||||
*,
|
||||
content: typing.Optional[httpx._types.RequestContent] = None,
|
||||
data: typing.Optional[_RequestData] = None,
|
||||
files: typing.Optional[httpx._types.RequestFiles] = None,
|
||||
json: typing.Any = None,
|
||||
params: typing.Optional[httpx._types.QueryParamTypes] = None,
|
||||
headers: typing.Optional[httpx._types.HeaderTypes] = None,
|
||||
cookies: typing.Optional[httpx._types.CookieTypes] = None,
|
||||
auth: typing.Union[
|
||||
httpx._types.AuthTypes, httpx._client.UseClientDefault
|
||||
] = httpx._client.USE_CLIENT_DEFAULT,
|
||||
follow_redirects: typing.Optional[bool] = None,
|
||||
allow_redirects: typing.Optional[bool] = None,
|
||||
timeout: typing.Union[
|
||||
httpx._client.TimeoutTypes, httpx._client.UseClientDefault
|
||||
] = httpx._client.USE_CLIENT_DEFAULT,
|
||||
extensions: typing.Optional[typing.Dict[str, typing.Any]] = None,
|
||||
content: httpx._types.RequestContent | None = None,
|
||||
data: _RequestData | None = None,
|
||||
files: httpx._types.RequestFiles | None = None,
|
||||
json: Any = None,
|
||||
params: httpx._types.QueryParamTypes | None = None,
|
||||
headers: httpx._types.HeaderTypes | None = None,
|
||||
cookies: httpx._types.CookieTypes | None = None,
|
||||
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
extensions: dict[str, Any] | None = None,
|
||||
) -> httpx.Response:
|
||||
redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
|
||||
return super().patch(
|
||||
url,
|
||||
content=content,
|
||||
data=data, # type: ignore[arg-type]
|
||||
data=data,
|
||||
files=files,
|
||||
json=json,
|
||||
params=params,
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
auth=auth,
|
||||
follow_redirects=redirect,
|
||||
follow_redirects=follow_redirects,
|
||||
timeout=timeout,
|
||||
extensions=extensions,
|
||||
)
|
||||
@@ -680,34 +624,31 @@ class TestClient(httpx.Client):
|
||||
self,
|
||||
url: httpx._types.URLTypes,
|
||||
*,
|
||||
params: typing.Optional[httpx._types.QueryParamTypes] = None,
|
||||
headers: typing.Optional[httpx._types.HeaderTypes] = None,
|
||||
cookies: typing.Optional[httpx._types.CookieTypes] = None,
|
||||
auth: typing.Union[
|
||||
httpx._types.AuthTypes, httpx._client.UseClientDefault
|
||||
] = httpx._client.USE_CLIENT_DEFAULT,
|
||||
follow_redirects: typing.Optional[bool] = None,
|
||||
allow_redirects: typing.Optional[bool] = None,
|
||||
timeout: typing.Union[
|
||||
httpx._client.TimeoutTypes, httpx._client.UseClientDefault
|
||||
] = httpx._client.USE_CLIENT_DEFAULT,
|
||||
extensions: typing.Optional[typing.Dict[str, typing.Any]] = None,
|
||||
params: httpx._types.QueryParamTypes | None = None,
|
||||
headers: httpx._types.HeaderTypes | None = None,
|
||||
cookies: httpx._types.CookieTypes | None = None,
|
||||
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
extensions: dict[str, Any] | None = None,
|
||||
) -> httpx.Response:
|
||||
redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
|
||||
return super().delete(
|
||||
url,
|
||||
params=params,
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
auth=auth,
|
||||
follow_redirects=redirect,
|
||||
follow_redirects=follow_redirects,
|
||||
timeout=timeout,
|
||||
extensions=extensions,
|
||||
)
|
||||
|
||||
def websocket_connect(
|
||||
self, url: str, subprotocols: typing.Sequence[str] = None, **kwargs: typing.Any
|
||||
) -> typing.Any:
|
||||
self,
|
||||
url: str,
|
||||
subprotocols: Sequence[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> WebSocketTestSession:
|
||||
url = urljoin("ws://testserver", url)
|
||||
headers = kwargs.get("headers", {})
|
||||
headers.setdefault("connection", "upgrade")
|
||||
@@ -725,22 +666,24 @@ class TestClient(httpx.Client):
|
||||
|
||||
return session
|
||||
|
||||
def __enter__(self) -> "TestClient":
|
||||
def __enter__(self) -> Self:
|
||||
with contextlib.ExitStack() as stack:
|
||||
self.portal = portal = stack.enter_context(
|
||||
anyio.from_thread.start_blocking_portal(**self.async_backend)
|
||||
)
|
||||
self.portal = portal = stack.enter_context(anyio.from_thread.start_blocking_portal(**self.async_backend))
|
||||
|
||||
@stack.callback
|
||||
def reset_portal() -> None:
|
||||
self.portal = None
|
||||
|
||||
self.stream_send = StapledObjectStream(
|
||||
*anyio.create_memory_object_stream(math.inf)
|
||||
send: anyio.create_memory_object_stream[MutableMapping[str, Any] | None] = (
|
||||
anyio.create_memory_object_stream(math.inf)
|
||||
)
|
||||
self.stream_receive = StapledObjectStream(
|
||||
*anyio.create_memory_object_stream(math.inf)
|
||||
receive: anyio.create_memory_object_stream[MutableMapping[str, Any]] = anyio.create_memory_object_stream(
|
||||
math.inf
|
||||
)
|
||||
for channel in (*send, *receive):
|
||||
stack.callback(channel.close)
|
||||
self.stream_send = StapledObjectStream(*send)
|
||||
self.stream_receive = StapledObjectStream(*receive)
|
||||
self.task = portal.start_task_soon(self.lifespan)
|
||||
portal.call(self.wait_startup)
|
||||
|
||||
@@ -752,7 +695,7 @@ class TestClient(httpx.Client):
|
||||
|
||||
return self
|
||||
|
||||
def __exit__(self, *args: typing.Any) -> None:
|
||||
def __exit__(self, *args: Any) -> None:
|
||||
self.exit_stack.close()
|
||||
|
||||
async def lifespan(self) -> None:
|
||||
@@ -765,7 +708,7 @@ class TestClient(httpx.Client):
|
||||
async def wait_startup(self) -> None:
|
||||
await self.stream_receive.send({"type": "lifespan.startup"})
|
||||
|
||||
async def receive() -> typing.Any:
|
||||
async def receive() -> Any:
|
||||
message = await self.stream_send.receive()
|
||||
if message is None:
|
||||
self.task.result()
|
||||
@@ -780,18 +723,17 @@ class TestClient(httpx.Client):
|
||||
await receive()
|
||||
|
||||
async def wait_shutdown(self) -> None:
|
||||
async def receive() -> typing.Any:
|
||||
async def receive() -> Any:
|
||||
message = await self.stream_send.receive()
|
||||
if message is None:
|
||||
self.task.result()
|
||||
return message
|
||||
|
||||
async with self.stream_send:
|
||||
await self.stream_receive.send({"type": "lifespan.shutdown"})
|
||||
message = await receive()
|
||||
assert message["type"] in (
|
||||
"lifespan.shutdown.complete",
|
||||
"lifespan.shutdown.failed",
|
||||
)
|
||||
if message["type"] == "lifespan.shutdown.failed":
|
||||
await receive()
|
||||
await self.stream_receive.send({"type": "lifespan.shutdown"})
|
||||
message = await receive()
|
||||
assert message["type"] in (
|
||||
"lifespan.shutdown.complete",
|
||||
"lifespan.shutdown.failed",
|
||||
)
|
||||
if message["type"] == "lifespan.shutdown.failed":
|
||||
await receive()
|
||||
|
||||
@@ -1,17 +1,26 @@
|
||||
import typing
|
||||
from collections.abc import Awaitable, Callable, Mapping, MutableMapping
|
||||
from contextlib import AbstractAsyncContextManager
|
||||
from typing import TYPE_CHECKING, Any, TypeVar
|
||||
|
||||
AppType = typing.TypeVar("AppType")
|
||||
if TYPE_CHECKING:
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
from starlette.websockets import WebSocket
|
||||
|
||||
Scope = typing.MutableMapping[str, typing.Any]
|
||||
Message = typing.MutableMapping[str, typing.Any]
|
||||
AppType = TypeVar("AppType")
|
||||
|
||||
Receive = typing.Callable[[], typing.Awaitable[Message]]
|
||||
Send = typing.Callable[[Message], typing.Awaitable[None]]
|
||||
Scope = MutableMapping[str, Any]
|
||||
Message = MutableMapping[str, Any]
|
||||
|
||||
ASGIApp = typing.Callable[[Scope, Receive, Send], typing.Awaitable[None]]
|
||||
Receive = Callable[[], Awaitable[Message]]
|
||||
Send = Callable[[Message], Awaitable[None]]
|
||||
|
||||
StatelessLifespan = typing.Callable[[AppType], typing.AsyncContextManager[None]]
|
||||
StatefulLifespan = typing.Callable[
|
||||
[AppType], typing.AsyncContextManager[typing.Mapping[str, typing.Any]]
|
||||
]
|
||||
Lifespan = typing.Union[StatelessLifespan[AppType], StatefulLifespan[AppType]]
|
||||
ASGIApp = Callable[[Scope, Receive, Send], Awaitable[None]]
|
||||
|
||||
StatelessLifespan = Callable[[AppType], AbstractAsyncContextManager[None]]
|
||||
StatefulLifespan = Callable[[AppType], AbstractAsyncContextManager[Mapping[str, Any]]]
|
||||
Lifespan = StatelessLifespan[AppType] | StatefulLifespan[AppType]
|
||||
|
||||
HTTPExceptionHandler = Callable[["Request", Exception], "Response | Awaitable[Response]"]
|
||||
WebSocketExceptionHandler = Callable[["WebSocket", Exception], Awaitable[None]]
|
||||
ExceptionHandler = HTTPExceptionHandler | WebSocketExceptionHandler
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
import json
|
||||
import typing
|
||||
from collections.abc import AsyncIterator, Iterable
|
||||
from typing import Any, cast
|
||||
|
||||
from starlette.requests import HTTPConnection
|
||||
from starlette.responses import Response
|
||||
from starlette.types import Message, Receive, Scope, Send
|
||||
|
||||
|
||||
@@ -10,10 +14,11 @@ class WebSocketState(enum.Enum):
|
||||
CONNECTING = 0
|
||||
CONNECTED = 1
|
||||
DISCONNECTED = 2
|
||||
RESPONSE = 3
|
||||
|
||||
|
||||
class WebSocketDisconnect(Exception):
|
||||
def __init__(self, code: int = 1000, reason: typing.Optional[str] = None) -> None:
|
||||
def __init__(self, code: int = 1000, reason: str | None = None) -> None:
|
||||
self.code = code
|
||||
self.reason = reason or ""
|
||||
|
||||
@@ -35,10 +40,7 @@ class WebSocket(HTTPConnection):
|
||||
message = await self._receive()
|
||||
message_type = message["type"]
|
||||
if message_type != "websocket.connect":
|
||||
raise RuntimeError(
|
||||
'Expected ASGI message "websocket.connect", '
|
||||
f"but got {message_type!r}"
|
||||
)
|
||||
raise RuntimeError(f'Expected ASGI message "websocket.connect", but got {message_type!r}')
|
||||
self.client_state = WebSocketState.CONNECTED
|
||||
return message
|
||||
elif self.client_state == WebSocketState.CONNECTED:
|
||||
@@ -46,16 +48,13 @@ class WebSocket(HTTPConnection):
|
||||
message_type = message["type"]
|
||||
if message_type not in {"websocket.receive", "websocket.disconnect"}:
|
||||
raise RuntimeError(
|
||||
'Expected ASGI message "websocket.receive" or '
|
||||
f'"websocket.disconnect", but got {message_type!r}'
|
||||
f'Expected ASGI message "websocket.receive" or "websocket.disconnect", but got {message_type!r}'
|
||||
)
|
||||
if message_type == "websocket.disconnect":
|
||||
self.client_state = WebSocketState.DISCONNECTED
|
||||
return message
|
||||
else:
|
||||
raise RuntimeError(
|
||||
'Cannot call "receive" once a disconnect message has been received.'
|
||||
)
|
||||
raise RuntimeError('Cannot call "receive" once a disconnect message has been received.')
|
||||
|
||||
async def send(self, message: Message) -> None:
|
||||
"""
|
||||
@@ -63,13 +62,15 @@ class WebSocket(HTTPConnection):
|
||||
"""
|
||||
if self.application_state == WebSocketState.CONNECTING:
|
||||
message_type = message["type"]
|
||||
if message_type not in {"websocket.accept", "websocket.close"}:
|
||||
if message_type not in {"websocket.accept", "websocket.close", "websocket.http.response.start"}:
|
||||
raise RuntimeError(
|
||||
'Expected ASGI message "websocket.accept" or '
|
||||
f'"websocket.close", but got {message_type!r}'
|
||||
'Expected ASGI message "websocket.accept", "websocket.close" or "websocket.http.response.start", '
|
||||
f"but got {message_type!r}"
|
||||
)
|
||||
if message_type == "websocket.close":
|
||||
self.application_state = WebSocketState.DISCONNECTED
|
||||
elif message_type == "websocket.http.response.start":
|
||||
self.application_state = WebSocketState.RESPONSE
|
||||
else:
|
||||
self.application_state = WebSocketState.CONNECTED
|
||||
await self._send(message)
|
||||
@@ -77,58 +78,60 @@ class WebSocket(HTTPConnection):
|
||||
message_type = message["type"]
|
||||
if message_type not in {"websocket.send", "websocket.close"}:
|
||||
raise RuntimeError(
|
||||
'Expected ASGI message "websocket.send" or "websocket.close", '
|
||||
f"but got {message_type!r}"
|
||||
f'Expected ASGI message "websocket.send" or "websocket.close", but got {message_type!r}'
|
||||
)
|
||||
if message_type == "websocket.close":
|
||||
self.application_state = WebSocketState.DISCONNECTED
|
||||
try:
|
||||
await self._send(message)
|
||||
except OSError:
|
||||
self.application_state = WebSocketState.DISCONNECTED
|
||||
raise WebSocketDisconnect(code=1006)
|
||||
elif self.application_state == WebSocketState.RESPONSE:
|
||||
message_type = message["type"]
|
||||
if message_type != "websocket.http.response.body":
|
||||
raise RuntimeError(f'Expected ASGI message "websocket.http.response.body", but got {message_type!r}')
|
||||
if not message.get("more_body", False):
|
||||
self.application_state = WebSocketState.DISCONNECTED
|
||||
await self._send(message)
|
||||
else:
|
||||
raise RuntimeError('Cannot call "send" once a close message has been sent.')
|
||||
|
||||
async def accept(
|
||||
self,
|
||||
subprotocol: typing.Optional[str] = None,
|
||||
headers: typing.Optional[typing.Iterable[typing.Tuple[bytes, bytes]]] = None,
|
||||
subprotocol: str | None = None,
|
||||
headers: Iterable[tuple[bytes, bytes]] | None = None,
|
||||
) -> None:
|
||||
headers = headers or []
|
||||
|
||||
if self.client_state == WebSocketState.CONNECTING:
|
||||
if self.client_state == WebSocketState.CONNECTING: # pragma: no branch
|
||||
# If we haven't yet seen the 'connect' message, then wait for it first.
|
||||
await self.receive()
|
||||
await self.send(
|
||||
{"type": "websocket.accept", "subprotocol": subprotocol, "headers": headers}
|
||||
)
|
||||
await self.send({"type": "websocket.accept", "subprotocol": subprotocol, "headers": headers})
|
||||
|
||||
def _raise_on_disconnect(self, message: Message) -> None:
|
||||
if message["type"] == "websocket.disconnect":
|
||||
raise WebSocketDisconnect(message["code"])
|
||||
raise WebSocketDisconnect(message["code"], message.get("reason"))
|
||||
|
||||
async def receive_text(self) -> str:
|
||||
if self.application_state != WebSocketState.CONNECTED:
|
||||
raise RuntimeError(
|
||||
'WebSocket is not connected. Need to call "accept" first.'
|
||||
)
|
||||
raise RuntimeError('WebSocket is not connected. Need to call "accept" first.')
|
||||
message = await self.receive()
|
||||
self._raise_on_disconnect(message)
|
||||
return message["text"]
|
||||
return cast(str, message["text"])
|
||||
|
||||
async def receive_bytes(self) -> bytes:
|
||||
if self.application_state != WebSocketState.CONNECTED:
|
||||
raise RuntimeError(
|
||||
'WebSocket is not connected. Need to call "accept" first.'
|
||||
)
|
||||
raise RuntimeError('WebSocket is not connected. Need to call "accept" first.')
|
||||
message = await self.receive()
|
||||
self._raise_on_disconnect(message)
|
||||
return message["bytes"]
|
||||
return cast(bytes, message["bytes"])
|
||||
|
||||
async def receive_json(self, mode: str = "text") -> typing.Any:
|
||||
async def receive_json(self, mode: str = "text") -> Any:
|
||||
if mode not in {"text", "binary"}:
|
||||
raise RuntimeError('The "mode" argument should be "text" or "binary".')
|
||||
if self.application_state != WebSocketState.CONNECTED:
|
||||
raise RuntimeError(
|
||||
'WebSocket is not connected. Need to call "accept" first.'
|
||||
)
|
||||
raise RuntimeError('WebSocket is not connected. Need to call "accept" first.')
|
||||
message = await self.receive()
|
||||
self._raise_on_disconnect(message)
|
||||
|
||||
@@ -138,21 +141,21 @@ class WebSocket(HTTPConnection):
|
||||
text = message["bytes"].decode("utf-8")
|
||||
return json.loads(text)
|
||||
|
||||
async def iter_text(self) -> typing.AsyncIterator[str]:
|
||||
async def iter_text(self) -> AsyncIterator[str]:
|
||||
try:
|
||||
while True:
|
||||
yield await self.receive_text()
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
|
||||
async def iter_bytes(self) -> typing.AsyncIterator[bytes]:
|
||||
async def iter_bytes(self) -> AsyncIterator[bytes]:
|
||||
try:
|
||||
while True:
|
||||
yield await self.receive_bytes()
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
|
||||
async def iter_json(self) -> typing.AsyncIterator[typing.Any]:
|
||||
async def iter_json(self) -> AsyncIterator[Any]:
|
||||
try:
|
||||
while True:
|
||||
yield await self.receive_json()
|
||||
@@ -165,29 +168,29 @@ class WebSocket(HTTPConnection):
|
||||
async def send_bytes(self, data: bytes) -> None:
|
||||
await self.send({"type": "websocket.send", "bytes": data})
|
||||
|
||||
async def send_json(self, data: typing.Any, mode: str = "text") -> None:
|
||||
async def send_json(self, data: Any, mode: str = "text") -> None:
|
||||
if mode not in {"text", "binary"}:
|
||||
raise RuntimeError('The "mode" argument should be "text" or "binary".')
|
||||
text = json.dumps(data, separators=(",", ":"))
|
||||
text = json.dumps(data, separators=(",", ":"), ensure_ascii=False)
|
||||
if mode == "text":
|
||||
await self.send({"type": "websocket.send", "text": text})
|
||||
else:
|
||||
await self.send({"type": "websocket.send", "bytes": text.encode("utf-8")})
|
||||
|
||||
async def close(
|
||||
self, code: int = 1000, reason: typing.Optional[str] = None
|
||||
) -> None:
|
||||
await self.send(
|
||||
{"type": "websocket.close", "code": code, "reason": reason or ""}
|
||||
)
|
||||
async def close(self, code: int = 1000, reason: str | None = None) -> None:
|
||||
await self.send({"type": "websocket.close", "code": code, "reason": reason or ""})
|
||||
|
||||
async def send_denial_response(self, response: Response) -> None:
|
||||
if "websocket.http.response" in self.scope.get("extensions", {}):
|
||||
await response(self.scope, self.receive, self.send)
|
||||
else:
|
||||
raise RuntimeError("The server doesn't support the Websocket Denial Response extension.")
|
||||
|
||||
|
||||
class WebSocketClose:
|
||||
def __init__(self, code: int = 1000, reason: typing.Optional[str] = None) -> None:
|
||||
def __init__(self, code: int = 1000, reason: str | None = None) -> None:
|
||||
self.code = code
|
||||
self.reason = reason or ""
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
await send(
|
||||
{"type": "websocket.close", "code": self.code, "reason": self.reason}
|
||||
)
|
||||
await send({"type": "websocket.close", "code": self.code, "reason": self.reason})
|
||||
|
||||
Reference in New Issue
Block a user