updates
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
"""FastAPI framework, high performance, easy to learn, fast to code, ready for production"""
|
||||
|
||||
__version__ = "0.104.1"
|
||||
__version__ = "0.123.0"
|
||||
|
||||
from starlette import status as status
|
||||
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
from fastapi.cli import main
|
||||
|
||||
main()
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,629 +0,0 @@
|
||||
from collections import deque
|
||||
from copy import copy
|
||||
from dataclasses import dataclass, is_dataclass
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Deque,
|
||||
Dict,
|
||||
FrozenSet,
|
||||
List,
|
||||
Mapping,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
from fastapi.exceptions import RequestErrorModel
|
||||
from fastapi.types import IncEx, ModelNameMap, UnionType
|
||||
from pydantic import BaseModel, create_model
|
||||
from pydantic.version import VERSION as PYDANTIC_VERSION
|
||||
from starlette.datastructures import UploadFile
|
||||
from typing_extensions import Annotated, Literal, get_args, get_origin
|
||||
|
||||
PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.")
|
||||
|
||||
|
||||
sequence_annotation_to_type = {
|
||||
Sequence: list,
|
||||
List: list,
|
||||
list: list,
|
||||
Tuple: tuple,
|
||||
tuple: tuple,
|
||||
Set: set,
|
||||
set: set,
|
||||
FrozenSet: frozenset,
|
||||
frozenset: frozenset,
|
||||
Deque: deque,
|
||||
deque: deque,
|
||||
}
|
||||
|
||||
sequence_types = tuple(sequence_annotation_to_type.keys())
|
||||
|
||||
if PYDANTIC_V2:
|
||||
from pydantic import PydanticSchemaGenerationError as PydanticSchemaGenerationError
|
||||
from pydantic import TypeAdapter
|
||||
from pydantic import ValidationError as ValidationError
|
||||
from pydantic._internal._schema_generation_shared import ( # type: ignore[attr-defined]
|
||||
GetJsonSchemaHandler as GetJsonSchemaHandler,
|
||||
)
|
||||
from pydantic._internal._typing_extra import eval_type_lenient
|
||||
from pydantic._internal._utils import lenient_issubclass as lenient_issubclass
|
||||
from pydantic.fields import FieldInfo
|
||||
from pydantic.json_schema import GenerateJsonSchema as GenerateJsonSchema
|
||||
from pydantic.json_schema import JsonSchemaValue as JsonSchemaValue
|
||||
from pydantic_core import CoreSchema as CoreSchema
|
||||
from pydantic_core import PydanticUndefined, PydanticUndefinedType
|
||||
from pydantic_core import Url as Url
|
||||
|
||||
try:
|
||||
from pydantic_core.core_schema import (
|
||||
with_info_plain_validator_function as with_info_plain_validator_function,
|
||||
)
|
||||
except ImportError: # pragma: no cover
|
||||
from pydantic_core.core_schema import (
|
||||
general_plain_validator_function as with_info_plain_validator_function, # noqa: F401
|
||||
)
|
||||
|
||||
Required = PydanticUndefined
|
||||
Undefined = PydanticUndefined
|
||||
UndefinedType = PydanticUndefinedType
|
||||
evaluate_forwardref = eval_type_lenient
|
||||
Validator = Any
|
||||
|
||||
class BaseConfig:
|
||||
pass
|
||||
|
||||
class ErrorWrapper(Exception):
|
||||
pass
|
||||
|
||||
@dataclass
|
||||
class ModelField:
|
||||
field_info: FieldInfo
|
||||
name: str
|
||||
mode: Literal["validation", "serialization"] = "validation"
|
||||
|
||||
@property
|
||||
def alias(self) -> str:
|
||||
a = self.field_info.alias
|
||||
return a if a is not None else self.name
|
||||
|
||||
@property
|
||||
def required(self) -> bool:
|
||||
return self.field_info.is_required()
|
||||
|
||||
@property
|
||||
def default(self) -> Any:
|
||||
return self.get_default()
|
||||
|
||||
@property
|
||||
def type_(self) -> Any:
|
||||
return self.field_info.annotation
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self._type_adapter: TypeAdapter[Any] = TypeAdapter(
|
||||
Annotated[self.field_info.annotation, self.field_info]
|
||||
)
|
||||
|
||||
def get_default(self) -> Any:
|
||||
if self.field_info.is_required():
|
||||
return Undefined
|
||||
return self.field_info.get_default(call_default_factory=True)
|
||||
|
||||
def validate(
|
||||
self,
|
||||
value: Any,
|
||||
values: Dict[str, Any] = {}, # noqa: B006
|
||||
*,
|
||||
loc: Tuple[Union[int, str], ...] = (),
|
||||
) -> Tuple[Any, Union[List[Dict[str, Any]], None]]:
|
||||
try:
|
||||
return (
|
||||
self._type_adapter.validate_python(value, from_attributes=True),
|
||||
None,
|
||||
)
|
||||
except ValidationError as exc:
|
||||
return None, _regenerate_error_with_loc(
|
||||
errors=exc.errors(), loc_prefix=loc
|
||||
)
|
||||
|
||||
def serialize(
|
||||
self,
|
||||
value: Any,
|
||||
*,
|
||||
mode: Literal["json", "python"] = "json",
|
||||
include: Union[IncEx, None] = None,
|
||||
exclude: Union[IncEx, None] = None,
|
||||
by_alias: bool = True,
|
||||
exclude_unset: bool = False,
|
||||
exclude_defaults: bool = False,
|
||||
exclude_none: bool = False,
|
||||
) -> Any:
|
||||
# What calls this code passes a value that already called
|
||||
# self._type_adapter.validate_python(value)
|
||||
return self._type_adapter.dump_python(
|
||||
value,
|
||||
mode=mode,
|
||||
include=include,
|
||||
exclude=exclude,
|
||||
by_alias=by_alias,
|
||||
exclude_unset=exclude_unset,
|
||||
exclude_defaults=exclude_defaults,
|
||||
exclude_none=exclude_none,
|
||||
)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
# Each ModelField is unique for our purposes, to allow making a dict from
|
||||
# ModelField to its JSON Schema.
|
||||
return id(self)
|
||||
|
||||
def get_annotation_from_field_info(
|
||||
annotation: Any, field_info: FieldInfo, field_name: str
|
||||
) -> Any:
|
||||
return annotation
|
||||
|
||||
def _normalize_errors(errors: Sequence[Any]) -> List[Dict[str, Any]]:
|
||||
return errors # type: ignore[return-value]
|
||||
|
||||
def _model_rebuild(model: Type[BaseModel]) -> None:
|
||||
model.model_rebuild()
|
||||
|
||||
def _model_dump(
|
||||
model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any
|
||||
) -> Any:
|
||||
return model.model_dump(mode=mode, **kwargs)
|
||||
|
||||
def _get_model_config(model: BaseModel) -> Any:
|
||||
return model.model_config
|
||||
|
||||
def get_schema_from_model_field(
|
||||
*,
|
||||
field: ModelField,
|
||||
schema_generator: GenerateJsonSchema,
|
||||
model_name_map: ModelNameMap,
|
||||
field_mapping: Dict[
|
||||
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
|
||||
],
|
||||
separate_input_output_schemas: bool = True,
|
||||
) -> Dict[str, Any]:
|
||||
override_mode: Union[Literal["validation"], None] = (
|
||||
None if separate_input_output_schemas else "validation"
|
||||
)
|
||||
# This expects that GenerateJsonSchema was already used to generate the definitions
|
||||
json_schema = field_mapping[(field, override_mode or field.mode)]
|
||||
if "$ref" not in json_schema:
|
||||
# TODO remove when deprecating Pydantic v1
|
||||
# Ref: https://github.com/pydantic/pydantic/blob/d61792cc42c80b13b23e3ffa74bc37ec7c77f7d1/pydantic/schema.py#L207
|
||||
json_schema["title"] = (
|
||||
field.field_info.title or field.alias.title().replace("_", " ")
|
||||
)
|
||||
return json_schema
|
||||
|
||||
def get_compat_model_name_map(fields: List[ModelField]) -> ModelNameMap:
|
||||
return {}
|
||||
|
||||
def get_definitions(
|
||||
*,
|
||||
fields: List[ModelField],
|
||||
schema_generator: GenerateJsonSchema,
|
||||
model_name_map: ModelNameMap,
|
||||
separate_input_output_schemas: bool = True,
|
||||
) -> Tuple[
|
||||
Dict[
|
||||
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
|
||||
],
|
||||
Dict[str, Dict[str, Any]],
|
||||
]:
|
||||
override_mode: Union[Literal["validation"], None] = (
|
||||
None if separate_input_output_schemas else "validation"
|
||||
)
|
||||
inputs = [
|
||||
(field, override_mode or field.mode, field._type_adapter.core_schema)
|
||||
for field in fields
|
||||
]
|
||||
field_mapping, definitions = schema_generator.generate_definitions(
|
||||
inputs=inputs
|
||||
)
|
||||
return field_mapping, definitions # type: ignore[return-value]
|
||||
|
||||
def is_scalar_field(field: ModelField) -> bool:
|
||||
from fastapi import params
|
||||
|
||||
return field_annotation_is_scalar(
|
||||
field.field_info.annotation
|
||||
) and not isinstance(field.field_info, params.Body)
|
||||
|
||||
def is_sequence_field(field: ModelField) -> bool:
|
||||
return field_annotation_is_sequence(field.field_info.annotation)
|
||||
|
||||
def is_scalar_sequence_field(field: ModelField) -> bool:
|
||||
return field_annotation_is_scalar_sequence(field.field_info.annotation)
|
||||
|
||||
def is_bytes_field(field: ModelField) -> bool:
|
||||
return is_bytes_or_nonable_bytes_annotation(field.type_)
|
||||
|
||||
def is_bytes_sequence_field(field: ModelField) -> bool:
|
||||
return is_bytes_sequence_annotation(field.type_)
|
||||
|
||||
def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo:
|
||||
return type(field_info).from_annotation(annotation)
|
||||
|
||||
def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]:
|
||||
origin_type = (
|
||||
get_origin(field.field_info.annotation) or field.field_info.annotation
|
||||
)
|
||||
assert issubclass(origin_type, sequence_types) # type: ignore[arg-type]
|
||||
return sequence_annotation_to_type[origin_type](value) # type: ignore[no-any-return]
|
||||
|
||||
def get_missing_field_error(loc: Tuple[str, ...]) -> Dict[str, Any]:
|
||||
error = ValidationError.from_exception_data(
|
||||
"Field required", [{"type": "missing", "loc": loc, "input": {}}]
|
||||
).errors()[0]
|
||||
error["input"] = None
|
||||
return error # type: ignore[return-value]
|
||||
|
||||
def create_body_model(
|
||||
*, fields: Sequence[ModelField], model_name: str
|
||||
) -> Type[BaseModel]:
|
||||
field_params = {f.name: (f.field_info.annotation, f.field_info) for f in fields}
|
||||
BodyModel: Type[BaseModel] = create_model(model_name, **field_params) # type: ignore[call-overload]
|
||||
return BodyModel
|
||||
|
||||
else:
|
||||
from fastapi.openapi.constants import REF_PREFIX as REF_PREFIX
|
||||
from pydantic import AnyUrl as Url # noqa: F401
|
||||
from pydantic import ( # type: ignore[assignment]
|
||||
BaseConfig as BaseConfig, # noqa: F401
|
||||
)
|
||||
from pydantic import ValidationError as ValidationError # noqa: F401
|
||||
from pydantic.class_validators import ( # type: ignore[no-redef]
|
||||
Validator as Validator, # noqa: F401
|
||||
)
|
||||
from pydantic.error_wrappers import ( # type: ignore[no-redef]
|
||||
ErrorWrapper as ErrorWrapper, # noqa: F401
|
||||
)
|
||||
from pydantic.errors import MissingError
|
||||
from pydantic.fields import ( # type: ignore[attr-defined]
|
||||
SHAPE_FROZENSET,
|
||||
SHAPE_LIST,
|
||||
SHAPE_SEQUENCE,
|
||||
SHAPE_SET,
|
||||
SHAPE_SINGLETON,
|
||||
SHAPE_TUPLE,
|
||||
SHAPE_TUPLE_ELLIPSIS,
|
||||
)
|
||||
from pydantic.fields import FieldInfo as FieldInfo
|
||||
from pydantic.fields import ( # type: ignore[no-redef,attr-defined]
|
||||
ModelField as ModelField, # noqa: F401
|
||||
)
|
||||
from pydantic.fields import ( # type: ignore[no-redef,attr-defined]
|
||||
Required as Required, # noqa: F401
|
||||
)
|
||||
from pydantic.fields import ( # type: ignore[no-redef,attr-defined]
|
||||
Undefined as Undefined,
|
||||
)
|
||||
from pydantic.fields import ( # type: ignore[no-redef, attr-defined]
|
||||
UndefinedType as UndefinedType, # noqa: F401
|
||||
)
|
||||
from pydantic.schema import (
|
||||
field_schema,
|
||||
get_flat_models_from_fields,
|
||||
get_model_name_map,
|
||||
model_process_schema,
|
||||
)
|
||||
from pydantic.schema import ( # type: ignore[no-redef] # noqa: F401
|
||||
get_annotation_from_field_info as get_annotation_from_field_info,
|
||||
)
|
||||
from pydantic.typing import ( # type: ignore[no-redef]
|
||||
evaluate_forwardref as evaluate_forwardref, # noqa: F401
|
||||
)
|
||||
from pydantic.utils import ( # type: ignore[no-redef]
|
||||
lenient_issubclass as lenient_issubclass, # noqa: F401
|
||||
)
|
||||
|
||||
GetJsonSchemaHandler = Any # type: ignore[assignment,misc]
|
||||
JsonSchemaValue = Dict[str, Any] # type: ignore[misc]
|
||||
CoreSchema = Any # type: ignore[assignment,misc]
|
||||
|
||||
sequence_shapes = {
|
||||
SHAPE_LIST,
|
||||
SHAPE_SET,
|
||||
SHAPE_FROZENSET,
|
||||
SHAPE_TUPLE,
|
||||
SHAPE_SEQUENCE,
|
||||
SHAPE_TUPLE_ELLIPSIS,
|
||||
}
|
||||
sequence_shape_to_type = {
|
||||
SHAPE_LIST: list,
|
||||
SHAPE_SET: set,
|
||||
SHAPE_TUPLE: tuple,
|
||||
SHAPE_SEQUENCE: list,
|
||||
SHAPE_TUPLE_ELLIPSIS: list,
|
||||
}
|
||||
|
||||
@dataclass
|
||||
class GenerateJsonSchema: # type: ignore[no-redef]
|
||||
ref_template: str
|
||||
|
||||
class PydanticSchemaGenerationError(Exception): # type: ignore[no-redef]
|
||||
pass
|
||||
|
||||
def with_info_plain_validator_function( # type: ignore[misc]
|
||||
function: Callable[..., Any],
|
||||
*,
|
||||
ref: Union[str, None] = None,
|
||||
metadata: Any = None,
|
||||
serialization: Any = None,
|
||||
) -> Any:
|
||||
return {}
|
||||
|
||||
def get_model_definitions(
|
||||
*,
|
||||
flat_models: Set[Union[Type[BaseModel], Type[Enum]]],
|
||||
model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str],
|
||||
) -> Dict[str, Any]:
|
||||
definitions: Dict[str, Dict[str, Any]] = {}
|
||||
for model in flat_models:
|
||||
m_schema, m_definitions, m_nested_models = model_process_schema(
|
||||
model, model_name_map=model_name_map, ref_prefix=REF_PREFIX
|
||||
)
|
||||
definitions.update(m_definitions)
|
||||
model_name = model_name_map[model]
|
||||
if "description" in m_schema:
|
||||
m_schema["description"] = m_schema["description"].split("\f")[0]
|
||||
definitions[model_name] = m_schema
|
||||
return definitions
|
||||
|
||||
def is_pv1_scalar_field(field: ModelField) -> bool:
|
||||
from fastapi import params
|
||||
|
||||
field_info = field.field_info
|
||||
if not (
|
||||
field.shape == SHAPE_SINGLETON # type: ignore[attr-defined]
|
||||
and not lenient_issubclass(field.type_, BaseModel)
|
||||
and not lenient_issubclass(field.type_, dict)
|
||||
and not field_annotation_is_sequence(field.type_)
|
||||
and not is_dataclass(field.type_)
|
||||
and not isinstance(field_info, params.Body)
|
||||
):
|
||||
return False
|
||||
if field.sub_fields: # type: ignore[attr-defined]
|
||||
if not all(
|
||||
is_pv1_scalar_field(f)
|
||||
for f in field.sub_fields # type: ignore[attr-defined]
|
||||
):
|
||||
return False
|
||||
return True
|
||||
|
||||
def is_pv1_scalar_sequence_field(field: ModelField) -> bool:
|
||||
if (field.shape in sequence_shapes) and not lenient_issubclass( # type: ignore[attr-defined]
|
||||
field.type_, BaseModel
|
||||
):
|
||||
if field.sub_fields is not None: # type: ignore[attr-defined]
|
||||
for sub_field in field.sub_fields: # type: ignore[attr-defined]
|
||||
if not is_pv1_scalar_field(sub_field):
|
||||
return False
|
||||
return True
|
||||
if _annotation_is_sequence(field.type_):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _normalize_errors(errors: Sequence[Any]) -> List[Dict[str, Any]]:
|
||||
use_errors: List[Any] = []
|
||||
for error in errors:
|
||||
if isinstance(error, ErrorWrapper):
|
||||
new_errors = ValidationError( # type: ignore[call-arg]
|
||||
errors=[error], model=RequestErrorModel
|
||||
).errors()
|
||||
use_errors.extend(new_errors)
|
||||
elif isinstance(error, list):
|
||||
use_errors.extend(_normalize_errors(error))
|
||||
else:
|
||||
use_errors.append(error)
|
||||
return use_errors
|
||||
|
||||
def _model_rebuild(model: Type[BaseModel]) -> None:
|
||||
model.update_forward_refs()
|
||||
|
||||
def _model_dump(
|
||||
model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any
|
||||
) -> Any:
|
||||
return model.dict(**kwargs)
|
||||
|
||||
def _get_model_config(model: BaseModel) -> Any:
|
||||
return model.__config__ # type: ignore[attr-defined]
|
||||
|
||||
def get_schema_from_model_field(
|
||||
*,
|
||||
field: ModelField,
|
||||
schema_generator: GenerateJsonSchema,
|
||||
model_name_map: ModelNameMap,
|
||||
field_mapping: Dict[
|
||||
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
|
||||
],
|
||||
separate_input_output_schemas: bool = True,
|
||||
) -> Dict[str, Any]:
|
||||
# This expects that GenerateJsonSchema was already used to generate the definitions
|
||||
return field_schema( # type: ignore[no-any-return]
|
||||
field, model_name_map=model_name_map, ref_prefix=REF_PREFIX
|
||||
)[0]
|
||||
|
||||
def get_compat_model_name_map(fields: List[ModelField]) -> ModelNameMap:
|
||||
models = get_flat_models_from_fields(fields, known_models=set())
|
||||
return get_model_name_map(models) # type: ignore[no-any-return]
|
||||
|
||||
def get_definitions(
|
||||
*,
|
||||
fields: List[ModelField],
|
||||
schema_generator: GenerateJsonSchema,
|
||||
model_name_map: ModelNameMap,
|
||||
separate_input_output_schemas: bool = True,
|
||||
) -> Tuple[
|
||||
Dict[
|
||||
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
|
||||
],
|
||||
Dict[str, Dict[str, Any]],
|
||||
]:
|
||||
models = get_flat_models_from_fields(fields, known_models=set())
|
||||
return {}, get_model_definitions(
|
||||
flat_models=models, model_name_map=model_name_map
|
||||
)
|
||||
|
||||
def is_scalar_field(field: ModelField) -> bool:
|
||||
return is_pv1_scalar_field(field)
|
||||
|
||||
def is_sequence_field(field: ModelField) -> bool:
|
||||
return field.shape in sequence_shapes or _annotation_is_sequence(field.type_) # type: ignore[attr-defined]
|
||||
|
||||
def is_scalar_sequence_field(field: ModelField) -> bool:
|
||||
return is_pv1_scalar_sequence_field(field)
|
||||
|
||||
def is_bytes_field(field: ModelField) -> bool:
|
||||
return lenient_issubclass(field.type_, bytes)
|
||||
|
||||
def is_bytes_sequence_field(field: ModelField) -> bool:
|
||||
return field.shape in sequence_shapes and lenient_issubclass(field.type_, bytes) # type: ignore[attr-defined]
|
||||
|
||||
def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo:
|
||||
return copy(field_info)
|
||||
|
||||
def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]:
|
||||
return sequence_shape_to_type[field.shape](value) # type: ignore[no-any-return,attr-defined]
|
||||
|
||||
def get_missing_field_error(loc: Tuple[str, ...]) -> Dict[str, Any]:
|
||||
missing_field_error = ErrorWrapper(MissingError(), loc=loc) # type: ignore[call-arg]
|
||||
new_error = ValidationError([missing_field_error], RequestErrorModel)
|
||||
return new_error.errors()[0] # type: ignore[return-value]
|
||||
|
||||
def create_body_model(
|
||||
*, fields: Sequence[ModelField], model_name: str
|
||||
) -> Type[BaseModel]:
|
||||
BodyModel = create_model(model_name)
|
||||
for f in fields:
|
||||
BodyModel.__fields__[f.name] = f # type: ignore[index]
|
||||
return BodyModel
|
||||
|
||||
|
||||
def _regenerate_error_with_loc(
|
||||
*, errors: Sequence[Any], loc_prefix: Tuple[Union[str, int], ...]
|
||||
) -> List[Dict[str, Any]]:
|
||||
updated_loc_errors: List[Any] = [
|
||||
{**err, "loc": loc_prefix + err.get("loc", ())}
|
||||
for err in _normalize_errors(errors)
|
||||
]
|
||||
|
||||
return updated_loc_errors
|
||||
|
||||
|
||||
def _annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool:
|
||||
if lenient_issubclass(annotation, (str, bytes)):
|
||||
return False
|
||||
return lenient_issubclass(annotation, sequence_types)
|
||||
|
||||
|
||||
def field_annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool:
|
||||
return _annotation_is_sequence(annotation) or _annotation_is_sequence(
|
||||
get_origin(annotation)
|
||||
)
|
||||
|
||||
|
||||
def value_is_sequence(value: Any) -> bool:
|
||||
return isinstance(value, sequence_types) and not isinstance(value, (str, bytes)) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def _annotation_is_complex(annotation: Union[Type[Any], None]) -> bool:
|
||||
return (
|
||||
lenient_issubclass(annotation, (BaseModel, Mapping, UploadFile))
|
||||
or _annotation_is_sequence(annotation)
|
||||
or is_dataclass(annotation)
|
||||
)
|
||||
|
||||
|
||||
def field_annotation_is_complex(annotation: Union[Type[Any], None]) -> bool:
|
||||
origin = get_origin(annotation)
|
||||
if origin is Union or origin is UnionType:
|
||||
return any(field_annotation_is_complex(arg) for arg in get_args(annotation))
|
||||
|
||||
return (
|
||||
_annotation_is_complex(annotation)
|
||||
or _annotation_is_complex(origin)
|
||||
or hasattr(origin, "__pydantic_core_schema__")
|
||||
or hasattr(origin, "__get_pydantic_core_schema__")
|
||||
)
|
||||
|
||||
|
||||
def field_annotation_is_scalar(annotation: Any) -> bool:
|
||||
# handle Ellipsis here to make tuple[int, ...] work nicely
|
||||
return annotation is Ellipsis or not field_annotation_is_complex(annotation)
|
||||
|
||||
|
||||
def field_annotation_is_scalar_sequence(annotation: Union[Type[Any], None]) -> bool:
|
||||
origin = get_origin(annotation)
|
||||
if origin is Union or origin is UnionType:
|
||||
at_least_one_scalar_sequence = False
|
||||
for arg in get_args(annotation):
|
||||
if field_annotation_is_scalar_sequence(arg):
|
||||
at_least_one_scalar_sequence = True
|
||||
continue
|
||||
elif not field_annotation_is_scalar(arg):
|
||||
return False
|
||||
return at_least_one_scalar_sequence
|
||||
return field_annotation_is_sequence(annotation) and all(
|
||||
field_annotation_is_scalar(sub_annotation)
|
||||
for sub_annotation in get_args(annotation)
|
||||
)
|
||||
|
||||
|
||||
def is_bytes_or_nonable_bytes_annotation(annotation: Any) -> bool:
|
||||
if lenient_issubclass(annotation, bytes):
|
||||
return True
|
||||
origin = get_origin(annotation)
|
||||
if origin is Union or origin is UnionType:
|
||||
for arg in get_args(annotation):
|
||||
if lenient_issubclass(arg, bytes):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_uploadfile_or_nonable_uploadfile_annotation(annotation: Any) -> bool:
|
||||
if lenient_issubclass(annotation, UploadFile):
|
||||
return True
|
||||
origin = get_origin(annotation)
|
||||
if origin is Union or origin is UnionType:
|
||||
for arg in get_args(annotation):
|
||||
if lenient_issubclass(arg, UploadFile):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_bytes_sequence_annotation(annotation: Any) -> bool:
|
||||
origin = get_origin(annotation)
|
||||
if origin is Union or origin is UnionType:
|
||||
at_least_one = False
|
||||
for arg in get_args(annotation):
|
||||
if is_bytes_sequence_annotation(arg):
|
||||
at_least_one = True
|
||||
continue
|
||||
return at_least_one
|
||||
return field_annotation_is_sequence(annotation) and all(
|
||||
is_bytes_or_nonable_bytes_annotation(sub_annotation)
|
||||
for sub_annotation in get_args(annotation)
|
||||
)
|
||||
|
||||
|
||||
def is_uploadfile_sequence_annotation(annotation: Any) -> bool:
|
||||
origin = get_origin(annotation)
|
||||
if origin is Union or origin is UnionType:
|
||||
at_least_one = False
|
||||
for arg in get_args(annotation):
|
||||
if is_uploadfile_sequence_annotation(arg):
|
||||
at_least_one = True
|
||||
continue
|
||||
return at_least_one
|
||||
return field_annotation_is_sequence(annotation) and all(
|
||||
is_uploadfile_or_nonable_uploadfile_annotation(sub_annotation)
|
||||
for sub_annotation in get_args(annotation)
|
||||
)
|
||||
@@ -0,0 +1,50 @@
|
||||
from .main import BaseConfig as BaseConfig
|
||||
from .main import PydanticSchemaGenerationError as PydanticSchemaGenerationError
|
||||
from .main import RequiredParam as RequiredParam
|
||||
from .main import Undefined as Undefined
|
||||
from .main import UndefinedType as UndefinedType
|
||||
from .main import Url as Url
|
||||
from .main import Validator as Validator
|
||||
from .main import _get_model_config as _get_model_config
|
||||
from .main import _is_error_wrapper as _is_error_wrapper
|
||||
from .main import _is_model_class as _is_model_class
|
||||
from .main import _is_model_field as _is_model_field
|
||||
from .main import _is_undefined as _is_undefined
|
||||
from .main import _model_dump as _model_dump
|
||||
from .main import _model_rebuild as _model_rebuild
|
||||
from .main import copy_field_info as copy_field_info
|
||||
from .main import create_body_model as create_body_model
|
||||
from .main import evaluate_forwardref as evaluate_forwardref
|
||||
from .main import get_annotation_from_field_info as get_annotation_from_field_info
|
||||
from .main import get_cached_model_fields as get_cached_model_fields
|
||||
from .main import get_compat_model_name_map as get_compat_model_name_map
|
||||
from .main import get_definitions as get_definitions
|
||||
from .main import get_missing_field_error as get_missing_field_error
|
||||
from .main import get_schema_from_model_field as get_schema_from_model_field
|
||||
from .main import is_bytes_field as is_bytes_field
|
||||
from .main import is_bytes_sequence_field as is_bytes_sequence_field
|
||||
from .main import is_scalar_field as is_scalar_field
|
||||
from .main import is_scalar_sequence_field as is_scalar_sequence_field
|
||||
from .main import is_sequence_field as is_sequence_field
|
||||
from .main import serialize_sequence_value as serialize_sequence_value
|
||||
from .main import (
|
||||
with_info_plain_validator_function as with_info_plain_validator_function,
|
||||
)
|
||||
from .may_v1 import CoreSchema as CoreSchema
|
||||
from .may_v1 import GetJsonSchemaHandler as GetJsonSchemaHandler
|
||||
from .may_v1 import JsonSchemaValue as JsonSchemaValue
|
||||
from .may_v1 import _normalize_errors as _normalize_errors
|
||||
from .model_field import ModelField as ModelField
|
||||
from .shared import PYDANTIC_V2 as PYDANTIC_V2
|
||||
from .shared import PYDANTIC_VERSION_MINOR_TUPLE as PYDANTIC_VERSION_MINOR_TUPLE
|
||||
from .shared import annotation_is_pydantic_v1 as annotation_is_pydantic_v1
|
||||
from .shared import field_annotation_is_scalar as field_annotation_is_scalar
|
||||
from .shared import (
|
||||
is_uploadfile_or_nonable_uploadfile_annotation as is_uploadfile_or_nonable_uploadfile_annotation,
|
||||
)
|
||||
from .shared import (
|
||||
is_uploadfile_sequence_annotation as is_uploadfile_sequence_annotation,
|
||||
)
|
||||
from .shared import lenient_issubclass as lenient_issubclass
|
||||
from .shared import sequence_types as sequence_types
|
||||
from .shared import value_is_sequence as value_is_sequence
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,362 @@
|
||||
import sys
|
||||
from functools import lru_cache
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
)
|
||||
|
||||
from fastapi._compat import may_v1
|
||||
from fastapi._compat.shared import PYDANTIC_V2, lenient_issubclass
|
||||
from fastapi.types import ModelNameMap
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Literal
|
||||
|
||||
from .model_field import ModelField
|
||||
|
||||
if PYDANTIC_V2:
|
||||
from .v2 import BaseConfig as BaseConfig
|
||||
from .v2 import FieldInfo as FieldInfo
|
||||
from .v2 import PydanticSchemaGenerationError as PydanticSchemaGenerationError
|
||||
from .v2 import RequiredParam as RequiredParam
|
||||
from .v2 import Undefined as Undefined
|
||||
from .v2 import UndefinedType as UndefinedType
|
||||
from .v2 import Url as Url
|
||||
from .v2 import Validator as Validator
|
||||
from .v2 import evaluate_forwardref as evaluate_forwardref
|
||||
from .v2 import get_missing_field_error as get_missing_field_error
|
||||
from .v2 import (
|
||||
with_info_plain_validator_function as with_info_plain_validator_function,
|
||||
)
|
||||
else:
|
||||
from .v1 import BaseConfig as BaseConfig # type: ignore[assignment]
|
||||
from .v1 import FieldInfo as FieldInfo
|
||||
from .v1 import ( # type: ignore[assignment]
|
||||
PydanticSchemaGenerationError as PydanticSchemaGenerationError,
|
||||
)
|
||||
from .v1 import RequiredParam as RequiredParam
|
||||
from .v1 import Undefined as Undefined
|
||||
from .v1 import UndefinedType as UndefinedType
|
||||
from .v1 import Url as Url # type: ignore[assignment]
|
||||
from .v1 import Validator as Validator
|
||||
from .v1 import evaluate_forwardref as evaluate_forwardref
|
||||
from .v1 import get_missing_field_error as get_missing_field_error
|
||||
from .v1 import ( # type: ignore[assignment]
|
||||
with_info_plain_validator_function as with_info_plain_validator_function,
|
||||
)
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_cached_model_fields(model: Type[BaseModel]) -> List[ModelField]:
|
||||
if lenient_issubclass(model, may_v1.BaseModel):
|
||||
from fastapi._compat import v1
|
||||
|
||||
return v1.get_model_fields(model)
|
||||
else:
|
||||
from . import v2
|
||||
|
||||
return v2.get_model_fields(model) # type: ignore[return-value]
|
||||
|
||||
|
||||
def _is_undefined(value: object) -> bool:
|
||||
if isinstance(value, may_v1.UndefinedType):
|
||||
return True
|
||||
elif PYDANTIC_V2:
|
||||
from . import v2
|
||||
|
||||
return isinstance(value, v2.UndefinedType)
|
||||
return False
|
||||
|
||||
|
||||
def _get_model_config(model: BaseModel) -> Any:
|
||||
if isinstance(model, may_v1.BaseModel):
|
||||
from fastapi._compat import v1
|
||||
|
||||
return v1._get_model_config(model)
|
||||
elif PYDANTIC_V2:
|
||||
from . import v2
|
||||
|
||||
return v2._get_model_config(model)
|
||||
|
||||
|
||||
def _model_dump(
|
||||
model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any
|
||||
) -> Any:
|
||||
if isinstance(model, may_v1.BaseModel):
|
||||
from fastapi._compat import v1
|
||||
|
||||
return v1._model_dump(model, mode=mode, **kwargs)
|
||||
elif PYDANTIC_V2:
|
||||
from . import v2
|
||||
|
||||
return v2._model_dump(model, mode=mode, **kwargs)
|
||||
|
||||
|
||||
def _is_error_wrapper(exc: Exception) -> bool:
|
||||
if isinstance(exc, may_v1.ErrorWrapper):
|
||||
return True
|
||||
elif PYDANTIC_V2:
|
||||
from . import v2
|
||||
|
||||
return isinstance(exc, v2.ErrorWrapper)
|
||||
return False
|
||||
|
||||
|
||||
def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo:
|
||||
if isinstance(field_info, may_v1.FieldInfo):
|
||||
from fastapi._compat import v1
|
||||
|
||||
return v1.copy_field_info(field_info=field_info, annotation=annotation)
|
||||
else:
|
||||
assert PYDANTIC_V2
|
||||
from . import v2
|
||||
|
||||
return v2.copy_field_info(field_info=field_info, annotation=annotation)
|
||||
|
||||
|
||||
def create_body_model(
|
||||
*, fields: Sequence[ModelField], model_name: str
|
||||
) -> Type[BaseModel]:
|
||||
if fields and isinstance(fields[0], may_v1.ModelField):
|
||||
from fastapi._compat import v1
|
||||
|
||||
return v1.create_body_model(fields=fields, model_name=model_name)
|
||||
else:
|
||||
assert PYDANTIC_V2
|
||||
from . import v2
|
||||
|
||||
return v2.create_body_model(fields=fields, model_name=model_name) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def get_annotation_from_field_info(
|
||||
annotation: Any, field_info: FieldInfo, field_name: str
|
||||
) -> Any:
|
||||
if isinstance(field_info, may_v1.FieldInfo):
|
||||
from fastapi._compat import v1
|
||||
|
||||
return v1.get_annotation_from_field_info(
|
||||
annotation=annotation, field_info=field_info, field_name=field_name
|
||||
)
|
||||
else:
|
||||
assert PYDANTIC_V2
|
||||
from . import v2
|
||||
|
||||
return v2.get_annotation_from_field_info(
|
||||
annotation=annotation, field_info=field_info, field_name=field_name
|
||||
)
|
||||
|
||||
|
||||
def is_bytes_field(field: ModelField) -> bool:
|
||||
if isinstance(field, may_v1.ModelField):
|
||||
from fastapi._compat import v1
|
||||
|
||||
return v1.is_bytes_field(field)
|
||||
else:
|
||||
assert PYDANTIC_V2
|
||||
from . import v2
|
||||
|
||||
return v2.is_bytes_field(field) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def is_bytes_sequence_field(field: ModelField) -> bool:
|
||||
if isinstance(field, may_v1.ModelField):
|
||||
from fastapi._compat import v1
|
||||
|
||||
return v1.is_bytes_sequence_field(field)
|
||||
else:
|
||||
assert PYDANTIC_V2
|
||||
from . import v2
|
||||
|
||||
return v2.is_bytes_sequence_field(field) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def is_scalar_field(field: ModelField) -> bool:
|
||||
if isinstance(field, may_v1.ModelField):
|
||||
from fastapi._compat import v1
|
||||
|
||||
return v1.is_scalar_field(field)
|
||||
else:
|
||||
assert PYDANTIC_V2
|
||||
from . import v2
|
||||
|
||||
return v2.is_scalar_field(field) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def is_scalar_sequence_field(field: ModelField) -> bool:
|
||||
if isinstance(field, may_v1.ModelField):
|
||||
from fastapi._compat import v1
|
||||
|
||||
return v1.is_scalar_sequence_field(field)
|
||||
else:
|
||||
assert PYDANTIC_V2
|
||||
from . import v2
|
||||
|
||||
return v2.is_scalar_sequence_field(field) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def is_sequence_field(field: ModelField) -> bool:
|
||||
if isinstance(field, may_v1.ModelField):
|
||||
from fastapi._compat import v1
|
||||
|
||||
return v1.is_sequence_field(field)
|
||||
else:
|
||||
assert PYDANTIC_V2
|
||||
from . import v2
|
||||
|
||||
return v2.is_sequence_field(field) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]:
|
||||
if isinstance(field, may_v1.ModelField):
|
||||
from fastapi._compat import v1
|
||||
|
||||
return v1.serialize_sequence_value(field=field, value=value)
|
||||
else:
|
||||
assert PYDANTIC_V2
|
||||
from . import v2
|
||||
|
||||
return v2.serialize_sequence_value(field=field, value=value) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def _model_rebuild(model: Type[BaseModel]) -> None:
|
||||
if lenient_issubclass(model, may_v1.BaseModel):
|
||||
from fastapi._compat import v1
|
||||
|
||||
v1._model_rebuild(model)
|
||||
elif PYDANTIC_V2:
|
||||
from . import v2
|
||||
|
||||
v2._model_rebuild(model)
|
||||
|
||||
|
||||
def get_compat_model_name_map(fields: List[ModelField]) -> ModelNameMap:
|
||||
v1_model_fields = [
|
||||
field for field in fields if isinstance(field, may_v1.ModelField)
|
||||
]
|
||||
if v1_model_fields:
|
||||
from fastapi._compat import v1
|
||||
|
||||
v1_flat_models = v1.get_flat_models_from_fields(
|
||||
v1_model_fields, known_models=set()
|
||||
)
|
||||
all_flat_models = v1_flat_models
|
||||
else:
|
||||
all_flat_models = set()
|
||||
if PYDANTIC_V2:
|
||||
from . import v2
|
||||
|
||||
v2_model_fields = [
|
||||
field for field in fields if isinstance(field, v2.ModelField)
|
||||
]
|
||||
v2_flat_models = v2.get_flat_models_from_fields(
|
||||
v2_model_fields, known_models=set()
|
||||
)
|
||||
all_flat_models = all_flat_models.union(v2_flat_models)
|
||||
|
||||
model_name_map = v2.get_model_name_map(all_flat_models)
|
||||
return model_name_map
|
||||
from fastapi._compat import v1
|
||||
|
||||
model_name_map = v1.get_model_name_map(all_flat_models)
|
||||
return model_name_map
|
||||
|
||||
|
||||
def get_definitions(
|
||||
*,
|
||||
fields: List[ModelField],
|
||||
model_name_map: ModelNameMap,
|
||||
separate_input_output_schemas: bool = True,
|
||||
) -> Tuple[
|
||||
Dict[
|
||||
Tuple[ModelField, Literal["validation", "serialization"]],
|
||||
may_v1.JsonSchemaValue,
|
||||
],
|
||||
Dict[str, Dict[str, Any]],
|
||||
]:
|
||||
if sys.version_info < (3, 14):
|
||||
v1_fields = [field for field in fields if isinstance(field, may_v1.ModelField)]
|
||||
v1_field_maps, v1_definitions = may_v1.get_definitions(
|
||||
fields=v1_fields,
|
||||
model_name_map=model_name_map,
|
||||
separate_input_output_schemas=separate_input_output_schemas,
|
||||
)
|
||||
if not PYDANTIC_V2:
|
||||
return v1_field_maps, v1_definitions
|
||||
else:
|
||||
from . import v2
|
||||
|
||||
v2_fields = [field for field in fields if isinstance(field, v2.ModelField)]
|
||||
v2_field_maps, v2_definitions = v2.get_definitions(
|
||||
fields=v2_fields,
|
||||
model_name_map=model_name_map,
|
||||
separate_input_output_schemas=separate_input_output_schemas,
|
||||
)
|
||||
all_definitions = {**v1_definitions, **v2_definitions}
|
||||
all_field_maps = {**v1_field_maps, **v2_field_maps}
|
||||
return all_field_maps, all_definitions
|
||||
|
||||
# Pydantic v1 is not supported since Python 3.14
|
||||
else:
|
||||
from . import v2
|
||||
|
||||
v2_fields = [field for field in fields if isinstance(field, v2.ModelField)]
|
||||
v2_field_maps, v2_definitions = v2.get_definitions(
|
||||
fields=v2_fields,
|
||||
model_name_map=model_name_map,
|
||||
separate_input_output_schemas=separate_input_output_schemas,
|
||||
)
|
||||
return v2_field_maps, v2_definitions
|
||||
|
||||
|
||||
def get_schema_from_model_field(
|
||||
*,
|
||||
field: ModelField,
|
||||
model_name_map: ModelNameMap,
|
||||
field_mapping: Dict[
|
||||
Tuple[ModelField, Literal["validation", "serialization"]],
|
||||
may_v1.JsonSchemaValue,
|
||||
],
|
||||
separate_input_output_schemas: bool = True,
|
||||
) -> Dict[str, Any]:
|
||||
if isinstance(field, may_v1.ModelField):
|
||||
from fastapi._compat import v1
|
||||
|
||||
return v1.get_schema_from_model_field(
|
||||
field=field,
|
||||
model_name_map=model_name_map,
|
||||
field_mapping=field_mapping,
|
||||
separate_input_output_schemas=separate_input_output_schemas,
|
||||
)
|
||||
else:
|
||||
assert PYDANTIC_V2
|
||||
from . import v2
|
||||
|
||||
return v2.get_schema_from_model_field(
|
||||
field=field, # type: ignore[arg-type]
|
||||
model_name_map=model_name_map,
|
||||
field_mapping=field_mapping, # type: ignore[arg-type]
|
||||
separate_input_output_schemas=separate_input_output_schemas,
|
||||
)
|
||||
|
||||
|
||||
def _is_model_field(value: Any) -> bool:
|
||||
if isinstance(value, may_v1.ModelField):
|
||||
return True
|
||||
elif PYDANTIC_V2:
|
||||
from . import v2
|
||||
|
||||
return isinstance(value, v2.ModelField)
|
||||
return False
|
||||
|
||||
|
||||
def _is_model_class(value: Any) -> bool:
|
||||
if lenient_issubclass(value, may_v1.BaseModel):
|
||||
return True
|
||||
elif PYDANTIC_V2:
|
||||
from . import v2
|
||||
|
||||
return lenient_issubclass(value, v2.BaseModel) # type: ignore[attr-defined]
|
||||
return False
|
||||
@@ -0,0 +1,123 @@
|
||||
import sys
|
||||
from typing import Any, Dict, List, Literal, Sequence, Tuple, Type, Union
|
||||
|
||||
from fastapi.types import ModelNameMap
|
||||
|
||||
if sys.version_info >= (3, 14):
|
||||
|
||||
class AnyUrl:
|
||||
pass
|
||||
|
||||
class BaseConfig:
|
||||
pass
|
||||
|
||||
class BaseModel:
|
||||
pass
|
||||
|
||||
class Color:
|
||||
pass
|
||||
|
||||
class CoreSchema:
|
||||
pass
|
||||
|
||||
class ErrorWrapper:
|
||||
pass
|
||||
|
||||
class FieldInfo:
|
||||
pass
|
||||
|
||||
class GetJsonSchemaHandler:
|
||||
pass
|
||||
|
||||
class JsonSchemaValue:
|
||||
pass
|
||||
|
||||
class ModelField:
|
||||
pass
|
||||
|
||||
class NameEmail:
|
||||
pass
|
||||
|
||||
class RequiredParam:
|
||||
pass
|
||||
|
||||
class SecretBytes:
|
||||
pass
|
||||
|
||||
class SecretStr:
|
||||
pass
|
||||
|
||||
class Undefined:
|
||||
pass
|
||||
|
||||
class UndefinedType:
|
||||
pass
|
||||
|
||||
class Url:
|
||||
pass
|
||||
|
||||
from .v2 import ValidationError, create_model
|
||||
|
||||
def get_definitions(
|
||||
*,
|
||||
fields: List[ModelField],
|
||||
model_name_map: ModelNameMap,
|
||||
separate_input_output_schemas: bool = True,
|
||||
) -> Tuple[
|
||||
Dict[
|
||||
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
|
||||
],
|
||||
Dict[str, Dict[str, Any]],
|
||||
]:
|
||||
return {}, {} # pragma: no cover
|
||||
|
||||
|
||||
else:
|
||||
from .v1 import AnyUrl as AnyUrl
|
||||
from .v1 import BaseConfig as BaseConfig
|
||||
from .v1 import BaseModel as BaseModel
|
||||
from .v1 import Color as Color
|
||||
from .v1 import CoreSchema as CoreSchema
|
||||
from .v1 import ErrorWrapper as ErrorWrapper
|
||||
from .v1 import FieldInfo as FieldInfo
|
||||
from .v1 import GetJsonSchemaHandler as GetJsonSchemaHandler
|
||||
from .v1 import JsonSchemaValue as JsonSchemaValue
|
||||
from .v1 import ModelField as ModelField
|
||||
from .v1 import NameEmail as NameEmail
|
||||
from .v1 import RequiredParam as RequiredParam
|
||||
from .v1 import SecretBytes as SecretBytes
|
||||
from .v1 import SecretStr as SecretStr
|
||||
from .v1 import Undefined as Undefined
|
||||
from .v1 import UndefinedType as UndefinedType
|
||||
from .v1 import Url as Url
|
||||
from .v1 import ValidationError, create_model
|
||||
from .v1 import get_definitions as get_definitions
|
||||
|
||||
|
||||
RequestErrorModel: Type[BaseModel] = create_model("Request")
|
||||
|
||||
|
||||
def _normalize_errors(errors: Sequence[Any]) -> List[Dict[str, Any]]:
|
||||
use_errors: List[Any] = []
|
||||
for error in errors:
|
||||
if isinstance(error, ErrorWrapper):
|
||||
new_errors = ValidationError( # type: ignore[call-arg]
|
||||
errors=[error], model=RequestErrorModel
|
||||
).errors()
|
||||
use_errors.extend(new_errors)
|
||||
elif isinstance(error, list):
|
||||
use_errors.extend(_normalize_errors(error))
|
||||
else:
|
||||
use_errors.append(error)
|
||||
return use_errors
|
||||
|
||||
|
||||
def _regenerate_error_with_loc(
|
||||
*, errors: Sequence[Any], loc_prefix: Tuple[Union[str, int], ...]
|
||||
) -> List[Dict[str, Any]]:
|
||||
updated_loc_errors: List[Any] = [
|
||||
{**err, "loc": loc_prefix + err.get("loc", ())}
|
||||
for err in _normalize_errors(errors)
|
||||
]
|
||||
|
||||
return updated_loc_errors
|
||||
@@ -0,0 +1,53 @@
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
from fastapi.types import IncEx
|
||||
from pydantic.fields import FieldInfo
|
||||
from typing_extensions import Literal, Protocol
|
||||
|
||||
|
||||
class ModelField(Protocol):
|
||||
field_info: "FieldInfo"
|
||||
name: str
|
||||
mode: Literal["validation", "serialization"] = "validation"
|
||||
_version: Literal["v1", "v2"] = "v1"
|
||||
|
||||
@property
|
||||
def alias(self) -> str: ...
|
||||
|
||||
@property
|
||||
def required(self) -> bool: ...
|
||||
|
||||
@property
|
||||
def default(self) -> Any: ...
|
||||
|
||||
@property
|
||||
def type_(self) -> Any: ...
|
||||
|
||||
def get_default(self) -> Any: ...
|
||||
|
||||
def validate(
|
||||
self,
|
||||
value: Any,
|
||||
values: Dict[str, Any] = {}, # noqa: B006
|
||||
*,
|
||||
loc: Tuple[Union[int, str], ...] = (),
|
||||
) -> Tuple[Any, Union[List[Dict[str, Any]], None]]: ...
|
||||
|
||||
def serialize(
|
||||
self,
|
||||
value: Any,
|
||||
*,
|
||||
mode: Literal["json", "python"] = "json",
|
||||
include: Union[IncEx, None] = None,
|
||||
exclude: Union[IncEx, None] = None,
|
||||
by_alias: bool = True,
|
||||
exclude_unset: bool = False,
|
||||
exclude_defaults: bool = False,
|
||||
exclude_none: bool = False,
|
||||
) -> Any: ...
|
||||
@@ -0,0 +1,211 @@
|
||||
import sys
|
||||
import types
|
||||
import typing
|
||||
from collections import deque
|
||||
from dataclasses import is_dataclass
|
||||
from typing import (
|
||||
Any,
|
||||
Deque,
|
||||
FrozenSet,
|
||||
List,
|
||||
Mapping,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
from fastapi._compat import may_v1
|
||||
from fastapi.types import UnionType
|
||||
from pydantic import BaseModel
|
||||
from pydantic.version import VERSION as PYDANTIC_VERSION
|
||||
from starlette.datastructures import UploadFile
|
||||
from typing_extensions import Annotated, get_args, get_origin
|
||||
|
||||
# Copy from Pydantic v2, compatible with v1
|
||||
if sys.version_info < (3, 9):
|
||||
# Pydantic no longer supports Python 3.8, this might be incorrect, but the code
|
||||
# this is used for is also never reached in this codebase, as it's a copy of
|
||||
# Pydantic's lenient_issubclass, just for compatibility with v1
|
||||
# TODO: remove when dropping support for Python 3.8
|
||||
WithArgsTypes: Tuple[Any, ...] = ()
|
||||
elif sys.version_info < (3, 10):
|
||||
WithArgsTypes: tuple[Any, ...] = (typing._GenericAlias, types.GenericAlias) # type: ignore[attr-defined]
|
||||
else:
|
||||
WithArgsTypes: tuple[Any, ...] = (
|
||||
typing._GenericAlias, # type: ignore[attr-defined]
|
||||
types.GenericAlias,
|
||||
types.UnionType,
|
||||
) # pyright: ignore[reportAttributeAccessIssue]
|
||||
|
||||
PYDANTIC_VERSION_MINOR_TUPLE = tuple(int(x) for x in PYDANTIC_VERSION.split(".")[:2])
|
||||
PYDANTIC_V2 = PYDANTIC_VERSION_MINOR_TUPLE[0] == 2
|
||||
|
||||
|
||||
sequence_annotation_to_type = {
|
||||
Sequence: list,
|
||||
List: list,
|
||||
list: list,
|
||||
Tuple: tuple,
|
||||
tuple: tuple,
|
||||
Set: set,
|
||||
set: set,
|
||||
FrozenSet: frozenset,
|
||||
frozenset: frozenset,
|
||||
Deque: deque,
|
||||
deque: deque,
|
||||
}
|
||||
|
||||
sequence_types = tuple(sequence_annotation_to_type.keys())
|
||||
|
||||
Url: Type[Any]
|
||||
|
||||
|
||||
# Copy of Pydantic v2, compatible with v1
|
||||
def lenient_issubclass(
|
||||
cls: Any, class_or_tuple: Union[Type[Any], Tuple[Type[Any], ...], None]
|
||||
) -> bool:
|
||||
try:
|
||||
return isinstance(cls, type) and issubclass(cls, class_or_tuple) # type: ignore[arg-type]
|
||||
except TypeError: # pragma: no cover
|
||||
if isinstance(cls, WithArgsTypes):
|
||||
return False
|
||||
raise # pragma: no cover
|
||||
|
||||
|
||||
def _annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool:
|
||||
if lenient_issubclass(annotation, (str, bytes)):
|
||||
return False
|
||||
return lenient_issubclass(annotation, sequence_types) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def field_annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool:
|
||||
origin = get_origin(annotation)
|
||||
if origin is Union or origin is UnionType:
|
||||
for arg in get_args(annotation):
|
||||
if field_annotation_is_sequence(arg):
|
||||
return True
|
||||
return False
|
||||
return _annotation_is_sequence(annotation) or _annotation_is_sequence(
|
||||
get_origin(annotation)
|
||||
)
|
||||
|
||||
|
||||
def value_is_sequence(value: Any) -> bool:
|
||||
return isinstance(value, sequence_types) and not isinstance(value, (str, bytes)) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def _annotation_is_complex(annotation: Union[Type[Any], None]) -> bool:
|
||||
return (
|
||||
lenient_issubclass(
|
||||
annotation, (BaseModel, may_v1.BaseModel, Mapping, UploadFile)
|
||||
)
|
||||
or _annotation_is_sequence(annotation)
|
||||
or is_dataclass(annotation)
|
||||
)
|
||||
|
||||
|
||||
def field_annotation_is_complex(annotation: Union[Type[Any], None]) -> bool:
|
||||
origin = get_origin(annotation)
|
||||
if origin is Union or origin is UnionType:
|
||||
return any(field_annotation_is_complex(arg) for arg in get_args(annotation))
|
||||
|
||||
if origin is Annotated:
|
||||
return field_annotation_is_complex(get_args(annotation)[0])
|
||||
|
||||
return (
|
||||
_annotation_is_complex(annotation)
|
||||
or _annotation_is_complex(origin)
|
||||
or hasattr(origin, "__pydantic_core_schema__")
|
||||
or hasattr(origin, "__get_pydantic_core_schema__")
|
||||
)
|
||||
|
||||
|
||||
def field_annotation_is_scalar(annotation: Any) -> bool:
|
||||
# handle Ellipsis here to make tuple[int, ...] work nicely
|
||||
return annotation is Ellipsis or not field_annotation_is_complex(annotation)
|
||||
|
||||
|
||||
def field_annotation_is_scalar_sequence(annotation: Union[Type[Any], None]) -> bool:
|
||||
origin = get_origin(annotation)
|
||||
if origin is Union or origin is UnionType:
|
||||
at_least_one_scalar_sequence = False
|
||||
for arg in get_args(annotation):
|
||||
if field_annotation_is_scalar_sequence(arg):
|
||||
at_least_one_scalar_sequence = True
|
||||
continue
|
||||
elif not field_annotation_is_scalar(arg):
|
||||
return False
|
||||
return at_least_one_scalar_sequence
|
||||
return field_annotation_is_sequence(annotation) and all(
|
||||
field_annotation_is_scalar(sub_annotation)
|
||||
for sub_annotation in get_args(annotation)
|
||||
)
|
||||
|
||||
|
||||
def is_bytes_or_nonable_bytes_annotation(annotation: Any) -> bool:
|
||||
if lenient_issubclass(annotation, bytes):
|
||||
return True
|
||||
origin = get_origin(annotation)
|
||||
if origin is Union or origin is UnionType:
|
||||
for arg in get_args(annotation):
|
||||
if lenient_issubclass(arg, bytes):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_uploadfile_or_nonable_uploadfile_annotation(annotation: Any) -> bool:
|
||||
if lenient_issubclass(annotation, UploadFile):
|
||||
return True
|
||||
origin = get_origin(annotation)
|
||||
if origin is Union or origin is UnionType:
|
||||
for arg in get_args(annotation):
|
||||
if lenient_issubclass(arg, UploadFile):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_bytes_sequence_annotation(annotation: Any) -> bool:
|
||||
origin = get_origin(annotation)
|
||||
if origin is Union or origin is UnionType:
|
||||
at_least_one = False
|
||||
for arg in get_args(annotation):
|
||||
if is_bytes_sequence_annotation(arg):
|
||||
at_least_one = True
|
||||
continue
|
||||
return at_least_one
|
||||
return field_annotation_is_sequence(annotation) and all(
|
||||
is_bytes_or_nonable_bytes_annotation(sub_annotation)
|
||||
for sub_annotation in get_args(annotation)
|
||||
)
|
||||
|
||||
|
||||
def is_uploadfile_sequence_annotation(annotation: Any) -> bool:
|
||||
origin = get_origin(annotation)
|
||||
if origin is Union or origin is UnionType:
|
||||
at_least_one = False
|
||||
for arg in get_args(annotation):
|
||||
if is_uploadfile_sequence_annotation(arg):
|
||||
at_least_one = True
|
||||
continue
|
||||
return at_least_one
|
||||
return field_annotation_is_sequence(annotation) and all(
|
||||
is_uploadfile_or_nonable_uploadfile_annotation(sub_annotation)
|
||||
for sub_annotation in get_args(annotation)
|
||||
)
|
||||
|
||||
|
||||
def annotation_is_pydantic_v1(annotation: Any) -> bool:
|
||||
if lenient_issubclass(annotation, may_v1.BaseModel):
|
||||
return True
|
||||
origin = get_origin(annotation)
|
||||
if origin is Union or origin is UnionType:
|
||||
for arg in get_args(annotation):
|
||||
if lenient_issubclass(arg, may_v1.BaseModel):
|
||||
return True
|
||||
if field_annotation_is_sequence(annotation):
|
||||
for sub_annotation in get_args(annotation):
|
||||
if annotation_is_pydantic_v1(sub_annotation):
|
||||
return True
|
||||
return False
|
||||
312
Backend/venv/lib/python3.12/site-packages/fastapi/_compat/v1.py
Normal file
312
Backend/venv/lib/python3.12/site-packages/fastapi/_compat/v1.py
Normal file
@@ -0,0 +1,312 @@
|
||||
from copy import copy
|
||||
from dataclasses import dataclass, is_dataclass
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
from fastapi._compat import shared
|
||||
from fastapi.openapi.constants import REF_PREFIX as REF_PREFIX
|
||||
from fastapi.types import ModelNameMap
|
||||
from pydantic.version import VERSION as PYDANTIC_VERSION
|
||||
from typing_extensions import Literal
|
||||
|
||||
PYDANTIC_VERSION_MINOR_TUPLE = tuple(int(x) for x in PYDANTIC_VERSION.split(".")[:2])
|
||||
PYDANTIC_V2 = PYDANTIC_VERSION_MINOR_TUPLE[0] == 2
|
||||
# Keeping old "Required" functionality from Pydantic V1, without
|
||||
# shadowing typing.Required.
|
||||
RequiredParam: Any = Ellipsis
|
||||
|
||||
if not PYDANTIC_V2:
|
||||
from pydantic import BaseConfig as BaseConfig
|
||||
from pydantic import BaseModel as BaseModel
|
||||
from pydantic import ValidationError as ValidationError
|
||||
from pydantic import create_model as create_model
|
||||
from pydantic.class_validators import Validator as Validator
|
||||
from pydantic.color import Color as Color
|
||||
from pydantic.error_wrappers import ErrorWrapper as ErrorWrapper
|
||||
from pydantic.errors import MissingError
|
||||
from pydantic.fields import ( # type: ignore[attr-defined]
|
||||
SHAPE_FROZENSET,
|
||||
SHAPE_LIST,
|
||||
SHAPE_SEQUENCE,
|
||||
SHAPE_SET,
|
||||
SHAPE_SINGLETON,
|
||||
SHAPE_TUPLE,
|
||||
SHAPE_TUPLE_ELLIPSIS,
|
||||
)
|
||||
from pydantic.fields import FieldInfo as FieldInfo
|
||||
from pydantic.fields import ModelField as ModelField # type: ignore[attr-defined]
|
||||
from pydantic.fields import Undefined as Undefined # type: ignore[attr-defined]
|
||||
from pydantic.fields import ( # type: ignore[attr-defined]
|
||||
UndefinedType as UndefinedType,
|
||||
)
|
||||
from pydantic.networks import AnyUrl as AnyUrl
|
||||
from pydantic.networks import NameEmail as NameEmail
|
||||
from pydantic.schema import TypeModelSet as TypeModelSet
|
||||
from pydantic.schema import (
|
||||
field_schema,
|
||||
model_process_schema,
|
||||
)
|
||||
from pydantic.schema import (
|
||||
get_annotation_from_field_info as get_annotation_from_field_info,
|
||||
)
|
||||
from pydantic.schema import get_flat_models_from_field as get_flat_models_from_field
|
||||
from pydantic.schema import (
|
||||
get_flat_models_from_fields as get_flat_models_from_fields,
|
||||
)
|
||||
from pydantic.schema import get_model_name_map as get_model_name_map
|
||||
from pydantic.types import SecretBytes as SecretBytes
|
||||
from pydantic.types import SecretStr as SecretStr
|
||||
from pydantic.typing import evaluate_forwardref as evaluate_forwardref
|
||||
from pydantic.utils import lenient_issubclass as lenient_issubclass
|
||||
|
||||
|
||||
else:
|
||||
from pydantic.v1 import BaseConfig as BaseConfig # type: ignore[assignment]
|
||||
from pydantic.v1 import BaseModel as BaseModel # type: ignore[assignment]
|
||||
from pydantic.v1 import ( # type: ignore[assignment]
|
||||
ValidationError as ValidationError,
|
||||
)
|
||||
from pydantic.v1 import create_model as create_model # type: ignore[no-redef]
|
||||
from pydantic.v1.class_validators import Validator as Validator
|
||||
from pydantic.v1.color import Color as Color # type: ignore[assignment]
|
||||
from pydantic.v1.error_wrappers import ErrorWrapper as ErrorWrapper
|
||||
from pydantic.v1.errors import MissingError
|
||||
from pydantic.v1.fields import (
|
||||
SHAPE_FROZENSET,
|
||||
SHAPE_LIST,
|
||||
SHAPE_SEQUENCE,
|
||||
SHAPE_SET,
|
||||
SHAPE_SINGLETON,
|
||||
SHAPE_TUPLE,
|
||||
SHAPE_TUPLE_ELLIPSIS,
|
||||
)
|
||||
from pydantic.v1.fields import FieldInfo as FieldInfo # type: ignore[assignment]
|
||||
from pydantic.v1.fields import ModelField as ModelField
|
||||
from pydantic.v1.fields import Undefined as Undefined
|
||||
from pydantic.v1.fields import UndefinedType as UndefinedType
|
||||
from pydantic.v1.networks import AnyUrl as AnyUrl
|
||||
from pydantic.v1.networks import ( # type: ignore[assignment]
|
||||
NameEmail as NameEmail,
|
||||
)
|
||||
from pydantic.v1.schema import TypeModelSet as TypeModelSet
|
||||
from pydantic.v1.schema import (
|
||||
field_schema,
|
||||
model_process_schema,
|
||||
)
|
||||
from pydantic.v1.schema import (
|
||||
get_annotation_from_field_info as get_annotation_from_field_info,
|
||||
)
|
||||
from pydantic.v1.schema import (
|
||||
get_flat_models_from_field as get_flat_models_from_field,
|
||||
)
|
||||
from pydantic.v1.schema import (
|
||||
get_flat_models_from_fields as get_flat_models_from_fields,
|
||||
)
|
||||
from pydantic.v1.schema import get_model_name_map as get_model_name_map
|
||||
from pydantic.v1.types import ( # type: ignore[assignment]
|
||||
SecretBytes as SecretBytes,
|
||||
)
|
||||
from pydantic.v1.types import ( # type: ignore[assignment]
|
||||
SecretStr as SecretStr,
|
||||
)
|
||||
from pydantic.v1.typing import evaluate_forwardref as evaluate_forwardref
|
||||
from pydantic.v1.utils import lenient_issubclass as lenient_issubclass
|
||||
|
||||
|
||||
GetJsonSchemaHandler = Any
|
||||
JsonSchemaValue = Dict[str, Any]
|
||||
CoreSchema = Any
|
||||
Url = AnyUrl
|
||||
|
||||
sequence_shapes = {
|
||||
SHAPE_LIST,
|
||||
SHAPE_SET,
|
||||
SHAPE_FROZENSET,
|
||||
SHAPE_TUPLE,
|
||||
SHAPE_SEQUENCE,
|
||||
SHAPE_TUPLE_ELLIPSIS,
|
||||
}
|
||||
sequence_shape_to_type = {
|
||||
SHAPE_LIST: list,
|
||||
SHAPE_SET: set,
|
||||
SHAPE_TUPLE: tuple,
|
||||
SHAPE_SEQUENCE: list,
|
||||
SHAPE_TUPLE_ELLIPSIS: list,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class GenerateJsonSchema:
|
||||
ref_template: str
|
||||
|
||||
|
||||
class PydanticSchemaGenerationError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
RequestErrorModel: Type[BaseModel] = create_model("Request")
|
||||
|
||||
|
||||
def with_info_plain_validator_function(
|
||||
function: Callable[..., Any],
|
||||
*,
|
||||
ref: Union[str, None] = None,
|
||||
metadata: Any = None,
|
||||
serialization: Any = None,
|
||||
) -> Any:
|
||||
return {}
|
||||
|
||||
|
||||
def get_model_definitions(
|
||||
*,
|
||||
flat_models: Set[Union[Type[BaseModel], Type[Enum]]],
|
||||
model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str],
|
||||
) -> Dict[str, Any]:
|
||||
definitions: Dict[str, Dict[str, Any]] = {}
|
||||
for model in flat_models:
|
||||
m_schema, m_definitions, m_nested_models = model_process_schema(
|
||||
model, model_name_map=model_name_map, ref_prefix=REF_PREFIX
|
||||
)
|
||||
definitions.update(m_definitions)
|
||||
model_name = model_name_map[model]
|
||||
definitions[model_name] = m_schema
|
||||
for m_schema in definitions.values():
|
||||
if "description" in m_schema:
|
||||
m_schema["description"] = m_schema["description"].split("\f")[0]
|
||||
return definitions
|
||||
|
||||
|
||||
def is_pv1_scalar_field(field: ModelField) -> bool:
|
||||
from fastapi import params
|
||||
|
||||
field_info = field.field_info
|
||||
if not (
|
||||
field.shape == SHAPE_SINGLETON
|
||||
and not lenient_issubclass(field.type_, BaseModel)
|
||||
and not lenient_issubclass(field.type_, dict)
|
||||
and not shared.field_annotation_is_sequence(field.type_)
|
||||
and not is_dataclass(field.type_)
|
||||
and not isinstance(field_info, params.Body)
|
||||
):
|
||||
return False
|
||||
if field.sub_fields:
|
||||
if not all(is_pv1_scalar_field(f) for f in field.sub_fields):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def is_pv1_scalar_sequence_field(field: ModelField) -> bool:
|
||||
if (field.shape in sequence_shapes) and not lenient_issubclass(
|
||||
field.type_, BaseModel
|
||||
):
|
||||
if field.sub_fields is not None:
|
||||
for sub_field in field.sub_fields:
|
||||
if not is_pv1_scalar_field(sub_field):
|
||||
return False
|
||||
return True
|
||||
if shared._annotation_is_sequence(field.type_):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _model_rebuild(model: Type[BaseModel]) -> None:
|
||||
model.update_forward_refs()
|
||||
|
||||
|
||||
def _model_dump(
|
||||
model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any
|
||||
) -> Any:
|
||||
return model.dict(**kwargs)
|
||||
|
||||
|
||||
def _get_model_config(model: BaseModel) -> Any:
|
||||
return model.__config__ # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def get_schema_from_model_field(
|
||||
*,
|
||||
field: ModelField,
|
||||
model_name_map: ModelNameMap,
|
||||
field_mapping: Dict[
|
||||
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
|
||||
],
|
||||
separate_input_output_schemas: bool = True,
|
||||
) -> Dict[str, Any]:
|
||||
return field_schema( # type: ignore[no-any-return]
|
||||
field, model_name_map=model_name_map, ref_prefix=REF_PREFIX
|
||||
)[0]
|
||||
|
||||
|
||||
# def get_compat_model_name_map(fields: List[ModelField]) -> ModelNameMap:
|
||||
# models = get_flat_models_from_fields(fields, known_models=set())
|
||||
# return get_model_name_map(models) # type: ignore[no-any-return]
|
||||
|
||||
|
||||
def get_definitions(
|
||||
*,
|
||||
fields: List[ModelField],
|
||||
model_name_map: ModelNameMap,
|
||||
separate_input_output_schemas: bool = True,
|
||||
) -> Tuple[
|
||||
Dict[Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue],
|
||||
Dict[str, Dict[str, Any]],
|
||||
]:
|
||||
models = get_flat_models_from_fields(fields, known_models=set())
|
||||
return {}, get_model_definitions(flat_models=models, model_name_map=model_name_map)
|
||||
|
||||
|
||||
def is_scalar_field(field: ModelField) -> bool:
|
||||
return is_pv1_scalar_field(field)
|
||||
|
||||
|
||||
def is_sequence_field(field: ModelField) -> bool:
|
||||
return field.shape in sequence_shapes or shared._annotation_is_sequence(field.type_)
|
||||
|
||||
|
||||
def is_scalar_sequence_field(field: ModelField) -> bool:
|
||||
return is_pv1_scalar_sequence_field(field)
|
||||
|
||||
|
||||
def is_bytes_field(field: ModelField) -> bool:
|
||||
return lenient_issubclass(field.type_, bytes) # type: ignore[no-any-return]
|
||||
|
||||
|
||||
def is_bytes_sequence_field(field: ModelField) -> bool:
|
||||
return field.shape in sequence_shapes and lenient_issubclass(field.type_, bytes)
|
||||
|
||||
|
||||
def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo:
|
||||
return copy(field_info)
|
||||
|
||||
|
||||
def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]:
|
||||
return sequence_shape_to_type[field.shape](value) # type: ignore[no-any-return]
|
||||
|
||||
|
||||
def get_missing_field_error(loc: Tuple[str, ...]) -> Dict[str, Any]:
|
||||
missing_field_error = ErrorWrapper(MissingError(), loc=loc)
|
||||
new_error = ValidationError([missing_field_error], RequestErrorModel)
|
||||
return new_error.errors()[0] # type: ignore[return-value]
|
||||
|
||||
|
||||
def create_body_model(
|
||||
*, fields: Sequence[ModelField], model_name: str
|
||||
) -> Type[BaseModel]:
|
||||
BodyModel = create_model(model_name)
|
||||
for f in fields:
|
||||
BodyModel.__fields__[f.name] = f # type: ignore[index]
|
||||
return BodyModel
|
||||
|
||||
|
||||
def get_model_fields(model: Type[BaseModel]) -> List[ModelField]:
|
||||
return list(model.__fields__.values()) # type: ignore[attr-defined]
|
||||
479
Backend/venv/lib/python3.12/site-packages/fastapi/_compat/v2.py
Normal file
479
Backend/venv/lib/python3.12/site-packages/fastapi/_compat/v2.py
Normal file
@@ -0,0 +1,479 @@
|
||||
import re
|
||||
import warnings
|
||||
from copy import copy, deepcopy
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from fastapi._compat import may_v1, shared
|
||||
from fastapi.openapi.constants import REF_TEMPLATE
|
||||
from fastapi.types import IncEx, ModelNameMap
|
||||
from pydantic import BaseModel, TypeAdapter, create_model
|
||||
from pydantic import PydanticSchemaGenerationError as PydanticSchemaGenerationError
|
||||
from pydantic import PydanticUndefinedAnnotation as PydanticUndefinedAnnotation
|
||||
from pydantic import ValidationError as ValidationError
|
||||
from pydantic._internal._schema_generation_shared import ( # type: ignore[attr-defined]
|
||||
GetJsonSchemaHandler as GetJsonSchemaHandler,
|
||||
)
|
||||
from pydantic._internal._typing_extra import eval_type_lenient
|
||||
from pydantic._internal._utils import lenient_issubclass as lenient_issubclass
|
||||
from pydantic.fields import FieldInfo as FieldInfo
|
||||
from pydantic.json_schema import GenerateJsonSchema as GenerateJsonSchema
|
||||
from pydantic.json_schema import JsonSchemaValue as JsonSchemaValue
|
||||
from pydantic_core import CoreSchema as CoreSchema
|
||||
from pydantic_core import PydanticUndefined, PydanticUndefinedType
|
||||
from pydantic_core import Url as Url
|
||||
from typing_extensions import Annotated, Literal, get_args, get_origin
|
||||
|
||||
try:
|
||||
from pydantic_core.core_schema import (
|
||||
with_info_plain_validator_function as with_info_plain_validator_function,
|
||||
)
|
||||
except ImportError: # pragma: no cover
|
||||
from pydantic_core.core_schema import (
|
||||
general_plain_validator_function as with_info_plain_validator_function, # noqa: F401
|
||||
)
|
||||
|
||||
RequiredParam = PydanticUndefined
|
||||
Undefined = PydanticUndefined
|
||||
UndefinedType = PydanticUndefinedType
|
||||
evaluate_forwardref = eval_type_lenient
|
||||
Validator = Any
|
||||
|
||||
|
||||
class BaseConfig:
|
||||
pass
|
||||
|
||||
|
||||
class ErrorWrapper(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelField:
|
||||
field_info: FieldInfo
|
||||
name: str
|
||||
mode: Literal["validation", "serialization"] = "validation"
|
||||
|
||||
@property
|
||||
def alias(self) -> str:
|
||||
a = self.field_info.alias
|
||||
return a if a is not None else self.name
|
||||
|
||||
@property
|
||||
def required(self) -> bool:
|
||||
return self.field_info.is_required()
|
||||
|
||||
@property
|
||||
def default(self) -> Any:
|
||||
return self.get_default()
|
||||
|
||||
@property
|
||||
def type_(self) -> Any:
|
||||
return self.field_info.annotation
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
with warnings.catch_warnings():
|
||||
# Pydantic >= 2.12.0 warns about field specific metadata that is unused
|
||||
# (e.g. `TypeAdapter(Annotated[int, Field(alias='b')])`). In some cases, we
|
||||
# end up building the type adapter from a model field annotation so we
|
||||
# need to ignore the warning:
|
||||
if shared.PYDANTIC_VERSION_MINOR_TUPLE >= (2, 12):
|
||||
from pydantic.warnings import UnsupportedFieldAttributeWarning
|
||||
|
||||
warnings.simplefilter(
|
||||
"ignore", category=UnsupportedFieldAttributeWarning
|
||||
)
|
||||
self._type_adapter: TypeAdapter[Any] = TypeAdapter(
|
||||
Annotated[self.field_info.annotation, self.field_info]
|
||||
)
|
||||
|
||||
def get_default(self) -> Any:
|
||||
if self.field_info.is_required():
|
||||
return Undefined
|
||||
return self.field_info.get_default(call_default_factory=True)
|
||||
|
||||
def validate(
|
||||
self,
|
||||
value: Any,
|
||||
values: Dict[str, Any] = {}, # noqa: B006
|
||||
*,
|
||||
loc: Tuple[Union[int, str], ...] = (),
|
||||
) -> Tuple[Any, Union[List[Dict[str, Any]], None]]:
|
||||
try:
|
||||
return (
|
||||
self._type_adapter.validate_python(value, from_attributes=True),
|
||||
None,
|
||||
)
|
||||
except ValidationError as exc:
|
||||
return None, may_v1._regenerate_error_with_loc(
|
||||
errors=exc.errors(include_url=False), loc_prefix=loc
|
||||
)
|
||||
|
||||
def serialize(
|
||||
self,
|
||||
value: Any,
|
||||
*,
|
||||
mode: Literal["json", "python"] = "json",
|
||||
include: Union[IncEx, None] = None,
|
||||
exclude: Union[IncEx, None] = None,
|
||||
by_alias: bool = True,
|
||||
exclude_unset: bool = False,
|
||||
exclude_defaults: bool = False,
|
||||
exclude_none: bool = False,
|
||||
) -> Any:
|
||||
# What calls this code passes a value that already called
|
||||
# self._type_adapter.validate_python(value)
|
||||
return self._type_adapter.dump_python(
|
||||
value,
|
||||
mode=mode,
|
||||
include=include,
|
||||
exclude=exclude,
|
||||
by_alias=by_alias,
|
||||
exclude_unset=exclude_unset,
|
||||
exclude_defaults=exclude_defaults,
|
||||
exclude_none=exclude_none,
|
||||
)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
# Each ModelField is unique for our purposes, to allow making a dict from
|
||||
# ModelField to its JSON Schema.
|
||||
return id(self)
|
||||
|
||||
|
||||
def get_annotation_from_field_info(
|
||||
annotation: Any, field_info: FieldInfo, field_name: str
|
||||
) -> Any:
|
||||
return annotation
|
||||
|
||||
|
||||
def _model_rebuild(model: Type[BaseModel]) -> None:
|
||||
model.model_rebuild()
|
||||
|
||||
|
||||
def _model_dump(
|
||||
model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any
|
||||
) -> Any:
|
||||
return model.model_dump(mode=mode, **kwargs)
|
||||
|
||||
|
||||
def _get_model_config(model: BaseModel) -> Any:
|
||||
return model.model_config
|
||||
|
||||
|
||||
def get_schema_from_model_field(
|
||||
*,
|
||||
field: ModelField,
|
||||
model_name_map: ModelNameMap,
|
||||
field_mapping: Dict[
|
||||
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
|
||||
],
|
||||
separate_input_output_schemas: bool = True,
|
||||
) -> Dict[str, Any]:
|
||||
override_mode: Union[Literal["validation"], None] = (
|
||||
None if separate_input_output_schemas else "validation"
|
||||
)
|
||||
# This expects that GenerateJsonSchema was already used to generate the definitions
|
||||
json_schema = field_mapping[(field, override_mode or field.mode)]
|
||||
if "$ref" not in json_schema:
|
||||
# TODO remove when deprecating Pydantic v1
|
||||
# Ref: https://github.com/pydantic/pydantic/blob/d61792cc42c80b13b23e3ffa74bc37ec7c77f7d1/pydantic/schema.py#L207
|
||||
json_schema["title"] = field.field_info.title or field.alias.title().replace(
|
||||
"_", " "
|
||||
)
|
||||
return json_schema
|
||||
|
||||
|
||||
def get_definitions(
|
||||
*,
|
||||
fields: Sequence[ModelField],
|
||||
model_name_map: ModelNameMap,
|
||||
separate_input_output_schemas: bool = True,
|
||||
) -> Tuple[
|
||||
Dict[Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue],
|
||||
Dict[str, Dict[str, Any]],
|
||||
]:
|
||||
schema_generator = GenerateJsonSchema(ref_template=REF_TEMPLATE)
|
||||
override_mode: Union[Literal["validation"], None] = (
|
||||
None if separate_input_output_schemas else "validation"
|
||||
)
|
||||
validation_fields = [field for field in fields if field.mode == "validation"]
|
||||
serialization_fields = [field for field in fields if field.mode == "serialization"]
|
||||
flat_validation_models = get_flat_models_from_fields(
|
||||
validation_fields, known_models=set()
|
||||
)
|
||||
flat_serialization_models = get_flat_models_from_fields(
|
||||
serialization_fields, known_models=set()
|
||||
)
|
||||
flat_validation_model_fields = [
|
||||
ModelField(
|
||||
field_info=FieldInfo(annotation=model),
|
||||
name=model.__name__,
|
||||
mode="validation",
|
||||
)
|
||||
for model in flat_validation_models
|
||||
]
|
||||
flat_serialization_model_fields = [
|
||||
ModelField(
|
||||
field_info=FieldInfo(annotation=model),
|
||||
name=model.__name__,
|
||||
mode="serialization",
|
||||
)
|
||||
for model in flat_serialization_models
|
||||
]
|
||||
flat_model_fields = flat_validation_model_fields + flat_serialization_model_fields
|
||||
input_types = {f.type_ for f in fields}
|
||||
unique_flat_model_fields = {
|
||||
f for f in flat_model_fields if f.type_ not in input_types
|
||||
}
|
||||
|
||||
inputs = [
|
||||
(field, override_mode or field.mode, field._type_adapter.core_schema)
|
||||
for field in list(fields) + list(unique_flat_model_fields)
|
||||
]
|
||||
field_mapping, definitions = schema_generator.generate_definitions(inputs=inputs)
|
||||
for item_def in cast(Dict[str, Dict[str, Any]], definitions).values():
|
||||
if "description" in item_def:
|
||||
item_description = cast(str, item_def["description"]).split("\f")[0]
|
||||
item_def["description"] = item_description
|
||||
new_mapping, new_definitions = _remap_definitions_and_field_mappings(
|
||||
model_name_map=model_name_map,
|
||||
definitions=definitions, # type: ignore[arg-type]
|
||||
field_mapping=field_mapping,
|
||||
)
|
||||
return new_mapping, new_definitions
|
||||
|
||||
|
||||
def _replace_refs(
|
||||
*,
|
||||
schema: Dict[str, Any],
|
||||
old_name_to_new_name_map: Dict[str, str],
|
||||
) -> Dict[str, Any]:
|
||||
new_schema = deepcopy(schema)
|
||||
for key, value in new_schema.items():
|
||||
if key == "$ref":
|
||||
value = schema["$ref"]
|
||||
if isinstance(value, str):
|
||||
ref_name = schema["$ref"].split("/")[-1]
|
||||
if ref_name in old_name_to_new_name_map:
|
||||
new_name = old_name_to_new_name_map[ref_name]
|
||||
new_schema["$ref"] = REF_TEMPLATE.format(model=new_name)
|
||||
continue
|
||||
if isinstance(value, dict):
|
||||
new_schema[key] = _replace_refs(
|
||||
schema=value,
|
||||
old_name_to_new_name_map=old_name_to_new_name_map,
|
||||
)
|
||||
elif isinstance(value, list):
|
||||
new_value = []
|
||||
for item in value:
|
||||
if isinstance(item, dict):
|
||||
new_item = _replace_refs(
|
||||
schema=item,
|
||||
old_name_to_new_name_map=old_name_to_new_name_map,
|
||||
)
|
||||
new_value.append(new_item)
|
||||
|
||||
else:
|
||||
new_value.append(item)
|
||||
new_schema[key] = new_value
|
||||
return new_schema
|
||||
|
||||
|
||||
def _remap_definitions_and_field_mappings(
|
||||
*,
|
||||
model_name_map: ModelNameMap,
|
||||
definitions: Dict[str, Any],
|
||||
field_mapping: Dict[
|
||||
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
|
||||
],
|
||||
) -> Tuple[
|
||||
Dict[Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue],
|
||||
Dict[str, Any],
|
||||
]:
|
||||
old_name_to_new_name_map = {}
|
||||
for field_key, schema in field_mapping.items():
|
||||
model = field_key[0].type_
|
||||
if model not in model_name_map:
|
||||
continue
|
||||
new_name = model_name_map[model]
|
||||
old_name = schema["$ref"].split("/")[-1]
|
||||
if old_name in {f"{new_name}-Input", f"{new_name}-Output"}:
|
||||
continue
|
||||
old_name_to_new_name_map[old_name] = new_name
|
||||
|
||||
new_field_mapping: Dict[
|
||||
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
|
||||
] = {}
|
||||
for field_key, schema in field_mapping.items():
|
||||
new_schema = _replace_refs(
|
||||
schema=schema,
|
||||
old_name_to_new_name_map=old_name_to_new_name_map,
|
||||
)
|
||||
new_field_mapping[field_key] = new_schema
|
||||
|
||||
new_definitions = {}
|
||||
for key, value in definitions.items():
|
||||
if key in old_name_to_new_name_map:
|
||||
new_key = old_name_to_new_name_map[key]
|
||||
else:
|
||||
new_key = key
|
||||
new_value = _replace_refs(
|
||||
schema=value,
|
||||
old_name_to_new_name_map=old_name_to_new_name_map,
|
||||
)
|
||||
new_definitions[new_key] = new_value
|
||||
return new_field_mapping, new_definitions
|
||||
|
||||
|
||||
def is_scalar_field(field: ModelField) -> bool:
|
||||
from fastapi import params
|
||||
|
||||
return shared.field_annotation_is_scalar(
|
||||
field.field_info.annotation
|
||||
) and not isinstance(field.field_info, params.Body)
|
||||
|
||||
|
||||
def is_sequence_field(field: ModelField) -> bool:
|
||||
return shared.field_annotation_is_sequence(field.field_info.annotation)
|
||||
|
||||
|
||||
def is_scalar_sequence_field(field: ModelField) -> bool:
|
||||
return shared.field_annotation_is_scalar_sequence(field.field_info.annotation)
|
||||
|
||||
|
||||
def is_bytes_field(field: ModelField) -> bool:
|
||||
return shared.is_bytes_or_nonable_bytes_annotation(field.type_)
|
||||
|
||||
|
||||
def is_bytes_sequence_field(field: ModelField) -> bool:
|
||||
return shared.is_bytes_sequence_annotation(field.type_)
|
||||
|
||||
|
||||
def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo:
|
||||
cls = type(field_info)
|
||||
merged_field_info = cls.from_annotation(annotation)
|
||||
new_field_info = copy(field_info)
|
||||
new_field_info.metadata = merged_field_info.metadata
|
||||
new_field_info.annotation = merged_field_info.annotation
|
||||
return new_field_info
|
||||
|
||||
|
||||
def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]:
|
||||
origin_type = get_origin(field.field_info.annotation) or field.field_info.annotation
|
||||
assert issubclass(origin_type, shared.sequence_types) # type: ignore[arg-type]
|
||||
return shared.sequence_annotation_to_type[origin_type](value) # type: ignore[no-any-return]
|
||||
|
||||
|
||||
def get_missing_field_error(loc: Tuple[str, ...]) -> Dict[str, Any]:
|
||||
error = ValidationError.from_exception_data(
|
||||
"Field required", [{"type": "missing", "loc": loc, "input": {}}]
|
||||
).errors(include_url=False)[0]
|
||||
error["input"] = None
|
||||
return error # type: ignore[return-value]
|
||||
|
||||
|
||||
def create_body_model(
|
||||
*, fields: Sequence[ModelField], model_name: str
|
||||
) -> Type[BaseModel]:
|
||||
field_params = {f.name: (f.field_info.annotation, f.field_info) for f in fields}
|
||||
BodyModel: Type[BaseModel] = create_model(model_name, **field_params) # type: ignore[call-overload]
|
||||
return BodyModel
|
||||
|
||||
|
||||
def get_model_fields(model: Type[BaseModel]) -> List[ModelField]:
|
||||
return [
|
||||
ModelField(field_info=field_info, name=name)
|
||||
for name, field_info in model.model_fields.items()
|
||||
]
|
||||
|
||||
|
||||
# Duplicate of several schema functions from Pydantic v1 to make them compatible with
|
||||
# Pydantic v2 and allow mixing the models
|
||||
|
||||
TypeModelOrEnum = Union[Type["BaseModel"], Type[Enum]]
|
||||
TypeModelSet = Set[TypeModelOrEnum]
|
||||
|
||||
|
||||
def normalize_name(name: str) -> str:
|
||||
return re.sub(r"[^a-zA-Z0-9.\-_]", "_", name)
|
||||
|
||||
|
||||
def get_model_name_map(unique_models: TypeModelSet) -> Dict[TypeModelOrEnum, str]:
|
||||
name_model_map = {}
|
||||
conflicting_names: Set[str] = set()
|
||||
for model in unique_models:
|
||||
model_name = normalize_name(model.__name__)
|
||||
if model_name in conflicting_names:
|
||||
model_name = get_long_model_name(model)
|
||||
name_model_map[model_name] = model
|
||||
elif model_name in name_model_map:
|
||||
conflicting_names.add(model_name)
|
||||
conflicting_model = name_model_map.pop(model_name)
|
||||
name_model_map[get_long_model_name(conflicting_model)] = conflicting_model
|
||||
name_model_map[get_long_model_name(model)] = model
|
||||
else:
|
||||
name_model_map[model_name] = model
|
||||
return {v: k for k, v in name_model_map.items()}
|
||||
|
||||
|
||||
def get_flat_models_from_model(
|
||||
model: Type["BaseModel"], known_models: Union[TypeModelSet, None] = None
|
||||
) -> TypeModelSet:
|
||||
known_models = known_models or set()
|
||||
fields = get_model_fields(model)
|
||||
get_flat_models_from_fields(fields, known_models=known_models)
|
||||
return known_models
|
||||
|
||||
|
||||
def get_flat_models_from_annotation(
|
||||
annotation: Any, known_models: TypeModelSet
|
||||
) -> TypeModelSet:
|
||||
origin = get_origin(annotation)
|
||||
if origin is not None:
|
||||
for arg in get_args(annotation):
|
||||
if lenient_issubclass(arg, (BaseModel, Enum)) and arg not in known_models:
|
||||
known_models.add(arg)
|
||||
if lenient_issubclass(arg, BaseModel):
|
||||
get_flat_models_from_model(arg, known_models=known_models)
|
||||
else:
|
||||
get_flat_models_from_annotation(arg, known_models=known_models)
|
||||
return known_models
|
||||
|
||||
|
||||
def get_flat_models_from_field(
|
||||
field: ModelField, known_models: TypeModelSet
|
||||
) -> TypeModelSet:
|
||||
field_type = field.type_
|
||||
if lenient_issubclass(field_type, BaseModel):
|
||||
if field_type in known_models:
|
||||
return known_models
|
||||
known_models.add(field_type)
|
||||
get_flat_models_from_model(field_type, known_models=known_models)
|
||||
elif lenient_issubclass(field_type, Enum):
|
||||
known_models.add(field_type)
|
||||
else:
|
||||
get_flat_models_from_annotation(field_type, known_models=known_models)
|
||||
return known_models
|
||||
|
||||
|
||||
def get_flat_models_from_fields(
|
||||
fields: Sequence[ModelField], known_models: TypeModelSet
|
||||
) -> TypeModelSet:
|
||||
for field in fields:
|
||||
get_flat_models_from_field(field, known_models=known_models)
|
||||
return known_models
|
||||
|
||||
|
||||
def get_long_model_name(model: TypeModelOrEnum) -> str:
|
||||
return f"{model.__module__}__{model.__qualname__}".replace(".", "__")
|
||||
@@ -13,6 +13,7 @@ from typing import (
|
||||
Union,
|
||||
)
|
||||
|
||||
from annotated_doc import Doc
|
||||
from fastapi import routing
|
||||
from fastapi.datastructures import Default, DefaultPlaceholder
|
||||
from fastapi.exception_handlers import (
|
||||
@@ -42,8 +43,8 @@ from starlette.middleware.exceptions import ExceptionMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import HTMLResponse, JSONResponse, Response
|
||||
from starlette.routing import BaseRoute
|
||||
from starlette.types import ASGIApp, Lifespan, Receive, Scope, Send
|
||||
from typing_extensions import Annotated, Doc, deprecated # type: ignore [attr-defined]
|
||||
from starlette.types import ASGIApp, ExceptionHandler, Lifespan, Receive, Scope, Send
|
||||
from typing_extensions import Annotated, deprecated
|
||||
|
||||
AppType = TypeVar("AppType", bound="FastAPI")
|
||||
|
||||
@@ -75,7 +76,7 @@ class FastAPI(Starlette):
|
||||
errors.
|
||||
|
||||
Read more in the
|
||||
[Starlette docs for Applications](https://www.starlette.io/applications/#instantiating-the-application).
|
||||
[Starlette docs for Applications](https://www.starlette.dev/applications/#instantiating-the-application).
|
||||
"""
|
||||
),
|
||||
] = False,
|
||||
@@ -300,7 +301,7 @@ class FastAPI(Starlette):
|
||||
browser tabs open). Or if you want to leave fixed the possible URLs.
|
||||
|
||||
If the servers `list` is not provided, or is an empty `list`, the
|
||||
default value would be a a `dict` with a `url` value of `/`.
|
||||
default value would be a `dict` with a `url` value of `/`.
|
||||
|
||||
Each item in the `list` is a `dict` containing:
|
||||
|
||||
@@ -751,7 +752,7 @@ class FastAPI(Starlette):
|
||||
This affects the generated OpenAPI (e.g. visible at `/docs`).
|
||||
|
||||
Read more about it in the
|
||||
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi).
|
||||
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi).
|
||||
"""
|
||||
),
|
||||
] = True,
|
||||
@@ -813,6 +814,32 @@ class FastAPI(Starlette):
|
||||
"""
|
||||
),
|
||||
] = True,
|
||||
openapi_external_docs: Annotated[
|
||||
Optional[Dict[str, Any]],
|
||||
Doc(
|
||||
"""
|
||||
This field allows you to provide additional external documentation links.
|
||||
If provided, it must be a dictionary containing:
|
||||
|
||||
* `description`: A brief description of the external documentation.
|
||||
* `url`: The URL pointing to the external documentation. The value **MUST**
|
||||
be a valid URL format.
|
||||
|
||||
**Example**:
|
||||
|
||||
```python
|
||||
from fastapi import FastAPI
|
||||
|
||||
external_docs = {
|
||||
"description": "Detailed API Reference",
|
||||
"url": "https://example.com/api-docs",
|
||||
}
|
||||
|
||||
app = FastAPI(openapi_external_docs=external_docs)
|
||||
```
|
||||
"""
|
||||
),
|
||||
] = None,
|
||||
**extra: Annotated[
|
||||
Any,
|
||||
Doc(
|
||||
@@ -841,6 +868,7 @@ class FastAPI(Starlette):
|
||||
self.swagger_ui_parameters = swagger_ui_parameters
|
||||
self.servers = servers or []
|
||||
self.separate_input_output_schemas = separate_input_output_schemas
|
||||
self.openapi_external_docs = openapi_external_docs
|
||||
self.extra = extra
|
||||
self.openapi_version: Annotated[
|
||||
str,
|
||||
@@ -905,13 +933,13 @@ class FastAPI(Starlette):
|
||||
A state object for the application. This is the same object for the
|
||||
entire application, it doesn't change from request to request.
|
||||
|
||||
You normally woudln't use this in FastAPI, for most of the cases you
|
||||
You normally wouldn't use this in FastAPI, for most of the cases you
|
||||
would instead use FastAPI dependencies.
|
||||
|
||||
This is simply inherited from Starlette.
|
||||
|
||||
Read more about it in the
|
||||
[Starlette docs for Applications](https://www.starlette.io/applications/#storing-state-on-the-app-instance).
|
||||
[Starlette docs for Applications](https://www.starlette.dev/applications/#storing-state-on-the-app-instance).
|
||||
"""
|
||||
),
|
||||
] = State()
|
||||
@@ -971,7 +999,7 @@ class FastAPI(Starlette):
|
||||
# inside of ExceptionMiddleware, inside of custom user middlewares
|
||||
debug = self.debug
|
||||
error_handler = None
|
||||
exception_handlers = {}
|
||||
exception_handlers: dict[Any, ExceptionHandler] = {}
|
||||
|
||||
for key, value in self.exception_handlers.items():
|
||||
if key in (500, Exception):
|
||||
@@ -986,33 +1014,32 @@ class FastAPI(Starlette):
|
||||
Middleware(
|
||||
ExceptionMiddleware, handlers=exception_handlers, debug=debug
|
||||
),
|
||||
# Add FastAPI-specific AsyncExitStackMiddleware for dependencies with
|
||||
# contextvars.
|
||||
# Add FastAPI-specific AsyncExitStackMiddleware for closing files.
|
||||
# Before this was also used for closing dependencies with yield but
|
||||
# those now have their own AsyncExitStack, to properly support
|
||||
# streaming responses while keeping compatibility with the previous
|
||||
# versions (as of writing 0.117.1) that allowed doing
|
||||
# except HTTPException inside a dependency with yield.
|
||||
# This needs to happen after user middlewares because those create a
|
||||
# new contextvars context copy by using a new AnyIO task group.
|
||||
# The initial part of dependencies with 'yield' is executed in the
|
||||
# FastAPI code, inside all the middlewares. However, the teardown part
|
||||
# (after 'yield') is executed in the AsyncExitStack in this middleware.
|
||||
# This AsyncExitStack preserves the context for contextvars, not
|
||||
# strictly necessary for closing files but it was one of the original
|
||||
# intentions.
|
||||
# If the AsyncExitStack lived outside of the custom middlewares and
|
||||
# contextvars were set in a dependency with 'yield' in that internal
|
||||
# contextvars context, the values would not be available in the
|
||||
# outer context of the AsyncExitStack.
|
||||
# contextvars were set, for example in a dependency with 'yield'
|
||||
# in that internal contextvars context, the values would not be
|
||||
# available in the outer context of the AsyncExitStack.
|
||||
# By placing the middleware and the AsyncExitStack here, inside all
|
||||
# user middlewares, the code before and after 'yield' in dependencies
|
||||
# with 'yield' is executed in the same contextvars context. Thus, all values
|
||||
# set in contextvars before 'yield' are still available after 'yield,' as
|
||||
# expected.
|
||||
# Additionally, by having this AsyncExitStack here, after the
|
||||
# ExceptionMiddleware, dependencies can now catch handled exceptions,
|
||||
# e.g. HTTPException, to customize the teardown code (e.g. DB session
|
||||
# rollback).
|
||||
# user middlewares, the same context is used.
|
||||
# This is currently not needed, only for closing files, but used to be
|
||||
# important when dependencies with yield were closed here.
|
||||
Middleware(AsyncExitStackMiddleware),
|
||||
]
|
||||
)
|
||||
|
||||
app = self.router
|
||||
for cls, options in reversed(middleware):
|
||||
app = cls(app=app, **options)
|
||||
for cls, args, kwargs in reversed(middleware):
|
||||
app = cls(app, *args, **kwargs)
|
||||
return app
|
||||
|
||||
def openapi(self) -> Dict[str, Any]:
|
||||
@@ -1044,6 +1071,7 @@ class FastAPI(Starlette):
|
||||
tags=self.openapi_tags,
|
||||
servers=self.servers,
|
||||
separate_input_output_schemas=self.separate_input_output_schemas,
|
||||
external_docs=self.openapi_external_docs,
|
||||
)
|
||||
return self.openapi_schema
|
||||
|
||||
@@ -1071,7 +1099,7 @@ class FastAPI(Starlette):
|
||||
oauth2_redirect_url = root_path + oauth2_redirect_url
|
||||
return get_swagger_ui_html(
|
||||
openapi_url=openapi_url,
|
||||
title=self.title + " - Swagger UI",
|
||||
title=f"{self.title} - Swagger UI",
|
||||
oauth2_redirect_url=oauth2_redirect_url,
|
||||
init_oauth=self.swagger_ui_init_oauth,
|
||||
swagger_ui_parameters=self.swagger_ui_parameters,
|
||||
@@ -1095,7 +1123,7 @@ class FastAPI(Starlette):
|
||||
root_path = req.scope.get("root_path", "").rstrip("/")
|
||||
openapi_url = root_path + self.openapi_url
|
||||
return get_redoc_html(
|
||||
openapi_url=openapi_url, title=self.title + " - ReDoc"
|
||||
openapi_url=openapi_url, title=f"{self.title} - ReDoc"
|
||||
)
|
||||
|
||||
self.add_route(self.redoc_url, redoc_html, include_in_schema=False)
|
||||
@@ -1108,7 +1136,7 @@ class FastAPI(Starlette):
|
||||
def add_api_route(
|
||||
self,
|
||||
path: str,
|
||||
endpoint: Callable[..., Coroutine[Any, Any, Response]],
|
||||
endpoint: Callable[..., Any],
|
||||
*,
|
||||
response_model: Any = Default(None),
|
||||
status_code: Optional[int] = None,
|
||||
@@ -1772,7 +1800,7 @@ class FastAPI(Starlette):
|
||||
This affects the generated OpenAPI (e.g. visible at `/docs`).
|
||||
|
||||
Read more about it in the
|
||||
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi).
|
||||
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi).
|
||||
"""
|
||||
),
|
||||
] = True,
|
||||
@@ -2145,7 +2173,7 @@ class FastAPI(Starlette):
|
||||
This affects the generated OpenAPI (e.g. visible at `/docs`).
|
||||
|
||||
Read more about it in the
|
||||
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi).
|
||||
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi).
|
||||
"""
|
||||
),
|
||||
] = True,
|
||||
@@ -2523,7 +2551,7 @@ class FastAPI(Starlette):
|
||||
This affects the generated OpenAPI (e.g. visible at `/docs`).
|
||||
|
||||
Read more about it in the
|
||||
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi).
|
||||
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi).
|
||||
"""
|
||||
),
|
||||
] = True,
|
||||
@@ -2901,7 +2929,7 @@ class FastAPI(Starlette):
|
||||
This affects the generated OpenAPI (e.g. visible at `/docs`).
|
||||
|
||||
Read more about it in the
|
||||
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi).
|
||||
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi).
|
||||
"""
|
||||
),
|
||||
] = True,
|
||||
@@ -3274,7 +3302,7 @@ class FastAPI(Starlette):
|
||||
This affects the generated OpenAPI (e.g. visible at `/docs`).
|
||||
|
||||
Read more about it in the
|
||||
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi).
|
||||
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi).
|
||||
"""
|
||||
),
|
||||
] = True,
|
||||
@@ -3647,7 +3675,7 @@ class FastAPI(Starlette):
|
||||
This affects the generated OpenAPI (e.g. visible at `/docs`).
|
||||
|
||||
Read more about it in the
|
||||
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi).
|
||||
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi).
|
||||
"""
|
||||
),
|
||||
] = True,
|
||||
@@ -4020,7 +4048,7 @@ class FastAPI(Starlette):
|
||||
This affects the generated OpenAPI (e.g. visible at `/docs`).
|
||||
|
||||
Read more about it in the
|
||||
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi).
|
||||
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi).
|
||||
"""
|
||||
),
|
||||
] = True,
|
||||
@@ -4398,7 +4426,7 @@ class FastAPI(Starlette):
|
||||
This affects the generated OpenAPI (e.g. visible at `/docs`).
|
||||
|
||||
Read more about it in the
|
||||
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi).
|
||||
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi).
|
||||
"""
|
||||
),
|
||||
] = True,
|
||||
@@ -4477,7 +4505,7 @@ class FastAPI(Starlette):
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@app.put("/items/{item_id}")
|
||||
@app.trace("/items/{item_id}")
|
||||
def trace_item(item_id: str):
|
||||
return None
|
||||
```
|
||||
@@ -4567,14 +4595,17 @@ class FastAPI(Starlette):
|
||||
|
||||
```python
|
||||
import time
|
||||
from typing import Awaitable, Callable
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi import FastAPI, Request, Response
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
@app.middleware("http")
|
||||
async def add_process_time_header(request: Request, call_next):
|
||||
async def add_process_time_header(
|
||||
request: Request, call_next: Callable[[Request], Awaitable[Response]]
|
||||
) -> Response:
|
||||
start_time = time.time()
|
||||
response = await call_next(request)
|
||||
process_time = time.time() - start_time
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from typing import Any, Callable
|
||||
|
||||
from annotated_doc import Doc
|
||||
from starlette.background import BackgroundTasks as StarletteBackgroundTasks
|
||||
from typing_extensions import Annotated, Doc, ParamSpec # type: ignore [attr-defined]
|
||||
from typing_extensions import Annotated, ParamSpec
|
||||
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
||||
13
Backend/venv/lib/python3.12/site-packages/fastapi/cli.py
Normal file
13
Backend/venv/lib/python3.12/site-packages/fastapi/cli.py
Normal file
@@ -0,0 +1,13 @@
|
||||
try:
|
||||
from fastapi_cli.cli import main as cli_main
|
||||
|
||||
except ImportError: # pragma: no cover
|
||||
cli_main = None # type: ignore
|
||||
|
||||
|
||||
def main() -> None:
|
||||
if not cli_main: # type: ignore[truthy-function]
|
||||
message = 'To use the fastapi command, please install "fastapi[standard]":\n\n\tpip install "fastapi[standard]"\n'
|
||||
print(message)
|
||||
raise RuntimeError(message) # noqa: B904
|
||||
cli_main()
|
||||
@@ -1,8 +1,7 @@
|
||||
from contextlib import AsyncExitStack as AsyncExitStack # noqa
|
||||
from contextlib import asynccontextmanager as asynccontextmanager
|
||||
from typing import AsyncGenerator, ContextManager, TypeVar
|
||||
|
||||
import anyio
|
||||
import anyio.to_thread
|
||||
from anyio import CapacityLimiter
|
||||
from starlette.concurrency import iterate_in_threadpool as iterate_in_threadpool # noqa
|
||||
from starlette.concurrency import run_in_threadpool as run_in_threadpool # noqa
|
||||
@@ -29,7 +28,7 @@ async def contextmanager_in_threadpool(
|
||||
except Exception as e:
|
||||
ok = bool(
|
||||
await anyio.to_thread.run_sync(
|
||||
cm.__exit__, type(e), e, None, limiter=exit_limiter
|
||||
cm.__exit__, type(e), e, e.__traceback__, limiter=exit_limiter
|
||||
)
|
||||
)
|
||||
if not ok:
|
||||
|
||||
@@ -10,12 +10,11 @@ from typing import (
|
||||
cast,
|
||||
)
|
||||
|
||||
from annotated_doc import Doc
|
||||
from fastapi._compat import (
|
||||
PYDANTIC_V2,
|
||||
CoreSchema,
|
||||
GetJsonSchemaHandler,
|
||||
JsonSchemaValue,
|
||||
with_info_plain_validator_function,
|
||||
)
|
||||
from starlette.datastructures import URL as URL # noqa: F401
|
||||
from starlette.datastructures import Address as Address # noqa: F401
|
||||
@@ -24,7 +23,7 @@ from starlette.datastructures import Headers as Headers # noqa: F401
|
||||
from starlette.datastructures import QueryParams as QueryParams # noqa: F401
|
||||
from starlette.datastructures import State as State # noqa: F401
|
||||
from starlette.datastructures import UploadFile as StarletteUploadFile
|
||||
from typing_extensions import Annotated, Doc # type: ignore [attr-defined]
|
||||
from typing_extensions import Annotated
|
||||
|
||||
|
||||
class UploadFile(StarletteUploadFile):
|
||||
@@ -154,11 +153,10 @@ class UploadFile(StarletteUploadFile):
|
||||
raise ValueError(f"Expected UploadFile, received: {type(__input_value)}")
|
||||
return cast(UploadFile, __input_value)
|
||||
|
||||
if not PYDANTIC_V2:
|
||||
|
||||
@classmethod
|
||||
def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
|
||||
field_schema.update({"type": "string", "format": "binary"})
|
||||
# TODO: remove when deprecating Pydantic v1
|
||||
@classmethod
|
||||
def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
|
||||
field_schema.update({"type": "string", "format": "binary"})
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_json_schema__(
|
||||
@@ -170,6 +168,8 @@ class UploadFile(StarletteUploadFile):
|
||||
def __get_pydantic_core_schema__(
|
||||
cls, source: Type[Any], handler: Callable[[Any], CoreSchema]
|
||||
) -> CoreSchema:
|
||||
from ._compat.v2 import with_info_plain_validator_function
|
||||
|
||||
return with_info_plain_validator_function(cls._validate)
|
||||
|
||||
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,58 +1,107 @@
|
||||
from typing import Any, Callable, List, Optional, Sequence
|
||||
import inspect
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from functools import cached_property
|
||||
from typing import Any, Callable, List, Optional, Sequence, Union
|
||||
|
||||
from fastapi._compat import ModelField
|
||||
from fastapi.security.base import SecurityBase
|
||||
from fastapi.types import DependencyCacheKey
|
||||
from typing_extensions import Literal
|
||||
|
||||
if sys.version_info >= (3, 13): # pragma: no cover
|
||||
from inspect import iscoroutinefunction
|
||||
else: # pragma: no cover
|
||||
from asyncio import iscoroutinefunction
|
||||
|
||||
|
||||
@dataclass
|
||||
class SecurityRequirement:
|
||||
def __init__(
|
||||
self, security_scheme: SecurityBase, scopes: Optional[Sequence[str]] = None
|
||||
):
|
||||
self.security_scheme = security_scheme
|
||||
self.scopes = scopes
|
||||
security_scheme: SecurityBase
|
||||
scopes: Optional[Sequence[str]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Dependant:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
path_params: Optional[List[ModelField]] = None,
|
||||
query_params: Optional[List[ModelField]] = None,
|
||||
header_params: Optional[List[ModelField]] = None,
|
||||
cookie_params: Optional[List[ModelField]] = None,
|
||||
body_params: Optional[List[ModelField]] = None,
|
||||
dependencies: Optional[List["Dependant"]] = None,
|
||||
security_schemes: Optional[List[SecurityRequirement]] = None,
|
||||
name: Optional[str] = None,
|
||||
call: Optional[Callable[..., Any]] = None,
|
||||
request_param_name: Optional[str] = None,
|
||||
websocket_param_name: Optional[str] = None,
|
||||
http_connection_param_name: Optional[str] = None,
|
||||
response_param_name: Optional[str] = None,
|
||||
background_tasks_param_name: Optional[str] = None,
|
||||
security_scopes_param_name: Optional[str] = None,
|
||||
security_scopes: Optional[List[str]] = None,
|
||||
use_cache: bool = True,
|
||||
path: Optional[str] = None,
|
||||
) -> None:
|
||||
self.path_params = path_params or []
|
||||
self.query_params = query_params or []
|
||||
self.header_params = header_params or []
|
||||
self.cookie_params = cookie_params or []
|
||||
self.body_params = body_params or []
|
||||
self.dependencies = dependencies or []
|
||||
self.security_requirements = security_schemes or []
|
||||
self.request_param_name = request_param_name
|
||||
self.websocket_param_name = websocket_param_name
|
||||
self.http_connection_param_name = http_connection_param_name
|
||||
self.response_param_name = response_param_name
|
||||
self.background_tasks_param_name = background_tasks_param_name
|
||||
self.security_scopes = security_scopes
|
||||
self.security_scopes_param_name = security_scopes_param_name
|
||||
self.name = name
|
||||
self.call = call
|
||||
self.use_cache = use_cache
|
||||
# Store the path to be able to re-generate a dependable from it in overrides
|
||||
self.path = path
|
||||
# Save the cache key at creation to optimize performance
|
||||
self.cache_key = (self.call, tuple(sorted(set(self.security_scopes or []))))
|
||||
path_params: List[ModelField] = field(default_factory=list)
|
||||
query_params: List[ModelField] = field(default_factory=list)
|
||||
header_params: List[ModelField] = field(default_factory=list)
|
||||
cookie_params: List[ModelField] = field(default_factory=list)
|
||||
body_params: List[ModelField] = field(default_factory=list)
|
||||
dependencies: List["Dependant"] = field(default_factory=list)
|
||||
security_requirements: List[SecurityRequirement] = field(default_factory=list)
|
||||
name: Optional[str] = None
|
||||
call: Optional[Callable[..., Any]] = None
|
||||
request_param_name: Optional[str] = None
|
||||
websocket_param_name: Optional[str] = None
|
||||
http_connection_param_name: Optional[str] = None
|
||||
response_param_name: Optional[str] = None
|
||||
background_tasks_param_name: Optional[str] = None
|
||||
security_scopes_param_name: Optional[str] = None
|
||||
own_oauth_scopes: Optional[List[str]] = None
|
||||
parent_oauth_scopes: Optional[List[str]] = None
|
||||
use_cache: bool = True
|
||||
path: Optional[str] = None
|
||||
scope: Union[Literal["function", "request"], None] = None
|
||||
|
||||
@cached_property
|
||||
def oauth_scopes(self) -> List[str]:
|
||||
scopes = self.parent_oauth_scopes.copy() if self.parent_oauth_scopes else []
|
||||
# This doesn't use a set to preserve order, just in case
|
||||
for scope in self.own_oauth_scopes or []:
|
||||
if scope not in scopes:
|
||||
scopes.append(scope)
|
||||
return scopes
|
||||
|
||||
@cached_property
|
||||
def cache_key(self) -> DependencyCacheKey:
|
||||
scopes_for_cache = (
|
||||
tuple(sorted(set(self.oauth_scopes or []))) if self._uses_scopes else ()
|
||||
)
|
||||
return (
|
||||
self.call,
|
||||
scopes_for_cache,
|
||||
self.computed_scope or "",
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def _uses_scopes(self) -> bool:
|
||||
if self.own_oauth_scopes:
|
||||
return True
|
||||
if self.security_scopes_param_name is not None:
|
||||
return True
|
||||
for sub_dep in self.dependencies:
|
||||
if sub_dep._uses_scopes:
|
||||
return True
|
||||
return False
|
||||
|
||||
@cached_property
|
||||
def is_gen_callable(self) -> bool:
|
||||
if inspect.isgeneratorfunction(self.call):
|
||||
return True
|
||||
dunder_call = getattr(self.call, "__call__", None) # noqa: B004
|
||||
return inspect.isgeneratorfunction(dunder_call)
|
||||
|
||||
@cached_property
|
||||
def is_async_gen_callable(self) -> bool:
|
||||
if inspect.isasyncgenfunction(self.call):
|
||||
return True
|
||||
dunder_call = getattr(self.call, "__call__", None) # noqa: B004
|
||||
return inspect.isasyncgenfunction(dunder_call)
|
||||
|
||||
@cached_property
|
||||
def is_coroutine_callable(self) -> bool:
|
||||
if inspect.isroutine(self.call):
|
||||
return iscoroutinefunction(self.call)
|
||||
if inspect.isclass(self.call):
|
||||
return False
|
||||
dunder_call = getattr(self.call, "__call__", None) # noqa: B004
|
||||
return iscoroutinefunction(dunder_call)
|
||||
|
||||
@cached_property
|
||||
def computed_scope(self) -> Union[str, None]:
|
||||
if self.scope:
|
||||
return self.scope
|
||||
if self.is_gen_callable or self.is_async_gen_callable:
|
||||
return "request"
|
||||
return None
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -17,14 +17,16 @@ from types import GeneratorType
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
|
||||
from uuid import UUID
|
||||
|
||||
from annotated_doc import Doc
|
||||
from fastapi._compat import may_v1
|
||||
from fastapi.types import IncEx
|
||||
from pydantic import BaseModel
|
||||
from pydantic.color import Color
|
||||
from pydantic.networks import AnyUrl, NameEmail
|
||||
from pydantic.types import SecretBytes, SecretStr
|
||||
from typing_extensions import Annotated, Doc # type: ignore [attr-defined]
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from ._compat import PYDANTIC_V2, Url, _model_dump
|
||||
from ._compat import Url, _is_undefined, _model_dump
|
||||
|
||||
|
||||
# Taken from Pydantic v1 as is
|
||||
@@ -58,6 +60,7 @@ def decimal_encoder(dec_value: Decimal) -> Union[int, float]:
|
||||
ENCODERS_BY_TYPE: Dict[Type[Any], Callable[[Any], Any]] = {
|
||||
bytes: lambda o: o.decode(),
|
||||
Color: str,
|
||||
may_v1.Color: str,
|
||||
datetime.date: isoformat,
|
||||
datetime.datetime: isoformat,
|
||||
datetime.time: isoformat,
|
||||
@@ -74,19 +77,24 @@ ENCODERS_BY_TYPE: Dict[Type[Any], Callable[[Any], Any]] = {
|
||||
IPv6Interface: str,
|
||||
IPv6Network: str,
|
||||
NameEmail: str,
|
||||
may_v1.NameEmail: str,
|
||||
Path: str,
|
||||
Pattern: lambda o: o.pattern,
|
||||
SecretBytes: str,
|
||||
may_v1.SecretBytes: str,
|
||||
SecretStr: str,
|
||||
may_v1.SecretStr: str,
|
||||
set: list,
|
||||
UUID: str,
|
||||
Url: str,
|
||||
may_v1.Url: str,
|
||||
AnyUrl: str,
|
||||
may_v1.AnyUrl: str,
|
||||
}
|
||||
|
||||
|
||||
def generate_encoders_by_class_tuples(
|
||||
type_encoder_map: Dict[Any, Callable[[Any], Any]]
|
||||
type_encoder_map: Dict[Any, Callable[[Any], Any]],
|
||||
) -> Dict[Callable[[Any], Any], Tuple[Any, ...]]:
|
||||
encoders_by_class_tuples: Dict[Callable[[Any], Any], Tuple[Any, ...]] = defaultdict(
|
||||
tuple
|
||||
@@ -213,13 +221,13 @@ def jsonable_encoder(
|
||||
include = set(include)
|
||||
if exclude is not None and not isinstance(exclude, (set, dict)):
|
||||
exclude = set(exclude)
|
||||
if isinstance(obj, BaseModel):
|
||||
if isinstance(obj, (BaseModel, may_v1.BaseModel)):
|
||||
# TODO: remove when deprecating Pydantic v1
|
||||
encoders: Dict[Any, Any] = {}
|
||||
if not PYDANTIC_V2:
|
||||
if isinstance(obj, may_v1.BaseModel):
|
||||
encoders = getattr(obj.__config__, "json_encoders", {}) # type: ignore[attr-defined]
|
||||
if custom_encoder:
|
||||
encoders.update(custom_encoder)
|
||||
encoders = {**encoders, **custom_encoder}
|
||||
obj_dict = _model_dump(
|
||||
obj,
|
||||
mode="json",
|
||||
@@ -241,6 +249,7 @@ def jsonable_encoder(
|
||||
sqlalchemy_safe=sqlalchemy_safe,
|
||||
)
|
||||
if dataclasses.is_dataclass(obj):
|
||||
assert not isinstance(obj, type)
|
||||
obj_dict = dataclasses.asdict(obj)
|
||||
return jsonable_encoder(
|
||||
obj_dict,
|
||||
@@ -259,6 +268,8 @@ def jsonable_encoder(
|
||||
return str(obj)
|
||||
if isinstance(obj, (str, int, float, type(None))):
|
||||
return obj
|
||||
if _is_undefined(obj):
|
||||
return None
|
||||
if isinstance(obj, dict):
|
||||
encoded_dict = {}
|
||||
allowed_keys = set(obj.keys())
|
||||
|
||||
@@ -5,7 +5,7 @@ from fastapi.websockets import WebSocket
|
||||
from starlette.exceptions import HTTPException
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse, Response
|
||||
from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY, WS_1008_POLICY_VIOLATION
|
||||
from starlette.status import WS_1008_POLICY_VIOLATION
|
||||
|
||||
|
||||
async def http_exception_handler(request: Request, exc: HTTPException) -> Response:
|
||||
@@ -21,7 +21,7 @@ async def request_validation_exception_handler(
|
||||
request: Request, exc: RequestValidationError
|
||||
) -> JSONResponse:
|
||||
return JSONResponse(
|
||||
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
status_code=422,
|
||||
content={"detail": jsonable_encoder(exc.errors())},
|
||||
)
|
||||
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from typing import Any, Dict, Optional, Sequence, Type, Union
|
||||
|
||||
from annotated_doc import Doc
|
||||
from pydantic import BaseModel, create_model
|
||||
from starlette.exceptions import HTTPException as StarletteHTTPException
|
||||
from starlette.exceptions import WebSocketException as StarletteWebSocketException
|
||||
from typing_extensions import Annotated, Doc # type: ignore [attr-defined]
|
||||
from typing_extensions import Annotated
|
||||
|
||||
|
||||
class HTTPException(StarletteHTTPException):
|
||||
@@ -146,6 +147,13 @@ class FastAPIError(RuntimeError):
|
||||
"""
|
||||
|
||||
|
||||
class DependencyScopeError(FastAPIError):
|
||||
"""
|
||||
A dependency declared that it depends on another dependency with an invalid
|
||||
(narrower) scope.
|
||||
"""
|
||||
|
||||
|
||||
class ValidationException(Exception):
|
||||
def __init__(self, errors: Sequence[Any]) -> None:
|
||||
self._errors = errors
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,25 +1,18 @@
|
||||
from typing import Optional
|
||||
from contextlib import AsyncExitStack
|
||||
|
||||
from fastapi.concurrency import AsyncExitStack
|
||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
|
||||
|
||||
# Used mainly to close files after the request is done, dependencies are closed
|
||||
# in their own AsyncExitStack
|
||||
class AsyncExitStackMiddleware:
|
||||
def __init__(self, app: ASGIApp, context_name: str = "fastapi_astack") -> None:
|
||||
def __init__(
|
||||
self, app: ASGIApp, context_name: str = "fastapi_middleware_astack"
|
||||
) -> None:
|
||||
self.app = app
|
||||
self.context_name = context_name
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
dependency_exception: Optional[Exception] = None
|
||||
async with AsyncExitStack() as stack:
|
||||
scope[self.context_name] = stack
|
||||
try:
|
||||
await self.app(scope, receive, send)
|
||||
except Exception as e:
|
||||
dependency_exception = e
|
||||
raise e
|
||||
if dependency_exception:
|
||||
# This exception was possibly handled by the dependency but it should
|
||||
# still bubble up so that the ServerErrorMiddleware can return a 500
|
||||
# or the ExceptionMiddleware can catch and handle any other exceptions
|
||||
raise dependency_exception
|
||||
await self.app(scope, receive, send)
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,9 +1,10 @@
|
||||
import json
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from annotated_doc import Doc
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from starlette.responses import HTMLResponse
|
||||
from typing_extensions import Annotated, Doc # type: ignore [attr-defined]
|
||||
from typing_extensions import Annotated
|
||||
|
||||
swagger_ui_default_parameters: Annotated[
|
||||
Dict[str, Any],
|
||||
@@ -53,7 +54,7 @@ def get_swagger_ui_html(
|
||||
It is normally set to a CDN URL.
|
||||
"""
|
||||
),
|
||||
] = "https://cdn.jsdelivr.net/npm/swagger-ui-dist@5.9.0/swagger-ui-bundle.js",
|
||||
] = "https://cdn.jsdelivr.net/npm/swagger-ui-dist@5/swagger-ui-bundle.js",
|
||||
swagger_css_url: Annotated[
|
||||
str,
|
||||
Doc(
|
||||
@@ -63,7 +64,7 @@ def get_swagger_ui_html(
|
||||
It is normally set to a CDN URL.
|
||||
"""
|
||||
),
|
||||
] = "https://cdn.jsdelivr.net/npm/swagger-ui-dist@5.9.0/swagger-ui.css",
|
||||
] = "https://cdn.jsdelivr.net/npm/swagger-ui-dist@5/swagger-ui.css",
|
||||
swagger_favicon_url: Annotated[
|
||||
str,
|
||||
Doc(
|
||||
@@ -188,7 +189,7 @@ def get_redoc_html(
|
||||
It is normally set to a CDN URL.
|
||||
"""
|
||||
),
|
||||
] = "https://cdn.jsdelivr.net/npm/redoc@next/bundles/redoc.standalone.js",
|
||||
] = "https://cdn.jsdelivr.net/npm/redoc@2/bundles/redoc.standalone.js",
|
||||
redoc_favicon_url: Annotated[
|
||||
str,
|
||||
Doc(
|
||||
|
||||
@@ -55,35 +55,29 @@ except ImportError: # pragma: no cover
|
||||
return with_info_plain_validator_function(cls._validate)
|
||||
|
||||
|
||||
class Contact(BaseModel):
|
||||
class BaseModelWithConfig(BaseModel):
|
||||
if PYDANTIC_V2:
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
else:
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
class Contact(BaseModelWithConfig):
|
||||
name: Optional[str] = None
|
||||
url: Optional[AnyUrl] = None
|
||||
email: Optional[EmailStr] = None
|
||||
|
||||
if PYDANTIC_V2:
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
else:
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
class License(BaseModel):
|
||||
class License(BaseModelWithConfig):
|
||||
name: str
|
||||
identifier: Optional[str] = None
|
||||
url: Optional[AnyUrl] = None
|
||||
|
||||
if PYDANTIC_V2:
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
else:
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
class Info(BaseModel):
|
||||
class Info(BaseModelWithConfig):
|
||||
title: str
|
||||
summary: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
@@ -92,42 +86,18 @@ class Info(BaseModel):
|
||||
license: Optional[License] = None
|
||||
version: str
|
||||
|
||||
if PYDANTIC_V2:
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
else:
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
class ServerVariable(BaseModel):
|
||||
class ServerVariable(BaseModelWithConfig):
|
||||
enum: Annotated[Optional[List[str]], Field(min_length=1)] = None
|
||||
default: str
|
||||
description: Optional[str] = None
|
||||
|
||||
if PYDANTIC_V2:
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
else:
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
class Server(BaseModel):
|
||||
class Server(BaseModelWithConfig):
|
||||
url: Union[AnyUrl, str]
|
||||
description: Optional[str] = None
|
||||
variables: Optional[Dict[str, ServerVariable]] = None
|
||||
|
||||
if PYDANTIC_V2:
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
else:
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
class Reference(BaseModel):
|
||||
ref: str = Field(alias="$ref")
|
||||
@@ -138,36 +108,26 @@ class Discriminator(BaseModel):
|
||||
mapping: Optional[Dict[str, str]] = None
|
||||
|
||||
|
||||
class XML(BaseModel):
|
||||
class XML(BaseModelWithConfig):
|
||||
name: Optional[str] = None
|
||||
namespace: Optional[str] = None
|
||||
prefix: Optional[str] = None
|
||||
attribute: Optional[bool] = None
|
||||
wrapped: Optional[bool] = None
|
||||
|
||||
if PYDANTIC_V2:
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
else:
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
class ExternalDocumentation(BaseModel):
|
||||
class ExternalDocumentation(BaseModelWithConfig):
|
||||
description: Optional[str] = None
|
||||
url: AnyUrl
|
||||
|
||||
if PYDANTIC_V2:
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
else:
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
# Ref JSON Schema 2020-12: https://json-schema.org/draft/2020-12/json-schema-validation#name-type
|
||||
SchemaType = Literal[
|
||||
"array", "boolean", "integer", "null", "number", "object", "string"
|
||||
]
|
||||
|
||||
|
||||
class Schema(BaseModel):
|
||||
class Schema(BaseModelWithConfig):
|
||||
# Ref: JSON Schema 2020-12: https://json-schema.org/draft/2020-12/json-schema-core.html#name-the-json-schema-core-vocabu
|
||||
# Core Vocabulary
|
||||
schema_: Optional[str] = Field(default=None, alias="$schema")
|
||||
@@ -191,7 +151,7 @@ class Schema(BaseModel):
|
||||
dependentSchemas: Optional[Dict[str, "SchemaOrBool"]] = None
|
||||
prefixItems: Optional[List["SchemaOrBool"]] = None
|
||||
# TODO: uncomment and remove below when deprecating Pydantic v1
|
||||
# It generales a list of schemas for tuples, before prefixItems was available
|
||||
# It generates a list of schemas for tuples, before prefixItems was available
|
||||
# items: Optional["SchemaOrBool"] = None
|
||||
items: Optional[Union["SchemaOrBool", List["SchemaOrBool"]]] = None
|
||||
contains: Optional["SchemaOrBool"] = None
|
||||
@@ -203,7 +163,7 @@ class Schema(BaseModel):
|
||||
unevaluatedProperties: Optional["SchemaOrBool"] = None
|
||||
# Ref: JSON Schema Validation 2020-12: https://json-schema.org/draft/2020-12/json-schema-validation.html#name-a-vocabulary-for-structural
|
||||
# A Vocabulary for Structural Validation
|
||||
type: Optional[str] = None
|
||||
type: Optional[Union[SchemaType, List[SchemaType]]] = None
|
||||
enum: Optional[List[Any]] = None
|
||||
const: Optional[Any] = None
|
||||
multipleOf: Optional[float] = Field(default=None, gt=0)
|
||||
@@ -253,14 +213,6 @@ class Schema(BaseModel):
|
||||
),
|
||||
] = None
|
||||
|
||||
if PYDANTIC_V2:
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
else:
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
# Ref: https://json-schema.org/draft/2020-12/json-schema-core.html#name-json-schema-documents
|
||||
# A JSON Schema MUST be an object or a boolean.
|
||||
@@ -289,38 +241,22 @@ class ParameterInType(Enum):
|
||||
cookie = "cookie"
|
||||
|
||||
|
||||
class Encoding(BaseModel):
|
||||
class Encoding(BaseModelWithConfig):
|
||||
contentType: Optional[str] = None
|
||||
headers: Optional[Dict[str, Union["Header", Reference]]] = None
|
||||
style: Optional[str] = None
|
||||
explode: Optional[bool] = None
|
||||
allowReserved: Optional[bool] = None
|
||||
|
||||
if PYDANTIC_V2:
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
else:
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
class MediaType(BaseModel):
|
||||
class MediaType(BaseModelWithConfig):
|
||||
schema_: Optional[Union[Schema, Reference]] = Field(default=None, alias="schema")
|
||||
example: Optional[Any] = None
|
||||
examples: Optional[Dict[str, Union[Example, Reference]]] = None
|
||||
encoding: Optional[Dict[str, Encoding]] = None
|
||||
|
||||
if PYDANTIC_V2:
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
else:
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
class ParameterBase(BaseModel):
|
||||
class ParameterBase(BaseModelWithConfig):
|
||||
description: Optional[str] = None
|
||||
required: Optional[bool] = None
|
||||
deprecated: Optional[bool] = None
|
||||
@@ -334,14 +270,6 @@ class ParameterBase(BaseModel):
|
||||
# Serialization rules for more complex scenarios
|
||||
content: Optional[Dict[str, MediaType]] = None
|
||||
|
||||
if PYDANTIC_V2:
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
else:
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
class Parameter(ParameterBase):
|
||||
name: str
|
||||
@@ -352,21 +280,13 @@ class Header(ParameterBase):
|
||||
pass
|
||||
|
||||
|
||||
class RequestBody(BaseModel):
|
||||
class RequestBody(BaseModelWithConfig):
|
||||
description: Optional[str] = None
|
||||
content: Dict[str, MediaType]
|
||||
required: Optional[bool] = None
|
||||
|
||||
if PYDANTIC_V2:
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
else:
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
class Link(BaseModel):
|
||||
class Link(BaseModelWithConfig):
|
||||
operationRef: Optional[str] = None
|
||||
operationId: Optional[str] = None
|
||||
parameters: Optional[Dict[str, Union[Any, str]]] = None
|
||||
@@ -374,31 +294,15 @@ class Link(BaseModel):
|
||||
description: Optional[str] = None
|
||||
server: Optional[Server] = None
|
||||
|
||||
if PYDANTIC_V2:
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
else:
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
class Response(BaseModel):
|
||||
class Response(BaseModelWithConfig):
|
||||
description: str
|
||||
headers: Optional[Dict[str, Union[Header, Reference]]] = None
|
||||
content: Optional[Dict[str, MediaType]] = None
|
||||
links: Optional[Dict[str, Union[Link, Reference]]] = None
|
||||
|
||||
if PYDANTIC_V2:
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
else:
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
class Operation(BaseModel):
|
||||
class Operation(BaseModelWithConfig):
|
||||
tags: Optional[List[str]] = None
|
||||
summary: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
@@ -413,16 +317,8 @@ class Operation(BaseModel):
|
||||
security: Optional[List[Dict[str, List[str]]]] = None
|
||||
servers: Optional[List[Server]] = None
|
||||
|
||||
if PYDANTIC_V2:
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
else:
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
class PathItem(BaseModel):
|
||||
class PathItem(BaseModelWithConfig):
|
||||
ref: Optional[str] = Field(default=None, alias="$ref")
|
||||
summary: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
@@ -437,14 +333,6 @@ class PathItem(BaseModel):
|
||||
servers: Optional[List[Server]] = None
|
||||
parameters: Optional[List[Union[Parameter, Reference]]] = None
|
||||
|
||||
if PYDANTIC_V2:
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
else:
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
class SecuritySchemeType(Enum):
|
||||
apiKey = "apiKey"
|
||||
@@ -453,18 +341,10 @@ class SecuritySchemeType(Enum):
|
||||
openIdConnect = "openIdConnect"
|
||||
|
||||
|
||||
class SecurityBase(BaseModel):
|
||||
class SecurityBase(BaseModelWithConfig):
|
||||
type_: SecuritySchemeType = Field(alias="type")
|
||||
description: Optional[str] = None
|
||||
|
||||
if PYDANTIC_V2:
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
else:
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
class APIKeyIn(Enum):
|
||||
query = "query"
|
||||
@@ -488,18 +368,10 @@ class HTTPBearer(HTTPBase):
|
||||
bearerFormat: Optional[str] = None
|
||||
|
||||
|
||||
class OAuthFlow(BaseModel):
|
||||
class OAuthFlow(BaseModelWithConfig):
|
||||
refreshUrl: Optional[str] = None
|
||||
scopes: Dict[str, str] = {}
|
||||
|
||||
if PYDANTIC_V2:
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
else:
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
class OAuthFlowImplicit(OAuthFlow):
|
||||
authorizationUrl: str
|
||||
@@ -518,20 +390,12 @@ class OAuthFlowAuthorizationCode(OAuthFlow):
|
||||
tokenUrl: str
|
||||
|
||||
|
||||
class OAuthFlows(BaseModel):
|
||||
class OAuthFlows(BaseModelWithConfig):
|
||||
implicit: Optional[OAuthFlowImplicit] = None
|
||||
password: Optional[OAuthFlowPassword] = None
|
||||
clientCredentials: Optional[OAuthFlowClientCredentials] = None
|
||||
authorizationCode: Optional[OAuthFlowAuthorizationCode] = None
|
||||
|
||||
if PYDANTIC_V2:
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
else:
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
class OAuth2(SecurityBase):
|
||||
type_: SecuritySchemeType = Field(default=SecuritySchemeType.oauth2, alias="type")
|
||||
@@ -548,7 +412,7 @@ class OpenIdConnect(SecurityBase):
|
||||
SecurityScheme = Union[APIKey, HTTPBase, OAuth2, OpenIdConnect, HTTPBearer]
|
||||
|
||||
|
||||
class Components(BaseModel):
|
||||
class Components(BaseModelWithConfig):
|
||||
schemas: Optional[Dict[str, Union[Schema, Reference]]] = None
|
||||
responses: Optional[Dict[str, Union[Response, Reference]]] = None
|
||||
parameters: Optional[Dict[str, Union[Parameter, Reference]]] = None
|
||||
@@ -561,30 +425,14 @@ class Components(BaseModel):
|
||||
callbacks: Optional[Dict[str, Union[Dict[str, PathItem], Reference, Any]]] = None
|
||||
pathItems: Optional[Dict[str, Union[PathItem, Reference]]] = None
|
||||
|
||||
if PYDANTIC_V2:
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
else:
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
class Tag(BaseModel):
|
||||
class Tag(BaseModelWithConfig):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
externalDocs: Optional[ExternalDocumentation] = None
|
||||
|
||||
if PYDANTIC_V2:
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
else:
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
class OpenAPI(BaseModel):
|
||||
class OpenAPI(BaseModelWithConfig):
|
||||
openapi: str
|
||||
info: Info
|
||||
jsonSchemaDialect: Optional[str] = None
|
||||
@@ -597,14 +445,6 @@ class OpenAPI(BaseModel):
|
||||
tags: Optional[List[Tag]] = None
|
||||
externalDocs: Optional[ExternalDocumentation] = None
|
||||
|
||||
if PYDANTIC_V2:
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
else:
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
_model_rebuild(Schema)
|
||||
_model_rebuild(Operation)
|
||||
|
||||
@@ -5,7 +5,6 @@ from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Type, Union,
|
||||
|
||||
from fastapi import routing
|
||||
from fastapi._compat import (
|
||||
GenerateJsonSchema,
|
||||
JsonSchemaValue,
|
||||
ModelField,
|
||||
Undefined,
|
||||
@@ -16,11 +15,15 @@ from fastapi._compat import (
|
||||
)
|
||||
from fastapi.datastructures import DefaultPlaceholder
|
||||
from fastapi.dependencies.models import Dependant
|
||||
from fastapi.dependencies.utils import get_flat_dependant, get_flat_params
|
||||
from fastapi.dependencies.utils import (
|
||||
_get_flat_fields_from_params,
|
||||
get_flat_dependant,
|
||||
get_flat_params,
|
||||
)
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from fastapi.openapi.constants import METHODS_WITH_BODY, REF_PREFIX, REF_TEMPLATE
|
||||
from fastapi.openapi.constants import METHODS_WITH_BODY, REF_PREFIX
|
||||
from fastapi.openapi.models import OpenAPI
|
||||
from fastapi.params import Body, Param
|
||||
from fastapi.params import Body, ParamTypes
|
||||
from fastapi.responses import Response
|
||||
from fastapi.types import ModelNameMap
|
||||
from fastapi.utils import (
|
||||
@@ -28,11 +31,13 @@ from fastapi.utils import (
|
||||
generate_operation_id_for_path,
|
||||
is_body_allowed_for_status_code,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette.routing import BaseRoute
|
||||
from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY
|
||||
from typing_extensions import Literal
|
||||
|
||||
from .._compat import _is_model_field
|
||||
|
||||
validation_error_definition = {
|
||||
"title": "ValidationError",
|
||||
"type": "object",
|
||||
@@ -87,10 +92,9 @@ def get_openapi_security_definitions(
|
||||
return security_definitions, operation_security
|
||||
|
||||
|
||||
def get_openapi_operation_parameters(
|
||||
def _get_openapi_operation_parameters(
|
||||
*,
|
||||
all_route_params: Sequence[ModelField],
|
||||
schema_generator: GenerateJsonSchema,
|
||||
dependant: Dependant,
|
||||
model_name_map: ModelNameMap,
|
||||
field_mapping: Dict[
|
||||
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
|
||||
@@ -98,40 +102,72 @@ def get_openapi_operation_parameters(
|
||||
separate_input_output_schemas: bool = True,
|
||||
) -> List[Dict[str, Any]]:
|
||||
parameters = []
|
||||
for param in all_route_params:
|
||||
field_info = param.field_info
|
||||
field_info = cast(Param, field_info)
|
||||
if not field_info.include_in_schema:
|
||||
continue
|
||||
param_schema = get_schema_from_model_field(
|
||||
field=param,
|
||||
schema_generator=schema_generator,
|
||||
model_name_map=model_name_map,
|
||||
field_mapping=field_mapping,
|
||||
separate_input_output_schemas=separate_input_output_schemas,
|
||||
)
|
||||
parameter = {
|
||||
"name": param.alias,
|
||||
"in": field_info.in_.value,
|
||||
"required": param.required,
|
||||
"schema": param_schema,
|
||||
}
|
||||
if field_info.description:
|
||||
parameter["description"] = field_info.description
|
||||
if field_info.openapi_examples:
|
||||
parameter["examples"] = jsonable_encoder(field_info.openapi_examples)
|
||||
elif field_info.example != Undefined:
|
||||
parameter["example"] = jsonable_encoder(field_info.example)
|
||||
if field_info.deprecated:
|
||||
parameter["deprecated"] = field_info.deprecated
|
||||
parameters.append(parameter)
|
||||
flat_dependant = get_flat_dependant(dependant, skip_repeats=True)
|
||||
path_params = _get_flat_fields_from_params(flat_dependant.path_params)
|
||||
query_params = _get_flat_fields_from_params(flat_dependant.query_params)
|
||||
header_params = _get_flat_fields_from_params(flat_dependant.header_params)
|
||||
cookie_params = _get_flat_fields_from_params(flat_dependant.cookie_params)
|
||||
parameter_groups = [
|
||||
(ParamTypes.path, path_params),
|
||||
(ParamTypes.query, query_params),
|
||||
(ParamTypes.header, header_params),
|
||||
(ParamTypes.cookie, cookie_params),
|
||||
]
|
||||
default_convert_underscores = True
|
||||
if len(flat_dependant.header_params) == 1:
|
||||
first_field = flat_dependant.header_params[0]
|
||||
if lenient_issubclass(first_field.type_, BaseModel):
|
||||
default_convert_underscores = getattr(
|
||||
first_field.field_info, "convert_underscores", True
|
||||
)
|
||||
for param_type, param_group in parameter_groups:
|
||||
for param in param_group:
|
||||
field_info = param.field_info
|
||||
# field_info = cast(Param, field_info)
|
||||
if not getattr(field_info, "include_in_schema", True):
|
||||
continue
|
||||
param_schema = get_schema_from_model_field(
|
||||
field=param,
|
||||
model_name_map=model_name_map,
|
||||
field_mapping=field_mapping,
|
||||
separate_input_output_schemas=separate_input_output_schemas,
|
||||
)
|
||||
name = param.alias
|
||||
convert_underscores = getattr(
|
||||
param.field_info,
|
||||
"convert_underscores",
|
||||
default_convert_underscores,
|
||||
)
|
||||
if (
|
||||
param_type == ParamTypes.header
|
||||
and param.alias == param.name
|
||||
and convert_underscores
|
||||
):
|
||||
name = param.name.replace("_", "-")
|
||||
|
||||
parameter = {
|
||||
"name": name,
|
||||
"in": param_type.value,
|
||||
"required": param.required,
|
||||
"schema": param_schema,
|
||||
}
|
||||
if field_info.description:
|
||||
parameter["description"] = field_info.description
|
||||
openapi_examples = getattr(field_info, "openapi_examples", None)
|
||||
example = getattr(field_info, "example", None)
|
||||
if openapi_examples:
|
||||
parameter["examples"] = jsonable_encoder(openapi_examples)
|
||||
elif example != Undefined:
|
||||
parameter["example"] = jsonable_encoder(example)
|
||||
if getattr(field_info, "deprecated", None):
|
||||
parameter["deprecated"] = True
|
||||
parameters.append(parameter)
|
||||
return parameters
|
||||
|
||||
|
||||
def get_openapi_operation_request_body(
|
||||
*,
|
||||
body_field: Optional[ModelField],
|
||||
schema_generator: GenerateJsonSchema,
|
||||
model_name_map: ModelNameMap,
|
||||
field_mapping: Dict[
|
||||
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
|
||||
@@ -140,10 +176,9 @@ def get_openapi_operation_request_body(
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
if not body_field:
|
||||
return None
|
||||
assert isinstance(body_field, ModelField)
|
||||
assert _is_model_field(body_field)
|
||||
body_schema = get_schema_from_model_field(
|
||||
field=body_field,
|
||||
schema_generator=schema_generator,
|
||||
model_name_map=model_name_map,
|
||||
field_mapping=field_mapping,
|
||||
separate_input_output_schemas=separate_input_output_schemas,
|
||||
@@ -216,7 +251,6 @@ def get_openapi_path(
|
||||
*,
|
||||
route: routing.APIRoute,
|
||||
operation_ids: Set[str],
|
||||
schema_generator: GenerateJsonSchema,
|
||||
model_name_map: ModelNameMap,
|
||||
field_mapping: Dict[
|
||||
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
|
||||
@@ -247,10 +281,8 @@ def get_openapi_path(
|
||||
operation.setdefault("security", []).extend(operation_security)
|
||||
if security_definitions:
|
||||
security_schemes.update(security_definitions)
|
||||
all_route_params = get_flat_params(route.dependant)
|
||||
operation_parameters = get_openapi_operation_parameters(
|
||||
all_route_params=all_route_params,
|
||||
schema_generator=schema_generator,
|
||||
operation_parameters = _get_openapi_operation_parameters(
|
||||
dependant=route.dependant,
|
||||
model_name_map=model_name_map,
|
||||
field_mapping=field_mapping,
|
||||
separate_input_output_schemas=separate_input_output_schemas,
|
||||
@@ -272,7 +304,6 @@ def get_openapi_path(
|
||||
if method in METHODS_WITH_BODY:
|
||||
request_body_oai = get_openapi_operation_request_body(
|
||||
body_field=route.body_field,
|
||||
schema_generator=schema_generator,
|
||||
model_name_map=model_name_map,
|
||||
field_mapping=field_mapping,
|
||||
separate_input_output_schemas=separate_input_output_schemas,
|
||||
@@ -290,7 +321,6 @@ def get_openapi_path(
|
||||
) = get_openapi_path(
|
||||
route=callback,
|
||||
operation_ids=operation_ids,
|
||||
schema_generator=schema_generator,
|
||||
model_name_map=model_name_map,
|
||||
field_mapping=field_mapping,
|
||||
separate_input_output_schemas=separate_input_output_schemas,
|
||||
@@ -321,7 +351,6 @@ def get_openapi_path(
|
||||
if route.response_field:
|
||||
response_schema = get_schema_from_model_field(
|
||||
field=route.response_field,
|
||||
schema_generator=schema_generator,
|
||||
model_name_map=model_name_map,
|
||||
field_mapping=field_mapping,
|
||||
separate_input_output_schemas=separate_input_output_schemas,
|
||||
@@ -347,15 +376,14 @@ def get_openapi_path(
|
||||
openapi_response = operation_responses.setdefault(
|
||||
status_code_key, {}
|
||||
)
|
||||
assert isinstance(
|
||||
process_response, dict
|
||||
), "An additional response must be a dict"
|
||||
assert isinstance(process_response, dict), (
|
||||
"An additional response must be a dict"
|
||||
)
|
||||
field = route.response_fields.get(additional_status_code)
|
||||
additional_field_schema: Optional[Dict[str, Any]] = None
|
||||
if field:
|
||||
additional_field_schema = get_schema_from_model_field(
|
||||
field=field,
|
||||
schema_generator=schema_generator,
|
||||
model_name_map=model_name_map,
|
||||
field_mapping=field_mapping,
|
||||
separate_input_output_schemas=separate_input_output_schemas,
|
||||
@@ -378,7 +406,8 @@ def get_openapi_path(
|
||||
)
|
||||
deep_dict_update(openapi_response, process_response)
|
||||
openapi_response["description"] = description
|
||||
http422 = str(HTTP_422_UNPROCESSABLE_ENTITY)
|
||||
http422 = "422"
|
||||
all_route_params = get_flat_params(route.dependant)
|
||||
if (all_route_params or route.body_field) and not any(
|
||||
status in operation["responses"]
|
||||
for status in [http422, "4XX", "default"]
|
||||
@@ -416,9 +445,9 @@ def get_fields_from_routes(
|
||||
route, routing.APIRoute
|
||||
):
|
||||
if route.body_field:
|
||||
assert isinstance(
|
||||
route.body_field, ModelField
|
||||
), "A request body must be a Pydantic Field"
|
||||
assert _is_model_field(route.body_field), (
|
||||
"A request body must be a Pydantic Field"
|
||||
)
|
||||
body_fields_from_routes.append(route.body_field)
|
||||
if route.response_field:
|
||||
responses_from_routes.append(route.response_field)
|
||||
@@ -450,6 +479,7 @@ def get_openapi(
|
||||
contact: Optional[Dict[str, Union[str, Any]]] = None,
|
||||
license_info: Optional[Dict[str, Union[str, Any]]] = None,
|
||||
separate_input_output_schemas: bool = True,
|
||||
external_docs: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
info: Dict[str, Any] = {"title": title, "version": version}
|
||||
if summary:
|
||||
@@ -471,10 +501,8 @@ def get_openapi(
|
||||
operation_ids: Set[str] = set()
|
||||
all_fields = get_fields_from_routes(list(routes or []) + list(webhooks or []))
|
||||
model_name_map = get_compat_model_name_map(all_fields)
|
||||
schema_generator = GenerateJsonSchema(ref_template=REF_TEMPLATE)
|
||||
field_mapping, definitions = get_definitions(
|
||||
fields=all_fields,
|
||||
schema_generator=schema_generator,
|
||||
model_name_map=model_name_map,
|
||||
separate_input_output_schemas=separate_input_output_schemas,
|
||||
)
|
||||
@@ -483,7 +511,6 @@ def get_openapi(
|
||||
result = get_openapi_path(
|
||||
route=route,
|
||||
operation_ids=operation_ids,
|
||||
schema_generator=schema_generator,
|
||||
model_name_map=model_name_map,
|
||||
field_mapping=field_mapping,
|
||||
separate_input_output_schemas=separate_input_output_schemas,
|
||||
@@ -503,7 +530,6 @@ def get_openapi(
|
||||
result = get_openapi_path(
|
||||
route=webhook,
|
||||
operation_ids=operation_ids,
|
||||
schema_generator=schema_generator,
|
||||
model_name_map=model_name_map,
|
||||
field_mapping=field_mapping,
|
||||
separate_input_output_schemas=separate_input_output_schemas,
|
||||
@@ -527,4 +553,6 @@ def get_openapi(
|
||||
output["webhooks"] = webhook_paths
|
||||
if tags:
|
||||
output["tags"] = tags
|
||||
if external_docs:
|
||||
output["externalDocs"] = external_docs
|
||||
return jsonable_encoder(OpenAPI(**output), by_alias=True, exclude_none=True) # type: ignore
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
|
||||
|
||||
from annotated_doc import Doc
|
||||
from fastapi import params
|
||||
from fastapi._compat import Undefined
|
||||
from fastapi.openapi.models import Example
|
||||
from typing_extensions import Annotated, Doc, deprecated # type: ignore [attr-defined]
|
||||
from typing_extensions import Annotated, Literal, deprecated
|
||||
|
||||
_Unset: Any = Undefined
|
||||
|
||||
@@ -240,7 +241,7 @@ def Path( # noqa: N802
|
||||
),
|
||||
] = None,
|
||||
deprecated: Annotated[
|
||||
Optional[bool],
|
||||
Union[deprecated, str, bool, None],
|
||||
Doc(
|
||||
"""
|
||||
Mark this parameter field as deprecated.
|
||||
@@ -565,7 +566,7 @@ def Query( # noqa: N802
|
||||
),
|
||||
] = None,
|
||||
deprecated: Annotated[
|
||||
Optional[bool],
|
||||
Union[deprecated, str, bool, None],
|
||||
Doc(
|
||||
"""
|
||||
Mark this parameter field as deprecated.
|
||||
@@ -880,7 +881,7 @@ def Header( # noqa: N802
|
||||
),
|
||||
] = None,
|
||||
deprecated: Annotated[
|
||||
Optional[bool],
|
||||
Union[deprecated, str, bool, None],
|
||||
Doc(
|
||||
"""
|
||||
Mark this parameter field as deprecated.
|
||||
@@ -1185,7 +1186,7 @@ def Cookie( # noqa: N802
|
||||
),
|
||||
] = None,
|
||||
deprecated: Annotated[
|
||||
Optional[bool],
|
||||
Union[deprecated, str, bool, None],
|
||||
Doc(
|
||||
"""
|
||||
Mark this parameter field as deprecated.
|
||||
@@ -1282,7 +1283,7 @@ def Body( # noqa: N802
|
||||
),
|
||||
] = _Unset,
|
||||
embed: Annotated[
|
||||
bool,
|
||||
Union[bool, None],
|
||||
Doc(
|
||||
"""
|
||||
When `embed` is `True`, the parameter will be expected in a JSON body as a
|
||||
@@ -1294,7 +1295,7 @@ def Body( # noqa: N802
|
||||
[FastAPI docs for Body - Multiple Parameters](https://fastapi.tiangolo.com/tutorial/body-multiple-params/#embed-a-single-body-parameter).
|
||||
"""
|
||||
),
|
||||
] = False,
|
||||
] = None,
|
||||
media_type: Annotated[
|
||||
str,
|
||||
Doc(
|
||||
@@ -1512,7 +1513,7 @@ def Body( # noqa: N802
|
||||
),
|
||||
] = None,
|
||||
deprecated: Annotated[
|
||||
Optional[bool],
|
||||
Union[deprecated, str, bool, None],
|
||||
Doc(
|
||||
"""
|
||||
Mark this parameter field as deprecated.
|
||||
@@ -1827,7 +1828,7 @@ def Form( # noqa: N802
|
||||
),
|
||||
] = None,
|
||||
deprecated: Annotated[
|
||||
Optional[bool],
|
||||
Union[deprecated, str, bool, None],
|
||||
Doc(
|
||||
"""
|
||||
Mark this parameter field as deprecated.
|
||||
@@ -2141,7 +2142,7 @@ def File( # noqa: N802
|
||||
),
|
||||
] = None,
|
||||
deprecated: Annotated[
|
||||
Optional[bool],
|
||||
Union[deprecated, str, bool, None],
|
||||
Doc(
|
||||
"""
|
||||
Mark this parameter field as deprecated.
|
||||
@@ -2244,6 +2245,26 @@ def Depends( # noqa: N802
|
||||
"""
|
||||
),
|
||||
] = True,
|
||||
scope: Annotated[
|
||||
Union[Literal["function", "request"], None],
|
||||
Doc(
|
||||
"""
|
||||
Mainly for dependencies with `yield`, define when the dependency function
|
||||
should start (the code before `yield`) and when it should end (the code
|
||||
after `yield`).
|
||||
|
||||
* `"function"`: start the dependency before the *path operation function*
|
||||
that handles the request, end the dependency after the *path operation
|
||||
function* ends, but **before** the response is sent back to the client.
|
||||
So, the dependency function will be executed **around** the *path operation
|
||||
**function***.
|
||||
* `"request"`: start the dependency before the *path operation function*
|
||||
that handles the request (similar to when using `"function"`), but end
|
||||
**after** the response is sent back to the client. So, the dependency
|
||||
function will be executed **around** the **request** and response cycle.
|
||||
"""
|
||||
),
|
||||
] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Declare a FastAPI dependency.
|
||||
@@ -2274,7 +2295,7 @@ def Depends( # noqa: N802
|
||||
return commons
|
||||
```
|
||||
"""
|
||||
return params.Depends(dependency=dependency, use_cache=use_cache)
|
||||
return params.Depends(dependency=dependency, use_cache=use_cache, scope=scope)
|
||||
|
||||
|
||||
def Security( # noqa: N802
|
||||
@@ -2298,7 +2319,7 @@ def Security( # noqa: N802
|
||||
dependency.
|
||||
|
||||
The term "scope" comes from the OAuth2 specification, it seems to be
|
||||
intentionaly vague and interpretable. It normally refers to permissions,
|
||||
intentionally vague and interpretable. It normally refers to permissions,
|
||||
in cases to roles.
|
||||
|
||||
These scopes are integrated with OpenAPI (and the API docs at `/docs`).
|
||||
@@ -2343,7 +2364,7 @@ def Security( # noqa: N802
|
||||
```python
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import Depends, FastAPI
|
||||
from fastapi import Security, FastAPI
|
||||
|
||||
from .db import User
|
||||
from .security import get_current_active_user
|
||||
|
||||
@@ -1,12 +1,17 @@
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
|
||||
|
||||
from fastapi.openapi.models import Example
|
||||
from pydantic.fields import FieldInfo
|
||||
from typing_extensions import Annotated, deprecated
|
||||
from typing_extensions import Annotated, Literal, deprecated
|
||||
|
||||
from ._compat import PYDANTIC_V2, Undefined
|
||||
from ._compat import (
|
||||
PYDANTIC_V2,
|
||||
PYDANTIC_VERSION_MINOR_TUPLE,
|
||||
Undefined,
|
||||
)
|
||||
|
||||
_Unset: Any = Undefined
|
||||
|
||||
@@ -18,7 +23,7 @@ class ParamTypes(Enum):
|
||||
cookie = "cookie"
|
||||
|
||||
|
||||
class Param(FieldInfo):
|
||||
class Param(FieldInfo): # type: ignore[misc]
|
||||
in_: ParamTypes
|
||||
|
||||
def __init__(
|
||||
@@ -63,12 +68,11 @@ class Param(FieldInfo):
|
||||
),
|
||||
] = _Unset,
|
||||
openapi_examples: Optional[Dict[str, Example]] = None,
|
||||
deprecated: Optional[bool] = None,
|
||||
deprecated: Union[deprecated, str, bool, None] = None,
|
||||
include_in_schema: bool = True,
|
||||
json_schema_extra: Union[Dict[str, Any], None] = None,
|
||||
**extra: Any,
|
||||
):
|
||||
self.deprecated = deprecated
|
||||
if example is not _Unset:
|
||||
warnings.warn(
|
||||
"`example` has been deprecated, please use `examples` instead",
|
||||
@@ -92,7 +96,7 @@ class Param(FieldInfo):
|
||||
max_length=max_length,
|
||||
discriminator=discriminator,
|
||||
multiple_of=multiple_of,
|
||||
allow_nan=allow_inf_nan,
|
||||
allow_inf_nan=allow_inf_nan,
|
||||
max_digits=max_digits,
|
||||
decimal_places=decimal_places,
|
||||
**extra,
|
||||
@@ -106,6 +110,10 @@ class Param(FieldInfo):
|
||||
stacklevel=4,
|
||||
)
|
||||
current_json_schema_extra = json_schema_extra or extra
|
||||
if PYDANTIC_VERSION_MINOR_TUPLE < (2, 7):
|
||||
self.deprecated = deprecated
|
||||
else:
|
||||
kwargs["deprecated"] = deprecated
|
||||
if PYDANTIC_V2:
|
||||
kwargs.update(
|
||||
{
|
||||
@@ -129,7 +137,7 @@ class Param(FieldInfo):
|
||||
return f"{self.__class__.__name__}({self.default})"
|
||||
|
||||
|
||||
class Path(Param):
|
||||
class Path(Param): # type: ignore[misc]
|
||||
in_ = ParamTypes.path
|
||||
|
||||
def __init__(
|
||||
@@ -174,7 +182,7 @@ class Path(Param):
|
||||
),
|
||||
] = _Unset,
|
||||
openapi_examples: Optional[Dict[str, Example]] = None,
|
||||
deprecated: Optional[bool] = None,
|
||||
deprecated: Union[deprecated, str, bool, None] = None,
|
||||
include_in_schema: bool = True,
|
||||
json_schema_extra: Union[Dict[str, Any], None] = None,
|
||||
**extra: Any,
|
||||
@@ -215,7 +223,7 @@ class Path(Param):
|
||||
)
|
||||
|
||||
|
||||
class Query(Param):
|
||||
class Query(Param): # type: ignore[misc]
|
||||
in_ = ParamTypes.query
|
||||
|
||||
def __init__(
|
||||
@@ -260,7 +268,7 @@ class Query(Param):
|
||||
),
|
||||
] = _Unset,
|
||||
openapi_examples: Optional[Dict[str, Example]] = None,
|
||||
deprecated: Optional[bool] = None,
|
||||
deprecated: Union[deprecated, str, bool, None] = None,
|
||||
include_in_schema: bool = True,
|
||||
json_schema_extra: Union[Dict[str, Any], None] = None,
|
||||
**extra: Any,
|
||||
@@ -299,7 +307,7 @@ class Query(Param):
|
||||
)
|
||||
|
||||
|
||||
class Header(Param):
|
||||
class Header(Param): # type: ignore[misc]
|
||||
in_ = ParamTypes.header
|
||||
|
||||
def __init__(
|
||||
@@ -345,7 +353,7 @@ class Header(Param):
|
||||
),
|
||||
] = _Unset,
|
||||
openapi_examples: Optional[Dict[str, Example]] = None,
|
||||
deprecated: Optional[bool] = None,
|
||||
deprecated: Union[deprecated, str, bool, None] = None,
|
||||
include_in_schema: bool = True,
|
||||
json_schema_extra: Union[Dict[str, Any], None] = None,
|
||||
**extra: Any,
|
||||
@@ -385,7 +393,7 @@ class Header(Param):
|
||||
)
|
||||
|
||||
|
||||
class Cookie(Param):
|
||||
class Cookie(Param): # type: ignore[misc]
|
||||
in_ = ParamTypes.cookie
|
||||
|
||||
def __init__(
|
||||
@@ -430,7 +438,7 @@ class Cookie(Param):
|
||||
),
|
||||
] = _Unset,
|
||||
openapi_examples: Optional[Dict[str, Example]] = None,
|
||||
deprecated: Optional[bool] = None,
|
||||
deprecated: Union[deprecated, str, bool, None] = None,
|
||||
include_in_schema: bool = True,
|
||||
json_schema_extra: Union[Dict[str, Any], None] = None,
|
||||
**extra: Any,
|
||||
@@ -469,14 +477,14 @@ class Cookie(Param):
|
||||
)
|
||||
|
||||
|
||||
class Body(FieldInfo):
|
||||
class Body(FieldInfo): # type: ignore[misc]
|
||||
def __init__(
|
||||
self,
|
||||
default: Any = Undefined,
|
||||
*,
|
||||
default_factory: Union[Callable[[], Any], None] = _Unset,
|
||||
annotation: Optional[Any] = None,
|
||||
embed: bool = False,
|
||||
embed: Union[bool, None] = None,
|
||||
media_type: str = "application/json",
|
||||
alias: Optional[str] = None,
|
||||
alias_priority: Union[int, None] = _Unset,
|
||||
@@ -514,14 +522,13 @@ class Body(FieldInfo):
|
||||
),
|
||||
] = _Unset,
|
||||
openapi_examples: Optional[Dict[str, Example]] = None,
|
||||
deprecated: Optional[bool] = None,
|
||||
deprecated: Union[deprecated, str, bool, None] = None,
|
||||
include_in_schema: bool = True,
|
||||
json_schema_extra: Union[Dict[str, Any], None] = None,
|
||||
**extra: Any,
|
||||
):
|
||||
self.embed = embed
|
||||
self.media_type = media_type
|
||||
self.deprecated = deprecated
|
||||
if example is not _Unset:
|
||||
warnings.warn(
|
||||
"`example` has been deprecated, please use `examples` instead",
|
||||
@@ -545,7 +552,7 @@ class Body(FieldInfo):
|
||||
max_length=max_length,
|
||||
discriminator=discriminator,
|
||||
multiple_of=multiple_of,
|
||||
allow_nan=allow_inf_nan,
|
||||
allow_inf_nan=allow_inf_nan,
|
||||
max_digits=max_digits,
|
||||
decimal_places=decimal_places,
|
||||
**extra,
|
||||
@@ -554,11 +561,15 @@ class Body(FieldInfo):
|
||||
kwargs["examples"] = examples
|
||||
if regex is not None:
|
||||
warnings.warn(
|
||||
"`regex` has been depreacated, please use `pattern` instead",
|
||||
"`regex` has been deprecated, please use `pattern` instead",
|
||||
category=DeprecationWarning,
|
||||
stacklevel=4,
|
||||
)
|
||||
current_json_schema_extra = json_schema_extra or extra
|
||||
if PYDANTIC_VERSION_MINOR_TUPLE < (2, 7):
|
||||
self.deprecated = deprecated
|
||||
else:
|
||||
kwargs["deprecated"] = deprecated
|
||||
if PYDANTIC_V2:
|
||||
kwargs.update(
|
||||
{
|
||||
@@ -583,7 +594,7 @@ class Body(FieldInfo):
|
||||
return f"{self.__class__.__name__}({self.default})"
|
||||
|
||||
|
||||
class Form(Body):
|
||||
class Form(Body): # type: ignore[misc]
|
||||
def __init__(
|
||||
self,
|
||||
default: Any = Undefined,
|
||||
@@ -627,7 +638,7 @@ class Form(Body):
|
||||
),
|
||||
] = _Unset,
|
||||
openapi_examples: Optional[Dict[str, Example]] = None,
|
||||
deprecated: Optional[bool] = None,
|
||||
deprecated: Union[deprecated, str, bool, None] = None,
|
||||
include_in_schema: bool = True,
|
||||
json_schema_extra: Union[Dict[str, Any], None] = None,
|
||||
**extra: Any,
|
||||
@@ -636,7 +647,6 @@ class Form(Body):
|
||||
default=default,
|
||||
default_factory=default_factory,
|
||||
annotation=annotation,
|
||||
embed=True,
|
||||
media_type=media_type,
|
||||
alias=alias,
|
||||
alias_priority=alias_priority,
|
||||
@@ -668,7 +678,7 @@ class Form(Body):
|
||||
)
|
||||
|
||||
|
||||
class File(Form):
|
||||
class File(Form): # type: ignore[misc]
|
||||
def __init__(
|
||||
self,
|
||||
default: Any = Undefined,
|
||||
@@ -712,7 +722,7 @@ class File(Form):
|
||||
),
|
||||
] = _Unset,
|
||||
openapi_examples: Optional[Dict[str, Example]] = None,
|
||||
deprecated: Optional[bool] = None,
|
||||
deprecated: Union[deprecated, str, bool, None] = None,
|
||||
include_in_schema: bool = True,
|
||||
json_schema_extra: Union[Dict[str, Any], None] = None,
|
||||
**extra: Any,
|
||||
@@ -752,26 +762,13 @@ class File(Form):
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Depends:
|
||||
def __init__(
|
||||
self, dependency: Optional[Callable[..., Any]] = None, *, use_cache: bool = True
|
||||
):
|
||||
self.dependency = dependency
|
||||
self.use_cache = use_cache
|
||||
|
||||
def __repr__(self) -> str:
|
||||
attr = getattr(self.dependency, "__name__", type(self.dependency).__name__)
|
||||
cache = "" if self.use_cache else ", use_cache=False"
|
||||
return f"{self.__class__.__name__}({attr}{cache})"
|
||||
dependency: Optional[Callable[..., Any]] = None
|
||||
use_cache: bool = True
|
||||
scope: Union[Literal["function", "request"], None] = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Security(Depends):
|
||||
def __init__(
|
||||
self,
|
||||
dependency: Optional[Callable[..., Any]] = None,
|
||||
*,
|
||||
scopes: Optional[Sequence[str]] = None,
|
||||
use_cache: bool = True,
|
||||
):
|
||||
super().__init__(dependency=dependency, use_cache=use_cache)
|
||||
self.scopes = scopes or []
|
||||
scopes: Optional[Sequence[str]] = None
|
||||
|
||||
@@ -1,16 +1,21 @@
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import email.message
|
||||
import functools
|
||||
import inspect
|
||||
import json
|
||||
from contextlib import AsyncExitStack
|
||||
import sys
|
||||
from contextlib import AsyncExitStack, asynccontextmanager
|
||||
from enum import Enum, IntEnum
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Collection,
|
||||
Coroutine,
|
||||
Dict,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
@@ -19,7 +24,8 @@ from typing import (
|
||||
Union,
|
||||
)
|
||||
|
||||
from fastapi import params
|
||||
from annotated_doc import Doc
|
||||
from fastapi import params, temp_pydantic_v1_params
|
||||
from fastapi._compat import (
|
||||
ModelField,
|
||||
Undefined,
|
||||
@@ -31,8 +37,10 @@ from fastapi._compat import (
|
||||
from fastapi.datastructures import Default, DefaultPlaceholder
|
||||
from fastapi.dependencies.models import Dependant
|
||||
from fastapi.dependencies.utils import (
|
||||
_should_embed_body_fields,
|
||||
get_body_field,
|
||||
get_dependant,
|
||||
get_flat_dependant,
|
||||
get_parameterless_sub_dependant,
|
||||
get_typed_return_annotation,
|
||||
solve_dependencies,
|
||||
@@ -47,13 +55,15 @@ from fastapi.exceptions import (
|
||||
from fastapi.types import DecoratedCallable, IncEx
|
||||
from fastapi.utils import (
|
||||
create_cloned_field,
|
||||
create_response_field,
|
||||
create_model_field,
|
||||
generate_unique_id,
|
||||
get_value_or_default,
|
||||
is_body_allowed_for_status_code,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
from starlette import routing
|
||||
from starlette._exception_handler import wrap_app_handling_exceptions
|
||||
from starlette._utils import is_async_callable
|
||||
from starlette.concurrency import run_in_threadpool
|
||||
from starlette.exceptions import HTTPException
|
||||
from starlette.requests import Request
|
||||
@@ -63,13 +73,84 @@ from starlette.routing import (
|
||||
Match,
|
||||
compile_path,
|
||||
get_name,
|
||||
request_response,
|
||||
websocket_session,
|
||||
)
|
||||
from starlette.routing import Mount as Mount # noqa
|
||||
from starlette.types import ASGIApp, Lifespan, Scope
|
||||
from starlette.types import AppType, ASGIApp, Lifespan, Receive, Scope, Send
|
||||
from starlette.websockets import WebSocket
|
||||
from typing_extensions import Annotated, Doc, deprecated # type: ignore [attr-defined]
|
||||
from typing_extensions import Annotated, deprecated
|
||||
|
||||
if sys.version_info >= (3, 13): # pragma: no cover
|
||||
from inspect import iscoroutinefunction
|
||||
else: # pragma: no cover
|
||||
from asyncio import iscoroutinefunction
|
||||
|
||||
|
||||
# Copy of starlette.routing.request_response modified to include the
|
||||
# dependencies' AsyncExitStack
|
||||
def request_response(
|
||||
func: Callable[[Request], Union[Awaitable[Response], Response]],
|
||||
) -> ASGIApp:
|
||||
"""
|
||||
Takes a function or coroutine `func(request) -> response`,
|
||||
and returns an ASGI application.
|
||||
"""
|
||||
f: Callable[[Request], Awaitable[Response]] = (
|
||||
func if is_async_callable(func) else functools.partial(run_in_threadpool, func) # type:ignore
|
||||
)
|
||||
|
||||
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
||||
request = Request(scope, receive, send)
|
||||
|
||||
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
||||
# Starts customization
|
||||
response_awaited = False
|
||||
async with AsyncExitStack() as request_stack:
|
||||
scope["fastapi_inner_astack"] = request_stack
|
||||
async with AsyncExitStack() as function_stack:
|
||||
scope["fastapi_function_astack"] = function_stack
|
||||
response = await f(request)
|
||||
await response(scope, receive, send)
|
||||
# Continues customization
|
||||
response_awaited = True
|
||||
if not response_awaited:
|
||||
raise FastAPIError(
|
||||
"Response not awaited. There's a high chance that the "
|
||||
"application code is raising an exception and a dependency with yield "
|
||||
"has a block with a bare except, or a block with except Exception, "
|
||||
"and is not raising the exception again. Read more about it in the "
|
||||
"docs: https://fastapi.tiangolo.com/tutorial/dependencies/dependencies-with-yield/#dependencies-with-yield-and-except"
|
||||
)
|
||||
|
||||
# Same as in Starlette
|
||||
await wrap_app_handling_exceptions(app, request)(scope, receive, send)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
# Copy of starlette.routing.websocket_session modified to include the
|
||||
# dependencies' AsyncExitStack
|
||||
def websocket_session(
|
||||
func: Callable[[WebSocket], Awaitable[None]],
|
||||
) -> ASGIApp:
|
||||
"""
|
||||
Takes a coroutine `func(session)`, and returns an ASGI application.
|
||||
"""
|
||||
# assert asyncio.iscoroutinefunction(func), "WebSocket endpoints must be async"
|
||||
|
||||
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
||||
session = WebSocket(scope, receive=receive, send=send)
|
||||
|
||||
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
||||
async with AsyncExitStack() as request_stack:
|
||||
scope["fastapi_inner_astack"] = request_stack
|
||||
async with AsyncExitStack() as function_stack:
|
||||
scope["fastapi_function_astack"] = function_stack
|
||||
await func(session)
|
||||
|
||||
# Same as in Starlette
|
||||
await wrap_app_handling_exceptions(app, session)(scope, receive, send)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def _prepare_response_content(
|
||||
@@ -115,10 +196,28 @@ def _prepare_response_content(
|
||||
for k, v in res.items()
|
||||
}
|
||||
elif dataclasses.is_dataclass(res):
|
||||
assert not isinstance(res, type)
|
||||
return dataclasses.asdict(res)
|
||||
return res
|
||||
|
||||
|
||||
def _merge_lifespan_context(
|
||||
original_context: Lifespan[Any], nested_context: Lifespan[Any]
|
||||
) -> Lifespan[Any]:
|
||||
@asynccontextmanager
|
||||
async def merged_lifespan(
|
||||
app: AppType,
|
||||
) -> AsyncIterator[Optional[Mapping[str, Any]]]:
|
||||
async with original_context(app) as maybe_original_state:
|
||||
async with nested_context(app) as maybe_nested_state:
|
||||
if maybe_nested_state is None and maybe_original_state is None:
|
||||
yield None # old ASGI compatibility
|
||||
else:
|
||||
yield {**(maybe_nested_state or {}), **(maybe_original_state or {})}
|
||||
|
||||
return merged_lifespan # type: ignore[return-value]
|
||||
|
||||
|
||||
async def serialize_response(
|
||||
*,
|
||||
field: Optional[ModelField] = None,
|
||||
@@ -206,24 +305,32 @@ def get_request_handler(
|
||||
response_model_exclude_defaults: bool = False,
|
||||
response_model_exclude_none: bool = False,
|
||||
dependency_overrides_provider: Optional[Any] = None,
|
||||
embed_body_fields: bool = False,
|
||||
) -> Callable[[Request], Coroutine[Any, Any, Response]]:
|
||||
assert dependant.call is not None, "dependant.call must be a function"
|
||||
is_coroutine = asyncio.iscoroutinefunction(dependant.call)
|
||||
is_body_form = body_field and isinstance(body_field.field_info, params.Form)
|
||||
is_coroutine = iscoroutinefunction(dependant.call)
|
||||
is_body_form = body_field and isinstance(
|
||||
body_field.field_info, (params.Form, temp_pydantic_v1_params.Form)
|
||||
)
|
||||
if isinstance(response_class, DefaultPlaceholder):
|
||||
actual_response_class: Type[Response] = response_class.value
|
||||
else:
|
||||
actual_response_class = response_class
|
||||
|
||||
async def app(request: Request) -> Response:
|
||||
response: Union[Response, None] = None
|
||||
file_stack = request.scope.get("fastapi_middleware_astack")
|
||||
assert isinstance(file_stack, AsyncExitStack), (
|
||||
"fastapi_middleware_astack not found in request scope"
|
||||
)
|
||||
|
||||
# Read body and auto-close files
|
||||
try:
|
||||
body: Any = None
|
||||
if body_field:
|
||||
if is_body_form:
|
||||
body = await request.form()
|
||||
stack = request.scope.get("fastapi_astack")
|
||||
assert isinstance(stack, AsyncExitStack)
|
||||
stack.push_async_callback(body.close)
|
||||
file_stack.push_async_callback(body.close)
|
||||
else:
|
||||
body_bytes = await request.body()
|
||||
if body_bytes:
|
||||
@@ -243,7 +350,7 @@ def get_request_handler(
|
||||
else:
|
||||
body = body_bytes
|
||||
except json.JSONDecodeError as e:
|
||||
raise RequestValidationError(
|
||||
validation_error = RequestValidationError(
|
||||
[
|
||||
{
|
||||
"type": "json_invalid",
|
||||
@@ -254,75 +361,106 @@ def get_request_handler(
|
||||
}
|
||||
],
|
||||
body=e.doc,
|
||||
) from e
|
||||
)
|
||||
raise validation_error from e
|
||||
except HTTPException:
|
||||
# If a middleware raises an HTTPException, it should be raised again
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
http_error = HTTPException(
|
||||
status_code=400, detail="There was an error parsing the body"
|
||||
) from e
|
||||
)
|
||||
raise http_error from e
|
||||
|
||||
# Solve dependencies and run path operation function, auto-closing dependencies
|
||||
errors: List[Any] = []
|
||||
async_exit_stack = request.scope.get("fastapi_inner_astack")
|
||||
assert isinstance(async_exit_stack, AsyncExitStack), (
|
||||
"fastapi_inner_astack not found in request scope"
|
||||
)
|
||||
solved_result = await solve_dependencies(
|
||||
request=request,
|
||||
dependant=dependant,
|
||||
body=body,
|
||||
dependency_overrides_provider=dependency_overrides_provider,
|
||||
async_exit_stack=async_exit_stack,
|
||||
embed_body_fields=embed_body_fields,
|
||||
)
|
||||
values, errors, background_tasks, sub_response, _ = solved_result
|
||||
if errors:
|
||||
raise RequestValidationError(_normalize_errors(errors), body=body)
|
||||
else:
|
||||
errors = solved_result.errors
|
||||
if not errors:
|
||||
raw_response = await run_endpoint_function(
|
||||
dependant=dependant, values=values, is_coroutine=is_coroutine
|
||||
)
|
||||
|
||||
if isinstance(raw_response, Response):
|
||||
if raw_response.background is None:
|
||||
raw_response.background = background_tasks
|
||||
return raw_response
|
||||
response_args: Dict[str, Any] = {"background": background_tasks}
|
||||
# If status_code was set, use it, otherwise use the default from the
|
||||
# response class, in the case of redirect it's 307
|
||||
current_status_code = (
|
||||
status_code if status_code else sub_response.status_code
|
||||
)
|
||||
if current_status_code is not None:
|
||||
response_args["status_code"] = current_status_code
|
||||
if sub_response.status_code:
|
||||
response_args["status_code"] = sub_response.status_code
|
||||
content = await serialize_response(
|
||||
field=response_field,
|
||||
response_content=raw_response,
|
||||
include=response_model_include,
|
||||
exclude=response_model_exclude,
|
||||
by_alias=response_model_by_alias,
|
||||
exclude_unset=response_model_exclude_unset,
|
||||
exclude_defaults=response_model_exclude_defaults,
|
||||
exclude_none=response_model_exclude_none,
|
||||
dependant=dependant,
|
||||
values=solved_result.values,
|
||||
is_coroutine=is_coroutine,
|
||||
)
|
||||
response = actual_response_class(content, **response_args)
|
||||
if not is_body_allowed_for_status_code(response.status_code):
|
||||
response.body = b""
|
||||
response.headers.raw.extend(sub_response.headers.raw)
|
||||
return response
|
||||
if isinstance(raw_response, Response):
|
||||
if raw_response.background is None:
|
||||
raw_response.background = solved_result.background_tasks
|
||||
response = raw_response
|
||||
else:
|
||||
response_args: Dict[str, Any] = {
|
||||
"background": solved_result.background_tasks
|
||||
}
|
||||
# If status_code was set, use it, otherwise use the default from the
|
||||
# response class, in the case of redirect it's 307
|
||||
current_status_code = (
|
||||
status_code if status_code else solved_result.response.status_code
|
||||
)
|
||||
if current_status_code is not None:
|
||||
response_args["status_code"] = current_status_code
|
||||
if solved_result.response.status_code:
|
||||
response_args["status_code"] = solved_result.response.status_code
|
||||
content = await serialize_response(
|
||||
field=response_field,
|
||||
response_content=raw_response,
|
||||
include=response_model_include,
|
||||
exclude=response_model_exclude,
|
||||
by_alias=response_model_by_alias,
|
||||
exclude_unset=response_model_exclude_unset,
|
||||
exclude_defaults=response_model_exclude_defaults,
|
||||
exclude_none=response_model_exclude_none,
|
||||
is_coroutine=is_coroutine,
|
||||
)
|
||||
response = actual_response_class(content, **response_args)
|
||||
if not is_body_allowed_for_status_code(response.status_code):
|
||||
response.body = b""
|
||||
response.headers.raw.extend(solved_result.response.headers.raw)
|
||||
if errors:
|
||||
validation_error = RequestValidationError(
|
||||
_normalize_errors(errors), body=body
|
||||
)
|
||||
raise validation_error
|
||||
|
||||
# Return response
|
||||
assert response
|
||||
return response
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def get_websocket_app(
|
||||
dependant: Dependant, dependency_overrides_provider: Optional[Any] = None
|
||||
dependant: Dependant,
|
||||
dependency_overrides_provider: Optional[Any] = None,
|
||||
embed_body_fields: bool = False,
|
||||
) -> Callable[[WebSocket], Coroutine[Any, Any, Any]]:
|
||||
async def app(websocket: WebSocket) -> None:
|
||||
async_exit_stack = websocket.scope.get("fastapi_inner_astack")
|
||||
assert isinstance(async_exit_stack, AsyncExitStack), (
|
||||
"fastapi_inner_astack not found in request scope"
|
||||
)
|
||||
solved_result = await solve_dependencies(
|
||||
request=websocket,
|
||||
dependant=dependant,
|
||||
dependency_overrides_provider=dependency_overrides_provider,
|
||||
async_exit_stack=async_exit_stack,
|
||||
embed_body_fields=embed_body_fields,
|
||||
)
|
||||
values, errors, _, _2, _3 = solved_result
|
||||
if errors:
|
||||
raise WebSocketRequestValidationError(_normalize_errors(errors))
|
||||
if solved_result.errors:
|
||||
raise WebSocketRequestValidationError(
|
||||
_normalize_errors(solved_result.errors)
|
||||
)
|
||||
assert dependant.call is not None, "dependant.call must be a function"
|
||||
await dependant.call(**values)
|
||||
await dependant.call(**solved_result.values)
|
||||
|
||||
return app
|
||||
|
||||
@@ -342,17 +480,23 @@ class APIWebSocketRoute(routing.WebSocketRoute):
|
||||
self.name = get_name(endpoint) if name is None else name
|
||||
self.dependencies = list(dependencies or [])
|
||||
self.path_regex, self.path_format, self.param_convertors = compile_path(path)
|
||||
self.dependant = get_dependant(path=self.path_format, call=self.endpoint)
|
||||
self.dependant = get_dependant(
|
||||
path=self.path_format, call=self.endpoint, scope="function"
|
||||
)
|
||||
for depends in self.dependencies[::-1]:
|
||||
self.dependant.dependencies.insert(
|
||||
0,
|
||||
get_parameterless_sub_dependant(depends=depends, path=self.path_format),
|
||||
)
|
||||
|
||||
self._flat_dependant = get_flat_dependant(self.dependant)
|
||||
self._embed_body_fields = _should_embed_body_fields(
|
||||
self._flat_dependant.body_params
|
||||
)
|
||||
self.app = websocket_session(
|
||||
get_websocket_app(
|
||||
dependant=self.dependant,
|
||||
dependency_overrides_provider=dependency_overrides_provider,
|
||||
embed_body_fields=self._embed_body_fields,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -431,9 +575,9 @@ class APIRoute(routing.Route):
|
||||
methods = ["GET"]
|
||||
self.methods: Set[str] = {method.upper() for method in methods}
|
||||
if isinstance(generate_unique_id_function, DefaultPlaceholder):
|
||||
current_generate_unique_id: Callable[
|
||||
["APIRoute"], str
|
||||
] = generate_unique_id_function.value
|
||||
current_generate_unique_id: Callable[[APIRoute], str] = (
|
||||
generate_unique_id_function.value
|
||||
)
|
||||
else:
|
||||
current_generate_unique_id = generate_unique_id_function
|
||||
self.unique_id = self.operation_id or current_generate_unique_id(self)
|
||||
@@ -442,11 +586,11 @@ class APIRoute(routing.Route):
|
||||
status_code = int(status_code)
|
||||
self.status_code = status_code
|
||||
if self.response_model:
|
||||
assert is_body_allowed_for_status_code(
|
||||
status_code
|
||||
), f"Status code {status_code} must not have a response body"
|
||||
assert is_body_allowed_for_status_code(status_code), (
|
||||
f"Status code {status_code} must not have a response body"
|
||||
)
|
||||
response_name = "Response_" + self.unique_id
|
||||
self.response_field = create_response_field(
|
||||
self.response_field = create_model_field(
|
||||
name=response_name,
|
||||
type_=self.response_model,
|
||||
mode="serialization",
|
||||
@@ -459,9 +603,9 @@ class APIRoute(routing.Route):
|
||||
# By being a new field, no inheritance will be passed as is. A new model
|
||||
# will always be created.
|
||||
# TODO: remove when deprecating Pydantic v1
|
||||
self.secure_cloned_response_field: Optional[
|
||||
ModelField
|
||||
] = create_cloned_field(self.response_field)
|
||||
self.secure_cloned_response_field: Optional[ModelField] = (
|
||||
create_cloned_field(self.response_field)
|
||||
)
|
||||
else:
|
||||
self.response_field = None # type: ignore
|
||||
self.secure_cloned_response_field = None
|
||||
@@ -475,11 +619,13 @@ class APIRoute(routing.Route):
|
||||
assert isinstance(response, dict), "An additional response must be a dict"
|
||||
model = response.get("model")
|
||||
if model:
|
||||
assert is_body_allowed_for_status_code(
|
||||
additional_status_code
|
||||
), f"Status code {additional_status_code} must not have a response body"
|
||||
assert is_body_allowed_for_status_code(additional_status_code), (
|
||||
f"Status code {additional_status_code} must not have a response body"
|
||||
)
|
||||
response_name = f"Response_{additional_status_code}_{self.unique_id}"
|
||||
response_field = create_response_field(name=response_name, type_=model)
|
||||
response_field = create_model_field(
|
||||
name=response_name, type_=model, mode="serialization"
|
||||
)
|
||||
response_fields[additional_status_code] = response_field
|
||||
if response_fields:
|
||||
self.response_fields: Dict[Union[int, str], ModelField] = response_fields
|
||||
@@ -487,13 +633,23 @@ class APIRoute(routing.Route):
|
||||
self.response_fields = {}
|
||||
|
||||
assert callable(endpoint), "An endpoint must be a callable"
|
||||
self.dependant = get_dependant(path=self.path_format, call=self.endpoint)
|
||||
self.dependant = get_dependant(
|
||||
path=self.path_format, call=self.endpoint, scope="function"
|
||||
)
|
||||
for depends in self.dependencies[::-1]:
|
||||
self.dependant.dependencies.insert(
|
||||
0,
|
||||
get_parameterless_sub_dependant(depends=depends, path=self.path_format),
|
||||
)
|
||||
self.body_field = get_body_field(dependant=self.dependant, name=self.unique_id)
|
||||
self._flat_dependant = get_flat_dependant(self.dependant)
|
||||
self._embed_body_fields = _should_embed_body_fields(
|
||||
self._flat_dependant.body_params
|
||||
)
|
||||
self.body_field = get_body_field(
|
||||
flat_dependant=self._flat_dependant,
|
||||
name=self.unique_id,
|
||||
embed_body_fields=self._embed_body_fields,
|
||||
)
|
||||
self.app = request_response(self.get_route_handler())
|
||||
|
||||
def get_route_handler(self) -> Callable[[Request], Coroutine[Any, Any, Response]]:
|
||||
@@ -510,6 +666,7 @@ class APIRoute(routing.Route):
|
||||
response_model_exclude_defaults=self.response_model_exclude_defaults,
|
||||
response_model_exclude_none=self.response_model_exclude_none,
|
||||
dependency_overrides_provider=self.dependency_overrides_provider,
|
||||
embed_body_fields=self._embed_body_fields,
|
||||
)
|
||||
|
||||
def matches(self, scope: Scope) -> Tuple[Match, Scope]:
|
||||
@@ -741,7 +898,7 @@ class APIRouter(routing.Router):
|
||||
This affects the generated OpenAPI (e.g. visible at `/docs`).
|
||||
|
||||
Read more about it in the
|
||||
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi).
|
||||
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi).
|
||||
"""
|
||||
),
|
||||
] = True,
|
||||
@@ -771,9 +928,9 @@ class APIRouter(routing.Router):
|
||||
)
|
||||
if prefix:
|
||||
assert prefix.startswith("/"), "A path prefix must start with '/'"
|
||||
assert not prefix.endswith(
|
||||
"/"
|
||||
), "A path prefix must not end with '/', as the routes will start with '/'"
|
||||
assert not prefix.endswith("/"), (
|
||||
"A path prefix must not end with '/', as the routes will start with '/'"
|
||||
)
|
||||
self.prefix = prefix
|
||||
self.tags: List[Union[str, Enum]] = tags or []
|
||||
self.dependencies = list(dependencies or [])
|
||||
@@ -789,7 +946,7 @@ class APIRouter(routing.Router):
|
||||
def route(
|
||||
self,
|
||||
path: str,
|
||||
methods: Optional[List[str]] = None,
|
||||
methods: Optional[Collection[str]] = None,
|
||||
name: Optional[str] = None,
|
||||
include_in_schema: bool = True,
|
||||
) -> Callable[[DecoratedCallable], DecoratedCallable]:
|
||||
@@ -1183,9 +1340,9 @@ class APIRouter(routing.Router):
|
||||
"""
|
||||
if prefix:
|
||||
assert prefix.startswith("/"), "A path prefix must start with '/'"
|
||||
assert not prefix.endswith(
|
||||
"/"
|
||||
), "A path prefix must not end with '/', as the routes will start with '/'"
|
||||
assert not prefix.endswith("/"), (
|
||||
"A path prefix must not end with '/', as the routes will start with '/'"
|
||||
)
|
||||
else:
|
||||
for r in router.routes:
|
||||
path = getattr(r, "path") # noqa: B009
|
||||
@@ -1285,6 +1442,10 @@ class APIRouter(routing.Router):
|
||||
self.add_event_handler("startup", handler)
|
||||
for handler in router.on_shutdown:
|
||||
self.add_event_handler("shutdown", handler)
|
||||
self.lifespan_context = _merge_lifespan_context(
|
||||
self.lifespan_context,
|
||||
router.lifespan_context,
|
||||
)
|
||||
|
||||
def get(
|
||||
self,
|
||||
@@ -1549,7 +1710,7 @@ class APIRouter(routing.Router):
|
||||
This affects the generated OpenAPI (e.g. visible at `/docs`).
|
||||
|
||||
Read more about it in the
|
||||
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi).
|
||||
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi).
|
||||
"""
|
||||
),
|
||||
] = True,
|
||||
@@ -1926,7 +2087,7 @@ class APIRouter(routing.Router):
|
||||
This affects the generated OpenAPI (e.g. visible at `/docs`).
|
||||
|
||||
Read more about it in the
|
||||
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi).
|
||||
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi).
|
||||
"""
|
||||
),
|
||||
] = True,
|
||||
@@ -2308,7 +2469,7 @@ class APIRouter(routing.Router):
|
||||
This affects the generated OpenAPI (e.g. visible at `/docs`).
|
||||
|
||||
Read more about it in the
|
||||
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi).
|
||||
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi).
|
||||
"""
|
||||
),
|
||||
] = True,
|
||||
@@ -2690,7 +2851,7 @@ class APIRouter(routing.Router):
|
||||
This affects the generated OpenAPI (e.g. visible at `/docs`).
|
||||
|
||||
Read more about it in the
|
||||
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi).
|
||||
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi).
|
||||
"""
|
||||
),
|
||||
] = True,
|
||||
@@ -3067,7 +3228,7 @@ class APIRouter(routing.Router):
|
||||
This affects the generated OpenAPI (e.g. visible at `/docs`).
|
||||
|
||||
Read more about it in the
|
||||
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi).
|
||||
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi).
|
||||
"""
|
||||
),
|
||||
] = True,
|
||||
@@ -3444,7 +3605,7 @@ class APIRouter(routing.Router):
|
||||
This affects the generated OpenAPI (e.g. visible at `/docs`).
|
||||
|
||||
Read more about it in the
|
||||
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi).
|
||||
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi).
|
||||
"""
|
||||
),
|
||||
] = True,
|
||||
@@ -3826,7 +3987,7 @@ class APIRouter(routing.Router):
|
||||
This affects the generated OpenAPI (e.g. visible at `/docs`).
|
||||
|
||||
Read more about it in the
|
||||
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi).
|
||||
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi).
|
||||
"""
|
||||
),
|
||||
] = True,
|
||||
@@ -4208,7 +4369,7 @@ class APIRouter(routing.Router):
|
||||
This affects the generated OpenAPI (e.g. visible at `/docs`).
|
||||
|
||||
Read more about it in the
|
||||
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi).
|
||||
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi).
|
||||
"""
|
||||
),
|
||||
] = True,
|
||||
@@ -4293,7 +4454,7 @@ class APIRouter(routing.Router):
|
||||
app = FastAPI()
|
||||
router = APIRouter()
|
||||
|
||||
@router.put("/items/{item_id}")
|
||||
@router.trace("/items/{item_id}")
|
||||
def trace_item(item_id: str):
|
||||
return None
|
||||
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,15 +1,54 @@
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
from annotated_doc import Doc
|
||||
from fastapi.openapi.models import APIKey, APIKeyIn
|
||||
from fastapi.security.base import SecurityBase
|
||||
from starlette.exceptions import HTTPException
|
||||
from starlette.requests import Request
|
||||
from starlette.status import HTTP_403_FORBIDDEN
|
||||
from typing_extensions import Annotated, Doc # type: ignore [attr-defined]
|
||||
from starlette.status import HTTP_401_UNAUTHORIZED
|
||||
from typing_extensions import Annotated
|
||||
|
||||
|
||||
class APIKeyBase(SecurityBase):
|
||||
pass
|
||||
def __init__(
|
||||
self,
|
||||
location: APIKeyIn,
|
||||
name: str,
|
||||
description: Union[str, None],
|
||||
scheme_name: Union[str, None],
|
||||
auto_error: bool,
|
||||
):
|
||||
self.auto_error = auto_error
|
||||
|
||||
self.model: APIKey = APIKey(
|
||||
**{"in": location},
|
||||
name=name,
|
||||
description=description,
|
||||
)
|
||||
self.scheme_name = scheme_name or self.__class__.__name__
|
||||
|
||||
def make_not_authenticated_error(self) -> HTTPException:
|
||||
"""
|
||||
The WWW-Authenticate header is not standardized for API Key authentication but
|
||||
the HTTP specification requires that an error of 401 "Unauthorized" must
|
||||
include a WWW-Authenticate header.
|
||||
|
||||
Ref: https://datatracker.ietf.org/doc/html/rfc9110#name-401-unauthorized
|
||||
|
||||
For this, this method sends a custom challenge `APIKey`.
|
||||
"""
|
||||
return HTTPException(
|
||||
status_code=HTTP_401_UNAUTHORIZED,
|
||||
detail="Not authenticated",
|
||||
headers={"WWW-Authenticate": "APIKey"},
|
||||
)
|
||||
|
||||
def check_api_key(self, api_key: Optional[str]) -> Optional[str]:
|
||||
if not api_key:
|
||||
if self.auto_error:
|
||||
raise self.make_not_authenticated_error()
|
||||
return None
|
||||
return api_key
|
||||
|
||||
|
||||
class APIKeyQuery(APIKeyBase):
|
||||
@@ -76,7 +115,7 @@ class APIKeyQuery(APIKeyBase):
|
||||
Doc(
|
||||
"""
|
||||
By default, if the query parameter is not provided, `APIKeyQuery` will
|
||||
automatically cancel the request and sebd the client an error.
|
||||
automatically cancel the request and send the client an error.
|
||||
|
||||
If `auto_error` is set to `False`, when the query parameter is not
|
||||
available, instead of erroring out, the dependency result will be
|
||||
@@ -91,24 +130,17 @@ class APIKeyQuery(APIKeyBase):
|
||||
),
|
||||
] = True,
|
||||
):
|
||||
self.model: APIKey = APIKey(
|
||||
**{"in": APIKeyIn.query}, # type: ignore[arg-type]
|
||||
super().__init__(
|
||||
location=APIKeyIn.query,
|
||||
name=name,
|
||||
scheme_name=scheme_name,
|
||||
description=description,
|
||||
auto_error=auto_error,
|
||||
)
|
||||
self.scheme_name = scheme_name or self.__class__.__name__
|
||||
self.auto_error = auto_error
|
||||
|
||||
async def __call__(self, request: Request) -> Optional[str]:
|
||||
api_key = request.query_params.get(self.model.name)
|
||||
if not api_key:
|
||||
if self.auto_error:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
|
||||
)
|
||||
else:
|
||||
return None
|
||||
return api_key
|
||||
return self.check_api_key(api_key)
|
||||
|
||||
|
||||
class APIKeyHeader(APIKeyBase):
|
||||
@@ -186,24 +218,17 @@ class APIKeyHeader(APIKeyBase):
|
||||
),
|
||||
] = True,
|
||||
):
|
||||
self.model: APIKey = APIKey(
|
||||
**{"in": APIKeyIn.header}, # type: ignore[arg-type]
|
||||
super().__init__(
|
||||
location=APIKeyIn.header,
|
||||
name=name,
|
||||
scheme_name=scheme_name,
|
||||
description=description,
|
||||
auto_error=auto_error,
|
||||
)
|
||||
self.scheme_name = scheme_name or self.__class__.__name__
|
||||
self.auto_error = auto_error
|
||||
|
||||
async def __call__(self, request: Request) -> Optional[str]:
|
||||
api_key = request.headers.get(self.model.name)
|
||||
if not api_key:
|
||||
if self.auto_error:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
|
||||
)
|
||||
else:
|
||||
return None
|
||||
return api_key
|
||||
return self.check_api_key(api_key)
|
||||
|
||||
|
||||
class APIKeyCookie(APIKeyBase):
|
||||
@@ -281,21 +306,14 @@ class APIKeyCookie(APIKeyBase):
|
||||
),
|
||||
] = True,
|
||||
):
|
||||
self.model: APIKey = APIKey(
|
||||
**{"in": APIKeyIn.cookie}, # type: ignore[arg-type]
|
||||
super().__init__(
|
||||
location=APIKeyIn.cookie,
|
||||
name=name,
|
||||
scheme_name=scheme_name,
|
||||
description=description,
|
||||
auto_error=auto_error,
|
||||
)
|
||||
self.scheme_name = scheme_name or self.__class__.__name__
|
||||
self.auto_error = auto_error
|
||||
|
||||
async def __call__(self, request: Request) -> Optional[str]:
|
||||
api_key = request.cookies.get(self.model.name)
|
||||
if not api_key:
|
||||
if self.auto_error:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
|
||||
)
|
||||
else:
|
||||
return None
|
||||
return api_key
|
||||
return self.check_api_key(api_key)
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import binascii
|
||||
from base64 import b64decode
|
||||
from typing import Optional
|
||||
from typing import Dict, Optional
|
||||
|
||||
from annotated_doc import Doc
|
||||
from fastapi.exceptions import HTTPException
|
||||
from fastapi.openapi.models import HTTPBase as HTTPBaseModel
|
||||
from fastapi.openapi.models import HTTPBearer as HTTPBearerModel
|
||||
@@ -9,13 +10,13 @@ from fastapi.security.base import SecurityBase
|
||||
from fastapi.security.utils import get_authorization_scheme_param
|
||||
from pydantic import BaseModel
|
||||
from starlette.requests import Request
|
||||
from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN
|
||||
from typing_extensions import Annotated, Doc # type: ignore [attr-defined]
|
||||
from starlette.status import HTTP_401_UNAUTHORIZED
|
||||
from typing_extensions import Annotated
|
||||
|
||||
|
||||
class HTTPBasicCredentials(BaseModel):
|
||||
"""
|
||||
The HTTP Basic credendials given as the result of using `HTTPBasic` in a
|
||||
The HTTP Basic credentials given as the result of using `HTTPBasic` in a
|
||||
dependency.
|
||||
|
||||
Read more about it in the
|
||||
@@ -75,10 +76,22 @@ class HTTPBase(SecurityBase):
|
||||
description: Optional[str] = None,
|
||||
auto_error: bool = True,
|
||||
):
|
||||
self.model = HTTPBaseModel(scheme=scheme, description=description)
|
||||
self.model: HTTPBaseModel = HTTPBaseModel(
|
||||
scheme=scheme, description=description
|
||||
)
|
||||
self.scheme_name = scheme_name or self.__class__.__name__
|
||||
self.auto_error = auto_error
|
||||
|
||||
def make_authenticate_headers(self) -> Dict[str, str]:
|
||||
return {"WWW-Authenticate": f"{self.model.scheme.title()}"}
|
||||
|
||||
def make_not_authenticated_error(self) -> HTTPException:
|
||||
return HTTPException(
|
||||
status_code=HTTP_401_UNAUTHORIZED,
|
||||
detail="Not authenticated",
|
||||
headers=self.make_authenticate_headers(),
|
||||
)
|
||||
|
||||
async def __call__(
|
||||
self, request: Request
|
||||
) -> Optional[HTTPAuthorizationCredentials]:
|
||||
@@ -86,9 +99,7 @@ class HTTPBase(SecurityBase):
|
||||
scheme, credentials = get_authorization_scheme_param(authorization)
|
||||
if not (authorization and scheme and credentials):
|
||||
if self.auto_error:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
|
||||
)
|
||||
raise self.make_not_authenticated_error()
|
||||
else:
|
||||
return None
|
||||
return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)
|
||||
@@ -98,6 +109,8 @@ class HTTPBasic(HTTPBase):
|
||||
"""
|
||||
HTTP Basic authentication.
|
||||
|
||||
Ref: https://datatracker.ietf.org/doc/html/rfc7617
|
||||
|
||||
## Usage
|
||||
|
||||
Create an instance object and use that object as the dependency in `Depends()`.
|
||||
@@ -184,36 +197,28 @@ class HTTPBasic(HTTPBase):
|
||||
self.realm = realm
|
||||
self.auto_error = auto_error
|
||||
|
||||
def make_authenticate_headers(self) -> Dict[str, str]:
|
||||
if self.realm:
|
||||
return {"WWW-Authenticate": f'Basic realm="{self.realm}"'}
|
||||
return {"WWW-Authenticate": "Basic"}
|
||||
|
||||
async def __call__( # type: ignore
|
||||
self, request: Request
|
||||
) -> Optional[HTTPBasicCredentials]:
|
||||
authorization = request.headers.get("Authorization")
|
||||
scheme, param = get_authorization_scheme_param(authorization)
|
||||
if self.realm:
|
||||
unauthorized_headers = {"WWW-Authenticate": f'Basic realm="{self.realm}"'}
|
||||
else:
|
||||
unauthorized_headers = {"WWW-Authenticate": "Basic"}
|
||||
if not authorization or scheme.lower() != "basic":
|
||||
if self.auto_error:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_401_UNAUTHORIZED,
|
||||
detail="Not authenticated",
|
||||
headers=unauthorized_headers,
|
||||
)
|
||||
raise self.make_not_authenticated_error()
|
||||
else:
|
||||
return None
|
||||
invalid_user_credentials_exc = HTTPException(
|
||||
status_code=HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid authentication credentials",
|
||||
headers=unauthorized_headers,
|
||||
)
|
||||
try:
|
||||
data = b64decode(param).decode("ascii")
|
||||
except (ValueError, UnicodeDecodeError, binascii.Error):
|
||||
raise invalid_user_credentials_exc # noqa: B904
|
||||
except (ValueError, UnicodeDecodeError, binascii.Error) as e:
|
||||
raise self.make_not_authenticated_error() from e
|
||||
username, separator, password = data.partition(":")
|
||||
if not separator:
|
||||
raise invalid_user_credentials_exc
|
||||
raise self.make_not_authenticated_error()
|
||||
return HTTPBasicCredentials(username=username, password=password)
|
||||
|
||||
|
||||
@@ -277,7 +282,7 @@ class HTTPBearer(HTTPBase):
|
||||
bool,
|
||||
Doc(
|
||||
"""
|
||||
By default, if the HTTP Bearer token not provided (in an
|
||||
By default, if the HTTP Bearer token is not provided (in an
|
||||
`Authorization` header), `HTTPBearer` will automatically cancel the
|
||||
request and send the client an error.
|
||||
|
||||
@@ -305,17 +310,12 @@ class HTTPBearer(HTTPBase):
|
||||
scheme, credentials = get_authorization_scheme_param(authorization)
|
||||
if not (authorization and scheme and credentials):
|
||||
if self.auto_error:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
|
||||
)
|
||||
raise self.make_not_authenticated_error()
|
||||
else:
|
||||
return None
|
||||
if scheme.lower() != "bearer":
|
||||
if self.auto_error:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_403_FORBIDDEN,
|
||||
detail="Invalid authentication credentials",
|
||||
)
|
||||
raise self.make_not_authenticated_error()
|
||||
else:
|
||||
return None
|
||||
return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)
|
||||
@@ -325,6 +325,12 @@ class HTTPDigest(HTTPBase):
|
||||
"""
|
||||
HTTP Digest authentication.
|
||||
|
||||
**Warning**: this is only a stub to connect the components with OpenAPI in FastAPI,
|
||||
but it doesn't implement the full Digest scheme, you would need to to subclass it
|
||||
and implement it in your code.
|
||||
|
||||
Ref: https://datatracker.ietf.org/doc/html/rfc7616
|
||||
|
||||
## Usage
|
||||
|
||||
Create an instance object and use that object as the dependency in `Depends()`.
|
||||
@@ -380,7 +386,7 @@ class HTTPDigest(HTTPBase):
|
||||
bool,
|
||||
Doc(
|
||||
"""
|
||||
By default, if the HTTP Digest not provided, `HTTPDigest` will
|
||||
By default, if the HTTP Digest is not provided, `HTTPDigest` will
|
||||
automatically cancel the request and send the client an error.
|
||||
|
||||
If `auto_error` is set to `False`, when the HTTP Digest is not
|
||||
@@ -407,14 +413,12 @@ class HTTPDigest(HTTPBase):
|
||||
scheme, credentials = get_authorization_scheme_param(authorization)
|
||||
if not (authorization and scheme and credentials):
|
||||
if self.auto_error:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
|
||||
)
|
||||
raise self.make_not_authenticated_error()
|
||||
else:
|
||||
return None
|
||||
if scheme.lower() != "digest":
|
||||
raise HTTPException(
|
||||
status_code=HTTP_403_FORBIDDEN,
|
||||
detail="Invalid authentication credentials",
|
||||
)
|
||||
if self.auto_error:
|
||||
raise self.make_not_authenticated_error()
|
||||
else:
|
||||
return None
|
||||
return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import Any, Dict, List, Optional, Union, cast
|
||||
|
||||
from annotated_doc import Doc
|
||||
from fastapi.exceptions import HTTPException
|
||||
from fastapi.openapi.models import OAuth2 as OAuth2Model
|
||||
from fastapi.openapi.models import OAuthFlows as OAuthFlowsModel
|
||||
@@ -7,10 +8,10 @@ from fastapi.param_functions import Form
|
||||
from fastapi.security.base import SecurityBase
|
||||
from fastapi.security.utils import get_authorization_scheme_param
|
||||
from starlette.requests import Request
|
||||
from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN
|
||||
from starlette.status import HTTP_401_UNAUTHORIZED
|
||||
|
||||
# TODO: import from typing when deprecating Python 3.9
|
||||
from typing_extensions import Annotated, Doc # type: ignore [attr-defined]
|
||||
from typing_extensions import Annotated
|
||||
|
||||
|
||||
class OAuth2PasswordRequestForm:
|
||||
@@ -52,9 +53,9 @@ class OAuth2PasswordRequestForm:
|
||||
```
|
||||
|
||||
Note that for OAuth2 the scope `items:read` is a single scope in an opaque string.
|
||||
You could have custom internal logic to separate it by colon caracters (`:`) or
|
||||
You could have custom internal logic to separate it by colon characters (`:`) or
|
||||
similar, and get the two parts `items` and `read`. Many applications do that to
|
||||
group and organize permisions, you could do it as well in your application, just
|
||||
group and organize permissions, you could do it as well in your application, just
|
||||
know that that it is application specific, it's not part of the specification.
|
||||
"""
|
||||
|
||||
@@ -63,7 +64,7 @@ class OAuth2PasswordRequestForm:
|
||||
*,
|
||||
grant_type: Annotated[
|
||||
Union[str, None],
|
||||
Form(pattern="password"),
|
||||
Form(pattern="^password$"),
|
||||
Doc(
|
||||
"""
|
||||
The OAuth2 spec says it is required and MUST be the fixed string
|
||||
@@ -85,11 +86,11 @@ class OAuth2PasswordRequestForm:
|
||||
],
|
||||
password: Annotated[
|
||||
str,
|
||||
Form(),
|
||||
Form(json_schema_extra={"format": "password"}),
|
||||
Doc(
|
||||
"""
|
||||
`password` string. The OAuth2 spec requires the exact field name
|
||||
`password".
|
||||
`password`.
|
||||
"""
|
||||
),
|
||||
],
|
||||
@@ -130,7 +131,7 @@ class OAuth2PasswordRequestForm:
|
||||
] = None,
|
||||
client_secret: Annotated[
|
||||
Union[str, None],
|
||||
Form(),
|
||||
Form(json_schema_extra={"format": "password"}),
|
||||
Doc(
|
||||
"""
|
||||
If there's a `client_password` (and a `client_id`), they can be sent
|
||||
@@ -194,9 +195,9 @@ class OAuth2PasswordRequestFormStrict(OAuth2PasswordRequestForm):
|
||||
```
|
||||
|
||||
Note that for OAuth2 the scope `items:read` is a single scope in an opaque string.
|
||||
You could have custom internal logic to separate it by colon caracters (`:`) or
|
||||
You could have custom internal logic to separate it by colon characters (`:`) or
|
||||
similar, and get the two parts `items` and `read`. Many applications do that to
|
||||
group and organize permisions, you could do it as well in your application, just
|
||||
group and organize permissions, you could do it as well in your application, just
|
||||
know that that it is application specific, it's not part of the specification.
|
||||
|
||||
|
||||
@@ -217,7 +218,7 @@ class OAuth2PasswordRequestFormStrict(OAuth2PasswordRequestForm):
|
||||
self,
|
||||
grant_type: Annotated[
|
||||
str,
|
||||
Form(pattern="password"),
|
||||
Form(pattern="^password$"),
|
||||
Doc(
|
||||
"""
|
||||
The OAuth2 spec says it is required and MUST be the fixed string
|
||||
@@ -243,7 +244,7 @@ class OAuth2PasswordRequestFormStrict(OAuth2PasswordRequestForm):
|
||||
Doc(
|
||||
"""
|
||||
`password` string. The OAuth2 spec requires the exact field name
|
||||
`password".
|
||||
`password`.
|
||||
"""
|
||||
),
|
||||
],
|
||||
@@ -353,7 +354,7 @@ class OAuth2(SecurityBase):
|
||||
bool,
|
||||
Doc(
|
||||
"""
|
||||
By default, if no HTTP Auhtorization header is provided, required for
|
||||
By default, if no HTTP Authorization header is provided, required for
|
||||
OAuth2 authentication, it will automatically cancel the request and
|
||||
send the client an error.
|
||||
|
||||
@@ -376,13 +377,33 @@ class OAuth2(SecurityBase):
|
||||
self.scheme_name = scheme_name or self.__class__.__name__
|
||||
self.auto_error = auto_error
|
||||
|
||||
def make_not_authenticated_error(self) -> HTTPException:
|
||||
"""
|
||||
The OAuth 2 specification doesn't define the challenge that should be used,
|
||||
because a `Bearer` token is not really the only option to authenticate.
|
||||
|
||||
But declaring any other authentication challenge would be application-specific
|
||||
as it's not defined in the specification.
|
||||
|
||||
For practical reasons, this method uses the `Bearer` challenge by default, as
|
||||
it's probably the most common one.
|
||||
|
||||
If you are implementing an OAuth2 authentication scheme other than the provided
|
||||
ones in FastAPI (based on bearer tokens), you might want to override this.
|
||||
|
||||
Ref: https://datatracker.ietf.org/doc/html/rfc6749
|
||||
"""
|
||||
return HTTPException(
|
||||
status_code=HTTP_401_UNAUTHORIZED,
|
||||
detail="Not authenticated",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
async def __call__(self, request: Request) -> Optional[str]:
|
||||
authorization = request.headers.get("Authorization")
|
||||
if not authorization:
|
||||
if self.auto_error:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
|
||||
)
|
||||
raise self.make_not_authenticated_error()
|
||||
else:
|
||||
return None
|
||||
return authorization
|
||||
@@ -441,7 +462,7 @@ class OAuth2PasswordBearer(OAuth2):
|
||||
bool,
|
||||
Doc(
|
||||
"""
|
||||
By default, if no HTTP Auhtorization header is provided, required for
|
||||
By default, if no HTTP Authorization header is provided, required for
|
||||
OAuth2 authentication, it will automatically cancel the request and
|
||||
send the client an error.
|
||||
|
||||
@@ -457,11 +478,26 @@ class OAuth2PasswordBearer(OAuth2):
|
||||
"""
|
||||
),
|
||||
] = True,
|
||||
refreshUrl: Annotated[
|
||||
Optional[str],
|
||||
Doc(
|
||||
"""
|
||||
The URL to refresh the token and obtain a new one.
|
||||
"""
|
||||
),
|
||||
] = None,
|
||||
):
|
||||
if not scopes:
|
||||
scopes = {}
|
||||
flows = OAuthFlowsModel(
|
||||
password=cast(Any, {"tokenUrl": tokenUrl, "scopes": scopes})
|
||||
password=cast(
|
||||
Any,
|
||||
{
|
||||
"tokenUrl": tokenUrl,
|
||||
"refreshUrl": refreshUrl,
|
||||
"scopes": scopes,
|
||||
},
|
||||
)
|
||||
)
|
||||
super().__init__(
|
||||
flows=flows,
|
||||
@@ -475,11 +511,7 @@ class OAuth2PasswordBearer(OAuth2):
|
||||
scheme, param = get_authorization_scheme_param(authorization)
|
||||
if not authorization or scheme.lower() != "bearer":
|
||||
if self.auto_error:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_401_UNAUTHORIZED,
|
||||
detail="Not authenticated",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
raise self.make_not_authenticated_error()
|
||||
else:
|
||||
return None
|
||||
return param
|
||||
@@ -543,7 +575,7 @@ class OAuth2AuthorizationCodeBearer(OAuth2):
|
||||
bool,
|
||||
Doc(
|
||||
"""
|
||||
By default, if no HTTP Auhtorization header is provided, required for
|
||||
By default, if no HTTP Authorization header is provided, required for
|
||||
OAuth2 authentication, it will automatically cancel the request and
|
||||
send the client an error.
|
||||
|
||||
@@ -585,11 +617,7 @@ class OAuth2AuthorizationCodeBearer(OAuth2):
|
||||
scheme, param = get_authorization_scheme_param(authorization)
|
||||
if not authorization or scheme.lower() != "bearer":
|
||||
if self.auto_error:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_401_UNAUTHORIZED,
|
||||
detail="Not authenticated",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
raise self.make_not_authenticated_error()
|
||||
else:
|
||||
return None # pragma: nocover
|
||||
return param
|
||||
|
||||
@@ -1,17 +1,23 @@
|
||||
from typing import Optional
|
||||
|
||||
from annotated_doc import Doc
|
||||
from fastapi.openapi.models import OpenIdConnect as OpenIdConnectModel
|
||||
from fastapi.security.base import SecurityBase
|
||||
from starlette.exceptions import HTTPException
|
||||
from starlette.requests import Request
|
||||
from starlette.status import HTTP_403_FORBIDDEN
|
||||
from typing_extensions import Annotated, Doc # type: ignore [attr-defined]
|
||||
from starlette.status import HTTP_401_UNAUTHORIZED
|
||||
from typing_extensions import Annotated
|
||||
|
||||
|
||||
class OpenIdConnect(SecurityBase):
|
||||
"""
|
||||
OpenID Connect authentication class. An instance of it would be used as a
|
||||
dependency.
|
||||
|
||||
**Warning**: this is only a stub to connect the components with OpenAPI in FastAPI,
|
||||
but it doesn't implement the full OpenIdConnect scheme, for example, it doesn't use
|
||||
the OpenIDConnect URL. You would need to to subclass it and implement it in your
|
||||
code.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -49,7 +55,7 @@ class OpenIdConnect(SecurityBase):
|
||||
bool,
|
||||
Doc(
|
||||
"""
|
||||
By default, if no HTTP Auhtorization header is provided, required for
|
||||
By default, if no HTTP Authorization header is provided, required for
|
||||
OpenID Connect authentication, it will automatically cancel the request
|
||||
and send the client an error.
|
||||
|
||||
@@ -72,13 +78,18 @@ class OpenIdConnect(SecurityBase):
|
||||
self.scheme_name = scheme_name or self.__class__.__name__
|
||||
self.auto_error = auto_error
|
||||
|
||||
def make_not_authenticated_error(self) -> HTTPException:
|
||||
return HTTPException(
|
||||
status_code=HTTP_401_UNAUTHORIZED,
|
||||
detail="Not authenticated",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
async def __call__(self, request: Request) -> Optional[str]:
|
||||
authorization = request.headers.get("Authorization")
|
||||
if not authorization:
|
||||
if self.auto_error:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
|
||||
)
|
||||
raise self.make_not_authenticated_error()
|
||||
else:
|
||||
return None
|
||||
return authorization
|
||||
|
||||
@@ -0,0 +1,724 @@
|
||||
import warnings
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
from fastapi.openapi.models import Example
|
||||
from fastapi.params import ParamTypes
|
||||
from typing_extensions import Annotated, deprecated
|
||||
|
||||
from ._compat.may_v1 import FieldInfo, Undefined
|
||||
from ._compat.shared import PYDANTIC_VERSION_MINOR_TUPLE
|
||||
|
||||
_Unset: Any = Undefined
|
||||
|
||||
|
||||
class Param(FieldInfo): # type: ignore[misc]
|
||||
in_: ParamTypes
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
default: Any = Undefined,
|
||||
*,
|
||||
default_factory: Union[Callable[[], Any], None] = _Unset,
|
||||
annotation: Optional[Any] = None,
|
||||
alias: Optional[str] = None,
|
||||
alias_priority: Union[int, None] = _Unset,
|
||||
# TODO: update when deprecating Pydantic v1, import these types
|
||||
# validation_alias: str | AliasPath | AliasChoices | None
|
||||
validation_alias: Union[str, None] = None,
|
||||
serialization_alias: Union[str, None] = None,
|
||||
title: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
gt: Optional[float] = None,
|
||||
ge: Optional[float] = None,
|
||||
lt: Optional[float] = None,
|
||||
le: Optional[float] = None,
|
||||
min_length: Optional[int] = None,
|
||||
max_length: Optional[int] = None,
|
||||
pattern: Optional[str] = None,
|
||||
regex: Annotated[
|
||||
Optional[str],
|
||||
deprecated(
|
||||
"Deprecated in FastAPI 0.100.0 and Pydantic v2, use `pattern` instead."
|
||||
),
|
||||
] = None,
|
||||
discriminator: Union[str, None] = None,
|
||||
strict: Union[bool, None] = _Unset,
|
||||
multiple_of: Union[float, None] = _Unset,
|
||||
allow_inf_nan: Union[bool, None] = _Unset,
|
||||
max_digits: Union[int, None] = _Unset,
|
||||
decimal_places: Union[int, None] = _Unset,
|
||||
examples: Optional[List[Any]] = None,
|
||||
example: Annotated[
|
||||
Optional[Any],
|
||||
deprecated(
|
||||
"Deprecated in OpenAPI 3.1.0 that now uses JSON Schema 2020-12, "
|
||||
"although still supported. Use examples instead."
|
||||
),
|
||||
] = _Unset,
|
||||
openapi_examples: Optional[Dict[str, Example]] = None,
|
||||
deprecated: Union[deprecated, str, bool, None] = None,
|
||||
include_in_schema: bool = True,
|
||||
json_schema_extra: Union[Dict[str, Any], None] = None,
|
||||
**extra: Any,
|
||||
):
|
||||
if example is not _Unset:
|
||||
warnings.warn(
|
||||
"`example` has been deprecated, please use `examples` instead",
|
||||
category=DeprecationWarning,
|
||||
stacklevel=4,
|
||||
)
|
||||
self.example = example
|
||||
self.include_in_schema = include_in_schema
|
||||
self.openapi_examples = openapi_examples
|
||||
kwargs = dict(
|
||||
default=default,
|
||||
default_factory=default_factory,
|
||||
alias=alias,
|
||||
title=title,
|
||||
description=description,
|
||||
gt=gt,
|
||||
ge=ge,
|
||||
lt=lt,
|
||||
le=le,
|
||||
min_length=min_length,
|
||||
max_length=max_length,
|
||||
discriminator=discriminator,
|
||||
multiple_of=multiple_of,
|
||||
allow_inf_nan=allow_inf_nan,
|
||||
max_digits=max_digits,
|
||||
decimal_places=decimal_places,
|
||||
**extra,
|
||||
)
|
||||
if examples is not None:
|
||||
kwargs["examples"] = examples
|
||||
if regex is not None:
|
||||
warnings.warn(
|
||||
"`regex` has been deprecated, please use `pattern` instead",
|
||||
category=DeprecationWarning,
|
||||
stacklevel=4,
|
||||
)
|
||||
current_json_schema_extra = json_schema_extra or extra
|
||||
if PYDANTIC_VERSION_MINOR_TUPLE < (2, 7):
|
||||
self.deprecated = deprecated
|
||||
else:
|
||||
kwargs["deprecated"] = deprecated
|
||||
kwargs["regex"] = pattern or regex
|
||||
kwargs.update(**current_json_schema_extra)
|
||||
use_kwargs = {k: v for k, v in kwargs.items() if v is not _Unset}
|
||||
|
||||
super().__init__(**use_kwargs)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}({self.default})"
|
||||
|
||||
|
||||
class Path(Param): # type: ignore[misc]
|
||||
in_ = ParamTypes.path
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
default: Any = ...,
|
||||
*,
|
||||
default_factory: Union[Callable[[], Any], None] = _Unset,
|
||||
annotation: Optional[Any] = None,
|
||||
alias: Optional[str] = None,
|
||||
alias_priority: Union[int, None] = _Unset,
|
||||
# TODO: update when deprecating Pydantic v1, import these types
|
||||
# validation_alias: str | AliasPath | AliasChoices | None
|
||||
validation_alias: Union[str, None] = None,
|
||||
serialization_alias: Union[str, None] = None,
|
||||
title: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
gt: Optional[float] = None,
|
||||
ge: Optional[float] = None,
|
||||
lt: Optional[float] = None,
|
||||
le: Optional[float] = None,
|
||||
min_length: Optional[int] = None,
|
||||
max_length: Optional[int] = None,
|
||||
pattern: Optional[str] = None,
|
||||
regex: Annotated[
|
||||
Optional[str],
|
||||
deprecated(
|
||||
"Deprecated in FastAPI 0.100.0 and Pydantic v2, use `pattern` instead."
|
||||
),
|
||||
] = None,
|
||||
discriminator: Union[str, None] = None,
|
||||
strict: Union[bool, None] = _Unset,
|
||||
multiple_of: Union[float, None] = _Unset,
|
||||
allow_inf_nan: Union[bool, None] = _Unset,
|
||||
max_digits: Union[int, None] = _Unset,
|
||||
decimal_places: Union[int, None] = _Unset,
|
||||
examples: Optional[List[Any]] = None,
|
||||
example: Annotated[
|
||||
Optional[Any],
|
||||
deprecated(
|
||||
"Deprecated in OpenAPI 3.1.0 that now uses JSON Schema 2020-12, "
|
||||
"although still supported. Use examples instead."
|
||||
),
|
||||
] = _Unset,
|
||||
openapi_examples: Optional[Dict[str, Example]] = None,
|
||||
deprecated: Union[deprecated, str, bool, None] = None,
|
||||
include_in_schema: bool = True,
|
||||
json_schema_extra: Union[Dict[str, Any], None] = None,
|
||||
**extra: Any,
|
||||
):
|
||||
assert default is ..., "Path parameters cannot have a default value"
|
||||
self.in_ = self.in_
|
||||
super().__init__(
|
||||
default=default,
|
||||
default_factory=default_factory,
|
||||
annotation=annotation,
|
||||
alias=alias,
|
||||
alias_priority=alias_priority,
|
||||
validation_alias=validation_alias,
|
||||
serialization_alias=serialization_alias,
|
||||
title=title,
|
||||
description=description,
|
||||
gt=gt,
|
||||
ge=ge,
|
||||
lt=lt,
|
||||
le=le,
|
||||
min_length=min_length,
|
||||
max_length=max_length,
|
||||
pattern=pattern,
|
||||
regex=regex,
|
||||
discriminator=discriminator,
|
||||
strict=strict,
|
||||
multiple_of=multiple_of,
|
||||
allow_inf_nan=allow_inf_nan,
|
||||
max_digits=max_digits,
|
||||
decimal_places=decimal_places,
|
||||
deprecated=deprecated,
|
||||
example=example,
|
||||
examples=examples,
|
||||
openapi_examples=openapi_examples,
|
||||
include_in_schema=include_in_schema,
|
||||
json_schema_extra=json_schema_extra,
|
||||
**extra,
|
||||
)
|
||||
|
||||
|
||||
class Query(Param): # type: ignore[misc]
|
||||
in_ = ParamTypes.query
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
default: Any = Undefined,
|
||||
*,
|
||||
default_factory: Union[Callable[[], Any], None] = _Unset,
|
||||
annotation: Optional[Any] = None,
|
||||
alias: Optional[str] = None,
|
||||
alias_priority: Union[int, None] = _Unset,
|
||||
# TODO: update when deprecating Pydantic v1, import these types
|
||||
# validation_alias: str | AliasPath | AliasChoices | None
|
||||
validation_alias: Union[str, None] = None,
|
||||
serialization_alias: Union[str, None] = None,
|
||||
title: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
gt: Optional[float] = None,
|
||||
ge: Optional[float] = None,
|
||||
lt: Optional[float] = None,
|
||||
le: Optional[float] = None,
|
||||
min_length: Optional[int] = None,
|
||||
max_length: Optional[int] = None,
|
||||
pattern: Optional[str] = None,
|
||||
regex: Annotated[
|
||||
Optional[str],
|
||||
deprecated(
|
||||
"Deprecated in FastAPI 0.100.0 and Pydantic v2, use `pattern` instead."
|
||||
),
|
||||
] = None,
|
||||
discriminator: Union[str, None] = None,
|
||||
strict: Union[bool, None] = _Unset,
|
||||
multiple_of: Union[float, None] = _Unset,
|
||||
allow_inf_nan: Union[bool, None] = _Unset,
|
||||
max_digits: Union[int, None] = _Unset,
|
||||
decimal_places: Union[int, None] = _Unset,
|
||||
examples: Optional[List[Any]] = None,
|
||||
example: Annotated[
|
||||
Optional[Any],
|
||||
deprecated(
|
||||
"Deprecated in OpenAPI 3.1.0 that now uses JSON Schema 2020-12, "
|
||||
"although still supported. Use examples instead."
|
||||
),
|
||||
] = _Unset,
|
||||
openapi_examples: Optional[Dict[str, Example]] = None,
|
||||
deprecated: Union[deprecated, str, bool, None] = None,
|
||||
include_in_schema: bool = True,
|
||||
json_schema_extra: Union[Dict[str, Any], None] = None,
|
||||
**extra: Any,
|
||||
):
|
||||
super().__init__(
|
||||
default=default,
|
||||
default_factory=default_factory,
|
||||
annotation=annotation,
|
||||
alias=alias,
|
||||
alias_priority=alias_priority,
|
||||
validation_alias=validation_alias,
|
||||
serialization_alias=serialization_alias,
|
||||
title=title,
|
||||
description=description,
|
||||
gt=gt,
|
||||
ge=ge,
|
||||
lt=lt,
|
||||
le=le,
|
||||
min_length=min_length,
|
||||
max_length=max_length,
|
||||
pattern=pattern,
|
||||
regex=regex,
|
||||
discriminator=discriminator,
|
||||
strict=strict,
|
||||
multiple_of=multiple_of,
|
||||
allow_inf_nan=allow_inf_nan,
|
||||
max_digits=max_digits,
|
||||
decimal_places=decimal_places,
|
||||
deprecated=deprecated,
|
||||
example=example,
|
||||
examples=examples,
|
||||
openapi_examples=openapi_examples,
|
||||
include_in_schema=include_in_schema,
|
||||
json_schema_extra=json_schema_extra,
|
||||
**extra,
|
||||
)
|
||||
|
||||
|
||||
class Header(Param): # type: ignore[misc]
|
||||
in_ = ParamTypes.header
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
default: Any = Undefined,
|
||||
*,
|
||||
default_factory: Union[Callable[[], Any], None] = _Unset,
|
||||
annotation: Optional[Any] = None,
|
||||
alias: Optional[str] = None,
|
||||
alias_priority: Union[int, None] = _Unset,
|
||||
# TODO: update when deprecating Pydantic v1, import these types
|
||||
# validation_alias: str | AliasPath | AliasChoices | None
|
||||
validation_alias: Union[str, None] = None,
|
||||
serialization_alias: Union[str, None] = None,
|
||||
convert_underscores: bool = True,
|
||||
title: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
gt: Optional[float] = None,
|
||||
ge: Optional[float] = None,
|
||||
lt: Optional[float] = None,
|
||||
le: Optional[float] = None,
|
||||
min_length: Optional[int] = None,
|
||||
max_length: Optional[int] = None,
|
||||
pattern: Optional[str] = None,
|
||||
regex: Annotated[
|
||||
Optional[str],
|
||||
deprecated(
|
||||
"Deprecated in FastAPI 0.100.0 and Pydantic v2, use `pattern` instead."
|
||||
),
|
||||
] = None,
|
||||
discriminator: Union[str, None] = None,
|
||||
strict: Union[bool, None] = _Unset,
|
||||
multiple_of: Union[float, None] = _Unset,
|
||||
allow_inf_nan: Union[bool, None] = _Unset,
|
||||
max_digits: Union[int, None] = _Unset,
|
||||
decimal_places: Union[int, None] = _Unset,
|
||||
examples: Optional[List[Any]] = None,
|
||||
example: Annotated[
|
||||
Optional[Any],
|
||||
deprecated(
|
||||
"Deprecated in OpenAPI 3.1.0 that now uses JSON Schema 2020-12, "
|
||||
"although still supported. Use examples instead."
|
||||
),
|
||||
] = _Unset,
|
||||
openapi_examples: Optional[Dict[str, Example]] = None,
|
||||
deprecated: Union[deprecated, str, bool, None] = None,
|
||||
include_in_schema: bool = True,
|
||||
json_schema_extra: Union[Dict[str, Any], None] = None,
|
||||
**extra: Any,
|
||||
):
|
||||
self.convert_underscores = convert_underscores
|
||||
super().__init__(
|
||||
default=default,
|
||||
default_factory=default_factory,
|
||||
annotation=annotation,
|
||||
alias=alias,
|
||||
alias_priority=alias_priority,
|
||||
validation_alias=validation_alias,
|
||||
serialization_alias=serialization_alias,
|
||||
title=title,
|
||||
description=description,
|
||||
gt=gt,
|
||||
ge=ge,
|
||||
lt=lt,
|
||||
le=le,
|
||||
min_length=min_length,
|
||||
max_length=max_length,
|
||||
pattern=pattern,
|
||||
regex=regex,
|
||||
discriminator=discriminator,
|
||||
strict=strict,
|
||||
multiple_of=multiple_of,
|
||||
allow_inf_nan=allow_inf_nan,
|
||||
max_digits=max_digits,
|
||||
decimal_places=decimal_places,
|
||||
deprecated=deprecated,
|
||||
example=example,
|
||||
examples=examples,
|
||||
openapi_examples=openapi_examples,
|
||||
include_in_schema=include_in_schema,
|
||||
json_schema_extra=json_schema_extra,
|
||||
**extra,
|
||||
)
|
||||
|
||||
|
||||
class Cookie(Param): # type: ignore[misc]
|
||||
in_ = ParamTypes.cookie
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
default: Any = Undefined,
|
||||
*,
|
||||
default_factory: Union[Callable[[], Any], None] = _Unset,
|
||||
annotation: Optional[Any] = None,
|
||||
alias: Optional[str] = None,
|
||||
alias_priority: Union[int, None] = _Unset,
|
||||
# TODO: update when deprecating Pydantic v1, import these types
|
||||
# validation_alias: str | AliasPath | AliasChoices | None
|
||||
validation_alias: Union[str, None] = None,
|
||||
serialization_alias: Union[str, None] = None,
|
||||
title: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
gt: Optional[float] = None,
|
||||
ge: Optional[float] = None,
|
||||
lt: Optional[float] = None,
|
||||
le: Optional[float] = None,
|
||||
min_length: Optional[int] = None,
|
||||
max_length: Optional[int] = None,
|
||||
pattern: Optional[str] = None,
|
||||
regex: Annotated[
|
||||
Optional[str],
|
||||
deprecated(
|
||||
"Deprecated in FastAPI 0.100.0 and Pydantic v2, use `pattern` instead."
|
||||
),
|
||||
] = None,
|
||||
discriminator: Union[str, None] = None,
|
||||
strict: Union[bool, None] = _Unset,
|
||||
multiple_of: Union[float, None] = _Unset,
|
||||
allow_inf_nan: Union[bool, None] = _Unset,
|
||||
max_digits: Union[int, None] = _Unset,
|
||||
decimal_places: Union[int, None] = _Unset,
|
||||
examples: Optional[List[Any]] = None,
|
||||
example: Annotated[
|
||||
Optional[Any],
|
||||
deprecated(
|
||||
"Deprecated in OpenAPI 3.1.0 that now uses JSON Schema 2020-12, "
|
||||
"although still supported. Use examples instead."
|
||||
),
|
||||
] = _Unset,
|
||||
openapi_examples: Optional[Dict[str, Example]] = None,
|
||||
deprecated: Union[deprecated, str, bool, None] = None,
|
||||
include_in_schema: bool = True,
|
||||
json_schema_extra: Union[Dict[str, Any], None] = None,
|
||||
**extra: Any,
|
||||
):
|
||||
super().__init__(
|
||||
default=default,
|
||||
default_factory=default_factory,
|
||||
annotation=annotation,
|
||||
alias=alias,
|
||||
alias_priority=alias_priority,
|
||||
validation_alias=validation_alias,
|
||||
serialization_alias=serialization_alias,
|
||||
title=title,
|
||||
description=description,
|
||||
gt=gt,
|
||||
ge=ge,
|
||||
lt=lt,
|
||||
le=le,
|
||||
min_length=min_length,
|
||||
max_length=max_length,
|
||||
pattern=pattern,
|
||||
regex=regex,
|
||||
discriminator=discriminator,
|
||||
strict=strict,
|
||||
multiple_of=multiple_of,
|
||||
allow_inf_nan=allow_inf_nan,
|
||||
max_digits=max_digits,
|
||||
decimal_places=decimal_places,
|
||||
deprecated=deprecated,
|
||||
example=example,
|
||||
examples=examples,
|
||||
openapi_examples=openapi_examples,
|
||||
include_in_schema=include_in_schema,
|
||||
json_schema_extra=json_schema_extra,
|
||||
**extra,
|
||||
)
|
||||
|
||||
|
||||
class Body(FieldInfo): # type: ignore[misc]
|
||||
def __init__(
|
||||
self,
|
||||
default: Any = Undefined,
|
||||
*,
|
||||
default_factory: Union[Callable[[], Any], None] = _Unset,
|
||||
annotation: Optional[Any] = None,
|
||||
embed: Union[bool, None] = None,
|
||||
media_type: str = "application/json",
|
||||
alias: Optional[str] = None,
|
||||
alias_priority: Union[int, None] = _Unset,
|
||||
# TODO: update when deprecating Pydantic v1, import these types
|
||||
# validation_alias: str | AliasPath | AliasChoices | None
|
||||
validation_alias: Union[str, None] = None,
|
||||
serialization_alias: Union[str, None] = None,
|
||||
title: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
gt: Optional[float] = None,
|
||||
ge: Optional[float] = None,
|
||||
lt: Optional[float] = None,
|
||||
le: Optional[float] = None,
|
||||
min_length: Optional[int] = None,
|
||||
max_length: Optional[int] = None,
|
||||
pattern: Optional[str] = None,
|
||||
regex: Annotated[
|
||||
Optional[str],
|
||||
deprecated(
|
||||
"Deprecated in FastAPI 0.100.0 and Pydantic v2, use `pattern` instead."
|
||||
),
|
||||
] = None,
|
||||
discriminator: Union[str, None] = None,
|
||||
strict: Union[bool, None] = _Unset,
|
||||
multiple_of: Union[float, None] = _Unset,
|
||||
allow_inf_nan: Union[bool, None] = _Unset,
|
||||
max_digits: Union[int, None] = _Unset,
|
||||
decimal_places: Union[int, None] = _Unset,
|
||||
examples: Optional[List[Any]] = None,
|
||||
example: Annotated[
|
||||
Optional[Any],
|
||||
deprecated(
|
||||
"Deprecated in OpenAPI 3.1.0 that now uses JSON Schema 2020-12, "
|
||||
"although still supported. Use examples instead."
|
||||
),
|
||||
] = _Unset,
|
||||
openapi_examples: Optional[Dict[str, Example]] = None,
|
||||
deprecated: Union[deprecated, str, bool, None] = None,
|
||||
include_in_schema: bool = True,
|
||||
json_schema_extra: Union[Dict[str, Any], None] = None,
|
||||
**extra: Any,
|
||||
):
|
||||
self.embed = embed
|
||||
self.media_type = media_type
|
||||
if example is not _Unset:
|
||||
warnings.warn(
|
||||
"`example` has been deprecated, please use `examples` instead",
|
||||
category=DeprecationWarning,
|
||||
stacklevel=4,
|
||||
)
|
||||
self.example = example
|
||||
self.include_in_schema = include_in_schema
|
||||
self.openapi_examples = openapi_examples
|
||||
kwargs = dict(
|
||||
default=default,
|
||||
default_factory=default_factory,
|
||||
alias=alias,
|
||||
title=title,
|
||||
description=description,
|
||||
gt=gt,
|
||||
ge=ge,
|
||||
lt=lt,
|
||||
le=le,
|
||||
min_length=min_length,
|
||||
max_length=max_length,
|
||||
discriminator=discriminator,
|
||||
multiple_of=multiple_of,
|
||||
allow_inf_nan=allow_inf_nan,
|
||||
max_digits=max_digits,
|
||||
decimal_places=decimal_places,
|
||||
**extra,
|
||||
)
|
||||
if examples is not None:
|
||||
kwargs["examples"] = examples
|
||||
if regex is not None:
|
||||
warnings.warn(
|
||||
"`regex` has been deprecated, please use `pattern` instead",
|
||||
category=DeprecationWarning,
|
||||
stacklevel=4,
|
||||
)
|
||||
current_json_schema_extra = json_schema_extra or extra
|
||||
if PYDANTIC_VERSION_MINOR_TUPLE < (2, 7):
|
||||
self.deprecated = deprecated
|
||||
else:
|
||||
kwargs["deprecated"] = deprecated
|
||||
kwargs["regex"] = pattern or regex
|
||||
kwargs.update(**current_json_schema_extra)
|
||||
|
||||
use_kwargs = {k: v for k, v in kwargs.items() if v is not _Unset}
|
||||
|
||||
super().__init__(**use_kwargs)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}({self.default})"
|
||||
|
||||
|
||||
class Form(Body): # type: ignore[misc]
|
||||
def __init__(
|
||||
self,
|
||||
default: Any = Undefined,
|
||||
*,
|
||||
default_factory: Union[Callable[[], Any], None] = _Unset,
|
||||
annotation: Optional[Any] = None,
|
||||
media_type: str = "application/x-www-form-urlencoded",
|
||||
alias: Optional[str] = None,
|
||||
alias_priority: Union[int, None] = _Unset,
|
||||
# TODO: update when deprecating Pydantic v1, import these types
|
||||
# validation_alias: str | AliasPath | AliasChoices | None
|
||||
validation_alias: Union[str, None] = None,
|
||||
serialization_alias: Union[str, None] = None,
|
||||
title: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
gt: Optional[float] = None,
|
||||
ge: Optional[float] = None,
|
||||
lt: Optional[float] = None,
|
||||
le: Optional[float] = None,
|
||||
min_length: Optional[int] = None,
|
||||
max_length: Optional[int] = None,
|
||||
pattern: Optional[str] = None,
|
||||
regex: Annotated[
|
||||
Optional[str],
|
||||
deprecated(
|
||||
"Deprecated in FastAPI 0.100.0 and Pydantic v2, use `pattern` instead."
|
||||
),
|
||||
] = None,
|
||||
discriminator: Union[str, None] = None,
|
||||
strict: Union[bool, None] = _Unset,
|
||||
multiple_of: Union[float, None] = _Unset,
|
||||
allow_inf_nan: Union[bool, None] = _Unset,
|
||||
max_digits: Union[int, None] = _Unset,
|
||||
decimal_places: Union[int, None] = _Unset,
|
||||
examples: Optional[List[Any]] = None,
|
||||
example: Annotated[
|
||||
Optional[Any],
|
||||
deprecated(
|
||||
"Deprecated in OpenAPI 3.1.0 that now uses JSON Schema 2020-12, "
|
||||
"although still supported. Use examples instead."
|
||||
),
|
||||
] = _Unset,
|
||||
openapi_examples: Optional[Dict[str, Example]] = None,
|
||||
deprecated: Union[deprecated, str, bool, None] = None,
|
||||
include_in_schema: bool = True,
|
||||
json_schema_extra: Union[Dict[str, Any], None] = None,
|
||||
**extra: Any,
|
||||
):
|
||||
super().__init__(
|
||||
default=default,
|
||||
default_factory=default_factory,
|
||||
annotation=annotation,
|
||||
media_type=media_type,
|
||||
alias=alias,
|
||||
alias_priority=alias_priority,
|
||||
validation_alias=validation_alias,
|
||||
serialization_alias=serialization_alias,
|
||||
title=title,
|
||||
description=description,
|
||||
gt=gt,
|
||||
ge=ge,
|
||||
lt=lt,
|
||||
le=le,
|
||||
min_length=min_length,
|
||||
max_length=max_length,
|
||||
pattern=pattern,
|
||||
regex=regex,
|
||||
discriminator=discriminator,
|
||||
strict=strict,
|
||||
multiple_of=multiple_of,
|
||||
allow_inf_nan=allow_inf_nan,
|
||||
max_digits=max_digits,
|
||||
decimal_places=decimal_places,
|
||||
deprecated=deprecated,
|
||||
example=example,
|
||||
examples=examples,
|
||||
openapi_examples=openapi_examples,
|
||||
include_in_schema=include_in_schema,
|
||||
json_schema_extra=json_schema_extra,
|
||||
**extra,
|
||||
)
|
||||
|
||||
|
||||
class File(Form): # type: ignore[misc]
|
||||
def __init__(
|
||||
self,
|
||||
default: Any = Undefined,
|
||||
*,
|
||||
default_factory: Union[Callable[[], Any], None] = _Unset,
|
||||
annotation: Optional[Any] = None,
|
||||
media_type: str = "multipart/form-data",
|
||||
alias: Optional[str] = None,
|
||||
alias_priority: Union[int, None] = _Unset,
|
||||
# TODO: update when deprecating Pydantic v1, import these types
|
||||
# validation_alias: str | AliasPath | AliasChoices | None
|
||||
validation_alias: Union[str, None] = None,
|
||||
serialization_alias: Union[str, None] = None,
|
||||
title: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
gt: Optional[float] = None,
|
||||
ge: Optional[float] = None,
|
||||
lt: Optional[float] = None,
|
||||
le: Optional[float] = None,
|
||||
min_length: Optional[int] = None,
|
||||
max_length: Optional[int] = None,
|
||||
pattern: Optional[str] = None,
|
||||
regex: Annotated[
|
||||
Optional[str],
|
||||
deprecated(
|
||||
"Deprecated in FastAPI 0.100.0 and Pydantic v2, use `pattern` instead."
|
||||
),
|
||||
] = None,
|
||||
discriminator: Union[str, None] = None,
|
||||
strict: Union[bool, None] = _Unset,
|
||||
multiple_of: Union[float, None] = _Unset,
|
||||
allow_inf_nan: Union[bool, None] = _Unset,
|
||||
max_digits: Union[int, None] = _Unset,
|
||||
decimal_places: Union[int, None] = _Unset,
|
||||
examples: Optional[List[Any]] = None,
|
||||
example: Annotated[
|
||||
Optional[Any],
|
||||
deprecated(
|
||||
"Deprecated in OpenAPI 3.1.0 that now uses JSON Schema 2020-12, "
|
||||
"although still supported. Use examples instead."
|
||||
),
|
||||
] = _Unset,
|
||||
openapi_examples: Optional[Dict[str, Example]] = None,
|
||||
deprecated: Union[deprecated, str, bool, None] = None,
|
||||
include_in_schema: bool = True,
|
||||
json_schema_extra: Union[Dict[str, Any], None] = None,
|
||||
**extra: Any,
|
||||
):
|
||||
super().__init__(
|
||||
default=default,
|
||||
default_factory=default_factory,
|
||||
annotation=annotation,
|
||||
media_type=media_type,
|
||||
alias=alias,
|
||||
alias_priority=alias_priority,
|
||||
validation_alias=validation_alias,
|
||||
serialization_alias=serialization_alias,
|
||||
title=title,
|
||||
description=description,
|
||||
gt=gt,
|
||||
ge=ge,
|
||||
lt=lt,
|
||||
le=le,
|
||||
min_length=min_length,
|
||||
max_length=max_length,
|
||||
pattern=pattern,
|
||||
regex=regex,
|
||||
discriminator=discriminator,
|
||||
strict=strict,
|
||||
multiple_of=multiple_of,
|
||||
allow_inf_nan=allow_inf_nan,
|
||||
max_digits=max_digits,
|
||||
decimal_places=decimal_places,
|
||||
deprecated=deprecated,
|
||||
example=example,
|
||||
examples=examples,
|
||||
openapi_examples=openapi_examples,
|
||||
include_in_schema=include_in_schema,
|
||||
json_schema_extra=json_schema_extra,
|
||||
**extra,
|
||||
)
|
||||
@@ -1,11 +1,11 @@
|
||||
import types
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Dict, Set, Type, TypeVar, Union
|
||||
from typing import Any, Callable, Dict, Optional, Set, Tuple, Type, TypeVar, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
DecoratedCallable = TypeVar("DecoratedCallable", bound=Callable[..., Any])
|
||||
UnionType = getattr(types, "UnionType", Union)
|
||||
NoneType = getattr(types, "UnionType", None)
|
||||
ModelNameMap = Dict[Union[Type[BaseModel], Type[Enum]], str]
|
||||
IncEx = Union[Set[int], Set[str], Dict[int, Any], Dict[str, Any]]
|
||||
DependencyCacheKey = Tuple[Optional[Callable[..., Any]], Tuple[str, ...], str]
|
||||
|
||||
@@ -23,10 +23,12 @@ from fastapi._compat import (
|
||||
Undefined,
|
||||
UndefinedType,
|
||||
Validator,
|
||||
annotation_is_pydantic_v1,
|
||||
lenient_issubclass,
|
||||
may_v1,
|
||||
)
|
||||
from fastapi.datastructures import DefaultPlaceholder, DefaultType
|
||||
from pydantic import BaseModel, create_model
|
||||
from pydantic import BaseModel
|
||||
from pydantic.fields import FieldInfo
|
||||
from typing_extensions import Literal
|
||||
|
||||
@@ -34,9 +36,9 @@ if TYPE_CHECKING: # pragma: nocover
|
||||
from .routing import APIRoute
|
||||
|
||||
# Cache for `create_cloned_field`
|
||||
_CLONED_TYPES_CACHE: MutableMapping[
|
||||
Type[BaseModel], Type[BaseModel]
|
||||
] = WeakKeyDictionary()
|
||||
_CLONED_TYPES_CACHE: MutableMapping[Type[BaseModel], Type[BaseModel]] = (
|
||||
WeakKeyDictionary()
|
||||
)
|
||||
|
||||
|
||||
def is_body_allowed_for_status_code(status_code: Union[int, str, None]) -> bool:
|
||||
@@ -53,60 +55,81 @@ def is_body_allowed_for_status_code(status_code: Union[int, str, None]) -> bool:
|
||||
}:
|
||||
return True
|
||||
current_status_code = int(status_code)
|
||||
return not (current_status_code < 200 or current_status_code in {204, 304})
|
||||
return not (current_status_code < 200 or current_status_code in {204, 205, 304})
|
||||
|
||||
|
||||
def get_path_param_names(path: str) -> Set[str]:
|
||||
return set(re.findall("{(.*?)}", path))
|
||||
|
||||
|
||||
def create_response_field(
|
||||
_invalid_args_message = (
|
||||
"Invalid args for response field! Hint: "
|
||||
"check that {type_} is a valid Pydantic field type. "
|
||||
"If you are using a return type annotation that is not a valid Pydantic "
|
||||
"field (e.g. Union[Response, dict, None]) you can disable generating the "
|
||||
"response model from the type annotation with the path operation decorator "
|
||||
"parameter response_model=None. Read more: "
|
||||
"https://fastapi.tiangolo.com/tutorial/response-model/"
|
||||
)
|
||||
|
||||
|
||||
def create_model_field(
|
||||
name: str,
|
||||
type_: Type[Any],
|
||||
type_: Any,
|
||||
class_validators: Optional[Dict[str, Validator]] = None,
|
||||
default: Optional[Any] = Undefined,
|
||||
required: Union[bool, UndefinedType] = Undefined,
|
||||
model_config: Type[BaseConfig] = BaseConfig,
|
||||
model_config: Union[Type[BaseConfig], None] = None,
|
||||
field_info: Optional[FieldInfo] = None,
|
||||
alias: Optional[str] = None,
|
||||
mode: Literal["validation", "serialization"] = "validation",
|
||||
version: Literal["1", "auto"] = "auto",
|
||||
) -> ModelField:
|
||||
"""
|
||||
Create a new response field. Raises if type_ is invalid.
|
||||
"""
|
||||
class_validators = class_validators or {}
|
||||
if PYDANTIC_V2:
|
||||
|
||||
v1_model_config = may_v1.BaseConfig
|
||||
v1_field_info = field_info or may_v1.FieldInfo()
|
||||
v1_kwargs = {
|
||||
"name": name,
|
||||
"field_info": v1_field_info,
|
||||
"type_": type_,
|
||||
"class_validators": class_validators,
|
||||
"default": default,
|
||||
"required": required,
|
||||
"model_config": v1_model_config,
|
||||
"alias": alias,
|
||||
}
|
||||
|
||||
if (
|
||||
annotation_is_pydantic_v1(type_)
|
||||
or isinstance(field_info, may_v1.FieldInfo)
|
||||
or version == "1"
|
||||
):
|
||||
from fastapi._compat import v1
|
||||
|
||||
try:
|
||||
return v1.ModelField(**v1_kwargs) # type: ignore[no-any-return]
|
||||
except RuntimeError:
|
||||
raise fastapi.exceptions.FastAPIError(_invalid_args_message) from None
|
||||
elif PYDANTIC_V2:
|
||||
from ._compat import v2
|
||||
|
||||
field_info = field_info or FieldInfo(
|
||||
annotation=type_, default=default, alias=alias
|
||||
)
|
||||
else:
|
||||
field_info = field_info or FieldInfo()
|
||||
kwargs = {"name": name, "field_info": field_info}
|
||||
if PYDANTIC_V2:
|
||||
kwargs.update({"mode": mode})
|
||||
else:
|
||||
kwargs.update(
|
||||
{
|
||||
"type_": type_,
|
||||
"class_validators": class_validators,
|
||||
"default": default,
|
||||
"required": required,
|
||||
"model_config": model_config,
|
||||
"alias": alias,
|
||||
}
|
||||
)
|
||||
kwargs = {"mode": mode, "name": name, "field_info": field_info}
|
||||
try:
|
||||
return v2.ModelField(**kwargs) # type: ignore[return-value,arg-type]
|
||||
except PydanticSchemaGenerationError:
|
||||
raise fastapi.exceptions.FastAPIError(_invalid_args_message) from None
|
||||
# Pydantic v2 is not installed, but it's not a Pydantic v1 ModelField, it could be
|
||||
# a Pydantic v1 type, like a constrained int
|
||||
from fastapi._compat import v1
|
||||
|
||||
try:
|
||||
return ModelField(**kwargs) # type: ignore[arg-type]
|
||||
except (RuntimeError, PydanticSchemaGenerationError):
|
||||
raise fastapi.exceptions.FastAPIError(
|
||||
"Invalid args for response field! Hint: "
|
||||
f"check that {type_} is a valid Pydantic field type. "
|
||||
"If you are using a return type annotation that is not a valid Pydantic "
|
||||
"field (e.g. Union[Response, dict, None]) you can disable generating the "
|
||||
"response model from the type annotation with the path operation decorator "
|
||||
"parameter response_model=None. Read more: "
|
||||
"https://fastapi.tiangolo.com/tutorial/response-model/"
|
||||
) from None
|
||||
return v1.ModelField(**v1_kwargs) # type: ignore[no-any-return]
|
||||
except RuntimeError:
|
||||
raise fastapi.exceptions.FastAPIError(_invalid_args_message) from None
|
||||
|
||||
|
||||
def create_cloned_field(
|
||||
@@ -115,7 +138,13 @@ def create_cloned_field(
|
||||
cloned_types: Optional[MutableMapping[Type[BaseModel], Type[BaseModel]]] = None,
|
||||
) -> ModelField:
|
||||
if PYDANTIC_V2:
|
||||
return field
|
||||
from ._compat import v2
|
||||
|
||||
if isinstance(field, v2.ModelField):
|
||||
return field
|
||||
|
||||
from fastapi._compat import v1
|
||||
|
||||
# cloned_types caches already cloned types to support recursive models and improve
|
||||
# performance by avoiding unnecessary cloning
|
||||
if cloned_types is None:
|
||||
@@ -125,21 +154,23 @@ def create_cloned_field(
|
||||
if is_dataclass(original_type) and hasattr(original_type, "__pydantic_model__"):
|
||||
original_type = original_type.__pydantic_model__
|
||||
use_type = original_type
|
||||
if lenient_issubclass(original_type, BaseModel):
|
||||
original_type = cast(Type[BaseModel], original_type)
|
||||
if lenient_issubclass(original_type, v1.BaseModel):
|
||||
original_type = cast(Type[v1.BaseModel], original_type)
|
||||
use_type = cloned_types.get(original_type)
|
||||
if use_type is None:
|
||||
use_type = create_model(original_type.__name__, __base__=original_type)
|
||||
use_type = v1.create_model(original_type.__name__, __base__=original_type)
|
||||
cloned_types[original_type] = use_type
|
||||
for f in original_type.__fields__.values():
|
||||
use_type.__fields__[f.name] = create_cloned_field(
|
||||
f, cloned_types=cloned_types
|
||||
f,
|
||||
cloned_types=cloned_types,
|
||||
)
|
||||
new_field = create_response_field(name=field.name, type_=use_type)
|
||||
new_field = create_model_field(name=field.name, type_=use_type, version="1")
|
||||
new_field.has_alias = field.has_alias # type: ignore[attr-defined]
|
||||
new_field.alias = field.alias # type: ignore[misc]
|
||||
new_field.class_validators = field.class_validators # type: ignore[attr-defined]
|
||||
new_field.default = field.default # type: ignore[misc]
|
||||
new_field.default_factory = field.default_factory # type: ignore[attr-defined]
|
||||
new_field.required = field.required # type: ignore[misc]
|
||||
new_field.model_config = field.model_config # type: ignore[attr-defined]
|
||||
new_field.field_info = field.field_info
|
||||
@@ -173,17 +204,17 @@ def generate_operation_id_for_path(
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
operation_id = name + path
|
||||
operation_id = f"{name}{path}"
|
||||
operation_id = re.sub(r"\W", "_", operation_id)
|
||||
operation_id = operation_id + "_" + method.lower()
|
||||
operation_id = f"{operation_id}_{method.lower()}"
|
||||
return operation_id
|
||||
|
||||
|
||||
def generate_unique_id(route: "APIRoute") -> str:
|
||||
operation_id = route.name + route.path_format
|
||||
operation_id = f"{route.name}{route.path_format}"
|
||||
operation_id = re.sub(r"\W", "_", operation_id)
|
||||
assert route.methods
|
||||
operation_id = operation_id + "_" + list(route.methods)[0].lower()
|
||||
operation_id = f"{operation_id}_{list(route.methods)[0].lower()}"
|
||||
return operation_id
|
||||
|
||||
|
||||
@@ -221,9 +252,3 @@ def get_value_or_default(
|
||||
if not isinstance(item, DefaultPlaceholder):
|
||||
return item
|
||||
return first_item
|
||||
|
||||
|
||||
def match_pydantic_error_url(error_type: str) -> Any:
|
||||
from dirty_equals import IsStr
|
||||
|
||||
return IsStr(regex=rf"^https://errors\.pydantic\.dev/.*/v/{error_type}")
|
||||
|
||||
Reference in New Issue
Block a user