Updates
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,112 @@
|
||||
"""Create Celery app instances used for testing."""
|
||||
import weakref
|
||||
from contextlib import contextmanager
|
||||
from copy import deepcopy
|
||||
|
||||
from kombu.utils.imports import symbol_by_name
|
||||
|
||||
from celery import Celery, _state
|
||||
|
||||
#: Contains the default configuration values for the test app.
|
||||
DEFAULT_TEST_CONFIG = {
|
||||
'worker_hijack_root_logger': False,
|
||||
'worker_log_color': False,
|
||||
'accept_content': {'json'},
|
||||
'enable_utc': True,
|
||||
'timezone': 'UTC',
|
||||
'broker_url': 'memory://',
|
||||
'result_backend': 'cache+memory://',
|
||||
'broker_heartbeat': 0,
|
||||
}
|
||||
|
||||
|
||||
class Trap:
|
||||
"""Trap that pretends to be an app but raises an exception instead.
|
||||
|
||||
This to protect from code that does not properly pass app instances,
|
||||
then falls back to the current_app.
|
||||
"""
|
||||
|
||||
def __getattr__(self, name):
|
||||
# Workaround to allow unittest.mock to patch this object
|
||||
# in Python 3.8 and above.
|
||||
if name == '_is_coroutine' or name == '__func__':
|
||||
return None
|
||||
print(name)
|
||||
raise RuntimeError('Test depends on current_app')
|
||||
|
||||
|
||||
class UnitLogging(symbol_by_name(Celery.log_cls)):
|
||||
"""Sets up logging for the test application."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.already_setup = True
|
||||
|
||||
|
||||
def TestApp(name=None, config=None, enable_logging=False, set_as_current=False,
|
||||
log=UnitLogging, backend=None, broker=None, **kwargs):
|
||||
"""App used for testing."""
|
||||
from . import tasks # noqa
|
||||
config = dict(deepcopy(DEFAULT_TEST_CONFIG), **config or {})
|
||||
if broker is not None:
|
||||
config.pop('broker_url', None)
|
||||
if backend is not None:
|
||||
config.pop('result_backend', None)
|
||||
log = None if enable_logging else log
|
||||
test_app = Celery(
|
||||
name or 'celery.tests',
|
||||
set_as_current=set_as_current,
|
||||
log=log,
|
||||
broker=broker,
|
||||
backend=backend,
|
||||
**kwargs)
|
||||
test_app.add_defaults(config)
|
||||
return test_app
|
||||
|
||||
|
||||
@contextmanager
|
||||
def set_trap(app):
|
||||
"""Contextmanager that installs the trap app.
|
||||
|
||||
The trap means that anything trying to use the current or default app
|
||||
will raise an exception.
|
||||
"""
|
||||
trap = Trap()
|
||||
prev_tls = _state._tls
|
||||
_state.set_default_app(trap)
|
||||
|
||||
class NonTLS:
|
||||
current_app = trap
|
||||
_state._tls = NonTLS()
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_state._tls = prev_tls
|
||||
|
||||
|
||||
@contextmanager
|
||||
def setup_default_app(app, use_trap=False):
|
||||
"""Setup default app for testing.
|
||||
|
||||
Ensures state is clean after the test returns.
|
||||
"""
|
||||
prev_current_app = _state.get_current_app()
|
||||
prev_default_app = _state.default_app
|
||||
prev_finalizers = set(_state._on_app_finalizers)
|
||||
prev_apps = weakref.WeakSet(_state._apps)
|
||||
|
||||
try:
|
||||
if use_trap:
|
||||
with set_trap(app):
|
||||
yield
|
||||
else:
|
||||
yield
|
||||
finally:
|
||||
_state.set_default_app(prev_default_app)
|
||||
_state._tls.current_app = prev_current_app
|
||||
if app is not prev_current_app:
|
||||
app.close()
|
||||
_state._on_app_finalizers = prev_finalizers
|
||||
_state._apps = prev_apps
|
||||
@@ -0,0 +1,239 @@
|
||||
"""Integration testing utilities."""
|
||||
import socket
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from functools import partial
|
||||
from itertools import count
|
||||
from typing import Any, Callable, Dict, Sequence, TextIO, Tuple # noqa
|
||||
|
||||
from kombu.exceptions import ContentDisallowed
|
||||
from kombu.utils.functional import retry_over_time
|
||||
|
||||
from celery import states
|
||||
from celery.exceptions import TimeoutError
|
||||
from celery.result import AsyncResult, ResultSet # noqa
|
||||
from celery.utils.text import truncate
|
||||
from celery.utils.time import humanize_seconds as _humanize_seconds
|
||||
|
||||
E_STILL_WAITING = 'Still waiting for {0}. Trying again {when}: {exc!r}'
|
||||
|
||||
humanize_seconds = partial(_humanize_seconds, microseconds=True)
|
||||
|
||||
|
||||
class Sentinel(Exception):
|
||||
"""Signifies the end of something."""
|
||||
|
||||
|
||||
class ManagerMixin:
|
||||
"""Mixin that adds :class:`Manager` capabilities."""
|
||||
|
||||
def _init_manager(self,
|
||||
block_timeout=30 * 60.0, no_join=False,
|
||||
stdout=None, stderr=None):
|
||||
# type: (float, bool, TextIO, TextIO) -> None
|
||||
self.stdout = sys.stdout if stdout is None else stdout
|
||||
self.stderr = sys.stderr if stderr is None else stderr
|
||||
self.connerrors = self.app.connection().recoverable_connection_errors
|
||||
self.block_timeout = block_timeout
|
||||
self.no_join = no_join
|
||||
|
||||
def remark(self, s, sep='-'):
|
||||
# type: (str, str) -> None
|
||||
print(f'{sep}{s}', file=self.stdout)
|
||||
|
||||
def missing_results(self, r):
|
||||
# type: (Sequence[AsyncResult]) -> Sequence[str]
|
||||
return [res.id for res in r if res.id not in res.backend._cache]
|
||||
|
||||
def wait_for(
|
||||
self,
|
||||
fun, # type: Callable
|
||||
catch, # type: Sequence[Any]
|
||||
desc="thing", # type: str
|
||||
args=(), # type: Tuple
|
||||
kwargs=None, # type: Dict
|
||||
errback=None, # type: Callable
|
||||
max_retries=10, # type: int
|
||||
interval_start=0.1, # type: float
|
||||
interval_step=0.5, # type: float
|
||||
interval_max=5.0, # type: float
|
||||
emit_warning=False, # type: bool
|
||||
**options # type: Any
|
||||
):
|
||||
# type: (...) -> Any
|
||||
"""Wait for event to happen.
|
||||
|
||||
The `catch` argument specifies the exception that means the event
|
||||
has not happened yet.
|
||||
"""
|
||||
kwargs = {} if not kwargs else kwargs
|
||||
|
||||
def on_error(exc, intervals, retries):
|
||||
interval = next(intervals)
|
||||
if emit_warning:
|
||||
self.warn(E_STILL_WAITING.format(
|
||||
desc, when=humanize_seconds(interval, 'in', ' '), exc=exc,
|
||||
))
|
||||
if errback:
|
||||
errback(exc, interval, retries)
|
||||
return interval
|
||||
|
||||
return self.retry_over_time(
|
||||
fun, catch,
|
||||
args=args, kwargs=kwargs,
|
||||
errback=on_error, max_retries=max_retries,
|
||||
interval_start=interval_start, interval_step=interval_step,
|
||||
**options
|
||||
)
|
||||
|
||||
def ensure_not_for_a_while(self, fun, catch,
|
||||
desc='thing', max_retries=20,
|
||||
interval_start=0.1, interval_step=0.02,
|
||||
interval_max=1.0, emit_warning=False,
|
||||
**options):
|
||||
"""Make sure something does not happen (at least for a while)."""
|
||||
try:
|
||||
return self.wait_for(
|
||||
fun, catch, desc=desc, max_retries=max_retries,
|
||||
interval_start=interval_start, interval_step=interval_step,
|
||||
interval_max=interval_max, emit_warning=emit_warning,
|
||||
)
|
||||
except catch:
|
||||
pass
|
||||
else:
|
||||
raise AssertionError(f'Should not have happened: {desc}')
|
||||
|
||||
def retry_over_time(self, *args, **kwargs):
|
||||
return retry_over_time(*args, **kwargs)
|
||||
|
||||
def join(self, r, propagate=False, max_retries=10, **kwargs):
|
||||
if self.no_join:
|
||||
return
|
||||
if not isinstance(r, ResultSet):
|
||||
r = self.app.ResultSet([r])
|
||||
received = []
|
||||
|
||||
def on_result(task_id, value):
|
||||
received.append(task_id)
|
||||
|
||||
for i in range(max_retries) if max_retries else count(0):
|
||||
received[:] = []
|
||||
try:
|
||||
return r.get(callback=on_result, propagate=propagate, **kwargs)
|
||||
except (socket.timeout, TimeoutError) as exc:
|
||||
waiting_for = self.missing_results(r)
|
||||
self.remark(
|
||||
'Still waiting for {}/{}: [{}]: {!r}'.format(
|
||||
len(r) - len(received), len(r),
|
||||
truncate(', '.join(waiting_for)), exc), '!',
|
||||
)
|
||||
except self.connerrors as exc:
|
||||
self.remark(f'join: connection lost: {exc!r}', '!')
|
||||
raise AssertionError('Test failed: Missing task results')
|
||||
|
||||
def inspect(self, timeout=3.0):
|
||||
return self.app.control.inspect(timeout=timeout)
|
||||
|
||||
def query_tasks(self, ids, timeout=0.5):
|
||||
tasks = self.inspect(timeout).query_task(*ids) or {}
|
||||
yield from tasks.items()
|
||||
|
||||
def query_task_states(self, ids, timeout=0.5):
|
||||
states = defaultdict(set)
|
||||
for hostname, reply in self.query_tasks(ids, timeout=timeout):
|
||||
for task_id, (state, _) in reply.items():
|
||||
states[state].add(task_id)
|
||||
return states
|
||||
|
||||
def assert_accepted(self, ids, interval=0.5,
|
||||
desc='waiting for tasks to be accepted', **policy):
|
||||
return self.assert_task_worker_state(
|
||||
self.is_accepted, ids, interval=interval, desc=desc, **policy
|
||||
)
|
||||
|
||||
def assert_received(self, ids, interval=0.5,
|
||||
desc='waiting for tasks to be received', **policy):
|
||||
return self.assert_task_worker_state(
|
||||
self.is_received, ids, interval=interval, desc=desc, **policy
|
||||
)
|
||||
|
||||
def assert_result_tasks_in_progress_or_completed(
|
||||
self,
|
||||
async_results,
|
||||
interval=0.5,
|
||||
desc='waiting for tasks to be started or completed',
|
||||
**policy
|
||||
):
|
||||
return self.assert_task_state_from_result(
|
||||
self.is_result_task_in_progress,
|
||||
async_results,
|
||||
interval=interval, desc=desc, **policy
|
||||
)
|
||||
|
||||
def assert_task_state_from_result(self, fun, results,
|
||||
interval=0.5, **policy):
|
||||
return self.wait_for(
|
||||
partial(self.true_or_raise, fun, results, timeout=interval),
|
||||
(Sentinel,), **policy
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def is_result_task_in_progress(results, **kwargs):
|
||||
possible_states = (states.STARTED, states.SUCCESS)
|
||||
return all(result.state in possible_states for result in results)
|
||||
|
||||
def assert_task_worker_state(self, fun, ids, interval=0.5, **policy):
|
||||
return self.wait_for(
|
||||
partial(self.true_or_raise, fun, ids, timeout=interval),
|
||||
(Sentinel,), **policy
|
||||
)
|
||||
|
||||
def is_received(self, ids, **kwargs):
|
||||
return self._ids_matches_state(
|
||||
['reserved', 'active', 'ready'], ids, **kwargs)
|
||||
|
||||
def is_accepted(self, ids, **kwargs):
|
||||
return self._ids_matches_state(['active', 'ready'], ids, **kwargs)
|
||||
|
||||
def _ids_matches_state(self, expected_states, ids, timeout=0.5):
|
||||
states = self.query_task_states(ids, timeout=timeout)
|
||||
return all(
|
||||
any(t in s for s in [states[k] for k in expected_states])
|
||||
for t in ids
|
||||
)
|
||||
|
||||
def true_or_raise(self, fun, *args, **kwargs):
|
||||
res = fun(*args, **kwargs)
|
||||
if not res:
|
||||
raise Sentinel()
|
||||
return res
|
||||
|
||||
def wait_until_idle(self):
|
||||
control = self.app.control
|
||||
with self.app.connection() as connection:
|
||||
# Try to purge the queue before we start
|
||||
# to attempt to avoid interference from other tests
|
||||
while True:
|
||||
count = control.purge(connection=connection)
|
||||
if count == 0:
|
||||
break
|
||||
|
||||
# Wait until worker is idle
|
||||
inspect = control.inspect()
|
||||
inspect.connection = connection
|
||||
while True:
|
||||
try:
|
||||
count = sum(len(t) for t in inspect.active().values())
|
||||
except ContentDisallowed:
|
||||
# test_security_task_done may trigger this exception
|
||||
break
|
||||
if count == 0:
|
||||
break
|
||||
|
||||
|
||||
class Manager(ManagerMixin):
|
||||
"""Test helpers for task integration tests."""
|
||||
|
||||
def __init__(self, app, **kwargs):
|
||||
self.app = app
|
||||
self._init_manager(**kwargs)
|
||||
@@ -0,0 +1,137 @@
|
||||
"""Useful mocks for unit testing."""
|
||||
import numbers
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Mapping, Sequence # noqa
|
||||
from unittest.mock import Mock
|
||||
|
||||
from celery import Celery # noqa
|
||||
from celery.canvas import Signature # noqa
|
||||
|
||||
|
||||
def TaskMessage(
|
||||
name, # type: str
|
||||
id=None, # type: str
|
||||
args=(), # type: Sequence
|
||||
kwargs=None, # type: Mapping
|
||||
callbacks=None, # type: Sequence[Signature]
|
||||
errbacks=None, # type: Sequence[Signature]
|
||||
chain=None, # type: Sequence[Signature]
|
||||
shadow=None, # type: str
|
||||
utc=None, # type: bool
|
||||
**options # type: Any
|
||||
):
|
||||
# type: (...) -> Any
|
||||
"""Create task message in protocol 2 format."""
|
||||
kwargs = {} if not kwargs else kwargs
|
||||
from kombu.serialization import dumps
|
||||
|
||||
from celery import uuid
|
||||
id = id or uuid()
|
||||
message = Mock(name=f'TaskMessage-{id}')
|
||||
message.headers = {
|
||||
'id': id,
|
||||
'task': name,
|
||||
'shadow': shadow,
|
||||
}
|
||||
embed = {'callbacks': callbacks, 'errbacks': errbacks, 'chain': chain}
|
||||
message.headers.update(options)
|
||||
message.content_type, message.content_encoding, message.body = dumps(
|
||||
(args, kwargs, embed), serializer='json',
|
||||
)
|
||||
message.payload = (args, kwargs, embed)
|
||||
return message
|
||||
|
||||
|
||||
def TaskMessage1(
|
||||
name, # type: str
|
||||
id=None, # type: str
|
||||
args=(), # type: Sequence
|
||||
kwargs=None, # type: Mapping
|
||||
callbacks=None, # type: Sequence[Signature]
|
||||
errbacks=None, # type: Sequence[Signature]
|
||||
chain=None, # type: Sequence[Signature]
|
||||
**options # type: Any
|
||||
):
|
||||
# type: (...) -> Any
|
||||
"""Create task message in protocol 1 format."""
|
||||
kwargs = {} if not kwargs else kwargs
|
||||
from kombu.serialization import dumps
|
||||
|
||||
from celery import uuid
|
||||
id = id or uuid()
|
||||
message = Mock(name=f'TaskMessage-{id}')
|
||||
message.headers = {}
|
||||
message.payload = {
|
||||
'task': name,
|
||||
'id': id,
|
||||
'args': args,
|
||||
'kwargs': kwargs,
|
||||
'callbacks': callbacks,
|
||||
'errbacks': errbacks,
|
||||
}
|
||||
message.payload.update(options)
|
||||
message.content_type, message.content_encoding, message.body = dumps(
|
||||
message.payload,
|
||||
)
|
||||
return message
|
||||
|
||||
|
||||
def task_message_from_sig(app, sig, utc=True, TaskMessage=TaskMessage):
|
||||
# type: (Celery, Signature, bool, Any) -> Any
|
||||
"""Create task message from :class:`celery.Signature`.
|
||||
|
||||
Example:
|
||||
>>> m = task_message_from_sig(app, add.s(2, 2))
|
||||
>>> amqp_client.basic_publish(m, exchange='ex', routing_key='rkey')
|
||||
"""
|
||||
sig.freeze()
|
||||
callbacks = sig.options.pop('link', None)
|
||||
errbacks = sig.options.pop('link_error', None)
|
||||
countdown = sig.options.pop('countdown', None)
|
||||
if countdown:
|
||||
eta = app.now() + timedelta(seconds=countdown)
|
||||
else:
|
||||
eta = sig.options.pop('eta', None)
|
||||
if eta and isinstance(eta, datetime):
|
||||
eta = eta.isoformat()
|
||||
expires = sig.options.pop('expires', None)
|
||||
if expires and isinstance(expires, numbers.Real):
|
||||
expires = app.now() + timedelta(seconds=expires)
|
||||
if expires and isinstance(expires, datetime):
|
||||
expires = expires.isoformat()
|
||||
return TaskMessage(
|
||||
sig.task, id=sig.id, args=sig.args,
|
||||
kwargs=sig.kwargs,
|
||||
callbacks=[dict(s) for s in callbacks] if callbacks else None,
|
||||
errbacks=[dict(s) for s in errbacks] if errbacks else None,
|
||||
eta=eta,
|
||||
expires=expires,
|
||||
utc=utc,
|
||||
**sig.options
|
||||
)
|
||||
|
||||
|
||||
class _ContextMock(Mock):
|
||||
"""Dummy class implementing __enter__ and __exit__.
|
||||
|
||||
The :keyword:`with` statement requires these to be implemented
|
||||
in the class, not just the instance.
|
||||
"""
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *exc_info):
|
||||
pass
|
||||
|
||||
|
||||
def ContextMock(*args, **kwargs):
|
||||
"""Mock that mocks :keyword:`with` statement contexts."""
|
||||
obj = _ContextMock(*args, **kwargs)
|
||||
obj.attach_mock(_ContextMock(), '__enter__')
|
||||
obj.attach_mock(_ContextMock(), '__exit__')
|
||||
obj.__enter__.return_value = obj
|
||||
# if __exit__ return a value the exception is ignored,
|
||||
# so it must return None here.
|
||||
obj.__exit__.return_value = None
|
||||
return obj
|
||||
@@ -0,0 +1,9 @@
|
||||
"""Helper tasks for integration tests."""
|
||||
from celery import shared_task
|
||||
|
||||
|
||||
@shared_task(name='celery.ping')
|
||||
def ping():
|
||||
# type: () -> str
|
||||
"""Simple task that just returns 'pong'."""
|
||||
return 'pong'
|
||||
@@ -0,0 +1,223 @@
|
||||
"""Embedded workers for integration tests."""
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Iterable, Optional, Union
|
||||
|
||||
import celery.worker.consumer # noqa
|
||||
from celery import Celery, worker
|
||||
from celery.result import _set_task_join_will_block, allow_join_result
|
||||
from celery.utils.dispatch import Signal
|
||||
from celery.utils.nodenames import anon_nodename
|
||||
|
||||
WORKER_LOGLEVEL = os.environ.get('WORKER_LOGLEVEL', 'error')
|
||||
|
||||
test_worker_starting = Signal(
|
||||
name='test_worker_starting',
|
||||
providing_args={},
|
||||
)
|
||||
test_worker_started = Signal(
|
||||
name='test_worker_started',
|
||||
providing_args={'worker', 'consumer'},
|
||||
)
|
||||
test_worker_stopped = Signal(
|
||||
name='test_worker_stopped',
|
||||
providing_args={'worker'},
|
||||
)
|
||||
|
||||
|
||||
class TestWorkController(worker.WorkController):
|
||||
"""Worker that can synchronize on being fully started."""
|
||||
|
||||
# When this class is imported in pytest files, prevent pytest from thinking
|
||||
# this is a test class
|
||||
__test__ = False
|
||||
|
||||
logger_queue = None
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
# type: (*Any, **Any) -> None
|
||||
self._on_started = threading.Event()
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
if self.pool_cls.__module__.split('.')[-1] == 'prefork':
|
||||
from billiard import Queue
|
||||
self.logger_queue = Queue()
|
||||
self.pid = os.getpid()
|
||||
|
||||
try:
|
||||
from tblib import pickling_support
|
||||
pickling_support.install()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# collect logs from forked process.
|
||||
# XXX: those logs will appear twice in the live log
|
||||
self.queue_listener = logging.handlers.QueueListener(self.logger_queue, logging.getLogger())
|
||||
self.queue_listener.start()
|
||||
|
||||
class QueueHandler(logging.handlers.QueueHandler):
|
||||
def prepare(self, record):
|
||||
record.from_queue = True
|
||||
# Keep origin record.
|
||||
return record
|
||||
|
||||
def handleError(self, record):
|
||||
if logging.raiseExceptions:
|
||||
raise
|
||||
|
||||
def start(self):
|
||||
if self.logger_queue:
|
||||
handler = self.QueueHandler(self.logger_queue)
|
||||
handler.addFilter(lambda r: r.process != self.pid and not getattr(r, 'from_queue', False))
|
||||
logger = logging.getLogger()
|
||||
logger.addHandler(handler)
|
||||
return super().start()
|
||||
|
||||
def on_consumer_ready(self, consumer):
|
||||
# type: (celery.worker.consumer.Consumer) -> None
|
||||
"""Callback called when the Consumer blueprint is fully started."""
|
||||
self._on_started.set()
|
||||
test_worker_started.send(
|
||||
sender=self.app, worker=self, consumer=consumer)
|
||||
|
||||
def ensure_started(self):
|
||||
# type: () -> None
|
||||
"""Wait for worker to be fully up and running.
|
||||
|
||||
Warning:
|
||||
Worker must be started within a thread for this to work,
|
||||
or it will block forever.
|
||||
"""
|
||||
self._on_started.wait()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def start_worker(
|
||||
app, # type: Celery
|
||||
concurrency=1, # type: int
|
||||
pool='solo', # type: str
|
||||
loglevel=WORKER_LOGLEVEL, # type: Union[str, int]
|
||||
logfile=None, # type: str
|
||||
perform_ping_check=True, # type: bool
|
||||
ping_task_timeout=10.0, # type: float
|
||||
shutdown_timeout=10.0, # type: float
|
||||
**kwargs # type: Any
|
||||
):
|
||||
# type: (...) -> Iterable
|
||||
"""Start embedded worker.
|
||||
|
||||
Yields:
|
||||
celery.app.worker.Worker: worker instance.
|
||||
"""
|
||||
test_worker_starting.send(sender=app)
|
||||
|
||||
worker = None
|
||||
try:
|
||||
with _start_worker_thread(app,
|
||||
concurrency=concurrency,
|
||||
pool=pool,
|
||||
loglevel=loglevel,
|
||||
logfile=logfile,
|
||||
perform_ping_check=perform_ping_check,
|
||||
shutdown_timeout=shutdown_timeout,
|
||||
**kwargs) as worker:
|
||||
if perform_ping_check:
|
||||
from .tasks import ping
|
||||
with allow_join_result():
|
||||
assert ping.delay().get(timeout=ping_task_timeout) == 'pong'
|
||||
|
||||
yield worker
|
||||
finally:
|
||||
test_worker_stopped.send(sender=app, worker=worker)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _start_worker_thread(app: Celery,
|
||||
concurrency: int = 1,
|
||||
pool: str = 'solo',
|
||||
loglevel: Union[str, int] = WORKER_LOGLEVEL,
|
||||
logfile: Optional[str] = None,
|
||||
WorkController: Any = TestWorkController,
|
||||
perform_ping_check: bool = True,
|
||||
shutdown_timeout: float = 10.0,
|
||||
**kwargs) -> Iterable[worker.WorkController]:
|
||||
"""Start Celery worker in a thread.
|
||||
|
||||
Yields:
|
||||
celery.worker.Worker: worker instance.
|
||||
"""
|
||||
setup_app_for_worker(app, loglevel, logfile)
|
||||
if perform_ping_check:
|
||||
assert 'celery.ping' in app.tasks
|
||||
# Make sure we can connect to the broker
|
||||
with app.connection(hostname=os.environ.get('TEST_BROKER')) as conn:
|
||||
conn.default_channel.queue_declare
|
||||
|
||||
worker = WorkController(
|
||||
app=app,
|
||||
concurrency=concurrency,
|
||||
hostname=kwargs.pop("hostname", anon_nodename()),
|
||||
pool=pool,
|
||||
loglevel=loglevel,
|
||||
logfile=logfile,
|
||||
# not allowed to override TestWorkController.on_consumer_ready
|
||||
ready_callback=None,
|
||||
without_heartbeat=kwargs.pop("without_heartbeat", True),
|
||||
without_mingle=True,
|
||||
without_gossip=True,
|
||||
**kwargs)
|
||||
|
||||
t = threading.Thread(target=worker.start, daemon=True)
|
||||
t.start()
|
||||
worker.ensure_started()
|
||||
_set_task_join_will_block(False)
|
||||
|
||||
try:
|
||||
yield worker
|
||||
finally:
|
||||
from celery.worker import state
|
||||
state.should_terminate = 0
|
||||
t.join(shutdown_timeout)
|
||||
if t.is_alive():
|
||||
raise RuntimeError(
|
||||
"Worker thread failed to exit within the allocated timeout. "
|
||||
"Consider raising `shutdown_timeout` if your tasks take longer "
|
||||
"to execute."
|
||||
)
|
||||
state.should_terminate = None
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _start_worker_process(app,
|
||||
concurrency=1,
|
||||
pool='solo',
|
||||
loglevel=WORKER_LOGLEVEL,
|
||||
logfile=None,
|
||||
**kwargs):
|
||||
# type (Celery, int, str, Union[int, str], str, **Any) -> Iterable
|
||||
"""Start worker in separate process.
|
||||
|
||||
Yields:
|
||||
celery.app.worker.Worker: worker instance.
|
||||
"""
|
||||
from celery.apps.multi import Cluster, Node
|
||||
|
||||
app.set_current()
|
||||
cluster = Cluster([Node('testworker1@%h')])
|
||||
cluster.start()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
cluster.stopwait()
|
||||
|
||||
|
||||
def setup_app_for_worker(app: Celery, loglevel: Union[str, int], logfile: str) -> None:
|
||||
"""Setup the app to be used for starting an embedded worker."""
|
||||
app.finalize()
|
||||
app.set_current()
|
||||
app.set_default()
|
||||
type(app.log)._setup = False
|
||||
app.log.setup(loglevel=loglevel, logfile=logfile)
|
||||
Reference in New Issue
Block a user