This commit is contained in:
Iliyan Angelov
2025-09-19 11:58:53 +03:00
parent 306b20e24a
commit 6b247e5b9f
11423 changed files with 1500615 additions and 778 deletions

View File

@@ -0,0 +1,21 @@
VERSION = (6, 0, 0)
__version__ = ".".join(map(str, VERSION))
def get_redis_connection(alias="default", write=True):
"""
Helper used for obtaining a raw redis client.
"""
from django.core.cache import caches
cache = caches[alias]
error_message = "This backend does not support this feature"
if not hasattr(cache, "client"):
raise NotImplementedError(error_message)
if not hasattr(cache.client, "get_client"):
raise NotImplementedError(error_message)
return cache.client.get_client(write)

View File

@@ -0,0 +1,274 @@
import functools
import logging
from typing import Any, Callable, Optional
from django import VERSION as DJANGO_VERSION
from django.conf import settings
from django.core.cache.backends.base import BaseCache
from django.utils.module_loading import import_string
from django_redis.exceptions import ConnectionInterrupted
CONNECTION_INTERRUPTED = object()
def omit_exception(
method: Optional[Callable] = None, return_value: Optional[Any] = None
):
"""
Simple decorator that intercepts connection
errors and ignores these if settings specify this.
"""
if method is None:
return functools.partial(omit_exception, return_value=return_value)
@functools.wraps(method)
def _decorator(self, *args, **kwargs):
try:
return method(self, *args, **kwargs)
except ConnectionInterrupted as e:
if self._ignore_exceptions:
if self._log_ignored_exceptions:
self.logger.exception("Exception ignored")
return return_value
raise e.__cause__ # noqa: B904
return _decorator
class RedisCache(BaseCache):
def __init__(self, server: str, params: dict[str, Any]) -> None:
super().__init__(params)
self._server = server
self._params = params
self._default_scan_itersize = getattr(
settings, "DJANGO_REDIS_SCAN_ITERSIZE", 10
)
options = params.get("OPTIONS", {})
self._client_cls = options.get(
"CLIENT_CLASS", "django_redis.client.DefaultClient"
)
self._client_cls = import_string(self._client_cls)
self._client = None
self._ignore_exceptions = options.get(
"IGNORE_EXCEPTIONS",
getattr(settings, "DJANGO_REDIS_IGNORE_EXCEPTIONS", False),
)
self._log_ignored_exceptions = getattr(
settings, "DJANGO_REDIS_LOG_IGNORED_EXCEPTIONS", False
)
self.logger = (
logging.getLogger(getattr(settings, "DJANGO_REDIS_LOGGER", __name__))
if self._log_ignored_exceptions
else None
)
@property
def client(self):
"""
Lazy client connection property.
"""
if self._client is None:
self._client = self._client_cls(self._server, self._params, self)
return self._client
@omit_exception
def set(self, *args, **kwargs):
return self.client.set(*args, **kwargs)
@omit_exception
def incr_version(self, *args, **kwargs):
return self.client.incr_version(*args, **kwargs)
@omit_exception
def add(self, *args, **kwargs):
return self.client.add(*args, **kwargs)
def get(self, key, default=None, version=None, client=None):
value = self._get(key, default, version, client)
if value is CONNECTION_INTERRUPTED:
value = default
return value
@omit_exception(return_value=CONNECTION_INTERRUPTED)
def _get(self, key, default, version, client):
return self.client.get(key, default=default, version=version, client=client)
@omit_exception
def delete(self, *args, **kwargs):
"""returns a boolean instead of int since django version 3.1"""
result = self.client.delete(*args, **kwargs)
return bool(result) if DJANGO_VERSION >= (3, 1, 0) else result
@omit_exception
def delete_pattern(self, *args, **kwargs):
kwargs.setdefault("itersize", self._default_scan_itersize)
return self.client.delete_pattern(*args, **kwargs)
@omit_exception
def delete_many(self, *args, **kwargs):
return self.client.delete_many(*args, **kwargs)
@omit_exception
def clear(self):
return self.client.clear()
@omit_exception(return_value={})
def get_many(self, *args, **kwargs):
return self.client.get_many(*args, **kwargs)
@omit_exception
def set_many(self, *args, **kwargs):
return self.client.set_many(*args, **kwargs)
@omit_exception
def incr(self, *args, **kwargs):
return self.client.incr(*args, **kwargs)
@omit_exception
def decr(self, *args, **kwargs):
return self.client.decr(*args, **kwargs)
@omit_exception
def has_key(self, *args, **kwargs):
return self.client.has_key(*args, **kwargs)
@omit_exception
def keys(self, *args, **kwargs):
return self.client.keys(*args, **kwargs)
@omit_exception
def iter_keys(self, *args, **kwargs):
return self.client.iter_keys(*args, **kwargs)
@omit_exception
def ttl(self, *args, **kwargs):
return self.client.ttl(*args, **kwargs)
@omit_exception
def pttl(self, *args, **kwargs):
return self.client.pttl(*args, **kwargs)
@omit_exception
def persist(self, *args, **kwargs):
return self.client.persist(*args, **kwargs)
@omit_exception
def expire(self, *args, **kwargs):
return self.client.expire(*args, **kwargs)
@omit_exception
def expire_at(self, *args, **kwargs):
return self.client.expire_at(*args, **kwargs)
@omit_exception
def pexpire(self, *args, **kwargs):
return self.client.pexpire(*args, **kwargs)
@omit_exception
def pexpire_at(self, *args, **kwargs):
return self.client.pexpire_at(*args, **kwargs)
@omit_exception
def lock(self, *args, **kwargs):
return self.client.lock(*args, **kwargs)
@omit_exception
def close(self, **kwargs):
self.client.close(**kwargs)
@omit_exception
def touch(self, *args, **kwargs):
return self.client.touch(*args, **kwargs)
@omit_exception
def sadd(self, *args, **kwargs):
return self.client.sadd(*args, **kwargs)
@omit_exception
def scard(self, *args, **kwargs):
return self.client.scard(*args, **kwargs)
@omit_exception
def sdiff(self, *args, **kwargs):
return self.client.sdiff(*args, **kwargs)
@omit_exception
def sdiffstore(self, *args, **kwargs):
return self.client.sdiffstore(*args, **kwargs)
@omit_exception
def sinter(self, *args, **kwargs):
return self.client.sinter(*args, **kwargs)
@omit_exception
def sinterstore(self, *args, **kwargs):
return self.client.sinterstore(*args, **kwargs)
@omit_exception
def sismember(self, *args, **kwargs):
return self.client.sismember(*args, **kwargs)
@omit_exception
def smembers(self, *args, **kwargs):
return self.client.smembers(*args, **kwargs)
@omit_exception
def smove(self, *args, **kwargs):
return self.client.smove(*args, **kwargs)
@omit_exception
def spop(self, *args, **kwargs):
return self.client.spop(*args, **kwargs)
@omit_exception
def srandmember(self, *args, **kwargs):
return self.client.srandmember(*args, **kwargs)
@omit_exception
def srem(self, *args, **kwargs):
return self.client.srem(*args, **kwargs)
@omit_exception
def sscan(self, *args, **kwargs):
return self.client.sscan(*args, **kwargs)
@omit_exception
def sscan_iter(self, *args, **kwargs):
return self.client.sscan_iter(*args, **kwargs)
@omit_exception
def smismember(self, *args, **kwargs):
return self.client.smismember(*args, **kwargs)
@omit_exception
def sunion(self, *args, **kwargs):
return self.client.sunion(*args, **kwargs)
@omit_exception
def sunionstore(self, *args, **kwargs):
return self.client.sunionstore(*args, **kwargs)
@omit_exception
def hset(self, *args, **kwargs):
return self.client.hset(*args, **kwargs)
@omit_exception
def hdel(self, *args, **kwargs):
return self.client.hdel(*args, **kwargs)
@omit_exception
def hlen(self, *args, **kwargs):
return self.client.hlen(*args, **kwargs)
@omit_exception
def hkeys(self, *args, **kwargs):
return self.client.hkeys(*args, **kwargs)
@omit_exception
def hexists(self, *args, **kwargs):
return self.client.hexists(*args, **kwargs)

View File

@@ -0,0 +1,6 @@
from django_redis.client.default import DefaultClient
from django_redis.client.herd import HerdClient
from django_redis.client.sentinel import SentinelClient
from django_redis.client.sharded import ShardClient
__all__ = ["DefaultClient", "HerdClient", "SentinelClient", "ShardClient"]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,169 @@
import random
import socket
import time
from collections import OrderedDict
from django.conf import settings
from redis.exceptions import ConnectionError as RedisConnectionError
from redis.exceptions import ResponseError
from redis.exceptions import TimeoutError as RedisTimeoutError
from django_redis.client.default import DEFAULT_TIMEOUT, DefaultClient
from django_redis.exceptions import ConnectionInterrupted
_main_exceptions = (
RedisConnectionError,
RedisTimeoutError,
ResponseError,
socket.timeout,
)
class Marker:
"""
Dummy class for use as
marker for herded keys.
"""
pass
def _is_expired(x, herd_timeout: int) -> bool:
if x >= herd_timeout:
return True
val = x + random.randint(1, herd_timeout)
return val >= herd_timeout
class HerdClient(DefaultClient):
def __init__(self, *args, **kwargs):
self._marker = Marker()
self._herd_timeout = getattr(settings, "CACHE_HERD_TIMEOUT", 60)
super().__init__(*args, **kwargs)
def _pack(self, value, timeout):
herd_timeout = (timeout or self._backend.default_timeout) + int(time.time())
return self._marker, value, herd_timeout
def _unpack(self, value):
try:
marker, unpacked, herd_timeout = value
except (ValueError, TypeError):
return value, False
if not isinstance(marker, Marker):
return value, False
now = int(time.time())
if herd_timeout < now:
x = now - herd_timeout
return unpacked, _is_expired(x, self._herd_timeout)
return unpacked, False
def set(
self,
key,
value,
timeout=DEFAULT_TIMEOUT,
version=None,
client=None,
nx=False,
xx=False,
):
if timeout is DEFAULT_TIMEOUT:
timeout = self._backend.default_timeout
if timeout is None or timeout <= 0:
return super().set(
key,
value,
timeout=timeout,
version=version,
client=client,
nx=nx,
xx=xx,
)
packed = self._pack(value, timeout)
real_timeout = timeout + self._herd_timeout
return super().set(
key, packed, timeout=real_timeout, version=version, client=client, nx=nx
)
def get(self, key, default=None, version=None, client=None):
packed = super().get(key, default=default, version=version, client=client)
val, refresh = self._unpack(packed)
if refresh:
return default
return val
def get_many(self, keys, version=None, client=None):
if client is None:
client = self.get_client(write=False)
if not keys:
return {}
recovered_data = OrderedDict()
new_keys = [self.make_key(key, version=version) for key in keys]
map_keys = dict(zip(new_keys, keys))
try:
results = client.mget(*new_keys)
except _main_exceptions as e:
raise ConnectionInterrupted(connection=client) from e
for key, value in zip(new_keys, results):
if value is None:
continue
val, refresh = self._unpack(self.decode(value))
recovered_data[map_keys[key]] = None if refresh else val
return recovered_data
def set_many(
self, data, timeout=DEFAULT_TIMEOUT, version=None, client=None, herd=True
):
"""
Set a bunch of values in the cache at once from a dict of key/value
pairs. This is much more efficient than calling set() multiple times.
If timeout is given, that timeout will be used for the key; otherwise
the default cache timeout will be used.
"""
if client is None:
client = self.get_client(write=True)
set_function = self.set if herd else super().set
try:
pipeline = client.pipeline()
for key, value in data.items():
set_function(key, value, timeout, version=version, client=pipeline)
pipeline.execute()
except _main_exceptions as e:
raise ConnectionInterrupted(connection=client) from e
def incr(self, *args, **kwargs):
raise NotImplementedError
def decr(self, *args, **kwargs):
raise NotImplementedError
def touch(self, key, timeout=DEFAULT_TIMEOUT, version=None, client=None):
if client is None:
client = self.get_client(write=True)
value = self.get(key, version=version, client=client)
if value is None:
return False
self.set(key, value, timeout=timeout, version=version, client=client)
return True

View File

@@ -0,0 +1,42 @@
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
from django.core.exceptions import ImproperlyConfigured
from redis.sentinel import SentinelConnectionPool
from django_redis.client.default import DefaultClient
def replace_query(url, query):
return urlunparse((*url[:4], urlencode(query, doseq=True), url[5]))
class SentinelClient(DefaultClient):
"""
Sentinel client which uses the single redis URL specified by the CACHE's
LOCATION to create a LOCATION configuration for two connection pools; One
pool for the primaries and another pool for the replicas, and upon
connecting ensures the connection pool factory is configured correctly.
"""
def __init__(self, server, params, backend):
if isinstance(server, str):
url = urlparse(server)
primary_query = parse_qs(url.query, keep_blank_values=True)
replica_query = dict(primary_query)
primary_query["is_master"] = [1]
replica_query["is_master"] = [0]
server = [replace_query(url, i) for i in (primary_query, replica_query)]
super().__init__(server, params, backend)
def connect(self, *args, **kwargs):
connection = super().connect(*args, **kwargs)
if not isinstance(connection.connection_pool, SentinelConnectionPool):
error_message = (
"Settings DJANGO_REDIS_CONNECTION_FACTORY or "
"CACHE[].OPTIONS.CONNECTION_POOL_CLASS is not configured correctly."
)
raise ImproperlyConfigured(error_message)
return connection

View File

@@ -0,0 +1,487 @@
import builtins
import re
from collections import OrderedDict
from collections.abc import Iterator
from datetime import datetime
from typing import Any, Optional, Union
from redis import Redis
from redis.exceptions import ConnectionError as RedisConnectionError
from redis.typing import KeyT
from django_redis.client.default import DEFAULT_TIMEOUT, DefaultClient
from django_redis.exceptions import ConnectionInterrupted
from django_redis.hash_ring import HashRing
from django_redis.util import CacheKey
class ShardClient(DefaultClient):
_findhash = re.compile(r".*\{(.*)\}.*", re.I)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if not isinstance(self._server, (list, tuple)):
self._server = [self._server]
self._ring = HashRing(self._server)
self._serverdict = self.connect()
def get_client(self, *args, **kwargs):
raise NotImplementedError
def connect(self, index=0):
connection_dict = {}
for name in self._server:
connection_dict[name] = self.connection_factory.connect(name)
return connection_dict
def get_server_name(self, _key):
key = str(_key)
g = self._findhash.match(key)
if g is not None and len(g.groups()) > 0:
key = g.groups()[0]
return self._ring.get_node(key)
def get_server(self, key):
name = self.get_server_name(key)
return self._serverdict[name]
def add(self, key, value, timeout=DEFAULT_TIMEOUT, version=None, client=None):
if client is None:
key = self.make_key(key, version=version)
client = self.get_server(key)
return super().add(
key=key, value=value, version=version, client=client, timeout=timeout
)
def get(self, key, default=None, version=None, client=None):
if client is None:
key = self.make_key(key, version=version)
client = self.get_server(key)
return super().get(key=key, default=default, version=version, client=client)
def get_many(self, keys, version=None):
if not keys:
return {}
recovered_data = OrderedDict()
new_keys = [self.make_key(key, version=version) for key in keys]
map_keys = dict(zip(new_keys, keys))
for key in new_keys:
client = self.get_server(key)
value = self.get(key=key, version=version, client=client)
if value is None:
continue
recovered_data[map_keys[key]] = value
return recovered_data
def set(
self,
key,
value,
timeout=DEFAULT_TIMEOUT,
version=None,
client=None,
nx=False,
xx=False,
):
"""
Persist a value to the cache, and set an optional expiration time.
"""
if client is None:
key = self.make_key(key, version=version)
client = self.get_server(key)
return super().set(
key=key,
value=value,
timeout=timeout,
version=version,
client=client,
nx=nx,
xx=xx,
)
def set_many(self, data, timeout=DEFAULT_TIMEOUT, version=None, client=None):
"""
Set a bunch of values in the cache at once from a dict of key/value
pairs. This is much more efficient than calling set() multiple times.
If timeout is given, that timeout will be used for the key; otherwise
the default cache timeout will be used.
"""
for key, value in data.items():
self.set(key, value, timeout, version=version, client=client)
def has_key(self, key, version=None, client=None):
"""
Test if key exists.
"""
if client is None:
key = self.make_key(key, version=version)
client = self.get_server(key)
key = self.make_key(key, version=version)
try:
return client.exists(key) == 1
except RedisConnectionError as e:
raise ConnectionInterrupted(connection=client) from e
def delete(self, key, version=None, client=None):
if client is None:
key = self.make_key(key, version=version)
client = self.get_server(key)
return super().delete(key=key, version=version, client=client)
def ttl(self, key, version=None, client=None):
"""
Executes TTL redis command and return the "time-to-live" of specified key.
If key is a non volatile key, it returns None.
"""
if client is None:
key = self.make_key(key, version=version)
client = self.get_server(key)
return super().ttl(key=key, version=version, client=client)
def pttl(self, key, version=None, client=None):
"""
Executes PTTL redis command and return the "time-to-live" of specified key
in milliseconds. If key is a non volatile key, it returns None.
"""
if client is None:
key = self.make_key(key, version=version)
client = self.get_server(key)
return super().pttl(key=key, version=version, client=client)
def persist(self, key, version=None, client=None):
if client is None:
key = self.make_key(key, version=version)
client = self.get_server(key)
return super().persist(key=key, version=version, client=client)
def expire(self, key, timeout, version=None, client=None):
if client is None:
key = self.make_key(key, version=version)
client = self.get_server(key)
return super().expire(key=key, timeout=timeout, version=version, client=client)
def pexpire(self, key, timeout, version=None, client=None):
if client is None:
key = self.make_key(key, version=version)
client = self.get_server(key)
return super().pexpire(key=key, timeout=timeout, version=version, client=client)
def pexpire_at(self, key, when: Union[datetime, int], version=None, client=None):
"""
Set an expire flag on a ``key`` to ``when`` on a shard client.
``when`` which can be represented as an integer indicating unix
time or a Python datetime object.
"""
if client is None:
key = self.make_key(key, version=version)
client = self.get_server(key)
return super().pexpire_at(key=key, when=when, version=version, client=client)
def expire_at(self, key, when: Union[datetime, int], version=None, client=None):
"""
Set an expire flag on a ``key`` to ``when`` on a shard client.
``when`` which can be represented as an integer indicating unix
time or a Python datetime object.
"""
if client is None:
key = self.make_key(key, version=version)
client = self.get_server(key)
return super().expire_at(key=key, when=when, version=version, client=client)
def lock(
self,
key,
version=None,
timeout=None,
sleep=0.1,
blocking_timeout=None,
client=None,
thread_local=True,
):
if client is None:
key = self.make_key(key, version=version)
client = self.get_server(key)
key = self.make_key(key, version=version)
return super().lock(
key,
timeout=timeout,
sleep=sleep,
client=client,
blocking_timeout=blocking_timeout,
thread_local=thread_local,
)
def delete_many(self, keys, version=None):
"""
Remove multiple keys at once.
"""
res = 0
for key in [self.make_key(k, version=version) for k in keys]:
client = self.get_server(key)
res += self.delete(key, client=client)
return res
def incr_version(self, key, delta=1, version=None, client=None):
if client is None:
key = self.make_key(key, version=version)
client = self.get_server(key)
if version is None:
version = self._backend.version
old_key = self.make_key(key, version)
value = self.get(old_key, version=version, client=client)
try:
ttl = self.ttl(old_key, version=version, client=client)
except RedisConnectionError as e:
raise ConnectionInterrupted(connection=client) from e
if value is None:
msg = f"Key '{key}' not found"
raise ValueError(msg)
if isinstance(key, CacheKey):
new_key = self.make_key(key.original_key(), version=version + delta)
else:
new_key = self.make_key(key, version=version + delta)
self.set(new_key, value, timeout=ttl, client=self.get_server(new_key))
self.delete(old_key, client=client)
return version + delta
def incr(self, key, delta=1, version=None, client=None):
if client is None:
key = self.make_key(key, version=version)
client = self.get_server(key)
return super().incr(key=key, delta=delta, version=version, client=client)
def decr(self, key, delta=1, version=None, client=None):
if client is None:
key = self.make_key(key, version=version)
client = self.get_server(key)
return super().decr(key=key, delta=delta, version=version, client=client)
def iter_keys(self, key, version=None):
error_message = "iter_keys not supported on sharded client"
raise NotImplementedError(error_message)
def keys(self, search, version=None):
pattern = self.make_pattern(search, version=version)
keys = []
try:
for connection in self._serverdict.values():
keys.extend(connection.keys(pattern))
except RedisConnectionError as e:
# FIXME: technically all clients should be passed as `connection`.
client = self.get_server(pattern)
raise ConnectionInterrupted(connection=client) from e
return [self.reverse_key(k.decode()) for k in keys]
def delete_pattern(
self, pattern, version=None, client=None, itersize=None, prefix=None
):
"""
Remove all keys matching pattern.
"""
pattern = self.make_pattern(pattern, version=version, prefix=prefix)
kwargs = {"match": pattern}
if itersize:
kwargs["count"] = itersize
keys = []
for connection in self._serverdict.values():
keys.extend(key for key in connection.scan_iter(**kwargs))
res = 0
if keys:
for connection in self._serverdict.values():
res += connection.delete(*keys)
return res
def do_close_clients(self):
for client in self._serverdict.values():
self.disconnect(client=client)
def touch(self, key, timeout=DEFAULT_TIMEOUT, version=None, client=None):
if client is None:
key = self.make_key(key, version=version)
client = self.get_server(key)
return super().touch(key=key, timeout=timeout, version=version, client=client)
def clear(self, client=None):
for connection in self._serverdict.values():
connection.flushdb()
def sadd(
self,
key: KeyT,
*values: Any,
version: Optional[int] = None,
client: Optional[Redis] = None,
) -> int:
if client is None:
key = self.make_key(key, version=version)
client = self.get_server(key)
return super().sadd(key, *values, version=version, client=client)
def scard(
self,
key: KeyT,
version: Optional[int] = None,
client: Optional[Redis] = None,
) -> int:
if client is None:
key = self.make_key(key, version=version)
client = self.get_server(key)
return super().scard(key=key, version=version, client=client)
def smembers(
self,
key: KeyT,
version: Optional[int] = None,
client: Optional[Redis] = None,
) -> builtins.set[Any]:
if client is None:
key = self.make_key(key, version=version)
client = self.get_server(key)
return super().smembers(key=key, version=version, client=client)
def smove(
self,
source: KeyT,
destination: KeyT,
member: Any,
version: Optional[int] = None,
client: Optional[Redis] = None,
):
if client is None:
source = self.make_key(source, version=version)
client = self.get_server(source)
destination = self.make_key(destination, version=version)
return super().smove(
source=source,
destination=destination,
member=member,
version=version,
client=client,
)
def srem(
self,
key: KeyT,
*members,
version: Optional[int] = None,
client: Optional[Redis] = None,
) -> int:
if client is None:
key = self.make_key(key, version=version)
client = self.get_server(key)
return super().srem(key, *members, version=version, client=client)
def sscan(
self,
key: KeyT,
match: Optional[str] = None,
count: Optional[int] = 10,
version: Optional[int] = None,
client: Optional[Redis] = None,
) -> builtins.set[Any]:
if client is None:
key = self.make_key(key, version=version)
client = self.get_server(key)
return super().sscan(
key=key, match=match, count=count, version=version, client=client
)
def sscan_iter(
self,
key: KeyT,
match: Optional[str] = None,
count: Optional[int] = 10,
version: Optional[int] = None,
client: Optional[Redis] = None,
) -> Iterator[Any]:
if client is None:
key = self.make_key(key, version=version)
client = self.get_server(key)
return super().sscan_iter(
key=key, match=match, count=count, version=version, client=client
)
def srandmember(
self,
key: KeyT,
count: Optional[int] = None,
version: Optional[int] = None,
client: Optional[Redis] = None,
) -> Union[builtins.set, Any]:
if client is None:
key = self.make_key(key, version=version)
client = self.get_server(key)
return super().srandmember(key=key, count=count, version=version, client=client)
def sismember(
self,
key: KeyT,
member: Any,
version: Optional[int] = None,
client: Optional[Redis] = None,
) -> bool:
if client is None:
key = self.make_key(key, version=version)
client = self.get_server(key)
return super().sismember(key, member, version=version, client=client)
def spop(
self,
key: KeyT,
count: Optional[int] = None,
version: Optional[int] = None,
client: Optional[Redis] = None,
) -> Union[builtins.set, Any]:
if client is None:
key = self.make_key(key, version=version)
client = self.get_server(key)
return super().spop(key=key, count=count, version=version, client=client)
def smismember(
self,
key: KeyT,
*members,
version: Optional[int] = None,
client: Optional[Redis] = None,
) -> list[bool]:
if client is None:
key = self.make_key(key, version=version)
client = self.get_server(key)
return super().smismember(key, *members, version=version, client=client)

View File

@@ -0,0 +1,9 @@
class BaseCompressor:
def __init__(self, options):
self._options = options
def compress(self, value: bytes) -> bytes:
raise NotImplementedError
def decompress(self, value: bytes) -> bytes:
raise NotImplementedError

View File

@@ -0,0 +1,19 @@
import gzip
from django_redis.compressors.base import BaseCompressor
from django_redis.exceptions import CompressorError
class GzipCompressor(BaseCompressor):
min_length = 15
def compress(self, value: bytes) -> bytes:
if len(value) > self.min_length:
return gzip.compress(value)
return value
def decompress(self, value: bytes) -> bytes:
try:
return gzip.decompress(value)
except gzip.BadGzipFile as e:
raise CompressorError from e

View File

@@ -0,0 +1,9 @@
from django_redis.compressors.base import BaseCompressor
class IdentityCompressor(BaseCompressor):
def compress(self, value: bytes) -> bytes:
return value
def decompress(self, value: bytes) -> bytes:
return value

View File

@@ -0,0 +1,20 @@
from lz4.frame import compress as _compress
from lz4.frame import decompress as _decompress
from django_redis.compressors.base import BaseCompressor
from django_redis.exceptions import CompressorError
class Lz4Compressor(BaseCompressor):
min_length = 15
def compress(self, value: bytes) -> bytes:
if len(value) > self.min_length:
return _compress(value)
return value
def decompress(self, value: bytes) -> bytes:
try:
return _decompress(value)
except Exception as e:
raise CompressorError from e

View File

@@ -0,0 +1,20 @@
import lzma
from django_redis.compressors.base import BaseCompressor
from django_redis.exceptions import CompressorError
class LzmaCompressor(BaseCompressor):
min_length = 100
preset = 4
def compress(self, value: bytes) -> bytes:
if len(value) > self.min_length:
return lzma.compress(value, preset=self.preset)
return value
def decompress(self, value: bytes) -> bytes:
try:
return lzma.decompress(value)
except lzma.LZMAError as e:
raise CompressorError from e

View File

@@ -0,0 +1,20 @@
import zlib
from django_redis.compressors.base import BaseCompressor
from django_redis.exceptions import CompressorError
class ZlibCompressor(BaseCompressor):
min_length = 15
preset = 6
def compress(self, value: bytes) -> bytes:
if len(value) > self.min_length:
return zlib.compress(value, self.preset)
return value
def decompress(self, value: bytes) -> bytes:
try:
return zlib.decompress(value)
except zlib.error as e:
raise CompressorError from e

View File

@@ -0,0 +1,19 @@
import pyzstd
from django_redis.compressors.base import BaseCompressor
from django_redis.exceptions import CompressorError
class ZStdCompressor(BaseCompressor):
min_length = 15
def compress(self, value: bytes) -> bytes:
if len(value) > self.min_length:
return pyzstd.compress(value)
return value
def decompress(self, value: bytes) -> bytes:
try:
return pyzstd.decompress(value)
except pyzstd.ZstdError as e:
raise CompressorError from e

View File

@@ -0,0 +1,12 @@
class ConnectionInterrupted(Exception):
def __init__(self, connection, parent=None):
self.connection = connection
def __str__(self) -> str:
error_type = type(self.__cause__).__name__
error_msg = str(self.__cause__)
return f"Redis {error_type}: {error_msg}"
class CompressorError(Exception):
pass

View File

@@ -0,0 +1,59 @@
import bisect
import hashlib
from collections.abc import Iterable, Iterator
from typing import Optional
class HashRing:
nodes: list[str] = []
def __init__(self, nodes: Iterable[str] = (), replicas: int = 128) -> None:
self.replicas: int = replicas
self.ring: dict[str, str] = {}
self.sorted_keys: list[str] = []
for node in nodes:
self.add_node(node)
def add_node(self, node: str) -> None:
self.nodes.append(node)
for x in range(self.replicas):
_key = f"{node}:{x}"
_hash = hashlib.sha256(_key.encode()).hexdigest()
self.ring[_hash] = node
self.sorted_keys.append(_hash)
self.sorted_keys.sort()
def remove_node(self, node: str) -> None:
self.nodes.remove(node)
for x in range(self.replicas):
_hash = hashlib.sha256(f"{node}:{x}".encode()).hexdigest()
del self.ring[_hash]
self.sorted_keys.remove(_hash)
def get_node(self, key: str) -> Optional[str]:
n, i = self.get_node_pos(key)
return n
def get_node_pos(self, key: str) -> tuple[Optional[str], Optional[int]]:
if len(self.ring) == 0:
return None, None
_hash = hashlib.sha256(key.encode()).hexdigest()
idx = bisect.bisect(self.sorted_keys, _hash)
idx = min(idx - 1, (self.replicas * len(self.nodes)) - 1)
return self.ring[self.sorted_keys[idx]], idx
def iter_nodes(self, key: str) -> Iterator[tuple[Optional[str], Optional[str]]]:
if len(self.ring) == 0:
yield None, None
node, pos = self.get_node_pos(key)
for k in self.sorted_keys[pos:]:
yield k, self.ring[k]
def __call__(self, key: str) -> Optional[str]:
return self.get_node(key)

View File

@@ -0,0 +1,200 @@
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
from django.conf import settings
from django.core.exceptions import ImproperlyConfigured
from django.utils.module_loading import import_string
from redis import Redis
from redis.connection import ConnectionPool, DefaultParser, to_bool
from redis.sentinel import Sentinel
class ConnectionFactory:
# Store connection pool by cache backend options.
#
# _pools is a process-global, as otherwise _pools is cleared every time
# ConnectionFactory is instantiated, as Django creates new cache client
# (DefaultClient) instance for every request.
_pools: dict[str, ConnectionPool] = {}
def __init__(self, options):
pool_cls_path = options.get(
"CONNECTION_POOL_CLASS", "redis.connection.ConnectionPool"
)
self.pool_cls = import_string(pool_cls_path)
self.pool_cls_kwargs = options.get("CONNECTION_POOL_KWARGS", {})
redis_client_cls_path = options.get("REDIS_CLIENT_CLASS", "redis.client.Redis")
self.redis_client_cls = import_string(redis_client_cls_path)
self.redis_client_cls_kwargs = options.get("REDIS_CLIENT_KWARGS", {})
self.options = options
def make_connection_params(self, url):
"""
Given a main connection parameters, build a complete
dict of connection parameters.
"""
kwargs = {
"url": url,
"parser_class": self.get_parser_cls(),
}
password = self.options.get("PASSWORD", None)
if password:
kwargs["password"] = password
socket_timeout = self.options.get("SOCKET_TIMEOUT", None)
if socket_timeout:
if not isinstance(socket_timeout, (int, float)):
error_message = "Socket timeout should be float or integer"
raise ImproperlyConfigured(error_message)
kwargs["socket_timeout"] = socket_timeout
socket_connect_timeout = self.options.get("SOCKET_CONNECT_TIMEOUT", None)
if socket_connect_timeout:
if not isinstance(socket_connect_timeout, (int, float)):
error_message = "Socket connect timeout should be float or integer"
raise ImproperlyConfigured(error_message)
kwargs["socket_connect_timeout"] = socket_connect_timeout
return kwargs
def connect(self, url: str) -> Redis:
"""
Given a basic connection parameters,
return a new connection.
"""
params = self.make_connection_params(url)
return self.get_connection(params)
def disconnect(self, connection: Redis) -> None:
"""
Given a not null client connection it disconnect from the Redis server.
The default implementation uses a pool to hold connections.
"""
connection.connection_pool.disconnect()
def get_connection(self, params):
"""
Given a now preformatted params, return a
new connection.
The default implementation uses a cached pools
for create new connection.
"""
pool = self.get_or_create_connection_pool(params)
return self.redis_client_cls(
connection_pool=pool, **self.redis_client_cls_kwargs
)
def get_parser_cls(self):
cls = self.options.get("PARSER_CLASS", None)
if cls is None:
return DefaultParser
return import_string(cls)
def get_or_create_connection_pool(self, params):
"""
Given a connection parameters and return a new
or cached connection pool for them.
Reimplement this method if you want distinct
connection pool instance caching behavior.
"""
key = params["url"]
if key not in self._pools:
self._pools[key] = self.get_connection_pool(params)
return self._pools[key]
def get_connection_pool(self, params):
"""
Given a connection parameters, return a new
connection pool for them.
Overwrite this method if you want a custom
behavior on creating connection pool.
"""
cp_params = dict(params)
cp_params.update(self.pool_cls_kwargs)
pool = self.pool_cls.from_url(**cp_params)
if pool.connection_kwargs.get("password", None) is None:
pool.connection_kwargs["password"] = params.get("password", None)
pool.reset()
return pool
class SentinelConnectionFactory(ConnectionFactory):
def __init__(self, options):
# allow overriding the default SentinelConnectionPool class
options.setdefault(
"CONNECTION_POOL_CLASS", "redis.sentinel.SentinelConnectionPool"
)
super().__init__(options)
sentinels = options.get("SENTINELS")
if not sentinels:
error_message = "SENTINELS must be provided as a list of (host, port)."
raise ImproperlyConfigured(error_message)
# provide the connection pool kwargs to the sentinel in case it
# needs to use the socket options for the sentinels themselves
connection_kwargs = self.make_connection_params(None)
connection_kwargs.pop("url")
connection_kwargs.update(self.pool_cls_kwargs)
self._sentinel = Sentinel(
sentinels,
sentinel_kwargs=options.get("SENTINEL_KWARGS"),
**connection_kwargs,
)
def get_connection_pool(self, params):
"""
Given a connection parameters, return a new sentinel connection pool
for them.
"""
url = urlparse(params["url"])
# explicitly set service_name and sentinel_manager for the
# SentinelConnectionPool constructor since will be called by from_url
cp_params = dict(params)
# convert "is_master" to a boolean if set on the URL, otherwise if not
# provided it defaults to True.
query_params = parse_qs(url.query)
is_master = query_params.get("is_master")
if is_master:
cp_params["is_master"] = to_bool(is_master[0])
# then remove the "is_master" query string from the URL
# so it doesn't interfere with the SentinelConnectionPool constructor
if "is_master" in query_params:
del query_params["is_master"]
new_query = urlencode(query_params, doseq=True)
new_url = urlunparse(
(url.scheme, url.netloc, url.path, url.params, new_query, url.fragment)
)
cp_params.update(
service_name=url.hostname, sentinel_manager=self._sentinel, url=new_url
)
return super().get_connection_pool(cp_params)
def get_connection_factory(path=None, options=None):
if path is None:
path = getattr(
settings,
"DJANGO_REDIS_CONNECTION_FACTORY",
"django_redis.pool.ConnectionFactory",
)
opt_conn_factory = options.get("CONNECTION_FACTORY")
if opt_conn_factory:
path = opt_conn_factory
cls = import_string(path)
return cls(options or {})

View File

@@ -0,0 +1,12 @@
from typing import Any
class BaseSerializer:
def __init__(self, options):
pass
def dumps(self, value: Any) -> bytes:
raise NotImplementedError
def loads(self, value: bytes) -> Any:
raise NotImplementedError

View File

@@ -0,0 +1,16 @@
import json
from typing import Any
from django.core.serializers.json import DjangoJSONEncoder
from django_redis.serializers.base import BaseSerializer
class JSONSerializer(BaseSerializer):
encoder_class = DjangoJSONEncoder
def dumps(self, value: Any) -> bytes:
return json.dumps(value, cls=self.encoder_class).encode()
def loads(self, value: bytes) -> Any:
return json.loads(value.decode())

View File

@@ -0,0 +1,13 @@
from typing import Any
import msgpack
from django_redis.serializers.base import BaseSerializer
class MSGPackSerializer(BaseSerializer):
def dumps(self, value: Any) -> bytes:
return msgpack.dumps(value)
def loads(self, value: bytes) -> Any:
return msgpack.loads(value, raw=False)

View File

@@ -0,0 +1,34 @@
import pickle
from typing import Any
from django.core.exceptions import ImproperlyConfigured
from django_redis.serializers.base import BaseSerializer
class PickleSerializer(BaseSerializer):
def __init__(self, options) -> None:
self._pickle_version = pickle.DEFAULT_PROTOCOL
self.setup_pickle_version(options)
super().__init__(options=options)
def setup_pickle_version(self, options) -> None:
if "PICKLE_VERSION" in options:
try:
self._pickle_version = int(options["PICKLE_VERSION"])
if self._pickle_version > pickle.HIGHEST_PROTOCOL:
error_message = (
f"PICKLE_VERSION can't be higher than pickle.HIGHEST_PROTOCOL:"
f" {pickle.HIGHEST_PROTOCOL}"
)
raise ImproperlyConfigured(error_message)
except (ValueError, TypeError) as e:
error_message = "PICKLE_VERSION value must be an integer"
raise ImproperlyConfigured(error_message) from e
def dumps(self, value: Any) -> bytes:
return pickle.dumps(value, self._pickle_version)
def loads(self, value: bytes) -> Any:
return pickle.loads(value)

View File

@@ -0,0 +1,11 @@
class CacheKey(str):
"""
A stub string class that we can use to check if a key was created already.
"""
def original_key(self) -> str:
return self.rsplit(":", 1)[1]
def default_reverse_key(key: str) -> str:
return key.split(":", 2)[2]