Updates
This commit is contained in:
@@ -0,0 +1,15 @@
|
||||
"""Worker consumer."""
|
||||
from .agent import Agent
|
||||
from .connection import Connection
|
||||
from .consumer import Consumer
|
||||
from .control import Control
|
||||
from .events import Events
|
||||
from .gossip import Gossip
|
||||
from .heart import Heart
|
||||
from .mingle import Mingle
|
||||
from .tasks import Tasks
|
||||
|
||||
__all__ = (
|
||||
'Consumer', 'Agent', 'Connection', 'Control',
|
||||
'Events', 'Gossip', 'Heart', 'Mingle', 'Tasks',
|
||||
)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,21 @@
|
||||
"""Celery + :pypi:`cell` integration."""
|
||||
from celery import bootsteps
|
||||
|
||||
from .connection import Connection
|
||||
|
||||
__all__ = ('Agent',)
|
||||
|
||||
|
||||
class Agent(bootsteps.StartStopStep):
|
||||
"""Agent starts :pypi:`cell` actors."""
|
||||
|
||||
conditional = True
|
||||
requires = (Connection,)
|
||||
|
||||
def __init__(self, c, **kwargs):
|
||||
self.agent_cls = self.enabled = c.app.conf.worker_agent
|
||||
super().__init__(c, **kwargs)
|
||||
|
||||
def create(self, c):
|
||||
agent = c.agent = self.instantiate(self.agent_cls, c.connection)
|
||||
return agent
|
||||
@@ -0,0 +1,36 @@
|
||||
"""Consumer Broker Connection Bootstep."""
|
||||
from kombu.common import ignore_errors
|
||||
|
||||
from celery import bootsteps
|
||||
from celery.utils.log import get_logger
|
||||
|
||||
__all__ = ('Connection',)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
info = logger.info
|
||||
|
||||
|
||||
class Connection(bootsteps.StartStopStep):
|
||||
"""Service managing the consumer broker connection."""
|
||||
|
||||
def __init__(self, c, **kwargs):
|
||||
c.connection = None
|
||||
super().__init__(c, **kwargs)
|
||||
|
||||
def start(self, c):
|
||||
c.connection = c.connect()
|
||||
info('Connected to %s', c.connection.as_uri())
|
||||
|
||||
def shutdown(self, c):
|
||||
# We must set self.connection to None here, so
|
||||
# that the green pidbox thread exits.
|
||||
connection, c.connection = c.connection, None
|
||||
if connection:
|
||||
ignore_errors(connection, connection.close)
|
||||
|
||||
def info(self, c):
|
||||
params = 'N/A'
|
||||
if c.connection:
|
||||
params = c.connection.info()
|
||||
params.pop('password', None) # don't send password.
|
||||
return {'broker': params}
|
||||
@@ -0,0 +1,775 @@
|
||||
"""Worker Consumer Blueprint.
|
||||
|
||||
This module contains the components responsible for consuming messages
|
||||
from the broker, processing the messages and keeping the broker connections
|
||||
up and running.
|
||||
"""
|
||||
import errno
|
||||
import logging
|
||||
import os
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from time import sleep
|
||||
|
||||
from billiard.common import restart_state
|
||||
from billiard.exceptions import RestartFreqExceeded
|
||||
from kombu.asynchronous.semaphore import DummyLock
|
||||
from kombu.exceptions import ContentDisallowed, DecodeError
|
||||
from kombu.utils.compat import _detect_environment
|
||||
from kombu.utils.encoding import safe_repr
|
||||
from kombu.utils.limits import TokenBucket
|
||||
from vine import ppartial, promise
|
||||
|
||||
from celery import bootsteps, signals
|
||||
from celery.app.trace import build_tracer
|
||||
from celery.exceptions import (CPendingDeprecationWarning, InvalidTaskError, NotRegistered, WorkerShutdown,
|
||||
WorkerTerminate)
|
||||
from celery.utils.functional import noop
|
||||
from celery.utils.log import get_logger
|
||||
from celery.utils.nodenames import gethostname
|
||||
from celery.utils.objects import Bunch
|
||||
from celery.utils.text import truncate
|
||||
from celery.utils.time import humanize_seconds, rate
|
||||
from celery.worker import loops
|
||||
from celery.worker.state import active_requests, maybe_shutdown, requests, reserved_requests, task_reserved
|
||||
|
||||
__all__ = ('Consumer', 'Evloop', 'dump_body')
|
||||
|
||||
CLOSE = bootsteps.CLOSE
|
||||
TERMINATE = bootsteps.TERMINATE
|
||||
STOP_CONDITIONS = {CLOSE, TERMINATE}
|
||||
logger = get_logger(__name__)
|
||||
debug, info, warn, error, crit = (logger.debug, logger.info, logger.warning,
|
||||
logger.error, logger.critical)
|
||||
|
||||
CONNECTION_RETRY = """\
|
||||
consumer: Connection to broker lost. \
|
||||
Trying to re-establish the connection...\
|
||||
"""
|
||||
|
||||
CONNECTION_RETRY_STEP = """\
|
||||
Trying again {when}... ({retries}/{max_retries})\
|
||||
"""
|
||||
|
||||
CONNECTION_ERROR = """\
|
||||
consumer: Cannot connect to %s: %s.
|
||||
%s
|
||||
"""
|
||||
|
||||
CONNECTION_FAILOVER = """\
|
||||
Will retry using next failover.\
|
||||
"""
|
||||
|
||||
UNKNOWN_FORMAT = """\
|
||||
Received and deleted unknown message. Wrong destination?!?
|
||||
|
||||
The full contents of the message body was: %s
|
||||
"""
|
||||
|
||||
#: Error message for when an unregistered task is received.
|
||||
UNKNOWN_TASK_ERROR = """\
|
||||
Received unregistered task of type %s.
|
||||
The message has been ignored and discarded.
|
||||
|
||||
Did you remember to import the module containing this task?
|
||||
Or maybe you're using relative imports?
|
||||
|
||||
Please see
|
||||
https://docs.celeryq.dev/en/latest/internals/protocol.html
|
||||
for more information.
|
||||
|
||||
The full contents of the message body was:
|
||||
%s
|
||||
|
||||
The full contents of the message headers:
|
||||
%s
|
||||
|
||||
The delivery info for this task is:
|
||||
%s
|
||||
"""
|
||||
|
||||
#: Error message for when an invalid task message is received.
|
||||
INVALID_TASK_ERROR = """\
|
||||
Received invalid task message: %s
|
||||
The message has been ignored and discarded.
|
||||
|
||||
Please ensure your message conforms to the task
|
||||
message protocol as described here:
|
||||
https://docs.celeryq.dev/en/latest/internals/protocol.html
|
||||
|
||||
The full contents of the message body was:
|
||||
%s
|
||||
"""
|
||||
|
||||
MESSAGE_DECODE_ERROR = """\
|
||||
Can't decode message body: %r [type:%r encoding:%r headers:%s]
|
||||
|
||||
body: %s
|
||||
"""
|
||||
|
||||
MESSAGE_REPORT = """\
|
||||
body: {0}
|
||||
{{content_type:{1} content_encoding:{2}
|
||||
delivery_info:{3} headers={4}}}
|
||||
"""
|
||||
|
||||
TERMINATING_TASK_ON_RESTART_AFTER_A_CONNECTION_LOSS = """\
|
||||
Task %s cannot be acknowledged after a connection loss since late acknowledgement is enabled for it.
|
||||
Terminating it instead.
|
||||
"""
|
||||
|
||||
CANCEL_TASKS_BY_DEFAULT = """
|
||||
In Celery 5.1 we introduced an optional breaking change which
|
||||
on connection loss cancels all currently executed tasks with late acknowledgement enabled.
|
||||
These tasks cannot be acknowledged as the connection is gone, and the tasks are automatically redelivered
|
||||
back to the queue. You can enable this behavior using the worker_cancel_long_running_tasks_on_connection_loss
|
||||
setting. In Celery 5.1 it is set to False by default. The setting will be set to True by default in Celery 6.0.
|
||||
"""
|
||||
|
||||
|
||||
def dump_body(m, body):
|
||||
"""Format message body for debugging purposes."""
|
||||
# v2 protocol does not deserialize body
|
||||
body = m.body if body is None else body
|
||||
return '{} ({}b)'.format(truncate(safe_repr(body), 1024),
|
||||
len(m.body))
|
||||
|
||||
|
||||
class Consumer:
|
||||
"""Consumer blueprint."""
|
||||
|
||||
Strategies = dict
|
||||
|
||||
#: Optional callback called the first time the worker
|
||||
#: is ready to receive tasks.
|
||||
init_callback = None
|
||||
|
||||
#: The current worker pool instance.
|
||||
pool = None
|
||||
|
||||
#: A timer used for high-priority internal tasks, such
|
||||
#: as sending heartbeats.
|
||||
timer = None
|
||||
|
||||
restart_count = -1 # first start is the same as a restart
|
||||
|
||||
#: This flag will be turned off after the first failed
|
||||
#: connection attempt.
|
||||
first_connection_attempt = True
|
||||
|
||||
class Blueprint(bootsteps.Blueprint):
|
||||
"""Consumer blueprint."""
|
||||
|
||||
name = 'Consumer'
|
||||
default_steps = [
|
||||
'celery.worker.consumer.connection:Connection',
|
||||
'celery.worker.consumer.mingle:Mingle',
|
||||
'celery.worker.consumer.events:Events',
|
||||
'celery.worker.consumer.gossip:Gossip',
|
||||
'celery.worker.consumer.heart:Heart',
|
||||
'celery.worker.consumer.control:Control',
|
||||
'celery.worker.consumer.tasks:Tasks',
|
||||
'celery.worker.consumer.delayed_delivery:DelayedDelivery',
|
||||
'celery.worker.consumer.consumer:Evloop',
|
||||
'celery.worker.consumer.agent:Agent',
|
||||
]
|
||||
|
||||
def shutdown(self, parent):
|
||||
self.send_all(parent, 'shutdown')
|
||||
|
||||
def __init__(self, on_task_request,
|
||||
init_callback=noop, hostname=None,
|
||||
pool=None, app=None,
|
||||
timer=None, controller=None, hub=None, amqheartbeat=None,
|
||||
worker_options=None, disable_rate_limits=False,
|
||||
initial_prefetch_count=2, prefetch_multiplier=1, **kwargs):
|
||||
self.app = app
|
||||
self.controller = controller
|
||||
self.init_callback = init_callback
|
||||
self.hostname = hostname or gethostname()
|
||||
self.pid = os.getpid()
|
||||
self.pool = pool
|
||||
self.timer = timer
|
||||
self.strategies = self.Strategies()
|
||||
self.conninfo = self.app.connection_for_read()
|
||||
self.connection_errors = self.conninfo.connection_errors
|
||||
self.channel_errors = self.conninfo.channel_errors
|
||||
self._restart_state = restart_state(maxR=5, maxT=1)
|
||||
|
||||
self._does_info = logger.isEnabledFor(logging.INFO)
|
||||
self._limit_order = 0
|
||||
self.on_task_request = on_task_request
|
||||
self.on_task_message = set()
|
||||
self.amqheartbeat_rate = self.app.conf.broker_heartbeat_checkrate
|
||||
self.disable_rate_limits = disable_rate_limits
|
||||
self.initial_prefetch_count = initial_prefetch_count
|
||||
self.prefetch_multiplier = prefetch_multiplier
|
||||
self._maximum_prefetch_restored = True
|
||||
|
||||
# this contains a tokenbucket for each task type by name, used for
|
||||
# rate limits, or None if rate limits are disabled for that task.
|
||||
self.task_buckets = defaultdict(lambda: None)
|
||||
self.reset_rate_limits()
|
||||
|
||||
self.hub = hub
|
||||
if self.hub or getattr(self.pool, 'is_green', False):
|
||||
self.amqheartbeat = amqheartbeat
|
||||
if self.amqheartbeat is None:
|
||||
self.amqheartbeat = self.app.conf.broker_heartbeat
|
||||
else:
|
||||
self.amqheartbeat = 0
|
||||
|
||||
if not hasattr(self, 'loop'):
|
||||
self.loop = loops.asynloop if hub else loops.synloop
|
||||
|
||||
if _detect_environment() == 'gevent':
|
||||
# there's a gevent bug that causes timeouts to not be reset,
|
||||
# so if the connection timeout is exceeded once, it can NEVER
|
||||
# connect again.
|
||||
self.app.conf.broker_connection_timeout = None
|
||||
|
||||
self._pending_operations = []
|
||||
|
||||
self.steps = []
|
||||
self.blueprint = self.Blueprint(
|
||||
steps=self.app.steps['consumer'],
|
||||
on_close=self.on_close,
|
||||
)
|
||||
self.blueprint.apply(self, **dict(worker_options or {}, **kwargs))
|
||||
|
||||
def call_soon(self, p, *args, **kwargs):
|
||||
p = ppartial(p, *args, **kwargs)
|
||||
if self.hub:
|
||||
return self.hub.call_soon(p)
|
||||
self._pending_operations.append(p)
|
||||
return p
|
||||
|
||||
def perform_pending_operations(self):
|
||||
if not self.hub:
|
||||
while self._pending_operations:
|
||||
try:
|
||||
self._pending_operations.pop()()
|
||||
except Exception as exc: # pylint: disable=broad-except
|
||||
logger.exception('Pending callback raised: %r', exc)
|
||||
|
||||
def bucket_for_task(self, type):
|
||||
limit = rate(getattr(type, 'rate_limit', None))
|
||||
return TokenBucket(limit, capacity=1) if limit else None
|
||||
|
||||
def reset_rate_limits(self):
|
||||
self.task_buckets.update(
|
||||
(n, self.bucket_for_task(t)) for n, t in self.app.tasks.items()
|
||||
)
|
||||
|
||||
def _update_prefetch_count(self, index=0):
|
||||
"""Update prefetch count after pool/shrink grow operations.
|
||||
|
||||
Index must be the change in number of processes as a positive
|
||||
(increasing) or negative (decreasing) number.
|
||||
|
||||
Note:
|
||||
Currently pool grow operations will end up with an offset
|
||||
of +1 if the initial size of the pool was 0 (e.g.
|
||||
:option:`--autoscale=1,0 <celery worker --autoscale>`).
|
||||
"""
|
||||
num_processes = self.pool.num_processes
|
||||
if not self.initial_prefetch_count or not num_processes:
|
||||
return # prefetch disabled
|
||||
self.initial_prefetch_count = (
|
||||
self.pool.num_processes * self.prefetch_multiplier
|
||||
)
|
||||
return self._update_qos_eventually(index)
|
||||
|
||||
def _update_qos_eventually(self, index):
|
||||
return (self.qos.decrement_eventually if index < 0
|
||||
else self.qos.increment_eventually)(
|
||||
abs(index) * self.prefetch_multiplier)
|
||||
|
||||
def _limit_move_to_pool(self, request):
|
||||
task_reserved(request)
|
||||
self.on_task_request(request)
|
||||
|
||||
def _schedule_bucket_request(self, bucket):
|
||||
while True:
|
||||
try:
|
||||
request, tokens = bucket.pop()
|
||||
except IndexError:
|
||||
# no request, break
|
||||
break
|
||||
|
||||
if bucket.can_consume(tokens):
|
||||
self._limit_move_to_pool(request)
|
||||
continue
|
||||
else:
|
||||
# requeue to head, keep the order.
|
||||
bucket.contents.appendleft((request, tokens))
|
||||
|
||||
pri = self._limit_order = (self._limit_order + 1) % 10
|
||||
hold = bucket.expected_time(tokens)
|
||||
self.timer.call_after(
|
||||
hold, self._schedule_bucket_request, (bucket,),
|
||||
priority=pri,
|
||||
)
|
||||
# no tokens, break
|
||||
break
|
||||
|
||||
def _limit_task(self, request, bucket, tokens):
|
||||
bucket.add((request, tokens))
|
||||
return self._schedule_bucket_request(bucket)
|
||||
|
||||
def _limit_post_eta(self, request, bucket, tokens):
|
||||
self.qos.decrement_eventually()
|
||||
bucket.add((request, tokens))
|
||||
return self._schedule_bucket_request(bucket)
|
||||
|
||||
def start(self):
|
||||
blueprint = self.blueprint
|
||||
while blueprint.state not in STOP_CONDITIONS:
|
||||
maybe_shutdown()
|
||||
if self.restart_count:
|
||||
try:
|
||||
self._restart_state.step()
|
||||
except RestartFreqExceeded as exc:
|
||||
crit('Frequent restarts detected: %r', exc, exc_info=1)
|
||||
sleep(1)
|
||||
self.restart_count += 1
|
||||
if self.app.conf.broker_channel_error_retry:
|
||||
recoverable_errors = (self.connection_errors + self.channel_errors)
|
||||
else:
|
||||
recoverable_errors = self.connection_errors
|
||||
try:
|
||||
blueprint.start(self)
|
||||
except recoverable_errors as exc:
|
||||
# If we're not retrying connections, we need to properly shutdown or terminate
|
||||
# the Celery main process instead of abruptly aborting the process without any cleanup.
|
||||
is_connection_loss_on_startup = self.first_connection_attempt
|
||||
self.first_connection_attempt = False
|
||||
connection_retry_type = self._get_connection_retry_type(is_connection_loss_on_startup)
|
||||
connection_retry = self.app.conf[connection_retry_type]
|
||||
if not connection_retry:
|
||||
crit(
|
||||
f"Retrying to {'establish' if is_connection_loss_on_startup else 're-establish'} "
|
||||
f"a connection to the message broker after a connection loss has "
|
||||
f"been disabled (app.conf.{connection_retry_type}=False). Shutting down..."
|
||||
)
|
||||
raise WorkerShutdown(1) from exc
|
||||
if isinstance(exc, OSError) and exc.errno == errno.EMFILE:
|
||||
crit("Too many open files. Aborting...")
|
||||
raise WorkerTerminate(1) from exc
|
||||
maybe_shutdown()
|
||||
if blueprint.state not in STOP_CONDITIONS:
|
||||
if self.connection:
|
||||
self.on_connection_error_after_connected(exc)
|
||||
else:
|
||||
self.on_connection_error_before_connected(exc)
|
||||
self.on_close()
|
||||
blueprint.restart(self)
|
||||
|
||||
def _get_connection_retry_type(self, is_connection_loss_on_startup):
|
||||
return ('broker_connection_retry_on_startup'
|
||||
if (is_connection_loss_on_startup
|
||||
and self.app.conf.broker_connection_retry_on_startup is not None)
|
||||
else 'broker_connection_retry')
|
||||
|
||||
def on_connection_error_before_connected(self, exc):
|
||||
error(CONNECTION_ERROR, self.conninfo.as_uri(), exc,
|
||||
'Trying to reconnect...')
|
||||
|
||||
def on_connection_error_after_connected(self, exc):
|
||||
warn(CONNECTION_RETRY, exc_info=True)
|
||||
try:
|
||||
self.connection.collect()
|
||||
except Exception: # pylint: disable=broad-except
|
||||
pass
|
||||
|
||||
if self.app.conf.worker_cancel_long_running_tasks_on_connection_loss:
|
||||
for request in tuple(active_requests):
|
||||
if request.task.acks_late and not request.acknowledged:
|
||||
warn(TERMINATING_TASK_ON_RESTART_AFTER_A_CONNECTION_LOSS,
|
||||
request)
|
||||
request.cancel(self.pool)
|
||||
else:
|
||||
warnings.warn(CANCEL_TASKS_BY_DEFAULT, CPendingDeprecationWarning)
|
||||
|
||||
if self.app.conf.worker_enable_prefetch_count_reduction:
|
||||
self.initial_prefetch_count = max(
|
||||
self.prefetch_multiplier,
|
||||
self.max_prefetch_count - len(tuple(active_requests)) * self.prefetch_multiplier
|
||||
)
|
||||
|
||||
self._maximum_prefetch_restored = self.initial_prefetch_count == self.max_prefetch_count
|
||||
if not self._maximum_prefetch_restored:
|
||||
logger.info(
|
||||
f"Temporarily reducing the prefetch count to {self.initial_prefetch_count} to avoid "
|
||||
f"over-fetching since {len(tuple(active_requests))} tasks are currently being processed.\n"
|
||||
f"The prefetch count will be gradually restored to {self.max_prefetch_count} as the tasks "
|
||||
"complete processing."
|
||||
)
|
||||
|
||||
def register_with_event_loop(self, hub):
|
||||
self.blueprint.send_all(
|
||||
self, 'register_with_event_loop', args=(hub,),
|
||||
description='Hub.register',
|
||||
)
|
||||
|
||||
def shutdown(self):
|
||||
self.perform_pending_operations()
|
||||
self.blueprint.shutdown(self)
|
||||
|
||||
def stop(self):
|
||||
self.blueprint.stop(self)
|
||||
|
||||
def on_ready(self):
|
||||
callback, self.init_callback = self.init_callback, None
|
||||
if callback:
|
||||
callback(self)
|
||||
|
||||
def loop_args(self):
|
||||
return (self, self.connection, self.task_consumer,
|
||||
self.blueprint, self.hub, self.qos, self.amqheartbeat,
|
||||
self.app.clock, self.amqheartbeat_rate)
|
||||
|
||||
def on_decode_error(self, message, exc):
|
||||
"""Callback called if an error occurs while decoding a message.
|
||||
|
||||
Simply logs the error and acknowledges the message so it
|
||||
doesn't enter a loop.
|
||||
|
||||
Arguments:
|
||||
message (kombu.Message): The message received.
|
||||
exc (Exception): The exception being handled.
|
||||
"""
|
||||
crit(MESSAGE_DECODE_ERROR,
|
||||
exc, message.content_type, message.content_encoding,
|
||||
safe_repr(message.headers), dump_body(message, message.body),
|
||||
exc_info=1)
|
||||
message.ack()
|
||||
|
||||
def on_close(self):
|
||||
# Clear internal queues to get rid of old messages.
|
||||
# They can't be acked anyway, as a delivery tag is specific
|
||||
# to the current channel.
|
||||
if self.controller and self.controller.semaphore:
|
||||
self.controller.semaphore.clear()
|
||||
if self.timer:
|
||||
self.timer.clear()
|
||||
for bucket in self.task_buckets.values():
|
||||
if bucket:
|
||||
bucket.clear_pending()
|
||||
for request_id in reserved_requests:
|
||||
if request_id in requests:
|
||||
del requests[request_id]
|
||||
reserved_requests.clear()
|
||||
if self.pool and self.pool.flush:
|
||||
self.pool.flush()
|
||||
|
||||
def connect(self):
|
||||
"""Establish the broker connection used for consuming tasks.
|
||||
|
||||
Retries establishing the connection if the
|
||||
:setting:`broker_connection_retry` setting is enabled
|
||||
"""
|
||||
conn = self.connection_for_read(heartbeat=self.amqheartbeat)
|
||||
if self.hub:
|
||||
conn.transport.register_with_event_loop(conn.connection, self.hub)
|
||||
return conn
|
||||
|
||||
def connection_for_read(self, heartbeat=None):
|
||||
return self.ensure_connected(
|
||||
self.app.connection_for_read(heartbeat=heartbeat))
|
||||
|
||||
def connection_for_write(self, url=None, heartbeat=None):
|
||||
return self.ensure_connected(
|
||||
self.app.connection_for_write(url=url, heartbeat=heartbeat))
|
||||
|
||||
def ensure_connected(self, conn):
|
||||
# Callback called for each retry while the connection
|
||||
# can't be established.
|
||||
def _error_handler(exc, interval, next_step=CONNECTION_RETRY_STEP):
|
||||
if getattr(conn, 'alt', None) and interval == 0:
|
||||
next_step = CONNECTION_FAILOVER
|
||||
next_step = next_step.format(
|
||||
when=humanize_seconds(interval, 'in', ' '),
|
||||
retries=int(interval / 2),
|
||||
max_retries=self.app.conf.broker_connection_max_retries)
|
||||
error(CONNECTION_ERROR, conn.as_uri(), exc, next_step)
|
||||
|
||||
# Remember that the connection is lazy, it won't establish
|
||||
# until needed.
|
||||
|
||||
# TODO: Rely only on broker_connection_retry_on_startup to determine whether connection retries are disabled.
|
||||
# We will make the switch in Celery 6.0.
|
||||
|
||||
retry_disabled = False
|
||||
|
||||
if self.app.conf.broker_connection_retry_on_startup is None:
|
||||
# If broker_connection_retry_on_startup is not set, revert to broker_connection_retry
|
||||
# to determine whether connection retries are disabled.
|
||||
retry_disabled = not self.app.conf.broker_connection_retry
|
||||
|
||||
if retry_disabled:
|
||||
warnings.warn(
|
||||
CPendingDeprecationWarning(
|
||||
"The broker_connection_retry configuration setting will no longer determine\n"
|
||||
"whether broker connection retries are made during startup in Celery 6.0 and above.\n"
|
||||
"If you wish to refrain from retrying connections on startup,\n"
|
||||
"you should set broker_connection_retry_on_startup to False instead.")
|
||||
)
|
||||
else:
|
||||
if self.first_connection_attempt:
|
||||
retry_disabled = not self.app.conf.broker_connection_retry_on_startup
|
||||
else:
|
||||
retry_disabled = not self.app.conf.broker_connection_retry
|
||||
|
||||
if retry_disabled:
|
||||
# Retry disabled, just call connect directly.
|
||||
conn.connect()
|
||||
self.first_connection_attempt = False
|
||||
return conn
|
||||
|
||||
conn = conn.ensure_connection(
|
||||
_error_handler, self.app.conf.broker_connection_max_retries,
|
||||
callback=maybe_shutdown,
|
||||
)
|
||||
self.first_connection_attempt = False
|
||||
return conn
|
||||
|
||||
def _flush_events(self):
|
||||
if self.event_dispatcher:
|
||||
self.event_dispatcher.flush()
|
||||
|
||||
def on_send_event_buffered(self):
|
||||
if self.hub:
|
||||
self.hub._ready.add(self._flush_events)
|
||||
|
||||
def add_task_queue(self, queue, exchange=None, exchange_type=None,
|
||||
routing_key=None, **options):
|
||||
cset = self.task_consumer
|
||||
queues = self.app.amqp.queues
|
||||
# Must use in' here, as __missing__ will automatically
|
||||
# create queues when :setting:`task_create_missing_queues` is enabled.
|
||||
# (Issue #1079)
|
||||
if queue in queues:
|
||||
q = queues[queue]
|
||||
else:
|
||||
exchange = queue if exchange is None else exchange
|
||||
exchange_type = ('direct' if exchange_type is None
|
||||
else exchange_type)
|
||||
q = queues.select_add(queue,
|
||||
exchange=exchange,
|
||||
exchange_type=exchange_type,
|
||||
routing_key=routing_key, **options)
|
||||
if not cset.consuming_from(queue):
|
||||
cset.add_queue(q)
|
||||
cset.consume()
|
||||
info('Started consuming from %s', queue)
|
||||
|
||||
def cancel_task_queue(self, queue):
|
||||
info('Canceling queue %s', queue)
|
||||
self.app.amqp.queues.deselect(queue)
|
||||
self.task_consumer.cancel_by_queue(queue)
|
||||
|
||||
def apply_eta_task(self, task):
|
||||
"""Method called by the timer to apply a task with an ETA/countdown."""
|
||||
task_reserved(task)
|
||||
self.on_task_request(task)
|
||||
self.qos.decrement_eventually()
|
||||
|
||||
def _message_report(self, body, message):
|
||||
return MESSAGE_REPORT.format(dump_body(message, body),
|
||||
safe_repr(message.content_type),
|
||||
safe_repr(message.content_encoding),
|
||||
safe_repr(message.delivery_info),
|
||||
safe_repr(message.headers))
|
||||
|
||||
def on_unknown_message(self, body, message):
|
||||
warn(UNKNOWN_FORMAT, self._message_report(body, message))
|
||||
message.reject_log_error(logger, self.connection_errors)
|
||||
signals.task_rejected.send(sender=self, message=message, exc=None)
|
||||
|
||||
def on_unknown_task(self, body, message, exc):
|
||||
error(UNKNOWN_TASK_ERROR,
|
||||
exc,
|
||||
dump_body(message, body),
|
||||
message.headers,
|
||||
message.delivery_info,
|
||||
exc_info=True)
|
||||
try:
|
||||
id_, name = message.headers['id'], message.headers['task']
|
||||
root_id = message.headers.get('root_id')
|
||||
except KeyError: # proto1
|
||||
payload = message.payload
|
||||
id_, name = payload['id'], payload['task']
|
||||
root_id = None
|
||||
request = Bunch(
|
||||
name=name, chord=None, root_id=root_id,
|
||||
correlation_id=message.properties.get('correlation_id'),
|
||||
reply_to=message.properties.get('reply_to'),
|
||||
errbacks=None,
|
||||
)
|
||||
message.reject_log_error(logger, self.connection_errors)
|
||||
self.app.backend.mark_as_failure(
|
||||
id_, NotRegistered(name), request=request,
|
||||
)
|
||||
if self.event_dispatcher:
|
||||
self.event_dispatcher.send(
|
||||
'task-failed', uuid=id_,
|
||||
exception=f'NotRegistered({name!r})',
|
||||
)
|
||||
signals.task_unknown.send(
|
||||
sender=self, message=message, exc=exc, name=name, id=id_,
|
||||
)
|
||||
|
||||
def on_invalid_task(self, body, message, exc):
|
||||
error(INVALID_TASK_ERROR, exc, dump_body(message, body),
|
||||
exc_info=True)
|
||||
message.reject_log_error(logger, self.connection_errors)
|
||||
signals.task_rejected.send(sender=self, message=message, exc=exc)
|
||||
|
||||
def update_strategies(self):
|
||||
loader = self.app.loader
|
||||
for name, task in self.app.tasks.items():
|
||||
self.strategies[name] = task.start_strategy(self.app, self)
|
||||
task.__trace__ = build_tracer(name, task, loader, self.hostname,
|
||||
app=self.app)
|
||||
|
||||
def create_task_handler(self, promise=promise):
|
||||
strategies = self.strategies
|
||||
on_unknown_message = self.on_unknown_message
|
||||
on_unknown_task = self.on_unknown_task
|
||||
on_invalid_task = self.on_invalid_task
|
||||
callbacks = self.on_task_message
|
||||
call_soon = self.call_soon
|
||||
|
||||
def on_task_received(message):
|
||||
# payload will only be set for v1 protocol, since v2
|
||||
# will defer deserializing the message body to the pool.
|
||||
payload = None
|
||||
try:
|
||||
type_ = message.headers['task'] # protocol v2
|
||||
except TypeError:
|
||||
return on_unknown_message(None, message)
|
||||
except KeyError:
|
||||
try:
|
||||
payload = message.decode()
|
||||
except Exception as exc: # pylint: disable=broad-except
|
||||
return self.on_decode_error(message, exc)
|
||||
try:
|
||||
type_, payload = payload['task'], payload # protocol v1
|
||||
except (TypeError, KeyError):
|
||||
return on_unknown_message(payload, message)
|
||||
try:
|
||||
strategy = strategies[type_]
|
||||
except KeyError as exc:
|
||||
return on_unknown_task(None, message, exc)
|
||||
else:
|
||||
try:
|
||||
ack_log_error_promise = promise(
|
||||
call_soon,
|
||||
(message.ack_log_error,),
|
||||
on_error=self._restore_prefetch_count_after_connection_restart,
|
||||
)
|
||||
reject_log_error_promise = promise(
|
||||
call_soon,
|
||||
(message.reject_log_error,),
|
||||
on_error=self._restore_prefetch_count_after_connection_restart,
|
||||
)
|
||||
|
||||
if (
|
||||
not self._maximum_prefetch_restored
|
||||
and self.restart_count > 0
|
||||
and self._new_prefetch_count <= self.max_prefetch_count
|
||||
):
|
||||
ack_log_error_promise.then(self._restore_prefetch_count_after_connection_restart,
|
||||
on_error=self._restore_prefetch_count_after_connection_restart)
|
||||
reject_log_error_promise.then(self._restore_prefetch_count_after_connection_restart,
|
||||
on_error=self._restore_prefetch_count_after_connection_restart)
|
||||
|
||||
strategy(
|
||||
message, payload,
|
||||
ack_log_error_promise,
|
||||
reject_log_error_promise,
|
||||
callbacks,
|
||||
)
|
||||
except (InvalidTaskError, ContentDisallowed) as exc:
|
||||
return on_invalid_task(payload, message, exc)
|
||||
except DecodeError as exc:
|
||||
return self.on_decode_error(message, exc)
|
||||
|
||||
return on_task_received
|
||||
|
||||
def _restore_prefetch_count_after_connection_restart(self, p, *args):
|
||||
with self.qos._mutex:
|
||||
if any((
|
||||
not self.app.conf.worker_enable_prefetch_count_reduction,
|
||||
self._maximum_prefetch_restored,
|
||||
)):
|
||||
return
|
||||
|
||||
new_prefetch_count = min(self.max_prefetch_count, self._new_prefetch_count)
|
||||
self.qos.value = self.initial_prefetch_count = new_prefetch_count
|
||||
self.qos.set(self.qos.value)
|
||||
|
||||
already_restored = self._maximum_prefetch_restored
|
||||
self._maximum_prefetch_restored = new_prefetch_count == self.max_prefetch_count
|
||||
|
||||
if already_restored is False and self._maximum_prefetch_restored is True:
|
||||
logger.info(
|
||||
"Resuming normal operations following a restart.\n"
|
||||
f"Prefetch count has been restored to the maximum of {self.max_prefetch_count}"
|
||||
)
|
||||
|
||||
@property
|
||||
def max_prefetch_count(self):
|
||||
return self.pool.num_processes * self.prefetch_multiplier
|
||||
|
||||
@property
|
||||
def _new_prefetch_count(self):
|
||||
return self.qos.value + self.prefetch_multiplier
|
||||
|
||||
def __repr__(self):
|
||||
"""``repr(self)``."""
|
||||
return '<Consumer: {self.hostname} ({state})>'.format(
|
||||
self=self, state=self.blueprint.human_state(),
|
||||
)
|
||||
|
||||
def cancel_all_unacked_requests(self):
|
||||
"""Cancel all active requests that either do not require late acknowledgments or,
|
||||
if they do, have not been acknowledged yet.
|
||||
"""
|
||||
|
||||
def should_cancel(request):
|
||||
if not request.task.acks_late:
|
||||
# Task does not require late acknowledgment, cancel it.
|
||||
return True
|
||||
|
||||
if not request.acknowledged:
|
||||
# Task is late acknowledged, but it has not been acknowledged yet, cancel it.
|
||||
return True
|
||||
|
||||
# Task is late acknowledged, but it has already been acknowledged.
|
||||
return False # Do not cancel and allow it to gracefully finish as it has already been acknowledged.
|
||||
|
||||
requests_to_cancel = tuple(filter(should_cancel, active_requests))
|
||||
|
||||
if requests_to_cancel:
|
||||
for request in requests_to_cancel:
|
||||
request.cancel(self.pool)
|
||||
|
||||
|
||||
class Evloop(bootsteps.StartStopStep):
|
||||
"""Event loop service.
|
||||
|
||||
Note:
|
||||
This is always started last.
|
||||
"""
|
||||
|
||||
label = 'event loop'
|
||||
last = True
|
||||
|
||||
def start(self, c):
|
||||
self.patch_all(c)
|
||||
c.loop(*c.loop_args())
|
||||
|
||||
def patch_all(self, c):
|
||||
c.qos._mutex = DummyLock()
|
||||
@@ -0,0 +1,33 @@
|
||||
"""Worker Remote Control Bootstep.
|
||||
|
||||
``Control`` -> :mod:`celery.worker.pidbox` -> :mod:`kombu.pidbox`.
|
||||
|
||||
The actual commands are implemented in :mod:`celery.worker.control`.
|
||||
"""
|
||||
from celery import bootsteps
|
||||
from celery.utils.log import get_logger
|
||||
from celery.worker import pidbox
|
||||
|
||||
from .tasks import Tasks
|
||||
|
||||
__all__ = ('Control',)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class Control(bootsteps.StartStopStep):
|
||||
"""Remote control command service."""
|
||||
|
||||
requires = (Tasks,)
|
||||
|
||||
def __init__(self, c, **kwargs):
|
||||
self.is_green = c.pool is not None and c.pool.is_green
|
||||
self.box = (pidbox.gPidbox if self.is_green else pidbox.Pidbox)(c)
|
||||
self.start = self.box.start
|
||||
self.stop = self.box.stop
|
||||
self.shutdown = self.box.shutdown
|
||||
super().__init__(c, **kwargs)
|
||||
|
||||
def include_if(self, c):
|
||||
return (c.app.conf.worker_enable_remote_control and
|
||||
c.conninfo.supports_exchange_type('fanout'))
|
||||
@@ -0,0 +1,247 @@
|
||||
"""Native delayed delivery functionality for Celery workers.
|
||||
|
||||
This module provides the DelayedDelivery bootstep which handles setup and configuration
|
||||
of native delayed delivery functionality when using quorum queues.
|
||||
"""
|
||||
from typing import Iterator, List, Optional, Set, Union, ValuesView
|
||||
|
||||
from kombu import Connection, Queue
|
||||
from kombu.transport.native_delayed_delivery import (bind_queue_to_native_delayed_delivery_exchange,
|
||||
declare_native_delayed_delivery_exchanges_and_queues)
|
||||
from kombu.utils.functional import retry_over_time
|
||||
|
||||
from celery import Celery, bootsteps
|
||||
from celery.utils.log import get_logger
|
||||
from celery.utils.quorum_queues import detect_quorum_queues
|
||||
from celery.worker.consumer import Consumer, Tasks
|
||||
|
||||
__all__ = ('DelayedDelivery',)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
# Default retry settings
|
||||
RETRY_INTERVAL = 1.0 # seconds between retries
|
||||
MAX_RETRIES = 3 # maximum number of retries
|
||||
|
||||
|
||||
# Valid queue types for delayed delivery
|
||||
VALID_QUEUE_TYPES = {'classic', 'quorum'}
|
||||
|
||||
|
||||
class DelayedDelivery(bootsteps.StartStopStep):
|
||||
"""Bootstep that sets up native delayed delivery functionality.
|
||||
|
||||
This component handles the setup and configuration of native delayed delivery
|
||||
for Celery workers. It is automatically included when quorum queues are
|
||||
detected in the application configuration.
|
||||
|
||||
Responsibilities:
|
||||
- Declaring native delayed delivery exchanges and queues
|
||||
- Binding all application queues to the delayed delivery exchanges
|
||||
- Handling connection failures gracefully with retries
|
||||
- Validating configuration settings
|
||||
"""
|
||||
|
||||
requires = (Tasks,)
|
||||
|
||||
def include_if(self, c: Consumer) -> bool:
|
||||
"""Determine if this bootstep should be included.
|
||||
|
||||
Args:
|
||||
c: The Celery consumer instance
|
||||
|
||||
Returns:
|
||||
bool: True if quorum queues are detected, False otherwise
|
||||
"""
|
||||
return detect_quorum_queues(c.app, c.app.connection_for_write().transport.driver_type)[0]
|
||||
|
||||
def start(self, c: Consumer) -> None:
|
||||
"""Initialize delayed delivery for all broker URLs.
|
||||
|
||||
Attempts to set up delayed delivery for each broker URL in the configuration.
|
||||
Failures are logged but don't prevent attempting remaining URLs.
|
||||
|
||||
Args:
|
||||
c: The Celery consumer instance
|
||||
|
||||
Raises:
|
||||
ValueError: If configuration validation fails
|
||||
"""
|
||||
app: Celery = c.app
|
||||
|
||||
try:
|
||||
self._validate_configuration(app)
|
||||
except ValueError as e:
|
||||
logger.critical("Configuration validation failed: %s", str(e))
|
||||
raise
|
||||
|
||||
broker_urls = self._validate_broker_urls(app.conf.broker_url)
|
||||
setup_errors = []
|
||||
|
||||
for broker_url in broker_urls:
|
||||
try:
|
||||
retry_over_time(
|
||||
self._setup_delayed_delivery,
|
||||
args=(c, broker_url),
|
||||
catch=(ConnectionRefusedError, OSError),
|
||||
errback=self._on_retry,
|
||||
interval_start=RETRY_INTERVAL,
|
||||
max_retries=MAX_RETRIES,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to setup delayed delivery for %r: %s",
|
||||
broker_url, str(e)
|
||||
)
|
||||
setup_errors.append((broker_url, e))
|
||||
|
||||
if len(setup_errors) == len(broker_urls):
|
||||
logger.critical(
|
||||
"Failed to setup delayed delivery for all broker URLs. "
|
||||
"Native delayed delivery will not be available."
|
||||
)
|
||||
|
||||
def _setup_delayed_delivery(self, c: Consumer, broker_url: str) -> None:
|
||||
"""Set up delayed delivery for a specific broker URL.
|
||||
|
||||
Args:
|
||||
c: The Celery consumer instance
|
||||
broker_url: The broker URL to configure
|
||||
|
||||
Raises:
|
||||
ConnectionRefusedError: If connection to the broker fails
|
||||
OSError: If there are network-related issues
|
||||
Exception: For other unexpected errors during setup
|
||||
"""
|
||||
connection: Connection = c.app.connection_for_write(url=broker_url)
|
||||
queue_type = c.app.conf.broker_native_delayed_delivery_queue_type
|
||||
logger.debug(
|
||||
"Setting up delayed delivery for broker %r with queue type %r",
|
||||
broker_url, queue_type
|
||||
)
|
||||
|
||||
try:
|
||||
declare_native_delayed_delivery_exchanges_and_queues(
|
||||
connection,
|
||||
queue_type
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to declare exchanges and queues for %r: %s",
|
||||
broker_url, str(e)
|
||||
)
|
||||
raise
|
||||
|
||||
try:
|
||||
self._bind_queues(c.app, connection)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to bind queues for %r: %s",
|
||||
broker_url, str(e)
|
||||
)
|
||||
raise
|
||||
|
||||
def _bind_queues(self, app: Celery, connection: Connection) -> None:
|
||||
"""Bind all application queues to delayed delivery exchanges.
|
||||
|
||||
Args:
|
||||
app: The Celery application instance
|
||||
connection: The broker connection to use
|
||||
|
||||
Raises:
|
||||
Exception: If queue binding fails
|
||||
"""
|
||||
queues: ValuesView[Queue] = app.amqp.queues.values()
|
||||
if not queues:
|
||||
logger.warning("No queues found to bind for delayed delivery")
|
||||
return
|
||||
|
||||
for queue in queues:
|
||||
try:
|
||||
logger.debug("Binding queue %r to delayed delivery exchange", queue.name)
|
||||
bind_queue_to_native_delayed_delivery_exchange(connection, queue)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to bind queue %r: %s",
|
||||
queue.name, str(e)
|
||||
)
|
||||
raise
|
||||
|
||||
def _on_retry(self, exc: Exception, interval_range: Iterator[float], intervals_count: int) -> None:
|
||||
"""Callback for retry attempts.
|
||||
|
||||
Args:
|
||||
exc: The exception that triggered the retry
|
||||
interval_range: An iterator which returns the time in seconds to sleep next
|
||||
intervals_count: Number of retry attempts so far
|
||||
"""
|
||||
logger.warning(
|
||||
"Retrying delayed delivery setup (attempt %d/%d) after error: %s",
|
||||
intervals_count + 1, MAX_RETRIES, str(exc)
|
||||
)
|
||||
|
||||
def _validate_configuration(self, app: Celery) -> None:
|
||||
"""Validate all required configuration settings.
|
||||
|
||||
Args:
|
||||
app: The Celery application instance
|
||||
|
||||
Raises:
|
||||
ValueError: If any configuration is invalid
|
||||
"""
|
||||
# Validate broker URLs
|
||||
self._validate_broker_urls(app.conf.broker_url)
|
||||
|
||||
# Validate queue type
|
||||
self._validate_queue_type(app.conf.broker_native_delayed_delivery_queue_type)
|
||||
|
||||
def _validate_broker_urls(self, broker_urls: Union[str, List[str]]) -> Set[str]:
|
||||
"""Validate and split broker URLs.
|
||||
|
||||
Args:
|
||||
broker_urls: Broker URLs, either as a semicolon-separated string
|
||||
or as a list of strings
|
||||
|
||||
Returns:
|
||||
Set of valid broker URLs
|
||||
|
||||
Raises:
|
||||
ValueError: If no valid broker URLs are found or if invalid URLs are provided
|
||||
"""
|
||||
if not broker_urls:
|
||||
raise ValueError("broker_url configuration is empty")
|
||||
|
||||
if isinstance(broker_urls, str):
|
||||
brokers = broker_urls.split(";")
|
||||
elif isinstance(broker_urls, list):
|
||||
if not all(isinstance(url, str) for url in broker_urls):
|
||||
raise ValueError("All broker URLs must be strings")
|
||||
brokers = broker_urls
|
||||
else:
|
||||
raise ValueError(f"broker_url must be a string or list, got {broker_urls!r}")
|
||||
|
||||
valid_urls = {url for url in brokers}
|
||||
|
||||
if not valid_urls:
|
||||
raise ValueError("No valid broker URLs found in configuration")
|
||||
|
||||
return valid_urls
|
||||
|
||||
def _validate_queue_type(self, queue_type: Optional[str]) -> None:
|
||||
"""Validate the queue type configuration.
|
||||
|
||||
Args:
|
||||
queue_type: The configured queue type
|
||||
|
||||
Raises:
|
||||
ValueError: If queue type is invalid
|
||||
"""
|
||||
if not queue_type:
|
||||
raise ValueError("broker_native_delayed_delivery_queue_type is not configured")
|
||||
|
||||
if queue_type not in VALID_QUEUE_TYPES:
|
||||
sorted_types = sorted(VALID_QUEUE_TYPES)
|
||||
raise ValueError(
|
||||
f"Invalid queue type {queue_type!r}. Must be one of: {', '.join(sorted_types)}"
|
||||
)
|
||||
@@ -0,0 +1,68 @@
|
||||
"""Worker Event Dispatcher Bootstep.
|
||||
|
||||
``Events`` -> :class:`celery.events.EventDispatcher`.
|
||||
"""
|
||||
from kombu.common import ignore_errors
|
||||
|
||||
from celery import bootsteps
|
||||
|
||||
from .connection import Connection
|
||||
|
||||
__all__ = ('Events',)
|
||||
|
||||
|
||||
class Events(bootsteps.StartStopStep):
|
||||
"""Service used for sending monitoring events."""
|
||||
|
||||
requires = (Connection,)
|
||||
|
||||
def __init__(self, c,
|
||||
task_events=True,
|
||||
without_heartbeat=False,
|
||||
without_gossip=False,
|
||||
**kwargs):
|
||||
self.groups = None if task_events else ['worker']
|
||||
self.send_events = (
|
||||
task_events or
|
||||
not without_gossip or
|
||||
not without_heartbeat
|
||||
)
|
||||
self.enabled = self.send_events
|
||||
c.event_dispatcher = None
|
||||
super().__init__(c, **kwargs)
|
||||
|
||||
def start(self, c):
|
||||
# flush events sent while connection was down.
|
||||
prev = self._close(c)
|
||||
dis = c.event_dispatcher = c.app.events.Dispatcher(
|
||||
c.connection_for_write(),
|
||||
hostname=c.hostname,
|
||||
enabled=self.send_events,
|
||||
groups=self.groups,
|
||||
# we currently only buffer events when the event loop is enabled
|
||||
# XXX This excludes eventlet/gevent, which should actually buffer.
|
||||
buffer_group=['task'] if c.hub else None,
|
||||
on_send_buffered=c.on_send_event_buffered if c.hub else None,
|
||||
)
|
||||
if prev:
|
||||
dis.extend_buffer(prev)
|
||||
dis.flush()
|
||||
|
||||
def stop(self, c):
|
||||
pass
|
||||
|
||||
def _close(self, c):
|
||||
if c.event_dispatcher:
|
||||
dispatcher = c.event_dispatcher
|
||||
# remember changes from remote control commands:
|
||||
self.groups = dispatcher.groups
|
||||
|
||||
# close custom connection
|
||||
if dispatcher.connection:
|
||||
ignore_errors(c, dispatcher.connection.close)
|
||||
ignore_errors(c, dispatcher.close)
|
||||
c.event_dispatcher = None
|
||||
return dispatcher
|
||||
|
||||
def shutdown(self, c):
|
||||
self._close(c)
|
||||
@@ -0,0 +1,206 @@
|
||||
"""Worker <-> Worker communication Bootstep."""
|
||||
from collections import defaultdict
|
||||
from functools import partial
|
||||
from heapq import heappush
|
||||
from operator import itemgetter
|
||||
|
||||
from kombu import Consumer
|
||||
from kombu.asynchronous.semaphore import DummyLock
|
||||
from kombu.exceptions import ContentDisallowed, DecodeError
|
||||
|
||||
from celery import bootsteps
|
||||
from celery.utils.log import get_logger
|
||||
from celery.utils.objects import Bunch
|
||||
|
||||
from .mingle import Mingle
|
||||
|
||||
__all__ = ('Gossip',)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
debug, info = logger.debug, logger.info
|
||||
|
||||
|
||||
class Gossip(bootsteps.ConsumerStep):
|
||||
"""Bootstep consuming events from other workers.
|
||||
|
||||
This keeps the logical clock value up to date.
|
||||
"""
|
||||
|
||||
label = 'Gossip'
|
||||
requires = (Mingle,)
|
||||
_cons_stamp_fields = itemgetter(
|
||||
'id', 'clock', 'hostname', 'pid', 'topic', 'action', 'cver',
|
||||
)
|
||||
compatible_transports = {'amqp', 'redis'}
|
||||
|
||||
def __init__(self, c, without_gossip=False,
|
||||
interval=5.0, heartbeat_interval=2.0, **kwargs):
|
||||
self.enabled = not without_gossip and self.compatible_transport(c.app)
|
||||
self.app = c.app
|
||||
c.gossip = self
|
||||
self.Receiver = c.app.events.Receiver
|
||||
self.hostname = c.hostname
|
||||
self.full_hostname = '.'.join([self.hostname, str(c.pid)])
|
||||
self.on = Bunch(
|
||||
node_join=set(),
|
||||
node_leave=set(),
|
||||
node_lost=set(),
|
||||
)
|
||||
|
||||
self.timer = c.timer
|
||||
if self.enabled:
|
||||
self.state = c.app.events.State(
|
||||
on_node_join=self.on_node_join,
|
||||
on_node_leave=self.on_node_leave,
|
||||
max_tasks_in_memory=1,
|
||||
)
|
||||
if c.hub:
|
||||
c._mutex = DummyLock()
|
||||
self.update_state = self.state.event
|
||||
self.interval = interval
|
||||
self.heartbeat_interval = heartbeat_interval
|
||||
self._tref = None
|
||||
self.consensus_requests = defaultdict(list)
|
||||
self.consensus_replies = {}
|
||||
self.event_handlers = {
|
||||
'worker.elect': self.on_elect,
|
||||
'worker.elect.ack': self.on_elect_ack,
|
||||
}
|
||||
self.clock = c.app.clock
|
||||
|
||||
self.election_handlers = {
|
||||
'task': self.call_task
|
||||
}
|
||||
|
||||
super().__init__(c, **kwargs)
|
||||
|
||||
def compatible_transport(self, app):
|
||||
with app.connection_for_read() as conn:
|
||||
return conn.transport.driver_type in self.compatible_transports
|
||||
|
||||
def election(self, id, topic, action=None):
|
||||
self.consensus_replies[id] = []
|
||||
self.dispatcher.send(
|
||||
'worker-elect',
|
||||
id=id, topic=topic, action=action, cver=1,
|
||||
)
|
||||
|
||||
def call_task(self, task):
|
||||
try:
|
||||
self.app.signature(task).apply_async()
|
||||
except Exception as exc: # pylint: disable=broad-except
|
||||
logger.exception('Could not call task: %r', exc)
|
||||
|
||||
def on_elect(self, event):
|
||||
try:
|
||||
(id_, clock, hostname, pid,
|
||||
topic, action, _) = self._cons_stamp_fields(event)
|
||||
except KeyError as exc:
|
||||
return logger.exception('election request missing field %s', exc)
|
||||
heappush(
|
||||
self.consensus_requests[id_],
|
||||
(clock, f'{hostname}.{pid}', topic, action),
|
||||
)
|
||||
self.dispatcher.send('worker-elect-ack', id=id_)
|
||||
|
||||
def start(self, c):
|
||||
super().start(c)
|
||||
self.dispatcher = c.event_dispatcher
|
||||
|
||||
def on_elect_ack(self, event):
|
||||
id = event['id']
|
||||
try:
|
||||
replies = self.consensus_replies[id]
|
||||
except KeyError:
|
||||
return # not for us
|
||||
alive_workers = set(self.state.alive_workers())
|
||||
replies.append(event['hostname'])
|
||||
|
||||
if len(replies) >= len(alive_workers):
|
||||
_, leader, topic, action = self.clock.sort_heap(
|
||||
self.consensus_requests[id],
|
||||
)
|
||||
if leader == self.full_hostname:
|
||||
info('I won the election %r', id)
|
||||
try:
|
||||
handler = self.election_handlers[topic]
|
||||
except KeyError:
|
||||
logger.exception('Unknown election topic %r', topic)
|
||||
else:
|
||||
handler(action)
|
||||
else:
|
||||
info('node %s elected for %r', leader, id)
|
||||
self.consensus_requests.pop(id, None)
|
||||
self.consensus_replies.pop(id, None)
|
||||
|
||||
def on_node_join(self, worker):
|
||||
debug('%s joined the party', worker.hostname)
|
||||
self._call_handlers(self.on.node_join, worker)
|
||||
|
||||
def on_node_leave(self, worker):
|
||||
debug('%s left', worker.hostname)
|
||||
self._call_handlers(self.on.node_leave, worker)
|
||||
|
||||
def on_node_lost(self, worker):
|
||||
info('missed heartbeat from %s', worker.hostname)
|
||||
self._call_handlers(self.on.node_lost, worker)
|
||||
|
||||
def _call_handlers(self, handlers, *args, **kwargs):
|
||||
for handler in handlers:
|
||||
try:
|
||||
handler(*args, **kwargs)
|
||||
except Exception as exc: # pylint: disable=broad-except
|
||||
logger.exception(
|
||||
'Ignored error from handler %r: %r', handler, exc)
|
||||
|
||||
def register_timer(self):
|
||||
if self._tref is not None:
|
||||
self._tref.cancel()
|
||||
self._tref = self.timer.call_repeatedly(self.interval, self.periodic)
|
||||
|
||||
def periodic(self):
|
||||
workers = self.state.workers
|
||||
dirty = set()
|
||||
for worker in workers.values():
|
||||
if not worker.alive:
|
||||
dirty.add(worker)
|
||||
self.on_node_lost(worker)
|
||||
for worker in dirty:
|
||||
workers.pop(worker.hostname, None)
|
||||
|
||||
def get_consumers(self, channel):
|
||||
self.register_timer()
|
||||
ev = self.Receiver(channel, routing_key='worker.#',
|
||||
queue_ttl=self.heartbeat_interval)
|
||||
return [Consumer(
|
||||
channel,
|
||||
queues=[ev.queue],
|
||||
on_message=partial(self.on_message, ev.event_from_message),
|
||||
accept=ev.accept,
|
||||
no_ack=True
|
||||
)]
|
||||
|
||||
def on_message(self, prepare, message):
|
||||
_type = message.delivery_info['routing_key']
|
||||
|
||||
# For redis when `fanout_patterns=False` (See Issue #1882)
|
||||
if _type.split('.', 1)[0] == 'task':
|
||||
return
|
||||
try:
|
||||
handler = self.event_handlers[_type]
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
return handler(message.payload)
|
||||
|
||||
# proto2: hostname in header; proto1: in body
|
||||
hostname = (message.headers.get('hostname') or
|
||||
message.payload['hostname'])
|
||||
if hostname != self.hostname:
|
||||
try:
|
||||
_, event = prepare(message.payload)
|
||||
self.update_state(event)
|
||||
except (DecodeError, ContentDisallowed, TypeError) as exc:
|
||||
logger.error(exc)
|
||||
else:
|
||||
self.clock.forward()
|
||||
@@ -0,0 +1,36 @@
|
||||
"""Worker Event Heartbeat Bootstep."""
|
||||
from celery import bootsteps
|
||||
from celery.worker import heartbeat
|
||||
|
||||
from .events import Events
|
||||
|
||||
__all__ = ('Heart',)
|
||||
|
||||
|
||||
class Heart(bootsteps.StartStopStep):
|
||||
"""Bootstep sending event heartbeats.
|
||||
|
||||
This service sends a ``worker-heartbeat`` message every n seconds.
|
||||
|
||||
Note:
|
||||
Not to be confused with AMQP protocol level heartbeats.
|
||||
"""
|
||||
|
||||
requires = (Events,)
|
||||
|
||||
def __init__(self, c,
|
||||
without_heartbeat=False, heartbeat_interval=None, **kwargs):
|
||||
self.enabled = not without_heartbeat
|
||||
self.heartbeat_interval = heartbeat_interval
|
||||
c.heart = None
|
||||
super().__init__(c, **kwargs)
|
||||
|
||||
def start(self, c):
|
||||
c.heart = heartbeat.Heart(
|
||||
c.timer, c.event_dispatcher, self.heartbeat_interval,
|
||||
)
|
||||
c.heart.start()
|
||||
|
||||
def stop(self, c):
|
||||
c.heart = c.heart and c.heart.stop()
|
||||
shutdown = stop
|
||||
@@ -0,0 +1,76 @@
|
||||
"""Worker <-> Worker Sync at startup (Bootstep)."""
|
||||
from celery import bootsteps
|
||||
from celery.utils.log import get_logger
|
||||
|
||||
from .events import Events
|
||||
|
||||
__all__ = ('Mingle',)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
debug, info, exception = logger.debug, logger.info, logger.exception
|
||||
|
||||
|
||||
class Mingle(bootsteps.StartStopStep):
|
||||
"""Bootstep syncing state with neighbor workers.
|
||||
|
||||
At startup, or upon consumer restart, this will:
|
||||
|
||||
- Sync logical clocks.
|
||||
- Sync revoked tasks.
|
||||
|
||||
"""
|
||||
|
||||
label = 'Mingle'
|
||||
requires = (Events,)
|
||||
compatible_transports = {'amqp', 'redis', 'gcpubsub'}
|
||||
|
||||
def __init__(self, c, without_mingle=False, **kwargs):
|
||||
self.enabled = not without_mingle and self.compatible_transport(c.app)
|
||||
super().__init__(
|
||||
c, without_mingle=without_mingle, **kwargs)
|
||||
|
||||
def compatible_transport(self, app):
|
||||
with app.connection_for_read() as conn:
|
||||
return conn.transport.driver_type in self.compatible_transports
|
||||
|
||||
def start(self, c):
|
||||
self.sync(c)
|
||||
|
||||
def sync(self, c):
|
||||
info('mingle: searching for neighbors')
|
||||
replies = self.send_hello(c)
|
||||
if replies:
|
||||
info('mingle: sync with %s nodes',
|
||||
len([reply for reply, value in replies.items() if value]))
|
||||
[self.on_node_reply(c, nodename, reply)
|
||||
for nodename, reply in replies.items() if reply]
|
||||
info('mingle: sync complete')
|
||||
else:
|
||||
info('mingle: all alone')
|
||||
|
||||
def send_hello(self, c):
|
||||
inspect = c.app.control.inspect(timeout=1.0, connection=c.connection)
|
||||
our_revoked = c.controller.state.revoked
|
||||
replies = inspect.hello(c.hostname, our_revoked._data) or {}
|
||||
replies.pop(c.hostname, None) # delete my own response
|
||||
return replies
|
||||
|
||||
def on_node_reply(self, c, nodename, reply):
|
||||
debug('mingle: processing reply from %s', nodename)
|
||||
try:
|
||||
self.sync_with_node(c, **reply)
|
||||
except MemoryError:
|
||||
raise
|
||||
except Exception as exc: # pylint: disable=broad-except
|
||||
exception('mingle: sync with %s failed: %r', nodename, exc)
|
||||
|
||||
def sync_with_node(self, c, clock=None, revoked=None, **kwargs):
|
||||
self.on_clock_event(c, clock)
|
||||
self.on_revoked_received(c, revoked)
|
||||
|
||||
def on_clock_event(self, c, clock):
|
||||
c.app.clock.adjust(clock) if clock else c.app.clock.forward()
|
||||
|
||||
def on_revoked_received(self, c, revoked):
|
||||
if revoked:
|
||||
c.controller.state.revoked.update(revoked)
|
||||
@@ -0,0 +1,88 @@
|
||||
"""Worker Task Consumer Bootstep."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from kombu.common import QoS, ignore_errors
|
||||
|
||||
from celery import bootsteps
|
||||
from celery.utils.log import get_logger
|
||||
from celery.utils.quorum_queues import detect_quorum_queues
|
||||
|
||||
from .mingle import Mingle
|
||||
|
||||
__all__ = ('Tasks',)
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
debug = logger.debug
|
||||
|
||||
|
||||
class Tasks(bootsteps.StartStopStep):
|
||||
"""Bootstep starting the task message consumer."""
|
||||
|
||||
requires = (Mingle,)
|
||||
|
||||
def __init__(self, c, **kwargs):
|
||||
c.task_consumer = c.qos = None
|
||||
super().__init__(c, **kwargs)
|
||||
|
||||
def start(self, c):
|
||||
"""Start task consumer."""
|
||||
c.update_strategies()
|
||||
|
||||
qos_global = self.qos_global(c)
|
||||
|
||||
# set initial prefetch count
|
||||
c.connection.default_channel.basic_qos(
|
||||
0, c.initial_prefetch_count, qos_global,
|
||||
)
|
||||
|
||||
c.task_consumer = c.app.amqp.TaskConsumer(
|
||||
c.connection, on_decode_error=c.on_decode_error,
|
||||
)
|
||||
|
||||
def set_prefetch_count(prefetch_count):
|
||||
return c.task_consumer.qos(
|
||||
prefetch_count=prefetch_count,
|
||||
apply_global=qos_global,
|
||||
)
|
||||
c.qos = QoS(set_prefetch_count, c.initial_prefetch_count)
|
||||
|
||||
def stop(self, c):
|
||||
"""Stop task consumer."""
|
||||
if c.task_consumer:
|
||||
debug('Canceling task consumer...')
|
||||
ignore_errors(c, c.task_consumer.cancel)
|
||||
|
||||
def shutdown(self, c):
|
||||
"""Shutdown task consumer."""
|
||||
if c.task_consumer:
|
||||
self.stop(c)
|
||||
debug('Closing consumer channel...')
|
||||
ignore_errors(c, c.task_consumer.close)
|
||||
c.task_consumer = None
|
||||
|
||||
def info(self, c):
|
||||
"""Return task consumer info."""
|
||||
return {'prefetch_count': c.qos.value if c.qos else 'N/A'}
|
||||
|
||||
def qos_global(self, c) -> bool:
|
||||
"""Determine if global QoS should be applied.
|
||||
|
||||
Additional information:
|
||||
https://www.rabbitmq.com/docs/consumer-prefetch
|
||||
https://www.rabbitmq.com/docs/quorum-queues#global-qos
|
||||
"""
|
||||
# - RabbitMQ 3.3 completely redefines how basic_qos works...
|
||||
# This will detect if the new qos semantics is in effect,
|
||||
# and if so make sure the 'apply_global' flag is set on qos updates.
|
||||
qos_global = not c.connection.qos_semantics_matches_spec
|
||||
|
||||
if c.app.conf.worker_detect_quorum_queues:
|
||||
using_quorum_queues, qname = detect_quorum_queues(c.app, c.connection.transport.driver_type)
|
||||
|
||||
if using_quorum_queues:
|
||||
qos_global = False
|
||||
logger.info("Global QoS is disabled. Prefetch count in now static.")
|
||||
|
||||
return qos_global
|
||||
Reference in New Issue
Block a user