update
This commit is contained in:
@@ -0,0 +1,222 @@
|
||||
"""SQLAlchemy result store backend."""
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
|
||||
from vine.utils import wraps
|
||||
|
||||
from celery import states
|
||||
from celery.backends.base import BaseBackend
|
||||
from celery.exceptions import ImproperlyConfigured
|
||||
from celery.utils.time import maybe_timedelta
|
||||
|
||||
from .models import Task, TaskExtended, TaskSet
|
||||
from .session import SessionManager
|
||||
|
||||
try:
|
||||
from sqlalchemy.exc import DatabaseError, InvalidRequestError
|
||||
from sqlalchemy.orm.exc import StaleDataError
|
||||
except ImportError:
|
||||
raise ImproperlyConfigured(
|
||||
'The database result backend requires SQLAlchemy to be installed.'
|
||||
'See https://pypi.org/project/SQLAlchemy/')
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
__all__ = ('DatabaseBackend',)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def session_cleanup(session):
|
||||
try:
|
||||
yield
|
||||
except Exception:
|
||||
session.rollback()
|
||||
raise
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
|
||||
def retry(fun):
|
||||
|
||||
@wraps(fun)
|
||||
def _inner(*args, **kwargs):
|
||||
max_retries = kwargs.pop('max_retries', 3)
|
||||
|
||||
for retries in range(max_retries):
|
||||
try:
|
||||
return fun(*args, **kwargs)
|
||||
except (DatabaseError, InvalidRequestError, StaleDataError):
|
||||
logger.warning(
|
||||
'Failed operation %s. Retrying %s more times.',
|
||||
fun.__name__, max_retries - retries - 1,
|
||||
exc_info=True)
|
||||
if retries + 1 >= max_retries:
|
||||
raise
|
||||
|
||||
return _inner
|
||||
|
||||
|
||||
class DatabaseBackend(BaseBackend):
|
||||
"""The database result backend."""
|
||||
|
||||
# ResultSet.iterate should sleep this much between each pool,
|
||||
# to not bombard the database with queries.
|
||||
subpolling_interval = 0.5
|
||||
|
||||
task_cls = Task
|
||||
taskset_cls = TaskSet
|
||||
|
||||
def __init__(self, dburi=None, engine_options=None, url=None, **kwargs):
|
||||
# The `url` argument was added later and is used by
|
||||
# the app to set backend by url (celery.app.backends.by_url)
|
||||
super().__init__(expires_type=maybe_timedelta,
|
||||
url=url, **kwargs)
|
||||
conf = self.app.conf
|
||||
|
||||
if self.extended_result:
|
||||
self.task_cls = TaskExtended
|
||||
|
||||
self.url = url or dburi or conf.database_url
|
||||
self.engine_options = dict(
|
||||
engine_options or {},
|
||||
**conf.database_engine_options or {})
|
||||
self.short_lived_sessions = kwargs.get(
|
||||
'short_lived_sessions',
|
||||
conf.database_short_lived_sessions)
|
||||
|
||||
schemas = conf.database_table_schemas or {}
|
||||
tablenames = conf.database_table_names or {}
|
||||
self.task_cls.configure(
|
||||
schema=schemas.get('task'),
|
||||
name=tablenames.get('task'))
|
||||
self.taskset_cls.configure(
|
||||
schema=schemas.get('group'),
|
||||
name=tablenames.get('group'))
|
||||
|
||||
if not self.url:
|
||||
raise ImproperlyConfigured(
|
||||
'Missing connection string! Do you have the'
|
||||
' database_url setting set to a real value?')
|
||||
|
||||
@property
|
||||
def extended_result(self):
|
||||
return self.app.conf.find_value_for_key('extended', 'result')
|
||||
|
||||
def ResultSession(self, session_manager=SessionManager()):
|
||||
return session_manager.session_factory(
|
||||
dburi=self.url,
|
||||
short_lived_sessions=self.short_lived_sessions,
|
||||
**self.engine_options)
|
||||
|
||||
@retry
|
||||
def _store_result(self, task_id, result, state, traceback=None,
|
||||
request=None, **kwargs):
|
||||
"""Store return value and state of an executed task."""
|
||||
session = self.ResultSession()
|
||||
with session_cleanup(session):
|
||||
task = list(session.query(self.task_cls).filter(self.task_cls.task_id == task_id))
|
||||
task = task and task[0]
|
||||
if not task:
|
||||
task = self.task_cls(task_id)
|
||||
task.task_id = task_id
|
||||
session.add(task)
|
||||
session.flush()
|
||||
|
||||
self._update_result(task, result, state, traceback=traceback, request=request)
|
||||
session.commit()
|
||||
|
||||
def _update_result(self, task, result, state, traceback=None,
|
||||
request=None):
|
||||
|
||||
meta = self._get_result_meta(result=result, state=state,
|
||||
traceback=traceback, request=request,
|
||||
format_date=False, encode=True)
|
||||
|
||||
# Exclude the primary key id and task_id columns
|
||||
# as we should not set it None
|
||||
columns = [column.name for column in self.task_cls.__table__.columns
|
||||
if column.name not in {'id', 'task_id'}]
|
||||
|
||||
# Iterate through the columns name of the table
|
||||
# to set the value from meta.
|
||||
# If the value is not present in meta, set None
|
||||
for column in columns:
|
||||
value = meta.get(column)
|
||||
setattr(task, column, value)
|
||||
|
||||
@retry
|
||||
def _get_task_meta_for(self, task_id):
|
||||
"""Get task meta-data for a task by id."""
|
||||
session = self.ResultSession()
|
||||
with session_cleanup(session):
|
||||
task = list(session.query(self.task_cls).filter(self.task_cls.task_id == task_id))
|
||||
task = task and task[0]
|
||||
if not task:
|
||||
task = self.task_cls(task_id)
|
||||
task.status = states.PENDING
|
||||
task.result = None
|
||||
data = task.to_dict()
|
||||
if data.get('args', None) is not None:
|
||||
data['args'] = self.decode(data['args'])
|
||||
if data.get('kwargs', None) is not None:
|
||||
data['kwargs'] = self.decode(data['kwargs'])
|
||||
return self.meta_from_decoded(data)
|
||||
|
||||
@retry
|
||||
def _save_group(self, group_id, result):
|
||||
"""Store the result of an executed group."""
|
||||
session = self.ResultSession()
|
||||
with session_cleanup(session):
|
||||
group = self.taskset_cls(group_id, result)
|
||||
session.add(group)
|
||||
session.flush()
|
||||
session.commit()
|
||||
return result
|
||||
|
||||
@retry
|
||||
def _restore_group(self, group_id):
|
||||
"""Get meta-data for group by id."""
|
||||
session = self.ResultSession()
|
||||
with session_cleanup(session):
|
||||
group = session.query(self.taskset_cls).filter(
|
||||
self.taskset_cls.taskset_id == group_id).first()
|
||||
if group:
|
||||
return group.to_dict()
|
||||
|
||||
@retry
|
||||
def _delete_group(self, group_id):
|
||||
"""Delete meta-data for group by id."""
|
||||
session = self.ResultSession()
|
||||
with session_cleanup(session):
|
||||
session.query(self.taskset_cls).filter(
|
||||
self.taskset_cls.taskset_id == group_id).delete()
|
||||
session.flush()
|
||||
session.commit()
|
||||
|
||||
@retry
|
||||
def _forget(self, task_id):
|
||||
"""Forget about result."""
|
||||
session = self.ResultSession()
|
||||
with session_cleanup(session):
|
||||
session.query(self.task_cls).filter(self.task_cls.task_id == task_id).delete()
|
||||
session.commit()
|
||||
|
||||
def cleanup(self):
|
||||
"""Delete expired meta-data."""
|
||||
session = self.ResultSession()
|
||||
expires = self.expires
|
||||
now = self.app.now()
|
||||
with session_cleanup(session):
|
||||
session.query(self.task_cls).filter(
|
||||
self.task_cls.date_done < (now - expires)).delete()
|
||||
session.query(self.taskset_cls).filter(
|
||||
self.taskset_cls.date_done < (now - expires)).delete()
|
||||
session.commit()
|
||||
|
||||
def __reduce__(self, args=(), kwargs=None):
|
||||
kwargs = {} if not kwargs else kwargs
|
||||
kwargs.update(
|
||||
{'dburi': self.url,
|
||||
'expires': self.expires,
|
||||
'engine_options': self.engine_options})
|
||||
return super().__reduce__(args, kwargs)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,108 @@
|
||||
"""Database models used by the SQLAlchemy result store backend."""
|
||||
from datetime import datetime
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.types import PickleType
|
||||
|
||||
from celery import states
|
||||
|
||||
from .session import ResultModelBase
|
||||
|
||||
__all__ = ('Task', 'TaskExtended', 'TaskSet')
|
||||
|
||||
|
||||
class Task(ResultModelBase):
|
||||
"""Task result/status."""
|
||||
|
||||
__tablename__ = 'celery_taskmeta'
|
||||
__table_args__ = {'sqlite_autoincrement': True}
|
||||
|
||||
id = sa.Column(sa.Integer, sa.Sequence('task_id_sequence'),
|
||||
primary_key=True, autoincrement=True)
|
||||
task_id = sa.Column(sa.String(155), unique=True)
|
||||
status = sa.Column(sa.String(50), default=states.PENDING)
|
||||
result = sa.Column(PickleType, nullable=True)
|
||||
date_done = sa.Column(sa.DateTime, default=datetime.utcnow,
|
||||
onupdate=datetime.utcnow, nullable=True)
|
||||
traceback = sa.Column(sa.Text, nullable=True)
|
||||
|
||||
def __init__(self, task_id):
|
||||
self.task_id = task_id
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
'task_id': self.task_id,
|
||||
'status': self.status,
|
||||
'result': self.result,
|
||||
'traceback': self.traceback,
|
||||
'date_done': self.date_done,
|
||||
}
|
||||
|
||||
def __repr__(self):
|
||||
return '<Task {0.task_id} state: {0.status}>'.format(self)
|
||||
|
||||
@classmethod
|
||||
def configure(cls, schema=None, name=None):
|
||||
cls.__table__.schema = schema
|
||||
cls.id.default.schema = schema
|
||||
cls.__table__.name = name or cls.__tablename__
|
||||
|
||||
|
||||
class TaskExtended(Task):
|
||||
"""For the extend result."""
|
||||
|
||||
__tablename__ = 'celery_taskmeta'
|
||||
__table_args__ = {'sqlite_autoincrement': True, 'extend_existing': True}
|
||||
|
||||
name = sa.Column(sa.String(155), nullable=True)
|
||||
args = sa.Column(sa.LargeBinary, nullable=True)
|
||||
kwargs = sa.Column(sa.LargeBinary, nullable=True)
|
||||
worker = sa.Column(sa.String(155), nullable=True)
|
||||
retries = sa.Column(sa.Integer, nullable=True)
|
||||
queue = sa.Column(sa.String(155), nullable=True)
|
||||
|
||||
def to_dict(self):
|
||||
task_dict = super().to_dict()
|
||||
task_dict.update({
|
||||
'name': self.name,
|
||||
'args': self.args,
|
||||
'kwargs': self.kwargs,
|
||||
'worker': self.worker,
|
||||
'retries': self.retries,
|
||||
'queue': self.queue,
|
||||
})
|
||||
return task_dict
|
||||
|
||||
|
||||
class TaskSet(ResultModelBase):
|
||||
"""TaskSet result."""
|
||||
|
||||
__tablename__ = 'celery_tasksetmeta'
|
||||
__table_args__ = {'sqlite_autoincrement': True}
|
||||
|
||||
id = sa.Column(sa.Integer, sa.Sequence('taskset_id_sequence'),
|
||||
autoincrement=True, primary_key=True)
|
||||
taskset_id = sa.Column(sa.String(155), unique=True)
|
||||
result = sa.Column(PickleType, nullable=True)
|
||||
date_done = sa.Column(sa.DateTime, default=datetime.utcnow,
|
||||
nullable=True)
|
||||
|
||||
def __init__(self, taskset_id, result):
|
||||
self.taskset_id = taskset_id
|
||||
self.result = result
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
'taskset_id': self.taskset_id,
|
||||
'result': self.result,
|
||||
'date_done': self.date_done,
|
||||
}
|
||||
|
||||
def __repr__(self):
|
||||
return f'<TaskSet: {self.taskset_id}>'
|
||||
|
||||
@classmethod
|
||||
def configure(cls, schema=None, name=None):
|
||||
cls.__table__.schema = schema
|
||||
cls.id.default.schema = schema
|
||||
cls.__table__.name = name or cls.__tablename__
|
||||
@@ -0,0 +1,89 @@
|
||||
"""SQLAlchemy session."""
|
||||
import time
|
||||
|
||||
from kombu.utils.compat import register_after_fork
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.exc import DatabaseError
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.pool import NullPool
|
||||
|
||||
from celery.utils.time import get_exponential_backoff_interval
|
||||
|
||||
try:
|
||||
from sqlalchemy.orm import declarative_base
|
||||
except ImportError:
|
||||
# TODO: Remove this once we drop support for SQLAlchemy < 1.4.
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
ResultModelBase = declarative_base()
|
||||
|
||||
__all__ = ('SessionManager',)
|
||||
|
||||
PREPARE_MODELS_MAX_RETRIES = 10
|
||||
|
||||
|
||||
def _after_fork_cleanup_session(session):
|
||||
session._after_fork()
|
||||
|
||||
|
||||
class SessionManager:
|
||||
"""Manage SQLAlchemy sessions."""
|
||||
|
||||
def __init__(self):
|
||||
self._engines = {}
|
||||
self._sessions = {}
|
||||
self.forked = False
|
||||
self.prepared = False
|
||||
if register_after_fork is not None:
|
||||
register_after_fork(self, _after_fork_cleanup_session)
|
||||
|
||||
def _after_fork(self):
|
||||
self.forked = True
|
||||
|
||||
def get_engine(self, dburi, **kwargs):
|
||||
if self.forked:
|
||||
try:
|
||||
return self._engines[dburi]
|
||||
except KeyError:
|
||||
engine = self._engines[dburi] = create_engine(dburi, **kwargs)
|
||||
return engine
|
||||
else:
|
||||
kwargs = {k: v for k, v in kwargs.items() if
|
||||
not k.startswith('pool')}
|
||||
return create_engine(dburi, poolclass=NullPool, **kwargs)
|
||||
|
||||
def create_session(self, dburi, short_lived_sessions=False, **kwargs):
|
||||
engine = self.get_engine(dburi, **kwargs)
|
||||
if self.forked:
|
||||
if short_lived_sessions or dburi not in self._sessions:
|
||||
self._sessions[dburi] = sessionmaker(bind=engine)
|
||||
return engine, self._sessions[dburi]
|
||||
return engine, sessionmaker(bind=engine)
|
||||
|
||||
def prepare_models(self, engine):
|
||||
if not self.prepared:
|
||||
# SQLAlchemy will check if the items exist before trying to
|
||||
# create them, which is a race condition. If it raises an error
|
||||
# in one iteration, the next may pass all the existence checks
|
||||
# and the call will succeed.
|
||||
retries = 0
|
||||
while True:
|
||||
try:
|
||||
ResultModelBase.metadata.create_all(engine)
|
||||
except DatabaseError:
|
||||
if retries < PREPARE_MODELS_MAX_RETRIES:
|
||||
sleep_amount_ms = get_exponential_backoff_interval(
|
||||
10, retries, 1000, True
|
||||
)
|
||||
time.sleep(sleep_amount_ms / 1000)
|
||||
retries += 1
|
||||
else:
|
||||
raise
|
||||
else:
|
||||
break
|
||||
self.prepared = True
|
||||
|
||||
def session_factory(self, dburi, **kwargs):
|
||||
engine, session = self.create_session(dburi, **kwargs)
|
||||
self.prepare_models(engine)
|
||||
return session()
|
||||
Reference in New Issue
Block a user