This commit is contained in:
Iliyan Angelov
2025-09-14 23:24:25 +03:00
commit c67067a2a4
71311 changed files with 6800714 additions and 0 deletions

View File

@@ -0,0 +1,4 @@
"""Worker implementation."""
from .worker import WorkController
__all__ = ('WorkController',)

View File

@@ -0,0 +1,154 @@
"""Pool Autoscaling.
This module implements the internal thread responsible
for growing and shrinking the pool according to the
current autoscale settings.
The autoscale thread is only enabled if
the :option:`celery worker --autoscale` option is used.
"""
import os
import threading
from time import monotonic, sleep
from kombu.asynchronous.semaphore import DummyLock
from celery import bootsteps
from celery.utils.log import get_logger
from celery.utils.threads import bgThread
from . import state
from .components import Pool
__all__ = ('Autoscaler', 'WorkerComponent')
logger = get_logger(__name__)
debug, info, error = logger.debug, logger.info, logger.error
AUTOSCALE_KEEPALIVE = float(os.environ.get('AUTOSCALE_KEEPALIVE', 30))
class WorkerComponent(bootsteps.StartStopStep):
"""Bootstep that starts the autoscaler thread/timer in the worker."""
label = 'Autoscaler'
conditional = True
requires = (Pool,)
def __init__(self, w, **kwargs):
self.enabled = w.autoscale
w.autoscaler = None
def create(self, w):
scaler = w.autoscaler = self.instantiate(
w.autoscaler_cls,
w.pool, w.max_concurrency, w.min_concurrency,
worker=w, mutex=DummyLock() if w.use_eventloop else None,
)
return scaler if not w.use_eventloop else None
def register_with_event_loop(self, w, hub):
w.consumer.on_task_message.add(w.autoscaler.maybe_scale)
hub.call_repeatedly(
w.autoscaler.keepalive, w.autoscaler.maybe_scale,
)
def info(self, w):
"""Return `Autoscaler` info."""
return {'autoscaler': w.autoscaler.info()}
class Autoscaler(bgThread):
"""Background thread to autoscale pool workers."""
def __init__(self, pool, max_concurrency,
min_concurrency=0, worker=None,
keepalive=AUTOSCALE_KEEPALIVE, mutex=None):
super().__init__()
self.pool = pool
self.mutex = mutex or threading.Lock()
self.max_concurrency = max_concurrency
self.min_concurrency = min_concurrency
self.keepalive = keepalive
self._last_scale_up = None
self.worker = worker
assert self.keepalive, 'cannot scale down too fast.'
def body(self):
with self.mutex:
self.maybe_scale()
sleep(1.0)
def _maybe_scale(self, req=None):
procs = self.processes
cur = min(self.qty, self.max_concurrency)
if cur > procs:
self.scale_up(cur - procs)
return True
cur = max(self.qty, self.min_concurrency)
if cur < procs:
self.scale_down(procs - cur)
return True
def maybe_scale(self, req=None):
if self._maybe_scale(req):
self.pool.maintain_pool()
def update(self, max=None, min=None):
with self.mutex:
if max is not None:
if max < self.processes:
self._shrink(self.processes - max)
self._update_consumer_prefetch_count(max)
self.max_concurrency = max
if min is not None:
if min > self.processes:
self._grow(min - self.processes)
self.min_concurrency = min
return self.max_concurrency, self.min_concurrency
def scale_up(self, n):
self._last_scale_up = monotonic()
return self._grow(n)
def scale_down(self, n):
if self._last_scale_up and (
monotonic() - self._last_scale_up > self.keepalive):
return self._shrink(n)
def _grow(self, n):
info('Scaling up %s processes.', n)
self.pool.grow(n)
def _shrink(self, n):
info('Scaling down %s processes.', n)
try:
self.pool.shrink(n)
except ValueError:
debug("Autoscaler won't scale down: all processes busy.")
except Exception as exc:
error('Autoscaler: scale_down: %r', exc, exc_info=True)
def _update_consumer_prefetch_count(self, new_max):
diff = new_max - self.max_concurrency
if diff:
self.worker.consumer._update_prefetch_count(
diff
)
def info(self):
return {
'max': self.max_concurrency,
'min': self.min_concurrency,
'current': self.processes,
'qty': self.qty,
}
@property
def qty(self):
return len(state.reserved_requests)
@property
def processes(self):
return self.pool.num_processes

View File

@@ -0,0 +1,240 @@
"""Worker-level Bootsteps."""
import atexit
import warnings
from kombu.asynchronous import Hub as _Hub
from kombu.asynchronous import get_event_loop, set_event_loop
from kombu.asynchronous.semaphore import DummyLock, LaxBoundedSemaphore
from kombu.asynchronous.timer import Timer as _Timer
from celery import bootsteps
from celery._state import _set_task_join_will_block
from celery.exceptions import ImproperlyConfigured
from celery.platforms import IS_WINDOWS
from celery.utils.log import worker_logger as logger
__all__ = ('Timer', 'Hub', 'Pool', 'Beat', 'StateDB', 'Consumer')
GREEN_POOLS = {'eventlet', 'gevent'}
ERR_B_GREEN = """\
-B option doesn't work with eventlet/gevent pools: \
use standalone beat instead.\
"""
W_POOL_SETTING = """
The worker_pool setting shouldn't be used to select the eventlet/gevent
pools, instead you *must use the -P* argument so that patches are applied
as early as possible.
"""
class Timer(bootsteps.Step):
"""Timer bootstep."""
def create(self, w):
if w.use_eventloop:
# does not use dedicated timer thread.
w.timer = _Timer(max_interval=10.0)
else:
if not w.timer_cls:
# Default Timer is set by the pool, as for example, the
# eventlet pool needs a custom timer implementation.
w.timer_cls = w.pool_cls.Timer
w.timer = self.instantiate(w.timer_cls,
max_interval=w.timer_precision,
on_error=self.on_timer_error,
on_tick=self.on_timer_tick)
def on_timer_error(self, exc):
logger.error('Timer error: %r', exc, exc_info=True)
def on_timer_tick(self, delay):
logger.debug('Timer wake-up! Next ETA %s secs.', delay)
class Hub(bootsteps.StartStopStep):
"""Worker starts the event loop."""
requires = (Timer,)
def __init__(self, w, **kwargs):
w.hub = None
super().__init__(w, **kwargs)
def include_if(self, w):
return w.use_eventloop
def create(self, w):
w.hub = get_event_loop()
if w.hub is None:
required_hub = getattr(w._conninfo, 'requires_hub', None)
w.hub = set_event_loop((
required_hub if required_hub else _Hub)(w.timer))
self._patch_thread_primitives(w)
return self
def start(self, w):
pass
def stop(self, w):
w.hub.close()
def terminate(self, w):
w.hub.close()
def _patch_thread_primitives(self, w):
# make clock use dummy lock
w.app.clock.mutex = DummyLock()
# multiprocessing's ApplyResult uses this lock.
try:
from billiard import pool
except ImportError:
pass
else:
pool.Lock = DummyLock
class Pool(bootsteps.StartStopStep):
"""Bootstep managing the worker pool.
Describes how to initialize the worker pool, and starts and stops
the pool during worker start-up/shutdown.
Adds attributes:
* autoscale
* pool
* max_concurrency
* min_concurrency
"""
requires = (Hub,)
def __init__(self, w, autoscale=None, **kwargs):
w.pool = None
w.max_concurrency = None
w.min_concurrency = w.concurrency
self.optimization = w.optimization
if isinstance(autoscale, str):
max_c, _, min_c = autoscale.partition(',')
autoscale = [int(max_c), min_c and int(min_c) or 0]
w.autoscale = autoscale
if w.autoscale:
w.max_concurrency, w.min_concurrency = w.autoscale
super().__init__(w, **kwargs)
def close(self, w):
if w.pool:
w.pool.close()
def terminate(self, w):
if w.pool:
w.pool.terminate()
def create(self, w):
semaphore = None
max_restarts = None
if w.app.conf.worker_pool in GREEN_POOLS: # pragma: no cover
warnings.warn(UserWarning(W_POOL_SETTING))
threaded = not w.use_eventloop or IS_WINDOWS
procs = w.min_concurrency
w.process_task = w._process_task
if not threaded:
semaphore = w.semaphore = LaxBoundedSemaphore(procs)
w._quick_acquire = w.semaphore.acquire
w._quick_release = w.semaphore.release
max_restarts = 100
if w.pool_putlocks and w.pool_cls.uses_semaphore:
w.process_task = w._process_task_sem
allow_restart = w.pool_restarts
pool = w.pool = self.instantiate(
w.pool_cls, w.min_concurrency,
initargs=(w.app, w.hostname),
maxtasksperchild=w.max_tasks_per_child,
max_memory_per_child=w.max_memory_per_child,
timeout=w.time_limit,
soft_timeout=w.soft_time_limit,
putlocks=w.pool_putlocks and threaded,
lost_worker_timeout=w.worker_lost_wait,
threads=threaded,
max_restarts=max_restarts,
allow_restart=allow_restart,
forking_enable=True,
semaphore=semaphore,
sched_strategy=self.optimization,
app=w.app,
)
_set_task_join_will_block(pool.task_join_will_block)
return pool
def info(self, w):
return {'pool': w.pool.info if w.pool else 'N/A'}
def register_with_event_loop(self, w, hub):
w.pool.register_with_event_loop(hub)
class Beat(bootsteps.StartStopStep):
"""Step used to embed a beat process.
Enabled when the ``beat`` argument is set.
"""
label = 'Beat'
conditional = True
def __init__(self, w, beat=False, **kwargs):
self.enabled = w.beat = beat
w.beat = None
super().__init__(w, beat=beat, **kwargs)
def create(self, w):
from celery.beat import EmbeddedService
if w.pool_cls.__module__.endswith(('gevent', 'eventlet')):
raise ImproperlyConfigured(ERR_B_GREEN)
b = w.beat = EmbeddedService(w.app,
schedule_filename=w.schedule_filename,
scheduler_cls=w.scheduler)
return b
class StateDB(bootsteps.Step):
"""Bootstep that sets up between-restart state database file."""
def __init__(self, w, **kwargs):
self.enabled = w.statedb
w._persistence = None
super().__init__(w, **kwargs)
def create(self, w):
w._persistence = w.state.Persistent(w.state, w.statedb, w.app.clock)
atexit.register(w._persistence.save)
class Consumer(bootsteps.StartStopStep):
"""Bootstep starting the Consumer blueprint."""
last = True
def create(self, w):
if w.max_concurrency:
prefetch_count = max(w.max_concurrency, 1) * w.prefetch_multiplier
else:
prefetch_count = w.concurrency * w.prefetch_multiplier
c = w.consumer = self.instantiate(
w.consumer_cls, w.process_task,
hostname=w.hostname,
task_events=w.task_events,
init_callback=w.ready_callback,
initial_prefetch_count=prefetch_count,
pool=w.pool,
timer=w.timer,
app=w.app,
controller=w,
hub=w.hub,
worker_options=w.options,
disable_rate_limits=w.disable_rate_limits,
prefetch_multiplier=w.prefetch_multiplier,
)
return c

View File

@@ -0,0 +1,15 @@
"""Worker consumer."""
from .agent import Agent
from .connection import Connection
from .consumer import Consumer
from .control import Control
from .events import Events
from .gossip import Gossip
from .heart import Heart
from .mingle import Mingle
from .tasks import Tasks
__all__ = (
'Consumer', 'Agent', 'Connection', 'Control',
'Events', 'Gossip', 'Heart', 'Mingle', 'Tasks',
)

View File

@@ -0,0 +1,21 @@
"""Celery + :pypi:`cell` integration."""
from celery import bootsteps
from .connection import Connection
__all__ = ('Agent',)
class Agent(bootsteps.StartStopStep):
"""Agent starts :pypi:`cell` actors."""
conditional = True
requires = (Connection,)
def __init__(self, c, **kwargs):
self.agent_cls = self.enabled = c.app.conf.worker_agent
super().__init__(c, **kwargs)
def create(self, c):
agent = c.agent = self.instantiate(self.agent_cls, c.connection)
return agent

View File

@@ -0,0 +1,36 @@
"""Consumer Broker Connection Bootstep."""
from kombu.common import ignore_errors
from celery import bootsteps
from celery.utils.log import get_logger
__all__ = ('Connection',)
logger = get_logger(__name__)
info = logger.info
class Connection(bootsteps.StartStopStep):
"""Service managing the consumer broker connection."""
def __init__(self, c, **kwargs):
c.connection = None
super().__init__(c, **kwargs)
def start(self, c):
c.connection = c.connect()
info('Connected to %s', c.connection.as_uri())
def shutdown(self, c):
# We must set self.connection to None here, so
# that the green pidbox thread exits.
connection, c.connection = c.connection, None
if connection:
ignore_errors(connection, connection.close)
def info(self, c):
params = 'N/A'
if c.connection:
params = c.connection.info()
params.pop('password', None) # don't send password.
return {'broker': params}

View File

@@ -0,0 +1,745 @@
"""Worker Consumer Blueprint.
This module contains the components responsible for consuming messages
from the broker, processing the messages and keeping the broker connections
up and running.
"""
import errno
import logging
import os
import warnings
from collections import defaultdict
from time import sleep
from billiard.common import restart_state
from billiard.exceptions import RestartFreqExceeded
from kombu.asynchronous.semaphore import DummyLock
from kombu.exceptions import ContentDisallowed, DecodeError
from kombu.utils.compat import _detect_environment
from kombu.utils.encoding import safe_repr
from kombu.utils.limits import TokenBucket
from vine import ppartial, promise
from celery import bootsteps, signals
from celery.app.trace import build_tracer
from celery.exceptions import (CPendingDeprecationWarning, InvalidTaskError, NotRegistered, WorkerShutdown,
WorkerTerminate)
from celery.utils.functional import noop
from celery.utils.log import get_logger
from celery.utils.nodenames import gethostname
from celery.utils.objects import Bunch
from celery.utils.text import truncate
from celery.utils.time import humanize_seconds, rate
from celery.worker import loops
from celery.worker.state import active_requests, maybe_shutdown, requests, reserved_requests, task_reserved
__all__ = ('Consumer', 'Evloop', 'dump_body')
CLOSE = bootsteps.CLOSE
TERMINATE = bootsteps.TERMINATE
STOP_CONDITIONS = {CLOSE, TERMINATE}
logger = get_logger(__name__)
debug, info, warn, error, crit = (logger.debug, logger.info, logger.warning,
logger.error, logger.critical)
CONNECTION_RETRY = """\
consumer: Connection to broker lost. \
Trying to re-establish the connection...\
"""
CONNECTION_RETRY_STEP = """\
Trying again {when}... ({retries}/{max_retries})\
"""
CONNECTION_ERROR = """\
consumer: Cannot connect to %s: %s.
%s
"""
CONNECTION_FAILOVER = """\
Will retry using next failover.\
"""
UNKNOWN_FORMAT = """\
Received and deleted unknown message. Wrong destination?!?
The full contents of the message body was: %s
"""
#: Error message for when an unregistered task is received.
UNKNOWN_TASK_ERROR = """\
Received unregistered task of type %s.
The message has been ignored and discarded.
Did you remember to import the module containing this task?
Or maybe you're using relative imports?
Please see
https://docs.celeryq.dev/en/latest/internals/protocol.html
for more information.
The full contents of the message body was:
%s
The full contents of the message headers:
%s
The delivery info for this task is:
%s
"""
#: Error message for when an invalid task message is received.
INVALID_TASK_ERROR = """\
Received invalid task message: %s
The message has been ignored and discarded.
Please ensure your message conforms to the task
message protocol as described here:
https://docs.celeryq.dev/en/latest/internals/protocol.html
The full contents of the message body was:
%s
"""
MESSAGE_DECODE_ERROR = """\
Can't decode message body: %r [type:%r encoding:%r headers:%s]
body: %s
"""
MESSAGE_REPORT = """\
body: {0}
{{content_type:{1} content_encoding:{2}
delivery_info:{3} headers={4}}}
"""
TERMINATING_TASK_ON_RESTART_AFTER_A_CONNECTION_LOSS = """\
Task %s cannot be acknowledged after a connection loss since late acknowledgement is enabled for it.
Terminating it instead.
"""
CANCEL_TASKS_BY_DEFAULT = """
In Celery 5.1 we introduced an optional breaking change which
on connection loss cancels all currently executed tasks with late acknowledgement enabled.
These tasks cannot be acknowledged as the connection is gone, and the tasks are automatically redelivered
back to the queue. You can enable this behavior using the worker_cancel_long_running_tasks_on_connection_loss
setting. In Celery 5.1 it is set to False by default. The setting will be set to True by default in Celery 6.0.
"""
def dump_body(m, body):
"""Format message body for debugging purposes."""
# v2 protocol does not deserialize body
body = m.body if body is None else body
return '{} ({}b)'.format(truncate(safe_repr(body), 1024),
len(m.body))
class Consumer:
"""Consumer blueprint."""
Strategies = dict
#: Optional callback called the first time the worker
#: is ready to receive tasks.
init_callback = None
#: The current worker pool instance.
pool = None
#: A timer used for high-priority internal tasks, such
#: as sending heartbeats.
timer = None
restart_count = -1 # first start is the same as a restart
#: This flag will be turned off after the first failed
#: connection attempt.
first_connection_attempt = True
class Blueprint(bootsteps.Blueprint):
"""Consumer blueprint."""
name = 'Consumer'
default_steps = [
'celery.worker.consumer.connection:Connection',
'celery.worker.consumer.mingle:Mingle',
'celery.worker.consumer.events:Events',
'celery.worker.consumer.gossip:Gossip',
'celery.worker.consumer.heart:Heart',
'celery.worker.consumer.control:Control',
'celery.worker.consumer.tasks:Tasks',
'celery.worker.consumer.consumer:Evloop',
'celery.worker.consumer.agent:Agent',
]
def shutdown(self, parent):
self.send_all(parent, 'shutdown')
def __init__(self, on_task_request,
init_callback=noop, hostname=None,
pool=None, app=None,
timer=None, controller=None, hub=None, amqheartbeat=None,
worker_options=None, disable_rate_limits=False,
initial_prefetch_count=2, prefetch_multiplier=1, **kwargs):
self.app = app
self.controller = controller
self.init_callback = init_callback
self.hostname = hostname or gethostname()
self.pid = os.getpid()
self.pool = pool
self.timer = timer
self.strategies = self.Strategies()
self.conninfo = self.app.connection_for_read()
self.connection_errors = self.conninfo.connection_errors
self.channel_errors = self.conninfo.channel_errors
self._restart_state = restart_state(maxR=5, maxT=1)
self._does_info = logger.isEnabledFor(logging.INFO)
self._limit_order = 0
self.on_task_request = on_task_request
self.on_task_message = set()
self.amqheartbeat_rate = self.app.conf.broker_heartbeat_checkrate
self.disable_rate_limits = disable_rate_limits
self.initial_prefetch_count = initial_prefetch_count
self.prefetch_multiplier = prefetch_multiplier
self._maximum_prefetch_restored = True
# this contains a tokenbucket for each task type by name, used for
# rate limits, or None if rate limits are disabled for that task.
self.task_buckets = defaultdict(lambda: None)
self.reset_rate_limits()
self.hub = hub
if self.hub or getattr(self.pool, 'is_green', False):
self.amqheartbeat = amqheartbeat
if self.amqheartbeat is None:
self.amqheartbeat = self.app.conf.broker_heartbeat
else:
self.amqheartbeat = 0
if not hasattr(self, 'loop'):
self.loop = loops.asynloop if hub else loops.synloop
if _detect_environment() == 'gevent':
# there's a gevent bug that causes timeouts to not be reset,
# so if the connection timeout is exceeded once, it can NEVER
# connect again.
self.app.conf.broker_connection_timeout = None
self._pending_operations = []
self.steps = []
self.blueprint = self.Blueprint(
steps=self.app.steps['consumer'],
on_close=self.on_close,
)
self.blueprint.apply(self, **dict(worker_options or {}, **kwargs))
def call_soon(self, p, *args, **kwargs):
p = ppartial(p, *args, **kwargs)
if self.hub:
return self.hub.call_soon(p)
self._pending_operations.append(p)
return p
def perform_pending_operations(self):
if not self.hub:
while self._pending_operations:
try:
self._pending_operations.pop()()
except Exception as exc: # pylint: disable=broad-except
logger.exception('Pending callback raised: %r', exc)
def bucket_for_task(self, type):
limit = rate(getattr(type, 'rate_limit', None))
return TokenBucket(limit, capacity=1) if limit else None
def reset_rate_limits(self):
self.task_buckets.update(
(n, self.bucket_for_task(t)) for n, t in self.app.tasks.items()
)
def _update_prefetch_count(self, index=0):
"""Update prefetch count after pool/shrink grow operations.
Index must be the change in number of processes as a positive
(increasing) or negative (decreasing) number.
Note:
Currently pool grow operations will end up with an offset
of +1 if the initial size of the pool was 0 (e.g.
:option:`--autoscale=1,0 <celery worker --autoscale>`).
"""
num_processes = self.pool.num_processes
if not self.initial_prefetch_count or not num_processes:
return # prefetch disabled
self.initial_prefetch_count = (
self.pool.num_processes * self.prefetch_multiplier
)
return self._update_qos_eventually(index)
def _update_qos_eventually(self, index):
return (self.qos.decrement_eventually if index < 0
else self.qos.increment_eventually)(
abs(index) * self.prefetch_multiplier)
def _limit_move_to_pool(self, request):
task_reserved(request)
self.on_task_request(request)
def _schedule_bucket_request(self, bucket):
while True:
try:
request, tokens = bucket.pop()
except IndexError:
# no request, break
break
if bucket.can_consume(tokens):
self._limit_move_to_pool(request)
continue
else:
# requeue to head, keep the order.
bucket.contents.appendleft((request, tokens))
pri = self._limit_order = (self._limit_order + 1) % 10
hold = bucket.expected_time(tokens)
self.timer.call_after(
hold, self._schedule_bucket_request, (bucket,),
priority=pri,
)
# no tokens, break
break
def _limit_task(self, request, bucket, tokens):
bucket.add((request, tokens))
return self._schedule_bucket_request(bucket)
def _limit_post_eta(self, request, bucket, tokens):
self.qos.decrement_eventually()
bucket.add((request, tokens))
return self._schedule_bucket_request(bucket)
def start(self):
blueprint = self.blueprint
while blueprint.state not in STOP_CONDITIONS:
maybe_shutdown()
if self.restart_count:
try:
self._restart_state.step()
except RestartFreqExceeded as exc:
crit('Frequent restarts detected: %r', exc, exc_info=1)
sleep(1)
self.restart_count += 1
if self.app.conf.broker_channel_error_retry:
recoverable_errors = (self.connection_errors + self.channel_errors)
else:
recoverable_errors = self.connection_errors
try:
blueprint.start(self)
except recoverable_errors as exc:
# If we're not retrying connections, we need to properly shutdown or terminate
# the Celery main process instead of abruptly aborting the process without any cleanup.
is_connection_loss_on_startup = self.first_connection_attempt
self.first_connection_attempt = False
connection_retry_type = self._get_connection_retry_type(is_connection_loss_on_startup)
connection_retry = self.app.conf[connection_retry_type]
if not connection_retry:
crit(
f"Retrying to {'establish' if is_connection_loss_on_startup else 're-establish'} "
f"a connection to the message broker after a connection loss has "
f"been disabled (app.conf.{connection_retry_type}=False). Shutting down..."
)
raise WorkerShutdown(1) from exc
if isinstance(exc, OSError) and exc.errno == errno.EMFILE:
crit("Too many open files. Aborting...")
raise WorkerTerminate(1) from exc
maybe_shutdown()
if blueprint.state not in STOP_CONDITIONS:
if self.connection:
self.on_connection_error_after_connected(exc)
else:
self.on_connection_error_before_connected(exc)
self.on_close()
blueprint.restart(self)
def _get_connection_retry_type(self, is_connection_loss_on_startup):
return ('broker_connection_retry_on_startup'
if (is_connection_loss_on_startup
and self.app.conf.broker_connection_retry_on_startup is not None)
else 'broker_connection_retry')
def on_connection_error_before_connected(self, exc):
error(CONNECTION_ERROR, self.conninfo.as_uri(), exc,
'Trying to reconnect...')
def on_connection_error_after_connected(self, exc):
warn(CONNECTION_RETRY, exc_info=True)
try:
self.connection.collect()
except Exception: # pylint: disable=broad-except
pass
if self.app.conf.worker_cancel_long_running_tasks_on_connection_loss:
for request in tuple(active_requests):
if request.task.acks_late and not request.acknowledged:
warn(TERMINATING_TASK_ON_RESTART_AFTER_A_CONNECTION_LOSS,
request)
request.cancel(self.pool)
else:
warnings.warn(CANCEL_TASKS_BY_DEFAULT, CPendingDeprecationWarning)
self.initial_prefetch_count = max(
self.prefetch_multiplier,
self.max_prefetch_count - len(tuple(active_requests)) * self.prefetch_multiplier
)
self._maximum_prefetch_restored = self.initial_prefetch_count == self.max_prefetch_count
if not self._maximum_prefetch_restored:
logger.info(
f"Temporarily reducing the prefetch count to {self.initial_prefetch_count} to avoid over-fetching "
f"since {len(tuple(active_requests))} tasks are currently being processed.\n"
f"The prefetch count will be gradually restored to {self.max_prefetch_count} as the tasks "
"complete processing."
)
def register_with_event_loop(self, hub):
self.blueprint.send_all(
self, 'register_with_event_loop', args=(hub,),
description='Hub.register',
)
def shutdown(self):
self.blueprint.shutdown(self)
def stop(self):
self.blueprint.stop(self)
def on_ready(self):
callback, self.init_callback = self.init_callback, None
if callback:
callback(self)
def loop_args(self):
return (self, self.connection, self.task_consumer,
self.blueprint, self.hub, self.qos, self.amqheartbeat,
self.app.clock, self.amqheartbeat_rate)
def on_decode_error(self, message, exc):
"""Callback called if an error occurs while decoding a message.
Simply logs the error and acknowledges the message so it
doesn't enter a loop.
Arguments:
message (kombu.Message): The message received.
exc (Exception): The exception being handled.
"""
crit(MESSAGE_DECODE_ERROR,
exc, message.content_type, message.content_encoding,
safe_repr(message.headers), dump_body(message, message.body),
exc_info=1)
message.ack()
def on_close(self):
# Clear internal queues to get rid of old messages.
# They can't be acked anyway, as a delivery tag is specific
# to the current channel.
if self.controller and self.controller.semaphore:
self.controller.semaphore.clear()
if self.timer:
self.timer.clear()
for bucket in self.task_buckets.values():
if bucket:
bucket.clear_pending()
for request_id in reserved_requests:
if request_id in requests:
del requests[request_id]
reserved_requests.clear()
if self.pool and self.pool.flush:
self.pool.flush()
def connect(self):
"""Establish the broker connection used for consuming tasks.
Retries establishing the connection if the
:setting:`broker_connection_retry` setting is enabled
"""
conn = self.connection_for_read(heartbeat=self.amqheartbeat)
if self.hub:
conn.transport.register_with_event_loop(conn.connection, self.hub)
return conn
def connection_for_read(self, heartbeat=None):
return self.ensure_connected(
self.app.connection_for_read(heartbeat=heartbeat))
def connection_for_write(self, heartbeat=None):
return self.ensure_connected(
self.app.connection_for_write(heartbeat=heartbeat))
def ensure_connected(self, conn):
# Callback called for each retry while the connection
# can't be established.
def _error_handler(exc, interval, next_step=CONNECTION_RETRY_STEP):
if getattr(conn, 'alt', None) and interval == 0:
next_step = CONNECTION_FAILOVER
next_step = next_step.format(
when=humanize_seconds(interval, 'in', ' '),
retries=int(interval / 2),
max_retries=self.app.conf.broker_connection_max_retries)
error(CONNECTION_ERROR, conn.as_uri(), exc, next_step)
# Remember that the connection is lazy, it won't establish
# until needed.
# TODO: Rely only on broker_connection_retry_on_startup to determine whether connection retries are disabled.
# We will make the switch in Celery 6.0.
retry_disabled = False
if self.app.conf.broker_connection_retry_on_startup is None:
# If broker_connection_retry_on_startup is not set, revert to broker_connection_retry
# to determine whether connection retries are disabled.
retry_disabled = not self.app.conf.broker_connection_retry
warnings.warn(
CPendingDeprecationWarning(
f"The broker_connection_retry configuration setting will no longer determine\n"
f"whether broker connection retries are made during startup in Celery 6.0 and above.\n"
f"If you wish to retain the existing behavior for retrying connections on startup,\n"
f"you should set broker_connection_retry_on_startup to {self.app.conf.broker_connection_retry}.")
)
else:
if self.first_connection_attempt:
retry_disabled = not self.app.conf.broker_connection_retry_on_startup
else:
retry_disabled = not self.app.conf.broker_connection_retry
if retry_disabled:
# Retry disabled, just call connect directly.
conn.connect()
self.first_connection_attempt = False
return conn
conn = conn.ensure_connection(
_error_handler, self.app.conf.broker_connection_max_retries,
callback=maybe_shutdown,
)
self.first_connection_attempt = False
return conn
def _flush_events(self):
if self.event_dispatcher:
self.event_dispatcher.flush()
def on_send_event_buffered(self):
if self.hub:
self.hub._ready.add(self._flush_events)
def add_task_queue(self, queue, exchange=None, exchange_type=None,
routing_key=None, **options):
cset = self.task_consumer
queues = self.app.amqp.queues
# Must use in' here, as __missing__ will automatically
# create queues when :setting:`task_create_missing_queues` is enabled.
# (Issue #1079)
if queue in queues:
q = queues[queue]
else:
exchange = queue if exchange is None else exchange
exchange_type = ('direct' if exchange_type is None
else exchange_type)
q = queues.select_add(queue,
exchange=exchange,
exchange_type=exchange_type,
routing_key=routing_key, **options)
if not cset.consuming_from(queue):
cset.add_queue(q)
cset.consume()
info('Started consuming from %s', queue)
def cancel_task_queue(self, queue):
info('Canceling queue %s', queue)
self.app.amqp.queues.deselect(queue)
self.task_consumer.cancel_by_queue(queue)
def apply_eta_task(self, task):
"""Method called by the timer to apply a task with an ETA/countdown."""
task_reserved(task)
self.on_task_request(task)
self.qos.decrement_eventually()
def _message_report(self, body, message):
return MESSAGE_REPORT.format(dump_body(message, body),
safe_repr(message.content_type),
safe_repr(message.content_encoding),
safe_repr(message.delivery_info),
safe_repr(message.headers))
def on_unknown_message(self, body, message):
warn(UNKNOWN_FORMAT, self._message_report(body, message))
message.reject_log_error(logger, self.connection_errors)
signals.task_rejected.send(sender=self, message=message, exc=None)
def on_unknown_task(self, body, message, exc):
error(UNKNOWN_TASK_ERROR,
exc,
dump_body(message, body),
message.headers,
message.delivery_info,
exc_info=True)
try:
id_, name = message.headers['id'], message.headers['task']
root_id = message.headers.get('root_id')
except KeyError: # proto1
payload = message.payload
id_, name = payload['id'], payload['task']
root_id = None
request = Bunch(
name=name, chord=None, root_id=root_id,
correlation_id=message.properties.get('correlation_id'),
reply_to=message.properties.get('reply_to'),
errbacks=None,
)
message.reject_log_error(logger, self.connection_errors)
self.app.backend.mark_as_failure(
id_, NotRegistered(name), request=request,
)
if self.event_dispatcher:
self.event_dispatcher.send(
'task-failed', uuid=id_,
exception=f'NotRegistered({name!r})',
)
signals.task_unknown.send(
sender=self, message=message, exc=exc, name=name, id=id_,
)
def on_invalid_task(self, body, message, exc):
error(INVALID_TASK_ERROR, exc, dump_body(message, body),
exc_info=True)
message.reject_log_error(logger, self.connection_errors)
signals.task_rejected.send(sender=self, message=message, exc=exc)
def update_strategies(self):
loader = self.app.loader
for name, task in self.app.tasks.items():
self.strategies[name] = task.start_strategy(self.app, self)
task.__trace__ = build_tracer(name, task, loader, self.hostname,
app=self.app)
def create_task_handler(self, promise=promise):
strategies = self.strategies
on_unknown_message = self.on_unknown_message
on_unknown_task = self.on_unknown_task
on_invalid_task = self.on_invalid_task
callbacks = self.on_task_message
call_soon = self.call_soon
def on_task_received(message):
# payload will only be set for v1 protocol, since v2
# will defer deserializing the message body to the pool.
payload = None
try:
type_ = message.headers['task'] # protocol v2
except TypeError:
return on_unknown_message(None, message)
except KeyError:
try:
payload = message.decode()
except Exception as exc: # pylint: disable=broad-except
return self.on_decode_error(message, exc)
try:
type_, payload = payload['task'], payload # protocol v1
except (TypeError, KeyError):
return on_unknown_message(payload, message)
try:
strategy = strategies[type_]
except KeyError as exc:
return on_unknown_task(None, message, exc)
else:
try:
ack_log_error_promise = promise(
call_soon,
(message.ack_log_error,),
on_error=self._restore_prefetch_count_after_connection_restart,
)
reject_log_error_promise = promise(
call_soon,
(message.reject_log_error,),
on_error=self._restore_prefetch_count_after_connection_restart,
)
if (
not self._maximum_prefetch_restored
and self.restart_count > 0
and self._new_prefetch_count <= self.max_prefetch_count
):
ack_log_error_promise.then(self._restore_prefetch_count_after_connection_restart,
on_error=self._restore_prefetch_count_after_connection_restart)
reject_log_error_promise.then(self._restore_prefetch_count_after_connection_restart,
on_error=self._restore_prefetch_count_after_connection_restart)
strategy(
message, payload,
ack_log_error_promise,
reject_log_error_promise,
callbacks,
)
except (InvalidTaskError, ContentDisallowed) as exc:
return on_invalid_task(payload, message, exc)
except DecodeError as exc:
return self.on_decode_error(message, exc)
return on_task_received
def _restore_prefetch_count_after_connection_restart(self, p, *args):
with self.qos._mutex:
if self._maximum_prefetch_restored:
return
new_prefetch_count = min(self.max_prefetch_count, self._new_prefetch_count)
self.qos.value = self.initial_prefetch_count = new_prefetch_count
self.qos.set(self.qos.value)
already_restored = self._maximum_prefetch_restored
self._maximum_prefetch_restored = new_prefetch_count == self.max_prefetch_count
if already_restored is False and self._maximum_prefetch_restored is True:
logger.info(
"Resuming normal operations following a restart.\n"
f"Prefetch count has been restored to the maximum of {self.max_prefetch_count}"
)
@property
def max_prefetch_count(self):
return self.pool.num_processes * self.prefetch_multiplier
@property
def _new_prefetch_count(self):
return self.qos.value + self.prefetch_multiplier
def __repr__(self):
"""``repr(self)``."""
return '<Consumer: {self.hostname} ({state})>'.format(
self=self, state=self.blueprint.human_state(),
)
class Evloop(bootsteps.StartStopStep):
"""Event loop service.
Note:
This is always started last.
"""
label = 'event loop'
last = True
def start(self, c):
self.patch_all(c)
c.loop(*c.loop_args())
def patch_all(self, c):
c.qos._mutex = DummyLock()

View File

@@ -0,0 +1,33 @@
"""Worker Remote Control Bootstep.
``Control`` -> :mod:`celery.worker.pidbox` -> :mod:`kombu.pidbox`.
The actual commands are implemented in :mod:`celery.worker.control`.
"""
from celery import bootsteps
from celery.utils.log import get_logger
from celery.worker import pidbox
from .tasks import Tasks
__all__ = ('Control',)
logger = get_logger(__name__)
class Control(bootsteps.StartStopStep):
"""Remote control command service."""
requires = (Tasks,)
def __init__(self, c, **kwargs):
self.is_green = c.pool is not None and c.pool.is_green
self.box = (pidbox.gPidbox if self.is_green else pidbox.Pidbox)(c)
self.start = self.box.start
self.stop = self.box.stop
self.shutdown = self.box.shutdown
super().__init__(c, **kwargs)
def include_if(self, c):
return (c.app.conf.worker_enable_remote_control and
c.conninfo.supports_exchange_type('fanout'))

View File

@@ -0,0 +1,68 @@
"""Worker Event Dispatcher Bootstep.
``Events`` -> :class:`celery.events.EventDispatcher`.
"""
from kombu.common import ignore_errors
from celery import bootsteps
from .connection import Connection
__all__ = ('Events',)
class Events(bootsteps.StartStopStep):
"""Service used for sending monitoring events."""
requires = (Connection,)
def __init__(self, c,
task_events=True,
without_heartbeat=False,
without_gossip=False,
**kwargs):
self.groups = None if task_events else ['worker']
self.send_events = (
task_events or
not without_gossip or
not without_heartbeat
)
self.enabled = self.send_events
c.event_dispatcher = None
super().__init__(c, **kwargs)
def start(self, c):
# flush events sent while connection was down.
prev = self._close(c)
dis = c.event_dispatcher = c.app.events.Dispatcher(
c.connection_for_write(),
hostname=c.hostname,
enabled=self.send_events,
groups=self.groups,
# we currently only buffer events when the event loop is enabled
# XXX This excludes eventlet/gevent, which should actually buffer.
buffer_group=['task'] if c.hub else None,
on_send_buffered=c.on_send_event_buffered if c.hub else None,
)
if prev:
dis.extend_buffer(prev)
dis.flush()
def stop(self, c):
pass
def _close(self, c):
if c.event_dispatcher:
dispatcher = c.event_dispatcher
# remember changes from remote control commands:
self.groups = dispatcher.groups
# close custom connection
if dispatcher.connection:
ignore_errors(c, dispatcher.connection.close)
ignore_errors(c, dispatcher.close)
c.event_dispatcher = None
return dispatcher
def shutdown(self, c):
self._close(c)

View File

@@ -0,0 +1,205 @@
"""Worker <-> Worker communication Bootstep."""
from collections import defaultdict
from functools import partial
from heapq import heappush
from operator import itemgetter
from kombu import Consumer
from kombu.asynchronous.semaphore import DummyLock
from kombu.exceptions import ContentDisallowed, DecodeError
from celery import bootsteps
from celery.utils.log import get_logger
from celery.utils.objects import Bunch
from .mingle import Mingle
__all__ = ('Gossip',)
logger = get_logger(__name__)
debug, info = logger.debug, logger.info
class Gossip(bootsteps.ConsumerStep):
"""Bootstep consuming events from other workers.
This keeps the logical clock value up to date.
"""
label = 'Gossip'
requires = (Mingle,)
_cons_stamp_fields = itemgetter(
'id', 'clock', 'hostname', 'pid', 'topic', 'action', 'cver',
)
compatible_transports = {'amqp', 'redis'}
def __init__(self, c, without_gossip=False,
interval=5.0, heartbeat_interval=2.0, **kwargs):
self.enabled = not without_gossip and self.compatible_transport(c.app)
self.app = c.app
c.gossip = self
self.Receiver = c.app.events.Receiver
self.hostname = c.hostname
self.full_hostname = '.'.join([self.hostname, str(c.pid)])
self.on = Bunch(
node_join=set(),
node_leave=set(),
node_lost=set(),
)
self.timer = c.timer
if self.enabled:
self.state = c.app.events.State(
on_node_join=self.on_node_join,
on_node_leave=self.on_node_leave,
max_tasks_in_memory=1,
)
if c.hub:
c._mutex = DummyLock()
self.update_state = self.state.event
self.interval = interval
self.heartbeat_interval = heartbeat_interval
self._tref = None
self.consensus_requests = defaultdict(list)
self.consensus_replies = {}
self.event_handlers = {
'worker.elect': self.on_elect,
'worker.elect.ack': self.on_elect_ack,
}
self.clock = c.app.clock
self.election_handlers = {
'task': self.call_task
}
super().__init__(c, **kwargs)
def compatible_transport(self, app):
with app.connection_for_read() as conn:
return conn.transport.driver_type in self.compatible_transports
def election(self, id, topic, action=None):
self.consensus_replies[id] = []
self.dispatcher.send(
'worker-elect',
id=id, topic=topic, action=action, cver=1,
)
def call_task(self, task):
try:
self.app.signature(task).apply_async()
except Exception as exc: # pylint: disable=broad-except
logger.exception('Could not call task: %r', exc)
def on_elect(self, event):
try:
(id_, clock, hostname, pid,
topic, action, _) = self._cons_stamp_fields(event)
except KeyError as exc:
return logger.exception('election request missing field %s', exc)
heappush(
self.consensus_requests[id_],
(clock, f'{hostname}.{pid}', topic, action),
)
self.dispatcher.send('worker-elect-ack', id=id_)
def start(self, c):
super().start(c)
self.dispatcher = c.event_dispatcher
def on_elect_ack(self, event):
id = event['id']
try:
replies = self.consensus_replies[id]
except KeyError:
return # not for us
alive_workers = set(self.state.alive_workers())
replies.append(event['hostname'])
if len(replies) >= len(alive_workers):
_, leader, topic, action = self.clock.sort_heap(
self.consensus_requests[id],
)
if leader == self.full_hostname:
info('I won the election %r', id)
try:
handler = self.election_handlers[topic]
except KeyError:
logger.exception('Unknown election topic %r', topic)
else:
handler(action)
else:
info('node %s elected for %r', leader, id)
self.consensus_requests.pop(id, None)
self.consensus_replies.pop(id, None)
def on_node_join(self, worker):
debug('%s joined the party', worker.hostname)
self._call_handlers(self.on.node_join, worker)
def on_node_leave(self, worker):
debug('%s left', worker.hostname)
self._call_handlers(self.on.node_leave, worker)
def on_node_lost(self, worker):
info('missed heartbeat from %s', worker.hostname)
self._call_handlers(self.on.node_lost, worker)
def _call_handlers(self, handlers, *args, **kwargs):
for handler in handlers:
try:
handler(*args, **kwargs)
except Exception as exc: # pylint: disable=broad-except
logger.exception(
'Ignored error from handler %r: %r', handler, exc)
def register_timer(self):
if self._tref is not None:
self._tref.cancel()
self._tref = self.timer.call_repeatedly(self.interval, self.periodic)
def periodic(self):
workers = self.state.workers
dirty = set()
for worker in workers.values():
if not worker.alive:
dirty.add(worker)
self.on_node_lost(worker)
for worker in dirty:
workers.pop(worker.hostname, None)
def get_consumers(self, channel):
self.register_timer()
ev = self.Receiver(channel, routing_key='worker.#',
queue_ttl=self.heartbeat_interval)
return [Consumer(
channel,
queues=[ev.queue],
on_message=partial(self.on_message, ev.event_from_message),
no_ack=True
)]
def on_message(self, prepare, message):
_type = message.delivery_info['routing_key']
# For redis when `fanout_patterns=False` (See Issue #1882)
if _type.split('.', 1)[0] == 'task':
return
try:
handler = self.event_handlers[_type]
except KeyError:
pass
else:
return handler(message.payload)
# proto2: hostname in header; proto1: in body
hostname = (message.headers.get('hostname') or
message.payload['hostname'])
if hostname != self.hostname:
try:
_, event = prepare(message.payload)
self.update_state(event)
except (DecodeError, ContentDisallowed, TypeError) as exc:
logger.error(exc)
else:
self.clock.forward()

View File

@@ -0,0 +1,36 @@
"""Worker Event Heartbeat Bootstep."""
from celery import bootsteps
from celery.worker import heartbeat
from .events import Events
__all__ = ('Heart',)
class Heart(bootsteps.StartStopStep):
"""Bootstep sending event heartbeats.
This service sends a ``worker-heartbeat`` message every n seconds.
Note:
Not to be confused with AMQP protocol level heartbeats.
"""
requires = (Events,)
def __init__(self, c,
without_heartbeat=False, heartbeat_interval=None, **kwargs):
self.enabled = not without_heartbeat
self.heartbeat_interval = heartbeat_interval
c.heart = None
super().__init__(c, **kwargs)
def start(self, c):
c.heart = heartbeat.Heart(
c.timer, c.event_dispatcher, self.heartbeat_interval,
)
c.heart.start()
def stop(self, c):
c.heart = c.heart and c.heart.stop()
shutdown = stop

View File

@@ -0,0 +1,76 @@
"""Worker <-> Worker Sync at startup (Bootstep)."""
from celery import bootsteps
from celery.utils.log import get_logger
from .events import Events
__all__ = ('Mingle',)
logger = get_logger(__name__)
debug, info, exception = logger.debug, logger.info, logger.exception
class Mingle(bootsteps.StartStopStep):
"""Bootstep syncing state with neighbor workers.
At startup, or upon consumer restart, this will:
- Sync logical clocks.
- Sync revoked tasks.
"""
label = 'Mingle'
requires = (Events,)
compatible_transports = {'amqp', 'redis'}
def __init__(self, c, without_mingle=False, **kwargs):
self.enabled = not without_mingle and self.compatible_transport(c.app)
super().__init__(
c, without_mingle=without_mingle, **kwargs)
def compatible_transport(self, app):
with app.connection_for_read() as conn:
return conn.transport.driver_type in self.compatible_transports
def start(self, c):
self.sync(c)
def sync(self, c):
info('mingle: searching for neighbors')
replies = self.send_hello(c)
if replies:
info('mingle: sync with %s nodes',
len([reply for reply, value in replies.items() if value]))
[self.on_node_reply(c, nodename, reply)
for nodename, reply in replies.items() if reply]
info('mingle: sync complete')
else:
info('mingle: all alone')
def send_hello(self, c):
inspect = c.app.control.inspect(timeout=1.0, connection=c.connection)
our_revoked = c.controller.state.revoked
replies = inspect.hello(c.hostname, our_revoked._data) or {}
replies.pop(c.hostname, None) # delete my own response
return replies
def on_node_reply(self, c, nodename, reply):
debug('mingle: processing reply from %s', nodename)
try:
self.sync_with_node(c, **reply)
except MemoryError:
raise
except Exception as exc: # pylint: disable=broad-except
exception('mingle: sync with %s failed: %r', nodename, exc)
def sync_with_node(self, c, clock=None, revoked=None, **kwargs):
self.on_clock_event(c, clock)
self.on_revoked_received(c, revoked)
def on_clock_event(self, c, clock):
c.app.clock.adjust(clock) if clock else c.app.clock.forward()
def on_revoked_received(self, c, revoked):
if revoked:
c.controller.state.revoked.update(revoked)

View File

@@ -0,0 +1,65 @@
"""Worker Task Consumer Bootstep."""
from kombu.common import QoS, ignore_errors
from celery import bootsteps
from celery.utils.log import get_logger
from .mingle import Mingle
__all__ = ('Tasks',)
logger = get_logger(__name__)
debug = logger.debug
class Tasks(bootsteps.StartStopStep):
"""Bootstep starting the task message consumer."""
requires = (Mingle,)
def __init__(self, c, **kwargs):
c.task_consumer = c.qos = None
super().__init__(c, **kwargs)
def start(self, c):
"""Start task consumer."""
c.update_strategies()
# - RabbitMQ 3.3 completely redefines how basic_qos works...
# This will detect if the new qos semantics is in effect,
# and if so make sure the 'apply_global' flag is set on qos updates.
qos_global = not c.connection.qos_semantics_matches_spec
# set initial prefetch count
c.connection.default_channel.basic_qos(
0, c.initial_prefetch_count, qos_global,
)
c.task_consumer = c.app.amqp.TaskConsumer(
c.connection, on_decode_error=c.on_decode_error,
)
def set_prefetch_count(prefetch_count):
return c.task_consumer.qos(
prefetch_count=prefetch_count,
apply_global=qos_global,
)
c.qos = QoS(set_prefetch_count, c.initial_prefetch_count)
def stop(self, c):
"""Stop task consumer."""
if c.task_consumer:
debug('Canceling task consumer...')
ignore_errors(c, c.task_consumer.cancel)
def shutdown(self, c):
"""Shutdown task consumer."""
if c.task_consumer:
self.stop(c)
debug('Closing consumer channel...')
ignore_errors(c, c.task_consumer.close)
c.task_consumer = None
def info(self, c):
"""Return task consumer info."""
return {'prefetch_count': c.qos.value if c.qos else 'N/A'}

View File

@@ -0,0 +1,624 @@
"""Worker remote control command implementations."""
import io
import tempfile
from collections import UserDict, defaultdict, namedtuple
from billiard.common import TERM_SIGNAME
from kombu.utils.encoding import safe_repr
from celery.exceptions import WorkerShutdown
from celery.platforms import signals as _signals
from celery.utils.functional import maybe_list
from celery.utils.log import get_logger
from celery.utils.serialization import jsonify, strtobool
from celery.utils.time import rate
from . import state as worker_state
from .request import Request
__all__ = ('Panel',)
DEFAULT_TASK_INFO_ITEMS = ('exchange', 'routing_key', 'rate_limit')
logger = get_logger(__name__)
controller_info_t = namedtuple('controller_info_t', [
'alias', 'type', 'visible', 'default_timeout',
'help', 'signature', 'args', 'variadic',
])
def ok(value):
return {'ok': value}
def nok(value):
return {'error': value}
class Panel(UserDict):
"""Global registry of remote control commands."""
data = {} # global dict.
meta = {} # -"-
@classmethod
def register(cls, *args, **kwargs):
if args:
return cls._register(**kwargs)(*args)
return cls._register(**kwargs)
@classmethod
def _register(cls, name=None, alias=None, type='control',
visible=True, default_timeout=1.0, help=None,
signature=None, args=None, variadic=None):
def _inner(fun):
control_name = name or fun.__name__
_help = help or (fun.__doc__ or '').strip().split('\n')[0]
cls.data[control_name] = fun
cls.meta[control_name] = controller_info_t(
alias, type, visible, default_timeout,
_help, signature, args, variadic)
if alias:
cls.data[alias] = fun
return fun
return _inner
def control_command(**kwargs):
return Panel.register(type='control', **kwargs)
def inspect_command(**kwargs):
return Panel.register(type='inspect', **kwargs)
# -- App
@inspect_command()
def report(state):
"""Information about Celery installation for bug reports."""
return ok(state.app.bugreport())
@inspect_command(
alias='dump_conf', # XXX < backwards compatible
signature='[include_defaults=False]',
args=[('with_defaults', strtobool)],
)
def conf(state, with_defaults=False, **kwargs):
"""List configuration."""
return jsonify(state.app.conf.table(with_defaults=with_defaults),
keyfilter=_wanted_config_key,
unknown_type_filter=safe_repr)
def _wanted_config_key(key):
return isinstance(key, str) and not key.startswith('__')
# -- Task
@inspect_command(
variadic='ids',
signature='[id1 [id2 [... [idN]]]]',
)
def query_task(state, ids, **kwargs):
"""Query for task information by id."""
return {
req.id: (_state_of_task(req), req.info())
for req in _find_requests_by_id(maybe_list(ids))
}
def _find_requests_by_id(ids,
get_request=worker_state.requests.__getitem__):
for task_id in ids:
try:
yield get_request(task_id)
except KeyError:
pass
def _state_of_task(request,
is_active=worker_state.active_requests.__contains__,
is_reserved=worker_state.reserved_requests.__contains__):
if is_active(request):
return 'active'
elif is_reserved(request):
return 'reserved'
return 'ready'
@control_command(
variadic='task_id',
signature='[id1 [id2 [... [idN]]]]',
)
def revoke(state, task_id, terminate=False, signal=None, **kwargs):
"""Revoke task by task id (or list of ids).
Keyword Arguments:
terminate (bool): Also terminate the process if the task is active.
signal (str): Name of signal to use for terminate (e.g., ``KILL``).
"""
# pylint: disable=redefined-outer-name
# XXX Note that this redefines `terminate`:
# Outside of this scope that is a function.
# supports list argument since 3.1
task_ids, task_id = set(maybe_list(task_id) or []), None
task_ids = _revoke(state, task_ids, terminate, signal, **kwargs)
if isinstance(task_ids, dict) and 'ok' in task_ids:
return task_ids
return ok(f'tasks {task_ids} flagged as revoked')
@control_command(
variadic='headers',
signature='[key1=value1 [key2=value2 [... [keyN=valueN]]]]',
)
def revoke_by_stamped_headers(state, headers, terminate=False, signal=None, **kwargs):
"""Revoke task by header (or list of headers).
Keyword Arguments:
headers(dictionary): Dictionary that contains stamping scheme name as keys and stamps as values.
If headers is a list, it will be converted to a dictionary.
terminate (bool): Also terminate the process if the task is active.
signal (str): Name of signal to use for terminate (e.g., ``KILL``).
Sample headers input:
{'mtask_id': [id1, id2, id3]}
"""
# pylint: disable=redefined-outer-name
# XXX Note that this redefines `terminate`:
# Outside of this scope that is a function.
# supports list argument since 3.1
signum = _signals.signum(signal or TERM_SIGNAME)
if isinstance(headers, list):
headers = {h.split('=')[0]: h.split('=')[1] for h in headers}
for header, stamps in headers.items():
updated_stamps = maybe_list(worker_state.revoked_stamps.get(header) or []) + list(maybe_list(stamps))
worker_state.revoked_stamps[header] = updated_stamps
if not terminate:
return ok(f'headers {headers} flagged as revoked, but not terminated')
active_requests = list(worker_state.active_requests)
terminated_scheme_to_stamps_mapping = defaultdict(set)
# Terminate all running tasks of matching headers
# Go through all active requests, and check if one of the
# requests has a stamped header that matches the given headers to revoke
for req in active_requests:
# Check stamps exist
if hasattr(req, "stamps") and req.stamps:
# if so, check if any stamps match a revoked stamp
for expected_header_key, expected_header_value in headers.items():
if expected_header_key in req.stamps:
expected_header_value = maybe_list(expected_header_value)
actual_header = maybe_list(req.stamps[expected_header_key])
matching_stamps_for_request = set(actual_header) & set(expected_header_value)
# Check any possible match regardless if the stamps are a sequence or not
if matching_stamps_for_request:
terminated_scheme_to_stamps_mapping[expected_header_key].update(matching_stamps_for_request)
req.terminate(state.consumer.pool, signal=signum)
if not terminated_scheme_to_stamps_mapping:
return ok(f'headers {headers} were not terminated')
return ok(f'headers {terminated_scheme_to_stamps_mapping} revoked')
def _revoke(state, task_ids, terminate=False, signal=None, **kwargs):
size = len(task_ids)
terminated = set()
worker_state.revoked.update(task_ids)
if terminate:
signum = _signals.signum(signal or TERM_SIGNAME)
for request in _find_requests_by_id(task_ids):
if request.id not in terminated:
terminated.add(request.id)
logger.info('Terminating %s (%s)', request.id, signum)
request.terminate(state.consumer.pool, signal=signum)
if len(terminated) >= size:
break
if not terminated:
return ok('terminate: tasks unknown')
return ok('terminate: {}'.format(', '.join(terminated)))
idstr = ', '.join(task_ids)
logger.info('Tasks flagged as revoked: %s', idstr)
return task_ids
@control_command(
variadic='task_id',
args=[('signal', str)],
signature='<signal> [id1 [id2 [... [idN]]]]'
)
def terminate(state, signal, task_id, **kwargs):
"""Terminate task by task id (or list of ids)."""
return revoke(state, task_id, terminate=True, signal=signal)
@control_command(
args=[('task_name', str), ('rate_limit', str)],
signature='<task_name> <rate_limit (e.g., 5/s | 5/m | 5/h)>',
)
def rate_limit(state, task_name, rate_limit, **kwargs):
"""Tell worker(s) to modify the rate limit for a task by type.
See Also:
:attr:`celery.app.task.Task.rate_limit`.
Arguments:
task_name (str): Type of task to set rate limit for.
rate_limit (int, str): New rate limit.
"""
# pylint: disable=redefined-outer-name
# XXX Note that this redefines `terminate`:
# Outside of this scope that is a function.
try:
rate(rate_limit)
except ValueError as exc:
return nok(f'Invalid rate limit string: {exc!r}')
try:
state.app.tasks[task_name].rate_limit = rate_limit
except KeyError:
logger.error('Rate limit attempt for unknown task %s',
task_name, exc_info=True)
return nok('unknown task')
state.consumer.reset_rate_limits()
if not rate_limit:
logger.info('Rate limits disabled for tasks of type %s', task_name)
return ok('rate limit disabled successfully')
logger.info('New rate limit for tasks of type %s: %s.',
task_name, rate_limit)
return ok('new rate limit set successfully')
@control_command(
args=[('task_name', str), ('soft', float), ('hard', float)],
signature='<task_name> <soft_secs> [hard_secs]',
)
def time_limit(state, task_name=None, hard=None, soft=None, **kwargs):
"""Tell worker(s) to modify the time limit for task by type.
Arguments:
task_name (str): Name of task to change.
hard (float): Hard time limit.
soft (float): Soft time limit.
"""
try:
task = state.app.tasks[task_name]
except KeyError:
logger.error('Change time limit attempt for unknown task %s',
task_name, exc_info=True)
return nok('unknown task')
task.soft_time_limit = soft
task.time_limit = hard
logger.info('New time limits for tasks of type %s: soft=%s hard=%s',
task_name, soft, hard)
return ok('time limits set successfully')
# -- Events
@inspect_command()
def clock(state, **kwargs):
"""Get current logical clock value."""
return {'clock': state.app.clock.value}
@control_command()
def election(state, id, topic, action=None, **kwargs):
"""Hold election.
Arguments:
id (str): Unique election id.
topic (str): Election topic.
action (str): Action to take for elected actor.
"""
if state.consumer.gossip:
state.consumer.gossip.election(id, topic, action)
@control_command()
def enable_events(state):
"""Tell worker(s) to send task-related events."""
dispatcher = state.consumer.event_dispatcher
if dispatcher.groups and 'task' not in dispatcher.groups:
dispatcher.groups.add('task')
logger.info('Events of group {task} enabled by remote.')
return ok('task events enabled')
return ok('task events already enabled')
@control_command()
def disable_events(state):
"""Tell worker(s) to stop sending task-related events."""
dispatcher = state.consumer.event_dispatcher
if 'task' in dispatcher.groups:
dispatcher.groups.discard('task')
logger.info('Events of group {task} disabled by remote.')
return ok('task events disabled')
return ok('task events already disabled')
@control_command()
def heartbeat(state):
"""Tell worker(s) to send event heartbeat immediately."""
logger.debug('Heartbeat requested by remote.')
dispatcher = state.consumer.event_dispatcher
dispatcher.send('worker-heartbeat', freq=5, **worker_state.SOFTWARE_INFO)
# -- Worker
@inspect_command(visible=False)
def hello(state, from_node, revoked=None, **kwargs):
"""Request mingle sync-data."""
# pylint: disable=redefined-outer-name
# XXX Note that this redefines `revoked`:
# Outside of this scope that is a function.
if from_node != state.hostname:
logger.info('sync with %s', from_node)
if revoked:
worker_state.revoked.update(revoked)
# Do not send expired items to the other worker.
worker_state.revoked.purge()
return {
'revoked': worker_state.revoked._data,
'clock': state.app.clock.forward(),
}
@inspect_command(default_timeout=0.2)
def ping(state, **kwargs):
"""Ping worker(s)."""
return ok('pong')
@inspect_command()
def stats(state, **kwargs):
"""Request worker statistics/information."""
return state.consumer.controller.stats()
@inspect_command(alias='dump_schedule')
def scheduled(state, **kwargs):
"""List of currently scheduled ETA/countdown tasks."""
return list(_iter_schedule_requests(state.consumer.timer))
def _iter_schedule_requests(timer):
for waiting in timer.schedule.queue:
try:
arg0 = waiting.entry.args[0]
except (IndexError, TypeError):
continue
else:
if isinstance(arg0, Request):
yield {
'eta': arg0.eta.isoformat() if arg0.eta else None,
'priority': waiting.priority,
'request': arg0.info(),
}
@inspect_command(alias='dump_reserved')
def reserved(state, **kwargs):
"""List of currently reserved tasks, not including scheduled/active."""
reserved_tasks = (
state.tset(worker_state.reserved_requests) -
state.tset(worker_state.active_requests)
)
if not reserved_tasks:
return []
return [request.info() for request in reserved_tasks]
@inspect_command(alias='dump_active')
def active(state, safe=False, **kwargs):
"""List of tasks currently being executed."""
return [request.info(safe=safe)
for request in state.tset(worker_state.active_requests)]
@inspect_command(alias='dump_revoked')
def revoked(state, **kwargs):
"""List of revoked task-ids."""
return list(worker_state.revoked)
@inspect_command(
alias='dump_tasks',
variadic='taskinfoitems',
signature='[attr1 [attr2 [... [attrN]]]]',
)
def registered(state, taskinfoitems=None, builtins=False, **kwargs):
"""List of registered tasks.
Arguments:
taskinfoitems (Sequence[str]): List of task attributes to include.
Defaults to ``exchange,routing_key,rate_limit``.
builtins (bool): Also include built-in tasks.
"""
reg = state.app.tasks
taskinfoitems = taskinfoitems or DEFAULT_TASK_INFO_ITEMS
tasks = reg if builtins else (
task for task in reg if not task.startswith('celery.'))
def _extract_info(task):
fields = {
field: str(getattr(task, field, None)) for field in taskinfoitems
if getattr(task, field, None) is not None
}
if fields:
info = ['='.join(f) for f in fields.items()]
return '{} [{}]'.format(task.name, ' '.join(info))
return task.name
return [_extract_info(reg[task]) for task in sorted(tasks)]
# -- Debugging
@inspect_command(
default_timeout=60.0,
args=[('type', str), ('num', int), ('max_depth', int)],
signature='[object_type=Request] [num=200 [max_depth=10]]',
)
def objgraph(state, num=200, max_depth=10, type='Request'): # pragma: no cover
"""Create graph of uncollected objects (memory-leak debugging).
Arguments:
num (int): Max number of objects to graph.
max_depth (int): Traverse at most n levels deep.
type (str): Name of object to graph. Default is ``"Request"``.
"""
try:
import objgraph as _objgraph
except ImportError:
raise ImportError('Requires the objgraph library')
logger.info('Dumping graph for type %r', type)
with tempfile.NamedTemporaryFile(prefix='cobjg',
suffix='.png', delete=False) as fh:
objects = _objgraph.by_type(type)[:num]
_objgraph.show_backrefs(
objects,
max_depth=max_depth, highlight=lambda v: v in objects,
filename=fh.name,
)
return {'filename': fh.name}
@inspect_command()
def memsample(state, **kwargs):
"""Sample current RSS memory usage."""
from celery.utils.debug import sample_mem
return sample_mem()
@inspect_command(
args=[('samples', int)],
signature='[n_samples=10]',
)
def memdump(state, samples=10, **kwargs): # pragma: no cover
"""Dump statistics of previous memsample requests."""
from celery.utils import debug
out = io.StringIO()
debug.memdump(file=out)
return out.getvalue()
# -- Pool
@control_command(
args=[('n', int)],
signature='[N=1]',
)
def pool_grow(state, n=1, **kwargs):
"""Grow pool by n processes/threads."""
if state.consumer.controller.autoscaler:
return nok("pool_grow is not supported with autoscale. Adjust autoscale range instead.")
else:
state.consumer.pool.grow(n)
state.consumer._update_prefetch_count(n)
return ok('pool will grow')
@control_command(
args=[('n', int)],
signature='[N=1]',
)
def pool_shrink(state, n=1, **kwargs):
"""Shrink pool by n processes/threads."""
if state.consumer.controller.autoscaler:
return nok("pool_shrink is not supported with autoscale. Adjust autoscale range instead.")
else:
state.consumer.pool.shrink(n)
state.consumer._update_prefetch_count(-n)
return ok('pool will shrink')
@control_command()
def pool_restart(state, modules=None, reload=False, reloader=None, **kwargs):
"""Restart execution pool."""
if state.app.conf.worker_pool_restarts:
state.consumer.controller.reload(modules, reload, reloader=reloader)
return ok('reload started')
else:
raise ValueError('Pool restarts not enabled')
@control_command(
args=[('max', int), ('min', int)],
signature='[max [min]]',
)
def autoscale(state, max=None, min=None):
"""Modify autoscale settings."""
autoscaler = state.consumer.controller.autoscaler
if autoscaler:
max_, min_ = autoscaler.update(max, min)
return ok(f'autoscale now max={max_} min={min_}')
raise ValueError('Autoscale not enabled')
@control_command()
def shutdown(state, msg='Got shutdown from remote', **kwargs):
"""Shutdown worker(s)."""
logger.warning(msg)
raise WorkerShutdown(msg)
# -- Queues
@control_command(
args=[
('queue', str),
('exchange', str),
('exchange_type', str),
('routing_key', str),
],
signature='<queue> [exchange [type [routing_key]]]',
)
def add_consumer(state, queue, exchange=None, exchange_type=None,
routing_key=None, **options):
"""Tell worker(s) to consume from task queue by name."""
state.consumer.call_soon(
state.consumer.add_task_queue,
queue, exchange, exchange_type or 'direct', routing_key, **options)
return ok(f'add consumer {queue}')
@control_command(
args=[('queue', str)],
signature='<queue>',
)
def cancel_consumer(state, queue, **_):
"""Tell worker(s) to stop consuming from task queue by name."""
state.consumer.call_soon(
state.consumer.cancel_task_queue, queue,
)
return ok(f'no longer consuming from {queue}')
@inspect_command()
def active_queues(state):
"""List the task queues a worker is currently consuming from."""
if state.consumer.task_consumer:
return [dict(queue.as_dict(recurse=True))
for queue in state.consumer.task_consumer.queues]
return []

View File

@@ -0,0 +1,61 @@
"""Heartbeat service.
This is the internal thread responsible for sending heartbeat events
at regular intervals (may not be an actual thread).
"""
from celery.signals import heartbeat_sent
from celery.utils.sysinfo import load_average
from .state import SOFTWARE_INFO, active_requests, all_total_count
__all__ = ('Heart',)
class Heart:
"""Timer sending heartbeats at regular intervals.
Arguments:
timer (kombu.asynchronous.timer.Timer): Timer to use.
eventer (celery.events.EventDispatcher): Event dispatcher
to use.
interval (float): Time in seconds between sending
heartbeats. Default is 2 seconds.
"""
def __init__(self, timer, eventer, interval=None):
self.timer = timer
self.eventer = eventer
self.interval = float(interval or 2.0)
self.tref = None
# Make event dispatcher start/stop us when enabled/disabled.
self.eventer.on_enabled.add(self.start)
self.eventer.on_disabled.add(self.stop)
# Only send heartbeat_sent signal if it has receivers.
self._send_sent_signal = (
heartbeat_sent.send if heartbeat_sent.receivers else None)
def _send(self, event, retry=True):
if self._send_sent_signal is not None:
self._send_sent_signal(sender=self)
return self.eventer.send(event, freq=self.interval,
active=len(active_requests),
processed=all_total_count[0],
loadavg=load_average(),
retry=retry,
**SOFTWARE_INFO)
def start(self):
if self.eventer.enabled:
self._send('worker-online')
self.tref = self.timer.call_repeatedly(
self.interval, self._send, ('worker-heartbeat',),
)
def stop(self):
if self.tref is not None:
self.timer.cancel(self.tref)
self.tref = None
if self.eventer.enabled:
self._send('worker-offline', retry=False)

View File

@@ -0,0 +1,135 @@
"""The consumers highly-optimized inner loop."""
import errno
import socket
from celery import bootsteps
from celery.exceptions import WorkerLostError
from celery.utils.log import get_logger
from . import state
__all__ = ('asynloop', 'synloop')
# pylint: disable=redefined-outer-name
# We cache globals and attribute lookups, so disable this warning.
logger = get_logger(__name__)
def _quick_drain(connection, timeout=0.1):
try:
connection.drain_events(timeout=timeout)
except Exception as exc: # pylint: disable=broad-except
exc_errno = getattr(exc, 'errno', None)
if exc_errno is not None and exc_errno != errno.EAGAIN:
raise
def _enable_amqheartbeats(timer, connection, rate=2.0):
heartbeat_error = [None]
if not connection:
return heartbeat_error
heartbeat = connection.get_heartbeat_interval() # negotiated
if not (heartbeat and connection.supports_heartbeats):
return heartbeat_error
def tick(rate):
try:
connection.heartbeat_check(rate)
except Exception as e:
# heartbeat_error is passed by reference can be updated
# no append here list should be fixed size=1
heartbeat_error[0] = e
timer.call_repeatedly(heartbeat / rate, tick, (rate,))
return heartbeat_error
def asynloop(obj, connection, consumer, blueprint, hub, qos,
heartbeat, clock, hbrate=2.0):
"""Non-blocking event loop."""
RUN = bootsteps.RUN
update_qos = qos.update
errors = connection.connection_errors
on_task_received = obj.create_task_handler()
heartbeat_error = _enable_amqheartbeats(hub.timer, connection, rate=hbrate)
consumer.on_message = on_task_received
obj.controller.register_with_event_loop(hub)
obj.register_with_event_loop(hub)
consumer.consume()
obj.on_ready()
# did_start_ok will verify that pool processes were able to start,
# but this will only work the first time we start, as
# maxtasksperchild will mess up metrics.
if not obj.restart_count and not obj.pool.did_start_ok():
raise WorkerLostError('Could not start worker processes')
# consumer.consume() may have prefetched up to our
# limit - drain an event so we're in a clean state
# prior to starting our event loop.
if connection.transport.driver_type == 'amqp':
hub.call_soon(_quick_drain, connection)
# FIXME: Use loop.run_forever
# Tried and works, but no time to test properly before release.
hub.propagate_errors = errors
loop = hub.create_loop()
try:
while blueprint.state == RUN and obj.connection:
state.maybe_shutdown()
if heartbeat_error[0] is not None:
raise heartbeat_error[0]
# We only update QoS when there's no more messages to read.
# This groups together qos calls, and makes sure that remote
# control commands will be prioritized over task messages.
if qos.prev != qos.value:
update_qos()
try:
next(loop)
except StopIteration:
loop = hub.create_loop()
finally:
try:
hub.reset()
except Exception as exc: # pylint: disable=broad-except
logger.exception(
'Error cleaning up after event loop: %r', exc)
def synloop(obj, connection, consumer, blueprint, hub, qos,
heartbeat, clock, hbrate=2.0, **kwargs):
"""Fallback blocking event loop for transports that doesn't support AIO."""
RUN = bootsteps.RUN
on_task_received = obj.create_task_handler()
perform_pending_operations = obj.perform_pending_operations
heartbeat_error = [None]
if getattr(obj.pool, 'is_green', False):
heartbeat_error = _enable_amqheartbeats(obj.timer, connection, rate=hbrate)
consumer.on_message = on_task_received
consumer.consume()
obj.on_ready()
while blueprint.state == RUN and obj.connection:
state.maybe_shutdown()
if heartbeat_error[0] is not None:
raise heartbeat_error[0]
if qos.prev != qos.value:
qos.update()
try:
perform_pending_operations()
connection.drain_events(timeout=2.0)
except socket.timeout:
pass
except OSError:
if blueprint.state == RUN:
raise

View File

@@ -0,0 +1,122 @@
"""Worker Pidbox (remote control)."""
import socket
import threading
from kombu.common import ignore_errors
from kombu.utils.encoding import safe_str
from celery.utils.collections import AttributeDict
from celery.utils.functional import pass1
from celery.utils.log import get_logger
from . import control
__all__ = ('Pidbox', 'gPidbox')
logger = get_logger(__name__)
debug, error, info = logger.debug, logger.error, logger.info
class Pidbox:
"""Worker mailbox."""
consumer = None
def __init__(self, c):
self.c = c
self.hostname = c.hostname
self.node = c.app.control.mailbox.Node(
safe_str(c.hostname),
handlers=control.Panel.data,
state=AttributeDict(
app=c.app,
hostname=c.hostname,
consumer=c,
tset=pass1 if c.controller.use_eventloop else set),
)
self._forward_clock = self.c.app.clock.forward
def on_message(self, body, message):
# just increase clock as clients usually don't
# have a valid clock to adjust with.
self._forward_clock()
try:
self.node.handle_message(body, message)
except KeyError as exc:
error('No such control command: %s', exc)
except Exception as exc:
error('Control command error: %r', exc, exc_info=True)
self.reset()
def start(self, c):
self.node.channel = c.connection.channel()
self.consumer = self.node.listen(callback=self.on_message)
self.consumer.on_decode_error = c.on_decode_error
def on_stop(self):
pass
def stop(self, c):
self.on_stop()
self.consumer = self._close_channel(c)
def reset(self):
self.stop(self.c)
self.start(self.c)
def _close_channel(self, c):
if self.node and self.node.channel:
ignore_errors(c, self.node.channel.close)
def shutdown(self, c):
self.on_stop()
if self.consumer:
debug('Canceling broadcast consumer...')
ignore_errors(c, self.consumer.cancel)
self.stop(self.c)
class gPidbox(Pidbox):
"""Worker pidbox (greenlet)."""
_node_shutdown = None
_node_stopped = None
_resets = 0
def start(self, c):
c.pool.spawn_n(self.loop, c)
def on_stop(self):
if self._node_stopped:
self._node_shutdown.set()
debug('Waiting for broadcast thread to shutdown...')
self._node_stopped.wait()
self._node_stopped = self._node_shutdown = None
def reset(self):
self._resets += 1
def _do_reset(self, c, connection):
self._close_channel(c)
self.node.channel = connection.channel()
self.consumer = self.node.listen(callback=self.on_message)
self.consumer.consume()
def loop(self, c):
resets = [self._resets]
shutdown = self._node_shutdown = threading.Event()
stopped = self._node_stopped = threading.Event()
try:
with c.connection_for_read() as connection:
info('pidbox: Connected to %s.', connection.as_uri())
self._do_reset(c, connection)
while not shutdown.is_set() and c.connection:
if resets[0] < self._resets:
resets[0] += 1
self._do_reset(c, connection)
try:
connection.drain_events(timeout=1.0)
except socket.timeout:
pass
finally:
stopped.set()

View File

@@ -0,0 +1,790 @@
"""Task request.
This module defines the :class:`Request` class, that specifies
how tasks are executed.
"""
import logging
import sys
from datetime import datetime
from time import monotonic, time
from weakref import ref
from billiard.common import TERM_SIGNAME
from billiard.einfo import ExceptionWithTraceback
from kombu.utils.encoding import safe_repr, safe_str
from kombu.utils.objects import cached_property
from celery import current_app, signals
from celery.app.task import Context
from celery.app.trace import fast_trace_task, trace_task, trace_task_ret
from celery.concurrency.base import BasePool
from celery.exceptions import (Ignore, InvalidTaskError, Reject, Retry, TaskRevokedError, Terminated,
TimeLimitExceeded, WorkerLostError)
from celery.platforms import signals as _signals
from celery.utils.functional import maybe, maybe_list, noop
from celery.utils.log import get_logger
from celery.utils.nodenames import gethostname
from celery.utils.serialization import get_pickled_exception
from celery.utils.time import maybe_iso8601, maybe_make_aware, timezone
from . import state
__all__ = ('Request',)
# pylint: disable=redefined-outer-name
# We cache globals and attribute lookups, so disable this warning.
IS_PYPY = hasattr(sys, 'pypy_version_info')
logger = get_logger(__name__)
debug, info, warn, error = (logger.debug, logger.info,
logger.warning, logger.error)
_does_info = False
_does_debug = False
def __optimize__():
# this is also called by celery.app.trace.setup_worker_optimizations
global _does_debug
global _does_info
_does_debug = logger.isEnabledFor(logging.DEBUG)
_does_info = logger.isEnabledFor(logging.INFO)
__optimize__()
# Localize
tz_or_local = timezone.tz_or_local
send_revoked = signals.task_revoked.send
send_retry = signals.task_retry.send
task_accepted = state.task_accepted
task_ready = state.task_ready
revoked_tasks = state.revoked
revoked_stamps = state.revoked_stamps
class Request:
"""A request for task execution."""
acknowledged = False
time_start = None
worker_pid = None
time_limits = (None, None)
_already_revoked = False
_already_cancelled = False
_terminate_on_ack = None
_apply_result = None
_tzlocal = None
if not IS_PYPY: # pragma: no cover
__slots__ = (
'_app', '_type', 'name', 'id', '_root_id', '_parent_id',
'_on_ack', '_body', '_hostname', '_eventer', '_connection_errors',
'_task', '_eta', '_expires', '_request_dict', '_on_reject', '_utc',
'_content_type', '_content_encoding', '_argsrepr', '_kwargsrepr',
'_args', '_kwargs', '_decoded', '__payload',
'__weakref__', '__dict__',
)
def __init__(self, message, on_ack=noop,
hostname=None, eventer=None, app=None,
connection_errors=None, request_dict=None,
task=None, on_reject=noop, body=None,
headers=None, decoded=False, utc=True,
maybe_make_aware=maybe_make_aware,
maybe_iso8601=maybe_iso8601, **opts):
self._message = message
self._request_dict = (message.headers.copy() if headers is None
else headers.copy())
self._body = message.body if body is None else body
self._app = app
self._utc = utc
self._decoded = decoded
if decoded:
self._content_type = self._content_encoding = None
else:
self._content_type, self._content_encoding = (
message.content_type, message.content_encoding,
)
self.__payload = self._body if self._decoded else message.payload
self.id = self._request_dict['id']
self._type = self.name = self._request_dict['task']
if 'shadow' in self._request_dict:
self.name = self._request_dict['shadow'] or self.name
self._root_id = self._request_dict.get('root_id')
self._parent_id = self._request_dict.get('parent_id')
timelimit = self._request_dict.get('timelimit', None)
if timelimit:
self.time_limits = timelimit
self._argsrepr = self._request_dict.get('argsrepr', '')
self._kwargsrepr = self._request_dict.get('kwargsrepr', '')
self._on_ack = on_ack
self._on_reject = on_reject
self._hostname = hostname or gethostname()
self._eventer = eventer
self._connection_errors = connection_errors or ()
self._task = task or self._app.tasks[self._type]
self._ignore_result = self._request_dict.get('ignore_result', False)
# timezone means the message is timezone-aware, and the only timezone
# supported at this point is UTC.
eta = self._request_dict.get('eta')
if eta is not None:
try:
eta = maybe_iso8601(eta)
except (AttributeError, ValueError, TypeError) as exc:
raise InvalidTaskError(
f'invalid ETA value {eta!r}: {exc}')
self._eta = maybe_make_aware(eta, self.tzlocal)
else:
self._eta = None
expires = self._request_dict.get('expires')
if expires is not None:
try:
expires = maybe_iso8601(expires)
except (AttributeError, ValueError, TypeError) as exc:
raise InvalidTaskError(
f'invalid expires value {expires!r}: {exc}')
self._expires = maybe_make_aware(expires, self.tzlocal)
else:
self._expires = None
delivery_info = message.delivery_info or {}
properties = message.properties or {}
self._delivery_info = {
'exchange': delivery_info.get('exchange'),
'routing_key': delivery_info.get('routing_key'),
'priority': properties.get('priority'),
'redelivered': delivery_info.get('redelivered', False),
}
self._request_dict.update({
'properties': properties,
'reply_to': properties.get('reply_to'),
'correlation_id': properties.get('correlation_id'),
'hostname': self._hostname,
'delivery_info': self._delivery_info
})
# this is a reference pass to avoid memory usage burst
self._request_dict['args'], self._request_dict['kwargs'], _ = self.__payload
self._args = self._request_dict['args']
self._kwargs = self._request_dict['kwargs']
@property
def delivery_info(self):
return self._delivery_info
@property
def message(self):
return self._message
@property
def request_dict(self):
return self._request_dict
@property
def body(self):
return self._body
@property
def app(self):
return self._app
@property
def utc(self):
return self._utc
@property
def content_type(self):
return self._content_type
@property
def content_encoding(self):
return self._content_encoding
@property
def type(self):
return self._type
@property
def root_id(self):
return self._root_id
@property
def parent_id(self):
return self._parent_id
@property
def argsrepr(self):
return self._argsrepr
@property
def args(self):
return self._args
@property
def kwargs(self):
return self._kwargs
@property
def kwargsrepr(self):
return self._kwargsrepr
@property
def on_ack(self):
return self._on_ack
@property
def on_reject(self):
return self._on_reject
@on_reject.setter
def on_reject(self, value):
self._on_reject = value
@property
def hostname(self):
return self._hostname
@property
def ignore_result(self):
return self._ignore_result
@property
def eventer(self):
return self._eventer
@eventer.setter
def eventer(self, eventer):
self._eventer = eventer
@property
def connection_errors(self):
return self._connection_errors
@property
def task(self):
return self._task
@property
def eta(self):
return self._eta
@property
def expires(self):
return self._expires
@expires.setter
def expires(self, value):
self._expires = value
@property
def tzlocal(self):
if self._tzlocal is None:
self._tzlocal = self._app.conf.timezone
return self._tzlocal
@property
def store_errors(self):
return (not self.task.ignore_result or
self.task.store_errors_even_if_ignored)
@property
def task_id(self):
# XXX compat
return self.id
@task_id.setter
def task_id(self, value):
self.id = value
@property
def task_name(self):
# XXX compat
return self.name
@task_name.setter
def task_name(self, value):
self.name = value
@property
def reply_to(self):
# used by rpc backend when failures reported by parent process
return self._request_dict['reply_to']
@property
def replaced_task_nesting(self):
return self._request_dict.get('replaced_task_nesting', 0)
@property
def groups(self):
return self._request_dict.get('groups', [])
@property
def stamped_headers(self) -> list:
return self._request_dict.get('stamped_headers') or []
@property
def stamps(self) -> dict:
stamps = self._request_dict.get('stamps') or {}
return {header: stamps.get(header) for header in self.stamped_headers}
@property
def correlation_id(self):
# used similarly to reply_to
return self._request_dict['correlation_id']
def execute_using_pool(self, pool: BasePool, **kwargs):
"""Used by the worker to send this task to the pool.
Arguments:
pool (~celery.concurrency.base.TaskPool): The execution pool
used to execute this request.
Raises:
celery.exceptions.TaskRevokedError: if the task was revoked.
"""
task_id = self.id
task = self._task
if self.revoked():
raise TaskRevokedError(task_id)
time_limit, soft_time_limit = self.time_limits
trace = fast_trace_task if self._app.use_fast_trace_task else trace_task_ret
result = pool.apply_async(
trace,
args=(self._type, task_id, self._request_dict, self._body,
self._content_type, self._content_encoding),
accept_callback=self.on_accepted,
timeout_callback=self.on_timeout,
callback=self.on_success,
error_callback=self.on_failure,
soft_timeout=soft_time_limit or task.soft_time_limit,
timeout=time_limit or task.time_limit,
correlation_id=task_id,
)
# cannot create weakref to None
self._apply_result = maybe(ref, result)
return result
def execute(self, loglevel=None, logfile=None):
"""Execute the task in a :func:`~celery.app.trace.trace_task`.
Arguments:
loglevel (int): The loglevel used by the task.
logfile (str): The logfile used by the task.
"""
if self.revoked():
return
# acknowledge task as being processed.
if not self.task.acks_late:
self.acknowledge()
_, _, embed = self._payload
request = self._request_dict
# pylint: disable=unpacking-non-sequence
# payload is a property, so pylint doesn't think it's a tuple.
request.update({
'loglevel': loglevel,
'logfile': logfile,
'is_eager': False,
}, **embed or {})
retval, I, _, _ = trace_task(self.task, self.id, self._args, self._kwargs, request,
hostname=self._hostname, loader=self._app.loader,
app=self._app)
if I:
self.reject(requeue=False)
else:
self.acknowledge()
return retval
def maybe_expire(self):
"""If expired, mark the task as revoked."""
if self.expires:
now = datetime.now(self.expires.tzinfo)
if now > self.expires:
revoked_tasks.add(self.id)
return True
def terminate(self, pool, signal=None):
signal = _signals.signum(signal or TERM_SIGNAME)
if self.time_start:
pool.terminate_job(self.worker_pid, signal)
self._announce_revoked('terminated', True, signal, False)
else:
self._terminate_on_ack = pool, signal
if self._apply_result is not None:
obj = self._apply_result() # is a weakref
if obj is not None:
obj.terminate(signal)
def cancel(self, pool, signal=None):
signal = _signals.signum(signal or TERM_SIGNAME)
if self.time_start:
pool.terminate_job(self.worker_pid, signal)
self._announce_cancelled()
if self._apply_result is not None:
obj = self._apply_result() # is a weakref
if obj is not None:
obj.terminate(signal)
def _announce_cancelled(self):
task_ready(self)
self.send_event('task-cancelled')
reason = 'cancelled by Celery'
exc = Retry(message=reason)
self.task.backend.mark_as_retry(self.id,
exc,
request=self._context)
self.task.on_retry(exc, self.id, self.args, self.kwargs, None)
self._already_cancelled = True
send_retry(self.task, request=self._context, einfo=None)
def _announce_revoked(self, reason, terminated, signum, expired):
task_ready(self)
self.send_event('task-revoked',
terminated=terminated, signum=signum, expired=expired)
self.task.backend.mark_as_revoked(
self.id, reason, request=self._context,
store_result=self.store_errors,
)
self.acknowledge()
self._already_revoked = True
send_revoked(self.task, request=self._context,
terminated=terminated, signum=signum, expired=expired)
def revoked(self):
"""If revoked, skip task and mark state."""
expired = False
if self._already_revoked:
return True
if self.expires:
expired = self.maybe_expire()
revoked_by_id = self.id in revoked_tasks
revoked_by_header, revoking_header = False, None
if not revoked_by_id and self.stamped_headers:
for stamp in self.stamped_headers:
if stamp in revoked_stamps:
revoked_header = revoked_stamps[stamp]
stamped_header = self._message.headers['stamps'][stamp]
if isinstance(stamped_header, (list, tuple)):
for stamped_value in stamped_header:
if stamped_value in maybe_list(revoked_header):
revoked_by_header = True
revoking_header = {stamp: stamped_value}
break
else:
revoked_by_header = any([
stamped_header in maybe_list(revoked_header),
stamped_header == revoked_header, # When the header is a single set value
])
revoking_header = {stamp: stamped_header}
break
if any((expired, revoked_by_id, revoked_by_header)):
log_msg = 'Discarding revoked task: %s[%s]'
if revoked_by_header:
log_msg += ' (revoked by header: %s)' % revoking_header
info(log_msg, self.name, self.id)
self._announce_revoked(
'expired' if expired else 'revoked', False, None, expired,
)
return True
return False
def send_event(self, type, **fields):
if self._eventer and self._eventer.enabled and self.task.send_events:
self._eventer.send(type, uuid=self.id, **fields)
def on_accepted(self, pid, time_accepted):
"""Handler called when task is accepted by worker pool."""
self.worker_pid = pid
# Convert monotonic time_accepted to absolute time
self.time_start = time() - (monotonic() - time_accepted)
task_accepted(self)
if not self.task.acks_late:
self.acknowledge()
self.send_event('task-started')
if _does_debug:
debug('Task accepted: %s[%s] pid:%r', self.name, self.id, pid)
if self._terminate_on_ack is not None:
self.terminate(*self._terminate_on_ack)
def on_timeout(self, soft, timeout):
"""Handler called if the task times out."""
if soft:
warn('Soft time limit (%ss) exceeded for %s[%s]',
timeout, self.name, self.id)
else:
task_ready(self)
error('Hard time limit (%ss) exceeded for %s[%s]',
timeout, self.name, self.id)
exc = TimeLimitExceeded(timeout)
self.task.backend.mark_as_failure(
self.id, exc, request=self._context,
store_result=self.store_errors,
)
if self.task.acks_late and self.task.acks_on_failure_or_timeout:
self.acknowledge()
def on_success(self, failed__retval__runtime, **kwargs):
"""Handler called if the task was successfully processed."""
failed, retval, runtime = failed__retval__runtime
if failed:
exc = retval.exception
if isinstance(exc, ExceptionWithTraceback):
exc = exc.exc
if isinstance(exc, (SystemExit, KeyboardInterrupt)):
raise exc
return self.on_failure(retval, return_ok=True)
task_ready(self, successful=True)
if self.task.acks_late:
self.acknowledge()
self.send_event('task-succeeded', result=retval, runtime=runtime)
def on_retry(self, exc_info):
"""Handler called if the task should be retried."""
if self.task.acks_late:
self.acknowledge()
self.send_event('task-retried',
exception=safe_repr(exc_info.exception.exc),
traceback=safe_str(exc_info.traceback))
def on_failure(self, exc_info, send_failed_event=True, return_ok=False):
"""Handler called if the task raised an exception."""
task_ready(self)
exc = exc_info.exception
if isinstance(exc, ExceptionWithTraceback):
exc = exc.exc
is_terminated = isinstance(exc, Terminated)
if is_terminated:
# If the task was terminated and the task was not cancelled due
# to a connection loss, it is revoked.
# We always cancel the tasks inside the master process.
# If the request was cancelled, it was not revoked and there's
# nothing to be done.
# According to the comment below, we need to check if the task
# is already revoked and if it wasn't, we should announce that
# it was.
if not self._already_cancelled and not self._already_revoked:
# This is a special case where the process
# would not have had time to write the result.
self._announce_revoked(
'terminated', True, str(exc), False)
return
elif isinstance(exc, MemoryError):
raise MemoryError(f'Process got: {exc}')
elif isinstance(exc, Reject):
return self.reject(requeue=exc.requeue)
elif isinstance(exc, Ignore):
return self.acknowledge()
elif isinstance(exc, Retry):
return self.on_retry(exc_info)
# (acks_late) acknowledge after result stored.
requeue = False
is_worker_lost = isinstance(exc, WorkerLostError)
if self.task.acks_late:
reject = (
self.task.reject_on_worker_lost and
is_worker_lost
)
ack = self.task.acks_on_failure_or_timeout
if reject:
requeue = True
self.reject(requeue=requeue)
send_failed_event = False
elif ack:
self.acknowledge()
else:
# supporting the behaviour where a task failed and
# need to be removed from prefetched local queue
self.reject(requeue=False)
# This is a special case where the process would not have had time
# to write the result.
if not requeue and (is_worker_lost or not return_ok):
# only mark as failure if task has not been requeued
self.task.backend.mark_as_failure(
self.id, exc, request=self._context,
store_result=self.store_errors,
)
signals.task_failure.send(sender=self.task, task_id=self.id,
exception=exc, args=self.args,
kwargs=self.kwargs,
traceback=exc_info.traceback,
einfo=exc_info)
if send_failed_event:
self.send_event(
'task-failed',
exception=safe_repr(get_pickled_exception(exc_info.exception)),
traceback=exc_info.traceback,
)
if not return_ok:
error('Task handler raised error: %r', exc,
exc_info=exc_info.exc_info)
def acknowledge(self):
"""Acknowledge task."""
if not self.acknowledged:
self._on_ack(logger, self._connection_errors)
self.acknowledged = True
def reject(self, requeue=False):
if not self.acknowledged:
self._on_reject(logger, self._connection_errors, requeue)
self.acknowledged = True
self.send_event('task-rejected', requeue=requeue)
def info(self, safe=False):
return {
'id': self.id,
'name': self.name,
'args': self._args if not safe else self._argsrepr,
'kwargs': self._kwargs if not safe else self._kwargsrepr,
'type': self._type,
'hostname': self._hostname,
'time_start': self.time_start,
'acknowledged': self.acknowledged,
'delivery_info': self.delivery_info,
'worker_pid': self.worker_pid,
}
def humaninfo(self):
return '{0.name}[{0.id}]'.format(self)
def __str__(self):
"""``str(self)``."""
return ' '.join([
self.humaninfo(),
f' ETA:[{self._eta}]' if self._eta else '',
f' expires:[{self._expires}]' if self._expires else '',
]).strip()
def __repr__(self):
"""``repr(self)``."""
return '<{}: {} {} {}>'.format(
type(self).__name__, self.humaninfo(),
self._argsrepr, self._kwargsrepr,
)
@cached_property
def _payload(self):
return self.__payload
@cached_property
def chord(self):
# used by backend.mark_as_failure when failure is reported
# by parent process
# pylint: disable=unpacking-non-sequence
# payload is a property, so pylint doesn't think it's a tuple.
_, _, embed = self._payload
return embed.get('chord')
@cached_property
def errbacks(self):
# used by backend.mark_as_failure when failure is reported
# by parent process
# pylint: disable=unpacking-non-sequence
# payload is a property, so pylint doesn't think it's a tuple.
_, _, embed = self._payload
return embed.get('errbacks')
@cached_property
def group(self):
# used by backend.on_chord_part_return when failures reported
# by parent process
return self._request_dict.get('group')
@cached_property
def _context(self):
"""Context (:class:`~celery.app.task.Context`) of this task."""
request = self._request_dict
# pylint: disable=unpacking-non-sequence
# payload is a property, so pylint doesn't think it's a tuple.
_, _, embed = self._payload
request.update(**embed or {})
return Context(request)
@cached_property
def group_index(self):
# used by backend.on_chord_part_return to order return values in group
return self._request_dict.get('group_index')
def create_request_cls(base, task, pool, hostname, eventer,
ref=ref, revoked_tasks=revoked_tasks,
task_ready=task_ready, trace=None, app=current_app):
default_time_limit = task.time_limit
default_soft_time_limit = task.soft_time_limit
apply_async = pool.apply_async
acks_late = task.acks_late
events = eventer and eventer.enabled
if trace is None:
trace = fast_trace_task if app.use_fast_trace_task else trace_task_ret
class Request(base):
def execute_using_pool(self, pool, **kwargs):
task_id = self.task_id
if self.revoked():
raise TaskRevokedError(task_id)
time_limit, soft_time_limit = self.time_limits
result = apply_async(
trace,
args=(self.type, task_id, self.request_dict, self.body,
self.content_type, self.content_encoding),
accept_callback=self.on_accepted,
timeout_callback=self.on_timeout,
callback=self.on_success,
error_callback=self.on_failure,
soft_timeout=soft_time_limit or default_soft_time_limit,
timeout=time_limit or default_time_limit,
correlation_id=task_id,
)
# cannot create weakref to None
# pylint: disable=attribute-defined-outside-init
self._apply_result = maybe(ref, result)
return result
def on_success(self, failed__retval__runtime, **kwargs):
failed, retval, runtime = failed__retval__runtime
if failed:
exc = retval.exception
if isinstance(exc, ExceptionWithTraceback):
exc = exc.exc
if isinstance(exc, (SystemExit, KeyboardInterrupt)):
raise exc
return self.on_failure(retval, return_ok=True)
task_ready(self)
if acks_late:
self.acknowledge()
if events:
self.send_event(
'task-succeeded', result=retval, runtime=runtime,
)
return Request

View File

@@ -0,0 +1,288 @@
"""Internal worker state (global).
This includes the currently active and reserved tasks,
statistics, and revoked tasks.
"""
import os
import platform
import shelve
import sys
import weakref
import zlib
from collections import Counter
from kombu.serialization import pickle, pickle_protocol
from kombu.utils.objects import cached_property
from celery import __version__
from celery.exceptions import WorkerShutdown, WorkerTerminate
from celery.utils.collections import LimitedSet
__all__ = (
'SOFTWARE_INFO', 'reserved_requests', 'active_requests',
'total_count', 'revoked', 'task_reserved', 'maybe_shutdown',
'task_accepted', 'task_ready', 'Persistent',
)
#: Worker software/platform information.
SOFTWARE_INFO = {
'sw_ident': 'py-celery',
'sw_ver': __version__,
'sw_sys': platform.system(),
}
#: maximum number of revokes to keep in memory.
REVOKES_MAX = int(os.environ.get('CELERY_WORKER_REVOKES_MAX', 50000))
#: maximum number of successful tasks to keep in memory.
SUCCESSFUL_MAX = int(os.environ.get('CELERY_WORKER_SUCCESSFUL_MAX', 1000))
#: how many seconds a revoke will be active before
#: being expired when the max limit has been exceeded.
REVOKE_EXPIRES = float(os.environ.get('CELERY_WORKER_REVOKE_EXPIRES', 10800))
#: how many seconds a successful task will be cached in memory
#: before being expired when the max limit has been exceeded.
SUCCESSFUL_EXPIRES = float(os.environ.get('CELERY_WORKER_SUCCESSFUL_EXPIRES', 10800))
#: Mapping of reserved task_id->Request.
requests = {}
#: set of all reserved :class:`~celery.worker.request.Request`'s.
reserved_requests = weakref.WeakSet()
#: set of currently active :class:`~celery.worker.request.Request`'s.
active_requests = weakref.WeakSet()
#: A limited set of successful :class:`~celery.worker.request.Request`'s.
successful_requests = LimitedSet(maxlen=SUCCESSFUL_MAX,
expires=SUCCESSFUL_EXPIRES)
#: count of tasks accepted by the worker, sorted by type.
total_count = Counter()
#: count of all tasks accepted by the worker
all_total_count = [0]
#: the list of currently revoked tasks. Persistent if ``statedb`` set.
revoked = LimitedSet(maxlen=REVOKES_MAX, expires=REVOKE_EXPIRES)
#: Mapping of stamped headers flagged for revoking.
revoked_stamps = {}
should_stop = None
should_terminate = None
def reset_state():
requests.clear()
reserved_requests.clear()
active_requests.clear()
successful_requests.clear()
total_count.clear()
all_total_count[:] = [0]
revoked.clear()
revoked_stamps.clear()
def maybe_shutdown():
"""Shutdown if flags have been set."""
if should_terminate is not None and should_terminate is not False:
raise WorkerTerminate(should_terminate)
elif should_stop is not None and should_stop is not False:
raise WorkerShutdown(should_stop)
def task_reserved(request,
add_request=requests.__setitem__,
add_reserved_request=reserved_requests.add):
"""Update global state when a task has been reserved."""
add_request(request.id, request)
add_reserved_request(request)
def task_accepted(request,
_all_total_count=None,
add_request=requests.__setitem__,
add_active_request=active_requests.add,
add_to_total_count=total_count.update):
"""Update global state when a task has been accepted."""
if not _all_total_count:
_all_total_count = all_total_count
add_request(request.id, request)
add_active_request(request)
add_to_total_count({request.name: 1})
all_total_count[0] += 1
def task_ready(request,
successful=False,
remove_request=requests.pop,
discard_active_request=active_requests.discard,
discard_reserved_request=reserved_requests.discard):
"""Update global state when a task is ready."""
if successful:
successful_requests.add(request.id)
remove_request(request.id, None)
discard_active_request(request)
discard_reserved_request(request)
C_BENCH = os.environ.get('C_BENCH') or os.environ.get('CELERY_BENCH')
C_BENCH_EVERY = int(os.environ.get('C_BENCH_EVERY') or
os.environ.get('CELERY_BENCH_EVERY') or 1000)
if C_BENCH: # pragma: no cover
import atexit
from time import monotonic
from billiard.process import current_process
from celery.utils.debug import memdump, sample_mem
all_count = 0
bench_first = None
bench_start = None
bench_last = None
bench_every = C_BENCH_EVERY
bench_sample = []
__reserved = task_reserved
__ready = task_ready
if current_process()._name == 'MainProcess':
@atexit.register
def on_shutdown():
if bench_first is not None and bench_last is not None:
print('- Time spent in benchmark: {!r}'.format(
bench_last - bench_first))
print('- Avg: {}'.format(
sum(bench_sample) / len(bench_sample)))
memdump()
def task_reserved(request):
"""Called when a task is reserved by the worker."""
global bench_start
global bench_first
now = None
if bench_start is None:
bench_start = now = monotonic()
if bench_first is None:
bench_first = now
return __reserved(request)
def task_ready(request):
"""Called when a task is completed."""
global all_count
global bench_start
global bench_last
all_count += 1
if not all_count % bench_every:
now = monotonic()
diff = now - bench_start
print('- Time spent processing {} tasks (since first '
'task received): ~{:.4f}s\n'.format(bench_every, diff))
sys.stdout.flush()
bench_start = bench_last = now
bench_sample.append(diff)
sample_mem()
return __ready(request)
class Persistent:
"""Stores worker state between restarts.
This is the persistent data stored by the worker when
:option:`celery worker --statedb` is enabled.
Currently only stores revoked task id's.
"""
storage = shelve
protocol = pickle_protocol
compress = zlib.compress
decompress = zlib.decompress
_is_open = False
def __init__(self, state, filename, clock=None):
self.state = state
self.filename = filename
self.clock = clock
self.merge()
def open(self):
return self.storage.open(
self.filename, protocol=self.protocol, writeback=True,
)
def merge(self):
self._merge_with(self.db)
def sync(self):
self._sync_with(self.db)
self.db.sync()
def close(self):
if self._is_open:
self.db.close()
self._is_open = False
def save(self):
self.sync()
self.close()
def _merge_with(self, d):
self._merge_revoked(d)
self._merge_clock(d)
return d
def _sync_with(self, d):
self._revoked_tasks.purge()
d.update({
'__proto__': 3,
'zrevoked': self.compress(self._dumps(self._revoked_tasks)),
'clock': self.clock.forward() if self.clock else 0,
})
return d
def _merge_clock(self, d):
if self.clock:
d['clock'] = self.clock.adjust(d.get('clock') or 0)
def _merge_revoked(self, d):
try:
self._merge_revoked_v3(d['zrevoked'])
except KeyError:
try:
self._merge_revoked_v2(d.pop('revoked'))
except KeyError:
pass
# purge expired items at boot
self._revoked_tasks.purge()
def _merge_revoked_v3(self, zrevoked):
if zrevoked:
self._revoked_tasks.update(pickle.loads(self.decompress(zrevoked)))
def _merge_revoked_v2(self, saved):
if not isinstance(saved, LimitedSet):
# (pre 3.0.18) used to be stored as a dict
return self._merge_revoked_v1(saved)
self._revoked_tasks.update(saved)
def _merge_revoked_v1(self, saved):
add = self._revoked_tasks.add
for item in saved:
add(item)
def _dumps(self, obj):
return pickle.dumps(obj, protocol=self.protocol)
@property
def _revoked_tasks(self):
return self.state.revoked
@cached_property
def db(self):
self._is_open = True
return self.open()

View File

@@ -0,0 +1,208 @@
"""Task execution strategy (optimization)."""
import logging
from kombu.asynchronous.timer import to_timestamp
from celery import signals
from celery.app import trace as _app_trace
from celery.exceptions import InvalidTaskError
from celery.utils.imports import symbol_by_name
from celery.utils.log import get_logger
from celery.utils.saferepr import saferepr
from celery.utils.time import timezone
from .request import create_request_cls
from .state import task_reserved
__all__ = ('default',)
logger = get_logger(__name__)
# pylint: disable=redefined-outer-name
# We cache globals and attribute lookups, so disable this warning.
def hybrid_to_proto2(message, body):
"""Create a fresh protocol 2 message from a hybrid protocol 1/2 message."""
try:
args, kwargs = body.get('args', ()), body.get('kwargs', {})
kwargs.items # pylint: disable=pointless-statement
except KeyError:
raise InvalidTaskError('Message does not have args/kwargs')
except AttributeError:
raise InvalidTaskError(
'Task keyword arguments must be a mapping',
)
headers = {
'lang': body.get('lang'),
'task': body.get('task'),
'id': body.get('id'),
'root_id': body.get('root_id'),
'parent_id': body.get('parent_id'),
'group': body.get('group'),
'meth': body.get('meth'),
'shadow': body.get('shadow'),
'eta': body.get('eta'),
'expires': body.get('expires'),
'retries': body.get('retries', 0),
'timelimit': body.get('timelimit', (None, None)),
'argsrepr': body.get('argsrepr'),
'kwargsrepr': body.get('kwargsrepr'),
'origin': body.get('origin'),
}
headers.update(message.headers or {})
embed = {
'callbacks': body.get('callbacks'),
'errbacks': body.get('errbacks'),
'chord': body.get('chord'),
'chain': None,
}
return (args, kwargs, embed), headers, True, body.get('utc', True)
def proto1_to_proto2(message, body):
"""Convert Task message protocol 1 arguments to protocol 2.
Returns:
Tuple: of ``(body, headers, already_decoded_status, utc)``
"""
try:
args, kwargs = body.get('args', ()), body.get('kwargs', {})
kwargs.items # pylint: disable=pointless-statement
except KeyError:
raise InvalidTaskError('Message does not have args/kwargs')
except AttributeError:
raise InvalidTaskError(
'Task keyword arguments must be a mapping',
)
body.update(
argsrepr=saferepr(args),
kwargsrepr=saferepr(kwargs),
headers=message.headers,
)
try:
body['group'] = body['taskset']
except KeyError:
pass
embed = {
'callbacks': body.get('callbacks'),
'errbacks': body.get('errbacks'),
'chord': body.get('chord'),
'chain': None,
}
return (args, kwargs, embed), body, True, body.get('utc', True)
def default(task, app, consumer,
info=logger.info, error=logger.error, task_reserved=task_reserved,
to_system_tz=timezone.to_system, bytes=bytes,
proto1_to_proto2=proto1_to_proto2):
"""Default task execution strategy.
Note:
Strategies are here as an optimization, so sadly
it's not very easy to override.
"""
hostname = consumer.hostname
connection_errors = consumer.connection_errors
_does_info = logger.isEnabledFor(logging.INFO)
# task event related
# (optimized to avoid calling request.send_event)
eventer = consumer.event_dispatcher
events = eventer and eventer.enabled
send_event = eventer and eventer.send
task_sends_events = events and task.send_events
call_at = consumer.timer.call_at
apply_eta_task = consumer.apply_eta_task
rate_limits_enabled = not consumer.disable_rate_limits
get_bucket = consumer.task_buckets.__getitem__
handle = consumer.on_task_request
limit_task = consumer._limit_task
limit_post_eta = consumer._limit_post_eta
Request = symbol_by_name(task.Request)
Req = create_request_cls(Request, task, consumer.pool, hostname, eventer, app=app)
revoked_tasks = consumer.controller.state.revoked
def task_message_handler(message, body, ack, reject, callbacks,
to_timestamp=to_timestamp):
if body is None and 'args' not in message.payload:
body, headers, decoded, utc = (
message.body, message.headers, False, app.uses_utc_timezone(),
)
else:
if 'args' in message.payload:
body, headers, decoded, utc = hybrid_to_proto2(message,
message.payload)
else:
body, headers, decoded, utc = proto1_to_proto2(message, body)
req = Req(
message,
on_ack=ack, on_reject=reject, app=app, hostname=hostname,
eventer=eventer, task=task, connection_errors=connection_errors,
body=body, headers=headers, decoded=decoded, utc=utc,
)
if _does_info:
# Similar to `app.trace.info()`, we pass the formatting args as the
# `extra` kwarg for custom log handlers
context = {
'id': req.id,
'name': req.name,
'args': req.argsrepr,
'kwargs': req.kwargsrepr,
'eta': req.eta,
}
info(_app_trace.LOG_RECEIVED, context, extra={'data': context})
if (req.expires or req.id in revoked_tasks) and req.revoked():
return
signals.task_received.send(sender=consumer, request=req)
if task_sends_events:
send_event(
'task-received',
uuid=req.id, name=req.name,
args=req.argsrepr, kwargs=req.kwargsrepr,
root_id=req.root_id, parent_id=req.parent_id,
retries=req.request_dict.get('retries', 0),
eta=req.eta and req.eta.isoformat(),
expires=req.expires and req.expires.isoformat(),
)
bucket = None
eta = None
if req.eta:
try:
if req.utc:
eta = to_timestamp(to_system_tz(req.eta))
else:
eta = to_timestamp(req.eta, app.timezone)
except (OverflowError, ValueError) as exc:
error("Couldn't convert ETA %r to timestamp: %r. Task: %r",
req.eta, exc, req.info(safe=True), exc_info=True)
req.reject(requeue=False)
if rate_limits_enabled:
bucket = get_bucket(task.name)
if eta and bucket:
consumer.qos.increment_eventually()
return call_at(eta, limit_post_eta, (req, bucket, 1),
priority=6)
if eta:
consumer.qos.increment_eventually()
call_at(eta, apply_eta_task, (req,), priority=6)
return task_message_handler
if bucket:
return limit_task(req, bucket, 1)
task_reserved(req)
if callbacks:
[callback(req) for callback in callbacks]
handle(req)
return task_message_handler

View File

@@ -0,0 +1,409 @@
"""WorkController can be used to instantiate in-process workers.
The command-line interface for the worker is in :mod:`celery.bin.worker`,
while the worker program is in :mod:`celery.apps.worker`.
The worker program is responsible for adding signal handlers,
setting up logging, etc. This is a bare-bones worker without
global side-effects (i.e., except for the global state stored in
:mod:`celery.worker.state`).
The worker consists of several components, all managed by bootsteps
(mod:`celery.bootsteps`).
"""
import os
import sys
from datetime import datetime
from billiard import cpu_count
from kombu.utils.compat import detect_environment
from celery import bootsteps
from celery import concurrency as _concurrency
from celery import signals
from celery.bootsteps import RUN, TERMINATE
from celery.exceptions import ImproperlyConfigured, TaskRevokedError, WorkerTerminate
from celery.platforms import EX_FAILURE, create_pidlock
from celery.utils.imports import reload_from_cwd
from celery.utils.log import mlevel
from celery.utils.log import worker_logger as logger
from celery.utils.nodenames import default_nodename, worker_direct
from celery.utils.text import str_to_list
from celery.utils.threads import default_socket_timeout
from . import state
try:
import resource
except ImportError:
resource = None
__all__ = ('WorkController',)
#: Default socket timeout at shutdown.
SHUTDOWN_SOCKET_TIMEOUT = 5.0
SELECT_UNKNOWN_QUEUE = """
Trying to select queue subset of {0!r}, but queue {1} isn't
defined in the `task_queues` setting.
If you want to automatically declare unknown queues you can
enable the `task_create_missing_queues` setting.
"""
DESELECT_UNKNOWN_QUEUE = """
Trying to deselect queue subset of {0!r}, but queue {1} isn't
defined in the `task_queues` setting.
"""
class WorkController:
"""Unmanaged worker instance."""
app = None
pidlock = None
blueprint = None
pool = None
semaphore = None
#: contains the exit code if a :exc:`SystemExit` event is handled.
exitcode = None
class Blueprint(bootsteps.Blueprint):
"""Worker bootstep blueprint."""
name = 'Worker'
default_steps = {
'celery.worker.components:Hub',
'celery.worker.components:Pool',
'celery.worker.components:Beat',
'celery.worker.components:Timer',
'celery.worker.components:StateDB',
'celery.worker.components:Consumer',
'celery.worker.autoscale:WorkerComponent',
}
def __init__(self, app=None, hostname=None, **kwargs):
self.app = app or self.app
self.hostname = default_nodename(hostname)
self.startup_time = datetime.utcnow()
self.app.loader.init_worker()
self.on_before_init(**kwargs)
self.setup_defaults(**kwargs)
self.on_after_init(**kwargs)
self.setup_instance(**self.prepare_args(**kwargs))
def setup_instance(self, queues=None, ready_callback=None, pidfile=None,
include=None, use_eventloop=None, exclude_queues=None,
**kwargs):
self.pidfile = pidfile
self.setup_queues(queues, exclude_queues)
self.setup_includes(str_to_list(include))
# Set default concurrency
if not self.concurrency:
try:
self.concurrency = cpu_count()
except NotImplementedError:
self.concurrency = 2
# Options
self.loglevel = mlevel(self.loglevel)
self.ready_callback = ready_callback or self.on_consumer_ready
# this connection won't establish, only used for params
self._conninfo = self.app.connection_for_read()
self.use_eventloop = (
self.should_use_eventloop() if use_eventloop is None
else use_eventloop
)
self.options = kwargs
signals.worker_init.send(sender=self)
# Initialize bootsteps
self.pool_cls = _concurrency.get_implementation(self.pool_cls)
self.steps = []
self.on_init_blueprint()
self.blueprint = self.Blueprint(
steps=self.app.steps['worker'],
on_start=self.on_start,
on_close=self.on_close,
on_stopped=self.on_stopped,
)
self.blueprint.apply(self, **kwargs)
def on_init_blueprint(self):
pass
def on_before_init(self, **kwargs):
pass
def on_after_init(self, **kwargs):
pass
def on_start(self):
if self.pidfile:
self.pidlock = create_pidlock(self.pidfile)
def on_consumer_ready(self, consumer):
pass
def on_close(self):
self.app.loader.shutdown_worker()
def on_stopped(self):
self.timer.stop()
self.consumer.shutdown()
if self.pidlock:
self.pidlock.release()
def setup_queues(self, include, exclude=None):
include = str_to_list(include)
exclude = str_to_list(exclude)
try:
self.app.amqp.queues.select(include)
except KeyError as exc:
raise ImproperlyConfigured(
SELECT_UNKNOWN_QUEUE.strip().format(include, exc))
try:
self.app.amqp.queues.deselect(exclude)
except KeyError as exc:
raise ImproperlyConfigured(
DESELECT_UNKNOWN_QUEUE.strip().format(exclude, exc))
if self.app.conf.worker_direct:
self.app.amqp.queues.select_add(worker_direct(self.hostname))
def setup_includes(self, includes):
# Update celery_include to have all known task modules, so that we
# ensure all task modules are imported in case an execv happens.
prev = tuple(self.app.conf.include)
if includes:
prev += tuple(includes)
[self.app.loader.import_task_module(m) for m in includes]
self.include = includes
task_modules = {task.__class__.__module__
for task in self.app.tasks.values()}
self.app.conf.include = tuple(set(prev) | task_modules)
def prepare_args(self, **kwargs):
return kwargs
def _send_worker_shutdown(self):
signals.worker_shutdown.send(sender=self)
def start(self):
try:
self.blueprint.start(self)
except WorkerTerminate:
self.terminate()
except Exception as exc:
logger.critical('Unrecoverable error: %r', exc, exc_info=True)
self.stop(exitcode=EX_FAILURE)
except SystemExit as exc:
self.stop(exitcode=exc.code)
except KeyboardInterrupt:
self.stop(exitcode=EX_FAILURE)
def register_with_event_loop(self, hub):
self.blueprint.send_all(
self, 'register_with_event_loop', args=(hub,),
description='hub.register',
)
def _process_task_sem(self, req):
return self._quick_acquire(self._process_task, req)
def _process_task(self, req):
"""Process task by sending it to the pool of workers."""
try:
req.execute_using_pool(self.pool)
except TaskRevokedError:
try:
self._quick_release() # Issue 877
except AttributeError:
pass
def signal_consumer_close(self):
try:
self.consumer.close()
except AttributeError:
pass
def should_use_eventloop(self):
return (detect_environment() == 'default' and
self._conninfo.transport.implements.asynchronous and
not self.app.IS_WINDOWS)
def stop(self, in_sighandler=False, exitcode=None):
"""Graceful shutdown of the worker server."""
if exitcode is not None:
self.exitcode = exitcode
if self.blueprint.state == RUN:
self.signal_consumer_close()
if not in_sighandler or self.pool.signal_safe:
self._shutdown(warm=True)
self._send_worker_shutdown()
def terminate(self, in_sighandler=False):
"""Not so graceful shutdown of the worker server."""
if self.blueprint.state != TERMINATE:
self.signal_consumer_close()
if not in_sighandler or self.pool.signal_safe:
self._shutdown(warm=False)
def _shutdown(self, warm=True):
# if blueprint does not exist it means that we had an
# error before the bootsteps could be initialized.
if self.blueprint is not None:
with default_socket_timeout(SHUTDOWN_SOCKET_TIMEOUT): # Issue 975
self.blueprint.stop(self, terminate=not warm)
self.blueprint.join()
def reload(self, modules=None, reload=False, reloader=None):
list(self._reload_modules(
modules, force_reload=reload, reloader=reloader))
if self.consumer:
self.consumer.update_strategies()
self.consumer.reset_rate_limits()
try:
self.pool.restart()
except NotImplementedError:
pass
def _reload_modules(self, modules=None, **kwargs):
return (
self._maybe_reload_module(m, **kwargs)
for m in set(self.app.loader.task_modules
if modules is None else (modules or ()))
)
def _maybe_reload_module(self, module, force_reload=False, reloader=None):
if module not in sys.modules:
logger.debug('importing module %s', module)
return self.app.loader.import_from_cwd(module)
elif force_reload:
logger.debug('reloading module %s', module)
return reload_from_cwd(sys.modules[module], reloader)
def info(self):
uptime = datetime.utcnow() - self.startup_time
return {'total': self.state.total_count,
'pid': os.getpid(),
'clock': str(self.app.clock),
'uptime': round(uptime.total_seconds())}
def rusage(self):
if resource is None:
raise NotImplementedError('rusage not supported by this platform')
s = resource.getrusage(resource.RUSAGE_SELF)
return {
'utime': s.ru_utime,
'stime': s.ru_stime,
'maxrss': s.ru_maxrss,
'ixrss': s.ru_ixrss,
'idrss': s.ru_idrss,
'isrss': s.ru_isrss,
'minflt': s.ru_minflt,
'majflt': s.ru_majflt,
'nswap': s.ru_nswap,
'inblock': s.ru_inblock,
'oublock': s.ru_oublock,
'msgsnd': s.ru_msgsnd,
'msgrcv': s.ru_msgrcv,
'nsignals': s.ru_nsignals,
'nvcsw': s.ru_nvcsw,
'nivcsw': s.ru_nivcsw,
}
def stats(self):
info = self.info()
info.update(self.blueprint.info(self))
info.update(self.consumer.blueprint.info(self.consumer))
try:
info['rusage'] = self.rusage()
except NotImplementedError:
info['rusage'] = 'N/A'
return info
def __repr__(self):
"""``repr(worker)``."""
return '<Worker: {self.hostname} ({state})>'.format(
self=self,
state=self.blueprint.human_state() if self.blueprint else 'INIT',
)
def __str__(self):
"""``str(worker) == worker.hostname``."""
return self.hostname
@property
def state(self):
return state
def setup_defaults(self, concurrency=None, loglevel='WARN', logfile=None,
task_events=None, pool=None, consumer_cls=None,
timer_cls=None, timer_precision=None,
autoscaler_cls=None,
pool_putlocks=None,
pool_restarts=None,
optimization=None, O=None, # O maps to -O=fair
statedb=None,
time_limit=None,
soft_time_limit=None,
scheduler=None,
pool_cls=None, # XXX use pool
state_db=None, # XXX use statedb
task_time_limit=None, # XXX use time_limit
task_soft_time_limit=None, # XXX use soft_time_limit
scheduler_cls=None, # XXX use scheduler
schedule_filename=None,
max_tasks_per_child=None,
prefetch_multiplier=None, disable_rate_limits=None,
worker_lost_wait=None,
max_memory_per_child=None, **_kw):
either = self.app.either
self.loglevel = loglevel
self.logfile = logfile
self.concurrency = either('worker_concurrency', concurrency)
self.task_events = either('worker_send_task_events', task_events)
self.pool_cls = either('worker_pool', pool, pool_cls)
self.consumer_cls = either('worker_consumer', consumer_cls)
self.timer_cls = either('worker_timer', timer_cls)
self.timer_precision = either(
'worker_timer_precision', timer_precision,
)
self.optimization = optimization or O
self.autoscaler_cls = either('worker_autoscaler', autoscaler_cls)
self.pool_putlocks = either('worker_pool_putlocks', pool_putlocks)
self.pool_restarts = either('worker_pool_restarts', pool_restarts)
self.statedb = either('worker_state_db', statedb, state_db)
self.schedule_filename = either(
'beat_schedule_filename', schedule_filename,
)
self.scheduler = either('beat_scheduler', scheduler, scheduler_cls)
self.time_limit = either(
'task_time_limit', time_limit, task_time_limit)
self.soft_time_limit = either(
'task_soft_time_limit', soft_time_limit, task_soft_time_limit,
)
self.max_tasks_per_child = either(
'worker_max_tasks_per_child', max_tasks_per_child,
)
self.max_memory_per_child = either(
'worker_max_memory_per_child', max_memory_per_child,
)
self.prefetch_multiplier = int(either(
'worker_prefetch_multiplier', prefetch_multiplier,
))
self.disable_rate_limits = either(
'worker_disable_rate_limits', disable_rate_limits,
)
self.worker_lost_wait = either('worker_lost_wait', worker_lost_wait)