This commit is contained in:
Iliyan Angelov
2025-12-01 06:50:10 +02:00
parent 91f51bc6fe
commit 62c1fe5951
4682 changed files with 544807 additions and 31208 deletions

View File

@@ -1,17 +1,37 @@
import typing
from __future__ import annotations
from collections.abc import Awaitable, Callable, Iterator
from typing import Any, ParamSpec, Protocol
P = ParamSpec("P")
_Scope = Any
_Receive = Callable[[], Awaitable[Any]]
_Send = Callable[[Any], Awaitable[None]]
# Since `starlette.types.ASGIApp` type differs from `ASGIApplication` from `asgiref`
# we need to define a more permissive version of ASGIApp that doesn't cause type errors.
_ASGIApp = Callable[[_Scope, _Receive, _Send], Awaitable[None]]
class _MiddlewareFactory(Protocol[P]):
def __call__(self, app: _ASGIApp, /, *args: P.args, **kwargs: P.kwargs) -> _ASGIApp: ... # pragma: no cover
class Middleware:
def __init__(self, cls: type, **options: typing.Any) -> None:
def __init__(self, cls: _MiddlewareFactory[P], *args: P.args, **kwargs: P.kwargs) -> None:
self.cls = cls
self.options = options
self.args = args
self.kwargs = kwargs
def __iter__(self) -> typing.Iterator:
as_tuple = (self.cls, self.options)
def __iter__(self) -> Iterator[Any]:
as_tuple = (self.cls, self.args, self.kwargs)
return iter(as_tuple)
def __repr__(self) -> str:
class_name = self.__class__.__name__
option_strings = [f"{key}={value!r}" for key, value in self.options.items()]
args_repr = ", ".join([self.cls.__name__] + option_strings)
args_strings = [f"{value!r}" for value in self.args]
option_strings = [f"{key}={value!r}" for key, value in self.kwargs.items()]
name = getattr(self.cls, "__name__", "")
args_repr = ", ".join([name] + args_strings + option_strings)
return f"{class_name}({args_repr})"

View File

@@ -1,4 +1,6 @@
import typing
from __future__ import annotations
from collections.abc import Callable
from starlette.authentication import (
AuthCredentials,
@@ -16,15 +18,13 @@ class AuthenticationMiddleware:
self,
app: ASGIApp,
backend: AuthenticationBackend,
on_error: typing.Optional[
typing.Callable[[HTTPConnection, AuthenticationError], Response]
] = None,
on_error: Callable[[HTTPConnection, AuthenticationError], Response] | None = None,
) -> None:
self.app = app
self.backend = backend
self.on_error: typing.Callable[
[HTTPConnection, AuthenticationError], Response
] = (on_error if on_error is not None else self.default_on_error)
self.on_error: Callable[[HTTPConnection, AuthenticationError], Response] = (
on_error if on_error is not None else self.default_on_error
)
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] not in ["http", "websocket"]:

View File

@@ -1,23 +1,100 @@
import typing
from __future__ import annotations
from collections.abc import AsyncGenerator, AsyncIterable, Awaitable, Callable, Mapping, MutableMapping
from typing import Any, TypeVar
import anyio
from starlette.background import BackgroundTask
from starlette.requests import Request
from starlette.responses import ContentStream, Response, StreamingResponse
from starlette._utils import collapse_excgroups
from starlette.requests import ClientDisconnect, Request
from starlette.responses import Response
from starlette.types import ASGIApp, Message, Receive, Scope, Send
RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]]
DispatchFunction = typing.Callable[
[Request, RequestResponseEndpoint], typing.Awaitable[Response]
]
T = typing.TypeVar("T")
RequestResponseEndpoint = Callable[[Request], Awaitable[Response]]
DispatchFunction = Callable[[Request, RequestResponseEndpoint], Awaitable[Response]]
BodyStreamGenerator = AsyncGenerator[bytes | MutableMapping[str, Any], None]
AsyncContentStream = AsyncIterable[str | bytes | memoryview | MutableMapping[str, Any]]
T = TypeVar("T")
class _CachedRequest(Request):
"""
If the user calls Request.body() from their dispatch function
we cache the entire request body in memory and pass that to downstream middlewares,
but if they call Request.stream() then all we do is send an
empty body so that downstream things don't hang forever.
"""
def __init__(self, scope: Scope, receive: Receive):
super().__init__(scope, receive)
self._wrapped_rcv_disconnected = False
self._wrapped_rcv_consumed = False
self._wrapped_rc_stream = self.stream()
async def wrapped_receive(self) -> Message:
# wrapped_rcv state 1: disconnected
if self._wrapped_rcv_disconnected:
# we've already sent a disconnect to the downstream app
# we don't need to wait to get another one
# (although most ASGI servers will just keep sending it)
return {"type": "http.disconnect"}
# wrapped_rcv state 1: consumed but not yet disconnected
if self._wrapped_rcv_consumed:
# since the downstream app has consumed us all that is left
# is to send it a disconnect
if self._is_disconnected:
# the middleware has already seen the disconnect
# since we know the client is disconnected no need to wait
# for the message
self._wrapped_rcv_disconnected = True
return {"type": "http.disconnect"}
# we don't know yet if the client is disconnected or not
# so we'll wait until we get that message
msg = await self.receive()
if msg["type"] != "http.disconnect": # pragma: no cover
# at this point a disconnect is all that we should be receiving
# if we get something else, things went wrong somewhere
raise RuntimeError(f"Unexpected message received: {msg['type']}")
self._wrapped_rcv_disconnected = True
return msg
# wrapped_rcv state 3: not yet consumed
if getattr(self, "_body", None) is not None:
# body() was called, we return it even if the client disconnected
self._wrapped_rcv_consumed = True
return {
"type": "http.request",
"body": self._body,
"more_body": False,
}
elif self._stream_consumed:
# stream() was called to completion
# return an empty body so that downstream apps don't hang
# waiting for a disconnect
self._wrapped_rcv_consumed = True
return {
"type": "http.request",
"body": b"",
"more_body": False,
}
else:
# body() was never called and stream() wasn't consumed
try:
stream = self.stream()
chunk = await stream.__anext__()
self._wrapped_rcv_consumed = self._stream_consumed
return {
"type": "http.request",
"body": chunk,
"more_body": not self._stream_consumed,
}
except ClientDisconnect:
self._wrapped_rcv_disconnected = True
return {"type": "http.disconnect"}
class BaseHTTPMiddleware:
def __init__(
self, app: ASGIApp, dispatch: typing.Optional[DispatchFunction] = None
) -> None:
def __init__(self, app: ASGIApp, dispatch: DispatchFunction | None = None) -> None:
self.app = app
self.dispatch_func = self.dispatch if dispatch is None else dispatch
@@ -26,35 +103,32 @@ class BaseHTTPMiddleware:
await self.app(scope, receive, send)
return
request = _CachedRequest(scope, receive)
wrapped_receive = request.wrapped_receive
response_sent = anyio.Event()
app_exc: Exception | None = None
exception_already_raised = False
async def call_next(request: Request) -> Response:
app_exc: typing.Optional[Exception] = None
send_stream, recv_stream = anyio.create_memory_object_stream()
async def receive_or_disconnect() -> Message:
if response_sent.is_set():
return {"type": "http.disconnect"}
async with anyio.create_task_group() as task_group:
async def wrap(func: typing.Callable[[], typing.Awaitable[T]]) -> T:
async def wrap(func: Callable[[], Awaitable[T]]) -> T:
result = await func()
task_group.cancel_scope.cancel()
return result
task_group.start_soon(wrap, response_sent.wait)
message = await wrap(request.receive)
message = await wrap(wrapped_receive)
if response_sent.is_set():
return {"type": "http.disconnect"}
return message
async def close_recv_stream_on_response_sent() -> None:
await response_sent.wait()
recv_stream.close()
async def send_no_error(message: Message) -> None:
try:
await send_stream.send(message)
@@ -65,13 +139,12 @@ class BaseHTTPMiddleware:
async def coro() -> None:
nonlocal app_exc
async with send_stream:
with send_stream:
try:
await self.app(scope, receive_or_disconnect, send_no_error)
except Exception as exc:
app_exc = exc
task_group.start_soon(close_recv_stream_on_response_sent)
task_group.start_soon(coro)
try:
@@ -81,54 +154,91 @@ class BaseHTTPMiddleware:
message = await recv_stream.receive()
except anyio.EndOfStream:
if app_exc is not None:
raise app_exc
nonlocal exception_already_raised
exception_already_raised = True
# Prevent `anyio.EndOfStream` from polluting app exception context.
# If both cause and context are None then the context is suppressed
# and `anyio.EndOfStream` is not present in the exception traceback.
# If exception cause is not None then it is propagated with
# reraising here.
# If exception has no cause but has context set then the context is
# propagated as a cause with the reraise. This is necessary in order
# to prevent `anyio.EndOfStream` from polluting the exception
# context.
raise app_exc from app_exc.__cause__ or app_exc.__context__
raise RuntimeError("No response returned.")
assert message["type"] == "http.response.start"
async def body_stream() -> typing.AsyncGenerator[bytes, None]:
async with recv_stream:
async for message in recv_stream:
assert message["type"] == "http.response.body"
body = message.get("body", b"")
if body:
yield body
async def body_stream() -> BodyStreamGenerator:
async for message in recv_stream:
if message["type"] == "http.response.pathsend":
yield message
break
assert message["type"] == "http.response.body", f"Unexpected message: {message}"
body = message.get("body", b"")
if body:
yield body
if not message.get("more_body", False):
break
if app_exc is not None:
raise app_exc
response = _StreamingResponse(
status_code=message["status"], content=body_stream(), info=info
)
response = _StreamingResponse(status_code=message["status"], content=body_stream(), info=info)
response.raw_headers = message["headers"]
return response
async with anyio.create_task_group() as task_group:
request = Request(scope, receive=receive)
response = await self.dispatch_func(request, call_next)
await response(scope, receive, send)
response_sent.set()
streams: anyio.create_memory_object_stream[Message] = anyio.create_memory_object_stream()
send_stream, recv_stream = streams
with recv_stream, send_stream, collapse_excgroups():
async with anyio.create_task_group() as task_group:
response = await self.dispatch_func(request, call_next)
await response(scope, wrapped_receive, send)
response_sent.set()
recv_stream.close()
if app_exc is not None and not exception_already_raised:
raise app_exc
async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
) -> Response:
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
raise NotImplementedError() # pragma: no cover
class _StreamingResponse(StreamingResponse):
class _StreamingResponse(Response):
def __init__(
self,
content: ContentStream,
content: AsyncContentStream,
status_code: int = 200,
headers: typing.Optional[typing.Mapping[str, str]] = None,
media_type: typing.Optional[str] = None,
background: typing.Optional[BackgroundTask] = None,
info: typing.Optional[typing.Mapping[str, typing.Any]] = None,
headers: Mapping[str, str] | None = None,
media_type: str | None = None,
info: Mapping[str, Any] | None = None,
) -> None:
self._info = info
super().__init__(content, status_code, headers, media_type, background)
self.info = info
self.body_iterator = content
self.status_code = status_code
self.media_type = media_type
self.init_headers(headers)
self.background = None
async def stream_response(self, send: Send) -> None:
if self._info:
await send({"type": "http.response.debug", "info": self._info})
return await super().stream_response(send)
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if self.info is not None:
await send({"type": "http.response.debug", "info": self.info})
await send(
{
"type": "http.response.start",
"status": self.status_code,
"headers": self.raw_headers,
}
)
should_close_body = True
async for chunk in self.body_iterator:
if isinstance(chunk, dict):
# We got an ASGI message which is not response body (eg: pathsend)
should_close_body = False
await send(chunk)
continue
await send({"type": "http.response.body", "body": chunk, "more_body": True})
if should_close_body:
await send({"type": "http.response.body", "body": b"", "more_body": False})
if self.background:
await self.background()

View File

@@ -1,6 +1,8 @@
from __future__ import annotations
import functools
import re
import typing
from collections.abc import Sequence
from starlette.datastructures import Headers, MutableHeaders
from starlette.responses import PlainTextResponse, Response
@@ -14,12 +16,12 @@ class CORSMiddleware:
def __init__(
self,
app: ASGIApp,
allow_origins: typing.Sequence[str] = (),
allow_methods: typing.Sequence[str] = ("GET",),
allow_headers: typing.Sequence[str] = (),
allow_origins: Sequence[str] = (),
allow_methods: Sequence[str] = ("GET",),
allow_headers: Sequence[str] = (),
allow_credentials: bool = False,
allow_origin_regex: typing.Optional[str] = None,
expose_headers: typing.Sequence[str] = (),
allow_origin_regex: str | None = None,
expose_headers: Sequence[str] = (),
max_age: int = 600,
) -> None:
if "*" in allow_methods:
@@ -94,9 +96,7 @@ class CORSMiddleware:
if self.allow_all_origins:
return True
if self.allow_origin_regex is not None and self.allow_origin_regex.fullmatch(
origin
):
if self.allow_origin_regex is not None and self.allow_origin_regex.fullmatch(origin):
return True
return origin in self.allow_origins
@@ -139,15 +139,11 @@ class CORSMiddleware:
return PlainTextResponse("OK", status_code=200, headers=headers)
async def simple_response(
self, scope: Scope, receive: Receive, send: Send, request_headers: Headers
) -> None:
async def simple_response(self, scope: Scope, receive: Receive, send: Send, request_headers: Headers) -> None:
send = functools.partial(self.send, send=send, request_headers=request_headers)
await self.app(scope, receive, send)
async def send(
self, message: Message, send: Send, request_headers: Headers
) -> None:
async def send(self, message: Message, send: Send, request_headers: Headers) -> None:
if message["type"] != "http.response.start":
await send(message)
return

View File

@@ -1,13 +1,15 @@
from __future__ import annotations
import html
import inspect
import sys
import traceback
import typing
from starlette._utils import is_async_callable
from starlette.concurrency import run_in_threadpool
from starlette.requests import Request
from starlette.responses import HTMLResponse, PlainTextResponse, Response
from starlette.types import ASGIApp, Message, Receive, Scope, Send
from starlette.types import ASGIApp, ExceptionHandler, Message, Receive, Scope, Send
STYLES = """
p {
@@ -137,7 +139,7 @@ class ServerErrorMiddleware:
def __init__(
self,
app: ASGIApp,
handler: typing.Optional[typing.Callable] = None,
handler: ExceptionHandler | None = None,
debug: bool = False,
) -> None:
self.app = app
@@ -183,9 +185,7 @@ class ServerErrorMiddleware:
# to optionally raise the error within the test case.
raise exc
def format_line(
self, index: int, line: str, frame_lineno: int, frame_index: int
) -> str:
def format_line(self, index: int, line: str, frame_lineno: int, frame_index: int) -> str:
values = {
# HTML escape - line could contain < or >
"line": html.escape(line).replace(" ", "&nbsp"),
@@ -199,7 +199,10 @@ class ServerErrorMiddleware:
def generate_frame_html(self, frame: inspect.FrameInfo, is_collapsed: bool) -> str:
code_context = "".join(
self.format_line(
index, line, frame.lineno, frame.index # type: ignore[arg-type]
index,
line,
frame.lineno,
frame.index, # type: ignore[arg-type]
)
for index, line in enumerate(frame.code_context or [])
)
@@ -219,9 +222,7 @@ class ServerErrorMiddleware:
return FRAME_TEMPLATE.format(**values)
def generate_html(self, exc: Exception, limit: int = 7) -> str:
traceback_obj = traceback.TracebackException.from_exception(
exc, capture_locals=True
)
traceback_obj = traceback.TracebackException.from_exception(exc, capture_locals=True)
exc_html = ""
is_collapsed = False
@@ -232,11 +233,13 @@ class ServerErrorMiddleware:
exc_html += self.generate_frame_html(frame, is_collapsed)
is_collapsed = True
if sys.version_info >= (3, 13): # pragma: no cover
exc_type_str = traceback_obj.exc_type_str
else: # pragma: no cover
exc_type_str = traceback_obj.exc_type.__name__
# escape error class and text
error = (
f"{html.escape(traceback_obj.exc_type.__name__)}: "
f"{html.escape(str(traceback_obj))}"
)
error = f"{html.escape(exc_type_str)}: {html.escape(str(traceback_obj))}"
return TEMPLATE.format(styles=STYLES, js=JS, error=error, exc_html=exc_html)

View File

@@ -1,11 +1,17 @@
import typing
from __future__ import annotations
from starlette._utils import is_async_callable
from starlette.concurrency import run_in_threadpool
from collections.abc import Mapping
from typing import Any
from starlette._exception_handler import (
ExceptionHandlers,
StatusHandlers,
wrap_app_handling_exceptions,
)
from starlette.exceptions import HTTPException, WebSocketException
from starlette.requests import Request
from starlette.responses import PlainTextResponse, Response
from starlette.types import ASGIApp, Message, Receive, Scope, Send
from starlette.types import ASGIApp, ExceptionHandler, Receive, Scope, Send
from starlette.websockets import WebSocket
@@ -13,28 +19,24 @@ class ExceptionMiddleware:
def __init__(
self,
app: ASGIApp,
handlers: typing.Optional[
typing.Mapping[typing.Any, typing.Callable[[Request, Exception], Response]]
] = None,
handlers: Mapping[Any, ExceptionHandler] | None = None,
debug: bool = False,
) -> None:
self.app = app
self.debug = debug # TODO: We ought to handle 404 cases if debug is set.
self._status_handlers: typing.Dict[int, typing.Callable] = {}
self._exception_handlers: typing.Dict[
typing.Type[Exception], typing.Callable
] = {
self._status_handlers: StatusHandlers = {}
self._exception_handlers: ExceptionHandlers = {
HTTPException: self.http_exception,
WebSocketException: self.websocket_exception,
}
if handlers is not None:
if handlers is not None: # pragma: no branch
for key, value in handlers.items():
self.add_exception_handler(key, value)
def add_exception_handler(
self,
exc_class_or_status_code: typing.Union[int, typing.Type[Exception]],
handler: typing.Callable[[Request, Exception], Response],
exc_class_or_status_code: int | type[Exception],
handler: ExceptionHandler,
) -> None:
if isinstance(exc_class_or_status_code, int):
self._status_handlers[exc_class_or_status_code] = handler
@@ -42,68 +44,30 @@ class ExceptionMiddleware:
assert issubclass(exc_class_or_status_code, Exception)
self._exception_handlers[exc_class_or_status_code] = handler
def _lookup_exception_handler(
self, exc: Exception
) -> typing.Optional[typing.Callable]:
for cls in type(exc).__mro__:
if cls in self._exception_handlers:
return self._exception_handlers[cls]
return None
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] not in ("http", "websocket"):
await self.app(scope, receive, send)
return
response_started = False
async def sender(message: Message) -> None:
nonlocal response_started
if message["type"] == "http.response.start":
response_started = True
await send(message)
try:
await self.app(scope, receive, sender)
except Exception as exc:
handler = None
if isinstance(exc, HTTPException):
handler = self._status_handlers.get(exc.status_code)
if handler is None:
handler = self._lookup_exception_handler(exc)
if handler is None:
raise exc
if response_started:
msg = "Caught handled exception, but response already started."
raise RuntimeError(msg) from exc
if scope["type"] == "http":
request = Request(scope, receive=receive)
if is_async_callable(handler):
response = await handler(request, exc)
else:
response = await run_in_threadpool(handler, request, exc)
await response(scope, receive, sender)
elif scope["type"] == "websocket":
websocket = WebSocket(scope, receive=receive, send=send)
if is_async_callable(handler):
await handler(websocket, exc)
else:
await run_in_threadpool(handler, websocket, exc)
def http_exception(self, request: Request, exc: HTTPException) -> Response:
if exc.status_code in {204, 304}:
return Response(status_code=exc.status_code, headers=exc.headers)
return PlainTextResponse(
exc.detail, status_code=exc.status_code, headers=exc.headers
scope["starlette.exception_handlers"] = (
self._exception_handlers,
self._status_handlers,
)
async def websocket_exception(
self, websocket: WebSocket, exc: WebSocketException
) -> None:
await websocket.close(code=exc.code, reason=exc.reason)
conn: Request | WebSocket
if scope["type"] == "http":
conn = Request(scope, receive, send)
else:
conn = WebSocket(scope, receive, send)
await wrap_app_handling_exceptions(self.app, conn)(scope, receive, send)
async def http_exception(self, request: Request, exc: Exception) -> Response:
assert isinstance(exc, HTTPException)
if exc.status_code in {204, 304}:
return Response(status_code=exc.status_code, headers=exc.headers)
return PlainTextResponse(exc.detail, status_code=exc.status_code, headers=exc.headers)
async def websocket_exception(self, websocket: WebSocket, exc: Exception) -> None:
assert isinstance(exc, WebSocketException)
await websocket.close(code=exc.code, reason=exc.reason) # pragma: no cover

View File

@@ -1,49 +1,51 @@
import gzip
import io
import typing
from typing import NoReturn
from starlette.datastructures import Headers, MutableHeaders
from starlette.types import ASGIApp, Message, Receive, Scope, Send
DEFAULT_EXCLUDED_CONTENT_TYPES = ("text/event-stream",)
class GZipMiddleware:
def __init__(
self, app: ASGIApp, minimum_size: int = 500, compresslevel: int = 9
) -> None:
def __init__(self, app: ASGIApp, minimum_size: int = 500, compresslevel: int = 9) -> None:
self.app = app
self.minimum_size = minimum_size
self.compresslevel = compresslevel
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] == "http":
headers = Headers(scope=scope)
if "gzip" in headers.get("Accept-Encoding", ""):
responder = GZipResponder(
self.app, self.minimum_size, compresslevel=self.compresslevel
)
await responder(scope, receive, send)
return
await self.app(scope, receive, send)
if scope["type"] != "http": # pragma: no cover
await self.app(scope, receive, send)
return
headers = Headers(scope=scope)
responder: ASGIApp
if "gzip" in headers.get("Accept-Encoding", ""):
responder = GZipResponder(self.app, self.minimum_size, compresslevel=self.compresslevel)
else:
responder = IdentityResponder(self.app, self.minimum_size)
await responder(scope, receive, send)
class GZipResponder:
def __init__(self, app: ASGIApp, minimum_size: int, compresslevel: int = 9) -> None:
class IdentityResponder:
content_encoding: str
def __init__(self, app: ASGIApp, minimum_size: int) -> None:
self.app = app
self.minimum_size = minimum_size
self.send: Send = unattached_send
self.initial_message: Message = {}
self.started = False
self.content_encoding_set = False
self.gzip_buffer = io.BytesIO()
self.gzip_file = gzip.GzipFile(
mode="wb", fileobj=self.gzip_buffer, compresslevel=compresslevel
)
self.content_type_is_excluded = False
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
self.send = send
await self.app(scope, receive, self.send_with_gzip)
await self.app(scope, receive, self.send_with_compression)
async def send_with_gzip(self, message: Message) -> None:
async def send_with_compression(self, message: Message) -> None:
message_type = message["type"]
if message_type == "http.response.start":
# Don't send the initial message until we've determined how to
@@ -51,7 +53,8 @@ class GZipResponder:
self.initial_message = message
headers = Headers(raw=self.initial_message["headers"])
self.content_encoding_set = "content-encoding" in headers
elif message_type == "http.response.body" and self.content_encoding_set:
self.content_type_is_excluded = headers.get("content-type", "").startswith(DEFAULT_EXCLUDED_CONTENT_TYPES)
elif message_type == "http.response.body" and (self.content_encoding_set or self.content_type_is_excluded):
if not self.started:
self.started = True
await self.send(self.initial_message)
@@ -61,53 +64,82 @@ class GZipResponder:
body = message.get("body", b"")
more_body = message.get("more_body", False)
if len(body) < self.minimum_size and not more_body:
# Don't apply GZip to small outgoing responses.
# Don't apply compression to small outgoing responses.
await self.send(self.initial_message)
await self.send(message)
elif not more_body:
# Standard GZip response.
self.gzip_file.write(body)
self.gzip_file.close()
body = self.gzip_buffer.getvalue()
# Standard response.
body = self.apply_compression(body, more_body=False)
headers = MutableHeaders(raw=self.initial_message["headers"])
headers["Content-Encoding"] = "gzip"
headers["Content-Length"] = str(len(body))
headers.add_vary_header("Accept-Encoding")
message["body"] = body
if body != message["body"]:
headers["Content-Encoding"] = self.content_encoding
headers["Content-Length"] = str(len(body))
message["body"] = body
await self.send(self.initial_message)
await self.send(message)
else:
# Initial body in streaming GZip response.
headers = MutableHeaders(raw=self.initial_message["headers"])
headers["Content-Encoding"] = "gzip"
headers.add_vary_header("Accept-Encoding")
del headers["Content-Length"]
# Initial body in streaming response.
body = self.apply_compression(body, more_body=True)
self.gzip_file.write(body)
message["body"] = self.gzip_buffer.getvalue()
self.gzip_buffer.seek(0)
self.gzip_buffer.truncate()
headers = MutableHeaders(raw=self.initial_message["headers"])
headers.add_vary_header("Accept-Encoding")
if body != message["body"]:
headers["Content-Encoding"] = self.content_encoding
del headers["Content-Length"]
message["body"] = body
await self.send(self.initial_message)
await self.send(message)
elif message_type == "http.response.body":
# Remaining body in streaming GZip response.
# Remaining body in streaming response.
body = message.get("body", b"")
more_body = message.get("more_body", False)
self.gzip_file.write(body)
if not more_body:
self.gzip_file.close()
message["body"] = self.gzip_buffer.getvalue()
self.gzip_buffer.seek(0)
self.gzip_buffer.truncate()
message["body"] = self.apply_compression(body, more_body=more_body)
await self.send(message)
elif message_type == "http.response.pathsend": # pragma: no branch
# Don't apply GZip to pathsend responses
await self.send(self.initial_message)
await self.send(message)
def apply_compression(self, body: bytes, *, more_body: bool) -> bytes:
"""Apply compression on the response body.
If more_body is False, any compression file should be closed. If it
isn't, it won't be closed automatically until all background tasks
complete.
"""
return body
async def unattached_send(message: Message) -> typing.NoReturn:
class GZipResponder(IdentityResponder):
content_encoding = "gzip"
def __init__(self, app: ASGIApp, minimum_size: int, compresslevel: int = 9) -> None:
super().__init__(app, minimum_size)
self.gzip_buffer = io.BytesIO()
self.gzip_file = gzip.GzipFile(mode="wb", fileobj=self.gzip_buffer, compresslevel=compresslevel)
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
with self.gzip_buffer, self.gzip_file:
await super().__call__(scope, receive, send)
def apply_compression(self, body: bytes, *, more_body: bool) -> bytes:
self.gzip_file.write(body)
if not more_body:
self.gzip_file.close()
body = self.gzip_buffer.getvalue()
self.gzip_buffer.seek(0)
self.gzip_buffer.truncate()
return body
async def unattached_send(message: Message) -> NoReturn:
raise RuntimeError("send awaitable not set") # pragma: no cover

View File

@@ -1,7 +1,8 @@
from __future__ import annotations
import json
import sys
import typing
from base64 import b64decode, b64encode
from typing import Literal
import itsdangerous
from itsdangerous.exc import BadSignature
@@ -10,22 +11,18 @@ from starlette.datastructures import MutableHeaders, Secret
from starlette.requests import HTTPConnection
from starlette.types import ASGIApp, Message, Receive, Scope, Send
if sys.version_info >= (3, 8): # pragma: no cover
from typing import Literal
else: # pragma: no cover
from typing_extensions import Literal
class SessionMiddleware:
def __init__(
self,
app: ASGIApp,
secret_key: typing.Union[str, Secret],
secret_key: str | Secret,
session_cookie: str = "session",
max_age: typing.Optional[int] = 14 * 24 * 60 * 60, # 14 days, in seconds
max_age: int | None = 14 * 24 * 60 * 60, # 14 days, in seconds
path: str = "/",
same_site: Literal["lax", "strict", "none"] = "lax",
https_only: bool = False,
domain: str | None = None,
) -> None:
self.app = app
self.signer = itsdangerous.TimestampSigner(str(secret_key))
@@ -35,6 +32,8 @@ class SessionMiddleware:
self.security_flags = "httponly; samesite=" + same_site
if https_only: # Secure flag can be used with HTTPS only
self.security_flags += "; secure"
if domain is not None:
self.security_flags += f"; domain={domain}"
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] not in ("http", "websocket"): # pragma: no cover
@@ -62,7 +61,7 @@ class SessionMiddleware:
data = b64encode(json.dumps(scope["session"]).encode("utf-8"))
data = self.signer.sign(data)
headers = MutableHeaders(scope=message)
header_value = "{session_cookie}={data}; path={path}; {max_age}{security_flags}".format( # noqa E501
header_value = "{session_cookie}={data}; path={path}; {max_age}{security_flags}".format(
session_cookie=self.session_cookie,
data=data.decode("utf-8"),
path=self.path,
@@ -73,7 +72,7 @@ class SessionMiddleware:
elif not initial_session_was_empty:
# The session has been cleared.
headers = MutableHeaders(scope=message)
header_value = "{session_cookie}={data}; path={path}; {expires}{security_flags}".format( # noqa E501
header_value = "{session_cookie}={data}; path={path}; {expires}{security_flags}".format(
session_cookie=self.session_cookie,
data="null",
path=self.path,

View File

@@ -1,4 +1,6 @@
import typing
from __future__ import annotations
from collections.abc import Sequence
from starlette.datastructures import URL, Headers
from starlette.responses import PlainTextResponse, RedirectResponse, Response
@@ -11,7 +13,7 @@ class TrustedHostMiddleware:
def __init__(
self,
app: ASGIApp,
allowed_hosts: typing.Optional[typing.Sequence[str]] = None,
allowed_hosts: Sequence[str] | None = None,
www_redirect: bool = True,
) -> None:
if allowed_hosts is None:
@@ -39,9 +41,7 @@ class TrustedHostMiddleware:
is_valid_host = False
found_www_redirect = False
for pattern in self.allowed_hosts:
if host == pattern or (
pattern.startswith("*") and host.endswith(pattern[1:])
):
if host == pattern or (pattern.startswith("*") and host.endswith(pattern[1:])):
is_valid_host = True
break
elif "www." + host == pattern:

View File

@@ -1,10 +1,14 @@
from __future__ import annotations
import io
import math
import sys
import typing
import warnings
from collections.abc import Callable, MutableMapping
from typing import Any
import anyio
from anyio.abc import ObjectReceiveStream, ObjectSendStream
from starlette.types import Receive, Scope, Send
@@ -15,14 +19,20 @@ warnings.warn(
)
def build_environ(scope: Scope, body: bytes) -> dict:
def build_environ(scope: Scope, body: bytes) -> dict[str, Any]:
"""
Builds a scope and request body into a WSGI environ object.
"""
script_name = scope.get("root_path", "").encode("utf8").decode("latin1")
path_info = scope["path"].encode("utf8").decode("latin1")
if path_info.startswith(script_name):
path_info = path_info[len(script_name) :]
environ = {
"REQUEST_METHOD": scope["method"],
"SCRIPT_NAME": scope.get("root_path", "").encode("utf8").decode("latin1"),
"PATH_INFO": scope["path"].encode("utf8").decode("latin1"),
"SCRIPT_NAME": script_name,
"PATH_INFO": path_info,
"QUERY_STRING": scope["query_string"].decode("ascii"),
"SERVER_PROTOCOL": f"HTTP/{scope['http_version']}",
"wsgi.version": (1, 0),
@@ -62,7 +72,7 @@ def build_environ(scope: Scope, body: bytes) -> dict:
class WSGIMiddleware:
def __init__(self, app: typing.Callable) -> None:
def __init__(self, app: Callable[..., Any]) -> None:
self.app = app
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
@@ -72,16 +82,17 @@ class WSGIMiddleware:
class WSGIResponder:
def __init__(self, app: typing.Callable, scope: Scope) -> None:
stream_send: ObjectSendStream[MutableMapping[str, Any]]
stream_receive: ObjectReceiveStream[MutableMapping[str, Any]]
def __init__(self, app: Callable[..., Any], scope: Scope) -> None:
self.app = app
self.scope = scope
self.status = None
self.response_headers = None
self.stream_send, self.stream_receive = anyio.create_memory_object_stream(
math.inf
)
self.stream_send, self.stream_receive = anyio.create_memory_object_stream(math.inf)
self.response_started = False
self.exc_info: typing.Any = None
self.exc_info: Any = None
async def __call__(self, receive: Receive, send: Send) -> None:
body = b""
@@ -107,11 +118,11 @@ class WSGIResponder:
def start_response(
self,
status: str,
response_headers: typing.List[typing.Tuple[str, str]],
exc_info: typing.Any = None,
response_headers: list[tuple[str, str]],
exc_info: Any = None,
) -> None:
self.exc_info = exc_info
if not self.response_started:
if not self.response_started: # pragma: no branch
self.response_started = True
status_code_string, _ = status.split(" ", 1)
status_code = int(status_code_string)
@@ -128,13 +139,15 @@ class WSGIResponder:
},
)
def wsgi(self, environ: dict, start_response: typing.Callable) -> None:
def wsgi(
self,
environ: dict[str, Any],
start_response: Callable[..., Any],
) -> None:
for chunk in self.app(environ, start_response):
anyio.from_thread.run(
self.stream_send.send,
{"type": "http.response.body", "body": chunk, "more_body": True},
)
anyio.from_thread.run(
self.stream_send.send, {"type": "http.response.body", "body": b""}
)
anyio.from_thread.run(self.stream_send.send, {"type": "http.response.body", "body": b""})