update
This commit is contained in:
@@ -0,0 +1,9 @@
|
||||
"""Event loop."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from kombu.utils.eventio import ERR, READ, WRITE
|
||||
|
||||
from .hub import Hub, get_event_loop, set_event_loop
|
||||
|
||||
__all__ = ('READ', 'WRITE', 'ERR', 'Hub', 'get_event_loop', 'set_event_loop')
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from kombu.asynchronous.aws.sqs.connection import AsyncSQSConnection
|
||||
|
||||
|
||||
def connect_sqs(
|
||||
aws_access_key_id: str | None = None,
|
||||
aws_secret_access_key: str | None = None,
|
||||
**kwargs: Any
|
||||
) -> AsyncSQSConnection:
|
||||
"""Return async connection to Amazon SQS."""
|
||||
from .sqs.connection import AsyncSQSConnection
|
||||
return AsyncSQSConnection(
|
||||
aws_access_key_id, aws_secret_access_key, **kwargs
|
||||
)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,278 @@
|
||||
"""Amazon AWS Connection."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from email import message_from_bytes
|
||||
from email.mime.message import MIMEMessage
|
||||
|
||||
from vine import promise, transform
|
||||
|
||||
from kombu.asynchronous.aws.ext import AWSRequest, get_cert_path, get_response
|
||||
from kombu.asynchronous.http import Headers, Request, get_client
|
||||
|
||||
|
||||
def message_from_headers(hdr):
|
||||
bs = "\r\n".join("{}: {}".format(*h) for h in hdr)
|
||||
return message_from_bytes(bs.encode())
|
||||
|
||||
|
||||
__all__ = (
|
||||
'AsyncHTTPSConnection', 'AsyncConnection',
|
||||
)
|
||||
|
||||
|
||||
class AsyncHTTPResponse:
|
||||
"""Async HTTP Response."""
|
||||
|
||||
def __init__(self, response):
|
||||
self.response = response
|
||||
self._msg = None
|
||||
self.version = 10
|
||||
|
||||
def read(self, *args, **kwargs):
|
||||
return self.response.body
|
||||
|
||||
def getheader(self, name, default=None):
|
||||
return self.response.headers.get(name, default)
|
||||
|
||||
def getheaders(self):
|
||||
return list(self.response.headers.items())
|
||||
|
||||
@property
|
||||
def msg(self):
|
||||
if self._msg is None:
|
||||
self._msg = MIMEMessage(message_from_headers(self.getheaders()))
|
||||
return self._msg
|
||||
|
||||
@property
|
||||
def status(self):
|
||||
return self.response.code
|
||||
|
||||
@property
|
||||
def reason(self):
|
||||
if self.response.error:
|
||||
return self.response.error.message
|
||||
return ''
|
||||
|
||||
def __repr__(self):
|
||||
return repr(self.response)
|
||||
|
||||
|
||||
class AsyncHTTPSConnection:
|
||||
"""Async HTTP Connection."""
|
||||
|
||||
Request = Request
|
||||
Response = AsyncHTTPResponse
|
||||
|
||||
method = 'GET'
|
||||
path = '/'
|
||||
body = None
|
||||
default_ports = {'http': 80, 'https': 443}
|
||||
|
||||
def __init__(self, strict=None, timeout=20.0, http_client=None):
|
||||
self.headers = []
|
||||
self.timeout = timeout
|
||||
self.strict = strict
|
||||
self.http_client = http_client or get_client()
|
||||
|
||||
def request(self, method, path, body=None, headers=None):
|
||||
self.path = path
|
||||
self.method = method
|
||||
if body is not None:
|
||||
try:
|
||||
read = body.read
|
||||
except AttributeError:
|
||||
self.body = body
|
||||
else:
|
||||
self.body = read()
|
||||
if headers is not None:
|
||||
self.headers.extend(list(headers.items()))
|
||||
|
||||
def getrequest(self):
|
||||
headers = Headers(self.headers)
|
||||
return self.Request(self.path, method=self.method, headers=headers,
|
||||
body=self.body, connect_timeout=self.timeout,
|
||||
request_timeout=self.timeout,
|
||||
validate_cert=True, ca_certs=get_cert_path(True))
|
||||
|
||||
def getresponse(self, callback=None):
|
||||
request = self.getrequest()
|
||||
request.then(transform(self.Response, callback))
|
||||
return self.http_client.add_request(request)
|
||||
|
||||
def set_debuglevel(self, level):
|
||||
pass
|
||||
|
||||
def connect(self):
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
def putrequest(self, method, path):
|
||||
self.method = method
|
||||
self.path = path
|
||||
|
||||
def putheader(self, header, value):
|
||||
self.headers.append((header, value))
|
||||
|
||||
def endheaders(self):
|
||||
pass
|
||||
|
||||
def send(self, data):
|
||||
if self.body:
|
||||
self.body += data
|
||||
else:
|
||||
self.body = data
|
||||
|
||||
def __repr__(self):
|
||||
return f'<AsyncHTTPConnection: {self.getrequest()!r}>'
|
||||
|
||||
|
||||
class AsyncConnection:
|
||||
"""Async AWS Connection."""
|
||||
|
||||
def __init__(self, sqs_connection, http_client=None, **kwargs):
|
||||
self.sqs_connection = sqs_connection
|
||||
self._httpclient = http_client or get_client()
|
||||
|
||||
def get_http_connection(self):
|
||||
return AsyncHTTPSConnection(http_client=self._httpclient)
|
||||
|
||||
def _mexe(self, request, sender=None, callback=None):
|
||||
callback = callback or promise()
|
||||
conn = self.get_http_connection()
|
||||
|
||||
if callable(sender):
|
||||
sender(conn, request.method, request.path, request.body,
|
||||
request.headers, callback)
|
||||
else:
|
||||
conn.request(request.method, request.url,
|
||||
request.body, request.headers)
|
||||
conn.getresponse(callback=callback)
|
||||
return callback
|
||||
|
||||
|
||||
class AsyncAWSQueryConnection(AsyncConnection):
|
||||
"""Async AWS Query Connection."""
|
||||
|
||||
STATUS_CODE_OK = 200
|
||||
STATUS_CODE_REQUEST_TIMEOUT = 408
|
||||
STATUS_CODE_NETWORK_CONNECT_TIMEOUT_ERROR = 599
|
||||
STATUS_CODE_INTERNAL_ERROR = 500
|
||||
STATUS_CODE_BAD_GATEWAY = 502
|
||||
STATUS_CODE_SERVICE_UNAVAILABLE_ERROR = 503
|
||||
STATUS_CODE_GATEWAY_TIMEOUT = 504
|
||||
|
||||
STATUS_CODES_SERVER_ERRORS = (
|
||||
STATUS_CODE_INTERNAL_ERROR,
|
||||
STATUS_CODE_BAD_GATEWAY,
|
||||
STATUS_CODE_SERVICE_UNAVAILABLE_ERROR
|
||||
)
|
||||
|
||||
STATUS_CODES_TIMEOUT = (
|
||||
STATUS_CODE_REQUEST_TIMEOUT,
|
||||
STATUS_CODE_NETWORK_CONNECT_TIMEOUT_ERROR,
|
||||
STATUS_CODE_GATEWAY_TIMEOUT
|
||||
)
|
||||
|
||||
def __init__(self, sqs_connection, http_client=None,
|
||||
http_client_params=None, **kwargs):
|
||||
if not http_client_params:
|
||||
http_client_params = {}
|
||||
super().__init__(sqs_connection, http_client,
|
||||
**http_client_params)
|
||||
|
||||
def make_request(self, operation, params_, path, verb, callback=None, protocol_params=None):
|
||||
params = params_.copy()
|
||||
params.update((protocol_params or {}).get('query', {}))
|
||||
if operation:
|
||||
params['Action'] = operation
|
||||
signer = self.sqs_connection._request_signer
|
||||
|
||||
# defaults for non-get
|
||||
signing_type = 'standard'
|
||||
param_payload = {'data': params}
|
||||
if verb.lower() == 'get':
|
||||
# query-based opts
|
||||
signing_type = 'presign-url'
|
||||
param_payload = {'params': params}
|
||||
|
||||
request = AWSRequest(method=verb, url=path, **param_payload)
|
||||
signer.sign(operation, request, signing_type=signing_type)
|
||||
prepared_request = request.prepare()
|
||||
|
||||
return self._mexe(prepared_request, callback=callback)
|
||||
|
||||
def get_list(self, operation, params, markers, path='/', parent=None, verb='POST', callback=None,
|
||||
protocol_params=None):
|
||||
return self.make_request(
|
||||
operation, params, path, verb,
|
||||
callback=transform(
|
||||
self._on_list_ready, callback, parent or self, markers,
|
||||
operation
|
||||
),
|
||||
protocol_params=protocol_params,
|
||||
)
|
||||
|
||||
def get_object(self, operation, params, path='/', parent=None, verb='GET', callback=None, protocol_params=None):
|
||||
return self.make_request(
|
||||
operation, params, path, verb,
|
||||
callback=transform(
|
||||
self._on_obj_ready, callback, parent or self, operation
|
||||
),
|
||||
protocol_params=protocol_params,
|
||||
)
|
||||
|
||||
def get_status(self, operation, params, path='/', parent=None, verb='GET', callback=None, protocol_params=None):
|
||||
return self.make_request(
|
||||
operation, params, path, verb,
|
||||
callback=transform(
|
||||
self._on_status_ready, callback, parent or self, operation
|
||||
),
|
||||
protocol_params=protocol_params,
|
||||
)
|
||||
|
||||
def _on_list_ready(self, parent, markers, operation, response):
|
||||
service_model = self.sqs_connection.meta.service_model
|
||||
if response.status == self.STATUS_CODE_OK:
|
||||
_, parsed = get_response(
|
||||
service_model.operation_model(operation), response.response
|
||||
)
|
||||
return parsed
|
||||
elif (
|
||||
response.status in self.STATUS_CODES_TIMEOUT or
|
||||
response.status in self.STATUS_CODES_SERVER_ERRORS
|
||||
):
|
||||
# When the server returns a timeout or 50X server error,
|
||||
# the response is interpreted as an empty list.
|
||||
# This prevents hanging the Celery worker.
|
||||
return []
|
||||
else:
|
||||
raise self._for_status(response, response.read())
|
||||
|
||||
def _on_obj_ready(self, parent, operation, response):
|
||||
service_model = self.sqs_connection.meta.service_model
|
||||
if response.status == self.STATUS_CODE_OK:
|
||||
_, parsed = get_response(
|
||||
service_model.operation_model(operation), response.response
|
||||
)
|
||||
return parsed
|
||||
else:
|
||||
raise self._for_status(response, response.read())
|
||||
|
||||
def _on_status_ready(self, parent, operation, response):
|
||||
service_model = self.sqs_connection.meta.service_model
|
||||
if response.status == self.STATUS_CODE_OK:
|
||||
httpres, _ = get_response(
|
||||
service_model.operation_model(operation), response.response
|
||||
)
|
||||
return httpres.code
|
||||
else:
|
||||
raise self._for_status(response, response.read())
|
||||
|
||||
def _for_status(self, response, body):
|
||||
context = 'Empty body' if not body else 'HTTP Error'
|
||||
return Exception("Request {} HTTP {} {} ({})".format(
|
||||
context, response.status, response.reason, body
|
||||
))
|
||||
@@ -0,0 +1,28 @@
|
||||
"""Amazon boto3 interface."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
try:
|
||||
import boto3
|
||||
from botocore import exceptions
|
||||
from botocore.awsrequest import AWSRequest
|
||||
from botocore.httpsession import get_cert_path
|
||||
from botocore.response import get_response
|
||||
except ImportError:
|
||||
boto3 = None
|
||||
|
||||
class _void:
|
||||
pass
|
||||
|
||||
class BotoCoreError(Exception):
|
||||
pass
|
||||
exceptions = _void()
|
||||
exceptions.BotoCoreError = BotoCoreError
|
||||
AWSRequest = _void()
|
||||
get_response = _void()
|
||||
get_cert_path = _void()
|
||||
|
||||
|
||||
__all__ = (
|
||||
'exceptions', 'AWSRequest', 'get_response', 'get_cert_path',
|
||||
)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,320 @@
|
||||
"""Amazon SQS Connection."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
from botocore.serialize import Serializer
|
||||
from vine import transform
|
||||
|
||||
from kombu.asynchronous.aws.connection import AsyncAWSQueryConnection
|
||||
from kombu.asynchronous.aws.ext import AWSRequest
|
||||
|
||||
from .ext import boto3
|
||||
from .message import AsyncMessage
|
||||
from .queue import AsyncQueue
|
||||
|
||||
__all__ = ('AsyncSQSConnection',)
|
||||
|
||||
|
||||
class AsyncSQSConnection(AsyncAWSQueryConnection):
|
||||
"""Async SQS Connection."""
|
||||
|
||||
def __init__(self, sqs_connection, debug=0, region=None, fetch_message_attributes=None, **kwargs):
|
||||
if boto3 is None:
|
||||
raise ImportError('boto3 is not installed')
|
||||
super().__init__(
|
||||
sqs_connection,
|
||||
region_name=region, debug=debug,
|
||||
**kwargs
|
||||
)
|
||||
self.fetch_message_attributes = (
|
||||
fetch_message_attributes if fetch_message_attributes is not None
|
||||
else ["ApproximateReceiveCount"]
|
||||
)
|
||||
|
||||
def _create_query_request(self, operation, params, queue_url, method):
|
||||
params = params.copy()
|
||||
if operation:
|
||||
params['Action'] = operation
|
||||
|
||||
# defaults for non-get
|
||||
param_payload = {'data': params}
|
||||
headers = {}
|
||||
if method.lower() == 'get':
|
||||
# query-based opts
|
||||
param_payload = {'params': params}
|
||||
|
||||
if method.lower() == 'post':
|
||||
headers['Content-Type'] = 'application/x-www-form-urlencoded; charset=utf-8'
|
||||
|
||||
return AWSRequest(method=method, url=queue_url, headers=headers, **param_payload)
|
||||
|
||||
def _create_json_request(self, operation, params, queue_url):
|
||||
params = params.copy()
|
||||
params['QueueUrl'] = queue_url
|
||||
|
||||
service_model = self.sqs_connection.meta.service_model
|
||||
operation_model = service_model.operation_model(operation)
|
||||
|
||||
url = self.sqs_connection._endpoint.host
|
||||
|
||||
headers = {}
|
||||
# Content-Type
|
||||
json_version = operation_model.metadata['jsonVersion']
|
||||
content_type = f'application/x-amz-json-{json_version}'
|
||||
headers['Content-Type'] = content_type
|
||||
|
||||
# X-Amz-Target
|
||||
target = '{}.{}'.format(
|
||||
operation_model.metadata['targetPrefix'],
|
||||
operation_model.name,
|
||||
)
|
||||
headers['X-Amz-Target'] = target
|
||||
|
||||
param_payload = {
|
||||
'data': json.dumps(params).encode(),
|
||||
'headers': headers
|
||||
}
|
||||
|
||||
method = operation_model.http.get('method', Serializer.DEFAULT_METHOD)
|
||||
return AWSRequest(
|
||||
method=method,
|
||||
url=url,
|
||||
**param_payload
|
||||
)
|
||||
|
||||
def make_request(self, operation_name, params, queue_url, verb, callback=None, protocol_params=None):
|
||||
"""Override make_request to support different protocols.
|
||||
|
||||
botocore has changed the default protocol of communicating
|
||||
with SQS backend from 'query' to 'json', so we need a special
|
||||
implementation of make_request for SQS. More information on this can
|
||||
be found in: https://github.com/celery/kombu/pull/1807.
|
||||
|
||||
protocol_params: Optional[dict[str, dict]] of per-protocol additional parameters.
|
||||
Supported for the SQS query to json protocol transition.
|
||||
"""
|
||||
signer = self.sqs_connection._request_signer
|
||||
|
||||
service_model = self.sqs_connection.meta.service_model
|
||||
protocol = service_model.protocol
|
||||
all_params = {**(params or {}), **protocol_params.get(protocol, {})}
|
||||
|
||||
if protocol == 'query':
|
||||
request = self._create_query_request(
|
||||
operation_name, all_params, queue_url, verb)
|
||||
elif protocol == 'json':
|
||||
request = self._create_json_request(
|
||||
operation_name, all_params, queue_url)
|
||||
else:
|
||||
raise Exception(f'Unsupported protocol: {protocol}.')
|
||||
|
||||
signing_type = 'presign-url' if request.method.lower() == 'get' \
|
||||
else 'standard'
|
||||
|
||||
signer.sign(operation_name, request, signing_type=signing_type)
|
||||
prepared_request = request.prepare()
|
||||
|
||||
return self._mexe(prepared_request, callback=callback)
|
||||
|
||||
def create_queue(self, queue_name,
|
||||
visibility_timeout=None, callback=None):
|
||||
params = {'QueueName': queue_name}
|
||||
if visibility_timeout:
|
||||
params['DefaultVisibilityTimeout'] = format(
|
||||
visibility_timeout, 'd',
|
||||
)
|
||||
return self.get_object('CreateQueue', params,
|
||||
callback=callback)
|
||||
|
||||
def delete_queue(self, queue, force_deletion=False, callback=None):
|
||||
return self.get_status('DeleteQueue', None, queue.id,
|
||||
callback=callback)
|
||||
|
||||
def get_queue_url(self, queue):
|
||||
res = self.sqs_connection.get_queue_url(QueueName=queue)
|
||||
return res['QueueUrl']
|
||||
|
||||
def get_queue_attributes(self, queue, attribute='All', callback=None):
|
||||
return self.get_object(
|
||||
'GetQueueAttributes', {'AttributeName': attribute},
|
||||
queue.id, callback=callback,
|
||||
)
|
||||
|
||||
def set_queue_attribute(self, queue, attribute, value, callback=None):
|
||||
return self.get_status(
|
||||
'SetQueueAttribute',
|
||||
{},
|
||||
queue.id, callback=callback,
|
||||
protocol_params={
|
||||
'json': {'Attributes': {attribute: value}},
|
||||
'query': {'Attribute.Name': attribute, 'Attribute.Value': value},
|
||||
},
|
||||
)
|
||||
|
||||
def receive_message(
|
||||
self, queue, queue_url, number_messages=1, visibility_timeout=None,
|
||||
attributes=None, wait_time_seconds=None,
|
||||
callback=None
|
||||
):
|
||||
params = {'MaxNumberOfMessages': number_messages}
|
||||
proto_params = {'query': {}, 'json': {}}
|
||||
attrs = attributes if attributes is not None else self.fetch_message_attributes
|
||||
|
||||
if visibility_timeout:
|
||||
params['VisibilityTimeout'] = visibility_timeout
|
||||
if attrs:
|
||||
proto_params['json'].update({'AttributeNames': list(attrs)})
|
||||
proto_params['query'].update(_query_object_encode({'AttributeName': list(attrs)}))
|
||||
if wait_time_seconds is not None:
|
||||
params['WaitTimeSeconds'] = wait_time_seconds
|
||||
|
||||
return self.get_list(
|
||||
'ReceiveMessage', params, [('Message', AsyncMessage)],
|
||||
queue_url, callback=callback, parent=queue,
|
||||
protocol_params=proto_params,
|
||||
)
|
||||
|
||||
def delete_message(self, queue, receipt_handle, callback=None):
|
||||
return self.delete_message_from_handle(
|
||||
queue, receipt_handle, callback,
|
||||
)
|
||||
|
||||
def delete_message_batch(self, queue, messages, callback=None):
|
||||
p_params = {
|
||||
'json': {
|
||||
'Entries': [{'Id': m.id, 'ReceiptHandle': m.receipt_handle} for m in messages],
|
||||
},
|
||||
'query': _query_object_encode({
|
||||
'DeleteMessageBatchRequestEntry': [
|
||||
{'Id': m.id, 'ReceiptHandle': m.receipt_handle}
|
||||
for m in messages
|
||||
],
|
||||
}),
|
||||
}
|
||||
|
||||
return self.get_object(
|
||||
'DeleteMessageBatch', {}, queue.id,
|
||||
verb='POST', callback=callback, protocol_params=p_params,
|
||||
)
|
||||
|
||||
def delete_message_from_handle(self, queue, receipt_handle,
|
||||
callback=None):
|
||||
return self.get_status(
|
||||
'DeleteMessage', {'ReceiptHandle': receipt_handle},
|
||||
queue, callback=callback,
|
||||
)
|
||||
|
||||
def send_message(self, queue, message_content,
|
||||
delay_seconds=None, callback=None):
|
||||
params = {'MessageBody': message_content}
|
||||
if delay_seconds:
|
||||
params['DelaySeconds'] = int(delay_seconds)
|
||||
return self.get_object(
|
||||
'SendMessage', params, queue.id,
|
||||
verb='POST', callback=callback,
|
||||
)
|
||||
|
||||
def send_message_batch(self, queue, messages, callback=None):
|
||||
params = {}
|
||||
for i, msg in enumerate(messages):
|
||||
prefix = f'SendMessageBatchRequestEntry.{i + 1}'
|
||||
params.update({
|
||||
f'{prefix}.Id': msg[0],
|
||||
f'{prefix}.MessageBody': msg[1],
|
||||
f'{prefix}.DelaySeconds': msg[2],
|
||||
})
|
||||
return self.get_object(
|
||||
'SendMessageBatch', params, queue.id,
|
||||
verb='POST', callback=callback,
|
||||
)
|
||||
|
||||
def change_message_visibility(self, queue, receipt_handle,
|
||||
visibility_timeout, callback=None):
|
||||
return self.get_status(
|
||||
'ChangeMessageVisibility',
|
||||
{'ReceiptHandle': receipt_handle,
|
||||
'VisibilityTimeout': visibility_timeout},
|
||||
queue.id, callback=callback,
|
||||
)
|
||||
|
||||
def change_message_visibility_batch(self, queue, messages, callback=None):
|
||||
entries = [
|
||||
{'Id': t[0].id, 'ReceiptHandle': t[0].receipt_handle, 'VisibilityTimeout': t[1]}
|
||||
for t in messages
|
||||
]
|
||||
|
||||
p_params = {
|
||||
'json': {'Entries': entries},
|
||||
'query': _query_object_encode({'ChangeMessageVisibilityBatchRequestEntry': entries}),
|
||||
}
|
||||
|
||||
return self.get_object(
|
||||
'ChangeMessageVisibilityBatch', {}, queue.id,
|
||||
verb='POST', callback=callback,
|
||||
protocol_params=p_params,
|
||||
)
|
||||
|
||||
def get_all_queues(self, prefix='', callback=None):
|
||||
params = {}
|
||||
if prefix:
|
||||
params['QueueNamePrefix'] = prefix
|
||||
return self.get_list(
|
||||
'ListQueues', params, [('QueueUrl', AsyncQueue)],
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
def get_queue(self, queue_name, callback=None):
|
||||
# TODO Does not support owner_acct_id argument
|
||||
return self.get_all_queues(
|
||||
queue_name,
|
||||
transform(self._on_queue_ready, callback, queue_name),
|
||||
)
|
||||
lookup = get_queue
|
||||
|
||||
def _on_queue_ready(self, name, queues):
|
||||
return next(
|
||||
(q for q in queues if q.url.endswith(name)), None,
|
||||
)
|
||||
|
||||
def get_dead_letter_source_queues(self, queue, callback=None):
|
||||
return self.get_list(
|
||||
'ListDeadLetterSourceQueues', {'QueueUrl': queue.url},
|
||||
[('QueueUrl', AsyncQueue)],
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
def add_permission(self, queue, label, aws_account_id, action_name,
|
||||
callback=None):
|
||||
return self.get_status(
|
||||
'AddPermission',
|
||||
{'Label': label,
|
||||
'AWSAccountId': aws_account_id,
|
||||
'ActionName': action_name},
|
||||
queue.id, callback=callback,
|
||||
)
|
||||
|
||||
def remove_permission(self, queue, label, callback=None):
|
||||
return self.get_status(
|
||||
'RemovePermission', {'Label': label}, queue.id, callback=callback,
|
||||
)
|
||||
|
||||
|
||||
def _query_object_encode(items):
|
||||
params = {}
|
||||
_query_object_encode_part(params, '', items)
|
||||
return {k: v for k, v in params.items()}
|
||||
|
||||
|
||||
def _query_object_encode_part(params, prefix, part):
|
||||
dotted = f'{prefix}.' if prefix else prefix
|
||||
|
||||
if isinstance(part, (list, tuple)):
|
||||
for i, item in enumerate(part):
|
||||
_query_object_encode_part(params, f'{dotted}{i + 1}', item)
|
||||
elif isinstance(part, dict):
|
||||
for key, value in part.items():
|
||||
_query_object_encode_part(params, f'{dotted}{key}', value)
|
||||
else:
|
||||
params[prefix] = str(part)
|
||||
@@ -0,0 +1,9 @@
|
||||
"""Amazon SQS boto3 interface."""
|
||||
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
try:
|
||||
import boto3
|
||||
except ImportError:
|
||||
boto3 = None
|
||||
@@ -0,0 +1,35 @@
|
||||
"""Amazon SQS message implementation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
|
||||
from kombu.message import Message
|
||||
from kombu.utils.encoding import str_to_bytes
|
||||
|
||||
|
||||
class BaseAsyncMessage(Message):
|
||||
"""Base class for messages received on async client."""
|
||||
|
||||
|
||||
class AsyncRawMessage(BaseAsyncMessage):
|
||||
"""Raw Message."""
|
||||
|
||||
|
||||
class AsyncMessage(BaseAsyncMessage):
|
||||
"""Serialized message."""
|
||||
|
||||
def encode(self, value):
|
||||
"""Encode/decode the value using Base64 encoding."""
|
||||
return base64.b64encode(str_to_bytes(value)).decode()
|
||||
|
||||
def __getitem__(self, item):
|
||||
"""Support Boto3-style access on a message."""
|
||||
if item == 'ReceiptHandle':
|
||||
return self.receipt_handle
|
||||
elif item == 'Body':
|
||||
return self.get_body()
|
||||
elif item == 'queue':
|
||||
return self.queue
|
||||
else:
|
||||
raise KeyError(item)
|
||||
@@ -0,0 +1,130 @@
|
||||
"""Amazon SQS queue implementation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from vine import transform
|
||||
|
||||
from .message import AsyncMessage
|
||||
|
||||
_all__ = ['AsyncQueue']
|
||||
|
||||
|
||||
def list_first(rs):
|
||||
"""Get the first item in a list, or None if list empty."""
|
||||
return rs[0] if len(rs) == 1 else None
|
||||
|
||||
|
||||
class AsyncQueue:
|
||||
"""Async SQS Queue."""
|
||||
|
||||
def __init__(self, connection=None, url=None, message_class=AsyncMessage):
|
||||
self.connection = connection
|
||||
self.url = url
|
||||
self.message_class = message_class
|
||||
self.visibility_timeout = None
|
||||
|
||||
def _NA(self, *args, **kwargs):
|
||||
raise NotImplementedError()
|
||||
count_slow = dump = save_to_file = save_to_filename = save = \
|
||||
save_to_s3 = load_from_s3 = load_from_file = load_from_filename = \
|
||||
load = clear = _NA
|
||||
|
||||
def get_attributes(self, attributes='All', callback=None):
|
||||
return self.connection.get_queue_attributes(
|
||||
self, attributes, callback,
|
||||
)
|
||||
|
||||
def set_attribute(self, attribute, value, callback=None):
|
||||
return self.connection.set_queue_attribute(
|
||||
self, attribute, value, callback,
|
||||
)
|
||||
|
||||
def get_timeout(self, callback=None, _attr='VisibilityTimeout'):
|
||||
return self.get_attributes(
|
||||
_attr, transform(
|
||||
self._coerce_field_value, callback, _attr, int,
|
||||
),
|
||||
)
|
||||
|
||||
def _coerce_field_value(self, key, type, response):
|
||||
return type(response[key])
|
||||
|
||||
def set_timeout(self, visibility_timeout, callback=None):
|
||||
return self.set_attribute(
|
||||
'VisibilityTimeout', visibility_timeout,
|
||||
transform(
|
||||
self._on_timeout_set, callback,
|
||||
)
|
||||
)
|
||||
|
||||
def _on_timeout_set(self, visibility_timeout):
|
||||
if visibility_timeout:
|
||||
self.visibility_timeout = visibility_timeout
|
||||
return self.visibility_timeout
|
||||
|
||||
def add_permission(self, label, aws_account_id, action_name,
|
||||
callback=None):
|
||||
return self.connection.add_permission(
|
||||
self, label, aws_account_id, action_name, callback,
|
||||
)
|
||||
|
||||
def remove_permission(self, label, callback=None):
|
||||
return self.connection.remove_permission(self, label, callback)
|
||||
|
||||
def read(self, visibility_timeout=None, wait_time_seconds=None,
|
||||
callback=None):
|
||||
return self.get_messages(
|
||||
1, visibility_timeout,
|
||||
wait_time_seconds=wait_time_seconds,
|
||||
callback=transform(list_first, callback),
|
||||
)
|
||||
|
||||
def write(self, message, delay_seconds=None, callback=None):
|
||||
return self.connection.send_message(
|
||||
self, message.get_body_encoded(), delay_seconds,
|
||||
callback=transform(self._on_message_sent, callback, message),
|
||||
)
|
||||
|
||||
def write_batch(self, messages, callback=None):
|
||||
return self.connection.send_message_batch(
|
||||
self, messages, callback=callback,
|
||||
)
|
||||
|
||||
def _on_message_sent(self, orig_message, new_message):
|
||||
orig_message.id = new_message.id
|
||||
orig_message.md5 = new_message.md5
|
||||
return new_message
|
||||
|
||||
def get_messages(self, num_messages=1, visibility_timeout=None,
|
||||
attributes=None, wait_time_seconds=None, callback=None):
|
||||
return self.connection.receive_message(
|
||||
self, number_messages=num_messages,
|
||||
visibility_timeout=visibility_timeout,
|
||||
attributes=attributes,
|
||||
wait_time_seconds=wait_time_seconds,
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
def delete_message(self, message, callback=None):
|
||||
return self.connection.delete_message(self, message, callback)
|
||||
|
||||
def delete_message_batch(self, messages, callback=None):
|
||||
return self.connection.delete_message_batch(
|
||||
self, messages, callback=callback,
|
||||
)
|
||||
|
||||
def change_message_visibility_batch(self, messages, callback=None):
|
||||
return self.connection.change_message_visibility_batch(
|
||||
self, messages, callback=callback,
|
||||
)
|
||||
|
||||
def delete(self, callback=None):
|
||||
return self.connection.delete_queue(self, callback=callback)
|
||||
|
||||
def count(self, page_size=10, vtimeout=10, callback=None,
|
||||
_attr='ApproximateNumberOfMessages'):
|
||||
return self.get_attributes(
|
||||
_attr, callback=transform(
|
||||
self._coerce_field_value, callback, _attr, int,
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,67 @@
|
||||
"""Event-loop debugging tools."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from kombu.utils.eventio import ERR, READ, WRITE
|
||||
from kombu.utils.functional import reprcall
|
||||
|
||||
|
||||
def repr_flag(flag):
|
||||
"""Return description of event loop flag."""
|
||||
return '{}{}{}'.format('R' if flag & READ else '',
|
||||
'W' if flag & WRITE else '',
|
||||
'!' if flag & ERR else '')
|
||||
|
||||
|
||||
def _rcb(obj):
|
||||
if obj is None:
|
||||
return '<missing>'
|
||||
if isinstance(obj, str):
|
||||
return obj
|
||||
if isinstance(obj, tuple):
|
||||
cb, args = obj
|
||||
return reprcall(cb.__name__, args=args)
|
||||
return obj.__name__
|
||||
|
||||
|
||||
def repr_active(h):
|
||||
"""Return description of active readers and writers."""
|
||||
return ', '.join(repr_readers(h) + repr_writers(h))
|
||||
|
||||
|
||||
def repr_events(h, events):
|
||||
"""Return description of events returned by poll."""
|
||||
return ', '.join(
|
||||
'{}({})->{}'.format(
|
||||
_rcb(callback_for(h, fd, fl, '(GONE)')), fd,
|
||||
repr_flag(fl),
|
||||
)
|
||||
for fd, fl in events
|
||||
)
|
||||
|
||||
|
||||
def repr_readers(h):
|
||||
"""Return description of pending readers."""
|
||||
return [f'({fd}){_rcb(cb)}->{repr_flag(READ | ERR)}'
|
||||
for fd, cb in h.readers.items()]
|
||||
|
||||
|
||||
def repr_writers(h):
|
||||
"""Return description of pending writers."""
|
||||
return [f'({fd}){_rcb(cb)}->{repr_flag(WRITE)}'
|
||||
for fd, cb in h.writers.items()]
|
||||
|
||||
|
||||
def callback_for(h, fd, flag, *default):
|
||||
"""Return the callback used for hub+fd+flag."""
|
||||
try:
|
||||
if flag & READ:
|
||||
return h.readers[fd]
|
||||
if flag & WRITE:
|
||||
if fd in h.consolidate:
|
||||
return h.consolidate_callback
|
||||
return h.writers[fd]
|
||||
except KeyError:
|
||||
if default:
|
||||
return default[0]
|
||||
raise
|
||||
@@ -0,0 +1,23 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from kombu.asynchronous import get_event_loop
|
||||
from kombu.asynchronous.http.base import BaseClient, Headers, Request, Response
|
||||
from kombu.asynchronous.hub import Hub
|
||||
|
||||
__all__ = ('Client', 'Headers', 'Response', 'Request', 'get_client')
|
||||
|
||||
|
||||
def Client(hub: Hub | None = None, **kwargs: int) -> BaseClient:
|
||||
"""Create new HTTP client."""
|
||||
from .urllib3_client import Urllib3Client
|
||||
return Urllib3Client(hub, **kwargs)
|
||||
|
||||
|
||||
def get_client(hub: Hub | None = None, **kwargs: int) -> BaseClient:
|
||||
"""Get or create HTTP client bound to the current event loop."""
|
||||
hub = hub or get_event_loop()
|
||||
try:
|
||||
return hub._current_http_client
|
||||
except AttributeError:
|
||||
client = hub._current_http_client = Client(hub, **kwargs)
|
||||
return client
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,280 @@
|
||||
"""Base async HTTP client implementation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from http.client import responses
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from vine import Thenable, maybe_promise, promise
|
||||
|
||||
from kombu.exceptions import HttpError
|
||||
from kombu.utils.compat import coro
|
||||
from kombu.utils.encoding import bytes_to_str
|
||||
from kombu.utils.functional import maybe_list, memoize
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from types import TracebackType
|
||||
|
||||
__all__ = ('Headers', 'Response', 'Request', 'BaseClient')
|
||||
|
||||
PYPY = hasattr(sys, 'pypy_version_info')
|
||||
|
||||
|
||||
@memoize(maxsize=1000)
|
||||
def normalize_header(key):
|
||||
return '-'.join(p.capitalize() for p in key.split('-'))
|
||||
|
||||
|
||||
class Headers(dict):
|
||||
"""Represents a mapping of HTTP headers."""
|
||||
|
||||
# TODO: This is just a regular dict and will not perform normalization
|
||||
# when looking up keys etc.
|
||||
|
||||
#: Set when all of the headers have been read.
|
||||
complete = False
|
||||
|
||||
#: Internal attribute used to keep track of continuation lines.
|
||||
_prev_key = None
|
||||
|
||||
|
||||
@Thenable.register
|
||||
class Request:
|
||||
"""A HTTP Request.
|
||||
|
||||
Arguments:
|
||||
---------
|
||||
url (str): The URL to request.
|
||||
method (str): The HTTP method to use (defaults to ``GET``).
|
||||
|
||||
Keyword Arguments:
|
||||
-----------------
|
||||
headers (Dict, ~kombu.asynchronous.http.Headers): Optional headers for
|
||||
this request
|
||||
body (str): Optional body for this request.
|
||||
connect_timeout (float): Connection timeout in float seconds
|
||||
Default is 30.0.
|
||||
timeout (float): Time in float seconds before the request times out
|
||||
Default is 30.0.
|
||||
follow_redirects (bool): Specify if the client should follow redirects
|
||||
Enabled by default.
|
||||
max_redirects (int): Maximum number of redirects (default 6).
|
||||
use_gzip (bool): Allow the server to use gzip compression.
|
||||
Enabled by default.
|
||||
validate_cert (bool): Set to true if the server certificate should be
|
||||
verified when performing ``https://`` requests.
|
||||
Enabled by default.
|
||||
auth_username (str): Username for HTTP authentication.
|
||||
auth_password (str): Password for HTTP authentication.
|
||||
auth_mode (str): Type of HTTP authentication (``basic`` or ``digest``).
|
||||
user_agent (str): Custom user agent for this request.
|
||||
network_interface (str): Network interface to use for this request.
|
||||
on_ready (Callable): Callback to be called when the response has been
|
||||
received. Must accept single ``response`` argument.
|
||||
on_stream (Callable): Optional callback to be called every time body
|
||||
content has been read from the socket. If specified then the
|
||||
response body and buffer attributes will not be available.
|
||||
on_timeout (callable): Optional callback to be called if the request
|
||||
times out.
|
||||
on_header (Callable): Optional callback to be called for every header
|
||||
line received from the server. The signature
|
||||
is ``(headers, line)`` and note that if you want
|
||||
``response.headers`` to be populated then your callback needs to
|
||||
also call ``client.on_header(headers, line)``.
|
||||
on_prepare (Callable): Optional callback that is implementation
|
||||
specific (e.g. curl client will pass the ``curl`` instance to
|
||||
this callback).
|
||||
proxy_host (str): Optional proxy host. Note that a ``proxy_port`` must
|
||||
also be provided or a :exc:`ValueError` will be raised.
|
||||
proxy_username (str): Optional username to use when logging in
|
||||
to the proxy.
|
||||
proxy_password (str): Optional password to use when authenticating
|
||||
with the proxy server.
|
||||
ca_certs (str): Custom CA certificates file to use.
|
||||
client_key (str): Optional filename for client SSL key.
|
||||
client_cert (str): Optional filename for client SSL certificate.
|
||||
"""
|
||||
|
||||
body = user_agent = network_interface = \
|
||||
auth_username = auth_password = auth_mode = \
|
||||
proxy_host = proxy_port = proxy_username = proxy_password = \
|
||||
ca_certs = client_key = client_cert = None
|
||||
|
||||
connect_timeout = 30.0
|
||||
request_timeout = 30.0
|
||||
follow_redirects = True
|
||||
max_redirects = 6
|
||||
use_gzip = True
|
||||
validate_cert = True
|
||||
|
||||
if not PYPY: # pragma: no cover
|
||||
__slots__ = ('url', 'method', 'on_ready', 'on_timeout', 'on_stream',
|
||||
'on_prepare', 'on_header', 'headers',
|
||||
'__weakref__', '__dict__')
|
||||
|
||||
def __init__(self, url, method='GET', on_ready=None, on_timeout=None,
|
||||
on_stream=None, on_prepare=None, on_header=None,
|
||||
headers=None, **kwargs):
|
||||
self.url = url
|
||||
self.method = method or self.method
|
||||
self.on_ready = maybe_promise(on_ready) or promise()
|
||||
self.on_timeout = maybe_promise(on_timeout)
|
||||
self.on_stream = maybe_promise(on_stream)
|
||||
self.on_prepare = maybe_promise(on_prepare)
|
||||
self.on_header = maybe_promise(on_header)
|
||||
if kwargs:
|
||||
for k, v in kwargs.items():
|
||||
setattr(self, k, v)
|
||||
if not isinstance(headers, Headers):
|
||||
headers = Headers(headers or {})
|
||||
self.headers = headers
|
||||
|
||||
def then(self, callback, errback=None):
|
||||
self.on_ready.then(callback, errback)
|
||||
|
||||
def __repr__(self):
|
||||
return '<Request: {0.method} {0.url} {0.body}>'.format(self)
|
||||
|
||||
|
||||
class Response:
|
||||
"""HTTP Response.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
request (~kombu.asynchronous.http.Request): See :attr:`request`.
|
||||
code (int): See :attr:`code`.
|
||||
headers (~kombu.asynchronous.http.Headers): See :attr:`headers`.
|
||||
buffer (bytes): See :attr:`buffer`
|
||||
effective_url (str): See :attr:`effective_url`.
|
||||
status (str): See :attr:`status`.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
request (~kombu.asynchronous.http.Request): object used to
|
||||
get this response.
|
||||
code (int): HTTP response code (e.g. 200, 404, or 500).
|
||||
headers (~kombu.asynchronous.http.Headers): HTTP headers
|
||||
for this response.
|
||||
buffer (bytes): Socket read buffer.
|
||||
effective_url (str): The destination url for this request after
|
||||
following redirects.
|
||||
error (Exception): Error instance if the request resulted in
|
||||
a HTTP error code.
|
||||
status (str): Human equivalent of :attr:`code`,
|
||||
e.g. ``OK``, `Not found`, or 'Internal Server Error'.
|
||||
"""
|
||||
|
||||
if not PYPY: # pragma: no cover
|
||||
__slots__ = ('request', 'code', 'headers', 'buffer', 'effective_url',
|
||||
'error', 'status', '_body', '__weakref__')
|
||||
|
||||
def __init__(self, request, code, headers=None, buffer=None,
|
||||
effective_url=None, error=None, status=None):
|
||||
self.request = request
|
||||
self.code = code
|
||||
self.headers = headers if headers is not None else Headers()
|
||||
self.buffer = buffer
|
||||
self.effective_url = effective_url or request.url
|
||||
self._body = None
|
||||
|
||||
self.status = status or responses.get(self.code, 'Unknown')
|
||||
self.error = error
|
||||
if self.error is None and (self.code < 200 or self.code > 299):
|
||||
self.error = HttpError(self.code, self.status, self)
|
||||
|
||||
def raise_for_error(self):
|
||||
"""Raise if the request resulted in an HTTP error code.
|
||||
|
||||
Raises
|
||||
------
|
||||
:class:`~kombu.exceptions.HttpError`
|
||||
"""
|
||||
if self.error:
|
||||
raise self.error
|
||||
|
||||
@property
|
||||
def body(self):
|
||||
"""The full contents of the response body.
|
||||
|
||||
Note:
|
||||
----
|
||||
Accessing this property will evaluate the buffer
|
||||
and subsequent accesses will be cached.
|
||||
"""
|
||||
if self._body is None:
|
||||
if self.buffer is not None:
|
||||
self._body = self.buffer.getvalue()
|
||||
return self._body
|
||||
|
||||
# these are for compatibility with Requests
|
||||
@property
|
||||
def status_code(self):
|
||||
return self.code
|
||||
|
||||
@property
|
||||
def content(self):
|
||||
return self.body
|
||||
|
||||
|
||||
@coro
|
||||
def header_parser(keyt=normalize_header):
|
||||
while 1:
|
||||
(line, headers) = yield
|
||||
if line.startswith('HTTP/'):
|
||||
continue
|
||||
elif not line:
|
||||
headers.complete = True
|
||||
continue
|
||||
elif line[0].isspace():
|
||||
pkey = headers._prev_key
|
||||
headers[pkey] = ' '.join([headers.get(pkey) or '', line.lstrip()])
|
||||
else:
|
||||
key, value = line.split(':', 1)
|
||||
key = headers._prev_key = keyt(key)
|
||||
headers[key] = value.strip()
|
||||
|
||||
|
||||
class BaseClient:
|
||||
"""Base class for HTTP clients.
|
||||
|
||||
This class provides the basic structure and functionality for HTTP clients.
|
||||
Subclasses should implement specific HTTP client behavior.
|
||||
"""
|
||||
|
||||
Headers = Headers
|
||||
Request = Request
|
||||
Response = Response
|
||||
|
||||
def __init__(self, hub, **kwargs):
|
||||
self.hub = hub
|
||||
self._header_parser = header_parser()
|
||||
|
||||
def perform(self, request, **kwargs):
|
||||
for req in maybe_list(request) or []:
|
||||
if not isinstance(req, self.Request):
|
||||
req = self.Request(req, **kwargs)
|
||||
self.add_request(req)
|
||||
|
||||
def add_request(self, request):
|
||||
raise NotImplementedError('must implement add_request')
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
def on_header(self, headers, line):
|
||||
try:
|
||||
self._header_parser.send((bytes_to_str(line), headers))
|
||||
except StopIteration:
|
||||
self._header_parser = header_parser()
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: TracebackType | None
|
||||
) -> None:
|
||||
self.close()
|
||||
@@ -0,0 +1,219 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import deque
|
||||
from io import BytesIO
|
||||
|
||||
import urllib3
|
||||
|
||||
from kombu.asynchronous.hub import Hub, get_event_loop
|
||||
from kombu.exceptions import HttpError
|
||||
|
||||
from .base import BaseClient, Request
|
||||
|
||||
__all__ = ('Urllib3Client',)
|
||||
|
||||
from ...utils.encoding import bytes_to_str
|
||||
|
||||
DEFAULT_USER_AGENT = 'Mozilla/5.0 (compatible; urllib3)'
|
||||
EXTRA_METHODS = frozenset(['DELETE', 'OPTIONS', 'PATCH'])
|
||||
|
||||
|
||||
def _get_pool_key_parts(request: Request) -> list[str]:
|
||||
_pool_key_parts = []
|
||||
|
||||
if request.network_interface:
|
||||
_pool_key_parts.append(f"interface={request.network_interface}")
|
||||
|
||||
if request.validate_cert:
|
||||
_pool_key_parts.append("validate_cert=True")
|
||||
else:
|
||||
_pool_key_parts.append("validate_cert=False")
|
||||
|
||||
if request.ca_certs:
|
||||
_pool_key_parts.append(f"ca_certs={request.ca_certs}")
|
||||
|
||||
if request.client_cert:
|
||||
_pool_key_parts.append(f"client_cert={request.client_cert}")
|
||||
|
||||
if request.client_key:
|
||||
_pool_key_parts.append(f"client_key={request.client_key}")
|
||||
|
||||
return _pool_key_parts
|
||||
|
||||
|
||||
class Urllib3Client(BaseClient):
|
||||
"""Urllib3 HTTP Client."""
|
||||
|
||||
_pools = {}
|
||||
|
||||
def __init__(self, hub: Hub | None = None, max_clients: int = 10):
|
||||
hub = hub or get_event_loop()
|
||||
super().__init__(hub)
|
||||
self.max_clients = max_clients
|
||||
self._pending = deque()
|
||||
self._timeout_check_tref = self.hub.call_repeatedly(
|
||||
1.0, self._timeout_check,
|
||||
)
|
||||
|
||||
def pools_close(self):
|
||||
for pool in self._pools.values():
|
||||
pool.close()
|
||||
self._pools.clear()
|
||||
|
||||
def close(self):
|
||||
self._timeout_check_tref.cancel()
|
||||
self.pools_close()
|
||||
|
||||
def add_request(self, request):
|
||||
self._pending.append(request)
|
||||
self._process_queue()
|
||||
return request
|
||||
|
||||
def get_pool(self, request: Request):
|
||||
_pool_key_parts = _get_pool_key_parts(request=request)
|
||||
|
||||
_proxy_url = None
|
||||
proxy_headers = None
|
||||
if request.proxy_host:
|
||||
_proxy_url = urllib3.util.Url(
|
||||
scheme=None,
|
||||
host=request.proxy_host,
|
||||
port=request.proxy_port,
|
||||
)
|
||||
if request.proxy_username:
|
||||
proxy_headers = urllib3.make_headers(
|
||||
proxy_basic_auth=(
|
||||
f"{request.proxy_username}"
|
||||
f":{request.proxy_password}"
|
||||
)
|
||||
)
|
||||
else:
|
||||
proxy_headers = None
|
||||
|
||||
_proxy_url = _proxy_url.url
|
||||
|
||||
_pool_key_parts.append(f"proxy={_proxy_url}")
|
||||
if proxy_headers:
|
||||
_pool_key_parts.append(f"proxy_headers={str(proxy_headers)}")
|
||||
|
||||
_pool_key = "|".join(_pool_key_parts)
|
||||
if _pool_key in self._pools:
|
||||
return self._pools[_pool_key]
|
||||
|
||||
# create new pool
|
||||
if _proxy_url:
|
||||
_pool = urllib3.ProxyManager(
|
||||
proxy_url=_proxy_url,
|
||||
num_pools=self.max_clients,
|
||||
proxy_headers=proxy_headers
|
||||
)
|
||||
else:
|
||||
_pool = urllib3.PoolManager(num_pools=self.max_clients)
|
||||
|
||||
# Network Interface
|
||||
if request.network_interface:
|
||||
_pool.connection_pool_kw['source_address'] = (
|
||||
request.network_interface,
|
||||
0
|
||||
)
|
||||
|
||||
# SSL Verification
|
||||
if request.validate_cert:
|
||||
_pool.connection_pool_kw['cert_reqs'] = 'CERT_REQUIRED'
|
||||
else:
|
||||
_pool.connection_pool_kw['cert_reqs'] = 'CERT_NONE'
|
||||
|
||||
# CA Certificates
|
||||
if request.ca_certs is not None:
|
||||
_pool.connection_pool_kw['ca_certs'] = request.ca_certs
|
||||
elif request.validate_cert is True:
|
||||
try:
|
||||
from certifi import where
|
||||
_pool.connection_pool_kw['ca_certs'] = where()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Client Certificates
|
||||
if request.client_cert is not None:
|
||||
_pool.connection_pool_kw['cert_file'] = request.client_cert
|
||||
if request.client_key is not None:
|
||||
_pool.connection_pool_kw['key_file'] = request.client_key
|
||||
|
||||
self._pools[_pool_key] = _pool
|
||||
|
||||
return _pool
|
||||
|
||||
def _timeout_check(self):
|
||||
self._process_pending_requests()
|
||||
|
||||
def _process_pending_requests(self):
|
||||
while self._pending:
|
||||
request = self._pending.popleft()
|
||||
self._process_request(request)
|
||||
|
||||
def _process_request(self, request: Request):
|
||||
# Prepare headers
|
||||
headers = request.headers
|
||||
headers.setdefault(
|
||||
'User-Agent',
|
||||
bytes_to_str(request.user_agent or DEFAULT_USER_AGENT)
|
||||
)
|
||||
headers.setdefault(
|
||||
'Accept-Encoding',
|
||||
'gzip,deflate' if request.use_gzip else 'none'
|
||||
)
|
||||
|
||||
# Authentication
|
||||
if request.auth_username is not None:
|
||||
headers.update(
|
||||
urllib3.util.make_headers(
|
||||
basic_auth=(
|
||||
f"{request.auth_username}"
|
||||
f":{request.auth_password or ''}"
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Make the request using urllib3
|
||||
try:
|
||||
_pool = self.get_pool(request=request)
|
||||
response = _pool.request(
|
||||
request.method,
|
||||
request.url,
|
||||
headers=headers,
|
||||
body=request.body,
|
||||
preload_content=False,
|
||||
redirect=request.follow_redirects,
|
||||
)
|
||||
buffer = BytesIO(response.data)
|
||||
response_obj = self.Response(
|
||||
request=request,
|
||||
code=response.status,
|
||||
headers=response.headers,
|
||||
buffer=buffer,
|
||||
effective_url=response.geturl(),
|
||||
error=None
|
||||
)
|
||||
except urllib3.exceptions.HTTPError as e:
|
||||
response_obj = self.Response(
|
||||
request=request,
|
||||
code=599,
|
||||
headers={},
|
||||
buffer=None,
|
||||
effective_url=None,
|
||||
error=HttpError(599, str(e))
|
||||
)
|
||||
|
||||
request.on_ready(response_obj)
|
||||
|
||||
def _process_queue(self):
|
||||
self._process_pending_requests()
|
||||
|
||||
def on_readable(self, fd):
|
||||
pass
|
||||
|
||||
def on_writable(self, fd):
|
||||
pass
|
||||
|
||||
def _setup_request(self, curl, request, buffer, headers):
|
||||
pass
|
||||
@@ -0,0 +1,399 @@
|
||||
"""Event loop implementation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import errno
|
||||
import threading
|
||||
from contextlib import contextmanager
|
||||
from copy import copy
|
||||
from queue import Empty
|
||||
from time import sleep
|
||||
from types import GeneratorType as generator
|
||||
|
||||
from vine import Thenable, promise
|
||||
|
||||
from kombu.log import get_logger
|
||||
from kombu.utils.compat import fileno
|
||||
from kombu.utils.eventio import ERR, READ, WRITE, poll
|
||||
from kombu.utils.objects import cached_property
|
||||
|
||||
from .timer import Timer
|
||||
|
||||
__all__ = ('Hub', 'get_event_loop', 'set_event_loop')
|
||||
logger = get_logger(__name__)
|
||||
|
||||
_current_loop: Hub | None = None
|
||||
|
||||
W_UNKNOWN_EVENT = """\
|
||||
Received unknown event %r for fd %r, please contact support!\
|
||||
"""
|
||||
|
||||
|
||||
class Stop(BaseException):
|
||||
"""Stops the event loop."""
|
||||
|
||||
|
||||
def _raise_stop_error():
|
||||
raise Stop()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _dummy_context(*args, **kwargs):
|
||||
yield
|
||||
|
||||
|
||||
def get_event_loop() -> Hub | None:
|
||||
"""Get current event loop object."""
|
||||
return _current_loop
|
||||
|
||||
|
||||
def set_event_loop(loop: Hub | None) -> Hub | None:
|
||||
"""Set the current event loop object."""
|
||||
global _current_loop
|
||||
_current_loop = loop
|
||||
return loop
|
||||
|
||||
|
||||
class Hub:
|
||||
"""Event loop object.
|
||||
|
||||
Arguments:
|
||||
---------
|
||||
timer (kombu.asynchronous.Timer): Specify custom timer instance.
|
||||
"""
|
||||
|
||||
#: Flag set if reading from an fd will not block.
|
||||
READ = READ
|
||||
|
||||
#: Flag set if writing to an fd will not block.
|
||||
WRITE = WRITE
|
||||
|
||||
#: Flag set on error, and the fd should be read from asap.
|
||||
ERR = ERR
|
||||
|
||||
#: List of callbacks to be called when the loop is exiting,
|
||||
#: applied with the hub instance as sole argument.
|
||||
on_close = None
|
||||
|
||||
def __init__(self, timer=None):
|
||||
self.timer = timer if timer is not None else Timer()
|
||||
|
||||
self.readers = {}
|
||||
self.writers = {}
|
||||
self.on_tick = set()
|
||||
self.on_close = set()
|
||||
self._ready = set()
|
||||
self._ready_lock = threading.Lock()
|
||||
|
||||
self._running = False
|
||||
self._loop = None
|
||||
|
||||
# The eventloop (in celery.worker.loops)
|
||||
# will merge fds in this set and then instead of calling
|
||||
# the callback for each ready fd it will call the
|
||||
# :attr:`consolidate_callback` with the list of ready_fds
|
||||
# as an argument. This API is internal and is only
|
||||
# used by the multiprocessing pool to find inqueues
|
||||
# that are ready to write.
|
||||
self.consolidate = set()
|
||||
self.consolidate_callback = None
|
||||
|
||||
self.propagate_errors = ()
|
||||
|
||||
self._create_poller()
|
||||
|
||||
@property
|
||||
def poller(self):
|
||||
if not self._poller:
|
||||
self._create_poller()
|
||||
return self._poller
|
||||
|
||||
@poller.setter
|
||||
def poller(self, value):
|
||||
self._poller = value
|
||||
|
||||
def reset(self):
|
||||
self.close()
|
||||
self._create_poller()
|
||||
|
||||
def _create_poller(self):
|
||||
self._poller = poll()
|
||||
self._register_fd = self._poller.register
|
||||
self._unregister_fd = self._poller.unregister
|
||||
|
||||
def _close_poller(self):
|
||||
if self._poller is not None:
|
||||
self._poller.close()
|
||||
self._poller = None
|
||||
self._register_fd = None
|
||||
self._unregister_fd = None
|
||||
|
||||
def stop(self):
|
||||
self.call_soon(_raise_stop_error)
|
||||
|
||||
def __repr__(self):
|
||||
return '<Hub@{:#x}: R:{} W:{}>'.format(
|
||||
id(self), len(self.readers), len(self.writers),
|
||||
)
|
||||
|
||||
def fire_timers(self, min_delay=1, max_delay=10, max_timers=10,
|
||||
propagate=()):
|
||||
timer = self.timer
|
||||
delay = None
|
||||
if timer and timer._queue:
|
||||
for i in range(max_timers):
|
||||
delay, entry = next(self.scheduler)
|
||||
if entry is None:
|
||||
break
|
||||
try:
|
||||
entry()
|
||||
except propagate:
|
||||
raise
|
||||
except (MemoryError, AssertionError):
|
||||
raise
|
||||
except OSError as exc:
|
||||
if exc.errno == errno.ENOMEM:
|
||||
raise
|
||||
logger.error('Error in timer: %r', exc, exc_info=1)
|
||||
except Exception as exc:
|
||||
logger.error('Error in timer: %r', exc, exc_info=1)
|
||||
return min(delay or min_delay, max_delay)
|
||||
|
||||
def _remove_from_loop(self, fd):
|
||||
try:
|
||||
self._unregister(fd)
|
||||
finally:
|
||||
self._discard(fd)
|
||||
|
||||
def add(self, fd, callback, flags, args=(), consolidate=False):
|
||||
fd = fileno(fd)
|
||||
try:
|
||||
self.poller.register(fd, flags)
|
||||
except ValueError:
|
||||
self._remove_from_loop(fd)
|
||||
raise
|
||||
else:
|
||||
dest = self.readers if flags & READ else self.writers
|
||||
if consolidate:
|
||||
self.consolidate.add(fd)
|
||||
dest[fd] = None
|
||||
else:
|
||||
dest[fd] = callback, args
|
||||
|
||||
def remove(self, fd):
|
||||
fd = fileno(fd)
|
||||
self._remove_from_loop(fd)
|
||||
|
||||
def run_forever(self):
|
||||
self._running = True
|
||||
try:
|
||||
while 1:
|
||||
try:
|
||||
self.run_once()
|
||||
except Stop:
|
||||
break
|
||||
finally:
|
||||
self._running = False
|
||||
|
||||
def run_once(self):
|
||||
try:
|
||||
next(self.loop)
|
||||
except StopIteration:
|
||||
self._loop = None
|
||||
|
||||
def call_soon(self, callback, *args):
|
||||
if not isinstance(callback, Thenable):
|
||||
callback = promise(callback, args)
|
||||
with self._ready_lock:
|
||||
self._ready.add(callback)
|
||||
return callback
|
||||
|
||||
def call_later(self, delay, callback, *args):
|
||||
return self.timer.call_after(delay, callback, args)
|
||||
|
||||
def call_at(self, when, callback, *args):
|
||||
return self.timer.call_at(when, callback, args)
|
||||
|
||||
def call_repeatedly(self, delay, callback, *args):
|
||||
return self.timer.call_repeatedly(delay, callback, args)
|
||||
|
||||
def add_reader(self, fds, callback, *args):
|
||||
return self.add(fds, callback, READ | ERR, args)
|
||||
|
||||
def add_writer(self, fds, callback, *args):
|
||||
return self.add(fds, callback, WRITE, args)
|
||||
|
||||
def remove_reader(self, fd):
|
||||
writable = fd in self.writers
|
||||
on_write = self.writers.get(fd)
|
||||
try:
|
||||
self._remove_from_loop(fd)
|
||||
finally:
|
||||
if writable:
|
||||
cb, args = on_write
|
||||
self.add(fd, cb, WRITE, args)
|
||||
|
||||
def remove_writer(self, fd):
|
||||
readable = fd in self.readers
|
||||
on_read = self.readers.get(fd)
|
||||
try:
|
||||
self._remove_from_loop(fd)
|
||||
finally:
|
||||
if readable:
|
||||
cb, args = on_read
|
||||
self.add(fd, cb, READ | ERR, args)
|
||||
|
||||
def _unregister(self, fd):
|
||||
try:
|
||||
self.poller.unregister(fd)
|
||||
except (AttributeError, KeyError, OSError):
|
||||
pass
|
||||
|
||||
def _pop_ready(self):
|
||||
with self._ready_lock:
|
||||
ready = self._ready
|
||||
self._ready = set()
|
||||
return ready
|
||||
|
||||
def close(self, *args):
|
||||
[self._unregister(fd) for fd in self.readers]
|
||||
self.readers.clear()
|
||||
[self._unregister(fd) for fd in self.writers]
|
||||
self.writers.clear()
|
||||
self.consolidate.clear()
|
||||
self._close_poller()
|
||||
for callback in self.on_close:
|
||||
callback(self)
|
||||
|
||||
# Complete remaining todo before Hub close
|
||||
# Eg: Acknowledge message
|
||||
# To avoid infinite loop where one of the callables adds items
|
||||
# to self._ready (via call_soon or otherwise).
|
||||
# we create new list with current self._ready
|
||||
todos = self._pop_ready()
|
||||
for item in todos:
|
||||
item()
|
||||
|
||||
def _discard(self, fd):
|
||||
fd = fileno(fd)
|
||||
self.readers.pop(fd, None)
|
||||
self.writers.pop(fd, None)
|
||||
self.consolidate.discard(fd)
|
||||
|
||||
def on_callback_error(self, callback, exc):
|
||||
logger.error(
|
||||
'Callback %r raised exception: %r', callback, exc, exc_info=1,
|
||||
)
|
||||
|
||||
def create_loop(self,
|
||||
generator=generator, sleep=sleep, min=min, next=next,
|
||||
Empty=Empty, StopIteration=StopIteration,
|
||||
KeyError=KeyError, READ=READ, WRITE=WRITE, ERR=ERR):
|
||||
readers, writers = self.readers, self.writers
|
||||
poll = self.poller.poll
|
||||
fire_timers = self.fire_timers
|
||||
hub_remove = self.remove
|
||||
scheduled = self.timer._queue
|
||||
consolidate = self.consolidate
|
||||
consolidate_callback = self.consolidate_callback
|
||||
propagate = self.propagate_errors
|
||||
|
||||
while 1:
|
||||
todo = self._pop_ready()
|
||||
|
||||
for item in todo:
|
||||
if item:
|
||||
item()
|
||||
|
||||
poll_timeout = fire_timers(propagate=propagate) if scheduled else 1
|
||||
|
||||
for tick_callback in copy(self.on_tick):
|
||||
tick_callback()
|
||||
|
||||
# print('[[[HUB]]]: %s' % (self.repr_active(),))
|
||||
if readers or writers:
|
||||
to_consolidate = []
|
||||
try:
|
||||
events = poll(poll_timeout)
|
||||
# print('[EVENTS]: %s' % (self.repr_events(events),))
|
||||
except ValueError: # Issue celery/#882
|
||||
return
|
||||
|
||||
for fd, event in events or ():
|
||||
general_error = False
|
||||
if fd in consolidate and \
|
||||
writers.get(fd) is None:
|
||||
to_consolidate.append(fd)
|
||||
continue
|
||||
cb = cbargs = None
|
||||
|
||||
if event & READ:
|
||||
try:
|
||||
cb, cbargs = readers[fd]
|
||||
except KeyError:
|
||||
self.remove_reader(fd)
|
||||
continue
|
||||
elif event & WRITE:
|
||||
try:
|
||||
cb, cbargs = writers[fd]
|
||||
except KeyError:
|
||||
self.remove_writer(fd)
|
||||
continue
|
||||
elif event & ERR:
|
||||
general_error = True
|
||||
else:
|
||||
logger.info(W_UNKNOWN_EVENT, event, fd)
|
||||
general_error = True
|
||||
|
||||
if general_error:
|
||||
try:
|
||||
cb, cbargs = (readers.get(fd) or
|
||||
writers.get(fd))
|
||||
except TypeError:
|
||||
pass
|
||||
|
||||
if cb is None:
|
||||
self.remove(fd)
|
||||
continue
|
||||
|
||||
if isinstance(cb, generator):
|
||||
try:
|
||||
next(cb)
|
||||
except OSError as exc:
|
||||
if exc.errno != errno.EBADF:
|
||||
raise
|
||||
hub_remove(fd)
|
||||
except StopIteration:
|
||||
pass
|
||||
except Exception:
|
||||
hub_remove(fd)
|
||||
raise
|
||||
else:
|
||||
try:
|
||||
cb(*cbargs)
|
||||
except Empty:
|
||||
pass
|
||||
if to_consolidate:
|
||||
consolidate_callback(to_consolidate)
|
||||
else:
|
||||
# no sockets yet, startup is probably not done.
|
||||
sleep(min(poll_timeout, 0.1))
|
||||
yield
|
||||
|
||||
def repr_active(self):
|
||||
from .debug import repr_active
|
||||
return repr_active(self)
|
||||
|
||||
def repr_events(self, events):
|
||||
from .debug import repr_events
|
||||
return repr_events(self, events or [])
|
||||
|
||||
@cached_property
|
||||
def scheduler(self):
|
||||
return iter(self.timer)
|
||||
|
||||
@property
|
||||
def loop(self):
|
||||
if self._loop is None:
|
||||
self._loop = self.create_loop()
|
||||
return self._loop
|
||||
@@ -0,0 +1,127 @@
|
||||
"""Semaphores and concurrency primitives."""
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from collections import deque
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from types import TracebackType
|
||||
from typing import Callable, Deque
|
||||
if sys.version_info < (3, 10):
|
||||
from typing_extensions import ParamSpec
|
||||
else:
|
||||
from typing import ParamSpec
|
||||
|
||||
P = ParamSpec("P")
|
||||
|
||||
__all__ = ('DummyLock', 'LaxBoundedSemaphore')
|
||||
|
||||
|
||||
class LaxBoundedSemaphore:
|
||||
"""Asynchronous Bounded Semaphore.
|
||||
|
||||
Lax means that the value will stay within the specified
|
||||
range even if released more times than it was acquired.
|
||||
|
||||
Example:
|
||||
-------
|
||||
>>> x = LaxBoundedSemaphore(2)
|
||||
|
||||
>>> x.acquire(print, 'HELLO 1')
|
||||
HELLO 1
|
||||
|
||||
>>> x.acquire(print, 'HELLO 2')
|
||||
HELLO 2
|
||||
|
||||
>>> x.acquire(print, 'HELLO 3')
|
||||
>>> x._waiters # private, do not access directly
|
||||
[print, ('HELLO 3',)]
|
||||
|
||||
>>> x.release()
|
||||
HELLO 3
|
||||
"""
|
||||
|
||||
def __init__(self, value: int) -> None:
|
||||
self.initial_value = self.value = value
|
||||
self._waiting: Deque[tuple] = deque()
|
||||
self._add_waiter = self._waiting.append
|
||||
self._pop_waiter = self._waiting.popleft
|
||||
|
||||
def acquire(
|
||||
self,
|
||||
callback: Callable[P, None],
|
||||
*partial_args: P.args,
|
||||
**partial_kwargs: P.kwargs
|
||||
) -> bool:
|
||||
"""Acquire semaphore.
|
||||
|
||||
This will immediately apply ``callback`` if
|
||||
the resource is available, otherwise the callback is suspended
|
||||
until the semaphore is released.
|
||||
|
||||
Arguments:
|
||||
---------
|
||||
callback (Callable): The callback to apply.
|
||||
*partial_args (Any): partial arguments to callback.
|
||||
"""
|
||||
value = self.value
|
||||
if value <= 0:
|
||||
self._add_waiter((callback, partial_args, partial_kwargs))
|
||||
return False
|
||||
else:
|
||||
self.value = max(value - 1, 0)
|
||||
callback(*partial_args, **partial_kwargs)
|
||||
return True
|
||||
|
||||
def release(self) -> None:
|
||||
"""Release semaphore.
|
||||
|
||||
Note:
|
||||
----
|
||||
If there are any waiters this will apply the first waiter
|
||||
that is waiting for the resource (FIFO order).
|
||||
"""
|
||||
try:
|
||||
waiter, args, kwargs = self._pop_waiter()
|
||||
except IndexError:
|
||||
self.value = min(self.value + 1, self.initial_value)
|
||||
else:
|
||||
waiter(*args, **kwargs)
|
||||
|
||||
def grow(self, n: int = 1) -> None:
|
||||
"""Change the size of the semaphore to accept more users."""
|
||||
self.initial_value += n
|
||||
self.value += n
|
||||
for _ in range(n):
|
||||
self.release()
|
||||
|
||||
def shrink(self, n: int = 1) -> None:
|
||||
"""Change the size of the semaphore to accept less users."""
|
||||
self.initial_value = max(self.initial_value - n, 0)
|
||||
self.value = max(self.value - n, 0)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Reset the semaphore, which also wipes out any waiting callbacks."""
|
||||
self._waiting.clear()
|
||||
self.value = self.initial_value
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return '<{} at {:#x} value:{} waiting:{}>'.format(
|
||||
self.__class__.__name__, id(self), self.value, len(self._waiting),
|
||||
)
|
||||
|
||||
|
||||
class DummyLock:
|
||||
"""Pretending to be a lock."""
|
||||
|
||||
def __enter__(self) -> DummyLock:
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: TracebackType | None
|
||||
) -> None:
|
||||
pass
|
||||
@@ -0,0 +1,241 @@
|
||||
"""Timer scheduling Python callbacks."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import heapq
|
||||
import sys
|
||||
from collections import namedtuple
|
||||
from datetime import datetime
|
||||
from functools import total_ordering
|
||||
from time import monotonic
|
||||
from time import time as _time
|
||||
from typing import TYPE_CHECKING
|
||||
from weakref import proxy as weakrefproxy
|
||||
|
||||
from vine.utils import wraps
|
||||
|
||||
from kombu.log import get_logger
|
||||
|
||||
if sys.version_info >= (3, 9):
|
||||
from zoneinfo import ZoneInfo
|
||||
else:
|
||||
from backports.zoneinfo import ZoneInfo
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from types import TracebackType
|
||||
|
||||
__all__ = ('Entry', 'Timer', 'to_timestamp')
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
DEFAULT_MAX_INTERVAL = 2
|
||||
EPOCH = datetime.fromtimestamp(0, ZoneInfo("UTC"))
|
||||
IS_PYPY = hasattr(sys, 'pypy_version_info')
|
||||
|
||||
scheduled = namedtuple('scheduled', ('eta', 'priority', 'entry'))
|
||||
|
||||
|
||||
def to_timestamp(d, default_timezone=ZoneInfo("UTC"), time=monotonic):
|
||||
"""Convert datetime to timestamp.
|
||||
|
||||
If d' is already a timestamp, then that will be used.
|
||||
"""
|
||||
if isinstance(d, datetime):
|
||||
if d.tzinfo is None:
|
||||
d = d.replace(tzinfo=default_timezone)
|
||||
diff = _time() - time()
|
||||
return max((d - EPOCH).total_seconds() - diff, 0)
|
||||
return d
|
||||
|
||||
|
||||
@total_ordering
|
||||
class Entry:
|
||||
"""Schedule Entry."""
|
||||
|
||||
if not IS_PYPY: # pragma: no cover
|
||||
__slots__ = (
|
||||
'fun', 'args', 'kwargs', 'tref', 'canceled',
|
||||
'_last_run', '__weakref__',
|
||||
)
|
||||
|
||||
def __init__(self, fun, args=None, kwargs=None):
|
||||
self.fun = fun
|
||||
self.args = args or []
|
||||
self.kwargs = kwargs or {}
|
||||
self.tref = weakrefproxy(self)
|
||||
self._last_run = None
|
||||
self.canceled = False
|
||||
|
||||
def __call__(self):
|
||||
return self.fun(*self.args, **self.kwargs)
|
||||
|
||||
def cancel(self):
|
||||
try:
|
||||
self.tref.canceled = True
|
||||
except ReferenceError: # pragma: no cover
|
||||
pass
|
||||
|
||||
def __repr__(self):
|
||||
return '<TimerEntry: {}(*{!r}, **{!r})'.format(
|
||||
self.fun.__name__, self.args, self.kwargs)
|
||||
|
||||
# must not use hash() to order entries
|
||||
def __lt__(self, other):
|
||||
return id(self) < id(other)
|
||||
|
||||
@property
|
||||
def cancelled(self):
|
||||
return self.canceled
|
||||
|
||||
@cancelled.setter
|
||||
def cancelled(self, value):
|
||||
self.canceled = value
|
||||
|
||||
|
||||
class Timer:
|
||||
"""Async timer implementation."""
|
||||
|
||||
Entry = Entry
|
||||
|
||||
on_error = None
|
||||
|
||||
def __init__(self, max_interval=None, on_error=None, **kwargs):
|
||||
self.max_interval = float(max_interval or DEFAULT_MAX_INTERVAL)
|
||||
self.on_error = on_error or self.on_error
|
||||
self._queue = []
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: TracebackType | None
|
||||
) -> None:
|
||||
self.stop()
|
||||
|
||||
def call_at(self, eta, fun, args=(), kwargs=None, priority=0):
|
||||
kwargs = {} if not kwargs else kwargs
|
||||
return self.enter_at(self.Entry(fun, args, kwargs), eta, priority)
|
||||
|
||||
def call_after(self, secs, fun, args=(), kwargs=None, priority=0):
|
||||
kwargs = {} if not kwargs else kwargs
|
||||
return self.enter_after(secs, self.Entry(fun, args, kwargs), priority)
|
||||
|
||||
def call_repeatedly(self, secs, fun, args=(), kwargs=None, priority=0):
|
||||
kwargs = {} if not kwargs else kwargs
|
||||
tref = self.Entry(fun, args, kwargs)
|
||||
|
||||
@wraps(fun)
|
||||
def _reschedules(*args, **kwargs):
|
||||
last, now = tref._last_run, monotonic()
|
||||
lsince = (now - tref._last_run) if last else secs
|
||||
try:
|
||||
if lsince and lsince >= secs:
|
||||
tref._last_run = now
|
||||
return fun(*args, **kwargs)
|
||||
finally:
|
||||
if not tref.canceled:
|
||||
last = tref._last_run
|
||||
next = secs - (now - last) if last else secs
|
||||
self.enter_after(next, tref, priority)
|
||||
|
||||
tref.fun = _reschedules
|
||||
tref._last_run = None
|
||||
return self.enter_after(secs, tref, priority)
|
||||
|
||||
def enter_at(self, entry, eta=None, priority=0, time=monotonic):
|
||||
"""Enter function into the scheduler.
|
||||
|
||||
Arguments:
|
||||
---------
|
||||
entry (~kombu.asynchronous.timer.Entry): Item to enter.
|
||||
eta (datetime.datetime): Scheduled time.
|
||||
priority (int): Unused.
|
||||
"""
|
||||
if eta is None:
|
||||
eta = time()
|
||||
if isinstance(eta, datetime):
|
||||
try:
|
||||
eta = to_timestamp(eta)
|
||||
except Exception as exc:
|
||||
if not self.handle_error(exc):
|
||||
raise
|
||||
return
|
||||
return self._enter(eta, priority, entry)
|
||||
|
||||
def enter_after(self, secs, entry, priority=0, time=monotonic):
|
||||
return self.enter_at(entry, time() + float(secs), priority)
|
||||
|
||||
def _enter(self, eta, priority, entry, push=heapq.heappush):
|
||||
push(self._queue, scheduled(eta, priority, entry))
|
||||
return entry
|
||||
|
||||
def apply_entry(self, entry):
|
||||
try:
|
||||
entry()
|
||||
except Exception as exc:
|
||||
if not self.handle_error(exc):
|
||||
logger.error('Error in timer: %r', exc, exc_info=True)
|
||||
|
||||
def handle_error(self, exc_info):
|
||||
if self.on_error:
|
||||
self.on_error(exc_info)
|
||||
return True
|
||||
|
||||
def stop(self):
|
||||
pass
|
||||
|
||||
def __iter__(self, min=min, nowfun=monotonic,
|
||||
pop=heapq.heappop, push=heapq.heappush):
|
||||
"""Iterate over schedule.
|
||||
|
||||
This iterator yields a tuple of ``(wait_seconds, entry)``,
|
||||
where if entry is :const:`None` the caller should wait
|
||||
for ``wait_seconds`` until it polls the schedule again.
|
||||
"""
|
||||
max_interval = self.max_interval
|
||||
queue = self._queue
|
||||
|
||||
while 1:
|
||||
if queue:
|
||||
eventA = queue[0]
|
||||
now, eta = nowfun(), eventA[0]
|
||||
|
||||
if now < eta:
|
||||
yield min(eta - now, max_interval), None
|
||||
else:
|
||||
eventB = pop(queue)
|
||||
|
||||
if eventB is eventA:
|
||||
entry = eventA[2]
|
||||
if not entry.canceled:
|
||||
yield None, entry
|
||||
continue
|
||||
else:
|
||||
push(queue, eventB)
|
||||
else:
|
||||
yield None, None
|
||||
|
||||
def clear(self):
|
||||
self._queue[:] = [] # atomic, without creating a new list.
|
||||
|
||||
def cancel(self, tref):
|
||||
tref.cancel()
|
||||
|
||||
def __len__(self):
|
||||
return len(self._queue)
|
||||
|
||||
def __nonzero__(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def queue(self, _pop=heapq.heappop):
|
||||
"""Snapshot of underlying datastructure."""
|
||||
events = list(self._queue)
|
||||
return [_pop(v) for v in [events] * len(events)]
|
||||
|
||||
@property
|
||||
def schedule(self):
|
||||
return self
|
||||
Reference in New Issue
Block a user