This commit is contained in:
Iliyan Angelov
2025-12-01 06:50:10 +02:00
parent 91f51bc6fe
commit 62c1fe5951
4682 changed files with 544807 additions and 31208 deletions

View File

@@ -0,0 +1,30 @@
from marshmallow.constants import EXCLUDE, INCLUDE, RAISE, missing
from marshmallow.decorators import (
post_dump,
post_load,
pre_dump,
pre_load,
validates,
validates_schema,
)
from marshmallow.exceptions import ValidationError
from marshmallow.schema import Schema, SchemaOpts
from . import fields
__all__ = [
"EXCLUDE",
"INCLUDE",
"RAISE",
"Schema",
"SchemaOpts",
"ValidationError",
"fields",
"missing",
"post_dump",
"post_load",
"pre_dump",
"pre_load",
"validates",
"validates_schema",
]

View File

@@ -0,0 +1,103 @@
"""A registry of :class:`Schema <marshmallow.Schema>` classes. This allows for string
lookup of schemas, which may be used with
class:`fields.Nested <marshmallow.fields.Nested>`.
.. warning::
This module is treated as private API.
Users should not need to use this module directly.
"""
# ruff: noqa: ERA001
from __future__ import annotations
import typing
from marshmallow.exceptions import RegistryError
if typing.TYPE_CHECKING:
from marshmallow import Schema
SchemaType = type[Schema]
# {
# <class_name>: <list of class objects>
# <module_path_to_class>: <list of class objects>
# }
_registry = {} # type: dict[str, list[SchemaType]]
def register(classname: str, cls: SchemaType) -> None:
"""Add a class to the registry of serializer classes. When a class is
registered, an entry for both its classname and its full, module-qualified
path are added to the registry.
Example: ::
class MyClass:
pass
register("MyClass", MyClass)
# Registry:
# {
# 'MyClass': [path.to.MyClass],
# 'path.to.MyClass': [path.to.MyClass],
# }
"""
# Module where the class is located
module = cls.__module__
# Full module path to the class
# e.g. user.schemas.UserSchema
fullpath = f"{module}.{classname}"
# If the class is already registered; need to check if the entries are
# in the same module as cls to avoid having multiple instances of the same
# class in the registry
if classname in _registry and not any(
each.__module__ == module for each in _registry[classname]
):
_registry[classname].append(cls)
elif classname not in _registry:
_registry[classname] = [cls]
# Also register the full path
if fullpath not in _registry:
_registry.setdefault(fullpath, []).append(cls)
else:
# If fullpath does exist, replace existing entry
_registry[fullpath] = [cls]
@typing.overload
def get_class(classname: str, *, all: typing.Literal[False] = ...) -> SchemaType: ...
@typing.overload
def get_class(
classname: str, *, all: typing.Literal[True] = ...
) -> list[SchemaType]: ...
def get_class(classname: str, *, all: bool = False) -> list[SchemaType] | SchemaType: # noqa: A002
"""Retrieve a class from the registry.
:raises: `marshmallow.exceptions.RegistryError` if the class cannot be found
or if there are multiple entries for the given class name.
"""
try:
classes = _registry[classname]
except KeyError as error:
raise RegistryError(
f"Class with name {classname!r} was not found. You may need "
"to import the class."
) from error
if len(classes) > 1:
if all:
return _registry[classname]
raise RegistryError(
f"Multiple classes with name {classname!r} "
"were found. Please use the full, "
"module-qualified path."
)
return _registry[classname][0]

View File

@@ -0,0 +1,25 @@
import typing
EXCLUDE: typing.Final = "exclude"
INCLUDE: typing.Final = "include"
RAISE: typing.Final = "raise"
class _Missing:
def __bool__(self):
return False
def __copy__(self):
return self
def __deepcopy__(self, _):
return self
def __repr__(self):
return "<marshmallow.missing>"
def __len__(self):
return 0
missing: typing.Final = _Missing()

View File

@@ -0,0 +1,254 @@
"""Decorators for registering schema pre-processing and post-processing methods.
These should be imported from the top-level `marshmallow` module.
Methods decorated with
`pre_load <marshmallow.decorators.pre_load>`, `post_load <marshmallow.decorators.post_load>`,
`pre_dump <marshmallow.decorators.pre_dump>`, `post_dump <marshmallow.decorators.post_dump>`,
and `validates_schema <marshmallow.decorators.validates_schema>` receive
``many`` as a keyword argument. In addition, `pre_load <marshmallow.decorators.pre_load>`,
`post_load <marshmallow.decorators.post_load>`,
and `validates_schema <marshmallow.decorators.validates_schema>` receive
``partial``. If you don't need these arguments, add ``**kwargs`` to your method
signature.
Example: ::
from marshmallow import (
Schema,
pre_load,
pre_dump,
post_load,
validates_schema,
validates,
fields,
ValidationError,
)
class UserSchema(Schema):
email = fields.Str(required=True)
age = fields.Integer(required=True)
@post_load
def lowerstrip_email(self, item, many, **kwargs):
item["email"] = item["email"].lower().strip()
return item
@pre_load(pass_collection=True)
def remove_envelope(self, data, many, **kwargs):
namespace = "results" if many else "result"
return data[namespace]
@post_dump(pass_collection=True)
def add_envelope(self, data, many, **kwargs):
namespace = "results" if many else "result"
return {namespace: data}
@validates_schema
def validate_email(self, data, **kwargs):
if len(data["email"]) < 3:
raise ValidationError("Email must be more than 3 characters", "email")
@validates("age")
def validate_age(self, data, **kwargs):
if data < 14:
raise ValidationError("Too young!")
.. note::
These decorators only work with instance methods. Class and static
methods are not supported.
.. warning::
The invocation order of decorated methods of the same type is not guaranteed.
If you need to guarantee order of different processing steps, you should put
them in the same processing method.
"""
from __future__ import annotations
import functools
import typing
from collections import defaultdict
PRE_DUMP = "pre_dump"
POST_DUMP = "post_dump"
PRE_LOAD = "pre_load"
POST_LOAD = "post_load"
VALIDATES = "validates"
VALIDATES_SCHEMA = "validates_schema"
class MarshmallowHook:
__marshmallow_hook__: dict[str, list[tuple[bool, typing.Any]]] | None = None
def validates(*field_names: str) -> typing.Callable[..., typing.Any]:
"""Register a validator method for field(s).
:param field_names: Names of the fields that the method validates.
.. versionchanged:: 4.0.0 Accepts multiple field names as positional arguments.
.. versionchanged:: 4.0.0 Decorated methods receive ``data_key`` as a keyword argument.
"""
return set_hook(None, VALIDATES, field_names=field_names)
def validates_schema(
fn: typing.Callable[..., typing.Any] | None = None,
*,
pass_collection: bool = False,
pass_original: bool = False,
skip_on_field_errors: bool = True,
) -> typing.Callable[..., typing.Any]:
"""Register a schema-level validator.
By default it receives a single object at a time, transparently handling the ``many``
argument passed to the `Schema <marshmallow.Schema>`'s :func:`~marshmallow.Schema.validate` call.
If ``pass_collection=True``, the raw data (which may be a collection) is passed.
If ``pass_original=True``, the original data (before unmarshalling) will be passed as
an additional argument to the method.
If ``skip_on_field_errors=True``, this validation method will be skipped whenever
validation errors have been detected when validating fields.
.. versionchanged:: 3.0.0b1 ``skip_on_field_errors`` defaults to `True`.
.. versionchanged:: 3.0.0 ``partial`` and ``many`` are always passed as keyword arguments to
the decorated method.
.. versionchanged:: 4.0.0 ``unknown`` is passed as a keyword argument to the decorated method.
.. versionchanged:: 4.0.0 ``pass_many`` is renamed to ``pass_collection``.
.. versionchanged:: 4.0.0 ``pass_collection``, ``pass_original``, and ``skip_on_field_errors``
are keyword-only arguments.
"""
return set_hook(
fn,
VALIDATES_SCHEMA,
many=pass_collection,
pass_original=pass_original,
skip_on_field_errors=skip_on_field_errors,
)
def pre_dump(
fn: typing.Callable[..., typing.Any] | None = None,
*,
pass_collection: bool = False,
) -> typing.Callable[..., typing.Any]:
"""Register a method to invoke before serializing an object. The method
receives the object to be serialized and returns the processed object.
By default it receives a single object at a time, transparently handling the ``many``
argument passed to the `Schema <marshmallow.Schema>`'s :func:`~marshmallow.Schema.dump` call.
If ``pass_collection=True``, the raw data (which may be a collection) is passed.
.. versionchanged:: 3.0.0 ``many`` is always passed as a keyword arguments to the decorated method.
.. versionchanged:: 4.0.0 ``pass_many`` is renamed to ``pass_collection``.
.. versionchanged:: 4.0.0 ``pass_collection`` is a keyword-only argument.
"""
return set_hook(fn, PRE_DUMP, many=pass_collection)
def post_dump(
fn: typing.Callable[..., typing.Any] | None = None,
*,
pass_collection: bool = False,
pass_original: bool = False,
) -> typing.Callable[..., typing.Any]:
"""Register a method to invoke after serializing an object. The method
receives the serialized object and returns the processed object.
By default it receives a single object at a time, transparently handling the ``many``
argument passed to the `Schema <marshmallow.Schema>`'s :func:`~marshmallow.Schema.dump` call.
If ``pass_collection=True``, the raw data (which may be a collection) is passed.
If ``pass_original=True``, the original data (before serializing) will be passed as
an additional argument to the method.
.. versionchanged:: 3.0.0 ``many`` is always passed as a keyword arguments to the decorated method.
.. versionchanged:: 4.0.0 ``pass_many`` is renamed to ``pass_collection``.
.. versionchanged:: 4.0.0 ``pass_collection`` and ``pass_original`` are keyword-only arguments.
"""
return set_hook(fn, POST_DUMP, many=pass_collection, pass_original=pass_original)
def pre_load(
fn: typing.Callable[..., typing.Any] | None = None,
*,
pass_collection: bool = False,
) -> typing.Callable[..., typing.Any]:
"""Register a method to invoke before deserializing an object. The method
receives the data to be deserialized and returns the processed data.
By default it receives a single object at a time, transparently handling the ``many``
argument passed to the `Schema <marshmallow.Schema>`'s :func:`~marshmallow.Schema.load` call.
If ``pass_collection=True``, the raw data (which may be a collection) is passed.
.. versionchanged:: 3.0.0 ``partial`` and ``many`` are always passed as keyword arguments to
the decorated method.
.. versionchanged:: 4.0.0 ``pass_many`` is renamed to ``pass_collection``.
.. versionchanged:: 4.0.0 ``pass_collection`` is a keyword-only argument.
.. versionchanged:: 4.0.0 ``unknown`` is passed as a keyword argument to the decorated method.
"""
return set_hook(fn, PRE_LOAD, many=pass_collection)
def post_load(
fn: typing.Callable[..., typing.Any] | None = None,
*,
pass_collection: bool = False,
pass_original: bool = False,
) -> typing.Callable[..., typing.Any]:
"""Register a method to invoke after deserializing an object. The method
receives the deserialized data and returns the processed data.
By default it receives a single object at a time, transparently handling the ``many``
argument passed to the `Schema <marshmallow.Schema>`'s :func:`~marshmallow.Schema.load` call.
If ``pass_collection=True``, the raw data (which may be a collection) is passed.
If ``pass_original=True``, the original data (before deserializing) will be passed as
an additional argument to the method.
.. versionchanged:: 3.0.0 ``partial`` and ``many`` are always passed as keyword arguments to
the decorated method.
.. versionchanged:: 4.0.0 ``pass_many`` is renamed to ``pass_collection``.
.. versionchanged:: 4.0.0 ``pass_collection`` and ``pass_original`` are keyword-only arguments.
.. versionchanged:: 4.0.0 ``unknown`` is passed as a keyword argument to the decorated method.
"""
return set_hook(fn, POST_LOAD, many=pass_collection, pass_original=pass_original)
def set_hook(
fn: typing.Callable[..., typing.Any] | None,
tag: str,
*,
many: bool = False,
**kwargs: typing.Any,
) -> typing.Callable[..., typing.Any]:
"""Mark decorated function as a hook to be picked up later.
You should not need to use this method directly.
.. note::
Currently only works with functions and instance methods. Class and
static methods are not supported.
:return: Decorated function if supplied, else this decorator with its args
bound.
"""
# Allow using this as either a decorator or a decorator factory.
if fn is None:
return functools.partial(set_hook, tag=tag, many=many, **kwargs)
# Set a __marshmallow_hook__ attribute instead of wrapping in some class,
# because I still want this to end up as a normal (unbound) method.
function = typing.cast("MarshmallowHook", fn)
try:
hook_config = function.__marshmallow_hook__
except AttributeError:
function.__marshmallow_hook__ = hook_config = defaultdict(list)
# Also save the kwargs for the tagged function on
# __marshmallow_hook__, keyed by <tag>
if hook_config is not None:
hook_config[tag].append((many, kwargs))
return fn

View File

@@ -0,0 +1,60 @@
"""Utilities for storing collections of error messages.
.. warning::
This module is treated as private API.
Users should not need to use this module directly.
"""
from marshmallow.exceptions import SCHEMA
class ErrorStore:
def __init__(self):
#: Dictionary of errors stored during serialization
self.errors = {}
def store_error(self, messages, field_name=SCHEMA, index=None):
# field error -> store/merge error messages under field name key
# schema error -> if string or list, store/merge under _schema key
# -> if dict, store/merge with other top-level keys
if field_name != SCHEMA or not isinstance(messages, dict):
messages = {field_name: messages}
if index is not None:
messages = {index: messages}
self.errors = merge_errors(self.errors, messages)
def merge_errors(errors1, errors2): # noqa: PLR0911
"""Deeply merge two error messages.
The format of ``errors1`` and ``errors2`` matches the ``message``
parameter of :exc:`marshmallow.exceptions.ValidationError`.
"""
if not errors1:
return errors2
if not errors2:
return errors1
if isinstance(errors1, list):
if isinstance(errors2, list):
return errors1 + errors2
if isinstance(errors2, dict):
return dict(errors2, **{SCHEMA: merge_errors(errors1, errors2.get(SCHEMA))})
return [*errors1, errors2]
if isinstance(errors1, dict):
if isinstance(errors2, list):
return dict(errors1, **{SCHEMA: merge_errors(errors1.get(SCHEMA), errors2)})
if isinstance(errors2, dict):
errors = dict(errors1)
for key, val in errors2.items():
if key in errors:
errors[key] = merge_errors(errors[key], val)
else:
errors[key] = val
return errors
return dict(errors1, **{SCHEMA: merge_errors(errors1.get(SCHEMA), errors2)})
if isinstance(errors2, list):
return [errors1, *errors2]
if isinstance(errors2, dict):
return dict(errors2, **{SCHEMA: merge_errors(errors1, errors2.get(SCHEMA))})
return [errors1, errors2]

View File

@@ -0,0 +1,70 @@
"""Exception classes for marshmallow-related errors."""
from __future__ import annotations
import typing
# Key used for schema-level validation errors
SCHEMA = "_schema"
class MarshmallowError(Exception):
"""Base class for all marshmallow-related errors."""
class ValidationError(MarshmallowError):
"""Raised when validation fails on a field or schema.
Validators and custom fields should raise this exception.
:param message: An error message, list of error messages, or dict of
error messages. If a dict, the keys are subitems and the values are error messages.
:param field_name: Field name to store the error on.
:param data: Raw input data.
:param valid_data: Valid (de)serialized data.
"""
def __init__(
self,
message: str | list | dict,
field_name: str = SCHEMA,
data: typing.Mapping[str, typing.Any]
| typing.Iterable[typing.Mapping[str, typing.Any]]
| None = None,
valid_data: list[typing.Any] | dict[str, typing.Any] | None = None,
**kwargs,
):
self.messages = [message] if isinstance(message, (str, bytes)) else message
self.field_name = field_name
self.data = data
self.valid_data = valid_data
self.kwargs = kwargs
super().__init__(message)
def normalized_messages(self):
if self.field_name == SCHEMA and isinstance(self.messages, dict):
return self.messages
return {self.field_name: self.messages}
@property
def messages_dict(self) -> dict[str, typing.Any]:
if not isinstance(self.messages, dict):
raise TypeError(
"cannot access 'messages_dict' when 'messages' is of type "
+ type(self.messages).__name__
)
return self.messages
class RegistryError(NameError):
"""Raised when an invalid operation is performed on the serializer
class registry.
"""
class StringNotCollectionError(MarshmallowError, TypeError):
"""Raised when a string is passed when a list of strings is expected."""
class _FieldInstanceResolutionError(MarshmallowError, TypeError):
"""Raised when an argument is passed to a field class that cannot be resolved to a Field instance."""

View File

@@ -0,0 +1,5 @@
"""Experimental features.
The features in this subpackage are experimental. Breaking changes may be
introduced in minor marshmallow versions.
"""

View File

@@ -0,0 +1,73 @@
"""Helper API for setting serialization/deserialization context.
Example usage:
.. code-block:: python
import typing
from marshmallow import Schema, fields
from marshmallow.experimental.context import Context
class UserContext(typing.TypedDict):
suffix: str
UserSchemaContext = Context[UserContext]
class UserSchema(Schema):
name_suffixed = fields.Function(
lambda user: user["name"] + UserSchemaContext.get()["suffix"]
)
with UserSchemaContext({"suffix": "bar"}):
print(UserSchema().dump({"name": "foo"}))
# {'name_suffixed': 'foobar'}
"""
from __future__ import annotations
import contextlib
import contextvars
import typing
try:
from types import EllipsisType
except ImportError: # Python<3.10
EllipsisType = type(Ellipsis) # type: ignore[misc]
_ContextT = typing.TypeVar("_ContextT")
_DefaultT = typing.TypeVar("_DefaultT")
_CURRENT_CONTEXT: contextvars.ContextVar = contextvars.ContextVar("context")
class Context(contextlib.AbstractContextManager, typing.Generic[_ContextT]):
"""Context manager for setting and retrieving context.
:param context: The context to use within the context manager scope.
"""
def __init__(self, context: _ContextT) -> None:
self.context = context
self.token: contextvars.Token | None = None
def __enter__(self) -> Context[_ContextT]:
self.token = _CURRENT_CONTEXT.set(self.context)
return self
def __exit__(self, *args, **kwargs) -> None:
_CURRENT_CONTEXT.reset(typing.cast("contextvars.Token", self.token))
@classmethod
def get(cls, default: _DefaultT | EllipsisType = ...) -> _ContextT | _DefaultT:
"""Get the current context.
:param default: Default value to return if no context is set.
If not provided and no context is set, a :exc:`LookupError` is raised.
"""
if default is not ...:
return _CURRENT_CONTEXT.get(default)
return _CURRENT_CONTEXT.get()

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,89 @@
# OrderedSet
# Copyright (c) 2009 Raymond Hettinger
#
# Permission is hereby granted, free of charge, to any person
# obtaining a copy of this software and associated documentation files
# (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge,
# publish, distribute, sublicense, and/or sell copies of the Software,
# and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
# OTHER DEALINGS IN THE SOFTWARE.
from collections.abc import MutableSet
class OrderedSet(MutableSet): # noqa: PLW1641
def __init__(self, iterable=None):
self.end = end = []
end += [None, end, end] # sentinel node for doubly linked list
self.map = {} # key --> [key, prev, next]
if iterable is not None:
self |= iterable
def __len__(self):
return len(self.map)
def __contains__(self, key):
return key in self.map
def add(self, key):
if key not in self.map:
end = self.end
curr = end[1]
curr[2] = end[1] = self.map[key] = [key, curr, end]
def discard(self, key):
if key in self.map:
key, prev, next = self.map.pop(key) # noqa: A001
prev[2] = next
next[1] = prev
def __iter__(self):
end = self.end
curr = end[2]
while curr is not end:
yield curr[0]
curr = curr[2]
def __reversed__(self):
end = self.end
curr = end[1]
while curr is not end:
yield curr[0]
curr = curr[1]
def pop(self, last=True):
if not self:
raise KeyError("set is empty")
key = self.end[1][0] if last else self.end[2][0]
self.discard(key)
return key
def __repr__(self):
if not self:
return f"{self.__class__.__name__}()"
return f"{self.__class__.__name__}({list(self)!r})"
def __eq__(self, other):
if isinstance(other, OrderedSet):
return len(self) == len(other) and list(self) == list(other)
return set(self) == set(other)
if __name__ == "__main__":
s = OrderedSet("abracadaba")
t = OrderedSet("simsalabim")
print(s | t)
print(s & t)
print(s - t)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,41 @@
"""Type aliases.
.. warning::
This module is provisional. Types may be modified, added, and removed between minor releases.
"""
from __future__ import annotations
import typing
#: A type that can be either a sequence of strings or a set of strings
StrSequenceOrSet: typing.TypeAlias = typing.Sequence[str] | typing.AbstractSet[str]
#: Type for validator functions
Validator: typing.TypeAlias = typing.Callable[[typing.Any], typing.Any]
#: A valid option for the ``unknown`` schema option and argument
UnknownOption: typing.TypeAlias = typing.Literal["exclude", "include", "raise"]
class SchemaValidator(typing.Protocol):
def __call__(
self,
output: typing.Any,
original_data: typing.Any = ...,
*,
partial: bool | StrSequenceOrSet | None = None,
unknown: UnknownOption | None = None,
many: bool = False,
) -> None: ...
class RenderModule(typing.Protocol):
def dumps(
self, obj: typing.Any, *args: typing.Any, **kwargs: typing.Any
) -> str: ...
def loads(
self, s: str | bytes | bytearray, *args: typing.Any, **kwargs: typing.Any
) -> typing.Any: ...

View File

@@ -0,0 +1,164 @@
"""Utility methods for marshmallow."""
from __future__ import annotations
import datetime as dt
import inspect
import typing
from collections.abc import Mapping, Sequence
from marshmallow.constants import missing
def is_generator(obj) -> typing.TypeGuard[typing.Generator]:
"""Return True if ``obj`` is a generator"""
return inspect.isgeneratorfunction(obj) or inspect.isgenerator(obj)
def is_iterable_but_not_string(obj) -> typing.TypeGuard[typing.Iterable]:
"""Return True if ``obj`` is an iterable object that isn't a string."""
return (hasattr(obj, "__iter__") and not hasattr(obj, "strip")) or is_generator(obj)
def is_sequence_but_not_string(obj) -> typing.TypeGuard[Sequence]:
"""Return True if ``obj`` is a sequence that isn't a string."""
return isinstance(obj, Sequence) and not isinstance(obj, (str, bytes))
def is_collection(obj) -> typing.TypeGuard[typing.Iterable]:
"""Return True if ``obj`` is a collection type, e.g list, tuple, queryset."""
return is_iterable_but_not_string(obj) and not isinstance(obj, Mapping)
# https://stackoverflow.com/a/27596917
def is_aware(datetime: dt.datetime) -> bool:
return (
datetime.tzinfo is not None and datetime.tzinfo.utcoffset(datetime) is not None
)
def from_timestamp(value: typing.Any) -> dt.datetime:
if value is True or value is False:
raise ValueError("Not a valid POSIX timestamp")
value = float(value)
if value < 0:
raise ValueError("Not a valid POSIX timestamp")
# Load a timestamp with utc as timezone to prevent using system timezone.
# Then set timezone to None, to let the Field handle adding timezone info.
try:
return dt.datetime.fromtimestamp(value, tz=dt.timezone.utc).replace(tzinfo=None)
except OverflowError as exc:
raise ValueError("Timestamp is too large") from exc
except OSError as exc:
raise ValueError("Error converting value to datetime") from exc
def from_timestamp_ms(value: typing.Any) -> dt.datetime:
value = float(value)
return from_timestamp(value / 1000)
def timestamp(
value: dt.datetime,
) -> float:
if not is_aware(value):
# When a date is naive, use UTC as zone info to prevent using system timezone.
value = value.replace(tzinfo=dt.timezone.utc)
return value.timestamp()
def timestamp_ms(value: dt.datetime) -> float:
return timestamp(value) * 1000
def ensure_text_type(val: str | bytes) -> str:
if isinstance(val, bytes):
val = val.decode("utf-8")
return str(val)
def pluck(dictlist: list[dict[str, typing.Any]], key: str):
"""Extracts a list of dictionary values from a list of dictionaries.
::
>>> dlist = [{'id': 1, 'name': 'foo'}, {'id': 2, 'name': 'bar'}]
>>> pluck(dlist, 'id')
[1, 2]
"""
return [d[key] for d in dictlist]
# Various utilities for pulling keyed values from objects
def get_value(obj, key: int | str, default=missing):
"""Helper for pulling a keyed value off various types of objects. Fields use
this method by default to access attributes of the source object. For object `x`
and attribute `i`, this method first tries to access `x[i]`, and then falls back to
`x.i` if an exception is raised.
.. warning::
If an object `x` does not raise an exception when `x[i]` does not exist,
`get_value` will never check the value `x.i`. Consider overriding
`marshmallow.fields.Field.get_value` in this case.
"""
if not isinstance(key, int) and "." in key:
return _get_value_for_keys(obj, key.split("."), default)
return _get_value_for_key(obj, key, default)
def _get_value_for_keys(obj, keys, default):
if len(keys) == 1:
return _get_value_for_key(obj, keys[0], default)
return _get_value_for_keys(
_get_value_for_key(obj, keys[0], default), keys[1:], default
)
def _get_value_for_key(obj, key, default):
if not hasattr(obj, "__getitem__"):
return getattr(obj, key, default)
try:
return obj[key]
except (KeyError, IndexError, TypeError, AttributeError):
return getattr(obj, key, default)
def set_value(dct: dict[str, typing.Any], key: str, value: typing.Any):
"""Set a value in a dict. If `key` contains a '.', it is assumed
be a path (i.e. dot-delimited string) to the value's location.
::
>>> d = {}
>>> set_value(d, 'foo.bar', 42)
>>> d
{'foo': {'bar': 42}}
"""
if "." in key:
head, rest = key.split(".", 1)
target = dct.setdefault(head, {})
if not isinstance(target, dict):
raise ValueError(
f"Cannot set {key} in {head} due to existing value: {target}"
)
set_value(target, rest, value)
else:
dct[key] = value
def callable_or_raise(obj):
"""Check that an object is callable, else raise a :exc:`TypeError`."""
if not callable(obj):
raise TypeError(f"Object {obj!r} is not callable.")
return obj
def timedelta_to_microseconds(value: dt.timedelta) -> int:
"""Compute the total microseconds of a timedelta.
https://github.com/python/cpython/blob/v3.13.1/Lib/_pydatetime.py#L805-L807
"""
return (value.days * (24 * 3600) + value.seconds) * 1000000 + value.microseconds

View File

@@ -0,0 +1,686 @@
"""Validation classes for various types of data."""
from __future__ import annotations
import re
import typing
from abc import ABC, abstractmethod
from itertools import zip_longest
from operator import attrgetter
from marshmallow.exceptions import ValidationError
if typing.TYPE_CHECKING:
from marshmallow import types
_T = typing.TypeVar("_T")
class Validator(ABC):
"""Abstract base class for validators.
.. note::
This class does not provide any validation behavior. It is only used to
add a useful `__repr__` implementation for validators.
"""
error: str | None = None
def __repr__(self) -> str:
args = self._repr_args()
args = f"{args}, " if args else ""
return f"<{self.__class__.__name__}({args}error={self.error!r})>"
def _repr_args(self) -> str:
"""A string representation of the args passed to this validator. Used by
`__repr__`.
"""
return ""
@abstractmethod
def __call__(self, value: typing.Any) -> typing.Any: ...
class And(Validator):
"""Compose multiple validators and combine their error messages.
Example: ::
from marshmallow import validate, ValidationError
def is_even(value):
if value % 2 != 0:
raise ValidationError("Not an even value.")
validator = validate.And(validate.Range(min=0), is_even)
validator(-1)
# ValidationError: ['Must be greater than or equal to 0.', 'Not an even value.']
:param validators: Validators to combine.
"""
def __init__(self, *validators: types.Validator):
self.validators = tuple(validators)
def _repr_args(self) -> str:
return f"validators={self.validators!r}"
def __call__(self, value: typing.Any) -> typing.Any:
errors: list[str | dict] = []
kwargs: dict[str, typing.Any] = {}
for validator in self.validators:
try:
validator(value)
except ValidationError as err:
kwargs.update(err.kwargs)
if isinstance(err.messages, dict):
errors.append(err.messages)
else:
errors.extend(err.messages)
if errors:
raise ValidationError(errors, **kwargs)
return value
class URL(Validator):
"""Validate a URL.
:param relative: Whether to allow relative URLs.
:param absolute: Whether to allow absolute URLs.
:param error: Error message to raise in case of a validation error.
Can be interpolated with `{input}`.
:param schemes: Valid schemes. By default, ``http``, ``https``,
``ftp``, and ``ftps`` are allowed.
:param require_tld: Whether to reject non-FQDN hostnames.
"""
class RegexMemoizer:
def __init__(self):
self._memoized = {}
def _regex_generator(
self, *, relative: bool, absolute: bool, require_tld: bool
) -> typing.Pattern:
hostname_variants = [
# a normal domain name, expressed in [A-Z0-9] chars with hyphens allowed only in the middle
# note that the regex will be compiled with IGNORECASE, so these are upper and lowercase chars
(
r"(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+"
r"(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)"
),
# or the special string 'localhost'
r"localhost",
# or IPv4
r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}",
# or IPv6
r"\[[A-F0-9]*:[A-F0-9:]+\]",
]
if not require_tld:
# allow dotless hostnames
hostname_variants.append(r"(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.?)")
absolute_part = "".join(
(
# scheme (e.g. 'https://', 'ftp://', etc)
# this is validated separately against allowed schemes, so in the regex
# we simply want to capture its existence
r"(?:[a-z0-9\.\-\+]*)://",
# userinfo, for URLs encoding authentication
# e.g. 'ftp://foo:bar@ftp.example.org/'
r"(?:(?:[a-z0-9\-._~!$&'()*+,;=:]|%[0-9a-f]{2})*@)?",
# netloc, the hostname/domain part of the URL plus the optional port
r"(?:",
"|".join(hostname_variants),
r")",
r"(?::\d+)?",
)
)
relative_part = r"(?:/?|[/?]\S+)\Z"
if relative:
if absolute:
parts: tuple[str, ...] = (
r"^(",
absolute_part,
r")?",
relative_part,
)
else:
parts = (r"^", relative_part)
else:
parts = (r"^", absolute_part, relative_part)
return re.compile("".join(parts), re.IGNORECASE)
def __call__(
self, *, relative: bool, absolute: bool, require_tld: bool
) -> typing.Pattern:
key = (relative, absolute, require_tld)
if key not in self._memoized:
self._memoized[key] = self._regex_generator(
relative=relative, absolute=absolute, require_tld=require_tld
)
return self._memoized[key]
_regex = RegexMemoizer()
default_message = "Not a valid URL."
default_schemes = {"http", "https", "ftp", "ftps"}
def __init__(
self,
*,
relative: bool = False,
absolute: bool = True,
schemes: types.StrSequenceOrSet | None = None,
require_tld: bool = True,
error: str | None = None,
):
if not relative and not absolute:
raise ValueError(
"URL validation cannot set both relative and absolute to False."
)
self.relative = relative
self.absolute = absolute
self.error: str = error or self.default_message
self.schemes = schemes or self.default_schemes
self.require_tld = require_tld
def _repr_args(self) -> str:
return f"relative={self.relative!r}, absolute={self.absolute!r}"
def _format_error(self, value) -> str:
return self.error.format(input=value)
def __call__(self, value: str) -> str:
message = self._format_error(value)
if not value:
raise ValidationError(message)
# Check first if the scheme is valid
scheme = None
if "://" in value:
scheme = value.split("://")[0].lower()
if scheme not in self.schemes:
raise ValidationError(message)
regex = self._regex(
relative=self.relative, absolute=self.absolute, require_tld=self.require_tld
)
# Hostname is optional for file URLS. If absent it means `localhost`.
# Fill it in for the validation if needed
if scheme == "file" and value.startswith("file:///"):
matched = regex.search(value.replace("file:///", "file://localhost/", 1))
else:
matched = regex.search(value)
if not matched:
raise ValidationError(message)
return value
class Email(Validator):
"""Validate an email address.
:param error: Error message to raise in case of a validation error. Can be
interpolated with `{input}`.
"""
USER_REGEX = re.compile(
r"(^[-!#$%&'*+/=?^`{}|~\w]+(\.[-!#$%&'*+/=?^`{}|~\w]+)*\Z" # dot-atom
# quoted-string
r'|^"([\001-\010\013\014\016-\037!#-\[\]-\177]'
r'|\\[\001-\011\013\014\016-\177])*"\Z)',
re.IGNORECASE | re.UNICODE,
)
DOMAIN_REGEX = re.compile(
# domain
r"(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+"
r"(?:[A-Z]{2,6}|[A-Z0-9-]{2,})\Z"
# literal form, ipv4 address (SMTP 4.1.3)
r"|^\[(25[0-5]|2[0-4]\d|[0-1]?\d?\d)"
r"(\.(25[0-5]|2[0-4]\d|[0-1]?\d?\d)){3}\]\Z",
re.IGNORECASE | re.UNICODE,
)
DOMAIN_WHITELIST = ("localhost",)
default_message = "Not a valid email address."
def __init__(self, *, error: str | None = None):
self.error: str = error or self.default_message
def _format_error(self, value: str) -> str:
return self.error.format(input=value)
def __call__(self, value: str) -> str:
message = self._format_error(value)
if not value or "@" not in value:
raise ValidationError(message)
user_part, domain_part = value.rsplit("@", 1)
if not self.USER_REGEX.match(user_part):
raise ValidationError(message)
if domain_part not in self.DOMAIN_WHITELIST:
if not self.DOMAIN_REGEX.match(domain_part):
try:
domain_part = domain_part.encode("idna").decode("ascii")
except UnicodeError:
pass
else:
if self.DOMAIN_REGEX.match(domain_part):
return value
raise ValidationError(message)
return value
class Range(Validator):
"""Validator which succeeds if the value passed to it is within the specified
range. If ``min`` is not specified, or is specified as `None`,
no lower bound exists. If ``max`` is not specified, or is specified as `None`,
no upper bound exists. The inclusivity of the bounds (if they exist) is configurable.
If ``min_inclusive`` is not specified, or is specified as `True`, then
the ``min`` bound is included in the range. If ``max_inclusive`` is not specified,
or is specified as `True`, then the ``max`` bound is included in the range.
:param min: The minimum value (lower bound). If not provided, minimum
value will not be checked.
:param max: The maximum value (upper bound). If not provided, maximum
value will not be checked.
:param min_inclusive: Whether the `min` bound is included in the range.
:param max_inclusive: Whether the `max` bound is included in the range.
:param error: Error message to raise in case of a validation error.
Can be interpolated with `{input}`, `{min}` and `{max}`.
"""
message_min = "Must be {min_op} {{min}}."
message_max = "Must be {max_op} {{max}}."
message_all = "Must be {min_op} {{min}} and {max_op} {{max}}."
message_gte = "greater than or equal to"
message_gt = "greater than"
message_lte = "less than or equal to"
message_lt = "less than"
def __init__(
self,
min=None, # noqa: A002
max=None, # noqa: A002
*,
min_inclusive: bool = True,
max_inclusive: bool = True,
error: str | None = None,
):
self.min = min
self.max = max
self.error = error
self.min_inclusive = min_inclusive
self.max_inclusive = max_inclusive
# interpolate messages based on bound inclusivity
self.message_min = self.message_min.format(
min_op=self.message_gte if self.min_inclusive else self.message_gt
)
self.message_max = self.message_max.format(
max_op=self.message_lte if self.max_inclusive else self.message_lt
)
self.message_all = self.message_all.format(
min_op=self.message_gte if self.min_inclusive else self.message_gt,
max_op=self.message_lte if self.max_inclusive else self.message_lt,
)
def _repr_args(self) -> str:
return f"min={self.min!r}, max={self.max!r}, min_inclusive={self.min_inclusive!r}, max_inclusive={self.max_inclusive!r}"
def _format_error(self, value: _T, message: str) -> str:
return (self.error or message).format(input=value, min=self.min, max=self.max)
def __call__(self, value: _T) -> _T:
if self.min is not None and (
value < self.min if self.min_inclusive else value <= self.min
):
message = self.message_min if self.max is None else self.message_all
raise ValidationError(self._format_error(value, message))
if self.max is not None and (
value > self.max if self.max_inclusive else value >= self.max
):
message = self.message_max if self.min is None else self.message_all
raise ValidationError(self._format_error(value, message))
return value
_SizedT = typing.TypeVar("_SizedT", bound=typing.Sized)
class Length(Validator):
"""Validator which succeeds if the value passed to it has a
length between a minimum and maximum. Uses len(), so it
can work for strings, lists, or anything with length.
:param min: The minimum length. If not provided, minimum length
will not be checked.
:param max: The maximum length. If not provided, maximum length
will not be checked.
:param equal: The exact length. If provided, maximum and minimum
length will not be checked.
:param error: Error message to raise in case of a validation error.
Can be interpolated with `{input}`, `{min}` and `{max}`.
"""
message_min = "Shorter than minimum length {min}."
message_max = "Longer than maximum length {max}."
message_all = "Length must be between {min} and {max}."
message_equal = "Length must be {equal}."
def __init__(
self,
min: int | None = None, # noqa: A002
max: int | None = None, # noqa: A002
*,
equal: int | None = None,
error: str | None = None,
):
if equal is not None and any([min, max]):
raise ValueError(
"The `equal` parameter was provided, maximum or "
"minimum parameter must not be provided."
)
self.min = min
self.max = max
self.error = error
self.equal = equal
def _repr_args(self) -> str:
return f"min={self.min!r}, max={self.max!r}, equal={self.equal!r}"
def _format_error(self, value: _SizedT, message: str) -> str:
return (self.error or message).format(
input=value, min=self.min, max=self.max, equal=self.equal
)
def __call__(self, value: _SizedT) -> _SizedT:
length = len(value)
if self.equal is not None:
if length != self.equal:
raise ValidationError(self._format_error(value, self.message_equal))
return value
if self.min is not None and length < self.min:
message = self.message_min if self.max is None else self.message_all
raise ValidationError(self._format_error(value, message))
if self.max is not None and length > self.max:
message = self.message_max if self.min is None else self.message_all
raise ValidationError(self._format_error(value, message))
return value
class Equal(Validator):
"""Validator which succeeds if the ``value`` passed to it is
equal to ``comparable``.
:param comparable: The object to compare to.
:param error: Error message to raise in case of a validation error.
Can be interpolated with `{input}` and `{other}`.
"""
default_message = "Must be equal to {other}."
def __init__(self, comparable, *, error: str | None = None):
self.comparable = comparable
self.error: str = error or self.default_message
def _repr_args(self) -> str:
return f"comparable={self.comparable!r}"
def _format_error(self, value: _T) -> str:
return self.error.format(input=value, other=self.comparable)
def __call__(self, value: _T) -> _T:
if value != self.comparable:
raise ValidationError(self._format_error(value))
return value
class Regexp(Validator):
"""Validator which succeeds if the ``value`` matches ``regex``.
.. note::
Uses `re.match`, which searches for a match at the beginning of a string.
:param regex: The regular expression string to use. Can also be a compiled
regular expression pattern.
:param flags: The regexp flags to use, for example re.IGNORECASE. Ignored
if ``regex`` is not a string.
:param error: Error message to raise in case of a validation error.
Can be interpolated with `{input}` and `{regex}`.
"""
default_message = "String does not match expected pattern."
def __init__(
self,
regex: str | bytes | typing.Pattern,
flags: int = 0,
*,
error: str | None = None,
):
self.regex = (
re.compile(regex, flags) if isinstance(regex, (str, bytes)) else regex
)
self.error: str = error or self.default_message
def _repr_args(self) -> str:
return f"regex={self.regex!r}"
def _format_error(self, value: str | bytes) -> str:
return self.error.format(input=value, regex=self.regex.pattern)
@typing.overload
def __call__(self, value: str) -> str: ...
@typing.overload
def __call__(self, value: bytes) -> bytes: ...
def __call__(self, value):
if self.regex.match(value) is None:
raise ValidationError(self._format_error(value))
return value
class Predicate(Validator):
"""Call the specified ``method`` of the ``value`` object. The
validator succeeds if the invoked method returns an object that
evaluates to True in a Boolean context. Any additional keyword
argument will be passed to the method.
:param method: The name of the method to invoke.
:param error: Error message to raise in case of a validation error.
Can be interpolated with `{input}` and `{method}`.
:param kwargs: Additional keyword arguments to pass to the method.
"""
default_message = "Invalid input."
def __init__(self, method: str, *, error: str | None = None, **kwargs):
self.method = method
self.error: str = error or self.default_message
self.kwargs = kwargs
def _repr_args(self) -> str:
return f"method={self.method!r}, kwargs={self.kwargs!r}"
def _format_error(self, value: typing.Any) -> str:
return self.error.format(input=value, method=self.method)
def __call__(self, value: _T) -> _T:
method = getattr(value, self.method)
if not method(**self.kwargs):
raise ValidationError(self._format_error(value))
return value
class NoneOf(Validator):
"""Validator which fails if ``value`` is a member of ``iterable``.
:param iterable: A sequence of invalid values.
:param error: Error message to raise in case of a validation error. Can be
interpolated using `{input}` and `{values}`.
"""
default_message = "Invalid input."
def __init__(self, iterable: typing.Iterable, *, error: str | None = None):
self.iterable = iterable
self.values_text = ", ".join(str(each) for each in self.iterable)
self.error: str = error or self.default_message
def _repr_args(self) -> str:
return f"iterable={self.iterable!r}"
def _format_error(self, value) -> str:
return self.error.format(input=value, values=self.values_text)
def __call__(self, value: typing.Any) -> typing.Any:
try:
if value in self.iterable:
raise ValidationError(self._format_error(value))
except TypeError:
pass
return value
class OneOf(Validator):
"""Validator which succeeds if ``value`` is a member of ``choices``.
:param choices: A sequence of valid values.
:param labels: Optional sequence of labels to pair with the choices.
:param error: Error message to raise in case of a validation error. Can be
interpolated with `{input}`, `{choices}` and `{labels}`.
"""
default_message = "Must be one of: {choices}."
def __init__(
self,
choices: typing.Iterable,
labels: typing.Iterable[str] | None = None,
*,
error: str | None = None,
):
self.choices = choices
self.choices_text = ", ".join(str(choice) for choice in self.choices)
self.labels = labels if labels is not None else []
self.labels_text = ", ".join(str(label) for label in self.labels)
self.error: str = error or self.default_message
def _repr_args(self) -> str:
return f"choices={self.choices!r}, labels={self.labels!r}"
def _format_error(self, value) -> str:
return self.error.format(
input=value, choices=self.choices_text, labels=self.labels_text
)
def __call__(self, value: typing.Any) -> typing.Any:
try:
if value not in self.choices:
raise ValidationError(self._format_error(value))
except TypeError as error:
raise ValidationError(self._format_error(value)) from error
return value
def options(
self,
valuegetter: str | typing.Callable[[typing.Any], typing.Any] = str,
) -> typing.Iterable[tuple[typing.Any, str]]:
"""Return a generator over the (value, label) pairs, where value
is a string associated with each choice. This convenience method
is useful to populate, for instance, a form select field.
:param valuegetter: Can be a callable or a string. In the former case, it must
be a one-argument callable which returns the value of a
choice. In the latter case, the string specifies the name
of an attribute of the choice objects. Defaults to `str()`
or `str()`.
"""
valuegetter = valuegetter if callable(valuegetter) else attrgetter(valuegetter)
pairs = zip_longest(self.choices, self.labels, fillvalue="")
return ((valuegetter(choice), label) for choice, label in pairs)
class ContainsOnly(OneOf):
"""Validator which succeeds if ``value`` is a sequence and each element
in the sequence is also in the sequence passed as ``choices``. Empty input
is considered valid.
:param choices: Same as :class:`OneOf`.
:param labels: Same as :class:`OneOf`.
:param error: Same as :class:`OneOf`.
.. versionchanged:: 3.0.0b2
Duplicate values are considered valid.
.. versionchanged:: 3.0.0b2
Empty input is considered valid. Use `validate.Length(min=1) <marshmallow.validate.Length>`
to validate against empty inputs.
"""
default_message = "One or more of the choices you made was not in: {choices}."
def _format_error(self, value) -> str:
value_text = ", ".join(str(val) for val in value)
return super()._format_error(value_text)
def __call__(self, value: typing.Sequence[_T]) -> typing.Sequence[_T]:
# We can't use set.issubset because does not handle unhashable types
for val in value:
if val not in self.choices:
raise ValidationError(self._format_error(value))
return value
class ContainsNoneOf(NoneOf):
"""Validator which fails if ``value`` is a sequence and any element
in the sequence is a member of the sequence passed as ``iterable``. Empty input
is considered valid.
:param iterable: Same as :class:`NoneOf`.
:param error: Same as :class:`NoneOf`.
.. versionadded:: 3.6.0
"""
default_message = "One or more of the choices you made was in: {values}."
def _format_error(self, value) -> str:
value_text = ", ".join(str(val) for val in value)
return super()._format_error(value_text)
def __call__(self, value: typing.Sequence[_T]) -> typing.Sequence[_T]:
for val in value:
if val in self.iterable:
raise ValidationError(self._format_error(value))
return value