This commit is contained in:
Iliyan Angelov
2025-09-19 11:58:53 +03:00
parent 306b20e24a
commit 6b247e5b9f
11423 changed files with 1500615 additions and 778 deletions

View File

@@ -0,0 +1,53 @@
from django.urls import get_script_prefix, resolve
def get_breadcrumbs(url, request=None):
"""
Given a url returns a list of breadcrumbs, which are each a
tuple of (name, url).
"""
from rest_framework.reverse import preserve_builtin_query_params
from rest_framework.views import APIView
def breadcrumbs_recursive(url, breadcrumbs_list, prefix, seen):
"""
Add tuples of (name, url) to the breadcrumbs list,
progressively chomping off parts of the url.
"""
try:
(view, unused_args, unused_kwargs) = resolve(url)
except Exception:
pass
else:
# Check if this is a REST framework view,
# and if so add it to the breadcrumbs
cls = getattr(view, 'cls', None)
initkwargs = getattr(view, 'initkwargs', {})
if cls is not None and issubclass(cls, APIView):
# Don't list the same view twice in a row.
# Probably an optional trailing slash.
if not seen or seen[-1] != view:
c = cls(**initkwargs)
name = c.get_view_name()
insert_url = preserve_builtin_query_params(prefix + url, request)
breadcrumbs_list.insert(0, (name, insert_url))
seen.append(view)
if url == '':
# All done
return breadcrumbs_list
elif url.endswith('/'):
# Drop trailing slash off the end and continue to try to
# resolve more breadcrumbs
url = url.rstrip('/')
return breadcrumbs_recursive(url, breadcrumbs_list, prefix, seen)
# Drop trailing non-slash off the end and continue to try to
# resolve more breadcrumbs
url = url[:url.rfind('/') + 1]
return breadcrumbs_recursive(url, breadcrumbs_list, prefix, seen)
prefix = get_script_prefix().rstrip('/')
url = url[len(prefix):]
return breadcrumbs_recursive(url, [], prefix, [])

View File

@@ -0,0 +1,78 @@
"""
Helper classes for parsers.
"""
import contextlib
import datetime
import decimal
import json # noqa
import uuid
from django.db.models.query import QuerySet
from django.utils import timezone
from django.utils.encoding import force_str
from django.utils.functional import Promise
from rest_framework.compat import coreapi
class JSONEncoder(json.JSONEncoder):
"""
JSONEncoder subclass that knows how to encode date/time/timedelta,
decimal types, generators and other basic python objects.
"""
def default(self, obj):
# For Date Time string spec, see ECMA 262
# https://ecma-international.org/ecma-262/5.1/#sec-15.9.1.15
if isinstance(obj, Promise):
return force_str(obj)
elif isinstance(obj, datetime.datetime):
representation = obj.isoformat()
if representation.endswith('+00:00'):
representation = representation[:-6] + 'Z'
return representation
elif isinstance(obj, datetime.date):
return obj.isoformat()
elif isinstance(obj, datetime.time):
if timezone and timezone.is_aware(obj):
raise ValueError("JSON can't represent timezone-aware times.")
representation = obj.isoformat()
return representation
elif isinstance(obj, datetime.timedelta):
return str(obj.total_seconds())
elif isinstance(obj, decimal.Decimal):
# Serializers will coerce decimals to strings by default.
return float(obj)
elif isinstance(obj, uuid.UUID):
return str(obj)
elif isinstance(obj, QuerySet):
return tuple(obj)
elif isinstance(obj, bytes):
# Best-effort for binary blobs. See #4187.
return obj.decode()
elif hasattr(obj, 'tolist'):
# Numpy arrays and array scalars.
return obj.tolist()
elif (coreapi is not None) and isinstance(obj, (coreapi.Document, coreapi.Error)):
raise RuntimeError(
'Cannot return a coreapi object from a JSON view. '
'You should be using a schema renderer instead for this view.'
)
elif hasattr(obj, '__getitem__'):
cls = (list if isinstance(obj, (list, tuple)) else dict)
with contextlib.suppress(Exception):
return cls(obj)
elif hasattr(obj, '__iter__'):
return tuple(item for item in obj)
return super().default(obj)
class CustomScalar:
"""
CustomScalar that knows how to encode timedelta that renderer
can understand.
"""
@classmethod
def represent_timedelta(cls, dumper, data):
value = str(data.total_seconds())
return dumper.represent_scalar('tag:yaml.org,2002:str', value)

View File

@@ -0,0 +1,327 @@
"""
Helper functions for mapping model fields to a dictionary of default
keyword arguments that should be used for their equivalent serializer fields.
"""
import inspect
from django.core import validators
from django.db import models
from django.utils.text import capfirst
from rest_framework.compat import postgres_fields
from rest_framework.validators import UniqueValidator
NUMERIC_FIELD_TYPES = (
models.IntegerField, models.FloatField, models.DecimalField, models.DurationField,
)
class ClassLookupDict:
"""
Takes a dictionary with classes as keys.
Lookups against this object will traverses the object's inheritance
hierarchy in method resolution order, and returns the first matching value
from the dictionary or raises a KeyError if nothing matches.
"""
def __init__(self, mapping):
self.mapping = mapping
def __getitem__(self, key):
if hasattr(key, '_proxy_class'):
# Deal with proxy classes. Ie. BoundField behaves as if it
# is a Field instance when using ClassLookupDict.
base_class = key._proxy_class
else:
base_class = key.__class__
for cls in inspect.getmro(base_class):
if cls in self.mapping:
return self.mapping[cls]
raise KeyError('Class %s not found in lookup.' % base_class.__name__)
def __setitem__(self, key, value):
self.mapping[key] = value
def needs_label(model_field, field_name):
"""
Returns `True` if the label based on the model's verbose name
is not equal to the default label it would have based on it's field name.
"""
default_label = field_name.replace('_', ' ').capitalize()
return capfirst(model_field.verbose_name) != default_label
def get_detail_view_name(model):
"""
Given a model class, return the view name to use for URL relationships
that refer to instances of the model.
"""
return '%(model_name)s-detail' % {
'model_name': model._meta.object_name.lower()
}
def get_unique_validators(field_name, model_field):
"""
Returns a list of UniqueValidators that should be applied to the field.
"""
field_set = {field_name}
conditions = {
c.condition
for c in model_field.model._meta.constraints
if isinstance(c, models.UniqueConstraint) and set(c.fields) == field_set
}
if getattr(model_field, 'unique', False):
conditions.add(None)
if not conditions:
return
unique_error_message = get_unique_error_message(model_field)
queryset = model_field.model._default_manager
for condition in conditions:
yield UniqueValidator(
queryset=queryset if condition is None else queryset.filter(condition),
message=unique_error_message
)
def get_field_kwargs(field_name, model_field):
"""
Creates a default instance of a basic non-relational field.
"""
kwargs = {}
validator_kwarg = list(model_field.validators)
# The following will only be used by ModelField classes.
# Gets removed for everything else.
kwargs['model_field'] = model_field
if model_field.verbose_name and needs_label(model_field, field_name):
kwargs['label'] = capfirst(model_field.verbose_name)
if model_field.help_text:
kwargs['help_text'] = model_field.help_text
max_digits = getattr(model_field, 'max_digits', None)
if max_digits is not None:
kwargs['max_digits'] = max_digits
decimal_places = getattr(model_field, 'decimal_places', None)
if decimal_places is not None:
kwargs['decimal_places'] = decimal_places
if isinstance(model_field, models.SlugField):
kwargs['allow_unicode'] = model_field.allow_unicode
if isinstance(model_field, models.TextField) and not model_field.choices or \
(postgres_fields and isinstance(model_field, postgres_fields.JSONField)) or \
(hasattr(models, 'JSONField') and isinstance(model_field, models.JSONField)):
kwargs['style'] = {'base_template': 'textarea.html'}
if model_field.null:
kwargs['allow_null'] = True
if isinstance(model_field, models.AutoField) or not model_field.editable:
# If this field is read-only, then return early.
# Further keyword arguments are not valid.
kwargs['read_only'] = True
return kwargs
if model_field.has_default() or model_field.blank or model_field.null:
kwargs['required'] = False
if model_field.blank and (isinstance(model_field, (models.CharField, models.TextField))):
kwargs['allow_blank'] = True
if not model_field.blank and (postgres_fields and isinstance(model_field, postgres_fields.ArrayField)):
kwargs['allow_empty'] = False
if isinstance(model_field, models.FilePathField):
kwargs['path'] = model_field.path
if model_field.match is not None:
kwargs['match'] = model_field.match
if model_field.recursive is not False:
kwargs['recursive'] = model_field.recursive
if model_field.allow_files is not True:
kwargs['allow_files'] = model_field.allow_files
if model_field.allow_folders is not False:
kwargs['allow_folders'] = model_field.allow_folders
if model_field.choices:
kwargs['choices'] = model_field.choices
else:
# Ensure that max_value is passed explicitly as a keyword arg,
# rather than as a validator.
max_value = next((
validator.limit_value for validator in validator_kwarg
if isinstance(validator, validators.MaxValueValidator)
), None)
if max_value is not None and isinstance(model_field, NUMERIC_FIELD_TYPES):
kwargs['max_value'] = max_value
validator_kwarg = [
validator for validator in validator_kwarg
if not isinstance(validator, validators.MaxValueValidator)
]
# Ensure that min_value is passed explicitly as a keyword arg,
# rather than as a validator.
min_value = next((
validator.limit_value for validator in validator_kwarg
if isinstance(validator, validators.MinValueValidator)
), None)
if min_value is not None and isinstance(model_field, NUMERIC_FIELD_TYPES):
kwargs['min_value'] = min_value
validator_kwarg = [
validator for validator in validator_kwarg
if not isinstance(validator, validators.MinValueValidator)
]
# URLField does not need to include the URLValidator argument,
# as it is explicitly added in.
if isinstance(model_field, models.URLField):
validator_kwarg = [
validator for validator in validator_kwarg
if not isinstance(validator, validators.URLValidator)
]
# EmailField does not need to include the validate_email argument,
# as it is explicitly added in.
if isinstance(model_field, models.EmailField):
validator_kwarg = [
validator for validator in validator_kwarg
if validator is not validators.validate_email
]
# SlugField do not need to include the 'validate_slug' argument,
if isinstance(model_field, models.SlugField):
validator_kwarg = [
validator for validator in validator_kwarg
if validator is not validators.validate_slug
]
# IPAddressField do not need to include the 'validate_ipv46_address' argument,
if isinstance(model_field, models.GenericIPAddressField):
validator_kwarg = [
validator for validator in validator_kwarg
if validator is not validators.validate_ipv46_address
]
# Our decimal validation is handled in the field code, not validator code.
if isinstance(model_field, models.DecimalField):
validator_kwarg = [
validator for validator in validator_kwarg
if not isinstance(validator, validators.DecimalValidator)
]
# Ensure that max_length is passed explicitly as a keyword arg,
# rather than as a validator.
max_length = getattr(model_field, 'max_length', None)
if max_length is not None and (isinstance(model_field, (models.CharField, models.TextField, models.FileField))):
kwargs['max_length'] = max_length
validator_kwarg = [
validator for validator in validator_kwarg
if not isinstance(validator, validators.MaxLengthValidator)
]
# Ensure that min_length is passed explicitly as a keyword arg,
# rather than as a validator.
min_length = next((
validator.limit_value for validator in validator_kwarg
if isinstance(validator, validators.MinLengthValidator)
), None)
if min_length is not None and isinstance(model_field, models.CharField):
kwargs['min_length'] = min_length
validator_kwarg = [
validator for validator in validator_kwarg
if not isinstance(validator, validators.MinLengthValidator)
]
validator_kwarg += get_unique_validators(field_name, model_field)
if validator_kwarg:
kwargs['validators'] = validator_kwarg
return kwargs
def get_relation_kwargs(field_name, relation_info):
"""
Creates a default instance of a flat relational field.
"""
model_field, related_model, to_many, to_field, has_through_model, reverse = relation_info
kwargs = {
'queryset': related_model._default_manager,
'view_name': get_detail_view_name(related_model)
}
if to_many:
kwargs['many'] = True
if to_field:
kwargs['to_field'] = to_field
limit_choices_to = model_field and model_field.get_limit_choices_to()
if limit_choices_to:
if not isinstance(limit_choices_to, models.Q):
limit_choices_to = models.Q(**limit_choices_to)
kwargs['queryset'] = kwargs['queryset'].filter(limit_choices_to)
if has_through_model:
kwargs['read_only'] = True
kwargs.pop('queryset', None)
if model_field:
if model_field.verbose_name and needs_label(model_field, field_name):
kwargs['label'] = capfirst(model_field.verbose_name)
help_text = model_field.help_text
if help_text:
kwargs['help_text'] = help_text
if not model_field.editable:
kwargs['read_only'] = True
kwargs.pop('queryset', None)
if model_field.null:
kwargs['allow_null'] = True
if kwargs.get('read_only', False):
# If this field is read-only, then return early.
# No further keyword arguments are valid.
return kwargs
if model_field.has_default() or model_field.blank or model_field.null:
kwargs['required'] = False
if model_field.validators:
kwargs['validators'] = model_field.validators
if getattr(model_field, 'unique', False):
validator = UniqueValidator(
queryset=model_field.model._default_manager,
message=get_unique_error_message(model_field))
kwargs['validators'] = kwargs.get('validators', []) + [validator]
if to_many and not model_field.blank:
kwargs['allow_empty'] = False
return kwargs
def get_nested_relation_kwargs(relation_info):
kwargs = {'read_only': True}
if relation_info.to_many:
kwargs['many'] = True
return kwargs
def get_url_kwargs(model_field):
return {
'view_name': get_detail_view_name(model_field)
}
def get_unique_error_message(model_field):
unique_error_message = model_field.error_messages.get('unique', None)
if unique_error_message:
unique_error_message = unique_error_message % {
'model_name': model_field.model._meta.verbose_name,
'field_label': model_field.verbose_name
}
return unique_error_message

View File

@@ -0,0 +1,93 @@
"""
Utility functions to return a formatted name and description for a given view.
"""
import re
from django.utils.encoding import force_str
from django.utils.html import escape
from django.utils.safestring import mark_safe
from rest_framework.compat import apply_markdown
def remove_trailing_string(content, trailing):
"""
Strip trailing component `trailing` from `content` if it exists.
Used when generating names from view classes.
"""
if content.endswith(trailing) and content != trailing:
return content[:-len(trailing)]
return content
def dedent(content):
"""
Remove leading indent from a block of text.
Used when generating descriptions from docstrings.
Note that python's `textwrap.dedent` doesn't quite cut it,
as it fails to dedent multiline docstrings that include
unindented text on the initial line.
"""
content = force_str(content)
lines = [line for line in content.splitlines()[1:] if line.lstrip()]
# unindent the content if needed
if lines:
whitespace_counts = min([len(line) - len(line.lstrip(' ')) for line in lines])
tab_counts = min([len(line) - len(line.lstrip('\t')) for line in lines])
if whitespace_counts:
whitespace_pattern = '^' + (' ' * whitespace_counts)
content = re.sub(re.compile(whitespace_pattern, re.MULTILINE), '', content)
elif tab_counts:
whitespace_pattern = '^' + ('\t' * tab_counts)
content = re.sub(re.compile(whitespace_pattern, re.MULTILINE), '', content)
return content.strip()
def camelcase_to_spaces(content):
"""
Translate 'CamelCaseNames' to 'Camel Case Names'.
Used when generating names from view classes.
"""
camelcase_boundary = '(((?<=[a-z])[A-Z])|([A-Z](?![A-Z]|$)))'
content = re.sub(camelcase_boundary, ' \\1', content).strip()
return ' '.join(content.split('_')).title()
def markup_description(description):
"""
Apply HTML markup to the given description.
"""
if apply_markdown:
description = apply_markdown(description)
else:
description = escape(description).replace('\n', '<br />')
description = '<p>' + description + '</p>'
return mark_safe(description)
class lazy_format:
"""
Delay formatting until it's actually needed.
Useful when the format string or one of the arguments is lazy.
Not using Django's lazy because it is too slow.
"""
__slots__ = ('format_string', 'args', 'kwargs', 'result')
def __init__(self, format_string, *args, **kwargs):
self.result = None
self.format_string = format_string
self.args = args
self.kwargs = kwargs
def __str__(self):
if self.result is None:
self.result = self.format_string.format(*self.args, **self.kwargs)
self.format_string, self.args, self.kwargs = None, None, None
return self.result
def __mod__(self, value):
return str(self) % value

View File

@@ -0,0 +1,95 @@
"""
Helpers for dealing with HTML input.
"""
import re
from django.utils.datastructures import MultiValueDict
def is_html_input(dictionary):
# MultiDict type datastructures are used to represent HTML form input,
# which may have more than one value for each key.
return hasattr(dictionary, 'getlist')
def parse_html_list(dictionary, prefix='', default=None):
"""
Used to support list values in HTML forms.
Supports lists of primitives and/or dictionaries.
* List of primitives.
{
'[0]': 'abc',
'[1]': 'def',
'[2]': 'hij'
}
-->
[
'abc',
'def',
'hij'
]
* List of dictionaries.
{
'[0]foo': 'abc',
'[0]bar': 'def',
'[1]foo': 'hij',
'[1]bar': 'klm',
}
-->
[
{'foo': 'abc', 'bar': 'def'},
{'foo': 'hij', 'bar': 'klm'}
]
:returns a list of objects, or the value specified in ``default`` if the list is empty
"""
ret = {}
regex = re.compile(r'^%s\[([0-9]+)\](.*)$' % re.escape(prefix))
for field, value in dictionary.items():
match = regex.match(field)
if not match:
continue
index, key = match.groups()
index = int(index)
if not key:
ret[index] = value
elif isinstance(ret.get(index), dict):
ret[index][key] = value
else:
ret[index] = MultiValueDict({key: [value]})
# return the items of the ``ret`` dict, sorted by key, or ``default`` if the dict is empty
return [ret[item] for item in sorted(ret)] if ret else default
def parse_html_dict(dictionary, prefix=''):
"""
Used to support dictionary values in HTML forms.
{
'profile.username': 'example',
'profile.email': 'example@example.com',
}
-->
{
'profile': {
'username': 'example',
'email': 'example@example.com'
}
}
"""
ret = MultiValueDict()
regex = re.compile(r'^%s\.(.+)$' % re.escape(prefix))
for field in dictionary:
match = regex.match(field)
if not match:
continue
key = match.groups()[0]
value = dictionary.getlist(field)
ret.setlist(key, value)
return ret

View File

@@ -0,0 +1,47 @@
"""
Helper functions that convert strftime formats into more readable representations.
"""
from rest_framework import ISO_8601
def datetime_formats(formats):
format = ', '.join(formats).replace(
ISO_8601,
'YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z]'
)
return humanize_strptime(format)
def date_formats(formats):
format = ', '.join(formats).replace(ISO_8601, 'YYYY-MM-DD')
return humanize_strptime(format)
def time_formats(formats):
format = ', '.join(formats).replace(ISO_8601, 'hh:mm[:ss[.uuuuuu]]')
return humanize_strptime(format)
def humanize_strptime(format_string):
# Note that we're missing some of the locale specific mappings that
# don't really make sense.
mapping = {
"%Y": "YYYY",
"%y": "YY",
"%m": "MM",
"%b": "[Jan-Dec]",
"%B": "[January-December]",
"%d": "DD",
"%H": "hh",
"%I": "hh", # Requires '%p' to differentiate from '%H'.
"%M": "mm",
"%S": "ss",
"%f": "uuuuuu",
"%a": "[Mon-Sun]",
"%A": "[Monday-Sunday]",
"%p": "[AM|PM]",
"%z": "[+HHMM|-HHMM]"
}
for key, val in mapping.items():
format_string = format_string.replace(key, val)
return format_string

View File

@@ -0,0 +1,37 @@
"""
Wrapper for the builtin json module that ensures compliance with the JSON spec.
REST framework should always import this wrapper module in order to maintain
spec-compliant encoding/decoding. Support for non-standard features should be
handled by users at the renderer and parser layer.
"""
import functools
import json # noqa
def strict_constant(o):
raise ValueError('Out of range float values are not JSON compliant: ' + repr(o))
@functools.wraps(json.dump)
def dump(*args, **kwargs):
kwargs.setdefault('allow_nan', False)
return json.dump(*args, **kwargs)
@functools.wraps(json.dumps)
def dumps(*args, **kwargs):
kwargs.setdefault('allow_nan', False)
return json.dumps(*args, **kwargs)
@functools.wraps(json.load)
def load(*args, **kwargs):
kwargs.setdefault('parse_constant', strict_constant)
return json.load(*args, **kwargs)
@functools.wraps(json.loads)
def loads(*args, **kwargs):
kwargs.setdefault('parse_constant', strict_constant)
return json.loads(*args, **kwargs)

View File

@@ -0,0 +1,81 @@
"""
Handling of media types, as found in HTTP Content-Type and Accept headers.
See https://www.w3.org/Protocols/rfc2616/rfc2616-sec3.html#sec3.7
"""
from django.utils.http import parse_header_parameters
def media_type_matches(lhs, rhs):
"""
Returns ``True`` if the media type in the first argument <= the
media type in the second argument. The media types are strings
as described by the HTTP spec.
Valid media type strings include:
'application/json; indent=4'
'application/json'
'text/*'
'*/*'
"""
lhs = _MediaType(lhs)
rhs = _MediaType(rhs)
return lhs.match(rhs)
def order_by_precedence(media_type_lst):
"""
Returns a list of sets of media type strings, ordered by precedence.
Precedence is determined by how specific a media type is:
3. 'type/subtype; param=val'
2. 'type/subtype'
1. 'type/*'
0. '*/*'
"""
ret = [set(), set(), set(), set()]
for media_type in media_type_lst:
precedence = _MediaType(media_type).precedence
ret[3 - precedence].add(media_type)
return [media_types for media_types in ret if media_types]
class _MediaType:
def __init__(self, media_type_str):
self.orig = '' if (media_type_str is None) else media_type_str
self.full_type, self.params = parse_header_parameters(self.orig)
self.main_type, sep, self.sub_type = self.full_type.partition('/')
def match(self, other):
"""Return true if this MediaType satisfies the given MediaType."""
for key in self.params:
if key != 'q' and other.params.get(key, None) != self.params.get(key, None):
return False
if self.sub_type != '*' and other.sub_type != '*' and other.sub_type != self.sub_type:
return False
if self.main_type != '*' and other.main_type != '*' and other.main_type != self.main_type:
return False
return True
@property
def precedence(self):
"""
Return a precedence level from 0-3 for the media type given how specific it is.
"""
if self.main_type == '*':
return 0
elif self.sub_type == '*':
return 1
elif not self.params or list(self.params) == ['q']:
return 2
return 3
def __str__(self):
ret = "%s/%s" % (self.main_type, self.sub_type)
for key, val in self.params.items():
ret += "; %s=%s" % (key, val)
return ret

View File

@@ -0,0 +1,156 @@
"""
Helper function for returning the field information that is associated
with a model class. This includes returning all the forward and reverse
relationships and their associated metadata.
Usage: `get_field_info(model)` returns a `FieldInfo` instance.
"""
from collections import namedtuple
FieldInfo = namedtuple('FieldInfo', [
'pk', # Model field instance
'fields', # Dict of field name -> model field instance
'forward_relations', # Dict of field name -> RelationInfo
'reverse_relations', # Dict of field name -> RelationInfo
'fields_and_pk', # Shortcut for 'pk' + 'fields'
'relations' # Shortcut for 'forward_relations' + 'reverse_relations'
])
RelationInfo = namedtuple('RelationInfo', [
'model_field',
'related_model',
'to_many',
'to_field',
'has_through_model',
'reverse'
])
def get_field_info(model):
"""
Given a model class, returns a `FieldInfo` instance, which is a
`namedtuple`, containing metadata about the various field types on the model
including information about their relationships.
"""
opts = model._meta.concrete_model._meta
pk = _get_pk(opts)
fields = _get_fields(opts)
forward_relations = _get_forward_relationships(opts)
reverse_relations = _get_reverse_relationships(opts)
fields_and_pk = _merge_fields_and_pk(pk, fields)
relationships = _merge_relationships(forward_relations, reverse_relations)
return FieldInfo(pk, fields, forward_relations, reverse_relations,
fields_and_pk, relationships)
def _get_pk(opts):
pk = opts.pk
rel = pk.remote_field
while rel and rel.parent_link:
# If model is a child via multi-table inheritance, use parent's pk.
pk = pk.remote_field.model._meta.pk
rel = pk.remote_field
return pk
def _get_fields(opts):
fields = {}
for field in [field for field in opts.fields if field.serialize and not field.remote_field]:
fields[field.name] = field
return fields
def _get_to_field(field):
return getattr(field, 'to_fields', None) and field.to_fields[0]
def _get_forward_relationships(opts):
"""
Returns a dict of field names to `RelationInfo`.
"""
forward_relations = {}
for field in [field for field in opts.fields if field.serialize and field.remote_field]:
forward_relations[field.name] = RelationInfo(
model_field=field,
related_model=field.remote_field.model,
to_many=False,
to_field=_get_to_field(field),
has_through_model=False,
reverse=False
)
# Deal with forward many-to-many relationships.
for field in [field for field in opts.many_to_many if field.serialize]:
forward_relations[field.name] = RelationInfo(
model_field=field,
related_model=field.remote_field.model,
to_many=True,
# manytomany do not have to_fields
to_field=None,
has_through_model=(
not field.remote_field.through._meta.auto_created
),
reverse=False
)
return forward_relations
def _get_reverse_relationships(opts):
"""
Returns a dict of field names to `RelationInfo`.
"""
reverse_relations = {}
all_related_objects = [r for r in opts.related_objects if not r.field.many_to_many]
for relation in all_related_objects:
accessor_name = relation.get_accessor_name()
reverse_relations[accessor_name] = RelationInfo(
model_field=None,
related_model=relation.related_model,
to_many=relation.field.remote_field.multiple,
to_field=_get_to_field(relation.field),
has_through_model=False,
reverse=True
)
# Deal with reverse many-to-many relationships.
all_related_many_to_many_objects = [r for r in opts.related_objects if r.field.many_to_many]
for relation in all_related_many_to_many_objects:
accessor_name = relation.get_accessor_name()
reverse_relations[accessor_name] = RelationInfo(
model_field=None,
related_model=relation.related_model,
to_many=True,
# manytomany do not have to_fields
to_field=None,
has_through_model=(
(getattr(relation.field.remote_field, 'through', None) is not None) and
not relation.field.remote_field.through._meta.auto_created
),
reverse=True
)
return reverse_relations
def _merge_fields_and_pk(pk, fields):
fields_and_pk = {'pk': pk, pk.name: pk}
fields_and_pk.update(fields)
return fields_and_pk
def _merge_relationships(forward_relations, reverse_relations):
return {**forward_relations, **reverse_relations}
def is_abstract_model(model):
"""
Given a model class, returns a boolean True if it is abstract and False if it is not.
"""
return hasattr(model, '_meta') and hasattr(model._meta, 'abstract') and model._meta.abstract

View File

@@ -0,0 +1,101 @@
"""
Helper functions for creating user-friendly representations
of serializer classes and serializer fields.
"""
import re
from django.db import models
from django.utils.encoding import force_str
from django.utils.functional import Promise
def manager_repr(value):
model = value.model
opts = model._meta
names_and_managers = [
(manager.name, manager)
for manager
in opts.managers
]
for manager_name, manager_instance in names_and_managers:
if manager_instance == value:
return '%s.%s.all()' % (model._meta.object_name, manager_name)
return repr(value)
def smart_repr(value):
if isinstance(value, models.Manager):
return manager_repr(value)
if isinstance(value, Promise):
value = force_str(value, strings_only=True)
value = repr(value)
# Representations like u'help text'
# should simply be presented as 'help text'
if value.startswith("u'") and value.endswith("'"):
return value[1:]
# Representations like
# <django.core.validators.RegexValidator object at 0x1047af050>
# Should be presented as
# <django.core.validators.RegexValidator object>
return re.sub(' at 0x[0-9A-Fa-f]{4,32}>', '>', value)
def field_repr(field, force_many=False):
kwargs = field._kwargs
if force_many:
kwargs = kwargs.copy()
kwargs['many'] = True
kwargs.pop('child', None)
arg_string = ', '.join([smart_repr(val) for val in field._args])
kwarg_string = ', '.join([
'%s=%s' % (key, smart_repr(val))
for key, val in sorted(kwargs.items())
])
if arg_string and kwarg_string:
arg_string += ', '
if force_many:
class_name = force_many.__class__.__name__
else:
class_name = field.__class__.__name__
return "%s(%s%s)" % (class_name, arg_string, kwarg_string)
def serializer_repr(serializer, indent, force_many=None):
ret = field_repr(serializer, force_many) + ':'
indent_str = ' ' * indent
if force_many:
fields = force_many.fields
else:
fields = serializer.fields
for field_name, field in fields.items():
ret += '\n' + indent_str + field_name + ' = '
if hasattr(field, 'fields'):
ret += serializer_repr(field, indent + 1)
elif hasattr(field, 'child'):
ret += list_repr(field, indent + 1)
elif hasattr(field, 'child_relation'):
ret += field_repr(field.child_relation, force_many=field.child_relation)
else:
ret += field_repr(field)
if serializer.validators:
ret += '\n' + indent_str + 'class Meta:'
ret += '\n' + indent_str + ' validators = ' + smart_repr(serializer.validators)
return ret
def list_repr(serializer, indent):
child = serializer.child
if hasattr(child, 'fields'):
return serializer_repr(serializer, indent, force_many=child)
return field_repr(serializer)

View File

@@ -0,0 +1,180 @@
import contextlib
from collections.abc import Mapping, MutableMapping
from django.utils.encoding import force_str
from rest_framework.utils import json
class ReturnDict(dict):
"""
Return object from `serializer.data` for the `Serializer` class.
Includes a backlink to the serializer instance for renderers
to use if they need richer field information.
"""
def __init__(self, *args, **kwargs):
self.serializer = kwargs.pop('serializer')
super().__init__(*args, **kwargs)
def copy(self):
return ReturnDict(self, serializer=self.serializer)
def __repr__(self):
return dict.__repr__(self)
def __reduce__(self):
# Pickling these objects will drop the .serializer backlink,
# but preserve the raw data.
return (dict, (dict(self),))
# These are basically copied from OrderedDict, with `serializer` added.
def __or__(self, other):
if not isinstance(other, dict):
return NotImplemented
new = self.__class__(self, serializer=self.serializer)
new.update(other)
return new
def __ror__(self, other):
if not isinstance(other, dict):
return NotImplemented
new = self.__class__(other, serializer=self.serializer)
new.update(self)
return new
class ReturnList(list):
"""
Return object from `serializer.data` for the `SerializerList` class.
Includes a backlink to the serializer instance for renderers
to use if they need richer field information.
"""
def __init__(self, *args, **kwargs):
self.serializer = kwargs.pop('serializer')
super().__init__(*args, **kwargs)
def __repr__(self):
return list.__repr__(self)
def __reduce__(self):
# Pickling these objects will drop the .serializer backlink,
# but preserve the raw data.
return (list, (list(self),))
class BoundField:
"""
A field object that also includes `.value` and `.error` properties.
Returned when iterating over a serializer instance,
providing an API similar to Django forms and form fields.
"""
def __init__(self, field, value, errors, prefix=''):
self._field = field
self._prefix = prefix
self.value = value
self.errors = errors
self.name = prefix + self.field_name
def __getattr__(self, attr_name):
return getattr(self._field, attr_name)
@property
def _proxy_class(self):
return self._field.__class__
def __repr__(self):
return '<%s value=%s errors=%s>' % (
self.__class__.__name__, self.value, self.errors
)
def as_form_field(self):
value = '' if (self.value is None or self.value is False) else self.value
return self.__class__(self._field, value, self.errors, self._prefix)
class JSONBoundField(BoundField):
def as_form_field(self):
value = self.value
# When HTML form input is used and the input is not valid
# value will be a JSONString, rather than a JSON primitive.
if not getattr(value, 'is_json_string', False):
with contextlib.suppress(TypeError, ValueError):
value = json.dumps(
self.value,
sort_keys=True,
indent=4,
separators=(',', ': '),
)
return self.__class__(self._field, value, self.errors, self._prefix)
class NestedBoundField(BoundField):
"""
This `BoundField` additionally implements __iter__ and __getitem__
in order to support nested bound fields. This class is the type of
`BoundField` that is used for serializer fields.
"""
def __init__(self, field, value, errors, prefix=''):
if value is None or value == '' or not isinstance(value, Mapping):
value = {}
super().__init__(field, value, errors, prefix)
def __iter__(self):
for field in self.fields.values():
yield self[field.field_name]
def __getitem__(self, key):
field = self.fields[key]
value = self.value.get(key) if self.value else None
error = self.errors.get(key) if isinstance(self.errors, dict) else None
if hasattr(field, 'fields'):
return NestedBoundField(field, value, error, prefix=self.name + '.')
elif getattr(field, '_is_jsonfield', False):
return JSONBoundField(field, value, error, prefix=self.name + '.')
return BoundField(field, value, error, prefix=self.name + '.')
def as_form_field(self):
values = {}
for key, value in self.value.items():
if isinstance(value, (list, dict)):
values[key] = value
else:
values[key] = '' if (value is None or value is False) else force_str(value)
return self.__class__(self._field, values, self.errors, self._prefix)
class BindingDict(MutableMapping):
"""
This dict-like object is used to store fields on a serializer.
This ensures that whenever fields are added to the serializer we call
`field.bind()` so that the `field_name` and `parent` attributes
can be set correctly.
"""
def __init__(self, serializer):
self.serializer = serializer
self.fields = {}
def __setitem__(self, key, field):
self.fields[key] = field
field.bind(field_name=key, parent=self.serializer)
def __getitem__(self, key):
return self.fields[key]
def __delitem__(self, key):
del self.fields[key]
def __iter__(self):
return iter(self.fields)
def __len__(self):
return len(self.fields)
def __repr__(self):
return dict.__repr__(self.fields)

View File

@@ -0,0 +1,25 @@
from datetime import datetime, timezone, tzinfo
def datetime_exists(dt):
"""Check if a datetime exists. Taken from: https://pytz-deprecation-shim.readthedocs.io/en/latest/migration.html"""
# There are no non-existent times in UTC, and comparisons between
# aware time zones always compare absolute times; if a datetime is
# not equal to the same datetime represented in UTC, it is imaginary.
return dt.astimezone(timezone.utc) == dt
def datetime_ambiguous(dt: datetime):
"""Check whether a datetime is ambiguous. Taken from: https://pytz-deprecation-shim.readthedocs.io/en/latest/migration.html"""
# If a datetime exists and its UTC offset changes in response to
# changing `fold`, it is ambiguous in the zone specified.
return datetime_exists(dt) and (
dt.replace(fold=not dt.fold).utcoffset() != dt.utcoffset()
)
def valid_datetime(dt):
"""Returns True if the datetime is not ambiguous or imaginary, False otherwise."""
if isinstance(dt.tzinfo, tzinfo) and not datetime_ambiguous(dt):
return True
return False

View File

@@ -0,0 +1,27 @@
from urllib import parse
from django.utils.encoding import force_str
def replace_query_param(url, key, val):
"""
Given a URL and a key/val pair, set or replace an item in the query
parameters of the URL, and return the new URL.
"""
(scheme, netloc, path, query, fragment) = parse.urlsplit(force_str(url))
query_dict = parse.parse_qs(query, keep_blank_values=True)
query_dict[force_str(key)] = [force_str(val)]
query = parse.urlencode(sorted(query_dict.items()), doseq=True)
return parse.urlunsplit((scheme, netloc, path, query, fragment))
def remove_query_param(url, key):
"""
Given a URL and a key/val pair, remove an item in the query
parameters of the URL, and return the new URL.
"""
(scheme, netloc, path, query, fragment) = parse.urlsplit(force_str(url))
query_dict = parse.parse_qs(query, keep_blank_values=True)
query_dict.pop(key, None)
query = parse.urlencode(sorted(query_dict.items()), doseq=True)
return parse.urlunsplit((scheme, netloc, path, query, fragment))