This commit is contained in:
Iliyan Angelov
2025-09-14 23:24:25 +03:00
commit c67067a2a4
71311 changed files with 6800714 additions and 0 deletions

View File

@@ -0,0 +1,222 @@
"""SQLAlchemy result store backend."""
import logging
from contextlib import contextmanager
from vine.utils import wraps
from celery import states
from celery.backends.base import BaseBackend
from celery.exceptions import ImproperlyConfigured
from celery.utils.time import maybe_timedelta
from .models import Task, TaskExtended, TaskSet
from .session import SessionManager
try:
from sqlalchemy.exc import DatabaseError, InvalidRequestError
from sqlalchemy.orm.exc import StaleDataError
except ImportError:
raise ImproperlyConfigured(
'The database result backend requires SQLAlchemy to be installed.'
'See https://pypi.org/project/SQLAlchemy/')
logger = logging.getLogger(__name__)
__all__ = ('DatabaseBackend',)
@contextmanager
def session_cleanup(session):
try:
yield
except Exception:
session.rollback()
raise
finally:
session.close()
def retry(fun):
@wraps(fun)
def _inner(*args, **kwargs):
max_retries = kwargs.pop('max_retries', 3)
for retries in range(max_retries):
try:
return fun(*args, **kwargs)
except (DatabaseError, InvalidRequestError, StaleDataError):
logger.warning(
'Failed operation %s. Retrying %s more times.',
fun.__name__, max_retries - retries - 1,
exc_info=True)
if retries + 1 >= max_retries:
raise
return _inner
class DatabaseBackend(BaseBackend):
"""The database result backend."""
# ResultSet.iterate should sleep this much between each pool,
# to not bombard the database with queries.
subpolling_interval = 0.5
task_cls = Task
taskset_cls = TaskSet
def __init__(self, dburi=None, engine_options=None, url=None, **kwargs):
# The `url` argument was added later and is used by
# the app to set backend by url (celery.app.backends.by_url)
super().__init__(expires_type=maybe_timedelta,
url=url, **kwargs)
conf = self.app.conf
if self.extended_result:
self.task_cls = TaskExtended
self.url = url or dburi or conf.database_url
self.engine_options = dict(
engine_options or {},
**conf.database_engine_options or {})
self.short_lived_sessions = kwargs.get(
'short_lived_sessions',
conf.database_short_lived_sessions)
schemas = conf.database_table_schemas or {}
tablenames = conf.database_table_names or {}
self.task_cls.configure(
schema=schemas.get('task'),
name=tablenames.get('task'))
self.taskset_cls.configure(
schema=schemas.get('group'),
name=tablenames.get('group'))
if not self.url:
raise ImproperlyConfigured(
'Missing connection string! Do you have the'
' database_url setting set to a real value?')
@property
def extended_result(self):
return self.app.conf.find_value_for_key('extended', 'result')
def ResultSession(self, session_manager=SessionManager()):
return session_manager.session_factory(
dburi=self.url,
short_lived_sessions=self.short_lived_sessions,
**self.engine_options)
@retry
def _store_result(self, task_id, result, state, traceback=None,
request=None, **kwargs):
"""Store return value and state of an executed task."""
session = self.ResultSession()
with session_cleanup(session):
task = list(session.query(self.task_cls).filter(self.task_cls.task_id == task_id))
task = task and task[0]
if not task:
task = self.task_cls(task_id)
task.task_id = task_id
session.add(task)
session.flush()
self._update_result(task, result, state, traceback=traceback, request=request)
session.commit()
def _update_result(self, task, result, state, traceback=None,
request=None):
meta = self._get_result_meta(result=result, state=state,
traceback=traceback, request=request,
format_date=False, encode=True)
# Exclude the primary key id and task_id columns
# as we should not set it None
columns = [column.name for column in self.task_cls.__table__.columns
if column.name not in {'id', 'task_id'}]
# Iterate through the columns name of the table
# to set the value from meta.
# If the value is not present in meta, set None
for column in columns:
value = meta.get(column)
setattr(task, column, value)
@retry
def _get_task_meta_for(self, task_id):
"""Get task meta-data for a task by id."""
session = self.ResultSession()
with session_cleanup(session):
task = list(session.query(self.task_cls).filter(self.task_cls.task_id == task_id))
task = task and task[0]
if not task:
task = self.task_cls(task_id)
task.status = states.PENDING
task.result = None
data = task.to_dict()
if data.get('args', None) is not None:
data['args'] = self.decode(data['args'])
if data.get('kwargs', None) is not None:
data['kwargs'] = self.decode(data['kwargs'])
return self.meta_from_decoded(data)
@retry
def _save_group(self, group_id, result):
"""Store the result of an executed group."""
session = self.ResultSession()
with session_cleanup(session):
group = self.taskset_cls(group_id, result)
session.add(group)
session.flush()
session.commit()
return result
@retry
def _restore_group(self, group_id):
"""Get meta-data for group by id."""
session = self.ResultSession()
with session_cleanup(session):
group = session.query(self.taskset_cls).filter(
self.taskset_cls.taskset_id == group_id).first()
if group:
return group.to_dict()
@retry
def _delete_group(self, group_id):
"""Delete meta-data for group by id."""
session = self.ResultSession()
with session_cleanup(session):
session.query(self.taskset_cls).filter(
self.taskset_cls.taskset_id == group_id).delete()
session.flush()
session.commit()
@retry
def _forget(self, task_id):
"""Forget about result."""
session = self.ResultSession()
with session_cleanup(session):
session.query(self.task_cls).filter(self.task_cls.task_id == task_id).delete()
session.commit()
def cleanup(self):
"""Delete expired meta-data."""
session = self.ResultSession()
expires = self.expires
now = self.app.now()
with session_cleanup(session):
session.query(self.task_cls).filter(
self.task_cls.date_done < (now - expires)).delete()
session.query(self.taskset_cls).filter(
self.taskset_cls.date_done < (now - expires)).delete()
session.commit()
def __reduce__(self, args=(), kwargs=None):
kwargs = {} if not kwargs else kwargs
kwargs.update(
{'dburi': self.url,
'expires': self.expires,
'engine_options': self.engine_options})
return super().__reduce__(args, kwargs)

View File

@@ -0,0 +1,108 @@
"""Database models used by the SQLAlchemy result store backend."""
from datetime import datetime
import sqlalchemy as sa
from sqlalchemy.types import PickleType
from celery import states
from .session import ResultModelBase
__all__ = ('Task', 'TaskExtended', 'TaskSet')
class Task(ResultModelBase):
"""Task result/status."""
__tablename__ = 'celery_taskmeta'
__table_args__ = {'sqlite_autoincrement': True}
id = sa.Column(sa.Integer, sa.Sequence('task_id_sequence'),
primary_key=True, autoincrement=True)
task_id = sa.Column(sa.String(155), unique=True)
status = sa.Column(sa.String(50), default=states.PENDING)
result = sa.Column(PickleType, nullable=True)
date_done = sa.Column(sa.DateTime, default=datetime.utcnow,
onupdate=datetime.utcnow, nullable=True)
traceback = sa.Column(sa.Text, nullable=True)
def __init__(self, task_id):
self.task_id = task_id
def to_dict(self):
return {
'task_id': self.task_id,
'status': self.status,
'result': self.result,
'traceback': self.traceback,
'date_done': self.date_done,
}
def __repr__(self):
return '<Task {0.task_id} state: {0.status}>'.format(self)
@classmethod
def configure(cls, schema=None, name=None):
cls.__table__.schema = schema
cls.id.default.schema = schema
cls.__table__.name = name or cls.__tablename__
class TaskExtended(Task):
"""For the extend result."""
__tablename__ = 'celery_taskmeta'
__table_args__ = {'sqlite_autoincrement': True, 'extend_existing': True}
name = sa.Column(sa.String(155), nullable=True)
args = sa.Column(sa.LargeBinary, nullable=True)
kwargs = sa.Column(sa.LargeBinary, nullable=True)
worker = sa.Column(sa.String(155), nullable=True)
retries = sa.Column(sa.Integer, nullable=True)
queue = sa.Column(sa.String(155), nullable=True)
def to_dict(self):
task_dict = super().to_dict()
task_dict.update({
'name': self.name,
'args': self.args,
'kwargs': self.kwargs,
'worker': self.worker,
'retries': self.retries,
'queue': self.queue,
})
return task_dict
class TaskSet(ResultModelBase):
"""TaskSet result."""
__tablename__ = 'celery_tasksetmeta'
__table_args__ = {'sqlite_autoincrement': True}
id = sa.Column(sa.Integer, sa.Sequence('taskset_id_sequence'),
autoincrement=True, primary_key=True)
taskset_id = sa.Column(sa.String(155), unique=True)
result = sa.Column(PickleType, nullable=True)
date_done = sa.Column(sa.DateTime, default=datetime.utcnow,
nullable=True)
def __init__(self, taskset_id, result):
self.taskset_id = taskset_id
self.result = result
def to_dict(self):
return {
'taskset_id': self.taskset_id,
'result': self.result,
'date_done': self.date_done,
}
def __repr__(self):
return f'<TaskSet: {self.taskset_id}>'
@classmethod
def configure(cls, schema=None, name=None):
cls.__table__.schema = schema
cls.id.default.schema = schema
cls.__table__.name = name or cls.__tablename__

View File

@@ -0,0 +1,89 @@
"""SQLAlchemy session."""
import time
from kombu.utils.compat import register_after_fork
from sqlalchemy import create_engine
from sqlalchemy.exc import DatabaseError
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import NullPool
from celery.utils.time import get_exponential_backoff_interval
try:
from sqlalchemy.orm import declarative_base
except ImportError:
# TODO: Remove this once we drop support for SQLAlchemy < 1.4.
from sqlalchemy.ext.declarative import declarative_base
ResultModelBase = declarative_base()
__all__ = ('SessionManager',)
PREPARE_MODELS_MAX_RETRIES = 10
def _after_fork_cleanup_session(session):
session._after_fork()
class SessionManager:
"""Manage SQLAlchemy sessions."""
def __init__(self):
self._engines = {}
self._sessions = {}
self.forked = False
self.prepared = False
if register_after_fork is not None:
register_after_fork(self, _after_fork_cleanup_session)
def _after_fork(self):
self.forked = True
def get_engine(self, dburi, **kwargs):
if self.forked:
try:
return self._engines[dburi]
except KeyError:
engine = self._engines[dburi] = create_engine(dburi, **kwargs)
return engine
else:
kwargs = {k: v for k, v in kwargs.items() if
not k.startswith('pool')}
return create_engine(dburi, poolclass=NullPool, **kwargs)
def create_session(self, dburi, short_lived_sessions=False, **kwargs):
engine = self.get_engine(dburi, **kwargs)
if self.forked:
if short_lived_sessions or dburi not in self._sessions:
self._sessions[dburi] = sessionmaker(bind=engine)
return engine, self._sessions[dburi]
return engine, sessionmaker(bind=engine)
def prepare_models(self, engine):
if not self.prepared:
# SQLAlchemy will check if the items exist before trying to
# create them, which is a race condition. If it raises an error
# in one iteration, the next may pass all the existence checks
# and the call will succeed.
retries = 0
while True:
try:
ResultModelBase.metadata.create_all(engine)
except DatabaseError:
if retries < PREPARE_MODELS_MAX_RETRIES:
sleep_amount_ms = get_exponential_backoff_interval(
10, retries, 1000, True
)
time.sleep(sleep_amount_ms / 1000)
retries += 1
else:
raise
else:
break
self.prepared = True
def session_factory(self, dburi, **kwargs):
engine, session = self.create_session(dburi, **kwargs)
self.prepare_models(engine)
return session()