Updates
This commit is contained in:
@@ -0,0 +1,638 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Django Extensions additional model fields
|
||||
|
||||
Some fields might require additional dependencies to be installed.
|
||||
"""
|
||||
|
||||
import re
|
||||
import string
|
||||
|
||||
try:
|
||||
import uuid
|
||||
|
||||
HAS_UUID = True
|
||||
except ImportError:
|
||||
HAS_UUID = False
|
||||
|
||||
try:
|
||||
import shortuuid
|
||||
|
||||
HAS_SHORT_UUID = True
|
||||
except ImportError:
|
||||
HAS_SHORT_UUID = False
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from django.db.models import DateTimeField, CharField, SlugField, Q, UniqueConstraint
|
||||
from django.db.models.constants import LOOKUP_SEP
|
||||
from django.template.defaultfilters import slugify
|
||||
from django.utils.crypto import get_random_string
|
||||
from django.utils.encoding import force_str
|
||||
|
||||
|
||||
MAX_UNIQUE_QUERY_ATTEMPTS = getattr(
|
||||
settings, "EXTENSIONS_MAX_UNIQUE_QUERY_ATTEMPTS", 100
|
||||
)
|
||||
|
||||
|
||||
class UniqueFieldMixin:
|
||||
def check_is_bool(self, attrname):
|
||||
if not isinstance(getattr(self, attrname), bool):
|
||||
raise ValueError("'{}' argument must be True or False".format(attrname))
|
||||
|
||||
@staticmethod
|
||||
def _get_fields(model_cls):
|
||||
return [
|
||||
(f, f.model if f.model != model_cls else None)
|
||||
for f in model_cls._meta.get_fields()
|
||||
if not f.is_relation or f.one_to_one or (f.many_to_one and f.related_model)
|
||||
]
|
||||
|
||||
def get_queryset(self, model_cls, slug_field):
|
||||
for field, model in self._get_fields(model_cls):
|
||||
if model and field == slug_field:
|
||||
return model._default_manager.all()
|
||||
return model_cls._default_manager.all()
|
||||
|
||||
def find_unique(self, model_instance, field, iterator, *args):
|
||||
# exclude the current model instance from the queryset used in finding
|
||||
# next valid hash
|
||||
queryset = self.get_queryset(model_instance.__class__, field)
|
||||
if model_instance.pk:
|
||||
queryset = queryset.exclude(pk=model_instance.pk)
|
||||
|
||||
# form a kwarg dict used to implement any unique_together constraints
|
||||
kwargs = {}
|
||||
for params in model_instance._meta.unique_together:
|
||||
if self.attname in params:
|
||||
for param in params:
|
||||
kwargs[param] = getattr(model_instance, param, None)
|
||||
|
||||
# for support django 2.2+
|
||||
query = Q()
|
||||
constraints = getattr(model_instance._meta, "constraints", None)
|
||||
if constraints:
|
||||
unique_constraints = filter(
|
||||
lambda c: isinstance(c, UniqueConstraint), constraints
|
||||
)
|
||||
for unique_constraint in unique_constraints:
|
||||
if self.attname in unique_constraint.fields:
|
||||
condition = {
|
||||
field: getattr(model_instance, field, None)
|
||||
for field in unique_constraint.fields
|
||||
if field != self.attname
|
||||
}
|
||||
query &= Q(**condition)
|
||||
|
||||
new = next(iterator)
|
||||
kwargs[self.attname] = new
|
||||
while not new or queryset.filter(query, **kwargs):
|
||||
new = next(iterator)
|
||||
kwargs[self.attname] = new
|
||||
setattr(model_instance, self.attname, new)
|
||||
return new
|
||||
|
||||
|
||||
class AutoSlugField(UniqueFieldMixin, SlugField):
|
||||
"""
|
||||
AutoSlugField
|
||||
|
||||
By default, sets editable=False, blank=True.
|
||||
|
||||
Required arguments:
|
||||
|
||||
populate_from
|
||||
Specifies which field, list of fields, or model method
|
||||
the slug will be populated from.
|
||||
|
||||
populate_from can traverse a ForeignKey relationship
|
||||
by using Django ORM syntax:
|
||||
populate_from = 'related_model__field'
|
||||
|
||||
Optional arguments:
|
||||
|
||||
separator
|
||||
Defines the used separator (default: '-')
|
||||
|
||||
overwrite
|
||||
If set to True, overwrites the slug on every save (default: False)
|
||||
|
||||
slugify_function
|
||||
Defines the function which will be used to "slugify" a content
|
||||
(default: :py:func:`~django.template.defaultfilters.slugify` )
|
||||
|
||||
It is possible to provide custom "slugify" function with
|
||||
the ``slugify_function`` function in a model class.
|
||||
|
||||
``slugify_function`` function in a model class takes priority over
|
||||
``slugify_function`` given as an argument to :py:class:`~AutoSlugField`.
|
||||
|
||||
Example
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# models.py
|
||||
|
||||
from django.db import models
|
||||
|
||||
from django_extensions.db.fields import AutoSlugField
|
||||
|
||||
|
||||
class MyModel(models.Model):
|
||||
def slugify_function(self, content):
|
||||
return content.replace('_', '-').lower()
|
||||
|
||||
title = models.CharField(max_length=42)
|
||||
slug = AutoSlugField(populate_from='title')
|
||||
|
||||
Inspired by SmileyChris' Unique Slugify snippet:
|
||||
https://www.djangosnippets.org/snippets/690/
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
kwargs.setdefault("blank", True)
|
||||
kwargs.setdefault("editable", False)
|
||||
|
||||
populate_from = kwargs.pop("populate_from", None)
|
||||
if populate_from is None:
|
||||
raise ValueError("missing 'populate_from' argument")
|
||||
else:
|
||||
self._populate_from = populate_from
|
||||
|
||||
if not callable(populate_from):
|
||||
if not isinstance(populate_from, (list, tuple)):
|
||||
populate_from = (populate_from,)
|
||||
|
||||
if not all(isinstance(e, str) for e in populate_from):
|
||||
raise TypeError(
|
||||
"'populate_from' must be str or list[str] or tuple[str], found `%s`"
|
||||
% populate_from
|
||||
)
|
||||
|
||||
self.slugify_function = kwargs.pop("slugify_function", slugify)
|
||||
self.separator = kwargs.pop("separator", "-")
|
||||
self.overwrite = kwargs.pop("overwrite", False)
|
||||
self.check_is_bool("overwrite")
|
||||
self.overwrite_on_add = kwargs.pop("overwrite_on_add", True)
|
||||
self.check_is_bool("overwrite_on_add")
|
||||
self.allow_duplicates = kwargs.pop("allow_duplicates", False)
|
||||
self.check_is_bool("allow_duplicates")
|
||||
self.max_unique_query_attempts = kwargs.pop(
|
||||
"max_unique_query_attempts", MAX_UNIQUE_QUERY_ATTEMPTS
|
||||
)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def _slug_strip(self, value):
|
||||
"""
|
||||
Clean up a slug by removing slug separator characters that occur at
|
||||
the beginning or end of a slug.
|
||||
|
||||
If an alternate separator is used, it will also replace any instances
|
||||
of the default '-' separator with the new separator.
|
||||
"""
|
||||
re_sep = "(?:-|%s)" % re.escape(self.separator)
|
||||
value = re.sub("%s+" % re_sep, self.separator, value)
|
||||
return re.sub(r"^%s+|%s+$" % (re_sep, re_sep), "", value)
|
||||
|
||||
@staticmethod
|
||||
def slugify_func(content, slugify_function):
|
||||
if content:
|
||||
return slugify_function(content)
|
||||
return ""
|
||||
|
||||
def slug_generator(self, original_slug, start):
|
||||
yield original_slug
|
||||
for i in range(start, self.max_unique_query_attempts):
|
||||
slug = original_slug
|
||||
end = "%s%s" % (self.separator, i)
|
||||
end_len = len(end)
|
||||
if self.slug_len and len(slug) + end_len > self.slug_len:
|
||||
slug = slug[: self.slug_len - end_len]
|
||||
slug = self._slug_strip(slug)
|
||||
slug = "%s%s" % (slug, end)
|
||||
yield slug
|
||||
raise RuntimeError(
|
||||
"max slug attempts for %s exceeded (%s)"
|
||||
% (original_slug, self.max_unique_query_attempts)
|
||||
)
|
||||
|
||||
def create_slug(self, model_instance, add):
|
||||
slug = getattr(model_instance, self.attname)
|
||||
use_existing_slug = False
|
||||
if slug and not self.overwrite:
|
||||
# Existing slug and not configured to overwrite - Short-circuit
|
||||
# here to prevent slug generation when not required.
|
||||
use_existing_slug = True
|
||||
|
||||
if self.overwrite_on_add and add:
|
||||
use_existing_slug = False
|
||||
|
||||
if use_existing_slug:
|
||||
return slug
|
||||
|
||||
# get fields to populate from and slug field to set
|
||||
populate_from = self._populate_from
|
||||
if not isinstance(populate_from, (list, tuple)):
|
||||
populate_from = (populate_from,)
|
||||
|
||||
slug_field = model_instance._meta.get_field(self.attname)
|
||||
slugify_function = getattr(
|
||||
model_instance, "slugify_function", self.slugify_function
|
||||
)
|
||||
|
||||
# slugify the original field content and set next step to 2
|
||||
slug_for_field = lambda lookup_value: self.slugify_func(
|
||||
self.get_slug_fields(model_instance, lookup_value),
|
||||
slugify_function=slugify_function,
|
||||
)
|
||||
slug = self.separator.join(map(slug_for_field, populate_from))
|
||||
start = 2
|
||||
|
||||
# strip slug depending on max_length attribute of the slug field
|
||||
# and clean-up
|
||||
self.slug_len = slug_field.max_length
|
||||
if self.slug_len:
|
||||
slug = slug[: self.slug_len]
|
||||
slug = self._slug_strip(slug)
|
||||
original_slug = slug
|
||||
|
||||
if self.allow_duplicates:
|
||||
setattr(model_instance, self.attname, slug)
|
||||
return slug
|
||||
|
||||
return self.find_unique(
|
||||
model_instance, slug_field, self.slug_generator(original_slug, start)
|
||||
)
|
||||
|
||||
def get_slug_fields(self, model_instance, lookup_value):
|
||||
if callable(lookup_value):
|
||||
# A function has been provided
|
||||
return "%s" % lookup_value(model_instance)
|
||||
|
||||
lookup_value_path = lookup_value.split(LOOKUP_SEP)
|
||||
attr = model_instance
|
||||
for elem in lookup_value_path:
|
||||
try:
|
||||
attr = getattr(attr, elem)
|
||||
except AttributeError:
|
||||
raise AttributeError(
|
||||
"value {} in AutoSlugField's 'populate_from' argument {} returned an error - {} has no attribute {}".format( # noqa: E501
|
||||
elem, lookup_value, attr, elem
|
||||
)
|
||||
)
|
||||
if callable(attr):
|
||||
return "%s" % attr()
|
||||
|
||||
return attr
|
||||
|
||||
def pre_save(self, model_instance, add):
|
||||
value = force_str(self.create_slug(model_instance, add))
|
||||
return value
|
||||
|
||||
def get_internal_type(self):
|
||||
return "SlugField"
|
||||
|
||||
def deconstruct(self):
|
||||
name, path, args, kwargs = super().deconstruct()
|
||||
kwargs["populate_from"] = self._populate_from
|
||||
if not self.separator == "-":
|
||||
kwargs["separator"] = self.separator
|
||||
if self.overwrite is not False:
|
||||
kwargs["overwrite"] = True
|
||||
if self.allow_duplicates is not False:
|
||||
kwargs["allow_duplicates"] = True
|
||||
return name, path, args, kwargs
|
||||
|
||||
|
||||
class RandomCharField(UniqueFieldMixin, CharField):
|
||||
"""
|
||||
RandomCharField
|
||||
|
||||
By default, sets editable=False, blank=True, unique=False.
|
||||
|
||||
Required arguments:
|
||||
|
||||
length
|
||||
Specifies the length of the field
|
||||
|
||||
Optional arguments:
|
||||
|
||||
unique
|
||||
If set to True, duplicate entries are not allowed (default: False)
|
||||
|
||||
lowercase
|
||||
If set to True, lowercase the alpha characters (default: False)
|
||||
|
||||
uppercase
|
||||
If set to True, uppercase the alpha characters (default: False)
|
||||
|
||||
include_alpha
|
||||
If set to True, include alpha characters (default: True)
|
||||
|
||||
include_digits
|
||||
If set to True, include digit characters (default: True)
|
||||
|
||||
include_punctuation
|
||||
If set to True, include punctuation characters (default: False)
|
||||
|
||||
keep_default
|
||||
If set to True, keeps the default initialization value (default: False)
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
kwargs.setdefault("blank", True)
|
||||
kwargs.setdefault("editable", False)
|
||||
|
||||
self.length = kwargs.pop("length", None)
|
||||
if self.length is None:
|
||||
raise ValueError("missing 'length' argument")
|
||||
kwargs["max_length"] = self.length
|
||||
|
||||
self.lowercase = kwargs.pop("lowercase", False)
|
||||
self.check_is_bool("lowercase")
|
||||
self.uppercase = kwargs.pop("uppercase", False)
|
||||
self.check_is_bool("uppercase")
|
||||
if self.uppercase and self.lowercase:
|
||||
raise ValueError(
|
||||
"the 'lowercase' and 'uppercase' arguments are mutually exclusive"
|
||||
)
|
||||
self.include_digits = kwargs.pop("include_digits", True)
|
||||
self.check_is_bool("include_digits")
|
||||
self.include_alpha = kwargs.pop("include_alpha", True)
|
||||
self.check_is_bool("include_alpha")
|
||||
self.include_punctuation = kwargs.pop("include_punctuation", False)
|
||||
self.keep_default = kwargs.pop("keep_default", False)
|
||||
self.check_is_bool("include_punctuation")
|
||||
self.max_unique_query_attempts = kwargs.pop(
|
||||
"max_unique_query_attempts", MAX_UNIQUE_QUERY_ATTEMPTS
|
||||
)
|
||||
|
||||
# Set unique=False unless it's been set manually.
|
||||
if "unique" not in kwargs:
|
||||
kwargs["unique"] = False
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def random_char_generator(self, chars):
|
||||
for i in range(self.max_unique_query_attempts):
|
||||
yield "".join(get_random_string(self.length, chars))
|
||||
raise RuntimeError(
|
||||
"max random character attempts exceeded (%s)"
|
||||
% self.max_unique_query_attempts
|
||||
)
|
||||
|
||||
def in_unique_together(self, model_instance):
|
||||
for params in model_instance._meta.unique_together:
|
||||
if self.attname in params:
|
||||
return True
|
||||
return False
|
||||
|
||||
def pre_save(self, model_instance, add):
|
||||
if (not add or self.keep_default) and getattr(
|
||||
model_instance, self.attname
|
||||
) != "":
|
||||
return getattr(model_instance, self.attname)
|
||||
|
||||
population = ""
|
||||
if self.include_alpha:
|
||||
if self.lowercase:
|
||||
population += string.ascii_lowercase
|
||||
elif self.uppercase:
|
||||
population += string.ascii_uppercase
|
||||
else:
|
||||
population += string.ascii_letters
|
||||
|
||||
if self.include_digits:
|
||||
population += string.digits
|
||||
|
||||
if self.include_punctuation:
|
||||
population += string.punctuation
|
||||
|
||||
random_chars = self.random_char_generator(population)
|
||||
if not self.unique and not self.in_unique_together(model_instance):
|
||||
new = next(random_chars)
|
||||
setattr(model_instance, self.attname, new)
|
||||
return new
|
||||
|
||||
return self.find_unique(
|
||||
model_instance,
|
||||
model_instance._meta.get_field(self.attname),
|
||||
random_chars,
|
||||
)
|
||||
|
||||
def internal_type(self):
|
||||
return "CharField"
|
||||
|
||||
def deconstruct(self):
|
||||
name, path, args, kwargs = super().deconstruct()
|
||||
kwargs["length"] = self.length
|
||||
del kwargs["max_length"]
|
||||
if self.lowercase is True:
|
||||
kwargs["lowercase"] = self.lowercase
|
||||
if self.uppercase is True:
|
||||
kwargs["uppercase"] = self.uppercase
|
||||
if self.include_alpha is False:
|
||||
kwargs["include_alpha"] = self.include_alpha
|
||||
if self.include_digits is False:
|
||||
kwargs["include_digits"] = self.include_digits
|
||||
if self.include_punctuation is True:
|
||||
kwargs["include_punctuation"] = self.include_punctuation
|
||||
if self.unique is True:
|
||||
kwargs["unique"] = self.unique
|
||||
return name, path, args, kwargs
|
||||
|
||||
|
||||
class CreationDateTimeField(DateTimeField):
|
||||
"""
|
||||
CreationDateTimeField
|
||||
|
||||
By default, sets editable=False, blank=True, auto_now_add=True
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
kwargs.setdefault("editable", False)
|
||||
kwargs.setdefault("blank", True)
|
||||
kwargs.setdefault("auto_now_add", True)
|
||||
DateTimeField.__init__(self, *args, **kwargs)
|
||||
|
||||
def get_internal_type(self):
|
||||
return "DateTimeField"
|
||||
|
||||
def deconstruct(self):
|
||||
name, path, args, kwargs = super().deconstruct()
|
||||
if self.editable is not False:
|
||||
kwargs["editable"] = True
|
||||
if self.blank is not True:
|
||||
kwargs["blank"] = False
|
||||
if self.auto_now_add is not False:
|
||||
kwargs["auto_now_add"] = True
|
||||
return name, path, args, kwargs
|
||||
|
||||
|
||||
class ModificationDateTimeField(CreationDateTimeField):
|
||||
"""
|
||||
ModificationDateTimeField
|
||||
|
||||
By default, sets editable=False, blank=True, auto_now=True
|
||||
|
||||
Sets value to now every time the object is saved.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
kwargs.setdefault("auto_now", True)
|
||||
DateTimeField.__init__(self, *args, **kwargs)
|
||||
|
||||
def get_internal_type(self):
|
||||
return "DateTimeField"
|
||||
|
||||
def deconstruct(self):
|
||||
name, path, args, kwargs = super().deconstruct()
|
||||
if self.auto_now is not False:
|
||||
kwargs["auto_now"] = True
|
||||
return name, path, args, kwargs
|
||||
|
||||
def pre_save(self, model_instance, add):
|
||||
if not getattr(model_instance, "update_modified", True):
|
||||
return getattr(model_instance, self.attname)
|
||||
return super().pre_save(model_instance, add)
|
||||
|
||||
|
||||
class UUIDVersionError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class UUIDFieldMixin:
|
||||
"""
|
||||
UUIDFieldMixin
|
||||
|
||||
By default uses UUID version 4 (randomly generated UUID).
|
||||
|
||||
The field support all uuid versions which are natively supported by the uuid python module, except version 2.
|
||||
For more information see: https://docs.python.org/lib/module-uuid.html
|
||||
""" # noqa: E501
|
||||
|
||||
DEFAULT_MAX_LENGTH = 36
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
verbose_name=None,
|
||||
name=None,
|
||||
auto=True,
|
||||
version=4,
|
||||
node=None,
|
||||
clock_seq=None,
|
||||
namespace=None,
|
||||
uuid_name=None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
if not HAS_UUID:
|
||||
raise ImproperlyConfigured(
|
||||
"'uuid' module is required for UUIDField. "
|
||||
"(Do you have Python 2.5 or higher installed ?)"
|
||||
)
|
||||
|
||||
kwargs.setdefault("max_length", self.DEFAULT_MAX_LENGTH)
|
||||
|
||||
if auto:
|
||||
self.empty_strings_allowed = False
|
||||
kwargs["blank"] = True
|
||||
kwargs.setdefault("editable", False)
|
||||
|
||||
self.auto = auto
|
||||
self.version = version
|
||||
self.node = node
|
||||
self.clock_seq = clock_seq
|
||||
self.namespace = namespace
|
||||
self.uuid_name = uuid_name or name
|
||||
|
||||
super().__init__(verbose_name=verbose_name, *args, **kwargs)
|
||||
|
||||
def create_uuid(self):
|
||||
if not self.version or self.version == 4:
|
||||
return uuid.uuid4()
|
||||
elif self.version == 1:
|
||||
return uuid.uuid1(self.node, self.clock_seq)
|
||||
elif self.version == 2:
|
||||
raise UUIDVersionError("UUID version 2 is not supported.")
|
||||
elif self.version == 3:
|
||||
return uuid.uuid3(self.namespace, self.uuid_name)
|
||||
elif self.version == 5:
|
||||
return uuid.uuid5(self.namespace, self.uuid_name)
|
||||
else:
|
||||
raise UUIDVersionError("UUID version %s is not valid." % self.version)
|
||||
|
||||
def pre_save(self, model_instance, add):
|
||||
value = super().pre_save(model_instance, add)
|
||||
|
||||
if self.auto and add and value is None:
|
||||
value = force_str(self.create_uuid())
|
||||
setattr(model_instance, self.attname, value)
|
||||
return value
|
||||
else:
|
||||
if self.auto and not value:
|
||||
value = force_str(self.create_uuid())
|
||||
setattr(model_instance, self.attname, value)
|
||||
|
||||
return value
|
||||
|
||||
def formfield(self, form_class=None, choices_form_class=None, **kwargs):
|
||||
if self.auto:
|
||||
return None
|
||||
return super().formfield(form_class, choices_form_class, **kwargs)
|
||||
|
||||
def deconstruct(self):
|
||||
name, path, args, kwargs = super().deconstruct()
|
||||
|
||||
if kwargs.get("max_length", None) == self.DEFAULT_MAX_LENGTH:
|
||||
del kwargs["max_length"]
|
||||
if self.auto is not True:
|
||||
kwargs["auto"] = self.auto
|
||||
if self.version != 4:
|
||||
kwargs["version"] = self.version
|
||||
if self.node is not None:
|
||||
kwargs["node"] = self.node
|
||||
if self.clock_seq is not None:
|
||||
kwargs["clock_seq"] = self.clock_seq
|
||||
if self.namespace is not None:
|
||||
kwargs["namespace"] = self.namespace
|
||||
if self.uuid_name is not None:
|
||||
kwargs["uuid_name"] = self.name
|
||||
|
||||
return name, path, args, kwargs
|
||||
|
||||
|
||||
class ShortUUIDField(UUIDFieldMixin, CharField):
|
||||
"""
|
||||
ShortUUIDFied
|
||||
|
||||
Generates concise (22 characters instead of 36), unambiguous, URL-safe UUIDs.
|
||||
|
||||
Based on `shortuuid`: https://github.com/stochastic-technologies/shortuuid
|
||||
"""
|
||||
|
||||
DEFAULT_MAX_LENGTH = 22
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
if not HAS_SHORT_UUID:
|
||||
raise ImproperlyConfigured(
|
||||
"'shortuuid' module is required for ShortUUIDField. "
|
||||
"(Do you have Python 2.5 or higher installed ?)"
|
||||
)
|
||||
kwargs.setdefault("max_length", self.DEFAULT_MAX_LENGTH)
|
||||
|
||||
def create_uuid(self):
|
||||
if not self.version or self.version == 4:
|
||||
return shortuuid.uuid()
|
||||
elif self.version == 1:
|
||||
return shortuuid.uuid()
|
||||
elif self.version == 2:
|
||||
raise UUIDVersionError("UUID version 2 is not supported.")
|
||||
elif self.version == 3:
|
||||
raise UUIDVersionError("UUID version 3 is not supported.")
|
||||
elif self.version == 5:
|
||||
return shortuuid.uuid(name=self.namespace)
|
||||
else:
|
||||
raise UUIDVersionError("UUID version %s is not valid." % self.version)
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,115 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
JSONField automatically serializes most Python terms to JSON data.
|
||||
Creates a TEXT field with a default value of "{}". See test_json.py for
|
||||
more information.
|
||||
|
||||
from django.db import models
|
||||
from django_extensions.db.fields import json
|
||||
|
||||
class LOL(models.Model):
|
||||
extra = json.JSONField()
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
from django.core.serializers.json import DjangoJSONEncoder
|
||||
from django.db import models
|
||||
from django.db.models import expressions
|
||||
|
||||
|
||||
def dumps(value):
|
||||
return DjangoJSONEncoder().encode(value)
|
||||
|
||||
|
||||
def loads(txt):
|
||||
return json.loads(txt)
|
||||
|
||||
|
||||
class JSONDict(dict):
|
||||
"""
|
||||
Hack so repr() called by dumpdata will output JSON instead of
|
||||
Python formatted data. This way fixtures will work!
|
||||
"""
|
||||
|
||||
def __repr__(self):
|
||||
return dumps(self)
|
||||
|
||||
|
||||
class JSONList(list):
|
||||
"""
|
||||
Hack so repr() called by dumpdata will output JSON instead of
|
||||
Python formatted data. This way fixtures will work!
|
||||
"""
|
||||
|
||||
def __repr__(self):
|
||||
return dumps(self)
|
||||
|
||||
|
||||
class JSONField(models.TextField):
|
||||
"""
|
||||
JSONField is a generic textfield that neatly serializes/unserializes
|
||||
JSON objects seamlessly. Main thingy must be a dict object.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
kwargs["default"] = kwargs.get("default", dict)
|
||||
models.TextField.__init__(self, *args, **kwargs)
|
||||
|
||||
def get_default(self):
|
||||
if self.has_default():
|
||||
default = self.default
|
||||
|
||||
if callable(default):
|
||||
default = default()
|
||||
|
||||
return self.to_python(default)
|
||||
return super().get_default()
|
||||
|
||||
def to_python(self, value):
|
||||
"""Convert our string value to JSON after we load it from the DB"""
|
||||
if value is None or value == "":
|
||||
return {}
|
||||
|
||||
if isinstance(value, str):
|
||||
res = loads(value)
|
||||
else:
|
||||
res = value
|
||||
|
||||
if isinstance(res, dict):
|
||||
return JSONDict(**res)
|
||||
elif isinstance(res, list):
|
||||
return JSONList(res)
|
||||
|
||||
return res
|
||||
|
||||
def get_prep_value(self, value):
|
||||
if not isinstance(value, str):
|
||||
return dumps(value)
|
||||
return super(models.TextField, self).get_prep_value(value)
|
||||
|
||||
def from_db_value(self, value, expression, connection): # type: ignore
|
||||
return self.to_python(value)
|
||||
|
||||
def get_db_prep_save(self, value, connection, **kwargs):
|
||||
"""Convert our JSON object to a string before we save"""
|
||||
if value is None and self.null:
|
||||
return None
|
||||
|
||||
# default values come in as strings; only non-strings should be
|
||||
# run through `dumps`
|
||||
if (
|
||||
not isinstance(value, str)
|
||||
# https://github.com/django-extensions/django-extensions/issues/1924
|
||||
# https://code.djangoproject.com/ticket/35167
|
||||
and not isinstance(value, expressions.Expression)
|
||||
):
|
||||
value = dumps(value)
|
||||
|
||||
return super().get_db_prep_save(value, connection)
|
||||
|
||||
def deconstruct(self):
|
||||
name, path, args, kwargs = super().deconstruct()
|
||||
if self.default == "{}":
|
||||
del kwargs["default"]
|
||||
return name, path, args, kwargs
|
||||
Reference in New Issue
Block a user