updates
This commit is contained in:
@@ -1,22 +1,27 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
import inspect
|
||||
import re
|
||||
import traceback
|
||||
import types
|
||||
import typing
|
||||
import warnings
|
||||
from contextlib import asynccontextmanager
|
||||
from collections.abc import Awaitable, Callable, Collection, Generator, Sequence
|
||||
from contextlib import AbstractAsyncContextManager, AbstractContextManager, asynccontextmanager
|
||||
from enum import Enum
|
||||
from re import Pattern
|
||||
from typing import Any, TypeVar
|
||||
|
||||
from starlette._utils import is_async_callable
|
||||
from starlette._exception_handler import wrap_app_handling_exceptions
|
||||
from starlette._utils import get_route_path, is_async_callable
|
||||
from starlette.concurrency import run_in_threadpool
|
||||
from starlette.convertors import CONVERTOR_TYPES, Convertor
|
||||
from starlette.datastructures import URL, Headers, URLPath
|
||||
from starlette.exceptions import HTTPException
|
||||
from starlette.middleware import Middleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import PlainTextResponse, RedirectResponse
|
||||
from starlette.responses import PlainTextResponse, RedirectResponse, Response
|
||||
from starlette.types import ASGIApp, Lifespan, Receive, Scope, Send
|
||||
from starlette.websockets import WebSocket, WebSocketClose
|
||||
|
||||
@@ -27,7 +32,7 @@ class NoMatchFound(Exception):
|
||||
if no matching route exists.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, path_params: typing.Dict[str, typing.Any]) -> None:
|
||||
def __init__(self, name: str, path_params: dict[str, Any]) -> None:
|
||||
params = ", ".join(list(path_params.keys()))
|
||||
super().__init__(f'No route exists for name "{name}" and params "{params}".')
|
||||
|
||||
@@ -38,14 +43,13 @@ class Match(Enum):
|
||||
FULL = 2
|
||||
|
||||
|
||||
def iscoroutinefunction_or_partial(obj: typing.Any) -> bool: # pragma: no cover
|
||||
def iscoroutinefunction_or_partial(obj: Any) -> bool: # pragma: no cover
|
||||
"""
|
||||
Correctly determines if an object is a coroutine function,
|
||||
including those wrapped in functools.partial objects.
|
||||
"""
|
||||
warnings.warn(
|
||||
"iscoroutinefunction_or_partial is deprecated, "
|
||||
"and will be removed in a future release.",
|
||||
"iscoroutinefunction_or_partial is deprecated, and will be removed in a future release.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
while isinstance(obj, functools.partial):
|
||||
@@ -53,25 +57,32 @@ def iscoroutinefunction_or_partial(obj: typing.Any) -> bool: # pragma: no cover
|
||||
return inspect.iscoroutinefunction(obj)
|
||||
|
||||
|
||||
def request_response(func: typing.Callable) -> ASGIApp:
|
||||
def request_response(
|
||||
func: Callable[[Request], Awaitable[Response] | Response],
|
||||
) -> ASGIApp:
|
||||
"""
|
||||
Takes a function or coroutine `func(request) -> response`,
|
||||
and returns an ASGI application.
|
||||
"""
|
||||
is_coroutine = is_async_callable(func)
|
||||
f: Callable[[Request], Awaitable[Response]] = (
|
||||
func if is_async_callable(func) else functools.partial(run_in_threadpool, func)
|
||||
)
|
||||
|
||||
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
||||
request = Request(scope, receive=receive, send=send)
|
||||
if is_coroutine:
|
||||
response = await func(request)
|
||||
else:
|
||||
response = await run_in_threadpool(func, request)
|
||||
await response(scope, receive, send)
|
||||
request = Request(scope, receive, send)
|
||||
|
||||
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
||||
response = await f(request)
|
||||
await response(scope, receive, send)
|
||||
|
||||
await wrap_app_handling_exceptions(app, request)(scope, receive, send)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def websocket_session(func: typing.Callable) -> ASGIApp:
|
||||
def websocket_session(
|
||||
func: Callable[[WebSocket], Awaitable[None]],
|
||||
) -> ASGIApp:
|
||||
"""
|
||||
Takes a coroutine `func(session)`, and returns an ASGI application.
|
||||
"""
|
||||
@@ -79,22 +90,24 @@ def websocket_session(func: typing.Callable) -> ASGIApp:
|
||||
|
||||
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
||||
session = WebSocket(scope, receive=receive, send=send)
|
||||
await func(session)
|
||||
|
||||
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
||||
await func(session)
|
||||
|
||||
await wrap_app_handling_exceptions(app, session)(scope, receive, send)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def get_name(endpoint: typing.Callable) -> str:
|
||||
if inspect.isroutine(endpoint) or inspect.isclass(endpoint):
|
||||
return endpoint.__name__
|
||||
return endpoint.__class__.__name__
|
||||
def get_name(endpoint: Callable[..., Any]) -> str:
|
||||
return getattr(endpoint, "__name__", endpoint.__class__.__name__)
|
||||
|
||||
|
||||
def replace_params(
|
||||
path: str,
|
||||
param_convertors: typing.Dict[str, Convertor],
|
||||
path_params: typing.Dict[str, str],
|
||||
) -> typing.Tuple[str, dict]:
|
||||
param_convertors: dict[str, Convertor[Any]],
|
||||
path_params: dict[str, str],
|
||||
) -> tuple[str, dict[str, str]]:
|
||||
for key, value in list(path_params.items()):
|
||||
if "{" + key + "}" in path:
|
||||
convertor = param_convertors[key]
|
||||
@@ -110,7 +123,7 @@ PARAM_REGEX = re.compile("{([a-zA-Z_][a-zA-Z0-9_]*)(:[a-zA-Z_][a-zA-Z0-9_]*)?}")
|
||||
|
||||
def compile_path(
|
||||
path: str,
|
||||
) -> typing.Tuple[typing.Pattern, str, typing.Dict[str, Convertor]]:
|
||||
) -> tuple[Pattern[str], str, dict[str, Convertor[Any]]]:
|
||||
"""
|
||||
Given a path string, like: "/{username:str}",
|
||||
or a host string, like: "{subdomain}.mydomain.org", return a three-tuple
|
||||
@@ -131,9 +144,7 @@ def compile_path(
|
||||
for match in PARAM_REGEX.finditer(path):
|
||||
param_name, convertor_type = match.groups("str")
|
||||
convertor_type = convertor_type.lstrip(":")
|
||||
assert (
|
||||
convertor_type in CONVERTOR_TYPES
|
||||
), f"Unknown path convertor '{convertor_type}'"
|
||||
assert convertor_type in CONVERTOR_TYPES, f"Unknown path convertor '{convertor_type}'"
|
||||
convertor = CONVERTOR_TYPES[convertor_type]
|
||||
|
||||
path_regex += re.escape(path[idx : match.start()])
|
||||
@@ -167,10 +178,10 @@ def compile_path(
|
||||
|
||||
|
||||
class BaseRoute:
|
||||
def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
|
||||
def matches(self, scope: Scope) -> tuple[Match, Scope]:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
def url_path_for(self, __name: str, **path_params: typing.Any) -> URLPath:
|
||||
def url_path_for(self, name: str, /, **path_params: Any) -> URLPath:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
async def handle(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
@@ -187,7 +198,7 @@ class BaseRoute:
|
||||
if scope["type"] == "http":
|
||||
response = PlainTextResponse("Not Found", status_code=404)
|
||||
await response(scope, receive, send)
|
||||
elif scope["type"] == "websocket":
|
||||
elif scope["type"] == "websocket": # pragma: no branch
|
||||
websocket_close = WebSocketClose()
|
||||
await websocket_close(scope, receive, send)
|
||||
return
|
||||
@@ -200,11 +211,12 @@ class Route(BaseRoute):
|
||||
def __init__(
|
||||
self,
|
||||
path: str,
|
||||
endpoint: typing.Callable,
|
||||
endpoint: Callable[..., Any],
|
||||
*,
|
||||
methods: typing.Optional[typing.List[str]] = None,
|
||||
name: typing.Optional[str] = None,
|
||||
methods: Collection[str] | None = None,
|
||||
name: str | None = None,
|
||||
include_in_schema: bool = True,
|
||||
middleware: Sequence[Middleware] | None = None,
|
||||
) -> None:
|
||||
assert path.startswith("/"), "Routed paths must start with '/'"
|
||||
self.path = path
|
||||
@@ -224,6 +236,10 @@ class Route(BaseRoute):
|
||||
# Endpoint is a class. Treat it as ASGI.
|
||||
self.app = endpoint
|
||||
|
||||
if middleware is not None:
|
||||
for cls, args, kwargs in reversed(middleware):
|
||||
self.app = cls(self.app, *args, **kwargs)
|
||||
|
||||
if methods is None:
|
||||
self.methods = None
|
||||
else:
|
||||
@@ -233,9 +249,11 @@ class Route(BaseRoute):
|
||||
|
||||
self.path_regex, self.path_format, self.param_convertors = compile_path(path)
|
||||
|
||||
def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
|
||||
def matches(self, scope: Scope) -> tuple[Match, Scope]:
|
||||
path_params: dict[str, Any]
|
||||
if scope["type"] == "http":
|
||||
match = self.path_regex.match(scope["path"])
|
||||
route_path = get_route_path(scope)
|
||||
match = self.path_regex.match(route_path)
|
||||
if match:
|
||||
matched_params = match.groupdict()
|
||||
for key, value in matched_params.items():
|
||||
@@ -249,16 +267,14 @@ class Route(BaseRoute):
|
||||
return Match.FULL, child_scope
|
||||
return Match.NONE, {}
|
||||
|
||||
def url_path_for(self, __name: str, **path_params: typing.Any) -> URLPath:
|
||||
def url_path_for(self, name: str, /, **path_params: Any) -> URLPath:
|
||||
seen_params = set(path_params.keys())
|
||||
expected_params = set(self.param_convertors.keys())
|
||||
|
||||
if __name != self.name or seen_params != expected_params:
|
||||
raise NoMatchFound(__name, path_params)
|
||||
if name != self.name or seen_params != expected_params:
|
||||
raise NoMatchFound(name, path_params)
|
||||
|
||||
path, remaining_params = replace_params(
|
||||
self.path_format, self.param_convertors, path_params
|
||||
)
|
||||
path, remaining_params = replace_params(self.path_format, self.param_convertors, path_params)
|
||||
assert not remaining_params
|
||||
return URLPath(path=path, protocol="http")
|
||||
|
||||
@@ -268,14 +284,12 @@ class Route(BaseRoute):
|
||||
if "app" in scope:
|
||||
raise HTTPException(status_code=405, headers=headers)
|
||||
else:
|
||||
response = PlainTextResponse(
|
||||
"Method Not Allowed", status_code=405, headers=headers
|
||||
)
|
||||
response = PlainTextResponse("Method Not Allowed", status_code=405, headers=headers)
|
||||
await response(scope, receive, send)
|
||||
else:
|
||||
await self.app(scope, receive, send)
|
||||
|
||||
def __eq__(self, other: typing.Any) -> bool:
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
return (
|
||||
isinstance(other, Route)
|
||||
and self.path == other.path
|
||||
@@ -292,7 +306,12 @@ class Route(BaseRoute):
|
||||
|
||||
class WebSocketRoute(BaseRoute):
|
||||
def __init__(
|
||||
self, path: str, endpoint: typing.Callable, *, name: typing.Optional[str] = None
|
||||
self,
|
||||
path: str,
|
||||
endpoint: Callable[..., Any],
|
||||
*,
|
||||
name: str | None = None,
|
||||
middleware: Sequence[Middleware] | None = None,
|
||||
) -> None:
|
||||
assert path.startswith("/"), "Routed paths must start with '/'"
|
||||
self.path = path
|
||||
@@ -309,11 +328,17 @@ class WebSocketRoute(BaseRoute):
|
||||
# Endpoint is a class. Treat it as ASGI.
|
||||
self.app = endpoint
|
||||
|
||||
if middleware is not None:
|
||||
for cls, args, kwargs in reversed(middleware):
|
||||
self.app = cls(self.app, *args, **kwargs)
|
||||
|
||||
self.path_regex, self.path_format, self.param_convertors = compile_path(path)
|
||||
|
||||
def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
|
||||
def matches(self, scope: Scope) -> tuple[Match, Scope]:
|
||||
path_params: dict[str, Any]
|
||||
if scope["type"] == "websocket":
|
||||
match = self.path_regex.match(scope["path"])
|
||||
route_path = get_route_path(scope)
|
||||
match = self.path_regex.match(route_path)
|
||||
if match:
|
||||
matched_params = match.groupdict()
|
||||
for key, value in matched_params.items():
|
||||
@@ -324,28 +349,22 @@ class WebSocketRoute(BaseRoute):
|
||||
return Match.FULL, child_scope
|
||||
return Match.NONE, {}
|
||||
|
||||
def url_path_for(self, __name: str, **path_params: typing.Any) -> URLPath:
|
||||
def url_path_for(self, name: str, /, **path_params: Any) -> URLPath:
|
||||
seen_params = set(path_params.keys())
|
||||
expected_params = set(self.param_convertors.keys())
|
||||
|
||||
if __name != self.name or seen_params != expected_params:
|
||||
raise NoMatchFound(__name, path_params)
|
||||
if name != self.name or seen_params != expected_params:
|
||||
raise NoMatchFound(name, path_params)
|
||||
|
||||
path, remaining_params = replace_params(
|
||||
self.path_format, self.param_convertors, path_params
|
||||
)
|
||||
path, remaining_params = replace_params(self.path_format, self.param_convertors, path_params)
|
||||
assert not remaining_params
|
||||
return URLPath(path=path, protocol="websocket")
|
||||
|
||||
async def handle(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
await self.app(scope, receive, send)
|
||||
|
||||
def __eq__(self, other: typing.Any) -> bool:
|
||||
return (
|
||||
isinstance(other, WebSocketRoute)
|
||||
and self.path == other.path
|
||||
and self.endpoint == other.endpoint
|
||||
)
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
return isinstance(other, WebSocketRoute) and self.path == other.path and self.endpoint == other.endpoint
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(path={self.path!r}, name={self.name!r})"
|
||||
@@ -355,16 +374,14 @@ class Mount(BaseRoute):
|
||||
def __init__(
|
||||
self,
|
||||
path: str,
|
||||
app: typing.Optional[ASGIApp] = None,
|
||||
routes: typing.Optional[typing.Sequence[BaseRoute]] = None,
|
||||
name: typing.Optional[str] = None,
|
||||
app: ASGIApp | None = None,
|
||||
routes: Sequence[BaseRoute] | None = None,
|
||||
name: str | None = None,
|
||||
*,
|
||||
middleware: typing.Optional[typing.Sequence[Middleware]] = None,
|
||||
middleware: Sequence[Middleware] | None = None,
|
||||
) -> None:
|
||||
assert path == "" or path.startswith("/"), "Routed paths must start with '/'"
|
||||
assert (
|
||||
app is not None or routes is not None
|
||||
), "Either 'app=...', or 'routes=' must be specified"
|
||||
assert app is not None or routes is not None, "Either 'app=...', or 'routes=' must be specified"
|
||||
self.path = path.rstrip("/")
|
||||
if app is not None:
|
||||
self._base_app: ASGIApp = app
|
||||
@@ -372,82 +389,80 @@ class Mount(BaseRoute):
|
||||
self._base_app = Router(routes=routes)
|
||||
self.app = self._base_app
|
||||
if middleware is not None:
|
||||
for cls, options in reversed(middleware):
|
||||
self.app = cls(app=self.app, **options)
|
||||
for cls, args, kwargs in reversed(middleware):
|
||||
self.app = cls(self.app, *args, **kwargs)
|
||||
self.name = name
|
||||
self.path_regex, self.path_format, self.param_convertors = compile_path(
|
||||
self.path + "/{path:path}"
|
||||
)
|
||||
self.path_regex, self.path_format, self.param_convertors = compile_path(self.path + "/{path:path}")
|
||||
|
||||
@property
|
||||
def routes(self) -> typing.List[BaseRoute]:
|
||||
def routes(self) -> list[BaseRoute]:
|
||||
return getattr(self._base_app, "routes", [])
|
||||
|
||||
def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
|
||||
if scope["type"] in ("http", "websocket"):
|
||||
path = scope["path"]
|
||||
match = self.path_regex.match(path)
|
||||
def matches(self, scope: Scope) -> tuple[Match, Scope]:
|
||||
path_params: dict[str, Any]
|
||||
if scope["type"] in ("http", "websocket"): # pragma: no branch
|
||||
root_path = scope.get("root_path", "")
|
||||
route_path = get_route_path(scope)
|
||||
match = self.path_regex.match(route_path)
|
||||
if match:
|
||||
matched_params = match.groupdict()
|
||||
for key, value in matched_params.items():
|
||||
matched_params[key] = self.param_convertors[key].convert(value)
|
||||
remaining_path = "/" + matched_params.pop("path")
|
||||
matched_path = path[: -len(remaining_path)]
|
||||
matched_path = route_path[: -len(remaining_path)]
|
||||
path_params = dict(scope.get("path_params", {}))
|
||||
path_params.update(matched_params)
|
||||
root_path = scope.get("root_path", "")
|
||||
child_scope = {
|
||||
"path_params": path_params,
|
||||
# app_root_path will only be set at the top level scope,
|
||||
# initialized with the (optional) value of a root_path
|
||||
# set above/before Starlette. And even though any
|
||||
# mount will have its own child scope with its own respective
|
||||
# root_path, the app_root_path will always be available in all
|
||||
# the child scopes with the same top level value because it's
|
||||
# set only once here with a default, any other child scope will
|
||||
# just inherit that app_root_path default value stored in the
|
||||
# scope. All this is needed to support Request.url_for(), as it
|
||||
# uses the app_root_path to build the URL path.
|
||||
"app_root_path": scope.get("app_root_path", root_path),
|
||||
"root_path": root_path + matched_path,
|
||||
"path": remaining_path,
|
||||
"endpoint": self.app,
|
||||
}
|
||||
return Match.FULL, child_scope
|
||||
return Match.NONE, {}
|
||||
|
||||
def url_path_for(self, __name: str, **path_params: typing.Any) -> URLPath:
|
||||
if self.name is not None and __name == self.name and "path" in path_params:
|
||||
def url_path_for(self, name: str, /, **path_params: Any) -> URLPath:
|
||||
if self.name is not None and name == self.name and "path" in path_params:
|
||||
# 'name' matches "<mount_name>".
|
||||
path_params["path"] = path_params["path"].lstrip("/")
|
||||
path, remaining_params = replace_params(
|
||||
self.path_format, self.param_convertors, path_params
|
||||
)
|
||||
path, remaining_params = replace_params(self.path_format, self.param_convertors, path_params)
|
||||
if not remaining_params:
|
||||
return URLPath(path=path)
|
||||
elif self.name is None or __name.startswith(self.name + ":"):
|
||||
elif self.name is None or name.startswith(self.name + ":"):
|
||||
if self.name is None:
|
||||
# No mount name.
|
||||
remaining_name = __name
|
||||
remaining_name = name
|
||||
else:
|
||||
# 'name' matches "<mount_name>:<child_name>".
|
||||
remaining_name = __name[len(self.name) + 1 :]
|
||||
remaining_name = name[len(self.name) + 1 :]
|
||||
path_kwarg = path_params.get("path")
|
||||
path_params["path"] = ""
|
||||
path_prefix, remaining_params = replace_params(
|
||||
self.path_format, self.param_convertors, path_params
|
||||
)
|
||||
path_prefix, remaining_params = replace_params(self.path_format, self.param_convertors, path_params)
|
||||
if path_kwarg is not None:
|
||||
remaining_params["path"] = path_kwarg
|
||||
for route in self.routes or []:
|
||||
try:
|
||||
url = route.url_path_for(remaining_name, **remaining_params)
|
||||
return URLPath(
|
||||
path=path_prefix.rstrip("/") + str(url), protocol=url.protocol
|
||||
)
|
||||
return URLPath(path=path_prefix.rstrip("/") + str(url), protocol=url.protocol)
|
||||
except NoMatchFound:
|
||||
pass
|
||||
raise NoMatchFound(__name, path_params)
|
||||
raise NoMatchFound(name, path_params)
|
||||
|
||||
async def handle(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
await self.app(scope, receive, send)
|
||||
|
||||
def __eq__(self, other: typing.Any) -> bool:
|
||||
return (
|
||||
isinstance(other, Mount)
|
||||
and self.path == other.path
|
||||
and self.app == other.app
|
||||
)
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
return isinstance(other, Mount) and self.path == other.path and self.app == other.app
|
||||
|
||||
def __repr__(self) -> str:
|
||||
class_name = self.__class__.__name__
|
||||
@@ -456,9 +471,7 @@ class Mount(BaseRoute):
|
||||
|
||||
|
||||
class Host(BaseRoute):
|
||||
def __init__(
|
||||
self, host: str, app: ASGIApp, name: typing.Optional[str] = None
|
||||
) -> None:
|
||||
def __init__(self, host: str, app: ASGIApp, name: str | None = None) -> None:
|
||||
assert not host.startswith("/"), "Host must not start with '/'"
|
||||
self.host = host
|
||||
self.app = app
|
||||
@@ -466,11 +479,11 @@ class Host(BaseRoute):
|
||||
self.host_regex, self.host_format, self.param_convertors = compile_path(host)
|
||||
|
||||
@property
|
||||
def routes(self) -> typing.List[BaseRoute]:
|
||||
def routes(self) -> list[BaseRoute]:
|
||||
return getattr(self.app, "routes", [])
|
||||
|
||||
def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
|
||||
if scope["type"] in ("http", "websocket"):
|
||||
def matches(self, scope: Scope) -> tuple[Match, Scope]:
|
||||
if scope["type"] in ("http", "websocket"): # pragma:no branch
|
||||
headers = Headers(scope=scope)
|
||||
host = headers.get("host", "").split(":")[0]
|
||||
match = self.host_regex.match(host)
|
||||
@@ -484,42 +497,34 @@ class Host(BaseRoute):
|
||||
return Match.FULL, child_scope
|
||||
return Match.NONE, {}
|
||||
|
||||
def url_path_for(self, __name: str, **path_params: typing.Any) -> URLPath:
|
||||
if self.name is not None and __name == self.name and "path" in path_params:
|
||||
def url_path_for(self, name: str, /, **path_params: Any) -> URLPath:
|
||||
if self.name is not None and name == self.name and "path" in path_params:
|
||||
# 'name' matches "<mount_name>".
|
||||
path = path_params.pop("path")
|
||||
host, remaining_params = replace_params(
|
||||
self.host_format, self.param_convertors, path_params
|
||||
)
|
||||
host, remaining_params = replace_params(self.host_format, self.param_convertors, path_params)
|
||||
if not remaining_params:
|
||||
return URLPath(path=path, host=host)
|
||||
elif self.name is None or __name.startswith(self.name + ":"):
|
||||
elif self.name is None or name.startswith(self.name + ":"):
|
||||
if self.name is None:
|
||||
# No mount name.
|
||||
remaining_name = __name
|
||||
remaining_name = name
|
||||
else:
|
||||
# 'name' matches "<mount_name>:<child_name>".
|
||||
remaining_name = __name[len(self.name) + 1 :]
|
||||
host, remaining_params = replace_params(
|
||||
self.host_format, self.param_convertors, path_params
|
||||
)
|
||||
remaining_name = name[len(self.name) + 1 :]
|
||||
host, remaining_params = replace_params(self.host_format, self.param_convertors, path_params)
|
||||
for route in self.routes or []:
|
||||
try:
|
||||
url = route.url_path_for(remaining_name, **remaining_params)
|
||||
return URLPath(path=str(url), protocol=url.protocol, host=host)
|
||||
except NoMatchFound:
|
||||
pass
|
||||
raise NoMatchFound(__name, path_params)
|
||||
raise NoMatchFound(name, path_params)
|
||||
|
||||
async def handle(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
await self.app(scope, receive, send)
|
||||
|
||||
def __eq__(self, other: typing.Any) -> bool:
|
||||
return (
|
||||
isinstance(other, Host)
|
||||
and self.host == other.host
|
||||
and self.app == other.app
|
||||
)
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
return isinstance(other, Host) and self.host == other.host and self.app == other.app
|
||||
|
||||
def __repr__(self) -> str:
|
||||
class_name = self.__class__.__name__
|
||||
@@ -527,11 +532,11 @@ class Host(BaseRoute):
|
||||
return f"{class_name}(host={self.host!r}, name={name!r}, app={self.app!r})"
|
||||
|
||||
|
||||
_T = typing.TypeVar("_T")
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
class _AsyncLiftContextManager(typing.AsyncContextManager[_T]):
|
||||
def __init__(self, cm: typing.ContextManager[_T]):
|
||||
class _AsyncLiftContextManager(AbstractAsyncContextManager[_T]):
|
||||
def __init__(self, cm: AbstractContextManager[_T]):
|
||||
self._cm = cm
|
||||
|
||||
async def __aenter__(self) -> _T:
|
||||
@@ -539,27 +544,27 @@ class _AsyncLiftContextManager(typing.AsyncContextManager[_T]):
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: typing.Optional[typing.Type[BaseException]],
|
||||
exc_value: typing.Optional[BaseException],
|
||||
traceback: typing.Optional[types.TracebackType],
|
||||
) -> typing.Optional[bool]:
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_value: BaseException | None,
|
||||
traceback: types.TracebackType | None,
|
||||
) -> bool | None:
|
||||
return self._cm.__exit__(exc_type, exc_value, traceback)
|
||||
|
||||
|
||||
def _wrap_gen_lifespan_context(
|
||||
lifespan_context: typing.Callable[[typing.Any], typing.Generator]
|
||||
) -> typing.Callable[[typing.Any], typing.AsyncContextManager]:
|
||||
lifespan_context: Callable[[Any], Generator[Any, Any, Any]],
|
||||
) -> Callable[[Any], AbstractAsyncContextManager[Any]]:
|
||||
cmgr = contextlib.contextmanager(lifespan_context)
|
||||
|
||||
@functools.wraps(cmgr)
|
||||
def wrapper(app: typing.Any) -> _AsyncLiftContextManager:
|
||||
def wrapper(app: Any) -> _AsyncLiftContextManager[Any]:
|
||||
return _AsyncLiftContextManager(cmgr(app))
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class _DefaultLifespan:
|
||||
def __init__(self, router: "Router"):
|
||||
def __init__(self, router: Router):
|
||||
self._router = router
|
||||
|
||||
async def __aenter__(self) -> None:
|
||||
@@ -575,14 +580,16 @@ class _DefaultLifespan:
|
||||
class Router:
|
||||
def __init__(
|
||||
self,
|
||||
routes: typing.Optional[typing.Sequence[BaseRoute]] = None,
|
||||
routes: Sequence[BaseRoute] | None = None,
|
||||
redirect_slashes: bool = True,
|
||||
default: typing.Optional[ASGIApp] = None,
|
||||
on_startup: typing.Optional[typing.Sequence[typing.Callable]] = None,
|
||||
on_shutdown: typing.Optional[typing.Sequence[typing.Callable]] = None,
|
||||
default: ASGIApp | None = None,
|
||||
on_startup: Sequence[Callable[[], Any]] | None = None,
|
||||
on_shutdown: Sequence[Callable[[], Any]] | None = None,
|
||||
# the generic to Lifespan[AppType] is the type of the top level application
|
||||
# which the router cannot know statically, so we use typing.Any
|
||||
lifespan: typing.Optional[Lifespan[typing.Any]] = None,
|
||||
# which the router cannot know statically, so we use Any
|
||||
lifespan: Lifespan[Any] | None = None,
|
||||
*,
|
||||
middleware: Sequence[Middleware] | None = None,
|
||||
) -> None:
|
||||
self.routes = [] if routes is None else list(routes)
|
||||
self.redirect_slashes = redirect_slashes
|
||||
@@ -594,12 +601,18 @@ class Router:
|
||||
warnings.warn(
|
||||
"The on_startup and on_shutdown parameters are deprecated, and they "
|
||||
"will be removed on version 1.0. Use the lifespan parameter instead. "
|
||||
"See more about it on https://www.starlette.io/lifespan/.",
|
||||
"See more about it on https://starlette.dev/lifespan/.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
if lifespan:
|
||||
warnings.warn(
|
||||
"The `lifespan` parameter cannot be used with `on_startup` or "
|
||||
"`on_shutdown`. Both `on_startup` and `on_shutdown` will be "
|
||||
"ignored."
|
||||
)
|
||||
|
||||
if lifespan is None:
|
||||
self.lifespan_context: Lifespan = _DefaultLifespan(self)
|
||||
self.lifespan_context: Lifespan[Any] = _DefaultLifespan(self)
|
||||
|
||||
elif inspect.isasyncgenfunction(lifespan):
|
||||
warnings.warn(
|
||||
@@ -608,20 +621,24 @@ class Router:
|
||||
DeprecationWarning,
|
||||
)
|
||||
self.lifespan_context = asynccontextmanager(
|
||||
lifespan, # type: ignore[arg-type]
|
||||
lifespan,
|
||||
)
|
||||
elif inspect.isgeneratorfunction(lifespan):
|
||||
warnings.warn(
|
||||
"generator function lifespans are deprecated, "
|
||||
"use an @contextlib.asynccontextmanager function instead",
|
||||
"generator function lifespans are deprecated, use an @contextlib.asynccontextmanager function instead",
|
||||
DeprecationWarning,
|
||||
)
|
||||
self.lifespan_context = _wrap_gen_lifespan_context(
|
||||
lifespan, # type: ignore[arg-type]
|
||||
lifespan,
|
||||
)
|
||||
else:
|
||||
self.lifespan_context = lifespan
|
||||
|
||||
self.middleware_stack = self.app
|
||||
if middleware:
|
||||
for cls, args, kwargs in reversed(middleware):
|
||||
self.middleware_stack = cls(self.middleware_stack, *args, **kwargs)
|
||||
|
||||
async def not_found(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if scope["type"] == "websocket":
|
||||
websocket_close = WebSocketClose()
|
||||
@@ -637,13 +654,13 @@ class Router:
|
||||
response = PlainTextResponse("Not Found", status_code=404)
|
||||
await response(scope, receive, send)
|
||||
|
||||
def url_path_for(self, __name: str, **path_params: typing.Any) -> URLPath:
|
||||
def url_path_for(self, name: str, /, **path_params: Any) -> URLPath:
|
||||
for route in self.routes:
|
||||
try:
|
||||
return route.url_path_for(__name, **path_params)
|
||||
return route.url_path_for(name, **path_params)
|
||||
except NoMatchFound:
|
||||
pass
|
||||
raise NoMatchFound(__name, path_params)
|
||||
raise NoMatchFound(name, path_params)
|
||||
|
||||
async def startup(self) -> None:
|
||||
"""
|
||||
@@ -671,15 +688,13 @@ class Router:
|
||||
startup and shutdown events.
|
||||
"""
|
||||
started = False
|
||||
app: typing.Any = scope.get("app")
|
||||
app: Any = scope.get("app")
|
||||
await receive()
|
||||
try:
|
||||
async with self.lifespan_context(app) as maybe_state:
|
||||
if maybe_state is not None:
|
||||
if "state" not in scope:
|
||||
raise RuntimeError(
|
||||
'The server does not support "state" in the lifespan scope.'
|
||||
)
|
||||
raise RuntimeError('The server does not support "state" in the lifespan scope.')
|
||||
scope["state"].update(maybe_state)
|
||||
await send({"type": "lifespan.startup.complete"})
|
||||
started = True
|
||||
@@ -698,6 +713,9 @@ class Router:
|
||||
"""
|
||||
The main entry point to the Router class.
|
||||
"""
|
||||
await self.middleware_stack(scope, receive, send)
|
||||
|
||||
async def app(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
assert scope["type"] in ("http", "websocket", "lifespan")
|
||||
|
||||
if "router" not in scope:
|
||||
@@ -729,9 +747,10 @@ class Router:
|
||||
await partial.handle(scope, receive, send)
|
||||
return
|
||||
|
||||
if scope["type"] == "http" and self.redirect_slashes and scope["path"] != "/":
|
||||
route_path = get_route_path(scope)
|
||||
if scope["type"] == "http" and self.redirect_slashes and route_path != "/":
|
||||
redirect_scope = dict(scope)
|
||||
if scope["path"].endswith("/"):
|
||||
if route_path.endswith("/"):
|
||||
redirect_scope["path"] = redirect_scope["path"].rstrip("/")
|
||||
else:
|
||||
redirect_scope["path"] = redirect_scope["path"] + "/"
|
||||
@@ -746,29 +765,25 @@ class Router:
|
||||
|
||||
await self.default(scope, receive, send)
|
||||
|
||||
def __eq__(self, other: typing.Any) -> bool:
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
return isinstance(other, Router) and self.routes == other.routes
|
||||
|
||||
def mount(
|
||||
self, path: str, app: ASGIApp, name: typing.Optional[str] = None
|
||||
) -> None: # pragma: nocover
|
||||
def mount(self, path: str, app: ASGIApp, name: str | None = None) -> None: # pragma: no cover
|
||||
route = Mount(path, app=app, name=name)
|
||||
self.routes.append(route)
|
||||
|
||||
def host(
|
||||
self, host: str, app: ASGIApp, name: typing.Optional[str] = None
|
||||
) -> None: # pragma: no cover
|
||||
def host(self, host: str, app: ASGIApp, name: str | None = None) -> None: # pragma: no cover
|
||||
route = Host(host, app=app, name=name)
|
||||
self.routes.append(route)
|
||||
|
||||
def add_route(
|
||||
self,
|
||||
path: str,
|
||||
endpoint: typing.Callable,
|
||||
methods: typing.Optional[typing.List[str]] = None,
|
||||
name: typing.Optional[str] = None,
|
||||
endpoint: Callable[[Request], Awaitable[Response] | Response],
|
||||
methods: Collection[str] | None = None,
|
||||
name: str | None = None,
|
||||
include_in_schema: bool = True,
|
||||
) -> None: # pragma: nocover
|
||||
) -> None: # pragma: no cover
|
||||
route = Route(
|
||||
path,
|
||||
endpoint=endpoint,
|
||||
@@ -779,7 +794,10 @@ class Router:
|
||||
self.routes.append(route)
|
||||
|
||||
def add_websocket_route(
|
||||
self, path: str, endpoint: typing.Callable, name: typing.Optional[str] = None
|
||||
self,
|
||||
path: str,
|
||||
endpoint: Callable[[WebSocket], Awaitable[None]],
|
||||
name: str | None = None,
|
||||
) -> None: # pragma: no cover
|
||||
route = WebSocketRoute(path, endpoint=endpoint, name=name)
|
||||
self.routes.append(route)
|
||||
@@ -787,10 +805,10 @@ class Router:
|
||||
def route(
|
||||
self,
|
||||
path: str,
|
||||
methods: typing.Optional[typing.List[str]] = None,
|
||||
name: typing.Optional[str] = None,
|
||||
methods: Collection[str] | None = None,
|
||||
name: str | None = None,
|
||||
include_in_schema: bool = True,
|
||||
) -> typing.Callable:
|
||||
) -> Callable: # type: ignore[type-arg]
|
||||
"""
|
||||
We no longer document this decorator style API, and its usage is discouraged.
|
||||
Instead you should use the following approach:
|
||||
@@ -800,11 +818,11 @@ class Router:
|
||||
"""
|
||||
warnings.warn(
|
||||
"The `route` decorator is deprecated, and will be removed in version 1.0.0."
|
||||
"Refer to https://www.starlette.io/routing/#http-routing for the recommended approach.", # noqa: E501
|
||||
"Refer to https://starlette.dev/routing/#http-routing for the recommended approach.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
def decorator(func: typing.Callable) -> typing.Callable:
|
||||
def decorator(func: Callable) -> Callable: # type: ignore[type-arg]
|
||||
self.add_route(
|
||||
path,
|
||||
func,
|
||||
@@ -816,9 +834,7 @@ class Router:
|
||||
|
||||
return decorator
|
||||
|
||||
def websocket_route(
|
||||
self, path: str, name: typing.Optional[str] = None
|
||||
) -> typing.Callable:
|
||||
def websocket_route(self, path: str, name: str | None = None) -> Callable: # type: ignore[type-arg]
|
||||
"""
|
||||
We no longer document this decorator style API, and its usage is discouraged.
|
||||
Instead you should use the following approach:
|
||||
@@ -827,20 +843,18 @@ class Router:
|
||||
>>> app = Starlette(routes=routes)
|
||||
"""
|
||||
warnings.warn(
|
||||
"The `websocket_route` decorator is deprecated, and will be removed in version 1.0.0. Refer to " # noqa: E501
|
||||
"https://www.starlette.io/routing/#websocket-routing for the recommended approach.", # noqa: E501
|
||||
"The `websocket_route` decorator is deprecated, and will be removed in version 1.0.0. Refer to "
|
||||
"https://starlette.dev/routing/#websocket-routing for the recommended approach.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
def decorator(func: typing.Callable) -> typing.Callable:
|
||||
def decorator(func: Callable) -> Callable: # type: ignore[type-arg]
|
||||
self.add_websocket_route(path, func, name=name)
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def add_event_handler(
|
||||
self, event_type: str, func: typing.Callable
|
||||
) -> None: # pragma: no cover
|
||||
def add_event_handler(self, event_type: str, func: Callable[[], Any]) -> None: # pragma: no cover
|
||||
assert event_type in ("startup", "shutdown")
|
||||
|
||||
if event_type == "startup":
|
||||
@@ -848,14 +862,14 @@ class Router:
|
||||
else:
|
||||
self.on_shutdown.append(func)
|
||||
|
||||
def on_event(self, event_type: str) -> typing.Callable:
|
||||
def on_event(self, event_type: str) -> Callable: # type: ignore[type-arg]
|
||||
warnings.warn(
|
||||
"The `on_event` decorator is deprecated, and will be removed in version 1.0.0. " # noqa: E501
|
||||
"Refer to https://www.starlette.io/lifespan/ for recommended approach.",
|
||||
"The `on_event` decorator is deprecated, and will be removed in version 1.0.0. "
|
||||
"Refer to https://starlette.dev/lifespan/ for recommended approach.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
def decorator(func: typing.Callable) -> typing.Callable:
|
||||
def decorator(func: Callable) -> Callable: # type: ignore[type-arg]
|
||||
self.add_event_handler(event_type, func)
|
||||
return func
|
||||
|
||||
|
||||
Reference in New Issue
Block a user