This commit is contained in:
Iliyan Angelov
2025-12-01 06:50:10 +02:00
parent 91f51bc6fe
commit 62c1fe5951
4682 changed files with 544807 additions and 31208 deletions

View File

@@ -1,16 +1,21 @@
import asyncio
import dataclasses
import email.message
import functools
import inspect
import json
from contextlib import AsyncExitStack
import sys
from contextlib import AsyncExitStack, asynccontextmanager
from enum import Enum, IntEnum
from typing import (
Any,
AsyncIterator,
Awaitable,
Callable,
Collection,
Coroutine,
Dict,
List,
Mapping,
Optional,
Sequence,
Set,
@@ -19,7 +24,8 @@ from typing import (
Union,
)
from fastapi import params
from annotated_doc import Doc
from fastapi import params, temp_pydantic_v1_params
from fastapi._compat import (
ModelField,
Undefined,
@@ -31,8 +37,10 @@ from fastapi._compat import (
from fastapi.datastructures import Default, DefaultPlaceholder
from fastapi.dependencies.models import Dependant
from fastapi.dependencies.utils import (
_should_embed_body_fields,
get_body_field,
get_dependant,
get_flat_dependant,
get_parameterless_sub_dependant,
get_typed_return_annotation,
solve_dependencies,
@@ -47,13 +55,15 @@ from fastapi.exceptions import (
from fastapi.types import DecoratedCallable, IncEx
from fastapi.utils import (
create_cloned_field,
create_response_field,
create_model_field,
generate_unique_id,
get_value_or_default,
is_body_allowed_for_status_code,
)
from pydantic import BaseModel
from starlette import routing
from starlette._exception_handler import wrap_app_handling_exceptions
from starlette._utils import is_async_callable
from starlette.concurrency import run_in_threadpool
from starlette.exceptions import HTTPException
from starlette.requests import Request
@@ -63,13 +73,84 @@ from starlette.routing import (
Match,
compile_path,
get_name,
request_response,
websocket_session,
)
from starlette.routing import Mount as Mount # noqa
from starlette.types import ASGIApp, Lifespan, Scope
from starlette.types import AppType, ASGIApp, Lifespan, Receive, Scope, Send
from starlette.websockets import WebSocket
from typing_extensions import Annotated, Doc, deprecated # type: ignore [attr-defined]
from typing_extensions import Annotated, deprecated
if sys.version_info >= (3, 13): # pragma: no cover
from inspect import iscoroutinefunction
else: # pragma: no cover
from asyncio import iscoroutinefunction
# Copy of starlette.routing.request_response modified to include the
# dependencies' AsyncExitStack
def request_response(
func: Callable[[Request], Union[Awaitable[Response], Response]],
) -> ASGIApp:
"""
Takes a function or coroutine `func(request) -> response`,
and returns an ASGI application.
"""
f: Callable[[Request], Awaitable[Response]] = (
func if is_async_callable(func) else functools.partial(run_in_threadpool, func) # type:ignore
)
async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive, send)
async def app(scope: Scope, receive: Receive, send: Send) -> None:
# Starts customization
response_awaited = False
async with AsyncExitStack() as request_stack:
scope["fastapi_inner_astack"] = request_stack
async with AsyncExitStack() as function_stack:
scope["fastapi_function_astack"] = function_stack
response = await f(request)
await response(scope, receive, send)
# Continues customization
response_awaited = True
if not response_awaited:
raise FastAPIError(
"Response not awaited. There's a high chance that the "
"application code is raising an exception and a dependency with yield "
"has a block with a bare except, or a block with except Exception, "
"and is not raising the exception again. Read more about it in the "
"docs: https://fastapi.tiangolo.com/tutorial/dependencies/dependencies-with-yield/#dependencies-with-yield-and-except"
)
# Same as in Starlette
await wrap_app_handling_exceptions(app, request)(scope, receive, send)
return app
# Copy of starlette.routing.websocket_session modified to include the
# dependencies' AsyncExitStack
def websocket_session(
func: Callable[[WebSocket], Awaitable[None]],
) -> ASGIApp:
"""
Takes a coroutine `func(session)`, and returns an ASGI application.
"""
# assert asyncio.iscoroutinefunction(func), "WebSocket endpoints must be async"
async def app(scope: Scope, receive: Receive, send: Send) -> None:
session = WebSocket(scope, receive=receive, send=send)
async def app(scope: Scope, receive: Receive, send: Send) -> None:
async with AsyncExitStack() as request_stack:
scope["fastapi_inner_astack"] = request_stack
async with AsyncExitStack() as function_stack:
scope["fastapi_function_astack"] = function_stack
await func(session)
# Same as in Starlette
await wrap_app_handling_exceptions(app, session)(scope, receive, send)
return app
def _prepare_response_content(
@@ -115,10 +196,28 @@ def _prepare_response_content(
for k, v in res.items()
}
elif dataclasses.is_dataclass(res):
assert not isinstance(res, type)
return dataclasses.asdict(res)
return res
def _merge_lifespan_context(
original_context: Lifespan[Any], nested_context: Lifespan[Any]
) -> Lifespan[Any]:
@asynccontextmanager
async def merged_lifespan(
app: AppType,
) -> AsyncIterator[Optional[Mapping[str, Any]]]:
async with original_context(app) as maybe_original_state:
async with nested_context(app) as maybe_nested_state:
if maybe_nested_state is None and maybe_original_state is None:
yield None # old ASGI compatibility
else:
yield {**(maybe_nested_state or {}), **(maybe_original_state or {})}
return merged_lifespan # type: ignore[return-value]
async def serialize_response(
*,
field: Optional[ModelField] = None,
@@ -206,24 +305,32 @@ def get_request_handler(
response_model_exclude_defaults: bool = False,
response_model_exclude_none: bool = False,
dependency_overrides_provider: Optional[Any] = None,
embed_body_fields: bool = False,
) -> Callable[[Request], Coroutine[Any, Any, Response]]:
assert dependant.call is not None, "dependant.call must be a function"
is_coroutine = asyncio.iscoroutinefunction(dependant.call)
is_body_form = body_field and isinstance(body_field.field_info, params.Form)
is_coroutine = iscoroutinefunction(dependant.call)
is_body_form = body_field and isinstance(
body_field.field_info, (params.Form, temp_pydantic_v1_params.Form)
)
if isinstance(response_class, DefaultPlaceholder):
actual_response_class: Type[Response] = response_class.value
else:
actual_response_class = response_class
async def app(request: Request) -> Response:
response: Union[Response, None] = None
file_stack = request.scope.get("fastapi_middleware_astack")
assert isinstance(file_stack, AsyncExitStack), (
"fastapi_middleware_astack not found in request scope"
)
# Read body and auto-close files
try:
body: Any = None
if body_field:
if is_body_form:
body = await request.form()
stack = request.scope.get("fastapi_astack")
assert isinstance(stack, AsyncExitStack)
stack.push_async_callback(body.close)
file_stack.push_async_callback(body.close)
else:
body_bytes = await request.body()
if body_bytes:
@@ -243,7 +350,7 @@ def get_request_handler(
else:
body = body_bytes
except json.JSONDecodeError as e:
raise RequestValidationError(
validation_error = RequestValidationError(
[
{
"type": "json_invalid",
@@ -254,75 +361,106 @@ def get_request_handler(
}
],
body=e.doc,
) from e
)
raise validation_error from e
except HTTPException:
# If a middleware raises an HTTPException, it should be raised again
raise
except Exception as e:
raise HTTPException(
http_error = HTTPException(
status_code=400, detail="There was an error parsing the body"
) from e
)
raise http_error from e
# Solve dependencies and run path operation function, auto-closing dependencies
errors: List[Any] = []
async_exit_stack = request.scope.get("fastapi_inner_astack")
assert isinstance(async_exit_stack, AsyncExitStack), (
"fastapi_inner_astack not found in request scope"
)
solved_result = await solve_dependencies(
request=request,
dependant=dependant,
body=body,
dependency_overrides_provider=dependency_overrides_provider,
async_exit_stack=async_exit_stack,
embed_body_fields=embed_body_fields,
)
values, errors, background_tasks, sub_response, _ = solved_result
if errors:
raise RequestValidationError(_normalize_errors(errors), body=body)
else:
errors = solved_result.errors
if not errors:
raw_response = await run_endpoint_function(
dependant=dependant, values=values, is_coroutine=is_coroutine
)
if isinstance(raw_response, Response):
if raw_response.background is None:
raw_response.background = background_tasks
return raw_response
response_args: Dict[str, Any] = {"background": background_tasks}
# If status_code was set, use it, otherwise use the default from the
# response class, in the case of redirect it's 307
current_status_code = (
status_code if status_code else sub_response.status_code
)
if current_status_code is not None:
response_args["status_code"] = current_status_code
if sub_response.status_code:
response_args["status_code"] = sub_response.status_code
content = await serialize_response(
field=response_field,
response_content=raw_response,
include=response_model_include,
exclude=response_model_exclude,
by_alias=response_model_by_alias,
exclude_unset=response_model_exclude_unset,
exclude_defaults=response_model_exclude_defaults,
exclude_none=response_model_exclude_none,
dependant=dependant,
values=solved_result.values,
is_coroutine=is_coroutine,
)
response = actual_response_class(content, **response_args)
if not is_body_allowed_for_status_code(response.status_code):
response.body = b""
response.headers.raw.extend(sub_response.headers.raw)
return response
if isinstance(raw_response, Response):
if raw_response.background is None:
raw_response.background = solved_result.background_tasks
response = raw_response
else:
response_args: Dict[str, Any] = {
"background": solved_result.background_tasks
}
# If status_code was set, use it, otherwise use the default from the
# response class, in the case of redirect it's 307
current_status_code = (
status_code if status_code else solved_result.response.status_code
)
if current_status_code is not None:
response_args["status_code"] = current_status_code
if solved_result.response.status_code:
response_args["status_code"] = solved_result.response.status_code
content = await serialize_response(
field=response_field,
response_content=raw_response,
include=response_model_include,
exclude=response_model_exclude,
by_alias=response_model_by_alias,
exclude_unset=response_model_exclude_unset,
exclude_defaults=response_model_exclude_defaults,
exclude_none=response_model_exclude_none,
is_coroutine=is_coroutine,
)
response = actual_response_class(content, **response_args)
if not is_body_allowed_for_status_code(response.status_code):
response.body = b""
response.headers.raw.extend(solved_result.response.headers.raw)
if errors:
validation_error = RequestValidationError(
_normalize_errors(errors), body=body
)
raise validation_error
# Return response
assert response
return response
return app
def get_websocket_app(
dependant: Dependant, dependency_overrides_provider: Optional[Any] = None
dependant: Dependant,
dependency_overrides_provider: Optional[Any] = None,
embed_body_fields: bool = False,
) -> Callable[[WebSocket], Coroutine[Any, Any, Any]]:
async def app(websocket: WebSocket) -> None:
async_exit_stack = websocket.scope.get("fastapi_inner_astack")
assert isinstance(async_exit_stack, AsyncExitStack), (
"fastapi_inner_astack not found in request scope"
)
solved_result = await solve_dependencies(
request=websocket,
dependant=dependant,
dependency_overrides_provider=dependency_overrides_provider,
async_exit_stack=async_exit_stack,
embed_body_fields=embed_body_fields,
)
values, errors, _, _2, _3 = solved_result
if errors:
raise WebSocketRequestValidationError(_normalize_errors(errors))
if solved_result.errors:
raise WebSocketRequestValidationError(
_normalize_errors(solved_result.errors)
)
assert dependant.call is not None, "dependant.call must be a function"
await dependant.call(**values)
await dependant.call(**solved_result.values)
return app
@@ -342,17 +480,23 @@ class APIWebSocketRoute(routing.WebSocketRoute):
self.name = get_name(endpoint) if name is None else name
self.dependencies = list(dependencies or [])
self.path_regex, self.path_format, self.param_convertors = compile_path(path)
self.dependant = get_dependant(path=self.path_format, call=self.endpoint)
self.dependant = get_dependant(
path=self.path_format, call=self.endpoint, scope="function"
)
for depends in self.dependencies[::-1]:
self.dependant.dependencies.insert(
0,
get_parameterless_sub_dependant(depends=depends, path=self.path_format),
)
self._flat_dependant = get_flat_dependant(self.dependant)
self._embed_body_fields = _should_embed_body_fields(
self._flat_dependant.body_params
)
self.app = websocket_session(
get_websocket_app(
dependant=self.dependant,
dependency_overrides_provider=dependency_overrides_provider,
embed_body_fields=self._embed_body_fields,
)
)
@@ -431,9 +575,9 @@ class APIRoute(routing.Route):
methods = ["GET"]
self.methods: Set[str] = {method.upper() for method in methods}
if isinstance(generate_unique_id_function, DefaultPlaceholder):
current_generate_unique_id: Callable[
["APIRoute"], str
] = generate_unique_id_function.value
current_generate_unique_id: Callable[[APIRoute], str] = (
generate_unique_id_function.value
)
else:
current_generate_unique_id = generate_unique_id_function
self.unique_id = self.operation_id or current_generate_unique_id(self)
@@ -442,11 +586,11 @@ class APIRoute(routing.Route):
status_code = int(status_code)
self.status_code = status_code
if self.response_model:
assert is_body_allowed_for_status_code(
status_code
), f"Status code {status_code} must not have a response body"
assert is_body_allowed_for_status_code(status_code), (
f"Status code {status_code} must not have a response body"
)
response_name = "Response_" + self.unique_id
self.response_field = create_response_field(
self.response_field = create_model_field(
name=response_name,
type_=self.response_model,
mode="serialization",
@@ -459,9 +603,9 @@ class APIRoute(routing.Route):
# By being a new field, no inheritance will be passed as is. A new model
# will always be created.
# TODO: remove when deprecating Pydantic v1
self.secure_cloned_response_field: Optional[
ModelField
] = create_cloned_field(self.response_field)
self.secure_cloned_response_field: Optional[ModelField] = (
create_cloned_field(self.response_field)
)
else:
self.response_field = None # type: ignore
self.secure_cloned_response_field = None
@@ -475,11 +619,13 @@ class APIRoute(routing.Route):
assert isinstance(response, dict), "An additional response must be a dict"
model = response.get("model")
if model:
assert is_body_allowed_for_status_code(
additional_status_code
), f"Status code {additional_status_code} must not have a response body"
assert is_body_allowed_for_status_code(additional_status_code), (
f"Status code {additional_status_code} must not have a response body"
)
response_name = f"Response_{additional_status_code}_{self.unique_id}"
response_field = create_response_field(name=response_name, type_=model)
response_field = create_model_field(
name=response_name, type_=model, mode="serialization"
)
response_fields[additional_status_code] = response_field
if response_fields:
self.response_fields: Dict[Union[int, str], ModelField] = response_fields
@@ -487,13 +633,23 @@ class APIRoute(routing.Route):
self.response_fields = {}
assert callable(endpoint), "An endpoint must be a callable"
self.dependant = get_dependant(path=self.path_format, call=self.endpoint)
self.dependant = get_dependant(
path=self.path_format, call=self.endpoint, scope="function"
)
for depends in self.dependencies[::-1]:
self.dependant.dependencies.insert(
0,
get_parameterless_sub_dependant(depends=depends, path=self.path_format),
)
self.body_field = get_body_field(dependant=self.dependant, name=self.unique_id)
self._flat_dependant = get_flat_dependant(self.dependant)
self._embed_body_fields = _should_embed_body_fields(
self._flat_dependant.body_params
)
self.body_field = get_body_field(
flat_dependant=self._flat_dependant,
name=self.unique_id,
embed_body_fields=self._embed_body_fields,
)
self.app = request_response(self.get_route_handler())
def get_route_handler(self) -> Callable[[Request], Coroutine[Any, Any, Response]]:
@@ -510,6 +666,7 @@ class APIRoute(routing.Route):
response_model_exclude_defaults=self.response_model_exclude_defaults,
response_model_exclude_none=self.response_model_exclude_none,
dependency_overrides_provider=self.dependency_overrides_provider,
embed_body_fields=self._embed_body_fields,
)
def matches(self, scope: Scope) -> Tuple[Match, Scope]:
@@ -741,7 +898,7 @@ class APIRouter(routing.Router):
This affects the generated OpenAPI (e.g. visible at `/docs`).
Read more about it in the
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi).
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi).
"""
),
] = True,
@@ -771,9 +928,9 @@ class APIRouter(routing.Router):
)
if prefix:
assert prefix.startswith("/"), "A path prefix must start with '/'"
assert not prefix.endswith(
"/"
), "A path prefix must not end with '/', as the routes will start with '/'"
assert not prefix.endswith("/"), (
"A path prefix must not end with '/', as the routes will start with '/'"
)
self.prefix = prefix
self.tags: List[Union[str, Enum]] = tags or []
self.dependencies = list(dependencies or [])
@@ -789,7 +946,7 @@ class APIRouter(routing.Router):
def route(
self,
path: str,
methods: Optional[List[str]] = None,
methods: Optional[Collection[str]] = None,
name: Optional[str] = None,
include_in_schema: bool = True,
) -> Callable[[DecoratedCallable], DecoratedCallable]:
@@ -1183,9 +1340,9 @@ class APIRouter(routing.Router):
"""
if prefix:
assert prefix.startswith("/"), "A path prefix must start with '/'"
assert not prefix.endswith(
"/"
), "A path prefix must not end with '/', as the routes will start with '/'"
assert not prefix.endswith("/"), (
"A path prefix must not end with '/', as the routes will start with '/'"
)
else:
for r in router.routes:
path = getattr(r, "path") # noqa: B009
@@ -1285,6 +1442,10 @@ class APIRouter(routing.Router):
self.add_event_handler("startup", handler)
for handler in router.on_shutdown:
self.add_event_handler("shutdown", handler)
self.lifespan_context = _merge_lifespan_context(
self.lifespan_context,
router.lifespan_context,
)
def get(
self,
@@ -1549,7 +1710,7 @@ class APIRouter(routing.Router):
This affects the generated OpenAPI (e.g. visible at `/docs`).
Read more about it in the
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi).
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi).
"""
),
] = True,
@@ -1926,7 +2087,7 @@ class APIRouter(routing.Router):
This affects the generated OpenAPI (e.g. visible at `/docs`).
Read more about it in the
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi).
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi).
"""
),
] = True,
@@ -2308,7 +2469,7 @@ class APIRouter(routing.Router):
This affects the generated OpenAPI (e.g. visible at `/docs`).
Read more about it in the
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi).
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi).
"""
),
] = True,
@@ -2690,7 +2851,7 @@ class APIRouter(routing.Router):
This affects the generated OpenAPI (e.g. visible at `/docs`).
Read more about it in the
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi).
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi).
"""
),
] = True,
@@ -3067,7 +3228,7 @@ class APIRouter(routing.Router):
This affects the generated OpenAPI (e.g. visible at `/docs`).
Read more about it in the
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi).
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi).
"""
),
] = True,
@@ -3444,7 +3605,7 @@ class APIRouter(routing.Router):
This affects the generated OpenAPI (e.g. visible at `/docs`).
Read more about it in the
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi).
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi).
"""
),
] = True,
@@ -3826,7 +3987,7 @@ class APIRouter(routing.Router):
This affects the generated OpenAPI (e.g. visible at `/docs`).
Read more about it in the
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi).
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi).
"""
),
] = True,
@@ -4208,7 +4369,7 @@ class APIRouter(routing.Router):
This affects the generated OpenAPI (e.g. visible at `/docs`).
Read more about it in the
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi).
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi).
"""
),
] = True,
@@ -4293,7 +4454,7 @@ class APIRouter(routing.Router):
app = FastAPI()
router = APIRouter()
@router.put("/items/{item_id}")
@router.trace("/items/{item_id}")
def trace_item(item_id: str):
return None