This commit is contained in:
Iliyan Angelov
2025-12-01 06:50:10 +02:00
parent 91f51bc6fe
commit 62c1fe5951
4682 changed files with 544807 additions and 31208 deletions

View File

@@ -1,6 +1,9 @@
from __future__ import annotations
import json
import typing
from collections.abc import AsyncGenerator, Iterator, Mapping
from http import cookies as http_cookies
from typing import TYPE_CHECKING, Any, NoReturn, cast
import anyio
@@ -10,14 +13,19 @@ from starlette.exceptions import HTTPException
from starlette.formparsers import FormParser, MultiPartException, MultiPartParser
from starlette.types import Message, Receive, Scope, Send
try:
from multipart.multipart import parse_options_header
except ModuleNotFoundError: # pragma: nocover
parse_options_header = None
if TYPE_CHECKING:
from python_multipart.multipart import parse_options_header
if typing.TYPE_CHECKING:
from starlette.applications import Starlette
from starlette.routing import Router
else:
try:
try:
from python_multipart.multipart import parse_options_header
except ModuleNotFoundError: # pragma: no cover
from multipart.multipart import parse_options_header
except ModuleNotFoundError: # pragma: no cover
parse_options_header = None
SERVER_PUSH_HEADERS_TO_COPY = {
@@ -29,7 +37,7 @@ SERVER_PUSH_HEADERS_TO_COPY = {
}
def cookie_parser(cookie_string: str) -> typing.Dict[str, str]:
def cookie_parser(cookie_string: str) -> dict[str, str]:
"""
This function parses a ``Cookie`` HTTP header into a dict of key/value pairs.
@@ -41,7 +49,7 @@ def cookie_parser(cookie_string: str) -> typing.Dict[str, str]:
Note: we are explicitly _NOT_ using `SimpleCookie.load` because it is based
on an outdated spec and will fail on lots of input we want to support
"""
cookie_dict: typing.Dict[str, str] = {}
cookie_dict: dict[str, str] = {}
for chunk in cookie_string.split(";"):
if "=" in chunk:
key, val = chunk.split("=", 1)
@@ -60,20 +68,20 @@ class ClientDisconnect(Exception):
pass
class HTTPConnection(typing.Mapping[str, typing.Any]):
class HTTPConnection(Mapping[str, Any]):
"""
A base class for incoming HTTP connections, that is used to provide
any functionality that is common to both `Request` and `WebSocket`.
"""
def __init__(self, scope: Scope, receive: typing.Optional[Receive] = None) -> None:
def __init__(self, scope: Scope, receive: Receive | None = None) -> None:
assert scope["type"] in ("http", "websocket")
self.scope = scope
def __getitem__(self, key: str) -> typing.Any:
def __getitem__(self, key: str) -> Any:
return self.scope[key]
def __iter__(self) -> typing.Iterator[str]:
def __iter__(self) -> Iterator[str]:
return iter(self.scope)
def __len__(self) -> int:
@@ -86,12 +94,12 @@ class HTTPConnection(typing.Mapping[str, typing.Any]):
__hash__ = object.__hash__
@property
def app(self) -> typing.Any:
def app(self) -> Any:
return self.scope["app"]
@property
def url(self) -> URL:
if not hasattr(self, "_url"):
if not hasattr(self, "_url"): # pragma: no branch
self._url = URL(scope=self.scope)
return self._url
@@ -99,11 +107,16 @@ class HTTPConnection(typing.Mapping[str, typing.Any]):
def base_url(self) -> URL:
if not hasattr(self, "_base_url"):
base_url_scope = dict(self.scope)
base_url_scope["path"] = "/"
# This is used by request.url_for, it might be used inside a Mount which
# would have its own child scope with its own root_path, but the base URL
# for url_for should still be the top level app root path.
app_root_path = base_url_scope.get("app_root_path", base_url_scope.get("root_path", ""))
path = app_root_path
if not path.endswith("/"):
path += "/"
base_url_scope["path"] = path
base_url_scope["query_string"] = b""
base_url_scope["root_path"] = base_url_scope.get(
"app_root_path", base_url_scope.get("root_path", "")
)
base_url_scope["root_path"] = app_root_path
self._base_url = URL(scope=base_url_scope)
return self._base_url
@@ -115,52 +128,47 @@ class HTTPConnection(typing.Mapping[str, typing.Any]):
@property
def query_params(self) -> QueryParams:
if not hasattr(self, "_query_params"):
if not hasattr(self, "_query_params"): # pragma: no branch
self._query_params = QueryParams(self.scope["query_string"])
return self._query_params
@property
def path_params(self) -> typing.Dict[str, typing.Any]:
def path_params(self) -> dict[str, Any]:
return self.scope.get("path_params", {})
@property
def cookies(self) -> typing.Dict[str, str]:
def cookies(self) -> dict[str, str]:
if not hasattr(self, "_cookies"):
cookies: typing.Dict[str, str] = {}
cookie_header = self.headers.get("cookie")
cookies: dict[str, str] = {}
cookie_headers = self.headers.getlist("cookie")
for header in cookie_headers:
cookies.update(cookie_parser(header))
if cookie_header:
cookies = cookie_parser(cookie_header)
self._cookies = cookies
return self._cookies
@property
def client(self) -> typing.Optional[Address]:
# client is a 2 item tuple of (host, port), None or missing
def client(self) -> Address | None:
# client is a 2 item tuple of (host, port), None if missing
host_port = self.scope.get("client")
if host_port is not None:
return Address(*host_port)
return None
@property
def session(self) -> typing.Dict[str, typing.Any]:
assert (
"session" in self.scope
), "SessionMiddleware must be installed to access request.session"
return self.scope["session"]
def session(self) -> dict[str, Any]:
assert "session" in self.scope, "SessionMiddleware must be installed to access request.session"
return self.scope["session"] # type: ignore[no-any-return]
@property
def auth(self) -> typing.Any:
assert (
"auth" in self.scope
), "AuthenticationMiddleware must be installed to access request.auth"
def auth(self) -> Any:
assert "auth" in self.scope, "AuthenticationMiddleware must be installed to access request.auth"
return self.scope["auth"]
@property
def user(self) -> typing.Any:
assert (
"user" in self.scope
), "AuthenticationMiddleware must be installed to access request.user"
def user(self) -> Any:
assert "user" in self.scope, "AuthenticationMiddleware must be installed to access request.user"
return self.scope["user"]
@property
@@ -173,26 +181,26 @@ class HTTPConnection(typing.Mapping[str, typing.Any]):
self._state = State(self.scope["state"])
return self._state
def url_for(self, __name: str, **path_params: typing.Any) -> URL:
router: Router = self.scope["router"]
url_path = router.url_path_for(__name, **path_params)
def url_for(self, name: str, /, **path_params: Any) -> URL:
url_path_provider: Router | Starlette | None = self.scope.get("router") or self.scope.get("app")
if url_path_provider is None:
raise RuntimeError("The `url_for` method can only be used inside a Starlette application or with a router.")
url_path = url_path_provider.url_path_for(name, **path_params)
return url_path.make_absolute_url(base_url=self.base_url)
async def empty_receive() -> typing.NoReturn:
async def empty_receive() -> NoReturn:
raise RuntimeError("Receive channel has not been made available")
async def empty_send(message: Message) -> typing.NoReturn:
async def empty_send(message: Message) -> NoReturn:
raise RuntimeError("Send channel has not been made available")
class Request(HTTPConnection):
_form: typing.Optional[FormData]
_form: FormData | None
def __init__(
self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send
):
def __init__(self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send):
super().__init__(scope)
assert scope["type"] == "http"
self._receive = receive
@@ -203,43 +211,42 @@ class Request(HTTPConnection):
@property
def method(self) -> str:
return self.scope["method"]
return cast(str, self.scope["method"])
@property
def receive(self) -> Receive:
return self._receive
async def stream(self) -> typing.AsyncGenerator[bytes, None]:
async def stream(self) -> AsyncGenerator[bytes, None]:
if hasattr(self, "_body"):
yield self._body
yield b""
return
if self._stream_consumed:
raise RuntimeError("Stream consumed")
self._stream_consumed = True
while True:
while not self._stream_consumed:
message = await self._receive()
if message["type"] == "http.request":
body = message.get("body", b"")
if not message.get("more_body", False):
self._stream_consumed = True
if body:
yield body
if not message.get("more_body", False):
break
elif message["type"] == "http.disconnect":
elif message["type"] == "http.disconnect": # pragma: no branch
self._is_disconnected = True
raise ClientDisconnect()
yield b""
async def body(self) -> bytes:
if not hasattr(self, "_body"):
chunks: "typing.List[bytes]" = []
chunks: list[bytes] = []
async for chunk in self.stream():
chunks.append(chunk)
self._body = b"".join(chunks)
return self._body
async def json(self) -> typing.Any:
if not hasattr(self, "_json"):
async def json(self) -> Any:
if not hasattr(self, "_json"): # pragma: no branch
body = await self.body()
self._json = json.loads(body)
return self._json
@@ -247,13 +254,14 @@ class Request(HTTPConnection):
async def _get_form(
self,
*,
max_files: typing.Union[int, float] = 1000,
max_fields: typing.Union[int, float] = 1000,
max_files: int | float = 1000,
max_fields: int | float = 1000,
max_part_size: int = 1024 * 1024,
) -> FormData:
if self._form is None:
assert (
parse_options_header is not None
), "The `python-multipart` library must be installed to use form parsing."
if self._form is None: # pragma: no branch
assert parse_options_header is not None, (
"The `python-multipart` library must be installed to use form parsing."
)
content_type_header = self.headers.get("Content-Type")
content_type: bytes
content_type, _ = parse_options_header(content_type_header)
@@ -264,6 +272,7 @@ class Request(HTTPConnection):
self.stream(),
max_files=max_files,
max_fields=max_fields,
max_part_size=max_part_size,
)
self._form = await multipart_parser.parse()
except MultiPartException as exc:
@@ -280,15 +289,16 @@ class Request(HTTPConnection):
def form(
self,
*,
max_files: typing.Union[int, float] = 1000,
max_fields: typing.Union[int, float] = 1000,
max_files: int | float = 1000,
max_fields: int | float = 1000,
max_part_size: int = 1024 * 1024,
) -> AwaitableOrContextManager[FormData]:
return AwaitableOrContextManagerWrapper(
self._get_form(max_files=max_files, max_fields=max_fields)
self._get_form(max_files=max_files, max_fields=max_fields, max_part_size=max_part_size)
)
async def close(self) -> None:
if self._form is not None:
if self._form is not None: # pragma: no branch
await self._form.close()
async def is_disconnected(self) -> bool:
@@ -307,12 +317,8 @@ class Request(HTTPConnection):
async def send_push_promise(self, path: str) -> None:
if "http.response.push" in self.scope.get("extensions", {}):
raw_headers: "typing.List[typing.Tuple[bytes, bytes]]" = []
raw_headers: list[tuple[bytes, bytes]] = []
for name in SERVER_PUSH_HEADERS_TO_COPY:
for value in self.headers.getlist(name):
raw_headers.append(
(name.encode("latin-1"), value.encode("latin-1"))
)
await self._send(
{"type": "http.response.push", "path": path, "headers": raw_headers}
)
raw_headers.append((name.encode("latin-1"), value.encode("latin-1")))
await self._send({"type": "http.response.push", "path": path, "headers": raw_headers})