Updates
This commit is contained in:
@@ -0,0 +1,202 @@
|
||||
"""SoftLayer Message Queue transport module for kombu.
|
||||
|
||||
Features
|
||||
========
|
||||
* Type: Virtual
|
||||
* Supports Direct: Yes
|
||||
* Supports Topic: Yes
|
||||
* Supports Fanout: No
|
||||
* Supports Priority: No
|
||||
* Supports TTL: No
|
||||
|
||||
Connection String
|
||||
=================
|
||||
*Unreviewed*
|
||||
|
||||
Transport Options
|
||||
=================
|
||||
*Unreviewed*
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import socket
|
||||
import string
|
||||
from queue import Empty
|
||||
|
||||
from kombu.utils.encoding import bytes_to_str, safe_str
|
||||
from kombu.utils.json import dumps, loads
|
||||
from kombu.utils.objects import cached_property
|
||||
|
||||
from . import virtual
|
||||
|
||||
try:
|
||||
from softlayer_messaging import get_client
|
||||
from softlayer_messaging.errors import ResponseError
|
||||
except ImportError: # pragma: no cover
|
||||
get_client = ResponseError = None
|
||||
|
||||
# dots are replaced by dash, all other punctuation replaced by underscore.
|
||||
CHARS_REPLACE_TABLE = {
|
||||
ord(c): 0x5f for c in string.punctuation if c not in '_'
|
||||
}
|
||||
|
||||
|
||||
class Channel(virtual.Channel):
|
||||
"""SLMQ Channel."""
|
||||
|
||||
default_visibility_timeout = 1800 # 30 minutes.
|
||||
domain_format = 'kombu%(vhost)s'
|
||||
_slmq = None
|
||||
_queue_cache = {}
|
||||
_noack_queues = set()
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
if get_client is None:
|
||||
raise ImportError(
|
||||
'SLMQ transport requires the softlayer_messaging library',
|
||||
)
|
||||
super().__init__(*args, **kwargs)
|
||||
queues = self.slmq.queues()
|
||||
for queue in queues:
|
||||
self._queue_cache[queue] = queue
|
||||
|
||||
def basic_consume(self, queue, no_ack, *args, **kwargs):
|
||||
if no_ack:
|
||||
self._noack_queues.add(queue)
|
||||
return super().basic_consume(queue, no_ack,
|
||||
*args, **kwargs)
|
||||
|
||||
def basic_cancel(self, consumer_tag):
|
||||
if consumer_tag in self._consumers:
|
||||
queue = self._tag_to_queue[consumer_tag]
|
||||
self._noack_queues.discard(queue)
|
||||
return super().basic_cancel(consumer_tag)
|
||||
|
||||
def entity_name(self, name, table=CHARS_REPLACE_TABLE):
|
||||
"""Format AMQP queue name into a valid SLQS queue name."""
|
||||
return str(safe_str(name)).translate(table)
|
||||
|
||||
def _new_queue(self, queue, **kwargs):
|
||||
"""Ensure a queue exists in SLQS."""
|
||||
queue = self.entity_name(self.queue_name_prefix + queue)
|
||||
try:
|
||||
return self._queue_cache[queue]
|
||||
except KeyError:
|
||||
try:
|
||||
self.slmq.create_queue(
|
||||
queue, visibility_timeout=self.visibility_timeout)
|
||||
except ResponseError:
|
||||
pass
|
||||
q = self._queue_cache[queue] = self.slmq.queue(queue)
|
||||
return q
|
||||
|
||||
def _delete(self, queue, *args, **kwargs):
|
||||
"""Delete queue by name."""
|
||||
queue_name = self.entity_name(queue)
|
||||
self._queue_cache.pop(queue_name, None)
|
||||
self.slmq.queue(queue_name).delete(force=True)
|
||||
super()._delete(queue_name)
|
||||
|
||||
def _put(self, queue, message, **kwargs):
|
||||
"""Put message onto queue."""
|
||||
q = self._new_queue(queue)
|
||||
q.push(dumps(message))
|
||||
|
||||
def _get(self, queue):
|
||||
"""Try to retrieve a single message off ``queue``."""
|
||||
q = self._new_queue(queue)
|
||||
rs = q.pop(1)
|
||||
if rs['items']:
|
||||
m = rs['items'][0]
|
||||
payload = loads(bytes_to_str(m['body']))
|
||||
if queue in self._noack_queues:
|
||||
q.message(m['id']).delete()
|
||||
else:
|
||||
payload['properties']['delivery_info'].update({
|
||||
'slmq_message_id': m['id'], 'slmq_queue_name': q.name})
|
||||
return payload
|
||||
raise Empty()
|
||||
|
||||
def basic_ack(self, delivery_tag):
|
||||
delivery_info = self.qos.get(delivery_tag).delivery_info
|
||||
try:
|
||||
queue = delivery_info['slmq_queue_name']
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
self.delete_message(queue, delivery_info['slmq_message_id'])
|
||||
super().basic_ack(delivery_tag)
|
||||
|
||||
def _size(self, queue):
|
||||
"""Return the number of messages in a queue."""
|
||||
return self._new_queue(queue).detail()['message_count']
|
||||
|
||||
def _purge(self, queue):
|
||||
"""Delete all current messages in a queue."""
|
||||
q = self._new_queue(queue)
|
||||
n = 0
|
||||
results = q.pop(10)
|
||||
while results['items']:
|
||||
for m in results['items']:
|
||||
self.delete_message(queue, m['id'])
|
||||
n += 1
|
||||
results = q.pop(10)
|
||||
return n
|
||||
|
||||
def delete_message(self, queue, message_id):
|
||||
q = self.slmq.queue(self.entity_name(queue))
|
||||
return q.message(message_id).delete()
|
||||
|
||||
@property
|
||||
def slmq(self):
|
||||
if self._slmq is None:
|
||||
conninfo = self.conninfo
|
||||
account = os.environ.get('SLMQ_ACCOUNT', conninfo.virtual_host)
|
||||
user = os.environ.get('SL_USERNAME', conninfo.userid)
|
||||
api_key = os.environ.get('SL_API_KEY', conninfo.password)
|
||||
host = os.environ.get('SLMQ_HOST', conninfo.hostname)
|
||||
port = os.environ.get('SLMQ_PORT', conninfo.port)
|
||||
secure = bool(os.environ.get(
|
||||
'SLMQ_SECURE', self.transport_options.get('secure')) or True,
|
||||
)
|
||||
endpoint = '{}://{}{}'.format(
|
||||
'https' if secure else 'http', host,
|
||||
f':{port}' if port else '',
|
||||
)
|
||||
|
||||
self._slmq = get_client(account, endpoint=endpoint)
|
||||
self._slmq.authenticate(user, api_key)
|
||||
return self._slmq
|
||||
|
||||
@property
|
||||
def conninfo(self):
|
||||
return self.connection.client
|
||||
|
||||
@property
|
||||
def transport_options(self):
|
||||
return self.connection.client.transport_options
|
||||
|
||||
@cached_property
|
||||
def visibility_timeout(self):
|
||||
return (self.transport_options.get('visibility_timeout') or
|
||||
self.default_visibility_timeout)
|
||||
|
||||
@cached_property
|
||||
def queue_name_prefix(self):
|
||||
return self.transport_options.get('queue_name_prefix', '')
|
||||
|
||||
|
||||
class Transport(virtual.Transport):
|
||||
"""SLMQ Transport."""
|
||||
|
||||
Channel = Channel
|
||||
|
||||
polling_interval = 1
|
||||
default_port = None
|
||||
connection_errors = (
|
||||
virtual.Transport.connection_errors + (
|
||||
ResponseError, socket.error
|
||||
)
|
||||
)
|
||||
973
ETB-API/venv/lib/python3.12/site-packages/kombu/transport/SQS.py
Normal file
973
ETB-API/venv/lib/python3.12/site-packages/kombu/transport/SQS.py
Normal file
@@ -0,0 +1,973 @@
|
||||
"""Amazon SQS transport module for Kombu.
|
||||
|
||||
This package implements an AMQP-like interface on top of Amazons SQS service,
|
||||
with the goal of being optimized for high performance and reliability.
|
||||
|
||||
The default settings for this module are focused now on high performance in
|
||||
task queue situations where tasks are small, idempotent and run very fast.
|
||||
|
||||
SQS Features supported by this transport
|
||||
========================================
|
||||
Long Polling
|
||||
------------
|
||||
https://docs.aws.amazon.com/AWSSimpleQueueService/latest/SQSDeveloperGuide/sqs-long-polling.html
|
||||
|
||||
Long polling is enabled by setting the `wait_time_seconds` transport
|
||||
option to a number > 1. Amazon supports up to 20 seconds. This is
|
||||
enabled with 10 seconds by default.
|
||||
|
||||
Batch API Actions
|
||||
-----------------
|
||||
https://docs.aws.amazon.com/AWSSimpleQueueService/latest/SQSDeveloperGuide/sqs-batch-api.html
|
||||
|
||||
The default behavior of the SQS Channel.drain_events() method is to
|
||||
request up to the 'prefetch_count' messages on every request to SQS.
|
||||
These messages are stored locally in a deque object and passed back
|
||||
to the Transport until the deque is empty, before triggering a new
|
||||
API call to Amazon.
|
||||
|
||||
This behavior dramatically speeds up the rate that you can pull tasks
|
||||
from SQS when you have short-running tasks (or a large number of workers).
|
||||
|
||||
When a Celery worker has multiple queues to monitor, it will pull down
|
||||
up to 'prefetch_count' messages from queueA and work on them all before
|
||||
moving on to queueB. If queueB is empty, it will wait up until
|
||||
'polling_interval' expires before moving back and checking on queueA.
|
||||
|
||||
Message Attributes
|
||||
-----------------
|
||||
https://docs.aws.amazon.com/AWSSimpleQueueService/latest/SQSDeveloperGuide/sqs-message-metadata.html
|
||||
|
||||
SQS supports sending message attributes along with the message body.
|
||||
To use this feature, you can pass a 'message_attributes' as keyword argument
|
||||
to `basic_publish` method.
|
||||
|
||||
Other Features supported by this transport
|
||||
==========================================
|
||||
Predefined Queues
|
||||
-----------------
|
||||
The default behavior of this transport is to use a single AWS credential
|
||||
pair in order to manage all SQS queues (e.g. listing queues, creating
|
||||
queues, polling queues, deleting messages).
|
||||
|
||||
If it is preferable for your environment to use multiple AWS credentials, you
|
||||
can use the 'predefined_queues' setting inside the 'transport_options' map.
|
||||
This setting allows you to specify the SQS queue URL and AWS credentials for
|
||||
each of your queues. For example, if you have two queues which both already
|
||||
exist in AWS) you can tell this transport about them as follows:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
transport_options = {
|
||||
'predefined_queues': {
|
||||
'queue-1': {
|
||||
'url': 'https://sqs.us-east-1.amazonaws.com/xxx/aaa',
|
||||
'access_key_id': 'a',
|
||||
'secret_access_key': 'b',
|
||||
'backoff_policy': {1: 10, 2: 20, 3: 40, 4: 80, 5: 320, 6: 640}, # optional
|
||||
'backoff_tasks': ['svc.tasks.tasks.task1'] # optional
|
||||
},
|
||||
'queue-2.fifo': {
|
||||
'url': 'https://sqs.us-east-1.amazonaws.com/xxx/bbb.fifo',
|
||||
'access_key_id': 'c',
|
||||
'secret_access_key': 'd',
|
||||
'backoff_policy': {1: 10, 2: 20, 3: 40, 4: 80, 5: 320, 6: 640}, # optional
|
||||
'backoff_tasks': ['svc.tasks.tasks.task2'] # optional
|
||||
},
|
||||
}
|
||||
'sts_role_arn': 'arn:aws:iam::<xxx>:role/STSTest', # optional
|
||||
'sts_token_timeout': 900 # optional
|
||||
}
|
||||
|
||||
Note that FIFO and standard queues must be named accordingly (the name of
|
||||
a FIFO queue must end with the .fifo suffix).
|
||||
|
||||
backoff_policy & backoff_tasks are optional arguments. These arguments
|
||||
automatically change the message visibility timeout, in order to have
|
||||
different times between specific task retries. This would apply after
|
||||
task failure.
|
||||
|
||||
AWS STS authentication is supported, by using sts_role_arn, and
|
||||
sts_token_timeout. sts_role_arn is the assumed IAM role ARN we are trying
|
||||
to access with. sts_token_timeout is the token timeout, defaults (and minimum)
|
||||
to 900 seconds. After the mentioned period, a new token will be created.
|
||||
|
||||
|
||||
|
||||
If you authenticate using Okta_ (e.g. calling |gac|_), you can also specify
|
||||
a 'session_token' to connect to a queue. Note that those tokens have a
|
||||
limited lifetime and are therefore only suited for short-lived tests.
|
||||
|
||||
.. _Okta: https://www.okta.com/
|
||||
.. _gac: https://github.com/Nike-Inc/gimme-aws-creds#readme
|
||||
.. |gac| replace:: ``gimme-aws-creds``
|
||||
|
||||
|
||||
Client config
|
||||
-------------
|
||||
In some cases you may need to override the botocore config. You can do it
|
||||
as follows:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
transport_option = {
|
||||
'client-config': {
|
||||
'connect_timeout': 5,
|
||||
},
|
||||
}
|
||||
|
||||
For a complete list of settings you can adjust using this option see
|
||||
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
|
||||
|
||||
Features
|
||||
========
|
||||
* Type: Virtual
|
||||
* Supports Direct: Yes
|
||||
* Supports Topic: Yes
|
||||
* Supports Fanout: Yes
|
||||
* Supports Priority: No
|
||||
* Supports TTL: No
|
||||
"""
|
||||
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import socket
|
||||
import string
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from queue import Empty
|
||||
|
||||
from botocore.client import Config
|
||||
from botocore.exceptions import ClientError
|
||||
from vine import ensure_promise, promise, transform
|
||||
|
||||
from kombu.asynchronous import get_event_loop
|
||||
from kombu.asynchronous.aws.ext import boto3, exceptions
|
||||
from kombu.asynchronous.aws.sqs.connection import AsyncSQSConnection
|
||||
from kombu.asynchronous.aws.sqs.message import AsyncMessage
|
||||
from kombu.log import get_logger
|
||||
from kombu.utils import scheduling
|
||||
from kombu.utils.encoding import bytes_to_str, safe_str
|
||||
from kombu.utils.json import dumps, loads
|
||||
from kombu.utils.objects import cached_property
|
||||
|
||||
from . import virtual
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# dots are replaced by dash, dash remains dash, all other punctuation
|
||||
# replaced by underscore.
|
||||
CHARS_REPLACE_TABLE = {
|
||||
ord(c): 0x5f for c in string.punctuation if c not in '-_.'
|
||||
}
|
||||
CHARS_REPLACE_TABLE[0x2e] = 0x2d # '.' -> '-'
|
||||
|
||||
#: SQS bulk get supports a maximum of 10 messages at a time.
|
||||
SQS_MAX_MESSAGES = 10
|
||||
|
||||
|
||||
def maybe_int(x):
|
||||
"""Try to convert x' to int, or return x' if that fails."""
|
||||
try:
|
||||
return int(x)
|
||||
except ValueError:
|
||||
return x
|
||||
|
||||
|
||||
class UndefinedQueueException(Exception):
|
||||
"""Predefined queues are being used and an undefined queue was used."""
|
||||
|
||||
|
||||
class InvalidQueueException(Exception):
|
||||
"""Predefined queues are being used and configuration is not valid."""
|
||||
|
||||
|
||||
class AccessDeniedQueueException(Exception):
|
||||
"""Raised when access to the AWS queue is denied.
|
||||
|
||||
This may occur if the permissions are not correctly set or the
|
||||
credentials are invalid.
|
||||
"""
|
||||
|
||||
|
||||
class DoesNotExistQueueException(Exception):
|
||||
"""The specified queue doesn't exist."""
|
||||
|
||||
|
||||
class QoS(virtual.QoS):
|
||||
"""Quality of Service guarantees implementation for SQS."""
|
||||
|
||||
def reject(self, delivery_tag, requeue=False):
|
||||
super().reject(delivery_tag, requeue=requeue)
|
||||
routing_key, message, backoff_tasks, backoff_policy = \
|
||||
self._extract_backoff_policy_configuration_and_message(
|
||||
delivery_tag)
|
||||
if routing_key and message and backoff_tasks and backoff_policy:
|
||||
self.apply_backoff_policy(
|
||||
routing_key, delivery_tag, backoff_policy, backoff_tasks)
|
||||
|
||||
def _extract_backoff_policy_configuration_and_message(self, delivery_tag):
|
||||
try:
|
||||
message = self._delivered[delivery_tag]
|
||||
routing_key = message.delivery_info['routing_key']
|
||||
except KeyError:
|
||||
return None, None, None, None
|
||||
if not routing_key or not message:
|
||||
return None, None, None, None
|
||||
queue_config = self.channel.predefined_queues.get(routing_key, {})
|
||||
backoff_tasks = queue_config.get('backoff_tasks')
|
||||
backoff_policy = queue_config.get('backoff_policy')
|
||||
return routing_key, message, backoff_tasks, backoff_policy
|
||||
|
||||
def apply_backoff_policy(self, routing_key, delivery_tag,
|
||||
backoff_policy, backoff_tasks):
|
||||
queue_url = self.channel._queue_cache[routing_key]
|
||||
task_name, number_of_retries = \
|
||||
self.extract_task_name_and_number_of_retries(delivery_tag)
|
||||
if not task_name or not number_of_retries:
|
||||
return None
|
||||
policy_value = backoff_policy.get(number_of_retries)
|
||||
if task_name in backoff_tasks and policy_value is not None:
|
||||
c = self.channel.sqs(routing_key)
|
||||
c.change_message_visibility(
|
||||
QueueUrl=queue_url,
|
||||
ReceiptHandle=delivery_tag,
|
||||
VisibilityTimeout=policy_value
|
||||
)
|
||||
|
||||
def extract_task_name_and_number_of_retries(self, delivery_tag):
|
||||
message = self._delivered[delivery_tag]
|
||||
message_headers = message.headers
|
||||
task_name = message_headers['task']
|
||||
number_of_retries = int(
|
||||
message.properties['delivery_info']['sqs_message']
|
||||
['Attributes']['ApproximateReceiveCount'])
|
||||
return task_name, number_of_retries
|
||||
|
||||
|
||||
class Channel(virtual.Channel):
|
||||
"""SQS Channel."""
|
||||
|
||||
default_region = 'us-east-1'
|
||||
default_visibility_timeout = 1800 # 30 minutes.
|
||||
default_wait_time_seconds = 10 # up to 20 seconds max
|
||||
domain_format = 'kombu%(vhost)s'
|
||||
_asynsqs = None
|
||||
_predefined_queue_async_clients = {} # A client for each predefined queue
|
||||
_sqs = None
|
||||
_predefined_queue_clients = {} # A client for each predefined queue
|
||||
_queue_cache = {} # SQS queue name => SQS queue URL
|
||||
_noack_queues = set()
|
||||
QoS = QoS
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
if boto3 is None:
|
||||
raise ImportError('boto3 is not installed')
|
||||
super().__init__(*args, **kwargs)
|
||||
self._validate_predifined_queues()
|
||||
|
||||
# SQS blows up if you try to create a new queue when one already
|
||||
# exists but with a different visibility_timeout. This prepopulates
|
||||
# the queue_cache to protect us from recreating
|
||||
# queues that are known to already exist.
|
||||
self._update_queue_cache(self.queue_name_prefix)
|
||||
|
||||
self.hub = kwargs.get('hub') or get_event_loop()
|
||||
|
||||
def _validate_predifined_queues(self):
|
||||
"""Check that standard and FIFO queues are named properly.
|
||||
|
||||
AWS requires FIFO queues to have a name
|
||||
that ends with the .fifo suffix.
|
||||
"""
|
||||
for queue_name, q in self.predefined_queues.items():
|
||||
fifo_url = q['url'].endswith('.fifo')
|
||||
fifo_name = queue_name.endswith('.fifo')
|
||||
if fifo_url and not fifo_name:
|
||||
raise InvalidQueueException(
|
||||
"Queue with url '{}' must have a name "
|
||||
"ending with .fifo".format(q['url'])
|
||||
)
|
||||
elif not fifo_url and fifo_name:
|
||||
raise InvalidQueueException(
|
||||
"Queue with name '{}' is not a FIFO queue: "
|
||||
"'{}'".format(queue_name, q['url'])
|
||||
)
|
||||
|
||||
def _update_queue_cache(self, queue_name_prefix):
|
||||
if self.predefined_queues:
|
||||
for queue_name, q in self.predefined_queues.items():
|
||||
self._queue_cache[queue_name] = q['url']
|
||||
return
|
||||
|
||||
resp = self.sqs().list_queues(QueueNamePrefix=queue_name_prefix)
|
||||
for url in resp.get('QueueUrls', []):
|
||||
queue_name = url.split('/')[-1]
|
||||
self._queue_cache[queue_name] = url
|
||||
|
||||
def basic_consume(self, queue, no_ack, *args, **kwargs):
|
||||
if no_ack:
|
||||
self._noack_queues.add(queue)
|
||||
if self.hub:
|
||||
self._loop1(queue)
|
||||
return super().basic_consume(
|
||||
queue, no_ack, *args, **kwargs
|
||||
)
|
||||
|
||||
def basic_cancel(self, consumer_tag):
|
||||
if consumer_tag in self._consumers:
|
||||
queue = self._tag_to_queue[consumer_tag]
|
||||
self._noack_queues.discard(queue)
|
||||
return super().basic_cancel(consumer_tag)
|
||||
|
||||
def drain_events(self, timeout=None, callback=None, **kwargs):
|
||||
"""Return a single payload message from one of our queues.
|
||||
|
||||
Raises
|
||||
------
|
||||
Queue.Empty: if no messages available.
|
||||
"""
|
||||
# If we're not allowed to consume or have no consumers, raise Empty
|
||||
if not self._consumers or not self.qos.can_consume():
|
||||
raise Empty()
|
||||
|
||||
# At this point, go and get more messages from SQS
|
||||
self._poll(self.cycle, callback, timeout=timeout)
|
||||
|
||||
def _reset_cycle(self):
|
||||
"""Reset the consume cycle.
|
||||
|
||||
Returns
|
||||
-------
|
||||
FairCycle: object that points to our _get_bulk() method
|
||||
rather than the standard _get() method. This allows for
|
||||
multiple messages to be returned at once from SQS (
|
||||
based on the prefetch limit).
|
||||
"""
|
||||
self._cycle = scheduling.FairCycle(
|
||||
self._get_bulk, self._active_queues, Empty,
|
||||
)
|
||||
|
||||
def entity_name(self, name, table=CHARS_REPLACE_TABLE):
|
||||
"""Format AMQP queue name into a legal SQS queue name."""
|
||||
if name.endswith('.fifo'):
|
||||
partial = name[:-len('.fifo')]
|
||||
partial = str(safe_str(partial)).translate(table)
|
||||
return partial + '.fifo'
|
||||
else:
|
||||
return str(safe_str(name)).translate(table)
|
||||
|
||||
def canonical_queue_name(self, queue_name):
|
||||
return self.entity_name(self.queue_name_prefix + queue_name)
|
||||
|
||||
def _resolve_queue_url(self, queue):
|
||||
"""Try to retrieve the SQS queue URL for a given queue name."""
|
||||
# Translate to SQS name for consistency with initial
|
||||
# _queue_cache population.
|
||||
sqs_qname = self.canonical_queue_name(queue)
|
||||
|
||||
# The SQS ListQueues method only returns 1000 queues. When you have
|
||||
# so many queues, it's possible that the queue you are looking for is
|
||||
# not cached. In this case, we could update the cache with the exact
|
||||
# queue name first.
|
||||
if sqs_qname not in self._queue_cache:
|
||||
self._update_queue_cache(sqs_qname)
|
||||
try:
|
||||
return self._queue_cache[sqs_qname]
|
||||
except KeyError:
|
||||
if self.predefined_queues:
|
||||
raise UndefinedQueueException((
|
||||
"Queue with name '{}' must be "
|
||||
"defined in 'predefined_queues'."
|
||||
).format(sqs_qname))
|
||||
|
||||
raise DoesNotExistQueueException(
|
||||
f"Queue with name '{sqs_qname}' doesn't exist in SQS"
|
||||
)
|
||||
|
||||
def _new_queue(self, queue, **kwargs):
|
||||
"""Ensure a queue with given name exists in SQS.
|
||||
|
||||
Arguments:
|
||||
---------
|
||||
queue (str): the AMQP queue name
|
||||
Returns
|
||||
str: the SQS queue URL
|
||||
"""
|
||||
try:
|
||||
return self._resolve_queue_url(queue)
|
||||
except DoesNotExistQueueException:
|
||||
sqs_qname = self.canonical_queue_name(queue)
|
||||
attributes = {'VisibilityTimeout': str(self.visibility_timeout)}
|
||||
if sqs_qname.endswith('.fifo'):
|
||||
attributes['FifoQueue'] = 'true'
|
||||
|
||||
resp = self._create_queue(sqs_qname, attributes)
|
||||
self._queue_cache[sqs_qname] = resp['QueueUrl']
|
||||
return resp['QueueUrl']
|
||||
|
||||
def _create_queue(self, queue_name, attributes):
|
||||
"""Create an SQS queue with a given name and nominal attributes."""
|
||||
# Allow specifying additional boto create_queue Attributes
|
||||
# via transport options
|
||||
if self.predefined_queues:
|
||||
return None
|
||||
|
||||
attributes.update(
|
||||
self.transport_options.get('sqs-creation-attributes') or {},
|
||||
)
|
||||
|
||||
return self.sqs(queue=queue_name).create_queue(
|
||||
QueueName=queue_name,
|
||||
Attributes=attributes,
|
||||
)
|
||||
|
||||
def _delete(self, queue, *args, **kwargs):
|
||||
"""Delete queue by name."""
|
||||
if self.predefined_queues:
|
||||
return
|
||||
|
||||
q_url = self._resolve_queue_url(queue)
|
||||
self.sqs().delete_queue(
|
||||
QueueUrl=q_url,
|
||||
)
|
||||
self._queue_cache.pop(queue, None)
|
||||
|
||||
def _put(self, queue, message, **kwargs):
|
||||
"""Put message onto queue."""
|
||||
q_url = self._new_queue(queue)
|
||||
kwargs = {'QueueUrl': q_url}
|
||||
if 'properties' in message:
|
||||
if 'message_attributes' in message['properties']:
|
||||
# we don't want to want to have the attribute in the body
|
||||
kwargs['MessageAttributes'] = \
|
||||
message['properties'].pop('message_attributes')
|
||||
if queue.endswith('.fifo'):
|
||||
if 'MessageGroupId' in message['properties']:
|
||||
kwargs['MessageGroupId'] = \
|
||||
message['properties']['MessageGroupId']
|
||||
else:
|
||||
kwargs['MessageGroupId'] = 'default'
|
||||
if 'MessageDeduplicationId' in message['properties']:
|
||||
kwargs['MessageDeduplicationId'] = \
|
||||
message['properties']['MessageDeduplicationId']
|
||||
else:
|
||||
kwargs['MessageDeduplicationId'] = str(uuid.uuid4())
|
||||
else:
|
||||
if "DelaySeconds" in message['properties']:
|
||||
kwargs['DelaySeconds'] = \
|
||||
message['properties']['DelaySeconds']
|
||||
|
||||
if self.sqs_base64_encoding:
|
||||
body = AsyncMessage().encode(dumps(message))
|
||||
else:
|
||||
body = dumps(message)
|
||||
kwargs['MessageBody'] = body
|
||||
|
||||
c = self.sqs(queue=self.canonical_queue_name(queue))
|
||||
if message.get('redelivered'):
|
||||
c.change_message_visibility(
|
||||
QueueUrl=q_url,
|
||||
ReceiptHandle=message['properties']['delivery_tag'],
|
||||
VisibilityTimeout=0
|
||||
)
|
||||
else:
|
||||
c.send_message(**kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _optional_b64_decode(byte_string):
|
||||
try:
|
||||
data = base64.b64decode(byte_string)
|
||||
if base64.b64encode(data) == byte_string:
|
||||
return data
|
||||
# else the base64 module found some embedded base64 content
|
||||
# that should be ignored.
|
||||
except Exception: # pylint: disable=broad-except
|
||||
pass
|
||||
return byte_string
|
||||
|
||||
def _message_to_python(self, message, queue_name, q_url):
|
||||
body = self._optional_b64_decode(message['Body'].encode())
|
||||
payload = loads(bytes_to_str(body))
|
||||
if queue_name in self._noack_queues:
|
||||
q_url = self._new_queue(queue_name)
|
||||
self.asynsqs(queue=queue_name).delete_message(
|
||||
q_url,
|
||||
message['ReceiptHandle'],
|
||||
)
|
||||
else:
|
||||
try:
|
||||
properties = payload['properties']
|
||||
delivery_info = payload['properties']['delivery_info']
|
||||
except KeyError:
|
||||
# json message not sent by kombu?
|
||||
delivery_info = {}
|
||||
properties = {'delivery_info': delivery_info}
|
||||
payload.update({
|
||||
'body': bytes_to_str(body),
|
||||
'properties': properties,
|
||||
})
|
||||
# set delivery tag to SQS receipt handle
|
||||
delivery_info.update({
|
||||
'sqs_message': message, 'sqs_queue': q_url,
|
||||
})
|
||||
properties['delivery_tag'] = message['ReceiptHandle']
|
||||
return payload
|
||||
|
||||
def _messages_to_python(self, messages, queue):
|
||||
"""Convert a list of SQS Message objects into Payloads.
|
||||
|
||||
This method handles converting SQS Message objects into
|
||||
Payloads, and appropriately updating the queue depending on
|
||||
the 'ack' settings for that queue.
|
||||
|
||||
Arguments:
|
||||
---------
|
||||
messages (SQSMessage): A list of SQS Message objects.
|
||||
queue (str): Name representing the queue they came from.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List: A list of Payload objects
|
||||
"""
|
||||
q_url = self._new_queue(queue)
|
||||
return [self._message_to_python(m, queue, q_url) for m in messages]
|
||||
|
||||
def _get_bulk(self, queue,
|
||||
max_if_unlimited=SQS_MAX_MESSAGES, callback=None):
|
||||
"""Try to retrieve multiple messages off ``queue``.
|
||||
|
||||
Where :meth:`_get` returns a single Payload object, this method
|
||||
returns a list of Payload objects. The number of objects returned
|
||||
is determined by the total number of messages available in the queue
|
||||
and the number of messages the QoS object allows (based on the
|
||||
prefetch_count).
|
||||
|
||||
Note:
|
||||
----
|
||||
Ignores QoS limits so caller is responsible for checking
|
||||
that we are allowed to consume at least one message from the
|
||||
queue. get_bulk will then ask QoS for an estimate of
|
||||
the number of extra messages that we can consume.
|
||||
|
||||
Arguments:
|
||||
---------
|
||||
queue (str): The queue name to pull from.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[Message]
|
||||
"""
|
||||
# drain_events calls `can_consume` first, consuming
|
||||
# a token, so we know that we are allowed to consume at least
|
||||
# one message.
|
||||
|
||||
# Note: ignoring max_messages for SQS with boto3
|
||||
max_count = self._get_message_estimate()
|
||||
if max_count:
|
||||
q_url = self._new_queue(queue)
|
||||
resp = self.sqs(queue=queue).receive_message(
|
||||
QueueUrl=q_url, MaxNumberOfMessages=max_count,
|
||||
WaitTimeSeconds=self.wait_time_seconds)
|
||||
if resp.get('Messages'):
|
||||
for m in resp['Messages']:
|
||||
m['Body'] = AsyncMessage(body=m['Body']).decode()
|
||||
for msg in self._messages_to_python(resp['Messages'], queue):
|
||||
self.connection._deliver(msg, queue)
|
||||
return
|
||||
raise Empty()
|
||||
|
||||
def _get(self, queue):
|
||||
"""Try to retrieve a single message off ``queue``."""
|
||||
q_url = self._new_queue(queue)
|
||||
resp = self.sqs(queue=queue).receive_message(
|
||||
QueueUrl=q_url, MaxNumberOfMessages=1,
|
||||
WaitTimeSeconds=self.wait_time_seconds)
|
||||
if resp.get('Messages'):
|
||||
body = AsyncMessage(body=resp['Messages'][0]['Body']).decode()
|
||||
resp['Messages'][0]['Body'] = body
|
||||
return self._messages_to_python(resp['Messages'], queue)[0]
|
||||
raise Empty()
|
||||
|
||||
def _loop1(self, queue, _=None):
|
||||
self.hub.call_soon(self._schedule_queue, queue)
|
||||
|
||||
def _schedule_queue(self, queue):
|
||||
if queue in self._active_queues:
|
||||
if self.qos.can_consume():
|
||||
self._get_bulk_async(
|
||||
queue, callback=promise(self._loop1, (queue,)),
|
||||
)
|
||||
else:
|
||||
self._loop1(queue)
|
||||
|
||||
def _get_message_estimate(self, max_if_unlimited=SQS_MAX_MESSAGES):
|
||||
maxcount = self.qos.can_consume_max_estimate()
|
||||
return min(
|
||||
max_if_unlimited if maxcount is None else max(maxcount, 1),
|
||||
max_if_unlimited,
|
||||
)
|
||||
|
||||
def _get_bulk_async(self, queue,
|
||||
max_if_unlimited=SQS_MAX_MESSAGES, callback=None):
|
||||
maxcount = self._get_message_estimate()
|
||||
if maxcount:
|
||||
return self._get_async(queue, maxcount, callback=callback)
|
||||
# Not allowed to consume, make sure to notify callback..
|
||||
callback = ensure_promise(callback)
|
||||
callback([])
|
||||
return callback
|
||||
|
||||
def _get_async(self, queue, count=1, callback=None):
|
||||
q_url = self._new_queue(queue)
|
||||
qname = self.canonical_queue_name(queue)
|
||||
return self._get_from_sqs(
|
||||
queue_name=qname, queue_url=q_url, count=count,
|
||||
connection=self.asynsqs(queue=qname),
|
||||
callback=transform(
|
||||
self._on_messages_ready, callback, q_url, queue
|
||||
),
|
||||
)
|
||||
|
||||
def _on_messages_ready(self, queue, qname, messages):
|
||||
if 'Messages' in messages and messages['Messages']:
|
||||
callbacks = self.connection._callbacks
|
||||
for msg in messages['Messages']:
|
||||
msg_parsed = self._message_to_python(msg, qname, queue)
|
||||
callbacks[qname](msg_parsed)
|
||||
|
||||
def _get_from_sqs(self, queue_name, queue_url,
|
||||
connection, count=1, callback=None):
|
||||
"""Retrieve and handle messages from SQS.
|
||||
|
||||
Uses long polling and returns :class:`~vine.promises.promise`.
|
||||
"""
|
||||
return connection.receive_message(
|
||||
queue_name, queue_url, number_messages=count,
|
||||
wait_time_seconds=self.wait_time_seconds,
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
def _restore(self, message,
|
||||
unwanted_delivery_info=('sqs_message', 'sqs_queue')):
|
||||
for unwanted_key in unwanted_delivery_info:
|
||||
# Remove objects that aren't JSON serializable (Issue #1108).
|
||||
message.delivery_info.pop(unwanted_key, None)
|
||||
return super()._restore(message)
|
||||
|
||||
def basic_ack(self, delivery_tag, multiple=False):
|
||||
try:
|
||||
message = self.qos.get(delivery_tag).delivery_info
|
||||
sqs_message = message['sqs_message']
|
||||
except KeyError:
|
||||
super().basic_ack(delivery_tag)
|
||||
else:
|
||||
queue = None
|
||||
if 'routing_key' in message:
|
||||
queue = self.canonical_queue_name(message['routing_key'])
|
||||
|
||||
try:
|
||||
self.sqs(queue=queue).delete_message(
|
||||
QueueUrl=message['sqs_queue'],
|
||||
ReceiptHandle=sqs_message['ReceiptHandle']
|
||||
)
|
||||
except ClientError as exception:
|
||||
if exception.response['Error']['Code'] == 'AccessDenied':
|
||||
raise AccessDeniedQueueException(
|
||||
exception.response["Error"]["Message"]
|
||||
)
|
||||
super().basic_reject(delivery_tag)
|
||||
else:
|
||||
super().basic_ack(delivery_tag)
|
||||
|
||||
def _size(self, queue):
|
||||
"""Return the number of messages in a queue."""
|
||||
q_url = self._new_queue(queue)
|
||||
c = self.sqs(queue=self.canonical_queue_name(queue))
|
||||
resp = c.get_queue_attributes(
|
||||
QueueUrl=q_url,
|
||||
AttributeNames=['ApproximateNumberOfMessages'])
|
||||
return int(resp['Attributes']['ApproximateNumberOfMessages'])
|
||||
|
||||
def _purge(self, queue):
|
||||
"""Delete all current messages in a queue."""
|
||||
q_url = self._new_queue(queue)
|
||||
# SQS is slow at registering messages, so run for a few
|
||||
# iterations to ensure messages are detected and deleted.
|
||||
size = 0
|
||||
for i in range(10):
|
||||
size += int(self._size(queue))
|
||||
if not size:
|
||||
break
|
||||
self.sqs(queue=queue).purge_queue(QueueUrl=q_url)
|
||||
return size
|
||||
|
||||
def close(self):
|
||||
super().close()
|
||||
# if self._asynsqs:
|
||||
# try:
|
||||
# self.asynsqs().close()
|
||||
# except AttributeError as exc: # FIXME ???
|
||||
# if "can't set attribute" not in str(exc):
|
||||
# raise
|
||||
|
||||
def new_sqs_client(self, region, access_key_id,
|
||||
secret_access_key, session_token=None):
|
||||
session = boto3.session.Session(
|
||||
region_name=region,
|
||||
aws_access_key_id=access_key_id,
|
||||
aws_secret_access_key=secret_access_key,
|
||||
aws_session_token=session_token,
|
||||
)
|
||||
is_secure = self.is_secure if self.is_secure is not None else True
|
||||
client_kwargs = {
|
||||
'use_ssl': is_secure
|
||||
}
|
||||
if self.endpoint_url is not None:
|
||||
client_kwargs['endpoint_url'] = self.endpoint_url
|
||||
client_config = self.transport_options.get('client-config') or {}
|
||||
config = Config(**client_config)
|
||||
return session.client('sqs', config=config, **client_kwargs)
|
||||
|
||||
def sqs(self, queue=None):
|
||||
if queue is not None and self.predefined_queues:
|
||||
|
||||
if queue not in self.predefined_queues:
|
||||
raise UndefinedQueueException(
|
||||
f"Queue with name '{queue}' must be defined"
|
||||
" in 'predefined_queues'.")
|
||||
q = self.predefined_queues[queue]
|
||||
if self.transport_options.get('sts_role_arn'):
|
||||
return self._handle_sts_session(queue, q)
|
||||
if not self.transport_options.get('sts_role_arn'):
|
||||
if queue in self._predefined_queue_clients:
|
||||
return self._predefined_queue_clients[queue]
|
||||
else:
|
||||
c = self._predefined_queue_clients[queue] = \
|
||||
self.new_sqs_client(
|
||||
region=q.get('region', self.region),
|
||||
access_key_id=q.get(
|
||||
'access_key_id', self.conninfo.userid),
|
||||
secret_access_key=q.get(
|
||||
'secret_access_key', self.conninfo.password)
|
||||
)
|
||||
return c
|
||||
|
||||
if self._sqs is not None:
|
||||
return self._sqs
|
||||
|
||||
c = self._sqs = self.new_sqs_client(
|
||||
region=self.region,
|
||||
access_key_id=self.conninfo.userid,
|
||||
secret_access_key=self.conninfo.password,
|
||||
)
|
||||
return c
|
||||
|
||||
def _handle_sts_session(self, queue, q):
|
||||
region = q.get('region', self.region)
|
||||
if not hasattr(self, 'sts_expiration'): # STS token - token init
|
||||
return self._new_predefined_queue_client_with_sts_session(queue, region)
|
||||
# STS token - refresh if expired
|
||||
elif self.sts_expiration.replace(tzinfo=None) < datetime.utcnow():
|
||||
return self._new_predefined_queue_client_with_sts_session(queue, region)
|
||||
else: # STS token - ruse existing
|
||||
if queue not in self._predefined_queue_clients:
|
||||
return self._new_predefined_queue_client_with_sts_session(queue, region)
|
||||
return self._predefined_queue_clients[queue]
|
||||
|
||||
def _new_predefined_queue_client_with_sts_session(self, queue, region):
|
||||
sts_creds = self.generate_sts_session_token(
|
||||
self.transport_options.get('sts_role_arn'),
|
||||
self.transport_options.get('sts_token_timeout', 900))
|
||||
self.sts_expiration = sts_creds['Expiration']
|
||||
c = self._predefined_queue_clients[queue] = self.new_sqs_client(
|
||||
region=region,
|
||||
access_key_id=sts_creds['AccessKeyId'],
|
||||
secret_access_key=sts_creds['SecretAccessKey'],
|
||||
session_token=sts_creds['SessionToken'],
|
||||
)
|
||||
return c
|
||||
|
||||
def generate_sts_session_token(self, role_arn, token_expiry_seconds):
|
||||
sts_client = boto3.client('sts')
|
||||
sts_policy = sts_client.assume_role(
|
||||
RoleArn=role_arn,
|
||||
RoleSessionName='Celery',
|
||||
DurationSeconds=token_expiry_seconds
|
||||
)
|
||||
return sts_policy['Credentials']
|
||||
|
||||
def asynsqs(self, queue=None):
|
||||
if queue is not None and self.predefined_queues:
|
||||
if queue in self._predefined_queue_async_clients and \
|
||||
not hasattr(self, 'sts_expiration'):
|
||||
return self._predefined_queue_async_clients[queue]
|
||||
if queue not in self.predefined_queues:
|
||||
raise UndefinedQueueException((
|
||||
"Queue with name '{}' must be defined in "
|
||||
"'predefined_queues'."
|
||||
).format(queue))
|
||||
q = self.predefined_queues[queue]
|
||||
c = self._predefined_queue_async_clients[queue] = \
|
||||
AsyncSQSConnection(
|
||||
sqs_connection=self.sqs(queue=queue),
|
||||
region=q.get('region', self.region),
|
||||
fetch_message_attributes=self.fetch_message_attributes,
|
||||
)
|
||||
return c
|
||||
|
||||
if self._asynsqs is not None:
|
||||
return self._asynsqs
|
||||
|
||||
c = self._asynsqs = AsyncSQSConnection(
|
||||
sqs_connection=self.sqs(queue=queue),
|
||||
region=self.region,
|
||||
fetch_message_attributes=self.fetch_message_attributes,
|
||||
)
|
||||
return c
|
||||
|
||||
@property
|
||||
def conninfo(self):
|
||||
return self.connection.client
|
||||
|
||||
@property
|
||||
def transport_options(self):
|
||||
return self.connection.client.transport_options
|
||||
|
||||
@cached_property
|
||||
def visibility_timeout(self):
|
||||
return (self.transport_options.get('visibility_timeout') or
|
||||
self.default_visibility_timeout)
|
||||
|
||||
@cached_property
|
||||
def predefined_queues(self):
|
||||
"""Map of queue_name to predefined queue settings."""
|
||||
return self.transport_options.get('predefined_queues', {})
|
||||
|
||||
@cached_property
|
||||
def queue_name_prefix(self):
|
||||
return self.transport_options.get('queue_name_prefix', '')
|
||||
|
||||
@cached_property
|
||||
def supports_fanout(self):
|
||||
return False
|
||||
|
||||
@cached_property
|
||||
def region(self):
|
||||
return (self.transport_options.get('region') or
|
||||
boto3.Session().region_name or
|
||||
self.default_region)
|
||||
|
||||
@cached_property
|
||||
def regioninfo(self):
|
||||
return self.transport_options.get('regioninfo')
|
||||
|
||||
@cached_property
|
||||
def is_secure(self):
|
||||
return self.transport_options.get('is_secure')
|
||||
|
||||
@cached_property
|
||||
def port(self):
|
||||
return self.transport_options.get('port')
|
||||
|
||||
@cached_property
|
||||
def endpoint_url(self):
|
||||
if self.conninfo.hostname is not None:
|
||||
scheme = 'https' if self.is_secure else 'http'
|
||||
if self.conninfo.port is not None:
|
||||
port = f':{self.conninfo.port}'
|
||||
else:
|
||||
port = ''
|
||||
return '{}://{}{}'.format(
|
||||
scheme,
|
||||
self.conninfo.hostname,
|
||||
port
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def wait_time_seconds(self):
|
||||
return self.transport_options.get('wait_time_seconds',
|
||||
self.default_wait_time_seconds)
|
||||
|
||||
@cached_property
|
||||
def sqs_base64_encoding(self):
|
||||
return self.transport_options.get('sqs_base64_encoding', True)
|
||||
|
||||
@cached_property
|
||||
def fetch_message_attributes(self):
|
||||
return self.transport_options.get('fetch_message_attributes')
|
||||
|
||||
|
||||
class Transport(virtual.Transport):
|
||||
"""SQS Transport.
|
||||
|
||||
Additional queue attributes can be supplied to SQS during queue
|
||||
creation by passing an ``sqs-creation-attributes`` key in
|
||||
transport_options. ``sqs-creation-attributes`` must be a dict whose
|
||||
key-value pairs correspond with Attributes in the
|
||||
`CreateQueue SQS API`_.
|
||||
|
||||
For example, to have SQS queues created with server-side encryption
|
||||
enabled using the default Amazon Managed Customer Master Key, you
|
||||
can set ``KmsMasterKeyId`` Attribute. When the queue is initially
|
||||
created by Kombu, encryption will be enabled.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from kombu.transport.SQS import Transport
|
||||
|
||||
transport = Transport(
|
||||
...,
|
||||
transport_options={
|
||||
'sqs-creation-attributes': {
|
||||
'KmsMasterKeyId': 'alias/aws/sqs',
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
.. _CreateQueue SQS API: https://docs.aws.amazon.com/AWSSimpleQueueService/latest/APIReference/API_CreateQueue.html#API_CreateQueue_RequestParameters
|
||||
|
||||
The ``ApproximateReceiveCount`` message attribute is fetched by this
|
||||
transport by default. Requested message attributes can be changed by
|
||||
setting ``fetch_message_attributes`` in the transport options.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from kombu.transport.SQS import Transport
|
||||
|
||||
transport = Transport(
|
||||
...,
|
||||
transport_options={
|
||||
'fetch_message_attributes': ["All"],
|
||||
}
|
||||
)
|
||||
|
||||
.. _Message Attributes: https://docs.aws.amazon.com/AWSSimpleQueueService/latest/APIReference/API_ReceiveMessage.html#SQS-ReceiveMessage-request-AttributeNames
|
||||
|
||||
""" # noqa: E501
|
||||
|
||||
Channel = Channel
|
||||
|
||||
polling_interval = 1
|
||||
wait_time_seconds = 0
|
||||
default_port = None
|
||||
connection_errors = (
|
||||
virtual.Transport.connection_errors +
|
||||
(exceptions.BotoCoreError, socket.error)
|
||||
)
|
||||
channel_errors = (
|
||||
virtual.Transport.channel_errors + (exceptions.BotoCoreError,)
|
||||
)
|
||||
driver_type = 'sqs'
|
||||
driver_name = 'sqs'
|
||||
|
||||
implements = virtual.Transport.implements.extend(
|
||||
asynchronous=True,
|
||||
exchange_type=frozenset(['direct']),
|
||||
)
|
||||
|
||||
@property
|
||||
def default_connection_params(self):
|
||||
return {'port': self.default_port}
|
||||
@@ -0,0 +1,93 @@
|
||||
"""Built-in transports."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from kombu.utils.compat import _detect_environment
|
||||
from kombu.utils.imports import symbol_by_name
|
||||
|
||||
|
||||
def supports_librabbitmq() -> bool | None:
|
||||
"""Return true if :pypi:`librabbitmq` can be used."""
|
||||
if _detect_environment() == 'default':
|
||||
try:
|
||||
import librabbitmq # noqa
|
||||
except ImportError: # pragma: no cover
|
||||
pass
|
||||
else: # pragma: no cover
|
||||
return True
|
||||
return None
|
||||
|
||||
|
||||
TRANSPORT_ALIASES = {
|
||||
'amqp': 'kombu.transport.pyamqp:Transport',
|
||||
'amqps': 'kombu.transport.pyamqp:SSLTransport',
|
||||
'pyamqp': 'kombu.transport.pyamqp:Transport',
|
||||
'librabbitmq': 'kombu.transport.librabbitmq:Transport',
|
||||
'confluentkafka': 'kombu.transport.confluentkafka:Transport',
|
||||
'kafka': 'kombu.transport.confluentkafka:Transport',
|
||||
'memory': 'kombu.transport.memory:Transport',
|
||||
'redis': 'kombu.transport.redis:Transport',
|
||||
'rediss': 'kombu.transport.redis:Transport',
|
||||
'SQS': 'kombu.transport.SQS:Transport',
|
||||
'sqs': 'kombu.transport.SQS:Transport',
|
||||
'mongodb': 'kombu.transport.mongodb:Transport',
|
||||
'zookeeper': 'kombu.transport.zookeeper:Transport',
|
||||
'sqlalchemy': 'kombu.transport.sqlalchemy:Transport',
|
||||
'sqla': 'kombu.transport.sqlalchemy:Transport',
|
||||
'SLMQ': 'kombu.transport.SLMQ.Transport',
|
||||
'slmq': 'kombu.transport.SLMQ.Transport',
|
||||
'filesystem': 'kombu.transport.filesystem:Transport',
|
||||
'qpid': 'kombu.transport.qpid:Transport',
|
||||
'sentinel': 'kombu.transport.redis:SentinelTransport',
|
||||
'consul': 'kombu.transport.consul:Transport',
|
||||
'etcd': 'kombu.transport.etcd:Transport',
|
||||
'azurestoragequeues': 'kombu.transport.azurestoragequeues:Transport',
|
||||
'azureservicebus': 'kombu.transport.azureservicebus:Transport',
|
||||
'pyro': 'kombu.transport.pyro:Transport',
|
||||
'gcpubsub': 'kombu.transport.gcpubsub:Transport',
|
||||
}
|
||||
|
||||
_transport_cache = {}
|
||||
|
||||
|
||||
def resolve_transport(transport: str | None = None) -> str | None:
|
||||
"""Get transport by name.
|
||||
|
||||
Arguments:
|
||||
---------
|
||||
transport (Union[str, type]): This can be either
|
||||
an actual transport class, or the fully qualified
|
||||
path to a transport class, or the alias of a transport.
|
||||
"""
|
||||
if isinstance(transport, str):
|
||||
try:
|
||||
transport = TRANSPORT_ALIASES[transport]
|
||||
except KeyError:
|
||||
if '.' not in transport and ':' not in transport:
|
||||
from kombu.utils.text import fmatch_best
|
||||
alt = fmatch_best(transport, TRANSPORT_ALIASES)
|
||||
if alt:
|
||||
raise KeyError(
|
||||
'No such transport: {}. Did you mean {}?'.format(
|
||||
transport, alt))
|
||||
raise KeyError(f'No such transport: {transport}')
|
||||
else:
|
||||
if callable(transport):
|
||||
transport = transport()
|
||||
return symbol_by_name(transport)
|
||||
return transport
|
||||
|
||||
|
||||
def get_transport_cls(transport: str | None = None) -> str | None:
|
||||
"""Get transport class by name.
|
||||
|
||||
The transport string is the full path to a transport class, e.g.::
|
||||
|
||||
"kombu.transport.pyamqp:Transport"
|
||||
|
||||
If the name does not include `"."` (is not fully qualified),
|
||||
the alias table will be consulted.
|
||||
"""
|
||||
if transport not in _transport_cache:
|
||||
_transport_cache[transport] = resolve_transport(transport)
|
||||
return _transport_cache[transport]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,498 @@
|
||||
"""Azure Service Bus Message Queue transport module for kombu.
|
||||
|
||||
Note that the Shared Access Policy used to connect to Azure Service Bus
|
||||
requires Manage, Send and Listen claims since the broker will create new
|
||||
queues and delete old queues as required.
|
||||
|
||||
|
||||
Notes when using with Celery if you are experiencing issues with programs not
|
||||
terminating properly. The Azure Service Bus SDK uses the Azure uAMQP library
|
||||
which in turn creates some threads. If the AzureServiceBus Channel is closed,
|
||||
said threads will be closed properly, but it seems there are times when Celery
|
||||
does not do this so these threads will be left running. As the uAMQP threads
|
||||
are not marked as Daemon threads, they will not be killed when the main thread
|
||||
exits. Setting the ``uamqp_keep_alive_interval`` transport option to 0 will
|
||||
prevent the keep_alive thread from starting
|
||||
|
||||
|
||||
More information about Azure Service Bus:
|
||||
https://azure.microsoft.com/en-us/services/service-bus/
|
||||
|
||||
Features
|
||||
========
|
||||
* Type: Virtual
|
||||
* Supports Direct: *Unreviewed*
|
||||
* Supports Topic: *Unreviewed*
|
||||
* Supports Fanout: *Unreviewed*
|
||||
* Supports Priority: *Unreviewed*
|
||||
* Supports TTL: *Unreviewed*
|
||||
|
||||
Connection String
|
||||
=================
|
||||
|
||||
Connection string has the following formats:
|
||||
|
||||
.. code-block::
|
||||
|
||||
azureservicebus://SAS_POLICY_NAME:SAS_KEY@SERVICE_BUSNAMESPACE
|
||||
azureservicebus://DefaultAzureCredential@SERVICE_BUSNAMESPACE
|
||||
azureservicebus://ManagedIdentityCredential@SERVICE_BUSNAMESPACE
|
||||
|
||||
Transport Options
|
||||
=================
|
||||
|
||||
* ``queue_name_prefix`` - String prefix to prepend to queue names in a
|
||||
service bus namespace.
|
||||
* ``wait_time_seconds`` - Number of seconds to wait to receive messages.
|
||||
Default ``5``
|
||||
* ``peek_lock_seconds`` - Number of seconds the message is visible for before
|
||||
it is requeued and sent to another consumer. Default ``60``
|
||||
* ``uamqp_keep_alive_interval`` - Interval in seconds the Azure uAMQP library
|
||||
should send keepalive messages. Default ``30``
|
||||
* ``retry_total`` - Azure SDK retry total. Default ``3``
|
||||
* ``retry_backoff_factor`` - Azure SDK exponential backoff factor.
|
||||
Default ``0.8``
|
||||
* ``retry_backoff_max`` - Azure SDK retry total time. Default ``120``
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import string
|
||||
from queue import Empty
|
||||
from typing import Any
|
||||
|
||||
import azure.core.exceptions
|
||||
import azure.servicebus.exceptions
|
||||
import isodate
|
||||
from azure.servicebus import (ServiceBusClient, ServiceBusMessage,
|
||||
ServiceBusReceiveMode, ServiceBusReceiver,
|
||||
ServiceBusSender)
|
||||
from azure.servicebus.management import ServiceBusAdministrationClient
|
||||
|
||||
try:
|
||||
from azure.identity import (DefaultAzureCredential,
|
||||
ManagedIdentityCredential)
|
||||
except ImportError:
|
||||
DefaultAzureCredential = None
|
||||
ManagedIdentityCredential = None
|
||||
|
||||
from kombu.utils.encoding import bytes_to_str, safe_str
|
||||
from kombu.utils.json import dumps, loads
|
||||
from kombu.utils.objects import cached_property
|
||||
|
||||
from . import virtual
|
||||
|
||||
# dots are replaced by dash, all other punctuation replaced by underscore.
|
||||
PUNCTUATIONS_TO_REPLACE = set(string.punctuation) - {'_', '.', '-'}
|
||||
CHARS_REPLACE_TABLE = {
|
||||
ord('.'): ord('-'),
|
||||
**{ord(c): ord('_') for c in PUNCTUATIONS_TO_REPLACE}
|
||||
}
|
||||
|
||||
|
||||
class SendReceive:
|
||||
"""Container for Sender and Receiver."""
|
||||
|
||||
def __init__(self,
|
||||
receiver: ServiceBusReceiver | None = None,
|
||||
sender: ServiceBusSender | None = None):
|
||||
self.receiver: ServiceBusReceiver = receiver
|
||||
self.sender: ServiceBusSender = sender
|
||||
|
||||
def close(self) -> None:
|
||||
if self.receiver:
|
||||
self.receiver.close()
|
||||
self.receiver = None
|
||||
if self.sender:
|
||||
self.sender.close()
|
||||
self.sender = None
|
||||
|
||||
|
||||
class Channel(virtual.Channel):
|
||||
"""Azure Service Bus channel."""
|
||||
|
||||
default_wait_time_seconds: int = 5 # in seconds
|
||||
default_peek_lock_seconds: int = 60 # in seconds (default 60, max 300)
|
||||
# in seconds (is the default from service bus repo)
|
||||
default_uamqp_keep_alive_interval: int = 30
|
||||
# number of retries (is the default from service bus repo)
|
||||
default_retry_total: int = 3
|
||||
# exponential backoff factor (is the default from service bus repo)
|
||||
default_retry_backoff_factor: float = 0.8
|
||||
# Max time to backoff (is the default from service bus repo)
|
||||
default_retry_backoff_max: int = 120
|
||||
domain_format: str = 'kombu%(vhost)s'
|
||||
_queue_cache: dict[str, SendReceive] = {}
|
||||
_noack_queues: set[str] = set()
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self._namespace = None
|
||||
self._policy = None
|
||||
self._sas_key = None
|
||||
self._connection_string = None
|
||||
|
||||
self._try_parse_connection_string()
|
||||
|
||||
self.qos.restore_at_shutdown = False
|
||||
|
||||
def _try_parse_connection_string(self) -> None:
|
||||
self._namespace, self._credential = Transport.parse_uri(
|
||||
self.conninfo.hostname)
|
||||
|
||||
if (
|
||||
DefaultAzureCredential is not None
|
||||
and isinstance(self._credential, DefaultAzureCredential)
|
||||
) or (
|
||||
ManagedIdentityCredential is not None
|
||||
and isinstance(self._credential, ManagedIdentityCredential)
|
||||
):
|
||||
return None
|
||||
|
||||
if ":" in self._credential:
|
||||
self._policy, self._sas_key = self._credential.split(':', 1)
|
||||
|
||||
conn_dict = {
|
||||
'Endpoint': 'sb://' + self._namespace,
|
||||
'SharedAccessKeyName': self._policy,
|
||||
'SharedAccessKey': self._sas_key,
|
||||
}
|
||||
self._connection_string = ';'.join(
|
||||
[key + '=' + value for key, value in conn_dict.items()])
|
||||
|
||||
def basic_consume(self, queue, no_ack, *args, **kwargs):
|
||||
if no_ack:
|
||||
self._noack_queues.add(queue)
|
||||
return super().basic_consume(
|
||||
queue, no_ack, *args, **kwargs
|
||||
)
|
||||
|
||||
def basic_cancel(self, consumer_tag):
|
||||
if consumer_tag in self._consumers:
|
||||
queue = self._tag_to_queue[consumer_tag]
|
||||
self._noack_queues.discard(queue)
|
||||
return super().basic_cancel(consumer_tag)
|
||||
|
||||
def _add_queue_to_cache(
|
||||
self, name: str,
|
||||
receiver: ServiceBusReceiver | None = None,
|
||||
sender: ServiceBusSender | None = None
|
||||
) -> SendReceive:
|
||||
if name in self._queue_cache:
|
||||
obj = self._queue_cache[name]
|
||||
obj.sender = obj.sender or sender
|
||||
obj.receiver = obj.receiver or receiver
|
||||
else:
|
||||
obj = SendReceive(receiver, sender)
|
||||
self._queue_cache[name] = obj
|
||||
return obj
|
||||
|
||||
def _get_asb_sender(self, queue: str) -> SendReceive:
|
||||
queue_obj = self._queue_cache.get(queue, None)
|
||||
if queue_obj is None or queue_obj.sender is None:
|
||||
sender = self.queue_service.get_queue_sender(
|
||||
queue, keep_alive=self.uamqp_keep_alive_interval)
|
||||
queue_obj = self._add_queue_to_cache(queue, sender=sender)
|
||||
return queue_obj
|
||||
|
||||
def _get_asb_receiver(
|
||||
self, queue: str,
|
||||
recv_mode: ServiceBusReceiveMode = ServiceBusReceiveMode.PEEK_LOCK,
|
||||
queue_cache_key: str | None = None) -> SendReceive:
|
||||
cache_key = queue_cache_key or queue
|
||||
queue_obj = self._queue_cache.get(cache_key, None)
|
||||
if queue_obj is None or queue_obj.receiver is None:
|
||||
receiver = self.queue_service.get_queue_receiver(
|
||||
queue_name=queue, receive_mode=recv_mode,
|
||||
keep_alive=self.uamqp_keep_alive_interval)
|
||||
queue_obj = self._add_queue_to_cache(cache_key, receiver=receiver)
|
||||
return queue_obj
|
||||
|
||||
def entity_name(
|
||||
self, name: str, table: dict[int, int] | None = None) -> str:
|
||||
"""Format AMQP queue name into a valid ServiceBus queue name."""
|
||||
return str(safe_str(name)).translate(table or CHARS_REPLACE_TABLE)
|
||||
|
||||
def _restore(self, message: virtual.base.Message) -> None:
|
||||
# Not be needed as ASB handles unacked messages
|
||||
# Remove 'azure_message' as its not JSON serializable
|
||||
# message.delivery_info.pop('azure_message', None)
|
||||
# super()._restore(message)
|
||||
pass
|
||||
|
||||
def _new_queue(self, queue: str, **kwargs) -> SendReceive:
|
||||
"""Ensure a queue exists in ServiceBus."""
|
||||
queue = self.entity_name(self.queue_name_prefix + queue)
|
||||
|
||||
try:
|
||||
return self._queue_cache[queue]
|
||||
except KeyError:
|
||||
# Converts seconds into ISO8601 duration format
|
||||
# ie 66seconds = P1M6S
|
||||
lock_duration = isodate.duration_isoformat(
|
||||
isodate.Duration(seconds=self.peek_lock_seconds))
|
||||
try:
|
||||
self.queue_mgmt_service.create_queue(
|
||||
queue_name=queue, lock_duration=lock_duration)
|
||||
except azure.core.exceptions.ResourceExistsError:
|
||||
pass
|
||||
return self._add_queue_to_cache(queue)
|
||||
|
||||
def _delete(self, queue: str, *args, **kwargs) -> None:
|
||||
"""Delete queue by name."""
|
||||
queue = self.entity_name(self.queue_name_prefix + queue)
|
||||
|
||||
self.queue_mgmt_service.delete_queue(queue)
|
||||
send_receive_obj = self._queue_cache.pop(queue, None)
|
||||
if send_receive_obj:
|
||||
send_receive_obj.close()
|
||||
|
||||
def _put(self, queue: str, message, **kwargs) -> None:
|
||||
"""Put message onto queue."""
|
||||
queue = self.entity_name(self.queue_name_prefix + queue)
|
||||
msg = ServiceBusMessage(dumps(message))
|
||||
|
||||
queue_obj = self._get_asb_sender(queue)
|
||||
queue_obj.sender.send_messages(msg)
|
||||
|
||||
def _get(
|
||||
self, queue: str,
|
||||
timeout: float | int | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""Try to retrieve a single message off ``queue``."""
|
||||
# If we're not ack'ing for this queue, just change receive_mode
|
||||
recv_mode = ServiceBusReceiveMode.RECEIVE_AND_DELETE \
|
||||
if queue in self._noack_queues else ServiceBusReceiveMode.PEEK_LOCK
|
||||
|
||||
queue = self.entity_name(self.queue_name_prefix + queue)
|
||||
|
||||
queue_obj = self._get_asb_receiver(queue, recv_mode)
|
||||
messages = queue_obj.receiver.receive_messages(
|
||||
max_message_count=1,
|
||||
max_wait_time=timeout or self.wait_time_seconds)
|
||||
|
||||
if not messages:
|
||||
raise Empty()
|
||||
|
||||
# message.body is either byte or generator[bytes]
|
||||
message = messages[0]
|
||||
if not isinstance(message.body, bytes):
|
||||
body = b''.join(message.body)
|
||||
else:
|
||||
body = message.body
|
||||
|
||||
msg = loads(bytes_to_str(body))
|
||||
msg['properties']['delivery_info']['azure_message'] = message
|
||||
msg['properties']['delivery_info']['azure_queue_name'] = queue
|
||||
|
||||
return msg
|
||||
|
||||
def basic_ack(self, delivery_tag: str, multiple: bool = False) -> None:
|
||||
try:
|
||||
delivery_info = self.qos.get(delivery_tag).delivery_info
|
||||
except KeyError:
|
||||
super().basic_ack(delivery_tag)
|
||||
else:
|
||||
queue = delivery_info['azure_queue_name']
|
||||
# recv_mode is PEEK_LOCK when ack'ing messages
|
||||
queue_obj = self._get_asb_receiver(queue)
|
||||
|
||||
try:
|
||||
queue_obj.receiver.complete_message(
|
||||
delivery_info['azure_message'])
|
||||
except azure.servicebus.exceptions.MessageAlreadySettled:
|
||||
super().basic_ack(delivery_tag)
|
||||
except Exception:
|
||||
super().basic_reject(delivery_tag)
|
||||
else:
|
||||
super().basic_ack(delivery_tag)
|
||||
|
||||
def _size(self, queue: str) -> int:
|
||||
"""Return the number of messages in a queue."""
|
||||
queue = self.entity_name(self.queue_name_prefix + queue)
|
||||
props = self.queue_mgmt_service.get_queue_runtime_properties(queue)
|
||||
|
||||
return props.total_message_count
|
||||
|
||||
def _purge(self, queue) -> int:
|
||||
"""Delete all current messages in a queue."""
|
||||
# Azure doesn't provide a purge api yet
|
||||
n = 0
|
||||
max_purge_count = 10
|
||||
queue = self.entity_name(self.queue_name_prefix + queue)
|
||||
|
||||
# By default all the receivers will be in PEEK_LOCK receive mode
|
||||
queue_obj = self._queue_cache.get(queue, None)
|
||||
if queue not in self._noack_queues or \
|
||||
queue_obj is None or queue_obj.receiver is None:
|
||||
queue_obj = self._get_asb_receiver(
|
||||
queue,
|
||||
ServiceBusReceiveMode.RECEIVE_AND_DELETE, 'purge_' + queue
|
||||
)
|
||||
|
||||
while True:
|
||||
messages = queue_obj.receiver.receive_messages(
|
||||
max_message_count=max_purge_count,
|
||||
max_wait_time=0.2
|
||||
)
|
||||
n += len(messages)
|
||||
|
||||
if len(messages) < max_purge_count:
|
||||
break
|
||||
|
||||
return n
|
||||
|
||||
def close(self) -> None:
|
||||
# receivers and senders spawn threads so clean them up
|
||||
if not self.closed:
|
||||
self.closed = True
|
||||
for queue_obj in self._queue_cache.values():
|
||||
queue_obj.close()
|
||||
self._queue_cache.clear()
|
||||
|
||||
if self.connection is not None:
|
||||
self.connection.close_channel(self)
|
||||
|
||||
@cached_property
|
||||
def queue_service(self) -> ServiceBusClient:
|
||||
if self._connection_string:
|
||||
return ServiceBusClient.from_connection_string(
|
||||
self._connection_string,
|
||||
retry_total=self.retry_total,
|
||||
retry_backoff_factor=self.retry_backoff_factor,
|
||||
retry_backoff_max=self.retry_backoff_max
|
||||
)
|
||||
|
||||
return ServiceBusClient(
|
||||
self._namespace,
|
||||
self._credential,
|
||||
retry_total=self.retry_total,
|
||||
retry_backoff_factor=self.retry_backoff_factor,
|
||||
retry_backoff_max=self.retry_backoff_max
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def queue_mgmt_service(self) -> ServiceBusAdministrationClient:
|
||||
if self._connection_string:
|
||||
return ServiceBusAdministrationClient.from_connection_string(
|
||||
self._connection_string
|
||||
)
|
||||
|
||||
return ServiceBusAdministrationClient(
|
||||
self._namespace, self._credential
|
||||
)
|
||||
|
||||
@property
|
||||
def conninfo(self):
|
||||
return self.connection.client
|
||||
|
||||
@property
|
||||
def transport_options(self):
|
||||
return self.connection.client.transport_options
|
||||
|
||||
@cached_property
|
||||
def queue_name_prefix(self) -> str:
|
||||
return self.transport_options.get('queue_name_prefix', '')
|
||||
|
||||
@cached_property
|
||||
def wait_time_seconds(self) -> int:
|
||||
return self.transport_options.get('wait_time_seconds',
|
||||
self.default_wait_time_seconds)
|
||||
|
||||
@cached_property
|
||||
def peek_lock_seconds(self) -> int:
|
||||
return min(self.transport_options.get('peek_lock_seconds',
|
||||
self.default_peek_lock_seconds),
|
||||
300) # Limit upper bounds to 300
|
||||
|
||||
@cached_property
|
||||
def uamqp_keep_alive_interval(self) -> int:
|
||||
return self.transport_options.get(
|
||||
'uamqp_keep_alive_interval',
|
||||
self.default_uamqp_keep_alive_interval
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def retry_total(self) -> int:
|
||||
return self.transport_options.get(
|
||||
'retry_total', self.default_retry_total)
|
||||
|
||||
@cached_property
|
||||
def retry_backoff_factor(self) -> float:
|
||||
return self.transport_options.get(
|
||||
'retry_backoff_factor', self.default_retry_backoff_factor)
|
||||
|
||||
@cached_property
|
||||
def retry_backoff_max(self) -> int:
|
||||
return self.transport_options.get(
|
||||
'retry_backoff_max', self.default_retry_backoff_max)
|
||||
|
||||
|
||||
class Transport(virtual.Transport):
|
||||
"""Azure Service Bus transport."""
|
||||
|
||||
Channel = Channel
|
||||
|
||||
polling_interval = 1
|
||||
default_port = None
|
||||
can_parse_url = True
|
||||
|
||||
@staticmethod
|
||||
def parse_uri(uri: str) -> tuple[str, str | DefaultAzureCredential |
|
||||
ManagedIdentityCredential]:
|
||||
# URL like:
|
||||
# azureservicebus://{SAS policy name}:{SAS key}@{ServiceBus Namespace}
|
||||
# urllib parse does not work as the sas key could contain a slash
|
||||
# e.g.: azureservicebus://rootpolicy:some/key@somenamespace
|
||||
|
||||
# > 'rootpolicy:some/key@somenamespace'
|
||||
uri = uri.replace('azureservicebus://', '')
|
||||
# > 'rootpolicy:some/key', 'somenamespace'
|
||||
credential, namespace = uri.rsplit('@', 1)
|
||||
|
||||
if not namespace.endswith('.net'):
|
||||
namespace += '.servicebus.windows.net'
|
||||
|
||||
if "DefaultAzureCredential".lower() == credential.lower():
|
||||
if DefaultAzureCredential is None:
|
||||
raise ImportError('Azure Service Bus transport with a '
|
||||
'DefaultAzureCredential requires the '
|
||||
'azure-identity library')
|
||||
credential = DefaultAzureCredential()
|
||||
elif "ManagedIdentityCredential".lower() == credential.lower():
|
||||
if ManagedIdentityCredential is None:
|
||||
raise ImportError('Azure Service Bus transport with a '
|
||||
'ManagedIdentityCredential requires the '
|
||||
'azure-identity library')
|
||||
credential = ManagedIdentityCredential()
|
||||
else:
|
||||
# > 'rootpolicy', 'some/key'
|
||||
policy, sas_key = credential.split(':', 1)
|
||||
credential = f"{policy}:{sas_key}"
|
||||
|
||||
# Validate ASB connection string
|
||||
if not all([namespace, credential]):
|
||||
raise ValueError(
|
||||
'Need a URI like '
|
||||
'azureservicebus://{SAS policy name}:{SAS key}@{ServiceBus Namespace} ' # noqa
|
||||
'or the azure Endpoint connection string'
|
||||
)
|
||||
|
||||
return namespace, credential
|
||||
|
||||
@classmethod
|
||||
def as_uri(cls, uri: str, include_password=False, mask='**') -> str:
|
||||
namespace, credential = cls.parse_uri(uri)
|
||||
if isinstance(credential, str) and ":" in credential:
|
||||
policy, sas_key = credential.split(':', 1)
|
||||
return 'azureservicebus://{}:{}@{}'.format(
|
||||
policy,
|
||||
sas_key if include_password else mask,
|
||||
namespace
|
||||
)
|
||||
|
||||
return 'azureservicebus://{}@{}'.format(
|
||||
credential.__class__.__name__,
|
||||
namespace
|
||||
)
|
||||
@@ -0,0 +1,263 @@
|
||||
"""Azure Storage Queues transport module for kombu.
|
||||
|
||||
More information about Azure Storage Queues:
|
||||
https://azure.microsoft.com/en-us/services/storage/queues/
|
||||
|
||||
Features
|
||||
========
|
||||
* Type: Virtual
|
||||
* Supports Direct: *Unreviewed*
|
||||
* Supports Topic: *Unreviewed*
|
||||
* Supports Fanout: *Unreviewed*
|
||||
* Supports Priority: *Unreviewed*
|
||||
* Supports TTL: *Unreviewed*
|
||||
|
||||
Connection String
|
||||
=================
|
||||
|
||||
Connection string has the following formats:
|
||||
|
||||
.. code-block::
|
||||
|
||||
azurestoragequeues://<STORAGE_ACCOUNT_ACCESS_KEY>@<STORAGE_ACCOUNT_URL>
|
||||
azurestoragequeues://<SAS_TOKEN>@<STORAGE_ACCOUNT_URL>
|
||||
azurestoragequeues://DefaultAzureCredential@<STORAGE_ACCOUNT_URL>
|
||||
azurestoragequeues://ManagedIdentityCredential@<STORAGE_ACCOUNT_URL>
|
||||
|
||||
Note that if the access key for the storage account contains a forward slash
|
||||
(``/``), it will have to be regenerated before it can be used in the connection
|
||||
URL.
|
||||
|
||||
.. code-block::
|
||||
|
||||
azurestoragequeues://DefaultAzureCredential@<STORAGE_ACCOUNT_URL>
|
||||
azurestoragequeues://ManagedIdentityCredential@<STORAGE_ACCOUNT_URL>
|
||||
|
||||
If you wish to use an `Azure Managed Identity` you may use the
|
||||
``DefaultAzureCredential`` format of the connection string which will use
|
||||
``DefaultAzureCredential`` class in the azure-identity package. You may want to
|
||||
read the `azure-identity documentation` for more information on how the
|
||||
``DefaultAzureCredential`` works.
|
||||
|
||||
.. _azure-identity documentation:
|
||||
https://learn.microsoft.com/en-us/python/api/overview/azure/identity-readme?view=azure-python
|
||||
.. _Azure Managed Identity:
|
||||
https://learn.microsoft.com/en-us/azure/active-directory/managed-identities-azure-resources/overview
|
||||
|
||||
Transport Options
|
||||
=================
|
||||
|
||||
* ``queue_name_prefix``
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import string
|
||||
from queue import Empty
|
||||
from typing import Any
|
||||
|
||||
from azure.core.exceptions import ResourceExistsError
|
||||
|
||||
from kombu.utils.encoding import safe_str
|
||||
from kombu.utils.json import dumps, loads
|
||||
from kombu.utils.objects import cached_property
|
||||
|
||||
from . import virtual
|
||||
|
||||
try:
|
||||
from azure.storage.queue import QueueServiceClient
|
||||
except ImportError: # pragma: no cover
|
||||
QueueServiceClient = None
|
||||
|
||||
try:
|
||||
from azure.identity import (DefaultAzureCredential,
|
||||
ManagedIdentityCredential)
|
||||
except ImportError:
|
||||
DefaultAzureCredential = None
|
||||
ManagedIdentityCredential = None
|
||||
|
||||
# Azure storage queues allow only alphanumeric and dashes
|
||||
# so, replace everything with a dash
|
||||
CHARS_REPLACE_TABLE = {
|
||||
ord(c): 0x2d for c in string.punctuation
|
||||
}
|
||||
|
||||
|
||||
class Channel(virtual.Channel):
|
||||
"""Azure Storage Queues channel."""
|
||||
|
||||
domain_format: str = 'kombu%(vhost)s'
|
||||
_queue_service: QueueServiceClient | None = None
|
||||
_queue_name_cache: dict[Any, Any] = {}
|
||||
no_ack: bool = True
|
||||
_noack_queues: set[Any] = set()
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
if QueueServiceClient is None:
|
||||
raise ImportError('Azure Storage Queues transport requires the '
|
||||
'azure-storage-queue library')
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self._credential, self._url = Transport.parse_uri(
|
||||
self.conninfo.hostname
|
||||
)
|
||||
|
||||
for queue in self.queue_service.list_queues():
|
||||
self._queue_name_cache[queue['name']] = queue
|
||||
|
||||
def basic_consume(self, queue, no_ack, *args, **kwargs):
|
||||
if no_ack:
|
||||
self._noack_queues.add(queue)
|
||||
|
||||
return super().basic_consume(queue, no_ack,
|
||||
*args, **kwargs)
|
||||
|
||||
def entity_name(self, name, table=CHARS_REPLACE_TABLE) -> str:
|
||||
"""Format AMQP queue name into a valid Azure Storage Queue name."""
|
||||
return str(safe_str(name)).translate(table)
|
||||
|
||||
def _ensure_queue(self, queue):
|
||||
"""Ensure a queue exists."""
|
||||
queue = self.entity_name(self.queue_name_prefix + queue)
|
||||
try:
|
||||
q = self._queue_service.get_queue_client(
|
||||
queue=self._queue_name_cache[queue]
|
||||
)
|
||||
except KeyError:
|
||||
try:
|
||||
q = self.queue_service.create_queue(queue)
|
||||
except ResourceExistsError:
|
||||
q = self._queue_service.get_queue_client(queue=queue)
|
||||
|
||||
self._queue_name_cache[queue] = q.get_queue_properties()
|
||||
return q
|
||||
|
||||
def _delete(self, queue, *args, **kwargs):
|
||||
"""Delete queue by name."""
|
||||
queue_name = self.entity_name(queue)
|
||||
self._queue_name_cache.pop(queue_name, None)
|
||||
self.queue_service.delete_queue(queue_name)
|
||||
|
||||
def _put(self, queue, message, **kwargs):
|
||||
"""Put message onto queue."""
|
||||
q = self._ensure_queue(queue)
|
||||
encoded_message = dumps(message)
|
||||
q.send_message(encoded_message)
|
||||
|
||||
def _get(self, queue, timeout=None):
|
||||
"""Try to retrieve a single message off ``queue``."""
|
||||
q = self._ensure_queue(queue)
|
||||
|
||||
messages = q.receive_messages(messages_per_page=1, timeout=timeout)
|
||||
try:
|
||||
message = next(messages)
|
||||
except StopIteration:
|
||||
raise Empty()
|
||||
|
||||
content = loads(message.content)
|
||||
|
||||
q.delete_message(message=message)
|
||||
|
||||
return content
|
||||
|
||||
def _size(self, queue):
|
||||
"""Return the number of messages in a queue."""
|
||||
q = self._ensure_queue(queue)
|
||||
return q.get_queue_properties().approximate_message_count
|
||||
|
||||
def _purge(self, queue):
|
||||
"""Delete all current messages in a queue."""
|
||||
q = self._ensure_queue(queue)
|
||||
n = self._size(q.queue_name)
|
||||
q.clear_messages()
|
||||
return n
|
||||
|
||||
@property
|
||||
def queue_service(self) -> QueueServiceClient:
|
||||
if self._queue_service is None:
|
||||
self._queue_service = QueueServiceClient(
|
||||
account_url=self._url, credential=self._credential
|
||||
)
|
||||
|
||||
return self._queue_service
|
||||
|
||||
@property
|
||||
def conninfo(self):
|
||||
return self.connection.client
|
||||
|
||||
@property
|
||||
def transport_options(self):
|
||||
return self.connection.client.transport_options
|
||||
|
||||
@cached_property
|
||||
def queue_name_prefix(self) -> str:
|
||||
return self.transport_options.get('queue_name_prefix', '')
|
||||
|
||||
|
||||
class Transport(virtual.Transport):
|
||||
"""Azure Storage Queues transport."""
|
||||
|
||||
Channel = Channel
|
||||
|
||||
polling_interval: int = 1
|
||||
default_port: int | None = None
|
||||
can_parse_url: bool = True
|
||||
|
||||
@staticmethod
|
||||
def parse_uri(uri: str) -> tuple[str | dict, str]:
|
||||
# URL like:
|
||||
# azurestoragequeues://<STORAGE_ACCOUNT_ACCESS_KEY>@<STORAGE_ACCOUNT_URL>
|
||||
# azurestoragequeues://<SAS_TOKEN>@<STORAGE_ACCOUNT_URL>
|
||||
# azurestoragequeues://DefaultAzureCredential@<STORAGE_ACCOUNT_URL>
|
||||
# azurestoragequeues://ManagedIdentityCredential@<STORAGE_ACCOUNT_URL>
|
||||
|
||||
# urllib parse does not work as the sas key could contain a slash
|
||||
# e.g.: azurestoragequeues://some/key@someurl
|
||||
|
||||
try:
|
||||
# > 'some/key@url'
|
||||
uri = uri.replace('azurestoragequeues://', '')
|
||||
# > 'some/key', 'url'
|
||||
credential, url = uri.rsplit('@', 1)
|
||||
|
||||
if "DefaultAzureCredential".lower() == credential.lower():
|
||||
if DefaultAzureCredential is None:
|
||||
raise ImportError('Azure Storage Queues transport with a '
|
||||
'DefaultAzureCredential requires the '
|
||||
'azure-identity library')
|
||||
credential = DefaultAzureCredential()
|
||||
elif "ManagedIdentityCredential".lower() == credential.lower():
|
||||
if ManagedIdentityCredential is None:
|
||||
raise ImportError('Azure Storage Queues transport with a '
|
||||
'ManagedIdentityCredential requires the '
|
||||
'azure-identity library')
|
||||
credential = ManagedIdentityCredential()
|
||||
elif "devstoreaccount1" in url and ".core.windows.net" not in url:
|
||||
# parse credential as a dict if Azurite is being used
|
||||
credential = {
|
||||
"account_name": "devstoreaccount1",
|
||||
"account_key": credential,
|
||||
}
|
||||
|
||||
# Validate parameters
|
||||
assert all([credential, url])
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
'Need a URI like '
|
||||
'azurestoragequeues://{SAS or access key}@{URL}, '
|
||||
'azurestoragequeues://DefaultAzureCredential@{URL}, '
|
||||
', or '
|
||||
'azurestoragequeues://ManagedIdentityCredential@{URL}'
|
||||
)
|
||||
|
||||
return credential, url
|
||||
|
||||
@classmethod
|
||||
def as_uri(
|
||||
cls, uri: str, include_password: bool = False, mask: str = "**"
|
||||
) -> str:
|
||||
credential, url = cls.parse_uri(uri)
|
||||
return "azurestoragequeues://{}@{}".format(
|
||||
credential if include_password else mask, url
|
||||
)
|
||||
@@ -0,0 +1,271 @@
|
||||
"""Base transport interface."""
|
||||
# flake8: noqa
|
||||
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import errno
|
||||
import socket
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from amqp.exceptions import RecoverableConnectionError
|
||||
|
||||
from kombu.exceptions import ChannelError, ConnectionError
|
||||
from kombu.message import Message
|
||||
from kombu.utils.functional import dictfilter
|
||||
from kombu.utils.objects import cached_property
|
||||
from kombu.utils.time import maybe_s_to_ms
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from types import TracebackType
|
||||
|
||||
__all__ = ('Message', 'StdChannel', 'Management', 'Transport')
|
||||
|
||||
RABBITMQ_QUEUE_ARGUMENTS = {
|
||||
'expires': ('x-expires', maybe_s_to_ms),
|
||||
'message_ttl': ('x-message-ttl', maybe_s_to_ms),
|
||||
'max_length': ('x-max-length', int),
|
||||
'max_length_bytes': ('x-max-length-bytes', int),
|
||||
'max_priority': ('x-max-priority', int),
|
||||
} # type: Mapping[str, Tuple[str, Callable]]
|
||||
|
||||
|
||||
def to_rabbitmq_queue_arguments(arguments, **options):
|
||||
# type: (Mapping, **Any) -> Dict
|
||||
"""Convert queue arguments to RabbitMQ queue arguments.
|
||||
|
||||
This is the implementation for Channel.prepare_queue_arguments
|
||||
for AMQP-based transports. It's used by both the pyamqp and librabbitmq
|
||||
transports.
|
||||
|
||||
Arguments:
|
||||
arguments (Mapping):
|
||||
User-supplied arguments (``Queue.queue_arguments``).
|
||||
|
||||
Keyword Arguments:
|
||||
expires (float): Queue expiry time in seconds.
|
||||
This will be converted to ``x-expires`` in int milliseconds.
|
||||
message_ttl (float): Message TTL in seconds.
|
||||
This will be converted to ``x-message-ttl`` in int milliseconds.
|
||||
max_length (int): Max queue length (in number of messages).
|
||||
This will be converted to ``x-max-length`` int.
|
||||
max_length_bytes (int): Max queue size in bytes.
|
||||
This will be converted to ``x-max-length-bytes`` int.
|
||||
max_priority (int): Max priority steps for queue.
|
||||
This will be converted to ``x-max-priority`` int.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Dict: RabbitMQ compatible queue arguments.
|
||||
"""
|
||||
prepared = dictfilter(dict(
|
||||
_to_rabbitmq_queue_argument(key, value)
|
||||
for key, value in options.items()
|
||||
))
|
||||
return dict(arguments, **prepared) if prepared else arguments
|
||||
|
||||
|
||||
def _to_rabbitmq_queue_argument(key, value):
|
||||
# type: (str, Any) -> Tuple[str, Any]
|
||||
opt, typ = RABBITMQ_QUEUE_ARGUMENTS[key]
|
||||
return opt, typ(value) if value is not None else value
|
||||
|
||||
|
||||
def _LeftBlank(obj, method):
|
||||
return NotImplementedError(
|
||||
'Transport {0.__module__}.{0.__name__} does not implement {1}'.format(
|
||||
obj.__class__, method))
|
||||
|
||||
|
||||
class StdChannel:
|
||||
"""Standard channel base class."""
|
||||
|
||||
no_ack_consumers = None
|
||||
|
||||
def Consumer(self, *args, **kwargs):
|
||||
from kombu.messaging import Consumer
|
||||
return Consumer(self, *args, **kwargs)
|
||||
|
||||
def Producer(self, *args, **kwargs):
|
||||
from kombu.messaging import Producer
|
||||
return Producer(self, *args, **kwargs)
|
||||
|
||||
def get_bindings(self):
|
||||
raise _LeftBlank(self, 'get_bindings')
|
||||
|
||||
def after_reply_message_received(self, queue):
|
||||
"""Callback called after RPC reply received.
|
||||
|
||||
Notes
|
||||
-----
|
||||
Reply queue semantics: can be used to delete the queue
|
||||
after transient reply message received.
|
||||
"""
|
||||
|
||||
def prepare_queue_arguments(self, arguments, **kwargs):
|
||||
return arguments
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: TracebackType | None
|
||||
) -> None:
|
||||
self.close()
|
||||
|
||||
|
||||
class Management:
|
||||
"""AMQP Management API (incomplete)."""
|
||||
|
||||
def __init__(self, transport):
|
||||
self.transport = transport
|
||||
|
||||
def get_bindings(self):
|
||||
raise _LeftBlank(self, 'get_bindings')
|
||||
|
||||
|
||||
class Implements(dict):
|
||||
"""Helper class used to define transport features."""
|
||||
|
||||
def __getattr__(self, key):
|
||||
try:
|
||||
return self[key]
|
||||
except KeyError:
|
||||
raise AttributeError(key)
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
self[key] = value
|
||||
|
||||
def extend(self, **kwargs):
|
||||
return self.__class__(self, **kwargs)
|
||||
|
||||
|
||||
default_transport_capabilities = Implements(
|
||||
asynchronous=False,
|
||||
exchange_type=frozenset(['direct', 'topic', 'fanout', 'headers']),
|
||||
heartbeats=False,
|
||||
)
|
||||
|
||||
|
||||
class Transport:
|
||||
"""Base class for transports."""
|
||||
|
||||
Management = Management
|
||||
|
||||
#: The :class:`~kombu.Connection` owning this instance.
|
||||
client = None
|
||||
|
||||
#: Set to True if :class:`~kombu.Connection` should pass the URL
|
||||
#: unmodified.
|
||||
can_parse_url = False
|
||||
|
||||
#: Default port used when no port has been specified.
|
||||
default_port = None
|
||||
|
||||
#: Tuple of errors that can happen due to connection failure.
|
||||
connection_errors = (ConnectionError,)
|
||||
|
||||
#: Tuple of errors that can happen due to channel/method failure.
|
||||
channel_errors = (ChannelError,)
|
||||
|
||||
#: Type of driver, can be used to separate transports
|
||||
#: using the AMQP protocol (driver_type: 'amqp'),
|
||||
#: Redis (driver_type: 'redis'), etc...
|
||||
driver_type = 'N/A'
|
||||
|
||||
#: Name of driver library (e.g. 'py-amqp', 'redis').
|
||||
driver_name = 'N/A'
|
||||
|
||||
__reader = None
|
||||
|
||||
implements = default_transport_capabilities.extend()
|
||||
|
||||
def __init__(self, client, **kwargs):
|
||||
self.client = client
|
||||
|
||||
def establish_connection(self):
|
||||
raise _LeftBlank(self, 'establish_connection')
|
||||
|
||||
def close_connection(self, connection):
|
||||
raise _LeftBlank(self, 'close_connection')
|
||||
|
||||
def create_channel(self, connection):
|
||||
raise _LeftBlank(self, 'create_channel')
|
||||
|
||||
def close_channel(self, connection):
|
||||
raise _LeftBlank(self, 'close_channel')
|
||||
|
||||
def drain_events(self, connection, **kwargs):
|
||||
raise _LeftBlank(self, 'drain_events')
|
||||
|
||||
def heartbeat_check(self, connection, rate=2):
|
||||
pass
|
||||
|
||||
def driver_version(self):
|
||||
return 'N/A'
|
||||
|
||||
def get_heartbeat_interval(self, connection):
|
||||
return 0
|
||||
|
||||
def register_with_event_loop(self, connection, loop):
|
||||
pass
|
||||
|
||||
def unregister_from_event_loop(self, connection, loop):
|
||||
pass
|
||||
|
||||
def verify_connection(self, connection):
|
||||
return True
|
||||
|
||||
def _make_reader(self, connection, timeout=socket.timeout,
|
||||
error=socket.error, _unavail=(errno.EAGAIN, errno.EINTR)):
|
||||
drain_events = connection.drain_events
|
||||
|
||||
def _read(loop):
|
||||
if not connection.connected:
|
||||
raise RecoverableConnectionError('Socket was disconnected')
|
||||
try:
|
||||
drain_events(timeout=0)
|
||||
except timeout:
|
||||
return
|
||||
except error as exc:
|
||||
if exc.errno in _unavail:
|
||||
return
|
||||
raise
|
||||
loop.call_soon(_read, loop)
|
||||
|
||||
return _read
|
||||
|
||||
def qos_semantics_matches_spec(self, connection):
|
||||
return True
|
||||
|
||||
def on_readable(self, connection, loop):
|
||||
reader = self.__reader
|
||||
if reader is None:
|
||||
reader = self.__reader = self._make_reader(connection)
|
||||
reader(loop)
|
||||
|
||||
def as_uri(self, uri: str, include_password=False, mask='**') -> str:
|
||||
"""Customise the display format of the URI."""
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def default_connection_params(self):
|
||||
return {}
|
||||
|
||||
def get_manager(self, *args, **kwargs):
|
||||
return self.Management(self)
|
||||
|
||||
@cached_property
|
||||
def manager(self):
|
||||
return self.get_manager()
|
||||
|
||||
@property
|
||||
def supports_heartbeats(self):
|
||||
return self.implements.heartbeats
|
||||
|
||||
@property
|
||||
def supports_ev(self):
|
||||
return self.implements.asynchronous
|
||||
@@ -0,0 +1,380 @@
|
||||
"""confluent-kafka transport module for Kombu.
|
||||
|
||||
Kafka transport using confluent-kafka library.
|
||||
|
||||
**References**
|
||||
|
||||
- http://docs.confluent.io/current/clients/confluent-kafka-python
|
||||
|
||||
**Limitations**
|
||||
|
||||
The confluent-kafka transport does not support PyPy environment.
|
||||
|
||||
Features
|
||||
========
|
||||
* Type: Virtual
|
||||
* Supports Direct: Yes
|
||||
* Supports Topic: Yes
|
||||
* Supports Fanout: No
|
||||
* Supports Priority: No
|
||||
* Supports TTL: No
|
||||
|
||||
Connection String
|
||||
=================
|
||||
Connection string has the following format:
|
||||
|
||||
.. code-block::
|
||||
|
||||
confluentkafka://[USER:PASSWORD@]KAFKA_ADDRESS[:PORT]
|
||||
|
||||
Transport Options
|
||||
=================
|
||||
* ``connection_wait_time_seconds`` - Time in seconds to wait for connection
|
||||
to succeed. Default ``5``
|
||||
* ``wait_time_seconds`` - Time in seconds to wait to receive messages.
|
||||
Default ``5``
|
||||
* ``security_protocol`` - Protocol used to communicate with broker.
|
||||
Visit https://github.com/edenhill/librdkafka/blob/master/CONFIGURATION.md for
|
||||
an explanation of valid values. Default ``plaintext``
|
||||
* ``sasl_mechanism`` - SASL mechanism to use for authentication.
|
||||
Visit https://github.com/edenhill/librdkafka/blob/master/CONFIGURATION.md for
|
||||
an explanation of valid values.
|
||||
* ``num_partitions`` - Number of partitions to create. Default ``1``
|
||||
* ``replication_factor`` - Replication factor of partitions. Default ``1``
|
||||
* ``topic_config`` - Topic configuration. Must be a dict whose key-value pairs
|
||||
correspond with attributes in the
|
||||
http://kafka.apache.org/documentation.html#topicconfigs.
|
||||
* ``kafka_common_config`` - Configuration applied to producer, consumer and
|
||||
admin client. Must be a dict whose key-value pairs correspond with attributes
|
||||
in the https://github.com/edenhill/librdkafka/blob/master/CONFIGURATION.md.
|
||||
* ``kafka_producer_config`` - Producer configuration. Must be a dict whose
|
||||
key-value pairs correspond with attributes in the
|
||||
https://github.com/edenhill/librdkafka/blob/master/CONFIGURATION.md.
|
||||
* ``kafka_consumer_config`` - Consumer configuration. Must be a dict whose
|
||||
key-value pairs correspond with attributes in the
|
||||
https://github.com/edenhill/librdkafka/blob/master/CONFIGURATION.md.
|
||||
* ``kafka_admin_config`` - Admin client configuration. Must be a dict whose
|
||||
key-value pairs correspond with attributes in the
|
||||
https://github.com/edenhill/librdkafka/blob/master/CONFIGURATION.md.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from queue import Empty
|
||||
|
||||
from kombu.transport import virtual
|
||||
from kombu.utils import cached_property
|
||||
from kombu.utils.encoding import str_to_bytes
|
||||
from kombu.utils.json import dumps, loads
|
||||
|
||||
try:
|
||||
import confluent_kafka
|
||||
from confluent_kafka import (Consumer, KafkaException, Producer,
|
||||
TopicPartition)
|
||||
from confluent_kafka.admin import AdminClient, NewTopic
|
||||
|
||||
KAFKA_CONNECTION_ERRORS = ()
|
||||
KAFKA_CHANNEL_ERRORS = ()
|
||||
|
||||
except ImportError:
|
||||
confluent_kafka = None
|
||||
KAFKA_CONNECTION_ERRORS = KAFKA_CHANNEL_ERRORS = ()
|
||||
|
||||
from kombu.log import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
DEFAULT_PORT = 9092
|
||||
|
||||
|
||||
class NoBrokersAvailable(KafkaException):
|
||||
"""Kafka broker is not available exception."""
|
||||
|
||||
retriable = True
|
||||
|
||||
|
||||
class Message(virtual.Message):
|
||||
"""Message object."""
|
||||
|
||||
def __init__(self, payload, channel=None, **kwargs):
|
||||
self.topic = payload.get('topic')
|
||||
super().__init__(payload, channel=channel, **kwargs)
|
||||
|
||||
|
||||
class QoS(virtual.QoS):
|
||||
"""Quality of Service guarantees."""
|
||||
|
||||
_not_yet_acked = {}
|
||||
|
||||
def can_consume(self):
|
||||
"""Return true if the channel can be consumed from.
|
||||
|
||||
:returns: True, if this QoS object can accept a message.
|
||||
:rtype: bool
|
||||
"""
|
||||
return not self.prefetch_count or len(self._not_yet_acked) < self \
|
||||
.prefetch_count
|
||||
|
||||
def can_consume_max_estimate(self):
|
||||
if self.prefetch_count:
|
||||
return self.prefetch_count - len(self._not_yet_acked)
|
||||
else:
|
||||
return 1
|
||||
|
||||
def append(self, message, delivery_tag):
|
||||
self._not_yet_acked[delivery_tag] = message
|
||||
|
||||
def get(self, delivery_tag):
|
||||
return self._not_yet_acked[delivery_tag]
|
||||
|
||||
def ack(self, delivery_tag):
|
||||
if delivery_tag not in self._not_yet_acked:
|
||||
return
|
||||
message = self._not_yet_acked.pop(delivery_tag)
|
||||
consumer = self.channel._get_consumer(message.topic)
|
||||
consumer.commit()
|
||||
|
||||
def reject(self, delivery_tag, requeue=False):
|
||||
"""Reject a message by delivery tag.
|
||||
|
||||
If requeue is True, then the last consumed message is reverted so
|
||||
it'll be refetched on the next attempt.
|
||||
If False, that message is consumed and ignored.
|
||||
"""
|
||||
if requeue:
|
||||
message = self._not_yet_acked.pop(delivery_tag)
|
||||
consumer = self.channel._get_consumer(message.topic)
|
||||
for assignment in consumer.assignment():
|
||||
topic_partition = TopicPartition(message.topic,
|
||||
assignment.partition)
|
||||
[committed_offset] = consumer.committed([topic_partition])
|
||||
consumer.seek(committed_offset)
|
||||
else:
|
||||
self.ack(delivery_tag)
|
||||
|
||||
def restore_unacked_once(self, stderr=None):
|
||||
pass
|
||||
|
||||
|
||||
class Channel(virtual.Channel):
|
||||
"""Kafka Channel."""
|
||||
|
||||
QoS = QoS
|
||||
Message = Message
|
||||
|
||||
default_wait_time_seconds = 5
|
||||
default_connection_wait_time_seconds = 5
|
||||
_client = None
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self._kafka_consumers = {}
|
||||
self._kafka_producers = {}
|
||||
|
||||
self._client = self._open()
|
||||
|
||||
def sanitize_queue_name(self, queue):
|
||||
"""Need to sanitize the name, celery sometimes pushes in @ signs."""
|
||||
return str(queue).replace('@', '')
|
||||
|
||||
def _get_producer(self, queue):
|
||||
"""Create/get a producer instance for the given topic/queue."""
|
||||
queue = self.sanitize_queue_name(queue)
|
||||
producer = self._kafka_producers.get(queue, None)
|
||||
if producer is None:
|
||||
producer = Producer({
|
||||
**self.common_config,
|
||||
**(self.options.get('kafka_producer_config') or {}),
|
||||
})
|
||||
self._kafka_producers[queue] = producer
|
||||
|
||||
return producer
|
||||
|
||||
def _get_consumer(self, queue):
|
||||
"""Create/get a consumer instance for the given topic/queue."""
|
||||
queue = self.sanitize_queue_name(queue)
|
||||
consumer = self._kafka_consumers.get(queue, None)
|
||||
if consumer is None:
|
||||
consumer = Consumer({
|
||||
'group.id': f'{queue}-consumer-group',
|
||||
'auto.offset.reset': 'earliest',
|
||||
'enable.auto.commit': False,
|
||||
**self.common_config,
|
||||
**(self.options.get('kafka_consumer_config') or {}),
|
||||
})
|
||||
consumer.subscribe([queue])
|
||||
self._kafka_consumers[queue] = consumer
|
||||
|
||||
return consumer
|
||||
|
||||
def _put(self, queue, message, **kwargs):
|
||||
"""Put a message on the topic/queue."""
|
||||
queue = self.sanitize_queue_name(queue)
|
||||
producer = self._get_producer(queue)
|
||||
producer.produce(queue, str_to_bytes(dumps(message)))
|
||||
producer.flush()
|
||||
|
||||
def _get(self, queue, **kwargs):
|
||||
"""Get a message from the topic/queue."""
|
||||
queue = self.sanitize_queue_name(queue)
|
||||
consumer = self._get_consumer(queue)
|
||||
message = None
|
||||
|
||||
try:
|
||||
message = consumer.poll(self.wait_time_seconds)
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
if not message:
|
||||
raise Empty()
|
||||
|
||||
error = message.error()
|
||||
if error:
|
||||
logger.error(error)
|
||||
raise Empty()
|
||||
|
||||
return {**loads(message.value()), 'topic': message.topic()}
|
||||
|
||||
def _delete(self, queue, *args, **kwargs):
|
||||
"""Delete a queue/topic."""
|
||||
queue = self.sanitize_queue_name(queue)
|
||||
self._kafka_consumers[queue].close()
|
||||
self._kafka_consumers.pop(queue)
|
||||
self.client.delete_topics([queue])
|
||||
|
||||
def _size(self, queue):
|
||||
"""Get the number of pending messages in the topic/queue."""
|
||||
queue = self.sanitize_queue_name(queue)
|
||||
|
||||
consumer = self._kafka_consumers.get(queue, None)
|
||||
if consumer is None:
|
||||
return 0
|
||||
|
||||
size = 0
|
||||
for assignment in consumer.assignment():
|
||||
topic_partition = TopicPartition(queue, assignment.partition)
|
||||
(_, end_offset) = consumer.get_watermark_offsets(topic_partition)
|
||||
[committed_offset] = consumer.committed([topic_partition])
|
||||
size += end_offset - committed_offset.offset
|
||||
return size
|
||||
|
||||
def _new_queue(self, queue, **kwargs):
|
||||
"""Create a new topic if it does not exist."""
|
||||
queue = self.sanitize_queue_name(queue)
|
||||
if queue in self.client.list_topics().topics:
|
||||
return
|
||||
|
||||
topic = NewTopic(
|
||||
queue,
|
||||
num_partitions=self.options.get('num_partitions', 1),
|
||||
replication_factor=self.options.get('replication_factor', 1),
|
||||
config=self.options.get('topic_config', {})
|
||||
)
|
||||
self.client.create_topics(new_topics=[topic])
|
||||
|
||||
def _has_queue(self, queue, **kwargs):
|
||||
"""Check if a topic already exists."""
|
||||
queue = self.sanitize_queue_name(queue)
|
||||
return queue in self.client.list_topics().topics
|
||||
|
||||
def _open(self):
|
||||
client = AdminClient({
|
||||
**self.common_config,
|
||||
**(self.options.get('kafka_admin_config') or {}),
|
||||
})
|
||||
|
||||
try:
|
||||
# seems to be the only way to check connection
|
||||
client.list_topics(timeout=self.wait_time_seconds)
|
||||
except confluent_kafka.KafkaException as e:
|
||||
raise NoBrokersAvailable(e)
|
||||
|
||||
return client
|
||||
|
||||
@property
|
||||
def client(self):
|
||||
if self._client is None:
|
||||
self._client = self._open()
|
||||
return self._client
|
||||
|
||||
@property
|
||||
def options(self):
|
||||
return self.connection.client.transport_options
|
||||
|
||||
@property
|
||||
def conninfo(self):
|
||||
return self.connection.client
|
||||
|
||||
@cached_property
|
||||
def wait_time_seconds(self):
|
||||
return self.options.get(
|
||||
'wait_time_seconds', self.default_wait_time_seconds
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def connection_wait_time_seconds(self):
|
||||
return self.options.get(
|
||||
'connection_wait_time_seconds',
|
||||
self.default_connection_wait_time_seconds,
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def common_config(self):
|
||||
conninfo = self.connection.client
|
||||
config = {
|
||||
'bootstrap.servers':
|
||||
f'{conninfo.hostname}:{int(conninfo.port) or DEFAULT_PORT}',
|
||||
}
|
||||
security_protocol = self.options.get('security_protocol', 'plaintext')
|
||||
if security_protocol.lower() != 'plaintext':
|
||||
config.update({
|
||||
'security.protocol': security_protocol,
|
||||
'sasl.username': conninfo.userid,
|
||||
'sasl.password': conninfo.password,
|
||||
'sasl.mechanism': self.options.get('sasl_mechanism'),
|
||||
})
|
||||
|
||||
config.update(self.options.get('kafka_common_config') or {})
|
||||
return config
|
||||
|
||||
def close(self):
|
||||
super().close()
|
||||
self._kafka_producers = {}
|
||||
|
||||
for consumer in self._kafka_consumers.values():
|
||||
consumer.close()
|
||||
|
||||
self._kafka_consumers = {}
|
||||
|
||||
|
||||
class Transport(virtual.Transport):
|
||||
"""Kafka Transport."""
|
||||
|
||||
def as_uri(self, uri: str, include_password=False, mask='**') -> str:
|
||||
pass
|
||||
|
||||
Channel = Channel
|
||||
|
||||
default_port = DEFAULT_PORT
|
||||
|
||||
driver_type = 'kafka'
|
||||
driver_name = 'confluentkafka'
|
||||
|
||||
recoverable_connection_errors = (
|
||||
NoBrokersAvailable,
|
||||
)
|
||||
|
||||
def __init__(self, client, **kwargs):
|
||||
if confluent_kafka is None:
|
||||
raise ImportError('The confluent-kafka library is not installed')
|
||||
super().__init__(client, **kwargs)
|
||||
|
||||
def driver_version(self):
|
||||
return confluent_kafka.__version__
|
||||
|
||||
def establish_connection(self):
|
||||
return super().establish_connection()
|
||||
|
||||
def close_connection(self, connection):
|
||||
return super().close_connection(connection)
|
||||
@@ -0,0 +1,323 @@
|
||||
"""Consul Transport module for Kombu.
|
||||
|
||||
Features
|
||||
========
|
||||
|
||||
It uses Consul.io's Key/Value store to transport messages in Queues
|
||||
|
||||
It uses python-consul for talking to Consul's HTTP API
|
||||
|
||||
Features
|
||||
========
|
||||
* Type: Native
|
||||
* Supports Direct: Yes
|
||||
* Supports Topic: *Unreviewed*
|
||||
* Supports Fanout: *Unreviewed*
|
||||
* Supports Priority: *Unreviewed*
|
||||
* Supports TTL: *Unreviewed*
|
||||
|
||||
Connection String
|
||||
=================
|
||||
|
||||
Connection string has the following format:
|
||||
|
||||
.. code-block::
|
||||
|
||||
consul://CONSUL_ADDRESS[:PORT]
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import socket
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from queue import Empty
|
||||
from time import monotonic
|
||||
|
||||
from kombu.exceptions import ChannelError
|
||||
from kombu.log import get_logger
|
||||
from kombu.utils.json import dumps, loads
|
||||
from kombu.utils.objects import cached_property
|
||||
|
||||
from . import virtual
|
||||
|
||||
try:
|
||||
import consul
|
||||
except ImportError:
|
||||
consul = None
|
||||
|
||||
logger = get_logger('kombu.transport.consul')
|
||||
|
||||
DEFAULT_PORT = 8500
|
||||
DEFAULT_HOST = 'localhost'
|
||||
|
||||
|
||||
class LockError(Exception):
|
||||
"""An error occurred while trying to acquire the lock."""
|
||||
|
||||
|
||||
class Channel(virtual.Channel):
|
||||
"""Consul Channel class which talks to the Consul Key/Value store."""
|
||||
|
||||
prefix = 'kombu'
|
||||
index = None
|
||||
timeout = '10s'
|
||||
session_ttl = 30
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
if consul is None:
|
||||
raise ImportError('Missing python-consul library')
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
port = self.connection.client.port or self.connection.default_port
|
||||
host = self.connection.client.hostname or DEFAULT_HOST
|
||||
|
||||
logger.debug('Host: %s Port: %s Timeout: %s', host, port, self.timeout)
|
||||
|
||||
self.queues = defaultdict(dict)
|
||||
|
||||
self.client = consul.Consul(host=host, port=int(port))
|
||||
|
||||
def _lock_key(self, queue):
|
||||
return f'{self.prefix}/{queue}.lock'
|
||||
|
||||
def _key_prefix(self, queue):
|
||||
return f'{self.prefix}/{queue}'
|
||||
|
||||
def _get_or_create_session(self, queue):
|
||||
"""Get or create consul session.
|
||||
|
||||
Try to renew the session if it exists, otherwise create a new
|
||||
session in Consul.
|
||||
|
||||
This session is used to acquire a lock inside Consul so that we achieve
|
||||
read-consistency between the nodes.
|
||||
|
||||
Arguments:
|
||||
---------
|
||||
queue (str): The name of the Queue.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str: The ID of the session.
|
||||
"""
|
||||
try:
|
||||
session_id = self.queues[queue]['session_id']
|
||||
except KeyError:
|
||||
session_id = None
|
||||
return (self._renew_existing_session(session_id)
|
||||
if session_id is not None else self._create_new_session())
|
||||
|
||||
def _renew_existing_session(self, session_id):
|
||||
logger.debug('Trying to renew existing session %s', session_id)
|
||||
session = self.client.session.renew(session_id=session_id)
|
||||
return session.get('ID')
|
||||
|
||||
def _create_new_session(self):
|
||||
logger.debug('Creating session %s with TTL %s',
|
||||
self.lock_name, self.session_ttl)
|
||||
session_id = self.client.session.create(
|
||||
name=self.lock_name, ttl=self.session_ttl)
|
||||
logger.debug('Created session %s with id %s',
|
||||
self.lock_name, session_id)
|
||||
return session_id
|
||||
|
||||
@contextmanager
|
||||
def _queue_lock(self, queue, raising=LockError):
|
||||
"""Try to acquire a lock on the Queue.
|
||||
|
||||
It does so by creating a object called 'lock' which is locked by the
|
||||
current session..
|
||||
|
||||
This way other nodes are not able to write to the lock object which
|
||||
means that they have to wait before the lock is released.
|
||||
|
||||
Arguments:
|
||||
---------
|
||||
queue (str): The name of the Queue.
|
||||
raising (Exception): Set custom lock error class.
|
||||
|
||||
Raises
|
||||
------
|
||||
LockError: if the lock cannot be acquired.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool: success?
|
||||
"""
|
||||
self._acquire_lock(queue, raising=raising)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self._release_lock(queue)
|
||||
|
||||
def _acquire_lock(self, queue, raising=LockError):
|
||||
session_id = self._get_or_create_session(queue)
|
||||
lock_key = self._lock_key(queue)
|
||||
|
||||
logger.debug('Trying to create lock object %s with session %s',
|
||||
lock_key, session_id)
|
||||
|
||||
if self.client.kv.put(key=lock_key,
|
||||
acquire=session_id,
|
||||
value=self.lock_name):
|
||||
self.queues[queue]['session_id'] = session_id
|
||||
return
|
||||
logger.info('Could not acquire lock on key %s', lock_key)
|
||||
raise raising()
|
||||
|
||||
def _release_lock(self, queue):
|
||||
"""Try to release a lock.
|
||||
|
||||
It does so by simply removing the lock key in Consul.
|
||||
|
||||
Arguments:
|
||||
---------
|
||||
queue (str): The name of the queue we want to release
|
||||
the lock from.
|
||||
"""
|
||||
logger.debug('Removing lock key %s', self._lock_key(queue))
|
||||
self.client.kv.delete(key=self._lock_key(queue))
|
||||
|
||||
def _destroy_session(self, queue):
|
||||
"""Destroy a previously created Consul session.
|
||||
|
||||
Will release all locks it still might hold.
|
||||
|
||||
Arguments:
|
||||
---------
|
||||
queue (str): The name of the Queue.
|
||||
"""
|
||||
logger.debug('Destroying session %s', self.queues[queue]['session_id'])
|
||||
self.client.session.destroy(self.queues[queue]['session_id'])
|
||||
|
||||
def _new_queue(self, queue, **_):
|
||||
self.queues[queue] = {'session_id': None}
|
||||
return self.client.kv.put(key=self._key_prefix(queue), value=None)
|
||||
|
||||
def _delete(self, queue, *args, **_):
|
||||
self._destroy_session(queue)
|
||||
self.queues.pop(queue, None)
|
||||
self._purge(queue)
|
||||
|
||||
def _put(self, queue, payload, **_):
|
||||
"""Put `message` onto `queue`.
|
||||
|
||||
This simply writes a key to the K/V store of Consul
|
||||
"""
|
||||
key = '{}/msg/{}_{}'.format(
|
||||
self._key_prefix(queue),
|
||||
int(round(monotonic() * 1000)),
|
||||
uuid.uuid4(),
|
||||
)
|
||||
if not self.client.kv.put(key=key, value=dumps(payload), cas=0):
|
||||
raise ChannelError(f'Cannot add key {key!r} to consul')
|
||||
|
||||
def _get(self, queue, timeout=None):
|
||||
"""Get the first available message from the queue.
|
||||
|
||||
Before it does so it acquires a lock on the Key/Value store so
|
||||
only one node reads at the same time. This is for read consistency
|
||||
"""
|
||||
with self._queue_lock(queue, raising=Empty):
|
||||
key = f'{self._key_prefix(queue)}/msg/'
|
||||
logger.debug('Fetching key %s with index %s', key, self.index)
|
||||
self.index, data = self.client.kv.get(
|
||||
key=key, recurse=True,
|
||||
index=self.index, wait=self.timeout,
|
||||
)
|
||||
|
||||
try:
|
||||
if data is None:
|
||||
raise Empty()
|
||||
|
||||
logger.debug('Removing key %s with modifyindex %s',
|
||||
data[0]['Key'], data[0]['ModifyIndex'])
|
||||
|
||||
self.client.kv.delete(key=data[0]['Key'],
|
||||
cas=data[0]['ModifyIndex'])
|
||||
|
||||
return loads(data[0]['Value'])
|
||||
except TypeError:
|
||||
pass
|
||||
|
||||
raise Empty()
|
||||
|
||||
def _purge(self, queue):
|
||||
self._destroy_session(queue)
|
||||
return self.client.kv.delete(
|
||||
key=f'{self._key_prefix(queue)}/msg/',
|
||||
recurse=True,
|
||||
)
|
||||
|
||||
def _size(self, queue):
|
||||
size = 0
|
||||
try:
|
||||
key = f'{self._key_prefix(queue)}/msg/'
|
||||
logger.debug('Fetching key recursively %s with index %s',
|
||||
key, self.index)
|
||||
self.index, data = self.client.kv.get(
|
||||
key=key, recurse=True,
|
||||
index=self.index, wait=self.timeout,
|
||||
)
|
||||
size = len(data)
|
||||
except TypeError:
|
||||
pass
|
||||
|
||||
logger.debug('Found %s keys under %s with index %s',
|
||||
size, key, self.index)
|
||||
return size
|
||||
|
||||
@cached_property
|
||||
def lock_name(self):
|
||||
return f'{socket.gethostname()}'
|
||||
|
||||
|
||||
class Transport(virtual.Transport):
|
||||
"""Consul K/V storage Transport for Kombu."""
|
||||
|
||||
Channel = Channel
|
||||
|
||||
default_port = DEFAULT_PORT
|
||||
driver_type = 'consul'
|
||||
driver_name = 'consul'
|
||||
|
||||
if consul:
|
||||
connection_errors = (
|
||||
virtual.Transport.connection_errors + (
|
||||
consul.ConsulException, consul.base.ConsulException
|
||||
)
|
||||
)
|
||||
|
||||
channel_errors = (
|
||||
virtual.Transport.channel_errors + (
|
||||
consul.ConsulException, consul.base.ConsulException
|
||||
)
|
||||
)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
if consul is None:
|
||||
raise ImportError('Missing python-consul library')
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def verify_connection(self, connection):
|
||||
port = connection.client.port or self.default_port
|
||||
host = connection.client.hostname or DEFAULT_HOST
|
||||
|
||||
logger.debug('Verify Consul connection to %s:%s', host, port)
|
||||
|
||||
try:
|
||||
client = consul.Consul(host=host, port=int(port))
|
||||
client.agent.self()
|
||||
return True
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return False
|
||||
|
||||
def driver_version(self):
|
||||
return consul.__version__
|
||||
@@ -0,0 +1,300 @@
|
||||
"""Etcd Transport module for Kombu.
|
||||
|
||||
It uses Etcd as a store to transport messages in Queues
|
||||
|
||||
It uses python-etcd for talking to Etcd's HTTP API
|
||||
|
||||
Features
|
||||
========
|
||||
* Type: Virtual
|
||||
* Supports Direct: *Unreviewed*
|
||||
* Supports Topic: *Unreviewed*
|
||||
* Supports Fanout: *Unreviewed*
|
||||
* Supports Priority: *Unreviewed*
|
||||
* Supports TTL: *Unreviewed*
|
||||
|
||||
Connection String
|
||||
=================
|
||||
|
||||
Connection string has the following format:
|
||||
|
||||
.. code-block::
|
||||
|
||||
'etcd'://SERVER:PORT
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import socket
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from queue import Empty
|
||||
|
||||
from kombu.exceptions import ChannelError
|
||||
from kombu.log import get_logger
|
||||
from kombu.utils.json import dumps, loads
|
||||
from kombu.utils.objects import cached_property
|
||||
|
||||
from . import virtual
|
||||
|
||||
try:
|
||||
import etcd
|
||||
except ImportError:
|
||||
etcd = None
|
||||
|
||||
logger = get_logger('kombu.transport.etcd')
|
||||
|
||||
DEFAULT_PORT = 2379
|
||||
DEFAULT_HOST = 'localhost'
|
||||
|
||||
|
||||
class Channel(virtual.Channel):
|
||||
"""Etcd Channel class which talks to the Etcd."""
|
||||
|
||||
prefix = 'kombu'
|
||||
index = None
|
||||
timeout = 10
|
||||
session_ttl = 30
|
||||
lock_ttl = 10
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
if etcd is None:
|
||||
raise ImportError('Missing python-etcd library')
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
port = self.connection.client.port or self.connection.default_port
|
||||
host = self.connection.client.hostname or DEFAULT_HOST
|
||||
|
||||
logger.debug('Host: %s Port: %s Timeout: %s', host, port, self.timeout)
|
||||
|
||||
self.queues = defaultdict(dict)
|
||||
|
||||
self.client = etcd.Client(host=host, port=int(port))
|
||||
|
||||
def _key_prefix(self, queue):
|
||||
"""Create and return the `queue` with the proper prefix.
|
||||
|
||||
Arguments:
|
||||
---------
|
||||
queue (str): The name of the queue.
|
||||
"""
|
||||
return f'{self.prefix}/{queue}'
|
||||
|
||||
@contextmanager
|
||||
def _queue_lock(self, queue):
|
||||
"""Try to acquire a lock on the Queue.
|
||||
|
||||
It does so by creating a object called 'lock' which is locked by the
|
||||
current session..
|
||||
|
||||
This way other nodes are not able to write to the lock object which
|
||||
means that they have to wait before the lock is released.
|
||||
|
||||
Arguments:
|
||||
---------
|
||||
queue (str): The name of the queue.
|
||||
"""
|
||||
lock = etcd.Lock(self.client, queue)
|
||||
lock._uuid = self.lock_value
|
||||
logger.debug(f'Acquiring lock {lock.name}')
|
||||
lock.acquire(blocking=True, lock_ttl=self.lock_ttl)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
logger.debug(f'Releasing lock {lock.name}')
|
||||
lock.release()
|
||||
|
||||
def _new_queue(self, queue, **_):
|
||||
"""Create a new `queue` if the `queue` doesn't already exist.
|
||||
|
||||
Arguments:
|
||||
---------
|
||||
queue (str): The name of the queue.
|
||||
"""
|
||||
self.queues[queue] = queue
|
||||
with self._queue_lock(queue):
|
||||
try:
|
||||
return self.client.write(
|
||||
key=self._key_prefix(queue), dir=True, value=None)
|
||||
except etcd.EtcdNotFile:
|
||||
logger.debug(f'Queue "{queue}" already exists')
|
||||
return self.client.read(key=self._key_prefix(queue))
|
||||
|
||||
def _has_queue(self, queue, **kwargs):
|
||||
"""Verify that queue exists.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool: Should return :const:`True` if the queue exists
|
||||
or :const:`False` otherwise.
|
||||
"""
|
||||
try:
|
||||
self.client.read(self._key_prefix(queue))
|
||||
return True
|
||||
except etcd.EtcdKeyNotFound:
|
||||
return False
|
||||
|
||||
def _delete(self, queue, *args, **_):
|
||||
"""Delete a `queue`.
|
||||
|
||||
Arguments:
|
||||
---------
|
||||
queue (str): The name of the queue.
|
||||
"""
|
||||
self.queues.pop(queue, None)
|
||||
self._purge(queue)
|
||||
|
||||
def _put(self, queue, payload, **_):
|
||||
"""Put `message` onto `queue`.
|
||||
|
||||
This simply writes a key to the Etcd store
|
||||
|
||||
Arguments:
|
||||
---------
|
||||
queue (str): The name of the queue.
|
||||
payload (dict): Message data which will be dumped to etcd.
|
||||
"""
|
||||
with self._queue_lock(queue):
|
||||
key = self._key_prefix(queue)
|
||||
if not self.client.write(
|
||||
key=key,
|
||||
value=dumps(payload),
|
||||
append=True):
|
||||
raise ChannelError(f'Cannot add key {key!r} to etcd')
|
||||
|
||||
def _get(self, queue, timeout=None):
|
||||
"""Get the first available message from the queue.
|
||||
|
||||
Before it does so it acquires a lock on the store so
|
||||
only one node reads at the same time. This is for read consistency
|
||||
|
||||
Arguments:
|
||||
---------
|
||||
queue (str): The name of the queue.
|
||||
timeout (int): Optional seconds to wait for a response.
|
||||
"""
|
||||
with self._queue_lock(queue):
|
||||
key = self._key_prefix(queue)
|
||||
logger.debug('Fetching key %s with index %s', key, self.index)
|
||||
|
||||
try:
|
||||
result = self.client.read(
|
||||
key=key, recursive=True,
|
||||
index=self.index, timeout=self.timeout)
|
||||
|
||||
if result is None:
|
||||
raise Empty()
|
||||
|
||||
item = result._children[-1]
|
||||
logger.debug('Removing key {}'.format(item['key']))
|
||||
|
||||
msg_content = loads(item['value'])
|
||||
self.client.delete(key=item['key'])
|
||||
return msg_content
|
||||
except (TypeError, IndexError, etcd.EtcdException) as error:
|
||||
logger.debug(f'_get failed: {type(error)}:{error}')
|
||||
|
||||
raise Empty()
|
||||
|
||||
def _purge(self, queue):
|
||||
"""Remove all `message`s from a `queue`.
|
||||
|
||||
Arguments:
|
||||
---------
|
||||
queue (str): The name of the queue.
|
||||
"""
|
||||
with self._queue_lock(queue):
|
||||
key = self._key_prefix(queue)
|
||||
logger.debug(f'Purging queue at key {key}')
|
||||
return self.client.delete(key=key, recursive=True)
|
||||
|
||||
def _size(self, queue):
|
||||
"""Return the size of the `queue`.
|
||||
|
||||
Arguments:
|
||||
---------
|
||||
queue (str): The name of the queue.
|
||||
"""
|
||||
with self._queue_lock(queue):
|
||||
size = 0
|
||||
try:
|
||||
key = self._key_prefix(queue)
|
||||
logger.debug('Fetching key recursively %s with index %s',
|
||||
key, self.index)
|
||||
result = self.client.read(
|
||||
key=key, recursive=True,
|
||||
index=self.index)
|
||||
size = len(result._children)
|
||||
except TypeError:
|
||||
pass
|
||||
|
||||
logger.debug('Found %s keys under %s with index %s',
|
||||
size, key, self.index)
|
||||
return size
|
||||
|
||||
@cached_property
|
||||
def lock_value(self):
|
||||
return f'{socket.gethostname()}.{os.getpid()}'
|
||||
|
||||
|
||||
class Transport(virtual.Transport):
|
||||
"""Etcd storage Transport for Kombu."""
|
||||
|
||||
Channel = Channel
|
||||
|
||||
default_port = DEFAULT_PORT
|
||||
driver_type = 'etcd'
|
||||
driver_name = 'python-etcd'
|
||||
polling_interval = 3
|
||||
|
||||
implements = virtual.Transport.implements.extend(
|
||||
exchange_type=frozenset(['direct']))
|
||||
|
||||
if etcd:
|
||||
connection_errors = (
|
||||
virtual.Transport.connection_errors + (etcd.EtcdException, )
|
||||
)
|
||||
|
||||
channel_errors = (
|
||||
virtual.Transport.channel_errors + (etcd.EtcdException, )
|
||||
)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""Create a new instance of etcd.Transport."""
|
||||
if etcd is None:
|
||||
raise ImportError('Missing python-etcd library')
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def verify_connection(self, connection):
|
||||
"""Verify the connection works."""
|
||||
port = connection.client.port or self.default_port
|
||||
host = connection.client.hostname or DEFAULT_HOST
|
||||
|
||||
logger.debug('Verify Etcd connection to %s:%s', host, port)
|
||||
|
||||
try:
|
||||
etcd.Client(host=host, port=int(port))
|
||||
return True
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return False
|
||||
|
||||
def driver_version(self):
|
||||
"""Return the version of the etcd library.
|
||||
|
||||
.. note::
|
||||
python-etcd has no __version__. This is a workaround.
|
||||
"""
|
||||
try:
|
||||
import pip.commands.freeze
|
||||
for x in pip.commands.freeze.freeze():
|
||||
if x.startswith('python-etcd'):
|
||||
return x.split('==')[1]
|
||||
except (ImportError, IndexError):
|
||||
logger.warning('Unable to find the python-etcd version.')
|
||||
return 'Unknown'
|
||||
@@ -0,0 +1,352 @@
|
||||
"""File-system Transport module for kombu.
|
||||
|
||||
Transport using the file-system as the message store. Messages written to the
|
||||
queue are stored in `data_folder_in` directory and
|
||||
messages read from the queue are read from `data_folder_out` directory. Both
|
||||
directories must be created manually. Simple example:
|
||||
|
||||
* Producer:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import kombu
|
||||
|
||||
conn = kombu.Connection(
|
||||
'filesystem://', transport_options={
|
||||
'data_folder_in': 'data_in', 'data_folder_out': 'data_out'
|
||||
}
|
||||
)
|
||||
conn.connect()
|
||||
|
||||
test_queue = kombu.Queue('test', routing_key='test')
|
||||
|
||||
with conn as conn:
|
||||
with conn.default_channel as channel:
|
||||
producer = kombu.Producer(channel)
|
||||
producer.publish(
|
||||
{'hello': 'world'},
|
||||
retry=True,
|
||||
exchange=test_queue.exchange,
|
||||
routing_key=test_queue.routing_key,
|
||||
declare=[test_queue],
|
||||
serializer='pickle'
|
||||
)
|
||||
|
||||
* Consumer:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import kombu
|
||||
|
||||
conn = kombu.Connection(
|
||||
'filesystem://', transport_options={
|
||||
'data_folder_in': 'data_out', 'data_folder_out': 'data_in'
|
||||
}
|
||||
)
|
||||
conn.connect()
|
||||
|
||||
def callback(body, message):
|
||||
print(body, message)
|
||||
message.ack()
|
||||
|
||||
test_queue = kombu.Queue('test', routing_key='test')
|
||||
|
||||
with conn as conn:
|
||||
with conn.default_channel as channel:
|
||||
consumer = kombu.Consumer(
|
||||
conn, [test_queue], accept=['pickle']
|
||||
)
|
||||
consumer.register_callback(callback)
|
||||
with consumer:
|
||||
conn.drain_events(timeout=1)
|
||||
|
||||
Features
|
||||
========
|
||||
* Type: Virtual
|
||||
* Supports Direct: Yes
|
||||
* Supports Topic: Yes
|
||||
* Supports Fanout: Yes
|
||||
* Supports Priority: No
|
||||
* Supports TTL: No
|
||||
|
||||
Connection String
|
||||
=================
|
||||
Connection string is in the following format:
|
||||
|
||||
.. code-block::
|
||||
|
||||
filesystem://
|
||||
|
||||
Transport Options
|
||||
=================
|
||||
* ``data_folder_in`` - directory where are messages stored when written
|
||||
to queue.
|
||||
* ``data_folder_out`` - directory from which are messages read when read from
|
||||
queue.
|
||||
* ``store_processed`` - if set to True, all processed messages are backed up to
|
||||
``processed_folder``.
|
||||
* ``processed_folder`` - directory where are backed up processed files.
|
||||
* ``control_folder`` - directory where are exchange-queue table stored.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import uuid
|
||||
from collections import namedtuple
|
||||
from pathlib import Path
|
||||
from queue import Empty
|
||||
from time import monotonic
|
||||
|
||||
from kombu.exceptions import ChannelError
|
||||
from kombu.transport import virtual
|
||||
from kombu.utils.encoding import bytes_to_str, str_to_bytes
|
||||
from kombu.utils.json import dumps, loads
|
||||
from kombu.utils.objects import cached_property
|
||||
|
||||
VERSION = (1, 0, 0)
|
||||
__version__ = '.'.join(map(str, VERSION))
|
||||
|
||||
# needs win32all to work on Windows
|
||||
if os.name == 'nt':
|
||||
|
||||
import pywintypes
|
||||
import win32con
|
||||
import win32file
|
||||
|
||||
LOCK_EX = win32con.LOCKFILE_EXCLUSIVE_LOCK
|
||||
# 0 is the default
|
||||
LOCK_SH = 0
|
||||
LOCK_NB = win32con.LOCKFILE_FAIL_IMMEDIATELY
|
||||
__overlapped = pywintypes.OVERLAPPED()
|
||||
|
||||
def lock(file, flags):
|
||||
"""Create file lock."""
|
||||
hfile = win32file._get_osfhandle(file.fileno())
|
||||
win32file.LockFileEx(hfile, flags, 0, 0xffff0000, __overlapped)
|
||||
|
||||
def unlock(file):
|
||||
"""Remove file lock."""
|
||||
hfile = win32file._get_osfhandle(file.fileno())
|
||||
win32file.UnlockFileEx(hfile, 0, 0xffff0000, __overlapped)
|
||||
|
||||
|
||||
elif os.name == 'posix':
|
||||
|
||||
import fcntl
|
||||
from fcntl import LOCK_EX, LOCK_SH
|
||||
|
||||
def lock(file, flags):
|
||||
"""Create file lock."""
|
||||
fcntl.flock(file.fileno(), flags)
|
||||
|
||||
def unlock(file):
|
||||
"""Remove file lock."""
|
||||
fcntl.flock(file.fileno(), fcntl.LOCK_UN)
|
||||
|
||||
|
||||
else:
|
||||
raise RuntimeError(
|
||||
'Filesystem plugin only defined for NT and POSIX platforms')
|
||||
|
||||
|
||||
exchange_queue_t = namedtuple("exchange_queue_t",
|
||||
["routing_key", "pattern", "queue"])
|
||||
|
||||
|
||||
class Channel(virtual.Channel):
|
||||
"""Filesystem Channel."""
|
||||
|
||||
supports_fanout = True
|
||||
|
||||
def get_table(self, exchange):
|
||||
file = self.control_folder / f"{exchange}.exchange"
|
||||
try:
|
||||
f_obj = file.open("r")
|
||||
try:
|
||||
lock(f_obj, LOCK_SH)
|
||||
exchange_table = loads(bytes_to_str(f_obj.read()))
|
||||
return [exchange_queue_t(*q) for q in exchange_table]
|
||||
finally:
|
||||
unlock(f_obj)
|
||||
f_obj.close()
|
||||
except FileNotFoundError:
|
||||
return []
|
||||
except OSError:
|
||||
raise ChannelError(f"Cannot open {file}")
|
||||
|
||||
def _queue_bind(self, exchange, routing_key, pattern, queue):
|
||||
file = self.control_folder / f"{exchange}.exchange"
|
||||
self.control_folder.mkdir(exist_ok=True)
|
||||
queue_val = exchange_queue_t(routing_key or "", pattern or "",
|
||||
queue or "")
|
||||
try:
|
||||
if file.exists():
|
||||
f_obj = file.open("rb+", buffering=0)
|
||||
lock(f_obj, LOCK_EX)
|
||||
exchange_table = loads(bytes_to_str(f_obj.read()))
|
||||
queues = [exchange_queue_t(*q) for q in exchange_table]
|
||||
if queue_val not in queues:
|
||||
queues.insert(0, queue_val)
|
||||
f_obj.seek(0)
|
||||
f_obj.write(str_to_bytes(dumps(queues)))
|
||||
else:
|
||||
f_obj = file.open("wb", buffering=0)
|
||||
lock(f_obj, LOCK_EX)
|
||||
queues = [queue_val]
|
||||
f_obj.write(str_to_bytes(dumps(queues)))
|
||||
finally:
|
||||
unlock(f_obj)
|
||||
f_obj.close()
|
||||
|
||||
def _put_fanout(self, exchange, payload, routing_key, **kwargs):
|
||||
for q in self.get_table(exchange):
|
||||
self._put(q.queue, payload, **kwargs)
|
||||
|
||||
def _put(self, queue, payload, **kwargs):
|
||||
"""Put `message` onto `queue`."""
|
||||
filename = '{}_{}.{}.msg'.format(int(round(monotonic() * 1000)),
|
||||
uuid.uuid4(), queue)
|
||||
filename = os.path.join(self.data_folder_out, filename)
|
||||
|
||||
try:
|
||||
f = open(filename, 'wb', buffering=0)
|
||||
lock(f, LOCK_EX)
|
||||
f.write(str_to_bytes(dumps(payload)))
|
||||
except OSError:
|
||||
raise ChannelError(
|
||||
f'Cannot add file {filename!r} to directory')
|
||||
finally:
|
||||
unlock(f)
|
||||
f.close()
|
||||
|
||||
def _get(self, queue):
|
||||
"""Get next message from `queue`."""
|
||||
queue_find = '.' + queue + '.msg'
|
||||
folder = os.listdir(self.data_folder_in)
|
||||
folder = sorted(folder)
|
||||
while len(folder) > 0:
|
||||
filename = folder.pop(0)
|
||||
|
||||
# only handle message for the requested queue
|
||||
if filename.find(queue_find) < 0:
|
||||
continue
|
||||
|
||||
if self.store_processed:
|
||||
processed_folder = self.processed_folder
|
||||
else:
|
||||
processed_folder = tempfile.gettempdir()
|
||||
|
||||
try:
|
||||
# move the file to the tmp/processed folder
|
||||
shutil.move(os.path.join(self.data_folder_in, filename),
|
||||
processed_folder)
|
||||
except OSError:
|
||||
# file could be locked, or removed in meantime so ignore
|
||||
continue
|
||||
|
||||
filename = os.path.join(processed_folder, filename)
|
||||
try:
|
||||
f = open(filename, 'rb')
|
||||
payload = f.read()
|
||||
f.close()
|
||||
if not self.store_processed:
|
||||
os.remove(filename)
|
||||
except OSError:
|
||||
raise ChannelError(
|
||||
f'Cannot read file {filename!r} from queue.')
|
||||
|
||||
return loads(bytes_to_str(payload))
|
||||
|
||||
raise Empty()
|
||||
|
||||
def _purge(self, queue):
|
||||
"""Remove all messages from `queue`."""
|
||||
count = 0
|
||||
queue_find = '.' + queue + '.msg'
|
||||
|
||||
folder = os.listdir(self.data_folder_in)
|
||||
while len(folder) > 0:
|
||||
filename = folder.pop()
|
||||
try:
|
||||
# only purge messages for the requested queue
|
||||
if filename.find(queue_find) < 0:
|
||||
continue
|
||||
|
||||
filename = os.path.join(self.data_folder_in, filename)
|
||||
os.remove(filename)
|
||||
|
||||
count += 1
|
||||
|
||||
except OSError:
|
||||
# we simply ignore its existence, as it was probably
|
||||
# processed by another worker
|
||||
pass
|
||||
|
||||
return count
|
||||
|
||||
def _size(self, queue):
|
||||
"""Return the number of messages in `queue` as an :class:`int`."""
|
||||
count = 0
|
||||
|
||||
queue_find = f'.{queue}.msg'
|
||||
folder = os.listdir(self.data_folder_in)
|
||||
while len(folder) > 0:
|
||||
filename = folder.pop()
|
||||
|
||||
# only handle message for the requested queue
|
||||
if filename.find(queue_find) < 0:
|
||||
continue
|
||||
|
||||
count += 1
|
||||
|
||||
return count
|
||||
|
||||
@property
|
||||
def transport_options(self):
|
||||
return self.connection.client.transport_options
|
||||
|
||||
@cached_property
|
||||
def data_folder_in(self):
|
||||
return self.transport_options.get('data_folder_in', 'data_in')
|
||||
|
||||
@cached_property
|
||||
def data_folder_out(self):
|
||||
return self.transport_options.get('data_folder_out', 'data_out')
|
||||
|
||||
@cached_property
|
||||
def store_processed(self):
|
||||
return self.transport_options.get('store_processed', False)
|
||||
|
||||
@cached_property
|
||||
def processed_folder(self):
|
||||
return self.transport_options.get('processed_folder', 'processed')
|
||||
|
||||
@property
|
||||
def control_folder(self):
|
||||
return Path(self.transport_options.get('control_folder', 'control'))
|
||||
|
||||
|
||||
class Transport(virtual.Transport):
|
||||
"""Filesystem Transport."""
|
||||
|
||||
implements = virtual.Transport.implements.extend(
|
||||
asynchronous=False,
|
||||
exchange_type=frozenset(['direct', 'topic', 'fanout'])
|
||||
)
|
||||
|
||||
Channel = Channel
|
||||
# filesystem backend state is global.
|
||||
global_state = virtual.BrokerState()
|
||||
default_port = 0
|
||||
driver_type = 'filesystem'
|
||||
driver_name = 'filesystem'
|
||||
|
||||
def __init__(self, client, **kwargs):
|
||||
super().__init__(client, **kwargs)
|
||||
self.state = self.global_state
|
||||
|
||||
def driver_version(self):
|
||||
return 'N/A'
|
||||
@@ -0,0 +1,810 @@
|
||||
"""GCP Pub/Sub transport module for kombu.
|
||||
|
||||
More information about GCP Pub/Sub:
|
||||
https://cloud.google.com/pubsub
|
||||
|
||||
Features
|
||||
========
|
||||
* Type: Virtual
|
||||
* Supports Direct: Yes
|
||||
* Supports Topic: No
|
||||
* Supports Fanout: Yes
|
||||
* Supports Priority: No
|
||||
* Supports TTL: No
|
||||
|
||||
Connection String
|
||||
=================
|
||||
|
||||
Connection string has the following formats:
|
||||
|
||||
.. code-block::
|
||||
|
||||
gcpubsub://projects/project-name
|
||||
|
||||
Transport Options
|
||||
=================
|
||||
* ``queue_name_prefix``: (str) Prefix for queue names.
|
||||
* ``ack_deadline_seconds``: (int) The maximum time after receiving a message
|
||||
and acknowledging it before pub/sub redelivers the message.
|
||||
* ``expiration_seconds``: (int) Subscriptions without any subscriber
|
||||
activity or changes made to their properties are removed after this period.
|
||||
Examples of subscriber activities include open connections,
|
||||
active pulls, or successful pushes.
|
||||
* ``wait_time_seconds``: (int) The maximum time to wait for new messages.
|
||||
Defaults to 10.
|
||||
* ``retry_timeout_seconds``: (int) The maximum time to wait before retrying.
|
||||
* ``bulk_max_messages``: (int) The maximum number of messages to pull in bulk.
|
||||
Defaults to 32.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import datetime
|
||||
import string
|
||||
import threading
|
||||
from concurrent.futures import (FIRST_COMPLETED, Future, ThreadPoolExecutor,
|
||||
wait)
|
||||
from contextlib import suppress
|
||||
from os import getpid
|
||||
from queue import Empty
|
||||
from threading import Lock
|
||||
from time import monotonic, sleep
|
||||
from uuid import NAMESPACE_OID, uuid3
|
||||
|
||||
from _socket import gethostname
|
||||
from _socket import timeout as socket_timeout
|
||||
from google.api_core.exceptions import (AlreadyExists, DeadlineExceeded,
|
||||
PermissionDenied)
|
||||
from google.api_core.retry import Retry
|
||||
from google.cloud import monitoring_v3
|
||||
from google.cloud.monitoring_v3 import query
|
||||
from google.cloud.pubsub_v1 import PublisherClient, SubscriberClient
|
||||
from google.cloud.pubsub_v1 import exceptions as pubsub_exceptions
|
||||
from google.cloud.pubsub_v1.publisher import exceptions as publisher_exceptions
|
||||
from google.cloud.pubsub_v1.subscriber import \
|
||||
exceptions as subscriber_exceptions
|
||||
from google.pubsub_v1 import gapic_version as package_version
|
||||
|
||||
from kombu.entity import TRANSIENT_DELIVERY_MODE
|
||||
from kombu.log import get_logger
|
||||
from kombu.utils.encoding import bytes_to_str, safe_str
|
||||
from kombu.utils.json import dumps, loads
|
||||
from kombu.utils.objects import cached_property
|
||||
|
||||
from . import virtual
|
||||
|
||||
logger = get_logger('kombu.transport.gcpubsub')
|
||||
|
||||
# dots are replaced by dash, all other punctuation replaced by underscore.
|
||||
PUNCTUATIONS_TO_REPLACE = set(string.punctuation) - {'_', '.', '-'}
|
||||
CHARS_REPLACE_TABLE = {
|
||||
ord('.'): ord('-'),
|
||||
**{ord(c): ord('_') for c in PUNCTUATIONS_TO_REPLACE},
|
||||
}
|
||||
|
||||
|
||||
class UnackedIds:
|
||||
"""Threadsafe list of ack_ids."""
|
||||
|
||||
def __init__(self):
|
||||
self._list = []
|
||||
self._lock = Lock()
|
||||
|
||||
def append(self, val):
|
||||
# append is atomic
|
||||
self._list.append(val)
|
||||
|
||||
def extend(self, vals: list):
|
||||
# extend is atomic
|
||||
self._list.extend(vals)
|
||||
|
||||
def pop(self, index=-1):
|
||||
with self._lock:
|
||||
return self._list.pop(index)
|
||||
|
||||
def remove(self, val):
|
||||
with self._lock, suppress(ValueError):
|
||||
self._list.remove(val)
|
||||
|
||||
def __len__(self):
|
||||
with self._lock:
|
||||
return len(self._list)
|
||||
|
||||
def __getitem__(self, item):
|
||||
# getitem is atomic
|
||||
return self._list[item]
|
||||
|
||||
|
||||
class AtomicCounter:
|
||||
"""Threadsafe counter.
|
||||
|
||||
Returns the value after inc/dec operations.
|
||||
"""
|
||||
|
||||
def __init__(self, initial=0):
|
||||
self._value = initial
|
||||
self._lock = Lock()
|
||||
|
||||
def inc(self, n=1):
|
||||
with self._lock:
|
||||
self._value += n
|
||||
return self._value
|
||||
|
||||
def dec(self, n=1):
|
||||
with self._lock:
|
||||
self._value -= n
|
||||
return self._value
|
||||
|
||||
def get(self):
|
||||
with self._lock:
|
||||
return self._value
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class QueueDescriptor:
|
||||
"""Pub/Sub queue descriptor."""
|
||||
|
||||
name: str
|
||||
topic_path: str # projects/{project_id}/topics/{topic_id}
|
||||
subscription_id: str
|
||||
subscription_path: str # projects/{project_id}/subscriptions/{subscription_id}
|
||||
unacked_ids: UnackedIds = dataclasses.field(default_factory=UnackedIds)
|
||||
|
||||
|
||||
class Channel(virtual.Channel):
|
||||
"""GCP Pub/Sub channel."""
|
||||
|
||||
supports_fanout = True
|
||||
do_restore = False # pub/sub does that for us
|
||||
default_wait_time_seconds = 10
|
||||
default_ack_deadline_seconds = 240
|
||||
default_expiration_seconds = 86400
|
||||
default_retry_timeout_seconds = 300
|
||||
default_bulk_max_messages = 32
|
||||
|
||||
_min_ack_deadline = 10
|
||||
_fanout_exchanges = set()
|
||||
_unacked_extender: threading.Thread = None
|
||||
_stop_extender = threading.Event()
|
||||
_n_channels = AtomicCounter()
|
||||
_queue_cache: dict[str, QueueDescriptor] = {}
|
||||
_tmp_subscriptions: set[str] = set()
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.pool = ThreadPoolExecutor()
|
||||
logger.info('new GCP pub/sub channel: %s', self.conninfo.hostname)
|
||||
|
||||
self.project_id = Transport.parse_uri(self.conninfo.hostname)
|
||||
if self._n_channels.inc() == 1:
|
||||
Channel._unacked_extender = threading.Thread(
|
||||
target=self._extend_unacked_deadline,
|
||||
daemon=True,
|
||||
)
|
||||
self._stop_extender.clear()
|
||||
Channel._unacked_extender.start()
|
||||
|
||||
def entity_name(self, name: str, table=CHARS_REPLACE_TABLE) -> str:
|
||||
"""Format AMQP queue name into a valid Pub/Sub queue name."""
|
||||
if not name.startswith(self.queue_name_prefix):
|
||||
name = self.queue_name_prefix + name
|
||||
|
||||
return str(safe_str(name)).translate(table)
|
||||
|
||||
def _queue_bind(self, exchange, routing_key, pattern, queue):
|
||||
exchange_type = self.typeof(exchange).type
|
||||
queue = self.entity_name(queue)
|
||||
logger.debug(
|
||||
'binding queue: %s to %s exchange: %s with routing_key: %s',
|
||||
queue,
|
||||
exchange_type,
|
||||
exchange,
|
||||
routing_key,
|
||||
)
|
||||
|
||||
filter_args = {}
|
||||
if exchange_type == 'direct':
|
||||
# Direct exchange is implemented as a single subscription
|
||||
# E.g. for exchange 'test_direct':
|
||||
# -topic:'test_direct'
|
||||
# -bound queue:'direct1':
|
||||
# -subscription: direct1' on topic 'test_direct'
|
||||
# -filter:routing_key'
|
||||
filter_args = {
|
||||
'filter': f'attributes.routing_key="{routing_key}"'
|
||||
}
|
||||
subscription_path = self.subscriber.subscription_path(
|
||||
self.project_id, queue
|
||||
)
|
||||
message_retention_duration = self.expiration_seconds
|
||||
elif exchange_type == 'fanout':
|
||||
# Fanout exchange is implemented as a separate subscription.
|
||||
# E.g. for exchange 'test_fanout':
|
||||
# -topic:'test_fanout'
|
||||
# -bound queue 'fanout1':
|
||||
# -subscription:'fanout1-uuid' on topic 'test_fanout'
|
||||
# -bound queue 'fanout2':
|
||||
# -subscription:'fanout2-uuid' on topic 'test_fanout'
|
||||
uid = f'{uuid3(NAMESPACE_OID, f"{gethostname()}.{getpid()}")}'
|
||||
uniq_sub_name = f'{queue}-{uid}'
|
||||
subscription_path = self.subscriber.subscription_path(
|
||||
self.project_id, uniq_sub_name
|
||||
)
|
||||
self._tmp_subscriptions.add(subscription_path)
|
||||
self._fanout_exchanges.add(exchange)
|
||||
message_retention_duration = 600
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f'exchange type {exchange_type} not implemented'
|
||||
)
|
||||
exchange_topic = self._create_topic(
|
||||
self.project_id, exchange, message_retention_duration
|
||||
)
|
||||
self._create_subscription(
|
||||
topic_path=exchange_topic,
|
||||
subscription_path=subscription_path,
|
||||
filter_args=filter_args,
|
||||
msg_retention=message_retention_duration,
|
||||
)
|
||||
qdesc = QueueDescriptor(
|
||||
name=queue,
|
||||
topic_path=exchange_topic,
|
||||
subscription_id=queue,
|
||||
subscription_path=subscription_path,
|
||||
)
|
||||
self._queue_cache[queue] = qdesc
|
||||
|
||||
def _create_topic(
|
||||
self,
|
||||
project_id: str,
|
||||
topic_id: str,
|
||||
message_retention_duration: int = None,
|
||||
) -> str:
|
||||
topic_path = self.publisher.topic_path(project_id, topic_id)
|
||||
if self._is_topic_exists(topic_path):
|
||||
# topic creation takes a while, so skip if possible
|
||||
logger.debug('topic: %s exists', topic_path)
|
||||
return topic_path
|
||||
try:
|
||||
logger.debug('creating topic: %s', topic_path)
|
||||
request = {'name': topic_path}
|
||||
if message_retention_duration:
|
||||
request[
|
||||
'message_retention_duration'
|
||||
] = f'{message_retention_duration}s'
|
||||
self.publisher.create_topic(request=request)
|
||||
except AlreadyExists:
|
||||
pass
|
||||
|
||||
return topic_path
|
||||
|
||||
def _is_topic_exists(self, topic_path: str) -> bool:
|
||||
topics = self.publisher.list_topics(
|
||||
request={"project": f'projects/{self.project_id}'}
|
||||
)
|
||||
for t in topics:
|
||||
if t.name == topic_path:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _create_subscription(
|
||||
self,
|
||||
project_id: str = None,
|
||||
topic_id: str = None,
|
||||
topic_path: str = None,
|
||||
subscription_path: str = None,
|
||||
filter_args=None,
|
||||
msg_retention: int = None,
|
||||
) -> str:
|
||||
subscription_path = (
|
||||
subscription_path
|
||||
or self.subscriber.subscription_path(self.project_id, topic_id)
|
||||
)
|
||||
topic_path = topic_path or self.publisher.topic_path(
|
||||
project_id, topic_id
|
||||
)
|
||||
try:
|
||||
logger.debug(
|
||||
'creating subscription: %s, topic: %s, filter: %s',
|
||||
subscription_path,
|
||||
topic_path,
|
||||
filter_args,
|
||||
)
|
||||
msg_retention = msg_retention or self.expiration_seconds
|
||||
self.subscriber.create_subscription(
|
||||
request={
|
||||
"name": subscription_path,
|
||||
"topic": topic_path,
|
||||
'ack_deadline_seconds': self.ack_deadline_seconds,
|
||||
'expiration_policy': {
|
||||
'ttl': f'{self.expiration_seconds}s'
|
||||
},
|
||||
'message_retention_duration': f'{msg_retention}s',
|
||||
**(filter_args or {}),
|
||||
}
|
||||
)
|
||||
except AlreadyExists:
|
||||
pass
|
||||
return subscription_path
|
||||
|
||||
def _delete(self, queue, *args, **kwargs):
|
||||
"""Delete a queue by name."""
|
||||
queue = self.entity_name(queue)
|
||||
logger.info('deleting queue: %s', queue)
|
||||
qdesc = self._queue_cache.get(queue)
|
||||
if not qdesc:
|
||||
return
|
||||
self.subscriber.delete_subscription(
|
||||
request={"subscription": qdesc.subscription_path}
|
||||
)
|
||||
self._queue_cache.pop(queue, None)
|
||||
|
||||
def _put(self, queue, message, **kwargs):
|
||||
"""Put a message onto the queue."""
|
||||
queue = self.entity_name(queue)
|
||||
qdesc = self._queue_cache[queue]
|
||||
routing_key = self._get_routing_key(message)
|
||||
logger.debug(
|
||||
'putting message to queue: %s, topic: %s, routing_key: %s',
|
||||
queue,
|
||||
qdesc.topic_path,
|
||||
routing_key,
|
||||
)
|
||||
encoded_message = dumps(message)
|
||||
self.publisher.publish(
|
||||
qdesc.topic_path,
|
||||
encoded_message.encode("utf-8"),
|
||||
routing_key=routing_key,
|
||||
)
|
||||
|
||||
def _put_fanout(self, exchange, message, routing_key, **kwargs):
|
||||
"""Put a message onto fanout exchange."""
|
||||
self._lookup(exchange, routing_key)
|
||||
topic_path = self.publisher.topic_path(self.project_id, exchange)
|
||||
logger.debug(
|
||||
'putting msg to fanout exchange: %s, topic: %s',
|
||||
exchange,
|
||||
topic_path,
|
||||
)
|
||||
encoded_message = dumps(message)
|
||||
self.publisher.publish(
|
||||
topic_path,
|
||||
encoded_message.encode("utf-8"),
|
||||
retry=Retry(deadline=self.retry_timeout_seconds),
|
||||
)
|
||||
|
||||
def _get(self, queue: str, timeout: float = None):
|
||||
"""Retrieves a single message from a queue."""
|
||||
queue = self.entity_name(queue)
|
||||
qdesc = self._queue_cache[queue]
|
||||
try:
|
||||
response = self.subscriber.pull(
|
||||
request={
|
||||
'subscription': qdesc.subscription_path,
|
||||
'max_messages': 1,
|
||||
},
|
||||
retry=Retry(deadline=self.retry_timeout_seconds),
|
||||
timeout=timeout or self.wait_time_seconds,
|
||||
)
|
||||
except DeadlineExceeded:
|
||||
raise Empty()
|
||||
|
||||
if len(response.received_messages) == 0:
|
||||
raise Empty()
|
||||
|
||||
message = response.received_messages[0]
|
||||
ack_id = message.ack_id
|
||||
payload = loads(message.message.data)
|
||||
delivery_info = payload['properties']['delivery_info']
|
||||
logger.debug(
|
||||
'queue:%s got message, ack_id: %s, payload: %s',
|
||||
queue,
|
||||
ack_id,
|
||||
payload['properties'],
|
||||
)
|
||||
if self._is_auto_ack(payload['properties']):
|
||||
logger.debug('auto acking message ack_id: %s', ack_id)
|
||||
self._do_ack([ack_id], qdesc.subscription_path)
|
||||
else:
|
||||
delivery_info['gcpubsub_message'] = {
|
||||
'queue': queue,
|
||||
'ack_id': ack_id,
|
||||
'message_id': message.message.message_id,
|
||||
'subscription_path': qdesc.subscription_path,
|
||||
}
|
||||
qdesc.unacked_ids.append(ack_id)
|
||||
|
||||
return payload
|
||||
|
||||
def _is_auto_ack(self, payload_properties: dict):
|
||||
exchange = payload_properties['delivery_info']['exchange']
|
||||
delivery_mode = payload_properties['delivery_mode']
|
||||
return (
|
||||
delivery_mode == TRANSIENT_DELIVERY_MODE
|
||||
or exchange in self._fanout_exchanges
|
||||
)
|
||||
|
||||
def _get_bulk(self, queue: str, timeout: float):
|
||||
"""Retrieves bulk of messages from a queue."""
|
||||
prefixed_queue = self.entity_name(queue)
|
||||
qdesc = self._queue_cache[prefixed_queue]
|
||||
max_messages = self._get_max_messages_estimate()
|
||||
if not max_messages:
|
||||
raise Empty()
|
||||
try:
|
||||
response = self.subscriber.pull(
|
||||
request={
|
||||
'subscription': qdesc.subscription_path,
|
||||
'max_messages': max_messages,
|
||||
},
|
||||
retry=Retry(deadline=self.retry_timeout_seconds),
|
||||
timeout=timeout or self.wait_time_seconds,
|
||||
)
|
||||
except DeadlineExceeded:
|
||||
raise Empty()
|
||||
|
||||
received_messages = response.received_messages
|
||||
if len(received_messages) == 0:
|
||||
raise Empty()
|
||||
|
||||
auto_ack_ids = []
|
||||
ret_payloads = []
|
||||
logger.debug(
|
||||
'batching %d messages from queue: %s',
|
||||
len(received_messages),
|
||||
prefixed_queue,
|
||||
)
|
||||
for message in received_messages:
|
||||
ack_id = message.ack_id
|
||||
payload = loads(bytes_to_str(message.message.data))
|
||||
delivery_info = payload['properties']['delivery_info']
|
||||
delivery_info['gcpubsub_message'] = {
|
||||
'queue': prefixed_queue,
|
||||
'ack_id': ack_id,
|
||||
'message_id': message.message.message_id,
|
||||
'subscription_path': qdesc.subscription_path,
|
||||
}
|
||||
if self._is_auto_ack(payload['properties']):
|
||||
auto_ack_ids.append(ack_id)
|
||||
else:
|
||||
qdesc.unacked_ids.append(ack_id)
|
||||
ret_payloads.append(payload)
|
||||
if auto_ack_ids:
|
||||
logger.debug('auto acking ack_ids: %s', auto_ack_ids)
|
||||
self._do_ack(auto_ack_ids, qdesc.subscription_path)
|
||||
|
||||
return queue, ret_payloads
|
||||
|
||||
def _get_max_messages_estimate(self) -> int:
|
||||
max_allowed = self.qos.can_consume_max_estimate()
|
||||
max_if_unlimited = self.bulk_max_messages
|
||||
return max_if_unlimited if max_allowed is None else max_allowed
|
||||
|
||||
def _lookup(self, exchange, routing_key, default=None):
|
||||
exchange_info = self.state.exchanges.get(exchange, {})
|
||||
if not exchange_info:
|
||||
return super()._lookup(exchange, routing_key, default)
|
||||
ret = self.typeof(exchange).lookup(
|
||||
self.get_table(exchange),
|
||||
exchange,
|
||||
routing_key,
|
||||
default,
|
||||
)
|
||||
if ret:
|
||||
return ret
|
||||
logger.debug(
|
||||
'no queues bound to exchange: %s, binding on the fly',
|
||||
exchange,
|
||||
)
|
||||
self.queue_bind(exchange, exchange, routing_key)
|
||||
return [exchange]
|
||||
|
||||
def _size(self, queue: str) -> int:
|
||||
"""Return the number of messages in a queue.
|
||||
|
||||
This is a *rough* estimation, as Pub/Sub doesn't provide
|
||||
an exact API.
|
||||
"""
|
||||
queue = self.entity_name(queue)
|
||||
if queue not in self._queue_cache:
|
||||
return 0
|
||||
qdesc = self._queue_cache[queue]
|
||||
result = query.Query(
|
||||
self.monitor,
|
||||
self.project_id,
|
||||
'pubsub.googleapis.com/subscription/num_undelivered_messages',
|
||||
end_time=datetime.datetime.now(),
|
||||
minutes=1,
|
||||
).select_resources(subscription_id=qdesc.subscription_id)
|
||||
|
||||
# monitoring API requires the caller to have the monitoring.viewer
|
||||
# role. Since we can live without the exact number of messages
|
||||
# in the queue, we can ignore the exception and allow users to
|
||||
# use the transport without this role.
|
||||
with suppress(PermissionDenied):
|
||||
return sum(
|
||||
content.points[0].value.int64_value for content in result
|
||||
)
|
||||
return -1
|
||||
|
||||
def basic_ack(self, delivery_tag, multiple=False):
|
||||
"""Acknowledge one message."""
|
||||
if multiple:
|
||||
raise NotImplementedError('multiple acks not implemented')
|
||||
|
||||
delivery_info = self.qos.get(delivery_tag).delivery_info
|
||||
pubsub_message = delivery_info['gcpubsub_message']
|
||||
ack_id = pubsub_message['ack_id']
|
||||
queue = pubsub_message['queue']
|
||||
logger.debug('ack message. queue: %s ack_id: %s', queue, ack_id)
|
||||
subscription_path = pubsub_message['subscription_path']
|
||||
self._do_ack([ack_id], subscription_path)
|
||||
qdesc = self._queue_cache[queue]
|
||||
qdesc.unacked_ids.remove(ack_id)
|
||||
super().basic_ack(delivery_tag)
|
||||
|
||||
def _do_ack(self, ack_ids: list[str], subscription_path: str):
|
||||
self.subscriber.acknowledge(
|
||||
request={"subscription": subscription_path, "ack_ids": ack_ids},
|
||||
retry=Retry(deadline=self.retry_timeout_seconds),
|
||||
)
|
||||
|
||||
def _purge(self, queue: str):
|
||||
"""Delete all current messages in a queue."""
|
||||
queue = self.entity_name(queue)
|
||||
qdesc = self._queue_cache.get(queue)
|
||||
if not qdesc:
|
||||
return
|
||||
|
||||
n = self._size(queue)
|
||||
self.subscriber.seek(
|
||||
request={
|
||||
"subscription": qdesc.subscription_path,
|
||||
"time": datetime.datetime.now(),
|
||||
}
|
||||
)
|
||||
return n
|
||||
|
||||
def _extend_unacked_deadline(self):
|
||||
thread_id = threading.get_native_id()
|
||||
logger.info(
|
||||
'unacked deadline extension thread: [%s] started',
|
||||
thread_id,
|
||||
)
|
||||
min_deadline_sleep = self._min_ack_deadline / 2
|
||||
sleep_time = max(min_deadline_sleep, self.ack_deadline_seconds / 4)
|
||||
while not self._stop_extender.wait(sleep_time):
|
||||
for qdesc in self._queue_cache.values():
|
||||
if len(qdesc.unacked_ids) == 0:
|
||||
logger.debug(
|
||||
'thread [%s]: no unacked messages for %s',
|
||||
thread_id,
|
||||
qdesc.subscription_path,
|
||||
)
|
||||
continue
|
||||
logger.debug(
|
||||
'thread [%s]: extend ack deadline for %s: %d msgs [%s]',
|
||||
thread_id,
|
||||
qdesc.subscription_path,
|
||||
len(qdesc.unacked_ids),
|
||||
list(qdesc.unacked_ids),
|
||||
)
|
||||
self.subscriber.modify_ack_deadline(
|
||||
request={
|
||||
"subscription": qdesc.subscription_path,
|
||||
"ack_ids": list(qdesc.unacked_ids),
|
||||
"ack_deadline_seconds": self.ack_deadline_seconds,
|
||||
}
|
||||
)
|
||||
logger.info(
|
||||
'unacked deadline extension thread [%s] stopped', thread_id
|
||||
)
|
||||
|
||||
def after_reply_message_received(self, queue: str):
|
||||
queue = self.entity_name(queue)
|
||||
sub = self.subscriber.subscription_path(self.project_id, queue)
|
||||
logger.debug(
|
||||
'after_reply_message_received: queue: %s, sub: %s', queue, sub
|
||||
)
|
||||
self._tmp_subscriptions.add(sub)
|
||||
|
||||
@cached_property
|
||||
def subscriber(self):
|
||||
return SubscriberClient()
|
||||
|
||||
@cached_property
|
||||
def publisher(self):
|
||||
return PublisherClient()
|
||||
|
||||
@cached_property
|
||||
def monitor(self):
|
||||
return monitoring_v3.MetricServiceClient()
|
||||
|
||||
@property
|
||||
def conninfo(self):
|
||||
return self.connection.client
|
||||
|
||||
@property
|
||||
def transport_options(self):
|
||||
return self.connection.client.transport_options
|
||||
|
||||
@cached_property
|
||||
def wait_time_seconds(self):
|
||||
return self.transport_options.get(
|
||||
'wait_time_seconds', self.default_wait_time_seconds
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def retry_timeout_seconds(self):
|
||||
return self.transport_options.get(
|
||||
'retry_timeout_seconds', self.default_retry_timeout_seconds
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def ack_deadline_seconds(self):
|
||||
return self.transport_options.get(
|
||||
'ack_deadline_seconds', self.default_ack_deadline_seconds
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def queue_name_prefix(self):
|
||||
return self.transport_options.get('queue_name_prefix', 'kombu-')
|
||||
|
||||
@cached_property
|
||||
def expiration_seconds(self):
|
||||
return self.transport_options.get(
|
||||
'expiration_seconds', self.default_expiration_seconds
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def bulk_max_messages(self):
|
||||
return self.transport_options.get(
|
||||
'bulk_max_messages', self.default_bulk_max_messages
|
||||
)
|
||||
|
||||
def close(self):
|
||||
"""Close the channel."""
|
||||
logger.debug('closing channel')
|
||||
while self._tmp_subscriptions:
|
||||
sub = self._tmp_subscriptions.pop()
|
||||
with suppress(Exception):
|
||||
logger.debug('deleting subscription: %s', sub)
|
||||
self.subscriber.delete_subscription(
|
||||
request={"subscription": sub}
|
||||
)
|
||||
if not self._n_channels.dec():
|
||||
self._stop_extender.set()
|
||||
Channel._unacked_extender.join()
|
||||
super().close()
|
||||
|
||||
@staticmethod
|
||||
def _get_routing_key(message):
|
||||
routing_key = (
|
||||
message['properties']
|
||||
.get('delivery_info', {})
|
||||
.get('routing_key', '')
|
||||
)
|
||||
return routing_key
|
||||
|
||||
|
||||
class Transport(virtual.Transport):
|
||||
"""GCP Pub/Sub transport."""
|
||||
|
||||
Channel = Channel
|
||||
|
||||
can_parse_url = True
|
||||
polling_interval = 0.1
|
||||
connection_errors = virtual.Transport.connection_errors + (
|
||||
pubsub_exceptions.TimeoutError,
|
||||
)
|
||||
channel_errors = (
|
||||
virtual.Transport.channel_errors
|
||||
+ (
|
||||
publisher_exceptions.FlowControlLimitError,
|
||||
publisher_exceptions.MessageTooLargeError,
|
||||
publisher_exceptions.PublishError,
|
||||
publisher_exceptions.TimeoutError,
|
||||
publisher_exceptions.PublishToPausedOrderingKeyException,
|
||||
)
|
||||
+ (subscriber_exceptions.AcknowledgeError,)
|
||||
)
|
||||
|
||||
driver_type = 'gcpubsub'
|
||||
driver_name = 'pubsub_v1'
|
||||
|
||||
implements = virtual.Transport.implements.extend(
|
||||
exchange_type=frozenset(['direct', 'fanout']),
|
||||
)
|
||||
|
||||
def __init__(self, client, **kwargs):
|
||||
super().__init__(client, **kwargs)
|
||||
self._pool = ThreadPoolExecutor()
|
||||
self._get_bulk_future_to_queue: dict[Future, str] = dict()
|
||||
|
||||
def driver_version(self):
|
||||
return package_version.__version__
|
||||
|
||||
@staticmethod
|
||||
def parse_uri(uri: str) -> str:
|
||||
# URL like:
|
||||
# gcpubsub://projects/project-name
|
||||
|
||||
project = uri.split('gcpubsub://projects/')[1]
|
||||
return project.strip('/')
|
||||
|
||||
@classmethod
|
||||
def as_uri(self, uri: str, include_password=False, mask='**') -> str:
|
||||
return uri or 'gcpubsub://'
|
||||
|
||||
def drain_events(self, connection, timeout=None):
|
||||
time_start = monotonic()
|
||||
polling_interval = self.polling_interval
|
||||
if timeout and polling_interval and polling_interval > timeout:
|
||||
polling_interval = timeout
|
||||
while 1:
|
||||
try:
|
||||
self._drain_from_active_queues(timeout=timeout)
|
||||
except Empty:
|
||||
if timeout and monotonic() - time_start >= timeout:
|
||||
raise socket_timeout()
|
||||
if polling_interval:
|
||||
sleep(polling_interval)
|
||||
else:
|
||||
break
|
||||
|
||||
def _drain_from_active_queues(self, timeout):
|
||||
# cleanup empty requests from prev run
|
||||
self._rm_empty_bulk_requests()
|
||||
|
||||
# submit new requests for all active queues
|
||||
# longer timeout means less frequent polling
|
||||
# and more messages in a single bulk
|
||||
self._submit_get_bulk_requests(timeout=10)
|
||||
|
||||
done, _ = wait(
|
||||
self._get_bulk_future_to_queue,
|
||||
timeout=timeout,
|
||||
return_when=FIRST_COMPLETED,
|
||||
)
|
||||
empty = {f for f in done if f.exception()}
|
||||
done -= empty
|
||||
for f in empty:
|
||||
self._get_bulk_future_to_queue.pop(f, None)
|
||||
|
||||
if not done:
|
||||
raise Empty()
|
||||
|
||||
logger.debug('got %d done get_bulk tasks', len(done))
|
||||
for f in done:
|
||||
queue, payloads = f.result()
|
||||
for payload in payloads:
|
||||
logger.debug('consuming message from queue: %s', queue)
|
||||
if queue not in self._callbacks:
|
||||
logger.warning(
|
||||
'Message for queue %s without consumers', queue
|
||||
)
|
||||
continue
|
||||
self._deliver(payload, queue)
|
||||
self._get_bulk_future_to_queue.pop(f, None)
|
||||
|
||||
def _rm_empty_bulk_requests(self):
|
||||
empty = {
|
||||
f
|
||||
for f in self._get_bulk_future_to_queue
|
||||
if f.done() and f.exception()
|
||||
}
|
||||
for f in empty:
|
||||
self._get_bulk_future_to_queue.pop(f, None)
|
||||
|
||||
def _submit_get_bulk_requests(self, timeout):
|
||||
queues_with_submitted_get_bulk = set(
|
||||
self._get_bulk_future_to_queue.values()
|
||||
)
|
||||
|
||||
for channel in self.channels:
|
||||
for queue in channel._active_queues:
|
||||
if queue in queues_with_submitted_get_bulk:
|
||||
continue
|
||||
future = self._pool.submit(channel._get_bulk, queue, timeout)
|
||||
self._get_bulk_future_to_queue[future] = queue
|
||||
@@ -0,0 +1,190 @@
|
||||
"""`librabbitmq`_ transport.
|
||||
|
||||
.. _`librabbitmq`: https://pypi.org/project/librabbitmq/
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import socket
|
||||
import warnings
|
||||
|
||||
import librabbitmq as amqp
|
||||
from librabbitmq import ChannelError, ConnectionError
|
||||
|
||||
from kombu.utils.amq_manager import get_manager
|
||||
from kombu.utils.text import version_string_as_tuple
|
||||
|
||||
from . import base
|
||||
from .base import to_rabbitmq_queue_arguments
|
||||
|
||||
W_VERSION = """
|
||||
librabbitmq version too old to detect RabbitMQ version information
|
||||
so make sure you are using librabbitmq 1.5 when using rabbitmq > 3.3
|
||||
"""
|
||||
DEFAULT_PORT = 5672
|
||||
DEFAULT_SSL_PORT = 5671
|
||||
|
||||
NO_SSL_ERROR = """\
|
||||
ssl not supported by librabbitmq, please use pyamqp:// or stunnel\
|
||||
"""
|
||||
|
||||
|
||||
class Message(base.Message):
|
||||
"""AMQP Message (librabbitmq)."""
|
||||
|
||||
def __init__(self, channel, props, info, body):
|
||||
super().__init__(
|
||||
channel=channel,
|
||||
body=body,
|
||||
delivery_info=info,
|
||||
properties=props,
|
||||
delivery_tag=info.get('delivery_tag'),
|
||||
content_type=props.get('content_type'),
|
||||
content_encoding=props.get('content_encoding'),
|
||||
headers=props.get('headers'))
|
||||
|
||||
|
||||
class Channel(amqp.Channel, base.StdChannel):
|
||||
"""AMQP Channel (librabbitmq)."""
|
||||
|
||||
Message = Message
|
||||
|
||||
def prepare_message(self, body, priority=None,
|
||||
content_type=None, content_encoding=None,
|
||||
headers=None, properties=None):
|
||||
"""Encapsulate data into a AMQP message."""
|
||||
properties = properties if properties is not None else {}
|
||||
properties.update({'content_type': content_type,
|
||||
'content_encoding': content_encoding,
|
||||
'headers': headers})
|
||||
# Don't include priority if it's not an integer.
|
||||
# If that's the case librabbitmq will fail
|
||||
# and raise an exception.
|
||||
if priority is not None:
|
||||
properties['priority'] = priority
|
||||
return body, properties
|
||||
|
||||
def prepare_queue_arguments(self, arguments, **kwargs):
|
||||
arguments = to_rabbitmq_queue_arguments(arguments, **kwargs)
|
||||
return {k.encode('utf8'): v for k, v in arguments.items()}
|
||||
|
||||
|
||||
class Connection(amqp.Connection):
|
||||
"""AMQP Connection (librabbitmq)."""
|
||||
|
||||
Channel = Channel
|
||||
Message = Message
|
||||
|
||||
|
||||
class Transport(base.Transport):
|
||||
"""AMQP Transport (librabbitmq)."""
|
||||
|
||||
Connection = Connection
|
||||
|
||||
default_port = DEFAULT_PORT
|
||||
default_ssl_port = DEFAULT_SSL_PORT
|
||||
|
||||
connection_errors = (
|
||||
base.Transport.connection_errors + (
|
||||
ConnectionError, socket.error, IOError, OSError)
|
||||
)
|
||||
channel_errors = (
|
||||
base.Transport.channel_errors + (ChannelError,)
|
||||
)
|
||||
driver_type = 'amqp'
|
||||
driver_name = 'librabbitmq'
|
||||
|
||||
implements = base.Transport.implements.extend(
|
||||
asynchronous=True,
|
||||
heartbeats=False,
|
||||
)
|
||||
|
||||
def __init__(self, client, **kwargs):
|
||||
self.client = client
|
||||
self.default_port = kwargs.get('default_port') or self.default_port
|
||||
self.default_ssl_port = (kwargs.get('default_ssl_port') or
|
||||
self.default_ssl_port)
|
||||
self.__reader = None
|
||||
|
||||
def driver_version(self):
|
||||
return amqp.__version__
|
||||
|
||||
def create_channel(self, connection):
|
||||
return connection.channel()
|
||||
|
||||
def drain_events(self, connection, **kwargs):
|
||||
return connection.drain_events(**kwargs)
|
||||
|
||||
def establish_connection(self):
|
||||
"""Establish connection to the AMQP broker."""
|
||||
conninfo = self.client
|
||||
for name, default_value in self.default_connection_params.items():
|
||||
if not getattr(conninfo, name, None):
|
||||
setattr(conninfo, name, default_value)
|
||||
if conninfo.ssl:
|
||||
raise NotImplementedError(NO_SSL_ERROR)
|
||||
opts = dict({
|
||||
'host': conninfo.host,
|
||||
'userid': conninfo.userid,
|
||||
'password': conninfo.password,
|
||||
'virtual_host': conninfo.virtual_host,
|
||||
'login_method': conninfo.login_method,
|
||||
'insist': conninfo.insist,
|
||||
'ssl': conninfo.ssl,
|
||||
'connect_timeout': conninfo.connect_timeout,
|
||||
}, **conninfo.transport_options or {})
|
||||
conn = self.Connection(**opts)
|
||||
conn.client = self.client
|
||||
self.client.drain_events = conn.drain_events
|
||||
return conn
|
||||
|
||||
def close_connection(self, connection):
|
||||
"""Close the AMQP broker connection."""
|
||||
self.client.drain_events = None
|
||||
connection.close()
|
||||
|
||||
def _collect(self, connection):
|
||||
if connection is not None:
|
||||
for channel in connection.channels.values():
|
||||
channel.connection = None
|
||||
try:
|
||||
os.close(connection.fileno())
|
||||
except (OSError, ValueError):
|
||||
pass
|
||||
connection.channels.clear()
|
||||
connection.callbacks.clear()
|
||||
self.client.drain_events = None
|
||||
self.client = None
|
||||
|
||||
def verify_connection(self, connection):
|
||||
return connection.connected
|
||||
|
||||
def register_with_event_loop(self, connection, loop):
|
||||
loop.add_reader(
|
||||
connection.fileno(), self.on_readable, connection, loop,
|
||||
)
|
||||
|
||||
def get_manager(self, *args, **kwargs):
|
||||
return get_manager(self.client, *args, **kwargs)
|
||||
|
||||
def qos_semantics_matches_spec(self, connection):
|
||||
try:
|
||||
props = connection.server_properties
|
||||
except AttributeError:
|
||||
warnings.warn(UserWarning(W_VERSION))
|
||||
else:
|
||||
if props.get('product') == 'RabbitMQ':
|
||||
return version_string_as_tuple(props['version']) < (3, 3)
|
||||
return True
|
||||
|
||||
@property
|
||||
def default_connection_params(self):
|
||||
return {
|
||||
'userid': 'guest',
|
||||
'password': 'guest',
|
||||
'port': (self.default_ssl_port if self.client.ssl
|
||||
else self.default_port),
|
||||
'hostname': 'localhost',
|
||||
'login_method': 'PLAIN',
|
||||
}
|
||||
@@ -0,0 +1,106 @@
|
||||
"""In-memory transport module for Kombu.
|
||||
|
||||
Simple transport using memory for storing messages.
|
||||
Messages can be passed only between threads.
|
||||
|
||||
Features
|
||||
========
|
||||
* Type: Virtual
|
||||
* Supports Direct: Yes
|
||||
* Supports Topic: Yes
|
||||
* Supports Fanout: No
|
||||
* Supports Priority: No
|
||||
* Supports TTL: Yes
|
||||
|
||||
Connection String
|
||||
=================
|
||||
Connection string is in the following format:
|
||||
|
||||
.. code-block::
|
||||
|
||||
memory://
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from queue import Queue
|
||||
|
||||
from . import base, virtual
|
||||
|
||||
|
||||
class Channel(virtual.Channel):
|
||||
"""In-memory Channel."""
|
||||
|
||||
events = defaultdict(set)
|
||||
queues = {}
|
||||
do_restore = False
|
||||
supports_fanout = True
|
||||
|
||||
def _has_queue(self, queue, **kwargs):
|
||||
return queue in self.queues
|
||||
|
||||
def _new_queue(self, queue, **kwargs):
|
||||
if queue not in self.queues:
|
||||
self.queues[queue] = Queue()
|
||||
|
||||
def _get(self, queue, timeout=None):
|
||||
return self._queue_for(queue).get(block=False)
|
||||
|
||||
def _queue_for(self, queue):
|
||||
if queue not in self.queues:
|
||||
self.queues[queue] = Queue()
|
||||
return self.queues[queue]
|
||||
|
||||
def _queue_bind(self, *args):
|
||||
pass
|
||||
|
||||
def _put_fanout(self, exchange, message, routing_key=None, **kwargs):
|
||||
for queue in self._lookup(exchange, routing_key):
|
||||
self._queue_for(queue).put(message)
|
||||
|
||||
def _put(self, queue, message, **kwargs):
|
||||
self._queue_for(queue).put(message)
|
||||
|
||||
def _size(self, queue):
|
||||
return self._queue_for(queue).qsize()
|
||||
|
||||
def _delete(self, queue, *args, **kwargs):
|
||||
self.queues.pop(queue, None)
|
||||
|
||||
def _purge(self, queue):
|
||||
q = self._queue_for(queue)
|
||||
size = q.qsize()
|
||||
q.queue.clear()
|
||||
return size
|
||||
|
||||
def close(self):
|
||||
super().close()
|
||||
for queue in self.queues.values():
|
||||
queue.empty()
|
||||
self.queues = {}
|
||||
|
||||
def after_reply_message_received(self, queue):
|
||||
pass
|
||||
|
||||
|
||||
class Transport(virtual.Transport):
|
||||
"""In-memory Transport."""
|
||||
|
||||
Channel = Channel
|
||||
|
||||
#: memory backend state is global.
|
||||
global_state = virtual.BrokerState()
|
||||
|
||||
implements = base.Transport.implements
|
||||
|
||||
driver_type = 'memory'
|
||||
driver_name = 'memory'
|
||||
|
||||
def __init__(self, client, **kwargs):
|
||||
super().__init__(client, **kwargs)
|
||||
self.state = self.global_state
|
||||
|
||||
def driver_version(self):
|
||||
return 'N/A'
|
||||
@@ -0,0 +1,534 @@
|
||||
# copyright: (c) 2010 - 2013 by Flavio Percoco Premoli.
|
||||
# license: BSD, see LICENSE for more details.
|
||||
|
||||
"""MongoDB transport module for kombu.
|
||||
|
||||
Features
|
||||
========
|
||||
* Type: Virtual
|
||||
* Supports Direct: Yes
|
||||
* Supports Topic: Yes
|
||||
* Supports Fanout: Yes
|
||||
* Supports Priority: Yes
|
||||
* Supports TTL: Yes
|
||||
|
||||
Connection String
|
||||
=================
|
||||
*Unreviewed*
|
||||
|
||||
Transport Options
|
||||
=================
|
||||
|
||||
* ``connect_timeout``,
|
||||
* ``ssl``,
|
||||
* ``ttl``,
|
||||
* ``capped_queue_size``,
|
||||
* ``default_hostname``,
|
||||
* ``default_port``,
|
||||
* ``default_database``,
|
||||
* ``messages_collection``,
|
||||
* ``routing_collection``,
|
||||
* ``broadcast_collection``,
|
||||
* ``queues_collection``,
|
||||
* ``calc_queue_size``,
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
from queue import Empty
|
||||
|
||||
import pymongo
|
||||
from pymongo import MongoClient, errors, uri_parser
|
||||
from pymongo.cursor import CursorType
|
||||
|
||||
from kombu.exceptions import VersionMismatch
|
||||
from kombu.utils.compat import _detect_environment
|
||||
from kombu.utils.encoding import bytes_to_str
|
||||
from kombu.utils.json import dumps, loads
|
||||
from kombu.utils.objects import cached_property
|
||||
from kombu.utils.url import maybe_sanitize_url
|
||||
|
||||
from . import virtual
|
||||
from .base import to_rabbitmq_queue_arguments
|
||||
|
||||
E_SERVER_VERSION = """\
|
||||
Kombu requires MongoDB version 1.3+ (server is {0})\
|
||||
"""
|
||||
|
||||
E_NO_TTL_INDEXES = """\
|
||||
Kombu requires MongoDB version 2.2+ (server is {0}) for TTL indexes support\
|
||||
"""
|
||||
|
||||
|
||||
class BroadcastCursor:
|
||||
"""Cursor for broadcast queues."""
|
||||
|
||||
def __init__(self, cursor):
|
||||
self._cursor = cursor
|
||||
self._offset = 0
|
||||
self.purge(rewind=False)
|
||||
|
||||
def get_size(self):
|
||||
return self._cursor.collection.count_documents({}) - self._offset
|
||||
|
||||
def close(self):
|
||||
self._cursor.close()
|
||||
|
||||
def purge(self, rewind=True):
|
||||
if rewind:
|
||||
self._cursor.rewind()
|
||||
|
||||
# Fast-forward the cursor past old events
|
||||
self._offset = self._cursor.collection.count_documents({})
|
||||
self._cursor = self._cursor.skip(self._offset)
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
while True:
|
||||
try:
|
||||
msg = next(self._cursor)
|
||||
except pymongo.errors.OperationFailure as exc:
|
||||
# In some cases tailed cursor can become invalid
|
||||
# and have to be reinitalized
|
||||
if 'not valid at server' in str(exc):
|
||||
self.purge()
|
||||
|
||||
continue
|
||||
|
||||
raise
|
||||
else:
|
||||
break
|
||||
|
||||
self._offset += 1
|
||||
|
||||
return msg
|
||||
next = __next__
|
||||
|
||||
|
||||
class Channel(virtual.Channel):
|
||||
"""MongoDB Channel."""
|
||||
|
||||
supports_fanout = True
|
||||
|
||||
# Mutable container. Shared by all class instances
|
||||
_fanout_queues = {}
|
||||
|
||||
# Options
|
||||
ssl = False
|
||||
ttl = False
|
||||
connect_timeout = None
|
||||
capped_queue_size = 100000
|
||||
calc_queue_size = True
|
||||
|
||||
default_hostname = '127.0.0.1'
|
||||
default_port = 27017
|
||||
default_database = 'kombu_default'
|
||||
|
||||
messages_collection = 'messages'
|
||||
routing_collection = 'messages.routing'
|
||||
broadcast_collection = 'messages.broadcast'
|
||||
queues_collection = 'messages.queues'
|
||||
|
||||
from_transport_options = (virtual.Channel.from_transport_options + (
|
||||
'connect_timeout', 'ssl', 'ttl', 'capped_queue_size',
|
||||
'default_hostname', 'default_port', 'default_database',
|
||||
'messages_collection', 'routing_collection',
|
||||
'broadcast_collection', 'queues_collection',
|
||||
'calc_queue_size',
|
||||
))
|
||||
|
||||
def __init__(self, *vargs, **kwargs):
|
||||
super().__init__(*vargs, **kwargs)
|
||||
|
||||
self._broadcast_cursors = {}
|
||||
|
||||
# Evaluate connection
|
||||
self.client
|
||||
|
||||
# AbstractChannel/Channel interface implementation
|
||||
|
||||
def _new_queue(self, queue, **kwargs):
|
||||
if self.ttl:
|
||||
self.queues.update_one(
|
||||
{'_id': queue},
|
||||
{
|
||||
'$set': {
|
||||
'_id': queue,
|
||||
'options': kwargs,
|
||||
'expire_at': self._get_queue_expire(
|
||||
kwargs, 'x-expires'
|
||||
),
|
||||
},
|
||||
},
|
||||
upsert=True)
|
||||
|
||||
def _get(self, queue):
|
||||
if queue in self._fanout_queues:
|
||||
try:
|
||||
msg = next(self._get_broadcast_cursor(queue))
|
||||
except StopIteration:
|
||||
msg = None
|
||||
else:
|
||||
msg = self.messages.find_one_and_delete(
|
||||
{'queue': queue},
|
||||
sort=[('priority', pymongo.ASCENDING)],
|
||||
)
|
||||
|
||||
if self.ttl:
|
||||
self._update_queues_expire(queue)
|
||||
|
||||
if msg is None:
|
||||
raise Empty()
|
||||
|
||||
return loads(bytes_to_str(msg['payload']))
|
||||
|
||||
def _size(self, queue):
|
||||
# Do not calculate actual queue size if requested
|
||||
# for performance considerations
|
||||
if not self.calc_queue_size:
|
||||
return super()._size(queue)
|
||||
|
||||
if queue in self._fanout_queues:
|
||||
return self._get_broadcast_cursor(queue).get_size()
|
||||
|
||||
return self.messages.count_documents({'queue': queue})
|
||||
|
||||
def _put(self, queue, message, **kwargs):
|
||||
data = {
|
||||
'payload': dumps(message),
|
||||
'queue': queue,
|
||||
'priority': self._get_message_priority(message, reverse=True)
|
||||
}
|
||||
|
||||
if self.ttl:
|
||||
data['expire_at'] = self._get_queue_expire(queue, 'x-message-ttl')
|
||||
msg_expire = self._get_message_expire(message)
|
||||
if msg_expire is not None and (
|
||||
data['expire_at'] is None or msg_expire < data['expire_at']
|
||||
):
|
||||
data['expire_at'] = msg_expire
|
||||
|
||||
self.messages.insert_one(data)
|
||||
|
||||
def _put_fanout(self, exchange, message, routing_key, **kwargs):
|
||||
self.broadcast.insert_one({'payload': dumps(message),
|
||||
'queue': exchange})
|
||||
|
||||
def _purge(self, queue):
|
||||
size = self._size(queue)
|
||||
|
||||
if queue in self._fanout_queues:
|
||||
self._get_broadcast_cursor(queue).purge()
|
||||
else:
|
||||
self.messages.delete_many({'queue': queue})
|
||||
|
||||
return size
|
||||
|
||||
def get_table(self, exchange):
|
||||
localRoutes = frozenset(self.state.exchanges[exchange]['table'])
|
||||
brokerRoutes = self.routing.find(
|
||||
{'exchange': exchange}
|
||||
)
|
||||
|
||||
return localRoutes | frozenset(
|
||||
(r['routing_key'], r['pattern'], r['queue'])
|
||||
for r in brokerRoutes
|
||||
)
|
||||
|
||||
def _queue_bind(self, exchange, routing_key, pattern, queue):
|
||||
if self.typeof(exchange).type == 'fanout':
|
||||
self._create_broadcast_cursor(
|
||||
exchange, routing_key, pattern, queue)
|
||||
self._fanout_queues[queue] = exchange
|
||||
|
||||
lookup = {
|
||||
'exchange': exchange,
|
||||
'queue': queue,
|
||||
'routing_key': routing_key,
|
||||
'pattern': pattern,
|
||||
}
|
||||
|
||||
data = lookup.copy()
|
||||
|
||||
if self.ttl:
|
||||
data['expire_at'] = self._get_queue_expire(queue, 'x-expires')
|
||||
|
||||
self.routing.update_one(lookup, {'$set': data}, upsert=True)
|
||||
|
||||
def queue_delete(self, queue, **kwargs):
|
||||
self.routing.delete_many({'queue': queue})
|
||||
|
||||
if self.ttl:
|
||||
self.queues.delete_one({'_id': queue})
|
||||
|
||||
super().queue_delete(queue, **kwargs)
|
||||
|
||||
if queue in self._fanout_queues:
|
||||
try:
|
||||
cursor = self._broadcast_cursors.pop(queue)
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
cursor.close()
|
||||
|
||||
self._fanout_queues.pop(queue)
|
||||
|
||||
# Implementation details
|
||||
|
||||
def _parse_uri(self, scheme='mongodb://'):
|
||||
# See mongodb uri documentation:
|
||||
# https://docs.mongodb.org/manual/reference/connection-string/
|
||||
client = self.connection.client
|
||||
hostname = client.hostname
|
||||
|
||||
if hostname.startswith('srv://'):
|
||||
scheme = 'mongodb+srv://'
|
||||
hostname = 'mongodb+' + hostname
|
||||
|
||||
if not hostname.startswith(scheme):
|
||||
hostname = scheme + hostname
|
||||
|
||||
if not hostname[len(scheme):]:
|
||||
hostname += self.default_hostname
|
||||
|
||||
if client.userid and '@' not in hostname:
|
||||
head, tail = hostname.split('://')
|
||||
|
||||
credentials = client.userid
|
||||
if client.password:
|
||||
credentials += ':' + client.password
|
||||
|
||||
hostname = head + '://' + credentials + '@' + tail
|
||||
|
||||
port = client.port if client.port else self.default_port
|
||||
|
||||
# We disable validating and normalization parameters here,
|
||||
# because pymongo will validate and normalize parameters later in __init__ of MongoClient
|
||||
parsed = uri_parser.parse_uri(hostname, port, validate=False)
|
||||
|
||||
dbname = parsed['database'] or client.virtual_host
|
||||
|
||||
if dbname in ('/', None):
|
||||
dbname = self.default_database
|
||||
|
||||
options = {
|
||||
'auto_start_request': True,
|
||||
'ssl': self.ssl,
|
||||
'connectTimeoutMS': (int(self.connect_timeout * 1000)
|
||||
if self.connect_timeout else None),
|
||||
}
|
||||
options.update(parsed['options'])
|
||||
options = self._prepare_client_options(options)
|
||||
|
||||
if 'tls' in options:
|
||||
options.pop('ssl')
|
||||
|
||||
return hostname, dbname, options
|
||||
|
||||
def _prepare_client_options(self, options):
|
||||
if pymongo.version_tuple >= (3,):
|
||||
options.pop('auto_start_request', None)
|
||||
if isinstance(options.get('readpreference'), int):
|
||||
modes = pymongo.read_preferences._MONGOS_MODES
|
||||
options['readpreference'] = modes[options['readpreference']]
|
||||
return options
|
||||
|
||||
def prepare_queue_arguments(self, arguments, **kwargs):
|
||||
return to_rabbitmq_queue_arguments(arguments, **kwargs)
|
||||
|
||||
def _open(self, scheme='mongodb://'):
|
||||
hostname, dbname, conf = self._parse_uri(scheme=scheme)
|
||||
|
||||
conf['host'] = hostname
|
||||
|
||||
env = _detect_environment()
|
||||
if env == 'gevent':
|
||||
from gevent import monkey
|
||||
monkey.patch_all()
|
||||
elif env == 'eventlet':
|
||||
from eventlet import monkey_patch
|
||||
monkey_patch()
|
||||
|
||||
mongoconn = MongoClient(**conf)
|
||||
database = mongoconn[dbname]
|
||||
|
||||
version_str = mongoconn.server_info()['version']
|
||||
version_str = version_str.split('-')[0]
|
||||
version = tuple(map(int, version_str.split('.')))
|
||||
|
||||
if version < (1, 3):
|
||||
raise VersionMismatch(E_SERVER_VERSION.format(version_str))
|
||||
elif self.ttl and version < (2, 2):
|
||||
raise VersionMismatch(E_NO_TTL_INDEXES.format(version_str))
|
||||
|
||||
return database
|
||||
|
||||
def _create_broadcast(self, database):
|
||||
"""Create capped collection for broadcast messages."""
|
||||
if self.broadcast_collection in database.list_collection_names():
|
||||
return
|
||||
|
||||
database.create_collection(self.broadcast_collection,
|
||||
size=self.capped_queue_size,
|
||||
capped=True)
|
||||
|
||||
def _ensure_indexes(self, database):
|
||||
"""Ensure indexes on collections."""
|
||||
messages = database[self.messages_collection]
|
||||
messages.create_index(
|
||||
[('queue', 1), ('priority', 1), ('_id', 1)], background=True,
|
||||
)
|
||||
|
||||
database[self.broadcast_collection].create_index([('queue', 1)])
|
||||
|
||||
routing = database[self.routing_collection]
|
||||
routing.create_index([('queue', 1), ('exchange', 1)])
|
||||
|
||||
if self.ttl:
|
||||
messages.create_index([('expire_at', 1)], expireAfterSeconds=0)
|
||||
routing.create_index([('expire_at', 1)], expireAfterSeconds=0)
|
||||
|
||||
database[self.queues_collection].create_index(
|
||||
[('expire_at', 1)], expireAfterSeconds=0)
|
||||
|
||||
def _create_client(self):
|
||||
"""Actually creates connection."""
|
||||
database = self._open()
|
||||
self._create_broadcast(database)
|
||||
self._ensure_indexes(database)
|
||||
|
||||
return database
|
||||
|
||||
@cached_property
|
||||
def client(self):
|
||||
return self._create_client()
|
||||
|
||||
@cached_property
|
||||
def messages(self):
|
||||
return self.client[self.messages_collection]
|
||||
|
||||
@cached_property
|
||||
def routing(self):
|
||||
return self.client[self.routing_collection]
|
||||
|
||||
@cached_property
|
||||
def broadcast(self):
|
||||
return self.client[self.broadcast_collection]
|
||||
|
||||
@cached_property
|
||||
def queues(self):
|
||||
return self.client[self.queues_collection]
|
||||
|
||||
def _get_broadcast_cursor(self, queue):
|
||||
try:
|
||||
return self._broadcast_cursors[queue]
|
||||
except KeyError:
|
||||
# Cursor may be absent when Channel created more than once.
|
||||
# _fanout_queues is a class-level mutable attribute so it's
|
||||
# shared over all Channel instances.
|
||||
return self._create_broadcast_cursor(
|
||||
self._fanout_queues[queue], None, None, queue,
|
||||
)
|
||||
|
||||
def _create_broadcast_cursor(self, exchange, routing_key, pattern, queue):
|
||||
if pymongo.version_tuple >= (3, ):
|
||||
query = {
|
||||
'filter': {'queue': exchange},
|
||||
'cursor_type': CursorType.TAILABLE,
|
||||
}
|
||||
else:
|
||||
query = {
|
||||
'query': {'queue': exchange},
|
||||
'tailable': True,
|
||||
}
|
||||
|
||||
cursor = self.broadcast.find(**query)
|
||||
ret = self._broadcast_cursors[queue] = BroadcastCursor(cursor)
|
||||
return ret
|
||||
|
||||
def _get_message_expire(self, message):
|
||||
value = message.get('properties', {}).get('expiration')
|
||||
if value is not None:
|
||||
return self.get_now() + datetime.timedelta(milliseconds=int(value))
|
||||
|
||||
def _get_queue_expire(self, queue, argument):
|
||||
"""Get expiration header named `argument` of queue definition.
|
||||
|
||||
Note:
|
||||
----
|
||||
`queue` must be either queue name or options itself.
|
||||
"""
|
||||
if isinstance(queue, str):
|
||||
doc = self.queues.find_one({'_id': queue})
|
||||
|
||||
if not doc:
|
||||
return
|
||||
|
||||
data = doc['options']
|
||||
else:
|
||||
data = queue
|
||||
|
||||
try:
|
||||
value = data['arguments'][argument]
|
||||
except (KeyError, TypeError):
|
||||
return
|
||||
|
||||
return self.get_now() + datetime.timedelta(milliseconds=value)
|
||||
|
||||
def _update_queues_expire(self, queue):
|
||||
"""Update expiration field on queues documents."""
|
||||
expire_at = self._get_queue_expire(queue, 'x-expires')
|
||||
|
||||
if not expire_at:
|
||||
return
|
||||
|
||||
self.routing.update_many(
|
||||
{'queue': queue}, {'$set': {'expire_at': expire_at}})
|
||||
self.queues.update_many(
|
||||
{'_id': queue}, {'$set': {'expire_at': expire_at}})
|
||||
|
||||
def get_now(self):
|
||||
"""Return current time in UTC."""
|
||||
return datetime.datetime.utcnow()
|
||||
|
||||
|
||||
class Transport(virtual.Transport):
|
||||
"""MongoDB Transport."""
|
||||
|
||||
Channel = Channel
|
||||
|
||||
can_parse_url = True
|
||||
polling_interval = 1
|
||||
default_port = Channel.default_port
|
||||
connection_errors = (
|
||||
virtual.Transport.connection_errors + (errors.ConnectionFailure,)
|
||||
)
|
||||
channel_errors = (
|
||||
virtual.Transport.channel_errors + (
|
||||
errors.ConnectionFailure,
|
||||
errors.OperationFailure)
|
||||
)
|
||||
driver_type = 'mongodb'
|
||||
driver_name = 'pymongo'
|
||||
|
||||
implements = virtual.Transport.implements.extend(
|
||||
exchange_type=frozenset(['direct', 'topic', 'fanout']),
|
||||
)
|
||||
|
||||
def driver_version(self):
|
||||
return pymongo.version
|
||||
|
||||
def as_uri(self, uri: str, include_password=False, mask='**') -> str:
|
||||
if not uri:
|
||||
return 'mongodb://'
|
||||
if include_password:
|
||||
return uri
|
||||
|
||||
if ',' not in uri:
|
||||
return maybe_sanitize_url(uri)
|
||||
|
||||
uri1, remainder = uri.split(',', 1)
|
||||
return ','.join([maybe_sanitize_url(uri1), remainder])
|
||||
@@ -0,0 +1,134 @@
|
||||
"""Native Delayed Delivery API.
|
||||
|
||||
Only relevant for RabbitMQ.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from kombu import Connection, Exchange, Queue, binding
|
||||
from kombu.log import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
MAX_NUMBER_OF_BITS_TO_USE = 28
|
||||
MAX_LEVEL = MAX_NUMBER_OF_BITS_TO_USE - 1
|
||||
CELERY_DELAYED_DELIVERY_EXCHANGE = "celery_delayed_delivery"
|
||||
|
||||
|
||||
def level_name(level: int) -> str:
|
||||
"""Generates the delayed queue/exchange name based on the level."""
|
||||
if level < 0:
|
||||
raise ValueError("level must be a non-negative number")
|
||||
|
||||
return f"celery_delayed_{level}"
|
||||
|
||||
|
||||
def declare_native_delayed_delivery_exchanges_and_queues(connection: Connection, queue_type: str) -> None:
|
||||
"""Declares all native delayed delivery exchanges and queues."""
|
||||
if queue_type != "classic" and queue_type != "quorum":
|
||||
raise ValueError("queue_type must be either classic or quorum")
|
||||
|
||||
channel = connection.channel()
|
||||
|
||||
routing_key: str = "1.#"
|
||||
|
||||
for level in range(27, -1, - 1):
|
||||
current_level = level_name(level)
|
||||
next_level = level_name(level - 1) if level > 0 else None
|
||||
|
||||
delayed_exchange: Exchange = Exchange(
|
||||
current_level, type="topic").bind(channel)
|
||||
delayed_exchange.declare()
|
||||
|
||||
queue_arguments = {
|
||||
"x-queue-type": queue_type,
|
||||
"x-overflow": "reject-publish",
|
||||
"x-message-ttl": pow(2, level) * 1000,
|
||||
"x-dead-letter-exchange": next_level if level > 0 else CELERY_DELAYED_DELIVERY_EXCHANGE,
|
||||
}
|
||||
|
||||
if queue_type == 'quorum':
|
||||
queue_arguments["x-dead-letter-strategy"] = "at-least-once"
|
||||
|
||||
delayed_queue: Queue = Queue(
|
||||
current_level,
|
||||
queue_arguments=queue_arguments
|
||||
).bind(channel)
|
||||
delayed_queue.declare()
|
||||
delayed_queue.bind_to(current_level, routing_key)
|
||||
|
||||
routing_key = "*." + routing_key
|
||||
|
||||
routing_key = "0.#"
|
||||
for level in range(27, 0, - 1):
|
||||
current_level = level_name(level)
|
||||
next_level = level_name(level - 1) if level > 0 else None
|
||||
|
||||
next_level_exchange: Exchange = Exchange(
|
||||
next_level, type="topic").bind(channel)
|
||||
|
||||
next_level_exchange.bind_to(current_level, routing_key)
|
||||
|
||||
routing_key = "*." + routing_key
|
||||
|
||||
delivery_exchange: Exchange = Exchange(
|
||||
CELERY_DELAYED_DELIVERY_EXCHANGE, type="topic").bind(channel)
|
||||
delivery_exchange.declare()
|
||||
delivery_exchange.bind_to(level_name(0), routing_key)
|
||||
|
||||
|
||||
def bind_queue_to_native_delayed_delivery_exchange(connection: Connection, queue: Queue) -> None:
|
||||
"""Bind a queue to the native delayed delivery exchange.
|
||||
|
||||
When a message arrives at the delivery exchange, it must be forwarded to
|
||||
the original exchange and queue. To accomplish this, the function retrieves
|
||||
the exchange or binding objects associated with the queue and binds them to
|
||||
the delivery exchange.
|
||||
|
||||
|
||||
:param connection: The connection object used to create and manage the channel.
|
||||
:type connection: Connection
|
||||
:param queue: The queue to be bound to the native delayed delivery exchange.
|
||||
:type queue: Queue
|
||||
|
||||
Warning:
|
||||
-------
|
||||
If a direct exchange is detected, a warning will be logged because
|
||||
native delayed delivery does not support direct exchanges.
|
||||
"""
|
||||
channel = connection.channel()
|
||||
queue = queue.bind(channel)
|
||||
|
||||
bindings: set[binding] = set()
|
||||
|
||||
if queue.exchange:
|
||||
bindings.add(binding(
|
||||
queue.exchange,
|
||||
routing_key=queue.routing_key,
|
||||
arguments=queue.binding_arguments
|
||||
))
|
||||
elif queue.bindings:
|
||||
bindings = queue.bindings
|
||||
|
||||
for binding_entry in bindings:
|
||||
exchange: Exchange = binding_entry.exchange.bind(channel)
|
||||
if exchange.type == 'direct':
|
||||
logger.warning(f"Exchange {exchange.name} is a direct exchange "
|
||||
f"and native delayed delivery do not support direct exchanges.\n"
|
||||
f"ETA tasks published to this exchange will block the worker until the ETA arrives.")
|
||||
continue
|
||||
|
||||
routing_key = binding_entry.routing_key if binding_entry.routing_key.startswith(
|
||||
'#') else f"#.{binding_entry.routing_key}"
|
||||
exchange.bind_to(CELERY_DELAYED_DELIVERY_EXCHANGE, routing_key=routing_key)
|
||||
queue.bind_to(exchange.name, routing_key=routing_key)
|
||||
|
||||
|
||||
def calculate_routing_key(countdown: int, routing_key: str) -> str:
|
||||
"""Calculate the routing key for publishing a delayed message based on the countdown."""
|
||||
if countdown < 1:
|
||||
raise ValueError("countdown must be a positive number")
|
||||
|
||||
if not routing_key:
|
||||
raise ValueError("routing_key must be non-empty")
|
||||
|
||||
return '.'.join(list(f'{countdown:028b}')) + f'.{routing_key}'
|
||||
@@ -0,0 +1,253 @@
|
||||
"""pyamqp transport module for Kombu.
|
||||
|
||||
Pure-Python amqp transport using py-amqp library.
|
||||
|
||||
Features
|
||||
========
|
||||
* Type: Native
|
||||
* Supports Direct: Yes
|
||||
* Supports Topic: Yes
|
||||
* Supports Fanout: Yes
|
||||
* Supports Priority: Yes
|
||||
* Supports TTL: Yes
|
||||
|
||||
Connection String
|
||||
=================
|
||||
Connection string can have the following formats:
|
||||
|
||||
.. code-block::
|
||||
|
||||
amqp://[USER:PASSWORD@]BROKER_ADDRESS[:PORT][/VIRTUALHOST]
|
||||
[USER:PASSWORD@]BROKER_ADDRESS[:PORT][/VIRTUALHOST]
|
||||
amqp://
|
||||
|
||||
For TLS encryption use:
|
||||
|
||||
.. code-block::
|
||||
|
||||
amqps://[USER:PASSWORD@]BROKER_ADDRESS[:PORT][/VIRTUALHOST]
|
||||
|
||||
Transport Options
|
||||
=================
|
||||
Transport Options are passed to constructor of underlying py-amqp
|
||||
:class:`~kombu.connection.Connection` class.
|
||||
|
||||
Using TLS
|
||||
=========
|
||||
Transport over TLS can be enabled by ``ssl`` parameter of
|
||||
:class:`~kombu.Connection` class. By setting ``ssl=True``, TLS transport is
|
||||
used::
|
||||
|
||||
conn = Connect('amqp://', ssl=True)
|
||||
|
||||
This is equivalent to ``amqps://`` transport URI::
|
||||
|
||||
conn = Connect('amqps://')
|
||||
|
||||
For adding additional parameters to underlying TLS, ``ssl`` parameter should
|
||||
be set with dict instead of True::
|
||||
|
||||
conn = Connect('amqp://broker.example.com', ssl={
|
||||
'keyfile': '/path/to/keyfile'
|
||||
'certfile': '/path/to/certfile',
|
||||
'ca_certs': '/path/to/ca_certfile'
|
||||
}
|
||||
)
|
||||
|
||||
All parameters are passed to ``ssl`` parameter of
|
||||
:class:`amqp.connection.Connection` class.
|
||||
|
||||
SSL option ``server_hostname`` can be set to ``None`` which is causing using
|
||||
hostname from broker URL. This is useful when failover is used to fill
|
||||
``server_hostname`` with currently used broker::
|
||||
|
||||
conn = Connect('amqp://broker1.example.com;broker2.example.com', ssl={
|
||||
'server_hostname': None
|
||||
}
|
||||
)
|
||||
"""
|
||||
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import amqp
|
||||
|
||||
from kombu.utils.amq_manager import get_manager
|
||||
from kombu.utils.text import version_string_as_tuple
|
||||
|
||||
from . import base
|
||||
from .base import to_rabbitmq_queue_arguments
|
||||
|
||||
DEFAULT_PORT = 5672
|
||||
DEFAULT_SSL_PORT = 5671
|
||||
|
||||
|
||||
class Message(base.Message):
|
||||
"""AMQP Message."""
|
||||
|
||||
def __init__(self, msg, channel=None, **kwargs):
|
||||
props = msg.properties
|
||||
super().__init__(
|
||||
body=msg.body,
|
||||
channel=channel,
|
||||
delivery_tag=msg.delivery_tag,
|
||||
content_type=props.get('content_type'),
|
||||
content_encoding=props.get('content_encoding'),
|
||||
delivery_info=msg.delivery_info,
|
||||
properties=msg.properties,
|
||||
headers=props.get('application_headers') or {},
|
||||
**kwargs)
|
||||
|
||||
|
||||
class Channel(amqp.Channel, base.StdChannel):
|
||||
"""AMQP Channel."""
|
||||
|
||||
Message = Message
|
||||
|
||||
def prepare_message(self, body, priority=None,
|
||||
content_type=None, content_encoding=None,
|
||||
headers=None, properties=None, _Message=amqp.Message):
|
||||
"""Prepare message so that it can be sent using this transport."""
|
||||
return _Message(
|
||||
body,
|
||||
priority=priority,
|
||||
content_type=content_type,
|
||||
content_encoding=content_encoding,
|
||||
application_headers=headers,
|
||||
**properties or {}
|
||||
)
|
||||
|
||||
def prepare_queue_arguments(self, arguments, **kwargs):
|
||||
return to_rabbitmq_queue_arguments(arguments, **kwargs)
|
||||
|
||||
def message_to_python(self, raw_message):
|
||||
"""Convert encoded message body back to a Python value."""
|
||||
return self.Message(raw_message, channel=self)
|
||||
|
||||
|
||||
class Connection(amqp.Connection):
|
||||
"""AMQP Connection."""
|
||||
|
||||
Channel = Channel
|
||||
|
||||
|
||||
class Transport(base.Transport):
|
||||
"""AMQP Transport."""
|
||||
|
||||
Connection = Connection
|
||||
|
||||
default_port = DEFAULT_PORT
|
||||
default_ssl_port = DEFAULT_SSL_PORT
|
||||
|
||||
# it's very annoying that pyamqp sometimes raises AttributeError
|
||||
# if the connection is lost, but nothing we can do about that here.
|
||||
connection_errors = amqp.Connection.connection_errors
|
||||
channel_errors = amqp.Connection.channel_errors
|
||||
recoverable_connection_errors = \
|
||||
amqp.Connection.recoverable_connection_errors
|
||||
recoverable_channel_errors = amqp.Connection.recoverable_channel_errors
|
||||
|
||||
driver_name = 'py-amqp'
|
||||
driver_type = 'amqp'
|
||||
|
||||
implements = base.Transport.implements.extend(
|
||||
asynchronous=True,
|
||||
heartbeats=True,
|
||||
)
|
||||
|
||||
def __init__(self, client,
|
||||
default_port=None, default_ssl_port=None, **kwargs):
|
||||
self.client = client
|
||||
self.default_port = default_port or self.default_port
|
||||
self.default_ssl_port = default_ssl_port or self.default_ssl_port
|
||||
|
||||
def driver_version(self):
|
||||
return amqp.__version__
|
||||
|
||||
def create_channel(self, connection):
|
||||
return connection.channel()
|
||||
|
||||
def drain_events(self, connection, **kwargs):
|
||||
return connection.drain_events(**kwargs)
|
||||
|
||||
def _collect(self, connection):
|
||||
if connection is not None:
|
||||
connection.collect()
|
||||
|
||||
def establish_connection(self):
|
||||
"""Establish connection to the AMQP broker."""
|
||||
conninfo = self.client
|
||||
for name, default_value in self.default_connection_params.items():
|
||||
if not getattr(conninfo, name, None):
|
||||
setattr(conninfo, name, default_value)
|
||||
if conninfo.hostname == 'localhost':
|
||||
conninfo.hostname = '127.0.0.1'
|
||||
# when server_hostname is None, use hostname from URI.
|
||||
if isinstance(conninfo.ssl, dict) and \
|
||||
'server_hostname' in conninfo.ssl and \
|
||||
conninfo.ssl['server_hostname'] is None:
|
||||
conninfo.ssl['server_hostname'] = conninfo.hostname
|
||||
opts = dict({
|
||||
'host': conninfo.host,
|
||||
'userid': conninfo.userid,
|
||||
'password': conninfo.password,
|
||||
'login_method': conninfo.login_method,
|
||||
'virtual_host': conninfo.virtual_host,
|
||||
'insist': conninfo.insist,
|
||||
'ssl': conninfo.ssl,
|
||||
'connect_timeout': conninfo.connect_timeout,
|
||||
'heartbeat': conninfo.heartbeat,
|
||||
}, **conninfo.transport_options or {})
|
||||
conn = self.Connection(**opts)
|
||||
conn.client = self.client
|
||||
conn.connect()
|
||||
return conn
|
||||
|
||||
def verify_connection(self, connection):
|
||||
return connection.connected
|
||||
|
||||
def close_connection(self, connection):
|
||||
"""Close the AMQP broker connection."""
|
||||
connection.client = None
|
||||
connection.close()
|
||||
|
||||
def get_heartbeat_interval(self, connection):
|
||||
return connection.heartbeat
|
||||
|
||||
def register_with_event_loop(self, connection, loop):
|
||||
connection.transport.raise_on_initial_eintr = True
|
||||
loop.add_reader(connection.sock, self.on_readable, connection, loop)
|
||||
|
||||
def heartbeat_check(self, connection, rate=2):
|
||||
return connection.heartbeat_tick(rate=rate)
|
||||
|
||||
def qos_semantics_matches_spec(self, connection):
|
||||
props = connection.server_properties
|
||||
if props.get('product') == 'RabbitMQ':
|
||||
return version_string_as_tuple(props['version']) < (3, 3)
|
||||
return True
|
||||
|
||||
@property
|
||||
def default_connection_params(self):
|
||||
return {
|
||||
'userid': 'guest',
|
||||
'password': 'guest',
|
||||
'port': (self.default_ssl_port if self.client.ssl
|
||||
else self.default_port),
|
||||
'hostname': 'localhost',
|
||||
'login_method': 'PLAIN',
|
||||
}
|
||||
|
||||
def get_manager(self, *args, **kwargs):
|
||||
return get_manager(self.client, *args, **kwargs)
|
||||
|
||||
|
||||
class SSLTransport(Transport):
|
||||
"""AMQP SSL Transport."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# ugh, not exactly pure, but hey, it's python.
|
||||
if not self.client.ssl: # not dict or False
|
||||
self.client.ssl = True
|
||||
@@ -0,0 +1,212 @@
|
||||
"""Pyro transport module for kombu.
|
||||
|
||||
Pyro transport, and Kombu Broker daemon.
|
||||
|
||||
Requires the :mod:`Pyro4` library to be installed.
|
||||
|
||||
Features
|
||||
========
|
||||
* Type: Virtual
|
||||
* Supports Direct: Yes
|
||||
* Supports Topic: Yes
|
||||
* Supports Fanout: No
|
||||
* Supports Priority: No
|
||||
* Supports TTL: No
|
||||
|
||||
Connection String
|
||||
=================
|
||||
|
||||
To use the Pyro transport with Kombu, use an url of the form:
|
||||
|
||||
.. code-block::
|
||||
|
||||
pyro://localhost/kombu.broker
|
||||
|
||||
The hostname is where the transport will be looking for a Pyro name server,
|
||||
which is used in turn to locate the kombu.broker Pyro service.
|
||||
This broker can be launched by simply executing this transport module directly,
|
||||
with the command: ``python -m kombu.transport.pyro``
|
||||
|
||||
Transport Options
|
||||
=================
|
||||
"""
|
||||
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from queue import Empty, Queue
|
||||
|
||||
from kombu.exceptions import reraise
|
||||
from kombu.log import get_logger
|
||||
from kombu.utils.objects import cached_property
|
||||
|
||||
from . import virtual
|
||||
|
||||
try:
|
||||
import Pyro4 as pyro
|
||||
from Pyro4.errors import NamingError
|
||||
from Pyro4.util import SerializerBase
|
||||
except ImportError: # pragma: no cover
|
||||
pyro = NamingError = SerializerBase = None
|
||||
|
||||
DEFAULT_PORT = 9090
|
||||
E_NAMESERVER = """\
|
||||
Unable to locate pyro nameserver on host {0.hostname}\
|
||||
"""
|
||||
E_LOOKUP = """\
|
||||
Unable to lookup '{0.virtual_host}' in pyro nameserver on host {0.hostname}\
|
||||
"""
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class Channel(virtual.Channel):
|
||||
"""Pyro Channel."""
|
||||
|
||||
def close(self):
|
||||
super().close()
|
||||
if self.shared_queues:
|
||||
self.shared_queues._pyroRelease()
|
||||
|
||||
def queues(self):
|
||||
return self.shared_queues.get_queue_names()
|
||||
|
||||
def _new_queue(self, queue, **kwargs):
|
||||
if queue not in self.queues():
|
||||
self.shared_queues.new_queue(queue)
|
||||
|
||||
def _has_queue(self, queue, **kwargs):
|
||||
return self.shared_queues.has_queue(queue)
|
||||
|
||||
def _get(self, queue, timeout=None):
|
||||
queue = self._queue_for(queue)
|
||||
return self.shared_queues.get(queue)
|
||||
|
||||
def _queue_for(self, queue):
|
||||
if queue not in self.queues():
|
||||
self.shared_queues.new_queue(queue)
|
||||
return queue
|
||||
|
||||
def _put(self, queue, message, **kwargs):
|
||||
queue = self._queue_for(queue)
|
||||
self.shared_queues.put(queue, message)
|
||||
|
||||
def _size(self, queue):
|
||||
return self.shared_queues.size(queue)
|
||||
|
||||
def _delete(self, queue, *args, **kwargs):
|
||||
self.shared_queues.delete(queue)
|
||||
|
||||
def _purge(self, queue):
|
||||
return self.shared_queues.purge(queue)
|
||||
|
||||
def after_reply_message_received(self, queue):
|
||||
pass
|
||||
|
||||
@cached_property
|
||||
def shared_queues(self):
|
||||
return self.connection.shared_queues
|
||||
|
||||
|
||||
class Transport(virtual.Transport):
|
||||
"""Pyro Transport."""
|
||||
|
||||
Channel = Channel
|
||||
|
||||
#: memory backend state is global.
|
||||
# TODO: To be checked whether state can be per-Transport
|
||||
global_state = virtual.BrokerState()
|
||||
|
||||
default_port = DEFAULT_PORT
|
||||
|
||||
driver_type = driver_name = 'pyro'
|
||||
|
||||
def __init__(self, client, **kwargs):
|
||||
super().__init__(client, **kwargs)
|
||||
self.state = self.global_state
|
||||
|
||||
def _open(self):
|
||||
logger.debug("trying Pyro nameserver to find the broker daemon")
|
||||
conninfo = self.client
|
||||
try:
|
||||
nameserver = pyro.locateNS(host=conninfo.hostname,
|
||||
port=self.default_port)
|
||||
except NamingError:
|
||||
reraise(NamingError, NamingError(E_NAMESERVER.format(conninfo)),
|
||||
sys.exc_info()[2])
|
||||
try:
|
||||
# name of registered pyro object
|
||||
uri = nameserver.lookup(conninfo.virtual_host)
|
||||
return pyro.Proxy(uri)
|
||||
except NamingError:
|
||||
reraise(NamingError, NamingError(E_LOOKUP.format(conninfo)),
|
||||
sys.exc_info()[2])
|
||||
|
||||
def driver_version(self):
|
||||
return pyro.__version__
|
||||
|
||||
@cached_property
|
||||
def shared_queues(self):
|
||||
return self._open()
|
||||
|
||||
|
||||
if pyro is not None:
|
||||
SerializerBase.register_dict_to_class("queue.Empty",
|
||||
lambda cls, data: Empty())
|
||||
|
||||
@pyro.expose
|
||||
@pyro.behavior(instance_mode="single")
|
||||
class KombuBroker:
|
||||
"""Kombu Broker used by the Pyro transport.
|
||||
|
||||
You have to run this as a separate (Pyro) service.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.queues = {}
|
||||
|
||||
def get_queue_names(self):
|
||||
return list(self.queues)
|
||||
|
||||
def new_queue(self, queue):
|
||||
if queue in self.queues:
|
||||
return # silently ignore the fact that queue already exists
|
||||
self.queues[queue] = Queue()
|
||||
|
||||
def has_queue(self, queue):
|
||||
return queue in self.queues
|
||||
|
||||
def get(self, queue):
|
||||
return self.queues[queue].get(block=False)
|
||||
|
||||
def put(self, queue, message):
|
||||
self.queues[queue].put(message)
|
||||
|
||||
def size(self, queue):
|
||||
return self.queues[queue].qsize()
|
||||
|
||||
def delete(self, queue):
|
||||
del self.queues[queue]
|
||||
|
||||
def purge(self, queue):
|
||||
while True:
|
||||
try:
|
||||
self.queues[queue].get(blocking=False)
|
||||
except Empty:
|
||||
break
|
||||
|
||||
|
||||
# launch a Kombu Broker daemon with the command:
|
||||
# ``python -m kombu.transport.pyro``
|
||||
if __name__ == "__main__":
|
||||
print("Launching Broker for Kombu's Pyro transport.")
|
||||
with pyro.Daemon() as daemon:
|
||||
print("(Expecting a Pyro name server at {}:{})"
|
||||
.format(pyro.config.NS_HOST, pyro.config.NS_PORT))
|
||||
with pyro.locateNS() as ns:
|
||||
print("You can connect with Kombu using the url "
|
||||
"'pyro://{}/kombu.broker'".format(pyro.config.NS_HOST))
|
||||
uri = daemon.register(KombuBroker)
|
||||
ns.register("kombu.broker", uri)
|
||||
daemon.requestLoop()
|
||||
1748
ETB-API/venv/lib/python3.12/site-packages/kombu/transport/qpid.py
Normal file
1748
ETB-API/venv/lib/python3.12/site-packages/kombu/transport/qpid.py
Normal file
File diff suppressed because it is too large
Load Diff
1460
ETB-API/venv/lib/python3.12/site-packages/kombu/transport/redis.py
Normal file
1460
ETB-API/venv/lib/python3.12/site-packages/kombu/transport/redis.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,256 @@
|
||||
"""SQLAlchemy Transport module for kombu.
|
||||
|
||||
Kombu transport using SQL Database as the message store.
|
||||
|
||||
Features
|
||||
========
|
||||
* Type: Virtual
|
||||
* Supports Direct: yes
|
||||
* Supports Topic: yes
|
||||
* Supports Fanout: no
|
||||
* Supports Priority: no
|
||||
* Supports TTL: no
|
||||
|
||||
Connection String
|
||||
=================
|
||||
|
||||
.. code-block::
|
||||
|
||||
sqla+SQL_ALCHEMY_CONNECTION_STRING
|
||||
sqlalchemy+SQL_ALCHEMY_CONNECTION_STRING
|
||||
|
||||
For details about ``SQL_ALCHEMY_CONNECTION_STRING`` see SQLAlchemy Engine Configuration documentation.
|
||||
|
||||
Examples
|
||||
--------
|
||||
.. code-block::
|
||||
|
||||
# PostgreSQL with default driver
|
||||
sqla+postgresql://scott:tiger@localhost/mydatabase
|
||||
|
||||
# PostgreSQL with psycopg2 driver
|
||||
sqla+postgresql+psycopg2://scott:tiger@localhost/mydatabase
|
||||
|
||||
# PostgreSQL with pg8000 driver
|
||||
sqla+postgresql+pg8000://scott:tiger@localhost/mydatabase
|
||||
|
||||
# MySQL with default driver
|
||||
sqla+mysql://scott:tiger@localhost/foo
|
||||
|
||||
# MySQL with mysqlclient driver (a maintained fork of MySQL-Python)
|
||||
sqla+mysql+mysqldb://scott:tiger@localhost/foo
|
||||
|
||||
# MySQL with PyMySQL driver
|
||||
sqla+mysql+pymysql://scott:tiger@localhost/foo
|
||||
|
||||
Transport Options
|
||||
=================
|
||||
|
||||
* ``queue_tablename``: Name of table storing queues.
|
||||
* ``message_tablename``: Name of table storing messages.
|
||||
|
||||
Moreover parameters of :func:`sqlalchemy.create_engine()` function can be passed as transport options.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
from json import dumps, loads
|
||||
from queue import Empty
|
||||
|
||||
from sqlalchemy import create_engine, text
|
||||
from sqlalchemy.exc import OperationalError
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from kombu.transport import virtual
|
||||
from kombu.utils import cached_property
|
||||
from kombu.utils.encoding import bytes_to_str
|
||||
|
||||
from .models import Message as MessageBase
|
||||
from .models import ModelBase
|
||||
from .models import Queue as QueueBase
|
||||
from .models import class_registry, metadata
|
||||
|
||||
# SQLAlchemy overrides != False to have special meaning and pep8 complains
|
||||
# flake8: noqa
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
VERSION = (1, 4, 1)
|
||||
__version__ = '.'.join(map(str, VERSION))
|
||||
|
||||
_MUTEX = threading.RLock()
|
||||
|
||||
|
||||
class Channel(virtual.Channel):
|
||||
"""The channel class."""
|
||||
|
||||
_session = None
|
||||
_engines = {} # engine cache
|
||||
|
||||
def __init__(self, connection, **kwargs):
|
||||
self._configure_entity_tablenames(connection.client.transport_options)
|
||||
super().__init__(connection, **kwargs)
|
||||
|
||||
def _configure_entity_tablenames(self, opts):
|
||||
self.queue_tablename = opts.get('queue_tablename', 'kombu_queue')
|
||||
self.message_tablename = opts.get('message_tablename', 'kombu_message')
|
||||
|
||||
#
|
||||
# Define the model definitions. This registers the declarative
|
||||
# classes with the active SQLAlchemy metadata object. This *must* be
|
||||
# done prior to the ``create_engine`` call.
|
||||
#
|
||||
self.queue_cls and self.message_cls
|
||||
|
||||
def _engine_from_config(self):
|
||||
conninfo = self.connection.client
|
||||
transport_options = conninfo.transport_options.copy()
|
||||
transport_options.pop('queue_tablename', None)
|
||||
transport_options.pop('message_tablename', None)
|
||||
transport_options.pop('callback', None)
|
||||
transport_options.pop('errback', None)
|
||||
transport_options.pop('max_retries', None)
|
||||
transport_options.pop('interval_start', None)
|
||||
transport_options.pop('interval_step', None)
|
||||
transport_options.pop('interval_max', None)
|
||||
transport_options.pop('retry_errors', None)
|
||||
|
||||
return create_engine(conninfo.hostname, **transport_options)
|
||||
|
||||
def _open(self):
|
||||
conninfo = self.connection.client
|
||||
if conninfo.hostname not in self._engines:
|
||||
with _MUTEX:
|
||||
if conninfo.hostname in self._engines:
|
||||
# Engine was created while we were waiting to
|
||||
# acquire the lock.
|
||||
return self._engines[conninfo.hostname]
|
||||
|
||||
engine = self._engine_from_config()
|
||||
Session = sessionmaker(bind=engine)
|
||||
metadata.create_all(engine)
|
||||
self._engines[conninfo.hostname] = engine, Session
|
||||
|
||||
return self._engines[conninfo.hostname]
|
||||
|
||||
@property
|
||||
def session(self):
|
||||
if self._session is None:
|
||||
_, Session = self._open()
|
||||
self._session = Session()
|
||||
return self._session
|
||||
|
||||
def _get_or_create(self, queue):
|
||||
obj = self.session.query(self.queue_cls) \
|
||||
.filter(self.queue_cls.name == queue).first()
|
||||
if not obj:
|
||||
with _MUTEX:
|
||||
obj = self.session.query(self.queue_cls) \
|
||||
.filter(self.queue_cls.name == queue).first()
|
||||
if obj:
|
||||
# Queue was created while we were waiting to
|
||||
# acquire the lock.
|
||||
return obj
|
||||
|
||||
obj = self.queue_cls(queue)
|
||||
self.session.add(obj)
|
||||
try:
|
||||
self.session.commit()
|
||||
except OperationalError:
|
||||
self.session.rollback()
|
||||
|
||||
return obj
|
||||
|
||||
def _new_queue(self, queue, **kwargs):
|
||||
self._get_or_create(queue)
|
||||
|
||||
def _put(self, queue, payload, **kwargs):
|
||||
obj = self._get_or_create(queue)
|
||||
message = self.message_cls(dumps(payload), obj)
|
||||
self.session.add(message)
|
||||
try:
|
||||
self.session.commit()
|
||||
except OperationalError:
|
||||
self.session.rollback()
|
||||
|
||||
def _get(self, queue):
|
||||
obj = self._get_or_create(queue)
|
||||
if self.session.bind.name == 'sqlite':
|
||||
self.session.execute(text('BEGIN IMMEDIATE TRANSACTION'))
|
||||
try:
|
||||
msg = self.session.query(self.message_cls) \
|
||||
.with_for_update() \
|
||||
.filter(self.message_cls.queue_id == obj.id) \
|
||||
.filter(self.message_cls.visible != False) \
|
||||
.order_by(self.message_cls.sent_at) \
|
||||
.order_by(self.message_cls.id) \
|
||||
.limit(1) \
|
||||
.first()
|
||||
if msg:
|
||||
msg.visible = False
|
||||
return loads(bytes_to_str(msg.payload))
|
||||
raise Empty()
|
||||
finally:
|
||||
self.session.commit()
|
||||
|
||||
def _query_all(self, queue):
|
||||
obj = self._get_or_create(queue)
|
||||
return self.session.query(self.message_cls) \
|
||||
.filter(self.message_cls.queue_id == obj.id)
|
||||
|
||||
def _purge(self, queue):
|
||||
count = self._query_all(queue).delete(synchronize_session=False)
|
||||
try:
|
||||
self.session.commit()
|
||||
except OperationalError:
|
||||
self.session.rollback()
|
||||
return count
|
||||
|
||||
def _size(self, queue):
|
||||
return self._query_all(queue).count()
|
||||
|
||||
def _declarative_cls(self, name, base, ns):
|
||||
if name not in class_registry:
|
||||
with _MUTEX:
|
||||
if name in class_registry:
|
||||
# Class was registered while we were waiting to
|
||||
# acquire the lock.
|
||||
return class_registry[name]
|
||||
|
||||
return type(str(name), (base, ModelBase), ns)
|
||||
|
||||
return class_registry[name]
|
||||
|
||||
@cached_property
|
||||
def queue_cls(self):
|
||||
return self._declarative_cls(
|
||||
'Queue',
|
||||
QueueBase,
|
||||
{'__tablename__': self.queue_tablename}
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def message_cls(self):
|
||||
return self._declarative_cls(
|
||||
'Message',
|
||||
MessageBase,
|
||||
{'__tablename__': self.message_tablename}
|
||||
)
|
||||
|
||||
|
||||
class Transport(virtual.Transport):
|
||||
"""The transport class."""
|
||||
|
||||
Channel = Channel
|
||||
|
||||
can_parse_url = True
|
||||
default_port = 0
|
||||
driver_type = 'sql'
|
||||
driver_name = 'sqlalchemy'
|
||||
connection_errors = (OperationalError, )
|
||||
|
||||
def driver_version(self):
|
||||
import sqlalchemy
|
||||
return sqlalchemy.__version__
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,76 @@
|
||||
"""Kombu transport using SQLAlchemy as the message store."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
|
||||
from sqlalchemy import (Boolean, Column, DateTime, ForeignKey, Index, Integer,
|
||||
Sequence, SmallInteger, String, Text)
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.schema import MetaData
|
||||
|
||||
try:
|
||||
from sqlalchemy.orm import declarative_base, declared_attr
|
||||
except ImportError:
|
||||
# TODO: Remove this once we drop support for SQLAlchemy < 1.4.
|
||||
from sqlalchemy.ext.declarative import declarative_base, declared_attr
|
||||
|
||||
class_registry = {}
|
||||
metadata = MetaData()
|
||||
ModelBase = declarative_base(metadata=metadata, class_registry=class_registry)
|
||||
|
||||
|
||||
class Queue:
|
||||
"""The queue class."""
|
||||
|
||||
__table_args__ = {'sqlite_autoincrement': True, 'mysql_engine': 'InnoDB'}
|
||||
|
||||
id = Column(Integer, Sequence('queue_id_sequence'), primary_key=True,
|
||||
autoincrement=True)
|
||||
name = Column(String(200), unique=True)
|
||||
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
def __str__(self):
|
||||
return f'<Queue({self.name})>'
|
||||
|
||||
@declared_attr
|
||||
def messages(cls):
|
||||
return relationship('Message', backref='queue', lazy='noload')
|
||||
|
||||
|
||||
class Message:
|
||||
"""The message class."""
|
||||
|
||||
__table_args__ = (
|
||||
Index('ix_kombu_message_timestamp_id', 'timestamp', 'id'),
|
||||
{'sqlite_autoincrement': True, 'mysql_engine': 'InnoDB'}
|
||||
)
|
||||
|
||||
id = Column(Integer, Sequence('message_id_sequence'),
|
||||
primary_key=True, autoincrement=True)
|
||||
visible = Column(Boolean, default=True, index=True)
|
||||
sent_at = Column('timestamp', DateTime, nullable=True, index=True,
|
||||
onupdate=datetime.datetime.now)
|
||||
payload = Column(Text, nullable=False)
|
||||
version = Column(SmallInteger, nullable=False, default=1)
|
||||
|
||||
__mapper_args__ = {'version_id_col': version}
|
||||
|
||||
def __init__(self, payload, queue):
|
||||
self.payload = payload
|
||||
self.queue = queue
|
||||
|
||||
def __str__(self):
|
||||
return '<Message: {0.sent_at} {0.payload} {0.queue_id}>'.format(self)
|
||||
|
||||
@declared_attr
|
||||
def queue_id(self):
|
||||
return Column(
|
||||
Integer,
|
||||
ForeignKey(
|
||||
'%s.id' % class_registry['Queue'].__tablename__,
|
||||
name='FK_kombu_message_queue'
|
||||
)
|
||||
)
|
||||
@@ -0,0 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .base import (AbstractChannel, Base64, BrokerState, Channel, Empty,
|
||||
Management, Message, NotEquivalentError, QoS, Transport,
|
||||
UndeliverableWarning, binding_key_t, queue_binding_t)
|
||||
|
||||
__all__ = (
|
||||
'Base64', 'NotEquivalentError', 'UndeliverableWarning', 'BrokerState',
|
||||
'QoS', 'Message', 'AbstractChannel', 'Channel', 'Management', 'Transport',
|
||||
'Empty', 'binding_key_t', 'queue_binding_t',
|
||||
)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,164 @@
|
||||
"""Virtual AMQ Exchange.
|
||||
|
||||
Implementations of the standard exchanges defined
|
||||
by the AMQ protocol (excluding the `headers` exchange).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
|
||||
from kombu.utils.text import escape_regex
|
||||
|
||||
|
||||
class ExchangeType:
|
||||
"""Base class for exchanges.
|
||||
|
||||
Implements the specifics for an exchange type.
|
||||
|
||||
Arguments:
|
||||
---------
|
||||
channel (ChannelT): AMQ Channel.
|
||||
"""
|
||||
|
||||
type = None
|
||||
|
||||
def __init__(self, channel):
|
||||
self.channel = channel
|
||||
|
||||
def lookup(self, table, exchange, routing_key, default):
|
||||
"""Lookup all queues matching `routing_key` in `exchange`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str: queue name, or 'default' if no queues matched.
|
||||
"""
|
||||
raise NotImplementedError('subclass responsibility')
|
||||
|
||||
def prepare_bind(self, queue, exchange, routing_key, arguments):
|
||||
"""Prepare queue-binding.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tuple[str, Pattern, str]: of `(routing_key, regex, queue)`
|
||||
to be stored for bindings to this exchange.
|
||||
"""
|
||||
return routing_key, None, queue
|
||||
|
||||
def equivalent(self, prev, exchange, type,
|
||||
durable, auto_delete, arguments):
|
||||
"""Return true if `prev` and `exchange` is equivalent."""
|
||||
return (type == prev['type'] and
|
||||
durable == prev['durable'] and
|
||||
auto_delete == prev['auto_delete'] and
|
||||
(arguments or {}) == (prev['arguments'] or {}))
|
||||
|
||||
|
||||
class DirectExchange(ExchangeType):
|
||||
"""Direct exchange.
|
||||
|
||||
The `direct` exchange routes based on exact routing keys.
|
||||
"""
|
||||
|
||||
type = 'direct'
|
||||
|
||||
def lookup(self, table, exchange, routing_key, default):
|
||||
return {
|
||||
queue for rkey, _, queue in table
|
||||
if rkey == routing_key
|
||||
}
|
||||
|
||||
def deliver(self, message, exchange, routing_key, **kwargs):
|
||||
_lookup = self.channel._lookup
|
||||
_put = self.channel._put
|
||||
for queue in _lookup(exchange, routing_key):
|
||||
_put(queue, message, **kwargs)
|
||||
|
||||
|
||||
class TopicExchange(ExchangeType):
|
||||
"""Topic exchange.
|
||||
|
||||
The `topic` exchange routes messages based on words separated by
|
||||
dots, using wildcard characters ``*`` (any single word), and ``#``
|
||||
(one or more words).
|
||||
"""
|
||||
|
||||
type = 'topic'
|
||||
|
||||
#: map of wildcard to regex conversions
|
||||
wildcards = {'*': r'.*?[^\.]',
|
||||
'#': r'.*?'}
|
||||
|
||||
#: compiled regex cache
|
||||
_compiled = {}
|
||||
|
||||
def lookup(self, table, exchange, routing_key, default):
|
||||
return {
|
||||
queue for rkey, pattern, queue in table
|
||||
if self._match(pattern, routing_key)
|
||||
}
|
||||
|
||||
def deliver(self, message, exchange, routing_key, **kwargs):
|
||||
_lookup = self.channel._lookup
|
||||
_put = self.channel._put
|
||||
deadletter = self.channel.deadletter_queue
|
||||
for queue in [q for q in _lookup(exchange, routing_key)
|
||||
if q and q != deadletter]:
|
||||
_put(queue, message, **kwargs)
|
||||
|
||||
def prepare_bind(self, queue, exchange, routing_key, arguments):
|
||||
return routing_key, self.key_to_pattern(routing_key), queue
|
||||
|
||||
def key_to_pattern(self, rkey):
|
||||
"""Get the corresponding regex for any routing key."""
|
||||
return '^%s$' % (r'\.'.join(
|
||||
self.wildcards.get(word, word)
|
||||
for word in escape_regex(rkey, '.#*').split('.')
|
||||
))
|
||||
|
||||
def _match(self, pattern, string):
|
||||
"""Match regular expression (cached).
|
||||
|
||||
Same as :func:`re.match`, except the regex is compiled and cached,
|
||||
then reused on subsequent matches with the same pattern.
|
||||
"""
|
||||
try:
|
||||
compiled = self._compiled[pattern]
|
||||
except KeyError:
|
||||
compiled = self._compiled[pattern] = re.compile(pattern, re.U)
|
||||
return compiled.match(string)
|
||||
|
||||
|
||||
class FanoutExchange(ExchangeType):
|
||||
"""Fanout exchange.
|
||||
|
||||
The `fanout` exchange implements broadcast messaging by delivering
|
||||
copies of all messages to all queues bound to the exchange.
|
||||
|
||||
To support fanout the virtual channel needs to store the table
|
||||
as shared state. This requires that the `Channel.supports_fanout`
|
||||
attribute is set to true, and the `Channel._queue_bind` and
|
||||
`Channel.get_table` methods are implemented.
|
||||
|
||||
See Also
|
||||
--------
|
||||
the redis backend for an example implementation of these methods.
|
||||
"""
|
||||
|
||||
type = 'fanout'
|
||||
|
||||
def lookup(self, table, exchange, routing_key, default):
|
||||
return {queue for _, _, queue in table}
|
||||
|
||||
def deliver(self, message, exchange, routing_key, **kwargs):
|
||||
if self.channel.supports_fanout:
|
||||
self.channel._put_fanout(
|
||||
exchange, message, routing_key, **kwargs)
|
||||
|
||||
|
||||
#: Map of standard exchange types and corresponding classes.
|
||||
STANDARD_EXCHANGE_TYPES = {
|
||||
'direct': DirectExchange,
|
||||
'topic': TopicExchange,
|
||||
'fanout': FanoutExchange,
|
||||
}
|
||||
@@ -0,0 +1,223 @@
|
||||
# copyright: (c) 2010 - 2013 by Mahendra M.
|
||||
# license: BSD, see LICENSE for more details.
|
||||
|
||||
"""Zookeeper transport module for kombu.
|
||||
|
||||
Zookeeper based transport. This transport uses the built-in kazoo Zookeeper
|
||||
based queue implementation.
|
||||
|
||||
**References**
|
||||
|
||||
- https://zookeeper.apache.org/doc/current/recipes.html#sc_recipes_Queues
|
||||
- https://kazoo.readthedocs.io/en/latest/api/recipe/queue.html
|
||||
|
||||
**Limitations**
|
||||
This queue does not offer reliable consumption. An entry is removed from
|
||||
the queue prior to being processed. So if an error occurs, the consumer
|
||||
has to re-queue the item or it will be lost.
|
||||
|
||||
Features
|
||||
========
|
||||
* Type: Virtual
|
||||
* Supports Direct: Yes
|
||||
* Supports Topic: Yes
|
||||
* Supports Fanout: No
|
||||
* Supports Priority: Yes
|
||||
* Supports TTL: No
|
||||
|
||||
Connection String
|
||||
=================
|
||||
Connects to a zookeeper node as:
|
||||
|
||||
.. code-block::
|
||||
|
||||
zookeeper://SERVER:PORT/VHOST
|
||||
|
||||
The <vhost> becomes the base for all the other znodes. So we can use
|
||||
it like a vhost.
|
||||
|
||||
|
||||
Transport Options
|
||||
=================
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import socket
|
||||
from queue import Empty
|
||||
|
||||
from kombu.utils.encoding import bytes_to_str, ensure_bytes
|
||||
from kombu.utils.json import dumps, loads
|
||||
|
||||
from . import virtual
|
||||
|
||||
try:
|
||||
import kazoo
|
||||
from kazoo.client import KazooClient
|
||||
from kazoo.recipe.queue import Queue
|
||||
|
||||
KZ_CONNECTION_ERRORS = (
|
||||
kazoo.exceptions.SystemErrorException,
|
||||
kazoo.exceptions.ConnectionLossException,
|
||||
kazoo.exceptions.MarshallingErrorException,
|
||||
kazoo.exceptions.UnimplementedException,
|
||||
kazoo.exceptions.OperationTimeoutException,
|
||||
kazoo.exceptions.NoAuthException,
|
||||
kazoo.exceptions.InvalidACLException,
|
||||
kazoo.exceptions.AuthFailedException,
|
||||
kazoo.exceptions.SessionExpiredException,
|
||||
)
|
||||
|
||||
KZ_CHANNEL_ERRORS = (
|
||||
kazoo.exceptions.RuntimeInconsistencyException,
|
||||
kazoo.exceptions.DataInconsistencyException,
|
||||
kazoo.exceptions.BadArgumentsException,
|
||||
kazoo.exceptions.MarshallingErrorException,
|
||||
kazoo.exceptions.UnimplementedException,
|
||||
kazoo.exceptions.OperationTimeoutException,
|
||||
kazoo.exceptions.ApiErrorException,
|
||||
kazoo.exceptions.NoNodeException,
|
||||
kazoo.exceptions.NoAuthException,
|
||||
kazoo.exceptions.NodeExistsException,
|
||||
kazoo.exceptions.NoChildrenForEphemeralsException,
|
||||
kazoo.exceptions.NotEmptyException,
|
||||
kazoo.exceptions.SessionExpiredException,
|
||||
kazoo.exceptions.InvalidCallbackException,
|
||||
socket.error,
|
||||
)
|
||||
except ImportError:
|
||||
kazoo = None
|
||||
KZ_CONNECTION_ERRORS = KZ_CHANNEL_ERRORS = ()
|
||||
|
||||
DEFAULT_PORT = 2181
|
||||
|
||||
__author__ = 'Mahendra M <mahendra.m@gmail.com>'
|
||||
|
||||
|
||||
class Channel(virtual.Channel):
|
||||
"""Zookeeper Channel."""
|
||||
|
||||
_client = None
|
||||
_queues = {}
|
||||
|
||||
def __init__(self, connection, **kwargs):
|
||||
super().__init__(connection, **kwargs)
|
||||
vhost = self.connection.client.virtual_host
|
||||
self._vhost = '/{}'.format(vhost.strip('/'))
|
||||
|
||||
def _get_path(self, queue_name):
|
||||
return os.path.join(self._vhost, queue_name)
|
||||
|
||||
def _get_queue(self, queue_name):
|
||||
queue = self._queues.get(queue_name, None)
|
||||
|
||||
if queue is None:
|
||||
queue = Queue(self.client, self._get_path(queue_name))
|
||||
self._queues[queue_name] = queue
|
||||
|
||||
# Ensure that the queue is created
|
||||
len(queue)
|
||||
|
||||
return queue
|
||||
|
||||
def _put(self, queue, message, **kwargs):
|
||||
return self._get_queue(queue).put(
|
||||
ensure_bytes(dumps(message)),
|
||||
priority=self._get_message_priority(message, reverse=True),
|
||||
)
|
||||
|
||||
def _get(self, queue):
|
||||
queue = self._get_queue(queue)
|
||||
msg = queue.get()
|
||||
|
||||
if msg is None:
|
||||
raise Empty()
|
||||
|
||||
return loads(bytes_to_str(msg))
|
||||
|
||||
def _purge(self, queue):
|
||||
count = 0
|
||||
queue = self._get_queue(queue)
|
||||
|
||||
while True:
|
||||
msg = queue.get()
|
||||
if msg is None:
|
||||
break
|
||||
count += 1
|
||||
|
||||
return count
|
||||
|
||||
def _delete(self, queue, *args, **kwargs):
|
||||
if self._has_queue(queue):
|
||||
self._purge(queue)
|
||||
self.client.delete(self._get_path(queue))
|
||||
|
||||
def _size(self, queue):
|
||||
queue = self._get_queue(queue)
|
||||
return len(queue)
|
||||
|
||||
def _new_queue(self, queue, **kwargs):
|
||||
if not self._has_queue(queue):
|
||||
queue = self._get_queue(queue)
|
||||
|
||||
def _has_queue(self, queue):
|
||||
return self.client.exists(self._get_path(queue)) is not None
|
||||
|
||||
def _open(self):
|
||||
conninfo = self.connection.client
|
||||
hosts = []
|
||||
if conninfo.alt:
|
||||
for host_port in conninfo.alt:
|
||||
if host_port.startswith('zookeeper://'):
|
||||
host_port = host_port[len('zookeeper://'):]
|
||||
if not host_port:
|
||||
continue
|
||||
try:
|
||||
host, port = host_port.split(':', 1)
|
||||
host_port = (host, int(port))
|
||||
except ValueError:
|
||||
if host_port == conninfo.hostname:
|
||||
host_port = (host_port, conninfo.port or DEFAULT_PORT)
|
||||
else:
|
||||
host_port = (host_port, DEFAULT_PORT)
|
||||
hosts.append(host_port)
|
||||
host_port = (conninfo.hostname, conninfo.port or DEFAULT_PORT)
|
||||
if host_port not in hosts:
|
||||
hosts.insert(0, host_port)
|
||||
conn_str = ','.join([f'{h}:{p}' for h, p in hosts])
|
||||
conn = KazooClient(conn_str)
|
||||
conn.start()
|
||||
return conn
|
||||
|
||||
@property
|
||||
def client(self):
|
||||
if self._client is None:
|
||||
self._client = self._open()
|
||||
return self._client
|
||||
|
||||
|
||||
class Transport(virtual.Transport):
|
||||
"""Zookeeper Transport."""
|
||||
|
||||
Channel = Channel
|
||||
polling_interval = 1
|
||||
default_port = DEFAULT_PORT
|
||||
connection_errors = (
|
||||
virtual.Transport.connection_errors + KZ_CONNECTION_ERRORS
|
||||
)
|
||||
channel_errors = (
|
||||
virtual.Transport.channel_errors + KZ_CHANNEL_ERRORS
|
||||
)
|
||||
driver_type = 'zookeeper'
|
||||
driver_name = 'kazoo'
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
if kazoo is None:
|
||||
raise ImportError('The kazoo library is not installed')
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def driver_version(self):
|
||||
return kazoo.__version__
|
||||
Reference in New Issue
Block a user