updates
This commit is contained in:
@@ -23,10 +23,12 @@ from fastapi._compat import (
|
||||
Undefined,
|
||||
UndefinedType,
|
||||
Validator,
|
||||
annotation_is_pydantic_v1,
|
||||
lenient_issubclass,
|
||||
may_v1,
|
||||
)
|
||||
from fastapi.datastructures import DefaultPlaceholder, DefaultType
|
||||
from pydantic import BaseModel, create_model
|
||||
from pydantic import BaseModel
|
||||
from pydantic.fields import FieldInfo
|
||||
from typing_extensions import Literal
|
||||
|
||||
@@ -34,9 +36,9 @@ if TYPE_CHECKING: # pragma: nocover
|
||||
from .routing import APIRoute
|
||||
|
||||
# Cache for `create_cloned_field`
|
||||
_CLONED_TYPES_CACHE: MutableMapping[
|
||||
Type[BaseModel], Type[BaseModel]
|
||||
] = WeakKeyDictionary()
|
||||
_CLONED_TYPES_CACHE: MutableMapping[Type[BaseModel], Type[BaseModel]] = (
|
||||
WeakKeyDictionary()
|
||||
)
|
||||
|
||||
|
||||
def is_body_allowed_for_status_code(status_code: Union[int, str, None]) -> bool:
|
||||
@@ -53,60 +55,81 @@ def is_body_allowed_for_status_code(status_code: Union[int, str, None]) -> bool:
|
||||
}:
|
||||
return True
|
||||
current_status_code = int(status_code)
|
||||
return not (current_status_code < 200 or current_status_code in {204, 304})
|
||||
return not (current_status_code < 200 or current_status_code in {204, 205, 304})
|
||||
|
||||
|
||||
def get_path_param_names(path: str) -> Set[str]:
|
||||
return set(re.findall("{(.*?)}", path))
|
||||
|
||||
|
||||
def create_response_field(
|
||||
_invalid_args_message = (
|
||||
"Invalid args for response field! Hint: "
|
||||
"check that {type_} is a valid Pydantic field type. "
|
||||
"If you are using a return type annotation that is not a valid Pydantic "
|
||||
"field (e.g. Union[Response, dict, None]) you can disable generating the "
|
||||
"response model from the type annotation with the path operation decorator "
|
||||
"parameter response_model=None. Read more: "
|
||||
"https://fastapi.tiangolo.com/tutorial/response-model/"
|
||||
)
|
||||
|
||||
|
||||
def create_model_field(
|
||||
name: str,
|
||||
type_: Type[Any],
|
||||
type_: Any,
|
||||
class_validators: Optional[Dict[str, Validator]] = None,
|
||||
default: Optional[Any] = Undefined,
|
||||
required: Union[bool, UndefinedType] = Undefined,
|
||||
model_config: Type[BaseConfig] = BaseConfig,
|
||||
model_config: Union[Type[BaseConfig], None] = None,
|
||||
field_info: Optional[FieldInfo] = None,
|
||||
alias: Optional[str] = None,
|
||||
mode: Literal["validation", "serialization"] = "validation",
|
||||
version: Literal["1", "auto"] = "auto",
|
||||
) -> ModelField:
|
||||
"""
|
||||
Create a new response field. Raises if type_ is invalid.
|
||||
"""
|
||||
class_validators = class_validators or {}
|
||||
if PYDANTIC_V2:
|
||||
|
||||
v1_model_config = may_v1.BaseConfig
|
||||
v1_field_info = field_info or may_v1.FieldInfo()
|
||||
v1_kwargs = {
|
||||
"name": name,
|
||||
"field_info": v1_field_info,
|
||||
"type_": type_,
|
||||
"class_validators": class_validators,
|
||||
"default": default,
|
||||
"required": required,
|
||||
"model_config": v1_model_config,
|
||||
"alias": alias,
|
||||
}
|
||||
|
||||
if (
|
||||
annotation_is_pydantic_v1(type_)
|
||||
or isinstance(field_info, may_v1.FieldInfo)
|
||||
or version == "1"
|
||||
):
|
||||
from fastapi._compat import v1
|
||||
|
||||
try:
|
||||
return v1.ModelField(**v1_kwargs) # type: ignore[no-any-return]
|
||||
except RuntimeError:
|
||||
raise fastapi.exceptions.FastAPIError(_invalid_args_message) from None
|
||||
elif PYDANTIC_V2:
|
||||
from ._compat import v2
|
||||
|
||||
field_info = field_info or FieldInfo(
|
||||
annotation=type_, default=default, alias=alias
|
||||
)
|
||||
else:
|
||||
field_info = field_info or FieldInfo()
|
||||
kwargs = {"name": name, "field_info": field_info}
|
||||
if PYDANTIC_V2:
|
||||
kwargs.update({"mode": mode})
|
||||
else:
|
||||
kwargs.update(
|
||||
{
|
||||
"type_": type_,
|
||||
"class_validators": class_validators,
|
||||
"default": default,
|
||||
"required": required,
|
||||
"model_config": model_config,
|
||||
"alias": alias,
|
||||
}
|
||||
)
|
||||
kwargs = {"mode": mode, "name": name, "field_info": field_info}
|
||||
try:
|
||||
return v2.ModelField(**kwargs) # type: ignore[return-value,arg-type]
|
||||
except PydanticSchemaGenerationError:
|
||||
raise fastapi.exceptions.FastAPIError(_invalid_args_message) from None
|
||||
# Pydantic v2 is not installed, but it's not a Pydantic v1 ModelField, it could be
|
||||
# a Pydantic v1 type, like a constrained int
|
||||
from fastapi._compat import v1
|
||||
|
||||
try:
|
||||
return ModelField(**kwargs) # type: ignore[arg-type]
|
||||
except (RuntimeError, PydanticSchemaGenerationError):
|
||||
raise fastapi.exceptions.FastAPIError(
|
||||
"Invalid args for response field! Hint: "
|
||||
f"check that {type_} is a valid Pydantic field type. "
|
||||
"If you are using a return type annotation that is not a valid Pydantic "
|
||||
"field (e.g. Union[Response, dict, None]) you can disable generating the "
|
||||
"response model from the type annotation with the path operation decorator "
|
||||
"parameter response_model=None. Read more: "
|
||||
"https://fastapi.tiangolo.com/tutorial/response-model/"
|
||||
) from None
|
||||
return v1.ModelField(**v1_kwargs) # type: ignore[no-any-return]
|
||||
except RuntimeError:
|
||||
raise fastapi.exceptions.FastAPIError(_invalid_args_message) from None
|
||||
|
||||
|
||||
def create_cloned_field(
|
||||
@@ -115,7 +138,13 @@ def create_cloned_field(
|
||||
cloned_types: Optional[MutableMapping[Type[BaseModel], Type[BaseModel]]] = None,
|
||||
) -> ModelField:
|
||||
if PYDANTIC_V2:
|
||||
return field
|
||||
from ._compat import v2
|
||||
|
||||
if isinstance(field, v2.ModelField):
|
||||
return field
|
||||
|
||||
from fastapi._compat import v1
|
||||
|
||||
# cloned_types caches already cloned types to support recursive models and improve
|
||||
# performance by avoiding unnecessary cloning
|
||||
if cloned_types is None:
|
||||
@@ -125,21 +154,23 @@ def create_cloned_field(
|
||||
if is_dataclass(original_type) and hasattr(original_type, "__pydantic_model__"):
|
||||
original_type = original_type.__pydantic_model__
|
||||
use_type = original_type
|
||||
if lenient_issubclass(original_type, BaseModel):
|
||||
original_type = cast(Type[BaseModel], original_type)
|
||||
if lenient_issubclass(original_type, v1.BaseModel):
|
||||
original_type = cast(Type[v1.BaseModel], original_type)
|
||||
use_type = cloned_types.get(original_type)
|
||||
if use_type is None:
|
||||
use_type = create_model(original_type.__name__, __base__=original_type)
|
||||
use_type = v1.create_model(original_type.__name__, __base__=original_type)
|
||||
cloned_types[original_type] = use_type
|
||||
for f in original_type.__fields__.values():
|
||||
use_type.__fields__[f.name] = create_cloned_field(
|
||||
f, cloned_types=cloned_types
|
||||
f,
|
||||
cloned_types=cloned_types,
|
||||
)
|
||||
new_field = create_response_field(name=field.name, type_=use_type)
|
||||
new_field = create_model_field(name=field.name, type_=use_type, version="1")
|
||||
new_field.has_alias = field.has_alias # type: ignore[attr-defined]
|
||||
new_field.alias = field.alias # type: ignore[misc]
|
||||
new_field.class_validators = field.class_validators # type: ignore[attr-defined]
|
||||
new_field.default = field.default # type: ignore[misc]
|
||||
new_field.default_factory = field.default_factory # type: ignore[attr-defined]
|
||||
new_field.required = field.required # type: ignore[misc]
|
||||
new_field.model_config = field.model_config # type: ignore[attr-defined]
|
||||
new_field.field_info = field.field_info
|
||||
@@ -173,17 +204,17 @@ def generate_operation_id_for_path(
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
operation_id = name + path
|
||||
operation_id = f"{name}{path}"
|
||||
operation_id = re.sub(r"\W", "_", operation_id)
|
||||
operation_id = operation_id + "_" + method.lower()
|
||||
operation_id = f"{operation_id}_{method.lower()}"
|
||||
return operation_id
|
||||
|
||||
|
||||
def generate_unique_id(route: "APIRoute") -> str:
|
||||
operation_id = route.name + route.path_format
|
||||
operation_id = f"{route.name}{route.path_format}"
|
||||
operation_id = re.sub(r"\W", "_", operation_id)
|
||||
assert route.methods
|
||||
operation_id = operation_id + "_" + list(route.methods)[0].lower()
|
||||
operation_id = f"{operation_id}_{list(route.methods)[0].lower()}"
|
||||
return operation_id
|
||||
|
||||
|
||||
@@ -221,9 +252,3 @@ def get_value_or_default(
|
||||
if not isinstance(item, DefaultPlaceholder):
|
||||
return item
|
||||
return first_item
|
||||
|
||||
|
||||
def match_pydantic_error_url(error_type: str) -> Any:
|
||||
from dirty_equals import IsStr
|
||||
|
||||
return IsStr(regex=rf"^https://errors\.pydantic\.dev/.*/v/{error_type}")
|
||||
|
||||
Reference in New Issue
Block a user