This commit is contained in:
Iliyan Angelov
2025-12-01 06:50:10 +02:00
parent 91f51bc6fe
commit 62c1fe5951
4682 changed files with 544807 additions and 31208 deletions

View File

@@ -1,6 +1,6 @@
"""FastAPI framework, high performance, easy to learn, fast to code, ready for production"""
__version__ = "0.104.1"
__version__ = "0.123.0"
from starlette import status as status

View File

@@ -0,0 +1,3 @@
from fastapi.cli import main
main()

View File

@@ -1,629 +0,0 @@
from collections import deque
from copy import copy
from dataclasses import dataclass, is_dataclass
from enum import Enum
from typing import (
Any,
Callable,
Deque,
Dict,
FrozenSet,
List,
Mapping,
Sequence,
Set,
Tuple,
Type,
Union,
)
from fastapi.exceptions import RequestErrorModel
from fastapi.types import IncEx, ModelNameMap, UnionType
from pydantic import BaseModel, create_model
from pydantic.version import VERSION as PYDANTIC_VERSION
from starlette.datastructures import UploadFile
from typing_extensions import Annotated, Literal, get_args, get_origin
PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.")
sequence_annotation_to_type = {
Sequence: list,
List: list,
list: list,
Tuple: tuple,
tuple: tuple,
Set: set,
set: set,
FrozenSet: frozenset,
frozenset: frozenset,
Deque: deque,
deque: deque,
}
sequence_types = tuple(sequence_annotation_to_type.keys())
if PYDANTIC_V2:
from pydantic import PydanticSchemaGenerationError as PydanticSchemaGenerationError
from pydantic import TypeAdapter
from pydantic import ValidationError as ValidationError
from pydantic._internal._schema_generation_shared import ( # type: ignore[attr-defined]
GetJsonSchemaHandler as GetJsonSchemaHandler,
)
from pydantic._internal._typing_extra import eval_type_lenient
from pydantic._internal._utils import lenient_issubclass as lenient_issubclass
from pydantic.fields import FieldInfo
from pydantic.json_schema import GenerateJsonSchema as GenerateJsonSchema
from pydantic.json_schema import JsonSchemaValue as JsonSchemaValue
from pydantic_core import CoreSchema as CoreSchema
from pydantic_core import PydanticUndefined, PydanticUndefinedType
from pydantic_core import Url as Url
try:
from pydantic_core.core_schema import (
with_info_plain_validator_function as with_info_plain_validator_function,
)
except ImportError: # pragma: no cover
from pydantic_core.core_schema import (
general_plain_validator_function as with_info_plain_validator_function, # noqa: F401
)
Required = PydanticUndefined
Undefined = PydanticUndefined
UndefinedType = PydanticUndefinedType
evaluate_forwardref = eval_type_lenient
Validator = Any
class BaseConfig:
pass
class ErrorWrapper(Exception):
pass
@dataclass
class ModelField:
field_info: FieldInfo
name: str
mode: Literal["validation", "serialization"] = "validation"
@property
def alias(self) -> str:
a = self.field_info.alias
return a if a is not None else self.name
@property
def required(self) -> bool:
return self.field_info.is_required()
@property
def default(self) -> Any:
return self.get_default()
@property
def type_(self) -> Any:
return self.field_info.annotation
def __post_init__(self) -> None:
self._type_adapter: TypeAdapter[Any] = TypeAdapter(
Annotated[self.field_info.annotation, self.field_info]
)
def get_default(self) -> Any:
if self.field_info.is_required():
return Undefined
return self.field_info.get_default(call_default_factory=True)
def validate(
self,
value: Any,
values: Dict[str, Any] = {}, # noqa: B006
*,
loc: Tuple[Union[int, str], ...] = (),
) -> Tuple[Any, Union[List[Dict[str, Any]], None]]:
try:
return (
self._type_adapter.validate_python(value, from_attributes=True),
None,
)
except ValidationError as exc:
return None, _regenerate_error_with_loc(
errors=exc.errors(), loc_prefix=loc
)
def serialize(
self,
value: Any,
*,
mode: Literal["json", "python"] = "json",
include: Union[IncEx, None] = None,
exclude: Union[IncEx, None] = None,
by_alias: bool = True,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
) -> Any:
# What calls this code passes a value that already called
# self._type_adapter.validate_python(value)
return self._type_adapter.dump_python(
value,
mode=mode,
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)
def __hash__(self) -> int:
# Each ModelField is unique for our purposes, to allow making a dict from
# ModelField to its JSON Schema.
return id(self)
def get_annotation_from_field_info(
annotation: Any, field_info: FieldInfo, field_name: str
) -> Any:
return annotation
def _normalize_errors(errors: Sequence[Any]) -> List[Dict[str, Any]]:
return errors # type: ignore[return-value]
def _model_rebuild(model: Type[BaseModel]) -> None:
model.model_rebuild()
def _model_dump(
model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any
) -> Any:
return model.model_dump(mode=mode, **kwargs)
def _get_model_config(model: BaseModel) -> Any:
return model.model_config
def get_schema_from_model_field(
*,
field: ModelField,
schema_generator: GenerateJsonSchema,
model_name_map: ModelNameMap,
field_mapping: Dict[
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
],
separate_input_output_schemas: bool = True,
) -> Dict[str, Any]:
override_mode: Union[Literal["validation"], None] = (
None if separate_input_output_schemas else "validation"
)
# This expects that GenerateJsonSchema was already used to generate the definitions
json_schema = field_mapping[(field, override_mode or field.mode)]
if "$ref" not in json_schema:
# TODO remove when deprecating Pydantic v1
# Ref: https://github.com/pydantic/pydantic/blob/d61792cc42c80b13b23e3ffa74bc37ec7c77f7d1/pydantic/schema.py#L207
json_schema["title"] = (
field.field_info.title or field.alias.title().replace("_", " ")
)
return json_schema
def get_compat_model_name_map(fields: List[ModelField]) -> ModelNameMap:
return {}
def get_definitions(
*,
fields: List[ModelField],
schema_generator: GenerateJsonSchema,
model_name_map: ModelNameMap,
separate_input_output_schemas: bool = True,
) -> Tuple[
Dict[
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
],
Dict[str, Dict[str, Any]],
]:
override_mode: Union[Literal["validation"], None] = (
None if separate_input_output_schemas else "validation"
)
inputs = [
(field, override_mode or field.mode, field._type_adapter.core_schema)
for field in fields
]
field_mapping, definitions = schema_generator.generate_definitions(
inputs=inputs
)
return field_mapping, definitions # type: ignore[return-value]
def is_scalar_field(field: ModelField) -> bool:
from fastapi import params
return field_annotation_is_scalar(
field.field_info.annotation
) and not isinstance(field.field_info, params.Body)
def is_sequence_field(field: ModelField) -> bool:
return field_annotation_is_sequence(field.field_info.annotation)
def is_scalar_sequence_field(field: ModelField) -> bool:
return field_annotation_is_scalar_sequence(field.field_info.annotation)
def is_bytes_field(field: ModelField) -> bool:
return is_bytes_or_nonable_bytes_annotation(field.type_)
def is_bytes_sequence_field(field: ModelField) -> bool:
return is_bytes_sequence_annotation(field.type_)
def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo:
return type(field_info).from_annotation(annotation)
def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]:
origin_type = (
get_origin(field.field_info.annotation) or field.field_info.annotation
)
assert issubclass(origin_type, sequence_types) # type: ignore[arg-type]
return sequence_annotation_to_type[origin_type](value) # type: ignore[no-any-return]
def get_missing_field_error(loc: Tuple[str, ...]) -> Dict[str, Any]:
error = ValidationError.from_exception_data(
"Field required", [{"type": "missing", "loc": loc, "input": {}}]
).errors()[0]
error["input"] = None
return error # type: ignore[return-value]
def create_body_model(
*, fields: Sequence[ModelField], model_name: str
) -> Type[BaseModel]:
field_params = {f.name: (f.field_info.annotation, f.field_info) for f in fields}
BodyModel: Type[BaseModel] = create_model(model_name, **field_params) # type: ignore[call-overload]
return BodyModel
else:
from fastapi.openapi.constants import REF_PREFIX as REF_PREFIX
from pydantic import AnyUrl as Url # noqa: F401
from pydantic import ( # type: ignore[assignment]
BaseConfig as BaseConfig, # noqa: F401
)
from pydantic import ValidationError as ValidationError # noqa: F401
from pydantic.class_validators import ( # type: ignore[no-redef]
Validator as Validator, # noqa: F401
)
from pydantic.error_wrappers import ( # type: ignore[no-redef]
ErrorWrapper as ErrorWrapper, # noqa: F401
)
from pydantic.errors import MissingError
from pydantic.fields import ( # type: ignore[attr-defined]
SHAPE_FROZENSET,
SHAPE_LIST,
SHAPE_SEQUENCE,
SHAPE_SET,
SHAPE_SINGLETON,
SHAPE_TUPLE,
SHAPE_TUPLE_ELLIPSIS,
)
from pydantic.fields import FieldInfo as FieldInfo
from pydantic.fields import ( # type: ignore[no-redef,attr-defined]
ModelField as ModelField, # noqa: F401
)
from pydantic.fields import ( # type: ignore[no-redef,attr-defined]
Required as Required, # noqa: F401
)
from pydantic.fields import ( # type: ignore[no-redef,attr-defined]
Undefined as Undefined,
)
from pydantic.fields import ( # type: ignore[no-redef, attr-defined]
UndefinedType as UndefinedType, # noqa: F401
)
from pydantic.schema import (
field_schema,
get_flat_models_from_fields,
get_model_name_map,
model_process_schema,
)
from pydantic.schema import ( # type: ignore[no-redef] # noqa: F401
get_annotation_from_field_info as get_annotation_from_field_info,
)
from pydantic.typing import ( # type: ignore[no-redef]
evaluate_forwardref as evaluate_forwardref, # noqa: F401
)
from pydantic.utils import ( # type: ignore[no-redef]
lenient_issubclass as lenient_issubclass, # noqa: F401
)
GetJsonSchemaHandler = Any # type: ignore[assignment,misc]
JsonSchemaValue = Dict[str, Any] # type: ignore[misc]
CoreSchema = Any # type: ignore[assignment,misc]
sequence_shapes = {
SHAPE_LIST,
SHAPE_SET,
SHAPE_FROZENSET,
SHAPE_TUPLE,
SHAPE_SEQUENCE,
SHAPE_TUPLE_ELLIPSIS,
}
sequence_shape_to_type = {
SHAPE_LIST: list,
SHAPE_SET: set,
SHAPE_TUPLE: tuple,
SHAPE_SEQUENCE: list,
SHAPE_TUPLE_ELLIPSIS: list,
}
@dataclass
class GenerateJsonSchema: # type: ignore[no-redef]
ref_template: str
class PydanticSchemaGenerationError(Exception): # type: ignore[no-redef]
pass
def with_info_plain_validator_function( # type: ignore[misc]
function: Callable[..., Any],
*,
ref: Union[str, None] = None,
metadata: Any = None,
serialization: Any = None,
) -> Any:
return {}
def get_model_definitions(
*,
flat_models: Set[Union[Type[BaseModel], Type[Enum]]],
model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str],
) -> Dict[str, Any]:
definitions: Dict[str, Dict[str, Any]] = {}
for model in flat_models:
m_schema, m_definitions, m_nested_models = model_process_schema(
model, model_name_map=model_name_map, ref_prefix=REF_PREFIX
)
definitions.update(m_definitions)
model_name = model_name_map[model]
if "description" in m_schema:
m_schema["description"] = m_schema["description"].split("\f")[0]
definitions[model_name] = m_schema
return definitions
def is_pv1_scalar_field(field: ModelField) -> bool:
from fastapi import params
field_info = field.field_info
if not (
field.shape == SHAPE_SINGLETON # type: ignore[attr-defined]
and not lenient_issubclass(field.type_, BaseModel)
and not lenient_issubclass(field.type_, dict)
and not field_annotation_is_sequence(field.type_)
and not is_dataclass(field.type_)
and not isinstance(field_info, params.Body)
):
return False
if field.sub_fields: # type: ignore[attr-defined]
if not all(
is_pv1_scalar_field(f)
for f in field.sub_fields # type: ignore[attr-defined]
):
return False
return True
def is_pv1_scalar_sequence_field(field: ModelField) -> bool:
if (field.shape in sequence_shapes) and not lenient_issubclass( # type: ignore[attr-defined]
field.type_, BaseModel
):
if field.sub_fields is not None: # type: ignore[attr-defined]
for sub_field in field.sub_fields: # type: ignore[attr-defined]
if not is_pv1_scalar_field(sub_field):
return False
return True
if _annotation_is_sequence(field.type_):
return True
return False
def _normalize_errors(errors: Sequence[Any]) -> List[Dict[str, Any]]:
use_errors: List[Any] = []
for error in errors:
if isinstance(error, ErrorWrapper):
new_errors = ValidationError( # type: ignore[call-arg]
errors=[error], model=RequestErrorModel
).errors()
use_errors.extend(new_errors)
elif isinstance(error, list):
use_errors.extend(_normalize_errors(error))
else:
use_errors.append(error)
return use_errors
def _model_rebuild(model: Type[BaseModel]) -> None:
model.update_forward_refs()
def _model_dump(
model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any
) -> Any:
return model.dict(**kwargs)
def _get_model_config(model: BaseModel) -> Any:
return model.__config__ # type: ignore[attr-defined]
def get_schema_from_model_field(
*,
field: ModelField,
schema_generator: GenerateJsonSchema,
model_name_map: ModelNameMap,
field_mapping: Dict[
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
],
separate_input_output_schemas: bool = True,
) -> Dict[str, Any]:
# This expects that GenerateJsonSchema was already used to generate the definitions
return field_schema( # type: ignore[no-any-return]
field, model_name_map=model_name_map, ref_prefix=REF_PREFIX
)[0]
def get_compat_model_name_map(fields: List[ModelField]) -> ModelNameMap:
models = get_flat_models_from_fields(fields, known_models=set())
return get_model_name_map(models) # type: ignore[no-any-return]
def get_definitions(
*,
fields: List[ModelField],
schema_generator: GenerateJsonSchema,
model_name_map: ModelNameMap,
separate_input_output_schemas: bool = True,
) -> Tuple[
Dict[
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
],
Dict[str, Dict[str, Any]],
]:
models = get_flat_models_from_fields(fields, known_models=set())
return {}, get_model_definitions(
flat_models=models, model_name_map=model_name_map
)
def is_scalar_field(field: ModelField) -> bool:
return is_pv1_scalar_field(field)
def is_sequence_field(field: ModelField) -> bool:
return field.shape in sequence_shapes or _annotation_is_sequence(field.type_) # type: ignore[attr-defined]
def is_scalar_sequence_field(field: ModelField) -> bool:
return is_pv1_scalar_sequence_field(field)
def is_bytes_field(field: ModelField) -> bool:
return lenient_issubclass(field.type_, bytes)
def is_bytes_sequence_field(field: ModelField) -> bool:
return field.shape in sequence_shapes and lenient_issubclass(field.type_, bytes) # type: ignore[attr-defined]
def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo:
return copy(field_info)
def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]:
return sequence_shape_to_type[field.shape](value) # type: ignore[no-any-return,attr-defined]
def get_missing_field_error(loc: Tuple[str, ...]) -> Dict[str, Any]:
missing_field_error = ErrorWrapper(MissingError(), loc=loc) # type: ignore[call-arg]
new_error = ValidationError([missing_field_error], RequestErrorModel)
return new_error.errors()[0] # type: ignore[return-value]
def create_body_model(
*, fields: Sequence[ModelField], model_name: str
) -> Type[BaseModel]:
BodyModel = create_model(model_name)
for f in fields:
BodyModel.__fields__[f.name] = f # type: ignore[index]
return BodyModel
def _regenerate_error_with_loc(
*, errors: Sequence[Any], loc_prefix: Tuple[Union[str, int], ...]
) -> List[Dict[str, Any]]:
updated_loc_errors: List[Any] = [
{**err, "loc": loc_prefix + err.get("loc", ())}
for err in _normalize_errors(errors)
]
return updated_loc_errors
def _annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool:
if lenient_issubclass(annotation, (str, bytes)):
return False
return lenient_issubclass(annotation, sequence_types)
def field_annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool:
return _annotation_is_sequence(annotation) or _annotation_is_sequence(
get_origin(annotation)
)
def value_is_sequence(value: Any) -> bool:
return isinstance(value, sequence_types) and not isinstance(value, (str, bytes)) # type: ignore[arg-type]
def _annotation_is_complex(annotation: Union[Type[Any], None]) -> bool:
return (
lenient_issubclass(annotation, (BaseModel, Mapping, UploadFile))
or _annotation_is_sequence(annotation)
or is_dataclass(annotation)
)
def field_annotation_is_complex(annotation: Union[Type[Any], None]) -> bool:
origin = get_origin(annotation)
if origin is Union or origin is UnionType:
return any(field_annotation_is_complex(arg) for arg in get_args(annotation))
return (
_annotation_is_complex(annotation)
or _annotation_is_complex(origin)
or hasattr(origin, "__pydantic_core_schema__")
or hasattr(origin, "__get_pydantic_core_schema__")
)
def field_annotation_is_scalar(annotation: Any) -> bool:
# handle Ellipsis here to make tuple[int, ...] work nicely
return annotation is Ellipsis or not field_annotation_is_complex(annotation)
def field_annotation_is_scalar_sequence(annotation: Union[Type[Any], None]) -> bool:
origin = get_origin(annotation)
if origin is Union or origin is UnionType:
at_least_one_scalar_sequence = False
for arg in get_args(annotation):
if field_annotation_is_scalar_sequence(arg):
at_least_one_scalar_sequence = True
continue
elif not field_annotation_is_scalar(arg):
return False
return at_least_one_scalar_sequence
return field_annotation_is_sequence(annotation) and all(
field_annotation_is_scalar(sub_annotation)
for sub_annotation in get_args(annotation)
)
def is_bytes_or_nonable_bytes_annotation(annotation: Any) -> bool:
if lenient_issubclass(annotation, bytes):
return True
origin = get_origin(annotation)
if origin is Union or origin is UnionType:
for arg in get_args(annotation):
if lenient_issubclass(arg, bytes):
return True
return False
def is_uploadfile_or_nonable_uploadfile_annotation(annotation: Any) -> bool:
if lenient_issubclass(annotation, UploadFile):
return True
origin = get_origin(annotation)
if origin is Union or origin is UnionType:
for arg in get_args(annotation):
if lenient_issubclass(arg, UploadFile):
return True
return False
def is_bytes_sequence_annotation(annotation: Any) -> bool:
origin = get_origin(annotation)
if origin is Union or origin is UnionType:
at_least_one = False
for arg in get_args(annotation):
if is_bytes_sequence_annotation(arg):
at_least_one = True
continue
return at_least_one
return field_annotation_is_sequence(annotation) and all(
is_bytes_or_nonable_bytes_annotation(sub_annotation)
for sub_annotation in get_args(annotation)
)
def is_uploadfile_sequence_annotation(annotation: Any) -> bool:
origin = get_origin(annotation)
if origin is Union or origin is UnionType:
at_least_one = False
for arg in get_args(annotation):
if is_uploadfile_sequence_annotation(arg):
at_least_one = True
continue
return at_least_one
return field_annotation_is_sequence(annotation) and all(
is_uploadfile_or_nonable_uploadfile_annotation(sub_annotation)
for sub_annotation in get_args(annotation)
)

View File

@@ -0,0 +1,50 @@
from .main import BaseConfig as BaseConfig
from .main import PydanticSchemaGenerationError as PydanticSchemaGenerationError
from .main import RequiredParam as RequiredParam
from .main import Undefined as Undefined
from .main import UndefinedType as UndefinedType
from .main import Url as Url
from .main import Validator as Validator
from .main import _get_model_config as _get_model_config
from .main import _is_error_wrapper as _is_error_wrapper
from .main import _is_model_class as _is_model_class
from .main import _is_model_field as _is_model_field
from .main import _is_undefined as _is_undefined
from .main import _model_dump as _model_dump
from .main import _model_rebuild as _model_rebuild
from .main import copy_field_info as copy_field_info
from .main import create_body_model as create_body_model
from .main import evaluate_forwardref as evaluate_forwardref
from .main import get_annotation_from_field_info as get_annotation_from_field_info
from .main import get_cached_model_fields as get_cached_model_fields
from .main import get_compat_model_name_map as get_compat_model_name_map
from .main import get_definitions as get_definitions
from .main import get_missing_field_error as get_missing_field_error
from .main import get_schema_from_model_field as get_schema_from_model_field
from .main import is_bytes_field as is_bytes_field
from .main import is_bytes_sequence_field as is_bytes_sequence_field
from .main import is_scalar_field as is_scalar_field
from .main import is_scalar_sequence_field as is_scalar_sequence_field
from .main import is_sequence_field as is_sequence_field
from .main import serialize_sequence_value as serialize_sequence_value
from .main import (
with_info_plain_validator_function as with_info_plain_validator_function,
)
from .may_v1 import CoreSchema as CoreSchema
from .may_v1 import GetJsonSchemaHandler as GetJsonSchemaHandler
from .may_v1 import JsonSchemaValue as JsonSchemaValue
from .may_v1 import _normalize_errors as _normalize_errors
from .model_field import ModelField as ModelField
from .shared import PYDANTIC_V2 as PYDANTIC_V2
from .shared import PYDANTIC_VERSION_MINOR_TUPLE as PYDANTIC_VERSION_MINOR_TUPLE
from .shared import annotation_is_pydantic_v1 as annotation_is_pydantic_v1
from .shared import field_annotation_is_scalar as field_annotation_is_scalar
from .shared import (
is_uploadfile_or_nonable_uploadfile_annotation as is_uploadfile_or_nonable_uploadfile_annotation,
)
from .shared import (
is_uploadfile_sequence_annotation as is_uploadfile_sequence_annotation,
)
from .shared import lenient_issubclass as lenient_issubclass
from .shared import sequence_types as sequence_types
from .shared import value_is_sequence as value_is_sequence

View File

@@ -0,0 +1,362 @@
import sys
from functools import lru_cache
from typing import (
Any,
Dict,
List,
Sequence,
Tuple,
Type,
)
from fastapi._compat import may_v1
from fastapi._compat.shared import PYDANTIC_V2, lenient_issubclass
from fastapi.types import ModelNameMap
from pydantic import BaseModel
from typing_extensions import Literal
from .model_field import ModelField
if PYDANTIC_V2:
from .v2 import BaseConfig as BaseConfig
from .v2 import FieldInfo as FieldInfo
from .v2 import PydanticSchemaGenerationError as PydanticSchemaGenerationError
from .v2 import RequiredParam as RequiredParam
from .v2 import Undefined as Undefined
from .v2 import UndefinedType as UndefinedType
from .v2 import Url as Url
from .v2 import Validator as Validator
from .v2 import evaluate_forwardref as evaluate_forwardref
from .v2 import get_missing_field_error as get_missing_field_error
from .v2 import (
with_info_plain_validator_function as with_info_plain_validator_function,
)
else:
from .v1 import BaseConfig as BaseConfig # type: ignore[assignment]
from .v1 import FieldInfo as FieldInfo
from .v1 import ( # type: ignore[assignment]
PydanticSchemaGenerationError as PydanticSchemaGenerationError,
)
from .v1 import RequiredParam as RequiredParam
from .v1 import Undefined as Undefined
from .v1 import UndefinedType as UndefinedType
from .v1 import Url as Url # type: ignore[assignment]
from .v1 import Validator as Validator
from .v1 import evaluate_forwardref as evaluate_forwardref
from .v1 import get_missing_field_error as get_missing_field_error
from .v1 import ( # type: ignore[assignment]
with_info_plain_validator_function as with_info_plain_validator_function,
)
@lru_cache
def get_cached_model_fields(model: Type[BaseModel]) -> List[ModelField]:
if lenient_issubclass(model, may_v1.BaseModel):
from fastapi._compat import v1
return v1.get_model_fields(model)
else:
from . import v2
return v2.get_model_fields(model) # type: ignore[return-value]
def _is_undefined(value: object) -> bool:
if isinstance(value, may_v1.UndefinedType):
return True
elif PYDANTIC_V2:
from . import v2
return isinstance(value, v2.UndefinedType)
return False
def _get_model_config(model: BaseModel) -> Any:
if isinstance(model, may_v1.BaseModel):
from fastapi._compat import v1
return v1._get_model_config(model)
elif PYDANTIC_V2:
from . import v2
return v2._get_model_config(model)
def _model_dump(
model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any
) -> Any:
if isinstance(model, may_v1.BaseModel):
from fastapi._compat import v1
return v1._model_dump(model, mode=mode, **kwargs)
elif PYDANTIC_V2:
from . import v2
return v2._model_dump(model, mode=mode, **kwargs)
def _is_error_wrapper(exc: Exception) -> bool:
if isinstance(exc, may_v1.ErrorWrapper):
return True
elif PYDANTIC_V2:
from . import v2
return isinstance(exc, v2.ErrorWrapper)
return False
def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo:
if isinstance(field_info, may_v1.FieldInfo):
from fastapi._compat import v1
return v1.copy_field_info(field_info=field_info, annotation=annotation)
else:
assert PYDANTIC_V2
from . import v2
return v2.copy_field_info(field_info=field_info, annotation=annotation)
def create_body_model(
*, fields: Sequence[ModelField], model_name: str
) -> Type[BaseModel]:
if fields and isinstance(fields[0], may_v1.ModelField):
from fastapi._compat import v1
return v1.create_body_model(fields=fields, model_name=model_name)
else:
assert PYDANTIC_V2
from . import v2
return v2.create_body_model(fields=fields, model_name=model_name) # type: ignore[arg-type]
def get_annotation_from_field_info(
annotation: Any, field_info: FieldInfo, field_name: str
) -> Any:
if isinstance(field_info, may_v1.FieldInfo):
from fastapi._compat import v1
return v1.get_annotation_from_field_info(
annotation=annotation, field_info=field_info, field_name=field_name
)
else:
assert PYDANTIC_V2
from . import v2
return v2.get_annotation_from_field_info(
annotation=annotation, field_info=field_info, field_name=field_name
)
def is_bytes_field(field: ModelField) -> bool:
if isinstance(field, may_v1.ModelField):
from fastapi._compat import v1
return v1.is_bytes_field(field)
else:
assert PYDANTIC_V2
from . import v2
return v2.is_bytes_field(field) # type: ignore[arg-type]
def is_bytes_sequence_field(field: ModelField) -> bool:
if isinstance(field, may_v1.ModelField):
from fastapi._compat import v1
return v1.is_bytes_sequence_field(field)
else:
assert PYDANTIC_V2
from . import v2
return v2.is_bytes_sequence_field(field) # type: ignore[arg-type]
def is_scalar_field(field: ModelField) -> bool:
if isinstance(field, may_v1.ModelField):
from fastapi._compat import v1
return v1.is_scalar_field(field)
else:
assert PYDANTIC_V2
from . import v2
return v2.is_scalar_field(field) # type: ignore[arg-type]
def is_scalar_sequence_field(field: ModelField) -> bool:
if isinstance(field, may_v1.ModelField):
from fastapi._compat import v1
return v1.is_scalar_sequence_field(field)
else:
assert PYDANTIC_V2
from . import v2
return v2.is_scalar_sequence_field(field) # type: ignore[arg-type]
def is_sequence_field(field: ModelField) -> bool:
if isinstance(field, may_v1.ModelField):
from fastapi._compat import v1
return v1.is_sequence_field(field)
else:
assert PYDANTIC_V2
from . import v2
return v2.is_sequence_field(field) # type: ignore[arg-type]
def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]:
if isinstance(field, may_v1.ModelField):
from fastapi._compat import v1
return v1.serialize_sequence_value(field=field, value=value)
else:
assert PYDANTIC_V2
from . import v2
return v2.serialize_sequence_value(field=field, value=value) # type: ignore[arg-type]
def _model_rebuild(model: Type[BaseModel]) -> None:
if lenient_issubclass(model, may_v1.BaseModel):
from fastapi._compat import v1
v1._model_rebuild(model)
elif PYDANTIC_V2:
from . import v2
v2._model_rebuild(model)
def get_compat_model_name_map(fields: List[ModelField]) -> ModelNameMap:
v1_model_fields = [
field for field in fields if isinstance(field, may_v1.ModelField)
]
if v1_model_fields:
from fastapi._compat import v1
v1_flat_models = v1.get_flat_models_from_fields(
v1_model_fields, known_models=set()
)
all_flat_models = v1_flat_models
else:
all_flat_models = set()
if PYDANTIC_V2:
from . import v2
v2_model_fields = [
field for field in fields if isinstance(field, v2.ModelField)
]
v2_flat_models = v2.get_flat_models_from_fields(
v2_model_fields, known_models=set()
)
all_flat_models = all_flat_models.union(v2_flat_models)
model_name_map = v2.get_model_name_map(all_flat_models)
return model_name_map
from fastapi._compat import v1
model_name_map = v1.get_model_name_map(all_flat_models)
return model_name_map
def get_definitions(
*,
fields: List[ModelField],
model_name_map: ModelNameMap,
separate_input_output_schemas: bool = True,
) -> Tuple[
Dict[
Tuple[ModelField, Literal["validation", "serialization"]],
may_v1.JsonSchemaValue,
],
Dict[str, Dict[str, Any]],
]:
if sys.version_info < (3, 14):
v1_fields = [field for field in fields if isinstance(field, may_v1.ModelField)]
v1_field_maps, v1_definitions = may_v1.get_definitions(
fields=v1_fields,
model_name_map=model_name_map,
separate_input_output_schemas=separate_input_output_schemas,
)
if not PYDANTIC_V2:
return v1_field_maps, v1_definitions
else:
from . import v2
v2_fields = [field for field in fields if isinstance(field, v2.ModelField)]
v2_field_maps, v2_definitions = v2.get_definitions(
fields=v2_fields,
model_name_map=model_name_map,
separate_input_output_schemas=separate_input_output_schemas,
)
all_definitions = {**v1_definitions, **v2_definitions}
all_field_maps = {**v1_field_maps, **v2_field_maps}
return all_field_maps, all_definitions
# Pydantic v1 is not supported since Python 3.14
else:
from . import v2
v2_fields = [field for field in fields if isinstance(field, v2.ModelField)]
v2_field_maps, v2_definitions = v2.get_definitions(
fields=v2_fields,
model_name_map=model_name_map,
separate_input_output_schemas=separate_input_output_schemas,
)
return v2_field_maps, v2_definitions
def get_schema_from_model_field(
*,
field: ModelField,
model_name_map: ModelNameMap,
field_mapping: Dict[
Tuple[ModelField, Literal["validation", "serialization"]],
may_v1.JsonSchemaValue,
],
separate_input_output_schemas: bool = True,
) -> Dict[str, Any]:
if isinstance(field, may_v1.ModelField):
from fastapi._compat import v1
return v1.get_schema_from_model_field(
field=field,
model_name_map=model_name_map,
field_mapping=field_mapping,
separate_input_output_schemas=separate_input_output_schemas,
)
else:
assert PYDANTIC_V2
from . import v2
return v2.get_schema_from_model_field(
field=field, # type: ignore[arg-type]
model_name_map=model_name_map,
field_mapping=field_mapping, # type: ignore[arg-type]
separate_input_output_schemas=separate_input_output_schemas,
)
def _is_model_field(value: Any) -> bool:
if isinstance(value, may_v1.ModelField):
return True
elif PYDANTIC_V2:
from . import v2
return isinstance(value, v2.ModelField)
return False
def _is_model_class(value: Any) -> bool:
if lenient_issubclass(value, may_v1.BaseModel):
return True
elif PYDANTIC_V2:
from . import v2
return lenient_issubclass(value, v2.BaseModel) # type: ignore[attr-defined]
return False

View File

@@ -0,0 +1,123 @@
import sys
from typing import Any, Dict, List, Literal, Sequence, Tuple, Type, Union
from fastapi.types import ModelNameMap
if sys.version_info >= (3, 14):
class AnyUrl:
pass
class BaseConfig:
pass
class BaseModel:
pass
class Color:
pass
class CoreSchema:
pass
class ErrorWrapper:
pass
class FieldInfo:
pass
class GetJsonSchemaHandler:
pass
class JsonSchemaValue:
pass
class ModelField:
pass
class NameEmail:
pass
class RequiredParam:
pass
class SecretBytes:
pass
class SecretStr:
pass
class Undefined:
pass
class UndefinedType:
pass
class Url:
pass
from .v2 import ValidationError, create_model
def get_definitions(
*,
fields: List[ModelField],
model_name_map: ModelNameMap,
separate_input_output_schemas: bool = True,
) -> Tuple[
Dict[
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
],
Dict[str, Dict[str, Any]],
]:
return {}, {} # pragma: no cover
else:
from .v1 import AnyUrl as AnyUrl
from .v1 import BaseConfig as BaseConfig
from .v1 import BaseModel as BaseModel
from .v1 import Color as Color
from .v1 import CoreSchema as CoreSchema
from .v1 import ErrorWrapper as ErrorWrapper
from .v1 import FieldInfo as FieldInfo
from .v1 import GetJsonSchemaHandler as GetJsonSchemaHandler
from .v1 import JsonSchemaValue as JsonSchemaValue
from .v1 import ModelField as ModelField
from .v1 import NameEmail as NameEmail
from .v1 import RequiredParam as RequiredParam
from .v1 import SecretBytes as SecretBytes
from .v1 import SecretStr as SecretStr
from .v1 import Undefined as Undefined
from .v1 import UndefinedType as UndefinedType
from .v1 import Url as Url
from .v1 import ValidationError, create_model
from .v1 import get_definitions as get_definitions
RequestErrorModel: Type[BaseModel] = create_model("Request")
def _normalize_errors(errors: Sequence[Any]) -> List[Dict[str, Any]]:
use_errors: List[Any] = []
for error in errors:
if isinstance(error, ErrorWrapper):
new_errors = ValidationError( # type: ignore[call-arg]
errors=[error], model=RequestErrorModel
).errors()
use_errors.extend(new_errors)
elif isinstance(error, list):
use_errors.extend(_normalize_errors(error))
else:
use_errors.append(error)
return use_errors
def _regenerate_error_with_loc(
*, errors: Sequence[Any], loc_prefix: Tuple[Union[str, int], ...]
) -> List[Dict[str, Any]]:
updated_loc_errors: List[Any] = [
{**err, "loc": loc_prefix + err.get("loc", ())}
for err in _normalize_errors(errors)
]
return updated_loc_errors

View File

@@ -0,0 +1,53 @@
from typing import (
Any,
Dict,
List,
Tuple,
Union,
)
from fastapi.types import IncEx
from pydantic.fields import FieldInfo
from typing_extensions import Literal, Protocol
class ModelField(Protocol):
field_info: "FieldInfo"
name: str
mode: Literal["validation", "serialization"] = "validation"
_version: Literal["v1", "v2"] = "v1"
@property
def alias(self) -> str: ...
@property
def required(self) -> bool: ...
@property
def default(self) -> Any: ...
@property
def type_(self) -> Any: ...
def get_default(self) -> Any: ...
def validate(
self,
value: Any,
values: Dict[str, Any] = {}, # noqa: B006
*,
loc: Tuple[Union[int, str], ...] = (),
) -> Tuple[Any, Union[List[Dict[str, Any]], None]]: ...
def serialize(
self,
value: Any,
*,
mode: Literal["json", "python"] = "json",
include: Union[IncEx, None] = None,
exclude: Union[IncEx, None] = None,
by_alias: bool = True,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
) -> Any: ...

View File

@@ -0,0 +1,211 @@
import sys
import types
import typing
from collections import deque
from dataclasses import is_dataclass
from typing import (
Any,
Deque,
FrozenSet,
List,
Mapping,
Sequence,
Set,
Tuple,
Type,
Union,
)
from fastapi._compat import may_v1
from fastapi.types import UnionType
from pydantic import BaseModel
from pydantic.version import VERSION as PYDANTIC_VERSION
from starlette.datastructures import UploadFile
from typing_extensions import Annotated, get_args, get_origin
# Copy from Pydantic v2, compatible with v1
if sys.version_info < (3, 9):
# Pydantic no longer supports Python 3.8, this might be incorrect, but the code
# this is used for is also never reached in this codebase, as it's a copy of
# Pydantic's lenient_issubclass, just for compatibility with v1
# TODO: remove when dropping support for Python 3.8
WithArgsTypes: Tuple[Any, ...] = ()
elif sys.version_info < (3, 10):
WithArgsTypes: tuple[Any, ...] = (typing._GenericAlias, types.GenericAlias) # type: ignore[attr-defined]
else:
WithArgsTypes: tuple[Any, ...] = (
typing._GenericAlias, # type: ignore[attr-defined]
types.GenericAlias,
types.UnionType,
) # pyright: ignore[reportAttributeAccessIssue]
PYDANTIC_VERSION_MINOR_TUPLE = tuple(int(x) for x in PYDANTIC_VERSION.split(".")[:2])
PYDANTIC_V2 = PYDANTIC_VERSION_MINOR_TUPLE[0] == 2
sequence_annotation_to_type = {
Sequence: list,
List: list,
list: list,
Tuple: tuple,
tuple: tuple,
Set: set,
set: set,
FrozenSet: frozenset,
frozenset: frozenset,
Deque: deque,
deque: deque,
}
sequence_types = tuple(sequence_annotation_to_type.keys())
Url: Type[Any]
# Copy of Pydantic v2, compatible with v1
def lenient_issubclass(
cls: Any, class_or_tuple: Union[Type[Any], Tuple[Type[Any], ...], None]
) -> bool:
try:
return isinstance(cls, type) and issubclass(cls, class_or_tuple) # type: ignore[arg-type]
except TypeError: # pragma: no cover
if isinstance(cls, WithArgsTypes):
return False
raise # pragma: no cover
def _annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool:
if lenient_issubclass(annotation, (str, bytes)):
return False
return lenient_issubclass(annotation, sequence_types) # type: ignore[arg-type]
def field_annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool:
origin = get_origin(annotation)
if origin is Union or origin is UnionType:
for arg in get_args(annotation):
if field_annotation_is_sequence(arg):
return True
return False
return _annotation_is_sequence(annotation) or _annotation_is_sequence(
get_origin(annotation)
)
def value_is_sequence(value: Any) -> bool:
return isinstance(value, sequence_types) and not isinstance(value, (str, bytes)) # type: ignore[arg-type]
def _annotation_is_complex(annotation: Union[Type[Any], None]) -> bool:
return (
lenient_issubclass(
annotation, (BaseModel, may_v1.BaseModel, Mapping, UploadFile)
)
or _annotation_is_sequence(annotation)
or is_dataclass(annotation)
)
def field_annotation_is_complex(annotation: Union[Type[Any], None]) -> bool:
origin = get_origin(annotation)
if origin is Union or origin is UnionType:
return any(field_annotation_is_complex(arg) for arg in get_args(annotation))
if origin is Annotated:
return field_annotation_is_complex(get_args(annotation)[0])
return (
_annotation_is_complex(annotation)
or _annotation_is_complex(origin)
or hasattr(origin, "__pydantic_core_schema__")
or hasattr(origin, "__get_pydantic_core_schema__")
)
def field_annotation_is_scalar(annotation: Any) -> bool:
# handle Ellipsis here to make tuple[int, ...] work nicely
return annotation is Ellipsis or not field_annotation_is_complex(annotation)
def field_annotation_is_scalar_sequence(annotation: Union[Type[Any], None]) -> bool:
origin = get_origin(annotation)
if origin is Union or origin is UnionType:
at_least_one_scalar_sequence = False
for arg in get_args(annotation):
if field_annotation_is_scalar_sequence(arg):
at_least_one_scalar_sequence = True
continue
elif not field_annotation_is_scalar(arg):
return False
return at_least_one_scalar_sequence
return field_annotation_is_sequence(annotation) and all(
field_annotation_is_scalar(sub_annotation)
for sub_annotation in get_args(annotation)
)
def is_bytes_or_nonable_bytes_annotation(annotation: Any) -> bool:
if lenient_issubclass(annotation, bytes):
return True
origin = get_origin(annotation)
if origin is Union or origin is UnionType:
for arg in get_args(annotation):
if lenient_issubclass(arg, bytes):
return True
return False
def is_uploadfile_or_nonable_uploadfile_annotation(annotation: Any) -> bool:
if lenient_issubclass(annotation, UploadFile):
return True
origin = get_origin(annotation)
if origin is Union or origin is UnionType:
for arg in get_args(annotation):
if lenient_issubclass(arg, UploadFile):
return True
return False
def is_bytes_sequence_annotation(annotation: Any) -> bool:
origin = get_origin(annotation)
if origin is Union or origin is UnionType:
at_least_one = False
for arg in get_args(annotation):
if is_bytes_sequence_annotation(arg):
at_least_one = True
continue
return at_least_one
return field_annotation_is_sequence(annotation) and all(
is_bytes_or_nonable_bytes_annotation(sub_annotation)
for sub_annotation in get_args(annotation)
)
def is_uploadfile_sequence_annotation(annotation: Any) -> bool:
origin = get_origin(annotation)
if origin is Union or origin is UnionType:
at_least_one = False
for arg in get_args(annotation):
if is_uploadfile_sequence_annotation(arg):
at_least_one = True
continue
return at_least_one
return field_annotation_is_sequence(annotation) and all(
is_uploadfile_or_nonable_uploadfile_annotation(sub_annotation)
for sub_annotation in get_args(annotation)
)
def annotation_is_pydantic_v1(annotation: Any) -> bool:
if lenient_issubclass(annotation, may_v1.BaseModel):
return True
origin = get_origin(annotation)
if origin is Union or origin is UnionType:
for arg in get_args(annotation):
if lenient_issubclass(arg, may_v1.BaseModel):
return True
if field_annotation_is_sequence(annotation):
for sub_annotation in get_args(annotation):
if annotation_is_pydantic_v1(sub_annotation):
return True
return False

View File

@@ -0,0 +1,312 @@
from copy import copy
from dataclasses import dataclass, is_dataclass
from enum import Enum
from typing import (
Any,
Callable,
Dict,
List,
Sequence,
Set,
Tuple,
Type,
Union,
)
from fastapi._compat import shared
from fastapi.openapi.constants import REF_PREFIX as REF_PREFIX
from fastapi.types import ModelNameMap
from pydantic.version import VERSION as PYDANTIC_VERSION
from typing_extensions import Literal
PYDANTIC_VERSION_MINOR_TUPLE = tuple(int(x) for x in PYDANTIC_VERSION.split(".")[:2])
PYDANTIC_V2 = PYDANTIC_VERSION_MINOR_TUPLE[0] == 2
# Keeping old "Required" functionality from Pydantic V1, without
# shadowing typing.Required.
RequiredParam: Any = Ellipsis
if not PYDANTIC_V2:
from pydantic import BaseConfig as BaseConfig
from pydantic import BaseModel as BaseModel
from pydantic import ValidationError as ValidationError
from pydantic import create_model as create_model
from pydantic.class_validators import Validator as Validator
from pydantic.color import Color as Color
from pydantic.error_wrappers import ErrorWrapper as ErrorWrapper
from pydantic.errors import MissingError
from pydantic.fields import ( # type: ignore[attr-defined]
SHAPE_FROZENSET,
SHAPE_LIST,
SHAPE_SEQUENCE,
SHAPE_SET,
SHAPE_SINGLETON,
SHAPE_TUPLE,
SHAPE_TUPLE_ELLIPSIS,
)
from pydantic.fields import FieldInfo as FieldInfo
from pydantic.fields import ModelField as ModelField # type: ignore[attr-defined]
from pydantic.fields import Undefined as Undefined # type: ignore[attr-defined]
from pydantic.fields import ( # type: ignore[attr-defined]
UndefinedType as UndefinedType,
)
from pydantic.networks import AnyUrl as AnyUrl
from pydantic.networks import NameEmail as NameEmail
from pydantic.schema import TypeModelSet as TypeModelSet
from pydantic.schema import (
field_schema,
model_process_schema,
)
from pydantic.schema import (
get_annotation_from_field_info as get_annotation_from_field_info,
)
from pydantic.schema import get_flat_models_from_field as get_flat_models_from_field
from pydantic.schema import (
get_flat_models_from_fields as get_flat_models_from_fields,
)
from pydantic.schema import get_model_name_map as get_model_name_map
from pydantic.types import SecretBytes as SecretBytes
from pydantic.types import SecretStr as SecretStr
from pydantic.typing import evaluate_forwardref as evaluate_forwardref
from pydantic.utils import lenient_issubclass as lenient_issubclass
else:
from pydantic.v1 import BaseConfig as BaseConfig # type: ignore[assignment]
from pydantic.v1 import BaseModel as BaseModel # type: ignore[assignment]
from pydantic.v1 import ( # type: ignore[assignment]
ValidationError as ValidationError,
)
from pydantic.v1 import create_model as create_model # type: ignore[no-redef]
from pydantic.v1.class_validators import Validator as Validator
from pydantic.v1.color import Color as Color # type: ignore[assignment]
from pydantic.v1.error_wrappers import ErrorWrapper as ErrorWrapper
from pydantic.v1.errors import MissingError
from pydantic.v1.fields import (
SHAPE_FROZENSET,
SHAPE_LIST,
SHAPE_SEQUENCE,
SHAPE_SET,
SHAPE_SINGLETON,
SHAPE_TUPLE,
SHAPE_TUPLE_ELLIPSIS,
)
from pydantic.v1.fields import FieldInfo as FieldInfo # type: ignore[assignment]
from pydantic.v1.fields import ModelField as ModelField
from pydantic.v1.fields import Undefined as Undefined
from pydantic.v1.fields import UndefinedType as UndefinedType
from pydantic.v1.networks import AnyUrl as AnyUrl
from pydantic.v1.networks import ( # type: ignore[assignment]
NameEmail as NameEmail,
)
from pydantic.v1.schema import TypeModelSet as TypeModelSet
from pydantic.v1.schema import (
field_schema,
model_process_schema,
)
from pydantic.v1.schema import (
get_annotation_from_field_info as get_annotation_from_field_info,
)
from pydantic.v1.schema import (
get_flat_models_from_field as get_flat_models_from_field,
)
from pydantic.v1.schema import (
get_flat_models_from_fields as get_flat_models_from_fields,
)
from pydantic.v1.schema import get_model_name_map as get_model_name_map
from pydantic.v1.types import ( # type: ignore[assignment]
SecretBytes as SecretBytes,
)
from pydantic.v1.types import ( # type: ignore[assignment]
SecretStr as SecretStr,
)
from pydantic.v1.typing import evaluate_forwardref as evaluate_forwardref
from pydantic.v1.utils import lenient_issubclass as lenient_issubclass
GetJsonSchemaHandler = Any
JsonSchemaValue = Dict[str, Any]
CoreSchema = Any
Url = AnyUrl
sequence_shapes = {
SHAPE_LIST,
SHAPE_SET,
SHAPE_FROZENSET,
SHAPE_TUPLE,
SHAPE_SEQUENCE,
SHAPE_TUPLE_ELLIPSIS,
}
sequence_shape_to_type = {
SHAPE_LIST: list,
SHAPE_SET: set,
SHAPE_TUPLE: tuple,
SHAPE_SEQUENCE: list,
SHAPE_TUPLE_ELLIPSIS: list,
}
@dataclass
class GenerateJsonSchema:
ref_template: str
class PydanticSchemaGenerationError(Exception):
pass
RequestErrorModel: Type[BaseModel] = create_model("Request")
def with_info_plain_validator_function(
function: Callable[..., Any],
*,
ref: Union[str, None] = None,
metadata: Any = None,
serialization: Any = None,
) -> Any:
return {}
def get_model_definitions(
*,
flat_models: Set[Union[Type[BaseModel], Type[Enum]]],
model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str],
) -> Dict[str, Any]:
definitions: Dict[str, Dict[str, Any]] = {}
for model in flat_models:
m_schema, m_definitions, m_nested_models = model_process_schema(
model, model_name_map=model_name_map, ref_prefix=REF_PREFIX
)
definitions.update(m_definitions)
model_name = model_name_map[model]
definitions[model_name] = m_schema
for m_schema in definitions.values():
if "description" in m_schema:
m_schema["description"] = m_schema["description"].split("\f")[0]
return definitions
def is_pv1_scalar_field(field: ModelField) -> bool:
from fastapi import params
field_info = field.field_info
if not (
field.shape == SHAPE_SINGLETON
and not lenient_issubclass(field.type_, BaseModel)
and not lenient_issubclass(field.type_, dict)
and not shared.field_annotation_is_sequence(field.type_)
and not is_dataclass(field.type_)
and not isinstance(field_info, params.Body)
):
return False
if field.sub_fields:
if not all(is_pv1_scalar_field(f) for f in field.sub_fields):
return False
return True
def is_pv1_scalar_sequence_field(field: ModelField) -> bool:
if (field.shape in sequence_shapes) and not lenient_issubclass(
field.type_, BaseModel
):
if field.sub_fields is not None:
for sub_field in field.sub_fields:
if not is_pv1_scalar_field(sub_field):
return False
return True
if shared._annotation_is_sequence(field.type_):
return True
return False
def _model_rebuild(model: Type[BaseModel]) -> None:
model.update_forward_refs()
def _model_dump(
model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any
) -> Any:
return model.dict(**kwargs)
def _get_model_config(model: BaseModel) -> Any:
return model.__config__ # type: ignore[attr-defined]
def get_schema_from_model_field(
*,
field: ModelField,
model_name_map: ModelNameMap,
field_mapping: Dict[
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
],
separate_input_output_schemas: bool = True,
) -> Dict[str, Any]:
return field_schema( # type: ignore[no-any-return]
field, model_name_map=model_name_map, ref_prefix=REF_PREFIX
)[0]
# def get_compat_model_name_map(fields: List[ModelField]) -> ModelNameMap:
# models = get_flat_models_from_fields(fields, known_models=set())
# return get_model_name_map(models) # type: ignore[no-any-return]
def get_definitions(
*,
fields: List[ModelField],
model_name_map: ModelNameMap,
separate_input_output_schemas: bool = True,
) -> Tuple[
Dict[Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue],
Dict[str, Dict[str, Any]],
]:
models = get_flat_models_from_fields(fields, known_models=set())
return {}, get_model_definitions(flat_models=models, model_name_map=model_name_map)
def is_scalar_field(field: ModelField) -> bool:
return is_pv1_scalar_field(field)
def is_sequence_field(field: ModelField) -> bool:
return field.shape in sequence_shapes or shared._annotation_is_sequence(field.type_)
def is_scalar_sequence_field(field: ModelField) -> bool:
return is_pv1_scalar_sequence_field(field)
def is_bytes_field(field: ModelField) -> bool:
return lenient_issubclass(field.type_, bytes) # type: ignore[no-any-return]
def is_bytes_sequence_field(field: ModelField) -> bool:
return field.shape in sequence_shapes and lenient_issubclass(field.type_, bytes)
def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo:
return copy(field_info)
def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]:
return sequence_shape_to_type[field.shape](value) # type: ignore[no-any-return]
def get_missing_field_error(loc: Tuple[str, ...]) -> Dict[str, Any]:
missing_field_error = ErrorWrapper(MissingError(), loc=loc)
new_error = ValidationError([missing_field_error], RequestErrorModel)
return new_error.errors()[0] # type: ignore[return-value]
def create_body_model(
*, fields: Sequence[ModelField], model_name: str
) -> Type[BaseModel]:
BodyModel = create_model(model_name)
for f in fields:
BodyModel.__fields__[f.name] = f # type: ignore[index]
return BodyModel
def get_model_fields(model: Type[BaseModel]) -> List[ModelField]:
return list(model.__fields__.values()) # type: ignore[attr-defined]

View File

@@ -0,0 +1,479 @@
import re
import warnings
from copy import copy, deepcopy
from dataclasses import dataclass
from enum import Enum
from typing import (
Any,
Dict,
List,
Sequence,
Set,
Tuple,
Type,
Union,
cast,
)
from fastapi._compat import may_v1, shared
from fastapi.openapi.constants import REF_TEMPLATE
from fastapi.types import IncEx, ModelNameMap
from pydantic import BaseModel, TypeAdapter, create_model
from pydantic import PydanticSchemaGenerationError as PydanticSchemaGenerationError
from pydantic import PydanticUndefinedAnnotation as PydanticUndefinedAnnotation
from pydantic import ValidationError as ValidationError
from pydantic._internal._schema_generation_shared import ( # type: ignore[attr-defined]
GetJsonSchemaHandler as GetJsonSchemaHandler,
)
from pydantic._internal._typing_extra import eval_type_lenient
from pydantic._internal._utils import lenient_issubclass as lenient_issubclass
from pydantic.fields import FieldInfo as FieldInfo
from pydantic.json_schema import GenerateJsonSchema as GenerateJsonSchema
from pydantic.json_schema import JsonSchemaValue as JsonSchemaValue
from pydantic_core import CoreSchema as CoreSchema
from pydantic_core import PydanticUndefined, PydanticUndefinedType
from pydantic_core import Url as Url
from typing_extensions import Annotated, Literal, get_args, get_origin
try:
from pydantic_core.core_schema import (
with_info_plain_validator_function as with_info_plain_validator_function,
)
except ImportError: # pragma: no cover
from pydantic_core.core_schema import (
general_plain_validator_function as with_info_plain_validator_function, # noqa: F401
)
RequiredParam = PydanticUndefined
Undefined = PydanticUndefined
UndefinedType = PydanticUndefinedType
evaluate_forwardref = eval_type_lenient
Validator = Any
class BaseConfig:
pass
class ErrorWrapper(Exception):
pass
@dataclass
class ModelField:
field_info: FieldInfo
name: str
mode: Literal["validation", "serialization"] = "validation"
@property
def alias(self) -> str:
a = self.field_info.alias
return a if a is not None else self.name
@property
def required(self) -> bool:
return self.field_info.is_required()
@property
def default(self) -> Any:
return self.get_default()
@property
def type_(self) -> Any:
return self.field_info.annotation
def __post_init__(self) -> None:
with warnings.catch_warnings():
# Pydantic >= 2.12.0 warns about field specific metadata that is unused
# (e.g. `TypeAdapter(Annotated[int, Field(alias='b')])`). In some cases, we
# end up building the type adapter from a model field annotation so we
# need to ignore the warning:
if shared.PYDANTIC_VERSION_MINOR_TUPLE >= (2, 12):
from pydantic.warnings import UnsupportedFieldAttributeWarning
warnings.simplefilter(
"ignore", category=UnsupportedFieldAttributeWarning
)
self._type_adapter: TypeAdapter[Any] = TypeAdapter(
Annotated[self.field_info.annotation, self.field_info]
)
def get_default(self) -> Any:
if self.field_info.is_required():
return Undefined
return self.field_info.get_default(call_default_factory=True)
def validate(
self,
value: Any,
values: Dict[str, Any] = {}, # noqa: B006
*,
loc: Tuple[Union[int, str], ...] = (),
) -> Tuple[Any, Union[List[Dict[str, Any]], None]]:
try:
return (
self._type_adapter.validate_python(value, from_attributes=True),
None,
)
except ValidationError as exc:
return None, may_v1._regenerate_error_with_loc(
errors=exc.errors(include_url=False), loc_prefix=loc
)
def serialize(
self,
value: Any,
*,
mode: Literal["json", "python"] = "json",
include: Union[IncEx, None] = None,
exclude: Union[IncEx, None] = None,
by_alias: bool = True,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
) -> Any:
# What calls this code passes a value that already called
# self._type_adapter.validate_python(value)
return self._type_adapter.dump_python(
value,
mode=mode,
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)
def __hash__(self) -> int:
# Each ModelField is unique for our purposes, to allow making a dict from
# ModelField to its JSON Schema.
return id(self)
def get_annotation_from_field_info(
annotation: Any, field_info: FieldInfo, field_name: str
) -> Any:
return annotation
def _model_rebuild(model: Type[BaseModel]) -> None:
model.model_rebuild()
def _model_dump(
model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any
) -> Any:
return model.model_dump(mode=mode, **kwargs)
def _get_model_config(model: BaseModel) -> Any:
return model.model_config
def get_schema_from_model_field(
*,
field: ModelField,
model_name_map: ModelNameMap,
field_mapping: Dict[
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
],
separate_input_output_schemas: bool = True,
) -> Dict[str, Any]:
override_mode: Union[Literal["validation"], None] = (
None if separate_input_output_schemas else "validation"
)
# This expects that GenerateJsonSchema was already used to generate the definitions
json_schema = field_mapping[(field, override_mode or field.mode)]
if "$ref" not in json_schema:
# TODO remove when deprecating Pydantic v1
# Ref: https://github.com/pydantic/pydantic/blob/d61792cc42c80b13b23e3ffa74bc37ec7c77f7d1/pydantic/schema.py#L207
json_schema["title"] = field.field_info.title or field.alias.title().replace(
"_", " "
)
return json_schema
def get_definitions(
*,
fields: Sequence[ModelField],
model_name_map: ModelNameMap,
separate_input_output_schemas: bool = True,
) -> Tuple[
Dict[Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue],
Dict[str, Dict[str, Any]],
]:
schema_generator = GenerateJsonSchema(ref_template=REF_TEMPLATE)
override_mode: Union[Literal["validation"], None] = (
None if separate_input_output_schemas else "validation"
)
validation_fields = [field for field in fields if field.mode == "validation"]
serialization_fields = [field for field in fields if field.mode == "serialization"]
flat_validation_models = get_flat_models_from_fields(
validation_fields, known_models=set()
)
flat_serialization_models = get_flat_models_from_fields(
serialization_fields, known_models=set()
)
flat_validation_model_fields = [
ModelField(
field_info=FieldInfo(annotation=model),
name=model.__name__,
mode="validation",
)
for model in flat_validation_models
]
flat_serialization_model_fields = [
ModelField(
field_info=FieldInfo(annotation=model),
name=model.__name__,
mode="serialization",
)
for model in flat_serialization_models
]
flat_model_fields = flat_validation_model_fields + flat_serialization_model_fields
input_types = {f.type_ for f in fields}
unique_flat_model_fields = {
f for f in flat_model_fields if f.type_ not in input_types
}
inputs = [
(field, override_mode or field.mode, field._type_adapter.core_schema)
for field in list(fields) + list(unique_flat_model_fields)
]
field_mapping, definitions = schema_generator.generate_definitions(inputs=inputs)
for item_def in cast(Dict[str, Dict[str, Any]], definitions).values():
if "description" in item_def:
item_description = cast(str, item_def["description"]).split("\f")[0]
item_def["description"] = item_description
new_mapping, new_definitions = _remap_definitions_and_field_mappings(
model_name_map=model_name_map,
definitions=definitions, # type: ignore[arg-type]
field_mapping=field_mapping,
)
return new_mapping, new_definitions
def _replace_refs(
*,
schema: Dict[str, Any],
old_name_to_new_name_map: Dict[str, str],
) -> Dict[str, Any]:
new_schema = deepcopy(schema)
for key, value in new_schema.items():
if key == "$ref":
value = schema["$ref"]
if isinstance(value, str):
ref_name = schema["$ref"].split("/")[-1]
if ref_name in old_name_to_new_name_map:
new_name = old_name_to_new_name_map[ref_name]
new_schema["$ref"] = REF_TEMPLATE.format(model=new_name)
continue
if isinstance(value, dict):
new_schema[key] = _replace_refs(
schema=value,
old_name_to_new_name_map=old_name_to_new_name_map,
)
elif isinstance(value, list):
new_value = []
for item in value:
if isinstance(item, dict):
new_item = _replace_refs(
schema=item,
old_name_to_new_name_map=old_name_to_new_name_map,
)
new_value.append(new_item)
else:
new_value.append(item)
new_schema[key] = new_value
return new_schema
def _remap_definitions_and_field_mappings(
*,
model_name_map: ModelNameMap,
definitions: Dict[str, Any],
field_mapping: Dict[
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
],
) -> Tuple[
Dict[Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue],
Dict[str, Any],
]:
old_name_to_new_name_map = {}
for field_key, schema in field_mapping.items():
model = field_key[0].type_
if model not in model_name_map:
continue
new_name = model_name_map[model]
old_name = schema["$ref"].split("/")[-1]
if old_name in {f"{new_name}-Input", f"{new_name}-Output"}:
continue
old_name_to_new_name_map[old_name] = new_name
new_field_mapping: Dict[
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
] = {}
for field_key, schema in field_mapping.items():
new_schema = _replace_refs(
schema=schema,
old_name_to_new_name_map=old_name_to_new_name_map,
)
new_field_mapping[field_key] = new_schema
new_definitions = {}
for key, value in definitions.items():
if key in old_name_to_new_name_map:
new_key = old_name_to_new_name_map[key]
else:
new_key = key
new_value = _replace_refs(
schema=value,
old_name_to_new_name_map=old_name_to_new_name_map,
)
new_definitions[new_key] = new_value
return new_field_mapping, new_definitions
def is_scalar_field(field: ModelField) -> bool:
from fastapi import params
return shared.field_annotation_is_scalar(
field.field_info.annotation
) and not isinstance(field.field_info, params.Body)
def is_sequence_field(field: ModelField) -> bool:
return shared.field_annotation_is_sequence(field.field_info.annotation)
def is_scalar_sequence_field(field: ModelField) -> bool:
return shared.field_annotation_is_scalar_sequence(field.field_info.annotation)
def is_bytes_field(field: ModelField) -> bool:
return shared.is_bytes_or_nonable_bytes_annotation(field.type_)
def is_bytes_sequence_field(field: ModelField) -> bool:
return shared.is_bytes_sequence_annotation(field.type_)
def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo:
cls = type(field_info)
merged_field_info = cls.from_annotation(annotation)
new_field_info = copy(field_info)
new_field_info.metadata = merged_field_info.metadata
new_field_info.annotation = merged_field_info.annotation
return new_field_info
def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]:
origin_type = get_origin(field.field_info.annotation) or field.field_info.annotation
assert issubclass(origin_type, shared.sequence_types) # type: ignore[arg-type]
return shared.sequence_annotation_to_type[origin_type](value) # type: ignore[no-any-return]
def get_missing_field_error(loc: Tuple[str, ...]) -> Dict[str, Any]:
error = ValidationError.from_exception_data(
"Field required", [{"type": "missing", "loc": loc, "input": {}}]
).errors(include_url=False)[0]
error["input"] = None
return error # type: ignore[return-value]
def create_body_model(
*, fields: Sequence[ModelField], model_name: str
) -> Type[BaseModel]:
field_params = {f.name: (f.field_info.annotation, f.field_info) for f in fields}
BodyModel: Type[BaseModel] = create_model(model_name, **field_params) # type: ignore[call-overload]
return BodyModel
def get_model_fields(model: Type[BaseModel]) -> List[ModelField]:
return [
ModelField(field_info=field_info, name=name)
for name, field_info in model.model_fields.items()
]
# Duplicate of several schema functions from Pydantic v1 to make them compatible with
# Pydantic v2 and allow mixing the models
TypeModelOrEnum = Union[Type["BaseModel"], Type[Enum]]
TypeModelSet = Set[TypeModelOrEnum]
def normalize_name(name: str) -> str:
return re.sub(r"[^a-zA-Z0-9.\-_]", "_", name)
def get_model_name_map(unique_models: TypeModelSet) -> Dict[TypeModelOrEnum, str]:
name_model_map = {}
conflicting_names: Set[str] = set()
for model in unique_models:
model_name = normalize_name(model.__name__)
if model_name in conflicting_names:
model_name = get_long_model_name(model)
name_model_map[model_name] = model
elif model_name in name_model_map:
conflicting_names.add(model_name)
conflicting_model = name_model_map.pop(model_name)
name_model_map[get_long_model_name(conflicting_model)] = conflicting_model
name_model_map[get_long_model_name(model)] = model
else:
name_model_map[model_name] = model
return {v: k for k, v in name_model_map.items()}
def get_flat_models_from_model(
model: Type["BaseModel"], known_models: Union[TypeModelSet, None] = None
) -> TypeModelSet:
known_models = known_models or set()
fields = get_model_fields(model)
get_flat_models_from_fields(fields, known_models=known_models)
return known_models
def get_flat_models_from_annotation(
annotation: Any, known_models: TypeModelSet
) -> TypeModelSet:
origin = get_origin(annotation)
if origin is not None:
for arg in get_args(annotation):
if lenient_issubclass(arg, (BaseModel, Enum)) and arg not in known_models:
known_models.add(arg)
if lenient_issubclass(arg, BaseModel):
get_flat_models_from_model(arg, known_models=known_models)
else:
get_flat_models_from_annotation(arg, known_models=known_models)
return known_models
def get_flat_models_from_field(
field: ModelField, known_models: TypeModelSet
) -> TypeModelSet:
field_type = field.type_
if lenient_issubclass(field_type, BaseModel):
if field_type in known_models:
return known_models
known_models.add(field_type)
get_flat_models_from_model(field_type, known_models=known_models)
elif lenient_issubclass(field_type, Enum):
known_models.add(field_type)
else:
get_flat_models_from_annotation(field_type, known_models=known_models)
return known_models
def get_flat_models_from_fields(
fields: Sequence[ModelField], known_models: TypeModelSet
) -> TypeModelSet:
for field in fields:
get_flat_models_from_field(field, known_models=known_models)
return known_models
def get_long_model_name(model: TypeModelOrEnum) -> str:
return f"{model.__module__}__{model.__qualname__}".replace(".", "__")

View File

@@ -13,6 +13,7 @@ from typing import (
Union,
)
from annotated_doc import Doc
from fastapi import routing
from fastapi.datastructures import Default, DefaultPlaceholder
from fastapi.exception_handlers import (
@@ -42,8 +43,8 @@ from starlette.middleware.exceptions import ExceptionMiddleware
from starlette.requests import Request
from starlette.responses import HTMLResponse, JSONResponse, Response
from starlette.routing import BaseRoute
from starlette.types import ASGIApp, Lifespan, Receive, Scope, Send
from typing_extensions import Annotated, Doc, deprecated # type: ignore [attr-defined]
from starlette.types import ASGIApp, ExceptionHandler, Lifespan, Receive, Scope, Send
from typing_extensions import Annotated, deprecated
AppType = TypeVar("AppType", bound="FastAPI")
@@ -75,7 +76,7 @@ class FastAPI(Starlette):
errors.
Read more in the
[Starlette docs for Applications](https://www.starlette.io/applications/#instantiating-the-application).
[Starlette docs for Applications](https://www.starlette.dev/applications/#instantiating-the-application).
"""
),
] = False,
@@ -300,7 +301,7 @@ class FastAPI(Starlette):
browser tabs open). Or if you want to leave fixed the possible URLs.
If the servers `list` is not provided, or is an empty `list`, the
default value would be a a `dict` with a `url` value of `/`.
default value would be a `dict` with a `url` value of `/`.
Each item in the `list` is a `dict` containing:
@@ -751,7 +752,7 @@ class FastAPI(Starlette):
This affects the generated OpenAPI (e.g. visible at `/docs`).
Read more about it in the
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi).
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi).
"""
),
] = True,
@@ -813,6 +814,32 @@ class FastAPI(Starlette):
"""
),
] = True,
openapi_external_docs: Annotated[
Optional[Dict[str, Any]],
Doc(
"""
This field allows you to provide additional external documentation links.
If provided, it must be a dictionary containing:
* `description`: A brief description of the external documentation.
* `url`: The URL pointing to the external documentation. The value **MUST**
be a valid URL format.
**Example**:
```python
from fastapi import FastAPI
external_docs = {
"description": "Detailed API Reference",
"url": "https://example.com/api-docs",
}
app = FastAPI(openapi_external_docs=external_docs)
```
"""
),
] = None,
**extra: Annotated[
Any,
Doc(
@@ -841,6 +868,7 @@ class FastAPI(Starlette):
self.swagger_ui_parameters = swagger_ui_parameters
self.servers = servers or []
self.separate_input_output_schemas = separate_input_output_schemas
self.openapi_external_docs = openapi_external_docs
self.extra = extra
self.openapi_version: Annotated[
str,
@@ -905,13 +933,13 @@ class FastAPI(Starlette):
A state object for the application. This is the same object for the
entire application, it doesn't change from request to request.
You normally woudln't use this in FastAPI, for most of the cases you
You normally wouldn't use this in FastAPI, for most of the cases you
would instead use FastAPI dependencies.
This is simply inherited from Starlette.
Read more about it in the
[Starlette docs for Applications](https://www.starlette.io/applications/#storing-state-on-the-app-instance).
[Starlette docs for Applications](https://www.starlette.dev/applications/#storing-state-on-the-app-instance).
"""
),
] = State()
@@ -971,7 +999,7 @@ class FastAPI(Starlette):
# inside of ExceptionMiddleware, inside of custom user middlewares
debug = self.debug
error_handler = None
exception_handlers = {}
exception_handlers: dict[Any, ExceptionHandler] = {}
for key, value in self.exception_handlers.items():
if key in (500, Exception):
@@ -986,33 +1014,32 @@ class FastAPI(Starlette):
Middleware(
ExceptionMiddleware, handlers=exception_handlers, debug=debug
),
# Add FastAPI-specific AsyncExitStackMiddleware for dependencies with
# contextvars.
# Add FastAPI-specific AsyncExitStackMiddleware for closing files.
# Before this was also used for closing dependencies with yield but
# those now have their own AsyncExitStack, to properly support
# streaming responses while keeping compatibility with the previous
# versions (as of writing 0.117.1) that allowed doing
# except HTTPException inside a dependency with yield.
# This needs to happen after user middlewares because those create a
# new contextvars context copy by using a new AnyIO task group.
# The initial part of dependencies with 'yield' is executed in the
# FastAPI code, inside all the middlewares. However, the teardown part
# (after 'yield') is executed in the AsyncExitStack in this middleware.
# This AsyncExitStack preserves the context for contextvars, not
# strictly necessary for closing files but it was one of the original
# intentions.
# If the AsyncExitStack lived outside of the custom middlewares and
# contextvars were set in a dependency with 'yield' in that internal
# contextvars context, the values would not be available in the
# outer context of the AsyncExitStack.
# contextvars were set, for example in a dependency with 'yield'
# in that internal contextvars context, the values would not be
# available in the outer context of the AsyncExitStack.
# By placing the middleware and the AsyncExitStack here, inside all
# user middlewares, the code before and after 'yield' in dependencies
# with 'yield' is executed in the same contextvars context. Thus, all values
# set in contextvars before 'yield' are still available after 'yield,' as
# expected.
# Additionally, by having this AsyncExitStack here, after the
# ExceptionMiddleware, dependencies can now catch handled exceptions,
# e.g. HTTPException, to customize the teardown code (e.g. DB session
# rollback).
# user middlewares, the same context is used.
# This is currently not needed, only for closing files, but used to be
# important when dependencies with yield were closed here.
Middleware(AsyncExitStackMiddleware),
]
)
app = self.router
for cls, options in reversed(middleware):
app = cls(app=app, **options)
for cls, args, kwargs in reversed(middleware):
app = cls(app, *args, **kwargs)
return app
def openapi(self) -> Dict[str, Any]:
@@ -1044,6 +1071,7 @@ class FastAPI(Starlette):
tags=self.openapi_tags,
servers=self.servers,
separate_input_output_schemas=self.separate_input_output_schemas,
external_docs=self.openapi_external_docs,
)
return self.openapi_schema
@@ -1071,7 +1099,7 @@ class FastAPI(Starlette):
oauth2_redirect_url = root_path + oauth2_redirect_url
return get_swagger_ui_html(
openapi_url=openapi_url,
title=self.title + " - Swagger UI",
title=f"{self.title} - Swagger UI",
oauth2_redirect_url=oauth2_redirect_url,
init_oauth=self.swagger_ui_init_oauth,
swagger_ui_parameters=self.swagger_ui_parameters,
@@ -1095,7 +1123,7 @@ class FastAPI(Starlette):
root_path = req.scope.get("root_path", "").rstrip("/")
openapi_url = root_path + self.openapi_url
return get_redoc_html(
openapi_url=openapi_url, title=self.title + " - ReDoc"
openapi_url=openapi_url, title=f"{self.title} - ReDoc"
)
self.add_route(self.redoc_url, redoc_html, include_in_schema=False)
@@ -1108,7 +1136,7 @@ class FastAPI(Starlette):
def add_api_route(
self,
path: str,
endpoint: Callable[..., Coroutine[Any, Any, Response]],
endpoint: Callable[..., Any],
*,
response_model: Any = Default(None),
status_code: Optional[int] = None,
@@ -1772,7 +1800,7 @@ class FastAPI(Starlette):
This affects the generated OpenAPI (e.g. visible at `/docs`).
Read more about it in the
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi).
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi).
"""
),
] = True,
@@ -2145,7 +2173,7 @@ class FastAPI(Starlette):
This affects the generated OpenAPI (e.g. visible at `/docs`).
Read more about it in the
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi).
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi).
"""
),
] = True,
@@ -2523,7 +2551,7 @@ class FastAPI(Starlette):
This affects the generated OpenAPI (e.g. visible at `/docs`).
Read more about it in the
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi).
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi).
"""
),
] = True,
@@ -2901,7 +2929,7 @@ class FastAPI(Starlette):
This affects the generated OpenAPI (e.g. visible at `/docs`).
Read more about it in the
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi).
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi).
"""
),
] = True,
@@ -3274,7 +3302,7 @@ class FastAPI(Starlette):
This affects the generated OpenAPI (e.g. visible at `/docs`).
Read more about it in the
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi).
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi).
"""
),
] = True,
@@ -3647,7 +3675,7 @@ class FastAPI(Starlette):
This affects the generated OpenAPI (e.g. visible at `/docs`).
Read more about it in the
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi).
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi).
"""
),
] = True,
@@ -4020,7 +4048,7 @@ class FastAPI(Starlette):
This affects the generated OpenAPI (e.g. visible at `/docs`).
Read more about it in the
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi).
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi).
"""
),
] = True,
@@ -4398,7 +4426,7 @@ class FastAPI(Starlette):
This affects the generated OpenAPI (e.g. visible at `/docs`).
Read more about it in the
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi).
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi).
"""
),
] = True,
@@ -4477,7 +4505,7 @@ class FastAPI(Starlette):
app = FastAPI()
@app.put("/items/{item_id}")
@app.trace("/items/{item_id}")
def trace_item(item_id: str):
return None
```
@@ -4567,14 +4595,17 @@ class FastAPI(Starlette):
```python
import time
from typing import Awaitable, Callable
from fastapi import FastAPI, Request
from fastapi import FastAPI, Request, Response
app = FastAPI()
@app.middleware("http")
async def add_process_time_header(request: Request, call_next):
async def add_process_time_header(
request: Request, call_next: Callable[[Request], Awaitable[Response]]
) -> Response:
start_time = time.time()
response = await call_next(request)
process_time = time.time() - start_time

View File

@@ -1,7 +1,8 @@
from typing import Any, Callable
from annotated_doc import Doc
from starlette.background import BackgroundTasks as StarletteBackgroundTasks
from typing_extensions import Annotated, Doc, ParamSpec # type: ignore [attr-defined]
from typing_extensions import Annotated, ParamSpec
P = ParamSpec("P")

View File

@@ -0,0 +1,13 @@
try:
from fastapi_cli.cli import main as cli_main
except ImportError: # pragma: no cover
cli_main = None # type: ignore
def main() -> None:
if not cli_main: # type: ignore[truthy-function]
message = 'To use the fastapi command, please install "fastapi[standard]":\n\n\tpip install "fastapi[standard]"\n'
print(message)
raise RuntimeError(message) # noqa: B904
cli_main()

View File

@@ -1,8 +1,7 @@
from contextlib import AsyncExitStack as AsyncExitStack # noqa
from contextlib import asynccontextmanager as asynccontextmanager
from typing import AsyncGenerator, ContextManager, TypeVar
import anyio
import anyio.to_thread
from anyio import CapacityLimiter
from starlette.concurrency import iterate_in_threadpool as iterate_in_threadpool # noqa
from starlette.concurrency import run_in_threadpool as run_in_threadpool # noqa
@@ -29,7 +28,7 @@ async def contextmanager_in_threadpool(
except Exception as e:
ok = bool(
await anyio.to_thread.run_sync(
cm.__exit__, type(e), e, None, limiter=exit_limiter
cm.__exit__, type(e), e, e.__traceback__, limiter=exit_limiter
)
)
if not ok:

View File

@@ -10,12 +10,11 @@ from typing import (
cast,
)
from annotated_doc import Doc
from fastapi._compat import (
PYDANTIC_V2,
CoreSchema,
GetJsonSchemaHandler,
JsonSchemaValue,
with_info_plain_validator_function,
)
from starlette.datastructures import URL as URL # noqa: F401
from starlette.datastructures import Address as Address # noqa: F401
@@ -24,7 +23,7 @@ from starlette.datastructures import Headers as Headers # noqa: F401
from starlette.datastructures import QueryParams as QueryParams # noqa: F401
from starlette.datastructures import State as State # noqa: F401
from starlette.datastructures import UploadFile as StarletteUploadFile
from typing_extensions import Annotated, Doc # type: ignore [attr-defined]
from typing_extensions import Annotated
class UploadFile(StarletteUploadFile):
@@ -154,11 +153,10 @@ class UploadFile(StarletteUploadFile):
raise ValueError(f"Expected UploadFile, received: {type(__input_value)}")
return cast(UploadFile, __input_value)
if not PYDANTIC_V2:
@classmethod
def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
field_schema.update({"type": "string", "format": "binary"})
# TODO: remove when deprecating Pydantic v1
@classmethod
def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
field_schema.update({"type": "string", "format": "binary"})
@classmethod
def __get_pydantic_json_schema__(
@@ -170,6 +168,8 @@ class UploadFile(StarletteUploadFile):
def __get_pydantic_core_schema__(
cls, source: Type[Any], handler: Callable[[Any], CoreSchema]
) -> CoreSchema:
from ._compat.v2 import with_info_plain_validator_function
return with_info_plain_validator_function(cls._validate)

View File

@@ -1,58 +1,107 @@
from typing import Any, Callable, List, Optional, Sequence
import inspect
import sys
from dataclasses import dataclass, field
from functools import cached_property
from typing import Any, Callable, List, Optional, Sequence, Union
from fastapi._compat import ModelField
from fastapi.security.base import SecurityBase
from fastapi.types import DependencyCacheKey
from typing_extensions import Literal
if sys.version_info >= (3, 13): # pragma: no cover
from inspect import iscoroutinefunction
else: # pragma: no cover
from asyncio import iscoroutinefunction
@dataclass
class SecurityRequirement:
def __init__(
self, security_scheme: SecurityBase, scopes: Optional[Sequence[str]] = None
):
self.security_scheme = security_scheme
self.scopes = scopes
security_scheme: SecurityBase
scopes: Optional[Sequence[str]] = None
@dataclass
class Dependant:
def __init__(
self,
*,
path_params: Optional[List[ModelField]] = None,
query_params: Optional[List[ModelField]] = None,
header_params: Optional[List[ModelField]] = None,
cookie_params: Optional[List[ModelField]] = None,
body_params: Optional[List[ModelField]] = None,
dependencies: Optional[List["Dependant"]] = None,
security_schemes: Optional[List[SecurityRequirement]] = None,
name: Optional[str] = None,
call: Optional[Callable[..., Any]] = None,
request_param_name: Optional[str] = None,
websocket_param_name: Optional[str] = None,
http_connection_param_name: Optional[str] = None,
response_param_name: Optional[str] = None,
background_tasks_param_name: Optional[str] = None,
security_scopes_param_name: Optional[str] = None,
security_scopes: Optional[List[str]] = None,
use_cache: bool = True,
path: Optional[str] = None,
) -> None:
self.path_params = path_params or []
self.query_params = query_params or []
self.header_params = header_params or []
self.cookie_params = cookie_params or []
self.body_params = body_params or []
self.dependencies = dependencies or []
self.security_requirements = security_schemes or []
self.request_param_name = request_param_name
self.websocket_param_name = websocket_param_name
self.http_connection_param_name = http_connection_param_name
self.response_param_name = response_param_name
self.background_tasks_param_name = background_tasks_param_name
self.security_scopes = security_scopes
self.security_scopes_param_name = security_scopes_param_name
self.name = name
self.call = call
self.use_cache = use_cache
# Store the path to be able to re-generate a dependable from it in overrides
self.path = path
# Save the cache key at creation to optimize performance
self.cache_key = (self.call, tuple(sorted(set(self.security_scopes or []))))
path_params: List[ModelField] = field(default_factory=list)
query_params: List[ModelField] = field(default_factory=list)
header_params: List[ModelField] = field(default_factory=list)
cookie_params: List[ModelField] = field(default_factory=list)
body_params: List[ModelField] = field(default_factory=list)
dependencies: List["Dependant"] = field(default_factory=list)
security_requirements: List[SecurityRequirement] = field(default_factory=list)
name: Optional[str] = None
call: Optional[Callable[..., Any]] = None
request_param_name: Optional[str] = None
websocket_param_name: Optional[str] = None
http_connection_param_name: Optional[str] = None
response_param_name: Optional[str] = None
background_tasks_param_name: Optional[str] = None
security_scopes_param_name: Optional[str] = None
own_oauth_scopes: Optional[List[str]] = None
parent_oauth_scopes: Optional[List[str]] = None
use_cache: bool = True
path: Optional[str] = None
scope: Union[Literal["function", "request"], None] = None
@cached_property
def oauth_scopes(self) -> List[str]:
scopes = self.parent_oauth_scopes.copy() if self.parent_oauth_scopes else []
# This doesn't use a set to preserve order, just in case
for scope in self.own_oauth_scopes or []:
if scope not in scopes:
scopes.append(scope)
return scopes
@cached_property
def cache_key(self) -> DependencyCacheKey:
scopes_for_cache = (
tuple(sorted(set(self.oauth_scopes or []))) if self._uses_scopes else ()
)
return (
self.call,
scopes_for_cache,
self.computed_scope or "",
)
@cached_property
def _uses_scopes(self) -> bool:
if self.own_oauth_scopes:
return True
if self.security_scopes_param_name is not None:
return True
for sub_dep in self.dependencies:
if sub_dep._uses_scopes:
return True
return False
@cached_property
def is_gen_callable(self) -> bool:
if inspect.isgeneratorfunction(self.call):
return True
dunder_call = getattr(self.call, "__call__", None) # noqa: B004
return inspect.isgeneratorfunction(dunder_call)
@cached_property
def is_async_gen_callable(self) -> bool:
if inspect.isasyncgenfunction(self.call):
return True
dunder_call = getattr(self.call, "__call__", None) # noqa: B004
return inspect.isasyncgenfunction(dunder_call)
@cached_property
def is_coroutine_callable(self) -> bool:
if inspect.isroutine(self.call):
return iscoroutinefunction(self.call)
if inspect.isclass(self.call):
return False
dunder_call = getattr(self.call, "__call__", None) # noqa: B004
return iscoroutinefunction(dunder_call)
@cached_property
def computed_scope(self) -> Union[str, None]:
if self.scope:
return self.scope
if self.is_gen_callable or self.is_async_gen_callable:
return "request"
return None

View File

@@ -17,14 +17,16 @@ from types import GeneratorType
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
from uuid import UUID
from annotated_doc import Doc
from fastapi._compat import may_v1
from fastapi.types import IncEx
from pydantic import BaseModel
from pydantic.color import Color
from pydantic.networks import AnyUrl, NameEmail
from pydantic.types import SecretBytes, SecretStr
from typing_extensions import Annotated, Doc # type: ignore [attr-defined]
from typing_extensions import Annotated
from ._compat import PYDANTIC_V2, Url, _model_dump
from ._compat import Url, _is_undefined, _model_dump
# Taken from Pydantic v1 as is
@@ -58,6 +60,7 @@ def decimal_encoder(dec_value: Decimal) -> Union[int, float]:
ENCODERS_BY_TYPE: Dict[Type[Any], Callable[[Any], Any]] = {
bytes: lambda o: o.decode(),
Color: str,
may_v1.Color: str,
datetime.date: isoformat,
datetime.datetime: isoformat,
datetime.time: isoformat,
@@ -74,19 +77,24 @@ ENCODERS_BY_TYPE: Dict[Type[Any], Callable[[Any], Any]] = {
IPv6Interface: str,
IPv6Network: str,
NameEmail: str,
may_v1.NameEmail: str,
Path: str,
Pattern: lambda o: o.pattern,
SecretBytes: str,
may_v1.SecretBytes: str,
SecretStr: str,
may_v1.SecretStr: str,
set: list,
UUID: str,
Url: str,
may_v1.Url: str,
AnyUrl: str,
may_v1.AnyUrl: str,
}
def generate_encoders_by_class_tuples(
type_encoder_map: Dict[Any, Callable[[Any], Any]]
type_encoder_map: Dict[Any, Callable[[Any], Any]],
) -> Dict[Callable[[Any], Any], Tuple[Any, ...]]:
encoders_by_class_tuples: Dict[Callable[[Any], Any], Tuple[Any, ...]] = defaultdict(
tuple
@@ -213,13 +221,13 @@ def jsonable_encoder(
include = set(include)
if exclude is not None and not isinstance(exclude, (set, dict)):
exclude = set(exclude)
if isinstance(obj, BaseModel):
if isinstance(obj, (BaseModel, may_v1.BaseModel)):
# TODO: remove when deprecating Pydantic v1
encoders: Dict[Any, Any] = {}
if not PYDANTIC_V2:
if isinstance(obj, may_v1.BaseModel):
encoders = getattr(obj.__config__, "json_encoders", {}) # type: ignore[attr-defined]
if custom_encoder:
encoders.update(custom_encoder)
encoders = {**encoders, **custom_encoder}
obj_dict = _model_dump(
obj,
mode="json",
@@ -241,6 +249,7 @@ def jsonable_encoder(
sqlalchemy_safe=sqlalchemy_safe,
)
if dataclasses.is_dataclass(obj):
assert not isinstance(obj, type)
obj_dict = dataclasses.asdict(obj)
return jsonable_encoder(
obj_dict,
@@ -259,6 +268,8 @@ def jsonable_encoder(
return str(obj)
if isinstance(obj, (str, int, float, type(None))):
return obj
if _is_undefined(obj):
return None
if isinstance(obj, dict):
encoded_dict = {}
allowed_keys = set(obj.keys())

View File

@@ -5,7 +5,7 @@ from fastapi.websockets import WebSocket
from starlette.exceptions import HTTPException
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY, WS_1008_POLICY_VIOLATION
from starlette.status import WS_1008_POLICY_VIOLATION
async def http_exception_handler(request: Request, exc: HTTPException) -> Response:
@@ -21,7 +21,7 @@ async def request_validation_exception_handler(
request: Request, exc: RequestValidationError
) -> JSONResponse:
return JSONResponse(
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
status_code=422,
content={"detail": jsonable_encoder(exc.errors())},
)

View File

@@ -1,9 +1,10 @@
from typing import Any, Dict, Optional, Sequence, Type, Union
from annotated_doc import Doc
from pydantic import BaseModel, create_model
from starlette.exceptions import HTTPException as StarletteHTTPException
from starlette.exceptions import WebSocketException as StarletteWebSocketException
from typing_extensions import Annotated, Doc # type: ignore [attr-defined]
from typing_extensions import Annotated
class HTTPException(StarletteHTTPException):
@@ -146,6 +147,13 @@ class FastAPIError(RuntimeError):
"""
class DependencyScopeError(FastAPIError):
"""
A dependency declared that it depends on another dependency with an invalid
(narrower) scope.
"""
class ValidationException(Exception):
def __init__(self, errors: Sequence[Any]) -> None:
self._errors = errors

View File

@@ -1,25 +1,18 @@
from typing import Optional
from contextlib import AsyncExitStack
from fastapi.concurrency import AsyncExitStack
from starlette.types import ASGIApp, Receive, Scope, Send
# Used mainly to close files after the request is done, dependencies are closed
# in their own AsyncExitStack
class AsyncExitStackMiddleware:
def __init__(self, app: ASGIApp, context_name: str = "fastapi_astack") -> None:
def __init__(
self, app: ASGIApp, context_name: str = "fastapi_middleware_astack"
) -> None:
self.app = app
self.context_name = context_name
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
dependency_exception: Optional[Exception] = None
async with AsyncExitStack() as stack:
scope[self.context_name] = stack
try:
await self.app(scope, receive, send)
except Exception as e:
dependency_exception = e
raise e
if dependency_exception:
# This exception was possibly handled by the dependency but it should
# still bubble up so that the ServerErrorMiddleware can return a 500
# or the ExceptionMiddleware can catch and handle any other exceptions
raise dependency_exception
await self.app(scope, receive, send)

View File

@@ -1,9 +1,10 @@
import json
from typing import Any, Dict, Optional
from annotated_doc import Doc
from fastapi.encoders import jsonable_encoder
from starlette.responses import HTMLResponse
from typing_extensions import Annotated, Doc # type: ignore [attr-defined]
from typing_extensions import Annotated
swagger_ui_default_parameters: Annotated[
Dict[str, Any],
@@ -53,7 +54,7 @@ def get_swagger_ui_html(
It is normally set to a CDN URL.
"""
),
] = "https://cdn.jsdelivr.net/npm/swagger-ui-dist@5.9.0/swagger-ui-bundle.js",
] = "https://cdn.jsdelivr.net/npm/swagger-ui-dist@5/swagger-ui-bundle.js",
swagger_css_url: Annotated[
str,
Doc(
@@ -63,7 +64,7 @@ def get_swagger_ui_html(
It is normally set to a CDN URL.
"""
),
] = "https://cdn.jsdelivr.net/npm/swagger-ui-dist@5.9.0/swagger-ui.css",
] = "https://cdn.jsdelivr.net/npm/swagger-ui-dist@5/swagger-ui.css",
swagger_favicon_url: Annotated[
str,
Doc(
@@ -188,7 +189,7 @@ def get_redoc_html(
It is normally set to a CDN URL.
"""
),
] = "https://cdn.jsdelivr.net/npm/redoc@next/bundles/redoc.standalone.js",
] = "https://cdn.jsdelivr.net/npm/redoc@2/bundles/redoc.standalone.js",
redoc_favicon_url: Annotated[
str,
Doc(

View File

@@ -55,35 +55,29 @@ except ImportError: # pragma: no cover
return with_info_plain_validator_function(cls._validate)
class Contact(BaseModel):
class BaseModelWithConfig(BaseModel):
if PYDANTIC_V2:
model_config = {"extra": "allow"}
else:
class Config:
extra = "allow"
class Contact(BaseModelWithConfig):
name: Optional[str] = None
url: Optional[AnyUrl] = None
email: Optional[EmailStr] = None
if PYDANTIC_V2:
model_config = {"extra": "allow"}
else:
class Config:
extra = "allow"
class License(BaseModel):
class License(BaseModelWithConfig):
name: str
identifier: Optional[str] = None
url: Optional[AnyUrl] = None
if PYDANTIC_V2:
model_config = {"extra": "allow"}
else:
class Config:
extra = "allow"
class Info(BaseModel):
class Info(BaseModelWithConfig):
title: str
summary: Optional[str] = None
description: Optional[str] = None
@@ -92,42 +86,18 @@ class Info(BaseModel):
license: Optional[License] = None
version: str
if PYDANTIC_V2:
model_config = {"extra": "allow"}
else:
class Config:
extra = "allow"
class ServerVariable(BaseModel):
class ServerVariable(BaseModelWithConfig):
enum: Annotated[Optional[List[str]], Field(min_length=1)] = None
default: str
description: Optional[str] = None
if PYDANTIC_V2:
model_config = {"extra": "allow"}
else:
class Config:
extra = "allow"
class Server(BaseModel):
class Server(BaseModelWithConfig):
url: Union[AnyUrl, str]
description: Optional[str] = None
variables: Optional[Dict[str, ServerVariable]] = None
if PYDANTIC_V2:
model_config = {"extra": "allow"}
else:
class Config:
extra = "allow"
class Reference(BaseModel):
ref: str = Field(alias="$ref")
@@ -138,36 +108,26 @@ class Discriminator(BaseModel):
mapping: Optional[Dict[str, str]] = None
class XML(BaseModel):
class XML(BaseModelWithConfig):
name: Optional[str] = None
namespace: Optional[str] = None
prefix: Optional[str] = None
attribute: Optional[bool] = None
wrapped: Optional[bool] = None
if PYDANTIC_V2:
model_config = {"extra": "allow"}
else:
class Config:
extra = "allow"
class ExternalDocumentation(BaseModel):
class ExternalDocumentation(BaseModelWithConfig):
description: Optional[str] = None
url: AnyUrl
if PYDANTIC_V2:
model_config = {"extra": "allow"}
else:
class Config:
extra = "allow"
# Ref JSON Schema 2020-12: https://json-schema.org/draft/2020-12/json-schema-validation#name-type
SchemaType = Literal[
"array", "boolean", "integer", "null", "number", "object", "string"
]
class Schema(BaseModel):
class Schema(BaseModelWithConfig):
# Ref: JSON Schema 2020-12: https://json-schema.org/draft/2020-12/json-schema-core.html#name-the-json-schema-core-vocabu
# Core Vocabulary
schema_: Optional[str] = Field(default=None, alias="$schema")
@@ -191,7 +151,7 @@ class Schema(BaseModel):
dependentSchemas: Optional[Dict[str, "SchemaOrBool"]] = None
prefixItems: Optional[List["SchemaOrBool"]] = None
# TODO: uncomment and remove below when deprecating Pydantic v1
# It generales a list of schemas for tuples, before prefixItems was available
# It generates a list of schemas for tuples, before prefixItems was available
# items: Optional["SchemaOrBool"] = None
items: Optional[Union["SchemaOrBool", List["SchemaOrBool"]]] = None
contains: Optional["SchemaOrBool"] = None
@@ -203,7 +163,7 @@ class Schema(BaseModel):
unevaluatedProperties: Optional["SchemaOrBool"] = None
# Ref: JSON Schema Validation 2020-12: https://json-schema.org/draft/2020-12/json-schema-validation.html#name-a-vocabulary-for-structural
# A Vocabulary for Structural Validation
type: Optional[str] = None
type: Optional[Union[SchemaType, List[SchemaType]]] = None
enum: Optional[List[Any]] = None
const: Optional[Any] = None
multipleOf: Optional[float] = Field(default=None, gt=0)
@@ -253,14 +213,6 @@ class Schema(BaseModel):
),
] = None
if PYDANTIC_V2:
model_config = {"extra": "allow"}
else:
class Config:
extra = "allow"
# Ref: https://json-schema.org/draft/2020-12/json-schema-core.html#name-json-schema-documents
# A JSON Schema MUST be an object or a boolean.
@@ -289,38 +241,22 @@ class ParameterInType(Enum):
cookie = "cookie"
class Encoding(BaseModel):
class Encoding(BaseModelWithConfig):
contentType: Optional[str] = None
headers: Optional[Dict[str, Union["Header", Reference]]] = None
style: Optional[str] = None
explode: Optional[bool] = None
allowReserved: Optional[bool] = None
if PYDANTIC_V2:
model_config = {"extra": "allow"}
else:
class Config:
extra = "allow"
class MediaType(BaseModel):
class MediaType(BaseModelWithConfig):
schema_: Optional[Union[Schema, Reference]] = Field(default=None, alias="schema")
example: Optional[Any] = None
examples: Optional[Dict[str, Union[Example, Reference]]] = None
encoding: Optional[Dict[str, Encoding]] = None
if PYDANTIC_V2:
model_config = {"extra": "allow"}
else:
class Config:
extra = "allow"
class ParameterBase(BaseModel):
class ParameterBase(BaseModelWithConfig):
description: Optional[str] = None
required: Optional[bool] = None
deprecated: Optional[bool] = None
@@ -334,14 +270,6 @@ class ParameterBase(BaseModel):
# Serialization rules for more complex scenarios
content: Optional[Dict[str, MediaType]] = None
if PYDANTIC_V2:
model_config = {"extra": "allow"}
else:
class Config:
extra = "allow"
class Parameter(ParameterBase):
name: str
@@ -352,21 +280,13 @@ class Header(ParameterBase):
pass
class RequestBody(BaseModel):
class RequestBody(BaseModelWithConfig):
description: Optional[str] = None
content: Dict[str, MediaType]
required: Optional[bool] = None
if PYDANTIC_V2:
model_config = {"extra": "allow"}
else:
class Config:
extra = "allow"
class Link(BaseModel):
class Link(BaseModelWithConfig):
operationRef: Optional[str] = None
operationId: Optional[str] = None
parameters: Optional[Dict[str, Union[Any, str]]] = None
@@ -374,31 +294,15 @@ class Link(BaseModel):
description: Optional[str] = None
server: Optional[Server] = None
if PYDANTIC_V2:
model_config = {"extra": "allow"}
else:
class Config:
extra = "allow"
class Response(BaseModel):
class Response(BaseModelWithConfig):
description: str
headers: Optional[Dict[str, Union[Header, Reference]]] = None
content: Optional[Dict[str, MediaType]] = None
links: Optional[Dict[str, Union[Link, Reference]]] = None
if PYDANTIC_V2:
model_config = {"extra": "allow"}
else:
class Config:
extra = "allow"
class Operation(BaseModel):
class Operation(BaseModelWithConfig):
tags: Optional[List[str]] = None
summary: Optional[str] = None
description: Optional[str] = None
@@ -413,16 +317,8 @@ class Operation(BaseModel):
security: Optional[List[Dict[str, List[str]]]] = None
servers: Optional[List[Server]] = None
if PYDANTIC_V2:
model_config = {"extra": "allow"}
else:
class Config:
extra = "allow"
class PathItem(BaseModel):
class PathItem(BaseModelWithConfig):
ref: Optional[str] = Field(default=None, alias="$ref")
summary: Optional[str] = None
description: Optional[str] = None
@@ -437,14 +333,6 @@ class PathItem(BaseModel):
servers: Optional[List[Server]] = None
parameters: Optional[List[Union[Parameter, Reference]]] = None
if PYDANTIC_V2:
model_config = {"extra": "allow"}
else:
class Config:
extra = "allow"
class SecuritySchemeType(Enum):
apiKey = "apiKey"
@@ -453,18 +341,10 @@ class SecuritySchemeType(Enum):
openIdConnect = "openIdConnect"
class SecurityBase(BaseModel):
class SecurityBase(BaseModelWithConfig):
type_: SecuritySchemeType = Field(alias="type")
description: Optional[str] = None
if PYDANTIC_V2:
model_config = {"extra": "allow"}
else:
class Config:
extra = "allow"
class APIKeyIn(Enum):
query = "query"
@@ -488,18 +368,10 @@ class HTTPBearer(HTTPBase):
bearerFormat: Optional[str] = None
class OAuthFlow(BaseModel):
class OAuthFlow(BaseModelWithConfig):
refreshUrl: Optional[str] = None
scopes: Dict[str, str] = {}
if PYDANTIC_V2:
model_config = {"extra": "allow"}
else:
class Config:
extra = "allow"
class OAuthFlowImplicit(OAuthFlow):
authorizationUrl: str
@@ -518,20 +390,12 @@ class OAuthFlowAuthorizationCode(OAuthFlow):
tokenUrl: str
class OAuthFlows(BaseModel):
class OAuthFlows(BaseModelWithConfig):
implicit: Optional[OAuthFlowImplicit] = None
password: Optional[OAuthFlowPassword] = None
clientCredentials: Optional[OAuthFlowClientCredentials] = None
authorizationCode: Optional[OAuthFlowAuthorizationCode] = None
if PYDANTIC_V2:
model_config = {"extra": "allow"}
else:
class Config:
extra = "allow"
class OAuth2(SecurityBase):
type_: SecuritySchemeType = Field(default=SecuritySchemeType.oauth2, alias="type")
@@ -548,7 +412,7 @@ class OpenIdConnect(SecurityBase):
SecurityScheme = Union[APIKey, HTTPBase, OAuth2, OpenIdConnect, HTTPBearer]
class Components(BaseModel):
class Components(BaseModelWithConfig):
schemas: Optional[Dict[str, Union[Schema, Reference]]] = None
responses: Optional[Dict[str, Union[Response, Reference]]] = None
parameters: Optional[Dict[str, Union[Parameter, Reference]]] = None
@@ -561,30 +425,14 @@ class Components(BaseModel):
callbacks: Optional[Dict[str, Union[Dict[str, PathItem], Reference, Any]]] = None
pathItems: Optional[Dict[str, Union[PathItem, Reference]]] = None
if PYDANTIC_V2:
model_config = {"extra": "allow"}
else:
class Config:
extra = "allow"
class Tag(BaseModel):
class Tag(BaseModelWithConfig):
name: str
description: Optional[str] = None
externalDocs: Optional[ExternalDocumentation] = None
if PYDANTIC_V2:
model_config = {"extra": "allow"}
else:
class Config:
extra = "allow"
class OpenAPI(BaseModel):
class OpenAPI(BaseModelWithConfig):
openapi: str
info: Info
jsonSchemaDialect: Optional[str] = None
@@ -597,14 +445,6 @@ class OpenAPI(BaseModel):
tags: Optional[List[Tag]] = None
externalDocs: Optional[ExternalDocumentation] = None
if PYDANTIC_V2:
model_config = {"extra": "allow"}
else:
class Config:
extra = "allow"
_model_rebuild(Schema)
_model_rebuild(Operation)

View File

@@ -5,7 +5,6 @@ from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Type, Union,
from fastapi import routing
from fastapi._compat import (
GenerateJsonSchema,
JsonSchemaValue,
ModelField,
Undefined,
@@ -16,11 +15,15 @@ from fastapi._compat import (
)
from fastapi.datastructures import DefaultPlaceholder
from fastapi.dependencies.models import Dependant
from fastapi.dependencies.utils import get_flat_dependant, get_flat_params
from fastapi.dependencies.utils import (
_get_flat_fields_from_params,
get_flat_dependant,
get_flat_params,
)
from fastapi.encoders import jsonable_encoder
from fastapi.openapi.constants import METHODS_WITH_BODY, REF_PREFIX, REF_TEMPLATE
from fastapi.openapi.constants import METHODS_WITH_BODY, REF_PREFIX
from fastapi.openapi.models import OpenAPI
from fastapi.params import Body, Param
from fastapi.params import Body, ParamTypes
from fastapi.responses import Response
from fastapi.types import ModelNameMap
from fastapi.utils import (
@@ -28,11 +31,13 @@ from fastapi.utils import (
generate_operation_id_for_path,
is_body_allowed_for_status_code,
)
from pydantic import BaseModel
from starlette.responses import JSONResponse
from starlette.routing import BaseRoute
from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY
from typing_extensions import Literal
from .._compat import _is_model_field
validation_error_definition = {
"title": "ValidationError",
"type": "object",
@@ -87,10 +92,9 @@ def get_openapi_security_definitions(
return security_definitions, operation_security
def get_openapi_operation_parameters(
def _get_openapi_operation_parameters(
*,
all_route_params: Sequence[ModelField],
schema_generator: GenerateJsonSchema,
dependant: Dependant,
model_name_map: ModelNameMap,
field_mapping: Dict[
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
@@ -98,40 +102,72 @@ def get_openapi_operation_parameters(
separate_input_output_schemas: bool = True,
) -> List[Dict[str, Any]]:
parameters = []
for param in all_route_params:
field_info = param.field_info
field_info = cast(Param, field_info)
if not field_info.include_in_schema:
continue
param_schema = get_schema_from_model_field(
field=param,
schema_generator=schema_generator,
model_name_map=model_name_map,
field_mapping=field_mapping,
separate_input_output_schemas=separate_input_output_schemas,
)
parameter = {
"name": param.alias,
"in": field_info.in_.value,
"required": param.required,
"schema": param_schema,
}
if field_info.description:
parameter["description"] = field_info.description
if field_info.openapi_examples:
parameter["examples"] = jsonable_encoder(field_info.openapi_examples)
elif field_info.example != Undefined:
parameter["example"] = jsonable_encoder(field_info.example)
if field_info.deprecated:
parameter["deprecated"] = field_info.deprecated
parameters.append(parameter)
flat_dependant = get_flat_dependant(dependant, skip_repeats=True)
path_params = _get_flat_fields_from_params(flat_dependant.path_params)
query_params = _get_flat_fields_from_params(flat_dependant.query_params)
header_params = _get_flat_fields_from_params(flat_dependant.header_params)
cookie_params = _get_flat_fields_from_params(flat_dependant.cookie_params)
parameter_groups = [
(ParamTypes.path, path_params),
(ParamTypes.query, query_params),
(ParamTypes.header, header_params),
(ParamTypes.cookie, cookie_params),
]
default_convert_underscores = True
if len(flat_dependant.header_params) == 1:
first_field = flat_dependant.header_params[0]
if lenient_issubclass(first_field.type_, BaseModel):
default_convert_underscores = getattr(
first_field.field_info, "convert_underscores", True
)
for param_type, param_group in parameter_groups:
for param in param_group:
field_info = param.field_info
# field_info = cast(Param, field_info)
if not getattr(field_info, "include_in_schema", True):
continue
param_schema = get_schema_from_model_field(
field=param,
model_name_map=model_name_map,
field_mapping=field_mapping,
separate_input_output_schemas=separate_input_output_schemas,
)
name = param.alias
convert_underscores = getattr(
param.field_info,
"convert_underscores",
default_convert_underscores,
)
if (
param_type == ParamTypes.header
and param.alias == param.name
and convert_underscores
):
name = param.name.replace("_", "-")
parameter = {
"name": name,
"in": param_type.value,
"required": param.required,
"schema": param_schema,
}
if field_info.description:
parameter["description"] = field_info.description
openapi_examples = getattr(field_info, "openapi_examples", None)
example = getattr(field_info, "example", None)
if openapi_examples:
parameter["examples"] = jsonable_encoder(openapi_examples)
elif example != Undefined:
parameter["example"] = jsonable_encoder(example)
if getattr(field_info, "deprecated", None):
parameter["deprecated"] = True
parameters.append(parameter)
return parameters
def get_openapi_operation_request_body(
*,
body_field: Optional[ModelField],
schema_generator: GenerateJsonSchema,
model_name_map: ModelNameMap,
field_mapping: Dict[
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
@@ -140,10 +176,9 @@ def get_openapi_operation_request_body(
) -> Optional[Dict[str, Any]]:
if not body_field:
return None
assert isinstance(body_field, ModelField)
assert _is_model_field(body_field)
body_schema = get_schema_from_model_field(
field=body_field,
schema_generator=schema_generator,
model_name_map=model_name_map,
field_mapping=field_mapping,
separate_input_output_schemas=separate_input_output_schemas,
@@ -216,7 +251,6 @@ def get_openapi_path(
*,
route: routing.APIRoute,
operation_ids: Set[str],
schema_generator: GenerateJsonSchema,
model_name_map: ModelNameMap,
field_mapping: Dict[
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
@@ -247,10 +281,8 @@ def get_openapi_path(
operation.setdefault("security", []).extend(operation_security)
if security_definitions:
security_schemes.update(security_definitions)
all_route_params = get_flat_params(route.dependant)
operation_parameters = get_openapi_operation_parameters(
all_route_params=all_route_params,
schema_generator=schema_generator,
operation_parameters = _get_openapi_operation_parameters(
dependant=route.dependant,
model_name_map=model_name_map,
field_mapping=field_mapping,
separate_input_output_schemas=separate_input_output_schemas,
@@ -272,7 +304,6 @@ def get_openapi_path(
if method in METHODS_WITH_BODY:
request_body_oai = get_openapi_operation_request_body(
body_field=route.body_field,
schema_generator=schema_generator,
model_name_map=model_name_map,
field_mapping=field_mapping,
separate_input_output_schemas=separate_input_output_schemas,
@@ -290,7 +321,6 @@ def get_openapi_path(
) = get_openapi_path(
route=callback,
operation_ids=operation_ids,
schema_generator=schema_generator,
model_name_map=model_name_map,
field_mapping=field_mapping,
separate_input_output_schemas=separate_input_output_schemas,
@@ -321,7 +351,6 @@ def get_openapi_path(
if route.response_field:
response_schema = get_schema_from_model_field(
field=route.response_field,
schema_generator=schema_generator,
model_name_map=model_name_map,
field_mapping=field_mapping,
separate_input_output_schemas=separate_input_output_schemas,
@@ -347,15 +376,14 @@ def get_openapi_path(
openapi_response = operation_responses.setdefault(
status_code_key, {}
)
assert isinstance(
process_response, dict
), "An additional response must be a dict"
assert isinstance(process_response, dict), (
"An additional response must be a dict"
)
field = route.response_fields.get(additional_status_code)
additional_field_schema: Optional[Dict[str, Any]] = None
if field:
additional_field_schema = get_schema_from_model_field(
field=field,
schema_generator=schema_generator,
model_name_map=model_name_map,
field_mapping=field_mapping,
separate_input_output_schemas=separate_input_output_schemas,
@@ -378,7 +406,8 @@ def get_openapi_path(
)
deep_dict_update(openapi_response, process_response)
openapi_response["description"] = description
http422 = str(HTTP_422_UNPROCESSABLE_ENTITY)
http422 = "422"
all_route_params = get_flat_params(route.dependant)
if (all_route_params or route.body_field) and not any(
status in operation["responses"]
for status in [http422, "4XX", "default"]
@@ -416,9 +445,9 @@ def get_fields_from_routes(
route, routing.APIRoute
):
if route.body_field:
assert isinstance(
route.body_field, ModelField
), "A request body must be a Pydantic Field"
assert _is_model_field(route.body_field), (
"A request body must be a Pydantic Field"
)
body_fields_from_routes.append(route.body_field)
if route.response_field:
responses_from_routes.append(route.response_field)
@@ -450,6 +479,7 @@ def get_openapi(
contact: Optional[Dict[str, Union[str, Any]]] = None,
license_info: Optional[Dict[str, Union[str, Any]]] = None,
separate_input_output_schemas: bool = True,
external_docs: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
info: Dict[str, Any] = {"title": title, "version": version}
if summary:
@@ -471,10 +501,8 @@ def get_openapi(
operation_ids: Set[str] = set()
all_fields = get_fields_from_routes(list(routes or []) + list(webhooks or []))
model_name_map = get_compat_model_name_map(all_fields)
schema_generator = GenerateJsonSchema(ref_template=REF_TEMPLATE)
field_mapping, definitions = get_definitions(
fields=all_fields,
schema_generator=schema_generator,
model_name_map=model_name_map,
separate_input_output_schemas=separate_input_output_schemas,
)
@@ -483,7 +511,6 @@ def get_openapi(
result = get_openapi_path(
route=route,
operation_ids=operation_ids,
schema_generator=schema_generator,
model_name_map=model_name_map,
field_mapping=field_mapping,
separate_input_output_schemas=separate_input_output_schemas,
@@ -503,7 +530,6 @@ def get_openapi(
result = get_openapi_path(
route=webhook,
operation_ids=operation_ids,
schema_generator=schema_generator,
model_name_map=model_name_map,
field_mapping=field_mapping,
separate_input_output_schemas=separate_input_output_schemas,
@@ -527,4 +553,6 @@ def get_openapi(
output["webhooks"] = webhook_paths
if tags:
output["tags"] = tags
if external_docs:
output["externalDocs"] = external_docs
return jsonable_encoder(OpenAPI(**output), by_alias=True, exclude_none=True) # type: ignore

View File

@@ -1,9 +1,10 @@
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
from annotated_doc import Doc
from fastapi import params
from fastapi._compat import Undefined
from fastapi.openapi.models import Example
from typing_extensions import Annotated, Doc, deprecated # type: ignore [attr-defined]
from typing_extensions import Annotated, Literal, deprecated
_Unset: Any = Undefined
@@ -240,7 +241,7 @@ def Path( # noqa: N802
),
] = None,
deprecated: Annotated[
Optional[bool],
Union[deprecated, str, bool, None],
Doc(
"""
Mark this parameter field as deprecated.
@@ -565,7 +566,7 @@ def Query( # noqa: N802
),
] = None,
deprecated: Annotated[
Optional[bool],
Union[deprecated, str, bool, None],
Doc(
"""
Mark this parameter field as deprecated.
@@ -880,7 +881,7 @@ def Header( # noqa: N802
),
] = None,
deprecated: Annotated[
Optional[bool],
Union[deprecated, str, bool, None],
Doc(
"""
Mark this parameter field as deprecated.
@@ -1185,7 +1186,7 @@ def Cookie( # noqa: N802
),
] = None,
deprecated: Annotated[
Optional[bool],
Union[deprecated, str, bool, None],
Doc(
"""
Mark this parameter field as deprecated.
@@ -1282,7 +1283,7 @@ def Body( # noqa: N802
),
] = _Unset,
embed: Annotated[
bool,
Union[bool, None],
Doc(
"""
When `embed` is `True`, the parameter will be expected in a JSON body as a
@@ -1294,7 +1295,7 @@ def Body( # noqa: N802
[FastAPI docs for Body - Multiple Parameters](https://fastapi.tiangolo.com/tutorial/body-multiple-params/#embed-a-single-body-parameter).
"""
),
] = False,
] = None,
media_type: Annotated[
str,
Doc(
@@ -1512,7 +1513,7 @@ def Body( # noqa: N802
),
] = None,
deprecated: Annotated[
Optional[bool],
Union[deprecated, str, bool, None],
Doc(
"""
Mark this parameter field as deprecated.
@@ -1827,7 +1828,7 @@ def Form( # noqa: N802
),
] = None,
deprecated: Annotated[
Optional[bool],
Union[deprecated, str, bool, None],
Doc(
"""
Mark this parameter field as deprecated.
@@ -2141,7 +2142,7 @@ def File( # noqa: N802
),
] = None,
deprecated: Annotated[
Optional[bool],
Union[deprecated, str, bool, None],
Doc(
"""
Mark this parameter field as deprecated.
@@ -2244,6 +2245,26 @@ def Depends( # noqa: N802
"""
),
] = True,
scope: Annotated[
Union[Literal["function", "request"], None],
Doc(
"""
Mainly for dependencies with `yield`, define when the dependency function
should start (the code before `yield`) and when it should end (the code
after `yield`).
* `"function"`: start the dependency before the *path operation function*
that handles the request, end the dependency after the *path operation
function* ends, but **before** the response is sent back to the client.
So, the dependency function will be executed **around** the *path operation
**function***.
* `"request"`: start the dependency before the *path operation function*
that handles the request (similar to when using `"function"`), but end
**after** the response is sent back to the client. So, the dependency
function will be executed **around** the **request** and response cycle.
"""
),
] = None,
) -> Any:
"""
Declare a FastAPI dependency.
@@ -2274,7 +2295,7 @@ def Depends( # noqa: N802
return commons
```
"""
return params.Depends(dependency=dependency, use_cache=use_cache)
return params.Depends(dependency=dependency, use_cache=use_cache, scope=scope)
def Security( # noqa: N802
@@ -2298,7 +2319,7 @@ def Security( # noqa: N802
dependency.
The term "scope" comes from the OAuth2 specification, it seems to be
intentionaly vague and interpretable. It normally refers to permissions,
intentionally vague and interpretable. It normally refers to permissions,
in cases to roles.
These scopes are integrated with OpenAPI (and the API docs at `/docs`).
@@ -2343,7 +2364,7 @@ def Security( # noqa: N802
```python
from typing import Annotated
from fastapi import Depends, FastAPI
from fastapi import Security, FastAPI
from .db import User
from .security import get_current_active_user

View File

@@ -1,12 +1,17 @@
import warnings
from dataclasses import dataclass
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
from fastapi.openapi.models import Example
from pydantic.fields import FieldInfo
from typing_extensions import Annotated, deprecated
from typing_extensions import Annotated, Literal, deprecated
from ._compat import PYDANTIC_V2, Undefined
from ._compat import (
PYDANTIC_V2,
PYDANTIC_VERSION_MINOR_TUPLE,
Undefined,
)
_Unset: Any = Undefined
@@ -18,7 +23,7 @@ class ParamTypes(Enum):
cookie = "cookie"
class Param(FieldInfo):
class Param(FieldInfo): # type: ignore[misc]
in_: ParamTypes
def __init__(
@@ -63,12 +68,11 @@ class Param(FieldInfo):
),
] = _Unset,
openapi_examples: Optional[Dict[str, Example]] = None,
deprecated: Optional[bool] = None,
deprecated: Union[deprecated, str, bool, None] = None,
include_in_schema: bool = True,
json_schema_extra: Union[Dict[str, Any], None] = None,
**extra: Any,
):
self.deprecated = deprecated
if example is not _Unset:
warnings.warn(
"`example` has been deprecated, please use `examples` instead",
@@ -92,7 +96,7 @@ class Param(FieldInfo):
max_length=max_length,
discriminator=discriminator,
multiple_of=multiple_of,
allow_nan=allow_inf_nan,
allow_inf_nan=allow_inf_nan,
max_digits=max_digits,
decimal_places=decimal_places,
**extra,
@@ -106,6 +110,10 @@ class Param(FieldInfo):
stacklevel=4,
)
current_json_schema_extra = json_schema_extra or extra
if PYDANTIC_VERSION_MINOR_TUPLE < (2, 7):
self.deprecated = deprecated
else:
kwargs["deprecated"] = deprecated
if PYDANTIC_V2:
kwargs.update(
{
@@ -129,7 +137,7 @@ class Param(FieldInfo):
return f"{self.__class__.__name__}({self.default})"
class Path(Param):
class Path(Param): # type: ignore[misc]
in_ = ParamTypes.path
def __init__(
@@ -174,7 +182,7 @@ class Path(Param):
),
] = _Unset,
openapi_examples: Optional[Dict[str, Example]] = None,
deprecated: Optional[bool] = None,
deprecated: Union[deprecated, str, bool, None] = None,
include_in_schema: bool = True,
json_schema_extra: Union[Dict[str, Any], None] = None,
**extra: Any,
@@ -215,7 +223,7 @@ class Path(Param):
)
class Query(Param):
class Query(Param): # type: ignore[misc]
in_ = ParamTypes.query
def __init__(
@@ -260,7 +268,7 @@ class Query(Param):
),
] = _Unset,
openapi_examples: Optional[Dict[str, Example]] = None,
deprecated: Optional[bool] = None,
deprecated: Union[deprecated, str, bool, None] = None,
include_in_schema: bool = True,
json_schema_extra: Union[Dict[str, Any], None] = None,
**extra: Any,
@@ -299,7 +307,7 @@ class Query(Param):
)
class Header(Param):
class Header(Param): # type: ignore[misc]
in_ = ParamTypes.header
def __init__(
@@ -345,7 +353,7 @@ class Header(Param):
),
] = _Unset,
openapi_examples: Optional[Dict[str, Example]] = None,
deprecated: Optional[bool] = None,
deprecated: Union[deprecated, str, bool, None] = None,
include_in_schema: bool = True,
json_schema_extra: Union[Dict[str, Any], None] = None,
**extra: Any,
@@ -385,7 +393,7 @@ class Header(Param):
)
class Cookie(Param):
class Cookie(Param): # type: ignore[misc]
in_ = ParamTypes.cookie
def __init__(
@@ -430,7 +438,7 @@ class Cookie(Param):
),
] = _Unset,
openapi_examples: Optional[Dict[str, Example]] = None,
deprecated: Optional[bool] = None,
deprecated: Union[deprecated, str, bool, None] = None,
include_in_schema: bool = True,
json_schema_extra: Union[Dict[str, Any], None] = None,
**extra: Any,
@@ -469,14 +477,14 @@ class Cookie(Param):
)
class Body(FieldInfo):
class Body(FieldInfo): # type: ignore[misc]
def __init__(
self,
default: Any = Undefined,
*,
default_factory: Union[Callable[[], Any], None] = _Unset,
annotation: Optional[Any] = None,
embed: bool = False,
embed: Union[bool, None] = None,
media_type: str = "application/json",
alias: Optional[str] = None,
alias_priority: Union[int, None] = _Unset,
@@ -514,14 +522,13 @@ class Body(FieldInfo):
),
] = _Unset,
openapi_examples: Optional[Dict[str, Example]] = None,
deprecated: Optional[bool] = None,
deprecated: Union[deprecated, str, bool, None] = None,
include_in_schema: bool = True,
json_schema_extra: Union[Dict[str, Any], None] = None,
**extra: Any,
):
self.embed = embed
self.media_type = media_type
self.deprecated = deprecated
if example is not _Unset:
warnings.warn(
"`example` has been deprecated, please use `examples` instead",
@@ -545,7 +552,7 @@ class Body(FieldInfo):
max_length=max_length,
discriminator=discriminator,
multiple_of=multiple_of,
allow_nan=allow_inf_nan,
allow_inf_nan=allow_inf_nan,
max_digits=max_digits,
decimal_places=decimal_places,
**extra,
@@ -554,11 +561,15 @@ class Body(FieldInfo):
kwargs["examples"] = examples
if regex is not None:
warnings.warn(
"`regex` has been depreacated, please use `pattern` instead",
"`regex` has been deprecated, please use `pattern` instead",
category=DeprecationWarning,
stacklevel=4,
)
current_json_schema_extra = json_schema_extra or extra
if PYDANTIC_VERSION_MINOR_TUPLE < (2, 7):
self.deprecated = deprecated
else:
kwargs["deprecated"] = deprecated
if PYDANTIC_V2:
kwargs.update(
{
@@ -583,7 +594,7 @@ class Body(FieldInfo):
return f"{self.__class__.__name__}({self.default})"
class Form(Body):
class Form(Body): # type: ignore[misc]
def __init__(
self,
default: Any = Undefined,
@@ -627,7 +638,7 @@ class Form(Body):
),
] = _Unset,
openapi_examples: Optional[Dict[str, Example]] = None,
deprecated: Optional[bool] = None,
deprecated: Union[deprecated, str, bool, None] = None,
include_in_schema: bool = True,
json_schema_extra: Union[Dict[str, Any], None] = None,
**extra: Any,
@@ -636,7 +647,6 @@ class Form(Body):
default=default,
default_factory=default_factory,
annotation=annotation,
embed=True,
media_type=media_type,
alias=alias,
alias_priority=alias_priority,
@@ -668,7 +678,7 @@ class Form(Body):
)
class File(Form):
class File(Form): # type: ignore[misc]
def __init__(
self,
default: Any = Undefined,
@@ -712,7 +722,7 @@ class File(Form):
),
] = _Unset,
openapi_examples: Optional[Dict[str, Example]] = None,
deprecated: Optional[bool] = None,
deprecated: Union[deprecated, str, bool, None] = None,
include_in_schema: bool = True,
json_schema_extra: Union[Dict[str, Any], None] = None,
**extra: Any,
@@ -752,26 +762,13 @@ class File(Form):
)
@dataclass(frozen=True)
class Depends:
def __init__(
self, dependency: Optional[Callable[..., Any]] = None, *, use_cache: bool = True
):
self.dependency = dependency
self.use_cache = use_cache
def __repr__(self) -> str:
attr = getattr(self.dependency, "__name__", type(self.dependency).__name__)
cache = "" if self.use_cache else ", use_cache=False"
return f"{self.__class__.__name__}({attr}{cache})"
dependency: Optional[Callable[..., Any]] = None
use_cache: bool = True
scope: Union[Literal["function", "request"], None] = None
@dataclass(frozen=True)
class Security(Depends):
def __init__(
self,
dependency: Optional[Callable[..., Any]] = None,
*,
scopes: Optional[Sequence[str]] = None,
use_cache: bool = True,
):
super().__init__(dependency=dependency, use_cache=use_cache)
self.scopes = scopes or []
scopes: Optional[Sequence[str]] = None

View File

@@ -1,16 +1,21 @@
import asyncio
import dataclasses
import email.message
import functools
import inspect
import json
from contextlib import AsyncExitStack
import sys
from contextlib import AsyncExitStack, asynccontextmanager
from enum import Enum, IntEnum
from typing import (
Any,
AsyncIterator,
Awaitable,
Callable,
Collection,
Coroutine,
Dict,
List,
Mapping,
Optional,
Sequence,
Set,
@@ -19,7 +24,8 @@ from typing import (
Union,
)
from fastapi import params
from annotated_doc import Doc
from fastapi import params, temp_pydantic_v1_params
from fastapi._compat import (
ModelField,
Undefined,
@@ -31,8 +37,10 @@ from fastapi._compat import (
from fastapi.datastructures import Default, DefaultPlaceholder
from fastapi.dependencies.models import Dependant
from fastapi.dependencies.utils import (
_should_embed_body_fields,
get_body_field,
get_dependant,
get_flat_dependant,
get_parameterless_sub_dependant,
get_typed_return_annotation,
solve_dependencies,
@@ -47,13 +55,15 @@ from fastapi.exceptions import (
from fastapi.types import DecoratedCallable, IncEx
from fastapi.utils import (
create_cloned_field,
create_response_field,
create_model_field,
generate_unique_id,
get_value_or_default,
is_body_allowed_for_status_code,
)
from pydantic import BaseModel
from starlette import routing
from starlette._exception_handler import wrap_app_handling_exceptions
from starlette._utils import is_async_callable
from starlette.concurrency import run_in_threadpool
from starlette.exceptions import HTTPException
from starlette.requests import Request
@@ -63,13 +73,84 @@ from starlette.routing import (
Match,
compile_path,
get_name,
request_response,
websocket_session,
)
from starlette.routing import Mount as Mount # noqa
from starlette.types import ASGIApp, Lifespan, Scope
from starlette.types import AppType, ASGIApp, Lifespan, Receive, Scope, Send
from starlette.websockets import WebSocket
from typing_extensions import Annotated, Doc, deprecated # type: ignore [attr-defined]
from typing_extensions import Annotated, deprecated
if sys.version_info >= (3, 13): # pragma: no cover
from inspect import iscoroutinefunction
else: # pragma: no cover
from asyncio import iscoroutinefunction
# Copy of starlette.routing.request_response modified to include the
# dependencies' AsyncExitStack
def request_response(
func: Callable[[Request], Union[Awaitable[Response], Response]],
) -> ASGIApp:
"""
Takes a function or coroutine `func(request) -> response`,
and returns an ASGI application.
"""
f: Callable[[Request], Awaitable[Response]] = (
func if is_async_callable(func) else functools.partial(run_in_threadpool, func) # type:ignore
)
async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive, send)
async def app(scope: Scope, receive: Receive, send: Send) -> None:
# Starts customization
response_awaited = False
async with AsyncExitStack() as request_stack:
scope["fastapi_inner_astack"] = request_stack
async with AsyncExitStack() as function_stack:
scope["fastapi_function_astack"] = function_stack
response = await f(request)
await response(scope, receive, send)
# Continues customization
response_awaited = True
if not response_awaited:
raise FastAPIError(
"Response not awaited. There's a high chance that the "
"application code is raising an exception and a dependency with yield "
"has a block with a bare except, or a block with except Exception, "
"and is not raising the exception again. Read more about it in the "
"docs: https://fastapi.tiangolo.com/tutorial/dependencies/dependencies-with-yield/#dependencies-with-yield-and-except"
)
# Same as in Starlette
await wrap_app_handling_exceptions(app, request)(scope, receive, send)
return app
# Copy of starlette.routing.websocket_session modified to include the
# dependencies' AsyncExitStack
def websocket_session(
func: Callable[[WebSocket], Awaitable[None]],
) -> ASGIApp:
"""
Takes a coroutine `func(session)`, and returns an ASGI application.
"""
# assert asyncio.iscoroutinefunction(func), "WebSocket endpoints must be async"
async def app(scope: Scope, receive: Receive, send: Send) -> None:
session = WebSocket(scope, receive=receive, send=send)
async def app(scope: Scope, receive: Receive, send: Send) -> None:
async with AsyncExitStack() as request_stack:
scope["fastapi_inner_astack"] = request_stack
async with AsyncExitStack() as function_stack:
scope["fastapi_function_astack"] = function_stack
await func(session)
# Same as in Starlette
await wrap_app_handling_exceptions(app, session)(scope, receive, send)
return app
def _prepare_response_content(
@@ -115,10 +196,28 @@ def _prepare_response_content(
for k, v in res.items()
}
elif dataclasses.is_dataclass(res):
assert not isinstance(res, type)
return dataclasses.asdict(res)
return res
def _merge_lifespan_context(
original_context: Lifespan[Any], nested_context: Lifespan[Any]
) -> Lifespan[Any]:
@asynccontextmanager
async def merged_lifespan(
app: AppType,
) -> AsyncIterator[Optional[Mapping[str, Any]]]:
async with original_context(app) as maybe_original_state:
async with nested_context(app) as maybe_nested_state:
if maybe_nested_state is None and maybe_original_state is None:
yield None # old ASGI compatibility
else:
yield {**(maybe_nested_state or {}), **(maybe_original_state or {})}
return merged_lifespan # type: ignore[return-value]
async def serialize_response(
*,
field: Optional[ModelField] = None,
@@ -206,24 +305,32 @@ def get_request_handler(
response_model_exclude_defaults: bool = False,
response_model_exclude_none: bool = False,
dependency_overrides_provider: Optional[Any] = None,
embed_body_fields: bool = False,
) -> Callable[[Request], Coroutine[Any, Any, Response]]:
assert dependant.call is not None, "dependant.call must be a function"
is_coroutine = asyncio.iscoroutinefunction(dependant.call)
is_body_form = body_field and isinstance(body_field.field_info, params.Form)
is_coroutine = iscoroutinefunction(dependant.call)
is_body_form = body_field and isinstance(
body_field.field_info, (params.Form, temp_pydantic_v1_params.Form)
)
if isinstance(response_class, DefaultPlaceholder):
actual_response_class: Type[Response] = response_class.value
else:
actual_response_class = response_class
async def app(request: Request) -> Response:
response: Union[Response, None] = None
file_stack = request.scope.get("fastapi_middleware_astack")
assert isinstance(file_stack, AsyncExitStack), (
"fastapi_middleware_astack not found in request scope"
)
# Read body and auto-close files
try:
body: Any = None
if body_field:
if is_body_form:
body = await request.form()
stack = request.scope.get("fastapi_astack")
assert isinstance(stack, AsyncExitStack)
stack.push_async_callback(body.close)
file_stack.push_async_callback(body.close)
else:
body_bytes = await request.body()
if body_bytes:
@@ -243,7 +350,7 @@ def get_request_handler(
else:
body = body_bytes
except json.JSONDecodeError as e:
raise RequestValidationError(
validation_error = RequestValidationError(
[
{
"type": "json_invalid",
@@ -254,75 +361,106 @@ def get_request_handler(
}
],
body=e.doc,
) from e
)
raise validation_error from e
except HTTPException:
# If a middleware raises an HTTPException, it should be raised again
raise
except Exception as e:
raise HTTPException(
http_error = HTTPException(
status_code=400, detail="There was an error parsing the body"
) from e
)
raise http_error from e
# Solve dependencies and run path operation function, auto-closing dependencies
errors: List[Any] = []
async_exit_stack = request.scope.get("fastapi_inner_astack")
assert isinstance(async_exit_stack, AsyncExitStack), (
"fastapi_inner_astack not found in request scope"
)
solved_result = await solve_dependencies(
request=request,
dependant=dependant,
body=body,
dependency_overrides_provider=dependency_overrides_provider,
async_exit_stack=async_exit_stack,
embed_body_fields=embed_body_fields,
)
values, errors, background_tasks, sub_response, _ = solved_result
if errors:
raise RequestValidationError(_normalize_errors(errors), body=body)
else:
errors = solved_result.errors
if not errors:
raw_response = await run_endpoint_function(
dependant=dependant, values=values, is_coroutine=is_coroutine
)
if isinstance(raw_response, Response):
if raw_response.background is None:
raw_response.background = background_tasks
return raw_response
response_args: Dict[str, Any] = {"background": background_tasks}
# If status_code was set, use it, otherwise use the default from the
# response class, in the case of redirect it's 307
current_status_code = (
status_code if status_code else sub_response.status_code
)
if current_status_code is not None:
response_args["status_code"] = current_status_code
if sub_response.status_code:
response_args["status_code"] = sub_response.status_code
content = await serialize_response(
field=response_field,
response_content=raw_response,
include=response_model_include,
exclude=response_model_exclude,
by_alias=response_model_by_alias,
exclude_unset=response_model_exclude_unset,
exclude_defaults=response_model_exclude_defaults,
exclude_none=response_model_exclude_none,
dependant=dependant,
values=solved_result.values,
is_coroutine=is_coroutine,
)
response = actual_response_class(content, **response_args)
if not is_body_allowed_for_status_code(response.status_code):
response.body = b""
response.headers.raw.extend(sub_response.headers.raw)
return response
if isinstance(raw_response, Response):
if raw_response.background is None:
raw_response.background = solved_result.background_tasks
response = raw_response
else:
response_args: Dict[str, Any] = {
"background": solved_result.background_tasks
}
# If status_code was set, use it, otherwise use the default from the
# response class, in the case of redirect it's 307
current_status_code = (
status_code if status_code else solved_result.response.status_code
)
if current_status_code is not None:
response_args["status_code"] = current_status_code
if solved_result.response.status_code:
response_args["status_code"] = solved_result.response.status_code
content = await serialize_response(
field=response_field,
response_content=raw_response,
include=response_model_include,
exclude=response_model_exclude,
by_alias=response_model_by_alias,
exclude_unset=response_model_exclude_unset,
exclude_defaults=response_model_exclude_defaults,
exclude_none=response_model_exclude_none,
is_coroutine=is_coroutine,
)
response = actual_response_class(content, **response_args)
if not is_body_allowed_for_status_code(response.status_code):
response.body = b""
response.headers.raw.extend(solved_result.response.headers.raw)
if errors:
validation_error = RequestValidationError(
_normalize_errors(errors), body=body
)
raise validation_error
# Return response
assert response
return response
return app
def get_websocket_app(
dependant: Dependant, dependency_overrides_provider: Optional[Any] = None
dependant: Dependant,
dependency_overrides_provider: Optional[Any] = None,
embed_body_fields: bool = False,
) -> Callable[[WebSocket], Coroutine[Any, Any, Any]]:
async def app(websocket: WebSocket) -> None:
async_exit_stack = websocket.scope.get("fastapi_inner_astack")
assert isinstance(async_exit_stack, AsyncExitStack), (
"fastapi_inner_astack not found in request scope"
)
solved_result = await solve_dependencies(
request=websocket,
dependant=dependant,
dependency_overrides_provider=dependency_overrides_provider,
async_exit_stack=async_exit_stack,
embed_body_fields=embed_body_fields,
)
values, errors, _, _2, _3 = solved_result
if errors:
raise WebSocketRequestValidationError(_normalize_errors(errors))
if solved_result.errors:
raise WebSocketRequestValidationError(
_normalize_errors(solved_result.errors)
)
assert dependant.call is not None, "dependant.call must be a function"
await dependant.call(**values)
await dependant.call(**solved_result.values)
return app
@@ -342,17 +480,23 @@ class APIWebSocketRoute(routing.WebSocketRoute):
self.name = get_name(endpoint) if name is None else name
self.dependencies = list(dependencies or [])
self.path_regex, self.path_format, self.param_convertors = compile_path(path)
self.dependant = get_dependant(path=self.path_format, call=self.endpoint)
self.dependant = get_dependant(
path=self.path_format, call=self.endpoint, scope="function"
)
for depends in self.dependencies[::-1]:
self.dependant.dependencies.insert(
0,
get_parameterless_sub_dependant(depends=depends, path=self.path_format),
)
self._flat_dependant = get_flat_dependant(self.dependant)
self._embed_body_fields = _should_embed_body_fields(
self._flat_dependant.body_params
)
self.app = websocket_session(
get_websocket_app(
dependant=self.dependant,
dependency_overrides_provider=dependency_overrides_provider,
embed_body_fields=self._embed_body_fields,
)
)
@@ -431,9 +575,9 @@ class APIRoute(routing.Route):
methods = ["GET"]
self.methods: Set[str] = {method.upper() for method in methods}
if isinstance(generate_unique_id_function, DefaultPlaceholder):
current_generate_unique_id: Callable[
["APIRoute"], str
] = generate_unique_id_function.value
current_generate_unique_id: Callable[[APIRoute], str] = (
generate_unique_id_function.value
)
else:
current_generate_unique_id = generate_unique_id_function
self.unique_id = self.operation_id or current_generate_unique_id(self)
@@ -442,11 +586,11 @@ class APIRoute(routing.Route):
status_code = int(status_code)
self.status_code = status_code
if self.response_model:
assert is_body_allowed_for_status_code(
status_code
), f"Status code {status_code} must not have a response body"
assert is_body_allowed_for_status_code(status_code), (
f"Status code {status_code} must not have a response body"
)
response_name = "Response_" + self.unique_id
self.response_field = create_response_field(
self.response_field = create_model_field(
name=response_name,
type_=self.response_model,
mode="serialization",
@@ -459,9 +603,9 @@ class APIRoute(routing.Route):
# By being a new field, no inheritance will be passed as is. A new model
# will always be created.
# TODO: remove when deprecating Pydantic v1
self.secure_cloned_response_field: Optional[
ModelField
] = create_cloned_field(self.response_field)
self.secure_cloned_response_field: Optional[ModelField] = (
create_cloned_field(self.response_field)
)
else:
self.response_field = None # type: ignore
self.secure_cloned_response_field = None
@@ -475,11 +619,13 @@ class APIRoute(routing.Route):
assert isinstance(response, dict), "An additional response must be a dict"
model = response.get("model")
if model:
assert is_body_allowed_for_status_code(
additional_status_code
), f"Status code {additional_status_code} must not have a response body"
assert is_body_allowed_for_status_code(additional_status_code), (
f"Status code {additional_status_code} must not have a response body"
)
response_name = f"Response_{additional_status_code}_{self.unique_id}"
response_field = create_response_field(name=response_name, type_=model)
response_field = create_model_field(
name=response_name, type_=model, mode="serialization"
)
response_fields[additional_status_code] = response_field
if response_fields:
self.response_fields: Dict[Union[int, str], ModelField] = response_fields
@@ -487,13 +633,23 @@ class APIRoute(routing.Route):
self.response_fields = {}
assert callable(endpoint), "An endpoint must be a callable"
self.dependant = get_dependant(path=self.path_format, call=self.endpoint)
self.dependant = get_dependant(
path=self.path_format, call=self.endpoint, scope="function"
)
for depends in self.dependencies[::-1]:
self.dependant.dependencies.insert(
0,
get_parameterless_sub_dependant(depends=depends, path=self.path_format),
)
self.body_field = get_body_field(dependant=self.dependant, name=self.unique_id)
self._flat_dependant = get_flat_dependant(self.dependant)
self._embed_body_fields = _should_embed_body_fields(
self._flat_dependant.body_params
)
self.body_field = get_body_field(
flat_dependant=self._flat_dependant,
name=self.unique_id,
embed_body_fields=self._embed_body_fields,
)
self.app = request_response(self.get_route_handler())
def get_route_handler(self) -> Callable[[Request], Coroutine[Any, Any, Response]]:
@@ -510,6 +666,7 @@ class APIRoute(routing.Route):
response_model_exclude_defaults=self.response_model_exclude_defaults,
response_model_exclude_none=self.response_model_exclude_none,
dependency_overrides_provider=self.dependency_overrides_provider,
embed_body_fields=self._embed_body_fields,
)
def matches(self, scope: Scope) -> Tuple[Match, Scope]:
@@ -741,7 +898,7 @@ class APIRouter(routing.Router):
This affects the generated OpenAPI (e.g. visible at `/docs`).
Read more about it in the
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi).
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi).
"""
),
] = True,
@@ -771,9 +928,9 @@ class APIRouter(routing.Router):
)
if prefix:
assert prefix.startswith("/"), "A path prefix must start with '/'"
assert not prefix.endswith(
"/"
), "A path prefix must not end with '/', as the routes will start with '/'"
assert not prefix.endswith("/"), (
"A path prefix must not end with '/', as the routes will start with '/'"
)
self.prefix = prefix
self.tags: List[Union[str, Enum]] = tags or []
self.dependencies = list(dependencies or [])
@@ -789,7 +946,7 @@ class APIRouter(routing.Router):
def route(
self,
path: str,
methods: Optional[List[str]] = None,
methods: Optional[Collection[str]] = None,
name: Optional[str] = None,
include_in_schema: bool = True,
) -> Callable[[DecoratedCallable], DecoratedCallable]:
@@ -1183,9 +1340,9 @@ class APIRouter(routing.Router):
"""
if prefix:
assert prefix.startswith("/"), "A path prefix must start with '/'"
assert not prefix.endswith(
"/"
), "A path prefix must not end with '/', as the routes will start with '/'"
assert not prefix.endswith("/"), (
"A path prefix must not end with '/', as the routes will start with '/'"
)
else:
for r in router.routes:
path = getattr(r, "path") # noqa: B009
@@ -1285,6 +1442,10 @@ class APIRouter(routing.Router):
self.add_event_handler("startup", handler)
for handler in router.on_shutdown:
self.add_event_handler("shutdown", handler)
self.lifespan_context = _merge_lifespan_context(
self.lifespan_context,
router.lifespan_context,
)
def get(
self,
@@ -1549,7 +1710,7 @@ class APIRouter(routing.Router):
This affects the generated OpenAPI (e.g. visible at `/docs`).
Read more about it in the
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi).
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi).
"""
),
] = True,
@@ -1926,7 +2087,7 @@ class APIRouter(routing.Router):
This affects the generated OpenAPI (e.g. visible at `/docs`).
Read more about it in the
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi).
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi).
"""
),
] = True,
@@ -2308,7 +2469,7 @@ class APIRouter(routing.Router):
This affects the generated OpenAPI (e.g. visible at `/docs`).
Read more about it in the
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi).
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi).
"""
),
] = True,
@@ -2690,7 +2851,7 @@ class APIRouter(routing.Router):
This affects the generated OpenAPI (e.g. visible at `/docs`).
Read more about it in the
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi).
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi).
"""
),
] = True,
@@ -3067,7 +3228,7 @@ class APIRouter(routing.Router):
This affects the generated OpenAPI (e.g. visible at `/docs`).
Read more about it in the
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi).
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi).
"""
),
] = True,
@@ -3444,7 +3605,7 @@ class APIRouter(routing.Router):
This affects the generated OpenAPI (e.g. visible at `/docs`).
Read more about it in the
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi).
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi).
"""
),
] = True,
@@ -3826,7 +3987,7 @@ class APIRouter(routing.Router):
This affects the generated OpenAPI (e.g. visible at `/docs`).
Read more about it in the
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi).
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi).
"""
),
] = True,
@@ -4208,7 +4369,7 @@ class APIRouter(routing.Router):
This affects the generated OpenAPI (e.g. visible at `/docs`).
Read more about it in the
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi).
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi).
"""
),
] = True,
@@ -4293,7 +4454,7 @@ class APIRouter(routing.Router):
app = FastAPI()
router = APIRouter()
@router.put("/items/{item_id}")
@router.trace("/items/{item_id}")
def trace_item(item_id: str):
return None

View File

@@ -1,15 +1,54 @@
from typing import Optional
from typing import Optional, Union
from annotated_doc import Doc
from fastapi.openapi.models import APIKey, APIKeyIn
from fastapi.security.base import SecurityBase
from starlette.exceptions import HTTPException
from starlette.requests import Request
from starlette.status import HTTP_403_FORBIDDEN
from typing_extensions import Annotated, Doc # type: ignore [attr-defined]
from starlette.status import HTTP_401_UNAUTHORIZED
from typing_extensions import Annotated
class APIKeyBase(SecurityBase):
pass
def __init__(
self,
location: APIKeyIn,
name: str,
description: Union[str, None],
scheme_name: Union[str, None],
auto_error: bool,
):
self.auto_error = auto_error
self.model: APIKey = APIKey(
**{"in": location},
name=name,
description=description,
)
self.scheme_name = scheme_name or self.__class__.__name__
def make_not_authenticated_error(self) -> HTTPException:
"""
The WWW-Authenticate header is not standardized for API Key authentication but
the HTTP specification requires that an error of 401 "Unauthorized" must
include a WWW-Authenticate header.
Ref: https://datatracker.ietf.org/doc/html/rfc9110#name-401-unauthorized
For this, this method sends a custom challenge `APIKey`.
"""
return HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Not authenticated",
headers={"WWW-Authenticate": "APIKey"},
)
def check_api_key(self, api_key: Optional[str]) -> Optional[str]:
if not api_key:
if self.auto_error:
raise self.make_not_authenticated_error()
return None
return api_key
class APIKeyQuery(APIKeyBase):
@@ -76,7 +115,7 @@ class APIKeyQuery(APIKeyBase):
Doc(
"""
By default, if the query parameter is not provided, `APIKeyQuery` will
automatically cancel the request and sebd the client an error.
automatically cancel the request and send the client an error.
If `auto_error` is set to `False`, when the query parameter is not
available, instead of erroring out, the dependency result will be
@@ -91,24 +130,17 @@ class APIKeyQuery(APIKeyBase):
),
] = True,
):
self.model: APIKey = APIKey(
**{"in": APIKeyIn.query}, # type: ignore[arg-type]
super().__init__(
location=APIKeyIn.query,
name=name,
scheme_name=scheme_name,
description=description,
auto_error=auto_error,
)
self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error
async def __call__(self, request: Request) -> Optional[str]:
api_key = request.query_params.get(self.model.name)
if not api_key:
if self.auto_error:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
)
else:
return None
return api_key
return self.check_api_key(api_key)
class APIKeyHeader(APIKeyBase):
@@ -186,24 +218,17 @@ class APIKeyHeader(APIKeyBase):
),
] = True,
):
self.model: APIKey = APIKey(
**{"in": APIKeyIn.header}, # type: ignore[arg-type]
super().__init__(
location=APIKeyIn.header,
name=name,
scheme_name=scheme_name,
description=description,
auto_error=auto_error,
)
self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error
async def __call__(self, request: Request) -> Optional[str]:
api_key = request.headers.get(self.model.name)
if not api_key:
if self.auto_error:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
)
else:
return None
return api_key
return self.check_api_key(api_key)
class APIKeyCookie(APIKeyBase):
@@ -281,21 +306,14 @@ class APIKeyCookie(APIKeyBase):
),
] = True,
):
self.model: APIKey = APIKey(
**{"in": APIKeyIn.cookie}, # type: ignore[arg-type]
super().__init__(
location=APIKeyIn.cookie,
name=name,
scheme_name=scheme_name,
description=description,
auto_error=auto_error,
)
self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error
async def __call__(self, request: Request) -> Optional[str]:
api_key = request.cookies.get(self.model.name)
if not api_key:
if self.auto_error:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
)
else:
return None
return api_key
return self.check_api_key(api_key)

View File

@@ -1,7 +1,8 @@
import binascii
from base64 import b64decode
from typing import Optional
from typing import Dict, Optional
from annotated_doc import Doc
from fastapi.exceptions import HTTPException
from fastapi.openapi.models import HTTPBase as HTTPBaseModel
from fastapi.openapi.models import HTTPBearer as HTTPBearerModel
@@ -9,13 +10,13 @@ from fastapi.security.base import SecurityBase
from fastapi.security.utils import get_authorization_scheme_param
from pydantic import BaseModel
from starlette.requests import Request
from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN
from typing_extensions import Annotated, Doc # type: ignore [attr-defined]
from starlette.status import HTTP_401_UNAUTHORIZED
from typing_extensions import Annotated
class HTTPBasicCredentials(BaseModel):
"""
The HTTP Basic credendials given as the result of using `HTTPBasic` in a
The HTTP Basic credentials given as the result of using `HTTPBasic` in a
dependency.
Read more about it in the
@@ -75,10 +76,22 @@ class HTTPBase(SecurityBase):
description: Optional[str] = None,
auto_error: bool = True,
):
self.model = HTTPBaseModel(scheme=scheme, description=description)
self.model: HTTPBaseModel = HTTPBaseModel(
scheme=scheme, description=description
)
self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error
def make_authenticate_headers(self) -> Dict[str, str]:
return {"WWW-Authenticate": f"{self.model.scheme.title()}"}
def make_not_authenticated_error(self) -> HTTPException:
return HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Not authenticated",
headers=self.make_authenticate_headers(),
)
async def __call__(
self, request: Request
) -> Optional[HTTPAuthorizationCredentials]:
@@ -86,9 +99,7 @@ class HTTPBase(SecurityBase):
scheme, credentials = get_authorization_scheme_param(authorization)
if not (authorization and scheme and credentials):
if self.auto_error:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
)
raise self.make_not_authenticated_error()
else:
return None
return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)
@@ -98,6 +109,8 @@ class HTTPBasic(HTTPBase):
"""
HTTP Basic authentication.
Ref: https://datatracker.ietf.org/doc/html/rfc7617
## Usage
Create an instance object and use that object as the dependency in `Depends()`.
@@ -184,36 +197,28 @@ class HTTPBasic(HTTPBase):
self.realm = realm
self.auto_error = auto_error
def make_authenticate_headers(self) -> Dict[str, str]:
if self.realm:
return {"WWW-Authenticate": f'Basic realm="{self.realm}"'}
return {"WWW-Authenticate": "Basic"}
async def __call__( # type: ignore
self, request: Request
) -> Optional[HTTPBasicCredentials]:
authorization = request.headers.get("Authorization")
scheme, param = get_authorization_scheme_param(authorization)
if self.realm:
unauthorized_headers = {"WWW-Authenticate": f'Basic realm="{self.realm}"'}
else:
unauthorized_headers = {"WWW-Authenticate": "Basic"}
if not authorization or scheme.lower() != "basic":
if self.auto_error:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Not authenticated",
headers=unauthorized_headers,
)
raise self.make_not_authenticated_error()
else:
return None
invalid_user_credentials_exc = HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Invalid authentication credentials",
headers=unauthorized_headers,
)
try:
data = b64decode(param).decode("ascii")
except (ValueError, UnicodeDecodeError, binascii.Error):
raise invalid_user_credentials_exc # noqa: B904
except (ValueError, UnicodeDecodeError, binascii.Error) as e:
raise self.make_not_authenticated_error() from e
username, separator, password = data.partition(":")
if not separator:
raise invalid_user_credentials_exc
raise self.make_not_authenticated_error()
return HTTPBasicCredentials(username=username, password=password)
@@ -277,7 +282,7 @@ class HTTPBearer(HTTPBase):
bool,
Doc(
"""
By default, if the HTTP Bearer token not provided (in an
By default, if the HTTP Bearer token is not provided (in an
`Authorization` header), `HTTPBearer` will automatically cancel the
request and send the client an error.
@@ -305,17 +310,12 @@ class HTTPBearer(HTTPBase):
scheme, credentials = get_authorization_scheme_param(authorization)
if not (authorization and scheme and credentials):
if self.auto_error:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
)
raise self.make_not_authenticated_error()
else:
return None
if scheme.lower() != "bearer":
if self.auto_error:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN,
detail="Invalid authentication credentials",
)
raise self.make_not_authenticated_error()
else:
return None
return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)
@@ -325,6 +325,12 @@ class HTTPDigest(HTTPBase):
"""
HTTP Digest authentication.
**Warning**: this is only a stub to connect the components with OpenAPI in FastAPI,
but it doesn't implement the full Digest scheme, you would need to to subclass it
and implement it in your code.
Ref: https://datatracker.ietf.org/doc/html/rfc7616
## Usage
Create an instance object and use that object as the dependency in `Depends()`.
@@ -380,7 +386,7 @@ class HTTPDigest(HTTPBase):
bool,
Doc(
"""
By default, if the HTTP Digest not provided, `HTTPDigest` will
By default, if the HTTP Digest is not provided, `HTTPDigest` will
automatically cancel the request and send the client an error.
If `auto_error` is set to `False`, when the HTTP Digest is not
@@ -407,14 +413,12 @@ class HTTPDigest(HTTPBase):
scheme, credentials = get_authorization_scheme_param(authorization)
if not (authorization and scheme and credentials):
if self.auto_error:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
)
raise self.make_not_authenticated_error()
else:
return None
if scheme.lower() != "digest":
raise HTTPException(
status_code=HTTP_403_FORBIDDEN,
detail="Invalid authentication credentials",
)
if self.auto_error:
raise self.make_not_authenticated_error()
else:
return None
return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)

View File

@@ -1,5 +1,6 @@
from typing import Any, Dict, List, Optional, Union, cast
from annotated_doc import Doc
from fastapi.exceptions import HTTPException
from fastapi.openapi.models import OAuth2 as OAuth2Model
from fastapi.openapi.models import OAuthFlows as OAuthFlowsModel
@@ -7,10 +8,10 @@ from fastapi.param_functions import Form
from fastapi.security.base import SecurityBase
from fastapi.security.utils import get_authorization_scheme_param
from starlette.requests import Request
from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN
from starlette.status import HTTP_401_UNAUTHORIZED
# TODO: import from typing when deprecating Python 3.9
from typing_extensions import Annotated, Doc # type: ignore [attr-defined]
from typing_extensions import Annotated
class OAuth2PasswordRequestForm:
@@ -52,9 +53,9 @@ class OAuth2PasswordRequestForm:
```
Note that for OAuth2 the scope `items:read` is a single scope in an opaque string.
You could have custom internal logic to separate it by colon caracters (`:`) or
You could have custom internal logic to separate it by colon characters (`:`) or
similar, and get the two parts `items` and `read`. Many applications do that to
group and organize permisions, you could do it as well in your application, just
group and organize permissions, you could do it as well in your application, just
know that that it is application specific, it's not part of the specification.
"""
@@ -63,7 +64,7 @@ class OAuth2PasswordRequestForm:
*,
grant_type: Annotated[
Union[str, None],
Form(pattern="password"),
Form(pattern="^password$"),
Doc(
"""
The OAuth2 spec says it is required and MUST be the fixed string
@@ -85,11 +86,11 @@ class OAuth2PasswordRequestForm:
],
password: Annotated[
str,
Form(),
Form(json_schema_extra={"format": "password"}),
Doc(
"""
`password` string. The OAuth2 spec requires the exact field name
`password".
`password`.
"""
),
],
@@ -130,7 +131,7 @@ class OAuth2PasswordRequestForm:
] = None,
client_secret: Annotated[
Union[str, None],
Form(),
Form(json_schema_extra={"format": "password"}),
Doc(
"""
If there's a `client_password` (and a `client_id`), they can be sent
@@ -194,9 +195,9 @@ class OAuth2PasswordRequestFormStrict(OAuth2PasswordRequestForm):
```
Note that for OAuth2 the scope `items:read` is a single scope in an opaque string.
You could have custom internal logic to separate it by colon caracters (`:`) or
You could have custom internal logic to separate it by colon characters (`:`) or
similar, and get the two parts `items` and `read`. Many applications do that to
group and organize permisions, you could do it as well in your application, just
group and organize permissions, you could do it as well in your application, just
know that that it is application specific, it's not part of the specification.
@@ -217,7 +218,7 @@ class OAuth2PasswordRequestFormStrict(OAuth2PasswordRequestForm):
self,
grant_type: Annotated[
str,
Form(pattern="password"),
Form(pattern="^password$"),
Doc(
"""
The OAuth2 spec says it is required and MUST be the fixed string
@@ -243,7 +244,7 @@ class OAuth2PasswordRequestFormStrict(OAuth2PasswordRequestForm):
Doc(
"""
`password` string. The OAuth2 spec requires the exact field name
`password".
`password`.
"""
),
],
@@ -353,7 +354,7 @@ class OAuth2(SecurityBase):
bool,
Doc(
"""
By default, if no HTTP Auhtorization header is provided, required for
By default, if no HTTP Authorization header is provided, required for
OAuth2 authentication, it will automatically cancel the request and
send the client an error.
@@ -376,13 +377,33 @@ class OAuth2(SecurityBase):
self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error
def make_not_authenticated_error(self) -> HTTPException:
"""
The OAuth 2 specification doesn't define the challenge that should be used,
because a `Bearer` token is not really the only option to authenticate.
But declaring any other authentication challenge would be application-specific
as it's not defined in the specification.
For practical reasons, this method uses the `Bearer` challenge by default, as
it's probably the most common one.
If you are implementing an OAuth2 authentication scheme other than the provided
ones in FastAPI (based on bearer tokens), you might want to override this.
Ref: https://datatracker.ietf.org/doc/html/rfc6749
"""
return HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Not authenticated",
headers={"WWW-Authenticate": "Bearer"},
)
async def __call__(self, request: Request) -> Optional[str]:
authorization = request.headers.get("Authorization")
if not authorization:
if self.auto_error:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
)
raise self.make_not_authenticated_error()
else:
return None
return authorization
@@ -441,7 +462,7 @@ class OAuth2PasswordBearer(OAuth2):
bool,
Doc(
"""
By default, if no HTTP Auhtorization header is provided, required for
By default, if no HTTP Authorization header is provided, required for
OAuth2 authentication, it will automatically cancel the request and
send the client an error.
@@ -457,11 +478,26 @@ class OAuth2PasswordBearer(OAuth2):
"""
),
] = True,
refreshUrl: Annotated[
Optional[str],
Doc(
"""
The URL to refresh the token and obtain a new one.
"""
),
] = None,
):
if not scopes:
scopes = {}
flows = OAuthFlowsModel(
password=cast(Any, {"tokenUrl": tokenUrl, "scopes": scopes})
password=cast(
Any,
{
"tokenUrl": tokenUrl,
"refreshUrl": refreshUrl,
"scopes": scopes,
},
)
)
super().__init__(
flows=flows,
@@ -475,11 +511,7 @@ class OAuth2PasswordBearer(OAuth2):
scheme, param = get_authorization_scheme_param(authorization)
if not authorization or scheme.lower() != "bearer":
if self.auto_error:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Not authenticated",
headers={"WWW-Authenticate": "Bearer"},
)
raise self.make_not_authenticated_error()
else:
return None
return param
@@ -543,7 +575,7 @@ class OAuth2AuthorizationCodeBearer(OAuth2):
bool,
Doc(
"""
By default, if no HTTP Auhtorization header is provided, required for
By default, if no HTTP Authorization header is provided, required for
OAuth2 authentication, it will automatically cancel the request and
send the client an error.
@@ -585,11 +617,7 @@ class OAuth2AuthorizationCodeBearer(OAuth2):
scheme, param = get_authorization_scheme_param(authorization)
if not authorization or scheme.lower() != "bearer":
if self.auto_error:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Not authenticated",
headers={"WWW-Authenticate": "Bearer"},
)
raise self.make_not_authenticated_error()
else:
return None # pragma: nocover
return param

View File

@@ -1,17 +1,23 @@
from typing import Optional
from annotated_doc import Doc
from fastapi.openapi.models import OpenIdConnect as OpenIdConnectModel
from fastapi.security.base import SecurityBase
from starlette.exceptions import HTTPException
from starlette.requests import Request
from starlette.status import HTTP_403_FORBIDDEN
from typing_extensions import Annotated, Doc # type: ignore [attr-defined]
from starlette.status import HTTP_401_UNAUTHORIZED
from typing_extensions import Annotated
class OpenIdConnect(SecurityBase):
"""
OpenID Connect authentication class. An instance of it would be used as a
dependency.
**Warning**: this is only a stub to connect the components with OpenAPI in FastAPI,
but it doesn't implement the full OpenIdConnect scheme, for example, it doesn't use
the OpenIDConnect URL. You would need to to subclass it and implement it in your
code.
"""
def __init__(
@@ -49,7 +55,7 @@ class OpenIdConnect(SecurityBase):
bool,
Doc(
"""
By default, if no HTTP Auhtorization header is provided, required for
By default, if no HTTP Authorization header is provided, required for
OpenID Connect authentication, it will automatically cancel the request
and send the client an error.
@@ -72,13 +78,18 @@ class OpenIdConnect(SecurityBase):
self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error
def make_not_authenticated_error(self) -> HTTPException:
return HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Not authenticated",
headers={"WWW-Authenticate": "Bearer"},
)
async def __call__(self, request: Request) -> Optional[str]:
authorization = request.headers.get("Authorization")
if not authorization:
if self.auto_error:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
)
raise self.make_not_authenticated_error()
else:
return None
return authorization

View File

@@ -0,0 +1,724 @@
import warnings
from typing import Any, Callable, Dict, List, Optional, Union
from fastapi.openapi.models import Example
from fastapi.params import ParamTypes
from typing_extensions import Annotated, deprecated
from ._compat.may_v1 import FieldInfo, Undefined
from ._compat.shared import PYDANTIC_VERSION_MINOR_TUPLE
_Unset: Any = Undefined
class Param(FieldInfo): # type: ignore[misc]
in_: ParamTypes
def __init__(
self,
default: Any = Undefined,
*,
default_factory: Union[Callable[[], Any], None] = _Unset,
annotation: Optional[Any] = None,
alias: Optional[str] = None,
alias_priority: Union[int, None] = _Unset,
# TODO: update when deprecating Pydantic v1, import these types
# validation_alias: str | AliasPath | AliasChoices | None
validation_alias: Union[str, None] = None,
serialization_alias: Union[str, None] = None,
title: Optional[str] = None,
description: Optional[str] = None,
gt: Optional[float] = None,
ge: Optional[float] = None,
lt: Optional[float] = None,
le: Optional[float] = None,
min_length: Optional[int] = None,
max_length: Optional[int] = None,
pattern: Optional[str] = None,
regex: Annotated[
Optional[str],
deprecated(
"Deprecated in FastAPI 0.100.0 and Pydantic v2, use `pattern` instead."
),
] = None,
discriminator: Union[str, None] = None,
strict: Union[bool, None] = _Unset,
multiple_of: Union[float, None] = _Unset,
allow_inf_nan: Union[bool, None] = _Unset,
max_digits: Union[int, None] = _Unset,
decimal_places: Union[int, None] = _Unset,
examples: Optional[List[Any]] = None,
example: Annotated[
Optional[Any],
deprecated(
"Deprecated in OpenAPI 3.1.0 that now uses JSON Schema 2020-12, "
"although still supported. Use examples instead."
),
] = _Unset,
openapi_examples: Optional[Dict[str, Example]] = None,
deprecated: Union[deprecated, str, bool, None] = None,
include_in_schema: bool = True,
json_schema_extra: Union[Dict[str, Any], None] = None,
**extra: Any,
):
if example is not _Unset:
warnings.warn(
"`example` has been deprecated, please use `examples` instead",
category=DeprecationWarning,
stacklevel=4,
)
self.example = example
self.include_in_schema = include_in_schema
self.openapi_examples = openapi_examples
kwargs = dict(
default=default,
default_factory=default_factory,
alias=alias,
title=title,
description=description,
gt=gt,
ge=ge,
lt=lt,
le=le,
min_length=min_length,
max_length=max_length,
discriminator=discriminator,
multiple_of=multiple_of,
allow_inf_nan=allow_inf_nan,
max_digits=max_digits,
decimal_places=decimal_places,
**extra,
)
if examples is not None:
kwargs["examples"] = examples
if regex is not None:
warnings.warn(
"`regex` has been deprecated, please use `pattern` instead",
category=DeprecationWarning,
stacklevel=4,
)
current_json_schema_extra = json_schema_extra or extra
if PYDANTIC_VERSION_MINOR_TUPLE < (2, 7):
self.deprecated = deprecated
else:
kwargs["deprecated"] = deprecated
kwargs["regex"] = pattern or regex
kwargs.update(**current_json_schema_extra)
use_kwargs = {k: v for k, v in kwargs.items() if v is not _Unset}
super().__init__(**use_kwargs)
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.default})"
class Path(Param): # type: ignore[misc]
in_ = ParamTypes.path
def __init__(
self,
default: Any = ...,
*,
default_factory: Union[Callable[[], Any], None] = _Unset,
annotation: Optional[Any] = None,
alias: Optional[str] = None,
alias_priority: Union[int, None] = _Unset,
# TODO: update when deprecating Pydantic v1, import these types
# validation_alias: str | AliasPath | AliasChoices | None
validation_alias: Union[str, None] = None,
serialization_alias: Union[str, None] = None,
title: Optional[str] = None,
description: Optional[str] = None,
gt: Optional[float] = None,
ge: Optional[float] = None,
lt: Optional[float] = None,
le: Optional[float] = None,
min_length: Optional[int] = None,
max_length: Optional[int] = None,
pattern: Optional[str] = None,
regex: Annotated[
Optional[str],
deprecated(
"Deprecated in FastAPI 0.100.0 and Pydantic v2, use `pattern` instead."
),
] = None,
discriminator: Union[str, None] = None,
strict: Union[bool, None] = _Unset,
multiple_of: Union[float, None] = _Unset,
allow_inf_nan: Union[bool, None] = _Unset,
max_digits: Union[int, None] = _Unset,
decimal_places: Union[int, None] = _Unset,
examples: Optional[List[Any]] = None,
example: Annotated[
Optional[Any],
deprecated(
"Deprecated in OpenAPI 3.1.0 that now uses JSON Schema 2020-12, "
"although still supported. Use examples instead."
),
] = _Unset,
openapi_examples: Optional[Dict[str, Example]] = None,
deprecated: Union[deprecated, str, bool, None] = None,
include_in_schema: bool = True,
json_schema_extra: Union[Dict[str, Any], None] = None,
**extra: Any,
):
assert default is ..., "Path parameters cannot have a default value"
self.in_ = self.in_
super().__init__(
default=default,
default_factory=default_factory,
annotation=annotation,
alias=alias,
alias_priority=alias_priority,
validation_alias=validation_alias,
serialization_alias=serialization_alias,
title=title,
description=description,
gt=gt,
ge=ge,
lt=lt,
le=le,
min_length=min_length,
max_length=max_length,
pattern=pattern,
regex=regex,
discriminator=discriminator,
strict=strict,
multiple_of=multiple_of,
allow_inf_nan=allow_inf_nan,
max_digits=max_digits,
decimal_places=decimal_places,
deprecated=deprecated,
example=example,
examples=examples,
openapi_examples=openapi_examples,
include_in_schema=include_in_schema,
json_schema_extra=json_schema_extra,
**extra,
)
class Query(Param): # type: ignore[misc]
in_ = ParamTypes.query
def __init__(
self,
default: Any = Undefined,
*,
default_factory: Union[Callable[[], Any], None] = _Unset,
annotation: Optional[Any] = None,
alias: Optional[str] = None,
alias_priority: Union[int, None] = _Unset,
# TODO: update when deprecating Pydantic v1, import these types
# validation_alias: str | AliasPath | AliasChoices | None
validation_alias: Union[str, None] = None,
serialization_alias: Union[str, None] = None,
title: Optional[str] = None,
description: Optional[str] = None,
gt: Optional[float] = None,
ge: Optional[float] = None,
lt: Optional[float] = None,
le: Optional[float] = None,
min_length: Optional[int] = None,
max_length: Optional[int] = None,
pattern: Optional[str] = None,
regex: Annotated[
Optional[str],
deprecated(
"Deprecated in FastAPI 0.100.0 and Pydantic v2, use `pattern` instead."
),
] = None,
discriminator: Union[str, None] = None,
strict: Union[bool, None] = _Unset,
multiple_of: Union[float, None] = _Unset,
allow_inf_nan: Union[bool, None] = _Unset,
max_digits: Union[int, None] = _Unset,
decimal_places: Union[int, None] = _Unset,
examples: Optional[List[Any]] = None,
example: Annotated[
Optional[Any],
deprecated(
"Deprecated in OpenAPI 3.1.0 that now uses JSON Schema 2020-12, "
"although still supported. Use examples instead."
),
] = _Unset,
openapi_examples: Optional[Dict[str, Example]] = None,
deprecated: Union[deprecated, str, bool, None] = None,
include_in_schema: bool = True,
json_schema_extra: Union[Dict[str, Any], None] = None,
**extra: Any,
):
super().__init__(
default=default,
default_factory=default_factory,
annotation=annotation,
alias=alias,
alias_priority=alias_priority,
validation_alias=validation_alias,
serialization_alias=serialization_alias,
title=title,
description=description,
gt=gt,
ge=ge,
lt=lt,
le=le,
min_length=min_length,
max_length=max_length,
pattern=pattern,
regex=regex,
discriminator=discriminator,
strict=strict,
multiple_of=multiple_of,
allow_inf_nan=allow_inf_nan,
max_digits=max_digits,
decimal_places=decimal_places,
deprecated=deprecated,
example=example,
examples=examples,
openapi_examples=openapi_examples,
include_in_schema=include_in_schema,
json_schema_extra=json_schema_extra,
**extra,
)
class Header(Param): # type: ignore[misc]
in_ = ParamTypes.header
def __init__(
self,
default: Any = Undefined,
*,
default_factory: Union[Callable[[], Any], None] = _Unset,
annotation: Optional[Any] = None,
alias: Optional[str] = None,
alias_priority: Union[int, None] = _Unset,
# TODO: update when deprecating Pydantic v1, import these types
# validation_alias: str | AliasPath | AliasChoices | None
validation_alias: Union[str, None] = None,
serialization_alias: Union[str, None] = None,
convert_underscores: bool = True,
title: Optional[str] = None,
description: Optional[str] = None,
gt: Optional[float] = None,
ge: Optional[float] = None,
lt: Optional[float] = None,
le: Optional[float] = None,
min_length: Optional[int] = None,
max_length: Optional[int] = None,
pattern: Optional[str] = None,
regex: Annotated[
Optional[str],
deprecated(
"Deprecated in FastAPI 0.100.0 and Pydantic v2, use `pattern` instead."
),
] = None,
discriminator: Union[str, None] = None,
strict: Union[bool, None] = _Unset,
multiple_of: Union[float, None] = _Unset,
allow_inf_nan: Union[bool, None] = _Unset,
max_digits: Union[int, None] = _Unset,
decimal_places: Union[int, None] = _Unset,
examples: Optional[List[Any]] = None,
example: Annotated[
Optional[Any],
deprecated(
"Deprecated in OpenAPI 3.1.0 that now uses JSON Schema 2020-12, "
"although still supported. Use examples instead."
),
] = _Unset,
openapi_examples: Optional[Dict[str, Example]] = None,
deprecated: Union[deprecated, str, bool, None] = None,
include_in_schema: bool = True,
json_schema_extra: Union[Dict[str, Any], None] = None,
**extra: Any,
):
self.convert_underscores = convert_underscores
super().__init__(
default=default,
default_factory=default_factory,
annotation=annotation,
alias=alias,
alias_priority=alias_priority,
validation_alias=validation_alias,
serialization_alias=serialization_alias,
title=title,
description=description,
gt=gt,
ge=ge,
lt=lt,
le=le,
min_length=min_length,
max_length=max_length,
pattern=pattern,
regex=regex,
discriminator=discriminator,
strict=strict,
multiple_of=multiple_of,
allow_inf_nan=allow_inf_nan,
max_digits=max_digits,
decimal_places=decimal_places,
deprecated=deprecated,
example=example,
examples=examples,
openapi_examples=openapi_examples,
include_in_schema=include_in_schema,
json_schema_extra=json_schema_extra,
**extra,
)
class Cookie(Param): # type: ignore[misc]
in_ = ParamTypes.cookie
def __init__(
self,
default: Any = Undefined,
*,
default_factory: Union[Callable[[], Any], None] = _Unset,
annotation: Optional[Any] = None,
alias: Optional[str] = None,
alias_priority: Union[int, None] = _Unset,
# TODO: update when deprecating Pydantic v1, import these types
# validation_alias: str | AliasPath | AliasChoices | None
validation_alias: Union[str, None] = None,
serialization_alias: Union[str, None] = None,
title: Optional[str] = None,
description: Optional[str] = None,
gt: Optional[float] = None,
ge: Optional[float] = None,
lt: Optional[float] = None,
le: Optional[float] = None,
min_length: Optional[int] = None,
max_length: Optional[int] = None,
pattern: Optional[str] = None,
regex: Annotated[
Optional[str],
deprecated(
"Deprecated in FastAPI 0.100.0 and Pydantic v2, use `pattern` instead."
),
] = None,
discriminator: Union[str, None] = None,
strict: Union[bool, None] = _Unset,
multiple_of: Union[float, None] = _Unset,
allow_inf_nan: Union[bool, None] = _Unset,
max_digits: Union[int, None] = _Unset,
decimal_places: Union[int, None] = _Unset,
examples: Optional[List[Any]] = None,
example: Annotated[
Optional[Any],
deprecated(
"Deprecated in OpenAPI 3.1.0 that now uses JSON Schema 2020-12, "
"although still supported. Use examples instead."
),
] = _Unset,
openapi_examples: Optional[Dict[str, Example]] = None,
deprecated: Union[deprecated, str, bool, None] = None,
include_in_schema: bool = True,
json_schema_extra: Union[Dict[str, Any], None] = None,
**extra: Any,
):
super().__init__(
default=default,
default_factory=default_factory,
annotation=annotation,
alias=alias,
alias_priority=alias_priority,
validation_alias=validation_alias,
serialization_alias=serialization_alias,
title=title,
description=description,
gt=gt,
ge=ge,
lt=lt,
le=le,
min_length=min_length,
max_length=max_length,
pattern=pattern,
regex=regex,
discriminator=discriminator,
strict=strict,
multiple_of=multiple_of,
allow_inf_nan=allow_inf_nan,
max_digits=max_digits,
decimal_places=decimal_places,
deprecated=deprecated,
example=example,
examples=examples,
openapi_examples=openapi_examples,
include_in_schema=include_in_schema,
json_schema_extra=json_schema_extra,
**extra,
)
class Body(FieldInfo): # type: ignore[misc]
def __init__(
self,
default: Any = Undefined,
*,
default_factory: Union[Callable[[], Any], None] = _Unset,
annotation: Optional[Any] = None,
embed: Union[bool, None] = None,
media_type: str = "application/json",
alias: Optional[str] = None,
alias_priority: Union[int, None] = _Unset,
# TODO: update when deprecating Pydantic v1, import these types
# validation_alias: str | AliasPath | AliasChoices | None
validation_alias: Union[str, None] = None,
serialization_alias: Union[str, None] = None,
title: Optional[str] = None,
description: Optional[str] = None,
gt: Optional[float] = None,
ge: Optional[float] = None,
lt: Optional[float] = None,
le: Optional[float] = None,
min_length: Optional[int] = None,
max_length: Optional[int] = None,
pattern: Optional[str] = None,
regex: Annotated[
Optional[str],
deprecated(
"Deprecated in FastAPI 0.100.0 and Pydantic v2, use `pattern` instead."
),
] = None,
discriminator: Union[str, None] = None,
strict: Union[bool, None] = _Unset,
multiple_of: Union[float, None] = _Unset,
allow_inf_nan: Union[bool, None] = _Unset,
max_digits: Union[int, None] = _Unset,
decimal_places: Union[int, None] = _Unset,
examples: Optional[List[Any]] = None,
example: Annotated[
Optional[Any],
deprecated(
"Deprecated in OpenAPI 3.1.0 that now uses JSON Schema 2020-12, "
"although still supported. Use examples instead."
),
] = _Unset,
openapi_examples: Optional[Dict[str, Example]] = None,
deprecated: Union[deprecated, str, bool, None] = None,
include_in_schema: bool = True,
json_schema_extra: Union[Dict[str, Any], None] = None,
**extra: Any,
):
self.embed = embed
self.media_type = media_type
if example is not _Unset:
warnings.warn(
"`example` has been deprecated, please use `examples` instead",
category=DeprecationWarning,
stacklevel=4,
)
self.example = example
self.include_in_schema = include_in_schema
self.openapi_examples = openapi_examples
kwargs = dict(
default=default,
default_factory=default_factory,
alias=alias,
title=title,
description=description,
gt=gt,
ge=ge,
lt=lt,
le=le,
min_length=min_length,
max_length=max_length,
discriminator=discriminator,
multiple_of=multiple_of,
allow_inf_nan=allow_inf_nan,
max_digits=max_digits,
decimal_places=decimal_places,
**extra,
)
if examples is not None:
kwargs["examples"] = examples
if regex is not None:
warnings.warn(
"`regex` has been deprecated, please use `pattern` instead",
category=DeprecationWarning,
stacklevel=4,
)
current_json_schema_extra = json_schema_extra or extra
if PYDANTIC_VERSION_MINOR_TUPLE < (2, 7):
self.deprecated = deprecated
else:
kwargs["deprecated"] = deprecated
kwargs["regex"] = pattern or regex
kwargs.update(**current_json_schema_extra)
use_kwargs = {k: v for k, v in kwargs.items() if v is not _Unset}
super().__init__(**use_kwargs)
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.default})"
class Form(Body): # type: ignore[misc]
def __init__(
self,
default: Any = Undefined,
*,
default_factory: Union[Callable[[], Any], None] = _Unset,
annotation: Optional[Any] = None,
media_type: str = "application/x-www-form-urlencoded",
alias: Optional[str] = None,
alias_priority: Union[int, None] = _Unset,
# TODO: update when deprecating Pydantic v1, import these types
# validation_alias: str | AliasPath | AliasChoices | None
validation_alias: Union[str, None] = None,
serialization_alias: Union[str, None] = None,
title: Optional[str] = None,
description: Optional[str] = None,
gt: Optional[float] = None,
ge: Optional[float] = None,
lt: Optional[float] = None,
le: Optional[float] = None,
min_length: Optional[int] = None,
max_length: Optional[int] = None,
pattern: Optional[str] = None,
regex: Annotated[
Optional[str],
deprecated(
"Deprecated in FastAPI 0.100.0 and Pydantic v2, use `pattern` instead."
),
] = None,
discriminator: Union[str, None] = None,
strict: Union[bool, None] = _Unset,
multiple_of: Union[float, None] = _Unset,
allow_inf_nan: Union[bool, None] = _Unset,
max_digits: Union[int, None] = _Unset,
decimal_places: Union[int, None] = _Unset,
examples: Optional[List[Any]] = None,
example: Annotated[
Optional[Any],
deprecated(
"Deprecated in OpenAPI 3.1.0 that now uses JSON Schema 2020-12, "
"although still supported. Use examples instead."
),
] = _Unset,
openapi_examples: Optional[Dict[str, Example]] = None,
deprecated: Union[deprecated, str, bool, None] = None,
include_in_schema: bool = True,
json_schema_extra: Union[Dict[str, Any], None] = None,
**extra: Any,
):
super().__init__(
default=default,
default_factory=default_factory,
annotation=annotation,
media_type=media_type,
alias=alias,
alias_priority=alias_priority,
validation_alias=validation_alias,
serialization_alias=serialization_alias,
title=title,
description=description,
gt=gt,
ge=ge,
lt=lt,
le=le,
min_length=min_length,
max_length=max_length,
pattern=pattern,
regex=regex,
discriminator=discriminator,
strict=strict,
multiple_of=multiple_of,
allow_inf_nan=allow_inf_nan,
max_digits=max_digits,
decimal_places=decimal_places,
deprecated=deprecated,
example=example,
examples=examples,
openapi_examples=openapi_examples,
include_in_schema=include_in_schema,
json_schema_extra=json_schema_extra,
**extra,
)
class File(Form): # type: ignore[misc]
def __init__(
self,
default: Any = Undefined,
*,
default_factory: Union[Callable[[], Any], None] = _Unset,
annotation: Optional[Any] = None,
media_type: str = "multipart/form-data",
alias: Optional[str] = None,
alias_priority: Union[int, None] = _Unset,
# TODO: update when deprecating Pydantic v1, import these types
# validation_alias: str | AliasPath | AliasChoices | None
validation_alias: Union[str, None] = None,
serialization_alias: Union[str, None] = None,
title: Optional[str] = None,
description: Optional[str] = None,
gt: Optional[float] = None,
ge: Optional[float] = None,
lt: Optional[float] = None,
le: Optional[float] = None,
min_length: Optional[int] = None,
max_length: Optional[int] = None,
pattern: Optional[str] = None,
regex: Annotated[
Optional[str],
deprecated(
"Deprecated in FastAPI 0.100.0 and Pydantic v2, use `pattern` instead."
),
] = None,
discriminator: Union[str, None] = None,
strict: Union[bool, None] = _Unset,
multiple_of: Union[float, None] = _Unset,
allow_inf_nan: Union[bool, None] = _Unset,
max_digits: Union[int, None] = _Unset,
decimal_places: Union[int, None] = _Unset,
examples: Optional[List[Any]] = None,
example: Annotated[
Optional[Any],
deprecated(
"Deprecated in OpenAPI 3.1.0 that now uses JSON Schema 2020-12, "
"although still supported. Use examples instead."
),
] = _Unset,
openapi_examples: Optional[Dict[str, Example]] = None,
deprecated: Union[deprecated, str, bool, None] = None,
include_in_schema: bool = True,
json_schema_extra: Union[Dict[str, Any], None] = None,
**extra: Any,
):
super().__init__(
default=default,
default_factory=default_factory,
annotation=annotation,
media_type=media_type,
alias=alias,
alias_priority=alias_priority,
validation_alias=validation_alias,
serialization_alias=serialization_alias,
title=title,
description=description,
gt=gt,
ge=ge,
lt=lt,
le=le,
min_length=min_length,
max_length=max_length,
pattern=pattern,
regex=regex,
discriminator=discriminator,
strict=strict,
multiple_of=multiple_of,
allow_inf_nan=allow_inf_nan,
max_digits=max_digits,
decimal_places=decimal_places,
deprecated=deprecated,
example=example,
examples=examples,
openapi_examples=openapi_examples,
include_in_schema=include_in_schema,
json_schema_extra=json_schema_extra,
**extra,
)

View File

@@ -1,11 +1,11 @@
import types
from enum import Enum
from typing import Any, Callable, Dict, Set, Type, TypeVar, Union
from typing import Any, Callable, Dict, Optional, Set, Tuple, Type, TypeVar, Union
from pydantic import BaseModel
DecoratedCallable = TypeVar("DecoratedCallable", bound=Callable[..., Any])
UnionType = getattr(types, "UnionType", Union)
NoneType = getattr(types, "UnionType", None)
ModelNameMap = Dict[Union[Type[BaseModel], Type[Enum]], str]
IncEx = Union[Set[int], Set[str], Dict[int, Any], Dict[str, Any]]
DependencyCacheKey = Tuple[Optional[Callable[..., Any]], Tuple[str, ...], str]

View File

@@ -23,10 +23,12 @@ from fastapi._compat import (
Undefined,
UndefinedType,
Validator,
annotation_is_pydantic_v1,
lenient_issubclass,
may_v1,
)
from fastapi.datastructures import DefaultPlaceholder, DefaultType
from pydantic import BaseModel, create_model
from pydantic import BaseModel
from pydantic.fields import FieldInfo
from typing_extensions import Literal
@@ -34,9 +36,9 @@ if TYPE_CHECKING: # pragma: nocover
from .routing import APIRoute
# Cache for `create_cloned_field`
_CLONED_TYPES_CACHE: MutableMapping[
Type[BaseModel], Type[BaseModel]
] = WeakKeyDictionary()
_CLONED_TYPES_CACHE: MutableMapping[Type[BaseModel], Type[BaseModel]] = (
WeakKeyDictionary()
)
def is_body_allowed_for_status_code(status_code: Union[int, str, None]) -> bool:
@@ -53,60 +55,81 @@ def is_body_allowed_for_status_code(status_code: Union[int, str, None]) -> bool:
}:
return True
current_status_code = int(status_code)
return not (current_status_code < 200 or current_status_code in {204, 304})
return not (current_status_code < 200 or current_status_code in {204, 205, 304})
def get_path_param_names(path: str) -> Set[str]:
return set(re.findall("{(.*?)}", path))
def create_response_field(
_invalid_args_message = (
"Invalid args for response field! Hint: "
"check that {type_} is a valid Pydantic field type. "
"If you are using a return type annotation that is not a valid Pydantic "
"field (e.g. Union[Response, dict, None]) you can disable generating the "
"response model from the type annotation with the path operation decorator "
"parameter response_model=None. Read more: "
"https://fastapi.tiangolo.com/tutorial/response-model/"
)
def create_model_field(
name: str,
type_: Type[Any],
type_: Any,
class_validators: Optional[Dict[str, Validator]] = None,
default: Optional[Any] = Undefined,
required: Union[bool, UndefinedType] = Undefined,
model_config: Type[BaseConfig] = BaseConfig,
model_config: Union[Type[BaseConfig], None] = None,
field_info: Optional[FieldInfo] = None,
alias: Optional[str] = None,
mode: Literal["validation", "serialization"] = "validation",
version: Literal["1", "auto"] = "auto",
) -> ModelField:
"""
Create a new response field. Raises if type_ is invalid.
"""
class_validators = class_validators or {}
if PYDANTIC_V2:
v1_model_config = may_v1.BaseConfig
v1_field_info = field_info or may_v1.FieldInfo()
v1_kwargs = {
"name": name,
"field_info": v1_field_info,
"type_": type_,
"class_validators": class_validators,
"default": default,
"required": required,
"model_config": v1_model_config,
"alias": alias,
}
if (
annotation_is_pydantic_v1(type_)
or isinstance(field_info, may_v1.FieldInfo)
or version == "1"
):
from fastapi._compat import v1
try:
return v1.ModelField(**v1_kwargs) # type: ignore[no-any-return]
except RuntimeError:
raise fastapi.exceptions.FastAPIError(_invalid_args_message) from None
elif PYDANTIC_V2:
from ._compat import v2
field_info = field_info or FieldInfo(
annotation=type_, default=default, alias=alias
)
else:
field_info = field_info or FieldInfo()
kwargs = {"name": name, "field_info": field_info}
if PYDANTIC_V2:
kwargs.update({"mode": mode})
else:
kwargs.update(
{
"type_": type_,
"class_validators": class_validators,
"default": default,
"required": required,
"model_config": model_config,
"alias": alias,
}
)
kwargs = {"mode": mode, "name": name, "field_info": field_info}
try:
return v2.ModelField(**kwargs) # type: ignore[return-value,arg-type]
except PydanticSchemaGenerationError:
raise fastapi.exceptions.FastAPIError(_invalid_args_message) from None
# Pydantic v2 is not installed, but it's not a Pydantic v1 ModelField, it could be
# a Pydantic v1 type, like a constrained int
from fastapi._compat import v1
try:
return ModelField(**kwargs) # type: ignore[arg-type]
except (RuntimeError, PydanticSchemaGenerationError):
raise fastapi.exceptions.FastAPIError(
"Invalid args for response field! Hint: "
f"check that {type_} is a valid Pydantic field type. "
"If you are using a return type annotation that is not a valid Pydantic "
"field (e.g. Union[Response, dict, None]) you can disable generating the "
"response model from the type annotation with the path operation decorator "
"parameter response_model=None. Read more: "
"https://fastapi.tiangolo.com/tutorial/response-model/"
) from None
return v1.ModelField(**v1_kwargs) # type: ignore[no-any-return]
except RuntimeError:
raise fastapi.exceptions.FastAPIError(_invalid_args_message) from None
def create_cloned_field(
@@ -115,7 +138,13 @@ def create_cloned_field(
cloned_types: Optional[MutableMapping[Type[BaseModel], Type[BaseModel]]] = None,
) -> ModelField:
if PYDANTIC_V2:
return field
from ._compat import v2
if isinstance(field, v2.ModelField):
return field
from fastapi._compat import v1
# cloned_types caches already cloned types to support recursive models and improve
# performance by avoiding unnecessary cloning
if cloned_types is None:
@@ -125,21 +154,23 @@ def create_cloned_field(
if is_dataclass(original_type) and hasattr(original_type, "__pydantic_model__"):
original_type = original_type.__pydantic_model__
use_type = original_type
if lenient_issubclass(original_type, BaseModel):
original_type = cast(Type[BaseModel], original_type)
if lenient_issubclass(original_type, v1.BaseModel):
original_type = cast(Type[v1.BaseModel], original_type)
use_type = cloned_types.get(original_type)
if use_type is None:
use_type = create_model(original_type.__name__, __base__=original_type)
use_type = v1.create_model(original_type.__name__, __base__=original_type)
cloned_types[original_type] = use_type
for f in original_type.__fields__.values():
use_type.__fields__[f.name] = create_cloned_field(
f, cloned_types=cloned_types
f,
cloned_types=cloned_types,
)
new_field = create_response_field(name=field.name, type_=use_type)
new_field = create_model_field(name=field.name, type_=use_type, version="1")
new_field.has_alias = field.has_alias # type: ignore[attr-defined]
new_field.alias = field.alias # type: ignore[misc]
new_field.class_validators = field.class_validators # type: ignore[attr-defined]
new_field.default = field.default # type: ignore[misc]
new_field.default_factory = field.default_factory # type: ignore[attr-defined]
new_field.required = field.required # type: ignore[misc]
new_field.model_config = field.model_config # type: ignore[attr-defined]
new_field.field_info = field.field_info
@@ -173,17 +204,17 @@ def generate_operation_id_for_path(
DeprecationWarning,
stacklevel=2,
)
operation_id = name + path
operation_id = f"{name}{path}"
operation_id = re.sub(r"\W", "_", operation_id)
operation_id = operation_id + "_" + method.lower()
operation_id = f"{operation_id}_{method.lower()}"
return operation_id
def generate_unique_id(route: "APIRoute") -> str:
operation_id = route.name + route.path_format
operation_id = f"{route.name}{route.path_format}"
operation_id = re.sub(r"\W", "_", operation_id)
assert route.methods
operation_id = operation_id + "_" + list(route.methods)[0].lower()
operation_id = f"{operation_id}_{list(route.methods)[0].lower()}"
return operation_id
@@ -221,9 +252,3 @@ def get_value_or_default(
if not isinstance(item, DefaultPlaceholder):
return item
return first_item
def match_pydantic_error_url(error_type: str) -> Any:
from dirty_equals import IsStr
return IsStr(regex=rf"^https://errors\.pydantic\.dev/.*/v/{error_type}")