updates
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -2,30 +2,26 @@ from __future__ import annotations as _annotations
|
||||
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from re import Pattern
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Literal,
|
||||
cast,
|
||||
)
|
||||
|
||||
from pydantic_core import core_schema
|
||||
from typing_extensions import (
|
||||
Literal,
|
||||
Self,
|
||||
)
|
||||
from typing_extensions import Self
|
||||
|
||||
from ..aliases import AliasGenerator
|
||||
from ..config import ConfigDict, ExtraValues, JsonDict, JsonEncoder, JsonSchemaExtraCallable
|
||||
from ..errors import PydanticUserError
|
||||
from ..warnings import PydanticDeprecatedSince20
|
||||
|
||||
if not TYPE_CHECKING:
|
||||
# See PyCharm issues https://youtrack.jetbrains.com/issue/PY-21915
|
||||
# and https://youtrack.jetbrains.com/issue/PY-51428
|
||||
DeprecationWarning = PydanticDeprecatedSince20
|
||||
from ..warnings import PydanticDeprecatedSince20, PydanticDeprecatedSince210
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .._internal._schema_generation_shared import GenerateSchema
|
||||
from ..fields import ComputedFieldInfo, FieldInfo
|
||||
|
||||
DEPRECATION_MESSAGE = 'Support for class-based `config` is deprecated, use ConfigDict instead.'
|
||||
|
||||
@@ -55,7 +51,9 @@ class ConfigWrapper:
|
||||
# whether to use the actual key provided in the data (e.g. alias or first alias for "field required" errors) instead of field_names
|
||||
# to construct error `loc`s, default `True`
|
||||
loc_by_alias: bool
|
||||
alias_generator: Callable[[str], str] | None
|
||||
alias_generator: Callable[[str], str] | AliasGenerator | None
|
||||
model_title_generator: Callable[[type], str] | None
|
||||
field_title_generator: Callable[[str, FieldInfo | ComputedFieldInfo], str] | None
|
||||
ignored_types: tuple[type, ...]
|
||||
allow_inf_nan: bool
|
||||
json_schema_extra: JsonDict | JsonSchemaExtraCallable | None
|
||||
@@ -66,11 +64,15 @@ class ConfigWrapper:
|
||||
# whether instances of models and dataclasses (including subclass instances) should re-validate, default 'never'
|
||||
revalidate_instances: Literal['always', 'never', 'subclass-instances']
|
||||
ser_json_timedelta: Literal['iso8601', 'float']
|
||||
ser_json_bytes: Literal['utf8', 'base64']
|
||||
ser_json_temporal: Literal['iso8601', 'seconds', 'milliseconds']
|
||||
val_temporal_unit: Literal['seconds', 'milliseconds', 'infer']
|
||||
ser_json_bytes: Literal['utf8', 'base64', 'hex']
|
||||
val_json_bytes: Literal['utf8', 'base64', 'hex']
|
||||
ser_json_inf_nan: Literal['null', 'constants', 'strings']
|
||||
# whether to validate default values during validation, default False
|
||||
validate_default: bool
|
||||
validate_return: bool
|
||||
protected_namespaces: tuple[str, ...]
|
||||
protected_namespaces: tuple[str | Pattern[str], ...]
|
||||
hide_input_in_errors: bool
|
||||
defer_build: bool
|
||||
plugin_settings: dict[str, object] | None
|
||||
@@ -80,6 +82,12 @@ class ConfigWrapper:
|
||||
coerce_numbers_to_str: bool
|
||||
regex_engine: Literal['rust-regex', 'python-re']
|
||||
validation_error_cause: bool
|
||||
use_attribute_docstrings: bool
|
||||
cache_strings: bool | Literal['all', 'keys', 'none']
|
||||
validate_by_alias: bool
|
||||
validate_by_name: bool
|
||||
serialize_by_alias: bool
|
||||
url_preserve_empty_path: bool
|
||||
|
||||
def __init__(self, config: ConfigDict | dict[str, Any] | type[Any] | None, *, check: bool = True):
|
||||
if check:
|
||||
@@ -88,7 +96,13 @@ class ConfigWrapper:
|
||||
self.config_dict = cast(ConfigDict, config)
|
||||
|
||||
@classmethod
|
||||
def for_model(cls, bases: tuple[type[Any], ...], namespace: dict[str, Any], kwargs: dict[str, Any]) -> Self:
|
||||
def for_model(
|
||||
cls,
|
||||
bases: tuple[type[Any], ...],
|
||||
namespace: dict[str, Any],
|
||||
raw_annotations: dict[str, Any],
|
||||
kwargs: dict[str, Any],
|
||||
) -> Self:
|
||||
"""Build a new `ConfigWrapper` instance for a `BaseModel`.
|
||||
|
||||
The config wrapper built based on (in descending order of priority):
|
||||
@@ -99,6 +113,7 @@ class ConfigWrapper:
|
||||
Args:
|
||||
bases: A tuple of base classes.
|
||||
namespace: The namespace of the class being created.
|
||||
raw_annotations: The (non-evaluated) annotations of the model.
|
||||
kwargs: The kwargs passed to the class being created.
|
||||
|
||||
Returns:
|
||||
@@ -113,6 +128,12 @@ class ConfigWrapper:
|
||||
config_class_from_namespace = namespace.get('Config')
|
||||
config_dict_from_namespace = namespace.get('model_config')
|
||||
|
||||
if raw_annotations.get('model_config') and config_dict_from_namespace is None:
|
||||
raise PydanticUserError(
|
||||
'`model_config` cannot be used as a model field name. Use `model_config` for model configuration.',
|
||||
code='model-config-invalid-field-name',
|
||||
)
|
||||
|
||||
if config_class_from_namespace and config_dict_from_namespace:
|
||||
raise PydanticUserError('"Config" and "model_config" cannot be used together', code='config-both')
|
||||
|
||||
@@ -138,48 +159,80 @@ class ConfigWrapper:
|
||||
except KeyError:
|
||||
raise AttributeError(f'Config has no attribute {name!r}') from None
|
||||
|
||||
def core_config(self, obj: Any) -> core_schema.CoreConfig:
|
||||
"""Create a pydantic-core config, `obj` is just used to populate `title` if not set in config.
|
||||
|
||||
Pass `obj=None` if you do not want to attempt to infer the `title`.
|
||||
def core_config(self, title: str | None) -> core_schema.CoreConfig:
|
||||
"""Create a pydantic-core config.
|
||||
|
||||
We don't use getattr here since we don't want to populate with defaults.
|
||||
|
||||
Args:
|
||||
obj: An object used to populate `title` if not set in config.
|
||||
title: The title to use if not set in config.
|
||||
|
||||
Returns:
|
||||
A `CoreConfig` object created from config.
|
||||
"""
|
||||
config = self.config_dict
|
||||
|
||||
def dict_not_none(**kwargs: Any) -> Any:
|
||||
return {k: v for k, v in kwargs.items() if v is not None}
|
||||
|
||||
core_config = core_schema.CoreConfig(
|
||||
**dict_not_none(
|
||||
title=self.config_dict.get('title') or (obj and obj.__name__),
|
||||
extra_fields_behavior=self.config_dict.get('extra'),
|
||||
allow_inf_nan=self.config_dict.get('allow_inf_nan'),
|
||||
populate_by_name=self.config_dict.get('populate_by_name'),
|
||||
str_strip_whitespace=self.config_dict.get('str_strip_whitespace'),
|
||||
str_to_lower=self.config_dict.get('str_to_lower'),
|
||||
str_to_upper=self.config_dict.get('str_to_upper'),
|
||||
strict=self.config_dict.get('strict'),
|
||||
ser_json_timedelta=self.config_dict.get('ser_json_timedelta'),
|
||||
ser_json_bytes=self.config_dict.get('ser_json_bytes'),
|
||||
from_attributes=self.config_dict.get('from_attributes'),
|
||||
loc_by_alias=self.config_dict.get('loc_by_alias'),
|
||||
revalidate_instances=self.config_dict.get('revalidate_instances'),
|
||||
validate_default=self.config_dict.get('validate_default'),
|
||||
str_max_length=self.config_dict.get('str_max_length'),
|
||||
str_min_length=self.config_dict.get('str_min_length'),
|
||||
hide_input_in_errors=self.config_dict.get('hide_input_in_errors'),
|
||||
coerce_numbers_to_str=self.config_dict.get('coerce_numbers_to_str'),
|
||||
regex_engine=self.config_dict.get('regex_engine'),
|
||||
validation_error_cause=self.config_dict.get('validation_error_cause'),
|
||||
if config.get('schema_generator') is not None:
|
||||
warnings.warn(
|
||||
'The `schema_generator` setting has been deprecated since v2.10. This setting no longer has any effect.',
|
||||
PydanticDeprecatedSince210,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
if (populate_by_name := config.get('populate_by_name')) is not None:
|
||||
# We include this patch for backwards compatibility purposes, but this config setting will be deprecated in v3.0, and likely removed in v4.0.
|
||||
# Thus, the above warning and this patch can be removed then as well.
|
||||
if config.get('validate_by_name') is None:
|
||||
config['validate_by_alias'] = True
|
||||
config['validate_by_name'] = populate_by_name
|
||||
|
||||
# We dynamically patch validate_by_name to be True if validate_by_alias is set to False
|
||||
# and validate_by_name is not explicitly set.
|
||||
if config.get('validate_by_alias') is False and config.get('validate_by_name') is None:
|
||||
config['validate_by_name'] = True
|
||||
|
||||
if (not config.get('validate_by_alias', True)) and (not config.get('validate_by_name', False)):
|
||||
raise PydanticUserError(
|
||||
'At least one of `validate_by_alias` or `validate_by_name` must be set to True.',
|
||||
code='validate-by-alias-and-name-false',
|
||||
)
|
||||
|
||||
return core_schema.CoreConfig(
|
||||
**{ # pyright: ignore[reportArgumentType]
|
||||
k: v
|
||||
for k, v in (
|
||||
('title', config.get('title') or title or None),
|
||||
('extra_fields_behavior', config.get('extra')),
|
||||
('allow_inf_nan', config.get('allow_inf_nan')),
|
||||
('str_strip_whitespace', config.get('str_strip_whitespace')),
|
||||
('str_to_lower', config.get('str_to_lower')),
|
||||
('str_to_upper', config.get('str_to_upper')),
|
||||
('strict', config.get('strict')),
|
||||
('ser_json_timedelta', config.get('ser_json_timedelta')),
|
||||
('ser_json_temporal', config.get('ser_json_temporal')),
|
||||
('val_temporal_unit', config.get('val_temporal_unit')),
|
||||
('ser_json_bytes', config.get('ser_json_bytes')),
|
||||
('val_json_bytes', config.get('val_json_bytes')),
|
||||
('ser_json_inf_nan', config.get('ser_json_inf_nan')),
|
||||
('from_attributes', config.get('from_attributes')),
|
||||
('loc_by_alias', config.get('loc_by_alias')),
|
||||
('revalidate_instances', config.get('revalidate_instances')),
|
||||
('validate_default', config.get('validate_default')),
|
||||
('str_max_length', config.get('str_max_length')),
|
||||
('str_min_length', config.get('str_min_length')),
|
||||
('hide_input_in_errors', config.get('hide_input_in_errors')),
|
||||
('coerce_numbers_to_str', config.get('coerce_numbers_to_str')),
|
||||
('regex_engine', config.get('regex_engine')),
|
||||
('validation_error_cause', config.get('validation_error_cause')),
|
||||
('cache_strings', config.get('cache_strings')),
|
||||
('validate_by_alias', config.get('validate_by_alias')),
|
||||
('validate_by_name', config.get('validate_by_name')),
|
||||
('serialize_by_alias', config.get('serialize_by_alias')),
|
||||
('url_preserve_empty_path', config.get('url_preserve_empty_path')),
|
||||
)
|
||||
if v is not None
|
||||
}
|
||||
)
|
||||
return core_config
|
||||
|
||||
def __repr__(self):
|
||||
c = ', '.join(f'{k}={v!r}' for k, v in self.config_dict.items())
|
||||
@@ -229,26 +282,38 @@ config_defaults = ConfigDict(
|
||||
from_attributes=False,
|
||||
loc_by_alias=True,
|
||||
alias_generator=None,
|
||||
model_title_generator=None,
|
||||
field_title_generator=None,
|
||||
ignored_types=(),
|
||||
allow_inf_nan=True,
|
||||
json_schema_extra=None,
|
||||
strict=False,
|
||||
revalidate_instances='never',
|
||||
ser_json_timedelta='iso8601',
|
||||
ser_json_temporal='iso8601',
|
||||
val_temporal_unit='infer',
|
||||
ser_json_bytes='utf8',
|
||||
val_json_bytes='utf8',
|
||||
ser_json_inf_nan='null',
|
||||
validate_default=False,
|
||||
validate_return=False,
|
||||
protected_namespaces=('model_',),
|
||||
protected_namespaces=('model_validate', 'model_dump'),
|
||||
hide_input_in_errors=False,
|
||||
json_encoders=None,
|
||||
defer_build=False,
|
||||
plugin_settings=None,
|
||||
schema_generator=None,
|
||||
plugin_settings=None,
|
||||
json_schema_serialization_defaults_required=False,
|
||||
json_schema_mode_override=None,
|
||||
coerce_numbers_to_str=False,
|
||||
regex_engine='rust-regex',
|
||||
validation_error_cause=False,
|
||||
use_attribute_docstrings=False,
|
||||
cache_strings=True,
|
||||
validate_by_alias=True,
|
||||
validate_by_name=False,
|
||||
serialize_by_alias=False,
|
||||
url_preserve_empty_path=False,
|
||||
)
|
||||
|
||||
|
||||
@@ -265,7 +330,7 @@ def prepare_config(config: ConfigDict | dict[str, Any] | type[Any] | None) -> Co
|
||||
return ConfigDict()
|
||||
|
||||
if not isinstance(config, dict):
|
||||
warnings.warn(DEPRECATION_MESSAGE, DeprecationWarning)
|
||||
warnings.warn(DEPRECATION_MESSAGE, PydanticDeprecatedSince20, stacklevel=4)
|
||||
config = {k: getattr(config, k) for k in dir(config) if not k.startswith('__')}
|
||||
|
||||
config_dict = cast(ConfigDict, config)
|
||||
@@ -289,7 +354,7 @@ V2_REMOVED_KEYS = {
|
||||
'post_init_call',
|
||||
}
|
||||
V2_RENAMED_KEYS = {
|
||||
'allow_population_by_field_name': 'populate_by_name',
|
||||
'allow_population_by_field_name': 'validate_by_name',
|
||||
'anystr_lower': 'str_to_lower',
|
||||
'anystr_strip_whitespace': 'str_strip_whitespace',
|
||||
'anystr_upper': 'str_to_upper',
|
||||
|
||||
@@ -1,92 +1,97 @@
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import typing
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any, TypedDict, cast
|
||||
from warnings import warn
|
||||
|
||||
import typing_extensions
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from ._schema_generation_shared import (
|
||||
CoreSchemaOrField as CoreSchemaOrField,
|
||||
)
|
||||
if TYPE_CHECKING:
|
||||
from ..config import JsonDict, JsonSchemaExtraCallable
|
||||
from ._schema_generation_shared import (
|
||||
GetJsonSchemaFunction,
|
||||
)
|
||||
|
||||
|
||||
class CoreMetadata(typing_extensions.TypedDict, total=False):
|
||||
class CoreMetadata(TypedDict, total=False):
|
||||
"""A `TypedDict` for holding the metadata dict of the schema.
|
||||
|
||||
Attributes:
|
||||
pydantic_js_functions: List of JSON schema functions.
|
||||
pydantic_js_functions: List of JSON schema functions that resolve refs during application.
|
||||
pydantic_js_annotation_functions: List of JSON schema functions that don't resolve refs during application.
|
||||
pydantic_js_prefer_positional_arguments: Whether JSON schema generator will
|
||||
prefer positional over keyword arguments for an 'arguments' schema.
|
||||
custom validation function. Only applies to before, plain, and wrap validators.
|
||||
pydantic_js_updates: key / value pair updates to apply to the JSON schema for a type.
|
||||
pydantic_js_extra: WIP, either key/value pair updates to apply to the JSON schema, or a custom callable.
|
||||
pydantic_internal_union_tag_key: Used internally by the `Tag` metadata to specify the tag used for a discriminated union.
|
||||
pydantic_internal_union_discriminator: Used internally to specify the discriminator value for a discriminated union
|
||||
when the discriminator was applied to a `'definition-ref'` schema, and that reference was missing at the time
|
||||
of the annotation application.
|
||||
|
||||
TODO: Perhaps we should move this structure to pydantic-core. At the moment, though,
|
||||
it's easier to iterate on if we leave it in pydantic until we feel there is a semi-stable API.
|
||||
|
||||
TODO: It's unfortunate how functionally oriented JSON schema generation is, especially that which occurs during
|
||||
the core schema generation process. It's inevitable that we need to store some json schema related information
|
||||
on core schemas, given that we generate JSON schemas directly from core schemas. That being said, debugging related
|
||||
issues is quite difficult when JSON schema information is disguised via dynamically defined functions.
|
||||
"""
|
||||
|
||||
pydantic_js_functions: list[GetJsonSchemaFunction]
|
||||
pydantic_js_annotation_functions: list[GetJsonSchemaFunction]
|
||||
|
||||
# If `pydantic_js_prefer_positional_arguments` is True, the JSON schema generator will
|
||||
# prefer positional over keyword arguments for an 'arguments' schema.
|
||||
pydantic_js_prefer_positional_arguments: bool | None
|
||||
|
||||
pydantic_typed_dict_cls: type[Any] | None # TODO: Consider moving this into the pydantic-core TypedDictSchema
|
||||
pydantic_js_prefer_positional_arguments: bool
|
||||
pydantic_js_updates: JsonDict
|
||||
pydantic_js_extra: JsonDict | JsonSchemaExtraCallable
|
||||
pydantic_internal_union_tag_key: str
|
||||
pydantic_internal_union_discriminator: str
|
||||
|
||||
|
||||
class CoreMetadataHandler:
|
||||
"""Because the metadata field in pydantic_core is of type `Any`, we can't assume much about its contents.
|
||||
def update_core_metadata(
|
||||
core_metadata: Any,
|
||||
/,
|
||||
*,
|
||||
pydantic_js_functions: list[GetJsonSchemaFunction] | None = None,
|
||||
pydantic_js_annotation_functions: list[GetJsonSchemaFunction] | None = None,
|
||||
pydantic_js_updates: JsonDict | None = None,
|
||||
pydantic_js_extra: JsonDict | JsonSchemaExtraCallable | None = None,
|
||||
) -> None:
|
||||
from ..json_schema import PydanticJsonSchemaWarning
|
||||
|
||||
This class is used to interact with the metadata field on a CoreSchema object in a consistent
|
||||
way throughout pydantic.
|
||||
"""Update CoreMetadata instance in place. When we make modifications in this function, they
|
||||
take effect on the `core_metadata` reference passed in as the first (and only) positional argument.
|
||||
|
||||
First, cast to `CoreMetadata`, then finish with a cast to `dict[str, Any]` for core schema compatibility.
|
||||
We do this here, instead of before / after each call to this function so that this typing hack
|
||||
can be easily removed if/when we move `CoreMetadata` to `pydantic-core`.
|
||||
|
||||
For parameter descriptions, see `CoreMetadata` above.
|
||||
"""
|
||||
core_metadata = cast(CoreMetadata, core_metadata)
|
||||
|
||||
__slots__ = ('_schema',)
|
||||
if pydantic_js_functions:
|
||||
core_metadata.setdefault('pydantic_js_functions', []).extend(pydantic_js_functions)
|
||||
|
||||
def __init__(self, schema: CoreSchemaOrField):
|
||||
self._schema = schema
|
||||
if pydantic_js_annotation_functions:
|
||||
core_metadata.setdefault('pydantic_js_annotation_functions', []).extend(pydantic_js_annotation_functions)
|
||||
|
||||
metadata = schema.get('metadata')
|
||||
if metadata is None:
|
||||
schema['metadata'] = CoreMetadata()
|
||||
elif not isinstance(metadata, dict):
|
||||
raise TypeError(f'CoreSchema metadata should be a dict; got {metadata!r}.')
|
||||
if pydantic_js_updates:
|
||||
if (existing_updates := core_metadata.get('pydantic_js_updates')) is not None:
|
||||
core_metadata['pydantic_js_updates'] = {**existing_updates, **pydantic_js_updates}
|
||||
else:
|
||||
core_metadata['pydantic_js_updates'] = pydantic_js_updates
|
||||
|
||||
@property
|
||||
def metadata(self) -> CoreMetadata:
|
||||
"""Retrieves the metadata dict from the schema, initializing it to a dict if it is None
|
||||
and raises an error if it is not a dict.
|
||||
"""
|
||||
metadata = self._schema.get('metadata')
|
||||
if metadata is None:
|
||||
self._schema['metadata'] = metadata = CoreMetadata()
|
||||
if not isinstance(metadata, dict):
|
||||
raise TypeError(f'CoreSchema metadata should be a dict; got {metadata!r}.')
|
||||
return metadata
|
||||
|
||||
|
||||
def build_metadata_dict(
|
||||
*, # force keyword arguments to make it easier to modify this signature in a backwards-compatible way
|
||||
js_functions: list[GetJsonSchemaFunction] | None = None,
|
||||
js_annotation_functions: list[GetJsonSchemaFunction] | None = None,
|
||||
js_prefer_positional_arguments: bool | None = None,
|
||||
typed_dict_cls: type[Any] | None = None,
|
||||
initial_metadata: Any | None = None,
|
||||
) -> Any:
|
||||
"""Builds a dict to use as the metadata field of a CoreSchema object in a manner that is consistent
|
||||
with the CoreMetadataHandler class.
|
||||
"""
|
||||
if initial_metadata is not None and not isinstance(initial_metadata, dict):
|
||||
raise TypeError(f'CoreSchema metadata should be a dict; got {initial_metadata!r}.')
|
||||
|
||||
metadata = CoreMetadata(
|
||||
pydantic_js_functions=js_functions or [],
|
||||
pydantic_js_annotation_functions=js_annotation_functions or [],
|
||||
pydantic_js_prefer_positional_arguments=js_prefer_positional_arguments,
|
||||
pydantic_typed_dict_cls=typed_dict_cls,
|
||||
)
|
||||
metadata = {k: v for k, v in metadata.items() if v is not None}
|
||||
|
||||
if initial_metadata is not None:
|
||||
metadata = {**initial_metadata, **metadata}
|
||||
|
||||
return metadata
|
||||
if pydantic_js_extra is not None:
|
||||
existing_pydantic_js_extra = core_metadata.get('pydantic_js_extra')
|
||||
if existing_pydantic_js_extra is None:
|
||||
core_metadata['pydantic_js_extra'] = pydantic_js_extra
|
||||
if isinstance(existing_pydantic_js_extra, dict):
|
||||
if isinstance(pydantic_js_extra, dict):
|
||||
core_metadata['pydantic_js_extra'] = {**existing_pydantic_js_extra, **pydantic_js_extra}
|
||||
if callable(pydantic_js_extra):
|
||||
warn(
|
||||
'Composing `dict` and `callable` type `json_schema_extra` is not supported.'
|
||||
'The `callable` type is being ignored.'
|
||||
"If you'd like support for this behavior, please open an issue on pydantic.",
|
||||
PydanticJsonSchemaWarning,
|
||||
)
|
||||
if callable(existing_pydantic_js_extra):
|
||||
# if ever there's a case of a callable, we'll just keep the last json schema extra spec
|
||||
core_metadata['pydantic_js_extra'] = pydantic_js_extra
|
||||
|
||||
@@ -1,23 +1,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Hashable,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
import inspect
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Union
|
||||
|
||||
from pydantic_core import CoreSchema, core_schema
|
||||
from pydantic_core import validate_core_schema as _validate_core_schema
|
||||
from typing_extensions import TypeAliasType, TypeGuard, get_args, get_origin
|
||||
from typing_extensions import TypeGuard, get_args, get_origin
|
||||
from typing_inspection import typing_objects
|
||||
|
||||
from . import _repr
|
||||
from ._typing_extra import is_generic_alias
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from rich.console import Console
|
||||
|
||||
AnyFunctionSchema = Union[
|
||||
core_schema.AfterValidatorFunctionSchema,
|
||||
core_schema.BeforeValidatorFunctionSchema,
|
||||
@@ -39,23 +35,7 @@ CoreSchemaOrField = Union[core_schema.CoreSchema, CoreSchemaField]
|
||||
|
||||
_CORE_SCHEMA_FIELD_TYPES = {'typed-dict-field', 'dataclass-field', 'model-field', 'computed-field'}
|
||||
_FUNCTION_WITH_INNER_SCHEMA_TYPES = {'function-before', 'function-after', 'function-wrap'}
|
||||
_LIST_LIKE_SCHEMA_WITH_ITEMS_TYPES = {'list', 'tuple-variable', 'set', 'frozenset'}
|
||||
|
||||
_DEFINITIONS_CACHE_METADATA_KEY = 'pydantic.definitions_cache'
|
||||
|
||||
NEEDS_APPLY_DISCRIMINATED_UNION_METADATA_KEY = 'pydantic.internal.needs_apply_discriminated_union'
|
||||
"""Used to mark a schema that has a discriminated union that needs to be checked for validity at the end of
|
||||
schema building because one of it's members refers to a definition that was not yet defined when the union
|
||||
was first encountered.
|
||||
"""
|
||||
TAGGED_UNION_TAG_KEY = 'pydantic.internal.tagged_union_tag'
|
||||
"""
|
||||
Used in a `Tag` schema to specify the tag used for a discriminated union.
|
||||
"""
|
||||
HAS_INVALID_SCHEMAS_METADATA_KEY = 'pydantic.internal.invalid'
|
||||
"""Used to mark a schema that is invalid because it refers to a definition that was not yet defined when the
|
||||
schema was first encountered.
|
||||
"""
|
||||
_LIST_LIKE_SCHEMA_WITH_ITEMS_TYPES = {'list', 'set', 'frozenset'}
|
||||
|
||||
|
||||
def is_core_schema(
|
||||
@@ -78,13 +58,11 @@ def is_function_with_inner_schema(
|
||||
|
||||
def is_list_like_schema_with_items_schema(
|
||||
schema: CoreSchema,
|
||||
) -> TypeGuard[
|
||||
core_schema.ListSchema | core_schema.TupleVariableSchema | core_schema.SetSchema | core_schema.FrozenSetSchema
|
||||
]:
|
||||
) -> TypeGuard[core_schema.ListSchema | core_schema.SetSchema | core_schema.FrozenSetSchema]:
|
||||
return schema['type'] in _LIST_LIKE_SCHEMA_WITH_ITEMS_TYPES
|
||||
|
||||
|
||||
def get_type_ref(type_: type[Any], args_override: tuple[type[Any], ...] | None = None) -> str:
|
||||
def get_type_ref(type_: Any, args_override: tuple[type[Any], ...] | None = None) -> str:
|
||||
"""Produces the ref to be used for this type by pydantic_core's core schemas.
|
||||
|
||||
This `args_override` argument was added for the purpose of creating valid recursive references
|
||||
@@ -99,7 +77,7 @@ def get_type_ref(type_: type[Any], args_override: tuple[type[Any], ...] | None =
|
||||
args = generic_metadata['args'] or args
|
||||
|
||||
module_name = getattr(origin, '__module__', '<No __module__>')
|
||||
if isinstance(origin, TypeAliasType):
|
||||
if typing_objects.is_typealiastype(origin):
|
||||
type_ref = f'{module_name}.{origin.__name__}:{id(origin)}'
|
||||
else:
|
||||
try:
|
||||
@@ -129,457 +107,68 @@ def get_ref(s: core_schema.CoreSchema) -> None | str:
|
||||
return s.get('ref', None)
|
||||
|
||||
|
||||
def collect_definitions(schema: core_schema.CoreSchema) -> dict[str, core_schema.CoreSchema]:
|
||||
defs: dict[str, CoreSchema] = {}
|
||||
def _clean_schema_for_pretty_print(obj: Any, strip_metadata: bool = True) -> Any: # pragma: no cover
|
||||
"""A utility function to remove irrelevant information from a core schema."""
|
||||
if isinstance(obj, Mapping):
|
||||
new_dct = {}
|
||||
for k, v in obj.items():
|
||||
if k == 'metadata' and strip_metadata:
|
||||
new_metadata = {}
|
||||
|
||||
def _record_valid_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema:
|
||||
ref = get_ref(s)
|
||||
if ref:
|
||||
defs[ref] = s
|
||||
return recurse(s, _record_valid_refs)
|
||||
for meta_k, meta_v in v.items():
|
||||
if meta_k in ('pydantic_js_functions', 'pydantic_js_annotation_functions'):
|
||||
new_metadata['js_metadata'] = '<stripped>'
|
||||
else:
|
||||
new_metadata[meta_k] = _clean_schema_for_pretty_print(meta_v, strip_metadata=strip_metadata)
|
||||
|
||||
walk_core_schema(schema, _record_valid_refs)
|
||||
if list(new_metadata.keys()) == ['js_metadata']:
|
||||
new_metadata = {'<stripped>'}
|
||||
|
||||
return defs
|
||||
|
||||
|
||||
def define_expected_missing_refs(
|
||||
schema: core_schema.CoreSchema, allowed_missing_refs: set[str]
|
||||
) -> core_schema.CoreSchema | None:
|
||||
if not allowed_missing_refs:
|
||||
# in this case, there are no missing refs to potentially substitute, so there's no need to walk the schema
|
||||
# this is a common case (will be hit for all non-generic models), so it's worth optimizing for
|
||||
return None
|
||||
|
||||
refs = collect_definitions(schema).keys()
|
||||
|
||||
expected_missing_refs = allowed_missing_refs.difference(refs)
|
||||
if expected_missing_refs:
|
||||
definitions: list[core_schema.CoreSchema] = [
|
||||
# TODO: Replace this with a (new) CoreSchema that, if present at any level, makes validation fail
|
||||
# Issue: https://github.com/pydantic/pydantic-core/issues/619
|
||||
core_schema.none_schema(ref=ref, metadata={HAS_INVALID_SCHEMAS_METADATA_KEY: True})
|
||||
for ref in expected_missing_refs
|
||||
]
|
||||
return core_schema.definitions_schema(schema, definitions)
|
||||
return None
|
||||
|
||||
|
||||
def collect_invalid_schemas(schema: core_schema.CoreSchema) -> bool:
|
||||
invalid = False
|
||||
|
||||
def _is_schema_valid(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema:
|
||||
nonlocal invalid
|
||||
if 'metadata' in s:
|
||||
metadata = s['metadata']
|
||||
if HAS_INVALID_SCHEMAS_METADATA_KEY in metadata:
|
||||
invalid = metadata[HAS_INVALID_SCHEMAS_METADATA_KEY]
|
||||
return s
|
||||
return recurse(s, _is_schema_valid)
|
||||
|
||||
walk_core_schema(schema, _is_schema_valid)
|
||||
return invalid
|
||||
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
Recurse = Callable[[core_schema.CoreSchema, 'Walk'], core_schema.CoreSchema]
|
||||
Walk = Callable[[core_schema.CoreSchema, Recurse], core_schema.CoreSchema]
|
||||
|
||||
# TODO: Should we move _WalkCoreSchema into pydantic_core proper?
|
||||
# Issue: https://github.com/pydantic/pydantic-core/issues/615
|
||||
|
||||
|
||||
class _WalkCoreSchema:
|
||||
def __init__(self):
|
||||
self._schema_type_to_method = self._build_schema_type_to_method()
|
||||
|
||||
def _build_schema_type_to_method(self) -> dict[core_schema.CoreSchemaType, Recurse]:
|
||||
mapping: dict[core_schema.CoreSchemaType, Recurse] = {}
|
||||
key: core_schema.CoreSchemaType
|
||||
for key in get_args(core_schema.CoreSchemaType):
|
||||
method_name = f"handle_{key.replace('-', '_')}_schema"
|
||||
mapping[key] = getattr(self, method_name, self._handle_other_schemas)
|
||||
return mapping
|
||||
|
||||
def walk(self, schema: core_schema.CoreSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
return f(schema, self._walk)
|
||||
|
||||
def _walk(self, schema: core_schema.CoreSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
schema = self._schema_type_to_method[schema['type']](schema.copy(), f)
|
||||
ser_schema: core_schema.SerSchema | None = schema.get('serialization') # type: ignore
|
||||
if ser_schema:
|
||||
schema['serialization'] = self._handle_ser_schemas(ser_schema, f)
|
||||
return schema
|
||||
|
||||
def _handle_other_schemas(self, schema: core_schema.CoreSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
sub_schema = schema.get('schema', None)
|
||||
if sub_schema is not None:
|
||||
schema['schema'] = self.walk(sub_schema, f) # type: ignore
|
||||
return schema
|
||||
|
||||
def _handle_ser_schemas(self, ser_schema: core_schema.SerSchema, f: Walk) -> core_schema.SerSchema:
|
||||
schema: core_schema.CoreSchema | None = ser_schema.get('schema', None)
|
||||
if schema is not None:
|
||||
ser_schema['schema'] = self.walk(schema, f) # type: ignore
|
||||
return_schema: core_schema.CoreSchema | None = ser_schema.get('return_schema', None)
|
||||
if return_schema is not None:
|
||||
ser_schema['return_schema'] = self.walk(return_schema, f) # type: ignore
|
||||
return ser_schema
|
||||
|
||||
def handle_definitions_schema(self, schema: core_schema.DefinitionsSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
new_definitions: list[core_schema.CoreSchema] = []
|
||||
for definition in schema['definitions']:
|
||||
updated_definition = self.walk(definition, f)
|
||||
if 'ref' in updated_definition:
|
||||
# If the updated definition schema doesn't have a 'ref', it shouldn't go in the definitions
|
||||
# This is most likely to happen due to replacing something with a definition reference, in
|
||||
# which case it should certainly not go in the definitions list
|
||||
new_definitions.append(updated_definition)
|
||||
new_inner_schema = self.walk(schema['schema'], f)
|
||||
|
||||
if not new_definitions and len(schema) == 3:
|
||||
# This means we'd be returning a "trivial" definitions schema that just wrapped the inner schema
|
||||
return new_inner_schema
|
||||
|
||||
new_schema = schema.copy()
|
||||
new_schema['schema'] = new_inner_schema
|
||||
new_schema['definitions'] = new_definitions
|
||||
return new_schema
|
||||
|
||||
def handle_list_schema(self, schema: core_schema.ListSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
items_schema = schema.get('items_schema')
|
||||
if items_schema is not None:
|
||||
schema['items_schema'] = self.walk(items_schema, f)
|
||||
return schema
|
||||
|
||||
def handle_set_schema(self, schema: core_schema.SetSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
items_schema = schema.get('items_schema')
|
||||
if items_schema is not None:
|
||||
schema['items_schema'] = self.walk(items_schema, f)
|
||||
return schema
|
||||
|
||||
def handle_frozenset_schema(self, schema: core_schema.FrozenSetSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
items_schema = schema.get('items_schema')
|
||||
if items_schema is not None:
|
||||
schema['items_schema'] = self.walk(items_schema, f)
|
||||
return schema
|
||||
|
||||
def handle_generator_schema(self, schema: core_schema.GeneratorSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
items_schema = schema.get('items_schema')
|
||||
if items_schema is not None:
|
||||
schema['items_schema'] = self.walk(items_schema, f)
|
||||
return schema
|
||||
|
||||
def handle_tuple_variable_schema(
|
||||
self, schema: core_schema.TupleVariableSchema | core_schema.TuplePositionalSchema, f: Walk
|
||||
) -> core_schema.CoreSchema:
|
||||
schema = cast(core_schema.TupleVariableSchema, schema)
|
||||
items_schema = schema.get('items_schema')
|
||||
if items_schema is not None:
|
||||
schema['items_schema'] = self.walk(items_schema, f)
|
||||
return schema
|
||||
|
||||
def handle_tuple_positional_schema(
|
||||
self, schema: core_schema.TupleVariableSchema | core_schema.TuplePositionalSchema, f: Walk
|
||||
) -> core_schema.CoreSchema:
|
||||
schema = cast(core_schema.TuplePositionalSchema, schema)
|
||||
schema['items_schema'] = [self.walk(v, f) for v in schema['items_schema']]
|
||||
extras_schema = schema.get('extras_schema')
|
||||
if extras_schema is not None:
|
||||
schema['extras_schema'] = self.walk(extras_schema, f)
|
||||
return schema
|
||||
|
||||
def handle_dict_schema(self, schema: core_schema.DictSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
keys_schema = schema.get('keys_schema')
|
||||
if keys_schema is not None:
|
||||
schema['keys_schema'] = self.walk(keys_schema, f)
|
||||
values_schema = schema.get('values_schema')
|
||||
if values_schema:
|
||||
schema['values_schema'] = self.walk(values_schema, f)
|
||||
return schema
|
||||
|
||||
def handle_function_schema(self, schema: AnyFunctionSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
if not is_function_with_inner_schema(schema):
|
||||
return schema
|
||||
schema['schema'] = self.walk(schema['schema'], f)
|
||||
return schema
|
||||
|
||||
def handle_union_schema(self, schema: core_schema.UnionSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
new_choices: list[CoreSchema | tuple[CoreSchema, str]] = []
|
||||
for v in schema['choices']:
|
||||
if isinstance(v, tuple):
|
||||
new_choices.append((self.walk(v[0], f), v[1]))
|
||||
new_dct[k] = new_metadata
|
||||
# Remove some defaults:
|
||||
elif k in ('custom_init', 'root_model') and not v:
|
||||
continue
|
||||
else:
|
||||
new_choices.append(self.walk(v, f))
|
||||
schema['choices'] = new_choices
|
||||
return schema
|
||||
new_dct[k] = _clean_schema_for_pretty_print(v, strip_metadata=strip_metadata)
|
||||
|
||||
def handle_tagged_union_schema(self, schema: core_schema.TaggedUnionSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
new_choices: dict[Hashable, core_schema.CoreSchema] = {}
|
||||
for k, v in schema['choices'].items():
|
||||
new_choices[k] = v if isinstance(v, (str, int)) else self.walk(v, f)
|
||||
schema['choices'] = new_choices
|
||||
return schema
|
||||
|
||||
def handle_chain_schema(self, schema: core_schema.ChainSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
schema['steps'] = [self.walk(v, f) for v in schema['steps']]
|
||||
return schema
|
||||
|
||||
def handle_lax_or_strict_schema(self, schema: core_schema.LaxOrStrictSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
schema['lax_schema'] = self.walk(schema['lax_schema'], f)
|
||||
schema['strict_schema'] = self.walk(schema['strict_schema'], f)
|
||||
return schema
|
||||
|
||||
def handle_json_or_python_schema(self, schema: core_schema.JsonOrPythonSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
schema['json_schema'] = self.walk(schema['json_schema'], f)
|
||||
schema['python_schema'] = self.walk(schema['python_schema'], f)
|
||||
return schema
|
||||
|
||||
def handle_model_fields_schema(self, schema: core_schema.ModelFieldsSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
extras_schema = schema.get('extras_schema')
|
||||
if extras_schema is not None:
|
||||
schema['extras_schema'] = self.walk(extras_schema, f)
|
||||
replaced_fields: dict[str, core_schema.ModelField] = {}
|
||||
replaced_computed_fields: list[core_schema.ComputedField] = []
|
||||
for computed_field in schema.get('computed_fields', ()):
|
||||
replaced_field = computed_field.copy()
|
||||
replaced_field['return_schema'] = self.walk(computed_field['return_schema'], f)
|
||||
replaced_computed_fields.append(replaced_field)
|
||||
if replaced_computed_fields:
|
||||
schema['computed_fields'] = replaced_computed_fields
|
||||
for k, v in schema['fields'].items():
|
||||
replaced_field = v.copy()
|
||||
replaced_field['schema'] = self.walk(v['schema'], f)
|
||||
replaced_fields[k] = replaced_field
|
||||
schema['fields'] = replaced_fields
|
||||
return schema
|
||||
|
||||
def handle_typed_dict_schema(self, schema: core_schema.TypedDictSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
extras_schema = schema.get('extras_schema')
|
||||
if extras_schema is not None:
|
||||
schema['extras_schema'] = self.walk(extras_schema, f)
|
||||
replaced_computed_fields: list[core_schema.ComputedField] = []
|
||||
for computed_field in schema.get('computed_fields', ()):
|
||||
replaced_field = computed_field.copy()
|
||||
replaced_field['return_schema'] = self.walk(computed_field['return_schema'], f)
|
||||
replaced_computed_fields.append(replaced_field)
|
||||
if replaced_computed_fields:
|
||||
schema['computed_fields'] = replaced_computed_fields
|
||||
replaced_fields: dict[str, core_schema.TypedDictField] = {}
|
||||
for k, v in schema['fields'].items():
|
||||
replaced_field = v.copy()
|
||||
replaced_field['schema'] = self.walk(v['schema'], f)
|
||||
replaced_fields[k] = replaced_field
|
||||
schema['fields'] = replaced_fields
|
||||
return schema
|
||||
|
||||
def handle_dataclass_args_schema(self, schema: core_schema.DataclassArgsSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
replaced_fields: list[core_schema.DataclassField] = []
|
||||
replaced_computed_fields: list[core_schema.ComputedField] = []
|
||||
for computed_field in schema.get('computed_fields', ()):
|
||||
replaced_field = computed_field.copy()
|
||||
replaced_field['return_schema'] = self.walk(computed_field['return_schema'], f)
|
||||
replaced_computed_fields.append(replaced_field)
|
||||
if replaced_computed_fields:
|
||||
schema['computed_fields'] = replaced_computed_fields
|
||||
for field in schema['fields']:
|
||||
replaced_field = field.copy()
|
||||
replaced_field['schema'] = self.walk(field['schema'], f)
|
||||
replaced_fields.append(replaced_field)
|
||||
schema['fields'] = replaced_fields
|
||||
return schema
|
||||
|
||||
def handle_arguments_schema(self, schema: core_schema.ArgumentsSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
replaced_arguments_schema: list[core_schema.ArgumentsParameter] = []
|
||||
for param in schema['arguments_schema']:
|
||||
replaced_param = param.copy()
|
||||
replaced_param['schema'] = self.walk(param['schema'], f)
|
||||
replaced_arguments_schema.append(replaced_param)
|
||||
schema['arguments_schema'] = replaced_arguments_schema
|
||||
if 'var_args_schema' in schema:
|
||||
schema['var_args_schema'] = self.walk(schema['var_args_schema'], f)
|
||||
if 'var_kwargs_schema' in schema:
|
||||
schema['var_kwargs_schema'] = self.walk(schema['var_kwargs_schema'], f)
|
||||
return schema
|
||||
|
||||
def handle_call_schema(self, schema: core_schema.CallSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
schema['arguments_schema'] = self.walk(schema['arguments_schema'], f)
|
||||
if 'return_schema' in schema:
|
||||
schema['return_schema'] = self.walk(schema['return_schema'], f)
|
||||
return schema
|
||||
|
||||
|
||||
_dispatch = _WalkCoreSchema().walk
|
||||
|
||||
|
||||
def walk_core_schema(schema: core_schema.CoreSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
"""Recursively traverse a CoreSchema.
|
||||
|
||||
Args:
|
||||
schema (core_schema.CoreSchema): The CoreSchema to process, it will not be modified.
|
||||
f (Walk): A function to apply. This function takes two arguments:
|
||||
1. The current CoreSchema that is being processed
|
||||
(not the same one you passed into this function, one level down).
|
||||
2. The "next" `f` to call. This lets you for example use `f=functools.partial(some_method, some_context)`
|
||||
to pass data down the recursive calls without using globals or other mutable state.
|
||||
|
||||
Returns:
|
||||
core_schema.CoreSchema: A processed CoreSchema.
|
||||
"""
|
||||
return f(schema.copy(), _dispatch)
|
||||
|
||||
|
||||
def simplify_schema_references(schema: core_schema.CoreSchema) -> core_schema.CoreSchema: # noqa: C901
|
||||
definitions: dict[str, core_schema.CoreSchema] = {}
|
||||
ref_counts: dict[str, int] = defaultdict(int)
|
||||
involved_in_recursion: dict[str, bool] = {}
|
||||
current_recursion_ref_count: dict[str, int] = defaultdict(int)
|
||||
|
||||
def collect_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema:
|
||||
if s['type'] == 'definitions':
|
||||
for definition in s['definitions']:
|
||||
ref = get_ref(definition)
|
||||
assert ref is not None
|
||||
if ref not in definitions:
|
||||
definitions[ref] = definition
|
||||
recurse(definition, collect_refs)
|
||||
return recurse(s['schema'], collect_refs)
|
||||
else:
|
||||
ref = get_ref(s)
|
||||
if ref is not None:
|
||||
new = recurse(s, collect_refs)
|
||||
new_ref = get_ref(new)
|
||||
if new_ref:
|
||||
definitions[new_ref] = new
|
||||
return core_schema.definition_reference_schema(schema_ref=ref)
|
||||
else:
|
||||
return recurse(s, collect_refs)
|
||||
|
||||
schema = walk_core_schema(schema, collect_refs)
|
||||
|
||||
def count_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema:
|
||||
if s['type'] != 'definition-ref':
|
||||
return recurse(s, count_refs)
|
||||
ref = s['schema_ref']
|
||||
ref_counts[ref] += 1
|
||||
|
||||
if ref_counts[ref] >= 2:
|
||||
# If this model is involved in a recursion this should be detected
|
||||
# on its second encounter, we can safely stop the walk here.
|
||||
if current_recursion_ref_count[ref] != 0:
|
||||
involved_in_recursion[ref] = True
|
||||
return s
|
||||
|
||||
current_recursion_ref_count[ref] += 1
|
||||
recurse(definitions[ref], count_refs)
|
||||
current_recursion_ref_count[ref] -= 1
|
||||
return s
|
||||
|
||||
schema = walk_core_schema(schema, count_refs)
|
||||
|
||||
assert all(c == 0 for c in current_recursion_ref_count.values()), 'this is a bug! please report it'
|
||||
|
||||
def can_be_inlined(s: core_schema.DefinitionReferenceSchema, ref: str) -> bool:
|
||||
if ref_counts[ref] > 1:
|
||||
return False
|
||||
if involved_in_recursion.get(ref, False):
|
||||
return False
|
||||
if 'serialization' in s:
|
||||
return False
|
||||
if 'metadata' in s:
|
||||
metadata = s['metadata']
|
||||
for k in (
|
||||
'pydantic_js_functions',
|
||||
'pydantic_js_annotation_functions',
|
||||
'pydantic.internal.union_discriminator',
|
||||
):
|
||||
if k in metadata:
|
||||
# we need to keep this as a ref
|
||||
return False
|
||||
return True
|
||||
|
||||
def inline_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema:
|
||||
if s['type'] == 'definition-ref':
|
||||
ref = s['schema_ref']
|
||||
# Check if the reference is only used once, not involved in recursion and does not have
|
||||
# any extra keys (like 'serialization')
|
||||
if can_be_inlined(s, ref):
|
||||
# Inline the reference by replacing the reference with the actual schema
|
||||
new = definitions.pop(ref)
|
||||
ref_counts[ref] -= 1 # because we just replaced it!
|
||||
# put all other keys that were on the def-ref schema into the inlined version
|
||||
# in particular this is needed for `serialization`
|
||||
if 'serialization' in s:
|
||||
new['serialization'] = s['serialization']
|
||||
s = recurse(new, inline_refs)
|
||||
return s
|
||||
else:
|
||||
return recurse(s, inline_refs)
|
||||
else:
|
||||
return recurse(s, inline_refs)
|
||||
|
||||
schema = walk_core_schema(schema, inline_refs)
|
||||
|
||||
def_values = [v for v in definitions.values() if ref_counts[v['ref']] > 0] # type: ignore
|
||||
|
||||
if def_values:
|
||||
schema = core_schema.definitions_schema(schema=schema, definitions=def_values)
|
||||
return schema
|
||||
|
||||
|
||||
def _strip_metadata(schema: CoreSchema) -> CoreSchema:
|
||||
def strip_metadata(s: CoreSchema, recurse: Recurse) -> CoreSchema:
|
||||
s = s.copy()
|
||||
s.pop('metadata', None)
|
||||
if s['type'] == 'model-fields':
|
||||
s = s.copy()
|
||||
s['fields'] = {k: v.copy() for k, v in s['fields'].items()}
|
||||
for field_name, field_schema in s['fields'].items():
|
||||
field_schema.pop('metadata', None)
|
||||
s['fields'][field_name] = field_schema
|
||||
computed_fields = s.get('computed_fields', None)
|
||||
if computed_fields:
|
||||
s['computed_fields'] = [cf.copy() for cf in computed_fields]
|
||||
for cf in computed_fields:
|
||||
cf.pop('metadata', None)
|
||||
else:
|
||||
s.pop('computed_fields', None)
|
||||
elif s['type'] == 'model':
|
||||
# remove some defaults
|
||||
if s.get('custom_init', True) is False:
|
||||
s.pop('custom_init')
|
||||
if s.get('root_model', True) is False:
|
||||
s.pop('root_model')
|
||||
if {'title'}.issuperset(s.get('config', {}).keys()):
|
||||
s.pop('config', None)
|
||||
|
||||
return recurse(s, strip_metadata)
|
||||
|
||||
return walk_core_schema(schema, strip_metadata)
|
||||
return new_dct
|
||||
elif isinstance(obj, Sequence) and not isinstance(obj, str):
|
||||
return [_clean_schema_for_pretty_print(v, strip_metadata=strip_metadata) for v in obj]
|
||||
else:
|
||||
return obj
|
||||
|
||||
|
||||
def pretty_print_core_schema(
|
||||
schema: CoreSchema,
|
||||
include_metadata: bool = False,
|
||||
) -> None:
|
||||
"""Pretty print a CoreSchema using rich.
|
||||
This is intended for debugging purposes.
|
||||
val: Any,
|
||||
*,
|
||||
console: Console | None = None,
|
||||
max_depth: int | None = None,
|
||||
strip_metadata: bool = True,
|
||||
) -> None: # pragma: no cover
|
||||
"""Pretty-print a core schema using the `rich` library.
|
||||
|
||||
Args:
|
||||
schema: The CoreSchema to print.
|
||||
include_metadata: Whether to include metadata in the output. Defaults to `False`.
|
||||
val: The core schema to print, or a Pydantic model/dataclass/type adapter
|
||||
(in which case the cached core schema is fetched and printed).
|
||||
console: A rich console to use when printing. Defaults to the global rich console instance.
|
||||
max_depth: The number of nesting levels which may be printed.
|
||||
strip_metadata: Whether to strip metadata in the output. If `True` any known core metadata
|
||||
attributes will be stripped (but custom attributes are kept). Defaults to `True`.
|
||||
"""
|
||||
from rich import print # type: ignore # install it manually in your dev env
|
||||
# lazy import:
|
||||
from rich.pretty import pprint
|
||||
|
||||
if not include_metadata:
|
||||
schema = _strip_metadata(schema)
|
||||
# circ. imports:
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
from pydantic.dataclasses import is_pydantic_dataclass
|
||||
|
||||
return print(schema)
|
||||
if (inspect.isclass(val) and issubclass(val, BaseModel)) or is_pydantic_dataclass(val):
|
||||
val = val.__pydantic_core_schema__
|
||||
if isinstance(val, TypeAdapter):
|
||||
val = val.core_schema
|
||||
cleaned_schema = _clean_schema_for_pretty_print(val, strip_metadata=strip_metadata)
|
||||
|
||||
pprint(cleaned_schema, console=console, max_depth=max_depth)
|
||||
|
||||
|
||||
def validate_core_schema(schema: CoreSchema) -> CoreSchema:
|
||||
if 'PYDANTIC_SKIP_VALIDATING_CORE_SCHEMAS' in os.environ:
|
||||
return schema
|
||||
return _validate_core_schema(schema)
|
||||
pps = pretty_print_core_schema
|
||||
|
||||
@@ -1,48 +1,43 @@
|
||||
"""Private logic for creating pydantic dataclasses."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import copy
|
||||
import dataclasses
|
||||
import inspect
|
||||
import typing
|
||||
import sys
|
||||
import warnings
|
||||
from functools import partial, wraps
|
||||
from inspect import Parameter, Signature
|
||||
from typing import Any, Callable, ClassVar
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Protocol, cast
|
||||
|
||||
from pydantic_core import (
|
||||
ArgsKwargs,
|
||||
PydanticUndefined,
|
||||
SchemaSerializer,
|
||||
SchemaValidator,
|
||||
core_schema,
|
||||
)
|
||||
from typing_extensions import TypeGuard
|
||||
from typing_extensions import TypeAlias, TypeIs
|
||||
|
||||
from ..errors import PydanticUndefinedAnnotation
|
||||
from ..fields import FieldInfo
|
||||
from ..plugin._schema_validator import create_schema_validator
|
||||
from ..plugin._schema_validator import PluggableSchemaValidator, create_schema_validator
|
||||
from ..warnings import PydanticDeprecatedSince20
|
||||
from . import _config, _decorators, _typing_extra
|
||||
from ._config import ConfigWrapper
|
||||
from . import _config, _decorators
|
||||
from ._fields import collect_dataclass_fields
|
||||
from ._generate_schema import GenerateSchema, generate_pydantic_signature
|
||||
from ._generate_schema import GenerateSchema, InvalidSchemaError
|
||||
from ._generics import get_standard_typevars_map
|
||||
from ._mock_val_ser import set_dataclass_mocks
|
||||
from ._schema_generation_shared import CallbackGetCoreSchemaHandler
|
||||
from ._utils import is_valid_identifier
|
||||
from ._namespace_utils import NsResolver
|
||||
from ._signature import generate_pydantic_signature
|
||||
from ._utils import LazyClassAttribute
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from _typeshed import DataclassInstance as StandardDataclass
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from ..config import ConfigDict
|
||||
|
||||
class StandardDataclass(typing.Protocol):
|
||||
__dataclass_fields__: ClassVar[dict[str, Any]]
|
||||
__dataclass_params__: ClassVar[Any] # in reality `dataclasses._DataclassParams`
|
||||
__post_init__: ClassVar[Callable[..., None]]
|
||||
|
||||
def __init__(self, *args: object, **kwargs: object) -> None:
|
||||
pass
|
||||
|
||||
class PydanticDataclass(StandardDataclass, typing.Protocol):
|
||||
class PydanticDataclass(StandardDataclass, Protocol):
|
||||
"""A protocol containing attributes only available once a class has been decorated as a Pydantic dataclass.
|
||||
|
||||
Attributes:
|
||||
@@ -61,23 +56,28 @@ if typing.TYPE_CHECKING:
|
||||
__pydantic_decorators__: ClassVar[_decorators.DecoratorInfos]
|
||||
__pydantic_fields__: ClassVar[dict[str, FieldInfo]]
|
||||
__pydantic_serializer__: ClassVar[SchemaSerializer]
|
||||
__pydantic_validator__: ClassVar[SchemaValidator]
|
||||
__pydantic_validator__: ClassVar[SchemaValidator | PluggableSchemaValidator]
|
||||
|
||||
else:
|
||||
# See PyCharm issues https://youtrack.jetbrains.com/issue/PY-21915
|
||||
# and https://youtrack.jetbrains.com/issue/PY-51428
|
||||
DeprecationWarning = PydanticDeprecatedSince20
|
||||
@classmethod
|
||||
def __pydantic_fields_complete__(cls) -> bool: ...
|
||||
|
||||
|
||||
def set_dataclass_fields(cls: type[StandardDataclass], types_namespace: dict[str, Any] | None = None) -> None:
|
||||
def set_dataclass_fields(
|
||||
cls: type[StandardDataclass],
|
||||
config_wrapper: _config.ConfigWrapper,
|
||||
ns_resolver: NsResolver | None = None,
|
||||
) -> None:
|
||||
"""Collect and set `cls.__pydantic_fields__`.
|
||||
|
||||
Args:
|
||||
cls: The class.
|
||||
types_namespace: The types namespace, defaults to `None`.
|
||||
config_wrapper: The config wrapper instance.
|
||||
ns_resolver: Namespace resolver to use when getting dataclass annotations.
|
||||
"""
|
||||
typevars_map = get_standard_typevars_map(cls)
|
||||
fields = collect_dataclass_fields(cls, types_namespace, typevars_map=typevars_map)
|
||||
fields = collect_dataclass_fields(
|
||||
cls, ns_resolver=ns_resolver, typevars_map=typevars_map, config_wrapper=config_wrapper
|
||||
)
|
||||
|
||||
cls.__pydantic_fields__ = fields # type: ignore
|
||||
|
||||
@@ -87,7 +87,8 @@ def complete_dataclass(
|
||||
config_wrapper: _config.ConfigWrapper,
|
||||
*,
|
||||
raise_errors: bool = True,
|
||||
types_namespace: dict[str, Any] | None,
|
||||
ns_resolver: NsResolver | None = None,
|
||||
_force_build: bool = False,
|
||||
) -> bool:
|
||||
"""Finish building a pydantic dataclass.
|
||||
|
||||
@@ -99,7 +100,10 @@ def complete_dataclass(
|
||||
cls: The class.
|
||||
config_wrapper: The config wrapper instance.
|
||||
raise_errors: Whether to raise errors, defaults to `True`.
|
||||
types_namespace: The types namespace.
|
||||
ns_resolver: The namespace resolver instance to use when collecting dataclass fields
|
||||
and during schema building.
|
||||
_force_build: Whether to force building the dataclass, no matter if
|
||||
[`defer_build`][pydantic.config.ConfigDict.defer_build] is set.
|
||||
|
||||
Returns:
|
||||
`True` if building a pydantic dataclass is successfully completed, `False` otherwise.
|
||||
@@ -107,27 +111,10 @@ def complete_dataclass(
|
||||
Raises:
|
||||
PydanticUndefinedAnnotation: If `raise_error` is `True` and there is an undefined annotations.
|
||||
"""
|
||||
if hasattr(cls, '__post_init_post_parse__'):
|
||||
warnings.warn(
|
||||
'Support for `__post_init_post_parse__` has been dropped, the method will not be called', DeprecationWarning
|
||||
)
|
||||
original_init = cls.__init__
|
||||
|
||||
if types_namespace is None:
|
||||
types_namespace = _typing_extra.get_cls_types_namespace(cls)
|
||||
|
||||
set_dataclass_fields(cls, types_namespace)
|
||||
|
||||
typevars_map = get_standard_typevars_map(cls)
|
||||
gen_schema = GenerateSchema(
|
||||
config_wrapper,
|
||||
types_namespace,
|
||||
typevars_map,
|
||||
)
|
||||
|
||||
# This needs to be called before we change the __init__
|
||||
sig = generate_dataclass_signature(cls, cls.__pydantic_fields__, config_wrapper) # type: ignore
|
||||
|
||||
# dataclass.__init__ must be defined here so its `__qualname__` can be changed since functions can't be copied.
|
||||
# dataclass.__init__ must be defined here so its `__qualname__` can be changed since functions can't be copied,
|
||||
# and so that the mock validator is used if building was deferred:
|
||||
def __init__(__dataclass_self__: PydanticDataclass, *args: Any, **kwargs: Any) -> None:
|
||||
__tracebackhide__ = True
|
||||
s = __dataclass_self__
|
||||
@@ -137,136 +124,77 @@ def complete_dataclass(
|
||||
|
||||
cls.__init__ = __init__ # type: ignore
|
||||
cls.__pydantic_config__ = config_wrapper.config_dict # type: ignore
|
||||
cls.__signature__ = sig # type: ignore
|
||||
get_core_schema = getattr(cls, '__get_pydantic_core_schema__', None)
|
||||
|
||||
set_dataclass_fields(cls, config_wrapper=config_wrapper, ns_resolver=ns_resolver)
|
||||
|
||||
if not _force_build and config_wrapper.defer_build:
|
||||
set_dataclass_mocks(cls)
|
||||
return False
|
||||
|
||||
if hasattr(cls, '__post_init_post_parse__'):
|
||||
warnings.warn(
|
||||
'Support for `__post_init_post_parse__` has been dropped, the method will not be called',
|
||||
PydanticDeprecatedSince20,
|
||||
)
|
||||
|
||||
typevars_map = get_standard_typevars_map(cls)
|
||||
gen_schema = GenerateSchema(
|
||||
config_wrapper,
|
||||
ns_resolver=ns_resolver,
|
||||
typevars_map=typevars_map,
|
||||
)
|
||||
|
||||
# set __signature__ attr only for the class, but not for its instances
|
||||
# (because instances can define `__call__`, and `inspect.signature` shouldn't
|
||||
# use the `__signature__` attribute and instead generate from `__call__`).
|
||||
cls.__signature__ = LazyClassAttribute(
|
||||
'__signature__',
|
||||
partial(
|
||||
generate_pydantic_signature,
|
||||
# It's important that we reference the `original_init` here,
|
||||
# as it is the one synthesized by the stdlib `dataclass` module:
|
||||
init=original_init,
|
||||
fields=cls.__pydantic_fields__, # type: ignore
|
||||
validate_by_name=config_wrapper.validate_by_name,
|
||||
extra=config_wrapper.extra,
|
||||
is_dataclass=True,
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
if get_core_schema:
|
||||
schema = get_core_schema(
|
||||
cls,
|
||||
CallbackGetCoreSchemaHandler(
|
||||
partial(gen_schema.generate_schema, from_dunder_get_core_schema=False),
|
||||
gen_schema,
|
||||
ref_mode='unpack',
|
||||
),
|
||||
)
|
||||
else:
|
||||
schema = gen_schema.generate_schema(cls, from_dunder_get_core_schema=False)
|
||||
schema = gen_schema.generate_schema(cls)
|
||||
except PydanticUndefinedAnnotation as e:
|
||||
if raise_errors:
|
||||
raise
|
||||
set_dataclass_mocks(cls, cls.__name__, f'`{e.name}`')
|
||||
set_dataclass_mocks(cls, f'`{e.name}`')
|
||||
return False
|
||||
|
||||
core_config = config_wrapper.core_config(cls)
|
||||
core_config = config_wrapper.core_config(title=cls.__name__)
|
||||
|
||||
try:
|
||||
schema = gen_schema.clean_schema(schema)
|
||||
except gen_schema.CollectedInvalid:
|
||||
set_dataclass_mocks(cls, cls.__name__, 'all referenced types')
|
||||
except InvalidSchemaError:
|
||||
set_dataclass_mocks(cls)
|
||||
return False
|
||||
|
||||
# We are about to set all the remaining required properties expected for this cast;
|
||||
# __pydantic_decorators__ and __pydantic_fields__ should already be set
|
||||
cls = typing.cast('type[PydanticDataclass]', cls)
|
||||
# debug(schema)
|
||||
cls = cast('type[PydanticDataclass]', cls)
|
||||
|
||||
cls.__pydantic_core_schema__ = schema
|
||||
cls.__pydantic_validator__ = validator = create_schema_validator(
|
||||
cls.__pydantic_validator__ = create_schema_validator(
|
||||
schema, cls, cls.__module__, cls.__qualname__, 'dataclass', core_config, config_wrapper.plugin_settings
|
||||
)
|
||||
cls.__pydantic_serializer__ = SchemaSerializer(schema, core_config)
|
||||
|
||||
if config_wrapper.validate_assignment:
|
||||
|
||||
@wraps(cls.__setattr__)
|
||||
def validated_setattr(instance: Any, __field: str, __value: str) -> None:
|
||||
validator.validate_assignment(instance, __field, __value)
|
||||
|
||||
cls.__setattr__ = validated_setattr.__get__(None, cls) # type: ignore
|
||||
|
||||
cls.__pydantic_complete__ = True
|
||||
return True
|
||||
|
||||
|
||||
def process_param_defaults(param: Parameter) -> Parameter:
|
||||
"""Custom processing where the parameter default is of type FieldInfo
|
||||
def is_stdlib_dataclass(cls: type[Any], /) -> TypeIs[type[StandardDataclass]]:
|
||||
"""Returns `True` if the class is a stdlib dataclass and *not* a Pydantic dataclass.
|
||||
|
||||
Args:
|
||||
param (Parameter): The parameter
|
||||
|
||||
Returns:
|
||||
Parameter: The custom processed parameter
|
||||
"""
|
||||
param_default = param.default
|
||||
if isinstance(param_default, FieldInfo):
|
||||
annotation = param.annotation
|
||||
# Replace the annotation if appropriate
|
||||
# inspect does "clever" things to show annotations as strings because we have
|
||||
# `from __future__ import annotations` in main, we don't want that
|
||||
if annotation == 'Any':
|
||||
annotation = Any
|
||||
|
||||
# Replace the field name with the alias if present
|
||||
name = param.name
|
||||
alias = param_default.alias
|
||||
validation_alias = param_default.validation_alias
|
||||
if validation_alias is None and isinstance(alias, str) and is_valid_identifier(alias):
|
||||
name = alias
|
||||
elif isinstance(validation_alias, str) and is_valid_identifier(validation_alias):
|
||||
name = validation_alias
|
||||
|
||||
# Replace the field default
|
||||
default = param_default.default
|
||||
if default is PydanticUndefined:
|
||||
if param_default.default_factory is PydanticUndefined:
|
||||
default = inspect.Signature.empty
|
||||
else:
|
||||
# this is used by dataclasses to indicate a factory exists:
|
||||
default = dataclasses._HAS_DEFAULT_FACTORY # type: ignore
|
||||
return param.replace(annotation=annotation, name=name, default=default)
|
||||
return param
|
||||
|
||||
|
||||
def generate_dataclass_signature(
|
||||
cls: type[StandardDataclass], fields: dict[str, FieldInfo], config_wrapper: ConfigWrapper
|
||||
) -> Signature:
|
||||
"""Generate signature for a pydantic dataclass.
|
||||
|
||||
Args:
|
||||
cls: The dataclass.
|
||||
fields: The model fields.
|
||||
config_wrapper: The config wrapper instance.
|
||||
|
||||
Returns:
|
||||
The dataclass signature.
|
||||
"""
|
||||
return generate_pydantic_signature(
|
||||
init=cls.__init__, fields=fields, config_wrapper=config_wrapper, post_process_parameter=process_param_defaults
|
||||
)
|
||||
|
||||
|
||||
def is_builtin_dataclass(_cls: type[Any]) -> TypeGuard[type[StandardDataclass]]:
|
||||
"""Returns True if a class is a stdlib dataclass and *not* a pydantic dataclass.
|
||||
|
||||
We check that
|
||||
- `_cls` is a dataclass
|
||||
- `_cls` does not inherit from a processed pydantic dataclass (and thus have a `__pydantic_validator__`)
|
||||
- `_cls` does not have any annotations that are not dataclass fields
|
||||
e.g.
|
||||
```py
|
||||
import dataclasses
|
||||
|
||||
import pydantic.dataclasses
|
||||
|
||||
@dataclasses.dataclass
|
||||
class A:
|
||||
x: int
|
||||
|
||||
@pydantic.dataclasses.dataclass
|
||||
class B(A):
|
||||
y: int
|
||||
```
|
||||
In this case, when we first check `B`, we make an extra check and look at the annotations ('y'),
|
||||
which won't be a superset of all the dataclass fields (only the stdlib fields i.e. 'x')
|
||||
Unlike the stdlib `dataclasses.is_dataclass()` function, this does *not* include subclasses
|
||||
of a dataclass that are themselves not dataclasses.
|
||||
|
||||
Args:
|
||||
cls: The class.
|
||||
@@ -274,8 +202,114 @@ def is_builtin_dataclass(_cls: type[Any]) -> TypeGuard[type[StandardDataclass]]:
|
||||
Returns:
|
||||
`True` if the class is a stdlib dataclass, `False` otherwise.
|
||||
"""
|
||||
return (
|
||||
dataclasses.is_dataclass(_cls)
|
||||
and not hasattr(_cls, '__pydantic_validator__')
|
||||
and set(_cls.__dataclass_fields__).issuperset(set(getattr(_cls, '__annotations__', {})))
|
||||
)
|
||||
return '__dataclass_fields__' in cls.__dict__ and not hasattr(cls, '__pydantic_validator__')
|
||||
|
||||
|
||||
def as_dataclass_field(pydantic_field: FieldInfo) -> dataclasses.Field[Any]:
|
||||
field_args: dict[str, Any] = {'default': pydantic_field}
|
||||
|
||||
# Needed because if `doc` is set, the dataclass slots will be a dict (field name -> doc) instead of a tuple:
|
||||
if sys.version_info >= (3, 14) and pydantic_field.description is not None:
|
||||
field_args['doc'] = pydantic_field.description
|
||||
|
||||
# Needed as the stdlib dataclass module processes kw_only in a specific way during class construction:
|
||||
if sys.version_info >= (3, 10) and pydantic_field.kw_only:
|
||||
field_args['kw_only'] = True
|
||||
|
||||
# Needed as the stdlib dataclass modules generates `__repr__()` during class construction:
|
||||
if pydantic_field.repr is not True:
|
||||
field_args['repr'] = pydantic_field.repr
|
||||
|
||||
return dataclasses.field(**field_args)
|
||||
|
||||
|
||||
DcFields: TypeAlias = dict[str, dataclasses.Field[Any]]
|
||||
|
||||
|
||||
@contextmanager
|
||||
def patch_base_fields(cls: type[Any]) -> Generator[None]:
|
||||
"""Temporarily patch the stdlib dataclasses bases of `cls` if the Pydantic `Field()` function is used.
|
||||
|
||||
When creating a Pydantic dataclass, it is possible to inherit from stdlib dataclasses, where
|
||||
the Pydantic `Field()` function is used. To create this Pydantic dataclass, we first apply
|
||||
the stdlib `@dataclass` decorator on it. During the construction of the stdlib dataclass,
|
||||
the `kw_only` and `repr` field arguments need to be understood by the stdlib *during* the
|
||||
dataclass construction. To do so, we temporarily patch the fields dictionary of the affected
|
||||
bases.
|
||||
|
||||
For instance, with the following example:
|
||||
|
||||
```python {test="skip" lint="skip"}
|
||||
import dataclasses as stdlib_dc
|
||||
|
||||
import pydantic
|
||||
import pydantic.dataclasses as pydantic_dc
|
||||
|
||||
@stdlib_dc.dataclass
|
||||
class A:
|
||||
a: int = pydantic.Field(repr=False)
|
||||
|
||||
# Notice that the `repr` attribute of the dataclass field is `True`:
|
||||
A.__dataclass_fields__['a']
|
||||
#> dataclass.Field(default=FieldInfo(repr=False), repr=True, ...)
|
||||
|
||||
@pydantic_dc.dataclass
|
||||
class B(A):
|
||||
b: int = pydantic.Field(repr=False)
|
||||
```
|
||||
|
||||
When passing `B` to the stdlib `@dataclass` decorator, it will look for fields in the parent classes
|
||||
and reuse them directly. When this context manager is active, `A` will be temporarily patched to be
|
||||
equivalent to:
|
||||
|
||||
```python {test="skip" lint="skip"}
|
||||
@stdlib_dc.dataclass
|
||||
class A:
|
||||
a: int = stdlib_dc.field(default=Field(repr=False), repr=False)
|
||||
```
|
||||
|
||||
!!! note
|
||||
This is only applied to the bases of `cls`, and not `cls` itself. The reason is that the Pydantic
|
||||
dataclass decorator "owns" `cls` (in the previous example, `B`). As such, we instead modify the fields
|
||||
directly (in the previous example, we simply do `setattr(B, 'b', as_dataclass_field(pydantic_field))`).
|
||||
|
||||
!!! note
|
||||
This approach is far from ideal, and can probably be the source of unwanted side effects/race conditions.
|
||||
The previous implemented approach was mutating the `__annotations__` dict of `cls`, which is no longer a
|
||||
safe operation in Python 3.14+, and resulted in unexpected behavior with field ordering anyway.
|
||||
"""
|
||||
# A list of two-tuples, the first element being a reference to the
|
||||
# dataclass fields dictionary, the second element being a mapping between
|
||||
# the field names that were modified, and their original `Field`:
|
||||
original_fields_list: list[tuple[DcFields, DcFields]] = []
|
||||
|
||||
for base in cls.__mro__[1:]:
|
||||
dc_fields: dict[str, dataclasses.Field[Any]] = base.__dict__.get('__dataclass_fields__', {})
|
||||
dc_fields_with_pydantic_field_defaults = {
|
||||
field_name: field
|
||||
for field_name, field in dc_fields.items()
|
||||
if isinstance(field.default, FieldInfo)
|
||||
# Only do the patching if one of the affected attributes is set:
|
||||
and (field.default.description is not None or field.default.kw_only or field.default.repr is not True)
|
||||
}
|
||||
if dc_fields_with_pydantic_field_defaults:
|
||||
original_fields_list.append((dc_fields, dc_fields_with_pydantic_field_defaults))
|
||||
for field_name, field in dc_fields_with_pydantic_field_defaults.items():
|
||||
default = cast(FieldInfo, field.default)
|
||||
# `dataclasses.Field` isn't documented as working with `copy.copy()`.
|
||||
# It is a class with `__slots__`, so should work (and we hope for the best):
|
||||
new_dc_field = copy.copy(field)
|
||||
# For base fields, no need to set `doc` from `FieldInfo.description`, this is only relevant
|
||||
# for the class under construction and handled in `as_dataclass_field()`.
|
||||
if sys.version_info >= (3, 10) and default.kw_only:
|
||||
new_dc_field.kw_only = True
|
||||
if default.repr is not True:
|
||||
new_dc_field.repr = default.repr
|
||||
dc_fields[field_name] = new_dc_field
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
for fields, original_fields in original_fields_list:
|
||||
for field_name, original_field in original_fields.items():
|
||||
fields[field_name] = original_field
|
||||
|
||||
@@ -1,30 +1,31 @@
|
||||
"""Logic related to validators applied to models etc. via the `@field_validator` and `@model_validator` decorators."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import sys
|
||||
import types
|
||||
from collections import deque
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial, partialmethod
|
||||
from functools import cached_property, partial, partialmethod
|
||||
from inspect import Parameter, Signature, isdatadescriptor, ismethoddescriptor, signature
|
||||
from itertools import islice
|
||||
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, Iterable, TypeVar, Union
|
||||
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, Literal, TypeVar, Union
|
||||
|
||||
from pydantic_core import PydanticUndefined, core_schema
|
||||
from typing_extensions import Literal, TypeAlias, is_typeddict
|
||||
from pydantic_core import PydanticUndefined, PydanticUndefinedType, core_schema
|
||||
from typing_extensions import TypeAlias, is_typeddict
|
||||
|
||||
from ..errors import PydanticUserError
|
||||
from ._core_utils import get_type_ref
|
||||
from ._internal_dataclass import slots_true
|
||||
from ._namespace_utils import GlobalsNamespace, MappingNamespace
|
||||
from ._typing_extra import get_function_type_hints
|
||||
from ._utils import can_be_positional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..fields import ComputedFieldInfo
|
||||
from ..functional_validators import FieldValidatorModes
|
||||
|
||||
try:
|
||||
from functools import cached_property # type: ignore
|
||||
except ImportError:
|
||||
# python 3.7
|
||||
cached_property = None
|
||||
from ._config import ConfigWrapper
|
||||
|
||||
|
||||
@dataclass(**slots_true)
|
||||
@@ -61,6 +62,9 @@ class FieldValidatorDecoratorInfo:
|
||||
fields: A tuple of field names the validator should be called on.
|
||||
mode: The proposed validator mode.
|
||||
check_fields: Whether to check that the fields actually exist on the model.
|
||||
json_schema_input_type: The input type of the function. This is only used to generate
|
||||
the appropriate JSON Schema (in validation mode) and can only specified
|
||||
when `mode` is either `'before'`, `'plain'` or `'wrap'`.
|
||||
"""
|
||||
|
||||
decorator_repr: ClassVar[str] = '@field_validator'
|
||||
@@ -68,6 +72,7 @@ class FieldValidatorDecoratorInfo:
|
||||
fields: tuple[str, ...]
|
||||
mode: FieldValidatorModes
|
||||
check_fields: bool | None
|
||||
json_schema_input_type: Any
|
||||
|
||||
|
||||
@dataclass(**slots_true)
|
||||
@@ -132,7 +137,7 @@ class ModelValidatorDecoratorInfo:
|
||||
while building the pydantic-core schema.
|
||||
|
||||
Attributes:
|
||||
decorator_repr: A class variable representing the decorator string, '@model_serializer'.
|
||||
decorator_repr: A class variable representing the decorator string, '@model_validator'.
|
||||
mode: The proposed serializer mode.
|
||||
"""
|
||||
|
||||
@@ -183,22 +188,28 @@ class PydanticDescriptorProxy(Generic[ReturnType]):
|
||||
|
||||
def _call_wrapped_attr(self, func: Callable[[Any], None], *, name: str) -> PydanticDescriptorProxy[ReturnType]:
|
||||
self.wrapped = getattr(self.wrapped, name)(func)
|
||||
if isinstance(self.wrapped, property):
|
||||
# update ComputedFieldInfo.wrapped_property
|
||||
from ..fields import ComputedFieldInfo
|
||||
|
||||
if isinstance(self.decorator_info, ComputedFieldInfo):
|
||||
self.decorator_info.wrapped_property = self.wrapped
|
||||
return self
|
||||
|
||||
def __get__(self, obj: object | None, obj_type: type[object] | None = None) -> PydanticDescriptorProxy[ReturnType]:
|
||||
try:
|
||||
return self.wrapped.__get__(obj, obj_type)
|
||||
return self.wrapped.__get__(obj, obj_type) # pyright: ignore[reportReturnType]
|
||||
except AttributeError:
|
||||
# not a descriptor, e.g. a partial object
|
||||
return self.wrapped # type: ignore[return-value]
|
||||
|
||||
def __set_name__(self, instance: Any, name: str) -> None:
|
||||
if hasattr(self.wrapped, '__set_name__'):
|
||||
self.wrapped.__set_name__(instance, name)
|
||||
self.wrapped.__set_name__(instance, name) # pyright: ignore[reportFunctionMemberAccess]
|
||||
|
||||
def __getattr__(self, __name: str) -> Any:
|
||||
def __getattr__(self, name: str, /) -> Any:
|
||||
"""Forward checks for __isabstractmethod__ and such."""
|
||||
return getattr(self.wrapped, __name)
|
||||
return getattr(self.wrapped, name)
|
||||
|
||||
|
||||
DecoratorInfoType = TypeVar('DecoratorInfoType', bound=DecoratorInfo)
|
||||
@@ -500,13 +511,20 @@ class DecoratorInfos:
|
||||
# so then we don't need to re-process the type, which means we can discard our descriptor wrappers
|
||||
# and replace them with the thing they are wrapping (see the other setattr call below)
|
||||
# which allows validator class methods to also function as regular class methods
|
||||
setattr(model_dc, '__pydantic_decorators__', res)
|
||||
model_dc.__pydantic_decorators__ = res
|
||||
for name, value in to_replace:
|
||||
setattr(model_dc, name, value)
|
||||
return res
|
||||
|
||||
def update_from_config(self, config_wrapper: ConfigWrapper) -> None:
|
||||
"""Update the decorator infos from the configuration of the class they are attached to."""
|
||||
for name, computed_field_dec in self.computed_fields.items():
|
||||
computed_field_dec.info._update_from_config(config_wrapper, name)
|
||||
|
||||
def inspect_validator(validator: Callable[..., Any], mode: FieldValidatorModes) -> bool:
|
||||
|
||||
def inspect_validator(
|
||||
validator: Callable[..., Any], *, mode: FieldValidatorModes, type: Literal['field', 'model']
|
||||
) -> bool:
|
||||
"""Look at a field or model validator function and determine whether it takes an info argument.
|
||||
|
||||
An error is raised if the function has an invalid signature.
|
||||
@@ -514,18 +532,18 @@ def inspect_validator(validator: Callable[..., Any], mode: FieldValidatorModes)
|
||||
Args:
|
||||
validator: The validator function to inspect.
|
||||
mode: The proposed validator mode.
|
||||
type: The type of validator, either 'field' or 'model'.
|
||||
|
||||
Returns:
|
||||
Whether the validator takes an info argument.
|
||||
"""
|
||||
try:
|
||||
sig = signature(validator)
|
||||
except ValueError:
|
||||
# builtins and some C extensions don't have signatures
|
||||
# assume that they don't take an info argument and only take a single argument
|
||||
# e.g. `str.strip` or `datetime.datetime`
|
||||
sig = _signature_no_eval(validator)
|
||||
except (ValueError, TypeError):
|
||||
# `inspect.signature` might not be able to infer a signature, e.g. with C objects.
|
||||
# In this case, we assume no info argument is present:
|
||||
return False
|
||||
n_positional = count_positional_params(sig)
|
||||
n_positional = count_positional_required_params(sig)
|
||||
if mode == 'wrap':
|
||||
if n_positional == 3:
|
||||
return True
|
||||
@@ -539,14 +557,12 @@ def inspect_validator(validator: Callable[..., Any], mode: FieldValidatorModes)
|
||||
return False
|
||||
|
||||
raise PydanticUserError(
|
||||
f'Unrecognized field_validator function signature for {validator} with `mode={mode}`:{sig}',
|
||||
f'Unrecognized {type} validator function signature for {validator} with `mode={mode}`: {sig}',
|
||||
code='validator-signature',
|
||||
)
|
||||
|
||||
|
||||
def inspect_field_serializer(
|
||||
serializer: Callable[..., Any], mode: Literal['plain', 'wrap'], computed_field: bool = False
|
||||
) -> tuple[bool, bool]:
|
||||
def inspect_field_serializer(serializer: Callable[..., Any], mode: Literal['plain', 'wrap']) -> tuple[bool, bool]:
|
||||
"""Look at a field serializer function and determine if it is a field serializer,
|
||||
and whether it takes an info argument.
|
||||
|
||||
@@ -555,18 +571,21 @@ def inspect_field_serializer(
|
||||
Args:
|
||||
serializer: The serializer function to inspect.
|
||||
mode: The serializer mode, either 'plain' or 'wrap'.
|
||||
computed_field: When serializer is applied on computed_field. It doesn't require
|
||||
info signature.
|
||||
|
||||
Returns:
|
||||
Tuple of (is_field_serializer, info_arg).
|
||||
"""
|
||||
sig = signature(serializer)
|
||||
try:
|
||||
sig = _signature_no_eval(serializer)
|
||||
except (ValueError, TypeError):
|
||||
# `inspect.signature` might not be able to infer a signature, e.g. with C objects.
|
||||
# In this case, we assume no info argument is present and this is not a method:
|
||||
return (False, False)
|
||||
|
||||
first = next(iter(sig.parameters.values()), None)
|
||||
is_field_serializer = first is not None and first.name == 'self'
|
||||
|
||||
n_positional = count_positional_params(sig)
|
||||
n_positional = count_positional_required_params(sig)
|
||||
if is_field_serializer:
|
||||
# -1 to correct for self parameter
|
||||
info_arg = _serializer_info_arg(mode, n_positional - 1)
|
||||
@@ -578,13 +597,8 @@ def inspect_field_serializer(
|
||||
f'Unrecognized field_serializer function signature for {serializer} with `mode={mode}`:{sig}',
|
||||
code='field-serializer-signature',
|
||||
)
|
||||
if info_arg and computed_field:
|
||||
raise PydanticUserError(
|
||||
'field_serializer on computed_field does not use info signature', code='field-serializer-signature'
|
||||
)
|
||||
|
||||
else:
|
||||
return is_field_serializer, info_arg
|
||||
return is_field_serializer, info_arg
|
||||
|
||||
|
||||
def inspect_annotated_serializer(serializer: Callable[..., Any], mode: Literal['plain', 'wrap']) -> bool:
|
||||
@@ -599,8 +613,13 @@ def inspect_annotated_serializer(serializer: Callable[..., Any], mode: Literal['
|
||||
Returns:
|
||||
info_arg
|
||||
"""
|
||||
sig = signature(serializer)
|
||||
info_arg = _serializer_info_arg(mode, count_positional_params(sig))
|
||||
try:
|
||||
sig = _signature_no_eval(serializer)
|
||||
except (ValueError, TypeError):
|
||||
# `inspect.signature` might not be able to infer a signature, e.g. with C objects.
|
||||
# In this case, we assume no info argument is present:
|
||||
return False
|
||||
info_arg = _serializer_info_arg(mode, count_positional_required_params(sig))
|
||||
if info_arg is None:
|
||||
raise PydanticUserError(
|
||||
f'Unrecognized field_serializer function signature for {serializer} with `mode={mode}`:{sig}',
|
||||
@@ -627,8 +646,8 @@ def inspect_model_serializer(serializer: Callable[..., Any], mode: Literal['plai
|
||||
'`@model_serializer` must be applied to instance methods', code='model-serializer-instance-method'
|
||||
)
|
||||
|
||||
sig = signature(serializer)
|
||||
info_arg = _serializer_info_arg(mode, count_positional_params(sig))
|
||||
sig = _signature_no_eval(serializer)
|
||||
info_arg = _serializer_info_arg(mode, count_positional_required_params(sig))
|
||||
if info_arg is None:
|
||||
raise PydanticUserError(
|
||||
f'Unrecognized model_serializer function signature for {serializer} with `mode={mode}`:{sig}',
|
||||
@@ -641,18 +660,18 @@ def inspect_model_serializer(serializer: Callable[..., Any], mode: Literal['plai
|
||||
def _serializer_info_arg(mode: Literal['plain', 'wrap'], n_positional: int) -> bool | None:
|
||||
if mode == 'plain':
|
||||
if n_positional == 1:
|
||||
# (__input_value: Any) -> Any
|
||||
# (input_value: Any, /) -> Any
|
||||
return False
|
||||
elif n_positional == 2:
|
||||
# (__model: Any, __input_value: Any) -> Any
|
||||
# (model: Any, input_value: Any, /) -> Any
|
||||
return True
|
||||
else:
|
||||
assert mode == 'wrap', f"invalid mode: {mode!r}, expected 'plain' or 'wrap'"
|
||||
if n_positional == 2:
|
||||
# (__input_value: Any, __serializer: SerializerFunctionWrapHandler) -> Any
|
||||
# (input_value: Any, serializer: SerializerFunctionWrapHandler, /) -> Any
|
||||
return False
|
||||
elif n_positional == 3:
|
||||
# (__input_value: Any, __serializer: SerializerFunctionWrapHandler, __info: SerializationInfo) -> Any
|
||||
# (input_value: Any, serializer: SerializerFunctionWrapHandler, info: SerializationInfo, /) -> Any
|
||||
return True
|
||||
|
||||
return None
|
||||
@@ -675,7 +694,7 @@ def is_instance_method_from_sig(function: AnyDecoratorCallable) -> bool:
|
||||
Returns:
|
||||
`True` if the function is an instance method, `False` otherwise.
|
||||
"""
|
||||
sig = signature(unwrap_wrapped_function(function))
|
||||
sig = _signature_no_eval(unwrap_wrapped_function(function))
|
||||
first = next(iter(sig.parameters.values()), None)
|
||||
if first and first.name == 'self':
|
||||
return True
|
||||
@@ -699,7 +718,7 @@ def ensure_classmethod_based_on_signature(function: AnyDecoratorCallable) -> Any
|
||||
|
||||
|
||||
def _is_classmethod_from_sig(function: AnyDecoratorCallable) -> bool:
|
||||
sig = signature(unwrap_wrapped_function(function))
|
||||
sig = _signature_no_eval(unwrap_wrapped_function(function))
|
||||
first = next(iter(sig.parameters.values()), None)
|
||||
if first and first.name == 'cls':
|
||||
return True
|
||||
@@ -713,34 +732,25 @@ def unwrap_wrapped_function(
|
||||
unwrap_class_static_method: bool = True,
|
||||
) -> Any:
|
||||
"""Recursively unwraps a wrapped function until the underlying function is reached.
|
||||
This handles property, functools.partial, functools.partialmethod, staticmethod and classmethod.
|
||||
This handles property, functools.partial, functools.partialmethod, staticmethod, and classmethod.
|
||||
|
||||
Args:
|
||||
func: The function to unwrap.
|
||||
unwrap_partial: If True (default), unwrap partial and partialmethod decorators, otherwise don't.
|
||||
decorators.
|
||||
unwrap_partial: If True (default), unwrap partial and partialmethod decorators.
|
||||
unwrap_class_static_method: If True (default), also unwrap classmethod and staticmethod
|
||||
decorators. If False, only unwrap partial and partialmethod decorators.
|
||||
|
||||
Returns:
|
||||
The underlying function of the wrapped function.
|
||||
"""
|
||||
all: set[Any] = {property}
|
||||
# Define the types we want to check against as a single tuple.
|
||||
unwrap_types = (
|
||||
(property, cached_property)
|
||||
+ ((partial, partialmethod) if unwrap_partial else ())
|
||||
+ ((staticmethod, classmethod) if unwrap_class_static_method else ())
|
||||
)
|
||||
|
||||
if unwrap_partial:
|
||||
all.update({partial, partialmethod})
|
||||
|
||||
try:
|
||||
from functools import cached_property # type: ignore
|
||||
except ImportError:
|
||||
cached_property = type('', (), {})
|
||||
else:
|
||||
all.add(cached_property)
|
||||
|
||||
if unwrap_class_static_method:
|
||||
all.update({staticmethod, classmethod})
|
||||
|
||||
while isinstance(func, tuple(all)):
|
||||
while isinstance(func, unwrap_types):
|
||||
if unwrap_class_static_method and isinstance(func, (classmethod, staticmethod)):
|
||||
func = func.__func__
|
||||
elif isinstance(func, (partial, partialmethod)):
|
||||
@@ -755,38 +765,72 @@ def unwrap_wrapped_function(
|
||||
return func
|
||||
|
||||
|
||||
def get_function_return_type(
|
||||
func: Any, explicit_return_type: Any, types_namespace: dict[str, Any] | None = None
|
||||
) -> Any:
|
||||
"""Get the function return type.
|
||||
_function_like = (
|
||||
partial,
|
||||
partialmethod,
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
types.MethodType,
|
||||
types.WrapperDescriptorType,
|
||||
types.MethodWrapperType,
|
||||
types.MemberDescriptorType,
|
||||
)
|
||||
|
||||
It gets the return type from the type annotation if `explicit_return_type` is `None`.
|
||||
Otherwise, it returns `explicit_return_type`.
|
||||
|
||||
def get_callable_return_type(
|
||||
callable_obj: Any,
|
||||
globalns: GlobalsNamespace | None = None,
|
||||
localns: MappingNamespace | None = None,
|
||||
) -> Any | PydanticUndefinedType:
|
||||
"""Get the callable return type.
|
||||
|
||||
Args:
|
||||
func: The function to get its return type.
|
||||
explicit_return_type: The explicit return type.
|
||||
types_namespace: The types namespace, defaults to `None`.
|
||||
callable_obj: The callable to analyze.
|
||||
globalns: The globals namespace to use during type annotation evaluation.
|
||||
localns: The locals namespace to use during type annotation evaluation.
|
||||
|
||||
Returns:
|
||||
The function return type.
|
||||
"""
|
||||
if explicit_return_type is PydanticUndefined:
|
||||
# try to get it from the type annotation
|
||||
hints = get_function_type_hints(
|
||||
unwrap_wrapped_function(func), include_keys={'return'}, types_namespace=types_namespace
|
||||
)
|
||||
return hints.get('return', PydanticUndefined)
|
||||
else:
|
||||
return explicit_return_type
|
||||
if isinstance(callable_obj, type):
|
||||
# types are callables, and we assume the return type
|
||||
# is the type itself (e.g. `int()` results in an instance of `int`).
|
||||
return callable_obj
|
||||
|
||||
if not isinstance(callable_obj, _function_like):
|
||||
call_func = getattr(type(callable_obj), '__call__', None) # noqa: B004
|
||||
if call_func is not None:
|
||||
callable_obj = call_func
|
||||
|
||||
hints = get_function_type_hints(
|
||||
unwrap_wrapped_function(callable_obj),
|
||||
include_keys={'return'},
|
||||
globalns=globalns,
|
||||
localns=localns,
|
||||
)
|
||||
return hints.get('return', PydanticUndefined)
|
||||
|
||||
|
||||
def count_positional_params(sig: Signature) -> int:
|
||||
return sum(1 for param in sig.parameters.values() if can_be_positional(param))
|
||||
def count_positional_required_params(sig: Signature) -> int:
|
||||
"""Get the number of positional (required) arguments of a signature.
|
||||
|
||||
This function should only be used to inspect signatures of validation and serialization functions.
|
||||
The first argument (the value being serialized or validated) is counted as a required argument
|
||||
even if a default value exists.
|
||||
|
||||
def can_be_positional(param: Parameter) -> bool:
|
||||
return param.kind in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD)
|
||||
Returns:
|
||||
The number of positional arguments of a signature.
|
||||
"""
|
||||
parameters = list(sig.parameters.values())
|
||||
return sum(
|
||||
1
|
||||
for param in parameters
|
||||
if can_be_positional(param)
|
||||
# First argument is the value being validated/serialized, and can have a default value
|
||||
# (e.g. `float`, which has signature `(x=0, /)`). We assume other parameters (the info arg
|
||||
# for instance) should be required, and thus without any default value.
|
||||
and (param.default is Parameter.empty or param is parameters[0])
|
||||
)
|
||||
|
||||
|
||||
def ensure_property(f: Any) -> Any:
|
||||
@@ -802,3 +846,13 @@ def ensure_property(f: Any) -> Any:
|
||||
return f
|
||||
else:
|
||||
return property(f)
|
||||
|
||||
|
||||
def _signature_no_eval(f: Callable[..., Any]) -> Signature:
|
||||
"""Get the signature of a callable without evaluating any annotations."""
|
||||
if sys.version_info >= (3, 14):
|
||||
from annotationlib import Format
|
||||
|
||||
return signature(f, annotation_format=Format.FORWARDREF)
|
||||
else:
|
||||
return signature(f)
|
||||
|
||||
@@ -1,49 +1,45 @@
|
||||
"""Logic for V1 validators, e.g. `@validator` and `@root_validator`."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
from inspect import Parameter, signature
|
||||
from typing import Any, Dict, Tuple, Union, cast
|
||||
from typing import Any, Union, cast
|
||||
|
||||
from pydantic_core import core_schema
|
||||
from typing_extensions import Protocol
|
||||
|
||||
from ..errors import PydanticUserError
|
||||
from ._decorators import can_be_positional
|
||||
from ._utils import can_be_positional
|
||||
|
||||
|
||||
class V1OnlyValueValidator(Protocol):
|
||||
"""A simple validator, supported for V1 validators and V2 validators."""
|
||||
|
||||
def __call__(self, __value: Any) -> Any:
|
||||
...
|
||||
def __call__(self, __value: Any) -> Any: ...
|
||||
|
||||
|
||||
class V1ValidatorWithValues(Protocol):
|
||||
"""A validator with `values` argument, supported for V1 validators and V2 validators."""
|
||||
|
||||
def __call__(self, __value: Any, values: dict[str, Any]) -> Any:
|
||||
...
|
||||
def __call__(self, __value: Any, values: dict[str, Any]) -> Any: ...
|
||||
|
||||
|
||||
class V1ValidatorWithValuesKwOnly(Protocol):
|
||||
"""A validator with keyword only `values` argument, supported for V1 validators and V2 validators."""
|
||||
|
||||
def __call__(self, __value: Any, *, values: dict[str, Any]) -> Any:
|
||||
...
|
||||
def __call__(self, __value: Any, *, values: dict[str, Any]) -> Any: ...
|
||||
|
||||
|
||||
class V1ValidatorWithKwargs(Protocol):
|
||||
"""A validator with `kwargs` argument, supported for V1 validators and V2 validators."""
|
||||
|
||||
def __call__(self, __value: Any, **kwargs: Any) -> Any:
|
||||
...
|
||||
def __call__(self, __value: Any, **kwargs: Any) -> Any: ...
|
||||
|
||||
|
||||
class V1ValidatorWithValuesAndKwargs(Protocol):
|
||||
"""A validator with `values` and `kwargs` arguments, supported for V1 validators and V2 validators."""
|
||||
|
||||
def __call__(self, __value: Any, values: dict[str, Any], **kwargs: Any) -> Any:
|
||||
...
|
||||
def __call__(self, __value: Any, values: dict[str, Any], **kwargs: Any) -> Any: ...
|
||||
|
||||
|
||||
V1Validator = Union[
|
||||
@@ -109,23 +105,21 @@ def make_generic_v1_field_validator(validator: V1Validator) -> core_schema.WithI
|
||||
return wrapper2
|
||||
|
||||
|
||||
RootValidatorValues = Dict[str, Any]
|
||||
RootValidatorValues = dict[str, Any]
|
||||
# technically tuple[model_dict, model_extra, fields_set] | tuple[dataclass_dict, init_vars]
|
||||
RootValidatorFieldsTuple = Tuple[Any, ...]
|
||||
RootValidatorFieldsTuple = tuple[Any, ...]
|
||||
|
||||
|
||||
class V1RootValidatorFunction(Protocol):
|
||||
"""A simple root validator, supported for V1 validators and V2 validators."""
|
||||
|
||||
def __call__(self, __values: RootValidatorValues) -> RootValidatorValues:
|
||||
...
|
||||
def __call__(self, __values: RootValidatorValues) -> RootValidatorValues: ...
|
||||
|
||||
|
||||
class V2CoreBeforeRootValidator(Protocol):
|
||||
"""V2 validator with mode='before'."""
|
||||
|
||||
def __call__(self, __values: RootValidatorValues, __info: core_schema.ValidationInfo) -> RootValidatorValues:
|
||||
...
|
||||
def __call__(self, __values: RootValidatorValues, __info: core_schema.ValidationInfo) -> RootValidatorValues: ...
|
||||
|
||||
|
||||
class V2CoreAfterRootValidator(Protocol):
|
||||
@@ -133,8 +127,7 @@ class V2CoreAfterRootValidator(Protocol):
|
||||
|
||||
def __call__(
|
||||
self, __fields_tuple: RootValidatorFieldsTuple, __info: core_schema.ValidationInfo
|
||||
) -> RootValidatorFieldsTuple:
|
||||
...
|
||||
) -> RootValidatorFieldsTuple: ...
|
||||
|
||||
|
||||
def make_v1_generic_root_validator(
|
||||
|
||||
@@ -1,22 +1,19 @@
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Hashable, Sequence
|
||||
from collections.abc import Hashable, Sequence
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from pydantic_core import CoreSchema, core_schema
|
||||
|
||||
from ..errors import PydanticUserError
|
||||
from . import _core_utils
|
||||
from ._core_utils import (
|
||||
NEEDS_APPLY_DISCRIMINATED_UNION_METADATA_KEY,
|
||||
CoreSchemaField,
|
||||
collect_definitions,
|
||||
simplify_schema_references,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..types import Discriminator
|
||||
|
||||
CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY = 'pydantic.internal.union_discriminator'
|
||||
from ._core_metadata import CoreMetadata
|
||||
|
||||
|
||||
class MissingDefinitionForUnionRef(Exception):
|
||||
@@ -29,35 +26,9 @@ class MissingDefinitionForUnionRef(Exception):
|
||||
super().__init__(f'Missing definition for ref {self.ref!r}')
|
||||
|
||||
|
||||
def set_discriminator(schema: CoreSchema, discriminator: Any) -> None:
|
||||
schema.setdefault('metadata', {})
|
||||
metadata = schema.get('metadata')
|
||||
assert metadata is not None
|
||||
metadata[CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY] = discriminator
|
||||
|
||||
|
||||
def apply_discriminators(schema: core_schema.CoreSchema) -> core_schema.CoreSchema:
|
||||
definitions: dict[str, CoreSchema] | None = None
|
||||
|
||||
def inner(s: core_schema.CoreSchema, recurse: _core_utils.Recurse) -> core_schema.CoreSchema:
|
||||
nonlocal definitions
|
||||
if 'metadata' in s:
|
||||
if s['metadata'].get(NEEDS_APPLY_DISCRIMINATED_UNION_METADATA_KEY, True) is False:
|
||||
return s
|
||||
|
||||
s = recurse(s, inner)
|
||||
if s['type'] == 'tagged-union':
|
||||
return s
|
||||
|
||||
metadata = s.get('metadata', {})
|
||||
discriminator = metadata.get(CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY, None)
|
||||
if discriminator is not None:
|
||||
if definitions is None:
|
||||
definitions = collect_definitions(schema)
|
||||
s = apply_discriminator(s, discriminator, definitions)
|
||||
return s
|
||||
|
||||
return simplify_schema_references(_core_utils.walk_core_schema(schema, inner))
|
||||
def set_discriminator_in_metadata(schema: CoreSchema, discriminator: Any) -> None:
|
||||
metadata = cast('CoreMetadata', schema.setdefault('metadata', {}))
|
||||
metadata['pydantic_internal_union_discriminator'] = discriminator
|
||||
|
||||
|
||||
def apply_discriminator(
|
||||
@@ -163,7 +134,7 @@ class _ApplyInferredDiscriminator:
|
||||
# in the output TaggedUnionSchema that will replace the union from the input schema
|
||||
self._tagged_union_choices: dict[Hashable, core_schema.CoreSchema] = {}
|
||||
|
||||
# `_used` is changed to True after applying the discriminator to prevent accidental re-use
|
||||
# `_used` is changed to True after applying the discriminator to prevent accidental reuse
|
||||
self._used = False
|
||||
|
||||
def apply(self, schema: core_schema.CoreSchema) -> core_schema.CoreSchema:
|
||||
@@ -189,16 +160,11 @@ class _ApplyInferredDiscriminator:
|
||||
- If discriminator fields have different aliases.
|
||||
- If discriminator field not of type `Literal`.
|
||||
"""
|
||||
self.definitions.update(collect_definitions(schema))
|
||||
assert not self._used
|
||||
schema = self._apply_to_root(schema)
|
||||
if self._should_be_nullable and not self._is_nullable:
|
||||
schema = core_schema.nullable_schema(schema)
|
||||
self._used = True
|
||||
new_defs = collect_definitions(schema)
|
||||
missing_defs = self.definitions.keys() - new_defs.keys()
|
||||
if missing_defs:
|
||||
schema = core_schema.definitions_schema(schema, [self.definitions[ref] for ref in missing_defs])
|
||||
return schema
|
||||
|
||||
def _apply_to_root(self, schema: core_schema.CoreSchema) -> core_schema.CoreSchema:
|
||||
@@ -268,6 +234,10 @@ class _ApplyInferredDiscriminator:
|
||||
* Validating that each allowed discriminator value maps to a unique choice
|
||||
* Updating the _tagged_union_choices mapping that will ultimately be used to build the TaggedUnionSchema.
|
||||
"""
|
||||
if choice['type'] == 'definition-ref':
|
||||
if choice['schema_ref'] not in self.definitions:
|
||||
raise MissingDefinitionForUnionRef(choice['schema_ref'])
|
||||
|
||||
if choice['type'] == 'none':
|
||||
self._should_be_nullable = True
|
||||
elif choice['type'] == 'definitions':
|
||||
@@ -279,10 +249,6 @@ class _ApplyInferredDiscriminator:
|
||||
# Reverse the choices list before extending the stack so that they get handled in the order they occur
|
||||
choices_schemas = [v[0] if isinstance(v, tuple) else v for v in choice['choices'][::-1]]
|
||||
self._choices_to_handle.extend(choices_schemas)
|
||||
elif choice['type'] == 'definition-ref':
|
||||
if choice['schema_ref'] not in self.definitions:
|
||||
raise MissingDefinitionForUnionRef(choice['schema_ref'])
|
||||
self._handle_choice(self.definitions[choice['schema_ref']])
|
||||
elif choice['type'] not in {
|
||||
'model',
|
||||
'typed-dict',
|
||||
@@ -290,12 +256,16 @@ class _ApplyInferredDiscriminator:
|
||||
'lax-or-strict',
|
||||
'dataclass',
|
||||
'dataclass-args',
|
||||
'definition-ref',
|
||||
} and not _core_utils.is_function_with_inner_schema(choice):
|
||||
# We should eventually handle 'definition-ref' as well
|
||||
raise TypeError(
|
||||
f'{choice["type"]!r} is not a valid discriminated union variant;'
|
||||
' should be a `BaseModel` or `dataclass`'
|
||||
)
|
||||
err_str = f'The core schema type {choice["type"]!r} is not a valid discriminated union variant.'
|
||||
if choice['type'] == 'list':
|
||||
err_str += (
|
||||
' If you are making use of a list of union types, make sure the discriminator is applied to the '
|
||||
'union type and not the list (e.g. `list[Annotated[<T> | <U>, Field(discriminator=...)]]`).'
|
||||
)
|
||||
raise TypeError(err_str)
|
||||
else:
|
||||
if choice['type'] == 'tagged-union' and self._is_discriminator_shared(choice):
|
||||
# In this case, this inner tagged-union is compatible with the outer tagged-union,
|
||||
@@ -329,13 +299,10 @@ class _ApplyInferredDiscriminator:
|
||||
"""
|
||||
if choice['type'] == 'definitions':
|
||||
return self._infer_discriminator_values_for_choice(choice['schema'], source_name=source_name)
|
||||
elif choice['type'] == 'function-plain':
|
||||
raise TypeError(
|
||||
f'{choice["type"]!r} is not a valid discriminated union variant;'
|
||||
' should be a `BaseModel` or `dataclass`'
|
||||
)
|
||||
|
||||
elif _core_utils.is_function_with_inner_schema(choice):
|
||||
return self._infer_discriminator_values_for_choice(choice['schema'], source_name=source_name)
|
||||
|
||||
elif choice['type'] == 'lax-or-strict':
|
||||
return sorted(
|
||||
set(
|
||||
@@ -386,10 +353,13 @@ class _ApplyInferredDiscriminator:
|
||||
raise MissingDefinitionForUnionRef(schema_ref)
|
||||
return self._infer_discriminator_values_for_choice(self.definitions[schema_ref], source_name=source_name)
|
||||
else:
|
||||
raise TypeError(
|
||||
f'{choice["type"]!r} is not a valid discriminated union variant;'
|
||||
' should be a `BaseModel` or `dataclass`'
|
||||
)
|
||||
err_str = f'The core schema type {choice["type"]!r} is not a valid discriminated union variant.'
|
||||
if choice['type'] == 'list':
|
||||
err_str += (
|
||||
' If you are making use of a list of union types, make sure the discriminator is applied to the '
|
||||
'union type and not the list (e.g. `list[Annotated[<T> | <U>, Field(discriminator=...)]]`).'
|
||||
)
|
||||
raise TypeError(err_str)
|
||||
|
||||
def _infer_discriminator_values_for_typed_dict_choice(
|
||||
self, choice: core_schema.TypedDictSchema, source_name: str | None = None
|
||||
|
||||
@@ -0,0 +1,113 @@
|
||||
"""Utilities related to attribute docstring extraction."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import inspect
|
||||
import sys
|
||||
import textwrap
|
||||
from typing import Any
|
||||
|
||||
|
||||
class DocstringVisitor(ast.NodeVisitor):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.target: str | None = None
|
||||
self.attrs: dict[str, str] = {}
|
||||
self.previous_node_type: type[ast.AST] | None = None
|
||||
|
||||
def visit(self, node: ast.AST) -> Any:
|
||||
node_result = super().visit(node)
|
||||
self.previous_node_type = type(node)
|
||||
return node_result
|
||||
|
||||
def visit_AnnAssign(self, node: ast.AnnAssign) -> Any:
|
||||
if isinstance(node.target, ast.Name):
|
||||
self.target = node.target.id
|
||||
|
||||
def visit_Expr(self, node: ast.Expr) -> Any:
|
||||
if (
|
||||
isinstance(node.value, ast.Constant)
|
||||
and isinstance(node.value.value, str)
|
||||
and self.previous_node_type is ast.AnnAssign
|
||||
):
|
||||
docstring = inspect.cleandoc(node.value.value)
|
||||
if self.target:
|
||||
self.attrs[self.target] = docstring
|
||||
self.target = None
|
||||
|
||||
|
||||
def _dedent_source_lines(source: list[str]) -> str:
|
||||
# Required for nested class definitions, e.g. in a function block
|
||||
dedent_source = textwrap.dedent(''.join(source))
|
||||
if dedent_source.startswith((' ', '\t')):
|
||||
# We are in the case where there's a dedented (usually multiline) string
|
||||
# at a lower indentation level than the class itself. We wrap our class
|
||||
# in a function as a workaround.
|
||||
dedent_source = f'def dedent_workaround():\n{dedent_source}'
|
||||
return dedent_source
|
||||
|
||||
|
||||
def _extract_source_from_frame(cls: type[Any]) -> list[str] | None:
|
||||
frame = inspect.currentframe()
|
||||
|
||||
while frame:
|
||||
if inspect.getmodule(frame) is inspect.getmodule(cls):
|
||||
lnum = frame.f_lineno
|
||||
try:
|
||||
lines, _ = inspect.findsource(frame)
|
||||
except OSError: # pragma: no cover
|
||||
# Source can't be retrieved (maybe because running in an interactive terminal),
|
||||
# we don't want to error here.
|
||||
pass
|
||||
else:
|
||||
block_lines = inspect.getblock(lines[lnum - 1 :])
|
||||
dedent_source = _dedent_source_lines(block_lines)
|
||||
try:
|
||||
block_tree = ast.parse(dedent_source)
|
||||
except SyntaxError:
|
||||
pass
|
||||
else:
|
||||
stmt = block_tree.body[0]
|
||||
if isinstance(stmt, ast.FunctionDef) and stmt.name == 'dedent_workaround':
|
||||
# `_dedent_source_lines` wrapped the class around the workaround function
|
||||
stmt = stmt.body[0]
|
||||
if isinstance(stmt, ast.ClassDef) and stmt.name == cls.__name__:
|
||||
return block_lines
|
||||
|
||||
frame = frame.f_back
|
||||
|
||||
|
||||
def extract_docstrings_from_cls(cls: type[Any], use_inspect: bool = False) -> dict[str, str]:
|
||||
"""Map model attributes and their corresponding docstring.
|
||||
|
||||
Args:
|
||||
cls: The class of the Pydantic model to inspect.
|
||||
use_inspect: Whether to skip usage of frames to find the object and use
|
||||
the `inspect` module instead.
|
||||
|
||||
Returns:
|
||||
A mapping containing attribute names and their corresponding docstring.
|
||||
"""
|
||||
if use_inspect or sys.version_info >= (3, 13):
|
||||
# On Python < 3.13, `inspect.getsourcelines()` might not work as expected
|
||||
# if two classes have the same name in the same source file.
|
||||
# On Python 3.13+, it will use the new `__firstlineno__` class attribute,
|
||||
# making it way more robust.
|
||||
try:
|
||||
source, _ = inspect.getsourcelines(cls)
|
||||
except OSError: # pragma: no cover
|
||||
return {}
|
||||
else:
|
||||
# TODO remove this implementation when we drop support for Python 3.12:
|
||||
source = _extract_source_from_frame(cls)
|
||||
|
||||
if not source:
|
||||
return {}
|
||||
|
||||
dedent_source = _dedent_source_lines(source)
|
||||
|
||||
visitor = DocstringVisitor()
|
||||
visitor.visit(ast.parse(dedent_source))
|
||||
return visitor.attrs
|
||||
@@ -1,58 +1,40 @@
|
||||
"""Private logic related to fields (the `Field()` function and `FieldInfo` class), and arguments to `Annotated`."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import dataclasses
|
||||
import sys
|
||||
import warnings
|
||||
from copy import copy
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from collections.abc import Mapping
|
||||
from functools import cache
|
||||
from inspect import Parameter, ismethoddescriptor, signature
|
||||
from re import Pattern
|
||||
from typing import TYPE_CHECKING, Any, Callable, TypeVar
|
||||
|
||||
from pydantic_core import PydanticUndefined
|
||||
from typing_extensions import TypeIs
|
||||
from typing_inspection.introspection import AnnotationSource
|
||||
|
||||
from . import _typing_extra
|
||||
from pydantic import PydanticDeprecatedSince211
|
||||
from pydantic.errors import PydanticUserError
|
||||
|
||||
from ..aliases import AliasGenerator
|
||||
from . import _generics, _typing_extra
|
||||
from ._config import ConfigWrapper
|
||||
from ._docs_extraction import extract_docstrings_from_cls
|
||||
from ._import_utils import import_cached_base_model, import_cached_field_info
|
||||
from ._namespace_utils import NsResolver
|
||||
from ._repr import Representation
|
||||
from ._typing_extra import get_cls_type_hints_lenient, get_type_hints, is_classvar, is_finalvar
|
||||
from ._utils import can_be_positional, get_first_not_none
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from annotated_types import BaseMetadata
|
||||
|
||||
from ..fields import FieldInfo
|
||||
from ..main import BaseModel
|
||||
from ._dataclasses import StandardDataclass
|
||||
from ._dataclasses import PydanticDataclass, StandardDataclass
|
||||
from ._decorators import DecoratorInfos
|
||||
|
||||
|
||||
def get_type_hints_infer_globalns(
|
||||
obj: Any,
|
||||
localns: dict[str, Any] | None = None,
|
||||
include_extras: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""Gets type hints for an object by inferring the global namespace.
|
||||
|
||||
It uses the `typing.get_type_hints`, The only thing that we do here is fetching
|
||||
global namespace from `obj.__module__` if it is not `None`.
|
||||
|
||||
Args:
|
||||
obj: The object to get its type hints.
|
||||
localns: The local namespaces.
|
||||
include_extras: Whether to recursively include annotation metadata.
|
||||
|
||||
Returns:
|
||||
The object type hints.
|
||||
"""
|
||||
module_name = getattr(obj, '__module__', None)
|
||||
globalns: dict[str, Any] | None = None
|
||||
if module_name:
|
||||
try:
|
||||
globalns = sys.modules[module_name].__dict__
|
||||
except KeyError:
|
||||
# happens occasionally, see https://github.com/pydantic/pydantic/issues/2363
|
||||
pass
|
||||
return get_type_hints(obj, globalns=globalns, localns=localns, include_extras=include_extras)
|
||||
|
||||
|
||||
class PydanticMetadata(Representation):
|
||||
"""Base class for annotation markers like `Strict`."""
|
||||
|
||||
@@ -71,7 +53,7 @@ def pydantic_general_metadata(**metadata: Any) -> BaseMetadata:
|
||||
return _general_metadata_cls()(metadata) # type: ignore
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
@cache
|
||||
def _general_metadata_cls() -> type[BaseMetadata]:
|
||||
"""Do it this way to avoid importing `annotated_types` at import time."""
|
||||
from annotated_types import BaseMetadata
|
||||
@@ -85,29 +67,176 @@ def _general_metadata_cls() -> type[BaseMetadata]:
|
||||
return _PydanticGeneralMetadata # type: ignore
|
||||
|
||||
|
||||
def _check_protected_namespaces(
|
||||
protected_namespaces: tuple[str | Pattern[str], ...],
|
||||
ann_name: str,
|
||||
bases: tuple[type[Any], ...],
|
||||
cls_name: str,
|
||||
) -> None:
|
||||
BaseModel = import_cached_base_model()
|
||||
|
||||
for protected_namespace in protected_namespaces:
|
||||
ns_violation = False
|
||||
if isinstance(protected_namespace, Pattern):
|
||||
ns_violation = protected_namespace.match(ann_name) is not None
|
||||
elif isinstance(protected_namespace, str):
|
||||
ns_violation = ann_name.startswith(protected_namespace)
|
||||
|
||||
if ns_violation:
|
||||
for b in bases:
|
||||
if hasattr(b, ann_name):
|
||||
if not (issubclass(b, BaseModel) and ann_name in getattr(b, '__pydantic_fields__', {})):
|
||||
raise ValueError(
|
||||
f'Field {ann_name!r} conflicts with member {getattr(b, ann_name)}'
|
||||
f' of protected namespace {protected_namespace!r}.'
|
||||
)
|
||||
else:
|
||||
valid_namespaces: list[str] = []
|
||||
for pn in protected_namespaces:
|
||||
if isinstance(pn, Pattern):
|
||||
if not pn.match(ann_name):
|
||||
valid_namespaces.append(f're.compile({pn.pattern!r})')
|
||||
else:
|
||||
if not ann_name.startswith(pn):
|
||||
valid_namespaces.append(f"'{pn}'")
|
||||
|
||||
valid_namespaces_str = f'({", ".join(valid_namespaces)}{",)" if len(valid_namespaces) == 1 else ")"}'
|
||||
|
||||
warnings.warn(
|
||||
f'Field {ann_name!r} in {cls_name!r} conflicts with protected namespace {protected_namespace!r}.\n\n'
|
||||
f"You may be able to solve this by setting the 'protected_namespaces' configuration to {valid_namespaces_str}.",
|
||||
UserWarning,
|
||||
stacklevel=5,
|
||||
)
|
||||
|
||||
|
||||
def _update_fields_from_docstrings(cls: type[Any], fields: dict[str, FieldInfo], use_inspect: bool = False) -> None:
|
||||
fields_docs = extract_docstrings_from_cls(cls, use_inspect=use_inspect)
|
||||
for ann_name, field_info in fields.items():
|
||||
if field_info.description is None and ann_name in fields_docs:
|
||||
field_info.description = fields_docs[ann_name]
|
||||
|
||||
|
||||
def _apply_field_title_generator_to_field_info(
|
||||
title_generator: Callable[[str, FieldInfo], str],
|
||||
field_name: str,
|
||||
field_info: FieldInfo,
|
||||
):
|
||||
if field_info.title is None:
|
||||
title = title_generator(field_name, field_info)
|
||||
if not isinstance(title, str):
|
||||
raise TypeError(f'field_title_generator {title_generator} must return str, not {title.__class__}')
|
||||
|
||||
field_info.title = title
|
||||
|
||||
|
||||
def _apply_alias_generator_to_field_info(
|
||||
alias_generator: Callable[[str], str] | AliasGenerator, field_name: str, field_info: FieldInfo
|
||||
):
|
||||
"""Apply an alias generator to aliases on a `FieldInfo` instance if appropriate.
|
||||
|
||||
Args:
|
||||
alias_generator: A callable that takes a string and returns a string, or an `AliasGenerator` instance.
|
||||
field_name: The name of the field from which to generate the alias.
|
||||
field_info: The `FieldInfo` instance to which the alias generator is (maybe) applied.
|
||||
"""
|
||||
# Apply an alias_generator if
|
||||
# 1. An alias is not specified
|
||||
# 2. An alias is specified, but the priority is <= 1
|
||||
if (
|
||||
field_info.alias_priority is None
|
||||
or field_info.alias_priority <= 1
|
||||
or field_info.alias is None
|
||||
or field_info.validation_alias is None
|
||||
or field_info.serialization_alias is None
|
||||
):
|
||||
alias, validation_alias, serialization_alias = None, None, None
|
||||
|
||||
if isinstance(alias_generator, AliasGenerator):
|
||||
alias, validation_alias, serialization_alias = alias_generator.generate_aliases(field_name)
|
||||
elif callable(alias_generator):
|
||||
alias = alias_generator(field_name)
|
||||
if not isinstance(alias, str):
|
||||
raise TypeError(f'alias_generator {alias_generator} must return str, not {alias.__class__}')
|
||||
|
||||
# if priority is not set, we set to 1
|
||||
# which supports the case where the alias_generator from a child class is used
|
||||
# to generate an alias for a field in a parent class
|
||||
if field_info.alias_priority is None or field_info.alias_priority <= 1:
|
||||
field_info.alias_priority = 1
|
||||
|
||||
# if the priority is 1, then we set the aliases to the generated alias
|
||||
if field_info.alias_priority == 1:
|
||||
field_info.serialization_alias = get_first_not_none(serialization_alias, alias)
|
||||
field_info.validation_alias = get_first_not_none(validation_alias, alias)
|
||||
field_info.alias = alias
|
||||
|
||||
# if any of the aliases are not set, then we set them to the corresponding generated alias
|
||||
if field_info.alias is None:
|
||||
field_info.alias = alias
|
||||
if field_info.serialization_alias is None:
|
||||
field_info.serialization_alias = get_first_not_none(serialization_alias, alias)
|
||||
if field_info.validation_alias is None:
|
||||
field_info.validation_alias = get_first_not_none(validation_alias, alias)
|
||||
|
||||
|
||||
def update_field_from_config(config_wrapper: ConfigWrapper, field_name: str, field_info: FieldInfo) -> None:
|
||||
"""Update the `FieldInfo` instance from the configuration set on the model it belongs to.
|
||||
|
||||
This will apply the title and alias generators from the configuration.
|
||||
|
||||
Args:
|
||||
config_wrapper: The configuration from the model.
|
||||
field_name: The field name the `FieldInfo` instance is attached to.
|
||||
field_info: The `FieldInfo` instance to update.
|
||||
"""
|
||||
field_title_generator = field_info.field_title_generator or config_wrapper.field_title_generator
|
||||
if field_title_generator is not None:
|
||||
_apply_field_title_generator_to_field_info(field_title_generator, field_name, field_info)
|
||||
if config_wrapper.alias_generator is not None:
|
||||
_apply_alias_generator_to_field_info(config_wrapper.alias_generator, field_name, field_info)
|
||||
|
||||
|
||||
_deprecated_method_names = {'dict', 'json', 'copy', '_iter', '_copy_and_set_values', '_calculate_keys'}
|
||||
|
||||
_deprecated_classmethod_names = {
|
||||
'parse_obj',
|
||||
'parse_raw',
|
||||
'parse_file',
|
||||
'from_orm',
|
||||
'construct',
|
||||
'schema',
|
||||
'schema_json',
|
||||
'validate',
|
||||
'update_forward_refs',
|
||||
'_get_value',
|
||||
}
|
||||
|
||||
|
||||
def collect_model_fields( # noqa: C901
|
||||
cls: type[BaseModel],
|
||||
bases: tuple[type[Any], ...],
|
||||
config_wrapper: ConfigWrapper,
|
||||
types_namespace: dict[str, Any] | None,
|
||||
ns_resolver: NsResolver | None,
|
||||
*,
|
||||
typevars_map: dict[Any, Any] | None = None,
|
||||
typevars_map: Mapping[TypeVar, Any] | None = None,
|
||||
) -> tuple[dict[str, FieldInfo], set[str]]:
|
||||
"""Collect the fields of a nascent pydantic model.
|
||||
"""Collect the fields and class variables names of a nascent Pydantic model.
|
||||
|
||||
Also collect the names of any ClassVars present in the type hints.
|
||||
The fields collection process is *lenient*, meaning it won't error if string annotations
|
||||
fail to evaluate. If this happens, the original annotation (and assigned value, if any)
|
||||
is stored on the created `FieldInfo` instance.
|
||||
|
||||
The returned value is a tuple of two items: the fields dict, and the set of ClassVar names.
|
||||
The `rebuild_model_fields()` should be called at a later point (e.g. when rebuilding the model),
|
||||
and will make use of these stored attributes.
|
||||
|
||||
Args:
|
||||
cls: BaseModel or dataclass.
|
||||
bases: Parents of the class, generally `cls.__bases__`.
|
||||
config_wrapper: The config wrapper instance.
|
||||
types_namespace: Optional extra namespace to look for types in.
|
||||
ns_resolver: Namespace resolver to use when getting model annotations.
|
||||
typevars_map: A dictionary mapping type variables to their concrete types.
|
||||
|
||||
Returns:
|
||||
A tuple contains fields and class variables.
|
||||
A two-tuple containing model fields and class variables names.
|
||||
|
||||
Raises:
|
||||
NameError:
|
||||
@@ -115,49 +244,58 @@ def collect_model_fields( # noqa: C901
|
||||
- If there is a field other than `root` in `RootModel`.
|
||||
- If a field shadows an attribute in the parent model.
|
||||
"""
|
||||
from ..fields import FieldInfo
|
||||
FieldInfo_ = import_cached_field_info()
|
||||
BaseModel_ = import_cached_base_model()
|
||||
|
||||
type_hints = get_cls_type_hints_lenient(cls, types_namespace)
|
||||
bases = cls.__bases__
|
||||
parent_fields_lookup: dict[str, FieldInfo] = {}
|
||||
for base in reversed(bases):
|
||||
if model_fields := getattr(base, '__pydantic_fields__', None):
|
||||
parent_fields_lookup.update(model_fields)
|
||||
|
||||
type_hints = _typing_extra.get_model_type_hints(cls, ns_resolver=ns_resolver)
|
||||
|
||||
# https://docs.python.org/3/howto/annotations.html#accessing-the-annotations-dict-of-an-object-in-python-3-9-and-older
|
||||
# annotations is only used for finding fields in parent classes
|
||||
annotations = cls.__dict__.get('__annotations__', {})
|
||||
annotations = _typing_extra.safe_get_annotations(cls)
|
||||
|
||||
fields: dict[str, FieldInfo] = {}
|
||||
|
||||
class_vars: set[str] = set()
|
||||
for ann_name, ann_type in type_hints.items():
|
||||
for ann_name, (ann_type, evaluated) in type_hints.items():
|
||||
if ann_name == 'model_config':
|
||||
# We never want to treat `model_config` as a field
|
||||
# Note: we may need to change this logic if/when we introduce a `BareModel` class with no
|
||||
# protected namespaces (where `model_config` might be allowed as a field name)
|
||||
continue
|
||||
for protected_namespace in config_wrapper.protected_namespaces:
|
||||
if ann_name.startswith(protected_namespace):
|
||||
for b in bases:
|
||||
if hasattr(b, ann_name):
|
||||
from ..main import BaseModel
|
||||
|
||||
if not (issubclass(b, BaseModel) and ann_name in b.model_fields):
|
||||
raise NameError(
|
||||
f'Field "{ann_name}" conflicts with member {getattr(b, ann_name)}'
|
||||
f' of protected namespace "{protected_namespace}".'
|
||||
)
|
||||
else:
|
||||
valid_namespaces = tuple(
|
||||
x for x in config_wrapper.protected_namespaces if not ann_name.startswith(x)
|
||||
)
|
||||
warnings.warn(
|
||||
f'Field "{ann_name}" has conflict with protected namespace "{protected_namespace}".'
|
||||
'\n\nYou may be able to resolve this warning by setting'
|
||||
f" `model_config['protected_namespaces'] = {valid_namespaces}`.",
|
||||
UserWarning,
|
||||
)
|
||||
if is_classvar(ann_type):
|
||||
class_vars.add(ann_name)
|
||||
continue
|
||||
if _is_finalvar_with_default_val(ann_type, getattr(cls, ann_name, PydanticUndefined)):
|
||||
_check_protected_namespaces(
|
||||
protected_namespaces=config_wrapper.protected_namespaces,
|
||||
ann_name=ann_name,
|
||||
bases=bases,
|
||||
cls_name=cls.__name__,
|
||||
)
|
||||
|
||||
if _typing_extra.is_classvar_annotation(ann_type):
|
||||
class_vars.add(ann_name)
|
||||
continue
|
||||
|
||||
assigned_value = getattr(cls, ann_name, PydanticUndefined)
|
||||
if assigned_value is not PydanticUndefined and (
|
||||
# One of the deprecated instance methods was used as a field name (e.g. `dict()`):
|
||||
any(getattr(BaseModel_, depr_name, None) is assigned_value for depr_name in _deprecated_method_names)
|
||||
# One of the deprecated class methods was used as a field name (e.g. `schema()`):
|
||||
or (
|
||||
hasattr(assigned_value, '__func__')
|
||||
and any(
|
||||
getattr(getattr(BaseModel_, depr_name, None), '__func__', None) is assigned_value.__func__ # pyright: ignore[reportAttributeAccessIssue]
|
||||
for depr_name in _deprecated_classmethod_names
|
||||
)
|
||||
)
|
||||
):
|
||||
# Then `assigned_value` would be the method, even though no default was specified:
|
||||
assigned_value = PydanticUndefined
|
||||
|
||||
if not is_valid_field_name(ann_name):
|
||||
continue
|
||||
if cls.__pydantic_root_model__ and ann_name != 'root':
|
||||
@@ -166,7 +304,7 @@ def collect_model_fields( # noqa: C901
|
||||
)
|
||||
|
||||
# when building a generic model with `MyModel[int]`, the generic_origin check makes sure we don't get
|
||||
# "... shadows an attribute" errors
|
||||
# "... shadows an attribute" warnings
|
||||
generic_origin = getattr(cls, '__pydantic_generic_metadata__', {}).get('origin')
|
||||
for base in bases:
|
||||
dataclass_fields = {
|
||||
@@ -174,42 +312,74 @@ def collect_model_fields( # noqa: C901
|
||||
}
|
||||
if hasattr(base, ann_name):
|
||||
if base is generic_origin:
|
||||
# Don't error when "shadowing" of attributes in parametrized generics
|
||||
# Don't warn when "shadowing" of attributes in parametrized generics
|
||||
continue
|
||||
|
||||
if ann_name in dataclass_fields:
|
||||
# Don't error when inheriting stdlib dataclasses whose fields are "shadowed" by defaults being set
|
||||
# Don't warn when inheriting stdlib dataclasses whose fields are "shadowed" by defaults being set
|
||||
# on the class instance.
|
||||
continue
|
||||
|
||||
if ann_name not in annotations:
|
||||
# Don't warn when a field exists in a parent class but has not been defined in the current class
|
||||
continue
|
||||
|
||||
warnings.warn(
|
||||
f'Field name "{ann_name}" shadows an attribute in parent "{base.__qualname__}"; ',
|
||||
f'Field name "{ann_name}" in "{cls.__qualname__}" shadows an attribute in parent '
|
||||
f'"{base.__qualname__}"',
|
||||
UserWarning,
|
||||
stacklevel=4,
|
||||
)
|
||||
|
||||
try:
|
||||
default = getattr(cls, ann_name, PydanticUndefined)
|
||||
if default is PydanticUndefined:
|
||||
raise AttributeError
|
||||
except AttributeError:
|
||||
if ann_name in annotations:
|
||||
field_info = FieldInfo.from_annotation(ann_type)
|
||||
if assigned_value is PydanticUndefined: # no assignment, just a plain annotation
|
||||
if ann_name in annotations or ann_name not in parent_fields_lookup:
|
||||
# field is either:
|
||||
# - present in the current model's annotations (and *not* from parent classes)
|
||||
# - not found on any base classes; this seems to be caused by fields bot getting
|
||||
# generated due to models not being fully defined while initializing recursive models.
|
||||
# Nothing stops us from just creating a `FieldInfo` for this type hint, so we do this.
|
||||
field_info = FieldInfo_.from_annotation(ann_type, _source=AnnotationSource.CLASS)
|
||||
if not evaluated:
|
||||
field_info._complete = False
|
||||
# Store the original annotation that should be used to rebuild
|
||||
# the field info later:
|
||||
field_info._original_annotation = ann_type
|
||||
else:
|
||||
# if field has no default value and is not in __annotations__ this means that it is
|
||||
# defined in a base class and we can take it from there
|
||||
model_fields_lookup: dict[str, FieldInfo] = {}
|
||||
for x in cls.__bases__[::-1]:
|
||||
model_fields_lookup.update(getattr(x, 'model_fields', {}))
|
||||
if ann_name in model_fields_lookup:
|
||||
# The field was present on one of the (possibly multiple) base classes
|
||||
# copy the field to make sure typevar substitutions don't cause issues with the base classes
|
||||
field_info = copy(model_fields_lookup[ann_name])
|
||||
else:
|
||||
# The field was not found on any base classes; this seems to be caused by fields not getting
|
||||
# generated thanks to models not being fully defined while initializing recursive models.
|
||||
# Nothing stops us from just creating a new FieldInfo for this type hint, so we do this.
|
||||
field_info = FieldInfo.from_annotation(ann_type)
|
||||
else:
|
||||
field_info = FieldInfo.from_annotated_attribute(ann_type, default)
|
||||
# The field was present on one of the (possibly multiple) base classes
|
||||
# copy the field to make sure typevar substitutions don't cause issues with the base classes
|
||||
field_info = parent_fields_lookup[ann_name]._copy()
|
||||
|
||||
else: # An assigned value is present (either the default value, or a `Field()` function)
|
||||
if isinstance(assigned_value, FieldInfo_) and ismethoddescriptor(assigned_value.default):
|
||||
# `assigned_value` was fetched using `getattr`, which triggers a call to `__get__`
|
||||
# for descriptors, so we do the same if the `= field(default=...)` form is used.
|
||||
# Note that we only do this for method descriptors for now, we might want to
|
||||
# extend this to any descriptor in the future (by simply checking for
|
||||
# `hasattr(assigned_value.default, '__get__')`).
|
||||
default = assigned_value.default.__get__(None, cls)
|
||||
assigned_value.default = default
|
||||
assigned_value._attributes_set['default'] = default
|
||||
|
||||
field_info = FieldInfo_.from_annotated_attribute(ann_type, assigned_value, _source=AnnotationSource.CLASS)
|
||||
# Store the original annotation and assignment value that should be used to rebuild the field info later.
|
||||
# Note that the assignment is always stored as the annotation might contain a type var that is later
|
||||
# parameterized with an unknown forward reference (and we'll need it to rebuild the field info):
|
||||
field_info._original_assignment = assigned_value
|
||||
if not evaluated:
|
||||
field_info._complete = False
|
||||
field_info._original_annotation = ann_type
|
||||
elif 'final' in field_info._qualifiers and not field_info.is_required():
|
||||
warnings.warn(
|
||||
f'Annotation {ann_name!r} is marked as final and has a default value. Pydantic treats {ann_name!r} as a '
|
||||
'class variable, but it will be considered as a normal field in V3 to be aligned with dataclasses. If you '
|
||||
f'still want {ann_name!r} to be considered as a class variable, annotate it as: `ClassVar[<type>] = <default>.`',
|
||||
category=PydanticDeprecatedSince211,
|
||||
# Incorrect when `create_model` is used, but the chance that final with a default is used is low in that case:
|
||||
stacklevel=4,
|
||||
)
|
||||
class_vars.add(ann_name)
|
||||
continue
|
||||
|
||||
# attributes which are fields are removed from the class namespace:
|
||||
# 1. To match the behaviour of annotation-only fields
|
||||
# 2. To avoid false positives in the NameError check above
|
||||
@@ -222,85 +392,244 @@ def collect_model_fields( # noqa: C901
|
||||
# to make sure the decorators have already been built for this exact class
|
||||
decorators: DecoratorInfos = cls.__dict__['__pydantic_decorators__']
|
||||
if ann_name in decorators.computed_fields:
|
||||
raise ValueError("you can't override a field with a computed field")
|
||||
raise TypeError(
|
||||
f'Field {ann_name!r} of class {cls.__name__!r} overrides symbol of same name in a parent class. '
|
||||
'This override with a computed_field is incompatible.'
|
||||
)
|
||||
fields[ann_name] = field_info
|
||||
|
||||
if field_info._complete:
|
||||
# If not complete, this will be called in `rebuild_model_fields()`:
|
||||
update_field_from_config(config_wrapper, ann_name, field_info)
|
||||
|
||||
if typevars_map:
|
||||
for field in fields.values():
|
||||
field.apply_typevars_map(typevars_map, types_namespace)
|
||||
if field._complete:
|
||||
field.apply_typevars_map(typevars_map)
|
||||
|
||||
if config_wrapper.use_attribute_docstrings:
|
||||
_update_fields_from_docstrings(cls, fields)
|
||||
return fields, class_vars
|
||||
|
||||
|
||||
def _is_finalvar_with_default_val(type_: type[Any], val: Any) -> bool:
|
||||
from ..fields import FieldInfo
|
||||
def rebuild_model_fields(
|
||||
cls: type[BaseModel],
|
||||
*,
|
||||
config_wrapper: ConfigWrapper,
|
||||
ns_resolver: NsResolver,
|
||||
typevars_map: Mapping[TypeVar, Any],
|
||||
) -> dict[str, FieldInfo]:
|
||||
"""Rebuild the (already present) model fields by trying to reevaluate annotations.
|
||||
|
||||
if not is_finalvar(type_):
|
||||
return False
|
||||
elif val is PydanticUndefined:
|
||||
return False
|
||||
elif isinstance(val, FieldInfo) and (val.default is PydanticUndefined and val.default_factory is None):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
This function should be called whenever a model with incomplete fields is encountered.
|
||||
|
||||
Raises:
|
||||
NameError: If one of the annotations failed to evaluate.
|
||||
|
||||
Note:
|
||||
This function *doesn't* mutate the model fields in place, as it can be called during
|
||||
schema generation, where you don't want to mutate other model's fields.
|
||||
"""
|
||||
FieldInfo_ = import_cached_field_info()
|
||||
|
||||
rebuilt_fields: dict[str, FieldInfo] = {}
|
||||
with ns_resolver.push(cls):
|
||||
for f_name, field_info in cls.__pydantic_fields__.items():
|
||||
if field_info._complete:
|
||||
rebuilt_fields[f_name] = field_info
|
||||
else:
|
||||
existing_desc = field_info.description
|
||||
ann = _typing_extra.eval_type(
|
||||
field_info._original_annotation,
|
||||
*ns_resolver.types_namespace,
|
||||
)
|
||||
ann = _generics.replace_types(ann, typevars_map)
|
||||
|
||||
if (assign := field_info._original_assignment) is PydanticUndefined:
|
||||
new_field = FieldInfo_.from_annotation(ann, _source=AnnotationSource.CLASS)
|
||||
else:
|
||||
new_field = FieldInfo_.from_annotated_attribute(ann, assign, _source=AnnotationSource.CLASS)
|
||||
# The description might come from the docstring if `use_attribute_docstrings` was `True`:
|
||||
new_field.description = new_field.description if new_field.description is not None else existing_desc
|
||||
update_field_from_config(config_wrapper, f_name, new_field)
|
||||
rebuilt_fields[f_name] = new_field
|
||||
|
||||
return rebuilt_fields
|
||||
|
||||
|
||||
def collect_dataclass_fields(
|
||||
cls: type[StandardDataclass], types_namespace: dict[str, Any] | None, *, typevars_map: dict[Any, Any] | None = None
|
||||
cls: type[StandardDataclass],
|
||||
*,
|
||||
config_wrapper: ConfigWrapper,
|
||||
ns_resolver: NsResolver | None = None,
|
||||
typevars_map: dict[Any, Any] | None = None,
|
||||
) -> dict[str, FieldInfo]:
|
||||
"""Collect the fields of a dataclass.
|
||||
|
||||
Args:
|
||||
cls: dataclass.
|
||||
types_namespace: Optional extra namespace to look for types in.
|
||||
config_wrapper: The config wrapper instance.
|
||||
ns_resolver: Namespace resolver to use when getting dataclass annotations.
|
||||
Defaults to an empty instance.
|
||||
typevars_map: A dictionary mapping type variables to their concrete types.
|
||||
|
||||
Returns:
|
||||
The dataclass fields.
|
||||
"""
|
||||
from ..fields import FieldInfo
|
||||
FieldInfo_ = import_cached_field_info()
|
||||
|
||||
fields: dict[str, FieldInfo] = {}
|
||||
dataclass_fields: dict[str, dataclasses.Field] = cls.__dataclass_fields__
|
||||
cls_localns = dict(vars(cls)) # this matches get_cls_type_hints_lenient, but all tests pass with `= None` instead
|
||||
ns_resolver = ns_resolver or NsResolver()
|
||||
dataclass_fields = cls.__dataclass_fields__
|
||||
|
||||
for ann_name, dataclass_field in dataclass_fields.items():
|
||||
ann_type = _typing_extra.eval_type_lenient(dataclass_field.type, types_namespace, cls_localns)
|
||||
if is_classvar(ann_type):
|
||||
# The logic here is similar to `_typing_extra.get_cls_type_hints`,
|
||||
# although we do it manually as stdlib dataclasses already have annotations
|
||||
# collected in each class:
|
||||
for base in reversed(cls.__mro__):
|
||||
if not dataclasses.is_dataclass(base):
|
||||
continue
|
||||
|
||||
if (
|
||||
not dataclass_field.init
|
||||
and dataclass_field.default == dataclasses.MISSING
|
||||
and dataclass_field.default_factory == dataclasses.MISSING
|
||||
):
|
||||
# TODO: We should probably do something with this so that validate_assignment behaves properly
|
||||
# Issue: https://github.com/pydantic/pydantic/issues/5470
|
||||
continue
|
||||
with ns_resolver.push(base):
|
||||
for ann_name, dataclass_field in dataclass_fields.items():
|
||||
base_anns = _typing_extra.safe_get_annotations(base)
|
||||
|
||||
if isinstance(dataclass_field.default, FieldInfo):
|
||||
if dataclass_field.default.init_var:
|
||||
# TODO: same note as above
|
||||
continue
|
||||
field_info = FieldInfo.from_annotated_attribute(ann_type, dataclass_field.default)
|
||||
else:
|
||||
field_info = FieldInfo.from_annotated_attribute(ann_type, dataclass_field)
|
||||
fields[ann_name] = field_info
|
||||
if ann_name not in base_anns:
|
||||
# `__dataclass_fields__`contains every field, even the ones from base classes.
|
||||
# Only collect the ones defined on `base`.
|
||||
continue
|
||||
|
||||
if field_info.default is not PydanticUndefined and isinstance(getattr(cls, ann_name, field_info), FieldInfo):
|
||||
# We need this to fix the default when the "default" from __dataclass_fields__ is a pydantic.FieldInfo
|
||||
setattr(cls, ann_name, field_info.default)
|
||||
globalns, localns = ns_resolver.types_namespace
|
||||
ann_type, evaluated = _typing_extra.try_eval_type(dataclass_field.type, globalns, localns)
|
||||
|
||||
if _typing_extra.is_classvar_annotation(ann_type):
|
||||
continue
|
||||
|
||||
if (
|
||||
not dataclass_field.init
|
||||
and dataclass_field.default is dataclasses.MISSING
|
||||
and dataclass_field.default_factory is dataclasses.MISSING
|
||||
):
|
||||
# TODO: We should probably do something with this so that validate_assignment behaves properly
|
||||
# Issue: https://github.com/pydantic/pydantic/issues/5470
|
||||
continue
|
||||
|
||||
if isinstance(dataclass_field.default, FieldInfo_):
|
||||
if dataclass_field.default.init_var:
|
||||
if dataclass_field.default.init is False:
|
||||
raise PydanticUserError(
|
||||
f'Dataclass field {ann_name} has init=False and init_var=True, but these are mutually exclusive.',
|
||||
code='clashing-init-and-init-var',
|
||||
)
|
||||
|
||||
# TODO: same note as above re validate_assignment
|
||||
continue
|
||||
field_info = FieldInfo_.from_annotated_attribute(
|
||||
ann_type, dataclass_field.default, _source=AnnotationSource.DATACLASS
|
||||
)
|
||||
field_info._original_assignment = dataclass_field.default
|
||||
else:
|
||||
field_info = FieldInfo_.from_annotated_attribute(
|
||||
ann_type, dataclass_field, _source=AnnotationSource.DATACLASS
|
||||
)
|
||||
field_info._original_assignment = dataclass_field
|
||||
|
||||
if not evaluated:
|
||||
field_info._complete = False
|
||||
field_info._original_annotation = ann_type
|
||||
|
||||
fields[ann_name] = field_info
|
||||
update_field_from_config(config_wrapper, ann_name, field_info)
|
||||
|
||||
if field_info.default is not PydanticUndefined and isinstance(
|
||||
getattr(cls, ann_name, field_info), FieldInfo_
|
||||
):
|
||||
# We need this to fix the default when the "default" from __dataclass_fields__ is a pydantic.FieldInfo
|
||||
setattr(cls, ann_name, field_info.default)
|
||||
|
||||
if typevars_map:
|
||||
for field in fields.values():
|
||||
field.apply_typevars_map(typevars_map, types_namespace)
|
||||
# We don't pass any ns, as `field.annotation`
|
||||
# was already evaluated. TODO: is this method relevant?
|
||||
# Can't we juste use `_generics.replace_types`?
|
||||
field.apply_typevars_map(typevars_map)
|
||||
|
||||
if config_wrapper.use_attribute_docstrings:
|
||||
_update_fields_from_docstrings(
|
||||
cls,
|
||||
fields,
|
||||
# We can't rely on the (more reliable) frame inspection method
|
||||
# for stdlib dataclasses:
|
||||
use_inspect=not hasattr(cls, '__is_pydantic_dataclass__'),
|
||||
)
|
||||
|
||||
return fields
|
||||
|
||||
|
||||
def rebuild_dataclass_fields(
|
||||
cls: type[PydanticDataclass],
|
||||
*,
|
||||
config_wrapper: ConfigWrapper,
|
||||
ns_resolver: NsResolver,
|
||||
typevars_map: Mapping[TypeVar, Any],
|
||||
) -> dict[str, FieldInfo]:
|
||||
"""Rebuild the (already present) dataclass fields by trying to reevaluate annotations.
|
||||
|
||||
This function should be called whenever a dataclass with incomplete fields is encountered.
|
||||
|
||||
Raises:
|
||||
NameError: If one of the annotations failed to evaluate.
|
||||
|
||||
Note:
|
||||
This function *doesn't* mutate the dataclass fields in place, as it can be called during
|
||||
schema generation, where you don't want to mutate other dataclass's fields.
|
||||
"""
|
||||
FieldInfo_ = import_cached_field_info()
|
||||
|
||||
rebuilt_fields: dict[str, FieldInfo] = {}
|
||||
with ns_resolver.push(cls):
|
||||
for f_name, field_info in cls.__pydantic_fields__.items():
|
||||
if field_info._complete:
|
||||
rebuilt_fields[f_name] = field_info
|
||||
else:
|
||||
existing_desc = field_info.description
|
||||
ann = _typing_extra.eval_type(
|
||||
field_info._original_annotation,
|
||||
*ns_resolver.types_namespace,
|
||||
)
|
||||
ann = _generics.replace_types(ann, typevars_map)
|
||||
new_field = FieldInfo_.from_annotated_attribute(
|
||||
ann,
|
||||
field_info._original_assignment,
|
||||
_source=AnnotationSource.DATACLASS,
|
||||
)
|
||||
|
||||
# The description might come from the docstring if `use_attribute_docstrings` was `True`:
|
||||
new_field.description = new_field.description if new_field.description is not None else existing_desc
|
||||
update_field_from_config(config_wrapper, f_name, new_field)
|
||||
rebuilt_fields[f_name] = new_field
|
||||
|
||||
return rebuilt_fields
|
||||
|
||||
|
||||
def is_valid_field_name(name: str) -> bool:
|
||||
return not name.startswith('_')
|
||||
|
||||
|
||||
def is_valid_privateattr_name(name: str) -> bool:
|
||||
return name.startswith('_') and not name.startswith('__')
|
||||
|
||||
|
||||
def takes_validated_data_argument(
|
||||
default_factory: Callable[[], Any] | Callable[[dict[str, Any]], Any],
|
||||
) -> TypeIs[Callable[[dict[str, Any]], Any]]:
|
||||
"""Whether the provided default factory callable has a validated data parameter."""
|
||||
try:
|
||||
sig = signature(default_factory)
|
||||
except (ValueError, TypeError):
|
||||
# `inspect.signature` might not be able to infer a signature, e.g. with C objects.
|
||||
# In this case, we assume no data argument is present:
|
||||
return False
|
||||
|
||||
parameters = list(sig.parameters.values())
|
||||
|
||||
return len(parameters) == 1 and can_be_positional(parameters[0]) and parameters[0].default is Parameter.empty
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,29 +1,32 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import operator
|
||||
import sys
|
||||
import types
|
||||
import typing
|
||||
from collections import ChainMap
|
||||
from collections.abc import Iterator, Mapping
|
||||
from contextlib import contextmanager
|
||||
from contextvars import ContextVar
|
||||
from functools import reduce
|
||||
from itertools import zip_longest
|
||||
from types import prepare_class
|
||||
from typing import TYPE_CHECKING, Any, Iterator, List, Mapping, MutableMapping, Tuple, TypeVar
|
||||
from typing import TYPE_CHECKING, Annotated, Any, TypedDict, TypeVar, cast
|
||||
from weakref import WeakValueDictionary
|
||||
|
||||
import typing_extensions
|
||||
from typing_inspection import typing_objects
|
||||
from typing_inspection.introspection import is_union_origin
|
||||
|
||||
from . import _typing_extra
|
||||
from ._core_utils import get_type_ref
|
||||
from ._forward_ref import PydanticRecursiveRef
|
||||
from ._typing_extra import TypeVarType, typing_base
|
||||
from ._utils import all_identical, is_model_class
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
from typing import _UnionGenericAlias # type: ignore[attr-defined]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..main import BaseModel
|
||||
|
||||
GenericTypesCacheKey = Tuple[Any, Any, Tuple[Any, ...]]
|
||||
GenericTypesCacheKey = tuple[Any, Any, tuple[Any, ...]]
|
||||
|
||||
# Note: We want to remove LimitedDict, but to do this, we'd need to improve the handling of generics caching.
|
||||
# Right now, to handle recursive generics, we some types must remain cached for brief periods without references.
|
||||
@@ -34,43 +37,25 @@ GenericTypesCacheKey = Tuple[Any, Any, Tuple[Any, ...]]
|
||||
KT = TypeVar('KT')
|
||||
VT = TypeVar('VT')
|
||||
_LIMITED_DICT_SIZE = 100
|
||||
if TYPE_CHECKING:
|
||||
|
||||
class LimitedDict(dict, MutableMapping[KT, VT]):
|
||||
def __init__(self, size_limit: int = _LIMITED_DICT_SIZE):
|
||||
...
|
||||
|
||||
else:
|
||||
class LimitedDict(dict[KT, VT]):
|
||||
def __init__(self, size_limit: int = _LIMITED_DICT_SIZE) -> None:
|
||||
self.size_limit = size_limit
|
||||
super().__init__()
|
||||
|
||||
class LimitedDict(dict):
|
||||
"""Limit the size/length of a dict used for caching to avoid unlimited increase in memory usage.
|
||||
|
||||
Since the dict is ordered, and we always remove elements from the beginning, this is effectively a FIFO cache.
|
||||
"""
|
||||
|
||||
def __init__(self, size_limit: int = _LIMITED_DICT_SIZE):
|
||||
self.size_limit = size_limit
|
||||
super().__init__()
|
||||
|
||||
def __setitem__(self, __key: Any, __value: Any) -> None:
|
||||
super().__setitem__(__key, __value)
|
||||
if len(self) > self.size_limit:
|
||||
excess = len(self) - self.size_limit + self.size_limit // 10
|
||||
to_remove = list(self.keys())[:excess]
|
||||
for key in to_remove:
|
||||
del self[key]
|
||||
|
||||
def __class_getitem__(cls, *args: Any) -> Any:
|
||||
# to avoid errors with 3.7
|
||||
return cls
|
||||
def __setitem__(self, key: KT, value: VT, /) -> None:
|
||||
super().__setitem__(key, value)
|
||||
if len(self) > self.size_limit:
|
||||
excess = len(self) - self.size_limit + self.size_limit // 10
|
||||
to_remove = list(self.keys())[:excess]
|
||||
for k in to_remove:
|
||||
del self[k]
|
||||
|
||||
|
||||
# weak dictionaries allow the dynamically created parametrized versions of generic models to get collected
|
||||
# once they are no longer referenced by the caller.
|
||||
if sys.version_info >= (3, 9): # Typing for weak dictionaries available at 3.9
|
||||
GenericTypesCache = WeakValueDictionary[GenericTypesCacheKey, 'type[BaseModel]']
|
||||
else:
|
||||
GenericTypesCache = WeakValueDictionary
|
||||
GenericTypesCache = WeakValueDictionary[GenericTypesCacheKey, 'type[BaseModel]']
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@@ -108,13 +93,13 @@ else:
|
||||
# and discover later on that we need to re-add all this infrastructure...
|
||||
# _GENERIC_TYPES_CACHE = DeepChainMap(GenericTypesCache(), LimitedDict())
|
||||
|
||||
_GENERIC_TYPES_CACHE = GenericTypesCache()
|
||||
_GENERIC_TYPES_CACHE: ContextVar[GenericTypesCache | None] = ContextVar('_GENERIC_TYPES_CACHE', default=None)
|
||||
|
||||
|
||||
class PydanticGenericMetadata(typing_extensions.TypedDict):
|
||||
class PydanticGenericMetadata(TypedDict):
|
||||
origin: type[BaseModel] | None # analogous to typing._GenericAlias.__origin__
|
||||
args: tuple[Any, ...] # analogous to typing._GenericAlias.__args__
|
||||
parameters: tuple[type[Any], ...] # analogous to typing.Generic.__parameters__
|
||||
parameters: tuple[TypeVar, ...] # analogous to typing.Generic.__parameters__
|
||||
|
||||
|
||||
def create_generic_submodel(
|
||||
@@ -171,7 +156,7 @@ def _get_caller_frame_info(depth: int = 2) -> tuple[str | None, bool]:
|
||||
depth: The depth to get the frame.
|
||||
|
||||
Returns:
|
||||
A tuple contains `module_nam` and `called_globally`.
|
||||
A tuple contains `module_name` and `called_globally`.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the function is not called inside a function.
|
||||
@@ -189,7 +174,7 @@ def _get_caller_frame_info(depth: int = 2) -> tuple[str | None, bool]:
|
||||
DictValues: type[Any] = {}.values().__class__
|
||||
|
||||
|
||||
def iter_contained_typevars(v: Any) -> Iterator[TypeVarType]:
|
||||
def iter_contained_typevars(v: Any) -> Iterator[TypeVar]:
|
||||
"""Recursively iterate through all subtypes and type args of `v` and yield any typevars that are found.
|
||||
|
||||
This is inspired as an alternative to directly accessing the `__parameters__` attribute of a GenericAlias,
|
||||
@@ -222,7 +207,7 @@ def get_origin(v: Any) -> Any:
|
||||
return typing_extensions.get_origin(v)
|
||||
|
||||
|
||||
def get_standard_typevars_map(cls: type[Any]) -> dict[TypeVarType, Any] | None:
|
||||
def get_standard_typevars_map(cls: Any) -> dict[TypeVar, Any] | None:
|
||||
"""Package a generic type's typevars and parametrization (if present) into a dictionary compatible with the
|
||||
`replace_types` function. Specifically, this works with standard typing generics and typing._GenericAlias.
|
||||
"""
|
||||
@@ -235,11 +220,11 @@ def get_standard_typevars_map(cls: type[Any]) -> dict[TypeVarType, Any] | None:
|
||||
# In this case, we know that cls is a _GenericAlias, and origin is the generic type
|
||||
# So it is safe to access cls.__args__ and origin.__parameters__
|
||||
args: tuple[Any, ...] = cls.__args__ # type: ignore
|
||||
parameters: tuple[TypeVarType, ...] = origin.__parameters__
|
||||
parameters: tuple[TypeVar, ...] = origin.__parameters__
|
||||
return dict(zip(parameters, args))
|
||||
|
||||
|
||||
def get_model_typevars_map(cls: type[BaseModel]) -> dict[TypeVarType, Any] | None:
|
||||
def get_model_typevars_map(cls: type[BaseModel]) -> dict[TypeVar, Any]:
|
||||
"""Package a generic BaseModel's typevars and concrete parametrization (if present) into a dictionary compatible
|
||||
with the `replace_types` function.
|
||||
|
||||
@@ -251,10 +236,13 @@ def get_model_typevars_map(cls: type[BaseModel]) -> dict[TypeVarType, Any] | Non
|
||||
generic_metadata = cls.__pydantic_generic_metadata__
|
||||
origin = generic_metadata['origin']
|
||||
args = generic_metadata['args']
|
||||
if not args:
|
||||
# No need to go into `iter_contained_typevars`:
|
||||
return {}
|
||||
return dict(zip(iter_contained_typevars(origin), args))
|
||||
|
||||
|
||||
def replace_types(type_: Any, type_map: Mapping[Any, Any] | None) -> Any:
|
||||
def replace_types(type_: Any, type_map: Mapping[TypeVar, Any] | None) -> Any:
|
||||
"""Return type with all occurrences of `type_map` keys recursively replaced with their values.
|
||||
|
||||
Args:
|
||||
@@ -266,13 +254,13 @@ def replace_types(type_: Any, type_map: Mapping[Any, Any] | None) -> Any:
|
||||
`typevar_map` keys recursively replaced.
|
||||
|
||||
Example:
|
||||
```py
|
||||
from typing import List, Tuple, Union
|
||||
```python
|
||||
from typing import Union
|
||||
|
||||
from pydantic._internal._generics import replace_types
|
||||
|
||||
replace_types(Tuple[str, Union[List[str], float]], {str: int})
|
||||
#> Tuple[int, Union[List[int], float]]
|
||||
replace_types(tuple[str, Union[list[str], float]], {str: int})
|
||||
#> tuple[int, Union[list[int], float]]
|
||||
```
|
||||
"""
|
||||
if not type_map:
|
||||
@@ -281,25 +269,25 @@ def replace_types(type_: Any, type_map: Mapping[Any, Any] | None) -> Any:
|
||||
type_args = get_args(type_)
|
||||
origin_type = get_origin(type_)
|
||||
|
||||
if origin_type is typing_extensions.Annotated:
|
||||
if typing_objects.is_annotated(origin_type):
|
||||
annotated_type, *annotations = type_args
|
||||
annotated = replace_types(annotated_type, type_map)
|
||||
for annotation in annotations:
|
||||
annotated = typing_extensions.Annotated[annotated, annotation]
|
||||
return annotated
|
||||
annotated_type = replace_types(annotated_type, type_map)
|
||||
# TODO remove parentheses when we drop support for Python 3.10:
|
||||
return Annotated[(annotated_type, *annotations)]
|
||||
|
||||
# Having type args is a good indicator that this is a typing module
|
||||
# class instantiation or a generic alias of some sort.
|
||||
# Having type args is a good indicator that this is a typing special form
|
||||
# instance or a generic alias of some sort.
|
||||
if type_args:
|
||||
resolved_type_args = tuple(replace_types(arg, type_map) for arg in type_args)
|
||||
if all_identical(type_args, resolved_type_args):
|
||||
# If all arguments are the same, there is no need to modify the
|
||||
# type or create a new object at all
|
||||
return type_
|
||||
|
||||
if (
|
||||
origin_type is not None
|
||||
and isinstance(type_, typing_base)
|
||||
and not isinstance(origin_type, typing_base)
|
||||
and isinstance(type_, _typing_extra.typing_base)
|
||||
and not isinstance(origin_type, _typing_extra.typing_base)
|
||||
and getattr(type_, '_name', None) is not None
|
||||
):
|
||||
# In python < 3.9 generic aliases don't exist so any of these like `list`,
|
||||
@@ -307,10 +295,22 @@ def replace_types(type_: Any, type_map: Mapping[Any, Any] | None) -> Any:
|
||||
# See: https://www.python.org/dev/peps/pep-0585
|
||||
origin_type = getattr(typing, type_._name)
|
||||
assert origin_type is not None
|
||||
|
||||
if is_union_origin(origin_type):
|
||||
if any(typing_objects.is_any(arg) for arg in resolved_type_args):
|
||||
# `Any | T` ~ `Any`:
|
||||
resolved_type_args = (Any,)
|
||||
# `Never | T` ~ `T`:
|
||||
resolved_type_args = tuple(
|
||||
arg
|
||||
for arg in resolved_type_args
|
||||
if not (typing_objects.is_noreturn(arg) or typing_objects.is_never(arg))
|
||||
)
|
||||
|
||||
# PEP-604 syntax (Ex.: list | str) is represented with a types.UnionType object that does not have __getitem__.
|
||||
# We also cannot use isinstance() since we have to compare types.
|
||||
if sys.version_info >= (3, 10) and origin_type is types.UnionType:
|
||||
return _UnionGenericAlias(origin_type, resolved_type_args)
|
||||
return reduce(operator.or_, resolved_type_args)
|
||||
# NotRequired[T] and Required[T] don't support tuple type resolved_type_args, hence the condition below
|
||||
return origin_type[resolved_type_args[0] if len(resolved_type_args) == 1 else resolved_type_args]
|
||||
|
||||
@@ -328,8 +328,8 @@ def replace_types(type_: Any, type_map: Mapping[Any, Any] | None) -> Any:
|
||||
|
||||
# Handle special case for typehints that can have lists as arguments.
|
||||
# `typing.Callable[[int, str], int]` is an example for this.
|
||||
if isinstance(type_, (List, list)):
|
||||
resolved_list = list(replace_types(element, type_map) for element in type_)
|
||||
if isinstance(type_, list):
|
||||
resolved_list = [replace_types(element, type_map) for element in type_]
|
||||
if all_identical(type_, resolved_list):
|
||||
return type_
|
||||
return resolved_list
|
||||
@@ -339,49 +339,57 @@ def replace_types(type_: Any, type_map: Mapping[Any, Any] | None) -> Any:
|
||||
return type_map.get(type_, type_)
|
||||
|
||||
|
||||
def has_instance_in_type(type_: Any, isinstance_target: Any) -> bool:
|
||||
"""Checks if the type, or any of its arbitrary nested args, satisfy
|
||||
`isinstance(<type>, isinstance_target)`.
|
||||
"""
|
||||
if isinstance(type_, isinstance_target):
|
||||
return True
|
||||
|
||||
type_args = get_args(type_)
|
||||
origin_type = get_origin(type_)
|
||||
|
||||
if origin_type is typing_extensions.Annotated:
|
||||
annotated_type, *annotations = type_args
|
||||
return has_instance_in_type(annotated_type, isinstance_target)
|
||||
|
||||
# Having type args is a good indicator that this is a typing module
|
||||
# class instantiation or a generic alias of some sort.
|
||||
if any(has_instance_in_type(a, isinstance_target) for a in type_args):
|
||||
return True
|
||||
|
||||
# Handle special case for typehints that can have lists as arguments.
|
||||
# `typing.Callable[[int, str], int]` is an example for this.
|
||||
if isinstance(type_, (List, list)) and not isinstance(type_, typing_extensions.ParamSpec):
|
||||
if any(has_instance_in_type(element, isinstance_target) for element in type_):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def check_parameters_count(cls: type[BaseModel], parameters: tuple[Any, ...]) -> None:
|
||||
"""Check the generic model parameters count is equal.
|
||||
|
||||
Args:
|
||||
cls: The generic model.
|
||||
parameters: A tuple of passed parameters to the generic model.
|
||||
def map_generic_model_arguments(cls: type[BaseModel], args: tuple[Any, ...]) -> dict[TypeVar, Any]:
|
||||
"""Return a mapping between the parameters of a generic model and the provided arguments during parameterization.
|
||||
|
||||
Raises:
|
||||
TypeError: If the passed parameters count is not equal to generic model parameters count.
|
||||
TypeError: If the number of arguments does not match the parameters (i.e. if providing too few or too many arguments).
|
||||
|
||||
Example:
|
||||
```python {test="skip" lint="skip"}
|
||||
class Model[T, U, V = int](BaseModel): ...
|
||||
|
||||
map_generic_model_arguments(Model, (str, bytes))
|
||||
#> {T: str, U: bytes, V: int}
|
||||
|
||||
map_generic_model_arguments(Model, (str,))
|
||||
#> TypeError: Too few arguments for <class '__main__.Model'>; actual 1, expected at least 2
|
||||
|
||||
map_generic_model_arguments(Model, (str, bytes, int, complex))
|
||||
#> TypeError: Too many arguments for <class '__main__.Model'>; actual 4, expected 3
|
||||
```
|
||||
|
||||
Note:
|
||||
This function is analogous to the private `typing._check_generic_specialization` function.
|
||||
"""
|
||||
actual = len(parameters)
|
||||
expected = len(cls.__pydantic_generic_metadata__['parameters'])
|
||||
if actual != expected:
|
||||
description = 'many' if actual > expected else 'few'
|
||||
raise TypeError(f'Too {description} parameters for {cls}; actual {actual}, expected {expected}')
|
||||
parameters = cls.__pydantic_generic_metadata__['parameters']
|
||||
expected_len = len(parameters)
|
||||
typevars_map: dict[TypeVar, Any] = {}
|
||||
|
||||
_missing = object()
|
||||
for parameter, argument in zip_longest(parameters, args, fillvalue=_missing):
|
||||
if parameter is _missing:
|
||||
raise TypeError(f'Too many arguments for {cls}; actual {len(args)}, expected {expected_len}')
|
||||
|
||||
if argument is _missing:
|
||||
param = cast(TypeVar, parameter)
|
||||
try:
|
||||
has_default = param.has_default() # pyright: ignore[reportAttributeAccessIssue]
|
||||
except AttributeError:
|
||||
# Happens if using `typing.TypeVar` (and not `typing_extensions`) on Python < 3.13.
|
||||
has_default = False
|
||||
if has_default:
|
||||
# The default might refer to other type parameters. For an example, see:
|
||||
# https://typing.python.org/en/latest/spec/generics.html#type-parameters-as-parameters-to-generics
|
||||
typevars_map[param] = replace_types(param.__default__, typevars_map) # pyright: ignore[reportAttributeAccessIssue]
|
||||
else:
|
||||
expected_len -= sum(hasattr(p, 'has_default') and p.has_default() for p in parameters) # pyright: ignore[reportAttributeAccessIssue]
|
||||
raise TypeError(f'Too few arguments for {cls}; actual {len(args)}, expected at least {expected_len}')
|
||||
else:
|
||||
param = cast(TypeVar, parameter)
|
||||
typevars_map[param] = argument
|
||||
|
||||
return typevars_map
|
||||
|
||||
|
||||
_generic_recursion_cache: ContextVar[set[str] | None] = ContextVar('_generic_recursion_cache', default=None)
|
||||
@@ -412,7 +420,8 @@ def generic_recursion_self_type(
|
||||
yield self_type
|
||||
else:
|
||||
previously_seen_type_refs.add(type_ref)
|
||||
yield None
|
||||
yield
|
||||
previously_seen_type_refs.remove(type_ref)
|
||||
finally:
|
||||
if token:
|
||||
_generic_recursion_cache.reset(token)
|
||||
@@ -443,14 +452,24 @@ def get_cached_generic_type_early(parent: type[BaseModel], typevar_values: Any)
|
||||
during validation, I think it is worthwhile to ensure that types that are functionally equivalent are actually
|
||||
equal.
|
||||
"""
|
||||
return _GENERIC_TYPES_CACHE.get(_early_cache_key(parent, typevar_values))
|
||||
generic_types_cache = _GENERIC_TYPES_CACHE.get()
|
||||
if generic_types_cache is None:
|
||||
generic_types_cache = GenericTypesCache()
|
||||
_GENERIC_TYPES_CACHE.set(generic_types_cache)
|
||||
return generic_types_cache.get(_early_cache_key(parent, typevar_values))
|
||||
|
||||
|
||||
def get_cached_generic_type_late(
|
||||
parent: type[BaseModel], typevar_values: Any, origin: type[BaseModel], args: tuple[Any, ...]
|
||||
) -> type[BaseModel] | None:
|
||||
"""See the docstring of `get_cached_generic_type_early` for more information about the two-stage cache lookup."""
|
||||
cached = _GENERIC_TYPES_CACHE.get(_late_cache_key(origin, args, typevar_values))
|
||||
generic_types_cache = _GENERIC_TYPES_CACHE.get()
|
||||
if (
|
||||
generic_types_cache is None
|
||||
): # pragma: no cover (early cache is guaranteed to run first and initialize the cache)
|
||||
generic_types_cache = GenericTypesCache()
|
||||
_GENERIC_TYPES_CACHE.set(generic_types_cache)
|
||||
cached = generic_types_cache.get(_late_cache_key(origin, args, typevar_values))
|
||||
if cached is not None:
|
||||
set_cached_generic_type(parent, typevar_values, cached, origin, args)
|
||||
return cached
|
||||
@@ -466,11 +485,17 @@ def set_cached_generic_type(
|
||||
"""See the docstring of `get_cached_generic_type_early` for more information about why items are cached with
|
||||
two different keys.
|
||||
"""
|
||||
_GENERIC_TYPES_CACHE[_early_cache_key(parent, typevar_values)] = type_
|
||||
generic_types_cache = _GENERIC_TYPES_CACHE.get()
|
||||
if (
|
||||
generic_types_cache is None
|
||||
): # pragma: no cover (cache lookup is guaranteed to run first and initialize the cache)
|
||||
generic_types_cache = GenericTypesCache()
|
||||
_GENERIC_TYPES_CACHE.set(generic_types_cache)
|
||||
generic_types_cache[_early_cache_key(parent, typevar_values)] = type_
|
||||
if len(typevar_values) == 1:
|
||||
_GENERIC_TYPES_CACHE[_early_cache_key(parent, typevar_values[0])] = type_
|
||||
generic_types_cache[_early_cache_key(parent, typevar_values[0])] = type_
|
||||
if origin and args:
|
||||
_GENERIC_TYPES_CACHE[_late_cache_key(origin, args, typevar_values)] = type_
|
||||
generic_types_cache[_late_cache_key(origin, args, typevar_values)] = type_
|
||||
|
||||
|
||||
def _union_orderings_key(typevar_values: Any) -> Any:
|
||||
@@ -487,11 +512,8 @@ def _union_orderings_key(typevar_values: Any) -> Any:
|
||||
(See https://github.com/python/cpython/issues/86483 for reference.)
|
||||
"""
|
||||
if isinstance(typevar_values, tuple):
|
||||
args_data = []
|
||||
for value in typevar_values:
|
||||
args_data.append(_union_orderings_key(value))
|
||||
return tuple(args_data)
|
||||
elif typing_extensions.get_origin(typevar_values) is typing.Union:
|
||||
return tuple(_union_orderings_key(value) for value in typevar_values)
|
||||
elif typing_objects.is_union(typing_extensions.get_origin(typevar_values)):
|
||||
return get_args(typevar_values)
|
||||
else:
|
||||
return ()
|
||||
|
||||
@@ -0,0 +1,27 @@
|
||||
"""Git utilities, adopted from mypy's git utilities (https://github.com/python/mypy/blob/master/mypy/git.py)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def is_git_repo(dir: Path) -> bool:
|
||||
"""Is the given directory version-controlled with git?"""
|
||||
return dir.joinpath('.git').exists()
|
||||
|
||||
|
||||
def have_git() -> bool: # pragma: no cover
|
||||
"""Can we run the git executable?"""
|
||||
try:
|
||||
subprocess.check_output(['git', '--help'])
|
||||
return True
|
||||
except subprocess.CalledProcessError:
|
||||
return False
|
||||
except OSError:
|
||||
return False
|
||||
|
||||
|
||||
def git_revision(dir: Path) -> str:
|
||||
"""Get the SHA-1 of the HEAD of a git repository."""
|
||||
return subprocess.check_output(['git', 'rev-parse', '--short', 'HEAD'], cwd=dir).decode('utf-8').strip()
|
||||
@@ -0,0 +1,20 @@
|
||||
from functools import cache
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic import BaseModel
|
||||
from pydantic.fields import FieldInfo
|
||||
|
||||
|
||||
@cache
|
||||
def import_cached_base_model() -> type['BaseModel']:
|
||||
from pydantic import BaseModel
|
||||
|
||||
return BaseModel
|
||||
|
||||
|
||||
@cache
|
||||
def import_cached_field_info() -> type['FieldInfo']:
|
||||
from pydantic.fields import FieldInfo
|
||||
|
||||
return FieldInfo
|
||||
@@ -1,7 +1,4 @@
|
||||
import sys
|
||||
from typing import Any, Dict
|
||||
|
||||
dataclass_kwargs: Dict[str, Any]
|
||||
|
||||
# `slots` is available on Python >= 3.10
|
||||
if sys.version_info >= (3, 10):
|
||||
|
||||
@@ -1,35 +1,48 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable
|
||||
from copy import copy
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, Callable, Iterable
|
||||
from functools import lru_cache, partial
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from pydantic_core import CoreSchema, PydanticCustomError, to_jsonable_python
|
||||
from pydantic_core import CoreSchema, PydanticCustomError, ValidationError, to_jsonable_python
|
||||
from pydantic_core import core_schema as cs
|
||||
|
||||
from ._fields import PydanticMetadata
|
||||
from ._import_utils import import_cached_field_info
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..annotated_handlers import GetJsonSchemaHandler
|
||||
|
||||
pass
|
||||
|
||||
STRICT = {'strict'}
|
||||
SEQUENCE_CONSTRAINTS = {'min_length', 'max_length'}
|
||||
FAIL_FAST = {'fail_fast'}
|
||||
LENGTH_CONSTRAINTS = {'min_length', 'max_length'}
|
||||
INEQUALITY = {'le', 'ge', 'lt', 'gt'}
|
||||
NUMERIC_CONSTRAINTS = {'multiple_of', 'allow_inf_nan', *INEQUALITY}
|
||||
NUMERIC_CONSTRAINTS = {'multiple_of', *INEQUALITY}
|
||||
ALLOW_INF_NAN = {'allow_inf_nan'}
|
||||
|
||||
STR_CONSTRAINTS = {*SEQUENCE_CONSTRAINTS, *STRICT, 'strip_whitespace', 'to_lower', 'to_upper', 'pattern'}
|
||||
BYTES_CONSTRAINTS = {*SEQUENCE_CONSTRAINTS, *STRICT}
|
||||
STR_CONSTRAINTS = {
|
||||
*LENGTH_CONSTRAINTS,
|
||||
*STRICT,
|
||||
'strip_whitespace',
|
||||
'to_lower',
|
||||
'to_upper',
|
||||
'pattern',
|
||||
'coerce_numbers_to_str',
|
||||
}
|
||||
BYTES_CONSTRAINTS = {*LENGTH_CONSTRAINTS, *STRICT}
|
||||
|
||||
LIST_CONSTRAINTS = {*SEQUENCE_CONSTRAINTS, *STRICT}
|
||||
TUPLE_CONSTRAINTS = {*SEQUENCE_CONSTRAINTS, *STRICT}
|
||||
SET_CONSTRAINTS = {*SEQUENCE_CONSTRAINTS, *STRICT}
|
||||
DICT_CONSTRAINTS = {*SEQUENCE_CONSTRAINTS, *STRICT}
|
||||
GENERATOR_CONSTRAINTS = {*SEQUENCE_CONSTRAINTS, *STRICT}
|
||||
LIST_CONSTRAINTS = {*LENGTH_CONSTRAINTS, *STRICT, *FAIL_FAST}
|
||||
TUPLE_CONSTRAINTS = {*LENGTH_CONSTRAINTS, *STRICT, *FAIL_FAST}
|
||||
SET_CONSTRAINTS = {*LENGTH_CONSTRAINTS, *STRICT, *FAIL_FAST}
|
||||
DICT_CONSTRAINTS = {*LENGTH_CONSTRAINTS, *STRICT}
|
||||
GENERATOR_CONSTRAINTS = {*LENGTH_CONSTRAINTS, *STRICT}
|
||||
SEQUENCE_CONSTRAINTS = {*LENGTH_CONSTRAINTS, *FAIL_FAST}
|
||||
|
||||
FLOAT_CONSTRAINTS = {*NUMERIC_CONSTRAINTS, *STRICT}
|
||||
INT_CONSTRAINTS = {*NUMERIC_CONSTRAINTS, *STRICT}
|
||||
FLOAT_CONSTRAINTS = {*NUMERIC_CONSTRAINTS, *ALLOW_INF_NAN, *STRICT}
|
||||
DECIMAL_CONSTRAINTS = {'max_digits', 'decimal_places', *FLOAT_CONSTRAINTS}
|
||||
INT_CONSTRAINTS = {*NUMERIC_CONSTRAINTS, *ALLOW_INF_NAN, *STRICT}
|
||||
BOOL_CONSTRAINTS = STRICT
|
||||
UUID_CONSTRAINTS = STRICT
|
||||
|
||||
@@ -37,6 +50,8 @@ DATE_TIME_CONSTRAINTS = {*NUMERIC_CONSTRAINTS, *STRICT}
|
||||
TIMEDELTA_CONSTRAINTS = {*NUMERIC_CONSTRAINTS, *STRICT}
|
||||
TIME_CONSTRAINTS = {*NUMERIC_CONSTRAINTS, *STRICT}
|
||||
LAX_OR_STRICT_CONSTRAINTS = STRICT
|
||||
ENUM_CONSTRAINTS = STRICT
|
||||
COMPLEX_CONSTRAINTS = STRICT
|
||||
|
||||
UNION_CONSTRAINTS = {'union_mode'}
|
||||
URL_CONSTRAINTS = {
|
||||
@@ -53,58 +68,33 @@ SEQUENCE_SCHEMA_TYPES = ('list', 'tuple', 'set', 'frozenset', 'generator', *TEXT
|
||||
NUMERIC_SCHEMA_TYPES = ('float', 'int', 'date', 'time', 'timedelta', 'datetime')
|
||||
|
||||
CONSTRAINTS_TO_ALLOWED_SCHEMAS: dict[str, set[str]] = defaultdict(set)
|
||||
for constraint in STR_CONSTRAINTS:
|
||||
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(TEXT_SCHEMA_TYPES)
|
||||
for constraint in BYTES_CONSTRAINTS:
|
||||
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('bytes',))
|
||||
for constraint in LIST_CONSTRAINTS:
|
||||
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('list',))
|
||||
for constraint in TUPLE_CONSTRAINTS:
|
||||
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('tuple',))
|
||||
for constraint in SET_CONSTRAINTS:
|
||||
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('set', 'frozenset'))
|
||||
for constraint in DICT_CONSTRAINTS:
|
||||
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('dict',))
|
||||
for constraint in GENERATOR_CONSTRAINTS:
|
||||
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('generator',))
|
||||
for constraint in FLOAT_CONSTRAINTS:
|
||||
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('float',))
|
||||
for constraint in INT_CONSTRAINTS:
|
||||
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('int',))
|
||||
for constraint in DATE_TIME_CONSTRAINTS:
|
||||
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('date', 'time', 'datetime'))
|
||||
for constraint in TIMEDELTA_CONSTRAINTS:
|
||||
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('timedelta',))
|
||||
for constraint in TIME_CONSTRAINTS:
|
||||
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('time',))
|
||||
for schema_type in (*TEXT_SCHEMA_TYPES, *SEQUENCE_SCHEMA_TYPES, *NUMERIC_SCHEMA_TYPES, 'typed-dict', 'model'):
|
||||
CONSTRAINTS_TO_ALLOWED_SCHEMAS['strict'].add(schema_type)
|
||||
for constraint in UNION_CONSTRAINTS:
|
||||
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('union',))
|
||||
for constraint in URL_CONSTRAINTS:
|
||||
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('url', 'multi-host-url'))
|
||||
for constraint in BOOL_CONSTRAINTS:
|
||||
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('bool',))
|
||||
for constraint in UUID_CONSTRAINTS:
|
||||
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('uuid',))
|
||||
for constraint in LAX_OR_STRICT_CONSTRAINTS:
|
||||
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('lax-or-strict',))
|
||||
|
||||
constraint_schema_pairings: list[tuple[set[str], tuple[str, ...]]] = [
|
||||
(STR_CONSTRAINTS, TEXT_SCHEMA_TYPES),
|
||||
(BYTES_CONSTRAINTS, ('bytes',)),
|
||||
(LIST_CONSTRAINTS, ('list',)),
|
||||
(TUPLE_CONSTRAINTS, ('tuple',)),
|
||||
(SET_CONSTRAINTS, ('set', 'frozenset')),
|
||||
(DICT_CONSTRAINTS, ('dict',)),
|
||||
(GENERATOR_CONSTRAINTS, ('generator',)),
|
||||
(FLOAT_CONSTRAINTS, ('float',)),
|
||||
(INT_CONSTRAINTS, ('int',)),
|
||||
(DATE_TIME_CONSTRAINTS, ('date', 'time', 'datetime', 'timedelta')),
|
||||
# TODO: this is a bit redundant, we could probably avoid some of these
|
||||
(STRICT, (*TEXT_SCHEMA_TYPES, *SEQUENCE_SCHEMA_TYPES, *NUMERIC_SCHEMA_TYPES, 'typed-dict', 'model')),
|
||||
(UNION_CONSTRAINTS, ('union',)),
|
||||
(URL_CONSTRAINTS, ('url', 'multi-host-url')),
|
||||
(BOOL_CONSTRAINTS, ('bool',)),
|
||||
(UUID_CONSTRAINTS, ('uuid',)),
|
||||
(LAX_OR_STRICT_CONSTRAINTS, ('lax-or-strict',)),
|
||||
(ENUM_CONSTRAINTS, ('enum',)),
|
||||
(DECIMAL_CONSTRAINTS, ('decimal',)),
|
||||
(COMPLEX_CONSTRAINTS, ('complex',)),
|
||||
]
|
||||
|
||||
def add_js_update_schema(s: cs.CoreSchema, f: Callable[[], dict[str, Any]]) -> None:
|
||||
def update_js_schema(s: cs.CoreSchema, handler: GetJsonSchemaHandler) -> dict[str, Any]:
|
||||
js_schema = handler(s)
|
||||
js_schema.update(f())
|
||||
return js_schema
|
||||
|
||||
if 'metadata' in s:
|
||||
metadata = s['metadata']
|
||||
if 'pydantic_js_functions' in s:
|
||||
metadata['pydantic_js_functions'].append(update_js_schema)
|
||||
else:
|
||||
metadata['pydantic_js_functions'] = [update_js_schema]
|
||||
else:
|
||||
s['metadata'] = {'pydantic_js_functions': [update_js_schema]}
|
||||
for constraints, schemas in constraint_schema_pairings:
|
||||
for c in constraints:
|
||||
CONSTRAINTS_TO_ALLOWED_SCHEMAS[c].update(schemas)
|
||||
|
||||
|
||||
def as_jsonable_value(v: Any) -> Any:
|
||||
@@ -123,7 +113,7 @@ def expand_grouped_metadata(annotations: Iterable[Any]) -> Iterable[Any]:
|
||||
An iterable of expanded annotations.
|
||||
|
||||
Example:
|
||||
```py
|
||||
```python
|
||||
from annotated_types import Ge, Len
|
||||
|
||||
from pydantic._internal._known_annotated_metadata import expand_grouped_metadata
|
||||
@@ -134,7 +124,7 @@ def expand_grouped_metadata(annotations: Iterable[Any]) -> Iterable[Any]:
|
||||
"""
|
||||
import annotated_types as at
|
||||
|
||||
from pydantic.fields import FieldInfo # circular import
|
||||
FieldInfo = import_cached_field_info()
|
||||
|
||||
for annotation in annotations:
|
||||
if isinstance(annotation, at.GroupedMetadata):
|
||||
@@ -153,6 +143,28 @@ def expand_grouped_metadata(annotations: Iterable[Any]) -> Iterable[Any]:
|
||||
yield annotation
|
||||
|
||||
|
||||
@lru_cache
|
||||
def _get_at_to_constraint_map() -> dict[type, str]:
|
||||
"""Return a mapping of annotated types to constraints.
|
||||
|
||||
Normally, we would define a mapping like this in the module scope, but we can't do that
|
||||
because we don't permit module level imports of `annotated_types`, in an attempt to speed up
|
||||
the import time of `pydantic`. We still only want to have this dictionary defined in one place,
|
||||
so we use this function to cache the result.
|
||||
"""
|
||||
import annotated_types as at
|
||||
|
||||
return {
|
||||
at.Gt: 'gt',
|
||||
at.Ge: 'ge',
|
||||
at.Lt: 'lt',
|
||||
at.Le: 'le',
|
||||
at.MultipleOf: 'multiple_of',
|
||||
at.MinLen: 'min_length',
|
||||
at.MaxLen: 'max_length',
|
||||
}
|
||||
|
||||
|
||||
def apply_known_metadata(annotation: Any, schema: CoreSchema) -> CoreSchema | None: # noqa: C901
|
||||
"""Apply `annotation` to `schema` if it is an annotation we know about (Gt, Le, etc.).
|
||||
Otherwise return `None`.
|
||||
@@ -170,20 +182,40 @@ def apply_known_metadata(annotation: Any, schema: CoreSchema) -> CoreSchema | No
|
||||
An updated schema with annotation if it is an annotation we know about, `None` otherwise.
|
||||
|
||||
Raises:
|
||||
PydanticCustomError: If `Predicate` fails.
|
||||
RuntimeError: If a constraint can't be applied to a specific schema type.
|
||||
ValueError: If an unknown constraint is encountered.
|
||||
"""
|
||||
import annotated_types as at
|
||||
|
||||
from . import _validators
|
||||
from ._validators import NUMERIC_VALIDATOR_LOOKUP, forbid_inf_nan_check
|
||||
|
||||
schema = schema.copy()
|
||||
schema_update, other_metadata = collect_known_metadata([annotation])
|
||||
schema_type = schema['type']
|
||||
|
||||
chain_schema_constraints: set[str] = {
|
||||
'pattern',
|
||||
'strip_whitespace',
|
||||
'to_lower',
|
||||
'to_upper',
|
||||
'coerce_numbers_to_str',
|
||||
}
|
||||
chain_schema_steps: list[CoreSchema] = []
|
||||
|
||||
for constraint, value in schema_update.items():
|
||||
if constraint not in CONSTRAINTS_TO_ALLOWED_SCHEMAS:
|
||||
raise ValueError(f'Unknown constraint {constraint}')
|
||||
allowed_schemas = CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint]
|
||||
|
||||
# if it becomes necessary to handle more than one constraint
|
||||
# in this recursive case with function-after or function-wrap, we should refactor
|
||||
# this is a bit challenging because we sometimes want to apply constraints to the inner schema,
|
||||
# whereas other times we want to wrap the existing schema with a new one that enforces a new constraint.
|
||||
if schema_type in {'function-before', 'function-wrap', 'function-after'} and constraint == 'strict':
|
||||
schema['schema'] = apply_known_metadata(annotation, schema['schema']) # type: ignore # schema is function schema
|
||||
return schema
|
||||
|
||||
# if we're allowed to apply constraint directly to the schema, like le to int, do that
|
||||
if schema_type in allowed_schemas:
|
||||
if constraint == 'union_mode' and schema_type == 'union':
|
||||
schema['mode'] = value # type: ignore # schema is UnionSchema
|
||||
@@ -191,145 +223,116 @@ def apply_known_metadata(annotation: Any, schema: CoreSchema) -> CoreSchema | No
|
||||
schema[constraint] = value
|
||||
continue
|
||||
|
||||
if constraint == 'allow_inf_nan' and value is False:
|
||||
return cs.no_info_after_validator_function(
|
||||
_validators.forbid_inf_nan_check,
|
||||
schema,
|
||||
# else, apply a function after validator to the schema to enforce the corresponding constraint
|
||||
if constraint in chain_schema_constraints:
|
||||
|
||||
def _apply_constraint_with_incompatibility_info(
|
||||
value: Any, handler: cs.ValidatorFunctionWrapHandler
|
||||
) -> Any:
|
||||
try:
|
||||
x = handler(value)
|
||||
except ValidationError as ve:
|
||||
# if the error is about the type, it's likely that the constraint is incompatible the type of the field
|
||||
# for example, the following invalid schema wouldn't be caught during schema build, but rather at this point
|
||||
# with a cryptic 'string_type' error coming from the string validator,
|
||||
# that we'd rather express as a constraint incompatibility error (TypeError)
|
||||
# Annotated[list[int], Field(pattern='abc')]
|
||||
if 'type' in ve.errors()[0]['type']:
|
||||
raise TypeError(
|
||||
f"Unable to apply constraint '{constraint}' to supplied value {value} for schema of type '{schema_type}'" # noqa: B023
|
||||
)
|
||||
raise ve
|
||||
return x
|
||||
|
||||
chain_schema_steps.append(
|
||||
cs.no_info_wrap_validator_function(
|
||||
_apply_constraint_with_incompatibility_info, cs.str_schema(**{constraint: value})
|
||||
)
|
||||
)
|
||||
elif constraint == 'pattern':
|
||||
# insert a str schema to make sure the regex engine matches
|
||||
return cs.chain_schema(
|
||||
[
|
||||
schema,
|
||||
cs.str_schema(pattern=value),
|
||||
]
|
||||
elif constraint in NUMERIC_VALIDATOR_LOOKUP:
|
||||
if constraint in LENGTH_CONSTRAINTS:
|
||||
inner_schema = schema
|
||||
while inner_schema['type'] in {'function-before', 'function-wrap', 'function-after'}:
|
||||
inner_schema = inner_schema['schema'] # type: ignore
|
||||
inner_schema_type = inner_schema['type']
|
||||
if inner_schema_type == 'list' or (
|
||||
inner_schema_type == 'json-or-python' and inner_schema['json_schema']['type'] == 'list' # type: ignore
|
||||
):
|
||||
js_constraint_key = 'minItems' if constraint == 'min_length' else 'maxItems'
|
||||
else:
|
||||
js_constraint_key = 'minLength' if constraint == 'min_length' else 'maxLength'
|
||||
else:
|
||||
js_constraint_key = constraint
|
||||
|
||||
schema = cs.no_info_after_validator_function(
|
||||
partial(NUMERIC_VALIDATOR_LOOKUP[constraint], **{constraint: value}), schema
|
||||
)
|
||||
elif constraint == 'gt':
|
||||
s = cs.no_info_after_validator_function(
|
||||
partial(_validators.greater_than_validator, gt=value),
|
||||
schema,
|
||||
)
|
||||
add_js_update_schema(s, lambda: {'gt': as_jsonable_value(value)})
|
||||
return s
|
||||
elif constraint == 'ge':
|
||||
return cs.no_info_after_validator_function(
|
||||
partial(_validators.greater_than_or_equal_validator, ge=value),
|
||||
schema,
|
||||
)
|
||||
elif constraint == 'lt':
|
||||
return cs.no_info_after_validator_function(
|
||||
partial(_validators.less_than_validator, lt=value),
|
||||
schema,
|
||||
)
|
||||
elif constraint == 'le':
|
||||
return cs.no_info_after_validator_function(
|
||||
partial(_validators.less_than_or_equal_validator, le=value),
|
||||
schema,
|
||||
)
|
||||
elif constraint == 'multiple_of':
|
||||
return cs.no_info_after_validator_function(
|
||||
partial(_validators.multiple_of_validator, multiple_of=value),
|
||||
schema,
|
||||
)
|
||||
elif constraint == 'min_length':
|
||||
s = cs.no_info_after_validator_function(
|
||||
partial(_validators.min_length_validator, min_length=value),
|
||||
schema,
|
||||
)
|
||||
add_js_update_schema(s, lambda: {'minLength': (as_jsonable_value(value))})
|
||||
return s
|
||||
elif constraint == 'max_length':
|
||||
s = cs.no_info_after_validator_function(
|
||||
partial(_validators.max_length_validator, max_length=value),
|
||||
schema,
|
||||
)
|
||||
add_js_update_schema(s, lambda: {'maxLength': (as_jsonable_value(value))})
|
||||
return s
|
||||
elif constraint == 'strip_whitespace':
|
||||
return cs.chain_schema(
|
||||
[
|
||||
schema,
|
||||
cs.str_schema(strip_whitespace=True),
|
||||
]
|
||||
)
|
||||
elif constraint == 'to_lower':
|
||||
return cs.chain_schema(
|
||||
[
|
||||
schema,
|
||||
cs.str_schema(to_lower=True),
|
||||
]
|
||||
)
|
||||
elif constraint == 'to_upper':
|
||||
return cs.chain_schema(
|
||||
[
|
||||
schema,
|
||||
cs.str_schema(to_upper=True),
|
||||
]
|
||||
)
|
||||
elif constraint == 'min_length':
|
||||
return cs.no_info_after_validator_function(
|
||||
partial(_validators.min_length_validator, min_length=annotation.min_length),
|
||||
schema,
|
||||
)
|
||||
elif constraint == 'max_length':
|
||||
return cs.no_info_after_validator_function(
|
||||
partial(_validators.max_length_validator, max_length=annotation.max_length),
|
||||
metadata = schema.get('metadata', {})
|
||||
if (existing_json_schema_updates := metadata.get('pydantic_js_updates')) is not None:
|
||||
metadata['pydantic_js_updates'] = {
|
||||
**existing_json_schema_updates,
|
||||
**{js_constraint_key: as_jsonable_value(value)},
|
||||
}
|
||||
else:
|
||||
metadata['pydantic_js_updates'] = {js_constraint_key: as_jsonable_value(value)}
|
||||
schema['metadata'] = metadata
|
||||
elif constraint == 'allow_inf_nan' and value is False:
|
||||
schema = cs.no_info_after_validator_function(
|
||||
forbid_inf_nan_check,
|
||||
schema,
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(f'Unable to apply constraint {constraint} to schema {schema_type}')
|
||||
# It's rare that we'd get here, but it's possible if we add a new constraint and forget to handle it
|
||||
# Most constraint errors are caught at runtime during attempted application
|
||||
raise RuntimeError(f"Unable to apply constraint '{constraint}' to schema of type '{schema_type}'")
|
||||
|
||||
for annotation in other_metadata:
|
||||
if isinstance(annotation, at.Gt):
|
||||
return cs.no_info_after_validator_function(
|
||||
partial(_validators.greater_than_validator, gt=annotation.gt),
|
||||
schema,
|
||||
if (annotation_type := type(annotation)) in (at_to_constraint_map := _get_at_to_constraint_map()):
|
||||
constraint = at_to_constraint_map[annotation_type]
|
||||
validator = NUMERIC_VALIDATOR_LOOKUP.get(constraint)
|
||||
if validator is None:
|
||||
raise ValueError(f'Unknown constraint {constraint}')
|
||||
schema = cs.no_info_after_validator_function(
|
||||
partial(validator, {constraint: getattr(annotation, constraint)}), schema
|
||||
)
|
||||
elif isinstance(annotation, at.Ge):
|
||||
return cs.no_info_after_validator_function(
|
||||
partial(_validators.greater_than_or_equal_validator, ge=annotation.ge),
|
||||
schema,
|
||||
)
|
||||
elif isinstance(annotation, at.Lt):
|
||||
return cs.no_info_after_validator_function(
|
||||
partial(_validators.less_than_validator, lt=annotation.lt),
|
||||
schema,
|
||||
)
|
||||
elif isinstance(annotation, at.Le):
|
||||
return cs.no_info_after_validator_function(
|
||||
partial(_validators.less_than_or_equal_validator, le=annotation.le),
|
||||
schema,
|
||||
)
|
||||
elif isinstance(annotation, at.MultipleOf):
|
||||
return cs.no_info_after_validator_function(
|
||||
partial(_validators.multiple_of_validator, multiple_of=annotation.multiple_of),
|
||||
schema,
|
||||
)
|
||||
elif isinstance(annotation, at.MinLen):
|
||||
return cs.no_info_after_validator_function(
|
||||
partial(_validators.min_length_validator, min_length=annotation.min_length),
|
||||
schema,
|
||||
)
|
||||
elif isinstance(annotation, at.MaxLen):
|
||||
return cs.no_info_after_validator_function(
|
||||
partial(_validators.max_length_validator, max_length=annotation.max_length),
|
||||
schema,
|
||||
)
|
||||
elif isinstance(annotation, at.Predicate):
|
||||
predicate_name = f'{annotation.func.__qualname__} ' if hasattr(annotation.func, '__qualname__') else ''
|
||||
continue
|
||||
elif isinstance(annotation, (at.Predicate, at.Not)):
|
||||
predicate_name = f'{annotation.func.__qualname__!r} ' if hasattr(annotation.func, '__qualname__') else ''
|
||||
|
||||
def val_func(v: Any) -> Any:
|
||||
# annotation.func may also raise an exception, let it pass through
|
||||
if not annotation.func(v):
|
||||
raise PydanticCustomError(
|
||||
'predicate_failed',
|
||||
f'Predicate {predicate_name}failed', # type: ignore
|
||||
)
|
||||
return v
|
||||
# Note: B023 is ignored because even though we iterate over `other_metadata`, it is guaranteed
|
||||
# to be of length 1. `apply_known_metadata()` is called from `GenerateSchema`, where annotations
|
||||
# were already expanded via `expand_grouped_metadata()`. Confusing, but this falls into the annotations
|
||||
# refactor.
|
||||
if isinstance(annotation, at.Predicate):
|
||||
|
||||
return cs.no_info_after_validator_function(val_func, schema)
|
||||
# ignore any other unknown metadata
|
||||
return None
|
||||
def val_func(v: Any) -> Any:
|
||||
predicate_satisfied = annotation.func(v) # noqa: B023
|
||||
if not predicate_satisfied:
|
||||
raise PydanticCustomError(
|
||||
'predicate_failed',
|
||||
f'Predicate {predicate_name}failed', # pyright: ignore[reportArgumentType] # noqa: B023
|
||||
)
|
||||
return v
|
||||
|
||||
else:
|
||||
|
||||
def val_func(v: Any) -> Any:
|
||||
predicate_satisfied = annotation.func(v) # noqa: B023
|
||||
if predicate_satisfied:
|
||||
raise PydanticCustomError(
|
||||
'not_operation_failed',
|
||||
f'Not of {predicate_name}failed', # pyright: ignore[reportArgumentType] # noqa: B023
|
||||
)
|
||||
return v
|
||||
|
||||
schema = cs.no_info_after_validator_function(val_func, schema)
|
||||
else:
|
||||
# ignore any other unknown metadata
|
||||
return None
|
||||
|
||||
if chain_schema_steps:
|
||||
chain_schema_steps = [schema] + chain_schema_steps
|
||||
return cs.chain_schema(chain_schema_steps)
|
||||
|
||||
return schema
|
||||
|
||||
@@ -344,7 +347,7 @@ def collect_known_metadata(annotations: Iterable[Any]) -> tuple[dict[str, Any],
|
||||
A tuple contains a dict of known metadata and a list of unknown annotations.
|
||||
|
||||
Example:
|
||||
```py
|
||||
```python
|
||||
from annotated_types import Gt, Len
|
||||
|
||||
from pydantic._internal._known_annotated_metadata import collect_known_metadata
|
||||
@@ -353,31 +356,19 @@ def collect_known_metadata(annotations: Iterable[Any]) -> tuple[dict[str, Any],
|
||||
#> ({'gt': 1, 'min_length': 42}, [Ellipsis])
|
||||
```
|
||||
"""
|
||||
import annotated_types as at
|
||||
|
||||
annotations = expand_grouped_metadata(annotations)
|
||||
|
||||
res: dict[str, Any] = {}
|
||||
remaining: list[Any] = []
|
||||
|
||||
for annotation in annotations:
|
||||
# isinstance(annotation, PydanticMetadata) also covers ._fields:_PydanticGeneralMetadata
|
||||
if isinstance(annotation, PydanticMetadata):
|
||||
res.update(annotation.__dict__)
|
||||
# we don't use dataclasses.asdict because that recursively calls asdict on the field values
|
||||
elif isinstance(annotation, at.MinLen):
|
||||
res.update({'min_length': annotation.min_length})
|
||||
elif isinstance(annotation, at.MaxLen):
|
||||
res.update({'max_length': annotation.max_length})
|
||||
elif isinstance(annotation, at.Gt):
|
||||
res.update({'gt': annotation.gt})
|
||||
elif isinstance(annotation, at.Ge):
|
||||
res.update({'ge': annotation.ge})
|
||||
elif isinstance(annotation, at.Lt):
|
||||
res.update({'lt': annotation.lt})
|
||||
elif isinstance(annotation, at.Le):
|
||||
res.update({'le': annotation.le})
|
||||
elif isinstance(annotation, at.MultipleOf):
|
||||
res.update({'multiple_of': annotation.multiple_of})
|
||||
elif (annotation_type := type(annotation)) in (at_to_constraint_map := _get_at_to_constraint_map()):
|
||||
constraint = at_to_constraint_map[annotation_type]
|
||||
res[constraint] = getattr(annotation, constraint)
|
||||
elif isinstance(annotation, type) and issubclass(annotation, PydanticMetadata):
|
||||
# also support PydanticMetadata classes being used without initialisation,
|
||||
# e.g. `Annotated[int, Strict]` as well as `Annotated[int, Strict()]`
|
||||
|
||||
@@ -1,18 +1,71 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Callable, Generic, TypeVar
|
||||
from collections.abc import Iterator, Mapping
|
||||
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, TypeVar, Union
|
||||
|
||||
from pydantic_core import SchemaSerializer, SchemaValidator
|
||||
from typing_extensions import Literal
|
||||
from pydantic_core import CoreSchema, SchemaSerializer, SchemaValidator
|
||||
|
||||
from ..errors import PydanticErrorCodes, PydanticUserError
|
||||
from ..plugin._schema_validator import PluggableSchemaValidator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..dataclasses import PydanticDataclass
|
||||
from ..main import BaseModel
|
||||
from ..type_adapter import TypeAdapter
|
||||
|
||||
|
||||
ValSer = TypeVar('ValSer', SchemaValidator, SchemaSerializer)
|
||||
ValSer = TypeVar('ValSer', bound=Union[SchemaValidator, PluggableSchemaValidator, SchemaSerializer])
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
class MockCoreSchema(Mapping[str, Any]):
|
||||
"""Mocker for `pydantic_core.CoreSchema` which optionally attempts to
|
||||
rebuild the thing it's mocking when one of its methods is accessed and raises an error if that fails.
|
||||
"""
|
||||
|
||||
__slots__ = '_error_message', '_code', '_attempt_rebuild', '_built_memo'
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
error_message: str,
|
||||
*,
|
||||
code: PydanticErrorCodes,
|
||||
attempt_rebuild: Callable[[], CoreSchema | None] | None = None,
|
||||
) -> None:
|
||||
self._error_message = error_message
|
||||
self._code: PydanticErrorCodes = code
|
||||
self._attempt_rebuild = attempt_rebuild
|
||||
self._built_memo: CoreSchema | None = None
|
||||
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
return self._get_built().__getitem__(key)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self._get_built().__len__()
|
||||
|
||||
def __iter__(self) -> Iterator[str]:
|
||||
return self._get_built().__iter__()
|
||||
|
||||
def _get_built(self) -> CoreSchema:
|
||||
if self._built_memo is not None:
|
||||
return self._built_memo
|
||||
|
||||
if self._attempt_rebuild:
|
||||
schema = self._attempt_rebuild()
|
||||
if schema is not None:
|
||||
self._built_memo = schema
|
||||
return schema
|
||||
raise PydanticUserError(self._error_message, code=self._code)
|
||||
|
||||
def rebuild(self) -> CoreSchema | None:
|
||||
self._built_memo = None
|
||||
if self._attempt_rebuild:
|
||||
schema = self._attempt_rebuild()
|
||||
if schema is not None:
|
||||
return schema
|
||||
else:
|
||||
raise PydanticUserError(self._error_message, code=self._code)
|
||||
return None
|
||||
|
||||
|
||||
class MockValSer(Generic[ValSer]):
|
||||
@@ -56,85 +109,120 @@ class MockValSer(Generic[ValSer]):
|
||||
return None
|
||||
|
||||
|
||||
def set_model_mocks(cls: type[BaseModel], cls_name: str, undefined_name: str = 'all referenced types') -> None:
|
||||
"""Set `__pydantic_validator__` and `__pydantic_serializer__` to `MockValSer`s on a model.
|
||||
def set_type_adapter_mocks(adapter: TypeAdapter) -> None:
|
||||
"""Set `core_schema`, `validator` and `serializer` to mock core types on a type adapter instance.
|
||||
|
||||
Args:
|
||||
cls: The model class to set the mocks on
|
||||
cls_name: Name of the model class, used in error messages
|
||||
undefined_name: Name of the undefined thing, used in error messages
|
||||
adapter: The type adapter instance to set the mocks on
|
||||
"""
|
||||
type_repr = str(adapter._type)
|
||||
undefined_type_error_message = (
|
||||
f'`{cls_name}` is not fully defined; you should define {undefined_name},'
|
||||
f' then call `{cls_name}.model_rebuild()`.'
|
||||
f'`TypeAdapter[{type_repr}]` is not fully defined; you should define `{type_repr}` and all referenced types,'
|
||||
f' then call `.rebuild()` on the instance.'
|
||||
)
|
||||
|
||||
def attempt_rebuild_validator() -> SchemaValidator | None:
|
||||
if cls.model_rebuild(raise_errors=False, _parent_namespace_depth=5) is not False:
|
||||
return cls.__pydantic_validator__
|
||||
else:
|
||||
def attempt_rebuild_fn(attr_fn: Callable[[TypeAdapter], T]) -> Callable[[], T | None]:
|
||||
def handler() -> T | None:
|
||||
if adapter.rebuild(raise_errors=False, _parent_namespace_depth=5) is not False:
|
||||
return attr_fn(adapter)
|
||||
return None
|
||||
|
||||
cls.__pydantic_validator__ = MockValSer( # type: ignore[assignment]
|
||||
return handler
|
||||
|
||||
adapter.core_schema = MockCoreSchema( # pyright: ignore[reportAttributeAccessIssue]
|
||||
undefined_type_error_message,
|
||||
code='class-not-fully-defined',
|
||||
attempt_rebuild=attempt_rebuild_fn(lambda ta: ta.core_schema),
|
||||
)
|
||||
adapter.validator = MockValSer( # pyright: ignore[reportAttributeAccessIssue]
|
||||
undefined_type_error_message,
|
||||
code='class-not-fully-defined',
|
||||
val_or_ser='validator',
|
||||
attempt_rebuild=attempt_rebuild_validator,
|
||||
attempt_rebuild=attempt_rebuild_fn(lambda ta: ta.validator),
|
||||
)
|
||||
|
||||
def attempt_rebuild_serializer() -> SchemaSerializer | None:
|
||||
if cls.model_rebuild(raise_errors=False, _parent_namespace_depth=5) is not False:
|
||||
return cls.__pydantic_serializer__
|
||||
else:
|
||||
return None
|
||||
|
||||
cls.__pydantic_serializer__ = MockValSer( # type: ignore[assignment]
|
||||
adapter.serializer = MockValSer( # pyright: ignore[reportAttributeAccessIssue]
|
||||
undefined_type_error_message,
|
||||
code='class-not-fully-defined',
|
||||
val_or_ser='serializer',
|
||||
attempt_rebuild=attempt_rebuild_serializer,
|
||||
attempt_rebuild=attempt_rebuild_fn(lambda ta: ta.serializer),
|
||||
)
|
||||
|
||||
|
||||
def set_dataclass_mocks(
|
||||
cls: type[PydanticDataclass], cls_name: str, undefined_name: str = 'all referenced types'
|
||||
) -> None:
|
||||
def set_model_mocks(cls: type[BaseModel], undefined_name: str = 'all referenced types') -> None:
|
||||
"""Set `__pydantic_core_schema__`, `__pydantic_validator__` and `__pydantic_serializer__` to mock core types on a model.
|
||||
|
||||
Args:
|
||||
cls: The model class to set the mocks on
|
||||
undefined_name: Name of the undefined thing, used in error messages
|
||||
"""
|
||||
undefined_type_error_message = (
|
||||
f'`{cls.__name__}` is not fully defined; you should define {undefined_name},'
|
||||
f' then call `{cls.__name__}.model_rebuild()`.'
|
||||
)
|
||||
|
||||
def attempt_rebuild_fn(attr_fn: Callable[[type[BaseModel]], T]) -> Callable[[], T | None]:
|
||||
def handler() -> T | None:
|
||||
if cls.model_rebuild(raise_errors=False, _parent_namespace_depth=5) is not False:
|
||||
return attr_fn(cls)
|
||||
return None
|
||||
|
||||
return handler
|
||||
|
||||
cls.__pydantic_core_schema__ = MockCoreSchema( # pyright: ignore[reportAttributeAccessIssue]
|
||||
undefined_type_error_message,
|
||||
code='class-not-fully-defined',
|
||||
attempt_rebuild=attempt_rebuild_fn(lambda c: c.__pydantic_core_schema__),
|
||||
)
|
||||
cls.__pydantic_validator__ = MockValSer( # pyright: ignore[reportAttributeAccessIssue]
|
||||
undefined_type_error_message,
|
||||
code='class-not-fully-defined',
|
||||
val_or_ser='validator',
|
||||
attempt_rebuild=attempt_rebuild_fn(lambda c: c.__pydantic_validator__),
|
||||
)
|
||||
cls.__pydantic_serializer__ = MockValSer( # pyright: ignore[reportAttributeAccessIssue]
|
||||
undefined_type_error_message,
|
||||
code='class-not-fully-defined',
|
||||
val_or_ser='serializer',
|
||||
attempt_rebuild=attempt_rebuild_fn(lambda c: c.__pydantic_serializer__),
|
||||
)
|
||||
|
||||
|
||||
def set_dataclass_mocks(cls: type[PydanticDataclass], undefined_name: str = 'all referenced types') -> None:
|
||||
"""Set `__pydantic_validator__` and `__pydantic_serializer__` to `MockValSer`s on a dataclass.
|
||||
|
||||
Args:
|
||||
cls: The model class to set the mocks on
|
||||
cls_name: Name of the model class, used in error messages
|
||||
undefined_name: Name of the undefined thing, used in error messages
|
||||
"""
|
||||
from ..dataclasses import rebuild_dataclass
|
||||
|
||||
undefined_type_error_message = (
|
||||
f'`{cls_name}` is not fully defined; you should define {undefined_name},'
|
||||
f' then call `pydantic.dataclasses.rebuild_dataclass({cls_name})`.'
|
||||
f'`{cls.__name__}` is not fully defined; you should define {undefined_name},'
|
||||
f' then call `pydantic.dataclasses.rebuild_dataclass({cls.__name__})`.'
|
||||
)
|
||||
|
||||
def attempt_rebuild_validator() -> SchemaValidator | None:
|
||||
if rebuild_dataclass(cls, raise_errors=False, _parent_namespace_depth=5) is not False:
|
||||
return cls.__pydantic_validator__
|
||||
else:
|
||||
def attempt_rebuild_fn(attr_fn: Callable[[type[PydanticDataclass]], T]) -> Callable[[], T | None]:
|
||||
def handler() -> T | None:
|
||||
if rebuild_dataclass(cls, raise_errors=False, _parent_namespace_depth=5) is not False:
|
||||
return attr_fn(cls)
|
||||
return None
|
||||
|
||||
cls.__pydantic_validator__ = MockValSer( # type: ignore[assignment]
|
||||
return handler
|
||||
|
||||
cls.__pydantic_core_schema__ = MockCoreSchema( # pyright: ignore[reportAttributeAccessIssue]
|
||||
undefined_type_error_message,
|
||||
code='class-not-fully-defined',
|
||||
attempt_rebuild=attempt_rebuild_fn(lambda c: c.__pydantic_core_schema__),
|
||||
)
|
||||
cls.__pydantic_validator__ = MockValSer( # pyright: ignore[reportAttributeAccessIssue]
|
||||
undefined_type_error_message,
|
||||
code='class-not-fully-defined',
|
||||
val_or_ser='validator',
|
||||
attempt_rebuild=attempt_rebuild_validator,
|
||||
attempt_rebuild=attempt_rebuild_fn(lambda c: c.__pydantic_validator__),
|
||||
)
|
||||
|
||||
def attempt_rebuild_serializer() -> SchemaSerializer | None:
|
||||
if rebuild_dataclass(cls, raise_errors=False, _parent_namespace_depth=5) is not False:
|
||||
return cls.__pydantic_serializer__
|
||||
else:
|
||||
return None
|
||||
|
||||
cls.__pydantic_serializer__ = MockValSer( # type: ignore[assignment]
|
||||
cls.__pydantic_serializer__ = MockValSer( # pyright: ignore[reportAttributeAccessIssue]
|
||||
undefined_type_error_message,
|
||||
code='class-not-fully-defined',
|
||||
val_or_ser='validator',
|
||||
attempt_rebuild=attempt_rebuild_serializer,
|
||||
val_or_ser='serializer',
|
||||
attempt_rebuild=attempt_rebuild_fn(lambda c: c.__pydantic_serializer__),
|
||||
)
|
||||
|
||||
@@ -1,43 +1,49 @@
|
||||
"""Private logic for creating models."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import operator
|
||||
import sys
|
||||
import typing
|
||||
import warnings
|
||||
import weakref
|
||||
from abc import ABCMeta
|
||||
from functools import partial
|
||||
from functools import cache, partial, wraps
|
||||
from types import FunctionType
|
||||
from typing import Any, Callable, Generic, Mapping
|
||||
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, NoReturn, TypeVar, cast
|
||||
|
||||
import typing_extensions
|
||||
from pydantic_core import PydanticUndefined, SchemaSerializer
|
||||
from typing_extensions import dataclass_transform, deprecated
|
||||
from typing_extensions import TypeAliasType, dataclass_transform, deprecated, get_args, get_origin
|
||||
from typing_inspection import typing_objects
|
||||
|
||||
from ..errors import PydanticUndefinedAnnotation, PydanticUserError
|
||||
from ..plugin._schema_validator import create_schema_validator
|
||||
from ..warnings import GenericBeforeBaseModelWarning, PydanticDeprecatedSince20
|
||||
from ._config import ConfigWrapper
|
||||
from ._decorators import DecoratorInfos, PydanticDescriptorProxy, get_attribute_from_bases
|
||||
from ._fields import collect_model_fields, is_valid_field_name, is_valid_privateattr_name
|
||||
from ._generate_schema import GenerateSchema, generate_pydantic_signature
|
||||
from ._decorators import DecoratorInfos, PydanticDescriptorProxy, get_attribute_from_bases, unwrap_wrapped_function
|
||||
from ._fields import collect_model_fields, is_valid_field_name, is_valid_privateattr_name, rebuild_model_fields
|
||||
from ._generate_schema import GenerateSchema, InvalidSchemaError
|
||||
from ._generics import PydanticGenericMetadata, get_model_typevars_map
|
||||
from ._mock_val_ser import MockValSer, set_model_mocks
|
||||
from ._schema_generation_shared import CallbackGetCoreSchemaHandler
|
||||
from ._typing_extra import get_cls_types_namespace, is_annotated, is_classvar, parent_frame_namespace
|
||||
from ._utils import ClassAttribute
|
||||
from ._validate_call import ValidateCallWrapper
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from inspect import Signature
|
||||
from ._import_utils import import_cached_base_model, import_cached_field_info
|
||||
from ._mock_val_ser import set_model_mocks
|
||||
from ._namespace_utils import NsResolver
|
||||
from ._signature import generate_pydantic_signature
|
||||
from ._typing_extra import (
|
||||
_make_forward_ref,
|
||||
eval_type_backport,
|
||||
is_classvar_annotation,
|
||||
parent_frame_namespace,
|
||||
)
|
||||
from ._utils import LazyClassAttribute, SafeGetItemProxy
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..fields import Field as PydanticModelField
|
||||
from ..fields import FieldInfo, ModelPrivateAttr
|
||||
from ..fields import PrivateAttr as PydanticModelPrivateAttr
|
||||
from ..main import BaseModel
|
||||
else:
|
||||
# See PyCharm issues https://youtrack.jetbrains.com/issue/PY-21915
|
||||
# and https://youtrack.jetbrains.com/issue/PY-51428
|
||||
DeprecationWarning = PydanticDeprecatedSince20
|
||||
PydanticModelField = object()
|
||||
PydanticModelPrivateAttr = object()
|
||||
|
||||
object_setattr = object.__setattr__
|
||||
|
||||
@@ -50,12 +56,29 @@ class _ModelNamespaceDict(dict):
|
||||
def __setitem__(self, k: str, v: object) -> None:
|
||||
existing: Any = self.get(k, None)
|
||||
if existing and v is not existing and isinstance(existing, PydanticDescriptorProxy):
|
||||
warnings.warn(f'`{k}` overrides an existing Pydantic `{existing.decorator_info.decorator_repr}` decorator')
|
||||
warnings.warn(
|
||||
f'`{k}` overrides an existing Pydantic `{existing.decorator_info.decorator_repr}` decorator',
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
return super().__setitem__(k, v)
|
||||
|
||||
|
||||
@dataclass_transform(kw_only_default=True, field_specifiers=(PydanticModelField,))
|
||||
def NoInitField(
|
||||
*,
|
||||
init: Literal[False] = False,
|
||||
) -> Any:
|
||||
"""Only for typing purposes. Used as default value of `__pydantic_fields_set__`,
|
||||
`__pydantic_extra__`, `__pydantic_private__`, so they could be ignored when
|
||||
synthesizing the `__init__` signature.
|
||||
"""
|
||||
|
||||
|
||||
# For ModelMetaclass.register():
|
||||
_T = TypeVar('_T')
|
||||
|
||||
|
||||
@dataclass_transform(kw_only_default=True, field_specifiers=(PydanticModelField, PydanticModelPrivateAttr, NoInitField))
|
||||
class ModelMetaclass(ABCMeta):
|
||||
def __new__(
|
||||
mcs,
|
||||
@@ -85,24 +108,42 @@ class ModelMetaclass(ABCMeta):
|
||||
# that `BaseModel` itself won't have any bases, but any subclass of it will, to determine whether the `__new__`
|
||||
# call we're in the middle of is for the `BaseModel` class.
|
||||
if bases:
|
||||
raw_annotations: dict[str, Any]
|
||||
if sys.version_info >= (3, 14):
|
||||
if (
|
||||
'__annotations__' in namespace
|
||||
): # `from __future__ import annotations` was used in the model's module
|
||||
raw_annotations = namespace['__annotations__']
|
||||
else:
|
||||
# See https://docs.python.org/3.14/library/annotationlib.html#using-annotations-in-a-metaclass:
|
||||
from annotationlib import Format, call_annotate_function, get_annotate_from_class_namespace
|
||||
|
||||
if annotate := get_annotate_from_class_namespace(namespace):
|
||||
raw_annotations = call_annotate_function(annotate, format=Format.FORWARDREF)
|
||||
else:
|
||||
raw_annotations = {}
|
||||
else:
|
||||
raw_annotations = namespace.get('__annotations__', {})
|
||||
|
||||
base_field_names, class_vars, base_private_attributes = mcs._collect_bases_data(bases)
|
||||
|
||||
config_wrapper = ConfigWrapper.for_model(bases, namespace, kwargs)
|
||||
config_wrapper = ConfigWrapper.for_model(bases, namespace, raw_annotations, kwargs)
|
||||
namespace['model_config'] = config_wrapper.config_dict
|
||||
private_attributes = inspect_namespace(
|
||||
namespace, config_wrapper.ignored_types, class_vars, base_field_names
|
||||
namespace, raw_annotations, config_wrapper.ignored_types, class_vars, base_field_names
|
||||
)
|
||||
if private_attributes:
|
||||
if private_attributes or base_private_attributes:
|
||||
original_model_post_init = get_model_post_init(namespace, bases)
|
||||
if original_model_post_init is not None:
|
||||
# if there are private_attributes and a model_post_init function, we handle both
|
||||
|
||||
def wrapped_model_post_init(self: BaseModel, __context: Any) -> None:
|
||||
@wraps(original_model_post_init)
|
||||
def wrapped_model_post_init(self: BaseModel, context: Any, /) -> None:
|
||||
"""We need to both initialize private attributes and call the user-defined model_post_init
|
||||
method.
|
||||
"""
|
||||
init_private_attributes(self, __context)
|
||||
original_model_post_init(self, __context)
|
||||
init_private_attributes(self, context)
|
||||
original_model_post_init(self, context)
|
||||
|
||||
namespace['model_post_init'] = wrapped_model_post_init
|
||||
else:
|
||||
@@ -111,15 +152,11 @@ class ModelMetaclass(ABCMeta):
|
||||
namespace['__class_vars__'] = class_vars
|
||||
namespace['__private_attributes__'] = {**base_private_attributes, **private_attributes}
|
||||
|
||||
if config_wrapper.frozen:
|
||||
set_default_hash_func(namespace, bases)
|
||||
|
||||
cls: type[BaseModel] = super().__new__(mcs, cls_name, bases, namespace, **kwargs) # type: ignore
|
||||
|
||||
from ..main import BaseModel
|
||||
cls = cast('type[BaseModel]', super().__new__(mcs, cls_name, bases, namespace, **kwargs))
|
||||
BaseModel_ = import_cached_base_model()
|
||||
|
||||
mro = cls.__mro__
|
||||
if Generic in mro and mro.index(Generic) < mro.index(BaseModel):
|
||||
if Generic in mro and mro.index(Generic) < mro.index(BaseModel_):
|
||||
warnings.warn(
|
||||
GenericBeforeBaseModelWarning(
|
||||
'Classes should inherit from `BaseModel` before generic classes (e.g. `typing.Generic[T]`) '
|
||||
@@ -129,9 +166,14 @@ class ModelMetaclass(ABCMeta):
|
||||
)
|
||||
|
||||
cls.__pydantic_custom_init__ = not getattr(cls.__init__, '__pydantic_base_init__', False)
|
||||
cls.__pydantic_post_init__ = None if cls.model_post_init is BaseModel.model_post_init else 'model_post_init'
|
||||
cls.__pydantic_post_init__ = (
|
||||
None if cls.model_post_init is BaseModel_.model_post_init else 'model_post_init'
|
||||
)
|
||||
|
||||
cls.__pydantic_setattr_handlers__ = {}
|
||||
|
||||
cls.__pydantic_decorators__ = DecoratorInfos.build(cls)
|
||||
cls.__pydantic_decorators__.update_from_config(config_wrapper)
|
||||
|
||||
# Use the getattr below to grab the __parameters__ from the `typing.Generic` parent class
|
||||
if __pydantic_generic_metadata__:
|
||||
@@ -140,22 +182,40 @@ class ModelMetaclass(ABCMeta):
|
||||
parent_parameters = getattr(cls, '__pydantic_generic_metadata__', {}).get('parameters', ())
|
||||
parameters = getattr(cls, '__parameters__', None) or parent_parameters
|
||||
if parameters and parent_parameters and not all(x in parameters for x in parent_parameters):
|
||||
combined_parameters = parent_parameters + tuple(x for x in parameters if x not in parent_parameters)
|
||||
parameters_str = ', '.join([str(x) for x in combined_parameters])
|
||||
generic_type_label = f'typing.Generic[{parameters_str}]'
|
||||
error_message = (
|
||||
f'All parameters must be present on typing.Generic;'
|
||||
f' you should inherit from {generic_type_label}.'
|
||||
)
|
||||
if Generic not in bases: # pragma: no cover
|
||||
# We raise an error here not because it is desirable, but because some cases are mishandled.
|
||||
# It would be nice to remove this error and still have things behave as expected, it's just
|
||||
# challenging because we are using a custom `__class_getitem__` to parametrize generic models,
|
||||
# and not returning a typing._GenericAlias from it.
|
||||
bases_str = ', '.join([x.__name__ for x in bases] + [generic_type_label])
|
||||
error_message += (
|
||||
f' Note: `typing.Generic` must go last: `class {cls.__name__}({bases_str}): ...`)'
|
||||
from ..root_model import RootModelRootType
|
||||
|
||||
missing_parameters = tuple(x for x in parameters if x not in parent_parameters)
|
||||
if RootModelRootType in parent_parameters and RootModelRootType not in parameters:
|
||||
# This is a special case where the user has subclassed `RootModel`, but has not parametrized
|
||||
# RootModel with the generic type identifiers being used. Ex:
|
||||
# class MyModel(RootModel, Generic[T]):
|
||||
# root: T
|
||||
# Should instead just be:
|
||||
# class MyModel(RootModel[T]):
|
||||
# root: T
|
||||
parameters_str = ', '.join([x.__name__ for x in missing_parameters])
|
||||
error_message = (
|
||||
f'{cls.__name__} is a subclass of `RootModel`, but does not include the generic type identifier(s) '
|
||||
f'{parameters_str} in its parameters. '
|
||||
f'You should parametrize RootModel directly, e.g., `class {cls.__name__}(RootModel[{parameters_str}]): ...`.'
|
||||
)
|
||||
else:
|
||||
combined_parameters = parent_parameters + missing_parameters
|
||||
parameters_str = ', '.join([str(x) for x in combined_parameters])
|
||||
generic_type_label = f'typing.Generic[{parameters_str}]'
|
||||
error_message = (
|
||||
f'All parameters must be present on typing.Generic;'
|
||||
f' you should inherit from {generic_type_label}.'
|
||||
)
|
||||
if Generic not in bases: # pragma: no cover
|
||||
# We raise an error here not because it is desirable, but because some cases are mishandled.
|
||||
# It would be nice to remove this error and still have things behave as expected, it's just
|
||||
# challenging because we are using a custom `__class_getitem__` to parametrize generic models,
|
||||
# and not returning a typing._GenericAlias from it.
|
||||
bases_str = ', '.join([x.__name__ for x in bases] + [generic_type_label])
|
||||
error_message += (
|
||||
f' Note: `typing.Generic` must go last: `class {cls.__name__}({bases_str}): ...`)'
|
||||
)
|
||||
raise TypeError(error_message)
|
||||
|
||||
cls.__pydantic_generic_metadata__ = {
|
||||
@@ -173,30 +233,52 @@ class ModelMetaclass(ABCMeta):
|
||||
|
||||
if __pydantic_reset_parent_namespace__:
|
||||
cls.__pydantic_parent_namespace__ = build_lenient_weakvaluedict(parent_frame_namespace())
|
||||
parent_namespace = getattr(cls, '__pydantic_parent_namespace__', None)
|
||||
parent_namespace: dict[str, Any] | None = getattr(cls, '__pydantic_parent_namespace__', None)
|
||||
if isinstance(parent_namespace, dict):
|
||||
parent_namespace = unpack_lenient_weakvaluedict(parent_namespace)
|
||||
|
||||
types_namespace = get_cls_types_namespace(cls, parent_namespace)
|
||||
set_model_fields(cls, bases, config_wrapper, types_namespace)
|
||||
complete_model_class(
|
||||
cls,
|
||||
cls_name,
|
||||
config_wrapper,
|
||||
raise_errors=False,
|
||||
types_namespace=types_namespace,
|
||||
create_model_module=_create_model_module,
|
||||
)
|
||||
ns_resolver = NsResolver(parent_namespace=parent_namespace)
|
||||
|
||||
set_model_fields(cls, config_wrapper=config_wrapper, ns_resolver=ns_resolver)
|
||||
|
||||
# This is also set in `complete_model_class()`, after schema gen because they are recreated.
|
||||
# We set them here as well for backwards compatibility:
|
||||
cls.__pydantic_computed_fields__ = {
|
||||
k: v.info for k, v in cls.__pydantic_decorators__.computed_fields.items()
|
||||
}
|
||||
|
||||
if config_wrapper.defer_build:
|
||||
set_model_mocks(cls)
|
||||
else:
|
||||
# Any operation that requires accessing the field infos instances should be put inside
|
||||
# `complete_model_class()`:
|
||||
complete_model_class(
|
||||
cls,
|
||||
config_wrapper,
|
||||
ns_resolver,
|
||||
raise_errors=False,
|
||||
create_model_module=_create_model_module,
|
||||
)
|
||||
|
||||
if config_wrapper.frozen and '__hash__' not in namespace:
|
||||
set_default_hash_func(cls, bases)
|
||||
|
||||
# using super(cls, cls) on the next line ensures we only call the parent class's __pydantic_init_subclass__
|
||||
# I believe the `type: ignore` is only necessary because mypy doesn't realize that this code branch is
|
||||
# only hit for _proper_ subclasses of BaseModel
|
||||
super(cls, cls).__pydantic_init_subclass__(**kwargs) # type: ignore[misc]
|
||||
return cls
|
||||
else:
|
||||
# this is the BaseModel class itself being created, no logic required
|
||||
# These are instance variables, but have been assigned to `NoInitField` to trick the type checker.
|
||||
for instance_slot in '__pydantic_fields_set__', '__pydantic_extra__', '__pydantic_private__':
|
||||
namespace.pop(
|
||||
instance_slot,
|
||||
None, # In case the metaclass is used with a class other than `BaseModel`.
|
||||
)
|
||||
namespace.get('__annotations__', {}).clear()
|
||||
return super().__new__(mcs, cls_name, bases, namespace, **kwargs)
|
||||
|
||||
if not typing.TYPE_CHECKING: # pragma: no branch
|
||||
if not TYPE_CHECKING: # pragma: no branch
|
||||
# We put `__getattr__` in a non-TYPE_CHECKING block because otherwise, mypy allows arbitrary attribute access
|
||||
|
||||
def __getattr__(self, item: str) -> Any:
|
||||
@@ -204,30 +286,30 @@ class ModelMetaclass(ABCMeta):
|
||||
private_attributes = self.__dict__.get('__private_attributes__')
|
||||
if private_attributes and item in private_attributes:
|
||||
return private_attributes[item]
|
||||
if item == '__pydantic_core_schema__':
|
||||
# This means the class didn't get a schema generated for it, likely because there was an undefined reference
|
||||
maybe_mock_validator = getattr(self, '__pydantic_validator__', None)
|
||||
if isinstance(maybe_mock_validator, MockValSer):
|
||||
rebuilt_validator = maybe_mock_validator.rebuild()
|
||||
if rebuilt_validator is not None:
|
||||
# In this case, a validator was built, and so `__pydantic_core_schema__` should now be set
|
||||
return getattr(self, '__pydantic_core_schema__')
|
||||
raise AttributeError(item)
|
||||
|
||||
@classmethod
|
||||
def __prepare__(cls, *args: Any, **kwargs: Any) -> Mapping[str, object]:
|
||||
def __prepare__(cls, *args: Any, **kwargs: Any) -> dict[str, object]:
|
||||
return _ModelNamespaceDict()
|
||||
|
||||
def __instancecheck__(self, instance: Any) -> bool:
|
||||
"""Avoid calling ABC _abc_subclasscheck unless we're pretty sure.
|
||||
# Due to performance and memory issues, in the ABCMeta.__subclasscheck__ implementation, we don't support
|
||||
# registered virtual subclasses. See https://github.com/python/cpython/issues/92810#issuecomment-2762454345.
|
||||
# This may change once the CPython gets fixed (possibly in 3.15), in which case we should conditionally
|
||||
# define `register()`.
|
||||
def register(self, subclass: type[_T]) -> type[_T]:
|
||||
warnings.warn(
|
||||
f"For performance reasons, virtual subclasses registered using '{self.__qualname__}.register()' "
|
||||
"are not supported in 'isinstance()' and 'issubclass()' checks.",
|
||||
stacklevel=2,
|
||||
)
|
||||
return super().register(subclass)
|
||||
|
||||
See #3829 and python/cpython#92810
|
||||
"""
|
||||
return hasattr(instance, '__pydantic_validator__') and super().__instancecheck__(instance)
|
||||
__instancecheck__ = type.__instancecheck__ # pyright: ignore[reportAssignmentType]
|
||||
__subclasscheck__ = type.__subclasscheck__ # pyright: ignore[reportAssignmentType]
|
||||
|
||||
@staticmethod
|
||||
def _collect_bases_data(bases: tuple[type[Any], ...]) -> tuple[set[str], set[str], dict[str, ModelPrivateAttr]]:
|
||||
from ..main import BaseModel
|
||||
BaseModel = import_cached_base_model()
|
||||
|
||||
field_names: set[str] = set()
|
||||
class_vars: set[str] = set()
|
||||
@@ -235,28 +317,51 @@ class ModelMetaclass(ABCMeta):
|
||||
for base in bases:
|
||||
if issubclass(base, BaseModel) and base is not BaseModel:
|
||||
# model_fields might not be defined yet in the case of generics, so we use getattr here:
|
||||
field_names.update(getattr(base, 'model_fields', {}).keys())
|
||||
field_names.update(getattr(base, '__pydantic_fields__', {}).keys())
|
||||
class_vars.update(base.__class_vars__)
|
||||
private_attributes.update(base.__private_attributes__)
|
||||
return field_names, class_vars, private_attributes
|
||||
|
||||
@property
|
||||
@deprecated(
|
||||
'The `__fields__` attribute is deprecated, use `model_fields` instead.', category=PydanticDeprecatedSince20
|
||||
'The `__fields__` attribute is deprecated, use the `model_fields` class property instead.', category=None
|
||||
)
|
||||
def __fields__(self) -> dict[str, FieldInfo]:
|
||||
warnings.warn('The `__fields__` attribute is deprecated, use `model_fields` instead.', DeprecationWarning)
|
||||
return self.model_fields # type: ignore
|
||||
warnings.warn(
|
||||
'The `__fields__` attribute is deprecated, use the `model_fields` class property instead.',
|
||||
PydanticDeprecatedSince20,
|
||||
stacklevel=2,
|
||||
)
|
||||
return getattr(self, '__pydantic_fields__', {})
|
||||
|
||||
@property
|
||||
def __pydantic_fields_complete__(self) -> bool:
|
||||
"""Whether the fields where successfully collected (i.e. type hints were successfully resolves).
|
||||
|
||||
This is a private attribute, not meant to be used outside Pydantic.
|
||||
"""
|
||||
if '__pydantic_fields__' not in self.__dict__:
|
||||
return False
|
||||
|
||||
field_infos = cast('dict[str, FieldInfo]', self.__pydantic_fields__) # pyright: ignore[reportAttributeAccessIssue]
|
||||
|
||||
return all(field_info._complete for field_info in field_infos.values())
|
||||
|
||||
def __dir__(self) -> list[str]:
|
||||
attributes = list(super().__dir__())
|
||||
if '__fields__' in attributes:
|
||||
attributes.remove('__fields__')
|
||||
return attributes
|
||||
|
||||
|
||||
def init_private_attributes(self: BaseModel, __context: Any) -> None:
|
||||
def init_private_attributes(self: BaseModel, context: Any, /) -> None:
|
||||
"""This function is meant to behave like a BaseModel method to initialise private attributes.
|
||||
|
||||
It takes context as an argument since that's what pydantic-core passes when calling it.
|
||||
|
||||
Args:
|
||||
self: The BaseModel instance.
|
||||
__context: The context.
|
||||
context: The context.
|
||||
"""
|
||||
if getattr(self, '__pydantic_private__', None) is None:
|
||||
pydantic_private = {}
|
||||
@@ -272,7 +377,7 @@ def get_model_post_init(namespace: dict[str, Any], bases: tuple[type[Any], ...])
|
||||
if 'model_post_init' in namespace:
|
||||
return namespace['model_post_init']
|
||||
|
||||
from ..main import BaseModel
|
||||
BaseModel = import_cached_base_model()
|
||||
|
||||
model_post_init = get_attribute_from_bases(bases, 'model_post_init')
|
||||
if model_post_init is not BaseModel.model_post_init:
|
||||
@@ -281,6 +386,7 @@ def get_model_post_init(namespace: dict[str, Any], bases: tuple[type[Any], ...])
|
||||
|
||||
def inspect_namespace( # noqa C901
|
||||
namespace: dict[str, Any],
|
||||
raw_annotations: dict[str, Any],
|
||||
ignored_types: tuple[type[Any], ...],
|
||||
base_class_vars: set[str],
|
||||
base_class_fields: set[str],
|
||||
@@ -291,6 +397,7 @@ def inspect_namespace( # noqa C901
|
||||
|
||||
Args:
|
||||
namespace: The attribute dictionary of the class to be created.
|
||||
raw_annotations: The (non-evaluated) annotations of the model.
|
||||
ignored_types: A tuple of ignore types.
|
||||
base_class_vars: A set of base class class variables.
|
||||
base_class_fields: A set of base class fields.
|
||||
@@ -305,24 +412,26 @@ def inspect_namespace( # noqa C901
|
||||
- If a field does not have a type annotation.
|
||||
- If a field on base class was overridden by a non-annotated attribute.
|
||||
"""
|
||||
from ..fields import FieldInfo, ModelPrivateAttr, PrivateAttr
|
||||
from ..fields import ModelPrivateAttr, PrivateAttr
|
||||
|
||||
FieldInfo = import_cached_field_info()
|
||||
|
||||
all_ignored_types = ignored_types + default_ignored_types()
|
||||
|
||||
private_attributes: dict[str, ModelPrivateAttr] = {}
|
||||
raw_annotations = namespace.get('__annotations__', {})
|
||||
|
||||
if '__root__' in raw_annotations or '__root__' in namespace:
|
||||
raise TypeError("To define root models, use `pydantic.RootModel` rather than a field called '__root__'")
|
||||
|
||||
ignored_names: set[str] = set()
|
||||
for var_name, value in list(namespace.items()):
|
||||
if var_name == 'model_config':
|
||||
if var_name == 'model_config' or var_name == '__pydantic_extra__':
|
||||
continue
|
||||
elif (
|
||||
isinstance(value, type)
|
||||
and value.__module__ == namespace['__module__']
|
||||
and value.__qualname__.startswith(namespace['__qualname__'])
|
||||
and '__qualname__' in namespace
|
||||
and value.__qualname__.startswith(f'{namespace["__qualname__"]}.')
|
||||
):
|
||||
# `value` is a nested type defined in this namespace; don't error
|
||||
continue
|
||||
@@ -352,8 +461,8 @@ def inspect_namespace( # noqa C901
|
||||
elif var_name.startswith('__'):
|
||||
continue
|
||||
elif is_valid_privateattr_name(var_name):
|
||||
if var_name not in raw_annotations or not is_classvar(raw_annotations[var_name]):
|
||||
private_attributes[var_name] = PrivateAttr(default=value)
|
||||
if var_name not in raw_annotations or not is_classvar_annotation(raw_annotations[var_name]):
|
||||
private_attributes[var_name] = cast(ModelPrivateAttr, PrivateAttr(default=value))
|
||||
del namespace[var_name]
|
||||
elif var_name in base_class_vars:
|
||||
continue
|
||||
@@ -381,12 +490,28 @@ def inspect_namespace( # noqa C901
|
||||
is_valid_privateattr_name(ann_name)
|
||||
and ann_name not in private_attributes
|
||||
and ann_name not in ignored_names
|
||||
and not is_classvar(ann_type)
|
||||
# This condition can be a false negative when `ann_type` is stringified,
|
||||
# but it is handled in most cases in `set_model_fields`:
|
||||
and not is_classvar_annotation(ann_type)
|
||||
and ann_type not in all_ignored_types
|
||||
and getattr(ann_type, '__module__', None) != 'functools'
|
||||
):
|
||||
if is_annotated(ann_type):
|
||||
_, *metadata = typing_extensions.get_args(ann_type)
|
||||
if isinstance(ann_type, str):
|
||||
# Walking up the frames to get the module namespace where the model is defined
|
||||
# (as the model class wasn't created yet, we unfortunately can't use `cls.__module__`):
|
||||
frame = sys._getframe(2)
|
||||
if frame is not None:
|
||||
try:
|
||||
ann_type = eval_type_backport(
|
||||
_make_forward_ref(ann_type, is_argument=False, is_class=True),
|
||||
globalns=frame.f_globals,
|
||||
localns=frame.f_locals,
|
||||
)
|
||||
except (NameError, TypeError):
|
||||
pass
|
||||
|
||||
if typing_objects.is_annotated(get_origin(ann_type)):
|
||||
_, *metadata = get_args(ann_type)
|
||||
private_attr = next((v for v in metadata if isinstance(v, ModelPrivateAttr)), None)
|
||||
if private_attr is not None:
|
||||
private_attributes[ann_name] = private_attr
|
||||
@@ -396,36 +521,51 @@ def inspect_namespace( # noqa C901
|
||||
return private_attributes
|
||||
|
||||
|
||||
def set_default_hash_func(namespace: dict[str, Any], bases: tuple[type[Any], ...]) -> None:
|
||||
if '__hash__' in namespace:
|
||||
return
|
||||
|
||||
def set_default_hash_func(cls: type[BaseModel], bases: tuple[type[Any], ...]) -> None:
|
||||
base_hash_func = get_attribute_from_bases(bases, '__hash__')
|
||||
if base_hash_func in {None, object.__hash__}:
|
||||
# If `__hash__` is None _or_ `object.__hash__`, we generate a hash function.
|
||||
# It will be `None` if not overridden from BaseModel, but may be `object.__hash__` if there is another
|
||||
new_hash_func = make_hash_func(cls)
|
||||
if base_hash_func in {None, object.__hash__} or getattr(base_hash_func, '__code__', None) == new_hash_func.__code__:
|
||||
# If `__hash__` is some default, we generate a hash function.
|
||||
# It will be `None` if not overridden from BaseModel.
|
||||
# It may be `object.__hash__` if there is another
|
||||
# parent class earlier in the bases which doesn't override `__hash__` (e.g. `typing.Generic`).
|
||||
def hash_func(self: Any) -> int:
|
||||
return hash(self.__class__) + hash(tuple(self.__dict__.values()))
|
||||
# It may be a value set by `set_default_hash_func` if `cls` is a subclass of another frozen model.
|
||||
# In the last case we still need a new hash function to account for new `model_fields`.
|
||||
cls.__hash__ = new_hash_func
|
||||
|
||||
namespace['__hash__'] = hash_func
|
||||
|
||||
def make_hash_func(cls: type[BaseModel]) -> Any:
|
||||
getter = operator.itemgetter(*cls.__pydantic_fields__.keys()) if cls.__pydantic_fields__ else lambda _: 0
|
||||
|
||||
def hash_func(self: Any) -> int:
|
||||
try:
|
||||
return hash(getter(self.__dict__))
|
||||
except KeyError:
|
||||
# In rare cases (such as when using the deprecated copy method), the __dict__ may not contain
|
||||
# all model fields, which is how we can get here.
|
||||
# getter(self.__dict__) is much faster than any 'safe' method that accounts for missing keys,
|
||||
# and wrapping it in a `try` doesn't slow things down much in the common case.
|
||||
return hash(getter(SafeGetItemProxy(self.__dict__)))
|
||||
|
||||
return hash_func
|
||||
|
||||
|
||||
def set_model_fields(
|
||||
cls: type[BaseModel], bases: tuple[type[Any], ...], config_wrapper: ConfigWrapper, types_namespace: dict[str, Any]
|
||||
cls: type[BaseModel],
|
||||
config_wrapper: ConfigWrapper,
|
||||
ns_resolver: NsResolver | None,
|
||||
) -> None:
|
||||
"""Collect and set `cls.model_fields` and `cls.__class_vars__`.
|
||||
"""Collect and set `cls.__pydantic_fields__` and `cls.__class_vars__`.
|
||||
|
||||
Args:
|
||||
cls: BaseModel or dataclass.
|
||||
bases: Parents of the class, generally `cls.__bases__`.
|
||||
config_wrapper: The config wrapper instance.
|
||||
types_namespace: Optional extra namespace to look for types in.
|
||||
ns_resolver: Namespace resolver to use when getting model annotations.
|
||||
"""
|
||||
typevars_map = get_model_typevars_map(cls)
|
||||
fields, class_vars = collect_model_fields(cls, bases, config_wrapper, types_namespace, typevars_map=typevars_map)
|
||||
fields, class_vars = collect_model_fields(cls, config_wrapper, ns_resolver, typevars_map=typevars_map)
|
||||
|
||||
cls.model_fields = fields
|
||||
cls.__pydantic_fields__ = fields
|
||||
cls.__class_vars__.update(class_vars)
|
||||
|
||||
for k in class_vars:
|
||||
@@ -443,11 +583,11 @@ def set_model_fields(
|
||||
|
||||
def complete_model_class(
|
||||
cls: type[BaseModel],
|
||||
cls_name: str,
|
||||
config_wrapper: ConfigWrapper,
|
||||
ns_resolver: NsResolver,
|
||||
*,
|
||||
raise_errors: bool = True,
|
||||
types_namespace: dict[str, Any] | None,
|
||||
call_on_complete_hook: bool = True,
|
||||
create_model_module: str | None = None,
|
||||
) -> bool:
|
||||
"""Finish building a model class.
|
||||
@@ -457,10 +597,10 @@ def complete_model_class(
|
||||
|
||||
Args:
|
||||
cls: BaseModel or dataclass.
|
||||
cls_name: The model or dataclass name.
|
||||
config_wrapper: The config wrapper instance.
|
||||
ns_resolver: The namespace resolver instance to use during schema building.
|
||||
raise_errors: Whether to raise errors.
|
||||
types_namespace: Optional extra namespace to look for types in.
|
||||
call_on_complete_hook: Whether to call the `__pydantic_on_complete__` hook.
|
||||
create_model_module: The module of the class to be created, if created by `create_model`.
|
||||
|
||||
Returns:
|
||||
@@ -471,39 +611,61 @@ def complete_model_class(
|
||||
and `raise_errors=True`.
|
||||
"""
|
||||
typevars_map = get_model_typevars_map(cls)
|
||||
|
||||
if not cls.__pydantic_fields_complete__:
|
||||
# Note: when coming from `ModelMetaclass.__new__()`, this results in fields being built twice.
|
||||
# We do so a second time here so that we can get the `NameError` for the specific undefined annotation.
|
||||
# Alternatively, we could let `GenerateSchema()` raise the error, but there are cases where incomplete
|
||||
# fields are inherited in `collect_model_fields()` and can actually have their annotation resolved in the
|
||||
# generate schema process. As we want to avoid having `__pydantic_fields_complete__` set to `False`
|
||||
# when `__pydantic_complete__` is `True`, we rebuild here:
|
||||
try:
|
||||
cls.__pydantic_fields__ = rebuild_model_fields(
|
||||
cls,
|
||||
config_wrapper=config_wrapper,
|
||||
ns_resolver=ns_resolver,
|
||||
typevars_map=typevars_map,
|
||||
)
|
||||
except NameError as e:
|
||||
exc = PydanticUndefinedAnnotation.from_name_error(e)
|
||||
set_model_mocks(cls, f'`{exc.name}`')
|
||||
if raise_errors:
|
||||
raise exc from e
|
||||
|
||||
if not raise_errors and not cls.__pydantic_fields_complete__:
|
||||
# No need to continue with schema gen, it is guaranteed to fail
|
||||
return False
|
||||
|
||||
assert cls.__pydantic_fields_complete__
|
||||
|
||||
gen_schema = GenerateSchema(
|
||||
config_wrapper,
|
||||
types_namespace,
|
||||
ns_resolver,
|
||||
typevars_map,
|
||||
)
|
||||
|
||||
handler = CallbackGetCoreSchemaHandler(
|
||||
partial(gen_schema.generate_schema, from_dunder_get_core_schema=False),
|
||||
gen_schema,
|
||||
ref_mode='unpack',
|
||||
)
|
||||
|
||||
if config_wrapper.defer_build:
|
||||
set_model_mocks(cls, cls_name)
|
||||
return False
|
||||
|
||||
try:
|
||||
schema = cls.__get_pydantic_core_schema__(cls, handler)
|
||||
schema = gen_schema.generate_schema(cls)
|
||||
except PydanticUndefinedAnnotation as e:
|
||||
if raise_errors:
|
||||
raise
|
||||
set_model_mocks(cls, cls_name, f'`{e.name}`')
|
||||
set_model_mocks(cls, f'`{e.name}`')
|
||||
return False
|
||||
|
||||
core_config = config_wrapper.core_config(cls)
|
||||
core_config = config_wrapper.core_config(title=cls.__name__)
|
||||
|
||||
try:
|
||||
schema = gen_schema.clean_schema(schema)
|
||||
except gen_schema.CollectedInvalid:
|
||||
set_model_mocks(cls, cls_name)
|
||||
except InvalidSchemaError:
|
||||
set_model_mocks(cls)
|
||||
return False
|
||||
|
||||
# debug(schema)
|
||||
# This needs to happen *after* model schema generation, as the return type
|
||||
# of the properties are evaluated and the `ComputedFieldInfo` are recreated:
|
||||
cls.__pydantic_computed_fields__ = {k: v.info for k, v in cls.__pydantic_decorators__.computed_fields.items()}
|
||||
|
||||
set_deprecated_descriptors(cls)
|
||||
|
||||
cls.__pydantic_core_schema__ = schema
|
||||
|
||||
cls.__pydantic_validator__ = create_schema_validator(
|
||||
@@ -516,29 +678,83 @@ def complete_model_class(
|
||||
config_wrapper.plugin_settings,
|
||||
)
|
||||
cls.__pydantic_serializer__ = SchemaSerializer(schema, core_config)
|
||||
cls.__pydantic_complete__ = True
|
||||
|
||||
# set __signature__ attr only for model class, but not for its instances
|
||||
cls.__signature__ = ClassAttribute(
|
||||
'__signature__', generate_model_signature(cls.__init__, cls.model_fields, config_wrapper)
|
||||
# (because instances can define `__call__`, and `inspect.signature` shouldn't
|
||||
# use the `__signature__` attribute and instead generate from `__call__`).
|
||||
cls.__signature__ = LazyClassAttribute(
|
||||
'__signature__',
|
||||
partial(
|
||||
generate_pydantic_signature,
|
||||
init=cls.__init__,
|
||||
fields=cls.__pydantic_fields__,
|
||||
validate_by_name=config_wrapper.validate_by_name,
|
||||
extra=config_wrapper.extra,
|
||||
),
|
||||
)
|
||||
|
||||
cls.__pydantic_complete__ = True
|
||||
|
||||
if call_on_complete_hook:
|
||||
cls.__pydantic_on_complete__()
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def generate_model_signature(
|
||||
init: Callable[..., None], fields: dict[str, FieldInfo], config_wrapper: ConfigWrapper
|
||||
) -> Signature:
|
||||
"""Generate signature for model based on its fields.
|
||||
def set_deprecated_descriptors(cls: type[BaseModel]) -> None:
|
||||
"""Set data descriptors on the class for deprecated fields."""
|
||||
for field, field_info in cls.__pydantic_fields__.items():
|
||||
if (msg := field_info.deprecation_message) is not None:
|
||||
desc = _DeprecatedFieldDescriptor(msg)
|
||||
desc.__set_name__(cls, field)
|
||||
setattr(cls, field, desc)
|
||||
|
||||
Args:
|
||||
init: The class init.
|
||||
fields: The model fields.
|
||||
config_wrapper: The config wrapper instance.
|
||||
for field, computed_field_info in cls.__pydantic_computed_fields__.items():
|
||||
if (
|
||||
(msg := computed_field_info.deprecation_message) is not None
|
||||
# Avoid having two warnings emitted:
|
||||
and not hasattr(unwrap_wrapped_function(computed_field_info.wrapped_property), '__deprecated__')
|
||||
):
|
||||
desc = _DeprecatedFieldDescriptor(msg, computed_field_info.wrapped_property)
|
||||
desc.__set_name__(cls, field)
|
||||
setattr(cls, field, desc)
|
||||
|
||||
Returns:
|
||||
The model signature.
|
||||
|
||||
class _DeprecatedFieldDescriptor:
|
||||
"""Read-only data descriptor used to emit a runtime deprecation warning before accessing a deprecated field.
|
||||
|
||||
Attributes:
|
||||
msg: The deprecation message to be emitted.
|
||||
wrapped_property: The property instance if the deprecated field is a computed field, or `None`.
|
||||
field_name: The name of the field being deprecated.
|
||||
"""
|
||||
return generate_pydantic_signature(init, fields, config_wrapper)
|
||||
|
||||
field_name: str
|
||||
|
||||
def __init__(self, msg: str, wrapped_property: property | None = None) -> None:
|
||||
self.msg = msg
|
||||
self.wrapped_property = wrapped_property
|
||||
|
||||
def __set_name__(self, cls: type[BaseModel], name: str) -> None:
|
||||
self.field_name = name
|
||||
|
||||
def __get__(self, obj: BaseModel | None, obj_type: type[BaseModel] | None = None) -> Any:
|
||||
if obj is None:
|
||||
if self.wrapped_property is not None:
|
||||
return self.wrapped_property.__get__(None, obj_type)
|
||||
raise AttributeError(self.field_name)
|
||||
|
||||
warnings.warn(self.msg, DeprecationWarning, stacklevel=2)
|
||||
|
||||
if self.wrapped_property is not None:
|
||||
return self.wrapped_property.__get__(obj, obj_type)
|
||||
return obj.__dict__[self.field_name]
|
||||
|
||||
# Defined to make it a data descriptor and take precedence over the instance's dictionary.
|
||||
# Note that it will not be called when setting a value on a model instance
|
||||
# as `BaseModel.__setattr__` is defined and takes priority.
|
||||
def __set__(self, obj: Any, value: Any) -> NoReturn:
|
||||
raise AttributeError(self.field_name)
|
||||
|
||||
|
||||
class _PydanticWeakRef:
|
||||
@@ -612,15 +828,21 @@ def unpack_lenient_weakvaluedict(d: dict[str, Any] | None) -> dict[str, Any] | N
|
||||
return result
|
||||
|
||||
|
||||
@cache
|
||||
def default_ignored_types() -> tuple[type[Any], ...]:
|
||||
from ..fields import ComputedFieldInfo
|
||||
|
||||
return (
|
||||
ignored_types = [
|
||||
FunctionType,
|
||||
property,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
PydanticDescriptorProxy,
|
||||
ComputedFieldInfo,
|
||||
ValidateCallWrapper,
|
||||
)
|
||||
TypeAliasType, # from `typing_extensions`
|
||||
]
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
ignored_types.append(typing.TypeAliasType)
|
||||
|
||||
return tuple(ignored_types)
|
||||
|
||||
@@ -0,0 +1,293 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from collections.abc import Generator, Iterator, Mapping
|
||||
from contextlib import contextmanager
|
||||
from functools import cached_property
|
||||
from typing import Any, Callable, NamedTuple, TypeVar
|
||||
|
||||
from typing_extensions import ParamSpec, TypeAlias, TypeAliasType, TypeVarTuple
|
||||
|
||||
GlobalsNamespace: TypeAlias = 'dict[str, Any]'
|
||||
"""A global namespace.
|
||||
|
||||
In most cases, this is a reference to the `__dict__` attribute of a module.
|
||||
This namespace type is expected as the `globals` argument during annotations evaluation.
|
||||
"""
|
||||
|
||||
MappingNamespace: TypeAlias = Mapping[str, Any]
|
||||
"""Any kind of namespace.
|
||||
|
||||
In most cases, this is a local namespace (e.g. the `__dict__` attribute of a class,
|
||||
the [`f_locals`][frame.f_locals] attribute of a frame object, when dealing with types
|
||||
defined inside functions).
|
||||
This namespace type is expected as the `locals` argument during annotations evaluation.
|
||||
"""
|
||||
|
||||
_TypeVarLike: TypeAlias = 'TypeVar | ParamSpec | TypeVarTuple'
|
||||
|
||||
|
||||
class NamespacesTuple(NamedTuple):
|
||||
"""A tuple of globals and locals to be used during annotations evaluation.
|
||||
|
||||
This datastructure is defined as a named tuple so that it can easily be unpacked:
|
||||
|
||||
```python {lint="skip" test="skip"}
|
||||
def eval_type(typ: type[Any], ns: NamespacesTuple) -> None:
|
||||
return eval(typ, *ns)
|
||||
```
|
||||
"""
|
||||
|
||||
globals: GlobalsNamespace
|
||||
"""The namespace to be used as the `globals` argument during annotations evaluation."""
|
||||
|
||||
locals: MappingNamespace
|
||||
"""The namespace to be used as the `locals` argument during annotations evaluation."""
|
||||
|
||||
|
||||
def get_module_ns_of(obj: Any) -> dict[str, Any]:
|
||||
"""Get the namespace of the module where the object is defined.
|
||||
|
||||
Caution: this function does not return a copy of the module namespace, so the result
|
||||
should not be mutated. The burden of enforcing this is on the caller.
|
||||
"""
|
||||
module_name = getattr(obj, '__module__', None)
|
||||
if module_name:
|
||||
try:
|
||||
return sys.modules[module_name].__dict__
|
||||
except KeyError:
|
||||
# happens occasionally, see https://github.com/pydantic/pydantic/issues/2363
|
||||
return {}
|
||||
return {}
|
||||
|
||||
|
||||
# Note that this class is almost identical to `collections.ChainMap`, but need to enforce
|
||||
# immutable mappings here:
|
||||
class LazyLocalNamespace(Mapping[str, Any]):
|
||||
"""A lazily evaluated mapping, to be used as the `locals` argument during annotations evaluation.
|
||||
|
||||
While the [`eval`][eval] function expects a mapping as the `locals` argument, it only
|
||||
performs `__getitem__` calls. The [`Mapping`][collections.abc.Mapping] abstract base class
|
||||
is fully implemented only for type checking purposes.
|
||||
|
||||
Args:
|
||||
*namespaces: The namespaces to consider, in ascending order of priority.
|
||||
|
||||
Example:
|
||||
```python {lint="skip" test="skip"}
|
||||
ns = LazyLocalNamespace({'a': 1, 'b': 2}, {'a': 3})
|
||||
ns['a']
|
||||
#> 3
|
||||
ns['b']
|
||||
#> 2
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, *namespaces: MappingNamespace) -> None:
|
||||
self._namespaces = namespaces
|
||||
|
||||
@cached_property
|
||||
def data(self) -> dict[str, Any]:
|
||||
return {k: v for ns in self._namespaces for k, v in ns.items()}
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
return self.data[key]
|
||||
|
||||
def __contains__(self, key: object) -> bool:
|
||||
return key in self.data
|
||||
|
||||
def __iter__(self) -> Iterator[str]:
|
||||
return iter(self.data)
|
||||
|
||||
|
||||
def ns_for_function(obj: Callable[..., Any], parent_namespace: MappingNamespace | None = None) -> NamespacesTuple:
|
||||
"""Return the global and local namespaces to be used when evaluating annotations for the provided function.
|
||||
|
||||
The global namespace will be the `__dict__` attribute of the module the function was defined in.
|
||||
The local namespace will contain the `__type_params__` introduced by PEP 695.
|
||||
|
||||
Args:
|
||||
obj: The object to use when building namespaces.
|
||||
parent_namespace: Optional namespace to be added with the lowest priority in the local namespace.
|
||||
If the passed function is a method, the `parent_namespace` will be the namespace of the class
|
||||
the method is defined in. Thus, we also fetch type `__type_params__` from there (i.e. the
|
||||
class-scoped type variables).
|
||||
"""
|
||||
locals_list: list[MappingNamespace] = []
|
||||
if parent_namespace is not None:
|
||||
locals_list.append(parent_namespace)
|
||||
|
||||
# Get the `__type_params__` attribute introduced by PEP 695.
|
||||
# Note that the `typing._eval_type` function expects type params to be
|
||||
# passed as a separate argument. However, internally, `_eval_type` calls
|
||||
# `ForwardRef._evaluate` which will merge type params with the localns,
|
||||
# essentially mimicking what we do here.
|
||||
type_params: tuple[_TypeVarLike, ...] = getattr(obj, '__type_params__', ())
|
||||
if parent_namespace is not None:
|
||||
# We also fetch type params from the parent namespace. If present, it probably
|
||||
# means the function was defined in a class. This is to support the following:
|
||||
# https://github.com/python/cpython/issues/124089.
|
||||
type_params += parent_namespace.get('__type_params__', ())
|
||||
|
||||
locals_list.append({t.__name__: t for t in type_params})
|
||||
|
||||
# What about short-circuiting to `obj.__globals__`?
|
||||
globalns = get_module_ns_of(obj)
|
||||
|
||||
return NamespacesTuple(globalns, LazyLocalNamespace(*locals_list))
|
||||
|
||||
|
||||
class NsResolver:
|
||||
"""A class responsible for the namespaces resolving logic for annotations evaluation.
|
||||
|
||||
This class handles the namespace logic when evaluating annotations mainly for class objects.
|
||||
|
||||
It holds a stack of classes that are being inspected during the core schema building,
|
||||
and the `types_namespace` property exposes the globals and locals to be used for
|
||||
type annotation evaluation. Additionally -- if no class is present in the stack -- a
|
||||
fallback globals and locals can be provided using the `namespaces_tuple` argument
|
||||
(this is useful when generating a schema for a simple annotation, e.g. when using
|
||||
`TypeAdapter`).
|
||||
|
||||
The namespace creation logic is unfortunately flawed in some cases, for backwards
|
||||
compatibility reasons and to better support valid edge cases. See the description
|
||||
for the `parent_namespace` argument and the example for more details.
|
||||
|
||||
Args:
|
||||
namespaces_tuple: The default globals and locals to use if no class is present
|
||||
on the stack. This can be useful when using the `GenerateSchema` class
|
||||
with `TypeAdapter`, where the "type" being analyzed is a simple annotation.
|
||||
parent_namespace: An optional parent namespace that will be added to the locals
|
||||
with the lowest priority. For a given class defined in a function, the locals
|
||||
of this function are usually used as the parent namespace:
|
||||
|
||||
```python {lint="skip" test="skip"}
|
||||
from pydantic import BaseModel
|
||||
|
||||
def func() -> None:
|
||||
SomeType = int
|
||||
|
||||
class Model(BaseModel):
|
||||
f: 'SomeType'
|
||||
|
||||
# when collecting fields, an namespace resolver instance will be created
|
||||
# this way:
|
||||
# ns_resolver = NsResolver(parent_namespace={'SomeType': SomeType})
|
||||
```
|
||||
|
||||
For backwards compatibility reasons and to support valid edge cases, this parent
|
||||
namespace will be used for *every* type being pushed to the stack. In the future,
|
||||
we might want to be smarter by only doing so when the type being pushed is defined
|
||||
in the same module as the parent namespace.
|
||||
|
||||
Example:
|
||||
```python {lint="skip" test="skip"}
|
||||
ns_resolver = NsResolver(
|
||||
parent_namespace={'fallback': 1},
|
||||
)
|
||||
|
||||
class Sub:
|
||||
m: 'Model'
|
||||
|
||||
class Model:
|
||||
some_local = 1
|
||||
sub: Sub
|
||||
|
||||
ns_resolver = NsResolver()
|
||||
|
||||
# This is roughly what happens when we build a core schema for `Model`:
|
||||
with ns_resolver.push(Model):
|
||||
ns_resolver.types_namespace
|
||||
#> NamespacesTuple({'Sub': Sub}, {'Model': Model, 'some_local': 1})
|
||||
# First thing to notice here, the model being pushed is added to the locals.
|
||||
# Because `NsResolver` is being used during the model definition, it is not
|
||||
# yet added to the globals. This is useful when resolving self-referencing annotations.
|
||||
|
||||
with ns_resolver.push(Sub):
|
||||
ns_resolver.types_namespace
|
||||
#> NamespacesTuple({'Sub': Sub}, {'Sub': Sub, 'Model': Model})
|
||||
# Second thing to notice: `Sub` is present in both the globals and locals.
|
||||
# This is not an issue, just that as described above, the model being pushed
|
||||
# is added to the locals, but it happens to be present in the globals as well
|
||||
# because it is already defined.
|
||||
# Third thing to notice: `Model` is also added in locals. This is a backwards
|
||||
# compatibility workaround that allows for `Sub` to be able to resolve `'Model'`
|
||||
# correctly (as otherwise models would have to be rebuilt even though this
|
||||
# doesn't look necessary).
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
namespaces_tuple: NamespacesTuple | None = None,
|
||||
parent_namespace: MappingNamespace | None = None,
|
||||
) -> None:
|
||||
self._base_ns_tuple = namespaces_tuple or NamespacesTuple({}, {})
|
||||
self._parent_ns = parent_namespace
|
||||
self._types_stack: list[type[Any] | TypeAliasType] = []
|
||||
|
||||
@cached_property
|
||||
def types_namespace(self) -> NamespacesTuple:
|
||||
"""The current global and local namespaces to be used for annotations evaluation."""
|
||||
if not self._types_stack:
|
||||
# TODO: should we merge the parent namespace here?
|
||||
# This is relevant for TypeAdapter, where there are no types on the stack, and we might
|
||||
# need access to the parent_ns. Right now, we sidestep this in `type_adapter.py` by passing
|
||||
# locals to both parent_ns and the base_ns_tuple, but this is a bit hacky.
|
||||
# we might consider something like:
|
||||
# if self._parent_ns is not None:
|
||||
# # Hacky workarounds, see class docstring:
|
||||
# # An optional parent namespace that will be added to the locals with the lowest priority
|
||||
# locals_list: list[MappingNamespace] = [self._parent_ns, self._base_ns_tuple.locals]
|
||||
# return NamespacesTuple(self._base_ns_tuple.globals, LazyLocalNamespace(*locals_list))
|
||||
return self._base_ns_tuple
|
||||
|
||||
typ = self._types_stack[-1]
|
||||
|
||||
globalns = get_module_ns_of(typ)
|
||||
|
||||
locals_list: list[MappingNamespace] = []
|
||||
# Hacky workarounds, see class docstring:
|
||||
# An optional parent namespace that will be added to the locals with the lowest priority
|
||||
if self._parent_ns is not None:
|
||||
locals_list.append(self._parent_ns)
|
||||
if len(self._types_stack) > 1:
|
||||
first_type = self._types_stack[0]
|
||||
locals_list.append({first_type.__name__: first_type})
|
||||
|
||||
# Adding `__type_params__` *before* `vars(typ)`, as the latter takes priority
|
||||
# (see https://github.com/python/cpython/pull/120272).
|
||||
# TODO `typ.__type_params__` when we drop support for Python 3.11:
|
||||
type_params: tuple[_TypeVarLike, ...] = getattr(typ, '__type_params__', ())
|
||||
if type_params:
|
||||
# Adding `__type_params__` is mostly useful for generic classes defined using
|
||||
# PEP 695 syntax *and* using forward annotations (see the example in
|
||||
# https://github.com/python/cpython/issues/114053). For TypeAliasType instances,
|
||||
# it is way less common, but still required if using a string annotation in the alias
|
||||
# value, e.g. `type A[T] = 'T'` (which is not necessary in most cases).
|
||||
locals_list.append({t.__name__: t for t in type_params})
|
||||
|
||||
# TypeAliasType instances don't have a `__dict__` attribute, so the check
|
||||
# is necessary:
|
||||
if hasattr(typ, '__dict__'):
|
||||
locals_list.append(vars(typ))
|
||||
|
||||
# The `len(self._types_stack) > 1` check above prevents this from being added twice:
|
||||
locals_list.append({typ.__name__: typ})
|
||||
|
||||
return NamespacesTuple(globalns, LazyLocalNamespace(*locals_list))
|
||||
|
||||
@contextmanager
|
||||
def push(self, typ: type[Any] | TypeAliasType, /) -> Generator[None]:
|
||||
"""Push a type to the stack."""
|
||||
self._types_stack.append(typ)
|
||||
# Reset the cached property:
|
||||
self.__dict__.pop('types_namespace', None)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self._types_stack.pop()
|
||||
self.__dict__.pop('types_namespace', None)
|
||||
@@ -1,19 +1,22 @@
|
||||
"""Tools to provide pretty/human-readable display of objects."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import types
|
||||
import typing
|
||||
from typing import Any
|
||||
from collections.abc import Callable, Collection, Generator, Iterable
|
||||
from typing import TYPE_CHECKING, Any, ForwardRef, cast
|
||||
|
||||
import typing_extensions
|
||||
from typing_extensions import TypeAlias
|
||||
from typing_inspection import typing_objects
|
||||
from typing_inspection.introspection import is_union_origin
|
||||
|
||||
from . import _typing_extra
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
ReprArgs: typing_extensions.TypeAlias = 'typing.Iterable[tuple[str | None, Any]]'
|
||||
RichReprResult: typing_extensions.TypeAlias = (
|
||||
'typing.Iterable[Any | tuple[Any] | tuple[str, Any] | tuple[str, Any, Any]]'
|
||||
)
|
||||
if TYPE_CHECKING:
|
||||
# TODO remove type error comments when we drop support for Python 3.9
|
||||
ReprArgs: TypeAlias = Iterable[tuple[str | None, Any]] # pyright: ignore[reportGeneralTypeIssues]
|
||||
RichReprResult: TypeAlias = Iterable[Any | tuple[Any] | tuple[str, Any] | tuple[str, Any, Any]] # pyright: ignore[reportGeneralTypeIssues]
|
||||
|
||||
|
||||
class PlainRepr(str):
|
||||
@@ -31,8 +34,7 @@ class Representation:
|
||||
# `__rich_repr__` is used by [rich](https://rich.readthedocs.io/en/stable/pretty.html).
|
||||
# (this is not a docstring to avoid adding a docstring to classes which inherit from Representation)
|
||||
|
||||
# we don't want to use a type annotation here as it can break get_type_hints
|
||||
__slots__ = tuple() # type: typing.Collection[str]
|
||||
__slots__ = ()
|
||||
|
||||
def __repr_args__(self) -> ReprArgs:
|
||||
"""Returns the attributes to show in __str__, __repr__, and __pretty__ this is generally overridden.
|
||||
@@ -41,20 +43,25 @@ class Representation:
|
||||
* name - value pairs, e.g.: `[('foo_name', 'foo'), ('bar_name', ['b', 'a', 'r'])]`
|
||||
* or, just values, e.g.: `[(None, 'foo'), (None, ['b', 'a', 'r'])]`
|
||||
"""
|
||||
attrs_names = self.__slots__
|
||||
attrs_names = cast(Collection[str], self.__slots__)
|
||||
if not attrs_names and hasattr(self, '__dict__'):
|
||||
attrs_names = self.__dict__.keys()
|
||||
attrs = ((s, getattr(self, s)) for s in attrs_names)
|
||||
return [(a, v) for a, v in attrs if v is not None]
|
||||
return [(a, v if v is not self else self.__repr_recursion__(v)) for a, v in attrs if v is not None]
|
||||
|
||||
def __repr_name__(self) -> str:
|
||||
"""Name of the instance's class, used in __repr__."""
|
||||
return self.__class__.__name__
|
||||
|
||||
def __repr_recursion__(self, object: Any) -> str:
|
||||
"""Returns the string representation of a recursive object."""
|
||||
# This is copied over from the stdlib `pprint` module:
|
||||
return f'<Recursion on {type(object).__name__} with id={id(object)}>'
|
||||
|
||||
def __repr_str__(self, join_str: str) -> str:
|
||||
return join_str.join(repr(v) if a is None else f'{a}={v!r}' for a, v in self.__repr_args__())
|
||||
|
||||
def __pretty__(self, fmt: typing.Callable[[Any], Any], **kwargs: Any) -> typing.Generator[Any, None, None]:
|
||||
def __pretty__(self, fmt: Callable[[Any], Any], **kwargs: Any) -> Generator[Any]:
|
||||
"""Used by devtools (https://python-devtools.helpmanual.io/) to pretty print objects."""
|
||||
yield self.__repr_name__() + '('
|
||||
yield 1
|
||||
@@ -87,28 +94,30 @@ def display_as_type(obj: Any) -> str:
|
||||
|
||||
Takes some logic from `typing._type_repr`.
|
||||
"""
|
||||
if isinstance(obj, types.FunctionType):
|
||||
if isinstance(obj, (types.FunctionType, types.BuiltinFunctionType)):
|
||||
return obj.__name__
|
||||
elif obj is ...:
|
||||
return '...'
|
||||
elif isinstance(obj, Representation):
|
||||
return repr(obj)
|
||||
elif isinstance(obj, ForwardRef) or typing_objects.is_typealiastype(obj):
|
||||
return str(obj)
|
||||
|
||||
if not isinstance(obj, (_typing_extra.typing_base, _typing_extra.WithArgsTypes, type)):
|
||||
obj = obj.__class__
|
||||
|
||||
if _typing_extra.origin_is_union(typing_extensions.get_origin(obj)):
|
||||
if is_union_origin(typing_extensions.get_origin(obj)):
|
||||
args = ', '.join(map(display_as_type, typing_extensions.get_args(obj)))
|
||||
return f'Union[{args}]'
|
||||
elif isinstance(obj, _typing_extra.WithArgsTypes):
|
||||
if typing_extensions.get_origin(obj) == typing_extensions.Literal:
|
||||
if typing_objects.is_literal(typing_extensions.get_origin(obj)):
|
||||
args = ', '.join(map(repr, typing_extensions.get_args(obj)))
|
||||
else:
|
||||
args = ', '.join(map(display_as_type, typing_extensions.get_args(obj)))
|
||||
try:
|
||||
return f'{obj.__qualname__}[{args}]'
|
||||
except AttributeError:
|
||||
return str(obj) # handles TypeAliasType in 3.12
|
||||
return str(obj).replace('typing.', '').replace('typing_extensions.', '') # handles TypeAliasType in 3.12
|
||||
elif isinstance(obj, type):
|
||||
return obj.__qualname__
|
||||
else:
|
||||
|
||||
@@ -0,0 +1,209 @@
|
||||
# pyright: reportTypedDictNotRequiredAccess=false, reportGeneralTypeIssues=false, reportArgumentType=false, reportAttributeAccessIssue=false
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TypedDict
|
||||
|
||||
from pydantic_core.core_schema import ComputedField, CoreSchema, DefinitionReferenceSchema, SerSchema
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
AllSchemas: TypeAlias = 'CoreSchema | SerSchema | ComputedField'
|
||||
|
||||
|
||||
class GatherResult(TypedDict):
|
||||
"""Schema traversing result."""
|
||||
|
||||
collected_references: dict[str, DefinitionReferenceSchema | None]
|
||||
"""The collected definition references.
|
||||
|
||||
If a definition reference schema can be inlined, it means that there is
|
||||
only one in the whole core schema. As such, it is stored as the value.
|
||||
Otherwise, the value is set to `None`.
|
||||
"""
|
||||
|
||||
deferred_discriminator_schemas: list[CoreSchema]
|
||||
"""The list of core schemas having the discriminator application deferred."""
|
||||
|
||||
|
||||
class MissingDefinitionError(LookupError):
|
||||
"""A reference was pointing to a non-existing core schema."""
|
||||
|
||||
def __init__(self, schema_reference: str, /) -> None:
|
||||
self.schema_reference = schema_reference
|
||||
|
||||
|
||||
@dataclass
|
||||
class GatherContext:
|
||||
"""The current context used during core schema traversing.
|
||||
|
||||
Context instances should only be used during schema traversing.
|
||||
"""
|
||||
|
||||
definitions: dict[str, CoreSchema]
|
||||
"""The available definitions."""
|
||||
|
||||
deferred_discriminator_schemas: list[CoreSchema] = field(init=False, default_factory=list)
|
||||
"""The list of core schemas having the discriminator application deferred.
|
||||
|
||||
Internally, these core schemas have a specific key set in the core metadata dict.
|
||||
"""
|
||||
|
||||
collected_references: dict[str, DefinitionReferenceSchema | None] = field(init=False, default_factory=dict)
|
||||
"""The collected definition references.
|
||||
|
||||
If a definition reference schema can be inlined, it means that there is
|
||||
only one in the whole core schema. As such, it is stored as the value.
|
||||
Otherwise, the value is set to `None`.
|
||||
|
||||
During schema traversing, definition reference schemas can be added as candidates, or removed
|
||||
(by setting the value to `None`).
|
||||
"""
|
||||
|
||||
|
||||
def traverse_metadata(schema: AllSchemas, ctx: GatherContext) -> None:
|
||||
meta = schema.get('metadata')
|
||||
if meta is not None and 'pydantic_internal_union_discriminator' in meta:
|
||||
ctx.deferred_discriminator_schemas.append(schema) # pyright: ignore[reportArgumentType]
|
||||
|
||||
|
||||
def traverse_definition_ref(def_ref_schema: DefinitionReferenceSchema, ctx: GatherContext) -> None:
|
||||
schema_ref = def_ref_schema['schema_ref']
|
||||
|
||||
if schema_ref not in ctx.collected_references:
|
||||
definition = ctx.definitions.get(schema_ref)
|
||||
if definition is None:
|
||||
raise MissingDefinitionError(schema_ref)
|
||||
|
||||
# The `'definition-ref'` schema was only encountered once, make it
|
||||
# a candidate to be inlined:
|
||||
ctx.collected_references[schema_ref] = def_ref_schema
|
||||
traverse_schema(definition, ctx)
|
||||
if 'serialization' in def_ref_schema:
|
||||
traverse_schema(def_ref_schema['serialization'], ctx)
|
||||
traverse_metadata(def_ref_schema, ctx)
|
||||
else:
|
||||
# The `'definition-ref'` schema was already encountered, meaning
|
||||
# the previously encountered schema (and this one) can't be inlined:
|
||||
ctx.collected_references[schema_ref] = None
|
||||
|
||||
|
||||
def traverse_schema(schema: AllSchemas, context: GatherContext) -> None:
|
||||
# TODO When we drop 3.9, use a match statement to get better type checking and remove
|
||||
# file-level type ignore.
|
||||
# (the `'type'` could also be fetched in every `if/elif` statement, but this alters performance).
|
||||
schema_type = schema['type']
|
||||
|
||||
if schema_type == 'definition-ref':
|
||||
traverse_definition_ref(schema, context)
|
||||
# `traverse_definition_ref` handles the possible serialization and metadata schemas:
|
||||
return
|
||||
elif schema_type == 'definitions':
|
||||
traverse_schema(schema['schema'], context)
|
||||
for definition in schema['definitions']:
|
||||
traverse_schema(definition, context)
|
||||
elif schema_type in {'list', 'set', 'frozenset', 'generator'}:
|
||||
if 'items_schema' in schema:
|
||||
traverse_schema(schema['items_schema'], context)
|
||||
elif schema_type == 'tuple':
|
||||
if 'items_schema' in schema:
|
||||
for s in schema['items_schema']:
|
||||
traverse_schema(s, context)
|
||||
elif schema_type == 'dict':
|
||||
if 'keys_schema' in schema:
|
||||
traverse_schema(schema['keys_schema'], context)
|
||||
if 'values_schema' in schema:
|
||||
traverse_schema(schema['values_schema'], context)
|
||||
elif schema_type == 'union':
|
||||
for choice in schema['choices']:
|
||||
if isinstance(choice, tuple):
|
||||
traverse_schema(choice[0], context)
|
||||
else:
|
||||
traverse_schema(choice, context)
|
||||
elif schema_type == 'tagged-union':
|
||||
for v in schema['choices'].values():
|
||||
traverse_schema(v, context)
|
||||
elif schema_type == 'chain':
|
||||
for step in schema['steps']:
|
||||
traverse_schema(step, context)
|
||||
elif schema_type == 'lax-or-strict':
|
||||
traverse_schema(schema['lax_schema'], context)
|
||||
traverse_schema(schema['strict_schema'], context)
|
||||
elif schema_type == 'json-or-python':
|
||||
traverse_schema(schema['json_schema'], context)
|
||||
traverse_schema(schema['python_schema'], context)
|
||||
elif schema_type in {'model-fields', 'typed-dict'}:
|
||||
if 'extras_schema' in schema:
|
||||
traverse_schema(schema['extras_schema'], context)
|
||||
if 'computed_fields' in schema:
|
||||
for s in schema['computed_fields']:
|
||||
traverse_schema(s, context)
|
||||
for s in schema['fields'].values():
|
||||
traverse_schema(s, context)
|
||||
elif schema_type == 'dataclass-args':
|
||||
if 'computed_fields' in schema:
|
||||
for s in schema['computed_fields']:
|
||||
traverse_schema(s, context)
|
||||
for s in schema['fields']:
|
||||
traverse_schema(s, context)
|
||||
elif schema_type == 'arguments':
|
||||
for s in schema['arguments_schema']:
|
||||
traverse_schema(s['schema'], context)
|
||||
if 'var_args_schema' in schema:
|
||||
traverse_schema(schema['var_args_schema'], context)
|
||||
if 'var_kwargs_schema' in schema:
|
||||
traverse_schema(schema['var_kwargs_schema'], context)
|
||||
elif schema_type == 'arguments-v3':
|
||||
for s in schema['arguments_schema']:
|
||||
traverse_schema(s['schema'], context)
|
||||
elif schema_type == 'call':
|
||||
traverse_schema(schema['arguments_schema'], context)
|
||||
if 'return_schema' in schema:
|
||||
traverse_schema(schema['return_schema'], context)
|
||||
elif schema_type == 'computed-field':
|
||||
traverse_schema(schema['return_schema'], context)
|
||||
elif schema_type == 'function-before':
|
||||
if 'schema' in schema:
|
||||
traverse_schema(schema['schema'], context)
|
||||
if 'json_schema_input_schema' in schema:
|
||||
traverse_schema(schema['json_schema_input_schema'], context)
|
||||
elif schema_type == 'function-plain':
|
||||
# TODO duplicate schema types for serializers and validators, needs to be deduplicated.
|
||||
if 'return_schema' in schema:
|
||||
traverse_schema(schema['return_schema'], context)
|
||||
if 'json_schema_input_schema' in schema:
|
||||
traverse_schema(schema['json_schema_input_schema'], context)
|
||||
elif schema_type == 'function-wrap':
|
||||
# TODO duplicate schema types for serializers and validators, needs to be deduplicated.
|
||||
if 'return_schema' in schema:
|
||||
traverse_schema(schema['return_schema'], context)
|
||||
if 'schema' in schema:
|
||||
traverse_schema(schema['schema'], context)
|
||||
if 'json_schema_input_schema' in schema:
|
||||
traverse_schema(schema['json_schema_input_schema'], context)
|
||||
else:
|
||||
if 'schema' in schema:
|
||||
traverse_schema(schema['schema'], context)
|
||||
|
||||
if 'serialization' in schema:
|
||||
traverse_schema(schema['serialization'], context)
|
||||
traverse_metadata(schema, context)
|
||||
|
||||
|
||||
def gather_schemas_for_cleaning(schema: CoreSchema, definitions: dict[str, CoreSchema]) -> GatherResult:
|
||||
"""Traverse the core schema and definitions and return the necessary information for schema cleaning.
|
||||
|
||||
During the core schema traversing, any `'definition-ref'` schema is:
|
||||
|
||||
- Validated: the reference must point to an existing definition. If this is not the case, a
|
||||
`MissingDefinitionError` exception is raised.
|
||||
- Stored in the context: the actual reference is stored in the context. Depending on whether
|
||||
the `'definition-ref'` schema is encountered more that once, the schema itself is also
|
||||
saved in the context to be inlined (i.e. replaced by the definition it points to).
|
||||
"""
|
||||
context = GatherContext(definitions)
|
||||
traverse_schema(schema, context)
|
||||
|
||||
return {
|
||||
'collected_references': context.collected_references,
|
||||
'deferred_discriminator_schemas': context.deferred_discriminator_schemas,
|
||||
}
|
||||
@@ -1,10 +1,10 @@
|
||||
"""Types and utility functions used by various other internal tools."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Callable
|
||||
from typing import TYPE_CHECKING, Any, Callable, Literal
|
||||
|
||||
from pydantic_core import core_schema
|
||||
from typing_extensions import Literal
|
||||
|
||||
from ..annotated_handlers import GetCoreSchemaHandler, GetJsonSchemaHandler
|
||||
|
||||
@@ -12,6 +12,7 @@ if TYPE_CHECKING:
|
||||
from ..json_schema import GenerateJsonSchema, JsonSchemaValue
|
||||
from ._core_utils import CoreSchemaOrField
|
||||
from ._generate_schema import GenerateSchema
|
||||
from ._namespace_utils import NamespacesTuple
|
||||
|
||||
GetJsonSchemaFunction = Callable[[CoreSchemaOrField, GetJsonSchemaHandler], JsonSchemaValue]
|
||||
HandlerOverride = Callable[[CoreSchemaOrField], JsonSchemaValue]
|
||||
@@ -32,8 +33,8 @@ class GenerateJsonSchemaHandler(GetJsonSchemaHandler):
|
||||
self.handler = handler_override or generate_json_schema.generate_inner
|
||||
self.mode = generate_json_schema.mode
|
||||
|
||||
def __call__(self, __core_schema: CoreSchemaOrField) -> JsonSchemaValue:
|
||||
return self.handler(__core_schema)
|
||||
def __call__(self, core_schema: CoreSchemaOrField, /) -> JsonSchemaValue:
|
||||
return self.handler(core_schema)
|
||||
|
||||
def resolve_ref_schema(self, maybe_ref_json_schema: JsonSchemaValue) -> JsonSchemaValue:
|
||||
"""Resolves `$ref` in the json schema.
|
||||
@@ -78,22 +79,21 @@ class CallbackGetCoreSchemaHandler(GetCoreSchemaHandler):
|
||||
self._generate_schema = generate_schema
|
||||
self._ref_mode = ref_mode
|
||||
|
||||
def __call__(self, __source_type: Any) -> core_schema.CoreSchema:
|
||||
schema = self._handler(__source_type)
|
||||
ref = schema.get('ref')
|
||||
def __call__(self, source_type: Any, /) -> core_schema.CoreSchema:
|
||||
schema = self._handler(source_type)
|
||||
if self._ref_mode == 'to-def':
|
||||
ref = schema.get('ref')
|
||||
if ref is not None:
|
||||
self._generate_schema.defs.definitions[ref] = schema
|
||||
return core_schema.definition_reference_schema(ref)
|
||||
return self._generate_schema.defs.create_definition_reference_schema(schema)
|
||||
return schema
|
||||
else: # ref_mode = 'unpack
|
||||
else: # ref_mode = 'unpack'
|
||||
return self.resolve_ref_schema(schema)
|
||||
|
||||
def _get_types_namespace(self) -> dict[str, Any] | None:
|
||||
def _get_types_namespace(self) -> NamespacesTuple:
|
||||
return self._generate_schema._types_namespace
|
||||
|
||||
def generate_schema(self, __source_type: Any) -> core_schema.CoreSchema:
|
||||
return self._generate_schema.generate_schema(__source_type)
|
||||
def generate_schema(self, source_type: Any, /) -> core_schema.CoreSchema:
|
||||
return self._generate_schema.generate_schema(source_type)
|
||||
|
||||
@property
|
||||
def field_name(self) -> str | None:
|
||||
@@ -113,12 +113,13 @@ class CallbackGetCoreSchemaHandler(GetCoreSchemaHandler):
|
||||
"""
|
||||
if maybe_ref_schema['type'] == 'definition-ref':
|
||||
ref = maybe_ref_schema['schema_ref']
|
||||
if ref not in self._generate_schema.defs.definitions:
|
||||
definition = self._generate_schema.defs.get_schema_from_ref(ref)
|
||||
if definition is None:
|
||||
raise LookupError(
|
||||
f'Could not find a ref for {ref}.'
|
||||
' Maybe you tried to call resolve_ref_schema from within a recursive model?'
|
||||
)
|
||||
return self._generate_schema.defs.definitions[ref]
|
||||
return definition
|
||||
elif maybe_ref_schema['type'] == 'definitions':
|
||||
return self.resolve_ref_schema(maybe_ref_schema['schema'])
|
||||
return maybe_ref_schema
|
||||
|
||||
@@ -0,0 +1,53 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
import collections.abc
|
||||
import typing
|
||||
from typing import Any
|
||||
|
||||
from pydantic_core import PydanticOmit, core_schema
|
||||
|
||||
SEQUENCE_ORIGIN_MAP: dict[Any, Any] = {
|
||||
typing.Deque: collections.deque, # noqa: UP006
|
||||
collections.deque: collections.deque,
|
||||
list: list,
|
||||
typing.List: list, # noqa: UP006
|
||||
tuple: tuple,
|
||||
typing.Tuple: tuple, # noqa: UP006
|
||||
set: set,
|
||||
typing.AbstractSet: set,
|
||||
typing.Set: set, # noqa: UP006
|
||||
frozenset: frozenset,
|
||||
typing.FrozenSet: frozenset, # noqa: UP006
|
||||
typing.Sequence: list,
|
||||
typing.MutableSequence: list,
|
||||
typing.MutableSet: set,
|
||||
# this doesn't handle subclasses of these
|
||||
# parametrized typing.Set creates one of these
|
||||
collections.abc.MutableSet: set,
|
||||
collections.abc.Set: frozenset,
|
||||
}
|
||||
|
||||
|
||||
def serialize_sequence_via_list(
|
||||
v: Any, handler: core_schema.SerializerFunctionWrapHandler, info: core_schema.SerializationInfo
|
||||
) -> Any:
|
||||
items: list[Any] = []
|
||||
|
||||
mapped_origin = SEQUENCE_ORIGIN_MAP.get(type(v), None)
|
||||
if mapped_origin is None:
|
||||
# we shouldn't hit this branch, should probably add a serialization error or something
|
||||
return v
|
||||
|
||||
for index, item in enumerate(v):
|
||||
try:
|
||||
v = handler(item, index)
|
||||
except PydanticOmit: # noqa: PERF203
|
||||
pass
|
||||
else:
|
||||
items.append(v)
|
||||
|
||||
if info.mode_is_json():
|
||||
return items
|
||||
else:
|
||||
return mapped_origin(items)
|
||||
@@ -0,0 +1,188 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
from inspect import Parameter, Signature, signature
|
||||
from typing import TYPE_CHECKING, Any, Callable
|
||||
|
||||
from pydantic_core import PydanticUndefined
|
||||
|
||||
from ._utils import is_valid_identifier
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..config import ExtraValues
|
||||
from ..fields import FieldInfo
|
||||
|
||||
|
||||
# Copied over from stdlib dataclasses
|
||||
class _HAS_DEFAULT_FACTORY_CLASS:
|
||||
def __repr__(self):
|
||||
return '<factory>'
|
||||
|
||||
|
||||
_HAS_DEFAULT_FACTORY = _HAS_DEFAULT_FACTORY_CLASS()
|
||||
|
||||
|
||||
def _field_name_for_signature(field_name: str, field_info: FieldInfo) -> str:
|
||||
"""Extract the correct name to use for the field when generating a signature.
|
||||
|
||||
Assuming the field has a valid alias, this will return the alias. Otherwise, it will return the field name.
|
||||
First priority is given to the alias, then the validation_alias, then the field name.
|
||||
|
||||
Args:
|
||||
field_name: The name of the field
|
||||
field_info: The corresponding FieldInfo object.
|
||||
|
||||
Returns:
|
||||
The correct name to use when generating a signature.
|
||||
"""
|
||||
if isinstance(field_info.alias, str) and is_valid_identifier(field_info.alias):
|
||||
return field_info.alias
|
||||
if isinstance(field_info.validation_alias, str) and is_valid_identifier(field_info.validation_alias):
|
||||
return field_info.validation_alias
|
||||
|
||||
return field_name
|
||||
|
||||
|
||||
def _process_param_defaults(param: Parameter) -> Parameter:
|
||||
"""Modify the signature for a parameter in a dataclass where the default value is a FieldInfo instance.
|
||||
|
||||
Args:
|
||||
param (Parameter): The parameter
|
||||
|
||||
Returns:
|
||||
Parameter: The custom processed parameter
|
||||
"""
|
||||
from ..fields import FieldInfo
|
||||
|
||||
param_default = param.default
|
||||
if isinstance(param_default, FieldInfo):
|
||||
annotation = param.annotation
|
||||
# Replace the annotation if appropriate
|
||||
# inspect does "clever" things to show annotations as strings because we have
|
||||
# `from __future__ import annotations` in main, we don't want that
|
||||
if annotation == 'Any':
|
||||
annotation = Any
|
||||
|
||||
# Replace the field default
|
||||
default = param_default.default
|
||||
if default is PydanticUndefined:
|
||||
if param_default.default_factory is PydanticUndefined:
|
||||
default = Signature.empty
|
||||
else:
|
||||
# this is used by dataclasses to indicate a factory exists:
|
||||
default = dataclasses._HAS_DEFAULT_FACTORY # type: ignore
|
||||
return param.replace(
|
||||
annotation=annotation, name=_field_name_for_signature(param.name, param_default), default=default
|
||||
)
|
||||
return param
|
||||
|
||||
|
||||
def _generate_signature_parameters( # noqa: C901 (ignore complexity, could use a refactor)
|
||||
init: Callable[..., None],
|
||||
fields: dict[str, FieldInfo],
|
||||
validate_by_name: bool,
|
||||
extra: ExtraValues | None,
|
||||
) -> dict[str, Parameter]:
|
||||
"""Generate a mapping of parameter names to Parameter objects for a pydantic BaseModel or dataclass."""
|
||||
from itertools import islice
|
||||
|
||||
present_params = signature(init).parameters.values()
|
||||
merged_params: dict[str, Parameter] = {}
|
||||
var_kw = None
|
||||
use_var_kw = False
|
||||
|
||||
for param in islice(present_params, 1, None): # skip self arg
|
||||
# inspect does "clever" things to show annotations as strings because we have
|
||||
# `from __future__ import annotations` in main, we don't want that
|
||||
if fields.get(param.name):
|
||||
# exclude params with init=False
|
||||
if getattr(fields[param.name], 'init', True) is False:
|
||||
continue
|
||||
param = param.replace(name=_field_name_for_signature(param.name, fields[param.name]))
|
||||
if param.annotation == 'Any':
|
||||
param = param.replace(annotation=Any)
|
||||
if param.kind is param.VAR_KEYWORD:
|
||||
var_kw = param
|
||||
continue
|
||||
merged_params[param.name] = param
|
||||
|
||||
if var_kw: # if custom init has no var_kw, fields which are not declared in it cannot be passed through
|
||||
allow_names = validate_by_name
|
||||
for field_name, field in fields.items():
|
||||
# when alias is a str it should be used for signature generation
|
||||
param_name = _field_name_for_signature(field_name, field)
|
||||
|
||||
if field_name in merged_params or param_name in merged_params:
|
||||
continue
|
||||
|
||||
if not is_valid_identifier(param_name):
|
||||
if allow_names:
|
||||
param_name = field_name
|
||||
else:
|
||||
use_var_kw = True
|
||||
continue
|
||||
|
||||
if field.is_required():
|
||||
default = Parameter.empty
|
||||
elif field.default_factory is not None:
|
||||
# Mimics stdlib dataclasses:
|
||||
default = _HAS_DEFAULT_FACTORY
|
||||
else:
|
||||
default = field.default
|
||||
merged_params[param_name] = Parameter(
|
||||
param_name,
|
||||
Parameter.KEYWORD_ONLY,
|
||||
annotation=field.rebuild_annotation(),
|
||||
default=default,
|
||||
)
|
||||
|
||||
if extra == 'allow':
|
||||
use_var_kw = True
|
||||
|
||||
if var_kw and use_var_kw:
|
||||
# Make sure the parameter for extra kwargs
|
||||
# does not have the same name as a field
|
||||
default_model_signature = [
|
||||
('self', Parameter.POSITIONAL_ONLY),
|
||||
('data', Parameter.VAR_KEYWORD),
|
||||
]
|
||||
if [(p.name, p.kind) for p in present_params] == default_model_signature:
|
||||
# if this is the standard model signature, use extra_data as the extra args name
|
||||
var_kw_name = 'extra_data'
|
||||
else:
|
||||
# else start from var_kw
|
||||
var_kw_name = var_kw.name
|
||||
|
||||
# generate a name that's definitely unique
|
||||
while var_kw_name in fields:
|
||||
var_kw_name += '_'
|
||||
merged_params[var_kw_name] = var_kw.replace(name=var_kw_name)
|
||||
|
||||
return merged_params
|
||||
|
||||
|
||||
def generate_pydantic_signature(
|
||||
init: Callable[..., None],
|
||||
fields: dict[str, FieldInfo],
|
||||
validate_by_name: bool,
|
||||
extra: ExtraValues | None,
|
||||
is_dataclass: bool = False,
|
||||
) -> Signature:
|
||||
"""Generate signature for a pydantic BaseModel or dataclass.
|
||||
|
||||
Args:
|
||||
init: The class init.
|
||||
fields: The model fields.
|
||||
validate_by_name: The `validate_by_name` value of the config.
|
||||
extra: The `extra` value of the config.
|
||||
is_dataclass: Whether the model is a dataclass.
|
||||
|
||||
Returns:
|
||||
The dataclass/BaseModel subclass signature.
|
||||
"""
|
||||
merged_params = _generate_signature_parameters(init, fields, validate_by_name, extra)
|
||||
|
||||
if is_dataclass:
|
||||
merged_params = {k: _process_param_defaults(v) for k, v in merged_params.items()}
|
||||
|
||||
return Signature(parameters=list(merged_params.values()), return_annotation=None)
|
||||
@@ -1,714 +0,0 @@
|
||||
"""Logic for generating pydantic-core schemas for standard library types.
|
||||
|
||||
Import of this module is deferred since it contains imports of many standard library modules.
|
||||
"""
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import collections
|
||||
import collections.abc
|
||||
import dataclasses
|
||||
import decimal
|
||||
import inspect
|
||||
import os
|
||||
import typing
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network
|
||||
from typing import Any, Callable, Iterable, TypeVar
|
||||
|
||||
import typing_extensions
|
||||
from pydantic_core import (
|
||||
CoreSchema,
|
||||
MultiHostUrl,
|
||||
PydanticCustomError,
|
||||
PydanticOmit,
|
||||
Url,
|
||||
core_schema,
|
||||
)
|
||||
from typing_extensions import get_args, get_origin
|
||||
|
||||
from pydantic.errors import PydanticSchemaGenerationError
|
||||
from pydantic.fields import FieldInfo
|
||||
from pydantic.types import Strict
|
||||
|
||||
from ..config import ConfigDict
|
||||
from ..json_schema import JsonSchemaValue, update_json_schema
|
||||
from . import _known_annotated_metadata, _typing_extra, _validators
|
||||
from ._core_utils import get_type_ref
|
||||
from ._internal_dataclass import slots_true
|
||||
from ._schema_generation_shared import GetCoreSchemaHandler, GetJsonSchemaHandler
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from ._generate_schema import GenerateSchema
|
||||
|
||||
StdSchemaFunction = Callable[[GenerateSchema, type[Any]], core_schema.CoreSchema]
|
||||
|
||||
|
||||
@dataclasses.dataclass(**slots_true)
|
||||
class SchemaTransformer:
|
||||
get_core_schema: Callable[[Any, GetCoreSchemaHandler], CoreSchema]
|
||||
get_json_schema: Callable[[CoreSchema, GetJsonSchemaHandler], JsonSchemaValue]
|
||||
|
||||
def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaHandler) -> CoreSchema:
|
||||
return self.get_core_schema(source_type, handler)
|
||||
|
||||
def __get_pydantic_json_schema__(self, schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
|
||||
return self.get_json_schema(schema, handler)
|
||||
|
||||
|
||||
def get_enum_core_schema(enum_type: type[Enum], config: ConfigDict) -> CoreSchema:
|
||||
cases: list[Any] = list(enum_type.__members__.values())
|
||||
|
||||
enum_ref = get_type_ref(enum_type)
|
||||
description = None if not enum_type.__doc__ else inspect.cleandoc(enum_type.__doc__)
|
||||
if description == 'An enumeration.': # This is the default value provided by enum.EnumMeta.__new__; don't use it
|
||||
description = None
|
||||
updates = {'title': enum_type.__name__, 'description': description}
|
||||
updates = {k: v for k, v in updates.items() if v is not None}
|
||||
|
||||
def get_json_schema(_, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
|
||||
json_schema = handler(core_schema.literal_schema([x.value for x in cases], ref=enum_ref))
|
||||
original_schema = handler.resolve_ref_schema(json_schema)
|
||||
update_json_schema(original_schema, updates)
|
||||
return json_schema
|
||||
|
||||
if not cases:
|
||||
# Use an isinstance check for enums with no cases.
|
||||
# The most important use case for this is creating TypeVar bounds for generics that should
|
||||
# be restricted to enums. This is more consistent than it might seem at first, since you can only
|
||||
# subclass enum.Enum (or subclasses of enum.Enum) if all parent classes have no cases.
|
||||
# We use the get_json_schema function when an Enum subclass has been declared with no cases
|
||||
# so that we can still generate a valid json schema.
|
||||
return core_schema.is_instance_schema(enum_type, metadata={'pydantic_js_functions': [get_json_schema]})
|
||||
|
||||
use_enum_values = config.get('use_enum_values', False)
|
||||
|
||||
if len(cases) == 1:
|
||||
expected = repr(cases[0].value)
|
||||
else:
|
||||
expected = ', '.join([repr(case.value) for case in cases[:-1]]) + f' or {cases[-1].value!r}'
|
||||
|
||||
def to_enum(__input_value: Any) -> Enum:
|
||||
try:
|
||||
enum_field = enum_type(__input_value)
|
||||
if use_enum_values:
|
||||
return enum_field.value
|
||||
return enum_field
|
||||
except ValueError:
|
||||
# The type: ignore on the next line is to ignore the requirement of LiteralString
|
||||
raise PydanticCustomError('enum', f'Input should be {expected}', {'expected': expected}) # type: ignore
|
||||
|
||||
strict_python_schema = core_schema.is_instance_schema(enum_type)
|
||||
if use_enum_values:
|
||||
strict_python_schema = core_schema.chain_schema(
|
||||
[strict_python_schema, core_schema.no_info_plain_validator_function(lambda x: x.value)]
|
||||
)
|
||||
|
||||
to_enum_validator = core_schema.no_info_plain_validator_function(to_enum)
|
||||
if issubclass(enum_type, int):
|
||||
# this handles `IntEnum`, and also `Foobar(int, Enum)`
|
||||
updates['type'] = 'integer'
|
||||
lax = core_schema.chain_schema([core_schema.int_schema(), to_enum_validator])
|
||||
# Disallow float from JSON due to strict mode
|
||||
strict = core_schema.json_or_python_schema(
|
||||
json_schema=core_schema.no_info_after_validator_function(to_enum, core_schema.int_schema()),
|
||||
python_schema=strict_python_schema,
|
||||
)
|
||||
elif issubclass(enum_type, str):
|
||||
# this handles `StrEnum` (3.11 only), and also `Foobar(str, Enum)`
|
||||
updates['type'] = 'string'
|
||||
lax = core_schema.chain_schema([core_schema.str_schema(), to_enum_validator])
|
||||
strict = core_schema.json_or_python_schema(
|
||||
json_schema=core_schema.no_info_after_validator_function(to_enum, core_schema.str_schema()),
|
||||
python_schema=strict_python_schema,
|
||||
)
|
||||
elif issubclass(enum_type, float):
|
||||
updates['type'] = 'numeric'
|
||||
lax = core_schema.chain_schema([core_schema.float_schema(), to_enum_validator])
|
||||
strict = core_schema.json_or_python_schema(
|
||||
json_schema=core_schema.no_info_after_validator_function(to_enum, core_schema.float_schema()),
|
||||
python_schema=strict_python_schema,
|
||||
)
|
||||
else:
|
||||
lax = to_enum_validator
|
||||
strict = core_schema.json_or_python_schema(json_schema=to_enum_validator, python_schema=strict_python_schema)
|
||||
return core_schema.lax_or_strict_schema(
|
||||
lax_schema=lax, strict_schema=strict, ref=enum_ref, metadata={'pydantic_js_functions': [get_json_schema]}
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass(**slots_true)
|
||||
class InnerSchemaValidator:
|
||||
"""Use a fixed CoreSchema, avoiding interference from outward annotations."""
|
||||
|
||||
core_schema: CoreSchema
|
||||
js_schema: JsonSchemaValue | None = None
|
||||
js_core_schema: CoreSchema | None = None
|
||||
js_schema_update: JsonSchemaValue | None = None
|
||||
|
||||
def __get_pydantic_json_schema__(self, _schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
|
||||
if self.js_schema is not None:
|
||||
return self.js_schema
|
||||
js_schema = handler(self.js_core_schema or self.core_schema)
|
||||
if self.js_schema_update is not None:
|
||||
js_schema.update(self.js_schema_update)
|
||||
return js_schema
|
||||
|
||||
def __get_pydantic_core_schema__(self, _source_type: Any, _handler: GetCoreSchemaHandler) -> CoreSchema:
|
||||
return self.core_schema
|
||||
|
||||
|
||||
def decimal_prepare_pydantic_annotations(
|
||||
source: Any, annotations: Iterable[Any], config: ConfigDict
|
||||
) -> tuple[Any, list[Any]] | None:
|
||||
if source is not decimal.Decimal:
|
||||
return None
|
||||
|
||||
metadata, remaining_annotations = _known_annotated_metadata.collect_known_metadata(annotations)
|
||||
|
||||
config_allow_inf_nan = config.get('allow_inf_nan')
|
||||
if config_allow_inf_nan is not None:
|
||||
metadata.setdefault('allow_inf_nan', config_allow_inf_nan)
|
||||
|
||||
_known_annotated_metadata.check_metadata(
|
||||
metadata, {*_known_annotated_metadata.FLOAT_CONSTRAINTS, 'max_digits', 'decimal_places'}, decimal.Decimal
|
||||
)
|
||||
return source, [InnerSchemaValidator(core_schema.decimal_schema(**metadata)), *remaining_annotations]
|
||||
|
||||
|
||||
def datetime_prepare_pydantic_annotations(
|
||||
source_type: Any, annotations: Iterable[Any], _config: ConfigDict
|
||||
) -> tuple[Any, list[Any]] | None:
|
||||
import datetime
|
||||
|
||||
metadata, remaining_annotations = _known_annotated_metadata.collect_known_metadata(annotations)
|
||||
if source_type is datetime.date:
|
||||
sv = InnerSchemaValidator(core_schema.date_schema(**metadata))
|
||||
elif source_type is datetime.datetime:
|
||||
sv = InnerSchemaValidator(core_schema.datetime_schema(**metadata))
|
||||
elif source_type is datetime.time:
|
||||
sv = InnerSchemaValidator(core_schema.time_schema(**metadata))
|
||||
elif source_type is datetime.timedelta:
|
||||
sv = InnerSchemaValidator(core_schema.timedelta_schema(**metadata))
|
||||
else:
|
||||
return None
|
||||
# check now that we know the source type is correct
|
||||
_known_annotated_metadata.check_metadata(metadata, _known_annotated_metadata.DATE_TIME_CONSTRAINTS, source_type)
|
||||
return (source_type, [sv, *remaining_annotations])
|
||||
|
||||
|
||||
def uuid_prepare_pydantic_annotations(
|
||||
source_type: Any, annotations: Iterable[Any], _config: ConfigDict
|
||||
) -> tuple[Any, list[Any]] | None:
|
||||
# UUIDs have no constraints - they are fixed length, constructing a UUID instance checks the length
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
if source_type is not UUID:
|
||||
return None
|
||||
|
||||
return (source_type, [InnerSchemaValidator(core_schema.uuid_schema()), *annotations])
|
||||
|
||||
|
||||
def path_schema_prepare_pydantic_annotations(
|
||||
source_type: Any, annotations: Iterable[Any], _config: ConfigDict
|
||||
) -> tuple[Any, list[Any]] | None:
|
||||
import pathlib
|
||||
|
||||
if source_type not in {
|
||||
os.PathLike,
|
||||
pathlib.Path,
|
||||
pathlib.PurePath,
|
||||
pathlib.PosixPath,
|
||||
pathlib.PurePosixPath,
|
||||
pathlib.PureWindowsPath,
|
||||
}:
|
||||
return None
|
||||
|
||||
metadata, remaining_annotations = _known_annotated_metadata.collect_known_metadata(annotations)
|
||||
_known_annotated_metadata.check_metadata(metadata, _known_annotated_metadata.STR_CONSTRAINTS, source_type)
|
||||
|
||||
construct_path = pathlib.PurePath if source_type is os.PathLike else source_type
|
||||
|
||||
def path_validator(input_value: str) -> os.PathLike[Any]:
|
||||
try:
|
||||
return construct_path(input_value)
|
||||
except TypeError as e:
|
||||
raise PydanticCustomError('path_type', 'Input is not a valid path') from e
|
||||
|
||||
constrained_str_schema = core_schema.str_schema(**metadata)
|
||||
|
||||
instance_schema = core_schema.json_or_python_schema(
|
||||
json_schema=core_schema.no_info_after_validator_function(path_validator, constrained_str_schema),
|
||||
python_schema=core_schema.is_instance_schema(source_type),
|
||||
)
|
||||
|
||||
strict: bool | None = None
|
||||
for annotation in annotations:
|
||||
if isinstance(annotation, Strict):
|
||||
strict = annotation.strict
|
||||
|
||||
schema = core_schema.lax_or_strict_schema(
|
||||
lax_schema=core_schema.union_schema(
|
||||
[
|
||||
instance_schema,
|
||||
core_schema.no_info_after_validator_function(path_validator, constrained_str_schema),
|
||||
],
|
||||
custom_error_type='path_type',
|
||||
custom_error_message='Input is not a valid path',
|
||||
strict=True,
|
||||
),
|
||||
strict_schema=instance_schema,
|
||||
serialization=core_schema.to_string_ser_schema(),
|
||||
strict=strict,
|
||||
)
|
||||
|
||||
return (
|
||||
source_type,
|
||||
[
|
||||
InnerSchemaValidator(schema, js_core_schema=constrained_str_schema, js_schema_update={'format': 'path'}),
|
||||
*remaining_annotations,
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def dequeue_validator(
|
||||
input_value: Any, handler: core_schema.ValidatorFunctionWrapHandler, maxlen: None | int
|
||||
) -> collections.deque[Any]:
|
||||
if isinstance(input_value, collections.deque):
|
||||
maxlens = [v for v in (input_value.maxlen, maxlen) if v is not None]
|
||||
if maxlens:
|
||||
maxlen = min(maxlens)
|
||||
return collections.deque(handler(input_value), maxlen=maxlen)
|
||||
else:
|
||||
return collections.deque(handler(input_value), maxlen=maxlen)
|
||||
|
||||
|
||||
@dataclasses.dataclass(**slots_true)
|
||||
class SequenceValidator:
|
||||
mapped_origin: type[Any]
|
||||
item_source_type: type[Any]
|
||||
min_length: int | None = None
|
||||
max_length: int | None = None
|
||||
strict: bool = False
|
||||
|
||||
def serialize_sequence_via_list(
|
||||
self, v: Any, handler: core_schema.SerializerFunctionWrapHandler, info: core_schema.SerializationInfo
|
||||
) -> Any:
|
||||
items: list[Any] = []
|
||||
for index, item in enumerate(v):
|
||||
try:
|
||||
v = handler(item, index)
|
||||
except PydanticOmit:
|
||||
pass
|
||||
else:
|
||||
items.append(v)
|
||||
|
||||
if info.mode_is_json():
|
||||
return items
|
||||
else:
|
||||
return self.mapped_origin(items)
|
||||
|
||||
def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaHandler) -> CoreSchema:
|
||||
if self.item_source_type is Any:
|
||||
items_schema = None
|
||||
else:
|
||||
items_schema = handler.generate_schema(self.item_source_type)
|
||||
|
||||
metadata = {'min_length': self.min_length, 'max_length': self.max_length, 'strict': self.strict}
|
||||
|
||||
if self.mapped_origin in (list, set, frozenset):
|
||||
if self.mapped_origin is list:
|
||||
constrained_schema = core_schema.list_schema(items_schema, **metadata)
|
||||
elif self.mapped_origin is set:
|
||||
constrained_schema = core_schema.set_schema(items_schema, **metadata)
|
||||
else:
|
||||
assert self.mapped_origin is frozenset # safety check in case we forget to add a case
|
||||
constrained_schema = core_schema.frozenset_schema(items_schema, **metadata)
|
||||
|
||||
schema = constrained_schema
|
||||
else:
|
||||
# safety check in case we forget to add a case
|
||||
assert self.mapped_origin in (collections.deque, collections.Counter)
|
||||
|
||||
if self.mapped_origin is collections.deque:
|
||||
# if we have a MaxLen annotation might as well set that as the default maxlen on the deque
|
||||
# this lets us re-use existing metadata annotations to let users set the maxlen on a dequeue
|
||||
# that e.g. comes from JSON
|
||||
coerce_instance_wrap = partial(
|
||||
core_schema.no_info_wrap_validator_function,
|
||||
partial(dequeue_validator, maxlen=metadata.get('max_length', None)),
|
||||
)
|
||||
else:
|
||||
coerce_instance_wrap = partial(core_schema.no_info_after_validator_function, self.mapped_origin)
|
||||
|
||||
constrained_schema = core_schema.list_schema(items_schema, **metadata)
|
||||
|
||||
check_instance = core_schema.json_or_python_schema(
|
||||
json_schema=core_schema.list_schema(),
|
||||
python_schema=core_schema.is_instance_schema(self.mapped_origin),
|
||||
)
|
||||
|
||||
serialization = core_schema.wrap_serializer_function_ser_schema(
|
||||
self.serialize_sequence_via_list, schema=items_schema or core_schema.any_schema(), info_arg=True
|
||||
)
|
||||
|
||||
strict = core_schema.chain_schema([check_instance, coerce_instance_wrap(constrained_schema)])
|
||||
|
||||
if metadata.get('strict', False):
|
||||
schema = strict
|
||||
else:
|
||||
lax = coerce_instance_wrap(constrained_schema)
|
||||
schema = core_schema.lax_or_strict_schema(lax_schema=lax, strict_schema=strict)
|
||||
schema['serialization'] = serialization
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
SEQUENCE_ORIGIN_MAP: dict[Any, Any] = {
|
||||
typing.Deque: collections.deque,
|
||||
collections.deque: collections.deque,
|
||||
list: list,
|
||||
typing.List: list,
|
||||
set: set,
|
||||
typing.AbstractSet: set,
|
||||
typing.Set: set,
|
||||
frozenset: frozenset,
|
||||
typing.FrozenSet: frozenset,
|
||||
typing.Sequence: list,
|
||||
typing.MutableSequence: list,
|
||||
typing.MutableSet: set,
|
||||
# this doesn't handle subclasses of these
|
||||
# parametrized typing.Set creates one of these
|
||||
collections.abc.MutableSet: set,
|
||||
collections.abc.Set: frozenset,
|
||||
}
|
||||
|
||||
|
||||
def identity(s: CoreSchema) -> CoreSchema:
|
||||
return s
|
||||
|
||||
|
||||
def sequence_like_prepare_pydantic_annotations(
|
||||
source_type: Any, annotations: Iterable[Any], _config: ConfigDict
|
||||
) -> tuple[Any, list[Any]] | None:
|
||||
origin: Any = get_origin(source_type)
|
||||
|
||||
mapped_origin = SEQUENCE_ORIGIN_MAP.get(origin, None) if origin else SEQUENCE_ORIGIN_MAP.get(source_type, None)
|
||||
if mapped_origin is None:
|
||||
return None
|
||||
|
||||
args = get_args(source_type)
|
||||
|
||||
if not args:
|
||||
args = (Any,)
|
||||
elif len(args) != 1:
|
||||
raise ValueError('Expected sequence to have exactly 1 generic parameter')
|
||||
|
||||
item_source_type = args[0]
|
||||
|
||||
metadata, remaining_annotations = _known_annotated_metadata.collect_known_metadata(annotations)
|
||||
_known_annotated_metadata.check_metadata(metadata, _known_annotated_metadata.SEQUENCE_CONSTRAINTS, source_type)
|
||||
|
||||
return (source_type, [SequenceValidator(mapped_origin, item_source_type, **metadata), *remaining_annotations])
|
||||
|
||||
|
||||
MAPPING_ORIGIN_MAP: dict[Any, Any] = {
|
||||
typing.DefaultDict: collections.defaultdict,
|
||||
collections.defaultdict: collections.defaultdict,
|
||||
collections.OrderedDict: collections.OrderedDict,
|
||||
typing_extensions.OrderedDict: collections.OrderedDict,
|
||||
dict: dict,
|
||||
typing.Dict: dict,
|
||||
collections.Counter: collections.Counter,
|
||||
typing.Counter: collections.Counter,
|
||||
# this doesn't handle subclasses of these
|
||||
typing.Mapping: dict,
|
||||
typing.MutableMapping: dict,
|
||||
# parametrized typing.{Mutable}Mapping creates one of these
|
||||
collections.abc.MutableMapping: dict,
|
||||
collections.abc.Mapping: dict,
|
||||
}
|
||||
|
||||
|
||||
def defaultdict_validator(
|
||||
input_value: Any, handler: core_schema.ValidatorFunctionWrapHandler, default_default_factory: Callable[[], Any]
|
||||
) -> collections.defaultdict[Any, Any]:
|
||||
if isinstance(input_value, collections.defaultdict):
|
||||
default_factory = input_value.default_factory
|
||||
return collections.defaultdict(default_factory, handler(input_value))
|
||||
else:
|
||||
return collections.defaultdict(default_default_factory, handler(input_value))
|
||||
|
||||
|
||||
def get_defaultdict_default_default_factory(values_source_type: Any) -> Callable[[], Any]:
|
||||
def infer_default() -> Callable[[], Any]:
|
||||
allowed_default_types: dict[Any, Any] = {
|
||||
typing.Tuple: tuple,
|
||||
tuple: tuple,
|
||||
collections.abc.Sequence: tuple,
|
||||
collections.abc.MutableSequence: list,
|
||||
typing.List: list,
|
||||
list: list,
|
||||
typing.Sequence: list,
|
||||
typing.Set: set,
|
||||
set: set,
|
||||
typing.MutableSet: set,
|
||||
collections.abc.MutableSet: set,
|
||||
collections.abc.Set: frozenset,
|
||||
typing.MutableMapping: dict,
|
||||
typing.Mapping: dict,
|
||||
collections.abc.Mapping: dict,
|
||||
collections.abc.MutableMapping: dict,
|
||||
float: float,
|
||||
int: int,
|
||||
str: str,
|
||||
bool: bool,
|
||||
}
|
||||
values_type_origin = get_origin(values_source_type) or values_source_type
|
||||
instructions = 'set using `DefaultDict[..., Annotated[..., Field(default_factory=...)]]`'
|
||||
if isinstance(values_type_origin, TypeVar):
|
||||
|
||||
def type_var_default_factory() -> None:
|
||||
raise RuntimeError(
|
||||
'Generic defaultdict cannot be used without a concrete value type or an'
|
||||
' explicit default factory, ' + instructions
|
||||
)
|
||||
|
||||
return type_var_default_factory
|
||||
elif values_type_origin not in allowed_default_types:
|
||||
# a somewhat subjective set of types that have reasonable default values
|
||||
allowed_msg = ', '.join([t.__name__ for t in set(allowed_default_types.values())])
|
||||
raise PydanticSchemaGenerationError(
|
||||
f'Unable to infer a default factory for keys of type {values_source_type}.'
|
||||
f' Only {allowed_msg} are supported, other types require an explicit default factory'
|
||||
' ' + instructions
|
||||
)
|
||||
return allowed_default_types[values_type_origin]
|
||||
|
||||
# Assume Annotated[..., Field(...)]
|
||||
if _typing_extra.is_annotated(values_source_type):
|
||||
field_info = next((v for v in get_args(values_source_type) if isinstance(v, FieldInfo)), None)
|
||||
else:
|
||||
field_info = None
|
||||
if field_info and field_info.default_factory:
|
||||
default_default_factory = field_info.default_factory
|
||||
else:
|
||||
default_default_factory = infer_default()
|
||||
return default_default_factory
|
||||
|
||||
|
||||
@dataclasses.dataclass(**slots_true)
|
||||
class MappingValidator:
|
||||
mapped_origin: type[Any]
|
||||
keys_source_type: type[Any]
|
||||
values_source_type: type[Any]
|
||||
min_length: int | None = None
|
||||
max_length: int | None = None
|
||||
strict: bool = False
|
||||
|
||||
def serialize_mapping_via_dict(self, v: Any, handler: core_schema.SerializerFunctionWrapHandler) -> Any:
|
||||
return handler(v)
|
||||
|
||||
def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaHandler) -> CoreSchema:
|
||||
if self.keys_source_type is Any:
|
||||
keys_schema = None
|
||||
else:
|
||||
keys_schema = handler.generate_schema(self.keys_source_type)
|
||||
if self.values_source_type is Any:
|
||||
values_schema = None
|
||||
else:
|
||||
values_schema = handler.generate_schema(self.values_source_type)
|
||||
|
||||
metadata = {'min_length': self.min_length, 'max_length': self.max_length, 'strict': self.strict}
|
||||
|
||||
if self.mapped_origin is dict:
|
||||
schema = core_schema.dict_schema(keys_schema, values_schema, **metadata)
|
||||
else:
|
||||
constrained_schema = core_schema.dict_schema(keys_schema, values_schema, **metadata)
|
||||
check_instance = core_schema.json_or_python_schema(
|
||||
json_schema=core_schema.dict_schema(),
|
||||
python_schema=core_schema.is_instance_schema(self.mapped_origin),
|
||||
)
|
||||
|
||||
if self.mapped_origin is collections.defaultdict:
|
||||
default_default_factory = get_defaultdict_default_default_factory(self.values_source_type)
|
||||
coerce_instance_wrap = partial(
|
||||
core_schema.no_info_wrap_validator_function,
|
||||
partial(defaultdict_validator, default_default_factory=default_default_factory),
|
||||
)
|
||||
else:
|
||||
coerce_instance_wrap = partial(core_schema.no_info_after_validator_function, self.mapped_origin)
|
||||
|
||||
serialization = core_schema.wrap_serializer_function_ser_schema(
|
||||
self.serialize_mapping_via_dict,
|
||||
schema=core_schema.dict_schema(
|
||||
keys_schema or core_schema.any_schema(), values_schema or core_schema.any_schema()
|
||||
),
|
||||
info_arg=False,
|
||||
)
|
||||
|
||||
strict = core_schema.chain_schema([check_instance, coerce_instance_wrap(constrained_schema)])
|
||||
|
||||
if metadata.get('strict', False):
|
||||
schema = strict
|
||||
else:
|
||||
lax = coerce_instance_wrap(constrained_schema)
|
||||
schema = core_schema.lax_or_strict_schema(lax_schema=lax, strict_schema=strict)
|
||||
schema['serialization'] = serialization
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
def mapping_like_prepare_pydantic_annotations(
|
||||
source_type: Any, annotations: Iterable[Any], _config: ConfigDict
|
||||
) -> tuple[Any, list[Any]] | None:
|
||||
origin: Any = get_origin(source_type)
|
||||
|
||||
mapped_origin = MAPPING_ORIGIN_MAP.get(origin, None) if origin else MAPPING_ORIGIN_MAP.get(source_type, None)
|
||||
if mapped_origin is None:
|
||||
return None
|
||||
|
||||
args = get_args(source_type)
|
||||
|
||||
if not args:
|
||||
args = (Any, Any)
|
||||
elif mapped_origin is collections.Counter:
|
||||
# a single generic
|
||||
if len(args) != 1:
|
||||
raise ValueError('Expected Counter to have exactly 1 generic parameter')
|
||||
args = (args[0], int) # keys are always an int
|
||||
elif len(args) != 2:
|
||||
raise ValueError('Expected mapping to have exactly 2 generic parameters')
|
||||
|
||||
keys_source_type, values_source_type = args
|
||||
|
||||
metadata, remaining_annotations = _known_annotated_metadata.collect_known_metadata(annotations)
|
||||
_known_annotated_metadata.check_metadata(metadata, _known_annotated_metadata.SEQUENCE_CONSTRAINTS, source_type)
|
||||
|
||||
return (
|
||||
source_type,
|
||||
[
|
||||
MappingValidator(mapped_origin, keys_source_type, values_source_type, **metadata),
|
||||
*remaining_annotations,
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def ip_prepare_pydantic_annotations(
|
||||
source_type: Any, annotations: Iterable[Any], _config: ConfigDict
|
||||
) -> tuple[Any, list[Any]] | None:
|
||||
def make_strict_ip_schema(tp: type[Any]) -> CoreSchema:
|
||||
return core_schema.json_or_python_schema(
|
||||
json_schema=core_schema.no_info_after_validator_function(tp, core_schema.str_schema()),
|
||||
python_schema=core_schema.is_instance_schema(tp),
|
||||
)
|
||||
|
||||
if source_type is IPv4Address:
|
||||
return source_type, [
|
||||
SchemaTransformer(
|
||||
lambda _1, _2: core_schema.lax_or_strict_schema(
|
||||
lax_schema=core_schema.no_info_plain_validator_function(_validators.ip_v4_address_validator),
|
||||
strict_schema=make_strict_ip_schema(IPv4Address),
|
||||
serialization=core_schema.to_string_ser_schema(),
|
||||
),
|
||||
lambda _1, _2: {'type': 'string', 'format': 'ipv4'},
|
||||
),
|
||||
*annotations,
|
||||
]
|
||||
if source_type is IPv4Network:
|
||||
return source_type, [
|
||||
SchemaTransformer(
|
||||
lambda _1, _2: core_schema.lax_or_strict_schema(
|
||||
lax_schema=core_schema.no_info_plain_validator_function(_validators.ip_v4_network_validator),
|
||||
strict_schema=make_strict_ip_schema(IPv4Network),
|
||||
serialization=core_schema.to_string_ser_schema(),
|
||||
),
|
||||
lambda _1, _2: {'type': 'string', 'format': 'ipv4network'},
|
||||
),
|
||||
*annotations,
|
||||
]
|
||||
if source_type is IPv4Interface:
|
||||
return source_type, [
|
||||
SchemaTransformer(
|
||||
lambda _1, _2: core_schema.lax_or_strict_schema(
|
||||
lax_schema=core_schema.no_info_plain_validator_function(_validators.ip_v4_interface_validator),
|
||||
strict_schema=make_strict_ip_schema(IPv4Interface),
|
||||
serialization=core_schema.to_string_ser_schema(),
|
||||
),
|
||||
lambda _1, _2: {'type': 'string', 'format': 'ipv4interface'},
|
||||
),
|
||||
*annotations,
|
||||
]
|
||||
|
||||
if source_type is IPv6Address:
|
||||
return source_type, [
|
||||
SchemaTransformer(
|
||||
lambda _1, _2: core_schema.lax_or_strict_schema(
|
||||
lax_schema=core_schema.no_info_plain_validator_function(_validators.ip_v6_address_validator),
|
||||
strict_schema=make_strict_ip_schema(IPv6Address),
|
||||
serialization=core_schema.to_string_ser_schema(),
|
||||
),
|
||||
lambda _1, _2: {'type': 'string', 'format': 'ipv6'},
|
||||
),
|
||||
*annotations,
|
||||
]
|
||||
if source_type is IPv6Network:
|
||||
return source_type, [
|
||||
SchemaTransformer(
|
||||
lambda _1, _2: core_schema.lax_or_strict_schema(
|
||||
lax_schema=core_schema.no_info_plain_validator_function(_validators.ip_v6_network_validator),
|
||||
strict_schema=make_strict_ip_schema(IPv6Network),
|
||||
serialization=core_schema.to_string_ser_schema(),
|
||||
),
|
||||
lambda _1, _2: {'type': 'string', 'format': 'ipv6network'},
|
||||
),
|
||||
*annotations,
|
||||
]
|
||||
if source_type is IPv6Interface:
|
||||
return source_type, [
|
||||
SchemaTransformer(
|
||||
lambda _1, _2: core_schema.lax_or_strict_schema(
|
||||
lax_schema=core_schema.no_info_plain_validator_function(_validators.ip_v6_interface_validator),
|
||||
strict_schema=make_strict_ip_schema(IPv6Interface),
|
||||
serialization=core_schema.to_string_ser_schema(),
|
||||
),
|
||||
lambda _1, _2: {'type': 'string', 'format': 'ipv6interface'},
|
||||
),
|
||||
*annotations,
|
||||
]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def url_prepare_pydantic_annotations(
|
||||
source_type: Any, annotations: Iterable[Any], _config: ConfigDict
|
||||
) -> tuple[Any, list[Any]] | None:
|
||||
if source_type is Url:
|
||||
return source_type, [
|
||||
SchemaTransformer(
|
||||
lambda _1, _2: core_schema.url_schema(),
|
||||
lambda cs, handler: handler(cs),
|
||||
),
|
||||
*annotations,
|
||||
]
|
||||
if source_type is MultiHostUrl:
|
||||
return source_type, [
|
||||
SchemaTransformer(
|
||||
lambda _1, _2: core_schema.multi_host_url_schema(),
|
||||
lambda cs, handler: handler(cs),
|
||||
),
|
||||
*annotations,
|
||||
]
|
||||
|
||||
|
||||
PREPARE_METHODS: tuple[Callable[[Any, Iterable[Any], ConfigDict], tuple[Any, list[Any]] | None], ...] = (
|
||||
decimal_prepare_pydantic_annotations,
|
||||
sequence_like_prepare_pydantic_annotations,
|
||||
datetime_prepare_pydantic_annotations,
|
||||
uuid_prepare_pydantic_annotations,
|
||||
path_schema_prepare_pydantic_annotations,
|
||||
mapping_like_prepare_pydantic_annotations,
|
||||
ip_prepare_pydantic_annotations,
|
||||
url_prepare_pydantic_annotations,
|
||||
)
|
||||
@@ -1,244 +1,589 @@
|
||||
"""Logic for interacting with type annotations, mostly extensions, shims and hacks to wrap python's typing module."""
|
||||
from __future__ import annotations as _annotations
|
||||
"""Logic for interacting with type annotations, mostly extensions, shims and hacks to wrap Python's typing module."""
|
||||
|
||||
import dataclasses
|
||||
from __future__ import annotations
|
||||
|
||||
import collections.abc
|
||||
import re
|
||||
import sys
|
||||
import types
|
||||
import typing
|
||||
from collections.abc import Callable
|
||||
from functools import partial
|
||||
from types import GetSetDescriptorType
|
||||
from typing import TYPE_CHECKING, Any, ForwardRef
|
||||
from typing import TYPE_CHECKING, Any, Callable, cast
|
||||
|
||||
from typing_extensions import Annotated, Final, Literal, TypeAliasType, TypeGuard, get_args, get_origin
|
||||
import typing_extensions
|
||||
from typing_extensions import deprecated, get_args, get_origin
|
||||
from typing_inspection import typing_objects
|
||||
from typing_inspection.introspection import is_union_origin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ._dataclasses import StandardDataclass
|
||||
|
||||
try:
|
||||
from typing import _TypingBase # type: ignore[attr-defined]
|
||||
except ImportError:
|
||||
from typing import _Final as _TypingBase # type: ignore[attr-defined]
|
||||
|
||||
typing_base = _TypingBase
|
||||
|
||||
|
||||
if sys.version_info < (3, 9):
|
||||
# python < 3.9 does not have GenericAlias (list[int], tuple[str, ...] and so on)
|
||||
TypingGenericAlias = ()
|
||||
else:
|
||||
from typing import GenericAlias as TypingGenericAlias # type: ignore
|
||||
|
||||
|
||||
if sys.version_info < (3, 11):
|
||||
from typing_extensions import NotRequired, Required
|
||||
else:
|
||||
from typing import NotRequired, Required # noqa: F401
|
||||
|
||||
|
||||
if sys.version_info < (3, 10):
|
||||
|
||||
def origin_is_union(tp: type[Any] | None) -> bool:
|
||||
return tp is typing.Union
|
||||
|
||||
WithArgsTypes = (TypingGenericAlias,)
|
||||
|
||||
else:
|
||||
|
||||
def origin_is_union(tp: type[Any] | None) -> bool:
|
||||
return tp is typing.Union or tp is types.UnionType
|
||||
|
||||
WithArgsTypes = typing._GenericAlias, types.GenericAlias, types.UnionType # type: ignore[attr-defined]
|
||||
from pydantic.version import version_short
|
||||
|
||||
from ._namespace_utils import GlobalsNamespace, MappingNamespace, NsResolver, get_module_ns_of
|
||||
|
||||
if sys.version_info < (3, 10):
|
||||
NoneType = type(None)
|
||||
EllipsisType = type(Ellipsis)
|
||||
else:
|
||||
from types import EllipsisType as EllipsisType
|
||||
from types import NoneType as NoneType
|
||||
|
||||
if sys.version_info >= (3, 14):
|
||||
import annotationlib
|
||||
|
||||
LITERAL_TYPES: set[Any] = {Literal}
|
||||
if hasattr(typing, 'Literal'):
|
||||
LITERAL_TYPES.add(typing.Literal) # type: ignore
|
||||
if TYPE_CHECKING:
|
||||
from pydantic import BaseModel
|
||||
|
||||
NONE_TYPES: tuple[Any, ...] = (None, NoneType, *(tp[None] for tp in LITERAL_TYPES))
|
||||
# As per https://typing-extensions.readthedocs.io/en/latest/#runtime-use-of-types,
|
||||
# always check for both `typing` and `typing_extensions` variants of a typing construct.
|
||||
# (this is implemented differently than the suggested approach in the `typing_extensions`
|
||||
# docs for performance).
|
||||
|
||||
|
||||
TypeVarType = Any # since mypy doesn't allow the use of TypeVar as a type
|
||||
_t_annotated = typing.Annotated
|
||||
_te_annotated = typing_extensions.Annotated
|
||||
|
||||
|
||||
def is_none_type(type_: Any) -> bool:
|
||||
return type_ in NONE_TYPES
|
||||
def is_annotated(tp: Any, /) -> bool:
|
||||
"""Return whether the provided argument is a `Annotated` special form.
|
||||
|
||||
|
||||
def is_callable_type(type_: type[Any]) -> bool:
|
||||
return type_ is Callable or get_origin(type_) is Callable
|
||||
|
||||
|
||||
def is_literal_type(type_: type[Any]) -> bool:
|
||||
return Literal is not None and get_origin(type_) in LITERAL_TYPES
|
||||
|
||||
|
||||
def literal_values(type_: type[Any]) -> tuple[Any, ...]:
|
||||
return get_args(type_)
|
||||
|
||||
|
||||
def all_literal_values(type_: type[Any]) -> list[Any]:
|
||||
"""This method is used to retrieve all Literal values as
|
||||
Literal can be used recursively (see https://www.python.org/dev/peps/pep-0586)
|
||||
e.g. `Literal[Literal[Literal[1, 2, 3], "foo"], 5, None]`.
|
||||
```python {test="skip" lint="skip"}
|
||||
is_annotated(Annotated[int, ...])
|
||||
#> True
|
||||
```
|
||||
"""
|
||||
if not is_literal_type(type_):
|
||||
return [type_]
|
||||
|
||||
values = literal_values(type_)
|
||||
return list(x for value in values for x in all_literal_values(value))
|
||||
origin = get_origin(tp)
|
||||
return origin is _t_annotated or origin is _te_annotated
|
||||
|
||||
|
||||
def is_annotated(ann_type: Any) -> bool:
|
||||
from ._utils import lenient_issubclass
|
||||
|
||||
origin = get_origin(ann_type)
|
||||
return origin is not None and lenient_issubclass(origin, Annotated)
|
||||
def annotated_type(tp: Any, /) -> Any | None:
|
||||
"""Return the type of the `Annotated` special form, or `None`."""
|
||||
return tp.__origin__ if typing_objects.is_annotated(get_origin(tp)) else None
|
||||
|
||||
|
||||
def is_namedtuple(type_: type[Any]) -> bool:
|
||||
"""Check if a given class is a named tuple.
|
||||
It can be either a `typing.NamedTuple` or `collections.namedtuple`.
|
||||
def unpack_type(tp: Any, /) -> Any | None:
|
||||
"""Return the type wrapped by the `Unpack` special form, or `None`."""
|
||||
return get_args(tp)[0] if typing_objects.is_unpack(get_origin(tp)) else None
|
||||
|
||||
|
||||
def is_hashable(tp: Any, /) -> bool:
|
||||
"""Return whether the provided argument is the `Hashable` class.
|
||||
|
||||
```python {test="skip" lint="skip"}
|
||||
is_hashable(Hashable)
|
||||
#> True
|
||||
```
|
||||
"""
|
||||
from ._utils import lenient_issubclass
|
||||
|
||||
return lenient_issubclass(type_, tuple) and hasattr(type_, '_fields')
|
||||
# `get_origin` is documented as normalizing any typing-module aliases to `collections` classes,
|
||||
# hence the second check:
|
||||
return tp is collections.abc.Hashable or get_origin(tp) is collections.abc.Hashable
|
||||
|
||||
|
||||
test_new_type = typing.NewType('test_new_type', str)
|
||||
def is_callable(tp: Any, /) -> bool:
|
||||
"""Return whether the provided argument is a `Callable`, parametrized or not.
|
||||
|
||||
|
||||
def is_new_type(type_: type[Any]) -> bool:
|
||||
"""Check whether type_ was created using typing.NewType.
|
||||
|
||||
Can't use isinstance because it fails <3.10.
|
||||
```python {test="skip" lint="skip"}
|
||||
is_callable(Callable[[int], str])
|
||||
#> True
|
||||
is_callable(typing.Callable)
|
||||
#> True
|
||||
is_callable(collections.abc.Callable)
|
||||
#> True
|
||||
```
|
||||
"""
|
||||
return isinstance(type_, test_new_type.__class__) and hasattr(type_, '__supertype__') # type: ignore[arg-type]
|
||||
# `get_origin` is documented as normalizing any typing-module aliases to `collections` classes,
|
||||
# hence the second check:
|
||||
return tp is collections.abc.Callable or get_origin(tp) is collections.abc.Callable
|
||||
|
||||
|
||||
def _check_classvar(v: type[Any] | None) -> bool:
|
||||
if v is None:
|
||||
return False
|
||||
|
||||
return v.__class__ == typing.ClassVar.__class__ and getattr(v, '_name', None) == 'ClassVar'
|
||||
_classvar_re = re.compile(r'((\w+\.)?Annotated\[)?(\w+\.)?ClassVar\[')
|
||||
|
||||
|
||||
def is_classvar(ann_type: type[Any]) -> bool:
|
||||
if _check_classvar(ann_type) or _check_classvar(get_origin(ann_type)):
|
||||
def is_classvar_annotation(tp: Any, /) -> bool:
|
||||
"""Return whether the provided argument represents a class variable annotation.
|
||||
|
||||
Although not explicitly stated by the typing specification, `ClassVar` can be used
|
||||
inside `Annotated` and as such, this function checks for this specific scenario.
|
||||
|
||||
Because this function is used to detect class variables before evaluating forward references
|
||||
(or because evaluation failed), we also implement a naive regex match implementation. This is
|
||||
required because class variables are inspected before fields are collected, so we try to be
|
||||
as accurate as possible.
|
||||
"""
|
||||
if typing_objects.is_classvar(tp):
|
||||
return True
|
||||
|
||||
# this is an ugly workaround for class vars that contain forward references and are therefore themselves
|
||||
# forward references, see #3679
|
||||
if ann_type.__class__ == typing.ForwardRef and ann_type.__forward_arg__.startswith('ClassVar['): # type: ignore
|
||||
origin = get_origin(tp)
|
||||
|
||||
if typing_objects.is_classvar(origin):
|
||||
return True
|
||||
|
||||
if typing_objects.is_annotated(origin):
|
||||
annotated_type = tp.__origin__
|
||||
if typing_objects.is_classvar(annotated_type) or typing_objects.is_classvar(get_origin(annotated_type)):
|
||||
return True
|
||||
|
||||
str_ann: str | None = None
|
||||
if isinstance(tp, typing.ForwardRef):
|
||||
str_ann = tp.__forward_arg__
|
||||
if isinstance(tp, str):
|
||||
str_ann = tp
|
||||
|
||||
if str_ann is not None and _classvar_re.match(str_ann):
|
||||
# stdlib dataclasses do something similar, although a bit more advanced
|
||||
# (see `dataclass._is_type`).
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _check_finalvar(v: type[Any] | None) -> bool:
|
||||
"""Check if a given type is a `typing.Final` type."""
|
||||
if v is None:
|
||||
return False
|
||||
|
||||
return v.__class__ == Final.__class__ and (sys.version_info < (3, 8) or getattr(v, '_name', None) == 'Final')
|
||||
_t_final = typing.Final
|
||||
_te_final = typing_extensions.Final
|
||||
|
||||
|
||||
def is_finalvar(ann_type: Any) -> bool:
|
||||
return _check_finalvar(ann_type) or _check_finalvar(get_origin(ann_type))
|
||||
# TODO implement `is_finalvar_annotation` as Final can be wrapped with other special forms:
|
||||
def is_finalvar(tp: Any, /) -> bool:
|
||||
"""Return whether the provided argument is a `Final` special form, parametrized or not.
|
||||
|
||||
```python {test="skip" lint="skip"}
|
||||
is_finalvar(Final[int])
|
||||
#> True
|
||||
is_finalvar(Final)
|
||||
#> True
|
||||
"""
|
||||
# Final is not necessarily parametrized:
|
||||
if tp is _t_final or tp is _te_final:
|
||||
return True
|
||||
origin = get_origin(tp)
|
||||
return origin is _t_final or origin is _te_final
|
||||
|
||||
|
||||
def parent_frame_namespace(*, parent_depth: int = 2) -> dict[str, Any] | None:
|
||||
"""We allow use of items in parent namespace to get around the issue with `get_type_hints` only looking in the
|
||||
global module namespace. See https://github.com/pydantic/pydantic/issues/2678#issuecomment-1008139014 -> Scope
|
||||
and suggestion at the end of the next comment by @gvanrossum.
|
||||
_NONE_TYPES: tuple[Any, ...] = (None, NoneType, typing.Literal[None], typing_extensions.Literal[None])
|
||||
|
||||
WARNING 1: it matters exactly where this is called. By default, this function will build a namespace from the
|
||||
parent of where it is called.
|
||||
|
||||
WARNING 2: this only looks in the parent namespace, not other parents since (AFAIK) there's no way to collect a
|
||||
dict of exactly what's in scope. Using `f_back` would work sometimes but would be very wrong and confusing in many
|
||||
other cases. See https://discuss.python.org/t/is-there-a-way-to-access-parent-nested-namespaces/20659.
|
||||
def is_none_type(tp: Any, /) -> bool:
|
||||
"""Return whether the argument represents the `None` type as part of an annotation.
|
||||
|
||||
```python {test="skip" lint="skip"}
|
||||
is_none_type(None)
|
||||
#> True
|
||||
is_none_type(NoneType)
|
||||
#> True
|
||||
is_none_type(Literal[None])
|
||||
#> True
|
||||
is_none_type(type[None])
|
||||
#> False
|
||||
"""
|
||||
return tp in _NONE_TYPES
|
||||
|
||||
|
||||
def is_namedtuple(tp: Any, /) -> bool:
|
||||
"""Return whether the provided argument is a named tuple class.
|
||||
|
||||
The class can be created using `typing.NamedTuple` or `collections.namedtuple`.
|
||||
Parametrized generic classes are *not* assumed to be named tuples.
|
||||
"""
|
||||
from ._utils import lenient_issubclass # circ. import
|
||||
|
||||
return lenient_issubclass(tp, tuple) and hasattr(tp, '_fields')
|
||||
|
||||
|
||||
# TODO In 2.12, delete this export. It is currently defined only to not break
|
||||
# pydantic-settings which relies on it:
|
||||
origin_is_union = is_union_origin
|
||||
|
||||
|
||||
def is_generic_alias(tp: Any, /) -> bool:
|
||||
return isinstance(tp, (types.GenericAlias, typing._GenericAlias)) # pyright: ignore[reportAttributeAccessIssue]
|
||||
|
||||
|
||||
# TODO: Ideally, we should avoid relying on the private `typing` constructs:
|
||||
|
||||
if sys.version_info < (3, 10):
|
||||
WithArgsTypes: tuple[Any, ...] = (typing._GenericAlias, types.GenericAlias) # pyright: ignore[reportAttributeAccessIssue]
|
||||
else:
|
||||
WithArgsTypes: tuple[Any, ...] = (typing._GenericAlias, types.GenericAlias, types.UnionType) # pyright: ignore[reportAttributeAccessIssue]
|
||||
|
||||
|
||||
# Similarly, we shouldn't rely on this `_Final` class, which is even more private than `_GenericAlias`:
|
||||
typing_base: Any = typing._Final # pyright: ignore[reportAttributeAccessIssue]
|
||||
|
||||
|
||||
### Annotation evaluations functions:
|
||||
|
||||
|
||||
def parent_frame_namespace(*, parent_depth: int = 2, force: bool = False) -> dict[str, Any] | None:
|
||||
"""Fetch the local namespace of the parent frame where this function is called.
|
||||
|
||||
Using this function is mostly useful to resolve forward annotations pointing to members defined in a local namespace,
|
||||
such as assignments inside a function. Using the standard library tools, it is currently not possible to resolve
|
||||
such annotations:
|
||||
|
||||
```python {lint="skip" test="skip"}
|
||||
from typing import get_type_hints
|
||||
|
||||
def func() -> None:
|
||||
Alias = int
|
||||
|
||||
class C:
|
||||
a: 'Alias'
|
||||
|
||||
# Raises a `NameError: 'Alias' is not defined`
|
||||
get_type_hints(C)
|
||||
```
|
||||
|
||||
Pydantic uses this function when a Pydantic model is being defined to fetch the parent frame locals. However,
|
||||
this only allows us to fetch the parent frame namespace and not other parents (e.g. a model defined in a function,
|
||||
itself defined in another function). Inspecting the next outer frames (using `f_back`) is not reliable enough
|
||||
(see https://discuss.python.org/t/20659).
|
||||
|
||||
Because this function is mostly used to better resolve forward annotations, nothing is returned if the parent frame's
|
||||
code object is defined at the module level. In this case, the locals of the frame will be the same as the module
|
||||
globals where the class is defined (see `_namespace_utils.get_module_ns_of`). However, if you still want to fetch
|
||||
the module globals (e.g. when rebuilding a model, where the frame where the rebuild call is performed might contain
|
||||
members that you want to use for forward annotations evaluation), you can use the `force` parameter.
|
||||
|
||||
Args:
|
||||
parent_depth: The depth at which to get the frame. Defaults to 2, meaning the parent frame where this function
|
||||
is called will be used.
|
||||
force: Whether to always return the frame locals, even if the frame's code object is defined at the module level.
|
||||
|
||||
Returns:
|
||||
The locals of the namespace, or `None` if it was skipped as per the described logic.
|
||||
"""
|
||||
frame = sys._getframe(parent_depth)
|
||||
# if f_back is None, it's the global module namespace and we don't need to include it here
|
||||
if frame.f_back is None:
|
||||
return None
|
||||
else:
|
||||
|
||||
if frame.f_code.co_name.startswith('<generic parameters of'):
|
||||
# As `parent_frame_namespace` is mostly called in `ModelMetaclass.__new__`,
|
||||
# the parent frame can be the annotation scope if the PEP 695 generic syntax is used.
|
||||
# (see https://docs.python.org/3/reference/executionmodel.html#annotation-scopes,
|
||||
# https://docs.python.org/3/reference/compound_stmts.html#generic-classes).
|
||||
# In this case, the code name is set to `<generic parameters of MyClass>`,
|
||||
# and we need to skip this frame as it is irrelevant.
|
||||
frame = cast(types.FrameType, frame.f_back) # guaranteed to not be `None`
|
||||
|
||||
# note, we don't copy frame.f_locals here (or during the last return call), because we don't expect the namespace to be
|
||||
# modified down the line if this becomes a problem, we could implement some sort of frozen mapping structure to enforce this.
|
||||
if force:
|
||||
return frame.f_locals
|
||||
|
||||
# If either of the following conditions are true, the class is defined at the top module level.
|
||||
# To better understand why we need both of these checks, see
|
||||
# https://github.com/pydantic/pydantic/pull/10113#discussion_r1714981531.
|
||||
if frame.f_back is None or frame.f_code.co_name == '<module>':
|
||||
return None
|
||||
|
||||
def add_module_globals(obj: Any, globalns: dict[str, Any] | None = None) -> dict[str, Any]:
|
||||
module_name = getattr(obj, '__module__', None)
|
||||
if module_name:
|
||||
try:
|
||||
module_globalns = sys.modules[module_name].__dict__
|
||||
except KeyError:
|
||||
# happens occasionally, see https://github.com/pydantic/pydantic/issues/2363
|
||||
pass
|
||||
else:
|
||||
if globalns:
|
||||
return {**module_globalns, **globalns}
|
||||
else:
|
||||
# copy module globals to make sure it can't be updated later
|
||||
return module_globalns.copy()
|
||||
|
||||
return globalns or {}
|
||||
return frame.f_locals
|
||||
|
||||
|
||||
def get_cls_types_namespace(cls: type[Any], parent_namespace: dict[str, Any] | None = None) -> dict[str, Any]:
|
||||
ns = add_module_globals(cls, parent_namespace)
|
||||
ns[cls.__name__] = cls
|
||||
return ns
|
||||
def _type_convert(arg: Any) -> Any:
|
||||
"""Convert `None` to `NoneType` and strings to `ForwardRef` instances.
|
||||
|
||||
|
||||
def get_cls_type_hints_lenient(obj: Any, globalns: dict[str, Any] | None = None) -> dict[str, Any]:
|
||||
"""Collect annotations from a class, including those from parent classes.
|
||||
|
||||
Unlike `typing.get_type_hints`, this function will not error if a forward reference is not resolvable.
|
||||
This is a backport of the private `typing._type_convert` function. When
|
||||
evaluating a type, `ForwardRef._evaluate` ends up being called, and is
|
||||
responsible for making this conversion. However, we still have to apply
|
||||
it for the first argument passed to our type evaluation functions, similarly
|
||||
to the `typing.get_type_hints` function.
|
||||
"""
|
||||
hints = {}
|
||||
if arg is None:
|
||||
return NoneType
|
||||
if isinstance(arg, str):
|
||||
# Like `typing.get_type_hints`, assume the arg can be in any context,
|
||||
# hence the proper `is_argument` and `is_class` args:
|
||||
return _make_forward_ref(arg, is_argument=False, is_class=True)
|
||||
return arg
|
||||
|
||||
|
||||
def safe_get_annotations(cls: type[Any]) -> dict[str, Any]:
|
||||
"""Get the annotations for the provided class, accounting for potential deferred forward references.
|
||||
|
||||
Starting with Python 3.14, accessing the `__annotations__` attribute might raise a `NameError` if
|
||||
a referenced symbol isn't defined yet. In this case, we return the annotation in the *forward ref*
|
||||
format.
|
||||
"""
|
||||
if sys.version_info >= (3, 14):
|
||||
return annotationlib.get_annotations(cls, format=annotationlib.Format.FORWARDREF)
|
||||
else:
|
||||
return cls.__dict__.get('__annotations__', {})
|
||||
|
||||
|
||||
def get_model_type_hints(
|
||||
obj: type[BaseModel],
|
||||
*,
|
||||
ns_resolver: NsResolver | None = None,
|
||||
) -> dict[str, tuple[Any, bool]]:
|
||||
"""Collect annotations from a Pydantic model class, including those from parent classes.
|
||||
|
||||
Args:
|
||||
obj: The Pydantic model to inspect.
|
||||
ns_resolver: A namespace resolver instance to use. Defaults to an empty instance.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping annotation names to a two-tuple: the first element is the evaluated
|
||||
type or the original annotation if a `NameError` occurred, the second element is a boolean
|
||||
indicating if whether the evaluation succeeded.
|
||||
"""
|
||||
hints: dict[str, Any] | dict[str, tuple[Any, bool]] = {}
|
||||
ns_resolver = ns_resolver or NsResolver()
|
||||
|
||||
for base in reversed(obj.__mro__):
|
||||
ann = base.__dict__.get('__annotations__')
|
||||
localns = dict(vars(base))
|
||||
if ann is not None and ann is not GetSetDescriptorType:
|
||||
# For Python 3.14, we could also use `Format.VALUE` and pass the globals/locals
|
||||
# from the ns_resolver, but we want to be able to know which specific field failed
|
||||
# to evaluate:
|
||||
ann = safe_get_annotations(base)
|
||||
|
||||
if not ann:
|
||||
continue
|
||||
|
||||
with ns_resolver.push(base):
|
||||
globalns, localns = ns_resolver.types_namespace
|
||||
for name, value in ann.items():
|
||||
hints[name] = eval_type_lenient(value, globalns, localns)
|
||||
if name.startswith('_'):
|
||||
# For private attributes, we only need the annotation to detect the `ClassVar` special form.
|
||||
# For this reason, we still try to evaluate it, but we also catch any possible exception (on
|
||||
# top of the `NameError`s caught in `try_eval_type`) that could happen so that users are free
|
||||
# to use any kind of forward annotation for private fields (e.g. circular imports, new typing
|
||||
# syntax, etc).
|
||||
try:
|
||||
hints[name] = try_eval_type(value, globalns, localns)
|
||||
except Exception:
|
||||
hints[name] = (value, False)
|
||||
else:
|
||||
hints[name] = try_eval_type(value, globalns, localns)
|
||||
return hints
|
||||
|
||||
|
||||
def eval_type_lenient(value: Any, globalns: dict[str, Any] | None, localns: dict[str, Any] | None) -> Any:
|
||||
"""Behaves like typing._eval_type, except it won't raise an error if a forward reference can't be resolved."""
|
||||
if value is None:
|
||||
value = NoneType
|
||||
elif isinstance(value, str):
|
||||
value = _make_forward_ref(value, is_argument=False, is_class=True)
|
||||
def get_cls_type_hints(
|
||||
obj: type[Any],
|
||||
*,
|
||||
ns_resolver: NsResolver | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Collect annotations from a class, including those from parent classes.
|
||||
|
||||
Args:
|
||||
obj: The class to inspect.
|
||||
ns_resolver: A namespace resolver instance to use. Defaults to an empty instance.
|
||||
"""
|
||||
hints: dict[str, Any] = {}
|
||||
ns_resolver = ns_resolver or NsResolver()
|
||||
|
||||
for base in reversed(obj.__mro__):
|
||||
# For Python 3.14, we could also use `Format.VALUE` and pass the globals/locals
|
||||
# from the ns_resolver, but we want to be able to know which specific field failed
|
||||
# to evaluate:
|
||||
ann = safe_get_annotations(base)
|
||||
|
||||
if not ann:
|
||||
continue
|
||||
|
||||
with ns_resolver.push(base):
|
||||
globalns, localns = ns_resolver.types_namespace
|
||||
for name, value in ann.items():
|
||||
hints[name] = eval_type(value, globalns, localns)
|
||||
return hints
|
||||
|
||||
|
||||
def try_eval_type(
|
||||
value: Any,
|
||||
globalns: GlobalsNamespace | None = None,
|
||||
localns: MappingNamespace | None = None,
|
||||
) -> tuple[Any, bool]:
|
||||
"""Try evaluating the annotation using the provided namespaces.
|
||||
|
||||
Args:
|
||||
value: The value to evaluate. If `None`, it will be replaced by `type[None]`. If an instance
|
||||
of `str`, it will be converted to a `ForwardRef`.
|
||||
localns: The global namespace to use during annotation evaluation.
|
||||
globalns: The local namespace to use during annotation evaluation.
|
||||
|
||||
Returns:
|
||||
A two-tuple containing the possibly evaluated type and a boolean indicating
|
||||
whether the evaluation succeeded or not.
|
||||
"""
|
||||
value = _type_convert(value)
|
||||
|
||||
try:
|
||||
return typing._eval_type(value, globalns, localns) # type: ignore
|
||||
return eval_type_backport(value, globalns, localns), True
|
||||
except NameError:
|
||||
# the point of this function is to be tolerant to this case
|
||||
return value
|
||||
return value, False
|
||||
|
||||
|
||||
def eval_type(
|
||||
value: Any,
|
||||
globalns: GlobalsNamespace | None = None,
|
||||
localns: MappingNamespace | None = None,
|
||||
) -> Any:
|
||||
"""Evaluate the annotation using the provided namespaces.
|
||||
|
||||
Args:
|
||||
value: The value to evaluate. If `None`, it will be replaced by `type[None]`. If an instance
|
||||
of `str`, it will be converted to a `ForwardRef`.
|
||||
localns: The global namespace to use during annotation evaluation.
|
||||
globalns: The local namespace to use during annotation evaluation.
|
||||
"""
|
||||
value = _type_convert(value)
|
||||
return eval_type_backport(value, globalns, localns)
|
||||
|
||||
|
||||
@deprecated(
|
||||
'`eval_type_lenient` is deprecated, use `try_eval_type` instead.',
|
||||
category=None,
|
||||
)
|
||||
def eval_type_lenient(
|
||||
value: Any,
|
||||
globalns: GlobalsNamespace | None = None,
|
||||
localns: MappingNamespace | None = None,
|
||||
) -> Any:
|
||||
ev, _ = try_eval_type(value, globalns, localns)
|
||||
return ev
|
||||
|
||||
|
||||
def eval_type_backport(
|
||||
value: Any,
|
||||
globalns: GlobalsNamespace | None = None,
|
||||
localns: MappingNamespace | None = None,
|
||||
type_params: tuple[Any, ...] | None = None,
|
||||
) -> Any:
|
||||
"""An enhanced version of `typing._eval_type` which will fall back to using the `eval_type_backport`
|
||||
package if it's installed to let older Python versions use newer typing constructs.
|
||||
|
||||
Specifically, this transforms `X | Y` into `typing.Union[X, Y]` and `list[X]` into `typing.List[X]`
|
||||
(as well as all the types made generic in PEP 585) if the original syntax is not supported in the
|
||||
current Python version.
|
||||
|
||||
This function will also display a helpful error if the value passed fails to evaluate.
|
||||
"""
|
||||
try:
|
||||
return _eval_type_backport(value, globalns, localns, type_params)
|
||||
except TypeError as e:
|
||||
if 'Unable to evaluate type annotation' in str(e):
|
||||
raise
|
||||
|
||||
# If it is a `TypeError` and value isn't a `ForwardRef`, it would have failed during annotation definition.
|
||||
# Thus we assert here for type checking purposes:
|
||||
assert isinstance(value, typing.ForwardRef)
|
||||
|
||||
message = f'Unable to evaluate type annotation {value.__forward_arg__!r}.'
|
||||
if sys.version_info >= (3, 11):
|
||||
e.add_note(message)
|
||||
raise
|
||||
else:
|
||||
raise TypeError(message) from e
|
||||
except RecursionError as e:
|
||||
# TODO ideally recursion errors should be checked in `eval_type` above, but `eval_type_backport`
|
||||
# is used directly in some places.
|
||||
message = (
|
||||
"If you made use of an implicit recursive type alias (e.g. `MyType = list['MyType']), "
|
||||
'consider using PEP 695 type aliases instead. For more details, refer to the documentation: '
|
||||
f'https://docs.pydantic.dev/{version_short()}/concepts/types/#named-recursive-types'
|
||||
)
|
||||
if sys.version_info >= (3, 11):
|
||||
e.add_note(message)
|
||||
raise
|
||||
else:
|
||||
raise RecursionError(f'{e.args[0]}\n{message}')
|
||||
|
||||
|
||||
def _eval_type_backport(
|
||||
value: Any,
|
||||
globalns: GlobalsNamespace | None = None,
|
||||
localns: MappingNamespace | None = None,
|
||||
type_params: tuple[Any, ...] | None = None,
|
||||
) -> Any:
|
||||
try:
|
||||
return _eval_type(value, globalns, localns, type_params)
|
||||
except TypeError as e:
|
||||
if not (isinstance(value, typing.ForwardRef) and is_backport_fixable_error(e)):
|
||||
raise
|
||||
|
||||
try:
|
||||
from eval_type_backport import eval_type_backport
|
||||
except ImportError:
|
||||
raise TypeError(
|
||||
f'Unable to evaluate type annotation {value.__forward_arg__!r}. If you are making use '
|
||||
'of the new typing syntax (unions using `|` since Python 3.10 or builtins subscripting '
|
||||
'since Python 3.9), you should either replace the use of new syntax with the existing '
|
||||
'`typing` constructs or install the `eval_type_backport` package.'
|
||||
) from e
|
||||
|
||||
return eval_type_backport(
|
||||
value,
|
||||
globalns,
|
||||
localns, # pyright: ignore[reportArgumentType], waiting on a new `eval_type_backport` release.
|
||||
try_default=False,
|
||||
)
|
||||
|
||||
|
||||
def _eval_type(
|
||||
value: Any,
|
||||
globalns: GlobalsNamespace | None = None,
|
||||
localns: MappingNamespace | None = None,
|
||||
type_params: tuple[Any, ...] | None = None,
|
||||
) -> Any:
|
||||
if sys.version_info >= (3, 14):
|
||||
# Starting in 3.14, `_eval_type()` does *not* apply `_type_convert()`
|
||||
# anymore. This means the `None` -> `type(None)` conversion does not apply:
|
||||
evaluated = typing._eval_type( # type: ignore
|
||||
value,
|
||||
globalns,
|
||||
localns,
|
||||
type_params=type_params,
|
||||
# This is relevant when evaluating types from `TypedDict` classes, where string annotations
|
||||
# are automatically converted to `ForwardRef` instances with a module set. In this case,
|
||||
# Our `globalns` is irrelevant and we need to indicate `typing._eval_type()` that it should
|
||||
# infer it from the `ForwardRef.__forward_module__` attribute instead (`typing.get_type_hints()`
|
||||
# does the same). Note that this would probably be unnecessary if we properly iterated over the
|
||||
# `__orig_bases__` for TypedDicts in `get_cls_type_hints()`:
|
||||
prefer_fwd_module=True,
|
||||
)
|
||||
if evaluated is None:
|
||||
evaluated = type(None)
|
||||
return evaluated
|
||||
elif sys.version_info >= (3, 13):
|
||||
return typing._eval_type( # type: ignore
|
||||
value, globalns, localns, type_params=type_params
|
||||
)
|
||||
else:
|
||||
return typing._eval_type( # type: ignore
|
||||
value, globalns, localns
|
||||
)
|
||||
|
||||
|
||||
def is_backport_fixable_error(e: TypeError) -> bool:
|
||||
msg = str(e)
|
||||
|
||||
return sys.version_info < (3, 10) and msg.startswith('unsupported operand type(s) for |: ')
|
||||
|
||||
|
||||
def get_function_type_hints(
|
||||
function: Callable[..., Any], *, include_keys: set[str] | None = None, types_namespace: dict[str, Any] | None = None
|
||||
function: Callable[..., Any],
|
||||
*,
|
||||
include_keys: set[str] | None = None,
|
||||
globalns: GlobalsNamespace | None = None,
|
||||
localns: MappingNamespace | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Like `typing.get_type_hints`, but doesn't convert `X` to `Optional[X]` if the default value is `None`, also
|
||||
copes with `partial`.
|
||||
"""
|
||||
if isinstance(function, partial):
|
||||
annotations = function.func.__annotations__
|
||||
else:
|
||||
annotations = function.__annotations__
|
||||
"""Return type hints for a function.
|
||||
|
||||
This is similar to the `typing.get_type_hints` function, with a few differences:
|
||||
- Support `functools.partial` by using the underlying `func` attribute.
|
||||
- Do not wrap type annotation of a parameter with `Optional` if it has a default value of `None`
|
||||
(related bug: https://github.com/python/cpython/issues/90353, only fixed in 3.11+).
|
||||
"""
|
||||
try:
|
||||
if isinstance(function, partial):
|
||||
annotations = function.func.__annotations__
|
||||
else:
|
||||
annotations = function.__annotations__
|
||||
except AttributeError:
|
||||
# Some functions (e.g. builtins) don't have annotations:
|
||||
return {}
|
||||
|
||||
if globalns is None:
|
||||
globalns = get_module_ns_of(function)
|
||||
type_params: tuple[Any, ...] | None = None
|
||||
if localns is None:
|
||||
# If localns was specified, it is assumed to already contain type params. This is because
|
||||
# Pydantic has more advanced logic to do so (see `_namespace_utils.ns_for_function`).
|
||||
type_params = getattr(function, '__type_params__', ())
|
||||
|
||||
globalns = add_module_globals(function)
|
||||
type_hints = {}
|
||||
for name, value in annotations.items():
|
||||
if include_keys is not None and name not in include_keys:
|
||||
@@ -248,11 +593,12 @@ def get_function_type_hints(
|
||||
elif isinstance(value, str):
|
||||
value = _make_forward_ref(value)
|
||||
|
||||
type_hints[name] = typing._eval_type(value, globalns, types_namespace) # type: ignore
|
||||
type_hints[name] = eval_type_backport(value, globalns, localns, type_params)
|
||||
|
||||
return type_hints
|
||||
|
||||
|
||||
# TODO use typing.ForwardRef directly when we stop supporting 3.9:
|
||||
if sys.version_info < (3, 9, 8) or (3, 10) <= sys.version_info < (3, 10, 1):
|
||||
|
||||
def _make_forward_ref(
|
||||
@@ -272,10 +618,10 @@ if sys.version_info < (3, 9, 8) or (3, 10) <= sys.version_info < (3, 10, 1):
|
||||
|
||||
Implemented as EAFP with memory.
|
||||
"""
|
||||
return typing.ForwardRef(arg, is_argument)
|
||||
return typing.ForwardRef(arg, is_argument) # pyright: ignore[reportCallIssue]
|
||||
|
||||
else:
|
||||
_make_forward_ref = typing.ForwardRef
|
||||
_make_forward_ref = typing.ForwardRef # pyright: ignore[reportAssignmentType]
|
||||
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
@@ -363,11 +709,15 @@ else:
|
||||
if isinstance(value, str):
|
||||
value = _make_forward_ref(value, is_argument=False, is_class=True)
|
||||
|
||||
value = typing._eval_type(value, base_globals, base_locals) # type: ignore
|
||||
value = eval_type_backport(value, base_globals, base_locals)
|
||||
hints[name] = value
|
||||
return (
|
||||
hints if include_extras else {k: typing._strip_annotations(t) for k, t in hints.items()} # type: ignore
|
||||
)
|
||||
if not include_extras and hasattr(typing, '_strip_annotations'):
|
||||
return {
|
||||
k: typing._strip_annotations(t) # type: ignore
|
||||
for k, t in hints.items()
|
||||
}
|
||||
else:
|
||||
return hints
|
||||
|
||||
if globalns is None:
|
||||
if isinstance(obj, types.ModuleType):
|
||||
@@ -388,7 +738,7 @@ else:
|
||||
if isinstance(obj, typing._allowed_types): # type: ignore
|
||||
return {}
|
||||
else:
|
||||
raise TypeError(f'{obj!r} is not a module, class, method, ' 'or function.')
|
||||
raise TypeError(f'{obj!r} is not a module, class, method, or function.')
|
||||
defaults = typing._get_defaults(obj) # type: ignore
|
||||
hints = dict(hints)
|
||||
for name, value in hints.items():
|
||||
@@ -403,44 +753,8 @@ else:
|
||||
is_argument=not isinstance(obj, types.ModuleType),
|
||||
is_class=False,
|
||||
)
|
||||
value = typing._eval_type(value, globalns, localns) # type: ignore
|
||||
value = eval_type_backport(value, globalns, localns)
|
||||
if name in defaults and defaults[name] is None:
|
||||
value = typing.Optional[value]
|
||||
hints[name] = value
|
||||
return hints if include_extras else {k: typing._strip_annotations(t) for k, t in hints.items()} # type: ignore
|
||||
|
||||
|
||||
if sys.version_info < (3, 9):
|
||||
|
||||
def evaluate_fwd_ref(
|
||||
ref: ForwardRef, globalns: dict[str, Any] | None = None, localns: dict[str, Any] | None = None
|
||||
) -> Any:
|
||||
return ref._evaluate(globalns=globalns, localns=localns)
|
||||
|
||||
else:
|
||||
|
||||
def evaluate_fwd_ref(
|
||||
ref: ForwardRef, globalns: dict[str, Any] | None = None, localns: dict[str, Any] | None = None
|
||||
) -> Any:
|
||||
return ref._evaluate(globalns=globalns, localns=localns, recursive_guard=frozenset())
|
||||
|
||||
|
||||
def is_dataclass(_cls: type[Any]) -> TypeGuard[type[StandardDataclass]]:
|
||||
# The dataclasses.is_dataclass function doesn't seem to provide TypeGuard functionality,
|
||||
# so I created this convenience function
|
||||
return dataclasses.is_dataclass(_cls)
|
||||
|
||||
|
||||
def origin_is_type_alias_type(origin: Any) -> TypeGuard[TypeAliasType]:
|
||||
return isinstance(origin, TypeAliasType)
|
||||
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
|
||||
def is_generic_alias(type_: type[Any]) -> bool:
|
||||
return isinstance(type_, (types.GenericAlias, typing._GenericAlias)) # type: ignore[attr-defined]
|
||||
|
||||
else:
|
||||
|
||||
def is_generic_alias(type_: type[Any]) -> bool:
|
||||
return isinstance(type_, typing._GenericAlias) # type: ignore
|
||||
|
||||
@@ -2,24 +2,36 @@
|
||||
|
||||
This should be reduced as much as possible with functions only used in one place, moved to that place.
|
||||
"""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import dataclasses
|
||||
import keyword
|
||||
import typing
|
||||
import sys
|
||||
import warnings
|
||||
import weakref
|
||||
from collections import OrderedDict, defaultdict, deque
|
||||
from collections.abc import Callable, Iterable, Mapping
|
||||
from collections.abc import Set as AbstractSet
|
||||
from copy import deepcopy
|
||||
from functools import cached_property
|
||||
from inspect import Parameter
|
||||
from itertools import zip_longest
|
||||
from types import BuiltinFunctionType, CodeType, FunctionType, GeneratorType, LambdaType, ModuleType
|
||||
from typing import Any, TypeVar
|
||||
from typing import TYPE_CHECKING, Any, Generic, TypeVar, overload
|
||||
|
||||
from typing_extensions import TypeAlias, TypeGuard
|
||||
from pydantic_core import MISSING
|
||||
from typing_extensions import TypeAlias, TypeGuard, deprecated
|
||||
|
||||
from pydantic import PydanticDeprecatedSince211
|
||||
|
||||
from . import _repr, _typing_extra
|
||||
from ._import_utils import import_cached_base_model
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
MappingIntStrAny: TypeAlias = 'typing.Mapping[int, Any] | typing.Mapping[str, Any]'
|
||||
AbstractSetIntStr: TypeAlias = 'typing.AbstractSet[int] | typing.AbstractSet[str]'
|
||||
if TYPE_CHECKING:
|
||||
# TODO remove type error comments when we drop support for Python 3.9
|
||||
MappingIntStrAny: TypeAlias = Mapping[int, Any] | Mapping[str, Any] # pyright: ignore[reportGeneralTypeIssues]
|
||||
AbstractSetIntStr: TypeAlias = AbstractSet[int] | AbstractSet[str] # pyright: ignore[reportGeneralTypeIssues]
|
||||
from ..main import BaseModel
|
||||
|
||||
|
||||
@@ -59,6 +71,25 @@ BUILTIN_COLLECTIONS: set[type[Any]] = {
|
||||
}
|
||||
|
||||
|
||||
def can_be_positional(param: Parameter) -> bool:
|
||||
"""Return whether the parameter accepts a positional argument.
|
||||
|
||||
```python {test="skip" lint="skip"}
|
||||
def func(a, /, b, *, c):
|
||||
pass
|
||||
|
||||
params = inspect.signature(func).parameters
|
||||
can_be_positional(params['a'])
|
||||
#> True
|
||||
can_be_positional(params['b'])
|
||||
#> True
|
||||
can_be_positional(params['c'])
|
||||
#> False
|
||||
```
|
||||
"""
|
||||
return param.kind in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD)
|
||||
|
||||
|
||||
def sequence_like(v: Any) -> bool:
|
||||
return isinstance(v, (list, tuple, set, frozenset, GeneratorType, deque))
|
||||
|
||||
@@ -83,7 +114,7 @@ def is_model_class(cls: Any) -> TypeGuard[type[BaseModel]]:
|
||||
"""Returns true if cls is a _proper_ subclass of BaseModel, and provides proper type-checking,
|
||||
unlike raw calls to lenient_issubclass.
|
||||
"""
|
||||
from ..main import BaseModel
|
||||
BaseModel = import_cached_base_model()
|
||||
|
||||
return lenient_issubclass(cls, BaseModel) and cls is not BaseModel
|
||||
|
||||
@@ -120,7 +151,7 @@ T = TypeVar('T')
|
||||
def unique_list(
|
||||
input_list: list[T] | tuple[T, ...],
|
||||
*,
|
||||
name_factory: typing.Callable[[T], str] = str,
|
||||
name_factory: Callable[[T], str] = str,
|
||||
) -> list[T]:
|
||||
"""Make a list unique while maintaining order.
|
||||
We update the list if another one with the same name is set
|
||||
@@ -185,7 +216,7 @@ class ValueItems(_repr.Representation):
|
||||
normalized_items: dict[int | str, Any] = {}
|
||||
all_items = None
|
||||
for i, v in items.items():
|
||||
if not (isinstance(v, typing.Mapping) or isinstance(v, typing.AbstractSet) or self.is_true(v)):
|
||||
if not (isinstance(v, Mapping) or isinstance(v, AbstractSet) or self.is_true(v)):
|
||||
raise TypeError(f'Unexpected type of exclude value for index "{i}" {v.__class__}')
|
||||
if i == '__all__':
|
||||
all_items = self._coerce_value(v)
|
||||
@@ -250,9 +281,9 @@ class ValueItems(_repr.Representation):
|
||||
|
||||
@staticmethod
|
||||
def _coerce_items(items: AbstractSetIntStr | MappingIntStrAny) -> MappingIntStrAny:
|
||||
if isinstance(items, typing.Mapping):
|
||||
if isinstance(items, Mapping):
|
||||
pass
|
||||
elif isinstance(items, typing.AbstractSet):
|
||||
elif isinstance(items, AbstractSet):
|
||||
items = dict.fromkeys(items, ...) # type: ignore
|
||||
else:
|
||||
class_name = getattr(items, '__class__', '???')
|
||||
@@ -273,21 +304,25 @@ class ValueItems(_repr.Representation):
|
||||
return [(None, self._items)]
|
||||
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
if TYPE_CHECKING:
|
||||
|
||||
def ClassAttribute(name: str, value: T) -> T:
|
||||
...
|
||||
def LazyClassAttribute(name: str, get_value: Callable[[], T]) -> T: ...
|
||||
|
||||
else:
|
||||
|
||||
class ClassAttribute:
|
||||
"""Hide class attribute from its instances."""
|
||||
class LazyClassAttribute:
|
||||
"""A descriptor exposing an attribute only accessible on a class (hidden from instances).
|
||||
|
||||
__slots__ = 'name', 'value'
|
||||
The attribute is lazily computed and cached during the first access.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, value: Any) -> None:
|
||||
def __init__(self, name: str, get_value: Callable[[], Any]) -> None:
|
||||
self.name = name
|
||||
self.value = value
|
||||
self.get_value = get_value
|
||||
|
||||
@cached_property
|
||||
def value(self) -> Any:
|
||||
return self.get_value()
|
||||
|
||||
def __get__(self, instance: Any, owner: type[Any]) -> None:
|
||||
if instance is None:
|
||||
@@ -303,6 +338,8 @@ def smart_deepcopy(obj: Obj) -> Obj:
|
||||
Use obj.copy() for built-in empty collections
|
||||
Use copy.deepcopy() for non-empty collections and unknown objects.
|
||||
"""
|
||||
if obj is MISSING:
|
||||
return obj # pyright: ignore[reportReturnType]
|
||||
obj_type = obj.__class__
|
||||
if obj_type in IMMUTABLE_NON_COLLECTIONS_TYPES:
|
||||
return obj # fastest case: obj is immutable and not collection therefore will not be copied anyway
|
||||
@@ -317,10 +354,10 @@ def smart_deepcopy(obj: Obj) -> Obj:
|
||||
return deepcopy(obj) # slowest way when we actually might need a deepcopy
|
||||
|
||||
|
||||
_EMPTY = object()
|
||||
_SENTINEL = object()
|
||||
|
||||
|
||||
def all_identical(left: typing.Iterable[Any], right: typing.Iterable[Any]) -> bool:
|
||||
def all_identical(left: Iterable[Any], right: Iterable[Any]) -> bool:
|
||||
"""Check that the items of `left` are the same objects as those in `right`.
|
||||
|
||||
>>> a, b = object(), object()
|
||||
@@ -329,7 +366,81 @@ def all_identical(left: typing.Iterable[Any], right: typing.Iterable[Any]) -> bo
|
||||
>>> all_identical([a, b, [a]], [a, b, [a]]) # new list object, while "equal" is not "identical"
|
||||
False
|
||||
"""
|
||||
for left_item, right_item in zip_longest(left, right, fillvalue=_EMPTY):
|
||||
for left_item, right_item in zip_longest(left, right, fillvalue=_SENTINEL):
|
||||
if left_item is not right_item:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_first_not_none(a: Any, b: Any) -> Any:
|
||||
"""Return the first argument if it is not `None`, otherwise return the second argument."""
|
||||
return a if a is not None else b
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class SafeGetItemProxy:
|
||||
"""Wrapper redirecting `__getitem__` to `get` with a sentinel value as default
|
||||
|
||||
This makes is safe to use in `operator.itemgetter` when some keys may be missing
|
||||
"""
|
||||
|
||||
# Define __slots__manually for performances
|
||||
# @dataclasses.dataclass() only support slots=True in python>=3.10
|
||||
__slots__ = ('wrapped',)
|
||||
|
||||
wrapped: Mapping[str, Any]
|
||||
|
||||
def __getitem__(self, key: str, /) -> Any:
|
||||
return self.wrapped.get(key, _SENTINEL)
|
||||
|
||||
# required to pass the object to operator.itemgetter() instances due to a quirk of typeshed
|
||||
# https://github.com/python/mypy/issues/13713
|
||||
# https://github.com/python/typeshed/pull/8785
|
||||
# Since this is typing-only, hide it in a typing.TYPE_CHECKING block
|
||||
if TYPE_CHECKING:
|
||||
|
||||
def __contains__(self, key: str, /) -> bool:
|
||||
return self.wrapped.__contains__(key)
|
||||
|
||||
|
||||
_ModelT = TypeVar('_ModelT', bound='BaseModel')
|
||||
_RT = TypeVar('_RT')
|
||||
|
||||
|
||||
class deprecated_instance_property(Generic[_ModelT, _RT]):
|
||||
"""A decorator exposing the decorated class method as a property, with a warning on instance access.
|
||||
|
||||
This decorator takes a class method defined on the `BaseModel` class and transforms it into
|
||||
an attribute. The attribute can be accessed on both the class and instances of the class. If accessed
|
||||
via an instance, a deprecation warning is emitted stating that instance access will be removed in V3.
|
||||
"""
|
||||
|
||||
def __init__(self, fget: Callable[[type[_ModelT]], _RT], /) -> None:
|
||||
# Note: fget should be a classmethod:
|
||||
self.fget = fget
|
||||
|
||||
@overload
|
||||
def __get__(self, instance: None, objtype: type[_ModelT]) -> _RT: ...
|
||||
@overload
|
||||
@deprecated(
|
||||
'Accessing this attribute on the instance is deprecated, and will be removed in Pydantic V3. '
|
||||
'Instead, you should access this attribute from the model class.',
|
||||
category=None,
|
||||
)
|
||||
def __get__(self, instance: _ModelT, objtype: type[_ModelT]) -> _RT: ...
|
||||
def __get__(self, instance: _ModelT | None, objtype: type[_ModelT]) -> _RT:
|
||||
if instance is not None:
|
||||
# fmt: off
|
||||
attr_name = (
|
||||
self.fget.__name__
|
||||
if sys.version_info >= (3, 10)
|
||||
else self.fget.__func__.__name__ # pyright: ignore[reportFunctionMemberAccess]
|
||||
)
|
||||
# fmt: on
|
||||
warnings.warn(
|
||||
f'Accessing the {attr_name!r} attribute on the instance is deprecated. '
|
||||
'Instead, you should access this attribute from the model class.',
|
||||
category=PydanticDeprecatedSince211,
|
||||
stacklevel=2,
|
||||
)
|
||||
return self.fget.__get__(instance, objtype)()
|
||||
|
||||
@@ -1,101 +1,122 @@
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
from dataclasses import dataclass
|
||||
from collections.abc import Awaitable
|
||||
from functools import partial
|
||||
from typing import Any, Awaitable, Callable
|
||||
from typing import Any, Callable
|
||||
|
||||
import pydantic_core
|
||||
|
||||
from ..config import ConfigDict
|
||||
from ..plugin._schema_validator import create_schema_validator
|
||||
from . import _generate_schema, _typing_extra
|
||||
from ._config import ConfigWrapper
|
||||
from ._generate_schema import GenerateSchema, ValidateCallSupportedTypes
|
||||
from ._namespace_utils import MappingNamespace, NsResolver, ns_for_function
|
||||
|
||||
|
||||
@dataclass
|
||||
class CallMarker:
|
||||
function: Callable[..., Any]
|
||||
validate_return: bool
|
||||
def extract_function_name(func: ValidateCallSupportedTypes) -> str:
|
||||
"""Extract the name of a `ValidateCallSupportedTypes` object."""
|
||||
return f'partial({func.func.__name__})' if isinstance(func, functools.partial) else func.__name__
|
||||
|
||||
|
||||
def extract_function_qualname(func: ValidateCallSupportedTypes) -> str:
|
||||
"""Extract the qualname of a `ValidateCallSupportedTypes` object."""
|
||||
return f'partial({func.func.__qualname__})' if isinstance(func, functools.partial) else func.__qualname__
|
||||
|
||||
|
||||
def update_wrapper_attributes(wrapped: ValidateCallSupportedTypes, wrapper: Callable[..., Any]):
|
||||
"""Update the `wrapper` function with the attributes of the `wrapped` function. Return the updated function."""
|
||||
if inspect.iscoroutinefunction(wrapped):
|
||||
|
||||
@functools.wraps(wrapped)
|
||||
async def wrapper_function(*args, **kwargs): # type: ignore
|
||||
return await wrapper(*args, **kwargs)
|
||||
else:
|
||||
|
||||
@functools.wraps(wrapped)
|
||||
def wrapper_function(*args, **kwargs):
|
||||
return wrapper(*args, **kwargs)
|
||||
|
||||
# We need to manually update this because `partial` object has no `__name__` and `__qualname__`.
|
||||
wrapper_function.__name__ = extract_function_name(wrapped)
|
||||
wrapper_function.__qualname__ = extract_function_qualname(wrapped)
|
||||
wrapper_function.raw_function = wrapped # type: ignore
|
||||
|
||||
return wrapper_function
|
||||
|
||||
|
||||
class ValidateCallWrapper:
|
||||
"""This is a wrapper around a function that validates the arguments passed to it, and optionally the return value.
|
||||
|
||||
It's partially inspired by `wraps` which in turn uses `partial`, but extended to be a descriptor so
|
||||
these functions can be applied to instance methods, class methods, static methods, as well as normal functions.
|
||||
"""
|
||||
"""This is a wrapper around a function that validates the arguments passed to it, and optionally the return value."""
|
||||
|
||||
__slots__ = (
|
||||
'raw_function',
|
||||
'_config',
|
||||
'_validate_return',
|
||||
'__pydantic_core_schema__',
|
||||
'function',
|
||||
'validate_return',
|
||||
'schema_type',
|
||||
'module',
|
||||
'qualname',
|
||||
'ns_resolver',
|
||||
'config_wrapper',
|
||||
'__pydantic_complete__',
|
||||
'__pydantic_validator__',
|
||||
'__signature__',
|
||||
'__name__',
|
||||
'__qualname__',
|
||||
'__annotations__',
|
||||
'__dict__', # required for __module__
|
||||
'__return_pydantic_validator__',
|
||||
)
|
||||
|
||||
def __init__(self, function: Callable[..., Any], config: ConfigDict | None, validate_return: bool):
|
||||
self.raw_function = function
|
||||
self._config = config
|
||||
self._validate_return = validate_return
|
||||
self.__signature__ = inspect.signature(function)
|
||||
def __init__(
|
||||
self,
|
||||
function: ValidateCallSupportedTypes,
|
||||
config: ConfigDict | None,
|
||||
validate_return: bool,
|
||||
parent_namespace: MappingNamespace | None,
|
||||
) -> None:
|
||||
self.function = function
|
||||
self.validate_return = validate_return
|
||||
if isinstance(function, partial):
|
||||
func = function.func
|
||||
schema_type = func
|
||||
self.__name__ = f'partial({func.__name__})'
|
||||
self.__qualname__ = f'partial({func.__qualname__})'
|
||||
self.__annotations__ = func.__annotations__
|
||||
self.__module__ = func.__module__
|
||||
self.__doc__ = func.__doc__
|
||||
self.schema_type = function.func
|
||||
self.module = function.func.__module__
|
||||
else:
|
||||
schema_type = function
|
||||
self.__name__ = function.__name__
|
||||
self.__qualname__ = function.__qualname__
|
||||
self.__annotations__ = function.__annotations__
|
||||
self.__module__ = function.__module__
|
||||
self.__doc__ = function.__doc__
|
||||
self.schema_type = function
|
||||
self.module = function.__module__
|
||||
self.qualname = extract_function_qualname(function)
|
||||
|
||||
namespace = _typing_extra.add_module_globals(function, None)
|
||||
config_wrapper = ConfigWrapper(config)
|
||||
gen_schema = _generate_schema.GenerateSchema(config_wrapper, namespace)
|
||||
schema = gen_schema.clean_schema(gen_schema.generate_schema(function))
|
||||
self.__pydantic_core_schema__ = schema
|
||||
core_config = config_wrapper.core_config(self)
|
||||
self.ns_resolver = NsResolver(
|
||||
namespaces_tuple=ns_for_function(self.schema_type, parent_namespace=parent_namespace)
|
||||
)
|
||||
self.config_wrapper = ConfigWrapper(config)
|
||||
if not self.config_wrapper.defer_build:
|
||||
self._create_validators()
|
||||
else:
|
||||
self.__pydantic_complete__ = False
|
||||
|
||||
def _create_validators(self) -> None:
|
||||
gen_schema = GenerateSchema(self.config_wrapper, self.ns_resolver)
|
||||
schema = gen_schema.clean_schema(gen_schema.generate_schema(self.function))
|
||||
core_config = self.config_wrapper.core_config(title=self.qualname)
|
||||
|
||||
self.__pydantic_validator__ = create_schema_validator(
|
||||
schema,
|
||||
schema_type,
|
||||
self.__module__,
|
||||
self.__qualname__,
|
||||
self.schema_type,
|
||||
self.module,
|
||||
self.qualname,
|
||||
'validate_call',
|
||||
core_config,
|
||||
config_wrapper.plugin_settings,
|
||||
self.config_wrapper.plugin_settings,
|
||||
)
|
||||
|
||||
if self._validate_return:
|
||||
return_type = (
|
||||
self.__signature__.return_annotation
|
||||
if self.__signature__.return_annotation is not self.__signature__.empty
|
||||
else Any
|
||||
)
|
||||
gen_schema = _generate_schema.GenerateSchema(config_wrapper, namespace)
|
||||
if self.validate_return:
|
||||
signature = inspect.signature(self.function)
|
||||
return_type = signature.return_annotation if signature.return_annotation is not signature.empty else Any
|
||||
gen_schema = GenerateSchema(self.config_wrapper, self.ns_resolver)
|
||||
schema = gen_schema.clean_schema(gen_schema.generate_schema(return_type))
|
||||
self.__return_pydantic_core_schema__ = schema
|
||||
validator = create_schema_validator(
|
||||
schema,
|
||||
schema_type,
|
||||
self.__module__,
|
||||
self.__qualname__,
|
||||
self.schema_type,
|
||||
self.module,
|
||||
self.qualname,
|
||||
'validate_call',
|
||||
core_config,
|
||||
config_wrapper.plugin_settings,
|
||||
self.config_wrapper.plugin_settings,
|
||||
)
|
||||
if inspect.iscoroutinefunction(self.raw_function):
|
||||
if inspect.iscoroutinefunction(self.function):
|
||||
|
||||
async def return_val_wrapper(aw: Awaitable[Any]) -> None:
|
||||
return validator.validate_python(await aw)
|
||||
@@ -104,46 +125,16 @@ class ValidateCallWrapper:
|
||||
else:
|
||||
self.__return_pydantic_validator__ = validator.validate_python
|
||||
else:
|
||||
self.__return_pydantic_core_schema__ = None
|
||||
self.__return_pydantic_validator__ = None
|
||||
|
||||
self._name: str | None = None # set by __get__, used to set the instance attribute when decorating methods
|
||||
self.__pydantic_complete__ = True
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||
if not self.__pydantic_complete__:
|
||||
self._create_validators()
|
||||
|
||||
res = self.__pydantic_validator__.validate_python(pydantic_core.ArgsKwargs(args, kwargs))
|
||||
if self.__return_pydantic_validator__:
|
||||
return self.__return_pydantic_validator__(res)
|
||||
return res
|
||||
|
||||
def __get__(self, obj: Any, objtype: type[Any] | None = None) -> ValidateCallWrapper:
|
||||
"""Bind the raw function and return another ValidateCallWrapper wrapping that."""
|
||||
if obj is None:
|
||||
try:
|
||||
# Handle the case where a method is accessed as a class attribute
|
||||
return objtype.__getattribute__(objtype, self._name) # type: ignore
|
||||
except AttributeError:
|
||||
# This will happen the first time the attribute is accessed
|
||||
pass
|
||||
|
||||
bound_function = self.raw_function.__get__(obj, objtype)
|
||||
result = self.__class__(bound_function, self._config, self._validate_return)
|
||||
|
||||
# skip binding to instance when obj or objtype has __slots__ attribute
|
||||
if hasattr(obj, '__slots__') or hasattr(objtype, '__slots__'):
|
||||
return result
|
||||
|
||||
if self._name is not None:
|
||||
if obj is not None:
|
||||
object.__setattr__(obj, self._name, result)
|
||||
else:
|
||||
object.__setattr__(objtype, self._name, result)
|
||||
return result
|
||||
|
||||
def __set_name__(self, owner: Any, name: str) -> None:
|
||||
self._name = name
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'ValidateCallWrapper({self.raw_function})'
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.raw_function == other.raw_function
|
||||
else:
|
||||
return res
|
||||
|
||||
@@ -5,22 +5,33 @@ Import of this module is deferred since it contains imports of many standard lib
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import collections.abc
|
||||
import math
|
||||
import re
|
||||
import typing
|
||||
from collections.abc import Sequence
|
||||
from decimal import Decimal
|
||||
from fractions import Fraction
|
||||
from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network
|
||||
from typing import Any
|
||||
from typing import Any, Callable, TypeVar, Union, cast
|
||||
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
|
||||
|
||||
from pydantic_core import PydanticCustomError, core_schema
|
||||
from pydantic_core._pydantic_core import PydanticKnownError
|
||||
import typing_extensions
|
||||
from pydantic_core import PydanticCustomError, PydanticKnownError, core_schema
|
||||
from typing_extensions import get_args, get_origin
|
||||
from typing_inspection import typing_objects
|
||||
|
||||
from pydantic._internal._import_utils import import_cached_field_info
|
||||
from pydantic.errors import PydanticSchemaGenerationError
|
||||
|
||||
|
||||
def sequence_validator(
|
||||
__input_value: typing.Sequence[Any],
|
||||
input_value: Sequence[Any],
|
||||
/,
|
||||
validator: core_schema.ValidatorFunctionWrapHandler,
|
||||
) -> typing.Sequence[Any]:
|
||||
) -> Sequence[Any]:
|
||||
"""Validator for `Sequence` types, isinstance(v, Sequence) has already been called."""
|
||||
value_type = type(__input_value)
|
||||
value_type = type(input_value)
|
||||
|
||||
# We don't accept any plain string as a sequence
|
||||
# Relevant issue: https://github.com/pydantic/pydantic/issues/5595
|
||||
@@ -31,14 +42,24 @@ def sequence_validator(
|
||||
{'type_name': value_type.__name__},
|
||||
)
|
||||
|
||||
v_list = validator(__input_value)
|
||||
# TODO: refactor sequence validation to validate with either a list or a tuple
|
||||
# schema, depending on the type of the value.
|
||||
# Additionally, we should be able to remove one of either this validator or the
|
||||
# SequenceValidator in _std_types_schema.py (preferably this one, while porting over some logic).
|
||||
# Effectively, a refactor for sequence validation is needed.
|
||||
if value_type is tuple:
|
||||
input_value = list(input_value)
|
||||
|
||||
v_list = validator(input_value)
|
||||
|
||||
# the rest of the logic is just re-creating the original type from `v_list`
|
||||
if value_type == list:
|
||||
if value_type is list:
|
||||
return v_list
|
||||
elif issubclass(value_type, range):
|
||||
# return the list as we probably can't re-create the range
|
||||
return v_list
|
||||
elif value_type is tuple:
|
||||
return tuple(v_list)
|
||||
else:
|
||||
# best guess at how to re-create the original type, more custom construction logic might be required
|
||||
return value_type(v_list) # type: ignore[call-arg]
|
||||
@@ -106,173 +127,407 @@ def _import_string_logic(dotted_path: str) -> Any:
|
||||
return module
|
||||
|
||||
|
||||
def pattern_either_validator(__input_value: Any) -> typing.Pattern[Any]:
|
||||
if isinstance(__input_value, typing.Pattern):
|
||||
return __input_value
|
||||
elif isinstance(__input_value, (str, bytes)):
|
||||
def pattern_either_validator(input_value: Any, /) -> re.Pattern[Any]:
|
||||
if isinstance(input_value, re.Pattern):
|
||||
return input_value
|
||||
elif isinstance(input_value, (str, bytes)):
|
||||
# todo strict mode
|
||||
return compile_pattern(__input_value) # type: ignore
|
||||
return compile_pattern(input_value) # type: ignore
|
||||
else:
|
||||
raise PydanticCustomError('pattern_type', 'Input should be a valid pattern')
|
||||
|
||||
|
||||
def pattern_str_validator(__input_value: Any) -> typing.Pattern[str]:
|
||||
if isinstance(__input_value, typing.Pattern):
|
||||
if isinstance(__input_value.pattern, str):
|
||||
return __input_value
|
||||
def pattern_str_validator(input_value: Any, /) -> re.Pattern[str]:
|
||||
if isinstance(input_value, re.Pattern):
|
||||
if isinstance(input_value.pattern, str):
|
||||
return input_value
|
||||
else:
|
||||
raise PydanticCustomError('pattern_str_type', 'Input should be a string pattern')
|
||||
elif isinstance(__input_value, str):
|
||||
return compile_pattern(__input_value)
|
||||
elif isinstance(__input_value, bytes):
|
||||
elif isinstance(input_value, str):
|
||||
return compile_pattern(input_value)
|
||||
elif isinstance(input_value, bytes):
|
||||
raise PydanticCustomError('pattern_str_type', 'Input should be a string pattern')
|
||||
else:
|
||||
raise PydanticCustomError('pattern_type', 'Input should be a valid pattern')
|
||||
|
||||
|
||||
def pattern_bytes_validator(__input_value: Any) -> typing.Pattern[bytes]:
|
||||
if isinstance(__input_value, typing.Pattern):
|
||||
if isinstance(__input_value.pattern, bytes):
|
||||
return __input_value
|
||||
def pattern_bytes_validator(input_value: Any, /) -> re.Pattern[bytes]:
|
||||
if isinstance(input_value, re.Pattern):
|
||||
if isinstance(input_value.pattern, bytes):
|
||||
return input_value
|
||||
else:
|
||||
raise PydanticCustomError('pattern_bytes_type', 'Input should be a bytes pattern')
|
||||
elif isinstance(__input_value, bytes):
|
||||
return compile_pattern(__input_value)
|
||||
elif isinstance(__input_value, str):
|
||||
elif isinstance(input_value, bytes):
|
||||
return compile_pattern(input_value)
|
||||
elif isinstance(input_value, str):
|
||||
raise PydanticCustomError('pattern_bytes_type', 'Input should be a bytes pattern')
|
||||
else:
|
||||
raise PydanticCustomError('pattern_type', 'Input should be a valid pattern')
|
||||
|
||||
|
||||
PatternType = typing.TypeVar('PatternType', str, bytes)
|
||||
PatternType = TypeVar('PatternType', str, bytes)
|
||||
|
||||
|
||||
def compile_pattern(pattern: PatternType) -> typing.Pattern[PatternType]:
|
||||
def compile_pattern(pattern: PatternType) -> re.Pattern[PatternType]:
|
||||
try:
|
||||
return re.compile(pattern)
|
||||
except re.error:
|
||||
raise PydanticCustomError('pattern_regex', 'Input should be a valid regular expression')
|
||||
|
||||
|
||||
def ip_v4_address_validator(__input_value: Any) -> IPv4Address:
|
||||
if isinstance(__input_value, IPv4Address):
|
||||
return __input_value
|
||||
def ip_v4_address_validator(input_value: Any, /) -> IPv4Address:
|
||||
if isinstance(input_value, IPv4Address):
|
||||
return input_value
|
||||
|
||||
try:
|
||||
return IPv4Address(__input_value)
|
||||
return IPv4Address(input_value)
|
||||
except ValueError:
|
||||
raise PydanticCustomError('ip_v4_address', 'Input is not a valid IPv4 address')
|
||||
|
||||
|
||||
def ip_v6_address_validator(__input_value: Any) -> IPv6Address:
|
||||
if isinstance(__input_value, IPv6Address):
|
||||
return __input_value
|
||||
def ip_v6_address_validator(input_value: Any, /) -> IPv6Address:
|
||||
if isinstance(input_value, IPv6Address):
|
||||
return input_value
|
||||
|
||||
try:
|
||||
return IPv6Address(__input_value)
|
||||
return IPv6Address(input_value)
|
||||
except ValueError:
|
||||
raise PydanticCustomError('ip_v6_address', 'Input is not a valid IPv6 address')
|
||||
|
||||
|
||||
def ip_v4_network_validator(__input_value: Any) -> IPv4Network:
|
||||
def ip_v4_network_validator(input_value: Any, /) -> IPv4Network:
|
||||
"""Assume IPv4Network initialised with a default `strict` argument.
|
||||
|
||||
See more:
|
||||
https://docs.python.org/library/ipaddress.html#ipaddress.IPv4Network
|
||||
"""
|
||||
if isinstance(__input_value, IPv4Network):
|
||||
return __input_value
|
||||
if isinstance(input_value, IPv4Network):
|
||||
return input_value
|
||||
|
||||
try:
|
||||
return IPv4Network(__input_value)
|
||||
return IPv4Network(input_value)
|
||||
except ValueError:
|
||||
raise PydanticCustomError('ip_v4_network', 'Input is not a valid IPv4 network')
|
||||
|
||||
|
||||
def ip_v6_network_validator(__input_value: Any) -> IPv6Network:
|
||||
def ip_v6_network_validator(input_value: Any, /) -> IPv6Network:
|
||||
"""Assume IPv6Network initialised with a default `strict` argument.
|
||||
|
||||
See more:
|
||||
https://docs.python.org/library/ipaddress.html#ipaddress.IPv6Network
|
||||
"""
|
||||
if isinstance(__input_value, IPv6Network):
|
||||
return __input_value
|
||||
if isinstance(input_value, IPv6Network):
|
||||
return input_value
|
||||
|
||||
try:
|
||||
return IPv6Network(__input_value)
|
||||
return IPv6Network(input_value)
|
||||
except ValueError:
|
||||
raise PydanticCustomError('ip_v6_network', 'Input is not a valid IPv6 network')
|
||||
|
||||
|
||||
def ip_v4_interface_validator(__input_value: Any) -> IPv4Interface:
|
||||
if isinstance(__input_value, IPv4Interface):
|
||||
return __input_value
|
||||
def ip_v4_interface_validator(input_value: Any, /) -> IPv4Interface:
|
||||
if isinstance(input_value, IPv4Interface):
|
||||
return input_value
|
||||
|
||||
try:
|
||||
return IPv4Interface(__input_value)
|
||||
return IPv4Interface(input_value)
|
||||
except ValueError:
|
||||
raise PydanticCustomError('ip_v4_interface', 'Input is not a valid IPv4 interface')
|
||||
|
||||
|
||||
def ip_v6_interface_validator(__input_value: Any) -> IPv6Interface:
|
||||
if isinstance(__input_value, IPv6Interface):
|
||||
return __input_value
|
||||
def ip_v6_interface_validator(input_value: Any, /) -> IPv6Interface:
|
||||
if isinstance(input_value, IPv6Interface):
|
||||
return input_value
|
||||
|
||||
try:
|
||||
return IPv6Interface(__input_value)
|
||||
return IPv6Interface(input_value)
|
||||
except ValueError:
|
||||
raise PydanticCustomError('ip_v6_interface', 'Input is not a valid IPv6 interface')
|
||||
|
||||
|
||||
def greater_than_validator(x: Any, gt: Any) -> Any:
|
||||
if not (x > gt):
|
||||
raise PydanticKnownError('greater_than', {'gt': gt})
|
||||
return x
|
||||
def fraction_validator(input_value: Any, /) -> Fraction:
|
||||
if isinstance(input_value, Fraction):
|
||||
return input_value
|
||||
|
||||
|
||||
def greater_than_or_equal_validator(x: Any, ge: Any) -> Any:
|
||||
if not (x >= ge):
|
||||
raise PydanticKnownError('greater_than_equal', {'ge': ge})
|
||||
return x
|
||||
|
||||
|
||||
def less_than_validator(x: Any, lt: Any) -> Any:
|
||||
if not (x < lt):
|
||||
raise PydanticKnownError('less_than', {'lt': lt})
|
||||
return x
|
||||
|
||||
|
||||
def less_than_or_equal_validator(x: Any, le: Any) -> Any:
|
||||
if not (x <= le):
|
||||
raise PydanticKnownError('less_than_equal', {'le': le})
|
||||
return x
|
||||
|
||||
|
||||
def multiple_of_validator(x: Any, multiple_of: Any) -> Any:
|
||||
if not (x % multiple_of == 0):
|
||||
raise PydanticKnownError('multiple_of', {'multiple_of': multiple_of})
|
||||
return x
|
||||
|
||||
|
||||
def min_length_validator(x: Any, min_length: Any) -> Any:
|
||||
if not (len(x) >= min_length):
|
||||
raise PydanticKnownError(
|
||||
'too_short',
|
||||
{'field_type': 'Value', 'min_length': min_length, 'actual_length': len(x)},
|
||||
)
|
||||
return x
|
||||
|
||||
|
||||
def max_length_validator(x: Any, max_length: Any) -> Any:
|
||||
if len(x) > max_length:
|
||||
raise PydanticKnownError(
|
||||
'too_long',
|
||||
{'field_type': 'Value', 'max_length': max_length, 'actual_length': len(x)},
|
||||
)
|
||||
return x
|
||||
try:
|
||||
return Fraction(input_value)
|
||||
except ValueError:
|
||||
raise PydanticCustomError('fraction_parsing', 'Input is not a valid fraction')
|
||||
|
||||
|
||||
def forbid_inf_nan_check(x: Any) -> Any:
|
||||
if not math.isfinite(x):
|
||||
raise PydanticKnownError('finite_number')
|
||||
return x
|
||||
|
||||
|
||||
def _safe_repr(v: Any) -> int | float | str:
|
||||
"""The context argument for `PydanticKnownError` requires a number or str type, so we do a simple repr() coercion for types like timedelta.
|
||||
|
||||
See tests/test_types.py::test_annotated_metadata_any_order for some context.
|
||||
"""
|
||||
if isinstance(v, (int, float, str)):
|
||||
return v
|
||||
return repr(v)
|
||||
|
||||
|
||||
def greater_than_validator(x: Any, gt: Any) -> Any:
|
||||
try:
|
||||
if not (x > gt):
|
||||
raise PydanticKnownError('greater_than', {'gt': _safe_repr(gt)})
|
||||
return x
|
||||
except TypeError:
|
||||
raise TypeError(f"Unable to apply constraint 'gt' to supplied value {x}")
|
||||
|
||||
|
||||
def greater_than_or_equal_validator(x: Any, ge: Any) -> Any:
|
||||
try:
|
||||
if not (x >= ge):
|
||||
raise PydanticKnownError('greater_than_equal', {'ge': _safe_repr(ge)})
|
||||
return x
|
||||
except TypeError:
|
||||
raise TypeError(f"Unable to apply constraint 'ge' to supplied value {x}")
|
||||
|
||||
|
||||
def less_than_validator(x: Any, lt: Any) -> Any:
|
||||
try:
|
||||
if not (x < lt):
|
||||
raise PydanticKnownError('less_than', {'lt': _safe_repr(lt)})
|
||||
return x
|
||||
except TypeError:
|
||||
raise TypeError(f"Unable to apply constraint 'lt' to supplied value {x}")
|
||||
|
||||
|
||||
def less_than_or_equal_validator(x: Any, le: Any) -> Any:
|
||||
try:
|
||||
if not (x <= le):
|
||||
raise PydanticKnownError('less_than_equal', {'le': _safe_repr(le)})
|
||||
return x
|
||||
except TypeError:
|
||||
raise TypeError(f"Unable to apply constraint 'le' to supplied value {x}")
|
||||
|
||||
|
||||
def multiple_of_validator(x: Any, multiple_of: Any) -> Any:
|
||||
try:
|
||||
if x % multiple_of:
|
||||
raise PydanticKnownError('multiple_of', {'multiple_of': _safe_repr(multiple_of)})
|
||||
return x
|
||||
except TypeError:
|
||||
raise TypeError(f"Unable to apply constraint 'multiple_of' to supplied value {x}")
|
||||
|
||||
|
||||
def min_length_validator(x: Any, min_length: Any) -> Any:
|
||||
try:
|
||||
if not (len(x) >= min_length):
|
||||
raise PydanticKnownError(
|
||||
'too_short', {'field_type': 'Value', 'min_length': min_length, 'actual_length': len(x)}
|
||||
)
|
||||
return x
|
||||
except TypeError:
|
||||
raise TypeError(f"Unable to apply constraint 'min_length' to supplied value {x}")
|
||||
|
||||
|
||||
def max_length_validator(x: Any, max_length: Any) -> Any:
|
||||
try:
|
||||
if len(x) > max_length:
|
||||
raise PydanticKnownError(
|
||||
'too_long',
|
||||
{'field_type': 'Value', 'max_length': max_length, 'actual_length': len(x)},
|
||||
)
|
||||
return x
|
||||
except TypeError:
|
||||
raise TypeError(f"Unable to apply constraint 'max_length' to supplied value {x}")
|
||||
|
||||
|
||||
def _extract_decimal_digits_info(decimal: Decimal) -> tuple[int, int]:
|
||||
"""Compute the total number of digits and decimal places for a given [`Decimal`][decimal.Decimal] instance.
|
||||
|
||||
This function handles both normalized and non-normalized Decimal instances.
|
||||
Example: Decimal('1.230') -> 4 digits, 3 decimal places
|
||||
|
||||
Args:
|
||||
decimal (Decimal): The decimal number to analyze.
|
||||
|
||||
Returns:
|
||||
tuple[int, int]: A tuple containing the number of decimal places and total digits.
|
||||
|
||||
Though this could be divided into two separate functions, the logic is easier to follow if we couple the computation
|
||||
of the number of decimals and digits together.
|
||||
"""
|
||||
try:
|
||||
decimal_tuple = decimal.as_tuple()
|
||||
|
||||
assert isinstance(decimal_tuple.exponent, int)
|
||||
|
||||
exponent = decimal_tuple.exponent
|
||||
num_digits = len(decimal_tuple.digits)
|
||||
|
||||
if exponent >= 0:
|
||||
# A positive exponent adds that many trailing zeros
|
||||
# Ex: digit_tuple=(1, 2, 3), exponent=2 -> 12300 -> 0 decimal places, 5 digits
|
||||
num_digits += exponent
|
||||
decimal_places = 0
|
||||
else:
|
||||
# If the absolute value of the negative exponent is larger than the
|
||||
# number of digits, then it's the same as the number of digits,
|
||||
# because it'll consume all the digits in digit_tuple and then
|
||||
# add abs(exponent) - len(digit_tuple) leading zeros after the decimal point.
|
||||
# Ex: digit_tuple=(1, 2, 3), exponent=-2 -> 1.23 -> 2 decimal places, 3 digits
|
||||
# Ex: digit_tuple=(1, 2, 3), exponent=-4 -> 0.0123 -> 4 decimal places, 4 digits
|
||||
decimal_places = abs(exponent)
|
||||
num_digits = max(num_digits, decimal_places)
|
||||
|
||||
return decimal_places, num_digits
|
||||
except (AssertionError, AttributeError):
|
||||
raise TypeError(f'Unable to extract decimal digits info from supplied value {decimal}')
|
||||
|
||||
|
||||
def max_digits_validator(x: Any, max_digits: Any) -> Any:
|
||||
try:
|
||||
_, num_digits = _extract_decimal_digits_info(x)
|
||||
_, normalized_num_digits = _extract_decimal_digits_info(x.normalize())
|
||||
if (num_digits > max_digits) and (normalized_num_digits > max_digits):
|
||||
raise PydanticKnownError(
|
||||
'decimal_max_digits',
|
||||
{'max_digits': max_digits},
|
||||
)
|
||||
return x
|
||||
except TypeError:
|
||||
raise TypeError(f"Unable to apply constraint 'max_digits' to supplied value {x}")
|
||||
|
||||
|
||||
def decimal_places_validator(x: Any, decimal_places: Any) -> Any:
|
||||
try:
|
||||
decimal_places_, _ = _extract_decimal_digits_info(x)
|
||||
if decimal_places_ > decimal_places:
|
||||
normalized_decimal_places, _ = _extract_decimal_digits_info(x.normalize())
|
||||
if normalized_decimal_places > decimal_places:
|
||||
raise PydanticKnownError(
|
||||
'decimal_max_places',
|
||||
{'decimal_places': decimal_places},
|
||||
)
|
||||
return x
|
||||
except TypeError:
|
||||
raise TypeError(f"Unable to apply constraint 'decimal_places' to supplied value {x}")
|
||||
|
||||
|
||||
def deque_validator(input_value: Any, handler: core_schema.ValidatorFunctionWrapHandler) -> collections.deque[Any]:
|
||||
return collections.deque(handler(input_value), maxlen=getattr(input_value, 'maxlen', None))
|
||||
|
||||
|
||||
def defaultdict_validator(
|
||||
input_value: Any, handler: core_schema.ValidatorFunctionWrapHandler, default_default_factory: Callable[[], Any]
|
||||
) -> collections.defaultdict[Any, Any]:
|
||||
if isinstance(input_value, collections.defaultdict):
|
||||
default_factory = input_value.default_factory
|
||||
return collections.defaultdict(default_factory, handler(input_value))
|
||||
else:
|
||||
return collections.defaultdict(default_default_factory, handler(input_value))
|
||||
|
||||
|
||||
def get_defaultdict_default_default_factory(values_source_type: Any) -> Callable[[], Any]:
|
||||
FieldInfo = import_cached_field_info()
|
||||
|
||||
values_type_origin = get_origin(values_source_type)
|
||||
|
||||
def infer_default() -> Callable[[], Any]:
|
||||
allowed_default_types: dict[Any, Any] = {
|
||||
tuple: tuple,
|
||||
collections.abc.Sequence: tuple,
|
||||
collections.abc.MutableSequence: list,
|
||||
list: list,
|
||||
typing.Sequence: list,
|
||||
set: set,
|
||||
typing.MutableSet: set,
|
||||
collections.abc.MutableSet: set,
|
||||
collections.abc.Set: frozenset,
|
||||
typing.MutableMapping: dict,
|
||||
typing.Mapping: dict,
|
||||
collections.abc.Mapping: dict,
|
||||
collections.abc.MutableMapping: dict,
|
||||
float: float,
|
||||
int: int,
|
||||
str: str,
|
||||
bool: bool,
|
||||
}
|
||||
values_type = values_type_origin or values_source_type
|
||||
instructions = 'set using `DefaultDict[..., Annotated[..., Field(default_factory=...)]]`'
|
||||
if typing_objects.is_typevar(values_type):
|
||||
|
||||
def type_var_default_factory() -> None:
|
||||
raise RuntimeError(
|
||||
'Generic defaultdict cannot be used without a concrete value type or an'
|
||||
' explicit default factory, ' + instructions
|
||||
)
|
||||
|
||||
return type_var_default_factory
|
||||
elif values_type not in allowed_default_types:
|
||||
# a somewhat subjective set of types that have reasonable default values
|
||||
allowed_msg = ', '.join([t.__name__ for t in set(allowed_default_types.values())])
|
||||
raise PydanticSchemaGenerationError(
|
||||
f'Unable to infer a default factory for keys of type {values_source_type}.'
|
||||
f' Only {allowed_msg} are supported, other types require an explicit default factory'
|
||||
' ' + instructions
|
||||
)
|
||||
return allowed_default_types[values_type]
|
||||
|
||||
# Assume Annotated[..., Field(...)]
|
||||
if typing_objects.is_annotated(values_type_origin):
|
||||
field_info = next((v for v in get_args(values_source_type) if isinstance(v, FieldInfo)), None)
|
||||
else:
|
||||
field_info = None
|
||||
if field_info and field_info.default_factory:
|
||||
# Assume the default factory does not take any argument:
|
||||
default_default_factory = cast(Callable[[], Any], field_info.default_factory)
|
||||
else:
|
||||
default_default_factory = infer_default()
|
||||
return default_default_factory
|
||||
|
||||
|
||||
def validate_str_is_valid_iana_tz(value: Any, /) -> ZoneInfo:
|
||||
if isinstance(value, ZoneInfo):
|
||||
return value
|
||||
try:
|
||||
return ZoneInfo(value)
|
||||
except (ZoneInfoNotFoundError, ValueError, TypeError):
|
||||
raise PydanticCustomError('zoneinfo_str', 'invalid timezone: {value}', {'value': value})
|
||||
|
||||
|
||||
NUMERIC_VALIDATOR_LOOKUP: dict[str, Callable] = {
|
||||
'gt': greater_than_validator,
|
||||
'ge': greater_than_or_equal_validator,
|
||||
'lt': less_than_validator,
|
||||
'le': less_than_or_equal_validator,
|
||||
'multiple_of': multiple_of_validator,
|
||||
'min_length': min_length_validator,
|
||||
'max_length': max_length_validator,
|
||||
'max_digits': max_digits_validator,
|
||||
'decimal_places': decimal_places_validator,
|
||||
}
|
||||
|
||||
IpType = Union[IPv4Address, IPv6Address, IPv4Network, IPv6Network, IPv4Interface, IPv6Interface]
|
||||
|
||||
IP_VALIDATOR_LOOKUP: dict[type[IpType], Callable] = {
|
||||
IPv4Address: ip_v4_address_validator,
|
||||
IPv6Address: ip_v6_address_validator,
|
||||
IPv4Network: ip_v4_network_validator,
|
||||
IPv6Network: ip_v6_network_validator,
|
||||
IPv4Interface: ip_v4_interface_validator,
|
||||
IPv6Interface: ip_v6_interface_validator,
|
||||
}
|
||||
|
||||
MAPPING_ORIGIN_MAP: dict[Any, Any] = {
|
||||
typing.DefaultDict: collections.defaultdict, # noqa: UP006
|
||||
collections.defaultdict: collections.defaultdict,
|
||||
typing.OrderedDict: collections.OrderedDict, # noqa: UP006
|
||||
collections.OrderedDict: collections.OrderedDict,
|
||||
typing_extensions.OrderedDict: collections.OrderedDict,
|
||||
typing.Counter: collections.Counter,
|
||||
collections.Counter: collections.Counter,
|
||||
# this doesn't handle subclasses of these
|
||||
typing.Mapping: dict,
|
||||
typing.MutableMapping: dict,
|
||||
# parametrized typing.{Mutable}Mapping creates one of these
|
||||
collections.abc.Mapping: dict,
|
||||
collections.abc.MutableMapping: dict,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user