Updates
This commit is contained in:
@@ -0,0 +1,487 @@
|
||||
import builtins
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from redis import Redis
|
||||
from redis.exceptions import ConnectionError as RedisConnectionError
|
||||
from redis.typing import KeyT
|
||||
|
||||
from django_redis.client.default import DEFAULT_TIMEOUT, DefaultClient
|
||||
from django_redis.exceptions import ConnectionInterrupted
|
||||
from django_redis.hash_ring import HashRing
|
||||
from django_redis.util import CacheKey
|
||||
|
||||
|
||||
class ShardClient(DefaultClient):
|
||||
_findhash = re.compile(r".*\{(.*)\}.*", re.I)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
if not isinstance(self._server, (list, tuple)):
|
||||
self._server = [self._server]
|
||||
|
||||
self._ring = HashRing(self._server)
|
||||
self._serverdict = self.connect()
|
||||
|
||||
def get_client(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def connect(self, index=0):
|
||||
connection_dict = {}
|
||||
for name in self._server:
|
||||
connection_dict[name] = self.connection_factory.connect(name)
|
||||
return connection_dict
|
||||
|
||||
def get_server_name(self, _key):
|
||||
key = str(_key)
|
||||
g = self._findhash.match(key)
|
||||
if g is not None and len(g.groups()) > 0:
|
||||
key = g.groups()[0]
|
||||
return self._ring.get_node(key)
|
||||
|
||||
def get_server(self, key):
|
||||
name = self.get_server_name(key)
|
||||
return self._serverdict[name]
|
||||
|
||||
def add(self, key, value, timeout=DEFAULT_TIMEOUT, version=None, client=None):
|
||||
if client is None:
|
||||
key = self.make_key(key, version=version)
|
||||
client = self.get_server(key)
|
||||
|
||||
return super().add(
|
||||
key=key, value=value, version=version, client=client, timeout=timeout
|
||||
)
|
||||
|
||||
def get(self, key, default=None, version=None, client=None):
|
||||
if client is None:
|
||||
key = self.make_key(key, version=version)
|
||||
client = self.get_server(key)
|
||||
|
||||
return super().get(key=key, default=default, version=version, client=client)
|
||||
|
||||
def get_many(self, keys, version=None):
|
||||
if not keys:
|
||||
return {}
|
||||
|
||||
recovered_data = OrderedDict()
|
||||
|
||||
new_keys = [self.make_key(key, version=version) for key in keys]
|
||||
map_keys = dict(zip(new_keys, keys))
|
||||
|
||||
for key in new_keys:
|
||||
client = self.get_server(key)
|
||||
value = self.get(key=key, version=version, client=client)
|
||||
|
||||
if value is None:
|
||||
continue
|
||||
|
||||
recovered_data[map_keys[key]] = value
|
||||
return recovered_data
|
||||
|
||||
def set(
|
||||
self,
|
||||
key,
|
||||
value,
|
||||
timeout=DEFAULT_TIMEOUT,
|
||||
version=None,
|
||||
client=None,
|
||||
nx=False,
|
||||
xx=False,
|
||||
):
|
||||
"""
|
||||
Persist a value to the cache, and set an optional expiration time.
|
||||
"""
|
||||
if client is None:
|
||||
key = self.make_key(key, version=version)
|
||||
client = self.get_server(key)
|
||||
|
||||
return super().set(
|
||||
key=key,
|
||||
value=value,
|
||||
timeout=timeout,
|
||||
version=version,
|
||||
client=client,
|
||||
nx=nx,
|
||||
xx=xx,
|
||||
)
|
||||
|
||||
def set_many(self, data, timeout=DEFAULT_TIMEOUT, version=None, client=None):
|
||||
"""
|
||||
Set a bunch of values in the cache at once from a dict of key/value
|
||||
pairs. This is much more efficient than calling set() multiple times.
|
||||
|
||||
If timeout is given, that timeout will be used for the key; otherwise
|
||||
the default cache timeout will be used.
|
||||
"""
|
||||
for key, value in data.items():
|
||||
self.set(key, value, timeout, version=version, client=client)
|
||||
|
||||
def has_key(self, key, version=None, client=None):
|
||||
"""
|
||||
Test if key exists.
|
||||
"""
|
||||
|
||||
if client is None:
|
||||
key = self.make_key(key, version=version)
|
||||
client = self.get_server(key)
|
||||
|
||||
key = self.make_key(key, version=version)
|
||||
try:
|
||||
return client.exists(key) == 1
|
||||
except RedisConnectionError as e:
|
||||
raise ConnectionInterrupted(connection=client) from e
|
||||
|
||||
def delete(self, key, version=None, client=None):
|
||||
if client is None:
|
||||
key = self.make_key(key, version=version)
|
||||
client = self.get_server(key)
|
||||
|
||||
return super().delete(key=key, version=version, client=client)
|
||||
|
||||
def ttl(self, key, version=None, client=None):
|
||||
"""
|
||||
Executes TTL redis command and return the "time-to-live" of specified key.
|
||||
If key is a non volatile key, it returns None.
|
||||
"""
|
||||
|
||||
if client is None:
|
||||
key = self.make_key(key, version=version)
|
||||
client = self.get_server(key)
|
||||
|
||||
return super().ttl(key=key, version=version, client=client)
|
||||
|
||||
def pttl(self, key, version=None, client=None):
|
||||
"""
|
||||
Executes PTTL redis command and return the "time-to-live" of specified key
|
||||
in milliseconds. If key is a non volatile key, it returns None.
|
||||
"""
|
||||
|
||||
if client is None:
|
||||
key = self.make_key(key, version=version)
|
||||
client = self.get_server(key)
|
||||
|
||||
return super().pttl(key=key, version=version, client=client)
|
||||
|
||||
def persist(self, key, version=None, client=None):
|
||||
if client is None:
|
||||
key = self.make_key(key, version=version)
|
||||
client = self.get_server(key)
|
||||
|
||||
return super().persist(key=key, version=version, client=client)
|
||||
|
||||
def expire(self, key, timeout, version=None, client=None):
|
||||
if client is None:
|
||||
key = self.make_key(key, version=version)
|
||||
client = self.get_server(key)
|
||||
|
||||
return super().expire(key=key, timeout=timeout, version=version, client=client)
|
||||
|
||||
def pexpire(self, key, timeout, version=None, client=None):
|
||||
if client is None:
|
||||
key = self.make_key(key, version=version)
|
||||
client = self.get_server(key)
|
||||
|
||||
return super().pexpire(key=key, timeout=timeout, version=version, client=client)
|
||||
|
||||
def pexpire_at(self, key, when: Union[datetime, int], version=None, client=None):
|
||||
"""
|
||||
Set an expire flag on a ``key`` to ``when`` on a shard client.
|
||||
``when`` which can be represented as an integer indicating unix
|
||||
time or a Python datetime object.
|
||||
"""
|
||||
if client is None:
|
||||
key = self.make_key(key, version=version)
|
||||
client = self.get_server(key)
|
||||
|
||||
return super().pexpire_at(key=key, when=when, version=version, client=client)
|
||||
|
||||
def expire_at(self, key, when: Union[datetime, int], version=None, client=None):
|
||||
"""
|
||||
Set an expire flag on a ``key`` to ``when`` on a shard client.
|
||||
``when`` which can be represented as an integer indicating unix
|
||||
time or a Python datetime object.
|
||||
"""
|
||||
if client is None:
|
||||
key = self.make_key(key, version=version)
|
||||
client = self.get_server(key)
|
||||
|
||||
return super().expire_at(key=key, when=when, version=version, client=client)
|
||||
|
||||
def lock(
|
||||
self,
|
||||
key,
|
||||
version=None,
|
||||
timeout=None,
|
||||
sleep=0.1,
|
||||
blocking_timeout=None,
|
||||
client=None,
|
||||
thread_local=True,
|
||||
):
|
||||
if client is None:
|
||||
key = self.make_key(key, version=version)
|
||||
client = self.get_server(key)
|
||||
|
||||
key = self.make_key(key, version=version)
|
||||
return super().lock(
|
||||
key,
|
||||
timeout=timeout,
|
||||
sleep=sleep,
|
||||
client=client,
|
||||
blocking_timeout=blocking_timeout,
|
||||
thread_local=thread_local,
|
||||
)
|
||||
|
||||
def delete_many(self, keys, version=None):
|
||||
"""
|
||||
Remove multiple keys at once.
|
||||
"""
|
||||
res = 0
|
||||
for key in [self.make_key(k, version=version) for k in keys]:
|
||||
client = self.get_server(key)
|
||||
res += self.delete(key, client=client)
|
||||
return res
|
||||
|
||||
def incr_version(self, key, delta=1, version=None, client=None):
|
||||
if client is None:
|
||||
key = self.make_key(key, version=version)
|
||||
client = self.get_server(key)
|
||||
|
||||
if version is None:
|
||||
version = self._backend.version
|
||||
|
||||
old_key = self.make_key(key, version)
|
||||
value = self.get(old_key, version=version, client=client)
|
||||
|
||||
try:
|
||||
ttl = self.ttl(old_key, version=version, client=client)
|
||||
except RedisConnectionError as e:
|
||||
raise ConnectionInterrupted(connection=client) from e
|
||||
|
||||
if value is None:
|
||||
msg = f"Key '{key}' not found"
|
||||
raise ValueError(msg)
|
||||
|
||||
if isinstance(key, CacheKey):
|
||||
new_key = self.make_key(key.original_key(), version=version + delta)
|
||||
else:
|
||||
new_key = self.make_key(key, version=version + delta)
|
||||
|
||||
self.set(new_key, value, timeout=ttl, client=self.get_server(new_key))
|
||||
self.delete(old_key, client=client)
|
||||
return version + delta
|
||||
|
||||
def incr(self, key, delta=1, version=None, client=None):
|
||||
if client is None:
|
||||
key = self.make_key(key, version=version)
|
||||
client = self.get_server(key)
|
||||
|
||||
return super().incr(key=key, delta=delta, version=version, client=client)
|
||||
|
||||
def decr(self, key, delta=1, version=None, client=None):
|
||||
if client is None:
|
||||
key = self.make_key(key, version=version)
|
||||
client = self.get_server(key)
|
||||
|
||||
return super().decr(key=key, delta=delta, version=version, client=client)
|
||||
|
||||
def iter_keys(self, key, version=None):
|
||||
error_message = "iter_keys not supported on sharded client"
|
||||
raise NotImplementedError(error_message)
|
||||
|
||||
def keys(self, search, version=None):
|
||||
pattern = self.make_pattern(search, version=version)
|
||||
keys = []
|
||||
try:
|
||||
for connection in self._serverdict.values():
|
||||
keys.extend(connection.keys(pattern))
|
||||
except RedisConnectionError as e:
|
||||
# FIXME: technically all clients should be passed as `connection`.
|
||||
client = self.get_server(pattern)
|
||||
raise ConnectionInterrupted(connection=client) from e
|
||||
|
||||
return [self.reverse_key(k.decode()) for k in keys]
|
||||
|
||||
def delete_pattern(
|
||||
self, pattern, version=None, client=None, itersize=None, prefix=None
|
||||
):
|
||||
"""
|
||||
Remove all keys matching pattern.
|
||||
"""
|
||||
pattern = self.make_pattern(pattern, version=version, prefix=prefix)
|
||||
kwargs = {"match": pattern}
|
||||
if itersize:
|
||||
kwargs["count"] = itersize
|
||||
|
||||
keys = []
|
||||
for connection in self._serverdict.values():
|
||||
keys.extend(key for key in connection.scan_iter(**kwargs))
|
||||
|
||||
res = 0
|
||||
if keys:
|
||||
for connection in self._serverdict.values():
|
||||
res += connection.delete(*keys)
|
||||
return res
|
||||
|
||||
def do_close_clients(self):
|
||||
for client in self._serverdict.values():
|
||||
self.disconnect(client=client)
|
||||
|
||||
def touch(self, key, timeout=DEFAULT_TIMEOUT, version=None, client=None):
|
||||
if client is None:
|
||||
key = self.make_key(key, version=version)
|
||||
client = self.get_server(key)
|
||||
|
||||
return super().touch(key=key, timeout=timeout, version=version, client=client)
|
||||
|
||||
def clear(self, client=None):
|
||||
for connection in self._serverdict.values():
|
||||
connection.flushdb()
|
||||
|
||||
def sadd(
|
||||
self,
|
||||
key: KeyT,
|
||||
*values: Any,
|
||||
version: Optional[int] = None,
|
||||
client: Optional[Redis] = None,
|
||||
) -> int:
|
||||
if client is None:
|
||||
key = self.make_key(key, version=version)
|
||||
client = self.get_server(key)
|
||||
return super().sadd(key, *values, version=version, client=client)
|
||||
|
||||
def scard(
|
||||
self,
|
||||
key: KeyT,
|
||||
version: Optional[int] = None,
|
||||
client: Optional[Redis] = None,
|
||||
) -> int:
|
||||
if client is None:
|
||||
key = self.make_key(key, version=version)
|
||||
client = self.get_server(key)
|
||||
return super().scard(key=key, version=version, client=client)
|
||||
|
||||
def smembers(
|
||||
self,
|
||||
key: KeyT,
|
||||
version: Optional[int] = None,
|
||||
client: Optional[Redis] = None,
|
||||
) -> builtins.set[Any]:
|
||||
if client is None:
|
||||
key = self.make_key(key, version=version)
|
||||
client = self.get_server(key)
|
||||
return super().smembers(key=key, version=version, client=client)
|
||||
|
||||
def smove(
|
||||
self,
|
||||
source: KeyT,
|
||||
destination: KeyT,
|
||||
member: Any,
|
||||
version: Optional[int] = None,
|
||||
client: Optional[Redis] = None,
|
||||
):
|
||||
if client is None:
|
||||
source = self.make_key(source, version=version)
|
||||
client = self.get_server(source)
|
||||
destination = self.make_key(destination, version=version)
|
||||
|
||||
return super().smove(
|
||||
source=source,
|
||||
destination=destination,
|
||||
member=member,
|
||||
version=version,
|
||||
client=client,
|
||||
)
|
||||
|
||||
def srem(
|
||||
self,
|
||||
key: KeyT,
|
||||
*members,
|
||||
version: Optional[int] = None,
|
||||
client: Optional[Redis] = None,
|
||||
) -> int:
|
||||
if client is None:
|
||||
key = self.make_key(key, version=version)
|
||||
client = self.get_server(key)
|
||||
return super().srem(key, *members, version=version, client=client)
|
||||
|
||||
def sscan(
|
||||
self,
|
||||
key: KeyT,
|
||||
match: Optional[str] = None,
|
||||
count: Optional[int] = 10,
|
||||
version: Optional[int] = None,
|
||||
client: Optional[Redis] = None,
|
||||
) -> builtins.set[Any]:
|
||||
if client is None:
|
||||
key = self.make_key(key, version=version)
|
||||
client = self.get_server(key)
|
||||
return super().sscan(
|
||||
key=key, match=match, count=count, version=version, client=client
|
||||
)
|
||||
|
||||
def sscan_iter(
|
||||
self,
|
||||
key: KeyT,
|
||||
match: Optional[str] = None,
|
||||
count: Optional[int] = 10,
|
||||
version: Optional[int] = None,
|
||||
client: Optional[Redis] = None,
|
||||
) -> Iterator[Any]:
|
||||
if client is None:
|
||||
key = self.make_key(key, version=version)
|
||||
client = self.get_server(key)
|
||||
return super().sscan_iter(
|
||||
key=key, match=match, count=count, version=version, client=client
|
||||
)
|
||||
|
||||
def srandmember(
|
||||
self,
|
||||
key: KeyT,
|
||||
count: Optional[int] = None,
|
||||
version: Optional[int] = None,
|
||||
client: Optional[Redis] = None,
|
||||
) -> Union[builtins.set, Any]:
|
||||
if client is None:
|
||||
key = self.make_key(key, version=version)
|
||||
client = self.get_server(key)
|
||||
return super().srandmember(key=key, count=count, version=version, client=client)
|
||||
|
||||
def sismember(
|
||||
self,
|
||||
key: KeyT,
|
||||
member: Any,
|
||||
version: Optional[int] = None,
|
||||
client: Optional[Redis] = None,
|
||||
) -> bool:
|
||||
if client is None:
|
||||
key = self.make_key(key, version=version)
|
||||
client = self.get_server(key)
|
||||
return super().sismember(key, member, version=version, client=client)
|
||||
|
||||
def spop(
|
||||
self,
|
||||
key: KeyT,
|
||||
count: Optional[int] = None,
|
||||
version: Optional[int] = None,
|
||||
client: Optional[Redis] = None,
|
||||
) -> Union[builtins.set, Any]:
|
||||
if client is None:
|
||||
key = self.make_key(key, version=version)
|
||||
client = self.get_server(key)
|
||||
return super().spop(key=key, count=count, version=version, client=client)
|
||||
|
||||
def smismember(
|
||||
self,
|
||||
key: KeyT,
|
||||
*members,
|
||||
version: Optional[int] = None,
|
||||
client: Optional[Redis] = None,
|
||||
) -> list[bool]:
|
||||
if client is None:
|
||||
key = self.make_key(key, version=version)
|
||||
client = self.get_server(key)
|
||||
return super().smismember(key, *members, version=version, client=client)
|
||||
Reference in New Issue
Block a user