import builtins import re from collections import OrderedDict from collections.abc import Iterator from datetime import datetime from typing import Any, Optional, Union from redis import Redis from redis.exceptions import ConnectionError as RedisConnectionError from redis.typing import KeyT from django_redis.client.default import DEFAULT_TIMEOUT, DefaultClient from django_redis.exceptions import ConnectionInterrupted from django_redis.hash_ring import HashRing from django_redis.util import CacheKey class ShardClient(DefaultClient): _findhash = re.compile(r".*\{(.*)\}.*", re.I) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) if not isinstance(self._server, (list, tuple)): self._server = [self._server] self._ring = HashRing(self._server) self._serverdict = self.connect() def get_client(self, *args, **kwargs): raise NotImplementedError def connect(self, index=0): connection_dict = {} for name in self._server: connection_dict[name] = self.connection_factory.connect(name) return connection_dict def get_server_name(self, _key): key = str(_key) g = self._findhash.match(key) if g is not None and len(g.groups()) > 0: key = g.groups()[0] return self._ring.get_node(key) def get_server(self, key): name = self.get_server_name(key) return self._serverdict[name] def add(self, key, value, timeout=DEFAULT_TIMEOUT, version=None, client=None): if client is None: key = self.make_key(key, version=version) client = self.get_server(key) return super().add( key=key, value=value, version=version, client=client, timeout=timeout ) def get(self, key, default=None, version=None, client=None): if client is None: key = self.make_key(key, version=version) client = self.get_server(key) return super().get(key=key, default=default, version=version, client=client) def get_many(self, keys, version=None): if not keys: return {} recovered_data = OrderedDict() new_keys = [self.make_key(key, version=version) for key in keys] map_keys = dict(zip(new_keys, keys)) for key in new_keys: client = self.get_server(key) value = self.get(key=key, version=version, client=client) if value is None: continue recovered_data[map_keys[key]] = value return recovered_data def set( self, key, value, timeout=DEFAULT_TIMEOUT, version=None, client=None, nx=False, xx=False, ): """ Persist a value to the cache, and set an optional expiration time. """ if client is None: key = self.make_key(key, version=version) client = self.get_server(key) return super().set( key=key, value=value, timeout=timeout, version=version, client=client, nx=nx, xx=xx, ) def set_many(self, data, timeout=DEFAULT_TIMEOUT, version=None, client=None): """ Set a bunch of values in the cache at once from a dict of key/value pairs. This is much more efficient than calling set() multiple times. If timeout is given, that timeout will be used for the key; otherwise the default cache timeout will be used. """ for key, value in data.items(): self.set(key, value, timeout, version=version, client=client) def has_key(self, key, version=None, client=None): """ Test if key exists. """ if client is None: key = self.make_key(key, version=version) client = self.get_server(key) key = self.make_key(key, version=version) try: return client.exists(key) == 1 except RedisConnectionError as e: raise ConnectionInterrupted(connection=client) from e def delete(self, key, version=None, client=None): if client is None: key = self.make_key(key, version=version) client = self.get_server(key) return super().delete(key=key, version=version, client=client) def ttl(self, key, version=None, client=None): """ Executes TTL redis command and return the "time-to-live" of specified key. If key is a non volatile key, it returns None. """ if client is None: key = self.make_key(key, version=version) client = self.get_server(key) return super().ttl(key=key, version=version, client=client) def pttl(self, key, version=None, client=None): """ Executes PTTL redis command and return the "time-to-live" of specified key in milliseconds. If key is a non volatile key, it returns None. """ if client is None: key = self.make_key(key, version=version) client = self.get_server(key) return super().pttl(key=key, version=version, client=client) def persist(self, key, version=None, client=None): if client is None: key = self.make_key(key, version=version) client = self.get_server(key) return super().persist(key=key, version=version, client=client) def expire(self, key, timeout, version=None, client=None): if client is None: key = self.make_key(key, version=version) client = self.get_server(key) return super().expire(key=key, timeout=timeout, version=version, client=client) def pexpire(self, key, timeout, version=None, client=None): if client is None: key = self.make_key(key, version=version) client = self.get_server(key) return super().pexpire(key=key, timeout=timeout, version=version, client=client) def pexpire_at(self, key, when: Union[datetime, int], version=None, client=None): """ Set an expire flag on a ``key`` to ``when`` on a shard client. ``when`` which can be represented as an integer indicating unix time or a Python datetime object. """ if client is None: key = self.make_key(key, version=version) client = self.get_server(key) return super().pexpire_at(key=key, when=when, version=version, client=client) def expire_at(self, key, when: Union[datetime, int], version=None, client=None): """ Set an expire flag on a ``key`` to ``when`` on a shard client. ``when`` which can be represented as an integer indicating unix time or a Python datetime object. """ if client is None: key = self.make_key(key, version=version) client = self.get_server(key) return super().expire_at(key=key, when=when, version=version, client=client) def lock( self, key, version=None, timeout=None, sleep=0.1, blocking_timeout=None, client=None, thread_local=True, ): if client is None: key = self.make_key(key, version=version) client = self.get_server(key) key = self.make_key(key, version=version) return super().lock( key, timeout=timeout, sleep=sleep, client=client, blocking_timeout=blocking_timeout, thread_local=thread_local, ) def delete_many(self, keys, version=None): """ Remove multiple keys at once. """ res = 0 for key in [self.make_key(k, version=version) for k in keys]: client = self.get_server(key) res += self.delete(key, client=client) return res def incr_version(self, key, delta=1, version=None, client=None): if client is None: key = self.make_key(key, version=version) client = self.get_server(key) if version is None: version = self._backend.version old_key = self.make_key(key, version) value = self.get(old_key, version=version, client=client) try: ttl = self.ttl(old_key, version=version, client=client) except RedisConnectionError as e: raise ConnectionInterrupted(connection=client) from e if value is None: msg = f"Key '{key}' not found" raise ValueError(msg) if isinstance(key, CacheKey): new_key = self.make_key(key.original_key(), version=version + delta) else: new_key = self.make_key(key, version=version + delta) self.set(new_key, value, timeout=ttl, client=self.get_server(new_key)) self.delete(old_key, client=client) return version + delta def incr(self, key, delta=1, version=None, client=None): if client is None: key = self.make_key(key, version=version) client = self.get_server(key) return super().incr(key=key, delta=delta, version=version, client=client) def decr(self, key, delta=1, version=None, client=None): if client is None: key = self.make_key(key, version=version) client = self.get_server(key) return super().decr(key=key, delta=delta, version=version, client=client) def iter_keys(self, key, version=None): error_message = "iter_keys not supported on sharded client" raise NotImplementedError(error_message) def keys(self, search, version=None): pattern = self.make_pattern(search, version=version) keys = [] try: for connection in self._serverdict.values(): keys.extend(connection.keys(pattern)) except RedisConnectionError as e: # FIXME: technically all clients should be passed as `connection`. client = self.get_server(pattern) raise ConnectionInterrupted(connection=client) from e return [self.reverse_key(k.decode()) for k in keys] def delete_pattern( self, pattern, version=None, client=None, itersize=None, prefix=None ): """ Remove all keys matching pattern. """ pattern = self.make_pattern(pattern, version=version, prefix=prefix) kwargs = {"match": pattern} if itersize: kwargs["count"] = itersize keys = [] for connection in self._serverdict.values(): keys.extend(key for key in connection.scan_iter(**kwargs)) res = 0 if keys: for connection in self._serverdict.values(): res += connection.delete(*keys) return res def do_close_clients(self): for client in self._serverdict.values(): self.disconnect(client=client) def touch(self, key, timeout=DEFAULT_TIMEOUT, version=None, client=None): if client is None: key = self.make_key(key, version=version) client = self.get_server(key) return super().touch(key=key, timeout=timeout, version=version, client=client) def clear(self, client=None): for connection in self._serverdict.values(): connection.flushdb() def sadd( self, key: KeyT, *values: Any, version: Optional[int] = None, client: Optional[Redis] = None, ) -> int: if client is None: key = self.make_key(key, version=version) client = self.get_server(key) return super().sadd(key, *values, version=version, client=client) def scard( self, key: KeyT, version: Optional[int] = None, client: Optional[Redis] = None, ) -> int: if client is None: key = self.make_key(key, version=version) client = self.get_server(key) return super().scard(key=key, version=version, client=client) def smembers( self, key: KeyT, version: Optional[int] = None, client: Optional[Redis] = None, ) -> builtins.set[Any]: if client is None: key = self.make_key(key, version=version) client = self.get_server(key) return super().smembers(key=key, version=version, client=client) def smove( self, source: KeyT, destination: KeyT, member: Any, version: Optional[int] = None, client: Optional[Redis] = None, ): if client is None: source = self.make_key(source, version=version) client = self.get_server(source) destination = self.make_key(destination, version=version) return super().smove( source=source, destination=destination, member=member, version=version, client=client, ) def srem( self, key: KeyT, *members, version: Optional[int] = None, client: Optional[Redis] = None, ) -> int: if client is None: key = self.make_key(key, version=version) client = self.get_server(key) return super().srem(key, *members, version=version, client=client) def sscan( self, key: KeyT, match: Optional[str] = None, count: Optional[int] = 10, version: Optional[int] = None, client: Optional[Redis] = None, ) -> builtins.set[Any]: if client is None: key = self.make_key(key, version=version) client = self.get_server(key) return super().sscan( key=key, match=match, count=count, version=version, client=client ) def sscan_iter( self, key: KeyT, match: Optional[str] = None, count: Optional[int] = 10, version: Optional[int] = None, client: Optional[Redis] = None, ) -> Iterator[Any]: if client is None: key = self.make_key(key, version=version) client = self.get_server(key) return super().sscan_iter( key=key, match=match, count=count, version=version, client=client ) def srandmember( self, key: KeyT, count: Optional[int] = None, version: Optional[int] = None, client: Optional[Redis] = None, ) -> Union[builtins.set, Any]: if client is None: key = self.make_key(key, version=version) client = self.get_server(key) return super().srandmember(key=key, count=count, version=version, client=client) def sismember( self, key: KeyT, member: Any, version: Optional[int] = None, client: Optional[Redis] = None, ) -> bool: if client is None: key = self.make_key(key, version=version) client = self.get_server(key) return super().sismember(key, member, version=version, client=client) def spop( self, key: KeyT, count: Optional[int] = None, version: Optional[int] = None, client: Optional[Redis] = None, ) -> Union[builtins.set, Any]: if client is None: key = self.make_key(key, version=version) client = self.get_server(key) return super().spop(key=key, count=count, version=version, client=client) def smismember( self, key: KeyT, *members, version: Optional[int] = None, client: Optional[Redis] = None, ) -> list[bool]: if client is None: key = self.make_key(key, version=version) client = self.get_server(key) return super().smismember(key, *members, version=version, client=client)