Updates
This commit is contained in:
@@ -0,0 +1,189 @@
|
||||
import redis
|
||||
|
||||
from ...asyncio.client import Pipeline as AsyncioPipeline
|
||||
from .commands import (
|
||||
AGGREGATE_CMD,
|
||||
CONFIG_CMD,
|
||||
INFO_CMD,
|
||||
PROFILE_CMD,
|
||||
SEARCH_CMD,
|
||||
SPELLCHECK_CMD,
|
||||
SYNDUMP_CMD,
|
||||
AsyncSearchCommands,
|
||||
SearchCommands,
|
||||
)
|
||||
|
||||
|
||||
class Search(SearchCommands):
|
||||
"""
|
||||
Create a client for talking to search.
|
||||
It abstracts the API of the module and lets you just use the engine.
|
||||
"""
|
||||
|
||||
class BatchIndexer:
|
||||
"""
|
||||
A batch indexer allows you to automatically batch
|
||||
document indexing in pipelines, flushing it every N documents.
|
||||
"""
|
||||
|
||||
def __init__(self, client, chunk_size=1000):
|
||||
self.client = client
|
||||
self.execute_command = client.execute_command
|
||||
self._pipeline = client.pipeline(transaction=False, shard_hint=None)
|
||||
self.total = 0
|
||||
self.chunk_size = chunk_size
|
||||
self.current_chunk = 0
|
||||
|
||||
def __del__(self):
|
||||
if self.current_chunk:
|
||||
self.commit()
|
||||
|
||||
def add_document(
|
||||
self,
|
||||
doc_id,
|
||||
nosave=False,
|
||||
score=1.0,
|
||||
payload=None,
|
||||
replace=False,
|
||||
partial=False,
|
||||
no_create=False,
|
||||
**fields,
|
||||
):
|
||||
"""
|
||||
Add a document to the batch query
|
||||
"""
|
||||
self.client._add_document(
|
||||
doc_id,
|
||||
conn=self._pipeline,
|
||||
nosave=nosave,
|
||||
score=score,
|
||||
payload=payload,
|
||||
replace=replace,
|
||||
partial=partial,
|
||||
no_create=no_create,
|
||||
**fields,
|
||||
)
|
||||
self.current_chunk += 1
|
||||
self.total += 1
|
||||
if self.current_chunk >= self.chunk_size:
|
||||
self.commit()
|
||||
|
||||
def add_document_hash(self, doc_id, score=1.0, replace=False):
|
||||
"""
|
||||
Add a hash to the batch query
|
||||
"""
|
||||
self.client._add_document_hash(
|
||||
doc_id, conn=self._pipeline, score=score, replace=replace
|
||||
)
|
||||
self.current_chunk += 1
|
||||
self.total += 1
|
||||
if self.current_chunk >= self.chunk_size:
|
||||
self.commit()
|
||||
|
||||
def commit(self):
|
||||
"""
|
||||
Manually commit and flush the batch indexing query
|
||||
"""
|
||||
self._pipeline.execute()
|
||||
self.current_chunk = 0
|
||||
|
||||
def __init__(self, client, index_name="idx"):
|
||||
"""
|
||||
Create a new Client for the given index_name.
|
||||
The default name is `idx`
|
||||
|
||||
If conn is not None, we employ an already existing redis connection
|
||||
"""
|
||||
self._MODULE_CALLBACKS = {}
|
||||
self.client = client
|
||||
self.index_name = index_name
|
||||
self.execute_command = client.execute_command
|
||||
self._pipeline = client.pipeline
|
||||
self._RESP2_MODULE_CALLBACKS = {
|
||||
INFO_CMD: self._parse_info,
|
||||
SEARCH_CMD: self._parse_search,
|
||||
AGGREGATE_CMD: self._parse_aggregate,
|
||||
PROFILE_CMD: self._parse_profile,
|
||||
SPELLCHECK_CMD: self._parse_spellcheck,
|
||||
CONFIG_CMD: self._parse_config_get,
|
||||
SYNDUMP_CMD: self._parse_syndump,
|
||||
}
|
||||
|
||||
def pipeline(self, transaction=True, shard_hint=None):
|
||||
"""Creates a pipeline for the SEARCH module, that can be used for executing
|
||||
SEARCH commands, as well as classic core commands.
|
||||
"""
|
||||
p = Pipeline(
|
||||
connection_pool=self.client.connection_pool,
|
||||
response_callbacks=self._MODULE_CALLBACKS,
|
||||
transaction=transaction,
|
||||
shard_hint=shard_hint,
|
||||
)
|
||||
p.index_name = self.index_name
|
||||
return p
|
||||
|
||||
|
||||
class AsyncSearch(Search, AsyncSearchCommands):
|
||||
class BatchIndexer(Search.BatchIndexer):
|
||||
"""
|
||||
A batch indexer allows you to automatically batch
|
||||
document indexing in pipelines, flushing it every N documents.
|
||||
"""
|
||||
|
||||
async def add_document(
|
||||
self,
|
||||
doc_id,
|
||||
nosave=False,
|
||||
score=1.0,
|
||||
payload=None,
|
||||
replace=False,
|
||||
partial=False,
|
||||
no_create=False,
|
||||
**fields,
|
||||
):
|
||||
"""
|
||||
Add a document to the batch query
|
||||
"""
|
||||
self.client._add_document(
|
||||
doc_id,
|
||||
conn=self._pipeline,
|
||||
nosave=nosave,
|
||||
score=score,
|
||||
payload=payload,
|
||||
replace=replace,
|
||||
partial=partial,
|
||||
no_create=no_create,
|
||||
**fields,
|
||||
)
|
||||
self.current_chunk += 1
|
||||
self.total += 1
|
||||
if self.current_chunk >= self.chunk_size:
|
||||
await self.commit()
|
||||
|
||||
async def commit(self):
|
||||
"""
|
||||
Manually commit and flush the batch indexing query
|
||||
"""
|
||||
await self._pipeline.execute()
|
||||
self.current_chunk = 0
|
||||
|
||||
def pipeline(self, transaction=True, shard_hint=None):
|
||||
"""Creates a pipeline for the SEARCH module, that can be used for executing
|
||||
SEARCH commands, as well as classic core commands.
|
||||
"""
|
||||
p = AsyncPipeline(
|
||||
connection_pool=self.client.connection_pool,
|
||||
response_callbacks=self._MODULE_CALLBACKS,
|
||||
transaction=transaction,
|
||||
shard_hint=shard_hint,
|
||||
)
|
||||
p.index_name = self.index_name
|
||||
return p
|
||||
|
||||
|
||||
class Pipeline(SearchCommands, redis.client.Pipeline):
|
||||
"""Pipeline for the module."""
|
||||
|
||||
|
||||
class AsyncPipeline(AsyncSearchCommands, AsyncioPipeline, Pipeline):
|
||||
"""AsyncPipeline for the module."""
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,7 @@
|
||||
def to_string(s, encoding: str = "utf-8"):
|
||||
if isinstance(s, str):
|
||||
return s
|
||||
elif isinstance(s, bytes):
|
||||
return s.decode(encoding, "ignore")
|
||||
else:
|
||||
return s # Not a string we care about
|
||||
@@ -0,0 +1,401 @@
|
||||
from typing import List, Union
|
||||
|
||||
from redis.commands.search.dialect import DEFAULT_DIALECT
|
||||
|
||||
FIELDNAME = object()
|
||||
|
||||
|
||||
class Limit:
|
||||
def __init__(self, offset: int = 0, count: int = 0) -> None:
|
||||
self.offset = offset
|
||||
self.count = count
|
||||
|
||||
def build_args(self):
|
||||
if self.count:
|
||||
return ["LIMIT", str(self.offset), str(self.count)]
|
||||
else:
|
||||
return []
|
||||
|
||||
|
||||
class Reducer:
|
||||
"""
|
||||
Base reducer object for all reducers.
|
||||
|
||||
See the `redisearch.reducers` module for the actual reducers.
|
||||
"""
|
||||
|
||||
NAME = None
|
||||
|
||||
def __init__(self, *args: str) -> None:
|
||||
self._args = args
|
||||
self._field = None
|
||||
self._alias = None
|
||||
|
||||
def alias(self, alias: str) -> "Reducer":
|
||||
"""
|
||||
Set the alias for this reducer.
|
||||
|
||||
### Parameters
|
||||
|
||||
- **alias**: The value of the alias for this reducer. If this is the
|
||||
special value `aggregation.FIELDNAME` then this reducer will be
|
||||
aliased using the same name as the field upon which it operates.
|
||||
Note that using `FIELDNAME` is only possible on reducers which
|
||||
operate on a single field value.
|
||||
|
||||
This method returns the `Reducer` object making it suitable for
|
||||
chaining.
|
||||
"""
|
||||
if alias is FIELDNAME:
|
||||
if not self._field:
|
||||
raise ValueError("Cannot use FIELDNAME alias with no field")
|
||||
# Chop off initial '@'
|
||||
alias = self._field[1:]
|
||||
self._alias = alias
|
||||
return self
|
||||
|
||||
@property
|
||||
def args(self) -> List[str]:
|
||||
return self._args
|
||||
|
||||
|
||||
class SortDirection:
|
||||
"""
|
||||
This special class is used to indicate sort direction.
|
||||
"""
|
||||
|
||||
DIRSTRING = None
|
||||
|
||||
def __init__(self, field: str) -> None:
|
||||
self.field = field
|
||||
|
||||
|
||||
class Asc(SortDirection):
|
||||
"""
|
||||
Indicate that the given field should be sorted in ascending order
|
||||
"""
|
||||
|
||||
DIRSTRING = "ASC"
|
||||
|
||||
|
||||
class Desc(SortDirection):
|
||||
"""
|
||||
Indicate that the given field should be sorted in descending order
|
||||
"""
|
||||
|
||||
DIRSTRING = "DESC"
|
||||
|
||||
|
||||
class AggregateRequest:
|
||||
"""
|
||||
Aggregation request which can be passed to `Client.aggregate`.
|
||||
"""
|
||||
|
||||
def __init__(self, query: str = "*") -> None:
|
||||
"""
|
||||
Create an aggregation request. This request may then be passed to
|
||||
`client.aggregate()`.
|
||||
|
||||
In order for the request to be usable, it must contain at least one
|
||||
group.
|
||||
|
||||
- **query** Query string for filtering records.
|
||||
|
||||
All member methods (except `build_args()`)
|
||||
return the object itself, making them useful for chaining.
|
||||
"""
|
||||
self._query = query
|
||||
self._aggregateplan = []
|
||||
self._loadfields = []
|
||||
self._loadall = False
|
||||
self._max = 0
|
||||
self._with_schema = False
|
||||
self._verbatim = False
|
||||
self._cursor = []
|
||||
self._dialect = DEFAULT_DIALECT
|
||||
self._add_scores = False
|
||||
self._scorer = "TFIDF"
|
||||
|
||||
def load(self, *fields: str) -> "AggregateRequest":
|
||||
"""
|
||||
Indicate the fields to be returned in the response. These fields are
|
||||
returned in addition to any others implicitly specified.
|
||||
|
||||
### Parameters
|
||||
|
||||
- **fields**: If fields not specified, all the fields will be loaded.
|
||||
Otherwise, fields should be given in the format of `@field`.
|
||||
"""
|
||||
if fields:
|
||||
self._loadfields.extend(fields)
|
||||
else:
|
||||
self._loadall = True
|
||||
return self
|
||||
|
||||
def group_by(
|
||||
self, fields: List[str], *reducers: Union[Reducer, List[Reducer]]
|
||||
) -> "AggregateRequest":
|
||||
"""
|
||||
Specify by which fields to group the aggregation.
|
||||
|
||||
### Parameters
|
||||
|
||||
- **fields**: Fields to group by. This can either be a single string,
|
||||
or a list of strings. both cases, the field should be specified as
|
||||
`@field`.
|
||||
- **reducers**: One or more reducers. Reducers may be found in the
|
||||
`aggregation` module.
|
||||
"""
|
||||
fields = [fields] if isinstance(fields, str) else fields
|
||||
reducers = [reducers] if isinstance(reducers, Reducer) else reducers
|
||||
|
||||
ret = ["GROUPBY", str(len(fields)), *fields]
|
||||
for reducer in reducers:
|
||||
ret += ["REDUCE", reducer.NAME, str(len(reducer.args))]
|
||||
ret.extend(reducer.args)
|
||||
if reducer._alias is not None:
|
||||
ret += ["AS", reducer._alias]
|
||||
|
||||
self._aggregateplan.extend(ret)
|
||||
return self
|
||||
|
||||
def apply(self, **kwexpr) -> "AggregateRequest":
|
||||
"""
|
||||
Specify one or more projection expressions to add to each result
|
||||
|
||||
### Parameters
|
||||
|
||||
- **kwexpr**: One or more key-value pairs for a projection. The key is
|
||||
the alias for the projection, and the value is the projection
|
||||
expression itself, for example `apply(square_root="sqrt(@foo)")`
|
||||
"""
|
||||
for alias, expr in kwexpr.items():
|
||||
ret = ["APPLY", expr]
|
||||
if alias is not None:
|
||||
ret += ["AS", alias]
|
||||
self._aggregateplan.extend(ret)
|
||||
|
||||
return self
|
||||
|
||||
def limit(self, offset: int, num: int) -> "AggregateRequest":
|
||||
"""
|
||||
Sets the limit for the most recent group or query.
|
||||
|
||||
If no group has been defined yet (via `group_by()`) then this sets
|
||||
the limit for the initial pool of results from the query. Otherwise,
|
||||
this limits the number of items operated on from the previous group.
|
||||
|
||||
Setting a limit on the initial search results may be useful when
|
||||
attempting to execute an aggregation on a sample of a large data set.
|
||||
|
||||
### Parameters
|
||||
|
||||
- **offset**: Result offset from which to begin paging
|
||||
- **num**: Number of results to return
|
||||
|
||||
|
||||
Example of sorting the initial results:
|
||||
|
||||
```
|
||||
AggregateRequest("@sale_amount:[10000, inf]")\
|
||||
.limit(0, 10)\
|
||||
.group_by("@state", r.count())
|
||||
```
|
||||
|
||||
Will only group by the states found in the first 10 results of the
|
||||
query `@sale_amount:[10000, inf]`. On the other hand,
|
||||
|
||||
```
|
||||
AggregateRequest("@sale_amount:[10000, inf]")\
|
||||
.limit(0, 1000)\
|
||||
.group_by("@state", r.count()\
|
||||
.limit(0, 10)
|
||||
```
|
||||
|
||||
Will group all the results matching the query, but only return the
|
||||
first 10 groups.
|
||||
|
||||
If you only wish to return a *top-N* style query, consider using
|
||||
`sort_by()` instead.
|
||||
|
||||
"""
|
||||
_limit = Limit(offset, num)
|
||||
self._aggregateplan.extend(_limit.build_args())
|
||||
return self
|
||||
|
||||
def sort_by(self, *fields: str, **kwargs) -> "AggregateRequest":
|
||||
"""
|
||||
Indicate how the results should be sorted. This can also be used for
|
||||
*top-N* style queries
|
||||
|
||||
### Parameters
|
||||
|
||||
- **fields**: The fields by which to sort. This can be either a single
|
||||
field or a list of fields. If you wish to specify order, you can
|
||||
use the `Asc` or `Desc` wrapper classes.
|
||||
- **max**: Maximum number of results to return. This can be
|
||||
used instead of `LIMIT` and is also faster.
|
||||
|
||||
|
||||
Example of sorting by `foo` ascending and `bar` descending:
|
||||
|
||||
```
|
||||
sort_by(Asc("@foo"), Desc("@bar"))
|
||||
```
|
||||
|
||||
Return the top 10 customers:
|
||||
|
||||
```
|
||||
AggregateRequest()\
|
||||
.group_by("@customer", r.sum("@paid").alias(FIELDNAME))\
|
||||
.sort_by(Desc("@paid"), max=10)
|
||||
```
|
||||
"""
|
||||
if isinstance(fields, (str, SortDirection)):
|
||||
fields = [fields]
|
||||
|
||||
fields_args = []
|
||||
for f in fields:
|
||||
if isinstance(f, SortDirection):
|
||||
fields_args += [f.field, f.DIRSTRING]
|
||||
else:
|
||||
fields_args += [f]
|
||||
|
||||
ret = ["SORTBY", str(len(fields_args))]
|
||||
ret.extend(fields_args)
|
||||
max = kwargs.get("max", 0)
|
||||
if max > 0:
|
||||
ret += ["MAX", str(max)]
|
||||
|
||||
self._aggregateplan.extend(ret)
|
||||
return self
|
||||
|
||||
def filter(self, expressions: Union[str, List[str]]) -> "AggregateRequest":
|
||||
"""
|
||||
Specify filter for post-query results using predicates relating to
|
||||
values in the result set.
|
||||
|
||||
### Parameters
|
||||
|
||||
- **fields**: Fields to group by. This can either be a single string,
|
||||
or a list of strings.
|
||||
"""
|
||||
if isinstance(expressions, str):
|
||||
expressions = [expressions]
|
||||
|
||||
for expression in expressions:
|
||||
self._aggregateplan.extend(["FILTER", expression])
|
||||
|
||||
return self
|
||||
|
||||
def with_schema(self) -> "AggregateRequest":
|
||||
"""
|
||||
If set, the `schema` property will contain a list of `[field, type]`
|
||||
entries in the result object.
|
||||
"""
|
||||
self._with_schema = True
|
||||
return self
|
||||
|
||||
def add_scores(self) -> "AggregateRequest":
|
||||
"""
|
||||
If set, includes the score as an ordinary field of the row.
|
||||
"""
|
||||
self._add_scores = True
|
||||
return self
|
||||
|
||||
def scorer(self, scorer: str) -> "AggregateRequest":
|
||||
"""
|
||||
Use a different scoring function to evaluate document relevance.
|
||||
Default is `TFIDF`.
|
||||
|
||||
:param scorer: The scoring function to use
|
||||
(e.g. `TFIDF.DOCNORM` or `BM25`)
|
||||
"""
|
||||
self._scorer = scorer
|
||||
return self
|
||||
|
||||
def verbatim(self) -> "AggregateRequest":
|
||||
self._verbatim = True
|
||||
return self
|
||||
|
||||
def cursor(self, count: int = 0, max_idle: float = 0.0) -> "AggregateRequest":
|
||||
args = ["WITHCURSOR"]
|
||||
if count:
|
||||
args += ["COUNT", str(count)]
|
||||
if max_idle:
|
||||
args += ["MAXIDLE", str(max_idle * 1000)]
|
||||
self._cursor = args
|
||||
return self
|
||||
|
||||
def build_args(self) -> List[str]:
|
||||
# @foo:bar ...
|
||||
ret = [self._query]
|
||||
|
||||
if self._with_schema:
|
||||
ret.append("WITHSCHEMA")
|
||||
|
||||
if self._verbatim:
|
||||
ret.append("VERBATIM")
|
||||
|
||||
if self._scorer:
|
||||
ret.extend(["SCORER", self._scorer])
|
||||
|
||||
if self._add_scores:
|
||||
ret.append("ADDSCORES")
|
||||
|
||||
if self._cursor:
|
||||
ret += self._cursor
|
||||
|
||||
if self._loadall:
|
||||
ret.append("LOAD")
|
||||
ret.append("*")
|
||||
|
||||
elif self._loadfields:
|
||||
ret.append("LOAD")
|
||||
ret.append(str(len(self._loadfields)))
|
||||
ret.extend(self._loadfields)
|
||||
|
||||
if self._dialect:
|
||||
ret.extend(["DIALECT", self._dialect])
|
||||
|
||||
ret.extend(self._aggregateplan)
|
||||
|
||||
return ret
|
||||
|
||||
def dialect(self, dialect: int) -> "AggregateRequest":
|
||||
"""
|
||||
Add a dialect field to the aggregate command.
|
||||
|
||||
- **dialect** - dialect version to execute the query under
|
||||
"""
|
||||
self._dialect = dialect
|
||||
return self
|
||||
|
||||
|
||||
class Cursor:
|
||||
def __init__(self, cid: int) -> None:
|
||||
self.cid = cid
|
||||
self.max_idle = 0
|
||||
self.count = 0
|
||||
|
||||
def build_args(self):
|
||||
args = [str(self.cid)]
|
||||
if self.max_idle:
|
||||
args += ["MAXIDLE", str(self.max_idle)]
|
||||
if self.count:
|
||||
args += ["COUNT", str(self.count)]
|
||||
return args
|
||||
|
||||
|
||||
class AggregateResult:
|
||||
def __init__(self, rows, cursor: Cursor, schema) -> None:
|
||||
self.rows = rows
|
||||
self.cursor = cursor
|
||||
self.schema = schema
|
||||
|
||||
def __repr__(self) -> (str, str):
|
||||
cid = self.cursor.cid if self.cursor else -1
|
||||
return (
|
||||
f"<{self.__class__.__name__} at 0x{id(self):x} "
|
||||
f"Rows={len(self.rows)}, Cursor={cid}>"
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,3 @@
|
||||
# Value for the default dialect to be used as a part of
|
||||
# Search or Aggregate query.
|
||||
DEFAULT_DIALECT = 2
|
||||
@@ -0,0 +1,17 @@
|
||||
class Document:
|
||||
"""
|
||||
Represents a single document in a result set
|
||||
"""
|
||||
|
||||
def __init__(self, id, payload=None, **fields):
|
||||
self.id = id
|
||||
self.payload = payload
|
||||
for k, v in fields.items():
|
||||
setattr(self, k, v)
|
||||
|
||||
def __repr__(self):
|
||||
return f"Document {self.__dict__}"
|
||||
|
||||
def __getitem__(self, item):
|
||||
value = getattr(self, item)
|
||||
return value
|
||||
@@ -0,0 +1,210 @@
|
||||
from typing import List
|
||||
|
||||
from redis import DataError
|
||||
|
||||
|
||||
class Field:
|
||||
"""
|
||||
A class representing a field in a document.
|
||||
"""
|
||||
|
||||
NUMERIC = "NUMERIC"
|
||||
TEXT = "TEXT"
|
||||
WEIGHT = "WEIGHT"
|
||||
GEO = "GEO"
|
||||
TAG = "TAG"
|
||||
VECTOR = "VECTOR"
|
||||
SORTABLE = "SORTABLE"
|
||||
NOINDEX = "NOINDEX"
|
||||
AS = "AS"
|
||||
GEOSHAPE = "GEOSHAPE"
|
||||
INDEX_MISSING = "INDEXMISSING"
|
||||
INDEX_EMPTY = "INDEXEMPTY"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
args: List[str] = None,
|
||||
sortable: bool = False,
|
||||
no_index: bool = False,
|
||||
index_missing: bool = False,
|
||||
index_empty: bool = False,
|
||||
as_name: str = None,
|
||||
):
|
||||
"""
|
||||
Create a new field object.
|
||||
|
||||
Args:
|
||||
name: The name of the field.
|
||||
args:
|
||||
sortable: If `True`, the field will be sortable.
|
||||
no_index: If `True`, the field will not be indexed.
|
||||
index_missing: If `True`, it will be possible to search for documents that
|
||||
have this field missing.
|
||||
index_empty: If `True`, it will be possible to search for documents that
|
||||
have this field empty.
|
||||
as_name: If provided, this alias will be used for the field.
|
||||
"""
|
||||
if args is None:
|
||||
args = []
|
||||
self.name = name
|
||||
self.args = args
|
||||
self.args_suffix = list()
|
||||
self.as_name = as_name
|
||||
|
||||
if sortable:
|
||||
self.args_suffix.append(Field.SORTABLE)
|
||||
if no_index:
|
||||
self.args_suffix.append(Field.NOINDEX)
|
||||
if index_missing:
|
||||
self.args_suffix.append(Field.INDEX_MISSING)
|
||||
if index_empty:
|
||||
self.args_suffix.append(Field.INDEX_EMPTY)
|
||||
|
||||
if no_index and not sortable:
|
||||
raise ValueError("Non-Sortable non-Indexable fields are ignored")
|
||||
|
||||
def append_arg(self, value):
|
||||
self.args.append(value)
|
||||
|
||||
def redis_args(self):
|
||||
args = [self.name]
|
||||
if self.as_name:
|
||||
args += [self.AS, self.as_name]
|
||||
args += self.args
|
||||
args += self.args_suffix
|
||||
return args
|
||||
|
||||
|
||||
class TextField(Field):
|
||||
"""
|
||||
TextField is used to define a text field in a schema definition
|
||||
"""
|
||||
|
||||
NOSTEM = "NOSTEM"
|
||||
PHONETIC = "PHONETIC"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
weight: float = 1.0,
|
||||
no_stem: bool = False,
|
||||
phonetic_matcher: str = None,
|
||||
withsuffixtrie: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
Field.__init__(self, name, args=[Field.TEXT, Field.WEIGHT, weight], **kwargs)
|
||||
|
||||
if no_stem:
|
||||
Field.append_arg(self, self.NOSTEM)
|
||||
if phonetic_matcher and phonetic_matcher in [
|
||||
"dm:en",
|
||||
"dm:fr",
|
||||
"dm:pt",
|
||||
"dm:es",
|
||||
]:
|
||||
Field.append_arg(self, self.PHONETIC)
|
||||
Field.append_arg(self, phonetic_matcher)
|
||||
if withsuffixtrie:
|
||||
Field.append_arg(self, "WITHSUFFIXTRIE")
|
||||
|
||||
|
||||
class NumericField(Field):
|
||||
"""
|
||||
NumericField is used to define a numeric field in a schema definition
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, **kwargs):
|
||||
Field.__init__(self, name, args=[Field.NUMERIC], **kwargs)
|
||||
|
||||
|
||||
class GeoShapeField(Field):
|
||||
"""
|
||||
GeoShapeField is used to enable within/contain indexing/searching
|
||||
"""
|
||||
|
||||
SPHERICAL = "SPHERICAL"
|
||||
FLAT = "FLAT"
|
||||
|
||||
def __init__(self, name: str, coord_system=None, **kwargs):
|
||||
args = [Field.GEOSHAPE]
|
||||
if coord_system:
|
||||
args.append(coord_system)
|
||||
Field.__init__(self, name, args=args, **kwargs)
|
||||
|
||||
|
||||
class GeoField(Field):
|
||||
"""
|
||||
GeoField is used to define a geo-indexing field in a schema definition
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, **kwargs):
|
||||
Field.__init__(self, name, args=[Field.GEO], **kwargs)
|
||||
|
||||
|
||||
class TagField(Field):
|
||||
"""
|
||||
TagField is a tag-indexing field with simpler compression and tokenization.
|
||||
See http://redisearch.io/Tags/
|
||||
"""
|
||||
|
||||
SEPARATOR = "SEPARATOR"
|
||||
CASESENSITIVE = "CASESENSITIVE"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
separator: str = ",",
|
||||
case_sensitive: bool = False,
|
||||
withsuffixtrie: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
args = [Field.TAG, self.SEPARATOR, separator]
|
||||
if case_sensitive:
|
||||
args.append(self.CASESENSITIVE)
|
||||
if withsuffixtrie:
|
||||
args.append("WITHSUFFIXTRIE")
|
||||
|
||||
Field.__init__(self, name, args=args, **kwargs)
|
||||
|
||||
|
||||
class VectorField(Field):
|
||||
"""
|
||||
Allows vector similarity queries against the value in this attribute.
|
||||
See https://oss.redis.com/redisearch/Vectors/#vector_fields.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, algorithm: str, attributes: dict, **kwargs):
|
||||
"""
|
||||
Create Vector Field. Notice that Vector cannot have sortable or no_index tag,
|
||||
although it's also a Field.
|
||||
|
||||
``name`` is the name of the field.
|
||||
|
||||
``algorithm`` can be "FLAT", "HNSW", or "SVS-VAMANA".
|
||||
|
||||
``attributes`` each algorithm can have specific attributes. Some of them
|
||||
are mandatory and some of them are optional. See
|
||||
https://oss.redis.com/redisearch/master/Vectors/#specific_creation_attributes_per_algorithm
|
||||
for more information.
|
||||
"""
|
||||
sort = kwargs.get("sortable", False)
|
||||
noindex = kwargs.get("no_index", False)
|
||||
|
||||
if sort or noindex:
|
||||
raise DataError("Cannot set 'sortable' or 'no_index' in Vector fields.")
|
||||
|
||||
if algorithm.upper() not in ["FLAT", "HNSW", "SVS-VAMANA"]:
|
||||
raise DataError(
|
||||
"Realtime vector indexing supporting 3 Indexing Methods:"
|
||||
"'FLAT', 'HNSW', and 'SVS-VAMANA'."
|
||||
)
|
||||
|
||||
attr_li = []
|
||||
|
||||
for key, value in attributes.items():
|
||||
attr_li.extend([key, value])
|
||||
|
||||
Field.__init__(
|
||||
self, name, args=[Field.VECTOR, algorithm, len(attr_li), *attr_li], **kwargs
|
||||
)
|
||||
@@ -0,0 +1,79 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class IndexType(Enum):
|
||||
"""Enum of the currently supported index types."""
|
||||
|
||||
HASH = 1
|
||||
JSON = 2
|
||||
|
||||
|
||||
class IndexDefinition:
|
||||
"""IndexDefinition is used to define a index definition for automatic
|
||||
indexing on Hash or Json update."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prefix=[],
|
||||
filter=None,
|
||||
language_field=None,
|
||||
language=None,
|
||||
score_field=None,
|
||||
score=1.0,
|
||||
payload_field=None,
|
||||
index_type=None,
|
||||
):
|
||||
self.args = []
|
||||
self._append_index_type(index_type)
|
||||
self._append_prefix(prefix)
|
||||
self._append_filter(filter)
|
||||
self._append_language(language_field, language)
|
||||
self._append_score(score_field, score)
|
||||
self._append_payload(payload_field)
|
||||
|
||||
def _append_index_type(self, index_type):
|
||||
"""Append `ON HASH` or `ON JSON` according to the enum."""
|
||||
if index_type is IndexType.HASH:
|
||||
self.args.extend(["ON", "HASH"])
|
||||
elif index_type is IndexType.JSON:
|
||||
self.args.extend(["ON", "JSON"])
|
||||
elif index_type is not None:
|
||||
raise RuntimeError(f"index_type must be one of {list(IndexType)}")
|
||||
|
||||
def _append_prefix(self, prefix):
|
||||
"""Append PREFIX."""
|
||||
if len(prefix) > 0:
|
||||
self.args.append("PREFIX")
|
||||
self.args.append(len(prefix))
|
||||
for p in prefix:
|
||||
self.args.append(p)
|
||||
|
||||
def _append_filter(self, filter):
|
||||
"""Append FILTER."""
|
||||
if filter is not None:
|
||||
self.args.append("FILTER")
|
||||
self.args.append(filter)
|
||||
|
||||
def _append_language(self, language_field, language):
|
||||
"""Append LANGUAGE_FIELD and LANGUAGE."""
|
||||
if language_field is not None:
|
||||
self.args.append("LANGUAGE_FIELD")
|
||||
self.args.append(language_field)
|
||||
if language is not None:
|
||||
self.args.append("LANGUAGE")
|
||||
self.args.append(language)
|
||||
|
||||
def _append_score(self, score_field, score):
|
||||
"""Append SCORE_FIELD and SCORE."""
|
||||
if score_field is not None:
|
||||
self.args.append("SCORE_FIELD")
|
||||
self.args.append(score_field)
|
||||
if score is not None:
|
||||
self.args.append("SCORE")
|
||||
self.args.append(score)
|
||||
|
||||
def _append_payload(self, payload_field):
|
||||
"""Append PAYLOAD_FIELD."""
|
||||
if payload_field is not None:
|
||||
self.args.append("PAYLOAD_FIELD")
|
||||
self.args.append(payload_field)
|
||||
@@ -0,0 +1,14 @@
|
||||
from typing import Any
|
||||
|
||||
|
||||
class ProfileInformation:
|
||||
"""
|
||||
Wrapper around FT.PROFILE response
|
||||
"""
|
||||
|
||||
def __init__(self, info: Any) -> None:
|
||||
self._info: Any = info
|
||||
|
||||
@property
|
||||
def info(self) -> Any:
|
||||
return self._info
|
||||
@@ -0,0 +1,381 @@
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from redis.commands.search.dialect import DEFAULT_DIALECT
|
||||
|
||||
|
||||
class Query:
|
||||
"""
|
||||
Query is used to build complex queries that have more parameters than just
|
||||
the query string. The query string is set in the constructor, and other
|
||||
options have setter functions.
|
||||
|
||||
The setter functions return the query object, so they can be chained,
|
||||
i.e. `Query("foo").verbatim().filter(...)` etc.
|
||||
"""
|
||||
|
||||
def __init__(self, query_string: str) -> None:
|
||||
"""
|
||||
Create a new query object.
|
||||
The query string is set in the constructor, and other options have
|
||||
setter functions.
|
||||
"""
|
||||
|
||||
self._query_string: str = query_string
|
||||
self._offset: int = 0
|
||||
self._num: int = 10
|
||||
self._no_content: bool = False
|
||||
self._no_stopwords: bool = False
|
||||
self._fields: Optional[List[str]] = None
|
||||
self._verbatim: bool = False
|
||||
self._with_payloads: bool = False
|
||||
self._with_scores: bool = False
|
||||
self._scorer: Optional[str] = None
|
||||
self._filters: List = list()
|
||||
self._ids: Optional[List[str]] = None
|
||||
self._slop: int = -1
|
||||
self._timeout: Optional[float] = None
|
||||
self._in_order: bool = False
|
||||
self._sortby: Optional[SortbyField] = None
|
||||
self._return_fields: List = []
|
||||
self._return_fields_decode_as: dict = {}
|
||||
self._summarize_fields: List = []
|
||||
self._highlight_fields: List = []
|
||||
self._language: Optional[str] = None
|
||||
self._expander: Optional[str] = None
|
||||
self._dialect: int = DEFAULT_DIALECT
|
||||
|
||||
def query_string(self) -> str:
|
||||
"""Return the query string of this query only."""
|
||||
return self._query_string
|
||||
|
||||
def limit_ids(self, *ids) -> "Query":
|
||||
"""Limit the results to a specific set of pre-known document
|
||||
ids of any length."""
|
||||
self._ids = ids
|
||||
return self
|
||||
|
||||
def return_fields(self, *fields) -> "Query":
|
||||
"""Add fields to return fields."""
|
||||
for field in fields:
|
||||
self.return_field(field)
|
||||
return self
|
||||
|
||||
def return_field(
|
||||
self,
|
||||
field: str,
|
||||
as_field: Optional[str] = None,
|
||||
decode_field: Optional[bool] = True,
|
||||
encoding: Optional[str] = "utf8",
|
||||
) -> "Query":
|
||||
"""
|
||||
Add a field to the list of fields to return.
|
||||
|
||||
- **field**: The field to include in query results
|
||||
- **as_field**: The alias for the field
|
||||
- **decode_field**: Whether to decode the field from bytes to string
|
||||
- **encoding**: The encoding to use when decoding the field
|
||||
"""
|
||||
self._return_fields.append(field)
|
||||
self._return_fields_decode_as[field] = encoding if decode_field else None
|
||||
if as_field is not None:
|
||||
self._return_fields += ("AS", as_field)
|
||||
return self
|
||||
|
||||
def _mk_field_list(self, fields: List[str]) -> List:
|
||||
if not fields:
|
||||
return []
|
||||
return [fields] if isinstance(fields, str) else list(fields)
|
||||
|
||||
def summarize(
|
||||
self,
|
||||
fields: Optional[List] = None,
|
||||
context_len: Optional[int] = None,
|
||||
num_frags: Optional[int] = None,
|
||||
sep: Optional[str] = None,
|
||||
) -> "Query":
|
||||
"""
|
||||
Return an abridged format of the field, containing only the segments of
|
||||
the field which contain the matching term(s).
|
||||
|
||||
If `fields` is specified, then only the mentioned fields are
|
||||
summarized; otherwise all results are summarized.
|
||||
|
||||
Server side defaults are used for each option (except `fields`)
|
||||
if not specified
|
||||
|
||||
- **fields** List of fields to summarize. All fields are summarized
|
||||
if not specified
|
||||
- **context_len** Amount of context to include with each fragment
|
||||
- **num_frags** Number of fragments per document
|
||||
- **sep** Separator string to separate fragments
|
||||
"""
|
||||
args = ["SUMMARIZE"]
|
||||
fields = self._mk_field_list(fields)
|
||||
if fields:
|
||||
args += ["FIELDS", str(len(fields))] + fields
|
||||
|
||||
if context_len is not None:
|
||||
args += ["LEN", str(context_len)]
|
||||
if num_frags is not None:
|
||||
args += ["FRAGS", str(num_frags)]
|
||||
if sep is not None:
|
||||
args += ["SEPARATOR", sep]
|
||||
|
||||
self._summarize_fields = args
|
||||
return self
|
||||
|
||||
def highlight(
|
||||
self, fields: Optional[List[str]] = None, tags: Optional[List[str]] = None
|
||||
) -> None:
|
||||
"""
|
||||
Apply specified markup to matched term(s) within the returned field(s).
|
||||
|
||||
- **fields** If specified then only those mentioned fields are
|
||||
highlighted, otherwise all fields are highlighted
|
||||
- **tags** A list of two strings to surround the match.
|
||||
"""
|
||||
args = ["HIGHLIGHT"]
|
||||
fields = self._mk_field_list(fields)
|
||||
if fields:
|
||||
args += ["FIELDS", str(len(fields))] + fields
|
||||
if tags:
|
||||
args += ["TAGS"] + list(tags)
|
||||
|
||||
self._highlight_fields = args
|
||||
return self
|
||||
|
||||
def language(self, language: str) -> "Query":
|
||||
"""
|
||||
Analyze the query as being in the specified language.
|
||||
|
||||
:param language: The language (e.g. `chinese` or `english`)
|
||||
"""
|
||||
self._language = language
|
||||
return self
|
||||
|
||||
def slop(self, slop: int) -> "Query":
|
||||
"""Allow a maximum of N intervening non matched terms between
|
||||
phrase terms (0 means exact phrase).
|
||||
"""
|
||||
self._slop = slop
|
||||
return self
|
||||
|
||||
def timeout(self, timeout: float) -> "Query":
|
||||
"""overrides the timeout parameter of the module"""
|
||||
self._timeout = timeout
|
||||
return self
|
||||
|
||||
def in_order(self) -> "Query":
|
||||
"""
|
||||
Match only documents where the query terms appear in
|
||||
the same order in the document.
|
||||
i.e. for the query "hello world", we do not match "world hello"
|
||||
"""
|
||||
self._in_order = True
|
||||
return self
|
||||
|
||||
def scorer(self, scorer: str) -> "Query":
|
||||
"""
|
||||
Use a different scoring function to evaluate document relevance.
|
||||
Default is `TFIDF`.
|
||||
|
||||
Since Redis 8.0 default was changed to BM25STD.
|
||||
|
||||
:param scorer: The scoring function to use
|
||||
(e.g. `TFIDF.DOCNORM` or `BM25`)
|
||||
"""
|
||||
self._scorer = scorer
|
||||
return self
|
||||
|
||||
def get_args(self) -> List[str]:
|
||||
"""Format the redis arguments for this query and return them."""
|
||||
args = [self._query_string]
|
||||
args += self._get_args_tags()
|
||||
args += self._summarize_fields + self._highlight_fields
|
||||
args += ["LIMIT", self._offset, self._num]
|
||||
return args
|
||||
|
||||
def _get_args_tags(self) -> List[str]:
|
||||
args = []
|
||||
if self._no_content:
|
||||
args.append("NOCONTENT")
|
||||
if self._fields:
|
||||
args.append("INFIELDS")
|
||||
args.append(len(self._fields))
|
||||
args += self._fields
|
||||
if self._verbatim:
|
||||
args.append("VERBATIM")
|
||||
if self._no_stopwords:
|
||||
args.append("NOSTOPWORDS")
|
||||
if self._filters:
|
||||
for flt in self._filters:
|
||||
if not isinstance(flt, Filter):
|
||||
raise AttributeError("Did not receive a Filter object.")
|
||||
args += flt.args
|
||||
if self._with_payloads:
|
||||
args.append("WITHPAYLOADS")
|
||||
if self._scorer:
|
||||
args += ["SCORER", self._scorer]
|
||||
if self._with_scores:
|
||||
args.append("WITHSCORES")
|
||||
if self._ids:
|
||||
args.append("INKEYS")
|
||||
args.append(len(self._ids))
|
||||
args += self._ids
|
||||
if self._slop >= 0:
|
||||
args += ["SLOP", self._slop]
|
||||
if self._timeout is not None:
|
||||
args += ["TIMEOUT", self._timeout]
|
||||
if self._in_order:
|
||||
args.append("INORDER")
|
||||
if self._return_fields:
|
||||
args.append("RETURN")
|
||||
args.append(len(self._return_fields))
|
||||
args += self._return_fields
|
||||
if self._sortby:
|
||||
if not isinstance(self._sortby, SortbyField):
|
||||
raise AttributeError("Did not receive a SortByField.")
|
||||
args.append("SORTBY")
|
||||
args += self._sortby.args
|
||||
if self._language:
|
||||
args += ["LANGUAGE", self._language]
|
||||
if self._expander:
|
||||
args += ["EXPANDER", self._expander]
|
||||
if self._dialect:
|
||||
args += ["DIALECT", self._dialect]
|
||||
|
||||
return args
|
||||
|
||||
def paging(self, offset: int, num: int) -> "Query":
|
||||
"""
|
||||
Set the paging for the query (defaults to 0..10).
|
||||
|
||||
- **offset**: Paging offset for the results. Defaults to 0
|
||||
- **num**: How many results do we want
|
||||
"""
|
||||
self._offset = offset
|
||||
self._num = num
|
||||
return self
|
||||
|
||||
def verbatim(self) -> "Query":
|
||||
"""Set the query to be verbatim, i.e. use no query expansion
|
||||
or stemming.
|
||||
"""
|
||||
self._verbatim = True
|
||||
return self
|
||||
|
||||
def no_content(self) -> "Query":
|
||||
"""Set the query to only return ids and not the document content."""
|
||||
self._no_content = True
|
||||
return self
|
||||
|
||||
def no_stopwords(self) -> "Query":
|
||||
"""
|
||||
Prevent the query from being filtered for stopwords.
|
||||
Only useful in very big queries that you are certain contain
|
||||
no stopwords.
|
||||
"""
|
||||
self._no_stopwords = True
|
||||
return self
|
||||
|
||||
def with_payloads(self) -> "Query":
|
||||
"""Ask the engine to return document payloads."""
|
||||
self._with_payloads = True
|
||||
return self
|
||||
|
||||
def with_scores(self) -> "Query":
|
||||
"""Ask the engine to return document search scores."""
|
||||
self._with_scores = True
|
||||
return self
|
||||
|
||||
def limit_fields(self, *fields: List[str]) -> "Query":
|
||||
"""
|
||||
Limit the search to specific TEXT fields only.
|
||||
|
||||
- **fields**: A list of strings, case sensitive field names
|
||||
from the defined schema.
|
||||
"""
|
||||
self._fields = fields
|
||||
return self
|
||||
|
||||
def add_filter(self, flt: "Filter") -> "Query":
|
||||
"""
|
||||
Add a numeric or geo filter to the query.
|
||||
**Currently only one of each filter is supported by the engine**
|
||||
|
||||
- **flt**: A NumericFilter or GeoFilter object, used on a
|
||||
corresponding field
|
||||
"""
|
||||
|
||||
self._filters.append(flt)
|
||||
return self
|
||||
|
||||
def sort_by(self, field: str, asc: bool = True) -> "Query":
|
||||
"""
|
||||
Add a sortby field to the query.
|
||||
|
||||
- **field** - the name of the field to sort by
|
||||
- **asc** - when `True`, sorting will be done in asceding order
|
||||
"""
|
||||
self._sortby = SortbyField(field, asc)
|
||||
return self
|
||||
|
||||
def expander(self, expander: str) -> "Query":
|
||||
"""
|
||||
Add a expander field to the query.
|
||||
|
||||
- **expander** - the name of the expander
|
||||
"""
|
||||
self._expander = expander
|
||||
return self
|
||||
|
||||
def dialect(self, dialect: int) -> "Query":
|
||||
"""
|
||||
Add a dialect field to the query.
|
||||
|
||||
- **dialect** - dialect version to execute the query under
|
||||
"""
|
||||
self._dialect = dialect
|
||||
return self
|
||||
|
||||
|
||||
class Filter:
|
||||
def __init__(self, keyword: str, field: str, *args: List[str]) -> None:
|
||||
self.args = [keyword, field] + list(args)
|
||||
|
||||
|
||||
class NumericFilter(Filter):
|
||||
INF = "+inf"
|
||||
NEG_INF = "-inf"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
field: str,
|
||||
minval: Union[int, str],
|
||||
maxval: Union[int, str],
|
||||
minExclusive: bool = False,
|
||||
maxExclusive: bool = False,
|
||||
) -> None:
|
||||
args = [
|
||||
minval if not minExclusive else f"({minval}",
|
||||
maxval if not maxExclusive else f"({maxval}",
|
||||
]
|
||||
|
||||
Filter.__init__(self, "FILTER", field, *args)
|
||||
|
||||
|
||||
class GeoFilter(Filter):
|
||||
METERS = "m"
|
||||
KILOMETERS = "km"
|
||||
FEET = "ft"
|
||||
MILES = "mi"
|
||||
|
||||
def __init__(
|
||||
self, field: str, lon: float, lat: float, radius: float, unit: str = KILOMETERS
|
||||
) -> None:
|
||||
Filter.__init__(self, "GEOFILTER", field, lon, lat, radius, unit)
|
||||
|
||||
|
||||
class SortbyField:
|
||||
def __init__(self, field: str, asc=True) -> None:
|
||||
self.args = [field, "ASC" if asc else "DESC"]
|
||||
@@ -0,0 +1,317 @@
|
||||
def tags(*t):
|
||||
"""
|
||||
Indicate that the values should be matched to a tag field
|
||||
|
||||
### Parameters
|
||||
|
||||
- **t**: Tags to search for
|
||||
"""
|
||||
if not t:
|
||||
raise ValueError("At least one tag must be specified")
|
||||
return TagValue(*t)
|
||||
|
||||
|
||||
def between(a, b, inclusive_min=True, inclusive_max=True):
|
||||
"""
|
||||
Indicate that value is a numeric range
|
||||
"""
|
||||
return RangeValue(a, b, inclusive_min=inclusive_min, inclusive_max=inclusive_max)
|
||||
|
||||
|
||||
def equal(n):
|
||||
"""
|
||||
Match a numeric value
|
||||
"""
|
||||
return between(n, n)
|
||||
|
||||
|
||||
def lt(n):
|
||||
"""
|
||||
Match any value less than n
|
||||
"""
|
||||
return between(None, n, inclusive_max=False)
|
||||
|
||||
|
||||
def le(n):
|
||||
"""
|
||||
Match any value less or equal to n
|
||||
"""
|
||||
return between(None, n, inclusive_max=True)
|
||||
|
||||
|
||||
def gt(n):
|
||||
"""
|
||||
Match any value greater than n
|
||||
"""
|
||||
return between(n, None, inclusive_min=False)
|
||||
|
||||
|
||||
def ge(n):
|
||||
"""
|
||||
Match any value greater or equal to n
|
||||
"""
|
||||
return between(n, None, inclusive_min=True)
|
||||
|
||||
|
||||
def geo(lat, lon, radius, unit="km"):
|
||||
"""
|
||||
Indicate that value is a geo region
|
||||
"""
|
||||
return GeoValue(lat, lon, radius, unit)
|
||||
|
||||
|
||||
class Value:
|
||||
@property
|
||||
def combinable(self):
|
||||
"""
|
||||
Whether this type of value may be combined with other values
|
||||
for the same field. This makes the filter potentially more efficient
|
||||
"""
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def make_value(v):
|
||||
"""
|
||||
Convert an object to a value, if it is not a value already
|
||||
"""
|
||||
if isinstance(v, Value):
|
||||
return v
|
||||
return ScalarValue(v)
|
||||
|
||||
def to_string(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def __str__(self):
|
||||
return self.to_string()
|
||||
|
||||
|
||||
class RangeValue(Value):
|
||||
combinable = False
|
||||
|
||||
def __init__(self, a, b, inclusive_min=False, inclusive_max=False):
|
||||
if a is None:
|
||||
a = "-inf"
|
||||
if b is None:
|
||||
b = "inf"
|
||||
self.range = [str(a), str(b)]
|
||||
self.inclusive_min = inclusive_min
|
||||
self.inclusive_max = inclusive_max
|
||||
|
||||
def to_string(self):
|
||||
return "[{1}{0[0]} {2}{0[1]}]".format(
|
||||
self.range,
|
||||
"(" if not self.inclusive_min else "",
|
||||
"(" if not self.inclusive_max else "",
|
||||
)
|
||||
|
||||
|
||||
class ScalarValue(Value):
|
||||
combinable = True
|
||||
|
||||
def __init__(self, v):
|
||||
self.v = str(v)
|
||||
|
||||
def to_string(self):
|
||||
return self.v
|
||||
|
||||
|
||||
class TagValue(Value):
|
||||
combinable = False
|
||||
|
||||
def __init__(self, *tags):
|
||||
self.tags = tags
|
||||
|
||||
def to_string(self):
|
||||
return "{" + " | ".join(str(t) for t in self.tags) + "}"
|
||||
|
||||
|
||||
class GeoValue(Value):
|
||||
def __init__(self, lon, lat, radius, unit="km"):
|
||||
self.lon = lon
|
||||
self.lat = lat
|
||||
self.radius = radius
|
||||
self.unit = unit
|
||||
|
||||
def to_string(self):
|
||||
return f"[{self.lon} {self.lat} {self.radius} {self.unit}]"
|
||||
|
||||
|
||||
class Node:
|
||||
def __init__(self, *children, **kwparams):
|
||||
"""
|
||||
Create a node
|
||||
|
||||
### Parameters
|
||||
|
||||
- **children**: One or more sub-conditions. These can be additional
|
||||
`intersect`, `disjunct`, `union`, `optional`, or any other `Node`
|
||||
type.
|
||||
|
||||
The semantics of multiple conditions are dependent on the type of
|
||||
query. For an `intersection` node, this amounts to a logical AND,
|
||||
for a `union` node, this amounts to a logical `OR`.
|
||||
|
||||
- **kwparams**: key-value parameters. Each key is the name of a field,
|
||||
and the value should be a field value. This can be one of the
|
||||
following:
|
||||
|
||||
- Simple string (for text field matches)
|
||||
- value returned by one of the helper functions
|
||||
- list of either a string or a value
|
||||
|
||||
|
||||
### Examples
|
||||
|
||||
Field `num` should be between 1 and 10
|
||||
```
|
||||
intersect(num=between(1, 10)
|
||||
```
|
||||
|
||||
Name can either be `bob` or `john`
|
||||
|
||||
```
|
||||
union(name=("bob", "john"))
|
||||
```
|
||||
|
||||
Don't select countries in Israel, Japan, or US
|
||||
|
||||
```
|
||||
disjunct_union(country=("il", "jp", "us"))
|
||||
```
|
||||
"""
|
||||
|
||||
self.params = []
|
||||
|
||||
kvparams = {}
|
||||
for k, v in kwparams.items():
|
||||
curvals = kvparams.setdefault(k, [])
|
||||
if isinstance(v, (str, int, float)):
|
||||
curvals.append(Value.make_value(v))
|
||||
elif isinstance(v, Value):
|
||||
curvals.append(v)
|
||||
else:
|
||||
curvals.extend(Value.make_value(subv) for subv in v)
|
||||
|
||||
self.params += [Node.to_node(p) for p in children]
|
||||
|
||||
for k, v in kvparams.items():
|
||||
self.params.extend(self.join_fields(k, v))
|
||||
|
||||
def join_fields(self, key, vals):
|
||||
if len(vals) == 1:
|
||||
return [BaseNode(f"@{key}:{vals[0].to_string()}")]
|
||||
if not vals[0].combinable:
|
||||
return [BaseNode(f"@{key}:{v.to_string()}") for v in vals]
|
||||
s = BaseNode(f"@{key}:({self.JOINSTR.join(v.to_string() for v in vals)})")
|
||||
return [s]
|
||||
|
||||
@classmethod
|
||||
def to_node(cls, obj): # noqa
|
||||
if isinstance(obj, Node):
|
||||
return obj
|
||||
return BaseNode(obj)
|
||||
|
||||
@property
|
||||
def JOINSTR(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def to_string(self, with_parens=None):
|
||||
with_parens = self._should_use_paren(with_parens)
|
||||
pre, post = ("(", ")") if with_parens else ("", "")
|
||||
return f"{pre}{self.JOINSTR.join(n.to_string() for n in self.params)}{post}"
|
||||
|
||||
def _should_use_paren(self, optval):
|
||||
if optval is not None:
|
||||
return optval
|
||||
return len(self.params) > 1
|
||||
|
||||
def __str__(self):
|
||||
return self.to_string()
|
||||
|
||||
|
||||
class BaseNode(Node):
|
||||
def __init__(self, s):
|
||||
super().__init__()
|
||||
self.s = str(s)
|
||||
|
||||
def to_string(self, with_parens=None):
|
||||
return self.s
|
||||
|
||||
|
||||
class IntersectNode(Node):
|
||||
"""
|
||||
Create an intersection node. All children need to be satisfied in order for
|
||||
this node to evaluate as true
|
||||
"""
|
||||
|
||||
JOINSTR = " "
|
||||
|
||||
|
||||
class UnionNode(Node):
|
||||
"""
|
||||
Create a union node. Any of the children need to be satisfied in order for
|
||||
this node to evaluate as true
|
||||
"""
|
||||
|
||||
JOINSTR = "|"
|
||||
|
||||
|
||||
class DisjunctNode(IntersectNode):
|
||||
"""
|
||||
Create a disjunct node. In order for this node to be true, all of its
|
||||
children must evaluate to false
|
||||
"""
|
||||
|
||||
def to_string(self, with_parens=None):
|
||||
with_parens = self._should_use_paren(with_parens)
|
||||
ret = super().to_string(with_parens=False)
|
||||
if with_parens:
|
||||
return "(-" + ret + ")"
|
||||
else:
|
||||
return "-" + ret
|
||||
|
||||
|
||||
class DistjunctUnion(DisjunctNode):
|
||||
"""
|
||||
This node is true if *all* of its children are false. This is equivalent to
|
||||
```
|
||||
disjunct(union(...))
|
||||
```
|
||||
"""
|
||||
|
||||
JOINSTR = "|"
|
||||
|
||||
|
||||
class OptionalNode(IntersectNode):
|
||||
"""
|
||||
Create an optional node. If this nodes evaluates to true, then the document
|
||||
will be rated higher in score/rank.
|
||||
"""
|
||||
|
||||
def to_string(self, with_parens=None):
|
||||
with_parens = self._should_use_paren(with_parens)
|
||||
ret = super().to_string(with_parens=False)
|
||||
if with_parens:
|
||||
return "(~" + ret + ")"
|
||||
else:
|
||||
return "~" + ret
|
||||
|
||||
|
||||
def intersect(*args, **kwargs):
|
||||
return IntersectNode(*args, **kwargs)
|
||||
|
||||
|
||||
def union(*args, **kwargs):
|
||||
return UnionNode(*args, **kwargs)
|
||||
|
||||
|
||||
def disjunct(*args, **kwargs):
|
||||
return DisjunctNode(*args, **kwargs)
|
||||
|
||||
|
||||
def disjunct_union(*args, **kwargs):
|
||||
return DistjunctUnion(*args, **kwargs)
|
||||
|
||||
|
||||
def querystring(*args, **kwargs):
|
||||
return intersect(*args, **kwargs).to_string()
|
||||
@@ -0,0 +1,182 @@
|
||||
from typing import Union
|
||||
|
||||
from .aggregation import Asc, Desc, Reducer, SortDirection
|
||||
|
||||
|
||||
class FieldOnlyReducer(Reducer):
|
||||
"""See https://redis.io/docs/interact/search-and-query/search/aggregations/"""
|
||||
|
||||
def __init__(self, field: str) -> None:
|
||||
super().__init__(field)
|
||||
self._field = field
|
||||
|
||||
|
||||
class count(Reducer):
|
||||
"""
|
||||
Counts the number of results in the group
|
||||
"""
|
||||
|
||||
NAME = "COUNT"
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
|
||||
class sum(FieldOnlyReducer):
|
||||
"""
|
||||
Calculates the sum of all the values in the given fields within the group
|
||||
"""
|
||||
|
||||
NAME = "SUM"
|
||||
|
||||
def __init__(self, field: str) -> None:
|
||||
super().__init__(field)
|
||||
|
||||
|
||||
class min(FieldOnlyReducer):
|
||||
"""
|
||||
Calculates the smallest value in the given field within the group
|
||||
"""
|
||||
|
||||
NAME = "MIN"
|
||||
|
||||
def __init__(self, field: str) -> None:
|
||||
super().__init__(field)
|
||||
|
||||
|
||||
class max(FieldOnlyReducer):
|
||||
"""
|
||||
Calculates the largest value in the given field within the group
|
||||
"""
|
||||
|
||||
NAME = "MAX"
|
||||
|
||||
def __init__(self, field: str) -> None:
|
||||
super().__init__(field)
|
||||
|
||||
|
||||
class avg(FieldOnlyReducer):
|
||||
"""
|
||||
Calculates the mean value in the given field within the group
|
||||
"""
|
||||
|
||||
NAME = "AVG"
|
||||
|
||||
def __init__(self, field: str) -> None:
|
||||
super().__init__(field)
|
||||
|
||||
|
||||
class tolist(FieldOnlyReducer):
|
||||
"""
|
||||
Returns all the matched properties in a list
|
||||
"""
|
||||
|
||||
NAME = "TOLIST"
|
||||
|
||||
def __init__(self, field: str) -> None:
|
||||
super().__init__(field)
|
||||
|
||||
|
||||
class count_distinct(FieldOnlyReducer):
|
||||
"""
|
||||
Calculate the number of distinct values contained in all the results in
|
||||
the group for the given field
|
||||
"""
|
||||
|
||||
NAME = "COUNT_DISTINCT"
|
||||
|
||||
def __init__(self, field: str) -> None:
|
||||
super().__init__(field)
|
||||
|
||||
|
||||
class count_distinctish(FieldOnlyReducer):
|
||||
"""
|
||||
Calculate the number of distinct values contained in all the results in the
|
||||
group for the given field. This uses a faster algorithm than
|
||||
`count_distinct` but is less accurate
|
||||
"""
|
||||
|
||||
NAME = "COUNT_DISTINCTISH"
|
||||
|
||||
|
||||
class quantile(Reducer):
|
||||
"""
|
||||
Return the value for the nth percentile within the range of values for the
|
||||
field within the group.
|
||||
"""
|
||||
|
||||
NAME = "QUANTILE"
|
||||
|
||||
def __init__(self, field: str, pct: float) -> None:
|
||||
super().__init__(field, str(pct))
|
||||
self._field = field
|
||||
|
||||
|
||||
class stddev(FieldOnlyReducer):
|
||||
"""
|
||||
Return the standard deviation for the values within the group
|
||||
"""
|
||||
|
||||
NAME = "STDDEV"
|
||||
|
||||
def __init__(self, field: str) -> None:
|
||||
super().__init__(field)
|
||||
|
||||
|
||||
class first_value(Reducer):
|
||||
"""
|
||||
Selects the first value within the group according to sorting parameters
|
||||
"""
|
||||
|
||||
NAME = "FIRST_VALUE"
|
||||
|
||||
def __init__(self, field: str, *byfields: Union[Asc, Desc]) -> None:
|
||||
"""
|
||||
Selects the first value of the given field within the group.
|
||||
|
||||
### Parameter
|
||||
|
||||
- **field**: Source field used for the value
|
||||
- **byfields**: How to sort the results. This can be either the
|
||||
*class* of `aggregation.Asc` or `aggregation.Desc` in which
|
||||
case the field `field` is also used as the sort input.
|
||||
|
||||
`byfields` can also be one or more *instances* of `Asc` or `Desc`
|
||||
indicating the sort order for these fields
|
||||
"""
|
||||
|
||||
fieldstrs = []
|
||||
if (
|
||||
len(byfields) == 1
|
||||
and isinstance(byfields[0], type)
|
||||
and issubclass(byfields[0], SortDirection)
|
||||
):
|
||||
byfields = [byfields[0](field)]
|
||||
|
||||
for f in byfields:
|
||||
fieldstrs += [f.field, f.DIRSTRING]
|
||||
|
||||
args = [field]
|
||||
if fieldstrs:
|
||||
args += ["BY"] + fieldstrs
|
||||
super().__init__(*args)
|
||||
self._field = field
|
||||
|
||||
|
||||
class random_sample(Reducer):
|
||||
"""
|
||||
Returns a random sample of items from the dataset, from the given property
|
||||
"""
|
||||
|
||||
NAME = "RANDOM_SAMPLE"
|
||||
|
||||
def __init__(self, field: str, size: int) -> None:
|
||||
"""
|
||||
### Parameter
|
||||
|
||||
**field**: Field to sample from
|
||||
**size**: Return this many items (can be less)
|
||||
"""
|
||||
args = [field, str(size)]
|
||||
super().__init__(*args)
|
||||
self._field = field
|
||||
@@ -0,0 +1,87 @@
|
||||
from typing import Optional
|
||||
|
||||
from ._util import to_string
|
||||
from .document import Document
|
||||
|
||||
|
||||
class Result:
|
||||
"""
|
||||
Represents the result of a search query, and has an array of Document
|
||||
objects
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
res,
|
||||
hascontent,
|
||||
duration=0,
|
||||
has_payload=False,
|
||||
with_scores=False,
|
||||
field_encodings: Optional[dict] = None,
|
||||
):
|
||||
"""
|
||||
- duration: the execution time of the query
|
||||
- has_payload: whether the query has payloads
|
||||
- with_scores: whether the query has scores
|
||||
- field_encodings: a dictionary of field encodings if any is provided
|
||||
"""
|
||||
|
||||
self.total = res[0]
|
||||
self.duration = duration
|
||||
self.docs = []
|
||||
|
||||
step = 1
|
||||
if hascontent:
|
||||
step = step + 1
|
||||
if has_payload:
|
||||
step = step + 1
|
||||
if with_scores:
|
||||
step = step + 1
|
||||
|
||||
offset = 2 if with_scores else 1
|
||||
|
||||
for i in range(1, len(res), step):
|
||||
id = to_string(res[i])
|
||||
payload = to_string(res[i + offset]) if has_payload else None
|
||||
# fields_offset = 2 if has_payload else 1
|
||||
fields_offset = offset + 1 if has_payload else offset
|
||||
score = float(res[i + 1]) if with_scores else None
|
||||
|
||||
fields = {}
|
||||
if hascontent and res[i + fields_offset] is not None:
|
||||
keys = map(to_string, res[i + fields_offset][::2])
|
||||
values = res[i + fields_offset][1::2]
|
||||
|
||||
for key, value in zip(keys, values):
|
||||
if field_encodings is None or key not in field_encodings:
|
||||
fields[key] = to_string(value)
|
||||
continue
|
||||
|
||||
encoding = field_encodings[key]
|
||||
|
||||
# If the encoding is None, we don't need to decode the value
|
||||
if encoding is None:
|
||||
fields[key] = value
|
||||
else:
|
||||
fields[key] = to_string(value, encoding=encoding)
|
||||
|
||||
try:
|
||||
del fields["id"]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
try:
|
||||
fields["json"] = fields["$"]
|
||||
del fields["$"]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
doc = (
|
||||
Document(id, score=score, payload=payload, **fields)
|
||||
if with_scores
|
||||
else Document(id, payload=payload, **fields)
|
||||
)
|
||||
self.docs.append(doc)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Result{{{self.total} total, docs: {self.docs}}}"
|
||||
@@ -0,0 +1,55 @@
|
||||
from typing import Optional
|
||||
|
||||
from ._util import to_string
|
||||
|
||||
|
||||
class Suggestion:
|
||||
"""
|
||||
Represents a single suggestion being sent or returned from the
|
||||
autocomplete server
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, string: str, score: float = 1.0, payload: Optional[str] = None
|
||||
) -> None:
|
||||
self.string = to_string(string)
|
||||
self.payload = to_string(payload)
|
||||
self.score = score
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.string
|
||||
|
||||
|
||||
class SuggestionParser:
|
||||
"""
|
||||
Internal class used to parse results from the `SUGGET` command.
|
||||
This needs to consume either 1, 2, or 3 values at a time from
|
||||
the return value depending on what objects were requested
|
||||
"""
|
||||
|
||||
def __init__(self, with_scores: bool, with_payloads: bool, ret) -> None:
|
||||
self.with_scores = with_scores
|
||||
self.with_payloads = with_payloads
|
||||
|
||||
if with_scores and with_payloads:
|
||||
self.sugsize = 3
|
||||
self._scoreidx = 1
|
||||
self._payloadidx = 2
|
||||
elif with_scores:
|
||||
self.sugsize = 2
|
||||
self._scoreidx = 1
|
||||
elif with_payloads:
|
||||
self.sugsize = 2
|
||||
self._payloadidx = 1
|
||||
else:
|
||||
self.sugsize = 1
|
||||
self._scoreidx = -1
|
||||
|
||||
self._sugs = ret
|
||||
|
||||
def __iter__(self):
|
||||
for i in range(0, len(self._sugs), self.sugsize):
|
||||
ss = self._sugs[i]
|
||||
score = float(self._sugs[i + self._scoreidx]) if self.with_scores else 1.0
|
||||
payload = self._sugs[i + self._payloadidx] if self.with_payloads else None
|
||||
yield Suggestion(ss, score, payload)
|
||||
Reference in New Issue
Block a user