This commit is contained in:
Iliyan Angelov
2025-09-14 23:24:25 +03:00
commit c67067a2a4
71311 changed files with 6800714 additions and 0 deletions

View File

@@ -0,0 +1,115 @@
"""Messaging library for Python."""
from __future__ import annotations
import os
import re
import sys
from collections import namedtuple
from typing import Any, cast
__version__ = '5.5.4'
__author__ = 'Ask Solem'
__contact__ = 'auvipy@gmail.com'
__homepage__ = 'https://kombu.readthedocs.io'
__docformat__ = 'restructuredtext en'
# -eof meta-
version_info_t = namedtuple('version_info_t', (
'major', 'minor', 'micro', 'releaselevel', 'serial',
))
# bumpversion can only search for {current_version}
# so we have to parse the version here.
_temp = cast(re.Match, re.match(
r'(\d+)\.(\d+).(\d+)(.+)?', __version__)).groups()
VERSION = version_info = version_info_t(
int(_temp[0]), int(_temp[1]), int(_temp[2]), _temp[3] or '', '')
del _temp
del re
STATICA_HACK = True
globals()['kcah_acitats'[::-1].upper()] = False
if STATICA_HACK: # pragma: no cover
# This is never executed, but tricks static analyzers (PyDev, PyCharm,
# pylint, etc.) into knowing the types of these symbols, and what
# they contain.
from kombu.common import eventloop, uuid # noqa
from kombu.connection import BrokerConnection, Connection # noqa
from kombu.entity import Exchange, Queue, binding # noqa
from kombu.message import Message # noqa
from kombu.messaging import Consumer, Producer # noqa
from kombu.pools import connections, producers # noqa
from kombu.serialization import disable_insecure_serializers # noqa
from kombu.serialization import enable_insecure_serializers # noqa
from kombu.utils.url import parse_url # noqa
# Lazy loading.
# - See werkzeug/__init__.py for the rationale behind this.
from types import ModuleType # noqa
all_by_module = {
'kombu.connection': ['Connection', 'BrokerConnection'],
'kombu.entity': ['Exchange', 'Queue', 'binding'],
'kombu.message': ['Message'],
'kombu.messaging': ['Consumer', 'Producer'],
'kombu.pools': ['connections', 'producers'],
'kombu.utils.url': ['parse_url'],
'kombu.common': ['eventloop', 'uuid'],
'kombu.serialization': [
'enable_insecure_serializers',
'disable_insecure_serializers',
],
}
object_origins = {}
for _module, items in all_by_module.items():
for item in items:
object_origins[item] = _module
class module(ModuleType):
"""Customized Python module."""
def __getattr__(self, name: str) -> Any:
if name in object_origins:
module = __import__(object_origins[name], None, None, [name])
for extra_name in all_by_module[module.__name__]:
setattr(self, extra_name, getattr(module, extra_name))
return getattr(module, name)
return ModuleType.__getattribute__(self, name)
def __dir__(self) -> list[str]:
result = list(new_module.__all__)
result.extend(('__file__', '__path__', '__doc__', '__all__',
'__docformat__', '__name__', '__path__', 'VERSION',
'__package__', '__version__', '__author__',
'__contact__', '__homepage__', '__docformat__'))
return result
# keep a reference to this module so that it's not garbage collected
old_module = sys.modules[__name__]
new_module = sys.modules[__name__] = module(__name__)
new_module.__dict__.update({
'__file__': __file__,
'__path__': __path__,
'__doc__': __doc__,
'__all__': tuple(object_origins),
'__version__': __version__,
'__author__': __author__,
'__contact__': __contact__,
'__homepage__': __homepage__,
'__docformat__': __docformat__,
'__package__': __package__,
'version_info_t': version_info_t,
'version_info': version_info,
'VERSION': VERSION
})
if os.environ.get('KOMBU_LOG_DEBUG'): # pragma: no cover
os.environ.update(KOMBU_LOG_CHANNEL='1', KOMBU_LOG_CONNECTION='1')
from .utils import debug
debug.setup_logging()

View File

@@ -0,0 +1,143 @@
"""Object utilities."""
from __future__ import annotations
from copy import copy
from typing import TYPE_CHECKING, Any, Callable, TypeVar
from .connection import maybe_channel
from .exceptions import NotBoundError
from .utils.functional import ChannelPromise
if TYPE_CHECKING:
from kombu.connection import Connection
from kombu.transport.virtual import Channel
__all__ = ('Object', 'MaybeChannelBound')
_T = TypeVar("_T")
_ObjectType = TypeVar("_ObjectType", bound="Object")
_MaybeChannelBoundType = TypeVar(
"_MaybeChannelBoundType", bound="MaybeChannelBound"
)
def unpickle_dict(
cls: type[_ObjectType], kwargs: dict[str, Any]
) -> _ObjectType:
return cls(**kwargs)
def _any(v: _T) -> _T:
return v
class Object:
"""Common base class.
Supports automatic kwargs->attributes handling, and cloning.
"""
attrs: tuple[tuple[str, Any], ...] = ()
def __init__(self, *args: Any, **kwargs: Any) -> None:
for name, type_ in self.attrs:
value = kwargs.get(name)
if value is not None:
setattr(self, name, (type_ or _any)(value))
else:
try:
getattr(self, name)
except AttributeError:
setattr(self, name, None)
def as_dict(self, recurse: bool = False) -> dict[str, Any]:
def f(obj: Any, type: Callable[[Any], Any] | None = None) -> Any:
if recurse and isinstance(obj, Object):
return obj.as_dict(recurse=True)
return type(obj) if type and obj is not None else obj
return {
attr: f(getattr(self, attr), type) for attr, type in self.attrs
}
def __reduce__(self: _ObjectType) -> tuple[
Callable[[type[_ObjectType], dict[str, Any]], _ObjectType],
tuple[type[_ObjectType], dict[str, Any]]
]:
return unpickle_dict, (self.__class__, self.as_dict())
def __copy__(self: _ObjectType) -> _ObjectType:
return self.__class__(**self.as_dict())
class MaybeChannelBound(Object):
"""Mixin for classes that can be bound to an AMQP channel."""
_channel: Channel | None = None
_is_bound = False
#: Defines whether maybe_declare can skip declaring this entity twice.
can_cache_declaration = False
def __call__(
self: _MaybeChannelBoundType, channel: (Channel | Connection)
) -> _MaybeChannelBoundType:
"""`self(channel) -> self.bind(channel)`."""
return self.bind(channel)
def bind(
self: _MaybeChannelBoundType, channel: (Channel | Connection)
) -> _MaybeChannelBoundType:
"""Create copy of the instance that is bound to a channel."""
return copy(self).maybe_bind(channel)
def maybe_bind(
self: _MaybeChannelBoundType, channel: (Channel | Connection)
) -> _MaybeChannelBoundType:
"""Bind instance to channel if not already bound."""
if not self.is_bound and channel:
self._channel = maybe_channel(channel)
self.when_bound()
self._is_bound = True
return self
def revive(self, channel: Channel) -> None:
"""Revive channel after the connection has been re-established.
Used by :meth:`~kombu.Connection.ensure`.
"""
if self.is_bound:
self._channel = channel
self.when_bound()
def when_bound(self) -> None:
"""Callback called when the class is bound."""
def __repr__(self) -> str:
return self._repr_entity(type(self).__name__)
def _repr_entity(self, item: str = '') -> str:
item = item or type(self).__name__
if self.is_bound:
return '<{} bound to chan:{}>'.format(
item or type(self).__name__, self.channel.channel_id)
return f'<unbound {item}>'
@property
def is_bound(self) -> bool:
"""Flag set if the channel is bound."""
return self._is_bound and self._channel is not None
@property
def channel(self) -> Channel:
"""Current channel if the object is bound."""
channel = self._channel
if channel is None:
raise NotBoundError(
"Can't call method on {} not bound to a channel".format(
type(self).__name__))
if isinstance(channel, ChannelPromise):
channel = self._channel = channel()
return channel

View File

@@ -0,0 +1,9 @@
"""Event loop."""
from __future__ import annotations
from kombu.utils.eventio import ERR, READ, WRITE
from .hub import Hub, get_event_loop, set_event_loop
__all__ = ('READ', 'WRITE', 'ERR', 'Hub', 'get_event_loop', 'set_event_loop')

View File

@@ -0,0 +1,17 @@
from __future__ import annotations
from typing import Any
from kombu.asynchronous.aws.sqs.connection import AsyncSQSConnection
def connect_sqs(
aws_access_key_id: str | None = None,
aws_secret_access_key: str | None = None,
**kwargs: Any
) -> AsyncSQSConnection:
"""Return async connection to Amazon SQS."""
from .sqs.connection import AsyncSQSConnection
return AsyncSQSConnection(
aws_access_key_id, aws_secret_access_key, **kwargs
)

View File

@@ -0,0 +1,278 @@
"""Amazon AWS Connection."""
from __future__ import annotations
from email import message_from_bytes
from email.mime.message import MIMEMessage
from vine import promise, transform
from kombu.asynchronous.aws.ext import AWSRequest, get_cert_path, get_response
from kombu.asynchronous.http import Headers, Request, get_client
def message_from_headers(hdr):
bs = "\r\n".join("{}: {}".format(*h) for h in hdr)
return message_from_bytes(bs.encode())
__all__ = (
'AsyncHTTPSConnection', 'AsyncConnection',
)
class AsyncHTTPResponse:
"""Async HTTP Response."""
def __init__(self, response):
self.response = response
self._msg = None
self.version = 10
def read(self, *args, **kwargs):
return self.response.body
def getheader(self, name, default=None):
return self.response.headers.get(name, default)
def getheaders(self):
return list(self.response.headers.items())
@property
def msg(self):
if self._msg is None:
self._msg = MIMEMessage(message_from_headers(self.getheaders()))
return self._msg
@property
def status(self):
return self.response.code
@property
def reason(self):
if self.response.error:
return self.response.error.message
return ''
def __repr__(self):
return repr(self.response)
class AsyncHTTPSConnection:
"""Async HTTP Connection."""
Request = Request
Response = AsyncHTTPResponse
method = 'GET'
path = '/'
body = None
default_ports = {'http': 80, 'https': 443}
def __init__(self, strict=None, timeout=20.0, http_client=None):
self.headers = []
self.timeout = timeout
self.strict = strict
self.http_client = http_client or get_client()
def request(self, method, path, body=None, headers=None):
self.path = path
self.method = method
if body is not None:
try:
read = body.read
except AttributeError:
self.body = body
else:
self.body = read()
if headers is not None:
self.headers.extend(list(headers.items()))
def getrequest(self):
headers = Headers(self.headers)
return self.Request(self.path, method=self.method, headers=headers,
body=self.body, connect_timeout=self.timeout,
request_timeout=self.timeout,
validate_cert=True, ca_certs=get_cert_path(True))
def getresponse(self, callback=None):
request = self.getrequest()
request.then(transform(self.Response, callback))
return self.http_client.add_request(request)
def set_debuglevel(self, level):
pass
def connect(self):
pass
def close(self):
pass
def putrequest(self, method, path):
self.method = method
self.path = path
def putheader(self, header, value):
self.headers.append((header, value))
def endheaders(self):
pass
def send(self, data):
if self.body:
self.body += data
else:
self.body = data
def __repr__(self):
return f'<AsyncHTTPConnection: {self.getrequest()!r}>'
class AsyncConnection:
"""Async AWS Connection."""
def __init__(self, sqs_connection, http_client=None, **kwargs):
self.sqs_connection = sqs_connection
self._httpclient = http_client or get_client()
def get_http_connection(self):
return AsyncHTTPSConnection(http_client=self._httpclient)
def _mexe(self, request, sender=None, callback=None):
callback = callback or promise()
conn = self.get_http_connection()
if callable(sender):
sender(conn, request.method, request.path, request.body,
request.headers, callback)
else:
conn.request(request.method, request.url,
request.body, request.headers)
conn.getresponse(callback=callback)
return callback
class AsyncAWSQueryConnection(AsyncConnection):
"""Async AWS Query Connection."""
STATUS_CODE_OK = 200
STATUS_CODE_REQUEST_TIMEOUT = 408
STATUS_CODE_NETWORK_CONNECT_TIMEOUT_ERROR = 599
STATUS_CODE_INTERNAL_ERROR = 500
STATUS_CODE_BAD_GATEWAY = 502
STATUS_CODE_SERVICE_UNAVAILABLE_ERROR = 503
STATUS_CODE_GATEWAY_TIMEOUT = 504
STATUS_CODES_SERVER_ERRORS = (
STATUS_CODE_INTERNAL_ERROR,
STATUS_CODE_BAD_GATEWAY,
STATUS_CODE_SERVICE_UNAVAILABLE_ERROR
)
STATUS_CODES_TIMEOUT = (
STATUS_CODE_REQUEST_TIMEOUT,
STATUS_CODE_NETWORK_CONNECT_TIMEOUT_ERROR,
STATUS_CODE_GATEWAY_TIMEOUT
)
def __init__(self, sqs_connection, http_client=None,
http_client_params=None, **kwargs):
if not http_client_params:
http_client_params = {}
super().__init__(sqs_connection, http_client,
**http_client_params)
def make_request(self, operation, params_, path, verb, callback=None, protocol_params=None):
params = params_.copy()
params.update((protocol_params or {}).get('query', {}))
if operation:
params['Action'] = operation
signer = self.sqs_connection._request_signer
# defaults for non-get
signing_type = 'standard'
param_payload = {'data': params}
if verb.lower() == 'get':
# query-based opts
signing_type = 'presign-url'
param_payload = {'params': params}
request = AWSRequest(method=verb, url=path, **param_payload)
signer.sign(operation, request, signing_type=signing_type)
prepared_request = request.prepare()
return self._mexe(prepared_request, callback=callback)
def get_list(self, operation, params, markers, path='/', parent=None, verb='POST', callback=None,
protocol_params=None):
return self.make_request(
operation, params, path, verb,
callback=transform(
self._on_list_ready, callback, parent or self, markers,
operation
),
protocol_params=protocol_params,
)
def get_object(self, operation, params, path='/', parent=None, verb='GET', callback=None, protocol_params=None):
return self.make_request(
operation, params, path, verb,
callback=transform(
self._on_obj_ready, callback, parent or self, operation
),
protocol_params=protocol_params,
)
def get_status(self, operation, params, path='/', parent=None, verb='GET', callback=None, protocol_params=None):
return self.make_request(
operation, params, path, verb,
callback=transform(
self._on_status_ready, callback, parent or self, operation
),
protocol_params=protocol_params,
)
def _on_list_ready(self, parent, markers, operation, response):
service_model = self.sqs_connection.meta.service_model
if response.status == self.STATUS_CODE_OK:
_, parsed = get_response(
service_model.operation_model(operation), response.response
)
return parsed
elif (
response.status in self.STATUS_CODES_TIMEOUT or
response.status in self.STATUS_CODES_SERVER_ERRORS
):
# When the server returns a timeout or 50X server error,
# the response is interpreted as an empty list.
# This prevents hanging the Celery worker.
return []
else:
raise self._for_status(response, response.read())
def _on_obj_ready(self, parent, operation, response):
service_model = self.sqs_connection.meta.service_model
if response.status == self.STATUS_CODE_OK:
_, parsed = get_response(
service_model.operation_model(operation), response.response
)
return parsed
else:
raise self._for_status(response, response.read())
def _on_status_ready(self, parent, operation, response):
service_model = self.sqs_connection.meta.service_model
if response.status == self.STATUS_CODE_OK:
httpres, _ = get_response(
service_model.operation_model(operation), response.response
)
return httpres.code
else:
raise self._for_status(response, response.read())
def _for_status(self, response, body):
context = 'Empty body' if not body else 'HTTP Error'
return Exception("Request {} HTTP {} {} ({})".format(
context, response.status, response.reason, body
))

View File

@@ -0,0 +1,28 @@
"""Amazon boto3 interface."""
from __future__ import annotations
try:
import boto3
from botocore import exceptions
from botocore.awsrequest import AWSRequest
from botocore.httpsession import get_cert_path
from botocore.response import get_response
except ImportError:
boto3 = None
class _void:
pass
class BotoCoreError(Exception):
pass
exceptions = _void()
exceptions.BotoCoreError = BotoCoreError
AWSRequest = _void()
get_response = _void()
get_cert_path = _void()
__all__ = (
'exceptions', 'AWSRequest', 'get_response', 'get_cert_path',
)

View File

@@ -0,0 +1,320 @@
"""Amazon SQS Connection."""
from __future__ import annotations
import json
from botocore.serialize import Serializer
from vine import transform
from kombu.asynchronous.aws.connection import AsyncAWSQueryConnection
from kombu.asynchronous.aws.ext import AWSRequest
from .ext import boto3
from .message import AsyncMessage
from .queue import AsyncQueue
__all__ = ('AsyncSQSConnection',)
class AsyncSQSConnection(AsyncAWSQueryConnection):
"""Async SQS Connection."""
def __init__(self, sqs_connection, debug=0, region=None, fetch_message_attributes=None, **kwargs):
if boto3 is None:
raise ImportError('boto3 is not installed')
super().__init__(
sqs_connection,
region_name=region, debug=debug,
**kwargs
)
self.fetch_message_attributes = (
fetch_message_attributes if fetch_message_attributes is not None
else ["ApproximateReceiveCount"]
)
def _create_query_request(self, operation, params, queue_url, method):
params = params.copy()
if operation:
params['Action'] = operation
# defaults for non-get
param_payload = {'data': params}
headers = {}
if method.lower() == 'get':
# query-based opts
param_payload = {'params': params}
if method.lower() == 'post':
headers['Content-Type'] = 'application/x-www-form-urlencoded; charset=utf-8'
return AWSRequest(method=method, url=queue_url, headers=headers, **param_payload)
def _create_json_request(self, operation, params, queue_url):
params = params.copy()
params['QueueUrl'] = queue_url
service_model = self.sqs_connection.meta.service_model
operation_model = service_model.operation_model(operation)
url = self.sqs_connection._endpoint.host
headers = {}
# Content-Type
json_version = operation_model.metadata['jsonVersion']
content_type = f'application/x-amz-json-{json_version}'
headers['Content-Type'] = content_type
# X-Amz-Target
target = '{}.{}'.format(
operation_model.metadata['targetPrefix'],
operation_model.name,
)
headers['X-Amz-Target'] = target
param_payload = {
'data': json.dumps(params).encode(),
'headers': headers
}
method = operation_model.http.get('method', Serializer.DEFAULT_METHOD)
return AWSRequest(
method=method,
url=url,
**param_payload
)
def make_request(self, operation_name, params, queue_url, verb, callback=None, protocol_params=None):
"""Override make_request to support different protocols.
botocore has changed the default protocol of communicating
with SQS backend from 'query' to 'json', so we need a special
implementation of make_request for SQS. More information on this can
be found in: https://github.com/celery/kombu/pull/1807.
protocol_params: Optional[dict[str, dict]] of per-protocol additional parameters.
Supported for the SQS query to json protocol transition.
"""
signer = self.sqs_connection._request_signer
service_model = self.sqs_connection.meta.service_model
protocol = service_model.protocol
all_params = {**(params or {}), **protocol_params.get(protocol, {})}
if protocol == 'query':
request = self._create_query_request(
operation_name, all_params, queue_url, verb)
elif protocol == 'json':
request = self._create_json_request(
operation_name, all_params, queue_url)
else:
raise Exception(f'Unsupported protocol: {protocol}.')
signing_type = 'presign-url' if request.method.lower() == 'get' \
else 'standard'
signer.sign(operation_name, request, signing_type=signing_type)
prepared_request = request.prepare()
return self._mexe(prepared_request, callback=callback)
def create_queue(self, queue_name,
visibility_timeout=None, callback=None):
params = {'QueueName': queue_name}
if visibility_timeout:
params['DefaultVisibilityTimeout'] = format(
visibility_timeout, 'd',
)
return self.get_object('CreateQueue', params,
callback=callback)
def delete_queue(self, queue, force_deletion=False, callback=None):
return self.get_status('DeleteQueue', None, queue.id,
callback=callback)
def get_queue_url(self, queue):
res = self.sqs_connection.get_queue_url(QueueName=queue)
return res['QueueUrl']
def get_queue_attributes(self, queue, attribute='All', callback=None):
return self.get_object(
'GetQueueAttributes', {'AttributeName': attribute},
queue.id, callback=callback,
)
def set_queue_attribute(self, queue, attribute, value, callback=None):
return self.get_status(
'SetQueueAttribute',
{},
queue.id, callback=callback,
protocol_params={
'json': {'Attributes': {attribute: value}},
'query': {'Attribute.Name': attribute, 'Attribute.Value': value},
},
)
def receive_message(
self, queue, queue_url, number_messages=1, visibility_timeout=None,
attributes=None, wait_time_seconds=None,
callback=None
):
params = {'MaxNumberOfMessages': number_messages}
proto_params = {'query': {}, 'json': {}}
attrs = attributes if attributes is not None else self.fetch_message_attributes
if visibility_timeout:
params['VisibilityTimeout'] = visibility_timeout
if attrs:
proto_params['json'].update({'AttributeNames': list(attrs)})
proto_params['query'].update(_query_object_encode({'AttributeName': list(attrs)}))
if wait_time_seconds is not None:
params['WaitTimeSeconds'] = wait_time_seconds
return self.get_list(
'ReceiveMessage', params, [('Message', AsyncMessage)],
queue_url, callback=callback, parent=queue,
protocol_params=proto_params,
)
def delete_message(self, queue, receipt_handle, callback=None):
return self.delete_message_from_handle(
queue, receipt_handle, callback,
)
def delete_message_batch(self, queue, messages, callback=None):
p_params = {
'json': {
'Entries': [{'Id': m.id, 'ReceiptHandle': m.receipt_handle} for m in messages],
},
'query': _query_object_encode({
'DeleteMessageBatchRequestEntry': [
{'Id': m.id, 'ReceiptHandle': m.receipt_handle}
for m in messages
],
}),
}
return self.get_object(
'DeleteMessageBatch', {}, queue.id,
verb='POST', callback=callback, protocol_params=p_params,
)
def delete_message_from_handle(self, queue, receipt_handle,
callback=None):
return self.get_status(
'DeleteMessage', {'ReceiptHandle': receipt_handle},
queue, callback=callback,
)
def send_message(self, queue, message_content,
delay_seconds=None, callback=None):
params = {'MessageBody': message_content}
if delay_seconds:
params['DelaySeconds'] = int(delay_seconds)
return self.get_object(
'SendMessage', params, queue.id,
verb='POST', callback=callback,
)
def send_message_batch(self, queue, messages, callback=None):
params = {}
for i, msg in enumerate(messages):
prefix = f'SendMessageBatchRequestEntry.{i + 1}'
params.update({
f'{prefix}.Id': msg[0],
f'{prefix}.MessageBody': msg[1],
f'{prefix}.DelaySeconds': msg[2],
})
return self.get_object(
'SendMessageBatch', params, queue.id,
verb='POST', callback=callback,
)
def change_message_visibility(self, queue, receipt_handle,
visibility_timeout, callback=None):
return self.get_status(
'ChangeMessageVisibility',
{'ReceiptHandle': receipt_handle,
'VisibilityTimeout': visibility_timeout},
queue.id, callback=callback,
)
def change_message_visibility_batch(self, queue, messages, callback=None):
entries = [
{'Id': t[0].id, 'ReceiptHandle': t[0].receipt_handle, 'VisibilityTimeout': t[1]}
for t in messages
]
p_params = {
'json': {'Entries': entries},
'query': _query_object_encode({'ChangeMessageVisibilityBatchRequestEntry': entries}),
}
return self.get_object(
'ChangeMessageVisibilityBatch', {}, queue.id,
verb='POST', callback=callback,
protocol_params=p_params,
)
def get_all_queues(self, prefix='', callback=None):
params = {}
if prefix:
params['QueueNamePrefix'] = prefix
return self.get_list(
'ListQueues', params, [('QueueUrl', AsyncQueue)],
callback=callback,
)
def get_queue(self, queue_name, callback=None):
# TODO Does not support owner_acct_id argument
return self.get_all_queues(
queue_name,
transform(self._on_queue_ready, callback, queue_name),
)
lookup = get_queue
def _on_queue_ready(self, name, queues):
return next(
(q for q in queues if q.url.endswith(name)), None,
)
def get_dead_letter_source_queues(self, queue, callback=None):
return self.get_list(
'ListDeadLetterSourceQueues', {'QueueUrl': queue.url},
[('QueueUrl', AsyncQueue)],
callback=callback,
)
def add_permission(self, queue, label, aws_account_id, action_name,
callback=None):
return self.get_status(
'AddPermission',
{'Label': label,
'AWSAccountId': aws_account_id,
'ActionName': action_name},
queue.id, callback=callback,
)
def remove_permission(self, queue, label, callback=None):
return self.get_status(
'RemovePermission', {'Label': label}, queue.id, callback=callback,
)
def _query_object_encode(items):
params = {}
_query_object_encode_part(params, '', items)
return {k: v for k, v in params.items()}
def _query_object_encode_part(params, prefix, part):
dotted = f'{prefix}.' if prefix else prefix
if isinstance(part, (list, tuple)):
for i, item in enumerate(part):
_query_object_encode_part(params, f'{dotted}{i + 1}', item)
elif isinstance(part, dict):
for key, value in part.items():
_query_object_encode_part(params, f'{dotted}{key}', value)
else:
params[prefix] = str(part)

View File

@@ -0,0 +1,9 @@
"""Amazon SQS boto3 interface."""
from __future__ import annotations
try:
import boto3
except ImportError:
boto3 = None

View File

@@ -0,0 +1,35 @@
"""Amazon SQS message implementation."""
from __future__ import annotations
import base64
from kombu.message import Message
from kombu.utils.encoding import str_to_bytes
class BaseAsyncMessage(Message):
"""Base class for messages received on async client."""
class AsyncRawMessage(BaseAsyncMessage):
"""Raw Message."""
class AsyncMessage(BaseAsyncMessage):
"""Serialized message."""
def encode(self, value):
"""Encode/decode the value using Base64 encoding."""
return base64.b64encode(str_to_bytes(value)).decode()
def __getitem__(self, item):
"""Support Boto3-style access on a message."""
if item == 'ReceiptHandle':
return self.receipt_handle
elif item == 'Body':
return self.get_body()
elif item == 'queue':
return self.queue
else:
raise KeyError(item)

View File

@@ -0,0 +1,130 @@
"""Amazon SQS queue implementation."""
from __future__ import annotations
from vine import transform
from .message import AsyncMessage
_all__ = ['AsyncQueue']
def list_first(rs):
"""Get the first item in a list, or None if list empty."""
return rs[0] if len(rs) == 1 else None
class AsyncQueue:
"""Async SQS Queue."""
def __init__(self, connection=None, url=None, message_class=AsyncMessage):
self.connection = connection
self.url = url
self.message_class = message_class
self.visibility_timeout = None
def _NA(self, *args, **kwargs):
raise NotImplementedError()
count_slow = dump = save_to_file = save_to_filename = save = \
save_to_s3 = load_from_s3 = load_from_file = load_from_filename = \
load = clear = _NA
def get_attributes(self, attributes='All', callback=None):
return self.connection.get_queue_attributes(
self, attributes, callback,
)
def set_attribute(self, attribute, value, callback=None):
return self.connection.set_queue_attribute(
self, attribute, value, callback,
)
def get_timeout(self, callback=None, _attr='VisibilityTimeout'):
return self.get_attributes(
_attr, transform(
self._coerce_field_value, callback, _attr, int,
),
)
def _coerce_field_value(self, key, type, response):
return type(response[key])
def set_timeout(self, visibility_timeout, callback=None):
return self.set_attribute(
'VisibilityTimeout', visibility_timeout,
transform(
self._on_timeout_set, callback,
)
)
def _on_timeout_set(self, visibility_timeout):
if visibility_timeout:
self.visibility_timeout = visibility_timeout
return self.visibility_timeout
def add_permission(self, label, aws_account_id, action_name,
callback=None):
return self.connection.add_permission(
self, label, aws_account_id, action_name, callback,
)
def remove_permission(self, label, callback=None):
return self.connection.remove_permission(self, label, callback)
def read(self, visibility_timeout=None, wait_time_seconds=None,
callback=None):
return self.get_messages(
1, visibility_timeout,
wait_time_seconds=wait_time_seconds,
callback=transform(list_first, callback),
)
def write(self, message, delay_seconds=None, callback=None):
return self.connection.send_message(
self, message.get_body_encoded(), delay_seconds,
callback=transform(self._on_message_sent, callback, message),
)
def write_batch(self, messages, callback=None):
return self.connection.send_message_batch(
self, messages, callback=callback,
)
def _on_message_sent(self, orig_message, new_message):
orig_message.id = new_message.id
orig_message.md5 = new_message.md5
return new_message
def get_messages(self, num_messages=1, visibility_timeout=None,
attributes=None, wait_time_seconds=None, callback=None):
return self.connection.receive_message(
self, number_messages=num_messages,
visibility_timeout=visibility_timeout,
attributes=attributes,
wait_time_seconds=wait_time_seconds,
callback=callback,
)
def delete_message(self, message, callback=None):
return self.connection.delete_message(self, message, callback)
def delete_message_batch(self, messages, callback=None):
return self.connection.delete_message_batch(
self, messages, callback=callback,
)
def change_message_visibility_batch(self, messages, callback=None):
return self.connection.change_message_visibility_batch(
self, messages, callback=callback,
)
def delete(self, callback=None):
return self.connection.delete_queue(self, callback=callback)
def count(self, page_size=10, vtimeout=10, callback=None,
_attr='ApproximateNumberOfMessages'):
return self.get_attributes(
_attr, callback=transform(
self._coerce_field_value, callback, _attr, int,
),
)

View File

@@ -0,0 +1,67 @@
"""Event-loop debugging tools."""
from __future__ import annotations
from kombu.utils.eventio import ERR, READ, WRITE
from kombu.utils.functional import reprcall
def repr_flag(flag):
"""Return description of event loop flag."""
return '{}{}{}'.format('R' if flag & READ else '',
'W' if flag & WRITE else '',
'!' if flag & ERR else '')
def _rcb(obj):
if obj is None:
return '<missing>'
if isinstance(obj, str):
return obj
if isinstance(obj, tuple):
cb, args = obj
return reprcall(cb.__name__, args=args)
return obj.__name__
def repr_active(h):
"""Return description of active readers and writers."""
return ', '.join(repr_readers(h) + repr_writers(h))
def repr_events(h, events):
"""Return description of events returned by poll."""
return ', '.join(
'{}({})->{}'.format(
_rcb(callback_for(h, fd, fl, '(GONE)')), fd,
repr_flag(fl),
)
for fd, fl in events
)
def repr_readers(h):
"""Return description of pending readers."""
return [f'({fd}){_rcb(cb)}->{repr_flag(READ | ERR)}'
for fd, cb in h.readers.items()]
def repr_writers(h):
"""Return description of pending writers."""
return [f'({fd}){_rcb(cb)}->{repr_flag(WRITE)}'
for fd, cb in h.writers.items()]
def callback_for(h, fd, flag, *default):
"""Return the callback used for hub+fd+flag."""
try:
if flag & READ:
return h.readers[fd]
if flag & WRITE:
if fd in h.consolidate:
return h.consolidate_callback
return h.writers[fd]
except KeyError:
if default:
return default[0]
raise

View File

@@ -0,0 +1,23 @@
from __future__ import annotations
from kombu.asynchronous import get_event_loop
from kombu.asynchronous.http.base import BaseClient, Headers, Request, Response
from kombu.asynchronous.hub import Hub
__all__ = ('Client', 'Headers', 'Response', 'Request', 'get_client')
def Client(hub: Hub | None = None, **kwargs: int) -> BaseClient:
"""Create new HTTP client."""
from .urllib3_client import Urllib3Client
return Urllib3Client(hub, **kwargs)
def get_client(hub: Hub | None = None, **kwargs: int) -> BaseClient:
"""Get or create HTTP client bound to the current event loop."""
hub = hub or get_event_loop()
try:
return hub._current_http_client
except AttributeError:
client = hub._current_http_client = Client(hub, **kwargs)
return client

View File

@@ -0,0 +1,280 @@
"""Base async HTTP client implementation."""
from __future__ import annotations
import sys
from http.client import responses
from typing import TYPE_CHECKING
from vine import Thenable, maybe_promise, promise
from kombu.exceptions import HttpError
from kombu.utils.compat import coro
from kombu.utils.encoding import bytes_to_str
from kombu.utils.functional import maybe_list, memoize
if TYPE_CHECKING:
from types import TracebackType
__all__ = ('Headers', 'Response', 'Request', 'BaseClient')
PYPY = hasattr(sys, 'pypy_version_info')
@memoize(maxsize=1000)
def normalize_header(key):
return '-'.join(p.capitalize() for p in key.split('-'))
class Headers(dict):
"""Represents a mapping of HTTP headers."""
# TODO: This is just a regular dict and will not perform normalization
# when looking up keys etc.
#: Set when all of the headers have been read.
complete = False
#: Internal attribute used to keep track of continuation lines.
_prev_key = None
@Thenable.register
class Request:
"""A HTTP Request.
Arguments:
---------
url (str): The URL to request.
method (str): The HTTP method to use (defaults to ``GET``).
Keyword Arguments:
-----------------
headers (Dict, ~kombu.asynchronous.http.Headers): Optional headers for
this request
body (str): Optional body for this request.
connect_timeout (float): Connection timeout in float seconds
Default is 30.0.
timeout (float): Time in float seconds before the request times out
Default is 30.0.
follow_redirects (bool): Specify if the client should follow redirects
Enabled by default.
max_redirects (int): Maximum number of redirects (default 6).
use_gzip (bool): Allow the server to use gzip compression.
Enabled by default.
validate_cert (bool): Set to true if the server certificate should be
verified when performing ``https://`` requests.
Enabled by default.
auth_username (str): Username for HTTP authentication.
auth_password (str): Password for HTTP authentication.
auth_mode (str): Type of HTTP authentication (``basic`` or ``digest``).
user_agent (str): Custom user agent for this request.
network_interface (str): Network interface to use for this request.
on_ready (Callable): Callback to be called when the response has been
received. Must accept single ``response`` argument.
on_stream (Callable): Optional callback to be called every time body
content has been read from the socket. If specified then the
response body and buffer attributes will not be available.
on_timeout (callable): Optional callback to be called if the request
times out.
on_header (Callable): Optional callback to be called for every header
line received from the server. The signature
is ``(headers, line)`` and note that if you want
``response.headers`` to be populated then your callback needs to
also call ``client.on_header(headers, line)``.
on_prepare (Callable): Optional callback that is implementation
specific (e.g. curl client will pass the ``curl`` instance to
this callback).
proxy_host (str): Optional proxy host. Note that a ``proxy_port`` must
also be provided or a :exc:`ValueError` will be raised.
proxy_username (str): Optional username to use when logging in
to the proxy.
proxy_password (str): Optional password to use when authenticating
with the proxy server.
ca_certs (str): Custom CA certificates file to use.
client_key (str): Optional filename for client SSL key.
client_cert (str): Optional filename for client SSL certificate.
"""
body = user_agent = network_interface = \
auth_username = auth_password = auth_mode = \
proxy_host = proxy_port = proxy_username = proxy_password = \
ca_certs = client_key = client_cert = None
connect_timeout = 30.0
request_timeout = 30.0
follow_redirects = True
max_redirects = 6
use_gzip = True
validate_cert = True
if not PYPY: # pragma: no cover
__slots__ = ('url', 'method', 'on_ready', 'on_timeout', 'on_stream',
'on_prepare', 'on_header', 'headers',
'__weakref__', '__dict__')
def __init__(self, url, method='GET', on_ready=None, on_timeout=None,
on_stream=None, on_prepare=None, on_header=None,
headers=None, **kwargs):
self.url = url
self.method = method or self.method
self.on_ready = maybe_promise(on_ready) or promise()
self.on_timeout = maybe_promise(on_timeout)
self.on_stream = maybe_promise(on_stream)
self.on_prepare = maybe_promise(on_prepare)
self.on_header = maybe_promise(on_header)
if kwargs:
for k, v in kwargs.items():
setattr(self, k, v)
if not isinstance(headers, Headers):
headers = Headers(headers or {})
self.headers = headers
def then(self, callback, errback=None):
self.on_ready.then(callback, errback)
def __repr__(self):
return '<Request: {0.method} {0.url} {0.body}>'.format(self)
class Response:
"""HTTP Response.
Arguments
---------
request (~kombu.asynchronous.http.Request): See :attr:`request`.
code (int): See :attr:`code`.
headers (~kombu.asynchronous.http.Headers): See :attr:`headers`.
buffer (bytes): See :attr:`buffer`
effective_url (str): See :attr:`effective_url`.
status (str): See :attr:`status`.
Attributes
----------
request (~kombu.asynchronous.http.Request): object used to
get this response.
code (int): HTTP response code (e.g. 200, 404, or 500).
headers (~kombu.asynchronous.http.Headers): HTTP headers
for this response.
buffer (bytes): Socket read buffer.
effective_url (str): The destination url for this request after
following redirects.
error (Exception): Error instance if the request resulted in
a HTTP error code.
status (str): Human equivalent of :attr:`code`,
e.g. ``OK``, `Not found`, or 'Internal Server Error'.
"""
if not PYPY: # pragma: no cover
__slots__ = ('request', 'code', 'headers', 'buffer', 'effective_url',
'error', 'status', '_body', '__weakref__')
def __init__(self, request, code, headers=None, buffer=None,
effective_url=None, error=None, status=None):
self.request = request
self.code = code
self.headers = headers if headers is not None else Headers()
self.buffer = buffer
self.effective_url = effective_url or request.url
self._body = None
self.status = status or responses.get(self.code, 'Unknown')
self.error = error
if self.error is None and (self.code < 200 or self.code > 299):
self.error = HttpError(self.code, self.status, self)
def raise_for_error(self):
"""Raise if the request resulted in an HTTP error code.
Raises
------
:class:`~kombu.exceptions.HttpError`
"""
if self.error:
raise self.error
@property
def body(self):
"""The full contents of the response body.
Note:
----
Accessing this property will evaluate the buffer
and subsequent accesses will be cached.
"""
if self._body is None:
if self.buffer is not None:
self._body = self.buffer.getvalue()
return self._body
# these are for compatibility with Requests
@property
def status_code(self):
return self.code
@property
def content(self):
return self.body
@coro
def header_parser(keyt=normalize_header):
while 1:
(line, headers) = yield
if line.startswith('HTTP/'):
continue
elif not line:
headers.complete = True
continue
elif line[0].isspace():
pkey = headers._prev_key
headers[pkey] = ' '.join([headers.get(pkey) or '', line.lstrip()])
else:
key, value = line.split(':', 1)
key = headers._prev_key = keyt(key)
headers[key] = value.strip()
class BaseClient:
"""Base class for HTTP clients.
This class provides the basic structure and functionality for HTTP clients.
Subclasses should implement specific HTTP client behavior.
"""
Headers = Headers
Request = Request
Response = Response
def __init__(self, hub, **kwargs):
self.hub = hub
self._header_parser = header_parser()
def perform(self, request, **kwargs):
for req in maybe_list(request) or []:
if not isinstance(req, self.Request):
req = self.Request(req, **kwargs)
self.add_request(req)
def add_request(self, request):
raise NotImplementedError('must implement add_request')
def close(self):
pass
def on_header(self, headers, line):
try:
self._header_parser.send((bytes_to_str(line), headers))
except StopIteration:
self._header_parser = header_parser()
def __enter__(self):
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None
) -> None:
self.close()

View File

@@ -0,0 +1,219 @@
from __future__ import annotations
from collections import deque
from io import BytesIO
import urllib3
from kombu.asynchronous.hub import Hub, get_event_loop
from kombu.exceptions import HttpError
from .base import BaseClient, Request
__all__ = ('Urllib3Client',)
from ...utils.encoding import bytes_to_str
DEFAULT_USER_AGENT = 'Mozilla/5.0 (compatible; urllib3)'
EXTRA_METHODS = frozenset(['DELETE', 'OPTIONS', 'PATCH'])
def _get_pool_key_parts(request: Request) -> list[str]:
_pool_key_parts = []
if request.network_interface:
_pool_key_parts.append(f"interface={request.network_interface}")
if request.validate_cert:
_pool_key_parts.append("validate_cert=True")
else:
_pool_key_parts.append("validate_cert=False")
if request.ca_certs:
_pool_key_parts.append(f"ca_certs={request.ca_certs}")
if request.client_cert:
_pool_key_parts.append(f"client_cert={request.client_cert}")
if request.client_key:
_pool_key_parts.append(f"client_key={request.client_key}")
return _pool_key_parts
class Urllib3Client(BaseClient):
"""Urllib3 HTTP Client."""
_pools = {}
def __init__(self, hub: Hub | None = None, max_clients: int = 10):
hub = hub or get_event_loop()
super().__init__(hub)
self.max_clients = max_clients
self._pending = deque()
self._timeout_check_tref = self.hub.call_repeatedly(
1.0, self._timeout_check,
)
def pools_close(self):
for pool in self._pools.values():
pool.close()
self._pools.clear()
def close(self):
self._timeout_check_tref.cancel()
self.pools_close()
def add_request(self, request):
self._pending.append(request)
self._process_queue()
return request
def get_pool(self, request: Request):
_pool_key_parts = _get_pool_key_parts(request=request)
_proxy_url = None
proxy_headers = None
if request.proxy_host:
_proxy_url = urllib3.util.Url(
scheme=None,
host=request.proxy_host,
port=request.proxy_port,
)
if request.proxy_username:
proxy_headers = urllib3.make_headers(
proxy_basic_auth=(
f"{request.proxy_username}"
f":{request.proxy_password}"
)
)
else:
proxy_headers = None
_proxy_url = _proxy_url.url
_pool_key_parts.append(f"proxy={_proxy_url}")
if proxy_headers:
_pool_key_parts.append(f"proxy_headers={str(proxy_headers)}")
_pool_key = "|".join(_pool_key_parts)
if _pool_key in self._pools:
return self._pools[_pool_key]
# create new pool
if _proxy_url:
_pool = urllib3.ProxyManager(
proxy_url=_proxy_url,
num_pools=self.max_clients,
proxy_headers=proxy_headers
)
else:
_pool = urllib3.PoolManager(num_pools=self.max_clients)
# Network Interface
if request.network_interface:
_pool.connection_pool_kw['source_address'] = (
request.network_interface,
0
)
# SSL Verification
if request.validate_cert:
_pool.connection_pool_kw['cert_reqs'] = 'CERT_REQUIRED'
else:
_pool.connection_pool_kw['cert_reqs'] = 'CERT_NONE'
# CA Certificates
if request.ca_certs is not None:
_pool.connection_pool_kw['ca_certs'] = request.ca_certs
elif request.validate_cert is True:
try:
from certifi import where
_pool.connection_pool_kw['ca_certs'] = where()
except ImportError:
pass
# Client Certificates
if request.client_cert is not None:
_pool.connection_pool_kw['cert_file'] = request.client_cert
if request.client_key is not None:
_pool.connection_pool_kw['key_file'] = request.client_key
self._pools[_pool_key] = _pool
return _pool
def _timeout_check(self):
self._process_pending_requests()
def _process_pending_requests(self):
while self._pending:
request = self._pending.popleft()
self._process_request(request)
def _process_request(self, request: Request):
# Prepare headers
headers = request.headers
headers.setdefault(
'User-Agent',
bytes_to_str(request.user_agent or DEFAULT_USER_AGENT)
)
headers.setdefault(
'Accept-Encoding',
'gzip,deflate' if request.use_gzip else 'none'
)
# Authentication
if request.auth_username is not None:
headers.update(
urllib3.util.make_headers(
basic_auth=(
f"{request.auth_username}"
f":{request.auth_password or ''}"
)
)
)
# Make the request using urllib3
try:
_pool = self.get_pool(request=request)
response = _pool.request(
request.method,
request.url,
headers=headers,
body=request.body,
preload_content=False,
redirect=request.follow_redirects,
)
buffer = BytesIO(response.data)
response_obj = self.Response(
request=request,
code=response.status,
headers=response.headers,
buffer=buffer,
effective_url=response.geturl(),
error=None
)
except urllib3.exceptions.HTTPError as e:
response_obj = self.Response(
request=request,
code=599,
headers={},
buffer=None,
effective_url=None,
error=HttpError(599, str(e))
)
request.on_ready(response_obj)
def _process_queue(self):
self._process_pending_requests()
def on_readable(self, fd):
pass
def on_writable(self, fd):
pass
def _setup_request(self, curl, request, buffer, headers):
pass

View File

@@ -0,0 +1,399 @@
"""Event loop implementation."""
from __future__ import annotations
import errno
import threading
from contextlib import contextmanager
from copy import copy
from queue import Empty
from time import sleep
from types import GeneratorType as generator
from vine import Thenable, promise
from kombu.log import get_logger
from kombu.utils.compat import fileno
from kombu.utils.eventio import ERR, READ, WRITE, poll
from kombu.utils.objects import cached_property
from .timer import Timer
__all__ = ('Hub', 'get_event_loop', 'set_event_loop')
logger = get_logger(__name__)
_current_loop: Hub | None = None
W_UNKNOWN_EVENT = """\
Received unknown event %r for fd %r, please contact support!\
"""
class Stop(BaseException):
"""Stops the event loop."""
def _raise_stop_error():
raise Stop()
@contextmanager
def _dummy_context(*args, **kwargs):
yield
def get_event_loop() -> Hub | None:
"""Get current event loop object."""
return _current_loop
def set_event_loop(loop: Hub | None) -> Hub | None:
"""Set the current event loop object."""
global _current_loop
_current_loop = loop
return loop
class Hub:
"""Event loop object.
Arguments:
---------
timer (kombu.asynchronous.Timer): Specify custom timer instance.
"""
#: Flag set if reading from an fd will not block.
READ = READ
#: Flag set if writing to an fd will not block.
WRITE = WRITE
#: Flag set on error, and the fd should be read from asap.
ERR = ERR
#: List of callbacks to be called when the loop is exiting,
#: applied with the hub instance as sole argument.
on_close = None
def __init__(self, timer=None):
self.timer = timer if timer is not None else Timer()
self.readers = {}
self.writers = {}
self.on_tick = set()
self.on_close = set()
self._ready = set()
self._ready_lock = threading.Lock()
self._running = False
self._loop = None
# The eventloop (in celery.worker.loops)
# will merge fds in this set and then instead of calling
# the callback for each ready fd it will call the
# :attr:`consolidate_callback` with the list of ready_fds
# as an argument. This API is internal and is only
# used by the multiprocessing pool to find inqueues
# that are ready to write.
self.consolidate = set()
self.consolidate_callback = None
self.propagate_errors = ()
self._create_poller()
@property
def poller(self):
if not self._poller:
self._create_poller()
return self._poller
@poller.setter
def poller(self, value):
self._poller = value
def reset(self):
self.close()
self._create_poller()
def _create_poller(self):
self._poller = poll()
self._register_fd = self._poller.register
self._unregister_fd = self._poller.unregister
def _close_poller(self):
if self._poller is not None:
self._poller.close()
self._poller = None
self._register_fd = None
self._unregister_fd = None
def stop(self):
self.call_soon(_raise_stop_error)
def __repr__(self):
return '<Hub@{:#x}: R:{} W:{}>'.format(
id(self), len(self.readers), len(self.writers),
)
def fire_timers(self, min_delay=1, max_delay=10, max_timers=10,
propagate=()):
timer = self.timer
delay = None
if timer and timer._queue:
for i in range(max_timers):
delay, entry = next(self.scheduler)
if entry is None:
break
try:
entry()
except propagate:
raise
except (MemoryError, AssertionError):
raise
except OSError as exc:
if exc.errno == errno.ENOMEM:
raise
logger.error('Error in timer: %r', exc, exc_info=1)
except Exception as exc:
logger.error('Error in timer: %r', exc, exc_info=1)
return min(delay or min_delay, max_delay)
def _remove_from_loop(self, fd):
try:
self._unregister(fd)
finally:
self._discard(fd)
def add(self, fd, callback, flags, args=(), consolidate=False):
fd = fileno(fd)
try:
self.poller.register(fd, flags)
except ValueError:
self._remove_from_loop(fd)
raise
else:
dest = self.readers if flags & READ else self.writers
if consolidate:
self.consolidate.add(fd)
dest[fd] = None
else:
dest[fd] = callback, args
def remove(self, fd):
fd = fileno(fd)
self._remove_from_loop(fd)
def run_forever(self):
self._running = True
try:
while 1:
try:
self.run_once()
except Stop:
break
finally:
self._running = False
def run_once(self):
try:
next(self.loop)
except StopIteration:
self._loop = None
def call_soon(self, callback, *args):
if not isinstance(callback, Thenable):
callback = promise(callback, args)
with self._ready_lock:
self._ready.add(callback)
return callback
def call_later(self, delay, callback, *args):
return self.timer.call_after(delay, callback, args)
def call_at(self, when, callback, *args):
return self.timer.call_at(when, callback, args)
def call_repeatedly(self, delay, callback, *args):
return self.timer.call_repeatedly(delay, callback, args)
def add_reader(self, fds, callback, *args):
return self.add(fds, callback, READ | ERR, args)
def add_writer(self, fds, callback, *args):
return self.add(fds, callback, WRITE, args)
def remove_reader(self, fd):
writable = fd in self.writers
on_write = self.writers.get(fd)
try:
self._remove_from_loop(fd)
finally:
if writable:
cb, args = on_write
self.add(fd, cb, WRITE, args)
def remove_writer(self, fd):
readable = fd in self.readers
on_read = self.readers.get(fd)
try:
self._remove_from_loop(fd)
finally:
if readable:
cb, args = on_read
self.add(fd, cb, READ | ERR, args)
def _unregister(self, fd):
try:
self.poller.unregister(fd)
except (AttributeError, KeyError, OSError):
pass
def _pop_ready(self):
with self._ready_lock:
ready = self._ready
self._ready = set()
return ready
def close(self, *args):
[self._unregister(fd) for fd in self.readers]
self.readers.clear()
[self._unregister(fd) for fd in self.writers]
self.writers.clear()
self.consolidate.clear()
self._close_poller()
for callback in self.on_close:
callback(self)
# Complete remaining todo before Hub close
# Eg: Acknowledge message
# To avoid infinite loop where one of the callables adds items
# to self._ready (via call_soon or otherwise).
# we create new list with current self._ready
todos = self._pop_ready()
for item in todos:
item()
def _discard(self, fd):
fd = fileno(fd)
self.readers.pop(fd, None)
self.writers.pop(fd, None)
self.consolidate.discard(fd)
def on_callback_error(self, callback, exc):
logger.error(
'Callback %r raised exception: %r', callback, exc, exc_info=1,
)
def create_loop(self,
generator=generator, sleep=sleep, min=min, next=next,
Empty=Empty, StopIteration=StopIteration,
KeyError=KeyError, READ=READ, WRITE=WRITE, ERR=ERR):
readers, writers = self.readers, self.writers
poll = self.poller.poll
fire_timers = self.fire_timers
hub_remove = self.remove
scheduled = self.timer._queue
consolidate = self.consolidate
consolidate_callback = self.consolidate_callback
propagate = self.propagate_errors
while 1:
todo = self._pop_ready()
for item in todo:
if item:
item()
poll_timeout = fire_timers(propagate=propagate) if scheduled else 1
for tick_callback in copy(self.on_tick):
tick_callback()
# print('[[[HUB]]]: %s' % (self.repr_active(),))
if readers or writers:
to_consolidate = []
try:
events = poll(poll_timeout)
# print('[EVENTS]: %s' % (self.repr_events(events),))
except ValueError: # Issue celery/#882
return
for fd, event in events or ():
general_error = False
if fd in consolidate and \
writers.get(fd) is None:
to_consolidate.append(fd)
continue
cb = cbargs = None
if event & READ:
try:
cb, cbargs = readers[fd]
except KeyError:
self.remove_reader(fd)
continue
elif event & WRITE:
try:
cb, cbargs = writers[fd]
except KeyError:
self.remove_writer(fd)
continue
elif event & ERR:
general_error = True
else:
logger.info(W_UNKNOWN_EVENT, event, fd)
general_error = True
if general_error:
try:
cb, cbargs = (readers.get(fd) or
writers.get(fd))
except TypeError:
pass
if cb is None:
self.remove(fd)
continue
if isinstance(cb, generator):
try:
next(cb)
except OSError as exc:
if exc.errno != errno.EBADF:
raise
hub_remove(fd)
except StopIteration:
pass
except Exception:
hub_remove(fd)
raise
else:
try:
cb(*cbargs)
except Empty:
pass
if to_consolidate:
consolidate_callback(to_consolidate)
else:
# no sockets yet, startup is probably not done.
sleep(min(poll_timeout, 0.1))
yield
def repr_active(self):
from .debug import repr_active
return repr_active(self)
def repr_events(self, events):
from .debug import repr_events
return repr_events(self, events or [])
@cached_property
def scheduler(self):
return iter(self.timer)
@property
def loop(self):
if self._loop is None:
self._loop = self.create_loop()
return self._loop

View File

@@ -0,0 +1,127 @@
"""Semaphores and concurrency primitives."""
from __future__ import annotations
import sys
from collections import deque
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from types import TracebackType
from typing import Callable, Deque
if sys.version_info < (3, 10):
from typing_extensions import ParamSpec
else:
from typing import ParamSpec
P = ParamSpec("P")
__all__ = ('DummyLock', 'LaxBoundedSemaphore')
class LaxBoundedSemaphore:
"""Asynchronous Bounded Semaphore.
Lax means that the value will stay within the specified
range even if released more times than it was acquired.
Example:
-------
>>> x = LaxBoundedSemaphore(2)
>>> x.acquire(print, 'HELLO 1')
HELLO 1
>>> x.acquire(print, 'HELLO 2')
HELLO 2
>>> x.acquire(print, 'HELLO 3')
>>> x._waiters # private, do not access directly
[print, ('HELLO 3',)]
>>> x.release()
HELLO 3
"""
def __init__(self, value: int) -> None:
self.initial_value = self.value = value
self._waiting: Deque[tuple] = deque()
self._add_waiter = self._waiting.append
self._pop_waiter = self._waiting.popleft
def acquire(
self,
callback: Callable[P, None],
*partial_args: P.args,
**partial_kwargs: P.kwargs
) -> bool:
"""Acquire semaphore.
This will immediately apply ``callback`` if
the resource is available, otherwise the callback is suspended
until the semaphore is released.
Arguments:
---------
callback (Callable): The callback to apply.
*partial_args (Any): partial arguments to callback.
"""
value = self.value
if value <= 0:
self._add_waiter((callback, partial_args, partial_kwargs))
return False
else:
self.value = max(value - 1, 0)
callback(*partial_args, **partial_kwargs)
return True
def release(self) -> None:
"""Release semaphore.
Note:
----
If there are any waiters this will apply the first waiter
that is waiting for the resource (FIFO order).
"""
try:
waiter, args, kwargs = self._pop_waiter()
except IndexError:
self.value = min(self.value + 1, self.initial_value)
else:
waiter(*args, **kwargs)
def grow(self, n: int = 1) -> None:
"""Change the size of the semaphore to accept more users."""
self.initial_value += n
self.value += n
for _ in range(n):
self.release()
def shrink(self, n: int = 1) -> None:
"""Change the size of the semaphore to accept less users."""
self.initial_value = max(self.initial_value - n, 0)
self.value = max(self.value - n, 0)
def clear(self) -> None:
"""Reset the semaphore, which also wipes out any waiting callbacks."""
self._waiting.clear()
self.value = self.initial_value
def __repr__(self) -> str:
return '<{} at {:#x} value:{} waiting:{}>'.format(
self.__class__.__name__, id(self), self.value, len(self._waiting),
)
class DummyLock:
"""Pretending to be a lock."""
def __enter__(self) -> DummyLock:
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None
) -> None:
pass

View File

@@ -0,0 +1,241 @@
"""Timer scheduling Python callbacks."""
from __future__ import annotations
import heapq
import sys
from collections import namedtuple
from datetime import datetime
from functools import total_ordering
from time import monotonic
from time import time as _time
from typing import TYPE_CHECKING
from weakref import proxy as weakrefproxy
from vine.utils import wraps
from kombu.log import get_logger
if sys.version_info >= (3, 9):
from zoneinfo import ZoneInfo
else:
from backports.zoneinfo import ZoneInfo
if TYPE_CHECKING:
from types import TracebackType
__all__ = ('Entry', 'Timer', 'to_timestamp')
logger = get_logger(__name__)
DEFAULT_MAX_INTERVAL = 2
EPOCH = datetime.fromtimestamp(0, ZoneInfo("UTC"))
IS_PYPY = hasattr(sys, 'pypy_version_info')
scheduled = namedtuple('scheduled', ('eta', 'priority', 'entry'))
def to_timestamp(d, default_timezone=ZoneInfo("UTC"), time=monotonic):
"""Convert datetime to timestamp.
If d' is already a timestamp, then that will be used.
"""
if isinstance(d, datetime):
if d.tzinfo is None:
d = d.replace(tzinfo=default_timezone)
diff = _time() - time()
return max((d - EPOCH).total_seconds() - diff, 0)
return d
@total_ordering
class Entry:
"""Schedule Entry."""
if not IS_PYPY: # pragma: no cover
__slots__ = (
'fun', 'args', 'kwargs', 'tref', 'canceled',
'_last_run', '__weakref__',
)
def __init__(self, fun, args=None, kwargs=None):
self.fun = fun
self.args = args or []
self.kwargs = kwargs or {}
self.tref = weakrefproxy(self)
self._last_run = None
self.canceled = False
def __call__(self):
return self.fun(*self.args, **self.kwargs)
def cancel(self):
try:
self.tref.canceled = True
except ReferenceError: # pragma: no cover
pass
def __repr__(self):
return '<TimerEntry: {}(*{!r}, **{!r})'.format(
self.fun.__name__, self.args, self.kwargs)
# must not use hash() to order entries
def __lt__(self, other):
return id(self) < id(other)
@property
def cancelled(self):
return self.canceled
@cancelled.setter
def cancelled(self, value):
self.canceled = value
class Timer:
"""Async timer implementation."""
Entry = Entry
on_error = None
def __init__(self, max_interval=None, on_error=None, **kwargs):
self.max_interval = float(max_interval or DEFAULT_MAX_INTERVAL)
self.on_error = on_error or self.on_error
self._queue = []
def __enter__(self):
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None
) -> None:
self.stop()
def call_at(self, eta, fun, args=(), kwargs=None, priority=0):
kwargs = {} if not kwargs else kwargs
return self.enter_at(self.Entry(fun, args, kwargs), eta, priority)
def call_after(self, secs, fun, args=(), kwargs=None, priority=0):
kwargs = {} if not kwargs else kwargs
return self.enter_after(secs, self.Entry(fun, args, kwargs), priority)
def call_repeatedly(self, secs, fun, args=(), kwargs=None, priority=0):
kwargs = {} if not kwargs else kwargs
tref = self.Entry(fun, args, kwargs)
@wraps(fun)
def _reschedules(*args, **kwargs):
last, now = tref._last_run, monotonic()
lsince = (now - tref._last_run) if last else secs
try:
if lsince and lsince >= secs:
tref._last_run = now
return fun(*args, **kwargs)
finally:
if not tref.canceled:
last = tref._last_run
next = secs - (now - last) if last else secs
self.enter_after(next, tref, priority)
tref.fun = _reschedules
tref._last_run = None
return self.enter_after(secs, tref, priority)
def enter_at(self, entry, eta=None, priority=0, time=monotonic):
"""Enter function into the scheduler.
Arguments:
---------
entry (~kombu.asynchronous.timer.Entry): Item to enter.
eta (datetime.datetime): Scheduled time.
priority (int): Unused.
"""
if eta is None:
eta = time()
if isinstance(eta, datetime):
try:
eta = to_timestamp(eta)
except Exception as exc:
if not self.handle_error(exc):
raise
return
return self._enter(eta, priority, entry)
def enter_after(self, secs, entry, priority=0, time=monotonic):
return self.enter_at(entry, time() + float(secs), priority)
def _enter(self, eta, priority, entry, push=heapq.heappush):
push(self._queue, scheduled(eta, priority, entry))
return entry
def apply_entry(self, entry):
try:
entry()
except Exception as exc:
if not self.handle_error(exc):
logger.error('Error in timer: %r', exc, exc_info=True)
def handle_error(self, exc_info):
if self.on_error:
self.on_error(exc_info)
return True
def stop(self):
pass
def __iter__(self, min=min, nowfun=monotonic,
pop=heapq.heappop, push=heapq.heappush):
"""Iterate over schedule.
This iterator yields a tuple of ``(wait_seconds, entry)``,
where if entry is :const:`None` the caller should wait
for ``wait_seconds`` until it polls the schedule again.
"""
max_interval = self.max_interval
queue = self._queue
while 1:
if queue:
eventA = queue[0]
now, eta = nowfun(), eventA[0]
if now < eta:
yield min(eta - now, max_interval), None
else:
eventB = pop(queue)
if eventB is eventA:
entry = eventA[2]
if not entry.canceled:
yield None, entry
continue
else:
push(queue, eventB)
else:
yield None, None
def clear(self):
self._queue[:] = [] # atomic, without creating a new list.
def cancel(self, tref):
tref.cancel()
def __len__(self):
return len(self._queue)
def __nonzero__(self):
return True
@property
def queue(self, _pop=heapq.heappop):
"""Snapshot of underlying datastructure."""
events = list(self._queue)
return [_pop(v) for v in [events] * len(events)]
@property
def schedule(self):
return self

View File

@@ -0,0 +1,156 @@
"""Logical Clocks and Synchronization."""
from __future__ import annotations
from itertools import islice
from operator import itemgetter
from threading import Lock
from typing import Any
__all__ = ('LamportClock', 'timetuple')
R_CLOCK = '_lamport(clock={0}, timestamp={1}, id={2} {3!r})'
class timetuple(tuple):
"""Tuple of event clock information.
Can be used as part of a heap to keep events ordered.
Arguments:
---------
clock (Optional[int]): Event clock value.
timestamp (float): Event UNIX timestamp value.
id (str): Event host id (e.g. ``hostname:pid``).
obj (Any): Optional obj to associate with this event.
"""
__slots__ = ()
def __new__(
cls, clock: int | None, timestamp: float, id: str, obj: Any = None
) -> timetuple:
return tuple.__new__(cls, (clock, timestamp, id, obj))
def __repr__(self) -> str:
return R_CLOCK.format(*self)
def __getnewargs__(self) -> tuple:
return tuple(self)
def __lt__(self, other: tuple) -> bool:
# 0: clock 1: timestamp 3: process id
try:
A, B = self[0], other[0]
# uses logical clock value first
if A and B: # use logical clock if available
if A == B: # equal clocks use lower process id
return self[2] < other[2]
return A < B
return self[1] < other[1] # ... or use timestamp
except IndexError:
return NotImplemented
def __gt__(self, other: tuple) -> bool:
return other < self
def __le__(self, other: tuple) -> bool:
return not other < self
def __ge__(self, other: tuple) -> bool:
return not self < other
clock = property(itemgetter(0))
timestamp = property(itemgetter(1))
id = property(itemgetter(2))
obj = property(itemgetter(3))
class LamportClock:
"""Lamport's logical clock.
From Wikipedia:
A Lamport logical clock is a monotonically incrementing software counter
maintained in each process. It follows some simple rules:
* A process increments its counter before each event in that process;
* When a process sends a message, it includes its counter value with
the message;
* On receiving a message, the receiver process sets its counter to be
greater than the maximum of its own value and the received value
before it considers the message received.
Conceptually, this logical clock can be thought of as a clock that only
has meaning in relation to messages moving between processes. When a
process receives a message, it resynchronizes its logical clock with
the sender.
See Also
--------
* `Lamport timestamps`_
* `Lamports distributed mutex`_
.. _`Lamport Timestamps`: https://en.wikipedia.org/wiki/Lamport_timestamps
.. _`Lamports distributed mutex`: https://bit.ly/p99ybE
*Usage*
When sending a message use :meth:`forward` to increment the clock,
when receiving a message use :meth:`adjust` to sync with
the time stamp of the incoming message.
"""
#: The clocks current value.
value = 0
def __init__(
self, initial_value: int = 0, Lock: type[Lock] = Lock
) -> None:
self.value = initial_value
self.mutex = Lock()
def adjust(self, other: int) -> int:
with self.mutex:
value = self.value = max(self.value, other) + 1
return value
def forward(self) -> int:
with self.mutex:
self.value += 1
return self.value
def sort_heap(self, h: list[tuple[int, str]]) -> tuple[int, str]:
"""Sort heap of events.
List of tuples containing at least two elements, representing
an event, where the first element is the event's scalar clock value,
and the second element is the id of the process (usually
``"hostname:pid"``): ``sh([(clock, processid, ...?), (...)])``
The list must already be sorted, which is why we refer to it as a
heap.
The tuple will not be unpacked, so more than two elements can be
present.
Will return the latest event.
"""
if h[0][0] == h[1][0]:
same = []
for PN in zip(h, islice(h, 1, None)):
if PN[0][0] != PN[1][0]:
break # Prev and Next's clocks differ
same.append(PN[0])
# return first item sorted by process id
return sorted(same, key=lambda event: event[1])[0]
# clock values unique, return first item
return h[0]
def __str__(self) -> str:
return str(self.value)
def __repr__(self) -> str:
return f'<LamportClock: {self.value}>'

View File

@@ -0,0 +1,448 @@
"""Common Utilities."""
from __future__ import annotations
import os
import socket
import threading
from collections import deque
from contextlib import contextmanager
from functools import partial
from itertools import count
from uuid import NAMESPACE_OID, uuid3, uuid4, uuid5
from amqp import ChannelError, RecoverableConnectionError
from .entity import Exchange, Queue
from .log import get_logger
from .serialization import registry as serializers
from .utils.uuid import uuid
__all__ = ('Broadcast', 'maybe_declare', 'uuid',
'itermessages', 'send_reply',
'collect_replies', 'insured', 'drain_consumer',
'eventloop')
#: Prefetch count can't exceed short.
PREFETCH_COUNT_MAX = 0xFFFF
logger = get_logger(__name__)
_node_id = None
def get_node_id():
global _node_id
if _node_id is None:
_node_id = uuid4().int
return _node_id
def generate_oid(node_id, process_id, thread_id, instance):
ent = '{:x}-{:x}-{:x}-{:x}'.format(
node_id, process_id, thread_id, id(instance))
try:
ret = str(uuid3(NAMESPACE_OID, ent))
except ValueError:
ret = str(uuid5(NAMESPACE_OID, ent))
return ret
def oid_from(instance, threads=True):
return generate_oid(
get_node_id(),
os.getpid(),
threading.get_ident() if threads else 0,
instance,
)
class Broadcast(Queue):
"""Broadcast queue.
Convenience class used to define broadcast queues.
Every queue instance will have a unique name,
and both the queue and exchange is configured with auto deletion.
Arguments:
---------
name (str): This is used as the name of the exchange.
queue (str): By default a unique id is used for the queue
name for every consumer. You can specify a custom
queue name here.
unique (bool): Always create a unique queue
even if a queue name is supplied.
**kwargs (Any): See :class:`~kombu.Queue` for a list
of additional keyword arguments supported.
"""
attrs = Queue.attrs + (('queue', None),)
def __init__(self,
name=None,
queue=None,
unique=False,
auto_delete=True,
exchange=None,
alias=None,
**kwargs):
if unique:
queue = '{}.{}'.format(queue or 'bcast', uuid())
else:
queue = queue or f'bcast.{uuid()}'
super().__init__(
alias=alias or name,
queue=queue,
name=queue,
auto_delete=auto_delete,
exchange=(exchange if exchange is not None
else Exchange(name, type='fanout')),
**kwargs
)
def declaration_cached(entity, channel):
return entity in channel.connection.client.declared_entities
def maybe_declare(entity, channel=None, retry=False, **retry_policy):
"""Declare entity (cached)."""
if retry:
return _imaybe_declare(entity, channel, **retry_policy)
return _maybe_declare(entity, channel)
def _ensure_channel_is_bound(entity, channel):
"""Make sure the channel is bound to the entity.
:param entity: generic kombu nomenclature, generally an exchange or queue
:param channel: channel to bind to the entity
:return: the updated entity
"""
is_bound = entity.is_bound
if not is_bound:
if not channel:
raise ChannelError(
f"Cannot bind channel {channel} to entity {entity}")
entity = entity.bind(channel)
return entity
def _maybe_declare(entity, channel):
# _maybe_declare sets name on original for autogen queues
orig = entity
_ensure_channel_is_bound(entity, channel)
if channel is None or channel.connection is None:
# If this was called from the `ensure()` method then the channel could have been invalidated
# and the correct channel was re-bound to the entity by calling the `entity.revive()` method.
if not entity.is_bound:
raise ChannelError(
f"channel is None and entity {entity} not bound.")
channel = entity.channel
declared = ident = None
if channel.connection and entity.can_cache_declaration:
declared = channel.connection.client.declared_entities
ident = hash(entity)
if ident in declared:
return False
if not channel.connection:
raise RecoverableConnectionError('channel disconnected')
entity.declare(channel=channel)
if declared is not None and ident:
declared.add(ident)
if orig is not None:
orig.name = entity.name
return True
def _imaybe_declare(entity, channel, **retry_policy):
entity = _ensure_channel_is_bound(entity, channel)
if not entity.channel.connection:
raise RecoverableConnectionError('channel disconnected')
return entity.channel.connection.client.ensure(
entity, _maybe_declare, **retry_policy)(entity, channel)
def drain_consumer(consumer, limit=1, timeout=None, callbacks=None):
"""Drain messages from consumer instance."""
acc = deque()
def on_message(body, message):
acc.append((body, message))
consumer.callbacks = [on_message] + (callbacks or [])
with consumer:
for _ in eventloop(consumer.channel.connection.client,
limit=limit, timeout=timeout, ignore_timeouts=True):
try:
yield acc.popleft()
except IndexError:
pass
def itermessages(conn, channel, queue, limit=1, timeout=None,
callbacks=None, **kwargs):
"""Iterator over messages."""
return drain_consumer(
conn.Consumer(queues=[queue], channel=channel, **kwargs),
limit=limit, timeout=timeout, callbacks=callbacks,
)
def eventloop(conn, limit=None, timeout=None, ignore_timeouts=False):
"""Best practice generator wrapper around ``Connection.drain_events``.
Able to drain events forever, with a limit, and optionally ignoring
timeout errors (a timeout of 1 is often used in environments where
the socket can get "stuck", and is a best practice for Kombu consumers).
``eventloop`` is a generator.
Examples
--------
>>> from kombu.common import eventloop
>>> def run(conn):
... it = eventloop(conn, timeout=1, ignore_timeouts=True)
... next(it) # one event consumed, or timed out.
...
... for _ in eventloop(conn, timeout=1, ignore_timeouts=True):
... pass # loop forever.
It also takes an optional limit parameter, and timeout errors
are propagated by default::
for _ in eventloop(connection, limit=1, timeout=1):
pass
See Also
--------
:func:`itermessages`, which is an event loop bound to one or more
consumers, that yields any messages received.
"""
for i in limit and range(limit) or count():
try:
yield conn.drain_events(timeout=timeout)
except socket.timeout:
if timeout and not ignore_timeouts: # pragma: no cover
raise
def send_reply(exchange, req, msg,
producer=None, retry=False, retry_policy=None, **props):
"""Send reply for request.
Arguments:
---------
exchange (kombu.Exchange, str): Reply exchange
req (~kombu.Message): Original request, a message with
a ``reply_to`` property.
producer (kombu.Producer): Producer instance
retry (bool): If true must retry according to
the ``reply_policy`` argument.
retry_policy (Dict): Retry settings.
**props (Any): Extra properties.
"""
return producer.publish(
msg, exchange=exchange,
retry=retry, retry_policy=retry_policy,
**dict({'routing_key': req.properties['reply_to'],
'correlation_id': req.properties.get('correlation_id'),
'serializer': serializers.type_to_name[req.content_type],
'content_encoding': req.content_encoding}, **props)
)
def collect_replies(conn, channel, queue, *args, **kwargs):
"""Generator collecting replies from ``queue``."""
no_ack = kwargs.setdefault('no_ack', True)
received = False
try:
for body, message in itermessages(conn, channel, queue,
*args, **kwargs):
if not no_ack:
message.ack()
received = True
yield body
finally:
if received:
channel.after_reply_message_received(queue.name)
def _ensure_errback(exc, interval):
logger.error(
'Connection error: %r. Retry in %ss\n', exc, interval,
exc_info=True,
)
@contextmanager
def _ignore_errors(conn):
try:
yield
except conn.connection_errors + conn.channel_errors:
pass
def ignore_errors(conn, fun=None, *args, **kwargs):
"""Ignore connection and channel errors.
The first argument must be a connection object, or any other object
with ``connection_error`` and ``channel_error`` attributes.
Can be used as a function:
.. code-block:: python
def example(connection):
ignore_errors(connection, consumer.channel.close)
or as a context manager:
.. code-block:: python
def example(connection):
with ignore_errors(connection):
consumer.channel.close()
Note:
----
Connection and channel errors should be properly handled,
and not ignored. Using this function is only acceptable in a cleanup
phase, like when a connection is lost or at shutdown.
"""
if fun:
with _ignore_errors(conn):
return fun(*args, **kwargs)
return _ignore_errors(conn)
def revive_connection(connection, channel, on_revive=None):
if on_revive:
on_revive(channel)
def insured(pool, fun, args, kwargs, errback=None, on_revive=None, **opts):
"""Function wrapper to handle connection errors.
Ensures function performing broker commands completes
despite intermittent connection failures.
"""
errback = errback or _ensure_errback
with pool.acquire(block=True) as conn:
conn.ensure_connection(errback=errback)
# we cache the channel for subsequent calls, this has to be
# reset on revival.
channel = conn.default_channel
revive = partial(revive_connection, conn, on_revive=on_revive)
insured = conn.autoretry(fun, channel, errback=errback,
on_revive=revive, **opts)
retval, _ = insured(*args, **dict(kwargs, connection=conn))
return retval
class QoS:
"""Thread safe increment/decrement of a channels prefetch_count.
Arguments:
---------
callback (Callable): Function used to set new prefetch count,
e.g. ``consumer.qos`` or ``channel.basic_qos``. Will be called
with a single ``prefetch_count`` keyword argument.
initial_value (int): Initial prefetch count value..
Example:
-------
>>> from kombu import Consumer, Connection
>>> connection = Connection('amqp://')
>>> consumer = Consumer(connection)
>>> qos = QoS(consumer.qos, initial_prefetch_count=2)
>>> qos.update() # set initial
>>> qos.value
2
>>> def in_some_thread():
... qos.increment_eventually()
>>> def in_some_other_thread():
... qos.decrement_eventually()
>>> while 1:
... if qos.prev != qos.value:
... qos.update() # prefetch changed so update.
It can be used with any function supporting a ``prefetch_count`` keyword
argument::
>>> channel = connection.channel()
>>> QoS(channel.basic_qos, 10)
>>> def set_qos(prefetch_count):
... print('prefetch count now: %r' % (prefetch_count,))
>>> QoS(set_qos, 10)
"""
prev = None
def __init__(self, callback, initial_value):
self.callback = callback
self._mutex = threading.RLock()
self.value = initial_value or 0
def increment_eventually(self, n=1):
"""Increment the value, but do not update the channels QoS.
Note:
----
The MainThread will be responsible for calling :meth:`update`
when necessary.
"""
with self._mutex:
if self.value:
self.value = self.value + max(n, 0)
return self.value
def decrement_eventually(self, n=1):
"""Decrement the value, but do not update the channels QoS.
Note:
----
The MainThread will be responsible for calling :meth:`update`
when necessary.
"""
with self._mutex:
if self.value:
self.value -= n
if self.value < 1:
self.value = 1
return self.value
def set(self, pcount):
"""Set channel prefetch_count setting."""
if pcount != self.prev:
new_value = pcount
if pcount > PREFETCH_COUNT_MAX:
logger.warning('QoS: Disabled: prefetch_count exceeds %r',
PREFETCH_COUNT_MAX)
new_value = 0
logger.debug('basic.qos: prefetch_count->%s', new_value)
self.callback(prefetch_count=new_value)
self.prev = pcount
return pcount
def update(self):
"""Update prefetch count with current value."""
with self._mutex:
return self.set(self.value)

View File

@@ -0,0 +1,227 @@
"""Carrot compatibility interface.
See https://pypi.org/project/carrot/ for documentation.
"""
from __future__ import annotations
from itertools import count
from typing import TYPE_CHECKING
from . import messaging
from .entity import Exchange, Queue
if TYPE_CHECKING:
from types import TracebackType
__all__ = ('Publisher', 'Consumer')
# XXX compat attribute
entry_to_queue = Queue.from_dict
def _iterconsume(connection, consumer, no_ack=False, limit=None):
consumer.consume(no_ack=no_ack)
for iteration in count(0): # for infinity
if limit and iteration >= limit:
break
yield connection.drain_events()
class Publisher(messaging.Producer):
"""Carrot compatible producer."""
exchange = ''
exchange_type = 'direct'
routing_key = ''
durable = True
auto_delete = False
_closed = False
def __init__(self, connection, exchange=None, routing_key=None,
exchange_type=None, durable=None, auto_delete=None,
channel=None, **kwargs):
if channel:
connection = channel
self.exchange = exchange or self.exchange
self.exchange_type = exchange_type or self.exchange_type
self.routing_key = routing_key or self.routing_key
if auto_delete is not None:
self.auto_delete = auto_delete
if durable is not None:
self.durable = durable
if not isinstance(self.exchange, Exchange):
self.exchange = Exchange(name=self.exchange,
type=self.exchange_type,
routing_key=self.routing_key,
auto_delete=self.auto_delete,
durable=self.durable)
super().__init__(connection, self.exchange, **kwargs)
def send(self, *args, **kwargs):
return self.publish(*args, **kwargs)
def close(self):
super().close()
self._closed = True
def __enter__(self):
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None
) -> None:
self.close()
@property
def backend(self):
return self.channel
class Consumer(messaging.Consumer):
"""Carrot compatible consumer."""
queue = ''
exchange = ''
routing_key = ''
exchange_type = 'direct'
durable = True
exclusive = False
auto_delete = False
_closed = False
def __init__(self, connection, queue=None, exchange=None,
routing_key=None, exchange_type=None, durable=None,
exclusive=None, auto_delete=None, **kwargs):
self.backend = connection.channel()
if durable is not None:
self.durable = durable
if exclusive is not None:
self.exclusive = exclusive
if auto_delete is not None:
self.auto_delete = auto_delete
self.queue = queue or self.queue
self.exchange = exchange or self.exchange
self.exchange_type = exchange_type or self.exchange_type
self.routing_key = routing_key or self.routing_key
exchange = Exchange(self.exchange,
type=self.exchange_type,
routing_key=self.routing_key,
auto_delete=self.auto_delete,
durable=self.durable)
queue = Queue(self.queue,
exchange=exchange,
routing_key=self.routing_key,
durable=self.durable,
exclusive=self.exclusive,
auto_delete=self.auto_delete)
super().__init__(self.backend, queue, **kwargs)
def revive(self, channel):
self.backend = channel
super().revive(channel)
def close(self):
self.cancel()
self.backend.close()
self._closed = True
def __enter__(self):
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None
) -> None:
self.close()
def __iter__(self):
return self.iterqueue(infinite=True)
def fetch(self, no_ack=None, enable_callbacks=False):
if no_ack is None:
no_ack = self.no_ack
message = self.queues[0].get(no_ack)
if message:
if enable_callbacks:
self.receive(message.payload, message)
return message
def process_next(self):
raise NotImplementedError('Use fetch(enable_callbacks=True)')
def discard_all(self, filterfunc=None):
if filterfunc is not None:
raise NotImplementedError(
'discard_all does not implement filters')
return self.purge()
def iterconsume(self, limit=None, no_ack=None):
return _iterconsume(self.connection, self, no_ack, limit)
def wait(self, limit=None):
it = self.iterconsume(limit)
return list(it)
def iterqueue(self, limit=None, infinite=False):
for items_since_start in count(): # for infinity
item = self.fetch()
if (not infinite and item is None) or \
(limit and items_since_start >= limit):
break
yield item
class ConsumerSet(messaging.Consumer):
def __init__(self, connection, from_dict=None, consumers=None,
channel=None, **kwargs):
if channel:
self._provided_channel = True
self.backend = channel
else:
self._provided_channel = False
self.backend = connection.channel()
queues = []
if consumers:
for consumer in consumers:
queues.extend(consumer.queues)
if from_dict:
for queue_name, queue_options in from_dict.items():
queues.append(Queue.from_dict(queue_name, **queue_options))
super().__init__(self.backend, queues, **kwargs)
def iterconsume(self, limit=None, no_ack=False):
return _iterconsume(self.connection, self, no_ack, limit)
def discard_all(self):
return self.purge()
def add_consumer_from_dict(self, queue, **options):
return self.add_queue(Queue.from_dict(queue, **options))
def add_consumer(self, consumer):
for queue in consumer.queues:
self.add_queue(queue)
def revive(self, channel):
self.backend = channel
super().revive(channel)
def close(self):
self.cancel()
if not self._provided_channel:
self.channel.close()

View File

@@ -0,0 +1,121 @@
"""Compression utilities."""
from __future__ import annotations
import zlib
from kombu.utils.encoding import ensure_bytes
_aliases = {}
_encoders = {}
_decoders = {}
__all__ = ('register', 'encoders', 'get_encoder',
'get_decoder', 'compress', 'decompress')
def register(encoder, decoder, content_type, aliases=None):
"""Register new compression method.
Arguments:
---------
encoder (Callable): Function used to compress text.
decoder (Callable): Function used to decompress previously
compressed text.
content_type (str): The mime type this compression method
identifies as.
aliases (Sequence[str]): A list of names to associate with
this compression method.
"""
_encoders[content_type] = encoder
_decoders[content_type] = decoder
if aliases:
_aliases.update((alias, content_type) for alias in aliases)
def encoders():
"""Return a list of available compression methods."""
return list(_encoders)
def get_encoder(t):
"""Get encoder by alias name."""
t = _aliases.get(t, t)
return _encoders[t], t
def get_decoder(t):
"""Get decoder by alias name."""
return _decoders[_aliases.get(t, t)]
def compress(body, content_type):
"""Compress text.
Arguments:
---------
body (AnyStr): The text to compress.
content_type (str): mime-type of compression method to use.
"""
encoder, content_type = get_encoder(content_type)
return encoder(ensure_bytes(body)), content_type
def decompress(body, content_type):
"""Decompress compressed text.
Arguments:
---------
body (AnyStr): Previously compressed text to uncompress.
content_type (str): mime-type of compression method used.
"""
return get_decoder(content_type)(body)
register(zlib.compress,
zlib.decompress,
'application/x-gzip', aliases=['gzip', 'zlib'])
try:
import bz2
except ImportError: # pragma: no cover
pass # No bz2 support
else:
register(bz2.compress,
bz2.decompress,
'application/x-bz2', aliases=['bzip2', 'bzip'])
try:
import brotli
except ImportError: # pragma: no cover
pass
else:
register(brotli.compress,
brotli.decompress,
'application/x-brotli', aliases=['brotli'])
try:
import lzma
except ImportError: # pragma: no cover
pass # no lzma support
else:
register(lzma.compress,
lzma.decompress,
'application/x-lzma', aliases=['lzma', 'xz'])
try:
import zstandard as zstd
except ImportError: # pragma: no cover
pass
else:
def zstd_compress(body):
c = zstd.ZstdCompressor()
return c.compress(body)
def zstd_decompress(body):
d = zstd.ZstdDecompressor()
return d.decompress(body)
register(zstd_compress,
zstd_decompress,
'application/zstd', aliases=['zstd', 'zstandard'])

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,887 @@
"""Exchange and Queue declarations."""
from __future__ import annotations
import numbers
from .abstract import MaybeChannelBound, Object
from .exceptions import ContentDisallowed
from .serialization import prepare_accept_content
TRANSIENT_DELIVERY_MODE = 1
PERSISTENT_DELIVERY_MODE = 2
DELIVERY_MODES = {'transient': TRANSIENT_DELIVERY_MODE,
'persistent': PERSISTENT_DELIVERY_MODE}
__all__ = ('Exchange', 'Queue', 'binding', 'maybe_delivery_mode')
INTERNAL_EXCHANGE_PREFIX = ('amq.',)
def _reprstr(s):
s = repr(s)
if isinstance(s, str) and s.startswith("u'"):
return s[2:-1]
return s[1:-1]
def pretty_bindings(bindings):
return '[{}]'.format(', '.join(map(str, bindings)))
def maybe_delivery_mode(
v, modes=None, default=PERSISTENT_DELIVERY_MODE):
"""Get delivery mode by name (or none if undefined)."""
modes = DELIVERY_MODES if not modes else modes
if v:
return v if isinstance(v, numbers.Integral) else modes[v]
return default
class Exchange(MaybeChannelBound):
"""An Exchange declaration.
Arguments:
---------
name (str): See :attr:`name`.
type (str): See :attr:`type`.
channel (kombu.Connection, ChannelT): See :attr:`channel`.
durable (bool): See :attr:`durable`.
auto_delete (bool): See :attr:`auto_delete`.
delivery_mode (enum): See :attr:`delivery_mode`.
arguments (Dict): See :attr:`arguments`.
no_declare (bool): See :attr:`no_declare`
Attributes
----------
name (str): Name of the exchange.
Default is no name (the default exchange).
type (str):
*This description of AMQP exchange types was shamelessly stolen
from the blog post `AMQP in 10 minutes: Part 4`_ by
Rajith Attapattu. Reading this article is recommended if you're
new to amqp.*
"AMQP defines four default exchange types (routing algorithms) that
covers most of the common messaging use cases. An AMQP broker can
also define additional exchange types, so see your broker
manual for more information about available exchange types.
* `direct` (*default*)
Direct match between the routing key in the message,
and the routing criteria used when a queue is bound to
this exchange.
* `topic`
Wildcard match between the routing key and the routing
pattern specified in the exchange/queue binding.
The routing key is treated as zero or more words delimited
by `"."` and supports special wildcard characters. `"*"`
matches a single word and `"#"` matches zero or more words.
* `fanout`
Queues are bound to this exchange with no arguments. Hence
any message sent to this exchange will be forwarded to all
queues bound to this exchange.
* `headers`
Queues are bound to this exchange with a table of arguments
containing headers and values (optional). A special
argument named "x-match" determines the matching algorithm,
where `"all"` implies an `AND` (all pairs must match) and
`"any"` implies `OR` (at least one pair must match).
:attr:`arguments` is used to specify the arguments.
.. _`AMQP in 10 minutes: Part 4`:
https://bit.ly/2rcICv5
channel (ChannelT): The channel the exchange is bound to (if bound).
durable (bool): Durable exchanges remain active when a server restarts.
Non-durable exchanges (transient exchanges) are purged when a
server restarts. Default is :const:`True`.
auto_delete (bool): If set, the exchange is deleted when all queues
have finished using it. Default is :const:`False`.
delivery_mode (enum): The default delivery mode used for messages.
The value is an integer, or alias string.
* 1 or `"transient"`
The message is transient. Which means it is stored in
memory only, and is lost if the server dies or restarts.
* 2 or "persistent" (*default*)
The message is persistent. Which means the message is
stored both in-memory, and on disk, and therefore
preserved if the server dies or restarts.
The default value is 2 (persistent).
arguments (Dict): Additional arguments to specify when the exchange
is declared.
no_declare (bool): Never declare this exchange
(:meth:`declare` does nothing).
"""
TRANSIENT_DELIVERY_MODE = TRANSIENT_DELIVERY_MODE
PERSISTENT_DELIVERY_MODE = PERSISTENT_DELIVERY_MODE
name = ''
type = 'direct'
durable = True
auto_delete = False
passive = False
delivery_mode = None
no_declare = False
attrs = (
('name', None),
('type', None),
('arguments', None),
('durable', bool),
('passive', bool),
('auto_delete', bool),
('delivery_mode', lambda m: DELIVERY_MODES.get(m) or m),
('no_declare', bool),
)
def __init__(self, name='', type='', channel=None, **kwargs):
super().__init__(**kwargs)
self.name = name or self.name
self.type = type or self.type
self.maybe_bind(channel)
def __hash__(self):
return hash(f'E|{self.name}')
def _can_declare(self):
return not self.no_declare and (
self.name and not self.name.startswith(
INTERNAL_EXCHANGE_PREFIX))
def declare(self, nowait=False, passive=None, channel=None):
"""Declare the exchange.
Creates the exchange on the broker, unless passive is set
in which case it will only assert that the exchange exists.
Argument:
nowait (bool): If set the server will not respond, and a
response will not be waited for. Default is :const:`False`.
"""
if self._can_declare():
passive = self.passive if passive is None else passive
return (channel or self.channel).exchange_declare(
exchange=self.name, type=self.type, durable=self.durable,
auto_delete=self.auto_delete, arguments=self.arguments,
nowait=nowait, passive=passive,
)
def bind_to(self, exchange='', routing_key='',
arguments=None, nowait=False, channel=None, **kwargs):
"""Bind the exchange to another exchange.
Arguments:
---------
nowait (bool): If set the server will not respond, and the call
will not block waiting for a response.
Default is :const:`False`.
"""
if isinstance(exchange, Exchange):
exchange = exchange.name
return (channel or self.channel).exchange_bind(
destination=self.name,
source=exchange,
routing_key=routing_key,
nowait=nowait,
arguments=arguments,
)
def unbind_from(self, source='', routing_key='',
nowait=False, arguments=None, channel=None):
"""Delete previously created exchange binding from the server."""
if isinstance(source, Exchange):
source = source.name
return (channel or self.channel).exchange_unbind(
destination=self.name,
source=source,
routing_key=routing_key,
nowait=nowait,
arguments=arguments,
)
def Message(self, body, delivery_mode=None, properties=None, **kwargs):
"""Create message instance to be sent with :meth:`publish`.
Arguments:
---------
body (Any): Message body.
delivery_mode (bool): Set custom delivery mode.
Defaults to :attr:`delivery_mode`.
priority (int): Message priority, 0 to broker configured
max priority, where higher is better.
content_type (str): The messages content_type. If content_type
is set, no serialization occurs as it is assumed this is either
a binary object, or you've done your own serialization.
Leave blank if using built-in serialization as our library
properly sets content_type.
content_encoding (str): The character set in which this object
is encoded. Use "binary" if sending in raw binary objects.
Leave blank if using built-in serialization as our library
properly sets content_encoding.
properties (Dict): Message properties.
headers (Dict): Message headers.
"""
properties = {} if properties is None else properties
properties['delivery_mode'] = maybe_delivery_mode(self.delivery_mode)
if (isinstance(body, str) and
properties.get('content_encoding', None)) is None:
kwargs['content_encoding'] = 'utf-8'
return self.channel.prepare_message(
body,
properties=properties,
**kwargs)
def publish(self, message, routing_key=None, mandatory=False,
immediate=False, exchange=None):
"""Publish message.
Arguments:
---------
message (Union[kombu.Message, str, bytes]):
Message to publish.
routing_key (str): Message routing key.
mandatory (bool): Currently not supported.
immediate (bool): Currently not supported.
"""
if isinstance(message, str):
message = self.Message(message)
exchange = exchange or self.name
return self.channel.basic_publish(
message,
exchange=exchange,
routing_key=routing_key,
mandatory=mandatory,
immediate=immediate,
)
def delete(self, if_unused=False, nowait=False):
"""Delete the exchange declaration on server.
Arguments:
---------
if_unused (bool): Delete only if the exchange has no bindings.
Default is :const:`False`.
nowait (bool): If set the server will not respond, and a
response will not be waited for. Default is :const:`False`.
"""
return self.channel.exchange_delete(exchange=self.name,
if_unused=if_unused,
nowait=nowait)
def binding(self, routing_key='', arguments=None, unbind_arguments=None):
return binding(self, routing_key, arguments, unbind_arguments)
def __eq__(self, other):
if isinstance(other, Exchange):
return (self.name == other.name and
self.type == other.type and
self.arguments == other.arguments and
self.durable == other.durable and
self.auto_delete == other.auto_delete and
self.delivery_mode == other.delivery_mode)
return NotImplemented
def __ne__(self, other):
return not self.__eq__(other)
def __repr__(self):
return self._repr_entity(self)
def __str__(self):
return 'Exchange {}({})'.format(
_reprstr(self.name) or repr(''), self.type,
)
@property
def can_cache_declaration(self):
return not self.auto_delete
class binding(Object):
"""Represents a queue or exchange binding.
Arguments:
---------
exchange (Exchange): Exchange to bind to.
routing_key (str): Routing key used as binding key.
arguments (Dict): Arguments for bind operation.
unbind_arguments (Dict): Arguments for unbind operation.
"""
attrs = (
('exchange', None),
('routing_key', None),
('arguments', None),
('unbind_arguments', None)
)
def __init__(self, exchange=None, routing_key='',
arguments=None, unbind_arguments=None):
self.exchange = exchange
self.routing_key = routing_key
self.arguments = arguments
self.unbind_arguments = unbind_arguments
def declare(self, channel, nowait=False):
"""Declare destination exchange."""
if self.exchange and self.exchange.name:
self.exchange.declare(channel=channel, nowait=nowait)
def bind(self, entity, nowait=False, channel=None):
"""Bind entity to this binding."""
entity.bind_to(exchange=self.exchange,
routing_key=self.routing_key,
arguments=self.arguments,
nowait=nowait,
channel=channel)
def unbind(self, entity, nowait=False, channel=None):
"""Unbind entity from this binding."""
entity.unbind_from(self.exchange,
routing_key=self.routing_key,
arguments=self.unbind_arguments,
nowait=nowait,
channel=channel)
def __repr__(self):
return f'<binding: {self}>'
def __str__(self):
return '{}->{}'.format(
_reprstr(self.exchange.name), _reprstr(self.routing_key),
)
class Queue(MaybeChannelBound):
"""A Queue declaration.
Arguments:
---------
name (str): See :attr:`name`.
exchange (Exchange, str): See :attr:`exchange`.
routing_key (str): See :attr:`routing_key`.
channel (kombu.Connection, ChannelT): See :attr:`channel`.
durable (bool): See :attr:`durable`.
exclusive (bool): See :attr:`exclusive`.
auto_delete (bool): See :attr:`auto_delete`.
queue_arguments (Dict): See :attr:`queue_arguments`.
binding_arguments (Dict): See :attr:`binding_arguments`.
consumer_arguments (Dict): See :attr:`consumer_arguments`.
no_declare (bool): See :attr:`no_declare`.
on_declared (Callable): See :attr:`on_declared`.
expires (float): See :attr:`expires`.
message_ttl (float): See :attr:`message_ttl`.
max_length (int): See :attr:`max_length`.
max_length_bytes (int): See :attr:`max_length_bytes`.
max_priority (int): See :attr:`max_priority`.
Attributes
----------
name (str): Name of the queue.
Default is no name (default queue destination).
exchange (Exchange): The :class:`Exchange` the queue binds to.
routing_key (str): The routing key (if any), also called *binding key*.
The interpretation of the routing key depends on
the :attr:`Exchange.type`.
* direct exchange
Matches if the routing key property of the message and
the :attr:`routing_key` attribute are identical.
* fanout exchange
Always matches, even if the binding does not have a key.
* topic exchange
Matches the routing key property of the message by a primitive
pattern matching scheme. The message routing key then consists
of words separated by dots (`"."`, like domain names), and
two special characters are available; star (`"*"`) and hash
(`"#"`). The star matches any word, and the hash matches
zero or more words. For example `"*.stock.#"` matches the
routing keys `"usd.stock"` and `"eur.stock.db"` but not
`"stock.nasdaq"`.
channel (ChannelT): The channel the Queue is bound to (if bound).
durable (bool): Durable queues remain active when a server restarts.
Non-durable queues (transient queues) are purged if/when
a server restarts.
Note that durable queues do not necessarily hold persistent
messages, although it does not make sense to send
persistent messages to a transient queue.
Default is :const:`True`.
exclusive (bool): Exclusive queues may only be consumed from by the
current connection. Setting the 'exclusive' flag
always implies 'auto-delete'.
Default is :const:`False`.
auto_delete (bool): If set, the queue is deleted when all consumers
have finished using it. Last consumer can be canceled
either explicitly or because its channel is closed. If
there was no consumer ever on the queue, it won't be
deleted.
expires (float): Set the expiry time (in seconds) for when this
queue should expire.
The expiry time decides how long the queue can stay unused
before it's automatically deleted.
*Unused* means the queue has no consumers, the queue has not been
redeclared, and ``Queue.get`` has not been invoked for a duration
of at least the expiration period.
See https://www.rabbitmq.com/ttl.html#queue-ttl
**RabbitMQ extension**: Only available when using RabbitMQ.
message_ttl (float): Message time to live in seconds.
This setting controls how long messages can stay in the queue
unconsumed. If the expiry time passes before a message consumer
has received the message, the message is deleted and no consumer
will see the message.
See https://www.rabbitmq.com/ttl.html#per-queue-message-ttl
**RabbitMQ extension**: Only available when using RabbitMQ.
max_length (int): Set the maximum number of messages that the
queue can hold.
If the number of messages in the queue size exceeds this limit,
new messages will be dropped (or dead-lettered if a dead letter
exchange is active).
See https://www.rabbitmq.com/maxlength.html
**RabbitMQ extension**: Only available when using RabbitMQ.
max_length_bytes (int): Set the max size (in bytes) for the total
of messages in the queue.
If the total size of all the messages in the queue exceeds this
limit, new messages will be dropped (or dead-lettered if a dead
letter exchange is active).
**RabbitMQ extension**: Only available when using RabbitMQ.
max_priority (int): Set the highest priority number for this queue.
For example if the value is 10, then messages can delivered to
this queue can have a ``priority`` value between 0 and 10,
where 10 is the highest priority.
RabbitMQ queues without a max priority set will ignore
the priority field in the message, so if you want priorities
you need to set the max priority field to declare the queue
as a priority queue.
**RabbitMQ extension**: Only available when using RabbitMQ.
queue_arguments (Dict): Additional arguments used when declaring
the queue. Can be used to to set the arguments value
for RabbitMQ/AMQP's ``queue.declare``.
binding_arguments (Dict): Additional arguments used when binding
the queue. Can be used to to set the arguments value
for RabbitMQ/AMQP's ``queue.declare``.
consumer_arguments (Dict): Additional arguments used when consuming
from this queue. Can be used to to set the arguments value
for RabbitMQ/AMQP's ``basic.consume``.
alias (str): Unused in Kombu, but applications can take advantage
of this, for example to give alternate names to queues with
automatically generated queue names.
on_declared (Callable): Optional callback to be applied when the
queue has been declared (the ``queue_declare`` operation is
complete). This must be a function with a signature that
accepts at least 3 positional arguments:
``(name, messages, consumers)``.
no_declare (bool): Never declare this queue, nor related
entities (:meth:`declare` does nothing).
"""
ContentDisallowed = ContentDisallowed
name = ''
exchange = Exchange('')
routing_key = ''
durable = True
exclusive = False
auto_delete = False
no_ack = False
attrs = (
('name', None),
('exchange', None),
('routing_key', None),
('queue_arguments', None),
('binding_arguments', None),
('consumer_arguments', None),
('durable', bool),
('exclusive', bool),
('auto_delete', bool),
('no_ack', None),
('alias', None),
('bindings', list),
('no_declare', bool),
('expires', float),
('message_ttl', float),
('max_length', int),
('max_length_bytes', int),
('max_priority', int)
)
def __init__(self, name='', exchange=None, routing_key='',
channel=None, bindings=None, on_declared=None,
**kwargs):
super().__init__(**kwargs)
self.name = name or self.name
if isinstance(exchange, str):
self.exchange = Exchange(exchange)
elif isinstance(exchange, Exchange):
self.exchange = exchange
self.routing_key = routing_key or self.routing_key
self.bindings = set(bindings or [])
self.on_declared = on_declared
# allows Queue('name', [binding(...), binding(...), ...])
if isinstance(exchange, (list, tuple, set)):
self.bindings |= set(exchange)
if self.bindings:
self.exchange = None
# exclusive implies auto-delete.
if self.exclusive:
self.auto_delete = True
self.maybe_bind(channel)
def bind(self, channel):
on_declared = self.on_declared
bound = super().bind(channel)
bound.on_declared = on_declared
return bound
def __hash__(self):
return hash(f'Q|{self.name}')
def when_bound(self):
if self.exchange:
self.exchange = self.exchange(self.channel)
def declare(self, nowait=False, channel=None):
"""Declare queue and exchange then binds queue to exchange."""
if not self.no_declare:
# - declare main binding.
self._create_exchange(nowait=nowait, channel=channel)
self._create_queue(nowait=nowait, channel=channel)
self._create_bindings(nowait=nowait, channel=channel)
return self.name
def _create_exchange(self, nowait=False, channel=None):
if self.exchange:
self.exchange.declare(nowait=nowait, channel=channel)
def _create_queue(self, nowait=False, channel=None):
self.queue_declare(nowait=nowait, passive=False, channel=channel)
if self.exchange and self.exchange.name:
self.queue_bind(nowait=nowait, channel=channel)
def _create_bindings(self, nowait=False, channel=None):
for B in self.bindings:
channel = channel or self.channel
B.declare(channel)
B.bind(self, nowait=nowait, channel=channel)
def queue_declare(self, nowait=False, passive=False, channel=None):
"""Declare queue on the server.
Arguments:
---------
nowait (bool): Do not wait for a reply.
passive (bool): If set, the server will not create the queue.
The client can use this to check whether a queue exists
without modifying the server state.
"""
channel = channel or self.channel
queue_arguments = channel.prepare_queue_arguments(
self.queue_arguments or {},
expires=self.expires,
message_ttl=self.message_ttl,
max_length=self.max_length,
max_length_bytes=self.max_length_bytes,
max_priority=self.max_priority,
)
ret = channel.queue_declare(
queue=self.name,
passive=passive,
durable=self.durable,
exclusive=self.exclusive,
auto_delete=self.auto_delete,
arguments=queue_arguments,
nowait=nowait,
)
if not self.name:
self.name = ret[0]
if self.on_declared:
self.on_declared(*ret)
return ret
def queue_bind(self, nowait=False, channel=None):
"""Create the queue binding on the server."""
return self.bind_to(self.exchange, self.routing_key,
self.binding_arguments,
channel=channel, nowait=nowait)
def bind_to(self, exchange='', routing_key='',
arguments=None, nowait=False, channel=None):
if isinstance(exchange, Exchange):
exchange = exchange.name
return (channel or self.channel).queue_bind(
queue=self.name,
exchange=exchange,
routing_key=routing_key,
arguments=arguments,
nowait=nowait,
)
def get(self, no_ack=None, accept=None):
"""Poll the server for a new message.
This method provides direct access to the messages in a
queue using a synchronous dialogue, designed for
specific types of applications where synchronous functionality
is more important than performance.
Returns
-------
~kombu.Message: if a message was available,
or :const:`None` otherwise.
Arguments:
---------
no_ack (bool): If enabled the broker will
automatically ack messages.
accept (Set[str]): Custom list of accepted content types.
"""
no_ack = self.no_ack if no_ack is None else no_ack
message = self.channel.basic_get(queue=self.name, no_ack=no_ack)
if message is not None:
m2p = getattr(self.channel, 'message_to_python', None)
if m2p:
message = m2p(message)
if message.errors:
message._reraise_error()
message.accept = prepare_accept_content(accept)
return message
def purge(self, nowait=False):
"""Remove all ready messages from the queue."""
return self.channel.queue_purge(queue=self.name,
nowait=nowait) or 0
def consume(self, consumer_tag='', callback=None,
no_ack=None, nowait=False, on_cancel=None):
"""Start a queue consumer.
Consumers last as long as the channel they were created on, or
until the client cancels them.
Arguments:
---------
consumer_tag (str): Unique identifier for the consumer.
The consumer tag is local to a connection, so two clients
can use the same consumer tags. If this field is empty
the server will generate a unique tag.
no_ack (bool): If enabled the broker will automatically
ack messages.
nowait (bool): Do not wait for a reply.
callback (Callable): callback called for each delivered message.
on_cancel (Callable): callback called on cancel notify received
from broker.
"""
if no_ack is None:
no_ack = self.no_ack
return self.channel.basic_consume(
queue=self.name,
no_ack=no_ack,
consumer_tag=consumer_tag or '',
callback=callback,
nowait=nowait,
arguments=self.consumer_arguments,
on_cancel=on_cancel,
)
def cancel(self, consumer_tag):
"""Cancel a consumer by consumer tag."""
return self.channel.basic_cancel(consumer_tag)
def delete(self, if_unused=False, if_empty=False, nowait=False):
"""Delete the queue.
Arguments:
---------
if_unused (bool): If set, the server will only delete the queue
if it has no consumers. A channel error will be raised
if the queue has consumers.
if_empty (bool): If set, the server will only delete the queue if
it is empty. If it is not empty a channel error will be raised.
nowait (bool): Do not wait for a reply.
"""
return self.channel.queue_delete(queue=self.name,
if_unused=if_unused,
if_empty=if_empty,
nowait=nowait)
def queue_unbind(self, arguments=None, nowait=False, channel=None):
return self.unbind_from(self.exchange, self.routing_key,
arguments, nowait, channel)
def unbind_from(self, exchange='', routing_key='',
arguments=None, nowait=False, channel=None):
"""Unbind queue by deleting the binding from the server."""
return (channel or self.channel).queue_unbind(
queue=self.name,
exchange=exchange.name,
routing_key=routing_key,
arguments=arguments,
nowait=nowait,
)
def __eq__(self, other):
if isinstance(other, Queue):
return (self.name == other.name and
self.exchange == other.exchange and
self.routing_key == other.routing_key and
self.queue_arguments == other.queue_arguments and
self.binding_arguments == other.binding_arguments and
self.consumer_arguments == other.consumer_arguments and
self.durable == other.durable and
self.exclusive == other.exclusive and
self.auto_delete == other.auto_delete)
return NotImplemented
def __ne__(self, other):
return not self.__eq__(other)
def __repr__(self):
if self.bindings:
return self._repr_entity('Queue {name} -> {bindings}'.format(
name=_reprstr(self.name),
bindings=pretty_bindings(self.bindings),
))
return self._repr_entity(
'Queue {name} -> {0.exchange!r} -> {routing_key}'.format(
self, name=_reprstr(self.name),
routing_key=_reprstr(self.routing_key),
),
)
@property
def can_cache_declaration(self):
if self.queue_arguments:
expiring_queue = "x-expires" in self.queue_arguments
else:
expiring_queue = False
return not expiring_queue and not self.auto_delete
@classmethod
def from_dict(cls, queue, **options):
binding_key = options.get('binding_key') or options.get('routing_key')
e_durable = options.get('exchange_durable')
if e_durable is None:
e_durable = options.get('durable')
e_auto_delete = options.get('exchange_auto_delete')
if e_auto_delete is None:
e_auto_delete = options.get('auto_delete')
q_durable = options.get('queue_durable')
if q_durable is None:
q_durable = options.get('durable')
q_auto_delete = options.get('queue_auto_delete')
if q_auto_delete is None:
q_auto_delete = options.get('auto_delete')
e_arguments = options.get('exchange_arguments')
q_arguments = options.get('queue_arguments')
b_arguments = options.get('binding_arguments')
c_arguments = options.get('consumer_arguments')
bindings = options.get('bindings')
exchange = Exchange(options.get('exchange'),
type=options.get('exchange_type'),
delivery_mode=options.get('delivery_mode'),
routing_key=options.get('routing_key'),
durable=e_durable,
auto_delete=e_auto_delete,
arguments=e_arguments)
return Queue(queue,
exchange=exchange,
routing_key=binding_key,
durable=q_durable,
exclusive=options.get('exclusive'),
auto_delete=q_auto_delete,
no_ack=options.get('no_ack'),
queue_arguments=q_arguments,
binding_arguments=b_arguments,
consumer_arguments=c_arguments,
bindings=bindings)
def as_dict(self, recurse=False):
res = super().as_dict(recurse)
if not recurse:
return res
bindings = res.get('bindings')
if bindings:
res['bindings'] = [b.as_dict(recurse=True) for b in bindings]
return res

View File

@@ -0,0 +1,112 @@
"""Exceptions."""
from __future__ import annotations
from socket import timeout as TimeoutError
from types import TracebackType
from typing import TYPE_CHECKING, TypeVar
from amqp import ChannelError, ConnectionError, ResourceError
if TYPE_CHECKING:
from kombu.asynchronous.http import Response
__all__ = (
'reraise', 'KombuError', 'OperationalError',
'NotBoundError', 'MessageStateError', 'TimeoutError',
'LimitExceeded', 'ConnectionLimitExceeded',
'ChannelLimitExceeded', 'ConnectionError', 'ChannelError',
'VersionMismatch', 'SerializerNotInstalled', 'ResourceError',
'SerializationError', 'EncodeError', 'DecodeError', 'HttpError',
'InconsistencyError',
)
BaseExceptionType = TypeVar('BaseExceptionType', bound=BaseException)
def reraise(
tp: type[BaseExceptionType],
value: BaseExceptionType,
tb: TracebackType | None = None
) -> BaseExceptionType:
"""Reraise exception."""
if value.__traceback__ is not tb:
raise value.with_traceback(tb)
raise value
class KombuError(Exception):
"""Common subclass for all Kombu exceptions."""
class OperationalError(KombuError):
"""Recoverable message transport connection error."""
class SerializationError(KombuError):
"""Failed to serialize/deserialize content."""
class EncodeError(SerializationError):
"""Cannot encode object."""
class DecodeError(SerializationError):
"""Cannot decode object."""
class NotBoundError(KombuError):
"""Trying to call channel dependent method on unbound entity."""
class MessageStateError(KombuError):
"""The message has already been acknowledged."""
class LimitExceeded(KombuError):
"""Limit exceeded."""
class ConnectionLimitExceeded(LimitExceeded):
"""Maximum number of simultaneous connections exceeded."""
class ChannelLimitExceeded(LimitExceeded):
"""Maximum number of simultaneous channels exceeded."""
class VersionMismatch(KombuError):
"""Library dependency version mismatch."""
class SerializerNotInstalled(KombuError):
"""Support for the requested serialization type is not installed."""
class ContentDisallowed(SerializerNotInstalled):
"""Consumer does not allow this content-type."""
class InconsistencyError(ConnectionError):
"""Data or environment has been found to be inconsistent.
Depending on the cause it may be possible to retry the operation.
"""
class HttpError(Exception):
"""HTTP Client Error."""
def __init__(
self,
code: int,
message: str | None = None,
response: Response | None = None
) -> None:
self.code = code
self.message = message
self.response = response
super().__init__(code, message, response)
def __str__(self) -> str:
return 'HTTP {0.code}: {0.message}'.format(self)

View File

@@ -0,0 +1,143 @@
"""Logging Utilities."""
from __future__ import annotations
import logging
import numbers
import os
import sys
from logging.handlers import WatchedFileHandler
from typing import TYPE_CHECKING
from .utils.encoding import safe_repr, safe_str
from .utils.functional import maybe_evaluate
from .utils.objects import cached_property
if TYPE_CHECKING:
from logging import Logger
__all__ = ('LogMixin', 'LOG_LEVELS', 'get_loglevel', 'setup_logging')
LOG_LEVELS = dict(logging._nameToLevel)
LOG_LEVELS.update(logging._levelToName)
LOG_LEVELS.setdefault('FATAL', logging.FATAL)
LOG_LEVELS.setdefault(logging.FATAL, 'FATAL')
DISABLE_TRACEBACKS = os.environ.get('DISABLE_TRACEBACKS')
def get_logger(logger: str | Logger):
"""Get logger by name."""
if isinstance(logger, str):
logger = logging.getLogger(logger)
if not logger.handlers:
logger.addHandler(logging.NullHandler())
return logger
def get_loglevel(level):
"""Get loglevel by name."""
if isinstance(level, str):
return LOG_LEVELS[level]
return level
def naive_format_parts(fmt):
parts = fmt.split('%')
for i, e in enumerate(parts[1:]):
yield None if not e or not parts[i - 1] else e[0]
def safeify_format(fmt, args, filters=None):
filters = {'s': safe_str, 'r': safe_repr} if not filters else filters
for index, type in enumerate(naive_format_parts(fmt)):
filt = filters.get(type)
yield filt(args[index]) if filt else args[index]
class LogMixin:
"""Mixin that adds severity methods to any class."""
def debug(self, *args, **kwargs):
return self.log(logging.DEBUG, *args, **kwargs)
def info(self, *args, **kwargs):
return self.log(logging.INFO, *args, **kwargs)
def warn(self, *args, **kwargs):
return self.log(logging.WARN, *args, **kwargs)
def error(self, *args, **kwargs):
kwargs.setdefault('exc_info', True)
return self.log(logging.ERROR, *args, **kwargs)
def critical(self, *args, **kwargs):
kwargs.setdefault('exc_info', True)
return self.log(logging.CRITICAL, *args, **kwargs)
def annotate(self, text):
return f'{self.logger_name} - {text}'
def log(self, severity, *args, **kwargs):
if DISABLE_TRACEBACKS:
kwargs.pop('exc_info', None)
if self.logger.isEnabledFor(severity):
log = self.logger.log
if len(args) > 1 and isinstance(args[0], str):
expand = [maybe_evaluate(arg) for arg in args[1:]]
return log(severity,
self.annotate(args[0].replace('%r', '%s')),
*list(safeify_format(args[0], expand)), **kwargs)
else:
return self.logger.log(
severity, self.annotate(' '.join(map(safe_str, args))),
**kwargs)
def get_logger(self):
return get_logger(self.logger_name)
def is_enabled_for(self, level):
return self.logger.isEnabledFor(self.get_loglevel(level))
def get_loglevel(self, level):
if not isinstance(level, numbers.Integral):
return LOG_LEVELS[level]
return level
@cached_property
def logger(self):
return self.get_logger()
@property
def logger_name(self):
return self.__class__.__name__
class Log(LogMixin):
def __init__(self, name, logger=None):
self._logger_name = name
self._logger = logger
def get_logger(self):
if self._logger:
return self._logger
return super().get_logger()
@property
def logger_name(self):
return self._logger_name
def setup_logging(loglevel=None, logfile=None):
"""Setup logging."""
logger = logging.getLogger()
loglevel = get_loglevel(loglevel or 'ERROR')
logfile = logfile if logfile else sys.__stderr__
if not logger.handlers:
if hasattr(logfile, 'write'):
handler = logging.StreamHandler(logfile)
else:
handler = WatchedFileHandler(logfile)
logger.addHandler(handler)
logger.setLevel(loglevel)
return logger

View File

@@ -0,0 +1,144 @@
"""Pattern matching registry."""
from __future__ import annotations
from fnmatch import fnmatch
from re import match as rematch
from typing import Callable, cast
from .utils.compat import entrypoints
from .utils.encoding import bytes_to_str
MatcherFunction = Callable[[str, str], bool]
class MatcherNotInstalled(Exception):
"""Matcher not installed/found."""
class MatcherRegistry:
"""Pattern matching function registry."""
MatcherNotInstalled = MatcherNotInstalled
matcher_pattern_first = ["pcre", ]
def __init__(self) -> None:
self._matchers: dict[str, MatcherFunction] = {}
self._default_matcher: MatcherFunction | None = None
def register(self, name: str, matcher: MatcherFunction) -> None:
"""Add matcher by name to the registry."""
self._matchers[name] = matcher
def unregister(self, name: str) -> None:
"""Remove matcher by name from the registry."""
try:
self._matchers.pop(name)
except KeyError:
raise self.MatcherNotInstalled(
f'No matcher installed for {name}'
)
def _set_default_matcher(self, name: str) -> None:
"""Set the default matching method.
:param name: The name of the registered matching method.
For example, `glob` (default), `pcre`, or any custom
methods registered using :meth:`register`.
:raises MatcherNotInstalled: If the matching method requested
is not available.
"""
try:
self._default_matcher = self._matchers[name]
except KeyError:
raise self.MatcherNotInstalled(
f'No matcher installed for {name}'
)
def match(
self,
data: bytes,
pattern: bytes,
matcher: str | None = None,
matcher_kwargs: dict[str, str] | None = None
) -> bool:
"""Call the matcher."""
if matcher and not self._matchers.get(matcher):
raise self.MatcherNotInstalled(
f'No matcher installed for {matcher}'
)
match_func = self._matchers[matcher or 'glob']
if matcher in self.matcher_pattern_first:
first_arg = bytes_to_str(pattern)
second_arg = bytes_to_str(data)
else:
first_arg = bytes_to_str(data)
second_arg = bytes_to_str(pattern)
return match_func(first_arg, second_arg, **matcher_kwargs or {})
#: Global registry of matchers.
registry = MatcherRegistry()
"""
.. function:: match(data, pattern, matcher=default_matcher,
matcher_kwargs=None):
Match `data` by `pattern` using `matcher`.
:param data: The data that should be matched. Must be string.
:param pattern: The pattern that should be applied. Must be string.
:keyword matcher: An optional string representing the matching
method (for example, `glob` or `pcre`).
If :const:`None` (default), then `glob` will be used.
:keyword matcher_kwargs: Additional keyword arguments that will be passed
to the specified `matcher`.
:returns: :const:`True` if `data` matches pattern,
:const:`False` otherwise.
:raises MatcherNotInstalled: If the matching method requested is not
available.
"""
match = registry.match
"""
.. function:: register(name, matcher):
Register a new matching method.
:param name: A convenient name for the matching method.
:param matcher: A method that will be passed data and pattern.
"""
register = registry.register
"""
.. function:: unregister(name):
Unregister registered matching method.
:param name: Registered matching method name.
"""
unregister = registry.unregister
def register_glob() -> None:
"""Register glob into default registry."""
registry.register('glob', fnmatch)
def register_pcre() -> None:
"""Register pcre into default registry."""
registry.register('pcre', cast(MatcherFunction, rematch))
# Register the base matching methods.
register_glob()
register_pcre()
# Default matching method is 'glob'
registry._set_default_matcher('glob')
# Load entrypoints from installed extensions
for ep, args in entrypoints('kombu.matchers'):
register(ep.name, *args)

View File

@@ -0,0 +1,234 @@
"""Message class."""
from __future__ import annotations
import sys
from .compression import decompress
from .exceptions import MessageStateError, reraise
from .serialization import loads
from .utils.functional import dictfilter
__all__ = ('Message',)
ACK_STATES = {'ACK', 'REJECTED', 'REQUEUED'}
IS_PYPY = hasattr(sys, 'pypy_version_info')
class Message:
"""Base class for received messages.
Keyword Arguments:
-----------------
channel (ChannelT): If message was received, this should be the
channel that the message was received on.
body (str): Message body.
delivery_mode (bool): Set custom delivery mode.
Defaults to :attr:`delivery_mode`.
priority (int): Message priority, 0 to broker configured
max priority, where higher is better.
content_type (str): The messages content_type. If content_type
is set, no serialization occurs as it is assumed this is either
a binary object, or you've done your own serialization.
Leave blank if using built-in serialization as our library
properly sets content_type.
content_encoding (str): The character set in which this object
is encoded. Use "binary" if sending in raw binary objects.
Leave blank if using built-in serialization as our library
properly sets content_encoding.
properties (Dict): Message properties.
headers (Dict): Message headers.
"""
MessageStateError = MessageStateError
errors = None
if not IS_PYPY: # pragma: no cover
__slots__ = (
'_state', 'channel', 'delivery_tag',
'content_type', 'content_encoding',
'delivery_info', 'headers', 'properties',
'body', '_decoded_cache', 'accept', '__dict__',
)
def __init__(self, body=None, delivery_tag=None,
content_type=None, content_encoding=None, delivery_info=None,
properties=None, headers=None, postencode=None,
accept=None, channel=None, **kwargs):
delivery_info = {} if not delivery_info else delivery_info
self.errors = [] if self.errors is None else self.errors
self.channel = channel
self.delivery_tag = delivery_tag
self.content_type = content_type
self.content_encoding = content_encoding
self.delivery_info = delivery_info
self.headers = headers or {}
self.properties = properties or {}
self._decoded_cache = None
self._state = 'RECEIVED'
self.accept = accept
compression = self.headers.get('compression')
if not self.errors and compression:
try:
body = decompress(body, compression)
except Exception:
self.errors.append(sys.exc_info())
if not self.errors and postencode and isinstance(body, str):
try:
body = body.encode(postencode)
except Exception:
self.errors.append(sys.exc_info())
self.body = body
def _reraise_error(self, callback=None):
try:
reraise(*self.errors[0])
except Exception as exc:
if not callback:
raise
callback(self, exc)
def ack(self, multiple=False):
"""Acknowledge this message as being processed.
This will remove the message from the queue.
Raises
------
MessageStateError: If the message has already been
acknowledged/requeued/rejected.
"""
if self.channel is None:
raise self.MessageStateError(
'This message does not have a receiving channel')
if self.channel.no_ack_consumers is not None:
try:
consumer_tag = self.delivery_info['consumer_tag']
except KeyError:
pass
else:
if consumer_tag in self.channel.no_ack_consumers:
return
if self.acknowledged:
raise self.MessageStateError(
'Message already acknowledged with state: {0._state}'.format(
self))
self.channel.basic_ack(self.delivery_tag, multiple=multiple)
self._state = 'ACK'
def ack_log_error(self, logger, errors, multiple=False):
try:
self.ack(multiple=multiple)
except BrokenPipeError as exc:
logger.critical("Couldn't ack %r, reason:%r",
self.delivery_tag, exc, exc_info=True)
raise
except errors as exc:
logger.critical("Couldn't ack %r, reason:%r",
self.delivery_tag, exc, exc_info=True)
def reject_log_error(self, logger, errors, requeue=False):
try:
self.reject(requeue=requeue)
except errors as exc:
logger.critical("Couldn't reject %r, reason: %r",
self.delivery_tag, exc, exc_info=True)
def reject(self, requeue=False):
"""Reject this message.
The message will be discarded by the server.
Raises
------
MessageStateError: If the message has already been
acknowledged/requeued/rejected.
"""
if self.channel is None:
raise self.MessageStateError(
'This message does not have a receiving channel')
if self.acknowledged:
raise self.MessageStateError(
'Message already acknowledged with state: {0._state}'.format(
self))
self.channel.basic_reject(self.delivery_tag, requeue=requeue)
self._state = 'REJECTED'
def requeue(self):
"""Reject this message and put it back on the queue.
Warning:
-------
You must not use this method as a means of selecting messages
to process.
Raises
------
MessageStateError: If the message has already been
acknowledged/requeued/rejected.
"""
if self.channel is None:
raise self.MessageStateError(
'This message does not have a receiving channel')
if self.acknowledged:
raise self.MessageStateError(
'Message already acknowledged with state: {0._state}'.format(
self))
self.channel.basic_reject(self.delivery_tag, requeue=True)
self._state = 'REQUEUED'
def decode(self):
"""Deserialize the message body.
Returning the original python structure sent by the publisher.
Note:
----
The return value is memoized, use `_decode` to force
re-evaluation.
"""
if not self._decoded_cache:
self._decoded_cache = self._decode()
return self._decoded_cache
def _decode(self):
return loads(self.body, self.content_type,
self.content_encoding, accept=self.accept)
@property
def acknowledged(self):
"""Set to true if the message has been acknowledged."""
return self._state in ACK_STATES
@property
def payload(self):
"""The decoded message body."""
return self._decoded_cache if self._decoded_cache else self.decode()
def __repr__(self):
return '<{} object at {:#x} with details {!r}>'.format(
type(self).__name__, id(self), dictfilter(
state=self._state,
content_type=self.content_type,
delivery_tag=self.delivery_tag,
body_length=len(self.body) if self.body is not None else None,
properties=dictfilter(
correlation_id=self.properties.get('correlation_id'),
type=self.properties.get('type'),
),
delivery_info=dictfilter(
exchange=self.delivery_info.get('exchange'),
routing_key=self.delivery_info.get('routing_key'),
),
),
)

View File

@@ -0,0 +1,678 @@
"""Sending and receiving messages."""
from __future__ import annotations
from itertools import count
from typing import TYPE_CHECKING
from .common import maybe_declare
from .compression import compress
from .connection import PooledConnection, is_connection, maybe_channel
from .entity import Exchange, Queue, maybe_delivery_mode
from .exceptions import ContentDisallowed
from .serialization import dumps, prepare_accept_content
from .utils.functional import ChannelPromise, maybe_list
if TYPE_CHECKING:
from types import TracebackType
__all__ = ('Exchange', 'Queue', 'Producer', 'Consumer')
class Producer:
"""Message Producer.
Arguments:
---------
channel (kombu.Connection, ChannelT): Connection or channel.
exchange (kombu.entity.Exchange, str): Optional default exchange.
routing_key (str): Optional default routing key.
serializer (str): Default serializer. Default is `"json"`.
compression (str): Default compression method.
Default is no compression.
auto_declare (bool): Automatically declare the default exchange
at instantiation. Default is :const:`True`.
on_return (Callable): Callback to call for undeliverable messages,
when the `mandatory` or `immediate` arguments to
:meth:`publish` is used. This callback needs the following
signature: `(exception, exchange, routing_key, message)`.
Note that the producer needs to drain events to use this feature.
"""
#: Default exchange
exchange = None
#: Default routing key.
routing_key = ''
#: Default serializer to use. Default is JSON.
serializer = None
#: Default compression method. Disabled by default.
compression = None
#: By default, if a default exchange is set,
#: that exchange will be declare when publishing a message.
auto_declare = True
#: Basic return callback.
on_return = None
#: Set if channel argument was a Connection instance (using
#: default_channel).
__connection__ = None
def __init__(self, channel, exchange=None, routing_key=None,
serializer=None, auto_declare=None, compression=None,
on_return=None):
self._channel = channel
self.exchange = exchange
self.routing_key = routing_key or self.routing_key
self.serializer = serializer or self.serializer
self.compression = compression or self.compression
self.on_return = on_return or self.on_return
self._channel_promise = None
if self.exchange is None:
self.exchange = Exchange('')
if auto_declare is not None:
self.auto_declare = auto_declare
if self._channel:
self.revive(self._channel)
def __repr__(self):
return f'<Producer: {self._channel}>'
def __reduce__(self):
return self.__class__, self.__reduce_args__()
def __reduce_args__(self):
return (None, self.exchange, self.routing_key, self.serializer,
self.auto_declare, self.compression)
def declare(self):
"""Declare the exchange.
Note:
----
This happens automatically at instantiation when
the :attr:`auto_declare` flag is enabled.
"""
if self.exchange.name:
self.exchange.declare()
def maybe_declare(self, entity, retry=False, **retry_policy):
"""Declare exchange if not already declared during this session."""
if entity:
return maybe_declare(entity, self.channel, retry, **retry_policy)
def _delivery_details(self, exchange, delivery_mode=None,
maybe_delivery_mode=maybe_delivery_mode,
Exchange=Exchange):
if isinstance(exchange, Exchange):
return exchange.name, maybe_delivery_mode(
delivery_mode or exchange.delivery_mode,
)
# exchange is string, so inherit the delivery
# mode of our default exchange.
return exchange, maybe_delivery_mode(
delivery_mode or self.exchange.delivery_mode,
)
def publish(self, body, routing_key=None, delivery_mode=None,
mandatory=False, immediate=False, priority=0,
content_type=None, content_encoding=None, serializer=None,
headers=None, compression=None, exchange=None, retry=False,
retry_policy=None, declare=None, expiration=None, timeout=None,
confirm_timeout=None,
**properties):
"""Publish message to the specified exchange.
Arguments:
---------
body (Any): Message body.
routing_key (str): Message routing key.
delivery_mode (enum): See :attr:`delivery_mode`.
mandatory (bool): Currently not supported.
immediate (bool): Currently not supported.
priority (int): Message priority. A number between 0 and 9.
content_type (str): Content type. Default is auto-detect.
content_encoding (str): Content encoding. Default is auto-detect.
serializer (str): Serializer to use. Default is auto-detect.
compression (str): Compression method to use. Default is none.
headers (Dict): Mapping of arbitrary headers to pass along
with the message body.
exchange (kombu.entity.Exchange, str): Override the exchange.
Note that this exchange must have been declared.
declare (Sequence[EntityT]): Optional list of required entities
that must have been declared before publishing the message.
The entities will be declared using
:func:`~kombu.common.maybe_declare`.
retry (bool): Retry publishing, or declaring entities if the
connection is lost.
retry_policy (Dict): Retry configuration, this is the keywords
supported by :meth:`~kombu.Connection.ensure`.
expiration (float): A TTL in seconds can be specified per message.
Default is no expiration.
timeout (float): Set timeout to wait maximum timeout second
for message to publish.
confirm_timeout (float): Set confirm timeout to wait maximum timeout second
for message to confirm publishing if the channel is set to confirm publish mode.
**properties (Any): Additional message properties, see AMQP spec.
"""
_publish = self._publish
declare = [] if declare is None else declare
headers = {} if headers is None else headers
retry_policy = {} if retry_policy is None else retry_policy
routing_key = self.routing_key if routing_key is None else routing_key
compression = self.compression if compression is None else compression
exchange_name, properties['delivery_mode'] = self._delivery_details(
exchange or self.exchange, delivery_mode,
)
if expiration is not None:
properties['expiration'] = str(int(expiration * 1000))
body, content_type, content_encoding = self._prepare(
body, serializer, content_type, content_encoding,
compression, headers)
if self.auto_declare and self.exchange.name:
if self.exchange not in declare:
# XXX declare should be a Set.
declare.append(self.exchange)
if retry:
self.connection.transport_options.update(retry_policy)
_publish = self.connection.ensure(self, _publish, **retry_policy)
return _publish(
body, priority, content_type, content_encoding,
headers, properties, routing_key, mandatory, immediate,
exchange_name, declare, timeout, confirm_timeout, retry, retry_policy
)
def _publish(self, body, priority, content_type, content_encoding,
headers, properties, routing_key, mandatory,
immediate, exchange, declare, timeout=None, confirm_timeout=None, retry=False, retry_policy=None):
retry_policy = {} if retry_policy is None else retry_policy
channel = self.channel
message = channel.prepare_message(
body, priority, content_type,
content_encoding, headers, properties,
)
if declare:
maybe_declare = self.maybe_declare
for entity in declare:
maybe_declare(entity, retry=retry, **retry_policy)
# handle autogenerated queue names for reply_to
reply_to = properties.get('reply_to')
if isinstance(reply_to, Queue):
properties['reply_to'] = reply_to.name
return channel.basic_publish(
message,
exchange=exchange, routing_key=routing_key,
mandatory=mandatory, immediate=immediate,
timeout=timeout, confirm_timeout=confirm_timeout
)
def _get_channel(self):
channel = self._channel
if isinstance(channel, ChannelPromise):
channel = self._channel = channel()
self.exchange.revive(channel)
if self.on_return:
channel.events['basic_return'].add(self.on_return)
return channel
def _set_channel(self, channel):
self._channel = channel
channel = property(_get_channel, _set_channel)
def revive(self, channel):
"""Revive the producer after connection loss."""
if is_connection(channel):
connection = channel
self.__connection__ = connection
channel = ChannelPromise(lambda: connection.default_channel)
if isinstance(channel, ChannelPromise):
self._channel = channel
self.exchange = self.exchange(channel)
else:
# Channel already concrete
self._channel = channel
if self.on_return:
self._channel.events['basic_return'].add(self.on_return)
self.exchange = self.exchange(channel)
def __enter__(self):
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None
) -> None:
# In case the connection is part of a pool it needs to be
# replaced in case of an exception
if self.__connection__ is not None and exc_type is not None:
if isinstance(self.__connection__, PooledConnection):
self.__connection__._pool.replace(self.__connection__)
self.release()
def release(self):
pass
close = release
def _prepare(self, body, serializer=None, content_type=None,
content_encoding=None, compression=None, headers=None):
# No content_type? Then we're serializing the data internally.
if not content_type:
serializer = serializer or self.serializer
(content_type, content_encoding,
body) = dumps(body, serializer=serializer)
else:
# If the programmer doesn't want us to serialize,
# make sure content_encoding is set.
if isinstance(body, str):
if not content_encoding:
content_encoding = 'utf-8'
body = body.encode(content_encoding)
# If they passed in a string, we can't know anything
# about it. So assume it's binary data.
elif not content_encoding:
content_encoding = 'binary'
if compression:
body, headers['compression'] = compress(body, compression)
return body, content_type, content_encoding
@property
def connection(self):
try:
return self.__connection__ or self.channel.connection.client
except AttributeError:
pass
class Consumer:
"""Message consumer.
Arguments:
---------
channel (kombu.Connection, ChannelT): see :attr:`channel`.
queues (Sequence[kombu.Queue]): see :attr:`queues`.
no_ack (bool): see :attr:`no_ack`.
auto_declare (bool): see :attr:`auto_declare`
callbacks (Sequence[Callable]): see :attr:`callbacks`.
on_message (Callable): See :attr:`on_message`
on_decode_error (Callable): see :attr:`on_decode_error`.
prefetch_count (int): see :attr:`prefetch_count`.
"""
ContentDisallowed = ContentDisallowed
#: The connection/channel to use for this consumer.
channel = None
#: A single :class:`~kombu.Queue`, or a list of queues to
#: consume from.
queues = None
#: Flag for automatic message acknowledgment.
#: If enabled the messages are automatically acknowledged by the
#: broker. This can increase performance but means that you
#: have no control of when the message is removed.
#:
#: Disabled by default.
no_ack = None
#: By default all entities will be declared at instantiation, if you
#: want to handle this manually you can set this to :const:`False`.
auto_declare = True
#: List of callbacks called in order when a message is received.
#:
#: The signature of the callbacks must take two arguments:
#: `(body, message)`, which is the decoded message body and
#: the :class:`~kombu.Message` instance.
callbacks = None
#: Optional function called whenever a message is received.
#:
#: When defined this function will be called instead of the
#: :meth:`receive` method, and :attr:`callbacks` will be disabled.
#:
#: So this can be used as an alternative to :attr:`callbacks` when
#: you don't want the body to be automatically decoded.
#: Note that the message will still be decompressed if the message
#: has the ``compression`` header set.
#:
#: The signature of the callback must take a single argument,
#: which is the :class:`~kombu.Message` object.
#:
#: Also note that the ``message.body`` attribute, which is the raw
#: contents of the message body, may in some cases be a read-only
#: :class:`buffer` object.
on_message = None
#: Callback called when a message can't be decoded.
#:
#: The signature of the callback must take two arguments: `(message,
#: exc)`, which is the message that can't be decoded and the exception
#: that occurred while trying to decode it.
on_decode_error = None
#: List of accepted content-types.
#:
#: An exception will be raised if the consumer receives
#: a message with an untrusted content type.
#: By default all content-types are accepted, but not if
#: :func:`kombu.disable_untrusted_serializers` was called,
#: in which case only json is allowed.
accept = None
#: Initial prefetch count
#:
#: If set, the consumer will set the prefetch_count QoS value at startup.
#: Can also be changed using :meth:`qos`.
prefetch_count = None
#: Mapping of queues we consume from.
_queues = None
_tags = count(1) # global
def __init__(self, channel, queues=None, no_ack=None, auto_declare=None,
callbacks=None, on_decode_error=None, on_message=None,
accept=None, prefetch_count=None, tag_prefix=None):
self.channel = channel
self.queues = maybe_list(queues or [])
self.no_ack = self.no_ack if no_ack is None else no_ack
self.callbacks = (self.callbacks or [] if callbacks is None
else callbacks)
self.on_message = on_message
self.tag_prefix = tag_prefix
self._active_tags = {}
if auto_declare is not None:
self.auto_declare = auto_declare
if on_decode_error is not None:
self.on_decode_error = on_decode_error
self.accept = prepare_accept_content(accept)
self.prefetch_count = prefetch_count
if self.channel:
self.revive(self.channel)
@property
def queues(self): # noqa
return list(self._queues.values())
@queues.setter
def queues(self, queues):
self._queues = {q.name: q for q in queues}
def revive(self, channel):
"""Revive consumer after connection loss."""
self._active_tags.clear()
channel = self.channel = maybe_channel(channel)
# modify dict size while iterating over it is not allowed
for qname, queue in list(self._queues.items()):
# name may have changed after declare
self._queues.pop(qname, None)
queue = self._queues[queue.name] = queue(self.channel)
queue.revive(channel)
if self.auto_declare:
self.declare()
if self.prefetch_count is not None:
self.qos(prefetch_count=self.prefetch_count)
def declare(self):
"""Declare queues, exchanges and bindings.
Note:
----
This is done automatically at instantiation
when :attr:`auto_declare` is set.
"""
for queue in self._queues.values():
queue.declare()
def register_callback(self, callback):
"""Register a new callback to be called when a message is received.
Note:
----
The signature of the callback needs to accept two arguments:
`(body, message)`, which is the decoded message body
and the :class:`~kombu.Message` instance.
"""
self.callbacks.append(callback)
def __enter__(self):
self.consume()
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None
) -> None:
if self.channel and self.channel.connection:
conn_errors = self.channel.connection.client.connection_errors
if not isinstance(exc_val, conn_errors):
try:
self.cancel()
except Exception:
pass
def add_queue(self, queue):
"""Add a queue to the list of queues to consume from.
Note:
----
This will not start consuming from the queue,
for that you will have to call :meth:`consume` after.
"""
queue = queue(self.channel)
if self.auto_declare:
queue.declare()
self._queues[queue.name] = queue
return queue
def consume(self, no_ack=None):
"""Start consuming messages.
Can be called multiple times, but note that while it
will consume from new queues added since the last call,
it will not cancel consuming from removed queues (
use :meth:`cancel_by_queue`).
Arguments:
---------
no_ack (bool): See :attr:`no_ack`.
"""
queues = list(self._queues.values())
if queues:
no_ack = self.no_ack if no_ack is None else no_ack
H, T = queues[:-1], queues[-1]
for queue in H:
self._basic_consume(queue, no_ack=no_ack, nowait=True)
self._basic_consume(T, no_ack=no_ack, nowait=False)
def cancel(self):
"""End all active queue consumers.
Note:
----
This does not affect already delivered messages, but it does
mean the server will not send any more messages for this consumer.
"""
cancel = self.channel.basic_cancel
for tag in self._active_tags.values():
cancel(tag)
self._active_tags.clear()
close = cancel
def cancel_by_queue(self, queue):
"""Cancel consumer by queue name."""
qname = queue.name if isinstance(queue, Queue) else queue
try:
tag = self._active_tags.pop(qname)
except KeyError:
pass
else:
self.channel.basic_cancel(tag)
finally:
self._queues.pop(qname, None)
def consuming_from(self, queue):
"""Return :const:`True` if currently consuming from queue'."""
name = queue
if isinstance(queue, Queue):
name = queue.name
return name in self._active_tags
def purge(self):
"""Purge messages from all queues.
Warning:
-------
This will *delete all ready messages*, there is no undo operation.
"""
return sum(queue.purge() for queue in self._queues.values())
def flow(self, active):
"""Enable/disable flow from peer.
This is a simple flow-control mechanism that a peer can use
to avoid overflowing its queues or otherwise finding itself
receiving more messages than it can process.
The peer that receives a request to stop sending content
will finish sending the current content (if any), and then wait
until flow is reactivated.
"""
self.channel.flow(active)
def qos(self, prefetch_size=0, prefetch_count=0, apply_global=False):
"""Specify quality of service.
The client can request that messages should be sent in
advance so that when the client finishes processing a message,
the following message is already held locally, rather than needing
to be sent down the channel. Prefetching gives a performance
improvement.
The prefetch window is Ignored if the :attr:`no_ack` option is set.
Arguments:
---------
prefetch_size (int): Specify the prefetch window in octets.
The server will send a message in advance if it is equal to
or smaller in size than the available prefetch size (and
also falls within other prefetch limits). May be set to zero,
meaning "no specific limit", although other prefetch limits
may still apply.
prefetch_count (int): Specify the prefetch window in terms of
whole messages.
apply_global (bool): Apply new settings globally on all channels.
"""
return self.channel.basic_qos(prefetch_size,
prefetch_count,
apply_global)
def recover(self, requeue=False):
"""Redeliver unacknowledged messages.
Asks the broker to redeliver all unacknowledged messages
on the specified channel.
Arguments:
---------
requeue (bool): By default the messages will be redelivered
to the original recipient. With `requeue` set to true, the
server will attempt to requeue the message, potentially then
delivering it to an alternative subscriber.
"""
return self.channel.basic_recover(requeue=requeue)
def receive(self, body, message):
"""Method called when a message is received.
This dispatches to the registered :attr:`callbacks`.
Arguments:
---------
body (Any): The decoded message body.
message (~kombu.Message): The message instance.
Raises
------
NotImplementedError: If no consumer callbacks have been
registered.
"""
callbacks = self.callbacks
if not callbacks:
raise NotImplementedError('Consumer does not have any callbacks')
[callback(body, message) for callback in callbacks]
def _basic_consume(self, queue, consumer_tag=None,
no_ack=no_ack, nowait=True):
tag = self._active_tags.get(queue.name)
if tag is None:
tag = self._add_tag(queue, consumer_tag)
queue.consume(tag, self._receive_callback,
no_ack=no_ack, nowait=nowait)
return tag
def _add_tag(self, queue, consumer_tag=None):
tag = consumer_tag or '{}{}'.format(
self.tag_prefix, next(self._tags))
self._active_tags[queue.name] = tag
return tag
def _receive_callback(self, message):
accept = self.accept
on_m, channel, decoded = self.on_message, self.channel, None
try:
m2p = getattr(channel, 'message_to_python', None)
if m2p:
message = m2p(message)
if accept is not None:
message.accept = accept
if message.errors:
return message._reraise_error(self.on_decode_error)
decoded = None if on_m else message.decode()
except Exception as exc:
if not self.on_decode_error:
raise
self.on_decode_error(message, exc)
else:
return on_m(message) if on_m else self.receive(decoded, message)
def __repr__(self):
return f'<{type(self).__name__}: {self.queues}>'
@property
def connection(self):
try:
return self.channel.connection.client
except AttributeError:
pass

View File

@@ -0,0 +1,303 @@
"""Mixins."""
from __future__ import annotations
import socket
from contextlib import contextmanager
from functools import partial
from itertools import count
from time import sleep
from .common import ignore_errors
from .log import get_logger
from .messaging import Consumer, Producer
from .utils.compat import nested
from .utils.encoding import safe_repr
from .utils.limits import TokenBucket
from .utils.objects import cached_property
__all__ = ('ConsumerMixin', 'ConsumerProducerMixin')
logger = get_logger(__name__)
debug, info, warn, error = (
logger.debug,
logger.info,
logger.warning,
logger.error
)
W_CONN_LOST = """\
Connection to broker lost, trying to re-establish connection...\
"""
W_CONN_ERROR = """\
Broker connection error, trying again in %s seconds: %r.\
"""
class ConsumerMixin:
"""Convenience mixin for implementing consumer programs.
It can be used outside of threads, with threads, or greenthreads
(eventlet/gevent) too.
The basic class would need a :attr:`connection` attribute
which must be a :class:`~kombu.Connection` instance,
and define a :meth:`get_consumers` method that returns a list
of :class:`kombu.Consumer` instances to use.
Supporting multiple consumers is important so that multiple
channels can be used for different QoS requirements.
Example:
-------
.. code-block:: python
class Worker(ConsumerMixin):
task_queue = Queue('tasks', Exchange('tasks'), 'tasks')
def __init__(self, connection):
self.connection = None
def get_consumers(self, Consumer, channel):
return [Consumer(queues=[self.task_queue],
callbacks=[self.on_task])]
def on_task(self, body, message):
print('Got task: {0!r}'.format(body))
message.ack()
Methods
-------
* :meth:`extra_context`
Optional extra context manager that will be entered
after the connection and consumers have been set up.
Takes arguments ``(connection, channel)``.
* :meth:`on_connection_error`
Handler called if the connection is lost/ or
is unavailable.
Takes arguments ``(exc, interval)``, where interval
is the time in seconds when the connection will be retried.
The default handler will log the exception.
* :meth:`on_connection_revived`
Handler called as soon as the connection is re-established
after connection failure.
Takes no arguments.
* :meth:`on_consume_ready`
Handler called when the consumer is ready to accept
messages.
Takes arguments ``(connection, channel, consumers)``.
Also keyword arguments to ``consume`` are forwarded
to this handler.
* :meth:`on_consume_end`
Handler called after the consumers are canceled.
Takes arguments ``(connection, channel)``.
* :meth:`on_iteration`
Handler called for every iteration while draining
events.
Takes no arguments.
* :meth:`on_decode_error`
Handler called if a consumer was unable to decode
the body of a message.
Takes arguments ``(message, exc)`` where message is the
original message object.
The default handler will log the error and
acknowledge the message, so if you override make
sure to call super, or perform these steps yourself.
"""
#: maximum number of retries trying to re-establish the connection,
#: if the connection is lost/unavailable.
connect_max_retries = None
#: When this is set to true the consumer should stop consuming
#: and return, so that it can be joined if it is the implementation
#: of a thread.
should_stop = False
def get_consumers(self, Consumer, channel):
raise NotImplementedError('Subclass responsibility')
def on_connection_revived(self):
pass
def on_consume_ready(self, connection, channel, consumers, **kwargs):
pass
def on_consume_end(self, connection, channel):
pass
def on_iteration(self):
pass
def on_decode_error(self, message, exc):
error("Can't decode message body: %r (type:%r encoding:%r raw:%r')",
exc, message.content_type, message.content_encoding,
safe_repr(message.body))
message.ack()
def on_connection_error(self, exc, interval):
warn(W_CONN_ERROR, interval, exc, exc_info=1)
@contextmanager
def extra_context(self, connection, channel):
yield
def run(self, _tokens=1, **kwargs):
restart_limit = self.restart_limit
errors = (self.connection.connection_errors +
self.connection.channel_errors)
while not self.should_stop:
try:
if restart_limit.can_consume(_tokens): # pragma: no cover
for _ in self.consume(limit=None, **kwargs):
pass
else:
sleep(restart_limit.expected_time(_tokens))
except errors:
warn(W_CONN_LOST, exc_info=1)
@contextmanager
def consumer_context(self, **kwargs):
with self.Consumer() as (connection, channel, consumers):
with self.extra_context(connection, channel):
self.on_consume_ready(connection, channel, consumers, **kwargs)
yield connection, channel, consumers
def consume(self, limit=None, timeout=None, safety_interval=1, **kwargs):
elapsed = 0
with self.consumer_context(**kwargs) as (conn, channel, consumers):
for i in limit and range(limit) or count():
if self.should_stop:
break
self.on_iteration()
try:
conn.drain_events(timeout=safety_interval)
except socket.timeout:
conn.heartbeat_check()
elapsed += safety_interval
if timeout and elapsed >= timeout:
raise
except OSError:
if not self.should_stop:
raise
else:
yield
elapsed = 0
debug('consume exiting')
def maybe_conn_error(self, fun):
"""Use :func:`kombu.common.ignore_errors` instead."""
return ignore_errors(self, fun)
def create_connection(self):
return self.connection.clone()
@contextmanager
def establish_connection(self):
with self.create_connection() as conn:
conn.ensure_connection(self.on_connection_error,
self.connect_max_retries)
yield conn
@contextmanager
def Consumer(self):
with self.establish_connection() as conn:
self.on_connection_revived()
info('Connected to %s', conn.as_uri())
channel = conn.default_channel
cls = partial(Consumer, channel,
on_decode_error=self.on_decode_error)
with self._consume_from(*self.get_consumers(cls, channel)) as c:
yield conn, channel, c
debug('Consumers canceled')
self.on_consume_end(conn, channel)
debug('Connection closed')
def _consume_from(self, *consumers):
return nested(*consumers)
@cached_property
def restart_limit(self):
return TokenBucket(1)
@cached_property
def connection_errors(self):
return self.connection.connection_errors
@cached_property
def channel_errors(self):
return self.connection.channel_errors
class ConsumerProducerMixin(ConsumerMixin):
"""Consumer and Producer mixin.
Version of ConsumerMixin having separate connection for also
publishing messages.
Example:
-------
.. code-block:: python
class Worker(ConsumerProducerMixin):
def __init__(self, connection):
self.connection = connection
def get_consumers(self, Consumer, channel):
return [Consumer(queues=Queue('foo'),
on_message=self.handle_message,
accept='application/json',
prefetch_count=10)]
def handle_message(self, message):
self.producer.publish(
{'message': 'hello to you'},
exchange='',
routing_key=message.properties['reply_to'],
correlation_id=message.properties['correlation_id'],
retry=True,
)
"""
_producer_connection = None
def on_consume_end(self, connection, channel):
if self._producer_connection is not None:
self._producer_connection.close()
self._producer_connection = None
@property
def producer(self):
return Producer(self.producer_connection)
@property
def producer_connection(self):
if self._producer_connection is None:
conn = self.connection.clone()
conn.ensure_connection(self.on_connection_error,
self.connect_max_retries)
self._producer_connection = conn
return self._producer_connection

View File

@@ -0,0 +1,413 @@
"""Generic process mailbox."""
from __future__ import annotations
import socket
import warnings
from collections import defaultdict, deque
from contextlib import contextmanager
from copy import copy
from itertools import count
from time import time
from . import Consumer, Exchange, Producer, Queue
from .clocks import LamportClock
from .common import maybe_declare, oid_from
from .exceptions import InconsistencyError
from .log import get_logger
from .matcher import match
from .utils.functional import maybe_evaluate, reprcall
from .utils.objects import cached_property
from .utils.uuid import uuid
REPLY_QUEUE_EXPIRES = 10
W_PIDBOX_IN_USE = """\
A node named {node.hostname} is already using this process mailbox!
Maybe you forgot to shutdown the other node or did not do so properly?
Or if you meant to start multiple nodes on the same host please make sure
you give each node a unique node name!
"""
__all__ = ('Node', 'Mailbox')
logger = get_logger(__name__)
debug, error = logger.debug, logger.error
class Node:
"""Mailbox node."""
#: hostname of the node.
hostname = None
#: the :class:`Mailbox` this is a node for.
mailbox = None
#: map of method name/handlers.
handlers = None
#: current context (passed on to handlers)
state = None
#: current channel.
channel = None
def __init__(self, hostname, state=None, channel=None,
handlers=None, mailbox=None):
self.channel = channel
self.mailbox = mailbox
self.hostname = hostname
self.state = state
self.adjust_clock = self.mailbox.clock.adjust
if handlers is None:
handlers = {}
self.handlers = handlers
def Consumer(self, channel=None, no_ack=True, accept=None, **options):
queue = self.mailbox.get_queue(self.hostname)
def verify_exclusive(name, messages, consumers):
if consumers:
warnings.warn(W_PIDBOX_IN_USE.format(node=self))
queue.on_declared = verify_exclusive
return Consumer(
channel or self.channel, [queue], no_ack=no_ack,
accept=self.mailbox.accept if accept is None else accept,
**options
)
def handler(self, fun):
self.handlers[fun.__name__] = fun
return fun
def on_decode_error(self, message, exc):
error('Cannot decode message: %r', exc, exc_info=1)
def listen(self, channel=None, callback=None):
consumer = self.Consumer(channel=channel,
callbacks=[callback or self.handle_message],
on_decode_error=self.on_decode_error)
consumer.consume()
return consumer
def dispatch(self, method, arguments=None,
reply_to=None, ticket=None, **kwargs):
arguments = arguments or {}
debug('pidbox received method %s [reply_to:%s ticket:%s]',
reprcall(method, (), kwargs=arguments), reply_to, ticket)
handle = reply_to and self.handle_call or self.handle_cast
try:
reply = handle(method, arguments)
except SystemExit:
raise
except Exception as exc:
error('pidbox command error: %r', exc, exc_info=1)
reply = {'error': repr(exc)}
if reply_to:
self.reply({self.hostname: reply},
exchange=reply_to['exchange'],
routing_key=reply_to['routing_key'],
ticket=ticket)
return reply
def handle(self, method, arguments=None):
arguments = {} if not arguments else arguments
return self.handlers[method](self.state, **arguments)
def handle_call(self, method, arguments):
return self.handle(method, arguments)
def handle_cast(self, method, arguments):
return self.handle(method, arguments)
def handle_message(self, body, message=None):
destination = body.get('destination')
pattern = body.get('pattern')
matcher = body.get('matcher')
if message:
self.adjust_clock(message.headers.get('clock') or 0)
hostname = self.hostname
run_dispatch = False
if destination:
if hostname in destination:
run_dispatch = True
elif pattern and matcher:
if match(hostname, pattern, matcher):
run_dispatch = True
else:
run_dispatch = True
if run_dispatch:
return self.dispatch(**body)
dispatch_from_message = handle_message
def reply(self, data, exchange, routing_key, ticket, **kwargs):
self.mailbox._publish_reply(data, exchange, routing_key, ticket,
channel=self.channel,
serializer=self.mailbox.serializer)
class Mailbox:
"""Process Mailbox."""
node_cls = Node
exchange_fmt = '%s.pidbox'
reply_exchange_fmt = 'reply.%s.pidbox'
#: Name of application.
namespace = None
#: Connection (if bound).
connection = None
#: Exchange type (usually direct, or fanout for broadcast).
type = 'direct'
#: mailbox exchange (init by constructor).
exchange = None
#: exchange to send replies to.
reply_exchange = None
#: Only accepts json messages by default.
accept = ['json']
#: Message serializer
serializer = None
def __init__(self, namespace,
type='direct', connection=None, clock=None,
accept=None, serializer=None, producer_pool=None,
queue_ttl=None, queue_expires=None,
reply_queue_ttl=None, reply_queue_expires=10.0):
self.namespace = namespace
self.connection = connection
self.type = type
self.clock = LamportClock() if clock is None else clock
self.exchange = self._get_exchange(self.namespace, self.type)
self.reply_exchange = self._get_reply_exchange(self.namespace)
self.unclaimed = defaultdict(deque)
self.accept = self.accept if accept is None else accept
self.serializer = self.serializer if serializer is None else serializer
self.queue_ttl = queue_ttl
self.queue_expires = queue_expires
self.reply_queue_ttl = reply_queue_ttl
self.reply_queue_expires = reply_queue_expires
self._producer_pool = producer_pool
def __call__(self, connection):
bound = copy(self)
bound.connection = connection
return bound
def Node(self, hostname=None, state=None, channel=None, handlers=None):
hostname = hostname or socket.gethostname()
return self.node_cls(hostname, state, channel, handlers, mailbox=self)
def call(self, destination, command, kwargs=None,
timeout=None, callback=None, channel=None):
kwargs = {} if not kwargs else kwargs
return self._broadcast(command, kwargs, destination,
reply=True, timeout=timeout,
callback=callback,
channel=channel)
def cast(self, destination, command, kwargs=None):
kwargs = {} if not kwargs else kwargs
return self._broadcast(command, kwargs, destination, reply=False)
def abcast(self, command, kwargs=None):
kwargs = {} if not kwargs else kwargs
return self._broadcast(command, kwargs, reply=False)
def multi_call(self, command, kwargs=None, timeout=1,
limit=None, callback=None, channel=None):
kwargs = {} if not kwargs else kwargs
return self._broadcast(command, kwargs, reply=True,
timeout=timeout, limit=limit,
callback=callback,
channel=channel)
def get_reply_queue(self):
oid = self.oid
return Queue(
f'{oid}.{self.reply_exchange.name}',
exchange=self.reply_exchange,
routing_key=oid,
durable=False,
auto_delete=True,
expires=self.reply_queue_expires,
message_ttl=self.reply_queue_ttl,
)
@cached_property
def reply_queue(self):
return self.get_reply_queue()
def get_queue(self, hostname):
return Queue(
f'{hostname}.{self.namespace}.pidbox',
exchange=self.exchange,
durable=False,
auto_delete=True,
expires=self.queue_expires,
message_ttl=self.queue_ttl,
)
@contextmanager
def producer_or_acquire(self, producer=None, channel=None):
if producer:
yield producer
elif self.producer_pool:
with self.producer_pool.acquire() as producer:
yield producer
else:
yield Producer(channel, auto_declare=False)
def _publish_reply(self, reply, exchange, routing_key, ticket,
channel=None, producer=None, **opts):
chan = channel or self.connection.default_channel
exchange = Exchange(exchange, exchange_type='direct',
delivery_mode='transient',
durable=False)
with self.producer_or_acquire(producer, chan) as producer:
try:
producer.publish(
reply, exchange=exchange, routing_key=routing_key,
declare=[exchange], headers={
'ticket': ticket, 'clock': self.clock.forward(),
}, retry=True,
**opts
)
except InconsistencyError:
# queue probably deleted and no one is expecting a reply.
pass
def _publish(self, type, arguments, destination=None,
reply_ticket=None, channel=None, timeout=None,
serializer=None, producer=None, pattern=None, matcher=None):
message = {'method': type,
'arguments': arguments,
'destination': destination,
'pattern': pattern,
'matcher': matcher}
chan = channel or self.connection.default_channel
exchange = self.exchange
if reply_ticket:
maybe_declare(self.reply_queue(chan))
message.update(ticket=reply_ticket,
reply_to={'exchange': self.reply_exchange.name,
'routing_key': self.oid})
serializer = serializer or self.serializer
with self.producer_or_acquire(producer, chan) as producer:
producer.publish(
message, exchange=exchange.name, declare=[exchange],
headers={'clock': self.clock.forward(),
'expires': time() + timeout if timeout else 0},
serializer=serializer, retry=True,
)
def _broadcast(self, command, arguments=None, destination=None,
reply=False, timeout=1, limit=None,
callback=None, channel=None, serializer=None,
pattern=None, matcher=None):
if destination is not None and \
not isinstance(destination, (list, tuple)):
raise ValueError(
'destination must be a list/tuple not {}'.format(
type(destination)))
if (pattern is not None and not isinstance(pattern, str) and
matcher is not None and not isinstance(matcher, str)):
raise ValueError(
'pattern and matcher must be '
'strings not {}, {}'.format(type(pattern), type(matcher))
)
arguments = arguments or {}
reply_ticket = reply and uuid() or None
chan = channel or self.connection.default_channel
# Set reply limit to number of destinations (if specified)
if limit is None and destination:
limit = destination and len(destination) or None
serializer = serializer or self.serializer
self._publish(command, arguments, destination=destination,
reply_ticket=reply_ticket,
channel=chan,
timeout=timeout,
serializer=serializer,
pattern=pattern,
matcher=matcher)
if reply_ticket:
return self._collect(reply_ticket, limit=limit,
timeout=timeout,
callback=callback,
channel=chan)
def _collect(self, ticket,
limit=None, timeout=1, callback=None,
channel=None, accept=None):
if accept is None:
accept = self.accept
chan = channel or self.connection.default_channel
queue = self.reply_queue
consumer = Consumer(chan, [queue], accept=accept, no_ack=True)
responses = []
unclaimed = self.unclaimed
adjust_clock = self.clock.adjust
try:
return unclaimed.pop(ticket)
except KeyError:
pass
def on_message(body, message):
# ticket header added in kombu 2.5
header = message.headers.get
adjust_clock(header('clock') or 0)
expires = header('expires')
if expires and time() > expires:
return
this_id = header('ticket', ticket)
if this_id == ticket:
if callback:
callback(body)
responses.append(body)
else:
unclaimed[this_id].append(body)
consumer.register_callback(on_message)
try:
with consumer:
for i in limit and range(limit) or count():
try:
self.connection.drain_events(timeout=timeout)
except socket.timeout:
break
return responses
finally:
chan.after_reply_message_received(queue.name)
def _get_exchange(self, namespace, type):
return Exchange(self.exchange_fmt % namespace,
type=type,
durable=False,
delivery_mode='transient')
def _get_reply_exchange(self, namespace):
return Exchange(self.reply_exchange_fmt % namespace,
type='direct',
durable=False,
delivery_mode='transient')
@property
def oid(self):
return oid_from(self)
@cached_property
def producer_pool(self):
return maybe_evaluate(self._producer_pool)

View File

@@ -0,0 +1,152 @@
"""Public resource pools."""
from __future__ import annotations
import os
from itertools import chain
from .connection import Resource
from .messaging import Producer
from .utils.collections import EqualityDict
from .utils.compat import register_after_fork
from .utils.functional import lazy
__all__ = ('ProducerPool', 'PoolGroup', 'register_group',
'connections', 'producers', 'get_limit', 'set_limit', 'reset')
_limit = [10]
_groups = []
use_global_limit = object()
disable_limit_protection = os.environ.get('KOMBU_DISABLE_LIMIT_PROTECTION')
def _after_fork_cleanup_group(group):
group.clear()
class ProducerPool(Resource):
"""Pool of :class:`kombu.Producer` instances."""
Producer = Producer
close_after_fork = True
def __init__(self, connections, *args, **kwargs):
self.connections = connections
self.Producer = kwargs.pop('Producer', None) or self.Producer
super().__init__(*args, **kwargs)
def _acquire_connection(self):
return self.connections.acquire(block=True)
def create_producer(self):
conn = self._acquire_connection()
try:
return self.Producer(conn)
except BaseException:
conn.release()
raise
def new(self):
return lazy(self.create_producer)
def setup(self):
if self.limit:
for _ in range(self.limit):
self._resource.put_nowait(self.new())
def close_resource(self, resource):
pass
def prepare(self, p):
if callable(p):
p = p()
if p._channel is None:
conn = self._acquire_connection()
try:
p.revive(conn)
except BaseException:
conn.release()
raise
return p
def release(self, resource):
if resource.__connection__:
resource.__connection__.release()
resource.channel = None
super().release(resource)
class PoolGroup(EqualityDict):
"""Collection of resource pools."""
def __init__(self, limit=None, close_after_fork=True):
self.limit = limit
self.close_after_fork = close_after_fork
if self.close_after_fork and register_after_fork is not None:
register_after_fork(self, _after_fork_cleanup_group)
def create(self, resource, limit):
raise NotImplementedError('PoolGroups must define ``create``')
def __missing__(self, resource):
limit = self.limit
if limit is use_global_limit:
limit = get_limit()
k = self[resource] = self.create(resource, limit)
return k
def register_group(group):
"""Register group (can be used as decorator)."""
_groups.append(group)
return group
class Connections(PoolGroup):
"""Collection of connection pools."""
def create(self, connection, limit):
return connection.Pool(limit=limit)
connections = register_group(Connections(limit=use_global_limit))
class Producers(PoolGroup):
"""Collection of producer pools."""
def create(self, connection, limit):
return ProducerPool(connections[connection], limit=limit)
producers = register_group(Producers(limit=use_global_limit))
def _all_pools():
return chain(*((g.values() if g else iter([])) for g in _groups))
def get_limit():
"""Get current connection pool limit."""
return _limit[0]
def set_limit(limit, force=False, reset_after=False, ignore_errors=False):
"""Set new connection pool limit."""
limit = limit or 0
glimit = _limit[0] or 0
if limit != glimit:
_limit[0] = limit
for pool in _all_pools():
pool.resize(limit)
return limit
def reset(*args, **kwargs):
"""Reset all pools by closing open resources."""
for pool in _all_pools():
try:
pool.force_close_all()
except Exception:
pass
for group in _groups:
group.clear()

View File

@@ -0,0 +1,258 @@
"""Generic resource pool implementation."""
from __future__ import annotations
import os
from collections import deque
from queue import Empty
from queue import LifoQueue as _LifoQueue
from typing import TYPE_CHECKING
from . import exceptions
from .utils.compat import register_after_fork
from .utils.functional import lazy
if TYPE_CHECKING:
from types import TracebackType
def _after_fork_cleanup_resource(resource):
try:
resource.force_close_all()
except Exception:
pass
class LifoQueue(_LifoQueue):
"""Last in first out version of Queue."""
def _init(self, maxsize):
self.queue = deque()
class Resource:
"""Pool of resources."""
LimitExceeded = exceptions.LimitExceeded
close_after_fork = False
def __init__(self, limit=None, preload=None, close_after_fork=None):
self._limit = limit
self.preload = preload or 0
self._closed = False
self.close_after_fork = (
close_after_fork
if close_after_fork is not None else self.close_after_fork
)
self._resource = LifoQueue()
self._dirty = set()
if self.close_after_fork and register_after_fork is not None:
register_after_fork(self, _after_fork_cleanup_resource)
self.setup()
def setup(self):
raise NotImplementedError('subclass responsibility')
def _add_when_empty(self):
if self.limit and len(self._dirty) >= self.limit:
raise self.LimitExceeded(self.limit)
# All taken, put new on the queue and
# try get again, this way the first in line
# will get the resource.
self._resource.put_nowait(self.new())
def acquire(self, block=False, timeout=None):
"""Acquire resource.
Arguments:
---------
block (bool): If the limit is exceeded,
then block until there is an available item.
timeout (float): Timeout to wait
if ``block`` is true. Default is :const:`None` (forever).
Raises
------
LimitExceeded: if block is false and the limit has been exceeded.
"""
if self._closed:
raise RuntimeError('Acquire on closed pool')
if self.limit:
while 1:
try:
R = self._resource.get(block=block, timeout=timeout)
except Empty:
self._add_when_empty()
else:
try:
R = self.prepare(R)
except BaseException:
if isinstance(R, lazy):
# not evaluated yet, just put it back
self._resource.put_nowait(R)
else:
# evaluated so must try to release/close first.
self.release(R)
raise
self._dirty.add(R)
break
else:
R = self.prepare(self.new())
def release():
"""Release resource so it can be used by another thread.
Warnings:
--------
The caller is responsible for discarding the object,
and to never use the resource again. A new resource must
be acquired if so needed.
"""
self.release(R)
R.release = release
return R
def prepare(self, resource):
return resource
def close_resource(self, resource):
resource.close()
def release_resource(self, resource):
pass
def replace(self, resource):
"""Replace existing resource with a new instance.
This can be used in case of defective resources.
"""
if self.limit:
self._dirty.discard(resource)
self.close_resource(resource)
def release(self, resource):
if self.limit:
self._dirty.discard(resource)
self._resource.put_nowait(resource)
self.release_resource(resource)
else:
self.close_resource(resource)
def collect_resource(self, resource):
pass
def force_close_all(self, close_pool=True):
"""Close and remove all resources in the pool (also those in use).
Used to close resources from parent processes after fork
(e.g. sockets/connections).
Arguments:
---------
close_pool (bool): If True (default) then the pool is marked
as closed. In case of False the pool can be reused.
"""
if self._closed:
return
self._closed = close_pool
dirty = self._dirty
resource = self._resource
while 1: # - acquired
try:
dres = dirty.pop()
except KeyError:
break
try:
self.collect_resource(dres)
except AttributeError: # Issue #78
pass
while 1: # - available
# deque supports '.clear', but lists do not, so for that
# reason we use pop here, so that the underlying object can
# be any object supporting '.pop' and '.append'.
try:
res = resource.queue.pop()
except IndexError:
break
try:
self.collect_resource(res)
except AttributeError:
pass # Issue #78
def resize(self, limit, force=False, ignore_errors=False, reset=False):
prev_limit = self._limit
if (self._dirty and 0 < limit < self._limit) and not ignore_errors:
if not force:
raise RuntimeError(
"Can't shrink pool when in use: was={} now={}".format(
self._limit, limit))
reset = True
self._limit = limit
if reset:
try:
self.force_close_all(close_pool=False)
except Exception:
pass
self.setup()
if limit < prev_limit:
self._shrink_down(collect=limit > 0)
def _shrink_down(self, collect=True):
class Noop:
def __enter__(self):
pass
def __exit__(
self,
exc_type: type,
exc_val: Exception,
exc_tb: TracebackType
) -> None:
pass
resource = self._resource
# Items to the left are last recently used, so we remove those first.
with getattr(resource, 'mutex', Noop()):
# keep in mind the dirty resources are not shrinking
while len(resource.queue) and \
(len(resource.queue) + len(self._dirty)) > self.limit:
R = resource.queue.popleft()
if collect:
self.collect_resource(R)
@property
def limit(self):
return self._limit
@limit.setter
def limit(self, limit):
self.resize(limit)
if os.environ.get('KOMBU_DEBUG_POOL'): # pragma: no cover
_orig_acquire = acquire
_orig_release = release
_next_resource_id = 0
def acquire(self, *args, **kwargs):
import traceback
id = self._next_resource_id = self._next_resource_id + 1
print(f'+{id} ACQUIRE {self.__class__.__name__}')
r = self._orig_acquire(*args, **kwargs)
r._resource_id = id
print(f'-{id} ACQUIRE {self.__class__.__name__}')
if not hasattr(r, 'acquired_by'):
r.acquired_by = []
r.acquired_by.append(traceback.format_stack())
return r
def release(self, resource):
id = resource._resource_id
print(f'+{id} RELEASE {self.__class__.__name__}')
r = self._orig_release(resource)
print(f'-{id} RELEASE {self.__class__.__name__}')
self._next_resource_id -= 1
return r

View File

@@ -0,0 +1,463 @@
"""Serialization utilities."""
from __future__ import annotations
import codecs
import os
import pickle
import sys
from collections import namedtuple
from contextlib import contextmanager
from io import BytesIO
from .exceptions import (ContentDisallowed, DecodeError, EncodeError,
SerializerNotInstalled, reraise)
from .utils.compat import entrypoints
from .utils.encoding import bytes_to_str, str_to_bytes
__all__ = ('pickle', 'loads', 'dumps', 'register', 'unregister')
SKIP_DECODE = frozenset(['binary', 'ascii-8bit'])
TRUSTED_CONTENT = frozenset(['application/data', 'application/text'])
if sys.platform.startswith('java'): # pragma: no cover
def _decode(t, coding):
return codecs.getdecoder(coding)(t)[0]
else:
_decode = codecs.decode
pickle_load = pickle.load
#: We have to use protocol 4 until we drop support for Python 3.6 and 3.7.
pickle_protocol = int(os.environ.get('PICKLE_PROTOCOL', 4))
codec = namedtuple('codec', ('content_type', 'content_encoding', 'encoder'))
@contextmanager
def _reraise_errors(wrapper,
include=(Exception,), exclude=(SerializerNotInstalled,)):
try:
yield
except exclude:
raise
except include as exc:
reraise(wrapper, wrapper(exc), sys.exc_info()[2])
def pickle_loads(s, load=pickle_load):
# used to support buffer objects
return load(BytesIO(s))
def parenthesize_alias(first, second):
return f'{first} ({second})' if first else second
class SerializerRegistry:
"""The registry keeps track of serialization methods."""
def __init__(self):
self._encoders = {}
self._decoders = {}
self._default_encode = None
self._default_content_type = None
self._default_content_encoding = None
self._disabled_content_types = set()
self.type_to_name = {}
self.name_to_type = {}
def register(self, name, encoder, decoder, content_type,
content_encoding='utf-8'):
"""Register a new encoder/decoder.
Arguments:
---------
name (str): A convenience name for the serialization method.
encoder (callable): A method that will be passed a python data
structure and should return a string representing the
serialized data. If :const:`None`, then only a decoder
will be registered. Encoding will not be possible.
decoder (Callable): A method that will be passed a string
representing serialized data and should return a python
data structure. If :const:`None`, then only an encoder
will be registered. Decoding will not be possible.
content_type (str): The mime-type describing the serialized
structure.
content_encoding (str): The content encoding (character set) that
the `decoder` method will be returning. Will usually be
`utf-8`, `us-ascii`, or `binary`.
"""
if encoder:
self._encoders[name] = codec(
content_type, content_encoding, encoder,
)
if decoder:
self._decoders[content_type] = decoder
self.type_to_name[content_type] = name
self.name_to_type[name] = content_type
def enable(self, name):
if '/' not in name:
name = self.name_to_type[name]
self._disabled_content_types.discard(name)
def disable(self, name):
if '/' not in name:
name = self.name_to_type[name]
self._disabled_content_types.add(name)
def unregister(self, name):
"""Unregister registered encoder/decoder.
Arguments:
---------
name (str): Registered serialization method name.
Raises
------
SerializerNotInstalled: If a serializer by that name
cannot be found.
"""
try:
content_type = self.name_to_type[name]
self._decoders.pop(content_type, None)
self._encoders.pop(name, None)
self.type_to_name.pop(content_type, None)
self.name_to_type.pop(name, None)
except KeyError:
raise SerializerNotInstalled(
f'No encoder/decoder installed for {name}')
def _set_default_serializer(self, name):
"""Set the default serialization method used by this library.
Arguments:
---------
name (str): The name of the registered serialization method.
For example, `json` (default), `pickle`, `yaml`, `msgpack`,
or any custom methods registered using :meth:`register`.
Raises
------
SerializerNotInstalled: If the serialization method
requested is not available.
"""
try:
(self._default_content_type, self._default_content_encoding,
self._default_encode) = self._encoders[name]
except KeyError:
raise SerializerNotInstalled(
f'No encoder installed for {name}')
def dumps(self, data, serializer=None):
"""Encode data.
Serialize a data structure into a string suitable for sending
as an AMQP message body.
Arguments:
---------
data (List, Dict, str): The message data to send.
serializer (str): An optional string representing
the serialization method you want the data marshalled
into. (For example, `json`, `raw`, or `pickle`).
If :const:`None` (default), then json will be used, unless
`data` is a :class:`str` or :class:`unicode` object. In this
latter case, no serialization occurs as it would be
unnecessary.
Note that if `serializer` is specified, then that
serialization method will be used even if a :class:`str`
or :class:`unicode` object is passed in.
Returns
-------
Tuple[str, str, str]: A three-item tuple containing the
content type (e.g., `application/json`), content encoding, (e.g.,
`utf-8`) and a string containing the serialized data.
Raises
------
SerializerNotInstalled: If the serialization method
requested is not available.
"""
if serializer == 'raw':
return raw_encode(data)
if serializer and not self._encoders.get(serializer):
raise SerializerNotInstalled(
f'No encoder installed for {serializer}')
# If a raw string was sent, assume binary encoding
# (it's likely either ASCII or a raw binary file, and a character
# set of 'binary' will encompass both, even if not ideal.
if not serializer and isinstance(data, bytes):
# In Python 3+, this would be "bytes"; allow binary data to be
# sent as a message without getting encoder errors
return 'application/data', 'binary', data
# For Unicode objects, force it into a string
if not serializer and isinstance(data, str):
with _reraise_errors(EncodeError, exclude=()):
payload = data.encode('utf-8')
return 'text/plain', 'utf-8', payload
if serializer:
content_type, content_encoding, encoder = \
self._encoders[serializer]
else:
encoder = self._default_encode
content_type = self._default_content_type
content_encoding = self._default_content_encoding
with _reraise_errors(EncodeError):
payload = encoder(data)
return content_type, content_encoding, payload
def loads(self, data, content_type, content_encoding,
accept=None, force=False, _trusted_content=TRUSTED_CONTENT):
"""Decode serialized data.
Deserialize a data stream as serialized using `dumps`
based on `content_type`.
Arguments:
---------
data (bytes, buffer, str): The message data to deserialize.
content_type (str): The content-type of the data.
(e.g., `application/json`).
content_encoding (str): The content-encoding of the data.
(e.g., `utf-8`, `binary`, or `us-ascii`).
accept (Set): List of content-types to accept.
Raises
------
ContentDisallowed: If the content-type is not accepted.
Returns
-------
Any: The unserialized data.
"""
content_type = (bytes_to_str(content_type) if content_type
else 'application/data')
if accept is not None:
if content_type not in _trusted_content \
and content_type not in accept:
raise self._for_untrusted_content(content_type, 'untrusted')
else:
if content_type in self._disabled_content_types and not force:
raise self._for_untrusted_content(content_type, 'disabled')
content_encoding = (content_encoding or 'utf-8').lower()
if data:
decode = self._decoders.get(content_type)
if decode:
with _reraise_errors(DecodeError):
return decode(data)
if content_encoding not in SKIP_DECODE and \
not isinstance(data, str):
with _reraise_errors(DecodeError):
return _decode(data, content_encoding)
return data
def _for_untrusted_content(self, ctype, why):
return ContentDisallowed(
'Refusing to deserialize {} content of type {}'.format(
why,
parenthesize_alias(self.type_to_name.get(ctype, ctype), ctype),
),
)
#: Global registry of serializers/deserializers.
registry = SerializerRegistry()
dumps = registry.dumps
loads = registry.loads
register = registry.register
unregister = registry.unregister
def raw_encode(data):
"""Special case serializer."""
content_type = 'application/data'
payload = data
if isinstance(payload, str):
content_encoding = 'utf-8'
with _reraise_errors(EncodeError, exclude=()):
payload = payload.encode(content_encoding)
else:
content_encoding = 'binary'
return content_type, content_encoding, payload
def register_json():
"""Register a encoder/decoder for JSON serialization."""
from kombu.utils import json as _json
registry.register('json', _json.dumps, _json.loads,
content_type='application/json',
content_encoding='utf-8')
def register_yaml():
"""Register a encoder/decoder for YAML serialization.
It is slower than JSON, but allows for more data types
to be serialized. Useful if you need to send data such as dates
"""
try:
import yaml
registry.register('yaml', yaml.safe_dump, yaml.safe_load,
content_type='application/x-yaml',
content_encoding='utf-8')
except ImportError:
def not_available(*args, **kwargs):
"""Raise SerializerNotInstalled.
Used in case a client receives a yaml message, but yaml
isn't installed.
"""
raise SerializerNotInstalled(
'No decoder installed for YAML. Install the PyYAML library')
registry.register('yaml', None, not_available, 'application/x-yaml')
def unpickle(s):
return pickle_loads(str_to_bytes(s))
def register_pickle():
"""Register pickle serializer.
The fastest serialization method, but restricts
you to python clients.
"""
def pickle_dumps(obj, dumper=pickle.dumps):
return dumper(obj, protocol=pickle_protocol)
registry.register('pickle', pickle_dumps, unpickle,
content_type='application/x-python-serialize',
content_encoding='binary')
def register_msgpack():
"""Register msgpack serializer.
See Also
--------
https://msgpack.org/.
"""
pack = unpack = None
try:
import msgpack
if msgpack.version >= (0, 4):
from msgpack import packb, unpackb
def pack(s): # noqa
return packb(s, use_bin_type=True)
def unpack(s): # noqa
return unpackb(s, raw=False)
else:
def version_mismatch(*args, **kwargs):
raise SerializerNotInstalled(
'msgpack requires msgpack-python >= 0.4.0')
pack = unpack = version_mismatch
except (ImportError, ValueError):
def not_available(*args, **kwargs):
raise SerializerNotInstalled(
'No decoder installed for msgpack. '
'Please install the msgpack-python library')
pack = unpack = not_available
registry.register(
'msgpack', pack, unpack,
content_type='application/x-msgpack',
content_encoding='binary',
)
# Register the base serialization methods.
register_json()
register_pickle()
register_yaml()
register_msgpack()
# Default serializer is 'json'
registry._set_default_serializer('json')
NOTSET = object()
def enable_insecure_serializers(choices=NOTSET):
"""Enable serializers that are considered to be unsafe.
Note:
----
Will enable ``pickle``, ``yaml`` and ``msgpack`` by default, but you
can also specify a list of serializers (by name or content type)
to enable.
"""
choices = ['pickle', 'yaml', 'msgpack'] if choices is NOTSET else choices
if choices is not None:
for choice in choices:
try:
registry.enable(choice)
except KeyError:
pass
def disable_insecure_serializers(allowed=NOTSET):
"""Disable untrusted serializers.
Will disable all serializers except ``json``
or you can specify a list of deserializers to allow.
Note:
----
Producers will still be able to serialize data
in these formats, but consumers will not accept
incoming data using the untrusted content types.
"""
allowed = ['json'] if allowed is NOTSET else allowed
for name in registry._decoders:
registry.disable(name)
if allowed is not None:
for name in allowed:
registry.enable(name)
# Insecure serializers are disabled by default since v3.0
disable_insecure_serializers()
# Load entrypoints from installed extensions
for ep, args in entrypoints('kombu.serializers'): # pragma: no cover
register(ep.name, *args)
def prepare_accept_content(content_types, name_to_type=None):
"""Replace aliases of content_types with full names from registry.
Raises
------
SerializerNotInstalled: If the serialization method
requested is not available.
"""
name_to_type = registry.name_to_type if not name_to_type else name_to_type
if content_types is not None:
try:
return {n if '/' in n else name_to_type[n] for n in content_types}
except KeyError as e:
raise SerializerNotInstalled(
f'No encoder/decoder installed for {e.args[0]}')
return content_types

View File

@@ -0,0 +1,163 @@
"""Simple messaging interface."""
from __future__ import annotations
import socket
from collections import deque
from queue import Empty
from time import monotonic
from typing import TYPE_CHECKING
from . import entity, messaging
from .connection import maybe_channel
if TYPE_CHECKING:
from types import TracebackType
__all__ = ('SimpleQueue', 'SimpleBuffer')
class SimpleBase:
Empty = Empty
_consuming = False
def __enter__(self):
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None
) -> None:
self.close()
def __init__(self, channel, producer, consumer, no_ack=False):
self.channel = maybe_channel(channel)
self.producer = producer
self.consumer = consumer
self.no_ack = no_ack
self.queue = self.consumer.queues[0]
self.buffer = deque()
self.consumer.register_callback(self._receive)
def get(self, block=True, timeout=None):
if not block:
return self.get_nowait()
self._consume()
time_start = monotonic()
remaining = timeout
while True:
if self.buffer:
return self.buffer.popleft()
if remaining is not None and remaining <= 0.0:
raise self.Empty()
try:
# The `drain_events` method will
# block on the socket connection to rabbitmq. if any
# application-level messages are received, it will put them
# into `self.buffer`.
# * The method will block for UP TO `timeout` milliseconds.
# * The method may raise a socket.timeout exception; or...
# * The method may return without having put anything on
# `self.buffer`. This is because internal heartbeat
# messages are sent over the same socket; also POSIX makes
# no guarantees against socket calls returning early.
self.channel.connection.client.drain_events(timeout=remaining)
except socket.timeout:
raise self.Empty()
if remaining is not None:
elapsed = monotonic() - time_start
remaining = timeout - elapsed
def get_nowait(self):
m = self.queue.get(no_ack=self.no_ack, accept=self.consumer.accept)
if not m:
raise self.Empty()
return m
def put(self, message, serializer=None, headers=None, compression=None,
routing_key=None, **kwargs):
self.producer.publish(message,
serializer=serializer,
routing_key=routing_key,
headers=headers,
compression=compression,
**kwargs)
def clear(self):
return self.consumer.purge()
def qsize(self):
_, size, _ = self.queue.queue_declare(passive=True)
return size
def close(self):
self.consumer.cancel()
def _receive(self, message_data, message):
self.buffer.append(message)
def _consume(self):
if not self._consuming:
self.consumer.consume(no_ack=self.no_ack)
self._consuming = True
def __len__(self):
"""`len(self) -> self.qsize()`."""
return self.qsize()
def __bool__(self):
return True
__nonzero__ = __bool__
class SimpleQueue(SimpleBase):
"""Simple API for persistent queues."""
no_ack = False
queue_opts = {}
queue_args = {}
exchange_opts = {'type': 'direct'}
def __init__(self, channel, name, no_ack=None, queue_opts=None,
queue_args=None, exchange_opts=None, serializer=None,
compression=None, accept=None):
queue = name
queue_opts = dict(self.queue_opts, **queue_opts or {})
queue_args = dict(self.queue_args, **queue_args or {})
exchange_opts = dict(self.exchange_opts, **exchange_opts or {})
if no_ack is None:
no_ack = self.no_ack
if not isinstance(queue, entity.Queue):
exchange = entity.Exchange(name, **exchange_opts)
queue = entity.Queue(name, exchange, name,
queue_arguments=queue_args,
**queue_opts)
routing_key = name
else:
exchange = queue.exchange
routing_key = queue.routing_key
consumer = messaging.Consumer(channel, queue, accept=accept)
producer = messaging.Producer(channel, exchange,
serializer=serializer,
routing_key=routing_key,
compression=compression)
super().__init__(channel, producer,
consumer, no_ack)
class SimpleBuffer(SimpleQueue):
"""Simple API for ephemeral queues."""
no_ack = True
queue_opts = {'durable': False,
'auto_delete': True}
exchange_opts = {'durable': False,
'delivery_mode': 'transient',
'auto_delete': True}

View File

@@ -0,0 +1,202 @@
"""SoftLayer Message Queue transport module for kombu.
Features
========
* Type: Virtual
* Supports Direct: Yes
* Supports Topic: Yes
* Supports Fanout: No
* Supports Priority: No
* Supports TTL: No
Connection String
=================
*Unreviewed*
Transport Options
=================
*Unreviewed*
"""
from __future__ import annotations
import os
import socket
import string
from queue import Empty
from kombu.utils.encoding import bytes_to_str, safe_str
from kombu.utils.json import dumps, loads
from kombu.utils.objects import cached_property
from . import virtual
try:
from softlayer_messaging import get_client
from softlayer_messaging.errors import ResponseError
except ImportError: # pragma: no cover
get_client = ResponseError = None
# dots are replaced by dash, all other punctuation replaced by underscore.
CHARS_REPLACE_TABLE = {
ord(c): 0x5f for c in string.punctuation if c not in '_'
}
class Channel(virtual.Channel):
"""SLMQ Channel."""
default_visibility_timeout = 1800 # 30 minutes.
domain_format = 'kombu%(vhost)s'
_slmq = None
_queue_cache = {}
_noack_queues = set()
def __init__(self, *args, **kwargs):
if get_client is None:
raise ImportError(
'SLMQ transport requires the softlayer_messaging library',
)
super().__init__(*args, **kwargs)
queues = self.slmq.queues()
for queue in queues:
self._queue_cache[queue] = queue
def basic_consume(self, queue, no_ack, *args, **kwargs):
if no_ack:
self._noack_queues.add(queue)
return super().basic_consume(queue, no_ack,
*args, **kwargs)
def basic_cancel(self, consumer_tag):
if consumer_tag in self._consumers:
queue = self._tag_to_queue[consumer_tag]
self._noack_queues.discard(queue)
return super().basic_cancel(consumer_tag)
def entity_name(self, name, table=CHARS_REPLACE_TABLE):
"""Format AMQP queue name into a valid SLQS queue name."""
return str(safe_str(name)).translate(table)
def _new_queue(self, queue, **kwargs):
"""Ensure a queue exists in SLQS."""
queue = self.entity_name(self.queue_name_prefix + queue)
try:
return self._queue_cache[queue]
except KeyError:
try:
self.slmq.create_queue(
queue, visibility_timeout=self.visibility_timeout)
except ResponseError:
pass
q = self._queue_cache[queue] = self.slmq.queue(queue)
return q
def _delete(self, queue, *args, **kwargs):
"""Delete queue by name."""
queue_name = self.entity_name(queue)
self._queue_cache.pop(queue_name, None)
self.slmq.queue(queue_name).delete(force=True)
super()._delete(queue_name)
def _put(self, queue, message, **kwargs):
"""Put message onto queue."""
q = self._new_queue(queue)
q.push(dumps(message))
def _get(self, queue):
"""Try to retrieve a single message off ``queue``."""
q = self._new_queue(queue)
rs = q.pop(1)
if rs['items']:
m = rs['items'][0]
payload = loads(bytes_to_str(m['body']))
if queue in self._noack_queues:
q.message(m['id']).delete()
else:
payload['properties']['delivery_info'].update({
'slmq_message_id': m['id'], 'slmq_queue_name': q.name})
return payload
raise Empty()
def basic_ack(self, delivery_tag):
delivery_info = self.qos.get(delivery_tag).delivery_info
try:
queue = delivery_info['slmq_queue_name']
except KeyError:
pass
else:
self.delete_message(queue, delivery_info['slmq_message_id'])
super().basic_ack(delivery_tag)
def _size(self, queue):
"""Return the number of messages in a queue."""
return self._new_queue(queue).detail()['message_count']
def _purge(self, queue):
"""Delete all current messages in a queue."""
q = self._new_queue(queue)
n = 0
results = q.pop(10)
while results['items']:
for m in results['items']:
self.delete_message(queue, m['id'])
n += 1
results = q.pop(10)
return n
def delete_message(self, queue, message_id):
q = self.slmq.queue(self.entity_name(queue))
return q.message(message_id).delete()
@property
def slmq(self):
if self._slmq is None:
conninfo = self.conninfo
account = os.environ.get('SLMQ_ACCOUNT', conninfo.virtual_host)
user = os.environ.get('SL_USERNAME', conninfo.userid)
api_key = os.environ.get('SL_API_KEY', conninfo.password)
host = os.environ.get('SLMQ_HOST', conninfo.hostname)
port = os.environ.get('SLMQ_PORT', conninfo.port)
secure = bool(os.environ.get(
'SLMQ_SECURE', self.transport_options.get('secure')) or True,
)
endpoint = '{}://{}{}'.format(
'https' if secure else 'http', host,
f':{port}' if port else '',
)
self._slmq = get_client(account, endpoint=endpoint)
self._slmq.authenticate(user, api_key)
return self._slmq
@property
def conninfo(self):
return self.connection.client
@property
def transport_options(self):
return self.connection.client.transport_options
@cached_property
def visibility_timeout(self):
return (self.transport_options.get('visibility_timeout') or
self.default_visibility_timeout)
@cached_property
def queue_name_prefix(self):
return self.transport_options.get('queue_name_prefix', '')
class Transport(virtual.Transport):
"""SLMQ Transport."""
Channel = Channel
polling_interval = 1
default_port = None
connection_errors = (
virtual.Transport.connection_errors + (
ResponseError, socket.error
)
)

View File

@@ -0,0 +1,973 @@
"""Amazon SQS transport module for Kombu.
This package implements an AMQP-like interface on top of Amazons SQS service,
with the goal of being optimized for high performance and reliability.
The default settings for this module are focused now on high performance in
task queue situations where tasks are small, idempotent and run very fast.
SQS Features supported by this transport
========================================
Long Polling
------------
https://docs.aws.amazon.com/AWSSimpleQueueService/latest/SQSDeveloperGuide/sqs-long-polling.html
Long polling is enabled by setting the `wait_time_seconds` transport
option to a number > 1. Amazon supports up to 20 seconds. This is
enabled with 10 seconds by default.
Batch API Actions
-----------------
https://docs.aws.amazon.com/AWSSimpleQueueService/latest/SQSDeveloperGuide/sqs-batch-api.html
The default behavior of the SQS Channel.drain_events() method is to
request up to the 'prefetch_count' messages on every request to SQS.
These messages are stored locally in a deque object and passed back
to the Transport until the deque is empty, before triggering a new
API call to Amazon.
This behavior dramatically speeds up the rate that you can pull tasks
from SQS when you have short-running tasks (or a large number of workers).
When a Celery worker has multiple queues to monitor, it will pull down
up to 'prefetch_count' messages from queueA and work on them all before
moving on to queueB. If queueB is empty, it will wait up until
'polling_interval' expires before moving back and checking on queueA.
Message Attributes
-----------------
https://docs.aws.amazon.com/AWSSimpleQueueService/latest/SQSDeveloperGuide/sqs-message-metadata.html
SQS supports sending message attributes along with the message body.
To use this feature, you can pass a 'message_attributes' as keyword argument
to `basic_publish` method.
Other Features supported by this transport
==========================================
Predefined Queues
-----------------
The default behavior of this transport is to use a single AWS credential
pair in order to manage all SQS queues (e.g. listing queues, creating
queues, polling queues, deleting messages).
If it is preferable for your environment to use multiple AWS credentials, you
can use the 'predefined_queues' setting inside the 'transport_options' map.
This setting allows you to specify the SQS queue URL and AWS credentials for
each of your queues. For example, if you have two queues which both already
exist in AWS) you can tell this transport about them as follows:
.. code-block:: python
transport_options = {
'predefined_queues': {
'queue-1': {
'url': 'https://sqs.us-east-1.amazonaws.com/xxx/aaa',
'access_key_id': 'a',
'secret_access_key': 'b',
'backoff_policy': {1: 10, 2: 20, 3: 40, 4: 80, 5: 320, 6: 640}, # optional
'backoff_tasks': ['svc.tasks.tasks.task1'] # optional
},
'queue-2.fifo': {
'url': 'https://sqs.us-east-1.amazonaws.com/xxx/bbb.fifo',
'access_key_id': 'c',
'secret_access_key': 'd',
'backoff_policy': {1: 10, 2: 20, 3: 40, 4: 80, 5: 320, 6: 640}, # optional
'backoff_tasks': ['svc.tasks.tasks.task2'] # optional
},
}
'sts_role_arn': 'arn:aws:iam::<xxx>:role/STSTest', # optional
'sts_token_timeout': 900 # optional
}
Note that FIFO and standard queues must be named accordingly (the name of
a FIFO queue must end with the .fifo suffix).
backoff_policy & backoff_tasks are optional arguments. These arguments
automatically change the message visibility timeout, in order to have
different times between specific task retries. This would apply after
task failure.
AWS STS authentication is supported, by using sts_role_arn, and
sts_token_timeout. sts_role_arn is the assumed IAM role ARN we are trying
to access with. sts_token_timeout is the token timeout, defaults (and minimum)
to 900 seconds. After the mentioned period, a new token will be created.
If you authenticate using Okta_ (e.g. calling |gac|_), you can also specify
a 'session_token' to connect to a queue. Note that those tokens have a
limited lifetime and are therefore only suited for short-lived tests.
.. _Okta: https://www.okta.com/
.. _gac: https://github.com/Nike-Inc/gimme-aws-creds#readme
.. |gac| replace:: ``gimme-aws-creds``
Client config
-------------
In some cases you may need to override the botocore config. You can do it
as follows:
.. code-block:: python
transport_option = {
'client-config': {
'connect_timeout': 5,
},
}
For a complete list of settings you can adjust using this option see
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
Features
========
* Type: Virtual
* Supports Direct: Yes
* Supports Topic: Yes
* Supports Fanout: Yes
* Supports Priority: No
* Supports TTL: No
"""
from __future__ import annotations
import base64
import socket
import string
import uuid
from datetime import datetime
from queue import Empty
from botocore.client import Config
from botocore.exceptions import ClientError
from vine import ensure_promise, promise, transform
from kombu.asynchronous import get_event_loop
from kombu.asynchronous.aws.ext import boto3, exceptions
from kombu.asynchronous.aws.sqs.connection import AsyncSQSConnection
from kombu.asynchronous.aws.sqs.message import AsyncMessage
from kombu.log import get_logger
from kombu.utils import scheduling
from kombu.utils.encoding import bytes_to_str, safe_str
from kombu.utils.json import dumps, loads
from kombu.utils.objects import cached_property
from . import virtual
logger = get_logger(__name__)
# dots are replaced by dash, dash remains dash, all other punctuation
# replaced by underscore.
CHARS_REPLACE_TABLE = {
ord(c): 0x5f for c in string.punctuation if c not in '-_.'
}
CHARS_REPLACE_TABLE[0x2e] = 0x2d # '.' -> '-'
#: SQS bulk get supports a maximum of 10 messages at a time.
SQS_MAX_MESSAGES = 10
def maybe_int(x):
"""Try to convert x' to int, or return x' if that fails."""
try:
return int(x)
except ValueError:
return x
class UndefinedQueueException(Exception):
"""Predefined queues are being used and an undefined queue was used."""
class InvalidQueueException(Exception):
"""Predefined queues are being used and configuration is not valid."""
class AccessDeniedQueueException(Exception):
"""Raised when access to the AWS queue is denied.
This may occur if the permissions are not correctly set or the
credentials are invalid.
"""
class DoesNotExistQueueException(Exception):
"""The specified queue doesn't exist."""
class QoS(virtual.QoS):
"""Quality of Service guarantees implementation for SQS."""
def reject(self, delivery_tag, requeue=False):
super().reject(delivery_tag, requeue=requeue)
routing_key, message, backoff_tasks, backoff_policy = \
self._extract_backoff_policy_configuration_and_message(
delivery_tag)
if routing_key and message and backoff_tasks and backoff_policy:
self.apply_backoff_policy(
routing_key, delivery_tag, backoff_policy, backoff_tasks)
def _extract_backoff_policy_configuration_and_message(self, delivery_tag):
try:
message = self._delivered[delivery_tag]
routing_key = message.delivery_info['routing_key']
except KeyError:
return None, None, None, None
if not routing_key or not message:
return None, None, None, None
queue_config = self.channel.predefined_queues.get(routing_key, {})
backoff_tasks = queue_config.get('backoff_tasks')
backoff_policy = queue_config.get('backoff_policy')
return routing_key, message, backoff_tasks, backoff_policy
def apply_backoff_policy(self, routing_key, delivery_tag,
backoff_policy, backoff_tasks):
queue_url = self.channel._queue_cache[routing_key]
task_name, number_of_retries = \
self.extract_task_name_and_number_of_retries(delivery_tag)
if not task_name or not number_of_retries:
return None
policy_value = backoff_policy.get(number_of_retries)
if task_name in backoff_tasks and policy_value is not None:
c = self.channel.sqs(routing_key)
c.change_message_visibility(
QueueUrl=queue_url,
ReceiptHandle=delivery_tag,
VisibilityTimeout=policy_value
)
def extract_task_name_and_number_of_retries(self, delivery_tag):
message = self._delivered[delivery_tag]
message_headers = message.headers
task_name = message_headers['task']
number_of_retries = int(
message.properties['delivery_info']['sqs_message']
['Attributes']['ApproximateReceiveCount'])
return task_name, number_of_retries
class Channel(virtual.Channel):
"""SQS Channel."""
default_region = 'us-east-1'
default_visibility_timeout = 1800 # 30 minutes.
default_wait_time_seconds = 10 # up to 20 seconds max
domain_format = 'kombu%(vhost)s'
_asynsqs = None
_predefined_queue_async_clients = {} # A client for each predefined queue
_sqs = None
_predefined_queue_clients = {} # A client for each predefined queue
_queue_cache = {} # SQS queue name => SQS queue URL
_noack_queues = set()
QoS = QoS
def __init__(self, *args, **kwargs):
if boto3 is None:
raise ImportError('boto3 is not installed')
super().__init__(*args, **kwargs)
self._validate_predifined_queues()
# SQS blows up if you try to create a new queue when one already
# exists but with a different visibility_timeout. This prepopulates
# the queue_cache to protect us from recreating
# queues that are known to already exist.
self._update_queue_cache(self.queue_name_prefix)
self.hub = kwargs.get('hub') or get_event_loop()
def _validate_predifined_queues(self):
"""Check that standard and FIFO queues are named properly.
AWS requires FIFO queues to have a name
that ends with the .fifo suffix.
"""
for queue_name, q in self.predefined_queues.items():
fifo_url = q['url'].endswith('.fifo')
fifo_name = queue_name.endswith('.fifo')
if fifo_url and not fifo_name:
raise InvalidQueueException(
"Queue with url '{}' must have a name "
"ending with .fifo".format(q['url'])
)
elif not fifo_url and fifo_name:
raise InvalidQueueException(
"Queue with name '{}' is not a FIFO queue: "
"'{}'".format(queue_name, q['url'])
)
def _update_queue_cache(self, queue_name_prefix):
if self.predefined_queues:
for queue_name, q in self.predefined_queues.items():
self._queue_cache[queue_name] = q['url']
return
resp = self.sqs().list_queues(QueueNamePrefix=queue_name_prefix)
for url in resp.get('QueueUrls', []):
queue_name = url.split('/')[-1]
self._queue_cache[queue_name] = url
def basic_consume(self, queue, no_ack, *args, **kwargs):
if no_ack:
self._noack_queues.add(queue)
if self.hub:
self._loop1(queue)
return super().basic_consume(
queue, no_ack, *args, **kwargs
)
def basic_cancel(self, consumer_tag):
if consumer_tag in self._consumers:
queue = self._tag_to_queue[consumer_tag]
self._noack_queues.discard(queue)
return super().basic_cancel(consumer_tag)
def drain_events(self, timeout=None, callback=None, **kwargs):
"""Return a single payload message from one of our queues.
Raises
------
Queue.Empty: if no messages available.
"""
# If we're not allowed to consume or have no consumers, raise Empty
if not self._consumers or not self.qos.can_consume():
raise Empty()
# At this point, go and get more messages from SQS
self._poll(self.cycle, callback, timeout=timeout)
def _reset_cycle(self):
"""Reset the consume cycle.
Returns
-------
FairCycle: object that points to our _get_bulk() method
rather than the standard _get() method. This allows for
multiple messages to be returned at once from SQS (
based on the prefetch limit).
"""
self._cycle = scheduling.FairCycle(
self._get_bulk, self._active_queues, Empty,
)
def entity_name(self, name, table=CHARS_REPLACE_TABLE):
"""Format AMQP queue name into a legal SQS queue name."""
if name.endswith('.fifo'):
partial = name[:-len('.fifo')]
partial = str(safe_str(partial)).translate(table)
return partial + '.fifo'
else:
return str(safe_str(name)).translate(table)
def canonical_queue_name(self, queue_name):
return self.entity_name(self.queue_name_prefix + queue_name)
def _resolve_queue_url(self, queue):
"""Try to retrieve the SQS queue URL for a given queue name."""
# Translate to SQS name for consistency with initial
# _queue_cache population.
sqs_qname = self.canonical_queue_name(queue)
# The SQS ListQueues method only returns 1000 queues. When you have
# so many queues, it's possible that the queue you are looking for is
# not cached. In this case, we could update the cache with the exact
# queue name first.
if sqs_qname not in self._queue_cache:
self._update_queue_cache(sqs_qname)
try:
return self._queue_cache[sqs_qname]
except KeyError:
if self.predefined_queues:
raise UndefinedQueueException((
"Queue with name '{}' must be "
"defined in 'predefined_queues'."
).format(sqs_qname))
raise DoesNotExistQueueException(
f"Queue with name '{sqs_qname}' doesn't exist in SQS"
)
def _new_queue(self, queue, **kwargs):
"""Ensure a queue with given name exists in SQS.
Arguments:
---------
queue (str): the AMQP queue name
Returns
str: the SQS queue URL
"""
try:
return self._resolve_queue_url(queue)
except DoesNotExistQueueException:
sqs_qname = self.canonical_queue_name(queue)
attributes = {'VisibilityTimeout': str(self.visibility_timeout)}
if sqs_qname.endswith('.fifo'):
attributes['FifoQueue'] = 'true'
resp = self._create_queue(sqs_qname, attributes)
self._queue_cache[sqs_qname] = resp['QueueUrl']
return resp['QueueUrl']
def _create_queue(self, queue_name, attributes):
"""Create an SQS queue with a given name and nominal attributes."""
# Allow specifying additional boto create_queue Attributes
# via transport options
if self.predefined_queues:
return None
attributes.update(
self.transport_options.get('sqs-creation-attributes') or {},
)
return self.sqs(queue=queue_name).create_queue(
QueueName=queue_name,
Attributes=attributes,
)
def _delete(self, queue, *args, **kwargs):
"""Delete queue by name."""
if self.predefined_queues:
return
q_url = self._resolve_queue_url(queue)
self.sqs().delete_queue(
QueueUrl=q_url,
)
self._queue_cache.pop(queue, None)
def _put(self, queue, message, **kwargs):
"""Put message onto queue."""
q_url = self._new_queue(queue)
kwargs = {'QueueUrl': q_url}
if 'properties' in message:
if 'message_attributes' in message['properties']:
# we don't want to want to have the attribute in the body
kwargs['MessageAttributes'] = \
message['properties'].pop('message_attributes')
if queue.endswith('.fifo'):
if 'MessageGroupId' in message['properties']:
kwargs['MessageGroupId'] = \
message['properties']['MessageGroupId']
else:
kwargs['MessageGroupId'] = 'default'
if 'MessageDeduplicationId' in message['properties']:
kwargs['MessageDeduplicationId'] = \
message['properties']['MessageDeduplicationId']
else:
kwargs['MessageDeduplicationId'] = str(uuid.uuid4())
else:
if "DelaySeconds" in message['properties']:
kwargs['DelaySeconds'] = \
message['properties']['DelaySeconds']
if self.sqs_base64_encoding:
body = AsyncMessage().encode(dumps(message))
else:
body = dumps(message)
kwargs['MessageBody'] = body
c = self.sqs(queue=self.canonical_queue_name(queue))
if message.get('redelivered'):
c.change_message_visibility(
QueueUrl=q_url,
ReceiptHandle=message['properties']['delivery_tag'],
VisibilityTimeout=0
)
else:
c.send_message(**kwargs)
@staticmethod
def _optional_b64_decode(byte_string):
try:
data = base64.b64decode(byte_string)
if base64.b64encode(data) == byte_string:
return data
# else the base64 module found some embedded base64 content
# that should be ignored.
except Exception: # pylint: disable=broad-except
pass
return byte_string
def _message_to_python(self, message, queue_name, q_url):
body = self._optional_b64_decode(message['Body'].encode())
payload = loads(bytes_to_str(body))
if queue_name in self._noack_queues:
q_url = self._new_queue(queue_name)
self.asynsqs(queue=queue_name).delete_message(
q_url,
message['ReceiptHandle'],
)
else:
try:
properties = payload['properties']
delivery_info = payload['properties']['delivery_info']
except KeyError:
# json message not sent by kombu?
delivery_info = {}
properties = {'delivery_info': delivery_info}
payload.update({
'body': bytes_to_str(body),
'properties': properties,
})
# set delivery tag to SQS receipt handle
delivery_info.update({
'sqs_message': message, 'sqs_queue': q_url,
})
properties['delivery_tag'] = message['ReceiptHandle']
return payload
def _messages_to_python(self, messages, queue):
"""Convert a list of SQS Message objects into Payloads.
This method handles converting SQS Message objects into
Payloads, and appropriately updating the queue depending on
the 'ack' settings for that queue.
Arguments:
---------
messages (SQSMessage): A list of SQS Message objects.
queue (str): Name representing the queue they came from.
Returns
-------
List: A list of Payload objects
"""
q_url = self._new_queue(queue)
return [self._message_to_python(m, queue, q_url) for m in messages]
def _get_bulk(self, queue,
max_if_unlimited=SQS_MAX_MESSAGES, callback=None):
"""Try to retrieve multiple messages off ``queue``.
Where :meth:`_get` returns a single Payload object, this method
returns a list of Payload objects. The number of objects returned
is determined by the total number of messages available in the queue
and the number of messages the QoS object allows (based on the
prefetch_count).
Note:
----
Ignores QoS limits so caller is responsible for checking
that we are allowed to consume at least one message from the
queue. get_bulk will then ask QoS for an estimate of
the number of extra messages that we can consume.
Arguments:
---------
queue (str): The queue name to pull from.
Returns
-------
List[Message]
"""
# drain_events calls `can_consume` first, consuming
# a token, so we know that we are allowed to consume at least
# one message.
# Note: ignoring max_messages for SQS with boto3
max_count = self._get_message_estimate()
if max_count:
q_url = self._new_queue(queue)
resp = self.sqs(queue=queue).receive_message(
QueueUrl=q_url, MaxNumberOfMessages=max_count,
WaitTimeSeconds=self.wait_time_seconds)
if resp.get('Messages'):
for m in resp['Messages']:
m['Body'] = AsyncMessage(body=m['Body']).decode()
for msg in self._messages_to_python(resp['Messages'], queue):
self.connection._deliver(msg, queue)
return
raise Empty()
def _get(self, queue):
"""Try to retrieve a single message off ``queue``."""
q_url = self._new_queue(queue)
resp = self.sqs(queue=queue).receive_message(
QueueUrl=q_url, MaxNumberOfMessages=1,
WaitTimeSeconds=self.wait_time_seconds)
if resp.get('Messages'):
body = AsyncMessage(body=resp['Messages'][0]['Body']).decode()
resp['Messages'][0]['Body'] = body
return self._messages_to_python(resp['Messages'], queue)[0]
raise Empty()
def _loop1(self, queue, _=None):
self.hub.call_soon(self._schedule_queue, queue)
def _schedule_queue(self, queue):
if queue in self._active_queues:
if self.qos.can_consume():
self._get_bulk_async(
queue, callback=promise(self._loop1, (queue,)),
)
else:
self._loop1(queue)
def _get_message_estimate(self, max_if_unlimited=SQS_MAX_MESSAGES):
maxcount = self.qos.can_consume_max_estimate()
return min(
max_if_unlimited if maxcount is None else max(maxcount, 1),
max_if_unlimited,
)
def _get_bulk_async(self, queue,
max_if_unlimited=SQS_MAX_MESSAGES, callback=None):
maxcount = self._get_message_estimate()
if maxcount:
return self._get_async(queue, maxcount, callback=callback)
# Not allowed to consume, make sure to notify callback..
callback = ensure_promise(callback)
callback([])
return callback
def _get_async(self, queue, count=1, callback=None):
q_url = self._new_queue(queue)
qname = self.canonical_queue_name(queue)
return self._get_from_sqs(
queue_name=qname, queue_url=q_url, count=count,
connection=self.asynsqs(queue=qname),
callback=transform(
self._on_messages_ready, callback, q_url, queue
),
)
def _on_messages_ready(self, queue, qname, messages):
if 'Messages' in messages and messages['Messages']:
callbacks = self.connection._callbacks
for msg in messages['Messages']:
msg_parsed = self._message_to_python(msg, qname, queue)
callbacks[qname](msg_parsed)
def _get_from_sqs(self, queue_name, queue_url,
connection, count=1, callback=None):
"""Retrieve and handle messages from SQS.
Uses long polling and returns :class:`~vine.promises.promise`.
"""
return connection.receive_message(
queue_name, queue_url, number_messages=count,
wait_time_seconds=self.wait_time_seconds,
callback=callback,
)
def _restore(self, message,
unwanted_delivery_info=('sqs_message', 'sqs_queue')):
for unwanted_key in unwanted_delivery_info:
# Remove objects that aren't JSON serializable (Issue #1108).
message.delivery_info.pop(unwanted_key, None)
return super()._restore(message)
def basic_ack(self, delivery_tag, multiple=False):
try:
message = self.qos.get(delivery_tag).delivery_info
sqs_message = message['sqs_message']
except KeyError:
super().basic_ack(delivery_tag)
else:
queue = None
if 'routing_key' in message:
queue = self.canonical_queue_name(message['routing_key'])
try:
self.sqs(queue=queue).delete_message(
QueueUrl=message['sqs_queue'],
ReceiptHandle=sqs_message['ReceiptHandle']
)
except ClientError as exception:
if exception.response['Error']['Code'] == 'AccessDenied':
raise AccessDeniedQueueException(
exception.response["Error"]["Message"]
)
super().basic_reject(delivery_tag)
else:
super().basic_ack(delivery_tag)
def _size(self, queue):
"""Return the number of messages in a queue."""
q_url = self._new_queue(queue)
c = self.sqs(queue=self.canonical_queue_name(queue))
resp = c.get_queue_attributes(
QueueUrl=q_url,
AttributeNames=['ApproximateNumberOfMessages'])
return int(resp['Attributes']['ApproximateNumberOfMessages'])
def _purge(self, queue):
"""Delete all current messages in a queue."""
q_url = self._new_queue(queue)
# SQS is slow at registering messages, so run for a few
# iterations to ensure messages are detected and deleted.
size = 0
for i in range(10):
size += int(self._size(queue))
if not size:
break
self.sqs(queue=queue).purge_queue(QueueUrl=q_url)
return size
def close(self):
super().close()
# if self._asynsqs:
# try:
# self.asynsqs().close()
# except AttributeError as exc: # FIXME ???
# if "can't set attribute" not in str(exc):
# raise
def new_sqs_client(self, region, access_key_id,
secret_access_key, session_token=None):
session = boto3.session.Session(
region_name=region,
aws_access_key_id=access_key_id,
aws_secret_access_key=secret_access_key,
aws_session_token=session_token,
)
is_secure = self.is_secure if self.is_secure is not None else True
client_kwargs = {
'use_ssl': is_secure
}
if self.endpoint_url is not None:
client_kwargs['endpoint_url'] = self.endpoint_url
client_config = self.transport_options.get('client-config') or {}
config = Config(**client_config)
return session.client('sqs', config=config, **client_kwargs)
def sqs(self, queue=None):
if queue is not None and self.predefined_queues:
if queue not in self.predefined_queues:
raise UndefinedQueueException(
f"Queue with name '{queue}' must be defined"
" in 'predefined_queues'.")
q = self.predefined_queues[queue]
if self.transport_options.get('sts_role_arn'):
return self._handle_sts_session(queue, q)
if not self.transport_options.get('sts_role_arn'):
if queue in self._predefined_queue_clients:
return self._predefined_queue_clients[queue]
else:
c = self._predefined_queue_clients[queue] = \
self.new_sqs_client(
region=q.get('region', self.region),
access_key_id=q.get(
'access_key_id', self.conninfo.userid),
secret_access_key=q.get(
'secret_access_key', self.conninfo.password)
)
return c
if self._sqs is not None:
return self._sqs
c = self._sqs = self.new_sqs_client(
region=self.region,
access_key_id=self.conninfo.userid,
secret_access_key=self.conninfo.password,
)
return c
def _handle_sts_session(self, queue, q):
region = q.get('region', self.region)
if not hasattr(self, 'sts_expiration'): # STS token - token init
return self._new_predefined_queue_client_with_sts_session(queue, region)
# STS token - refresh if expired
elif self.sts_expiration.replace(tzinfo=None) < datetime.utcnow():
return self._new_predefined_queue_client_with_sts_session(queue, region)
else: # STS token - ruse existing
if queue not in self._predefined_queue_clients:
return self._new_predefined_queue_client_with_sts_session(queue, region)
return self._predefined_queue_clients[queue]
def _new_predefined_queue_client_with_sts_session(self, queue, region):
sts_creds = self.generate_sts_session_token(
self.transport_options.get('sts_role_arn'),
self.transport_options.get('sts_token_timeout', 900))
self.sts_expiration = sts_creds['Expiration']
c = self._predefined_queue_clients[queue] = self.new_sqs_client(
region=region,
access_key_id=sts_creds['AccessKeyId'],
secret_access_key=sts_creds['SecretAccessKey'],
session_token=sts_creds['SessionToken'],
)
return c
def generate_sts_session_token(self, role_arn, token_expiry_seconds):
sts_client = boto3.client('sts')
sts_policy = sts_client.assume_role(
RoleArn=role_arn,
RoleSessionName='Celery',
DurationSeconds=token_expiry_seconds
)
return sts_policy['Credentials']
def asynsqs(self, queue=None):
if queue is not None and self.predefined_queues:
if queue in self._predefined_queue_async_clients and \
not hasattr(self, 'sts_expiration'):
return self._predefined_queue_async_clients[queue]
if queue not in self.predefined_queues:
raise UndefinedQueueException((
"Queue with name '{}' must be defined in "
"'predefined_queues'."
).format(queue))
q = self.predefined_queues[queue]
c = self._predefined_queue_async_clients[queue] = \
AsyncSQSConnection(
sqs_connection=self.sqs(queue=queue),
region=q.get('region', self.region),
fetch_message_attributes=self.fetch_message_attributes,
)
return c
if self._asynsqs is not None:
return self._asynsqs
c = self._asynsqs = AsyncSQSConnection(
sqs_connection=self.sqs(queue=queue),
region=self.region,
fetch_message_attributes=self.fetch_message_attributes,
)
return c
@property
def conninfo(self):
return self.connection.client
@property
def transport_options(self):
return self.connection.client.transport_options
@cached_property
def visibility_timeout(self):
return (self.transport_options.get('visibility_timeout') or
self.default_visibility_timeout)
@cached_property
def predefined_queues(self):
"""Map of queue_name to predefined queue settings."""
return self.transport_options.get('predefined_queues', {})
@cached_property
def queue_name_prefix(self):
return self.transport_options.get('queue_name_prefix', '')
@cached_property
def supports_fanout(self):
return False
@cached_property
def region(self):
return (self.transport_options.get('region') or
boto3.Session().region_name or
self.default_region)
@cached_property
def regioninfo(self):
return self.transport_options.get('regioninfo')
@cached_property
def is_secure(self):
return self.transport_options.get('is_secure')
@cached_property
def port(self):
return self.transport_options.get('port')
@cached_property
def endpoint_url(self):
if self.conninfo.hostname is not None:
scheme = 'https' if self.is_secure else 'http'
if self.conninfo.port is not None:
port = f':{self.conninfo.port}'
else:
port = ''
return '{}://{}{}'.format(
scheme,
self.conninfo.hostname,
port
)
@cached_property
def wait_time_seconds(self):
return self.transport_options.get('wait_time_seconds',
self.default_wait_time_seconds)
@cached_property
def sqs_base64_encoding(self):
return self.transport_options.get('sqs_base64_encoding', True)
@cached_property
def fetch_message_attributes(self):
return self.transport_options.get('fetch_message_attributes')
class Transport(virtual.Transport):
"""SQS Transport.
Additional queue attributes can be supplied to SQS during queue
creation by passing an ``sqs-creation-attributes`` key in
transport_options. ``sqs-creation-attributes`` must be a dict whose
key-value pairs correspond with Attributes in the
`CreateQueue SQS API`_.
For example, to have SQS queues created with server-side encryption
enabled using the default Amazon Managed Customer Master Key, you
can set ``KmsMasterKeyId`` Attribute. When the queue is initially
created by Kombu, encryption will be enabled.
.. code-block:: python
from kombu.transport.SQS import Transport
transport = Transport(
...,
transport_options={
'sqs-creation-attributes': {
'KmsMasterKeyId': 'alias/aws/sqs',
},
}
)
.. _CreateQueue SQS API: https://docs.aws.amazon.com/AWSSimpleQueueService/latest/APIReference/API_CreateQueue.html#API_CreateQueue_RequestParameters
The ``ApproximateReceiveCount`` message attribute is fetched by this
transport by default. Requested message attributes can be changed by
setting ``fetch_message_attributes`` in the transport options.
.. code-block:: python
from kombu.transport.SQS import Transport
transport = Transport(
...,
transport_options={
'fetch_message_attributes': ["All"],
}
)
.. _Message Attributes: https://docs.aws.amazon.com/AWSSimpleQueueService/latest/APIReference/API_ReceiveMessage.html#SQS-ReceiveMessage-request-AttributeNames
""" # noqa: E501
Channel = Channel
polling_interval = 1
wait_time_seconds = 0
default_port = None
connection_errors = (
virtual.Transport.connection_errors +
(exceptions.BotoCoreError, socket.error)
)
channel_errors = (
virtual.Transport.channel_errors + (exceptions.BotoCoreError,)
)
driver_type = 'sqs'
driver_name = 'sqs'
implements = virtual.Transport.implements.extend(
asynchronous=True,
exchange_type=frozenset(['direct']),
)
@property
def default_connection_params(self):
return {'port': self.default_port}

View File

@@ -0,0 +1,93 @@
"""Built-in transports."""
from __future__ import annotations
from kombu.utils.compat import _detect_environment
from kombu.utils.imports import symbol_by_name
def supports_librabbitmq() -> bool | None:
"""Return true if :pypi:`librabbitmq` can be used."""
if _detect_environment() == 'default':
try:
import librabbitmq # noqa
except ImportError: # pragma: no cover
pass
else: # pragma: no cover
return True
return None
TRANSPORT_ALIASES = {
'amqp': 'kombu.transport.pyamqp:Transport',
'amqps': 'kombu.transport.pyamqp:SSLTransport',
'pyamqp': 'kombu.transport.pyamqp:Transport',
'librabbitmq': 'kombu.transport.librabbitmq:Transport',
'confluentkafka': 'kombu.transport.confluentkafka:Transport',
'kafka': 'kombu.transport.confluentkafka:Transport',
'memory': 'kombu.transport.memory:Transport',
'redis': 'kombu.transport.redis:Transport',
'rediss': 'kombu.transport.redis:Transport',
'SQS': 'kombu.transport.SQS:Transport',
'sqs': 'kombu.transport.SQS:Transport',
'mongodb': 'kombu.transport.mongodb:Transport',
'zookeeper': 'kombu.transport.zookeeper:Transport',
'sqlalchemy': 'kombu.transport.sqlalchemy:Transport',
'sqla': 'kombu.transport.sqlalchemy:Transport',
'SLMQ': 'kombu.transport.SLMQ.Transport',
'slmq': 'kombu.transport.SLMQ.Transport',
'filesystem': 'kombu.transport.filesystem:Transport',
'qpid': 'kombu.transport.qpid:Transport',
'sentinel': 'kombu.transport.redis:SentinelTransport',
'consul': 'kombu.transport.consul:Transport',
'etcd': 'kombu.transport.etcd:Transport',
'azurestoragequeues': 'kombu.transport.azurestoragequeues:Transport',
'azureservicebus': 'kombu.transport.azureservicebus:Transport',
'pyro': 'kombu.transport.pyro:Transport',
'gcpubsub': 'kombu.transport.gcpubsub:Transport',
}
_transport_cache = {}
def resolve_transport(transport: str | None = None) -> str | None:
"""Get transport by name.
Arguments:
---------
transport (Union[str, type]): This can be either
an actual transport class, or the fully qualified
path to a transport class, or the alias of a transport.
"""
if isinstance(transport, str):
try:
transport = TRANSPORT_ALIASES[transport]
except KeyError:
if '.' not in transport and ':' not in transport:
from kombu.utils.text import fmatch_best
alt = fmatch_best(transport, TRANSPORT_ALIASES)
if alt:
raise KeyError(
'No such transport: {}. Did you mean {}?'.format(
transport, alt))
raise KeyError(f'No such transport: {transport}')
else:
if callable(transport):
transport = transport()
return symbol_by_name(transport)
return transport
def get_transport_cls(transport: str | None = None) -> str | None:
"""Get transport class by name.
The transport string is the full path to a transport class, e.g.::
"kombu.transport.pyamqp:Transport"
If the name does not include `"."` (is not fully qualified),
the alias table will be consulted.
"""
if transport not in _transport_cache:
_transport_cache[transport] = resolve_transport(transport)
return _transport_cache[transport]

View File

@@ -0,0 +1,498 @@
"""Azure Service Bus Message Queue transport module for kombu.
Note that the Shared Access Policy used to connect to Azure Service Bus
requires Manage, Send and Listen claims since the broker will create new
queues and delete old queues as required.
Notes when using with Celery if you are experiencing issues with programs not
terminating properly. The Azure Service Bus SDK uses the Azure uAMQP library
which in turn creates some threads. If the AzureServiceBus Channel is closed,
said threads will be closed properly, but it seems there are times when Celery
does not do this so these threads will be left running. As the uAMQP threads
are not marked as Daemon threads, they will not be killed when the main thread
exits. Setting the ``uamqp_keep_alive_interval`` transport option to 0 will
prevent the keep_alive thread from starting
More information about Azure Service Bus:
https://azure.microsoft.com/en-us/services/service-bus/
Features
========
* Type: Virtual
* Supports Direct: *Unreviewed*
* Supports Topic: *Unreviewed*
* Supports Fanout: *Unreviewed*
* Supports Priority: *Unreviewed*
* Supports TTL: *Unreviewed*
Connection String
=================
Connection string has the following formats:
.. code-block::
azureservicebus://SAS_POLICY_NAME:SAS_KEY@SERVICE_BUSNAMESPACE
azureservicebus://DefaultAzureCredential@SERVICE_BUSNAMESPACE
azureservicebus://ManagedIdentityCredential@SERVICE_BUSNAMESPACE
Transport Options
=================
* ``queue_name_prefix`` - String prefix to prepend to queue names in a
service bus namespace.
* ``wait_time_seconds`` - Number of seconds to wait to receive messages.
Default ``5``
* ``peek_lock_seconds`` - Number of seconds the message is visible for before
it is requeued and sent to another consumer. Default ``60``
* ``uamqp_keep_alive_interval`` - Interval in seconds the Azure uAMQP library
should send keepalive messages. Default ``30``
* ``retry_total`` - Azure SDK retry total. Default ``3``
* ``retry_backoff_factor`` - Azure SDK exponential backoff factor.
Default ``0.8``
* ``retry_backoff_max`` - Azure SDK retry total time. Default ``120``
"""
from __future__ import annotations
import string
from queue import Empty
from typing import Any
import azure.core.exceptions
import azure.servicebus.exceptions
import isodate
from azure.servicebus import (ServiceBusClient, ServiceBusMessage,
ServiceBusReceiveMode, ServiceBusReceiver,
ServiceBusSender)
from azure.servicebus.management import ServiceBusAdministrationClient
try:
from azure.identity import (DefaultAzureCredential,
ManagedIdentityCredential)
except ImportError:
DefaultAzureCredential = None
ManagedIdentityCredential = None
from kombu.utils.encoding import bytes_to_str, safe_str
from kombu.utils.json import dumps, loads
from kombu.utils.objects import cached_property
from . import virtual
# dots are replaced by dash, all other punctuation replaced by underscore.
PUNCTUATIONS_TO_REPLACE = set(string.punctuation) - {'_', '.', '-'}
CHARS_REPLACE_TABLE = {
ord('.'): ord('-'),
**{ord(c): ord('_') for c in PUNCTUATIONS_TO_REPLACE}
}
class SendReceive:
"""Container for Sender and Receiver."""
def __init__(self,
receiver: ServiceBusReceiver | None = None,
sender: ServiceBusSender | None = None):
self.receiver: ServiceBusReceiver = receiver
self.sender: ServiceBusSender = sender
def close(self) -> None:
if self.receiver:
self.receiver.close()
self.receiver = None
if self.sender:
self.sender.close()
self.sender = None
class Channel(virtual.Channel):
"""Azure Service Bus channel."""
default_wait_time_seconds: int = 5 # in seconds
default_peek_lock_seconds: int = 60 # in seconds (default 60, max 300)
# in seconds (is the default from service bus repo)
default_uamqp_keep_alive_interval: int = 30
# number of retries (is the default from service bus repo)
default_retry_total: int = 3
# exponential backoff factor (is the default from service bus repo)
default_retry_backoff_factor: float = 0.8
# Max time to backoff (is the default from service bus repo)
default_retry_backoff_max: int = 120
domain_format: str = 'kombu%(vhost)s'
_queue_cache: dict[str, SendReceive] = {}
_noack_queues: set[str] = set()
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._namespace = None
self._policy = None
self._sas_key = None
self._connection_string = None
self._try_parse_connection_string()
self.qos.restore_at_shutdown = False
def _try_parse_connection_string(self) -> None:
self._namespace, self._credential = Transport.parse_uri(
self.conninfo.hostname)
if (
DefaultAzureCredential is not None
and isinstance(self._credential, DefaultAzureCredential)
) or (
ManagedIdentityCredential is not None
and isinstance(self._credential, ManagedIdentityCredential)
):
return None
if ":" in self._credential:
self._policy, self._sas_key = self._credential.split(':', 1)
conn_dict = {
'Endpoint': 'sb://' + self._namespace,
'SharedAccessKeyName': self._policy,
'SharedAccessKey': self._sas_key,
}
self._connection_string = ';'.join(
[key + '=' + value for key, value in conn_dict.items()])
def basic_consume(self, queue, no_ack, *args, **kwargs):
if no_ack:
self._noack_queues.add(queue)
return super().basic_consume(
queue, no_ack, *args, **kwargs
)
def basic_cancel(self, consumer_tag):
if consumer_tag in self._consumers:
queue = self._tag_to_queue[consumer_tag]
self._noack_queues.discard(queue)
return super().basic_cancel(consumer_tag)
def _add_queue_to_cache(
self, name: str,
receiver: ServiceBusReceiver | None = None,
sender: ServiceBusSender | None = None
) -> SendReceive:
if name in self._queue_cache:
obj = self._queue_cache[name]
obj.sender = obj.sender or sender
obj.receiver = obj.receiver or receiver
else:
obj = SendReceive(receiver, sender)
self._queue_cache[name] = obj
return obj
def _get_asb_sender(self, queue: str) -> SendReceive:
queue_obj = self._queue_cache.get(queue, None)
if queue_obj is None or queue_obj.sender is None:
sender = self.queue_service.get_queue_sender(
queue, keep_alive=self.uamqp_keep_alive_interval)
queue_obj = self._add_queue_to_cache(queue, sender=sender)
return queue_obj
def _get_asb_receiver(
self, queue: str,
recv_mode: ServiceBusReceiveMode = ServiceBusReceiveMode.PEEK_LOCK,
queue_cache_key: str | None = None) -> SendReceive:
cache_key = queue_cache_key or queue
queue_obj = self._queue_cache.get(cache_key, None)
if queue_obj is None or queue_obj.receiver is None:
receiver = self.queue_service.get_queue_receiver(
queue_name=queue, receive_mode=recv_mode,
keep_alive=self.uamqp_keep_alive_interval)
queue_obj = self._add_queue_to_cache(cache_key, receiver=receiver)
return queue_obj
def entity_name(
self, name: str, table: dict[int, int] | None = None) -> str:
"""Format AMQP queue name into a valid ServiceBus queue name."""
return str(safe_str(name)).translate(table or CHARS_REPLACE_TABLE)
def _restore(self, message: virtual.base.Message) -> None:
# Not be needed as ASB handles unacked messages
# Remove 'azure_message' as its not JSON serializable
# message.delivery_info.pop('azure_message', None)
# super()._restore(message)
pass
def _new_queue(self, queue: str, **kwargs) -> SendReceive:
"""Ensure a queue exists in ServiceBus."""
queue = self.entity_name(self.queue_name_prefix + queue)
try:
return self._queue_cache[queue]
except KeyError:
# Converts seconds into ISO8601 duration format
# ie 66seconds = P1M6S
lock_duration = isodate.duration_isoformat(
isodate.Duration(seconds=self.peek_lock_seconds))
try:
self.queue_mgmt_service.create_queue(
queue_name=queue, lock_duration=lock_duration)
except azure.core.exceptions.ResourceExistsError:
pass
return self._add_queue_to_cache(queue)
def _delete(self, queue: str, *args, **kwargs) -> None:
"""Delete queue by name."""
queue = self.entity_name(self.queue_name_prefix + queue)
self.queue_mgmt_service.delete_queue(queue)
send_receive_obj = self._queue_cache.pop(queue, None)
if send_receive_obj:
send_receive_obj.close()
def _put(self, queue: str, message, **kwargs) -> None:
"""Put message onto queue."""
queue = self.entity_name(self.queue_name_prefix + queue)
msg = ServiceBusMessage(dumps(message))
queue_obj = self._get_asb_sender(queue)
queue_obj.sender.send_messages(msg)
def _get(
self, queue: str,
timeout: float | int | None = None
) -> dict[str, Any]:
"""Try to retrieve a single message off ``queue``."""
# If we're not ack'ing for this queue, just change receive_mode
recv_mode = ServiceBusReceiveMode.RECEIVE_AND_DELETE \
if queue in self._noack_queues else ServiceBusReceiveMode.PEEK_LOCK
queue = self.entity_name(self.queue_name_prefix + queue)
queue_obj = self._get_asb_receiver(queue, recv_mode)
messages = queue_obj.receiver.receive_messages(
max_message_count=1,
max_wait_time=timeout or self.wait_time_seconds)
if not messages:
raise Empty()
# message.body is either byte or generator[bytes]
message = messages[0]
if not isinstance(message.body, bytes):
body = b''.join(message.body)
else:
body = message.body
msg = loads(bytes_to_str(body))
msg['properties']['delivery_info']['azure_message'] = message
msg['properties']['delivery_info']['azure_queue_name'] = queue
return msg
def basic_ack(self, delivery_tag: str, multiple: bool = False) -> None:
try:
delivery_info = self.qos.get(delivery_tag).delivery_info
except KeyError:
super().basic_ack(delivery_tag)
else:
queue = delivery_info['azure_queue_name']
# recv_mode is PEEK_LOCK when ack'ing messages
queue_obj = self._get_asb_receiver(queue)
try:
queue_obj.receiver.complete_message(
delivery_info['azure_message'])
except azure.servicebus.exceptions.MessageAlreadySettled:
super().basic_ack(delivery_tag)
except Exception:
super().basic_reject(delivery_tag)
else:
super().basic_ack(delivery_tag)
def _size(self, queue: str) -> int:
"""Return the number of messages in a queue."""
queue = self.entity_name(self.queue_name_prefix + queue)
props = self.queue_mgmt_service.get_queue_runtime_properties(queue)
return props.total_message_count
def _purge(self, queue) -> int:
"""Delete all current messages in a queue."""
# Azure doesn't provide a purge api yet
n = 0
max_purge_count = 10
queue = self.entity_name(self.queue_name_prefix + queue)
# By default all the receivers will be in PEEK_LOCK receive mode
queue_obj = self._queue_cache.get(queue, None)
if queue not in self._noack_queues or \
queue_obj is None or queue_obj.receiver is None:
queue_obj = self._get_asb_receiver(
queue,
ServiceBusReceiveMode.RECEIVE_AND_DELETE, 'purge_' + queue
)
while True:
messages = queue_obj.receiver.receive_messages(
max_message_count=max_purge_count,
max_wait_time=0.2
)
n += len(messages)
if len(messages) < max_purge_count:
break
return n
def close(self) -> None:
# receivers and senders spawn threads so clean them up
if not self.closed:
self.closed = True
for queue_obj in self._queue_cache.values():
queue_obj.close()
self._queue_cache.clear()
if self.connection is not None:
self.connection.close_channel(self)
@cached_property
def queue_service(self) -> ServiceBusClient:
if self._connection_string:
return ServiceBusClient.from_connection_string(
self._connection_string,
retry_total=self.retry_total,
retry_backoff_factor=self.retry_backoff_factor,
retry_backoff_max=self.retry_backoff_max
)
return ServiceBusClient(
self._namespace,
self._credential,
retry_total=self.retry_total,
retry_backoff_factor=self.retry_backoff_factor,
retry_backoff_max=self.retry_backoff_max
)
@cached_property
def queue_mgmt_service(self) -> ServiceBusAdministrationClient:
if self._connection_string:
return ServiceBusAdministrationClient.from_connection_string(
self._connection_string
)
return ServiceBusAdministrationClient(
self._namespace, self._credential
)
@property
def conninfo(self):
return self.connection.client
@property
def transport_options(self):
return self.connection.client.transport_options
@cached_property
def queue_name_prefix(self) -> str:
return self.transport_options.get('queue_name_prefix', '')
@cached_property
def wait_time_seconds(self) -> int:
return self.transport_options.get('wait_time_seconds',
self.default_wait_time_seconds)
@cached_property
def peek_lock_seconds(self) -> int:
return min(self.transport_options.get('peek_lock_seconds',
self.default_peek_lock_seconds),
300) # Limit upper bounds to 300
@cached_property
def uamqp_keep_alive_interval(self) -> int:
return self.transport_options.get(
'uamqp_keep_alive_interval',
self.default_uamqp_keep_alive_interval
)
@cached_property
def retry_total(self) -> int:
return self.transport_options.get(
'retry_total', self.default_retry_total)
@cached_property
def retry_backoff_factor(self) -> float:
return self.transport_options.get(
'retry_backoff_factor', self.default_retry_backoff_factor)
@cached_property
def retry_backoff_max(self) -> int:
return self.transport_options.get(
'retry_backoff_max', self.default_retry_backoff_max)
class Transport(virtual.Transport):
"""Azure Service Bus transport."""
Channel = Channel
polling_interval = 1
default_port = None
can_parse_url = True
@staticmethod
def parse_uri(uri: str) -> tuple[str, str | DefaultAzureCredential |
ManagedIdentityCredential]:
# URL like:
# azureservicebus://{SAS policy name}:{SAS key}@{ServiceBus Namespace}
# urllib parse does not work as the sas key could contain a slash
# e.g.: azureservicebus://rootpolicy:some/key@somenamespace
# > 'rootpolicy:some/key@somenamespace'
uri = uri.replace('azureservicebus://', '')
# > 'rootpolicy:some/key', 'somenamespace'
credential, namespace = uri.rsplit('@', 1)
if not namespace.endswith('.net'):
namespace += '.servicebus.windows.net'
if "DefaultAzureCredential".lower() == credential.lower():
if DefaultAzureCredential is None:
raise ImportError('Azure Service Bus transport with a '
'DefaultAzureCredential requires the '
'azure-identity library')
credential = DefaultAzureCredential()
elif "ManagedIdentityCredential".lower() == credential.lower():
if ManagedIdentityCredential is None:
raise ImportError('Azure Service Bus transport with a '
'ManagedIdentityCredential requires the '
'azure-identity library')
credential = ManagedIdentityCredential()
else:
# > 'rootpolicy', 'some/key'
policy, sas_key = credential.split(':', 1)
credential = f"{policy}:{sas_key}"
# Validate ASB connection string
if not all([namespace, credential]):
raise ValueError(
'Need a URI like '
'azureservicebus://{SAS policy name}:{SAS key}@{ServiceBus Namespace} ' # noqa
'or the azure Endpoint connection string'
)
return namespace, credential
@classmethod
def as_uri(cls, uri: str, include_password=False, mask='**') -> str:
namespace, credential = cls.parse_uri(uri)
if isinstance(credential, str) and ":" in credential:
policy, sas_key = credential.split(':', 1)
return 'azureservicebus://{}:{}@{}'.format(
policy,
sas_key if include_password else mask,
namespace
)
return 'azureservicebus://{}@{}'.format(
credential.__class__.__name__,
namespace
)

View File

@@ -0,0 +1,263 @@
"""Azure Storage Queues transport module for kombu.
More information about Azure Storage Queues:
https://azure.microsoft.com/en-us/services/storage/queues/
Features
========
* Type: Virtual
* Supports Direct: *Unreviewed*
* Supports Topic: *Unreviewed*
* Supports Fanout: *Unreviewed*
* Supports Priority: *Unreviewed*
* Supports TTL: *Unreviewed*
Connection String
=================
Connection string has the following formats:
.. code-block::
azurestoragequeues://<STORAGE_ACCOUNT_ACCESS_KEY>@<STORAGE_ACCOUNT_URL>
azurestoragequeues://<SAS_TOKEN>@<STORAGE_ACCOUNT_URL>
azurestoragequeues://DefaultAzureCredential@<STORAGE_ACCOUNT_URL>
azurestoragequeues://ManagedIdentityCredential@<STORAGE_ACCOUNT_URL>
Note that if the access key for the storage account contains a forward slash
(``/``), it will have to be regenerated before it can be used in the connection
URL.
.. code-block::
azurestoragequeues://DefaultAzureCredential@<STORAGE_ACCOUNT_URL>
azurestoragequeues://ManagedIdentityCredential@<STORAGE_ACCOUNT_URL>
If you wish to use an `Azure Managed Identity` you may use the
``DefaultAzureCredential`` format of the connection string which will use
``DefaultAzureCredential`` class in the azure-identity package. You may want to
read the `azure-identity documentation` for more information on how the
``DefaultAzureCredential`` works.
.. _azure-identity documentation:
https://learn.microsoft.com/en-us/python/api/overview/azure/identity-readme?view=azure-python
.. _Azure Managed Identity:
https://learn.microsoft.com/en-us/azure/active-directory/managed-identities-azure-resources/overview
Transport Options
=================
* ``queue_name_prefix``
"""
from __future__ import annotations
import string
from queue import Empty
from typing import Any
from azure.core.exceptions import ResourceExistsError
from kombu.utils.encoding import safe_str
from kombu.utils.json import dumps, loads
from kombu.utils.objects import cached_property
from . import virtual
try:
from azure.storage.queue import QueueServiceClient
except ImportError: # pragma: no cover
QueueServiceClient = None
try:
from azure.identity import (DefaultAzureCredential,
ManagedIdentityCredential)
except ImportError:
DefaultAzureCredential = None
ManagedIdentityCredential = None
# Azure storage queues allow only alphanumeric and dashes
# so, replace everything with a dash
CHARS_REPLACE_TABLE = {
ord(c): 0x2d for c in string.punctuation
}
class Channel(virtual.Channel):
"""Azure Storage Queues channel."""
domain_format: str = 'kombu%(vhost)s'
_queue_service: QueueServiceClient | None = None
_queue_name_cache: dict[Any, Any] = {}
no_ack: bool = True
_noack_queues: set[Any] = set()
def __init__(self, *args, **kwargs):
if QueueServiceClient is None:
raise ImportError('Azure Storage Queues transport requires the '
'azure-storage-queue library')
super().__init__(*args, **kwargs)
self._credential, self._url = Transport.parse_uri(
self.conninfo.hostname
)
for queue in self.queue_service.list_queues():
self._queue_name_cache[queue['name']] = queue
def basic_consume(self, queue, no_ack, *args, **kwargs):
if no_ack:
self._noack_queues.add(queue)
return super().basic_consume(queue, no_ack,
*args, **kwargs)
def entity_name(self, name, table=CHARS_REPLACE_TABLE) -> str:
"""Format AMQP queue name into a valid Azure Storage Queue name."""
return str(safe_str(name)).translate(table)
def _ensure_queue(self, queue):
"""Ensure a queue exists."""
queue = self.entity_name(self.queue_name_prefix + queue)
try:
q = self._queue_service.get_queue_client(
queue=self._queue_name_cache[queue]
)
except KeyError:
try:
q = self.queue_service.create_queue(queue)
except ResourceExistsError:
q = self._queue_service.get_queue_client(queue=queue)
self._queue_name_cache[queue] = q.get_queue_properties()
return q
def _delete(self, queue, *args, **kwargs):
"""Delete queue by name."""
queue_name = self.entity_name(queue)
self._queue_name_cache.pop(queue_name, None)
self.queue_service.delete_queue(queue_name)
def _put(self, queue, message, **kwargs):
"""Put message onto queue."""
q = self._ensure_queue(queue)
encoded_message = dumps(message)
q.send_message(encoded_message)
def _get(self, queue, timeout=None):
"""Try to retrieve a single message off ``queue``."""
q = self._ensure_queue(queue)
messages = q.receive_messages(messages_per_page=1, timeout=timeout)
try:
message = next(messages)
except StopIteration:
raise Empty()
content = loads(message.content)
q.delete_message(message=message)
return content
def _size(self, queue):
"""Return the number of messages in a queue."""
q = self._ensure_queue(queue)
return q.get_queue_properties().approximate_message_count
def _purge(self, queue):
"""Delete all current messages in a queue."""
q = self._ensure_queue(queue)
n = self._size(q.queue_name)
q.clear_messages()
return n
@property
def queue_service(self) -> QueueServiceClient:
if self._queue_service is None:
self._queue_service = QueueServiceClient(
account_url=self._url, credential=self._credential
)
return self._queue_service
@property
def conninfo(self):
return self.connection.client
@property
def transport_options(self):
return self.connection.client.transport_options
@cached_property
def queue_name_prefix(self) -> str:
return self.transport_options.get('queue_name_prefix', '')
class Transport(virtual.Transport):
"""Azure Storage Queues transport."""
Channel = Channel
polling_interval: int = 1
default_port: int | None = None
can_parse_url: bool = True
@staticmethod
def parse_uri(uri: str) -> tuple[str | dict, str]:
# URL like:
# azurestoragequeues://<STORAGE_ACCOUNT_ACCESS_KEY>@<STORAGE_ACCOUNT_URL>
# azurestoragequeues://<SAS_TOKEN>@<STORAGE_ACCOUNT_URL>
# azurestoragequeues://DefaultAzureCredential@<STORAGE_ACCOUNT_URL>
# azurestoragequeues://ManagedIdentityCredential@<STORAGE_ACCOUNT_URL>
# urllib parse does not work as the sas key could contain a slash
# e.g.: azurestoragequeues://some/key@someurl
try:
# > 'some/key@url'
uri = uri.replace('azurestoragequeues://', '')
# > 'some/key', 'url'
credential, url = uri.rsplit('@', 1)
if "DefaultAzureCredential".lower() == credential.lower():
if DefaultAzureCredential is None:
raise ImportError('Azure Storage Queues transport with a '
'DefaultAzureCredential requires the '
'azure-identity library')
credential = DefaultAzureCredential()
elif "ManagedIdentityCredential".lower() == credential.lower():
if ManagedIdentityCredential is None:
raise ImportError('Azure Storage Queues transport with a '
'ManagedIdentityCredential requires the '
'azure-identity library')
credential = ManagedIdentityCredential()
elif "devstoreaccount1" in url and ".core.windows.net" not in url:
# parse credential as a dict if Azurite is being used
credential = {
"account_name": "devstoreaccount1",
"account_key": credential,
}
# Validate parameters
assert all([credential, url])
except Exception:
raise ValueError(
'Need a URI like '
'azurestoragequeues://{SAS or access key}@{URL}, '
'azurestoragequeues://DefaultAzureCredential@{URL}, '
', or '
'azurestoragequeues://ManagedIdentityCredential@{URL}'
)
return credential, url
@classmethod
def as_uri(
cls, uri: str, include_password: bool = False, mask: str = "**"
) -> str:
credential, url = cls.parse_uri(uri)
return "azurestoragequeues://{}@{}".format(
credential if include_password else mask, url
)

View File

@@ -0,0 +1,271 @@
"""Base transport interface."""
# flake8: noqa
from __future__ import annotations
import errno
import socket
from typing import TYPE_CHECKING
from amqp.exceptions import RecoverableConnectionError
from kombu.exceptions import ChannelError, ConnectionError
from kombu.message import Message
from kombu.utils.functional import dictfilter
from kombu.utils.objects import cached_property
from kombu.utils.time import maybe_s_to_ms
if TYPE_CHECKING:
from types import TracebackType
__all__ = ('Message', 'StdChannel', 'Management', 'Transport')
RABBITMQ_QUEUE_ARGUMENTS = {
'expires': ('x-expires', maybe_s_to_ms),
'message_ttl': ('x-message-ttl', maybe_s_to_ms),
'max_length': ('x-max-length', int),
'max_length_bytes': ('x-max-length-bytes', int),
'max_priority': ('x-max-priority', int),
} # type: Mapping[str, Tuple[str, Callable]]
def to_rabbitmq_queue_arguments(arguments, **options):
# type: (Mapping, **Any) -> Dict
"""Convert queue arguments to RabbitMQ queue arguments.
This is the implementation for Channel.prepare_queue_arguments
for AMQP-based transports. It's used by both the pyamqp and librabbitmq
transports.
Arguments:
arguments (Mapping):
User-supplied arguments (``Queue.queue_arguments``).
Keyword Arguments:
expires (float): Queue expiry time in seconds.
This will be converted to ``x-expires`` in int milliseconds.
message_ttl (float): Message TTL in seconds.
This will be converted to ``x-message-ttl`` in int milliseconds.
max_length (int): Max queue length (in number of messages).
This will be converted to ``x-max-length`` int.
max_length_bytes (int): Max queue size in bytes.
This will be converted to ``x-max-length-bytes`` int.
max_priority (int): Max priority steps for queue.
This will be converted to ``x-max-priority`` int.
Returns
-------
Dict: RabbitMQ compatible queue arguments.
"""
prepared = dictfilter(dict(
_to_rabbitmq_queue_argument(key, value)
for key, value in options.items()
))
return dict(arguments, **prepared) if prepared else arguments
def _to_rabbitmq_queue_argument(key, value):
# type: (str, Any) -> Tuple[str, Any]
opt, typ = RABBITMQ_QUEUE_ARGUMENTS[key]
return opt, typ(value) if value is not None else value
def _LeftBlank(obj, method):
return NotImplementedError(
'Transport {0.__module__}.{0.__name__} does not implement {1}'.format(
obj.__class__, method))
class StdChannel:
"""Standard channel base class."""
no_ack_consumers = None
def Consumer(self, *args, **kwargs):
from kombu.messaging import Consumer
return Consumer(self, *args, **kwargs)
def Producer(self, *args, **kwargs):
from kombu.messaging import Producer
return Producer(self, *args, **kwargs)
def get_bindings(self):
raise _LeftBlank(self, 'get_bindings')
def after_reply_message_received(self, queue):
"""Callback called after RPC reply received.
Notes
-----
Reply queue semantics: can be used to delete the queue
after transient reply message received.
"""
def prepare_queue_arguments(self, arguments, **kwargs):
return arguments
def __enter__(self):
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None
) -> None:
self.close()
class Management:
"""AMQP Management API (incomplete)."""
def __init__(self, transport):
self.transport = transport
def get_bindings(self):
raise _LeftBlank(self, 'get_bindings')
class Implements(dict):
"""Helper class used to define transport features."""
def __getattr__(self, key):
try:
return self[key]
except KeyError:
raise AttributeError(key)
def __setattr__(self, key, value):
self[key] = value
def extend(self, **kwargs):
return self.__class__(self, **kwargs)
default_transport_capabilities = Implements(
asynchronous=False,
exchange_type=frozenset(['direct', 'topic', 'fanout', 'headers']),
heartbeats=False,
)
class Transport:
"""Base class for transports."""
Management = Management
#: The :class:`~kombu.Connection` owning this instance.
client = None
#: Set to True if :class:`~kombu.Connection` should pass the URL
#: unmodified.
can_parse_url = False
#: Default port used when no port has been specified.
default_port = None
#: Tuple of errors that can happen due to connection failure.
connection_errors = (ConnectionError,)
#: Tuple of errors that can happen due to channel/method failure.
channel_errors = (ChannelError,)
#: Type of driver, can be used to separate transports
#: using the AMQP protocol (driver_type: 'amqp'),
#: Redis (driver_type: 'redis'), etc...
driver_type = 'N/A'
#: Name of driver library (e.g. 'py-amqp', 'redis').
driver_name = 'N/A'
__reader = None
implements = default_transport_capabilities.extend()
def __init__(self, client, **kwargs):
self.client = client
def establish_connection(self):
raise _LeftBlank(self, 'establish_connection')
def close_connection(self, connection):
raise _LeftBlank(self, 'close_connection')
def create_channel(self, connection):
raise _LeftBlank(self, 'create_channel')
def close_channel(self, connection):
raise _LeftBlank(self, 'close_channel')
def drain_events(self, connection, **kwargs):
raise _LeftBlank(self, 'drain_events')
def heartbeat_check(self, connection, rate=2):
pass
def driver_version(self):
return 'N/A'
def get_heartbeat_interval(self, connection):
return 0
def register_with_event_loop(self, connection, loop):
pass
def unregister_from_event_loop(self, connection, loop):
pass
def verify_connection(self, connection):
return True
def _make_reader(self, connection, timeout=socket.timeout,
error=socket.error, _unavail=(errno.EAGAIN, errno.EINTR)):
drain_events = connection.drain_events
def _read(loop):
if not connection.connected:
raise RecoverableConnectionError('Socket was disconnected')
try:
drain_events(timeout=0)
except timeout:
return
except error as exc:
if exc.errno in _unavail:
return
raise
loop.call_soon(_read, loop)
return _read
def qos_semantics_matches_spec(self, connection):
return True
def on_readable(self, connection, loop):
reader = self.__reader
if reader is None:
reader = self.__reader = self._make_reader(connection)
reader(loop)
def as_uri(self, uri: str, include_password=False, mask='**') -> str:
"""Customise the display format of the URI."""
raise NotImplementedError()
@property
def default_connection_params(self):
return {}
def get_manager(self, *args, **kwargs):
return self.Management(self)
@cached_property
def manager(self):
return self.get_manager()
@property
def supports_heartbeats(self):
return self.implements.heartbeats
@property
def supports_ev(self):
return self.implements.asynchronous

View File

@@ -0,0 +1,380 @@
"""confluent-kafka transport module for Kombu.
Kafka transport using confluent-kafka library.
**References**
- http://docs.confluent.io/current/clients/confluent-kafka-python
**Limitations**
The confluent-kafka transport does not support PyPy environment.
Features
========
* Type: Virtual
* Supports Direct: Yes
* Supports Topic: Yes
* Supports Fanout: No
* Supports Priority: No
* Supports TTL: No
Connection String
=================
Connection string has the following format:
.. code-block::
confluentkafka://[USER:PASSWORD@]KAFKA_ADDRESS[:PORT]
Transport Options
=================
* ``connection_wait_time_seconds`` - Time in seconds to wait for connection
to succeed. Default ``5``
* ``wait_time_seconds`` - Time in seconds to wait to receive messages.
Default ``5``
* ``security_protocol`` - Protocol used to communicate with broker.
Visit https://github.com/edenhill/librdkafka/blob/master/CONFIGURATION.md for
an explanation of valid values. Default ``plaintext``
* ``sasl_mechanism`` - SASL mechanism to use for authentication.
Visit https://github.com/edenhill/librdkafka/blob/master/CONFIGURATION.md for
an explanation of valid values.
* ``num_partitions`` - Number of partitions to create. Default ``1``
* ``replication_factor`` - Replication factor of partitions. Default ``1``
* ``topic_config`` - Topic configuration. Must be a dict whose key-value pairs
correspond with attributes in the
http://kafka.apache.org/documentation.html#topicconfigs.
* ``kafka_common_config`` - Configuration applied to producer, consumer and
admin client. Must be a dict whose key-value pairs correspond with attributes
in the https://github.com/edenhill/librdkafka/blob/master/CONFIGURATION.md.
* ``kafka_producer_config`` - Producer configuration. Must be a dict whose
key-value pairs correspond with attributes in the
https://github.com/edenhill/librdkafka/blob/master/CONFIGURATION.md.
* ``kafka_consumer_config`` - Consumer configuration. Must be a dict whose
key-value pairs correspond with attributes in the
https://github.com/edenhill/librdkafka/blob/master/CONFIGURATION.md.
* ``kafka_admin_config`` - Admin client configuration. Must be a dict whose
key-value pairs correspond with attributes in the
https://github.com/edenhill/librdkafka/blob/master/CONFIGURATION.md.
"""
from __future__ import annotations
from queue import Empty
from kombu.transport import virtual
from kombu.utils import cached_property
from kombu.utils.encoding import str_to_bytes
from kombu.utils.json import dumps, loads
try:
import confluent_kafka
from confluent_kafka import (Consumer, KafkaException, Producer,
TopicPartition)
from confluent_kafka.admin import AdminClient, NewTopic
KAFKA_CONNECTION_ERRORS = ()
KAFKA_CHANNEL_ERRORS = ()
except ImportError:
confluent_kafka = None
KAFKA_CONNECTION_ERRORS = KAFKA_CHANNEL_ERRORS = ()
from kombu.log import get_logger
logger = get_logger(__name__)
DEFAULT_PORT = 9092
class NoBrokersAvailable(KafkaException):
"""Kafka broker is not available exception."""
retriable = True
class Message(virtual.Message):
"""Message object."""
def __init__(self, payload, channel=None, **kwargs):
self.topic = payload.get('topic')
super().__init__(payload, channel=channel, **kwargs)
class QoS(virtual.QoS):
"""Quality of Service guarantees."""
_not_yet_acked = {}
def can_consume(self):
"""Return true if the channel can be consumed from.
:returns: True, if this QoS object can accept a message.
:rtype: bool
"""
return not self.prefetch_count or len(self._not_yet_acked) < self \
.prefetch_count
def can_consume_max_estimate(self):
if self.prefetch_count:
return self.prefetch_count - len(self._not_yet_acked)
else:
return 1
def append(self, message, delivery_tag):
self._not_yet_acked[delivery_tag] = message
def get(self, delivery_tag):
return self._not_yet_acked[delivery_tag]
def ack(self, delivery_tag):
if delivery_tag not in self._not_yet_acked:
return
message = self._not_yet_acked.pop(delivery_tag)
consumer = self.channel._get_consumer(message.topic)
consumer.commit()
def reject(self, delivery_tag, requeue=False):
"""Reject a message by delivery tag.
If requeue is True, then the last consumed message is reverted so
it'll be refetched on the next attempt.
If False, that message is consumed and ignored.
"""
if requeue:
message = self._not_yet_acked.pop(delivery_tag)
consumer = self.channel._get_consumer(message.topic)
for assignment in consumer.assignment():
topic_partition = TopicPartition(message.topic,
assignment.partition)
[committed_offset] = consumer.committed([topic_partition])
consumer.seek(committed_offset)
else:
self.ack(delivery_tag)
def restore_unacked_once(self, stderr=None):
pass
class Channel(virtual.Channel):
"""Kafka Channel."""
QoS = QoS
Message = Message
default_wait_time_seconds = 5
default_connection_wait_time_seconds = 5
_client = None
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._kafka_consumers = {}
self._kafka_producers = {}
self._client = self._open()
def sanitize_queue_name(self, queue):
"""Need to sanitize the name, celery sometimes pushes in @ signs."""
return str(queue).replace('@', '')
def _get_producer(self, queue):
"""Create/get a producer instance for the given topic/queue."""
queue = self.sanitize_queue_name(queue)
producer = self._kafka_producers.get(queue, None)
if producer is None:
producer = Producer({
**self.common_config,
**(self.options.get('kafka_producer_config') or {}),
})
self._kafka_producers[queue] = producer
return producer
def _get_consumer(self, queue):
"""Create/get a consumer instance for the given topic/queue."""
queue = self.sanitize_queue_name(queue)
consumer = self._kafka_consumers.get(queue, None)
if consumer is None:
consumer = Consumer({
'group.id': f'{queue}-consumer-group',
'auto.offset.reset': 'earliest',
'enable.auto.commit': False,
**self.common_config,
**(self.options.get('kafka_consumer_config') or {}),
})
consumer.subscribe([queue])
self._kafka_consumers[queue] = consumer
return consumer
def _put(self, queue, message, **kwargs):
"""Put a message on the topic/queue."""
queue = self.sanitize_queue_name(queue)
producer = self._get_producer(queue)
producer.produce(queue, str_to_bytes(dumps(message)))
producer.flush()
def _get(self, queue, **kwargs):
"""Get a message from the topic/queue."""
queue = self.sanitize_queue_name(queue)
consumer = self._get_consumer(queue)
message = None
try:
message = consumer.poll(self.wait_time_seconds)
except StopIteration:
pass
if not message:
raise Empty()
error = message.error()
if error:
logger.error(error)
raise Empty()
return {**loads(message.value()), 'topic': message.topic()}
def _delete(self, queue, *args, **kwargs):
"""Delete a queue/topic."""
queue = self.sanitize_queue_name(queue)
self._kafka_consumers[queue].close()
self._kafka_consumers.pop(queue)
self.client.delete_topics([queue])
def _size(self, queue):
"""Get the number of pending messages in the topic/queue."""
queue = self.sanitize_queue_name(queue)
consumer = self._kafka_consumers.get(queue, None)
if consumer is None:
return 0
size = 0
for assignment in consumer.assignment():
topic_partition = TopicPartition(queue, assignment.partition)
(_, end_offset) = consumer.get_watermark_offsets(topic_partition)
[committed_offset] = consumer.committed([topic_partition])
size += end_offset - committed_offset.offset
return size
def _new_queue(self, queue, **kwargs):
"""Create a new topic if it does not exist."""
queue = self.sanitize_queue_name(queue)
if queue in self.client.list_topics().topics:
return
topic = NewTopic(
queue,
num_partitions=self.options.get('num_partitions', 1),
replication_factor=self.options.get('replication_factor', 1),
config=self.options.get('topic_config', {})
)
self.client.create_topics(new_topics=[topic])
def _has_queue(self, queue, **kwargs):
"""Check if a topic already exists."""
queue = self.sanitize_queue_name(queue)
return queue in self.client.list_topics().topics
def _open(self):
client = AdminClient({
**self.common_config,
**(self.options.get('kafka_admin_config') or {}),
})
try:
# seems to be the only way to check connection
client.list_topics(timeout=self.wait_time_seconds)
except confluent_kafka.KafkaException as e:
raise NoBrokersAvailable(e)
return client
@property
def client(self):
if self._client is None:
self._client = self._open()
return self._client
@property
def options(self):
return self.connection.client.transport_options
@property
def conninfo(self):
return self.connection.client
@cached_property
def wait_time_seconds(self):
return self.options.get(
'wait_time_seconds', self.default_wait_time_seconds
)
@cached_property
def connection_wait_time_seconds(self):
return self.options.get(
'connection_wait_time_seconds',
self.default_connection_wait_time_seconds,
)
@cached_property
def common_config(self):
conninfo = self.connection.client
config = {
'bootstrap.servers':
f'{conninfo.hostname}:{int(conninfo.port) or DEFAULT_PORT}',
}
security_protocol = self.options.get('security_protocol', 'plaintext')
if security_protocol.lower() != 'plaintext':
config.update({
'security.protocol': security_protocol,
'sasl.username': conninfo.userid,
'sasl.password': conninfo.password,
'sasl.mechanism': self.options.get('sasl_mechanism'),
})
config.update(self.options.get('kafka_common_config') or {})
return config
def close(self):
super().close()
self._kafka_producers = {}
for consumer in self._kafka_consumers.values():
consumer.close()
self._kafka_consumers = {}
class Transport(virtual.Transport):
"""Kafka Transport."""
def as_uri(self, uri: str, include_password=False, mask='**') -> str:
pass
Channel = Channel
default_port = DEFAULT_PORT
driver_type = 'kafka'
driver_name = 'confluentkafka'
recoverable_connection_errors = (
NoBrokersAvailable,
)
def __init__(self, client, **kwargs):
if confluent_kafka is None:
raise ImportError('The confluent-kafka library is not installed')
super().__init__(client, **kwargs)
def driver_version(self):
return confluent_kafka.__version__
def establish_connection(self):
return super().establish_connection()
def close_connection(self, connection):
return super().close_connection(connection)

View File

@@ -0,0 +1,323 @@
"""Consul Transport module for Kombu.
Features
========
It uses Consul.io's Key/Value store to transport messages in Queues
It uses python-consul for talking to Consul's HTTP API
Features
========
* Type: Native
* Supports Direct: Yes
* Supports Topic: *Unreviewed*
* Supports Fanout: *Unreviewed*
* Supports Priority: *Unreviewed*
* Supports TTL: *Unreviewed*
Connection String
=================
Connection string has the following format:
.. code-block::
consul://CONSUL_ADDRESS[:PORT]
"""
from __future__ import annotations
import socket
import uuid
from collections import defaultdict
from contextlib import contextmanager
from queue import Empty
from time import monotonic
from kombu.exceptions import ChannelError
from kombu.log import get_logger
from kombu.utils.json import dumps, loads
from kombu.utils.objects import cached_property
from . import virtual
try:
import consul
except ImportError:
consul = None
logger = get_logger('kombu.transport.consul')
DEFAULT_PORT = 8500
DEFAULT_HOST = 'localhost'
class LockError(Exception):
"""An error occurred while trying to acquire the lock."""
class Channel(virtual.Channel):
"""Consul Channel class which talks to the Consul Key/Value store."""
prefix = 'kombu'
index = None
timeout = '10s'
session_ttl = 30
def __init__(self, *args, **kwargs):
if consul is None:
raise ImportError('Missing python-consul library')
super().__init__(*args, **kwargs)
port = self.connection.client.port or self.connection.default_port
host = self.connection.client.hostname or DEFAULT_HOST
logger.debug('Host: %s Port: %s Timeout: %s', host, port, self.timeout)
self.queues = defaultdict(dict)
self.client = consul.Consul(host=host, port=int(port))
def _lock_key(self, queue):
return f'{self.prefix}/{queue}.lock'
def _key_prefix(self, queue):
return f'{self.prefix}/{queue}'
def _get_or_create_session(self, queue):
"""Get or create consul session.
Try to renew the session if it exists, otherwise create a new
session in Consul.
This session is used to acquire a lock inside Consul so that we achieve
read-consistency between the nodes.
Arguments:
---------
queue (str): The name of the Queue.
Returns
-------
str: The ID of the session.
"""
try:
session_id = self.queues[queue]['session_id']
except KeyError:
session_id = None
return (self._renew_existing_session(session_id)
if session_id is not None else self._create_new_session())
def _renew_existing_session(self, session_id):
logger.debug('Trying to renew existing session %s', session_id)
session = self.client.session.renew(session_id=session_id)
return session.get('ID')
def _create_new_session(self):
logger.debug('Creating session %s with TTL %s',
self.lock_name, self.session_ttl)
session_id = self.client.session.create(
name=self.lock_name, ttl=self.session_ttl)
logger.debug('Created session %s with id %s',
self.lock_name, session_id)
return session_id
@contextmanager
def _queue_lock(self, queue, raising=LockError):
"""Try to acquire a lock on the Queue.
It does so by creating a object called 'lock' which is locked by the
current session..
This way other nodes are not able to write to the lock object which
means that they have to wait before the lock is released.
Arguments:
---------
queue (str): The name of the Queue.
raising (Exception): Set custom lock error class.
Raises
------
LockError: if the lock cannot be acquired.
Returns
-------
bool: success?
"""
self._acquire_lock(queue, raising=raising)
try:
yield
finally:
self._release_lock(queue)
def _acquire_lock(self, queue, raising=LockError):
session_id = self._get_or_create_session(queue)
lock_key = self._lock_key(queue)
logger.debug('Trying to create lock object %s with session %s',
lock_key, session_id)
if self.client.kv.put(key=lock_key,
acquire=session_id,
value=self.lock_name):
self.queues[queue]['session_id'] = session_id
return
logger.info('Could not acquire lock on key %s', lock_key)
raise raising()
def _release_lock(self, queue):
"""Try to release a lock.
It does so by simply removing the lock key in Consul.
Arguments:
---------
queue (str): The name of the queue we want to release
the lock from.
"""
logger.debug('Removing lock key %s', self._lock_key(queue))
self.client.kv.delete(key=self._lock_key(queue))
def _destroy_session(self, queue):
"""Destroy a previously created Consul session.
Will release all locks it still might hold.
Arguments:
---------
queue (str): The name of the Queue.
"""
logger.debug('Destroying session %s', self.queues[queue]['session_id'])
self.client.session.destroy(self.queues[queue]['session_id'])
def _new_queue(self, queue, **_):
self.queues[queue] = {'session_id': None}
return self.client.kv.put(key=self._key_prefix(queue), value=None)
def _delete(self, queue, *args, **_):
self._destroy_session(queue)
self.queues.pop(queue, None)
self._purge(queue)
def _put(self, queue, payload, **_):
"""Put `message` onto `queue`.
This simply writes a key to the K/V store of Consul
"""
key = '{}/msg/{}_{}'.format(
self._key_prefix(queue),
int(round(monotonic() * 1000)),
uuid.uuid4(),
)
if not self.client.kv.put(key=key, value=dumps(payload), cas=0):
raise ChannelError(f'Cannot add key {key!r} to consul')
def _get(self, queue, timeout=None):
"""Get the first available message from the queue.
Before it does so it acquires a lock on the Key/Value store so
only one node reads at the same time. This is for read consistency
"""
with self._queue_lock(queue, raising=Empty):
key = f'{self._key_prefix(queue)}/msg/'
logger.debug('Fetching key %s with index %s', key, self.index)
self.index, data = self.client.kv.get(
key=key, recurse=True,
index=self.index, wait=self.timeout,
)
try:
if data is None:
raise Empty()
logger.debug('Removing key %s with modifyindex %s',
data[0]['Key'], data[0]['ModifyIndex'])
self.client.kv.delete(key=data[0]['Key'],
cas=data[0]['ModifyIndex'])
return loads(data[0]['Value'])
except TypeError:
pass
raise Empty()
def _purge(self, queue):
self._destroy_session(queue)
return self.client.kv.delete(
key=f'{self._key_prefix(queue)}/msg/',
recurse=True,
)
def _size(self, queue):
size = 0
try:
key = f'{self._key_prefix(queue)}/msg/'
logger.debug('Fetching key recursively %s with index %s',
key, self.index)
self.index, data = self.client.kv.get(
key=key, recurse=True,
index=self.index, wait=self.timeout,
)
size = len(data)
except TypeError:
pass
logger.debug('Found %s keys under %s with index %s',
size, key, self.index)
return size
@cached_property
def lock_name(self):
return f'{socket.gethostname()}'
class Transport(virtual.Transport):
"""Consul K/V storage Transport for Kombu."""
Channel = Channel
default_port = DEFAULT_PORT
driver_type = 'consul'
driver_name = 'consul'
if consul:
connection_errors = (
virtual.Transport.connection_errors + (
consul.ConsulException, consul.base.ConsulException
)
)
channel_errors = (
virtual.Transport.channel_errors + (
consul.ConsulException, consul.base.ConsulException
)
)
def __init__(self, *args, **kwargs):
if consul is None:
raise ImportError('Missing python-consul library')
super().__init__(*args, **kwargs)
def verify_connection(self, connection):
port = connection.client.port or self.default_port
host = connection.client.hostname or DEFAULT_HOST
logger.debug('Verify Consul connection to %s:%s', host, port)
try:
client = consul.Consul(host=host, port=int(port))
client.agent.self()
return True
except ValueError:
pass
return False
def driver_version(self):
return consul.__version__

View File

@@ -0,0 +1,300 @@
"""Etcd Transport module for Kombu.
It uses Etcd as a store to transport messages in Queues
It uses python-etcd for talking to Etcd's HTTP API
Features
========
* Type: Virtual
* Supports Direct: *Unreviewed*
* Supports Topic: *Unreviewed*
* Supports Fanout: *Unreviewed*
* Supports Priority: *Unreviewed*
* Supports TTL: *Unreviewed*
Connection String
=================
Connection string has the following format:
.. code-block::
'etcd'://SERVER:PORT
"""
from __future__ import annotations
import os
import socket
from collections import defaultdict
from contextlib import contextmanager
from queue import Empty
from kombu.exceptions import ChannelError
from kombu.log import get_logger
from kombu.utils.json import dumps, loads
from kombu.utils.objects import cached_property
from . import virtual
try:
import etcd
except ImportError:
etcd = None
logger = get_logger('kombu.transport.etcd')
DEFAULT_PORT = 2379
DEFAULT_HOST = 'localhost'
class Channel(virtual.Channel):
"""Etcd Channel class which talks to the Etcd."""
prefix = 'kombu'
index = None
timeout = 10
session_ttl = 30
lock_ttl = 10
def __init__(self, *args, **kwargs):
if etcd is None:
raise ImportError('Missing python-etcd library')
super().__init__(*args, **kwargs)
port = self.connection.client.port or self.connection.default_port
host = self.connection.client.hostname or DEFAULT_HOST
logger.debug('Host: %s Port: %s Timeout: %s', host, port, self.timeout)
self.queues = defaultdict(dict)
self.client = etcd.Client(host=host, port=int(port))
def _key_prefix(self, queue):
"""Create and return the `queue` with the proper prefix.
Arguments:
---------
queue (str): The name of the queue.
"""
return f'{self.prefix}/{queue}'
@contextmanager
def _queue_lock(self, queue):
"""Try to acquire a lock on the Queue.
It does so by creating a object called 'lock' which is locked by the
current session..
This way other nodes are not able to write to the lock object which
means that they have to wait before the lock is released.
Arguments:
---------
queue (str): The name of the queue.
"""
lock = etcd.Lock(self.client, queue)
lock._uuid = self.lock_value
logger.debug(f'Acquiring lock {lock.name}')
lock.acquire(blocking=True, lock_ttl=self.lock_ttl)
try:
yield
finally:
logger.debug(f'Releasing lock {lock.name}')
lock.release()
def _new_queue(self, queue, **_):
"""Create a new `queue` if the `queue` doesn't already exist.
Arguments:
---------
queue (str): The name of the queue.
"""
self.queues[queue] = queue
with self._queue_lock(queue):
try:
return self.client.write(
key=self._key_prefix(queue), dir=True, value=None)
except etcd.EtcdNotFile:
logger.debug(f'Queue "{queue}" already exists')
return self.client.read(key=self._key_prefix(queue))
def _has_queue(self, queue, **kwargs):
"""Verify that queue exists.
Returns
-------
bool: Should return :const:`True` if the queue exists
or :const:`False` otherwise.
"""
try:
self.client.read(self._key_prefix(queue))
return True
except etcd.EtcdKeyNotFound:
return False
def _delete(self, queue, *args, **_):
"""Delete a `queue`.
Arguments:
---------
queue (str): The name of the queue.
"""
self.queues.pop(queue, None)
self._purge(queue)
def _put(self, queue, payload, **_):
"""Put `message` onto `queue`.
This simply writes a key to the Etcd store
Arguments:
---------
queue (str): The name of the queue.
payload (dict): Message data which will be dumped to etcd.
"""
with self._queue_lock(queue):
key = self._key_prefix(queue)
if not self.client.write(
key=key,
value=dumps(payload),
append=True):
raise ChannelError(f'Cannot add key {key!r} to etcd')
def _get(self, queue, timeout=None):
"""Get the first available message from the queue.
Before it does so it acquires a lock on the store so
only one node reads at the same time. This is for read consistency
Arguments:
---------
queue (str): The name of the queue.
timeout (int): Optional seconds to wait for a response.
"""
with self._queue_lock(queue):
key = self._key_prefix(queue)
logger.debug('Fetching key %s with index %s', key, self.index)
try:
result = self.client.read(
key=key, recursive=True,
index=self.index, timeout=self.timeout)
if result is None:
raise Empty()
item = result._children[-1]
logger.debug('Removing key {}'.format(item['key']))
msg_content = loads(item['value'])
self.client.delete(key=item['key'])
return msg_content
except (TypeError, IndexError, etcd.EtcdException) as error:
logger.debug(f'_get failed: {type(error)}:{error}')
raise Empty()
def _purge(self, queue):
"""Remove all `message`s from a `queue`.
Arguments:
---------
queue (str): The name of the queue.
"""
with self._queue_lock(queue):
key = self._key_prefix(queue)
logger.debug(f'Purging queue at key {key}')
return self.client.delete(key=key, recursive=True)
def _size(self, queue):
"""Return the size of the `queue`.
Arguments:
---------
queue (str): The name of the queue.
"""
with self._queue_lock(queue):
size = 0
try:
key = self._key_prefix(queue)
logger.debug('Fetching key recursively %s with index %s',
key, self.index)
result = self.client.read(
key=key, recursive=True,
index=self.index)
size = len(result._children)
except TypeError:
pass
logger.debug('Found %s keys under %s with index %s',
size, key, self.index)
return size
@cached_property
def lock_value(self):
return f'{socket.gethostname()}.{os.getpid()}'
class Transport(virtual.Transport):
"""Etcd storage Transport for Kombu."""
Channel = Channel
default_port = DEFAULT_PORT
driver_type = 'etcd'
driver_name = 'python-etcd'
polling_interval = 3
implements = virtual.Transport.implements.extend(
exchange_type=frozenset(['direct']))
if etcd:
connection_errors = (
virtual.Transport.connection_errors + (etcd.EtcdException, )
)
channel_errors = (
virtual.Transport.channel_errors + (etcd.EtcdException, )
)
def __init__(self, *args, **kwargs):
"""Create a new instance of etcd.Transport."""
if etcd is None:
raise ImportError('Missing python-etcd library')
super().__init__(*args, **kwargs)
def verify_connection(self, connection):
"""Verify the connection works."""
port = connection.client.port or self.default_port
host = connection.client.hostname or DEFAULT_HOST
logger.debug('Verify Etcd connection to %s:%s', host, port)
try:
etcd.Client(host=host, port=int(port))
return True
except ValueError:
pass
return False
def driver_version(self):
"""Return the version of the etcd library.
.. note::
python-etcd has no __version__. This is a workaround.
"""
try:
import pip.commands.freeze
for x in pip.commands.freeze.freeze():
if x.startswith('python-etcd'):
return x.split('==')[1]
except (ImportError, IndexError):
logger.warning('Unable to find the python-etcd version.')
return 'Unknown'

View File

@@ -0,0 +1,352 @@
"""File-system Transport module for kombu.
Transport using the file-system as the message store. Messages written to the
queue are stored in `data_folder_in` directory and
messages read from the queue are read from `data_folder_out` directory. Both
directories must be created manually. Simple example:
* Producer:
.. code-block:: python
import kombu
conn = kombu.Connection(
'filesystem://', transport_options={
'data_folder_in': 'data_in', 'data_folder_out': 'data_out'
}
)
conn.connect()
test_queue = kombu.Queue('test', routing_key='test')
with conn as conn:
with conn.default_channel as channel:
producer = kombu.Producer(channel)
producer.publish(
{'hello': 'world'},
retry=True,
exchange=test_queue.exchange,
routing_key=test_queue.routing_key,
declare=[test_queue],
serializer='pickle'
)
* Consumer:
.. code-block:: python
import kombu
conn = kombu.Connection(
'filesystem://', transport_options={
'data_folder_in': 'data_out', 'data_folder_out': 'data_in'
}
)
conn.connect()
def callback(body, message):
print(body, message)
message.ack()
test_queue = kombu.Queue('test', routing_key='test')
with conn as conn:
with conn.default_channel as channel:
consumer = kombu.Consumer(
conn, [test_queue], accept=['pickle']
)
consumer.register_callback(callback)
with consumer:
conn.drain_events(timeout=1)
Features
========
* Type: Virtual
* Supports Direct: Yes
* Supports Topic: Yes
* Supports Fanout: Yes
* Supports Priority: No
* Supports TTL: No
Connection String
=================
Connection string is in the following format:
.. code-block::
filesystem://
Transport Options
=================
* ``data_folder_in`` - directory where are messages stored when written
to queue.
* ``data_folder_out`` - directory from which are messages read when read from
queue.
* ``store_processed`` - if set to True, all processed messages are backed up to
``processed_folder``.
* ``processed_folder`` - directory where are backed up processed files.
* ``control_folder`` - directory where are exchange-queue table stored.
"""
from __future__ import annotations
import os
import shutil
import tempfile
import uuid
from collections import namedtuple
from pathlib import Path
from queue import Empty
from time import monotonic
from kombu.exceptions import ChannelError
from kombu.transport import virtual
from kombu.utils.encoding import bytes_to_str, str_to_bytes
from kombu.utils.json import dumps, loads
from kombu.utils.objects import cached_property
VERSION = (1, 0, 0)
__version__ = '.'.join(map(str, VERSION))
# needs win32all to work on Windows
if os.name == 'nt':
import pywintypes
import win32con
import win32file
LOCK_EX = win32con.LOCKFILE_EXCLUSIVE_LOCK
# 0 is the default
LOCK_SH = 0
LOCK_NB = win32con.LOCKFILE_FAIL_IMMEDIATELY
__overlapped = pywintypes.OVERLAPPED()
def lock(file, flags):
"""Create file lock."""
hfile = win32file._get_osfhandle(file.fileno())
win32file.LockFileEx(hfile, flags, 0, 0xffff0000, __overlapped)
def unlock(file):
"""Remove file lock."""
hfile = win32file._get_osfhandle(file.fileno())
win32file.UnlockFileEx(hfile, 0, 0xffff0000, __overlapped)
elif os.name == 'posix':
import fcntl
from fcntl import LOCK_EX, LOCK_SH
def lock(file, flags):
"""Create file lock."""
fcntl.flock(file.fileno(), flags)
def unlock(file):
"""Remove file lock."""
fcntl.flock(file.fileno(), fcntl.LOCK_UN)
else:
raise RuntimeError(
'Filesystem plugin only defined for NT and POSIX platforms')
exchange_queue_t = namedtuple("exchange_queue_t",
["routing_key", "pattern", "queue"])
class Channel(virtual.Channel):
"""Filesystem Channel."""
supports_fanout = True
def get_table(self, exchange):
file = self.control_folder / f"{exchange}.exchange"
try:
f_obj = file.open("r")
try:
lock(f_obj, LOCK_SH)
exchange_table = loads(bytes_to_str(f_obj.read()))
return [exchange_queue_t(*q) for q in exchange_table]
finally:
unlock(f_obj)
f_obj.close()
except FileNotFoundError:
return []
except OSError:
raise ChannelError(f"Cannot open {file}")
def _queue_bind(self, exchange, routing_key, pattern, queue):
file = self.control_folder / f"{exchange}.exchange"
self.control_folder.mkdir(exist_ok=True)
queue_val = exchange_queue_t(routing_key or "", pattern or "",
queue or "")
try:
if file.exists():
f_obj = file.open("rb+", buffering=0)
lock(f_obj, LOCK_EX)
exchange_table = loads(bytes_to_str(f_obj.read()))
queues = [exchange_queue_t(*q) for q in exchange_table]
if queue_val not in queues:
queues.insert(0, queue_val)
f_obj.seek(0)
f_obj.write(str_to_bytes(dumps(queues)))
else:
f_obj = file.open("wb", buffering=0)
lock(f_obj, LOCK_EX)
queues = [queue_val]
f_obj.write(str_to_bytes(dumps(queues)))
finally:
unlock(f_obj)
f_obj.close()
def _put_fanout(self, exchange, payload, routing_key, **kwargs):
for q in self.get_table(exchange):
self._put(q.queue, payload, **kwargs)
def _put(self, queue, payload, **kwargs):
"""Put `message` onto `queue`."""
filename = '{}_{}.{}.msg'.format(int(round(monotonic() * 1000)),
uuid.uuid4(), queue)
filename = os.path.join(self.data_folder_out, filename)
try:
f = open(filename, 'wb', buffering=0)
lock(f, LOCK_EX)
f.write(str_to_bytes(dumps(payload)))
except OSError:
raise ChannelError(
f'Cannot add file {filename!r} to directory')
finally:
unlock(f)
f.close()
def _get(self, queue):
"""Get next message from `queue`."""
queue_find = '.' + queue + '.msg'
folder = os.listdir(self.data_folder_in)
folder = sorted(folder)
while len(folder) > 0:
filename = folder.pop(0)
# only handle message for the requested queue
if filename.find(queue_find) < 0:
continue
if self.store_processed:
processed_folder = self.processed_folder
else:
processed_folder = tempfile.gettempdir()
try:
# move the file to the tmp/processed folder
shutil.move(os.path.join(self.data_folder_in, filename),
processed_folder)
except OSError:
# file could be locked, or removed in meantime so ignore
continue
filename = os.path.join(processed_folder, filename)
try:
f = open(filename, 'rb')
payload = f.read()
f.close()
if not self.store_processed:
os.remove(filename)
except OSError:
raise ChannelError(
f'Cannot read file {filename!r} from queue.')
return loads(bytes_to_str(payload))
raise Empty()
def _purge(self, queue):
"""Remove all messages from `queue`."""
count = 0
queue_find = '.' + queue + '.msg'
folder = os.listdir(self.data_folder_in)
while len(folder) > 0:
filename = folder.pop()
try:
# only purge messages for the requested queue
if filename.find(queue_find) < 0:
continue
filename = os.path.join(self.data_folder_in, filename)
os.remove(filename)
count += 1
except OSError:
# we simply ignore its existence, as it was probably
# processed by another worker
pass
return count
def _size(self, queue):
"""Return the number of messages in `queue` as an :class:`int`."""
count = 0
queue_find = f'.{queue}.msg'
folder = os.listdir(self.data_folder_in)
while len(folder) > 0:
filename = folder.pop()
# only handle message for the requested queue
if filename.find(queue_find) < 0:
continue
count += 1
return count
@property
def transport_options(self):
return self.connection.client.transport_options
@cached_property
def data_folder_in(self):
return self.transport_options.get('data_folder_in', 'data_in')
@cached_property
def data_folder_out(self):
return self.transport_options.get('data_folder_out', 'data_out')
@cached_property
def store_processed(self):
return self.transport_options.get('store_processed', False)
@cached_property
def processed_folder(self):
return self.transport_options.get('processed_folder', 'processed')
@property
def control_folder(self):
return Path(self.transport_options.get('control_folder', 'control'))
class Transport(virtual.Transport):
"""Filesystem Transport."""
implements = virtual.Transport.implements.extend(
asynchronous=False,
exchange_type=frozenset(['direct', 'topic', 'fanout'])
)
Channel = Channel
# filesystem backend state is global.
global_state = virtual.BrokerState()
default_port = 0
driver_type = 'filesystem'
driver_name = 'filesystem'
def __init__(self, client, **kwargs):
super().__init__(client, **kwargs)
self.state = self.global_state
def driver_version(self):
return 'N/A'

Some files were not shown because too many files have changed in this diff Show More