This commit is contained in:
Iliyan Angelov
2025-09-14 23:24:25 +03:00
commit c67067a2a4
71311 changed files with 6800714 additions and 0 deletions

View File

@@ -0,0 +1,37 @@
"""Utility functions.
Don't import from here directly anymore, as these are only
here for backwards compatibility.
"""
from kombu.utils.objects import cached_property
from kombu.utils.uuid import uuid
from .functional import chunks, memoize, noop
from .imports import gen_task_name, import_from_cwd, instantiate
from .imports import qualname as get_full_cls_name
from .imports import symbol_by_name as get_cls_by_name
# ------------------------------------------------------------------------ #
# > XXX Compat
from .log import LOG_LEVELS
from .nodenames import nodename, nodesplit, worker_direct
gen_unique_id = uuid
__all__ = (
'LOG_LEVELS',
'cached_property',
'chunks',
'gen_task_name',
'gen_task_name',
'gen_unique_id',
'get_cls_by_name',
'get_full_cls_name',
'import_from_cwd',
'instantiate',
'memoize',
'nodename',
'nodesplit',
'noop',
'uuid',
'worker_direct'
)

View File

@@ -0,0 +1,146 @@
"""Abstract classes."""
from abc import ABCMeta, abstractmethod
from collections.abc import Callable
__all__ = ('CallableTask', 'CallableSignature')
def _hasattr(C, attr):
return any(attr in B.__dict__ for B in C.__mro__)
class _AbstractClass(metaclass=ABCMeta):
__required_attributes__ = frozenset()
@classmethod
def _subclasshook_using(cls, parent, C):
return (
cls is parent and
all(_hasattr(C, attr) for attr in cls.__required_attributes__)
) or NotImplemented
@classmethod
def register(cls, other):
# we override `register` to return other for use as a decorator.
type(cls).register(cls, other)
return other
class CallableTask(_AbstractClass, Callable): # pragma: no cover
"""Task interface."""
__required_attributes__ = frozenset({
'delay', 'apply_async', 'apply',
})
@abstractmethod
def delay(self, *args, **kwargs):
pass
@abstractmethod
def apply_async(self, *args, **kwargs):
pass
@abstractmethod
def apply(self, *args, **kwargs):
pass
@classmethod
def __subclasshook__(cls, C):
return cls._subclasshook_using(CallableTask, C)
class CallableSignature(CallableTask): # pragma: no cover
"""Celery Signature interface."""
__required_attributes__ = frozenset({
'clone', 'freeze', 'set', 'link', 'link_error', '__or__',
})
@property
@abstractmethod
def name(self):
pass
@property
@abstractmethod
def type(self):
pass
@property
@abstractmethod
def app(self):
pass
@property
@abstractmethod
def id(self):
pass
@property
@abstractmethod
def task(self):
pass
@property
@abstractmethod
def args(self):
pass
@property
@abstractmethod
def kwargs(self):
pass
@property
@abstractmethod
def options(self):
pass
@property
@abstractmethod
def subtask_type(self):
pass
@property
@abstractmethod
def chord_size(self):
pass
@property
@abstractmethod
def immutable(self):
pass
@abstractmethod
def clone(self, args=None, kwargs=None):
pass
@abstractmethod
def freeze(self, id=None, group_id=None, chord=None, root_id=None,
group_index=None):
pass
@abstractmethod
def set(self, immutable=None, **options):
pass
@abstractmethod
def link(self, callback):
pass
@abstractmethod
def link_error(self, errback):
pass
@abstractmethod
def __or__(self, other):
pass
@abstractmethod
def __invert__(self):
pass
@classmethod
def __subclasshook__(cls, C):
return cls._subclasshook_using(CallableSignature, C)

View File

@@ -0,0 +1,864 @@
"""Custom maps, sets, sequences, and other data structures."""
import time
from collections import OrderedDict as _OrderedDict
from collections import deque
from collections.abc import Callable, Mapping, MutableMapping, MutableSet, Sequence
from heapq import heapify, heappop, heappush
from itertools import chain, count
from queue import Empty
from typing import Any, Dict, Iterable, List # noqa
from .functional import first, uniq
from .text import match_case
try:
# pypy: dicts are ordered in recent versions
from __pypy__ import reversed_dict as _dict_is_ordered
except ImportError:
_dict_is_ordered = None
try:
from django.utils.functional import LazyObject, LazySettings
except ImportError:
class LazyObject:
pass
LazySettings = LazyObject
__all__ = (
'AttributeDictMixin', 'AttributeDict', 'BufferMap', 'ChainMap',
'ConfigurationView', 'DictAttribute', 'Evictable',
'LimitedSet', 'Messagebuffer', 'OrderedDict',
'force_mapping', 'lpmerge',
)
REPR_LIMITED_SET = """\
<{name}({size}): maxlen={0.maxlen}, expires={0.expires}, minlen={0.minlen}>\
"""
def force_mapping(m):
# type: (Any) -> Mapping
"""Wrap object into supporting the mapping interface if necessary."""
if isinstance(m, (LazyObject, LazySettings)):
m = m._wrapped
return DictAttribute(m) if not isinstance(m, Mapping) else m
def lpmerge(L, R):
# type: (Mapping, Mapping) -> Mapping
"""In place left precedent dictionary merge.
Keeps values from `L`, if the value in `R` is :const:`None`.
"""
setitem = L.__setitem__
[setitem(k, v) for k, v in R.items() if v is not None]
return L
class OrderedDict(_OrderedDict):
"""Dict where insertion order matters."""
def _LRUkey(self):
# type: () -> Any
# return value of od.keys does not support __next__,
# but this version will also not create a copy of the list.
return next(iter(self.keys()))
if not hasattr(_OrderedDict, 'move_to_end'):
if _dict_is_ordered: # pragma: no cover
def move_to_end(self, key, last=True):
# type: (Any, bool) -> None
if not last:
# we don't use this argument, and the only way to
# implement this on PyPy seems to be O(n): creating a
# copy with the order changed, so we just raise.
raise NotImplementedError('no last=True on PyPy')
self[key] = self.pop(key)
else:
def move_to_end(self, key, last=True):
# type: (Any, bool) -> None
link = self._OrderedDict__map[key]
link_prev = link[0]
link_next = link[1]
link_prev[1] = link_next
link_next[0] = link_prev
root = self._OrderedDict__root
if last:
last = root[0]
link[0] = last
link[1] = root
last[1] = root[0] = link
else:
first_node = root[1]
link[0] = root
link[1] = first_node
root[1] = first_node[0] = link
class AttributeDictMixin:
"""Mixin for Mapping interface that adds attribute access.
I.e., `d.key -> d[key]`).
"""
def __getattr__(self, k):
# type: (str) -> Any
"""`d.key -> d[key]`."""
try:
return self[k]
except KeyError:
raise AttributeError(
f'{type(self).__name__!r} object has no attribute {k!r}')
def __setattr__(self, key: str, value) -> None:
"""`d[key] = value -> d.key = value`."""
self[key] = value
class AttributeDict(dict, AttributeDictMixin):
"""Dict subclass with attribute access."""
class DictAttribute:
"""Dict interface to attributes.
`obj[k] -> obj.k`
`obj[k] = val -> obj.k = val`
"""
obj = None
def __init__(self, obj):
# type: (Any) -> None
object.__setattr__(self, 'obj', obj)
def __getattr__(self, key):
# type: (Any) -> Any
return getattr(self.obj, key)
def __setattr__(self, key, value):
# type: (Any, Any) -> None
return setattr(self.obj, key, value)
def get(self, key, default=None):
# type: (Any, Any) -> Any
try:
return self[key]
except KeyError:
return default
def setdefault(self, key, default=None):
# type: (Any, Any) -> None
if key not in self:
self[key] = default
def __getitem__(self, key):
# type: (Any) -> Any
try:
return getattr(self.obj, key)
except AttributeError:
raise KeyError(key)
def __setitem__(self, key, value):
# type: (Any, Any) -> Any
setattr(self.obj, key, value)
def __contains__(self, key):
# type: (Any) -> bool
return hasattr(self.obj, key)
def _iterate_keys(self):
# type: () -> Iterable
return iter(dir(self.obj))
iterkeys = _iterate_keys
def __iter__(self):
# type: () -> Iterable
return self._iterate_keys()
def _iterate_items(self):
# type: () -> Iterable
for key in self._iterate_keys():
yield key, getattr(self.obj, key)
iteritems = _iterate_items
def _iterate_values(self):
# type: () -> Iterable
for key in self._iterate_keys():
yield getattr(self.obj, key)
itervalues = _iterate_values
items = _iterate_items
keys = _iterate_keys
values = _iterate_values
MutableMapping.register(DictAttribute)
class ChainMap(MutableMapping):
"""Key lookup on a sequence of maps."""
key_t = None
changes = None
defaults = None
maps = None
_observers = ()
def __init__(self, *maps, **kwargs):
# type: (*Mapping, **Any) -> None
maps = list(maps or [{}])
self.__dict__.update(
key_t=kwargs.get('key_t'),
maps=maps,
changes=maps[0],
defaults=maps[1:],
_observers=[],
)
def add_defaults(self, d):
# type: (Mapping) -> None
d = force_mapping(d)
self.defaults.insert(0, d)
self.maps.insert(1, d)
def pop(self, key, *default):
# type: (Any, *Any) -> Any
try:
return self.maps[0].pop(key, *default)
except KeyError:
raise KeyError(
f'Key not found in the first mapping: {key!r}')
def __missing__(self, key):
# type: (Any) -> Any
raise KeyError(key)
def _key(self, key):
# type: (Any) -> Any
return self.key_t(key) if self.key_t is not None else key
def __getitem__(self, key):
# type: (Any) -> Any
_key = self._key(key)
for mapping in self.maps:
try:
return mapping[_key]
except KeyError:
pass
return self.__missing__(key)
def __setitem__(self, key, value):
# type: (Any, Any) -> None
self.changes[self._key(key)] = value
def __delitem__(self, key):
# type: (Any) -> None
try:
del self.changes[self._key(key)]
except KeyError:
raise KeyError(f'Key not found in first mapping: {key!r}')
def clear(self):
# type: () -> None
self.changes.clear()
def get(self, key, default=None):
# type: (Any, Any) -> Any
try:
return self[self._key(key)]
except KeyError:
return default
def __len__(self):
# type: () -> int
return len(set().union(*self.maps))
def __iter__(self):
return self._iterate_keys()
def __contains__(self, key):
# type: (Any) -> bool
key = self._key(key)
return any(key in m for m in self.maps)
def __bool__(self):
# type: () -> bool
return any(self.maps)
__nonzero__ = __bool__ # Py2
def setdefault(self, key, default=None):
# type: (Any, Any) -> None
key = self._key(key)
if key not in self:
self[key] = default
def update(self, *args, **kwargs):
# type: (*Any, **Any) -> Any
result = self.changes.update(*args, **kwargs)
for callback in self._observers:
callback(*args, **kwargs)
return result
def __repr__(self):
# type: () -> str
return '{0.__class__.__name__}({1})'.format(
self, ', '.join(map(repr, self.maps)))
@classmethod
def fromkeys(cls, iterable, *args):
# type: (type, Iterable, *Any) -> 'ChainMap'
"""Create a ChainMap with a single dict created from the iterable."""
return cls(dict.fromkeys(iterable, *args))
def copy(self):
# type: () -> 'ChainMap'
return self.__class__(self.maps[0].copy(), *self.maps[1:])
__copy__ = copy # Py2
def _iter(self, op):
# type: (Callable) -> Iterable
# defaults must be first in the stream, so values in
# changes take precedence.
# pylint: disable=bad-reversed-sequence
# Someone should teach pylint about properties.
return chain(*(op(d) for d in reversed(self.maps)))
def _iterate_keys(self):
# type: () -> Iterable
return uniq(self._iter(lambda d: d.keys()))
iterkeys = _iterate_keys
def _iterate_items(self):
# type: () -> Iterable
return ((key, self[key]) for key in self)
iteritems = _iterate_items
def _iterate_values(self):
# type: () -> Iterable
return (self[key] for key in self)
itervalues = _iterate_values
def bind_to(self, callback):
self._observers.append(callback)
keys = _iterate_keys
items = _iterate_items
values = _iterate_values
class ConfigurationView(ChainMap, AttributeDictMixin):
"""A view over an applications configuration dictionaries.
Custom (but older) version of :class:`collections.ChainMap`.
If the key does not exist in ``changes``, the ``defaults``
dictionaries are consulted.
Arguments:
changes (Mapping): Map of configuration changes.
defaults (List[Mapping]): List of dictionaries containing
the default configuration.
"""
def __init__(self, changes, defaults=None, keys=None, prefix=None):
# type: (Mapping, Mapping, List[str], str) -> None
defaults = [] if defaults is None else defaults
super().__init__(changes, *defaults)
self.__dict__.update(
prefix=prefix.rstrip('_') + '_' if prefix else prefix,
_keys=keys,
)
def _to_keys(self, key):
# type: (str) -> Sequence[str]
prefix = self.prefix
if prefix:
pkey = prefix + key if not key.startswith(prefix) else key
return match_case(pkey, prefix), key
return key,
def __getitem__(self, key):
# type: (str) -> Any
keys = self._to_keys(key)
getitem = super().__getitem__
for k in keys + (
tuple(f(key) for f in self._keys) if self._keys else ()):
try:
return getitem(k)
except KeyError:
pass
try:
# support subclasses implementing __missing__
return self.__missing__(key)
except KeyError:
if len(keys) > 1:
raise KeyError(
'Key not found: {0!r} (with prefix: {0!r})'.format(*keys))
raise
def __setitem__(self, key, value):
# type: (str, Any) -> Any
self.changes[self._key(key)] = value
def first(self, *keys):
# type: (*str) -> Any
return first(None, (self.get(key) for key in keys))
def get(self, key, default=None):
# type: (str, Any) -> Any
try:
return self[key]
except KeyError:
return default
def clear(self):
# type: () -> None
"""Remove all changes, but keep defaults."""
self.changes.clear()
def __contains__(self, key):
# type: (str) -> bool
keys = self._to_keys(key)
return any(any(k in m for k in keys) for m in self.maps)
def swap_with(self, other):
# type: (ConfigurationView) -> None
changes = other.__dict__['changes']
defaults = other.__dict__['defaults']
self.__dict__.update(
changes=changes,
defaults=defaults,
key_t=other.__dict__['key_t'],
prefix=other.__dict__['prefix'],
maps=[changes] + defaults
)
class LimitedSet:
"""Kind-of Set (or priority queue) with limitations.
Good for when you need to test for membership (`a in set`),
but the set should not grow unbounded.
``maxlen`` is enforced at all times, so if the limit is reached
we'll also remove non-expired items.
You can also configure ``minlen``: this is the minimal residual size
of the set.
All arguments are optional, and no limits are enabled by default.
Arguments:
maxlen (int): Optional max number of items.
Adding more items than ``maxlen`` will result in immediate
removal of items sorted by oldest insertion time.
expires (float): TTL for all items.
Expired items are purged as keys are inserted.
minlen (int): Minimal residual size of this set.
.. versionadded:: 4.0
Value must be less than ``maxlen`` if both are configured.
Older expired items will be deleted, only after the set
exceeds ``minlen`` number of items.
data (Sequence): Initial data to initialize set with.
Can be an iterable of ``(key, value)`` pairs,
a dict (``{key: insertion_time}``), or another instance
of :class:`LimitedSet`.
Example:
>>> s = LimitedSet(maxlen=50000, expires=3600, minlen=4000)
>>> for i in range(60000):
... s.add(i)
... s.add(str(i))
...
>>> 57000 in s # last 50k inserted values are kept
True
>>> '10' in s # '10' did expire and was purged from set.
False
>>> len(s) # maxlen is reached
50000
>>> s.purge(now=time.monotonic() + 7200) # clock + 2 hours
>>> len(s) # now only minlen items are cached
4000
>>>> 57000 in s # even this item is gone now
False
"""
max_heap_percent_overload = 15
def __init__(self, maxlen=0, expires=0, data=None, minlen=0):
# type: (int, float, Mapping, int) -> None
self.maxlen = 0 if maxlen is None else maxlen
self.minlen = 0 if minlen is None else minlen
self.expires = 0 if expires is None else expires
self._data = {}
self._heap = []
if data:
# import items from data
self.update(data)
if not self.maxlen >= self.minlen >= 0:
raise ValueError(
'minlen must be a positive number, less or equal to maxlen.')
if self.expires < 0:
raise ValueError('expires cannot be negative!')
def _refresh_heap(self):
# type: () -> None
"""Time consuming recreating of heap. Don't run this too often."""
self._heap[:] = [entry for entry in self._data.values()]
heapify(self._heap)
def _maybe_refresh_heap(self):
# type: () -> None
if self._heap_overload >= self.max_heap_percent_overload:
self._refresh_heap()
def clear(self):
# type: () -> None
"""Clear all data, start from scratch again."""
self._data.clear()
self._heap[:] = []
def add(self, item, now=None):
# type: (Any, float) -> None
"""Add a new item, or reset the expiry time of an existing item."""
now = now or time.monotonic()
if item in self._data:
self.discard(item)
entry = (now, item)
self._data[item] = entry
heappush(self._heap, entry)
if self.maxlen and len(self._data) >= self.maxlen:
self.purge()
def update(self, other):
# type: (Iterable) -> None
"""Update this set from other LimitedSet, dict or iterable."""
if not other:
return
if isinstance(other, LimitedSet):
self._data.update(other._data)
self._refresh_heap()
self.purge()
elif isinstance(other, dict):
# revokes are sent as a dict
for key, inserted in other.items():
if isinstance(inserted, (tuple, list)):
# in case someone uses ._data directly for sending update
inserted = inserted[0]
if not isinstance(inserted, float):
raise ValueError(
'Expecting float timestamp, got type '
f'{type(inserted)!r} with value: {inserted}')
self.add(key, inserted)
else:
# XXX AVOID THIS, it could keep old data if more parties
# exchange them all over and over again
for obj in other:
self.add(obj)
def discard(self, item):
# type: (Any) -> None
# mark an existing item as removed. If KeyError is not found, pass.
self._data.pop(item, None)
self._maybe_refresh_heap()
pop_value = discard
def purge(self, now=None):
# type: (float) -> None
"""Check oldest items and remove them if needed.
Arguments:
now (float): Time of purging -- by default right now.
This can be useful for unit testing.
"""
now = now or time.monotonic()
now = now() if isinstance(now, Callable) else now
if self.maxlen:
while len(self._data) > self.maxlen:
self.pop()
# time based expiring:
if self.expires:
while len(self._data) > self.minlen >= 0:
inserted_time, _ = self._heap[0]
if inserted_time + self.expires > now:
break # oldest item hasn't expired yet
self.pop()
def pop(self, default=None) -> Any:
# type: (Any) -> Any
"""Remove and return the oldest item, or :const:`None` when empty."""
while self._heap:
_, item = heappop(self._heap)
try:
self._data.pop(item)
except KeyError:
pass
else:
return item
return default
def as_dict(self):
# type: () -> Dict
"""Whole set as serializable dictionary.
Example:
>>> s = LimitedSet(maxlen=200)
>>> r = LimitedSet(maxlen=200)
>>> for i in range(500):
... s.add(i)
...
>>> r.update(s.as_dict())
>>> r == s
True
"""
return {key: inserted for inserted, key in self._data.values()}
def __eq__(self, other):
# type: (Any) -> bool
return self._data == other._data
def __repr__(self):
# type: () -> str
return REPR_LIMITED_SET.format(
self, name=type(self).__name__, size=len(self),
)
def __iter__(self):
# type: () -> Iterable
return (i for _, i in sorted(self._data.values()))
def __len__(self):
# type: () -> int
return len(self._data)
def __contains__(self, key):
# type: (Any) -> bool
return key in self._data
def __reduce__(self):
# type: () -> Any
return self.__class__, (
self.maxlen, self.expires, self.as_dict(), self.minlen)
def __bool__(self):
# type: () -> bool
return bool(self._data)
__nonzero__ = __bool__ # Py2
@property
def _heap_overload(self):
# type: () -> float
"""Compute how much is heap bigger than data [percents]."""
return len(self._heap) * 100 / max(len(self._data), 1) - 100
MutableSet.register(LimitedSet)
class Evictable:
"""Mixin for classes supporting the ``evict`` method."""
Empty = Empty
def evict(self) -> None:
"""Force evict until maxsize is enforced."""
self._evict(range=count)
def _evict(self, limit: int = 100, range=range) -> None:
try:
[self._evict1() for _ in range(limit)]
except IndexError:
pass
def _evict1(self) -> None:
if self._evictcount <= self.maxsize:
raise IndexError()
try:
self._pop_to_evict()
except self.Empty:
raise IndexError()
class Messagebuffer(Evictable):
"""A buffer of pending messages."""
Empty = Empty
def __init__(self, maxsize, iterable=None, deque=deque):
# type: (int, Iterable, Any) -> None
self.maxsize = maxsize
self.data = deque(iterable or [])
self._append = self.data.append
self._pop = self.data.popleft
self._len = self.data.__len__
self._extend = self.data.extend
def put(self, item):
# type: (Any) -> None
self._append(item)
self.maxsize and self._evict()
def extend(self, it):
# type: (Iterable) -> None
self._extend(it)
self.maxsize and self._evict()
def take(self, *default):
# type: (*Any) -> Any
try:
return self._pop()
except IndexError:
if default:
return default[0]
raise self.Empty()
def _pop_to_evict(self):
# type: () -> None
return self.take()
def __repr__(self):
# type: () -> str
return f'<{type(self).__name__}: {len(self)}/{self.maxsize}>'
def __iter__(self):
# type: () -> Iterable
while 1:
try:
yield self._pop()
except IndexError:
break
def __len__(self):
# type: () -> int
return self._len()
def __contains__(self, item) -> bool:
return item in self.data
def __reversed__(self):
# type: () -> Iterable
return reversed(self.data)
def __getitem__(self, index):
# type: (Any) -> Any
return self.data[index]
@property
def _evictcount(self):
# type: () -> int
return len(self)
Sequence.register(Messagebuffer)
class BufferMap(OrderedDict, Evictable):
"""Map of buffers."""
Buffer = Messagebuffer
Empty = Empty
maxsize = None
total = 0
bufmaxsize = None
def __init__(self, maxsize, iterable=None, bufmaxsize=1000):
# type: (int, Iterable, int) -> None
super().__init__()
self.maxsize = maxsize
self.bufmaxsize = 1000
if iterable:
self.update(iterable)
self.total = sum(len(buf) for buf in self.items())
def put(self, key, item):
# type: (Any, Any) -> None
self._get_or_create_buffer(key).put(item)
self.total += 1
self.move_to_end(key) # least recently used.
self.maxsize and self._evict()
def extend(self, key, it):
# type: (Any, Iterable) -> None
self._get_or_create_buffer(key).extend(it)
self.total += len(it)
self.maxsize and self._evict()
def take(self, key, *default):
# type: (Any, *Any) -> Any
item, throw = None, False
try:
buf = self[key]
except KeyError:
throw = True
else:
try:
item = buf.take()
self.total -= 1
except self.Empty:
throw = True
else:
self.move_to_end(key) # mark as LRU
if throw:
if default:
return default[0]
raise self.Empty()
return item
def _get_or_create_buffer(self, key):
# type: (Any) -> Messagebuffer
try:
return self[key]
except KeyError:
buf = self[key] = self._new_buffer()
return buf
def _new_buffer(self):
# type: () -> Messagebuffer
return self.Buffer(maxsize=self.bufmaxsize)
def _LRUpop(self, *default):
# type: (*Any) -> Any
return self[self._LRUkey()].take(*default)
def _pop_to_evict(self):
# type: () -> None
for _ in range(100):
key = self._LRUkey()
buf = self[key]
try:
buf.take()
except (IndexError, self.Empty):
# buffer empty, remove it from mapping.
self.pop(key)
else:
# we removed one item
self.total -= 1
# if buffer is empty now, remove it from mapping.
if not len(buf):
self.pop(key)
else:
# move to least recently used.
self.move_to_end(key)
break
def __repr__(self):
# type: () -> str
return f'<{type(self).__name__}: {self.total}/{self.maxsize}>'
@property
def _evictcount(self):
# type: () -> int
return self.total

View File

@@ -0,0 +1,193 @@
"""Utilities for debugging memory usage, blocking calls, etc."""
import os
import sys
import traceback
from contextlib import contextmanager
from functools import partial
from pprint import pprint
from celery.platforms import signals
from celery.utils.text import WhateverIO
try:
from psutil import Process
except ImportError:
Process = None
__all__ = (
'blockdetection', 'sample_mem', 'memdump', 'sample',
'humanbytes', 'mem_rss', 'ps', 'cry',
)
UNITS = (
(2 ** 40.0, 'TB'),
(2 ** 30.0, 'GB'),
(2 ** 20.0, 'MB'),
(2 ** 10.0, 'KB'),
(0.0, 'b'),
)
_process = None
_mem_sample = []
def _on_blocking(signum, frame):
import inspect
raise RuntimeError(
f'Blocking detection timed-out at: {inspect.getframeinfo(frame)}'
)
@contextmanager
def blockdetection(timeout):
"""Context that raises an exception if process is blocking.
Uses ``SIGALRM`` to detect blocking functions.
"""
if not timeout:
yield
else:
old_handler = signals['ALRM']
old_handler = None if old_handler == _on_blocking else old_handler
signals['ALRM'] = _on_blocking
try:
yield signals.arm_alarm(timeout)
finally:
if old_handler:
signals['ALRM'] = old_handler
signals.reset_alarm()
def sample_mem():
"""Sample RSS memory usage.
Statistics can then be output by calling :func:`memdump`.
"""
current_rss = mem_rss()
_mem_sample.append(current_rss)
return current_rss
def _memdump(samples=10): # pragma: no cover
S = _mem_sample
prev = list(S) if len(S) <= samples else sample(S, samples)
_mem_sample[:] = []
import gc
gc.collect()
after_collect = mem_rss()
return prev, after_collect
def memdump(samples=10, file=None): # pragma: no cover
"""Dump memory statistics.
Will print a sample of all RSS memory samples added by
calling :func:`sample_mem`, and in addition print
used RSS memory after :func:`gc.collect`.
"""
say = partial(print, file=file)
if ps() is None:
say('- rss: (psutil not installed).')
return
prev, after_collect = _memdump(samples)
if prev:
say('- rss (sample):')
for mem in prev:
say(f'- > {mem},')
say(f'- rss (end): {after_collect}.')
def sample(x, n, k=0):
"""Given a list `x` a sample of length ``n`` of that list is returned.
For example, if `n` is 10, and `x` has 100 items, a list of every tenth.
item is returned.
``k`` can be used as offset.
"""
j = len(x) // n
for _ in range(n):
try:
yield x[k]
except IndexError:
break
k += j
def hfloat(f, p=5):
"""Convert float to value suitable for humans.
Arguments:
f (float): The floating point number.
p (int): Floating point precision (default is 5).
"""
i = int(f)
return i if i == f else '{0:.{p}}'.format(f, p=p)
def humanbytes(s):
"""Convert bytes to human-readable form (e.g., KB, MB)."""
return next(
f'{hfloat(s / div if div else s)}{unit}'
for div, unit in UNITS if s >= div
)
def mem_rss():
"""Return RSS memory usage as a humanized string."""
p = ps()
if p is not None:
return humanbytes(_process_memory_info(p).rss)
def ps(): # pragma: no cover
"""Return the global :class:`psutil.Process` instance.
Note:
Returns :const:`None` if :pypi:`psutil` is not installed.
"""
global _process
if _process is None and Process is not None:
_process = Process(os.getpid())
return _process
def _process_memory_info(process):
try:
return process.memory_info()
except AttributeError:
return process.get_memory_info()
def cry(out=None, sepchr='=', seplen=49): # pragma: no cover
"""Return stack-trace of all active threads.
See Also:
Taken from https://gist.github.com/737056.
"""
import threading
out = WhateverIO() if out is None else out
P = partial(print, file=out)
# get a map of threads by their ID so we can print their names
# during the traceback dump
tmap = {t.ident: t for t in threading.enumerate()}
sep = sepchr * seplen
for tid, frame in sys._current_frames().items():
thread = tmap.get(tid)
if not thread:
# skip old junk (left-overs from a fork)
continue
P(f'{thread.name}')
P(sep)
traceback.print_stack(frame, file=out)
P(sep)
P('LOCAL VARIABLES')
P(sep)
pprint(frame.f_locals, stream=out)
P('\n')
return out.getvalue()

View File

@@ -0,0 +1,113 @@
"""Deprecation utilities."""
import warnings
from vine.utils import wraps
from celery.exceptions import CDeprecationWarning, CPendingDeprecationWarning
__all__ = ('Callable', 'Property', 'warn')
PENDING_DEPRECATION_FMT = """
{description} is scheduled for deprecation in \
version {deprecation} and removal in version v{removal}. \
{alternative}
"""
DEPRECATION_FMT = """
{description} is deprecated and scheduled for removal in
version {removal}. {alternative}
"""
def warn(description=None, deprecation=None,
removal=None, alternative=None, stacklevel=2):
"""Warn of (pending) deprecation."""
ctx = {'description': description,
'deprecation': deprecation, 'removal': removal,
'alternative': alternative}
if deprecation is not None:
w = CPendingDeprecationWarning(PENDING_DEPRECATION_FMT.format(**ctx))
else:
w = CDeprecationWarning(DEPRECATION_FMT.format(**ctx))
warnings.warn(w, stacklevel=stacklevel)
def Callable(deprecation=None, removal=None,
alternative=None, description=None):
"""Decorator for deprecated functions.
A deprecation warning will be emitted when the function is called.
Arguments:
deprecation (str): Version that marks first deprecation, if this
argument isn't set a ``PendingDeprecationWarning`` will be
emitted instead.
removal (str): Future version when this feature will be removed.
alternative (str): Instructions for an alternative solution (if any).
description (str): Description of what's being deprecated.
"""
def _inner(fun):
@wraps(fun)
def __inner(*args, **kwargs):
from .imports import qualname
warn(description=description or qualname(fun),
deprecation=deprecation,
removal=removal,
alternative=alternative,
stacklevel=3)
return fun(*args, **kwargs)
return __inner
return _inner
def Property(deprecation=None, removal=None,
alternative=None, description=None):
"""Decorator for deprecated properties."""
def _inner(fun):
return _deprecated_property(
fun, deprecation=deprecation, removal=removal,
alternative=alternative, description=description or fun.__name__)
return _inner
class _deprecated_property:
def __init__(self, fget=None, fset=None, fdel=None, doc=None, **depreinfo):
self.__get = fget
self.__set = fset
self.__del = fdel
self.__name__, self.__module__, self.__doc__ = (
fget.__name__, fget.__module__, fget.__doc__,
)
self.depreinfo = depreinfo
self.depreinfo.setdefault('stacklevel', 3)
def __get__(self, obj, type=None):
if obj is None:
return self
warn(**self.depreinfo)
return self.__get(obj)
def __set__(self, obj, value):
if obj is None:
return self
if self.__set is None:
raise AttributeError('cannot set attribute')
warn(**self.depreinfo)
self.__set(obj, value)
def __delete__(self, obj):
if obj is None:
return self
if self.__del is None:
raise AttributeError('cannot delete attribute')
warn(**self.depreinfo)
self.__del(obj)
def setter(self, fset):
return self.__class__(self.__get, fset, self.__del, **self.depreinfo)
def deleter(self, fdel):
return self.__class__(self.__get, self.__set, fdel, **self.depreinfo)

View File

@@ -0,0 +1,4 @@
"""Observer pattern."""
from .signal import Signal
__all__ = ('Signal',)

View File

@@ -0,0 +1,354 @@
"""Implementation of the Observer pattern."""
import sys
import threading
import warnings
import weakref
from weakref import WeakMethod
from kombu.utils.functional import retry_over_time
from celery.exceptions import CDeprecationWarning
from celery.local import PromiseProxy, Proxy
from celery.utils.functional import fun_accepts_kwargs
from celery.utils.log import get_logger
from celery.utils.time import humanize_seconds
__all__ = ('Signal',)
logger = get_logger(__name__)
def _make_id(target): # pragma: no cover
if isinstance(target, Proxy):
target = target._get_current_object()
if isinstance(target, (bytes, str)):
# see Issue #2475
return target
if hasattr(target, '__func__'):
return id(target.__func__)
return id(target)
def _boundmethod_safe_weakref(obj):
"""Get weakref constructor appropriate for `obj`. `obj` may be a bound method.
Bound method objects must be special-cased because they're usually garbage
collected immediately, even if the instance they're bound to persists.
Returns:
a (weakref constructor, main object) tuple. `weakref constructor` is
either :class:`weakref.ref` or :class:`weakref.WeakMethod`. `main
object` is the instance that `obj` is bound to if it is a bound method;
otherwise `main object` is simply `obj.
"""
try:
obj.__func__
obj.__self__
# Bound method
return WeakMethod, obj.__self__
except AttributeError:
# Not a bound method
return weakref.ref, obj
def _make_lookup_key(receiver, sender, dispatch_uid):
if dispatch_uid:
return (dispatch_uid, _make_id(sender))
else:
return (_make_id(receiver), _make_id(sender))
NONE_ID = _make_id(None)
NO_RECEIVERS = object()
RECEIVER_RETRY_ERROR = """\
Could not process signal receiver %(receiver)s. Retrying %(when)s...\
"""
class Signal: # pragma: no cover
"""Create new signal.
Keyword Arguments:
providing_args (List): A list of the arguments this signal can pass
along in a :meth:`send` call.
use_caching (bool): Enable receiver cache.
name (str): Name of signal, used for debugging purposes.
"""
#: Holds a dictionary of
#: ``{receiverkey (id): weakref(receiver)}`` mappings.
receivers = None
def __init__(self, providing_args=None, use_caching=False, name=None):
self.receivers = []
self.providing_args = set(
providing_args if providing_args is not None else [])
self.lock = threading.Lock()
self.use_caching = use_caching
self.name = name
# For convenience we create empty caches even if they are not used.
# A note about caching: if use_caching is defined, then for each
# distinct sender we cache the receivers that sender has in
# 'sender_receivers_cache'. The cache is cleaned when .connect() or
# .disconnect() is called and populated on .send().
self.sender_receivers_cache = (
weakref.WeakKeyDictionary() if use_caching else {}
)
self._dead_receivers = False
def _connect_proxy(self, fun, sender, weak, dispatch_uid):
return self.connect(
fun, sender=sender._get_current_object(),
weak=weak, dispatch_uid=dispatch_uid,
)
def connect(self, *args, **kwargs):
"""Connect receiver to sender for signal.
Arguments:
receiver (Callable): A function or an instance method which is to
receive signals. Receivers must be hashable objects.
if weak is :const:`True`, then receiver must be
weak-referenceable.
Receivers must be able to accept keyword arguments.
If receivers have a `dispatch_uid` attribute, the receiver will
not be added if another receiver already exists with that
`dispatch_uid`.
sender (Any): The sender to which the receiver should respond.
Must either be a Python object, or :const:`None` to
receive events from any sender.
weak (bool): Whether to use weak references to the receiver.
By default, the module will attempt to use weak references to
the receiver objects. If this parameter is false, then strong
references will be used.
dispatch_uid (Hashable): An identifier used to uniquely identify a
particular instance of a receiver. This will usually be a
string, though it may be anything hashable.
retry (bool): If the signal receiver raises an exception
(e.g. ConnectionError), the receiver will be retried until it
runs successfully. A strong ref to the receiver will be stored
and the `weak` option will be ignored.
"""
def _handle_options(sender=None, weak=True, dispatch_uid=None,
retry=False):
def _connect_signal(fun):
options = {'dispatch_uid': dispatch_uid,
'weak': weak}
def _retry_receiver(retry_fun):
def _try_receiver_over_time(*args, **kwargs):
def on_error(exc, intervals, retries):
interval = next(intervals)
err_msg = RECEIVER_RETRY_ERROR % \
{'receiver': retry_fun,
'when': humanize_seconds(interval, 'in', ' ')}
logger.error(err_msg)
return interval
return retry_over_time(retry_fun, Exception, args,
kwargs, on_error)
return _try_receiver_over_time
if retry:
options['weak'] = False
if not dispatch_uid:
# if there's no dispatch_uid then we need to set the
# dispatch uid to the original func id so we can look
# it up later with the original func id
options['dispatch_uid'] = _make_id(fun)
fun = _retry_receiver(fun)
self._connect_signal(fun, sender, options['weak'],
options['dispatch_uid'])
return fun
return _connect_signal
if args and callable(args[0]):
return _handle_options(*args[1:], **kwargs)(args[0])
return _handle_options(*args, **kwargs)
def _connect_signal(self, receiver, sender, weak, dispatch_uid):
assert callable(receiver), 'Signal receivers must be callable'
if not fun_accepts_kwargs(receiver):
raise ValueError(
'Signal receiver must accept keyword arguments.')
if isinstance(sender, PromiseProxy):
sender.__then__(
self._connect_proxy, receiver, sender, weak, dispatch_uid,
)
return receiver
lookup_key = _make_lookup_key(receiver, sender, dispatch_uid)
if weak:
ref, receiver_object = _boundmethod_safe_weakref(receiver)
receiver = ref(receiver)
weakref.finalize(receiver_object, self._remove_receiver)
with self.lock:
self._clear_dead_receivers()
for r_key, _ in self.receivers:
if r_key == lookup_key:
break
else:
self.receivers.append((lookup_key, receiver))
self.sender_receivers_cache.clear()
return receiver
def disconnect(self, receiver=None, sender=None, weak=None,
dispatch_uid=None):
"""Disconnect receiver from sender for signal.
If weak references are used, disconnect needn't be called.
The receiver will be removed from dispatch automatically.
Arguments:
receiver (Callable): The registered receiver to disconnect.
May be none if `dispatch_uid` is specified.
sender (Any): The registered sender to disconnect.
weak (bool): The weakref state to disconnect.
dispatch_uid (Hashable): The unique identifier of the receiver
to disconnect.
"""
if weak is not None:
warnings.warn(
'Passing `weak` to disconnect has no effect.',
CDeprecationWarning, stacklevel=2)
lookup_key = _make_lookup_key(receiver, sender, dispatch_uid)
disconnected = False
with self.lock:
self._clear_dead_receivers()
for index in range(len(self.receivers)):
(r_key, _) = self.receivers[index]
if r_key == lookup_key:
disconnected = True
del self.receivers[index]
break
self.sender_receivers_cache.clear()
return disconnected
def has_listeners(self, sender=None):
return bool(self._live_receivers(sender))
def send(self, sender, **named):
"""Send signal from sender to all connected receivers.
If any receiver raises an error, the exception is returned as the
corresponding response. (This is different from the "send" in
Django signals. In Celery "send" and "send_robust" do the same thing.)
Arguments:
sender (Any): The sender of the signal.
Either a specific object or :const:`None`.
**named (Any): Named arguments which will be passed to receivers.
Returns:
List: of tuple pairs: `[(receiver, response), … ]`.
"""
responses = []
if not self.receivers or \
self.sender_receivers_cache.get(sender) is NO_RECEIVERS:
return responses
for receiver in self._live_receivers(sender):
try:
response = receiver(signal=self, sender=sender, **named)
except Exception as exc: # pylint: disable=broad-except
if not hasattr(exc, '__traceback__'):
exc.__traceback__ = sys.exc_info()[2]
logger.exception(
'Signal handler %r raised: %r', receiver, exc)
responses.append((receiver, exc))
else:
responses.append((receiver, response))
return responses
send_robust = send # Compat with Django interface.
def _clear_dead_receivers(self):
# Warning: caller is assumed to hold self.lock
if self._dead_receivers:
self._dead_receivers = False
new_receivers = []
for r in self.receivers:
if isinstance(r[1], weakref.ReferenceType) and r[1]() is None:
continue
new_receivers.append(r)
self.receivers = new_receivers
def _live_receivers(self, sender):
"""Filter sequence of receivers to get resolved, live receivers.
This checks for weak references and resolves them, then returning only
live receivers.
"""
receivers = None
if self.use_caching and not self._dead_receivers:
receivers = self.sender_receivers_cache.get(sender)
# We could end up here with NO_RECEIVERS even if we do check this
# case in .send() prior to calling _Live_receivers() due to
# concurrent .send() call.
if receivers is NO_RECEIVERS:
return []
if receivers is None:
with self.lock:
self._clear_dead_receivers()
senderkey = _make_id(sender)
receivers = []
for (receiverkey, r_senderkey), receiver in self.receivers:
if r_senderkey == NONE_ID or r_senderkey == senderkey:
receivers.append(receiver)
if self.use_caching:
if not receivers:
self.sender_receivers_cache[sender] = NO_RECEIVERS
else:
# Note: we must cache the weakref versions.
self.sender_receivers_cache[sender] = receivers
non_weak_receivers = []
for receiver in receivers:
if isinstance(receiver, weakref.ReferenceType):
# Dereference the weak reference.
receiver = receiver()
if receiver is not None:
non_weak_receivers.append(receiver)
else:
non_weak_receivers.append(receiver)
return non_weak_receivers
def _remove_receiver(self, receiver=None):
"""Remove dead receivers from connections."""
# Mark that the self..receivers first has dead weakrefs. If so,
# we will clean those up in connect, disconnect and _live_receivers
# while holding self.lock. Note that doing the cleanup here isn't a
# good idea, _remove_receiver() will be called as a side effect of
# garbage collection, and so the call can happen wh ile we are already
# holding self.lock.
self._dead_receivers = True
def __repr__(self):
"""``repr(signal)``."""
return f'<{type(self).__name__}: {self.name} providing_args={self.providing_args!r}>'
def __str__(self):
"""``str(signal)``."""
return repr(self)

View File

@@ -0,0 +1,402 @@
"""Functional-style utilities."""
import inspect
from collections import UserList
from functools import partial
from itertools import islice, tee, zip_longest
from typing import Any, Callable
from kombu.utils.functional import LRUCache, dictfilter, is_list, lazy, maybe_evaluate, maybe_list, memoize
from vine import promise
from celery.utils.log import get_logger
logger = get_logger(__name__)
__all__ = (
'LRUCache', 'is_list', 'maybe_list', 'memoize', 'mlazy', 'noop',
'first', 'firstmethod', 'chunks', 'padlist', 'mattrgetter', 'uniq',
'regen', 'dictfilter', 'lazy', 'maybe_evaluate', 'head_from_fun',
'maybe', 'fun_accepts_kwargs',
)
FUNHEAD_TEMPLATE = """
def {fun_name}({fun_args}):
return {fun_value}
"""
class DummyContext:
def __enter__(self):
return self
def __exit__(self, *exc_info):
pass
class mlazy(lazy):
"""Memoized lazy evaluation.
The function is only evaluated once, every subsequent access
will return the same value.
"""
#: Set to :const:`True` after the object has been evaluated.
evaluated = False
_value = None
def evaluate(self):
if not self.evaluated:
self._value = super().evaluate()
self.evaluated = True
return self._value
def noop(*args, **kwargs):
"""No operation.
Takes any arguments/keyword arguments and does nothing.
"""
def pass1(arg, *args, **kwargs):
"""Return the first positional argument."""
return arg
def evaluate_promises(it):
for value in it:
if isinstance(value, promise):
value = value()
yield value
def first(predicate, it):
"""Return the first element in ``it`` that ``predicate`` accepts.
If ``predicate`` is None it will return the first item that's not
:const:`None`.
"""
return next(
(v for v in evaluate_promises(it) if (
predicate(v) if predicate is not None else v is not None)),
None,
)
def firstmethod(method, on_call=None):
"""Multiple dispatch.
Return a function that with a list of instances,
finds the first instance that gives a value for the given method.
The list can also contain lazy instances
(:class:`~kombu.utils.functional.lazy`.)
"""
def _matcher(it, *args, **kwargs):
for obj in it:
try:
meth = getattr(maybe_evaluate(obj), method)
reply = (on_call(meth, *args, **kwargs) if on_call
else meth(*args, **kwargs))
except AttributeError:
pass
else:
if reply is not None:
return reply
return _matcher
def chunks(it, n):
"""Split an iterator into chunks with `n` elements each.
Warning:
``it`` must be an actual iterator, if you pass this a
concrete sequence will get you repeating elements.
So ``chunks(iter(range(1000)), 10)`` is fine, but
``chunks(range(1000), 10)`` is not.
Example:
# n == 2
>>> x = chunks(iter([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]), 2)
>>> list(x)
[[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10]]
# n == 3
>>> x = chunks(iter([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]), 3)
>>> list(x)
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10]]
"""
for item in it:
yield [item] + list(islice(it, n - 1))
def padlist(container, size, default=None):
"""Pad list with default elements.
Example:
>>> first, last, city = padlist(['George', 'Costanza', 'NYC'], 3)
('George', 'Costanza', 'NYC')
>>> first, last, city = padlist(['George', 'Costanza'], 3)
('George', 'Costanza', None)
>>> first, last, city, planet = padlist(
... ['George', 'Costanza', 'NYC'], 4, default='Earth',
... )
('George', 'Costanza', 'NYC', 'Earth')
"""
return list(container)[:size] + [default] * (size - len(container))
def mattrgetter(*attrs):
"""Get attributes, ignoring attribute errors.
Like :func:`operator.itemgetter` but return :const:`None` on missing
attributes instead of raising :exc:`AttributeError`.
"""
return lambda obj: {attr: getattr(obj, attr, None) for attr in attrs}
def uniq(it):
"""Return all unique elements in ``it``, preserving order."""
seen = set()
return (seen.add(obj) or obj for obj in it if obj not in seen)
def lookahead(it):
"""Yield pairs of (current, next) items in `it`.
`next` is None if `current` is the last item.
Example:
>>> list(lookahead(x for x in range(6)))
[(0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, None)]
"""
a, b = tee(it)
next(b, None)
return zip_longest(a, b)
def regen(it):
"""Convert iterator to an object that can be consumed multiple times.
``Regen`` takes any iterable, and if the object is an
generator it will cache the evaluated list on first access,
so that the generator can be "consumed" multiple times.
"""
if isinstance(it, (list, tuple)):
return it
return _regen(it)
class _regen(UserList, list):
# must be subclass of list so that json can encode.
def __init__(self, it):
# pylint: disable=super-init-not-called
# UserList creates a new list and sets .data, so we don't
# want to call init here.
self.__it = it
self.__consumed = []
self.__done = False
def __reduce__(self):
return list, (self.data,)
def map(self, func):
self.__consumed = [func(el) for el in self.__consumed]
self.__it = map(func, self.__it)
def __length_hint__(self):
return self.__it.__length_hint__()
def __lookahead_consume(self, limit=None):
if not self.__done and (limit is None or limit > 0):
it = iter(self.__it)
try:
now = next(it)
except StopIteration:
return
self.__consumed.append(now)
# Maintain a single look-ahead to ensure we set `__done` when the
# underlying iterator gets exhausted
while not self.__done:
try:
next_ = next(it)
self.__consumed.append(next_)
except StopIteration:
self.__done = True
break
finally:
yield now
now = next_
# We can break out when `limit` is exhausted
if limit is not None:
limit -= 1
if limit <= 0:
break
def __iter__(self):
yield from self.__consumed
yield from self.__lookahead_consume()
def __getitem__(self, index):
if index < 0:
return self.data[index]
# Consume elements up to the desired index prior to attempting to
# access it from within `__consumed`
consume_count = index - len(self.__consumed) + 1
for _ in self.__lookahead_consume(limit=consume_count):
pass
return self.__consumed[index]
def __bool__(self):
if len(self.__consumed):
return True
try:
next(iter(self))
except StopIteration:
return False
else:
return True
@property
def data(self):
if not self.__done:
self.__consumed.extend(self.__it)
self.__done = True
return self.__consumed
def __repr__(self):
return "<{}: [{}{}]>".format(
self.__class__.__name__,
", ".join(repr(e) for e in self.__consumed),
"..." if not self.__done else "",
)
def _argsfromspec(spec, replace_defaults=True):
if spec.defaults:
split = len(spec.defaults)
defaults = (list(range(len(spec.defaults))) if replace_defaults
else spec.defaults)
positional = spec.args[:-split]
optional = list(zip(spec.args[-split:], defaults))
else:
positional, optional = spec.args, []
varargs = spec.varargs
varkw = spec.varkw
if spec.kwonlydefaults:
kwonlyargs = set(spec.kwonlyargs) - set(spec.kwonlydefaults.keys())
if replace_defaults:
kwonlyargs_optional = [
(kw, i) for i, kw in enumerate(spec.kwonlydefaults.keys())
]
else:
kwonlyargs_optional = list(spec.kwonlydefaults.items())
else:
kwonlyargs, kwonlyargs_optional = spec.kwonlyargs, []
return ', '.join(filter(None, [
', '.join(positional),
', '.join(f'{k}={v}' for k, v in optional),
f'*{varargs}' if varargs else None,
'*' if (kwonlyargs or kwonlyargs_optional) and not varargs else None,
', '.join(kwonlyargs) if kwonlyargs else None,
', '.join(f'{k}="{v}"' for k, v in kwonlyargs_optional),
f'**{varkw}' if varkw else None,
]))
def head_from_fun(fun: Callable[..., Any], bound: bool = False) -> str:
"""Generate signature function from actual function."""
# we could use inspect.Signature here, but that implementation
# is very slow since it implements the argument checking
# in pure-Python. Instead we use exec to create a new function
# with an empty body, meaning it has the same performance as
# as just calling a function.
is_function = inspect.isfunction(fun)
is_callable = callable(fun)
is_cython = fun.__class__.__name__ == 'cython_function_or_method'
is_method = inspect.ismethod(fun)
if not is_function and is_callable and not is_method and not is_cython:
name, fun = fun.__class__.__name__, fun.__call__
else:
name = fun.__name__
definition = FUNHEAD_TEMPLATE.format(
fun_name=name,
fun_args=_argsfromspec(inspect.getfullargspec(fun)),
fun_value=1,
)
logger.debug(definition)
namespace = {'__name__': fun.__module__}
# pylint: disable=exec-used
# Tasks are rarely, if ever, created at runtime - exec here is fine.
exec(definition, namespace)
result = namespace[name]
result._source = definition
if bound:
return partial(result, object())
return result
def arity_greater(fun, n):
argspec = inspect.getfullargspec(fun)
return argspec.varargs or len(argspec.args) > n
def fun_takes_argument(name, fun, position=None):
spec = inspect.getfullargspec(fun)
return (
spec.varkw or spec.varargs or
(len(spec.args) >= position if position else name in spec.args)
)
def fun_accepts_kwargs(fun):
"""Return true if function accepts arbitrary keyword arguments."""
return any(
p for p in inspect.signature(fun).parameters.values()
if p.kind == p.VAR_KEYWORD
)
def maybe(typ, val):
"""Call typ on value if val is defined."""
return typ(val) if val is not None else val
def seq_concat_item(seq, item):
"""Return copy of sequence seq with item added.
Returns:
Sequence: if seq is a tuple, the result will be a tuple,
otherwise it depends on the implementation of ``__add__``.
"""
return seq + (item,) if isinstance(seq, tuple) else seq + [item]
def seq_concat_seq(a, b):
"""Concatenate two sequences: ``a + b``.
Returns:
Sequence: The return value will depend on the largest sequence
- if b is larger and is a tuple, the return value will be a tuple.
- if a is larger and is a list, the return value will be a list,
"""
# find the type of the largest sequence
prefer = type(max([a, b], key=len))
# convert the smallest list to the type of the largest sequence.
if not isinstance(a, prefer):
a = prefer(a)
if not isinstance(b, prefer):
b = prefer(b)
return a + b
def is_numeric_value(value):
return isinstance(value, (int, float)) and not isinstance(value, bool)

View File

@@ -0,0 +1,309 @@
"""Dependency graph implementation."""
from collections import Counter
from textwrap import dedent
from kombu.utils.encoding import bytes_to_str, safe_str
__all__ = ('DOT', 'CycleError', 'DependencyGraph', 'GraphFormatter')
class DOT:
"""Constants related to the dot format."""
HEAD = dedent("""
{IN}{type} {id} {{
{INp}graph [{attrs}]
""")
ATTR = '{name}={value}'
NODE = '{INp}"{0}" [{attrs}]'
EDGE = '{INp}"{0}" {dir} "{1}" [{attrs}]'
ATTRSEP = ', '
DIRS = {'graph': '--', 'digraph': '->'}
TAIL = '{IN}}}'
class CycleError(Exception):
"""A cycle was detected in an acyclic graph."""
class DependencyGraph:
"""A directed acyclic graph of objects and their dependencies.
Supports a robust topological sort
to detect the order in which they must be handled.
Takes an optional iterator of ``(obj, dependencies)``
tuples to build the graph from.
Warning:
Does not support cycle detection.
"""
def __init__(self, it=None, formatter=None):
self.formatter = formatter or GraphFormatter()
self.adjacent = {}
if it is not None:
self.update(it)
def add_arc(self, obj):
"""Add an object to the graph."""
self.adjacent.setdefault(obj, [])
def add_edge(self, A, B):
"""Add an edge from object ``A`` to object ``B``.
I.e. ``A`` depends on ``B``.
"""
self[A].append(B)
def connect(self, graph):
"""Add nodes from another graph."""
self.adjacent.update(graph.adjacent)
def topsort(self):
"""Sort the graph topologically.
Returns:
List: of objects in the order in which they must be handled.
"""
graph = DependencyGraph()
components = self._tarjan72()
NC = {
node: component for component in components for node in component
}
for component in components:
graph.add_arc(component)
for node in self:
node_c = NC[node]
for successor in self[node]:
successor_c = NC[successor]
if node_c != successor_c:
graph.add_edge(node_c, successor_c)
return [t[0] for t in graph._khan62()]
def valency_of(self, obj):
"""Return the valency (degree) of a vertex in the graph."""
try:
l = [len(self[obj])]
except KeyError:
return 0
for node in self[obj]:
l.append(self.valency_of(node))
return sum(l)
def update(self, it):
"""Update graph with data from a list of ``(obj, deps)`` tuples."""
tups = list(it)
for obj, _ in tups:
self.add_arc(obj)
for obj, deps in tups:
for dep in deps:
self.add_edge(obj, dep)
def edges(self):
"""Return generator that yields for all edges in the graph."""
return (obj for obj, adj in self.items() if adj)
def _khan62(self):
"""Perform Khan's simple topological sort algorithm from '62.
See https://en.wikipedia.org/wiki/Topological_sorting
"""
count = Counter()
result = []
for node in self:
for successor in self[node]:
count[successor] += 1
ready = [node for node in self if not count[node]]
while ready:
node = ready.pop()
result.append(node)
for successor in self[node]:
count[successor] -= 1
if count[successor] == 0:
ready.append(successor)
result.reverse()
return result
def _tarjan72(self):
"""Perform Tarjan's algorithm to find strongly connected components.
See Also:
:wikipedia:`Tarjan%27s_strongly_connected_components_algorithm`
"""
result, stack, low = [], [], {}
def visit(node):
if node in low:
return
num = len(low)
low[node] = num
stack_pos = len(stack)
stack.append(node)
for successor in self[node]:
visit(successor)
low[node] = min(low[node], low[successor])
if num == low[node]:
component = tuple(stack[stack_pos:])
stack[stack_pos:] = []
result.append(component)
for item in component:
low[item] = len(self)
for node in self:
visit(node)
return result
def to_dot(self, fh, formatter=None):
"""Convert the graph to DOT format.
Arguments:
fh (IO): A file, or a file-like object to write the graph to.
formatter (celery.utils.graph.GraphFormatter): Custom graph
formatter to use.
"""
seen = set()
draw = formatter or self.formatter
def P(s):
print(bytes_to_str(s), file=fh)
def if_not_seen(fun, obj):
if draw.label(obj) not in seen:
P(fun(obj))
seen.add(draw.label(obj))
P(draw.head())
for obj, adjacent in self.items():
if not adjacent:
if_not_seen(draw.terminal_node, obj)
for req in adjacent:
if_not_seen(draw.node, obj)
P(draw.edge(obj, req))
P(draw.tail())
def format(self, obj):
return self.formatter(obj) if self.formatter else obj
def __iter__(self):
return iter(self.adjacent)
def __getitem__(self, node):
return self.adjacent[node]
def __len__(self):
return len(self.adjacent)
def __contains__(self, obj):
return obj in self.adjacent
def _iterate_items(self):
return self.adjacent.items()
items = iteritems = _iterate_items
def __repr__(self):
return '\n'.join(self.repr_node(N) for N in self)
def repr_node(self, obj, level=1, fmt='{0}({1})'):
output = [fmt.format(obj, self.valency_of(obj))]
if obj in self:
for other in self[obj]:
d = fmt.format(other, self.valency_of(other))
output.append(' ' * level + d)
output.extend(self.repr_node(other, level + 1).split('\n')[1:])
return '\n'.join(output)
class GraphFormatter:
"""Format dependency graphs."""
_attr = DOT.ATTR.strip()
_node = DOT.NODE.strip()
_edge = DOT.EDGE.strip()
_head = DOT.HEAD.strip()
_tail = DOT.TAIL.strip()
_attrsep = DOT.ATTRSEP
_dirs = dict(DOT.DIRS)
scheme = {
'shape': 'box',
'arrowhead': 'vee',
'style': 'filled',
'fontname': 'HelveticaNeue',
}
edge_scheme = {
'color': 'darkseagreen4',
'arrowcolor': 'black',
'arrowsize': 0.7,
}
node_scheme = {'fillcolor': 'palegreen3', 'color': 'palegreen4'}
term_scheme = {'fillcolor': 'palegreen1', 'color': 'palegreen2'}
graph_scheme = {'bgcolor': 'mintcream'}
def __init__(self, root=None, type=None, id=None,
indent=0, inw=' ' * 4, **scheme):
self.id = id or 'dependencies'
self.root = root
self.type = type or 'digraph'
self.direction = self._dirs[self.type]
self.IN = inw * (indent or 0)
self.INp = self.IN + inw
self.scheme = dict(self.scheme, **scheme)
self.graph_scheme = dict(self.graph_scheme, root=self.label(self.root))
def attr(self, name, value):
value = f'"{value}"'
return self.FMT(self._attr, name=name, value=value)
def attrs(self, d, scheme=None):
d = dict(self.scheme, **dict(scheme, **d or {}) if scheme else d)
return self._attrsep.join(
safe_str(self.attr(k, v)) for k, v in d.items()
)
def head(self, **attrs):
return self.FMT(
self._head, id=self.id, type=self.type,
attrs=self.attrs(attrs, self.graph_scheme),
)
def tail(self):
return self.FMT(self._tail)
def label(self, obj):
return obj
def node(self, obj, **attrs):
return self.draw_node(obj, self.node_scheme, attrs)
def terminal_node(self, obj, **attrs):
return self.draw_node(obj, self.term_scheme, attrs)
def edge(self, a, b, **attrs):
return self.draw_edge(a, b, **attrs)
def _enc(self, s):
return s.encode('utf-8', 'ignore')
def FMT(self, fmt, *args, **kwargs):
return self._enc(fmt.format(
*args, **dict(kwargs, IN=self.IN, INp=self.INp)
))
def draw_edge(self, a, b, scheme=None, attrs=None):
return self.FMT(
self._edge, self.label(a), self.label(b),
dir=self.direction, attrs=self.attrs(attrs, self.edge_scheme),
)
def draw_node(self, obj, scheme=None, attrs=None):
return self.FMT(
self._node, self.label(obj), attrs=self.attrs(attrs, scheme),
)

View File

@@ -0,0 +1,163 @@
"""Utilities related to importing modules and symbols by name."""
import os
import sys
import warnings
from contextlib import contextmanager
from importlib import import_module, reload
try:
from importlib.metadata import entry_points
except ImportError:
from importlib_metadata import entry_points
from kombu.utils.imports import symbol_by_name
#: Billiard sets this when execv is enabled.
#: We use it to find out the name of the original ``__main__``
#: module, so that we can properly rewrite the name of the
#: task to be that of ``App.main``.
MP_MAIN_FILE = os.environ.get('MP_MAIN_FILE')
__all__ = (
'NotAPackage', 'qualname', 'instantiate', 'symbol_by_name',
'cwd_in_path', 'find_module', 'import_from_cwd',
'reload_from_cwd', 'module_file', 'gen_task_name',
)
class NotAPackage(Exception):
"""Raised when importing a package, but it's not a package."""
def qualname(obj):
"""Return object name."""
if not hasattr(obj, '__name__') and hasattr(obj, '__class__'):
obj = obj.__class__
q = getattr(obj, '__qualname__', None)
if '.' not in q:
q = '.'.join((obj.__module__, q))
return q
def instantiate(name, *args, **kwargs):
"""Instantiate class by name.
See Also:
:func:`symbol_by_name`.
"""
return symbol_by_name(name)(*args, **kwargs)
@contextmanager
def cwd_in_path():
"""Context adding the current working directory to sys.path."""
cwd = os.getcwd()
if cwd in sys.path:
yield
else:
sys.path.insert(0, cwd)
try:
yield cwd
finally:
try:
sys.path.remove(cwd)
except ValueError: # pragma: no cover
pass
def find_module(module, path=None, imp=None):
"""Version of :func:`imp.find_module` supporting dots."""
if imp is None:
imp = import_module
with cwd_in_path():
try:
return imp(module)
except ImportError:
# Raise a more specific error if the problem is that one of the
# dot-separated segments of the module name is not a package.
if '.' in module:
parts = module.split('.')
for i, part in enumerate(parts[:-1]):
package = '.'.join(parts[:i + 1])
try:
mpart = imp(package)
except ImportError:
# Break out and re-raise the original ImportError
# instead.
break
try:
mpart.__path__
except AttributeError:
raise NotAPackage(package)
raise
def import_from_cwd(module, imp=None, package=None):
"""Import module, temporarily including modules in the current directory.
Modules located in the current directory has
precedence over modules located in `sys.path`.
"""
if imp is None:
imp = import_module
with cwd_in_path():
return imp(module, package=package)
def reload_from_cwd(module, reloader=None):
"""Reload module (ensuring that CWD is in sys.path)."""
if reloader is None:
reloader = reload
with cwd_in_path():
return reloader(module)
def module_file(module):
"""Return the correct original file name of a module."""
name = module.__file__
return name[:-1] if name.endswith('.pyc') else name
def gen_task_name(app, name, module_name):
"""Generate task name from name/module pair."""
module_name = module_name or '__main__'
try:
module = sys.modules[module_name]
except KeyError:
# Fix for manage.py shell_plus (Issue #366)
module = None
if module is not None:
module_name = module.__name__
# - If the task module is used as the __main__ script
# - we need to rewrite the module part of the task name
# - to match App.main.
if MP_MAIN_FILE and module.__file__ == MP_MAIN_FILE:
# - see comment about :envvar:`MP_MAIN_FILE` above.
module_name = '__main__'
if module_name == '__main__' and app.main:
return '.'.join([app.main, name])
return '.'.join(p for p in (module_name, name) if p)
def load_extension_class_names(namespace):
if sys.version_info >= (3, 10):
_entry_points = entry_points(group=namespace)
else:
try:
_entry_points = entry_points().get(namespace, [])
except AttributeError:
_entry_points = entry_points().select(group=namespace)
for ep in _entry_points:
yield ep.name, ep.value
def load_extension_classes(namespace):
for name, class_name in load_extension_class_names(namespace):
try:
cls = symbol_by_name(class_name)
except (ImportError, SyntaxError) as exc:
warnings.warn(
f'Cannot load {namespace} extension {class_name!r}: {exc!r}')
else:
yield name, cls

View File

@@ -0,0 +1,76 @@
"""Parse ISO8601 dates.
Originally taken from :pypi:`pyiso8601`
(https://bitbucket.org/micktwomey/pyiso8601)
Modified to match the behavior of ``dateutil.parser``:
- raise :exc:`ValueError` instead of ``ParseError``
- return naive :class:`~datetime.datetime` by default
This is the original License:
Copyright (c) 2007 Michael Twomey
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the
"Software"), to deal in the Software without restriction, including
without limitation the rights to use, copy, modify, merge, publish,
distribute, sub-license, and/or sell copies of the Software, and to
permit persons to whom the Software is furnished to do so, subject to
the following conditions:
The above copyright notice and this permission notice shall be included
in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
"""
import re
from datetime import datetime, timedelta, timezone
from celery.utils.deprecated import warn
__all__ = ('parse_iso8601',)
# Adapted from http://delete.me.uk/2005/03/iso8601.html
ISO8601_REGEX = re.compile(
r'(?P<year>[0-9]{4})(-(?P<month>[0-9]{1,2})(-(?P<day>[0-9]{1,2})'
r'((?P<separator>.)(?P<hour>[0-9]{2}):(?P<minute>[0-9]{2})'
r'(:(?P<second>[0-9]{2})(\.(?P<fraction>[0-9]+))?)?'
r'(?P<timezone>Z|(([-+])([0-9]{2}):([0-9]{2})))?)?)?)?'
)
TIMEZONE_REGEX = re.compile(
r'(?P<prefix>[+-])(?P<hours>[0-9]{2}).(?P<minutes>[0-9]{2})'
)
def parse_iso8601(datestring):
"""Parse and convert ISO-8601 string to datetime."""
warn("parse_iso8601", "v5.3", "v6", "datetime.datetime.fromisoformat")
m = ISO8601_REGEX.match(datestring)
if not m:
raise ValueError('unable to parse date string %r' % datestring)
groups = m.groupdict()
tz = groups['timezone']
if tz == 'Z':
tz = timezone(timedelta(0))
elif tz:
m = TIMEZONE_REGEX.match(tz)
prefix, hours, minutes = m.groups()
hours, minutes = int(hours), int(minutes)
if prefix == '-':
hours = -hours
minutes = -minutes
tz = timezone(timedelta(minutes=minutes, hours=hours))
return datetime(
int(groups['year']), int(groups['month']),
int(groups['day']), int(groups['hour'] or 0),
int(groups['minute'] or 0), int(groups['second'] or 0),
int(groups['fraction'] or 0), tz
)

View File

@@ -0,0 +1,295 @@
"""Logging utilities."""
import logging
import numbers
import os
import sys
import threading
import traceback
from contextlib import contextmanager
from typing import AnyStr, Sequence # noqa
from kombu.log import LOG_LEVELS
from kombu.log import get_logger as _get_logger
from kombu.utils.encoding import safe_str
from .term import colored
__all__ = (
'ColorFormatter', 'LoggingProxy', 'base_logger',
'set_in_sighandler', 'in_sighandler', 'get_logger',
'get_task_logger', 'mlevel',
'get_multiprocessing_logger', 'reset_multiprocessing_logger', 'LOG_LEVELS'
)
_process_aware = False
_in_sighandler = False
MP_LOG = os.environ.get('MP_LOG', False)
RESERVED_LOGGER_NAMES = {'celery', 'celery.task'}
# Sets up our logging hierarchy.
#
# Every logger in the celery package inherits from the "celery"
# logger, and every task logger inherits from the "celery.task"
# logger.
base_logger = logger = _get_logger('celery')
def set_in_sighandler(value):
"""Set flag signifiying that we're inside a signal handler."""
global _in_sighandler
_in_sighandler = value
def iter_open_logger_fds():
seen = set()
loggers = (list(logging.Logger.manager.loggerDict.values()) +
[logging.getLogger(None)])
for l in loggers:
try:
for handler in l.handlers:
try:
if handler not in seen: # pragma: no cover
yield handler.stream
seen.add(handler)
except AttributeError:
pass
except AttributeError: # PlaceHolder does not have handlers
pass
@contextmanager
def in_sighandler():
"""Context that records that we are in a signal handler."""
set_in_sighandler(True)
try:
yield
finally:
set_in_sighandler(False)
def logger_isa(l, p, max=1000):
this, seen = l, set()
for _ in range(max):
if this == p:
return True
else:
if this in seen:
raise RuntimeError(
f'Logger {l.name!r} parents recursive',
)
seen.add(this)
this = this.parent
if not this:
break
else: # pragma: no cover
raise RuntimeError(f'Logger hierarchy exceeds {max}')
return False
def _using_logger_parent(parent_logger, logger_):
if not logger_isa(logger_, parent_logger):
logger_.parent = parent_logger
return logger_
def get_logger(name):
"""Get logger by name."""
l = _get_logger(name)
if logging.root not in (l, l.parent) and l is not base_logger:
l = _using_logger_parent(base_logger, l)
return l
task_logger = get_logger('celery.task')
worker_logger = get_logger('celery.worker')
def get_task_logger(name):
"""Get logger for task module by name."""
if name in RESERVED_LOGGER_NAMES:
raise RuntimeError(f'Logger name {name!r} is reserved!')
return _using_logger_parent(task_logger, get_logger(name))
def mlevel(level):
"""Convert level name/int to log level."""
if level and not isinstance(level, numbers.Integral):
return LOG_LEVELS[level.upper()]
return level
class ColorFormatter(logging.Formatter):
"""Logging formatter that adds colors based on severity."""
#: Loglevel -> Color mapping.
COLORS = colored().names
colors = {
'DEBUG': COLORS['blue'],
'WARNING': COLORS['yellow'],
'ERROR': COLORS['red'],
'CRITICAL': COLORS['magenta'],
}
def __init__(self, fmt=None, use_color=True):
super().__init__(fmt)
self.use_color = use_color
def formatException(self, ei):
if ei and not isinstance(ei, tuple):
ei = sys.exc_info()
r = super().formatException(ei)
return r
def format(self, record):
msg = super().format(record)
color = self.colors.get(record.levelname)
# reset exception info later for other handlers...
einfo = sys.exc_info() if record.exc_info == 1 else record.exc_info
if color and self.use_color:
try:
# safe_str will repr the color object
# and color will break on non-string objects
# so need to reorder calls based on type.
# Issue #427
try:
if isinstance(msg, str):
return str(color(safe_str(msg)))
return safe_str(color(msg))
except UnicodeDecodeError: # pragma: no cover
return safe_str(msg) # skip colors
except Exception as exc: # pylint: disable=broad-except
prev_msg, record.exc_info, record.msg = (
record.msg, 1, '<Unrepresentable {!r}: {!r}>'.format(
type(msg), exc
),
)
try:
return super().format(record)
finally:
record.msg, record.exc_info = prev_msg, einfo
else:
return safe_str(msg)
class LoggingProxy:
"""Forward file object to :class:`logging.Logger` instance.
Arguments:
logger (~logging.Logger): Logger instance to forward to.
loglevel (int, str): Log level to use when logging messages.
"""
mode = 'w'
name = None
closed = False
loglevel = logging.ERROR
_thread = threading.local()
def __init__(self, logger, loglevel=None):
# pylint: disable=redefined-outer-name
# Note that the logger global is redefined here, be careful changing.
self.logger = logger
self.loglevel = mlevel(loglevel or self.logger.level or self.loglevel)
self._safewrap_handlers()
def _safewrap_handlers(self):
# Make the logger handlers dump internal errors to
# :data:`sys.__stderr__` instead of :data:`sys.stderr` to circumvent
# infinite loops.
def wrap_handler(handler): # pragma: no cover
class WithSafeHandleError(logging.Handler):
def handleError(self, record):
try:
traceback.print_exc(None, sys.__stderr__)
except OSError:
pass # see python issue 5971
handler.handleError = WithSafeHandleError().handleError
return [wrap_handler(h) for h in self.logger.handlers]
def write(self, data):
# type: (AnyStr) -> int
"""Write message to logging object."""
if _in_sighandler:
safe_data = safe_str(data)
print(safe_data, file=sys.__stderr__)
return len(safe_data)
if getattr(self._thread, 'recurse_protection', False):
# Logger is logging back to this file, so stop recursing.
return 0
if data and not self.closed:
self._thread.recurse_protection = True
try:
safe_data = safe_str(data).rstrip('\n')
if safe_data:
self.logger.log(self.loglevel, safe_data)
return len(safe_data)
finally:
self._thread.recurse_protection = False
return 0
def writelines(self, sequence):
# type: (Sequence[str]) -> None
"""Write list of strings to file.
The sequence can be any iterable object producing strings.
This is equivalent to calling :meth:`write` for each string.
"""
for part in sequence:
self.write(part)
def flush(self):
# This object is not buffered so any :meth:`flush`
# requests are ignored.
pass
def close(self):
# when the object is closed, no write requests are
# forwarded to the logging object anymore.
self.closed = True
def isatty(self):
"""Here for file support."""
return False
def get_multiprocessing_logger():
"""Return the multiprocessing logger."""
try:
from billiard import util
except ImportError:
pass
else:
return util.get_logger()
def reset_multiprocessing_logger():
"""Reset multiprocessing logging setup."""
try:
from billiard import util
except ImportError:
pass
else:
if hasattr(util, '_logger'): # pragma: no cover
util._logger = None
def current_process():
try:
from billiard import process
except ImportError:
pass
else:
return process.current_process()
def current_process_index(base=1):
index = getattr(current_process(), 'index', None)
return index + base if index is not None else index

View File

@@ -0,0 +1,102 @@
"""Worker name utilities."""
import os
import socket
from functools import partial
from kombu.entity import Exchange, Queue
from .functional import memoize
from .text import simple_format
#: Exchange for worker direct queues.
WORKER_DIRECT_EXCHANGE = Exchange('C.dq2')
#: Format for worker direct queue names.
WORKER_DIRECT_QUEUE_FORMAT = '{hostname}.dq2'
#: Separator for worker node name and hostname.
NODENAME_SEP = '@'
NODENAME_DEFAULT = 'celery'
gethostname = memoize(1, Cache=dict)(socket.gethostname)
__all__ = (
'worker_direct', 'gethostname', 'nodename',
'anon_nodename', 'nodesplit', 'default_nodename',
'node_format', 'host_format',
)
def worker_direct(hostname):
"""Return the :class:`kombu.Queue` being a direct route to a worker.
Arguments:
hostname (str, ~kombu.Queue): The fully qualified node name of
a worker (e.g., ``w1@example.com``). If passed a
:class:`kombu.Queue` instance it will simply return
that instead.
"""
if isinstance(hostname, Queue):
return hostname
return Queue(
WORKER_DIRECT_QUEUE_FORMAT.format(hostname=hostname),
WORKER_DIRECT_EXCHANGE,
hostname,
)
def nodename(name, hostname):
"""Create node name from name/hostname pair."""
return NODENAME_SEP.join((name, hostname))
def anon_nodename(hostname=None, prefix='gen'):
"""Return the nodename for this process (not a worker).
This is used for e.g. the origin task message field.
"""
return nodename(''.join([prefix, str(os.getpid())]),
hostname or gethostname())
def nodesplit(name):
"""Split node name into tuple of name/hostname."""
parts = name.split(NODENAME_SEP, 1)
if len(parts) == 1:
return None, parts[0]
return parts
def default_nodename(hostname):
"""Return the default nodename for this process."""
name, host = nodesplit(hostname or '')
return nodename(name or NODENAME_DEFAULT, host or gethostname())
def node_format(s, name, **extra):
"""Format worker node name (name@host.com)."""
shortname, host = nodesplit(name)
return host_format(
s, host, shortname or NODENAME_DEFAULT, p=name, **extra)
def _fmt_process_index(prefix='', default='0'):
from .log import current_process_index
index = current_process_index()
return f'{prefix}{index}' if index else default
_fmt_process_index_with_prefix = partial(_fmt_process_index, '-', '')
def host_format(s, host=None, name=None, **extra):
"""Format host %x abbreviations."""
host = host or gethostname()
hname, _, domain = host.partition('.')
name = name or hname
keys = dict({
'h': host, 'n': name, 'd': domain,
'i': _fmt_process_index, 'I': _fmt_process_index_with_prefix,
}, **extra)
return simple_format(s, keys)

View File

@@ -0,0 +1,142 @@
"""Object related utilities, including introspection, etc."""
from functools import reduce
__all__ = ('Bunch', 'FallbackContext', 'getitem_property', 'mro_lookup')
class Bunch:
"""Object that enables you to modify attributes."""
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
def mro_lookup(cls, attr, stop=None, monkey_patched=None):
"""Return the first node by MRO order that defines an attribute.
Arguments:
cls (Any): Child class to traverse.
attr (str): Name of attribute to find.
stop (Set[Any]): A set of types that if reached will stop
the search.
monkey_patched (Sequence): Use one of the stop classes
if the attributes module origin isn't in this list.
Used to detect monkey patched attributes.
Returns:
Any: The attribute value, or :const:`None` if not found.
"""
stop = set() if not stop else stop
monkey_patched = [] if not monkey_patched else monkey_patched
for node in cls.mro():
if node in stop:
try:
value = node.__dict__[attr]
module_origin = value.__module__
except (AttributeError, KeyError):
pass
else:
if module_origin not in monkey_patched:
return node
return
if attr in node.__dict__:
return node
class FallbackContext:
"""Context workaround.
The built-in ``@contextmanager`` utility does not work well
when wrapping other contexts, as the traceback is wrong when
the wrapped context raises.
This solves this problem and can be used instead of ``@contextmanager``
in this example::
@contextmanager
def connection_or_default_connection(connection=None):
if connection:
# user already has a connection, shouldn't close
# after use
yield connection
else:
# must've new connection, and also close the connection
# after the block returns
with create_new_connection() as connection:
yield connection
This wrapper can be used instead for the above like this::
def connection_or_default_connection(connection=None):
return FallbackContext(connection, create_new_connection)
"""
def __init__(self, provided, fallback, *fb_args, **fb_kwargs):
self.provided = provided
self.fallback = fallback
self.fb_args = fb_args
self.fb_kwargs = fb_kwargs
self._context = None
def __enter__(self):
if self.provided is not None:
return self.provided
context = self._context = self.fallback(
*self.fb_args, **self.fb_kwargs
).__enter__()
return context
def __exit__(self, *exc_info):
if self._context is not None:
return self._context.__exit__(*exc_info)
class getitem_property:
"""Attribute -> dict key descriptor.
The target object must support ``__getitem__``,
and optionally ``__setitem__``.
Example:
>>> from collections import defaultdict
>>> class Me(dict):
... deep = defaultdict(dict)
...
... foo = _getitem_property('foo')
... deep_thing = _getitem_property('deep.thing')
>>> me = Me()
>>> me.foo
None
>>> me.foo = 10
>>> me.foo
10
>>> me['foo']
10
>>> me.deep_thing = 42
>>> me.deep_thing
42
>>> me.deep
defaultdict(<type 'dict'>, {'thing': 42})
"""
def __init__(self, keypath, doc=None):
path, _, self.key = keypath.rpartition('.')
self.path = path.split('.') if path else None
self.__doc__ = doc
def _path(self, obj):
return (reduce(lambda d, k: d[k], [obj] + self.path) if self.path
else obj)
def __get__(self, obj, type=None):
if obj is None:
return type
return self._path(obj).get(self.key)
def __set__(self, obj, value):
self._path(obj)[self.key] = value

View File

@@ -0,0 +1,266 @@
"""Streaming, truncating, non-recursive version of :func:`repr`.
Differences from regular :func:`repr`:
- Sets are represented the Python 3 way: ``{1, 2}`` vs ``set([1, 2])``.
- Unicode strings does not have the ``u'`` prefix, even on Python 2.
- Empty set formatted as ``set()`` (Python 3), not ``set([])`` (Python 2).
- Longs don't have the ``L`` suffix.
Very slow with no limits, super quick with limits.
"""
import traceback
from collections import deque, namedtuple
from decimal import Decimal
from itertools import chain
from numbers import Number
from pprint import _recursion
from typing import Any, AnyStr, Callable, Dict, Iterator, List, Sequence, Set, Tuple # noqa
from .text import truncate
__all__ = ('saferepr', 'reprstream')
#: Node representing literal text.
#: - .value: is the literal text value
#: - .truncate: specifies if this text can be truncated, for things like
#: LIT_DICT_END this will be False, as we always display
#: the ending brackets, e.g: [[[1, 2, 3, ...,], ..., ]]
#: - .direction: If +1 the current level is increment by one,
#: if -1 the current level is decremented by one, and
#: if 0 the current level is unchanged.
_literal = namedtuple('_literal', ('value', 'truncate', 'direction'))
#: Node representing a dictionary key.
_key = namedtuple('_key', ('value',))
#: Node representing quoted text, e.g. a string value.
_quoted = namedtuple('_quoted', ('value',))
#: Recursion protection.
_dirty = namedtuple('_dirty', ('objid',))
#: Types that are repsented as chars.
chars_t = (bytes, str)
#: Types that are regarded as safe to call repr on.
safe_t = (Number,)
#: Set types.
set_t = (frozenset, set)
LIT_DICT_START = _literal('{', False, +1)
LIT_DICT_KVSEP = _literal(': ', True, 0)
LIT_DICT_END = _literal('}', False, -1)
LIT_LIST_START = _literal('[', False, +1)
LIT_LIST_END = _literal(']', False, -1)
LIT_LIST_SEP = _literal(', ', True, 0)
LIT_SET_START = _literal('{', False, +1)
LIT_SET_END = _literal('}', False, -1)
LIT_TUPLE_START = _literal('(', False, +1)
LIT_TUPLE_END = _literal(')', False, -1)
LIT_TUPLE_END_SV = _literal(',)', False, -1)
def saferepr(o, maxlen=None, maxlevels=3, seen=None):
# type: (Any, int, int, Set) -> str
"""Safe version of :func:`repr`.
Warning:
Make sure you set the maxlen argument, or it will be very slow
for recursive objects. With the maxlen set, it's often faster
than built-in repr.
"""
return ''.join(_saferepr(
o, maxlen=maxlen, maxlevels=maxlevels, seen=seen
))
def _chaindict(mapping,
LIT_DICT_KVSEP=LIT_DICT_KVSEP,
LIT_LIST_SEP=LIT_LIST_SEP):
# type: (Dict, _literal, _literal) -> Iterator[Any]
size = len(mapping)
for i, (k, v) in enumerate(mapping.items()):
yield _key(k)
yield LIT_DICT_KVSEP
yield v
if i < (size - 1):
yield LIT_LIST_SEP
def _chainlist(it, LIT_LIST_SEP=LIT_LIST_SEP):
# type: (List) -> Iterator[Any]
size = len(it)
for i, v in enumerate(it):
yield v
if i < (size - 1):
yield LIT_LIST_SEP
def _repr_empty_set(s):
# type: (Set) -> str
return f'{type(s).__name__}()'
def _safetext(val):
# type: (AnyStr) -> str
if isinstance(val, bytes):
try:
val.encode('utf-8')
except UnicodeDecodeError:
# is bytes with unrepresentable characters, attempt
# to convert back to unicode
return val.decode('utf-8', errors='backslashreplace')
return val
def _format_binary_bytes(val, maxlen, ellipsis='...'):
# type: (bytes, int, str) -> str
if maxlen and len(val) > maxlen:
# we don't want to copy all the data, just take what we need.
chunk = memoryview(val)[:maxlen].tobytes()
return _bytes_prefix(f"'{_repr_binary_bytes(chunk)}{ellipsis}'")
return _bytes_prefix(f"'{_repr_binary_bytes(val)}'")
def _bytes_prefix(s):
return 'b' + s
def _repr_binary_bytes(val):
# type: (bytes) -> str
try:
return val.decode('utf-8')
except UnicodeDecodeError:
# possibly not unicode, but binary data so format as hex.
return val.hex()
def _format_chars(val, maxlen):
# type: (AnyStr, int) -> str
if isinstance(val, bytes): # pragma: no cover
return _format_binary_bytes(val, maxlen)
else:
return "'{}'".format(truncate(val, maxlen).replace("'", "\\'"))
def _repr(obj):
# type: (Any) -> str
try:
return repr(obj)
except Exception as exc:
stack = '\n'.join(traceback.format_stack())
return f'<Unrepresentable {type(obj)!r}{id(obj):#x}: {exc!r} {stack!r}>'
def _saferepr(o, maxlen=None, maxlevels=3, seen=None):
# type: (Any, int, int, Set) -> str
stack = deque([iter([o])])
for token, it in reprstream(stack, seen=seen, maxlevels=maxlevels):
if maxlen is not None and maxlen <= 0:
yield ', ...'
# move rest back to stack, so that we can include
# dangling parens.
stack.append(it)
break
if isinstance(token, _literal):
val = token.value
elif isinstance(token, _key):
val = saferepr(token.value, maxlen, maxlevels)
elif isinstance(token, _quoted):
val = _format_chars(token.value, maxlen)
else:
val = _safetext(truncate(token, maxlen))
yield val
if maxlen is not None:
maxlen -= len(val)
for rest1 in stack:
# maxlen exceeded, process any dangling parens.
for rest2 in rest1:
if isinstance(rest2, _literal) and not rest2.truncate:
yield rest2.value
def _reprseq(val, lit_start, lit_end, builtin_type, chainer):
# type: (Sequence, _literal, _literal, Any, Any) -> Tuple[Any, ...]
if type(val) is builtin_type:
return lit_start, lit_end, chainer(val)
return (
_literal(f'{type(val).__name__}({lit_start.value}', False, +1),
_literal(f'{lit_end.value})', False, -1),
chainer(val)
)
def reprstream(stack, seen=None, maxlevels=3, level=0, isinstance=isinstance):
"""Streaming repr, yielding tokens."""
# type: (deque, Set, int, int, Callable) -> Iterator[Any]
seen = seen or set()
append = stack.append
popleft = stack.popleft
is_in_seen = seen.__contains__
discard_from_seen = seen.discard
add_to_seen = seen.add
while stack:
lit_start = lit_end = None
it = popleft()
for val in it:
orig = val
if isinstance(val, _dirty):
discard_from_seen(val.objid)
continue
elif isinstance(val, _literal):
level += val.direction
yield val, it
elif isinstance(val, _key):
yield val, it
elif isinstance(val, Decimal):
yield _repr(val), it
elif isinstance(val, safe_t):
yield str(val), it
elif isinstance(val, chars_t):
yield _quoted(val), it
elif isinstance(val, range): # pragma: no cover
yield _repr(val), it
else:
if isinstance(val, set_t):
if not val:
yield _repr_empty_set(val), it
continue
lit_start, lit_end, val = _reprseq(
val, LIT_SET_START, LIT_SET_END, set, _chainlist,
)
elif isinstance(val, tuple):
lit_start, lit_end, val = (
LIT_TUPLE_START,
LIT_TUPLE_END_SV if len(val) == 1 else LIT_TUPLE_END,
_chainlist(val))
elif isinstance(val, dict):
lit_start, lit_end, val = (
LIT_DICT_START, LIT_DICT_END, _chaindict(val))
elif isinstance(val, list):
lit_start, lit_end, val = (
LIT_LIST_START, LIT_LIST_END, _chainlist(val))
else:
# other type of object
yield _repr(val), it
continue
if maxlevels and level >= maxlevels:
yield f'{lit_start.value}...{lit_end.value}', it
continue
objid = id(orig)
if is_in_seen(objid):
yield _recursion(orig), it
continue
add_to_seen(objid)
# Recurse into the new list/tuple/dict/etc by tacking
# the rest of our iterable onto the new it: this way
# it works similar to a linked list.
append(chain([lit_start], val, [_dirty(objid), lit_end], it))
break

View File

@@ -0,0 +1,273 @@
"""Utilities for safely pickling exceptions."""
import datetime
import numbers
import sys
from base64 import b64decode as base64decode
from base64 import b64encode as base64encode
from functools import partial
from inspect import getmro
from itertools import takewhile
from kombu.utils.encoding import bytes_to_str, safe_repr, str_to_bytes
try:
import cPickle as pickle
except ImportError:
import pickle
__all__ = (
'UnpickleableExceptionWrapper', 'subclass_exception',
'find_pickleable_exception', 'create_exception_cls',
'get_pickleable_exception', 'get_pickleable_etype',
'get_pickled_exception', 'strtobool',
)
#: List of base classes we probably don't want to reduce to.
unwanted_base_classes = (Exception, BaseException, object)
STRTOBOOL_DEFAULT_TABLE = {'false': False, 'no': False, '0': False,
'true': True, 'yes': True, '1': True,
'on': True, 'off': False}
def subclass_exception(name, parent, module):
"""Create new exception class."""
return type(name, (parent,), {'__module__': module})
def find_pickleable_exception(exc, loads=pickle.loads,
dumps=pickle.dumps):
"""Find first pickleable exception base class.
With an exception instance, iterate over its super classes (by MRO)
and find the first super exception that's pickleable. It does
not go below :exc:`Exception` (i.e., it skips :exc:`Exception`,
:class:`BaseException` and :class:`object`). If that happens
you should use :exc:`UnpickleableException` instead.
Arguments:
exc (BaseException): An exception instance.
loads: decoder to use.
dumps: encoder to use
Returns:
Exception: Nearest pickleable parent exception class
(except :exc:`Exception` and parents), or if the exception is
pickleable it will return :const:`None`.
"""
exc_args = getattr(exc, 'args', [])
for supercls in itermro(exc.__class__, unwanted_base_classes):
try:
superexc = supercls(*exc_args)
loads(dumps(superexc))
except Exception: # pylint: disable=broad-except
pass
else:
return superexc
def itermro(cls, stop):
return takewhile(lambda sup: sup not in stop, getmro(cls))
def create_exception_cls(name, module, parent=None):
"""Dynamically create an exception class."""
if not parent:
parent = Exception
return subclass_exception(name, parent, module)
def ensure_serializable(items, encoder):
"""Ensure items will serialize.
For a given list of arbitrary objects, return the object
or a string representation, safe for serialization.
Arguments:
items (Iterable[Any]): Objects to serialize.
encoder (Callable): Callable function to serialize with.
"""
safe_exc_args = []
for arg in items:
try:
encoder(arg)
safe_exc_args.append(arg)
except Exception: # pylint: disable=broad-except
safe_exc_args.append(safe_repr(arg))
return tuple(safe_exc_args)
class UnpickleableExceptionWrapper(Exception):
"""Wraps unpickleable exceptions.
Arguments:
exc_module (str): See :attr:`exc_module`.
exc_cls_name (str): See :attr:`exc_cls_name`.
exc_args (Tuple[Any, ...]): See :attr:`exc_args`.
Example:
>>> def pickle_it(raising_function):
... try:
... raising_function()
... except Exception as e:
... exc = UnpickleableExceptionWrapper(
... e.__class__.__module__,
... e.__class__.__name__,
... e.args,
... )
... pickle.dumps(exc) # Works fine.
"""
#: The module of the original exception.
exc_module = None
#: The name of the original exception class.
exc_cls_name = None
#: The arguments for the original exception.
exc_args = None
def __init__(self, exc_module, exc_cls_name, exc_args, text=None):
safe_exc_args = ensure_serializable(
exc_args, lambda v: pickle.loads(pickle.dumps(v))
)
self.exc_module = exc_module
self.exc_cls_name = exc_cls_name
self.exc_args = safe_exc_args
self.text = text
super().__init__(exc_module, exc_cls_name, safe_exc_args,
text)
def restore(self):
return create_exception_cls(self.exc_cls_name,
self.exc_module)(*self.exc_args)
def __str__(self):
return self.text
@classmethod
def from_exception(cls, exc):
res = cls(
exc.__class__.__module__,
exc.__class__.__name__,
getattr(exc, 'args', []),
safe_repr(exc)
)
if hasattr(exc, "__traceback__"):
res = res.with_traceback(exc.__traceback__)
return res
def get_pickleable_exception(exc):
"""Make sure exception is pickleable."""
try:
pickle.loads(pickle.dumps(exc))
except Exception: # pylint: disable=broad-except
pass
else:
return exc
nearest = find_pickleable_exception(exc)
if nearest:
return nearest
return UnpickleableExceptionWrapper.from_exception(exc)
def get_pickleable_etype(cls, loads=pickle.loads, dumps=pickle.dumps):
"""Get pickleable exception type."""
try:
loads(dumps(cls))
except Exception: # pylint: disable=broad-except
return Exception
else:
return cls
def get_pickled_exception(exc):
"""Reverse of :meth:`get_pickleable_exception`."""
if isinstance(exc, UnpickleableExceptionWrapper):
return exc.restore()
return exc
def b64encode(s):
return bytes_to_str(base64encode(str_to_bytes(s)))
def b64decode(s):
return base64decode(str_to_bytes(s))
def strtobool(term, table=None):
"""Convert common terms for true/false to bool.
Examples (true/false/yes/no/on/off/1/0).
"""
if table is None:
table = STRTOBOOL_DEFAULT_TABLE
if isinstance(term, str):
try:
return table[term.lower()]
except KeyError:
raise TypeError(f'Cannot coerce {term!r} to type bool')
return term
def _datetime_to_json(dt):
# See "Date Time String Format" in the ECMA-262 specification.
if isinstance(dt, datetime.datetime):
r = dt.isoformat()
if dt.microsecond:
r = r[:23] + r[26:]
if r.endswith('+00:00'):
r = r[:-6] + 'Z'
return r
elif isinstance(dt, datetime.time):
r = dt.isoformat()
if dt.microsecond:
r = r[:12]
return r
else:
return dt.isoformat()
def jsonify(obj,
builtin_types=(numbers.Real, str), key=None,
keyfilter=None,
unknown_type_filter=None):
"""Transform object making it suitable for json serialization."""
from kombu.abstract import Object as KombuDictType
_jsonify = partial(jsonify, builtin_types=builtin_types, key=key,
keyfilter=keyfilter,
unknown_type_filter=unknown_type_filter)
if isinstance(obj, KombuDictType):
obj = obj.as_dict(recurse=True)
if obj is None or isinstance(obj, builtin_types):
return obj
elif isinstance(obj, (tuple, list)):
return [_jsonify(v) for v in obj]
elif isinstance(obj, dict):
return {
k: _jsonify(v, key=k) for k, v in obj.items()
if (keyfilter(k) if keyfilter else 1)
}
elif isinstance(obj, (datetime.date, datetime.time)):
return _datetime_to_json(obj)
elif isinstance(obj, datetime.timedelta):
return str(obj)
else:
if unknown_type_filter is None:
raise ValueError(
f'Unsupported type: {type(obj)!r} {obj!r} (parent: {key})'
)
return unknown_type_filter(obj)
def raise_with_context(exc):
exc_info = sys.exc_info()
if not exc_info:
raise exc
elif exc_info[1] is exc:
raise
raise exc from exc_info[1]

View File

@@ -0,0 +1,14 @@
"""Static files."""
import os
def get_file(*args):
# type: (*str) -> str
"""Get filename for static file."""
return os.path.join(os.path.abspath(os.path.dirname(__file__)), *args)
def logo():
# type: () -> bytes
"""Celery logo image."""
return get_file('celery_128.png')

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.5 KiB

View File

@@ -0,0 +1,48 @@
"""System information utilities."""
import os
from math import ceil
from kombu.utils.objects import cached_property
__all__ = ('load_average', 'df')
if hasattr(os, 'getloadavg'):
def _load_average():
return tuple(ceil(l * 1e2) / 1e2 for l in os.getloadavg())
else: # pragma: no cover
# Windows doesn't have getloadavg
def _load_average():
return (0.0, 0.0, 0.0)
def load_average():
"""Return system load average as a triple."""
return _load_average()
class df:
"""Disk information."""
def __init__(self, path):
self.path = path
@property
def total_blocks(self):
return self.stat.f_blocks * self.stat.f_frsize / 1024
@property
def available(self):
return self.stat.f_bavail * self.stat.f_frsize / 1024
@property
def capacity(self):
avail = self.stat.f_bavail
used = self.stat.f_blocks - self.stat.f_bfree
return int(ceil(used * 100.0 / (used + avail) + 0.5))
@cached_property
def stat(self):
return os.statvfs(os.path.abspath(self.path))

View File

@@ -0,0 +1,177 @@
"""Terminals and colors."""
import base64
import codecs
import os
import platform
import sys
from functools import reduce
__all__ = ('colored',)
BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE = range(8)
OP_SEQ = '\033[%dm'
RESET_SEQ = '\033[0m'
COLOR_SEQ = '\033[1;%dm'
IS_WINDOWS = platform.system() == 'Windows'
ITERM_PROFILE = os.environ.get('ITERM_PROFILE')
TERM = os.environ.get('TERM')
TERM_IS_SCREEN = TERM and TERM.startswith('screen')
# tmux requires unrecognized OSC sequences to be wrapped with DCS tmux;
# <sequence> ST, and for all ESCs in <sequence> to be replaced with ESC ESC.
# It only accepts ESC backslash for ST.
_IMG_PRE = '\033Ptmux;\033\033]' if TERM_IS_SCREEN else '\033]'
_IMG_POST = '\a\033\\' if TERM_IS_SCREEN else '\a'
def fg(s):
return COLOR_SEQ % s
class colored:
"""Terminal colored text.
Example:
>>> c = colored(enabled=True)
>>> print(str(c.red('the quick '), c.blue('brown ', c.bold('fox ')),
... c.magenta(c.underline('jumps over')),
... c.yellow(' the lazy '),
... c.green('dog ')))
"""
def __init__(self, *s, **kwargs):
self.s = s
self.enabled = not IS_WINDOWS and kwargs.get('enabled', True)
self.op = kwargs.get('op', '')
self.names = {
'black': self.black,
'red': self.red,
'green': self.green,
'yellow': self.yellow,
'blue': self.blue,
'magenta': self.magenta,
'cyan': self.cyan,
'white': self.white,
}
def _add(self, a, b):
return str(a) + str(b)
def _fold_no_color(self, a, b):
try:
A = a.no_color()
except AttributeError:
A = str(a)
try:
B = b.no_color()
except AttributeError:
B = str(b)
return ''.join((str(A), str(B)))
def no_color(self):
if self.s:
return str(reduce(self._fold_no_color, self.s))
return ''
def embed(self):
prefix = ''
if self.enabled:
prefix = self.op
return ''.join((str(prefix), str(reduce(self._add, self.s))))
def __str__(self):
suffix = ''
if self.enabled:
suffix = RESET_SEQ
return str(''.join((self.embed(), str(suffix))))
def node(self, s, op):
return self.__class__(enabled=self.enabled, op=op, *s)
def black(self, *s):
return self.node(s, fg(30 + BLACK))
def red(self, *s):
return self.node(s, fg(30 + RED))
def green(self, *s):
return self.node(s, fg(30 + GREEN))
def yellow(self, *s):
return self.node(s, fg(30 + YELLOW))
def blue(self, *s):
return self.node(s, fg(30 + BLUE))
def magenta(self, *s):
return self.node(s, fg(30 + MAGENTA))
def cyan(self, *s):
return self.node(s, fg(30 + CYAN))
def white(self, *s):
return self.node(s, fg(30 + WHITE))
def __repr__(self):
return repr(self.no_color())
def bold(self, *s):
return self.node(s, OP_SEQ % 1)
def underline(self, *s):
return self.node(s, OP_SEQ % 4)
def blink(self, *s):
return self.node(s, OP_SEQ % 5)
def reverse(self, *s):
return self.node(s, OP_SEQ % 7)
def bright(self, *s):
return self.node(s, OP_SEQ % 8)
def ired(self, *s):
return self.node(s, fg(40 + RED))
def igreen(self, *s):
return self.node(s, fg(40 + GREEN))
def iyellow(self, *s):
return self.node(s, fg(40 + YELLOW))
def iblue(self, *s):
return self.node(s, fg(40 + BLUE))
def imagenta(self, *s):
return self.node(s, fg(40 + MAGENTA))
def icyan(self, *s):
return self.node(s, fg(40 + CYAN))
def iwhite(self, *s):
return self.node(s, fg(40 + WHITE))
def reset(self, *s):
return self.node(s or [''], RESET_SEQ)
def __add__(self, other):
return str(self) + str(other)
def supports_images():
return sys.stdin.isatty() and ITERM_PROFILE
def _read_as_base64(path):
with codecs.open(path, mode='rb') as fh:
encoded = base64.b64encode(fh.read())
return encoded if isinstance(encoded, str) else encoded.decode('ascii')
def imgcat(path, inline=1, preserve_aspect_ratio=0, **kwargs):
return '\n%s1337;File=inline=%d;preserveAspectRatio=%d:%s%s' % (
_IMG_PRE, inline, preserve_aspect_ratio,
_read_as_base64(path), _IMG_POST)

View File

@@ -0,0 +1,198 @@
"""Text formatting utilities."""
from __future__ import annotations
import io
import re
from functools import partial
from pprint import pformat
from re import Match
from textwrap import fill
from typing import Any, Callable, Pattern
__all__ = (
'abbr', 'abbrtask', 'dedent', 'dedent_initial',
'ensure_newlines', 'ensure_sep',
'fill_paragraphs', 'indent', 'join',
'pluralize', 'pretty', 'str_to_list', 'simple_format', 'truncate',
)
UNKNOWN_SIMPLE_FORMAT_KEY = """
Unknown format %{0} in string {1!r}.
Possible causes: Did you forget to escape the expand sign (use '%%{0!r}'),
or did you escape and the value was expanded twice? (%%N -> %N -> %hostname)?
""".strip()
RE_FORMAT = re.compile(r'%(\w)')
def str_to_list(s: str) -> list[str]:
"""Convert string to list."""
if isinstance(s, str):
return s.split(',')
return s
def dedent_initial(s: str, n: int = 4) -> str:
"""Remove indentation from first line of text."""
return s[n:] if s[:n] == ' ' * n else s
def dedent(s: str, sep: str = '\n') -> str:
"""Remove indentation."""
return sep.join(dedent_initial(l) for l in s.splitlines())
def fill_paragraphs(s: str, width: int, sep: str = '\n') -> str:
"""Fill paragraphs with newlines (or custom separator)."""
return sep.join(fill(p, width) for p in s.split(sep))
def join(l: list[str], sep: str = '\n') -> str:
"""Concatenate list of strings."""
return sep.join(v for v in l if v)
def ensure_sep(sep: str, s: str, n: int = 2) -> str:
"""Ensure text s ends in separator sep'."""
return s + sep * (n - s.count(sep))
ensure_newlines = partial(ensure_sep, '\n')
def abbr(S: str, max: int, ellipsis: str | bool = '...') -> str:
"""Abbreviate word."""
if S is None:
return '???'
if len(S) > max:
return isinstance(ellipsis, str) and (
S[: max - len(ellipsis)] + ellipsis) or S[: max]
return S
def abbrtask(S: str, max: int) -> str:
"""Abbreviate task name."""
if S is None:
return '???'
if len(S) > max:
module, _, cls = S.rpartition('.')
module = abbr(module, max - len(cls) - 3, False)
return module + '[.]' + cls
return S
def indent(t: str, indent: int = 0, sep: str = '\n') -> str:
"""Indent text."""
return sep.join(' ' * indent + p for p in t.split(sep))
def truncate(s: str, maxlen: int = 128, suffix: str = '...') -> str:
"""Truncate text to a maximum number of characters."""
if maxlen and len(s) >= maxlen:
return s[:maxlen].rsplit(' ', 1)[0] + suffix
return s
def pluralize(n: float, text: str, suffix: str = 's') -> str:
"""Pluralize term when n is greater than one."""
if n != 1:
return text + suffix
return text
def pretty(value: str, width: int = 80, nl_width: int = 80, sep: str = '\n', **
kw: Any) -> str:
"""Format value for printing to console."""
if isinstance(value, dict):
return f'{sep} {pformat(value, 4, nl_width)[1:]}'
elif isinstance(value, tuple):
return '{}{}{}'.format(
sep, ' ' * 4, pformat(value, width=nl_width, **kw),
)
else:
return pformat(value, width=width, **kw)
def match_case(s: str, other: str) -> str:
return s.upper() if other.isupper() else s.lower()
def simple_format(
s: str, keys: dict[str, str | Callable],
pattern: Pattern[str] = RE_FORMAT, expand: str = r'\1') -> str:
"""Format string, expanding abbreviations in keys'."""
if s:
keys.setdefault('%', '%')
def resolve(match: Match) -> str | Any:
key = match.expand(expand)
try:
resolver = keys[key]
except KeyError:
raise ValueError(UNKNOWN_SIMPLE_FORMAT_KEY.format(key, s))
if callable(resolver):
return resolver()
return resolver
return pattern.sub(resolve, s)
return s
def remove_repeating_from_task(task_name: str, s: str) -> str:
"""Given task name, remove repeating module names.
Example:
>>> remove_repeating_from_task(
... 'tasks.add',
... 'tasks.add(2, 2), tasks.mul(3), tasks.div(4)')
'tasks.add(2, 2), mul(3), div(4)'
"""
# This is used by e.g. repr(chain), to remove repeating module names.
# - extract the module part of the task name
module = str(task_name).rpartition('.')[0] + '.'
return remove_repeating(module, s)
def remove_repeating(substr: str, s: str) -> str:
"""Remove repeating module names from string.
Arguments:
task_name (str): Task name (full path including module),
to use as the basis for removing module names.
s (str): The string we want to work on.
Example:
>>> _shorten_names(
... 'x.tasks.add',
... 'x.tasks.add(2, 2) | x.tasks.add(4) | x.tasks.mul(8)',
... )
'x.tasks.add(2, 2) | add(4) | mul(8)'
"""
# find the first occurrence of substr in the string.
index = s.find(substr)
if index >= 0:
return ''.join([
# leave the first occurrence of substr untouched.
s[:index + len(substr)],
# strip seen substr from the rest of the string.
s[index + len(substr):].replace(substr, ''),
])
return s
StringIO = io.StringIO
_SIO_write = StringIO.write
_SIO_init = StringIO.__init__
class WhateverIO(StringIO):
"""StringIO that takes bytes or str."""
def __init__(
self, v: bytes | str | None = None, *a: Any, **kw: Any) -> None:
_SIO_init(self, v.decode() if isinstance(v, bytes) else v, *a, **kw)
def write(self, data: bytes | str) -> int:
return _SIO_write(self, data.decode()
if isinstance(data, bytes) else data)

View File

@@ -0,0 +1,331 @@
"""Threading primitives and utilities."""
import os
import socket
import sys
import threading
import traceback
from contextlib import contextmanager
from threading import TIMEOUT_MAX as THREAD_TIMEOUT_MAX
from celery.local import Proxy
try:
from greenlet import getcurrent as get_ident
except ImportError:
try:
from _thread import get_ident
except ImportError:
try:
from thread import get_ident
except ImportError:
try:
from _dummy_thread import get_ident
except ImportError:
from dummy_thread import get_ident
__all__ = (
'bgThread', 'Local', 'LocalStack', 'LocalManager',
'get_ident', 'default_socket_timeout',
)
USE_FAST_LOCALS = os.environ.get('USE_FAST_LOCALS')
@contextmanager
def default_socket_timeout(timeout):
"""Context temporarily setting the default socket timeout."""
prev = socket.getdefaulttimeout()
socket.setdefaulttimeout(timeout)
yield
socket.setdefaulttimeout(prev)
class bgThread(threading.Thread):
"""Background service thread."""
def __init__(self, name=None, **kwargs):
super().__init__()
self.__is_shutdown = threading.Event()
self.__is_stopped = threading.Event()
self.daemon = True
self.name = name or self.__class__.__name__
def body(self):
raise NotImplementedError()
def on_crash(self, msg, *fmt, **kwargs):
print(msg.format(*fmt), file=sys.stderr)
traceback.print_exc(None, sys.stderr)
def run(self):
body = self.body
shutdown_set = self.__is_shutdown.is_set
try:
while not shutdown_set():
try:
body()
except Exception as exc: # pylint: disable=broad-except
try:
self.on_crash('{0!r} crashed: {1!r}', self.name, exc)
self._set_stopped()
finally:
sys.stderr.flush()
os._exit(1) # exiting by normal means won't work
finally:
self._set_stopped()
def _set_stopped(self):
try:
self.__is_stopped.set()
except TypeError: # pragma: no cover
# we lost the race at interpreter shutdown,
# so gc collected built-in modules.
pass
def stop(self):
"""Graceful shutdown."""
self.__is_shutdown.set()
self.__is_stopped.wait()
if self.is_alive():
self.join(THREAD_TIMEOUT_MAX)
def release_local(local):
"""Release the contents of the local for the current context.
This makes it possible to use locals without a manager.
With this function one can release :class:`Local` objects as well as
:class:`StackLocal` objects. However it's not possible to
release data held by proxies that way, one always has to retain
a reference to the underlying local object in order to be able
to release it.
Example:
>>> loc = Local()
>>> loc.foo = 42
>>> release_local(loc)
>>> hasattr(loc, 'foo')
False
"""
local.__release_local__()
class Local:
"""Local object."""
__slots__ = ('__storage__', '__ident_func__')
def __init__(self):
object.__setattr__(self, '__storage__', {})
object.__setattr__(self, '__ident_func__', get_ident)
def __iter__(self):
return iter(self.__storage__.items())
def __call__(self, proxy):
"""Create a proxy for a name."""
return Proxy(self, proxy)
def __release_local__(self):
self.__storage__.pop(self.__ident_func__(), None)
def __getattr__(self, name):
try:
return self.__storage__[self.__ident_func__()][name]
except KeyError:
raise AttributeError(name)
def __setattr__(self, name, value):
ident = self.__ident_func__()
storage = self.__storage__
try:
storage[ident][name] = value
except KeyError:
storage[ident] = {name: value}
def __delattr__(self, name):
try:
del self.__storage__[self.__ident_func__()][name]
except KeyError:
raise AttributeError(name)
class _LocalStack:
"""Local stack.
This class works similar to a :class:`Local` but keeps a stack
of objects instead. This is best explained with an example::
>>> ls = LocalStack()
>>> ls.push(42)
>>> ls.top
42
>>> ls.push(23)
>>> ls.top
23
>>> ls.pop()
23
>>> ls.top
42
They can be force released by using a :class:`LocalManager` or with
the :func:`release_local` function but the correct way is to pop the
item from the stack after using. When the stack is empty it will
no longer be bound to the current context (and as such released).
By calling the stack without arguments it will return a proxy that
resolves to the topmost item on the stack.
"""
def __init__(self):
self._local = Local()
def __release_local__(self):
self._local.__release_local__()
def _get__ident_func__(self):
return self._local.__ident_func__
def _set__ident_func__(self, value):
object.__setattr__(self._local, '__ident_func__', value)
__ident_func__ = property(_get__ident_func__, _set__ident_func__)
del _get__ident_func__, _set__ident_func__
def __call__(self):
def _lookup():
rv = self.top
if rv is None:
raise RuntimeError('object unbound')
return rv
return Proxy(_lookup)
def push(self, obj):
"""Push a new item to the stack."""
rv = getattr(self._local, 'stack', None)
if rv is None:
# pylint: disable=assigning-non-slot
# This attribute is defined now.
self._local.stack = rv = []
rv.append(obj)
return rv
def pop(self):
"""Remove the topmost item from the stack.
Note:
Will return the old value or `None` if the stack was already empty.
"""
stack = getattr(self._local, 'stack', None)
if stack is None:
return None
elif len(stack) == 1:
release_local(self._local)
return stack[-1]
else:
return stack.pop()
def __len__(self):
stack = getattr(self._local, 'stack', None)
return len(stack) if stack else 0
@property
def stack(self):
# get_current_worker_task uses this to find
# the original task that was executed by the worker.
stack = getattr(self._local, 'stack', None)
if stack is not None:
return stack
return []
@property
def top(self):
"""The topmost item on the stack.
Note:
If the stack is empty, :const:`None` is returned.
"""
try:
return self._local.stack[-1]
except (AttributeError, IndexError):
return None
class LocalManager:
"""Local objects cannot manage themselves.
For that you need a local manager.
You can pass a local manager multiple locals or add them
later by appending them to ``manager.locals``. Every time the manager
cleans up, it will clean up all the data left in the locals for this
context.
The ``ident_func`` parameter can be added to override the default ident
function for the wrapped locals.
"""
def __init__(self, locals=None, ident_func=None):
if locals is None:
self.locals = []
elif isinstance(locals, Local):
self.locals = [locals]
else:
self.locals = list(locals)
if ident_func is not None:
self.ident_func = ident_func
for local in self.locals:
object.__setattr__(local, '__ident_func__', ident_func)
else:
self.ident_func = get_ident
def get_ident(self):
"""Return context identifier.
This is the identifier the local objects use internally
for this context. You cannot override this method to change the
behavior but use it to link other context local objects (such as
SQLAlchemy's scoped sessions) to the Werkzeug locals.
"""
return self.ident_func()
def cleanup(self):
"""Manually clean up the data in the locals for this context.
Call this at the end of the request or use ``make_middleware()``.
"""
for local in self.locals:
release_local(local)
def __repr__(self):
return '<{} storages: {}>'.format(
self.__class__.__name__, len(self.locals))
class _FastLocalStack(threading.local):
def __init__(self):
self.stack = []
self.push = self.stack.append
self.pop = self.stack.pop
super().__init__()
@property
def top(self):
try:
return self.stack[-1]
except (AttributeError, IndexError):
return None
def __len__(self):
return len(self.stack)
if USE_FAST_LOCALS: # pragma: no cover
LocalStack = _FastLocalStack
else: # pragma: no cover
# - See #706
# since each thread has its own greenlet we can just use those as
# identifiers for the context. If greenlets aren't available we
# fall back to the current thread ident.
LocalStack = _LocalStack

View File

@@ -0,0 +1,429 @@
"""Utilities related to dates, times, intervals, and timezones."""
from __future__ import annotations
import numbers
import os
import random
import sys
import time as _time
from calendar import monthrange
from datetime import date, datetime, timedelta
from datetime import timezone as datetime_timezone
from datetime import tzinfo
from types import ModuleType
from typing import Any, Callable
from dateutil import tz as dateutil_tz
from kombu.utils.functional import reprcall
from kombu.utils.objects import cached_property
from .functional import dictfilter
from .text import pluralize
if sys.version_info >= (3, 9):
from zoneinfo import ZoneInfo
else:
from backports.zoneinfo import ZoneInfo
__all__ = (
'LocalTimezone', 'timezone', 'maybe_timedelta',
'delta_resolution', 'remaining', 'rate', 'weekday',
'humanize_seconds', 'maybe_iso8601', 'is_naive',
'make_aware', 'localize', 'to_utc', 'maybe_make_aware',
'ffwd', 'utcoffset', 'adjust_timestamp',
'get_exponential_backoff_interval',
)
C_REMDEBUG = os.environ.get('C_REMDEBUG', False)
DAYNAMES = 'sun', 'mon', 'tue', 'wed', 'thu', 'fri', 'sat'
WEEKDAYS = dict(zip(DAYNAMES, range(7)))
RATE_MODIFIER_MAP = {
's': lambda n: n,
'm': lambda n: n / 60.0,
'h': lambda n: n / 60.0 / 60.0,
}
TIME_UNITS = (
('day', 60 * 60 * 24.0, lambda n: format(n, '.2f')),
('hour', 60 * 60.0, lambda n: format(n, '.2f')),
('minute', 60.0, lambda n: format(n, '.2f')),
('second', 1.0, lambda n: format(n, '.2f')),
)
ZERO = timedelta(0)
_local_timezone = None
class LocalTimezone(tzinfo):
"""Local time implementation. Provided in _Zone to the app when `enable_utc` is disabled.
Otherwise, _Zone provides a UTC ZoneInfo instance as the timezone implementation for the application.
Note:
Used only when the :setting:`enable_utc` setting is disabled.
"""
_offset_cache: dict[int, tzinfo] = {}
def __init__(self) -> None:
# This code is moved in __init__ to execute it as late as possible
# See get_default_timezone().
self.STDOFFSET = timedelta(seconds=-_time.timezone)
if _time.daylight:
self.DSTOFFSET = timedelta(seconds=-_time.altzone)
else:
self.DSTOFFSET = self.STDOFFSET
self.DSTDIFF = self.DSTOFFSET - self.STDOFFSET
super().__init__()
def __repr__(self) -> str:
return f'<LocalTimezone: UTC{int(self.DSTOFFSET.total_seconds() / 3600):+03d}>'
def utcoffset(self, dt: datetime) -> timedelta:
return self.DSTOFFSET if self._isdst(dt) else self.STDOFFSET
def dst(self, dt: datetime) -> timedelta:
return self.DSTDIFF if self._isdst(dt) else ZERO
def tzname(self, dt: datetime) -> str:
return _time.tzname[self._isdst(dt)]
def fromutc(self, dt: datetime) -> datetime:
# The base tzinfo class no longer implements a DST
# offset aware .fromutc() in Python 3 (Issue #2306).
offset = int(self.utcoffset(dt).seconds / 60.0)
try:
tz = self._offset_cache[offset]
except KeyError:
tz = self._offset_cache[offset] = datetime_timezone(
timedelta(minutes=offset))
return tz.fromutc(dt.replace(tzinfo=tz))
def _isdst(self, dt: datetime) -> bool:
tt = (dt.year, dt.month, dt.day,
dt.hour, dt.minute, dt.second,
dt.weekday(), 0, 0)
stamp = _time.mktime(tt)
tt = _time.localtime(stamp)
return tt.tm_isdst > 0
class _Zone:
"""Timezone class that provides the timezone for the application.
If `enable_utc` is disabled, LocalTimezone is provided as the timezone provider through local().
Otherwise, this class provides a UTC ZoneInfo instance as the timezone provider for the application.
Additionally this class provides a few utility methods for converting datetimes.
"""
def tz_or_local(self, tzinfo: tzinfo | None = None) -> tzinfo:
"""Return either our local timezone or the provided timezone."""
# pylint: disable=redefined-outer-name
if tzinfo is None:
return self.local
return self.get_timezone(tzinfo)
def to_local(self, dt: datetime, local=None, orig=None):
"""Converts a datetime to the local timezone."""
if is_naive(dt):
dt = make_aware(dt, orig or self.utc)
return localize(dt, self.tz_or_local(local))
def to_system(self, dt: datetime) -> datetime:
"""Converts a datetime to the system timezone."""
# tz=None is a special case since Python 3.3, and will
# convert to the current local timezone (Issue #2306).
return dt.astimezone(tz=None)
def to_local_fallback(self, dt: datetime) -> datetime:
"""Converts a datetime to the local timezone, or the system timezone."""
if is_naive(dt):
return make_aware(dt, self.local)
return localize(dt, self.local)
def get_timezone(self, zone: str | tzinfo) -> tzinfo:
"""Returns ZoneInfo timezone if the provided zone is a string, otherwise return the zone."""
if isinstance(zone, str):
return ZoneInfo(zone)
return zone
@cached_property
def local(self) -> LocalTimezone:
"""Return LocalTimezone instance for the application."""
return LocalTimezone()
@cached_property
def utc(self) -> tzinfo:
"""Return UTC timezone created with ZoneInfo."""
return self.get_timezone('UTC')
timezone = _Zone()
def maybe_timedelta(delta: int) -> timedelta:
"""Convert integer to timedelta, if argument is an integer."""
if isinstance(delta, numbers.Real):
return timedelta(seconds=delta)
return delta
def delta_resolution(dt: datetime, delta: timedelta) -> datetime:
"""Round a :class:`~datetime.datetime` to the resolution of timedelta.
If the :class:`~datetime.timedelta` is in days, the
:class:`~datetime.datetime` will be rounded to the nearest days,
if the :class:`~datetime.timedelta` is in hours the
:class:`~datetime.datetime` will be rounded to the nearest hour,
and so on until seconds, which will just return the original
:class:`~datetime.datetime`.
"""
delta = max(delta.total_seconds(), 0)
resolutions = ((3, lambda x: x / 86400),
(4, lambda x: x / 3600),
(5, lambda x: x / 60))
args = dt.year, dt.month, dt.day, dt.hour, dt.minute, dt.second
for res, predicate in resolutions:
if predicate(delta) >= 1.0:
return datetime(*args[:res], tzinfo=dt.tzinfo)
return dt
def remaining(
start: datetime, ends_in: timedelta, now: Callable | None = None,
relative: bool = False) -> timedelta:
"""Calculate the remaining time for a start date and a timedelta.
For example, "how many seconds left for 30 seconds after start?"
Arguments:
start (~datetime.datetime): Starting date.
ends_in (~datetime.timedelta): The end delta.
relative (bool): If enabled the end time will be calculated
using :func:`delta_resolution` (i.e., rounded to the
resolution of `ends_in`).
now (Callable): Function returning the current time and date.
Defaults to :func:`datetime.utcnow`.
Returns:
~datetime.timedelta: Remaining time.
"""
now = now or datetime.utcnow()
if str(
start.tzinfo) == str(
now.tzinfo) and now.utcoffset() != start.utcoffset():
# DST started/ended
start = start.replace(tzinfo=now.tzinfo)
end_date = start + ends_in
if relative:
end_date = delta_resolution(end_date, ends_in).replace(microsecond=0)
ret = end_date - now
if C_REMDEBUG: # pragma: no cover
print('rem: NOW:{!r} START:{!r} ENDS_IN:{!r} END_DATE:{} REM:{}'.format(
now, start, ends_in, end_date, ret))
return ret
def rate(r: str) -> float:
"""Convert rate string (`"100/m"`, `"2/h"` or `"0.5/s"`) to seconds."""
if r:
if isinstance(r, str):
ops, _, modifier = r.partition('/')
return RATE_MODIFIER_MAP[modifier or 's'](float(ops)) or 0
return r or 0
return 0
def weekday(name: str) -> int:
"""Return the position of a weekday: 0 - 7, where 0 is Sunday.
Example:
>>> weekday('sunday'), weekday('sun'), weekday('mon')
(0, 0, 1)
"""
abbreviation = name[0:3].lower()
try:
return WEEKDAYS[abbreviation]
except KeyError:
# Show original day name in exception, instead of abbr.
raise KeyError(name)
def humanize_seconds(
secs: int, prefix: str = '', sep: str = '', now: str = 'now',
microseconds: bool = False) -> str:
"""Show seconds in human form.
For example, 60 becomes "1 minute", and 7200 becomes "2 hours".
Arguments:
prefix (str): can be used to add a preposition to the output
(e.g., 'in' will give 'in 1 second', but add nothing to 'now').
now (str): Literal 'now'.
microseconds (bool): Include microseconds.
"""
secs = float(format(float(secs), '.2f'))
for unit, divider, formatter in TIME_UNITS:
if secs >= divider:
w = secs / float(divider)
return '{}{}{} {}'.format(prefix, sep, formatter(w),
pluralize(w, unit))
if microseconds and secs > 0.0:
return '{prefix}{sep}{0:.2f} seconds'.format(
secs, sep=sep, prefix=prefix)
return now
def maybe_iso8601(dt: datetime | str | None) -> None | datetime:
"""Either ``datetime | str -> datetime`` or ``None -> None``."""
if not dt:
return
if isinstance(dt, datetime):
return dt
return datetime.fromisoformat(dt)
def is_naive(dt: datetime) -> bool:
"""Return True if :class:`~datetime.datetime` is naive, meaning it doesn't have timezone info set."""
return dt.tzinfo is None or dt.tzinfo.utcoffset(dt) is None
def _can_detect_ambiguous(tz: tzinfo) -> bool:
"""Helper function to determine if a timezone can detect ambiguous times using dateutil."""
return isinstance(tz, ZoneInfo) or hasattr(tz, "is_ambiguous")
def _is_ambigious(dt: datetime, tz: tzinfo) -> bool:
"""Helper function to determine if a timezone is ambiguous using python's dateutil module.
Returns False if the timezone cannot detect ambiguity, or if there is no ambiguity, otherwise True.
In order to detect ambiguous datetimes, the timezone must be built using ZoneInfo, or have an is_ambiguous
method. Previously, pytz timezones would throw an AmbiguousTimeError if the localized dt was ambiguous,
but now we need to specifically check for ambiguity with dateutil, as pytz is deprecated.
"""
return _can_detect_ambiguous(tz) and dateutil_tz.datetime_ambiguous(dt)
def make_aware(dt: datetime, tz: tzinfo) -> datetime:
"""Set timezone for a :class:`~datetime.datetime` object."""
dt = dt.replace(tzinfo=tz)
if _is_ambigious(dt, tz):
dt = min(dt.replace(fold=0), dt.replace(fold=1))
return dt
def localize(dt: datetime, tz: tzinfo) -> datetime:
"""Convert aware :class:`~datetime.datetime` to another timezone.
Using a ZoneInfo timezone will give the most flexibility in terms of ambiguous DST handling.
"""
if is_naive(dt): # Ensure timezone aware datetime
dt = make_aware(dt, tz)
if dt.tzinfo == ZoneInfo("UTC"):
dt = dt.astimezone(tz) # Always safe to call astimezone on utc zones
return dt
def to_utc(dt: datetime) -> datetime:
"""Convert naive :class:`~datetime.datetime` to UTC."""
return make_aware(dt, timezone.utc)
def maybe_make_aware(dt: datetime, tz: tzinfo | None = None,
naive_as_utc: bool = True) -> datetime:
"""Convert dt to aware datetime, do nothing if dt is already aware."""
if is_naive(dt):
if naive_as_utc:
dt = to_utc(dt)
return localize(
dt, timezone.utc if tz is None else timezone.tz_or_local(tz),
)
return dt
class ffwd:
"""Version of ``dateutil.relativedelta`` that only supports addition."""
def __init__(self, year=None, month=None, weeks=0, weekday=None, day=None,
hour=None, minute=None, second=None, microsecond=None,
**kwargs: Any):
# pylint: disable=redefined-outer-name
# weekday is also a function in outer scope.
self.year = year
self.month = month
self.weeks = weeks
self.weekday = weekday
self.day = day
self.hour = hour
self.minute = minute
self.second = second
self.microsecond = microsecond
self.days = weeks * 7
self._has_time = self.hour is not None or self.minute is not None
def __repr__(self) -> str:
return reprcall('ffwd', (), self._fields(weeks=self.weeks,
weekday=self.weekday))
def __radd__(self, other: Any) -> timedelta:
if not isinstance(other, date):
return NotImplemented
year = self.year or other.year
month = self.month or other.month
day = min(monthrange(year, month)[1], self.day or other.day)
ret = other.replace(**dict(dictfilter(self._fields()),
year=year, month=month, day=day))
if self.weekday is not None:
ret += timedelta(days=(7 - ret.weekday() + self.weekday) % 7)
return ret + timedelta(days=self.days)
def _fields(self, **extra: Any) -> dict[str, Any]:
return dictfilter({
'year': self.year, 'month': self.month, 'day': self.day,
'hour': self.hour, 'minute': self.minute,
'second': self.second, 'microsecond': self.microsecond,
}, **extra)
def utcoffset(
time: ModuleType = _time,
localtime: Callable[..., _time.struct_time] = _time.localtime) -> float:
"""Return the current offset to UTC in hours."""
if localtime().tm_isdst:
return time.altzone // 3600
return time.timezone // 3600
def adjust_timestamp(ts: float, offset: int,
here: Callable[..., float] = utcoffset) -> float:
"""Adjust timestamp based on provided utcoffset."""
return ts - (offset - here()) * 3600
def get_exponential_backoff_interval(
factor: int,
retries: int,
maximum: int,
full_jitter: bool = False
) -> int:
"""Calculate the exponential backoff wait time."""
# Will be zero if factor equals 0
countdown = min(maximum, factor * (2 ** retries))
# Full jitter according to
# https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/
if full_jitter:
countdown = random.randrange(countdown + 1)
# Adjust according to maximum wait time and account for negative values.
return max(0, countdown)

View File

@@ -0,0 +1,154 @@
"""Scheduler for Python functions.
.. note::
This is used for the thread-based worker only,
not for amqp/redis/sqs/qpid where :mod:`kombu.asynchronous.timer` is used.
"""
import os
import sys
import threading
from itertools import count
from threading import TIMEOUT_MAX as THREAD_TIMEOUT_MAX
from time import sleep
from kombu.asynchronous.timer import Entry
from kombu.asynchronous.timer import Timer as Schedule
from kombu.asynchronous.timer import logger, to_timestamp
TIMER_DEBUG = os.environ.get('TIMER_DEBUG')
__all__ = ('Entry', 'Schedule', 'Timer', 'to_timestamp')
class Timer(threading.Thread):
"""Timer thread.
Note:
This is only used for transports not supporting AsyncIO.
"""
Entry = Entry
Schedule = Schedule
running = False
on_tick = None
_timer_count = count(1)
if TIMER_DEBUG: # pragma: no cover
def start(self, *args, **kwargs):
import traceback
print('- Timer starting')
traceback.print_stack()
super().start(*args, **kwargs)
def __init__(self, schedule=None, on_error=None, on_tick=None,
on_start=None, max_interval=None, **kwargs):
self.schedule = schedule or self.Schedule(on_error=on_error,
max_interval=max_interval)
self.on_start = on_start
self.on_tick = on_tick or self.on_tick
super().__init__()
# `_is_stopped` is likely to be an attribute on `Thread` objects so we
# double underscore these names to avoid shadowing anything and
# potentially getting confused by the superclass turning these into
# something other than an `Event` instance (e.g. a `bool`)
self.__is_shutdown = threading.Event()
self.__is_stopped = threading.Event()
self.mutex = threading.Lock()
self.not_empty = threading.Condition(self.mutex)
self.daemon = True
self.name = f'Timer-{next(self._timer_count)}'
def _next_entry(self):
with self.not_empty:
delay, entry = next(self.scheduler)
if entry is None:
if delay is None:
self.not_empty.wait(1.0)
return delay
return self.schedule.apply_entry(entry)
__next__ = next = _next_entry # for 2to3
def run(self):
try:
self.running = True
self.scheduler = iter(self.schedule)
while not self.__is_shutdown.is_set():
delay = self._next_entry()
if delay:
if self.on_tick:
self.on_tick(delay)
if sleep is None: # pragma: no cover
break
sleep(delay)
try:
self.__is_stopped.set()
except TypeError: # pragma: no cover
# we lost the race at interpreter shutdown,
# so gc collected built-in modules.
pass
except Exception as exc:
logger.error('Thread Timer crashed: %r', exc, exc_info=True)
sys.stderr.flush()
os._exit(1)
def stop(self):
self.__is_shutdown.set()
if self.running:
self.__is_stopped.wait()
self.join(THREAD_TIMEOUT_MAX)
self.running = False
def ensure_started(self):
if not self.running and not self.is_alive():
if self.on_start:
self.on_start(self)
self.start()
def _do_enter(self, meth, *args, **kwargs):
self.ensure_started()
with self.mutex:
entry = getattr(self.schedule, meth)(*args, **kwargs)
self.not_empty.notify()
return entry
def enter(self, entry, eta, priority=None):
return self._do_enter('enter_at', entry, eta, priority=priority)
def call_at(self, *args, **kwargs):
return self._do_enter('call_at', *args, **kwargs)
def enter_after(self, *args, **kwargs):
return self._do_enter('enter_after', *args, **kwargs)
def call_after(self, *args, **kwargs):
return self._do_enter('call_after', *args, **kwargs)
def call_repeatedly(self, *args, **kwargs):
return self._do_enter('call_repeatedly', *args, **kwargs)
def exit_after(self, secs, priority=10):
self.call_after(secs, sys.exit, priority)
def cancel(self, tref):
tref.cancel()
def clear(self):
self.schedule.clear()
def empty(self):
return not len(self)
def __len__(self):
return len(self.schedule)
def __bool__(self):
"""``bool(timer)``."""
return True
__nonzero__ = __bool__
@property
def queue(self):
return self.schedule.queue