This commit is contained in:
Iliyan Angelov
2025-09-14 23:24:25 +03:00
commit c67067a2a4
71311 changed files with 6800714 additions and 0 deletions

View File

@@ -0,0 +1,112 @@
"""Create Celery app instances used for testing."""
import weakref
from contextlib import contextmanager
from copy import deepcopy
from kombu.utils.imports import symbol_by_name
from celery import Celery, _state
#: Contains the default configuration values for the test app.
DEFAULT_TEST_CONFIG = {
'worker_hijack_root_logger': False,
'worker_log_color': False,
'accept_content': {'json'},
'enable_utc': True,
'timezone': 'UTC',
'broker_url': 'memory://',
'result_backend': 'cache+memory://',
'broker_heartbeat': 0,
}
class Trap:
"""Trap that pretends to be an app but raises an exception instead.
This to protect from code that does not properly pass app instances,
then falls back to the current_app.
"""
def __getattr__(self, name):
# Workaround to allow unittest.mock to patch this object
# in Python 3.8 and above.
if name == '_is_coroutine' or name == '__func__':
return None
print(name)
raise RuntimeError('Test depends on current_app')
class UnitLogging(symbol_by_name(Celery.log_cls)):
"""Sets up logging for the test application."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.already_setup = True
def TestApp(name=None, config=None, enable_logging=False, set_as_current=False,
log=UnitLogging, backend=None, broker=None, **kwargs):
"""App used for testing."""
from . import tasks # noqa
config = dict(deepcopy(DEFAULT_TEST_CONFIG), **config or {})
if broker is not None:
config.pop('broker_url', None)
if backend is not None:
config.pop('result_backend', None)
log = None if enable_logging else log
test_app = Celery(
name or 'celery.tests',
set_as_current=set_as_current,
log=log,
broker=broker,
backend=backend,
**kwargs)
test_app.add_defaults(config)
return test_app
@contextmanager
def set_trap(app):
"""Contextmanager that installs the trap app.
The trap means that anything trying to use the current or default app
will raise an exception.
"""
trap = Trap()
prev_tls = _state._tls
_state.set_default_app(trap)
class NonTLS:
current_app = trap
_state._tls = NonTLS()
try:
yield
finally:
_state._tls = prev_tls
@contextmanager
def setup_default_app(app, use_trap=False):
"""Setup default app for testing.
Ensures state is clean after the test returns.
"""
prev_current_app = _state.get_current_app()
prev_default_app = _state.default_app
prev_finalizers = set(_state._on_app_finalizers)
prev_apps = weakref.WeakSet(_state._apps)
try:
if use_trap:
with set_trap(app):
yield
else:
yield
finally:
_state.set_default_app(prev_default_app)
_state._tls.current_app = prev_current_app
if app is not prev_current_app:
app.close()
_state._on_app_finalizers = prev_finalizers
_state._apps = prev_apps

View File

@@ -0,0 +1,239 @@
"""Integration testing utilities."""
import socket
import sys
from collections import defaultdict
from functools import partial
from itertools import count
from typing import Any, Callable, Dict, Sequence, TextIO, Tuple # noqa
from kombu.exceptions import ContentDisallowed
from kombu.utils.functional import retry_over_time
from celery import states
from celery.exceptions import TimeoutError
from celery.result import AsyncResult, ResultSet # noqa
from celery.utils.text import truncate
from celery.utils.time import humanize_seconds as _humanize_seconds
E_STILL_WAITING = 'Still waiting for {0}. Trying again {when}: {exc!r}'
humanize_seconds = partial(_humanize_seconds, microseconds=True)
class Sentinel(Exception):
"""Signifies the end of something."""
class ManagerMixin:
"""Mixin that adds :class:`Manager` capabilities."""
def _init_manager(self,
block_timeout=30 * 60.0, no_join=False,
stdout=None, stderr=None):
# type: (float, bool, TextIO, TextIO) -> None
self.stdout = sys.stdout if stdout is None else stdout
self.stderr = sys.stderr if stderr is None else stderr
self.connerrors = self.app.connection().recoverable_connection_errors
self.block_timeout = block_timeout
self.no_join = no_join
def remark(self, s, sep='-'):
# type: (str, str) -> None
print(f'{sep}{s}', file=self.stdout)
def missing_results(self, r):
# type: (Sequence[AsyncResult]) -> Sequence[str]
return [res.id for res in r if res.id not in res.backend._cache]
def wait_for(
self,
fun, # type: Callable
catch, # type: Sequence[Any]
desc="thing", # type: str
args=(), # type: Tuple
kwargs=None, # type: Dict
errback=None, # type: Callable
max_retries=10, # type: int
interval_start=0.1, # type: float
interval_step=0.5, # type: float
interval_max=5.0, # type: float
emit_warning=False, # type: bool
**options # type: Any
):
# type: (...) -> Any
"""Wait for event to happen.
The `catch` argument specifies the exception that means the event
has not happened yet.
"""
kwargs = {} if not kwargs else kwargs
def on_error(exc, intervals, retries):
interval = next(intervals)
if emit_warning:
self.warn(E_STILL_WAITING.format(
desc, when=humanize_seconds(interval, 'in', ' '), exc=exc,
))
if errback:
errback(exc, interval, retries)
return interval
return self.retry_over_time(
fun, catch,
args=args, kwargs=kwargs,
errback=on_error, max_retries=max_retries,
interval_start=interval_start, interval_step=interval_step,
**options
)
def ensure_not_for_a_while(self, fun, catch,
desc='thing', max_retries=20,
interval_start=0.1, interval_step=0.02,
interval_max=1.0, emit_warning=False,
**options):
"""Make sure something does not happen (at least for a while)."""
try:
return self.wait_for(
fun, catch, desc=desc, max_retries=max_retries,
interval_start=interval_start, interval_step=interval_step,
interval_max=interval_max, emit_warning=emit_warning,
)
except catch:
pass
else:
raise AssertionError(f'Should not have happened: {desc}')
def retry_over_time(self, *args, **kwargs):
return retry_over_time(*args, **kwargs)
def join(self, r, propagate=False, max_retries=10, **kwargs):
if self.no_join:
return
if not isinstance(r, ResultSet):
r = self.app.ResultSet([r])
received = []
def on_result(task_id, value):
received.append(task_id)
for i in range(max_retries) if max_retries else count(0):
received[:] = []
try:
return r.get(callback=on_result, propagate=propagate, **kwargs)
except (socket.timeout, TimeoutError) as exc:
waiting_for = self.missing_results(r)
self.remark(
'Still waiting for {}/{}: [{}]: {!r}'.format(
len(r) - len(received), len(r),
truncate(', '.join(waiting_for)), exc), '!',
)
except self.connerrors as exc:
self.remark(f'join: connection lost: {exc!r}', '!')
raise AssertionError('Test failed: Missing task results')
def inspect(self, timeout=3.0):
return self.app.control.inspect(timeout=timeout)
def query_tasks(self, ids, timeout=0.5):
tasks = self.inspect(timeout).query_task(*ids) or {}
yield from tasks.items()
def query_task_states(self, ids, timeout=0.5):
states = defaultdict(set)
for hostname, reply in self.query_tasks(ids, timeout=timeout):
for task_id, (state, _) in reply.items():
states[state].add(task_id)
return states
def assert_accepted(self, ids, interval=0.5,
desc='waiting for tasks to be accepted', **policy):
return self.assert_task_worker_state(
self.is_accepted, ids, interval=interval, desc=desc, **policy
)
def assert_received(self, ids, interval=0.5,
desc='waiting for tasks to be received', **policy):
return self.assert_task_worker_state(
self.is_received, ids, interval=interval, desc=desc, **policy
)
def assert_result_tasks_in_progress_or_completed(
self,
async_results,
interval=0.5,
desc='waiting for tasks to be started or completed',
**policy
):
return self.assert_task_state_from_result(
self.is_result_task_in_progress,
async_results,
interval=interval, desc=desc, **policy
)
def assert_task_state_from_result(self, fun, results,
interval=0.5, **policy):
return self.wait_for(
partial(self.true_or_raise, fun, results, timeout=interval),
(Sentinel,), **policy
)
@staticmethod
def is_result_task_in_progress(results, **kwargs):
possible_states = (states.STARTED, states.SUCCESS)
return all(result.state in possible_states for result in results)
def assert_task_worker_state(self, fun, ids, interval=0.5, **policy):
return self.wait_for(
partial(self.true_or_raise, fun, ids, timeout=interval),
(Sentinel,), **policy
)
def is_received(self, ids, **kwargs):
return self._ids_matches_state(
['reserved', 'active', 'ready'], ids, **kwargs)
def is_accepted(self, ids, **kwargs):
return self._ids_matches_state(['active', 'ready'], ids, **kwargs)
def _ids_matches_state(self, expected_states, ids, timeout=0.5):
states = self.query_task_states(ids, timeout=timeout)
return all(
any(t in s for s in [states[k] for k in expected_states])
for t in ids
)
def true_or_raise(self, fun, *args, **kwargs):
res = fun(*args, **kwargs)
if not res:
raise Sentinel()
return res
def wait_until_idle(self):
control = self.app.control
with self.app.connection() as connection:
# Try to purge the queue before we start
# to attempt to avoid interference from other tests
while True:
count = control.purge(connection=connection)
if count == 0:
break
# Wait until worker is idle
inspect = control.inspect()
inspect.connection = connection
while True:
try:
count = sum(len(t) for t in inspect.active().values())
except ContentDisallowed:
# test_security_task_done may trigger this exception
break
if count == 0:
break
class Manager(ManagerMixin):
"""Test helpers for task integration tests."""
def __init__(self, app, **kwargs):
self.app = app
self._init_manager(**kwargs)

View File

@@ -0,0 +1,137 @@
"""Useful mocks for unit testing."""
import numbers
from datetime import datetime, timedelta
from typing import Any, Mapping, Sequence # noqa
from unittest.mock import Mock
from celery import Celery # noqa
from celery.canvas import Signature # noqa
def TaskMessage(
name, # type: str
id=None, # type: str
args=(), # type: Sequence
kwargs=None, # type: Mapping
callbacks=None, # type: Sequence[Signature]
errbacks=None, # type: Sequence[Signature]
chain=None, # type: Sequence[Signature]
shadow=None, # type: str
utc=None, # type: bool
**options # type: Any
):
# type: (...) -> Any
"""Create task message in protocol 2 format."""
kwargs = {} if not kwargs else kwargs
from kombu.serialization import dumps
from celery import uuid
id = id or uuid()
message = Mock(name=f'TaskMessage-{id}')
message.headers = {
'id': id,
'task': name,
'shadow': shadow,
}
embed = {'callbacks': callbacks, 'errbacks': errbacks, 'chain': chain}
message.headers.update(options)
message.content_type, message.content_encoding, message.body = dumps(
(args, kwargs, embed), serializer='json',
)
message.payload = (args, kwargs, embed)
return message
def TaskMessage1(
name, # type: str
id=None, # type: str
args=(), # type: Sequence
kwargs=None, # type: Mapping
callbacks=None, # type: Sequence[Signature]
errbacks=None, # type: Sequence[Signature]
chain=None, # type: Sequence[Signature]
**options # type: Any
):
# type: (...) -> Any
"""Create task message in protocol 1 format."""
kwargs = {} if not kwargs else kwargs
from kombu.serialization import dumps
from celery import uuid
id = id or uuid()
message = Mock(name=f'TaskMessage-{id}')
message.headers = {}
message.payload = {
'task': name,
'id': id,
'args': args,
'kwargs': kwargs,
'callbacks': callbacks,
'errbacks': errbacks,
}
message.payload.update(options)
message.content_type, message.content_encoding, message.body = dumps(
message.payload,
)
return message
def task_message_from_sig(app, sig, utc=True, TaskMessage=TaskMessage):
# type: (Celery, Signature, bool, Any) -> Any
"""Create task message from :class:`celery.Signature`.
Example:
>>> m = task_message_from_sig(app, add.s(2, 2))
>>> amqp_client.basic_publish(m, exchange='ex', routing_key='rkey')
"""
sig.freeze()
callbacks = sig.options.pop('link', None)
errbacks = sig.options.pop('link_error', None)
countdown = sig.options.pop('countdown', None)
if countdown:
eta = app.now() + timedelta(seconds=countdown)
else:
eta = sig.options.pop('eta', None)
if eta and isinstance(eta, datetime):
eta = eta.isoformat()
expires = sig.options.pop('expires', None)
if expires and isinstance(expires, numbers.Real):
expires = app.now() + timedelta(seconds=expires)
if expires and isinstance(expires, datetime):
expires = expires.isoformat()
return TaskMessage(
sig.task, id=sig.id, args=sig.args,
kwargs=sig.kwargs,
callbacks=[dict(s) for s in callbacks] if callbacks else None,
errbacks=[dict(s) for s in errbacks] if errbacks else None,
eta=eta,
expires=expires,
utc=utc,
**sig.options
)
class _ContextMock(Mock):
"""Dummy class implementing __enter__ and __exit__.
The :keyword:`with` statement requires these to be implemented
in the class, not just the instance.
"""
def __enter__(self):
return self
def __exit__(self, *exc_info):
pass
def ContextMock(*args, **kwargs):
"""Mock that mocks :keyword:`with` statement contexts."""
obj = _ContextMock(*args, **kwargs)
obj.attach_mock(_ContextMock(), '__enter__')
obj.attach_mock(_ContextMock(), '__exit__')
obj.__enter__.return_value = obj
# if __exit__ return a value the exception is ignored,
# so it must return None here.
obj.__exit__.return_value = None
return obj

View File

@@ -0,0 +1,9 @@
"""Helper tasks for integration tests."""
from celery import shared_task
@shared_task(name='celery.ping')
def ping():
# type: () -> str
"""Simple task that just returns 'pong'."""
return 'pong'

View File

@@ -0,0 +1,221 @@
"""Embedded workers for integration tests."""
import logging
import os
import threading
from contextlib import contextmanager
from typing import Any, Iterable, Union # noqa
import celery.worker.consumer # noqa
from celery import Celery, worker # noqa
from celery.result import _set_task_join_will_block, allow_join_result
from celery.utils.dispatch import Signal
from celery.utils.nodenames import anon_nodename
WORKER_LOGLEVEL = os.environ.get('WORKER_LOGLEVEL', 'error')
test_worker_starting = Signal(
name='test_worker_starting',
providing_args={},
)
test_worker_started = Signal(
name='test_worker_started',
providing_args={'worker', 'consumer'},
)
test_worker_stopped = Signal(
name='test_worker_stopped',
providing_args={'worker'},
)
class TestWorkController(worker.WorkController):
"""Worker that can synchronize on being fully started."""
logger_queue = None
def __init__(self, *args, **kwargs):
# type: (*Any, **Any) -> None
self._on_started = threading.Event()
super().__init__(*args, **kwargs)
if self.pool_cls.__module__.split('.')[-1] == 'prefork':
from billiard import Queue
self.logger_queue = Queue()
self.pid = os.getpid()
try:
from tblib import pickling_support
pickling_support.install()
except ImportError:
pass
# collect logs from forked process.
# XXX: those logs will appear twice in the live log
self.queue_listener = logging.handlers.QueueListener(self.logger_queue, logging.getLogger())
self.queue_listener.start()
class QueueHandler(logging.handlers.QueueHandler):
def prepare(self, record):
record.from_queue = True
# Keep origin record.
return record
def handleError(self, record):
if logging.raiseExceptions:
raise
def start(self):
if self.logger_queue:
handler = self.QueueHandler(self.logger_queue)
handler.addFilter(lambda r: r.process != self.pid and not getattr(r, 'from_queue', False))
logger = logging.getLogger()
logger.addHandler(handler)
return super().start()
def on_consumer_ready(self, consumer):
# type: (celery.worker.consumer.Consumer) -> None
"""Callback called when the Consumer blueprint is fully started."""
self._on_started.set()
test_worker_started.send(
sender=self.app, worker=self, consumer=consumer)
def ensure_started(self):
# type: () -> None
"""Wait for worker to be fully up and running.
Warning:
Worker must be started within a thread for this to work,
or it will block forever.
"""
self._on_started.wait()
@contextmanager
def start_worker(
app, # type: Celery
concurrency=1, # type: int
pool='solo', # type: str
loglevel=WORKER_LOGLEVEL, # type: Union[str, int]
logfile=None, # type: str
perform_ping_check=True, # type: bool
ping_task_timeout=10.0, # type: float
shutdown_timeout=10.0, # type: float
**kwargs # type: Any
):
# type: (...) -> Iterable
"""Start embedded worker.
Yields:
celery.app.worker.Worker: worker instance.
"""
test_worker_starting.send(sender=app)
worker = None
try:
with _start_worker_thread(app,
concurrency=concurrency,
pool=pool,
loglevel=loglevel,
logfile=logfile,
perform_ping_check=perform_ping_check,
shutdown_timeout=shutdown_timeout,
**kwargs) as worker:
if perform_ping_check:
from .tasks import ping
with allow_join_result():
assert ping.delay().get(timeout=ping_task_timeout) == 'pong'
yield worker
finally:
test_worker_stopped.send(sender=app, worker=worker)
@contextmanager
def _start_worker_thread(app,
concurrency=1,
pool='solo',
loglevel=WORKER_LOGLEVEL,
logfile=None,
WorkController=TestWorkController,
perform_ping_check=True,
shutdown_timeout=10.0,
**kwargs):
# type: (Celery, int, str, Union[str, int], str, Any, **Any) -> Iterable
"""Start Celery worker in a thread.
Yields:
celery.worker.Worker: worker instance.
"""
setup_app_for_worker(app, loglevel, logfile)
if perform_ping_check:
assert 'celery.ping' in app.tasks
# Make sure we can connect to the broker
with app.connection(hostname=os.environ.get('TEST_BROKER')) as conn:
conn.default_channel.queue_declare
worker = WorkController(
app=app,
concurrency=concurrency,
hostname=anon_nodename(),
pool=pool,
loglevel=loglevel,
logfile=logfile,
# not allowed to override TestWorkController.on_consumer_ready
ready_callback=None,
without_heartbeat=kwargs.pop("without_heartbeat", True),
without_mingle=True,
without_gossip=True,
**kwargs)
t = threading.Thread(target=worker.start, daemon=True)
t.start()
worker.ensure_started()
_set_task_join_will_block(False)
try:
yield worker
finally:
from celery.worker import state
state.should_terminate = 0
t.join(shutdown_timeout)
if t.is_alive():
raise RuntimeError(
"Worker thread failed to exit within the allocated timeout. "
"Consider raising `shutdown_timeout` if your tasks take longer "
"to execute."
)
state.should_terminate = None
@contextmanager
def _start_worker_process(app,
concurrency=1,
pool='solo',
loglevel=WORKER_LOGLEVEL,
logfile=None,
**kwargs):
# type (Celery, int, str, Union[int, str], str, **Any) -> Iterable
"""Start worker in separate process.
Yields:
celery.app.worker.Worker: worker instance.
"""
from celery.apps.multi import Cluster, Node
app.set_current()
cluster = Cluster([Node('testworker1@%h')])
cluster.start()
try:
yield
finally:
cluster.stopwait()
def setup_app_for_worker(app, loglevel, logfile) -> None:
# type: (Celery, Union[str, int], str) -> None
"""Setup the app to be used for starting an embedded worker."""
app.finalize()
app.set_current()
app.set_default()
type(app.log)._setup = False
app.log.setup(loglevel=loglevel, logfile=logfile)