This commit is contained in:
Iliyan Angelov
2025-09-19 11:58:53 +03:00
parent 306b20e24a
commit 6b247e5b9f
11423 changed files with 1500615 additions and 778 deletions

View File

@@ -0,0 +1,20 @@
"""DEPRECATED - Import from modules below."""
from __future__ import annotations
from .collections import EqualityDict
from .compat import fileno, maybe_fileno, nested, register_after_fork
from .div import emergency_dump_state
from .functional import (fxrange, fxrangemax, maybe_list, reprcall,
retry_over_time)
from .imports import symbol_by_name
from .objects import cached_property
from .uuid import uuid
__all__ = (
'EqualityDict', 'uuid', 'maybe_list',
'fxrange', 'fxrangemax', 'retry_over_time',
'emergency_dump_state', 'cached_property',
'register_after_fork', 'reprkwargs', 'reprcall',
'symbol_by_name', 'nested', 'fileno', 'maybe_fileno',
)

View File

@@ -0,0 +1,22 @@
"""AMQP Management API utilities."""
from __future__ import annotations
def get_manager(client, hostname=None, port=None, userid=None,
password=None):
"""Get pyrabbit manager."""
import pyrabbit
opt = client.transport_options.get
def get(name, val, default):
return (val if val is not None
else opt('manager_%s' % name) or
getattr(client, name, None) or default)
host = get('hostname', hostname, 'localhost')
port = port if port is not None else opt('manager_port', 15672)
userid = get('userid', userid, 'guest')
password = get('password', password, 'guest')
return pyrabbit.Client(f'{host}:{port}', userid, password)

View File

@@ -0,0 +1,45 @@
"""Custom maps, sequences, etc."""
from __future__ import annotations
class HashedSeq(list):
"""Hashed Sequence.
Type used for hash() to make sure the hash is not generated
multiple times.
"""
__slots__ = 'hashvalue'
def __init__(self, *seq):
self[:] = seq
self.hashvalue = hash(seq)
def __hash__(self):
return self.hashvalue
def eqhash(o):
"""Call ``obj.__eqhash__``."""
try:
return o.__eqhash__()
except AttributeError:
return hash(o)
class EqualityDict(dict):
"""Dict using the eq operator for keying."""
def __getitem__(self, key):
h = eqhash(key)
if h not in self:
return self.__missing__(key)
return super().__getitem__(h)
def __setitem__(self, key, value):
return super().__setitem__(eqhash(key), value)
def __delitem__(self, key):
return super().__delitem__(eqhash(key))

View File

@@ -0,0 +1,137 @@
"""Python Compatibility Utilities."""
from __future__ import annotations
import numbers
import sys
from contextlib import contextmanager
from functools import wraps
from importlib import metadata as importlib_metadata
from io import UnsupportedOperation
from kombu.exceptions import reraise
FILENO_ERRORS = (AttributeError, ValueError, UnsupportedOperation)
try:
from billiard.util import register_after_fork
except ImportError: # pragma: no cover
try:
from multiprocessing.util import register_after_fork
except ImportError:
register_after_fork = None
_environment = None
def coro(gen):
"""Decorator to mark generator as co-routine."""
@wraps(gen)
def wind_up(*args, **kwargs):
it = gen(*args, **kwargs)
next(it)
return it
return wind_up
def _detect_environment():
# ## -eventlet-
if 'eventlet' in sys.modules:
try:
import socket
from eventlet.patcher import is_monkey_patched as is_eventlet
if is_eventlet(socket):
return 'eventlet'
except ImportError:
pass
# ## -gevent-
if 'gevent' in sys.modules:
try:
import socket
from gevent import socket as _gsocket
if socket.socket is _gsocket.socket:
return 'gevent'
except ImportError:
pass
return 'default'
def detect_environment():
"""Detect the current environment: default, eventlet, or gevent."""
global _environment
if _environment is None:
_environment = _detect_environment()
return _environment
def entrypoints(namespace):
"""Return setuptools entrypoints for namespace."""
if sys.version_info >= (3,10):
entry_points = importlib_metadata.entry_points(group=namespace)
else:
entry_points = importlib_metadata.entry_points()
try:
entry_points = entry_points.get(namespace, [])
except AttributeError:
entry_points = entry_points.select(group=namespace)
return (
(ep, ep.load())
for ep in entry_points
)
def fileno(f):
"""Get fileno from file-like object."""
if isinstance(f, numbers.Integral):
return f
return f.fileno()
def maybe_fileno(f):
"""Get object fileno, or :const:`None` if not defined."""
try:
return fileno(f)
except FILENO_ERRORS:
pass
@contextmanager
def nested(*managers): # pragma: no cover
"""Nest context managers."""
# flake8: noqa
exits = []
vars = []
exc = (None, None, None)
try:
try:
for mgr in managers:
exit = mgr.__exit__
enter = mgr.__enter__
vars.append(enter())
exits.append(exit)
yield vars
except:
exc = sys.exc_info()
finally:
while exits:
exit = exits.pop()
try:
if exit(*exc):
exc = (None, None, None)
except:
exc = sys.exc_info()
if exc != (None, None, None):
# Don't rely on sys.exc_info() still containing
# the right information. Another exception may
# have been raised and caught by an exit method
reraise(exc[0], exc[1], exc[2])
finally:
del(exc)

View File

@@ -0,0 +1,77 @@
"""Debugging support."""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING
from vine.utils import wraps
from kombu.log import get_logger
if TYPE_CHECKING:
from logging import Logger
from typing import Any, Callable
from kombu.transport.base import Transport
__all__ = ('setup_logging', 'Logwrapped')
def setup_logging(
loglevel: int | None = logging.DEBUG,
loggers: list[str] | None = None
) -> None:
"""Setup logging to stdout."""
loggers = ['kombu.connection', 'kombu.channel'] if not loggers else loggers
for logger_name in loggers:
logger = get_logger(logger_name)
logger.addHandler(logging.StreamHandler())
logger.setLevel(loglevel)
class Logwrapped:
"""Wrap all object methods, to log on call."""
__ignore = ('__enter__', '__exit__')
def __init__(
self,
instance: Transport,
logger: Logger | None = None,
ident: str | None = None
):
self.instance = instance
self.logger = get_logger(logger)
self.ident = ident
def __getattr__(self, key: str) -> Callable:
meth = getattr(self.instance, key)
if not callable(meth) or key in self.__ignore:
return meth
@wraps(meth)
def __wrapped(*args: list[Any], **kwargs: dict[str, Any]) -> Callable:
info = ''
if self.ident:
info += self.ident.format(self.instance)
info += f'{meth.__name__}('
if args:
info += ', '.join(map(repr, args))
if kwargs:
if args:
info += ', '
info += ', '.join(f'{key}={value!r}'
for key, value in kwargs.items())
info += ')'
self.logger.debug(info)
return meth(*args, **kwargs)
return __wrapped
def __repr__(self) -> str:
return repr(self.instance)
def __dir__(self) -> list[str]:
return dir(self.instance)

View File

@@ -0,0 +1,37 @@
"""Div. Utilities."""
from __future__ import annotations
import os
import sys
from .encoding import default_encode
def emergency_dump_state(state, open_file=open, dump=None, stderr=None):
"""Dump message state to stdout or file."""
from pprint import pformat
from tempfile import mkstemp
stderr = sys.stderr if stderr is None else stderr
if dump is None:
import pickle
dump = pickle.dump
fd, persist = mkstemp()
os.close(fd)
print(f'EMERGENCY DUMP STATE TO FILE -> {persist} <-',
file=stderr)
fh = open_file(persist, 'w')
try:
try:
dump(state, fh, protocol=0)
except Exception as exc:
print(
f'Cannot pickle state: {exc!r}. Fallback to pformat.',
file=stderr,
)
fh.write(default_encode(pformat(state)))
finally:
fh.flush()
fh.close()
return persist

View File

@@ -0,0 +1,97 @@
"""Text encoding utilities.
Utilities to encode text, and to safely emit text from running
applications without crashing from the infamous
:exc:`UnicodeDecodeError` exception.
"""
from __future__ import annotations
import sys
import traceback
#: safe_str takes encoding from this file by default.
#: :func:`set_default_encoding_file` can used to set the
#: default output file.
default_encoding_file = None
def set_default_encoding_file(file):
"""Set file used to get codec information."""
global default_encoding_file
default_encoding_file = file
def get_default_encoding_file():
"""Get file used to get codec information."""
return default_encoding_file
if sys.platform.startswith('java'): # pragma: no cover
def default_encoding(file=None):
"""Get default encoding."""
return 'utf-8'
else:
def default_encoding(file=None):
"""Get default encoding."""
file = file or get_default_encoding_file()
return getattr(file, 'encoding', None) or sys.getfilesystemencoding()
def str_to_bytes(s):
"""Convert str to bytes."""
if isinstance(s, str):
return s.encode()
return s
def bytes_to_str(s):
"""Convert bytes to str."""
if isinstance(s, bytes):
return s.decode(errors='replace')
return s
def from_utf8(s, *args, **kwargs):
"""Get str from utf-8 encoding."""
return s
def ensure_bytes(s):
"""Ensure s is bytes, not str."""
if not isinstance(s, bytes):
return str_to_bytes(s)
return s
def default_encode(obj):
"""Encode using default encoding."""
return obj
def safe_str(s, errors='replace'):
"""Safe form of str(), void of unicode errors."""
s = bytes_to_str(s)
if not isinstance(s, (str, bytes)):
return safe_repr(s, errors)
return _safe_str(s, errors)
def _safe_str(s, errors='replace', file=None):
if isinstance(s, str):
return s
try:
return str(s)
except Exception as exc:
return '<Unrepresentable {!r}: {!r} {!r}>'.format(
type(s), exc, '\n'.join(traceback.format_stack()))
def safe_repr(o, errors='replace'):
"""Safe form of repr, void of Unicode errors."""
try:
return repr(o)
except Exception:
return _safe_str(o, errors)

View File

@@ -0,0 +1,329 @@
"""Selector Utilities."""
from __future__ import annotations
import errno
import math
import select as __select__
import sys
from numbers import Integral
from . import fileno
from .compat import detect_environment
__all__ = ('poll',)
_selectf = __select__.select
_selecterr = __select__.error
xpoll = getattr(__select__, 'poll', None)
epoll = getattr(__select__, 'epoll', None)
kqueue = getattr(__select__, 'kqueue', None)
kevent = getattr(__select__, 'kevent', None)
KQ_EV_ADD = getattr(__select__, 'KQ_EV_ADD', 1)
KQ_EV_DELETE = getattr(__select__, 'KQ_EV_DELETE', 2)
KQ_EV_ENABLE = getattr(__select__, 'KQ_EV_ENABLE', 4)
KQ_EV_CLEAR = getattr(__select__, 'KQ_EV_CLEAR', 32)
KQ_EV_ERROR = getattr(__select__, 'KQ_EV_ERROR', 16384)
KQ_EV_EOF = getattr(__select__, 'KQ_EV_EOF', 32768)
KQ_FILTER_READ = getattr(__select__, 'KQ_FILTER_READ', -1)
KQ_FILTER_WRITE = getattr(__select__, 'KQ_FILTER_WRITE', -2)
KQ_FILTER_AIO = getattr(__select__, 'KQ_FILTER_AIO', -3)
KQ_FILTER_VNODE = getattr(__select__, 'KQ_FILTER_VNODE', -4)
KQ_FILTER_PROC = getattr(__select__, 'KQ_FILTER_PROC', -5)
KQ_FILTER_SIGNAL = getattr(__select__, 'KQ_FILTER_SIGNAL', -6)
KQ_FILTER_TIMER = getattr(__select__, 'KQ_FILTER_TIMER', -7)
KQ_NOTE_LOWAT = getattr(__select__, 'KQ_NOTE_LOWAT', 1)
KQ_NOTE_DELETE = getattr(__select__, 'KQ_NOTE_DELETE', 1)
KQ_NOTE_WRITE = getattr(__select__, 'KQ_NOTE_WRITE', 2)
KQ_NOTE_EXTEND = getattr(__select__, 'KQ_NOTE_EXTEND', 4)
KQ_NOTE_ATTRIB = getattr(__select__, 'KQ_NOTE_ATTRIB', 8)
KQ_NOTE_LINK = getattr(__select__, 'KQ_NOTE_LINK', 16)
KQ_NOTE_RENAME = getattr(__select__, 'KQ_NOTE_RENAME', 32)
KQ_NOTE_REVOKE = getattr(__select__, 'KQ_NOTE_REVOKE', 64)
POLLIN = getattr(__select__, 'POLLIN', 1)
POLLOUT = getattr(__select__, 'POLLOUT', 4)
POLLERR = getattr(__select__, 'POLLERR', 8)
POLLHUP = getattr(__select__, 'POLLHUP', 16)
POLLNVAL = getattr(__select__, 'POLLNVAL', 32)
READ = POLL_READ = 0x001
WRITE = POLL_WRITE = 0x004
ERR = POLL_ERR = 0x008 | 0x010
try:
SELECT_BAD_FD = {errno.EBADF, errno.WSAENOTSOCK}
except AttributeError:
SELECT_BAD_FD = {errno.EBADF}
class _epoll:
def __init__(self):
self._epoll = epoll()
def register(self, fd, events):
try:
self._epoll.register(fd, events)
except Exception as exc:
if getattr(exc, 'errno', None) != errno.EEXIST:
raise
return fd
def unregister(self, fd):
try:
self._epoll.unregister(fd)
except (OSError, ValueError, KeyError, TypeError):
pass
except OSError as exc:
if getattr(exc, 'errno', None) not in (errno.ENOENT, errno.EPERM):
raise
def poll(self, timeout):
try:
return self._epoll.poll(timeout if timeout is not None else -1)
except Exception as exc:
if getattr(exc, 'errno', None) != errno.EINTR:
raise
def close(self):
self._epoll.close()
class _kqueue:
w_fflags = (KQ_NOTE_WRITE | KQ_NOTE_EXTEND |
KQ_NOTE_ATTRIB | KQ_NOTE_DELETE)
def __init__(self):
self._kqueue = kqueue()
self._active = {}
self.on_file_change = None
self._kcontrol = self._kqueue.control
def register(self, fd, events):
self._control(fd, events, KQ_EV_ADD)
self._active[fd] = events
return fd
def unregister(self, fd):
events = self._active.pop(fd, None)
if events:
try:
self._control(fd, events, KQ_EV_DELETE)
except OSError:
pass
def watch_file(self, fd):
ev = kevent(fd,
filter=KQ_FILTER_VNODE,
flags=KQ_EV_ADD | KQ_EV_ENABLE | KQ_EV_CLEAR,
fflags=self.w_fflags)
self._kcontrol([ev], 0)
def unwatch_file(self, fd):
ev = kevent(fd,
filter=KQ_FILTER_VNODE,
flags=KQ_EV_DELETE,
fflags=self.w_fflags)
self._kcontrol([ev], 0)
def _control(self, fd, events, flags):
if not events:
return
kevents = []
if events & WRITE:
kevents.append(kevent(fd,
filter=KQ_FILTER_WRITE,
flags=flags))
if not kevents or events & READ:
kevents.append(
kevent(fd, filter=KQ_FILTER_READ, flags=flags),
)
control = self._kcontrol
for e in kevents:
try:
control([e], 0)
except ValueError:
pass
def poll(self, timeout):
try:
kevents = self._kcontrol(None, 1000, timeout)
except Exception as exc:
if getattr(exc, 'errno', None) == errno.EINTR:
return
raise
events, file_changes = {}, []
for k in kevents:
fd = k.ident
if k.filter == KQ_FILTER_READ:
events[fd] = events.get(fd, 0) | READ
elif k.filter == KQ_FILTER_WRITE:
if k.flags & KQ_EV_EOF:
events[fd] = ERR
else:
events[fd] = events.get(fd, 0) | WRITE
elif k.filter == KQ_EV_ERROR:
events[fd] = events.get(fd, 0) | ERR
elif k.filter == KQ_FILTER_VNODE:
if k.fflags & KQ_NOTE_DELETE:
self.unregister(fd)
file_changes.append(k)
if file_changes:
self.on_file_change(file_changes)
return list(events.items())
def close(self):
self._kqueue.close()
class _poll:
def __init__(self):
self._poller = xpoll()
self._quick_poll = self._poller.poll
self._quick_register = self._poller.register
self._quick_unregister = self._poller.unregister
def register(self, fd, events):
fd = fileno(fd)
poll_flags = 0
if events & ERR:
poll_flags |= POLLERR
if events & WRITE:
poll_flags |= POLLOUT
if events & READ:
poll_flags |= POLLIN
self._quick_register(fd, poll_flags)
return fd
def unregister(self, fd):
try:
fd = fileno(fd)
except OSError as exc:
# we don't know the previous fd of this object
# but it will be removed by the next poll iteration.
if getattr(exc, 'errno', None) in SELECT_BAD_FD:
return fd
raise
self._quick_unregister(fd)
return fd
def poll(self, timeout, round=math.ceil,
POLLIN=POLLIN, POLLOUT=POLLOUT, POLLERR=POLLERR,
READ=READ, WRITE=WRITE, ERR=ERR, Integral=Integral):
timeout = 0 if timeout and timeout < 0 else round((timeout or 0) * 1e3)
try:
event_list = self._quick_poll(timeout)
except (_selecterr, OSError) as exc:
if getattr(exc, 'errno', None) == errno.EINTR:
return
raise
ready = []
for fd, event in event_list:
events = 0
if event & POLLIN:
events |= READ
if event & POLLOUT:
events |= WRITE
if event & POLLERR or event & POLLNVAL or event & POLLHUP:
events |= ERR
assert events
if not isinstance(fd, Integral):
fd = fd.fileno()
ready.append((fd, events))
return ready
def close(self):
self._poller = None
class _select:
def __init__(self):
self._all = (self._rfd,
self._wfd,
self._efd) = set(), set(), set()
def register(self, fd, events):
fd = fileno(fd)
if events & ERR:
self._efd.add(fd)
if events & WRITE:
self._wfd.add(fd)
if events & READ:
self._rfd.add(fd)
return fd
def _remove_bad(self):
for fd in self._rfd | self._wfd | self._efd:
try:
_selectf([fd], [], [], 0)
except (_selecterr, OSError) as exc:
if getattr(exc, 'errno', None) in SELECT_BAD_FD:
self.unregister(fd)
def unregister(self, fd):
try:
fd = fileno(fd)
except OSError as exc:
# we don't know the previous fd of this object
# but it will be removed by the next poll iteration.
if getattr(exc, 'errno', None) in SELECT_BAD_FD:
return
raise
self._rfd.discard(fd)
self._wfd.discard(fd)
self._efd.discard(fd)
def poll(self, timeout):
try:
read, write, error = _selectf(
self._rfd, self._wfd, self._efd, timeout,
)
except (_selecterr, OSError) as exc:
if getattr(exc, 'errno', None) == errno.EINTR:
return
elif getattr(exc, 'errno', None) in SELECT_BAD_FD:
return self._remove_bad()
raise
events = {}
for fd in read:
if not isinstance(fd, Integral):
fd = fd.fileno()
events[fd] = events.get(fd, 0) | READ
for fd in write:
if not isinstance(fd, Integral):
fd = fd.fileno()
events[fd] = events.get(fd, 0) | WRITE
for fd in error:
if not isinstance(fd, Integral):
fd = fd.fileno()
events[fd] = events.get(fd, 0) | ERR
return list(events.items())
def close(self):
self._rfd.clear()
self._wfd.clear()
self._efd.clear()
def _get_poller():
if detect_environment() != 'default':
# greenlet
return _select
elif epoll:
# Py2.6+ Linux
return _epoll
elif kqueue and 'netbsd' in sys.platform:
return _kqueue
elif xpoll:
return _poll
else:
return _select
def poll(*args, **kwargs):
"""Create new poller instance."""
return _get_poller()(*args, **kwargs)

View File

@@ -0,0 +1,360 @@
"""Functional Utilities."""
from __future__ import annotations
import inspect
import random
import threading
from collections import OrderedDict, UserDict
from collections.abc import Iterable, Mapping
from itertools import count, repeat
from time import sleep, time
from vine.utils import wraps
from .encoding import safe_repr as _safe_repr
__all__ = (
'LRUCache', 'memoize', 'lazy', 'maybe_evaluate',
'is_list', 'maybe_list', 'dictfilter', 'retry_over_time',
)
KEYWORD_MARK = object()
class ChannelPromise:
def __init__(self, contract):
self.__contract__ = contract
def __call__(self):
try:
return self.__value__
except AttributeError:
value = self.__value__ = self.__contract__()
return value
def __repr__(self):
try:
return repr(self.__value__)
except AttributeError:
return f'<promise: 0x{id(self.__contract__):x}>'
class LRUCache(UserDict):
"""LRU Cache implementation using a doubly linked list to track access.
Arguments:
---------
limit (int): The maximum number of keys to keep in the cache.
When a new key is inserted and the limit has been exceeded,
the *Least Recently Used* key will be discarded from the
cache.
"""
def __init__(self, limit=None):
self.limit = limit
self.mutex = threading.RLock()
self.data = OrderedDict()
def __getitem__(self, key):
with self.mutex:
value = self[key] = self.data.pop(key)
return value
def update(self, *args, **kwargs):
with self.mutex:
data, limit = self.data, self.limit
data.update(*args, **kwargs)
if limit and len(data) > limit:
# pop additional items in case limit exceeded
for _ in range(len(data) - limit):
data.popitem(last=False)
def popitem(self, last=True):
with self.mutex:
return self.data.popitem(last)
def __setitem__(self, key, value):
# remove least recently used key.
with self.mutex:
if self.limit and len(self.data) >= self.limit:
self.data.pop(next(iter(self.data)))
self.data[key] = value
def __iter__(self):
return iter(self.data)
def _iterate_items(self):
with self.mutex:
for k in self:
try:
yield (k, self.data[k])
except KeyError: # pragma: no cover
pass
iteritems = _iterate_items
def _iterate_values(self):
with self.mutex:
for k in self:
try:
yield self.data[k]
except KeyError: # pragma: no cover
pass
itervalues = _iterate_values
def _iterate_keys(self):
# userdict.keys in py3k calls __getitem__
with self.mutex:
return self.data.keys()
iterkeys = _iterate_keys
def incr(self, key, delta=1):
with self.mutex:
# this acts as memcached does- store as a string, but return a
# integer as long as it exists and we can cast it
newval = int(self.data.pop(key)) + delta
self[key] = str(newval)
return newval
def __getstate__(self):
d = dict(vars(self))
d.pop('mutex')
return d
def __setstate__(self, state):
self.__dict__ = state
self.mutex = threading.RLock()
keys = _iterate_keys
values = _iterate_values
items = _iterate_items
def memoize(maxsize=None, keyfun=None, Cache=LRUCache):
"""Decorator to cache function return value."""
def _memoize(fun):
mutex = threading.Lock()
cache = Cache(limit=maxsize)
@wraps(fun)
def _M(*args, **kwargs):
if keyfun:
key = keyfun(args, kwargs)
else:
key = args + (KEYWORD_MARK,) + tuple(sorted(kwargs.items()))
try:
with mutex:
value = cache[key]
except KeyError:
value = fun(*args, **kwargs)
_M.misses += 1
with mutex:
cache[key] = value
else:
_M.hits += 1
return value
def clear():
"""Clear the cache and reset cache statistics."""
cache.clear()
_M.hits = _M.misses = 0
_M.hits = _M.misses = 0
_M.clear = clear
_M.original_func = fun
return _M
return _memoize
class lazy:
"""Holds lazy evaluation.
Evaluated when called or if the :meth:`evaluate` method is called.
The function is re-evaluated on every call.
Overloaded operations that will evaluate the promise:
:meth:`__str__`, :meth:`__repr__`, :meth:`__cmp__`.
"""
def __init__(self, fun, *args, **kwargs):
self._fun = fun
self._args = args
self._kwargs = kwargs
def __call__(self):
return self.evaluate()
def evaluate(self):
return self._fun(*self._args, **self._kwargs)
def __str__(self):
return str(self())
def __repr__(self):
return repr(self())
def __eq__(self, rhs):
return self() == rhs
def __ne__(self, rhs):
return self() != rhs
def __deepcopy__(self, memo):
memo[id(self)] = self
return self
def __reduce__(self):
return (self.__class__, (self._fun,), {'_args': self._args,
'_kwargs': self._kwargs})
def maybe_evaluate(value):
"""Evaluate value only if value is a :class:`lazy` instance."""
if isinstance(value, lazy):
return value.evaluate()
return value
def is_list(obj, scalars=(Mapping, str), iters=(Iterable,)):
"""Return true if the object is iterable.
Note:
----
Returns false if object is a mapping or string.
"""
return isinstance(obj, iters) and not isinstance(obj, scalars or ())
def maybe_list(obj, scalars=(Mapping, str)):
"""Return list of one element if ``l`` is a scalar."""
return obj if obj is None or is_list(obj, scalars) else [obj]
def dictfilter(d=None, **kw):
"""Remove all keys from dict ``d`` whose value is :const:`None`."""
d = kw if d is None else (dict(d, **kw) if kw else d)
return {k: v for k, v in d.items() if v is not None}
def shufflecycle(it):
it = list(it) # don't modify callers list
shuffle = random.shuffle
for _ in repeat(None):
shuffle(it)
yield it[0]
def fxrange(start=1.0, stop=None, step=1.0, repeatlast=False):
cur = start * 1.0
while 1:
if not stop or cur <= stop:
yield cur
cur += step
else:
if not repeatlast:
break
yield cur - step
def fxrangemax(start=1.0, stop=None, step=1.0, max=100.0):
sum_, cur = 0, start * 1.0
while 1:
if sum_ >= max:
break
yield cur
if stop:
cur = min(cur + step, stop)
else:
cur += step
sum_ += cur
def retry_over_time(fun, catch, args=None, kwargs=None, errback=None,
max_retries=None, interval_start=2, interval_step=2,
interval_max=30, callback=None, timeout=None):
"""Retry the function over and over until max retries is exceeded.
For each retry we sleep a for a while before we try again, this interval
is increased for every retry until the max seconds is reached.
Arguments:
---------
fun (Callable): The function to try
catch (Tuple[BaseException]): Exceptions to catch, can be either
tuple or a single exception class.
Keyword Arguments:
-----------------
args (Tuple): Positional arguments passed on to the function.
kwargs (Dict): Keyword arguments passed on to the function.
errback (Callable): Callback for when an exception in ``catch``
is raised. The callback must take three arguments:
``exc``, ``interval_range`` and ``retries``, where ``exc``
is the exception instance, ``interval_range`` is an iterator
which return the time in seconds to sleep next, and ``retries``
is the number of previous retries.
max_retries (int): Maximum number of retries before we give up.
If neither of this and timeout is set, we will retry forever.
If one of this and timeout is reached, stop.
interval_start (float): How long (in seconds) we start sleeping
between retries.
interval_step (float): By how much the interval is increased for
each retry.
interval_max (float): Maximum number of seconds to sleep
between retries.
timeout (int): Maximum seconds waiting before we give up.
"""
kwargs = {} if not kwargs else kwargs
args = [] if not args else args
interval_range = fxrange(interval_start,
interval_max + interval_start,
interval_step, repeatlast=True)
end = time() + timeout if timeout else None
for retries in count():
try:
return fun(*args, **kwargs)
except catch as exc:
if max_retries is not None and retries >= max_retries:
raise
if end and time() > end:
raise
if callback:
callback()
tts = float(errback(exc, interval_range, retries) if errback
else next(interval_range))
if tts:
for _ in range(int(tts)):
if callback:
callback()
sleep(1.0)
# sleep remainder after int truncation above.
sleep(abs(int(tts) - tts))
def reprkwargs(kwargs, sep=', ', fmt='{0}={1}'):
return sep.join(fmt.format(k, _safe_repr(v)) for k, v in kwargs.items())
def reprcall(name, args=(), kwargs=None, sep=', '):
kwargs = {} if not kwargs else kwargs
return '{}({}{}{})'.format(
name, sep.join(map(_safe_repr, args or ())),
(args and kwargs) and sep or '',
reprkwargs(kwargs, sep),
)
def accepts_argument(func, argument_name):
argument_spec = inspect.getfullargspec(func)
return (
argument_name in argument_spec.args or
argument_name in argument_spec.kwonlyargs
)
# Compat names (before kombu 3.0)
promise = lazy
maybe_promise = maybe_evaluate

View File

@@ -0,0 +1,68 @@
"""Import related utilities."""
from __future__ import annotations
import importlib
import sys
from kombu.exceptions import reraise
def symbol_by_name(name, aliases=None, imp=None, package=None,
sep='.', default=None, **kwargs):
"""Get symbol by qualified name.
The name should be the full dot-separated path to the class::
modulename.ClassName
Example::
celery.concurrency.processes.TaskPool
^- class name
or using ':' to separate module and symbol::
celery.concurrency.processes:TaskPool
If `aliases` is provided, a dict containing short name/long name
mappings, the name is looked up in the aliases first.
Examples
--------
>>> symbol_by_name('celery.concurrency.processes.TaskPool')
<class 'celery.concurrency.processes.TaskPool'>
>>> symbol_by_name('default', {
... 'default': 'celery.concurrency.processes.TaskPool'})
<class 'celery.concurrency.processes.TaskPool'>
# Does not try to look up non-string names.
>>> from celery.concurrency.processes import TaskPool
>>> symbol_by_name(TaskPool) is TaskPool
True
"""
aliases = {} if not aliases else aliases
if imp is None:
imp = importlib.import_module
if not isinstance(name, str):
return name # already a class
name = aliases.get(name) or name
sep = ':' if ':' in name else sep
module_name, _, cls_name = name.rpartition(sep)
if not module_name:
cls_name, module_name = None, package if package else cls_name
try:
try:
module = imp(module_name, package=package, **kwargs)
except ValueError as exc:
reraise(ValueError,
ValueError(f"Couldn't import {name!r}: {exc}"),
sys.exc_info()[2])
return getattr(module, cls_name) if cls_name else module
except (ImportError, AttributeError):
if default is None:
raise
return default

View File

@@ -0,0 +1,146 @@
"""JSON Serialization Utilities."""
from __future__ import annotations
import base64
import json
import uuid
from datetime import date, datetime, time
from decimal import Decimal
from typing import Any, Callable, TypeVar
textual_types = ()
try:
from django.utils.functional import Promise
textual_types += (Promise,)
except ImportError:
pass
class JSONEncoder(json.JSONEncoder):
"""Kombu custom json encoder."""
def default(self, o):
reducer = getattr(o, "__json__", None)
if reducer is not None:
return reducer()
if isinstance(o, textual_types):
return str(o)
for t, (marker, encoder) in _encoders.items():
if isinstance(o, t):
return (
encoder(o) if marker is None else _as(marker, encoder(o))
)
# Bytes is slightly trickier, so we cannot put them directly
# into _encoders, because we use two formats: bytes, and base64.
if isinstance(o, bytes):
try:
return _as("bytes", o.decode("utf-8"))
except UnicodeDecodeError:
return _as("base64", base64.b64encode(o).decode("utf-8"))
return super().default(o)
def _as(t: str, v: Any):
return {"__type__": t, "__value__": v}
def dumps(
s,
_dumps=json.dumps,
cls=JSONEncoder,
default_kwargs=None,
**kwargs
):
"""Serialize object to json string."""
default_kwargs = default_kwargs or {}
return _dumps(s, cls=cls, **dict(default_kwargs, **kwargs))
def object_hook(o: dict):
"""Hook function to perform custom deserialization."""
if o.keys() == {"__type__", "__value__"}:
decoder = _decoders.get(o["__type__"])
if decoder:
return decoder(o["__value__"])
else:
raise ValueError("Unsupported type", type, o)
else:
return o
def loads(s, _loads=json.loads, decode_bytes=True, object_hook=object_hook):
"""Deserialize json from string."""
# None of the json implementations supports decoding from
# a buffer/memoryview, or even reading from a stream
# (load is just loads(fp.read()))
# but this is Python, we love copying strings, preferably many times
# over. Note that pickle does support buffer/memoryview
# </rant>
if isinstance(s, memoryview):
s = s.tobytes().decode("utf-8")
elif isinstance(s, bytearray):
s = s.decode("utf-8")
elif decode_bytes and isinstance(s, bytes):
s = s.decode("utf-8")
return _loads(s, object_hook=object_hook)
DecoderT = EncoderT = Callable[[Any], Any]
T = TypeVar("T")
EncodedT = TypeVar("EncodedT")
def register_type(
t: type[T],
marker: str | None,
encoder: Callable[[T], EncodedT],
decoder: Callable[[EncodedT], T] = lambda d: d,
):
"""Add support for serializing/deserializing native python type.
If marker is `None`, the encoding is a pure transformation and the result
is not placed in an envelope, so `decoder` is unnecessary. Decoding must
instead be handled outside this library.
"""
_encoders[t] = (marker, encoder)
if marker is not None:
_decoders[marker] = decoder
_encoders: dict[type, tuple[str | None, EncoderT]] = {}
_decoders: dict[str, DecoderT] = {
"bytes": lambda o: o.encode("utf-8"),
"base64": lambda o: base64.b64decode(o.encode("utf-8")),
}
def _register_default_types():
# NOTE: datetime should be registered before date,
# because datetime is also instance of date.
register_type(datetime, "datetime", datetime.isoformat,
datetime.fromisoformat)
register_type(
date,
"date",
lambda o: o.isoformat(),
lambda o: datetime.fromisoformat(o).date(),
)
register_type(time, "time", lambda o: o.isoformat(), time.fromisoformat)
register_type(Decimal, "decimal", str, Decimal)
register_type(
uuid.UUID,
"uuid",
lambda o: {"hex": o.hex},
lambda o: uuid.UUID(**o),
)
_register_default_types()

View File

@@ -0,0 +1,87 @@
"""Token bucket implementation for rate limiting."""
from __future__ import annotations
from collections import deque
from time import monotonic
__all__ = ('TokenBucket',)
class TokenBucket:
"""Token Bucket Algorithm.
See Also
--------
https://en.wikipedia.org/wiki/Token_Bucket
Most of this code was stolen from an entry in the ASPN Python Cookbook:
https://code.activestate.com/recipes/511490/
Warning:
-------
Thread Safety: This implementation is not thread safe.
Access to a `TokenBucket` instance should occur within the critical
section of any multithreaded code.
"""
#: The rate in tokens/second that the bucket will be refilled.
fill_rate = None
#: Maximum number of tokens in the bucket.
capacity = 1
#: Timestamp of the last time a token was taken out of the bucket.
timestamp = None
def __init__(self, fill_rate, capacity=1):
self.capacity = float(capacity)
self._tokens = capacity
self.fill_rate = float(fill_rate)
self.timestamp = monotonic()
self.contents = deque()
def add(self, item):
self.contents.append(item)
def pop(self):
return self.contents.popleft()
def clear_pending(self):
self.contents.clear()
def can_consume(self, tokens=1):
"""Check if one or more tokens can be consumed.
Returns
-------
bool: true if the number of tokens can be consumed
from the bucket. If they can be consumed, a call will also
consume the requested number of tokens from the bucket.
Calls will only consume `tokens` (the number requested)
or zero tokens -- it will never consume a partial number
of tokens.
"""
if tokens <= self._get_tokens():
self._tokens -= tokens
return True
return False
def expected_time(self, tokens=1):
"""Return estimated time of token availability.
Returns
-------
float: the time in seconds.
"""
_tokens = self._get_tokens()
tokens = max(tokens, _tokens)
return (tokens - _tokens) / self.fill_rate
def _get_tokens(self):
if self._tokens < self.capacity:
now = monotonic()
delta = self.fill_rate * (now - self.timestamp)
self._tokens = min(self.capacity, self._tokens + delta)
self.timestamp = now
return self._tokens

View File

@@ -0,0 +1,67 @@
"""Object Utilities."""
from __future__ import annotations
from threading import RLock
__all__ = ('cached_property',)
try:
from functools import cached_property as _cached_property
except ImportError:
# TODO: Remove this fallback once we drop support for Python < 3.8
from cached_property import threaded_cached_property as _cached_property
_NOT_FOUND = object()
class cached_property(_cached_property):
"""Implementation of Cached property."""
def __init__(self, fget=None, fset=None, fdel=None):
super().__init__(fget)
self.__set = fset
self.__del = fdel
if not hasattr(self, 'attrname'):
# This is a backport so we set this ourselves.
self.attrname = self.func.__name__
if not hasattr(self, 'lock'):
# Prior to Python 3.12, functools.cached_property has an
# undocumented lock which is required for thread-safe __set__
# and __delete__. Create one if it isn't already present.
self.lock = RLock()
def __get__(self, instance, owner=None):
# TODO: Remove this after we drop support for Python<3.8
# or fix the signature in the cached_property package
with self.lock:
return super().__get__(instance, owner)
def __set__(self, instance, value):
if instance is None:
return self
with self.lock:
if self.__set is not None:
value = self.__set(instance, value)
cache = instance.__dict__
cache[self.attrname] = value
def __delete__(self, instance):
if instance is None:
return self
with self.lock:
value = instance.__dict__.pop(self.attrname, _NOT_FOUND)
if self.__del and value is not _NOT_FOUND:
self.__del(instance, value)
def setter(self, fset):
return self.__class__(self.func, fset, self.__del)
def deleter(self, fdel):
return self.__class__(self.func, self.__set, fdel)

View File

@@ -0,0 +1,111 @@
"""Scheduling Utilities."""
from __future__ import annotations
from itertools import count
from .imports import symbol_by_name
__all__ = (
'FairCycle', 'priority_cycle', 'round_robin_cycle', 'sorted_cycle',
)
CYCLE_ALIASES = {
'priority': 'kombu.utils.scheduling:priority_cycle',
'round_robin': 'kombu.utils.scheduling:round_robin_cycle',
'sorted': 'kombu.utils.scheduling:sorted_cycle',
}
class FairCycle:
"""Cycle between resources.
Consume from a set of resources, where each resource gets
an equal chance to be consumed from.
Arguments:
---------
fun (Callable): Callback to call.
resources (Sequence[Any]): List of resources.
predicate (type): Exception predicate.
"""
def __init__(self, fun, resources, predicate=Exception):
self.fun = fun
self.resources = resources
self.predicate = predicate
self.pos = 0
def _next(self):
while 1:
try:
resource = self.resources[self.pos]
self.pos += 1
return resource
except IndexError:
self.pos = 0
if not self.resources:
raise self.predicate()
def get(self, callback, **kwargs):
"""Get from next resource."""
for tried in count(0): # for infinity
resource = self._next()
try:
return self.fun(resource, callback, **kwargs)
except self.predicate:
# reraise when retries exhausted.
if tried >= len(self.resources) - 1:
raise
def close(self):
"""Close cycle."""
def __repr__(self):
"""``repr(cycle)``."""
return '<FairCycle: {self.pos}/{size} {self.resources}>'.format(
self=self, size=len(self.resources))
class round_robin_cycle:
"""Iterator that cycles between items in round-robin."""
def __init__(self, it=None):
self.items = it if it is not None else []
def update(self, it):
"""Update items from iterable."""
self.items[:] = it
def consume(self, n):
"""Consume n items."""
return self.items[:n]
def rotate(self, last_used):
"""Move most recently used item to end of list."""
items = self.items
try:
items.append(items.pop(items.index(last_used)))
except ValueError:
pass
return last_used
class priority_cycle(round_robin_cycle):
"""Cycle that repeats items in order."""
def rotate(self, last_used):
"""Unused in this implementation."""
class sorted_cycle(priority_cycle):
"""Cycle in sorted order."""
def consume(self, n):
"""Consume n items."""
return sorted(self.items[:n])
def cycle_by_name(name):
"""Get cycle class by name."""
return symbol_by_name(name, CYCLE_ALIASES)

View File

@@ -0,0 +1,73 @@
"""Text Utilities."""
# flake8: noqa
from __future__ import annotations
from difflib import SequenceMatcher
from typing import Iterable, Iterator
from kombu import version_info_t
def escape_regex(p, white=''):
# type: (str, str) -> str
"""Escape string for use within a regular expression."""
# what's up with re.escape? that code must be neglected or something
return ''.join(c if c.isalnum() or c in white
else ('\\000' if c == '\000' else '\\' + c)
for c in p)
def fmatch_iter(needle: str, haystack: Iterable[str], min_ratio: float = 0.6) -> Iterator[tuple[float, str]]:
"""Fuzzy match: iteratively.
Yields
------
Tuple: of ratio and key.
"""
for key in haystack:
ratio = SequenceMatcher(None, needle, key).ratio()
if ratio >= min_ratio:
yield ratio, key
def fmatch_best(needle: str, haystack: Iterable[str], min_ratio: float = 0.6) -> str | None:
"""Fuzzy match - Find best match (scalar)."""
try:
return sorted(
fmatch_iter(needle, haystack, min_ratio), reverse=True,
)[0][1]
except IndexError:
return None
def version_string_as_tuple(s: str) -> version_info_t:
"""Convert version string to version info tuple."""
v = _unpack_version(*s.split('.'))
# X.Y.3a1 -> (X, Y, 3, 'a1')
if isinstance(v.micro, str):
v = version_info_t(v.major, v.minor, *_splitmicro(*v[2:]))
# X.Y.3a1-40 -> (X, Y, 3, 'a1', '40')
if not v.serial and v.releaselevel and '-' in v.releaselevel:
v = version_info_t(*list(v[0:3]) + v.releaselevel.split('-'))
return v
def _unpack_version(
major: str,
minor: str | int = 0,
micro: str | int = 0,
releaselevel: str = '',
serial: str = ''
) -> version_info_t:
return version_info_t(int(major), int(minor), micro, releaselevel, serial)
def _splitmicro(micro: str, releaselevel: str = '', serial: str = '') -> tuple[int, str, str]:
for index, char in enumerate(micro):
if not char.isdigit():
break
else:
return int(micro or 0), releaselevel, serial
return int(micro[:index]), micro[index:], serial

View File

@@ -0,0 +1,9 @@
"""Time Utilities."""
from __future__ import annotations
__all__ = ('maybe_s_to_ms',)
def maybe_s_to_ms(v: int | float | None) -> int | None:
"""Convert seconds to milliseconds, but return None for None."""
return int(float(v) * 1000.0) if v is not None else v

View File

@@ -0,0 +1,132 @@
"""URL Utilities."""
# flake8: noqa
from __future__ import annotations
from collections.abc import Mapping
from functools import partial
from typing import NamedTuple
from urllib.parse import parse_qsl, quote, unquote, urlparse
try:
import ssl
ssl_available = True
except ImportError: # pragma: no cover
ssl_available = False
from ..log import get_logger
safequote = partial(quote, safe='')
logger = get_logger(__name__)
class urlparts(NamedTuple):
"""Named tuple representing parts of the URL."""
scheme: str
hostname: str
port: int
username: str
password: str
path: str
query: Mapping
def parse_url(url):
# type: (str) -> Dict
"""Parse URL into mapping of components."""
scheme, host, port, user, password, path, query = _parse_url(url)
if query:
keys = [key for key in query.keys() if key.startswith('ssl_')]
for key in keys:
if key == "ssl_check_hostname":
query[key] = query[key].lower() != 'false'
elif key == 'ssl_cert_reqs':
query[key] = parse_ssl_cert_reqs(query[key])
if query[key] is None:
logger.warning('Defaulting to insecure SSL behaviour.')
if 'ssl' not in query:
query['ssl'] = {}
query['ssl'][key] = query[key]
del query[key]
return dict(transport=scheme, hostname=host,
port=port, userid=user,
password=password, virtual_host=path, **query)
def url_to_parts(url):
# type: (str) -> urlparts
"""Parse URL into :class:`urlparts` tuple of components."""
scheme = urlparse(url).scheme
schemeless = url[len(scheme) + 3:]
# parse with HTTP URL semantics
parts = urlparse('http://' + schemeless)
path = parts.path or ''
path = path[1:] if path and path[0] == '/' else path
return urlparts(
scheme,
unquote(parts.hostname or '') or None,
parts.port,
unquote(parts.username or '') or None,
unquote(parts.password or '') or None,
unquote(path or '') or None,
dict(parse_qsl(parts.query)),
)
_parse_url = url_to_parts
def as_url(scheme, host=None, port=None, user=None, password=None,
path=None, query=None, sanitize=False, mask='**'):
# type: (str, str, int, str, str, str, str, bool, str) -> str
"""Generate URL from component parts."""
parts = [f'{scheme}://']
if user or password:
if user:
parts.append(safequote(user))
if password:
if sanitize:
parts.extend([':', mask] if mask else [':'])
else:
parts.extend([':', safequote(password)])
parts.append('@')
parts.append(safequote(host) if host else '')
if port:
parts.extend([':', port])
parts.extend(['/', path])
return ''.join(str(part) for part in parts if part)
def sanitize_url(url, mask='**'):
# type: (str, str) -> str
"""Return copy of URL with password removed."""
return as_url(*_parse_url(url), sanitize=True, mask=mask)
def maybe_sanitize_url(url, mask='**'):
# type: (Any, str) -> Any
"""Sanitize url, or do nothing if url undefined."""
if isinstance(url, str) and '://' in url:
return sanitize_url(url, mask)
return url
def parse_ssl_cert_reqs(query_value):
# type: (str) -> Any
"""Given the query parameter for ssl_cert_reqs, return the SSL constant or None."""
if ssl_available:
query_value_to_constant = {
'CERT_REQUIRED': ssl.CERT_REQUIRED,
'CERT_OPTIONAL': ssl.CERT_OPTIONAL,
'CERT_NONE': ssl.CERT_NONE,
'required': ssl.CERT_REQUIRED,
'optional': ssl.CERT_OPTIONAL,
'none': ssl.CERT_NONE,
}
return query_value_to_constant[query_value]
else:
return None

View File

@@ -0,0 +1,15 @@
"""UUID utilities."""
from __future__ import annotations
from typing import Callable
from uuid import UUID, uuid4
def uuid(_uuid: Callable[[], UUID] = uuid4) -> str:
"""Generate unique id in UUID4 format.
See Also
--------
For now this is provided by :func:`uuid.uuid4`.
"""
return str(_uuid())