Updates
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,165 @@
|
||||
"""Abortable Tasks.
|
||||
|
||||
Abortable tasks overview
|
||||
=========================
|
||||
|
||||
For long-running :class:`Task`'s, it can be desirable to support
|
||||
aborting during execution. Of course, these tasks should be built to
|
||||
support abortion specifically.
|
||||
|
||||
The :class:`AbortableTask` serves as a base class for all :class:`Task`
|
||||
objects that should support abortion by producers.
|
||||
|
||||
* Producers may invoke the :meth:`abort` method on
|
||||
:class:`AbortableAsyncResult` instances, to request abortion.
|
||||
|
||||
* Consumers (workers) should periodically check (and honor!) the
|
||||
:meth:`is_aborted` method at controlled points in their task's
|
||||
:meth:`run` method. The more often, the better.
|
||||
|
||||
The necessary intermediate communication is dealt with by the
|
||||
:class:`AbortableTask` implementation.
|
||||
|
||||
Usage example
|
||||
-------------
|
||||
|
||||
In the consumer:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from celery.contrib.abortable import AbortableTask
|
||||
from celery.utils.log import get_task_logger
|
||||
|
||||
from proj.celery import app
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@app.task(bind=True, base=AbortableTask)
|
||||
def long_running_task(self):
|
||||
results = []
|
||||
for i in range(100):
|
||||
# check after every 5 iterations...
|
||||
# (or alternatively, check when some timer is due)
|
||||
if not i % 5:
|
||||
if self.is_aborted():
|
||||
# respect aborted state, and terminate gracefully.
|
||||
logger.warning('Task aborted')
|
||||
return
|
||||
value = do_something_expensive(i)
|
||||
results.append(y)
|
||||
logger.info('Task complete')
|
||||
return results
|
||||
|
||||
In the producer:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import time
|
||||
|
||||
from proj.tasks import MyLongRunningTask
|
||||
|
||||
def myview(request):
|
||||
# result is of type AbortableAsyncResult
|
||||
result = long_running_task.delay()
|
||||
|
||||
# abort the task after 10 seconds
|
||||
time.sleep(10)
|
||||
result.abort()
|
||||
|
||||
After the `result.abort()` call, the task execution isn't
|
||||
aborted immediately. In fact, it's not guaranteed to abort at all.
|
||||
Keep checking `result.state` status, or call `result.get(timeout=)` to
|
||||
have it block until the task is finished.
|
||||
|
||||
.. note::
|
||||
|
||||
In order to abort tasks, there needs to be communication between the
|
||||
producer and the consumer. This is currently implemented through the
|
||||
database backend. Therefore, this class will only work with the
|
||||
database backends.
|
||||
"""
|
||||
from celery import Task
|
||||
from celery.result import AsyncResult
|
||||
|
||||
__all__ = ('AbortableAsyncResult', 'AbortableTask')
|
||||
|
||||
|
||||
"""
|
||||
Task States
|
||||
-----------
|
||||
|
||||
.. state:: ABORTED
|
||||
|
||||
ABORTED
|
||||
~~~~~~~
|
||||
|
||||
Task is aborted (typically by the producer) and should be
|
||||
aborted as soon as possible.
|
||||
|
||||
"""
|
||||
ABORTED = 'ABORTED'
|
||||
|
||||
|
||||
class AbortableAsyncResult(AsyncResult):
|
||||
"""Represents an abortable result.
|
||||
|
||||
Specifically, this gives the `AsyncResult` a :meth:`abort()` method,
|
||||
that sets the state of the underlying Task to `'ABORTED'`.
|
||||
"""
|
||||
|
||||
def is_aborted(self):
|
||||
"""Return :const:`True` if the task is (being) aborted."""
|
||||
return self.state == ABORTED
|
||||
|
||||
def abort(self):
|
||||
"""Set the state of the task to :const:`ABORTED`.
|
||||
|
||||
Abortable tasks monitor their state at regular intervals and
|
||||
terminate execution if so.
|
||||
|
||||
Warning:
|
||||
Be aware that invoking this method does not guarantee when the
|
||||
task will be aborted (or even if the task will be aborted at all).
|
||||
"""
|
||||
# TODO: store_result requires all four arguments to be set,
|
||||
# but only state should be updated here
|
||||
return self.backend.store_result(self.id, result=None,
|
||||
state=ABORTED, traceback=None)
|
||||
|
||||
|
||||
class AbortableTask(Task):
|
||||
"""Task that can be aborted.
|
||||
|
||||
This serves as a base class for all :class:`Task`'s
|
||||
that support aborting during execution.
|
||||
|
||||
All subclasses of :class:`AbortableTask` must call the
|
||||
:meth:`is_aborted` method periodically and act accordingly when
|
||||
the call evaluates to :const:`True`.
|
||||
"""
|
||||
|
||||
abstract = True
|
||||
|
||||
def AsyncResult(self, task_id):
|
||||
"""Return the accompanying AbortableAsyncResult instance."""
|
||||
return AbortableAsyncResult(task_id, backend=self.backend)
|
||||
|
||||
def is_aborted(self, **kwargs):
|
||||
"""Return true if task is aborted.
|
||||
|
||||
Checks against the backend whether this
|
||||
:class:`AbortableAsyncResult` is :const:`ABORTED`.
|
||||
|
||||
Always return :const:`False` in case the `task_id` parameter
|
||||
refers to a regular (non-abortable) :class:`Task`.
|
||||
|
||||
Be aware that invoking this method will cause a hit in the
|
||||
backend (for example a database query), so find a good balance
|
||||
between calling it regularly (for responsiveness), but not too
|
||||
often (for performance).
|
||||
"""
|
||||
task_id = kwargs.get('task_id', self.request.id)
|
||||
result = self.AsyncResult(task_id)
|
||||
if not isinstance(result, AbortableAsyncResult):
|
||||
return False
|
||||
return result.is_aborted()
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,21 @@
|
||||
import functools
|
||||
|
||||
from django.db import transaction
|
||||
|
||||
from celery.app.task import Task
|
||||
|
||||
|
||||
class DjangoTask(Task):
|
||||
"""
|
||||
Extend the base :class:`~celery.app.task.Task` for Django.
|
||||
|
||||
Provide a nicer API to trigger tasks at the end of the DB transaction.
|
||||
"""
|
||||
|
||||
def delay_on_commit(self, *args, **kwargs) -> None:
|
||||
"""Call :meth:`~celery.app.task.Task.delay` with Django's ``on_commit()``."""
|
||||
transaction.on_commit(functools.partial(self.delay, *args, **kwargs))
|
||||
|
||||
def apply_async_on_commit(self, *args, **kwargs) -> None:
|
||||
"""Call :meth:`~celery.app.task.Task.apply_async` with Django's ``on_commit()``."""
|
||||
transaction.on_commit(functools.partial(self.apply_async, *args, **kwargs))
|
||||
@@ -0,0 +1,416 @@
|
||||
"""Message migration tools (Broker <-> Broker)."""
|
||||
import socket
|
||||
from functools import partial
|
||||
from itertools import cycle, islice
|
||||
|
||||
from kombu import Queue, eventloop
|
||||
from kombu.common import maybe_declare
|
||||
from kombu.utils.encoding import ensure_bytes
|
||||
|
||||
from celery.app import app_or_default
|
||||
from celery.utils.nodenames import worker_direct
|
||||
from celery.utils.text import str_to_list
|
||||
|
||||
__all__ = (
|
||||
'StopFiltering', 'State', 'republish', 'migrate_task',
|
||||
'migrate_tasks', 'move', 'task_id_eq', 'task_id_in',
|
||||
'start_filter', 'move_task_by_id', 'move_by_idmap',
|
||||
'move_by_taskmap', 'move_direct', 'move_direct_by_id',
|
||||
)
|
||||
|
||||
MOVING_PROGRESS_FMT = """\
|
||||
Moving task {state.filtered}/{state.strtotal}: \
|
||||
{body[task]}[{body[id]}]\
|
||||
"""
|
||||
|
||||
|
||||
class StopFiltering(Exception):
|
||||
"""Semi-predicate used to signal filter stop."""
|
||||
|
||||
|
||||
class State:
|
||||
"""Migration progress state."""
|
||||
|
||||
count = 0
|
||||
filtered = 0
|
||||
total_apx = 0
|
||||
|
||||
@property
|
||||
def strtotal(self):
|
||||
if not self.total_apx:
|
||||
return '?'
|
||||
return str(self.total_apx)
|
||||
|
||||
def __repr__(self):
|
||||
if self.filtered:
|
||||
return f'^{self.filtered}'
|
||||
return f'{self.count}/{self.strtotal}'
|
||||
|
||||
|
||||
def republish(producer, message, exchange=None, routing_key=None,
|
||||
remove_props=None):
|
||||
"""Republish message."""
|
||||
if not remove_props:
|
||||
remove_props = ['application_headers', 'content_type',
|
||||
'content_encoding', 'headers']
|
||||
body = ensure_bytes(message.body) # use raw message body.
|
||||
info, headers, props = (message.delivery_info,
|
||||
message.headers, message.properties)
|
||||
exchange = info['exchange'] if exchange is None else exchange
|
||||
routing_key = info['routing_key'] if routing_key is None else routing_key
|
||||
ctype, enc = message.content_type, message.content_encoding
|
||||
# remove compression header, as this will be inserted again
|
||||
# when the message is recompressed.
|
||||
compression = headers.pop('compression', None)
|
||||
|
||||
expiration = props.pop('expiration', None)
|
||||
# ensure expiration is a float
|
||||
expiration = float(expiration) if expiration is not None else None
|
||||
|
||||
for key in remove_props:
|
||||
props.pop(key, None)
|
||||
|
||||
producer.publish(ensure_bytes(body), exchange=exchange,
|
||||
routing_key=routing_key, compression=compression,
|
||||
headers=headers, content_type=ctype,
|
||||
content_encoding=enc, expiration=expiration,
|
||||
**props)
|
||||
|
||||
|
||||
def migrate_task(producer, body_, message, queues=None):
|
||||
"""Migrate single task message."""
|
||||
info = message.delivery_info
|
||||
queues = {} if queues is None else queues
|
||||
republish(producer, message,
|
||||
exchange=queues.get(info['exchange']),
|
||||
routing_key=queues.get(info['routing_key']))
|
||||
|
||||
|
||||
def filter_callback(callback, tasks):
|
||||
|
||||
def filtered(body, message):
|
||||
if tasks and body['task'] not in tasks:
|
||||
return
|
||||
|
||||
return callback(body, message)
|
||||
return filtered
|
||||
|
||||
|
||||
def migrate_tasks(source, dest, migrate=migrate_task, app=None,
|
||||
queues=None, **kwargs):
|
||||
"""Migrate tasks from one broker to another."""
|
||||
app = app_or_default(app)
|
||||
queues = prepare_queues(queues)
|
||||
producer = app.amqp.Producer(dest, auto_declare=False)
|
||||
migrate = partial(migrate, producer, queues=queues)
|
||||
|
||||
def on_declare_queue(queue):
|
||||
new_queue = queue(producer.channel)
|
||||
new_queue.name = queues.get(queue.name, queue.name)
|
||||
if new_queue.routing_key == queue.name:
|
||||
new_queue.routing_key = queues.get(queue.name,
|
||||
new_queue.routing_key)
|
||||
if new_queue.exchange.name == queue.name:
|
||||
new_queue.exchange.name = queues.get(queue.name, queue.name)
|
||||
new_queue.declare()
|
||||
|
||||
return start_filter(app, source, migrate, queues=queues,
|
||||
on_declare_queue=on_declare_queue, **kwargs)
|
||||
|
||||
|
||||
def _maybe_queue(app, q):
|
||||
if isinstance(q, str):
|
||||
return app.amqp.queues[q]
|
||||
return q
|
||||
|
||||
|
||||
def move(predicate, connection=None, exchange=None, routing_key=None,
|
||||
source=None, app=None, callback=None, limit=None, transform=None,
|
||||
**kwargs):
|
||||
"""Find tasks by filtering them and move the tasks to a new queue.
|
||||
|
||||
Arguments:
|
||||
predicate (Callable): Filter function used to decide the messages
|
||||
to move. Must accept the standard signature of ``(body, message)``
|
||||
used by Kombu consumer callbacks. If the predicate wants the
|
||||
message to be moved it must return either:
|
||||
|
||||
1) a tuple of ``(exchange, routing_key)``, or
|
||||
|
||||
2) a :class:`~kombu.entity.Queue` instance, or
|
||||
|
||||
3) any other true value means the specified
|
||||
``exchange`` and ``routing_key`` arguments will be used.
|
||||
connection (kombu.Connection): Custom connection to use.
|
||||
source: List[Union[str, kombu.Queue]]: Optional list of source
|
||||
queues to use instead of the default (queues
|
||||
in :setting:`task_queues`). This list can also contain
|
||||
:class:`~kombu.entity.Queue` instances.
|
||||
exchange (str, kombu.Exchange): Default destination exchange.
|
||||
routing_key (str): Default destination routing key.
|
||||
limit (int): Limit number of messages to filter.
|
||||
callback (Callable): Callback called after message moved,
|
||||
with signature ``(state, body, message)``.
|
||||
transform (Callable): Optional function to transform the return
|
||||
value (destination) of the filter function.
|
||||
|
||||
Also supports the same keyword arguments as :func:`start_filter`.
|
||||
|
||||
To demonstrate, the :func:`move_task_by_id` operation can be implemented
|
||||
like this:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def is_wanted_task(body, message):
|
||||
if body['id'] == wanted_id:
|
||||
return Queue('foo', exchange=Exchange('foo'),
|
||||
routing_key='foo')
|
||||
|
||||
move(is_wanted_task)
|
||||
|
||||
or with a transform:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def transform(value):
|
||||
if isinstance(value, str):
|
||||
return Queue(value, Exchange(value), value)
|
||||
return value
|
||||
|
||||
move(is_wanted_task, transform=transform)
|
||||
|
||||
Note:
|
||||
The predicate may also return a tuple of ``(exchange, routing_key)``
|
||||
to specify the destination to where the task should be moved,
|
||||
or a :class:`~kombu.entity.Queue` instance.
|
||||
Any other true value means that the task will be moved to the
|
||||
default exchange/routing_key.
|
||||
"""
|
||||
app = app_or_default(app)
|
||||
queues = [_maybe_queue(app, queue) for queue in source or []] or None
|
||||
with app.connection_or_acquire(connection, pool=False) as conn:
|
||||
producer = app.amqp.Producer(conn)
|
||||
state = State()
|
||||
|
||||
def on_task(body, message):
|
||||
ret = predicate(body, message)
|
||||
if ret:
|
||||
if transform:
|
||||
ret = transform(ret)
|
||||
if isinstance(ret, Queue):
|
||||
maybe_declare(ret, conn.default_channel)
|
||||
ex, rk = ret.exchange.name, ret.routing_key
|
||||
else:
|
||||
ex, rk = expand_dest(ret, exchange, routing_key)
|
||||
republish(producer, message,
|
||||
exchange=ex, routing_key=rk)
|
||||
message.ack()
|
||||
|
||||
state.filtered += 1
|
||||
if callback:
|
||||
callback(state, body, message)
|
||||
if limit and state.filtered >= limit:
|
||||
raise StopFiltering()
|
||||
|
||||
return start_filter(app, conn, on_task, consume_from=queues, **kwargs)
|
||||
|
||||
|
||||
def expand_dest(ret, exchange, routing_key):
|
||||
try:
|
||||
ex, rk = ret
|
||||
except (TypeError, ValueError):
|
||||
ex, rk = exchange, routing_key
|
||||
return ex, rk
|
||||
|
||||
|
||||
def task_id_eq(task_id, body, message):
|
||||
"""Return true if task id equals task_id'."""
|
||||
return body['id'] == task_id
|
||||
|
||||
|
||||
def task_id_in(ids, body, message):
|
||||
"""Return true if task id is member of set ids'."""
|
||||
return body['id'] in ids
|
||||
|
||||
|
||||
def prepare_queues(queues):
|
||||
if isinstance(queues, str):
|
||||
queues = queues.split(',')
|
||||
if isinstance(queues, list):
|
||||
queues = dict(tuple(islice(cycle(q.split(':')), None, 2))
|
||||
for q in queues)
|
||||
if queues is None:
|
||||
queues = {}
|
||||
return queues
|
||||
|
||||
|
||||
class Filterer:
|
||||
|
||||
def __init__(self, app, conn, filter,
|
||||
limit=None, timeout=1.0,
|
||||
ack_messages=False, tasks=None, queues=None,
|
||||
callback=None, forever=False, on_declare_queue=None,
|
||||
consume_from=None, state=None, accept=None, **kwargs):
|
||||
self.app = app
|
||||
self.conn = conn
|
||||
self.filter = filter
|
||||
self.limit = limit
|
||||
self.timeout = timeout
|
||||
self.ack_messages = ack_messages
|
||||
self.tasks = set(str_to_list(tasks) or [])
|
||||
self.queues = prepare_queues(queues)
|
||||
self.callback = callback
|
||||
self.forever = forever
|
||||
self.on_declare_queue = on_declare_queue
|
||||
self.consume_from = [
|
||||
_maybe_queue(self.app, q)
|
||||
for q in consume_from or list(self.queues)
|
||||
]
|
||||
self.state = state or State()
|
||||
self.accept = accept
|
||||
|
||||
def start(self):
|
||||
# start migrating messages.
|
||||
with self.prepare_consumer(self.create_consumer()):
|
||||
try:
|
||||
for _ in eventloop(self.conn, # pragma: no cover
|
||||
timeout=self.timeout,
|
||||
ignore_timeouts=self.forever):
|
||||
pass
|
||||
except socket.timeout:
|
||||
pass
|
||||
except StopFiltering:
|
||||
pass
|
||||
return self.state
|
||||
|
||||
def update_state(self, body, message):
|
||||
self.state.count += 1
|
||||
if self.limit and self.state.count >= self.limit:
|
||||
raise StopFiltering()
|
||||
|
||||
def ack_message(self, body, message):
|
||||
message.ack()
|
||||
|
||||
def create_consumer(self):
|
||||
return self.app.amqp.TaskConsumer(
|
||||
self.conn,
|
||||
queues=self.consume_from,
|
||||
accept=self.accept,
|
||||
)
|
||||
|
||||
def prepare_consumer(self, consumer):
|
||||
filter = self.filter
|
||||
update_state = self.update_state
|
||||
ack_message = self.ack_message
|
||||
if self.tasks:
|
||||
filter = filter_callback(filter, self.tasks)
|
||||
update_state = filter_callback(update_state, self.tasks)
|
||||
ack_message = filter_callback(ack_message, self.tasks)
|
||||
consumer.register_callback(filter)
|
||||
consumer.register_callback(update_state)
|
||||
if self.ack_messages:
|
||||
consumer.register_callback(self.ack_message)
|
||||
if self.callback is not None:
|
||||
callback = partial(self.callback, self.state)
|
||||
if self.tasks:
|
||||
callback = filter_callback(callback, self.tasks)
|
||||
consumer.register_callback(callback)
|
||||
self.declare_queues(consumer)
|
||||
return consumer
|
||||
|
||||
def declare_queues(self, consumer):
|
||||
# declare all queues on the new broker.
|
||||
for queue in consumer.queues:
|
||||
if self.queues and queue.name not in self.queues:
|
||||
continue
|
||||
if self.on_declare_queue is not None:
|
||||
self.on_declare_queue(queue)
|
||||
try:
|
||||
_, mcount, _ = queue(
|
||||
consumer.channel).queue_declare(passive=True)
|
||||
if mcount:
|
||||
self.state.total_apx += mcount
|
||||
except self.conn.channel_errors:
|
||||
pass
|
||||
|
||||
|
||||
def start_filter(app, conn, filter, limit=None, timeout=1.0,
|
||||
ack_messages=False, tasks=None, queues=None,
|
||||
callback=None, forever=False, on_declare_queue=None,
|
||||
consume_from=None, state=None, accept=None, **kwargs):
|
||||
"""Filter tasks."""
|
||||
return Filterer(
|
||||
app, conn, filter,
|
||||
limit=limit,
|
||||
timeout=timeout,
|
||||
ack_messages=ack_messages,
|
||||
tasks=tasks,
|
||||
queues=queues,
|
||||
callback=callback,
|
||||
forever=forever,
|
||||
on_declare_queue=on_declare_queue,
|
||||
consume_from=consume_from,
|
||||
state=state,
|
||||
accept=accept,
|
||||
**kwargs).start()
|
||||
|
||||
|
||||
def move_task_by_id(task_id, dest, **kwargs):
|
||||
"""Find a task by id and move it to another queue.
|
||||
|
||||
Arguments:
|
||||
task_id (str): Id of task to find and move.
|
||||
dest: (str, kombu.Queue): Destination queue.
|
||||
transform (Callable): Optional function to transform the return
|
||||
value (destination) of the filter function.
|
||||
**kwargs (Any): Also supports the same keyword
|
||||
arguments as :func:`move`.
|
||||
"""
|
||||
return move_by_idmap({task_id: dest}, **kwargs)
|
||||
|
||||
|
||||
def move_by_idmap(map, **kwargs):
|
||||
"""Move tasks by matching from a ``task_id: queue`` mapping.
|
||||
|
||||
Where ``queue`` is a queue to move the task to.
|
||||
|
||||
Example:
|
||||
>>> move_by_idmap({
|
||||
... '5bee6e82-f4ac-468e-bd3d-13e8600250bc': Queue('name'),
|
||||
... 'ada8652d-aef3-466b-abd2-becdaf1b82b3': Queue('name'),
|
||||
... '3a2b140d-7db1-41ba-ac90-c36a0ef4ab1f': Queue('name')},
|
||||
... queues=['hipri'])
|
||||
"""
|
||||
def task_id_in_map(body, message):
|
||||
return map.get(message.properties['correlation_id'])
|
||||
|
||||
# adding the limit means that we don't have to consume any more
|
||||
# when we've found everything.
|
||||
return move(task_id_in_map, limit=len(map), **kwargs)
|
||||
|
||||
|
||||
def move_by_taskmap(map, **kwargs):
|
||||
"""Move tasks by matching from a ``task_name: queue`` mapping.
|
||||
|
||||
``queue`` is the queue to move the task to.
|
||||
|
||||
Example:
|
||||
>>> move_by_taskmap({
|
||||
... 'tasks.add': Queue('name'),
|
||||
... 'tasks.mul': Queue('name'),
|
||||
... })
|
||||
"""
|
||||
def task_name_in_map(body, message):
|
||||
return map.get(body['task']) # <- name of task
|
||||
|
||||
return move(task_name_in_map, **kwargs)
|
||||
|
||||
|
||||
def filter_status(state, body, message, **kwargs):
|
||||
print(MOVING_PROGRESS_FMT.format(state=state, body=body, **kwargs))
|
||||
|
||||
|
||||
move_direct = partial(move, transform=worker_direct)
|
||||
move_direct_by_id = partial(move_task_by_id, transform=worker_direct)
|
||||
move_direct_by_idmap = partial(move_by_idmap, transform=worker_direct)
|
||||
move_direct_by_taskmap = partial(move_by_taskmap, transform=worker_direct)
|
||||
@@ -0,0 +1,216 @@
|
||||
"""Fixtures and testing utilities for :pypi:`pytest <pytest>`."""
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING, Any, Mapping, Sequence, Union # noqa
|
||||
|
||||
import pytest
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from celery import Celery
|
||||
|
||||
from ..worker import WorkController
|
||||
else:
|
||||
Celery = WorkController = object
|
||||
|
||||
|
||||
NO_WORKER = os.environ.get('NO_WORKER')
|
||||
|
||||
# pylint: disable=redefined-outer-name
|
||||
# Well, they're called fixtures....
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
"""Register additional pytest configuration."""
|
||||
# add the pytest.mark.celery() marker registration to the pytest.ini [markers] section
|
||||
# this prevents pytest 4.5 and newer from issuing a warning about an unknown marker
|
||||
# and shows helpful marker documentation when running pytest --markers.
|
||||
config.addinivalue_line(
|
||||
"markers", "celery(**overrides): override celery configuration for a test case"
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _create_app(enable_logging=False,
|
||||
use_trap=False,
|
||||
parameters=None,
|
||||
**config):
|
||||
# type: (Any, Any, Any, **Any) -> Celery
|
||||
"""Utility context used to setup Celery app for pytest fixtures."""
|
||||
|
||||
from .testing.app import TestApp, setup_default_app
|
||||
|
||||
parameters = {} if not parameters else parameters
|
||||
test_app = TestApp(
|
||||
set_as_current=False,
|
||||
enable_logging=enable_logging,
|
||||
config=config,
|
||||
**parameters
|
||||
)
|
||||
with setup_default_app(test_app, use_trap=use_trap):
|
||||
yield test_app
|
||||
|
||||
|
||||
@pytest.fixture(scope='session')
|
||||
def use_celery_app_trap():
|
||||
# type: () -> bool
|
||||
"""You can override this fixture to enable the app trap.
|
||||
|
||||
The app trap raises an exception whenever something attempts
|
||||
to use the current or default apps.
|
||||
"""
|
||||
return False
|
||||
|
||||
|
||||
@pytest.fixture(scope='session')
|
||||
def celery_session_app(request,
|
||||
celery_config,
|
||||
celery_parameters,
|
||||
celery_enable_logging,
|
||||
use_celery_app_trap):
|
||||
# type: (Any, Any, Any, Any, Any) -> Celery
|
||||
"""Session Fixture: Return app for session fixtures."""
|
||||
mark = request.node.get_closest_marker('celery')
|
||||
config = dict(celery_config, **mark.kwargs if mark else {})
|
||||
with _create_app(enable_logging=celery_enable_logging,
|
||||
use_trap=use_celery_app_trap,
|
||||
parameters=celery_parameters,
|
||||
**config) as app:
|
||||
if not use_celery_app_trap:
|
||||
app.set_default()
|
||||
app.set_current()
|
||||
yield app
|
||||
|
||||
|
||||
@pytest.fixture(scope='session')
|
||||
def celery_session_worker(
|
||||
request, # type: Any
|
||||
celery_session_app, # type: Celery
|
||||
celery_includes, # type: Sequence[str]
|
||||
celery_class_tasks, # type: str
|
||||
celery_worker_pool, # type: Any
|
||||
celery_worker_parameters, # type: Mapping[str, Any]
|
||||
):
|
||||
# type: (...) -> WorkController
|
||||
"""Session Fixture: Start worker that lives throughout test suite."""
|
||||
from .testing import worker
|
||||
|
||||
if not NO_WORKER:
|
||||
for module in celery_includes:
|
||||
celery_session_app.loader.import_task_module(module)
|
||||
for class_task in celery_class_tasks:
|
||||
celery_session_app.register_task(class_task)
|
||||
with worker.start_worker(celery_session_app,
|
||||
pool=celery_worker_pool,
|
||||
**celery_worker_parameters) as w:
|
||||
yield w
|
||||
|
||||
|
||||
@pytest.fixture(scope='session')
|
||||
def celery_enable_logging():
|
||||
# type: () -> bool
|
||||
"""You can override this fixture to enable logging."""
|
||||
return False
|
||||
|
||||
|
||||
@pytest.fixture(scope='session')
|
||||
def celery_includes():
|
||||
# type: () -> Sequence[str]
|
||||
"""You can override this include modules when a worker start.
|
||||
|
||||
You can have this return a list of module names to import,
|
||||
these can be task modules, modules registering signals, and so on.
|
||||
"""
|
||||
return ()
|
||||
|
||||
|
||||
@pytest.fixture(scope='session')
|
||||
def celery_worker_pool():
|
||||
# type: () -> Union[str, Any]
|
||||
"""You can override this fixture to set the worker pool.
|
||||
|
||||
The "solo" pool is used by default, but you can set this to
|
||||
return e.g. "prefork".
|
||||
"""
|
||||
return 'solo'
|
||||
|
||||
|
||||
@pytest.fixture(scope='session')
|
||||
def celery_config():
|
||||
# type: () -> Mapping[str, Any]
|
||||
"""Redefine this fixture to configure the test Celery app.
|
||||
|
||||
The config returned by your fixture will then be used
|
||||
to configure the :func:`celery_app` fixture.
|
||||
"""
|
||||
return {}
|
||||
|
||||
|
||||
@pytest.fixture(scope='session')
|
||||
def celery_parameters():
|
||||
# type: () -> Mapping[str, Any]
|
||||
"""Redefine this fixture to change the init parameters of test Celery app.
|
||||
|
||||
The dict returned by your fixture will then be used
|
||||
as parameters when instantiating :class:`~celery.Celery`.
|
||||
"""
|
||||
return {}
|
||||
|
||||
|
||||
@pytest.fixture(scope='session')
|
||||
def celery_worker_parameters():
|
||||
# type: () -> Mapping[str, Any]
|
||||
"""Redefine this fixture to change the init parameters of Celery workers.
|
||||
|
||||
This can be used e. g. to define queues the worker will consume tasks from.
|
||||
|
||||
The dict returned by your fixture will then be used
|
||||
as parameters when instantiating :class:`~celery.worker.WorkController`.
|
||||
"""
|
||||
return {}
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def celery_app(request,
|
||||
celery_config,
|
||||
celery_parameters,
|
||||
celery_enable_logging,
|
||||
use_celery_app_trap):
|
||||
"""Fixture creating a Celery application instance."""
|
||||
mark = request.node.get_closest_marker('celery')
|
||||
config = dict(celery_config, **mark.kwargs if mark else {})
|
||||
with _create_app(enable_logging=celery_enable_logging,
|
||||
use_trap=use_celery_app_trap,
|
||||
parameters=celery_parameters,
|
||||
**config) as app:
|
||||
yield app
|
||||
|
||||
|
||||
@pytest.fixture(scope='session')
|
||||
def celery_class_tasks():
|
||||
"""Redefine this fixture to register tasks with the test Celery app."""
|
||||
return []
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def celery_worker(request,
|
||||
celery_app,
|
||||
celery_includes,
|
||||
celery_worker_pool,
|
||||
celery_worker_parameters):
|
||||
# type: (Any, Celery, Sequence[str], str, Any) -> WorkController
|
||||
"""Fixture: Start worker in a thread, stop it when the test returns."""
|
||||
from .testing import worker
|
||||
|
||||
if not NO_WORKER:
|
||||
for module in celery_includes:
|
||||
celery_app.loader.import_task_module(module)
|
||||
with worker.start_worker(celery_app,
|
||||
pool=celery_worker_pool,
|
||||
**celery_worker_parameters) as w:
|
||||
yield w
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def depends_on_current_app(celery_app):
|
||||
"""Fixture that sets app as current."""
|
||||
celery_app.set_current()
|
||||
187
ETB-API/venv/lib/python3.12/site-packages/celery/contrib/rdb.py
Normal file
187
ETB-API/venv/lib/python3.12/site-packages/celery/contrib/rdb.py
Normal file
@@ -0,0 +1,187 @@
|
||||
"""Remote Debugger.
|
||||
|
||||
Introduction
|
||||
============
|
||||
|
||||
This is a remote debugger for Celery tasks running in multiprocessing
|
||||
pool workers. Inspired by a lost post on dzone.com.
|
||||
|
||||
Usage
|
||||
-----
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from celery.contrib import rdb
|
||||
from celery import task
|
||||
|
||||
@task()
|
||||
def add(x, y):
|
||||
result = x + y
|
||||
rdb.set_trace()
|
||||
return result
|
||||
|
||||
Environment Variables
|
||||
=====================
|
||||
|
||||
.. envvar:: CELERY_RDB_HOST
|
||||
|
||||
``CELERY_RDB_HOST``
|
||||
-------------------
|
||||
|
||||
Hostname to bind to. Default is '127.0.0.1' (only accessible from
|
||||
localhost).
|
||||
|
||||
.. envvar:: CELERY_RDB_PORT
|
||||
|
||||
``CELERY_RDB_PORT``
|
||||
-------------------
|
||||
|
||||
Base port to bind to. Default is 6899.
|
||||
The debugger will try to find an available port starting from the
|
||||
base port. The selected port will be logged by the worker.
|
||||
"""
|
||||
import errno
|
||||
import os
|
||||
import socket
|
||||
import sys
|
||||
from pdb import Pdb
|
||||
|
||||
from billiard.process import current_process
|
||||
|
||||
__all__ = (
|
||||
'CELERY_RDB_HOST', 'CELERY_RDB_PORT', 'DEFAULT_PORT',
|
||||
'Rdb', 'debugger', 'set_trace',
|
||||
)
|
||||
|
||||
DEFAULT_PORT = 6899
|
||||
|
||||
CELERY_RDB_HOST = os.environ.get('CELERY_RDB_HOST') or '127.0.0.1'
|
||||
CELERY_RDB_PORT = int(os.environ.get('CELERY_RDB_PORT') or DEFAULT_PORT)
|
||||
|
||||
#: Holds the currently active debugger.
|
||||
_current = [None]
|
||||
|
||||
_frame = getattr(sys, '_getframe')
|
||||
|
||||
NO_AVAILABLE_PORT = """\
|
||||
{self.ident}: Couldn't find an available port.
|
||||
|
||||
Please specify one using the CELERY_RDB_PORT environment variable.
|
||||
"""
|
||||
|
||||
BANNER = """\
|
||||
{self.ident}: Ready to connect: telnet {self.host} {self.port}
|
||||
|
||||
Type `exit` in session to continue.
|
||||
|
||||
{self.ident}: Waiting for client...
|
||||
"""
|
||||
|
||||
SESSION_STARTED = '{self.ident}: Now in session with {self.remote_addr}.'
|
||||
SESSION_ENDED = '{self.ident}: Session with {self.remote_addr} ended.'
|
||||
|
||||
|
||||
class Rdb(Pdb):
|
||||
"""Remote debugger."""
|
||||
|
||||
me = 'Remote Debugger'
|
||||
_prev_outs = None
|
||||
_sock = None
|
||||
|
||||
def __init__(self, host=CELERY_RDB_HOST, port=CELERY_RDB_PORT,
|
||||
port_search_limit=100, port_skew=+0, out=sys.stdout):
|
||||
self.active = True
|
||||
self.out = out
|
||||
|
||||
self._prev_handles = sys.stdin, sys.stdout
|
||||
|
||||
self._sock, this_port = self.get_avail_port(
|
||||
host, port, port_search_limit, port_skew,
|
||||
)
|
||||
self._sock.setblocking(1)
|
||||
self._sock.listen(1)
|
||||
self.ident = f'{self.me}:{this_port}'
|
||||
self.host = host
|
||||
self.port = this_port
|
||||
self.say(BANNER.format(self=self))
|
||||
|
||||
self._client, address = self._sock.accept()
|
||||
self._client.setblocking(1)
|
||||
self.remote_addr = ':'.join(str(v) for v in address)
|
||||
self.say(SESSION_STARTED.format(self=self))
|
||||
self._handle = sys.stdin = sys.stdout = self._client.makefile('rw')
|
||||
super().__init__(completekey='tab',
|
||||
stdin=self._handle, stdout=self._handle)
|
||||
|
||||
def get_avail_port(self, host, port, search_limit=100, skew=+0):
|
||||
try:
|
||||
_, skew = current_process().name.split('-')
|
||||
skew = int(skew)
|
||||
except ValueError:
|
||||
pass
|
||||
this_port = None
|
||||
for i in range(search_limit):
|
||||
_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
this_port = port + skew + i
|
||||
try:
|
||||
_sock.bind((host, this_port))
|
||||
except OSError as exc:
|
||||
if exc.errno in [errno.EADDRINUSE, errno.EINVAL]:
|
||||
continue
|
||||
raise
|
||||
else:
|
||||
return _sock, this_port
|
||||
raise Exception(NO_AVAILABLE_PORT.format(self=self))
|
||||
|
||||
def say(self, m):
|
||||
print(m, file=self.out)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *exc_info):
|
||||
self._close_session()
|
||||
|
||||
def _close_session(self):
|
||||
self.stdin, self.stdout = sys.stdin, sys.stdout = self._prev_handles
|
||||
if self.active:
|
||||
if self._handle is not None:
|
||||
self._handle.close()
|
||||
if self._client is not None:
|
||||
self._client.close()
|
||||
if self._sock is not None:
|
||||
self._sock.close()
|
||||
self.active = False
|
||||
self.say(SESSION_ENDED.format(self=self))
|
||||
|
||||
def do_continue(self, arg):
|
||||
self._close_session()
|
||||
self.set_continue()
|
||||
return 1
|
||||
do_c = do_cont = do_continue
|
||||
|
||||
def do_quit(self, arg):
|
||||
self._close_session()
|
||||
self.set_quit()
|
||||
return 1
|
||||
do_q = do_exit = do_quit
|
||||
|
||||
def set_quit(self):
|
||||
# this raises a BdbQuit exception that we're unable to catch.
|
||||
sys.settrace(None)
|
||||
|
||||
|
||||
def debugger():
|
||||
"""Return the current debugger instance, or create if none."""
|
||||
rdb = _current[0]
|
||||
if rdb is None or not rdb.active:
|
||||
rdb = _current[0] = Rdb()
|
||||
return rdb
|
||||
|
||||
|
||||
def set_trace(frame=None):
|
||||
"""Set break-point at current location, or a specified frame."""
|
||||
if frame is None:
|
||||
frame = _frame().f_back
|
||||
return debugger().set_trace(frame)
|
||||
@@ -0,0 +1,105 @@
|
||||
"""Sphinx documentation plugin used to document tasks.
|
||||
|
||||
Introduction
|
||||
============
|
||||
|
||||
Usage
|
||||
-----
|
||||
|
||||
The Celery extension for Sphinx requires Sphinx 2.0 or later.
|
||||
|
||||
Add the extension to your :file:`docs/conf.py` configuration module:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
extensions = (...,
|
||||
'celery.contrib.sphinx')
|
||||
|
||||
If you'd like to change the prefix for tasks in reference documentation
|
||||
then you can change the ``celery_task_prefix`` configuration value:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
celery_task_prefix = '(task)' # < default
|
||||
|
||||
With the extension installed `autodoc` will automatically find
|
||||
task decorated objects (e.g. when using the automodule directive)
|
||||
and generate the correct (as well as add a ``(task)`` prefix),
|
||||
and you can also refer to the tasks using `:task:proj.tasks.add`
|
||||
syntax.
|
||||
|
||||
Use ``.. autotask::`` to alternatively manually document a task.
|
||||
"""
|
||||
from inspect import signature
|
||||
|
||||
from docutils import nodes
|
||||
from sphinx.domains.python import PyFunction
|
||||
from sphinx.ext.autodoc import FunctionDocumenter
|
||||
|
||||
from celery.app.task import BaseTask
|
||||
|
||||
|
||||
class TaskDocumenter(FunctionDocumenter):
|
||||
"""Document task definitions."""
|
||||
|
||||
objtype = 'task'
|
||||
member_order = 11
|
||||
|
||||
@classmethod
|
||||
def can_document_member(cls, member, membername, isattr, parent):
|
||||
return isinstance(member, BaseTask) and getattr(member, '__wrapped__')
|
||||
|
||||
def format_args(self):
|
||||
wrapped = getattr(self.object, '__wrapped__', None)
|
||||
if wrapped is not None:
|
||||
sig = signature(wrapped)
|
||||
if "self" in sig.parameters or "cls" in sig.parameters:
|
||||
sig = sig.replace(parameters=list(sig.parameters.values())[1:])
|
||||
return str(sig)
|
||||
return ''
|
||||
|
||||
def document_members(self, all_members=False):
|
||||
pass
|
||||
|
||||
def check_module(self):
|
||||
# Normally checks if *self.object* is really defined in the module
|
||||
# given by *self.modname*. But since functions decorated with the @task
|
||||
# decorator are instances living in the celery.local, we have to check
|
||||
# the wrapped function instead.
|
||||
wrapped = getattr(self.object, '__wrapped__', None)
|
||||
if wrapped and getattr(wrapped, '__module__') == self.modname:
|
||||
return True
|
||||
return super().check_module()
|
||||
|
||||
|
||||
class TaskDirective(PyFunction):
|
||||
"""Sphinx task directive."""
|
||||
|
||||
def get_signature_prefix(self, sig):
|
||||
return [nodes.Text(self.env.config.celery_task_prefix)]
|
||||
|
||||
|
||||
def autodoc_skip_member_handler(app, what, name, obj, skip, options):
|
||||
"""Handler for autodoc-skip-member event."""
|
||||
# Celery tasks created with the @task decorator have the property
|
||||
# that *obj.__doc__* and *obj.__class__.__doc__* are equal, which
|
||||
# trips up the logic in sphinx.ext.autodoc that is supposed to
|
||||
# suppress repetition of class documentation in an instance of the
|
||||
# class. This overrides that behavior.
|
||||
if isinstance(obj, BaseTask) and getattr(obj, '__wrapped__'):
|
||||
if skip:
|
||||
return False
|
||||
return None
|
||||
|
||||
|
||||
def setup(app):
|
||||
"""Setup Sphinx extension."""
|
||||
app.setup_extension('sphinx.ext.autodoc')
|
||||
app.add_autodocumenter(TaskDocumenter)
|
||||
app.add_directive_to_domain('py', 'task', TaskDirective)
|
||||
app.add_config_value('celery_task_prefix', '(task)', True)
|
||||
app.connect('autodoc-skip-member', autodoc_skip_member_handler)
|
||||
|
||||
return {
|
||||
'parallel_read_safe': True
|
||||
}
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,112 @@
|
||||
"""Create Celery app instances used for testing."""
|
||||
import weakref
|
||||
from contextlib import contextmanager
|
||||
from copy import deepcopy
|
||||
|
||||
from kombu.utils.imports import symbol_by_name
|
||||
|
||||
from celery import Celery, _state
|
||||
|
||||
#: Contains the default configuration values for the test app.
|
||||
DEFAULT_TEST_CONFIG = {
|
||||
'worker_hijack_root_logger': False,
|
||||
'worker_log_color': False,
|
||||
'accept_content': {'json'},
|
||||
'enable_utc': True,
|
||||
'timezone': 'UTC',
|
||||
'broker_url': 'memory://',
|
||||
'result_backend': 'cache+memory://',
|
||||
'broker_heartbeat': 0,
|
||||
}
|
||||
|
||||
|
||||
class Trap:
|
||||
"""Trap that pretends to be an app but raises an exception instead.
|
||||
|
||||
This to protect from code that does not properly pass app instances,
|
||||
then falls back to the current_app.
|
||||
"""
|
||||
|
||||
def __getattr__(self, name):
|
||||
# Workaround to allow unittest.mock to patch this object
|
||||
# in Python 3.8 and above.
|
||||
if name == '_is_coroutine' or name == '__func__':
|
||||
return None
|
||||
print(name)
|
||||
raise RuntimeError('Test depends on current_app')
|
||||
|
||||
|
||||
class UnitLogging(symbol_by_name(Celery.log_cls)):
|
||||
"""Sets up logging for the test application."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.already_setup = True
|
||||
|
||||
|
||||
def TestApp(name=None, config=None, enable_logging=False, set_as_current=False,
|
||||
log=UnitLogging, backend=None, broker=None, **kwargs):
|
||||
"""App used for testing."""
|
||||
from . import tasks # noqa
|
||||
config = dict(deepcopy(DEFAULT_TEST_CONFIG), **config or {})
|
||||
if broker is not None:
|
||||
config.pop('broker_url', None)
|
||||
if backend is not None:
|
||||
config.pop('result_backend', None)
|
||||
log = None if enable_logging else log
|
||||
test_app = Celery(
|
||||
name or 'celery.tests',
|
||||
set_as_current=set_as_current,
|
||||
log=log,
|
||||
broker=broker,
|
||||
backend=backend,
|
||||
**kwargs)
|
||||
test_app.add_defaults(config)
|
||||
return test_app
|
||||
|
||||
|
||||
@contextmanager
|
||||
def set_trap(app):
|
||||
"""Contextmanager that installs the trap app.
|
||||
|
||||
The trap means that anything trying to use the current or default app
|
||||
will raise an exception.
|
||||
"""
|
||||
trap = Trap()
|
||||
prev_tls = _state._tls
|
||||
_state.set_default_app(trap)
|
||||
|
||||
class NonTLS:
|
||||
current_app = trap
|
||||
_state._tls = NonTLS()
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_state._tls = prev_tls
|
||||
|
||||
|
||||
@contextmanager
|
||||
def setup_default_app(app, use_trap=False):
|
||||
"""Setup default app for testing.
|
||||
|
||||
Ensures state is clean after the test returns.
|
||||
"""
|
||||
prev_current_app = _state.get_current_app()
|
||||
prev_default_app = _state.default_app
|
||||
prev_finalizers = set(_state._on_app_finalizers)
|
||||
prev_apps = weakref.WeakSet(_state._apps)
|
||||
|
||||
try:
|
||||
if use_trap:
|
||||
with set_trap(app):
|
||||
yield
|
||||
else:
|
||||
yield
|
||||
finally:
|
||||
_state.set_default_app(prev_default_app)
|
||||
_state._tls.current_app = prev_current_app
|
||||
if app is not prev_current_app:
|
||||
app.close()
|
||||
_state._on_app_finalizers = prev_finalizers
|
||||
_state._apps = prev_apps
|
||||
@@ -0,0 +1,239 @@
|
||||
"""Integration testing utilities."""
|
||||
import socket
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from functools import partial
|
||||
from itertools import count
|
||||
from typing import Any, Callable, Dict, Sequence, TextIO, Tuple # noqa
|
||||
|
||||
from kombu.exceptions import ContentDisallowed
|
||||
from kombu.utils.functional import retry_over_time
|
||||
|
||||
from celery import states
|
||||
from celery.exceptions import TimeoutError
|
||||
from celery.result import AsyncResult, ResultSet # noqa
|
||||
from celery.utils.text import truncate
|
||||
from celery.utils.time import humanize_seconds as _humanize_seconds
|
||||
|
||||
E_STILL_WAITING = 'Still waiting for {0}. Trying again {when}: {exc!r}'
|
||||
|
||||
humanize_seconds = partial(_humanize_seconds, microseconds=True)
|
||||
|
||||
|
||||
class Sentinel(Exception):
|
||||
"""Signifies the end of something."""
|
||||
|
||||
|
||||
class ManagerMixin:
|
||||
"""Mixin that adds :class:`Manager` capabilities."""
|
||||
|
||||
def _init_manager(self,
|
||||
block_timeout=30 * 60.0, no_join=False,
|
||||
stdout=None, stderr=None):
|
||||
# type: (float, bool, TextIO, TextIO) -> None
|
||||
self.stdout = sys.stdout if stdout is None else stdout
|
||||
self.stderr = sys.stderr if stderr is None else stderr
|
||||
self.connerrors = self.app.connection().recoverable_connection_errors
|
||||
self.block_timeout = block_timeout
|
||||
self.no_join = no_join
|
||||
|
||||
def remark(self, s, sep='-'):
|
||||
# type: (str, str) -> None
|
||||
print(f'{sep}{s}', file=self.stdout)
|
||||
|
||||
def missing_results(self, r):
|
||||
# type: (Sequence[AsyncResult]) -> Sequence[str]
|
||||
return [res.id for res in r if res.id not in res.backend._cache]
|
||||
|
||||
def wait_for(
|
||||
self,
|
||||
fun, # type: Callable
|
||||
catch, # type: Sequence[Any]
|
||||
desc="thing", # type: str
|
||||
args=(), # type: Tuple
|
||||
kwargs=None, # type: Dict
|
||||
errback=None, # type: Callable
|
||||
max_retries=10, # type: int
|
||||
interval_start=0.1, # type: float
|
||||
interval_step=0.5, # type: float
|
||||
interval_max=5.0, # type: float
|
||||
emit_warning=False, # type: bool
|
||||
**options # type: Any
|
||||
):
|
||||
# type: (...) -> Any
|
||||
"""Wait for event to happen.
|
||||
|
||||
The `catch` argument specifies the exception that means the event
|
||||
has not happened yet.
|
||||
"""
|
||||
kwargs = {} if not kwargs else kwargs
|
||||
|
||||
def on_error(exc, intervals, retries):
|
||||
interval = next(intervals)
|
||||
if emit_warning:
|
||||
self.warn(E_STILL_WAITING.format(
|
||||
desc, when=humanize_seconds(interval, 'in', ' '), exc=exc,
|
||||
))
|
||||
if errback:
|
||||
errback(exc, interval, retries)
|
||||
return interval
|
||||
|
||||
return self.retry_over_time(
|
||||
fun, catch,
|
||||
args=args, kwargs=kwargs,
|
||||
errback=on_error, max_retries=max_retries,
|
||||
interval_start=interval_start, interval_step=interval_step,
|
||||
**options
|
||||
)
|
||||
|
||||
def ensure_not_for_a_while(self, fun, catch,
|
||||
desc='thing', max_retries=20,
|
||||
interval_start=0.1, interval_step=0.02,
|
||||
interval_max=1.0, emit_warning=False,
|
||||
**options):
|
||||
"""Make sure something does not happen (at least for a while)."""
|
||||
try:
|
||||
return self.wait_for(
|
||||
fun, catch, desc=desc, max_retries=max_retries,
|
||||
interval_start=interval_start, interval_step=interval_step,
|
||||
interval_max=interval_max, emit_warning=emit_warning,
|
||||
)
|
||||
except catch:
|
||||
pass
|
||||
else:
|
||||
raise AssertionError(f'Should not have happened: {desc}')
|
||||
|
||||
def retry_over_time(self, *args, **kwargs):
|
||||
return retry_over_time(*args, **kwargs)
|
||||
|
||||
def join(self, r, propagate=False, max_retries=10, **kwargs):
|
||||
if self.no_join:
|
||||
return
|
||||
if not isinstance(r, ResultSet):
|
||||
r = self.app.ResultSet([r])
|
||||
received = []
|
||||
|
||||
def on_result(task_id, value):
|
||||
received.append(task_id)
|
||||
|
||||
for i in range(max_retries) if max_retries else count(0):
|
||||
received[:] = []
|
||||
try:
|
||||
return r.get(callback=on_result, propagate=propagate, **kwargs)
|
||||
except (socket.timeout, TimeoutError) as exc:
|
||||
waiting_for = self.missing_results(r)
|
||||
self.remark(
|
||||
'Still waiting for {}/{}: [{}]: {!r}'.format(
|
||||
len(r) - len(received), len(r),
|
||||
truncate(', '.join(waiting_for)), exc), '!',
|
||||
)
|
||||
except self.connerrors as exc:
|
||||
self.remark(f'join: connection lost: {exc!r}', '!')
|
||||
raise AssertionError('Test failed: Missing task results')
|
||||
|
||||
def inspect(self, timeout=3.0):
|
||||
return self.app.control.inspect(timeout=timeout)
|
||||
|
||||
def query_tasks(self, ids, timeout=0.5):
|
||||
tasks = self.inspect(timeout).query_task(*ids) or {}
|
||||
yield from tasks.items()
|
||||
|
||||
def query_task_states(self, ids, timeout=0.5):
|
||||
states = defaultdict(set)
|
||||
for hostname, reply in self.query_tasks(ids, timeout=timeout):
|
||||
for task_id, (state, _) in reply.items():
|
||||
states[state].add(task_id)
|
||||
return states
|
||||
|
||||
def assert_accepted(self, ids, interval=0.5,
|
||||
desc='waiting for tasks to be accepted', **policy):
|
||||
return self.assert_task_worker_state(
|
||||
self.is_accepted, ids, interval=interval, desc=desc, **policy
|
||||
)
|
||||
|
||||
def assert_received(self, ids, interval=0.5,
|
||||
desc='waiting for tasks to be received', **policy):
|
||||
return self.assert_task_worker_state(
|
||||
self.is_received, ids, interval=interval, desc=desc, **policy
|
||||
)
|
||||
|
||||
def assert_result_tasks_in_progress_or_completed(
|
||||
self,
|
||||
async_results,
|
||||
interval=0.5,
|
||||
desc='waiting for tasks to be started or completed',
|
||||
**policy
|
||||
):
|
||||
return self.assert_task_state_from_result(
|
||||
self.is_result_task_in_progress,
|
||||
async_results,
|
||||
interval=interval, desc=desc, **policy
|
||||
)
|
||||
|
||||
def assert_task_state_from_result(self, fun, results,
|
||||
interval=0.5, **policy):
|
||||
return self.wait_for(
|
||||
partial(self.true_or_raise, fun, results, timeout=interval),
|
||||
(Sentinel,), **policy
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def is_result_task_in_progress(results, **kwargs):
|
||||
possible_states = (states.STARTED, states.SUCCESS)
|
||||
return all(result.state in possible_states for result in results)
|
||||
|
||||
def assert_task_worker_state(self, fun, ids, interval=0.5, **policy):
|
||||
return self.wait_for(
|
||||
partial(self.true_or_raise, fun, ids, timeout=interval),
|
||||
(Sentinel,), **policy
|
||||
)
|
||||
|
||||
def is_received(self, ids, **kwargs):
|
||||
return self._ids_matches_state(
|
||||
['reserved', 'active', 'ready'], ids, **kwargs)
|
||||
|
||||
def is_accepted(self, ids, **kwargs):
|
||||
return self._ids_matches_state(['active', 'ready'], ids, **kwargs)
|
||||
|
||||
def _ids_matches_state(self, expected_states, ids, timeout=0.5):
|
||||
states = self.query_task_states(ids, timeout=timeout)
|
||||
return all(
|
||||
any(t in s for s in [states[k] for k in expected_states])
|
||||
for t in ids
|
||||
)
|
||||
|
||||
def true_or_raise(self, fun, *args, **kwargs):
|
||||
res = fun(*args, **kwargs)
|
||||
if not res:
|
||||
raise Sentinel()
|
||||
return res
|
||||
|
||||
def wait_until_idle(self):
|
||||
control = self.app.control
|
||||
with self.app.connection() as connection:
|
||||
# Try to purge the queue before we start
|
||||
# to attempt to avoid interference from other tests
|
||||
while True:
|
||||
count = control.purge(connection=connection)
|
||||
if count == 0:
|
||||
break
|
||||
|
||||
# Wait until worker is idle
|
||||
inspect = control.inspect()
|
||||
inspect.connection = connection
|
||||
while True:
|
||||
try:
|
||||
count = sum(len(t) for t in inspect.active().values())
|
||||
except ContentDisallowed:
|
||||
# test_security_task_done may trigger this exception
|
||||
break
|
||||
if count == 0:
|
||||
break
|
||||
|
||||
|
||||
class Manager(ManagerMixin):
|
||||
"""Test helpers for task integration tests."""
|
||||
|
||||
def __init__(self, app, **kwargs):
|
||||
self.app = app
|
||||
self._init_manager(**kwargs)
|
||||
@@ -0,0 +1,137 @@
|
||||
"""Useful mocks for unit testing."""
|
||||
import numbers
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Mapping, Sequence # noqa
|
||||
from unittest.mock import Mock
|
||||
|
||||
from celery import Celery # noqa
|
||||
from celery.canvas import Signature # noqa
|
||||
|
||||
|
||||
def TaskMessage(
|
||||
name, # type: str
|
||||
id=None, # type: str
|
||||
args=(), # type: Sequence
|
||||
kwargs=None, # type: Mapping
|
||||
callbacks=None, # type: Sequence[Signature]
|
||||
errbacks=None, # type: Sequence[Signature]
|
||||
chain=None, # type: Sequence[Signature]
|
||||
shadow=None, # type: str
|
||||
utc=None, # type: bool
|
||||
**options # type: Any
|
||||
):
|
||||
# type: (...) -> Any
|
||||
"""Create task message in protocol 2 format."""
|
||||
kwargs = {} if not kwargs else kwargs
|
||||
from kombu.serialization import dumps
|
||||
|
||||
from celery import uuid
|
||||
id = id or uuid()
|
||||
message = Mock(name=f'TaskMessage-{id}')
|
||||
message.headers = {
|
||||
'id': id,
|
||||
'task': name,
|
||||
'shadow': shadow,
|
||||
}
|
||||
embed = {'callbacks': callbacks, 'errbacks': errbacks, 'chain': chain}
|
||||
message.headers.update(options)
|
||||
message.content_type, message.content_encoding, message.body = dumps(
|
||||
(args, kwargs, embed), serializer='json',
|
||||
)
|
||||
message.payload = (args, kwargs, embed)
|
||||
return message
|
||||
|
||||
|
||||
def TaskMessage1(
|
||||
name, # type: str
|
||||
id=None, # type: str
|
||||
args=(), # type: Sequence
|
||||
kwargs=None, # type: Mapping
|
||||
callbacks=None, # type: Sequence[Signature]
|
||||
errbacks=None, # type: Sequence[Signature]
|
||||
chain=None, # type: Sequence[Signature]
|
||||
**options # type: Any
|
||||
):
|
||||
# type: (...) -> Any
|
||||
"""Create task message in protocol 1 format."""
|
||||
kwargs = {} if not kwargs else kwargs
|
||||
from kombu.serialization import dumps
|
||||
|
||||
from celery import uuid
|
||||
id = id or uuid()
|
||||
message = Mock(name=f'TaskMessage-{id}')
|
||||
message.headers = {}
|
||||
message.payload = {
|
||||
'task': name,
|
||||
'id': id,
|
||||
'args': args,
|
||||
'kwargs': kwargs,
|
||||
'callbacks': callbacks,
|
||||
'errbacks': errbacks,
|
||||
}
|
||||
message.payload.update(options)
|
||||
message.content_type, message.content_encoding, message.body = dumps(
|
||||
message.payload,
|
||||
)
|
||||
return message
|
||||
|
||||
|
||||
def task_message_from_sig(app, sig, utc=True, TaskMessage=TaskMessage):
|
||||
# type: (Celery, Signature, bool, Any) -> Any
|
||||
"""Create task message from :class:`celery.Signature`.
|
||||
|
||||
Example:
|
||||
>>> m = task_message_from_sig(app, add.s(2, 2))
|
||||
>>> amqp_client.basic_publish(m, exchange='ex', routing_key='rkey')
|
||||
"""
|
||||
sig.freeze()
|
||||
callbacks = sig.options.pop('link', None)
|
||||
errbacks = sig.options.pop('link_error', None)
|
||||
countdown = sig.options.pop('countdown', None)
|
||||
if countdown:
|
||||
eta = app.now() + timedelta(seconds=countdown)
|
||||
else:
|
||||
eta = sig.options.pop('eta', None)
|
||||
if eta and isinstance(eta, datetime):
|
||||
eta = eta.isoformat()
|
||||
expires = sig.options.pop('expires', None)
|
||||
if expires and isinstance(expires, numbers.Real):
|
||||
expires = app.now() + timedelta(seconds=expires)
|
||||
if expires and isinstance(expires, datetime):
|
||||
expires = expires.isoformat()
|
||||
return TaskMessage(
|
||||
sig.task, id=sig.id, args=sig.args,
|
||||
kwargs=sig.kwargs,
|
||||
callbacks=[dict(s) for s in callbacks] if callbacks else None,
|
||||
errbacks=[dict(s) for s in errbacks] if errbacks else None,
|
||||
eta=eta,
|
||||
expires=expires,
|
||||
utc=utc,
|
||||
**sig.options
|
||||
)
|
||||
|
||||
|
||||
class _ContextMock(Mock):
|
||||
"""Dummy class implementing __enter__ and __exit__.
|
||||
|
||||
The :keyword:`with` statement requires these to be implemented
|
||||
in the class, not just the instance.
|
||||
"""
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *exc_info):
|
||||
pass
|
||||
|
||||
|
||||
def ContextMock(*args, **kwargs):
|
||||
"""Mock that mocks :keyword:`with` statement contexts."""
|
||||
obj = _ContextMock(*args, **kwargs)
|
||||
obj.attach_mock(_ContextMock(), '__enter__')
|
||||
obj.attach_mock(_ContextMock(), '__exit__')
|
||||
obj.__enter__.return_value = obj
|
||||
# if __exit__ return a value the exception is ignored,
|
||||
# so it must return None here.
|
||||
obj.__exit__.return_value = None
|
||||
return obj
|
||||
@@ -0,0 +1,9 @@
|
||||
"""Helper tasks for integration tests."""
|
||||
from celery import shared_task
|
||||
|
||||
|
||||
@shared_task(name='celery.ping')
|
||||
def ping():
|
||||
# type: () -> str
|
||||
"""Simple task that just returns 'pong'."""
|
||||
return 'pong'
|
||||
@@ -0,0 +1,223 @@
|
||||
"""Embedded workers for integration tests."""
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Iterable, Optional, Union
|
||||
|
||||
import celery.worker.consumer # noqa
|
||||
from celery import Celery, worker
|
||||
from celery.result import _set_task_join_will_block, allow_join_result
|
||||
from celery.utils.dispatch import Signal
|
||||
from celery.utils.nodenames import anon_nodename
|
||||
|
||||
WORKER_LOGLEVEL = os.environ.get('WORKER_LOGLEVEL', 'error')
|
||||
|
||||
test_worker_starting = Signal(
|
||||
name='test_worker_starting',
|
||||
providing_args={},
|
||||
)
|
||||
test_worker_started = Signal(
|
||||
name='test_worker_started',
|
||||
providing_args={'worker', 'consumer'},
|
||||
)
|
||||
test_worker_stopped = Signal(
|
||||
name='test_worker_stopped',
|
||||
providing_args={'worker'},
|
||||
)
|
||||
|
||||
|
||||
class TestWorkController(worker.WorkController):
|
||||
"""Worker that can synchronize on being fully started."""
|
||||
|
||||
# When this class is imported in pytest files, prevent pytest from thinking
|
||||
# this is a test class
|
||||
__test__ = False
|
||||
|
||||
logger_queue = None
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
# type: (*Any, **Any) -> None
|
||||
self._on_started = threading.Event()
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
if self.pool_cls.__module__.split('.')[-1] == 'prefork':
|
||||
from billiard import Queue
|
||||
self.logger_queue = Queue()
|
||||
self.pid = os.getpid()
|
||||
|
||||
try:
|
||||
from tblib import pickling_support
|
||||
pickling_support.install()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# collect logs from forked process.
|
||||
# XXX: those logs will appear twice in the live log
|
||||
self.queue_listener = logging.handlers.QueueListener(self.logger_queue, logging.getLogger())
|
||||
self.queue_listener.start()
|
||||
|
||||
class QueueHandler(logging.handlers.QueueHandler):
|
||||
def prepare(self, record):
|
||||
record.from_queue = True
|
||||
# Keep origin record.
|
||||
return record
|
||||
|
||||
def handleError(self, record):
|
||||
if logging.raiseExceptions:
|
||||
raise
|
||||
|
||||
def start(self):
|
||||
if self.logger_queue:
|
||||
handler = self.QueueHandler(self.logger_queue)
|
||||
handler.addFilter(lambda r: r.process != self.pid and not getattr(r, 'from_queue', False))
|
||||
logger = logging.getLogger()
|
||||
logger.addHandler(handler)
|
||||
return super().start()
|
||||
|
||||
def on_consumer_ready(self, consumer):
|
||||
# type: (celery.worker.consumer.Consumer) -> None
|
||||
"""Callback called when the Consumer blueprint is fully started."""
|
||||
self._on_started.set()
|
||||
test_worker_started.send(
|
||||
sender=self.app, worker=self, consumer=consumer)
|
||||
|
||||
def ensure_started(self):
|
||||
# type: () -> None
|
||||
"""Wait for worker to be fully up and running.
|
||||
|
||||
Warning:
|
||||
Worker must be started within a thread for this to work,
|
||||
or it will block forever.
|
||||
"""
|
||||
self._on_started.wait()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def start_worker(
|
||||
app, # type: Celery
|
||||
concurrency=1, # type: int
|
||||
pool='solo', # type: str
|
||||
loglevel=WORKER_LOGLEVEL, # type: Union[str, int]
|
||||
logfile=None, # type: str
|
||||
perform_ping_check=True, # type: bool
|
||||
ping_task_timeout=10.0, # type: float
|
||||
shutdown_timeout=10.0, # type: float
|
||||
**kwargs # type: Any
|
||||
):
|
||||
# type: (...) -> Iterable
|
||||
"""Start embedded worker.
|
||||
|
||||
Yields:
|
||||
celery.app.worker.Worker: worker instance.
|
||||
"""
|
||||
test_worker_starting.send(sender=app)
|
||||
|
||||
worker = None
|
||||
try:
|
||||
with _start_worker_thread(app,
|
||||
concurrency=concurrency,
|
||||
pool=pool,
|
||||
loglevel=loglevel,
|
||||
logfile=logfile,
|
||||
perform_ping_check=perform_ping_check,
|
||||
shutdown_timeout=shutdown_timeout,
|
||||
**kwargs) as worker:
|
||||
if perform_ping_check:
|
||||
from .tasks import ping
|
||||
with allow_join_result():
|
||||
assert ping.delay().get(timeout=ping_task_timeout) == 'pong'
|
||||
|
||||
yield worker
|
||||
finally:
|
||||
test_worker_stopped.send(sender=app, worker=worker)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _start_worker_thread(app: Celery,
|
||||
concurrency: int = 1,
|
||||
pool: str = 'solo',
|
||||
loglevel: Union[str, int] = WORKER_LOGLEVEL,
|
||||
logfile: Optional[str] = None,
|
||||
WorkController: Any = TestWorkController,
|
||||
perform_ping_check: bool = True,
|
||||
shutdown_timeout: float = 10.0,
|
||||
**kwargs) -> Iterable[worker.WorkController]:
|
||||
"""Start Celery worker in a thread.
|
||||
|
||||
Yields:
|
||||
celery.worker.Worker: worker instance.
|
||||
"""
|
||||
setup_app_for_worker(app, loglevel, logfile)
|
||||
if perform_ping_check:
|
||||
assert 'celery.ping' in app.tasks
|
||||
# Make sure we can connect to the broker
|
||||
with app.connection(hostname=os.environ.get('TEST_BROKER')) as conn:
|
||||
conn.default_channel.queue_declare
|
||||
|
||||
worker = WorkController(
|
||||
app=app,
|
||||
concurrency=concurrency,
|
||||
hostname=kwargs.pop("hostname", anon_nodename()),
|
||||
pool=pool,
|
||||
loglevel=loglevel,
|
||||
logfile=logfile,
|
||||
# not allowed to override TestWorkController.on_consumer_ready
|
||||
ready_callback=None,
|
||||
without_heartbeat=kwargs.pop("without_heartbeat", True),
|
||||
without_mingle=True,
|
||||
without_gossip=True,
|
||||
**kwargs)
|
||||
|
||||
t = threading.Thread(target=worker.start, daemon=True)
|
||||
t.start()
|
||||
worker.ensure_started()
|
||||
_set_task_join_will_block(False)
|
||||
|
||||
try:
|
||||
yield worker
|
||||
finally:
|
||||
from celery.worker import state
|
||||
state.should_terminate = 0
|
||||
t.join(shutdown_timeout)
|
||||
if t.is_alive():
|
||||
raise RuntimeError(
|
||||
"Worker thread failed to exit within the allocated timeout. "
|
||||
"Consider raising `shutdown_timeout` if your tasks take longer "
|
||||
"to execute."
|
||||
)
|
||||
state.should_terminate = None
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _start_worker_process(app,
|
||||
concurrency=1,
|
||||
pool='solo',
|
||||
loglevel=WORKER_LOGLEVEL,
|
||||
logfile=None,
|
||||
**kwargs):
|
||||
# type (Celery, int, str, Union[int, str], str, **Any) -> Iterable
|
||||
"""Start worker in separate process.
|
||||
|
||||
Yields:
|
||||
celery.app.worker.Worker: worker instance.
|
||||
"""
|
||||
from celery.apps.multi import Cluster, Node
|
||||
|
||||
app.set_current()
|
||||
cluster = Cluster([Node('testworker1@%h')])
|
||||
cluster.start()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
cluster.stopwait()
|
||||
|
||||
|
||||
def setup_app_for_worker(app: Celery, loglevel: Union[str, int], logfile: str) -> None:
|
||||
"""Setup the app to be used for starting an embedded worker."""
|
||||
app.finalize()
|
||||
app.set_current()
|
||||
app.set_default()
|
||||
type(app.log)._setup = False
|
||||
app.log.setup(loglevel=loglevel, logfile=logfile)
|
||||
Reference in New Issue
Block a user