Updates
This commit is contained in:
@@ -0,0 +1 @@
|
||||
"""Result Backends."""
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,190 @@
|
||||
"""ArangoDb result store backend."""
|
||||
|
||||
# pylint: disable=W1202,W0703
|
||||
|
||||
from datetime import timedelta
|
||||
|
||||
from kombu.utils.objects import cached_property
|
||||
from kombu.utils.url import _parse_url
|
||||
|
||||
from celery.exceptions import ImproperlyConfigured
|
||||
|
||||
from .base import KeyValueStoreBackend
|
||||
|
||||
try:
|
||||
from pyArango import connection as py_arango_connection
|
||||
from pyArango.theExceptions import AQLQueryError
|
||||
except ImportError:
|
||||
py_arango_connection = AQLQueryError = None
|
||||
|
||||
__all__ = ('ArangoDbBackend',)
|
||||
|
||||
|
||||
class ArangoDbBackend(KeyValueStoreBackend):
|
||||
"""ArangoDb backend.
|
||||
|
||||
Sample url
|
||||
"arangodb://username:password@host:port/database/collection"
|
||||
*arangodb_backend_settings* is where the settings are present
|
||||
(in the app.conf)
|
||||
Settings should contain the host, port, username, password, database name,
|
||||
collection name else the default will be chosen.
|
||||
Default database name and collection name is celery.
|
||||
|
||||
Raises
|
||||
------
|
||||
celery.exceptions.ImproperlyConfigured:
|
||||
if module :pypi:`pyArango` is not available.
|
||||
|
||||
"""
|
||||
|
||||
host = '127.0.0.1'
|
||||
port = '8529'
|
||||
database = 'celery'
|
||||
collection = 'celery'
|
||||
username = None
|
||||
password = None
|
||||
# protocol is not supported in backend url (http is taken as default)
|
||||
http_protocol = 'http'
|
||||
verify = False
|
||||
|
||||
# Use str as arangodb key not bytes
|
||||
key_t = str
|
||||
|
||||
def __init__(self, url=None, *args, **kwargs):
|
||||
"""Parse the url or load the settings from settings object."""
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
if py_arango_connection is None:
|
||||
raise ImproperlyConfigured(
|
||||
'You need to install the pyArango library to use the '
|
||||
'ArangoDb backend.',
|
||||
)
|
||||
|
||||
self.url = url
|
||||
|
||||
if url is None:
|
||||
host = port = database = collection = username = password = None
|
||||
else:
|
||||
(
|
||||
_schema, host, port, username, password,
|
||||
database_collection, _query
|
||||
) = _parse_url(url)
|
||||
if database_collection is None:
|
||||
database = collection = None
|
||||
else:
|
||||
database, collection = database_collection.split('/')
|
||||
|
||||
config = self.app.conf.get('arangodb_backend_settings', None)
|
||||
if config is not None:
|
||||
if not isinstance(config, dict):
|
||||
raise ImproperlyConfigured(
|
||||
'ArangoDb backend settings should be grouped in a dict',
|
||||
)
|
||||
else:
|
||||
config = {}
|
||||
|
||||
self.host = host or config.get('host', self.host)
|
||||
self.port = int(port or config.get('port', self.port))
|
||||
self.http_protocol = config.get('http_protocol', self.http_protocol)
|
||||
self.verify = config.get('verify', self.verify)
|
||||
self.database = database or config.get('database', self.database)
|
||||
self.collection = \
|
||||
collection or config.get('collection', self.collection)
|
||||
self.username = username or config.get('username', self.username)
|
||||
self.password = password or config.get('password', self.password)
|
||||
self.arangodb_url = "{http_protocol}://{host}:{port}".format(
|
||||
http_protocol=self.http_protocol, host=self.host, port=self.port
|
||||
)
|
||||
self._connection = None
|
||||
|
||||
@property
|
||||
def connection(self):
|
||||
"""Connect to the arangodb server."""
|
||||
if self._connection is None:
|
||||
self._connection = py_arango_connection.Connection(
|
||||
arangoURL=self.arangodb_url, username=self.username,
|
||||
password=self.password, verify=self.verify
|
||||
)
|
||||
return self._connection
|
||||
|
||||
@property
|
||||
def db(self):
|
||||
"""Database Object to the given database."""
|
||||
return self.connection[self.database]
|
||||
|
||||
@cached_property
|
||||
def expires_delta(self):
|
||||
return timedelta(seconds=0 if self.expires is None else self.expires)
|
||||
|
||||
def get(self, key):
|
||||
if key is None:
|
||||
return None
|
||||
query = self.db.AQLQuery(
|
||||
"RETURN DOCUMENT(@@collection, @key).task",
|
||||
rawResults=True,
|
||||
bindVars={
|
||||
"@collection": self.collection,
|
||||
"key": key,
|
||||
},
|
||||
)
|
||||
return next(query) if len(query) > 0 else None
|
||||
|
||||
def set(self, key, value):
|
||||
self.db.AQLQuery(
|
||||
"""
|
||||
UPSERT {_key: @key}
|
||||
INSERT {_key: @key, task: @value}
|
||||
UPDATE {task: @value} IN @@collection
|
||||
""",
|
||||
bindVars={
|
||||
"@collection": self.collection,
|
||||
"key": key,
|
||||
"value": value,
|
||||
},
|
||||
)
|
||||
|
||||
def mget(self, keys):
|
||||
if keys is None:
|
||||
return
|
||||
query = self.db.AQLQuery(
|
||||
"FOR k IN @keys RETURN DOCUMENT(@@collection, k).task",
|
||||
rawResults=True,
|
||||
bindVars={
|
||||
"@collection": self.collection,
|
||||
"keys": keys if isinstance(keys, list) else list(keys),
|
||||
},
|
||||
)
|
||||
while True:
|
||||
yield from query
|
||||
try:
|
||||
query.nextBatch()
|
||||
except StopIteration:
|
||||
break
|
||||
|
||||
def delete(self, key):
|
||||
if key is None:
|
||||
return
|
||||
self.db.AQLQuery(
|
||||
"REMOVE {_key: @key} IN @@collection",
|
||||
bindVars={
|
||||
"@collection": self.collection,
|
||||
"key": key,
|
||||
},
|
||||
)
|
||||
|
||||
def cleanup(self):
|
||||
if not self.expires:
|
||||
return
|
||||
checkpoint = (self.app.now() - self.expires_delta).isoformat()
|
||||
self.db.AQLQuery(
|
||||
"""
|
||||
FOR record IN @@collection
|
||||
FILTER record.task.date_done < @checkpoint
|
||||
REMOVE record IN @@collection
|
||||
""",
|
||||
bindVars={
|
||||
"@collection": self.collection,
|
||||
"checkpoint": checkpoint,
|
||||
},
|
||||
)
|
||||
@@ -0,0 +1,333 @@
|
||||
"""Async I/O backend support utilities."""
|
||||
import socket
|
||||
import threading
|
||||
import time
|
||||
from collections import deque
|
||||
from queue import Empty
|
||||
from time import sleep
|
||||
from weakref import WeakKeyDictionary
|
||||
|
||||
from kombu.utils.compat import detect_environment
|
||||
|
||||
from celery import states
|
||||
from celery.exceptions import TimeoutError
|
||||
from celery.utils.threads import THREAD_TIMEOUT_MAX
|
||||
|
||||
__all__ = (
|
||||
'AsyncBackendMixin', 'BaseResultConsumer', 'Drainer',
|
||||
'register_drainer',
|
||||
)
|
||||
|
||||
drainers = {}
|
||||
|
||||
|
||||
def register_drainer(name):
|
||||
"""Decorator used to register a new result drainer type."""
|
||||
def _inner(cls):
|
||||
drainers[name] = cls
|
||||
return cls
|
||||
return _inner
|
||||
|
||||
|
||||
@register_drainer('default')
|
||||
class Drainer:
|
||||
"""Result draining service."""
|
||||
|
||||
def __init__(self, result_consumer):
|
||||
self.result_consumer = result_consumer
|
||||
|
||||
def start(self):
|
||||
pass
|
||||
|
||||
def stop(self):
|
||||
pass
|
||||
|
||||
def drain_events_until(self, p, timeout=None, interval=1, on_interval=None, wait=None):
|
||||
wait = wait or self.result_consumer.drain_events
|
||||
time_start = time.monotonic()
|
||||
|
||||
while 1:
|
||||
# Total time spent may exceed a single call to wait()
|
||||
if timeout and time.monotonic() - time_start >= timeout:
|
||||
raise socket.timeout()
|
||||
try:
|
||||
yield self.wait_for(p, wait, timeout=interval)
|
||||
except socket.timeout:
|
||||
pass
|
||||
if on_interval:
|
||||
on_interval()
|
||||
if p.ready: # got event on the wanted channel.
|
||||
break
|
||||
|
||||
def wait_for(self, p, wait, timeout=None):
|
||||
wait(timeout=timeout)
|
||||
|
||||
|
||||
class greenletDrainer(Drainer):
|
||||
spawn = None
|
||||
_g = None
|
||||
_drain_complete_event = None # event, sended (and recreated) after every drain_events iteration
|
||||
|
||||
def _create_drain_complete_event(self):
|
||||
"""create new self._drain_complete_event object"""
|
||||
pass
|
||||
|
||||
def _send_drain_complete_event(self):
|
||||
"""raise self._drain_complete_event for wakeup .wait_for"""
|
||||
pass
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._started = threading.Event()
|
||||
self._stopped = threading.Event()
|
||||
self._shutdown = threading.Event()
|
||||
self._create_drain_complete_event()
|
||||
|
||||
def run(self):
|
||||
self._started.set()
|
||||
while not self._stopped.is_set():
|
||||
try:
|
||||
self.result_consumer.drain_events(timeout=1)
|
||||
self._send_drain_complete_event()
|
||||
self._create_drain_complete_event()
|
||||
except socket.timeout:
|
||||
pass
|
||||
self._shutdown.set()
|
||||
|
||||
def start(self):
|
||||
if not self._started.is_set():
|
||||
self._g = self.spawn(self.run)
|
||||
self._started.wait()
|
||||
|
||||
def stop(self):
|
||||
self._stopped.set()
|
||||
self._send_drain_complete_event()
|
||||
self._shutdown.wait(THREAD_TIMEOUT_MAX)
|
||||
|
||||
def wait_for(self, p, wait, timeout=None):
|
||||
self.start()
|
||||
if not p.ready:
|
||||
self._drain_complete_event.wait(timeout=timeout)
|
||||
|
||||
|
||||
@register_drainer('eventlet')
|
||||
class eventletDrainer(greenletDrainer):
|
||||
|
||||
def spawn(self, func):
|
||||
from eventlet import sleep, spawn
|
||||
g = spawn(func)
|
||||
sleep(0)
|
||||
return g
|
||||
|
||||
def _create_drain_complete_event(self):
|
||||
from eventlet.event import Event
|
||||
self._drain_complete_event = Event()
|
||||
|
||||
def _send_drain_complete_event(self):
|
||||
self._drain_complete_event.send()
|
||||
|
||||
|
||||
@register_drainer('gevent')
|
||||
class geventDrainer(greenletDrainer):
|
||||
|
||||
def spawn(self, func):
|
||||
import gevent
|
||||
g = gevent.spawn(func)
|
||||
gevent.sleep(0)
|
||||
return g
|
||||
|
||||
def _create_drain_complete_event(self):
|
||||
from gevent.event import Event
|
||||
self._drain_complete_event = Event()
|
||||
|
||||
def _send_drain_complete_event(self):
|
||||
self._drain_complete_event.set()
|
||||
self._create_drain_complete_event()
|
||||
|
||||
|
||||
class AsyncBackendMixin:
|
||||
"""Mixin for backends that enables the async API."""
|
||||
|
||||
def _collect_into(self, result, bucket):
|
||||
self.result_consumer.buckets[result] = bucket
|
||||
|
||||
def iter_native(self, result, no_ack=True, **kwargs):
|
||||
self._ensure_not_eager()
|
||||
|
||||
results = result.results
|
||||
if not results:
|
||||
raise StopIteration()
|
||||
|
||||
# we tell the result consumer to put consumed results
|
||||
# into these buckets.
|
||||
bucket = deque()
|
||||
for node in results:
|
||||
if not hasattr(node, '_cache'):
|
||||
bucket.append(node)
|
||||
elif node._cache:
|
||||
bucket.append(node)
|
||||
else:
|
||||
self._collect_into(node, bucket)
|
||||
|
||||
for _ in self._wait_for_pending(result, no_ack=no_ack, **kwargs):
|
||||
while bucket:
|
||||
node = bucket.popleft()
|
||||
if not hasattr(node, '_cache'):
|
||||
yield node.id, node.children
|
||||
else:
|
||||
yield node.id, node._cache
|
||||
while bucket:
|
||||
node = bucket.popleft()
|
||||
yield node.id, node._cache
|
||||
|
||||
def add_pending_result(self, result, weak=False, start_drainer=True):
|
||||
if start_drainer:
|
||||
self.result_consumer.drainer.start()
|
||||
try:
|
||||
self._maybe_resolve_from_buffer(result)
|
||||
except Empty:
|
||||
self._add_pending_result(result.id, result, weak=weak)
|
||||
return result
|
||||
|
||||
def _maybe_resolve_from_buffer(self, result):
|
||||
result._maybe_set_cache(self._pending_messages.take(result.id))
|
||||
|
||||
def _add_pending_result(self, task_id, result, weak=False):
|
||||
concrete, weak_ = self._pending_results
|
||||
if task_id not in weak_ and result.id not in concrete:
|
||||
(weak_ if weak else concrete)[task_id] = result
|
||||
self.result_consumer.consume_from(task_id)
|
||||
|
||||
def add_pending_results(self, results, weak=False):
|
||||
self.result_consumer.drainer.start()
|
||||
return [self.add_pending_result(result, weak=weak, start_drainer=False)
|
||||
for result in results]
|
||||
|
||||
def remove_pending_result(self, result):
|
||||
self._remove_pending_result(result.id)
|
||||
self.on_result_fulfilled(result)
|
||||
return result
|
||||
|
||||
def _remove_pending_result(self, task_id):
|
||||
for mapping in self._pending_results:
|
||||
mapping.pop(task_id, None)
|
||||
|
||||
def on_result_fulfilled(self, result):
|
||||
self.result_consumer.cancel_for(result.id)
|
||||
|
||||
def wait_for_pending(self, result,
|
||||
callback=None, propagate=True, **kwargs):
|
||||
self._ensure_not_eager()
|
||||
for _ in self._wait_for_pending(result, **kwargs):
|
||||
pass
|
||||
return result.maybe_throw(callback=callback, propagate=propagate)
|
||||
|
||||
def _wait_for_pending(self, result,
|
||||
timeout=None, on_interval=None, on_message=None,
|
||||
**kwargs):
|
||||
return self.result_consumer._wait_for_pending(
|
||||
result, timeout=timeout,
|
||||
on_interval=on_interval, on_message=on_message,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
@property
|
||||
def is_async(self):
|
||||
return True
|
||||
|
||||
|
||||
class BaseResultConsumer:
|
||||
"""Manager responsible for consuming result messages."""
|
||||
|
||||
def __init__(self, backend, app, accept,
|
||||
pending_results, pending_messages):
|
||||
self.backend = backend
|
||||
self.app = app
|
||||
self.accept = accept
|
||||
self._pending_results = pending_results
|
||||
self._pending_messages = pending_messages
|
||||
self.on_message = None
|
||||
self.buckets = WeakKeyDictionary()
|
||||
self.drainer = drainers[detect_environment()](self)
|
||||
|
||||
def start(self, initial_task_id, **kwargs):
|
||||
raise NotImplementedError()
|
||||
|
||||
def stop(self):
|
||||
pass
|
||||
|
||||
def drain_events(self, timeout=None):
|
||||
raise NotImplementedError()
|
||||
|
||||
def consume_from(self, task_id):
|
||||
raise NotImplementedError()
|
||||
|
||||
def cancel_for(self, task_id):
|
||||
raise NotImplementedError()
|
||||
|
||||
def _after_fork(self):
|
||||
self.buckets.clear()
|
||||
self.buckets = WeakKeyDictionary()
|
||||
self.on_message = None
|
||||
self.on_after_fork()
|
||||
|
||||
def on_after_fork(self):
|
||||
pass
|
||||
|
||||
def drain_events_until(self, p, timeout=None, on_interval=None):
|
||||
return self.drainer.drain_events_until(
|
||||
p, timeout=timeout, on_interval=on_interval)
|
||||
|
||||
def _wait_for_pending(self, result,
|
||||
timeout=None, on_interval=None, on_message=None,
|
||||
**kwargs):
|
||||
self.on_wait_for_pending(result, timeout=timeout, **kwargs)
|
||||
prev_on_m, self.on_message = self.on_message, on_message
|
||||
try:
|
||||
for _ in self.drain_events_until(
|
||||
result.on_ready, timeout=timeout,
|
||||
on_interval=on_interval):
|
||||
yield
|
||||
sleep(0)
|
||||
except socket.timeout:
|
||||
raise TimeoutError('The operation timed out.')
|
||||
finally:
|
||||
self.on_message = prev_on_m
|
||||
|
||||
def on_wait_for_pending(self, result, timeout=None, **kwargs):
|
||||
pass
|
||||
|
||||
def on_out_of_band_result(self, message):
|
||||
self.on_state_change(message.payload, message)
|
||||
|
||||
def _get_pending_result(self, task_id):
|
||||
for mapping in self._pending_results:
|
||||
try:
|
||||
return mapping[task_id]
|
||||
except KeyError:
|
||||
pass
|
||||
raise KeyError(task_id)
|
||||
|
||||
def on_state_change(self, meta, message):
|
||||
if self.on_message:
|
||||
self.on_message(meta)
|
||||
if meta['status'] in states.READY_STATES:
|
||||
task_id = meta['task_id']
|
||||
try:
|
||||
result = self._get_pending_result(task_id)
|
||||
except KeyError:
|
||||
# send to buffer in case we received this result
|
||||
# before it was added to _pending_results.
|
||||
self._pending_messages.put(task_id, meta)
|
||||
else:
|
||||
result._maybe_set_cache(meta)
|
||||
buckets = self.buckets
|
||||
try:
|
||||
# remove bucket for this result, since it's fulfilled
|
||||
bucket = buckets.pop(result)
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
# send to waiter via bucket
|
||||
bucket.append(result)
|
||||
sleep(0)
|
||||
@@ -0,0 +1,188 @@
|
||||
"""The Azure Storage Block Blob backend for Celery."""
|
||||
from kombu.transport.azurestoragequeues import Transport as AzureStorageQueuesTransport
|
||||
from kombu.utils import cached_property
|
||||
from kombu.utils.encoding import bytes_to_str
|
||||
|
||||
from celery.exceptions import ImproperlyConfigured
|
||||
from celery.utils.log import get_logger
|
||||
|
||||
from .base import KeyValueStoreBackend
|
||||
|
||||
try:
|
||||
import azure.storage.blob as azurestorage
|
||||
from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError
|
||||
from azure.storage.blob import BlobServiceClient
|
||||
except ImportError:
|
||||
azurestorage = None
|
||||
|
||||
__all__ = ("AzureBlockBlobBackend",)
|
||||
|
||||
LOGGER = get_logger(__name__)
|
||||
AZURE_BLOCK_BLOB_CONNECTION_PREFIX = 'azureblockblob://'
|
||||
|
||||
|
||||
class AzureBlockBlobBackend(KeyValueStoreBackend):
|
||||
"""Azure Storage Block Blob backend for Celery."""
|
||||
|
||||
def __init__(self,
|
||||
url=None,
|
||||
container_name=None,
|
||||
*args,
|
||||
**kwargs):
|
||||
"""
|
||||
Supported URL formats:
|
||||
|
||||
azureblockblob://CONNECTION_STRING
|
||||
azureblockblob://DefaultAzureCredential@STORAGE_ACCOUNT_URL
|
||||
azureblockblob://ManagedIdentityCredential@STORAGE_ACCOUNT_URL
|
||||
"""
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
if azurestorage is None or azurestorage.__version__ < '12':
|
||||
raise ImproperlyConfigured(
|
||||
"You need to install the azure-storage-blob v12 library to"
|
||||
"use the AzureBlockBlob backend")
|
||||
|
||||
conf = self.app.conf
|
||||
|
||||
self._connection_string = self._parse_url(url)
|
||||
|
||||
self._container_name = (
|
||||
container_name or
|
||||
conf["azureblockblob_container_name"])
|
||||
|
||||
self.base_path = conf.get('azureblockblob_base_path', '')
|
||||
self._connection_timeout = conf.get(
|
||||
'azureblockblob_connection_timeout', 20
|
||||
)
|
||||
self._read_timeout = conf.get('azureblockblob_read_timeout', 120)
|
||||
|
||||
@classmethod
|
||||
def _parse_url(cls, url, prefix=AZURE_BLOCK_BLOB_CONNECTION_PREFIX):
|
||||
connection_string = url[len(prefix):]
|
||||
if not connection_string:
|
||||
raise ImproperlyConfigured("Invalid URL")
|
||||
|
||||
return connection_string
|
||||
|
||||
@cached_property
|
||||
def _blob_service_client(self):
|
||||
"""Return the Azure Storage Blob service client.
|
||||
|
||||
If this is the first call to the property, the client is created and
|
||||
the container is created if it doesn't yet exist.
|
||||
|
||||
"""
|
||||
if (
|
||||
"DefaultAzureCredential" in self._connection_string or
|
||||
"ManagedIdentityCredential" in self._connection_string
|
||||
):
|
||||
# Leveraging the work that Kombu already did for us
|
||||
credential_, url = AzureStorageQueuesTransport.parse_uri(
|
||||
self._connection_string
|
||||
)
|
||||
client = BlobServiceClient(
|
||||
account_url=url,
|
||||
credential=credential_,
|
||||
connection_timeout=self._connection_timeout,
|
||||
read_timeout=self._read_timeout,
|
||||
)
|
||||
else:
|
||||
client = BlobServiceClient.from_connection_string(
|
||||
self._connection_string,
|
||||
connection_timeout=self._connection_timeout,
|
||||
read_timeout=self._read_timeout,
|
||||
)
|
||||
|
||||
try:
|
||||
client.create_container(name=self._container_name)
|
||||
msg = f"Container created with name {self._container_name}."
|
||||
except ResourceExistsError:
|
||||
msg = f"Container with name {self._container_name} already." \
|
||||
"exists. This will not be created."
|
||||
LOGGER.info(msg)
|
||||
|
||||
return client
|
||||
|
||||
def get(self, key):
|
||||
"""Read the value stored at the given key.
|
||||
|
||||
Args:
|
||||
key: The key for which to read the value.
|
||||
"""
|
||||
key = bytes_to_str(key)
|
||||
LOGGER.debug("Getting Azure Block Blob %s/%s", self._container_name, key)
|
||||
|
||||
blob_client = self._blob_service_client.get_blob_client(
|
||||
container=self._container_name,
|
||||
blob=f'{self.base_path}{key}',
|
||||
)
|
||||
|
||||
try:
|
||||
return blob_client.download_blob().readall().decode()
|
||||
except ResourceNotFoundError:
|
||||
return None
|
||||
|
||||
def set(self, key, value):
|
||||
"""Store a value for a given key.
|
||||
|
||||
Args:
|
||||
key: The key at which to store the value.
|
||||
value: The value to store.
|
||||
|
||||
"""
|
||||
key = bytes_to_str(key)
|
||||
LOGGER.debug(f"Creating azure blob at {self._container_name}/{key}")
|
||||
|
||||
blob_client = self._blob_service_client.get_blob_client(
|
||||
container=self._container_name,
|
||||
blob=f'{self.base_path}{key}',
|
||||
)
|
||||
|
||||
blob_client.upload_blob(value, overwrite=True)
|
||||
|
||||
def mget(self, keys):
|
||||
"""Read all the values for the provided keys.
|
||||
|
||||
Args:
|
||||
keys: The list of keys to read.
|
||||
|
||||
"""
|
||||
return [self.get(key) for key in keys]
|
||||
|
||||
def delete(self, key):
|
||||
"""Delete the value at a given key.
|
||||
|
||||
Args:
|
||||
key: The key of the value to delete.
|
||||
|
||||
"""
|
||||
key = bytes_to_str(key)
|
||||
LOGGER.debug(f"Deleting azure blob at {self._container_name}/{key}")
|
||||
|
||||
blob_client = self._blob_service_client.get_blob_client(
|
||||
container=self._container_name,
|
||||
blob=f'{self.base_path}{key}',
|
||||
)
|
||||
|
||||
blob_client.delete_blob()
|
||||
|
||||
def as_uri(self, include_password=False):
|
||||
if include_password:
|
||||
return (
|
||||
f'{AZURE_BLOCK_BLOB_CONNECTION_PREFIX}'
|
||||
f'{self._connection_string}'
|
||||
)
|
||||
|
||||
connection_string_parts = self._connection_string.split(';')
|
||||
account_key_prefix = 'AccountKey='
|
||||
redacted_connection_string_parts = [
|
||||
f'{account_key_prefix}**' if part.startswith(account_key_prefix)
|
||||
else part
|
||||
for part in connection_string_parts
|
||||
]
|
||||
|
||||
return (
|
||||
f'{AZURE_BLOCK_BLOB_CONNECTION_PREFIX}'
|
||||
f'{";".join(redacted_connection_string_parts)}'
|
||||
)
|
||||
1112
ETB-API/venv/lib/python3.12/site-packages/celery/backends/base.py
Normal file
1112
ETB-API/venv/lib/python3.12/site-packages/celery/backends/base.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,163 @@
|
||||
"""Memcached and in-memory cache result backend."""
|
||||
from kombu.utils.encoding import bytes_to_str, ensure_bytes
|
||||
from kombu.utils.objects import cached_property
|
||||
|
||||
from celery.exceptions import ImproperlyConfigured
|
||||
from celery.utils.functional import LRUCache
|
||||
|
||||
from .base import KeyValueStoreBackend
|
||||
|
||||
__all__ = ('CacheBackend',)
|
||||
|
||||
_imp = [None]
|
||||
|
||||
REQUIRES_BACKEND = """\
|
||||
The Memcached backend requires either pylibmc or python-memcached.\
|
||||
"""
|
||||
|
||||
UNKNOWN_BACKEND = """\
|
||||
The cache backend {0!r} is unknown,
|
||||
Please use one of the following backends instead: {1}\
|
||||
"""
|
||||
|
||||
# Global shared in-memory cache for in-memory cache client
|
||||
# This is to share cache between threads
|
||||
_DUMMY_CLIENT_CACHE = LRUCache(limit=5000)
|
||||
|
||||
|
||||
def import_best_memcache():
|
||||
if _imp[0] is None:
|
||||
is_pylibmc, memcache_key_t = False, bytes_to_str
|
||||
try:
|
||||
import pylibmc as memcache
|
||||
is_pylibmc = True
|
||||
except ImportError:
|
||||
try:
|
||||
import memcache
|
||||
except ImportError:
|
||||
raise ImproperlyConfigured(REQUIRES_BACKEND)
|
||||
_imp[0] = (is_pylibmc, memcache, memcache_key_t)
|
||||
return _imp[0]
|
||||
|
||||
|
||||
def get_best_memcache(*args, **kwargs):
|
||||
# pylint: disable=unpacking-non-sequence
|
||||
# This is most definitely a sequence, but pylint thinks it's not.
|
||||
is_pylibmc, memcache, key_t = import_best_memcache()
|
||||
Client = _Client = memcache.Client
|
||||
|
||||
if not is_pylibmc:
|
||||
def Client(*args, **kwargs): # noqa: F811
|
||||
kwargs.pop('behaviors', None)
|
||||
return _Client(*args, **kwargs)
|
||||
|
||||
return Client, key_t
|
||||
|
||||
|
||||
class DummyClient:
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.cache = _DUMMY_CLIENT_CACHE
|
||||
|
||||
def get(self, key, *args, **kwargs):
|
||||
return self.cache.get(key)
|
||||
|
||||
def get_multi(self, keys):
|
||||
cache = self.cache
|
||||
return {k: cache[k] for k in keys if k in cache}
|
||||
|
||||
def set(self, key, value, *args, **kwargs):
|
||||
self.cache[key] = value
|
||||
|
||||
def delete(self, key, *args, **kwargs):
|
||||
self.cache.pop(key, None)
|
||||
|
||||
def incr(self, key, delta=1):
|
||||
return self.cache.incr(key, delta)
|
||||
|
||||
def touch(self, key, expire):
|
||||
pass
|
||||
|
||||
|
||||
backends = {
|
||||
'memcache': get_best_memcache,
|
||||
'memcached': get_best_memcache,
|
||||
'pylibmc': get_best_memcache,
|
||||
'memory': lambda: (DummyClient, ensure_bytes),
|
||||
}
|
||||
|
||||
|
||||
class CacheBackend(KeyValueStoreBackend):
|
||||
"""Cache result backend."""
|
||||
|
||||
servers = None
|
||||
supports_autoexpire = True
|
||||
supports_native_join = True
|
||||
implements_incr = True
|
||||
|
||||
def __init__(self, app, expires=None, backend=None,
|
||||
options=None, url=None, **kwargs):
|
||||
options = {} if not options else options
|
||||
super().__init__(app, **kwargs)
|
||||
self.url = url
|
||||
|
||||
self.options = dict(self.app.conf.cache_backend_options,
|
||||
**options)
|
||||
|
||||
self.backend = url or backend or self.app.conf.cache_backend
|
||||
if self.backend:
|
||||
self.backend, _, servers = self.backend.partition('://')
|
||||
self.servers = servers.rstrip('/').split(';')
|
||||
self.expires = self.prepare_expires(expires, type=int)
|
||||
try:
|
||||
self.Client, self.key_t = backends[self.backend]()
|
||||
except KeyError:
|
||||
raise ImproperlyConfigured(UNKNOWN_BACKEND.format(
|
||||
self.backend, ', '.join(backends)))
|
||||
self._encode_prefixes() # rencode the keyprefixes
|
||||
|
||||
def get(self, key):
|
||||
return self.client.get(key)
|
||||
|
||||
def mget(self, keys):
|
||||
return self.client.get_multi(keys)
|
||||
|
||||
def set(self, key, value):
|
||||
return self.client.set(key, value, self.expires)
|
||||
|
||||
def delete(self, key):
|
||||
return self.client.delete(key)
|
||||
|
||||
def _apply_chord_incr(self, header_result_args, body, **kwargs):
|
||||
chord_key = self.get_key_for_chord(header_result_args[0])
|
||||
self.client.set(chord_key, 0, time=self.expires)
|
||||
return super()._apply_chord_incr(
|
||||
header_result_args, body, **kwargs)
|
||||
|
||||
def incr(self, key):
|
||||
return self.client.incr(key)
|
||||
|
||||
def expire(self, key, value):
|
||||
return self.client.touch(key, value)
|
||||
|
||||
@cached_property
|
||||
def client(self):
|
||||
return self.Client(self.servers, **self.options)
|
||||
|
||||
def __reduce__(self, args=(), kwargs=None):
|
||||
kwargs = {} if not kwargs else kwargs
|
||||
servers = ';'.join(self.servers)
|
||||
backend = f'{self.backend}://{servers}/'
|
||||
kwargs.update(
|
||||
{'backend': backend,
|
||||
'expires': self.expires,
|
||||
'options': self.options})
|
||||
return super().__reduce__(args, kwargs)
|
||||
|
||||
def as_uri(self, *args, **kwargs):
|
||||
"""Return the backend as an URI.
|
||||
|
||||
This properly handles the case of multiple servers.
|
||||
"""
|
||||
servers = ';'.join(self.servers)
|
||||
return f'{self.backend}://{servers}/'
|
||||
@@ -0,0 +1,256 @@
|
||||
"""Apache Cassandra result store backend using the DataStax driver."""
|
||||
import threading
|
||||
|
||||
from celery import states
|
||||
from celery.exceptions import ImproperlyConfigured
|
||||
from celery.utils.log import get_logger
|
||||
|
||||
from .base import BaseBackend
|
||||
|
||||
try: # pragma: no cover
|
||||
import cassandra
|
||||
import cassandra.auth
|
||||
import cassandra.cluster
|
||||
import cassandra.query
|
||||
except ImportError:
|
||||
cassandra = None
|
||||
|
||||
|
||||
__all__ = ('CassandraBackend',)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
E_NO_CASSANDRA = """
|
||||
You need to install the cassandra-driver library to
|
||||
use the Cassandra backend. See https://github.com/datastax/python-driver
|
||||
"""
|
||||
|
||||
E_NO_SUCH_CASSANDRA_AUTH_PROVIDER = """
|
||||
CASSANDRA_AUTH_PROVIDER you provided is not a valid auth_provider class.
|
||||
See https://datastax.github.io/python-driver/api/cassandra/auth.html.
|
||||
"""
|
||||
|
||||
E_CASSANDRA_MISCONFIGURED = 'Cassandra backend improperly configured.'
|
||||
|
||||
E_CASSANDRA_NOT_CONFIGURED = 'Cassandra backend not configured.'
|
||||
|
||||
Q_INSERT_RESULT = """
|
||||
INSERT INTO {table} (
|
||||
task_id, status, result, date_done, traceback, children) VALUES (
|
||||
%s, %s, %s, %s, %s, %s) {expires};
|
||||
"""
|
||||
|
||||
Q_SELECT_RESULT = """
|
||||
SELECT status, result, date_done, traceback, children
|
||||
FROM {table}
|
||||
WHERE task_id=%s
|
||||
LIMIT 1
|
||||
"""
|
||||
|
||||
Q_CREATE_RESULT_TABLE = """
|
||||
CREATE TABLE {table} (
|
||||
task_id text,
|
||||
status text,
|
||||
result blob,
|
||||
date_done timestamp,
|
||||
traceback blob,
|
||||
children blob,
|
||||
PRIMARY KEY ((task_id), date_done)
|
||||
) WITH CLUSTERING ORDER BY (date_done DESC);
|
||||
"""
|
||||
|
||||
Q_EXPIRES = """
|
||||
USING TTL {0}
|
||||
"""
|
||||
|
||||
|
||||
def buf_t(x):
|
||||
return bytes(x, 'utf8')
|
||||
|
||||
|
||||
class CassandraBackend(BaseBackend):
|
||||
"""Cassandra/AstraDB backend utilizing DataStax driver.
|
||||
|
||||
Raises:
|
||||
celery.exceptions.ImproperlyConfigured:
|
||||
if module :pypi:`cassandra-driver` is not available,
|
||||
or not-exactly-one of the :setting:`cassandra_servers` and
|
||||
the :setting:`cassandra_secure_bundle_path` settings is set.
|
||||
"""
|
||||
|
||||
#: List of Cassandra servers with format: ``hostname``.
|
||||
servers = None
|
||||
#: Location of the secure connect bundle zipfile (absolute path).
|
||||
bundle_path = None
|
||||
|
||||
supports_autoexpire = True # autoexpire supported via entry_ttl
|
||||
|
||||
def __init__(self, servers=None, keyspace=None, table=None, entry_ttl=None,
|
||||
port=None, bundle_path=None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if not cassandra:
|
||||
raise ImproperlyConfigured(E_NO_CASSANDRA)
|
||||
|
||||
conf = self.app.conf
|
||||
self.servers = servers or conf.get('cassandra_servers', None)
|
||||
self.bundle_path = bundle_path or conf.get(
|
||||
'cassandra_secure_bundle_path', None)
|
||||
self.port = port or conf.get('cassandra_port', None) or 9042
|
||||
self.keyspace = keyspace or conf.get('cassandra_keyspace', None)
|
||||
self.table = table or conf.get('cassandra_table', None)
|
||||
self.cassandra_options = conf.get('cassandra_options', {})
|
||||
|
||||
# either servers or bundle path must be provided...
|
||||
db_directions = self.servers or self.bundle_path
|
||||
if not db_directions or not self.keyspace or not self.table:
|
||||
raise ImproperlyConfigured(E_CASSANDRA_NOT_CONFIGURED)
|
||||
# ...but not both:
|
||||
if self.servers and self.bundle_path:
|
||||
raise ImproperlyConfigured(E_CASSANDRA_MISCONFIGURED)
|
||||
|
||||
expires = entry_ttl or conf.get('cassandra_entry_ttl', None)
|
||||
|
||||
self.cqlexpires = (
|
||||
Q_EXPIRES.format(expires) if expires is not None else '')
|
||||
|
||||
read_cons = conf.get('cassandra_read_consistency') or 'LOCAL_QUORUM'
|
||||
write_cons = conf.get('cassandra_write_consistency') or 'LOCAL_QUORUM'
|
||||
|
||||
self.read_consistency = getattr(
|
||||
cassandra.ConsistencyLevel, read_cons,
|
||||
cassandra.ConsistencyLevel.LOCAL_QUORUM)
|
||||
self.write_consistency = getattr(
|
||||
cassandra.ConsistencyLevel, write_cons,
|
||||
cassandra.ConsistencyLevel.LOCAL_QUORUM)
|
||||
|
||||
self.auth_provider = None
|
||||
auth_provider = conf.get('cassandra_auth_provider', None)
|
||||
auth_kwargs = conf.get('cassandra_auth_kwargs', None)
|
||||
if auth_provider and auth_kwargs:
|
||||
auth_provider_class = getattr(cassandra.auth, auth_provider, None)
|
||||
if not auth_provider_class:
|
||||
raise ImproperlyConfigured(E_NO_SUCH_CASSANDRA_AUTH_PROVIDER)
|
||||
self.auth_provider = auth_provider_class(**auth_kwargs)
|
||||
|
||||
self._cluster = None
|
||||
self._session = None
|
||||
self._write_stmt = None
|
||||
self._read_stmt = None
|
||||
self._lock = threading.RLock()
|
||||
|
||||
def _get_connection(self, write=False):
|
||||
"""Prepare the connection for action.
|
||||
|
||||
Arguments:
|
||||
write (bool): are we a writer?
|
||||
"""
|
||||
if self._session is not None:
|
||||
return
|
||||
self._lock.acquire()
|
||||
try:
|
||||
if self._session is not None:
|
||||
return
|
||||
# using either 'servers' or 'bundle_path' here:
|
||||
if self.servers:
|
||||
self._cluster = cassandra.cluster.Cluster(
|
||||
self.servers, port=self.port,
|
||||
auth_provider=self.auth_provider,
|
||||
**self.cassandra_options)
|
||||
else:
|
||||
# 'bundle_path' is guaranteed to be set
|
||||
self._cluster = cassandra.cluster.Cluster(
|
||||
cloud={
|
||||
'secure_connect_bundle': self.bundle_path,
|
||||
},
|
||||
auth_provider=self.auth_provider,
|
||||
**self.cassandra_options)
|
||||
self._session = self._cluster.connect(self.keyspace)
|
||||
|
||||
# We're forced to do concatenation below, as formatting would
|
||||
# blow up on superficial %s that'll be processed by Cassandra
|
||||
self._write_stmt = cassandra.query.SimpleStatement(
|
||||
Q_INSERT_RESULT.format(
|
||||
table=self.table, expires=self.cqlexpires),
|
||||
)
|
||||
self._write_stmt.consistency_level = self.write_consistency
|
||||
|
||||
self._read_stmt = cassandra.query.SimpleStatement(
|
||||
Q_SELECT_RESULT.format(table=self.table),
|
||||
)
|
||||
self._read_stmt.consistency_level = self.read_consistency
|
||||
|
||||
if write:
|
||||
# Only possible writers "workers" are allowed to issue
|
||||
# CREATE TABLE. This is to prevent conflicting situations
|
||||
# where both task-creator and task-executor would issue it
|
||||
# at the same time.
|
||||
|
||||
# Anyway; if you're doing anything critical, you should
|
||||
# have created this table in advance, in which case
|
||||
# this query will be a no-op (AlreadyExists)
|
||||
make_stmt = cassandra.query.SimpleStatement(
|
||||
Q_CREATE_RESULT_TABLE.format(table=self.table),
|
||||
)
|
||||
make_stmt.consistency_level = self.write_consistency
|
||||
|
||||
try:
|
||||
self._session.execute(make_stmt)
|
||||
except cassandra.AlreadyExists:
|
||||
pass
|
||||
|
||||
except cassandra.OperationTimedOut:
|
||||
# a heavily loaded or gone Cassandra cluster failed to respond.
|
||||
# leave this class in a consistent state
|
||||
if self._cluster is not None:
|
||||
self._cluster.shutdown() # also shuts down _session
|
||||
|
||||
self._cluster = None
|
||||
self._session = None
|
||||
raise # we did fail after all - reraise
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def _store_result(self, task_id, result, state,
|
||||
traceback=None, request=None, **kwargs):
|
||||
"""Store return value and state of an executed task."""
|
||||
self._get_connection(write=True)
|
||||
|
||||
self._session.execute(self._write_stmt, (
|
||||
task_id,
|
||||
state,
|
||||
buf_t(self.encode(result)),
|
||||
self.app.now(),
|
||||
buf_t(self.encode(traceback)),
|
||||
buf_t(self.encode(self.current_task_children(request)))
|
||||
))
|
||||
|
||||
def as_uri(self, include_password=True):
|
||||
return 'cassandra://'
|
||||
|
||||
def _get_task_meta_for(self, task_id):
|
||||
"""Get task meta-data for a task by id."""
|
||||
self._get_connection()
|
||||
|
||||
res = self._session.execute(self._read_stmt, (task_id, )).one()
|
||||
if not res:
|
||||
return {'status': states.PENDING, 'result': None}
|
||||
|
||||
status, result, date_done, traceback, children = res
|
||||
|
||||
return self.meta_from_decoded({
|
||||
'task_id': task_id,
|
||||
'status': status,
|
||||
'result': self.decode(result),
|
||||
'date_done': date_done,
|
||||
'traceback': self.decode(traceback),
|
||||
'children': self.decode(children),
|
||||
})
|
||||
|
||||
def __reduce__(self, args=(), kwargs=None):
|
||||
kwargs = {} if not kwargs else kwargs
|
||||
kwargs.update(
|
||||
{'servers': self.servers,
|
||||
'keyspace': self.keyspace,
|
||||
'table': self.table})
|
||||
return super().__reduce__(args, kwargs)
|
||||
@@ -0,0 +1,116 @@
|
||||
"""Consul result store backend.
|
||||
|
||||
- :class:`ConsulBackend` implements KeyValueStoreBackend to store results
|
||||
in the key-value store of Consul.
|
||||
"""
|
||||
from kombu.utils.encoding import bytes_to_str
|
||||
from kombu.utils.url import parse_url
|
||||
|
||||
from celery.backends.base import KeyValueStoreBackend
|
||||
from celery.exceptions import ImproperlyConfigured
|
||||
from celery.utils.log import get_logger
|
||||
|
||||
try:
|
||||
import consul
|
||||
except ImportError:
|
||||
consul = None
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
__all__ = ('ConsulBackend',)
|
||||
|
||||
CONSUL_MISSING = """\
|
||||
You need to install the python-consul library in order to use \
|
||||
the Consul result store backend."""
|
||||
|
||||
|
||||
class ConsulBackend(KeyValueStoreBackend):
|
||||
"""Consul.io K/V store backend for Celery."""
|
||||
|
||||
consul = consul
|
||||
|
||||
supports_autoexpire = True
|
||||
|
||||
consistency = 'consistent'
|
||||
path = None
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
if self.consul is None:
|
||||
raise ImproperlyConfigured(CONSUL_MISSING)
|
||||
#
|
||||
# By default, for correctness, we use a client connection per
|
||||
# operation. If set, self.one_client will be used for all operations.
|
||||
# This provides for the original behaviour to be selected, and is
|
||||
# also convenient for mocking in the unit tests.
|
||||
#
|
||||
self.one_client = None
|
||||
self._init_from_params(**parse_url(self.url))
|
||||
|
||||
def _init_from_params(self, hostname, port, virtual_host, **params):
|
||||
logger.debug('Setting on Consul client to connect to %s:%d',
|
||||
hostname, port)
|
||||
self.path = virtual_host
|
||||
self.hostname = hostname
|
||||
self.port = port
|
||||
#
|
||||
# Optionally, allow a single client connection to be used to reduce
|
||||
# the connection load on Consul by adding a "one_client=1" parameter
|
||||
# to the URL.
|
||||
#
|
||||
if params.get('one_client', None):
|
||||
self.one_client = self.client()
|
||||
|
||||
def client(self):
|
||||
return self.one_client or consul.Consul(host=self.hostname,
|
||||
port=self.port,
|
||||
consistency=self.consistency)
|
||||
|
||||
def _key_to_consul_key(self, key):
|
||||
key = bytes_to_str(key)
|
||||
return key if self.path is None else f'{self.path}/{key}'
|
||||
|
||||
def get(self, key):
|
||||
key = self._key_to_consul_key(key)
|
||||
logger.debug('Trying to fetch key %s from Consul', key)
|
||||
try:
|
||||
_, data = self.client().kv.get(key)
|
||||
return data['Value']
|
||||
except TypeError:
|
||||
pass
|
||||
|
||||
def mget(self, keys):
|
||||
for key in keys:
|
||||
yield self.get(key)
|
||||
|
||||
def set(self, key, value):
|
||||
"""Set a key in Consul.
|
||||
|
||||
Before creating the key it will create a session inside Consul
|
||||
where it creates a session with a TTL
|
||||
|
||||
The key created afterwards will reference to the session's ID.
|
||||
|
||||
If the session expires it will remove the key so that results
|
||||
can auto expire from the K/V store
|
||||
"""
|
||||
session_name = bytes_to_str(key)
|
||||
|
||||
key = self._key_to_consul_key(key)
|
||||
|
||||
logger.debug('Trying to create Consul session %s with TTL %d',
|
||||
session_name, self.expires)
|
||||
client = self.client()
|
||||
session_id = client.session.create(name=session_name,
|
||||
behavior='delete',
|
||||
ttl=self.expires)
|
||||
logger.debug('Created Consul session %s', session_id)
|
||||
|
||||
logger.debug('Writing key %s to Consul', key)
|
||||
return client.kv.put(key=key, value=value, acquire=session_id)
|
||||
|
||||
def delete(self, key):
|
||||
key = self._key_to_consul_key(key)
|
||||
logger.debug('Removing key %s from Consul', key)
|
||||
return self.client().kv.delete(key)
|
||||
@@ -0,0 +1,218 @@
|
||||
"""The CosmosDB/SQL backend for Celery (experimental)."""
|
||||
from kombu.utils import cached_property
|
||||
from kombu.utils.encoding import bytes_to_str
|
||||
from kombu.utils.url import _parse_url
|
||||
|
||||
from celery.exceptions import ImproperlyConfigured
|
||||
from celery.utils.log import get_logger
|
||||
|
||||
from .base import KeyValueStoreBackend
|
||||
|
||||
try:
|
||||
import pydocumentdb
|
||||
from pydocumentdb.document_client import DocumentClient
|
||||
from pydocumentdb.documents import ConnectionPolicy, ConsistencyLevel, PartitionKind
|
||||
from pydocumentdb.errors import HTTPFailure
|
||||
from pydocumentdb.retry_options import RetryOptions
|
||||
except ImportError:
|
||||
pydocumentdb = DocumentClient = ConsistencyLevel = PartitionKind = \
|
||||
HTTPFailure = ConnectionPolicy = RetryOptions = None
|
||||
|
||||
__all__ = ("CosmosDBSQLBackend",)
|
||||
|
||||
|
||||
ERROR_NOT_FOUND = 404
|
||||
ERROR_EXISTS = 409
|
||||
|
||||
LOGGER = get_logger(__name__)
|
||||
|
||||
|
||||
class CosmosDBSQLBackend(KeyValueStoreBackend):
|
||||
"""CosmosDB/SQL backend for Celery."""
|
||||
|
||||
def __init__(self,
|
||||
url=None,
|
||||
database_name=None,
|
||||
collection_name=None,
|
||||
consistency_level=None,
|
||||
max_retry_attempts=None,
|
||||
max_retry_wait_time=None,
|
||||
*args,
|
||||
**kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
if pydocumentdb is None:
|
||||
raise ImproperlyConfigured(
|
||||
"You need to install the pydocumentdb library to use the "
|
||||
"CosmosDB backend.")
|
||||
|
||||
conf = self.app.conf
|
||||
|
||||
self._endpoint, self._key = self._parse_url(url)
|
||||
|
||||
self._database_name = (
|
||||
database_name or
|
||||
conf["cosmosdbsql_database_name"])
|
||||
|
||||
self._collection_name = (
|
||||
collection_name or
|
||||
conf["cosmosdbsql_collection_name"])
|
||||
|
||||
try:
|
||||
self._consistency_level = getattr(
|
||||
ConsistencyLevel,
|
||||
consistency_level or
|
||||
conf["cosmosdbsql_consistency_level"])
|
||||
except AttributeError:
|
||||
raise ImproperlyConfigured("Unknown CosmosDB consistency level")
|
||||
|
||||
self._max_retry_attempts = (
|
||||
max_retry_attempts or
|
||||
conf["cosmosdbsql_max_retry_attempts"])
|
||||
|
||||
self._max_retry_wait_time = (
|
||||
max_retry_wait_time or
|
||||
conf["cosmosdbsql_max_retry_wait_time"])
|
||||
|
||||
@classmethod
|
||||
def _parse_url(cls, url):
|
||||
_, host, port, _, password, _, _ = _parse_url(url)
|
||||
|
||||
if not host or not password:
|
||||
raise ImproperlyConfigured("Invalid URL")
|
||||
|
||||
if not port:
|
||||
port = 443
|
||||
|
||||
scheme = "https" if port == 443 else "http"
|
||||
endpoint = f"{scheme}://{host}:{port}"
|
||||
return endpoint, password
|
||||
|
||||
@cached_property
|
||||
def _client(self):
|
||||
"""Return the CosmosDB/SQL client.
|
||||
|
||||
If this is the first call to the property, the client is created and
|
||||
the database and collection are initialized if they don't yet exist.
|
||||
|
||||
"""
|
||||
connection_policy = ConnectionPolicy()
|
||||
connection_policy.RetryOptions = RetryOptions(
|
||||
max_retry_attempt_count=self._max_retry_attempts,
|
||||
max_wait_time_in_seconds=self._max_retry_wait_time)
|
||||
|
||||
client = DocumentClient(
|
||||
self._endpoint,
|
||||
{"masterKey": self._key},
|
||||
connection_policy=connection_policy,
|
||||
consistency_level=self._consistency_level)
|
||||
|
||||
self._create_database_if_not_exists(client)
|
||||
self._create_collection_if_not_exists(client)
|
||||
|
||||
return client
|
||||
|
||||
def _create_database_if_not_exists(self, client):
|
||||
try:
|
||||
client.CreateDatabase({"id": self._database_name})
|
||||
except HTTPFailure as ex:
|
||||
if ex.status_code != ERROR_EXISTS:
|
||||
raise
|
||||
else:
|
||||
LOGGER.info("Created CosmosDB database %s",
|
||||
self._database_name)
|
||||
|
||||
def _create_collection_if_not_exists(self, client):
|
||||
try:
|
||||
client.CreateCollection(
|
||||
self._database_link,
|
||||
{"id": self._collection_name,
|
||||
"partitionKey": {"paths": ["/id"],
|
||||
"kind": PartitionKind.Hash}})
|
||||
except HTTPFailure as ex:
|
||||
if ex.status_code != ERROR_EXISTS:
|
||||
raise
|
||||
else:
|
||||
LOGGER.info("Created CosmosDB collection %s/%s",
|
||||
self._database_name, self._collection_name)
|
||||
|
||||
@cached_property
|
||||
def _database_link(self):
|
||||
return "dbs/" + self._database_name
|
||||
|
||||
@cached_property
|
||||
def _collection_link(self):
|
||||
return self._database_link + "/colls/" + self._collection_name
|
||||
|
||||
def _get_document_link(self, key):
|
||||
return self._collection_link + "/docs/" + key
|
||||
|
||||
@classmethod
|
||||
def _get_partition_key(cls, key):
|
||||
if not key or key.isspace():
|
||||
raise ValueError("Key cannot be none, empty or whitespace.")
|
||||
|
||||
return {"partitionKey": key}
|
||||
|
||||
def get(self, key):
|
||||
"""Read the value stored at the given key.
|
||||
|
||||
Args:
|
||||
key: The key for which to read the value.
|
||||
|
||||
"""
|
||||
key = bytes_to_str(key)
|
||||
LOGGER.debug("Getting CosmosDB document %s/%s/%s",
|
||||
self._database_name, self._collection_name, key)
|
||||
|
||||
try:
|
||||
document = self._client.ReadDocument(
|
||||
self._get_document_link(key),
|
||||
self._get_partition_key(key))
|
||||
except HTTPFailure as ex:
|
||||
if ex.status_code != ERROR_NOT_FOUND:
|
||||
raise
|
||||
return None
|
||||
else:
|
||||
return document.get("value")
|
||||
|
||||
def set(self, key, value):
|
||||
"""Store a value for a given key.
|
||||
|
||||
Args:
|
||||
key: The key at which to store the value.
|
||||
value: The value to store.
|
||||
|
||||
"""
|
||||
key = bytes_to_str(key)
|
||||
LOGGER.debug("Creating CosmosDB document %s/%s/%s",
|
||||
self._database_name, self._collection_name, key)
|
||||
|
||||
self._client.CreateDocument(
|
||||
self._collection_link,
|
||||
{"id": key, "value": value},
|
||||
self._get_partition_key(key))
|
||||
|
||||
def mget(self, keys):
|
||||
"""Read all the values for the provided keys.
|
||||
|
||||
Args:
|
||||
keys: The list of keys to read.
|
||||
|
||||
"""
|
||||
return [self.get(key) for key in keys]
|
||||
|
||||
def delete(self, key):
|
||||
"""Delete the value at a given key.
|
||||
|
||||
Args:
|
||||
key: The key of the value to delete.
|
||||
|
||||
"""
|
||||
key = bytes_to_str(key)
|
||||
LOGGER.debug("Deleting CosmosDB document %s/%s/%s",
|
||||
self._database_name, self._collection_name, key)
|
||||
|
||||
self._client.DeleteDocument(
|
||||
self._get_document_link(key),
|
||||
self._get_partition_key(key))
|
||||
@@ -0,0 +1,114 @@
|
||||
"""Couchbase result store backend."""
|
||||
|
||||
from kombu.utils.url import _parse_url
|
||||
|
||||
from celery.exceptions import ImproperlyConfigured
|
||||
|
||||
from .base import KeyValueStoreBackend
|
||||
|
||||
try:
|
||||
from couchbase.auth import PasswordAuthenticator
|
||||
from couchbase.cluster import Cluster
|
||||
except ImportError:
|
||||
Cluster = PasswordAuthenticator = None
|
||||
|
||||
try:
|
||||
from couchbase_core._libcouchbase import FMT_AUTO
|
||||
except ImportError:
|
||||
FMT_AUTO = None
|
||||
|
||||
__all__ = ('CouchbaseBackend',)
|
||||
|
||||
|
||||
class CouchbaseBackend(KeyValueStoreBackend):
|
||||
"""Couchbase backend.
|
||||
|
||||
Raises:
|
||||
celery.exceptions.ImproperlyConfigured:
|
||||
if module :pypi:`couchbase` is not available.
|
||||
"""
|
||||
|
||||
bucket = 'default'
|
||||
host = 'localhost'
|
||||
port = 8091
|
||||
username = None
|
||||
password = None
|
||||
quiet = False
|
||||
supports_autoexpire = True
|
||||
|
||||
timeout = 2.5
|
||||
|
||||
# Use str as couchbase key not bytes
|
||||
key_t = str
|
||||
|
||||
def __init__(self, url=None, *args, **kwargs):
|
||||
kwargs.setdefault('expires_type', int)
|
||||
super().__init__(*args, **kwargs)
|
||||
self.url = url
|
||||
|
||||
if Cluster is None:
|
||||
raise ImproperlyConfigured(
|
||||
'You need to install the couchbase library to use the '
|
||||
'Couchbase backend.',
|
||||
)
|
||||
|
||||
uhost = uport = uname = upass = ubucket = None
|
||||
if url:
|
||||
_, uhost, uport, uname, upass, ubucket, _ = _parse_url(url)
|
||||
ubucket = ubucket.strip('/') if ubucket else None
|
||||
|
||||
config = self.app.conf.get('couchbase_backend_settings', None)
|
||||
if config is not None:
|
||||
if not isinstance(config, dict):
|
||||
raise ImproperlyConfigured(
|
||||
'Couchbase backend settings should be grouped in a dict',
|
||||
)
|
||||
else:
|
||||
config = {}
|
||||
|
||||
self.host = uhost or config.get('host', self.host)
|
||||
self.port = int(uport or config.get('port', self.port))
|
||||
self.bucket = ubucket or config.get('bucket', self.bucket)
|
||||
self.username = uname or config.get('username', self.username)
|
||||
self.password = upass or config.get('password', self.password)
|
||||
|
||||
self._connection = None
|
||||
|
||||
def _get_connection(self):
|
||||
"""Connect to the Couchbase server."""
|
||||
if self._connection is None:
|
||||
if self.host and self.port:
|
||||
uri = f"couchbase://{self.host}:{self.port}"
|
||||
else:
|
||||
uri = f"couchbase://{self.host}"
|
||||
if self.username and self.password:
|
||||
opt = PasswordAuthenticator(self.username, self.password)
|
||||
else:
|
||||
opt = None
|
||||
|
||||
cluster = Cluster(uri, opt)
|
||||
|
||||
bucket = cluster.bucket(self.bucket)
|
||||
|
||||
self._connection = bucket.default_collection()
|
||||
return self._connection
|
||||
|
||||
@property
|
||||
def connection(self):
|
||||
return self._get_connection()
|
||||
|
||||
def get(self, key):
|
||||
return self.connection.get(key).content
|
||||
|
||||
def set(self, key, value):
|
||||
# Since 4.0.0 value is JSONType in couchbase lib, so parameter format isn't needed
|
||||
if FMT_AUTO is not None:
|
||||
self.connection.upsert(key, value, ttl=self.expires, format=FMT_AUTO)
|
||||
else:
|
||||
self.connection.upsert(key, value, ttl=self.expires)
|
||||
|
||||
def mget(self, keys):
|
||||
return self.connection.get_multi(keys)
|
||||
|
||||
def delete(self, key):
|
||||
self.connection.remove(key)
|
||||
@@ -0,0 +1,99 @@
|
||||
"""CouchDB result store backend."""
|
||||
from kombu.utils.encoding import bytes_to_str
|
||||
from kombu.utils.url import _parse_url
|
||||
|
||||
from celery.exceptions import ImproperlyConfigured
|
||||
|
||||
from .base import KeyValueStoreBackend
|
||||
|
||||
try:
|
||||
import pycouchdb
|
||||
except ImportError:
|
||||
pycouchdb = None
|
||||
|
||||
__all__ = ('CouchBackend',)
|
||||
|
||||
ERR_LIB_MISSING = """\
|
||||
You need to install the pycouchdb library to use the CouchDB result backend\
|
||||
"""
|
||||
|
||||
|
||||
class CouchBackend(KeyValueStoreBackend):
|
||||
"""CouchDB backend.
|
||||
|
||||
Raises:
|
||||
celery.exceptions.ImproperlyConfigured:
|
||||
if module :pypi:`pycouchdb` is not available.
|
||||
"""
|
||||
|
||||
container = 'default'
|
||||
scheme = 'http'
|
||||
host = 'localhost'
|
||||
port = 5984
|
||||
username = None
|
||||
password = None
|
||||
|
||||
def __init__(self, url=None, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.url = url
|
||||
|
||||
if pycouchdb is None:
|
||||
raise ImproperlyConfigured(ERR_LIB_MISSING)
|
||||
|
||||
uscheme = uhost = uport = uname = upass = ucontainer = None
|
||||
if url:
|
||||
_, uhost, uport, uname, upass, ucontainer, _ = _parse_url(url)
|
||||
ucontainer = ucontainer.strip('/') if ucontainer else None
|
||||
|
||||
self.scheme = uscheme or self.scheme
|
||||
self.host = uhost or self.host
|
||||
self.port = int(uport or self.port)
|
||||
self.container = ucontainer or self.container
|
||||
self.username = uname or self.username
|
||||
self.password = upass or self.password
|
||||
|
||||
self._connection = None
|
||||
|
||||
def _get_connection(self):
|
||||
"""Connect to the CouchDB server."""
|
||||
if self.username and self.password:
|
||||
conn_string = f'{self.scheme}://{self.username}:{self.password}@{self.host}:{self.port}'
|
||||
server = pycouchdb.Server(conn_string, authmethod='basic')
|
||||
else:
|
||||
conn_string = f'{self.scheme}://{self.host}:{self.port}'
|
||||
server = pycouchdb.Server(conn_string)
|
||||
|
||||
try:
|
||||
return server.database(self.container)
|
||||
except pycouchdb.exceptions.NotFound:
|
||||
return server.create(self.container)
|
||||
|
||||
@property
|
||||
def connection(self):
|
||||
if self._connection is None:
|
||||
self._connection = self._get_connection()
|
||||
return self._connection
|
||||
|
||||
def get(self, key):
|
||||
key = bytes_to_str(key)
|
||||
try:
|
||||
return self.connection.get(key)['value']
|
||||
except pycouchdb.exceptions.NotFound:
|
||||
return None
|
||||
|
||||
def set(self, key, value):
|
||||
key = bytes_to_str(key)
|
||||
data = {'_id': key, 'value': value}
|
||||
try:
|
||||
self.connection.save(data)
|
||||
except pycouchdb.exceptions.Conflict:
|
||||
# document already exists, update it
|
||||
data = self.connection.get(key)
|
||||
data['value'] = value
|
||||
self.connection.save(data)
|
||||
|
||||
def mget(self, keys):
|
||||
return [self.get(key) for key in keys]
|
||||
|
||||
def delete(self, key):
|
||||
self.connection.delete(key)
|
||||
@@ -0,0 +1,234 @@
|
||||
"""SQLAlchemy result store backend."""
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
|
||||
from vine.utils import wraps
|
||||
|
||||
from celery import states
|
||||
from celery.backends.base import BaseBackend
|
||||
from celery.exceptions import ImproperlyConfigured
|
||||
from celery.utils.time import maybe_timedelta
|
||||
|
||||
from .models import Task, TaskExtended, TaskSet
|
||||
from .session import SessionManager
|
||||
|
||||
try:
|
||||
from sqlalchemy.exc import DatabaseError, InvalidRequestError
|
||||
from sqlalchemy.orm.exc import StaleDataError
|
||||
except ImportError:
|
||||
raise ImproperlyConfigured(
|
||||
'The database result backend requires SQLAlchemy to be installed.'
|
||||
'See https://pypi.org/project/SQLAlchemy/')
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
__all__ = ('DatabaseBackend',)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def session_cleanup(session):
|
||||
try:
|
||||
yield
|
||||
except Exception:
|
||||
session.rollback()
|
||||
raise
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
|
||||
def retry(fun):
|
||||
|
||||
@wraps(fun)
|
||||
def _inner(*args, **kwargs):
|
||||
max_retries = kwargs.pop('max_retries', 3)
|
||||
|
||||
for retries in range(max_retries):
|
||||
try:
|
||||
return fun(*args, **kwargs)
|
||||
except (DatabaseError, InvalidRequestError, StaleDataError):
|
||||
logger.warning(
|
||||
'Failed operation %s. Retrying %s more times.',
|
||||
fun.__name__, max_retries - retries - 1,
|
||||
exc_info=True)
|
||||
if retries + 1 >= max_retries:
|
||||
raise
|
||||
|
||||
return _inner
|
||||
|
||||
|
||||
class DatabaseBackend(BaseBackend):
|
||||
"""The database result backend."""
|
||||
|
||||
# ResultSet.iterate should sleep this much between each pool,
|
||||
# to not bombard the database with queries.
|
||||
subpolling_interval = 0.5
|
||||
|
||||
task_cls = Task
|
||||
taskset_cls = TaskSet
|
||||
|
||||
def __init__(self, dburi=None, engine_options=None, url=None, **kwargs):
|
||||
# The `url` argument was added later and is used by
|
||||
# the app to set backend by url (celery.app.backends.by_url)
|
||||
super().__init__(expires_type=maybe_timedelta,
|
||||
url=url, **kwargs)
|
||||
conf = self.app.conf
|
||||
|
||||
if self.extended_result:
|
||||
self.task_cls = TaskExtended
|
||||
|
||||
self.url = url or dburi or conf.database_url
|
||||
self.engine_options = dict(
|
||||
engine_options or {},
|
||||
**conf.database_engine_options or {})
|
||||
self.short_lived_sessions = kwargs.get(
|
||||
'short_lived_sessions',
|
||||
conf.database_short_lived_sessions)
|
||||
|
||||
schemas = conf.database_table_schemas or {}
|
||||
tablenames = conf.database_table_names or {}
|
||||
self.task_cls.configure(
|
||||
schema=schemas.get('task'),
|
||||
name=tablenames.get('task'))
|
||||
self.taskset_cls.configure(
|
||||
schema=schemas.get('group'),
|
||||
name=tablenames.get('group'))
|
||||
|
||||
if not self.url:
|
||||
raise ImproperlyConfigured(
|
||||
'Missing connection string! Do you have the'
|
||||
' database_url setting set to a real value?')
|
||||
|
||||
self.session_manager = SessionManager()
|
||||
|
||||
create_tables_at_setup = conf.database_create_tables_at_setup
|
||||
if create_tables_at_setup is True:
|
||||
self._create_tables()
|
||||
|
||||
@property
|
||||
def extended_result(self):
|
||||
return self.app.conf.find_value_for_key('extended', 'result')
|
||||
|
||||
def _create_tables(self):
|
||||
"""Create the task and taskset tables."""
|
||||
self.ResultSession()
|
||||
|
||||
def ResultSession(self, session_manager=None):
|
||||
if session_manager is None:
|
||||
session_manager = self.session_manager
|
||||
return session_manager.session_factory(
|
||||
dburi=self.url,
|
||||
short_lived_sessions=self.short_lived_sessions,
|
||||
**self.engine_options)
|
||||
|
||||
@retry
|
||||
def _store_result(self, task_id, result, state, traceback=None,
|
||||
request=None, **kwargs):
|
||||
"""Store return value and state of an executed task."""
|
||||
session = self.ResultSession()
|
||||
with session_cleanup(session):
|
||||
task = list(session.query(self.task_cls).filter(self.task_cls.task_id == task_id))
|
||||
task = task and task[0]
|
||||
if not task:
|
||||
task = self.task_cls(task_id)
|
||||
task.task_id = task_id
|
||||
session.add(task)
|
||||
session.flush()
|
||||
|
||||
self._update_result(task, result, state, traceback=traceback, request=request)
|
||||
session.commit()
|
||||
|
||||
def _update_result(self, task, result, state, traceback=None,
|
||||
request=None):
|
||||
|
||||
meta = self._get_result_meta(result=result, state=state,
|
||||
traceback=traceback, request=request,
|
||||
format_date=False, encode=True)
|
||||
|
||||
# Exclude the primary key id and task_id columns
|
||||
# as we should not set it None
|
||||
columns = [column.name for column in self.task_cls.__table__.columns
|
||||
if column.name not in {'id', 'task_id'}]
|
||||
|
||||
# Iterate through the columns name of the table
|
||||
# to set the value from meta.
|
||||
# If the value is not present in meta, set None
|
||||
for column in columns:
|
||||
value = meta.get(column)
|
||||
setattr(task, column, value)
|
||||
|
||||
@retry
|
||||
def _get_task_meta_for(self, task_id):
|
||||
"""Get task meta-data for a task by id."""
|
||||
session = self.ResultSession()
|
||||
with session_cleanup(session):
|
||||
task = list(session.query(self.task_cls).filter(self.task_cls.task_id == task_id))
|
||||
task = task and task[0]
|
||||
if not task:
|
||||
task = self.task_cls(task_id)
|
||||
task.status = states.PENDING
|
||||
task.result = None
|
||||
data = task.to_dict()
|
||||
if data.get('args', None) is not None:
|
||||
data['args'] = self.decode(data['args'])
|
||||
if data.get('kwargs', None) is not None:
|
||||
data['kwargs'] = self.decode(data['kwargs'])
|
||||
return self.meta_from_decoded(data)
|
||||
|
||||
@retry
|
||||
def _save_group(self, group_id, result):
|
||||
"""Store the result of an executed group."""
|
||||
session = self.ResultSession()
|
||||
with session_cleanup(session):
|
||||
group = self.taskset_cls(group_id, result)
|
||||
session.add(group)
|
||||
session.flush()
|
||||
session.commit()
|
||||
return result
|
||||
|
||||
@retry
|
||||
def _restore_group(self, group_id):
|
||||
"""Get meta-data for group by id."""
|
||||
session = self.ResultSession()
|
||||
with session_cleanup(session):
|
||||
group = session.query(self.taskset_cls).filter(
|
||||
self.taskset_cls.taskset_id == group_id).first()
|
||||
if group:
|
||||
return group.to_dict()
|
||||
|
||||
@retry
|
||||
def _delete_group(self, group_id):
|
||||
"""Delete meta-data for group by id."""
|
||||
session = self.ResultSession()
|
||||
with session_cleanup(session):
|
||||
session.query(self.taskset_cls).filter(
|
||||
self.taskset_cls.taskset_id == group_id).delete()
|
||||
session.flush()
|
||||
session.commit()
|
||||
|
||||
@retry
|
||||
def _forget(self, task_id):
|
||||
"""Forget about result."""
|
||||
session = self.ResultSession()
|
||||
with session_cleanup(session):
|
||||
session.query(self.task_cls).filter(self.task_cls.task_id == task_id).delete()
|
||||
session.commit()
|
||||
|
||||
def cleanup(self):
|
||||
"""Delete expired meta-data."""
|
||||
session = self.ResultSession()
|
||||
expires = self.expires
|
||||
now = self.app.now()
|
||||
with session_cleanup(session):
|
||||
session.query(self.task_cls).filter(
|
||||
self.task_cls.date_done < (now - expires)).delete()
|
||||
session.query(self.taskset_cls).filter(
|
||||
self.taskset_cls.date_done < (now - expires)).delete()
|
||||
session.commit()
|
||||
|
||||
def __reduce__(self, args=(), kwargs=None):
|
||||
kwargs = {} if not kwargs else kwargs
|
||||
kwargs.update(
|
||||
{'dburi': self.url,
|
||||
'expires': self.expires,
|
||||
'engine_options': self.engine_options})
|
||||
return super().__reduce__(args, kwargs)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,108 @@
|
||||
"""Database models used by the SQLAlchemy result store backend."""
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.types import PickleType
|
||||
|
||||
from celery import states
|
||||
|
||||
from .session import ResultModelBase
|
||||
|
||||
__all__ = ('Task', 'TaskExtended', 'TaskSet')
|
||||
|
||||
|
||||
class Task(ResultModelBase):
|
||||
"""Task result/status."""
|
||||
|
||||
__tablename__ = 'celery_taskmeta'
|
||||
__table_args__ = {'sqlite_autoincrement': True}
|
||||
|
||||
id = sa.Column(sa.Integer, sa.Sequence('task_id_sequence'),
|
||||
primary_key=True, autoincrement=True)
|
||||
task_id = sa.Column(sa.String(155), unique=True)
|
||||
status = sa.Column(sa.String(50), default=states.PENDING)
|
||||
result = sa.Column(PickleType, nullable=True)
|
||||
date_done = sa.Column(sa.DateTime, default=datetime.now(timezone.utc),
|
||||
onupdate=datetime.now(timezone.utc), nullable=True)
|
||||
traceback = sa.Column(sa.Text, nullable=True)
|
||||
|
||||
def __init__(self, task_id):
|
||||
self.task_id = task_id
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
'task_id': self.task_id,
|
||||
'status': self.status,
|
||||
'result': self.result,
|
||||
'traceback': self.traceback,
|
||||
'date_done': self.date_done,
|
||||
}
|
||||
|
||||
def __repr__(self):
|
||||
return '<Task {0.task_id} state: {0.status}>'.format(self)
|
||||
|
||||
@classmethod
|
||||
def configure(cls, schema=None, name=None):
|
||||
cls.__table__.schema = schema
|
||||
cls.id.default.schema = schema
|
||||
cls.__table__.name = name or cls.__tablename__
|
||||
|
||||
|
||||
class TaskExtended(Task):
|
||||
"""For the extend result."""
|
||||
|
||||
__tablename__ = 'celery_taskmeta'
|
||||
__table_args__ = {'sqlite_autoincrement': True, 'extend_existing': True}
|
||||
|
||||
name = sa.Column(sa.String(155), nullable=True)
|
||||
args = sa.Column(sa.LargeBinary, nullable=True)
|
||||
kwargs = sa.Column(sa.LargeBinary, nullable=True)
|
||||
worker = sa.Column(sa.String(155), nullable=True)
|
||||
retries = sa.Column(sa.Integer, nullable=True)
|
||||
queue = sa.Column(sa.String(155), nullable=True)
|
||||
|
||||
def to_dict(self):
|
||||
task_dict = super().to_dict()
|
||||
task_dict.update({
|
||||
'name': self.name,
|
||||
'args': self.args,
|
||||
'kwargs': self.kwargs,
|
||||
'worker': self.worker,
|
||||
'retries': self.retries,
|
||||
'queue': self.queue,
|
||||
})
|
||||
return task_dict
|
||||
|
||||
|
||||
class TaskSet(ResultModelBase):
|
||||
"""TaskSet result."""
|
||||
|
||||
__tablename__ = 'celery_tasksetmeta'
|
||||
__table_args__ = {'sqlite_autoincrement': True}
|
||||
|
||||
id = sa.Column(sa.Integer, sa.Sequence('taskset_id_sequence'),
|
||||
autoincrement=True, primary_key=True)
|
||||
taskset_id = sa.Column(sa.String(155), unique=True)
|
||||
result = sa.Column(PickleType, nullable=True)
|
||||
date_done = sa.Column(sa.DateTime, default=datetime.now(timezone.utc),
|
||||
nullable=True)
|
||||
|
||||
def __init__(self, taskset_id, result):
|
||||
self.taskset_id = taskset_id
|
||||
self.result = result
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
'taskset_id': self.taskset_id,
|
||||
'result': self.result,
|
||||
'date_done': self.date_done,
|
||||
}
|
||||
|
||||
def __repr__(self):
|
||||
return f'<TaskSet: {self.taskset_id}>'
|
||||
|
||||
@classmethod
|
||||
def configure(cls, schema=None, name=None):
|
||||
cls.__table__.schema = schema
|
||||
cls.id.default.schema = schema
|
||||
cls.__table__.name = name or cls.__tablename__
|
||||
@@ -0,0 +1,89 @@
|
||||
"""SQLAlchemy session."""
|
||||
import time
|
||||
|
||||
from kombu.utils.compat import register_after_fork
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.exc import DatabaseError
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.pool import NullPool
|
||||
|
||||
from celery.utils.time import get_exponential_backoff_interval
|
||||
|
||||
try:
|
||||
from sqlalchemy.orm import declarative_base
|
||||
except ImportError:
|
||||
# TODO: Remove this once we drop support for SQLAlchemy < 1.4.
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
ResultModelBase = declarative_base()
|
||||
|
||||
__all__ = ('SessionManager',)
|
||||
|
||||
PREPARE_MODELS_MAX_RETRIES = 10
|
||||
|
||||
|
||||
def _after_fork_cleanup_session(session):
|
||||
session._after_fork()
|
||||
|
||||
|
||||
class SessionManager:
|
||||
"""Manage SQLAlchemy sessions."""
|
||||
|
||||
def __init__(self):
|
||||
self._engines = {}
|
||||
self._sessions = {}
|
||||
self.forked = False
|
||||
self.prepared = False
|
||||
if register_after_fork is not None:
|
||||
register_after_fork(self, _after_fork_cleanup_session)
|
||||
|
||||
def _after_fork(self):
|
||||
self.forked = True
|
||||
|
||||
def get_engine(self, dburi, **kwargs):
|
||||
if self.forked:
|
||||
try:
|
||||
return self._engines[dburi]
|
||||
except KeyError:
|
||||
engine = self._engines[dburi] = create_engine(dburi, **kwargs)
|
||||
return engine
|
||||
else:
|
||||
kwargs = {k: v for k, v in kwargs.items() if
|
||||
not k.startswith('pool')}
|
||||
return create_engine(dburi, poolclass=NullPool, **kwargs)
|
||||
|
||||
def create_session(self, dburi, short_lived_sessions=False, **kwargs):
|
||||
engine = self.get_engine(dburi, **kwargs)
|
||||
if self.forked:
|
||||
if short_lived_sessions or dburi not in self._sessions:
|
||||
self._sessions[dburi] = sessionmaker(bind=engine)
|
||||
return engine, self._sessions[dburi]
|
||||
return engine, sessionmaker(bind=engine)
|
||||
|
||||
def prepare_models(self, engine):
|
||||
if not self.prepared:
|
||||
# SQLAlchemy will check if the items exist before trying to
|
||||
# create them, which is a race condition. If it raises an error
|
||||
# in one iteration, the next may pass all the existence checks
|
||||
# and the call will succeed.
|
||||
retries = 0
|
||||
while True:
|
||||
try:
|
||||
ResultModelBase.metadata.create_all(engine)
|
||||
except DatabaseError:
|
||||
if retries < PREPARE_MODELS_MAX_RETRIES:
|
||||
sleep_amount_ms = get_exponential_backoff_interval(
|
||||
10, retries, 1000, True
|
||||
)
|
||||
time.sleep(sleep_amount_ms / 1000)
|
||||
retries += 1
|
||||
else:
|
||||
raise
|
||||
else:
|
||||
break
|
||||
self.prepared = True
|
||||
|
||||
def session_factory(self, dburi, **kwargs):
|
||||
engine, session = self.create_session(dburi, **kwargs)
|
||||
self.prepare_models(engine)
|
||||
return session()
|
||||
@@ -0,0 +1,556 @@
|
||||
"""AWS DynamoDB result store backend."""
|
||||
from collections import namedtuple
|
||||
from ipaddress import ip_address
|
||||
from time import sleep, time
|
||||
from typing import Any, Dict
|
||||
|
||||
from kombu.utils.url import _parse_url as parse_url
|
||||
|
||||
from celery.exceptions import ImproperlyConfigured
|
||||
from celery.utils.log import get_logger
|
||||
|
||||
from .base import KeyValueStoreBackend
|
||||
|
||||
try:
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError
|
||||
except ImportError:
|
||||
boto3 = ClientError = None
|
||||
|
||||
__all__ = ('DynamoDBBackend',)
|
||||
|
||||
|
||||
# Helper class that describes a DynamoDB attribute
|
||||
DynamoDBAttribute = namedtuple('DynamoDBAttribute', ('name', 'data_type'))
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DynamoDBBackend(KeyValueStoreBackend):
|
||||
"""AWS DynamoDB result backend.
|
||||
|
||||
Raises:
|
||||
celery.exceptions.ImproperlyConfigured:
|
||||
if module :pypi:`boto3` is not available.
|
||||
"""
|
||||
|
||||
#: default DynamoDB table name (`default`)
|
||||
table_name = 'celery'
|
||||
|
||||
#: Read Provisioned Throughput (`default`)
|
||||
read_capacity_units = 1
|
||||
|
||||
#: Write Provisioned Throughput (`default`)
|
||||
write_capacity_units = 1
|
||||
|
||||
#: AWS region (`default`)
|
||||
aws_region = None
|
||||
|
||||
#: The endpoint URL that is passed to boto3 (local DynamoDB) (`default`)
|
||||
endpoint_url = None
|
||||
|
||||
#: Item time-to-live in seconds (`default`)
|
||||
time_to_live_seconds = None
|
||||
|
||||
# DynamoDB supports Time to Live as an auto-expiry mechanism.
|
||||
supports_autoexpire = True
|
||||
|
||||
_key_field = DynamoDBAttribute(name='id', data_type='S')
|
||||
# Each record has either a value field or count field
|
||||
_value_field = DynamoDBAttribute(name='result', data_type='B')
|
||||
_count_filed = DynamoDBAttribute(name="chord_count", data_type='N')
|
||||
_timestamp_field = DynamoDBAttribute(name='timestamp', data_type='N')
|
||||
_ttl_field = DynamoDBAttribute(name='ttl', data_type='N')
|
||||
_available_fields = None
|
||||
|
||||
implements_incr = True
|
||||
|
||||
def __init__(self, url=None, table_name=None, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.url = url
|
||||
self.table_name = table_name or self.table_name
|
||||
|
||||
if not boto3:
|
||||
raise ImproperlyConfigured(
|
||||
'You need to install the boto3 library to use the '
|
||||
'DynamoDB backend.')
|
||||
|
||||
aws_credentials_given = False
|
||||
aws_access_key_id = None
|
||||
aws_secret_access_key = None
|
||||
|
||||
if url is not None:
|
||||
scheme, region, port, username, password, table, query = \
|
||||
parse_url(url)
|
||||
|
||||
aws_access_key_id = username
|
||||
aws_secret_access_key = password
|
||||
|
||||
access_key_given = aws_access_key_id is not None
|
||||
secret_key_given = aws_secret_access_key is not None
|
||||
|
||||
if access_key_given != secret_key_given:
|
||||
raise ImproperlyConfigured(
|
||||
'You need to specify both the Access Key ID '
|
||||
'and Secret.')
|
||||
|
||||
aws_credentials_given = access_key_given
|
||||
|
||||
if region == 'localhost' or DynamoDBBackend._is_valid_ip(region):
|
||||
# We are using the downloadable, local version of DynamoDB
|
||||
self.endpoint_url = f'http://{region}:{port}'
|
||||
self.aws_region = 'us-east-1'
|
||||
logger.warning(
|
||||
'Using local-only DynamoDB endpoint URL: {}'.format(
|
||||
self.endpoint_url
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.aws_region = region
|
||||
|
||||
# If endpoint_url is explicitly set use it instead
|
||||
_get = self.app.conf.get
|
||||
config_endpoint_url = _get('dynamodb_endpoint_url')
|
||||
if config_endpoint_url:
|
||||
self.endpoint_url = config_endpoint_url
|
||||
|
||||
self.read_capacity_units = int(
|
||||
query.get(
|
||||
'read',
|
||||
self.read_capacity_units
|
||||
)
|
||||
)
|
||||
self.write_capacity_units = int(
|
||||
query.get(
|
||||
'write',
|
||||
self.write_capacity_units
|
||||
)
|
||||
)
|
||||
|
||||
ttl = query.get('ttl_seconds', self.time_to_live_seconds)
|
||||
if ttl:
|
||||
try:
|
||||
self.time_to_live_seconds = int(ttl)
|
||||
except ValueError as e:
|
||||
logger.error(
|
||||
f'TTL must be a number; got "{ttl}"',
|
||||
exc_info=e
|
||||
)
|
||||
raise e
|
||||
|
||||
self.table_name = table or self.table_name
|
||||
|
||||
self._available_fields = (
|
||||
self._key_field,
|
||||
self._value_field,
|
||||
self._timestamp_field
|
||||
)
|
||||
|
||||
self._client = None
|
||||
if aws_credentials_given:
|
||||
self._get_client(
|
||||
access_key_id=aws_access_key_id,
|
||||
secret_access_key=aws_secret_access_key
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _is_valid_ip(ip):
|
||||
try:
|
||||
ip_address(ip)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
def _get_client(self, access_key_id=None, secret_access_key=None):
|
||||
"""Get client connection."""
|
||||
if self._client is None:
|
||||
client_parameters = {
|
||||
'region_name': self.aws_region
|
||||
}
|
||||
if access_key_id is not None:
|
||||
client_parameters.update({
|
||||
'aws_access_key_id': access_key_id,
|
||||
'aws_secret_access_key': secret_access_key
|
||||
})
|
||||
|
||||
if self.endpoint_url is not None:
|
||||
client_parameters['endpoint_url'] = self.endpoint_url
|
||||
|
||||
self._client = boto3.client(
|
||||
'dynamodb',
|
||||
**client_parameters
|
||||
)
|
||||
self._get_or_create_table()
|
||||
|
||||
if self._has_ttl() is not None:
|
||||
self._validate_ttl_methods()
|
||||
self._set_table_ttl()
|
||||
|
||||
return self._client
|
||||
|
||||
def _get_table_schema(self):
|
||||
"""Get the boto3 structure describing the DynamoDB table schema."""
|
||||
return {
|
||||
'AttributeDefinitions': [
|
||||
{
|
||||
'AttributeName': self._key_field.name,
|
||||
'AttributeType': self._key_field.data_type
|
||||
}
|
||||
],
|
||||
'TableName': self.table_name,
|
||||
'KeySchema': [
|
||||
{
|
||||
'AttributeName': self._key_field.name,
|
||||
'KeyType': 'HASH'
|
||||
}
|
||||
],
|
||||
'ProvisionedThroughput': {
|
||||
'ReadCapacityUnits': self.read_capacity_units,
|
||||
'WriteCapacityUnits': self.write_capacity_units
|
||||
}
|
||||
}
|
||||
|
||||
def _get_or_create_table(self):
|
||||
"""Create table if not exists, otherwise return the description."""
|
||||
table_schema = self._get_table_schema()
|
||||
try:
|
||||
return self._client.describe_table(TableName=self.table_name)
|
||||
except ClientError as e:
|
||||
error_code = e.response['Error'].get('Code', 'Unknown')
|
||||
|
||||
if error_code == 'ResourceNotFoundException':
|
||||
table_description = self._client.create_table(**table_schema)
|
||||
logger.info(
|
||||
'DynamoDB Table {} did not exist, creating.'.format(
|
||||
self.table_name
|
||||
)
|
||||
)
|
||||
# In case we created the table, wait until it becomes available.
|
||||
self._wait_for_table_status('ACTIVE')
|
||||
logger.info(
|
||||
'DynamoDB Table {} is now available.'.format(
|
||||
self.table_name
|
||||
)
|
||||
)
|
||||
return table_description
|
||||
else:
|
||||
raise e
|
||||
|
||||
def _has_ttl(self):
|
||||
"""Return the desired Time to Live config.
|
||||
|
||||
- True: Enable TTL on the table; use expiry.
|
||||
- False: Disable TTL on the table; don't use expiry.
|
||||
- None: Ignore TTL on the table; don't use expiry.
|
||||
"""
|
||||
return None if self.time_to_live_seconds is None \
|
||||
else self.time_to_live_seconds >= 0
|
||||
|
||||
def _validate_ttl_methods(self):
|
||||
"""Verify boto support for the DynamoDB Time to Live methods."""
|
||||
# Required TTL methods.
|
||||
required_methods = (
|
||||
'update_time_to_live',
|
||||
'describe_time_to_live',
|
||||
)
|
||||
|
||||
# Find missing methods.
|
||||
missing_methods = []
|
||||
for method in list(required_methods):
|
||||
if not hasattr(self._client, method):
|
||||
missing_methods.append(method)
|
||||
|
||||
if missing_methods:
|
||||
logger.error(
|
||||
(
|
||||
'boto3 method(s) {methods} not found; ensure that '
|
||||
'boto3>=1.9.178 and botocore>=1.12.178 are installed'
|
||||
).format(
|
||||
methods=','.join(missing_methods)
|
||||
)
|
||||
)
|
||||
raise AttributeError(
|
||||
'boto3 method(s) {methods} not found'.format(
|
||||
methods=','.join(missing_methods)
|
||||
)
|
||||
)
|
||||
|
||||
def _get_ttl_specification(self, ttl_attr_name):
|
||||
"""Get the boto3 structure describing the DynamoDB TTL specification."""
|
||||
return {
|
||||
'TableName': self.table_name,
|
||||
'TimeToLiveSpecification': {
|
||||
'Enabled': self._has_ttl(),
|
||||
'AttributeName': ttl_attr_name
|
||||
}
|
||||
}
|
||||
|
||||
def _get_table_ttl_description(self):
|
||||
# Get the current TTL description.
|
||||
try:
|
||||
description = self._client.describe_time_to_live(
|
||||
TableName=self.table_name
|
||||
)
|
||||
except ClientError as e:
|
||||
error_code = e.response['Error'].get('Code', 'Unknown')
|
||||
error_message = e.response['Error'].get('Message', 'Unknown')
|
||||
logger.error((
|
||||
'Error describing Time to Live on DynamoDB table {table}: '
|
||||
'{code}: {message}'
|
||||
).format(
|
||||
table=self.table_name,
|
||||
code=error_code,
|
||||
message=error_message,
|
||||
))
|
||||
raise e
|
||||
|
||||
return description
|
||||
|
||||
def _set_table_ttl(self):
|
||||
"""Enable or disable Time to Live on the table."""
|
||||
# Get the table TTL description, and return early when possible.
|
||||
description = self._get_table_ttl_description()
|
||||
status = description['TimeToLiveDescription']['TimeToLiveStatus']
|
||||
if status in ('ENABLED', 'ENABLING'):
|
||||
cur_attr_name = \
|
||||
description['TimeToLiveDescription']['AttributeName']
|
||||
if self._has_ttl():
|
||||
if cur_attr_name == self._ttl_field.name:
|
||||
# We want TTL enabled, and it is currently enabled or being
|
||||
# enabled, and on the correct attribute.
|
||||
logger.debug((
|
||||
'DynamoDB Time to Live is {situation} '
|
||||
'on table {table}'
|
||||
).format(
|
||||
situation='already enabled'
|
||||
if status == 'ENABLED'
|
||||
else 'currently being enabled',
|
||||
table=self.table_name
|
||||
))
|
||||
return description
|
||||
|
||||
elif status in ('DISABLED', 'DISABLING'):
|
||||
if not self._has_ttl():
|
||||
# We want TTL disabled, and it is currently disabled or being
|
||||
# disabled.
|
||||
logger.debug((
|
||||
'DynamoDB Time to Live is {situation} '
|
||||
'on table {table}'
|
||||
).format(
|
||||
situation='already disabled'
|
||||
if status == 'DISABLED'
|
||||
else 'currently being disabled',
|
||||
table=self.table_name
|
||||
))
|
||||
return description
|
||||
|
||||
# The state shouldn't ever have any value beyond the four handled
|
||||
# above, but to ease troubleshooting of potential future changes, emit
|
||||
# a log showing the unknown state.
|
||||
else: # pragma: no cover
|
||||
logger.warning((
|
||||
'Unknown DynamoDB Time to Live status {status} '
|
||||
'on table {table}. Attempting to continue.'
|
||||
).format(
|
||||
status=status,
|
||||
table=self.table_name
|
||||
))
|
||||
|
||||
# At this point, we have one of the following situations:
|
||||
#
|
||||
# We want TTL enabled,
|
||||
#
|
||||
# - and it's currently disabled: Try to enable.
|
||||
#
|
||||
# - and it's being disabled: Try to enable, but this is almost sure to
|
||||
# raise ValidationException with message:
|
||||
#
|
||||
# Time to live has been modified multiple times within a fixed
|
||||
# interval
|
||||
#
|
||||
# - and it's currently enabling or being enabled, but on the wrong
|
||||
# attribute: Try to enable, but this will raise ValidationException
|
||||
# with message:
|
||||
#
|
||||
# TimeToLive is active on a different AttributeName: current
|
||||
# AttributeName is ttlx
|
||||
#
|
||||
# We want TTL disabled,
|
||||
#
|
||||
# - and it's currently enabled: Try to disable.
|
||||
#
|
||||
# - and it's being enabled: Try to disable, but this is almost sure to
|
||||
# raise ValidationException with message:
|
||||
#
|
||||
# Time to live has been modified multiple times within a fixed
|
||||
# interval
|
||||
#
|
||||
attr_name = \
|
||||
cur_attr_name if status == 'ENABLED' else self._ttl_field.name
|
||||
try:
|
||||
specification = self._client.update_time_to_live(
|
||||
**self._get_ttl_specification(
|
||||
ttl_attr_name=attr_name
|
||||
)
|
||||
)
|
||||
logger.info(
|
||||
(
|
||||
'DynamoDB table Time to Live updated: '
|
||||
'table={table} enabled={enabled} attribute={attr}'
|
||||
).format(
|
||||
table=self.table_name,
|
||||
enabled=self._has_ttl(),
|
||||
attr=self._ttl_field.name
|
||||
)
|
||||
)
|
||||
return specification
|
||||
except ClientError as e:
|
||||
error_code = e.response['Error'].get('Code', 'Unknown')
|
||||
error_message = e.response['Error'].get('Message', 'Unknown')
|
||||
logger.error((
|
||||
'Error {action} Time to Live on DynamoDB table {table}: '
|
||||
'{code}: {message}'
|
||||
).format(
|
||||
action='enabling' if self._has_ttl() else 'disabling',
|
||||
table=self.table_name,
|
||||
code=error_code,
|
||||
message=error_message,
|
||||
))
|
||||
raise e
|
||||
|
||||
def _wait_for_table_status(self, expected='ACTIVE'):
|
||||
"""Poll for the expected table status."""
|
||||
achieved_state = False
|
||||
while not achieved_state:
|
||||
table_description = self.client.describe_table(
|
||||
TableName=self.table_name
|
||||
)
|
||||
logger.debug(
|
||||
'Waiting for DynamoDB table {} to become {}.'.format(
|
||||
self.table_name,
|
||||
expected
|
||||
)
|
||||
)
|
||||
current_status = table_description['Table']['TableStatus']
|
||||
achieved_state = current_status == expected
|
||||
sleep(1)
|
||||
|
||||
def _prepare_get_request(self, key):
|
||||
"""Construct the item retrieval request parameters."""
|
||||
return {
|
||||
'TableName': self.table_name,
|
||||
'Key': {
|
||||
self._key_field.name: {
|
||||
self._key_field.data_type: key
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def _prepare_put_request(self, key, value):
|
||||
"""Construct the item creation request parameters."""
|
||||
timestamp = time()
|
||||
put_request = {
|
||||
'TableName': self.table_name,
|
||||
'Item': {
|
||||
self._key_field.name: {
|
||||
self._key_field.data_type: key
|
||||
},
|
||||
self._value_field.name: {
|
||||
self._value_field.data_type: value
|
||||
},
|
||||
self._timestamp_field.name: {
|
||||
self._timestamp_field.data_type: str(timestamp)
|
||||
}
|
||||
}
|
||||
}
|
||||
if self._has_ttl():
|
||||
put_request['Item'].update({
|
||||
self._ttl_field.name: {
|
||||
self._ttl_field.data_type:
|
||||
str(int(timestamp + self.time_to_live_seconds))
|
||||
}
|
||||
})
|
||||
return put_request
|
||||
|
||||
def _prepare_init_count_request(self, key: str) -> Dict[str, Any]:
|
||||
"""Construct the counter initialization request parameters"""
|
||||
timestamp = time()
|
||||
return {
|
||||
'TableName': self.table_name,
|
||||
'Item': {
|
||||
self._key_field.name: {
|
||||
self._key_field.data_type: key
|
||||
},
|
||||
self._count_filed.name: {
|
||||
self._count_filed.data_type: "0"
|
||||
},
|
||||
self._timestamp_field.name: {
|
||||
self._timestamp_field.data_type: str(timestamp)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def _prepare_inc_count_request(self, key: str) -> Dict[str, Any]:
|
||||
"""Construct the counter increment request parameters"""
|
||||
return {
|
||||
'TableName': self.table_name,
|
||||
'Key': {
|
||||
self._key_field.name: {
|
||||
self._key_field.data_type: key
|
||||
}
|
||||
},
|
||||
'UpdateExpression': f"set {self._count_filed.name} = {self._count_filed.name} + :num",
|
||||
"ExpressionAttributeValues": {
|
||||
":num": {"N": "1"},
|
||||
},
|
||||
"ReturnValues": "UPDATED_NEW",
|
||||
}
|
||||
|
||||
def _item_to_dict(self, raw_response):
|
||||
"""Convert get_item() response to field-value pairs."""
|
||||
if 'Item' not in raw_response:
|
||||
return {}
|
||||
return {
|
||||
field.name: raw_response['Item'][field.name][field.data_type]
|
||||
for field in self._available_fields
|
||||
}
|
||||
|
||||
@property
|
||||
def client(self):
|
||||
return self._get_client()
|
||||
|
||||
def get(self, key):
|
||||
key = str(key)
|
||||
request_parameters = self._prepare_get_request(key)
|
||||
item_response = self.client.get_item(**request_parameters)
|
||||
item = self._item_to_dict(item_response)
|
||||
return item.get(self._value_field.name)
|
||||
|
||||
def set(self, key, value):
|
||||
key = str(key)
|
||||
request_parameters = self._prepare_put_request(key, value)
|
||||
self.client.put_item(**request_parameters)
|
||||
|
||||
def mget(self, keys):
|
||||
return [self.get(key) for key in keys]
|
||||
|
||||
def delete(self, key):
|
||||
key = str(key)
|
||||
request_parameters = self._prepare_get_request(key)
|
||||
self.client.delete_item(**request_parameters)
|
||||
|
||||
def incr(self, key: bytes) -> int:
|
||||
"""Atomically increase the chord_count and return the new count"""
|
||||
key = str(key)
|
||||
request_parameters = self._prepare_inc_count_request(key)
|
||||
item_response = self.client.update_item(**request_parameters)
|
||||
new_count: str = item_response["Attributes"][self._count_filed.name][self._count_filed.data_type]
|
||||
return int(new_count)
|
||||
|
||||
def _apply_chord_incr(self, header_result_args, body, **kwargs):
|
||||
chord_key = self.get_key_for_chord(header_result_args[0])
|
||||
init_count_request = self._prepare_init_count_request(str(chord_key))
|
||||
self.client.put_item(**init_count_request)
|
||||
return super()._apply_chord_incr(
|
||||
header_result_args, body, **kwargs)
|
||||
@@ -0,0 +1,283 @@
|
||||
"""Elasticsearch result store backend."""
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from kombu.utils.encoding import bytes_to_str
|
||||
from kombu.utils.url import _parse_url
|
||||
|
||||
from celery import states
|
||||
from celery.exceptions import ImproperlyConfigured
|
||||
|
||||
from .base import KeyValueStoreBackend
|
||||
|
||||
try:
|
||||
import elasticsearch
|
||||
except ImportError:
|
||||
elasticsearch = None
|
||||
|
||||
try:
|
||||
import elastic_transport
|
||||
except ImportError:
|
||||
elastic_transport = None
|
||||
|
||||
__all__ = ('ElasticsearchBackend',)
|
||||
|
||||
E_LIB_MISSING = """\
|
||||
You need to install the elasticsearch library to use the Elasticsearch \
|
||||
result backend.\
|
||||
"""
|
||||
|
||||
|
||||
class ElasticsearchBackend(KeyValueStoreBackend):
|
||||
"""Elasticsearch Backend.
|
||||
|
||||
Raises:
|
||||
celery.exceptions.ImproperlyConfigured:
|
||||
if module :pypi:`elasticsearch` is not available.
|
||||
"""
|
||||
|
||||
index = 'celery'
|
||||
doc_type = None
|
||||
scheme = 'http'
|
||||
host = 'localhost'
|
||||
port = 9200
|
||||
username = None
|
||||
password = None
|
||||
es_retry_on_timeout = False
|
||||
es_timeout = 10
|
||||
es_max_retries = 3
|
||||
|
||||
def __init__(self, url=None, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.url = url
|
||||
_get = self.app.conf.get
|
||||
|
||||
if elasticsearch is None:
|
||||
raise ImproperlyConfigured(E_LIB_MISSING)
|
||||
|
||||
index = doc_type = scheme = host = port = username = password = None
|
||||
|
||||
if url:
|
||||
scheme, host, port, username, password, path, _ = _parse_url(url)
|
||||
if scheme == 'elasticsearch':
|
||||
scheme = None
|
||||
if path:
|
||||
path = path.strip('/')
|
||||
index, _, doc_type = path.partition('/')
|
||||
|
||||
self.index = index or self.index
|
||||
self.doc_type = doc_type or self.doc_type
|
||||
self.scheme = scheme or self.scheme
|
||||
self.host = host or self.host
|
||||
self.port = port or self.port
|
||||
self.username = username or self.username
|
||||
self.password = password or self.password
|
||||
|
||||
self.es_retry_on_timeout = (
|
||||
_get('elasticsearch_retry_on_timeout') or self.es_retry_on_timeout
|
||||
)
|
||||
|
||||
es_timeout = _get('elasticsearch_timeout')
|
||||
if es_timeout is not None:
|
||||
self.es_timeout = es_timeout
|
||||
|
||||
es_max_retries = _get('elasticsearch_max_retries')
|
||||
if es_max_retries is not None:
|
||||
self.es_max_retries = es_max_retries
|
||||
|
||||
self.es_save_meta_as_text = _get('elasticsearch_save_meta_as_text', True)
|
||||
self._server = None
|
||||
|
||||
def exception_safe_to_retry(self, exc):
|
||||
if isinstance(exc, elasticsearch.exceptions.ApiError):
|
||||
# 401: Unauthorized
|
||||
# 409: Conflict
|
||||
# 500: Internal Server Error
|
||||
# 502: Bad Gateway
|
||||
# 504: Gateway Timeout
|
||||
# N/A: Low level exception (i.e. socket exception)
|
||||
if exc.status_code in {401, 409, 500, 502, 504, 'N/A'}:
|
||||
return True
|
||||
if isinstance(exc, elasticsearch.exceptions.TransportError):
|
||||
return True
|
||||
return False
|
||||
|
||||
def get(self, key):
|
||||
try:
|
||||
res = self._get(key)
|
||||
try:
|
||||
if res['found']:
|
||||
return res['_source']['result']
|
||||
except (TypeError, KeyError):
|
||||
pass
|
||||
except elasticsearch.exceptions.NotFoundError:
|
||||
pass
|
||||
|
||||
def _get(self, key):
|
||||
if self.doc_type:
|
||||
return self.server.get(
|
||||
index=self.index,
|
||||
id=key,
|
||||
doc_type=self.doc_type,
|
||||
)
|
||||
else:
|
||||
return self.server.get(
|
||||
index=self.index,
|
||||
id=key,
|
||||
)
|
||||
|
||||
def _set_with_state(self, key, value, state):
|
||||
body = {
|
||||
'result': value,
|
||||
'@timestamp': '{}Z'.format(
|
||||
datetime.now(timezone.utc).isoformat()[:-9]
|
||||
),
|
||||
}
|
||||
try:
|
||||
self._index(
|
||||
id=key,
|
||||
body=body,
|
||||
)
|
||||
except elasticsearch.exceptions.ConflictError:
|
||||
# document already exists, update it
|
||||
self._update(key, body, state)
|
||||
|
||||
def set(self, key, value):
|
||||
return self._set_with_state(key, value, None)
|
||||
|
||||
def _index(self, id, body, **kwargs):
|
||||
body = {bytes_to_str(k): v for k, v in body.items()}
|
||||
if self.doc_type:
|
||||
return self.server.index(
|
||||
id=bytes_to_str(id),
|
||||
index=self.index,
|
||||
doc_type=self.doc_type,
|
||||
body=body,
|
||||
params={'op_type': 'create'},
|
||||
**kwargs
|
||||
)
|
||||
else:
|
||||
return self.server.index(
|
||||
id=bytes_to_str(id),
|
||||
index=self.index,
|
||||
body=body,
|
||||
params={'op_type': 'create'},
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def _update(self, id, body, state, **kwargs):
|
||||
"""Update state in a conflict free manner.
|
||||
|
||||
If state is defined (not None), this will not update ES server if either:
|
||||
* existing state is success
|
||||
* existing state is a ready state and current state in not a ready state
|
||||
|
||||
This way, a Retry state cannot override a Success or Failure, and chord_unlock
|
||||
will not retry indefinitely.
|
||||
"""
|
||||
body = {bytes_to_str(k): v for k, v in body.items()}
|
||||
|
||||
try:
|
||||
res_get = self._get(key=id)
|
||||
if not res_get.get('found'):
|
||||
return self._index(id, body, **kwargs)
|
||||
# document disappeared between index and get calls.
|
||||
except elasticsearch.exceptions.NotFoundError:
|
||||
return self._index(id, body, **kwargs)
|
||||
|
||||
try:
|
||||
meta_present_on_backend = self.decode_result(res_get['_source']['result'])
|
||||
except (TypeError, KeyError):
|
||||
pass
|
||||
else:
|
||||
if meta_present_on_backend['status'] == states.SUCCESS:
|
||||
# if stored state is already in success, do nothing
|
||||
return {'result': 'noop'}
|
||||
elif meta_present_on_backend['status'] in states.READY_STATES and state in states.UNREADY_STATES:
|
||||
# if stored state is in ready state and current not, do nothing
|
||||
return {'result': 'noop'}
|
||||
|
||||
# get current sequence number and primary term
|
||||
# https://www.elastic.co/guide/en/elasticsearch/reference/current/optimistic-concurrency-control.html
|
||||
seq_no = res_get.get('_seq_no', 1)
|
||||
prim_term = res_get.get('_primary_term', 1)
|
||||
|
||||
# try to update document with current seq_no and primary_term
|
||||
if self.doc_type:
|
||||
res = self.server.update(
|
||||
id=bytes_to_str(id),
|
||||
index=self.index,
|
||||
doc_type=self.doc_type,
|
||||
body={'doc': body},
|
||||
params={'if_primary_term': prim_term, 'if_seq_no': seq_no},
|
||||
**kwargs
|
||||
)
|
||||
else:
|
||||
res = self.server.update(
|
||||
id=bytes_to_str(id),
|
||||
index=self.index,
|
||||
body={'doc': body},
|
||||
params={'if_primary_term': prim_term, 'if_seq_no': seq_no},
|
||||
**kwargs
|
||||
)
|
||||
# result is elastic search update query result
|
||||
# noop = query did not update any document
|
||||
# updated = at least one document got updated
|
||||
if res['result'] == 'noop':
|
||||
raise elasticsearch.exceptions.ConflictError(
|
||||
"conflicting update occurred concurrently",
|
||||
elastic_transport.ApiResponseMeta(409, "HTTP/1.1",
|
||||
elastic_transport.HttpHeaders(), 0, elastic_transport.NodeConfig(
|
||||
self.scheme, self.host, self.port)), None)
|
||||
return res
|
||||
|
||||
def encode(self, data):
|
||||
if self.es_save_meta_as_text:
|
||||
return super().encode(data)
|
||||
else:
|
||||
if not isinstance(data, dict):
|
||||
return super().encode(data)
|
||||
if data.get("result"):
|
||||
data["result"] = self._encode(data["result"])[2]
|
||||
if data.get("traceback"):
|
||||
data["traceback"] = self._encode(data["traceback"])[2]
|
||||
return data
|
||||
|
||||
def decode(self, payload):
|
||||
if self.es_save_meta_as_text:
|
||||
return super().decode(payload)
|
||||
else:
|
||||
if not isinstance(payload, dict):
|
||||
return super().decode(payload)
|
||||
if payload.get("result"):
|
||||
payload["result"] = super().decode(payload["result"])
|
||||
if payload.get("traceback"):
|
||||
payload["traceback"] = super().decode(payload["traceback"])
|
||||
return payload
|
||||
|
||||
def mget(self, keys):
|
||||
return [self.get(key) for key in keys]
|
||||
|
||||
def delete(self, key):
|
||||
if self.doc_type:
|
||||
self.server.delete(index=self.index, id=key, doc_type=self.doc_type)
|
||||
else:
|
||||
self.server.delete(index=self.index, id=key)
|
||||
|
||||
def _get_server(self):
|
||||
"""Connect to the Elasticsearch server."""
|
||||
http_auth = None
|
||||
if self.username and self.password:
|
||||
http_auth = (self.username, self.password)
|
||||
return elasticsearch.Elasticsearch(
|
||||
f'{self.scheme}://{self.host}:{self.port}',
|
||||
retry_on_timeout=self.es_retry_on_timeout,
|
||||
max_retries=self.es_max_retries,
|
||||
timeout=self.es_timeout,
|
||||
http_auth=http_auth,
|
||||
)
|
||||
|
||||
@property
|
||||
def server(self):
|
||||
if self._server is None:
|
||||
self._server = self._get_server()
|
||||
return self._server
|
||||
@@ -0,0 +1,112 @@
|
||||
"""File-system result store backend."""
|
||||
import locale
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
from kombu.utils.encoding import ensure_bytes
|
||||
|
||||
from celery import uuid
|
||||
from celery.backends.base import KeyValueStoreBackend
|
||||
from celery.exceptions import ImproperlyConfigured
|
||||
|
||||
default_encoding = locale.getpreferredencoding(False)
|
||||
|
||||
E_NO_PATH_SET = 'You need to configure a path for the file-system backend'
|
||||
E_PATH_NON_CONFORMING_SCHEME = (
|
||||
'A path for the file-system backend should conform to the file URI scheme'
|
||||
)
|
||||
E_PATH_INVALID = """\
|
||||
The configured path for the file-system backend does not
|
||||
work correctly, please make sure that it exists and has
|
||||
the correct permissions.\
|
||||
"""
|
||||
|
||||
|
||||
class FilesystemBackend(KeyValueStoreBackend):
|
||||
"""File-system result backend.
|
||||
|
||||
Arguments:
|
||||
url (str): URL to the directory we should use
|
||||
open (Callable): open function to use when opening files
|
||||
unlink (Callable): unlink function to use when deleting files
|
||||
sep (str): directory separator (to join the directory with the key)
|
||||
encoding (str): encoding used on the file-system
|
||||
"""
|
||||
|
||||
def __init__(self, url=None, open=open, unlink=os.unlink, sep=os.sep,
|
||||
encoding=default_encoding, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.url = url
|
||||
path = self._find_path(url)
|
||||
|
||||
# Remove forwarding "/" for Windows os
|
||||
if os.name == "nt" and path.startswith("/"):
|
||||
path = path[1:]
|
||||
|
||||
# We need the path and separator as bytes objects
|
||||
self.path = path.encode(encoding)
|
||||
self.sep = sep.encode(encoding)
|
||||
|
||||
self.open = open
|
||||
self.unlink = unlink
|
||||
|
||||
# Let's verify that we've everything setup right
|
||||
self._do_directory_test(b'.fs-backend-' + uuid().encode(encoding))
|
||||
|
||||
def __reduce__(self, args=(), kwargs=None):
|
||||
kwargs = {} if not kwargs else kwargs
|
||||
return super().__reduce__(args, {**kwargs, 'url': self.url})
|
||||
|
||||
def _find_path(self, url):
|
||||
if not url:
|
||||
raise ImproperlyConfigured(E_NO_PATH_SET)
|
||||
if url.startswith('file://localhost/'):
|
||||
return url[16:]
|
||||
if url.startswith('file://'):
|
||||
return url[7:]
|
||||
raise ImproperlyConfigured(E_PATH_NON_CONFORMING_SCHEME)
|
||||
|
||||
def _do_directory_test(self, key):
|
||||
try:
|
||||
self.set(key, b'test value')
|
||||
assert self.get(key) == b'test value'
|
||||
self.delete(key)
|
||||
except OSError:
|
||||
raise ImproperlyConfigured(E_PATH_INVALID)
|
||||
|
||||
def _filename(self, key):
|
||||
return self.sep.join((self.path, key))
|
||||
|
||||
def get(self, key):
|
||||
try:
|
||||
with self.open(self._filename(key), 'rb') as infile:
|
||||
return infile.read()
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
def set(self, key, value):
|
||||
with self.open(self._filename(key), 'wb') as outfile:
|
||||
outfile.write(ensure_bytes(value))
|
||||
|
||||
def mget(self, keys):
|
||||
for key in keys:
|
||||
yield self.get(key)
|
||||
|
||||
def delete(self, key):
|
||||
self.unlink(self._filename(key))
|
||||
|
||||
def cleanup(self):
|
||||
"""Delete expired meta-data."""
|
||||
if not self.expires:
|
||||
return
|
||||
epoch = datetime(1970, 1, 1, tzinfo=self.app.timezone)
|
||||
now_ts = (self.app.now() - epoch).total_seconds()
|
||||
cutoff_ts = now_ts - self.expires
|
||||
for filename in os.listdir(self.path):
|
||||
for prefix in (self.task_keyprefix, self.group_keyprefix,
|
||||
self.chord_keyprefix):
|
||||
if filename.startswith(prefix):
|
||||
path = os.path.join(self.path, filename)
|
||||
if os.stat(path).st_mtime < cutoff_ts:
|
||||
self.unlink(path)
|
||||
break
|
||||
352
ETB-API/venv/lib/python3.12/site-packages/celery/backends/gcs.py
Normal file
352
ETB-API/venv/lib/python3.12/site-packages/celery/backends/gcs.py
Normal file
@@ -0,0 +1,352 @@
|
||||
"""Google Cloud Storage result store backend for Celery."""
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime, timedelta
|
||||
from os import getpid
|
||||
from threading import RLock
|
||||
|
||||
from kombu.utils.encoding import bytes_to_str
|
||||
from kombu.utils.functional import dictfilter
|
||||
from kombu.utils.url import url_to_parts
|
||||
|
||||
from celery.canvas import maybe_signature
|
||||
from celery.exceptions import ChordError, ImproperlyConfigured
|
||||
from celery.result import GroupResult, allow_join_result
|
||||
from celery.utils.log import get_logger
|
||||
|
||||
from .base import KeyValueStoreBackend
|
||||
|
||||
try:
|
||||
import requests
|
||||
from google.api_core import retry
|
||||
from google.api_core.exceptions import Conflict
|
||||
from google.api_core.retry import if_exception_type
|
||||
from google.cloud import storage
|
||||
from google.cloud.storage import Client
|
||||
from google.cloud.storage.retry import DEFAULT_RETRY
|
||||
except ImportError:
|
||||
storage = None
|
||||
|
||||
try:
|
||||
from google.cloud import firestore, firestore_admin_v1
|
||||
except ImportError:
|
||||
firestore = None
|
||||
firestore_admin_v1 = None
|
||||
|
||||
|
||||
__all__ = ('GCSBackend',)
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class GCSBackendBase(KeyValueStoreBackend):
|
||||
"""Google Cloud Storage task result backend."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
if not storage:
|
||||
raise ImproperlyConfigured(
|
||||
'You must install google-cloud-storage to use gcs backend'
|
||||
)
|
||||
super().__init__(**kwargs)
|
||||
self._client_lock = RLock()
|
||||
self._pid = getpid()
|
||||
self._retry_policy = DEFAULT_RETRY
|
||||
self._client = None
|
||||
|
||||
conf = self.app.conf
|
||||
if self.url:
|
||||
url_params = self._params_from_url()
|
||||
conf.update(**dictfilter(url_params))
|
||||
|
||||
self.bucket_name = conf.get('gcs_bucket')
|
||||
if not self.bucket_name:
|
||||
raise ImproperlyConfigured(
|
||||
'Missing bucket name: specify gcs_bucket to use gcs backend'
|
||||
)
|
||||
self.project = conf.get('gcs_project')
|
||||
if not self.project:
|
||||
raise ImproperlyConfigured(
|
||||
'Missing project:specify gcs_project to use gcs backend'
|
||||
)
|
||||
self.base_path = conf.get('gcs_base_path', '').strip('/')
|
||||
self._threadpool_maxsize = int(conf.get('gcs_threadpool_maxsize', 10))
|
||||
self.ttl = float(conf.get('gcs_ttl') or 0)
|
||||
if self.ttl < 0:
|
||||
raise ImproperlyConfigured(
|
||||
f'Invalid ttl: {self.ttl} must be greater than or equal to 0'
|
||||
)
|
||||
elif self.ttl:
|
||||
if not self._is_bucket_lifecycle_rule_exists():
|
||||
raise ImproperlyConfigured(
|
||||
f'Missing lifecycle rule to use gcs backend with ttl on '
|
||||
f'bucket: {self.bucket_name}'
|
||||
)
|
||||
|
||||
def get(self, key):
|
||||
key = bytes_to_str(key)
|
||||
blob = self._get_blob(key)
|
||||
try:
|
||||
return blob.download_as_bytes(retry=self._retry_policy)
|
||||
except storage.blob.NotFound:
|
||||
return None
|
||||
|
||||
def set(self, key, value):
|
||||
key = bytes_to_str(key)
|
||||
blob = self._get_blob(key)
|
||||
if self.ttl:
|
||||
blob.custom_time = datetime.utcnow() + timedelta(seconds=self.ttl)
|
||||
blob.upload_from_string(value, retry=self._retry_policy)
|
||||
|
||||
def delete(self, key):
|
||||
key = bytes_to_str(key)
|
||||
blob = self._get_blob(key)
|
||||
if blob.exists():
|
||||
blob.delete(retry=self._retry_policy)
|
||||
|
||||
def mget(self, keys):
|
||||
with ThreadPoolExecutor() as pool:
|
||||
return list(pool.map(self.get, keys))
|
||||
|
||||
@property
|
||||
def client(self):
|
||||
"""Returns a storage client."""
|
||||
|
||||
# make sure it's thread-safe, as creating a new client is expensive
|
||||
with self._client_lock:
|
||||
if self._client and self._pid == getpid():
|
||||
return self._client
|
||||
# make sure each process gets its own connection after a fork
|
||||
self._client = Client(project=self.project)
|
||||
self._pid = getpid()
|
||||
|
||||
# config the number of connections to the server
|
||||
adapter = requests.adapters.HTTPAdapter(
|
||||
pool_connections=self._threadpool_maxsize,
|
||||
pool_maxsize=self._threadpool_maxsize,
|
||||
max_retries=3,
|
||||
)
|
||||
client_http = self._client._http
|
||||
client_http.mount("https://", adapter)
|
||||
client_http._auth_request.session.mount("https://", adapter)
|
||||
|
||||
return self._client
|
||||
|
||||
@property
|
||||
def bucket(self):
|
||||
return self.client.bucket(self.bucket_name)
|
||||
|
||||
def _get_blob(self, key):
|
||||
key_bucket_path = f'{self.base_path}/{key}' if self.base_path else key
|
||||
return self.bucket.blob(key_bucket_path)
|
||||
|
||||
def _is_bucket_lifecycle_rule_exists(self):
|
||||
bucket = self.bucket
|
||||
bucket.reload()
|
||||
for rule in bucket.lifecycle_rules:
|
||||
if rule['action']['type'] == 'Delete':
|
||||
return True
|
||||
return False
|
||||
|
||||
def _params_from_url(self):
|
||||
url_parts = url_to_parts(self.url)
|
||||
|
||||
return {
|
||||
'gcs_bucket': url_parts.hostname,
|
||||
'gcs_base_path': url_parts.path,
|
||||
**url_parts.query,
|
||||
}
|
||||
|
||||
|
||||
class GCSBackend(GCSBackendBase):
|
||||
"""Google Cloud Storage task result backend.
|
||||
|
||||
Uses Firestore for chord ref count.
|
||||
"""
|
||||
|
||||
implements_incr = True
|
||||
supports_native_join = True
|
||||
|
||||
# Firestore parameters
|
||||
_collection_name = 'celery'
|
||||
_field_count = 'chord_count'
|
||||
_field_expires = 'expires_at'
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
if not (firestore and firestore_admin_v1):
|
||||
raise ImproperlyConfigured(
|
||||
'You must install google-cloud-firestore to use gcs backend'
|
||||
)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._firestore_lock = RLock()
|
||||
self._firestore_client = None
|
||||
|
||||
self.firestore_project = self.app.conf.get(
|
||||
'firestore_project', self.project
|
||||
)
|
||||
if not self._is_firestore_ttl_policy_enabled():
|
||||
raise ImproperlyConfigured(
|
||||
f'Missing TTL policy to use gcs backend with ttl on '
|
||||
f'Firestore collection: {self._collection_name} '
|
||||
f'project: {self.firestore_project}'
|
||||
)
|
||||
|
||||
@property
|
||||
def firestore_client(self):
|
||||
"""Returns a firestore client."""
|
||||
|
||||
# make sure it's thread-safe, as creating a new client is expensive
|
||||
with self._firestore_lock:
|
||||
if self._firestore_client and self._pid == getpid():
|
||||
return self._firestore_client
|
||||
# make sure each process gets its own connection after a fork
|
||||
self._firestore_client = firestore.Client(
|
||||
project=self.firestore_project
|
||||
)
|
||||
self._pid = getpid()
|
||||
return self._firestore_client
|
||||
|
||||
def _is_firestore_ttl_policy_enabled(self):
|
||||
client = firestore_admin_v1.FirestoreAdminClient()
|
||||
|
||||
name = (
|
||||
f"projects/{self.firestore_project}"
|
||||
f"/databases/(default)/collectionGroups/{self._collection_name}"
|
||||
f"/fields/{self._field_expires}"
|
||||
)
|
||||
request = firestore_admin_v1.GetFieldRequest(name=name)
|
||||
field = client.get_field(request=request)
|
||||
|
||||
ttl_config = field.ttl_config
|
||||
return ttl_config and ttl_config.state in {
|
||||
firestore_admin_v1.Field.TtlConfig.State.ACTIVE,
|
||||
firestore_admin_v1.Field.TtlConfig.State.CREATING,
|
||||
}
|
||||
|
||||
def _apply_chord_incr(self, header_result_args, body, **kwargs):
|
||||
key = self.get_key_for_chord(header_result_args[0]).decode()
|
||||
self._expire_chord_key(key, 86400)
|
||||
return super()._apply_chord_incr(header_result_args, body, **kwargs)
|
||||
|
||||
def incr(self, key: bytes) -> int:
|
||||
doc = self._firestore_document(key)
|
||||
resp = doc.set(
|
||||
{self._field_count: firestore.Increment(1)},
|
||||
merge=True,
|
||||
retry=retry.Retry(
|
||||
predicate=if_exception_type(Conflict),
|
||||
initial=1.0,
|
||||
maximum=180.0,
|
||||
multiplier=2.0,
|
||||
timeout=180.0,
|
||||
),
|
||||
)
|
||||
return resp.transform_results[0].integer_value
|
||||
|
||||
def on_chord_part_return(self, request, state, result, **kwargs):
|
||||
"""Chord part return callback.
|
||||
|
||||
Called for each task in the chord.
|
||||
Increments the counter stored in Firestore.
|
||||
If the counter reaches the number of tasks in the chord, the callback
|
||||
is called.
|
||||
If the callback raises an exception, the chord is marked as errored.
|
||||
If the callback returns a value, the chord is marked as successful.
|
||||
"""
|
||||
app = self.app
|
||||
gid = request.group
|
||||
if not gid:
|
||||
return
|
||||
key = self.get_key_for_chord(gid)
|
||||
val = self.incr(key)
|
||||
size = request.chord.get("chord_size")
|
||||
if size is None:
|
||||
deps = self._restore_deps(gid, request)
|
||||
if deps is None:
|
||||
return
|
||||
size = len(deps)
|
||||
if val > size: # pragma: no cover
|
||||
logger.warning(
|
||||
'Chord counter incremented too many times for %r', gid
|
||||
)
|
||||
elif val == size:
|
||||
# Read the deps once, to reduce the number of reads from GCS ($$)
|
||||
deps = self._restore_deps(gid, request)
|
||||
if deps is None:
|
||||
return
|
||||
callback = maybe_signature(request.chord, app=app)
|
||||
j = deps.join_native
|
||||
try:
|
||||
with allow_join_result():
|
||||
ret = j(
|
||||
timeout=app.conf.result_chord_join_timeout,
|
||||
propagate=True,
|
||||
)
|
||||
except Exception as exc: # pylint: disable=broad-except
|
||||
try:
|
||||
culprit = next(deps._failed_join_report())
|
||||
reason = 'Dependency {0.id} raised {1!r}'.format(
|
||||
culprit,
|
||||
exc,
|
||||
)
|
||||
except StopIteration:
|
||||
reason = repr(exc)
|
||||
|
||||
logger.exception('Chord %r raised: %r', gid, reason)
|
||||
self.chord_error_from_stack(callback, ChordError(reason))
|
||||
else:
|
||||
try:
|
||||
callback.delay(ret)
|
||||
except Exception as exc: # pylint: disable=broad-except
|
||||
logger.exception('Chord %r raised: %r', gid, exc)
|
||||
self.chord_error_from_stack(
|
||||
callback,
|
||||
ChordError(f'Callback error: {exc!r}'),
|
||||
)
|
||||
finally:
|
||||
deps.delete()
|
||||
# Firestore doesn't have an exact ttl policy, so delete the key.
|
||||
self._delete_chord_key(key)
|
||||
|
||||
def _restore_deps(self, gid, request):
|
||||
app = self.app
|
||||
try:
|
||||
deps = GroupResult.restore(gid, backend=self)
|
||||
except Exception as exc: # pylint: disable=broad-except
|
||||
callback = maybe_signature(request.chord, app=app)
|
||||
logger.exception('Chord %r raised: %r', gid, exc)
|
||||
self.chord_error_from_stack(
|
||||
callback,
|
||||
ChordError(f'Cannot restore group: {exc!r}'),
|
||||
)
|
||||
return
|
||||
if deps is None:
|
||||
try:
|
||||
raise ValueError(gid)
|
||||
except ValueError as exc:
|
||||
callback = maybe_signature(request.chord, app=app)
|
||||
logger.exception('Chord callback %r raised: %r', gid, exc)
|
||||
self.chord_error_from_stack(
|
||||
callback,
|
||||
ChordError(f'GroupResult {gid} no longer exists'),
|
||||
)
|
||||
return deps
|
||||
|
||||
def _delete_chord_key(self, key):
|
||||
doc = self._firestore_document(key)
|
||||
doc.delete()
|
||||
|
||||
def _expire_chord_key(self, key, expires):
|
||||
"""Set TTL policy for a Firestore document.
|
||||
|
||||
Firestore ttl data is typically deleted within 24 hours after its
|
||||
expiration date.
|
||||
"""
|
||||
val_expires = datetime.utcnow() + timedelta(seconds=expires)
|
||||
doc = self._firestore_document(key)
|
||||
doc.set({self._field_expires: val_expires}, merge=True)
|
||||
|
||||
def _firestore_document(self, key):
|
||||
return self.firestore_client.collection(
|
||||
self._collection_name
|
||||
).document(bytes_to_str(key))
|
||||
@@ -0,0 +1,333 @@
|
||||
"""MongoDB result store backend."""
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from kombu.exceptions import EncodeError
|
||||
from kombu.utils.objects import cached_property
|
||||
from kombu.utils.url import maybe_sanitize_url, urlparse
|
||||
|
||||
from celery import states
|
||||
from celery.exceptions import ImproperlyConfigured
|
||||
|
||||
from .base import BaseBackend
|
||||
|
||||
try:
|
||||
import pymongo
|
||||
except ImportError:
|
||||
pymongo = None
|
||||
|
||||
if pymongo:
|
||||
try:
|
||||
from bson.binary import Binary
|
||||
except ImportError:
|
||||
from pymongo.binary import Binary
|
||||
from pymongo.errors import InvalidDocument
|
||||
else: # pragma: no cover
|
||||
Binary = None
|
||||
|
||||
class InvalidDocument(Exception):
|
||||
pass
|
||||
|
||||
__all__ = ('MongoBackend',)
|
||||
|
||||
BINARY_CODECS = frozenset(['pickle', 'msgpack'])
|
||||
|
||||
|
||||
class MongoBackend(BaseBackend):
|
||||
"""MongoDB result backend.
|
||||
|
||||
Raises:
|
||||
celery.exceptions.ImproperlyConfigured:
|
||||
if module :pypi:`pymongo` is not available.
|
||||
"""
|
||||
|
||||
mongo_host = None
|
||||
host = 'localhost'
|
||||
port = 27017
|
||||
user = None
|
||||
password = None
|
||||
database_name = 'celery'
|
||||
taskmeta_collection = 'celery_taskmeta'
|
||||
groupmeta_collection = 'celery_groupmeta'
|
||||
max_pool_size = 10
|
||||
options = None
|
||||
|
||||
supports_autoexpire = False
|
||||
|
||||
_connection = None
|
||||
|
||||
def __init__(self, app=None, **kwargs):
|
||||
self.options = {}
|
||||
|
||||
super().__init__(app, **kwargs)
|
||||
|
||||
if not pymongo:
|
||||
raise ImproperlyConfigured(
|
||||
'You need to install the pymongo library to use the '
|
||||
'MongoDB backend.')
|
||||
|
||||
# Set option defaults
|
||||
for key, value in self._prepare_client_options().items():
|
||||
self.options.setdefault(key, value)
|
||||
|
||||
# update conf with mongo uri data, only if uri was given
|
||||
if self.url:
|
||||
self.url = self._ensure_mongodb_uri_compliance(self.url)
|
||||
|
||||
uri_data = pymongo.uri_parser.parse_uri(self.url)
|
||||
# build the hosts list to create a mongo connection
|
||||
hostslist = [
|
||||
f'{x[0]}:{x[1]}' for x in uri_data['nodelist']
|
||||
]
|
||||
self.user = uri_data['username']
|
||||
self.password = uri_data['password']
|
||||
self.mongo_host = hostslist
|
||||
if uri_data['database']:
|
||||
# if no database is provided in the uri, use default
|
||||
self.database_name = uri_data['database']
|
||||
|
||||
self.options.update(uri_data['options'])
|
||||
|
||||
# update conf with specific settings
|
||||
config = self.app.conf.get('mongodb_backend_settings')
|
||||
if config is not None:
|
||||
if not isinstance(config, dict):
|
||||
raise ImproperlyConfigured(
|
||||
'MongoDB backend settings should be grouped in a dict')
|
||||
config = dict(config) # don't modify original
|
||||
|
||||
if 'host' in config or 'port' in config:
|
||||
# these should take over uri conf
|
||||
self.mongo_host = None
|
||||
|
||||
self.host = config.pop('host', self.host)
|
||||
self.port = config.pop('port', self.port)
|
||||
self.mongo_host = config.pop('mongo_host', self.mongo_host)
|
||||
self.user = config.pop('user', self.user)
|
||||
self.password = config.pop('password', self.password)
|
||||
self.database_name = config.pop('database', self.database_name)
|
||||
self.taskmeta_collection = config.pop(
|
||||
'taskmeta_collection', self.taskmeta_collection,
|
||||
)
|
||||
self.groupmeta_collection = config.pop(
|
||||
'groupmeta_collection', self.groupmeta_collection,
|
||||
)
|
||||
|
||||
self.options.update(config.pop('options', {}))
|
||||
self.options.update(config)
|
||||
|
||||
@staticmethod
|
||||
def _ensure_mongodb_uri_compliance(url):
|
||||
parsed_url = urlparse(url)
|
||||
if not parsed_url.scheme.startswith('mongodb'):
|
||||
url = f'mongodb+{url}'
|
||||
|
||||
if url == 'mongodb://':
|
||||
url += 'localhost'
|
||||
|
||||
return url
|
||||
|
||||
def _prepare_client_options(self):
|
||||
if pymongo.version_tuple >= (3,):
|
||||
return {'maxPoolSize': self.max_pool_size}
|
||||
else: # pragma: no cover
|
||||
return {'max_pool_size': self.max_pool_size,
|
||||
'auto_start_request': False}
|
||||
|
||||
def _get_connection(self):
|
||||
"""Connect to the MongoDB server."""
|
||||
if self._connection is None:
|
||||
from pymongo import MongoClient
|
||||
|
||||
host = self.mongo_host
|
||||
if not host:
|
||||
# The first pymongo.Connection() argument (host) can be
|
||||
# a list of ['host:port'] elements or a mongodb connection
|
||||
# URI. If this is the case, don't use self.port
|
||||
# but let pymongo get the port(s) from the URI instead.
|
||||
# This enables the use of replica sets and sharding.
|
||||
# See pymongo.Connection() for more info.
|
||||
host = self.host
|
||||
if isinstance(host, str) \
|
||||
and not host.startswith('mongodb://'):
|
||||
host = f'mongodb://{host}:{self.port}'
|
||||
# don't change self.options
|
||||
conf = dict(self.options)
|
||||
conf['host'] = host
|
||||
if self.user:
|
||||
conf['username'] = self.user
|
||||
if self.password:
|
||||
conf['password'] = self.password
|
||||
|
||||
self._connection = MongoClient(**conf)
|
||||
|
||||
return self._connection
|
||||
|
||||
def encode(self, data):
|
||||
if self.serializer == 'bson':
|
||||
# mongodb handles serialization
|
||||
return data
|
||||
payload = super().encode(data)
|
||||
|
||||
# serializer which are in a unsupported format (pickle/binary)
|
||||
if self.serializer in BINARY_CODECS:
|
||||
payload = Binary(payload)
|
||||
return payload
|
||||
|
||||
def decode(self, data):
|
||||
if self.serializer == 'bson':
|
||||
return data
|
||||
return super().decode(data)
|
||||
|
||||
def _store_result(self, task_id, result, state,
|
||||
traceback=None, request=None, **kwargs):
|
||||
"""Store return value and state of an executed task."""
|
||||
meta = self._get_result_meta(result=self.encode(result), state=state,
|
||||
traceback=traceback, request=request,
|
||||
format_date=False)
|
||||
# Add the _id for mongodb
|
||||
meta['_id'] = task_id
|
||||
|
||||
try:
|
||||
self.collection.replace_one({'_id': task_id}, meta, upsert=True)
|
||||
except InvalidDocument as exc:
|
||||
raise EncodeError(exc)
|
||||
|
||||
return result
|
||||
|
||||
def _get_task_meta_for(self, task_id):
|
||||
"""Get task meta-data for a task by id."""
|
||||
obj = self.collection.find_one({'_id': task_id})
|
||||
if obj:
|
||||
if self.app.conf.find_value_for_key('extended', 'result'):
|
||||
return self.meta_from_decoded({
|
||||
'name': obj['name'],
|
||||
'args': obj['args'],
|
||||
'task_id': obj['_id'],
|
||||
'queue': obj['queue'],
|
||||
'kwargs': obj['kwargs'],
|
||||
'status': obj['status'],
|
||||
'worker': obj['worker'],
|
||||
'retries': obj['retries'],
|
||||
'children': obj['children'],
|
||||
'date_done': obj['date_done'],
|
||||
'traceback': obj['traceback'],
|
||||
'result': self.decode(obj['result']),
|
||||
})
|
||||
return self.meta_from_decoded({
|
||||
'task_id': obj['_id'],
|
||||
'status': obj['status'],
|
||||
'result': self.decode(obj['result']),
|
||||
'date_done': obj['date_done'],
|
||||
'traceback': obj['traceback'],
|
||||
'children': obj['children'],
|
||||
})
|
||||
return {'status': states.PENDING, 'result': None}
|
||||
|
||||
def _save_group(self, group_id, result):
|
||||
"""Save the group result."""
|
||||
meta = {
|
||||
'_id': group_id,
|
||||
'result': self.encode([i.id for i in result]),
|
||||
'date_done': datetime.now(timezone.utc),
|
||||
}
|
||||
self.group_collection.replace_one({'_id': group_id}, meta, upsert=True)
|
||||
return result
|
||||
|
||||
def _restore_group(self, group_id):
|
||||
"""Get the result for a group by id."""
|
||||
obj = self.group_collection.find_one({'_id': group_id})
|
||||
if obj:
|
||||
return {
|
||||
'task_id': obj['_id'],
|
||||
'date_done': obj['date_done'],
|
||||
'result': [
|
||||
self.app.AsyncResult(task)
|
||||
for task in self.decode(obj['result'])
|
||||
],
|
||||
}
|
||||
|
||||
def _delete_group(self, group_id):
|
||||
"""Delete a group by id."""
|
||||
self.group_collection.delete_one({'_id': group_id})
|
||||
|
||||
def _forget(self, task_id):
|
||||
"""Remove result from MongoDB.
|
||||
|
||||
Raises:
|
||||
pymongo.exceptions.OperationsError:
|
||||
if the task_id could not be removed.
|
||||
"""
|
||||
# By using safe=True, this will wait until it receives a response from
|
||||
# the server. Likewise, it will raise an OperationsError if the
|
||||
# response was unable to be completed.
|
||||
self.collection.delete_one({'_id': task_id})
|
||||
|
||||
def cleanup(self):
|
||||
"""Delete expired meta-data."""
|
||||
if not self.expires:
|
||||
return
|
||||
|
||||
self.collection.delete_many(
|
||||
{'date_done': {'$lt': self.app.now() - self.expires_delta}},
|
||||
)
|
||||
self.group_collection.delete_many(
|
||||
{'date_done': {'$lt': self.app.now() - self.expires_delta}},
|
||||
)
|
||||
|
||||
def __reduce__(self, args=(), kwargs=None):
|
||||
kwargs = {} if not kwargs else kwargs
|
||||
return super().__reduce__(
|
||||
args, dict(kwargs, expires=self.expires, url=self.url))
|
||||
|
||||
def _get_database(self):
|
||||
conn = self._get_connection()
|
||||
return conn[self.database_name]
|
||||
|
||||
@cached_property
|
||||
def database(self):
|
||||
"""Get database from MongoDB connection.
|
||||
|
||||
performs authentication if necessary.
|
||||
"""
|
||||
return self._get_database()
|
||||
|
||||
@cached_property
|
||||
def collection(self):
|
||||
"""Get the meta-data task collection."""
|
||||
collection = self.database[self.taskmeta_collection]
|
||||
|
||||
# Ensure an index on date_done is there, if not process the index
|
||||
# in the background. Once completed cleanup will be much faster
|
||||
collection.create_index('date_done', background=True)
|
||||
return collection
|
||||
|
||||
@cached_property
|
||||
def group_collection(self):
|
||||
"""Get the meta-data task collection."""
|
||||
collection = self.database[self.groupmeta_collection]
|
||||
|
||||
# Ensure an index on date_done is there, if not process the index
|
||||
# in the background. Once completed cleanup will be much faster
|
||||
collection.create_index('date_done', background=True)
|
||||
return collection
|
||||
|
||||
@cached_property
|
||||
def expires_delta(self):
|
||||
return timedelta(seconds=self.expires)
|
||||
|
||||
def as_uri(self, include_password=False):
|
||||
"""Return the backend as an URI.
|
||||
|
||||
Arguments:
|
||||
include_password (bool): Password censored if disabled.
|
||||
"""
|
||||
if not self.url:
|
||||
return 'mongodb://'
|
||||
if include_password:
|
||||
return self.url
|
||||
|
||||
if ',' not in self.url:
|
||||
return maybe_sanitize_url(self.url)
|
||||
|
||||
uri1, remainder = self.url.split(',', 1)
|
||||
return ','.join([maybe_sanitize_url(uri1), remainder])
|
||||
@@ -0,0 +1,673 @@
|
||||
"""Redis result store backend."""
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from functools import partial
|
||||
from ssl import CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED
|
||||
from urllib.parse import unquote
|
||||
|
||||
from kombu.utils.functional import retry_over_time
|
||||
from kombu.utils.objects import cached_property
|
||||
from kombu.utils.url import _parse_url, maybe_sanitize_url
|
||||
|
||||
from celery import states
|
||||
from celery._state import task_join_will_block
|
||||
from celery.canvas import maybe_signature
|
||||
from celery.exceptions import BackendStoreError, ChordError, ImproperlyConfigured
|
||||
from celery.result import GroupResult, allow_join_result
|
||||
from celery.utils.functional import _regen, dictfilter
|
||||
from celery.utils.log import get_logger
|
||||
from celery.utils.time import humanize_seconds
|
||||
|
||||
from .asynchronous import AsyncBackendMixin, BaseResultConsumer
|
||||
from .base import BaseKeyValueStoreBackend
|
||||
|
||||
try:
|
||||
import redis.connection
|
||||
from kombu.transport.redis import get_redis_error_classes
|
||||
except ImportError:
|
||||
redis = None
|
||||
get_redis_error_classes = None
|
||||
|
||||
try:
|
||||
import redis.sentinel
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
__all__ = ('RedisBackend', 'SentinelBackend')
|
||||
|
||||
E_REDIS_MISSING = """
|
||||
You need to install the redis library in order to use \
|
||||
the Redis result store backend.
|
||||
"""
|
||||
|
||||
E_REDIS_SENTINEL_MISSING = """
|
||||
You need to install the redis library with support of \
|
||||
sentinel in order to use the Redis result store backend.
|
||||
"""
|
||||
|
||||
W_REDIS_SSL_CERT_OPTIONAL = """
|
||||
Setting ssl_cert_reqs=CERT_OPTIONAL when connecting to redis means that \
|
||||
celery might not validate the identity of the redis broker when connecting. \
|
||||
This leaves you vulnerable to man in the middle attacks.
|
||||
"""
|
||||
|
||||
W_REDIS_SSL_CERT_NONE = """
|
||||
Setting ssl_cert_reqs=CERT_NONE when connecting to redis means that celery \
|
||||
will not validate the identity of the redis broker when connecting. This \
|
||||
leaves you vulnerable to man in the middle attacks.
|
||||
"""
|
||||
|
||||
E_REDIS_SSL_PARAMS_AND_SCHEME_MISMATCH = """
|
||||
SSL connection parameters have been provided but the specified URL scheme \
|
||||
is redis://. A Redis SSL connection URL should use the scheme rediss://.
|
||||
"""
|
||||
|
||||
E_REDIS_SSL_CERT_REQS_MISSING_INVALID = """
|
||||
A rediss:// URL must have parameter ssl_cert_reqs and this must be set to \
|
||||
CERT_REQUIRED, CERT_OPTIONAL, or CERT_NONE
|
||||
"""
|
||||
|
||||
E_LOST = 'Connection to Redis lost: Retry (%s/%s) %s.'
|
||||
|
||||
E_RETRY_LIMIT_EXCEEDED = """
|
||||
Retry limit exceeded while trying to reconnect to the Celery redis result \
|
||||
store backend. The Celery application must be restarted.
|
||||
"""
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class ResultConsumer(BaseResultConsumer):
|
||||
_pubsub = None
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._get_key_for_task = self.backend.get_key_for_task
|
||||
self._decode_result = self.backend.decode_result
|
||||
self._ensure = self.backend.ensure
|
||||
self._connection_errors = self.backend.connection_errors
|
||||
self.subscribed_to = set()
|
||||
|
||||
def on_after_fork(self):
|
||||
try:
|
||||
self.backend.client.connection_pool.reset()
|
||||
if self._pubsub is not None:
|
||||
self._pubsub.close()
|
||||
except KeyError as e:
|
||||
logger.warning(str(e))
|
||||
super().on_after_fork()
|
||||
|
||||
def _reconnect_pubsub(self):
|
||||
self._pubsub = None
|
||||
self.backend.client.connection_pool.reset()
|
||||
# task state might have changed when the connection was down so we
|
||||
# retrieve meta for all subscribed tasks before going into pubsub mode
|
||||
if self.subscribed_to:
|
||||
metas = self.backend.client.mget(self.subscribed_to)
|
||||
metas = [meta for meta in metas if meta]
|
||||
for meta in metas:
|
||||
self.on_state_change(self._decode_result(meta), None)
|
||||
self._pubsub = self.backend.client.pubsub(
|
||||
ignore_subscribe_messages=True,
|
||||
)
|
||||
# subscribed_to maybe empty after on_state_change
|
||||
if self.subscribed_to:
|
||||
self._pubsub.subscribe(*self.subscribed_to)
|
||||
else:
|
||||
self._pubsub.connection = self._pubsub.connection_pool.get_connection(
|
||||
'pubsub', self._pubsub.shard_hint
|
||||
)
|
||||
# even if there is nothing to subscribe, we should not lose the callback after connecting.
|
||||
# The on_connect callback will re-subscribe to any channels we previously subscribed to.
|
||||
self._pubsub.connection.register_connect_callback(self._pubsub.on_connect)
|
||||
|
||||
@contextmanager
|
||||
def reconnect_on_error(self):
|
||||
try:
|
||||
yield
|
||||
except self._connection_errors:
|
||||
try:
|
||||
self._ensure(self._reconnect_pubsub, ())
|
||||
except self._connection_errors:
|
||||
logger.critical(E_RETRY_LIMIT_EXCEEDED)
|
||||
raise
|
||||
|
||||
def _maybe_cancel_ready_task(self, meta):
|
||||
if meta['status'] in states.READY_STATES:
|
||||
self.cancel_for(meta['task_id'])
|
||||
|
||||
def on_state_change(self, meta, message):
|
||||
super().on_state_change(meta, message)
|
||||
self._maybe_cancel_ready_task(meta)
|
||||
|
||||
def start(self, initial_task_id, **kwargs):
|
||||
self._pubsub = self.backend.client.pubsub(
|
||||
ignore_subscribe_messages=True,
|
||||
)
|
||||
self._consume_from(initial_task_id)
|
||||
|
||||
def on_wait_for_pending(self, result, **kwargs):
|
||||
for meta in result._iter_meta(**kwargs):
|
||||
if meta is not None:
|
||||
self.on_state_change(meta, None)
|
||||
|
||||
def stop(self):
|
||||
if self._pubsub is not None:
|
||||
self._pubsub.close()
|
||||
|
||||
def drain_events(self, timeout=None):
|
||||
if self._pubsub:
|
||||
with self.reconnect_on_error():
|
||||
message = self._pubsub.get_message(timeout=timeout)
|
||||
if message and message['type'] == 'message':
|
||||
self.on_state_change(self._decode_result(message['data']), message)
|
||||
elif timeout:
|
||||
time.sleep(timeout)
|
||||
|
||||
def consume_from(self, task_id):
|
||||
if self._pubsub is None:
|
||||
return self.start(task_id)
|
||||
self._consume_from(task_id)
|
||||
|
||||
def _consume_from(self, task_id):
|
||||
key = self._get_key_for_task(task_id)
|
||||
if key not in self.subscribed_to:
|
||||
self.subscribed_to.add(key)
|
||||
with self.reconnect_on_error():
|
||||
self._pubsub.subscribe(key)
|
||||
|
||||
def cancel_for(self, task_id):
|
||||
key = self._get_key_for_task(task_id)
|
||||
self.subscribed_to.discard(key)
|
||||
if self._pubsub:
|
||||
with self.reconnect_on_error():
|
||||
self._pubsub.unsubscribe(key)
|
||||
|
||||
|
||||
class RedisBackend(BaseKeyValueStoreBackend, AsyncBackendMixin):
|
||||
"""Redis task result store.
|
||||
|
||||
It makes use of the following commands:
|
||||
GET, MGET, DEL, INCRBY, EXPIRE, SET, SETEX
|
||||
"""
|
||||
|
||||
ResultConsumer = ResultConsumer
|
||||
|
||||
#: :pypi:`redis` client module.
|
||||
redis = redis
|
||||
connection_class_ssl = redis.SSLConnection if redis else None
|
||||
|
||||
#: Maximum number of connections in the pool.
|
||||
max_connections = None
|
||||
|
||||
supports_autoexpire = True
|
||||
supports_native_join = True
|
||||
|
||||
#: Maximal length of string value in Redis.
|
||||
#: 512 MB - https://redis.io/topics/data-types
|
||||
_MAX_STR_VALUE_SIZE = 536870912
|
||||
|
||||
def __init__(self, host=None, port=None, db=None, password=None,
|
||||
max_connections=None, url=None,
|
||||
connection_pool=None, **kwargs):
|
||||
super().__init__(expires_type=int, **kwargs)
|
||||
_get = self.app.conf.get
|
||||
if self.redis is None:
|
||||
raise ImproperlyConfigured(E_REDIS_MISSING.strip())
|
||||
|
||||
if host and '://' in host:
|
||||
url, host = host, None
|
||||
|
||||
self.max_connections = (
|
||||
max_connections or
|
||||
_get('redis_max_connections') or
|
||||
self.max_connections)
|
||||
self._ConnectionPool = connection_pool
|
||||
|
||||
socket_timeout = _get('redis_socket_timeout')
|
||||
socket_connect_timeout = _get('redis_socket_connect_timeout')
|
||||
retry_on_timeout = _get('redis_retry_on_timeout')
|
||||
socket_keepalive = _get('redis_socket_keepalive')
|
||||
health_check_interval = _get('redis_backend_health_check_interval')
|
||||
|
||||
self.connparams = {
|
||||
'host': _get('redis_host') or 'localhost',
|
||||
'port': _get('redis_port') or 6379,
|
||||
'db': _get('redis_db') or 0,
|
||||
'password': _get('redis_password'),
|
||||
'max_connections': self.max_connections,
|
||||
'socket_timeout': socket_timeout and float(socket_timeout),
|
||||
'retry_on_timeout': retry_on_timeout or False,
|
||||
'socket_connect_timeout':
|
||||
socket_connect_timeout and float(socket_connect_timeout),
|
||||
}
|
||||
|
||||
username = _get('redis_username')
|
||||
if username:
|
||||
# We're extra careful to avoid including this configuration value
|
||||
# if it wasn't specified since older versions of py-redis
|
||||
# don't support specifying a username.
|
||||
# Only Redis>6.0 supports username/password authentication.
|
||||
|
||||
# TODO: Include this in connparams' definition once we drop
|
||||
# support for py-redis<3.4.0.
|
||||
self.connparams['username'] = username
|
||||
|
||||
if health_check_interval:
|
||||
self.connparams["health_check_interval"] = health_check_interval
|
||||
|
||||
# absent in redis.connection.UnixDomainSocketConnection
|
||||
if socket_keepalive:
|
||||
self.connparams['socket_keepalive'] = socket_keepalive
|
||||
|
||||
# "redis_backend_use_ssl" must be a dict with the keys:
|
||||
# 'ssl_cert_reqs', 'ssl_ca_certs', 'ssl_certfile', 'ssl_keyfile'
|
||||
# (the same as "broker_use_ssl")
|
||||
ssl = _get('redis_backend_use_ssl')
|
||||
if ssl:
|
||||
self.connparams.update(ssl)
|
||||
self.connparams['connection_class'] = self.connection_class_ssl
|
||||
|
||||
if url:
|
||||
self.connparams = self._params_from_url(url, self.connparams)
|
||||
|
||||
# If we've received SSL parameters via query string or the
|
||||
# redis_backend_use_ssl dict, check ssl_cert_reqs is valid. If set
|
||||
# via query string ssl_cert_reqs will be a string so convert it here
|
||||
if ('connection_class' in self.connparams and
|
||||
issubclass(self.connparams['connection_class'], redis.SSLConnection)):
|
||||
ssl_cert_reqs_missing = 'MISSING'
|
||||
ssl_string_to_constant = {'CERT_REQUIRED': CERT_REQUIRED,
|
||||
'CERT_OPTIONAL': CERT_OPTIONAL,
|
||||
'CERT_NONE': CERT_NONE,
|
||||
'required': CERT_REQUIRED,
|
||||
'optional': CERT_OPTIONAL,
|
||||
'none': CERT_NONE}
|
||||
ssl_cert_reqs = self.connparams.get('ssl_cert_reqs', ssl_cert_reqs_missing)
|
||||
ssl_cert_reqs = ssl_string_to_constant.get(ssl_cert_reqs, ssl_cert_reqs)
|
||||
if ssl_cert_reqs not in ssl_string_to_constant.values():
|
||||
raise ValueError(E_REDIS_SSL_CERT_REQS_MISSING_INVALID)
|
||||
|
||||
if ssl_cert_reqs == CERT_OPTIONAL:
|
||||
logger.warning(W_REDIS_SSL_CERT_OPTIONAL)
|
||||
elif ssl_cert_reqs == CERT_NONE:
|
||||
logger.warning(W_REDIS_SSL_CERT_NONE)
|
||||
self.connparams['ssl_cert_reqs'] = ssl_cert_reqs
|
||||
|
||||
self.url = url
|
||||
|
||||
self.connection_errors, self.channel_errors = (
|
||||
get_redis_error_classes() if get_redis_error_classes
|
||||
else ((), ()))
|
||||
self.result_consumer = self.ResultConsumer(
|
||||
self, self.app, self.accept,
|
||||
self._pending_results, self._pending_messages,
|
||||
)
|
||||
|
||||
def _params_from_url(self, url, defaults):
|
||||
scheme, host, port, username, password, path, query = _parse_url(url)
|
||||
connparams = dict(
|
||||
defaults, **dictfilter({
|
||||
'host': host, 'port': port, 'username': username,
|
||||
'password': password, 'db': query.pop('virtual_host', None)})
|
||||
)
|
||||
|
||||
if scheme == 'socket':
|
||||
# use 'path' as path to the socket… in this case
|
||||
# the database number should be given in 'query'
|
||||
connparams.update({
|
||||
'connection_class': self.redis.UnixDomainSocketConnection,
|
||||
'path': '/' + path,
|
||||
})
|
||||
# host+port are invalid options when using this connection type.
|
||||
connparams.pop('host', None)
|
||||
connparams.pop('port', None)
|
||||
connparams.pop('socket_connect_timeout')
|
||||
else:
|
||||
connparams['db'] = path
|
||||
|
||||
ssl_param_keys = ['ssl_ca_certs', 'ssl_certfile', 'ssl_keyfile',
|
||||
'ssl_cert_reqs']
|
||||
|
||||
if scheme == 'redis':
|
||||
# If connparams or query string contain ssl params, raise error
|
||||
if (any(key in connparams for key in ssl_param_keys) or
|
||||
any(key in query for key in ssl_param_keys)):
|
||||
raise ValueError(E_REDIS_SSL_PARAMS_AND_SCHEME_MISMATCH)
|
||||
|
||||
if scheme == 'rediss':
|
||||
connparams['connection_class'] = redis.SSLConnection
|
||||
# The following parameters, if present in the URL, are encoded. We
|
||||
# must add the decoded values to connparams.
|
||||
for ssl_setting in ssl_param_keys:
|
||||
ssl_val = query.pop(ssl_setting, None)
|
||||
if ssl_val:
|
||||
connparams[ssl_setting] = unquote(ssl_val)
|
||||
|
||||
# db may be string and start with / like in kombu.
|
||||
db = connparams.get('db') or 0
|
||||
db = db.strip('/') if isinstance(db, str) else db
|
||||
connparams['db'] = int(db)
|
||||
|
||||
for key, value in query.items():
|
||||
if key in redis.connection.URL_QUERY_ARGUMENT_PARSERS:
|
||||
query[key] = redis.connection.URL_QUERY_ARGUMENT_PARSERS[key](
|
||||
value
|
||||
)
|
||||
|
||||
# Query parameters override other parameters
|
||||
connparams.update(query)
|
||||
return connparams
|
||||
|
||||
def exception_safe_to_retry(self, exc):
|
||||
if isinstance(exc, self.connection_errors):
|
||||
return True
|
||||
return False
|
||||
|
||||
@cached_property
|
||||
def retry_policy(self):
|
||||
retry_policy = super().retry_policy
|
||||
if "retry_policy" in self._transport_options:
|
||||
retry_policy = retry_policy.copy()
|
||||
retry_policy.update(self._transport_options['retry_policy'])
|
||||
|
||||
return retry_policy
|
||||
|
||||
def on_task_call(self, producer, task_id):
|
||||
if not task_join_will_block():
|
||||
self.result_consumer.consume_from(task_id)
|
||||
|
||||
def get(self, key):
|
||||
return self.client.get(key)
|
||||
|
||||
def mget(self, keys):
|
||||
return self.client.mget(keys)
|
||||
|
||||
def ensure(self, fun, args, **policy):
|
||||
retry_policy = dict(self.retry_policy, **policy)
|
||||
max_retries = retry_policy.get('max_retries')
|
||||
return retry_over_time(
|
||||
fun, self.connection_errors, args, {},
|
||||
partial(self.on_connection_error, max_retries),
|
||||
**retry_policy)
|
||||
|
||||
def on_connection_error(self, max_retries, exc, intervals, retries):
|
||||
tts = next(intervals)
|
||||
logger.error(
|
||||
E_LOST.strip(),
|
||||
retries, max_retries or 'Inf', humanize_seconds(tts, 'in '))
|
||||
return tts
|
||||
|
||||
def set(self, key, value, **retry_policy):
|
||||
if isinstance(value, str) and len(value) > self._MAX_STR_VALUE_SIZE:
|
||||
raise BackendStoreError('value too large for Redis backend')
|
||||
|
||||
return self.ensure(self._set, (key, value), **retry_policy)
|
||||
|
||||
def _set(self, key, value):
|
||||
with self.client.pipeline() as pipe:
|
||||
if self.expires:
|
||||
pipe.setex(key, self.expires, value)
|
||||
else:
|
||||
pipe.set(key, value)
|
||||
pipe.publish(key, value)
|
||||
pipe.execute()
|
||||
|
||||
def forget(self, task_id):
|
||||
super().forget(task_id)
|
||||
self.result_consumer.cancel_for(task_id)
|
||||
|
||||
def delete(self, key):
|
||||
self.client.delete(key)
|
||||
|
||||
def incr(self, key):
|
||||
return self.client.incr(key)
|
||||
|
||||
def expire(self, key, value):
|
||||
return self.client.expire(key, value)
|
||||
|
||||
def add_to_chord(self, group_id, result):
|
||||
self.client.incr(self.get_key_for_group(group_id, '.t'), 1)
|
||||
|
||||
def _unpack_chord_result(self, tup, decode,
|
||||
EXCEPTION_STATES=states.EXCEPTION_STATES,
|
||||
PROPAGATE_STATES=states.PROPAGATE_STATES):
|
||||
_, tid, state, retval = decode(tup)
|
||||
if state in EXCEPTION_STATES:
|
||||
retval = self.exception_to_python(retval)
|
||||
if state in PROPAGATE_STATES:
|
||||
raise ChordError(f'Dependency {tid} raised {retval!r}')
|
||||
return retval
|
||||
|
||||
def set_chord_size(self, group_id, chord_size):
|
||||
self.set(self.get_key_for_group(group_id, '.s'), chord_size)
|
||||
|
||||
def apply_chord(self, header_result_args, body, **kwargs):
|
||||
# If any of the child results of this chord are complex (ie. group
|
||||
# results themselves), we need to save `header_result` to ensure that
|
||||
# the expected structure is retained when we finish the chord and pass
|
||||
# the results onward to the body in `on_chord_part_return()`. We don't
|
||||
# do this is all cases to retain an optimisation in the common case
|
||||
# where a chord header is comprised of simple result objects.
|
||||
if not isinstance(header_result_args[1], _regen):
|
||||
header_result = self.app.GroupResult(*header_result_args)
|
||||
if any(isinstance(nr, GroupResult) for nr in header_result.results):
|
||||
header_result.save(backend=self)
|
||||
|
||||
@cached_property
|
||||
def _chord_zset(self):
|
||||
return self._transport_options.get('result_chord_ordered', True)
|
||||
|
||||
@cached_property
|
||||
def _transport_options(self):
|
||||
return self.app.conf.get('result_backend_transport_options', {})
|
||||
|
||||
def on_chord_part_return(self, request, state, result,
|
||||
propagate=None, **kwargs):
|
||||
app = self.app
|
||||
tid, gid, group_index = request.id, request.group, request.group_index
|
||||
if not gid or not tid:
|
||||
return
|
||||
if group_index is None:
|
||||
group_index = '+inf'
|
||||
|
||||
client = self.client
|
||||
jkey = self.get_key_for_group(gid, '.j')
|
||||
tkey = self.get_key_for_group(gid, '.t')
|
||||
skey = self.get_key_for_group(gid, '.s')
|
||||
result = self.encode_result(result, state)
|
||||
encoded = self.encode([1, tid, state, result])
|
||||
with client.pipeline() as pipe:
|
||||
pipeline = (
|
||||
pipe.zadd(jkey, {encoded: group_index}).zcount(jkey, "-inf", "+inf")
|
||||
if self._chord_zset
|
||||
else pipe.rpush(jkey, encoded).llen(jkey)
|
||||
).get(tkey).get(skey)
|
||||
if self.expires:
|
||||
pipeline = pipeline \
|
||||
.expire(jkey, self.expires) \
|
||||
.expire(tkey, self.expires) \
|
||||
.expire(skey, self.expires)
|
||||
|
||||
_, readycount, totaldiff, chord_size_bytes = pipeline.execute()[:4]
|
||||
|
||||
totaldiff = int(totaldiff or 0)
|
||||
|
||||
if chord_size_bytes:
|
||||
try:
|
||||
callback = maybe_signature(request.chord, app=app)
|
||||
total = int(chord_size_bytes) + totaldiff
|
||||
if readycount == total:
|
||||
header_result = GroupResult.restore(gid)
|
||||
if header_result is not None:
|
||||
# If we manage to restore a `GroupResult`, then it must
|
||||
# have been complex and saved by `apply_chord()` earlier.
|
||||
#
|
||||
# Before we can join the `GroupResult`, it needs to be
|
||||
# manually marked as ready to avoid blocking
|
||||
header_result.on_ready()
|
||||
# We'll `join()` it to get the results and ensure they are
|
||||
# structured as intended rather than the flattened version
|
||||
# we'd construct without any other information.
|
||||
join_func = (
|
||||
header_result.join_native
|
||||
if header_result.supports_native_join
|
||||
else header_result.join
|
||||
)
|
||||
with allow_join_result():
|
||||
resl = join_func(
|
||||
timeout=app.conf.result_chord_join_timeout,
|
||||
propagate=True
|
||||
)
|
||||
else:
|
||||
# Otherwise simply extract and decode the results we
|
||||
# stashed along the way, which should be faster for large
|
||||
# numbers of simple results in the chord header.
|
||||
decode, unpack = self.decode, self._unpack_chord_result
|
||||
with client.pipeline() as pipe:
|
||||
if self._chord_zset:
|
||||
pipeline = pipe.zrange(jkey, 0, -1)
|
||||
else:
|
||||
pipeline = pipe.lrange(jkey, 0, total)
|
||||
resl, = pipeline.execute()
|
||||
resl = [unpack(tup, decode) for tup in resl]
|
||||
try:
|
||||
callback.delay(resl)
|
||||
except Exception as exc: # pylint: disable=broad-except
|
||||
logger.exception(
|
||||
'Chord callback for %r raised: %r', request.group, exc)
|
||||
return self.chord_error_from_stack(
|
||||
callback,
|
||||
ChordError(f'Callback error: {exc!r}'),
|
||||
)
|
||||
finally:
|
||||
with client.pipeline() as pipe:
|
||||
pipe \
|
||||
.delete(jkey) \
|
||||
.delete(tkey) \
|
||||
.delete(skey) \
|
||||
.execute()
|
||||
except ChordError as exc:
|
||||
logger.exception('Chord %r raised: %r', request.group, exc)
|
||||
return self.chord_error_from_stack(callback, exc)
|
||||
except Exception as exc: # pylint: disable=broad-except
|
||||
logger.exception('Chord %r raised: %r', request.group, exc)
|
||||
return self.chord_error_from_stack(
|
||||
callback,
|
||||
ChordError(f'Join error: {exc!r}'),
|
||||
)
|
||||
|
||||
def _create_client(self, **params):
|
||||
return self._get_client()(
|
||||
connection_pool=self._get_pool(**params),
|
||||
)
|
||||
|
||||
def _get_client(self):
|
||||
return self.redis.StrictRedis
|
||||
|
||||
def _get_pool(self, **params):
|
||||
return self.ConnectionPool(**params)
|
||||
|
||||
@property
|
||||
def ConnectionPool(self):
|
||||
if self._ConnectionPool is None:
|
||||
self._ConnectionPool = self.redis.ConnectionPool
|
||||
return self._ConnectionPool
|
||||
|
||||
@cached_property
|
||||
def client(self):
|
||||
return self._create_client(**self.connparams)
|
||||
|
||||
def __reduce__(self, args=(), kwargs=None):
|
||||
kwargs = {} if not kwargs else kwargs
|
||||
return super().__reduce__(
|
||||
args, dict(kwargs, expires=self.expires, url=self.url))
|
||||
|
||||
|
||||
if getattr(redis, "sentinel", None):
|
||||
class SentinelManagedSSLConnection(
|
||||
redis.sentinel.SentinelManagedConnection,
|
||||
redis.SSLConnection):
|
||||
"""Connect to a Redis server using Sentinel + TLS.
|
||||
|
||||
Use Sentinel to identify which Redis server is the current master
|
||||
to connect to and when connecting to the Master server, use an
|
||||
SSL Connection.
|
||||
"""
|
||||
|
||||
|
||||
class SentinelBackend(RedisBackend):
|
||||
"""Redis sentinel task result store."""
|
||||
|
||||
# URL looks like `sentinel://0.0.0.0:26347/3;sentinel://0.0.0.0:26348/3`
|
||||
_SERVER_URI_SEPARATOR = ";"
|
||||
|
||||
sentinel = getattr(redis, "sentinel", None)
|
||||
connection_class_ssl = SentinelManagedSSLConnection if sentinel else None
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
if self.sentinel is None:
|
||||
raise ImproperlyConfigured(E_REDIS_SENTINEL_MISSING.strip())
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def as_uri(self, include_password=False):
|
||||
"""Return the server addresses as URIs, sanitizing the password or not."""
|
||||
# Allow superclass to do work if we don't need to force sanitization
|
||||
if include_password:
|
||||
return super().as_uri(
|
||||
include_password=include_password,
|
||||
)
|
||||
# Otherwise we need to ensure that all components get sanitized rather
|
||||
# by passing them one by one to the `kombu` helper
|
||||
uri_chunks = (
|
||||
maybe_sanitize_url(chunk)
|
||||
for chunk in (self.url or "").split(self._SERVER_URI_SEPARATOR)
|
||||
)
|
||||
# Similar to the superclass, strip the trailing slash from URIs with
|
||||
# all components empty other than the scheme
|
||||
return self._SERVER_URI_SEPARATOR.join(
|
||||
uri[:-1] if uri.endswith(":///") else uri
|
||||
for uri in uri_chunks
|
||||
)
|
||||
|
||||
def _params_from_url(self, url, defaults):
|
||||
chunks = url.split(self._SERVER_URI_SEPARATOR)
|
||||
connparams = dict(defaults, hosts=[])
|
||||
for chunk in chunks:
|
||||
data = super()._params_from_url(
|
||||
url=chunk, defaults=defaults)
|
||||
connparams['hosts'].append(data)
|
||||
for param in ("host", "port", "db", "password"):
|
||||
connparams.pop(param)
|
||||
|
||||
# Adding db/password in connparams to connect to the correct instance
|
||||
for param in ("db", "password"):
|
||||
if connparams['hosts'] and param in connparams['hosts'][0]:
|
||||
connparams[param] = connparams['hosts'][0].get(param)
|
||||
return connparams
|
||||
|
||||
def _get_sentinel_instance(self, **params):
|
||||
connparams = params.copy()
|
||||
|
||||
hosts = connparams.pop("hosts")
|
||||
min_other_sentinels = self._transport_options.get("min_other_sentinels", 0)
|
||||
sentinel_kwargs = self._transport_options.get("sentinel_kwargs", {})
|
||||
|
||||
sentinel_instance = self.sentinel.Sentinel(
|
||||
[(cp['host'], cp['port']) for cp in hosts],
|
||||
min_other_sentinels=min_other_sentinels,
|
||||
sentinel_kwargs=sentinel_kwargs,
|
||||
**connparams)
|
||||
|
||||
return sentinel_instance
|
||||
|
||||
def _get_pool(self, **params):
|
||||
sentinel_instance = self._get_sentinel_instance(**params)
|
||||
|
||||
master_name = self._transport_options.get("master_name", None)
|
||||
|
||||
return sentinel_instance.master_for(
|
||||
service_name=master_name,
|
||||
redis_class=self._get_client(),
|
||||
).connection_pool
|
||||
342
ETB-API/venv/lib/python3.12/site-packages/celery/backends/rpc.py
Normal file
342
ETB-API/venv/lib/python3.12/site-packages/celery/backends/rpc.py
Normal file
@@ -0,0 +1,342 @@
|
||||
"""The ``RPC`` result backend for AMQP brokers.
|
||||
|
||||
RPC-style result backend, using reply-to and one queue per client.
|
||||
"""
|
||||
import time
|
||||
|
||||
import kombu
|
||||
from kombu.common import maybe_declare
|
||||
from kombu.utils.compat import register_after_fork
|
||||
from kombu.utils.objects import cached_property
|
||||
|
||||
from celery import states
|
||||
from celery._state import current_task, task_join_will_block
|
||||
|
||||
from . import base
|
||||
from .asynchronous import AsyncBackendMixin, BaseResultConsumer
|
||||
|
||||
__all__ = ('BacklogLimitExceeded', 'RPCBackend')
|
||||
|
||||
E_NO_CHORD_SUPPORT = """
|
||||
The "rpc" result backend does not support chords!
|
||||
|
||||
Note that a group chained with a task is also upgraded to be a chord,
|
||||
as this pattern requires synchronization.
|
||||
|
||||
Result backends that supports chords: Redis, Database, Memcached, and more.
|
||||
"""
|
||||
|
||||
|
||||
class BacklogLimitExceeded(Exception):
|
||||
"""Too much state history to fast-forward."""
|
||||
|
||||
|
||||
def _on_after_fork_cleanup_backend(backend):
|
||||
backend._after_fork()
|
||||
|
||||
|
||||
class ResultConsumer(BaseResultConsumer):
|
||||
Consumer = kombu.Consumer
|
||||
|
||||
_connection = None
|
||||
_consumer = None
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._create_binding = self.backend._create_binding
|
||||
|
||||
def start(self, initial_task_id, no_ack=True, **kwargs):
|
||||
self._connection = self.app.connection()
|
||||
initial_queue = self._create_binding(initial_task_id)
|
||||
self._consumer = self.Consumer(
|
||||
self._connection.default_channel, [initial_queue],
|
||||
callbacks=[self.on_state_change], no_ack=no_ack,
|
||||
accept=self.accept)
|
||||
self._consumer.consume()
|
||||
|
||||
def drain_events(self, timeout=None):
|
||||
if self._connection:
|
||||
return self._connection.drain_events(timeout=timeout)
|
||||
elif timeout:
|
||||
time.sleep(timeout)
|
||||
|
||||
def stop(self):
|
||||
try:
|
||||
self._consumer.cancel()
|
||||
finally:
|
||||
self._connection.close()
|
||||
|
||||
def on_after_fork(self):
|
||||
self._consumer = None
|
||||
if self._connection is not None:
|
||||
self._connection.collect()
|
||||
self._connection = None
|
||||
|
||||
def consume_from(self, task_id):
|
||||
if self._consumer is None:
|
||||
return self.start(task_id)
|
||||
queue = self._create_binding(task_id)
|
||||
if not self._consumer.consuming_from(queue):
|
||||
self._consumer.add_queue(queue)
|
||||
self._consumer.consume()
|
||||
|
||||
def cancel_for(self, task_id):
|
||||
if self._consumer:
|
||||
self._consumer.cancel_by_queue(self._create_binding(task_id).name)
|
||||
|
||||
|
||||
class RPCBackend(base.Backend, AsyncBackendMixin):
|
||||
"""Base class for the RPC result backend."""
|
||||
|
||||
Exchange = kombu.Exchange
|
||||
Producer = kombu.Producer
|
||||
ResultConsumer = ResultConsumer
|
||||
|
||||
#: Exception raised when there are too many messages for a task id.
|
||||
BacklogLimitExceeded = BacklogLimitExceeded
|
||||
|
||||
persistent = False
|
||||
supports_autoexpire = True
|
||||
supports_native_join = True
|
||||
|
||||
retry_policy = {
|
||||
'max_retries': 20,
|
||||
'interval_start': 0,
|
||||
'interval_step': 1,
|
||||
'interval_max': 1,
|
||||
}
|
||||
|
||||
class Consumer(kombu.Consumer):
|
||||
"""Consumer that requires manual declaration of queues."""
|
||||
|
||||
auto_declare = False
|
||||
|
||||
class Queue(kombu.Queue):
|
||||
"""Queue that never caches declaration."""
|
||||
|
||||
can_cache_declaration = False
|
||||
|
||||
def __init__(self, app, connection=None, exchange=None, exchange_type=None,
|
||||
persistent=None, serializer=None, auto_delete=True, **kwargs):
|
||||
super().__init__(app, **kwargs)
|
||||
conf = self.app.conf
|
||||
self._connection = connection
|
||||
self._out_of_band = {}
|
||||
self.persistent = self.prepare_persistent(persistent)
|
||||
self.delivery_mode = 2 if self.persistent else 1
|
||||
exchange = exchange or conf.result_exchange
|
||||
exchange_type = exchange_type or conf.result_exchange_type
|
||||
self.exchange = self._create_exchange(
|
||||
exchange, exchange_type, self.delivery_mode,
|
||||
)
|
||||
self.serializer = serializer or conf.result_serializer
|
||||
self.auto_delete = auto_delete
|
||||
self.result_consumer = self.ResultConsumer(
|
||||
self, self.app, self.accept,
|
||||
self._pending_results, self._pending_messages,
|
||||
)
|
||||
if register_after_fork is not None:
|
||||
register_after_fork(self, _on_after_fork_cleanup_backend)
|
||||
|
||||
def _after_fork(self):
|
||||
# clear state for child processes.
|
||||
self._pending_results.clear()
|
||||
self.result_consumer._after_fork()
|
||||
|
||||
def _create_exchange(self, name, type='direct', delivery_mode=2):
|
||||
# uses direct to queue routing (anon exchange).
|
||||
return self.Exchange(None)
|
||||
|
||||
def _create_binding(self, task_id):
|
||||
"""Create new binding for task with id."""
|
||||
# RPC backend caches the binding, as one queue is used for all tasks.
|
||||
return self.binding
|
||||
|
||||
def ensure_chords_allowed(self):
|
||||
raise NotImplementedError(E_NO_CHORD_SUPPORT.strip())
|
||||
|
||||
def on_task_call(self, producer, task_id):
|
||||
# Called every time a task is sent when using this backend.
|
||||
# We declare the queue we receive replies on in advance of sending
|
||||
# the message, but we skip this if running in the prefork pool
|
||||
# (task_join_will_block), as we know the queue is already declared.
|
||||
if not task_join_will_block():
|
||||
maybe_declare(self.binding(producer.channel), retry=True)
|
||||
|
||||
def destination_for(self, task_id, request):
|
||||
"""Get the destination for result by task id.
|
||||
|
||||
Returns:
|
||||
Tuple[str, str]: tuple of ``(reply_to, correlation_id)``.
|
||||
"""
|
||||
# Backends didn't always receive the `request`, so we must still
|
||||
# support old code that relies on current_task.
|
||||
try:
|
||||
request = request or current_task.request
|
||||
except AttributeError:
|
||||
raise RuntimeError(
|
||||
f'RPC backend missing task request for {task_id!r}')
|
||||
return request.reply_to, request.correlation_id or task_id
|
||||
|
||||
def on_reply_declare(self, task_id):
|
||||
# Return value here is used as the `declare=` argument
|
||||
# for Producer.publish.
|
||||
# By default we don't have to declare anything when sending a result.
|
||||
pass
|
||||
|
||||
def on_result_fulfilled(self, result):
|
||||
# This usually cancels the queue after the result is received,
|
||||
# but we don't have to cancel since we have one queue per process.
|
||||
pass
|
||||
|
||||
def as_uri(self, include_password=True):
|
||||
return 'rpc://'
|
||||
|
||||
def store_result(self, task_id, result, state,
|
||||
traceback=None, request=None, **kwargs):
|
||||
"""Send task return value and state."""
|
||||
routing_key, correlation_id = self.destination_for(task_id, request)
|
||||
if not routing_key:
|
||||
return
|
||||
with self.app.amqp.producer_pool.acquire(block=True) as producer:
|
||||
producer.publish(
|
||||
self._to_result(task_id, state, result, traceback, request),
|
||||
exchange=self.exchange,
|
||||
routing_key=routing_key,
|
||||
correlation_id=correlation_id,
|
||||
serializer=self.serializer,
|
||||
retry=True, retry_policy=self.retry_policy,
|
||||
declare=self.on_reply_declare(task_id),
|
||||
delivery_mode=self.delivery_mode,
|
||||
)
|
||||
return result
|
||||
|
||||
def _to_result(self, task_id, state, result, traceback, request):
|
||||
return {
|
||||
'task_id': task_id,
|
||||
'status': state,
|
||||
'result': self.encode_result(result, state),
|
||||
'traceback': traceback,
|
||||
'children': self.current_task_children(request),
|
||||
}
|
||||
|
||||
def on_out_of_band_result(self, task_id, message):
|
||||
# Callback called when a reply for a task is received,
|
||||
# but we have no idea what to do with it.
|
||||
# Since the result is not pending, we put it in a separate
|
||||
# buffer: probably it will become pending later.
|
||||
if self.result_consumer:
|
||||
self.result_consumer.on_out_of_band_result(message)
|
||||
self._out_of_band[task_id] = message
|
||||
|
||||
def get_task_meta(self, task_id, backlog_limit=1000):
|
||||
buffered = self._out_of_band.pop(task_id, None)
|
||||
if buffered:
|
||||
return self._set_cache_by_message(task_id, buffered)
|
||||
|
||||
# Polling and using basic_get
|
||||
latest_by_id = {}
|
||||
prev = None
|
||||
for acc in self._slurp_from_queue(task_id, self.accept, backlog_limit):
|
||||
tid = self._get_message_task_id(acc)
|
||||
prev, latest_by_id[tid] = latest_by_id.get(tid), acc
|
||||
if prev:
|
||||
# backends aren't expected to keep history,
|
||||
# so we delete everything except the most recent state.
|
||||
prev.ack()
|
||||
prev = None
|
||||
|
||||
latest = latest_by_id.pop(task_id, None)
|
||||
for tid, msg in latest_by_id.items():
|
||||
self.on_out_of_band_result(tid, msg)
|
||||
|
||||
if latest:
|
||||
latest.requeue()
|
||||
return self._set_cache_by_message(task_id, latest)
|
||||
else:
|
||||
# no new state, use previous
|
||||
try:
|
||||
return self._cache[task_id]
|
||||
except KeyError:
|
||||
# result probably pending.
|
||||
return {'status': states.PENDING, 'result': None}
|
||||
poll = get_task_meta # XXX compat
|
||||
|
||||
def _set_cache_by_message(self, task_id, message):
|
||||
payload = self._cache[task_id] = self.meta_from_decoded(
|
||||
message.payload)
|
||||
return payload
|
||||
|
||||
def _slurp_from_queue(self, task_id, accept,
|
||||
limit=1000, no_ack=False):
|
||||
with self.app.pool.acquire_channel(block=True) as (_, channel):
|
||||
binding = self._create_binding(task_id)(channel)
|
||||
binding.declare()
|
||||
|
||||
for _ in range(limit):
|
||||
msg = binding.get(accept=accept, no_ack=no_ack)
|
||||
if not msg:
|
||||
break
|
||||
yield msg
|
||||
else:
|
||||
raise self.BacklogLimitExceeded(task_id)
|
||||
|
||||
def _get_message_task_id(self, message):
|
||||
try:
|
||||
# try property first so we don't have to deserialize
|
||||
# the payload.
|
||||
return message.properties['correlation_id']
|
||||
except (AttributeError, KeyError):
|
||||
# message sent by old Celery version, need to deserialize.
|
||||
return message.payload['task_id']
|
||||
|
||||
def revive(self, channel):
|
||||
pass
|
||||
|
||||
def reload_task_result(self, task_id):
|
||||
raise NotImplementedError(
|
||||
'reload_task_result is not supported by this backend.')
|
||||
|
||||
def reload_group_result(self, task_id):
|
||||
"""Reload group result, even if it has been previously fetched."""
|
||||
raise NotImplementedError(
|
||||
'reload_group_result is not supported by this backend.')
|
||||
|
||||
def save_group(self, group_id, result):
|
||||
raise NotImplementedError(
|
||||
'save_group is not supported by this backend.')
|
||||
|
||||
def restore_group(self, group_id, cache=True):
|
||||
raise NotImplementedError(
|
||||
'restore_group is not supported by this backend.')
|
||||
|
||||
def delete_group(self, group_id):
|
||||
raise NotImplementedError(
|
||||
'delete_group is not supported by this backend.')
|
||||
|
||||
def __reduce__(self, args=(), kwargs=None):
|
||||
kwargs = {} if not kwargs else kwargs
|
||||
return super().__reduce__(args, dict(
|
||||
kwargs,
|
||||
connection=self._connection,
|
||||
exchange=self.exchange.name,
|
||||
exchange_type=self.exchange.type,
|
||||
persistent=self.persistent,
|
||||
serializer=self.serializer,
|
||||
auto_delete=self.auto_delete,
|
||||
expires=self.expires,
|
||||
))
|
||||
|
||||
@property
|
||||
def binding(self):
|
||||
return self.Queue(
|
||||
self.oid, self.exchange, self.oid,
|
||||
durable=False,
|
||||
auto_delete=True,
|
||||
expires=self.expires,
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def oid(self):
|
||||
# cached here is the app thread OID: name of queue we receive results on.
|
||||
return self.app.thread_oid
|
||||
@@ -0,0 +1,87 @@
|
||||
"""s3 result store backend."""
|
||||
|
||||
from kombu.utils.encoding import bytes_to_str
|
||||
|
||||
from celery.exceptions import ImproperlyConfigured
|
||||
|
||||
from .base import KeyValueStoreBackend
|
||||
|
||||
try:
|
||||
import boto3
|
||||
import botocore
|
||||
except ImportError:
|
||||
boto3 = None
|
||||
botocore = None
|
||||
|
||||
|
||||
__all__ = ('S3Backend',)
|
||||
|
||||
|
||||
class S3Backend(KeyValueStoreBackend):
|
||||
"""An S3 task result store.
|
||||
|
||||
Raises:
|
||||
celery.exceptions.ImproperlyConfigured:
|
||||
if module :pypi:`boto3` is not available,
|
||||
if the :setting:`aws_access_key_id` or
|
||||
setting:`aws_secret_access_key` are not set,
|
||||
or it the :setting:`bucket` is not set.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if not boto3 or not botocore:
|
||||
raise ImproperlyConfigured('You must install boto3'
|
||||
'to use s3 backend')
|
||||
conf = self.app.conf
|
||||
|
||||
self.endpoint_url = conf.get('s3_endpoint_url', None)
|
||||
self.aws_region = conf.get('s3_region', None)
|
||||
|
||||
self.aws_access_key_id = conf.get('s3_access_key_id', None)
|
||||
self.aws_secret_access_key = conf.get('s3_secret_access_key', None)
|
||||
|
||||
self.bucket_name = conf.get('s3_bucket', None)
|
||||
if not self.bucket_name:
|
||||
raise ImproperlyConfigured('Missing bucket name')
|
||||
|
||||
self.base_path = conf.get('s3_base_path', None)
|
||||
|
||||
self._s3_resource = self._connect_to_s3()
|
||||
|
||||
def _get_s3_object(self, key):
|
||||
key_bucket_path = self.base_path + key if self.base_path else key
|
||||
return self._s3_resource.Object(self.bucket_name, key_bucket_path)
|
||||
|
||||
def get(self, key):
|
||||
key = bytes_to_str(key)
|
||||
s3_object = self._get_s3_object(key)
|
||||
try:
|
||||
s3_object.load()
|
||||
data = s3_object.get()['Body'].read()
|
||||
return data if self.content_encoding == 'binary' else data.decode('utf-8')
|
||||
except botocore.exceptions.ClientError as error:
|
||||
if error.response['Error']['Code'] == "404":
|
||||
return None
|
||||
raise error
|
||||
|
||||
def set(self, key, value):
|
||||
key = bytes_to_str(key)
|
||||
s3_object = self._get_s3_object(key)
|
||||
s3_object.put(Body=value)
|
||||
|
||||
def delete(self, key):
|
||||
key = bytes_to_str(key)
|
||||
s3_object = self._get_s3_object(key)
|
||||
s3_object.delete()
|
||||
|
||||
def _connect_to_s3(self):
|
||||
session = boto3.Session(
|
||||
aws_access_key_id=self.aws_access_key_id,
|
||||
aws_secret_access_key=self.aws_secret_access_key,
|
||||
region_name=self.aws_region
|
||||
)
|
||||
if session.get_credentials() is None:
|
||||
raise ImproperlyConfigured('Missing aws s3 creds')
|
||||
return session.resource('s3', endpoint_url=self.endpoint_url)
|
||||
Reference in New Issue
Block a user