This commit is contained in:
Iliyan Angelov
2025-09-14 23:24:25 +03:00
commit c67067a2a4
71311 changed files with 6800714 additions and 0 deletions

View File

@@ -0,0 +1,17 @@
from __future__ import annotations
from typing import Any
from kombu.asynchronous.aws.sqs.connection import AsyncSQSConnection
def connect_sqs(
aws_access_key_id: str | None = None,
aws_secret_access_key: str | None = None,
**kwargs: Any
) -> AsyncSQSConnection:
"""Return async connection to Amazon SQS."""
from .sqs.connection import AsyncSQSConnection
return AsyncSQSConnection(
aws_access_key_id, aws_secret_access_key, **kwargs
)

View File

@@ -0,0 +1,278 @@
"""Amazon AWS Connection."""
from __future__ import annotations
from email import message_from_bytes
from email.mime.message import MIMEMessage
from vine import promise, transform
from kombu.asynchronous.aws.ext import AWSRequest, get_cert_path, get_response
from kombu.asynchronous.http import Headers, Request, get_client
def message_from_headers(hdr):
bs = "\r\n".join("{}: {}".format(*h) for h in hdr)
return message_from_bytes(bs.encode())
__all__ = (
'AsyncHTTPSConnection', 'AsyncConnection',
)
class AsyncHTTPResponse:
"""Async HTTP Response."""
def __init__(self, response):
self.response = response
self._msg = None
self.version = 10
def read(self, *args, **kwargs):
return self.response.body
def getheader(self, name, default=None):
return self.response.headers.get(name, default)
def getheaders(self):
return list(self.response.headers.items())
@property
def msg(self):
if self._msg is None:
self._msg = MIMEMessage(message_from_headers(self.getheaders()))
return self._msg
@property
def status(self):
return self.response.code
@property
def reason(self):
if self.response.error:
return self.response.error.message
return ''
def __repr__(self):
return repr(self.response)
class AsyncHTTPSConnection:
"""Async HTTP Connection."""
Request = Request
Response = AsyncHTTPResponse
method = 'GET'
path = '/'
body = None
default_ports = {'http': 80, 'https': 443}
def __init__(self, strict=None, timeout=20.0, http_client=None):
self.headers = []
self.timeout = timeout
self.strict = strict
self.http_client = http_client or get_client()
def request(self, method, path, body=None, headers=None):
self.path = path
self.method = method
if body is not None:
try:
read = body.read
except AttributeError:
self.body = body
else:
self.body = read()
if headers is not None:
self.headers.extend(list(headers.items()))
def getrequest(self):
headers = Headers(self.headers)
return self.Request(self.path, method=self.method, headers=headers,
body=self.body, connect_timeout=self.timeout,
request_timeout=self.timeout,
validate_cert=True, ca_certs=get_cert_path(True))
def getresponse(self, callback=None):
request = self.getrequest()
request.then(transform(self.Response, callback))
return self.http_client.add_request(request)
def set_debuglevel(self, level):
pass
def connect(self):
pass
def close(self):
pass
def putrequest(self, method, path):
self.method = method
self.path = path
def putheader(self, header, value):
self.headers.append((header, value))
def endheaders(self):
pass
def send(self, data):
if self.body:
self.body += data
else:
self.body = data
def __repr__(self):
return f'<AsyncHTTPConnection: {self.getrequest()!r}>'
class AsyncConnection:
"""Async AWS Connection."""
def __init__(self, sqs_connection, http_client=None, **kwargs):
self.sqs_connection = sqs_connection
self._httpclient = http_client or get_client()
def get_http_connection(self):
return AsyncHTTPSConnection(http_client=self._httpclient)
def _mexe(self, request, sender=None, callback=None):
callback = callback or promise()
conn = self.get_http_connection()
if callable(sender):
sender(conn, request.method, request.path, request.body,
request.headers, callback)
else:
conn.request(request.method, request.url,
request.body, request.headers)
conn.getresponse(callback=callback)
return callback
class AsyncAWSQueryConnection(AsyncConnection):
"""Async AWS Query Connection."""
STATUS_CODE_OK = 200
STATUS_CODE_REQUEST_TIMEOUT = 408
STATUS_CODE_NETWORK_CONNECT_TIMEOUT_ERROR = 599
STATUS_CODE_INTERNAL_ERROR = 500
STATUS_CODE_BAD_GATEWAY = 502
STATUS_CODE_SERVICE_UNAVAILABLE_ERROR = 503
STATUS_CODE_GATEWAY_TIMEOUT = 504
STATUS_CODES_SERVER_ERRORS = (
STATUS_CODE_INTERNAL_ERROR,
STATUS_CODE_BAD_GATEWAY,
STATUS_CODE_SERVICE_UNAVAILABLE_ERROR
)
STATUS_CODES_TIMEOUT = (
STATUS_CODE_REQUEST_TIMEOUT,
STATUS_CODE_NETWORK_CONNECT_TIMEOUT_ERROR,
STATUS_CODE_GATEWAY_TIMEOUT
)
def __init__(self, sqs_connection, http_client=None,
http_client_params=None, **kwargs):
if not http_client_params:
http_client_params = {}
super().__init__(sqs_connection, http_client,
**http_client_params)
def make_request(self, operation, params_, path, verb, callback=None, protocol_params=None):
params = params_.copy()
params.update((protocol_params or {}).get('query', {}))
if operation:
params['Action'] = operation
signer = self.sqs_connection._request_signer
# defaults for non-get
signing_type = 'standard'
param_payload = {'data': params}
if verb.lower() == 'get':
# query-based opts
signing_type = 'presign-url'
param_payload = {'params': params}
request = AWSRequest(method=verb, url=path, **param_payload)
signer.sign(operation, request, signing_type=signing_type)
prepared_request = request.prepare()
return self._mexe(prepared_request, callback=callback)
def get_list(self, operation, params, markers, path='/', parent=None, verb='POST', callback=None,
protocol_params=None):
return self.make_request(
operation, params, path, verb,
callback=transform(
self._on_list_ready, callback, parent or self, markers,
operation
),
protocol_params=protocol_params,
)
def get_object(self, operation, params, path='/', parent=None, verb='GET', callback=None, protocol_params=None):
return self.make_request(
operation, params, path, verb,
callback=transform(
self._on_obj_ready, callback, parent or self, operation
),
protocol_params=protocol_params,
)
def get_status(self, operation, params, path='/', parent=None, verb='GET', callback=None, protocol_params=None):
return self.make_request(
operation, params, path, verb,
callback=transform(
self._on_status_ready, callback, parent or self, operation
),
protocol_params=protocol_params,
)
def _on_list_ready(self, parent, markers, operation, response):
service_model = self.sqs_connection.meta.service_model
if response.status == self.STATUS_CODE_OK:
_, parsed = get_response(
service_model.operation_model(operation), response.response
)
return parsed
elif (
response.status in self.STATUS_CODES_TIMEOUT or
response.status in self.STATUS_CODES_SERVER_ERRORS
):
# When the server returns a timeout or 50X server error,
# the response is interpreted as an empty list.
# This prevents hanging the Celery worker.
return []
else:
raise self._for_status(response, response.read())
def _on_obj_ready(self, parent, operation, response):
service_model = self.sqs_connection.meta.service_model
if response.status == self.STATUS_CODE_OK:
_, parsed = get_response(
service_model.operation_model(operation), response.response
)
return parsed
else:
raise self._for_status(response, response.read())
def _on_status_ready(self, parent, operation, response):
service_model = self.sqs_connection.meta.service_model
if response.status == self.STATUS_CODE_OK:
httpres, _ = get_response(
service_model.operation_model(operation), response.response
)
return httpres.code
else:
raise self._for_status(response, response.read())
def _for_status(self, response, body):
context = 'Empty body' if not body else 'HTTP Error'
return Exception("Request {} HTTP {} {} ({})".format(
context, response.status, response.reason, body
))

View File

@@ -0,0 +1,28 @@
"""Amazon boto3 interface."""
from __future__ import annotations
try:
import boto3
from botocore import exceptions
from botocore.awsrequest import AWSRequest
from botocore.httpsession import get_cert_path
from botocore.response import get_response
except ImportError:
boto3 = None
class _void:
pass
class BotoCoreError(Exception):
pass
exceptions = _void()
exceptions.BotoCoreError = BotoCoreError
AWSRequest = _void()
get_response = _void()
get_cert_path = _void()
__all__ = (
'exceptions', 'AWSRequest', 'get_response', 'get_cert_path',
)

View File

@@ -0,0 +1,320 @@
"""Amazon SQS Connection."""
from __future__ import annotations
import json
from botocore.serialize import Serializer
from vine import transform
from kombu.asynchronous.aws.connection import AsyncAWSQueryConnection
from kombu.asynchronous.aws.ext import AWSRequest
from .ext import boto3
from .message import AsyncMessage
from .queue import AsyncQueue
__all__ = ('AsyncSQSConnection',)
class AsyncSQSConnection(AsyncAWSQueryConnection):
"""Async SQS Connection."""
def __init__(self, sqs_connection, debug=0, region=None, fetch_message_attributes=None, **kwargs):
if boto3 is None:
raise ImportError('boto3 is not installed')
super().__init__(
sqs_connection,
region_name=region, debug=debug,
**kwargs
)
self.fetch_message_attributes = (
fetch_message_attributes if fetch_message_attributes is not None
else ["ApproximateReceiveCount"]
)
def _create_query_request(self, operation, params, queue_url, method):
params = params.copy()
if operation:
params['Action'] = operation
# defaults for non-get
param_payload = {'data': params}
headers = {}
if method.lower() == 'get':
# query-based opts
param_payload = {'params': params}
if method.lower() == 'post':
headers['Content-Type'] = 'application/x-www-form-urlencoded; charset=utf-8'
return AWSRequest(method=method, url=queue_url, headers=headers, **param_payload)
def _create_json_request(self, operation, params, queue_url):
params = params.copy()
params['QueueUrl'] = queue_url
service_model = self.sqs_connection.meta.service_model
operation_model = service_model.operation_model(operation)
url = self.sqs_connection._endpoint.host
headers = {}
# Content-Type
json_version = operation_model.metadata['jsonVersion']
content_type = f'application/x-amz-json-{json_version}'
headers['Content-Type'] = content_type
# X-Amz-Target
target = '{}.{}'.format(
operation_model.metadata['targetPrefix'],
operation_model.name,
)
headers['X-Amz-Target'] = target
param_payload = {
'data': json.dumps(params).encode(),
'headers': headers
}
method = operation_model.http.get('method', Serializer.DEFAULT_METHOD)
return AWSRequest(
method=method,
url=url,
**param_payload
)
def make_request(self, operation_name, params, queue_url, verb, callback=None, protocol_params=None):
"""Override make_request to support different protocols.
botocore has changed the default protocol of communicating
with SQS backend from 'query' to 'json', so we need a special
implementation of make_request for SQS. More information on this can
be found in: https://github.com/celery/kombu/pull/1807.
protocol_params: Optional[dict[str, dict]] of per-protocol additional parameters.
Supported for the SQS query to json protocol transition.
"""
signer = self.sqs_connection._request_signer
service_model = self.sqs_connection.meta.service_model
protocol = service_model.protocol
all_params = {**(params or {}), **protocol_params.get(protocol, {})}
if protocol == 'query':
request = self._create_query_request(
operation_name, all_params, queue_url, verb)
elif protocol == 'json':
request = self._create_json_request(
operation_name, all_params, queue_url)
else:
raise Exception(f'Unsupported protocol: {protocol}.')
signing_type = 'presign-url' if request.method.lower() == 'get' \
else 'standard'
signer.sign(operation_name, request, signing_type=signing_type)
prepared_request = request.prepare()
return self._mexe(prepared_request, callback=callback)
def create_queue(self, queue_name,
visibility_timeout=None, callback=None):
params = {'QueueName': queue_name}
if visibility_timeout:
params['DefaultVisibilityTimeout'] = format(
visibility_timeout, 'd',
)
return self.get_object('CreateQueue', params,
callback=callback)
def delete_queue(self, queue, force_deletion=False, callback=None):
return self.get_status('DeleteQueue', None, queue.id,
callback=callback)
def get_queue_url(self, queue):
res = self.sqs_connection.get_queue_url(QueueName=queue)
return res['QueueUrl']
def get_queue_attributes(self, queue, attribute='All', callback=None):
return self.get_object(
'GetQueueAttributes', {'AttributeName': attribute},
queue.id, callback=callback,
)
def set_queue_attribute(self, queue, attribute, value, callback=None):
return self.get_status(
'SetQueueAttribute',
{},
queue.id, callback=callback,
protocol_params={
'json': {'Attributes': {attribute: value}},
'query': {'Attribute.Name': attribute, 'Attribute.Value': value},
},
)
def receive_message(
self, queue, queue_url, number_messages=1, visibility_timeout=None,
attributes=None, wait_time_seconds=None,
callback=None
):
params = {'MaxNumberOfMessages': number_messages}
proto_params = {'query': {}, 'json': {}}
attrs = attributes if attributes is not None else self.fetch_message_attributes
if visibility_timeout:
params['VisibilityTimeout'] = visibility_timeout
if attrs:
proto_params['json'].update({'AttributeNames': list(attrs)})
proto_params['query'].update(_query_object_encode({'AttributeName': list(attrs)}))
if wait_time_seconds is not None:
params['WaitTimeSeconds'] = wait_time_seconds
return self.get_list(
'ReceiveMessage', params, [('Message', AsyncMessage)],
queue_url, callback=callback, parent=queue,
protocol_params=proto_params,
)
def delete_message(self, queue, receipt_handle, callback=None):
return self.delete_message_from_handle(
queue, receipt_handle, callback,
)
def delete_message_batch(self, queue, messages, callback=None):
p_params = {
'json': {
'Entries': [{'Id': m.id, 'ReceiptHandle': m.receipt_handle} for m in messages],
},
'query': _query_object_encode({
'DeleteMessageBatchRequestEntry': [
{'Id': m.id, 'ReceiptHandle': m.receipt_handle}
for m in messages
],
}),
}
return self.get_object(
'DeleteMessageBatch', {}, queue.id,
verb='POST', callback=callback, protocol_params=p_params,
)
def delete_message_from_handle(self, queue, receipt_handle,
callback=None):
return self.get_status(
'DeleteMessage', {'ReceiptHandle': receipt_handle},
queue, callback=callback,
)
def send_message(self, queue, message_content,
delay_seconds=None, callback=None):
params = {'MessageBody': message_content}
if delay_seconds:
params['DelaySeconds'] = int(delay_seconds)
return self.get_object(
'SendMessage', params, queue.id,
verb='POST', callback=callback,
)
def send_message_batch(self, queue, messages, callback=None):
params = {}
for i, msg in enumerate(messages):
prefix = f'SendMessageBatchRequestEntry.{i + 1}'
params.update({
f'{prefix}.Id': msg[0],
f'{prefix}.MessageBody': msg[1],
f'{prefix}.DelaySeconds': msg[2],
})
return self.get_object(
'SendMessageBatch', params, queue.id,
verb='POST', callback=callback,
)
def change_message_visibility(self, queue, receipt_handle,
visibility_timeout, callback=None):
return self.get_status(
'ChangeMessageVisibility',
{'ReceiptHandle': receipt_handle,
'VisibilityTimeout': visibility_timeout},
queue.id, callback=callback,
)
def change_message_visibility_batch(self, queue, messages, callback=None):
entries = [
{'Id': t[0].id, 'ReceiptHandle': t[0].receipt_handle, 'VisibilityTimeout': t[1]}
for t in messages
]
p_params = {
'json': {'Entries': entries},
'query': _query_object_encode({'ChangeMessageVisibilityBatchRequestEntry': entries}),
}
return self.get_object(
'ChangeMessageVisibilityBatch', {}, queue.id,
verb='POST', callback=callback,
protocol_params=p_params,
)
def get_all_queues(self, prefix='', callback=None):
params = {}
if prefix:
params['QueueNamePrefix'] = prefix
return self.get_list(
'ListQueues', params, [('QueueUrl', AsyncQueue)],
callback=callback,
)
def get_queue(self, queue_name, callback=None):
# TODO Does not support owner_acct_id argument
return self.get_all_queues(
queue_name,
transform(self._on_queue_ready, callback, queue_name),
)
lookup = get_queue
def _on_queue_ready(self, name, queues):
return next(
(q for q in queues if q.url.endswith(name)), None,
)
def get_dead_letter_source_queues(self, queue, callback=None):
return self.get_list(
'ListDeadLetterSourceQueues', {'QueueUrl': queue.url},
[('QueueUrl', AsyncQueue)],
callback=callback,
)
def add_permission(self, queue, label, aws_account_id, action_name,
callback=None):
return self.get_status(
'AddPermission',
{'Label': label,
'AWSAccountId': aws_account_id,
'ActionName': action_name},
queue.id, callback=callback,
)
def remove_permission(self, queue, label, callback=None):
return self.get_status(
'RemovePermission', {'Label': label}, queue.id, callback=callback,
)
def _query_object_encode(items):
params = {}
_query_object_encode_part(params, '', items)
return {k: v for k, v in params.items()}
def _query_object_encode_part(params, prefix, part):
dotted = f'{prefix}.' if prefix else prefix
if isinstance(part, (list, tuple)):
for i, item in enumerate(part):
_query_object_encode_part(params, f'{dotted}{i + 1}', item)
elif isinstance(part, dict):
for key, value in part.items():
_query_object_encode_part(params, f'{dotted}{key}', value)
else:
params[prefix] = str(part)

View File

@@ -0,0 +1,9 @@
"""Amazon SQS boto3 interface."""
from __future__ import annotations
try:
import boto3
except ImportError:
boto3 = None

View File

@@ -0,0 +1,35 @@
"""Amazon SQS message implementation."""
from __future__ import annotations
import base64
from kombu.message import Message
from kombu.utils.encoding import str_to_bytes
class BaseAsyncMessage(Message):
"""Base class for messages received on async client."""
class AsyncRawMessage(BaseAsyncMessage):
"""Raw Message."""
class AsyncMessage(BaseAsyncMessage):
"""Serialized message."""
def encode(self, value):
"""Encode/decode the value using Base64 encoding."""
return base64.b64encode(str_to_bytes(value)).decode()
def __getitem__(self, item):
"""Support Boto3-style access on a message."""
if item == 'ReceiptHandle':
return self.receipt_handle
elif item == 'Body':
return self.get_body()
elif item == 'queue':
return self.queue
else:
raise KeyError(item)

View File

@@ -0,0 +1,130 @@
"""Amazon SQS queue implementation."""
from __future__ import annotations
from vine import transform
from .message import AsyncMessage
_all__ = ['AsyncQueue']
def list_first(rs):
"""Get the first item in a list, or None if list empty."""
return rs[0] if len(rs) == 1 else None
class AsyncQueue:
"""Async SQS Queue."""
def __init__(self, connection=None, url=None, message_class=AsyncMessage):
self.connection = connection
self.url = url
self.message_class = message_class
self.visibility_timeout = None
def _NA(self, *args, **kwargs):
raise NotImplementedError()
count_slow = dump = save_to_file = save_to_filename = save = \
save_to_s3 = load_from_s3 = load_from_file = load_from_filename = \
load = clear = _NA
def get_attributes(self, attributes='All', callback=None):
return self.connection.get_queue_attributes(
self, attributes, callback,
)
def set_attribute(self, attribute, value, callback=None):
return self.connection.set_queue_attribute(
self, attribute, value, callback,
)
def get_timeout(self, callback=None, _attr='VisibilityTimeout'):
return self.get_attributes(
_attr, transform(
self._coerce_field_value, callback, _attr, int,
),
)
def _coerce_field_value(self, key, type, response):
return type(response[key])
def set_timeout(self, visibility_timeout, callback=None):
return self.set_attribute(
'VisibilityTimeout', visibility_timeout,
transform(
self._on_timeout_set, callback,
)
)
def _on_timeout_set(self, visibility_timeout):
if visibility_timeout:
self.visibility_timeout = visibility_timeout
return self.visibility_timeout
def add_permission(self, label, aws_account_id, action_name,
callback=None):
return self.connection.add_permission(
self, label, aws_account_id, action_name, callback,
)
def remove_permission(self, label, callback=None):
return self.connection.remove_permission(self, label, callback)
def read(self, visibility_timeout=None, wait_time_seconds=None,
callback=None):
return self.get_messages(
1, visibility_timeout,
wait_time_seconds=wait_time_seconds,
callback=transform(list_first, callback),
)
def write(self, message, delay_seconds=None, callback=None):
return self.connection.send_message(
self, message.get_body_encoded(), delay_seconds,
callback=transform(self._on_message_sent, callback, message),
)
def write_batch(self, messages, callback=None):
return self.connection.send_message_batch(
self, messages, callback=callback,
)
def _on_message_sent(self, orig_message, new_message):
orig_message.id = new_message.id
orig_message.md5 = new_message.md5
return new_message
def get_messages(self, num_messages=1, visibility_timeout=None,
attributes=None, wait_time_seconds=None, callback=None):
return self.connection.receive_message(
self, number_messages=num_messages,
visibility_timeout=visibility_timeout,
attributes=attributes,
wait_time_seconds=wait_time_seconds,
callback=callback,
)
def delete_message(self, message, callback=None):
return self.connection.delete_message(self, message, callback)
def delete_message_batch(self, messages, callback=None):
return self.connection.delete_message_batch(
self, messages, callback=callback,
)
def change_message_visibility_batch(self, messages, callback=None):
return self.connection.change_message_visibility_batch(
self, messages, callback=callback,
)
def delete(self, callback=None):
return self.connection.delete_queue(self, callback=callback)
def count(self, page_size=10, vtimeout=10, callback=None,
_attr='ApproximateNumberOfMessages'):
return self.get_attributes(
_attr, callback=transform(
self._coerce_field_value, callback, _attr, int,
),
)