updates
This commit is contained in:
@@ -1,17 +1,37 @@
|
||||
import typing
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Awaitable, Callable, Iterator
|
||||
from typing import Any, ParamSpec, Protocol
|
||||
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
||||
_Scope = Any
|
||||
_Receive = Callable[[], Awaitable[Any]]
|
||||
_Send = Callable[[Any], Awaitable[None]]
|
||||
# Since `starlette.types.ASGIApp` type differs from `ASGIApplication` from `asgiref`
|
||||
# we need to define a more permissive version of ASGIApp that doesn't cause type errors.
|
||||
_ASGIApp = Callable[[_Scope, _Receive, _Send], Awaitable[None]]
|
||||
|
||||
|
||||
class _MiddlewareFactory(Protocol[P]):
|
||||
def __call__(self, app: _ASGIApp, /, *args: P.args, **kwargs: P.kwargs) -> _ASGIApp: ... # pragma: no cover
|
||||
|
||||
|
||||
class Middleware:
|
||||
def __init__(self, cls: type, **options: typing.Any) -> None:
|
||||
def __init__(self, cls: _MiddlewareFactory[P], *args: P.args, **kwargs: P.kwargs) -> None:
|
||||
self.cls = cls
|
||||
self.options = options
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
|
||||
def __iter__(self) -> typing.Iterator:
|
||||
as_tuple = (self.cls, self.options)
|
||||
def __iter__(self) -> Iterator[Any]:
|
||||
as_tuple = (self.cls, self.args, self.kwargs)
|
||||
return iter(as_tuple)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
class_name = self.__class__.__name__
|
||||
option_strings = [f"{key}={value!r}" for key, value in self.options.items()]
|
||||
args_repr = ", ".join([self.cls.__name__] + option_strings)
|
||||
args_strings = [f"{value!r}" for value in self.args]
|
||||
option_strings = [f"{key}={value!r}" for key, value in self.kwargs.items()]
|
||||
name = getattr(self.cls, "__name__", "")
|
||||
args_repr = ", ".join([name] + args_strings + option_strings)
|
||||
return f"{class_name}({args_repr})"
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,4 +1,6 @@
|
||||
import typing
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
from starlette.authentication import (
|
||||
AuthCredentials,
|
||||
@@ -16,15 +18,13 @@ class AuthenticationMiddleware:
|
||||
self,
|
||||
app: ASGIApp,
|
||||
backend: AuthenticationBackend,
|
||||
on_error: typing.Optional[
|
||||
typing.Callable[[HTTPConnection, AuthenticationError], Response]
|
||||
] = None,
|
||||
on_error: Callable[[HTTPConnection, AuthenticationError], Response] | None = None,
|
||||
) -> None:
|
||||
self.app = app
|
||||
self.backend = backend
|
||||
self.on_error: typing.Callable[
|
||||
[HTTPConnection, AuthenticationError], Response
|
||||
] = (on_error if on_error is not None else self.default_on_error)
|
||||
self.on_error: Callable[[HTTPConnection, AuthenticationError], Response] = (
|
||||
on_error if on_error is not None else self.default_on_error
|
||||
)
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if scope["type"] not in ["http", "websocket"]:
|
||||
|
||||
@@ -1,23 +1,100 @@
|
||||
import typing
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncGenerator, AsyncIterable, Awaitable, Callable, Mapping, MutableMapping
|
||||
from typing import Any, TypeVar
|
||||
|
||||
import anyio
|
||||
|
||||
from starlette.background import BackgroundTask
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import ContentStream, Response, StreamingResponse
|
||||
from starlette._utils import collapse_excgroups
|
||||
from starlette.requests import ClientDisconnect, Request
|
||||
from starlette.responses import Response
|
||||
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
||||
|
||||
RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]]
|
||||
DispatchFunction = typing.Callable[
|
||||
[Request, RequestResponseEndpoint], typing.Awaitable[Response]
|
||||
]
|
||||
T = typing.TypeVar("T")
|
||||
RequestResponseEndpoint = Callable[[Request], Awaitable[Response]]
|
||||
DispatchFunction = Callable[[Request, RequestResponseEndpoint], Awaitable[Response]]
|
||||
BodyStreamGenerator = AsyncGenerator[bytes | MutableMapping[str, Any], None]
|
||||
AsyncContentStream = AsyncIterable[str | bytes | memoryview | MutableMapping[str, Any]]
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class _CachedRequest(Request):
|
||||
"""
|
||||
If the user calls Request.body() from their dispatch function
|
||||
we cache the entire request body in memory and pass that to downstream middlewares,
|
||||
but if they call Request.stream() then all we do is send an
|
||||
empty body so that downstream things don't hang forever.
|
||||
"""
|
||||
|
||||
def __init__(self, scope: Scope, receive: Receive):
|
||||
super().__init__(scope, receive)
|
||||
self._wrapped_rcv_disconnected = False
|
||||
self._wrapped_rcv_consumed = False
|
||||
self._wrapped_rc_stream = self.stream()
|
||||
|
||||
async def wrapped_receive(self) -> Message:
|
||||
# wrapped_rcv state 1: disconnected
|
||||
if self._wrapped_rcv_disconnected:
|
||||
# we've already sent a disconnect to the downstream app
|
||||
# we don't need to wait to get another one
|
||||
# (although most ASGI servers will just keep sending it)
|
||||
return {"type": "http.disconnect"}
|
||||
# wrapped_rcv state 1: consumed but not yet disconnected
|
||||
if self._wrapped_rcv_consumed:
|
||||
# since the downstream app has consumed us all that is left
|
||||
# is to send it a disconnect
|
||||
if self._is_disconnected:
|
||||
# the middleware has already seen the disconnect
|
||||
# since we know the client is disconnected no need to wait
|
||||
# for the message
|
||||
self._wrapped_rcv_disconnected = True
|
||||
return {"type": "http.disconnect"}
|
||||
# we don't know yet if the client is disconnected or not
|
||||
# so we'll wait until we get that message
|
||||
msg = await self.receive()
|
||||
if msg["type"] != "http.disconnect": # pragma: no cover
|
||||
# at this point a disconnect is all that we should be receiving
|
||||
# if we get something else, things went wrong somewhere
|
||||
raise RuntimeError(f"Unexpected message received: {msg['type']}")
|
||||
self._wrapped_rcv_disconnected = True
|
||||
return msg
|
||||
|
||||
# wrapped_rcv state 3: not yet consumed
|
||||
if getattr(self, "_body", None) is not None:
|
||||
# body() was called, we return it even if the client disconnected
|
||||
self._wrapped_rcv_consumed = True
|
||||
return {
|
||||
"type": "http.request",
|
||||
"body": self._body,
|
||||
"more_body": False,
|
||||
}
|
||||
elif self._stream_consumed:
|
||||
# stream() was called to completion
|
||||
# return an empty body so that downstream apps don't hang
|
||||
# waiting for a disconnect
|
||||
self._wrapped_rcv_consumed = True
|
||||
return {
|
||||
"type": "http.request",
|
||||
"body": b"",
|
||||
"more_body": False,
|
||||
}
|
||||
else:
|
||||
# body() was never called and stream() wasn't consumed
|
||||
try:
|
||||
stream = self.stream()
|
||||
chunk = await stream.__anext__()
|
||||
self._wrapped_rcv_consumed = self._stream_consumed
|
||||
return {
|
||||
"type": "http.request",
|
||||
"body": chunk,
|
||||
"more_body": not self._stream_consumed,
|
||||
}
|
||||
except ClientDisconnect:
|
||||
self._wrapped_rcv_disconnected = True
|
||||
return {"type": "http.disconnect"}
|
||||
|
||||
|
||||
class BaseHTTPMiddleware:
|
||||
def __init__(
|
||||
self, app: ASGIApp, dispatch: typing.Optional[DispatchFunction] = None
|
||||
) -> None:
|
||||
def __init__(self, app: ASGIApp, dispatch: DispatchFunction | None = None) -> None:
|
||||
self.app = app
|
||||
self.dispatch_func = self.dispatch if dispatch is None else dispatch
|
||||
|
||||
@@ -26,35 +103,32 @@ class BaseHTTPMiddleware:
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
request = _CachedRequest(scope, receive)
|
||||
wrapped_receive = request.wrapped_receive
|
||||
response_sent = anyio.Event()
|
||||
app_exc: Exception | None = None
|
||||
exception_already_raised = False
|
||||
|
||||
async def call_next(request: Request) -> Response:
|
||||
app_exc: typing.Optional[Exception] = None
|
||||
send_stream, recv_stream = anyio.create_memory_object_stream()
|
||||
|
||||
async def receive_or_disconnect() -> Message:
|
||||
if response_sent.is_set():
|
||||
return {"type": "http.disconnect"}
|
||||
|
||||
async with anyio.create_task_group() as task_group:
|
||||
|
||||
async def wrap(func: typing.Callable[[], typing.Awaitable[T]]) -> T:
|
||||
async def wrap(func: Callable[[], Awaitable[T]]) -> T:
|
||||
result = await func()
|
||||
task_group.cancel_scope.cancel()
|
||||
return result
|
||||
|
||||
task_group.start_soon(wrap, response_sent.wait)
|
||||
message = await wrap(request.receive)
|
||||
message = await wrap(wrapped_receive)
|
||||
|
||||
if response_sent.is_set():
|
||||
return {"type": "http.disconnect"}
|
||||
|
||||
return message
|
||||
|
||||
async def close_recv_stream_on_response_sent() -> None:
|
||||
await response_sent.wait()
|
||||
recv_stream.close()
|
||||
|
||||
async def send_no_error(message: Message) -> None:
|
||||
try:
|
||||
await send_stream.send(message)
|
||||
@@ -65,13 +139,12 @@ class BaseHTTPMiddleware:
|
||||
async def coro() -> None:
|
||||
nonlocal app_exc
|
||||
|
||||
async with send_stream:
|
||||
with send_stream:
|
||||
try:
|
||||
await self.app(scope, receive_or_disconnect, send_no_error)
|
||||
except Exception as exc:
|
||||
app_exc = exc
|
||||
|
||||
task_group.start_soon(close_recv_stream_on_response_sent)
|
||||
task_group.start_soon(coro)
|
||||
|
||||
try:
|
||||
@@ -81,54 +154,91 @@ class BaseHTTPMiddleware:
|
||||
message = await recv_stream.receive()
|
||||
except anyio.EndOfStream:
|
||||
if app_exc is not None:
|
||||
raise app_exc
|
||||
nonlocal exception_already_raised
|
||||
exception_already_raised = True
|
||||
# Prevent `anyio.EndOfStream` from polluting app exception context.
|
||||
# If both cause and context are None then the context is suppressed
|
||||
# and `anyio.EndOfStream` is not present in the exception traceback.
|
||||
# If exception cause is not None then it is propagated with
|
||||
# reraising here.
|
||||
# If exception has no cause but has context set then the context is
|
||||
# propagated as a cause with the reraise. This is necessary in order
|
||||
# to prevent `anyio.EndOfStream` from polluting the exception
|
||||
# context.
|
||||
raise app_exc from app_exc.__cause__ or app_exc.__context__
|
||||
raise RuntimeError("No response returned.")
|
||||
|
||||
assert message["type"] == "http.response.start"
|
||||
|
||||
async def body_stream() -> typing.AsyncGenerator[bytes, None]:
|
||||
async with recv_stream:
|
||||
async for message in recv_stream:
|
||||
assert message["type"] == "http.response.body"
|
||||
body = message.get("body", b"")
|
||||
if body:
|
||||
yield body
|
||||
async def body_stream() -> BodyStreamGenerator:
|
||||
async for message in recv_stream:
|
||||
if message["type"] == "http.response.pathsend":
|
||||
yield message
|
||||
break
|
||||
assert message["type"] == "http.response.body", f"Unexpected message: {message}"
|
||||
body = message.get("body", b"")
|
||||
if body:
|
||||
yield body
|
||||
if not message.get("more_body", False):
|
||||
break
|
||||
|
||||
if app_exc is not None:
|
||||
raise app_exc
|
||||
|
||||
response = _StreamingResponse(
|
||||
status_code=message["status"], content=body_stream(), info=info
|
||||
)
|
||||
response = _StreamingResponse(status_code=message["status"], content=body_stream(), info=info)
|
||||
response.raw_headers = message["headers"]
|
||||
return response
|
||||
|
||||
async with anyio.create_task_group() as task_group:
|
||||
request = Request(scope, receive=receive)
|
||||
response = await self.dispatch_func(request, call_next)
|
||||
await response(scope, receive, send)
|
||||
response_sent.set()
|
||||
streams: anyio.create_memory_object_stream[Message] = anyio.create_memory_object_stream()
|
||||
send_stream, recv_stream = streams
|
||||
with recv_stream, send_stream, collapse_excgroups():
|
||||
async with anyio.create_task_group() as task_group:
|
||||
response = await self.dispatch_func(request, call_next)
|
||||
await response(scope, wrapped_receive, send)
|
||||
response_sent.set()
|
||||
recv_stream.close()
|
||||
if app_exc is not None and not exception_already_raised:
|
||||
raise app_exc
|
||||
|
||||
async def dispatch(
|
||||
self, request: Request, call_next: RequestResponseEndpoint
|
||||
) -> Response:
|
||||
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
|
||||
class _StreamingResponse(StreamingResponse):
|
||||
class _StreamingResponse(Response):
|
||||
def __init__(
|
||||
self,
|
||||
content: ContentStream,
|
||||
content: AsyncContentStream,
|
||||
status_code: int = 200,
|
||||
headers: typing.Optional[typing.Mapping[str, str]] = None,
|
||||
media_type: typing.Optional[str] = None,
|
||||
background: typing.Optional[BackgroundTask] = None,
|
||||
info: typing.Optional[typing.Mapping[str, typing.Any]] = None,
|
||||
headers: Mapping[str, str] | None = None,
|
||||
media_type: str | None = None,
|
||||
info: Mapping[str, Any] | None = None,
|
||||
) -> None:
|
||||
self._info = info
|
||||
super().__init__(content, status_code, headers, media_type, background)
|
||||
self.info = info
|
||||
self.body_iterator = content
|
||||
self.status_code = status_code
|
||||
self.media_type = media_type
|
||||
self.init_headers(headers)
|
||||
self.background = None
|
||||
|
||||
async def stream_response(self, send: Send) -> None:
|
||||
if self._info:
|
||||
await send({"type": "http.response.debug", "info": self._info})
|
||||
return await super().stream_response(send)
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if self.info is not None:
|
||||
await send({"type": "http.response.debug", "info": self.info})
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": self.status_code,
|
||||
"headers": self.raw_headers,
|
||||
}
|
||||
)
|
||||
|
||||
should_close_body = True
|
||||
async for chunk in self.body_iterator:
|
||||
if isinstance(chunk, dict):
|
||||
# We got an ASGI message which is not response body (eg: pathsend)
|
||||
should_close_body = False
|
||||
await send(chunk)
|
||||
continue
|
||||
await send({"type": "http.response.body", "body": chunk, "more_body": True})
|
||||
|
||||
if should_close_body:
|
||||
await send({"type": "http.response.body", "body": b"", "more_body": False})
|
||||
|
||||
if self.background:
|
||||
await self.background()
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import re
|
||||
import typing
|
||||
from collections.abc import Sequence
|
||||
|
||||
from starlette.datastructures import Headers, MutableHeaders
|
||||
from starlette.responses import PlainTextResponse, Response
|
||||
@@ -14,12 +16,12 @@ class CORSMiddleware:
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
allow_origins: typing.Sequence[str] = (),
|
||||
allow_methods: typing.Sequence[str] = ("GET",),
|
||||
allow_headers: typing.Sequence[str] = (),
|
||||
allow_origins: Sequence[str] = (),
|
||||
allow_methods: Sequence[str] = ("GET",),
|
||||
allow_headers: Sequence[str] = (),
|
||||
allow_credentials: bool = False,
|
||||
allow_origin_regex: typing.Optional[str] = None,
|
||||
expose_headers: typing.Sequence[str] = (),
|
||||
allow_origin_regex: str | None = None,
|
||||
expose_headers: Sequence[str] = (),
|
||||
max_age: int = 600,
|
||||
) -> None:
|
||||
if "*" in allow_methods:
|
||||
@@ -94,9 +96,7 @@ class CORSMiddleware:
|
||||
if self.allow_all_origins:
|
||||
return True
|
||||
|
||||
if self.allow_origin_regex is not None and self.allow_origin_regex.fullmatch(
|
||||
origin
|
||||
):
|
||||
if self.allow_origin_regex is not None and self.allow_origin_regex.fullmatch(origin):
|
||||
return True
|
||||
|
||||
return origin in self.allow_origins
|
||||
@@ -139,15 +139,11 @@ class CORSMiddleware:
|
||||
|
||||
return PlainTextResponse("OK", status_code=200, headers=headers)
|
||||
|
||||
async def simple_response(
|
||||
self, scope: Scope, receive: Receive, send: Send, request_headers: Headers
|
||||
) -> None:
|
||||
async def simple_response(self, scope: Scope, receive: Receive, send: Send, request_headers: Headers) -> None:
|
||||
send = functools.partial(self.send, send=send, request_headers=request_headers)
|
||||
await self.app(scope, receive, send)
|
||||
|
||||
async def send(
|
||||
self, message: Message, send: Send, request_headers: Headers
|
||||
) -> None:
|
||||
async def send(self, message: Message, send: Send, request_headers: Headers) -> None:
|
||||
if message["type"] != "http.response.start":
|
||||
await send(message)
|
||||
return
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import html
|
||||
import inspect
|
||||
import sys
|
||||
import traceback
|
||||
import typing
|
||||
|
||||
from starlette._utils import is_async_callable
|
||||
from starlette.concurrency import run_in_threadpool
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import HTMLResponse, PlainTextResponse, Response
|
||||
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
||||
from starlette.types import ASGIApp, ExceptionHandler, Message, Receive, Scope, Send
|
||||
|
||||
STYLES = """
|
||||
p {
|
||||
@@ -137,7 +139,7 @@ class ServerErrorMiddleware:
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
handler: typing.Optional[typing.Callable] = None,
|
||||
handler: ExceptionHandler | None = None,
|
||||
debug: bool = False,
|
||||
) -> None:
|
||||
self.app = app
|
||||
@@ -183,9 +185,7 @@ class ServerErrorMiddleware:
|
||||
# to optionally raise the error within the test case.
|
||||
raise exc
|
||||
|
||||
def format_line(
|
||||
self, index: int, line: str, frame_lineno: int, frame_index: int
|
||||
) -> str:
|
||||
def format_line(self, index: int, line: str, frame_lineno: int, frame_index: int) -> str:
|
||||
values = {
|
||||
# HTML escape - line could contain < or >
|
||||
"line": html.escape(line).replace(" ", " "),
|
||||
@@ -199,7 +199,10 @@ class ServerErrorMiddleware:
|
||||
def generate_frame_html(self, frame: inspect.FrameInfo, is_collapsed: bool) -> str:
|
||||
code_context = "".join(
|
||||
self.format_line(
|
||||
index, line, frame.lineno, frame.index # type: ignore[arg-type]
|
||||
index,
|
||||
line,
|
||||
frame.lineno,
|
||||
frame.index, # type: ignore[arg-type]
|
||||
)
|
||||
for index, line in enumerate(frame.code_context or [])
|
||||
)
|
||||
@@ -219,9 +222,7 @@ class ServerErrorMiddleware:
|
||||
return FRAME_TEMPLATE.format(**values)
|
||||
|
||||
def generate_html(self, exc: Exception, limit: int = 7) -> str:
|
||||
traceback_obj = traceback.TracebackException.from_exception(
|
||||
exc, capture_locals=True
|
||||
)
|
||||
traceback_obj = traceback.TracebackException.from_exception(exc, capture_locals=True)
|
||||
|
||||
exc_html = ""
|
||||
is_collapsed = False
|
||||
@@ -232,11 +233,13 @@ class ServerErrorMiddleware:
|
||||
exc_html += self.generate_frame_html(frame, is_collapsed)
|
||||
is_collapsed = True
|
||||
|
||||
if sys.version_info >= (3, 13): # pragma: no cover
|
||||
exc_type_str = traceback_obj.exc_type_str
|
||||
else: # pragma: no cover
|
||||
exc_type_str = traceback_obj.exc_type.__name__
|
||||
|
||||
# escape error class and text
|
||||
error = (
|
||||
f"{html.escape(traceback_obj.exc_type.__name__)}: "
|
||||
f"{html.escape(str(traceback_obj))}"
|
||||
)
|
||||
error = f"{html.escape(exc_type_str)}: {html.escape(str(traceback_obj))}"
|
||||
|
||||
return TEMPLATE.format(styles=STYLES, js=JS, error=error, exc_html=exc_html)
|
||||
|
||||
|
||||
@@ -1,11 +1,17 @@
|
||||
import typing
|
||||
from __future__ import annotations
|
||||
|
||||
from starlette._utils import is_async_callable
|
||||
from starlette.concurrency import run_in_threadpool
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from starlette._exception_handler import (
|
||||
ExceptionHandlers,
|
||||
StatusHandlers,
|
||||
wrap_app_handling_exceptions,
|
||||
)
|
||||
from starlette.exceptions import HTTPException, WebSocketException
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import PlainTextResponse, Response
|
||||
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
||||
from starlette.types import ASGIApp, ExceptionHandler, Receive, Scope, Send
|
||||
from starlette.websockets import WebSocket
|
||||
|
||||
|
||||
@@ -13,28 +19,24 @@ class ExceptionMiddleware:
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
handlers: typing.Optional[
|
||||
typing.Mapping[typing.Any, typing.Callable[[Request, Exception], Response]]
|
||||
] = None,
|
||||
handlers: Mapping[Any, ExceptionHandler] | None = None,
|
||||
debug: bool = False,
|
||||
) -> None:
|
||||
self.app = app
|
||||
self.debug = debug # TODO: We ought to handle 404 cases if debug is set.
|
||||
self._status_handlers: typing.Dict[int, typing.Callable] = {}
|
||||
self._exception_handlers: typing.Dict[
|
||||
typing.Type[Exception], typing.Callable
|
||||
] = {
|
||||
self._status_handlers: StatusHandlers = {}
|
||||
self._exception_handlers: ExceptionHandlers = {
|
||||
HTTPException: self.http_exception,
|
||||
WebSocketException: self.websocket_exception,
|
||||
}
|
||||
if handlers is not None:
|
||||
if handlers is not None: # pragma: no branch
|
||||
for key, value in handlers.items():
|
||||
self.add_exception_handler(key, value)
|
||||
|
||||
def add_exception_handler(
|
||||
self,
|
||||
exc_class_or_status_code: typing.Union[int, typing.Type[Exception]],
|
||||
handler: typing.Callable[[Request, Exception], Response],
|
||||
exc_class_or_status_code: int | type[Exception],
|
||||
handler: ExceptionHandler,
|
||||
) -> None:
|
||||
if isinstance(exc_class_or_status_code, int):
|
||||
self._status_handlers[exc_class_or_status_code] = handler
|
||||
@@ -42,68 +44,30 @@ class ExceptionMiddleware:
|
||||
assert issubclass(exc_class_or_status_code, Exception)
|
||||
self._exception_handlers[exc_class_or_status_code] = handler
|
||||
|
||||
def _lookup_exception_handler(
|
||||
self, exc: Exception
|
||||
) -> typing.Optional[typing.Callable]:
|
||||
for cls in type(exc).__mro__:
|
||||
if cls in self._exception_handlers:
|
||||
return self._exception_handlers[cls]
|
||||
return None
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if scope["type"] not in ("http", "websocket"):
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
response_started = False
|
||||
|
||||
async def sender(message: Message) -> None:
|
||||
nonlocal response_started
|
||||
|
||||
if message["type"] == "http.response.start":
|
||||
response_started = True
|
||||
await send(message)
|
||||
|
||||
try:
|
||||
await self.app(scope, receive, sender)
|
||||
except Exception as exc:
|
||||
handler = None
|
||||
|
||||
if isinstance(exc, HTTPException):
|
||||
handler = self._status_handlers.get(exc.status_code)
|
||||
|
||||
if handler is None:
|
||||
handler = self._lookup_exception_handler(exc)
|
||||
|
||||
if handler is None:
|
||||
raise exc
|
||||
|
||||
if response_started:
|
||||
msg = "Caught handled exception, but response already started."
|
||||
raise RuntimeError(msg) from exc
|
||||
|
||||
if scope["type"] == "http":
|
||||
request = Request(scope, receive=receive)
|
||||
if is_async_callable(handler):
|
||||
response = await handler(request, exc)
|
||||
else:
|
||||
response = await run_in_threadpool(handler, request, exc)
|
||||
await response(scope, receive, sender)
|
||||
elif scope["type"] == "websocket":
|
||||
websocket = WebSocket(scope, receive=receive, send=send)
|
||||
if is_async_callable(handler):
|
||||
await handler(websocket, exc)
|
||||
else:
|
||||
await run_in_threadpool(handler, websocket, exc)
|
||||
|
||||
def http_exception(self, request: Request, exc: HTTPException) -> Response:
|
||||
if exc.status_code in {204, 304}:
|
||||
return Response(status_code=exc.status_code, headers=exc.headers)
|
||||
return PlainTextResponse(
|
||||
exc.detail, status_code=exc.status_code, headers=exc.headers
|
||||
scope["starlette.exception_handlers"] = (
|
||||
self._exception_handlers,
|
||||
self._status_handlers,
|
||||
)
|
||||
|
||||
async def websocket_exception(
|
||||
self, websocket: WebSocket, exc: WebSocketException
|
||||
) -> None:
|
||||
await websocket.close(code=exc.code, reason=exc.reason)
|
||||
conn: Request | WebSocket
|
||||
if scope["type"] == "http":
|
||||
conn = Request(scope, receive, send)
|
||||
else:
|
||||
conn = WebSocket(scope, receive, send)
|
||||
|
||||
await wrap_app_handling_exceptions(self.app, conn)(scope, receive, send)
|
||||
|
||||
async def http_exception(self, request: Request, exc: Exception) -> Response:
|
||||
assert isinstance(exc, HTTPException)
|
||||
if exc.status_code in {204, 304}:
|
||||
return Response(status_code=exc.status_code, headers=exc.headers)
|
||||
return PlainTextResponse(exc.detail, status_code=exc.status_code, headers=exc.headers)
|
||||
|
||||
async def websocket_exception(self, websocket: WebSocket, exc: Exception) -> None:
|
||||
assert isinstance(exc, WebSocketException)
|
||||
await websocket.close(code=exc.code, reason=exc.reason) # pragma: no cover
|
||||
|
||||
@@ -1,49 +1,51 @@
|
||||
import gzip
|
||||
import io
|
||||
import typing
|
||||
from typing import NoReturn
|
||||
|
||||
from starlette.datastructures import Headers, MutableHeaders
|
||||
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
||||
|
||||
DEFAULT_EXCLUDED_CONTENT_TYPES = ("text/event-stream",)
|
||||
|
||||
|
||||
class GZipMiddleware:
|
||||
def __init__(
|
||||
self, app: ASGIApp, minimum_size: int = 500, compresslevel: int = 9
|
||||
) -> None:
|
||||
def __init__(self, app: ASGIApp, minimum_size: int = 500, compresslevel: int = 9) -> None:
|
||||
self.app = app
|
||||
self.minimum_size = minimum_size
|
||||
self.compresslevel = compresslevel
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if scope["type"] == "http":
|
||||
headers = Headers(scope=scope)
|
||||
if "gzip" in headers.get("Accept-Encoding", ""):
|
||||
responder = GZipResponder(
|
||||
self.app, self.minimum_size, compresslevel=self.compresslevel
|
||||
)
|
||||
await responder(scope, receive, send)
|
||||
return
|
||||
await self.app(scope, receive, send)
|
||||
if scope["type"] != "http": # pragma: no cover
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
headers = Headers(scope=scope)
|
||||
responder: ASGIApp
|
||||
if "gzip" in headers.get("Accept-Encoding", ""):
|
||||
responder = GZipResponder(self.app, self.minimum_size, compresslevel=self.compresslevel)
|
||||
else:
|
||||
responder = IdentityResponder(self.app, self.minimum_size)
|
||||
|
||||
await responder(scope, receive, send)
|
||||
|
||||
|
||||
class GZipResponder:
|
||||
def __init__(self, app: ASGIApp, minimum_size: int, compresslevel: int = 9) -> None:
|
||||
class IdentityResponder:
|
||||
content_encoding: str
|
||||
|
||||
def __init__(self, app: ASGIApp, minimum_size: int) -> None:
|
||||
self.app = app
|
||||
self.minimum_size = minimum_size
|
||||
self.send: Send = unattached_send
|
||||
self.initial_message: Message = {}
|
||||
self.started = False
|
||||
self.content_encoding_set = False
|
||||
self.gzip_buffer = io.BytesIO()
|
||||
self.gzip_file = gzip.GzipFile(
|
||||
mode="wb", fileobj=self.gzip_buffer, compresslevel=compresslevel
|
||||
)
|
||||
self.content_type_is_excluded = False
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
self.send = send
|
||||
await self.app(scope, receive, self.send_with_gzip)
|
||||
await self.app(scope, receive, self.send_with_compression)
|
||||
|
||||
async def send_with_gzip(self, message: Message) -> None:
|
||||
async def send_with_compression(self, message: Message) -> None:
|
||||
message_type = message["type"]
|
||||
if message_type == "http.response.start":
|
||||
# Don't send the initial message until we've determined how to
|
||||
@@ -51,7 +53,8 @@ class GZipResponder:
|
||||
self.initial_message = message
|
||||
headers = Headers(raw=self.initial_message["headers"])
|
||||
self.content_encoding_set = "content-encoding" in headers
|
||||
elif message_type == "http.response.body" and self.content_encoding_set:
|
||||
self.content_type_is_excluded = headers.get("content-type", "").startswith(DEFAULT_EXCLUDED_CONTENT_TYPES)
|
||||
elif message_type == "http.response.body" and (self.content_encoding_set or self.content_type_is_excluded):
|
||||
if not self.started:
|
||||
self.started = True
|
||||
await self.send(self.initial_message)
|
||||
@@ -61,53 +64,82 @@ class GZipResponder:
|
||||
body = message.get("body", b"")
|
||||
more_body = message.get("more_body", False)
|
||||
if len(body) < self.minimum_size and not more_body:
|
||||
# Don't apply GZip to small outgoing responses.
|
||||
# Don't apply compression to small outgoing responses.
|
||||
await self.send(self.initial_message)
|
||||
await self.send(message)
|
||||
elif not more_body:
|
||||
# Standard GZip response.
|
||||
self.gzip_file.write(body)
|
||||
self.gzip_file.close()
|
||||
body = self.gzip_buffer.getvalue()
|
||||
# Standard response.
|
||||
body = self.apply_compression(body, more_body=False)
|
||||
|
||||
headers = MutableHeaders(raw=self.initial_message["headers"])
|
||||
headers["Content-Encoding"] = "gzip"
|
||||
headers["Content-Length"] = str(len(body))
|
||||
headers.add_vary_header("Accept-Encoding")
|
||||
message["body"] = body
|
||||
if body != message["body"]:
|
||||
headers["Content-Encoding"] = self.content_encoding
|
||||
headers["Content-Length"] = str(len(body))
|
||||
message["body"] = body
|
||||
|
||||
await self.send(self.initial_message)
|
||||
await self.send(message)
|
||||
else:
|
||||
# Initial body in streaming GZip response.
|
||||
headers = MutableHeaders(raw=self.initial_message["headers"])
|
||||
headers["Content-Encoding"] = "gzip"
|
||||
headers.add_vary_header("Accept-Encoding")
|
||||
del headers["Content-Length"]
|
||||
# Initial body in streaming response.
|
||||
body = self.apply_compression(body, more_body=True)
|
||||
|
||||
self.gzip_file.write(body)
|
||||
message["body"] = self.gzip_buffer.getvalue()
|
||||
self.gzip_buffer.seek(0)
|
||||
self.gzip_buffer.truncate()
|
||||
headers = MutableHeaders(raw=self.initial_message["headers"])
|
||||
headers.add_vary_header("Accept-Encoding")
|
||||
if body != message["body"]:
|
||||
headers["Content-Encoding"] = self.content_encoding
|
||||
del headers["Content-Length"]
|
||||
message["body"] = body
|
||||
|
||||
await self.send(self.initial_message)
|
||||
await self.send(message)
|
||||
|
||||
elif message_type == "http.response.body":
|
||||
# Remaining body in streaming GZip response.
|
||||
# Remaining body in streaming response.
|
||||
body = message.get("body", b"")
|
||||
more_body = message.get("more_body", False)
|
||||
|
||||
self.gzip_file.write(body)
|
||||
if not more_body:
|
||||
self.gzip_file.close()
|
||||
|
||||
message["body"] = self.gzip_buffer.getvalue()
|
||||
self.gzip_buffer.seek(0)
|
||||
self.gzip_buffer.truncate()
|
||||
message["body"] = self.apply_compression(body, more_body=more_body)
|
||||
|
||||
await self.send(message)
|
||||
elif message_type == "http.response.pathsend": # pragma: no branch
|
||||
# Don't apply GZip to pathsend responses
|
||||
await self.send(self.initial_message)
|
||||
await self.send(message)
|
||||
|
||||
def apply_compression(self, body: bytes, *, more_body: bool) -> bytes:
|
||||
"""Apply compression on the response body.
|
||||
|
||||
If more_body is False, any compression file should be closed. If it
|
||||
isn't, it won't be closed automatically until all background tasks
|
||||
complete.
|
||||
"""
|
||||
return body
|
||||
|
||||
|
||||
async def unattached_send(message: Message) -> typing.NoReturn:
|
||||
class GZipResponder(IdentityResponder):
|
||||
content_encoding = "gzip"
|
||||
|
||||
def __init__(self, app: ASGIApp, minimum_size: int, compresslevel: int = 9) -> None:
|
||||
super().__init__(app, minimum_size)
|
||||
|
||||
self.gzip_buffer = io.BytesIO()
|
||||
self.gzip_file = gzip.GzipFile(mode="wb", fileobj=self.gzip_buffer, compresslevel=compresslevel)
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
with self.gzip_buffer, self.gzip_file:
|
||||
await super().__call__(scope, receive, send)
|
||||
|
||||
def apply_compression(self, body: bytes, *, more_body: bool) -> bytes:
|
||||
self.gzip_file.write(body)
|
||||
if not more_body:
|
||||
self.gzip_file.close()
|
||||
|
||||
body = self.gzip_buffer.getvalue()
|
||||
self.gzip_buffer.seek(0)
|
||||
self.gzip_buffer.truncate()
|
||||
|
||||
return body
|
||||
|
||||
|
||||
async def unattached_send(message: Message) -> NoReturn:
|
||||
raise RuntimeError("send awaitable not set") # pragma: no cover
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sys
|
||||
import typing
|
||||
from base64 import b64decode, b64encode
|
||||
from typing import Literal
|
||||
|
||||
import itsdangerous
|
||||
from itsdangerous.exc import BadSignature
|
||||
@@ -10,22 +11,18 @@ from starlette.datastructures import MutableHeaders, Secret
|
||||
from starlette.requests import HTTPConnection
|
||||
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
||||
|
||||
if sys.version_info >= (3, 8): # pragma: no cover
|
||||
from typing import Literal
|
||||
else: # pragma: no cover
|
||||
from typing_extensions import Literal
|
||||
|
||||
|
||||
class SessionMiddleware:
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
secret_key: typing.Union[str, Secret],
|
||||
secret_key: str | Secret,
|
||||
session_cookie: str = "session",
|
||||
max_age: typing.Optional[int] = 14 * 24 * 60 * 60, # 14 days, in seconds
|
||||
max_age: int | None = 14 * 24 * 60 * 60, # 14 days, in seconds
|
||||
path: str = "/",
|
||||
same_site: Literal["lax", "strict", "none"] = "lax",
|
||||
https_only: bool = False,
|
||||
domain: str | None = None,
|
||||
) -> None:
|
||||
self.app = app
|
||||
self.signer = itsdangerous.TimestampSigner(str(secret_key))
|
||||
@@ -35,6 +32,8 @@ class SessionMiddleware:
|
||||
self.security_flags = "httponly; samesite=" + same_site
|
||||
if https_only: # Secure flag can be used with HTTPS only
|
||||
self.security_flags += "; secure"
|
||||
if domain is not None:
|
||||
self.security_flags += f"; domain={domain}"
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if scope["type"] not in ("http", "websocket"): # pragma: no cover
|
||||
@@ -62,7 +61,7 @@ class SessionMiddleware:
|
||||
data = b64encode(json.dumps(scope["session"]).encode("utf-8"))
|
||||
data = self.signer.sign(data)
|
||||
headers = MutableHeaders(scope=message)
|
||||
header_value = "{session_cookie}={data}; path={path}; {max_age}{security_flags}".format( # noqa E501
|
||||
header_value = "{session_cookie}={data}; path={path}; {max_age}{security_flags}".format(
|
||||
session_cookie=self.session_cookie,
|
||||
data=data.decode("utf-8"),
|
||||
path=self.path,
|
||||
@@ -73,7 +72,7 @@ class SessionMiddleware:
|
||||
elif not initial_session_was_empty:
|
||||
# The session has been cleared.
|
||||
headers = MutableHeaders(scope=message)
|
||||
header_value = "{session_cookie}={data}; path={path}; {expires}{security_flags}".format( # noqa E501
|
||||
header_value = "{session_cookie}={data}; path={path}; {expires}{security_flags}".format(
|
||||
session_cookie=self.session_cookie,
|
||||
data="null",
|
||||
path=self.path,
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import typing
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from starlette.datastructures import URL, Headers
|
||||
from starlette.responses import PlainTextResponse, RedirectResponse, Response
|
||||
@@ -11,7 +13,7 @@ class TrustedHostMiddleware:
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
allowed_hosts: typing.Optional[typing.Sequence[str]] = None,
|
||||
allowed_hosts: Sequence[str] | None = None,
|
||||
www_redirect: bool = True,
|
||||
) -> None:
|
||||
if allowed_hosts is None:
|
||||
@@ -39,9 +41,7 @@ class TrustedHostMiddleware:
|
||||
is_valid_host = False
|
||||
found_www_redirect = False
|
||||
for pattern in self.allowed_hosts:
|
||||
if host == pattern or (
|
||||
pattern.startswith("*") and host.endswith(pattern[1:])
|
||||
):
|
||||
if host == pattern or (pattern.startswith("*") and host.endswith(pattern[1:])):
|
||||
is_valid_host = True
|
||||
break
|
||||
elif "www." + host == pattern:
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import math
|
||||
import sys
|
||||
import typing
|
||||
import warnings
|
||||
from collections.abc import Callable, MutableMapping
|
||||
from typing import Any
|
||||
|
||||
import anyio
|
||||
from anyio.abc import ObjectReceiveStream, ObjectSendStream
|
||||
|
||||
from starlette.types import Receive, Scope, Send
|
||||
|
||||
@@ -15,14 +19,20 @@ warnings.warn(
|
||||
)
|
||||
|
||||
|
||||
def build_environ(scope: Scope, body: bytes) -> dict:
|
||||
def build_environ(scope: Scope, body: bytes) -> dict[str, Any]:
|
||||
"""
|
||||
Builds a scope and request body into a WSGI environ object.
|
||||
"""
|
||||
|
||||
script_name = scope.get("root_path", "").encode("utf8").decode("latin1")
|
||||
path_info = scope["path"].encode("utf8").decode("latin1")
|
||||
if path_info.startswith(script_name):
|
||||
path_info = path_info[len(script_name) :]
|
||||
|
||||
environ = {
|
||||
"REQUEST_METHOD": scope["method"],
|
||||
"SCRIPT_NAME": scope.get("root_path", "").encode("utf8").decode("latin1"),
|
||||
"PATH_INFO": scope["path"].encode("utf8").decode("latin1"),
|
||||
"SCRIPT_NAME": script_name,
|
||||
"PATH_INFO": path_info,
|
||||
"QUERY_STRING": scope["query_string"].decode("ascii"),
|
||||
"SERVER_PROTOCOL": f"HTTP/{scope['http_version']}",
|
||||
"wsgi.version": (1, 0),
|
||||
@@ -62,7 +72,7 @@ def build_environ(scope: Scope, body: bytes) -> dict:
|
||||
|
||||
|
||||
class WSGIMiddleware:
|
||||
def __init__(self, app: typing.Callable) -> None:
|
||||
def __init__(self, app: Callable[..., Any]) -> None:
|
||||
self.app = app
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
@@ -72,16 +82,17 @@ class WSGIMiddleware:
|
||||
|
||||
|
||||
class WSGIResponder:
|
||||
def __init__(self, app: typing.Callable, scope: Scope) -> None:
|
||||
stream_send: ObjectSendStream[MutableMapping[str, Any]]
|
||||
stream_receive: ObjectReceiveStream[MutableMapping[str, Any]]
|
||||
|
||||
def __init__(self, app: Callable[..., Any], scope: Scope) -> None:
|
||||
self.app = app
|
||||
self.scope = scope
|
||||
self.status = None
|
||||
self.response_headers = None
|
||||
self.stream_send, self.stream_receive = anyio.create_memory_object_stream(
|
||||
math.inf
|
||||
)
|
||||
self.stream_send, self.stream_receive = anyio.create_memory_object_stream(math.inf)
|
||||
self.response_started = False
|
||||
self.exc_info: typing.Any = None
|
||||
self.exc_info: Any = None
|
||||
|
||||
async def __call__(self, receive: Receive, send: Send) -> None:
|
||||
body = b""
|
||||
@@ -107,11 +118,11 @@ class WSGIResponder:
|
||||
def start_response(
|
||||
self,
|
||||
status: str,
|
||||
response_headers: typing.List[typing.Tuple[str, str]],
|
||||
exc_info: typing.Any = None,
|
||||
response_headers: list[tuple[str, str]],
|
||||
exc_info: Any = None,
|
||||
) -> None:
|
||||
self.exc_info = exc_info
|
||||
if not self.response_started:
|
||||
if not self.response_started: # pragma: no branch
|
||||
self.response_started = True
|
||||
status_code_string, _ = status.split(" ", 1)
|
||||
status_code = int(status_code_string)
|
||||
@@ -128,13 +139,15 @@ class WSGIResponder:
|
||||
},
|
||||
)
|
||||
|
||||
def wsgi(self, environ: dict, start_response: typing.Callable) -> None:
|
||||
def wsgi(
|
||||
self,
|
||||
environ: dict[str, Any],
|
||||
start_response: Callable[..., Any],
|
||||
) -> None:
|
||||
for chunk in self.app(environ, start_response):
|
||||
anyio.from_thread.run(
|
||||
self.stream_send.send,
|
||||
{"type": "http.response.body", "body": chunk, "more_body": True},
|
||||
)
|
||||
|
||||
anyio.from_thread.run(
|
||||
self.stream_send.send, {"type": "http.response.body", "body": b""}
|
||||
)
|
||||
anyio.from_thread.run(self.stream_send.send, {"type": "http.response.body", "body": b""})
|
||||
|
||||
Reference in New Issue
Block a user