This commit is contained in:
Iliyan Angelov
2025-09-14 23:24:25 +03:00
commit c67067a2a4
71311 changed files with 6800714 additions and 0 deletions

View File

@@ -0,0 +1,202 @@
"""SoftLayer Message Queue transport module for kombu.
Features
========
* Type: Virtual
* Supports Direct: Yes
* Supports Topic: Yes
* Supports Fanout: No
* Supports Priority: No
* Supports TTL: No
Connection String
=================
*Unreviewed*
Transport Options
=================
*Unreviewed*
"""
from __future__ import annotations
import os
import socket
import string
from queue import Empty
from kombu.utils.encoding import bytes_to_str, safe_str
from kombu.utils.json import dumps, loads
from kombu.utils.objects import cached_property
from . import virtual
try:
from softlayer_messaging import get_client
from softlayer_messaging.errors import ResponseError
except ImportError: # pragma: no cover
get_client = ResponseError = None
# dots are replaced by dash, all other punctuation replaced by underscore.
CHARS_REPLACE_TABLE = {
ord(c): 0x5f for c in string.punctuation if c not in '_'
}
class Channel(virtual.Channel):
"""SLMQ Channel."""
default_visibility_timeout = 1800 # 30 minutes.
domain_format = 'kombu%(vhost)s'
_slmq = None
_queue_cache = {}
_noack_queues = set()
def __init__(self, *args, **kwargs):
if get_client is None:
raise ImportError(
'SLMQ transport requires the softlayer_messaging library',
)
super().__init__(*args, **kwargs)
queues = self.slmq.queues()
for queue in queues:
self._queue_cache[queue] = queue
def basic_consume(self, queue, no_ack, *args, **kwargs):
if no_ack:
self._noack_queues.add(queue)
return super().basic_consume(queue, no_ack,
*args, **kwargs)
def basic_cancel(self, consumer_tag):
if consumer_tag in self._consumers:
queue = self._tag_to_queue[consumer_tag]
self._noack_queues.discard(queue)
return super().basic_cancel(consumer_tag)
def entity_name(self, name, table=CHARS_REPLACE_TABLE):
"""Format AMQP queue name into a valid SLQS queue name."""
return str(safe_str(name)).translate(table)
def _new_queue(self, queue, **kwargs):
"""Ensure a queue exists in SLQS."""
queue = self.entity_name(self.queue_name_prefix + queue)
try:
return self._queue_cache[queue]
except KeyError:
try:
self.slmq.create_queue(
queue, visibility_timeout=self.visibility_timeout)
except ResponseError:
pass
q = self._queue_cache[queue] = self.slmq.queue(queue)
return q
def _delete(self, queue, *args, **kwargs):
"""Delete queue by name."""
queue_name = self.entity_name(queue)
self._queue_cache.pop(queue_name, None)
self.slmq.queue(queue_name).delete(force=True)
super()._delete(queue_name)
def _put(self, queue, message, **kwargs):
"""Put message onto queue."""
q = self._new_queue(queue)
q.push(dumps(message))
def _get(self, queue):
"""Try to retrieve a single message off ``queue``."""
q = self._new_queue(queue)
rs = q.pop(1)
if rs['items']:
m = rs['items'][0]
payload = loads(bytes_to_str(m['body']))
if queue in self._noack_queues:
q.message(m['id']).delete()
else:
payload['properties']['delivery_info'].update({
'slmq_message_id': m['id'], 'slmq_queue_name': q.name})
return payload
raise Empty()
def basic_ack(self, delivery_tag):
delivery_info = self.qos.get(delivery_tag).delivery_info
try:
queue = delivery_info['slmq_queue_name']
except KeyError:
pass
else:
self.delete_message(queue, delivery_info['slmq_message_id'])
super().basic_ack(delivery_tag)
def _size(self, queue):
"""Return the number of messages in a queue."""
return self._new_queue(queue).detail()['message_count']
def _purge(self, queue):
"""Delete all current messages in a queue."""
q = self._new_queue(queue)
n = 0
results = q.pop(10)
while results['items']:
for m in results['items']:
self.delete_message(queue, m['id'])
n += 1
results = q.pop(10)
return n
def delete_message(self, queue, message_id):
q = self.slmq.queue(self.entity_name(queue))
return q.message(message_id).delete()
@property
def slmq(self):
if self._slmq is None:
conninfo = self.conninfo
account = os.environ.get('SLMQ_ACCOUNT', conninfo.virtual_host)
user = os.environ.get('SL_USERNAME', conninfo.userid)
api_key = os.environ.get('SL_API_KEY', conninfo.password)
host = os.environ.get('SLMQ_HOST', conninfo.hostname)
port = os.environ.get('SLMQ_PORT', conninfo.port)
secure = bool(os.environ.get(
'SLMQ_SECURE', self.transport_options.get('secure')) or True,
)
endpoint = '{}://{}{}'.format(
'https' if secure else 'http', host,
f':{port}' if port else '',
)
self._slmq = get_client(account, endpoint=endpoint)
self._slmq.authenticate(user, api_key)
return self._slmq
@property
def conninfo(self):
return self.connection.client
@property
def transport_options(self):
return self.connection.client.transport_options
@cached_property
def visibility_timeout(self):
return (self.transport_options.get('visibility_timeout') or
self.default_visibility_timeout)
@cached_property
def queue_name_prefix(self):
return self.transport_options.get('queue_name_prefix', '')
class Transport(virtual.Transport):
"""SLMQ Transport."""
Channel = Channel
polling_interval = 1
default_port = None
connection_errors = (
virtual.Transport.connection_errors + (
ResponseError, socket.error
)
)

View File

@@ -0,0 +1,973 @@
"""Amazon SQS transport module for Kombu.
This package implements an AMQP-like interface on top of Amazons SQS service,
with the goal of being optimized for high performance and reliability.
The default settings for this module are focused now on high performance in
task queue situations where tasks are small, idempotent and run very fast.
SQS Features supported by this transport
========================================
Long Polling
------------
https://docs.aws.amazon.com/AWSSimpleQueueService/latest/SQSDeveloperGuide/sqs-long-polling.html
Long polling is enabled by setting the `wait_time_seconds` transport
option to a number > 1. Amazon supports up to 20 seconds. This is
enabled with 10 seconds by default.
Batch API Actions
-----------------
https://docs.aws.amazon.com/AWSSimpleQueueService/latest/SQSDeveloperGuide/sqs-batch-api.html
The default behavior of the SQS Channel.drain_events() method is to
request up to the 'prefetch_count' messages on every request to SQS.
These messages are stored locally in a deque object and passed back
to the Transport until the deque is empty, before triggering a new
API call to Amazon.
This behavior dramatically speeds up the rate that you can pull tasks
from SQS when you have short-running tasks (or a large number of workers).
When a Celery worker has multiple queues to monitor, it will pull down
up to 'prefetch_count' messages from queueA and work on them all before
moving on to queueB. If queueB is empty, it will wait up until
'polling_interval' expires before moving back and checking on queueA.
Message Attributes
-----------------
https://docs.aws.amazon.com/AWSSimpleQueueService/latest/SQSDeveloperGuide/sqs-message-metadata.html
SQS supports sending message attributes along with the message body.
To use this feature, you can pass a 'message_attributes' as keyword argument
to `basic_publish` method.
Other Features supported by this transport
==========================================
Predefined Queues
-----------------
The default behavior of this transport is to use a single AWS credential
pair in order to manage all SQS queues (e.g. listing queues, creating
queues, polling queues, deleting messages).
If it is preferable for your environment to use multiple AWS credentials, you
can use the 'predefined_queues' setting inside the 'transport_options' map.
This setting allows you to specify the SQS queue URL and AWS credentials for
each of your queues. For example, if you have two queues which both already
exist in AWS) you can tell this transport about them as follows:
.. code-block:: python
transport_options = {
'predefined_queues': {
'queue-1': {
'url': 'https://sqs.us-east-1.amazonaws.com/xxx/aaa',
'access_key_id': 'a',
'secret_access_key': 'b',
'backoff_policy': {1: 10, 2: 20, 3: 40, 4: 80, 5: 320, 6: 640}, # optional
'backoff_tasks': ['svc.tasks.tasks.task1'] # optional
},
'queue-2.fifo': {
'url': 'https://sqs.us-east-1.amazonaws.com/xxx/bbb.fifo',
'access_key_id': 'c',
'secret_access_key': 'd',
'backoff_policy': {1: 10, 2: 20, 3: 40, 4: 80, 5: 320, 6: 640}, # optional
'backoff_tasks': ['svc.tasks.tasks.task2'] # optional
},
}
'sts_role_arn': 'arn:aws:iam::<xxx>:role/STSTest', # optional
'sts_token_timeout': 900 # optional
}
Note that FIFO and standard queues must be named accordingly (the name of
a FIFO queue must end with the .fifo suffix).
backoff_policy & backoff_tasks are optional arguments. These arguments
automatically change the message visibility timeout, in order to have
different times between specific task retries. This would apply after
task failure.
AWS STS authentication is supported, by using sts_role_arn, and
sts_token_timeout. sts_role_arn is the assumed IAM role ARN we are trying
to access with. sts_token_timeout is the token timeout, defaults (and minimum)
to 900 seconds. After the mentioned period, a new token will be created.
If you authenticate using Okta_ (e.g. calling |gac|_), you can also specify
a 'session_token' to connect to a queue. Note that those tokens have a
limited lifetime and are therefore only suited for short-lived tests.
.. _Okta: https://www.okta.com/
.. _gac: https://github.com/Nike-Inc/gimme-aws-creds#readme
.. |gac| replace:: ``gimme-aws-creds``
Client config
-------------
In some cases you may need to override the botocore config. You can do it
as follows:
.. code-block:: python
transport_option = {
'client-config': {
'connect_timeout': 5,
},
}
For a complete list of settings you can adjust using this option see
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
Features
========
* Type: Virtual
* Supports Direct: Yes
* Supports Topic: Yes
* Supports Fanout: Yes
* Supports Priority: No
* Supports TTL: No
"""
from __future__ import annotations
import base64
import socket
import string
import uuid
from datetime import datetime
from queue import Empty
from botocore.client import Config
from botocore.exceptions import ClientError
from vine import ensure_promise, promise, transform
from kombu.asynchronous import get_event_loop
from kombu.asynchronous.aws.ext import boto3, exceptions
from kombu.asynchronous.aws.sqs.connection import AsyncSQSConnection
from kombu.asynchronous.aws.sqs.message import AsyncMessage
from kombu.log import get_logger
from kombu.utils import scheduling
from kombu.utils.encoding import bytes_to_str, safe_str
from kombu.utils.json import dumps, loads
from kombu.utils.objects import cached_property
from . import virtual
logger = get_logger(__name__)
# dots are replaced by dash, dash remains dash, all other punctuation
# replaced by underscore.
CHARS_REPLACE_TABLE = {
ord(c): 0x5f for c in string.punctuation if c not in '-_.'
}
CHARS_REPLACE_TABLE[0x2e] = 0x2d # '.' -> '-'
#: SQS bulk get supports a maximum of 10 messages at a time.
SQS_MAX_MESSAGES = 10
def maybe_int(x):
"""Try to convert x' to int, or return x' if that fails."""
try:
return int(x)
except ValueError:
return x
class UndefinedQueueException(Exception):
"""Predefined queues are being used and an undefined queue was used."""
class InvalidQueueException(Exception):
"""Predefined queues are being used and configuration is not valid."""
class AccessDeniedQueueException(Exception):
"""Raised when access to the AWS queue is denied.
This may occur if the permissions are not correctly set or the
credentials are invalid.
"""
class DoesNotExistQueueException(Exception):
"""The specified queue doesn't exist."""
class QoS(virtual.QoS):
"""Quality of Service guarantees implementation for SQS."""
def reject(self, delivery_tag, requeue=False):
super().reject(delivery_tag, requeue=requeue)
routing_key, message, backoff_tasks, backoff_policy = \
self._extract_backoff_policy_configuration_and_message(
delivery_tag)
if routing_key and message and backoff_tasks and backoff_policy:
self.apply_backoff_policy(
routing_key, delivery_tag, backoff_policy, backoff_tasks)
def _extract_backoff_policy_configuration_and_message(self, delivery_tag):
try:
message = self._delivered[delivery_tag]
routing_key = message.delivery_info['routing_key']
except KeyError:
return None, None, None, None
if not routing_key or not message:
return None, None, None, None
queue_config = self.channel.predefined_queues.get(routing_key, {})
backoff_tasks = queue_config.get('backoff_tasks')
backoff_policy = queue_config.get('backoff_policy')
return routing_key, message, backoff_tasks, backoff_policy
def apply_backoff_policy(self, routing_key, delivery_tag,
backoff_policy, backoff_tasks):
queue_url = self.channel._queue_cache[routing_key]
task_name, number_of_retries = \
self.extract_task_name_and_number_of_retries(delivery_tag)
if not task_name or not number_of_retries:
return None
policy_value = backoff_policy.get(number_of_retries)
if task_name in backoff_tasks and policy_value is not None:
c = self.channel.sqs(routing_key)
c.change_message_visibility(
QueueUrl=queue_url,
ReceiptHandle=delivery_tag,
VisibilityTimeout=policy_value
)
def extract_task_name_and_number_of_retries(self, delivery_tag):
message = self._delivered[delivery_tag]
message_headers = message.headers
task_name = message_headers['task']
number_of_retries = int(
message.properties['delivery_info']['sqs_message']
['Attributes']['ApproximateReceiveCount'])
return task_name, number_of_retries
class Channel(virtual.Channel):
"""SQS Channel."""
default_region = 'us-east-1'
default_visibility_timeout = 1800 # 30 minutes.
default_wait_time_seconds = 10 # up to 20 seconds max
domain_format = 'kombu%(vhost)s'
_asynsqs = None
_predefined_queue_async_clients = {} # A client for each predefined queue
_sqs = None
_predefined_queue_clients = {} # A client for each predefined queue
_queue_cache = {} # SQS queue name => SQS queue URL
_noack_queues = set()
QoS = QoS
def __init__(self, *args, **kwargs):
if boto3 is None:
raise ImportError('boto3 is not installed')
super().__init__(*args, **kwargs)
self._validate_predifined_queues()
# SQS blows up if you try to create a new queue when one already
# exists but with a different visibility_timeout. This prepopulates
# the queue_cache to protect us from recreating
# queues that are known to already exist.
self._update_queue_cache(self.queue_name_prefix)
self.hub = kwargs.get('hub') or get_event_loop()
def _validate_predifined_queues(self):
"""Check that standard and FIFO queues are named properly.
AWS requires FIFO queues to have a name
that ends with the .fifo suffix.
"""
for queue_name, q in self.predefined_queues.items():
fifo_url = q['url'].endswith('.fifo')
fifo_name = queue_name.endswith('.fifo')
if fifo_url and not fifo_name:
raise InvalidQueueException(
"Queue with url '{}' must have a name "
"ending with .fifo".format(q['url'])
)
elif not fifo_url and fifo_name:
raise InvalidQueueException(
"Queue with name '{}' is not a FIFO queue: "
"'{}'".format(queue_name, q['url'])
)
def _update_queue_cache(self, queue_name_prefix):
if self.predefined_queues:
for queue_name, q in self.predefined_queues.items():
self._queue_cache[queue_name] = q['url']
return
resp = self.sqs().list_queues(QueueNamePrefix=queue_name_prefix)
for url in resp.get('QueueUrls', []):
queue_name = url.split('/')[-1]
self._queue_cache[queue_name] = url
def basic_consume(self, queue, no_ack, *args, **kwargs):
if no_ack:
self._noack_queues.add(queue)
if self.hub:
self._loop1(queue)
return super().basic_consume(
queue, no_ack, *args, **kwargs
)
def basic_cancel(self, consumer_tag):
if consumer_tag in self._consumers:
queue = self._tag_to_queue[consumer_tag]
self._noack_queues.discard(queue)
return super().basic_cancel(consumer_tag)
def drain_events(self, timeout=None, callback=None, **kwargs):
"""Return a single payload message from one of our queues.
Raises
------
Queue.Empty: if no messages available.
"""
# If we're not allowed to consume or have no consumers, raise Empty
if not self._consumers or not self.qos.can_consume():
raise Empty()
# At this point, go and get more messages from SQS
self._poll(self.cycle, callback, timeout=timeout)
def _reset_cycle(self):
"""Reset the consume cycle.
Returns
-------
FairCycle: object that points to our _get_bulk() method
rather than the standard _get() method. This allows for
multiple messages to be returned at once from SQS (
based on the prefetch limit).
"""
self._cycle = scheduling.FairCycle(
self._get_bulk, self._active_queues, Empty,
)
def entity_name(self, name, table=CHARS_REPLACE_TABLE):
"""Format AMQP queue name into a legal SQS queue name."""
if name.endswith('.fifo'):
partial = name[:-len('.fifo')]
partial = str(safe_str(partial)).translate(table)
return partial + '.fifo'
else:
return str(safe_str(name)).translate(table)
def canonical_queue_name(self, queue_name):
return self.entity_name(self.queue_name_prefix + queue_name)
def _resolve_queue_url(self, queue):
"""Try to retrieve the SQS queue URL for a given queue name."""
# Translate to SQS name for consistency with initial
# _queue_cache population.
sqs_qname = self.canonical_queue_name(queue)
# The SQS ListQueues method only returns 1000 queues. When you have
# so many queues, it's possible that the queue you are looking for is
# not cached. In this case, we could update the cache with the exact
# queue name first.
if sqs_qname not in self._queue_cache:
self._update_queue_cache(sqs_qname)
try:
return self._queue_cache[sqs_qname]
except KeyError:
if self.predefined_queues:
raise UndefinedQueueException((
"Queue with name '{}' must be "
"defined in 'predefined_queues'."
).format(sqs_qname))
raise DoesNotExistQueueException(
f"Queue with name '{sqs_qname}' doesn't exist in SQS"
)
def _new_queue(self, queue, **kwargs):
"""Ensure a queue with given name exists in SQS.
Arguments:
---------
queue (str): the AMQP queue name
Returns
str: the SQS queue URL
"""
try:
return self._resolve_queue_url(queue)
except DoesNotExistQueueException:
sqs_qname = self.canonical_queue_name(queue)
attributes = {'VisibilityTimeout': str(self.visibility_timeout)}
if sqs_qname.endswith('.fifo'):
attributes['FifoQueue'] = 'true'
resp = self._create_queue(sqs_qname, attributes)
self._queue_cache[sqs_qname] = resp['QueueUrl']
return resp['QueueUrl']
def _create_queue(self, queue_name, attributes):
"""Create an SQS queue with a given name and nominal attributes."""
# Allow specifying additional boto create_queue Attributes
# via transport options
if self.predefined_queues:
return None
attributes.update(
self.transport_options.get('sqs-creation-attributes') or {},
)
return self.sqs(queue=queue_name).create_queue(
QueueName=queue_name,
Attributes=attributes,
)
def _delete(self, queue, *args, **kwargs):
"""Delete queue by name."""
if self.predefined_queues:
return
q_url = self._resolve_queue_url(queue)
self.sqs().delete_queue(
QueueUrl=q_url,
)
self._queue_cache.pop(queue, None)
def _put(self, queue, message, **kwargs):
"""Put message onto queue."""
q_url = self._new_queue(queue)
kwargs = {'QueueUrl': q_url}
if 'properties' in message:
if 'message_attributes' in message['properties']:
# we don't want to want to have the attribute in the body
kwargs['MessageAttributes'] = \
message['properties'].pop('message_attributes')
if queue.endswith('.fifo'):
if 'MessageGroupId' in message['properties']:
kwargs['MessageGroupId'] = \
message['properties']['MessageGroupId']
else:
kwargs['MessageGroupId'] = 'default'
if 'MessageDeduplicationId' in message['properties']:
kwargs['MessageDeduplicationId'] = \
message['properties']['MessageDeduplicationId']
else:
kwargs['MessageDeduplicationId'] = str(uuid.uuid4())
else:
if "DelaySeconds" in message['properties']:
kwargs['DelaySeconds'] = \
message['properties']['DelaySeconds']
if self.sqs_base64_encoding:
body = AsyncMessage().encode(dumps(message))
else:
body = dumps(message)
kwargs['MessageBody'] = body
c = self.sqs(queue=self.canonical_queue_name(queue))
if message.get('redelivered'):
c.change_message_visibility(
QueueUrl=q_url,
ReceiptHandle=message['properties']['delivery_tag'],
VisibilityTimeout=0
)
else:
c.send_message(**kwargs)
@staticmethod
def _optional_b64_decode(byte_string):
try:
data = base64.b64decode(byte_string)
if base64.b64encode(data) == byte_string:
return data
# else the base64 module found some embedded base64 content
# that should be ignored.
except Exception: # pylint: disable=broad-except
pass
return byte_string
def _message_to_python(self, message, queue_name, q_url):
body = self._optional_b64_decode(message['Body'].encode())
payload = loads(bytes_to_str(body))
if queue_name in self._noack_queues:
q_url = self._new_queue(queue_name)
self.asynsqs(queue=queue_name).delete_message(
q_url,
message['ReceiptHandle'],
)
else:
try:
properties = payload['properties']
delivery_info = payload['properties']['delivery_info']
except KeyError:
# json message not sent by kombu?
delivery_info = {}
properties = {'delivery_info': delivery_info}
payload.update({
'body': bytes_to_str(body),
'properties': properties,
})
# set delivery tag to SQS receipt handle
delivery_info.update({
'sqs_message': message, 'sqs_queue': q_url,
})
properties['delivery_tag'] = message['ReceiptHandle']
return payload
def _messages_to_python(self, messages, queue):
"""Convert a list of SQS Message objects into Payloads.
This method handles converting SQS Message objects into
Payloads, and appropriately updating the queue depending on
the 'ack' settings for that queue.
Arguments:
---------
messages (SQSMessage): A list of SQS Message objects.
queue (str): Name representing the queue they came from.
Returns
-------
List: A list of Payload objects
"""
q_url = self._new_queue(queue)
return [self._message_to_python(m, queue, q_url) for m in messages]
def _get_bulk(self, queue,
max_if_unlimited=SQS_MAX_MESSAGES, callback=None):
"""Try to retrieve multiple messages off ``queue``.
Where :meth:`_get` returns a single Payload object, this method
returns a list of Payload objects. The number of objects returned
is determined by the total number of messages available in the queue
and the number of messages the QoS object allows (based on the
prefetch_count).
Note:
----
Ignores QoS limits so caller is responsible for checking
that we are allowed to consume at least one message from the
queue. get_bulk will then ask QoS for an estimate of
the number of extra messages that we can consume.
Arguments:
---------
queue (str): The queue name to pull from.
Returns
-------
List[Message]
"""
# drain_events calls `can_consume` first, consuming
# a token, so we know that we are allowed to consume at least
# one message.
# Note: ignoring max_messages for SQS with boto3
max_count = self._get_message_estimate()
if max_count:
q_url = self._new_queue(queue)
resp = self.sqs(queue=queue).receive_message(
QueueUrl=q_url, MaxNumberOfMessages=max_count,
WaitTimeSeconds=self.wait_time_seconds)
if resp.get('Messages'):
for m in resp['Messages']:
m['Body'] = AsyncMessage(body=m['Body']).decode()
for msg in self._messages_to_python(resp['Messages'], queue):
self.connection._deliver(msg, queue)
return
raise Empty()
def _get(self, queue):
"""Try to retrieve a single message off ``queue``."""
q_url = self._new_queue(queue)
resp = self.sqs(queue=queue).receive_message(
QueueUrl=q_url, MaxNumberOfMessages=1,
WaitTimeSeconds=self.wait_time_seconds)
if resp.get('Messages'):
body = AsyncMessage(body=resp['Messages'][0]['Body']).decode()
resp['Messages'][0]['Body'] = body
return self._messages_to_python(resp['Messages'], queue)[0]
raise Empty()
def _loop1(self, queue, _=None):
self.hub.call_soon(self._schedule_queue, queue)
def _schedule_queue(self, queue):
if queue in self._active_queues:
if self.qos.can_consume():
self._get_bulk_async(
queue, callback=promise(self._loop1, (queue,)),
)
else:
self._loop1(queue)
def _get_message_estimate(self, max_if_unlimited=SQS_MAX_MESSAGES):
maxcount = self.qos.can_consume_max_estimate()
return min(
max_if_unlimited if maxcount is None else max(maxcount, 1),
max_if_unlimited,
)
def _get_bulk_async(self, queue,
max_if_unlimited=SQS_MAX_MESSAGES, callback=None):
maxcount = self._get_message_estimate()
if maxcount:
return self._get_async(queue, maxcount, callback=callback)
# Not allowed to consume, make sure to notify callback..
callback = ensure_promise(callback)
callback([])
return callback
def _get_async(self, queue, count=1, callback=None):
q_url = self._new_queue(queue)
qname = self.canonical_queue_name(queue)
return self._get_from_sqs(
queue_name=qname, queue_url=q_url, count=count,
connection=self.asynsqs(queue=qname),
callback=transform(
self._on_messages_ready, callback, q_url, queue
),
)
def _on_messages_ready(self, queue, qname, messages):
if 'Messages' in messages and messages['Messages']:
callbacks = self.connection._callbacks
for msg in messages['Messages']:
msg_parsed = self._message_to_python(msg, qname, queue)
callbacks[qname](msg_parsed)
def _get_from_sqs(self, queue_name, queue_url,
connection, count=1, callback=None):
"""Retrieve and handle messages from SQS.
Uses long polling and returns :class:`~vine.promises.promise`.
"""
return connection.receive_message(
queue_name, queue_url, number_messages=count,
wait_time_seconds=self.wait_time_seconds,
callback=callback,
)
def _restore(self, message,
unwanted_delivery_info=('sqs_message', 'sqs_queue')):
for unwanted_key in unwanted_delivery_info:
# Remove objects that aren't JSON serializable (Issue #1108).
message.delivery_info.pop(unwanted_key, None)
return super()._restore(message)
def basic_ack(self, delivery_tag, multiple=False):
try:
message = self.qos.get(delivery_tag).delivery_info
sqs_message = message['sqs_message']
except KeyError:
super().basic_ack(delivery_tag)
else:
queue = None
if 'routing_key' in message:
queue = self.canonical_queue_name(message['routing_key'])
try:
self.sqs(queue=queue).delete_message(
QueueUrl=message['sqs_queue'],
ReceiptHandle=sqs_message['ReceiptHandle']
)
except ClientError as exception:
if exception.response['Error']['Code'] == 'AccessDenied':
raise AccessDeniedQueueException(
exception.response["Error"]["Message"]
)
super().basic_reject(delivery_tag)
else:
super().basic_ack(delivery_tag)
def _size(self, queue):
"""Return the number of messages in a queue."""
q_url = self._new_queue(queue)
c = self.sqs(queue=self.canonical_queue_name(queue))
resp = c.get_queue_attributes(
QueueUrl=q_url,
AttributeNames=['ApproximateNumberOfMessages'])
return int(resp['Attributes']['ApproximateNumberOfMessages'])
def _purge(self, queue):
"""Delete all current messages in a queue."""
q_url = self._new_queue(queue)
# SQS is slow at registering messages, so run for a few
# iterations to ensure messages are detected and deleted.
size = 0
for i in range(10):
size += int(self._size(queue))
if not size:
break
self.sqs(queue=queue).purge_queue(QueueUrl=q_url)
return size
def close(self):
super().close()
# if self._asynsqs:
# try:
# self.asynsqs().close()
# except AttributeError as exc: # FIXME ???
# if "can't set attribute" not in str(exc):
# raise
def new_sqs_client(self, region, access_key_id,
secret_access_key, session_token=None):
session = boto3.session.Session(
region_name=region,
aws_access_key_id=access_key_id,
aws_secret_access_key=secret_access_key,
aws_session_token=session_token,
)
is_secure = self.is_secure if self.is_secure is not None else True
client_kwargs = {
'use_ssl': is_secure
}
if self.endpoint_url is not None:
client_kwargs['endpoint_url'] = self.endpoint_url
client_config = self.transport_options.get('client-config') or {}
config = Config(**client_config)
return session.client('sqs', config=config, **client_kwargs)
def sqs(self, queue=None):
if queue is not None and self.predefined_queues:
if queue not in self.predefined_queues:
raise UndefinedQueueException(
f"Queue with name '{queue}' must be defined"
" in 'predefined_queues'.")
q = self.predefined_queues[queue]
if self.transport_options.get('sts_role_arn'):
return self._handle_sts_session(queue, q)
if not self.transport_options.get('sts_role_arn'):
if queue in self._predefined_queue_clients:
return self._predefined_queue_clients[queue]
else:
c = self._predefined_queue_clients[queue] = \
self.new_sqs_client(
region=q.get('region', self.region),
access_key_id=q.get(
'access_key_id', self.conninfo.userid),
secret_access_key=q.get(
'secret_access_key', self.conninfo.password)
)
return c
if self._sqs is not None:
return self._sqs
c = self._sqs = self.new_sqs_client(
region=self.region,
access_key_id=self.conninfo.userid,
secret_access_key=self.conninfo.password,
)
return c
def _handle_sts_session(self, queue, q):
region = q.get('region', self.region)
if not hasattr(self, 'sts_expiration'): # STS token - token init
return self._new_predefined_queue_client_with_sts_session(queue, region)
# STS token - refresh if expired
elif self.sts_expiration.replace(tzinfo=None) < datetime.utcnow():
return self._new_predefined_queue_client_with_sts_session(queue, region)
else: # STS token - ruse existing
if queue not in self._predefined_queue_clients:
return self._new_predefined_queue_client_with_sts_session(queue, region)
return self._predefined_queue_clients[queue]
def _new_predefined_queue_client_with_sts_session(self, queue, region):
sts_creds = self.generate_sts_session_token(
self.transport_options.get('sts_role_arn'),
self.transport_options.get('sts_token_timeout', 900))
self.sts_expiration = sts_creds['Expiration']
c = self._predefined_queue_clients[queue] = self.new_sqs_client(
region=region,
access_key_id=sts_creds['AccessKeyId'],
secret_access_key=sts_creds['SecretAccessKey'],
session_token=sts_creds['SessionToken'],
)
return c
def generate_sts_session_token(self, role_arn, token_expiry_seconds):
sts_client = boto3.client('sts')
sts_policy = sts_client.assume_role(
RoleArn=role_arn,
RoleSessionName='Celery',
DurationSeconds=token_expiry_seconds
)
return sts_policy['Credentials']
def asynsqs(self, queue=None):
if queue is not None and self.predefined_queues:
if queue in self._predefined_queue_async_clients and \
not hasattr(self, 'sts_expiration'):
return self._predefined_queue_async_clients[queue]
if queue not in self.predefined_queues:
raise UndefinedQueueException((
"Queue with name '{}' must be defined in "
"'predefined_queues'."
).format(queue))
q = self.predefined_queues[queue]
c = self._predefined_queue_async_clients[queue] = \
AsyncSQSConnection(
sqs_connection=self.sqs(queue=queue),
region=q.get('region', self.region),
fetch_message_attributes=self.fetch_message_attributes,
)
return c
if self._asynsqs is not None:
return self._asynsqs
c = self._asynsqs = AsyncSQSConnection(
sqs_connection=self.sqs(queue=queue),
region=self.region,
fetch_message_attributes=self.fetch_message_attributes,
)
return c
@property
def conninfo(self):
return self.connection.client
@property
def transport_options(self):
return self.connection.client.transport_options
@cached_property
def visibility_timeout(self):
return (self.transport_options.get('visibility_timeout') or
self.default_visibility_timeout)
@cached_property
def predefined_queues(self):
"""Map of queue_name to predefined queue settings."""
return self.transport_options.get('predefined_queues', {})
@cached_property
def queue_name_prefix(self):
return self.transport_options.get('queue_name_prefix', '')
@cached_property
def supports_fanout(self):
return False
@cached_property
def region(self):
return (self.transport_options.get('region') or
boto3.Session().region_name or
self.default_region)
@cached_property
def regioninfo(self):
return self.transport_options.get('regioninfo')
@cached_property
def is_secure(self):
return self.transport_options.get('is_secure')
@cached_property
def port(self):
return self.transport_options.get('port')
@cached_property
def endpoint_url(self):
if self.conninfo.hostname is not None:
scheme = 'https' if self.is_secure else 'http'
if self.conninfo.port is not None:
port = f':{self.conninfo.port}'
else:
port = ''
return '{}://{}{}'.format(
scheme,
self.conninfo.hostname,
port
)
@cached_property
def wait_time_seconds(self):
return self.transport_options.get('wait_time_seconds',
self.default_wait_time_seconds)
@cached_property
def sqs_base64_encoding(self):
return self.transport_options.get('sqs_base64_encoding', True)
@cached_property
def fetch_message_attributes(self):
return self.transport_options.get('fetch_message_attributes')
class Transport(virtual.Transport):
"""SQS Transport.
Additional queue attributes can be supplied to SQS during queue
creation by passing an ``sqs-creation-attributes`` key in
transport_options. ``sqs-creation-attributes`` must be a dict whose
key-value pairs correspond with Attributes in the
`CreateQueue SQS API`_.
For example, to have SQS queues created with server-side encryption
enabled using the default Amazon Managed Customer Master Key, you
can set ``KmsMasterKeyId`` Attribute. When the queue is initially
created by Kombu, encryption will be enabled.
.. code-block:: python
from kombu.transport.SQS import Transport
transport = Transport(
...,
transport_options={
'sqs-creation-attributes': {
'KmsMasterKeyId': 'alias/aws/sqs',
},
}
)
.. _CreateQueue SQS API: https://docs.aws.amazon.com/AWSSimpleQueueService/latest/APIReference/API_CreateQueue.html#API_CreateQueue_RequestParameters
The ``ApproximateReceiveCount`` message attribute is fetched by this
transport by default. Requested message attributes can be changed by
setting ``fetch_message_attributes`` in the transport options.
.. code-block:: python
from kombu.transport.SQS import Transport
transport = Transport(
...,
transport_options={
'fetch_message_attributes': ["All"],
}
)
.. _Message Attributes: https://docs.aws.amazon.com/AWSSimpleQueueService/latest/APIReference/API_ReceiveMessage.html#SQS-ReceiveMessage-request-AttributeNames
""" # noqa: E501
Channel = Channel
polling_interval = 1
wait_time_seconds = 0
default_port = None
connection_errors = (
virtual.Transport.connection_errors +
(exceptions.BotoCoreError, socket.error)
)
channel_errors = (
virtual.Transport.channel_errors + (exceptions.BotoCoreError,)
)
driver_type = 'sqs'
driver_name = 'sqs'
implements = virtual.Transport.implements.extend(
asynchronous=True,
exchange_type=frozenset(['direct']),
)
@property
def default_connection_params(self):
return {'port': self.default_port}

View File

@@ -0,0 +1,93 @@
"""Built-in transports."""
from __future__ import annotations
from kombu.utils.compat import _detect_environment
from kombu.utils.imports import symbol_by_name
def supports_librabbitmq() -> bool | None:
"""Return true if :pypi:`librabbitmq` can be used."""
if _detect_environment() == 'default':
try:
import librabbitmq # noqa
except ImportError: # pragma: no cover
pass
else: # pragma: no cover
return True
return None
TRANSPORT_ALIASES = {
'amqp': 'kombu.transport.pyamqp:Transport',
'amqps': 'kombu.transport.pyamqp:SSLTransport',
'pyamqp': 'kombu.transport.pyamqp:Transport',
'librabbitmq': 'kombu.transport.librabbitmq:Transport',
'confluentkafka': 'kombu.transport.confluentkafka:Transport',
'kafka': 'kombu.transport.confluentkafka:Transport',
'memory': 'kombu.transport.memory:Transport',
'redis': 'kombu.transport.redis:Transport',
'rediss': 'kombu.transport.redis:Transport',
'SQS': 'kombu.transport.SQS:Transport',
'sqs': 'kombu.transport.SQS:Transport',
'mongodb': 'kombu.transport.mongodb:Transport',
'zookeeper': 'kombu.transport.zookeeper:Transport',
'sqlalchemy': 'kombu.transport.sqlalchemy:Transport',
'sqla': 'kombu.transport.sqlalchemy:Transport',
'SLMQ': 'kombu.transport.SLMQ.Transport',
'slmq': 'kombu.transport.SLMQ.Transport',
'filesystem': 'kombu.transport.filesystem:Transport',
'qpid': 'kombu.transport.qpid:Transport',
'sentinel': 'kombu.transport.redis:SentinelTransport',
'consul': 'kombu.transport.consul:Transport',
'etcd': 'kombu.transport.etcd:Transport',
'azurestoragequeues': 'kombu.transport.azurestoragequeues:Transport',
'azureservicebus': 'kombu.transport.azureservicebus:Transport',
'pyro': 'kombu.transport.pyro:Transport',
'gcpubsub': 'kombu.transport.gcpubsub:Transport',
}
_transport_cache = {}
def resolve_transport(transport: str | None = None) -> str | None:
"""Get transport by name.
Arguments:
---------
transport (Union[str, type]): This can be either
an actual transport class, or the fully qualified
path to a transport class, or the alias of a transport.
"""
if isinstance(transport, str):
try:
transport = TRANSPORT_ALIASES[transport]
except KeyError:
if '.' not in transport and ':' not in transport:
from kombu.utils.text import fmatch_best
alt = fmatch_best(transport, TRANSPORT_ALIASES)
if alt:
raise KeyError(
'No such transport: {}. Did you mean {}?'.format(
transport, alt))
raise KeyError(f'No such transport: {transport}')
else:
if callable(transport):
transport = transport()
return symbol_by_name(transport)
return transport
def get_transport_cls(transport: str | None = None) -> str | None:
"""Get transport class by name.
The transport string is the full path to a transport class, e.g.::
"kombu.transport.pyamqp:Transport"
If the name does not include `"."` (is not fully qualified),
the alias table will be consulted.
"""
if transport not in _transport_cache:
_transport_cache[transport] = resolve_transport(transport)
return _transport_cache[transport]

View File

@@ -0,0 +1,498 @@
"""Azure Service Bus Message Queue transport module for kombu.
Note that the Shared Access Policy used to connect to Azure Service Bus
requires Manage, Send and Listen claims since the broker will create new
queues and delete old queues as required.
Notes when using with Celery if you are experiencing issues with programs not
terminating properly. The Azure Service Bus SDK uses the Azure uAMQP library
which in turn creates some threads. If the AzureServiceBus Channel is closed,
said threads will be closed properly, but it seems there are times when Celery
does not do this so these threads will be left running. As the uAMQP threads
are not marked as Daemon threads, they will not be killed when the main thread
exits. Setting the ``uamqp_keep_alive_interval`` transport option to 0 will
prevent the keep_alive thread from starting
More information about Azure Service Bus:
https://azure.microsoft.com/en-us/services/service-bus/
Features
========
* Type: Virtual
* Supports Direct: *Unreviewed*
* Supports Topic: *Unreviewed*
* Supports Fanout: *Unreviewed*
* Supports Priority: *Unreviewed*
* Supports TTL: *Unreviewed*
Connection String
=================
Connection string has the following formats:
.. code-block::
azureservicebus://SAS_POLICY_NAME:SAS_KEY@SERVICE_BUSNAMESPACE
azureservicebus://DefaultAzureCredential@SERVICE_BUSNAMESPACE
azureservicebus://ManagedIdentityCredential@SERVICE_BUSNAMESPACE
Transport Options
=================
* ``queue_name_prefix`` - String prefix to prepend to queue names in a
service bus namespace.
* ``wait_time_seconds`` - Number of seconds to wait to receive messages.
Default ``5``
* ``peek_lock_seconds`` - Number of seconds the message is visible for before
it is requeued and sent to another consumer. Default ``60``
* ``uamqp_keep_alive_interval`` - Interval in seconds the Azure uAMQP library
should send keepalive messages. Default ``30``
* ``retry_total`` - Azure SDK retry total. Default ``3``
* ``retry_backoff_factor`` - Azure SDK exponential backoff factor.
Default ``0.8``
* ``retry_backoff_max`` - Azure SDK retry total time. Default ``120``
"""
from __future__ import annotations
import string
from queue import Empty
from typing import Any
import azure.core.exceptions
import azure.servicebus.exceptions
import isodate
from azure.servicebus import (ServiceBusClient, ServiceBusMessage,
ServiceBusReceiveMode, ServiceBusReceiver,
ServiceBusSender)
from azure.servicebus.management import ServiceBusAdministrationClient
try:
from azure.identity import (DefaultAzureCredential,
ManagedIdentityCredential)
except ImportError:
DefaultAzureCredential = None
ManagedIdentityCredential = None
from kombu.utils.encoding import bytes_to_str, safe_str
from kombu.utils.json import dumps, loads
from kombu.utils.objects import cached_property
from . import virtual
# dots are replaced by dash, all other punctuation replaced by underscore.
PUNCTUATIONS_TO_REPLACE = set(string.punctuation) - {'_', '.', '-'}
CHARS_REPLACE_TABLE = {
ord('.'): ord('-'),
**{ord(c): ord('_') for c in PUNCTUATIONS_TO_REPLACE}
}
class SendReceive:
"""Container for Sender and Receiver."""
def __init__(self,
receiver: ServiceBusReceiver | None = None,
sender: ServiceBusSender | None = None):
self.receiver: ServiceBusReceiver = receiver
self.sender: ServiceBusSender = sender
def close(self) -> None:
if self.receiver:
self.receiver.close()
self.receiver = None
if self.sender:
self.sender.close()
self.sender = None
class Channel(virtual.Channel):
"""Azure Service Bus channel."""
default_wait_time_seconds: int = 5 # in seconds
default_peek_lock_seconds: int = 60 # in seconds (default 60, max 300)
# in seconds (is the default from service bus repo)
default_uamqp_keep_alive_interval: int = 30
# number of retries (is the default from service bus repo)
default_retry_total: int = 3
# exponential backoff factor (is the default from service bus repo)
default_retry_backoff_factor: float = 0.8
# Max time to backoff (is the default from service bus repo)
default_retry_backoff_max: int = 120
domain_format: str = 'kombu%(vhost)s'
_queue_cache: dict[str, SendReceive] = {}
_noack_queues: set[str] = set()
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._namespace = None
self._policy = None
self._sas_key = None
self._connection_string = None
self._try_parse_connection_string()
self.qos.restore_at_shutdown = False
def _try_parse_connection_string(self) -> None:
self._namespace, self._credential = Transport.parse_uri(
self.conninfo.hostname)
if (
DefaultAzureCredential is not None
and isinstance(self._credential, DefaultAzureCredential)
) or (
ManagedIdentityCredential is not None
and isinstance(self._credential, ManagedIdentityCredential)
):
return None
if ":" in self._credential:
self._policy, self._sas_key = self._credential.split(':', 1)
conn_dict = {
'Endpoint': 'sb://' + self._namespace,
'SharedAccessKeyName': self._policy,
'SharedAccessKey': self._sas_key,
}
self._connection_string = ';'.join(
[key + '=' + value for key, value in conn_dict.items()])
def basic_consume(self, queue, no_ack, *args, **kwargs):
if no_ack:
self._noack_queues.add(queue)
return super().basic_consume(
queue, no_ack, *args, **kwargs
)
def basic_cancel(self, consumer_tag):
if consumer_tag in self._consumers:
queue = self._tag_to_queue[consumer_tag]
self._noack_queues.discard(queue)
return super().basic_cancel(consumer_tag)
def _add_queue_to_cache(
self, name: str,
receiver: ServiceBusReceiver | None = None,
sender: ServiceBusSender | None = None
) -> SendReceive:
if name in self._queue_cache:
obj = self._queue_cache[name]
obj.sender = obj.sender or sender
obj.receiver = obj.receiver or receiver
else:
obj = SendReceive(receiver, sender)
self._queue_cache[name] = obj
return obj
def _get_asb_sender(self, queue: str) -> SendReceive:
queue_obj = self._queue_cache.get(queue, None)
if queue_obj is None or queue_obj.sender is None:
sender = self.queue_service.get_queue_sender(
queue, keep_alive=self.uamqp_keep_alive_interval)
queue_obj = self._add_queue_to_cache(queue, sender=sender)
return queue_obj
def _get_asb_receiver(
self, queue: str,
recv_mode: ServiceBusReceiveMode = ServiceBusReceiveMode.PEEK_LOCK,
queue_cache_key: str | None = None) -> SendReceive:
cache_key = queue_cache_key or queue
queue_obj = self._queue_cache.get(cache_key, None)
if queue_obj is None or queue_obj.receiver is None:
receiver = self.queue_service.get_queue_receiver(
queue_name=queue, receive_mode=recv_mode,
keep_alive=self.uamqp_keep_alive_interval)
queue_obj = self._add_queue_to_cache(cache_key, receiver=receiver)
return queue_obj
def entity_name(
self, name: str, table: dict[int, int] | None = None) -> str:
"""Format AMQP queue name into a valid ServiceBus queue name."""
return str(safe_str(name)).translate(table or CHARS_REPLACE_TABLE)
def _restore(self, message: virtual.base.Message) -> None:
# Not be needed as ASB handles unacked messages
# Remove 'azure_message' as its not JSON serializable
# message.delivery_info.pop('azure_message', None)
# super()._restore(message)
pass
def _new_queue(self, queue: str, **kwargs) -> SendReceive:
"""Ensure a queue exists in ServiceBus."""
queue = self.entity_name(self.queue_name_prefix + queue)
try:
return self._queue_cache[queue]
except KeyError:
# Converts seconds into ISO8601 duration format
# ie 66seconds = P1M6S
lock_duration = isodate.duration_isoformat(
isodate.Duration(seconds=self.peek_lock_seconds))
try:
self.queue_mgmt_service.create_queue(
queue_name=queue, lock_duration=lock_duration)
except azure.core.exceptions.ResourceExistsError:
pass
return self._add_queue_to_cache(queue)
def _delete(self, queue: str, *args, **kwargs) -> None:
"""Delete queue by name."""
queue = self.entity_name(self.queue_name_prefix + queue)
self.queue_mgmt_service.delete_queue(queue)
send_receive_obj = self._queue_cache.pop(queue, None)
if send_receive_obj:
send_receive_obj.close()
def _put(self, queue: str, message, **kwargs) -> None:
"""Put message onto queue."""
queue = self.entity_name(self.queue_name_prefix + queue)
msg = ServiceBusMessage(dumps(message))
queue_obj = self._get_asb_sender(queue)
queue_obj.sender.send_messages(msg)
def _get(
self, queue: str,
timeout: float | int | None = None
) -> dict[str, Any]:
"""Try to retrieve a single message off ``queue``."""
# If we're not ack'ing for this queue, just change receive_mode
recv_mode = ServiceBusReceiveMode.RECEIVE_AND_DELETE \
if queue in self._noack_queues else ServiceBusReceiveMode.PEEK_LOCK
queue = self.entity_name(self.queue_name_prefix + queue)
queue_obj = self._get_asb_receiver(queue, recv_mode)
messages = queue_obj.receiver.receive_messages(
max_message_count=1,
max_wait_time=timeout or self.wait_time_seconds)
if not messages:
raise Empty()
# message.body is either byte or generator[bytes]
message = messages[0]
if not isinstance(message.body, bytes):
body = b''.join(message.body)
else:
body = message.body
msg = loads(bytes_to_str(body))
msg['properties']['delivery_info']['azure_message'] = message
msg['properties']['delivery_info']['azure_queue_name'] = queue
return msg
def basic_ack(self, delivery_tag: str, multiple: bool = False) -> None:
try:
delivery_info = self.qos.get(delivery_tag).delivery_info
except KeyError:
super().basic_ack(delivery_tag)
else:
queue = delivery_info['azure_queue_name']
# recv_mode is PEEK_LOCK when ack'ing messages
queue_obj = self._get_asb_receiver(queue)
try:
queue_obj.receiver.complete_message(
delivery_info['azure_message'])
except azure.servicebus.exceptions.MessageAlreadySettled:
super().basic_ack(delivery_tag)
except Exception:
super().basic_reject(delivery_tag)
else:
super().basic_ack(delivery_tag)
def _size(self, queue: str) -> int:
"""Return the number of messages in a queue."""
queue = self.entity_name(self.queue_name_prefix + queue)
props = self.queue_mgmt_service.get_queue_runtime_properties(queue)
return props.total_message_count
def _purge(self, queue) -> int:
"""Delete all current messages in a queue."""
# Azure doesn't provide a purge api yet
n = 0
max_purge_count = 10
queue = self.entity_name(self.queue_name_prefix + queue)
# By default all the receivers will be in PEEK_LOCK receive mode
queue_obj = self._queue_cache.get(queue, None)
if queue not in self._noack_queues or \
queue_obj is None or queue_obj.receiver is None:
queue_obj = self._get_asb_receiver(
queue,
ServiceBusReceiveMode.RECEIVE_AND_DELETE, 'purge_' + queue
)
while True:
messages = queue_obj.receiver.receive_messages(
max_message_count=max_purge_count,
max_wait_time=0.2
)
n += len(messages)
if len(messages) < max_purge_count:
break
return n
def close(self) -> None:
# receivers and senders spawn threads so clean them up
if not self.closed:
self.closed = True
for queue_obj in self._queue_cache.values():
queue_obj.close()
self._queue_cache.clear()
if self.connection is not None:
self.connection.close_channel(self)
@cached_property
def queue_service(self) -> ServiceBusClient:
if self._connection_string:
return ServiceBusClient.from_connection_string(
self._connection_string,
retry_total=self.retry_total,
retry_backoff_factor=self.retry_backoff_factor,
retry_backoff_max=self.retry_backoff_max
)
return ServiceBusClient(
self._namespace,
self._credential,
retry_total=self.retry_total,
retry_backoff_factor=self.retry_backoff_factor,
retry_backoff_max=self.retry_backoff_max
)
@cached_property
def queue_mgmt_service(self) -> ServiceBusAdministrationClient:
if self._connection_string:
return ServiceBusAdministrationClient.from_connection_string(
self._connection_string
)
return ServiceBusAdministrationClient(
self._namespace, self._credential
)
@property
def conninfo(self):
return self.connection.client
@property
def transport_options(self):
return self.connection.client.transport_options
@cached_property
def queue_name_prefix(self) -> str:
return self.transport_options.get('queue_name_prefix', '')
@cached_property
def wait_time_seconds(self) -> int:
return self.transport_options.get('wait_time_seconds',
self.default_wait_time_seconds)
@cached_property
def peek_lock_seconds(self) -> int:
return min(self.transport_options.get('peek_lock_seconds',
self.default_peek_lock_seconds),
300) # Limit upper bounds to 300
@cached_property
def uamqp_keep_alive_interval(self) -> int:
return self.transport_options.get(
'uamqp_keep_alive_interval',
self.default_uamqp_keep_alive_interval
)
@cached_property
def retry_total(self) -> int:
return self.transport_options.get(
'retry_total', self.default_retry_total)
@cached_property
def retry_backoff_factor(self) -> float:
return self.transport_options.get(
'retry_backoff_factor', self.default_retry_backoff_factor)
@cached_property
def retry_backoff_max(self) -> int:
return self.transport_options.get(
'retry_backoff_max', self.default_retry_backoff_max)
class Transport(virtual.Transport):
"""Azure Service Bus transport."""
Channel = Channel
polling_interval = 1
default_port = None
can_parse_url = True
@staticmethod
def parse_uri(uri: str) -> tuple[str, str | DefaultAzureCredential |
ManagedIdentityCredential]:
# URL like:
# azureservicebus://{SAS policy name}:{SAS key}@{ServiceBus Namespace}
# urllib parse does not work as the sas key could contain a slash
# e.g.: azureservicebus://rootpolicy:some/key@somenamespace
# > 'rootpolicy:some/key@somenamespace'
uri = uri.replace('azureservicebus://', '')
# > 'rootpolicy:some/key', 'somenamespace'
credential, namespace = uri.rsplit('@', 1)
if not namespace.endswith('.net'):
namespace += '.servicebus.windows.net'
if "DefaultAzureCredential".lower() == credential.lower():
if DefaultAzureCredential is None:
raise ImportError('Azure Service Bus transport with a '
'DefaultAzureCredential requires the '
'azure-identity library')
credential = DefaultAzureCredential()
elif "ManagedIdentityCredential".lower() == credential.lower():
if ManagedIdentityCredential is None:
raise ImportError('Azure Service Bus transport with a '
'ManagedIdentityCredential requires the '
'azure-identity library')
credential = ManagedIdentityCredential()
else:
# > 'rootpolicy', 'some/key'
policy, sas_key = credential.split(':', 1)
credential = f"{policy}:{sas_key}"
# Validate ASB connection string
if not all([namespace, credential]):
raise ValueError(
'Need a URI like '
'azureservicebus://{SAS policy name}:{SAS key}@{ServiceBus Namespace} ' # noqa
'or the azure Endpoint connection string'
)
return namespace, credential
@classmethod
def as_uri(cls, uri: str, include_password=False, mask='**') -> str:
namespace, credential = cls.parse_uri(uri)
if isinstance(credential, str) and ":" in credential:
policy, sas_key = credential.split(':', 1)
return 'azureservicebus://{}:{}@{}'.format(
policy,
sas_key if include_password else mask,
namespace
)
return 'azureservicebus://{}@{}'.format(
credential.__class__.__name__,
namespace
)

View File

@@ -0,0 +1,263 @@
"""Azure Storage Queues transport module for kombu.
More information about Azure Storage Queues:
https://azure.microsoft.com/en-us/services/storage/queues/
Features
========
* Type: Virtual
* Supports Direct: *Unreviewed*
* Supports Topic: *Unreviewed*
* Supports Fanout: *Unreviewed*
* Supports Priority: *Unreviewed*
* Supports TTL: *Unreviewed*
Connection String
=================
Connection string has the following formats:
.. code-block::
azurestoragequeues://<STORAGE_ACCOUNT_ACCESS_KEY>@<STORAGE_ACCOUNT_URL>
azurestoragequeues://<SAS_TOKEN>@<STORAGE_ACCOUNT_URL>
azurestoragequeues://DefaultAzureCredential@<STORAGE_ACCOUNT_URL>
azurestoragequeues://ManagedIdentityCredential@<STORAGE_ACCOUNT_URL>
Note that if the access key for the storage account contains a forward slash
(``/``), it will have to be regenerated before it can be used in the connection
URL.
.. code-block::
azurestoragequeues://DefaultAzureCredential@<STORAGE_ACCOUNT_URL>
azurestoragequeues://ManagedIdentityCredential@<STORAGE_ACCOUNT_URL>
If you wish to use an `Azure Managed Identity` you may use the
``DefaultAzureCredential`` format of the connection string which will use
``DefaultAzureCredential`` class in the azure-identity package. You may want to
read the `azure-identity documentation` for more information on how the
``DefaultAzureCredential`` works.
.. _azure-identity documentation:
https://learn.microsoft.com/en-us/python/api/overview/azure/identity-readme?view=azure-python
.. _Azure Managed Identity:
https://learn.microsoft.com/en-us/azure/active-directory/managed-identities-azure-resources/overview
Transport Options
=================
* ``queue_name_prefix``
"""
from __future__ import annotations
import string
from queue import Empty
from typing import Any
from azure.core.exceptions import ResourceExistsError
from kombu.utils.encoding import safe_str
from kombu.utils.json import dumps, loads
from kombu.utils.objects import cached_property
from . import virtual
try:
from azure.storage.queue import QueueServiceClient
except ImportError: # pragma: no cover
QueueServiceClient = None
try:
from azure.identity import (DefaultAzureCredential,
ManagedIdentityCredential)
except ImportError:
DefaultAzureCredential = None
ManagedIdentityCredential = None
# Azure storage queues allow only alphanumeric and dashes
# so, replace everything with a dash
CHARS_REPLACE_TABLE = {
ord(c): 0x2d for c in string.punctuation
}
class Channel(virtual.Channel):
"""Azure Storage Queues channel."""
domain_format: str = 'kombu%(vhost)s'
_queue_service: QueueServiceClient | None = None
_queue_name_cache: dict[Any, Any] = {}
no_ack: bool = True
_noack_queues: set[Any] = set()
def __init__(self, *args, **kwargs):
if QueueServiceClient is None:
raise ImportError('Azure Storage Queues transport requires the '
'azure-storage-queue library')
super().__init__(*args, **kwargs)
self._credential, self._url = Transport.parse_uri(
self.conninfo.hostname
)
for queue in self.queue_service.list_queues():
self._queue_name_cache[queue['name']] = queue
def basic_consume(self, queue, no_ack, *args, **kwargs):
if no_ack:
self._noack_queues.add(queue)
return super().basic_consume(queue, no_ack,
*args, **kwargs)
def entity_name(self, name, table=CHARS_REPLACE_TABLE) -> str:
"""Format AMQP queue name into a valid Azure Storage Queue name."""
return str(safe_str(name)).translate(table)
def _ensure_queue(self, queue):
"""Ensure a queue exists."""
queue = self.entity_name(self.queue_name_prefix + queue)
try:
q = self._queue_service.get_queue_client(
queue=self._queue_name_cache[queue]
)
except KeyError:
try:
q = self.queue_service.create_queue(queue)
except ResourceExistsError:
q = self._queue_service.get_queue_client(queue=queue)
self._queue_name_cache[queue] = q.get_queue_properties()
return q
def _delete(self, queue, *args, **kwargs):
"""Delete queue by name."""
queue_name = self.entity_name(queue)
self._queue_name_cache.pop(queue_name, None)
self.queue_service.delete_queue(queue_name)
def _put(self, queue, message, **kwargs):
"""Put message onto queue."""
q = self._ensure_queue(queue)
encoded_message = dumps(message)
q.send_message(encoded_message)
def _get(self, queue, timeout=None):
"""Try to retrieve a single message off ``queue``."""
q = self._ensure_queue(queue)
messages = q.receive_messages(messages_per_page=1, timeout=timeout)
try:
message = next(messages)
except StopIteration:
raise Empty()
content = loads(message.content)
q.delete_message(message=message)
return content
def _size(self, queue):
"""Return the number of messages in a queue."""
q = self._ensure_queue(queue)
return q.get_queue_properties().approximate_message_count
def _purge(self, queue):
"""Delete all current messages in a queue."""
q = self._ensure_queue(queue)
n = self._size(q.queue_name)
q.clear_messages()
return n
@property
def queue_service(self) -> QueueServiceClient:
if self._queue_service is None:
self._queue_service = QueueServiceClient(
account_url=self._url, credential=self._credential
)
return self._queue_service
@property
def conninfo(self):
return self.connection.client
@property
def transport_options(self):
return self.connection.client.transport_options
@cached_property
def queue_name_prefix(self) -> str:
return self.transport_options.get('queue_name_prefix', '')
class Transport(virtual.Transport):
"""Azure Storage Queues transport."""
Channel = Channel
polling_interval: int = 1
default_port: int | None = None
can_parse_url: bool = True
@staticmethod
def parse_uri(uri: str) -> tuple[str | dict, str]:
# URL like:
# azurestoragequeues://<STORAGE_ACCOUNT_ACCESS_KEY>@<STORAGE_ACCOUNT_URL>
# azurestoragequeues://<SAS_TOKEN>@<STORAGE_ACCOUNT_URL>
# azurestoragequeues://DefaultAzureCredential@<STORAGE_ACCOUNT_URL>
# azurestoragequeues://ManagedIdentityCredential@<STORAGE_ACCOUNT_URL>
# urllib parse does not work as the sas key could contain a slash
# e.g.: azurestoragequeues://some/key@someurl
try:
# > 'some/key@url'
uri = uri.replace('azurestoragequeues://', '')
# > 'some/key', 'url'
credential, url = uri.rsplit('@', 1)
if "DefaultAzureCredential".lower() == credential.lower():
if DefaultAzureCredential is None:
raise ImportError('Azure Storage Queues transport with a '
'DefaultAzureCredential requires the '
'azure-identity library')
credential = DefaultAzureCredential()
elif "ManagedIdentityCredential".lower() == credential.lower():
if ManagedIdentityCredential is None:
raise ImportError('Azure Storage Queues transport with a '
'ManagedIdentityCredential requires the '
'azure-identity library')
credential = ManagedIdentityCredential()
elif "devstoreaccount1" in url and ".core.windows.net" not in url:
# parse credential as a dict if Azurite is being used
credential = {
"account_name": "devstoreaccount1",
"account_key": credential,
}
# Validate parameters
assert all([credential, url])
except Exception:
raise ValueError(
'Need a URI like '
'azurestoragequeues://{SAS or access key}@{URL}, '
'azurestoragequeues://DefaultAzureCredential@{URL}, '
', or '
'azurestoragequeues://ManagedIdentityCredential@{URL}'
)
return credential, url
@classmethod
def as_uri(
cls, uri: str, include_password: bool = False, mask: str = "**"
) -> str:
credential, url = cls.parse_uri(uri)
return "azurestoragequeues://{}@{}".format(
credential if include_password else mask, url
)

View File

@@ -0,0 +1,271 @@
"""Base transport interface."""
# flake8: noqa
from __future__ import annotations
import errno
import socket
from typing import TYPE_CHECKING
from amqp.exceptions import RecoverableConnectionError
from kombu.exceptions import ChannelError, ConnectionError
from kombu.message import Message
from kombu.utils.functional import dictfilter
from kombu.utils.objects import cached_property
from kombu.utils.time import maybe_s_to_ms
if TYPE_CHECKING:
from types import TracebackType
__all__ = ('Message', 'StdChannel', 'Management', 'Transport')
RABBITMQ_QUEUE_ARGUMENTS = {
'expires': ('x-expires', maybe_s_to_ms),
'message_ttl': ('x-message-ttl', maybe_s_to_ms),
'max_length': ('x-max-length', int),
'max_length_bytes': ('x-max-length-bytes', int),
'max_priority': ('x-max-priority', int),
} # type: Mapping[str, Tuple[str, Callable]]
def to_rabbitmq_queue_arguments(arguments, **options):
# type: (Mapping, **Any) -> Dict
"""Convert queue arguments to RabbitMQ queue arguments.
This is the implementation for Channel.prepare_queue_arguments
for AMQP-based transports. It's used by both the pyamqp and librabbitmq
transports.
Arguments:
arguments (Mapping):
User-supplied arguments (``Queue.queue_arguments``).
Keyword Arguments:
expires (float): Queue expiry time in seconds.
This will be converted to ``x-expires`` in int milliseconds.
message_ttl (float): Message TTL in seconds.
This will be converted to ``x-message-ttl`` in int milliseconds.
max_length (int): Max queue length (in number of messages).
This will be converted to ``x-max-length`` int.
max_length_bytes (int): Max queue size in bytes.
This will be converted to ``x-max-length-bytes`` int.
max_priority (int): Max priority steps for queue.
This will be converted to ``x-max-priority`` int.
Returns
-------
Dict: RabbitMQ compatible queue arguments.
"""
prepared = dictfilter(dict(
_to_rabbitmq_queue_argument(key, value)
for key, value in options.items()
))
return dict(arguments, **prepared) if prepared else arguments
def _to_rabbitmq_queue_argument(key, value):
# type: (str, Any) -> Tuple[str, Any]
opt, typ = RABBITMQ_QUEUE_ARGUMENTS[key]
return opt, typ(value) if value is not None else value
def _LeftBlank(obj, method):
return NotImplementedError(
'Transport {0.__module__}.{0.__name__} does not implement {1}'.format(
obj.__class__, method))
class StdChannel:
"""Standard channel base class."""
no_ack_consumers = None
def Consumer(self, *args, **kwargs):
from kombu.messaging import Consumer
return Consumer(self, *args, **kwargs)
def Producer(self, *args, **kwargs):
from kombu.messaging import Producer
return Producer(self, *args, **kwargs)
def get_bindings(self):
raise _LeftBlank(self, 'get_bindings')
def after_reply_message_received(self, queue):
"""Callback called after RPC reply received.
Notes
-----
Reply queue semantics: can be used to delete the queue
after transient reply message received.
"""
def prepare_queue_arguments(self, arguments, **kwargs):
return arguments
def __enter__(self):
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None
) -> None:
self.close()
class Management:
"""AMQP Management API (incomplete)."""
def __init__(self, transport):
self.transport = transport
def get_bindings(self):
raise _LeftBlank(self, 'get_bindings')
class Implements(dict):
"""Helper class used to define transport features."""
def __getattr__(self, key):
try:
return self[key]
except KeyError:
raise AttributeError(key)
def __setattr__(self, key, value):
self[key] = value
def extend(self, **kwargs):
return self.__class__(self, **kwargs)
default_transport_capabilities = Implements(
asynchronous=False,
exchange_type=frozenset(['direct', 'topic', 'fanout', 'headers']),
heartbeats=False,
)
class Transport:
"""Base class for transports."""
Management = Management
#: The :class:`~kombu.Connection` owning this instance.
client = None
#: Set to True if :class:`~kombu.Connection` should pass the URL
#: unmodified.
can_parse_url = False
#: Default port used when no port has been specified.
default_port = None
#: Tuple of errors that can happen due to connection failure.
connection_errors = (ConnectionError,)
#: Tuple of errors that can happen due to channel/method failure.
channel_errors = (ChannelError,)
#: Type of driver, can be used to separate transports
#: using the AMQP protocol (driver_type: 'amqp'),
#: Redis (driver_type: 'redis'), etc...
driver_type = 'N/A'
#: Name of driver library (e.g. 'py-amqp', 'redis').
driver_name = 'N/A'
__reader = None
implements = default_transport_capabilities.extend()
def __init__(self, client, **kwargs):
self.client = client
def establish_connection(self):
raise _LeftBlank(self, 'establish_connection')
def close_connection(self, connection):
raise _LeftBlank(self, 'close_connection')
def create_channel(self, connection):
raise _LeftBlank(self, 'create_channel')
def close_channel(self, connection):
raise _LeftBlank(self, 'close_channel')
def drain_events(self, connection, **kwargs):
raise _LeftBlank(self, 'drain_events')
def heartbeat_check(self, connection, rate=2):
pass
def driver_version(self):
return 'N/A'
def get_heartbeat_interval(self, connection):
return 0
def register_with_event_loop(self, connection, loop):
pass
def unregister_from_event_loop(self, connection, loop):
pass
def verify_connection(self, connection):
return True
def _make_reader(self, connection, timeout=socket.timeout,
error=socket.error, _unavail=(errno.EAGAIN, errno.EINTR)):
drain_events = connection.drain_events
def _read(loop):
if not connection.connected:
raise RecoverableConnectionError('Socket was disconnected')
try:
drain_events(timeout=0)
except timeout:
return
except error as exc:
if exc.errno in _unavail:
return
raise
loop.call_soon(_read, loop)
return _read
def qos_semantics_matches_spec(self, connection):
return True
def on_readable(self, connection, loop):
reader = self.__reader
if reader is None:
reader = self.__reader = self._make_reader(connection)
reader(loop)
def as_uri(self, uri: str, include_password=False, mask='**') -> str:
"""Customise the display format of the URI."""
raise NotImplementedError()
@property
def default_connection_params(self):
return {}
def get_manager(self, *args, **kwargs):
return self.Management(self)
@cached_property
def manager(self):
return self.get_manager()
@property
def supports_heartbeats(self):
return self.implements.heartbeats
@property
def supports_ev(self):
return self.implements.asynchronous

View File

@@ -0,0 +1,380 @@
"""confluent-kafka transport module for Kombu.
Kafka transport using confluent-kafka library.
**References**
- http://docs.confluent.io/current/clients/confluent-kafka-python
**Limitations**
The confluent-kafka transport does not support PyPy environment.
Features
========
* Type: Virtual
* Supports Direct: Yes
* Supports Topic: Yes
* Supports Fanout: No
* Supports Priority: No
* Supports TTL: No
Connection String
=================
Connection string has the following format:
.. code-block::
confluentkafka://[USER:PASSWORD@]KAFKA_ADDRESS[:PORT]
Transport Options
=================
* ``connection_wait_time_seconds`` - Time in seconds to wait for connection
to succeed. Default ``5``
* ``wait_time_seconds`` - Time in seconds to wait to receive messages.
Default ``5``
* ``security_protocol`` - Protocol used to communicate with broker.
Visit https://github.com/edenhill/librdkafka/blob/master/CONFIGURATION.md for
an explanation of valid values. Default ``plaintext``
* ``sasl_mechanism`` - SASL mechanism to use for authentication.
Visit https://github.com/edenhill/librdkafka/blob/master/CONFIGURATION.md for
an explanation of valid values.
* ``num_partitions`` - Number of partitions to create. Default ``1``
* ``replication_factor`` - Replication factor of partitions. Default ``1``
* ``topic_config`` - Topic configuration. Must be a dict whose key-value pairs
correspond with attributes in the
http://kafka.apache.org/documentation.html#topicconfigs.
* ``kafka_common_config`` - Configuration applied to producer, consumer and
admin client. Must be a dict whose key-value pairs correspond with attributes
in the https://github.com/edenhill/librdkafka/blob/master/CONFIGURATION.md.
* ``kafka_producer_config`` - Producer configuration. Must be a dict whose
key-value pairs correspond with attributes in the
https://github.com/edenhill/librdkafka/blob/master/CONFIGURATION.md.
* ``kafka_consumer_config`` - Consumer configuration. Must be a dict whose
key-value pairs correspond with attributes in the
https://github.com/edenhill/librdkafka/blob/master/CONFIGURATION.md.
* ``kafka_admin_config`` - Admin client configuration. Must be a dict whose
key-value pairs correspond with attributes in the
https://github.com/edenhill/librdkafka/blob/master/CONFIGURATION.md.
"""
from __future__ import annotations
from queue import Empty
from kombu.transport import virtual
from kombu.utils import cached_property
from kombu.utils.encoding import str_to_bytes
from kombu.utils.json import dumps, loads
try:
import confluent_kafka
from confluent_kafka import (Consumer, KafkaException, Producer,
TopicPartition)
from confluent_kafka.admin import AdminClient, NewTopic
KAFKA_CONNECTION_ERRORS = ()
KAFKA_CHANNEL_ERRORS = ()
except ImportError:
confluent_kafka = None
KAFKA_CONNECTION_ERRORS = KAFKA_CHANNEL_ERRORS = ()
from kombu.log import get_logger
logger = get_logger(__name__)
DEFAULT_PORT = 9092
class NoBrokersAvailable(KafkaException):
"""Kafka broker is not available exception."""
retriable = True
class Message(virtual.Message):
"""Message object."""
def __init__(self, payload, channel=None, **kwargs):
self.topic = payload.get('topic')
super().__init__(payload, channel=channel, **kwargs)
class QoS(virtual.QoS):
"""Quality of Service guarantees."""
_not_yet_acked = {}
def can_consume(self):
"""Return true if the channel can be consumed from.
:returns: True, if this QoS object can accept a message.
:rtype: bool
"""
return not self.prefetch_count or len(self._not_yet_acked) < self \
.prefetch_count
def can_consume_max_estimate(self):
if self.prefetch_count:
return self.prefetch_count - len(self._not_yet_acked)
else:
return 1
def append(self, message, delivery_tag):
self._not_yet_acked[delivery_tag] = message
def get(self, delivery_tag):
return self._not_yet_acked[delivery_tag]
def ack(self, delivery_tag):
if delivery_tag not in self._not_yet_acked:
return
message = self._not_yet_acked.pop(delivery_tag)
consumer = self.channel._get_consumer(message.topic)
consumer.commit()
def reject(self, delivery_tag, requeue=False):
"""Reject a message by delivery tag.
If requeue is True, then the last consumed message is reverted so
it'll be refetched on the next attempt.
If False, that message is consumed and ignored.
"""
if requeue:
message = self._not_yet_acked.pop(delivery_tag)
consumer = self.channel._get_consumer(message.topic)
for assignment in consumer.assignment():
topic_partition = TopicPartition(message.topic,
assignment.partition)
[committed_offset] = consumer.committed([topic_partition])
consumer.seek(committed_offset)
else:
self.ack(delivery_tag)
def restore_unacked_once(self, stderr=None):
pass
class Channel(virtual.Channel):
"""Kafka Channel."""
QoS = QoS
Message = Message
default_wait_time_seconds = 5
default_connection_wait_time_seconds = 5
_client = None
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._kafka_consumers = {}
self._kafka_producers = {}
self._client = self._open()
def sanitize_queue_name(self, queue):
"""Need to sanitize the name, celery sometimes pushes in @ signs."""
return str(queue).replace('@', '')
def _get_producer(self, queue):
"""Create/get a producer instance for the given topic/queue."""
queue = self.sanitize_queue_name(queue)
producer = self._kafka_producers.get(queue, None)
if producer is None:
producer = Producer({
**self.common_config,
**(self.options.get('kafka_producer_config') or {}),
})
self._kafka_producers[queue] = producer
return producer
def _get_consumer(self, queue):
"""Create/get a consumer instance for the given topic/queue."""
queue = self.sanitize_queue_name(queue)
consumer = self._kafka_consumers.get(queue, None)
if consumer is None:
consumer = Consumer({
'group.id': f'{queue}-consumer-group',
'auto.offset.reset': 'earliest',
'enable.auto.commit': False,
**self.common_config,
**(self.options.get('kafka_consumer_config') or {}),
})
consumer.subscribe([queue])
self._kafka_consumers[queue] = consumer
return consumer
def _put(self, queue, message, **kwargs):
"""Put a message on the topic/queue."""
queue = self.sanitize_queue_name(queue)
producer = self._get_producer(queue)
producer.produce(queue, str_to_bytes(dumps(message)))
producer.flush()
def _get(self, queue, **kwargs):
"""Get a message from the topic/queue."""
queue = self.sanitize_queue_name(queue)
consumer = self._get_consumer(queue)
message = None
try:
message = consumer.poll(self.wait_time_seconds)
except StopIteration:
pass
if not message:
raise Empty()
error = message.error()
if error:
logger.error(error)
raise Empty()
return {**loads(message.value()), 'topic': message.topic()}
def _delete(self, queue, *args, **kwargs):
"""Delete a queue/topic."""
queue = self.sanitize_queue_name(queue)
self._kafka_consumers[queue].close()
self._kafka_consumers.pop(queue)
self.client.delete_topics([queue])
def _size(self, queue):
"""Get the number of pending messages in the topic/queue."""
queue = self.sanitize_queue_name(queue)
consumer = self._kafka_consumers.get(queue, None)
if consumer is None:
return 0
size = 0
for assignment in consumer.assignment():
topic_partition = TopicPartition(queue, assignment.partition)
(_, end_offset) = consumer.get_watermark_offsets(topic_partition)
[committed_offset] = consumer.committed([topic_partition])
size += end_offset - committed_offset.offset
return size
def _new_queue(self, queue, **kwargs):
"""Create a new topic if it does not exist."""
queue = self.sanitize_queue_name(queue)
if queue in self.client.list_topics().topics:
return
topic = NewTopic(
queue,
num_partitions=self.options.get('num_partitions', 1),
replication_factor=self.options.get('replication_factor', 1),
config=self.options.get('topic_config', {})
)
self.client.create_topics(new_topics=[topic])
def _has_queue(self, queue, **kwargs):
"""Check if a topic already exists."""
queue = self.sanitize_queue_name(queue)
return queue in self.client.list_topics().topics
def _open(self):
client = AdminClient({
**self.common_config,
**(self.options.get('kafka_admin_config') or {}),
})
try:
# seems to be the only way to check connection
client.list_topics(timeout=self.wait_time_seconds)
except confluent_kafka.KafkaException as e:
raise NoBrokersAvailable(e)
return client
@property
def client(self):
if self._client is None:
self._client = self._open()
return self._client
@property
def options(self):
return self.connection.client.transport_options
@property
def conninfo(self):
return self.connection.client
@cached_property
def wait_time_seconds(self):
return self.options.get(
'wait_time_seconds', self.default_wait_time_seconds
)
@cached_property
def connection_wait_time_seconds(self):
return self.options.get(
'connection_wait_time_seconds',
self.default_connection_wait_time_seconds,
)
@cached_property
def common_config(self):
conninfo = self.connection.client
config = {
'bootstrap.servers':
f'{conninfo.hostname}:{int(conninfo.port) or DEFAULT_PORT}',
}
security_protocol = self.options.get('security_protocol', 'plaintext')
if security_protocol.lower() != 'plaintext':
config.update({
'security.protocol': security_protocol,
'sasl.username': conninfo.userid,
'sasl.password': conninfo.password,
'sasl.mechanism': self.options.get('sasl_mechanism'),
})
config.update(self.options.get('kafka_common_config') or {})
return config
def close(self):
super().close()
self._kafka_producers = {}
for consumer in self._kafka_consumers.values():
consumer.close()
self._kafka_consumers = {}
class Transport(virtual.Transport):
"""Kafka Transport."""
def as_uri(self, uri: str, include_password=False, mask='**') -> str:
pass
Channel = Channel
default_port = DEFAULT_PORT
driver_type = 'kafka'
driver_name = 'confluentkafka'
recoverable_connection_errors = (
NoBrokersAvailable,
)
def __init__(self, client, **kwargs):
if confluent_kafka is None:
raise ImportError('The confluent-kafka library is not installed')
super().__init__(client, **kwargs)
def driver_version(self):
return confluent_kafka.__version__
def establish_connection(self):
return super().establish_connection()
def close_connection(self, connection):
return super().close_connection(connection)

View File

@@ -0,0 +1,323 @@
"""Consul Transport module for Kombu.
Features
========
It uses Consul.io's Key/Value store to transport messages in Queues
It uses python-consul for talking to Consul's HTTP API
Features
========
* Type: Native
* Supports Direct: Yes
* Supports Topic: *Unreviewed*
* Supports Fanout: *Unreviewed*
* Supports Priority: *Unreviewed*
* Supports TTL: *Unreviewed*
Connection String
=================
Connection string has the following format:
.. code-block::
consul://CONSUL_ADDRESS[:PORT]
"""
from __future__ import annotations
import socket
import uuid
from collections import defaultdict
from contextlib import contextmanager
from queue import Empty
from time import monotonic
from kombu.exceptions import ChannelError
from kombu.log import get_logger
from kombu.utils.json import dumps, loads
from kombu.utils.objects import cached_property
from . import virtual
try:
import consul
except ImportError:
consul = None
logger = get_logger('kombu.transport.consul')
DEFAULT_PORT = 8500
DEFAULT_HOST = 'localhost'
class LockError(Exception):
"""An error occurred while trying to acquire the lock."""
class Channel(virtual.Channel):
"""Consul Channel class which talks to the Consul Key/Value store."""
prefix = 'kombu'
index = None
timeout = '10s'
session_ttl = 30
def __init__(self, *args, **kwargs):
if consul is None:
raise ImportError('Missing python-consul library')
super().__init__(*args, **kwargs)
port = self.connection.client.port or self.connection.default_port
host = self.connection.client.hostname or DEFAULT_HOST
logger.debug('Host: %s Port: %s Timeout: %s', host, port, self.timeout)
self.queues = defaultdict(dict)
self.client = consul.Consul(host=host, port=int(port))
def _lock_key(self, queue):
return f'{self.prefix}/{queue}.lock'
def _key_prefix(self, queue):
return f'{self.prefix}/{queue}'
def _get_or_create_session(self, queue):
"""Get or create consul session.
Try to renew the session if it exists, otherwise create a new
session in Consul.
This session is used to acquire a lock inside Consul so that we achieve
read-consistency between the nodes.
Arguments:
---------
queue (str): The name of the Queue.
Returns
-------
str: The ID of the session.
"""
try:
session_id = self.queues[queue]['session_id']
except KeyError:
session_id = None
return (self._renew_existing_session(session_id)
if session_id is not None else self._create_new_session())
def _renew_existing_session(self, session_id):
logger.debug('Trying to renew existing session %s', session_id)
session = self.client.session.renew(session_id=session_id)
return session.get('ID')
def _create_new_session(self):
logger.debug('Creating session %s with TTL %s',
self.lock_name, self.session_ttl)
session_id = self.client.session.create(
name=self.lock_name, ttl=self.session_ttl)
logger.debug('Created session %s with id %s',
self.lock_name, session_id)
return session_id
@contextmanager
def _queue_lock(self, queue, raising=LockError):
"""Try to acquire a lock on the Queue.
It does so by creating a object called 'lock' which is locked by the
current session..
This way other nodes are not able to write to the lock object which
means that they have to wait before the lock is released.
Arguments:
---------
queue (str): The name of the Queue.
raising (Exception): Set custom lock error class.
Raises
------
LockError: if the lock cannot be acquired.
Returns
-------
bool: success?
"""
self._acquire_lock(queue, raising=raising)
try:
yield
finally:
self._release_lock(queue)
def _acquire_lock(self, queue, raising=LockError):
session_id = self._get_or_create_session(queue)
lock_key = self._lock_key(queue)
logger.debug('Trying to create lock object %s with session %s',
lock_key, session_id)
if self.client.kv.put(key=lock_key,
acquire=session_id,
value=self.lock_name):
self.queues[queue]['session_id'] = session_id
return
logger.info('Could not acquire lock on key %s', lock_key)
raise raising()
def _release_lock(self, queue):
"""Try to release a lock.
It does so by simply removing the lock key in Consul.
Arguments:
---------
queue (str): The name of the queue we want to release
the lock from.
"""
logger.debug('Removing lock key %s', self._lock_key(queue))
self.client.kv.delete(key=self._lock_key(queue))
def _destroy_session(self, queue):
"""Destroy a previously created Consul session.
Will release all locks it still might hold.
Arguments:
---------
queue (str): The name of the Queue.
"""
logger.debug('Destroying session %s', self.queues[queue]['session_id'])
self.client.session.destroy(self.queues[queue]['session_id'])
def _new_queue(self, queue, **_):
self.queues[queue] = {'session_id': None}
return self.client.kv.put(key=self._key_prefix(queue), value=None)
def _delete(self, queue, *args, **_):
self._destroy_session(queue)
self.queues.pop(queue, None)
self._purge(queue)
def _put(self, queue, payload, **_):
"""Put `message` onto `queue`.
This simply writes a key to the K/V store of Consul
"""
key = '{}/msg/{}_{}'.format(
self._key_prefix(queue),
int(round(monotonic() * 1000)),
uuid.uuid4(),
)
if not self.client.kv.put(key=key, value=dumps(payload), cas=0):
raise ChannelError(f'Cannot add key {key!r} to consul')
def _get(self, queue, timeout=None):
"""Get the first available message from the queue.
Before it does so it acquires a lock on the Key/Value store so
only one node reads at the same time. This is for read consistency
"""
with self._queue_lock(queue, raising=Empty):
key = f'{self._key_prefix(queue)}/msg/'
logger.debug('Fetching key %s with index %s', key, self.index)
self.index, data = self.client.kv.get(
key=key, recurse=True,
index=self.index, wait=self.timeout,
)
try:
if data is None:
raise Empty()
logger.debug('Removing key %s with modifyindex %s',
data[0]['Key'], data[0]['ModifyIndex'])
self.client.kv.delete(key=data[0]['Key'],
cas=data[0]['ModifyIndex'])
return loads(data[0]['Value'])
except TypeError:
pass
raise Empty()
def _purge(self, queue):
self._destroy_session(queue)
return self.client.kv.delete(
key=f'{self._key_prefix(queue)}/msg/',
recurse=True,
)
def _size(self, queue):
size = 0
try:
key = f'{self._key_prefix(queue)}/msg/'
logger.debug('Fetching key recursively %s with index %s',
key, self.index)
self.index, data = self.client.kv.get(
key=key, recurse=True,
index=self.index, wait=self.timeout,
)
size = len(data)
except TypeError:
pass
logger.debug('Found %s keys under %s with index %s',
size, key, self.index)
return size
@cached_property
def lock_name(self):
return f'{socket.gethostname()}'
class Transport(virtual.Transport):
"""Consul K/V storage Transport for Kombu."""
Channel = Channel
default_port = DEFAULT_PORT
driver_type = 'consul'
driver_name = 'consul'
if consul:
connection_errors = (
virtual.Transport.connection_errors + (
consul.ConsulException, consul.base.ConsulException
)
)
channel_errors = (
virtual.Transport.channel_errors + (
consul.ConsulException, consul.base.ConsulException
)
)
def __init__(self, *args, **kwargs):
if consul is None:
raise ImportError('Missing python-consul library')
super().__init__(*args, **kwargs)
def verify_connection(self, connection):
port = connection.client.port or self.default_port
host = connection.client.hostname or DEFAULT_HOST
logger.debug('Verify Consul connection to %s:%s', host, port)
try:
client = consul.Consul(host=host, port=int(port))
client.agent.self()
return True
except ValueError:
pass
return False
def driver_version(self):
return consul.__version__

View File

@@ -0,0 +1,300 @@
"""Etcd Transport module for Kombu.
It uses Etcd as a store to transport messages in Queues
It uses python-etcd for talking to Etcd's HTTP API
Features
========
* Type: Virtual
* Supports Direct: *Unreviewed*
* Supports Topic: *Unreviewed*
* Supports Fanout: *Unreviewed*
* Supports Priority: *Unreviewed*
* Supports TTL: *Unreviewed*
Connection String
=================
Connection string has the following format:
.. code-block::
'etcd'://SERVER:PORT
"""
from __future__ import annotations
import os
import socket
from collections import defaultdict
from contextlib import contextmanager
from queue import Empty
from kombu.exceptions import ChannelError
from kombu.log import get_logger
from kombu.utils.json import dumps, loads
from kombu.utils.objects import cached_property
from . import virtual
try:
import etcd
except ImportError:
etcd = None
logger = get_logger('kombu.transport.etcd')
DEFAULT_PORT = 2379
DEFAULT_HOST = 'localhost'
class Channel(virtual.Channel):
"""Etcd Channel class which talks to the Etcd."""
prefix = 'kombu'
index = None
timeout = 10
session_ttl = 30
lock_ttl = 10
def __init__(self, *args, **kwargs):
if etcd is None:
raise ImportError('Missing python-etcd library')
super().__init__(*args, **kwargs)
port = self.connection.client.port or self.connection.default_port
host = self.connection.client.hostname or DEFAULT_HOST
logger.debug('Host: %s Port: %s Timeout: %s', host, port, self.timeout)
self.queues = defaultdict(dict)
self.client = etcd.Client(host=host, port=int(port))
def _key_prefix(self, queue):
"""Create and return the `queue` with the proper prefix.
Arguments:
---------
queue (str): The name of the queue.
"""
return f'{self.prefix}/{queue}'
@contextmanager
def _queue_lock(self, queue):
"""Try to acquire a lock on the Queue.
It does so by creating a object called 'lock' which is locked by the
current session..
This way other nodes are not able to write to the lock object which
means that they have to wait before the lock is released.
Arguments:
---------
queue (str): The name of the queue.
"""
lock = etcd.Lock(self.client, queue)
lock._uuid = self.lock_value
logger.debug(f'Acquiring lock {lock.name}')
lock.acquire(blocking=True, lock_ttl=self.lock_ttl)
try:
yield
finally:
logger.debug(f'Releasing lock {lock.name}')
lock.release()
def _new_queue(self, queue, **_):
"""Create a new `queue` if the `queue` doesn't already exist.
Arguments:
---------
queue (str): The name of the queue.
"""
self.queues[queue] = queue
with self._queue_lock(queue):
try:
return self.client.write(
key=self._key_prefix(queue), dir=True, value=None)
except etcd.EtcdNotFile:
logger.debug(f'Queue "{queue}" already exists')
return self.client.read(key=self._key_prefix(queue))
def _has_queue(self, queue, **kwargs):
"""Verify that queue exists.
Returns
-------
bool: Should return :const:`True` if the queue exists
or :const:`False` otherwise.
"""
try:
self.client.read(self._key_prefix(queue))
return True
except etcd.EtcdKeyNotFound:
return False
def _delete(self, queue, *args, **_):
"""Delete a `queue`.
Arguments:
---------
queue (str): The name of the queue.
"""
self.queues.pop(queue, None)
self._purge(queue)
def _put(self, queue, payload, **_):
"""Put `message` onto `queue`.
This simply writes a key to the Etcd store
Arguments:
---------
queue (str): The name of the queue.
payload (dict): Message data which will be dumped to etcd.
"""
with self._queue_lock(queue):
key = self._key_prefix(queue)
if not self.client.write(
key=key,
value=dumps(payload),
append=True):
raise ChannelError(f'Cannot add key {key!r} to etcd')
def _get(self, queue, timeout=None):
"""Get the first available message from the queue.
Before it does so it acquires a lock on the store so
only one node reads at the same time. This is for read consistency
Arguments:
---------
queue (str): The name of the queue.
timeout (int): Optional seconds to wait for a response.
"""
with self._queue_lock(queue):
key = self._key_prefix(queue)
logger.debug('Fetching key %s with index %s', key, self.index)
try:
result = self.client.read(
key=key, recursive=True,
index=self.index, timeout=self.timeout)
if result is None:
raise Empty()
item = result._children[-1]
logger.debug('Removing key {}'.format(item['key']))
msg_content = loads(item['value'])
self.client.delete(key=item['key'])
return msg_content
except (TypeError, IndexError, etcd.EtcdException) as error:
logger.debug(f'_get failed: {type(error)}:{error}')
raise Empty()
def _purge(self, queue):
"""Remove all `message`s from a `queue`.
Arguments:
---------
queue (str): The name of the queue.
"""
with self._queue_lock(queue):
key = self._key_prefix(queue)
logger.debug(f'Purging queue at key {key}')
return self.client.delete(key=key, recursive=True)
def _size(self, queue):
"""Return the size of the `queue`.
Arguments:
---------
queue (str): The name of the queue.
"""
with self._queue_lock(queue):
size = 0
try:
key = self._key_prefix(queue)
logger.debug('Fetching key recursively %s with index %s',
key, self.index)
result = self.client.read(
key=key, recursive=True,
index=self.index)
size = len(result._children)
except TypeError:
pass
logger.debug('Found %s keys under %s with index %s',
size, key, self.index)
return size
@cached_property
def lock_value(self):
return f'{socket.gethostname()}.{os.getpid()}'
class Transport(virtual.Transport):
"""Etcd storage Transport for Kombu."""
Channel = Channel
default_port = DEFAULT_PORT
driver_type = 'etcd'
driver_name = 'python-etcd'
polling_interval = 3
implements = virtual.Transport.implements.extend(
exchange_type=frozenset(['direct']))
if etcd:
connection_errors = (
virtual.Transport.connection_errors + (etcd.EtcdException, )
)
channel_errors = (
virtual.Transport.channel_errors + (etcd.EtcdException, )
)
def __init__(self, *args, **kwargs):
"""Create a new instance of etcd.Transport."""
if etcd is None:
raise ImportError('Missing python-etcd library')
super().__init__(*args, **kwargs)
def verify_connection(self, connection):
"""Verify the connection works."""
port = connection.client.port or self.default_port
host = connection.client.hostname or DEFAULT_HOST
logger.debug('Verify Etcd connection to %s:%s', host, port)
try:
etcd.Client(host=host, port=int(port))
return True
except ValueError:
pass
return False
def driver_version(self):
"""Return the version of the etcd library.
.. note::
python-etcd has no __version__. This is a workaround.
"""
try:
import pip.commands.freeze
for x in pip.commands.freeze.freeze():
if x.startswith('python-etcd'):
return x.split('==')[1]
except (ImportError, IndexError):
logger.warning('Unable to find the python-etcd version.')
return 'Unknown'

View File

@@ -0,0 +1,352 @@
"""File-system Transport module for kombu.
Transport using the file-system as the message store. Messages written to the
queue are stored in `data_folder_in` directory and
messages read from the queue are read from `data_folder_out` directory. Both
directories must be created manually. Simple example:
* Producer:
.. code-block:: python
import kombu
conn = kombu.Connection(
'filesystem://', transport_options={
'data_folder_in': 'data_in', 'data_folder_out': 'data_out'
}
)
conn.connect()
test_queue = kombu.Queue('test', routing_key='test')
with conn as conn:
with conn.default_channel as channel:
producer = kombu.Producer(channel)
producer.publish(
{'hello': 'world'},
retry=True,
exchange=test_queue.exchange,
routing_key=test_queue.routing_key,
declare=[test_queue],
serializer='pickle'
)
* Consumer:
.. code-block:: python
import kombu
conn = kombu.Connection(
'filesystem://', transport_options={
'data_folder_in': 'data_out', 'data_folder_out': 'data_in'
}
)
conn.connect()
def callback(body, message):
print(body, message)
message.ack()
test_queue = kombu.Queue('test', routing_key='test')
with conn as conn:
with conn.default_channel as channel:
consumer = kombu.Consumer(
conn, [test_queue], accept=['pickle']
)
consumer.register_callback(callback)
with consumer:
conn.drain_events(timeout=1)
Features
========
* Type: Virtual
* Supports Direct: Yes
* Supports Topic: Yes
* Supports Fanout: Yes
* Supports Priority: No
* Supports TTL: No
Connection String
=================
Connection string is in the following format:
.. code-block::
filesystem://
Transport Options
=================
* ``data_folder_in`` - directory where are messages stored when written
to queue.
* ``data_folder_out`` - directory from which are messages read when read from
queue.
* ``store_processed`` - if set to True, all processed messages are backed up to
``processed_folder``.
* ``processed_folder`` - directory where are backed up processed files.
* ``control_folder`` - directory where are exchange-queue table stored.
"""
from __future__ import annotations
import os
import shutil
import tempfile
import uuid
from collections import namedtuple
from pathlib import Path
from queue import Empty
from time import monotonic
from kombu.exceptions import ChannelError
from kombu.transport import virtual
from kombu.utils.encoding import bytes_to_str, str_to_bytes
from kombu.utils.json import dumps, loads
from kombu.utils.objects import cached_property
VERSION = (1, 0, 0)
__version__ = '.'.join(map(str, VERSION))
# needs win32all to work on Windows
if os.name == 'nt':
import pywintypes
import win32con
import win32file
LOCK_EX = win32con.LOCKFILE_EXCLUSIVE_LOCK
# 0 is the default
LOCK_SH = 0
LOCK_NB = win32con.LOCKFILE_FAIL_IMMEDIATELY
__overlapped = pywintypes.OVERLAPPED()
def lock(file, flags):
"""Create file lock."""
hfile = win32file._get_osfhandle(file.fileno())
win32file.LockFileEx(hfile, flags, 0, 0xffff0000, __overlapped)
def unlock(file):
"""Remove file lock."""
hfile = win32file._get_osfhandle(file.fileno())
win32file.UnlockFileEx(hfile, 0, 0xffff0000, __overlapped)
elif os.name == 'posix':
import fcntl
from fcntl import LOCK_EX, LOCK_SH
def lock(file, flags):
"""Create file lock."""
fcntl.flock(file.fileno(), flags)
def unlock(file):
"""Remove file lock."""
fcntl.flock(file.fileno(), fcntl.LOCK_UN)
else:
raise RuntimeError(
'Filesystem plugin only defined for NT and POSIX platforms')
exchange_queue_t = namedtuple("exchange_queue_t",
["routing_key", "pattern", "queue"])
class Channel(virtual.Channel):
"""Filesystem Channel."""
supports_fanout = True
def get_table(self, exchange):
file = self.control_folder / f"{exchange}.exchange"
try:
f_obj = file.open("r")
try:
lock(f_obj, LOCK_SH)
exchange_table = loads(bytes_to_str(f_obj.read()))
return [exchange_queue_t(*q) for q in exchange_table]
finally:
unlock(f_obj)
f_obj.close()
except FileNotFoundError:
return []
except OSError:
raise ChannelError(f"Cannot open {file}")
def _queue_bind(self, exchange, routing_key, pattern, queue):
file = self.control_folder / f"{exchange}.exchange"
self.control_folder.mkdir(exist_ok=True)
queue_val = exchange_queue_t(routing_key or "", pattern or "",
queue or "")
try:
if file.exists():
f_obj = file.open("rb+", buffering=0)
lock(f_obj, LOCK_EX)
exchange_table = loads(bytes_to_str(f_obj.read()))
queues = [exchange_queue_t(*q) for q in exchange_table]
if queue_val not in queues:
queues.insert(0, queue_val)
f_obj.seek(0)
f_obj.write(str_to_bytes(dumps(queues)))
else:
f_obj = file.open("wb", buffering=0)
lock(f_obj, LOCK_EX)
queues = [queue_val]
f_obj.write(str_to_bytes(dumps(queues)))
finally:
unlock(f_obj)
f_obj.close()
def _put_fanout(self, exchange, payload, routing_key, **kwargs):
for q in self.get_table(exchange):
self._put(q.queue, payload, **kwargs)
def _put(self, queue, payload, **kwargs):
"""Put `message` onto `queue`."""
filename = '{}_{}.{}.msg'.format(int(round(monotonic() * 1000)),
uuid.uuid4(), queue)
filename = os.path.join(self.data_folder_out, filename)
try:
f = open(filename, 'wb', buffering=0)
lock(f, LOCK_EX)
f.write(str_to_bytes(dumps(payload)))
except OSError:
raise ChannelError(
f'Cannot add file {filename!r} to directory')
finally:
unlock(f)
f.close()
def _get(self, queue):
"""Get next message from `queue`."""
queue_find = '.' + queue + '.msg'
folder = os.listdir(self.data_folder_in)
folder = sorted(folder)
while len(folder) > 0:
filename = folder.pop(0)
# only handle message for the requested queue
if filename.find(queue_find) < 0:
continue
if self.store_processed:
processed_folder = self.processed_folder
else:
processed_folder = tempfile.gettempdir()
try:
# move the file to the tmp/processed folder
shutil.move(os.path.join(self.data_folder_in, filename),
processed_folder)
except OSError:
# file could be locked, or removed in meantime so ignore
continue
filename = os.path.join(processed_folder, filename)
try:
f = open(filename, 'rb')
payload = f.read()
f.close()
if not self.store_processed:
os.remove(filename)
except OSError:
raise ChannelError(
f'Cannot read file {filename!r} from queue.')
return loads(bytes_to_str(payload))
raise Empty()
def _purge(self, queue):
"""Remove all messages from `queue`."""
count = 0
queue_find = '.' + queue + '.msg'
folder = os.listdir(self.data_folder_in)
while len(folder) > 0:
filename = folder.pop()
try:
# only purge messages for the requested queue
if filename.find(queue_find) < 0:
continue
filename = os.path.join(self.data_folder_in, filename)
os.remove(filename)
count += 1
except OSError:
# we simply ignore its existence, as it was probably
# processed by another worker
pass
return count
def _size(self, queue):
"""Return the number of messages in `queue` as an :class:`int`."""
count = 0
queue_find = f'.{queue}.msg'
folder = os.listdir(self.data_folder_in)
while len(folder) > 0:
filename = folder.pop()
# only handle message for the requested queue
if filename.find(queue_find) < 0:
continue
count += 1
return count
@property
def transport_options(self):
return self.connection.client.transport_options
@cached_property
def data_folder_in(self):
return self.transport_options.get('data_folder_in', 'data_in')
@cached_property
def data_folder_out(self):
return self.transport_options.get('data_folder_out', 'data_out')
@cached_property
def store_processed(self):
return self.transport_options.get('store_processed', False)
@cached_property
def processed_folder(self):
return self.transport_options.get('processed_folder', 'processed')
@property
def control_folder(self):
return Path(self.transport_options.get('control_folder', 'control'))
class Transport(virtual.Transport):
"""Filesystem Transport."""
implements = virtual.Transport.implements.extend(
asynchronous=False,
exchange_type=frozenset(['direct', 'topic', 'fanout'])
)
Channel = Channel
# filesystem backend state is global.
global_state = virtual.BrokerState()
default_port = 0
driver_type = 'filesystem'
driver_name = 'filesystem'
def __init__(self, client, **kwargs):
super().__init__(client, **kwargs)
self.state = self.global_state
def driver_version(self):
return 'N/A'

View File

@@ -0,0 +1,810 @@
"""GCP Pub/Sub transport module for kombu.
More information about GCP Pub/Sub:
https://cloud.google.com/pubsub
Features
========
* Type: Virtual
* Supports Direct: Yes
* Supports Topic: No
* Supports Fanout: Yes
* Supports Priority: No
* Supports TTL: No
Connection String
=================
Connection string has the following formats:
.. code-block::
gcpubsub://projects/project-name
Transport Options
=================
* ``queue_name_prefix``: (str) Prefix for queue names.
* ``ack_deadline_seconds``: (int) The maximum time after receiving a message
and acknowledging it before pub/sub redelivers the message.
* ``expiration_seconds``: (int) Subscriptions without any subscriber
activity or changes made to their properties are removed after this period.
Examples of subscriber activities include open connections,
active pulls, or successful pushes.
* ``wait_time_seconds``: (int) The maximum time to wait for new messages.
Defaults to 10.
* ``retry_timeout_seconds``: (int) The maximum time to wait before retrying.
* ``bulk_max_messages``: (int) The maximum number of messages to pull in bulk.
Defaults to 32.
"""
from __future__ import annotations
import dataclasses
import datetime
import string
import threading
from concurrent.futures import (FIRST_COMPLETED, Future, ThreadPoolExecutor,
wait)
from contextlib import suppress
from os import getpid
from queue import Empty
from threading import Lock
from time import monotonic, sleep
from uuid import NAMESPACE_OID, uuid3
from _socket import gethostname
from _socket import timeout as socket_timeout
from google.api_core.exceptions import (AlreadyExists, DeadlineExceeded,
PermissionDenied)
from google.api_core.retry import Retry
from google.cloud import monitoring_v3
from google.cloud.monitoring_v3 import query
from google.cloud.pubsub_v1 import PublisherClient, SubscriberClient
from google.cloud.pubsub_v1 import exceptions as pubsub_exceptions
from google.cloud.pubsub_v1.publisher import exceptions as publisher_exceptions
from google.cloud.pubsub_v1.subscriber import \
exceptions as subscriber_exceptions
from google.pubsub_v1 import gapic_version as package_version
from kombu.entity import TRANSIENT_DELIVERY_MODE
from kombu.log import get_logger
from kombu.utils.encoding import bytes_to_str, safe_str
from kombu.utils.json import dumps, loads
from kombu.utils.objects import cached_property
from . import virtual
logger = get_logger('kombu.transport.gcpubsub')
# dots are replaced by dash, all other punctuation replaced by underscore.
PUNCTUATIONS_TO_REPLACE = set(string.punctuation) - {'_', '.', '-'}
CHARS_REPLACE_TABLE = {
ord('.'): ord('-'),
**{ord(c): ord('_') for c in PUNCTUATIONS_TO_REPLACE},
}
class UnackedIds:
"""Threadsafe list of ack_ids."""
def __init__(self):
self._list = []
self._lock = Lock()
def append(self, val):
# append is atomic
self._list.append(val)
def extend(self, vals: list):
# extend is atomic
self._list.extend(vals)
def pop(self, index=-1):
with self._lock:
return self._list.pop(index)
def remove(self, val):
with self._lock, suppress(ValueError):
self._list.remove(val)
def __len__(self):
with self._lock:
return len(self._list)
def __getitem__(self, item):
# getitem is atomic
return self._list[item]
class AtomicCounter:
"""Threadsafe counter.
Returns the value after inc/dec operations.
"""
def __init__(self, initial=0):
self._value = initial
self._lock = Lock()
def inc(self, n=1):
with self._lock:
self._value += n
return self._value
def dec(self, n=1):
with self._lock:
self._value -= n
return self._value
def get(self):
with self._lock:
return self._value
@dataclasses.dataclass
class QueueDescriptor:
"""Pub/Sub queue descriptor."""
name: str
topic_path: str # projects/{project_id}/topics/{topic_id}
subscription_id: str
subscription_path: str # projects/{project_id}/subscriptions/{subscription_id}
unacked_ids: UnackedIds = dataclasses.field(default_factory=UnackedIds)
class Channel(virtual.Channel):
"""GCP Pub/Sub channel."""
supports_fanout = True
do_restore = False # pub/sub does that for us
default_wait_time_seconds = 10
default_ack_deadline_seconds = 240
default_expiration_seconds = 86400
default_retry_timeout_seconds = 300
default_bulk_max_messages = 32
_min_ack_deadline = 10
_fanout_exchanges = set()
_unacked_extender: threading.Thread = None
_stop_extender = threading.Event()
_n_channels = AtomicCounter()
_queue_cache: dict[str, QueueDescriptor] = {}
_tmp_subscriptions: set[str] = set()
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.pool = ThreadPoolExecutor()
logger.info('new GCP pub/sub channel: %s', self.conninfo.hostname)
self.project_id = Transport.parse_uri(self.conninfo.hostname)
if self._n_channels.inc() == 1:
Channel._unacked_extender = threading.Thread(
target=self._extend_unacked_deadline,
daemon=True,
)
self._stop_extender.clear()
Channel._unacked_extender.start()
def entity_name(self, name: str, table=CHARS_REPLACE_TABLE) -> str:
"""Format AMQP queue name into a valid Pub/Sub queue name."""
if not name.startswith(self.queue_name_prefix):
name = self.queue_name_prefix + name
return str(safe_str(name)).translate(table)
def _queue_bind(self, exchange, routing_key, pattern, queue):
exchange_type = self.typeof(exchange).type
queue = self.entity_name(queue)
logger.debug(
'binding queue: %s to %s exchange: %s with routing_key: %s',
queue,
exchange_type,
exchange,
routing_key,
)
filter_args = {}
if exchange_type == 'direct':
# Direct exchange is implemented as a single subscription
# E.g. for exchange 'test_direct':
# -topic:'test_direct'
# -bound queue:'direct1':
# -subscription: direct1' on topic 'test_direct'
# -filter:routing_key'
filter_args = {
'filter': f'attributes.routing_key="{routing_key}"'
}
subscription_path = self.subscriber.subscription_path(
self.project_id, queue
)
message_retention_duration = self.expiration_seconds
elif exchange_type == 'fanout':
# Fanout exchange is implemented as a separate subscription.
# E.g. for exchange 'test_fanout':
# -topic:'test_fanout'
# -bound queue 'fanout1':
# -subscription:'fanout1-uuid' on topic 'test_fanout'
# -bound queue 'fanout2':
# -subscription:'fanout2-uuid' on topic 'test_fanout'
uid = f'{uuid3(NAMESPACE_OID, f"{gethostname()}.{getpid()}")}'
uniq_sub_name = f'{queue}-{uid}'
subscription_path = self.subscriber.subscription_path(
self.project_id, uniq_sub_name
)
self._tmp_subscriptions.add(subscription_path)
self._fanout_exchanges.add(exchange)
message_retention_duration = 600
else:
raise NotImplementedError(
f'exchange type {exchange_type} not implemented'
)
exchange_topic = self._create_topic(
self.project_id, exchange, message_retention_duration
)
self._create_subscription(
topic_path=exchange_topic,
subscription_path=subscription_path,
filter_args=filter_args,
msg_retention=message_retention_duration,
)
qdesc = QueueDescriptor(
name=queue,
topic_path=exchange_topic,
subscription_id=queue,
subscription_path=subscription_path,
)
self._queue_cache[queue] = qdesc
def _create_topic(
self,
project_id: str,
topic_id: str,
message_retention_duration: int = None,
) -> str:
topic_path = self.publisher.topic_path(project_id, topic_id)
if self._is_topic_exists(topic_path):
# topic creation takes a while, so skip if possible
logger.debug('topic: %s exists', topic_path)
return topic_path
try:
logger.debug('creating topic: %s', topic_path)
request = {'name': topic_path}
if message_retention_duration:
request[
'message_retention_duration'
] = f'{message_retention_duration}s'
self.publisher.create_topic(request=request)
except AlreadyExists:
pass
return topic_path
def _is_topic_exists(self, topic_path: str) -> bool:
topics = self.publisher.list_topics(
request={"project": f'projects/{self.project_id}'}
)
for t in topics:
if t.name == topic_path:
return True
return False
def _create_subscription(
self,
project_id: str = None,
topic_id: str = None,
topic_path: str = None,
subscription_path: str = None,
filter_args=None,
msg_retention: int = None,
) -> str:
subscription_path = (
subscription_path
or self.subscriber.subscription_path(self.project_id, topic_id)
)
topic_path = topic_path or self.publisher.topic_path(
project_id, topic_id
)
try:
logger.debug(
'creating subscription: %s, topic: %s, filter: %s',
subscription_path,
topic_path,
filter_args,
)
msg_retention = msg_retention or self.expiration_seconds
self.subscriber.create_subscription(
request={
"name": subscription_path,
"topic": topic_path,
'ack_deadline_seconds': self.ack_deadline_seconds,
'expiration_policy': {
'ttl': f'{self.expiration_seconds}s'
},
'message_retention_duration': f'{msg_retention}s',
**(filter_args or {}),
}
)
except AlreadyExists:
pass
return subscription_path
def _delete(self, queue, *args, **kwargs):
"""Delete a queue by name."""
queue = self.entity_name(queue)
logger.info('deleting queue: %s', queue)
qdesc = self._queue_cache.get(queue)
if not qdesc:
return
self.subscriber.delete_subscription(
request={"subscription": qdesc.subscription_path}
)
self._queue_cache.pop(queue, None)
def _put(self, queue, message, **kwargs):
"""Put a message onto the queue."""
queue = self.entity_name(queue)
qdesc = self._queue_cache[queue]
routing_key = self._get_routing_key(message)
logger.debug(
'putting message to queue: %s, topic: %s, routing_key: %s',
queue,
qdesc.topic_path,
routing_key,
)
encoded_message = dumps(message)
self.publisher.publish(
qdesc.topic_path,
encoded_message.encode("utf-8"),
routing_key=routing_key,
)
def _put_fanout(self, exchange, message, routing_key, **kwargs):
"""Put a message onto fanout exchange."""
self._lookup(exchange, routing_key)
topic_path = self.publisher.topic_path(self.project_id, exchange)
logger.debug(
'putting msg to fanout exchange: %s, topic: %s',
exchange,
topic_path,
)
encoded_message = dumps(message)
self.publisher.publish(
topic_path,
encoded_message.encode("utf-8"),
retry=Retry(deadline=self.retry_timeout_seconds),
)
def _get(self, queue: str, timeout: float = None):
"""Retrieves a single message from a queue."""
queue = self.entity_name(queue)
qdesc = self._queue_cache[queue]
try:
response = self.subscriber.pull(
request={
'subscription': qdesc.subscription_path,
'max_messages': 1,
},
retry=Retry(deadline=self.retry_timeout_seconds),
timeout=timeout or self.wait_time_seconds,
)
except DeadlineExceeded:
raise Empty()
if len(response.received_messages) == 0:
raise Empty()
message = response.received_messages[0]
ack_id = message.ack_id
payload = loads(message.message.data)
delivery_info = payload['properties']['delivery_info']
logger.debug(
'queue:%s got message, ack_id: %s, payload: %s',
queue,
ack_id,
payload['properties'],
)
if self._is_auto_ack(payload['properties']):
logger.debug('auto acking message ack_id: %s', ack_id)
self._do_ack([ack_id], qdesc.subscription_path)
else:
delivery_info['gcpubsub_message'] = {
'queue': queue,
'ack_id': ack_id,
'message_id': message.message.message_id,
'subscription_path': qdesc.subscription_path,
}
qdesc.unacked_ids.append(ack_id)
return payload
def _is_auto_ack(self, payload_properties: dict):
exchange = payload_properties['delivery_info']['exchange']
delivery_mode = payload_properties['delivery_mode']
return (
delivery_mode == TRANSIENT_DELIVERY_MODE
or exchange in self._fanout_exchanges
)
def _get_bulk(self, queue: str, timeout: float):
"""Retrieves bulk of messages from a queue."""
prefixed_queue = self.entity_name(queue)
qdesc = self._queue_cache[prefixed_queue]
max_messages = self._get_max_messages_estimate()
if not max_messages:
raise Empty()
try:
response = self.subscriber.pull(
request={
'subscription': qdesc.subscription_path,
'max_messages': max_messages,
},
retry=Retry(deadline=self.retry_timeout_seconds),
timeout=timeout or self.wait_time_seconds,
)
except DeadlineExceeded:
raise Empty()
received_messages = response.received_messages
if len(received_messages) == 0:
raise Empty()
auto_ack_ids = []
ret_payloads = []
logger.debug(
'batching %d messages from queue: %s',
len(received_messages),
prefixed_queue,
)
for message in received_messages:
ack_id = message.ack_id
payload = loads(bytes_to_str(message.message.data))
delivery_info = payload['properties']['delivery_info']
delivery_info['gcpubsub_message'] = {
'queue': prefixed_queue,
'ack_id': ack_id,
'message_id': message.message.message_id,
'subscription_path': qdesc.subscription_path,
}
if self._is_auto_ack(payload['properties']):
auto_ack_ids.append(ack_id)
else:
qdesc.unacked_ids.append(ack_id)
ret_payloads.append(payload)
if auto_ack_ids:
logger.debug('auto acking ack_ids: %s', auto_ack_ids)
self._do_ack(auto_ack_ids, qdesc.subscription_path)
return queue, ret_payloads
def _get_max_messages_estimate(self) -> int:
max_allowed = self.qos.can_consume_max_estimate()
max_if_unlimited = self.bulk_max_messages
return max_if_unlimited if max_allowed is None else max_allowed
def _lookup(self, exchange, routing_key, default=None):
exchange_info = self.state.exchanges.get(exchange, {})
if not exchange_info:
return super()._lookup(exchange, routing_key, default)
ret = self.typeof(exchange).lookup(
self.get_table(exchange),
exchange,
routing_key,
default,
)
if ret:
return ret
logger.debug(
'no queues bound to exchange: %s, binding on the fly',
exchange,
)
self.queue_bind(exchange, exchange, routing_key)
return [exchange]
def _size(self, queue: str) -> int:
"""Return the number of messages in a queue.
This is a *rough* estimation, as Pub/Sub doesn't provide
an exact API.
"""
queue = self.entity_name(queue)
if queue not in self._queue_cache:
return 0
qdesc = self._queue_cache[queue]
result = query.Query(
self.monitor,
self.project_id,
'pubsub.googleapis.com/subscription/num_undelivered_messages',
end_time=datetime.datetime.now(),
minutes=1,
).select_resources(subscription_id=qdesc.subscription_id)
# monitoring API requires the caller to have the monitoring.viewer
# role. Since we can live without the exact number of messages
# in the queue, we can ignore the exception and allow users to
# use the transport without this role.
with suppress(PermissionDenied):
return sum(
content.points[0].value.int64_value for content in result
)
return -1
def basic_ack(self, delivery_tag, multiple=False):
"""Acknowledge one message."""
if multiple:
raise NotImplementedError('multiple acks not implemented')
delivery_info = self.qos.get(delivery_tag).delivery_info
pubsub_message = delivery_info['gcpubsub_message']
ack_id = pubsub_message['ack_id']
queue = pubsub_message['queue']
logger.debug('ack message. queue: %s ack_id: %s', queue, ack_id)
subscription_path = pubsub_message['subscription_path']
self._do_ack([ack_id], subscription_path)
qdesc = self._queue_cache[queue]
qdesc.unacked_ids.remove(ack_id)
super().basic_ack(delivery_tag)
def _do_ack(self, ack_ids: list[str], subscription_path: str):
self.subscriber.acknowledge(
request={"subscription": subscription_path, "ack_ids": ack_ids},
retry=Retry(deadline=self.retry_timeout_seconds),
)
def _purge(self, queue: str):
"""Delete all current messages in a queue."""
queue = self.entity_name(queue)
qdesc = self._queue_cache.get(queue)
if not qdesc:
return
n = self._size(queue)
self.subscriber.seek(
request={
"subscription": qdesc.subscription_path,
"time": datetime.datetime.now(),
}
)
return n
def _extend_unacked_deadline(self):
thread_id = threading.get_native_id()
logger.info(
'unacked deadline extension thread: [%s] started',
thread_id,
)
min_deadline_sleep = self._min_ack_deadline / 2
sleep_time = max(min_deadline_sleep, self.ack_deadline_seconds / 4)
while not self._stop_extender.wait(sleep_time):
for qdesc in self._queue_cache.values():
if len(qdesc.unacked_ids) == 0:
logger.debug(
'thread [%s]: no unacked messages for %s',
thread_id,
qdesc.subscription_path,
)
continue
logger.debug(
'thread [%s]: extend ack deadline for %s: %d msgs [%s]',
thread_id,
qdesc.subscription_path,
len(qdesc.unacked_ids),
list(qdesc.unacked_ids),
)
self.subscriber.modify_ack_deadline(
request={
"subscription": qdesc.subscription_path,
"ack_ids": list(qdesc.unacked_ids),
"ack_deadline_seconds": self.ack_deadline_seconds,
}
)
logger.info(
'unacked deadline extension thread [%s] stopped', thread_id
)
def after_reply_message_received(self, queue: str):
queue = self.entity_name(queue)
sub = self.subscriber.subscription_path(self.project_id, queue)
logger.debug(
'after_reply_message_received: queue: %s, sub: %s', queue, sub
)
self._tmp_subscriptions.add(sub)
@cached_property
def subscriber(self):
return SubscriberClient()
@cached_property
def publisher(self):
return PublisherClient()
@cached_property
def monitor(self):
return monitoring_v3.MetricServiceClient()
@property
def conninfo(self):
return self.connection.client
@property
def transport_options(self):
return self.connection.client.transport_options
@cached_property
def wait_time_seconds(self):
return self.transport_options.get(
'wait_time_seconds', self.default_wait_time_seconds
)
@cached_property
def retry_timeout_seconds(self):
return self.transport_options.get(
'retry_timeout_seconds', self.default_retry_timeout_seconds
)
@cached_property
def ack_deadline_seconds(self):
return self.transport_options.get(
'ack_deadline_seconds', self.default_ack_deadline_seconds
)
@cached_property
def queue_name_prefix(self):
return self.transport_options.get('queue_name_prefix', 'kombu-')
@cached_property
def expiration_seconds(self):
return self.transport_options.get(
'expiration_seconds', self.default_expiration_seconds
)
@cached_property
def bulk_max_messages(self):
return self.transport_options.get(
'bulk_max_messages', self.default_bulk_max_messages
)
def close(self):
"""Close the channel."""
logger.debug('closing channel')
while self._tmp_subscriptions:
sub = self._tmp_subscriptions.pop()
with suppress(Exception):
logger.debug('deleting subscription: %s', sub)
self.subscriber.delete_subscription(
request={"subscription": sub}
)
if not self._n_channels.dec():
self._stop_extender.set()
Channel._unacked_extender.join()
super().close()
@staticmethod
def _get_routing_key(message):
routing_key = (
message['properties']
.get('delivery_info', {})
.get('routing_key', '')
)
return routing_key
class Transport(virtual.Transport):
"""GCP Pub/Sub transport."""
Channel = Channel
can_parse_url = True
polling_interval = 0.1
connection_errors = virtual.Transport.connection_errors + (
pubsub_exceptions.TimeoutError,
)
channel_errors = (
virtual.Transport.channel_errors
+ (
publisher_exceptions.FlowControlLimitError,
publisher_exceptions.MessageTooLargeError,
publisher_exceptions.PublishError,
publisher_exceptions.TimeoutError,
publisher_exceptions.PublishToPausedOrderingKeyException,
)
+ (subscriber_exceptions.AcknowledgeError,)
)
driver_type = 'gcpubsub'
driver_name = 'pubsub_v1'
implements = virtual.Transport.implements.extend(
exchange_type=frozenset(['direct', 'fanout']),
)
def __init__(self, client, **kwargs):
super().__init__(client, **kwargs)
self._pool = ThreadPoolExecutor()
self._get_bulk_future_to_queue: dict[Future, str] = dict()
def driver_version(self):
return package_version.__version__
@staticmethod
def parse_uri(uri: str) -> str:
# URL like:
# gcpubsub://projects/project-name
project = uri.split('gcpubsub://projects/')[1]
return project.strip('/')
@classmethod
def as_uri(self, uri: str, include_password=False, mask='**') -> str:
return uri or 'gcpubsub://'
def drain_events(self, connection, timeout=None):
time_start = monotonic()
polling_interval = self.polling_interval
if timeout and polling_interval and polling_interval > timeout:
polling_interval = timeout
while 1:
try:
self._drain_from_active_queues(timeout=timeout)
except Empty:
if timeout and monotonic() - time_start >= timeout:
raise socket_timeout()
if polling_interval:
sleep(polling_interval)
else:
break
def _drain_from_active_queues(self, timeout):
# cleanup empty requests from prev run
self._rm_empty_bulk_requests()
# submit new requests for all active queues
# longer timeout means less frequent polling
# and more messages in a single bulk
self._submit_get_bulk_requests(timeout=10)
done, _ = wait(
self._get_bulk_future_to_queue,
timeout=timeout,
return_when=FIRST_COMPLETED,
)
empty = {f for f in done if f.exception()}
done -= empty
for f in empty:
self._get_bulk_future_to_queue.pop(f, None)
if not done:
raise Empty()
logger.debug('got %d done get_bulk tasks', len(done))
for f in done:
queue, payloads = f.result()
for payload in payloads:
logger.debug('consuming message from queue: %s', queue)
if queue not in self._callbacks:
logger.warning(
'Message for queue %s without consumers', queue
)
continue
self._deliver(payload, queue)
self._get_bulk_future_to_queue.pop(f, None)
def _rm_empty_bulk_requests(self):
empty = {
f
for f in self._get_bulk_future_to_queue
if f.done() and f.exception()
}
for f in empty:
self._get_bulk_future_to_queue.pop(f, None)
def _submit_get_bulk_requests(self, timeout):
queues_with_submitted_get_bulk = set(
self._get_bulk_future_to_queue.values()
)
for channel in self.channels:
for queue in channel._active_queues:
if queue in queues_with_submitted_get_bulk:
continue
future = self._pool.submit(channel._get_bulk, queue, timeout)
self._get_bulk_future_to_queue[future] = queue

View File

@@ -0,0 +1,190 @@
"""`librabbitmq`_ transport.
.. _`librabbitmq`: https://pypi.org/project/librabbitmq/
"""
from __future__ import annotations
import os
import socket
import warnings
import librabbitmq as amqp
from librabbitmq import ChannelError, ConnectionError
from kombu.utils.amq_manager import get_manager
from kombu.utils.text import version_string_as_tuple
from . import base
from .base import to_rabbitmq_queue_arguments
W_VERSION = """
librabbitmq version too old to detect RabbitMQ version information
so make sure you are using librabbitmq 1.5 when using rabbitmq > 3.3
"""
DEFAULT_PORT = 5672
DEFAULT_SSL_PORT = 5671
NO_SSL_ERROR = """\
ssl not supported by librabbitmq, please use pyamqp:// or stunnel\
"""
class Message(base.Message):
"""AMQP Message (librabbitmq)."""
def __init__(self, channel, props, info, body):
super().__init__(
channel=channel,
body=body,
delivery_info=info,
properties=props,
delivery_tag=info.get('delivery_tag'),
content_type=props.get('content_type'),
content_encoding=props.get('content_encoding'),
headers=props.get('headers'))
class Channel(amqp.Channel, base.StdChannel):
"""AMQP Channel (librabbitmq)."""
Message = Message
def prepare_message(self, body, priority=None,
content_type=None, content_encoding=None,
headers=None, properties=None):
"""Encapsulate data into a AMQP message."""
properties = properties if properties is not None else {}
properties.update({'content_type': content_type,
'content_encoding': content_encoding,
'headers': headers})
# Don't include priority if it's not an integer.
# If that's the case librabbitmq will fail
# and raise an exception.
if priority is not None:
properties['priority'] = priority
return body, properties
def prepare_queue_arguments(self, arguments, **kwargs):
arguments = to_rabbitmq_queue_arguments(arguments, **kwargs)
return {k.encode('utf8'): v for k, v in arguments.items()}
class Connection(amqp.Connection):
"""AMQP Connection (librabbitmq)."""
Channel = Channel
Message = Message
class Transport(base.Transport):
"""AMQP Transport (librabbitmq)."""
Connection = Connection
default_port = DEFAULT_PORT
default_ssl_port = DEFAULT_SSL_PORT
connection_errors = (
base.Transport.connection_errors + (
ConnectionError, socket.error, IOError, OSError)
)
channel_errors = (
base.Transport.channel_errors + (ChannelError,)
)
driver_type = 'amqp'
driver_name = 'librabbitmq'
implements = base.Transport.implements.extend(
asynchronous=True,
heartbeats=False,
)
def __init__(self, client, **kwargs):
self.client = client
self.default_port = kwargs.get('default_port') or self.default_port
self.default_ssl_port = (kwargs.get('default_ssl_port') or
self.default_ssl_port)
self.__reader = None
def driver_version(self):
return amqp.__version__
def create_channel(self, connection):
return connection.channel()
def drain_events(self, connection, **kwargs):
return connection.drain_events(**kwargs)
def establish_connection(self):
"""Establish connection to the AMQP broker."""
conninfo = self.client
for name, default_value in self.default_connection_params.items():
if not getattr(conninfo, name, None):
setattr(conninfo, name, default_value)
if conninfo.ssl:
raise NotImplementedError(NO_SSL_ERROR)
opts = dict({
'host': conninfo.host,
'userid': conninfo.userid,
'password': conninfo.password,
'virtual_host': conninfo.virtual_host,
'login_method': conninfo.login_method,
'insist': conninfo.insist,
'ssl': conninfo.ssl,
'connect_timeout': conninfo.connect_timeout,
}, **conninfo.transport_options or {})
conn = self.Connection(**opts)
conn.client = self.client
self.client.drain_events = conn.drain_events
return conn
def close_connection(self, connection):
"""Close the AMQP broker connection."""
self.client.drain_events = None
connection.close()
def _collect(self, connection):
if connection is not None:
for channel in connection.channels.values():
channel.connection = None
try:
os.close(connection.fileno())
except (OSError, ValueError):
pass
connection.channels.clear()
connection.callbacks.clear()
self.client.drain_events = None
self.client = None
def verify_connection(self, connection):
return connection.connected
def register_with_event_loop(self, connection, loop):
loop.add_reader(
connection.fileno(), self.on_readable, connection, loop,
)
def get_manager(self, *args, **kwargs):
return get_manager(self.client, *args, **kwargs)
def qos_semantics_matches_spec(self, connection):
try:
props = connection.server_properties
except AttributeError:
warnings.warn(UserWarning(W_VERSION))
else:
if props.get('product') == 'RabbitMQ':
return version_string_as_tuple(props['version']) < (3, 3)
return True
@property
def default_connection_params(self):
return {
'userid': 'guest',
'password': 'guest',
'port': (self.default_ssl_port if self.client.ssl
else self.default_port),
'hostname': 'localhost',
'login_method': 'PLAIN',
}

View File

@@ -0,0 +1,106 @@
"""In-memory transport module for Kombu.
Simple transport using memory for storing messages.
Messages can be passed only between threads.
Features
========
* Type: Virtual
* Supports Direct: Yes
* Supports Topic: Yes
* Supports Fanout: No
* Supports Priority: No
* Supports TTL: Yes
Connection String
=================
Connection string is in the following format:
.. code-block::
memory://
"""
from __future__ import annotations
from collections import defaultdict
from queue import Queue
from . import base, virtual
class Channel(virtual.Channel):
"""In-memory Channel."""
events = defaultdict(set)
queues = {}
do_restore = False
supports_fanout = True
def _has_queue(self, queue, **kwargs):
return queue in self.queues
def _new_queue(self, queue, **kwargs):
if queue not in self.queues:
self.queues[queue] = Queue()
def _get(self, queue, timeout=None):
return self._queue_for(queue).get(block=False)
def _queue_for(self, queue):
if queue not in self.queues:
self.queues[queue] = Queue()
return self.queues[queue]
def _queue_bind(self, *args):
pass
def _put_fanout(self, exchange, message, routing_key=None, **kwargs):
for queue in self._lookup(exchange, routing_key):
self._queue_for(queue).put(message)
def _put(self, queue, message, **kwargs):
self._queue_for(queue).put(message)
def _size(self, queue):
return self._queue_for(queue).qsize()
def _delete(self, queue, *args, **kwargs):
self.queues.pop(queue, None)
def _purge(self, queue):
q = self._queue_for(queue)
size = q.qsize()
q.queue.clear()
return size
def close(self):
super().close()
for queue in self.queues.values():
queue.empty()
self.queues = {}
def after_reply_message_received(self, queue):
pass
class Transport(virtual.Transport):
"""In-memory Transport."""
Channel = Channel
#: memory backend state is global.
global_state = virtual.BrokerState()
implements = base.Transport.implements
driver_type = 'memory'
driver_name = 'memory'
def __init__(self, client, **kwargs):
super().__init__(client, **kwargs)
self.state = self.global_state
def driver_version(self):
return 'N/A'

View File

@@ -0,0 +1,534 @@
# copyright: (c) 2010 - 2013 by Flavio Percoco Premoli.
# license: BSD, see LICENSE for more details.
"""MongoDB transport module for kombu.
Features
========
* Type: Virtual
* Supports Direct: Yes
* Supports Topic: Yes
* Supports Fanout: Yes
* Supports Priority: Yes
* Supports TTL: Yes
Connection String
=================
*Unreviewed*
Transport Options
=================
* ``connect_timeout``,
* ``ssl``,
* ``ttl``,
* ``capped_queue_size``,
* ``default_hostname``,
* ``default_port``,
* ``default_database``,
* ``messages_collection``,
* ``routing_collection``,
* ``broadcast_collection``,
* ``queues_collection``,
* ``calc_queue_size``,
"""
from __future__ import annotations
import datetime
from queue import Empty
import pymongo
from pymongo import MongoClient, errors, uri_parser
from pymongo.cursor import CursorType
from kombu.exceptions import VersionMismatch
from kombu.utils.compat import _detect_environment
from kombu.utils.encoding import bytes_to_str
from kombu.utils.json import dumps, loads
from kombu.utils.objects import cached_property
from kombu.utils.url import maybe_sanitize_url
from . import virtual
from .base import to_rabbitmq_queue_arguments
E_SERVER_VERSION = """\
Kombu requires MongoDB version 1.3+ (server is {0})\
"""
E_NO_TTL_INDEXES = """\
Kombu requires MongoDB version 2.2+ (server is {0}) for TTL indexes support\
"""
class BroadcastCursor:
"""Cursor for broadcast queues."""
def __init__(self, cursor):
self._cursor = cursor
self._offset = 0
self.purge(rewind=False)
def get_size(self):
return self._cursor.collection.count_documents({}) - self._offset
def close(self):
self._cursor.close()
def purge(self, rewind=True):
if rewind:
self._cursor.rewind()
# Fast-forward the cursor past old events
self._offset = self._cursor.collection.count_documents({})
self._cursor = self._cursor.skip(self._offset)
def __iter__(self):
return self
def __next__(self):
while True:
try:
msg = next(self._cursor)
except pymongo.errors.OperationFailure as exc:
# In some cases tailed cursor can become invalid
# and have to be reinitalized
if 'not valid at server' in str(exc):
self.purge()
continue
raise
else:
break
self._offset += 1
return msg
next = __next__
class Channel(virtual.Channel):
"""MongoDB Channel."""
supports_fanout = True
# Mutable container. Shared by all class instances
_fanout_queues = {}
# Options
ssl = False
ttl = False
connect_timeout = None
capped_queue_size = 100000
calc_queue_size = True
default_hostname = '127.0.0.1'
default_port = 27017
default_database = 'kombu_default'
messages_collection = 'messages'
routing_collection = 'messages.routing'
broadcast_collection = 'messages.broadcast'
queues_collection = 'messages.queues'
from_transport_options = (virtual.Channel.from_transport_options + (
'connect_timeout', 'ssl', 'ttl', 'capped_queue_size',
'default_hostname', 'default_port', 'default_database',
'messages_collection', 'routing_collection',
'broadcast_collection', 'queues_collection',
'calc_queue_size',
))
def __init__(self, *vargs, **kwargs):
super().__init__(*vargs, **kwargs)
self._broadcast_cursors = {}
# Evaluate connection
self.client
# AbstractChannel/Channel interface implementation
def _new_queue(self, queue, **kwargs):
if self.ttl:
self.queues.update_one(
{'_id': queue},
{
'$set': {
'_id': queue,
'options': kwargs,
'expire_at': self._get_queue_expire(
kwargs, 'x-expires'
),
},
},
upsert=True)
def _get(self, queue):
if queue in self._fanout_queues:
try:
msg = next(self._get_broadcast_cursor(queue))
except StopIteration:
msg = None
else:
msg = self.messages.find_one_and_delete(
{'queue': queue},
sort=[('priority', pymongo.ASCENDING)],
)
if self.ttl:
self._update_queues_expire(queue)
if msg is None:
raise Empty()
return loads(bytes_to_str(msg['payload']))
def _size(self, queue):
# Do not calculate actual queue size if requested
# for performance considerations
if not self.calc_queue_size:
return super()._size(queue)
if queue in self._fanout_queues:
return self._get_broadcast_cursor(queue).get_size()
return self.messages.count_documents({'queue': queue})
def _put(self, queue, message, **kwargs):
data = {
'payload': dumps(message),
'queue': queue,
'priority': self._get_message_priority(message, reverse=True)
}
if self.ttl:
data['expire_at'] = self._get_queue_expire(queue, 'x-message-ttl')
msg_expire = self._get_message_expire(message)
if msg_expire is not None and (
data['expire_at'] is None or msg_expire < data['expire_at']
):
data['expire_at'] = msg_expire
self.messages.insert_one(data)
def _put_fanout(self, exchange, message, routing_key, **kwargs):
self.broadcast.insert_one({'payload': dumps(message),
'queue': exchange})
def _purge(self, queue):
size = self._size(queue)
if queue in self._fanout_queues:
self._get_broadcast_cursor(queue).purge()
else:
self.messages.delete_many({'queue': queue})
return size
def get_table(self, exchange):
localRoutes = frozenset(self.state.exchanges[exchange]['table'])
brokerRoutes = self.routing.find(
{'exchange': exchange}
)
return localRoutes | frozenset(
(r['routing_key'], r['pattern'], r['queue'])
for r in brokerRoutes
)
def _queue_bind(self, exchange, routing_key, pattern, queue):
if self.typeof(exchange).type == 'fanout':
self._create_broadcast_cursor(
exchange, routing_key, pattern, queue)
self._fanout_queues[queue] = exchange
lookup = {
'exchange': exchange,
'queue': queue,
'routing_key': routing_key,
'pattern': pattern,
}
data = lookup.copy()
if self.ttl:
data['expire_at'] = self._get_queue_expire(queue, 'x-expires')
self.routing.update_one(lookup, {'$set': data}, upsert=True)
def queue_delete(self, queue, **kwargs):
self.routing.delete_many({'queue': queue})
if self.ttl:
self.queues.delete_one({'_id': queue})
super().queue_delete(queue, **kwargs)
if queue in self._fanout_queues:
try:
cursor = self._broadcast_cursors.pop(queue)
except KeyError:
pass
else:
cursor.close()
self._fanout_queues.pop(queue)
# Implementation details
def _parse_uri(self, scheme='mongodb://'):
# See mongodb uri documentation:
# https://docs.mongodb.org/manual/reference/connection-string/
client = self.connection.client
hostname = client.hostname
if hostname.startswith('srv://'):
scheme = 'mongodb+srv://'
hostname = 'mongodb+' + hostname
if not hostname.startswith(scheme):
hostname = scheme + hostname
if not hostname[len(scheme):]:
hostname += self.default_hostname
if client.userid and '@' not in hostname:
head, tail = hostname.split('://')
credentials = client.userid
if client.password:
credentials += ':' + client.password
hostname = head + '://' + credentials + '@' + tail
port = client.port if client.port else self.default_port
# We disable validating and normalization parameters here,
# because pymongo will validate and normalize parameters later in __init__ of MongoClient
parsed = uri_parser.parse_uri(hostname, port, validate=False)
dbname = parsed['database'] or client.virtual_host
if dbname in ('/', None):
dbname = self.default_database
options = {
'auto_start_request': True,
'ssl': self.ssl,
'connectTimeoutMS': (int(self.connect_timeout * 1000)
if self.connect_timeout else None),
}
options.update(parsed['options'])
options = self._prepare_client_options(options)
if 'tls' in options:
options.pop('ssl')
return hostname, dbname, options
def _prepare_client_options(self, options):
if pymongo.version_tuple >= (3,):
options.pop('auto_start_request', None)
if isinstance(options.get('readpreference'), int):
modes = pymongo.read_preferences._MONGOS_MODES
options['readpreference'] = modes[options['readpreference']]
return options
def prepare_queue_arguments(self, arguments, **kwargs):
return to_rabbitmq_queue_arguments(arguments, **kwargs)
def _open(self, scheme='mongodb://'):
hostname, dbname, conf = self._parse_uri(scheme=scheme)
conf['host'] = hostname
env = _detect_environment()
if env == 'gevent':
from gevent import monkey
monkey.patch_all()
elif env == 'eventlet':
from eventlet import monkey_patch
monkey_patch()
mongoconn = MongoClient(**conf)
database = mongoconn[dbname]
version_str = mongoconn.server_info()['version']
version_str = version_str.split('-')[0]
version = tuple(map(int, version_str.split('.')))
if version < (1, 3):
raise VersionMismatch(E_SERVER_VERSION.format(version_str))
elif self.ttl and version < (2, 2):
raise VersionMismatch(E_NO_TTL_INDEXES.format(version_str))
return database
def _create_broadcast(self, database):
"""Create capped collection for broadcast messages."""
if self.broadcast_collection in database.list_collection_names():
return
database.create_collection(self.broadcast_collection,
size=self.capped_queue_size,
capped=True)
def _ensure_indexes(self, database):
"""Ensure indexes on collections."""
messages = database[self.messages_collection]
messages.create_index(
[('queue', 1), ('priority', 1), ('_id', 1)], background=True,
)
database[self.broadcast_collection].create_index([('queue', 1)])
routing = database[self.routing_collection]
routing.create_index([('queue', 1), ('exchange', 1)])
if self.ttl:
messages.create_index([('expire_at', 1)], expireAfterSeconds=0)
routing.create_index([('expire_at', 1)], expireAfterSeconds=0)
database[self.queues_collection].create_index(
[('expire_at', 1)], expireAfterSeconds=0)
def _create_client(self):
"""Actually creates connection."""
database = self._open()
self._create_broadcast(database)
self._ensure_indexes(database)
return database
@cached_property
def client(self):
return self._create_client()
@cached_property
def messages(self):
return self.client[self.messages_collection]
@cached_property
def routing(self):
return self.client[self.routing_collection]
@cached_property
def broadcast(self):
return self.client[self.broadcast_collection]
@cached_property
def queues(self):
return self.client[self.queues_collection]
def _get_broadcast_cursor(self, queue):
try:
return self._broadcast_cursors[queue]
except KeyError:
# Cursor may be absent when Channel created more than once.
# _fanout_queues is a class-level mutable attribute so it's
# shared over all Channel instances.
return self._create_broadcast_cursor(
self._fanout_queues[queue], None, None, queue,
)
def _create_broadcast_cursor(self, exchange, routing_key, pattern, queue):
if pymongo.version_tuple >= (3, ):
query = {
'filter': {'queue': exchange},
'cursor_type': CursorType.TAILABLE,
}
else:
query = {
'query': {'queue': exchange},
'tailable': True,
}
cursor = self.broadcast.find(**query)
ret = self._broadcast_cursors[queue] = BroadcastCursor(cursor)
return ret
def _get_message_expire(self, message):
value = message.get('properties', {}).get('expiration')
if value is not None:
return self.get_now() + datetime.timedelta(milliseconds=int(value))
def _get_queue_expire(self, queue, argument):
"""Get expiration header named `argument` of queue definition.
Note:
----
`queue` must be either queue name or options itself.
"""
if isinstance(queue, str):
doc = self.queues.find_one({'_id': queue})
if not doc:
return
data = doc['options']
else:
data = queue
try:
value = data['arguments'][argument]
except (KeyError, TypeError):
return
return self.get_now() + datetime.timedelta(milliseconds=value)
def _update_queues_expire(self, queue):
"""Update expiration field on queues documents."""
expire_at = self._get_queue_expire(queue, 'x-expires')
if not expire_at:
return
self.routing.update_many(
{'queue': queue}, {'$set': {'expire_at': expire_at}})
self.queues.update_many(
{'_id': queue}, {'$set': {'expire_at': expire_at}})
def get_now(self):
"""Return current time in UTC."""
return datetime.datetime.utcnow()
class Transport(virtual.Transport):
"""MongoDB Transport."""
Channel = Channel
can_parse_url = True
polling_interval = 1
default_port = Channel.default_port
connection_errors = (
virtual.Transport.connection_errors + (errors.ConnectionFailure,)
)
channel_errors = (
virtual.Transport.channel_errors + (
errors.ConnectionFailure,
errors.OperationFailure)
)
driver_type = 'mongodb'
driver_name = 'pymongo'
implements = virtual.Transport.implements.extend(
exchange_type=frozenset(['direct', 'topic', 'fanout']),
)
def driver_version(self):
return pymongo.version
def as_uri(self, uri: str, include_password=False, mask='**') -> str:
if not uri:
return 'mongodb://'
if include_password:
return uri
if ',' not in uri:
return maybe_sanitize_url(uri)
uri1, remainder = uri.split(',', 1)
return ','.join([maybe_sanitize_url(uri1), remainder])

View File

@@ -0,0 +1,134 @@
"""Native Delayed Delivery API.
Only relevant for RabbitMQ.
"""
from __future__ import annotations
from kombu import Connection, Exchange, Queue, binding
from kombu.log import get_logger
logger = get_logger(__name__)
MAX_NUMBER_OF_BITS_TO_USE = 28
MAX_LEVEL = MAX_NUMBER_OF_BITS_TO_USE - 1
CELERY_DELAYED_DELIVERY_EXCHANGE = "celery_delayed_delivery"
def level_name(level: int) -> str:
"""Generates the delayed queue/exchange name based on the level."""
if level < 0:
raise ValueError("level must be a non-negative number")
return f"celery_delayed_{level}"
def declare_native_delayed_delivery_exchanges_and_queues(connection: Connection, queue_type: str) -> None:
"""Declares all native delayed delivery exchanges and queues."""
if queue_type != "classic" and queue_type != "quorum":
raise ValueError("queue_type must be either classic or quorum")
channel = connection.channel()
routing_key: str = "1.#"
for level in range(27, -1, - 1):
current_level = level_name(level)
next_level = level_name(level - 1) if level > 0 else None
delayed_exchange: Exchange = Exchange(
current_level, type="topic").bind(channel)
delayed_exchange.declare()
queue_arguments = {
"x-queue-type": queue_type,
"x-overflow": "reject-publish",
"x-message-ttl": pow(2, level) * 1000,
"x-dead-letter-exchange": next_level if level > 0 else CELERY_DELAYED_DELIVERY_EXCHANGE,
}
if queue_type == 'quorum':
queue_arguments["x-dead-letter-strategy"] = "at-least-once"
delayed_queue: Queue = Queue(
current_level,
queue_arguments=queue_arguments
).bind(channel)
delayed_queue.declare()
delayed_queue.bind_to(current_level, routing_key)
routing_key = "*." + routing_key
routing_key = "0.#"
for level in range(27, 0, - 1):
current_level = level_name(level)
next_level = level_name(level - 1) if level > 0 else None
next_level_exchange: Exchange = Exchange(
next_level, type="topic").bind(channel)
next_level_exchange.bind_to(current_level, routing_key)
routing_key = "*." + routing_key
delivery_exchange: Exchange = Exchange(
CELERY_DELAYED_DELIVERY_EXCHANGE, type="topic").bind(channel)
delivery_exchange.declare()
delivery_exchange.bind_to(level_name(0), routing_key)
def bind_queue_to_native_delayed_delivery_exchange(connection: Connection, queue: Queue) -> None:
"""Bind a queue to the native delayed delivery exchange.
When a message arrives at the delivery exchange, it must be forwarded to
the original exchange and queue. To accomplish this, the function retrieves
the exchange or binding objects associated with the queue and binds them to
the delivery exchange.
:param connection: The connection object used to create and manage the channel.
:type connection: Connection
:param queue: The queue to be bound to the native delayed delivery exchange.
:type queue: Queue
Warning:
-------
If a direct exchange is detected, a warning will be logged because
native delayed delivery does not support direct exchanges.
"""
channel = connection.channel()
queue = queue.bind(channel)
bindings: set[binding] = set()
if queue.exchange:
bindings.add(binding(
queue.exchange,
routing_key=queue.routing_key,
arguments=queue.binding_arguments
))
elif queue.bindings:
bindings = queue.bindings
for binding_entry in bindings:
exchange: Exchange = binding_entry.exchange.bind(channel)
if exchange.type == 'direct':
logger.warning(f"Exchange {exchange.name} is a direct exchange "
f"and native delayed delivery do not support direct exchanges.\n"
f"ETA tasks published to this exchange will block the worker until the ETA arrives.")
continue
routing_key = binding_entry.routing_key if binding_entry.routing_key.startswith(
'#') else f"#.{binding_entry.routing_key}"
exchange.bind_to(CELERY_DELAYED_DELIVERY_EXCHANGE, routing_key=routing_key)
queue.bind_to(exchange.name, routing_key=routing_key)
def calculate_routing_key(countdown: int, routing_key: str) -> str:
"""Calculate the routing key for publishing a delayed message based on the countdown."""
if countdown < 1:
raise ValueError("countdown must be a positive number")
if not routing_key:
raise ValueError("routing_key must be non-empty")
return '.'.join(list(f'{countdown:028b}')) + f'.{routing_key}'

View File

@@ -0,0 +1,253 @@
"""pyamqp transport module for Kombu.
Pure-Python amqp transport using py-amqp library.
Features
========
* Type: Native
* Supports Direct: Yes
* Supports Topic: Yes
* Supports Fanout: Yes
* Supports Priority: Yes
* Supports TTL: Yes
Connection String
=================
Connection string can have the following formats:
.. code-block::
amqp://[USER:PASSWORD@]BROKER_ADDRESS[:PORT][/VIRTUALHOST]
[USER:PASSWORD@]BROKER_ADDRESS[:PORT][/VIRTUALHOST]
amqp://
For TLS encryption use:
.. code-block::
amqps://[USER:PASSWORD@]BROKER_ADDRESS[:PORT][/VIRTUALHOST]
Transport Options
=================
Transport Options are passed to constructor of underlying py-amqp
:class:`~kombu.connection.Connection` class.
Using TLS
=========
Transport over TLS can be enabled by ``ssl`` parameter of
:class:`~kombu.Connection` class. By setting ``ssl=True``, TLS transport is
used::
conn = Connect('amqp://', ssl=True)
This is equivalent to ``amqps://`` transport URI::
conn = Connect('amqps://')
For adding additional parameters to underlying TLS, ``ssl`` parameter should
be set with dict instead of True::
conn = Connect('amqp://broker.example.com', ssl={
'keyfile': '/path/to/keyfile'
'certfile': '/path/to/certfile',
'ca_certs': '/path/to/ca_certfile'
}
)
All parameters are passed to ``ssl`` parameter of
:class:`amqp.connection.Connection` class.
SSL option ``server_hostname`` can be set to ``None`` which is causing using
hostname from broker URL. This is useful when failover is used to fill
``server_hostname`` with currently used broker::
conn = Connect('amqp://broker1.example.com;broker2.example.com', ssl={
'server_hostname': None
}
)
"""
from __future__ import annotations
import amqp
from kombu.utils.amq_manager import get_manager
from kombu.utils.text import version_string_as_tuple
from . import base
from .base import to_rabbitmq_queue_arguments
DEFAULT_PORT = 5672
DEFAULT_SSL_PORT = 5671
class Message(base.Message):
"""AMQP Message."""
def __init__(self, msg, channel=None, **kwargs):
props = msg.properties
super().__init__(
body=msg.body,
channel=channel,
delivery_tag=msg.delivery_tag,
content_type=props.get('content_type'),
content_encoding=props.get('content_encoding'),
delivery_info=msg.delivery_info,
properties=msg.properties,
headers=props.get('application_headers') or {},
**kwargs)
class Channel(amqp.Channel, base.StdChannel):
"""AMQP Channel."""
Message = Message
def prepare_message(self, body, priority=None,
content_type=None, content_encoding=None,
headers=None, properties=None, _Message=amqp.Message):
"""Prepare message so that it can be sent using this transport."""
return _Message(
body,
priority=priority,
content_type=content_type,
content_encoding=content_encoding,
application_headers=headers,
**properties or {}
)
def prepare_queue_arguments(self, arguments, **kwargs):
return to_rabbitmq_queue_arguments(arguments, **kwargs)
def message_to_python(self, raw_message):
"""Convert encoded message body back to a Python value."""
return self.Message(raw_message, channel=self)
class Connection(amqp.Connection):
"""AMQP Connection."""
Channel = Channel
class Transport(base.Transport):
"""AMQP Transport."""
Connection = Connection
default_port = DEFAULT_PORT
default_ssl_port = DEFAULT_SSL_PORT
# it's very annoying that pyamqp sometimes raises AttributeError
# if the connection is lost, but nothing we can do about that here.
connection_errors = amqp.Connection.connection_errors
channel_errors = amqp.Connection.channel_errors
recoverable_connection_errors = \
amqp.Connection.recoverable_connection_errors
recoverable_channel_errors = amqp.Connection.recoverable_channel_errors
driver_name = 'py-amqp'
driver_type = 'amqp'
implements = base.Transport.implements.extend(
asynchronous=True,
heartbeats=True,
)
def __init__(self, client,
default_port=None, default_ssl_port=None, **kwargs):
self.client = client
self.default_port = default_port or self.default_port
self.default_ssl_port = default_ssl_port or self.default_ssl_port
def driver_version(self):
return amqp.__version__
def create_channel(self, connection):
return connection.channel()
def drain_events(self, connection, **kwargs):
return connection.drain_events(**kwargs)
def _collect(self, connection):
if connection is not None:
connection.collect()
def establish_connection(self):
"""Establish connection to the AMQP broker."""
conninfo = self.client
for name, default_value in self.default_connection_params.items():
if not getattr(conninfo, name, None):
setattr(conninfo, name, default_value)
if conninfo.hostname == 'localhost':
conninfo.hostname = '127.0.0.1'
# when server_hostname is None, use hostname from URI.
if isinstance(conninfo.ssl, dict) and \
'server_hostname' in conninfo.ssl and \
conninfo.ssl['server_hostname'] is None:
conninfo.ssl['server_hostname'] = conninfo.hostname
opts = dict({
'host': conninfo.host,
'userid': conninfo.userid,
'password': conninfo.password,
'login_method': conninfo.login_method,
'virtual_host': conninfo.virtual_host,
'insist': conninfo.insist,
'ssl': conninfo.ssl,
'connect_timeout': conninfo.connect_timeout,
'heartbeat': conninfo.heartbeat,
}, **conninfo.transport_options or {})
conn = self.Connection(**opts)
conn.client = self.client
conn.connect()
return conn
def verify_connection(self, connection):
return connection.connected
def close_connection(self, connection):
"""Close the AMQP broker connection."""
connection.client = None
connection.close()
def get_heartbeat_interval(self, connection):
return connection.heartbeat
def register_with_event_loop(self, connection, loop):
connection.transport.raise_on_initial_eintr = True
loop.add_reader(connection.sock, self.on_readable, connection, loop)
def heartbeat_check(self, connection, rate=2):
return connection.heartbeat_tick(rate=rate)
def qos_semantics_matches_spec(self, connection):
props = connection.server_properties
if props.get('product') == 'RabbitMQ':
return version_string_as_tuple(props['version']) < (3, 3)
return True
@property
def default_connection_params(self):
return {
'userid': 'guest',
'password': 'guest',
'port': (self.default_ssl_port if self.client.ssl
else self.default_port),
'hostname': 'localhost',
'login_method': 'PLAIN',
}
def get_manager(self, *args, **kwargs):
return get_manager(self.client, *args, **kwargs)
class SSLTransport(Transport):
"""AMQP SSL Transport."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# ugh, not exactly pure, but hey, it's python.
if not self.client.ssl: # not dict or False
self.client.ssl = True

View File

@@ -0,0 +1,212 @@
"""Pyro transport module for kombu.
Pyro transport, and Kombu Broker daemon.
Requires the :mod:`Pyro4` library to be installed.
Features
========
* Type: Virtual
* Supports Direct: Yes
* Supports Topic: Yes
* Supports Fanout: No
* Supports Priority: No
* Supports TTL: No
Connection String
=================
To use the Pyro transport with Kombu, use an url of the form:
.. code-block::
pyro://localhost/kombu.broker
The hostname is where the transport will be looking for a Pyro name server,
which is used in turn to locate the kombu.broker Pyro service.
This broker can be launched by simply executing this transport module directly,
with the command: ``python -m kombu.transport.pyro``
Transport Options
=================
"""
from __future__ import annotations
import sys
from queue import Empty, Queue
from kombu.exceptions import reraise
from kombu.log import get_logger
from kombu.utils.objects import cached_property
from . import virtual
try:
import Pyro4 as pyro
from Pyro4.errors import NamingError
from Pyro4.util import SerializerBase
except ImportError: # pragma: no cover
pyro = NamingError = SerializerBase = None
DEFAULT_PORT = 9090
E_NAMESERVER = """\
Unable to locate pyro nameserver on host {0.hostname}\
"""
E_LOOKUP = """\
Unable to lookup '{0.virtual_host}' in pyro nameserver on host {0.hostname}\
"""
logger = get_logger(__name__)
class Channel(virtual.Channel):
"""Pyro Channel."""
def close(self):
super().close()
if self.shared_queues:
self.shared_queues._pyroRelease()
def queues(self):
return self.shared_queues.get_queue_names()
def _new_queue(self, queue, **kwargs):
if queue not in self.queues():
self.shared_queues.new_queue(queue)
def _has_queue(self, queue, **kwargs):
return self.shared_queues.has_queue(queue)
def _get(self, queue, timeout=None):
queue = self._queue_for(queue)
return self.shared_queues.get(queue)
def _queue_for(self, queue):
if queue not in self.queues():
self.shared_queues.new_queue(queue)
return queue
def _put(self, queue, message, **kwargs):
queue = self._queue_for(queue)
self.shared_queues.put(queue, message)
def _size(self, queue):
return self.shared_queues.size(queue)
def _delete(self, queue, *args, **kwargs):
self.shared_queues.delete(queue)
def _purge(self, queue):
return self.shared_queues.purge(queue)
def after_reply_message_received(self, queue):
pass
@cached_property
def shared_queues(self):
return self.connection.shared_queues
class Transport(virtual.Transport):
"""Pyro Transport."""
Channel = Channel
#: memory backend state is global.
# TODO: To be checked whether state can be per-Transport
global_state = virtual.BrokerState()
default_port = DEFAULT_PORT
driver_type = driver_name = 'pyro'
def __init__(self, client, **kwargs):
super().__init__(client, **kwargs)
self.state = self.global_state
def _open(self):
logger.debug("trying Pyro nameserver to find the broker daemon")
conninfo = self.client
try:
nameserver = pyro.locateNS(host=conninfo.hostname,
port=self.default_port)
except NamingError:
reraise(NamingError, NamingError(E_NAMESERVER.format(conninfo)),
sys.exc_info()[2])
try:
# name of registered pyro object
uri = nameserver.lookup(conninfo.virtual_host)
return pyro.Proxy(uri)
except NamingError:
reraise(NamingError, NamingError(E_LOOKUP.format(conninfo)),
sys.exc_info()[2])
def driver_version(self):
return pyro.__version__
@cached_property
def shared_queues(self):
return self._open()
if pyro is not None:
SerializerBase.register_dict_to_class("queue.Empty",
lambda cls, data: Empty())
@pyro.expose
@pyro.behavior(instance_mode="single")
class KombuBroker:
"""Kombu Broker used by the Pyro transport.
You have to run this as a separate (Pyro) service.
"""
def __init__(self):
self.queues = {}
def get_queue_names(self):
return list(self.queues)
def new_queue(self, queue):
if queue in self.queues:
return # silently ignore the fact that queue already exists
self.queues[queue] = Queue()
def has_queue(self, queue):
return queue in self.queues
def get(self, queue):
return self.queues[queue].get(block=False)
def put(self, queue, message):
self.queues[queue].put(message)
def size(self, queue):
return self.queues[queue].qsize()
def delete(self, queue):
del self.queues[queue]
def purge(self, queue):
while True:
try:
self.queues[queue].get(blocking=False)
except Empty:
break
# launch a Kombu Broker daemon with the command:
# ``python -m kombu.transport.pyro``
if __name__ == "__main__":
print("Launching Broker for Kombu's Pyro transport.")
with pyro.Daemon() as daemon:
print("(Expecting a Pyro name server at {}:{})"
.format(pyro.config.NS_HOST, pyro.config.NS_PORT))
with pyro.locateNS() as ns:
print("You can connect with Kombu using the url "
"'pyro://{}/kombu.broker'".format(pyro.config.NS_HOST))
uri = daemon.register(KombuBroker)
ns.register("kombu.broker", uri)
daemon.requestLoop()

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,256 @@
"""SQLAlchemy Transport module for kombu.
Kombu transport using SQL Database as the message store.
Features
========
* Type: Virtual
* Supports Direct: yes
* Supports Topic: yes
* Supports Fanout: no
* Supports Priority: no
* Supports TTL: no
Connection String
=================
.. code-block::
sqla+SQL_ALCHEMY_CONNECTION_STRING
sqlalchemy+SQL_ALCHEMY_CONNECTION_STRING
For details about ``SQL_ALCHEMY_CONNECTION_STRING`` see SQLAlchemy Engine Configuration documentation.
Examples
--------
.. code-block::
# PostgreSQL with default driver
sqla+postgresql://scott:tiger@localhost/mydatabase
# PostgreSQL with psycopg2 driver
sqla+postgresql+psycopg2://scott:tiger@localhost/mydatabase
# PostgreSQL with pg8000 driver
sqla+postgresql+pg8000://scott:tiger@localhost/mydatabase
# MySQL with default driver
sqla+mysql://scott:tiger@localhost/foo
# MySQL with mysqlclient driver (a maintained fork of MySQL-Python)
sqla+mysql+mysqldb://scott:tiger@localhost/foo
# MySQL with PyMySQL driver
sqla+mysql+pymysql://scott:tiger@localhost/foo
Transport Options
=================
* ``queue_tablename``: Name of table storing queues.
* ``message_tablename``: Name of table storing messages.
Moreover parameters of :func:`sqlalchemy.create_engine()` function can be passed as transport options.
"""
from __future__ import annotations
import threading
from json import dumps, loads
from queue import Empty
from sqlalchemy import create_engine, text
from sqlalchemy.exc import OperationalError
from sqlalchemy.orm import sessionmaker
from kombu.transport import virtual
from kombu.utils import cached_property
from kombu.utils.encoding import bytes_to_str
from .models import Message as MessageBase
from .models import ModelBase
from .models import Queue as QueueBase
from .models import class_registry, metadata
# SQLAlchemy overrides != False to have special meaning and pep8 complains
# flake8: noqa
VERSION = (1, 4, 1)
__version__ = '.'.join(map(str, VERSION))
_MUTEX = threading.RLock()
class Channel(virtual.Channel):
"""The channel class."""
_session = None
_engines = {} # engine cache
def __init__(self, connection, **kwargs):
self._configure_entity_tablenames(connection.client.transport_options)
super().__init__(connection, **kwargs)
def _configure_entity_tablenames(self, opts):
self.queue_tablename = opts.get('queue_tablename', 'kombu_queue')
self.message_tablename = opts.get('message_tablename', 'kombu_message')
#
# Define the model definitions. This registers the declarative
# classes with the active SQLAlchemy metadata object. This *must* be
# done prior to the ``create_engine`` call.
#
self.queue_cls and self.message_cls
def _engine_from_config(self):
conninfo = self.connection.client
transport_options = conninfo.transport_options.copy()
transport_options.pop('queue_tablename', None)
transport_options.pop('message_tablename', None)
transport_options.pop('callback', None)
transport_options.pop('errback', None)
transport_options.pop('max_retries', None)
transport_options.pop('interval_start', None)
transport_options.pop('interval_step', None)
transport_options.pop('interval_max', None)
transport_options.pop('retry_errors', None)
return create_engine(conninfo.hostname, **transport_options)
def _open(self):
conninfo = self.connection.client
if conninfo.hostname not in self._engines:
with _MUTEX:
if conninfo.hostname in self._engines:
# Engine was created while we were waiting to
# acquire the lock.
return self._engines[conninfo.hostname]
engine = self._engine_from_config()
Session = sessionmaker(bind=engine)
metadata.create_all(engine)
self._engines[conninfo.hostname] = engine, Session
return self._engines[conninfo.hostname]
@property
def session(self):
if self._session is None:
_, Session = self._open()
self._session = Session()
return self._session
def _get_or_create(self, queue):
obj = self.session.query(self.queue_cls) \
.filter(self.queue_cls.name == queue).first()
if not obj:
with _MUTEX:
obj = self.session.query(self.queue_cls) \
.filter(self.queue_cls.name == queue).first()
if obj:
# Queue was created while we were waiting to
# acquire the lock.
return obj
obj = self.queue_cls(queue)
self.session.add(obj)
try:
self.session.commit()
except OperationalError:
self.session.rollback()
return obj
def _new_queue(self, queue, **kwargs):
self._get_or_create(queue)
def _put(self, queue, payload, **kwargs):
obj = self._get_or_create(queue)
message = self.message_cls(dumps(payload), obj)
self.session.add(message)
try:
self.session.commit()
except OperationalError:
self.session.rollback()
def _get(self, queue):
obj = self._get_or_create(queue)
if self.session.bind.name == 'sqlite':
self.session.execute(text('BEGIN IMMEDIATE TRANSACTION'))
try:
msg = self.session.query(self.message_cls) \
.with_for_update() \
.filter(self.message_cls.queue_id == obj.id) \
.filter(self.message_cls.visible != False) \
.order_by(self.message_cls.sent_at) \
.order_by(self.message_cls.id) \
.limit(1) \
.first()
if msg:
msg.visible = False
return loads(bytes_to_str(msg.payload))
raise Empty()
finally:
self.session.commit()
def _query_all(self, queue):
obj = self._get_or_create(queue)
return self.session.query(self.message_cls) \
.filter(self.message_cls.queue_id == obj.id)
def _purge(self, queue):
count = self._query_all(queue).delete(synchronize_session=False)
try:
self.session.commit()
except OperationalError:
self.session.rollback()
return count
def _size(self, queue):
return self._query_all(queue).count()
def _declarative_cls(self, name, base, ns):
if name not in class_registry:
with _MUTEX:
if name in class_registry:
# Class was registered while we were waiting to
# acquire the lock.
return class_registry[name]
return type(str(name), (base, ModelBase), ns)
return class_registry[name]
@cached_property
def queue_cls(self):
return self._declarative_cls(
'Queue',
QueueBase,
{'__tablename__': self.queue_tablename}
)
@cached_property
def message_cls(self):
return self._declarative_cls(
'Message',
MessageBase,
{'__tablename__': self.message_tablename}
)
class Transport(virtual.Transport):
"""The transport class."""
Channel = Channel
can_parse_url = True
default_port = 0
driver_type = 'sql'
driver_name = 'sqlalchemy'
connection_errors = (OperationalError, )
def driver_version(self):
import sqlalchemy
return sqlalchemy.__version__

View File

@@ -0,0 +1,76 @@
"""Kombu transport using SQLAlchemy as the message store."""
from __future__ import annotations
import datetime
from sqlalchemy import (Boolean, Column, DateTime, ForeignKey, Index, Integer,
Sequence, SmallInteger, String, Text)
from sqlalchemy.orm import relationship
from sqlalchemy.schema import MetaData
try:
from sqlalchemy.orm import declarative_base, declared_attr
except ImportError:
# TODO: Remove this once we drop support for SQLAlchemy < 1.4.
from sqlalchemy.ext.declarative import declarative_base, declared_attr
class_registry = {}
metadata = MetaData()
ModelBase = declarative_base(metadata=metadata, class_registry=class_registry)
class Queue:
"""The queue class."""
__table_args__ = {'sqlite_autoincrement': True, 'mysql_engine': 'InnoDB'}
id = Column(Integer, Sequence('queue_id_sequence'), primary_key=True,
autoincrement=True)
name = Column(String(200), unique=True)
def __init__(self, name):
self.name = name
def __str__(self):
return f'<Queue({self.name})>'
@declared_attr
def messages(cls):
return relationship('Message', backref='queue', lazy='noload')
class Message:
"""The message class."""
__table_args__ = (
Index('ix_kombu_message_timestamp_id', 'timestamp', 'id'),
{'sqlite_autoincrement': True, 'mysql_engine': 'InnoDB'}
)
id = Column(Integer, Sequence('message_id_sequence'),
primary_key=True, autoincrement=True)
visible = Column(Boolean, default=True, index=True)
sent_at = Column('timestamp', DateTime, nullable=True, index=True,
onupdate=datetime.datetime.now)
payload = Column(Text, nullable=False)
version = Column(SmallInteger, nullable=False, default=1)
__mapper_args__ = {'version_id_col': version}
def __init__(self, payload, queue):
self.payload = payload
self.queue = queue
def __str__(self):
return '<Message: {0.sent_at} {0.payload} {0.queue_id}>'.format(self)
@declared_attr
def queue_id(self):
return Column(
Integer,
ForeignKey(
'%s.id' % class_registry['Queue'].__tablename__,
name='FK_kombu_message_queue'
)
)

View File

@@ -0,0 +1,11 @@
from __future__ import annotations
from .base import (AbstractChannel, Base64, BrokerState, Channel, Empty,
Management, Message, NotEquivalentError, QoS, Transport,
UndeliverableWarning, binding_key_t, queue_binding_t)
__all__ = (
'Base64', 'NotEquivalentError', 'UndeliverableWarning', 'BrokerState',
'QoS', 'Message', 'AbstractChannel', 'Channel', 'Management', 'Transport',
'Empty', 'binding_key_t', 'queue_binding_t',
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,164 @@
"""Virtual AMQ Exchange.
Implementations of the standard exchanges defined
by the AMQ protocol (excluding the `headers` exchange).
"""
from __future__ import annotations
import re
from kombu.utils.text import escape_regex
class ExchangeType:
"""Base class for exchanges.
Implements the specifics for an exchange type.
Arguments:
---------
channel (ChannelT): AMQ Channel.
"""
type = None
def __init__(self, channel):
self.channel = channel
def lookup(self, table, exchange, routing_key, default):
"""Lookup all queues matching `routing_key` in `exchange`.
Returns
-------
str: queue name, or 'default' if no queues matched.
"""
raise NotImplementedError('subclass responsibility')
def prepare_bind(self, queue, exchange, routing_key, arguments):
"""Prepare queue-binding.
Returns
-------
Tuple[str, Pattern, str]: of `(routing_key, regex, queue)`
to be stored for bindings to this exchange.
"""
return routing_key, None, queue
def equivalent(self, prev, exchange, type,
durable, auto_delete, arguments):
"""Return true if `prev` and `exchange` is equivalent."""
return (type == prev['type'] and
durable == prev['durable'] and
auto_delete == prev['auto_delete'] and
(arguments or {}) == (prev['arguments'] or {}))
class DirectExchange(ExchangeType):
"""Direct exchange.
The `direct` exchange routes based on exact routing keys.
"""
type = 'direct'
def lookup(self, table, exchange, routing_key, default):
return {
queue for rkey, _, queue in table
if rkey == routing_key
}
def deliver(self, message, exchange, routing_key, **kwargs):
_lookup = self.channel._lookup
_put = self.channel._put
for queue in _lookup(exchange, routing_key):
_put(queue, message, **kwargs)
class TopicExchange(ExchangeType):
"""Topic exchange.
The `topic` exchange routes messages based on words separated by
dots, using wildcard characters ``*`` (any single word), and ``#``
(one or more words).
"""
type = 'topic'
#: map of wildcard to regex conversions
wildcards = {'*': r'.*?[^\.]',
'#': r'.*?'}
#: compiled regex cache
_compiled = {}
def lookup(self, table, exchange, routing_key, default):
return {
queue for rkey, pattern, queue in table
if self._match(pattern, routing_key)
}
def deliver(self, message, exchange, routing_key, **kwargs):
_lookup = self.channel._lookup
_put = self.channel._put
deadletter = self.channel.deadletter_queue
for queue in [q for q in _lookup(exchange, routing_key)
if q and q != deadletter]:
_put(queue, message, **kwargs)
def prepare_bind(self, queue, exchange, routing_key, arguments):
return routing_key, self.key_to_pattern(routing_key), queue
def key_to_pattern(self, rkey):
"""Get the corresponding regex for any routing key."""
return '^%s$' % (r'\.'.join(
self.wildcards.get(word, word)
for word in escape_regex(rkey, '.#*').split('.')
))
def _match(self, pattern, string):
"""Match regular expression (cached).
Same as :func:`re.match`, except the regex is compiled and cached,
then reused on subsequent matches with the same pattern.
"""
try:
compiled = self._compiled[pattern]
except KeyError:
compiled = self._compiled[pattern] = re.compile(pattern, re.U)
return compiled.match(string)
class FanoutExchange(ExchangeType):
"""Fanout exchange.
The `fanout` exchange implements broadcast messaging by delivering
copies of all messages to all queues bound to the exchange.
To support fanout the virtual channel needs to store the table
as shared state. This requires that the `Channel.supports_fanout`
attribute is set to true, and the `Channel._queue_bind` and
`Channel.get_table` methods are implemented.
See Also
--------
the redis backend for an example implementation of these methods.
"""
type = 'fanout'
def lookup(self, table, exchange, routing_key, default):
return {queue for _, _, queue in table}
def deliver(self, message, exchange, routing_key, **kwargs):
if self.channel.supports_fanout:
self.channel._put_fanout(
exchange, message, routing_key, **kwargs)
#: Map of standard exchange types and corresponding classes.
STANDARD_EXCHANGE_TYPES = {
'direct': DirectExchange,
'topic': TopicExchange,
'fanout': FanoutExchange,
}

View File

@@ -0,0 +1,223 @@
# copyright: (c) 2010 - 2013 by Mahendra M.
# license: BSD, see LICENSE for more details.
"""Zookeeper transport module for kombu.
Zookeeper based transport. This transport uses the built-in kazoo Zookeeper
based queue implementation.
**References**
- https://zookeeper.apache.org/doc/current/recipes.html#sc_recipes_Queues
- https://kazoo.readthedocs.io/en/latest/api/recipe/queue.html
**Limitations**
This queue does not offer reliable consumption. An entry is removed from
the queue prior to being processed. So if an error occurs, the consumer
has to re-queue the item or it will be lost.
Features
========
* Type: Virtual
* Supports Direct: Yes
* Supports Topic: Yes
* Supports Fanout: No
* Supports Priority: Yes
* Supports TTL: No
Connection String
=================
Connects to a zookeeper node as:
.. code-block::
zookeeper://SERVER:PORT/VHOST
The <vhost> becomes the base for all the other znodes. So we can use
it like a vhost.
Transport Options
=================
"""
from __future__ import annotations
import os
import socket
from queue import Empty
from kombu.utils.encoding import bytes_to_str, ensure_bytes
from kombu.utils.json import dumps, loads
from . import virtual
try:
import kazoo
from kazoo.client import KazooClient
from kazoo.recipe.queue import Queue
KZ_CONNECTION_ERRORS = (
kazoo.exceptions.SystemErrorException,
kazoo.exceptions.ConnectionLossException,
kazoo.exceptions.MarshallingErrorException,
kazoo.exceptions.UnimplementedException,
kazoo.exceptions.OperationTimeoutException,
kazoo.exceptions.NoAuthException,
kazoo.exceptions.InvalidACLException,
kazoo.exceptions.AuthFailedException,
kazoo.exceptions.SessionExpiredException,
)
KZ_CHANNEL_ERRORS = (
kazoo.exceptions.RuntimeInconsistencyException,
kazoo.exceptions.DataInconsistencyException,
kazoo.exceptions.BadArgumentsException,
kazoo.exceptions.MarshallingErrorException,
kazoo.exceptions.UnimplementedException,
kazoo.exceptions.OperationTimeoutException,
kazoo.exceptions.ApiErrorException,
kazoo.exceptions.NoNodeException,
kazoo.exceptions.NoAuthException,
kazoo.exceptions.NodeExistsException,
kazoo.exceptions.NoChildrenForEphemeralsException,
kazoo.exceptions.NotEmptyException,
kazoo.exceptions.SessionExpiredException,
kazoo.exceptions.InvalidCallbackException,
socket.error,
)
except ImportError:
kazoo = None
KZ_CONNECTION_ERRORS = KZ_CHANNEL_ERRORS = ()
DEFAULT_PORT = 2181
__author__ = 'Mahendra M <mahendra.m@gmail.com>'
class Channel(virtual.Channel):
"""Zookeeper Channel."""
_client = None
_queues = {}
def __init__(self, connection, **kwargs):
super().__init__(connection, **kwargs)
vhost = self.connection.client.virtual_host
self._vhost = '/{}'.format(vhost.strip('/'))
def _get_path(self, queue_name):
return os.path.join(self._vhost, queue_name)
def _get_queue(self, queue_name):
queue = self._queues.get(queue_name, None)
if queue is None:
queue = Queue(self.client, self._get_path(queue_name))
self._queues[queue_name] = queue
# Ensure that the queue is created
len(queue)
return queue
def _put(self, queue, message, **kwargs):
return self._get_queue(queue).put(
ensure_bytes(dumps(message)),
priority=self._get_message_priority(message, reverse=True),
)
def _get(self, queue):
queue = self._get_queue(queue)
msg = queue.get()
if msg is None:
raise Empty()
return loads(bytes_to_str(msg))
def _purge(self, queue):
count = 0
queue = self._get_queue(queue)
while True:
msg = queue.get()
if msg is None:
break
count += 1
return count
def _delete(self, queue, *args, **kwargs):
if self._has_queue(queue):
self._purge(queue)
self.client.delete(self._get_path(queue))
def _size(self, queue):
queue = self._get_queue(queue)
return len(queue)
def _new_queue(self, queue, **kwargs):
if not self._has_queue(queue):
queue = self._get_queue(queue)
def _has_queue(self, queue):
return self.client.exists(self._get_path(queue)) is not None
def _open(self):
conninfo = self.connection.client
hosts = []
if conninfo.alt:
for host_port in conninfo.alt:
if host_port.startswith('zookeeper://'):
host_port = host_port[len('zookeeper://'):]
if not host_port:
continue
try:
host, port = host_port.split(':', 1)
host_port = (host, int(port))
except ValueError:
if host_port == conninfo.hostname:
host_port = (host_port, conninfo.port or DEFAULT_PORT)
else:
host_port = (host_port, DEFAULT_PORT)
hosts.append(host_port)
host_port = (conninfo.hostname, conninfo.port or DEFAULT_PORT)
if host_port not in hosts:
hosts.insert(0, host_port)
conn_str = ','.join([f'{h}:{p}' for h, p in hosts])
conn = KazooClient(conn_str)
conn.start()
return conn
@property
def client(self):
if self._client is None:
self._client = self._open()
return self._client
class Transport(virtual.Transport):
"""Zookeeper Transport."""
Channel = Channel
polling_interval = 1
default_port = DEFAULT_PORT
connection_errors = (
virtual.Transport.connection_errors + KZ_CONNECTION_ERRORS
)
channel_errors = (
virtual.Transport.channel_errors + KZ_CHANNEL_ERRORS
)
driver_type = 'zookeeper'
driver_name = 'kazoo'
def __init__(self, *args, **kwargs):
if kazoo is None:
raise ImportError('The kazoo library is not installed')
super().__init__(*args, **kwargs)
def driver_version(self):
return kazoo.__version__