Updates
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,320 @@
|
||||
"""Amazon SQS Connection."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
from botocore.serialize import Serializer
|
||||
from vine import transform
|
||||
|
||||
from kombu.asynchronous.aws.connection import AsyncAWSQueryConnection
|
||||
from kombu.asynchronous.aws.ext import AWSRequest
|
||||
|
||||
from .ext import boto3
|
||||
from .message import AsyncMessage
|
||||
from .queue import AsyncQueue
|
||||
|
||||
__all__ = ('AsyncSQSConnection',)
|
||||
|
||||
|
||||
class AsyncSQSConnection(AsyncAWSQueryConnection):
|
||||
"""Async SQS Connection."""
|
||||
|
||||
def __init__(self, sqs_connection, debug=0, region=None, fetch_message_attributes=None, **kwargs):
|
||||
if boto3 is None:
|
||||
raise ImportError('boto3 is not installed')
|
||||
super().__init__(
|
||||
sqs_connection,
|
||||
region_name=region, debug=debug,
|
||||
**kwargs
|
||||
)
|
||||
self.fetch_message_attributes = (
|
||||
fetch_message_attributes if fetch_message_attributes is not None
|
||||
else ["ApproximateReceiveCount"]
|
||||
)
|
||||
|
||||
def _create_query_request(self, operation, params, queue_url, method):
|
||||
params = params.copy()
|
||||
if operation:
|
||||
params['Action'] = operation
|
||||
|
||||
# defaults for non-get
|
||||
param_payload = {'data': params}
|
||||
headers = {}
|
||||
if method.lower() == 'get':
|
||||
# query-based opts
|
||||
param_payload = {'params': params}
|
||||
|
||||
if method.lower() == 'post':
|
||||
headers['Content-Type'] = 'application/x-www-form-urlencoded; charset=utf-8'
|
||||
|
||||
return AWSRequest(method=method, url=queue_url, headers=headers, **param_payload)
|
||||
|
||||
def _create_json_request(self, operation, params, queue_url):
|
||||
params = params.copy()
|
||||
params['QueueUrl'] = queue_url
|
||||
|
||||
service_model = self.sqs_connection.meta.service_model
|
||||
operation_model = service_model.operation_model(operation)
|
||||
|
||||
url = self.sqs_connection._endpoint.host
|
||||
|
||||
headers = {}
|
||||
# Content-Type
|
||||
json_version = operation_model.metadata['jsonVersion']
|
||||
content_type = f'application/x-amz-json-{json_version}'
|
||||
headers['Content-Type'] = content_type
|
||||
|
||||
# X-Amz-Target
|
||||
target = '{}.{}'.format(
|
||||
operation_model.metadata['targetPrefix'],
|
||||
operation_model.name,
|
||||
)
|
||||
headers['X-Amz-Target'] = target
|
||||
|
||||
param_payload = {
|
||||
'data': json.dumps(params).encode(),
|
||||
'headers': headers
|
||||
}
|
||||
|
||||
method = operation_model.http.get('method', Serializer.DEFAULT_METHOD)
|
||||
return AWSRequest(
|
||||
method=method,
|
||||
url=url,
|
||||
**param_payload
|
||||
)
|
||||
|
||||
def make_request(self, operation_name, params, queue_url, verb, callback=None, protocol_params=None):
|
||||
"""Override make_request to support different protocols.
|
||||
|
||||
botocore has changed the default protocol of communicating
|
||||
with SQS backend from 'query' to 'json', so we need a special
|
||||
implementation of make_request for SQS. More information on this can
|
||||
be found in: https://github.com/celery/kombu/pull/1807.
|
||||
|
||||
protocol_params: Optional[dict[str, dict]] of per-protocol additional parameters.
|
||||
Supported for the SQS query to json protocol transition.
|
||||
"""
|
||||
signer = self.sqs_connection._request_signer
|
||||
|
||||
service_model = self.sqs_connection.meta.service_model
|
||||
protocol = service_model.protocol
|
||||
all_params = {**(params or {}), **protocol_params.get(protocol, {})}
|
||||
|
||||
if protocol == 'query':
|
||||
request = self._create_query_request(
|
||||
operation_name, all_params, queue_url, verb)
|
||||
elif protocol == 'json':
|
||||
request = self._create_json_request(
|
||||
operation_name, all_params, queue_url)
|
||||
else:
|
||||
raise Exception(f'Unsupported protocol: {protocol}.')
|
||||
|
||||
signing_type = 'presign-url' if request.method.lower() == 'get' \
|
||||
else 'standard'
|
||||
|
||||
signer.sign(operation_name, request, signing_type=signing_type)
|
||||
prepared_request = request.prepare()
|
||||
|
||||
return self._mexe(prepared_request, callback=callback)
|
||||
|
||||
def create_queue(self, queue_name,
|
||||
visibility_timeout=None, callback=None):
|
||||
params = {'QueueName': queue_name}
|
||||
if visibility_timeout:
|
||||
params['DefaultVisibilityTimeout'] = format(
|
||||
visibility_timeout, 'd',
|
||||
)
|
||||
return self.get_object('CreateQueue', params,
|
||||
callback=callback)
|
||||
|
||||
def delete_queue(self, queue, force_deletion=False, callback=None):
|
||||
return self.get_status('DeleteQueue', None, queue.id,
|
||||
callback=callback)
|
||||
|
||||
def get_queue_url(self, queue):
|
||||
res = self.sqs_connection.get_queue_url(QueueName=queue)
|
||||
return res['QueueUrl']
|
||||
|
||||
def get_queue_attributes(self, queue, attribute='All', callback=None):
|
||||
return self.get_object(
|
||||
'GetQueueAttributes', {'AttributeName': attribute},
|
||||
queue.id, callback=callback,
|
||||
)
|
||||
|
||||
def set_queue_attribute(self, queue, attribute, value, callback=None):
|
||||
return self.get_status(
|
||||
'SetQueueAttribute',
|
||||
{},
|
||||
queue.id, callback=callback,
|
||||
protocol_params={
|
||||
'json': {'Attributes': {attribute: value}},
|
||||
'query': {'Attribute.Name': attribute, 'Attribute.Value': value},
|
||||
},
|
||||
)
|
||||
|
||||
def receive_message(
|
||||
self, queue, queue_url, number_messages=1, visibility_timeout=None,
|
||||
attributes=None, wait_time_seconds=None,
|
||||
callback=None
|
||||
):
|
||||
params = {'MaxNumberOfMessages': number_messages}
|
||||
proto_params = {'query': {}, 'json': {}}
|
||||
attrs = attributes if attributes is not None else self.fetch_message_attributes
|
||||
|
||||
if visibility_timeout:
|
||||
params['VisibilityTimeout'] = visibility_timeout
|
||||
if attrs:
|
||||
proto_params['json'].update({'AttributeNames': list(attrs)})
|
||||
proto_params['query'].update(_query_object_encode({'AttributeName': list(attrs)}))
|
||||
if wait_time_seconds is not None:
|
||||
params['WaitTimeSeconds'] = wait_time_seconds
|
||||
|
||||
return self.get_list(
|
||||
'ReceiveMessage', params, [('Message', AsyncMessage)],
|
||||
queue_url, callback=callback, parent=queue,
|
||||
protocol_params=proto_params,
|
||||
)
|
||||
|
||||
def delete_message(self, queue, receipt_handle, callback=None):
|
||||
return self.delete_message_from_handle(
|
||||
queue, receipt_handle, callback,
|
||||
)
|
||||
|
||||
def delete_message_batch(self, queue, messages, callback=None):
|
||||
p_params = {
|
||||
'json': {
|
||||
'Entries': [{'Id': m.id, 'ReceiptHandle': m.receipt_handle} for m in messages],
|
||||
},
|
||||
'query': _query_object_encode({
|
||||
'DeleteMessageBatchRequestEntry': [
|
||||
{'Id': m.id, 'ReceiptHandle': m.receipt_handle}
|
||||
for m in messages
|
||||
],
|
||||
}),
|
||||
}
|
||||
|
||||
return self.get_object(
|
||||
'DeleteMessageBatch', {}, queue.id,
|
||||
verb='POST', callback=callback, protocol_params=p_params,
|
||||
)
|
||||
|
||||
def delete_message_from_handle(self, queue, receipt_handle,
|
||||
callback=None):
|
||||
return self.get_status(
|
||||
'DeleteMessage', {'ReceiptHandle': receipt_handle},
|
||||
queue, callback=callback,
|
||||
)
|
||||
|
||||
def send_message(self, queue, message_content,
|
||||
delay_seconds=None, callback=None):
|
||||
params = {'MessageBody': message_content}
|
||||
if delay_seconds:
|
||||
params['DelaySeconds'] = int(delay_seconds)
|
||||
return self.get_object(
|
||||
'SendMessage', params, queue.id,
|
||||
verb='POST', callback=callback,
|
||||
)
|
||||
|
||||
def send_message_batch(self, queue, messages, callback=None):
|
||||
params = {}
|
||||
for i, msg in enumerate(messages):
|
||||
prefix = f'SendMessageBatchRequestEntry.{i + 1}'
|
||||
params.update({
|
||||
f'{prefix}.Id': msg[0],
|
||||
f'{prefix}.MessageBody': msg[1],
|
||||
f'{prefix}.DelaySeconds': msg[2],
|
||||
})
|
||||
return self.get_object(
|
||||
'SendMessageBatch', params, queue.id,
|
||||
verb='POST', callback=callback,
|
||||
)
|
||||
|
||||
def change_message_visibility(self, queue, receipt_handle,
|
||||
visibility_timeout, callback=None):
|
||||
return self.get_status(
|
||||
'ChangeMessageVisibility',
|
||||
{'ReceiptHandle': receipt_handle,
|
||||
'VisibilityTimeout': visibility_timeout},
|
||||
queue.id, callback=callback,
|
||||
)
|
||||
|
||||
def change_message_visibility_batch(self, queue, messages, callback=None):
|
||||
entries = [
|
||||
{'Id': t[0].id, 'ReceiptHandle': t[0].receipt_handle, 'VisibilityTimeout': t[1]}
|
||||
for t in messages
|
||||
]
|
||||
|
||||
p_params = {
|
||||
'json': {'Entries': entries},
|
||||
'query': _query_object_encode({'ChangeMessageVisibilityBatchRequestEntry': entries}),
|
||||
}
|
||||
|
||||
return self.get_object(
|
||||
'ChangeMessageVisibilityBatch', {}, queue.id,
|
||||
verb='POST', callback=callback,
|
||||
protocol_params=p_params,
|
||||
)
|
||||
|
||||
def get_all_queues(self, prefix='', callback=None):
|
||||
params = {}
|
||||
if prefix:
|
||||
params['QueueNamePrefix'] = prefix
|
||||
return self.get_list(
|
||||
'ListQueues', params, [('QueueUrl', AsyncQueue)],
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
def get_queue(self, queue_name, callback=None):
|
||||
# TODO Does not support owner_acct_id argument
|
||||
return self.get_all_queues(
|
||||
queue_name,
|
||||
transform(self._on_queue_ready, callback, queue_name),
|
||||
)
|
||||
lookup = get_queue
|
||||
|
||||
def _on_queue_ready(self, name, queues):
|
||||
return next(
|
||||
(q for q in queues if q.url.endswith(name)), None,
|
||||
)
|
||||
|
||||
def get_dead_letter_source_queues(self, queue, callback=None):
|
||||
return self.get_list(
|
||||
'ListDeadLetterSourceQueues', {'QueueUrl': queue.url},
|
||||
[('QueueUrl', AsyncQueue)],
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
def add_permission(self, queue, label, aws_account_id, action_name,
|
||||
callback=None):
|
||||
return self.get_status(
|
||||
'AddPermission',
|
||||
{'Label': label,
|
||||
'AWSAccountId': aws_account_id,
|
||||
'ActionName': action_name},
|
||||
queue.id, callback=callback,
|
||||
)
|
||||
|
||||
def remove_permission(self, queue, label, callback=None):
|
||||
return self.get_status(
|
||||
'RemovePermission', {'Label': label}, queue.id, callback=callback,
|
||||
)
|
||||
|
||||
|
||||
def _query_object_encode(items):
|
||||
params = {}
|
||||
_query_object_encode_part(params, '', items)
|
||||
return {k: v for k, v in params.items()}
|
||||
|
||||
|
||||
def _query_object_encode_part(params, prefix, part):
|
||||
dotted = f'{prefix}.' if prefix else prefix
|
||||
|
||||
if isinstance(part, (list, tuple)):
|
||||
for i, item in enumerate(part):
|
||||
_query_object_encode_part(params, f'{dotted}{i + 1}', item)
|
||||
elif isinstance(part, dict):
|
||||
for key, value in part.items():
|
||||
_query_object_encode_part(params, f'{dotted}{key}', value)
|
||||
else:
|
||||
params[prefix] = str(part)
|
||||
@@ -0,0 +1,9 @@
|
||||
"""Amazon SQS boto3 interface."""
|
||||
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
try:
|
||||
import boto3
|
||||
except ImportError:
|
||||
boto3 = None
|
||||
@@ -0,0 +1,35 @@
|
||||
"""Amazon SQS message implementation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
|
||||
from kombu.message import Message
|
||||
from kombu.utils.encoding import str_to_bytes
|
||||
|
||||
|
||||
class BaseAsyncMessage(Message):
|
||||
"""Base class for messages received on async client."""
|
||||
|
||||
|
||||
class AsyncRawMessage(BaseAsyncMessage):
|
||||
"""Raw Message."""
|
||||
|
||||
|
||||
class AsyncMessage(BaseAsyncMessage):
|
||||
"""Serialized message."""
|
||||
|
||||
def encode(self, value):
|
||||
"""Encode/decode the value using Base64 encoding."""
|
||||
return base64.b64encode(str_to_bytes(value)).decode()
|
||||
|
||||
def __getitem__(self, item):
|
||||
"""Support Boto3-style access on a message."""
|
||||
if item == 'ReceiptHandle':
|
||||
return self.receipt_handle
|
||||
elif item == 'Body':
|
||||
return self.get_body()
|
||||
elif item == 'queue':
|
||||
return self.queue
|
||||
else:
|
||||
raise KeyError(item)
|
||||
@@ -0,0 +1,130 @@
|
||||
"""Amazon SQS queue implementation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from vine import transform
|
||||
|
||||
from .message import AsyncMessage
|
||||
|
||||
_all__ = ['AsyncQueue']
|
||||
|
||||
|
||||
def list_first(rs):
|
||||
"""Get the first item in a list, or None if list empty."""
|
||||
return rs[0] if len(rs) == 1 else None
|
||||
|
||||
|
||||
class AsyncQueue:
|
||||
"""Async SQS Queue."""
|
||||
|
||||
def __init__(self, connection=None, url=None, message_class=AsyncMessage):
|
||||
self.connection = connection
|
||||
self.url = url
|
||||
self.message_class = message_class
|
||||
self.visibility_timeout = None
|
||||
|
||||
def _NA(self, *args, **kwargs):
|
||||
raise NotImplementedError()
|
||||
count_slow = dump = save_to_file = save_to_filename = save = \
|
||||
save_to_s3 = load_from_s3 = load_from_file = load_from_filename = \
|
||||
load = clear = _NA
|
||||
|
||||
def get_attributes(self, attributes='All', callback=None):
|
||||
return self.connection.get_queue_attributes(
|
||||
self, attributes, callback,
|
||||
)
|
||||
|
||||
def set_attribute(self, attribute, value, callback=None):
|
||||
return self.connection.set_queue_attribute(
|
||||
self, attribute, value, callback,
|
||||
)
|
||||
|
||||
def get_timeout(self, callback=None, _attr='VisibilityTimeout'):
|
||||
return self.get_attributes(
|
||||
_attr, transform(
|
||||
self._coerce_field_value, callback, _attr, int,
|
||||
),
|
||||
)
|
||||
|
||||
def _coerce_field_value(self, key, type, response):
|
||||
return type(response[key])
|
||||
|
||||
def set_timeout(self, visibility_timeout, callback=None):
|
||||
return self.set_attribute(
|
||||
'VisibilityTimeout', visibility_timeout,
|
||||
transform(
|
||||
self._on_timeout_set, callback,
|
||||
)
|
||||
)
|
||||
|
||||
def _on_timeout_set(self, visibility_timeout):
|
||||
if visibility_timeout:
|
||||
self.visibility_timeout = visibility_timeout
|
||||
return self.visibility_timeout
|
||||
|
||||
def add_permission(self, label, aws_account_id, action_name,
|
||||
callback=None):
|
||||
return self.connection.add_permission(
|
||||
self, label, aws_account_id, action_name, callback,
|
||||
)
|
||||
|
||||
def remove_permission(self, label, callback=None):
|
||||
return self.connection.remove_permission(self, label, callback)
|
||||
|
||||
def read(self, visibility_timeout=None, wait_time_seconds=None,
|
||||
callback=None):
|
||||
return self.get_messages(
|
||||
1, visibility_timeout,
|
||||
wait_time_seconds=wait_time_seconds,
|
||||
callback=transform(list_first, callback),
|
||||
)
|
||||
|
||||
def write(self, message, delay_seconds=None, callback=None):
|
||||
return self.connection.send_message(
|
||||
self, message.get_body_encoded(), delay_seconds,
|
||||
callback=transform(self._on_message_sent, callback, message),
|
||||
)
|
||||
|
||||
def write_batch(self, messages, callback=None):
|
||||
return self.connection.send_message_batch(
|
||||
self, messages, callback=callback,
|
||||
)
|
||||
|
||||
def _on_message_sent(self, orig_message, new_message):
|
||||
orig_message.id = new_message.id
|
||||
orig_message.md5 = new_message.md5
|
||||
return new_message
|
||||
|
||||
def get_messages(self, num_messages=1, visibility_timeout=None,
|
||||
attributes=None, wait_time_seconds=None, callback=None):
|
||||
return self.connection.receive_message(
|
||||
self, number_messages=num_messages,
|
||||
visibility_timeout=visibility_timeout,
|
||||
attributes=attributes,
|
||||
wait_time_seconds=wait_time_seconds,
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
def delete_message(self, message, callback=None):
|
||||
return self.connection.delete_message(self, message, callback)
|
||||
|
||||
def delete_message_batch(self, messages, callback=None):
|
||||
return self.connection.delete_message_batch(
|
||||
self, messages, callback=callback,
|
||||
)
|
||||
|
||||
def change_message_visibility_batch(self, messages, callback=None):
|
||||
return self.connection.change_message_visibility_batch(
|
||||
self, messages, callback=callback,
|
||||
)
|
||||
|
||||
def delete(self, callback=None):
|
||||
return self.connection.delete_queue(self, callback=callback)
|
||||
|
||||
def count(self, page_size=10, vtimeout=10, callback=None,
|
||||
_attr='ApproximateNumberOfMessages'):
|
||||
return self.get_attributes(
|
||||
_attr, callback=transform(
|
||||
self._coerce_field_value, callback, _attr, int,
|
||||
),
|
||||
)
|
||||
Reference in New Issue
Block a user