This commit is contained in:
Iliyan Angelov
2025-09-19 11:58:53 +03:00
parent 306b20e24a
commit 6b247e5b9f
11423 changed files with 1500615 additions and 778 deletions

View File

@@ -0,0 +1,165 @@
"""Abortable Tasks.
Abortable tasks overview
=========================
For long-running :class:`Task`'s, it can be desirable to support
aborting during execution. Of course, these tasks should be built to
support abortion specifically.
The :class:`AbortableTask` serves as a base class for all :class:`Task`
objects that should support abortion by producers.
* Producers may invoke the :meth:`abort` method on
:class:`AbortableAsyncResult` instances, to request abortion.
* Consumers (workers) should periodically check (and honor!) the
:meth:`is_aborted` method at controlled points in their task's
:meth:`run` method. The more often, the better.
The necessary intermediate communication is dealt with by the
:class:`AbortableTask` implementation.
Usage example
-------------
In the consumer:
.. code-block:: python
from celery.contrib.abortable import AbortableTask
from celery.utils.log import get_task_logger
from proj.celery import app
logger = get_logger(__name__)
@app.task(bind=True, base=AbortableTask)
def long_running_task(self):
results = []
for i in range(100):
# check after every 5 iterations...
# (or alternatively, check when some timer is due)
if not i % 5:
if self.is_aborted():
# respect aborted state, and terminate gracefully.
logger.warning('Task aborted')
return
value = do_something_expensive(i)
results.append(y)
logger.info('Task complete')
return results
In the producer:
.. code-block:: python
import time
from proj.tasks import MyLongRunningTask
def myview(request):
# result is of type AbortableAsyncResult
result = long_running_task.delay()
# abort the task after 10 seconds
time.sleep(10)
result.abort()
After the `result.abort()` call, the task execution isn't
aborted immediately. In fact, it's not guaranteed to abort at all.
Keep checking `result.state` status, or call `result.get(timeout=)` to
have it block until the task is finished.
.. note::
In order to abort tasks, there needs to be communication between the
producer and the consumer. This is currently implemented through the
database backend. Therefore, this class will only work with the
database backends.
"""
from celery import Task
from celery.result import AsyncResult
__all__ = ('AbortableAsyncResult', 'AbortableTask')
"""
Task States
-----------
.. state:: ABORTED
ABORTED
~~~~~~~
Task is aborted (typically by the producer) and should be
aborted as soon as possible.
"""
ABORTED = 'ABORTED'
class AbortableAsyncResult(AsyncResult):
"""Represents an abortable result.
Specifically, this gives the `AsyncResult` a :meth:`abort()` method,
that sets the state of the underlying Task to `'ABORTED'`.
"""
def is_aborted(self):
"""Return :const:`True` if the task is (being) aborted."""
return self.state == ABORTED
def abort(self):
"""Set the state of the task to :const:`ABORTED`.
Abortable tasks monitor their state at regular intervals and
terminate execution if so.
Warning:
Be aware that invoking this method does not guarantee when the
task will be aborted (or even if the task will be aborted at all).
"""
# TODO: store_result requires all four arguments to be set,
# but only state should be updated here
return self.backend.store_result(self.id, result=None,
state=ABORTED, traceback=None)
class AbortableTask(Task):
"""Task that can be aborted.
This serves as a base class for all :class:`Task`'s
that support aborting during execution.
All subclasses of :class:`AbortableTask` must call the
:meth:`is_aborted` method periodically and act accordingly when
the call evaluates to :const:`True`.
"""
abstract = True
def AsyncResult(self, task_id):
"""Return the accompanying AbortableAsyncResult instance."""
return AbortableAsyncResult(task_id, backend=self.backend)
def is_aborted(self, **kwargs):
"""Return true if task is aborted.
Checks against the backend whether this
:class:`AbortableAsyncResult` is :const:`ABORTED`.
Always return :const:`False` in case the `task_id` parameter
refers to a regular (non-abortable) :class:`Task`.
Be aware that invoking this method will cause a hit in the
backend (for example a database query), so find a good balance
between calling it regularly (for responsiveness), but not too
often (for performance).
"""
task_id = kwargs.get('task_id', self.request.id)
result = self.AsyncResult(task_id)
if not isinstance(result, AbortableAsyncResult):
return False
return result.is_aborted()

View File

@@ -0,0 +1,21 @@
import functools
from django.db import transaction
from celery.app.task import Task
class DjangoTask(Task):
"""
Extend the base :class:`~celery.app.task.Task` for Django.
Provide a nicer API to trigger tasks at the end of the DB transaction.
"""
def delay_on_commit(self, *args, **kwargs) -> None:
"""Call :meth:`~celery.app.task.Task.delay` with Django's ``on_commit()``."""
transaction.on_commit(functools.partial(self.delay, *args, **kwargs))
def apply_async_on_commit(self, *args, **kwargs) -> None:
"""Call :meth:`~celery.app.task.Task.apply_async` with Django's ``on_commit()``."""
transaction.on_commit(functools.partial(self.apply_async, *args, **kwargs))

View File

@@ -0,0 +1,416 @@
"""Message migration tools (Broker <-> Broker)."""
import socket
from functools import partial
from itertools import cycle, islice
from kombu import Queue, eventloop
from kombu.common import maybe_declare
from kombu.utils.encoding import ensure_bytes
from celery.app import app_or_default
from celery.utils.nodenames import worker_direct
from celery.utils.text import str_to_list
__all__ = (
'StopFiltering', 'State', 'republish', 'migrate_task',
'migrate_tasks', 'move', 'task_id_eq', 'task_id_in',
'start_filter', 'move_task_by_id', 'move_by_idmap',
'move_by_taskmap', 'move_direct', 'move_direct_by_id',
)
MOVING_PROGRESS_FMT = """\
Moving task {state.filtered}/{state.strtotal}: \
{body[task]}[{body[id]}]\
"""
class StopFiltering(Exception):
"""Semi-predicate used to signal filter stop."""
class State:
"""Migration progress state."""
count = 0
filtered = 0
total_apx = 0
@property
def strtotal(self):
if not self.total_apx:
return '?'
return str(self.total_apx)
def __repr__(self):
if self.filtered:
return f'^{self.filtered}'
return f'{self.count}/{self.strtotal}'
def republish(producer, message, exchange=None, routing_key=None,
remove_props=None):
"""Republish message."""
if not remove_props:
remove_props = ['application_headers', 'content_type',
'content_encoding', 'headers']
body = ensure_bytes(message.body) # use raw message body.
info, headers, props = (message.delivery_info,
message.headers, message.properties)
exchange = info['exchange'] if exchange is None else exchange
routing_key = info['routing_key'] if routing_key is None else routing_key
ctype, enc = message.content_type, message.content_encoding
# remove compression header, as this will be inserted again
# when the message is recompressed.
compression = headers.pop('compression', None)
expiration = props.pop('expiration', None)
# ensure expiration is a float
expiration = float(expiration) if expiration is not None else None
for key in remove_props:
props.pop(key, None)
producer.publish(ensure_bytes(body), exchange=exchange,
routing_key=routing_key, compression=compression,
headers=headers, content_type=ctype,
content_encoding=enc, expiration=expiration,
**props)
def migrate_task(producer, body_, message, queues=None):
"""Migrate single task message."""
info = message.delivery_info
queues = {} if queues is None else queues
republish(producer, message,
exchange=queues.get(info['exchange']),
routing_key=queues.get(info['routing_key']))
def filter_callback(callback, tasks):
def filtered(body, message):
if tasks and body['task'] not in tasks:
return
return callback(body, message)
return filtered
def migrate_tasks(source, dest, migrate=migrate_task, app=None,
queues=None, **kwargs):
"""Migrate tasks from one broker to another."""
app = app_or_default(app)
queues = prepare_queues(queues)
producer = app.amqp.Producer(dest, auto_declare=False)
migrate = partial(migrate, producer, queues=queues)
def on_declare_queue(queue):
new_queue = queue(producer.channel)
new_queue.name = queues.get(queue.name, queue.name)
if new_queue.routing_key == queue.name:
new_queue.routing_key = queues.get(queue.name,
new_queue.routing_key)
if new_queue.exchange.name == queue.name:
new_queue.exchange.name = queues.get(queue.name, queue.name)
new_queue.declare()
return start_filter(app, source, migrate, queues=queues,
on_declare_queue=on_declare_queue, **kwargs)
def _maybe_queue(app, q):
if isinstance(q, str):
return app.amqp.queues[q]
return q
def move(predicate, connection=None, exchange=None, routing_key=None,
source=None, app=None, callback=None, limit=None, transform=None,
**kwargs):
"""Find tasks by filtering them and move the tasks to a new queue.
Arguments:
predicate (Callable): Filter function used to decide the messages
to move. Must accept the standard signature of ``(body, message)``
used by Kombu consumer callbacks. If the predicate wants the
message to be moved it must return either:
1) a tuple of ``(exchange, routing_key)``, or
2) a :class:`~kombu.entity.Queue` instance, or
3) any other true value means the specified
``exchange`` and ``routing_key`` arguments will be used.
connection (kombu.Connection): Custom connection to use.
source: List[Union[str, kombu.Queue]]: Optional list of source
queues to use instead of the default (queues
in :setting:`task_queues`). This list can also contain
:class:`~kombu.entity.Queue` instances.
exchange (str, kombu.Exchange): Default destination exchange.
routing_key (str): Default destination routing key.
limit (int): Limit number of messages to filter.
callback (Callable): Callback called after message moved,
with signature ``(state, body, message)``.
transform (Callable): Optional function to transform the return
value (destination) of the filter function.
Also supports the same keyword arguments as :func:`start_filter`.
To demonstrate, the :func:`move_task_by_id` operation can be implemented
like this:
.. code-block:: python
def is_wanted_task(body, message):
if body['id'] == wanted_id:
return Queue('foo', exchange=Exchange('foo'),
routing_key='foo')
move(is_wanted_task)
or with a transform:
.. code-block:: python
def transform(value):
if isinstance(value, str):
return Queue(value, Exchange(value), value)
return value
move(is_wanted_task, transform=transform)
Note:
The predicate may also return a tuple of ``(exchange, routing_key)``
to specify the destination to where the task should be moved,
or a :class:`~kombu.entity.Queue` instance.
Any other true value means that the task will be moved to the
default exchange/routing_key.
"""
app = app_or_default(app)
queues = [_maybe_queue(app, queue) for queue in source or []] or None
with app.connection_or_acquire(connection, pool=False) as conn:
producer = app.amqp.Producer(conn)
state = State()
def on_task(body, message):
ret = predicate(body, message)
if ret:
if transform:
ret = transform(ret)
if isinstance(ret, Queue):
maybe_declare(ret, conn.default_channel)
ex, rk = ret.exchange.name, ret.routing_key
else:
ex, rk = expand_dest(ret, exchange, routing_key)
republish(producer, message,
exchange=ex, routing_key=rk)
message.ack()
state.filtered += 1
if callback:
callback(state, body, message)
if limit and state.filtered >= limit:
raise StopFiltering()
return start_filter(app, conn, on_task, consume_from=queues, **kwargs)
def expand_dest(ret, exchange, routing_key):
try:
ex, rk = ret
except (TypeError, ValueError):
ex, rk = exchange, routing_key
return ex, rk
def task_id_eq(task_id, body, message):
"""Return true if task id equals task_id'."""
return body['id'] == task_id
def task_id_in(ids, body, message):
"""Return true if task id is member of set ids'."""
return body['id'] in ids
def prepare_queues(queues):
if isinstance(queues, str):
queues = queues.split(',')
if isinstance(queues, list):
queues = dict(tuple(islice(cycle(q.split(':')), None, 2))
for q in queues)
if queues is None:
queues = {}
return queues
class Filterer:
def __init__(self, app, conn, filter,
limit=None, timeout=1.0,
ack_messages=False, tasks=None, queues=None,
callback=None, forever=False, on_declare_queue=None,
consume_from=None, state=None, accept=None, **kwargs):
self.app = app
self.conn = conn
self.filter = filter
self.limit = limit
self.timeout = timeout
self.ack_messages = ack_messages
self.tasks = set(str_to_list(tasks) or [])
self.queues = prepare_queues(queues)
self.callback = callback
self.forever = forever
self.on_declare_queue = on_declare_queue
self.consume_from = [
_maybe_queue(self.app, q)
for q in consume_from or list(self.queues)
]
self.state = state or State()
self.accept = accept
def start(self):
# start migrating messages.
with self.prepare_consumer(self.create_consumer()):
try:
for _ in eventloop(self.conn, # pragma: no cover
timeout=self.timeout,
ignore_timeouts=self.forever):
pass
except socket.timeout:
pass
except StopFiltering:
pass
return self.state
def update_state(self, body, message):
self.state.count += 1
if self.limit and self.state.count >= self.limit:
raise StopFiltering()
def ack_message(self, body, message):
message.ack()
def create_consumer(self):
return self.app.amqp.TaskConsumer(
self.conn,
queues=self.consume_from,
accept=self.accept,
)
def prepare_consumer(self, consumer):
filter = self.filter
update_state = self.update_state
ack_message = self.ack_message
if self.tasks:
filter = filter_callback(filter, self.tasks)
update_state = filter_callback(update_state, self.tasks)
ack_message = filter_callback(ack_message, self.tasks)
consumer.register_callback(filter)
consumer.register_callback(update_state)
if self.ack_messages:
consumer.register_callback(self.ack_message)
if self.callback is not None:
callback = partial(self.callback, self.state)
if self.tasks:
callback = filter_callback(callback, self.tasks)
consumer.register_callback(callback)
self.declare_queues(consumer)
return consumer
def declare_queues(self, consumer):
# declare all queues on the new broker.
for queue in consumer.queues:
if self.queues and queue.name not in self.queues:
continue
if self.on_declare_queue is not None:
self.on_declare_queue(queue)
try:
_, mcount, _ = queue(
consumer.channel).queue_declare(passive=True)
if mcount:
self.state.total_apx += mcount
except self.conn.channel_errors:
pass
def start_filter(app, conn, filter, limit=None, timeout=1.0,
ack_messages=False, tasks=None, queues=None,
callback=None, forever=False, on_declare_queue=None,
consume_from=None, state=None, accept=None, **kwargs):
"""Filter tasks."""
return Filterer(
app, conn, filter,
limit=limit,
timeout=timeout,
ack_messages=ack_messages,
tasks=tasks,
queues=queues,
callback=callback,
forever=forever,
on_declare_queue=on_declare_queue,
consume_from=consume_from,
state=state,
accept=accept,
**kwargs).start()
def move_task_by_id(task_id, dest, **kwargs):
"""Find a task by id and move it to another queue.
Arguments:
task_id (str): Id of task to find and move.
dest: (str, kombu.Queue): Destination queue.
transform (Callable): Optional function to transform the return
value (destination) of the filter function.
**kwargs (Any): Also supports the same keyword
arguments as :func:`move`.
"""
return move_by_idmap({task_id: dest}, **kwargs)
def move_by_idmap(map, **kwargs):
"""Move tasks by matching from a ``task_id: queue`` mapping.
Where ``queue`` is a queue to move the task to.
Example:
>>> move_by_idmap({
... '5bee6e82-f4ac-468e-bd3d-13e8600250bc': Queue('name'),
... 'ada8652d-aef3-466b-abd2-becdaf1b82b3': Queue('name'),
... '3a2b140d-7db1-41ba-ac90-c36a0ef4ab1f': Queue('name')},
... queues=['hipri'])
"""
def task_id_in_map(body, message):
return map.get(message.properties['correlation_id'])
# adding the limit means that we don't have to consume any more
# when we've found everything.
return move(task_id_in_map, limit=len(map), **kwargs)
def move_by_taskmap(map, **kwargs):
"""Move tasks by matching from a ``task_name: queue`` mapping.
``queue`` is the queue to move the task to.
Example:
>>> move_by_taskmap({
... 'tasks.add': Queue('name'),
... 'tasks.mul': Queue('name'),
... })
"""
def task_name_in_map(body, message):
return map.get(body['task']) # <- name of task
return move(task_name_in_map, **kwargs)
def filter_status(state, body, message, **kwargs):
print(MOVING_PROGRESS_FMT.format(state=state, body=body, **kwargs))
move_direct = partial(move, transform=worker_direct)
move_direct_by_id = partial(move_task_by_id, transform=worker_direct)
move_direct_by_idmap = partial(move_by_idmap, transform=worker_direct)
move_direct_by_taskmap = partial(move_by_taskmap, transform=worker_direct)

View File

@@ -0,0 +1,216 @@
"""Fixtures and testing utilities for :pypi:`pytest <pytest>`."""
import os
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Mapping, Sequence, Union # noqa
import pytest
if TYPE_CHECKING:
from celery import Celery
from ..worker import WorkController
else:
Celery = WorkController = object
NO_WORKER = os.environ.get('NO_WORKER')
# pylint: disable=redefined-outer-name
# Well, they're called fixtures....
def pytest_configure(config):
"""Register additional pytest configuration."""
# add the pytest.mark.celery() marker registration to the pytest.ini [markers] section
# this prevents pytest 4.5 and newer from issuing a warning about an unknown marker
# and shows helpful marker documentation when running pytest --markers.
config.addinivalue_line(
"markers", "celery(**overrides): override celery configuration for a test case"
)
@contextmanager
def _create_app(enable_logging=False,
use_trap=False,
parameters=None,
**config):
# type: (Any, Any, Any, **Any) -> Celery
"""Utility context used to setup Celery app for pytest fixtures."""
from .testing.app import TestApp, setup_default_app
parameters = {} if not parameters else parameters
test_app = TestApp(
set_as_current=False,
enable_logging=enable_logging,
config=config,
**parameters
)
with setup_default_app(test_app, use_trap=use_trap):
yield test_app
@pytest.fixture(scope='session')
def use_celery_app_trap():
# type: () -> bool
"""You can override this fixture to enable the app trap.
The app trap raises an exception whenever something attempts
to use the current or default apps.
"""
return False
@pytest.fixture(scope='session')
def celery_session_app(request,
celery_config,
celery_parameters,
celery_enable_logging,
use_celery_app_trap):
# type: (Any, Any, Any, Any, Any) -> Celery
"""Session Fixture: Return app for session fixtures."""
mark = request.node.get_closest_marker('celery')
config = dict(celery_config, **mark.kwargs if mark else {})
with _create_app(enable_logging=celery_enable_logging,
use_trap=use_celery_app_trap,
parameters=celery_parameters,
**config) as app:
if not use_celery_app_trap:
app.set_default()
app.set_current()
yield app
@pytest.fixture(scope='session')
def celery_session_worker(
request, # type: Any
celery_session_app, # type: Celery
celery_includes, # type: Sequence[str]
celery_class_tasks, # type: str
celery_worker_pool, # type: Any
celery_worker_parameters, # type: Mapping[str, Any]
):
# type: (...) -> WorkController
"""Session Fixture: Start worker that lives throughout test suite."""
from .testing import worker
if not NO_WORKER:
for module in celery_includes:
celery_session_app.loader.import_task_module(module)
for class_task in celery_class_tasks:
celery_session_app.register_task(class_task)
with worker.start_worker(celery_session_app,
pool=celery_worker_pool,
**celery_worker_parameters) as w:
yield w
@pytest.fixture(scope='session')
def celery_enable_logging():
# type: () -> bool
"""You can override this fixture to enable logging."""
return False
@pytest.fixture(scope='session')
def celery_includes():
# type: () -> Sequence[str]
"""You can override this include modules when a worker start.
You can have this return a list of module names to import,
these can be task modules, modules registering signals, and so on.
"""
return ()
@pytest.fixture(scope='session')
def celery_worker_pool():
# type: () -> Union[str, Any]
"""You can override this fixture to set the worker pool.
The "solo" pool is used by default, but you can set this to
return e.g. "prefork".
"""
return 'solo'
@pytest.fixture(scope='session')
def celery_config():
# type: () -> Mapping[str, Any]
"""Redefine this fixture to configure the test Celery app.
The config returned by your fixture will then be used
to configure the :func:`celery_app` fixture.
"""
return {}
@pytest.fixture(scope='session')
def celery_parameters():
# type: () -> Mapping[str, Any]
"""Redefine this fixture to change the init parameters of test Celery app.
The dict returned by your fixture will then be used
as parameters when instantiating :class:`~celery.Celery`.
"""
return {}
@pytest.fixture(scope='session')
def celery_worker_parameters():
# type: () -> Mapping[str, Any]
"""Redefine this fixture to change the init parameters of Celery workers.
This can be used e. g. to define queues the worker will consume tasks from.
The dict returned by your fixture will then be used
as parameters when instantiating :class:`~celery.worker.WorkController`.
"""
return {}
@pytest.fixture()
def celery_app(request,
celery_config,
celery_parameters,
celery_enable_logging,
use_celery_app_trap):
"""Fixture creating a Celery application instance."""
mark = request.node.get_closest_marker('celery')
config = dict(celery_config, **mark.kwargs if mark else {})
with _create_app(enable_logging=celery_enable_logging,
use_trap=use_celery_app_trap,
parameters=celery_parameters,
**config) as app:
yield app
@pytest.fixture(scope='session')
def celery_class_tasks():
"""Redefine this fixture to register tasks with the test Celery app."""
return []
@pytest.fixture()
def celery_worker(request,
celery_app,
celery_includes,
celery_worker_pool,
celery_worker_parameters):
# type: (Any, Celery, Sequence[str], str, Any) -> WorkController
"""Fixture: Start worker in a thread, stop it when the test returns."""
from .testing import worker
if not NO_WORKER:
for module in celery_includes:
celery_app.loader.import_task_module(module)
with worker.start_worker(celery_app,
pool=celery_worker_pool,
**celery_worker_parameters) as w:
yield w
@pytest.fixture()
def depends_on_current_app(celery_app):
"""Fixture that sets app as current."""
celery_app.set_current()

View File

@@ -0,0 +1,187 @@
"""Remote Debugger.
Introduction
============
This is a remote debugger for Celery tasks running in multiprocessing
pool workers. Inspired by a lost post on dzone.com.
Usage
-----
.. code-block:: python
from celery.contrib import rdb
from celery import task
@task()
def add(x, y):
result = x + y
rdb.set_trace()
return result
Environment Variables
=====================
.. envvar:: CELERY_RDB_HOST
``CELERY_RDB_HOST``
-------------------
Hostname to bind to. Default is '127.0.0.1' (only accessible from
localhost).
.. envvar:: CELERY_RDB_PORT
``CELERY_RDB_PORT``
-------------------
Base port to bind to. Default is 6899.
The debugger will try to find an available port starting from the
base port. The selected port will be logged by the worker.
"""
import errno
import os
import socket
import sys
from pdb import Pdb
from billiard.process import current_process
__all__ = (
'CELERY_RDB_HOST', 'CELERY_RDB_PORT', 'DEFAULT_PORT',
'Rdb', 'debugger', 'set_trace',
)
DEFAULT_PORT = 6899
CELERY_RDB_HOST = os.environ.get('CELERY_RDB_HOST') or '127.0.0.1'
CELERY_RDB_PORT = int(os.environ.get('CELERY_RDB_PORT') or DEFAULT_PORT)
#: Holds the currently active debugger.
_current = [None]
_frame = getattr(sys, '_getframe')
NO_AVAILABLE_PORT = """\
{self.ident}: Couldn't find an available port.
Please specify one using the CELERY_RDB_PORT environment variable.
"""
BANNER = """\
{self.ident}: Ready to connect: telnet {self.host} {self.port}
Type `exit` in session to continue.
{self.ident}: Waiting for client...
"""
SESSION_STARTED = '{self.ident}: Now in session with {self.remote_addr}.'
SESSION_ENDED = '{self.ident}: Session with {self.remote_addr} ended.'
class Rdb(Pdb):
"""Remote debugger."""
me = 'Remote Debugger'
_prev_outs = None
_sock = None
def __init__(self, host=CELERY_RDB_HOST, port=CELERY_RDB_PORT,
port_search_limit=100, port_skew=+0, out=sys.stdout):
self.active = True
self.out = out
self._prev_handles = sys.stdin, sys.stdout
self._sock, this_port = self.get_avail_port(
host, port, port_search_limit, port_skew,
)
self._sock.setblocking(1)
self._sock.listen(1)
self.ident = f'{self.me}:{this_port}'
self.host = host
self.port = this_port
self.say(BANNER.format(self=self))
self._client, address = self._sock.accept()
self._client.setblocking(1)
self.remote_addr = ':'.join(str(v) for v in address)
self.say(SESSION_STARTED.format(self=self))
self._handle = sys.stdin = sys.stdout = self._client.makefile('rw')
super().__init__(completekey='tab',
stdin=self._handle, stdout=self._handle)
def get_avail_port(self, host, port, search_limit=100, skew=+0):
try:
_, skew = current_process().name.split('-')
skew = int(skew)
except ValueError:
pass
this_port = None
for i in range(search_limit):
_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
this_port = port + skew + i
try:
_sock.bind((host, this_port))
except OSError as exc:
if exc.errno in [errno.EADDRINUSE, errno.EINVAL]:
continue
raise
else:
return _sock, this_port
raise Exception(NO_AVAILABLE_PORT.format(self=self))
def say(self, m):
print(m, file=self.out)
def __enter__(self):
return self
def __exit__(self, *exc_info):
self._close_session()
def _close_session(self):
self.stdin, self.stdout = sys.stdin, sys.stdout = self._prev_handles
if self.active:
if self._handle is not None:
self._handle.close()
if self._client is not None:
self._client.close()
if self._sock is not None:
self._sock.close()
self.active = False
self.say(SESSION_ENDED.format(self=self))
def do_continue(self, arg):
self._close_session()
self.set_continue()
return 1
do_c = do_cont = do_continue
def do_quit(self, arg):
self._close_session()
self.set_quit()
return 1
do_q = do_exit = do_quit
def set_quit(self):
# this raises a BdbQuit exception that we're unable to catch.
sys.settrace(None)
def debugger():
"""Return the current debugger instance, or create if none."""
rdb = _current[0]
if rdb is None or not rdb.active:
rdb = _current[0] = Rdb()
return rdb
def set_trace(frame=None):
"""Set break-point at current location, or a specified frame."""
if frame is None:
frame = _frame().f_back
return debugger().set_trace(frame)

View File

@@ -0,0 +1,105 @@
"""Sphinx documentation plugin used to document tasks.
Introduction
============
Usage
-----
The Celery extension for Sphinx requires Sphinx 2.0 or later.
Add the extension to your :file:`docs/conf.py` configuration module:
.. code-block:: python
extensions = (...,
'celery.contrib.sphinx')
If you'd like to change the prefix for tasks in reference documentation
then you can change the ``celery_task_prefix`` configuration value:
.. code-block:: python
celery_task_prefix = '(task)' # < default
With the extension installed `autodoc` will automatically find
task decorated objects (e.g. when using the automodule directive)
and generate the correct (as well as add a ``(task)`` prefix),
and you can also refer to the tasks using `:task:proj.tasks.add`
syntax.
Use ``.. autotask::`` to alternatively manually document a task.
"""
from inspect import signature
from docutils import nodes
from sphinx.domains.python import PyFunction
from sphinx.ext.autodoc import FunctionDocumenter
from celery.app.task import BaseTask
class TaskDocumenter(FunctionDocumenter):
"""Document task definitions."""
objtype = 'task'
member_order = 11
@classmethod
def can_document_member(cls, member, membername, isattr, parent):
return isinstance(member, BaseTask) and getattr(member, '__wrapped__')
def format_args(self):
wrapped = getattr(self.object, '__wrapped__', None)
if wrapped is not None:
sig = signature(wrapped)
if "self" in sig.parameters or "cls" in sig.parameters:
sig = sig.replace(parameters=list(sig.parameters.values())[1:])
return str(sig)
return ''
def document_members(self, all_members=False):
pass
def check_module(self):
# Normally checks if *self.object* is really defined in the module
# given by *self.modname*. But since functions decorated with the @task
# decorator are instances living in the celery.local, we have to check
# the wrapped function instead.
wrapped = getattr(self.object, '__wrapped__', None)
if wrapped and getattr(wrapped, '__module__') == self.modname:
return True
return super().check_module()
class TaskDirective(PyFunction):
"""Sphinx task directive."""
def get_signature_prefix(self, sig):
return [nodes.Text(self.env.config.celery_task_prefix)]
def autodoc_skip_member_handler(app, what, name, obj, skip, options):
"""Handler for autodoc-skip-member event."""
# Celery tasks created with the @task decorator have the property
# that *obj.__doc__* and *obj.__class__.__doc__* are equal, which
# trips up the logic in sphinx.ext.autodoc that is supposed to
# suppress repetition of class documentation in an instance of the
# class. This overrides that behavior.
if isinstance(obj, BaseTask) and getattr(obj, '__wrapped__'):
if skip:
return False
return None
def setup(app):
"""Setup Sphinx extension."""
app.setup_extension('sphinx.ext.autodoc')
app.add_autodocumenter(TaskDocumenter)
app.add_directive_to_domain('py', 'task', TaskDirective)
app.add_config_value('celery_task_prefix', '(task)', True)
app.connect('autodoc-skip-member', autodoc_skip_member_handler)
return {
'parallel_read_safe': True
}

View File

@@ -0,0 +1,112 @@
"""Create Celery app instances used for testing."""
import weakref
from contextlib import contextmanager
from copy import deepcopy
from kombu.utils.imports import symbol_by_name
from celery import Celery, _state
#: Contains the default configuration values for the test app.
DEFAULT_TEST_CONFIG = {
'worker_hijack_root_logger': False,
'worker_log_color': False,
'accept_content': {'json'},
'enable_utc': True,
'timezone': 'UTC',
'broker_url': 'memory://',
'result_backend': 'cache+memory://',
'broker_heartbeat': 0,
}
class Trap:
"""Trap that pretends to be an app but raises an exception instead.
This to protect from code that does not properly pass app instances,
then falls back to the current_app.
"""
def __getattr__(self, name):
# Workaround to allow unittest.mock to patch this object
# in Python 3.8 and above.
if name == '_is_coroutine' or name == '__func__':
return None
print(name)
raise RuntimeError('Test depends on current_app')
class UnitLogging(symbol_by_name(Celery.log_cls)):
"""Sets up logging for the test application."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.already_setup = True
def TestApp(name=None, config=None, enable_logging=False, set_as_current=False,
log=UnitLogging, backend=None, broker=None, **kwargs):
"""App used for testing."""
from . import tasks # noqa
config = dict(deepcopy(DEFAULT_TEST_CONFIG), **config or {})
if broker is not None:
config.pop('broker_url', None)
if backend is not None:
config.pop('result_backend', None)
log = None if enable_logging else log
test_app = Celery(
name or 'celery.tests',
set_as_current=set_as_current,
log=log,
broker=broker,
backend=backend,
**kwargs)
test_app.add_defaults(config)
return test_app
@contextmanager
def set_trap(app):
"""Contextmanager that installs the trap app.
The trap means that anything trying to use the current or default app
will raise an exception.
"""
trap = Trap()
prev_tls = _state._tls
_state.set_default_app(trap)
class NonTLS:
current_app = trap
_state._tls = NonTLS()
try:
yield
finally:
_state._tls = prev_tls
@contextmanager
def setup_default_app(app, use_trap=False):
"""Setup default app for testing.
Ensures state is clean after the test returns.
"""
prev_current_app = _state.get_current_app()
prev_default_app = _state.default_app
prev_finalizers = set(_state._on_app_finalizers)
prev_apps = weakref.WeakSet(_state._apps)
try:
if use_trap:
with set_trap(app):
yield
else:
yield
finally:
_state.set_default_app(prev_default_app)
_state._tls.current_app = prev_current_app
if app is not prev_current_app:
app.close()
_state._on_app_finalizers = prev_finalizers
_state._apps = prev_apps

View File

@@ -0,0 +1,239 @@
"""Integration testing utilities."""
import socket
import sys
from collections import defaultdict
from functools import partial
from itertools import count
from typing import Any, Callable, Dict, Sequence, TextIO, Tuple # noqa
from kombu.exceptions import ContentDisallowed
from kombu.utils.functional import retry_over_time
from celery import states
from celery.exceptions import TimeoutError
from celery.result import AsyncResult, ResultSet # noqa
from celery.utils.text import truncate
from celery.utils.time import humanize_seconds as _humanize_seconds
E_STILL_WAITING = 'Still waiting for {0}. Trying again {when}: {exc!r}'
humanize_seconds = partial(_humanize_seconds, microseconds=True)
class Sentinel(Exception):
"""Signifies the end of something."""
class ManagerMixin:
"""Mixin that adds :class:`Manager` capabilities."""
def _init_manager(self,
block_timeout=30 * 60.0, no_join=False,
stdout=None, stderr=None):
# type: (float, bool, TextIO, TextIO) -> None
self.stdout = sys.stdout if stdout is None else stdout
self.stderr = sys.stderr if stderr is None else stderr
self.connerrors = self.app.connection().recoverable_connection_errors
self.block_timeout = block_timeout
self.no_join = no_join
def remark(self, s, sep='-'):
# type: (str, str) -> None
print(f'{sep}{s}', file=self.stdout)
def missing_results(self, r):
# type: (Sequence[AsyncResult]) -> Sequence[str]
return [res.id for res in r if res.id not in res.backend._cache]
def wait_for(
self,
fun, # type: Callable
catch, # type: Sequence[Any]
desc="thing", # type: str
args=(), # type: Tuple
kwargs=None, # type: Dict
errback=None, # type: Callable
max_retries=10, # type: int
interval_start=0.1, # type: float
interval_step=0.5, # type: float
interval_max=5.0, # type: float
emit_warning=False, # type: bool
**options # type: Any
):
# type: (...) -> Any
"""Wait for event to happen.
The `catch` argument specifies the exception that means the event
has not happened yet.
"""
kwargs = {} if not kwargs else kwargs
def on_error(exc, intervals, retries):
interval = next(intervals)
if emit_warning:
self.warn(E_STILL_WAITING.format(
desc, when=humanize_seconds(interval, 'in', ' '), exc=exc,
))
if errback:
errback(exc, interval, retries)
return interval
return self.retry_over_time(
fun, catch,
args=args, kwargs=kwargs,
errback=on_error, max_retries=max_retries,
interval_start=interval_start, interval_step=interval_step,
**options
)
def ensure_not_for_a_while(self, fun, catch,
desc='thing', max_retries=20,
interval_start=0.1, interval_step=0.02,
interval_max=1.0, emit_warning=False,
**options):
"""Make sure something does not happen (at least for a while)."""
try:
return self.wait_for(
fun, catch, desc=desc, max_retries=max_retries,
interval_start=interval_start, interval_step=interval_step,
interval_max=interval_max, emit_warning=emit_warning,
)
except catch:
pass
else:
raise AssertionError(f'Should not have happened: {desc}')
def retry_over_time(self, *args, **kwargs):
return retry_over_time(*args, **kwargs)
def join(self, r, propagate=False, max_retries=10, **kwargs):
if self.no_join:
return
if not isinstance(r, ResultSet):
r = self.app.ResultSet([r])
received = []
def on_result(task_id, value):
received.append(task_id)
for i in range(max_retries) if max_retries else count(0):
received[:] = []
try:
return r.get(callback=on_result, propagate=propagate, **kwargs)
except (socket.timeout, TimeoutError) as exc:
waiting_for = self.missing_results(r)
self.remark(
'Still waiting for {}/{}: [{}]: {!r}'.format(
len(r) - len(received), len(r),
truncate(', '.join(waiting_for)), exc), '!',
)
except self.connerrors as exc:
self.remark(f'join: connection lost: {exc!r}', '!')
raise AssertionError('Test failed: Missing task results')
def inspect(self, timeout=3.0):
return self.app.control.inspect(timeout=timeout)
def query_tasks(self, ids, timeout=0.5):
tasks = self.inspect(timeout).query_task(*ids) or {}
yield from tasks.items()
def query_task_states(self, ids, timeout=0.5):
states = defaultdict(set)
for hostname, reply in self.query_tasks(ids, timeout=timeout):
for task_id, (state, _) in reply.items():
states[state].add(task_id)
return states
def assert_accepted(self, ids, interval=0.5,
desc='waiting for tasks to be accepted', **policy):
return self.assert_task_worker_state(
self.is_accepted, ids, interval=interval, desc=desc, **policy
)
def assert_received(self, ids, interval=0.5,
desc='waiting for tasks to be received', **policy):
return self.assert_task_worker_state(
self.is_received, ids, interval=interval, desc=desc, **policy
)
def assert_result_tasks_in_progress_or_completed(
self,
async_results,
interval=0.5,
desc='waiting for tasks to be started or completed',
**policy
):
return self.assert_task_state_from_result(
self.is_result_task_in_progress,
async_results,
interval=interval, desc=desc, **policy
)
def assert_task_state_from_result(self, fun, results,
interval=0.5, **policy):
return self.wait_for(
partial(self.true_or_raise, fun, results, timeout=interval),
(Sentinel,), **policy
)
@staticmethod
def is_result_task_in_progress(results, **kwargs):
possible_states = (states.STARTED, states.SUCCESS)
return all(result.state in possible_states for result in results)
def assert_task_worker_state(self, fun, ids, interval=0.5, **policy):
return self.wait_for(
partial(self.true_or_raise, fun, ids, timeout=interval),
(Sentinel,), **policy
)
def is_received(self, ids, **kwargs):
return self._ids_matches_state(
['reserved', 'active', 'ready'], ids, **kwargs)
def is_accepted(self, ids, **kwargs):
return self._ids_matches_state(['active', 'ready'], ids, **kwargs)
def _ids_matches_state(self, expected_states, ids, timeout=0.5):
states = self.query_task_states(ids, timeout=timeout)
return all(
any(t in s for s in [states[k] for k in expected_states])
for t in ids
)
def true_or_raise(self, fun, *args, **kwargs):
res = fun(*args, **kwargs)
if not res:
raise Sentinel()
return res
def wait_until_idle(self):
control = self.app.control
with self.app.connection() as connection:
# Try to purge the queue before we start
# to attempt to avoid interference from other tests
while True:
count = control.purge(connection=connection)
if count == 0:
break
# Wait until worker is idle
inspect = control.inspect()
inspect.connection = connection
while True:
try:
count = sum(len(t) for t in inspect.active().values())
except ContentDisallowed:
# test_security_task_done may trigger this exception
break
if count == 0:
break
class Manager(ManagerMixin):
"""Test helpers for task integration tests."""
def __init__(self, app, **kwargs):
self.app = app
self._init_manager(**kwargs)

View File

@@ -0,0 +1,137 @@
"""Useful mocks for unit testing."""
import numbers
from datetime import datetime, timedelta
from typing import Any, Mapping, Sequence # noqa
from unittest.mock import Mock
from celery import Celery # noqa
from celery.canvas import Signature # noqa
def TaskMessage(
name, # type: str
id=None, # type: str
args=(), # type: Sequence
kwargs=None, # type: Mapping
callbacks=None, # type: Sequence[Signature]
errbacks=None, # type: Sequence[Signature]
chain=None, # type: Sequence[Signature]
shadow=None, # type: str
utc=None, # type: bool
**options # type: Any
):
# type: (...) -> Any
"""Create task message in protocol 2 format."""
kwargs = {} if not kwargs else kwargs
from kombu.serialization import dumps
from celery import uuid
id = id or uuid()
message = Mock(name=f'TaskMessage-{id}')
message.headers = {
'id': id,
'task': name,
'shadow': shadow,
}
embed = {'callbacks': callbacks, 'errbacks': errbacks, 'chain': chain}
message.headers.update(options)
message.content_type, message.content_encoding, message.body = dumps(
(args, kwargs, embed), serializer='json',
)
message.payload = (args, kwargs, embed)
return message
def TaskMessage1(
name, # type: str
id=None, # type: str
args=(), # type: Sequence
kwargs=None, # type: Mapping
callbacks=None, # type: Sequence[Signature]
errbacks=None, # type: Sequence[Signature]
chain=None, # type: Sequence[Signature]
**options # type: Any
):
# type: (...) -> Any
"""Create task message in protocol 1 format."""
kwargs = {} if not kwargs else kwargs
from kombu.serialization import dumps
from celery import uuid
id = id or uuid()
message = Mock(name=f'TaskMessage-{id}')
message.headers = {}
message.payload = {
'task': name,
'id': id,
'args': args,
'kwargs': kwargs,
'callbacks': callbacks,
'errbacks': errbacks,
}
message.payload.update(options)
message.content_type, message.content_encoding, message.body = dumps(
message.payload,
)
return message
def task_message_from_sig(app, sig, utc=True, TaskMessage=TaskMessage):
# type: (Celery, Signature, bool, Any) -> Any
"""Create task message from :class:`celery.Signature`.
Example:
>>> m = task_message_from_sig(app, add.s(2, 2))
>>> amqp_client.basic_publish(m, exchange='ex', routing_key='rkey')
"""
sig.freeze()
callbacks = sig.options.pop('link', None)
errbacks = sig.options.pop('link_error', None)
countdown = sig.options.pop('countdown', None)
if countdown:
eta = app.now() + timedelta(seconds=countdown)
else:
eta = sig.options.pop('eta', None)
if eta and isinstance(eta, datetime):
eta = eta.isoformat()
expires = sig.options.pop('expires', None)
if expires and isinstance(expires, numbers.Real):
expires = app.now() + timedelta(seconds=expires)
if expires and isinstance(expires, datetime):
expires = expires.isoformat()
return TaskMessage(
sig.task, id=sig.id, args=sig.args,
kwargs=sig.kwargs,
callbacks=[dict(s) for s in callbacks] if callbacks else None,
errbacks=[dict(s) for s in errbacks] if errbacks else None,
eta=eta,
expires=expires,
utc=utc,
**sig.options
)
class _ContextMock(Mock):
"""Dummy class implementing __enter__ and __exit__.
The :keyword:`with` statement requires these to be implemented
in the class, not just the instance.
"""
def __enter__(self):
return self
def __exit__(self, *exc_info):
pass
def ContextMock(*args, **kwargs):
"""Mock that mocks :keyword:`with` statement contexts."""
obj = _ContextMock(*args, **kwargs)
obj.attach_mock(_ContextMock(), '__enter__')
obj.attach_mock(_ContextMock(), '__exit__')
obj.__enter__.return_value = obj
# if __exit__ return a value the exception is ignored,
# so it must return None here.
obj.__exit__.return_value = None
return obj

View File

@@ -0,0 +1,9 @@
"""Helper tasks for integration tests."""
from celery import shared_task
@shared_task(name='celery.ping')
def ping():
# type: () -> str
"""Simple task that just returns 'pong'."""
return 'pong'

View File

@@ -0,0 +1,223 @@
"""Embedded workers for integration tests."""
import logging
import os
import threading
from contextlib import contextmanager
from typing import Any, Iterable, Optional, Union
import celery.worker.consumer # noqa
from celery import Celery, worker
from celery.result import _set_task_join_will_block, allow_join_result
from celery.utils.dispatch import Signal
from celery.utils.nodenames import anon_nodename
WORKER_LOGLEVEL = os.environ.get('WORKER_LOGLEVEL', 'error')
test_worker_starting = Signal(
name='test_worker_starting',
providing_args={},
)
test_worker_started = Signal(
name='test_worker_started',
providing_args={'worker', 'consumer'},
)
test_worker_stopped = Signal(
name='test_worker_stopped',
providing_args={'worker'},
)
class TestWorkController(worker.WorkController):
"""Worker that can synchronize on being fully started."""
# When this class is imported in pytest files, prevent pytest from thinking
# this is a test class
__test__ = False
logger_queue = None
def __init__(self, *args, **kwargs):
# type: (*Any, **Any) -> None
self._on_started = threading.Event()
super().__init__(*args, **kwargs)
if self.pool_cls.__module__.split('.')[-1] == 'prefork':
from billiard import Queue
self.logger_queue = Queue()
self.pid = os.getpid()
try:
from tblib import pickling_support
pickling_support.install()
except ImportError:
pass
# collect logs from forked process.
# XXX: those logs will appear twice in the live log
self.queue_listener = logging.handlers.QueueListener(self.logger_queue, logging.getLogger())
self.queue_listener.start()
class QueueHandler(logging.handlers.QueueHandler):
def prepare(self, record):
record.from_queue = True
# Keep origin record.
return record
def handleError(self, record):
if logging.raiseExceptions:
raise
def start(self):
if self.logger_queue:
handler = self.QueueHandler(self.logger_queue)
handler.addFilter(lambda r: r.process != self.pid and not getattr(r, 'from_queue', False))
logger = logging.getLogger()
logger.addHandler(handler)
return super().start()
def on_consumer_ready(self, consumer):
# type: (celery.worker.consumer.Consumer) -> None
"""Callback called when the Consumer blueprint is fully started."""
self._on_started.set()
test_worker_started.send(
sender=self.app, worker=self, consumer=consumer)
def ensure_started(self):
# type: () -> None
"""Wait for worker to be fully up and running.
Warning:
Worker must be started within a thread for this to work,
or it will block forever.
"""
self._on_started.wait()
@contextmanager
def start_worker(
app, # type: Celery
concurrency=1, # type: int
pool='solo', # type: str
loglevel=WORKER_LOGLEVEL, # type: Union[str, int]
logfile=None, # type: str
perform_ping_check=True, # type: bool
ping_task_timeout=10.0, # type: float
shutdown_timeout=10.0, # type: float
**kwargs # type: Any
):
# type: (...) -> Iterable
"""Start embedded worker.
Yields:
celery.app.worker.Worker: worker instance.
"""
test_worker_starting.send(sender=app)
worker = None
try:
with _start_worker_thread(app,
concurrency=concurrency,
pool=pool,
loglevel=loglevel,
logfile=logfile,
perform_ping_check=perform_ping_check,
shutdown_timeout=shutdown_timeout,
**kwargs) as worker:
if perform_ping_check:
from .tasks import ping
with allow_join_result():
assert ping.delay().get(timeout=ping_task_timeout) == 'pong'
yield worker
finally:
test_worker_stopped.send(sender=app, worker=worker)
@contextmanager
def _start_worker_thread(app: Celery,
concurrency: int = 1,
pool: str = 'solo',
loglevel: Union[str, int] = WORKER_LOGLEVEL,
logfile: Optional[str] = None,
WorkController: Any = TestWorkController,
perform_ping_check: bool = True,
shutdown_timeout: float = 10.0,
**kwargs) -> Iterable[worker.WorkController]:
"""Start Celery worker in a thread.
Yields:
celery.worker.Worker: worker instance.
"""
setup_app_for_worker(app, loglevel, logfile)
if perform_ping_check:
assert 'celery.ping' in app.tasks
# Make sure we can connect to the broker
with app.connection(hostname=os.environ.get('TEST_BROKER')) as conn:
conn.default_channel.queue_declare
worker = WorkController(
app=app,
concurrency=concurrency,
hostname=kwargs.pop("hostname", anon_nodename()),
pool=pool,
loglevel=loglevel,
logfile=logfile,
# not allowed to override TestWorkController.on_consumer_ready
ready_callback=None,
without_heartbeat=kwargs.pop("without_heartbeat", True),
without_mingle=True,
without_gossip=True,
**kwargs)
t = threading.Thread(target=worker.start, daemon=True)
t.start()
worker.ensure_started()
_set_task_join_will_block(False)
try:
yield worker
finally:
from celery.worker import state
state.should_terminate = 0
t.join(shutdown_timeout)
if t.is_alive():
raise RuntimeError(
"Worker thread failed to exit within the allocated timeout. "
"Consider raising `shutdown_timeout` if your tasks take longer "
"to execute."
)
state.should_terminate = None
@contextmanager
def _start_worker_process(app,
concurrency=1,
pool='solo',
loglevel=WORKER_LOGLEVEL,
logfile=None,
**kwargs):
# type (Celery, int, str, Union[int, str], str, **Any) -> Iterable
"""Start worker in separate process.
Yields:
celery.app.worker.Worker: worker instance.
"""
from celery.apps.multi import Cluster, Node
app.set_current()
cluster = Cluster([Node('testworker1@%h')])
cluster.start()
try:
yield
finally:
cluster.stopwait()
def setup_app_for_worker(app: Celery, loglevel: Union[str, int], logfile: str) -> None:
"""Setup the app to be used for starting an embedded worker."""
app.finalize()
app.set_current()
app.set_default()
type(app.log)._setup = False
app.log.setup(loglevel=loglevel, logfile=logfile)