Updates
This commit is contained in:
@@ -0,0 +1,4 @@
|
||||
"""Worker implementation."""
|
||||
from .worker import WorkController
|
||||
|
||||
__all__ = ('WorkController',)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,154 @@
|
||||
"""Pool Autoscaling.
|
||||
|
||||
This module implements the internal thread responsible
|
||||
for growing and shrinking the pool according to the
|
||||
current autoscale settings.
|
||||
|
||||
The autoscale thread is only enabled if
|
||||
the :option:`celery worker --autoscale` option is used.
|
||||
"""
|
||||
import os
|
||||
import threading
|
||||
from time import monotonic, sleep
|
||||
|
||||
from kombu.asynchronous.semaphore import DummyLock
|
||||
|
||||
from celery import bootsteps
|
||||
from celery.utils.log import get_logger
|
||||
from celery.utils.threads import bgThread
|
||||
|
||||
from . import state
|
||||
from .components import Pool
|
||||
|
||||
__all__ = ('Autoscaler', 'WorkerComponent')
|
||||
|
||||
logger = get_logger(__name__)
|
||||
debug, info, error = logger.debug, logger.info, logger.error
|
||||
|
||||
AUTOSCALE_KEEPALIVE = float(os.environ.get('AUTOSCALE_KEEPALIVE', 30))
|
||||
|
||||
|
||||
class WorkerComponent(bootsteps.StartStopStep):
|
||||
"""Bootstep that starts the autoscaler thread/timer in the worker."""
|
||||
|
||||
label = 'Autoscaler'
|
||||
conditional = True
|
||||
requires = (Pool,)
|
||||
|
||||
def __init__(self, w, **kwargs):
|
||||
self.enabled = w.autoscale
|
||||
w.autoscaler = None
|
||||
|
||||
def create(self, w):
|
||||
scaler = w.autoscaler = self.instantiate(
|
||||
w.autoscaler_cls,
|
||||
w.pool, w.max_concurrency, w.min_concurrency,
|
||||
worker=w, mutex=DummyLock() if w.use_eventloop else None,
|
||||
)
|
||||
return scaler if not w.use_eventloop else None
|
||||
|
||||
def register_with_event_loop(self, w, hub):
|
||||
w.consumer.on_task_message.add(w.autoscaler.maybe_scale)
|
||||
hub.call_repeatedly(
|
||||
w.autoscaler.keepalive, w.autoscaler.maybe_scale,
|
||||
)
|
||||
|
||||
def info(self, w):
|
||||
"""Return `Autoscaler` info."""
|
||||
return {'autoscaler': w.autoscaler.info()}
|
||||
|
||||
|
||||
class Autoscaler(bgThread):
|
||||
"""Background thread to autoscale pool workers."""
|
||||
|
||||
def __init__(self, pool, max_concurrency,
|
||||
min_concurrency=0, worker=None,
|
||||
keepalive=AUTOSCALE_KEEPALIVE, mutex=None):
|
||||
super().__init__()
|
||||
self.pool = pool
|
||||
self.mutex = mutex or threading.Lock()
|
||||
self.max_concurrency = max_concurrency
|
||||
self.min_concurrency = min_concurrency
|
||||
self.keepalive = keepalive
|
||||
self._last_scale_up = None
|
||||
self.worker = worker
|
||||
|
||||
assert self.keepalive, 'cannot scale down too fast.'
|
||||
|
||||
def body(self):
|
||||
with self.mutex:
|
||||
self.maybe_scale()
|
||||
sleep(1.0)
|
||||
|
||||
def _maybe_scale(self, req=None):
|
||||
procs = self.processes
|
||||
cur = min(self.qty, self.max_concurrency)
|
||||
if cur > procs:
|
||||
self.scale_up(cur - procs)
|
||||
return True
|
||||
cur = max(self.qty, self.min_concurrency)
|
||||
if cur < procs:
|
||||
self.scale_down(procs - cur)
|
||||
return True
|
||||
|
||||
def maybe_scale(self, req=None):
|
||||
if self._maybe_scale(req):
|
||||
self.pool.maintain_pool()
|
||||
|
||||
def update(self, max=None, min=None):
|
||||
with self.mutex:
|
||||
if max is not None:
|
||||
if max < self.processes:
|
||||
self._shrink(self.processes - max)
|
||||
self._update_consumer_prefetch_count(max)
|
||||
self.max_concurrency = max
|
||||
if min is not None:
|
||||
if min > self.processes:
|
||||
self._grow(min - self.processes)
|
||||
self.min_concurrency = min
|
||||
return self.max_concurrency, self.min_concurrency
|
||||
|
||||
def scale_up(self, n):
|
||||
self._last_scale_up = monotonic()
|
||||
return self._grow(n)
|
||||
|
||||
def scale_down(self, n):
|
||||
if self._last_scale_up and (
|
||||
monotonic() - self._last_scale_up > self.keepalive):
|
||||
return self._shrink(n)
|
||||
|
||||
def _grow(self, n):
|
||||
info('Scaling up %s processes.', n)
|
||||
self.pool.grow(n)
|
||||
|
||||
def _shrink(self, n):
|
||||
info('Scaling down %s processes.', n)
|
||||
try:
|
||||
self.pool.shrink(n)
|
||||
except ValueError:
|
||||
debug("Autoscaler won't scale down: all processes busy.")
|
||||
except Exception as exc:
|
||||
error('Autoscaler: scale_down: %r', exc, exc_info=True)
|
||||
|
||||
def _update_consumer_prefetch_count(self, new_max):
|
||||
diff = new_max - self.max_concurrency
|
||||
if diff:
|
||||
self.worker.consumer._update_prefetch_count(
|
||||
diff
|
||||
)
|
||||
|
||||
def info(self):
|
||||
return {
|
||||
'max': self.max_concurrency,
|
||||
'min': self.min_concurrency,
|
||||
'current': self.processes,
|
||||
'qty': self.qty,
|
||||
}
|
||||
|
||||
@property
|
||||
def qty(self):
|
||||
return len(state.reserved_requests)
|
||||
|
||||
@property
|
||||
def processes(self):
|
||||
return self.pool.num_processes
|
||||
@@ -0,0 +1,240 @@
|
||||
"""Worker-level Bootsteps."""
|
||||
import atexit
|
||||
import warnings
|
||||
|
||||
from kombu.asynchronous import Hub as _Hub
|
||||
from kombu.asynchronous import get_event_loop, set_event_loop
|
||||
from kombu.asynchronous.semaphore import DummyLock, LaxBoundedSemaphore
|
||||
from kombu.asynchronous.timer import Timer as _Timer
|
||||
|
||||
from celery import bootsteps
|
||||
from celery._state import _set_task_join_will_block
|
||||
from celery.exceptions import ImproperlyConfigured
|
||||
from celery.platforms import IS_WINDOWS
|
||||
from celery.utils.log import worker_logger as logger
|
||||
|
||||
__all__ = ('Timer', 'Hub', 'Pool', 'Beat', 'StateDB', 'Consumer')
|
||||
|
||||
GREEN_POOLS = {'eventlet', 'gevent'}
|
||||
|
||||
ERR_B_GREEN = """\
|
||||
-B option doesn't work with eventlet/gevent pools: \
|
||||
use standalone beat instead.\
|
||||
"""
|
||||
|
||||
W_POOL_SETTING = """
|
||||
The worker_pool setting shouldn't be used to select the eventlet/gevent
|
||||
pools, instead you *must use the -P* argument so that patches are applied
|
||||
as early as possible.
|
||||
"""
|
||||
|
||||
|
||||
class Timer(bootsteps.Step):
|
||||
"""Timer bootstep."""
|
||||
|
||||
def create(self, w):
|
||||
if w.use_eventloop:
|
||||
# does not use dedicated timer thread.
|
||||
w.timer = _Timer(max_interval=10.0)
|
||||
else:
|
||||
if not w.timer_cls:
|
||||
# Default Timer is set by the pool, as for example, the
|
||||
# eventlet pool needs a custom timer implementation.
|
||||
w.timer_cls = w.pool_cls.Timer
|
||||
w.timer = self.instantiate(w.timer_cls,
|
||||
max_interval=w.timer_precision,
|
||||
on_error=self.on_timer_error,
|
||||
on_tick=self.on_timer_tick)
|
||||
|
||||
def on_timer_error(self, exc):
|
||||
logger.error('Timer error: %r', exc, exc_info=True)
|
||||
|
||||
def on_timer_tick(self, delay):
|
||||
logger.debug('Timer wake-up! Next ETA %s secs.', delay)
|
||||
|
||||
|
||||
class Hub(bootsteps.StartStopStep):
|
||||
"""Worker starts the event loop."""
|
||||
|
||||
requires = (Timer,)
|
||||
|
||||
def __init__(self, w, **kwargs):
|
||||
w.hub = None
|
||||
super().__init__(w, **kwargs)
|
||||
|
||||
def include_if(self, w):
|
||||
return w.use_eventloop
|
||||
|
||||
def create(self, w):
|
||||
w.hub = get_event_loop()
|
||||
if w.hub is None:
|
||||
required_hub = getattr(w._conninfo, 'requires_hub', None)
|
||||
w.hub = set_event_loop((
|
||||
required_hub if required_hub else _Hub)(w.timer))
|
||||
self._patch_thread_primitives(w)
|
||||
return self
|
||||
|
||||
def start(self, w):
|
||||
pass
|
||||
|
||||
def stop(self, w):
|
||||
w.hub.close()
|
||||
|
||||
def terminate(self, w):
|
||||
w.hub.close()
|
||||
|
||||
def _patch_thread_primitives(self, w):
|
||||
# make clock use dummy lock
|
||||
w.app.clock.mutex = DummyLock()
|
||||
# multiprocessing's ApplyResult uses this lock.
|
||||
try:
|
||||
from billiard import pool
|
||||
except ImportError:
|
||||
pass
|
||||
else:
|
||||
pool.Lock = DummyLock
|
||||
|
||||
|
||||
class Pool(bootsteps.StartStopStep):
|
||||
"""Bootstep managing the worker pool.
|
||||
|
||||
Describes how to initialize the worker pool, and starts and stops
|
||||
the pool during worker start-up/shutdown.
|
||||
|
||||
Adds attributes:
|
||||
|
||||
* autoscale
|
||||
* pool
|
||||
* max_concurrency
|
||||
* min_concurrency
|
||||
"""
|
||||
|
||||
requires = (Hub,)
|
||||
|
||||
def __init__(self, w, autoscale=None, **kwargs):
|
||||
w.pool = None
|
||||
w.max_concurrency = None
|
||||
w.min_concurrency = w.concurrency
|
||||
self.optimization = w.optimization
|
||||
if isinstance(autoscale, str):
|
||||
max_c, _, min_c = autoscale.partition(',')
|
||||
autoscale = [int(max_c), min_c and int(min_c) or 0]
|
||||
w.autoscale = autoscale
|
||||
if w.autoscale:
|
||||
w.max_concurrency, w.min_concurrency = w.autoscale
|
||||
super().__init__(w, **kwargs)
|
||||
|
||||
def close(self, w):
|
||||
if w.pool:
|
||||
w.pool.close()
|
||||
|
||||
def terminate(self, w):
|
||||
if w.pool:
|
||||
w.pool.terminate()
|
||||
|
||||
def create(self, w):
|
||||
semaphore = None
|
||||
max_restarts = None
|
||||
if w.app.conf.worker_pool in GREEN_POOLS: # pragma: no cover
|
||||
warnings.warn(UserWarning(W_POOL_SETTING))
|
||||
threaded = not w.use_eventloop or IS_WINDOWS
|
||||
procs = w.min_concurrency
|
||||
w.process_task = w._process_task
|
||||
if not threaded:
|
||||
semaphore = w.semaphore = LaxBoundedSemaphore(procs)
|
||||
w._quick_acquire = w.semaphore.acquire
|
||||
w._quick_release = w.semaphore.release
|
||||
max_restarts = 100
|
||||
if w.pool_putlocks and w.pool_cls.uses_semaphore:
|
||||
w.process_task = w._process_task_sem
|
||||
allow_restart = w.pool_restarts
|
||||
pool = w.pool = self.instantiate(
|
||||
w.pool_cls, w.min_concurrency,
|
||||
initargs=(w.app, w.hostname),
|
||||
maxtasksperchild=w.max_tasks_per_child,
|
||||
max_memory_per_child=w.max_memory_per_child,
|
||||
timeout=w.time_limit,
|
||||
soft_timeout=w.soft_time_limit,
|
||||
putlocks=w.pool_putlocks and threaded,
|
||||
lost_worker_timeout=w.worker_lost_wait,
|
||||
threads=threaded,
|
||||
max_restarts=max_restarts,
|
||||
allow_restart=allow_restart,
|
||||
forking_enable=True,
|
||||
semaphore=semaphore,
|
||||
sched_strategy=self.optimization,
|
||||
app=w.app,
|
||||
)
|
||||
_set_task_join_will_block(pool.task_join_will_block)
|
||||
return pool
|
||||
|
||||
def info(self, w):
|
||||
return {'pool': w.pool.info if w.pool else 'N/A'}
|
||||
|
||||
def register_with_event_loop(self, w, hub):
|
||||
w.pool.register_with_event_loop(hub)
|
||||
|
||||
|
||||
class Beat(bootsteps.StartStopStep):
|
||||
"""Step used to embed a beat process.
|
||||
|
||||
Enabled when the ``beat`` argument is set.
|
||||
"""
|
||||
|
||||
label = 'Beat'
|
||||
conditional = True
|
||||
|
||||
def __init__(self, w, beat=False, **kwargs):
|
||||
self.enabled = w.beat = beat
|
||||
w.beat = None
|
||||
super().__init__(w, beat=beat, **kwargs)
|
||||
|
||||
def create(self, w):
|
||||
from celery.beat import EmbeddedService
|
||||
if w.pool_cls.__module__.endswith(('gevent', 'eventlet')):
|
||||
raise ImproperlyConfigured(ERR_B_GREEN)
|
||||
b = w.beat = EmbeddedService(w.app,
|
||||
schedule_filename=w.schedule_filename,
|
||||
scheduler_cls=w.scheduler)
|
||||
return b
|
||||
|
||||
|
||||
class StateDB(bootsteps.Step):
|
||||
"""Bootstep that sets up between-restart state database file."""
|
||||
|
||||
def __init__(self, w, **kwargs):
|
||||
self.enabled = w.statedb
|
||||
w._persistence = None
|
||||
super().__init__(w, **kwargs)
|
||||
|
||||
def create(self, w):
|
||||
w._persistence = w.state.Persistent(w.state, w.statedb, w.app.clock)
|
||||
atexit.register(w._persistence.save)
|
||||
|
||||
|
||||
class Consumer(bootsteps.StartStopStep):
|
||||
"""Bootstep starting the Consumer blueprint."""
|
||||
|
||||
last = True
|
||||
|
||||
def create(self, w):
|
||||
if w.max_concurrency:
|
||||
prefetch_count = max(w.max_concurrency, 1) * w.prefetch_multiplier
|
||||
else:
|
||||
prefetch_count = w.concurrency * w.prefetch_multiplier
|
||||
c = w.consumer = self.instantiate(
|
||||
w.consumer_cls, w.process_task,
|
||||
hostname=w.hostname,
|
||||
task_events=w.task_events,
|
||||
init_callback=w.ready_callback,
|
||||
initial_prefetch_count=prefetch_count,
|
||||
pool=w.pool,
|
||||
timer=w.timer,
|
||||
app=w.app,
|
||||
controller=w,
|
||||
hub=w.hub,
|
||||
worker_options=w.options,
|
||||
disable_rate_limits=w.disable_rate_limits,
|
||||
prefetch_multiplier=w.prefetch_multiplier,
|
||||
)
|
||||
return c
|
||||
@@ -0,0 +1,15 @@
|
||||
"""Worker consumer."""
|
||||
from .agent import Agent
|
||||
from .connection import Connection
|
||||
from .consumer import Consumer
|
||||
from .control import Control
|
||||
from .events import Events
|
||||
from .gossip import Gossip
|
||||
from .heart import Heart
|
||||
from .mingle import Mingle
|
||||
from .tasks import Tasks
|
||||
|
||||
__all__ = (
|
||||
'Consumer', 'Agent', 'Connection', 'Control',
|
||||
'Events', 'Gossip', 'Heart', 'Mingle', 'Tasks',
|
||||
)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,21 @@
|
||||
"""Celery + :pypi:`cell` integration."""
|
||||
from celery import bootsteps
|
||||
|
||||
from .connection import Connection
|
||||
|
||||
__all__ = ('Agent',)
|
||||
|
||||
|
||||
class Agent(bootsteps.StartStopStep):
|
||||
"""Agent starts :pypi:`cell` actors."""
|
||||
|
||||
conditional = True
|
||||
requires = (Connection,)
|
||||
|
||||
def __init__(self, c, **kwargs):
|
||||
self.agent_cls = self.enabled = c.app.conf.worker_agent
|
||||
super().__init__(c, **kwargs)
|
||||
|
||||
def create(self, c):
|
||||
agent = c.agent = self.instantiate(self.agent_cls, c.connection)
|
||||
return agent
|
||||
@@ -0,0 +1,36 @@
|
||||
"""Consumer Broker Connection Bootstep."""
|
||||
from kombu.common import ignore_errors
|
||||
|
||||
from celery import bootsteps
|
||||
from celery.utils.log import get_logger
|
||||
|
||||
__all__ = ('Connection',)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
info = logger.info
|
||||
|
||||
|
||||
class Connection(bootsteps.StartStopStep):
|
||||
"""Service managing the consumer broker connection."""
|
||||
|
||||
def __init__(self, c, **kwargs):
|
||||
c.connection = None
|
||||
super().__init__(c, **kwargs)
|
||||
|
||||
def start(self, c):
|
||||
c.connection = c.connect()
|
||||
info('Connected to %s', c.connection.as_uri())
|
||||
|
||||
def shutdown(self, c):
|
||||
# We must set self.connection to None here, so
|
||||
# that the green pidbox thread exits.
|
||||
connection, c.connection = c.connection, None
|
||||
if connection:
|
||||
ignore_errors(connection, connection.close)
|
||||
|
||||
def info(self, c):
|
||||
params = 'N/A'
|
||||
if c.connection:
|
||||
params = c.connection.info()
|
||||
params.pop('password', None) # don't send password.
|
||||
return {'broker': params}
|
||||
@@ -0,0 +1,775 @@
|
||||
"""Worker Consumer Blueprint.
|
||||
|
||||
This module contains the components responsible for consuming messages
|
||||
from the broker, processing the messages and keeping the broker connections
|
||||
up and running.
|
||||
"""
|
||||
import errno
|
||||
import logging
|
||||
import os
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from time import sleep
|
||||
|
||||
from billiard.common import restart_state
|
||||
from billiard.exceptions import RestartFreqExceeded
|
||||
from kombu.asynchronous.semaphore import DummyLock
|
||||
from kombu.exceptions import ContentDisallowed, DecodeError
|
||||
from kombu.utils.compat import _detect_environment
|
||||
from kombu.utils.encoding import safe_repr
|
||||
from kombu.utils.limits import TokenBucket
|
||||
from vine import ppartial, promise
|
||||
|
||||
from celery import bootsteps, signals
|
||||
from celery.app.trace import build_tracer
|
||||
from celery.exceptions import (CPendingDeprecationWarning, InvalidTaskError, NotRegistered, WorkerShutdown,
|
||||
WorkerTerminate)
|
||||
from celery.utils.functional import noop
|
||||
from celery.utils.log import get_logger
|
||||
from celery.utils.nodenames import gethostname
|
||||
from celery.utils.objects import Bunch
|
||||
from celery.utils.text import truncate
|
||||
from celery.utils.time import humanize_seconds, rate
|
||||
from celery.worker import loops
|
||||
from celery.worker.state import active_requests, maybe_shutdown, requests, reserved_requests, task_reserved
|
||||
|
||||
__all__ = ('Consumer', 'Evloop', 'dump_body')
|
||||
|
||||
CLOSE = bootsteps.CLOSE
|
||||
TERMINATE = bootsteps.TERMINATE
|
||||
STOP_CONDITIONS = {CLOSE, TERMINATE}
|
||||
logger = get_logger(__name__)
|
||||
debug, info, warn, error, crit = (logger.debug, logger.info, logger.warning,
|
||||
logger.error, logger.critical)
|
||||
|
||||
CONNECTION_RETRY = """\
|
||||
consumer: Connection to broker lost. \
|
||||
Trying to re-establish the connection...\
|
||||
"""
|
||||
|
||||
CONNECTION_RETRY_STEP = """\
|
||||
Trying again {when}... ({retries}/{max_retries})\
|
||||
"""
|
||||
|
||||
CONNECTION_ERROR = """\
|
||||
consumer: Cannot connect to %s: %s.
|
||||
%s
|
||||
"""
|
||||
|
||||
CONNECTION_FAILOVER = """\
|
||||
Will retry using next failover.\
|
||||
"""
|
||||
|
||||
UNKNOWN_FORMAT = """\
|
||||
Received and deleted unknown message. Wrong destination?!?
|
||||
|
||||
The full contents of the message body was: %s
|
||||
"""
|
||||
|
||||
#: Error message for when an unregistered task is received.
|
||||
UNKNOWN_TASK_ERROR = """\
|
||||
Received unregistered task of type %s.
|
||||
The message has been ignored and discarded.
|
||||
|
||||
Did you remember to import the module containing this task?
|
||||
Or maybe you're using relative imports?
|
||||
|
||||
Please see
|
||||
https://docs.celeryq.dev/en/latest/internals/protocol.html
|
||||
for more information.
|
||||
|
||||
The full contents of the message body was:
|
||||
%s
|
||||
|
||||
The full contents of the message headers:
|
||||
%s
|
||||
|
||||
The delivery info for this task is:
|
||||
%s
|
||||
"""
|
||||
|
||||
#: Error message for when an invalid task message is received.
|
||||
INVALID_TASK_ERROR = """\
|
||||
Received invalid task message: %s
|
||||
The message has been ignored and discarded.
|
||||
|
||||
Please ensure your message conforms to the task
|
||||
message protocol as described here:
|
||||
https://docs.celeryq.dev/en/latest/internals/protocol.html
|
||||
|
||||
The full contents of the message body was:
|
||||
%s
|
||||
"""
|
||||
|
||||
MESSAGE_DECODE_ERROR = """\
|
||||
Can't decode message body: %r [type:%r encoding:%r headers:%s]
|
||||
|
||||
body: %s
|
||||
"""
|
||||
|
||||
MESSAGE_REPORT = """\
|
||||
body: {0}
|
||||
{{content_type:{1} content_encoding:{2}
|
||||
delivery_info:{3} headers={4}}}
|
||||
"""
|
||||
|
||||
TERMINATING_TASK_ON_RESTART_AFTER_A_CONNECTION_LOSS = """\
|
||||
Task %s cannot be acknowledged after a connection loss since late acknowledgement is enabled for it.
|
||||
Terminating it instead.
|
||||
"""
|
||||
|
||||
CANCEL_TASKS_BY_DEFAULT = """
|
||||
In Celery 5.1 we introduced an optional breaking change which
|
||||
on connection loss cancels all currently executed tasks with late acknowledgement enabled.
|
||||
These tasks cannot be acknowledged as the connection is gone, and the tasks are automatically redelivered
|
||||
back to the queue. You can enable this behavior using the worker_cancel_long_running_tasks_on_connection_loss
|
||||
setting. In Celery 5.1 it is set to False by default. The setting will be set to True by default in Celery 6.0.
|
||||
"""
|
||||
|
||||
|
||||
def dump_body(m, body):
|
||||
"""Format message body for debugging purposes."""
|
||||
# v2 protocol does not deserialize body
|
||||
body = m.body if body is None else body
|
||||
return '{} ({}b)'.format(truncate(safe_repr(body), 1024),
|
||||
len(m.body))
|
||||
|
||||
|
||||
class Consumer:
|
||||
"""Consumer blueprint."""
|
||||
|
||||
Strategies = dict
|
||||
|
||||
#: Optional callback called the first time the worker
|
||||
#: is ready to receive tasks.
|
||||
init_callback = None
|
||||
|
||||
#: The current worker pool instance.
|
||||
pool = None
|
||||
|
||||
#: A timer used for high-priority internal tasks, such
|
||||
#: as sending heartbeats.
|
||||
timer = None
|
||||
|
||||
restart_count = -1 # first start is the same as a restart
|
||||
|
||||
#: This flag will be turned off after the first failed
|
||||
#: connection attempt.
|
||||
first_connection_attempt = True
|
||||
|
||||
class Blueprint(bootsteps.Blueprint):
|
||||
"""Consumer blueprint."""
|
||||
|
||||
name = 'Consumer'
|
||||
default_steps = [
|
||||
'celery.worker.consumer.connection:Connection',
|
||||
'celery.worker.consumer.mingle:Mingle',
|
||||
'celery.worker.consumer.events:Events',
|
||||
'celery.worker.consumer.gossip:Gossip',
|
||||
'celery.worker.consumer.heart:Heart',
|
||||
'celery.worker.consumer.control:Control',
|
||||
'celery.worker.consumer.tasks:Tasks',
|
||||
'celery.worker.consumer.delayed_delivery:DelayedDelivery',
|
||||
'celery.worker.consumer.consumer:Evloop',
|
||||
'celery.worker.consumer.agent:Agent',
|
||||
]
|
||||
|
||||
def shutdown(self, parent):
|
||||
self.send_all(parent, 'shutdown')
|
||||
|
||||
def __init__(self, on_task_request,
|
||||
init_callback=noop, hostname=None,
|
||||
pool=None, app=None,
|
||||
timer=None, controller=None, hub=None, amqheartbeat=None,
|
||||
worker_options=None, disable_rate_limits=False,
|
||||
initial_prefetch_count=2, prefetch_multiplier=1, **kwargs):
|
||||
self.app = app
|
||||
self.controller = controller
|
||||
self.init_callback = init_callback
|
||||
self.hostname = hostname or gethostname()
|
||||
self.pid = os.getpid()
|
||||
self.pool = pool
|
||||
self.timer = timer
|
||||
self.strategies = self.Strategies()
|
||||
self.conninfo = self.app.connection_for_read()
|
||||
self.connection_errors = self.conninfo.connection_errors
|
||||
self.channel_errors = self.conninfo.channel_errors
|
||||
self._restart_state = restart_state(maxR=5, maxT=1)
|
||||
|
||||
self._does_info = logger.isEnabledFor(logging.INFO)
|
||||
self._limit_order = 0
|
||||
self.on_task_request = on_task_request
|
||||
self.on_task_message = set()
|
||||
self.amqheartbeat_rate = self.app.conf.broker_heartbeat_checkrate
|
||||
self.disable_rate_limits = disable_rate_limits
|
||||
self.initial_prefetch_count = initial_prefetch_count
|
||||
self.prefetch_multiplier = prefetch_multiplier
|
||||
self._maximum_prefetch_restored = True
|
||||
|
||||
# this contains a tokenbucket for each task type by name, used for
|
||||
# rate limits, or None if rate limits are disabled for that task.
|
||||
self.task_buckets = defaultdict(lambda: None)
|
||||
self.reset_rate_limits()
|
||||
|
||||
self.hub = hub
|
||||
if self.hub or getattr(self.pool, 'is_green', False):
|
||||
self.amqheartbeat = amqheartbeat
|
||||
if self.amqheartbeat is None:
|
||||
self.amqheartbeat = self.app.conf.broker_heartbeat
|
||||
else:
|
||||
self.amqheartbeat = 0
|
||||
|
||||
if not hasattr(self, 'loop'):
|
||||
self.loop = loops.asynloop if hub else loops.synloop
|
||||
|
||||
if _detect_environment() == 'gevent':
|
||||
# there's a gevent bug that causes timeouts to not be reset,
|
||||
# so if the connection timeout is exceeded once, it can NEVER
|
||||
# connect again.
|
||||
self.app.conf.broker_connection_timeout = None
|
||||
|
||||
self._pending_operations = []
|
||||
|
||||
self.steps = []
|
||||
self.blueprint = self.Blueprint(
|
||||
steps=self.app.steps['consumer'],
|
||||
on_close=self.on_close,
|
||||
)
|
||||
self.blueprint.apply(self, **dict(worker_options or {}, **kwargs))
|
||||
|
||||
def call_soon(self, p, *args, **kwargs):
|
||||
p = ppartial(p, *args, **kwargs)
|
||||
if self.hub:
|
||||
return self.hub.call_soon(p)
|
||||
self._pending_operations.append(p)
|
||||
return p
|
||||
|
||||
def perform_pending_operations(self):
|
||||
if not self.hub:
|
||||
while self._pending_operations:
|
||||
try:
|
||||
self._pending_operations.pop()()
|
||||
except Exception as exc: # pylint: disable=broad-except
|
||||
logger.exception('Pending callback raised: %r', exc)
|
||||
|
||||
def bucket_for_task(self, type):
|
||||
limit = rate(getattr(type, 'rate_limit', None))
|
||||
return TokenBucket(limit, capacity=1) if limit else None
|
||||
|
||||
def reset_rate_limits(self):
|
||||
self.task_buckets.update(
|
||||
(n, self.bucket_for_task(t)) for n, t in self.app.tasks.items()
|
||||
)
|
||||
|
||||
def _update_prefetch_count(self, index=0):
|
||||
"""Update prefetch count after pool/shrink grow operations.
|
||||
|
||||
Index must be the change in number of processes as a positive
|
||||
(increasing) or negative (decreasing) number.
|
||||
|
||||
Note:
|
||||
Currently pool grow operations will end up with an offset
|
||||
of +1 if the initial size of the pool was 0 (e.g.
|
||||
:option:`--autoscale=1,0 <celery worker --autoscale>`).
|
||||
"""
|
||||
num_processes = self.pool.num_processes
|
||||
if not self.initial_prefetch_count or not num_processes:
|
||||
return # prefetch disabled
|
||||
self.initial_prefetch_count = (
|
||||
self.pool.num_processes * self.prefetch_multiplier
|
||||
)
|
||||
return self._update_qos_eventually(index)
|
||||
|
||||
def _update_qos_eventually(self, index):
|
||||
return (self.qos.decrement_eventually if index < 0
|
||||
else self.qos.increment_eventually)(
|
||||
abs(index) * self.prefetch_multiplier)
|
||||
|
||||
def _limit_move_to_pool(self, request):
|
||||
task_reserved(request)
|
||||
self.on_task_request(request)
|
||||
|
||||
def _schedule_bucket_request(self, bucket):
|
||||
while True:
|
||||
try:
|
||||
request, tokens = bucket.pop()
|
||||
except IndexError:
|
||||
# no request, break
|
||||
break
|
||||
|
||||
if bucket.can_consume(tokens):
|
||||
self._limit_move_to_pool(request)
|
||||
continue
|
||||
else:
|
||||
# requeue to head, keep the order.
|
||||
bucket.contents.appendleft((request, tokens))
|
||||
|
||||
pri = self._limit_order = (self._limit_order + 1) % 10
|
||||
hold = bucket.expected_time(tokens)
|
||||
self.timer.call_after(
|
||||
hold, self._schedule_bucket_request, (bucket,),
|
||||
priority=pri,
|
||||
)
|
||||
# no tokens, break
|
||||
break
|
||||
|
||||
def _limit_task(self, request, bucket, tokens):
|
||||
bucket.add((request, tokens))
|
||||
return self._schedule_bucket_request(bucket)
|
||||
|
||||
def _limit_post_eta(self, request, bucket, tokens):
|
||||
self.qos.decrement_eventually()
|
||||
bucket.add((request, tokens))
|
||||
return self._schedule_bucket_request(bucket)
|
||||
|
||||
def start(self):
|
||||
blueprint = self.blueprint
|
||||
while blueprint.state not in STOP_CONDITIONS:
|
||||
maybe_shutdown()
|
||||
if self.restart_count:
|
||||
try:
|
||||
self._restart_state.step()
|
||||
except RestartFreqExceeded as exc:
|
||||
crit('Frequent restarts detected: %r', exc, exc_info=1)
|
||||
sleep(1)
|
||||
self.restart_count += 1
|
||||
if self.app.conf.broker_channel_error_retry:
|
||||
recoverable_errors = (self.connection_errors + self.channel_errors)
|
||||
else:
|
||||
recoverable_errors = self.connection_errors
|
||||
try:
|
||||
blueprint.start(self)
|
||||
except recoverable_errors as exc:
|
||||
# If we're not retrying connections, we need to properly shutdown or terminate
|
||||
# the Celery main process instead of abruptly aborting the process without any cleanup.
|
||||
is_connection_loss_on_startup = self.first_connection_attempt
|
||||
self.first_connection_attempt = False
|
||||
connection_retry_type = self._get_connection_retry_type(is_connection_loss_on_startup)
|
||||
connection_retry = self.app.conf[connection_retry_type]
|
||||
if not connection_retry:
|
||||
crit(
|
||||
f"Retrying to {'establish' if is_connection_loss_on_startup else 're-establish'} "
|
||||
f"a connection to the message broker after a connection loss has "
|
||||
f"been disabled (app.conf.{connection_retry_type}=False). Shutting down..."
|
||||
)
|
||||
raise WorkerShutdown(1) from exc
|
||||
if isinstance(exc, OSError) and exc.errno == errno.EMFILE:
|
||||
crit("Too many open files. Aborting...")
|
||||
raise WorkerTerminate(1) from exc
|
||||
maybe_shutdown()
|
||||
if blueprint.state not in STOP_CONDITIONS:
|
||||
if self.connection:
|
||||
self.on_connection_error_after_connected(exc)
|
||||
else:
|
||||
self.on_connection_error_before_connected(exc)
|
||||
self.on_close()
|
||||
blueprint.restart(self)
|
||||
|
||||
def _get_connection_retry_type(self, is_connection_loss_on_startup):
|
||||
return ('broker_connection_retry_on_startup'
|
||||
if (is_connection_loss_on_startup
|
||||
and self.app.conf.broker_connection_retry_on_startup is not None)
|
||||
else 'broker_connection_retry')
|
||||
|
||||
def on_connection_error_before_connected(self, exc):
|
||||
error(CONNECTION_ERROR, self.conninfo.as_uri(), exc,
|
||||
'Trying to reconnect...')
|
||||
|
||||
def on_connection_error_after_connected(self, exc):
|
||||
warn(CONNECTION_RETRY, exc_info=True)
|
||||
try:
|
||||
self.connection.collect()
|
||||
except Exception: # pylint: disable=broad-except
|
||||
pass
|
||||
|
||||
if self.app.conf.worker_cancel_long_running_tasks_on_connection_loss:
|
||||
for request in tuple(active_requests):
|
||||
if request.task.acks_late and not request.acknowledged:
|
||||
warn(TERMINATING_TASK_ON_RESTART_AFTER_A_CONNECTION_LOSS,
|
||||
request)
|
||||
request.cancel(self.pool)
|
||||
else:
|
||||
warnings.warn(CANCEL_TASKS_BY_DEFAULT, CPendingDeprecationWarning)
|
||||
|
||||
if self.app.conf.worker_enable_prefetch_count_reduction:
|
||||
self.initial_prefetch_count = max(
|
||||
self.prefetch_multiplier,
|
||||
self.max_prefetch_count - len(tuple(active_requests)) * self.prefetch_multiplier
|
||||
)
|
||||
|
||||
self._maximum_prefetch_restored = self.initial_prefetch_count == self.max_prefetch_count
|
||||
if not self._maximum_prefetch_restored:
|
||||
logger.info(
|
||||
f"Temporarily reducing the prefetch count to {self.initial_prefetch_count} to avoid "
|
||||
f"over-fetching since {len(tuple(active_requests))} tasks are currently being processed.\n"
|
||||
f"The prefetch count will be gradually restored to {self.max_prefetch_count} as the tasks "
|
||||
"complete processing."
|
||||
)
|
||||
|
||||
def register_with_event_loop(self, hub):
|
||||
self.blueprint.send_all(
|
||||
self, 'register_with_event_loop', args=(hub,),
|
||||
description='Hub.register',
|
||||
)
|
||||
|
||||
def shutdown(self):
|
||||
self.perform_pending_operations()
|
||||
self.blueprint.shutdown(self)
|
||||
|
||||
def stop(self):
|
||||
self.blueprint.stop(self)
|
||||
|
||||
def on_ready(self):
|
||||
callback, self.init_callback = self.init_callback, None
|
||||
if callback:
|
||||
callback(self)
|
||||
|
||||
def loop_args(self):
|
||||
return (self, self.connection, self.task_consumer,
|
||||
self.blueprint, self.hub, self.qos, self.amqheartbeat,
|
||||
self.app.clock, self.amqheartbeat_rate)
|
||||
|
||||
def on_decode_error(self, message, exc):
|
||||
"""Callback called if an error occurs while decoding a message.
|
||||
|
||||
Simply logs the error and acknowledges the message so it
|
||||
doesn't enter a loop.
|
||||
|
||||
Arguments:
|
||||
message (kombu.Message): The message received.
|
||||
exc (Exception): The exception being handled.
|
||||
"""
|
||||
crit(MESSAGE_DECODE_ERROR,
|
||||
exc, message.content_type, message.content_encoding,
|
||||
safe_repr(message.headers), dump_body(message, message.body),
|
||||
exc_info=1)
|
||||
message.ack()
|
||||
|
||||
def on_close(self):
|
||||
# Clear internal queues to get rid of old messages.
|
||||
# They can't be acked anyway, as a delivery tag is specific
|
||||
# to the current channel.
|
||||
if self.controller and self.controller.semaphore:
|
||||
self.controller.semaphore.clear()
|
||||
if self.timer:
|
||||
self.timer.clear()
|
||||
for bucket in self.task_buckets.values():
|
||||
if bucket:
|
||||
bucket.clear_pending()
|
||||
for request_id in reserved_requests:
|
||||
if request_id in requests:
|
||||
del requests[request_id]
|
||||
reserved_requests.clear()
|
||||
if self.pool and self.pool.flush:
|
||||
self.pool.flush()
|
||||
|
||||
def connect(self):
|
||||
"""Establish the broker connection used for consuming tasks.
|
||||
|
||||
Retries establishing the connection if the
|
||||
:setting:`broker_connection_retry` setting is enabled
|
||||
"""
|
||||
conn = self.connection_for_read(heartbeat=self.amqheartbeat)
|
||||
if self.hub:
|
||||
conn.transport.register_with_event_loop(conn.connection, self.hub)
|
||||
return conn
|
||||
|
||||
def connection_for_read(self, heartbeat=None):
|
||||
return self.ensure_connected(
|
||||
self.app.connection_for_read(heartbeat=heartbeat))
|
||||
|
||||
def connection_for_write(self, url=None, heartbeat=None):
|
||||
return self.ensure_connected(
|
||||
self.app.connection_for_write(url=url, heartbeat=heartbeat))
|
||||
|
||||
def ensure_connected(self, conn):
|
||||
# Callback called for each retry while the connection
|
||||
# can't be established.
|
||||
def _error_handler(exc, interval, next_step=CONNECTION_RETRY_STEP):
|
||||
if getattr(conn, 'alt', None) and interval == 0:
|
||||
next_step = CONNECTION_FAILOVER
|
||||
next_step = next_step.format(
|
||||
when=humanize_seconds(interval, 'in', ' '),
|
||||
retries=int(interval / 2),
|
||||
max_retries=self.app.conf.broker_connection_max_retries)
|
||||
error(CONNECTION_ERROR, conn.as_uri(), exc, next_step)
|
||||
|
||||
# Remember that the connection is lazy, it won't establish
|
||||
# until needed.
|
||||
|
||||
# TODO: Rely only on broker_connection_retry_on_startup to determine whether connection retries are disabled.
|
||||
# We will make the switch in Celery 6.0.
|
||||
|
||||
retry_disabled = False
|
||||
|
||||
if self.app.conf.broker_connection_retry_on_startup is None:
|
||||
# If broker_connection_retry_on_startup is not set, revert to broker_connection_retry
|
||||
# to determine whether connection retries are disabled.
|
||||
retry_disabled = not self.app.conf.broker_connection_retry
|
||||
|
||||
if retry_disabled:
|
||||
warnings.warn(
|
||||
CPendingDeprecationWarning(
|
||||
"The broker_connection_retry configuration setting will no longer determine\n"
|
||||
"whether broker connection retries are made during startup in Celery 6.0 and above.\n"
|
||||
"If you wish to refrain from retrying connections on startup,\n"
|
||||
"you should set broker_connection_retry_on_startup to False instead.")
|
||||
)
|
||||
else:
|
||||
if self.first_connection_attempt:
|
||||
retry_disabled = not self.app.conf.broker_connection_retry_on_startup
|
||||
else:
|
||||
retry_disabled = not self.app.conf.broker_connection_retry
|
||||
|
||||
if retry_disabled:
|
||||
# Retry disabled, just call connect directly.
|
||||
conn.connect()
|
||||
self.first_connection_attempt = False
|
||||
return conn
|
||||
|
||||
conn = conn.ensure_connection(
|
||||
_error_handler, self.app.conf.broker_connection_max_retries,
|
||||
callback=maybe_shutdown,
|
||||
)
|
||||
self.first_connection_attempt = False
|
||||
return conn
|
||||
|
||||
def _flush_events(self):
|
||||
if self.event_dispatcher:
|
||||
self.event_dispatcher.flush()
|
||||
|
||||
def on_send_event_buffered(self):
|
||||
if self.hub:
|
||||
self.hub._ready.add(self._flush_events)
|
||||
|
||||
def add_task_queue(self, queue, exchange=None, exchange_type=None,
|
||||
routing_key=None, **options):
|
||||
cset = self.task_consumer
|
||||
queues = self.app.amqp.queues
|
||||
# Must use in' here, as __missing__ will automatically
|
||||
# create queues when :setting:`task_create_missing_queues` is enabled.
|
||||
# (Issue #1079)
|
||||
if queue in queues:
|
||||
q = queues[queue]
|
||||
else:
|
||||
exchange = queue if exchange is None else exchange
|
||||
exchange_type = ('direct' if exchange_type is None
|
||||
else exchange_type)
|
||||
q = queues.select_add(queue,
|
||||
exchange=exchange,
|
||||
exchange_type=exchange_type,
|
||||
routing_key=routing_key, **options)
|
||||
if not cset.consuming_from(queue):
|
||||
cset.add_queue(q)
|
||||
cset.consume()
|
||||
info('Started consuming from %s', queue)
|
||||
|
||||
def cancel_task_queue(self, queue):
|
||||
info('Canceling queue %s', queue)
|
||||
self.app.amqp.queues.deselect(queue)
|
||||
self.task_consumer.cancel_by_queue(queue)
|
||||
|
||||
def apply_eta_task(self, task):
|
||||
"""Method called by the timer to apply a task with an ETA/countdown."""
|
||||
task_reserved(task)
|
||||
self.on_task_request(task)
|
||||
self.qos.decrement_eventually()
|
||||
|
||||
def _message_report(self, body, message):
|
||||
return MESSAGE_REPORT.format(dump_body(message, body),
|
||||
safe_repr(message.content_type),
|
||||
safe_repr(message.content_encoding),
|
||||
safe_repr(message.delivery_info),
|
||||
safe_repr(message.headers))
|
||||
|
||||
def on_unknown_message(self, body, message):
|
||||
warn(UNKNOWN_FORMAT, self._message_report(body, message))
|
||||
message.reject_log_error(logger, self.connection_errors)
|
||||
signals.task_rejected.send(sender=self, message=message, exc=None)
|
||||
|
||||
def on_unknown_task(self, body, message, exc):
|
||||
error(UNKNOWN_TASK_ERROR,
|
||||
exc,
|
||||
dump_body(message, body),
|
||||
message.headers,
|
||||
message.delivery_info,
|
||||
exc_info=True)
|
||||
try:
|
||||
id_, name = message.headers['id'], message.headers['task']
|
||||
root_id = message.headers.get('root_id')
|
||||
except KeyError: # proto1
|
||||
payload = message.payload
|
||||
id_, name = payload['id'], payload['task']
|
||||
root_id = None
|
||||
request = Bunch(
|
||||
name=name, chord=None, root_id=root_id,
|
||||
correlation_id=message.properties.get('correlation_id'),
|
||||
reply_to=message.properties.get('reply_to'),
|
||||
errbacks=None,
|
||||
)
|
||||
message.reject_log_error(logger, self.connection_errors)
|
||||
self.app.backend.mark_as_failure(
|
||||
id_, NotRegistered(name), request=request,
|
||||
)
|
||||
if self.event_dispatcher:
|
||||
self.event_dispatcher.send(
|
||||
'task-failed', uuid=id_,
|
||||
exception=f'NotRegistered({name!r})',
|
||||
)
|
||||
signals.task_unknown.send(
|
||||
sender=self, message=message, exc=exc, name=name, id=id_,
|
||||
)
|
||||
|
||||
def on_invalid_task(self, body, message, exc):
|
||||
error(INVALID_TASK_ERROR, exc, dump_body(message, body),
|
||||
exc_info=True)
|
||||
message.reject_log_error(logger, self.connection_errors)
|
||||
signals.task_rejected.send(sender=self, message=message, exc=exc)
|
||||
|
||||
def update_strategies(self):
|
||||
loader = self.app.loader
|
||||
for name, task in self.app.tasks.items():
|
||||
self.strategies[name] = task.start_strategy(self.app, self)
|
||||
task.__trace__ = build_tracer(name, task, loader, self.hostname,
|
||||
app=self.app)
|
||||
|
||||
def create_task_handler(self, promise=promise):
|
||||
strategies = self.strategies
|
||||
on_unknown_message = self.on_unknown_message
|
||||
on_unknown_task = self.on_unknown_task
|
||||
on_invalid_task = self.on_invalid_task
|
||||
callbacks = self.on_task_message
|
||||
call_soon = self.call_soon
|
||||
|
||||
def on_task_received(message):
|
||||
# payload will only be set for v1 protocol, since v2
|
||||
# will defer deserializing the message body to the pool.
|
||||
payload = None
|
||||
try:
|
||||
type_ = message.headers['task'] # protocol v2
|
||||
except TypeError:
|
||||
return on_unknown_message(None, message)
|
||||
except KeyError:
|
||||
try:
|
||||
payload = message.decode()
|
||||
except Exception as exc: # pylint: disable=broad-except
|
||||
return self.on_decode_error(message, exc)
|
||||
try:
|
||||
type_, payload = payload['task'], payload # protocol v1
|
||||
except (TypeError, KeyError):
|
||||
return on_unknown_message(payload, message)
|
||||
try:
|
||||
strategy = strategies[type_]
|
||||
except KeyError as exc:
|
||||
return on_unknown_task(None, message, exc)
|
||||
else:
|
||||
try:
|
||||
ack_log_error_promise = promise(
|
||||
call_soon,
|
||||
(message.ack_log_error,),
|
||||
on_error=self._restore_prefetch_count_after_connection_restart,
|
||||
)
|
||||
reject_log_error_promise = promise(
|
||||
call_soon,
|
||||
(message.reject_log_error,),
|
||||
on_error=self._restore_prefetch_count_after_connection_restart,
|
||||
)
|
||||
|
||||
if (
|
||||
not self._maximum_prefetch_restored
|
||||
and self.restart_count > 0
|
||||
and self._new_prefetch_count <= self.max_prefetch_count
|
||||
):
|
||||
ack_log_error_promise.then(self._restore_prefetch_count_after_connection_restart,
|
||||
on_error=self._restore_prefetch_count_after_connection_restart)
|
||||
reject_log_error_promise.then(self._restore_prefetch_count_after_connection_restart,
|
||||
on_error=self._restore_prefetch_count_after_connection_restart)
|
||||
|
||||
strategy(
|
||||
message, payload,
|
||||
ack_log_error_promise,
|
||||
reject_log_error_promise,
|
||||
callbacks,
|
||||
)
|
||||
except (InvalidTaskError, ContentDisallowed) as exc:
|
||||
return on_invalid_task(payload, message, exc)
|
||||
except DecodeError as exc:
|
||||
return self.on_decode_error(message, exc)
|
||||
|
||||
return on_task_received
|
||||
|
||||
def _restore_prefetch_count_after_connection_restart(self, p, *args):
|
||||
with self.qos._mutex:
|
||||
if any((
|
||||
not self.app.conf.worker_enable_prefetch_count_reduction,
|
||||
self._maximum_prefetch_restored,
|
||||
)):
|
||||
return
|
||||
|
||||
new_prefetch_count = min(self.max_prefetch_count, self._new_prefetch_count)
|
||||
self.qos.value = self.initial_prefetch_count = new_prefetch_count
|
||||
self.qos.set(self.qos.value)
|
||||
|
||||
already_restored = self._maximum_prefetch_restored
|
||||
self._maximum_prefetch_restored = new_prefetch_count == self.max_prefetch_count
|
||||
|
||||
if already_restored is False and self._maximum_prefetch_restored is True:
|
||||
logger.info(
|
||||
"Resuming normal operations following a restart.\n"
|
||||
f"Prefetch count has been restored to the maximum of {self.max_prefetch_count}"
|
||||
)
|
||||
|
||||
@property
|
||||
def max_prefetch_count(self):
|
||||
return self.pool.num_processes * self.prefetch_multiplier
|
||||
|
||||
@property
|
||||
def _new_prefetch_count(self):
|
||||
return self.qos.value + self.prefetch_multiplier
|
||||
|
||||
def __repr__(self):
|
||||
"""``repr(self)``."""
|
||||
return '<Consumer: {self.hostname} ({state})>'.format(
|
||||
self=self, state=self.blueprint.human_state(),
|
||||
)
|
||||
|
||||
def cancel_all_unacked_requests(self):
|
||||
"""Cancel all active requests that either do not require late acknowledgments or,
|
||||
if they do, have not been acknowledged yet.
|
||||
"""
|
||||
|
||||
def should_cancel(request):
|
||||
if not request.task.acks_late:
|
||||
# Task does not require late acknowledgment, cancel it.
|
||||
return True
|
||||
|
||||
if not request.acknowledged:
|
||||
# Task is late acknowledged, but it has not been acknowledged yet, cancel it.
|
||||
return True
|
||||
|
||||
# Task is late acknowledged, but it has already been acknowledged.
|
||||
return False # Do not cancel and allow it to gracefully finish as it has already been acknowledged.
|
||||
|
||||
requests_to_cancel = tuple(filter(should_cancel, active_requests))
|
||||
|
||||
if requests_to_cancel:
|
||||
for request in requests_to_cancel:
|
||||
request.cancel(self.pool)
|
||||
|
||||
|
||||
class Evloop(bootsteps.StartStopStep):
|
||||
"""Event loop service.
|
||||
|
||||
Note:
|
||||
This is always started last.
|
||||
"""
|
||||
|
||||
label = 'event loop'
|
||||
last = True
|
||||
|
||||
def start(self, c):
|
||||
self.patch_all(c)
|
||||
c.loop(*c.loop_args())
|
||||
|
||||
def patch_all(self, c):
|
||||
c.qos._mutex = DummyLock()
|
||||
@@ -0,0 +1,33 @@
|
||||
"""Worker Remote Control Bootstep.
|
||||
|
||||
``Control`` -> :mod:`celery.worker.pidbox` -> :mod:`kombu.pidbox`.
|
||||
|
||||
The actual commands are implemented in :mod:`celery.worker.control`.
|
||||
"""
|
||||
from celery import bootsteps
|
||||
from celery.utils.log import get_logger
|
||||
from celery.worker import pidbox
|
||||
|
||||
from .tasks import Tasks
|
||||
|
||||
__all__ = ('Control',)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class Control(bootsteps.StartStopStep):
|
||||
"""Remote control command service."""
|
||||
|
||||
requires = (Tasks,)
|
||||
|
||||
def __init__(self, c, **kwargs):
|
||||
self.is_green = c.pool is not None and c.pool.is_green
|
||||
self.box = (pidbox.gPidbox if self.is_green else pidbox.Pidbox)(c)
|
||||
self.start = self.box.start
|
||||
self.stop = self.box.stop
|
||||
self.shutdown = self.box.shutdown
|
||||
super().__init__(c, **kwargs)
|
||||
|
||||
def include_if(self, c):
|
||||
return (c.app.conf.worker_enable_remote_control and
|
||||
c.conninfo.supports_exchange_type('fanout'))
|
||||
@@ -0,0 +1,247 @@
|
||||
"""Native delayed delivery functionality for Celery workers.
|
||||
|
||||
This module provides the DelayedDelivery bootstep which handles setup and configuration
|
||||
of native delayed delivery functionality when using quorum queues.
|
||||
"""
|
||||
from typing import Iterator, List, Optional, Set, Union, ValuesView
|
||||
|
||||
from kombu import Connection, Queue
|
||||
from kombu.transport.native_delayed_delivery import (bind_queue_to_native_delayed_delivery_exchange,
|
||||
declare_native_delayed_delivery_exchanges_and_queues)
|
||||
from kombu.utils.functional import retry_over_time
|
||||
|
||||
from celery import Celery, bootsteps
|
||||
from celery.utils.log import get_logger
|
||||
from celery.utils.quorum_queues import detect_quorum_queues
|
||||
from celery.worker.consumer import Consumer, Tasks
|
||||
|
||||
__all__ = ('DelayedDelivery',)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
# Default retry settings
|
||||
RETRY_INTERVAL = 1.0 # seconds between retries
|
||||
MAX_RETRIES = 3 # maximum number of retries
|
||||
|
||||
|
||||
# Valid queue types for delayed delivery
|
||||
VALID_QUEUE_TYPES = {'classic', 'quorum'}
|
||||
|
||||
|
||||
class DelayedDelivery(bootsteps.StartStopStep):
|
||||
"""Bootstep that sets up native delayed delivery functionality.
|
||||
|
||||
This component handles the setup and configuration of native delayed delivery
|
||||
for Celery workers. It is automatically included when quorum queues are
|
||||
detected in the application configuration.
|
||||
|
||||
Responsibilities:
|
||||
- Declaring native delayed delivery exchanges and queues
|
||||
- Binding all application queues to the delayed delivery exchanges
|
||||
- Handling connection failures gracefully with retries
|
||||
- Validating configuration settings
|
||||
"""
|
||||
|
||||
requires = (Tasks,)
|
||||
|
||||
def include_if(self, c: Consumer) -> bool:
|
||||
"""Determine if this bootstep should be included.
|
||||
|
||||
Args:
|
||||
c: The Celery consumer instance
|
||||
|
||||
Returns:
|
||||
bool: True if quorum queues are detected, False otherwise
|
||||
"""
|
||||
return detect_quorum_queues(c.app, c.app.connection_for_write().transport.driver_type)[0]
|
||||
|
||||
def start(self, c: Consumer) -> None:
|
||||
"""Initialize delayed delivery for all broker URLs.
|
||||
|
||||
Attempts to set up delayed delivery for each broker URL in the configuration.
|
||||
Failures are logged but don't prevent attempting remaining URLs.
|
||||
|
||||
Args:
|
||||
c: The Celery consumer instance
|
||||
|
||||
Raises:
|
||||
ValueError: If configuration validation fails
|
||||
"""
|
||||
app: Celery = c.app
|
||||
|
||||
try:
|
||||
self._validate_configuration(app)
|
||||
except ValueError as e:
|
||||
logger.critical("Configuration validation failed: %s", str(e))
|
||||
raise
|
||||
|
||||
broker_urls = self._validate_broker_urls(app.conf.broker_url)
|
||||
setup_errors = []
|
||||
|
||||
for broker_url in broker_urls:
|
||||
try:
|
||||
retry_over_time(
|
||||
self._setup_delayed_delivery,
|
||||
args=(c, broker_url),
|
||||
catch=(ConnectionRefusedError, OSError),
|
||||
errback=self._on_retry,
|
||||
interval_start=RETRY_INTERVAL,
|
||||
max_retries=MAX_RETRIES,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to setup delayed delivery for %r: %s",
|
||||
broker_url, str(e)
|
||||
)
|
||||
setup_errors.append((broker_url, e))
|
||||
|
||||
if len(setup_errors) == len(broker_urls):
|
||||
logger.critical(
|
||||
"Failed to setup delayed delivery for all broker URLs. "
|
||||
"Native delayed delivery will not be available."
|
||||
)
|
||||
|
||||
def _setup_delayed_delivery(self, c: Consumer, broker_url: str) -> None:
|
||||
"""Set up delayed delivery for a specific broker URL.
|
||||
|
||||
Args:
|
||||
c: The Celery consumer instance
|
||||
broker_url: The broker URL to configure
|
||||
|
||||
Raises:
|
||||
ConnectionRefusedError: If connection to the broker fails
|
||||
OSError: If there are network-related issues
|
||||
Exception: For other unexpected errors during setup
|
||||
"""
|
||||
connection: Connection = c.app.connection_for_write(url=broker_url)
|
||||
queue_type = c.app.conf.broker_native_delayed_delivery_queue_type
|
||||
logger.debug(
|
||||
"Setting up delayed delivery for broker %r with queue type %r",
|
||||
broker_url, queue_type
|
||||
)
|
||||
|
||||
try:
|
||||
declare_native_delayed_delivery_exchanges_and_queues(
|
||||
connection,
|
||||
queue_type
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to declare exchanges and queues for %r: %s",
|
||||
broker_url, str(e)
|
||||
)
|
||||
raise
|
||||
|
||||
try:
|
||||
self._bind_queues(c.app, connection)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to bind queues for %r: %s",
|
||||
broker_url, str(e)
|
||||
)
|
||||
raise
|
||||
|
||||
def _bind_queues(self, app: Celery, connection: Connection) -> None:
|
||||
"""Bind all application queues to delayed delivery exchanges.
|
||||
|
||||
Args:
|
||||
app: The Celery application instance
|
||||
connection: The broker connection to use
|
||||
|
||||
Raises:
|
||||
Exception: If queue binding fails
|
||||
"""
|
||||
queues: ValuesView[Queue] = app.amqp.queues.values()
|
||||
if not queues:
|
||||
logger.warning("No queues found to bind for delayed delivery")
|
||||
return
|
||||
|
||||
for queue in queues:
|
||||
try:
|
||||
logger.debug("Binding queue %r to delayed delivery exchange", queue.name)
|
||||
bind_queue_to_native_delayed_delivery_exchange(connection, queue)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to bind queue %r: %s",
|
||||
queue.name, str(e)
|
||||
)
|
||||
raise
|
||||
|
||||
def _on_retry(self, exc: Exception, interval_range: Iterator[float], intervals_count: int) -> None:
|
||||
"""Callback for retry attempts.
|
||||
|
||||
Args:
|
||||
exc: The exception that triggered the retry
|
||||
interval_range: An iterator which returns the time in seconds to sleep next
|
||||
intervals_count: Number of retry attempts so far
|
||||
"""
|
||||
logger.warning(
|
||||
"Retrying delayed delivery setup (attempt %d/%d) after error: %s",
|
||||
intervals_count + 1, MAX_RETRIES, str(exc)
|
||||
)
|
||||
|
||||
def _validate_configuration(self, app: Celery) -> None:
|
||||
"""Validate all required configuration settings.
|
||||
|
||||
Args:
|
||||
app: The Celery application instance
|
||||
|
||||
Raises:
|
||||
ValueError: If any configuration is invalid
|
||||
"""
|
||||
# Validate broker URLs
|
||||
self._validate_broker_urls(app.conf.broker_url)
|
||||
|
||||
# Validate queue type
|
||||
self._validate_queue_type(app.conf.broker_native_delayed_delivery_queue_type)
|
||||
|
||||
def _validate_broker_urls(self, broker_urls: Union[str, List[str]]) -> Set[str]:
|
||||
"""Validate and split broker URLs.
|
||||
|
||||
Args:
|
||||
broker_urls: Broker URLs, either as a semicolon-separated string
|
||||
or as a list of strings
|
||||
|
||||
Returns:
|
||||
Set of valid broker URLs
|
||||
|
||||
Raises:
|
||||
ValueError: If no valid broker URLs are found or if invalid URLs are provided
|
||||
"""
|
||||
if not broker_urls:
|
||||
raise ValueError("broker_url configuration is empty")
|
||||
|
||||
if isinstance(broker_urls, str):
|
||||
brokers = broker_urls.split(";")
|
||||
elif isinstance(broker_urls, list):
|
||||
if not all(isinstance(url, str) for url in broker_urls):
|
||||
raise ValueError("All broker URLs must be strings")
|
||||
brokers = broker_urls
|
||||
else:
|
||||
raise ValueError(f"broker_url must be a string or list, got {broker_urls!r}")
|
||||
|
||||
valid_urls = {url for url in brokers}
|
||||
|
||||
if not valid_urls:
|
||||
raise ValueError("No valid broker URLs found in configuration")
|
||||
|
||||
return valid_urls
|
||||
|
||||
def _validate_queue_type(self, queue_type: Optional[str]) -> None:
|
||||
"""Validate the queue type configuration.
|
||||
|
||||
Args:
|
||||
queue_type: The configured queue type
|
||||
|
||||
Raises:
|
||||
ValueError: If queue type is invalid
|
||||
"""
|
||||
if not queue_type:
|
||||
raise ValueError("broker_native_delayed_delivery_queue_type is not configured")
|
||||
|
||||
if queue_type not in VALID_QUEUE_TYPES:
|
||||
sorted_types = sorted(VALID_QUEUE_TYPES)
|
||||
raise ValueError(
|
||||
f"Invalid queue type {queue_type!r}. Must be one of: {', '.join(sorted_types)}"
|
||||
)
|
||||
@@ -0,0 +1,68 @@
|
||||
"""Worker Event Dispatcher Bootstep.
|
||||
|
||||
``Events`` -> :class:`celery.events.EventDispatcher`.
|
||||
"""
|
||||
from kombu.common import ignore_errors
|
||||
|
||||
from celery import bootsteps
|
||||
|
||||
from .connection import Connection
|
||||
|
||||
__all__ = ('Events',)
|
||||
|
||||
|
||||
class Events(bootsteps.StartStopStep):
|
||||
"""Service used for sending monitoring events."""
|
||||
|
||||
requires = (Connection,)
|
||||
|
||||
def __init__(self, c,
|
||||
task_events=True,
|
||||
without_heartbeat=False,
|
||||
without_gossip=False,
|
||||
**kwargs):
|
||||
self.groups = None if task_events else ['worker']
|
||||
self.send_events = (
|
||||
task_events or
|
||||
not without_gossip or
|
||||
not without_heartbeat
|
||||
)
|
||||
self.enabled = self.send_events
|
||||
c.event_dispatcher = None
|
||||
super().__init__(c, **kwargs)
|
||||
|
||||
def start(self, c):
|
||||
# flush events sent while connection was down.
|
||||
prev = self._close(c)
|
||||
dis = c.event_dispatcher = c.app.events.Dispatcher(
|
||||
c.connection_for_write(),
|
||||
hostname=c.hostname,
|
||||
enabled=self.send_events,
|
||||
groups=self.groups,
|
||||
# we currently only buffer events when the event loop is enabled
|
||||
# XXX This excludes eventlet/gevent, which should actually buffer.
|
||||
buffer_group=['task'] if c.hub else None,
|
||||
on_send_buffered=c.on_send_event_buffered if c.hub else None,
|
||||
)
|
||||
if prev:
|
||||
dis.extend_buffer(prev)
|
||||
dis.flush()
|
||||
|
||||
def stop(self, c):
|
||||
pass
|
||||
|
||||
def _close(self, c):
|
||||
if c.event_dispatcher:
|
||||
dispatcher = c.event_dispatcher
|
||||
# remember changes from remote control commands:
|
||||
self.groups = dispatcher.groups
|
||||
|
||||
# close custom connection
|
||||
if dispatcher.connection:
|
||||
ignore_errors(c, dispatcher.connection.close)
|
||||
ignore_errors(c, dispatcher.close)
|
||||
c.event_dispatcher = None
|
||||
return dispatcher
|
||||
|
||||
def shutdown(self, c):
|
||||
self._close(c)
|
||||
@@ -0,0 +1,206 @@
|
||||
"""Worker <-> Worker communication Bootstep."""
|
||||
from collections import defaultdict
|
||||
from functools import partial
|
||||
from heapq import heappush
|
||||
from operator import itemgetter
|
||||
|
||||
from kombu import Consumer
|
||||
from kombu.asynchronous.semaphore import DummyLock
|
||||
from kombu.exceptions import ContentDisallowed, DecodeError
|
||||
|
||||
from celery import bootsteps
|
||||
from celery.utils.log import get_logger
|
||||
from celery.utils.objects import Bunch
|
||||
|
||||
from .mingle import Mingle
|
||||
|
||||
__all__ = ('Gossip',)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
debug, info = logger.debug, logger.info
|
||||
|
||||
|
||||
class Gossip(bootsteps.ConsumerStep):
|
||||
"""Bootstep consuming events from other workers.
|
||||
|
||||
This keeps the logical clock value up to date.
|
||||
"""
|
||||
|
||||
label = 'Gossip'
|
||||
requires = (Mingle,)
|
||||
_cons_stamp_fields = itemgetter(
|
||||
'id', 'clock', 'hostname', 'pid', 'topic', 'action', 'cver',
|
||||
)
|
||||
compatible_transports = {'amqp', 'redis'}
|
||||
|
||||
def __init__(self, c, without_gossip=False,
|
||||
interval=5.0, heartbeat_interval=2.0, **kwargs):
|
||||
self.enabled = not without_gossip and self.compatible_transport(c.app)
|
||||
self.app = c.app
|
||||
c.gossip = self
|
||||
self.Receiver = c.app.events.Receiver
|
||||
self.hostname = c.hostname
|
||||
self.full_hostname = '.'.join([self.hostname, str(c.pid)])
|
||||
self.on = Bunch(
|
||||
node_join=set(),
|
||||
node_leave=set(),
|
||||
node_lost=set(),
|
||||
)
|
||||
|
||||
self.timer = c.timer
|
||||
if self.enabled:
|
||||
self.state = c.app.events.State(
|
||||
on_node_join=self.on_node_join,
|
||||
on_node_leave=self.on_node_leave,
|
||||
max_tasks_in_memory=1,
|
||||
)
|
||||
if c.hub:
|
||||
c._mutex = DummyLock()
|
||||
self.update_state = self.state.event
|
||||
self.interval = interval
|
||||
self.heartbeat_interval = heartbeat_interval
|
||||
self._tref = None
|
||||
self.consensus_requests = defaultdict(list)
|
||||
self.consensus_replies = {}
|
||||
self.event_handlers = {
|
||||
'worker.elect': self.on_elect,
|
||||
'worker.elect.ack': self.on_elect_ack,
|
||||
}
|
||||
self.clock = c.app.clock
|
||||
|
||||
self.election_handlers = {
|
||||
'task': self.call_task
|
||||
}
|
||||
|
||||
super().__init__(c, **kwargs)
|
||||
|
||||
def compatible_transport(self, app):
|
||||
with app.connection_for_read() as conn:
|
||||
return conn.transport.driver_type in self.compatible_transports
|
||||
|
||||
def election(self, id, topic, action=None):
|
||||
self.consensus_replies[id] = []
|
||||
self.dispatcher.send(
|
||||
'worker-elect',
|
||||
id=id, topic=topic, action=action, cver=1,
|
||||
)
|
||||
|
||||
def call_task(self, task):
|
||||
try:
|
||||
self.app.signature(task).apply_async()
|
||||
except Exception as exc: # pylint: disable=broad-except
|
||||
logger.exception('Could not call task: %r', exc)
|
||||
|
||||
def on_elect(self, event):
|
||||
try:
|
||||
(id_, clock, hostname, pid,
|
||||
topic, action, _) = self._cons_stamp_fields(event)
|
||||
except KeyError as exc:
|
||||
return logger.exception('election request missing field %s', exc)
|
||||
heappush(
|
||||
self.consensus_requests[id_],
|
||||
(clock, f'{hostname}.{pid}', topic, action),
|
||||
)
|
||||
self.dispatcher.send('worker-elect-ack', id=id_)
|
||||
|
||||
def start(self, c):
|
||||
super().start(c)
|
||||
self.dispatcher = c.event_dispatcher
|
||||
|
||||
def on_elect_ack(self, event):
|
||||
id = event['id']
|
||||
try:
|
||||
replies = self.consensus_replies[id]
|
||||
except KeyError:
|
||||
return # not for us
|
||||
alive_workers = set(self.state.alive_workers())
|
||||
replies.append(event['hostname'])
|
||||
|
||||
if len(replies) >= len(alive_workers):
|
||||
_, leader, topic, action = self.clock.sort_heap(
|
||||
self.consensus_requests[id],
|
||||
)
|
||||
if leader == self.full_hostname:
|
||||
info('I won the election %r', id)
|
||||
try:
|
||||
handler = self.election_handlers[topic]
|
||||
except KeyError:
|
||||
logger.exception('Unknown election topic %r', topic)
|
||||
else:
|
||||
handler(action)
|
||||
else:
|
||||
info('node %s elected for %r', leader, id)
|
||||
self.consensus_requests.pop(id, None)
|
||||
self.consensus_replies.pop(id, None)
|
||||
|
||||
def on_node_join(self, worker):
|
||||
debug('%s joined the party', worker.hostname)
|
||||
self._call_handlers(self.on.node_join, worker)
|
||||
|
||||
def on_node_leave(self, worker):
|
||||
debug('%s left', worker.hostname)
|
||||
self._call_handlers(self.on.node_leave, worker)
|
||||
|
||||
def on_node_lost(self, worker):
|
||||
info('missed heartbeat from %s', worker.hostname)
|
||||
self._call_handlers(self.on.node_lost, worker)
|
||||
|
||||
def _call_handlers(self, handlers, *args, **kwargs):
|
||||
for handler in handlers:
|
||||
try:
|
||||
handler(*args, **kwargs)
|
||||
except Exception as exc: # pylint: disable=broad-except
|
||||
logger.exception(
|
||||
'Ignored error from handler %r: %r', handler, exc)
|
||||
|
||||
def register_timer(self):
|
||||
if self._tref is not None:
|
||||
self._tref.cancel()
|
||||
self._tref = self.timer.call_repeatedly(self.interval, self.periodic)
|
||||
|
||||
def periodic(self):
|
||||
workers = self.state.workers
|
||||
dirty = set()
|
||||
for worker in workers.values():
|
||||
if not worker.alive:
|
||||
dirty.add(worker)
|
||||
self.on_node_lost(worker)
|
||||
for worker in dirty:
|
||||
workers.pop(worker.hostname, None)
|
||||
|
||||
def get_consumers(self, channel):
|
||||
self.register_timer()
|
||||
ev = self.Receiver(channel, routing_key='worker.#',
|
||||
queue_ttl=self.heartbeat_interval)
|
||||
return [Consumer(
|
||||
channel,
|
||||
queues=[ev.queue],
|
||||
on_message=partial(self.on_message, ev.event_from_message),
|
||||
accept=ev.accept,
|
||||
no_ack=True
|
||||
)]
|
||||
|
||||
def on_message(self, prepare, message):
|
||||
_type = message.delivery_info['routing_key']
|
||||
|
||||
# For redis when `fanout_patterns=False` (See Issue #1882)
|
||||
if _type.split('.', 1)[0] == 'task':
|
||||
return
|
||||
try:
|
||||
handler = self.event_handlers[_type]
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
return handler(message.payload)
|
||||
|
||||
# proto2: hostname in header; proto1: in body
|
||||
hostname = (message.headers.get('hostname') or
|
||||
message.payload['hostname'])
|
||||
if hostname != self.hostname:
|
||||
try:
|
||||
_, event = prepare(message.payload)
|
||||
self.update_state(event)
|
||||
except (DecodeError, ContentDisallowed, TypeError) as exc:
|
||||
logger.error(exc)
|
||||
else:
|
||||
self.clock.forward()
|
||||
@@ -0,0 +1,36 @@
|
||||
"""Worker Event Heartbeat Bootstep."""
|
||||
from celery import bootsteps
|
||||
from celery.worker import heartbeat
|
||||
|
||||
from .events import Events
|
||||
|
||||
__all__ = ('Heart',)
|
||||
|
||||
|
||||
class Heart(bootsteps.StartStopStep):
|
||||
"""Bootstep sending event heartbeats.
|
||||
|
||||
This service sends a ``worker-heartbeat`` message every n seconds.
|
||||
|
||||
Note:
|
||||
Not to be confused with AMQP protocol level heartbeats.
|
||||
"""
|
||||
|
||||
requires = (Events,)
|
||||
|
||||
def __init__(self, c,
|
||||
without_heartbeat=False, heartbeat_interval=None, **kwargs):
|
||||
self.enabled = not without_heartbeat
|
||||
self.heartbeat_interval = heartbeat_interval
|
||||
c.heart = None
|
||||
super().__init__(c, **kwargs)
|
||||
|
||||
def start(self, c):
|
||||
c.heart = heartbeat.Heart(
|
||||
c.timer, c.event_dispatcher, self.heartbeat_interval,
|
||||
)
|
||||
c.heart.start()
|
||||
|
||||
def stop(self, c):
|
||||
c.heart = c.heart and c.heart.stop()
|
||||
shutdown = stop
|
||||
@@ -0,0 +1,76 @@
|
||||
"""Worker <-> Worker Sync at startup (Bootstep)."""
|
||||
from celery import bootsteps
|
||||
from celery.utils.log import get_logger
|
||||
|
||||
from .events import Events
|
||||
|
||||
__all__ = ('Mingle',)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
debug, info, exception = logger.debug, logger.info, logger.exception
|
||||
|
||||
|
||||
class Mingle(bootsteps.StartStopStep):
|
||||
"""Bootstep syncing state with neighbor workers.
|
||||
|
||||
At startup, or upon consumer restart, this will:
|
||||
|
||||
- Sync logical clocks.
|
||||
- Sync revoked tasks.
|
||||
|
||||
"""
|
||||
|
||||
label = 'Mingle'
|
||||
requires = (Events,)
|
||||
compatible_transports = {'amqp', 'redis', 'gcpubsub'}
|
||||
|
||||
def __init__(self, c, without_mingle=False, **kwargs):
|
||||
self.enabled = not without_mingle and self.compatible_transport(c.app)
|
||||
super().__init__(
|
||||
c, without_mingle=without_mingle, **kwargs)
|
||||
|
||||
def compatible_transport(self, app):
|
||||
with app.connection_for_read() as conn:
|
||||
return conn.transport.driver_type in self.compatible_transports
|
||||
|
||||
def start(self, c):
|
||||
self.sync(c)
|
||||
|
||||
def sync(self, c):
|
||||
info('mingle: searching for neighbors')
|
||||
replies = self.send_hello(c)
|
||||
if replies:
|
||||
info('mingle: sync with %s nodes',
|
||||
len([reply for reply, value in replies.items() if value]))
|
||||
[self.on_node_reply(c, nodename, reply)
|
||||
for nodename, reply in replies.items() if reply]
|
||||
info('mingle: sync complete')
|
||||
else:
|
||||
info('mingle: all alone')
|
||||
|
||||
def send_hello(self, c):
|
||||
inspect = c.app.control.inspect(timeout=1.0, connection=c.connection)
|
||||
our_revoked = c.controller.state.revoked
|
||||
replies = inspect.hello(c.hostname, our_revoked._data) or {}
|
||||
replies.pop(c.hostname, None) # delete my own response
|
||||
return replies
|
||||
|
||||
def on_node_reply(self, c, nodename, reply):
|
||||
debug('mingle: processing reply from %s', nodename)
|
||||
try:
|
||||
self.sync_with_node(c, **reply)
|
||||
except MemoryError:
|
||||
raise
|
||||
except Exception as exc: # pylint: disable=broad-except
|
||||
exception('mingle: sync with %s failed: %r', nodename, exc)
|
||||
|
||||
def sync_with_node(self, c, clock=None, revoked=None, **kwargs):
|
||||
self.on_clock_event(c, clock)
|
||||
self.on_revoked_received(c, revoked)
|
||||
|
||||
def on_clock_event(self, c, clock):
|
||||
c.app.clock.adjust(clock) if clock else c.app.clock.forward()
|
||||
|
||||
def on_revoked_received(self, c, revoked):
|
||||
if revoked:
|
||||
c.controller.state.revoked.update(revoked)
|
||||
@@ -0,0 +1,88 @@
|
||||
"""Worker Task Consumer Bootstep."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from kombu.common import QoS, ignore_errors
|
||||
|
||||
from celery import bootsteps
|
||||
from celery.utils.log import get_logger
|
||||
from celery.utils.quorum_queues import detect_quorum_queues
|
||||
|
||||
from .mingle import Mingle
|
||||
|
||||
__all__ = ('Tasks',)
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
debug = logger.debug
|
||||
|
||||
|
||||
class Tasks(bootsteps.StartStopStep):
|
||||
"""Bootstep starting the task message consumer."""
|
||||
|
||||
requires = (Mingle,)
|
||||
|
||||
def __init__(self, c, **kwargs):
|
||||
c.task_consumer = c.qos = None
|
||||
super().__init__(c, **kwargs)
|
||||
|
||||
def start(self, c):
|
||||
"""Start task consumer."""
|
||||
c.update_strategies()
|
||||
|
||||
qos_global = self.qos_global(c)
|
||||
|
||||
# set initial prefetch count
|
||||
c.connection.default_channel.basic_qos(
|
||||
0, c.initial_prefetch_count, qos_global,
|
||||
)
|
||||
|
||||
c.task_consumer = c.app.amqp.TaskConsumer(
|
||||
c.connection, on_decode_error=c.on_decode_error,
|
||||
)
|
||||
|
||||
def set_prefetch_count(prefetch_count):
|
||||
return c.task_consumer.qos(
|
||||
prefetch_count=prefetch_count,
|
||||
apply_global=qos_global,
|
||||
)
|
||||
c.qos = QoS(set_prefetch_count, c.initial_prefetch_count)
|
||||
|
||||
def stop(self, c):
|
||||
"""Stop task consumer."""
|
||||
if c.task_consumer:
|
||||
debug('Canceling task consumer...')
|
||||
ignore_errors(c, c.task_consumer.cancel)
|
||||
|
||||
def shutdown(self, c):
|
||||
"""Shutdown task consumer."""
|
||||
if c.task_consumer:
|
||||
self.stop(c)
|
||||
debug('Closing consumer channel...')
|
||||
ignore_errors(c, c.task_consumer.close)
|
||||
c.task_consumer = None
|
||||
|
||||
def info(self, c):
|
||||
"""Return task consumer info."""
|
||||
return {'prefetch_count': c.qos.value if c.qos else 'N/A'}
|
||||
|
||||
def qos_global(self, c) -> bool:
|
||||
"""Determine if global QoS should be applied.
|
||||
|
||||
Additional information:
|
||||
https://www.rabbitmq.com/docs/consumer-prefetch
|
||||
https://www.rabbitmq.com/docs/quorum-queues#global-qos
|
||||
"""
|
||||
# - RabbitMQ 3.3 completely redefines how basic_qos works...
|
||||
# This will detect if the new qos semantics is in effect,
|
||||
# and if so make sure the 'apply_global' flag is set on qos updates.
|
||||
qos_global = not c.connection.qos_semantics_matches_spec
|
||||
|
||||
if c.app.conf.worker_detect_quorum_queues:
|
||||
using_quorum_queues, qname = detect_quorum_queues(c.app, c.connection.transport.driver_type)
|
||||
|
||||
if using_quorum_queues:
|
||||
qos_global = False
|
||||
logger.info("Global QoS is disabled. Prefetch count in now static.")
|
||||
|
||||
return qos_global
|
||||
@@ -0,0 +1,625 @@
|
||||
"""Worker remote control command implementations."""
|
||||
import io
|
||||
import tempfile
|
||||
from collections import UserDict, defaultdict, namedtuple
|
||||
|
||||
from billiard.common import TERM_SIGNAME
|
||||
from kombu.utils.encoding import safe_repr
|
||||
|
||||
from celery.exceptions import WorkerShutdown
|
||||
from celery.platforms import EX_OK
|
||||
from celery.platforms import signals as _signals
|
||||
from celery.utils.functional import maybe_list
|
||||
from celery.utils.log import get_logger
|
||||
from celery.utils.serialization import jsonify, strtobool
|
||||
from celery.utils.time import rate
|
||||
|
||||
from . import state as worker_state
|
||||
from .request import Request
|
||||
|
||||
__all__ = ('Panel',)
|
||||
|
||||
DEFAULT_TASK_INFO_ITEMS = ('exchange', 'routing_key', 'rate_limit')
|
||||
logger = get_logger(__name__)
|
||||
|
||||
controller_info_t = namedtuple('controller_info_t', [
|
||||
'alias', 'type', 'visible', 'default_timeout',
|
||||
'help', 'signature', 'args', 'variadic',
|
||||
])
|
||||
|
||||
|
||||
def ok(value):
|
||||
return {'ok': value}
|
||||
|
||||
|
||||
def nok(value):
|
||||
return {'error': value}
|
||||
|
||||
|
||||
class Panel(UserDict):
|
||||
"""Global registry of remote control commands."""
|
||||
|
||||
data = {} # global dict.
|
||||
meta = {} # -"-
|
||||
|
||||
@classmethod
|
||||
def register(cls, *args, **kwargs):
|
||||
if args:
|
||||
return cls._register(**kwargs)(*args)
|
||||
return cls._register(**kwargs)
|
||||
|
||||
@classmethod
|
||||
def _register(cls, name=None, alias=None, type='control',
|
||||
visible=True, default_timeout=1.0, help=None,
|
||||
signature=None, args=None, variadic=None):
|
||||
|
||||
def _inner(fun):
|
||||
control_name = name or fun.__name__
|
||||
_help = help or (fun.__doc__ or '').strip().split('\n')[0]
|
||||
cls.data[control_name] = fun
|
||||
cls.meta[control_name] = controller_info_t(
|
||||
alias, type, visible, default_timeout,
|
||||
_help, signature, args, variadic)
|
||||
if alias:
|
||||
cls.data[alias] = fun
|
||||
return fun
|
||||
return _inner
|
||||
|
||||
|
||||
def control_command(**kwargs):
|
||||
return Panel.register(type='control', **kwargs)
|
||||
|
||||
|
||||
def inspect_command(**kwargs):
|
||||
return Panel.register(type='inspect', **kwargs)
|
||||
|
||||
# -- App
|
||||
|
||||
|
||||
@inspect_command()
|
||||
def report(state):
|
||||
"""Information about Celery installation for bug reports."""
|
||||
return ok(state.app.bugreport())
|
||||
|
||||
|
||||
@inspect_command(
|
||||
alias='dump_conf', # XXX < backwards compatible
|
||||
signature='[include_defaults=False]',
|
||||
args=[('with_defaults', strtobool)],
|
||||
)
|
||||
def conf(state, with_defaults=False, **kwargs):
|
||||
"""List configuration."""
|
||||
return jsonify(state.app.conf.table(with_defaults=with_defaults),
|
||||
keyfilter=_wanted_config_key,
|
||||
unknown_type_filter=safe_repr)
|
||||
|
||||
|
||||
def _wanted_config_key(key):
|
||||
return isinstance(key, str) and not key.startswith('__')
|
||||
|
||||
|
||||
# -- Task
|
||||
|
||||
@inspect_command(
|
||||
variadic='ids',
|
||||
signature='[id1 [id2 [... [idN]]]]',
|
||||
)
|
||||
def query_task(state, ids, **kwargs):
|
||||
"""Query for task information by id."""
|
||||
return {
|
||||
req.id: (_state_of_task(req), req.info())
|
||||
for req in _find_requests_by_id(maybe_list(ids))
|
||||
}
|
||||
|
||||
|
||||
def _find_requests_by_id(ids,
|
||||
get_request=worker_state.requests.__getitem__):
|
||||
for task_id in ids:
|
||||
try:
|
||||
yield get_request(task_id)
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
|
||||
def _state_of_task(request,
|
||||
is_active=worker_state.active_requests.__contains__,
|
||||
is_reserved=worker_state.reserved_requests.__contains__):
|
||||
if is_active(request):
|
||||
return 'active'
|
||||
elif is_reserved(request):
|
||||
return 'reserved'
|
||||
return 'ready'
|
||||
|
||||
|
||||
@control_command(
|
||||
variadic='task_id',
|
||||
signature='[id1 [id2 [... [idN]]]]',
|
||||
)
|
||||
def revoke(state, task_id, terminate=False, signal=None, **kwargs):
|
||||
"""Revoke task by task id (or list of ids).
|
||||
|
||||
Keyword Arguments:
|
||||
terminate (bool): Also terminate the process if the task is active.
|
||||
signal (str): Name of signal to use for terminate (e.g., ``KILL``).
|
||||
"""
|
||||
# pylint: disable=redefined-outer-name
|
||||
# XXX Note that this redefines `terminate`:
|
||||
# Outside of this scope that is a function.
|
||||
# supports list argument since 3.1
|
||||
task_ids, task_id = set(maybe_list(task_id) or []), None
|
||||
task_ids = _revoke(state, task_ids, terminate, signal, **kwargs)
|
||||
if isinstance(task_ids, dict) and 'ok' in task_ids:
|
||||
return task_ids
|
||||
return ok(f'tasks {task_ids} flagged as revoked')
|
||||
|
||||
|
||||
@control_command(
|
||||
variadic='headers',
|
||||
signature='[key1=value1 [key2=value2 [... [keyN=valueN]]]]',
|
||||
)
|
||||
def revoke_by_stamped_headers(state, headers, terminate=False, signal=None, **kwargs):
|
||||
"""Revoke task by header (or list of headers).
|
||||
|
||||
Keyword Arguments:
|
||||
headers(dictionary): Dictionary that contains stamping scheme name as keys and stamps as values.
|
||||
If headers is a list, it will be converted to a dictionary.
|
||||
terminate (bool): Also terminate the process if the task is active.
|
||||
signal (str): Name of signal to use for terminate (e.g., ``KILL``).
|
||||
Sample headers input:
|
||||
{'mtask_id': [id1, id2, id3]}
|
||||
"""
|
||||
# pylint: disable=redefined-outer-name
|
||||
# XXX Note that this redefines `terminate`:
|
||||
# Outside of this scope that is a function.
|
||||
# supports list argument since 3.1
|
||||
signum = _signals.signum(signal or TERM_SIGNAME)
|
||||
|
||||
if isinstance(headers, list):
|
||||
headers = {h.split('=')[0]: h.split('=')[1] for h in headers}
|
||||
|
||||
for header, stamps in headers.items():
|
||||
updated_stamps = maybe_list(worker_state.revoked_stamps.get(header) or []) + list(maybe_list(stamps))
|
||||
worker_state.revoked_stamps[header] = updated_stamps
|
||||
|
||||
if not terminate:
|
||||
return ok(f'headers {headers} flagged as revoked, but not terminated')
|
||||
|
||||
active_requests = list(worker_state.active_requests)
|
||||
|
||||
terminated_scheme_to_stamps_mapping = defaultdict(set)
|
||||
|
||||
# Terminate all running tasks of matching headers
|
||||
# Go through all active requests, and check if one of the
|
||||
# requests has a stamped header that matches the given headers to revoke
|
||||
|
||||
for req in active_requests:
|
||||
# Check stamps exist
|
||||
if hasattr(req, "stamps") and req.stamps:
|
||||
# if so, check if any stamps match a revoked stamp
|
||||
for expected_header_key, expected_header_value in headers.items():
|
||||
if expected_header_key in req.stamps:
|
||||
expected_header_value = maybe_list(expected_header_value)
|
||||
actual_header = maybe_list(req.stamps[expected_header_key])
|
||||
matching_stamps_for_request = set(actual_header) & set(expected_header_value)
|
||||
# Check any possible match regardless if the stamps are a sequence or not
|
||||
if matching_stamps_for_request:
|
||||
terminated_scheme_to_stamps_mapping[expected_header_key].update(matching_stamps_for_request)
|
||||
req.terminate(state.consumer.pool, signal=signum)
|
||||
|
||||
if not terminated_scheme_to_stamps_mapping:
|
||||
return ok(f'headers {headers} were not terminated')
|
||||
return ok(f'headers {terminated_scheme_to_stamps_mapping} revoked')
|
||||
|
||||
|
||||
def _revoke(state, task_ids, terminate=False, signal=None, **kwargs):
|
||||
size = len(task_ids)
|
||||
terminated = set()
|
||||
|
||||
worker_state.revoked.update(task_ids)
|
||||
if terminate:
|
||||
signum = _signals.signum(signal or TERM_SIGNAME)
|
||||
for request in _find_requests_by_id(task_ids):
|
||||
if request.id not in terminated:
|
||||
terminated.add(request.id)
|
||||
logger.info('Terminating %s (%s)', request.id, signum)
|
||||
request.terminate(state.consumer.pool, signal=signum)
|
||||
if len(terminated) >= size:
|
||||
break
|
||||
|
||||
if not terminated:
|
||||
return ok('terminate: tasks unknown')
|
||||
return ok('terminate: {}'.format(', '.join(terminated)))
|
||||
|
||||
idstr = ', '.join(task_ids)
|
||||
logger.info('Tasks flagged as revoked: %s', idstr)
|
||||
return task_ids
|
||||
|
||||
|
||||
@control_command(
|
||||
variadic='task_id',
|
||||
args=[('signal', str)],
|
||||
signature='<signal> [id1 [id2 [... [idN]]]]'
|
||||
)
|
||||
def terminate(state, signal, task_id, **kwargs):
|
||||
"""Terminate task by task id (or list of ids)."""
|
||||
return revoke(state, task_id, terminate=True, signal=signal)
|
||||
|
||||
|
||||
@control_command(
|
||||
args=[('task_name', str), ('rate_limit', str)],
|
||||
signature='<task_name> <rate_limit (e.g., 5/s | 5/m | 5/h)>',
|
||||
)
|
||||
def rate_limit(state, task_name, rate_limit, **kwargs):
|
||||
"""Tell worker(s) to modify the rate limit for a task by type.
|
||||
|
||||
See Also:
|
||||
:attr:`celery.app.task.Task.rate_limit`.
|
||||
|
||||
Arguments:
|
||||
task_name (str): Type of task to set rate limit for.
|
||||
rate_limit (int, str): New rate limit.
|
||||
"""
|
||||
# pylint: disable=redefined-outer-name
|
||||
# XXX Note that this redefines `terminate`:
|
||||
# Outside of this scope that is a function.
|
||||
try:
|
||||
rate(rate_limit)
|
||||
except ValueError as exc:
|
||||
return nok(f'Invalid rate limit string: {exc!r}')
|
||||
|
||||
try:
|
||||
state.app.tasks[task_name].rate_limit = rate_limit
|
||||
except KeyError:
|
||||
logger.error('Rate limit attempt for unknown task %s',
|
||||
task_name, exc_info=True)
|
||||
return nok('unknown task')
|
||||
|
||||
state.consumer.reset_rate_limits()
|
||||
|
||||
if not rate_limit:
|
||||
logger.info('Rate limits disabled for tasks of type %s', task_name)
|
||||
return ok('rate limit disabled successfully')
|
||||
|
||||
logger.info('New rate limit for tasks of type %s: %s.',
|
||||
task_name, rate_limit)
|
||||
return ok('new rate limit set successfully')
|
||||
|
||||
|
||||
@control_command(
|
||||
args=[('task_name', str), ('soft', float), ('hard', float)],
|
||||
signature='<task_name> <soft_secs> [hard_secs]',
|
||||
)
|
||||
def time_limit(state, task_name=None, hard=None, soft=None, **kwargs):
|
||||
"""Tell worker(s) to modify the time limit for task by type.
|
||||
|
||||
Arguments:
|
||||
task_name (str): Name of task to change.
|
||||
hard (float): Hard time limit.
|
||||
soft (float): Soft time limit.
|
||||
"""
|
||||
try:
|
||||
task = state.app.tasks[task_name]
|
||||
except KeyError:
|
||||
logger.error('Change time limit attempt for unknown task %s',
|
||||
task_name, exc_info=True)
|
||||
return nok('unknown task')
|
||||
|
||||
task.soft_time_limit = soft
|
||||
task.time_limit = hard
|
||||
|
||||
logger.info('New time limits for tasks of type %s: soft=%s hard=%s',
|
||||
task_name, soft, hard)
|
||||
return ok('time limits set successfully')
|
||||
|
||||
|
||||
# -- Events
|
||||
|
||||
|
||||
@inspect_command()
|
||||
def clock(state, **kwargs):
|
||||
"""Get current logical clock value."""
|
||||
return {'clock': state.app.clock.value}
|
||||
|
||||
|
||||
@control_command()
|
||||
def election(state, id, topic, action=None, **kwargs):
|
||||
"""Hold election.
|
||||
|
||||
Arguments:
|
||||
id (str): Unique election id.
|
||||
topic (str): Election topic.
|
||||
action (str): Action to take for elected actor.
|
||||
"""
|
||||
if state.consumer.gossip:
|
||||
state.consumer.gossip.election(id, topic, action)
|
||||
|
||||
|
||||
@control_command()
|
||||
def enable_events(state):
|
||||
"""Tell worker(s) to send task-related events."""
|
||||
dispatcher = state.consumer.event_dispatcher
|
||||
if dispatcher.groups and 'task' not in dispatcher.groups:
|
||||
dispatcher.groups.add('task')
|
||||
logger.info('Events of group {task} enabled by remote.')
|
||||
return ok('task events enabled')
|
||||
return ok('task events already enabled')
|
||||
|
||||
|
||||
@control_command()
|
||||
def disable_events(state):
|
||||
"""Tell worker(s) to stop sending task-related events."""
|
||||
dispatcher = state.consumer.event_dispatcher
|
||||
if 'task' in dispatcher.groups:
|
||||
dispatcher.groups.discard('task')
|
||||
logger.info('Events of group {task} disabled by remote.')
|
||||
return ok('task events disabled')
|
||||
return ok('task events already disabled')
|
||||
|
||||
|
||||
@control_command()
|
||||
def heartbeat(state):
|
||||
"""Tell worker(s) to send event heartbeat immediately."""
|
||||
logger.debug('Heartbeat requested by remote.')
|
||||
dispatcher = state.consumer.event_dispatcher
|
||||
dispatcher.send('worker-heartbeat', freq=5, **worker_state.SOFTWARE_INFO)
|
||||
|
||||
|
||||
# -- Worker
|
||||
|
||||
@inspect_command(visible=False)
|
||||
def hello(state, from_node, revoked=None, **kwargs):
|
||||
"""Request mingle sync-data."""
|
||||
# pylint: disable=redefined-outer-name
|
||||
# XXX Note that this redefines `revoked`:
|
||||
# Outside of this scope that is a function.
|
||||
if from_node != state.hostname:
|
||||
logger.info('sync with %s', from_node)
|
||||
if revoked:
|
||||
worker_state.revoked.update(revoked)
|
||||
# Do not send expired items to the other worker.
|
||||
worker_state.revoked.purge()
|
||||
return {
|
||||
'revoked': worker_state.revoked._data,
|
||||
'clock': state.app.clock.forward(),
|
||||
}
|
||||
|
||||
|
||||
@inspect_command(default_timeout=0.2)
|
||||
def ping(state, **kwargs):
|
||||
"""Ping worker(s)."""
|
||||
return ok('pong')
|
||||
|
||||
|
||||
@inspect_command()
|
||||
def stats(state, **kwargs):
|
||||
"""Request worker statistics/information."""
|
||||
return state.consumer.controller.stats()
|
||||
|
||||
|
||||
@inspect_command(alias='dump_schedule')
|
||||
def scheduled(state, **kwargs):
|
||||
"""List of currently scheduled ETA/countdown tasks."""
|
||||
return list(_iter_schedule_requests(state.consumer.timer))
|
||||
|
||||
|
||||
def _iter_schedule_requests(timer):
|
||||
for waiting in timer.schedule.queue:
|
||||
try:
|
||||
arg0 = waiting.entry.args[0]
|
||||
except (IndexError, TypeError):
|
||||
continue
|
||||
else:
|
||||
if isinstance(arg0, Request):
|
||||
yield {
|
||||
'eta': arg0.eta.isoformat() if arg0.eta else None,
|
||||
'priority': waiting.priority,
|
||||
'request': arg0.info(),
|
||||
}
|
||||
|
||||
|
||||
@inspect_command(alias='dump_reserved')
|
||||
def reserved(state, **kwargs):
|
||||
"""List of currently reserved tasks, not including scheduled/active."""
|
||||
reserved_tasks = (
|
||||
state.tset(worker_state.reserved_requests) -
|
||||
state.tset(worker_state.active_requests)
|
||||
)
|
||||
if not reserved_tasks:
|
||||
return []
|
||||
return [request.info() for request in reserved_tasks]
|
||||
|
||||
|
||||
@inspect_command(alias='dump_active')
|
||||
def active(state, safe=False, **kwargs):
|
||||
"""List of tasks currently being executed."""
|
||||
return [request.info(safe=safe)
|
||||
for request in state.tset(worker_state.active_requests)]
|
||||
|
||||
|
||||
@inspect_command(alias='dump_revoked')
|
||||
def revoked(state, **kwargs):
|
||||
"""List of revoked task-ids."""
|
||||
return list(worker_state.revoked)
|
||||
|
||||
|
||||
@inspect_command(
|
||||
alias='dump_tasks',
|
||||
variadic='taskinfoitems',
|
||||
signature='[attr1 [attr2 [... [attrN]]]]',
|
||||
)
|
||||
def registered(state, taskinfoitems=None, builtins=False, **kwargs):
|
||||
"""List of registered tasks.
|
||||
|
||||
Arguments:
|
||||
taskinfoitems (Sequence[str]): List of task attributes to include.
|
||||
Defaults to ``exchange,routing_key,rate_limit``.
|
||||
builtins (bool): Also include built-in tasks.
|
||||
"""
|
||||
reg = state.app.tasks
|
||||
taskinfoitems = taskinfoitems or DEFAULT_TASK_INFO_ITEMS
|
||||
|
||||
tasks = reg if builtins else (
|
||||
task for task in reg if not task.startswith('celery.'))
|
||||
|
||||
def _extract_info(task):
|
||||
fields = {
|
||||
field: str(getattr(task, field, None)) for field in taskinfoitems
|
||||
if getattr(task, field, None) is not None
|
||||
}
|
||||
if fields:
|
||||
info = ['='.join(f) for f in fields.items()]
|
||||
return '{} [{}]'.format(task.name, ' '.join(info))
|
||||
return task.name
|
||||
|
||||
return [_extract_info(reg[task]) for task in sorted(tasks)]
|
||||
|
||||
|
||||
# -- Debugging
|
||||
|
||||
@inspect_command(
|
||||
default_timeout=60.0,
|
||||
args=[('type', str), ('num', int), ('max_depth', int)],
|
||||
signature='[object_type=Request] [num=200 [max_depth=10]]',
|
||||
)
|
||||
def objgraph(state, num=200, max_depth=10, type='Request'): # pragma: no cover
|
||||
"""Create graph of uncollected objects (memory-leak debugging).
|
||||
|
||||
Arguments:
|
||||
num (int): Max number of objects to graph.
|
||||
max_depth (int): Traverse at most n levels deep.
|
||||
type (str): Name of object to graph. Default is ``"Request"``.
|
||||
"""
|
||||
try:
|
||||
import objgraph as _objgraph
|
||||
except ImportError:
|
||||
raise ImportError('Requires the objgraph library')
|
||||
logger.info('Dumping graph for type %r', type)
|
||||
with tempfile.NamedTemporaryFile(prefix='cobjg',
|
||||
suffix='.png', delete=False) as fh:
|
||||
objects = _objgraph.by_type(type)[:num]
|
||||
_objgraph.show_backrefs(
|
||||
objects,
|
||||
max_depth=max_depth, highlight=lambda v: v in objects,
|
||||
filename=fh.name,
|
||||
)
|
||||
return {'filename': fh.name}
|
||||
|
||||
|
||||
@inspect_command()
|
||||
def memsample(state, **kwargs):
|
||||
"""Sample current RSS memory usage."""
|
||||
from celery.utils.debug import sample_mem
|
||||
return sample_mem()
|
||||
|
||||
|
||||
@inspect_command(
|
||||
args=[('samples', int)],
|
||||
signature='[n_samples=10]',
|
||||
)
|
||||
def memdump(state, samples=10, **kwargs): # pragma: no cover
|
||||
"""Dump statistics of previous memsample requests."""
|
||||
from celery.utils import debug
|
||||
out = io.StringIO()
|
||||
debug.memdump(file=out)
|
||||
return out.getvalue()
|
||||
|
||||
# -- Pool
|
||||
|
||||
|
||||
@control_command(
|
||||
args=[('n', int)],
|
||||
signature='[N=1]',
|
||||
)
|
||||
def pool_grow(state, n=1, **kwargs):
|
||||
"""Grow pool by n processes/threads."""
|
||||
if state.consumer.controller.autoscaler:
|
||||
return nok("pool_grow is not supported with autoscale. Adjust autoscale range instead.")
|
||||
else:
|
||||
state.consumer.pool.grow(n)
|
||||
state.consumer._update_prefetch_count(n)
|
||||
return ok('pool will grow')
|
||||
|
||||
|
||||
@control_command(
|
||||
args=[('n', int)],
|
||||
signature='[N=1]',
|
||||
)
|
||||
def pool_shrink(state, n=1, **kwargs):
|
||||
"""Shrink pool by n processes/threads."""
|
||||
if state.consumer.controller.autoscaler:
|
||||
return nok("pool_shrink is not supported with autoscale. Adjust autoscale range instead.")
|
||||
else:
|
||||
state.consumer.pool.shrink(n)
|
||||
state.consumer._update_prefetch_count(-n)
|
||||
return ok('pool will shrink')
|
||||
|
||||
|
||||
@control_command()
|
||||
def pool_restart(state, modules=None, reload=False, reloader=None, **kwargs):
|
||||
"""Restart execution pool."""
|
||||
if state.app.conf.worker_pool_restarts:
|
||||
state.consumer.controller.reload(modules, reload, reloader=reloader)
|
||||
return ok('reload started')
|
||||
else:
|
||||
raise ValueError('Pool restarts not enabled')
|
||||
|
||||
|
||||
@control_command(
|
||||
args=[('max', int), ('min', int)],
|
||||
signature='[max [min]]',
|
||||
)
|
||||
def autoscale(state, max=None, min=None):
|
||||
"""Modify autoscale settings."""
|
||||
autoscaler = state.consumer.controller.autoscaler
|
||||
if autoscaler:
|
||||
max_, min_ = autoscaler.update(max, min)
|
||||
return ok(f'autoscale now max={max_} min={min_}')
|
||||
raise ValueError('Autoscale not enabled')
|
||||
|
||||
|
||||
@control_command()
|
||||
def shutdown(state, msg='Got shutdown from remote', **kwargs):
|
||||
"""Shutdown worker(s)."""
|
||||
logger.warning(msg)
|
||||
raise WorkerShutdown(EX_OK)
|
||||
|
||||
|
||||
# -- Queues
|
||||
|
||||
@control_command(
|
||||
args=[
|
||||
('queue', str),
|
||||
('exchange', str),
|
||||
('exchange_type', str),
|
||||
('routing_key', str),
|
||||
],
|
||||
signature='<queue> [exchange [type [routing_key]]]',
|
||||
)
|
||||
def add_consumer(state, queue, exchange=None, exchange_type=None,
|
||||
routing_key=None, **options):
|
||||
"""Tell worker(s) to consume from task queue by name."""
|
||||
state.consumer.call_soon(
|
||||
state.consumer.add_task_queue,
|
||||
queue, exchange, exchange_type or 'direct', routing_key, **options)
|
||||
return ok(f'add consumer {queue}')
|
||||
|
||||
|
||||
@control_command(
|
||||
args=[('queue', str)],
|
||||
signature='<queue>',
|
||||
)
|
||||
def cancel_consumer(state, queue, **_):
|
||||
"""Tell worker(s) to stop consuming from task queue by name."""
|
||||
state.consumer.call_soon(
|
||||
state.consumer.cancel_task_queue, queue,
|
||||
)
|
||||
return ok(f'no longer consuming from {queue}')
|
||||
|
||||
|
||||
@inspect_command()
|
||||
def active_queues(state):
|
||||
"""List the task queues a worker is currently consuming from."""
|
||||
if state.consumer.task_consumer:
|
||||
return [dict(queue.as_dict(recurse=True))
|
||||
for queue in state.consumer.task_consumer.queues]
|
||||
return []
|
||||
@@ -0,0 +1,61 @@
|
||||
"""Heartbeat service.
|
||||
|
||||
This is the internal thread responsible for sending heartbeat events
|
||||
at regular intervals (may not be an actual thread).
|
||||
"""
|
||||
from celery.signals import heartbeat_sent
|
||||
from celery.utils.sysinfo import load_average
|
||||
|
||||
from .state import SOFTWARE_INFO, active_requests, all_total_count
|
||||
|
||||
__all__ = ('Heart',)
|
||||
|
||||
|
||||
class Heart:
|
||||
"""Timer sending heartbeats at regular intervals.
|
||||
|
||||
Arguments:
|
||||
timer (kombu.asynchronous.timer.Timer): Timer to use.
|
||||
eventer (celery.events.EventDispatcher): Event dispatcher
|
||||
to use.
|
||||
interval (float): Time in seconds between sending
|
||||
heartbeats. Default is 2 seconds.
|
||||
"""
|
||||
|
||||
def __init__(self, timer, eventer, interval=None):
|
||||
self.timer = timer
|
||||
self.eventer = eventer
|
||||
self.interval = float(interval or 2.0)
|
||||
self.tref = None
|
||||
|
||||
# Make event dispatcher start/stop us when enabled/disabled.
|
||||
self.eventer.on_enabled.add(self.start)
|
||||
self.eventer.on_disabled.add(self.stop)
|
||||
|
||||
# Only send heartbeat_sent signal if it has receivers.
|
||||
self._send_sent_signal = (
|
||||
heartbeat_sent.send if heartbeat_sent.receivers else None)
|
||||
|
||||
def _send(self, event, retry=True):
|
||||
if self._send_sent_signal is not None:
|
||||
self._send_sent_signal(sender=self)
|
||||
return self.eventer.send(event, freq=self.interval,
|
||||
active=len(active_requests),
|
||||
processed=all_total_count[0],
|
||||
loadavg=load_average(),
|
||||
retry=retry,
|
||||
**SOFTWARE_INFO)
|
||||
|
||||
def start(self):
|
||||
if self.eventer.enabled:
|
||||
self._send('worker-online')
|
||||
self.tref = self.timer.call_repeatedly(
|
||||
self.interval, self._send, ('worker-heartbeat',),
|
||||
)
|
||||
|
||||
def stop(self):
|
||||
if self.tref is not None:
|
||||
self.timer.cancel(self.tref)
|
||||
self.tref = None
|
||||
if self.eventer.enabled:
|
||||
self._send('worker-offline', retry=False)
|
||||
143
ETB-API/venv/lib/python3.12/site-packages/celery/worker/loops.py
Normal file
143
ETB-API/venv/lib/python3.12/site-packages/celery/worker/loops.py
Normal file
@@ -0,0 +1,143 @@
|
||||
"""The consumers highly-optimized inner loop."""
|
||||
import errno
|
||||
import socket
|
||||
|
||||
from celery import bootsteps
|
||||
from celery.exceptions import WorkerLostError
|
||||
from celery.utils.log import get_logger
|
||||
|
||||
from . import state
|
||||
|
||||
__all__ = ('asynloop', 'synloop')
|
||||
|
||||
# pylint: disable=redefined-outer-name
|
||||
# We cache globals and attribute lookups, so disable this warning.
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _quick_drain(connection, timeout=0.1):
|
||||
try:
|
||||
connection.drain_events(timeout=timeout)
|
||||
except Exception as exc: # pylint: disable=broad-except
|
||||
exc_errno = getattr(exc, 'errno', None)
|
||||
if exc_errno is not None and exc_errno != errno.EAGAIN:
|
||||
raise
|
||||
|
||||
|
||||
def _enable_amqheartbeats(timer, connection, rate=2.0):
|
||||
heartbeat_error = [None]
|
||||
|
||||
if not connection:
|
||||
return heartbeat_error
|
||||
|
||||
heartbeat = connection.get_heartbeat_interval() # negotiated
|
||||
if not (heartbeat and connection.supports_heartbeats):
|
||||
return heartbeat_error
|
||||
|
||||
def tick(rate):
|
||||
try:
|
||||
connection.heartbeat_check(rate)
|
||||
except Exception as e:
|
||||
# heartbeat_error is passed by reference can be updated
|
||||
# no append here list should be fixed size=1
|
||||
heartbeat_error[0] = e
|
||||
|
||||
timer.call_repeatedly(heartbeat / rate, tick, (rate,))
|
||||
return heartbeat_error
|
||||
|
||||
|
||||
def asynloop(obj, connection, consumer, blueprint, hub, qos,
|
||||
heartbeat, clock, hbrate=2.0):
|
||||
"""Non-blocking event loop."""
|
||||
RUN = bootsteps.RUN
|
||||
update_qos = qos.update
|
||||
errors = connection.connection_errors
|
||||
|
||||
on_task_received = obj.create_task_handler()
|
||||
|
||||
heartbeat_error = _enable_amqheartbeats(hub.timer, connection, rate=hbrate)
|
||||
|
||||
consumer.on_message = on_task_received
|
||||
obj.controller.register_with_event_loop(hub)
|
||||
obj.register_with_event_loop(hub)
|
||||
consumer.consume()
|
||||
obj.on_ready()
|
||||
|
||||
# did_start_ok will verify that pool processes were able to start,
|
||||
# but this will only work the first time we start, as
|
||||
# maxtasksperchild will mess up metrics.
|
||||
if not obj.restart_count and not obj.pool.did_start_ok():
|
||||
raise WorkerLostError('Could not start worker processes')
|
||||
|
||||
# consumer.consume() may have prefetched up to our
|
||||
# limit - drain an event so we're in a clean state
|
||||
# prior to starting our event loop.
|
||||
if connection.transport.driver_type == 'amqp':
|
||||
hub.call_soon(_quick_drain, connection)
|
||||
|
||||
# FIXME: Use loop.run_forever
|
||||
# Tried and works, but no time to test properly before release.
|
||||
hub.propagate_errors = errors
|
||||
loop = hub.create_loop()
|
||||
|
||||
try:
|
||||
while blueprint.state == RUN and obj.connection:
|
||||
state.maybe_shutdown()
|
||||
if heartbeat_error[0] is not None:
|
||||
raise heartbeat_error[0]
|
||||
|
||||
# We only update QoS when there's no more messages to read.
|
||||
# This groups together qos calls, and makes sure that remote
|
||||
# control commands will be prioritized over task messages.
|
||||
if qos.prev != qos.value:
|
||||
update_qos()
|
||||
|
||||
try:
|
||||
next(loop)
|
||||
except StopIteration:
|
||||
loop = hub.create_loop()
|
||||
finally:
|
||||
try:
|
||||
hub.reset()
|
||||
except Exception as exc: # pylint: disable=broad-except
|
||||
logger.exception(
|
||||
'Error cleaning up after event loop: %r', exc)
|
||||
|
||||
|
||||
def synloop(obj, connection, consumer, blueprint, hub, qos,
|
||||
heartbeat, clock, hbrate=2.0, **kwargs):
|
||||
"""Fallback blocking event loop for transports that doesn't support AIO."""
|
||||
RUN = bootsteps.RUN
|
||||
on_task_received = obj.create_task_handler()
|
||||
perform_pending_operations = obj.perform_pending_operations
|
||||
heartbeat_error = [None]
|
||||
if getattr(obj.pool, 'is_green', False):
|
||||
heartbeat_error = _enable_amqheartbeats(obj.timer, connection, rate=hbrate)
|
||||
consumer.on_message = on_task_received
|
||||
consumer.consume()
|
||||
|
||||
obj.on_ready()
|
||||
|
||||
def _loop_cycle():
|
||||
"""
|
||||
Perform one iteration of the blocking event loop.
|
||||
"""
|
||||
if heartbeat_error[0] is not None:
|
||||
raise heartbeat_error[0]
|
||||
if qos.prev != qos.value:
|
||||
qos.update()
|
||||
try:
|
||||
perform_pending_operations()
|
||||
connection.drain_events(timeout=2.0)
|
||||
except socket.timeout:
|
||||
pass
|
||||
except OSError:
|
||||
if blueprint.state == RUN:
|
||||
raise
|
||||
|
||||
while blueprint.state == RUN and obj.connection:
|
||||
try:
|
||||
state.maybe_shutdown()
|
||||
finally:
|
||||
_loop_cycle()
|
||||
@@ -0,0 +1,122 @@
|
||||
"""Worker Pidbox (remote control)."""
|
||||
import socket
|
||||
import threading
|
||||
|
||||
from kombu.common import ignore_errors
|
||||
from kombu.utils.encoding import safe_str
|
||||
|
||||
from celery.utils.collections import AttributeDict
|
||||
from celery.utils.functional import pass1
|
||||
from celery.utils.log import get_logger
|
||||
|
||||
from . import control
|
||||
|
||||
__all__ = ('Pidbox', 'gPidbox')
|
||||
|
||||
logger = get_logger(__name__)
|
||||
debug, error, info = logger.debug, logger.error, logger.info
|
||||
|
||||
|
||||
class Pidbox:
|
||||
"""Worker mailbox."""
|
||||
|
||||
consumer = None
|
||||
|
||||
def __init__(self, c):
|
||||
self.c = c
|
||||
self.hostname = c.hostname
|
||||
self.node = c.app.control.mailbox.Node(
|
||||
safe_str(c.hostname),
|
||||
handlers=control.Panel.data,
|
||||
state=AttributeDict(
|
||||
app=c.app,
|
||||
hostname=c.hostname,
|
||||
consumer=c,
|
||||
tset=pass1 if c.controller.use_eventloop else set),
|
||||
)
|
||||
self._forward_clock = self.c.app.clock.forward
|
||||
|
||||
def on_message(self, body, message):
|
||||
# just increase clock as clients usually don't
|
||||
# have a valid clock to adjust with.
|
||||
self._forward_clock()
|
||||
try:
|
||||
self.node.handle_message(body, message)
|
||||
except KeyError as exc:
|
||||
error('No such control command: %s', exc)
|
||||
except Exception as exc:
|
||||
error('Control command error: %r', exc, exc_info=True)
|
||||
self.reset()
|
||||
|
||||
def start(self, c):
|
||||
self.node.channel = c.connection.channel()
|
||||
self.consumer = self.node.listen(callback=self.on_message)
|
||||
self.consumer.on_decode_error = c.on_decode_error
|
||||
|
||||
def on_stop(self):
|
||||
pass
|
||||
|
||||
def stop(self, c):
|
||||
self.on_stop()
|
||||
self.consumer = self._close_channel(c)
|
||||
|
||||
def reset(self):
|
||||
self.stop(self.c)
|
||||
self.start(self.c)
|
||||
|
||||
def _close_channel(self, c):
|
||||
if self.node and self.node.channel:
|
||||
ignore_errors(c, self.node.channel.close)
|
||||
|
||||
def shutdown(self, c):
|
||||
self.on_stop()
|
||||
if self.consumer:
|
||||
debug('Canceling broadcast consumer...')
|
||||
ignore_errors(c, self.consumer.cancel)
|
||||
self.stop(self.c)
|
||||
|
||||
|
||||
class gPidbox(Pidbox):
|
||||
"""Worker pidbox (greenlet)."""
|
||||
|
||||
_node_shutdown = None
|
||||
_node_stopped = None
|
||||
_resets = 0
|
||||
|
||||
def start(self, c):
|
||||
c.pool.spawn_n(self.loop, c)
|
||||
|
||||
def on_stop(self):
|
||||
if self._node_stopped:
|
||||
self._node_shutdown.set()
|
||||
debug('Waiting for broadcast thread to shutdown...')
|
||||
self._node_stopped.wait()
|
||||
self._node_stopped = self._node_shutdown = None
|
||||
|
||||
def reset(self):
|
||||
self._resets += 1
|
||||
|
||||
def _do_reset(self, c, connection):
|
||||
self._close_channel(c)
|
||||
self.node.channel = connection.channel()
|
||||
self.consumer = self.node.listen(callback=self.on_message)
|
||||
self.consumer.consume()
|
||||
|
||||
def loop(self, c):
|
||||
resets = [self._resets]
|
||||
shutdown = self._node_shutdown = threading.Event()
|
||||
stopped = self._node_stopped = threading.Event()
|
||||
try:
|
||||
with c.connection_for_read() as connection:
|
||||
info('pidbox: Connected to %s.', connection.as_uri())
|
||||
self._do_reset(c, connection)
|
||||
while not shutdown.is_set() and c.connection:
|
||||
if resets[0] < self._resets:
|
||||
resets[0] += 1
|
||||
self._do_reset(c, connection)
|
||||
try:
|
||||
connection.drain_events(timeout=1.0)
|
||||
except socket.timeout:
|
||||
pass
|
||||
finally:
|
||||
stopped.set()
|
||||
@@ -0,0 +1,790 @@
|
||||
"""Task request.
|
||||
|
||||
This module defines the :class:`Request` class, that specifies
|
||||
how tasks are executed.
|
||||
"""
|
||||
import logging
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from time import monotonic, time
|
||||
from weakref import ref
|
||||
|
||||
from billiard.common import TERM_SIGNAME
|
||||
from billiard.einfo import ExceptionWithTraceback
|
||||
from kombu.utils.encoding import safe_repr, safe_str
|
||||
from kombu.utils.objects import cached_property
|
||||
|
||||
from celery import current_app, signals
|
||||
from celery.app.task import Context
|
||||
from celery.app.trace import fast_trace_task, trace_task, trace_task_ret
|
||||
from celery.concurrency.base import BasePool
|
||||
from celery.exceptions import (Ignore, InvalidTaskError, Reject, Retry, TaskRevokedError, Terminated,
|
||||
TimeLimitExceeded, WorkerLostError)
|
||||
from celery.platforms import signals as _signals
|
||||
from celery.utils.functional import maybe, maybe_list, noop
|
||||
from celery.utils.log import get_logger
|
||||
from celery.utils.nodenames import gethostname
|
||||
from celery.utils.serialization import get_pickled_exception
|
||||
from celery.utils.time import maybe_iso8601, maybe_make_aware, timezone
|
||||
|
||||
from . import state
|
||||
|
||||
__all__ = ('Request',)
|
||||
|
||||
# pylint: disable=redefined-outer-name
|
||||
# We cache globals and attribute lookups, so disable this warning.
|
||||
|
||||
IS_PYPY = hasattr(sys, 'pypy_version_info')
|
||||
|
||||
logger = get_logger(__name__)
|
||||
debug, info, warn, error = (logger.debug, logger.info,
|
||||
logger.warning, logger.error)
|
||||
_does_info = False
|
||||
_does_debug = False
|
||||
|
||||
|
||||
def __optimize__():
|
||||
# this is also called by celery.app.trace.setup_worker_optimizations
|
||||
global _does_debug
|
||||
global _does_info
|
||||
_does_debug = logger.isEnabledFor(logging.DEBUG)
|
||||
_does_info = logger.isEnabledFor(logging.INFO)
|
||||
|
||||
|
||||
__optimize__()
|
||||
|
||||
# Localize
|
||||
tz_or_local = timezone.tz_or_local
|
||||
send_revoked = signals.task_revoked.send
|
||||
send_retry = signals.task_retry.send
|
||||
|
||||
task_accepted = state.task_accepted
|
||||
task_ready = state.task_ready
|
||||
revoked_tasks = state.revoked
|
||||
revoked_stamps = state.revoked_stamps
|
||||
|
||||
|
||||
class Request:
|
||||
"""A request for task execution."""
|
||||
|
||||
acknowledged = False
|
||||
time_start = None
|
||||
worker_pid = None
|
||||
time_limits = (None, None)
|
||||
_already_revoked = False
|
||||
_already_cancelled = False
|
||||
_terminate_on_ack = None
|
||||
_apply_result = None
|
||||
_tzlocal = None
|
||||
|
||||
if not IS_PYPY: # pragma: no cover
|
||||
__slots__ = (
|
||||
'_app', '_type', 'name', 'id', '_root_id', '_parent_id',
|
||||
'_on_ack', '_body', '_hostname', '_eventer', '_connection_errors',
|
||||
'_task', '_eta', '_expires', '_request_dict', '_on_reject', '_utc',
|
||||
'_content_type', '_content_encoding', '_argsrepr', '_kwargsrepr',
|
||||
'_args', '_kwargs', '_decoded', '__payload',
|
||||
'__weakref__', '__dict__',
|
||||
)
|
||||
|
||||
def __init__(self, message, on_ack=noop,
|
||||
hostname=None, eventer=None, app=None,
|
||||
connection_errors=None, request_dict=None,
|
||||
task=None, on_reject=noop, body=None,
|
||||
headers=None, decoded=False, utc=True,
|
||||
maybe_make_aware=maybe_make_aware,
|
||||
maybe_iso8601=maybe_iso8601, **opts):
|
||||
self._message = message
|
||||
self._request_dict = (message.headers.copy() if headers is None
|
||||
else headers.copy())
|
||||
self._body = message.body if body is None else body
|
||||
self._app = app
|
||||
self._utc = utc
|
||||
self._decoded = decoded
|
||||
if decoded:
|
||||
self._content_type = self._content_encoding = None
|
||||
else:
|
||||
self._content_type, self._content_encoding = (
|
||||
message.content_type, message.content_encoding,
|
||||
)
|
||||
self.__payload = self._body if self._decoded else message.payload
|
||||
self.id = self._request_dict['id']
|
||||
self._type = self.name = self._request_dict['task']
|
||||
if 'shadow' in self._request_dict:
|
||||
self.name = self._request_dict['shadow'] or self.name
|
||||
self._root_id = self._request_dict.get('root_id')
|
||||
self._parent_id = self._request_dict.get('parent_id')
|
||||
timelimit = self._request_dict.get('timelimit', None)
|
||||
if timelimit:
|
||||
self.time_limits = timelimit
|
||||
self._argsrepr = self._request_dict.get('argsrepr', '')
|
||||
self._kwargsrepr = self._request_dict.get('kwargsrepr', '')
|
||||
self._on_ack = on_ack
|
||||
self._on_reject = on_reject
|
||||
self._hostname = hostname or gethostname()
|
||||
self._eventer = eventer
|
||||
self._connection_errors = connection_errors or ()
|
||||
self._task = task or self._app.tasks[self._type]
|
||||
self._ignore_result = self._request_dict.get('ignore_result', False)
|
||||
|
||||
# timezone means the message is timezone-aware, and the only timezone
|
||||
# supported at this point is UTC.
|
||||
eta = self._request_dict.get('eta')
|
||||
if eta is not None:
|
||||
try:
|
||||
eta = maybe_iso8601(eta)
|
||||
except (AttributeError, ValueError, TypeError) as exc:
|
||||
raise InvalidTaskError(
|
||||
f'invalid ETA value {eta!r}: {exc}')
|
||||
self._eta = maybe_make_aware(eta, self.tzlocal)
|
||||
else:
|
||||
self._eta = None
|
||||
|
||||
expires = self._request_dict.get('expires')
|
||||
if expires is not None:
|
||||
try:
|
||||
expires = maybe_iso8601(expires)
|
||||
except (AttributeError, ValueError, TypeError) as exc:
|
||||
raise InvalidTaskError(
|
||||
f'invalid expires value {expires!r}: {exc}')
|
||||
self._expires = maybe_make_aware(expires, self.tzlocal)
|
||||
else:
|
||||
self._expires = None
|
||||
|
||||
delivery_info = message.delivery_info or {}
|
||||
properties = message.properties or {}
|
||||
self._delivery_info = {
|
||||
'exchange': delivery_info.get('exchange'),
|
||||
'routing_key': delivery_info.get('routing_key'),
|
||||
'priority': properties.get('priority'),
|
||||
'redelivered': delivery_info.get('redelivered', False),
|
||||
}
|
||||
self._request_dict.update({
|
||||
'properties': properties,
|
||||
'reply_to': properties.get('reply_to'),
|
||||
'correlation_id': properties.get('correlation_id'),
|
||||
'hostname': self._hostname,
|
||||
'delivery_info': self._delivery_info
|
||||
})
|
||||
# this is a reference pass to avoid memory usage burst
|
||||
self._request_dict['args'], self._request_dict['kwargs'], _ = self.__payload
|
||||
self._args = self._request_dict['args']
|
||||
self._kwargs = self._request_dict['kwargs']
|
||||
|
||||
@property
|
||||
def delivery_info(self):
|
||||
return self._delivery_info
|
||||
|
||||
@property
|
||||
def message(self):
|
||||
return self._message
|
||||
|
||||
@property
|
||||
def request_dict(self):
|
||||
return self._request_dict
|
||||
|
||||
@property
|
||||
def body(self):
|
||||
return self._body
|
||||
|
||||
@property
|
||||
def app(self):
|
||||
return self._app
|
||||
|
||||
@property
|
||||
def utc(self):
|
||||
return self._utc
|
||||
|
||||
@property
|
||||
def content_type(self):
|
||||
return self._content_type
|
||||
|
||||
@property
|
||||
def content_encoding(self):
|
||||
return self._content_encoding
|
||||
|
||||
@property
|
||||
def type(self):
|
||||
return self._type
|
||||
|
||||
@property
|
||||
def root_id(self):
|
||||
return self._root_id
|
||||
|
||||
@property
|
||||
def parent_id(self):
|
||||
return self._parent_id
|
||||
|
||||
@property
|
||||
def argsrepr(self):
|
||||
return self._argsrepr
|
||||
|
||||
@property
|
||||
def args(self):
|
||||
return self._args
|
||||
|
||||
@property
|
||||
def kwargs(self):
|
||||
return self._kwargs
|
||||
|
||||
@property
|
||||
def kwargsrepr(self):
|
||||
return self._kwargsrepr
|
||||
|
||||
@property
|
||||
def on_ack(self):
|
||||
return self._on_ack
|
||||
|
||||
@property
|
||||
def on_reject(self):
|
||||
return self._on_reject
|
||||
|
||||
@on_reject.setter
|
||||
def on_reject(self, value):
|
||||
self._on_reject = value
|
||||
|
||||
@property
|
||||
def hostname(self):
|
||||
return self._hostname
|
||||
|
||||
@property
|
||||
def ignore_result(self):
|
||||
return self._ignore_result
|
||||
|
||||
@property
|
||||
def eventer(self):
|
||||
return self._eventer
|
||||
|
||||
@eventer.setter
|
||||
def eventer(self, eventer):
|
||||
self._eventer = eventer
|
||||
|
||||
@property
|
||||
def connection_errors(self):
|
||||
return self._connection_errors
|
||||
|
||||
@property
|
||||
def task(self):
|
||||
return self._task
|
||||
|
||||
@property
|
||||
def eta(self):
|
||||
return self._eta
|
||||
|
||||
@property
|
||||
def expires(self):
|
||||
return self._expires
|
||||
|
||||
@expires.setter
|
||||
def expires(self, value):
|
||||
self._expires = value
|
||||
|
||||
@property
|
||||
def tzlocal(self):
|
||||
if self._tzlocal is None:
|
||||
self._tzlocal = self._app.conf.timezone
|
||||
return self._tzlocal
|
||||
|
||||
@property
|
||||
def store_errors(self):
|
||||
return (not self.task.ignore_result or
|
||||
self.task.store_errors_even_if_ignored)
|
||||
|
||||
@property
|
||||
def task_id(self):
|
||||
# XXX compat
|
||||
return self.id
|
||||
|
||||
@task_id.setter
|
||||
def task_id(self, value):
|
||||
self.id = value
|
||||
|
||||
@property
|
||||
def task_name(self):
|
||||
# XXX compat
|
||||
return self.name
|
||||
|
||||
@task_name.setter
|
||||
def task_name(self, value):
|
||||
self.name = value
|
||||
|
||||
@property
|
||||
def reply_to(self):
|
||||
# used by rpc backend when failures reported by parent process
|
||||
return self._request_dict['reply_to']
|
||||
|
||||
@property
|
||||
def replaced_task_nesting(self):
|
||||
return self._request_dict.get('replaced_task_nesting', 0)
|
||||
|
||||
@property
|
||||
def groups(self):
|
||||
return self._request_dict.get('groups', [])
|
||||
|
||||
@property
|
||||
def stamped_headers(self) -> list:
|
||||
return self._request_dict.get('stamped_headers') or []
|
||||
|
||||
@property
|
||||
def stamps(self) -> dict:
|
||||
stamps = self._request_dict.get('stamps') or {}
|
||||
return {header: stamps.get(header) for header in self.stamped_headers}
|
||||
|
||||
@property
|
||||
def correlation_id(self):
|
||||
# used similarly to reply_to
|
||||
return self._request_dict['correlation_id']
|
||||
|
||||
def execute_using_pool(self, pool: BasePool, **kwargs):
|
||||
"""Used by the worker to send this task to the pool.
|
||||
|
||||
Arguments:
|
||||
pool (~celery.concurrency.base.TaskPool): The execution pool
|
||||
used to execute this request.
|
||||
|
||||
Raises:
|
||||
celery.exceptions.TaskRevokedError: if the task was revoked.
|
||||
"""
|
||||
task_id = self.id
|
||||
task = self._task
|
||||
if self.revoked():
|
||||
raise TaskRevokedError(task_id)
|
||||
|
||||
time_limit, soft_time_limit = self.time_limits
|
||||
trace = fast_trace_task if self._app.use_fast_trace_task else trace_task_ret
|
||||
result = pool.apply_async(
|
||||
trace,
|
||||
args=(self._type, task_id, self._request_dict, self._body,
|
||||
self._content_type, self._content_encoding),
|
||||
accept_callback=self.on_accepted,
|
||||
timeout_callback=self.on_timeout,
|
||||
callback=self.on_success,
|
||||
error_callback=self.on_failure,
|
||||
soft_timeout=soft_time_limit or task.soft_time_limit,
|
||||
timeout=time_limit or task.time_limit,
|
||||
correlation_id=task_id,
|
||||
)
|
||||
# cannot create weakref to None
|
||||
self._apply_result = maybe(ref, result)
|
||||
return result
|
||||
|
||||
def execute(self, loglevel=None, logfile=None):
|
||||
"""Execute the task in a :func:`~celery.app.trace.trace_task`.
|
||||
|
||||
Arguments:
|
||||
loglevel (int): The loglevel used by the task.
|
||||
logfile (str): The logfile used by the task.
|
||||
"""
|
||||
if self.revoked():
|
||||
return
|
||||
|
||||
# acknowledge task as being processed.
|
||||
if not self.task.acks_late:
|
||||
self.acknowledge()
|
||||
|
||||
_, _, embed = self._payload
|
||||
request = self._request_dict
|
||||
# pylint: disable=unpacking-non-sequence
|
||||
# payload is a property, so pylint doesn't think it's a tuple.
|
||||
request.update({
|
||||
'loglevel': loglevel,
|
||||
'logfile': logfile,
|
||||
'is_eager': False,
|
||||
}, **embed or {})
|
||||
|
||||
retval, I, _, _ = trace_task(self.task, self.id, self._args, self._kwargs, request,
|
||||
hostname=self._hostname, loader=self._app.loader,
|
||||
app=self._app)
|
||||
|
||||
if I:
|
||||
self.reject(requeue=False)
|
||||
else:
|
||||
self.acknowledge()
|
||||
return retval
|
||||
|
||||
def maybe_expire(self):
|
||||
"""If expired, mark the task as revoked."""
|
||||
if self.expires:
|
||||
now = datetime.now(self.expires.tzinfo)
|
||||
if now > self.expires:
|
||||
revoked_tasks.add(self.id)
|
||||
return True
|
||||
|
||||
def terminate(self, pool, signal=None):
|
||||
signal = _signals.signum(signal or TERM_SIGNAME)
|
||||
if self.time_start:
|
||||
pool.terminate_job(self.worker_pid, signal)
|
||||
self._announce_revoked('terminated', True, signal, False)
|
||||
else:
|
||||
self._terminate_on_ack = pool, signal
|
||||
if self._apply_result is not None:
|
||||
obj = self._apply_result() # is a weakref
|
||||
if obj is not None:
|
||||
obj.terminate(signal)
|
||||
|
||||
def cancel(self, pool, signal=None):
|
||||
signal = _signals.signum(signal or TERM_SIGNAME)
|
||||
if self.time_start:
|
||||
pool.terminate_job(self.worker_pid, signal)
|
||||
self._announce_cancelled()
|
||||
|
||||
if self._apply_result is not None:
|
||||
obj = self._apply_result() # is a weakref
|
||||
if obj is not None:
|
||||
obj.terminate(signal)
|
||||
|
||||
def _announce_cancelled(self):
|
||||
task_ready(self)
|
||||
self.send_event('task-cancelled')
|
||||
reason = 'cancelled by Celery'
|
||||
exc = Retry(message=reason)
|
||||
self.task.backend.mark_as_retry(self.id,
|
||||
exc,
|
||||
request=self._context)
|
||||
|
||||
self.task.on_retry(exc, self.id, self.args, self.kwargs, None)
|
||||
self._already_cancelled = True
|
||||
send_retry(self.task, request=self._context, einfo=None)
|
||||
|
||||
def _announce_revoked(self, reason, terminated, signum, expired):
|
||||
task_ready(self)
|
||||
self.send_event('task-revoked',
|
||||
terminated=terminated, signum=signum, expired=expired)
|
||||
self.task.backend.mark_as_revoked(
|
||||
self.id, reason, request=self._context,
|
||||
store_result=self.store_errors,
|
||||
)
|
||||
self.acknowledge()
|
||||
self._already_revoked = True
|
||||
send_revoked(self.task, request=self._context,
|
||||
terminated=terminated, signum=signum, expired=expired)
|
||||
|
||||
def revoked(self):
|
||||
"""If revoked, skip task and mark state."""
|
||||
expired = False
|
||||
if self._already_revoked:
|
||||
return True
|
||||
if self.expires:
|
||||
expired = self.maybe_expire()
|
||||
revoked_by_id = self.id in revoked_tasks
|
||||
revoked_by_header, revoking_header = False, None
|
||||
|
||||
if not revoked_by_id and self.stamped_headers:
|
||||
for stamp in self.stamped_headers:
|
||||
if stamp in revoked_stamps:
|
||||
revoked_header = revoked_stamps[stamp]
|
||||
stamped_header = self._message.headers['stamps'][stamp]
|
||||
|
||||
if isinstance(stamped_header, (list, tuple)):
|
||||
for stamped_value in stamped_header:
|
||||
if stamped_value in maybe_list(revoked_header):
|
||||
revoked_by_header = True
|
||||
revoking_header = {stamp: stamped_value}
|
||||
break
|
||||
else:
|
||||
revoked_by_header = any([
|
||||
stamped_header in maybe_list(revoked_header),
|
||||
stamped_header == revoked_header, # When the header is a single set value
|
||||
])
|
||||
revoking_header = {stamp: stamped_header}
|
||||
break
|
||||
|
||||
if any((expired, revoked_by_id, revoked_by_header)):
|
||||
log_msg = 'Discarding revoked task: %s[%s]'
|
||||
if revoked_by_header:
|
||||
log_msg += ' (revoked by header: %s)' % revoking_header
|
||||
info(log_msg, self.name, self.id)
|
||||
self._announce_revoked(
|
||||
'expired' if expired else 'revoked', False, None, expired,
|
||||
)
|
||||
return True
|
||||
return False
|
||||
|
||||
def send_event(self, type, **fields):
|
||||
if self._eventer and self._eventer.enabled and self.task.send_events:
|
||||
self._eventer.send(type, uuid=self.id, **fields)
|
||||
|
||||
def on_accepted(self, pid, time_accepted):
|
||||
"""Handler called when task is accepted by worker pool."""
|
||||
self.worker_pid = pid
|
||||
# Convert monotonic time_accepted to absolute time
|
||||
self.time_start = time() - (monotonic() - time_accepted)
|
||||
task_accepted(self)
|
||||
if not self.task.acks_late:
|
||||
self.acknowledge()
|
||||
self.send_event('task-started')
|
||||
if _does_debug:
|
||||
debug('Task accepted: %s[%s] pid:%r', self.name, self.id, pid)
|
||||
if self._terminate_on_ack is not None:
|
||||
self.terminate(*self._terminate_on_ack)
|
||||
|
||||
def on_timeout(self, soft, timeout):
|
||||
"""Handler called if the task times out."""
|
||||
if soft:
|
||||
warn('Soft time limit (%ss) exceeded for %s[%s]',
|
||||
timeout, self.name, self.id)
|
||||
else:
|
||||
task_ready(self)
|
||||
error('Hard time limit (%ss) exceeded for %s[%s]',
|
||||
timeout, self.name, self.id)
|
||||
exc = TimeLimitExceeded(timeout)
|
||||
|
||||
self.task.backend.mark_as_failure(
|
||||
self.id, exc, request=self._context,
|
||||
store_result=self.store_errors,
|
||||
)
|
||||
|
||||
if self.task.acks_late and self.task.acks_on_failure_or_timeout:
|
||||
self.acknowledge()
|
||||
|
||||
def on_success(self, failed__retval__runtime, **kwargs):
|
||||
"""Handler called if the task was successfully processed."""
|
||||
failed, retval, runtime = failed__retval__runtime
|
||||
if failed:
|
||||
exc = retval.exception
|
||||
if isinstance(exc, ExceptionWithTraceback):
|
||||
exc = exc.exc
|
||||
if isinstance(exc, (SystemExit, KeyboardInterrupt)):
|
||||
raise exc
|
||||
return self.on_failure(retval, return_ok=True)
|
||||
task_ready(self, successful=True)
|
||||
|
||||
if self.task.acks_late:
|
||||
self.acknowledge()
|
||||
|
||||
self.send_event('task-succeeded', result=retval, runtime=runtime)
|
||||
|
||||
def on_retry(self, exc_info):
|
||||
"""Handler called if the task should be retried."""
|
||||
if self.task.acks_late:
|
||||
self.acknowledge()
|
||||
|
||||
self.send_event('task-retried',
|
||||
exception=safe_repr(exc_info.exception.exc),
|
||||
traceback=safe_str(exc_info.traceback))
|
||||
|
||||
def on_failure(self, exc_info, send_failed_event=True, return_ok=False):
|
||||
"""Handler called if the task raised an exception."""
|
||||
task_ready(self)
|
||||
exc = exc_info.exception
|
||||
|
||||
if isinstance(exc, ExceptionWithTraceback):
|
||||
exc = exc.exc
|
||||
|
||||
is_terminated = isinstance(exc, Terminated)
|
||||
if is_terminated:
|
||||
# If the task was terminated and the task was not cancelled due
|
||||
# to a connection loss, it is revoked.
|
||||
|
||||
# We always cancel the tasks inside the master process.
|
||||
# If the request was cancelled, it was not revoked and there's
|
||||
# nothing to be done.
|
||||
# According to the comment below, we need to check if the task
|
||||
# is already revoked and if it wasn't, we should announce that
|
||||
# it was.
|
||||
if not self._already_cancelled and not self._already_revoked:
|
||||
# This is a special case where the process
|
||||
# would not have had time to write the result.
|
||||
self._announce_revoked(
|
||||
'terminated', True, str(exc), False)
|
||||
return
|
||||
elif isinstance(exc, MemoryError):
|
||||
raise MemoryError(f'Process got: {exc}')
|
||||
elif isinstance(exc, Reject):
|
||||
return self.reject(requeue=exc.requeue)
|
||||
elif isinstance(exc, Ignore):
|
||||
return self.acknowledge()
|
||||
elif isinstance(exc, Retry):
|
||||
return self.on_retry(exc_info)
|
||||
|
||||
# (acks_late) acknowledge after result stored.
|
||||
requeue = False
|
||||
is_worker_lost = isinstance(exc, WorkerLostError)
|
||||
if self.task.acks_late:
|
||||
reject = (
|
||||
(self.task.reject_on_worker_lost and is_worker_lost)
|
||||
or (isinstance(exc, TimeLimitExceeded) and not self.task.acks_on_failure_or_timeout)
|
||||
)
|
||||
ack = self.task.acks_on_failure_or_timeout
|
||||
if reject:
|
||||
requeue = True
|
||||
self.reject(requeue=requeue)
|
||||
send_failed_event = False
|
||||
elif ack:
|
||||
self.acknowledge()
|
||||
else:
|
||||
# supporting the behaviour where a task failed and
|
||||
# need to be removed from prefetched local queue
|
||||
self.reject(requeue=False)
|
||||
|
||||
# This is a special case where the process would not have had time
|
||||
# to write the result.
|
||||
if not requeue and (is_worker_lost or not return_ok):
|
||||
# only mark as failure if task has not been requeued
|
||||
self.task.backend.mark_as_failure(
|
||||
self.id, exc, request=self._context,
|
||||
store_result=self.store_errors,
|
||||
)
|
||||
|
||||
signals.task_failure.send(sender=self.task, task_id=self.id,
|
||||
exception=exc, args=self.args,
|
||||
kwargs=self.kwargs,
|
||||
traceback=exc_info.traceback,
|
||||
einfo=exc_info)
|
||||
|
||||
if send_failed_event:
|
||||
self.send_event(
|
||||
'task-failed',
|
||||
exception=safe_repr(get_pickled_exception(exc_info.exception)),
|
||||
traceback=exc_info.traceback,
|
||||
)
|
||||
|
||||
if not return_ok:
|
||||
error('Task handler raised error: %r', exc,
|
||||
exc_info=exc_info.exc_info)
|
||||
|
||||
def acknowledge(self):
|
||||
"""Acknowledge task."""
|
||||
if not self.acknowledged:
|
||||
self._on_ack(logger, self._connection_errors)
|
||||
self.acknowledged = True
|
||||
|
||||
def reject(self, requeue=False):
|
||||
if not self.acknowledged:
|
||||
self._on_reject(logger, self._connection_errors, requeue)
|
||||
self.acknowledged = True
|
||||
self.send_event('task-rejected', requeue=requeue)
|
||||
|
||||
def info(self, safe=False):
|
||||
return {
|
||||
'id': self.id,
|
||||
'name': self.name,
|
||||
'args': self._args if not safe else self._argsrepr,
|
||||
'kwargs': self._kwargs if not safe else self._kwargsrepr,
|
||||
'type': self._type,
|
||||
'hostname': self._hostname,
|
||||
'time_start': self.time_start,
|
||||
'acknowledged': self.acknowledged,
|
||||
'delivery_info': self.delivery_info,
|
||||
'worker_pid': self.worker_pid,
|
||||
}
|
||||
|
||||
def humaninfo(self):
|
||||
return '{0.name}[{0.id}]'.format(self)
|
||||
|
||||
def __str__(self):
|
||||
"""``str(self)``."""
|
||||
return ' '.join([
|
||||
self.humaninfo(),
|
||||
f' ETA:[{self._eta}]' if self._eta else '',
|
||||
f' expires:[{self._expires}]' if self._expires else '',
|
||||
]).strip()
|
||||
|
||||
def __repr__(self):
|
||||
"""``repr(self)``."""
|
||||
return '<{}: {} {} {}>'.format(
|
||||
type(self).__name__, self.humaninfo(),
|
||||
self._argsrepr, self._kwargsrepr,
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def _payload(self):
|
||||
return self.__payload
|
||||
|
||||
@cached_property
|
||||
def chord(self):
|
||||
# used by backend.mark_as_failure when failure is reported
|
||||
# by parent process
|
||||
# pylint: disable=unpacking-non-sequence
|
||||
# payload is a property, so pylint doesn't think it's a tuple.
|
||||
_, _, embed = self._payload
|
||||
return embed.get('chord')
|
||||
|
||||
@cached_property
|
||||
def errbacks(self):
|
||||
# used by backend.mark_as_failure when failure is reported
|
||||
# by parent process
|
||||
# pylint: disable=unpacking-non-sequence
|
||||
# payload is a property, so pylint doesn't think it's a tuple.
|
||||
_, _, embed = self._payload
|
||||
return embed.get('errbacks')
|
||||
|
||||
@cached_property
|
||||
def group(self):
|
||||
# used by backend.on_chord_part_return when failures reported
|
||||
# by parent process
|
||||
return self._request_dict.get('group')
|
||||
|
||||
@cached_property
|
||||
def _context(self):
|
||||
"""Context (:class:`~celery.app.task.Context`) of this task."""
|
||||
request = self._request_dict
|
||||
# pylint: disable=unpacking-non-sequence
|
||||
# payload is a property, so pylint doesn't think it's a tuple.
|
||||
_, _, embed = self._payload
|
||||
request.update(**embed or {})
|
||||
return Context(request)
|
||||
|
||||
@cached_property
|
||||
def group_index(self):
|
||||
# used by backend.on_chord_part_return to order return values in group
|
||||
return self._request_dict.get('group_index')
|
||||
|
||||
|
||||
def create_request_cls(base, task, pool, hostname, eventer,
|
||||
ref=ref, revoked_tasks=revoked_tasks,
|
||||
task_ready=task_ready, trace=None, app=current_app):
|
||||
default_time_limit = task.time_limit
|
||||
default_soft_time_limit = task.soft_time_limit
|
||||
apply_async = pool.apply_async
|
||||
acks_late = task.acks_late
|
||||
events = eventer and eventer.enabled
|
||||
|
||||
if trace is None:
|
||||
trace = fast_trace_task if app.use_fast_trace_task else trace_task_ret
|
||||
|
||||
class Request(base):
|
||||
|
||||
def execute_using_pool(self, pool, **kwargs):
|
||||
task_id = self.task_id
|
||||
if self.revoked():
|
||||
raise TaskRevokedError(task_id)
|
||||
|
||||
time_limit, soft_time_limit = self.time_limits
|
||||
result = apply_async(
|
||||
trace,
|
||||
args=(self.type, task_id, self.request_dict, self.body,
|
||||
self.content_type, self.content_encoding),
|
||||
accept_callback=self.on_accepted,
|
||||
timeout_callback=self.on_timeout,
|
||||
callback=self.on_success,
|
||||
error_callback=self.on_failure,
|
||||
soft_timeout=soft_time_limit or default_soft_time_limit,
|
||||
timeout=time_limit or default_time_limit,
|
||||
correlation_id=task_id,
|
||||
)
|
||||
# cannot create weakref to None
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
self._apply_result = maybe(ref, result)
|
||||
return result
|
||||
|
||||
def on_success(self, failed__retval__runtime, **kwargs):
|
||||
failed, retval, runtime = failed__retval__runtime
|
||||
if failed:
|
||||
exc = retval.exception
|
||||
if isinstance(exc, ExceptionWithTraceback):
|
||||
exc = exc.exc
|
||||
if isinstance(exc, (SystemExit, KeyboardInterrupt)):
|
||||
raise exc
|
||||
return self.on_failure(retval, return_ok=True)
|
||||
task_ready(self, successful=True)
|
||||
|
||||
if acks_late:
|
||||
self.acknowledge()
|
||||
|
||||
if events:
|
||||
self.send_event(
|
||||
'task-succeeded', result=retval, runtime=runtime,
|
||||
)
|
||||
|
||||
return Request
|
||||
288
ETB-API/venv/lib/python3.12/site-packages/celery/worker/state.py
Normal file
288
ETB-API/venv/lib/python3.12/site-packages/celery/worker/state.py
Normal file
@@ -0,0 +1,288 @@
|
||||
"""Internal worker state (global).
|
||||
|
||||
This includes the currently active and reserved tasks,
|
||||
statistics, and revoked tasks.
|
||||
"""
|
||||
import os
|
||||
import platform
|
||||
import shelve
|
||||
import sys
|
||||
import weakref
|
||||
import zlib
|
||||
from collections import Counter
|
||||
|
||||
from kombu.serialization import pickle, pickle_protocol
|
||||
from kombu.utils.objects import cached_property
|
||||
|
||||
from celery import __version__
|
||||
from celery.exceptions import WorkerShutdown, WorkerTerminate
|
||||
from celery.utils.collections import LimitedSet
|
||||
|
||||
__all__ = (
|
||||
'SOFTWARE_INFO', 'reserved_requests', 'active_requests',
|
||||
'total_count', 'revoked', 'task_reserved', 'maybe_shutdown',
|
||||
'task_accepted', 'task_ready', 'Persistent',
|
||||
)
|
||||
|
||||
#: Worker software/platform information.
|
||||
SOFTWARE_INFO = {
|
||||
'sw_ident': 'py-celery',
|
||||
'sw_ver': __version__,
|
||||
'sw_sys': platform.system(),
|
||||
}
|
||||
|
||||
#: maximum number of revokes to keep in memory.
|
||||
REVOKES_MAX = int(os.environ.get('CELERY_WORKER_REVOKES_MAX', 50000))
|
||||
|
||||
#: maximum number of successful tasks to keep in memory.
|
||||
SUCCESSFUL_MAX = int(os.environ.get('CELERY_WORKER_SUCCESSFUL_MAX', 1000))
|
||||
|
||||
#: how many seconds a revoke will be active before
|
||||
#: being expired when the max limit has been exceeded.
|
||||
REVOKE_EXPIRES = float(os.environ.get('CELERY_WORKER_REVOKE_EXPIRES', 10800))
|
||||
|
||||
#: how many seconds a successful task will be cached in memory
|
||||
#: before being expired when the max limit has been exceeded.
|
||||
SUCCESSFUL_EXPIRES = float(os.environ.get('CELERY_WORKER_SUCCESSFUL_EXPIRES', 10800))
|
||||
|
||||
#: Mapping of reserved task_id->Request.
|
||||
requests = {}
|
||||
|
||||
#: set of all reserved :class:`~celery.worker.request.Request`'s.
|
||||
reserved_requests = weakref.WeakSet()
|
||||
|
||||
#: set of currently active :class:`~celery.worker.request.Request`'s.
|
||||
active_requests = weakref.WeakSet()
|
||||
|
||||
#: A limited set of successful :class:`~celery.worker.request.Request`'s.
|
||||
successful_requests = LimitedSet(maxlen=SUCCESSFUL_MAX,
|
||||
expires=SUCCESSFUL_EXPIRES)
|
||||
|
||||
#: count of tasks accepted by the worker, sorted by type.
|
||||
total_count = Counter()
|
||||
|
||||
#: count of all tasks accepted by the worker
|
||||
all_total_count = [0]
|
||||
|
||||
#: the list of currently revoked tasks. Persistent if ``statedb`` set.
|
||||
revoked = LimitedSet(maxlen=REVOKES_MAX, expires=REVOKE_EXPIRES)
|
||||
|
||||
#: Mapping of stamped headers flagged for revoking.
|
||||
revoked_stamps = {}
|
||||
|
||||
should_stop = None
|
||||
should_terminate = None
|
||||
|
||||
|
||||
def reset_state():
|
||||
requests.clear()
|
||||
reserved_requests.clear()
|
||||
active_requests.clear()
|
||||
successful_requests.clear()
|
||||
total_count.clear()
|
||||
all_total_count[:] = [0]
|
||||
revoked.clear()
|
||||
revoked_stamps.clear()
|
||||
|
||||
|
||||
def maybe_shutdown():
|
||||
"""Shutdown if flags have been set."""
|
||||
if should_terminate is not None and should_terminate is not False:
|
||||
raise WorkerTerminate(should_terminate)
|
||||
elif should_stop is not None and should_stop is not False:
|
||||
raise WorkerShutdown(should_stop)
|
||||
|
||||
|
||||
def task_reserved(request,
|
||||
add_request=requests.__setitem__,
|
||||
add_reserved_request=reserved_requests.add):
|
||||
"""Update global state when a task has been reserved."""
|
||||
add_request(request.id, request)
|
||||
add_reserved_request(request)
|
||||
|
||||
|
||||
def task_accepted(request,
|
||||
_all_total_count=None,
|
||||
add_request=requests.__setitem__,
|
||||
add_active_request=active_requests.add,
|
||||
add_to_total_count=total_count.update):
|
||||
"""Update global state when a task has been accepted."""
|
||||
if not _all_total_count:
|
||||
_all_total_count = all_total_count
|
||||
add_request(request.id, request)
|
||||
add_active_request(request)
|
||||
add_to_total_count({request.name: 1})
|
||||
all_total_count[0] += 1
|
||||
|
||||
|
||||
def task_ready(request,
|
||||
successful=False,
|
||||
remove_request=requests.pop,
|
||||
discard_active_request=active_requests.discard,
|
||||
discard_reserved_request=reserved_requests.discard):
|
||||
"""Update global state when a task is ready."""
|
||||
if successful:
|
||||
successful_requests.add(request.id)
|
||||
|
||||
remove_request(request.id, None)
|
||||
discard_active_request(request)
|
||||
discard_reserved_request(request)
|
||||
|
||||
|
||||
C_BENCH = os.environ.get('C_BENCH') or os.environ.get('CELERY_BENCH')
|
||||
C_BENCH_EVERY = int(os.environ.get('C_BENCH_EVERY') or
|
||||
os.environ.get('CELERY_BENCH_EVERY') or 1000)
|
||||
if C_BENCH: # pragma: no cover
|
||||
import atexit
|
||||
from time import monotonic
|
||||
|
||||
from billiard.process import current_process
|
||||
|
||||
from celery.utils.debug import memdump, sample_mem
|
||||
|
||||
all_count = 0
|
||||
bench_first = None
|
||||
bench_start = None
|
||||
bench_last = None
|
||||
bench_every = C_BENCH_EVERY
|
||||
bench_sample = []
|
||||
__reserved = task_reserved
|
||||
__ready = task_ready
|
||||
|
||||
if current_process()._name == 'MainProcess':
|
||||
@atexit.register
|
||||
def on_shutdown():
|
||||
if bench_first is not None and bench_last is not None:
|
||||
print('- Time spent in benchmark: {!r}'.format(
|
||||
bench_last - bench_first))
|
||||
print('- Avg: {}'.format(
|
||||
sum(bench_sample) / len(bench_sample)))
|
||||
memdump()
|
||||
|
||||
def task_reserved(request):
|
||||
"""Called when a task is reserved by the worker."""
|
||||
global bench_start
|
||||
global bench_first
|
||||
now = None
|
||||
if bench_start is None:
|
||||
bench_start = now = monotonic()
|
||||
if bench_first is None:
|
||||
bench_first = now
|
||||
|
||||
return __reserved(request)
|
||||
|
||||
def task_ready(request):
|
||||
"""Called when a task is completed."""
|
||||
global all_count
|
||||
global bench_start
|
||||
global bench_last
|
||||
all_count += 1
|
||||
if not all_count % bench_every:
|
||||
now = monotonic()
|
||||
diff = now - bench_start
|
||||
print('- Time spent processing {} tasks (since first '
|
||||
'task received): ~{:.4f}s\n'.format(bench_every, diff))
|
||||
sys.stdout.flush()
|
||||
bench_start = bench_last = now
|
||||
bench_sample.append(diff)
|
||||
sample_mem()
|
||||
return __ready(request)
|
||||
|
||||
|
||||
class Persistent:
|
||||
"""Stores worker state between restarts.
|
||||
|
||||
This is the persistent data stored by the worker when
|
||||
:option:`celery worker --statedb` is enabled.
|
||||
|
||||
Currently only stores revoked task id's.
|
||||
"""
|
||||
|
||||
storage = shelve
|
||||
protocol = pickle_protocol
|
||||
compress = zlib.compress
|
||||
decompress = zlib.decompress
|
||||
_is_open = False
|
||||
|
||||
def __init__(self, state, filename, clock=None):
|
||||
self.state = state
|
||||
self.filename = filename
|
||||
self.clock = clock
|
||||
self.merge()
|
||||
|
||||
def open(self):
|
||||
return self.storage.open(
|
||||
self.filename, protocol=self.protocol, writeback=True,
|
||||
)
|
||||
|
||||
def merge(self):
|
||||
self._merge_with(self.db)
|
||||
|
||||
def sync(self):
|
||||
self._sync_with(self.db)
|
||||
self.db.sync()
|
||||
|
||||
def close(self):
|
||||
if self._is_open:
|
||||
self.db.close()
|
||||
self._is_open = False
|
||||
|
||||
def save(self):
|
||||
self.sync()
|
||||
self.close()
|
||||
|
||||
def _merge_with(self, d):
|
||||
self._merge_revoked(d)
|
||||
self._merge_clock(d)
|
||||
return d
|
||||
|
||||
def _sync_with(self, d):
|
||||
self._revoked_tasks.purge()
|
||||
d.update({
|
||||
'__proto__': 3,
|
||||
'zrevoked': self.compress(self._dumps(self._revoked_tasks)),
|
||||
'clock': self.clock.forward() if self.clock else 0,
|
||||
})
|
||||
return d
|
||||
|
||||
def _merge_clock(self, d):
|
||||
if self.clock:
|
||||
d['clock'] = self.clock.adjust(d.get('clock') or 0)
|
||||
|
||||
def _merge_revoked(self, d):
|
||||
try:
|
||||
self._merge_revoked_v3(d['zrevoked'])
|
||||
except KeyError:
|
||||
try:
|
||||
self._merge_revoked_v2(d.pop('revoked'))
|
||||
except KeyError:
|
||||
pass
|
||||
# purge expired items at boot
|
||||
self._revoked_tasks.purge()
|
||||
|
||||
def _merge_revoked_v3(self, zrevoked):
|
||||
if zrevoked:
|
||||
self._revoked_tasks.update(pickle.loads(self.decompress(zrevoked)))
|
||||
|
||||
def _merge_revoked_v2(self, saved):
|
||||
if not isinstance(saved, LimitedSet):
|
||||
# (pre 3.0.18) used to be stored as a dict
|
||||
return self._merge_revoked_v1(saved)
|
||||
self._revoked_tasks.update(saved)
|
||||
|
||||
def _merge_revoked_v1(self, saved):
|
||||
add = self._revoked_tasks.add
|
||||
for item in saved:
|
||||
add(item)
|
||||
|
||||
def _dumps(self, obj):
|
||||
return pickle.dumps(obj, protocol=self.protocol)
|
||||
|
||||
@property
|
||||
def _revoked_tasks(self):
|
||||
return self.state.revoked
|
||||
|
||||
@cached_property
|
||||
def db(self):
|
||||
self._is_open = True
|
||||
return self.open()
|
||||
@@ -0,0 +1,208 @@
|
||||
"""Task execution strategy (optimization)."""
|
||||
import logging
|
||||
|
||||
from kombu.asynchronous.timer import to_timestamp
|
||||
|
||||
from celery import signals
|
||||
from celery.app import trace as _app_trace
|
||||
from celery.exceptions import InvalidTaskError
|
||||
from celery.utils.imports import symbol_by_name
|
||||
from celery.utils.log import get_logger
|
||||
from celery.utils.saferepr import saferepr
|
||||
from celery.utils.time import timezone
|
||||
|
||||
from .request import create_request_cls
|
||||
from .state import task_reserved
|
||||
|
||||
__all__ = ('default',)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# pylint: disable=redefined-outer-name
|
||||
# We cache globals and attribute lookups, so disable this warning.
|
||||
|
||||
|
||||
def hybrid_to_proto2(message, body):
|
||||
"""Create a fresh protocol 2 message from a hybrid protocol 1/2 message."""
|
||||
try:
|
||||
args, kwargs = body.get('args', ()), body.get('kwargs', {})
|
||||
kwargs.items # pylint: disable=pointless-statement
|
||||
except KeyError:
|
||||
raise InvalidTaskError('Message does not have args/kwargs')
|
||||
except AttributeError:
|
||||
raise InvalidTaskError(
|
||||
'Task keyword arguments must be a mapping',
|
||||
)
|
||||
|
||||
headers = {
|
||||
'lang': body.get('lang'),
|
||||
'task': body.get('task'),
|
||||
'id': body.get('id'),
|
||||
'root_id': body.get('root_id'),
|
||||
'parent_id': body.get('parent_id'),
|
||||
'group': body.get('group'),
|
||||
'meth': body.get('meth'),
|
||||
'shadow': body.get('shadow'),
|
||||
'eta': body.get('eta'),
|
||||
'expires': body.get('expires'),
|
||||
'retries': body.get('retries', 0),
|
||||
'timelimit': body.get('timelimit', (None, None)),
|
||||
'argsrepr': body.get('argsrepr'),
|
||||
'kwargsrepr': body.get('kwargsrepr'),
|
||||
'origin': body.get('origin'),
|
||||
}
|
||||
headers.update(message.headers or {})
|
||||
|
||||
embed = {
|
||||
'callbacks': body.get('callbacks'),
|
||||
'errbacks': body.get('errbacks'),
|
||||
'chord': body.get('chord'),
|
||||
'chain': None,
|
||||
}
|
||||
|
||||
return (args, kwargs, embed), headers, True, body.get('utc', True)
|
||||
|
||||
|
||||
def proto1_to_proto2(message, body):
|
||||
"""Convert Task message protocol 1 arguments to protocol 2.
|
||||
|
||||
Returns:
|
||||
Tuple: of ``(body, headers, already_decoded_status, utc)``
|
||||
"""
|
||||
try:
|
||||
args, kwargs = body.get('args', ()), body.get('kwargs', {})
|
||||
kwargs.items # pylint: disable=pointless-statement
|
||||
except KeyError:
|
||||
raise InvalidTaskError('Message does not have args/kwargs')
|
||||
except AttributeError:
|
||||
raise InvalidTaskError(
|
||||
'Task keyword arguments must be a mapping',
|
||||
)
|
||||
body.update(
|
||||
argsrepr=saferepr(args),
|
||||
kwargsrepr=saferepr(kwargs),
|
||||
headers=message.headers,
|
||||
)
|
||||
try:
|
||||
body['group'] = body['taskset']
|
||||
except KeyError:
|
||||
pass
|
||||
embed = {
|
||||
'callbacks': body.get('callbacks'),
|
||||
'errbacks': body.get('errbacks'),
|
||||
'chord': body.get('chord'),
|
||||
'chain': None,
|
||||
}
|
||||
return (args, kwargs, embed), body, True, body.get('utc', True)
|
||||
|
||||
|
||||
def default(task, app, consumer,
|
||||
info=logger.info, error=logger.error, task_reserved=task_reserved,
|
||||
to_system_tz=timezone.to_system, bytes=bytes,
|
||||
proto1_to_proto2=proto1_to_proto2):
|
||||
"""Default task execution strategy.
|
||||
|
||||
Note:
|
||||
Strategies are here as an optimization, so sadly
|
||||
it's not very easy to override.
|
||||
"""
|
||||
hostname = consumer.hostname
|
||||
connection_errors = consumer.connection_errors
|
||||
_does_info = logger.isEnabledFor(logging.INFO)
|
||||
|
||||
# task event related
|
||||
# (optimized to avoid calling request.send_event)
|
||||
eventer = consumer.event_dispatcher
|
||||
events = eventer and eventer.enabled
|
||||
send_event = eventer and eventer.send
|
||||
task_sends_events = events and task.send_events
|
||||
|
||||
call_at = consumer.timer.call_at
|
||||
apply_eta_task = consumer.apply_eta_task
|
||||
rate_limits_enabled = not consumer.disable_rate_limits
|
||||
get_bucket = consumer.task_buckets.__getitem__
|
||||
handle = consumer.on_task_request
|
||||
limit_task = consumer._limit_task
|
||||
limit_post_eta = consumer._limit_post_eta
|
||||
Request = symbol_by_name(task.Request)
|
||||
Req = create_request_cls(Request, task, consumer.pool, hostname, eventer, app=app)
|
||||
|
||||
revoked_tasks = consumer.controller.state.revoked
|
||||
|
||||
def task_message_handler(message, body, ack, reject, callbacks,
|
||||
to_timestamp=to_timestamp):
|
||||
if body is None and 'args' not in message.payload:
|
||||
body, headers, decoded, utc = (
|
||||
message.body, message.headers, False, app.uses_utc_timezone(),
|
||||
)
|
||||
else:
|
||||
if 'args' in message.payload:
|
||||
body, headers, decoded, utc = hybrid_to_proto2(message,
|
||||
message.payload)
|
||||
else:
|
||||
body, headers, decoded, utc = proto1_to_proto2(message, body)
|
||||
|
||||
req = Req(
|
||||
message,
|
||||
on_ack=ack, on_reject=reject, app=app, hostname=hostname,
|
||||
eventer=eventer, task=task, connection_errors=connection_errors,
|
||||
body=body, headers=headers, decoded=decoded, utc=utc,
|
||||
)
|
||||
if _does_info:
|
||||
# Similar to `app.trace.info()`, we pass the formatting args as the
|
||||
# `extra` kwarg for custom log handlers
|
||||
context = {
|
||||
'id': req.id,
|
||||
'name': req.name,
|
||||
'args': req.argsrepr,
|
||||
'kwargs': req.kwargsrepr,
|
||||
'eta': req.eta,
|
||||
}
|
||||
info(_app_trace.LOG_RECEIVED, context, extra={'data': context})
|
||||
if (req.expires or req.id in revoked_tasks) and req.revoked():
|
||||
return
|
||||
|
||||
signals.task_received.send(sender=consumer, request=req)
|
||||
|
||||
if task_sends_events:
|
||||
send_event(
|
||||
'task-received',
|
||||
uuid=req.id, name=req.name,
|
||||
args=req.argsrepr, kwargs=req.kwargsrepr,
|
||||
root_id=req.root_id, parent_id=req.parent_id,
|
||||
retries=req.request_dict.get('retries', 0),
|
||||
eta=req.eta and req.eta.isoformat(),
|
||||
expires=req.expires and req.expires.isoformat(),
|
||||
)
|
||||
|
||||
bucket = None
|
||||
eta = None
|
||||
if req.eta:
|
||||
try:
|
||||
if req.utc:
|
||||
eta = to_timestamp(to_system_tz(req.eta))
|
||||
else:
|
||||
eta = to_timestamp(req.eta, app.timezone)
|
||||
except (OverflowError, ValueError) as exc:
|
||||
error("Couldn't convert ETA %r to timestamp: %r. Task: %r",
|
||||
req.eta, exc, req.info(safe=True), exc_info=True)
|
||||
req.reject(requeue=False)
|
||||
if rate_limits_enabled:
|
||||
bucket = get_bucket(task.name)
|
||||
|
||||
if eta and bucket:
|
||||
consumer.qos.increment_eventually()
|
||||
return call_at(eta, limit_post_eta, (req, bucket, 1),
|
||||
priority=6)
|
||||
if eta:
|
||||
consumer.qos.increment_eventually()
|
||||
call_at(eta, apply_eta_task, (req,), priority=6)
|
||||
return task_message_handler
|
||||
if bucket:
|
||||
return limit_task(req, bucket, 1)
|
||||
|
||||
task_reserved(req)
|
||||
if callbacks:
|
||||
[callback(req) for callback in callbacks]
|
||||
handle(req)
|
||||
return task_message_handler
|
||||
@@ -0,0 +1,435 @@
|
||||
"""WorkController can be used to instantiate in-process workers.
|
||||
|
||||
The command-line interface for the worker is in :mod:`celery.bin.worker`,
|
||||
while the worker program is in :mod:`celery.apps.worker`.
|
||||
|
||||
The worker program is responsible for adding signal handlers,
|
||||
setting up logging, etc. This is a bare-bones worker without
|
||||
global side-effects (i.e., except for the global state stored in
|
||||
:mod:`celery.worker.state`).
|
||||
|
||||
The worker consists of several components, all managed by bootsteps
|
||||
(mod:`celery.bootsteps`).
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime, timezone
|
||||
from time import sleep
|
||||
|
||||
from billiard import cpu_count
|
||||
from kombu.utils.compat import detect_environment
|
||||
|
||||
from celery import bootsteps
|
||||
from celery import concurrency as _concurrency
|
||||
from celery import signals
|
||||
from celery.bootsteps import RUN, TERMINATE
|
||||
from celery.exceptions import ImproperlyConfigured, TaskRevokedError, WorkerTerminate
|
||||
from celery.platforms import EX_FAILURE, create_pidlock
|
||||
from celery.utils.imports import reload_from_cwd
|
||||
from celery.utils.log import mlevel
|
||||
from celery.utils.log import worker_logger as logger
|
||||
from celery.utils.nodenames import default_nodename, worker_direct
|
||||
from celery.utils.text import str_to_list
|
||||
from celery.utils.threads import default_socket_timeout
|
||||
|
||||
from . import state
|
||||
|
||||
try:
|
||||
import resource
|
||||
except ImportError:
|
||||
resource = None
|
||||
|
||||
|
||||
__all__ = ('WorkController',)
|
||||
|
||||
#: Default socket timeout at shutdown.
|
||||
SHUTDOWN_SOCKET_TIMEOUT = 5.0
|
||||
|
||||
SELECT_UNKNOWN_QUEUE = """
|
||||
Trying to select queue subset of {0!r}, but queue {1} isn't
|
||||
defined in the `task_queues` setting.
|
||||
|
||||
If you want to automatically declare unknown queues you can
|
||||
enable the `task_create_missing_queues` setting.
|
||||
"""
|
||||
|
||||
DESELECT_UNKNOWN_QUEUE = """
|
||||
Trying to deselect queue subset of {0!r}, but queue {1} isn't
|
||||
defined in the `task_queues` setting.
|
||||
"""
|
||||
|
||||
|
||||
class WorkController:
|
||||
"""Unmanaged worker instance."""
|
||||
|
||||
app = None
|
||||
|
||||
pidlock = None
|
||||
blueprint = None
|
||||
pool = None
|
||||
semaphore = None
|
||||
|
||||
#: contains the exit code if a :exc:`SystemExit` event is handled.
|
||||
exitcode = None
|
||||
|
||||
class Blueprint(bootsteps.Blueprint):
|
||||
"""Worker bootstep blueprint."""
|
||||
|
||||
name = 'Worker'
|
||||
default_steps = {
|
||||
'celery.worker.components:Hub',
|
||||
'celery.worker.components:Pool',
|
||||
'celery.worker.components:Beat',
|
||||
'celery.worker.components:Timer',
|
||||
'celery.worker.components:StateDB',
|
||||
'celery.worker.components:Consumer',
|
||||
'celery.worker.autoscale:WorkerComponent',
|
||||
}
|
||||
|
||||
def __init__(self, app=None, hostname=None, **kwargs):
|
||||
self.app = app or self.app
|
||||
self.hostname = default_nodename(hostname)
|
||||
self.startup_time = datetime.now(timezone.utc)
|
||||
self.app.loader.init_worker()
|
||||
self.on_before_init(**kwargs)
|
||||
self.setup_defaults(**kwargs)
|
||||
self.on_after_init(**kwargs)
|
||||
|
||||
self.setup_instance(**self.prepare_args(**kwargs))
|
||||
|
||||
def setup_instance(self, queues=None, ready_callback=None, pidfile=None,
|
||||
include=None, use_eventloop=None, exclude_queues=None,
|
||||
**kwargs):
|
||||
self.pidfile = pidfile
|
||||
self.setup_queues(queues, exclude_queues)
|
||||
self.setup_includes(str_to_list(include))
|
||||
|
||||
# Set default concurrency
|
||||
if not self.concurrency:
|
||||
try:
|
||||
self.concurrency = cpu_count()
|
||||
except NotImplementedError:
|
||||
self.concurrency = 2
|
||||
|
||||
# Options
|
||||
self.loglevel = mlevel(self.loglevel)
|
||||
self.ready_callback = ready_callback or self.on_consumer_ready
|
||||
|
||||
# this connection won't establish, only used for params
|
||||
self._conninfo = self.app.connection_for_read()
|
||||
self.use_eventloop = (
|
||||
self.should_use_eventloop() if use_eventloop is None
|
||||
else use_eventloop
|
||||
)
|
||||
self.options = kwargs
|
||||
|
||||
signals.worker_init.send(sender=self)
|
||||
|
||||
# Initialize bootsteps
|
||||
self.pool_cls = _concurrency.get_implementation(self.pool_cls)
|
||||
self.steps = []
|
||||
self.on_init_blueprint()
|
||||
self.blueprint = self.Blueprint(
|
||||
steps=self.app.steps['worker'],
|
||||
on_start=self.on_start,
|
||||
on_close=self.on_close,
|
||||
on_stopped=self.on_stopped,
|
||||
)
|
||||
self.blueprint.apply(self, **kwargs)
|
||||
|
||||
def on_init_blueprint(self):
|
||||
pass
|
||||
|
||||
def on_before_init(self, **kwargs):
|
||||
pass
|
||||
|
||||
def on_after_init(self, **kwargs):
|
||||
pass
|
||||
|
||||
def on_start(self):
|
||||
if self.pidfile:
|
||||
self.pidlock = create_pidlock(self.pidfile)
|
||||
|
||||
def on_consumer_ready(self, consumer):
|
||||
pass
|
||||
|
||||
def on_close(self):
|
||||
self.app.loader.shutdown_worker()
|
||||
|
||||
def on_stopped(self):
|
||||
self.timer.stop()
|
||||
self.consumer.shutdown()
|
||||
|
||||
if self.pidlock:
|
||||
self.pidlock.release()
|
||||
|
||||
def setup_queues(self, include, exclude=None):
|
||||
include = str_to_list(include)
|
||||
exclude = str_to_list(exclude)
|
||||
try:
|
||||
self.app.amqp.queues.select(include)
|
||||
except KeyError as exc:
|
||||
raise ImproperlyConfigured(
|
||||
SELECT_UNKNOWN_QUEUE.strip().format(include, exc))
|
||||
try:
|
||||
self.app.amqp.queues.deselect(exclude)
|
||||
except KeyError as exc:
|
||||
raise ImproperlyConfigured(
|
||||
DESELECT_UNKNOWN_QUEUE.strip().format(exclude, exc))
|
||||
if self.app.conf.worker_direct:
|
||||
self.app.amqp.queues.select_add(worker_direct(self.hostname))
|
||||
|
||||
def setup_includes(self, includes):
|
||||
# Update celery_include to have all known task modules, so that we
|
||||
# ensure all task modules are imported in case an execv happens.
|
||||
prev = tuple(self.app.conf.include)
|
||||
if includes:
|
||||
prev += tuple(includes)
|
||||
[self.app.loader.import_task_module(m) for m in includes]
|
||||
self.include = includes
|
||||
task_modules = {task.__class__.__module__
|
||||
for task in self.app.tasks.values()}
|
||||
self.app.conf.include = tuple(set(prev) | task_modules)
|
||||
|
||||
def prepare_args(self, **kwargs):
|
||||
return kwargs
|
||||
|
||||
def _send_worker_shutdown(self):
|
||||
signals.worker_shutdown.send(sender=self)
|
||||
|
||||
def start(self):
|
||||
try:
|
||||
self.blueprint.start(self)
|
||||
except WorkerTerminate:
|
||||
self.terminate()
|
||||
except Exception as exc:
|
||||
logger.critical('Unrecoverable error: %r', exc, exc_info=True)
|
||||
self.stop(exitcode=EX_FAILURE)
|
||||
except SystemExit as exc:
|
||||
self.stop(exitcode=exc.code)
|
||||
except KeyboardInterrupt:
|
||||
self.stop(exitcode=EX_FAILURE)
|
||||
|
||||
def register_with_event_loop(self, hub):
|
||||
self.blueprint.send_all(
|
||||
self, 'register_with_event_loop', args=(hub,),
|
||||
description='hub.register',
|
||||
)
|
||||
|
||||
def _process_task_sem(self, req):
|
||||
return self._quick_acquire(self._process_task, req)
|
||||
|
||||
def _process_task(self, req):
|
||||
"""Process task by sending it to the pool of workers."""
|
||||
try:
|
||||
req.execute_using_pool(self.pool)
|
||||
except TaskRevokedError:
|
||||
try:
|
||||
self._quick_release() # Issue 877
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
def signal_consumer_close(self):
|
||||
try:
|
||||
self.consumer.close()
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
def should_use_eventloop(self):
|
||||
return (detect_environment() == 'default' and
|
||||
self._conninfo.transport.implements.asynchronous and
|
||||
not self.app.IS_WINDOWS)
|
||||
|
||||
def stop(self, in_sighandler=False, exitcode=None):
|
||||
"""Graceful shutdown of the worker server (Warm shutdown)."""
|
||||
if exitcode is not None:
|
||||
self.exitcode = exitcode
|
||||
if self.blueprint.state == RUN:
|
||||
self.signal_consumer_close()
|
||||
if not in_sighandler or self.pool.signal_safe:
|
||||
self._shutdown(warm=True)
|
||||
self._send_worker_shutdown()
|
||||
|
||||
def terminate(self, in_sighandler=False):
|
||||
"""Not so graceful shutdown of the worker server (Cold shutdown)."""
|
||||
if self.blueprint.state != TERMINATE:
|
||||
self.signal_consumer_close()
|
||||
if not in_sighandler or self.pool.signal_safe:
|
||||
self._shutdown(warm=False)
|
||||
|
||||
def _shutdown(self, warm=True):
|
||||
# if blueprint does not exist it means that we had an
|
||||
# error before the bootsteps could be initialized.
|
||||
if self.blueprint is not None:
|
||||
with default_socket_timeout(SHUTDOWN_SOCKET_TIMEOUT): # Issue 975
|
||||
self.blueprint.stop(self, terminate=not warm)
|
||||
self.blueprint.join()
|
||||
|
||||
def reload(self, modules=None, reload=False, reloader=None):
|
||||
list(self._reload_modules(
|
||||
modules, force_reload=reload, reloader=reloader))
|
||||
|
||||
if self.consumer:
|
||||
self.consumer.update_strategies()
|
||||
self.consumer.reset_rate_limits()
|
||||
try:
|
||||
self.pool.restart()
|
||||
except NotImplementedError:
|
||||
pass
|
||||
|
||||
def _reload_modules(self, modules=None, **kwargs):
|
||||
return (
|
||||
self._maybe_reload_module(m, **kwargs)
|
||||
for m in set(self.app.loader.task_modules
|
||||
if modules is None else (modules or ()))
|
||||
)
|
||||
|
||||
def _maybe_reload_module(self, module, force_reload=False, reloader=None):
|
||||
if module not in sys.modules:
|
||||
logger.debug('importing module %s', module)
|
||||
return self.app.loader.import_from_cwd(module)
|
||||
elif force_reload:
|
||||
logger.debug('reloading module %s', module)
|
||||
return reload_from_cwd(sys.modules[module], reloader)
|
||||
|
||||
def info(self):
|
||||
uptime = datetime.now(timezone.utc) - self.startup_time
|
||||
return {'total': self.state.total_count,
|
||||
'pid': os.getpid(),
|
||||
'clock': str(self.app.clock),
|
||||
'uptime': round(uptime.total_seconds())}
|
||||
|
||||
def rusage(self):
|
||||
if resource is None:
|
||||
raise NotImplementedError('rusage not supported by this platform')
|
||||
s = resource.getrusage(resource.RUSAGE_SELF)
|
||||
return {
|
||||
'utime': s.ru_utime,
|
||||
'stime': s.ru_stime,
|
||||
'maxrss': s.ru_maxrss,
|
||||
'ixrss': s.ru_ixrss,
|
||||
'idrss': s.ru_idrss,
|
||||
'isrss': s.ru_isrss,
|
||||
'minflt': s.ru_minflt,
|
||||
'majflt': s.ru_majflt,
|
||||
'nswap': s.ru_nswap,
|
||||
'inblock': s.ru_inblock,
|
||||
'oublock': s.ru_oublock,
|
||||
'msgsnd': s.ru_msgsnd,
|
||||
'msgrcv': s.ru_msgrcv,
|
||||
'nsignals': s.ru_nsignals,
|
||||
'nvcsw': s.ru_nvcsw,
|
||||
'nivcsw': s.ru_nivcsw,
|
||||
}
|
||||
|
||||
def stats(self):
|
||||
info = self.info()
|
||||
info.update(self.blueprint.info(self))
|
||||
info.update(self.consumer.blueprint.info(self.consumer))
|
||||
try:
|
||||
info['rusage'] = self.rusage()
|
||||
except NotImplementedError:
|
||||
info['rusage'] = 'N/A'
|
||||
return info
|
||||
|
||||
def __repr__(self):
|
||||
"""``repr(worker)``."""
|
||||
return '<Worker: {self.hostname} ({state})>'.format(
|
||||
self=self,
|
||||
state=self.blueprint.human_state() if self.blueprint else 'INIT',
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
"""``str(worker) == worker.hostname``."""
|
||||
return self.hostname
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
return state
|
||||
|
||||
def setup_defaults(self, concurrency=None, loglevel='WARN', logfile=None,
|
||||
task_events=None, pool=None, consumer_cls=None,
|
||||
timer_cls=None, timer_precision=None,
|
||||
autoscaler_cls=None,
|
||||
pool_putlocks=None,
|
||||
pool_restarts=None,
|
||||
optimization=None, O=None, # O maps to -O=fair
|
||||
statedb=None,
|
||||
time_limit=None,
|
||||
soft_time_limit=None,
|
||||
scheduler=None,
|
||||
pool_cls=None, # XXX use pool
|
||||
state_db=None, # XXX use statedb
|
||||
task_time_limit=None, # XXX use time_limit
|
||||
task_soft_time_limit=None, # XXX use soft_time_limit
|
||||
scheduler_cls=None, # XXX use scheduler
|
||||
schedule_filename=None,
|
||||
max_tasks_per_child=None,
|
||||
prefetch_multiplier=None, disable_rate_limits=None,
|
||||
worker_lost_wait=None,
|
||||
max_memory_per_child=None, **_kw):
|
||||
either = self.app.either
|
||||
self.loglevel = loglevel
|
||||
self.logfile = logfile
|
||||
|
||||
self.concurrency = either('worker_concurrency', concurrency)
|
||||
self.task_events = either('worker_send_task_events', task_events)
|
||||
self.pool_cls = either('worker_pool', pool, pool_cls)
|
||||
self.consumer_cls = either('worker_consumer', consumer_cls)
|
||||
self.timer_cls = either('worker_timer', timer_cls)
|
||||
self.timer_precision = either(
|
||||
'worker_timer_precision', timer_precision,
|
||||
)
|
||||
self.optimization = optimization or O
|
||||
self.autoscaler_cls = either('worker_autoscaler', autoscaler_cls)
|
||||
self.pool_putlocks = either('worker_pool_putlocks', pool_putlocks)
|
||||
self.pool_restarts = either('worker_pool_restarts', pool_restarts)
|
||||
self.statedb = either('worker_state_db', statedb, state_db)
|
||||
self.schedule_filename = either(
|
||||
'beat_schedule_filename', schedule_filename,
|
||||
)
|
||||
self.scheduler = either('beat_scheduler', scheduler, scheduler_cls)
|
||||
self.time_limit = either(
|
||||
'task_time_limit', time_limit, task_time_limit)
|
||||
self.soft_time_limit = either(
|
||||
'task_soft_time_limit', soft_time_limit, task_soft_time_limit,
|
||||
)
|
||||
self.max_tasks_per_child = either(
|
||||
'worker_max_tasks_per_child', max_tasks_per_child,
|
||||
)
|
||||
self.max_memory_per_child = either(
|
||||
'worker_max_memory_per_child', max_memory_per_child,
|
||||
)
|
||||
self.prefetch_multiplier = int(either(
|
||||
'worker_prefetch_multiplier', prefetch_multiplier,
|
||||
))
|
||||
self.disable_rate_limits = either(
|
||||
'worker_disable_rate_limits', disable_rate_limits,
|
||||
)
|
||||
self.worker_lost_wait = either('worker_lost_wait', worker_lost_wait)
|
||||
|
||||
def wait_for_soft_shutdown(self):
|
||||
"""Wait :setting:`worker_soft_shutdown_timeout` if soft shutdown is enabled.
|
||||
|
||||
To enable soft shutdown, set the :setting:`worker_soft_shutdown_timeout` in the
|
||||
configuration. Soft shutdown can be used to allow the worker to finish processing
|
||||
few more tasks before initiating a cold shutdown. This mechanism allows the worker
|
||||
to finish short tasks that are already in progress and requeue long-running tasks
|
||||
to be picked up by another worker.
|
||||
|
||||
.. warning::
|
||||
If there are no tasks in the worker, the worker will not wait for the
|
||||
soft shutdown timeout even if it is set as it makes no sense to wait for
|
||||
the timeout when there are no tasks to process.
|
||||
"""
|
||||
app = self.app
|
||||
requests = tuple(state.active_requests)
|
||||
|
||||
if app.conf.worker_enable_soft_shutdown_on_idle:
|
||||
requests = True
|
||||
|
||||
if app.conf.worker_soft_shutdown_timeout > 0 and requests:
|
||||
log = f"Initiating Soft Shutdown, terminating in {app.conf.worker_soft_shutdown_timeout} seconds"
|
||||
logger.warning(log)
|
||||
sleep(app.conf.worker_soft_shutdown_timeout)
|
||||
Reference in New Issue
Block a user