This commit is contained in:
Iliyan Angelov
2025-09-19 11:58:53 +03:00
parent 306b20e24a
commit 6b247e5b9f
11423 changed files with 1500615 additions and 778 deletions

View File

@@ -0,0 +1,18 @@
from .cluster import READ_COMMANDS, AsyncRedisClusterCommands, RedisClusterCommands
from .core import AsyncCoreCommands, CoreCommands
from .helpers import list_or_args
from .redismodules import AsyncRedisModuleCommands, RedisModuleCommands
from .sentinel import AsyncSentinelCommands, SentinelCommands
__all__ = [
"AsyncCoreCommands",
"AsyncRedisClusterCommands",
"AsyncRedisModuleCommands",
"AsyncSentinelCommands",
"CoreCommands",
"READ_COMMANDS",
"RedisClusterCommands",
"RedisModuleCommands",
"SentinelCommands",
"list_or_args",
]

View File

@@ -0,0 +1,253 @@
from redis._parsers.helpers import bool_ok
from ..helpers import get_protocol_version, parse_to_list
from .commands import * # noqa
from .info import BFInfo, CFInfo, CMSInfo, TDigestInfo, TopKInfo
class AbstractBloom:
"""
The client allows to interact with RedisBloom and use all of
it's functionality.
- BF for Bloom Filter
- CF for Cuckoo Filter
- CMS for Count-Min Sketch
- TOPK for TopK Data Structure
- TDIGEST for estimate rank statistics
"""
@staticmethod
def append_items(params, items):
"""Append ITEMS to params."""
params.extend(["ITEMS"])
params += items
@staticmethod
def append_error(params, error):
"""Append ERROR to params."""
if error is not None:
params.extend(["ERROR", error])
@staticmethod
def append_capacity(params, capacity):
"""Append CAPACITY to params."""
if capacity is not None:
params.extend(["CAPACITY", capacity])
@staticmethod
def append_expansion(params, expansion):
"""Append EXPANSION to params."""
if expansion is not None:
params.extend(["EXPANSION", expansion])
@staticmethod
def append_no_scale(params, noScale):
"""Append NONSCALING tag to params."""
if noScale is not None:
params.extend(["NONSCALING"])
@staticmethod
def append_weights(params, weights):
"""Append WEIGHTS to params."""
if len(weights) > 0:
params.append("WEIGHTS")
params += weights
@staticmethod
def append_no_create(params, noCreate):
"""Append NOCREATE tag to params."""
if noCreate is not None:
params.extend(["NOCREATE"])
@staticmethod
def append_items_and_increments(params, items, increments):
"""Append pairs of items and increments to params."""
for i in range(len(items)):
params.append(items[i])
params.append(increments[i])
@staticmethod
def append_values_and_weights(params, items, weights):
"""Append pairs of items and weights to params."""
for i in range(len(items)):
params.append(items[i])
params.append(weights[i])
@staticmethod
def append_max_iterations(params, max_iterations):
"""Append MAXITERATIONS to params."""
if max_iterations is not None:
params.extend(["MAXITERATIONS", max_iterations])
@staticmethod
def append_bucket_size(params, bucket_size):
"""Append BUCKETSIZE to params."""
if bucket_size is not None:
params.extend(["BUCKETSIZE", bucket_size])
class CMSBloom(CMSCommands, AbstractBloom):
def __init__(self, client, **kwargs):
"""Create a new RedisBloom client."""
# Set the module commands' callbacks
_MODULE_CALLBACKS = {
CMS_INITBYDIM: bool_ok,
CMS_INITBYPROB: bool_ok,
# CMS_INCRBY: spaceHolder,
# CMS_QUERY: spaceHolder,
CMS_MERGE: bool_ok,
}
_RESP2_MODULE_CALLBACKS = {
CMS_INFO: CMSInfo,
}
_RESP3_MODULE_CALLBACKS = {}
self.client = client
self.commandmixin = CMSCommands
self.execute_command = client.execute_command
if get_protocol_version(self.client) in ["3", 3]:
_MODULE_CALLBACKS.update(_RESP3_MODULE_CALLBACKS)
else:
_MODULE_CALLBACKS.update(_RESP2_MODULE_CALLBACKS)
for k, v in _MODULE_CALLBACKS.items():
self.client.set_response_callback(k, v)
class TOPKBloom(TOPKCommands, AbstractBloom):
def __init__(self, client, **kwargs):
"""Create a new RedisBloom client."""
# Set the module commands' callbacks
_MODULE_CALLBACKS = {
TOPK_RESERVE: bool_ok,
# TOPK_QUERY: spaceHolder,
# TOPK_COUNT: spaceHolder,
}
_RESP2_MODULE_CALLBACKS = {
TOPK_ADD: parse_to_list,
TOPK_INCRBY: parse_to_list,
TOPK_INFO: TopKInfo,
TOPK_LIST: parse_to_list,
}
_RESP3_MODULE_CALLBACKS = {}
self.client = client
self.commandmixin = TOPKCommands
self.execute_command = client.execute_command
if get_protocol_version(self.client) in ["3", 3]:
_MODULE_CALLBACKS.update(_RESP3_MODULE_CALLBACKS)
else:
_MODULE_CALLBACKS.update(_RESP2_MODULE_CALLBACKS)
for k, v in _MODULE_CALLBACKS.items():
self.client.set_response_callback(k, v)
class CFBloom(CFCommands, AbstractBloom):
def __init__(self, client, **kwargs):
"""Create a new RedisBloom client."""
# Set the module commands' callbacks
_MODULE_CALLBACKS = {
CF_RESERVE: bool_ok,
# CF_ADD: spaceHolder,
# CF_ADDNX: spaceHolder,
# CF_INSERT: spaceHolder,
# CF_INSERTNX: spaceHolder,
# CF_EXISTS: spaceHolder,
# CF_DEL: spaceHolder,
# CF_COUNT: spaceHolder,
# CF_SCANDUMP: spaceHolder,
# CF_LOADCHUNK: spaceHolder,
}
_RESP2_MODULE_CALLBACKS = {
CF_INFO: CFInfo,
}
_RESP3_MODULE_CALLBACKS = {}
self.client = client
self.commandmixin = CFCommands
self.execute_command = client.execute_command
if get_protocol_version(self.client) in ["3", 3]:
_MODULE_CALLBACKS.update(_RESP3_MODULE_CALLBACKS)
else:
_MODULE_CALLBACKS.update(_RESP2_MODULE_CALLBACKS)
for k, v in _MODULE_CALLBACKS.items():
self.client.set_response_callback(k, v)
class TDigestBloom(TDigestCommands, AbstractBloom):
def __init__(self, client, **kwargs):
"""Create a new RedisBloom client."""
# Set the module commands' callbacks
_MODULE_CALLBACKS = {
TDIGEST_CREATE: bool_ok,
# TDIGEST_RESET: bool_ok,
# TDIGEST_ADD: spaceHolder,
# TDIGEST_MERGE: spaceHolder,
}
_RESP2_MODULE_CALLBACKS = {
TDIGEST_BYRANK: parse_to_list,
TDIGEST_BYREVRANK: parse_to_list,
TDIGEST_CDF: parse_to_list,
TDIGEST_INFO: TDigestInfo,
TDIGEST_MIN: float,
TDIGEST_MAX: float,
TDIGEST_TRIMMED_MEAN: float,
TDIGEST_QUANTILE: parse_to_list,
}
_RESP3_MODULE_CALLBACKS = {}
self.client = client
self.commandmixin = TDigestCommands
self.execute_command = client.execute_command
if get_protocol_version(self.client) in ["3", 3]:
_MODULE_CALLBACKS.update(_RESP3_MODULE_CALLBACKS)
else:
_MODULE_CALLBACKS.update(_RESP2_MODULE_CALLBACKS)
for k, v in _MODULE_CALLBACKS.items():
self.client.set_response_callback(k, v)
class BFBloom(BFCommands, AbstractBloom):
def __init__(self, client, **kwargs):
"""Create a new RedisBloom client."""
# Set the module commands' callbacks
_MODULE_CALLBACKS = {
BF_RESERVE: bool_ok,
# BF_ADD: spaceHolder,
# BF_MADD: spaceHolder,
# BF_INSERT: spaceHolder,
# BF_EXISTS: spaceHolder,
# BF_MEXISTS: spaceHolder,
# BF_SCANDUMP: spaceHolder,
# BF_LOADCHUNK: spaceHolder,
# BF_CARD: spaceHolder,
}
_RESP2_MODULE_CALLBACKS = {
BF_INFO: BFInfo,
}
_RESP3_MODULE_CALLBACKS = {}
self.client = client
self.commandmixin = BFCommands
self.execute_command = client.execute_command
if get_protocol_version(self.client) in ["3", 3]:
_MODULE_CALLBACKS.update(_RESP3_MODULE_CALLBACKS)
else:
_MODULE_CALLBACKS.update(_RESP2_MODULE_CALLBACKS)
for k, v in _MODULE_CALLBACKS.items():
self.client.set_response_callback(k, v)

View File

@@ -0,0 +1,538 @@
from redis.client import NEVER_DECODE
from redis.utils import deprecated_function
BF_RESERVE = "BF.RESERVE"
BF_ADD = "BF.ADD"
BF_MADD = "BF.MADD"
BF_INSERT = "BF.INSERT"
BF_EXISTS = "BF.EXISTS"
BF_MEXISTS = "BF.MEXISTS"
BF_SCANDUMP = "BF.SCANDUMP"
BF_LOADCHUNK = "BF.LOADCHUNK"
BF_INFO = "BF.INFO"
BF_CARD = "BF.CARD"
CF_RESERVE = "CF.RESERVE"
CF_ADD = "CF.ADD"
CF_ADDNX = "CF.ADDNX"
CF_INSERT = "CF.INSERT"
CF_INSERTNX = "CF.INSERTNX"
CF_EXISTS = "CF.EXISTS"
CF_MEXISTS = "CF.MEXISTS"
CF_DEL = "CF.DEL"
CF_COUNT = "CF.COUNT"
CF_SCANDUMP = "CF.SCANDUMP"
CF_LOADCHUNK = "CF.LOADCHUNK"
CF_INFO = "CF.INFO"
CMS_INITBYDIM = "CMS.INITBYDIM"
CMS_INITBYPROB = "CMS.INITBYPROB"
CMS_INCRBY = "CMS.INCRBY"
CMS_QUERY = "CMS.QUERY"
CMS_MERGE = "CMS.MERGE"
CMS_INFO = "CMS.INFO"
TOPK_RESERVE = "TOPK.RESERVE"
TOPK_ADD = "TOPK.ADD"
TOPK_INCRBY = "TOPK.INCRBY"
TOPK_QUERY = "TOPK.QUERY"
TOPK_COUNT = "TOPK.COUNT"
TOPK_LIST = "TOPK.LIST"
TOPK_INFO = "TOPK.INFO"
TDIGEST_CREATE = "TDIGEST.CREATE"
TDIGEST_RESET = "TDIGEST.RESET"
TDIGEST_ADD = "TDIGEST.ADD"
TDIGEST_MERGE = "TDIGEST.MERGE"
TDIGEST_CDF = "TDIGEST.CDF"
TDIGEST_QUANTILE = "TDIGEST.QUANTILE"
TDIGEST_MIN = "TDIGEST.MIN"
TDIGEST_MAX = "TDIGEST.MAX"
TDIGEST_INFO = "TDIGEST.INFO"
TDIGEST_TRIMMED_MEAN = "TDIGEST.TRIMMED_MEAN"
TDIGEST_RANK = "TDIGEST.RANK"
TDIGEST_REVRANK = "TDIGEST.REVRANK"
TDIGEST_BYRANK = "TDIGEST.BYRANK"
TDIGEST_BYREVRANK = "TDIGEST.BYREVRANK"
class BFCommands:
"""Bloom Filter commands."""
def create(self, key, errorRate, capacity, expansion=None, noScale=None):
"""
Create a new Bloom Filter `key` with desired probability of false positives
`errorRate` expected entries to be inserted as `capacity`.
Default expansion value is 2. By default, filter is auto-scaling.
For more information see `BF.RESERVE <https://redis.io/commands/bf.reserve>`_.
""" # noqa
params = [key, errorRate, capacity]
self.append_expansion(params, expansion)
self.append_no_scale(params, noScale)
return self.execute_command(BF_RESERVE, *params)
reserve = create
def add(self, key, item):
"""
Add to a Bloom Filter `key` an `item`.
For more information see `BF.ADD <https://redis.io/commands/bf.add>`_.
""" # noqa
return self.execute_command(BF_ADD, key, item)
def madd(self, key, *items):
"""
Add to a Bloom Filter `key` multiple `items`.
For more information see `BF.MADD <https://redis.io/commands/bf.madd>`_.
""" # noqa
return self.execute_command(BF_MADD, key, *items)
def insert(
self,
key,
items,
capacity=None,
error=None,
noCreate=None,
expansion=None,
noScale=None,
):
"""
Add to a Bloom Filter `key` multiple `items`.
If `nocreate` remain `None` and `key` does not exist, a new Bloom Filter
`key` will be created with desired probability of false positives `errorRate`
and expected entries to be inserted as `size`.
For more information see `BF.INSERT <https://redis.io/commands/bf.insert>`_.
""" # noqa
params = [key]
self.append_capacity(params, capacity)
self.append_error(params, error)
self.append_expansion(params, expansion)
self.append_no_create(params, noCreate)
self.append_no_scale(params, noScale)
self.append_items(params, items)
return self.execute_command(BF_INSERT, *params)
def exists(self, key, item):
"""
Check whether an `item` exists in Bloom Filter `key`.
For more information see `BF.EXISTS <https://redis.io/commands/bf.exists>`_.
""" # noqa
return self.execute_command(BF_EXISTS, key, item)
def mexists(self, key, *items):
"""
Check whether `items` exist in Bloom Filter `key`.
For more information see `BF.MEXISTS <https://redis.io/commands/bf.mexists>`_.
""" # noqa
return self.execute_command(BF_MEXISTS, key, *items)
def scandump(self, key, iter):
"""
Begin an incremental save of the bloom filter `key`.
This is useful for large bloom filters which cannot fit into the normal SAVE and RESTORE model.
The first time this command is called, the value of `iter` should be 0.
This command will return successive (iter, data) pairs until (0, NULL) to indicate completion.
For more information see `BF.SCANDUMP <https://redis.io/commands/bf.scandump>`_.
""" # noqa
params = [key, iter]
options = {}
options[NEVER_DECODE] = []
return self.execute_command(BF_SCANDUMP, *params, **options)
def loadchunk(self, key, iter, data):
"""
Restore a filter previously saved using SCANDUMP.
See the SCANDUMP command for example usage.
This command will overwrite any bloom filter stored under key.
Ensure that the bloom filter will not be modified between invocations.
For more information see `BF.LOADCHUNK <https://redis.io/commands/bf.loadchunk>`_.
""" # noqa
return self.execute_command(BF_LOADCHUNK, key, iter, data)
def info(self, key):
"""
Return capacity, size, number of filters, number of items inserted, and expansion rate.
For more information see `BF.INFO <https://redis.io/commands/bf.info>`_.
""" # noqa
return self.execute_command(BF_INFO, key)
def card(self, key):
"""
Returns the cardinality of a Bloom filter - number of items that were added to a Bloom filter and detected as unique
(items that caused at least one bit to be set in at least one sub-filter).
For more information see `BF.CARD <https://redis.io/commands/bf.card>`_.
""" # noqa
return self.execute_command(BF_CARD, key)
class CFCommands:
"""Cuckoo Filter commands."""
def create(
self, key, capacity, expansion=None, bucket_size=None, max_iterations=None
):
"""
Create a new Cuckoo Filter `key` an initial `capacity` items.
For more information see `CF.RESERVE <https://redis.io/commands/cf.reserve>`_.
""" # noqa
params = [key, capacity]
self.append_expansion(params, expansion)
self.append_bucket_size(params, bucket_size)
self.append_max_iterations(params, max_iterations)
return self.execute_command(CF_RESERVE, *params)
reserve = create
def add(self, key, item):
"""
Add an `item` to a Cuckoo Filter `key`.
For more information see `CF.ADD <https://redis.io/commands/cf.add>`_.
""" # noqa
return self.execute_command(CF_ADD, key, item)
def addnx(self, key, item):
"""
Add an `item` to a Cuckoo Filter `key` only if item does not yet exist.
Command might be slower that `add`.
For more information see `CF.ADDNX <https://redis.io/commands/cf.addnx>`_.
""" # noqa
return self.execute_command(CF_ADDNX, key, item)
def insert(self, key, items, capacity=None, nocreate=None):
"""
Add multiple `items` to a Cuckoo Filter `key`, allowing the filter
to be created with a custom `capacity` if it does not yet exist.
`items` must be provided as a list.
For more information see `CF.INSERT <https://redis.io/commands/cf.insert>`_.
""" # noqa
params = [key]
self.append_capacity(params, capacity)
self.append_no_create(params, nocreate)
self.append_items(params, items)
return self.execute_command(CF_INSERT, *params)
def insertnx(self, key, items, capacity=None, nocreate=None):
"""
Add multiple `items` to a Cuckoo Filter `key` only if they do not exist yet,
allowing the filter to be created with a custom `capacity` if it does not yet exist.
`items` must be provided as a list.
For more information see `CF.INSERTNX <https://redis.io/commands/cf.insertnx>`_.
""" # noqa
params = [key]
self.append_capacity(params, capacity)
self.append_no_create(params, nocreate)
self.append_items(params, items)
return self.execute_command(CF_INSERTNX, *params)
def exists(self, key, item):
"""
Check whether an `item` exists in Cuckoo Filter `key`.
For more information see `CF.EXISTS <https://redis.io/commands/cf.exists>`_.
""" # noqa
return self.execute_command(CF_EXISTS, key, item)
def mexists(self, key, *items):
"""
Check whether an `items` exist in Cuckoo Filter `key`.
For more information see `CF.MEXISTS <https://redis.io/commands/cf.mexists>`_.
""" # noqa
return self.execute_command(CF_MEXISTS, key, *items)
def delete(self, key, item):
"""
Delete `item` from `key`.
For more information see `CF.DEL <https://redis.io/commands/cf.del>`_.
""" # noqa
return self.execute_command(CF_DEL, key, item)
def count(self, key, item):
"""
Return the number of times an `item` may be in the `key`.
For more information see `CF.COUNT <https://redis.io/commands/cf.count>`_.
""" # noqa
return self.execute_command(CF_COUNT, key, item)
def scandump(self, key, iter):
"""
Begin an incremental save of the Cuckoo filter `key`.
This is useful for large Cuckoo filters which cannot fit into the normal
SAVE and RESTORE model.
The first time this command is called, the value of `iter` should be 0.
This command will return successive (iter, data) pairs until
(0, NULL) to indicate completion.
For more information see `CF.SCANDUMP <https://redis.io/commands/cf.scandump>`_.
""" # noqa
return self.execute_command(CF_SCANDUMP, key, iter)
def loadchunk(self, key, iter, data):
"""
Restore a filter previously saved using SCANDUMP. See the SCANDUMP command for example usage.
This command will overwrite any Cuckoo filter stored under key.
Ensure that the Cuckoo filter will not be modified between invocations.
For more information see `CF.LOADCHUNK <https://redis.io/commands/cf.loadchunk>`_.
""" # noqa
return self.execute_command(CF_LOADCHUNK, key, iter, data)
def info(self, key):
"""
Return size, number of buckets, number of filter, number of items inserted,
number of items deleted, bucket size, expansion rate, and max iteration.
For more information see `CF.INFO <https://redis.io/commands/cf.info>`_.
""" # noqa
return self.execute_command(CF_INFO, key)
class TOPKCommands:
"""TOP-k Filter commands."""
def reserve(self, key, k, width, depth, decay):
"""
Create a new Top-K Filter `key` with desired probability of false
positives `errorRate` expected entries to be inserted as `size`.
For more information see `TOPK.RESERVE <https://redis.io/commands/topk.reserve>`_.
""" # noqa
return self.execute_command(TOPK_RESERVE, key, k, width, depth, decay)
def add(self, key, *items):
"""
Add one `item` or more to a Top-K Filter `key`.
For more information see `TOPK.ADD <https://redis.io/commands/topk.add>`_.
""" # noqa
return self.execute_command(TOPK_ADD, key, *items)
def incrby(self, key, items, increments):
"""
Add/increase `items` to a Top-K Sketch `key` by ''increments''.
Both `items` and `increments` are lists.
For more information see `TOPK.INCRBY <https://redis.io/commands/topk.incrby>`_.
Example:
>>> topkincrby('A', ['foo'], [1])
""" # noqa
params = [key]
self.append_items_and_increments(params, items, increments)
return self.execute_command(TOPK_INCRBY, *params)
def query(self, key, *items):
"""
Check whether one `item` or more is a Top-K item at `key`.
For more information see `TOPK.QUERY <https://redis.io/commands/topk.query>`_.
""" # noqa
return self.execute_command(TOPK_QUERY, key, *items)
@deprecated_function(version="4.4.0", reason="deprecated since redisbloom 2.4.0")
def count(self, key, *items):
"""
Return count for one `item` or more from `key`.
For more information see `TOPK.COUNT <https://redis.io/commands/topk.count>`_.
""" # noqa
return self.execute_command(TOPK_COUNT, key, *items)
def list(self, key, withcount=False):
"""
Return full list of items in Top-K list of `key`.
If `withcount` set to True, return full list of items
with probabilistic count in Top-K list of `key`.
For more information see `TOPK.LIST <https://redis.io/commands/topk.list>`_.
""" # noqa
params = [key]
if withcount:
params.append("WITHCOUNT")
return self.execute_command(TOPK_LIST, *params)
def info(self, key):
"""
Return k, width, depth and decay values of `key`.
For more information see `TOPK.INFO <https://redis.io/commands/topk.info>`_.
""" # noqa
return self.execute_command(TOPK_INFO, key)
class TDigestCommands:
def create(self, key, compression=100):
"""
Allocate the memory and initialize the t-digest.
For more information see `TDIGEST.CREATE <https://redis.io/commands/tdigest.create>`_.
""" # noqa
return self.execute_command(TDIGEST_CREATE, key, "COMPRESSION", compression)
def reset(self, key):
"""
Reset the sketch `key` to zero - empty out the sketch and re-initialize it.
For more information see `TDIGEST.RESET <https://redis.io/commands/tdigest.reset>`_.
""" # noqa
return self.execute_command(TDIGEST_RESET, key)
def add(self, key, values):
"""
Adds one or more observations to a t-digest sketch `key`.
For more information see `TDIGEST.ADD <https://redis.io/commands/tdigest.add>`_.
""" # noqa
return self.execute_command(TDIGEST_ADD, key, *values)
def merge(self, destination_key, num_keys, *keys, compression=None, override=False):
"""
Merges all of the values from `keys` to 'destination-key' sketch.
It is mandatory to provide the `num_keys` before passing the input keys and
the other (optional) arguments.
If `destination_key` already exists its values are merged with the input keys.
If you wish to override the destination key contents use the `OVERRIDE` parameter.
For more information see `TDIGEST.MERGE <https://redis.io/commands/tdigest.merge>`_.
""" # noqa
params = [destination_key, num_keys, *keys]
if compression is not None:
params.extend(["COMPRESSION", compression])
if override:
params.append("OVERRIDE")
return self.execute_command(TDIGEST_MERGE, *params)
def min(self, key):
"""
Return minimum value from the sketch `key`. Will return DBL_MAX if the sketch is empty.
For more information see `TDIGEST.MIN <https://redis.io/commands/tdigest.min>`_.
""" # noqa
return self.execute_command(TDIGEST_MIN, key)
def max(self, key):
"""
Return maximum value from the sketch `key`. Will return DBL_MIN if the sketch is empty.
For more information see `TDIGEST.MAX <https://redis.io/commands/tdigest.max>`_.
""" # noqa
return self.execute_command(TDIGEST_MAX, key)
def quantile(self, key, quantile, *quantiles):
"""
Returns estimates of one or more cutoffs such that a specified fraction of the
observations added to this t-digest would be less than or equal to each of the
specified cutoffs. (Multiple quantiles can be returned with one call)
For more information see `TDIGEST.QUANTILE <https://redis.io/commands/tdigest.quantile>`_.
""" # noqa
return self.execute_command(TDIGEST_QUANTILE, key, quantile, *quantiles)
def cdf(self, key, value, *values):
"""
Return double fraction of all points added which are <= value.
For more information see `TDIGEST.CDF <https://redis.io/commands/tdigest.cdf>`_.
""" # noqa
return self.execute_command(TDIGEST_CDF, key, value, *values)
def info(self, key):
"""
Return Compression, Capacity, Merged Nodes, Unmerged Nodes, Merged Weight, Unmerged Weight
and Total Compressions.
For more information see `TDIGEST.INFO <https://redis.io/commands/tdigest.info>`_.
""" # noqa
return self.execute_command(TDIGEST_INFO, key)
def trimmed_mean(self, key, low_cut_quantile, high_cut_quantile):
"""
Return mean value from the sketch, excluding observation values outside
the low and high cutoff quantiles.
For more information see `TDIGEST.TRIMMED_MEAN <https://redis.io/commands/tdigest.trimmed_mean>`_.
""" # noqa
return self.execute_command(
TDIGEST_TRIMMED_MEAN, key, low_cut_quantile, high_cut_quantile
)
def rank(self, key, value, *values):
"""
Retrieve the estimated rank of value (the number of observations in the sketch
that are smaller than value + half the number of observations that are equal to value).
For more information see `TDIGEST.RANK <https://redis.io/commands/tdigest.rank>`_.
""" # noqa
return self.execute_command(TDIGEST_RANK, key, value, *values)
def revrank(self, key, value, *values):
"""
Retrieve the estimated rank of value (the number of observations in the sketch
that are larger than value + half the number of observations that are equal to value).
For more information see `TDIGEST.REVRANK <https://redis.io/commands/tdigest.revrank>`_.
""" # noqa
return self.execute_command(TDIGEST_REVRANK, key, value, *values)
def byrank(self, key, rank, *ranks):
"""
Retrieve an estimation of the value with the given rank.
For more information see `TDIGEST.BY_RANK <https://redis.io/commands/tdigest.by_rank>`_.
""" # noqa
return self.execute_command(TDIGEST_BYRANK, key, rank, *ranks)
def byrevrank(self, key, rank, *ranks):
"""
Retrieve an estimation of the value with the given reverse rank.
For more information see `TDIGEST.BY_REVRANK <https://redis.io/commands/tdigest.by_revrank>`_.
""" # noqa
return self.execute_command(TDIGEST_BYREVRANK, key, rank, *ranks)
class CMSCommands:
"""Count-Min Sketch Commands"""
def initbydim(self, key, width, depth):
"""
Initialize a Count-Min Sketch `key` to dimensions (`width`, `depth`) specified by user.
For more information see `CMS.INITBYDIM <https://redis.io/commands/cms.initbydim>`_.
""" # noqa
return self.execute_command(CMS_INITBYDIM, key, width, depth)
def initbyprob(self, key, error, probability):
"""
Initialize a Count-Min Sketch `key` to characteristics (`error`, `probability`) specified by user.
For more information see `CMS.INITBYPROB <https://redis.io/commands/cms.initbyprob>`_.
""" # noqa
return self.execute_command(CMS_INITBYPROB, key, error, probability)
def incrby(self, key, items, increments):
"""
Add/increase `items` to a Count-Min Sketch `key` by ''increments''.
Both `items` and `increments` are lists.
For more information see `CMS.INCRBY <https://redis.io/commands/cms.incrby>`_.
Example:
>>> cmsincrby('A', ['foo'], [1])
""" # noqa
params = [key]
self.append_items_and_increments(params, items, increments)
return self.execute_command(CMS_INCRBY, *params)
def query(self, key, *items):
"""
Return count for an `item` from `key`. Multiple items can be queried with one call.
For more information see `CMS.QUERY <https://redis.io/commands/cms.query>`_.
""" # noqa
return self.execute_command(CMS_QUERY, key, *items)
def merge(self, destKey, numKeys, srcKeys, weights=[]):
"""
Merge `numKeys` of sketches into `destKey`. Sketches specified in `srcKeys`.
All sketches must have identical width and depth.
`Weights` can be used to multiply certain sketches. Default weight is 1.
Both `srcKeys` and `weights` are lists.
For more information see `CMS.MERGE <https://redis.io/commands/cms.merge>`_.
""" # noqa
params = [destKey, numKeys]
params += srcKeys
self.append_weights(params, weights)
return self.execute_command(CMS_MERGE, *params)
def info(self, key):
"""
Return width, depth and total count of the sketch.
For more information see `CMS.INFO <https://redis.io/commands/cms.info>`_.
""" # noqa
return self.execute_command(CMS_INFO, key)

View File

@@ -0,0 +1,120 @@
from ..helpers import nativestr
class BFInfo:
capacity = None
size = None
filterNum = None
insertedNum = None
expansionRate = None
def __init__(self, args):
response = dict(zip(map(nativestr, args[::2]), args[1::2]))
self.capacity = response["Capacity"]
self.size = response["Size"]
self.filterNum = response["Number of filters"]
self.insertedNum = response["Number of items inserted"]
self.expansionRate = response["Expansion rate"]
def get(self, item):
try:
return self.__getitem__(item)
except AttributeError:
return None
def __getitem__(self, item):
return getattr(self, item)
class CFInfo:
size = None
bucketNum = None
filterNum = None
insertedNum = None
deletedNum = None
bucketSize = None
expansionRate = None
maxIteration = None
def __init__(self, args):
response = dict(zip(map(nativestr, args[::2]), args[1::2]))
self.size = response["Size"]
self.bucketNum = response["Number of buckets"]
self.filterNum = response["Number of filters"]
self.insertedNum = response["Number of items inserted"]
self.deletedNum = response["Number of items deleted"]
self.bucketSize = response["Bucket size"]
self.expansionRate = response["Expansion rate"]
self.maxIteration = response["Max iterations"]
def get(self, item):
try:
return self.__getitem__(item)
except AttributeError:
return None
def __getitem__(self, item):
return getattr(self, item)
class CMSInfo:
width = None
depth = None
count = None
def __init__(self, args):
response = dict(zip(map(nativestr, args[::2]), args[1::2]))
self.width = response["width"]
self.depth = response["depth"]
self.count = response["count"]
def __getitem__(self, item):
return getattr(self, item)
class TopKInfo:
k = None
width = None
depth = None
decay = None
def __init__(self, args):
response = dict(zip(map(nativestr, args[::2]), args[1::2]))
self.k = response["k"]
self.width = response["width"]
self.depth = response["depth"]
self.decay = response["decay"]
def __getitem__(self, item):
return getattr(self, item)
class TDigestInfo:
compression = None
capacity = None
merged_nodes = None
unmerged_nodes = None
merged_weight = None
unmerged_weight = None
total_compressions = None
memory_usage = None
def __init__(self, args):
response = dict(zip(map(nativestr, args[::2]), args[1::2]))
self.compression = response["Compression"]
self.capacity = response["Capacity"]
self.merged_nodes = response["Merged nodes"]
self.unmerged_nodes = response["Unmerged nodes"]
self.merged_weight = response["Merged weight"]
self.unmerged_weight = response["Unmerged weight"]
self.total_compressions = response["Total compressions"]
self.memory_usage = response["Memory usage"]
def get(self, item):
try:
return self.__getitem__(item)
except AttributeError:
return None
def __getitem__(self, item):
return getattr(self, item)

View File

@@ -0,0 +1,919 @@
import asyncio
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Dict,
Iterable,
Iterator,
List,
Literal,
Mapping,
NoReturn,
Optional,
Union,
)
from redis.crc import key_slot
from redis.exceptions import RedisClusterException, RedisError
from redis.typing import (
AnyKeyT,
ClusterCommandsProtocol,
EncodableT,
KeysT,
KeyT,
PatternT,
ResponseT,
)
from .core import (
ACLCommands,
AsyncACLCommands,
AsyncDataAccessCommands,
AsyncFunctionCommands,
AsyncManagementCommands,
AsyncModuleCommands,
AsyncScriptCommands,
DataAccessCommands,
FunctionCommands,
ManagementCommands,
ModuleCommands,
PubSubCommands,
ScriptCommands,
)
from .helpers import list_or_args
from .redismodules import AsyncRedisModuleCommands, RedisModuleCommands
if TYPE_CHECKING:
from redis.asyncio.cluster import TargetNodesT
# Not complete, but covers the major ones
# https://redis.io/commands
READ_COMMANDS = frozenset(
[
"BITCOUNT",
"BITPOS",
"EVAL_RO",
"EVALSHA_RO",
"EXISTS",
"GEODIST",
"GEOHASH",
"GEOPOS",
"GEORADIUS",
"GEORADIUSBYMEMBER",
"GET",
"GETBIT",
"GETRANGE",
"HEXISTS",
"HGET",
"HGETALL",
"HKEYS",
"HLEN",
"HMGET",
"HSTRLEN",
"HVALS",
"KEYS",
"LINDEX",
"LLEN",
"LRANGE",
"MGET",
"PTTL",
"RANDOMKEY",
"SCARD",
"SDIFF",
"SINTER",
"SISMEMBER",
"SMEMBERS",
"SRANDMEMBER",
"STRLEN",
"SUNION",
"TTL",
"ZCARD",
"ZCOUNT",
"ZRANGE",
"ZSCORE",
]
)
class ClusterMultiKeyCommands(ClusterCommandsProtocol):
"""
A class containing commands that handle more than one key
"""
def _partition_keys_by_slot(self, keys: Iterable[KeyT]) -> Dict[int, List[KeyT]]:
"""Split keys into a dictionary that maps a slot to a list of keys."""
slots_to_keys = {}
for key in keys:
slot = key_slot(self.encoder.encode(key))
slots_to_keys.setdefault(slot, []).append(key)
return slots_to_keys
def _partition_pairs_by_slot(
self, mapping: Mapping[AnyKeyT, EncodableT]
) -> Dict[int, List[EncodableT]]:
"""Split pairs into a dictionary that maps a slot to a list of pairs."""
slots_to_pairs = {}
for pair in mapping.items():
slot = key_slot(self.encoder.encode(pair[0]))
slots_to_pairs.setdefault(slot, []).extend(pair)
return slots_to_pairs
def _execute_pipeline_by_slot(
self, command: str, slots_to_args: Mapping[int, Iterable[EncodableT]]
) -> List[Any]:
read_from_replicas = self.read_from_replicas and command in READ_COMMANDS
pipe = self.pipeline()
[
pipe.execute_command(
command,
*slot_args,
target_nodes=[
self.nodes_manager.get_node_from_slot(slot, read_from_replicas)
],
)
for slot, slot_args in slots_to_args.items()
]
return pipe.execute()
def _reorder_keys_by_command(
self,
keys: Iterable[KeyT],
slots_to_args: Mapping[int, Iterable[EncodableT]],
responses: Iterable[Any],
) -> List[Any]:
results = {
k: v
for slot_values, response in zip(slots_to_args.values(), responses)
for k, v in zip(slot_values, response)
}
return [results[key] for key in keys]
def mget_nonatomic(self, keys: KeysT, *args: KeyT) -> List[Optional[Any]]:
"""
Splits the keys into different slots and then calls MGET
for the keys of every slot. This operation will not be atomic
if keys belong to more than one slot.
Returns a list of values ordered identically to ``keys``
For more information see https://redis.io/commands/mget
"""
# Concatenate all keys into a list
keys = list_or_args(keys, args)
# Split keys into slots
slots_to_keys = self._partition_keys_by_slot(keys)
# Execute commands using a pipeline
res = self._execute_pipeline_by_slot("MGET", slots_to_keys)
# Reorder keys in the order the user provided & return
return self._reorder_keys_by_command(keys, slots_to_keys, res)
def mset_nonatomic(self, mapping: Mapping[AnyKeyT, EncodableT]) -> List[bool]:
"""
Sets key/values based on a mapping. Mapping is a dictionary of
key/value pairs. Both keys and values should be strings or types that
can be cast to a string via str().
Splits the keys into different slots and then calls MSET
for the keys of every slot. This operation will not be atomic
if keys belong to more than one slot.
For more information see https://redis.io/commands/mset
"""
# Partition the keys by slot
slots_to_pairs = self._partition_pairs_by_slot(mapping)
# Execute commands using a pipeline & return list of replies
return self._execute_pipeline_by_slot("MSET", slots_to_pairs)
def _split_command_across_slots(self, command: str, *keys: KeyT) -> int:
"""
Runs the given command once for the keys
of each slot. Returns the sum of the return values.
"""
# Partition the keys by slot
slots_to_keys = self._partition_keys_by_slot(keys)
# Sum up the reply from each command
return sum(self._execute_pipeline_by_slot(command, slots_to_keys))
def exists(self, *keys: KeyT) -> ResponseT:
"""
Returns the number of ``names`` that exist in the
whole cluster. The keys are first split up into slots
and then an EXISTS command is sent for every slot
For more information see https://redis.io/commands/exists
"""
return self._split_command_across_slots("EXISTS", *keys)
def delete(self, *keys: KeyT) -> ResponseT:
"""
Deletes the given keys in the cluster.
The keys are first split up into slots
and then an DEL command is sent for every slot
Non-existent keys are ignored.
Returns the number of keys that were deleted.
For more information see https://redis.io/commands/del
"""
return self._split_command_across_slots("DEL", *keys)
def touch(self, *keys: KeyT) -> ResponseT:
"""
Updates the last access time of given keys across the
cluster.
The keys are first split up into slots
and then an TOUCH command is sent for every slot
Non-existent keys are ignored.
Returns the number of keys that were touched.
For more information see https://redis.io/commands/touch
"""
return self._split_command_across_slots("TOUCH", *keys)
def unlink(self, *keys: KeyT) -> ResponseT:
"""
Remove the specified keys in a different thread.
The keys are first split up into slots
and then an TOUCH command is sent for every slot
Non-existent keys are ignored.
Returns the number of keys that were unlinked.
For more information see https://redis.io/commands/unlink
"""
return self._split_command_across_slots("UNLINK", *keys)
class AsyncClusterMultiKeyCommands(ClusterMultiKeyCommands):
"""
A class containing commands that handle more than one key
"""
async def mget_nonatomic(self, keys: KeysT, *args: KeyT) -> List[Optional[Any]]:
"""
Splits the keys into different slots and then calls MGET
for the keys of every slot. This operation will not be atomic
if keys belong to more than one slot.
Returns a list of values ordered identically to ``keys``
For more information see https://redis.io/commands/mget
"""
# Concatenate all keys into a list
keys = list_or_args(keys, args)
# Split keys into slots
slots_to_keys = self._partition_keys_by_slot(keys)
# Execute commands using a pipeline
res = await self._execute_pipeline_by_slot("MGET", slots_to_keys)
# Reorder keys in the order the user provided & return
return self._reorder_keys_by_command(keys, slots_to_keys, res)
async def mset_nonatomic(self, mapping: Mapping[AnyKeyT, EncodableT]) -> List[bool]:
"""
Sets key/values based on a mapping. Mapping is a dictionary of
key/value pairs. Both keys and values should be strings or types that
can be cast to a string via str().
Splits the keys into different slots and then calls MSET
for the keys of every slot. This operation will not be atomic
if keys belong to more than one slot.
For more information see https://redis.io/commands/mset
"""
# Partition the keys by slot
slots_to_pairs = self._partition_pairs_by_slot(mapping)
# Execute commands using a pipeline & return list of replies
return await self._execute_pipeline_by_slot("MSET", slots_to_pairs)
async def _split_command_across_slots(self, command: str, *keys: KeyT) -> int:
"""
Runs the given command once for the keys
of each slot. Returns the sum of the return values.
"""
# Partition the keys by slot
slots_to_keys = self._partition_keys_by_slot(keys)
# Sum up the reply from each command
return sum(await self._execute_pipeline_by_slot(command, slots_to_keys))
async def _execute_pipeline_by_slot(
self, command: str, slots_to_args: Mapping[int, Iterable[EncodableT]]
) -> List[Any]:
if self._initialize:
await self.initialize()
read_from_replicas = self.read_from_replicas and command in READ_COMMANDS
pipe = self.pipeline()
[
pipe.execute_command(
command,
*slot_args,
target_nodes=[
self.nodes_manager.get_node_from_slot(slot, read_from_replicas)
],
)
for slot, slot_args in slots_to_args.items()
]
return await pipe.execute()
class ClusterManagementCommands(ManagementCommands):
"""
A class for Redis Cluster management commands
The class inherits from Redis's core ManagementCommands class and do the
required adjustments to work with cluster mode
"""
def slaveof(self, *args, **kwargs) -> NoReturn:
"""
Make the server a replica of another instance, or promote it as master.
For more information see https://redis.io/commands/slaveof
"""
raise RedisClusterException("SLAVEOF is not supported in cluster mode")
def replicaof(self, *args, **kwargs) -> NoReturn:
"""
Make the server a replica of another instance, or promote it as master.
For more information see https://redis.io/commands/replicaof
"""
raise RedisClusterException("REPLICAOF is not supported in cluster mode")
def swapdb(self, *args, **kwargs) -> NoReturn:
"""
Swaps two Redis databases.
For more information see https://redis.io/commands/swapdb
"""
raise RedisClusterException("SWAPDB is not supported in cluster mode")
def cluster_myid(self, target_node: "TargetNodesT") -> ResponseT:
"""
Returns the node's id.
:target_node: 'ClusterNode'
The node to execute the command on
For more information check https://redis.io/commands/cluster-myid/
"""
return self.execute_command("CLUSTER MYID", target_nodes=target_node)
def cluster_addslots(
self, target_node: "TargetNodesT", *slots: EncodableT
) -> ResponseT:
"""
Assign new hash slots to receiving node. Sends to specified node.
:target_node: 'ClusterNode'
The node to execute the command on
For more information see https://redis.io/commands/cluster-addslots
"""
return self.execute_command(
"CLUSTER ADDSLOTS", *slots, target_nodes=target_node
)
def cluster_addslotsrange(
self, target_node: "TargetNodesT", *slots: EncodableT
) -> ResponseT:
"""
Similar to the CLUSTER ADDSLOTS command.
The difference between the two commands is that ADDSLOTS takes a list of slots
to assign to the node, while ADDSLOTSRANGE takes a list of slot ranges
(specified by start and end slots) to assign to the node.
:target_node: 'ClusterNode'
The node to execute the command on
For more information see https://redis.io/commands/cluster-addslotsrange
"""
return self.execute_command(
"CLUSTER ADDSLOTSRANGE", *slots, target_nodes=target_node
)
def cluster_countkeysinslot(self, slot_id: int) -> ResponseT:
"""
Return the number of local keys in the specified hash slot
Send to node based on specified slot_id
For more information see https://redis.io/commands/cluster-countkeysinslot
"""
return self.execute_command("CLUSTER COUNTKEYSINSLOT", slot_id)
def cluster_count_failure_report(self, node_id: str) -> ResponseT:
"""
Return the number of failure reports active for a given node
Sends to a random node
For more information see https://redis.io/commands/cluster-count-failure-reports
"""
return self.execute_command("CLUSTER COUNT-FAILURE-REPORTS", node_id)
def cluster_delslots(self, *slots: EncodableT) -> List[bool]:
"""
Set hash slots as unbound in the cluster.
It determines by it self what node the slot is in and sends it there
Returns a list of the results for each processed slot.
For more information see https://redis.io/commands/cluster-delslots
"""
return [self.execute_command("CLUSTER DELSLOTS", slot) for slot in slots]
def cluster_delslotsrange(self, *slots: EncodableT) -> ResponseT:
"""
Similar to the CLUSTER DELSLOTS command.
The difference is that CLUSTER DELSLOTS takes a list of hash slots to remove
from the node, while CLUSTER DELSLOTSRANGE takes a list of slot ranges to remove
from the node.
For more information see https://redis.io/commands/cluster-delslotsrange
"""
return self.execute_command("CLUSTER DELSLOTSRANGE", *slots)
def cluster_failover(
self, target_node: "TargetNodesT", option: Optional[str] = None
) -> ResponseT:
"""
Forces a slave to perform a manual failover of its master
Sends to specified node
:target_node: 'ClusterNode'
The node to execute the command on
For more information see https://redis.io/commands/cluster-failover
"""
if option:
if option.upper() not in ["FORCE", "TAKEOVER"]:
raise RedisError(
f"Invalid option for CLUSTER FAILOVER command: {option}"
)
else:
return self.execute_command(
"CLUSTER FAILOVER", option, target_nodes=target_node
)
else:
return self.execute_command("CLUSTER FAILOVER", target_nodes=target_node)
def cluster_info(self, target_nodes: Optional["TargetNodesT"] = None) -> ResponseT:
"""
Provides info about Redis Cluster node state.
The command will be sent to a random node in the cluster if no target
node is specified.
For more information see https://redis.io/commands/cluster-info
"""
return self.execute_command("CLUSTER INFO", target_nodes=target_nodes)
def cluster_keyslot(self, key: str) -> ResponseT:
"""
Returns the hash slot of the specified key
Sends to random node in the cluster
For more information see https://redis.io/commands/cluster-keyslot
"""
return self.execute_command("CLUSTER KEYSLOT", key)
def cluster_meet(
self, host: str, port: int, target_nodes: Optional["TargetNodesT"] = None
) -> ResponseT:
"""
Force a node cluster to handshake with another node.
Sends to specified node.
For more information see https://redis.io/commands/cluster-meet
"""
return self.execute_command(
"CLUSTER MEET", host, port, target_nodes=target_nodes
)
def cluster_nodes(self) -> ResponseT:
"""
Get Cluster config for the node.
Sends to random node in the cluster
For more information see https://redis.io/commands/cluster-nodes
"""
return self.execute_command("CLUSTER NODES")
def cluster_replicate(
self, target_nodes: "TargetNodesT", node_id: str
) -> ResponseT:
"""
Reconfigure a node as a slave of the specified master node
For more information see https://redis.io/commands/cluster-replicate
"""
return self.execute_command(
"CLUSTER REPLICATE", node_id, target_nodes=target_nodes
)
def cluster_reset(
self, soft: bool = True, target_nodes: Optional["TargetNodesT"] = None
) -> ResponseT:
"""
Reset a Redis Cluster node
If 'soft' is True then it will send 'SOFT' argument
If 'soft' is False then it will send 'HARD' argument
For more information see https://redis.io/commands/cluster-reset
"""
return self.execute_command(
"CLUSTER RESET", b"SOFT" if soft else b"HARD", target_nodes=target_nodes
)
def cluster_save_config(
self, target_nodes: Optional["TargetNodesT"] = None
) -> ResponseT:
"""
Forces the node to save cluster state on disk
For more information see https://redis.io/commands/cluster-saveconfig
"""
return self.execute_command("CLUSTER SAVECONFIG", target_nodes=target_nodes)
def cluster_get_keys_in_slot(self, slot: int, num_keys: int) -> ResponseT:
"""
Returns the number of keys in the specified cluster slot
For more information see https://redis.io/commands/cluster-getkeysinslot
"""
return self.execute_command("CLUSTER GETKEYSINSLOT", slot, num_keys)
def cluster_set_config_epoch(
self, epoch: int, target_nodes: Optional["TargetNodesT"] = None
) -> ResponseT:
"""
Set the configuration epoch in a new node
For more information see https://redis.io/commands/cluster-set-config-epoch
"""
return self.execute_command(
"CLUSTER SET-CONFIG-EPOCH", epoch, target_nodes=target_nodes
)
def cluster_setslot(
self, target_node: "TargetNodesT", node_id: str, slot_id: int, state: str
) -> ResponseT:
"""
Bind an hash slot to a specific node
:target_node: 'ClusterNode'
The node to execute the command on
For more information see https://redis.io/commands/cluster-setslot
"""
if state.upper() in ("IMPORTING", "NODE", "MIGRATING"):
return self.execute_command(
"CLUSTER SETSLOT", slot_id, state, node_id, target_nodes=target_node
)
elif state.upper() == "STABLE":
raise RedisError('For "stable" state please use cluster_setslot_stable')
else:
raise RedisError(f"Invalid slot state: {state}")
def cluster_setslot_stable(self, slot_id: int) -> ResponseT:
"""
Clears migrating / importing state from the slot.
It determines by it self what node the slot is in and sends it there.
For more information see https://redis.io/commands/cluster-setslot
"""
return self.execute_command("CLUSTER SETSLOT", slot_id, "STABLE")
def cluster_replicas(
self, node_id: str, target_nodes: Optional["TargetNodesT"] = None
) -> ResponseT:
"""
Provides a list of replica nodes replicating from the specified primary
target node.
For more information see https://redis.io/commands/cluster-replicas
"""
return self.execute_command(
"CLUSTER REPLICAS", node_id, target_nodes=target_nodes
)
def cluster_slots(self, target_nodes: Optional["TargetNodesT"] = None) -> ResponseT:
"""
Get array of Cluster slot to node mappings
For more information see https://redis.io/commands/cluster-slots
"""
return self.execute_command("CLUSTER SLOTS", target_nodes=target_nodes)
def cluster_shards(self, target_nodes=None):
"""
Returns details about the shards of the cluster.
For more information see https://redis.io/commands/cluster-shards
"""
return self.execute_command("CLUSTER SHARDS", target_nodes=target_nodes)
def cluster_myshardid(self, target_nodes=None):
"""
Returns the shard ID of the node.
For more information see https://redis.io/commands/cluster-myshardid/
"""
return self.execute_command("CLUSTER MYSHARDID", target_nodes=target_nodes)
def cluster_links(self, target_node: "TargetNodesT") -> ResponseT:
"""
Each node in a Redis Cluster maintains a pair of long-lived TCP link with each
peer in the cluster: One for sending outbound messages towards the peer and one
for receiving inbound messages from the peer.
This command outputs information of all such peer links as an array.
For more information see https://redis.io/commands/cluster-links
"""
return self.execute_command("CLUSTER LINKS", target_nodes=target_node)
def cluster_flushslots(self, target_nodes: Optional["TargetNodesT"] = None) -> None:
raise NotImplementedError(
"CLUSTER FLUSHSLOTS is intentionally not implemented in the client."
)
def cluster_bumpepoch(self, target_nodes: Optional["TargetNodesT"] = None) -> None:
raise NotImplementedError(
"CLUSTER BUMPEPOCH is intentionally not implemented in the client."
)
def readonly(self, target_nodes: Optional["TargetNodesT"] = None) -> ResponseT:
"""
Enables read queries.
The command will be sent to the default cluster node if target_nodes is
not specified.
For more information see https://redis.io/commands/readonly
"""
if target_nodes == "replicas" or target_nodes == "all":
# read_from_replicas will only be enabled if the READONLY command
# is sent to all replicas
self.read_from_replicas = True
return self.execute_command("READONLY", target_nodes=target_nodes)
def readwrite(self, target_nodes: Optional["TargetNodesT"] = None) -> ResponseT:
"""
Disables read queries.
The command will be sent to the default cluster node if target_nodes is
not specified.
For more information see https://redis.io/commands/readwrite
"""
# Reset read from replicas flag
self.read_from_replicas = False
return self.execute_command("READWRITE", target_nodes=target_nodes)
class AsyncClusterManagementCommands(
ClusterManagementCommands, AsyncManagementCommands
):
"""
A class for Redis Cluster management commands
The class inherits from Redis's core ManagementCommands class and do the
required adjustments to work with cluster mode
"""
async def cluster_delslots(self, *slots: EncodableT) -> List[bool]:
"""
Set hash slots as unbound in the cluster.
It determines by it self what node the slot is in and sends it there
Returns a list of the results for each processed slot.
For more information see https://redis.io/commands/cluster-delslots
"""
return await asyncio.gather(
*(
asyncio.create_task(self.execute_command("CLUSTER DELSLOTS", slot))
for slot in slots
)
)
class ClusterDataAccessCommands(DataAccessCommands):
"""
A class for Redis Cluster Data Access Commands
The class inherits from Redis's core DataAccessCommand class and do the
required adjustments to work with cluster mode
"""
def stralgo(
self,
algo: Literal["LCS"],
value1: KeyT,
value2: KeyT,
specific_argument: Union[Literal["strings"], Literal["keys"]] = "strings",
len: bool = False,
idx: bool = False,
minmatchlen: Optional[int] = None,
withmatchlen: bool = False,
**kwargs,
) -> ResponseT:
"""
Implements complex algorithms that operate on strings.
Right now the only algorithm implemented is the LCS algorithm
(longest common substring). However new algorithms could be
implemented in the future.
``algo`` Right now must be LCS
``value1`` and ``value2`` Can be two strings or two keys
``specific_argument`` Specifying if the arguments to the algorithm
will be keys or strings. strings is the default.
``len`` Returns just the len of the match.
``idx`` Returns the match positions in each string.
``minmatchlen`` Restrict the list of matches to the ones of a given
minimal length. Can be provided only when ``idx`` set to True.
``withmatchlen`` Returns the matches with the len of the match.
Can be provided only when ``idx`` set to True.
For more information see https://redis.io/commands/stralgo
"""
target_nodes = kwargs.pop("target_nodes", None)
if specific_argument == "strings" and target_nodes is None:
target_nodes = "default-node"
kwargs.update({"target_nodes": target_nodes})
return super().stralgo(
algo,
value1,
value2,
specific_argument,
len,
idx,
minmatchlen,
withmatchlen,
**kwargs,
)
def scan_iter(
self,
match: Optional[PatternT] = None,
count: Optional[int] = None,
_type: Optional[str] = None,
**kwargs,
) -> Iterator:
# Do the first query with cursor=0 for all nodes
cursors, data = self.scan(match=match, count=count, _type=_type, **kwargs)
yield from data
cursors = {name: cursor for name, cursor in cursors.items() if cursor != 0}
if cursors:
# Get nodes by name
nodes = {name: self.get_node(node_name=name) for name in cursors.keys()}
# Iterate over each node till its cursor is 0
kwargs.pop("target_nodes", None)
while cursors:
for name, cursor in cursors.items():
cur, data = self.scan(
cursor=cursor,
match=match,
count=count,
_type=_type,
target_nodes=nodes[name],
**kwargs,
)
yield from data
cursors[name] = cur[name]
cursors = {
name: cursor for name, cursor in cursors.items() if cursor != 0
}
class AsyncClusterDataAccessCommands(
ClusterDataAccessCommands, AsyncDataAccessCommands
):
"""
A class for Redis Cluster Data Access Commands
The class inherits from Redis's core DataAccessCommand class and do the
required adjustments to work with cluster mode
"""
async def scan_iter(
self,
match: Optional[PatternT] = None,
count: Optional[int] = None,
_type: Optional[str] = None,
**kwargs,
) -> AsyncIterator:
# Do the first query with cursor=0 for all nodes
cursors, data = await self.scan(match=match, count=count, _type=_type, **kwargs)
for value in data:
yield value
cursors = {name: cursor for name, cursor in cursors.items() if cursor != 0}
if cursors:
# Get nodes by name
nodes = {name: self.get_node(node_name=name) for name in cursors.keys()}
# Iterate over each node till its cursor is 0
kwargs.pop("target_nodes", None)
while cursors:
for name, cursor in cursors.items():
cur, data = await self.scan(
cursor=cursor,
match=match,
count=count,
_type=_type,
target_nodes=nodes[name],
**kwargs,
)
for value in data:
yield value
cursors[name] = cur[name]
cursors = {
name: cursor for name, cursor in cursors.items() if cursor != 0
}
class RedisClusterCommands(
ClusterMultiKeyCommands,
ClusterManagementCommands,
ACLCommands,
PubSubCommands,
ClusterDataAccessCommands,
ScriptCommands,
FunctionCommands,
ModuleCommands,
RedisModuleCommands,
):
"""
A class for all Redis Cluster commands
For key-based commands, the target node(s) will be internally determined
by the keys' hash slot.
Non-key-based commands can be executed with the 'target_nodes' argument to
target specific nodes. By default, if target_nodes is not specified, the
command will be executed on the default cluster node.
:param :target_nodes: type can be one of the followings:
- nodes flag: ALL_NODES, PRIMARIES, REPLICAS, RANDOM
- 'ClusterNode'
- 'list(ClusterNodes)'
- 'dict(any:clusterNodes)'
for example:
r.cluster_info(target_nodes=RedisCluster.ALL_NODES)
"""
class AsyncRedisClusterCommands(
AsyncClusterMultiKeyCommands,
AsyncClusterManagementCommands,
AsyncACLCommands,
AsyncClusterDataAccessCommands,
AsyncScriptCommands,
AsyncFunctionCommands,
AsyncModuleCommands,
AsyncRedisModuleCommands,
):
"""
A class for all Redis Cluster commands
For key-based commands, the target node(s) will be internally determined
by the keys' hash slot.
Non-key-based commands can be executed with the 'target_nodes' argument to
target specific nodes. By default, if target_nodes is not specified, the
command will be executed on the default cluster node.
:param :target_nodes: type can be one of the followings:
- nodes flag: ALL_NODES, PRIMARIES, REPLICAS, RANDOM
- 'ClusterNode'
- 'list(ClusterNodes)'
- 'dict(any:clusterNodes)'
for example:
r.cluster_info(target_nodes=RedisCluster.ALL_NODES)
"""

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,118 @@
import copy
import random
import string
from typing import List, Tuple
import redis
from redis.typing import KeysT, KeyT
def list_or_args(keys: KeysT, args: Tuple[KeyT, ...]) -> List[KeyT]:
# returns a single new list combining keys and args
try:
iter(keys)
# a string or bytes instance can be iterated, but indicates
# keys wasn't passed as a list
if isinstance(keys, (bytes, str)):
keys = [keys]
else:
keys = list(keys)
except TypeError:
keys = [keys]
if args:
keys.extend(args)
return keys
def nativestr(x):
"""Return the decoded binary string, or a string, depending on type."""
r = x.decode("utf-8", "replace") if isinstance(x, bytes) else x
if r == "null":
return
return r
def delist(x):
"""Given a list of binaries, return the stringified version."""
if x is None:
return x
return [nativestr(obj) for obj in x]
def parse_to_list(response):
"""Optimistically parse the response to a list."""
res = []
special_values = {"infinity", "nan", "-infinity"}
if response is None:
return res
for item in response:
if item is None:
res.append(None)
continue
try:
item_str = nativestr(item)
except TypeError:
res.append(None)
continue
if isinstance(item_str, str) and item_str.lower() in special_values:
res.append(item_str) # Keep as string
else:
try:
res.append(int(item))
except ValueError:
try:
res.append(float(item))
except ValueError:
res.append(item_str)
return res
def parse_list_to_dict(response):
res = {}
for i in range(0, len(response), 2):
if isinstance(response[i], list):
res["Child iterators"].append(parse_list_to_dict(response[i]))
try:
if isinstance(response[i + 1], list):
res["Child iterators"].append(parse_list_to_dict(response[i + 1]))
except IndexError:
pass
elif isinstance(response[i + 1], list):
res["Child iterators"] = [parse_list_to_dict(response[i + 1])]
else:
try:
res[response[i]] = float(response[i + 1])
except (TypeError, ValueError):
res[response[i]] = response[i + 1]
return res
def random_string(length=10):
"""
Returns a random N character long string.
"""
return "".join( # nosec
random.choice(string.ascii_lowercase) for x in range(length)
)
def decode_dict_keys(obj):
"""Decode the keys of the given dictionary with utf-8."""
newobj = copy.copy(obj)
for k in obj.keys():
if isinstance(k, bytes):
newobj[k.decode("utf-8")] = newobj[k]
newobj.pop(k)
return newobj
def get_protocol_version(client):
if isinstance(client, redis.Redis) or isinstance(client, redis.asyncio.Redis):
return client.connection_pool.connection_kwargs.get("protocol")
elif isinstance(client, redis.cluster.AbstractRedisCluster):
return client.nodes_manager.connection_kwargs.get("protocol")

View File

@@ -0,0 +1,147 @@
from json import JSONDecodeError, JSONDecoder, JSONEncoder
import redis
from ..helpers import get_protocol_version, nativestr
from .commands import JSONCommands
from .decoders import bulk_of_jsons, decode_list
class JSON(JSONCommands):
"""
Create a client for talking to json.
:param decoder:
:type json.JSONDecoder: An instance of json.JSONDecoder
:param encoder:
:type json.JSONEncoder: An instance of json.JSONEncoder
"""
def __init__(
self, client, version=None, decoder=JSONDecoder(), encoder=JSONEncoder()
):
"""
Create a client for talking to json.
:param decoder:
:type json.JSONDecoder: An instance of json.JSONDecoder
:param encoder:
:type json.JSONEncoder: An instance of json.JSONEncoder
"""
# Set the module commands' callbacks
self._MODULE_CALLBACKS = {
"JSON.ARRPOP": self._decode,
"JSON.DEBUG": self._decode,
"JSON.GET": self._decode,
"JSON.MERGE": lambda r: r and nativestr(r) == "OK",
"JSON.MGET": bulk_of_jsons(self._decode),
"JSON.MSET": lambda r: r and nativestr(r) == "OK",
"JSON.RESP": self._decode,
"JSON.SET": lambda r: r and nativestr(r) == "OK",
"JSON.TOGGLE": self._decode,
}
_RESP2_MODULE_CALLBACKS = {
"JSON.ARRAPPEND": self._decode,
"JSON.ARRINDEX": self._decode,
"JSON.ARRINSERT": self._decode,
"JSON.ARRLEN": self._decode,
"JSON.ARRTRIM": self._decode,
"JSON.CLEAR": int,
"JSON.DEL": int,
"JSON.FORGET": int,
"JSON.GET": self._decode,
"JSON.NUMINCRBY": self._decode,
"JSON.NUMMULTBY": self._decode,
"JSON.OBJKEYS": self._decode,
"JSON.STRAPPEND": self._decode,
"JSON.OBJLEN": self._decode,
"JSON.STRLEN": self._decode,
"JSON.TOGGLE": self._decode,
}
_RESP3_MODULE_CALLBACKS = {}
self.client = client
self.execute_command = client.execute_command
self.MODULE_VERSION = version
if get_protocol_version(self.client) in ["3", 3]:
self._MODULE_CALLBACKS.update(_RESP3_MODULE_CALLBACKS)
else:
self._MODULE_CALLBACKS.update(_RESP2_MODULE_CALLBACKS)
for key, value in self._MODULE_CALLBACKS.items():
self.client.set_response_callback(key, value)
self.__encoder__ = encoder
self.__decoder__ = decoder
def _decode(self, obj):
"""Get the decoder."""
if obj is None:
return obj
try:
x = self.__decoder__.decode(obj)
if x is None:
raise TypeError
return x
except TypeError:
try:
return self.__decoder__.decode(obj.decode())
except AttributeError:
return decode_list(obj)
except (AttributeError, JSONDecodeError):
return decode_list(obj)
def _encode(self, obj):
"""Get the encoder."""
return self.__encoder__.encode(obj)
def pipeline(self, transaction=True, shard_hint=None):
"""Creates a pipeline for the JSON module, that can be used for executing
JSON commands, as well as classic core commands.
Usage example:
r = redis.Redis()
pipe = r.json().pipeline()
pipe.jsonset('foo', '.', {'hello!': 'world'})
pipe.jsonget('foo')
pipe.jsonget('notakey')
"""
if isinstance(self.client, redis.RedisCluster):
p = ClusterPipeline(
nodes_manager=self.client.nodes_manager,
commands_parser=self.client.commands_parser,
startup_nodes=self.client.nodes_manager.startup_nodes,
result_callbacks=self.client.result_callbacks,
cluster_response_callbacks=self.client.cluster_response_callbacks,
cluster_error_retry_attempts=self.client.retry.get_retries(),
read_from_replicas=self.client.read_from_replicas,
reinitialize_steps=self.client.reinitialize_steps,
lock=self.client._lock,
)
else:
p = Pipeline(
connection_pool=self.client.connection_pool,
response_callbacks=self._MODULE_CALLBACKS,
transaction=transaction,
shard_hint=shard_hint,
)
p._encode = self._encode
p._decode = self._decode
return p
class ClusterPipeline(JSONCommands, redis.cluster.ClusterPipeline):
"""Cluster pipeline for the module."""
class Pipeline(JSONCommands, redis.client.Pipeline):
"""Pipeline for the module."""

View File

@@ -0,0 +1,5 @@
from typing import List, Mapping, Union
JsonType = Union[
str, int, float, bool, None, Mapping[str, "JsonType"], List["JsonType"]
]

View File

@@ -0,0 +1,431 @@
import os
from json import JSONDecodeError, loads
from typing import Dict, List, Optional, Tuple, Union
from redis.exceptions import DataError
from redis.utils import deprecated_function
from ._util import JsonType
from .decoders import decode_dict_keys
from .path import Path
class JSONCommands:
"""json commands."""
def arrappend(
self, name: str, path: Optional[str] = Path.root_path(), *args: List[JsonType]
) -> List[Optional[int]]:
"""Append the objects ``args`` to the array under the
``path` in key ``name``.
For more information see `JSON.ARRAPPEND <https://redis.io/commands/json.arrappend>`_..
""" # noqa
pieces = [name, str(path)]
for o in args:
pieces.append(self._encode(o))
return self.execute_command("JSON.ARRAPPEND", *pieces)
def arrindex(
self,
name: str,
path: str,
scalar: int,
start: Optional[int] = None,
stop: Optional[int] = None,
) -> List[Optional[int]]:
"""
Return the index of ``scalar`` in the JSON array under ``path`` at key
``name``.
The search can be limited using the optional inclusive ``start``
and exclusive ``stop`` indices.
For more information see `JSON.ARRINDEX <https://redis.io/commands/json.arrindex>`_.
""" # noqa
pieces = [name, str(path), self._encode(scalar)]
if start is not None:
pieces.append(start)
if stop is not None:
pieces.append(stop)
return self.execute_command("JSON.ARRINDEX", *pieces, keys=[name])
def arrinsert(
self, name: str, path: str, index: int, *args: List[JsonType]
) -> List[Optional[int]]:
"""Insert the objects ``args`` to the array at index ``index``
under the ``path` in key ``name``.
For more information see `JSON.ARRINSERT <https://redis.io/commands/json.arrinsert>`_.
""" # noqa
pieces = [name, str(path), index]
for o in args:
pieces.append(self._encode(o))
return self.execute_command("JSON.ARRINSERT", *pieces)
def arrlen(
self, name: str, path: Optional[str] = Path.root_path()
) -> List[Optional[int]]:
"""Return the length of the array JSON value under ``path``
at key``name``.
For more information see `JSON.ARRLEN <https://redis.io/commands/json.arrlen>`_.
""" # noqa
return self.execute_command("JSON.ARRLEN", name, str(path), keys=[name])
def arrpop(
self,
name: str,
path: Optional[str] = Path.root_path(),
index: Optional[int] = -1,
) -> List[Optional[str]]:
"""Pop the element at ``index`` in the array JSON value under
``path`` at key ``name``.
For more information see `JSON.ARRPOP <https://redis.io/commands/json.arrpop>`_.
""" # noqa
return self.execute_command("JSON.ARRPOP", name, str(path), index)
def arrtrim(
self, name: str, path: str, start: int, stop: int
) -> List[Optional[int]]:
"""Trim the array JSON value under ``path`` at key ``name`` to the
inclusive range given by ``start`` and ``stop``.
For more information see `JSON.ARRTRIM <https://redis.io/commands/json.arrtrim>`_.
""" # noqa
return self.execute_command("JSON.ARRTRIM", name, str(path), start, stop)
def type(self, name: str, path: Optional[str] = Path.root_path()) -> List[str]:
"""Get the type of the JSON value under ``path`` from key ``name``.
For more information see `JSON.TYPE <https://redis.io/commands/json.type>`_.
""" # noqa
return self.execute_command("JSON.TYPE", name, str(path), keys=[name])
def resp(self, name: str, path: Optional[str] = Path.root_path()) -> List:
"""Return the JSON value under ``path`` at key ``name``.
For more information see `JSON.RESP <https://redis.io/commands/json.resp>`_.
""" # noqa
return self.execute_command("JSON.RESP", name, str(path), keys=[name])
def objkeys(
self, name: str, path: Optional[str] = Path.root_path()
) -> List[Optional[List[str]]]:
"""Return the key names in the dictionary JSON value under ``path`` at
key ``name``.
For more information see `JSON.OBJKEYS <https://redis.io/commands/json.objkeys>`_.
""" # noqa
return self.execute_command("JSON.OBJKEYS", name, str(path), keys=[name])
def objlen(
self, name: str, path: Optional[str] = Path.root_path()
) -> List[Optional[int]]:
"""Return the length of the dictionary JSON value under ``path`` at key
``name``.
For more information see `JSON.OBJLEN <https://redis.io/commands/json.objlen>`_.
""" # noqa
return self.execute_command("JSON.OBJLEN", name, str(path), keys=[name])
def numincrby(self, name: str, path: str, number: int) -> str:
"""Increment the numeric (integer or floating point) JSON value under
``path`` at key ``name`` by the provided ``number``.
For more information see `JSON.NUMINCRBY <https://redis.io/commands/json.numincrby>`_.
""" # noqa
return self.execute_command(
"JSON.NUMINCRBY", name, str(path), self._encode(number)
)
@deprecated_function(version="4.0.0", reason="deprecated since redisjson 1.0.0")
def nummultby(self, name: str, path: str, number: int) -> str:
"""Multiply the numeric (integer or floating point) JSON value under
``path`` at key ``name`` with the provided ``number``.
For more information see `JSON.NUMMULTBY <https://redis.io/commands/json.nummultby>`_.
""" # noqa
return self.execute_command(
"JSON.NUMMULTBY", name, str(path), self._encode(number)
)
def clear(self, name: str, path: Optional[str] = Path.root_path()) -> int:
"""Empty arrays and objects (to have zero slots/keys without deleting the
array/object).
Return the count of cleared paths (ignoring non-array and non-objects
paths).
For more information see `JSON.CLEAR <https://redis.io/commands/json.clear>`_.
""" # noqa
return self.execute_command("JSON.CLEAR", name, str(path))
def delete(self, key: str, path: Optional[str] = Path.root_path()) -> int:
"""Delete the JSON value stored at key ``key`` under ``path``.
For more information see `JSON.DEL <https://redis.io/commands/json.del>`_.
"""
return self.execute_command("JSON.DEL", key, str(path))
# forget is an alias for delete
forget = delete
def get(
self, name: str, *args, no_escape: Optional[bool] = False
) -> Optional[List[JsonType]]:
"""
Get the object stored as a JSON value at key ``name``.
``args`` is zero or more paths, and defaults to root path
```no_escape`` is a boolean flag to add no_escape option to get
non-ascii characters
For more information see `JSON.GET <https://redis.io/commands/json.get>`_.
""" # noqa
pieces = [name]
if no_escape:
pieces.append("noescape")
if len(args) == 0:
pieces.append(Path.root_path())
else:
for p in args:
pieces.append(str(p))
# Handle case where key doesn't exist. The JSONDecoder would raise a
# TypeError exception since it can't decode None
try:
return self.execute_command("JSON.GET", *pieces, keys=[name])
except TypeError:
return None
def mget(self, keys: List[str], path: str) -> List[JsonType]:
"""
Get the objects stored as a JSON values under ``path``. ``keys``
is a list of one or more keys.
For more information see `JSON.MGET <https://redis.io/commands/json.mget>`_.
""" # noqa
pieces = []
pieces += keys
pieces.append(str(path))
return self.execute_command("JSON.MGET", *pieces, keys=keys)
def set(
self,
name: str,
path: str,
obj: JsonType,
nx: Optional[bool] = False,
xx: Optional[bool] = False,
decode_keys: Optional[bool] = False,
) -> Optional[str]:
"""
Set the JSON value at key ``name`` under the ``path`` to ``obj``.
``nx`` if set to True, set ``value`` only if it does not exist.
``xx`` if set to True, set ``value`` only if it exists.
``decode_keys`` If set to True, the keys of ``obj`` will be decoded
with utf-8.
For the purpose of using this within a pipeline, this command is also
aliased to JSON.SET.
For more information see `JSON.SET <https://redis.io/commands/json.set>`_.
"""
if decode_keys:
obj = decode_dict_keys(obj)
pieces = [name, str(path), self._encode(obj)]
# Handle existential modifiers
if nx and xx:
raise Exception(
"nx and xx are mutually exclusive: use one, the "
"other or neither - but not both"
)
elif nx:
pieces.append("NX")
elif xx:
pieces.append("XX")
return self.execute_command("JSON.SET", *pieces)
def mset(self, triplets: List[Tuple[str, str, JsonType]]) -> Optional[str]:
"""
Set the JSON value at key ``name`` under the ``path`` to ``obj``
for one or more keys.
``triplets`` is a list of one or more triplets of key, path, value.
For the purpose of using this within a pipeline, this command is also
aliased to JSON.MSET.
For more information see `JSON.MSET <https://redis.io/commands/json.mset>`_.
"""
pieces = []
for triplet in triplets:
pieces.extend([triplet[0], str(triplet[1]), self._encode(triplet[2])])
return self.execute_command("JSON.MSET", *pieces)
def merge(
self,
name: str,
path: str,
obj: JsonType,
decode_keys: Optional[bool] = False,
) -> Optional[str]:
"""
Merges a given JSON value into matching paths. Consequently, JSON values
at matching paths are updated, deleted, or expanded with new children
``decode_keys`` If set to True, the keys of ``obj`` will be decoded
with utf-8.
For more information see `JSON.MERGE <https://redis.io/commands/json.merge>`_.
"""
if decode_keys:
obj = decode_dict_keys(obj)
pieces = [name, str(path), self._encode(obj)]
return self.execute_command("JSON.MERGE", *pieces)
def set_file(
self,
name: str,
path: str,
file_name: str,
nx: Optional[bool] = False,
xx: Optional[bool] = False,
decode_keys: Optional[bool] = False,
) -> Optional[str]:
"""
Set the JSON value at key ``name`` under the ``path`` to the content
of the json file ``file_name``.
``nx`` if set to True, set ``value`` only if it does not exist.
``xx`` if set to True, set ``value`` only if it exists.
``decode_keys`` If set to True, the keys of ``obj`` will be decoded
with utf-8.
"""
with open(file_name) as fp:
file_content = loads(fp.read())
return self.set(name, path, file_content, nx=nx, xx=xx, decode_keys=decode_keys)
def set_path(
self,
json_path: str,
root_folder: str,
nx: Optional[bool] = False,
xx: Optional[bool] = False,
decode_keys: Optional[bool] = False,
) -> Dict[str, bool]:
"""
Iterate over ``root_folder`` and set each JSON file to a value
under ``json_path`` with the file name as the key.
``nx`` if set to True, set ``value`` only if it does not exist.
``xx`` if set to True, set ``value`` only if it exists.
``decode_keys`` If set to True, the keys of ``obj`` will be decoded
with utf-8.
"""
set_files_result = {}
for root, dirs, files in os.walk(root_folder):
for file in files:
file_path = os.path.join(root, file)
try:
file_name = file_path.rsplit(".")[0]
self.set_file(
file_name,
json_path,
file_path,
nx=nx,
xx=xx,
decode_keys=decode_keys,
)
set_files_result[file_path] = True
except JSONDecodeError:
set_files_result[file_path] = False
return set_files_result
def strlen(self, name: str, path: Optional[str] = None) -> List[Optional[int]]:
"""Return the length of the string JSON value under ``path`` at key
``name``.
For more information see `JSON.STRLEN <https://redis.io/commands/json.strlen>`_.
""" # noqa
pieces = [name]
if path is not None:
pieces.append(str(path))
return self.execute_command("JSON.STRLEN", *pieces, keys=[name])
def toggle(
self, name: str, path: Optional[str] = Path.root_path()
) -> Union[bool, List[Optional[int]]]:
"""Toggle boolean value under ``path`` at key ``name``.
returning the new value.
For more information see `JSON.TOGGLE <https://redis.io/commands/json.toggle>`_.
""" # noqa
return self.execute_command("JSON.TOGGLE", name, str(path))
def strappend(
self, name: str, value: str, path: Optional[str] = Path.root_path()
) -> Union[int, List[Optional[int]]]:
"""Append to the string JSON value. If two options are specified after
the key name, the path is determined to be the first. If a single
option is passed, then the root_path (i.e Path.root_path()) is used.
For more information see `JSON.STRAPPEND <https://redis.io/commands/json.strappend>`_.
""" # noqa
pieces = [name, str(path), self._encode(value)]
return self.execute_command("JSON.STRAPPEND", *pieces)
def debug(
self,
subcommand: str,
key: Optional[str] = None,
path: Optional[str] = Path.root_path(),
) -> Union[int, List[str]]:
"""Return the memory usage in bytes of a value under ``path`` from
key ``name``.
For more information see `JSON.DEBUG <https://redis.io/commands/json.debug>`_.
""" # noqa
valid_subcommands = ["MEMORY", "HELP"]
if subcommand not in valid_subcommands:
raise DataError("The only valid subcommands are ", str(valid_subcommands))
pieces = [subcommand]
if subcommand == "MEMORY":
if key is None:
raise DataError("No key specified")
pieces.append(key)
pieces.append(str(path))
return self.execute_command("JSON.DEBUG", *pieces)
@deprecated_function(
version="4.0.0", reason="redisjson-py supported this, call get directly."
)
def jsonget(self, *args, **kwargs):
return self.get(*args, **kwargs)
@deprecated_function(
version="4.0.0", reason="redisjson-py supported this, call get directly."
)
def jsonmget(self, *args, **kwargs):
return self.mget(*args, **kwargs)
@deprecated_function(
version="4.0.0", reason="redisjson-py supported this, call get directly."
)
def jsonset(self, *args, **kwargs):
return self.set(*args, **kwargs)

View File

@@ -0,0 +1,60 @@
import copy
import re
from ..helpers import nativestr
def bulk_of_jsons(d):
"""Replace serialized JSON values with objects in a
bulk array response (list).
"""
def _f(b):
for index, item in enumerate(b):
if item is not None:
b[index] = d(item)
return b
return _f
def decode_dict_keys(obj):
"""Decode the keys of the given dictionary with utf-8."""
newobj = copy.copy(obj)
for k in obj.keys():
if isinstance(k, bytes):
newobj[k.decode("utf-8")] = newobj[k]
newobj.pop(k)
return newobj
def unstring(obj):
"""
Attempt to parse string to native integer formats.
One can't simply call int/float in a try/catch because there is a
semantic difference between (for example) 15.0 and 15.
"""
floatreg = "^\\d+.\\d+$"
match = re.findall(floatreg, obj)
if match != []:
return float(match[0])
intreg = "^\\d+$"
match = re.findall(intreg, obj)
if match != []:
return int(match[0])
return obj
def decode_list(b):
"""
Given a non-deserializable object, make a best effort to
return a useful set of results.
"""
if isinstance(b, list):
return [nativestr(obj) for obj in b]
elif isinstance(b, bytes):
return unstring(nativestr(b))
elif isinstance(b, str):
return unstring(b)
return b

View File

@@ -0,0 +1,16 @@
class Path:
"""This class represents a path in a JSON value."""
strPath = ""
@staticmethod
def root_path():
"""Return the root path's string representation."""
return "."
def __init__(self, path):
"""Make a new path based on the string representation in `path`."""
self.strPath = path
def __repr__(self):
return self.strPath

View File

@@ -0,0 +1,101 @@
from __future__ import annotations
from json import JSONDecoder, JSONEncoder
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from .bf import BFBloom, CFBloom, CMSBloom, TDigestBloom, TOPKBloom
from .json import JSON
from .search import AsyncSearch, Search
from .timeseries import TimeSeries
from .vectorset import VectorSet
class RedisModuleCommands:
"""This class contains the wrapper functions to bring supported redis
modules into the command namespace.
"""
def json(self, encoder=JSONEncoder(), decoder=JSONDecoder()) -> JSON:
"""Access the json namespace, providing support for redis json."""
from .json import JSON
jj = JSON(client=self, encoder=encoder, decoder=decoder)
return jj
def ft(self, index_name="idx") -> Search:
"""Access the search namespace, providing support for redis search."""
from .search import Search
s = Search(client=self, index_name=index_name)
return s
def ts(self) -> TimeSeries:
"""Access the timeseries namespace, providing support for
redis timeseries data.
"""
from .timeseries import TimeSeries
s = TimeSeries(client=self)
return s
def bf(self) -> BFBloom:
"""Access the bloom namespace."""
from .bf import BFBloom
bf = BFBloom(client=self)
return bf
def cf(self) -> CFBloom:
"""Access the bloom namespace."""
from .bf import CFBloom
cf = CFBloom(client=self)
return cf
def cms(self) -> CMSBloom:
"""Access the bloom namespace."""
from .bf import CMSBloom
cms = CMSBloom(client=self)
return cms
def topk(self) -> TOPKBloom:
"""Access the bloom namespace."""
from .bf import TOPKBloom
topk = TOPKBloom(client=self)
return topk
def tdigest(self) -> TDigestBloom:
"""Access the bloom namespace."""
from .bf import TDigestBloom
tdigest = TDigestBloom(client=self)
return tdigest
def vset(self) -> VectorSet:
"""Access the VectorSet commands namespace."""
from .vectorset import VectorSet
vset = VectorSet(client=self)
return vset
class AsyncRedisModuleCommands(RedisModuleCommands):
def ft(self, index_name="idx") -> AsyncSearch:
"""Access the search namespace, providing support for redis search."""
from .search import AsyncSearch
s = AsyncSearch(client=self, index_name=index_name)
return s

View File

@@ -0,0 +1,189 @@
import redis
from ...asyncio.client import Pipeline as AsyncioPipeline
from .commands import (
AGGREGATE_CMD,
CONFIG_CMD,
INFO_CMD,
PROFILE_CMD,
SEARCH_CMD,
SPELLCHECK_CMD,
SYNDUMP_CMD,
AsyncSearchCommands,
SearchCommands,
)
class Search(SearchCommands):
"""
Create a client for talking to search.
It abstracts the API of the module and lets you just use the engine.
"""
class BatchIndexer:
"""
A batch indexer allows you to automatically batch
document indexing in pipelines, flushing it every N documents.
"""
def __init__(self, client, chunk_size=1000):
self.client = client
self.execute_command = client.execute_command
self._pipeline = client.pipeline(transaction=False, shard_hint=None)
self.total = 0
self.chunk_size = chunk_size
self.current_chunk = 0
def __del__(self):
if self.current_chunk:
self.commit()
def add_document(
self,
doc_id,
nosave=False,
score=1.0,
payload=None,
replace=False,
partial=False,
no_create=False,
**fields,
):
"""
Add a document to the batch query
"""
self.client._add_document(
doc_id,
conn=self._pipeline,
nosave=nosave,
score=score,
payload=payload,
replace=replace,
partial=partial,
no_create=no_create,
**fields,
)
self.current_chunk += 1
self.total += 1
if self.current_chunk >= self.chunk_size:
self.commit()
def add_document_hash(self, doc_id, score=1.0, replace=False):
"""
Add a hash to the batch query
"""
self.client._add_document_hash(
doc_id, conn=self._pipeline, score=score, replace=replace
)
self.current_chunk += 1
self.total += 1
if self.current_chunk >= self.chunk_size:
self.commit()
def commit(self):
"""
Manually commit and flush the batch indexing query
"""
self._pipeline.execute()
self.current_chunk = 0
def __init__(self, client, index_name="idx"):
"""
Create a new Client for the given index_name.
The default name is `idx`
If conn is not None, we employ an already existing redis connection
"""
self._MODULE_CALLBACKS = {}
self.client = client
self.index_name = index_name
self.execute_command = client.execute_command
self._pipeline = client.pipeline
self._RESP2_MODULE_CALLBACKS = {
INFO_CMD: self._parse_info,
SEARCH_CMD: self._parse_search,
AGGREGATE_CMD: self._parse_aggregate,
PROFILE_CMD: self._parse_profile,
SPELLCHECK_CMD: self._parse_spellcheck,
CONFIG_CMD: self._parse_config_get,
SYNDUMP_CMD: self._parse_syndump,
}
def pipeline(self, transaction=True, shard_hint=None):
"""Creates a pipeline for the SEARCH module, that can be used for executing
SEARCH commands, as well as classic core commands.
"""
p = Pipeline(
connection_pool=self.client.connection_pool,
response_callbacks=self._MODULE_CALLBACKS,
transaction=transaction,
shard_hint=shard_hint,
)
p.index_name = self.index_name
return p
class AsyncSearch(Search, AsyncSearchCommands):
class BatchIndexer(Search.BatchIndexer):
"""
A batch indexer allows you to automatically batch
document indexing in pipelines, flushing it every N documents.
"""
async def add_document(
self,
doc_id,
nosave=False,
score=1.0,
payload=None,
replace=False,
partial=False,
no_create=False,
**fields,
):
"""
Add a document to the batch query
"""
self.client._add_document(
doc_id,
conn=self._pipeline,
nosave=nosave,
score=score,
payload=payload,
replace=replace,
partial=partial,
no_create=no_create,
**fields,
)
self.current_chunk += 1
self.total += 1
if self.current_chunk >= self.chunk_size:
await self.commit()
async def commit(self):
"""
Manually commit and flush the batch indexing query
"""
await self._pipeline.execute()
self.current_chunk = 0
def pipeline(self, transaction=True, shard_hint=None):
"""Creates a pipeline for the SEARCH module, that can be used for executing
SEARCH commands, as well as classic core commands.
"""
p = AsyncPipeline(
connection_pool=self.client.connection_pool,
response_callbacks=self._MODULE_CALLBACKS,
transaction=transaction,
shard_hint=shard_hint,
)
p.index_name = self.index_name
return p
class Pipeline(SearchCommands, redis.client.Pipeline):
"""Pipeline for the module."""
class AsyncPipeline(AsyncSearchCommands, AsyncioPipeline, Pipeline):
"""AsyncPipeline for the module."""

View File

@@ -0,0 +1,7 @@
def to_string(s, encoding: str = "utf-8"):
if isinstance(s, str):
return s
elif isinstance(s, bytes):
return s.decode(encoding, "ignore")
else:
return s # Not a string we care about

View File

@@ -0,0 +1,401 @@
from typing import List, Union
from redis.commands.search.dialect import DEFAULT_DIALECT
FIELDNAME = object()
class Limit:
def __init__(self, offset: int = 0, count: int = 0) -> None:
self.offset = offset
self.count = count
def build_args(self):
if self.count:
return ["LIMIT", str(self.offset), str(self.count)]
else:
return []
class Reducer:
"""
Base reducer object for all reducers.
See the `redisearch.reducers` module for the actual reducers.
"""
NAME = None
def __init__(self, *args: str) -> None:
self._args = args
self._field = None
self._alias = None
def alias(self, alias: str) -> "Reducer":
"""
Set the alias for this reducer.
### Parameters
- **alias**: The value of the alias for this reducer. If this is the
special value `aggregation.FIELDNAME` then this reducer will be
aliased using the same name as the field upon which it operates.
Note that using `FIELDNAME` is only possible on reducers which
operate on a single field value.
This method returns the `Reducer` object making it suitable for
chaining.
"""
if alias is FIELDNAME:
if not self._field:
raise ValueError("Cannot use FIELDNAME alias with no field")
# Chop off initial '@'
alias = self._field[1:]
self._alias = alias
return self
@property
def args(self) -> List[str]:
return self._args
class SortDirection:
"""
This special class is used to indicate sort direction.
"""
DIRSTRING = None
def __init__(self, field: str) -> None:
self.field = field
class Asc(SortDirection):
"""
Indicate that the given field should be sorted in ascending order
"""
DIRSTRING = "ASC"
class Desc(SortDirection):
"""
Indicate that the given field should be sorted in descending order
"""
DIRSTRING = "DESC"
class AggregateRequest:
"""
Aggregation request which can be passed to `Client.aggregate`.
"""
def __init__(self, query: str = "*") -> None:
"""
Create an aggregation request. This request may then be passed to
`client.aggregate()`.
In order for the request to be usable, it must contain at least one
group.
- **query** Query string for filtering records.
All member methods (except `build_args()`)
return the object itself, making them useful for chaining.
"""
self._query = query
self._aggregateplan = []
self._loadfields = []
self._loadall = False
self._max = 0
self._with_schema = False
self._verbatim = False
self._cursor = []
self._dialect = DEFAULT_DIALECT
self._add_scores = False
self._scorer = "TFIDF"
def load(self, *fields: str) -> "AggregateRequest":
"""
Indicate the fields to be returned in the response. These fields are
returned in addition to any others implicitly specified.
### Parameters
- **fields**: If fields not specified, all the fields will be loaded.
Otherwise, fields should be given in the format of `@field`.
"""
if fields:
self._loadfields.extend(fields)
else:
self._loadall = True
return self
def group_by(
self, fields: List[str], *reducers: Union[Reducer, List[Reducer]]
) -> "AggregateRequest":
"""
Specify by which fields to group the aggregation.
### Parameters
- **fields**: Fields to group by. This can either be a single string,
or a list of strings. both cases, the field should be specified as
`@field`.
- **reducers**: One or more reducers. Reducers may be found in the
`aggregation` module.
"""
fields = [fields] if isinstance(fields, str) else fields
reducers = [reducers] if isinstance(reducers, Reducer) else reducers
ret = ["GROUPBY", str(len(fields)), *fields]
for reducer in reducers:
ret += ["REDUCE", reducer.NAME, str(len(reducer.args))]
ret.extend(reducer.args)
if reducer._alias is not None:
ret += ["AS", reducer._alias]
self._aggregateplan.extend(ret)
return self
def apply(self, **kwexpr) -> "AggregateRequest":
"""
Specify one or more projection expressions to add to each result
### Parameters
- **kwexpr**: One or more key-value pairs for a projection. The key is
the alias for the projection, and the value is the projection
expression itself, for example `apply(square_root="sqrt(@foo)")`
"""
for alias, expr in kwexpr.items():
ret = ["APPLY", expr]
if alias is not None:
ret += ["AS", alias]
self._aggregateplan.extend(ret)
return self
def limit(self, offset: int, num: int) -> "AggregateRequest":
"""
Sets the limit for the most recent group or query.
If no group has been defined yet (via `group_by()`) then this sets
the limit for the initial pool of results from the query. Otherwise,
this limits the number of items operated on from the previous group.
Setting a limit on the initial search results may be useful when
attempting to execute an aggregation on a sample of a large data set.
### Parameters
- **offset**: Result offset from which to begin paging
- **num**: Number of results to return
Example of sorting the initial results:
```
AggregateRequest("@sale_amount:[10000, inf]")\
.limit(0, 10)\
.group_by("@state", r.count())
```
Will only group by the states found in the first 10 results of the
query `@sale_amount:[10000, inf]`. On the other hand,
```
AggregateRequest("@sale_amount:[10000, inf]")\
.limit(0, 1000)\
.group_by("@state", r.count()\
.limit(0, 10)
```
Will group all the results matching the query, but only return the
first 10 groups.
If you only wish to return a *top-N* style query, consider using
`sort_by()` instead.
"""
_limit = Limit(offset, num)
self._aggregateplan.extend(_limit.build_args())
return self
def sort_by(self, *fields: str, **kwargs) -> "AggregateRequest":
"""
Indicate how the results should be sorted. This can also be used for
*top-N* style queries
### Parameters
- **fields**: The fields by which to sort. This can be either a single
field or a list of fields. If you wish to specify order, you can
use the `Asc` or `Desc` wrapper classes.
- **max**: Maximum number of results to return. This can be
used instead of `LIMIT` and is also faster.
Example of sorting by `foo` ascending and `bar` descending:
```
sort_by(Asc("@foo"), Desc("@bar"))
```
Return the top 10 customers:
```
AggregateRequest()\
.group_by("@customer", r.sum("@paid").alias(FIELDNAME))\
.sort_by(Desc("@paid"), max=10)
```
"""
if isinstance(fields, (str, SortDirection)):
fields = [fields]
fields_args = []
for f in fields:
if isinstance(f, SortDirection):
fields_args += [f.field, f.DIRSTRING]
else:
fields_args += [f]
ret = ["SORTBY", str(len(fields_args))]
ret.extend(fields_args)
max = kwargs.get("max", 0)
if max > 0:
ret += ["MAX", str(max)]
self._aggregateplan.extend(ret)
return self
def filter(self, expressions: Union[str, List[str]]) -> "AggregateRequest":
"""
Specify filter for post-query results using predicates relating to
values in the result set.
### Parameters
- **fields**: Fields to group by. This can either be a single string,
or a list of strings.
"""
if isinstance(expressions, str):
expressions = [expressions]
for expression in expressions:
self._aggregateplan.extend(["FILTER", expression])
return self
def with_schema(self) -> "AggregateRequest":
"""
If set, the `schema` property will contain a list of `[field, type]`
entries in the result object.
"""
self._with_schema = True
return self
def add_scores(self) -> "AggregateRequest":
"""
If set, includes the score as an ordinary field of the row.
"""
self._add_scores = True
return self
def scorer(self, scorer: str) -> "AggregateRequest":
"""
Use a different scoring function to evaluate document relevance.
Default is `TFIDF`.
:param scorer: The scoring function to use
(e.g. `TFIDF.DOCNORM` or `BM25`)
"""
self._scorer = scorer
return self
def verbatim(self) -> "AggregateRequest":
self._verbatim = True
return self
def cursor(self, count: int = 0, max_idle: float = 0.0) -> "AggregateRequest":
args = ["WITHCURSOR"]
if count:
args += ["COUNT", str(count)]
if max_idle:
args += ["MAXIDLE", str(max_idle * 1000)]
self._cursor = args
return self
def build_args(self) -> List[str]:
# @foo:bar ...
ret = [self._query]
if self._with_schema:
ret.append("WITHSCHEMA")
if self._verbatim:
ret.append("VERBATIM")
if self._scorer:
ret.extend(["SCORER", self._scorer])
if self._add_scores:
ret.append("ADDSCORES")
if self._cursor:
ret += self._cursor
if self._loadall:
ret.append("LOAD")
ret.append("*")
elif self._loadfields:
ret.append("LOAD")
ret.append(str(len(self._loadfields)))
ret.extend(self._loadfields)
if self._dialect:
ret.extend(["DIALECT", self._dialect])
ret.extend(self._aggregateplan)
return ret
def dialect(self, dialect: int) -> "AggregateRequest":
"""
Add a dialect field to the aggregate command.
- **dialect** - dialect version to execute the query under
"""
self._dialect = dialect
return self
class Cursor:
def __init__(self, cid: int) -> None:
self.cid = cid
self.max_idle = 0
self.count = 0
def build_args(self):
args = [str(self.cid)]
if self.max_idle:
args += ["MAXIDLE", str(self.max_idle)]
if self.count:
args += ["COUNT", str(self.count)]
return args
class AggregateResult:
def __init__(self, rows, cursor: Cursor, schema) -> None:
self.rows = rows
self.cursor = cursor
self.schema = schema
def __repr__(self) -> (str, str):
cid = self.cursor.cid if self.cursor else -1
return (
f"<{self.__class__.__name__} at 0x{id(self):x} "
f"Rows={len(self.rows)}, Cursor={cid}>"
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,3 @@
# Value for the default dialect to be used as a part of
# Search or Aggregate query.
DEFAULT_DIALECT = 2

View File

@@ -0,0 +1,17 @@
class Document:
"""
Represents a single document in a result set
"""
def __init__(self, id, payload=None, **fields):
self.id = id
self.payload = payload
for k, v in fields.items():
setattr(self, k, v)
def __repr__(self):
return f"Document {self.__dict__}"
def __getitem__(self, item):
value = getattr(self, item)
return value

View File

@@ -0,0 +1,210 @@
from typing import List
from redis import DataError
class Field:
"""
A class representing a field in a document.
"""
NUMERIC = "NUMERIC"
TEXT = "TEXT"
WEIGHT = "WEIGHT"
GEO = "GEO"
TAG = "TAG"
VECTOR = "VECTOR"
SORTABLE = "SORTABLE"
NOINDEX = "NOINDEX"
AS = "AS"
GEOSHAPE = "GEOSHAPE"
INDEX_MISSING = "INDEXMISSING"
INDEX_EMPTY = "INDEXEMPTY"
def __init__(
self,
name: str,
args: List[str] = None,
sortable: bool = False,
no_index: bool = False,
index_missing: bool = False,
index_empty: bool = False,
as_name: str = None,
):
"""
Create a new field object.
Args:
name: The name of the field.
args:
sortable: If `True`, the field will be sortable.
no_index: If `True`, the field will not be indexed.
index_missing: If `True`, it will be possible to search for documents that
have this field missing.
index_empty: If `True`, it will be possible to search for documents that
have this field empty.
as_name: If provided, this alias will be used for the field.
"""
if args is None:
args = []
self.name = name
self.args = args
self.args_suffix = list()
self.as_name = as_name
if sortable:
self.args_suffix.append(Field.SORTABLE)
if no_index:
self.args_suffix.append(Field.NOINDEX)
if index_missing:
self.args_suffix.append(Field.INDEX_MISSING)
if index_empty:
self.args_suffix.append(Field.INDEX_EMPTY)
if no_index and not sortable:
raise ValueError("Non-Sortable non-Indexable fields are ignored")
def append_arg(self, value):
self.args.append(value)
def redis_args(self):
args = [self.name]
if self.as_name:
args += [self.AS, self.as_name]
args += self.args
args += self.args_suffix
return args
class TextField(Field):
"""
TextField is used to define a text field in a schema definition
"""
NOSTEM = "NOSTEM"
PHONETIC = "PHONETIC"
def __init__(
self,
name: str,
weight: float = 1.0,
no_stem: bool = False,
phonetic_matcher: str = None,
withsuffixtrie: bool = False,
**kwargs,
):
Field.__init__(self, name, args=[Field.TEXT, Field.WEIGHT, weight], **kwargs)
if no_stem:
Field.append_arg(self, self.NOSTEM)
if phonetic_matcher and phonetic_matcher in [
"dm:en",
"dm:fr",
"dm:pt",
"dm:es",
]:
Field.append_arg(self, self.PHONETIC)
Field.append_arg(self, phonetic_matcher)
if withsuffixtrie:
Field.append_arg(self, "WITHSUFFIXTRIE")
class NumericField(Field):
"""
NumericField is used to define a numeric field in a schema definition
"""
def __init__(self, name: str, **kwargs):
Field.__init__(self, name, args=[Field.NUMERIC], **kwargs)
class GeoShapeField(Field):
"""
GeoShapeField is used to enable within/contain indexing/searching
"""
SPHERICAL = "SPHERICAL"
FLAT = "FLAT"
def __init__(self, name: str, coord_system=None, **kwargs):
args = [Field.GEOSHAPE]
if coord_system:
args.append(coord_system)
Field.__init__(self, name, args=args, **kwargs)
class GeoField(Field):
"""
GeoField is used to define a geo-indexing field in a schema definition
"""
def __init__(self, name: str, **kwargs):
Field.__init__(self, name, args=[Field.GEO], **kwargs)
class TagField(Field):
"""
TagField is a tag-indexing field with simpler compression and tokenization.
See http://redisearch.io/Tags/
"""
SEPARATOR = "SEPARATOR"
CASESENSITIVE = "CASESENSITIVE"
def __init__(
self,
name: str,
separator: str = ",",
case_sensitive: bool = False,
withsuffixtrie: bool = False,
**kwargs,
):
args = [Field.TAG, self.SEPARATOR, separator]
if case_sensitive:
args.append(self.CASESENSITIVE)
if withsuffixtrie:
args.append("WITHSUFFIXTRIE")
Field.__init__(self, name, args=args, **kwargs)
class VectorField(Field):
"""
Allows vector similarity queries against the value in this attribute.
See https://oss.redis.com/redisearch/Vectors/#vector_fields.
"""
def __init__(self, name: str, algorithm: str, attributes: dict, **kwargs):
"""
Create Vector Field. Notice that Vector cannot have sortable or no_index tag,
although it's also a Field.
``name`` is the name of the field.
``algorithm`` can be "FLAT", "HNSW", or "SVS-VAMANA".
``attributes`` each algorithm can have specific attributes. Some of them
are mandatory and some of them are optional. See
https://oss.redis.com/redisearch/master/Vectors/#specific_creation_attributes_per_algorithm
for more information.
"""
sort = kwargs.get("sortable", False)
noindex = kwargs.get("no_index", False)
if sort or noindex:
raise DataError("Cannot set 'sortable' or 'no_index' in Vector fields.")
if algorithm.upper() not in ["FLAT", "HNSW", "SVS-VAMANA"]:
raise DataError(
"Realtime vector indexing supporting 3 Indexing Methods:"
"'FLAT', 'HNSW', and 'SVS-VAMANA'."
)
attr_li = []
for key, value in attributes.items():
attr_li.extend([key, value])
Field.__init__(
self, name, args=[Field.VECTOR, algorithm, len(attr_li), *attr_li], **kwargs
)

View File

@@ -0,0 +1,79 @@
from enum import Enum
class IndexType(Enum):
"""Enum of the currently supported index types."""
HASH = 1
JSON = 2
class IndexDefinition:
"""IndexDefinition is used to define a index definition for automatic
indexing on Hash or Json update."""
def __init__(
self,
prefix=[],
filter=None,
language_field=None,
language=None,
score_field=None,
score=1.0,
payload_field=None,
index_type=None,
):
self.args = []
self._append_index_type(index_type)
self._append_prefix(prefix)
self._append_filter(filter)
self._append_language(language_field, language)
self._append_score(score_field, score)
self._append_payload(payload_field)
def _append_index_type(self, index_type):
"""Append `ON HASH` or `ON JSON` according to the enum."""
if index_type is IndexType.HASH:
self.args.extend(["ON", "HASH"])
elif index_type is IndexType.JSON:
self.args.extend(["ON", "JSON"])
elif index_type is not None:
raise RuntimeError(f"index_type must be one of {list(IndexType)}")
def _append_prefix(self, prefix):
"""Append PREFIX."""
if len(prefix) > 0:
self.args.append("PREFIX")
self.args.append(len(prefix))
for p in prefix:
self.args.append(p)
def _append_filter(self, filter):
"""Append FILTER."""
if filter is not None:
self.args.append("FILTER")
self.args.append(filter)
def _append_language(self, language_field, language):
"""Append LANGUAGE_FIELD and LANGUAGE."""
if language_field is not None:
self.args.append("LANGUAGE_FIELD")
self.args.append(language_field)
if language is not None:
self.args.append("LANGUAGE")
self.args.append(language)
def _append_score(self, score_field, score):
"""Append SCORE_FIELD and SCORE."""
if score_field is not None:
self.args.append("SCORE_FIELD")
self.args.append(score_field)
if score is not None:
self.args.append("SCORE")
self.args.append(score)
def _append_payload(self, payload_field):
"""Append PAYLOAD_FIELD."""
if payload_field is not None:
self.args.append("PAYLOAD_FIELD")
self.args.append(payload_field)

View File

@@ -0,0 +1,14 @@
from typing import Any
class ProfileInformation:
"""
Wrapper around FT.PROFILE response
"""
def __init__(self, info: Any) -> None:
self._info: Any = info
@property
def info(self) -> Any:
return self._info

View File

@@ -0,0 +1,381 @@
from typing import List, Optional, Union
from redis.commands.search.dialect import DEFAULT_DIALECT
class Query:
"""
Query is used to build complex queries that have more parameters than just
the query string. The query string is set in the constructor, and other
options have setter functions.
The setter functions return the query object, so they can be chained,
i.e. `Query("foo").verbatim().filter(...)` etc.
"""
def __init__(self, query_string: str) -> None:
"""
Create a new query object.
The query string is set in the constructor, and other options have
setter functions.
"""
self._query_string: str = query_string
self._offset: int = 0
self._num: int = 10
self._no_content: bool = False
self._no_stopwords: bool = False
self._fields: Optional[List[str]] = None
self._verbatim: bool = False
self._with_payloads: bool = False
self._with_scores: bool = False
self._scorer: Optional[str] = None
self._filters: List = list()
self._ids: Optional[List[str]] = None
self._slop: int = -1
self._timeout: Optional[float] = None
self._in_order: bool = False
self._sortby: Optional[SortbyField] = None
self._return_fields: List = []
self._return_fields_decode_as: dict = {}
self._summarize_fields: List = []
self._highlight_fields: List = []
self._language: Optional[str] = None
self._expander: Optional[str] = None
self._dialect: int = DEFAULT_DIALECT
def query_string(self) -> str:
"""Return the query string of this query only."""
return self._query_string
def limit_ids(self, *ids) -> "Query":
"""Limit the results to a specific set of pre-known document
ids of any length."""
self._ids = ids
return self
def return_fields(self, *fields) -> "Query":
"""Add fields to return fields."""
for field in fields:
self.return_field(field)
return self
def return_field(
self,
field: str,
as_field: Optional[str] = None,
decode_field: Optional[bool] = True,
encoding: Optional[str] = "utf8",
) -> "Query":
"""
Add a field to the list of fields to return.
- **field**: The field to include in query results
- **as_field**: The alias for the field
- **decode_field**: Whether to decode the field from bytes to string
- **encoding**: The encoding to use when decoding the field
"""
self._return_fields.append(field)
self._return_fields_decode_as[field] = encoding if decode_field else None
if as_field is not None:
self._return_fields += ("AS", as_field)
return self
def _mk_field_list(self, fields: List[str]) -> List:
if not fields:
return []
return [fields] if isinstance(fields, str) else list(fields)
def summarize(
self,
fields: Optional[List] = None,
context_len: Optional[int] = None,
num_frags: Optional[int] = None,
sep: Optional[str] = None,
) -> "Query":
"""
Return an abridged format of the field, containing only the segments of
the field which contain the matching term(s).
If `fields` is specified, then only the mentioned fields are
summarized; otherwise all results are summarized.
Server side defaults are used for each option (except `fields`)
if not specified
- **fields** List of fields to summarize. All fields are summarized
if not specified
- **context_len** Amount of context to include with each fragment
- **num_frags** Number of fragments per document
- **sep** Separator string to separate fragments
"""
args = ["SUMMARIZE"]
fields = self._mk_field_list(fields)
if fields:
args += ["FIELDS", str(len(fields))] + fields
if context_len is not None:
args += ["LEN", str(context_len)]
if num_frags is not None:
args += ["FRAGS", str(num_frags)]
if sep is not None:
args += ["SEPARATOR", sep]
self._summarize_fields = args
return self
def highlight(
self, fields: Optional[List[str]] = None, tags: Optional[List[str]] = None
) -> None:
"""
Apply specified markup to matched term(s) within the returned field(s).
- **fields** If specified then only those mentioned fields are
highlighted, otherwise all fields are highlighted
- **tags** A list of two strings to surround the match.
"""
args = ["HIGHLIGHT"]
fields = self._mk_field_list(fields)
if fields:
args += ["FIELDS", str(len(fields))] + fields
if tags:
args += ["TAGS"] + list(tags)
self._highlight_fields = args
return self
def language(self, language: str) -> "Query":
"""
Analyze the query as being in the specified language.
:param language: The language (e.g. `chinese` or `english`)
"""
self._language = language
return self
def slop(self, slop: int) -> "Query":
"""Allow a maximum of N intervening non matched terms between
phrase terms (0 means exact phrase).
"""
self._slop = slop
return self
def timeout(self, timeout: float) -> "Query":
"""overrides the timeout parameter of the module"""
self._timeout = timeout
return self
def in_order(self) -> "Query":
"""
Match only documents where the query terms appear in
the same order in the document.
i.e. for the query "hello world", we do not match "world hello"
"""
self._in_order = True
return self
def scorer(self, scorer: str) -> "Query":
"""
Use a different scoring function to evaluate document relevance.
Default is `TFIDF`.
Since Redis 8.0 default was changed to BM25STD.
:param scorer: The scoring function to use
(e.g. `TFIDF.DOCNORM` or `BM25`)
"""
self._scorer = scorer
return self
def get_args(self) -> List[str]:
"""Format the redis arguments for this query and return them."""
args = [self._query_string]
args += self._get_args_tags()
args += self._summarize_fields + self._highlight_fields
args += ["LIMIT", self._offset, self._num]
return args
def _get_args_tags(self) -> List[str]:
args = []
if self._no_content:
args.append("NOCONTENT")
if self._fields:
args.append("INFIELDS")
args.append(len(self._fields))
args += self._fields
if self._verbatim:
args.append("VERBATIM")
if self._no_stopwords:
args.append("NOSTOPWORDS")
if self._filters:
for flt in self._filters:
if not isinstance(flt, Filter):
raise AttributeError("Did not receive a Filter object.")
args += flt.args
if self._with_payloads:
args.append("WITHPAYLOADS")
if self._scorer:
args += ["SCORER", self._scorer]
if self._with_scores:
args.append("WITHSCORES")
if self._ids:
args.append("INKEYS")
args.append(len(self._ids))
args += self._ids
if self._slop >= 0:
args += ["SLOP", self._slop]
if self._timeout is not None:
args += ["TIMEOUT", self._timeout]
if self._in_order:
args.append("INORDER")
if self._return_fields:
args.append("RETURN")
args.append(len(self._return_fields))
args += self._return_fields
if self._sortby:
if not isinstance(self._sortby, SortbyField):
raise AttributeError("Did not receive a SortByField.")
args.append("SORTBY")
args += self._sortby.args
if self._language:
args += ["LANGUAGE", self._language]
if self._expander:
args += ["EXPANDER", self._expander]
if self._dialect:
args += ["DIALECT", self._dialect]
return args
def paging(self, offset: int, num: int) -> "Query":
"""
Set the paging for the query (defaults to 0..10).
- **offset**: Paging offset for the results. Defaults to 0
- **num**: How many results do we want
"""
self._offset = offset
self._num = num
return self
def verbatim(self) -> "Query":
"""Set the query to be verbatim, i.e. use no query expansion
or stemming.
"""
self._verbatim = True
return self
def no_content(self) -> "Query":
"""Set the query to only return ids and not the document content."""
self._no_content = True
return self
def no_stopwords(self) -> "Query":
"""
Prevent the query from being filtered for stopwords.
Only useful in very big queries that you are certain contain
no stopwords.
"""
self._no_stopwords = True
return self
def with_payloads(self) -> "Query":
"""Ask the engine to return document payloads."""
self._with_payloads = True
return self
def with_scores(self) -> "Query":
"""Ask the engine to return document search scores."""
self._with_scores = True
return self
def limit_fields(self, *fields: List[str]) -> "Query":
"""
Limit the search to specific TEXT fields only.
- **fields**: A list of strings, case sensitive field names
from the defined schema.
"""
self._fields = fields
return self
def add_filter(self, flt: "Filter") -> "Query":
"""
Add a numeric or geo filter to the query.
**Currently only one of each filter is supported by the engine**
- **flt**: A NumericFilter or GeoFilter object, used on a
corresponding field
"""
self._filters.append(flt)
return self
def sort_by(self, field: str, asc: bool = True) -> "Query":
"""
Add a sortby field to the query.
- **field** - the name of the field to sort by
- **asc** - when `True`, sorting will be done in asceding order
"""
self._sortby = SortbyField(field, asc)
return self
def expander(self, expander: str) -> "Query":
"""
Add a expander field to the query.
- **expander** - the name of the expander
"""
self._expander = expander
return self
def dialect(self, dialect: int) -> "Query":
"""
Add a dialect field to the query.
- **dialect** - dialect version to execute the query under
"""
self._dialect = dialect
return self
class Filter:
def __init__(self, keyword: str, field: str, *args: List[str]) -> None:
self.args = [keyword, field] + list(args)
class NumericFilter(Filter):
INF = "+inf"
NEG_INF = "-inf"
def __init__(
self,
field: str,
minval: Union[int, str],
maxval: Union[int, str],
minExclusive: bool = False,
maxExclusive: bool = False,
) -> None:
args = [
minval if not minExclusive else f"({minval}",
maxval if not maxExclusive else f"({maxval}",
]
Filter.__init__(self, "FILTER", field, *args)
class GeoFilter(Filter):
METERS = "m"
KILOMETERS = "km"
FEET = "ft"
MILES = "mi"
def __init__(
self, field: str, lon: float, lat: float, radius: float, unit: str = KILOMETERS
) -> None:
Filter.__init__(self, "GEOFILTER", field, lon, lat, radius, unit)
class SortbyField:
def __init__(self, field: str, asc=True) -> None:
self.args = [field, "ASC" if asc else "DESC"]

View File

@@ -0,0 +1,317 @@
def tags(*t):
"""
Indicate that the values should be matched to a tag field
### Parameters
- **t**: Tags to search for
"""
if not t:
raise ValueError("At least one tag must be specified")
return TagValue(*t)
def between(a, b, inclusive_min=True, inclusive_max=True):
"""
Indicate that value is a numeric range
"""
return RangeValue(a, b, inclusive_min=inclusive_min, inclusive_max=inclusive_max)
def equal(n):
"""
Match a numeric value
"""
return between(n, n)
def lt(n):
"""
Match any value less than n
"""
return between(None, n, inclusive_max=False)
def le(n):
"""
Match any value less or equal to n
"""
return between(None, n, inclusive_max=True)
def gt(n):
"""
Match any value greater than n
"""
return between(n, None, inclusive_min=False)
def ge(n):
"""
Match any value greater or equal to n
"""
return between(n, None, inclusive_min=True)
def geo(lat, lon, radius, unit="km"):
"""
Indicate that value is a geo region
"""
return GeoValue(lat, lon, radius, unit)
class Value:
@property
def combinable(self):
"""
Whether this type of value may be combined with other values
for the same field. This makes the filter potentially more efficient
"""
return False
@staticmethod
def make_value(v):
"""
Convert an object to a value, if it is not a value already
"""
if isinstance(v, Value):
return v
return ScalarValue(v)
def to_string(self):
raise NotImplementedError()
def __str__(self):
return self.to_string()
class RangeValue(Value):
combinable = False
def __init__(self, a, b, inclusive_min=False, inclusive_max=False):
if a is None:
a = "-inf"
if b is None:
b = "inf"
self.range = [str(a), str(b)]
self.inclusive_min = inclusive_min
self.inclusive_max = inclusive_max
def to_string(self):
return "[{1}{0[0]} {2}{0[1]}]".format(
self.range,
"(" if not self.inclusive_min else "",
"(" if not self.inclusive_max else "",
)
class ScalarValue(Value):
combinable = True
def __init__(self, v):
self.v = str(v)
def to_string(self):
return self.v
class TagValue(Value):
combinable = False
def __init__(self, *tags):
self.tags = tags
def to_string(self):
return "{" + " | ".join(str(t) for t in self.tags) + "}"
class GeoValue(Value):
def __init__(self, lon, lat, radius, unit="km"):
self.lon = lon
self.lat = lat
self.radius = radius
self.unit = unit
def to_string(self):
return f"[{self.lon} {self.lat} {self.radius} {self.unit}]"
class Node:
def __init__(self, *children, **kwparams):
"""
Create a node
### Parameters
- **children**: One or more sub-conditions. These can be additional
`intersect`, `disjunct`, `union`, `optional`, or any other `Node`
type.
The semantics of multiple conditions are dependent on the type of
query. For an `intersection` node, this amounts to a logical AND,
for a `union` node, this amounts to a logical `OR`.
- **kwparams**: key-value parameters. Each key is the name of a field,
and the value should be a field value. This can be one of the
following:
- Simple string (for text field matches)
- value returned by one of the helper functions
- list of either a string or a value
### Examples
Field `num` should be between 1 and 10
```
intersect(num=between(1, 10)
```
Name can either be `bob` or `john`
```
union(name=("bob", "john"))
```
Don't select countries in Israel, Japan, or US
```
disjunct_union(country=("il", "jp", "us"))
```
"""
self.params = []
kvparams = {}
for k, v in kwparams.items():
curvals = kvparams.setdefault(k, [])
if isinstance(v, (str, int, float)):
curvals.append(Value.make_value(v))
elif isinstance(v, Value):
curvals.append(v)
else:
curvals.extend(Value.make_value(subv) for subv in v)
self.params += [Node.to_node(p) for p in children]
for k, v in kvparams.items():
self.params.extend(self.join_fields(k, v))
def join_fields(self, key, vals):
if len(vals) == 1:
return [BaseNode(f"@{key}:{vals[0].to_string()}")]
if not vals[0].combinable:
return [BaseNode(f"@{key}:{v.to_string()}") for v in vals]
s = BaseNode(f"@{key}:({self.JOINSTR.join(v.to_string() for v in vals)})")
return [s]
@classmethod
def to_node(cls, obj): # noqa
if isinstance(obj, Node):
return obj
return BaseNode(obj)
@property
def JOINSTR(self):
raise NotImplementedError()
def to_string(self, with_parens=None):
with_parens = self._should_use_paren(with_parens)
pre, post = ("(", ")") if with_parens else ("", "")
return f"{pre}{self.JOINSTR.join(n.to_string() for n in self.params)}{post}"
def _should_use_paren(self, optval):
if optval is not None:
return optval
return len(self.params) > 1
def __str__(self):
return self.to_string()
class BaseNode(Node):
def __init__(self, s):
super().__init__()
self.s = str(s)
def to_string(self, with_parens=None):
return self.s
class IntersectNode(Node):
"""
Create an intersection node. All children need to be satisfied in order for
this node to evaluate as true
"""
JOINSTR = " "
class UnionNode(Node):
"""
Create a union node. Any of the children need to be satisfied in order for
this node to evaluate as true
"""
JOINSTR = "|"
class DisjunctNode(IntersectNode):
"""
Create a disjunct node. In order for this node to be true, all of its
children must evaluate to false
"""
def to_string(self, with_parens=None):
with_parens = self._should_use_paren(with_parens)
ret = super().to_string(with_parens=False)
if with_parens:
return "(-" + ret + ")"
else:
return "-" + ret
class DistjunctUnion(DisjunctNode):
"""
This node is true if *all* of its children are false. This is equivalent to
```
disjunct(union(...))
```
"""
JOINSTR = "|"
class OptionalNode(IntersectNode):
"""
Create an optional node. If this nodes evaluates to true, then the document
will be rated higher in score/rank.
"""
def to_string(self, with_parens=None):
with_parens = self._should_use_paren(with_parens)
ret = super().to_string(with_parens=False)
if with_parens:
return "(~" + ret + ")"
else:
return "~" + ret
def intersect(*args, **kwargs):
return IntersectNode(*args, **kwargs)
def union(*args, **kwargs):
return UnionNode(*args, **kwargs)
def disjunct(*args, **kwargs):
return DisjunctNode(*args, **kwargs)
def disjunct_union(*args, **kwargs):
return DistjunctUnion(*args, **kwargs)
def querystring(*args, **kwargs):
return intersect(*args, **kwargs).to_string()

View File

@@ -0,0 +1,182 @@
from typing import Union
from .aggregation import Asc, Desc, Reducer, SortDirection
class FieldOnlyReducer(Reducer):
"""See https://redis.io/docs/interact/search-and-query/search/aggregations/"""
def __init__(self, field: str) -> None:
super().__init__(field)
self._field = field
class count(Reducer):
"""
Counts the number of results in the group
"""
NAME = "COUNT"
def __init__(self) -> None:
super().__init__()
class sum(FieldOnlyReducer):
"""
Calculates the sum of all the values in the given fields within the group
"""
NAME = "SUM"
def __init__(self, field: str) -> None:
super().__init__(field)
class min(FieldOnlyReducer):
"""
Calculates the smallest value in the given field within the group
"""
NAME = "MIN"
def __init__(self, field: str) -> None:
super().__init__(field)
class max(FieldOnlyReducer):
"""
Calculates the largest value in the given field within the group
"""
NAME = "MAX"
def __init__(self, field: str) -> None:
super().__init__(field)
class avg(FieldOnlyReducer):
"""
Calculates the mean value in the given field within the group
"""
NAME = "AVG"
def __init__(self, field: str) -> None:
super().__init__(field)
class tolist(FieldOnlyReducer):
"""
Returns all the matched properties in a list
"""
NAME = "TOLIST"
def __init__(self, field: str) -> None:
super().__init__(field)
class count_distinct(FieldOnlyReducer):
"""
Calculate the number of distinct values contained in all the results in
the group for the given field
"""
NAME = "COUNT_DISTINCT"
def __init__(self, field: str) -> None:
super().__init__(field)
class count_distinctish(FieldOnlyReducer):
"""
Calculate the number of distinct values contained in all the results in the
group for the given field. This uses a faster algorithm than
`count_distinct` but is less accurate
"""
NAME = "COUNT_DISTINCTISH"
class quantile(Reducer):
"""
Return the value for the nth percentile within the range of values for the
field within the group.
"""
NAME = "QUANTILE"
def __init__(self, field: str, pct: float) -> None:
super().__init__(field, str(pct))
self._field = field
class stddev(FieldOnlyReducer):
"""
Return the standard deviation for the values within the group
"""
NAME = "STDDEV"
def __init__(self, field: str) -> None:
super().__init__(field)
class first_value(Reducer):
"""
Selects the first value within the group according to sorting parameters
"""
NAME = "FIRST_VALUE"
def __init__(self, field: str, *byfields: Union[Asc, Desc]) -> None:
"""
Selects the first value of the given field within the group.
### Parameter
- **field**: Source field used for the value
- **byfields**: How to sort the results. This can be either the
*class* of `aggregation.Asc` or `aggregation.Desc` in which
case the field `field` is also used as the sort input.
`byfields` can also be one or more *instances* of `Asc` or `Desc`
indicating the sort order for these fields
"""
fieldstrs = []
if (
len(byfields) == 1
and isinstance(byfields[0], type)
and issubclass(byfields[0], SortDirection)
):
byfields = [byfields[0](field)]
for f in byfields:
fieldstrs += [f.field, f.DIRSTRING]
args = [field]
if fieldstrs:
args += ["BY"] + fieldstrs
super().__init__(*args)
self._field = field
class random_sample(Reducer):
"""
Returns a random sample of items from the dataset, from the given property
"""
NAME = "RANDOM_SAMPLE"
def __init__(self, field: str, size: int) -> None:
"""
### Parameter
**field**: Field to sample from
**size**: Return this many items (can be less)
"""
args = [field, str(size)]
super().__init__(*args)
self._field = field

View File

@@ -0,0 +1,87 @@
from typing import Optional
from ._util import to_string
from .document import Document
class Result:
"""
Represents the result of a search query, and has an array of Document
objects
"""
def __init__(
self,
res,
hascontent,
duration=0,
has_payload=False,
with_scores=False,
field_encodings: Optional[dict] = None,
):
"""
- duration: the execution time of the query
- has_payload: whether the query has payloads
- with_scores: whether the query has scores
- field_encodings: a dictionary of field encodings if any is provided
"""
self.total = res[0]
self.duration = duration
self.docs = []
step = 1
if hascontent:
step = step + 1
if has_payload:
step = step + 1
if with_scores:
step = step + 1
offset = 2 if with_scores else 1
for i in range(1, len(res), step):
id = to_string(res[i])
payload = to_string(res[i + offset]) if has_payload else None
# fields_offset = 2 if has_payload else 1
fields_offset = offset + 1 if has_payload else offset
score = float(res[i + 1]) if with_scores else None
fields = {}
if hascontent and res[i + fields_offset] is not None:
keys = map(to_string, res[i + fields_offset][::2])
values = res[i + fields_offset][1::2]
for key, value in zip(keys, values):
if field_encodings is None or key not in field_encodings:
fields[key] = to_string(value)
continue
encoding = field_encodings[key]
# If the encoding is None, we don't need to decode the value
if encoding is None:
fields[key] = value
else:
fields[key] = to_string(value, encoding=encoding)
try:
del fields["id"]
except KeyError:
pass
try:
fields["json"] = fields["$"]
del fields["$"]
except KeyError:
pass
doc = (
Document(id, score=score, payload=payload, **fields)
if with_scores
else Document(id, payload=payload, **fields)
)
self.docs.append(doc)
def __repr__(self) -> str:
return f"Result{{{self.total} total, docs: {self.docs}}}"

View File

@@ -0,0 +1,55 @@
from typing import Optional
from ._util import to_string
class Suggestion:
"""
Represents a single suggestion being sent or returned from the
autocomplete server
"""
def __init__(
self, string: str, score: float = 1.0, payload: Optional[str] = None
) -> None:
self.string = to_string(string)
self.payload = to_string(payload)
self.score = score
def __repr__(self) -> str:
return self.string
class SuggestionParser:
"""
Internal class used to parse results from the `SUGGET` command.
This needs to consume either 1, 2, or 3 values at a time from
the return value depending on what objects were requested
"""
def __init__(self, with_scores: bool, with_payloads: bool, ret) -> None:
self.with_scores = with_scores
self.with_payloads = with_payloads
if with_scores and with_payloads:
self.sugsize = 3
self._scoreidx = 1
self._payloadidx = 2
elif with_scores:
self.sugsize = 2
self._scoreidx = 1
elif with_payloads:
self.sugsize = 2
self._payloadidx = 1
else:
self.sugsize = 1
self._scoreidx = -1
self._sugs = ret
def __iter__(self):
for i in range(0, len(self._sugs), self.sugsize):
ss = self._sugs[i]
score = float(self._sugs[i + self._scoreidx]) if self.with_scores else 1.0
payload = self._sugs[i + self._payloadidx] if self.with_payloads else None
yield Suggestion(ss, score, payload)

View File

@@ -0,0 +1,129 @@
import warnings
class SentinelCommands:
"""
A class containing the commands specific to redis sentinel. This class is
to be used as a mixin.
"""
def sentinel(self, *args):
"""Redis Sentinel's SENTINEL command."""
warnings.warn(DeprecationWarning("Use the individual sentinel_* methods"))
def sentinel_get_master_addr_by_name(self, service_name, return_responses=False):
"""
Returns a (host, port) pair for the given ``service_name`` when return_responses is True,
otherwise returns a boolean value that indicates if the command was successful.
"""
return self.execute_command(
"SENTINEL GET-MASTER-ADDR-BY-NAME",
service_name,
once=True,
return_responses=return_responses,
)
def sentinel_master(self, service_name, return_responses=False):
"""
Returns a dictionary containing the specified masters state, when return_responses is True,
otherwise returns a boolean value that indicates if the command was successful.
"""
return self.execute_command(
"SENTINEL MASTER", service_name, return_responses=return_responses
)
def sentinel_masters(self):
"""
Returns a list of dictionaries containing each master's state.
Important: This function is called by the Sentinel implementation and is
called directly on the Redis standalone client for sentinels,
so it doesn't support the "once" and "return_responses" options.
"""
return self.execute_command("SENTINEL MASTERS")
def sentinel_monitor(self, name, ip, port, quorum):
"""Add a new master to Sentinel to be monitored"""
return self.execute_command("SENTINEL MONITOR", name, ip, port, quorum)
def sentinel_remove(self, name):
"""Remove a master from Sentinel's monitoring"""
return self.execute_command("SENTINEL REMOVE", name)
def sentinel_sentinels(self, service_name, return_responses=False):
"""
Returns a list of sentinels for ``service_name``, when return_responses is True,
otherwise returns a boolean value that indicates if the command was successful.
"""
return self.execute_command(
"SENTINEL SENTINELS", service_name, return_responses=return_responses
)
def sentinel_set(self, name, option, value):
"""Set Sentinel monitoring parameters for a given master"""
return self.execute_command("SENTINEL SET", name, option, value)
def sentinel_slaves(self, service_name):
"""
Returns a list of slaves for ``service_name``
Important: This function is called by the Sentinel implementation and is
called directly on the Redis standalone client for sentinels,
so it doesn't support the "once" and "return_responses" options.
"""
return self.execute_command("SENTINEL SLAVES", service_name)
def sentinel_reset(self, pattern):
"""
This command will reset all the masters with matching name.
The pattern argument is a glob-style pattern.
The reset process clears any previous state in a master (including a
failover in progress), and removes every slave and sentinel already
discovered and associated with the master.
"""
return self.execute_command("SENTINEL RESET", pattern, once=True)
def sentinel_failover(self, new_master_name):
"""
Force a failover as if the master was not reachable, and without
asking for agreement to other Sentinels (however a new version of the
configuration will be published so that the other Sentinels will
update their configurations).
"""
return self.execute_command("SENTINEL FAILOVER", new_master_name)
def sentinel_ckquorum(self, new_master_name):
"""
Check if the current Sentinel configuration is able to reach the
quorum needed to failover a master, and the majority needed to
authorize the failover.
This command should be used in monitoring systems to check if a
Sentinel deployment is ok.
"""
return self.execute_command("SENTINEL CKQUORUM", new_master_name, once=True)
def sentinel_flushconfig(self):
"""
Force Sentinel to rewrite its configuration on disk, including the
current Sentinel state.
Normally Sentinel rewrites the configuration every time something
changes in its state (in the context of the subset of the state which
is persisted on disk across restart).
However sometimes it is possible that the configuration file is lost
because of operation errors, disk failures, package upgrade scripts or
configuration managers. In those cases a way to to force Sentinel to
rewrite the configuration file is handy.
This command works even if the previous configuration file is
completely missing.
"""
return self.execute_command("SENTINEL FLUSHCONFIG")
class AsyncSentinelCommands(SentinelCommands):
async def sentinel(self, *args) -> None:
"""Redis Sentinel's SENTINEL command."""
super().sentinel(*args)

View File

@@ -0,0 +1,108 @@
import redis
from redis._parsers.helpers import bool_ok
from ..helpers import get_protocol_version, parse_to_list
from .commands import (
ALTER_CMD,
CREATE_CMD,
CREATERULE_CMD,
DEL_CMD,
DELETERULE_CMD,
GET_CMD,
INFO_CMD,
MGET_CMD,
MRANGE_CMD,
MREVRANGE_CMD,
QUERYINDEX_CMD,
RANGE_CMD,
REVRANGE_CMD,
TimeSeriesCommands,
)
from .info import TSInfo
from .utils import parse_get, parse_m_get, parse_m_range, parse_range
class TimeSeries(TimeSeriesCommands):
"""
This class subclasses redis-py's `Redis` and implements RedisTimeSeries's
commands (prefixed with "ts").
The client allows to interact with RedisTimeSeries and use all of it's
functionality.
"""
def __init__(self, client=None, **kwargs):
"""Create a new RedisTimeSeries client."""
# Set the module commands' callbacks
self._MODULE_CALLBACKS = {
ALTER_CMD: bool_ok,
CREATE_CMD: bool_ok,
CREATERULE_CMD: bool_ok,
DELETERULE_CMD: bool_ok,
}
_RESP2_MODULE_CALLBACKS = {
DEL_CMD: int,
GET_CMD: parse_get,
INFO_CMD: TSInfo,
MGET_CMD: parse_m_get,
MRANGE_CMD: parse_m_range,
MREVRANGE_CMD: parse_m_range,
RANGE_CMD: parse_range,
REVRANGE_CMD: parse_range,
QUERYINDEX_CMD: parse_to_list,
}
_RESP3_MODULE_CALLBACKS = {}
self.client = client
self.execute_command = client.execute_command
if get_protocol_version(self.client) in ["3", 3]:
self._MODULE_CALLBACKS.update(_RESP3_MODULE_CALLBACKS)
else:
self._MODULE_CALLBACKS.update(_RESP2_MODULE_CALLBACKS)
for k, v in self._MODULE_CALLBACKS.items():
self.client.set_response_callback(k, v)
def pipeline(self, transaction=True, shard_hint=None):
"""Creates a pipeline for the TimeSeries module, that can be used
for executing only TimeSeries commands and core commands.
Usage example:
r = redis.Redis()
pipe = r.ts().pipeline()
for i in range(100):
pipeline.add("with_pipeline", i, 1.1 * i)
pipeline.execute()
"""
if isinstance(self.client, redis.RedisCluster):
p = ClusterPipeline(
nodes_manager=self.client.nodes_manager,
commands_parser=self.client.commands_parser,
startup_nodes=self.client.nodes_manager.startup_nodes,
result_callbacks=self.client.result_callbacks,
cluster_response_callbacks=self.client.cluster_response_callbacks,
cluster_error_retry_attempts=self.client.retry.get_retries(),
read_from_replicas=self.client.read_from_replicas,
reinitialize_steps=self.client.reinitialize_steps,
lock=self.client._lock,
)
else:
p = Pipeline(
connection_pool=self.client.connection_pool,
response_callbacks=self._MODULE_CALLBACKS,
transaction=transaction,
shard_hint=shard_hint,
)
return p
class ClusterPipeline(TimeSeriesCommands, redis.cluster.ClusterPipeline):
"""Cluster pipeline for the module."""
class Pipeline(TimeSeriesCommands, redis.client.Pipeline):
"""Pipeline for the module."""

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,91 @@
from ..helpers import nativestr
from .utils import list_to_dict
class TSInfo:
"""
Hold information and statistics on the time-series.
Can be created using ``tsinfo`` command
https://redis.io/docs/latest/commands/ts.info/
"""
rules = []
labels = []
sourceKey = None
chunk_count = None
memory_usage = None
total_samples = None
retention_msecs = None
last_time_stamp = None
first_time_stamp = None
max_samples_per_chunk = None
chunk_size = None
duplicate_policy = None
def __init__(self, args):
"""
Hold information and statistics on the time-series.
The supported params that can be passed as args:
rules:
A list of compaction rules of the time series.
sourceKey:
Key name for source time series in case the current series
is a target of a rule.
chunkCount:
Number of Memory Chunks used for the time series.
memoryUsage:
Total number of bytes allocated for the time series.
totalSamples:
Total number of samples in the time series.
labels:
A list of label-value pairs that represent the metadata
labels of the time series.
retentionTime:
Retention time, in milliseconds, for the time series.
lastTimestamp:
Last timestamp present in the time series.
firstTimestamp:
First timestamp present in the time series.
maxSamplesPerChunk:
Deprecated.
chunkSize:
Amount of memory, in bytes, allocated for data.
duplicatePolicy:
Policy that will define handling of duplicate samples.
Can read more about on
https://redis.io/docs/latest/develop/data-types/timeseries/configuration/#duplicate_policy
"""
response = dict(zip(map(nativestr, args[::2]), args[1::2]))
self.rules = response.get("rules")
self.source_key = response.get("sourceKey")
self.chunk_count = response.get("chunkCount")
self.memory_usage = response.get("memoryUsage")
self.total_samples = response.get("totalSamples")
self.labels = list_to_dict(response.get("labels"))
self.retention_msecs = response.get("retentionTime")
self.last_timestamp = response.get("lastTimestamp")
self.first_timestamp = response.get("firstTimestamp")
if "maxSamplesPerChunk" in response:
self.max_samples_per_chunk = response["maxSamplesPerChunk"]
self.chunk_size = (
self.max_samples_per_chunk * 16
) # backward compatible changes
if "chunkSize" in response:
self.chunk_size = response["chunkSize"]
if "duplicatePolicy" in response:
self.duplicate_policy = response["duplicatePolicy"]
if isinstance(self.duplicate_policy, bytes):
self.duplicate_policy = self.duplicate_policy.decode()
def get(self, item):
try:
return self.__getitem__(item)
except AttributeError:
return None
def __getitem__(self, item):
return getattr(self, item)

View File

@@ -0,0 +1,44 @@
from ..helpers import nativestr
def list_to_dict(aList):
return {nativestr(aList[i][0]): nativestr(aList[i][1]) for i in range(len(aList))}
def parse_range(response, **kwargs):
"""Parse range response. Used by TS.RANGE and TS.REVRANGE."""
return [tuple((r[0], float(r[1]))) for r in response]
def parse_m_range(response):
"""Parse multi range response. Used by TS.MRANGE and TS.MREVRANGE."""
res = []
for item in response:
res.append({nativestr(item[0]): [list_to_dict(item[1]), parse_range(item[2])]})
return sorted(res, key=lambda d: list(d.keys()))
def parse_get(response):
"""Parse get response. Used by TS.GET."""
if not response:
return None
return int(response[0]), float(response[1])
def parse_m_get(response):
"""Parse multi get response. Used by TS.MGET."""
res = []
for item in response:
if not item[2]:
res.append({nativestr(item[0]): [list_to_dict(item[1]), None, None]})
else:
res.append(
{
nativestr(item[0]): [
list_to_dict(item[1]),
int(item[2][0]),
float(item[2][1]),
]
}
)
return sorted(res, key=lambda d: list(d.keys()))

View File

@@ -0,0 +1,46 @@
import json
from redis._parsers.helpers import pairs_to_dict
from redis.commands.vectorset.utils import (
parse_vemb_result,
parse_vlinks_result,
parse_vsim_result,
)
from ..helpers import get_protocol_version
from .commands import (
VEMB_CMD,
VGETATTR_CMD,
VINFO_CMD,
VLINKS_CMD,
VSIM_CMD,
VectorSetCommands,
)
class VectorSet(VectorSetCommands):
def __init__(self, client, **kwargs):
"""Create a new VectorSet client."""
# Set the module commands' callbacks
self._MODULE_CALLBACKS = {
VEMB_CMD: parse_vemb_result,
VGETATTR_CMD: lambda r: r and json.loads(r) or None,
}
self._RESP2_MODULE_CALLBACKS = {
VINFO_CMD: lambda r: r and pairs_to_dict(r) or None,
VSIM_CMD: parse_vsim_result,
VLINKS_CMD: parse_vlinks_result,
}
self._RESP3_MODULE_CALLBACKS = {}
self.client = client
self.execute_command = client.execute_command
if get_protocol_version(self.client) in ["3", 3]:
self._MODULE_CALLBACKS.update(self._RESP3_MODULE_CALLBACKS)
else:
self._MODULE_CALLBACKS.update(self._RESP2_MODULE_CALLBACKS)
for k, v in self._MODULE_CALLBACKS.items():
self.client.set_response_callback(k, v)

View File

@@ -0,0 +1,374 @@
import json
from enum import Enum
from typing import Awaitable, Dict, List, Optional, Union
from redis.client import NEVER_DECODE
from redis.commands.helpers import get_protocol_version
from redis.exceptions import DataError
from redis.typing import CommandsProtocol, EncodableT, KeyT, Number
VADD_CMD = "VADD"
VSIM_CMD = "VSIM"
VREM_CMD = "VREM"
VDIM_CMD = "VDIM"
VCARD_CMD = "VCARD"
VEMB_CMD = "VEMB"
VLINKS_CMD = "VLINKS"
VINFO_CMD = "VINFO"
VSETATTR_CMD = "VSETATTR"
VGETATTR_CMD = "VGETATTR"
VRANDMEMBER_CMD = "VRANDMEMBER"
class QuantizationOptions(Enum):
"""Quantization options for the VADD command."""
NOQUANT = "NOQUANT"
BIN = "BIN"
Q8 = "Q8"
class CallbacksOptions(Enum):
"""Options that can be set for the commands callbacks"""
RAW = "RAW"
WITHSCORES = "WITHSCORES"
ALLOW_DECODING = "ALLOW_DECODING"
RESP3 = "RESP3"
class VectorSetCommands(CommandsProtocol):
"""Redis VectorSet commands"""
def vadd(
self,
key: KeyT,
vector: Union[List[float], bytes],
element: str,
reduce_dim: Optional[int] = None,
cas: Optional[bool] = False,
quantization: Optional[QuantizationOptions] = None,
ef: Optional[Number] = None,
attributes: Optional[Union[dict, str]] = None,
numlinks: Optional[int] = None,
) -> Union[Awaitable[int], int]:
"""
Add vector ``vector`` for element ``element`` to a vector set ``key``.
``reduce_dim`` sets the dimensions to reduce the vector to.
If not provided, the vector is not reduced.
``cas`` is a boolean flag that indicates whether to use CAS (check-and-set style)
when adding the vector. If not provided, CAS is not used.
``quantization`` sets the quantization type to use.
If not provided, int8 quantization is used.
The options are:
- NOQUANT: No quantization
- BIN: Binary quantization
- Q8: Signed 8-bit quantization
``ef`` sets the exploration factor to use.
If not provided, the default exploration factor is used.
``attributes`` is a dictionary or json string that contains the attributes to set for the vector.
If not provided, no attributes are set.
``numlinks`` sets the number of links to create for the vector.
If not provided, the default number of links is used.
For more information see https://redis.io/commands/vadd
"""
if not vector or not element:
raise DataError("Both vector and element must be provided")
pieces = []
if reduce_dim:
pieces.extend(["REDUCE", reduce_dim])
values_pieces = []
if isinstance(vector, bytes):
values_pieces.extend(["FP32", vector])
else:
values_pieces.extend(["VALUES", len(vector)])
values_pieces.extend(vector)
pieces.extend(values_pieces)
pieces.append(element)
if cas:
pieces.append("CAS")
if quantization:
pieces.append(quantization.value)
if ef:
pieces.extend(["EF", ef])
if attributes:
if isinstance(attributes, dict):
# transform attributes to json string
attributes_json = json.dumps(attributes)
else:
attributes_json = attributes
pieces.extend(["SETATTR", attributes_json])
if numlinks:
pieces.extend(["M", numlinks])
return self.execute_command(VADD_CMD, key, *pieces)
def vsim(
self,
key: KeyT,
input: Union[List[float], bytes, str],
with_scores: Optional[bool] = False,
count: Optional[int] = None,
ef: Optional[Number] = None,
filter: Optional[str] = None,
filter_ef: Optional[str] = None,
truth: Optional[bool] = False,
no_thread: Optional[bool] = False,
epsilon: Optional[Number] = None,
) -> Union[
Awaitable[Optional[List[Union[List[EncodableT], Dict[EncodableT, Number]]]]],
Optional[List[Union[List[EncodableT], Dict[EncodableT, Number]]]],
]:
"""
Compare a vector or element ``input`` with the other vectors in a vector set ``key``.
``with_scores`` sets if the results should be returned with the
similarity scores of the elements in the result.
``count`` sets the number of results to return.
``ef`` sets the exploration factor.
``filter`` sets filter that should be applied for the search.
``filter_ef`` sets the max filtering effort.
``truth`` when enabled forces the command to perform linear scan.
``no_thread`` when enabled forces the command to execute the search
on the data structure in the main thread.
``epsilon`` floating point between 0 and 1, if specified will return
only elements with distance no further than the specified one.
For more information see https://redis.io/commands/vsim
"""
if not input:
raise DataError("'input' should be provided")
pieces = []
options = {}
if isinstance(input, bytes):
pieces.extend(["FP32", input])
elif isinstance(input, list):
pieces.extend(["VALUES", len(input)])
pieces.extend(input)
else:
pieces.extend(["ELE", input])
if with_scores:
pieces.append("WITHSCORES")
options[CallbacksOptions.WITHSCORES.value] = True
if count:
pieces.extend(["COUNT", count])
if epsilon:
pieces.extend(["EPSILON", epsilon])
if ef:
pieces.extend(["EF", ef])
if filter:
pieces.extend(["FILTER", filter])
if filter_ef:
pieces.extend(["FILTER-EF", filter_ef])
if truth:
pieces.append("TRUTH")
if no_thread:
pieces.append("NOTHREAD")
return self.execute_command(VSIM_CMD, key, *pieces, **options)
def vdim(self, key: KeyT) -> Union[Awaitable[int], int]:
"""
Get the dimension of a vector set.
In the case of vectors that were populated using the `REDUCE`
option, for random projection, the vector set will report the size of
the projected (reduced) dimension.
Raises `redis.exceptions.ResponseError` if the vector set doesn't exist.
For more information see https://redis.io/commands/vdim
"""
return self.execute_command(VDIM_CMD, key)
def vcard(self, key: KeyT) -> Union[Awaitable[int], int]:
"""
Get the cardinality(the number of elements) of a vector set with key ``key``.
Raises `redis.exceptions.ResponseError` if the vector set doesn't exist.
For more information see https://redis.io/commands/vcard
"""
return self.execute_command(VCARD_CMD, key)
def vrem(self, key: KeyT, element: str) -> Union[Awaitable[int], int]:
"""
Remove an element from a vector set.
For more information see https://redis.io/commands/vrem
"""
return self.execute_command(VREM_CMD, key, element)
def vemb(
self, key: KeyT, element: str, raw: Optional[bool] = False
) -> Union[
Awaitable[Optional[Union[List[EncodableT], Dict[str, EncodableT]]]],
Optional[Union[List[EncodableT], Dict[str, EncodableT]]],
]:
"""
Get the approximated vector of an element ``element`` from vector set ``key``.
``raw`` is a boolean flag that indicates whether to return the
interal representation used by the vector.
For more information see https://redis.io/commands/vembed
"""
options = {}
pieces = []
pieces.extend([key, element])
if get_protocol_version(self.client) in ["3", 3]:
options[CallbacksOptions.RESP3.value] = True
if raw:
pieces.append("RAW")
options[NEVER_DECODE] = True
if (
hasattr(self.client, "connection_pool")
and self.client.connection_pool.connection_kwargs["decode_responses"]
) or (
hasattr(self.client, "nodes_manager")
and self.client.nodes_manager.connection_kwargs["decode_responses"]
):
# allow decoding in the postprocessing callback
# if the user set decode_responses=True
# in the connection pool
options[CallbacksOptions.ALLOW_DECODING.value] = True
options[CallbacksOptions.RAW.value] = True
return self.execute_command(VEMB_CMD, *pieces, **options)
def vlinks(
self, key: KeyT, element: str, with_scores: Optional[bool] = False
) -> Union[
Awaitable[
Optional[
List[Union[List[Union[str, bytes]], Dict[Union[str, bytes], Number]]]
]
],
Optional[List[Union[List[Union[str, bytes]], Dict[Union[str, bytes], Number]]]],
]:
"""
Returns the neighbors for each level the element ``element`` exists in the vector set ``key``.
The result is a list of lists, where each list contains the neighbors for one level.
If the element does not exist, or if the vector set does not exist, None is returned.
If the ``WITHSCORES`` option is provided, the result is a list of dicts,
where each dict contains the neighbors for one level, with the scores as values.
For more information see https://redis.io/commands/vlinks
"""
options = {}
pieces = []
pieces.extend([key, element])
if with_scores:
pieces.append("WITHSCORES")
options[CallbacksOptions.WITHSCORES.value] = True
return self.execute_command(VLINKS_CMD, *pieces, **options)
def vinfo(self, key: KeyT) -> Union[Awaitable[dict], dict]:
"""
Get information about a vector set.
For more information see https://redis.io/commands/vinfo
"""
return self.execute_command(VINFO_CMD, key)
def vsetattr(
self, key: KeyT, element: str, attributes: Optional[Union[dict, str]] = None
) -> Union[Awaitable[int], int]:
"""
Associate or remove JSON attributes ``attributes`` of element ``element``
for vector set ``key``.
For more information see https://redis.io/commands/vsetattr
"""
if attributes is None:
attributes_json = "{}"
elif isinstance(attributes, dict):
# transform attributes to json string
attributes_json = json.dumps(attributes)
else:
attributes_json = attributes
return self.execute_command(VSETATTR_CMD, key, element, attributes_json)
def vgetattr(
self, key: KeyT, element: str
) -> Union[Optional[Awaitable[dict]], Optional[dict]]:
"""
Retrieve the JSON attributes of an element ``elemet`` for vector set ``key``.
If the element does not exist, or if the vector set does not exist, None is
returned.
For more information see https://redis.io/commands/vgetattr
"""
return self.execute_command(VGETATTR_CMD, key, element)
def vrandmember(
self, key: KeyT, count: Optional[int] = None
) -> Union[
Awaitable[Optional[Union[List[str], str]]], Optional[Union[List[str], str]]
]:
"""
Returns random elements from a vector set ``key``.
``count`` is the number of elements to return.
If ``count`` is not provided, a single element is returned as a single string.
If ``count`` is positive(smaller than the number of elements
in the vector set), the command returns a list with up to ``count``
distinct elements from the vector set
If ``count`` is negative, the command returns a list with ``count`` random elements,
potentially with duplicates.
If ``count`` is greater than the number of elements in the vector set,
only the entire set is returned as a list.
If the vector set does not exist, ``None`` is returned.
For more information see https://redis.io/commands/vrandmember
"""
pieces = []
pieces.append(key)
if count is not None:
pieces.append(count)
return self.execute_command(VRANDMEMBER_CMD, *pieces)

View File

@@ -0,0 +1,94 @@
from redis._parsers.helpers import pairs_to_dict
from redis.commands.vectorset.commands import CallbacksOptions
def parse_vemb_result(response, **options):
"""
Handle VEMB result since the command can returning different result
structures depending on input options and on quantization type of the vector set.
Parsing VEMB result into:
- List[Union[bytes, Union[int, float]]]
- Dict[str, Union[bytes, str, float]]
"""
if response is None:
return response
if options.get(CallbacksOptions.RAW.value):
result = {}
result["quantization"] = (
response[0].decode("utf-8")
if options.get(CallbacksOptions.ALLOW_DECODING.value)
else response[0]
)
result["raw"] = response[1]
result["l2"] = float(response[2])
if len(response) > 3:
result["range"] = float(response[3])
return result
else:
if options.get(CallbacksOptions.RESP3.value):
return response
result = []
for i in range(len(response)):
try:
result.append(int(response[i]))
except ValueError:
# if the value is not an integer, it should be a float
result.append(float(response[i]))
return result
def parse_vlinks_result(response, **options):
"""
Handle VLINKS result since the command can be returning different result
structures depending on input options.
Parsing VLINKS result into:
- List[List[str]]
- List[Dict[str, Number]]
"""
if response is None:
return response
if options.get(CallbacksOptions.WITHSCORES.value):
result = []
# Redis will return a list of list of strings.
# This list have to be transformed to list of dicts
for level_item in response:
level_data_dict = {}
for key, value in pairs_to_dict(level_item).items():
value = float(value)
level_data_dict[key] = value
result.append(level_data_dict)
return result
else:
# return the list of elements for each level
# list of lists
return response
def parse_vsim_result(response, **options):
"""
Handle VSIM result since the command can be returning different result
structures depending on input options.
Parsing VSIM result into:
- List[List[str]]
- List[Dict[str, Number]]
"""
if response is None:
return response
if options.get(CallbacksOptions.WITHSCORES.value):
# Redis will return a list of list of pairs.
# This list have to be transformed to dict
result_dict = {}
for key, value in pairs_to_dict(response).items():
value = float(value)
result_dict[key] = value
return result_dict
else:
# return the list of elements for each level
# list of lists
return response