This commit is contained in:
Iliyan Angelov
2025-09-14 23:24:25 +03:00
commit c67067a2a4
71311 changed files with 6800714 additions and 0 deletions

View File

@@ -0,0 +1,13 @@
import django
from django.core.handlers.asgi import ASGIHandler
def get_asgi_application():
"""
The public interface to Django's ASGI support. Return an ASGI 3 callable.
Avoids making django.core.handlers.ASGIHandler a public API, in case the
internal implementation changes or moves in the future.
"""
django.setup(set_prefix=False)
return ASGIHandler()

View File

@@ -0,0 +1,66 @@
"""
Caching framework.
This package defines set of cache backends that all conform to a simple API.
In a nutshell, a cache is a set of values -- which can be any object that
may be pickled -- identified by string keys. For the complete API, see
the abstract BaseCache class in django.core.cache.backends.base.
Client code should use the `cache` variable defined here to access the default
cache backend and look up non-default cache backends in the `caches` dict-like
object.
See docs/topics/cache.txt for information on the public API.
"""
from django.core import signals
from django.core.cache.backends.base import (
BaseCache,
CacheKeyWarning,
InvalidCacheBackendError,
InvalidCacheKey,
)
from django.utils.connection import BaseConnectionHandler, ConnectionProxy
from django.utils.module_loading import import_string
__all__ = [
"cache",
"caches",
"DEFAULT_CACHE_ALIAS",
"InvalidCacheBackendError",
"CacheKeyWarning",
"BaseCache",
"InvalidCacheKey",
]
DEFAULT_CACHE_ALIAS = "default"
class CacheHandler(BaseConnectionHandler):
settings_name = "CACHES"
exception_class = InvalidCacheBackendError
def create_connection(self, alias):
params = self.settings[alias].copy()
backend = params.pop("BACKEND")
location = params.pop("LOCATION", "")
try:
backend_cls = import_string(backend)
except ImportError as e:
raise InvalidCacheBackendError(
"Could not find backend '%s': %s" % (backend, e)
) from e
return backend_cls(location, params)
caches = CacheHandler()
cache = ConnectionProxy(caches, DEFAULT_CACHE_ALIAS)
def close_caches(**kwargs):
# Some caches need to do a cleanup at the end of a request cycle. If not
# implemented in a particular backend cache.close() is a no-op.
caches.close_all()
signals.request_finished.connect(close_caches)

View File

@@ -0,0 +1,403 @@
"Base Cache class."
import time
import warnings
from asgiref.sync import sync_to_async
from django.core.exceptions import ImproperlyConfigured
from django.utils.module_loading import import_string
class InvalidCacheBackendError(ImproperlyConfigured):
pass
class CacheKeyWarning(RuntimeWarning):
pass
class InvalidCacheKey(ValueError):
pass
# Stub class to ensure not passing in a `timeout` argument results in
# the default timeout
DEFAULT_TIMEOUT = object()
# Memcached does not accept keys longer than this.
MEMCACHE_MAX_KEY_LENGTH = 250
def default_key_func(key, key_prefix, version):
"""
Default function to generate keys.
Construct the key used by all other methods. By default, prepend
the `key_prefix`. KEY_FUNCTION can be used to specify an alternate
function with custom key making behavior.
"""
return "%s:%s:%s" % (key_prefix, version, key)
def get_key_func(key_func):
"""
Function to decide which key function to use.
Default to ``default_key_func``.
"""
if key_func is not None:
if callable(key_func):
return key_func
else:
return import_string(key_func)
return default_key_func
class BaseCache:
_missing_key = object()
def __init__(self, params):
timeout = params.get("timeout", params.get("TIMEOUT", 300))
if timeout is not None:
try:
timeout = int(timeout)
except (ValueError, TypeError):
timeout = 300
self.default_timeout = timeout
options = params.get("OPTIONS", {})
max_entries = params.get("max_entries", options.get("MAX_ENTRIES", 300))
try:
self._max_entries = int(max_entries)
except (ValueError, TypeError):
self._max_entries = 300
cull_frequency = params.get("cull_frequency", options.get("CULL_FREQUENCY", 3))
try:
self._cull_frequency = int(cull_frequency)
except (ValueError, TypeError):
self._cull_frequency = 3
self.key_prefix = params.get("KEY_PREFIX", "")
self.version = params.get("VERSION", 1)
self.key_func = get_key_func(params.get("KEY_FUNCTION"))
def get_backend_timeout(self, timeout=DEFAULT_TIMEOUT):
"""
Return the timeout value usable by this backend based upon the provided
timeout.
"""
if timeout == DEFAULT_TIMEOUT:
timeout = self.default_timeout
elif timeout == 0:
# ticket 21147 - avoid time.time() related precision issues
timeout = -1
return None if timeout is None else time.time() + timeout
def make_key(self, key, version=None):
"""
Construct the key used by all other methods. By default, use the
key_func to generate a key (which, by default, prepends the
`key_prefix' and 'version'). A different key function can be provided
at the time of cache construction; alternatively, you can subclass the
cache backend to provide custom key making behavior.
"""
if version is None:
version = self.version
return self.key_func(key, self.key_prefix, version)
def validate_key(self, key):
"""
Warn about keys that would not be portable to the memcached
backend. This encourages (but does not force) writing backend-portable
cache code.
"""
for warning in memcache_key_warnings(key):
warnings.warn(warning, CacheKeyWarning)
def make_and_validate_key(self, key, version=None):
"""Helper to make and validate keys."""
key = self.make_key(key, version=version)
self.validate_key(key)
return key
def add(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
"""
Set a value in the cache if the key does not already exist. If
timeout is given, use that timeout for the key; otherwise use the
default cache timeout.
Return True if the value was stored, False otherwise.
"""
raise NotImplementedError(
"subclasses of BaseCache must provide an add() method"
)
async def aadd(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
return await sync_to_async(self.add, thread_sensitive=True)(
key, value, timeout, version
)
def get(self, key, default=None, version=None):
"""
Fetch a given key from the cache. If the key does not exist, return
default, which itself defaults to None.
"""
raise NotImplementedError("subclasses of BaseCache must provide a get() method")
async def aget(self, key, default=None, version=None):
return await sync_to_async(self.get, thread_sensitive=True)(
key, default, version
)
def set(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
"""
Set a value in the cache. If timeout is given, use that timeout for the
key; otherwise use the default cache timeout.
"""
raise NotImplementedError("subclasses of BaseCache must provide a set() method")
async def aset(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
return await sync_to_async(self.set, thread_sensitive=True)(
key, value, timeout, version
)
def touch(self, key, timeout=DEFAULT_TIMEOUT, version=None):
"""
Update the key's expiry time using timeout. Return True if successful
or False if the key does not exist.
"""
raise NotImplementedError(
"subclasses of BaseCache must provide a touch() method"
)
async def atouch(self, key, timeout=DEFAULT_TIMEOUT, version=None):
return await sync_to_async(self.touch, thread_sensitive=True)(
key, timeout, version
)
def delete(self, key, version=None):
"""
Delete a key from the cache and return whether it succeeded, failing
silently.
"""
raise NotImplementedError(
"subclasses of BaseCache must provide a delete() method"
)
async def adelete(self, key, version=None):
return await sync_to_async(self.delete, thread_sensitive=True)(key, version)
def get_many(self, keys, version=None):
"""
Fetch a bunch of keys from the cache. For certain backends (memcached,
pgsql) this can be *much* faster when fetching multiple values.
Return a dict mapping each key in keys to its value. If the given
key is missing, it will be missing from the response dict.
"""
d = {}
for k in keys:
val = self.get(k, self._missing_key, version=version)
if val is not self._missing_key:
d[k] = val
return d
async def aget_many(self, keys, version=None):
"""See get_many()."""
d = {}
for k in keys:
val = await self.aget(k, self._missing_key, version=version)
if val is not self._missing_key:
d[k] = val
return d
def get_or_set(self, key, default, timeout=DEFAULT_TIMEOUT, version=None):
"""
Fetch a given key from the cache. If the key does not exist,
add the key and set it to the default value. The default value can
also be any callable. If timeout is given, use that timeout for the
key; otherwise use the default cache timeout.
Return the value of the key stored or retrieved.
"""
val = self.get(key, self._missing_key, version=version)
if val is self._missing_key:
if callable(default):
default = default()
self.add(key, default, timeout=timeout, version=version)
# Fetch the value again to avoid a race condition if another caller
# added a value between the first get() and the add() above.
return self.get(key, default, version=version)
return val
async def aget_or_set(self, key, default, timeout=DEFAULT_TIMEOUT, version=None):
"""See get_or_set()."""
val = await self.aget(key, self._missing_key, version=version)
if val is self._missing_key:
if callable(default):
default = default()
await self.aadd(key, default, timeout=timeout, version=version)
# Fetch the value again to avoid a race condition if another caller
# added a value between the first aget() and the aadd() above.
return await self.aget(key, default, version=version)
return val
def has_key(self, key, version=None):
"""
Return True if the key is in the cache and has not expired.
"""
return (
self.get(key, self._missing_key, version=version) is not self._missing_key
)
async def ahas_key(self, key, version=None):
return (
await self.aget(key, self._missing_key, version=version)
is not self._missing_key
)
def incr(self, key, delta=1, version=None):
"""
Add delta to value in the cache. If the key does not exist, raise a
ValueError exception.
"""
value = self.get(key, self._missing_key, version=version)
if value is self._missing_key:
raise ValueError("Key '%s' not found" % key)
new_value = value + delta
self.set(key, new_value, version=version)
return new_value
async def aincr(self, key, delta=1, version=None):
"""See incr()."""
value = await self.aget(key, self._missing_key, version=version)
if value is self._missing_key:
raise ValueError("Key '%s' not found" % key)
new_value = value + delta
await self.aset(key, new_value, version=version)
return new_value
def decr(self, key, delta=1, version=None):
"""
Subtract delta from value in the cache. If the key does not exist, raise
a ValueError exception.
"""
return self.incr(key, -delta, version=version)
async def adecr(self, key, delta=1, version=None):
return await self.aincr(key, -delta, version=version)
def __contains__(self, key):
"""
Return True if the key is in the cache and has not expired.
"""
# This is a separate method, rather than just a copy of has_key(),
# so that it always has the same functionality as has_key(), even
# if a subclass overrides it.
return self.has_key(key)
def set_many(self, data, timeout=DEFAULT_TIMEOUT, version=None):
"""
Set a bunch of values in the cache at once from a dict of key/value
pairs. For certain backends (memcached), this is much more efficient
than calling set() multiple times.
If timeout is given, use that timeout for the key; otherwise use the
default cache timeout.
On backends that support it, return a list of keys that failed
insertion, or an empty list if all keys were inserted successfully.
"""
for key, value in data.items():
self.set(key, value, timeout=timeout, version=version)
return []
async def aset_many(self, data, timeout=DEFAULT_TIMEOUT, version=None):
for key, value in data.items():
await self.aset(key, value, timeout=timeout, version=version)
return []
def delete_many(self, keys, version=None):
"""
Delete a bunch of values in the cache at once. For certain backends
(memcached), this is much more efficient than calling delete() multiple
times.
"""
for key in keys:
self.delete(key, version=version)
async def adelete_many(self, keys, version=None):
for key in keys:
await self.adelete(key, version=version)
def clear(self):
"""Remove *all* values from the cache at once."""
raise NotImplementedError(
"subclasses of BaseCache must provide a clear() method"
)
async def aclear(self):
return await sync_to_async(self.clear, thread_sensitive=True)()
def incr_version(self, key, delta=1, version=None):
"""
Add delta to the cache version for the supplied key. Return the new
version.
"""
if version is None:
version = self.version
value = self.get(key, self._missing_key, version=version)
if value is self._missing_key:
raise ValueError("Key '%s' not found" % key)
self.set(key, value, version=version + delta)
self.delete(key, version=version)
return version + delta
async def aincr_version(self, key, delta=1, version=None):
"""See incr_version()."""
if version is None:
version = self.version
value = await self.aget(key, self._missing_key, version=version)
if value is self._missing_key:
raise ValueError("Key '%s' not found" % key)
await self.aset(key, value, version=version + delta)
await self.adelete(key, version=version)
return version + delta
def decr_version(self, key, delta=1, version=None):
"""
Subtract delta from the cache version for the supplied key. Return the
new version.
"""
return self.incr_version(key, -delta, version)
async def adecr_version(self, key, delta=1, version=None):
return await self.aincr_version(key, -delta, version)
def close(self, **kwargs):
"""Close the cache connection"""
pass
async def aclose(self, **kwargs):
pass
def memcache_key_warnings(key):
if len(key) > MEMCACHE_MAX_KEY_LENGTH:
yield (
"Cache key will cause errors if used with memcached: %r "
"(longer than %s)" % (key, MEMCACHE_MAX_KEY_LENGTH)
)
for char in key:
if ord(char) < 33 or ord(char) == 127:
yield (
"Cache key contains characters that will cause errors if "
"used with memcached: %r" % key
)
break

View File

@@ -0,0 +1,293 @@
"Database cache backend."
import base64
import pickle
from datetime import datetime, timezone
from django.conf import settings
from django.core.cache.backends.base import DEFAULT_TIMEOUT, BaseCache
from django.db import DatabaseError, connections, models, router, transaction
from django.utils.timezone import now as tz_now
class Options:
"""A class that will quack like a Django model _meta class.
This allows cache operations to be controlled by the router
"""
def __init__(self, table):
self.db_table = table
self.app_label = "django_cache"
self.model_name = "cacheentry"
self.verbose_name = "cache entry"
self.verbose_name_plural = "cache entries"
self.object_name = "CacheEntry"
self.abstract = False
self.managed = True
self.proxy = False
self.swapped = False
class BaseDatabaseCache(BaseCache):
def __init__(self, table, params):
super().__init__(params)
self._table = table
class CacheEntry:
_meta = Options(table)
self.cache_model_class = CacheEntry
class DatabaseCache(BaseDatabaseCache):
# This class uses cursors provided by the database connection. This means
# it reads expiration values as aware or naive datetimes, depending on the
# value of USE_TZ and whether the database supports time zones. The ORM's
# conversion and adaptation infrastructure is then used to avoid comparing
# aware and naive datetimes accidentally.
pickle_protocol = pickle.HIGHEST_PROTOCOL
def get(self, key, default=None, version=None):
return self.get_many([key], version).get(key, default)
def get_many(self, keys, version=None):
if not keys:
return {}
key_map = {
self.make_and_validate_key(key, version=version): key for key in keys
}
db = router.db_for_read(self.cache_model_class)
connection = connections[db]
quote_name = connection.ops.quote_name
table = quote_name(self._table)
with connection.cursor() as cursor:
cursor.execute(
"SELECT %s, %s, %s FROM %s WHERE %s IN (%s)"
% (
quote_name("cache_key"),
quote_name("value"),
quote_name("expires"),
table,
quote_name("cache_key"),
", ".join(["%s"] * len(key_map)),
),
list(key_map),
)
rows = cursor.fetchall()
result = {}
expired_keys = []
expression = models.Expression(output_field=models.DateTimeField())
converters = connection.ops.get_db_converters(
expression
) + expression.get_db_converters(connection)
for key, value, expires in rows:
for converter in converters:
expires = converter(expires, expression, connection)
if expires < tz_now():
expired_keys.append(key)
else:
value = connection.ops.process_clob(value)
value = pickle.loads(base64.b64decode(value.encode()))
result[key_map.get(key)] = value
self._base_delete_many(expired_keys)
return result
def set(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
key = self.make_and_validate_key(key, version=version)
self._base_set("set", key, value, timeout)
def add(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
key = self.make_and_validate_key(key, version=version)
return self._base_set("add", key, value, timeout)
def touch(self, key, timeout=DEFAULT_TIMEOUT, version=None):
key = self.make_and_validate_key(key, version=version)
return self._base_set("touch", key, None, timeout)
def _base_set(self, mode, key, value, timeout=DEFAULT_TIMEOUT):
timeout = self.get_backend_timeout(timeout)
db = router.db_for_write(self.cache_model_class)
connection = connections[db]
quote_name = connection.ops.quote_name
table = quote_name(self._table)
with connection.cursor() as cursor:
cursor.execute("SELECT COUNT(*) FROM %s" % table)
num = cursor.fetchone()[0]
now = tz_now()
now = now.replace(microsecond=0)
if timeout is None:
exp = datetime.max
else:
tz = timezone.utc if settings.USE_TZ else None
exp = datetime.fromtimestamp(timeout, tz=tz)
exp = exp.replace(microsecond=0)
if num > self._max_entries:
self._cull(db, cursor, now, num)
pickled = pickle.dumps(value, self.pickle_protocol)
# The DB column is expecting a string, so make sure the value is a
# string, not bytes. Refs #19274.
b64encoded = base64.b64encode(pickled).decode("latin1")
try:
# Note: typecasting for datetimes is needed by some 3rd party
# database backends. All core backends work without typecasting,
# so be careful about changes here - test suite will NOT pick
# regressions.
with transaction.atomic(using=db):
cursor.execute(
"SELECT %s, %s FROM %s WHERE %s = %%s"
% (
quote_name("cache_key"),
quote_name("expires"),
table,
quote_name("cache_key"),
),
[key],
)
result = cursor.fetchone()
if result:
current_expires = result[1]
expression = models.Expression(
output_field=models.DateTimeField()
)
for converter in connection.ops.get_db_converters(
expression
) + expression.get_db_converters(connection):
current_expires = converter(
current_expires, expression, connection
)
exp = connection.ops.adapt_datetimefield_value(exp)
if result and mode == "touch":
cursor.execute(
"UPDATE %s SET %s = %%s WHERE %s = %%s"
% (table, quote_name("expires"), quote_name("cache_key")),
[exp, key],
)
elif result and (
mode == "set" or (mode == "add" and current_expires < now)
):
cursor.execute(
"UPDATE %s SET %s = %%s, %s = %%s WHERE %s = %%s"
% (
table,
quote_name("value"),
quote_name("expires"),
quote_name("cache_key"),
),
[b64encoded, exp, key],
)
elif mode != "touch":
cursor.execute(
"INSERT INTO %s (%s, %s, %s) VALUES (%%s, %%s, %%s)"
% (
table,
quote_name("cache_key"),
quote_name("value"),
quote_name("expires"),
),
[key, b64encoded, exp],
)
else:
return False # touch failed.
except DatabaseError:
# To be threadsafe, updates/inserts are allowed to fail silently
return False
else:
return True
def delete(self, key, version=None):
key = self.make_and_validate_key(key, version=version)
return self._base_delete_many([key])
def delete_many(self, keys, version=None):
keys = [self.make_and_validate_key(key, version=version) for key in keys]
self._base_delete_many(keys)
def _base_delete_many(self, keys):
if not keys:
return False
db = router.db_for_write(self.cache_model_class)
connection = connections[db]
quote_name = connection.ops.quote_name
table = quote_name(self._table)
with connection.cursor() as cursor:
cursor.execute(
"DELETE FROM %s WHERE %s IN (%s)"
% (
table,
quote_name("cache_key"),
", ".join(["%s"] * len(keys)),
),
keys,
)
return bool(cursor.rowcount)
def has_key(self, key, version=None):
key = self.make_and_validate_key(key, version=version)
db = router.db_for_read(self.cache_model_class)
connection = connections[db]
quote_name = connection.ops.quote_name
now = tz_now().replace(microsecond=0, tzinfo=None)
with connection.cursor() as cursor:
cursor.execute(
"SELECT %s FROM %s WHERE %s = %%s and %s > %%s"
% (
quote_name("cache_key"),
quote_name(self._table),
quote_name("cache_key"),
quote_name("expires"),
),
[key, connection.ops.adapt_datetimefield_value(now)],
)
return cursor.fetchone() is not None
def _cull(self, db, cursor, now, num):
if self._cull_frequency == 0:
self.clear()
else:
connection = connections[db]
table = connection.ops.quote_name(self._table)
cursor.execute(
"DELETE FROM %s WHERE %s < %%s"
% (
table,
connection.ops.quote_name("expires"),
),
[connection.ops.adapt_datetimefield_value(now)],
)
deleted_count = cursor.rowcount
remaining_num = num - deleted_count
if remaining_num > self._max_entries:
cull_num = remaining_num // self._cull_frequency
cursor.execute(
connection.ops.cache_key_culling_sql() % table, [cull_num]
)
last_cache_key = cursor.fetchone()
if last_cache_key:
cursor.execute(
"DELETE FROM %s WHERE %s < %%s"
% (
table,
connection.ops.quote_name("cache_key"),
),
[last_cache_key[0]],
)
def clear(self):
db = router.db_for_write(self.cache_model_class)
connection = connections[db]
table = connection.ops.quote_name(self._table)
with connection.cursor() as cursor:
cursor.execute("DELETE FROM %s" % table)

View File

@@ -0,0 +1,34 @@
"Dummy cache backend"
from django.core.cache.backends.base import DEFAULT_TIMEOUT, BaseCache
class DummyCache(BaseCache):
def __init__(self, host, *args, **kwargs):
super().__init__(*args, **kwargs)
def add(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
self.make_and_validate_key(key, version=version)
return True
def get(self, key, default=None, version=None):
self.make_and_validate_key(key, version=version)
return default
def set(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
self.make_and_validate_key(key, version=version)
def touch(self, key, timeout=DEFAULT_TIMEOUT, version=None):
self.make_and_validate_key(key, version=version)
return False
def delete(self, key, version=None):
self.make_and_validate_key(key, version=version)
return False
def has_key(self, key, version=None):
self.make_and_validate_key(key, version=version)
return False
def clear(self):
pass

View File

@@ -0,0 +1,170 @@
"File-based cache backend"
import glob
import os
import pickle
import random
import tempfile
import time
import zlib
from django.core.cache.backends.base import DEFAULT_TIMEOUT, BaseCache
from django.core.files import locks
from django.core.files.move import file_move_safe
from django.utils.crypto import md5
class FileBasedCache(BaseCache):
cache_suffix = ".djcache"
pickle_protocol = pickle.HIGHEST_PROTOCOL
def __init__(self, dir, params):
super().__init__(params)
self._dir = os.path.abspath(dir)
self._createdir()
def add(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
if self.has_key(key, version):
return False
self.set(key, value, timeout, version)
return True
def get(self, key, default=None, version=None):
fname = self._key_to_file(key, version)
try:
with open(fname, "rb") as f:
if not self._is_expired(f):
return pickle.loads(zlib.decompress(f.read()))
except FileNotFoundError:
pass
return default
def _write_content(self, file, timeout, value):
expiry = self.get_backend_timeout(timeout)
file.write(pickle.dumps(expiry, self.pickle_protocol))
file.write(zlib.compress(pickle.dumps(value, self.pickle_protocol)))
def set(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
self._createdir() # Cache dir can be deleted at any time.
fname = self._key_to_file(key, version)
self._cull() # make some room if necessary
fd, tmp_path = tempfile.mkstemp(dir=self._dir)
renamed = False
try:
with open(fd, "wb") as f:
self._write_content(f, timeout, value)
file_move_safe(tmp_path, fname, allow_overwrite=True)
renamed = True
finally:
if not renamed:
os.remove(tmp_path)
def touch(self, key, timeout=DEFAULT_TIMEOUT, version=None):
try:
with open(self._key_to_file(key, version), "r+b") as f:
try:
locks.lock(f, locks.LOCK_EX)
if self._is_expired(f):
return False
else:
previous_value = pickle.loads(zlib.decompress(f.read()))
f.seek(0)
self._write_content(f, timeout, previous_value)
return True
finally:
locks.unlock(f)
except FileNotFoundError:
return False
def delete(self, key, version=None):
return self._delete(self._key_to_file(key, version))
def _delete(self, fname):
if not fname.startswith(self._dir) or not os.path.exists(fname):
return False
try:
os.remove(fname)
except FileNotFoundError:
# The file may have been removed by another process.
return False
return True
def has_key(self, key, version=None):
fname = self._key_to_file(key, version)
try:
with open(fname, "rb") as f:
return not self._is_expired(f)
except FileNotFoundError:
return False
def _cull(self):
"""
Remove random cache entries if max_entries is reached at a ratio
of num_entries / cull_frequency. A value of 0 for CULL_FREQUENCY means
that the entire cache will be purged.
"""
filelist = self._list_cache_files()
num_entries = len(filelist)
if num_entries < self._max_entries:
return # return early if no culling is required
if self._cull_frequency == 0:
return self.clear() # Clear the cache when CULL_FREQUENCY = 0
# Delete a random selection of entries
filelist = random.sample(filelist, int(num_entries / self._cull_frequency))
for fname in filelist:
self._delete(fname)
def _createdir(self):
# Set the umask because os.makedirs() doesn't apply the "mode" argument
# to intermediate-level directories.
old_umask = os.umask(0o077)
try:
os.makedirs(self._dir, 0o700, exist_ok=True)
finally:
os.umask(old_umask)
def _key_to_file(self, key, version=None):
"""
Convert a key into a cache file path. Basically this is the
root cache path joined with the md5sum of the key and a suffix.
"""
key = self.make_and_validate_key(key, version=version)
return os.path.join(
self._dir,
"".join(
[
md5(key.encode(), usedforsecurity=False).hexdigest(),
self.cache_suffix,
]
),
)
def clear(self):
"""
Remove all the cache files.
"""
for fname in self._list_cache_files():
self._delete(fname)
def _is_expired(self, f):
"""
Take an open cache file `f` and delete it if it's expired.
"""
try:
exp = pickle.load(f)
except EOFError:
exp = 0 # An empty file is considered expired.
if exp is not None and exp < time.time():
f.close() # On Windows a file has to be closed before deleting
self._delete(f.name)
return True
return False
def _list_cache_files(self):
"""
Get a list of paths to all the cache files. These are all the files
in the root cache dir that end on the cache_suffix.
"""
return [
os.path.join(self._dir, fname)
for fname in glob.glob1(self._dir, "*%s" % self.cache_suffix)
]

View File

@@ -0,0 +1,117 @@
"Thread-safe in-memory cache backend."
import pickle
import time
from collections import OrderedDict
from threading import Lock
from django.core.cache.backends.base import DEFAULT_TIMEOUT, BaseCache
# Global in-memory store of cache data. Keyed by name, to provide
# multiple named local memory caches.
_caches = {}
_expire_info = {}
_locks = {}
class LocMemCache(BaseCache):
pickle_protocol = pickle.HIGHEST_PROTOCOL
def __init__(self, name, params):
super().__init__(params)
self._cache = _caches.setdefault(name, OrderedDict())
self._expire_info = _expire_info.setdefault(name, {})
self._lock = _locks.setdefault(name, Lock())
def add(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
key = self.make_and_validate_key(key, version=version)
pickled = pickle.dumps(value, self.pickle_protocol)
with self._lock:
if self._has_expired(key):
self._set(key, pickled, timeout)
return True
return False
def get(self, key, default=None, version=None):
key = self.make_and_validate_key(key, version=version)
with self._lock:
if self._has_expired(key):
self._delete(key)
return default
pickled = self._cache[key]
self._cache.move_to_end(key, last=False)
return pickle.loads(pickled)
def _set(self, key, value, timeout=DEFAULT_TIMEOUT):
if len(self._cache) >= self._max_entries:
self._cull()
self._cache[key] = value
self._cache.move_to_end(key, last=False)
self._expire_info[key] = self.get_backend_timeout(timeout)
def set(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
key = self.make_and_validate_key(key, version=version)
pickled = pickle.dumps(value, self.pickle_protocol)
with self._lock:
self._set(key, pickled, timeout)
def touch(self, key, timeout=DEFAULT_TIMEOUT, version=None):
key = self.make_and_validate_key(key, version=version)
with self._lock:
if self._has_expired(key):
return False
self._expire_info[key] = self.get_backend_timeout(timeout)
return True
def incr(self, key, delta=1, version=None):
key = self.make_and_validate_key(key, version=version)
with self._lock:
if self._has_expired(key):
self._delete(key)
raise ValueError("Key '%s' not found" % key)
pickled = self._cache[key]
value = pickle.loads(pickled)
new_value = value + delta
pickled = pickle.dumps(new_value, self.pickle_protocol)
self._cache[key] = pickled
self._cache.move_to_end(key, last=False)
return new_value
def has_key(self, key, version=None):
key = self.make_and_validate_key(key, version=version)
with self._lock:
if self._has_expired(key):
self._delete(key)
return False
return True
def _has_expired(self, key):
exp = self._expire_info.get(key, -1)
return exp is not None and exp <= time.time()
def _cull(self):
if self._cull_frequency == 0:
self._cache.clear()
self._expire_info.clear()
else:
count = len(self._cache) // self._cull_frequency
for i in range(count):
key, _ = self._cache.popitem()
del self._expire_info[key]
def _delete(self, key):
try:
del self._cache[key]
del self._expire_info[key]
except KeyError:
return False
return True
def delete(self, key, version=None):
key = self.make_and_validate_key(key, version=version)
with self._lock:
return self._delete(key)
def clear(self):
with self._lock:
self._cache.clear()
self._expire_info.clear()

View File

@@ -0,0 +1,188 @@
"Memcached cache backend"
import re
import time
from django.core.cache.backends.base import (
DEFAULT_TIMEOUT,
BaseCache,
InvalidCacheKey,
memcache_key_warnings,
)
from django.utils.functional import cached_property
class BaseMemcachedCache(BaseCache):
def __init__(self, server, params, library, value_not_found_exception):
super().__init__(params)
if isinstance(server, str):
self._servers = re.split("[;,]", server)
else:
self._servers = server
# Exception type raised by the underlying client library for a
# nonexistent key.
self.LibraryValueNotFoundException = value_not_found_exception
self._lib = library
self._class = library.Client
self._options = params.get("OPTIONS") or {}
@property
def client_servers(self):
return self._servers
@cached_property
def _cache(self):
"""
Implement transparent thread-safe access to a memcached client.
"""
return self._class(self.client_servers, **self._options)
def get_backend_timeout(self, timeout=DEFAULT_TIMEOUT):
"""
Memcached deals with long (> 30 days) timeouts in a special
way. Call this function to obtain a safe value for your timeout.
"""
if timeout == DEFAULT_TIMEOUT:
timeout = self.default_timeout
if timeout is None:
# Using 0 in memcache sets a non-expiring timeout.
return 0
elif int(timeout) == 0:
# Other cache backends treat 0 as set-and-expire. To achieve this
# in memcache backends, a negative timeout must be passed.
timeout = -1
if timeout > 2592000: # 60*60*24*30, 30 days
# See https://github.com/memcached/memcached/wiki/Programming#expiration
# "Expiration times can be set from 0, meaning "never expire", to
# 30 days. Any time higher than 30 days is interpreted as a Unix
# timestamp date. If you want to expire an object on January 1st of
# next year, this is how you do that."
#
# This means that we have to switch to absolute timestamps.
timeout += int(time.time())
return int(timeout)
def add(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
key = self.make_and_validate_key(key, version=version)
return self._cache.add(key, value, self.get_backend_timeout(timeout))
def get(self, key, default=None, version=None):
key = self.make_and_validate_key(key, version=version)
return self._cache.get(key, default)
def set(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
key = self.make_and_validate_key(key, version=version)
if not self._cache.set(key, value, self.get_backend_timeout(timeout)):
# Make sure the key doesn't keep its old value in case of failure
# to set (memcached's 1MB limit).
self._cache.delete(key)
def touch(self, key, timeout=DEFAULT_TIMEOUT, version=None):
key = self.make_and_validate_key(key, version=version)
return bool(self._cache.touch(key, self.get_backend_timeout(timeout)))
def delete(self, key, version=None):
key = self.make_and_validate_key(key, version=version)
return bool(self._cache.delete(key))
def get_many(self, keys, version=None):
key_map = {
self.make_and_validate_key(key, version=version): key for key in keys
}
ret = self._cache.get_multi(key_map.keys())
return {key_map[k]: v for k, v in ret.items()}
def close(self, **kwargs):
# Many clients don't clean up connections properly.
self._cache.disconnect_all()
def incr(self, key, delta=1, version=None):
key = self.make_and_validate_key(key, version=version)
try:
# Memcached doesn't support negative delta.
if delta < 0:
val = self._cache.decr(key, -delta)
else:
val = self._cache.incr(key, delta)
# Normalize an exception raised by the underlying client library to
# ValueError in the event of a nonexistent key when calling
# incr()/decr().
except self.LibraryValueNotFoundException:
val = None
if val is None:
raise ValueError("Key '%s' not found" % key)
return val
def set_many(self, data, timeout=DEFAULT_TIMEOUT, version=None):
safe_data = {}
original_keys = {}
for key, value in data.items():
safe_key = self.make_and_validate_key(key, version=version)
safe_data[safe_key] = value
original_keys[safe_key] = key
failed_keys = self._cache.set_multi(
safe_data, self.get_backend_timeout(timeout)
)
return [original_keys[k] for k in failed_keys]
def delete_many(self, keys, version=None):
keys = [self.make_and_validate_key(key, version=version) for key in keys]
self._cache.delete_multi(keys)
def clear(self):
self._cache.flush_all()
def validate_key(self, key):
for warning in memcache_key_warnings(key):
raise InvalidCacheKey(warning)
class PyLibMCCache(BaseMemcachedCache):
"An implementation of a cache binding using pylibmc"
def __init__(self, server, params):
import pylibmc
super().__init__(
server, params, library=pylibmc, value_not_found_exception=pylibmc.NotFound
)
@property
def client_servers(self):
output = []
for server in self._servers:
output.append(server[5:] if server.startswith("unix:") else server)
return output
def touch(self, key, timeout=DEFAULT_TIMEOUT, version=None):
key = self.make_and_validate_key(key, version=version)
if timeout == 0:
return self._cache.delete(key)
return self._cache.touch(key, self.get_backend_timeout(timeout))
def close(self, **kwargs):
# libmemcached manages its own connections. Don't call disconnect_all()
# as it resets the failover state and creates unnecessary reconnects.
pass
class PyMemcacheCache(BaseMemcachedCache):
"""An implementation of a cache binding using pymemcache."""
def __init__(self, server, params):
import pymemcache.serde
super().__init__(
server, params, library=pymemcache, value_not_found_exception=KeyError
)
self._class = self._lib.HashClient
self._options = {
"allow_unicode_keys": True,
"default_noreply": False,
"serde": pymemcache.serde.pickle_serde,
**self._options,
}

View File

@@ -0,0 +1,233 @@
"""Redis cache backend."""
import pickle
import random
import re
from django.core.cache.backends.base import DEFAULT_TIMEOUT, BaseCache
from django.utils.functional import cached_property
from django.utils.module_loading import import_string
class RedisSerializer:
def __init__(self, protocol=None):
self.protocol = pickle.HIGHEST_PROTOCOL if protocol is None else protocol
def dumps(self, obj):
# Only skip pickling for integers, a int subclasses as bool should be
# pickled.
if type(obj) is int:
return obj
return pickle.dumps(obj, self.protocol)
def loads(self, data):
try:
return int(data)
except ValueError:
return pickle.loads(data)
class RedisCacheClient:
def __init__(
self,
servers,
serializer=None,
pool_class=None,
parser_class=None,
**options,
):
import redis
self._lib = redis
self._servers = servers
self._pools = {}
self._client = self._lib.Redis
if isinstance(pool_class, str):
pool_class = import_string(pool_class)
self._pool_class = pool_class or self._lib.ConnectionPool
if isinstance(serializer, str):
serializer = import_string(serializer)
if callable(serializer):
serializer = serializer()
self._serializer = serializer or RedisSerializer()
if isinstance(parser_class, str):
parser_class = import_string(parser_class)
parser_class = parser_class or self._lib.connection.DefaultParser
self._pool_options = {"parser_class": parser_class, **options}
def _get_connection_pool_index(self, write):
# Write to the first server. Read from other servers if there are more,
# otherwise read from the first server.
if write or len(self._servers) == 1:
return 0
return random.randint(1, len(self._servers) - 1)
def _get_connection_pool(self, write):
index = self._get_connection_pool_index(write)
if index not in self._pools:
self._pools[index] = self._pool_class.from_url(
self._servers[index],
**self._pool_options,
)
return self._pools[index]
def get_client(self, key=None, *, write=False):
# key is used so that the method signature remains the same and custom
# cache client can be implemented which might require the key to select
# the server, e.g. sharding.
pool = self._get_connection_pool(write)
return self._client(connection_pool=pool)
def add(self, key, value, timeout):
client = self.get_client(key, write=True)
value = self._serializer.dumps(value)
if timeout == 0:
if ret := bool(client.set(key, value, nx=True)):
client.delete(key)
return ret
else:
return bool(client.set(key, value, ex=timeout, nx=True))
def get(self, key, default):
client = self.get_client(key)
value = client.get(key)
return default if value is None else self._serializer.loads(value)
def set(self, key, value, timeout):
client = self.get_client(key, write=True)
value = self._serializer.dumps(value)
if timeout == 0:
client.delete(key)
else:
client.set(key, value, ex=timeout)
def touch(self, key, timeout):
client = self.get_client(key, write=True)
if timeout is None:
return bool(client.persist(key))
else:
return bool(client.expire(key, timeout))
def delete(self, key):
client = self.get_client(key, write=True)
return bool(client.delete(key))
def get_many(self, keys):
client = self.get_client(None)
ret = client.mget(keys)
return {
k: self._serializer.loads(v) for k, v in zip(keys, ret) if v is not None
}
def has_key(self, key):
client = self.get_client(key)
return bool(client.exists(key))
def incr(self, key, delta):
client = self.get_client(key, write=True)
if not client.exists(key):
raise ValueError("Key '%s' not found." % key)
return client.incr(key, delta)
def set_many(self, data, timeout):
client = self.get_client(None, write=True)
pipeline = client.pipeline()
pipeline.mset({k: self._serializer.dumps(v) for k, v in data.items()})
if timeout is not None:
# Setting timeout for each key as redis does not support timeout
# with mset().
for key in data:
pipeline.expire(key, timeout)
pipeline.execute()
def delete_many(self, keys):
client = self.get_client(None, write=True)
client.delete(*keys)
def clear(self):
client = self.get_client(None, write=True)
return bool(client.flushdb())
class RedisCache(BaseCache):
def __init__(self, server, params):
super().__init__(params)
if isinstance(server, str):
self._servers = re.split("[;,]", server)
else:
self._servers = server
self._class = RedisCacheClient
self._options = params.get("OPTIONS", {})
@cached_property
def _cache(self):
return self._class(self._servers, **self._options)
def get_backend_timeout(self, timeout=DEFAULT_TIMEOUT):
if timeout == DEFAULT_TIMEOUT:
timeout = self.default_timeout
# The key will be made persistent if None used as a timeout.
# Non-positive values will cause the key to be deleted.
return None if timeout is None else max(0, int(timeout))
def add(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
key = self.make_and_validate_key(key, version=version)
return self._cache.add(key, value, self.get_backend_timeout(timeout))
def get(self, key, default=None, version=None):
key = self.make_and_validate_key(key, version=version)
return self._cache.get(key, default)
def set(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
key = self.make_and_validate_key(key, version=version)
self._cache.set(key, value, self.get_backend_timeout(timeout))
def touch(self, key, timeout=DEFAULT_TIMEOUT, version=None):
key = self.make_and_validate_key(key, version=version)
return self._cache.touch(key, self.get_backend_timeout(timeout))
def delete(self, key, version=None):
key = self.make_and_validate_key(key, version=version)
return self._cache.delete(key)
def get_many(self, keys, version=None):
key_map = {
self.make_and_validate_key(key, version=version): key for key in keys
}
ret = self._cache.get_many(key_map.keys())
return {key_map[k]: v for k, v in ret.items()}
def has_key(self, key, version=None):
key = self.make_and_validate_key(key, version=version)
return self._cache.has_key(key)
def incr(self, key, delta=1, version=None):
key = self.make_and_validate_key(key, version=version)
return self._cache.incr(key, delta)
def set_many(self, data, timeout=DEFAULT_TIMEOUT, version=None):
if not data:
return []
safe_data = {}
for key, value in data.items():
key = self.make_and_validate_key(key, version=version)
safe_data[key] = value
self._cache.set_many(safe_data, self.get_backend_timeout(timeout))
return []
def delete_many(self, keys, version=None):
if not keys:
return
safe_keys = [self.make_and_validate_key(key, version=version) for key in keys]
self._cache.delete_many(safe_keys)
def clear(self):
return self._cache.clear()

View File

@@ -0,0 +1,12 @@
from django.utils.crypto import md5
TEMPLATE_FRAGMENT_KEY_TEMPLATE = "template.cache.%s.%s"
def make_template_fragment_key(fragment_name, vary_on=None):
hasher = md5(usedforsecurity=False)
if vary_on is not None:
for arg in vary_on:
hasher.update(str(arg).encode())
hasher.update(b":")
return TEMPLATE_FRAGMENT_KEY_TEMPLATE % (fragment_name, hasher.hexdigest())

View File

@@ -0,0 +1,47 @@
from .messages import (
CRITICAL,
DEBUG,
ERROR,
INFO,
WARNING,
CheckMessage,
Critical,
Debug,
Error,
Info,
Warning,
)
from .registry import Tags, register, run_checks, tag_exists
# Import these to force registration of checks
import django.core.checks.async_checks # NOQA isort:skip
import django.core.checks.caches # NOQA isort:skip
import django.core.checks.compatibility.django_4_0 # NOQA isort:skip
import django.core.checks.database # NOQA isort:skip
import django.core.checks.files # NOQA isort:skip
import django.core.checks.model_checks # NOQA isort:skip
import django.core.checks.security.base # NOQA isort:skip
import django.core.checks.security.csrf # NOQA isort:skip
import django.core.checks.security.sessions # NOQA isort:skip
import django.core.checks.templates # NOQA isort:skip
import django.core.checks.translation # NOQA isort:skip
import django.core.checks.urls # NOQA isort:skip
__all__ = [
"CheckMessage",
"Debug",
"Info",
"Warning",
"Error",
"Critical",
"DEBUG",
"INFO",
"WARNING",
"ERROR",
"CRITICAL",
"register",
"run_checks",
"tag_exists",
"Tags",
]

View File

@@ -0,0 +1,16 @@
import os
from . import Error, Tags, register
E001 = Error(
"You should not set the DJANGO_ALLOW_ASYNC_UNSAFE environment variable in "
"deployment. This disables async safety protection.",
id="async.E001",
)
@register(Tags.async_support, deploy=True)
def check_async_unsafe(app_configs, **kwargs):
if os.environ.get("DJANGO_ALLOW_ASYNC_UNSAFE"):
return [E001]
return []

View File

@@ -0,0 +1,76 @@
import pathlib
from django.conf import settings
from django.core.cache import DEFAULT_CACHE_ALIAS, caches
from django.core.cache.backends.filebased import FileBasedCache
from . import Error, Tags, Warning, register
E001 = Error(
"You must define a '%s' cache in your CACHES setting." % DEFAULT_CACHE_ALIAS,
id="caches.E001",
)
@register(Tags.caches)
def check_default_cache_is_configured(app_configs, **kwargs):
if DEFAULT_CACHE_ALIAS not in settings.CACHES:
return [E001]
return []
@register(Tags.caches, deploy=True)
def check_cache_location_not_exposed(app_configs, **kwargs):
errors = []
for name in ("MEDIA_ROOT", "STATIC_ROOT", "STATICFILES_DIRS"):
setting = getattr(settings, name, None)
if not setting:
continue
if name == "STATICFILES_DIRS":
paths = set()
for staticfiles_dir in setting:
if isinstance(staticfiles_dir, (list, tuple)):
_, staticfiles_dir = staticfiles_dir
paths.add(pathlib.Path(staticfiles_dir).resolve())
else:
paths = {pathlib.Path(setting).resolve()}
for alias in settings.CACHES:
cache = caches[alias]
if not isinstance(cache, FileBasedCache):
continue
cache_path = pathlib.Path(cache._dir).resolve()
if any(path == cache_path for path in paths):
relation = "matches"
elif any(path in cache_path.parents for path in paths):
relation = "is inside"
elif any(cache_path in path.parents for path in paths):
relation = "contains"
else:
continue
errors.append(
Warning(
f"Your '{alias}' cache configuration might expose your cache "
f"or lead to corruption of your data because its LOCATION "
f"{relation} {name}.",
id="caches.W002",
)
)
return errors
@register(Tags.caches)
def check_file_based_cache_is_absolute(app_configs, **kwargs):
errors = []
for alias, config in settings.CACHES.items():
cache = caches[alias]
if not isinstance(cache, FileBasedCache):
continue
if not pathlib.Path(config["LOCATION"]).is_absolute():
errors.append(
Warning(
f"Your '{alias}' cache LOCATION path is relative. Use an "
f"absolute path instead.",
id="caches.W003",
)
)
return errors

View File

@@ -0,0 +1,20 @@
from django.conf import settings
from .. import Error, Tags, register
@register(Tags.compatibility)
def check_csrf_trusted_origins(app_configs, **kwargs):
errors = []
for origin in settings.CSRF_TRUSTED_ORIGINS:
if "://" not in origin:
errors.append(
Error(
"As of Django 4.0, the values in the CSRF_TRUSTED_ORIGINS "
"setting must start with a scheme (usually http:// or "
"https://) but found %s. See the release notes for details."
% origin,
id="4_0.E001",
)
)
return errors

View File

@@ -0,0 +1,14 @@
from django.db import connections
from . import Tags, register
@register(Tags.database)
def check_database_backends(databases=None, **kwargs):
if databases is None:
return []
issues = []
for alias in databases:
conn = connections[alias]
issues.extend(conn.validation.check(**kwargs))
return issues

View File

@@ -0,0 +1,19 @@
from pathlib import Path
from django.conf import settings
from . import Error, Tags, register
@register(Tags.files)
def check_setting_file_upload_temp_dir(app_configs, **kwargs):
setting = getattr(settings, "FILE_UPLOAD_TEMP_DIR", None)
if setting and not Path(setting).is_dir():
return [
Error(
f"The FILE_UPLOAD_TEMP_DIR setting refers to the nonexistent "
f"directory '{setting}'.",
id="files.E001",
),
]
return []

View File

@@ -0,0 +1,81 @@
# Levels
DEBUG = 10
INFO = 20
WARNING = 30
ERROR = 40
CRITICAL = 50
class CheckMessage:
def __init__(self, level, msg, hint=None, obj=None, id=None):
if not isinstance(level, int):
raise TypeError("The first argument should be level.")
self.level = level
self.msg = msg
self.hint = hint
self.obj = obj
self.id = id
def __eq__(self, other):
return isinstance(other, self.__class__) and all(
getattr(self, attr) == getattr(other, attr)
for attr in ["level", "msg", "hint", "obj", "id"]
)
def __str__(self):
from django.db import models
if self.obj is None:
obj = "?"
elif isinstance(self.obj, models.base.ModelBase):
# We need to hardcode ModelBase and Field cases because its __str__
# method doesn't return "applabel.modellabel" and cannot be changed.
obj = self.obj._meta.label
else:
obj = str(self.obj)
id = "(%s) " % self.id if self.id else ""
hint = "\n\tHINT: %s" % self.hint if self.hint else ""
return "%s: %s%s%s" % (obj, id, self.msg, hint)
def __repr__(self):
return "<%s: level=%r, msg=%r, hint=%r, obj=%r, id=%r>" % (
self.__class__.__name__,
self.level,
self.msg,
self.hint,
self.obj,
self.id,
)
def is_serious(self, level=ERROR):
return self.level >= level
def is_silenced(self):
from django.conf import settings
return self.id in settings.SILENCED_SYSTEM_CHECKS
class Debug(CheckMessage):
def __init__(self, *args, **kwargs):
super().__init__(DEBUG, *args, **kwargs)
class Info(CheckMessage):
def __init__(self, *args, **kwargs):
super().__init__(INFO, *args, **kwargs)
class Warning(CheckMessage):
def __init__(self, *args, **kwargs):
super().__init__(WARNING, *args, **kwargs)
class Error(CheckMessage):
def __init__(self, *args, **kwargs):
super().__init__(ERROR, *args, **kwargs)
class Critical(CheckMessage):
def __init__(self, *args, **kwargs):
super().__init__(CRITICAL, *args, **kwargs)

View File

@@ -0,0 +1,227 @@
import inspect
import types
from collections import defaultdict
from itertools import chain
from django.apps import apps
from django.conf import settings
from django.core.checks import Error, Tags, Warning, register
@register(Tags.models)
def check_all_models(app_configs=None, **kwargs):
db_table_models = defaultdict(list)
indexes = defaultdict(list)
constraints = defaultdict(list)
errors = []
if app_configs is None:
models = apps.get_models()
else:
models = chain.from_iterable(
app_config.get_models() for app_config in app_configs
)
for model in models:
if model._meta.managed and not model._meta.proxy:
db_table_models[model._meta.db_table].append(model._meta.label)
if not inspect.ismethod(model.check):
errors.append(
Error(
"The '%s.check()' class method is currently overridden by %r."
% (model.__name__, model.check),
obj=model,
id="models.E020",
)
)
else:
errors.extend(model.check(**kwargs))
for model_index in model._meta.indexes:
indexes[model_index.name].append(model._meta.label)
for model_constraint in model._meta.constraints:
constraints[model_constraint.name].append(model._meta.label)
if settings.DATABASE_ROUTERS:
error_class, error_id = Warning, "models.W035"
error_hint = (
"You have configured settings.DATABASE_ROUTERS. Verify that %s "
"are correctly routed to separate databases."
)
else:
error_class, error_id = Error, "models.E028"
error_hint = None
for db_table, model_labels in db_table_models.items():
if len(model_labels) != 1:
model_labels_str = ", ".join(model_labels)
errors.append(
error_class(
"db_table '%s' is used by multiple models: %s."
% (db_table, model_labels_str),
obj=db_table,
hint=(error_hint % model_labels_str) if error_hint else None,
id=error_id,
)
)
for index_name, model_labels in indexes.items():
if len(model_labels) > 1:
model_labels = set(model_labels)
errors.append(
Error(
"index name '%s' is not unique %s %s."
% (
index_name,
"for model" if len(model_labels) == 1 else "among models:",
", ".join(sorted(model_labels)),
),
id="models.E029" if len(model_labels) == 1 else "models.E030",
),
)
for constraint_name, model_labels in constraints.items():
if len(model_labels) > 1:
model_labels = set(model_labels)
errors.append(
Error(
"constraint name '%s' is not unique %s %s."
% (
constraint_name,
"for model" if len(model_labels) == 1 else "among models:",
", ".join(sorted(model_labels)),
),
id="models.E031" if len(model_labels) == 1 else "models.E032",
),
)
return errors
def _check_lazy_references(apps, ignore=None):
"""
Ensure all lazy (i.e. string) model references have been resolved.
Lazy references are used in various places throughout Django, primarily in
related fields and model signals. Identify those common cases and provide
more helpful error messages for them.
The ignore parameter is used by StateApps to exclude swappable models from
this check.
"""
pending_models = set(apps._pending_operations) - (ignore or set())
# Short circuit if there aren't any errors.
if not pending_models:
return []
from django.db.models import signals
model_signals = {
signal: name
for name, signal in vars(signals).items()
if isinstance(signal, signals.ModelSignal)
}
def extract_operation(obj):
"""
Take a callable found in Apps._pending_operations and identify the
original callable passed to Apps.lazy_model_operation(). If that
callable was a partial, return the inner, non-partial function and
any arguments and keyword arguments that were supplied with it.
obj is a callback defined locally in Apps.lazy_model_operation() and
annotated there with a `func` attribute so as to imitate a partial.
"""
operation, args, keywords = obj, [], {}
while hasattr(operation, "func"):
args.extend(getattr(operation, "args", []))
keywords.update(getattr(operation, "keywords", {}))
operation = operation.func
return operation, args, keywords
def app_model_error(model_key):
try:
apps.get_app_config(model_key[0])
model_error = "app '%s' doesn't provide model '%s'" % model_key
except LookupError:
model_error = "app '%s' isn't installed" % model_key[0]
return model_error
# Here are several functions which return CheckMessage instances for the
# most common usages of lazy operations throughout Django. These functions
# take the model that was being waited on as an (app_label, modelname)
# pair, the original lazy function, and its positional and keyword args as
# determined by extract_operation().
def field_error(model_key, func, args, keywords):
error_msg = (
"The field %(field)s was declared with a lazy reference "
"to '%(model)s', but %(model_error)s."
)
params = {
"model": ".".join(model_key),
"field": keywords["field"],
"model_error": app_model_error(model_key),
}
return Error(error_msg % params, obj=keywords["field"], id="fields.E307")
def signal_connect_error(model_key, func, args, keywords):
error_msg = (
"%(receiver)s was connected to the '%(signal)s' signal with a "
"lazy reference to the sender '%(model)s', but %(model_error)s."
)
receiver = args[0]
# The receiver is either a function or an instance of class
# defining a `__call__` method.
if isinstance(receiver, types.FunctionType):
description = "The function '%s'" % receiver.__name__
elif isinstance(receiver, types.MethodType):
description = "Bound method '%s.%s'" % (
receiver.__self__.__class__.__name__,
receiver.__name__,
)
else:
description = "An instance of class '%s'" % receiver.__class__.__name__
signal_name = model_signals.get(func.__self__, "unknown")
params = {
"model": ".".join(model_key),
"receiver": description,
"signal": signal_name,
"model_error": app_model_error(model_key),
}
return Error(error_msg % params, obj=receiver.__module__, id="signals.E001")
def default_error(model_key, func, args, keywords):
error_msg = (
"%(op)s contains a lazy reference to %(model)s, but %(model_error)s."
)
params = {
"op": func,
"model": ".".join(model_key),
"model_error": app_model_error(model_key),
}
return Error(error_msg % params, obj=func, id="models.E022")
# Maps common uses of lazy operations to corresponding error functions
# defined above. If a key maps to None, no error will be produced.
# default_error() will be used for usages that don't appear in this dict.
known_lazy = {
("django.db.models.fields.related", "resolve_related_class"): field_error,
("django.db.models.fields.related", "set_managed"): None,
("django.dispatch.dispatcher", "connect"): signal_connect_error,
}
def build_error(model_key, func, args, keywords):
key = (func.__module__, func.__name__)
error_fn = known_lazy.get(key, default_error)
return error_fn(model_key, func, args, keywords) if error_fn else None
return sorted(
filter(
None,
(
build_error(model_key, *extract_operation(func))
for model_key in pending_models
for func in apps._pending_operations[model_key]
),
),
key=lambda error: error.msg,
)
@register(Tags.models)
def check_lazy_references(app_configs=None, **kwargs):
return _check_lazy_references(apps)

View File

@@ -0,0 +1,117 @@
from itertools import chain
from django.utils.inspect import func_accepts_kwargs
from django.utils.itercompat import is_iterable
class Tags:
"""
Built-in tags for internal checks.
"""
admin = "admin"
async_support = "async_support"
caches = "caches"
compatibility = "compatibility"
database = "database"
files = "files"
models = "models"
security = "security"
signals = "signals"
sites = "sites"
staticfiles = "staticfiles"
templates = "templates"
translation = "translation"
urls = "urls"
class CheckRegistry:
def __init__(self):
self.registered_checks = set()
self.deployment_checks = set()
def register(self, check=None, *tags, **kwargs):
"""
Can be used as a function or a decorator. Register given function
`f` labeled with given `tags`. The function should receive **kwargs
and return list of Errors and Warnings.
Example::
registry = CheckRegistry()
@registry.register('mytag', 'anothertag')
def my_check(app_configs, **kwargs):
# ... perform checks and collect `errors` ...
return errors
# or
registry.register(my_check, 'mytag', 'anothertag')
"""
def inner(check):
if not func_accepts_kwargs(check):
raise TypeError(
"Check functions must accept keyword arguments (**kwargs)."
)
check.tags = tags
checks = (
self.deployment_checks
if kwargs.get("deploy")
else self.registered_checks
)
checks.add(check)
return check
if callable(check):
return inner(check)
else:
if check:
tags += (check,)
return inner
def run_checks(
self,
app_configs=None,
tags=None,
include_deployment_checks=False,
databases=None,
):
"""
Run all registered checks and return list of Errors and Warnings.
"""
errors = []
checks = self.get_checks(include_deployment_checks)
if tags is not None:
checks = [check for check in checks if not set(check.tags).isdisjoint(tags)]
for check in checks:
new_errors = check(app_configs=app_configs, databases=databases)
if not is_iterable(new_errors):
raise TypeError(
"The function %r did not return a list. All functions "
"registered with the checks registry must return a list." % check,
)
errors.extend(new_errors)
return errors
def tag_exists(self, tag, include_deployment_checks=False):
return tag in self.tags_available(include_deployment_checks)
def tags_available(self, deployment_checks=False):
return set(
chain.from_iterable(
check.tags for check in self.get_checks(deployment_checks)
)
)
def get_checks(self, include_deployment_checks=False):
checks = list(self.registered_checks)
if include_deployment_checks:
checks.extend(self.deployment_checks)
return checks
registry = CheckRegistry()
register = registry.register
run_checks = registry.run_checks
tag_exists = registry.tag_exists

View File

@@ -0,0 +1,283 @@
from django.conf import settings
from django.core.exceptions import ImproperlyConfigured
from .. import Error, Tags, Warning, register
CROSS_ORIGIN_OPENER_POLICY_VALUES = {
"same-origin",
"same-origin-allow-popups",
"unsafe-none",
}
REFERRER_POLICY_VALUES = {
"no-referrer",
"no-referrer-when-downgrade",
"origin",
"origin-when-cross-origin",
"same-origin",
"strict-origin",
"strict-origin-when-cross-origin",
"unsafe-url",
}
SECRET_KEY_INSECURE_PREFIX = "django-insecure-"
SECRET_KEY_MIN_LENGTH = 50
SECRET_KEY_MIN_UNIQUE_CHARACTERS = 5
SECRET_KEY_WARNING_MSG = (
f"Your %s has less than {SECRET_KEY_MIN_LENGTH} characters, less than "
f"{SECRET_KEY_MIN_UNIQUE_CHARACTERS} unique characters, or it's prefixed "
f"with '{SECRET_KEY_INSECURE_PREFIX}' indicating that it was generated "
f"automatically by Django. Please generate a long and random value, "
f"otherwise many of Django's security-critical features will be "
f"vulnerable to attack."
)
W001 = Warning(
"You do not have 'django.middleware.security.SecurityMiddleware' "
"in your MIDDLEWARE so the SECURE_HSTS_SECONDS, "
"SECURE_CONTENT_TYPE_NOSNIFF, SECURE_REFERRER_POLICY, "
"SECURE_CROSS_ORIGIN_OPENER_POLICY, and SECURE_SSL_REDIRECT settings will "
"have no effect.",
id="security.W001",
)
W002 = Warning(
"You do not have "
"'django.middleware.clickjacking.XFrameOptionsMiddleware' in your "
"MIDDLEWARE, so your pages will not be served with an "
"'x-frame-options' header. Unless there is a good reason for your "
"site to be served in a frame, you should consider enabling this "
"header to help prevent clickjacking attacks.",
id="security.W002",
)
W004 = Warning(
"You have not set a value for the SECURE_HSTS_SECONDS setting. "
"If your entire site is served only over SSL, you may want to consider "
"setting a value and enabling HTTP Strict Transport Security. "
"Be sure to read the documentation first; enabling HSTS carelessly "
"can cause serious, irreversible problems.",
id="security.W004",
)
W005 = Warning(
"You have not set the SECURE_HSTS_INCLUDE_SUBDOMAINS setting to True. "
"Without this, your site is potentially vulnerable to attack "
"via an insecure connection to a subdomain. Only set this to True if "
"you are certain that all subdomains of your domain should be served "
"exclusively via SSL.",
id="security.W005",
)
W006 = Warning(
"Your SECURE_CONTENT_TYPE_NOSNIFF setting is not set to True, "
"so your pages will not be served with an "
"'X-Content-Type-Options: nosniff' header. "
"You should consider enabling this header to prevent the "
"browser from identifying content types incorrectly.",
id="security.W006",
)
W008 = Warning(
"Your SECURE_SSL_REDIRECT setting is not set to True. "
"Unless your site should be available over both SSL and non-SSL "
"connections, you may want to either set this setting True "
"or configure a load balancer or reverse-proxy server "
"to redirect all connections to HTTPS.",
id="security.W008",
)
W009 = Warning(
SECRET_KEY_WARNING_MSG % "SECRET_KEY",
id="security.W009",
)
W018 = Warning(
"You should not have DEBUG set to True in deployment.",
id="security.W018",
)
W019 = Warning(
"You have "
"'django.middleware.clickjacking.XFrameOptionsMiddleware' in your "
"MIDDLEWARE, but X_FRAME_OPTIONS is not set to 'DENY'. "
"Unless there is a good reason for your site to serve other parts of "
"itself in a frame, you should change it to 'DENY'.",
id="security.W019",
)
W020 = Warning(
"ALLOWED_HOSTS must not be empty in deployment.",
id="security.W020",
)
W021 = Warning(
"You have not set the SECURE_HSTS_PRELOAD setting to True. Without this, "
"your site cannot be submitted to the browser preload list.",
id="security.W021",
)
W022 = Warning(
"You have not set the SECURE_REFERRER_POLICY setting. Without this, your "
"site will not send a Referrer-Policy header. You should consider "
"enabling this header to protect user privacy.",
id="security.W022",
)
E023 = Error(
"You have set the SECURE_REFERRER_POLICY setting to an invalid value.",
hint="Valid values are: {}.".format(", ".join(sorted(REFERRER_POLICY_VALUES))),
id="security.E023",
)
E024 = Error(
"You have set the SECURE_CROSS_ORIGIN_OPENER_POLICY setting to an invalid "
"value.",
hint="Valid values are: {}.".format(
", ".join(sorted(CROSS_ORIGIN_OPENER_POLICY_VALUES)),
),
id="security.E024",
)
W025 = Warning(SECRET_KEY_WARNING_MSG, id="security.W025")
def _security_middleware():
return "django.middleware.security.SecurityMiddleware" in settings.MIDDLEWARE
def _xframe_middleware():
return (
"django.middleware.clickjacking.XFrameOptionsMiddleware" in settings.MIDDLEWARE
)
@register(Tags.security, deploy=True)
def check_security_middleware(app_configs, **kwargs):
passed_check = _security_middleware()
return [] if passed_check else [W001]
@register(Tags.security, deploy=True)
def check_xframe_options_middleware(app_configs, **kwargs):
passed_check = _xframe_middleware()
return [] if passed_check else [W002]
@register(Tags.security, deploy=True)
def check_sts(app_configs, **kwargs):
passed_check = not _security_middleware() or settings.SECURE_HSTS_SECONDS
return [] if passed_check else [W004]
@register(Tags.security, deploy=True)
def check_sts_include_subdomains(app_configs, **kwargs):
passed_check = (
not _security_middleware()
or not settings.SECURE_HSTS_SECONDS
or settings.SECURE_HSTS_INCLUDE_SUBDOMAINS is True
)
return [] if passed_check else [W005]
@register(Tags.security, deploy=True)
def check_sts_preload(app_configs, **kwargs):
passed_check = (
not _security_middleware()
or not settings.SECURE_HSTS_SECONDS
or settings.SECURE_HSTS_PRELOAD is True
)
return [] if passed_check else [W021]
@register(Tags.security, deploy=True)
def check_content_type_nosniff(app_configs, **kwargs):
passed_check = (
not _security_middleware() or settings.SECURE_CONTENT_TYPE_NOSNIFF is True
)
return [] if passed_check else [W006]
@register(Tags.security, deploy=True)
def check_ssl_redirect(app_configs, **kwargs):
passed_check = not _security_middleware() or settings.SECURE_SSL_REDIRECT is True
return [] if passed_check else [W008]
def _check_secret_key(secret_key):
return (
len(set(secret_key)) >= SECRET_KEY_MIN_UNIQUE_CHARACTERS
and len(secret_key) >= SECRET_KEY_MIN_LENGTH
and not secret_key.startswith(SECRET_KEY_INSECURE_PREFIX)
)
@register(Tags.security, deploy=True)
def check_secret_key(app_configs, **kwargs):
try:
secret_key = settings.SECRET_KEY
except (ImproperlyConfigured, AttributeError):
passed_check = False
else:
passed_check = _check_secret_key(secret_key)
return [] if passed_check else [W009]
@register(Tags.security, deploy=True)
def check_secret_key_fallbacks(app_configs, **kwargs):
warnings = []
try:
fallbacks = settings.SECRET_KEY_FALLBACKS
except (ImproperlyConfigured, AttributeError):
warnings.append(Warning(W025.msg % "SECRET_KEY_FALLBACKS", id=W025.id))
else:
for index, key in enumerate(fallbacks):
if not _check_secret_key(key):
warnings.append(
Warning(W025.msg % f"SECRET_KEY_FALLBACKS[{index}]", id=W025.id)
)
return warnings
@register(Tags.security, deploy=True)
def check_debug(app_configs, **kwargs):
passed_check = not settings.DEBUG
return [] if passed_check else [W018]
@register(Tags.security, deploy=True)
def check_xframe_deny(app_configs, **kwargs):
passed_check = not _xframe_middleware() or settings.X_FRAME_OPTIONS == "DENY"
return [] if passed_check else [W019]
@register(Tags.security, deploy=True)
def check_allowed_hosts(app_configs, **kwargs):
return [] if settings.ALLOWED_HOSTS else [W020]
@register(Tags.security, deploy=True)
def check_referrer_policy(app_configs, **kwargs):
if _security_middleware():
if settings.SECURE_REFERRER_POLICY is None:
return [W022]
# Support a comma-separated string or iterable of values to allow fallback.
if isinstance(settings.SECURE_REFERRER_POLICY, str):
values = {v.strip() for v in settings.SECURE_REFERRER_POLICY.split(",")}
else:
values = set(settings.SECURE_REFERRER_POLICY)
if not values <= REFERRER_POLICY_VALUES:
return [E023]
return []
@register(Tags.security, deploy=True)
def check_cross_origin_opener_policy(app_configs, **kwargs):
if (
_security_middleware()
and settings.SECURE_CROSS_ORIGIN_OPENER_POLICY is not None
and settings.SECURE_CROSS_ORIGIN_OPENER_POLICY
not in CROSS_ORIGIN_OPENER_POLICY_VALUES
):
return [E024]
return []

View File

@@ -0,0 +1,67 @@
import inspect
from django.conf import settings
from .. import Error, Tags, Warning, register
W003 = Warning(
"You don't appear to be using Django's built-in "
"cross-site request forgery protection via the middleware "
"('django.middleware.csrf.CsrfViewMiddleware' is not in your "
"MIDDLEWARE). Enabling the middleware is the safest approach "
"to ensure you don't leave any holes.",
id="security.W003",
)
W016 = Warning(
"You have 'django.middleware.csrf.CsrfViewMiddleware' in your "
"MIDDLEWARE, but you have not set CSRF_COOKIE_SECURE to True. "
"Using a secure-only CSRF cookie makes it more difficult for network "
"traffic sniffers to steal the CSRF token.",
id="security.W016",
)
def _csrf_middleware():
return "django.middleware.csrf.CsrfViewMiddleware" in settings.MIDDLEWARE
@register(Tags.security, deploy=True)
def check_csrf_middleware(app_configs, **kwargs):
passed_check = _csrf_middleware()
return [] if passed_check else [W003]
@register(Tags.security, deploy=True)
def check_csrf_cookie_secure(app_configs, **kwargs):
passed_check = (
settings.CSRF_USE_SESSIONS
or not _csrf_middleware()
or settings.CSRF_COOKIE_SECURE is True
)
return [] if passed_check else [W016]
@register(Tags.security)
def check_csrf_failure_view(app_configs, **kwargs):
from django.middleware.csrf import _get_failure_view
errors = []
try:
view = _get_failure_view()
except ImportError:
msg = (
"The CSRF failure view '%s' could not be imported."
% settings.CSRF_FAILURE_VIEW
)
errors.append(Error(msg, id="security.E102"))
else:
try:
inspect.signature(view).bind(None, reason=None)
except TypeError:
msg = (
"The CSRF failure view '%s' does not take the correct number of "
"arguments." % settings.CSRF_FAILURE_VIEW
)
errors.append(Error(msg, id="security.E101"))
return errors

View File

@@ -0,0 +1,99 @@
from django.conf import settings
from .. import Tags, Warning, register
def add_session_cookie_message(message):
return message + (
" Using a secure-only session cookie makes it more difficult for "
"network traffic sniffers to hijack user sessions."
)
W010 = Warning(
add_session_cookie_message(
"You have 'django.contrib.sessions' in your INSTALLED_APPS, "
"but you have not set SESSION_COOKIE_SECURE to True."
),
id="security.W010",
)
W011 = Warning(
add_session_cookie_message(
"You have 'django.contrib.sessions.middleware.SessionMiddleware' "
"in your MIDDLEWARE, but you have not set "
"SESSION_COOKIE_SECURE to True."
),
id="security.W011",
)
W012 = Warning(
add_session_cookie_message("SESSION_COOKIE_SECURE is not set to True."),
id="security.W012",
)
def add_httponly_message(message):
return message + (
" Using an HttpOnly session cookie makes it more difficult for "
"cross-site scripting attacks to hijack user sessions."
)
W013 = Warning(
add_httponly_message(
"You have 'django.contrib.sessions' in your INSTALLED_APPS, "
"but you have not set SESSION_COOKIE_HTTPONLY to True.",
),
id="security.W013",
)
W014 = Warning(
add_httponly_message(
"You have 'django.contrib.sessions.middleware.SessionMiddleware' "
"in your MIDDLEWARE, but you have not set "
"SESSION_COOKIE_HTTPONLY to True."
),
id="security.W014",
)
W015 = Warning(
add_httponly_message("SESSION_COOKIE_HTTPONLY is not set to True."),
id="security.W015",
)
@register(Tags.security, deploy=True)
def check_session_cookie_secure(app_configs, **kwargs):
if settings.SESSION_COOKIE_SECURE is True:
return []
errors = []
if _session_app():
errors.append(W010)
if _session_middleware():
errors.append(W011)
if len(errors) > 1:
errors = [W012]
return errors
@register(Tags.security, deploy=True)
def check_session_cookie_httponly(app_configs, **kwargs):
if settings.SESSION_COOKIE_HTTPONLY is True:
return []
errors = []
if _session_app():
errors.append(W013)
if _session_middleware():
errors.append(W014)
if len(errors) > 1:
errors = [W015]
return errors
def _session_middleware():
return "django.contrib.sessions.middleware.SessionMiddleware" in settings.MIDDLEWARE
def _session_app():
return "django.contrib.sessions" in settings.INSTALLED_APPS

View File

@@ -0,0 +1,75 @@
import copy
from collections import defaultdict
from django.conf import settings
from django.template.backends.django import get_template_tag_modules
from . import Error, Tags, Warning, register
E001 = Error(
"You have 'APP_DIRS': True in your TEMPLATES but also specify 'loaders' "
"in OPTIONS. Either remove APP_DIRS or remove the 'loaders' option.",
id="templates.E001",
)
E002 = Error(
"'string_if_invalid' in TEMPLATES OPTIONS must be a string but got: {} ({}).",
id="templates.E002",
)
W003 = Warning(
"{} is used for multiple template tag modules: {}",
id="templates.E003",
)
@register(Tags.templates)
def check_setting_app_dirs_loaders(app_configs, **kwargs):
return (
[E001]
if any(
conf.get("APP_DIRS") and "loaders" in conf.get("OPTIONS", {})
for conf in settings.TEMPLATES
)
else []
)
@register(Tags.templates)
def check_string_if_invalid_is_string(app_configs, **kwargs):
errors = []
for conf in settings.TEMPLATES:
string_if_invalid = conf.get("OPTIONS", {}).get("string_if_invalid", "")
if not isinstance(string_if_invalid, str):
error = copy.copy(E002)
error.msg = error.msg.format(
string_if_invalid, type(string_if_invalid).__name__
)
errors.append(error)
return errors
@register(Tags.templates)
def check_for_template_tags_with_the_same_name(app_configs, **kwargs):
errors = []
libraries = defaultdict(set)
for conf in settings.TEMPLATES:
custom_libraries = conf.get("OPTIONS", {}).get("libraries", {})
for module_name, module_path in custom_libraries.items():
libraries[module_name].add(module_path)
for module_name, module_path in get_template_tag_modules():
libraries[module_name].add(module_path)
for library_name, items in libraries.items():
if len(items) > 1:
errors.append(
Warning(
W003.msg.format(
repr(library_name),
", ".join(repr(item) for item in sorted(items)),
),
id=W003.id,
)
)
return errors

View File

@@ -0,0 +1,66 @@
from django.conf import settings
from django.utils.translation import get_supported_language_variant
from django.utils.translation.trans_real import language_code_re
from . import Error, Tags, register
E001 = Error(
"You have provided an invalid value for the LANGUAGE_CODE setting: {!r}.",
id="translation.E001",
)
E002 = Error(
"You have provided an invalid language code in the LANGUAGES setting: {!r}.",
id="translation.E002",
)
E003 = Error(
"You have provided an invalid language code in the LANGUAGES_BIDI setting: {!r}.",
id="translation.E003",
)
E004 = Error(
"You have provided a value for the LANGUAGE_CODE setting that is not in "
"the LANGUAGES setting.",
id="translation.E004",
)
@register(Tags.translation)
def check_setting_language_code(app_configs, **kwargs):
"""Error if LANGUAGE_CODE setting is invalid."""
tag = settings.LANGUAGE_CODE
if not isinstance(tag, str) or not language_code_re.match(tag):
return [Error(E001.msg.format(tag), id=E001.id)]
return []
@register(Tags.translation)
def check_setting_languages(app_configs, **kwargs):
"""Error if LANGUAGES setting is invalid."""
return [
Error(E002.msg.format(tag), id=E002.id)
for tag, _ in settings.LANGUAGES
if not isinstance(tag, str) or not language_code_re.match(tag)
]
@register(Tags.translation)
def check_setting_languages_bidi(app_configs, **kwargs):
"""Error if LANGUAGES_BIDI setting is invalid."""
return [
Error(E003.msg.format(tag), id=E003.id)
for tag in settings.LANGUAGES_BIDI
if not isinstance(tag, str) or not language_code_re.match(tag)
]
@register(Tags.translation)
def check_language_settings_consistent(app_configs, **kwargs):
"""Error if language settings are not consistent with each other."""
try:
get_supported_language_variant(settings.LANGUAGE_CODE)
except LookupError:
return [E004]
else:
return []

View File

@@ -0,0 +1,117 @@
from collections import Counter
from django.conf import settings
from . import Error, Tags, Warning, register
@register(Tags.urls)
def check_url_config(app_configs, **kwargs):
if getattr(settings, "ROOT_URLCONF", None):
from django.urls import get_resolver
resolver = get_resolver()
return check_resolver(resolver)
return []
def check_resolver(resolver):
"""
Recursively check the resolver.
"""
check_method = getattr(resolver, "check", None)
if check_method is not None:
return check_method()
elif not hasattr(resolver, "resolve"):
return get_warning_for_invalid_pattern(resolver)
else:
return []
@register(Tags.urls)
def check_url_namespaces_unique(app_configs, **kwargs):
"""
Warn if URL namespaces used in applications aren't unique.
"""
if not getattr(settings, "ROOT_URLCONF", None):
return []
from django.urls import get_resolver
resolver = get_resolver()
all_namespaces = _load_all_namespaces(resolver)
counter = Counter(all_namespaces)
non_unique_namespaces = [n for n, count in counter.items() if count > 1]
errors = []
for namespace in non_unique_namespaces:
errors.append(
Warning(
"URL namespace '{}' isn't unique. You may not be able to reverse "
"all URLs in this namespace".format(namespace),
id="urls.W005",
)
)
return errors
def _load_all_namespaces(resolver, parents=()):
"""
Recursively load all namespaces from URL patterns.
"""
url_patterns = getattr(resolver, "url_patterns", [])
namespaces = [
":".join(parents + (url.namespace,))
for url in url_patterns
if getattr(url, "namespace", None) is not None
]
for pattern in url_patterns:
namespace = getattr(pattern, "namespace", None)
current = parents
if namespace is not None:
current += (namespace,)
namespaces.extend(_load_all_namespaces(pattern, current))
return namespaces
def get_warning_for_invalid_pattern(pattern):
"""
Return a list containing a warning that the pattern is invalid.
describe_pattern() cannot be used here, because we cannot rely on the
urlpattern having regex or name attributes.
"""
if isinstance(pattern, str):
hint = (
"Try removing the string '{}'. The list of urlpatterns should not "
"have a prefix string as the first element.".format(pattern)
)
elif isinstance(pattern, tuple):
hint = "Try using path() instead of a tuple."
else:
hint = None
return [
Error(
"Your URL pattern {!r} is invalid. Ensure that urlpatterns is a list "
"of path() and/or re_path() instances.".format(pattern),
hint=hint,
id="urls.E004",
)
]
@register(Tags.urls)
def check_url_settings(app_configs, **kwargs):
errors = []
for name in ("STATIC_URL", "MEDIA_URL"):
value = getattr(settings, name)
if value and not value.endswith("/"):
errors.append(E006(name))
return errors
def E006(name):
return Error(
"The {} setting must end with a slash.".format(name),
id="urls.E006",
)

View File

@@ -0,0 +1,254 @@
"""
Global Django exception and warning classes.
"""
import operator
from django.utils.hashable import make_hashable
class FieldDoesNotExist(Exception):
"""The requested model field does not exist"""
pass
class AppRegistryNotReady(Exception):
"""The django.apps registry is not populated yet"""
pass
class ObjectDoesNotExist(Exception):
"""The requested object does not exist"""
silent_variable_failure = True
class MultipleObjectsReturned(Exception):
"""The query returned multiple objects when only one was expected."""
pass
class SuspiciousOperation(Exception):
"""The user did something suspicious"""
class SuspiciousMultipartForm(SuspiciousOperation):
"""Suspect MIME request in multipart form data"""
pass
class SuspiciousFileOperation(SuspiciousOperation):
"""A Suspicious filesystem operation was attempted"""
pass
class DisallowedHost(SuspiciousOperation):
"""HTTP_HOST header contains invalid value"""
pass
class DisallowedRedirect(SuspiciousOperation):
"""Redirect to scheme not in allowed list"""
pass
class TooManyFieldsSent(SuspiciousOperation):
"""
The number of fields in a GET or POST request exceeded
settings.DATA_UPLOAD_MAX_NUMBER_FIELDS.
"""
pass
class TooManyFilesSent(SuspiciousOperation):
"""
The number of fields in a GET or POST request exceeded
settings.DATA_UPLOAD_MAX_NUMBER_FILES.
"""
pass
class RequestDataTooBig(SuspiciousOperation):
"""
The size of the request (excluding any file uploads) exceeded
settings.DATA_UPLOAD_MAX_MEMORY_SIZE.
"""
pass
class RequestAborted(Exception):
"""The request was closed before it was completed, or timed out."""
pass
class BadRequest(Exception):
"""The request is malformed and cannot be processed."""
pass
class PermissionDenied(Exception):
"""The user did not have permission to do that"""
pass
class ViewDoesNotExist(Exception):
"""The requested view does not exist"""
pass
class MiddlewareNotUsed(Exception):
"""This middleware is not used in this server configuration"""
pass
class ImproperlyConfigured(Exception):
"""Django is somehow improperly configured"""
pass
class FieldError(Exception):
"""Some kind of problem with a model field."""
pass
NON_FIELD_ERRORS = "__all__"
class ValidationError(Exception):
"""An error while validating data."""
def __init__(self, message, code=None, params=None):
"""
The `message` argument can be a single error, a list of errors, or a
dictionary that maps field names to lists of errors. What we define as
an "error" can be either a simple string or an instance of
ValidationError with its message attribute set, and what we define as
list or dictionary can be an actual `list` or `dict` or an instance
of ValidationError with its `error_list` or `error_dict` attribute set.
"""
super().__init__(message, code, params)
if isinstance(message, ValidationError):
if hasattr(message, "error_dict"):
message = message.error_dict
elif not hasattr(message, "message"):
message = message.error_list
else:
message, code, params = message.message, message.code, message.params
if isinstance(message, dict):
self.error_dict = {}
for field, messages in message.items():
if not isinstance(messages, ValidationError):
messages = ValidationError(messages)
self.error_dict[field] = messages.error_list
elif isinstance(message, list):
self.error_list = []
for message in message:
# Normalize plain strings to instances of ValidationError.
if not isinstance(message, ValidationError):
message = ValidationError(message)
if hasattr(message, "error_dict"):
self.error_list.extend(sum(message.error_dict.values(), []))
else:
self.error_list.extend(message.error_list)
else:
self.message = message
self.code = code
self.params = params
self.error_list = [self]
@property
def message_dict(self):
# Trigger an AttributeError if this ValidationError
# doesn't have an error_dict.
getattr(self, "error_dict")
return dict(self)
@property
def messages(self):
if hasattr(self, "error_dict"):
return sum(dict(self).values(), [])
return list(self)
def update_error_dict(self, error_dict):
if hasattr(self, "error_dict"):
for field, error_list in self.error_dict.items():
error_dict.setdefault(field, []).extend(error_list)
else:
error_dict.setdefault(NON_FIELD_ERRORS, []).extend(self.error_list)
return error_dict
def __iter__(self):
if hasattr(self, "error_dict"):
for field, errors in self.error_dict.items():
yield field, list(ValidationError(errors))
else:
for error in self.error_list:
message = error.message
if error.params:
message %= error.params
yield str(message)
def __str__(self):
if hasattr(self, "error_dict"):
return repr(dict(self))
return repr(list(self))
def __repr__(self):
return "ValidationError(%s)" % self
def __eq__(self, other):
if not isinstance(other, ValidationError):
return NotImplemented
return hash(self) == hash(other)
def __hash__(self):
if hasattr(self, "message"):
return hash(
(
self.message,
self.code,
make_hashable(self.params),
)
)
if hasattr(self, "error_dict"):
return hash(make_hashable(self.error_dict))
return hash(tuple(sorted(self.error_list, key=operator.attrgetter("message"))))
class EmptyResultSet(Exception):
"""A database query predicate is impossible."""
pass
class FullResultSet(Exception):
"""A database query predicate is matches everything."""
pass
class SynchronousOnlyOperation(Exception):
"""The user tried to call a sync-only function from an async context."""
pass

View File

@@ -0,0 +1,3 @@
from django.core.files.base import File
__all__ = ["File"]

View File

@@ -0,0 +1,161 @@
import os
from io import BytesIO, StringIO, UnsupportedOperation
from django.core.files.utils import FileProxyMixin
from django.utils.functional import cached_property
class File(FileProxyMixin):
DEFAULT_CHUNK_SIZE = 64 * 2**10
def __init__(self, file, name=None):
self.file = file
if name is None:
name = getattr(file, "name", None)
self.name = name
if hasattr(file, "mode"):
self.mode = file.mode
def __str__(self):
return self.name or ""
def __repr__(self):
return "<%s: %s>" % (self.__class__.__name__, self or "None")
def __bool__(self):
return bool(self.name)
def __len__(self):
return self.size
@cached_property
def size(self):
if hasattr(self.file, "size"):
return self.file.size
if hasattr(self.file, "name"):
try:
return os.path.getsize(self.file.name)
except (OSError, TypeError):
pass
if hasattr(self.file, "tell") and hasattr(self.file, "seek"):
pos = self.file.tell()
self.file.seek(0, os.SEEK_END)
size = self.file.tell()
self.file.seek(pos)
return size
raise AttributeError("Unable to determine the file's size.")
def chunks(self, chunk_size=None):
"""
Read the file and yield chunks of ``chunk_size`` bytes (defaults to
``File.DEFAULT_CHUNK_SIZE``).
"""
chunk_size = chunk_size or self.DEFAULT_CHUNK_SIZE
try:
self.seek(0)
except (AttributeError, UnsupportedOperation):
pass
while True:
data = self.read(chunk_size)
if not data:
break
yield data
def multiple_chunks(self, chunk_size=None):
"""
Return ``True`` if you can expect multiple chunks.
NB: If a particular file representation is in memory, subclasses should
always return ``False`` -- there's no good reason to read from memory in
chunks.
"""
return self.size > (chunk_size or self.DEFAULT_CHUNK_SIZE)
def __iter__(self):
# Iterate over this file-like object by newlines
buffer_ = None
for chunk in self.chunks():
for line in chunk.splitlines(True):
if buffer_:
if endswith_cr(buffer_) and not equals_lf(line):
# Line split after a \r newline; yield buffer_.
yield buffer_
# Continue with line.
else:
# Line either split without a newline (line
# continues after buffer_) or with \r\n
# newline (line == b'\n').
line = buffer_ + line
# buffer_ handled, clear it.
buffer_ = None
# If this is the end of a \n or \r\n line, yield.
if endswith_lf(line):
yield line
else:
buffer_ = line
if buffer_ is not None:
yield buffer_
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, tb):
self.close()
def open(self, mode=None):
if not self.closed:
self.seek(0)
elif self.name and os.path.exists(self.name):
self.file = open(self.name, mode or self.mode)
else:
raise ValueError("The file cannot be reopened.")
return self
def close(self):
self.file.close()
class ContentFile(File):
"""
A File-like object that takes just raw content, rather than an actual file.
"""
def __init__(self, content, name=None):
stream_class = StringIO if isinstance(content, str) else BytesIO
super().__init__(stream_class(content), name=name)
self.size = len(content)
def __str__(self):
return "Raw content"
def __bool__(self):
return True
def open(self, mode=None):
self.seek(0)
return self
def close(self):
pass
def write(self, data):
self.__dict__.pop("size", None) # Clear the computed size.
return self.file.write(data)
def endswith_cr(line):
"""Return True if line (a text or bytestring) ends with '\r'."""
return line.endswith("\r" if isinstance(line, str) else b"\r")
def endswith_lf(line):
"""Return True if line (a text or bytestring) ends with '\n'."""
return line.endswith("\n" if isinstance(line, str) else b"\n")
def equals_lf(line):
"""Return True if line (a text or bytestring) equals '\n'."""
return line == ("\n" if isinstance(line, str) else b"\n")

View File

@@ -0,0 +1,88 @@
"""
Utility functions for handling images.
Requires Pillow as you might imagine.
"""
import struct
import zlib
from django.core.files import File
class ImageFile(File):
"""
A mixin for use alongside django.core.files.base.File, which provides
additional features for dealing with images.
"""
@property
def width(self):
return self._get_image_dimensions()[0]
@property
def height(self):
return self._get_image_dimensions()[1]
def _get_image_dimensions(self):
if not hasattr(self, "_dimensions_cache"):
close = self.closed
self.open()
self._dimensions_cache = get_image_dimensions(self, close=close)
return self._dimensions_cache
def get_image_dimensions(file_or_path, close=False):
"""
Return the (width, height) of an image, given an open file or a path. Set
'close' to True to close the file at the end if it is initially in an open
state.
"""
from PIL import ImageFile as PillowImageFile
p = PillowImageFile.Parser()
if hasattr(file_or_path, "read"):
file = file_or_path
file_pos = file.tell()
file.seek(0)
else:
try:
file = open(file_or_path, "rb")
except OSError:
return (None, None)
close = True
try:
# Most of the time Pillow only needs a small chunk to parse the image
# and get the dimensions, but with some TIFF files Pillow needs to
# parse the whole file.
chunk_size = 1024
while 1:
data = file.read(chunk_size)
if not data:
break
try:
p.feed(data)
except zlib.error as e:
# ignore zlib complaining on truncated stream, just feed more
# data to parser (ticket #19457).
if e.args[0].startswith("Error -5"):
pass
else:
raise
except struct.error:
# Ignore PIL failing on a too short buffer when reads return
# less bytes than expected. Skip and feed more data to the
# parser (ticket #24544).
pass
except RuntimeError:
# e.g. "RuntimeError: could not create decoder object" for
# WebP files. A different chunk_size may work.
pass
if p.image:
return p.image.size
chunk_size *= 2
return (None, None)
finally:
if close:
file.close()
else:
file.seek(file_pos)

View File

@@ -0,0 +1,127 @@
"""
Portable file locking utilities.
Based partially on an example by Jonathan Feignberg in the Python
Cookbook [1] (licensed under the Python Software License) and a ctypes port by
Anatoly Techtonik for Roundup [2] (license [3]).
[1] https://code.activestate.com/recipes/65203/
[2] https://sourceforge.net/p/roundup/code/ci/default/tree/roundup/backends/portalocker.py # NOQA
[3] https://sourceforge.net/p/roundup/code/ci/default/tree/COPYING.txt
Example Usage::
>>> from django.core.files import locks
>>> with open('./file', 'wb') as f:
... locks.lock(f, locks.LOCK_EX)
... f.write('Django')
"""
import os
__all__ = ("LOCK_EX", "LOCK_SH", "LOCK_NB", "lock", "unlock")
def _fd(f):
"""Get a filedescriptor from something which could be a file or an fd."""
return f.fileno() if hasattr(f, "fileno") else f
if os.name == "nt":
import msvcrt
from ctypes import (
POINTER,
Structure,
Union,
WinDLL,
byref,
c_int64,
c_ulong,
c_void_p,
sizeof,
)
from ctypes.wintypes import BOOL, DWORD, HANDLE
LOCK_SH = 0 # the default
LOCK_NB = 0x1 # LOCKFILE_FAIL_IMMEDIATELY
LOCK_EX = 0x2 # LOCKFILE_EXCLUSIVE_LOCK
# --- Adapted from the pyserial project ---
# detect size of ULONG_PTR
if sizeof(c_ulong) != sizeof(c_void_p):
ULONG_PTR = c_int64
else:
ULONG_PTR = c_ulong
PVOID = c_void_p
# --- Union inside Structure by stackoverflow:3480240 ---
class _OFFSET(Structure):
_fields_ = [("Offset", DWORD), ("OffsetHigh", DWORD)]
class _OFFSET_UNION(Union):
_anonymous_ = ["_offset"]
_fields_ = [("_offset", _OFFSET), ("Pointer", PVOID)]
class OVERLAPPED(Structure):
_anonymous_ = ["_offset_union"]
_fields_ = [
("Internal", ULONG_PTR),
("InternalHigh", ULONG_PTR),
("_offset_union", _OFFSET_UNION),
("hEvent", HANDLE),
]
LPOVERLAPPED = POINTER(OVERLAPPED)
# --- Define function prototypes for extra safety ---
kernel32 = WinDLL("kernel32")
LockFileEx = kernel32.LockFileEx
LockFileEx.restype = BOOL
LockFileEx.argtypes = [HANDLE, DWORD, DWORD, DWORD, DWORD, LPOVERLAPPED]
UnlockFileEx = kernel32.UnlockFileEx
UnlockFileEx.restype = BOOL
UnlockFileEx.argtypes = [HANDLE, DWORD, DWORD, DWORD, LPOVERLAPPED]
def lock(f, flags):
hfile = msvcrt.get_osfhandle(_fd(f))
overlapped = OVERLAPPED()
ret = LockFileEx(hfile, flags, 0, 0, 0xFFFF0000, byref(overlapped))
return bool(ret)
def unlock(f):
hfile = msvcrt.get_osfhandle(_fd(f))
overlapped = OVERLAPPED()
ret = UnlockFileEx(hfile, 0, 0, 0xFFFF0000, byref(overlapped))
return bool(ret)
else:
try:
import fcntl
LOCK_SH = fcntl.LOCK_SH # shared lock
LOCK_NB = fcntl.LOCK_NB # non-blocking
LOCK_EX = fcntl.LOCK_EX
except (ImportError, AttributeError):
# File locking is not supported.
LOCK_EX = LOCK_SH = LOCK_NB = 0
# Dummy functions that don't do anything.
def lock(f, flags):
# File is not locked
return False
def unlock(f):
# File is unlocked
return True
else:
def lock(f, flags):
try:
fcntl.flock(_fd(f), flags)
return True
except BlockingIOError:
return False
def unlock(f):
fcntl.flock(_fd(f), fcntl.LOCK_UN)
return True

View File

@@ -0,0 +1,102 @@
"""
Move a file in the safest way possible::
>>> from django.core.files.move import file_move_safe
>>> file_move_safe("/tmp/old_file", "/tmp/new_file")
"""
import os
from shutil import copymode, copystat
from django.core.files import locks
__all__ = ["file_move_safe"]
def _samefile(src, dst):
# Macintosh, Unix.
if hasattr(os.path, "samefile"):
try:
return os.path.samefile(src, dst)
except OSError:
return False
# All other platforms: check for same pathname.
return os.path.normcase(os.path.abspath(src)) == os.path.normcase(
os.path.abspath(dst)
)
def file_move_safe(
old_file_name, new_file_name, chunk_size=1024 * 64, allow_overwrite=False
):
"""
Move a file from one location to another in the safest way possible.
First, try ``os.rename``, which is simple but will break across filesystems.
If that fails, stream manually from one file to another in pure Python.
If the destination file exists and ``allow_overwrite`` is ``False``, raise
``FileExistsError``.
"""
# There's no reason to move if we don't have to.
if _samefile(old_file_name, new_file_name):
return
try:
if not allow_overwrite and os.access(new_file_name, os.F_OK):
raise FileExistsError(
"Destination file %s exists and allow_overwrite is False."
% new_file_name
)
os.rename(old_file_name, new_file_name)
return
except OSError:
# OSError happens with os.rename() if moving to another filesystem or
# when moving opened files on certain operating systems.
pass
# first open the old file, so that it won't go away
with open(old_file_name, "rb") as old_file:
# now open the new file, not forgetting allow_overwrite
fd = os.open(
new_file_name,
(
os.O_WRONLY
| os.O_CREAT
| getattr(os, "O_BINARY", 0)
| (os.O_EXCL if not allow_overwrite else 0)
),
)
try:
locks.lock(fd, locks.LOCK_EX)
current_chunk = None
while current_chunk != b"":
current_chunk = old_file.read(chunk_size)
os.write(fd, current_chunk)
finally:
locks.unlock(fd)
os.close(fd)
try:
copystat(old_file_name, new_file_name)
except PermissionError:
# Certain filesystems (e.g. CIFS) fail to copy the file's metadata if
# the type of the destination filesystem isn't the same as the source
# filesystem. This also happens with some SELinux-enabled systems.
# Ignore that, but try to set basic permissions.
try:
copymode(old_file_name, new_file_name)
except PermissionError:
pass
try:
os.remove(old_file_name)
except PermissionError as e:
# Certain operating systems (Cygwin and Windows)
# fail when deleting opened files, ignore it. (For the
# systems where this happens, temporary files will be auto-deleted
# on close anyway.)
if getattr(e, "winerror", 0) != 32:
raise

View File

@@ -0,0 +1,46 @@
import warnings
from django.conf import DEFAULT_STORAGE_ALIAS, settings
from django.utils.deprecation import RemovedInDjango51Warning
from django.utils.functional import LazyObject
from django.utils.module_loading import import_string
from .base import Storage
from .filesystem import FileSystemStorage
from .handler import InvalidStorageError, StorageHandler
from .memory import InMemoryStorage
__all__ = (
"FileSystemStorage",
"InMemoryStorage",
"Storage",
"DefaultStorage",
"default_storage",
"get_storage_class",
"InvalidStorageError",
"StorageHandler",
"storages",
)
GET_STORAGE_CLASS_DEPRECATED_MSG = (
"django.core.files.storage.get_storage_class is deprecated in favor of "
"using django.core.files.storage.storages."
)
def get_storage_class(import_path=None):
warnings.warn(
GET_STORAGE_CLASS_DEPRECATED_MSG,
RemovedInDjango51Warning,
stacklevel=2,
)
return import_string(import_path or settings.DEFAULT_FILE_STORAGE)
class DefaultStorage(LazyObject):
def _setup(self):
self._wrapped = storages[DEFAULT_STORAGE_ALIAS]
storages = StorageHandler()
default_storage = DefaultStorage()

View File

@@ -0,0 +1,190 @@
import os
import pathlib
from django.core.exceptions import SuspiciousFileOperation
from django.core.files import File
from django.core.files.utils import validate_file_name
from django.utils.crypto import get_random_string
from django.utils.text import get_valid_filename
class Storage:
"""
A base storage class, providing some default behaviors that all other
storage systems can inherit or override, as necessary.
"""
# The following methods represent a public interface to private methods.
# These shouldn't be overridden by subclasses unless absolutely necessary.
def open(self, name, mode="rb"):
"""Retrieve the specified file from storage."""
return self._open(name, mode)
def save(self, name, content, max_length=None):
"""
Save new content to the file specified by name. The content should be
a proper File object or any Python file-like object, ready to be read
from the beginning.
"""
# Get the proper name for the file, as it will actually be saved.
if name is None:
name = content.name
if not hasattr(content, "chunks"):
content = File(content, name)
name = self.get_available_name(name, max_length=max_length)
name = self._save(name, content)
# Ensure that the name returned from the storage system is still valid.
validate_file_name(name, allow_relative_path=True)
return name
# These methods are part of the public API, with default implementations.
def get_valid_name(self, name):
"""
Return a filename, based on the provided filename, that's suitable for
use in the target storage system.
"""
return get_valid_filename(name)
def get_alternative_name(self, file_root, file_ext):
"""
Return an alternative filename, by adding an underscore and a random 7
character alphanumeric string (before the file extension, if one
exists) to the filename.
"""
return "%s_%s%s" % (file_root, get_random_string(7), file_ext)
def get_available_name(self, name, max_length=None):
"""
Return a filename that's free on the target storage system and
available for new content to be written to.
"""
name = str(name).replace("\\", "/")
dir_name, file_name = os.path.split(name)
if ".." in pathlib.PurePath(dir_name).parts:
raise SuspiciousFileOperation(
"Detected path traversal attempt in '%s'" % dir_name
)
validate_file_name(file_name)
file_root, file_ext = os.path.splitext(file_name)
# If the filename already exists, generate an alternative filename
# until it doesn't exist.
# Truncate original name if required, so the new filename does not
# exceed the max_length.
while self.exists(name) or (max_length and len(name) > max_length):
# file_ext includes the dot.
name = os.path.join(
dir_name, self.get_alternative_name(file_root, file_ext)
)
if max_length is None:
continue
# Truncate file_root if max_length exceeded.
truncation = len(name) - max_length
if truncation > 0:
file_root = file_root[:-truncation]
# Entire file_root was truncated in attempt to find an
# available filename.
if not file_root:
raise SuspiciousFileOperation(
'Storage can not find an available filename for "%s". '
"Please make sure that the corresponding file field "
'allows sufficient "max_length".' % name
)
name = os.path.join(
dir_name, self.get_alternative_name(file_root, file_ext)
)
return name
def generate_filename(self, filename):
"""
Validate the filename by calling get_valid_name() and return a filename
to be passed to the save() method.
"""
filename = str(filename).replace("\\", "/")
# `filename` may include a path as returned by FileField.upload_to.
dirname, filename = os.path.split(filename)
if ".." in pathlib.PurePath(dirname).parts:
raise SuspiciousFileOperation(
"Detected path traversal attempt in '%s'" % dirname
)
return os.path.normpath(os.path.join(dirname, self.get_valid_name(filename)))
def path(self, name):
"""
Return a local filesystem path where the file can be retrieved using
Python's built-in open() function. Storage systems that can't be
accessed using open() should *not* implement this method.
"""
raise NotImplementedError("This backend doesn't support absolute paths.")
# The following methods form the public API for storage systems, but with
# no default implementations. Subclasses must implement *all* of these.
def delete(self, name):
"""
Delete the specified file from the storage system.
"""
raise NotImplementedError(
"subclasses of Storage must provide a delete() method"
)
def exists(self, name):
"""
Return True if a file referenced by the given name already exists in the
storage system, or False if the name is available for a new file.
"""
raise NotImplementedError(
"subclasses of Storage must provide an exists() method"
)
def listdir(self, path):
"""
List the contents of the specified path. Return a 2-tuple of lists:
the first item being directories, the second item being files.
"""
raise NotImplementedError(
"subclasses of Storage must provide a listdir() method"
)
def size(self, name):
"""
Return the total size, in bytes, of the file specified by name.
"""
raise NotImplementedError("subclasses of Storage must provide a size() method")
def url(self, name):
"""
Return an absolute URL where the file's contents can be accessed
directly by a web browser.
"""
raise NotImplementedError("subclasses of Storage must provide a url() method")
def get_accessed_time(self, name):
"""
Return the last accessed time (as a datetime) of the file specified by
name. The datetime will be timezone-aware if USE_TZ=True.
"""
raise NotImplementedError(
"subclasses of Storage must provide a get_accessed_time() method"
)
def get_created_time(self, name):
"""
Return the creation time (as a datetime) of the file specified by name.
The datetime will be timezone-aware if USE_TZ=True.
"""
raise NotImplementedError(
"subclasses of Storage must provide a get_created_time() method"
)
def get_modified_time(self, name):
"""
Return the last modified time (as a datetime) of the file specified by
name. The datetime will be timezone-aware if USE_TZ=True.
"""
raise NotImplementedError(
"subclasses of Storage must provide a get_modified_time() method"
)

View File

@@ -0,0 +1,207 @@
import os
from datetime import datetime, timezone
from urllib.parse import urljoin
from django.conf import settings
from django.core.files import File, locks
from django.core.files.move import file_move_safe
from django.core.signals import setting_changed
from django.utils._os import safe_join
from django.utils.deconstruct import deconstructible
from django.utils.encoding import filepath_to_uri
from django.utils.functional import cached_property
from .base import Storage
from .mixins import StorageSettingsMixin
@deconstructible(path="django.core.files.storage.FileSystemStorage")
class FileSystemStorage(Storage, StorageSettingsMixin):
"""
Standard filesystem storage
"""
# The combination of O_CREAT and O_EXCL makes os.open() raise OSError if
# the file already exists before it's opened.
OS_OPEN_FLAGS = os.O_WRONLY | os.O_CREAT | os.O_EXCL | getattr(os, "O_BINARY", 0)
def __init__(
self,
location=None,
base_url=None,
file_permissions_mode=None,
directory_permissions_mode=None,
):
self._location = location
self._base_url = base_url
self._file_permissions_mode = file_permissions_mode
self._directory_permissions_mode = directory_permissions_mode
setting_changed.connect(self._clear_cached_properties)
@cached_property
def base_location(self):
return self._value_or_setting(self._location, settings.MEDIA_ROOT)
@cached_property
def location(self):
return os.path.abspath(self.base_location)
@cached_property
def base_url(self):
if self._base_url is not None and not self._base_url.endswith("/"):
self._base_url += "/"
return self._value_or_setting(self._base_url, settings.MEDIA_URL)
@cached_property
def file_permissions_mode(self):
return self._value_or_setting(
self._file_permissions_mode, settings.FILE_UPLOAD_PERMISSIONS
)
@cached_property
def directory_permissions_mode(self):
return self._value_or_setting(
self._directory_permissions_mode, settings.FILE_UPLOAD_DIRECTORY_PERMISSIONS
)
def _open(self, name, mode="rb"):
return File(open(self.path(name), mode))
def _save(self, name, content):
full_path = self.path(name)
# Create any intermediate directories that do not exist.
directory = os.path.dirname(full_path)
try:
if self.directory_permissions_mode is not None:
# Set the umask because os.makedirs() doesn't apply the "mode"
# argument to intermediate-level directories.
old_umask = os.umask(0o777 & ~self.directory_permissions_mode)
try:
os.makedirs(
directory, self.directory_permissions_mode, exist_ok=True
)
finally:
os.umask(old_umask)
else:
os.makedirs(directory, exist_ok=True)
except FileExistsError:
raise FileExistsError("%s exists and is not a directory." % directory)
# There's a potential race condition between get_available_name and
# saving the file; it's possible that two threads might return the
# same name, at which point all sorts of fun happens. So we need to
# try to create the file, but if it already exists we have to go back
# to get_available_name() and try again.
while True:
try:
# This file has a file path that we can move.
if hasattr(content, "temporary_file_path"):
file_move_safe(content.temporary_file_path(), full_path)
# This is a normal uploadedfile that we can stream.
else:
# The current umask value is masked out by os.open!
fd = os.open(full_path, self.OS_OPEN_FLAGS, 0o666)
_file = None
try:
locks.lock(fd, locks.LOCK_EX)
for chunk in content.chunks():
if _file is None:
mode = "wb" if isinstance(chunk, bytes) else "wt"
_file = os.fdopen(fd, mode)
_file.write(chunk)
finally:
locks.unlock(fd)
if _file is not None:
_file.close()
else:
os.close(fd)
except FileExistsError:
# A new name is needed if the file exists.
name = self.get_available_name(name)
full_path = self.path(name)
else:
# OK, the file save worked. Break out of the loop.
break
if self.file_permissions_mode is not None:
os.chmod(full_path, self.file_permissions_mode)
# Ensure the saved path is always relative to the storage root.
name = os.path.relpath(full_path, self.location)
# Ensure the moved file has the same gid as the storage root.
self._ensure_location_group_id(full_path)
# Store filenames with forward slashes, even on Windows.
return str(name).replace("\\", "/")
def _ensure_location_group_id(self, full_path):
if os.name == "posix":
file_gid = os.stat(full_path).st_gid
location_gid = os.stat(self.location).st_gid
if file_gid != location_gid:
try:
os.chown(full_path, uid=-1, gid=location_gid)
except PermissionError:
pass
def delete(self, name):
if not name:
raise ValueError("The name must be given to delete().")
name = self.path(name)
# If the file or directory exists, delete it from the filesystem.
try:
if os.path.isdir(name):
os.rmdir(name)
else:
os.remove(name)
except FileNotFoundError:
# FileNotFoundError is raised if the file or directory was removed
# concurrently.
pass
def exists(self, name):
return os.path.lexists(self.path(name))
def listdir(self, path):
path = self.path(path)
directories, files = [], []
with os.scandir(path) as entries:
for entry in entries:
if entry.is_dir():
directories.append(entry.name)
else:
files.append(entry.name)
return directories, files
def path(self, name):
return safe_join(self.location, name)
def size(self, name):
return os.path.getsize(self.path(name))
def url(self, name):
if self.base_url is None:
raise ValueError("This file is not accessible via a URL.")
url = filepath_to_uri(name)
if url is not None:
url = url.lstrip("/")
return urljoin(self.base_url, url)
def _datetime_from_timestamp(self, ts):
"""
If timezone support is enabled, make an aware datetime object in UTC;
otherwise make a naive one in the local timezone.
"""
tz = timezone.utc if settings.USE_TZ else None
return datetime.fromtimestamp(ts, tz=tz)
def get_accessed_time(self, name):
return self._datetime_from_timestamp(os.path.getatime(self.path(name)))
def get_created_time(self, name):
return self._datetime_from_timestamp(os.path.getctime(self.path(name)))
def get_modified_time(self, name):
return self._datetime_from_timestamp(os.path.getmtime(self.path(name)))

View File

@@ -0,0 +1,55 @@
from django.conf import DEFAULT_STORAGE_ALIAS, STATICFILES_STORAGE_ALIAS, settings
from django.core.exceptions import ImproperlyConfigured
from django.utils.functional import cached_property
from django.utils.module_loading import import_string
class InvalidStorageError(ImproperlyConfigured):
pass
class StorageHandler:
def __init__(self, backends=None):
# backends is an optional dict of storage backend definitions
# (structured like settings.STORAGES).
self._backends = backends
self._storages = {}
@cached_property
def backends(self):
if self._backends is None:
self._backends = settings.STORAGES.copy()
# RemovedInDjango51Warning.
if settings.is_overridden("DEFAULT_FILE_STORAGE"):
self._backends[DEFAULT_STORAGE_ALIAS] = {
"BACKEND": settings.DEFAULT_FILE_STORAGE
}
if settings.is_overridden("STATICFILES_STORAGE"):
self._backends[STATICFILES_STORAGE_ALIAS] = {
"BACKEND": settings.STATICFILES_STORAGE
}
return self._backends
def __getitem__(self, alias):
try:
return self._storages[alias]
except KeyError:
try:
params = self.backends[alias]
except KeyError:
raise InvalidStorageError(
f"Could not find config for '{alias}' in settings.STORAGES."
)
storage = self.create_storage(params)
self._storages[alias] = storage
return storage
def create_storage(self, params):
params = params.copy()
backend = params.pop("BACKEND")
options = params.pop("OPTIONS", {})
try:
storage_cls = import_string(backend)
except ImportError as e:
raise InvalidStorageError(f"Could not find backend {backend!r}: {e}") from e
return storage_cls(**options)

View File

@@ -0,0 +1,290 @@
"""
Based on dj-inmemorystorage (BSD) by Cody Soyland, Seán Hayes, Tore Birkeland,
and Nick Presta.
"""
import errno
import io
import os
import pathlib
from urllib.parse import urljoin
from django.conf import settings
from django.core.files.base import ContentFile
from django.core.signals import setting_changed
from django.utils._os import safe_join
from django.utils.deconstruct import deconstructible
from django.utils.encoding import filepath_to_uri
from django.utils.functional import cached_property
from django.utils.timezone import now
from .base import Storage
from .mixins import StorageSettingsMixin
__all__ = ("InMemoryStorage",)
class TimingMixin:
def _initialize_times(self):
self.created_time = now()
self.accessed_time = self.created_time
self.modified_time = self.created_time
def _update_accessed_time(self):
self.accessed_time = now()
def _update_modified_time(self):
self.modified_time = now()
class InMemoryFileNode(ContentFile, TimingMixin):
"""
Helper class representing an in-memory file node.
Handle unicode/bytes conversion during I/O operations and record creation,
modification, and access times.
"""
def __init__(self, content="", name=""):
self.file = None
self._content_type = type(content)
self._initialize_stream()
self._initialize_times()
def open(self, mode):
self._convert_stream_content(mode)
self._update_accessed_time()
return super().open(mode)
def write(self, data):
super().write(data)
self._update_modified_time()
def _initialize_stream(self):
"""Initialize underlying stream according to the content type."""
self.file = io.BytesIO() if self._content_type == bytes else io.StringIO()
def _convert_stream_content(self, mode):
"""Convert actual file content according to the opening mode."""
new_content_type = bytes if "b" in mode else str
# No conversion needed.
if self._content_type == new_content_type:
return
content = self.file.getvalue()
content = content.encode() if isinstance(content, str) else content.decode()
self._content_type = new_content_type
self._initialize_stream()
self.file.write(content)
class InMemoryDirNode(TimingMixin):
"""
Helper class representing an in-memory directory node.
Handle path navigation of directory trees, creating missing nodes if
needed.
"""
def __init__(self):
self._children = {}
self._initialize_times()
def resolve(self, path, create_if_missing=False, leaf_cls=None, check_exists=True):
"""
Navigate current directory tree, returning node matching path or
creating a new one, if missing.
- path: path of the node to search
- create_if_missing: create nodes if not exist. Defaults to False.
- leaf_cls: expected type of leaf node. Defaults to None.
- check_exists: if True and the leaf node does not exist, raise a
FileNotFoundError. Defaults to True.
"""
path_segments = list(pathlib.Path(path).parts)
current_node = self
while path_segments:
path_segment = path_segments.pop(0)
# If current node is a file node and there are unprocessed
# segments, raise an error.
if isinstance(current_node, InMemoryFileNode):
path_segments = os.path.split(path)
current_path = "/".join(
path_segments[: path_segments.index(path_segment)]
)
raise NotADirectoryError(
errno.ENOTDIR, os.strerror(errno.ENOTDIR), current_path
)
current_node = current_node._resolve_child(
path_segment,
create_if_missing,
leaf_cls if len(path_segments) == 0 else InMemoryDirNode,
)
if current_node is None:
break
if current_node is None and check_exists:
raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), path)
# If a leaf_cls is not None, check if leaf node is of right type.
if leaf_cls and not isinstance(current_node, leaf_cls):
error_cls, error_code = (
(NotADirectoryError, errno.ENOTDIR)
if leaf_cls is InMemoryDirNode
else (IsADirectoryError, errno.EISDIR)
)
raise error_cls(error_code, os.strerror(error_code), path)
return current_node
def _resolve_child(self, path_segment, create_if_missing, child_cls):
if create_if_missing:
self._update_accessed_time()
self._update_modified_time()
return self._children.setdefault(path_segment, child_cls())
return self._children.get(path_segment)
def listdir(self):
directories, files = [], []
for name, entry in self._children.items():
if isinstance(entry, InMemoryDirNode):
directories.append(name)
else:
files.append(name)
return directories, files
def remove_child(self, name):
if name in self._children:
self._update_accessed_time()
self._update_modified_time()
del self._children[name]
@deconstructible(path="django.core.files.storage.InMemoryStorage")
class InMemoryStorage(Storage, StorageSettingsMixin):
"""A storage saving files in memory."""
def __init__(
self,
location=None,
base_url=None,
file_permissions_mode=None,
directory_permissions_mode=None,
):
self._location = location
self._base_url = base_url
self._file_permissions_mode = file_permissions_mode
self._directory_permissions_mode = directory_permissions_mode
self._root = InMemoryDirNode()
self._resolve(
self.base_location, create_if_missing=True, leaf_cls=InMemoryDirNode
)
setting_changed.connect(self._clear_cached_properties)
@cached_property
def base_location(self):
return self._value_or_setting(self._location, settings.MEDIA_ROOT)
@cached_property
def location(self):
return os.path.abspath(self.base_location)
@cached_property
def base_url(self):
if self._base_url is not None and not self._base_url.endswith("/"):
self._base_url += "/"
return self._value_or_setting(self._base_url, settings.MEDIA_URL)
@cached_property
def file_permissions_mode(self):
return self._value_or_setting(
self._file_permissions_mode, settings.FILE_UPLOAD_PERMISSIONS
)
@cached_property
def directory_permissions_mode(self):
return self._value_or_setting(
self._directory_permissions_mode, settings.FILE_UPLOAD_DIRECTORY_PERMISSIONS
)
def _relative_path(self, name):
full_path = self.path(name)
return os.path.relpath(full_path, self.location)
def _resolve(self, name, create_if_missing=False, leaf_cls=None, check_exists=True):
try:
relative_path = self._relative_path(name)
return self._root.resolve(
relative_path,
create_if_missing=create_if_missing,
leaf_cls=leaf_cls,
check_exists=check_exists,
)
except NotADirectoryError as exc:
absolute_path = self.path(exc.filename)
raise FileExistsError(f"{absolute_path} exists and is not a directory.")
def _open(self, name, mode="rb"):
create_if_missing = "w" in mode
file_node = self._resolve(
name, create_if_missing=create_if_missing, leaf_cls=InMemoryFileNode
)
return file_node.open(mode)
def _save(self, name, content):
file_node = self._resolve(
name, create_if_missing=True, leaf_cls=InMemoryFileNode
)
fd = None
for chunk in content.chunks():
if fd is None:
mode = "wb" if isinstance(chunk, bytes) else "wt"
fd = file_node.open(mode)
fd.write(chunk)
if hasattr(content, "temporary_file_path"):
os.remove(content.temporary_file_path())
file_node.modified_time = now()
return self._relative_path(name).replace("\\", "/")
def path(self, name):
return safe_join(self.location, name)
def delete(self, name):
path, filename = os.path.split(name)
dir_node = self._resolve(path, check_exists=False)
if dir_node is None:
return None
dir_node.remove_child(filename)
def exists(self, name):
return self._resolve(name, check_exists=False) is not None
def listdir(self, path):
node = self._resolve(path, leaf_cls=InMemoryDirNode)
return node.listdir()
def size(self, name):
return len(self._open(name, "rb").file.getvalue())
def url(self, name):
if self.base_url is None:
raise ValueError("This file is not accessible via a URL.")
url = filepath_to_uri(name)
if url is not None:
url = url.lstrip("/")
return urljoin(self.base_url, url)
def get_accessed_time(self, name):
file_node = self._resolve(name)
return file_node.accessed_time
def get_created_time(self, name):
file_node = self._resolve(name)
return file_node.created_time
def get_modified_time(self, name):
file_node = self._resolve(name)
return file_node.modified_time

View File

@@ -0,0 +1,15 @@
class StorageSettingsMixin:
def _clear_cached_properties(self, setting, **kwargs):
"""Reset setting based property values."""
if setting == "MEDIA_ROOT":
self.__dict__.pop("base_location", None)
self.__dict__.pop("location", None)
elif setting == "MEDIA_URL":
self.__dict__.pop("base_url", None)
elif setting == "FILE_UPLOAD_PERMISSIONS":
self.__dict__.pop("file_permissions_mode", None)
elif setting == "FILE_UPLOAD_DIRECTORY_PERMISSIONS":
self.__dict__.pop("directory_permissions_mode", None)
def _value_or_setting(self, value, setting):
return setting if value is None else value

View File

@@ -0,0 +1,79 @@
"""
The temp module provides a NamedTemporaryFile that can be reopened in the same
process on any platform. Most platforms use the standard Python
tempfile.NamedTemporaryFile class, but Windows users are given a custom class.
This is needed because the Python implementation of NamedTemporaryFile uses the
O_TEMPORARY flag under Windows, which prevents the file from being reopened
if the same flag is not provided [1][2]. Note that this does not address the
more general issue of opening a file for writing and reading in multiple
processes in a manner that works across platforms.
The custom version of NamedTemporaryFile doesn't support the same keyword
arguments available in tempfile.NamedTemporaryFile.
1: https://mail.python.org/pipermail/python-list/2005-December/336957.html
2: https://bugs.python.org/issue14243
"""
import os
import tempfile
from django.core.files.utils import FileProxyMixin
__all__ = (
"NamedTemporaryFile",
"gettempdir",
)
if os.name == "nt":
class TemporaryFile(FileProxyMixin):
"""
Temporary file object constructor that supports reopening of the
temporary file in Windows.
Unlike tempfile.NamedTemporaryFile from the standard library,
__init__() doesn't support the 'delete', 'buffering', 'encoding', or
'newline' keyword arguments.
"""
def __init__(self, mode="w+b", bufsize=-1, suffix="", prefix="", dir=None):
fd, name = tempfile.mkstemp(suffix=suffix, prefix=prefix, dir=dir)
self.name = name
self.file = os.fdopen(fd, mode, bufsize)
self.close_called = False
# Because close can be called during shutdown
# we need to cache os.unlink and access it
# as self.unlink only
unlink = os.unlink
def close(self):
if not self.close_called:
self.close_called = True
try:
self.file.close()
except OSError:
pass
try:
self.unlink(self.name)
except OSError:
pass
def __del__(self):
self.close()
def __enter__(self):
self.file.__enter__()
return self
def __exit__(self, exc, value, tb):
self.file.__exit__(exc, value, tb)
NamedTemporaryFile = TemporaryFile
else:
NamedTemporaryFile = tempfile.NamedTemporaryFile
gettempdir = tempfile.gettempdir

View File

@@ -0,0 +1,150 @@
"""
Classes representing uploaded files.
"""
import os
from io import BytesIO
from django.conf import settings
from django.core.files import temp as tempfile
from django.core.files.base import File
from django.core.files.utils import validate_file_name
__all__ = (
"UploadedFile",
"TemporaryUploadedFile",
"InMemoryUploadedFile",
"SimpleUploadedFile",
)
class UploadedFile(File):
"""
An abstract uploaded file (``TemporaryUploadedFile`` and
``InMemoryUploadedFile`` are the built-in concrete subclasses).
An ``UploadedFile`` object behaves somewhat like a file object and
represents some file data that the user submitted with a form.
"""
def __init__(
self,
file=None,
name=None,
content_type=None,
size=None,
charset=None,
content_type_extra=None,
):
super().__init__(file, name)
self.size = size
self.content_type = content_type
self.charset = charset
self.content_type_extra = content_type_extra
def __repr__(self):
return "<%s: %s (%s)>" % (self.__class__.__name__, self.name, self.content_type)
def _get_name(self):
return self._name
def _set_name(self, name):
# Sanitize the file name so that it can't be dangerous.
if name is not None:
# Just use the basename of the file -- anything else is dangerous.
name = os.path.basename(name)
# File names longer than 255 characters can cause problems on older OSes.
if len(name) > 255:
name, ext = os.path.splitext(name)
ext = ext[:255]
name = name[: 255 - len(ext)] + ext
name = validate_file_name(name)
self._name = name
name = property(_get_name, _set_name)
class TemporaryUploadedFile(UploadedFile):
"""
A file uploaded to a temporary location (i.e. stream-to-disk).
"""
def __init__(self, name, content_type, size, charset, content_type_extra=None):
_, ext = os.path.splitext(name)
file = tempfile.NamedTemporaryFile(
suffix=".upload" + ext, dir=settings.FILE_UPLOAD_TEMP_DIR
)
super().__init__(file, name, content_type, size, charset, content_type_extra)
def temporary_file_path(self):
"""Return the full path of this file."""
return self.file.name
def close(self):
try:
return self.file.close()
except FileNotFoundError:
# The file was moved or deleted before the tempfile could unlink
# it. Still sets self.file.close_called and calls
# self.file.file.close() before the exception.
pass
class InMemoryUploadedFile(UploadedFile):
"""
A file uploaded into memory (i.e. stream-to-memory).
"""
def __init__(
self,
file,
field_name,
name,
content_type,
size,
charset,
content_type_extra=None,
):
super().__init__(file, name, content_type, size, charset, content_type_extra)
self.field_name = field_name
def open(self, mode=None):
self.file.seek(0)
return self
def chunks(self, chunk_size=None):
self.file.seek(0)
yield self.read()
def multiple_chunks(self, chunk_size=None):
# Since it's in memory, we'll never have multiple chunks.
return False
class SimpleUploadedFile(InMemoryUploadedFile):
"""
A simple representation of a file, which just has content, size, and a name.
"""
def __init__(self, name, content, content_type="text/plain"):
content = content or b""
super().__init__(
BytesIO(content), None, name, content_type, len(content), None, None
)
@classmethod
def from_dict(cls, file_dict):
"""
Create a SimpleUploadedFile object from a dictionary with keys:
- filename
- content-type
- content
"""
return cls(
file_dict["filename"],
file_dict["content"],
file_dict.get("content-type", "text/plain"),
)

View File

@@ -0,0 +1,251 @@
"""
Base file upload handler classes, and the built-in concrete subclasses
"""
import os
from io import BytesIO
from django.conf import settings
from django.core.files.uploadedfile import InMemoryUploadedFile, TemporaryUploadedFile
from django.utils.module_loading import import_string
__all__ = [
"UploadFileException",
"StopUpload",
"SkipFile",
"FileUploadHandler",
"TemporaryFileUploadHandler",
"MemoryFileUploadHandler",
"load_handler",
"StopFutureHandlers",
]
class UploadFileException(Exception):
"""
Any error having to do with uploading files.
"""
pass
class StopUpload(UploadFileException):
"""
This exception is raised when an upload must abort.
"""
def __init__(self, connection_reset=False):
"""
If ``connection_reset`` is ``True``, Django knows will halt the upload
without consuming the rest of the upload. This will cause the browser to
show a "connection reset" error.
"""
self.connection_reset = connection_reset
def __str__(self):
if self.connection_reset:
return "StopUpload: Halt current upload."
else:
return "StopUpload: Consume request data, then halt."
class SkipFile(UploadFileException):
"""
This exception is raised by an upload handler that wants to skip a given file.
"""
pass
class StopFutureHandlers(UploadFileException):
"""
Upload handlers that have handled a file and do not want future handlers to
run should raise this exception instead of returning None.
"""
pass
class FileUploadHandler:
"""
Base class for streaming upload handlers.
"""
chunk_size = 64 * 2**10 # : The default chunk size is 64 KB.
def __init__(self, request=None):
self.file_name = None
self.content_type = None
self.content_length = None
self.charset = None
self.content_type_extra = None
self.request = request
def handle_raw_input(
self, input_data, META, content_length, boundary, encoding=None
):
"""
Handle the raw input from the client.
Parameters:
:input_data:
An object that supports reading via .read().
:META:
``request.META``.
:content_length:
The (integer) value of the Content-Length header from the
client.
:boundary: The boundary from the Content-Type header. Be sure to
prepend two '--'.
"""
pass
def new_file(
self,
field_name,
file_name,
content_type,
content_length,
charset=None,
content_type_extra=None,
):
"""
Signal that a new file has been started.
Warning: As with any data from the client, you should not trust
content_length (and sometimes won't even get it).
"""
self.field_name = field_name
self.file_name = file_name
self.content_type = content_type
self.content_length = content_length
self.charset = charset
self.content_type_extra = content_type_extra
def receive_data_chunk(self, raw_data, start):
"""
Receive data from the streamed upload parser. ``start`` is the position
in the file of the chunk.
"""
raise NotImplementedError(
"subclasses of FileUploadHandler must provide a receive_data_chunk() method"
)
def file_complete(self, file_size):
"""
Signal that a file has completed. File size corresponds to the actual
size accumulated by all the chunks.
Subclasses should return a valid ``UploadedFile`` object.
"""
raise NotImplementedError(
"subclasses of FileUploadHandler must provide a file_complete() method"
)
def upload_complete(self):
"""
Signal that the upload is complete. Subclasses should perform cleanup
that is necessary for this handler.
"""
pass
def upload_interrupted(self):
"""
Signal that the upload was interrupted. Subclasses should perform
cleanup that is necessary for this handler.
"""
pass
class TemporaryFileUploadHandler(FileUploadHandler):
"""
Upload handler that streams data into a temporary file.
"""
def new_file(self, *args, **kwargs):
"""
Create the file object to append to as data is coming in.
"""
super().new_file(*args, **kwargs)
self.file = TemporaryUploadedFile(
self.file_name, self.content_type, 0, self.charset, self.content_type_extra
)
def receive_data_chunk(self, raw_data, start):
self.file.write(raw_data)
def file_complete(self, file_size):
self.file.seek(0)
self.file.size = file_size
return self.file
def upload_interrupted(self):
if hasattr(self, "file"):
temp_location = self.file.temporary_file_path()
try:
self.file.close()
os.remove(temp_location)
except FileNotFoundError:
pass
class MemoryFileUploadHandler(FileUploadHandler):
"""
File upload handler to stream uploads into memory (used for small files).
"""
def handle_raw_input(
self, input_data, META, content_length, boundary, encoding=None
):
"""
Use the content_length to signal whether or not this handler should be
used.
"""
# Check the content-length header to see if we should
# If the post is too large, we cannot use the Memory handler.
self.activated = content_length <= settings.FILE_UPLOAD_MAX_MEMORY_SIZE
def new_file(self, *args, **kwargs):
super().new_file(*args, **kwargs)
if self.activated:
self.file = BytesIO()
raise StopFutureHandlers()
def receive_data_chunk(self, raw_data, start):
"""Add the data to the BytesIO file."""
if self.activated:
self.file.write(raw_data)
else:
return raw_data
def file_complete(self, file_size):
"""Return a file object if this handler is activated."""
if not self.activated:
return
self.file.seek(0)
return InMemoryUploadedFile(
file=self.file,
field_name=self.field_name,
name=self.file_name,
content_type=self.content_type,
size=file_size,
charset=self.charset,
content_type_extra=self.content_type_extra,
)
def load_handler(path, *args, **kwargs):
"""
Given a path to a handler, return an instance of that handler.
E.g.::
>>> from django.http import HttpRequest
>>> request = HttpRequest()
>>> load_handler(
... 'django.core.files.uploadhandler.TemporaryFileUploadHandler',
... request,
... )
<TemporaryFileUploadHandler object at 0x...>
"""
return import_string(path)(*args, **kwargs)

View File

@@ -0,0 +1,78 @@
import os
import pathlib
from django.core.exceptions import SuspiciousFileOperation
def validate_file_name(name, allow_relative_path=False):
# Remove potentially dangerous names
if os.path.basename(name) in {"", ".", ".."}:
raise SuspiciousFileOperation("Could not derive file name from '%s'" % name)
if allow_relative_path:
# Use PurePosixPath() because this branch is checked only in
# FileField.generate_filename() where all file paths are expected to be
# Unix style (with forward slashes).
path = pathlib.PurePosixPath(name)
if path.is_absolute() or ".." in path.parts:
raise SuspiciousFileOperation(
"Detected path traversal attempt in '%s'" % name
)
elif name != os.path.basename(name):
raise SuspiciousFileOperation("File name '%s' includes path elements" % name)
return name
class FileProxyMixin:
"""
A mixin class used to forward file methods to an underlaying file
object. The internal file object has to be called "file"::
class FileProxy(FileProxyMixin):
def __init__(self, file):
self.file = file
"""
encoding = property(lambda self: self.file.encoding)
fileno = property(lambda self: self.file.fileno)
flush = property(lambda self: self.file.flush)
isatty = property(lambda self: self.file.isatty)
newlines = property(lambda self: self.file.newlines)
read = property(lambda self: self.file.read)
readinto = property(lambda self: self.file.readinto)
readline = property(lambda self: self.file.readline)
readlines = property(lambda self: self.file.readlines)
seek = property(lambda self: self.file.seek)
tell = property(lambda self: self.file.tell)
truncate = property(lambda self: self.file.truncate)
write = property(lambda self: self.file.write)
writelines = property(lambda self: self.file.writelines)
@property
def closed(self):
return not self.file or self.file.closed
def readable(self):
if self.closed:
return False
if hasattr(self.file, "readable"):
return self.file.readable()
return True
def writable(self):
if self.closed:
return False
if hasattr(self.file, "writable"):
return self.file.writable()
return "w" in getattr(self.file, "mode", "")
def seekable(self):
if self.closed:
return False
if hasattr(self.file, "seekable"):
return self.file.seekable()
return True
def __iter__(self):
return iter(self.file)

Some files were not shown because too many files have changed in this diff Show More