This commit is contained in:
Iliyan Angelov
2025-09-14 23:24:25 +03:00
commit c67067a2a4
71311 changed files with 6800714 additions and 0 deletions

View File

@@ -0,0 +1,9 @@
"""Event loop."""
from __future__ import annotations
from kombu.utils.eventio import ERR, READ, WRITE
from .hub import Hub, get_event_loop, set_event_loop
__all__ = ('READ', 'WRITE', 'ERR', 'Hub', 'get_event_loop', 'set_event_loop')

View File

@@ -0,0 +1,17 @@
from __future__ import annotations
from typing import Any
from kombu.asynchronous.aws.sqs.connection import AsyncSQSConnection
def connect_sqs(
aws_access_key_id: str | None = None,
aws_secret_access_key: str | None = None,
**kwargs: Any
) -> AsyncSQSConnection:
"""Return async connection to Amazon SQS."""
from .sqs.connection import AsyncSQSConnection
return AsyncSQSConnection(
aws_access_key_id, aws_secret_access_key, **kwargs
)

View File

@@ -0,0 +1,278 @@
"""Amazon AWS Connection."""
from __future__ import annotations
from email import message_from_bytes
from email.mime.message import MIMEMessage
from vine import promise, transform
from kombu.asynchronous.aws.ext import AWSRequest, get_cert_path, get_response
from kombu.asynchronous.http import Headers, Request, get_client
def message_from_headers(hdr):
bs = "\r\n".join("{}: {}".format(*h) for h in hdr)
return message_from_bytes(bs.encode())
__all__ = (
'AsyncHTTPSConnection', 'AsyncConnection',
)
class AsyncHTTPResponse:
"""Async HTTP Response."""
def __init__(self, response):
self.response = response
self._msg = None
self.version = 10
def read(self, *args, **kwargs):
return self.response.body
def getheader(self, name, default=None):
return self.response.headers.get(name, default)
def getheaders(self):
return list(self.response.headers.items())
@property
def msg(self):
if self._msg is None:
self._msg = MIMEMessage(message_from_headers(self.getheaders()))
return self._msg
@property
def status(self):
return self.response.code
@property
def reason(self):
if self.response.error:
return self.response.error.message
return ''
def __repr__(self):
return repr(self.response)
class AsyncHTTPSConnection:
"""Async HTTP Connection."""
Request = Request
Response = AsyncHTTPResponse
method = 'GET'
path = '/'
body = None
default_ports = {'http': 80, 'https': 443}
def __init__(self, strict=None, timeout=20.0, http_client=None):
self.headers = []
self.timeout = timeout
self.strict = strict
self.http_client = http_client or get_client()
def request(self, method, path, body=None, headers=None):
self.path = path
self.method = method
if body is not None:
try:
read = body.read
except AttributeError:
self.body = body
else:
self.body = read()
if headers is not None:
self.headers.extend(list(headers.items()))
def getrequest(self):
headers = Headers(self.headers)
return self.Request(self.path, method=self.method, headers=headers,
body=self.body, connect_timeout=self.timeout,
request_timeout=self.timeout,
validate_cert=True, ca_certs=get_cert_path(True))
def getresponse(self, callback=None):
request = self.getrequest()
request.then(transform(self.Response, callback))
return self.http_client.add_request(request)
def set_debuglevel(self, level):
pass
def connect(self):
pass
def close(self):
pass
def putrequest(self, method, path):
self.method = method
self.path = path
def putheader(self, header, value):
self.headers.append((header, value))
def endheaders(self):
pass
def send(self, data):
if self.body:
self.body += data
else:
self.body = data
def __repr__(self):
return f'<AsyncHTTPConnection: {self.getrequest()!r}>'
class AsyncConnection:
"""Async AWS Connection."""
def __init__(self, sqs_connection, http_client=None, **kwargs):
self.sqs_connection = sqs_connection
self._httpclient = http_client or get_client()
def get_http_connection(self):
return AsyncHTTPSConnection(http_client=self._httpclient)
def _mexe(self, request, sender=None, callback=None):
callback = callback or promise()
conn = self.get_http_connection()
if callable(sender):
sender(conn, request.method, request.path, request.body,
request.headers, callback)
else:
conn.request(request.method, request.url,
request.body, request.headers)
conn.getresponse(callback=callback)
return callback
class AsyncAWSQueryConnection(AsyncConnection):
"""Async AWS Query Connection."""
STATUS_CODE_OK = 200
STATUS_CODE_REQUEST_TIMEOUT = 408
STATUS_CODE_NETWORK_CONNECT_TIMEOUT_ERROR = 599
STATUS_CODE_INTERNAL_ERROR = 500
STATUS_CODE_BAD_GATEWAY = 502
STATUS_CODE_SERVICE_UNAVAILABLE_ERROR = 503
STATUS_CODE_GATEWAY_TIMEOUT = 504
STATUS_CODES_SERVER_ERRORS = (
STATUS_CODE_INTERNAL_ERROR,
STATUS_CODE_BAD_GATEWAY,
STATUS_CODE_SERVICE_UNAVAILABLE_ERROR
)
STATUS_CODES_TIMEOUT = (
STATUS_CODE_REQUEST_TIMEOUT,
STATUS_CODE_NETWORK_CONNECT_TIMEOUT_ERROR,
STATUS_CODE_GATEWAY_TIMEOUT
)
def __init__(self, sqs_connection, http_client=None,
http_client_params=None, **kwargs):
if not http_client_params:
http_client_params = {}
super().__init__(sqs_connection, http_client,
**http_client_params)
def make_request(self, operation, params_, path, verb, callback=None, protocol_params=None):
params = params_.copy()
params.update((protocol_params or {}).get('query', {}))
if operation:
params['Action'] = operation
signer = self.sqs_connection._request_signer
# defaults for non-get
signing_type = 'standard'
param_payload = {'data': params}
if verb.lower() == 'get':
# query-based opts
signing_type = 'presign-url'
param_payload = {'params': params}
request = AWSRequest(method=verb, url=path, **param_payload)
signer.sign(operation, request, signing_type=signing_type)
prepared_request = request.prepare()
return self._mexe(prepared_request, callback=callback)
def get_list(self, operation, params, markers, path='/', parent=None, verb='POST', callback=None,
protocol_params=None):
return self.make_request(
operation, params, path, verb,
callback=transform(
self._on_list_ready, callback, parent or self, markers,
operation
),
protocol_params=protocol_params,
)
def get_object(self, operation, params, path='/', parent=None, verb='GET', callback=None, protocol_params=None):
return self.make_request(
operation, params, path, verb,
callback=transform(
self._on_obj_ready, callback, parent or self, operation
),
protocol_params=protocol_params,
)
def get_status(self, operation, params, path='/', parent=None, verb='GET', callback=None, protocol_params=None):
return self.make_request(
operation, params, path, verb,
callback=transform(
self._on_status_ready, callback, parent or self, operation
),
protocol_params=protocol_params,
)
def _on_list_ready(self, parent, markers, operation, response):
service_model = self.sqs_connection.meta.service_model
if response.status == self.STATUS_CODE_OK:
_, parsed = get_response(
service_model.operation_model(operation), response.response
)
return parsed
elif (
response.status in self.STATUS_CODES_TIMEOUT or
response.status in self.STATUS_CODES_SERVER_ERRORS
):
# When the server returns a timeout or 50X server error,
# the response is interpreted as an empty list.
# This prevents hanging the Celery worker.
return []
else:
raise self._for_status(response, response.read())
def _on_obj_ready(self, parent, operation, response):
service_model = self.sqs_connection.meta.service_model
if response.status == self.STATUS_CODE_OK:
_, parsed = get_response(
service_model.operation_model(operation), response.response
)
return parsed
else:
raise self._for_status(response, response.read())
def _on_status_ready(self, parent, operation, response):
service_model = self.sqs_connection.meta.service_model
if response.status == self.STATUS_CODE_OK:
httpres, _ = get_response(
service_model.operation_model(operation), response.response
)
return httpres.code
else:
raise self._for_status(response, response.read())
def _for_status(self, response, body):
context = 'Empty body' if not body else 'HTTP Error'
return Exception("Request {} HTTP {} {} ({})".format(
context, response.status, response.reason, body
))

View File

@@ -0,0 +1,28 @@
"""Amazon boto3 interface."""
from __future__ import annotations
try:
import boto3
from botocore import exceptions
from botocore.awsrequest import AWSRequest
from botocore.httpsession import get_cert_path
from botocore.response import get_response
except ImportError:
boto3 = None
class _void:
pass
class BotoCoreError(Exception):
pass
exceptions = _void()
exceptions.BotoCoreError = BotoCoreError
AWSRequest = _void()
get_response = _void()
get_cert_path = _void()
__all__ = (
'exceptions', 'AWSRequest', 'get_response', 'get_cert_path',
)

View File

@@ -0,0 +1,320 @@
"""Amazon SQS Connection."""
from __future__ import annotations
import json
from botocore.serialize import Serializer
from vine import transform
from kombu.asynchronous.aws.connection import AsyncAWSQueryConnection
from kombu.asynchronous.aws.ext import AWSRequest
from .ext import boto3
from .message import AsyncMessage
from .queue import AsyncQueue
__all__ = ('AsyncSQSConnection',)
class AsyncSQSConnection(AsyncAWSQueryConnection):
"""Async SQS Connection."""
def __init__(self, sqs_connection, debug=0, region=None, fetch_message_attributes=None, **kwargs):
if boto3 is None:
raise ImportError('boto3 is not installed')
super().__init__(
sqs_connection,
region_name=region, debug=debug,
**kwargs
)
self.fetch_message_attributes = (
fetch_message_attributes if fetch_message_attributes is not None
else ["ApproximateReceiveCount"]
)
def _create_query_request(self, operation, params, queue_url, method):
params = params.copy()
if operation:
params['Action'] = operation
# defaults for non-get
param_payload = {'data': params}
headers = {}
if method.lower() == 'get':
# query-based opts
param_payload = {'params': params}
if method.lower() == 'post':
headers['Content-Type'] = 'application/x-www-form-urlencoded; charset=utf-8'
return AWSRequest(method=method, url=queue_url, headers=headers, **param_payload)
def _create_json_request(self, operation, params, queue_url):
params = params.copy()
params['QueueUrl'] = queue_url
service_model = self.sqs_connection.meta.service_model
operation_model = service_model.operation_model(operation)
url = self.sqs_connection._endpoint.host
headers = {}
# Content-Type
json_version = operation_model.metadata['jsonVersion']
content_type = f'application/x-amz-json-{json_version}'
headers['Content-Type'] = content_type
# X-Amz-Target
target = '{}.{}'.format(
operation_model.metadata['targetPrefix'],
operation_model.name,
)
headers['X-Amz-Target'] = target
param_payload = {
'data': json.dumps(params).encode(),
'headers': headers
}
method = operation_model.http.get('method', Serializer.DEFAULT_METHOD)
return AWSRequest(
method=method,
url=url,
**param_payload
)
def make_request(self, operation_name, params, queue_url, verb, callback=None, protocol_params=None):
"""Override make_request to support different protocols.
botocore has changed the default protocol of communicating
with SQS backend from 'query' to 'json', so we need a special
implementation of make_request for SQS. More information on this can
be found in: https://github.com/celery/kombu/pull/1807.
protocol_params: Optional[dict[str, dict]] of per-protocol additional parameters.
Supported for the SQS query to json protocol transition.
"""
signer = self.sqs_connection._request_signer
service_model = self.sqs_connection.meta.service_model
protocol = service_model.protocol
all_params = {**(params or {}), **protocol_params.get(protocol, {})}
if protocol == 'query':
request = self._create_query_request(
operation_name, all_params, queue_url, verb)
elif protocol == 'json':
request = self._create_json_request(
operation_name, all_params, queue_url)
else:
raise Exception(f'Unsupported protocol: {protocol}.')
signing_type = 'presign-url' if request.method.lower() == 'get' \
else 'standard'
signer.sign(operation_name, request, signing_type=signing_type)
prepared_request = request.prepare()
return self._mexe(prepared_request, callback=callback)
def create_queue(self, queue_name,
visibility_timeout=None, callback=None):
params = {'QueueName': queue_name}
if visibility_timeout:
params['DefaultVisibilityTimeout'] = format(
visibility_timeout, 'd',
)
return self.get_object('CreateQueue', params,
callback=callback)
def delete_queue(self, queue, force_deletion=False, callback=None):
return self.get_status('DeleteQueue', None, queue.id,
callback=callback)
def get_queue_url(self, queue):
res = self.sqs_connection.get_queue_url(QueueName=queue)
return res['QueueUrl']
def get_queue_attributes(self, queue, attribute='All', callback=None):
return self.get_object(
'GetQueueAttributes', {'AttributeName': attribute},
queue.id, callback=callback,
)
def set_queue_attribute(self, queue, attribute, value, callback=None):
return self.get_status(
'SetQueueAttribute',
{},
queue.id, callback=callback,
protocol_params={
'json': {'Attributes': {attribute: value}},
'query': {'Attribute.Name': attribute, 'Attribute.Value': value},
},
)
def receive_message(
self, queue, queue_url, number_messages=1, visibility_timeout=None,
attributes=None, wait_time_seconds=None,
callback=None
):
params = {'MaxNumberOfMessages': number_messages}
proto_params = {'query': {}, 'json': {}}
attrs = attributes if attributes is not None else self.fetch_message_attributes
if visibility_timeout:
params['VisibilityTimeout'] = visibility_timeout
if attrs:
proto_params['json'].update({'AttributeNames': list(attrs)})
proto_params['query'].update(_query_object_encode({'AttributeName': list(attrs)}))
if wait_time_seconds is not None:
params['WaitTimeSeconds'] = wait_time_seconds
return self.get_list(
'ReceiveMessage', params, [('Message', AsyncMessage)],
queue_url, callback=callback, parent=queue,
protocol_params=proto_params,
)
def delete_message(self, queue, receipt_handle, callback=None):
return self.delete_message_from_handle(
queue, receipt_handle, callback,
)
def delete_message_batch(self, queue, messages, callback=None):
p_params = {
'json': {
'Entries': [{'Id': m.id, 'ReceiptHandle': m.receipt_handle} for m in messages],
},
'query': _query_object_encode({
'DeleteMessageBatchRequestEntry': [
{'Id': m.id, 'ReceiptHandle': m.receipt_handle}
for m in messages
],
}),
}
return self.get_object(
'DeleteMessageBatch', {}, queue.id,
verb='POST', callback=callback, protocol_params=p_params,
)
def delete_message_from_handle(self, queue, receipt_handle,
callback=None):
return self.get_status(
'DeleteMessage', {'ReceiptHandle': receipt_handle},
queue, callback=callback,
)
def send_message(self, queue, message_content,
delay_seconds=None, callback=None):
params = {'MessageBody': message_content}
if delay_seconds:
params['DelaySeconds'] = int(delay_seconds)
return self.get_object(
'SendMessage', params, queue.id,
verb='POST', callback=callback,
)
def send_message_batch(self, queue, messages, callback=None):
params = {}
for i, msg in enumerate(messages):
prefix = f'SendMessageBatchRequestEntry.{i + 1}'
params.update({
f'{prefix}.Id': msg[0],
f'{prefix}.MessageBody': msg[1],
f'{prefix}.DelaySeconds': msg[2],
})
return self.get_object(
'SendMessageBatch', params, queue.id,
verb='POST', callback=callback,
)
def change_message_visibility(self, queue, receipt_handle,
visibility_timeout, callback=None):
return self.get_status(
'ChangeMessageVisibility',
{'ReceiptHandle': receipt_handle,
'VisibilityTimeout': visibility_timeout},
queue.id, callback=callback,
)
def change_message_visibility_batch(self, queue, messages, callback=None):
entries = [
{'Id': t[0].id, 'ReceiptHandle': t[0].receipt_handle, 'VisibilityTimeout': t[1]}
for t in messages
]
p_params = {
'json': {'Entries': entries},
'query': _query_object_encode({'ChangeMessageVisibilityBatchRequestEntry': entries}),
}
return self.get_object(
'ChangeMessageVisibilityBatch', {}, queue.id,
verb='POST', callback=callback,
protocol_params=p_params,
)
def get_all_queues(self, prefix='', callback=None):
params = {}
if prefix:
params['QueueNamePrefix'] = prefix
return self.get_list(
'ListQueues', params, [('QueueUrl', AsyncQueue)],
callback=callback,
)
def get_queue(self, queue_name, callback=None):
# TODO Does not support owner_acct_id argument
return self.get_all_queues(
queue_name,
transform(self._on_queue_ready, callback, queue_name),
)
lookup = get_queue
def _on_queue_ready(self, name, queues):
return next(
(q for q in queues if q.url.endswith(name)), None,
)
def get_dead_letter_source_queues(self, queue, callback=None):
return self.get_list(
'ListDeadLetterSourceQueues', {'QueueUrl': queue.url},
[('QueueUrl', AsyncQueue)],
callback=callback,
)
def add_permission(self, queue, label, aws_account_id, action_name,
callback=None):
return self.get_status(
'AddPermission',
{'Label': label,
'AWSAccountId': aws_account_id,
'ActionName': action_name},
queue.id, callback=callback,
)
def remove_permission(self, queue, label, callback=None):
return self.get_status(
'RemovePermission', {'Label': label}, queue.id, callback=callback,
)
def _query_object_encode(items):
params = {}
_query_object_encode_part(params, '', items)
return {k: v for k, v in params.items()}
def _query_object_encode_part(params, prefix, part):
dotted = f'{prefix}.' if prefix else prefix
if isinstance(part, (list, tuple)):
for i, item in enumerate(part):
_query_object_encode_part(params, f'{dotted}{i + 1}', item)
elif isinstance(part, dict):
for key, value in part.items():
_query_object_encode_part(params, f'{dotted}{key}', value)
else:
params[prefix] = str(part)

View File

@@ -0,0 +1,9 @@
"""Amazon SQS boto3 interface."""
from __future__ import annotations
try:
import boto3
except ImportError:
boto3 = None

View File

@@ -0,0 +1,35 @@
"""Amazon SQS message implementation."""
from __future__ import annotations
import base64
from kombu.message import Message
from kombu.utils.encoding import str_to_bytes
class BaseAsyncMessage(Message):
"""Base class for messages received on async client."""
class AsyncRawMessage(BaseAsyncMessage):
"""Raw Message."""
class AsyncMessage(BaseAsyncMessage):
"""Serialized message."""
def encode(self, value):
"""Encode/decode the value using Base64 encoding."""
return base64.b64encode(str_to_bytes(value)).decode()
def __getitem__(self, item):
"""Support Boto3-style access on a message."""
if item == 'ReceiptHandle':
return self.receipt_handle
elif item == 'Body':
return self.get_body()
elif item == 'queue':
return self.queue
else:
raise KeyError(item)

View File

@@ -0,0 +1,130 @@
"""Amazon SQS queue implementation."""
from __future__ import annotations
from vine import transform
from .message import AsyncMessage
_all__ = ['AsyncQueue']
def list_first(rs):
"""Get the first item in a list, or None if list empty."""
return rs[0] if len(rs) == 1 else None
class AsyncQueue:
"""Async SQS Queue."""
def __init__(self, connection=None, url=None, message_class=AsyncMessage):
self.connection = connection
self.url = url
self.message_class = message_class
self.visibility_timeout = None
def _NA(self, *args, **kwargs):
raise NotImplementedError()
count_slow = dump = save_to_file = save_to_filename = save = \
save_to_s3 = load_from_s3 = load_from_file = load_from_filename = \
load = clear = _NA
def get_attributes(self, attributes='All', callback=None):
return self.connection.get_queue_attributes(
self, attributes, callback,
)
def set_attribute(self, attribute, value, callback=None):
return self.connection.set_queue_attribute(
self, attribute, value, callback,
)
def get_timeout(self, callback=None, _attr='VisibilityTimeout'):
return self.get_attributes(
_attr, transform(
self._coerce_field_value, callback, _attr, int,
),
)
def _coerce_field_value(self, key, type, response):
return type(response[key])
def set_timeout(self, visibility_timeout, callback=None):
return self.set_attribute(
'VisibilityTimeout', visibility_timeout,
transform(
self._on_timeout_set, callback,
)
)
def _on_timeout_set(self, visibility_timeout):
if visibility_timeout:
self.visibility_timeout = visibility_timeout
return self.visibility_timeout
def add_permission(self, label, aws_account_id, action_name,
callback=None):
return self.connection.add_permission(
self, label, aws_account_id, action_name, callback,
)
def remove_permission(self, label, callback=None):
return self.connection.remove_permission(self, label, callback)
def read(self, visibility_timeout=None, wait_time_seconds=None,
callback=None):
return self.get_messages(
1, visibility_timeout,
wait_time_seconds=wait_time_seconds,
callback=transform(list_first, callback),
)
def write(self, message, delay_seconds=None, callback=None):
return self.connection.send_message(
self, message.get_body_encoded(), delay_seconds,
callback=transform(self._on_message_sent, callback, message),
)
def write_batch(self, messages, callback=None):
return self.connection.send_message_batch(
self, messages, callback=callback,
)
def _on_message_sent(self, orig_message, new_message):
orig_message.id = new_message.id
orig_message.md5 = new_message.md5
return new_message
def get_messages(self, num_messages=1, visibility_timeout=None,
attributes=None, wait_time_seconds=None, callback=None):
return self.connection.receive_message(
self, number_messages=num_messages,
visibility_timeout=visibility_timeout,
attributes=attributes,
wait_time_seconds=wait_time_seconds,
callback=callback,
)
def delete_message(self, message, callback=None):
return self.connection.delete_message(self, message, callback)
def delete_message_batch(self, messages, callback=None):
return self.connection.delete_message_batch(
self, messages, callback=callback,
)
def change_message_visibility_batch(self, messages, callback=None):
return self.connection.change_message_visibility_batch(
self, messages, callback=callback,
)
def delete(self, callback=None):
return self.connection.delete_queue(self, callback=callback)
def count(self, page_size=10, vtimeout=10, callback=None,
_attr='ApproximateNumberOfMessages'):
return self.get_attributes(
_attr, callback=transform(
self._coerce_field_value, callback, _attr, int,
),
)

View File

@@ -0,0 +1,67 @@
"""Event-loop debugging tools."""
from __future__ import annotations
from kombu.utils.eventio import ERR, READ, WRITE
from kombu.utils.functional import reprcall
def repr_flag(flag):
"""Return description of event loop flag."""
return '{}{}{}'.format('R' if flag & READ else '',
'W' if flag & WRITE else '',
'!' if flag & ERR else '')
def _rcb(obj):
if obj is None:
return '<missing>'
if isinstance(obj, str):
return obj
if isinstance(obj, tuple):
cb, args = obj
return reprcall(cb.__name__, args=args)
return obj.__name__
def repr_active(h):
"""Return description of active readers and writers."""
return ', '.join(repr_readers(h) + repr_writers(h))
def repr_events(h, events):
"""Return description of events returned by poll."""
return ', '.join(
'{}({})->{}'.format(
_rcb(callback_for(h, fd, fl, '(GONE)')), fd,
repr_flag(fl),
)
for fd, fl in events
)
def repr_readers(h):
"""Return description of pending readers."""
return [f'({fd}){_rcb(cb)}->{repr_flag(READ | ERR)}'
for fd, cb in h.readers.items()]
def repr_writers(h):
"""Return description of pending writers."""
return [f'({fd}){_rcb(cb)}->{repr_flag(WRITE)}'
for fd, cb in h.writers.items()]
def callback_for(h, fd, flag, *default):
"""Return the callback used for hub+fd+flag."""
try:
if flag & READ:
return h.readers[fd]
if flag & WRITE:
if fd in h.consolidate:
return h.consolidate_callback
return h.writers[fd]
except KeyError:
if default:
return default[0]
raise

View File

@@ -0,0 +1,23 @@
from __future__ import annotations
from kombu.asynchronous import get_event_loop
from kombu.asynchronous.http.base import BaseClient, Headers, Request, Response
from kombu.asynchronous.hub import Hub
__all__ = ('Client', 'Headers', 'Response', 'Request', 'get_client')
def Client(hub: Hub | None = None, **kwargs: int) -> BaseClient:
"""Create new HTTP client."""
from .urllib3_client import Urllib3Client
return Urllib3Client(hub, **kwargs)
def get_client(hub: Hub | None = None, **kwargs: int) -> BaseClient:
"""Get or create HTTP client bound to the current event loop."""
hub = hub or get_event_loop()
try:
return hub._current_http_client
except AttributeError:
client = hub._current_http_client = Client(hub, **kwargs)
return client

View File

@@ -0,0 +1,280 @@
"""Base async HTTP client implementation."""
from __future__ import annotations
import sys
from http.client import responses
from typing import TYPE_CHECKING
from vine import Thenable, maybe_promise, promise
from kombu.exceptions import HttpError
from kombu.utils.compat import coro
from kombu.utils.encoding import bytes_to_str
from kombu.utils.functional import maybe_list, memoize
if TYPE_CHECKING:
from types import TracebackType
__all__ = ('Headers', 'Response', 'Request', 'BaseClient')
PYPY = hasattr(sys, 'pypy_version_info')
@memoize(maxsize=1000)
def normalize_header(key):
return '-'.join(p.capitalize() for p in key.split('-'))
class Headers(dict):
"""Represents a mapping of HTTP headers."""
# TODO: This is just a regular dict and will not perform normalization
# when looking up keys etc.
#: Set when all of the headers have been read.
complete = False
#: Internal attribute used to keep track of continuation lines.
_prev_key = None
@Thenable.register
class Request:
"""A HTTP Request.
Arguments:
---------
url (str): The URL to request.
method (str): The HTTP method to use (defaults to ``GET``).
Keyword Arguments:
-----------------
headers (Dict, ~kombu.asynchronous.http.Headers): Optional headers for
this request
body (str): Optional body for this request.
connect_timeout (float): Connection timeout in float seconds
Default is 30.0.
timeout (float): Time in float seconds before the request times out
Default is 30.0.
follow_redirects (bool): Specify if the client should follow redirects
Enabled by default.
max_redirects (int): Maximum number of redirects (default 6).
use_gzip (bool): Allow the server to use gzip compression.
Enabled by default.
validate_cert (bool): Set to true if the server certificate should be
verified when performing ``https://`` requests.
Enabled by default.
auth_username (str): Username for HTTP authentication.
auth_password (str): Password for HTTP authentication.
auth_mode (str): Type of HTTP authentication (``basic`` or ``digest``).
user_agent (str): Custom user agent for this request.
network_interface (str): Network interface to use for this request.
on_ready (Callable): Callback to be called when the response has been
received. Must accept single ``response`` argument.
on_stream (Callable): Optional callback to be called every time body
content has been read from the socket. If specified then the
response body and buffer attributes will not be available.
on_timeout (callable): Optional callback to be called if the request
times out.
on_header (Callable): Optional callback to be called for every header
line received from the server. The signature
is ``(headers, line)`` and note that if you want
``response.headers`` to be populated then your callback needs to
also call ``client.on_header(headers, line)``.
on_prepare (Callable): Optional callback that is implementation
specific (e.g. curl client will pass the ``curl`` instance to
this callback).
proxy_host (str): Optional proxy host. Note that a ``proxy_port`` must
also be provided or a :exc:`ValueError` will be raised.
proxy_username (str): Optional username to use when logging in
to the proxy.
proxy_password (str): Optional password to use when authenticating
with the proxy server.
ca_certs (str): Custom CA certificates file to use.
client_key (str): Optional filename for client SSL key.
client_cert (str): Optional filename for client SSL certificate.
"""
body = user_agent = network_interface = \
auth_username = auth_password = auth_mode = \
proxy_host = proxy_port = proxy_username = proxy_password = \
ca_certs = client_key = client_cert = None
connect_timeout = 30.0
request_timeout = 30.0
follow_redirects = True
max_redirects = 6
use_gzip = True
validate_cert = True
if not PYPY: # pragma: no cover
__slots__ = ('url', 'method', 'on_ready', 'on_timeout', 'on_stream',
'on_prepare', 'on_header', 'headers',
'__weakref__', '__dict__')
def __init__(self, url, method='GET', on_ready=None, on_timeout=None,
on_stream=None, on_prepare=None, on_header=None,
headers=None, **kwargs):
self.url = url
self.method = method or self.method
self.on_ready = maybe_promise(on_ready) or promise()
self.on_timeout = maybe_promise(on_timeout)
self.on_stream = maybe_promise(on_stream)
self.on_prepare = maybe_promise(on_prepare)
self.on_header = maybe_promise(on_header)
if kwargs:
for k, v in kwargs.items():
setattr(self, k, v)
if not isinstance(headers, Headers):
headers = Headers(headers or {})
self.headers = headers
def then(self, callback, errback=None):
self.on_ready.then(callback, errback)
def __repr__(self):
return '<Request: {0.method} {0.url} {0.body}>'.format(self)
class Response:
"""HTTP Response.
Arguments
---------
request (~kombu.asynchronous.http.Request): See :attr:`request`.
code (int): See :attr:`code`.
headers (~kombu.asynchronous.http.Headers): See :attr:`headers`.
buffer (bytes): See :attr:`buffer`
effective_url (str): See :attr:`effective_url`.
status (str): See :attr:`status`.
Attributes
----------
request (~kombu.asynchronous.http.Request): object used to
get this response.
code (int): HTTP response code (e.g. 200, 404, or 500).
headers (~kombu.asynchronous.http.Headers): HTTP headers
for this response.
buffer (bytes): Socket read buffer.
effective_url (str): The destination url for this request after
following redirects.
error (Exception): Error instance if the request resulted in
a HTTP error code.
status (str): Human equivalent of :attr:`code`,
e.g. ``OK``, `Not found`, or 'Internal Server Error'.
"""
if not PYPY: # pragma: no cover
__slots__ = ('request', 'code', 'headers', 'buffer', 'effective_url',
'error', 'status', '_body', '__weakref__')
def __init__(self, request, code, headers=None, buffer=None,
effective_url=None, error=None, status=None):
self.request = request
self.code = code
self.headers = headers if headers is not None else Headers()
self.buffer = buffer
self.effective_url = effective_url or request.url
self._body = None
self.status = status or responses.get(self.code, 'Unknown')
self.error = error
if self.error is None and (self.code < 200 or self.code > 299):
self.error = HttpError(self.code, self.status, self)
def raise_for_error(self):
"""Raise if the request resulted in an HTTP error code.
Raises
------
:class:`~kombu.exceptions.HttpError`
"""
if self.error:
raise self.error
@property
def body(self):
"""The full contents of the response body.
Note:
----
Accessing this property will evaluate the buffer
and subsequent accesses will be cached.
"""
if self._body is None:
if self.buffer is not None:
self._body = self.buffer.getvalue()
return self._body
# these are for compatibility with Requests
@property
def status_code(self):
return self.code
@property
def content(self):
return self.body
@coro
def header_parser(keyt=normalize_header):
while 1:
(line, headers) = yield
if line.startswith('HTTP/'):
continue
elif not line:
headers.complete = True
continue
elif line[0].isspace():
pkey = headers._prev_key
headers[pkey] = ' '.join([headers.get(pkey) or '', line.lstrip()])
else:
key, value = line.split(':', 1)
key = headers._prev_key = keyt(key)
headers[key] = value.strip()
class BaseClient:
"""Base class for HTTP clients.
This class provides the basic structure and functionality for HTTP clients.
Subclasses should implement specific HTTP client behavior.
"""
Headers = Headers
Request = Request
Response = Response
def __init__(self, hub, **kwargs):
self.hub = hub
self._header_parser = header_parser()
def perform(self, request, **kwargs):
for req in maybe_list(request) or []:
if not isinstance(req, self.Request):
req = self.Request(req, **kwargs)
self.add_request(req)
def add_request(self, request):
raise NotImplementedError('must implement add_request')
def close(self):
pass
def on_header(self, headers, line):
try:
self._header_parser.send((bytes_to_str(line), headers))
except StopIteration:
self._header_parser = header_parser()
def __enter__(self):
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None
) -> None:
self.close()

View File

@@ -0,0 +1,219 @@
from __future__ import annotations
from collections import deque
from io import BytesIO
import urllib3
from kombu.asynchronous.hub import Hub, get_event_loop
from kombu.exceptions import HttpError
from .base import BaseClient, Request
__all__ = ('Urllib3Client',)
from ...utils.encoding import bytes_to_str
DEFAULT_USER_AGENT = 'Mozilla/5.0 (compatible; urllib3)'
EXTRA_METHODS = frozenset(['DELETE', 'OPTIONS', 'PATCH'])
def _get_pool_key_parts(request: Request) -> list[str]:
_pool_key_parts = []
if request.network_interface:
_pool_key_parts.append(f"interface={request.network_interface}")
if request.validate_cert:
_pool_key_parts.append("validate_cert=True")
else:
_pool_key_parts.append("validate_cert=False")
if request.ca_certs:
_pool_key_parts.append(f"ca_certs={request.ca_certs}")
if request.client_cert:
_pool_key_parts.append(f"client_cert={request.client_cert}")
if request.client_key:
_pool_key_parts.append(f"client_key={request.client_key}")
return _pool_key_parts
class Urllib3Client(BaseClient):
"""Urllib3 HTTP Client."""
_pools = {}
def __init__(self, hub: Hub | None = None, max_clients: int = 10):
hub = hub or get_event_loop()
super().__init__(hub)
self.max_clients = max_clients
self._pending = deque()
self._timeout_check_tref = self.hub.call_repeatedly(
1.0, self._timeout_check,
)
def pools_close(self):
for pool in self._pools.values():
pool.close()
self._pools.clear()
def close(self):
self._timeout_check_tref.cancel()
self.pools_close()
def add_request(self, request):
self._pending.append(request)
self._process_queue()
return request
def get_pool(self, request: Request):
_pool_key_parts = _get_pool_key_parts(request=request)
_proxy_url = None
proxy_headers = None
if request.proxy_host:
_proxy_url = urllib3.util.Url(
scheme=None,
host=request.proxy_host,
port=request.proxy_port,
)
if request.proxy_username:
proxy_headers = urllib3.make_headers(
proxy_basic_auth=(
f"{request.proxy_username}"
f":{request.proxy_password}"
)
)
else:
proxy_headers = None
_proxy_url = _proxy_url.url
_pool_key_parts.append(f"proxy={_proxy_url}")
if proxy_headers:
_pool_key_parts.append(f"proxy_headers={str(proxy_headers)}")
_pool_key = "|".join(_pool_key_parts)
if _pool_key in self._pools:
return self._pools[_pool_key]
# create new pool
if _proxy_url:
_pool = urllib3.ProxyManager(
proxy_url=_proxy_url,
num_pools=self.max_clients,
proxy_headers=proxy_headers
)
else:
_pool = urllib3.PoolManager(num_pools=self.max_clients)
# Network Interface
if request.network_interface:
_pool.connection_pool_kw['source_address'] = (
request.network_interface,
0
)
# SSL Verification
if request.validate_cert:
_pool.connection_pool_kw['cert_reqs'] = 'CERT_REQUIRED'
else:
_pool.connection_pool_kw['cert_reqs'] = 'CERT_NONE'
# CA Certificates
if request.ca_certs is not None:
_pool.connection_pool_kw['ca_certs'] = request.ca_certs
elif request.validate_cert is True:
try:
from certifi import where
_pool.connection_pool_kw['ca_certs'] = where()
except ImportError:
pass
# Client Certificates
if request.client_cert is not None:
_pool.connection_pool_kw['cert_file'] = request.client_cert
if request.client_key is not None:
_pool.connection_pool_kw['key_file'] = request.client_key
self._pools[_pool_key] = _pool
return _pool
def _timeout_check(self):
self._process_pending_requests()
def _process_pending_requests(self):
while self._pending:
request = self._pending.popleft()
self._process_request(request)
def _process_request(self, request: Request):
# Prepare headers
headers = request.headers
headers.setdefault(
'User-Agent',
bytes_to_str(request.user_agent or DEFAULT_USER_AGENT)
)
headers.setdefault(
'Accept-Encoding',
'gzip,deflate' if request.use_gzip else 'none'
)
# Authentication
if request.auth_username is not None:
headers.update(
urllib3.util.make_headers(
basic_auth=(
f"{request.auth_username}"
f":{request.auth_password or ''}"
)
)
)
# Make the request using urllib3
try:
_pool = self.get_pool(request=request)
response = _pool.request(
request.method,
request.url,
headers=headers,
body=request.body,
preload_content=False,
redirect=request.follow_redirects,
)
buffer = BytesIO(response.data)
response_obj = self.Response(
request=request,
code=response.status,
headers=response.headers,
buffer=buffer,
effective_url=response.geturl(),
error=None
)
except urllib3.exceptions.HTTPError as e:
response_obj = self.Response(
request=request,
code=599,
headers={},
buffer=None,
effective_url=None,
error=HttpError(599, str(e))
)
request.on_ready(response_obj)
def _process_queue(self):
self._process_pending_requests()
def on_readable(self, fd):
pass
def on_writable(self, fd):
pass
def _setup_request(self, curl, request, buffer, headers):
pass

View File

@@ -0,0 +1,399 @@
"""Event loop implementation."""
from __future__ import annotations
import errno
import threading
from contextlib import contextmanager
from copy import copy
from queue import Empty
from time import sleep
from types import GeneratorType as generator
from vine import Thenable, promise
from kombu.log import get_logger
from kombu.utils.compat import fileno
from kombu.utils.eventio import ERR, READ, WRITE, poll
from kombu.utils.objects import cached_property
from .timer import Timer
__all__ = ('Hub', 'get_event_loop', 'set_event_loop')
logger = get_logger(__name__)
_current_loop: Hub | None = None
W_UNKNOWN_EVENT = """\
Received unknown event %r for fd %r, please contact support!\
"""
class Stop(BaseException):
"""Stops the event loop."""
def _raise_stop_error():
raise Stop()
@contextmanager
def _dummy_context(*args, **kwargs):
yield
def get_event_loop() -> Hub | None:
"""Get current event loop object."""
return _current_loop
def set_event_loop(loop: Hub | None) -> Hub | None:
"""Set the current event loop object."""
global _current_loop
_current_loop = loop
return loop
class Hub:
"""Event loop object.
Arguments:
---------
timer (kombu.asynchronous.Timer): Specify custom timer instance.
"""
#: Flag set if reading from an fd will not block.
READ = READ
#: Flag set if writing to an fd will not block.
WRITE = WRITE
#: Flag set on error, and the fd should be read from asap.
ERR = ERR
#: List of callbacks to be called when the loop is exiting,
#: applied with the hub instance as sole argument.
on_close = None
def __init__(self, timer=None):
self.timer = timer if timer is not None else Timer()
self.readers = {}
self.writers = {}
self.on_tick = set()
self.on_close = set()
self._ready = set()
self._ready_lock = threading.Lock()
self._running = False
self._loop = None
# The eventloop (in celery.worker.loops)
# will merge fds in this set and then instead of calling
# the callback for each ready fd it will call the
# :attr:`consolidate_callback` with the list of ready_fds
# as an argument. This API is internal and is only
# used by the multiprocessing pool to find inqueues
# that are ready to write.
self.consolidate = set()
self.consolidate_callback = None
self.propagate_errors = ()
self._create_poller()
@property
def poller(self):
if not self._poller:
self._create_poller()
return self._poller
@poller.setter
def poller(self, value):
self._poller = value
def reset(self):
self.close()
self._create_poller()
def _create_poller(self):
self._poller = poll()
self._register_fd = self._poller.register
self._unregister_fd = self._poller.unregister
def _close_poller(self):
if self._poller is not None:
self._poller.close()
self._poller = None
self._register_fd = None
self._unregister_fd = None
def stop(self):
self.call_soon(_raise_stop_error)
def __repr__(self):
return '<Hub@{:#x}: R:{} W:{}>'.format(
id(self), len(self.readers), len(self.writers),
)
def fire_timers(self, min_delay=1, max_delay=10, max_timers=10,
propagate=()):
timer = self.timer
delay = None
if timer and timer._queue:
for i in range(max_timers):
delay, entry = next(self.scheduler)
if entry is None:
break
try:
entry()
except propagate:
raise
except (MemoryError, AssertionError):
raise
except OSError as exc:
if exc.errno == errno.ENOMEM:
raise
logger.error('Error in timer: %r', exc, exc_info=1)
except Exception as exc:
logger.error('Error in timer: %r', exc, exc_info=1)
return min(delay or min_delay, max_delay)
def _remove_from_loop(self, fd):
try:
self._unregister(fd)
finally:
self._discard(fd)
def add(self, fd, callback, flags, args=(), consolidate=False):
fd = fileno(fd)
try:
self.poller.register(fd, flags)
except ValueError:
self._remove_from_loop(fd)
raise
else:
dest = self.readers if flags & READ else self.writers
if consolidate:
self.consolidate.add(fd)
dest[fd] = None
else:
dest[fd] = callback, args
def remove(self, fd):
fd = fileno(fd)
self._remove_from_loop(fd)
def run_forever(self):
self._running = True
try:
while 1:
try:
self.run_once()
except Stop:
break
finally:
self._running = False
def run_once(self):
try:
next(self.loop)
except StopIteration:
self._loop = None
def call_soon(self, callback, *args):
if not isinstance(callback, Thenable):
callback = promise(callback, args)
with self._ready_lock:
self._ready.add(callback)
return callback
def call_later(self, delay, callback, *args):
return self.timer.call_after(delay, callback, args)
def call_at(self, when, callback, *args):
return self.timer.call_at(when, callback, args)
def call_repeatedly(self, delay, callback, *args):
return self.timer.call_repeatedly(delay, callback, args)
def add_reader(self, fds, callback, *args):
return self.add(fds, callback, READ | ERR, args)
def add_writer(self, fds, callback, *args):
return self.add(fds, callback, WRITE, args)
def remove_reader(self, fd):
writable = fd in self.writers
on_write = self.writers.get(fd)
try:
self._remove_from_loop(fd)
finally:
if writable:
cb, args = on_write
self.add(fd, cb, WRITE, args)
def remove_writer(self, fd):
readable = fd in self.readers
on_read = self.readers.get(fd)
try:
self._remove_from_loop(fd)
finally:
if readable:
cb, args = on_read
self.add(fd, cb, READ | ERR, args)
def _unregister(self, fd):
try:
self.poller.unregister(fd)
except (AttributeError, KeyError, OSError):
pass
def _pop_ready(self):
with self._ready_lock:
ready = self._ready
self._ready = set()
return ready
def close(self, *args):
[self._unregister(fd) for fd in self.readers]
self.readers.clear()
[self._unregister(fd) for fd in self.writers]
self.writers.clear()
self.consolidate.clear()
self._close_poller()
for callback in self.on_close:
callback(self)
# Complete remaining todo before Hub close
# Eg: Acknowledge message
# To avoid infinite loop where one of the callables adds items
# to self._ready (via call_soon or otherwise).
# we create new list with current self._ready
todos = self._pop_ready()
for item in todos:
item()
def _discard(self, fd):
fd = fileno(fd)
self.readers.pop(fd, None)
self.writers.pop(fd, None)
self.consolidate.discard(fd)
def on_callback_error(self, callback, exc):
logger.error(
'Callback %r raised exception: %r', callback, exc, exc_info=1,
)
def create_loop(self,
generator=generator, sleep=sleep, min=min, next=next,
Empty=Empty, StopIteration=StopIteration,
KeyError=KeyError, READ=READ, WRITE=WRITE, ERR=ERR):
readers, writers = self.readers, self.writers
poll = self.poller.poll
fire_timers = self.fire_timers
hub_remove = self.remove
scheduled = self.timer._queue
consolidate = self.consolidate
consolidate_callback = self.consolidate_callback
propagate = self.propagate_errors
while 1:
todo = self._pop_ready()
for item in todo:
if item:
item()
poll_timeout = fire_timers(propagate=propagate) if scheduled else 1
for tick_callback in copy(self.on_tick):
tick_callback()
# print('[[[HUB]]]: %s' % (self.repr_active(),))
if readers or writers:
to_consolidate = []
try:
events = poll(poll_timeout)
# print('[EVENTS]: %s' % (self.repr_events(events),))
except ValueError: # Issue celery/#882
return
for fd, event in events or ():
general_error = False
if fd in consolidate and \
writers.get(fd) is None:
to_consolidate.append(fd)
continue
cb = cbargs = None
if event & READ:
try:
cb, cbargs = readers[fd]
except KeyError:
self.remove_reader(fd)
continue
elif event & WRITE:
try:
cb, cbargs = writers[fd]
except KeyError:
self.remove_writer(fd)
continue
elif event & ERR:
general_error = True
else:
logger.info(W_UNKNOWN_EVENT, event, fd)
general_error = True
if general_error:
try:
cb, cbargs = (readers.get(fd) or
writers.get(fd))
except TypeError:
pass
if cb is None:
self.remove(fd)
continue
if isinstance(cb, generator):
try:
next(cb)
except OSError as exc:
if exc.errno != errno.EBADF:
raise
hub_remove(fd)
except StopIteration:
pass
except Exception:
hub_remove(fd)
raise
else:
try:
cb(*cbargs)
except Empty:
pass
if to_consolidate:
consolidate_callback(to_consolidate)
else:
# no sockets yet, startup is probably not done.
sleep(min(poll_timeout, 0.1))
yield
def repr_active(self):
from .debug import repr_active
return repr_active(self)
def repr_events(self, events):
from .debug import repr_events
return repr_events(self, events or [])
@cached_property
def scheduler(self):
return iter(self.timer)
@property
def loop(self):
if self._loop is None:
self._loop = self.create_loop()
return self._loop

View File

@@ -0,0 +1,127 @@
"""Semaphores and concurrency primitives."""
from __future__ import annotations
import sys
from collections import deque
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from types import TracebackType
from typing import Callable, Deque
if sys.version_info < (3, 10):
from typing_extensions import ParamSpec
else:
from typing import ParamSpec
P = ParamSpec("P")
__all__ = ('DummyLock', 'LaxBoundedSemaphore')
class LaxBoundedSemaphore:
"""Asynchronous Bounded Semaphore.
Lax means that the value will stay within the specified
range even if released more times than it was acquired.
Example:
-------
>>> x = LaxBoundedSemaphore(2)
>>> x.acquire(print, 'HELLO 1')
HELLO 1
>>> x.acquire(print, 'HELLO 2')
HELLO 2
>>> x.acquire(print, 'HELLO 3')
>>> x._waiters # private, do not access directly
[print, ('HELLO 3',)]
>>> x.release()
HELLO 3
"""
def __init__(self, value: int) -> None:
self.initial_value = self.value = value
self._waiting: Deque[tuple] = deque()
self._add_waiter = self._waiting.append
self._pop_waiter = self._waiting.popleft
def acquire(
self,
callback: Callable[P, None],
*partial_args: P.args,
**partial_kwargs: P.kwargs
) -> bool:
"""Acquire semaphore.
This will immediately apply ``callback`` if
the resource is available, otherwise the callback is suspended
until the semaphore is released.
Arguments:
---------
callback (Callable): The callback to apply.
*partial_args (Any): partial arguments to callback.
"""
value = self.value
if value <= 0:
self._add_waiter((callback, partial_args, partial_kwargs))
return False
else:
self.value = max(value - 1, 0)
callback(*partial_args, **partial_kwargs)
return True
def release(self) -> None:
"""Release semaphore.
Note:
----
If there are any waiters this will apply the first waiter
that is waiting for the resource (FIFO order).
"""
try:
waiter, args, kwargs = self._pop_waiter()
except IndexError:
self.value = min(self.value + 1, self.initial_value)
else:
waiter(*args, **kwargs)
def grow(self, n: int = 1) -> None:
"""Change the size of the semaphore to accept more users."""
self.initial_value += n
self.value += n
for _ in range(n):
self.release()
def shrink(self, n: int = 1) -> None:
"""Change the size of the semaphore to accept less users."""
self.initial_value = max(self.initial_value - n, 0)
self.value = max(self.value - n, 0)
def clear(self) -> None:
"""Reset the semaphore, which also wipes out any waiting callbacks."""
self._waiting.clear()
self.value = self.initial_value
def __repr__(self) -> str:
return '<{} at {:#x} value:{} waiting:{}>'.format(
self.__class__.__name__, id(self), self.value, len(self._waiting),
)
class DummyLock:
"""Pretending to be a lock."""
def __enter__(self) -> DummyLock:
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None
) -> None:
pass

View File

@@ -0,0 +1,241 @@
"""Timer scheduling Python callbacks."""
from __future__ import annotations
import heapq
import sys
from collections import namedtuple
from datetime import datetime
from functools import total_ordering
from time import monotonic
from time import time as _time
from typing import TYPE_CHECKING
from weakref import proxy as weakrefproxy
from vine.utils import wraps
from kombu.log import get_logger
if sys.version_info >= (3, 9):
from zoneinfo import ZoneInfo
else:
from backports.zoneinfo import ZoneInfo
if TYPE_CHECKING:
from types import TracebackType
__all__ = ('Entry', 'Timer', 'to_timestamp')
logger = get_logger(__name__)
DEFAULT_MAX_INTERVAL = 2
EPOCH = datetime.fromtimestamp(0, ZoneInfo("UTC"))
IS_PYPY = hasattr(sys, 'pypy_version_info')
scheduled = namedtuple('scheduled', ('eta', 'priority', 'entry'))
def to_timestamp(d, default_timezone=ZoneInfo("UTC"), time=monotonic):
"""Convert datetime to timestamp.
If d' is already a timestamp, then that will be used.
"""
if isinstance(d, datetime):
if d.tzinfo is None:
d = d.replace(tzinfo=default_timezone)
diff = _time() - time()
return max((d - EPOCH).total_seconds() - diff, 0)
return d
@total_ordering
class Entry:
"""Schedule Entry."""
if not IS_PYPY: # pragma: no cover
__slots__ = (
'fun', 'args', 'kwargs', 'tref', 'canceled',
'_last_run', '__weakref__',
)
def __init__(self, fun, args=None, kwargs=None):
self.fun = fun
self.args = args or []
self.kwargs = kwargs or {}
self.tref = weakrefproxy(self)
self._last_run = None
self.canceled = False
def __call__(self):
return self.fun(*self.args, **self.kwargs)
def cancel(self):
try:
self.tref.canceled = True
except ReferenceError: # pragma: no cover
pass
def __repr__(self):
return '<TimerEntry: {}(*{!r}, **{!r})'.format(
self.fun.__name__, self.args, self.kwargs)
# must not use hash() to order entries
def __lt__(self, other):
return id(self) < id(other)
@property
def cancelled(self):
return self.canceled
@cancelled.setter
def cancelled(self, value):
self.canceled = value
class Timer:
"""Async timer implementation."""
Entry = Entry
on_error = None
def __init__(self, max_interval=None, on_error=None, **kwargs):
self.max_interval = float(max_interval or DEFAULT_MAX_INTERVAL)
self.on_error = on_error or self.on_error
self._queue = []
def __enter__(self):
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None
) -> None:
self.stop()
def call_at(self, eta, fun, args=(), kwargs=None, priority=0):
kwargs = {} if not kwargs else kwargs
return self.enter_at(self.Entry(fun, args, kwargs), eta, priority)
def call_after(self, secs, fun, args=(), kwargs=None, priority=0):
kwargs = {} if not kwargs else kwargs
return self.enter_after(secs, self.Entry(fun, args, kwargs), priority)
def call_repeatedly(self, secs, fun, args=(), kwargs=None, priority=0):
kwargs = {} if not kwargs else kwargs
tref = self.Entry(fun, args, kwargs)
@wraps(fun)
def _reschedules(*args, **kwargs):
last, now = tref._last_run, monotonic()
lsince = (now - tref._last_run) if last else secs
try:
if lsince and lsince >= secs:
tref._last_run = now
return fun(*args, **kwargs)
finally:
if not tref.canceled:
last = tref._last_run
next = secs - (now - last) if last else secs
self.enter_after(next, tref, priority)
tref.fun = _reschedules
tref._last_run = None
return self.enter_after(secs, tref, priority)
def enter_at(self, entry, eta=None, priority=0, time=monotonic):
"""Enter function into the scheduler.
Arguments:
---------
entry (~kombu.asynchronous.timer.Entry): Item to enter.
eta (datetime.datetime): Scheduled time.
priority (int): Unused.
"""
if eta is None:
eta = time()
if isinstance(eta, datetime):
try:
eta = to_timestamp(eta)
except Exception as exc:
if not self.handle_error(exc):
raise
return
return self._enter(eta, priority, entry)
def enter_after(self, secs, entry, priority=0, time=monotonic):
return self.enter_at(entry, time() + float(secs), priority)
def _enter(self, eta, priority, entry, push=heapq.heappush):
push(self._queue, scheduled(eta, priority, entry))
return entry
def apply_entry(self, entry):
try:
entry()
except Exception as exc:
if not self.handle_error(exc):
logger.error('Error in timer: %r', exc, exc_info=True)
def handle_error(self, exc_info):
if self.on_error:
self.on_error(exc_info)
return True
def stop(self):
pass
def __iter__(self, min=min, nowfun=monotonic,
pop=heapq.heappop, push=heapq.heappush):
"""Iterate over schedule.
This iterator yields a tuple of ``(wait_seconds, entry)``,
where if entry is :const:`None` the caller should wait
for ``wait_seconds`` until it polls the schedule again.
"""
max_interval = self.max_interval
queue = self._queue
while 1:
if queue:
eventA = queue[0]
now, eta = nowfun(), eventA[0]
if now < eta:
yield min(eta - now, max_interval), None
else:
eventB = pop(queue)
if eventB is eventA:
entry = eventA[2]
if not entry.canceled:
yield None, entry
continue
else:
push(queue, eventB)
else:
yield None, None
def clear(self):
self._queue[:] = [] # atomic, without creating a new list.
def cancel(self, tref):
tref.cancel()
def __len__(self):
return len(self._queue)
def __nonzero__(self):
return True
@property
def queue(self, _pop=heapq.heappop):
"""Snapshot of underlying datastructure."""
events = list(self._queue)
return [_pop(v) for v in [events] * len(events)]
@property
def schedule(self):
return self