Updates
This commit is contained in:
@@ -0,0 +1,2 @@
|
||||
from .migration import Migration, swappable_dependency # NOQA
|
||||
from .operations import * # NOQA
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,60 @@
|
||||
from django.db import DatabaseError
|
||||
|
||||
|
||||
class AmbiguityError(Exception):
|
||||
"""More than one migration matches a name prefix."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class BadMigrationError(Exception):
|
||||
"""There's a bad migration (unreadable/bad format/etc.)."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class CircularDependencyError(Exception):
|
||||
"""There's an impossible-to-resolve circular dependency."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class InconsistentMigrationHistory(Exception):
|
||||
"""An applied migration has some of its dependencies not applied."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class InvalidBasesError(ValueError):
|
||||
"""A model's base classes can't be resolved."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class IrreversibleError(RuntimeError):
|
||||
"""An irreversible migration is about to be reversed."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class NodeNotFoundError(LookupError):
|
||||
"""An attempt on a node is made that is not available in the graph."""
|
||||
|
||||
def __init__(self, message, node, origin=None):
|
||||
self.message = message
|
||||
self.origin = origin
|
||||
self.node = node
|
||||
|
||||
def __str__(self):
|
||||
return self.message
|
||||
|
||||
def __repr__(self):
|
||||
return "NodeNotFoundError(%r)" % (self.node,)
|
||||
|
||||
|
||||
class MigrationSchemaMissing(DatabaseError):
|
||||
pass
|
||||
|
||||
|
||||
class InvalidMigrationPlan(ValueError):
|
||||
pass
|
||||
@@ -0,0 +1,413 @@
|
||||
from django.apps.registry import apps as global_apps
|
||||
from django.db import migrations, router
|
||||
|
||||
from .exceptions import InvalidMigrationPlan
|
||||
from .loader import MigrationLoader
|
||||
from .recorder import MigrationRecorder
|
||||
from .state import ProjectState
|
||||
|
||||
|
||||
class MigrationExecutor:
|
||||
"""
|
||||
End-to-end migration execution - load migrations and run them up or down
|
||||
to a specified set of targets.
|
||||
"""
|
||||
|
||||
def __init__(self, connection, progress_callback=None):
|
||||
self.connection = connection
|
||||
self.loader = MigrationLoader(self.connection)
|
||||
self.recorder = MigrationRecorder(self.connection)
|
||||
self.progress_callback = progress_callback
|
||||
|
||||
def migration_plan(self, targets, clean_start=False):
|
||||
"""
|
||||
Given a set of targets, return a list of (Migration instance, backwards?).
|
||||
"""
|
||||
plan = []
|
||||
if clean_start:
|
||||
applied = {}
|
||||
else:
|
||||
applied = dict(self.loader.applied_migrations)
|
||||
for target in targets:
|
||||
# If the target is (app_label, None), that means unmigrate everything
|
||||
if target[1] is None:
|
||||
for root in self.loader.graph.root_nodes():
|
||||
if root[0] == target[0]:
|
||||
for migration in self.loader.graph.backwards_plan(root):
|
||||
if migration in applied:
|
||||
plan.append((self.loader.graph.nodes[migration], True))
|
||||
applied.pop(migration)
|
||||
# If the migration is already applied, do backwards mode,
|
||||
# otherwise do forwards mode.
|
||||
elif target in applied:
|
||||
# If the target is missing, it's likely a replaced migration.
|
||||
# Reload the graph without replacements.
|
||||
if (
|
||||
self.loader.replace_migrations
|
||||
and target not in self.loader.graph.node_map
|
||||
):
|
||||
self.loader.replace_migrations = False
|
||||
self.loader.build_graph()
|
||||
return self.migration_plan(targets, clean_start=clean_start)
|
||||
# Don't migrate backwards all the way to the target node (that
|
||||
# may roll back dependencies in other apps that don't need to
|
||||
# be rolled back); instead roll back through target's immediate
|
||||
# child(ren) in the same app, and no further.
|
||||
next_in_app = sorted(
|
||||
n
|
||||
for n in self.loader.graph.node_map[target].children
|
||||
if n[0] == target[0]
|
||||
)
|
||||
for node in next_in_app:
|
||||
for migration in self.loader.graph.backwards_plan(node):
|
||||
if migration in applied:
|
||||
plan.append((self.loader.graph.nodes[migration], True))
|
||||
applied.pop(migration)
|
||||
else:
|
||||
for migration in self.loader.graph.forwards_plan(target):
|
||||
if migration not in applied:
|
||||
plan.append((self.loader.graph.nodes[migration], False))
|
||||
applied[migration] = self.loader.graph.nodes[migration]
|
||||
return plan
|
||||
|
||||
def _create_project_state(self, with_applied_migrations=False):
|
||||
"""
|
||||
Create a project state including all the applications without
|
||||
migrations and applied migrations if with_applied_migrations=True.
|
||||
"""
|
||||
state = ProjectState(real_apps=self.loader.unmigrated_apps)
|
||||
if with_applied_migrations:
|
||||
# Create the forwards plan Django would follow on an empty database
|
||||
full_plan = self.migration_plan(
|
||||
self.loader.graph.leaf_nodes(), clean_start=True
|
||||
)
|
||||
applied_migrations = {
|
||||
self.loader.graph.nodes[key]
|
||||
for key in self.loader.applied_migrations
|
||||
if key in self.loader.graph.nodes
|
||||
}
|
||||
for migration, _ in full_plan:
|
||||
if migration in applied_migrations:
|
||||
migration.mutate_state(state, preserve=False)
|
||||
return state
|
||||
|
||||
def migrate(self, targets, plan=None, state=None, fake=False, fake_initial=False):
|
||||
"""
|
||||
Migrate the database up to the given targets.
|
||||
|
||||
Django first needs to create all project states before a migration is
|
||||
(un)applied and in a second step run all the database operations.
|
||||
"""
|
||||
# The django_migrations table must be present to record applied
|
||||
# migrations, but don't create it if there are no migrations to apply.
|
||||
if plan == []:
|
||||
if not self.recorder.has_table():
|
||||
return self._create_project_state(with_applied_migrations=False)
|
||||
else:
|
||||
self.recorder.ensure_schema()
|
||||
|
||||
if plan is None:
|
||||
plan = self.migration_plan(targets)
|
||||
# Create the forwards plan Django would follow on an empty database
|
||||
full_plan = self.migration_plan(
|
||||
self.loader.graph.leaf_nodes(), clean_start=True
|
||||
)
|
||||
|
||||
all_forwards = all(not backwards for mig, backwards in plan)
|
||||
all_backwards = all(backwards for mig, backwards in plan)
|
||||
|
||||
if not plan:
|
||||
if state is None:
|
||||
# The resulting state should include applied migrations.
|
||||
state = self._create_project_state(with_applied_migrations=True)
|
||||
elif all_forwards == all_backwards:
|
||||
# This should only happen if there's a mixed plan
|
||||
raise InvalidMigrationPlan(
|
||||
"Migration plans with both forwards and backwards migrations "
|
||||
"are not supported. Please split your migration process into "
|
||||
"separate plans of only forwards OR backwards migrations.",
|
||||
plan,
|
||||
)
|
||||
elif all_forwards:
|
||||
if state is None:
|
||||
# The resulting state should still include applied migrations.
|
||||
state = self._create_project_state(with_applied_migrations=True)
|
||||
state = self._migrate_all_forwards(
|
||||
state, plan, full_plan, fake=fake, fake_initial=fake_initial
|
||||
)
|
||||
else:
|
||||
# No need to check for `elif all_backwards` here, as that condition
|
||||
# would always evaluate to true.
|
||||
state = self._migrate_all_backwards(plan, full_plan, fake=fake)
|
||||
|
||||
self.check_replacements()
|
||||
|
||||
return state
|
||||
|
||||
def _migrate_all_forwards(self, state, plan, full_plan, fake, fake_initial):
|
||||
"""
|
||||
Take a list of 2-tuples of the form (migration instance, False) and
|
||||
apply them in the order they occur in the full_plan.
|
||||
"""
|
||||
migrations_to_run = {m[0] for m in plan}
|
||||
for migration, _ in full_plan:
|
||||
if not migrations_to_run:
|
||||
# We remove every migration that we applied from these sets so
|
||||
# that we can bail out once the last migration has been applied
|
||||
# and don't always run until the very end of the migration
|
||||
# process.
|
||||
break
|
||||
if migration in migrations_to_run:
|
||||
if "apps" not in state.__dict__:
|
||||
if self.progress_callback:
|
||||
self.progress_callback("render_start")
|
||||
state.apps # Render all -- performance critical
|
||||
if self.progress_callback:
|
||||
self.progress_callback("render_success")
|
||||
state = self.apply_migration(
|
||||
state, migration, fake=fake, fake_initial=fake_initial
|
||||
)
|
||||
migrations_to_run.remove(migration)
|
||||
|
||||
return state
|
||||
|
||||
def _migrate_all_backwards(self, plan, full_plan, fake):
|
||||
"""
|
||||
Take a list of 2-tuples of the form (migration instance, True) and
|
||||
unapply them in reverse order they occur in the full_plan.
|
||||
|
||||
Since unapplying a migration requires the project state prior to that
|
||||
migration, Django will compute the migration states before each of them
|
||||
in a first run over the plan and then unapply them in a second run over
|
||||
the plan.
|
||||
"""
|
||||
migrations_to_run = {m[0] for m in plan}
|
||||
# Holds all migration states prior to the migrations being unapplied
|
||||
states = {}
|
||||
state = self._create_project_state()
|
||||
applied_migrations = {
|
||||
self.loader.graph.nodes[key]
|
||||
for key in self.loader.applied_migrations
|
||||
if key in self.loader.graph.nodes
|
||||
}
|
||||
if self.progress_callback:
|
||||
self.progress_callback("render_start")
|
||||
for migration, _ in full_plan:
|
||||
if not migrations_to_run:
|
||||
# We remove every migration that we applied from this set so
|
||||
# that we can bail out once the last migration has been applied
|
||||
# and don't always run until the very end of the migration
|
||||
# process.
|
||||
break
|
||||
if migration in migrations_to_run:
|
||||
if "apps" not in state.__dict__:
|
||||
state.apps # Render all -- performance critical
|
||||
# The state before this migration
|
||||
states[migration] = state
|
||||
# The old state keeps as-is, we continue with the new state
|
||||
state = migration.mutate_state(state, preserve=True)
|
||||
migrations_to_run.remove(migration)
|
||||
elif migration in applied_migrations:
|
||||
# Only mutate the state if the migration is actually applied
|
||||
# to make sure the resulting state doesn't include changes
|
||||
# from unrelated migrations.
|
||||
migration.mutate_state(state, preserve=False)
|
||||
if self.progress_callback:
|
||||
self.progress_callback("render_success")
|
||||
|
||||
for migration, _ in plan:
|
||||
self.unapply_migration(states[migration], migration, fake=fake)
|
||||
applied_migrations.remove(migration)
|
||||
|
||||
# Generate the post migration state by starting from the state before
|
||||
# the last migration is unapplied and mutating it to include all the
|
||||
# remaining applied migrations.
|
||||
last_unapplied_migration = plan[-1][0]
|
||||
state = states[last_unapplied_migration]
|
||||
# Avoid mutating state with apps rendered as it's an expensive
|
||||
# operation.
|
||||
del state.apps
|
||||
for index, (migration, _) in enumerate(full_plan):
|
||||
if migration == last_unapplied_migration:
|
||||
for migration, _ in full_plan[index:]:
|
||||
if migration in applied_migrations:
|
||||
migration.mutate_state(state, preserve=False)
|
||||
break
|
||||
|
||||
return state
|
||||
|
||||
def apply_migration(self, state, migration, fake=False, fake_initial=False):
|
||||
"""Run a migration forwards."""
|
||||
migration_recorded = False
|
||||
if self.progress_callback:
|
||||
self.progress_callback("apply_start", migration, fake)
|
||||
if not fake:
|
||||
if fake_initial:
|
||||
# Test to see if this is an already-applied initial migration
|
||||
applied, state = self.detect_soft_applied(state, migration)
|
||||
if applied:
|
||||
fake = True
|
||||
if not fake:
|
||||
# Alright, do it normally
|
||||
with self.connection.schema_editor(
|
||||
atomic=migration.atomic
|
||||
) as schema_editor:
|
||||
state = migration.apply(state, schema_editor)
|
||||
if not schema_editor.deferred_sql:
|
||||
self.record_migration(migration)
|
||||
migration_recorded = True
|
||||
if not migration_recorded:
|
||||
self.record_migration(migration)
|
||||
# Report progress
|
||||
if self.progress_callback:
|
||||
self.progress_callback("apply_success", migration, fake)
|
||||
return state
|
||||
|
||||
def record_migration(self, migration):
|
||||
# For replacement migrations, record individual statuses
|
||||
if migration.replaces:
|
||||
for app_label, name in migration.replaces:
|
||||
self.recorder.record_applied(app_label, name)
|
||||
else:
|
||||
self.recorder.record_applied(migration.app_label, migration.name)
|
||||
|
||||
def unapply_migration(self, state, migration, fake=False):
|
||||
"""Run a migration backwards."""
|
||||
if self.progress_callback:
|
||||
self.progress_callback("unapply_start", migration, fake)
|
||||
if not fake:
|
||||
with self.connection.schema_editor(
|
||||
atomic=migration.atomic
|
||||
) as schema_editor:
|
||||
state = migration.unapply(state, schema_editor)
|
||||
# For replacement migrations, also record individual statuses.
|
||||
if migration.replaces:
|
||||
for app_label, name in migration.replaces:
|
||||
self.recorder.record_unapplied(app_label, name)
|
||||
self.recorder.record_unapplied(migration.app_label, migration.name)
|
||||
# Report progress
|
||||
if self.progress_callback:
|
||||
self.progress_callback("unapply_success", migration, fake)
|
||||
return state
|
||||
|
||||
def check_replacements(self):
|
||||
"""
|
||||
Mark replacement migrations applied if their replaced set all are.
|
||||
|
||||
Do this unconditionally on every migrate, rather than just when
|
||||
migrations are applied or unapplied, to correctly handle the case
|
||||
when a new squash migration is pushed to a deployment that already had
|
||||
all its replaced migrations applied. In this case no new migration will
|
||||
be applied, but the applied state of the squashed migration must be
|
||||
maintained.
|
||||
"""
|
||||
applied = self.recorder.applied_migrations()
|
||||
for key, migration in self.loader.replacements.items():
|
||||
all_applied = all(m in applied for m in migration.replaces)
|
||||
if all_applied and key not in applied:
|
||||
self.recorder.record_applied(*key)
|
||||
|
||||
def detect_soft_applied(self, project_state, migration):
|
||||
"""
|
||||
Test whether a migration has been implicitly applied - that the
|
||||
tables or columns it would create exist. This is intended only for use
|
||||
on initial migrations (as it only looks for CreateModel and AddField).
|
||||
"""
|
||||
|
||||
def should_skip_detecting_model(migration, model):
|
||||
"""
|
||||
No need to detect tables for proxy models, unmanaged models, or
|
||||
models that can't be migrated on the current database.
|
||||
"""
|
||||
return (
|
||||
model._meta.proxy
|
||||
or not model._meta.managed
|
||||
or not router.allow_migrate(
|
||||
self.connection.alias,
|
||||
migration.app_label,
|
||||
model_name=model._meta.model_name,
|
||||
)
|
||||
)
|
||||
|
||||
if migration.initial is None:
|
||||
# Bail if the migration isn't the first one in its app
|
||||
if any(app == migration.app_label for app, name in migration.dependencies):
|
||||
return False, project_state
|
||||
elif migration.initial is False:
|
||||
# Bail if it's NOT an initial migration
|
||||
return False, project_state
|
||||
|
||||
if project_state is None:
|
||||
after_state = self.loader.project_state(
|
||||
(migration.app_label, migration.name), at_end=True
|
||||
)
|
||||
else:
|
||||
after_state = migration.mutate_state(project_state)
|
||||
apps = after_state.apps
|
||||
found_create_model_migration = False
|
||||
found_add_field_migration = False
|
||||
fold_identifier_case = self.connection.features.ignores_table_name_case
|
||||
with self.connection.cursor() as cursor:
|
||||
existing_table_names = set(
|
||||
self.connection.introspection.table_names(cursor)
|
||||
)
|
||||
if fold_identifier_case:
|
||||
existing_table_names = {
|
||||
name.casefold() for name in existing_table_names
|
||||
}
|
||||
# Make sure all create model and add field operations are done
|
||||
for operation in migration.operations:
|
||||
if isinstance(operation, migrations.CreateModel):
|
||||
model = apps.get_model(migration.app_label, operation.name)
|
||||
if model._meta.swapped:
|
||||
# We have to fetch the model to test with from the
|
||||
# main app cache, as it's not a direct dependency.
|
||||
model = global_apps.get_model(model._meta.swapped)
|
||||
if should_skip_detecting_model(migration, model):
|
||||
continue
|
||||
db_table = model._meta.db_table
|
||||
if fold_identifier_case:
|
||||
db_table = db_table.casefold()
|
||||
if db_table not in existing_table_names:
|
||||
return False, project_state
|
||||
found_create_model_migration = True
|
||||
elif isinstance(operation, migrations.AddField):
|
||||
model = apps.get_model(migration.app_label, operation.model_name)
|
||||
if model._meta.swapped:
|
||||
# We have to fetch the model to test with from the
|
||||
# main app cache, as it's not a direct dependency.
|
||||
model = global_apps.get_model(model._meta.swapped)
|
||||
if should_skip_detecting_model(migration, model):
|
||||
continue
|
||||
|
||||
table = model._meta.db_table
|
||||
field = model._meta.get_field(operation.name)
|
||||
|
||||
# Handle implicit many-to-many tables created by AddField.
|
||||
if field.many_to_many:
|
||||
through_db_table = field.remote_field.through._meta.db_table
|
||||
if fold_identifier_case:
|
||||
through_db_table = through_db_table.casefold()
|
||||
if through_db_table not in existing_table_names:
|
||||
return False, project_state
|
||||
else:
|
||||
found_add_field_migration = True
|
||||
continue
|
||||
with self.connection.cursor() as cursor:
|
||||
columns = self.connection.introspection.get_table_description(
|
||||
cursor, table
|
||||
)
|
||||
for column in columns:
|
||||
field_column = field.column
|
||||
column_name = column.name
|
||||
if fold_identifier_case:
|
||||
column_name = column_name.casefold()
|
||||
field_column = field_column.casefold()
|
||||
if column_name == field_column:
|
||||
found_add_field_migration = True
|
||||
break
|
||||
else:
|
||||
return False, project_state
|
||||
# If we get this far and we found at least one CreateModel or AddField
|
||||
# migration, the migration is considered implicitly applied.
|
||||
return (found_create_model_migration or found_add_field_migration), after_state
|
||||
@@ -0,0 +1,333 @@
|
||||
from functools import total_ordering
|
||||
|
||||
from django.db.migrations.state import ProjectState
|
||||
|
||||
from .exceptions import CircularDependencyError, NodeNotFoundError
|
||||
|
||||
|
||||
@total_ordering
|
||||
class Node:
|
||||
"""
|
||||
A single node in the migration graph. Contains direct links to adjacent
|
||||
nodes in either direction.
|
||||
"""
|
||||
|
||||
def __init__(self, key):
|
||||
self.key = key
|
||||
self.children = set()
|
||||
self.parents = set()
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.key == other
|
||||
|
||||
def __lt__(self, other):
|
||||
return self.key < other
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.key)
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self.key[item]
|
||||
|
||||
def __str__(self):
|
||||
return str(self.key)
|
||||
|
||||
def __repr__(self):
|
||||
return "<%s: (%r, %r)>" % (self.__class__.__name__, self.key[0], self.key[1])
|
||||
|
||||
def add_child(self, child):
|
||||
self.children.add(child)
|
||||
|
||||
def add_parent(self, parent):
|
||||
self.parents.add(parent)
|
||||
|
||||
|
||||
class DummyNode(Node):
|
||||
"""
|
||||
A node that doesn't correspond to a migration file on disk.
|
||||
(A squashed migration that was removed, for example.)
|
||||
|
||||
After the migration graph is processed, all dummy nodes should be removed.
|
||||
If there are any left, a nonexistent dependency error is raised.
|
||||
"""
|
||||
|
||||
def __init__(self, key, origin, error_message):
|
||||
super().__init__(key)
|
||||
self.origin = origin
|
||||
self.error_message = error_message
|
||||
|
||||
def raise_error(self):
|
||||
raise NodeNotFoundError(self.error_message, self.key, origin=self.origin)
|
||||
|
||||
|
||||
class MigrationGraph:
|
||||
"""
|
||||
Represent the digraph of all migrations in a project.
|
||||
|
||||
Each migration is a node, and each dependency is an edge. There are
|
||||
no implicit dependencies between numbered migrations - the numbering is
|
||||
merely a convention to aid file listing. Every new numbered migration
|
||||
has a declared dependency to the previous number, meaning that VCS
|
||||
branch merges can be detected and resolved.
|
||||
|
||||
Migrations files can be marked as replacing another set of migrations -
|
||||
this is to support the "squash" feature. The graph handler isn't responsible
|
||||
for these; instead, the code to load them in here should examine the
|
||||
migration files and if the replaced migrations are all either unapplied
|
||||
or not present, it should ignore the replaced ones, load in just the
|
||||
replacing migration, and repoint any dependencies that pointed to the
|
||||
replaced migrations to point to the replacing one.
|
||||
|
||||
A node should be a tuple: (app_path, migration_name). The tree special-cases
|
||||
things within an app - namely, root nodes and leaf nodes ignore dependencies
|
||||
to other apps.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.node_map = {}
|
||||
self.nodes = {}
|
||||
|
||||
def add_node(self, key, migration):
|
||||
assert key not in self.node_map
|
||||
node = Node(key)
|
||||
self.node_map[key] = node
|
||||
self.nodes[key] = migration
|
||||
|
||||
def add_dummy_node(self, key, origin, error_message):
|
||||
node = DummyNode(key, origin, error_message)
|
||||
self.node_map[key] = node
|
||||
self.nodes[key] = None
|
||||
|
||||
def add_dependency(self, migration, child, parent, skip_validation=False):
|
||||
"""
|
||||
This may create dummy nodes if they don't yet exist. If
|
||||
`skip_validation=True`, validate_consistency() should be called
|
||||
afterward.
|
||||
"""
|
||||
if child not in self.nodes:
|
||||
error_message = (
|
||||
"Migration %s dependencies reference nonexistent"
|
||||
" child node %r" % (migration, child)
|
||||
)
|
||||
self.add_dummy_node(child, migration, error_message)
|
||||
if parent not in self.nodes:
|
||||
error_message = (
|
||||
"Migration %s dependencies reference nonexistent"
|
||||
" parent node %r" % (migration, parent)
|
||||
)
|
||||
self.add_dummy_node(parent, migration, error_message)
|
||||
self.node_map[child].add_parent(self.node_map[parent])
|
||||
self.node_map[parent].add_child(self.node_map[child])
|
||||
if not skip_validation:
|
||||
self.validate_consistency()
|
||||
|
||||
def remove_replaced_nodes(self, replacement, replaced):
|
||||
"""
|
||||
Remove each of the `replaced` nodes (when they exist). Any
|
||||
dependencies that were referencing them are changed to reference the
|
||||
`replacement` node instead.
|
||||
"""
|
||||
# Cast list of replaced keys to set to speed up lookup later.
|
||||
replaced = set(replaced)
|
||||
try:
|
||||
replacement_node = self.node_map[replacement]
|
||||
except KeyError as err:
|
||||
raise NodeNotFoundError(
|
||||
"Unable to find replacement node %r. It was either never added"
|
||||
" to the migration graph, or has been removed." % (replacement,),
|
||||
replacement,
|
||||
) from err
|
||||
for replaced_key in replaced:
|
||||
self.nodes.pop(replaced_key, None)
|
||||
replaced_node = self.node_map.pop(replaced_key, None)
|
||||
if replaced_node:
|
||||
for child in replaced_node.children:
|
||||
child.parents.remove(replaced_node)
|
||||
# We don't want to create dependencies between the replaced
|
||||
# node and the replacement node as this would lead to
|
||||
# self-referencing on the replacement node at a later iteration.
|
||||
if child.key not in replaced:
|
||||
replacement_node.add_child(child)
|
||||
child.add_parent(replacement_node)
|
||||
for parent in replaced_node.parents:
|
||||
parent.children.remove(replaced_node)
|
||||
# Again, to avoid self-referencing.
|
||||
if parent.key not in replaced:
|
||||
replacement_node.add_parent(parent)
|
||||
parent.add_child(replacement_node)
|
||||
|
||||
def remove_replacement_node(self, replacement, replaced):
|
||||
"""
|
||||
The inverse operation to `remove_replaced_nodes`. Almost. Remove the
|
||||
replacement node `replacement` and remap its child nodes to `replaced`
|
||||
- the list of nodes it would have replaced. Don't remap its parent
|
||||
nodes as they are expected to be correct already.
|
||||
"""
|
||||
self.nodes.pop(replacement, None)
|
||||
try:
|
||||
replacement_node = self.node_map.pop(replacement)
|
||||
except KeyError as err:
|
||||
raise NodeNotFoundError(
|
||||
"Unable to remove replacement node %r. It was either never added"
|
||||
" to the migration graph, or has been removed already."
|
||||
% (replacement,),
|
||||
replacement,
|
||||
) from err
|
||||
replaced_nodes = set()
|
||||
replaced_nodes_parents = set()
|
||||
for key in replaced:
|
||||
replaced_node = self.node_map.get(key)
|
||||
if replaced_node:
|
||||
replaced_nodes.add(replaced_node)
|
||||
replaced_nodes_parents |= replaced_node.parents
|
||||
# We're only interested in the latest replaced node, so filter out
|
||||
# replaced nodes that are parents of other replaced nodes.
|
||||
replaced_nodes -= replaced_nodes_parents
|
||||
for child in replacement_node.children:
|
||||
child.parents.remove(replacement_node)
|
||||
for replaced_node in replaced_nodes:
|
||||
replaced_node.add_child(child)
|
||||
child.add_parent(replaced_node)
|
||||
for parent in replacement_node.parents:
|
||||
parent.children.remove(replacement_node)
|
||||
# NOTE: There is no need to remap parent dependencies as we can
|
||||
# assume the replaced nodes already have the correct ancestry.
|
||||
|
||||
def validate_consistency(self):
|
||||
"""Ensure there are no dummy nodes remaining in the graph."""
|
||||
[n.raise_error() for n in self.node_map.values() if isinstance(n, DummyNode)]
|
||||
|
||||
def forwards_plan(self, target):
|
||||
"""
|
||||
Given a node, return a list of which previous nodes (dependencies) must
|
||||
be applied, ending with the node itself. This is the list you would
|
||||
follow if applying the migrations to a database.
|
||||
"""
|
||||
if target not in self.nodes:
|
||||
raise NodeNotFoundError("Node %r not a valid node" % (target,), target)
|
||||
return self.iterative_dfs(self.node_map[target])
|
||||
|
||||
def backwards_plan(self, target):
|
||||
"""
|
||||
Given a node, return a list of which dependent nodes (dependencies)
|
||||
must be unapplied, ending with the node itself. This is the list you
|
||||
would follow if removing the migrations from a database.
|
||||
"""
|
||||
if target not in self.nodes:
|
||||
raise NodeNotFoundError("Node %r not a valid node" % (target,), target)
|
||||
return self.iterative_dfs(self.node_map[target], forwards=False)
|
||||
|
||||
def iterative_dfs(self, start, forwards=True):
|
||||
"""Iterative depth-first search for finding dependencies."""
|
||||
visited = []
|
||||
visited_set = set()
|
||||
stack = [(start, False)]
|
||||
while stack:
|
||||
node, processed = stack.pop()
|
||||
if node in visited_set:
|
||||
pass
|
||||
elif processed:
|
||||
visited_set.add(node)
|
||||
visited.append(node.key)
|
||||
else:
|
||||
stack.append((node, True))
|
||||
stack += [
|
||||
(n, False)
|
||||
for n in sorted(node.parents if forwards else node.children)
|
||||
]
|
||||
return visited
|
||||
|
||||
def root_nodes(self, app=None):
|
||||
"""
|
||||
Return all root nodes - that is, nodes with no dependencies inside
|
||||
their app. These are the starting point for an app.
|
||||
"""
|
||||
roots = set()
|
||||
for node in self.nodes:
|
||||
if all(key[0] != node[0] for key in self.node_map[node].parents) and (
|
||||
not app or app == node[0]
|
||||
):
|
||||
roots.add(node)
|
||||
return sorted(roots)
|
||||
|
||||
def leaf_nodes(self, app=None):
|
||||
"""
|
||||
Return all leaf nodes - that is, nodes with no dependents in their app.
|
||||
These are the "most current" version of an app's schema.
|
||||
Having more than one per app is technically an error, but one that
|
||||
gets handled further up, in the interactive command - it's usually the
|
||||
result of a VCS merge and needs some user input.
|
||||
"""
|
||||
leaves = set()
|
||||
for node in self.nodes:
|
||||
if all(key[0] != node[0] for key in self.node_map[node].children) and (
|
||||
not app or app == node[0]
|
||||
):
|
||||
leaves.add(node)
|
||||
return sorted(leaves)
|
||||
|
||||
def ensure_not_cyclic(self):
|
||||
# Algo from GvR:
|
||||
# https://neopythonic.blogspot.com/2009/01/detecting-cycles-in-directed-graph.html
|
||||
todo = set(self.nodes)
|
||||
while todo:
|
||||
node = todo.pop()
|
||||
stack = [node]
|
||||
while stack:
|
||||
top = stack[-1]
|
||||
for child in self.node_map[top].children:
|
||||
# Use child.key instead of child to speed up the frequent
|
||||
# hashing.
|
||||
node = child.key
|
||||
if node in stack:
|
||||
cycle = stack[stack.index(node) :]
|
||||
raise CircularDependencyError(
|
||||
", ".join("%s.%s" % n for n in cycle)
|
||||
)
|
||||
if node in todo:
|
||||
stack.append(node)
|
||||
todo.remove(node)
|
||||
break
|
||||
else:
|
||||
node = stack.pop()
|
||||
|
||||
def __str__(self):
|
||||
return "Graph: %s nodes, %s edges" % self._nodes_and_edges()
|
||||
|
||||
def __repr__(self):
|
||||
nodes, edges = self._nodes_and_edges()
|
||||
return "<%s: nodes=%s, edges=%s>" % (self.__class__.__name__, nodes, edges)
|
||||
|
||||
def _nodes_and_edges(self):
|
||||
return len(self.nodes), sum(
|
||||
len(node.parents) for node in self.node_map.values()
|
||||
)
|
||||
|
||||
def _generate_plan(self, nodes, at_end):
|
||||
plan = []
|
||||
for node in nodes:
|
||||
for migration in self.forwards_plan(node):
|
||||
if migration not in plan and (at_end or migration not in nodes):
|
||||
plan.append(migration)
|
||||
return plan
|
||||
|
||||
def make_state(self, nodes=None, at_end=True, real_apps=None):
|
||||
"""
|
||||
Given a migration node or nodes, return a complete ProjectState for it.
|
||||
If at_end is False, return the state before the migration has run.
|
||||
If nodes is not provided, return the overall most current project state.
|
||||
"""
|
||||
if nodes is None:
|
||||
nodes = list(self.leaf_nodes())
|
||||
if not nodes:
|
||||
return ProjectState()
|
||||
if not isinstance(nodes[0], tuple):
|
||||
nodes = [nodes]
|
||||
plan = self._generate_plan(nodes, at_end)
|
||||
project_state = ProjectState(real_apps=real_apps)
|
||||
for node in plan:
|
||||
project_state = self.nodes[node].mutate_state(project_state, preserve=False)
|
||||
return project_state
|
||||
|
||||
def __contains__(self, node):
|
||||
return node in self.nodes
|
||||
@@ -0,0 +1,385 @@
|
||||
import pkgutil
|
||||
import sys
|
||||
from importlib import import_module, reload
|
||||
|
||||
from django.apps import apps
|
||||
from django.conf import settings
|
||||
from django.db.migrations.graph import MigrationGraph
|
||||
from django.db.migrations.recorder import MigrationRecorder
|
||||
|
||||
from .exceptions import (
|
||||
AmbiguityError,
|
||||
BadMigrationError,
|
||||
InconsistentMigrationHistory,
|
||||
NodeNotFoundError,
|
||||
)
|
||||
|
||||
MIGRATIONS_MODULE_NAME = "migrations"
|
||||
|
||||
|
||||
class MigrationLoader:
|
||||
"""
|
||||
Load migration files from disk and their status from the database.
|
||||
|
||||
Migration files are expected to live in the "migrations" directory of
|
||||
an app. Their names are entirely unimportant from a code perspective,
|
||||
but will probably follow the 1234_name.py convention.
|
||||
|
||||
On initialization, this class will scan those directories, and open and
|
||||
read the Python files, looking for a class called Migration, which should
|
||||
inherit from django.db.migrations.Migration. See
|
||||
django.db.migrations.migration for what that looks like.
|
||||
|
||||
Some migrations will be marked as "replacing" another set of migrations.
|
||||
These are loaded into a separate set of migrations away from the main ones.
|
||||
If all the migrations they replace are either unapplied or missing from
|
||||
disk, then they are injected into the main set, replacing the named migrations.
|
||||
Any dependency pointers to the replaced migrations are re-pointed to the
|
||||
new migration.
|
||||
|
||||
This does mean that this class MUST also talk to the database as well as
|
||||
to disk, but this is probably fine. We're already not just operating
|
||||
in memory.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connection,
|
||||
load=True,
|
||||
ignore_no_migrations=False,
|
||||
replace_migrations=True,
|
||||
):
|
||||
self.connection = connection
|
||||
self.disk_migrations = None
|
||||
self.applied_migrations = None
|
||||
self.ignore_no_migrations = ignore_no_migrations
|
||||
self.replace_migrations = replace_migrations
|
||||
if load:
|
||||
self.build_graph()
|
||||
|
||||
@classmethod
|
||||
def migrations_module(cls, app_label):
|
||||
"""
|
||||
Return the path to the migrations module for the specified app_label
|
||||
and a boolean indicating if the module is specified in
|
||||
settings.MIGRATION_MODULE.
|
||||
"""
|
||||
if app_label in settings.MIGRATION_MODULES:
|
||||
return settings.MIGRATION_MODULES[app_label], True
|
||||
else:
|
||||
app_package_name = apps.get_app_config(app_label).name
|
||||
return "%s.%s" % (app_package_name, MIGRATIONS_MODULE_NAME), False
|
||||
|
||||
def load_disk(self):
|
||||
"""Load the migrations from all INSTALLED_APPS from disk."""
|
||||
self.disk_migrations = {}
|
||||
self.unmigrated_apps = set()
|
||||
self.migrated_apps = set()
|
||||
for app_config in apps.get_app_configs():
|
||||
# Get the migrations module directory
|
||||
module_name, explicit = self.migrations_module(app_config.label)
|
||||
if module_name is None:
|
||||
self.unmigrated_apps.add(app_config.label)
|
||||
continue
|
||||
was_loaded = module_name in sys.modules
|
||||
try:
|
||||
module = import_module(module_name)
|
||||
except ModuleNotFoundError as e:
|
||||
if (explicit and self.ignore_no_migrations) or (
|
||||
not explicit and MIGRATIONS_MODULE_NAME in e.name.split(".")
|
||||
):
|
||||
self.unmigrated_apps.add(app_config.label)
|
||||
continue
|
||||
raise
|
||||
else:
|
||||
# Module is not a package (e.g. migrations.py).
|
||||
if not hasattr(module, "__path__"):
|
||||
self.unmigrated_apps.add(app_config.label)
|
||||
continue
|
||||
# Empty directories are namespaces. Namespace packages have no
|
||||
# __file__ and don't use a list for __path__. See
|
||||
# https://docs.python.org/3/reference/import.html#namespace-packages
|
||||
if getattr(module, "__file__", None) is None and not isinstance(
|
||||
module.__path__, list
|
||||
):
|
||||
self.unmigrated_apps.add(app_config.label)
|
||||
continue
|
||||
# Force a reload if it's already loaded (tests need this)
|
||||
if was_loaded:
|
||||
reload(module)
|
||||
self.migrated_apps.add(app_config.label)
|
||||
migration_names = {
|
||||
name
|
||||
for _, name, is_pkg in pkgutil.iter_modules(module.__path__)
|
||||
if not is_pkg and name[0] not in "_~"
|
||||
}
|
||||
# Load migrations
|
||||
for migration_name in migration_names:
|
||||
migration_path = "%s.%s" % (module_name, migration_name)
|
||||
try:
|
||||
migration_module = import_module(migration_path)
|
||||
except ImportError as e:
|
||||
if "bad magic number" in str(e):
|
||||
raise ImportError(
|
||||
"Couldn't import %r as it appears to be a stale "
|
||||
".pyc file." % migration_path
|
||||
) from e
|
||||
else:
|
||||
raise
|
||||
if not hasattr(migration_module, "Migration"):
|
||||
raise BadMigrationError(
|
||||
"Migration %s in app %s has no Migration class"
|
||||
% (migration_name, app_config.label)
|
||||
)
|
||||
self.disk_migrations[app_config.label, migration_name] = (
|
||||
migration_module.Migration(
|
||||
migration_name,
|
||||
app_config.label,
|
||||
)
|
||||
)
|
||||
|
||||
def get_migration(self, app_label, name_prefix):
|
||||
"""Return the named migration or raise NodeNotFoundError."""
|
||||
return self.graph.nodes[app_label, name_prefix]
|
||||
|
||||
def get_migration_by_prefix(self, app_label, name_prefix):
|
||||
"""
|
||||
Return the migration(s) which match the given app label and name_prefix.
|
||||
"""
|
||||
# Do the search
|
||||
results = []
|
||||
for migration_app_label, migration_name in self.disk_migrations:
|
||||
if migration_app_label == app_label and migration_name.startswith(
|
||||
name_prefix
|
||||
):
|
||||
results.append((migration_app_label, migration_name))
|
||||
if len(results) > 1:
|
||||
raise AmbiguityError(
|
||||
"There is more than one migration for '%s' with the prefix '%s'"
|
||||
% (app_label, name_prefix)
|
||||
)
|
||||
elif not results:
|
||||
raise KeyError(
|
||||
f"There is no migration for '{app_label}' with the prefix "
|
||||
f"'{name_prefix}'"
|
||||
)
|
||||
else:
|
||||
return self.disk_migrations[results[0]]
|
||||
|
||||
def check_key(self, key, current_app):
|
||||
if (key[1] != "__first__" and key[1] != "__latest__") or key in self.graph:
|
||||
return key
|
||||
# Special-case __first__, which means "the first migration" for
|
||||
# migrated apps, and is ignored for unmigrated apps. It allows
|
||||
# makemigrations to declare dependencies on apps before they even have
|
||||
# migrations.
|
||||
if key[0] == current_app:
|
||||
# Ignore __first__ references to the same app (#22325)
|
||||
return
|
||||
if key[0] in self.unmigrated_apps:
|
||||
# This app isn't migrated, but something depends on it.
|
||||
# The models will get auto-added into the state, though
|
||||
# so we're fine.
|
||||
return
|
||||
if key[0] in self.migrated_apps:
|
||||
try:
|
||||
if key[1] == "__first__":
|
||||
return self.graph.root_nodes(key[0])[0]
|
||||
else: # "__latest__"
|
||||
return self.graph.leaf_nodes(key[0])[0]
|
||||
except IndexError:
|
||||
if self.ignore_no_migrations:
|
||||
return None
|
||||
else:
|
||||
raise ValueError(
|
||||
"Dependency on app with no migrations: %s" % key[0]
|
||||
)
|
||||
raise ValueError("Dependency on unknown app: %s" % key[0])
|
||||
|
||||
def add_internal_dependencies(self, key, migration):
|
||||
"""
|
||||
Internal dependencies need to be added first to ensure `__first__`
|
||||
dependencies find the correct root node.
|
||||
"""
|
||||
for parent in migration.dependencies:
|
||||
# Ignore __first__ references to the same app.
|
||||
if parent[0] == key[0] and parent[1] != "__first__":
|
||||
self.graph.add_dependency(migration, key, parent, skip_validation=True)
|
||||
|
||||
def add_external_dependencies(self, key, migration):
|
||||
for parent in migration.dependencies:
|
||||
# Skip internal dependencies
|
||||
if key[0] == parent[0]:
|
||||
continue
|
||||
parent = self.check_key(parent, key[0])
|
||||
if parent is not None:
|
||||
self.graph.add_dependency(migration, key, parent, skip_validation=True)
|
||||
for child in migration.run_before:
|
||||
child = self.check_key(child, key[0])
|
||||
if child is not None:
|
||||
self.graph.add_dependency(migration, child, key, skip_validation=True)
|
||||
|
||||
def build_graph(self):
|
||||
"""
|
||||
Build a migration dependency graph using both the disk and database.
|
||||
You'll need to rebuild the graph if you apply migrations. This isn't
|
||||
usually a problem as generally migration stuff runs in a one-shot process.
|
||||
"""
|
||||
# Load disk data
|
||||
self.load_disk()
|
||||
# Load database data
|
||||
if self.connection is None:
|
||||
self.applied_migrations = {}
|
||||
else:
|
||||
recorder = MigrationRecorder(self.connection)
|
||||
self.applied_migrations = recorder.applied_migrations()
|
||||
# To start, populate the migration graph with nodes for ALL migrations
|
||||
# and their dependencies. Also make note of replacing migrations at this step.
|
||||
self.graph = MigrationGraph()
|
||||
self.replacements = {}
|
||||
for key, migration in self.disk_migrations.items():
|
||||
self.graph.add_node(key, migration)
|
||||
# Replacing migrations.
|
||||
if migration.replaces:
|
||||
self.replacements[key] = migration
|
||||
for key, migration in self.disk_migrations.items():
|
||||
# Internal (same app) dependencies.
|
||||
self.add_internal_dependencies(key, migration)
|
||||
# Add external dependencies now that the internal ones have been resolved.
|
||||
for key, migration in self.disk_migrations.items():
|
||||
self.add_external_dependencies(key, migration)
|
||||
# Carry out replacements where possible and if enabled.
|
||||
if self.replace_migrations:
|
||||
for key, migration in self.replacements.items():
|
||||
# Get applied status of each of this migration's replacement
|
||||
# targets.
|
||||
applied_statuses = [
|
||||
(target in self.applied_migrations) for target in migration.replaces
|
||||
]
|
||||
# The replacing migration is only marked as applied if all of
|
||||
# its replacement targets are.
|
||||
if all(applied_statuses):
|
||||
self.applied_migrations[key] = migration
|
||||
else:
|
||||
self.applied_migrations.pop(key, None)
|
||||
# A replacing migration can be used if either all or none of
|
||||
# its replacement targets have been applied.
|
||||
if all(applied_statuses) or (not any(applied_statuses)):
|
||||
self.graph.remove_replaced_nodes(key, migration.replaces)
|
||||
else:
|
||||
# This replacing migration cannot be used because it is
|
||||
# partially applied. Remove it from the graph and remap
|
||||
# dependencies to it (#25945).
|
||||
self.graph.remove_replacement_node(key, migration.replaces)
|
||||
# Ensure the graph is consistent.
|
||||
try:
|
||||
self.graph.validate_consistency()
|
||||
except NodeNotFoundError as exc:
|
||||
# Check if the missing node could have been replaced by any squash
|
||||
# migration but wasn't because the squash migration was partially
|
||||
# applied before. In that case raise a more understandable exception
|
||||
# (#23556).
|
||||
# Get reverse replacements.
|
||||
reverse_replacements = {}
|
||||
for key, migration in self.replacements.items():
|
||||
for replaced in migration.replaces:
|
||||
reverse_replacements.setdefault(replaced, set()).add(key)
|
||||
# Try to reraise exception with more detail.
|
||||
if exc.node in reverse_replacements:
|
||||
candidates = reverse_replacements.get(exc.node, set())
|
||||
is_replaced = any(
|
||||
candidate in self.graph.nodes for candidate in candidates
|
||||
)
|
||||
if not is_replaced:
|
||||
tries = ", ".join("%s.%s" % c for c in candidates)
|
||||
raise NodeNotFoundError(
|
||||
"Migration {0} depends on nonexistent node ('{1}', '{2}'). "
|
||||
"Django tried to replace migration {1}.{2} with any of [{3}] "
|
||||
"but wasn't able to because some of the replaced migrations "
|
||||
"are already applied.".format(
|
||||
exc.origin, exc.node[0], exc.node[1], tries
|
||||
),
|
||||
exc.node,
|
||||
) from exc
|
||||
raise
|
||||
self.graph.ensure_not_cyclic()
|
||||
|
||||
def check_consistent_history(self, connection):
|
||||
"""
|
||||
Raise InconsistentMigrationHistory if any applied migrations have
|
||||
unapplied dependencies.
|
||||
"""
|
||||
recorder = MigrationRecorder(connection)
|
||||
applied = recorder.applied_migrations()
|
||||
for migration in applied:
|
||||
# If the migration is unknown, skip it.
|
||||
if migration not in self.graph.nodes:
|
||||
continue
|
||||
for parent in self.graph.node_map[migration].parents:
|
||||
if parent not in applied:
|
||||
# Skip unapplied squashed migrations that have all of their
|
||||
# `replaces` applied.
|
||||
if parent in self.replacements:
|
||||
if all(
|
||||
m in applied for m in self.replacements[parent].replaces
|
||||
):
|
||||
continue
|
||||
raise InconsistentMigrationHistory(
|
||||
"Migration {}.{} is applied before its dependency "
|
||||
"{}.{} on database '{}'.".format(
|
||||
migration[0],
|
||||
migration[1],
|
||||
parent[0],
|
||||
parent[1],
|
||||
connection.alias,
|
||||
)
|
||||
)
|
||||
|
||||
def detect_conflicts(self):
|
||||
"""
|
||||
Look through the loaded graph and detect any conflicts - apps
|
||||
with more than one leaf migration. Return a dict of the app labels
|
||||
that conflict with the migration names that conflict.
|
||||
"""
|
||||
seen_apps = {}
|
||||
conflicting_apps = set()
|
||||
for app_label, migration_name in self.graph.leaf_nodes():
|
||||
if app_label in seen_apps:
|
||||
conflicting_apps.add(app_label)
|
||||
seen_apps.setdefault(app_label, set()).add(migration_name)
|
||||
return {
|
||||
app_label: sorted(seen_apps[app_label]) for app_label in conflicting_apps
|
||||
}
|
||||
|
||||
def project_state(self, nodes=None, at_end=True):
|
||||
"""
|
||||
Return a ProjectState object representing the most recent state
|
||||
that the loaded migrations represent.
|
||||
|
||||
See graph.make_state() for the meaning of "nodes" and "at_end".
|
||||
"""
|
||||
return self.graph.make_state(
|
||||
nodes=nodes, at_end=at_end, real_apps=self.unmigrated_apps
|
||||
)
|
||||
|
||||
def collect_sql(self, plan):
|
||||
"""
|
||||
Take a migration plan and return a list of collected SQL statements
|
||||
that represent the best-efforts version of that plan.
|
||||
"""
|
||||
statements = []
|
||||
state = None
|
||||
for migration, backwards in plan:
|
||||
with self.connection.schema_editor(
|
||||
collect_sql=True, atomic=migration.atomic
|
||||
) as schema_editor:
|
||||
if state is None:
|
||||
state = self.project_state(
|
||||
(migration.app_label, migration.name), at_end=False
|
||||
)
|
||||
if not backwards:
|
||||
state = migration.apply(state, schema_editor, collect_sql=True)
|
||||
else:
|
||||
state = migration.unapply(state, schema_editor, collect_sql=True)
|
||||
statements.extend(schema_editor.collected_sql)
|
||||
return statements
|
||||
@@ -0,0 +1,239 @@
|
||||
import re
|
||||
|
||||
from django.db.migrations.utils import get_migration_name_timestamp
|
||||
from django.db.transaction import atomic
|
||||
|
||||
from .exceptions import IrreversibleError
|
||||
|
||||
|
||||
class Migration:
|
||||
"""
|
||||
The base class for all migrations.
|
||||
|
||||
Migration files will import this from django.db.migrations.Migration
|
||||
and subclass it as a class called Migration. It will have one or more
|
||||
of the following attributes:
|
||||
|
||||
- operations: A list of Operation instances, probably from
|
||||
django.db.migrations.operations
|
||||
- dependencies: A list of tuples of (app_path, migration_name)
|
||||
- run_before: A list of tuples of (app_path, migration_name)
|
||||
- replaces: A list of migration_names
|
||||
|
||||
Note that all migrations come out of migrations and into the Loader or
|
||||
Graph as instances, having been initialized with their app label and name.
|
||||
"""
|
||||
|
||||
# Operations to apply during this migration, in order.
|
||||
operations = []
|
||||
|
||||
# Other migrations that should be run before this migration.
|
||||
# Should be a list of (app, migration_name).
|
||||
dependencies = []
|
||||
|
||||
# Other migrations that should be run after this one (i.e. have
|
||||
# this migration added to their dependencies). Useful to make third-party
|
||||
# apps' migrations run after your AUTH_USER replacement, for example.
|
||||
run_before = []
|
||||
|
||||
# Migration names in this app that this migration replaces. If this is
|
||||
# non-empty, this migration will only be applied if all these migrations
|
||||
# are not applied.
|
||||
replaces = []
|
||||
|
||||
# Is this an initial migration? Initial migrations are skipped on
|
||||
# --fake-initial if the table or fields already exist. If None, check if
|
||||
# the migration has any dependencies to determine if there are dependencies
|
||||
# to tell if db introspection needs to be done. If True, always perform
|
||||
# introspection. If False, never perform introspection.
|
||||
initial = None
|
||||
|
||||
# Whether to wrap the whole migration in a transaction. Only has an effect
|
||||
# on database backends which support transactional DDL.
|
||||
atomic = True
|
||||
|
||||
def __init__(self, name, app_label):
|
||||
self.name = name
|
||||
self.app_label = app_label
|
||||
# Copy dependencies & other attrs as we might mutate them at runtime
|
||||
self.operations = list(self.__class__.operations)
|
||||
self.dependencies = list(self.__class__.dependencies)
|
||||
self.run_before = list(self.__class__.run_before)
|
||||
self.replaces = list(self.__class__.replaces)
|
||||
|
||||
def __eq__(self, other):
|
||||
return (
|
||||
isinstance(other, Migration)
|
||||
and self.name == other.name
|
||||
and self.app_label == other.app_label
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return "<Migration %s.%s>" % (self.app_label, self.name)
|
||||
|
||||
def __str__(self):
|
||||
return "%s.%s" % (self.app_label, self.name)
|
||||
|
||||
def __hash__(self):
|
||||
return hash("%s.%s" % (self.app_label, self.name))
|
||||
|
||||
def mutate_state(self, project_state, preserve=True):
|
||||
"""
|
||||
Take a ProjectState and return a new one with the migration's
|
||||
operations applied to it. Preserve the original object state by
|
||||
default and return a mutated state from a copy.
|
||||
"""
|
||||
new_state = project_state
|
||||
if preserve:
|
||||
new_state = project_state.clone()
|
||||
|
||||
for operation in self.operations:
|
||||
operation.state_forwards(self.app_label, new_state)
|
||||
return new_state
|
||||
|
||||
def apply(self, project_state, schema_editor, collect_sql=False):
|
||||
"""
|
||||
Take a project_state representing all migrations prior to this one
|
||||
and a schema_editor for a live database and apply the migration
|
||||
in a forwards order.
|
||||
|
||||
Return the resulting project state for efficient reuse by following
|
||||
Migrations.
|
||||
"""
|
||||
for operation in self.operations:
|
||||
# If this operation cannot be represented as SQL, place a comment
|
||||
# there instead
|
||||
if collect_sql:
|
||||
schema_editor.collected_sql.append("--")
|
||||
schema_editor.collected_sql.append("-- %s" % operation.describe())
|
||||
schema_editor.collected_sql.append("--")
|
||||
if not operation.reduces_to_sql:
|
||||
schema_editor.collected_sql.append(
|
||||
"-- THIS OPERATION CANNOT BE WRITTEN AS SQL"
|
||||
)
|
||||
continue
|
||||
collected_sql_before = len(schema_editor.collected_sql)
|
||||
# Save the state before the operation has run
|
||||
old_state = project_state.clone()
|
||||
operation.state_forwards(self.app_label, project_state)
|
||||
# Run the operation
|
||||
atomic_operation = operation.atomic or (
|
||||
self.atomic and operation.atomic is not False
|
||||
)
|
||||
if not schema_editor.atomic_migration and atomic_operation:
|
||||
# Force a transaction on a non-transactional-DDL backend or an
|
||||
# atomic operation inside a non-atomic migration.
|
||||
with atomic(schema_editor.connection.alias):
|
||||
operation.database_forwards(
|
||||
self.app_label, schema_editor, old_state, project_state
|
||||
)
|
||||
else:
|
||||
# Normal behaviour
|
||||
operation.database_forwards(
|
||||
self.app_label, schema_editor, old_state, project_state
|
||||
)
|
||||
if collect_sql and collected_sql_before == len(schema_editor.collected_sql):
|
||||
schema_editor.collected_sql.append("-- (no-op)")
|
||||
return project_state
|
||||
|
||||
def unapply(self, project_state, schema_editor, collect_sql=False):
|
||||
"""
|
||||
Take a project_state representing all migrations prior to this one
|
||||
and a schema_editor for a live database and apply the migration
|
||||
in a reverse order.
|
||||
|
||||
The backwards migration process consists of two phases:
|
||||
|
||||
1. The intermediate states from right before the first until right
|
||||
after the last operation inside this migration are preserved.
|
||||
2. The operations are applied in reverse order using the states
|
||||
recorded in step 1.
|
||||
"""
|
||||
# Construct all the intermediate states we need for a reverse migration
|
||||
to_run = []
|
||||
new_state = project_state
|
||||
# Phase 1
|
||||
for operation in self.operations:
|
||||
# If it's irreversible, error out
|
||||
if not operation.reversible:
|
||||
raise IrreversibleError(
|
||||
"Operation %s in %s is not reversible" % (operation, self)
|
||||
)
|
||||
# Preserve new state from previous run to not tamper the same state
|
||||
# over all operations
|
||||
new_state = new_state.clone()
|
||||
old_state = new_state.clone()
|
||||
operation.state_forwards(self.app_label, new_state)
|
||||
to_run.insert(0, (operation, old_state, new_state))
|
||||
|
||||
# Phase 2
|
||||
for operation, to_state, from_state in to_run:
|
||||
if collect_sql:
|
||||
schema_editor.collected_sql.append("--")
|
||||
schema_editor.collected_sql.append("-- %s" % operation.describe())
|
||||
schema_editor.collected_sql.append("--")
|
||||
if not operation.reduces_to_sql:
|
||||
schema_editor.collected_sql.append(
|
||||
"-- THIS OPERATION CANNOT BE WRITTEN AS SQL"
|
||||
)
|
||||
continue
|
||||
collected_sql_before = len(schema_editor.collected_sql)
|
||||
atomic_operation = operation.atomic or (
|
||||
self.atomic and operation.atomic is not False
|
||||
)
|
||||
if not schema_editor.atomic_migration and atomic_operation:
|
||||
# Force a transaction on a non-transactional-DDL backend or an
|
||||
# atomic operation inside a non-atomic migration.
|
||||
with atomic(schema_editor.connection.alias):
|
||||
operation.database_backwards(
|
||||
self.app_label, schema_editor, from_state, to_state
|
||||
)
|
||||
else:
|
||||
# Normal behaviour
|
||||
operation.database_backwards(
|
||||
self.app_label, schema_editor, from_state, to_state
|
||||
)
|
||||
if collect_sql and collected_sql_before == len(schema_editor.collected_sql):
|
||||
schema_editor.collected_sql.append("-- (no-op)")
|
||||
return project_state
|
||||
|
||||
def suggest_name(self):
|
||||
"""
|
||||
Suggest a name for the operations this migration might represent. Names
|
||||
are not guaranteed to be unique, but put some effort into the fallback
|
||||
name to avoid VCS conflicts if possible.
|
||||
"""
|
||||
if self.initial:
|
||||
return "initial"
|
||||
|
||||
raw_fragments = [op.migration_name_fragment for op in self.operations]
|
||||
fragments = [re.sub(r"\W+", "_", name) for name in raw_fragments if name]
|
||||
|
||||
if not fragments or len(fragments) != len(self.operations):
|
||||
return "auto_%s" % get_migration_name_timestamp()
|
||||
|
||||
name = fragments[0]
|
||||
for fragment in fragments[1:]:
|
||||
new_name = f"{name}_{fragment}"
|
||||
if len(new_name) > 52:
|
||||
name = f"{name}_and_more"
|
||||
break
|
||||
name = new_name
|
||||
return name
|
||||
|
||||
|
||||
class SwappableTuple(tuple):
|
||||
"""
|
||||
Subclass of tuple so Django can tell this was originally a swappable
|
||||
dependency when it reads the migration file.
|
||||
"""
|
||||
|
||||
def __new__(cls, value, setting):
|
||||
self = tuple.__new__(cls, value)
|
||||
self.setting = setting
|
||||
return self
|
||||
|
||||
|
||||
def swappable_dependency(value):
|
||||
"""Turn a setting value into a dependency."""
|
||||
return SwappableTuple((value.split(".", 1)[0], "__first__"), value)
|
||||
@@ -0,0 +1,46 @@
|
||||
from .fields import AddField, AlterField, RemoveField, RenameField
|
||||
from .models import (
|
||||
AddConstraint,
|
||||
AddIndex,
|
||||
AlterConstraint,
|
||||
AlterIndexTogether,
|
||||
AlterModelManagers,
|
||||
AlterModelOptions,
|
||||
AlterModelTable,
|
||||
AlterModelTableComment,
|
||||
AlterOrderWithRespectTo,
|
||||
AlterUniqueTogether,
|
||||
CreateModel,
|
||||
DeleteModel,
|
||||
RemoveConstraint,
|
||||
RemoveIndex,
|
||||
RenameIndex,
|
||||
RenameModel,
|
||||
)
|
||||
from .special import RunPython, RunSQL, SeparateDatabaseAndState
|
||||
|
||||
__all__ = [
|
||||
"CreateModel",
|
||||
"DeleteModel",
|
||||
"AlterModelTable",
|
||||
"AlterModelTableComment",
|
||||
"AlterUniqueTogether",
|
||||
"RenameModel",
|
||||
"AlterIndexTogether",
|
||||
"AlterModelOptions",
|
||||
"AddIndex",
|
||||
"RemoveIndex",
|
||||
"RenameIndex",
|
||||
"AddField",
|
||||
"RemoveField",
|
||||
"AlterField",
|
||||
"RenameField",
|
||||
"AddConstraint",
|
||||
"RemoveConstraint",
|
||||
"AlterConstraint",
|
||||
"SeparateDatabaseAndState",
|
||||
"RunSQL",
|
||||
"RunPython",
|
||||
"AlterOrderWithRespectTo",
|
||||
"AlterModelManagers",
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,166 @@
|
||||
import enum
|
||||
|
||||
from django.db import router
|
||||
|
||||
|
||||
class OperationCategory(str, enum.Enum):
|
||||
ADDITION = "+"
|
||||
REMOVAL = "-"
|
||||
ALTERATION = "~"
|
||||
PYTHON = "p"
|
||||
SQL = "s"
|
||||
MIXED = "?"
|
||||
|
||||
|
||||
class Operation:
|
||||
"""
|
||||
Base class for migration operations.
|
||||
|
||||
It's responsible for both mutating the in-memory model state
|
||||
(see db/migrations/state.py) to represent what it performs, as well
|
||||
as actually performing it against a live database.
|
||||
|
||||
Note that some operations won't modify memory state at all (e.g. data
|
||||
copying operations), and some will need their modifications to be
|
||||
optionally specified by the user (e.g. custom Python code snippets)
|
||||
|
||||
Due to the way this class deals with deconstruction, it should be
|
||||
considered immutable.
|
||||
"""
|
||||
|
||||
# If this migration can be run in reverse.
|
||||
# Some operations are impossible to reverse, like deleting data.
|
||||
reversible = True
|
||||
|
||||
# Can this migration be represented as SQL? (things like RunPython cannot)
|
||||
reduces_to_sql = True
|
||||
|
||||
# Should this operation be forced as atomic even on backends with no
|
||||
# DDL transaction support (i.e., does it have no DDL, like RunPython)
|
||||
atomic = False
|
||||
|
||||
# Should this operation be considered safe to elide and optimize across?
|
||||
elidable = False
|
||||
|
||||
serialization_expand_args = []
|
||||
|
||||
category = None
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
# We capture the arguments to make returning them trivial
|
||||
self = object.__new__(cls)
|
||||
self._constructor_args = (args, kwargs)
|
||||
return self
|
||||
|
||||
def deconstruct(self):
|
||||
"""
|
||||
Return a 3-tuple of class import path (or just name if it lives
|
||||
under django.db.migrations), positional arguments, and keyword
|
||||
arguments.
|
||||
"""
|
||||
return (
|
||||
self.__class__.__name__,
|
||||
self._constructor_args[0],
|
||||
self._constructor_args[1],
|
||||
)
|
||||
|
||||
def state_forwards(self, app_label, state):
|
||||
"""
|
||||
Take the state from the previous migration, and mutate it
|
||||
so that it matches what this migration would perform.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"subclasses of Operation must provide a state_forwards() method"
|
||||
)
|
||||
|
||||
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
||||
"""
|
||||
Perform the mutation on the database schema in the normal
|
||||
(forwards) direction.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"subclasses of Operation must provide a database_forwards() method"
|
||||
)
|
||||
|
||||
def database_backwards(self, app_label, schema_editor, from_state, to_state):
|
||||
"""
|
||||
Perform the mutation on the database schema in the reverse
|
||||
direction - e.g. if this were CreateModel, it would in fact
|
||||
drop the model's table.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"subclasses of Operation must provide a database_backwards() method"
|
||||
)
|
||||
|
||||
def describe(self):
|
||||
"""
|
||||
Output a brief summary of what the action does.
|
||||
"""
|
||||
return "%s: %s" % (self.__class__.__name__, self._constructor_args)
|
||||
|
||||
def formatted_description(self):
|
||||
"""Output a description prefixed by a category symbol."""
|
||||
description = self.describe()
|
||||
if self.category is None:
|
||||
return f"{OperationCategory.MIXED.value} {description}"
|
||||
return f"{self.category.value} {description}"
|
||||
|
||||
@property
|
||||
def migration_name_fragment(self):
|
||||
"""
|
||||
A filename part suitable for automatically naming a migration
|
||||
containing this operation, or None if not applicable.
|
||||
"""
|
||||
return None
|
||||
|
||||
def references_model(self, name, app_label):
|
||||
"""
|
||||
Return True if there is a chance this operation references the given
|
||||
model name (as a string), with an app label for accuracy.
|
||||
|
||||
Used for optimization. If in doubt, return True;
|
||||
returning a false positive will merely make the optimizer a little
|
||||
less efficient, while returning a false negative may result in an
|
||||
unusable optimized migration.
|
||||
"""
|
||||
return True
|
||||
|
||||
def references_field(self, model_name, name, app_label):
|
||||
"""
|
||||
Return True if there is a chance this operation references the given
|
||||
field name, with an app label for accuracy.
|
||||
|
||||
Used for optimization. If in doubt, return True.
|
||||
"""
|
||||
return self.references_model(model_name, app_label)
|
||||
|
||||
def allow_migrate_model(self, connection_alias, model):
|
||||
"""
|
||||
Return whether or not a model may be migrated.
|
||||
|
||||
This is a thin wrapper around router.allow_migrate_model() that
|
||||
preemptively rejects any proxy, swapped out, or unmanaged model.
|
||||
"""
|
||||
if not model._meta.can_migrate(connection_alias):
|
||||
return False
|
||||
|
||||
return router.allow_migrate_model(connection_alias, model)
|
||||
|
||||
def reduce(self, operation, app_label):
|
||||
"""
|
||||
Return either a list of operations the actual operation should be
|
||||
replaced with or a boolean that indicates whether or not the specified
|
||||
operation can be optimized across.
|
||||
"""
|
||||
if self.elidable:
|
||||
return [operation]
|
||||
elif operation.elidable:
|
||||
return [self]
|
||||
return False
|
||||
|
||||
def __repr__(self):
|
||||
return "<%s %s%s>" % (
|
||||
self.__class__.__name__,
|
||||
", ".join(map(repr, self._constructor_args[0])),
|
||||
",".join(" %s=%r" % x for x in self._constructor_args[1].items()),
|
||||
)
|
||||
@@ -0,0 +1,365 @@
|
||||
from django.db.migrations.utils import field_references
|
||||
from django.db.models import NOT_PROVIDED
|
||||
from django.utils.functional import cached_property
|
||||
|
||||
from .base import Operation, OperationCategory
|
||||
|
||||
|
||||
class FieldOperation(Operation):
|
||||
def __init__(self, model_name, name, field=None):
|
||||
self.model_name = model_name
|
||||
self.name = name
|
||||
self.field = field
|
||||
|
||||
@cached_property
|
||||
def model_name_lower(self):
|
||||
return self.model_name.lower()
|
||||
|
||||
@cached_property
|
||||
def name_lower(self):
|
||||
return self.name.lower()
|
||||
|
||||
def is_same_model_operation(self, operation):
|
||||
return self.model_name_lower == operation.model_name_lower
|
||||
|
||||
def is_same_field_operation(self, operation):
|
||||
return (
|
||||
self.is_same_model_operation(operation)
|
||||
and self.name_lower == operation.name_lower
|
||||
)
|
||||
|
||||
def references_model(self, name, app_label):
|
||||
name_lower = name.lower()
|
||||
if name_lower == self.model_name_lower:
|
||||
return True
|
||||
if self.field:
|
||||
return bool(
|
||||
field_references(
|
||||
(app_label, self.model_name_lower),
|
||||
self.field,
|
||||
(app_label, name_lower),
|
||||
)
|
||||
)
|
||||
return False
|
||||
|
||||
def references_field(self, model_name, name, app_label):
|
||||
model_name_lower = model_name.lower()
|
||||
# Check if this operation locally references the field.
|
||||
if model_name_lower == self.model_name_lower:
|
||||
if name == self.name:
|
||||
return True
|
||||
elif (
|
||||
self.field
|
||||
and hasattr(self.field, "from_fields")
|
||||
and name in self.field.from_fields
|
||||
):
|
||||
return True
|
||||
# Check if this operation remotely references the field.
|
||||
if self.field is None:
|
||||
return False
|
||||
return bool(
|
||||
field_references(
|
||||
(app_label, self.model_name_lower),
|
||||
self.field,
|
||||
(app_label, model_name_lower),
|
||||
name,
|
||||
)
|
||||
)
|
||||
|
||||
def reduce(self, operation, app_label):
|
||||
return super().reduce(operation, app_label) or not operation.references_field(
|
||||
self.model_name, self.name, app_label
|
||||
)
|
||||
|
||||
|
||||
class AddField(FieldOperation):
|
||||
"""Add a field to a model."""
|
||||
|
||||
category = OperationCategory.ADDITION
|
||||
|
||||
def __init__(self, model_name, name, field, preserve_default=True):
|
||||
self.preserve_default = preserve_default
|
||||
super().__init__(model_name, name, field)
|
||||
|
||||
def deconstruct(self):
|
||||
kwargs = {
|
||||
"model_name": self.model_name,
|
||||
"name": self.name,
|
||||
"field": self.field,
|
||||
}
|
||||
if self.preserve_default is not True:
|
||||
kwargs["preserve_default"] = self.preserve_default
|
||||
return (self.__class__.__name__, [], kwargs)
|
||||
|
||||
def state_forwards(self, app_label, state):
|
||||
state.add_field(
|
||||
app_label,
|
||||
self.model_name_lower,
|
||||
self.name,
|
||||
self.field,
|
||||
self.preserve_default,
|
||||
)
|
||||
|
||||
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
||||
to_model = to_state.apps.get_model(app_label, self.model_name)
|
||||
if self.allow_migrate_model(schema_editor.connection.alias, to_model):
|
||||
from_model = from_state.apps.get_model(app_label, self.model_name)
|
||||
field = to_model._meta.get_field(self.name)
|
||||
if not self.preserve_default:
|
||||
field.default = self.field.default
|
||||
schema_editor.add_field(
|
||||
from_model,
|
||||
field,
|
||||
)
|
||||
if not self.preserve_default:
|
||||
field.default = NOT_PROVIDED
|
||||
|
||||
def database_backwards(self, app_label, schema_editor, from_state, to_state):
|
||||
from_model = from_state.apps.get_model(app_label, self.model_name)
|
||||
if self.allow_migrate_model(schema_editor.connection.alias, from_model):
|
||||
schema_editor.remove_field(
|
||||
from_model, from_model._meta.get_field(self.name)
|
||||
)
|
||||
|
||||
def describe(self):
|
||||
return "Add field %s to %s" % (self.name, self.model_name)
|
||||
|
||||
@property
|
||||
def migration_name_fragment(self):
|
||||
return "%s_%s" % (self.model_name_lower, self.name_lower)
|
||||
|
||||
def reduce(self, operation, app_label):
|
||||
if isinstance(operation, FieldOperation) and self.is_same_field_operation(
|
||||
operation
|
||||
):
|
||||
if isinstance(operation, AlterField):
|
||||
return [
|
||||
AddField(
|
||||
model_name=self.model_name,
|
||||
name=operation.name,
|
||||
field=operation.field,
|
||||
),
|
||||
]
|
||||
elif isinstance(operation, RemoveField):
|
||||
return []
|
||||
elif isinstance(operation, RenameField):
|
||||
return [
|
||||
AddField(
|
||||
model_name=self.model_name,
|
||||
name=operation.new_name,
|
||||
field=self.field,
|
||||
),
|
||||
]
|
||||
return super().reduce(operation, app_label)
|
||||
|
||||
|
||||
class RemoveField(FieldOperation):
|
||||
"""Remove a field from a model."""
|
||||
|
||||
category = OperationCategory.REMOVAL
|
||||
|
||||
def deconstruct(self):
|
||||
kwargs = {
|
||||
"model_name": self.model_name,
|
||||
"name": self.name,
|
||||
}
|
||||
return (self.__class__.__name__, [], kwargs)
|
||||
|
||||
def state_forwards(self, app_label, state):
|
||||
state.remove_field(app_label, self.model_name_lower, self.name)
|
||||
|
||||
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
||||
from_model = from_state.apps.get_model(app_label, self.model_name)
|
||||
if self.allow_migrate_model(schema_editor.connection.alias, from_model):
|
||||
schema_editor.remove_field(
|
||||
from_model, from_model._meta.get_field(self.name)
|
||||
)
|
||||
|
||||
def database_backwards(self, app_label, schema_editor, from_state, to_state):
|
||||
to_model = to_state.apps.get_model(app_label, self.model_name)
|
||||
if self.allow_migrate_model(schema_editor.connection.alias, to_model):
|
||||
from_model = from_state.apps.get_model(app_label, self.model_name)
|
||||
schema_editor.add_field(from_model, to_model._meta.get_field(self.name))
|
||||
|
||||
def describe(self):
|
||||
return "Remove field %s from %s" % (self.name, self.model_name)
|
||||
|
||||
@property
|
||||
def migration_name_fragment(self):
|
||||
return "remove_%s_%s" % (self.model_name_lower, self.name_lower)
|
||||
|
||||
def reduce(self, operation, app_label):
|
||||
from .models import DeleteModel
|
||||
|
||||
if (
|
||||
isinstance(operation, DeleteModel)
|
||||
and operation.name_lower == self.model_name_lower
|
||||
):
|
||||
return [operation]
|
||||
return super().reduce(operation, app_label)
|
||||
|
||||
|
||||
class AlterField(FieldOperation):
|
||||
"""
|
||||
Alter a field's database column (e.g. null, max_length) to the provided
|
||||
new field.
|
||||
"""
|
||||
|
||||
category = OperationCategory.ALTERATION
|
||||
|
||||
def __init__(self, model_name, name, field, preserve_default=True):
|
||||
self.preserve_default = preserve_default
|
||||
super().__init__(model_name, name, field)
|
||||
|
||||
def deconstruct(self):
|
||||
kwargs = {
|
||||
"model_name": self.model_name,
|
||||
"name": self.name,
|
||||
"field": self.field,
|
||||
}
|
||||
if self.preserve_default is not True:
|
||||
kwargs["preserve_default"] = self.preserve_default
|
||||
return (self.__class__.__name__, [], kwargs)
|
||||
|
||||
def state_forwards(self, app_label, state):
|
||||
state.alter_field(
|
||||
app_label,
|
||||
self.model_name_lower,
|
||||
self.name,
|
||||
self.field,
|
||||
self.preserve_default,
|
||||
)
|
||||
|
||||
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
||||
to_model = to_state.apps.get_model(app_label, self.model_name)
|
||||
if self.allow_migrate_model(schema_editor.connection.alias, to_model):
|
||||
from_model = from_state.apps.get_model(app_label, self.model_name)
|
||||
from_field = from_model._meta.get_field(self.name)
|
||||
to_field = to_model._meta.get_field(self.name)
|
||||
if not self.preserve_default:
|
||||
to_field.default = self.field.default
|
||||
schema_editor.alter_field(from_model, from_field, to_field)
|
||||
if not self.preserve_default:
|
||||
to_field.default = NOT_PROVIDED
|
||||
|
||||
def database_backwards(self, app_label, schema_editor, from_state, to_state):
|
||||
self.database_forwards(app_label, schema_editor, from_state, to_state)
|
||||
|
||||
def describe(self):
|
||||
return "Alter field %s on %s" % (self.name, self.model_name)
|
||||
|
||||
@property
|
||||
def migration_name_fragment(self):
|
||||
return "alter_%s_%s" % (self.model_name_lower, self.name_lower)
|
||||
|
||||
def reduce(self, operation, app_label):
|
||||
if isinstance(
|
||||
operation, (AlterField, RemoveField)
|
||||
) and self.is_same_field_operation(operation):
|
||||
return [operation]
|
||||
elif (
|
||||
isinstance(operation, RenameField)
|
||||
and self.is_same_field_operation(operation)
|
||||
and self.field.db_column is None
|
||||
):
|
||||
return [
|
||||
operation,
|
||||
AlterField(
|
||||
model_name=self.model_name,
|
||||
name=operation.new_name,
|
||||
field=self.field,
|
||||
),
|
||||
]
|
||||
return super().reduce(operation, app_label)
|
||||
|
||||
|
||||
class RenameField(FieldOperation):
|
||||
"""Rename a field on the model. Might affect db_column too."""
|
||||
|
||||
category = OperationCategory.ALTERATION
|
||||
|
||||
def __init__(self, model_name, old_name, new_name):
|
||||
self.old_name = old_name
|
||||
self.new_name = new_name
|
||||
super().__init__(model_name, old_name)
|
||||
|
||||
@cached_property
|
||||
def old_name_lower(self):
|
||||
return self.old_name.lower()
|
||||
|
||||
@cached_property
|
||||
def new_name_lower(self):
|
||||
return self.new_name.lower()
|
||||
|
||||
def deconstruct(self):
|
||||
kwargs = {
|
||||
"model_name": self.model_name,
|
||||
"old_name": self.old_name,
|
||||
"new_name": self.new_name,
|
||||
}
|
||||
return (self.__class__.__name__, [], kwargs)
|
||||
|
||||
def state_forwards(self, app_label, state):
|
||||
state.rename_field(
|
||||
app_label, self.model_name_lower, self.old_name, self.new_name
|
||||
)
|
||||
|
||||
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
||||
to_model = to_state.apps.get_model(app_label, self.model_name)
|
||||
if self.allow_migrate_model(schema_editor.connection.alias, to_model):
|
||||
from_model = from_state.apps.get_model(app_label, self.model_name)
|
||||
schema_editor.alter_field(
|
||||
from_model,
|
||||
from_model._meta.get_field(self.old_name),
|
||||
to_model._meta.get_field(self.new_name),
|
||||
)
|
||||
|
||||
def database_backwards(self, app_label, schema_editor, from_state, to_state):
|
||||
to_model = to_state.apps.get_model(app_label, self.model_name)
|
||||
if self.allow_migrate_model(schema_editor.connection.alias, to_model):
|
||||
from_model = from_state.apps.get_model(app_label, self.model_name)
|
||||
schema_editor.alter_field(
|
||||
from_model,
|
||||
from_model._meta.get_field(self.new_name),
|
||||
to_model._meta.get_field(self.old_name),
|
||||
)
|
||||
|
||||
def describe(self):
|
||||
return "Rename field %s on %s to %s" % (
|
||||
self.old_name,
|
||||
self.model_name,
|
||||
self.new_name,
|
||||
)
|
||||
|
||||
@property
|
||||
def migration_name_fragment(self):
|
||||
return "rename_%s_%s_%s" % (
|
||||
self.old_name_lower,
|
||||
self.model_name_lower,
|
||||
self.new_name_lower,
|
||||
)
|
||||
|
||||
def references_field(self, model_name, name, app_label):
|
||||
return self.references_model(model_name, app_label) and (
|
||||
name.lower() == self.old_name_lower or name.lower() == self.new_name_lower
|
||||
)
|
||||
|
||||
def reduce(self, operation, app_label):
|
||||
if (
|
||||
isinstance(operation, RenameField)
|
||||
and self.is_same_model_operation(operation)
|
||||
and self.new_name_lower == operation.old_name_lower
|
||||
):
|
||||
return [
|
||||
RenameField(
|
||||
self.model_name,
|
||||
self.old_name,
|
||||
operation.new_name,
|
||||
),
|
||||
]
|
||||
# Skip `FieldOperation.reduce` as we want to run `references_field`
|
||||
# against self.old_name and self.new_name.
|
||||
return super(FieldOperation, self).reduce(operation, app_label) or not (
|
||||
operation.references_field(self.model_name, self.old_name, app_label)
|
||||
or operation.references_field(self.model_name, self.new_name, app_label)
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,211 @@
|
||||
from django.db import router
|
||||
|
||||
from .base import Operation, OperationCategory
|
||||
|
||||
|
||||
class SeparateDatabaseAndState(Operation):
|
||||
"""
|
||||
Take two lists of operations - ones that will be used for the database,
|
||||
and ones that will be used for the state change. This allows operations
|
||||
that don't support state change to have it applied, or have operations
|
||||
that affect the state or not the database, or so on.
|
||||
"""
|
||||
|
||||
category = OperationCategory.MIXED
|
||||
serialization_expand_args = ["database_operations", "state_operations"]
|
||||
|
||||
def __init__(self, database_operations=None, state_operations=None):
|
||||
self.database_operations = database_operations or []
|
||||
self.state_operations = state_operations or []
|
||||
|
||||
def deconstruct(self):
|
||||
kwargs = {}
|
||||
if self.database_operations:
|
||||
kwargs["database_operations"] = self.database_operations
|
||||
if self.state_operations:
|
||||
kwargs["state_operations"] = self.state_operations
|
||||
return (self.__class__.__qualname__, [], kwargs)
|
||||
|
||||
def state_forwards(self, app_label, state):
|
||||
for state_operation in self.state_operations:
|
||||
state_operation.state_forwards(app_label, state)
|
||||
|
||||
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
||||
# We calculate state separately in here since our state functions aren't useful
|
||||
for database_operation in self.database_operations:
|
||||
to_state = from_state.clone()
|
||||
database_operation.state_forwards(app_label, to_state)
|
||||
database_operation.database_forwards(
|
||||
app_label, schema_editor, from_state, to_state
|
||||
)
|
||||
from_state = to_state
|
||||
|
||||
def database_backwards(self, app_label, schema_editor, from_state, to_state):
|
||||
# We calculate state separately in here since our state functions aren't useful
|
||||
to_states = {}
|
||||
for dbop in self.database_operations:
|
||||
to_states[dbop] = to_state
|
||||
to_state = to_state.clone()
|
||||
dbop.state_forwards(app_label, to_state)
|
||||
# to_state now has the states of all the database_operations applied
|
||||
# which is the from_state for the backwards migration of the last
|
||||
# operation.
|
||||
for database_operation in reversed(self.database_operations):
|
||||
from_state = to_state
|
||||
to_state = to_states[database_operation]
|
||||
database_operation.database_backwards(
|
||||
app_label, schema_editor, from_state, to_state
|
||||
)
|
||||
|
||||
def describe(self):
|
||||
return "Custom state/database change combination"
|
||||
|
||||
|
||||
class RunSQL(Operation):
|
||||
"""
|
||||
Run some raw SQL. A reverse SQL statement may be provided.
|
||||
|
||||
Also accept a list of operations that represent the state change effected
|
||||
by this SQL change, in case it's custom column/table creation/deletion.
|
||||
"""
|
||||
|
||||
category = OperationCategory.SQL
|
||||
noop = ""
|
||||
|
||||
def __init__(
|
||||
self, sql, reverse_sql=None, state_operations=None, hints=None, elidable=False
|
||||
):
|
||||
self.sql = sql
|
||||
self.reverse_sql = reverse_sql
|
||||
self.state_operations = state_operations or []
|
||||
self.hints = hints or {}
|
||||
self.elidable = elidable
|
||||
|
||||
def deconstruct(self):
|
||||
kwargs = {
|
||||
"sql": self.sql,
|
||||
}
|
||||
if self.reverse_sql is not None:
|
||||
kwargs["reverse_sql"] = self.reverse_sql
|
||||
if self.state_operations:
|
||||
kwargs["state_operations"] = self.state_operations
|
||||
if self.hints:
|
||||
kwargs["hints"] = self.hints
|
||||
return (self.__class__.__qualname__, [], kwargs)
|
||||
|
||||
@property
|
||||
def reversible(self):
|
||||
return self.reverse_sql is not None
|
||||
|
||||
def state_forwards(self, app_label, state):
|
||||
for state_operation in self.state_operations:
|
||||
state_operation.state_forwards(app_label, state)
|
||||
|
||||
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
||||
if router.allow_migrate(
|
||||
schema_editor.connection.alias, app_label, **self.hints
|
||||
):
|
||||
self._run_sql(schema_editor, self.sql)
|
||||
|
||||
def database_backwards(self, app_label, schema_editor, from_state, to_state):
|
||||
if self.reverse_sql is None:
|
||||
raise NotImplementedError("You cannot reverse this operation")
|
||||
if router.allow_migrate(
|
||||
schema_editor.connection.alias, app_label, **self.hints
|
||||
):
|
||||
self._run_sql(schema_editor, self.reverse_sql)
|
||||
|
||||
def describe(self):
|
||||
return "Raw SQL operation"
|
||||
|
||||
def _run_sql(self, schema_editor, sqls):
|
||||
if isinstance(sqls, (list, tuple)):
|
||||
for sql in sqls:
|
||||
params = None
|
||||
if isinstance(sql, (list, tuple)):
|
||||
elements = len(sql)
|
||||
if elements == 2:
|
||||
sql, params = sql
|
||||
else:
|
||||
raise ValueError("Expected a 2-tuple but got %d" % elements)
|
||||
schema_editor.execute(sql, params=params)
|
||||
elif sqls != RunSQL.noop:
|
||||
statements = schema_editor.connection.ops.prepare_sql_script(sqls)
|
||||
for statement in statements:
|
||||
schema_editor.execute(statement, params=None)
|
||||
|
||||
|
||||
class RunPython(Operation):
|
||||
"""
|
||||
Run Python code in a context suitable for doing versioned ORM operations.
|
||||
"""
|
||||
|
||||
category = OperationCategory.PYTHON
|
||||
reduces_to_sql = False
|
||||
|
||||
def __init__(
|
||||
self, code, reverse_code=None, atomic=None, hints=None, elidable=False
|
||||
):
|
||||
self.atomic = atomic
|
||||
# Forwards code
|
||||
if not callable(code):
|
||||
raise ValueError("RunPython must be supplied with a callable")
|
||||
self.code = code
|
||||
# Reverse code
|
||||
if reverse_code is None:
|
||||
self.reverse_code = None
|
||||
else:
|
||||
if not callable(reverse_code):
|
||||
raise ValueError("RunPython must be supplied with callable arguments")
|
||||
self.reverse_code = reverse_code
|
||||
self.hints = hints or {}
|
||||
self.elidable = elidable
|
||||
|
||||
def deconstruct(self):
|
||||
kwargs = {
|
||||
"code": self.code,
|
||||
}
|
||||
if self.reverse_code is not None:
|
||||
kwargs["reverse_code"] = self.reverse_code
|
||||
if self.atomic is not None:
|
||||
kwargs["atomic"] = self.atomic
|
||||
if self.hints:
|
||||
kwargs["hints"] = self.hints
|
||||
return (self.__class__.__qualname__, [], kwargs)
|
||||
|
||||
@property
|
||||
def reversible(self):
|
||||
return self.reverse_code is not None
|
||||
|
||||
def state_forwards(self, app_label, state):
|
||||
# RunPython objects have no state effect. To add some, combine this
|
||||
# with SeparateDatabaseAndState.
|
||||
pass
|
||||
|
||||
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
||||
# RunPython has access to all models. Ensure that all models are
|
||||
# reloaded in case any are delayed.
|
||||
from_state.clear_delayed_apps_cache()
|
||||
if router.allow_migrate(
|
||||
schema_editor.connection.alias, app_label, **self.hints
|
||||
):
|
||||
# We now execute the Python code in a context that contains a 'models'
|
||||
# object, representing the versioned models as an app registry.
|
||||
# We could try to override the global cache, but then people will still
|
||||
# use direct imports, so we go with a documentation approach instead.
|
||||
self.code(from_state.apps, schema_editor)
|
||||
|
||||
def database_backwards(self, app_label, schema_editor, from_state, to_state):
|
||||
if self.reverse_code is None:
|
||||
raise NotImplementedError("You cannot reverse this operation")
|
||||
if router.allow_migrate(
|
||||
schema_editor.connection.alias, app_label, **self.hints
|
||||
):
|
||||
self.reverse_code(from_state.apps, schema_editor)
|
||||
|
||||
def describe(self):
|
||||
return "Raw Python operation"
|
||||
|
||||
@staticmethod
|
||||
def noop(apps, schema_editor):
|
||||
return None
|
||||
@@ -0,0 +1,69 @@
|
||||
class MigrationOptimizer:
|
||||
"""
|
||||
Power the optimization process, where you provide a list of Operations
|
||||
and you are returned a list of equal or shorter length - operations
|
||||
are merged into one if possible.
|
||||
|
||||
For example, a CreateModel and an AddField can be optimized into a
|
||||
new CreateModel, and CreateModel and DeleteModel can be optimized into
|
||||
nothing.
|
||||
"""
|
||||
|
||||
def optimize(self, operations, app_label):
|
||||
"""
|
||||
Main optimization entry point. Pass in a list of Operation instances,
|
||||
get out a new list of Operation instances.
|
||||
|
||||
Unfortunately, due to the scope of the optimization (two combinable
|
||||
operations might be separated by several hundred others), this can't be
|
||||
done as a peephole optimization with checks/output implemented on
|
||||
the Operations themselves; instead, the optimizer looks at each
|
||||
individual operation and scans forwards in the list to see if there
|
||||
are any matches, stopping at boundaries - operations which can't
|
||||
be optimized over (RunSQL, operations on the same field/model, etc.)
|
||||
|
||||
The inner loop is run until the starting list is the same as the result
|
||||
list, and then the result is returned. This means that operation
|
||||
optimization must be stable and always return an equal or shorter list.
|
||||
"""
|
||||
# Internal tracking variable for test assertions about # of loops
|
||||
if app_label is None:
|
||||
raise TypeError("app_label must be a str.")
|
||||
self._iterations = 0
|
||||
while True:
|
||||
result = self.optimize_inner(operations, app_label)
|
||||
self._iterations += 1
|
||||
if result == operations:
|
||||
return result
|
||||
operations = result
|
||||
|
||||
def optimize_inner(self, operations, app_label):
|
||||
"""Inner optimization loop."""
|
||||
new_operations = []
|
||||
for i, operation in enumerate(operations):
|
||||
right = True # Should we reduce on the right or on the left.
|
||||
# Compare it to each operation after it
|
||||
for j, other in enumerate(operations[i + 1 :]):
|
||||
result = operation.reduce(other, app_label)
|
||||
if isinstance(result, list):
|
||||
in_between = operations[i + 1 : i + j + 1]
|
||||
if right:
|
||||
new_operations.extend(in_between)
|
||||
new_operations.extend(result)
|
||||
elif all(op.reduce(other, app_label) is True for op in in_between):
|
||||
# Perform a left reduction if all of the in-between
|
||||
# operations can optimize through other.
|
||||
new_operations.extend(result)
|
||||
new_operations.extend(in_between)
|
||||
else:
|
||||
# Otherwise keep trying.
|
||||
new_operations.append(operation)
|
||||
break
|
||||
new_operations.extend(operations[i + j + 2 :])
|
||||
return new_operations
|
||||
elif not result:
|
||||
# Can't perform a right reduction.
|
||||
right = False
|
||||
else:
|
||||
new_operations.append(operation)
|
||||
return new_operations
|
||||
@@ -0,0 +1,347 @@
|
||||
import datetime
|
||||
import importlib
|
||||
import os
|
||||
import sys
|
||||
|
||||
from django.apps import apps
|
||||
from django.core.management.base import OutputWrapper
|
||||
from django.db.models import NOT_PROVIDED
|
||||
from django.utils import timezone
|
||||
from django.utils.version import get_docs_version
|
||||
|
||||
from .loader import MigrationLoader
|
||||
|
||||
|
||||
class MigrationQuestioner:
|
||||
"""
|
||||
Give the autodetector responses to questions it might have.
|
||||
This base class has a built-in noninteractive mode, but the
|
||||
interactive subclass is what the command-line arguments will use.
|
||||
"""
|
||||
|
||||
def __init__(self, defaults=None, specified_apps=None, dry_run=None):
|
||||
self.defaults = defaults or {}
|
||||
self.specified_apps = specified_apps or set()
|
||||
self.dry_run = dry_run
|
||||
|
||||
def ask_initial(self, app_label):
|
||||
"""Should we create an initial migration for the app?"""
|
||||
# If it was specified on the command line, definitely true
|
||||
if app_label in self.specified_apps:
|
||||
return True
|
||||
# Otherwise, we look to see if it has a migrations module
|
||||
# without any Python files in it, apart from __init__.py.
|
||||
# Apps from the new app template will have these; the Python
|
||||
# file check will ensure we skip South ones.
|
||||
try:
|
||||
app_config = apps.get_app_config(app_label)
|
||||
except LookupError: # It's a fake app.
|
||||
return self.defaults.get("ask_initial", False)
|
||||
migrations_import_path, _ = MigrationLoader.migrations_module(app_config.label)
|
||||
if migrations_import_path is None:
|
||||
# It's an application with migrations disabled.
|
||||
return self.defaults.get("ask_initial", False)
|
||||
try:
|
||||
migrations_module = importlib.import_module(migrations_import_path)
|
||||
except ImportError:
|
||||
return self.defaults.get("ask_initial", False)
|
||||
else:
|
||||
if getattr(migrations_module, "__file__", None):
|
||||
filenames = os.listdir(os.path.dirname(migrations_module.__file__))
|
||||
elif hasattr(migrations_module, "__path__"):
|
||||
if len(migrations_module.__path__) > 1:
|
||||
return False
|
||||
filenames = os.listdir(list(migrations_module.__path__)[0])
|
||||
return not any(x.endswith(".py") for x in filenames if x != "__init__.py")
|
||||
|
||||
def ask_not_null_addition(self, field_name, model_name):
|
||||
"""Adding a NOT NULL field to a model."""
|
||||
# None means quit
|
||||
return None
|
||||
|
||||
def ask_not_null_alteration(self, field_name, model_name):
|
||||
"""Changing a NULL field to NOT NULL."""
|
||||
# None means quit
|
||||
return None
|
||||
|
||||
def ask_rename(self, model_name, old_name, new_name, field_instance):
|
||||
"""Was this field really renamed?"""
|
||||
return self.defaults.get("ask_rename", False)
|
||||
|
||||
def ask_rename_model(self, old_model_state, new_model_state):
|
||||
"""Was this model really renamed?"""
|
||||
return self.defaults.get("ask_rename_model", False)
|
||||
|
||||
def ask_merge(self, app_label):
|
||||
"""Should these migrations really be merged?"""
|
||||
return self.defaults.get("ask_merge", False)
|
||||
|
||||
def ask_auto_now_add_addition(self, field_name, model_name):
|
||||
"""Adding an auto_now_add field to a model."""
|
||||
# None means quit
|
||||
return None
|
||||
|
||||
def ask_unique_callable_default_addition(self, field_name, model_name):
|
||||
"""Adding a unique field with a callable default."""
|
||||
# None means continue.
|
||||
return None
|
||||
|
||||
|
||||
class InteractiveMigrationQuestioner(MigrationQuestioner):
|
||||
def __init__(
|
||||
self, defaults=None, specified_apps=None, dry_run=None, prompt_output=None
|
||||
):
|
||||
super().__init__(
|
||||
defaults=defaults, specified_apps=specified_apps, dry_run=dry_run
|
||||
)
|
||||
self.prompt_output = prompt_output or OutputWrapper(sys.stdout)
|
||||
|
||||
def _boolean_input(self, question, default=None):
|
||||
self.prompt_output.write(f"{question} ", ending="")
|
||||
result = input()
|
||||
if not result and default is not None:
|
||||
return default
|
||||
while not result or result[0].lower() not in "yn":
|
||||
self.prompt_output.write("Please answer yes or no: ", ending="")
|
||||
result = input()
|
||||
return result[0].lower() == "y"
|
||||
|
||||
def _choice_input(self, question, choices):
|
||||
self.prompt_output.write(f"{question}")
|
||||
for i, choice in enumerate(choices):
|
||||
self.prompt_output.write(" %s) %s" % (i + 1, choice))
|
||||
self.prompt_output.write("Select an option: ", ending="")
|
||||
while True:
|
||||
try:
|
||||
result = input()
|
||||
value = int(result)
|
||||
except ValueError:
|
||||
pass
|
||||
except KeyboardInterrupt:
|
||||
self.prompt_output.write("\nCancelled.")
|
||||
sys.exit(1)
|
||||
else:
|
||||
if 0 < value <= len(choices):
|
||||
return value
|
||||
self.prompt_output.write("Please select a valid option: ", ending="")
|
||||
|
||||
def _ask_default(self, default=""):
|
||||
"""
|
||||
Prompt for a default value.
|
||||
|
||||
The ``default`` argument allows providing a custom default value (as a
|
||||
string) which will be shown to the user and used as the return value
|
||||
if the user doesn't provide any other input.
|
||||
"""
|
||||
self.prompt_output.write("Please enter the default value as valid Python.")
|
||||
if default:
|
||||
self.prompt_output.write(
|
||||
f"Accept the default '{default}' by pressing 'Enter' or "
|
||||
f"provide another value."
|
||||
)
|
||||
self.prompt_output.write(
|
||||
"The datetime and django.utils.timezone modules are available, so "
|
||||
"it is possible to provide e.g. timezone.now as a value."
|
||||
)
|
||||
self.prompt_output.write("Type 'exit' to exit this prompt")
|
||||
while True:
|
||||
if default:
|
||||
prompt = "[default: {}] >>> ".format(default)
|
||||
else:
|
||||
prompt = ">>> "
|
||||
self.prompt_output.write(prompt, ending="")
|
||||
try:
|
||||
code = input()
|
||||
except KeyboardInterrupt:
|
||||
self.prompt_output.write("\nCancelled.")
|
||||
sys.exit(1)
|
||||
if not code and default:
|
||||
code = default
|
||||
if not code:
|
||||
self.prompt_output.write(
|
||||
"Please enter some code, or 'exit' (without quotes) to exit."
|
||||
)
|
||||
elif code == "exit":
|
||||
sys.exit(1)
|
||||
else:
|
||||
try:
|
||||
return eval(code, {}, {"datetime": datetime, "timezone": timezone})
|
||||
except Exception as e:
|
||||
self.prompt_output.write(f"{e.__class__.__name__}: {e}")
|
||||
|
||||
def ask_not_null_addition(self, field_name, model_name):
|
||||
"""Adding a NOT NULL field to a model."""
|
||||
if not self.dry_run:
|
||||
choice = self._choice_input(
|
||||
f"It is impossible to add a non-nullable field '{field_name}' "
|
||||
f"to {model_name} without specifying a default. This is "
|
||||
f"because the database needs something to populate existing "
|
||||
f"rows.\n"
|
||||
f"Please select a fix:",
|
||||
[
|
||||
(
|
||||
"Provide a one-off default now (will be set on all existing "
|
||||
"rows with a null value for this column)"
|
||||
),
|
||||
"Quit and manually define a default value in models.py.",
|
||||
],
|
||||
)
|
||||
if choice == 2:
|
||||
sys.exit(3)
|
||||
else:
|
||||
return self._ask_default()
|
||||
return None
|
||||
|
||||
def ask_not_null_alteration(self, field_name, model_name):
|
||||
"""Changing a NULL field to NOT NULL."""
|
||||
if not self.dry_run:
|
||||
choice = self._choice_input(
|
||||
f"It is impossible to change a nullable field '{field_name}' "
|
||||
f"on {model_name} to non-nullable without providing a "
|
||||
f"default. This is because the database needs something to "
|
||||
f"populate existing rows.\n"
|
||||
f"Please select a fix:",
|
||||
[
|
||||
(
|
||||
"Provide a one-off default now (will be set on all existing "
|
||||
"rows with a null value for this column)"
|
||||
),
|
||||
"Ignore for now. Existing rows that contain NULL values "
|
||||
"will have to be handled manually, for example with a "
|
||||
"RunPython or RunSQL operation.",
|
||||
"Quit and manually define a default value in models.py.",
|
||||
],
|
||||
)
|
||||
if choice == 2:
|
||||
return NOT_PROVIDED
|
||||
elif choice == 3:
|
||||
sys.exit(3)
|
||||
else:
|
||||
return self._ask_default()
|
||||
return None
|
||||
|
||||
def ask_rename(self, model_name, old_name, new_name, field_instance):
|
||||
"""Was this field really renamed?"""
|
||||
msg = "Was %s.%s renamed to %s.%s (a %s)? [y/N]"
|
||||
return self._boolean_input(
|
||||
msg
|
||||
% (
|
||||
model_name,
|
||||
old_name,
|
||||
model_name,
|
||||
new_name,
|
||||
field_instance.__class__.__name__,
|
||||
),
|
||||
False,
|
||||
)
|
||||
|
||||
def ask_rename_model(self, old_model_state, new_model_state):
|
||||
"""Was this model really renamed?"""
|
||||
msg = "Was the model %s.%s renamed to %s? [y/N]"
|
||||
return self._boolean_input(
|
||||
msg
|
||||
% (old_model_state.app_label, old_model_state.name, new_model_state.name),
|
||||
False,
|
||||
)
|
||||
|
||||
def ask_merge(self, app_label):
|
||||
return self._boolean_input(
|
||||
"\nMerging will only work if the operations printed above do not conflict\n"
|
||||
+ "with each other (working on different fields or models)\n"
|
||||
+ "Should these migration branches be merged? [y/N]",
|
||||
False,
|
||||
)
|
||||
|
||||
def ask_auto_now_add_addition(self, field_name, model_name):
|
||||
"""Adding an auto_now_add field to a model."""
|
||||
if not self.dry_run:
|
||||
choice = self._choice_input(
|
||||
f"It is impossible to add the field '{field_name}' with "
|
||||
f"'auto_now_add=True' to {model_name} without providing a "
|
||||
f"default. This is because the database needs something to "
|
||||
f"populate existing rows.\n",
|
||||
[
|
||||
"Provide a one-off default now which will be set on all "
|
||||
"existing rows",
|
||||
"Quit and manually define a default value in models.py.",
|
||||
],
|
||||
)
|
||||
if choice == 2:
|
||||
sys.exit(3)
|
||||
else:
|
||||
return self._ask_default(default="timezone.now")
|
||||
return None
|
||||
|
||||
def ask_unique_callable_default_addition(self, field_name, model_name):
|
||||
"""Adding a unique field with a callable default."""
|
||||
if not self.dry_run:
|
||||
version = get_docs_version()
|
||||
choice = self._choice_input(
|
||||
f"Callable default on unique field {model_name}.{field_name} "
|
||||
f"will not generate unique values upon migrating.\n"
|
||||
f"Please choose how to proceed:\n",
|
||||
[
|
||||
f"Continue making this migration as the first step in "
|
||||
f"writing a manual migration to generate unique values "
|
||||
f"described here: "
|
||||
f"https://docs.djangoproject.com/en/{version}/howto/"
|
||||
f"writing-migrations/#migrations-that-add-unique-fields.",
|
||||
"Quit and edit field options in models.py.",
|
||||
],
|
||||
)
|
||||
if choice == 2:
|
||||
sys.exit(3)
|
||||
return None
|
||||
|
||||
|
||||
class NonInteractiveMigrationQuestioner(MigrationQuestioner):
|
||||
def __init__(
|
||||
self,
|
||||
defaults=None,
|
||||
specified_apps=None,
|
||||
dry_run=None,
|
||||
verbosity=1,
|
||||
log=None,
|
||||
):
|
||||
self.verbosity = verbosity
|
||||
self.log = log
|
||||
super().__init__(
|
||||
defaults=defaults,
|
||||
specified_apps=specified_apps,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
|
||||
def log_lack_of_migration(self, field_name, model_name, reason):
|
||||
if self.verbosity > 0:
|
||||
self.log(
|
||||
f"Field '{field_name}' on model '{model_name}' not migrated: "
|
||||
f"{reason}."
|
||||
)
|
||||
|
||||
def ask_not_null_addition(self, field_name, model_name):
|
||||
# We can't ask the user, so act like the user aborted.
|
||||
self.log_lack_of_migration(
|
||||
field_name,
|
||||
model_name,
|
||||
"it is impossible to add a non-nullable field without specifying "
|
||||
"a default",
|
||||
)
|
||||
sys.exit(3)
|
||||
|
||||
def ask_not_null_alteration(self, field_name, model_name):
|
||||
# We can't ask the user, so set as not provided.
|
||||
self.log(
|
||||
f"Field '{field_name}' on model '{model_name}' given a default of "
|
||||
f"NOT PROVIDED and must be corrected."
|
||||
)
|
||||
return NOT_PROVIDED
|
||||
|
||||
def ask_auto_now_add_addition(self, field_name, model_name):
|
||||
# We can't ask the user, so act like the user aborted.
|
||||
self.log_lack_of_migration(
|
||||
field_name,
|
||||
model_name,
|
||||
"it is impossible to add a field with 'auto_now_add=True' without "
|
||||
"specifying a default",
|
||||
)
|
||||
sys.exit(3)
|
||||
@@ -0,0 +1,111 @@
|
||||
from django.apps.registry import Apps
|
||||
from django.db import DatabaseError, models
|
||||
from django.utils.functional import classproperty
|
||||
from django.utils.timezone import now
|
||||
|
||||
from .exceptions import MigrationSchemaMissing
|
||||
|
||||
|
||||
class MigrationRecorder:
|
||||
"""
|
||||
Deal with storing migration records in the database.
|
||||
|
||||
Because this table is actually itself used for dealing with model
|
||||
creation, it's the one thing we can't do normally via migrations.
|
||||
We manually handle table creation/schema updating (using schema backend)
|
||||
and then have a floating model to do queries with.
|
||||
|
||||
If a migration is unapplied its row is removed from the table. Having
|
||||
a row in the table always means a migration is applied.
|
||||
"""
|
||||
|
||||
_migration_class = None
|
||||
|
||||
@classproperty
|
||||
def Migration(cls):
|
||||
"""
|
||||
Lazy load to avoid AppRegistryNotReady if installed apps import
|
||||
MigrationRecorder.
|
||||
"""
|
||||
if cls._migration_class is None:
|
||||
|
||||
class Migration(models.Model):
|
||||
app = models.CharField(max_length=255)
|
||||
name = models.CharField(max_length=255)
|
||||
applied = models.DateTimeField(default=now)
|
||||
|
||||
class Meta:
|
||||
apps = Apps()
|
||||
app_label = "migrations"
|
||||
db_table = "django_migrations"
|
||||
|
||||
def __str__(self):
|
||||
return "Migration %s for %s" % (self.name, self.app)
|
||||
|
||||
cls._migration_class = Migration
|
||||
return cls._migration_class
|
||||
|
||||
def __init__(self, connection):
|
||||
self.connection = connection
|
||||
self._has_table = False
|
||||
|
||||
@property
|
||||
def migration_qs(self):
|
||||
return self.Migration.objects.using(self.connection.alias)
|
||||
|
||||
def has_table(self):
|
||||
"""Return True if the django_migrations table exists."""
|
||||
# If the migrations table has already been confirmed to exist, don't
|
||||
# recheck it's existence.
|
||||
if self._has_table:
|
||||
return True
|
||||
# It hasn't been confirmed to exist, recheck.
|
||||
with self.connection.cursor() as cursor:
|
||||
tables = self.connection.introspection.table_names(cursor)
|
||||
|
||||
self._has_table = self.Migration._meta.db_table in tables
|
||||
return self._has_table
|
||||
|
||||
def ensure_schema(self):
|
||||
"""Ensure the table exists and has the correct schema."""
|
||||
# If the table's there, that's fine - we've never changed its schema
|
||||
# in the codebase.
|
||||
if self.has_table():
|
||||
return
|
||||
# Make the table
|
||||
try:
|
||||
with self.connection.schema_editor() as editor:
|
||||
editor.create_model(self.Migration)
|
||||
except DatabaseError as exc:
|
||||
raise MigrationSchemaMissing(
|
||||
"Unable to create the django_migrations table (%s)" % exc
|
||||
)
|
||||
|
||||
def applied_migrations(self):
|
||||
"""
|
||||
Return a dict mapping (app_name, migration_name) to Migration instances
|
||||
for all applied migrations.
|
||||
"""
|
||||
if self.has_table():
|
||||
return {
|
||||
(migration.app, migration.name): migration
|
||||
for migration in self.migration_qs
|
||||
}
|
||||
else:
|
||||
# If the django_migrations table doesn't exist, then no migrations
|
||||
# are applied.
|
||||
return {}
|
||||
|
||||
def record_applied(self, app, name):
|
||||
"""Record that a migration was applied."""
|
||||
self.ensure_schema()
|
||||
self.migration_qs.create(app=app, name=name)
|
||||
|
||||
def record_unapplied(self, app, name):
|
||||
"""Record that a migration was unapplied."""
|
||||
self.ensure_schema()
|
||||
self.migration_qs.filter(app=app, name=name).delete()
|
||||
|
||||
def flush(self):
|
||||
"""Delete all migration records. Useful for testing migrations."""
|
||||
self.migration_qs.all().delete()
|
||||
@@ -0,0 +1,405 @@
|
||||
import builtins
|
||||
import collections.abc
|
||||
import datetime
|
||||
import decimal
|
||||
import enum
|
||||
import functools
|
||||
import math
|
||||
import os
|
||||
import pathlib
|
||||
import re
|
||||
import types
|
||||
import uuid
|
||||
|
||||
from django.conf import SettingsReference
|
||||
from django.db import models
|
||||
from django.db.migrations.operations.base import Operation
|
||||
from django.db.migrations.utils import COMPILED_REGEX_TYPE, RegexObject
|
||||
from django.utils.functional import LazyObject, Promise
|
||||
from django.utils.version import PY311, get_docs_version
|
||||
|
||||
FUNCTION_TYPES = (types.FunctionType, types.BuiltinFunctionType, types.MethodType)
|
||||
|
||||
if isinstance(functools._lru_cache_wrapper, type):
|
||||
# When using CPython's _functools C module, LRU cache function decorators
|
||||
# present as a class and not a function, so add that class to the list of
|
||||
# function types. In the pure Python implementation and PyPy they present
|
||||
# as normal functions which are already handled.
|
||||
FUNCTION_TYPES += (functools._lru_cache_wrapper,)
|
||||
|
||||
|
||||
class BaseSerializer:
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
||||
def serialize(self):
|
||||
raise NotImplementedError(
|
||||
"Subclasses of BaseSerializer must implement the serialize() method."
|
||||
)
|
||||
|
||||
|
||||
class BaseSequenceSerializer(BaseSerializer):
|
||||
def _format(self):
|
||||
raise NotImplementedError(
|
||||
"Subclasses of BaseSequenceSerializer must implement the _format() method."
|
||||
)
|
||||
|
||||
def serialize(self):
|
||||
imports = set()
|
||||
strings = []
|
||||
for item in self.value:
|
||||
item_string, item_imports = serializer_factory(item).serialize()
|
||||
imports.update(item_imports)
|
||||
strings.append(item_string)
|
||||
value = self._format()
|
||||
return value % (", ".join(strings)), imports
|
||||
|
||||
|
||||
class BaseUnorderedSequenceSerializer(BaseSequenceSerializer):
|
||||
def __init__(self, value):
|
||||
super().__init__(sorted(value, key=repr))
|
||||
|
||||
|
||||
class BaseSimpleSerializer(BaseSerializer):
|
||||
def serialize(self):
|
||||
return repr(self.value), set()
|
||||
|
||||
|
||||
class ChoicesSerializer(BaseSerializer):
|
||||
def serialize(self):
|
||||
return serializer_factory(self.value.value).serialize()
|
||||
|
||||
|
||||
class DateTimeSerializer(BaseSerializer):
|
||||
"""For datetime.*, except datetime.datetime."""
|
||||
|
||||
def serialize(self):
|
||||
return repr(self.value), {"import datetime"}
|
||||
|
||||
|
||||
class DatetimeDatetimeSerializer(BaseSerializer):
|
||||
"""For datetime.datetime."""
|
||||
|
||||
def serialize(self):
|
||||
if self.value.tzinfo is not None and self.value.tzinfo != datetime.timezone.utc:
|
||||
self.value = self.value.astimezone(datetime.timezone.utc)
|
||||
imports = ["import datetime"]
|
||||
return repr(self.value), set(imports)
|
||||
|
||||
|
||||
class DecimalSerializer(BaseSerializer):
|
||||
def serialize(self):
|
||||
return repr(self.value), {"from decimal import Decimal"}
|
||||
|
||||
|
||||
class DeconstructableSerializer(BaseSerializer):
|
||||
@staticmethod
|
||||
def serialize_deconstructed(path, args, kwargs):
|
||||
name, imports = DeconstructableSerializer._serialize_path(path)
|
||||
strings = []
|
||||
for arg in args:
|
||||
arg_string, arg_imports = serializer_factory(arg).serialize()
|
||||
strings.append(arg_string)
|
||||
imports.update(arg_imports)
|
||||
for kw, arg in sorted(kwargs.items()):
|
||||
arg_string, arg_imports = serializer_factory(arg).serialize()
|
||||
imports.update(arg_imports)
|
||||
strings.append("%s=%s" % (kw, arg_string))
|
||||
return "%s(%s)" % (name, ", ".join(strings)), imports
|
||||
|
||||
@staticmethod
|
||||
def _serialize_path(path):
|
||||
module, name = path.rsplit(".", 1)
|
||||
if module == "django.db.models":
|
||||
imports = {"from django.db import models"}
|
||||
name = "models.%s" % name
|
||||
else:
|
||||
imports = {"import %s" % module}
|
||||
name = path
|
||||
return name, imports
|
||||
|
||||
def serialize(self):
|
||||
return self.serialize_deconstructed(*self.value.deconstruct())
|
||||
|
||||
|
||||
class DictionarySerializer(BaseSerializer):
|
||||
def serialize(self):
|
||||
imports = set()
|
||||
strings = []
|
||||
for k, v in sorted(self.value.items()):
|
||||
k_string, k_imports = serializer_factory(k).serialize()
|
||||
v_string, v_imports = serializer_factory(v).serialize()
|
||||
imports.update(k_imports)
|
||||
imports.update(v_imports)
|
||||
strings.append((k_string, v_string))
|
||||
return "{%s}" % (", ".join("%s: %s" % (k, v) for k, v in strings)), imports
|
||||
|
||||
|
||||
class EnumSerializer(BaseSerializer):
|
||||
def serialize(self):
|
||||
enum_class = self.value.__class__
|
||||
module = enum_class.__module__
|
||||
if issubclass(enum_class, enum.Flag):
|
||||
if PY311:
|
||||
members = list(self.value)
|
||||
else:
|
||||
members, _ = enum._decompose(enum_class, self.value)
|
||||
members = reversed(members)
|
||||
else:
|
||||
members = (self.value,)
|
||||
return (
|
||||
" | ".join(
|
||||
[
|
||||
f"{module}.{enum_class.__qualname__}[{item.name!r}]"
|
||||
for item in members
|
||||
]
|
||||
),
|
||||
{"import %s" % module},
|
||||
)
|
||||
|
||||
|
||||
class FloatSerializer(BaseSimpleSerializer):
|
||||
def serialize(self):
|
||||
if math.isnan(self.value) or math.isinf(self.value):
|
||||
return 'float("{}")'.format(self.value), set()
|
||||
return super().serialize()
|
||||
|
||||
|
||||
class FrozensetSerializer(BaseUnorderedSequenceSerializer):
|
||||
def _format(self):
|
||||
return "frozenset([%s])"
|
||||
|
||||
|
||||
class FunctionTypeSerializer(BaseSerializer):
|
||||
def serialize(self):
|
||||
if getattr(self.value, "__self__", None) and isinstance(
|
||||
self.value.__self__, type
|
||||
):
|
||||
klass = self.value.__self__
|
||||
module = klass.__module__
|
||||
return "%s.%s.%s" % (module, klass.__qualname__, self.value.__name__), {
|
||||
"import %s" % module
|
||||
}
|
||||
# Further error checking
|
||||
if self.value.__name__ == "<lambda>":
|
||||
raise ValueError("Cannot serialize function: lambda")
|
||||
if self.value.__module__ is None:
|
||||
raise ValueError("Cannot serialize function %r: No module" % self.value)
|
||||
|
||||
module_name = self.value.__module__
|
||||
|
||||
if "<" not in self.value.__qualname__: # Qualname can include <locals>
|
||||
return "%s.%s" % (module_name, self.value.__qualname__), {
|
||||
"import %s" % self.value.__module__
|
||||
}
|
||||
|
||||
raise ValueError(
|
||||
"Could not find function %s in %s.\n" % (self.value.__name__, module_name)
|
||||
)
|
||||
|
||||
|
||||
class FunctoolsPartialSerializer(BaseSerializer):
|
||||
def serialize(self):
|
||||
# Serialize functools.partial() arguments
|
||||
func_string, func_imports = serializer_factory(self.value.func).serialize()
|
||||
args_string, args_imports = serializer_factory(self.value.args).serialize()
|
||||
keywords_string, keywords_imports = serializer_factory(
|
||||
self.value.keywords
|
||||
).serialize()
|
||||
# Add any imports needed by arguments
|
||||
imports = {"import functools", *func_imports, *args_imports, *keywords_imports}
|
||||
return (
|
||||
"functools.%s(%s, *%s, **%s)"
|
||||
% (
|
||||
self.value.__class__.__name__,
|
||||
func_string,
|
||||
args_string,
|
||||
keywords_string,
|
||||
),
|
||||
imports,
|
||||
)
|
||||
|
||||
|
||||
class IterableSerializer(BaseSerializer):
|
||||
def serialize(self):
|
||||
imports = set()
|
||||
strings = []
|
||||
for item in self.value:
|
||||
item_string, item_imports = serializer_factory(item).serialize()
|
||||
imports.update(item_imports)
|
||||
strings.append(item_string)
|
||||
# When len(strings)==0, the empty iterable should be serialized as
|
||||
# "()", not "(,)" because (,) is invalid Python syntax.
|
||||
value = "(%s)" if len(strings) != 1 else "(%s,)"
|
||||
return value % (", ".join(strings)), imports
|
||||
|
||||
|
||||
class ModelFieldSerializer(DeconstructableSerializer):
|
||||
def serialize(self):
|
||||
attr_name, path, args, kwargs = self.value.deconstruct()
|
||||
return self.serialize_deconstructed(path, args, kwargs)
|
||||
|
||||
|
||||
class ModelManagerSerializer(DeconstructableSerializer):
|
||||
def serialize(self):
|
||||
as_manager, manager_path, qs_path, args, kwargs = self.value.deconstruct()
|
||||
if as_manager:
|
||||
name, imports = self._serialize_path(qs_path)
|
||||
return "%s.as_manager()" % name, imports
|
||||
else:
|
||||
return self.serialize_deconstructed(manager_path, args, kwargs)
|
||||
|
||||
|
||||
class OperationSerializer(BaseSerializer):
|
||||
def serialize(self):
|
||||
from django.db.migrations.writer import OperationWriter
|
||||
|
||||
string, imports = OperationWriter(self.value, indentation=0).serialize()
|
||||
# Nested operation, trailing comma is handled in upper OperationWriter._write()
|
||||
return string.rstrip(","), imports
|
||||
|
||||
|
||||
class PathLikeSerializer(BaseSerializer):
|
||||
def serialize(self):
|
||||
return repr(os.fspath(self.value)), {}
|
||||
|
||||
|
||||
class PathSerializer(BaseSerializer):
|
||||
def serialize(self):
|
||||
# Convert concrete paths to pure paths to avoid issues with migrations
|
||||
# generated on one platform being used on a different platform.
|
||||
prefix = "Pure" if isinstance(self.value, pathlib.Path) else ""
|
||||
return "pathlib.%s%r" % (prefix, self.value), {"import pathlib"}
|
||||
|
||||
|
||||
class RegexSerializer(BaseSerializer):
|
||||
def serialize(self):
|
||||
regex_pattern, pattern_imports = serializer_factory(
|
||||
self.value.pattern
|
||||
).serialize()
|
||||
# Turn off default implicit flags (e.g. re.U) because regexes with the
|
||||
# same implicit and explicit flags aren't equal.
|
||||
flags = self.value.flags ^ re.compile("").flags
|
||||
regex_flags, flag_imports = serializer_factory(flags).serialize()
|
||||
imports = {"import re", *pattern_imports, *flag_imports}
|
||||
args = [regex_pattern]
|
||||
if flags:
|
||||
args.append(regex_flags)
|
||||
return "re.compile(%s)" % ", ".join(args), imports
|
||||
|
||||
|
||||
class SequenceSerializer(BaseSequenceSerializer):
|
||||
def _format(self):
|
||||
return "[%s]"
|
||||
|
||||
|
||||
class SetSerializer(BaseUnorderedSequenceSerializer):
|
||||
def _format(self):
|
||||
# Serialize as a set literal except when value is empty because {}
|
||||
# is an empty dict.
|
||||
return "{%s}" if self.value else "set(%s)"
|
||||
|
||||
|
||||
class SettingsReferenceSerializer(BaseSerializer):
|
||||
def serialize(self):
|
||||
return "settings.%s" % self.value.setting_name, {
|
||||
"from django.conf import settings"
|
||||
}
|
||||
|
||||
|
||||
class TupleSerializer(BaseSequenceSerializer):
|
||||
def _format(self):
|
||||
# When len(value)==0, the empty tuple should be serialized as "()",
|
||||
# not "(,)" because (,) is invalid Python syntax.
|
||||
return "(%s)" if len(self.value) != 1 else "(%s,)"
|
||||
|
||||
|
||||
class TypeSerializer(BaseSerializer):
|
||||
def serialize(self):
|
||||
special_cases = [
|
||||
(models.Model, "models.Model", ["from django.db import models"]),
|
||||
(types.NoneType, "types.NoneType", ["import types"]),
|
||||
]
|
||||
for case, string, imports in special_cases:
|
||||
if case is self.value:
|
||||
return string, set(imports)
|
||||
if hasattr(self.value, "__module__"):
|
||||
module = self.value.__module__
|
||||
if module == builtins.__name__:
|
||||
return self.value.__name__, set()
|
||||
else:
|
||||
return "%s.%s" % (module, self.value.__qualname__), {
|
||||
"import %s" % module
|
||||
}
|
||||
|
||||
|
||||
class UUIDSerializer(BaseSerializer):
|
||||
def serialize(self):
|
||||
return "uuid.%s" % repr(self.value), {"import uuid"}
|
||||
|
||||
|
||||
class Serializer:
|
||||
_registry = {
|
||||
# Some of these are order-dependent.
|
||||
frozenset: FrozensetSerializer,
|
||||
list: SequenceSerializer,
|
||||
set: SetSerializer,
|
||||
tuple: TupleSerializer,
|
||||
dict: DictionarySerializer,
|
||||
models.Choices: ChoicesSerializer,
|
||||
enum.Enum: EnumSerializer,
|
||||
datetime.datetime: DatetimeDatetimeSerializer,
|
||||
(datetime.date, datetime.timedelta, datetime.time): DateTimeSerializer,
|
||||
SettingsReference: SettingsReferenceSerializer,
|
||||
float: FloatSerializer,
|
||||
(bool, int, types.NoneType, bytes, str, range): BaseSimpleSerializer,
|
||||
decimal.Decimal: DecimalSerializer,
|
||||
(functools.partial, functools.partialmethod): FunctoolsPartialSerializer,
|
||||
FUNCTION_TYPES: FunctionTypeSerializer,
|
||||
collections.abc.Iterable: IterableSerializer,
|
||||
(COMPILED_REGEX_TYPE, RegexObject): RegexSerializer,
|
||||
uuid.UUID: UUIDSerializer,
|
||||
pathlib.PurePath: PathSerializer,
|
||||
os.PathLike: PathLikeSerializer,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def register(cls, type_, serializer):
|
||||
if not issubclass(serializer, BaseSerializer):
|
||||
raise ValueError(
|
||||
"'%s' must inherit from 'BaseSerializer'." % serializer.__name__
|
||||
)
|
||||
cls._registry[type_] = serializer
|
||||
|
||||
@classmethod
|
||||
def unregister(cls, type_):
|
||||
cls._registry.pop(type_)
|
||||
|
||||
|
||||
def serializer_factory(value):
|
||||
if isinstance(value, Promise):
|
||||
value = str(value)
|
||||
elif isinstance(value, LazyObject):
|
||||
# The unwrapped value is returned as the first item of the arguments
|
||||
# tuple.
|
||||
value = value.__reduce__()[1][0]
|
||||
|
||||
if isinstance(value, models.Field):
|
||||
return ModelFieldSerializer(value)
|
||||
if isinstance(value, models.manager.BaseManager):
|
||||
return ModelManagerSerializer(value)
|
||||
if isinstance(value, Operation):
|
||||
return OperationSerializer(value)
|
||||
if isinstance(value, type):
|
||||
return TypeSerializer(value)
|
||||
# Anything that knows how to deconstruct itself.
|
||||
if hasattr(value, "deconstruct"):
|
||||
return DeconstructableSerializer(value)
|
||||
for type_, serializer_cls in Serializer._registry.items():
|
||||
if isinstance(value, type_):
|
||||
return serializer_cls(value)
|
||||
raise ValueError(
|
||||
"Cannot serialize: %r\nThere are some values Django cannot serialize into "
|
||||
"migration files.\nFor more, see https://docs.djangoproject.com/en/%s/"
|
||||
"topics/migrations/#migration-serializing" % (value, get_docs_version())
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,129 @@
|
||||
import datetime
|
||||
import re
|
||||
from collections import namedtuple
|
||||
|
||||
from django.db.models.fields.related import RECURSIVE_RELATIONSHIP_CONSTANT
|
||||
|
||||
FieldReference = namedtuple("FieldReference", "to through")
|
||||
|
||||
COMPILED_REGEX_TYPE = type(re.compile(""))
|
||||
|
||||
|
||||
class RegexObject:
|
||||
def __init__(self, obj):
|
||||
self.pattern = obj.pattern
|
||||
self.flags = obj.flags
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, RegexObject):
|
||||
return NotImplemented
|
||||
return self.pattern == other.pattern and self.flags == other.flags
|
||||
|
||||
|
||||
def get_migration_name_timestamp():
|
||||
return datetime.datetime.now().strftime("%Y%m%d_%H%M")
|
||||
|
||||
|
||||
def resolve_relation(model, app_label=None, model_name=None):
|
||||
"""
|
||||
Turn a model class or model reference string and return a model tuple.
|
||||
|
||||
app_label and model_name are used to resolve the scope of recursive and
|
||||
unscoped model relationship.
|
||||
"""
|
||||
if isinstance(model, str):
|
||||
if model == RECURSIVE_RELATIONSHIP_CONSTANT:
|
||||
if app_label is None or model_name is None:
|
||||
raise TypeError(
|
||||
"app_label and model_name must be provided to resolve "
|
||||
"recursive relationships."
|
||||
)
|
||||
return app_label, model_name
|
||||
if "." in model:
|
||||
app_label, model_name = model.split(".", 1)
|
||||
return app_label, model_name.lower()
|
||||
if app_label is None:
|
||||
raise TypeError(
|
||||
"app_label must be provided to resolve unscoped model relationships."
|
||||
)
|
||||
return app_label, model.lower()
|
||||
return model._meta.app_label, model._meta.model_name
|
||||
|
||||
|
||||
def field_references(
|
||||
model_tuple,
|
||||
field,
|
||||
reference_model_tuple,
|
||||
reference_field_name=None,
|
||||
reference_field=None,
|
||||
):
|
||||
"""
|
||||
Return either False or a FieldReference if `field` references provided
|
||||
context.
|
||||
|
||||
False positives can be returned if `reference_field_name` is provided
|
||||
without `reference_field` because of the introspection limitation it
|
||||
incurs. This should not be an issue when this function is used to determine
|
||||
whether or not an optimization can take place.
|
||||
"""
|
||||
remote_field = field.remote_field
|
||||
if not remote_field:
|
||||
return False
|
||||
references_to = None
|
||||
references_through = None
|
||||
if resolve_relation(remote_field.model, *model_tuple) == reference_model_tuple:
|
||||
to_fields = getattr(field, "to_fields", None)
|
||||
if (
|
||||
reference_field_name is None
|
||||
or
|
||||
# Unspecified to_field(s).
|
||||
to_fields is None
|
||||
or
|
||||
# Reference to primary key.
|
||||
(
|
||||
None in to_fields
|
||||
and (reference_field is None or reference_field.primary_key)
|
||||
)
|
||||
or
|
||||
# Reference to field.
|
||||
reference_field_name in to_fields
|
||||
):
|
||||
references_to = (remote_field, to_fields)
|
||||
through = getattr(remote_field, "through", None)
|
||||
if through and resolve_relation(through, *model_tuple) == reference_model_tuple:
|
||||
through_fields = remote_field.through_fields
|
||||
if (
|
||||
reference_field_name is None
|
||||
or
|
||||
# Unspecified through_fields.
|
||||
through_fields is None
|
||||
or
|
||||
# Reference to field.
|
||||
reference_field_name in through_fields
|
||||
):
|
||||
references_through = (remote_field, through_fields)
|
||||
if not (references_to or references_through):
|
||||
return False
|
||||
return FieldReference(references_to, references_through)
|
||||
|
||||
|
||||
def get_references(state, model_tuple, field_tuple=()):
|
||||
"""
|
||||
Generator of (model_state, name, field, reference) referencing
|
||||
provided context.
|
||||
|
||||
If field_tuple is provided only references to this particular field of
|
||||
model_tuple will be generated.
|
||||
"""
|
||||
for state_model_tuple, model_state in state.models.items():
|
||||
for name, field in model_state.fields.items():
|
||||
reference = field_references(
|
||||
state_model_tuple, field, model_tuple, *field_tuple
|
||||
)
|
||||
if reference:
|
||||
yield model_state, name, field, reference
|
||||
|
||||
|
||||
def field_is_referenced(state, model_tuple, field_tuple):
|
||||
"""Return whether `field_tuple` is referenced by any state models."""
|
||||
return next(get_references(state, model_tuple, field_tuple), None) is not None
|
||||
@@ -0,0 +1,316 @@
|
||||
import os
|
||||
import re
|
||||
from importlib import import_module
|
||||
|
||||
from django import get_version
|
||||
from django.apps import apps
|
||||
|
||||
# SettingsReference imported for backwards compatibility in Django 2.2.
|
||||
from django.conf import SettingsReference # NOQA
|
||||
from django.db import migrations
|
||||
from django.db.migrations.loader import MigrationLoader
|
||||
from django.db.migrations.serializer import Serializer, serializer_factory
|
||||
from django.utils.inspect import get_func_args
|
||||
from django.utils.module_loading import module_dir
|
||||
from django.utils.timezone import now
|
||||
|
||||
|
||||
class OperationWriter:
|
||||
def __init__(self, operation, indentation=2):
|
||||
self.operation = operation
|
||||
self.buff = []
|
||||
self.indentation = indentation
|
||||
|
||||
def serialize(self):
|
||||
def _write(_arg_name, _arg_value):
|
||||
if _arg_name in self.operation.serialization_expand_args and isinstance(
|
||||
_arg_value, (list, tuple, dict)
|
||||
):
|
||||
if isinstance(_arg_value, dict):
|
||||
self.feed("%s={" % _arg_name)
|
||||
self.indent()
|
||||
for key, value in _arg_value.items():
|
||||
key_string, key_imports = MigrationWriter.serialize(key)
|
||||
arg_string, arg_imports = MigrationWriter.serialize(value)
|
||||
args = arg_string.splitlines()
|
||||
if len(args) > 1:
|
||||
self.feed("%s: %s" % (key_string, args[0]))
|
||||
for arg in args[1:-1]:
|
||||
self.feed(arg)
|
||||
self.feed("%s," % args[-1])
|
||||
else:
|
||||
self.feed("%s: %s," % (key_string, arg_string))
|
||||
imports.update(key_imports)
|
||||
imports.update(arg_imports)
|
||||
self.unindent()
|
||||
self.feed("},")
|
||||
else:
|
||||
self.feed("%s=[" % _arg_name)
|
||||
self.indent()
|
||||
for item in _arg_value:
|
||||
arg_string, arg_imports = MigrationWriter.serialize(item)
|
||||
args = arg_string.splitlines()
|
||||
if len(args) > 1:
|
||||
for arg in args[:-1]:
|
||||
self.feed(arg)
|
||||
self.feed("%s," % args[-1])
|
||||
else:
|
||||
self.feed("%s," % arg_string)
|
||||
imports.update(arg_imports)
|
||||
self.unindent()
|
||||
self.feed("],")
|
||||
else:
|
||||
arg_string, arg_imports = MigrationWriter.serialize(_arg_value)
|
||||
args = arg_string.splitlines()
|
||||
if len(args) > 1:
|
||||
self.feed("%s=%s" % (_arg_name, args[0]))
|
||||
for arg in args[1:-1]:
|
||||
self.feed(arg)
|
||||
self.feed("%s," % args[-1])
|
||||
else:
|
||||
self.feed("%s=%s," % (_arg_name, arg_string))
|
||||
imports.update(arg_imports)
|
||||
|
||||
imports = set()
|
||||
name, args, kwargs = self.operation.deconstruct()
|
||||
operation_args = get_func_args(self.operation.__init__)
|
||||
|
||||
# See if this operation is in django.db.migrations. If it is,
|
||||
# We can just use the fact we already have that imported,
|
||||
# otherwise, we need to add an import for the operation class.
|
||||
if getattr(migrations, name, None) == self.operation.__class__:
|
||||
self.feed("migrations.%s(" % name)
|
||||
else:
|
||||
imports.add("import %s" % (self.operation.__class__.__module__))
|
||||
self.feed("%s.%s(" % (self.operation.__class__.__module__, name))
|
||||
|
||||
self.indent()
|
||||
|
||||
for i, arg in enumerate(args):
|
||||
arg_value = arg
|
||||
arg_name = operation_args[i]
|
||||
_write(arg_name, arg_value)
|
||||
|
||||
i = len(args)
|
||||
# Only iterate over remaining arguments
|
||||
for arg_name in operation_args[i:]:
|
||||
if arg_name in kwargs: # Don't sort to maintain signature order
|
||||
arg_value = kwargs[arg_name]
|
||||
_write(arg_name, arg_value)
|
||||
|
||||
self.unindent()
|
||||
self.feed("),")
|
||||
return self.render(), imports
|
||||
|
||||
def indent(self):
|
||||
self.indentation += 1
|
||||
|
||||
def unindent(self):
|
||||
self.indentation -= 1
|
||||
|
||||
def feed(self, line):
|
||||
self.buff.append(" " * (self.indentation * 4) + line)
|
||||
|
||||
def render(self):
|
||||
return "\n".join(self.buff)
|
||||
|
||||
|
||||
class MigrationWriter:
|
||||
"""
|
||||
Take a Migration instance and is able to produce the contents
|
||||
of the migration file from it.
|
||||
"""
|
||||
|
||||
def __init__(self, migration, include_header=True):
|
||||
self.migration = migration
|
||||
self.include_header = include_header
|
||||
self.needs_manual_porting = False
|
||||
|
||||
def as_string(self):
|
||||
"""Return a string of the file contents."""
|
||||
items = {
|
||||
"replaces_str": "",
|
||||
"initial_str": "",
|
||||
}
|
||||
|
||||
imports = set()
|
||||
|
||||
# Deconstruct operations
|
||||
operations = []
|
||||
for operation in self.migration.operations:
|
||||
operation_string, operation_imports = OperationWriter(operation).serialize()
|
||||
imports.update(operation_imports)
|
||||
operations.append(operation_string)
|
||||
items["operations"] = "\n".join(operations) + "\n" if operations else ""
|
||||
|
||||
# Format dependencies and write out swappable dependencies right
|
||||
dependencies = []
|
||||
for dependency in self.migration.dependencies:
|
||||
if dependency[0] == "__setting__":
|
||||
dependencies.append(
|
||||
" migrations.swappable_dependency(settings.%s),"
|
||||
% dependency[1]
|
||||
)
|
||||
imports.add("from django.conf import settings")
|
||||
else:
|
||||
dependencies.append(" %s," % self.serialize(dependency)[0])
|
||||
items["dependencies"] = (
|
||||
"\n".join(sorted(dependencies)) + "\n" if dependencies else ""
|
||||
)
|
||||
|
||||
# Format imports nicely, swapping imports of functions from migration files
|
||||
# for comments
|
||||
migration_imports = set()
|
||||
for line in list(imports):
|
||||
if re.match(r"^import (.*)\.\d+[^\s]*$", line):
|
||||
migration_imports.add(line.split("import")[1].strip())
|
||||
imports.remove(line)
|
||||
self.needs_manual_porting = True
|
||||
|
||||
# django.db.migrations is always used, but models import may not be.
|
||||
# If models import exists, merge it with migrations import.
|
||||
if "from django.db import models" in imports:
|
||||
imports.discard("from django.db import models")
|
||||
imports.add("from django.db import migrations, models")
|
||||
else:
|
||||
imports.add("from django.db import migrations")
|
||||
|
||||
# Sort imports by the package / module to be imported (the part after
|
||||
# "from" in "from ... import ..." or after "import" in "import ...").
|
||||
# First group the "import" statements, then "from ... import ...".
|
||||
sorted_imports = sorted(
|
||||
imports, key=lambda i: (i.split()[0] == "from", i.split()[1])
|
||||
)
|
||||
items["imports"] = "\n".join(sorted_imports) + "\n" if imports else ""
|
||||
if migration_imports:
|
||||
items["imports"] += (
|
||||
"\n\n# Functions from the following migrations need manual "
|
||||
"copying.\n# Move them and any dependencies into this file, "
|
||||
"then update the\n# RunPython operations to refer to the local "
|
||||
"versions:\n# %s"
|
||||
) % "\n# ".join(sorted(migration_imports))
|
||||
# If there's a replaces, make a string for it
|
||||
if self.migration.replaces:
|
||||
items["replaces_str"] = (
|
||||
"\n replaces = %s\n" % self.serialize(self.migration.replaces)[0]
|
||||
)
|
||||
# Hinting that goes into comment
|
||||
if self.include_header:
|
||||
items["migration_header"] = MIGRATION_HEADER_TEMPLATE % {
|
||||
"version": get_version(),
|
||||
"timestamp": now().strftime("%Y-%m-%d %H:%M"),
|
||||
}
|
||||
else:
|
||||
items["migration_header"] = ""
|
||||
|
||||
if self.migration.initial:
|
||||
items["initial_str"] = "\n initial = True\n"
|
||||
|
||||
return MIGRATION_TEMPLATE % items
|
||||
|
||||
@property
|
||||
def basedir(self):
|
||||
migrations_package_name, _ = MigrationLoader.migrations_module(
|
||||
self.migration.app_label
|
||||
)
|
||||
|
||||
if migrations_package_name is None:
|
||||
raise ValueError(
|
||||
"Django can't create migrations for app '%s' because "
|
||||
"migrations have been disabled via the MIGRATION_MODULES "
|
||||
"setting." % self.migration.app_label
|
||||
)
|
||||
|
||||
# See if we can import the migrations module directly
|
||||
try:
|
||||
migrations_module = import_module(migrations_package_name)
|
||||
except ImportError:
|
||||
pass
|
||||
else:
|
||||
try:
|
||||
return module_dir(migrations_module)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Alright, see if it's a direct submodule of the app
|
||||
app_config = apps.get_app_config(self.migration.app_label)
|
||||
(
|
||||
maybe_app_name,
|
||||
_,
|
||||
migrations_package_basename,
|
||||
) = migrations_package_name.rpartition(".")
|
||||
if app_config.name == maybe_app_name:
|
||||
return os.path.join(app_config.path, migrations_package_basename)
|
||||
|
||||
# In case of using MIGRATION_MODULES setting and the custom package
|
||||
# doesn't exist, create one, starting from an existing package
|
||||
existing_dirs, missing_dirs = migrations_package_name.split("."), []
|
||||
while existing_dirs:
|
||||
missing_dirs.insert(0, existing_dirs.pop(-1))
|
||||
try:
|
||||
base_module = import_module(".".join(existing_dirs))
|
||||
except (ImportError, ValueError):
|
||||
continue
|
||||
else:
|
||||
try:
|
||||
base_dir = module_dir(base_module)
|
||||
except ValueError:
|
||||
continue
|
||||
else:
|
||||
break
|
||||
else:
|
||||
raise ValueError(
|
||||
"Could not locate an appropriate location to create "
|
||||
"migrations package %s. Make sure the toplevel "
|
||||
"package exists and can be imported." % migrations_package_name
|
||||
)
|
||||
|
||||
final_dir = os.path.join(base_dir, *missing_dirs)
|
||||
os.makedirs(final_dir, exist_ok=True)
|
||||
for missing_dir in missing_dirs:
|
||||
base_dir = os.path.join(base_dir, missing_dir)
|
||||
with open(os.path.join(base_dir, "__init__.py"), "w"):
|
||||
pass
|
||||
|
||||
return final_dir
|
||||
|
||||
@property
|
||||
def filename(self):
|
||||
return "%s.py" % self.migration.name
|
||||
|
||||
@property
|
||||
def path(self):
|
||||
return os.path.join(self.basedir, self.filename)
|
||||
|
||||
@classmethod
|
||||
def serialize(cls, value):
|
||||
return serializer_factory(value).serialize()
|
||||
|
||||
@classmethod
|
||||
def register_serializer(cls, type_, serializer):
|
||||
Serializer.register(type_, serializer)
|
||||
|
||||
@classmethod
|
||||
def unregister_serializer(cls, type_):
|
||||
Serializer.unregister(type_)
|
||||
|
||||
|
||||
MIGRATION_HEADER_TEMPLATE = """\
|
||||
# Generated by Django %(version)s on %(timestamp)s
|
||||
|
||||
"""
|
||||
|
||||
|
||||
MIGRATION_TEMPLATE = """\
|
||||
%(migration_header)s%(imports)s
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
%(replaces_str)s%(initial_str)s
|
||||
dependencies = [
|
||||
%(dependencies)s\
|
||||
]
|
||||
|
||||
operations = [
|
||||
%(operations)s\
|
||||
]
|
||||
"""
|
||||
Reference in New Issue
Block a user