update to python fastpi

This commit is contained in:
Iliyan Angelov
2025-11-16 15:59:05 +02:00
parent 93d4c1df80
commit 98ccd5b6ff
4464 changed files with 773233 additions and 13740 deletions

View File

@@ -0,0 +1,6 @@
from . import mssql
from . import mysql
from . import oracle
from . import postgresql
from . import sqlite
from .impl import DefaultImpl

View File

@@ -0,0 +1,332 @@
from __future__ import annotations
import functools
from typing import Optional
from typing import TYPE_CHECKING
from typing import Union
from sqlalchemy import exc
from sqlalchemy import Integer
from sqlalchemy import types as sqltypes
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.schema import Column
from sqlalchemy.schema import DDLElement
from sqlalchemy.sql.elements import quoted_name
from ..util.sqla_compat import _columns_for_constraint # noqa
from ..util.sqla_compat import _find_columns # noqa
from ..util.sqla_compat import _fk_spec # noqa
from ..util.sqla_compat import _is_type_bound # noqa
from ..util.sqla_compat import _table_for_constraint # noqa
if TYPE_CHECKING:
from typing import Any
from sqlalchemy.sql.compiler import Compiled
from sqlalchemy.sql.compiler import DDLCompiler
from sqlalchemy.sql.elements import TextClause
from sqlalchemy.sql.functions import Function
from sqlalchemy.sql.schema import FetchedValue
from sqlalchemy.sql.type_api import TypeEngine
from .impl import DefaultImpl
from ..util.sqla_compat import Computed
from ..util.sqla_compat import Identity
_ServerDefault = Union["TextClause", "FetchedValue", "Function[Any]", str]
class AlterTable(DDLElement):
"""Represent an ALTER TABLE statement.
Only the string name and optional schema name of the table
is required, not a full Table object.
"""
def __init__(
self,
table_name: str,
schema: Optional[Union[quoted_name, str]] = None,
) -> None:
self.table_name = table_name
self.schema = schema
class RenameTable(AlterTable):
def __init__(
self,
old_table_name: str,
new_table_name: Union[quoted_name, str],
schema: Optional[Union[quoted_name, str]] = None,
) -> None:
super().__init__(old_table_name, schema=schema)
self.new_table_name = new_table_name
class AlterColumn(AlterTable):
def __init__(
self,
name: str,
column_name: str,
schema: Optional[str] = None,
existing_type: Optional[TypeEngine] = None,
existing_nullable: Optional[bool] = None,
existing_server_default: Optional[_ServerDefault] = None,
existing_comment: Optional[str] = None,
) -> None:
super().__init__(name, schema=schema)
self.column_name = column_name
self.existing_type = (
sqltypes.to_instance(existing_type)
if existing_type is not None
else None
)
self.existing_nullable = existing_nullable
self.existing_server_default = existing_server_default
self.existing_comment = existing_comment
class ColumnNullable(AlterColumn):
def __init__(
self, name: str, column_name: str, nullable: bool, **kw
) -> None:
super().__init__(name, column_name, **kw)
self.nullable = nullable
class ColumnType(AlterColumn):
def __init__(
self, name: str, column_name: str, type_: TypeEngine, **kw
) -> None:
super().__init__(name, column_name, **kw)
self.type_ = sqltypes.to_instance(type_)
class ColumnName(AlterColumn):
def __init__(
self, name: str, column_name: str, newname: str, **kw
) -> None:
super().__init__(name, column_name, **kw)
self.newname = newname
class ColumnDefault(AlterColumn):
def __init__(
self,
name: str,
column_name: str,
default: Optional[_ServerDefault],
**kw,
) -> None:
super().__init__(name, column_name, **kw)
self.default = default
class ComputedColumnDefault(AlterColumn):
def __init__(
self, name: str, column_name: str, default: Optional[Computed], **kw
) -> None:
super().__init__(name, column_name, **kw)
self.default = default
class IdentityColumnDefault(AlterColumn):
def __init__(
self,
name: str,
column_name: str,
default: Optional[Identity],
impl: DefaultImpl,
**kw,
) -> None:
super().__init__(name, column_name, **kw)
self.default = default
self.impl = impl
class AddColumn(AlterTable):
def __init__(
self,
name: str,
column: Column[Any],
schema: Optional[Union[quoted_name, str]] = None,
) -> None:
super().__init__(name, schema=schema)
self.column = column
class DropColumn(AlterTable):
def __init__(
self, name: str, column: Column[Any], schema: Optional[str] = None
) -> None:
super().__init__(name, schema=schema)
self.column = column
class ColumnComment(AlterColumn):
def __init__(
self, name: str, column_name: str, comment: Optional[str], **kw
) -> None:
super().__init__(name, column_name, **kw)
self.comment = comment
@compiles(RenameTable)
def visit_rename_table(
element: RenameTable, compiler: DDLCompiler, **kw
) -> str:
return "%s RENAME TO %s" % (
alter_table(compiler, element.table_name, element.schema),
format_table_name(compiler, element.new_table_name, element.schema),
)
@compiles(AddColumn)
def visit_add_column(element: AddColumn, compiler: DDLCompiler, **kw) -> str:
return "%s %s" % (
alter_table(compiler, element.table_name, element.schema),
add_column(compiler, element.column, **kw),
)
@compiles(DropColumn)
def visit_drop_column(element: DropColumn, compiler: DDLCompiler, **kw) -> str:
return "%s %s" % (
alter_table(compiler, element.table_name, element.schema),
drop_column(compiler, element.column.name, **kw),
)
@compiles(ColumnNullable)
def visit_column_nullable(
element: ColumnNullable, compiler: DDLCompiler, **kw
) -> str:
return "%s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
"DROP NOT NULL" if element.nullable else "SET NOT NULL",
)
@compiles(ColumnType)
def visit_column_type(element: ColumnType, compiler: DDLCompiler, **kw) -> str:
return "%s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
"TYPE %s" % format_type(compiler, element.type_),
)
@compiles(ColumnName)
def visit_column_name(element: ColumnName, compiler: DDLCompiler, **kw) -> str:
return "%s RENAME %s TO %s" % (
alter_table(compiler, element.table_name, element.schema),
format_column_name(compiler, element.column_name),
format_column_name(compiler, element.newname),
)
@compiles(ColumnDefault)
def visit_column_default(
element: ColumnDefault, compiler: DDLCompiler, **kw
) -> str:
return "%s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
"SET DEFAULT %s" % format_server_default(compiler, element.default)
if element.default is not None
else "DROP DEFAULT",
)
@compiles(ComputedColumnDefault)
def visit_computed_column(
element: ComputedColumnDefault, compiler: DDLCompiler, **kw
):
raise exc.CompileError(
'Adding or removing a "computed" construct, e.g. GENERATED '
"ALWAYS AS, to or from an existing column is not supported."
)
@compiles(IdentityColumnDefault)
def visit_identity_column(
element: IdentityColumnDefault, compiler: DDLCompiler, **kw
):
raise exc.CompileError(
'Adding, removing or modifying an "identity" construct, '
"e.g. GENERATED AS IDENTITY, to or from an existing "
"column is not supported in this dialect."
)
def quote_dotted(
name: Union[quoted_name, str], quote: functools.partial
) -> Union[quoted_name, str]:
"""quote the elements of a dotted name"""
if isinstance(name, quoted_name):
return quote(name)
result = ".".join([quote(x) for x in name.split(".")])
return result
def format_table_name(
compiler: Compiled,
name: Union[quoted_name, str],
schema: Optional[Union[quoted_name, str]],
) -> Union[quoted_name, str]:
quote = functools.partial(compiler.preparer.quote)
if schema:
return quote_dotted(schema, quote) + "." + quote(name)
else:
return quote(name)
def format_column_name(
compiler: DDLCompiler, name: Optional[Union[quoted_name, str]]
) -> Union[quoted_name, str]:
return compiler.preparer.quote(name) # type: ignore[arg-type]
def format_server_default(
compiler: DDLCompiler,
default: Optional[_ServerDefault],
) -> str:
return compiler.get_column_default_string(
Column("x", Integer, server_default=default)
)
def format_type(compiler: DDLCompiler, type_: TypeEngine) -> str:
return compiler.dialect.type_compiler.process(type_)
def alter_table(
compiler: DDLCompiler,
name: str,
schema: Optional[str],
) -> str:
return "ALTER TABLE %s" % format_table_name(compiler, name, schema)
def drop_column(compiler: DDLCompiler, name: str, **kw) -> str:
return "DROP COLUMN %s" % format_column_name(compiler, name)
def alter_column(compiler: DDLCompiler, name: str) -> str:
return "ALTER COLUMN %s" % format_column_name(compiler, name)
def add_column(compiler: DDLCompiler, column: Column[Any], **kw) -> str:
text = "ADD COLUMN %s" % compiler.get_column_specification(column, **kw)
const = " ".join(
compiler.process(constraint) for constraint in column.constraints
)
if const:
text += " " + const
return text

View File

@@ -0,0 +1,747 @@
from __future__ import annotations
from collections import namedtuple
import re
from typing import Any
from typing import Callable
from typing import Dict
from typing import Iterable
from typing import List
from typing import Mapping
from typing import Optional
from typing import Sequence
from typing import Set
from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING
from typing import Union
from sqlalchemy import cast
from sqlalchemy import schema
from sqlalchemy import text
from . import base
from .. import util
from ..util import sqla_compat
if TYPE_CHECKING:
from typing import Literal
from typing import TextIO
from sqlalchemy.engine import Connection
from sqlalchemy.engine import Dialect
from sqlalchemy.engine.cursor import CursorResult
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.sql import ClauseElement
from sqlalchemy.sql import Executable
from sqlalchemy.sql.elements import ColumnElement
from sqlalchemy.sql.elements import quoted_name
from sqlalchemy.sql.schema import Column
from sqlalchemy.sql.schema import Constraint
from sqlalchemy.sql.schema import ForeignKeyConstraint
from sqlalchemy.sql.schema import Index
from sqlalchemy.sql.schema import Table
from sqlalchemy.sql.schema import UniqueConstraint
from sqlalchemy.sql.selectable import TableClause
from sqlalchemy.sql.type_api import TypeEngine
from .base import _ServerDefault
from ..autogenerate.api import AutogenContext
from ..operations.batch import ApplyBatchImpl
from ..operations.batch import BatchOperationsImpl
class ImplMeta(type):
def __init__(
cls,
classname: str,
bases: Tuple[Type[DefaultImpl]],
dict_: Dict[str, Any],
):
newtype = type.__init__(cls, classname, bases, dict_)
if "__dialect__" in dict_:
_impls[dict_["__dialect__"]] = cls # type: ignore[assignment]
return newtype
_impls: Dict[str, Type[DefaultImpl]] = {}
Params = namedtuple("Params", ["token0", "tokens", "args", "kwargs"])
class DefaultImpl(metaclass=ImplMeta):
"""Provide the entrypoint for major migration operations,
including database-specific behavioral variances.
While individual SQL/DDL constructs already provide
for database-specific implementations, variances here
allow for entirely different sequences of operations
to take place for a particular migration, such as
SQL Server's special 'IDENTITY INSERT' step for
bulk inserts.
"""
__dialect__ = "default"
transactional_ddl = False
command_terminator = ";"
type_synonyms: Tuple[Set[str], ...] = ({"NUMERIC", "DECIMAL"},)
type_arg_extract: Sequence[str] = ()
# These attributes are deprecated in SQLAlchemy via #10247. They need to
# be ignored to support older version that did not use dialect kwargs.
# They only apply to Oracle and are replaced by oracle_order,
# oracle_on_null
identity_attrs_ignore: Tuple[str, ...] = ("order", "on_null")
def __init__(
self,
dialect: Dialect,
connection: Optional[Connection],
as_sql: bool,
transactional_ddl: Optional[bool],
output_buffer: Optional[TextIO],
context_opts: Dict[str, Any],
) -> None:
self.dialect = dialect
self.connection = connection
self.as_sql = as_sql
self.literal_binds = context_opts.get("literal_binds", False)
self.output_buffer = output_buffer
self.memo: dict = {}
self.context_opts = context_opts
if transactional_ddl is not None:
self.transactional_ddl = transactional_ddl
if self.literal_binds:
if not self.as_sql:
raise util.CommandError(
"Can't use literal_binds setting without as_sql mode"
)
@classmethod
def get_by_dialect(cls, dialect: Dialect) -> Type[DefaultImpl]:
return _impls[dialect.name]
def static_output(self, text: str) -> None:
assert self.output_buffer is not None
self.output_buffer.write(text + "\n\n")
self.output_buffer.flush()
def requires_recreate_in_batch(
self, batch_op: BatchOperationsImpl
) -> bool:
"""Return True if the given :class:`.BatchOperationsImpl`
would need the table to be recreated and copied in order to
proceed.
Normally, only returns True on SQLite when operations other
than add_column are present.
"""
return False
def prep_table_for_batch(
self, batch_impl: ApplyBatchImpl, table: Table
) -> None:
"""perform any operations needed on a table before a new
one is created to replace it in batch mode.
the PG dialect uses this to drop constraints on the table
before the new one uses those same names.
"""
@property
def bind(self) -> Optional[Connection]:
return self.connection
def _exec(
self,
construct: Union[Executable, str],
execution_options: Optional[dict[str, Any]] = None,
multiparams: Sequence[dict] = (),
params: Dict[str, Any] = util.immutabledict(),
) -> Optional[CursorResult]:
if isinstance(construct, str):
construct = text(construct)
if self.as_sql:
if multiparams or params:
# TODO: coverage
raise Exception("Execution arguments not allowed with as_sql")
compile_kw: dict[str, Any]
if self.literal_binds and not isinstance(
construct, schema.DDLElement
):
compile_kw = dict(compile_kwargs={"literal_binds": True})
else:
compile_kw = {}
if TYPE_CHECKING:
assert isinstance(construct, ClauseElement)
compiled = construct.compile(dialect=self.dialect, **compile_kw)
self.static_output(
str(compiled).replace("\t", " ").strip()
+ self.command_terminator
)
return None
else:
conn = self.connection
assert conn is not None
if execution_options:
conn = conn.execution_options(**execution_options)
if params:
assert isinstance(multiparams, tuple)
multiparams += (params,)
return conn.execute(construct, multiparams)
def execute(
self,
sql: Union[Executable, str],
execution_options: Optional[dict[str, Any]] = None,
) -> None:
self._exec(sql, execution_options)
def alter_column(
self,
table_name: str,
column_name: str,
nullable: Optional[bool] = None,
server_default: Union[_ServerDefault, Literal[False]] = False,
name: Optional[str] = None,
type_: Optional[TypeEngine] = None,
schema: Optional[str] = None,
autoincrement: Optional[bool] = None,
comment: Optional[Union[str, Literal[False]]] = False,
existing_comment: Optional[str] = None,
existing_type: Optional[TypeEngine] = None,
existing_server_default: Optional[_ServerDefault] = None,
existing_nullable: Optional[bool] = None,
existing_autoincrement: Optional[bool] = None,
**kw: Any,
) -> None:
if autoincrement is not None or existing_autoincrement is not None:
util.warn(
"autoincrement and existing_autoincrement "
"only make sense for MySQL",
stacklevel=3,
)
if nullable is not None:
self._exec(
base.ColumnNullable(
table_name,
column_name,
nullable,
schema=schema,
existing_type=existing_type,
existing_server_default=existing_server_default,
existing_nullable=existing_nullable,
existing_comment=existing_comment,
)
)
if server_default is not False:
kw = {}
cls_: Type[
Union[
base.ComputedColumnDefault,
base.IdentityColumnDefault,
base.ColumnDefault,
]
]
if sqla_compat._server_default_is_computed(
server_default, existing_server_default
):
cls_ = base.ComputedColumnDefault
elif sqla_compat._server_default_is_identity(
server_default, existing_server_default
):
cls_ = base.IdentityColumnDefault
kw["impl"] = self
else:
cls_ = base.ColumnDefault
self._exec(
cls_(
table_name,
column_name,
server_default, # type:ignore[arg-type]
schema=schema,
existing_type=existing_type,
existing_server_default=existing_server_default,
existing_nullable=existing_nullable,
existing_comment=existing_comment,
**kw,
)
)
if type_ is not None:
self._exec(
base.ColumnType(
table_name,
column_name,
type_,
schema=schema,
existing_type=existing_type,
existing_server_default=existing_server_default,
existing_nullable=existing_nullable,
existing_comment=existing_comment,
)
)
if comment is not False:
self._exec(
base.ColumnComment(
table_name,
column_name,
comment,
schema=schema,
existing_type=existing_type,
existing_server_default=existing_server_default,
existing_nullable=existing_nullable,
existing_comment=existing_comment,
)
)
# do the new name last ;)
if name is not None:
self._exec(
base.ColumnName(
table_name,
column_name,
name,
schema=schema,
existing_type=existing_type,
existing_server_default=existing_server_default,
existing_nullable=existing_nullable,
)
)
def add_column(
self,
table_name: str,
column: Column[Any],
schema: Optional[Union[str, quoted_name]] = None,
) -> None:
self._exec(base.AddColumn(table_name, column, schema=schema))
def drop_column(
self,
table_name: str,
column: Column[Any],
schema: Optional[str] = None,
**kw,
) -> None:
self._exec(base.DropColumn(table_name, column, schema=schema))
def add_constraint(self, const: Any) -> None:
if const._create_rule is None or const._create_rule(self):
self._exec(schema.AddConstraint(const))
def drop_constraint(self, const: Constraint) -> None:
self._exec(schema.DropConstraint(const))
def rename_table(
self,
old_table_name: str,
new_table_name: Union[str, quoted_name],
schema: Optional[Union[str, quoted_name]] = None,
) -> None:
self._exec(
base.RenameTable(old_table_name, new_table_name, schema=schema)
)
def create_table(self, table: Table) -> None:
table.dispatch.before_create(
table, self.connection, checkfirst=False, _ddl_runner=self
)
self._exec(schema.CreateTable(table))
table.dispatch.after_create(
table, self.connection, checkfirst=False, _ddl_runner=self
)
for index in table.indexes:
self._exec(schema.CreateIndex(index))
with_comment = (
self.dialect.supports_comments and not self.dialect.inline_comments
)
comment = table.comment
if comment and with_comment:
self.create_table_comment(table)
for column in table.columns:
comment = column.comment
if comment and with_comment:
self.create_column_comment(column)
def drop_table(self, table: Table) -> None:
table.dispatch.before_drop(
table, self.connection, checkfirst=False, _ddl_runner=self
)
self._exec(schema.DropTable(table))
table.dispatch.after_drop(
table, self.connection, checkfirst=False, _ddl_runner=self
)
def create_index(self, index: Index, **kw: Any) -> None:
self._exec(schema.CreateIndex(index, **kw))
def create_table_comment(self, table: Table) -> None:
self._exec(schema.SetTableComment(table))
def drop_table_comment(self, table: Table) -> None:
self._exec(schema.DropTableComment(table))
def create_column_comment(self, column: ColumnElement[Any]) -> None:
self._exec(schema.SetColumnComment(column))
def drop_index(self, index: Index, **kw: Any) -> None:
self._exec(schema.DropIndex(index, **kw))
def bulk_insert(
self,
table: Union[TableClause, Table],
rows: List[dict],
multiinsert: bool = True,
) -> None:
if not isinstance(rows, list):
raise TypeError("List expected")
elif rows and not isinstance(rows[0], dict):
raise TypeError("List of dictionaries expected")
if self.as_sql:
for row in rows:
self._exec(
sqla_compat._insert_inline(table).values(
**{
k: sqla_compat._literal_bindparam(
k, v, type_=table.c[k].type
)
if not isinstance(
v, sqla_compat._literal_bindparam
)
else v
for k, v in row.items()
}
)
)
else:
if rows:
if multiinsert:
self._exec(
sqla_compat._insert_inline(table), multiparams=rows
)
else:
for row in rows:
self._exec(
sqla_compat._insert_inline(table).values(**row)
)
def _tokenize_column_type(self, column: Column) -> Params:
definition = self.dialect.type_compiler.process(column.type).lower()
# tokenize the SQLAlchemy-generated version of a type, so that
# the two can be compared.
#
# examples:
# NUMERIC(10, 5)
# TIMESTAMP WITH TIMEZONE
# INTEGER UNSIGNED
# INTEGER (10) UNSIGNED
# INTEGER(10) UNSIGNED
# varchar character set utf8
#
tokens = re.findall(r"[\w\-_]+|\(.+?\)", definition)
term_tokens = []
paren_term = None
for token in tokens:
if re.match(r"^\(.*\)$", token):
paren_term = token
else:
term_tokens.append(token)
params = Params(term_tokens[0], term_tokens[1:], [], {})
if paren_term:
for term in re.findall("[^(),]+", paren_term):
if "=" in term:
key, val = term.split("=")
params.kwargs[key.strip()] = val.strip()
else:
params.args.append(term.strip())
return params
def _column_types_match(
self, inspector_params: Params, metadata_params: Params
) -> bool:
if inspector_params.token0 == metadata_params.token0:
return True
synonyms = [{t.lower() for t in batch} for batch in self.type_synonyms]
inspector_all_terms = " ".join(
[inspector_params.token0] + inspector_params.tokens
)
metadata_all_terms = " ".join(
[metadata_params.token0] + metadata_params.tokens
)
for batch in synonyms:
if {inspector_all_terms, metadata_all_terms}.issubset(batch) or {
inspector_params.token0,
metadata_params.token0,
}.issubset(batch):
return True
return False
def _column_args_match(
self, inspected_params: Params, meta_params: Params
) -> bool:
"""We want to compare column parameters. However, we only want
to compare parameters that are set. If they both have `collation`,
we want to make sure they are the same. However, if only one
specifies it, dont flag it for being less specific
"""
if (
len(meta_params.tokens) == len(inspected_params.tokens)
and meta_params.tokens != inspected_params.tokens
):
return False
if (
len(meta_params.args) == len(inspected_params.args)
and meta_params.args != inspected_params.args
):
return False
insp = " ".join(inspected_params.tokens).lower()
meta = " ".join(meta_params.tokens).lower()
for reg in self.type_arg_extract:
mi = re.search(reg, insp)
mm = re.search(reg, meta)
if mi and mm and mi.group(1) != mm.group(1):
return False
return True
def compare_type(
self, inspector_column: Column[Any], metadata_column: Column
) -> bool:
"""Returns True if there ARE differences between the types of the two
columns. Takes impl.type_synonyms into account between retrospected
and metadata types
"""
inspector_params = self._tokenize_column_type(inspector_column)
metadata_params = self._tokenize_column_type(metadata_column)
if not self._column_types_match(inspector_params, metadata_params):
return True
if not self._column_args_match(inspector_params, metadata_params):
return True
return False
def compare_server_default(
self,
inspector_column,
metadata_column,
rendered_metadata_default,
rendered_inspector_default,
):
return rendered_inspector_default != rendered_metadata_default
def correct_for_autogen_constraints(
self,
conn_uniques: Set[UniqueConstraint],
conn_indexes: Set[Index],
metadata_unique_constraints: Set[UniqueConstraint],
metadata_indexes: Set[Index],
) -> None:
pass
def cast_for_batch_migrate(self, existing, existing_transfer, new_type):
if existing.type._type_affinity is not new_type._type_affinity:
existing_transfer["expr"] = cast(
existing_transfer["expr"], new_type
)
def render_ddl_sql_expr(
self, expr: ClauseElement, is_server_default: bool = False, **kw: Any
) -> str:
"""Render a SQL expression that is typically a server default,
index expression, etc.
"""
compile_kw = {"literal_binds": True, "include_table": False}
return str(
expr.compile(dialect=self.dialect, compile_kwargs=compile_kw)
)
def _compat_autogen_column_reflect(self, inspector: Inspector) -> Callable:
return self.autogen_column_reflect
def correct_for_autogen_foreignkeys(
self,
conn_fks: Set[ForeignKeyConstraint],
metadata_fks: Set[ForeignKeyConstraint],
) -> None:
pass
def autogen_column_reflect(self, inspector, table, column_info):
"""A hook that is attached to the 'column_reflect' event for when
a Table is reflected from the database during the autogenerate
process.
Dialects can elect to modify the information gathered here.
"""
def start_migrations(self) -> None:
"""A hook called when :meth:`.EnvironmentContext.run_migrations`
is called.
Implementations can set up per-migration-run state here.
"""
def emit_begin(self) -> None:
"""Emit the string ``BEGIN``, or the backend-specific
equivalent, on the current connection context.
This is used in offline mode and typically
via :meth:`.EnvironmentContext.begin_transaction`.
"""
self.static_output("BEGIN" + self.command_terminator)
def emit_commit(self) -> None:
"""Emit the string ``COMMIT``, or the backend-specific
equivalent, on the current connection context.
This is used in offline mode and typically
via :meth:`.EnvironmentContext.begin_transaction`.
"""
self.static_output("COMMIT" + self.command_terminator)
def render_type(
self, type_obj: TypeEngine, autogen_context: AutogenContext
) -> Union[str, Literal[False]]:
return False
def _compare_identity_default(self, metadata_identity, inspector_identity):
# ignored contains the attributes that were not considered
# because assumed to their default values in the db.
diff, ignored = _compare_identity_options(
metadata_identity,
inspector_identity,
sqla_compat.Identity(),
skip={"always"},
)
meta_always = getattr(metadata_identity, "always", None)
inspector_always = getattr(inspector_identity, "always", None)
# None and False are the same in this comparison
if bool(meta_always) != bool(inspector_always):
diff.add("always")
diff.difference_update(self.identity_attrs_ignore)
# returns 3 values:
return (
# different identity attributes
diff,
# ignored identity attributes
ignored,
# if the two identity should be considered different
bool(diff) or bool(metadata_identity) != bool(inspector_identity),
)
def create_index_sig(self, index: Index) -> Tuple[Any, ...]:
# order of col matters in an index
return tuple(col.name for col in index.columns)
def create_unique_constraint_sig(
self, const: UniqueConstraint
) -> Tuple[Any, ...]:
# order of col does not matters in an unique constraint
return tuple(sorted([col.name for col in const.columns]))
def _skip_functional_indexes(self, metadata_indexes, conn_indexes):
conn_indexes_by_name = {c.name: c for c in conn_indexes}
for idx in list(metadata_indexes):
if idx.name in conn_indexes_by_name:
continue
iex = sqla_compat.is_expression_index(idx)
if iex:
util.warn(
"autogenerate skipping metadata-specified "
"expression-based index "
f"{idx.name!r}; dialect {self.__dialect__!r} under "
f"SQLAlchemy {sqla_compat.sqlalchemy_version} can't "
"reflect these indexes so they can't be compared"
)
metadata_indexes.discard(idx)
def adjust_reflected_dialect_options(
self, reflected_object: Dict[str, Any], kind: str
) -> Dict[str, Any]:
return reflected_object.get("dialect_options", {})
def _compare_identity_options(
metadata_io: Union[schema.Identity, schema.Sequence, None],
inspector_io: Union[schema.Identity, schema.Sequence, None],
default_io: Union[schema.Identity, schema.Sequence],
skip: Set[str],
):
# this can be used for identity or sequence compare.
# default_io is an instance of IdentityOption with all attributes to the
# default value.
meta_d = sqla_compat._get_identity_options_dict(metadata_io)
insp_d = sqla_compat._get_identity_options_dict(inspector_io)
diff = set()
ignored_attr = set()
def check_dicts(
meta_dict: Mapping[str, Any],
insp_dict: Mapping[str, Any],
default_dict: Mapping[str, Any],
attrs: Iterable[str],
):
for attr in set(attrs).difference(skip):
meta_value = meta_dict.get(attr)
insp_value = insp_dict.get(attr)
if insp_value != meta_value:
default_value = default_dict.get(attr)
if meta_value == default_value:
ignored_attr.add(attr)
else:
diff.add(attr)
check_dicts(
meta_d,
insp_d,
sqla_compat._get_identity_options_dict(default_io),
set(meta_d).union(insp_d),
)
if sqla_compat.identity_has_dialect_kwargs:
# use only the dialect kwargs in inspector_io since metadata_io
# can have options for many backends
check_dicts(
getattr(metadata_io, "dialect_kwargs", {}),
getattr(inspector_io, "dialect_kwargs", {}),
default_io.dialect_kwargs, # type: ignore[union-attr]
getattr(inspector_io, "dialect_kwargs", {}),
)
return diff, ignored_attr

View File

@@ -0,0 +1,416 @@
from __future__ import annotations
import re
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import TYPE_CHECKING
from typing import Union
from sqlalchemy import types as sqltypes
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.schema import Column
from sqlalchemy.schema import CreateIndex
from sqlalchemy.sql.base import Executable
from sqlalchemy.sql.elements import ClauseElement
from .base import AddColumn
from .base import alter_column
from .base import alter_table
from .base import ColumnDefault
from .base import ColumnName
from .base import ColumnNullable
from .base import ColumnType
from .base import format_column_name
from .base import format_server_default
from .base import format_table_name
from .base import format_type
from .base import RenameTable
from .impl import DefaultImpl
from .. import util
from ..util import sqla_compat
if TYPE_CHECKING:
from typing import Literal
from sqlalchemy.dialects.mssql.base import MSDDLCompiler
from sqlalchemy.dialects.mssql.base import MSSQLCompiler
from sqlalchemy.engine.cursor import CursorResult
from sqlalchemy.sql.schema import Index
from sqlalchemy.sql.schema import Table
from sqlalchemy.sql.selectable import TableClause
from sqlalchemy.sql.type_api import TypeEngine
from .base import _ServerDefault
class MSSQLImpl(DefaultImpl):
__dialect__ = "mssql"
transactional_ddl = True
batch_separator = "GO"
type_synonyms = DefaultImpl.type_synonyms + ({"VARCHAR", "NVARCHAR"},)
identity_attrs_ignore = DefaultImpl.identity_attrs_ignore + (
"minvalue",
"maxvalue",
"nominvalue",
"nomaxvalue",
"cycle",
"cache",
)
def __init__(self, *arg, **kw) -> None:
super().__init__(*arg, **kw)
self.batch_separator = self.context_opts.get(
"mssql_batch_separator", self.batch_separator
)
def _exec(self, construct: Any, *args, **kw) -> Optional[CursorResult]:
result = super()._exec(construct, *args, **kw)
if self.as_sql and self.batch_separator:
self.static_output(self.batch_separator)
return result
def emit_begin(self) -> None:
self.static_output("BEGIN TRANSACTION" + self.command_terminator)
def emit_commit(self) -> None:
super().emit_commit()
if self.as_sql and self.batch_separator:
self.static_output(self.batch_separator)
def alter_column( # type:ignore[override]
self,
table_name: str,
column_name: str,
nullable: Optional[bool] = None,
server_default: Optional[
Union[_ServerDefault, Literal[False]]
] = False,
name: Optional[str] = None,
type_: Optional[TypeEngine] = None,
schema: Optional[str] = None,
existing_type: Optional[TypeEngine] = None,
existing_server_default: Optional[_ServerDefault] = None,
existing_nullable: Optional[bool] = None,
**kw: Any,
) -> None:
if nullable is not None:
if type_ is not None:
# the NULL/NOT NULL alter will handle
# the type alteration
existing_type = type_
type_ = None
elif existing_type is None:
raise util.CommandError(
"MS-SQL ALTER COLUMN operations "
"with NULL or NOT NULL require the "
"existing_type or a new type_ be passed."
)
elif existing_nullable is not None and type_ is not None:
nullable = existing_nullable
# the NULL/NOT NULL alter will handle
# the type alteration
existing_type = type_
type_ = None
elif type_ is not None:
util.warn(
"MS-SQL ALTER COLUMN operations that specify type_= "
"should also specify a nullable= or "
"existing_nullable= argument to avoid implicit conversion "
"of NOT NULL columns to NULL."
)
used_default = False
if sqla_compat._server_default_is_identity(
server_default, existing_server_default
) or sqla_compat._server_default_is_computed(
server_default, existing_server_default
):
used_default = True
kw["server_default"] = server_default
kw["existing_server_default"] = existing_server_default
super().alter_column(
table_name,
column_name,
nullable=nullable,
type_=type_,
schema=schema,
existing_type=existing_type,
existing_nullable=existing_nullable,
**kw,
)
if server_default is not False and used_default is False:
if existing_server_default is not False or server_default is None:
self._exec(
_ExecDropConstraint(
table_name,
column_name,
"sys.default_constraints",
schema,
)
)
if server_default is not None:
super().alter_column(
table_name,
column_name,
schema=schema,
server_default=server_default,
)
if name is not None:
super().alter_column(
table_name, column_name, schema=schema, name=name
)
def create_index(self, index: Index, **kw: Any) -> None:
# this likely defaults to None if not present, so get()
# should normally not return the default value. being
# defensive in any case
mssql_include = index.kwargs.get("mssql_include", None) or ()
assert index.table is not None
for col in mssql_include:
if col not in index.table.c:
index.table.append_column(Column(col, sqltypes.NullType))
self._exec(CreateIndex(index, **kw))
def bulk_insert( # type:ignore[override]
self, table: Union[TableClause, Table], rows: List[dict], **kw: Any
) -> None:
if self.as_sql:
self._exec(
"SET IDENTITY_INSERT %s ON"
% self.dialect.identifier_preparer.format_table(table)
)
super().bulk_insert(table, rows, **kw)
self._exec(
"SET IDENTITY_INSERT %s OFF"
% self.dialect.identifier_preparer.format_table(table)
)
else:
super().bulk_insert(table, rows, **kw)
def drop_column(
self,
table_name: str,
column: Column[Any],
schema: Optional[str] = None,
**kw,
) -> None:
drop_default = kw.pop("mssql_drop_default", False)
if drop_default:
self._exec(
_ExecDropConstraint(
table_name, column, "sys.default_constraints", schema
)
)
drop_check = kw.pop("mssql_drop_check", False)
if drop_check:
self._exec(
_ExecDropConstraint(
table_name, column, "sys.check_constraints", schema
)
)
drop_fks = kw.pop("mssql_drop_foreign_key", False)
if drop_fks:
self._exec(_ExecDropFKConstraint(table_name, column, schema))
super().drop_column(table_name, column, schema=schema, **kw)
def compare_server_default(
self,
inspector_column,
metadata_column,
rendered_metadata_default,
rendered_inspector_default,
):
if rendered_metadata_default is not None:
rendered_metadata_default = re.sub(
r"[\(\) \"\']", "", rendered_metadata_default
)
if rendered_inspector_default is not None:
# SQL Server collapses whitespace and adds arbitrary parenthesis
# within expressions. our only option is collapse all of it
rendered_inspector_default = re.sub(
r"[\(\) \"\']", "", rendered_inspector_default
)
return rendered_inspector_default != rendered_metadata_default
def _compare_identity_default(self, metadata_identity, inspector_identity):
diff, ignored, is_alter = super()._compare_identity_default(
metadata_identity, inspector_identity
)
if (
metadata_identity is None
and inspector_identity is not None
and not diff
and inspector_identity.column is not None
and inspector_identity.column.primary_key
):
# mssql reflect primary keys with autoincrement as identity
# columns. if no different attributes are present ignore them
is_alter = False
return diff, ignored, is_alter
def adjust_reflected_dialect_options(
self, reflected_object: Dict[str, Any], kind: str
) -> Dict[str, Any]:
options: Dict[str, Any]
options = reflected_object.get("dialect_options", {}).copy()
if not options.get("mssql_include"):
options.pop("mssql_include", None)
if not options.get("mssql_clustered"):
options.pop("mssql_clustered", None)
return options
class _ExecDropConstraint(Executable, ClauseElement):
inherit_cache = False
def __init__(
self,
tname: str,
colname: Union[Column[Any], str],
type_: str,
schema: Optional[str],
) -> None:
self.tname = tname
self.colname = colname
self.type_ = type_
self.schema = schema
class _ExecDropFKConstraint(Executable, ClauseElement):
inherit_cache = False
def __init__(
self, tname: str, colname: Column[Any], schema: Optional[str]
) -> None:
self.tname = tname
self.colname = colname
self.schema = schema
@compiles(_ExecDropConstraint, "mssql")
def _exec_drop_col_constraint(
element: _ExecDropConstraint, compiler: MSSQLCompiler, **kw
) -> str:
schema, tname, colname, type_ = (
element.schema,
element.tname,
element.colname,
element.type_,
)
# from http://www.mssqltips.com/sqlservertip/1425/\
# working-with-default-constraints-in-sql-server/
return """declare @const_name varchar(256)
select @const_name = QUOTENAME([name]) from %(type)s
where parent_object_id = object_id('%(schema_dot)s%(tname)s')
and col_name(parent_object_id, parent_column_id) = '%(colname)s'
exec('alter table %(tname_quoted)s drop constraint ' + @const_name)""" % {
"type": type_,
"tname": tname,
"colname": colname,
"tname_quoted": format_table_name(compiler, tname, schema),
"schema_dot": schema + "." if schema else "",
}
@compiles(_ExecDropFKConstraint, "mssql")
def _exec_drop_col_fk_constraint(
element: _ExecDropFKConstraint, compiler: MSSQLCompiler, **kw
) -> str:
schema, tname, colname = element.schema, element.tname, element.colname
return """declare @const_name varchar(256)
select @const_name = QUOTENAME([name]) from
sys.foreign_keys fk join sys.foreign_key_columns fkc
on fk.object_id=fkc.constraint_object_id
where fkc.parent_object_id = object_id('%(schema_dot)s%(tname)s')
and col_name(fkc.parent_object_id, fkc.parent_column_id) = '%(colname)s'
exec('alter table %(tname_quoted)s drop constraint ' + @const_name)""" % {
"tname": tname,
"colname": colname,
"tname_quoted": format_table_name(compiler, tname, schema),
"schema_dot": schema + "." if schema else "",
}
@compiles(AddColumn, "mssql")
def visit_add_column(element: AddColumn, compiler: MSDDLCompiler, **kw) -> str:
return "%s %s" % (
alter_table(compiler, element.table_name, element.schema),
mssql_add_column(compiler, element.column, **kw),
)
def mssql_add_column(
compiler: MSDDLCompiler, column: Column[Any], **kw
) -> str:
return "ADD %s" % compiler.get_column_specification(column, **kw)
@compiles(ColumnNullable, "mssql")
def visit_column_nullable(
element: ColumnNullable, compiler: MSDDLCompiler, **kw
) -> str:
return "%s %s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
format_type(compiler, element.existing_type), # type: ignore[arg-type]
"NULL" if element.nullable else "NOT NULL",
)
@compiles(ColumnDefault, "mssql")
def visit_column_default(
element: ColumnDefault, compiler: MSDDLCompiler, **kw
) -> str:
# TODO: there can also be a named constraint
# with ADD CONSTRAINT here
return "%s ADD DEFAULT %s FOR %s" % (
alter_table(compiler, element.table_name, element.schema),
format_server_default(compiler, element.default),
format_column_name(compiler, element.column_name),
)
@compiles(ColumnName, "mssql")
def visit_rename_column(
element: ColumnName, compiler: MSDDLCompiler, **kw
) -> str:
return "EXEC sp_rename '%s.%s', %s, 'COLUMN'" % (
format_table_name(compiler, element.table_name, element.schema),
format_column_name(compiler, element.column_name),
format_column_name(compiler, element.newname),
)
@compiles(ColumnType, "mssql")
def visit_column_type(
element: ColumnType, compiler: MSDDLCompiler, **kw
) -> str:
return "%s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
format_type(compiler, element.type_),
)
@compiles(RenameTable, "mssql")
def visit_rename_table(
element: RenameTable, compiler: MSDDLCompiler, **kw
) -> str:
return "EXEC sp_rename '%s', %s" % (
format_table_name(compiler, element.table_name, element.schema),
format_table_name(compiler, element.new_table_name, None),
)

View File

@@ -0,0 +1,471 @@
from __future__ import annotations
import re
from typing import Any
from typing import Optional
from typing import TYPE_CHECKING
from typing import Union
from sqlalchemy import schema
from sqlalchemy import types as sqltypes
from sqlalchemy.ext.compiler import compiles
from .base import alter_table
from .base import AlterColumn
from .base import ColumnDefault
from .base import ColumnName
from .base import ColumnNullable
from .base import ColumnType
from .base import format_column_name
from .base import format_server_default
from .impl import DefaultImpl
from .. import util
from ..autogenerate import compare
from ..util import sqla_compat
from ..util.sqla_compat import _is_mariadb
from ..util.sqla_compat import _is_type_bound
if TYPE_CHECKING:
from typing import Literal
from sqlalchemy.dialects.mysql.base import MySQLDDLCompiler
from sqlalchemy.sql.ddl import DropConstraint
from sqlalchemy.sql.schema import Constraint
from sqlalchemy.sql.type_api import TypeEngine
from .base import _ServerDefault
class MySQLImpl(DefaultImpl):
__dialect__ = "mysql"
transactional_ddl = False
type_synonyms = DefaultImpl.type_synonyms + (
{"BOOL", "TINYINT"},
{"JSON", "LONGTEXT"},
)
type_arg_extract = [r"character set ([\w\-_]+)", r"collate ([\w\-_]+)"]
def alter_column( # type:ignore[override]
self,
table_name: str,
column_name: str,
nullable: Optional[bool] = None,
server_default: Union[_ServerDefault, Literal[False]] = False,
name: Optional[str] = None,
type_: Optional[TypeEngine] = None,
schema: Optional[str] = None,
existing_type: Optional[TypeEngine] = None,
existing_server_default: Optional[_ServerDefault] = None,
existing_nullable: Optional[bool] = None,
autoincrement: Optional[bool] = None,
existing_autoincrement: Optional[bool] = None,
comment: Optional[Union[str, Literal[False]]] = False,
existing_comment: Optional[str] = None,
**kw: Any,
) -> None:
if sqla_compat._server_default_is_identity(
server_default, existing_server_default
) or sqla_compat._server_default_is_computed(
server_default, existing_server_default
):
# modifying computed or identity columns is not supported
# the default will raise
super().alter_column(
table_name,
column_name,
nullable=nullable,
type_=type_,
schema=schema,
existing_type=existing_type,
existing_nullable=existing_nullable,
server_default=server_default,
existing_server_default=existing_server_default,
**kw,
)
if name is not None or self._is_mysql_allowed_functional_default(
type_ if type_ is not None else existing_type, server_default
):
self._exec(
MySQLChangeColumn(
table_name,
column_name,
schema=schema,
newname=name if name is not None else column_name,
nullable=nullable
if nullable is not None
else existing_nullable
if existing_nullable is not None
else True,
type_=type_ if type_ is not None else existing_type,
default=server_default
if server_default is not False
else existing_server_default,
autoincrement=autoincrement
if autoincrement is not None
else existing_autoincrement,
comment=comment
if comment is not False
else existing_comment,
)
)
elif (
nullable is not None
or type_ is not None
or autoincrement is not None
or comment is not False
):
self._exec(
MySQLModifyColumn(
table_name,
column_name,
schema=schema,
newname=name if name is not None else column_name,
nullable=nullable
if nullable is not None
else existing_nullable
if existing_nullable is not None
else True,
type_=type_ if type_ is not None else existing_type,
default=server_default
if server_default is not False
else existing_server_default,
autoincrement=autoincrement
if autoincrement is not None
else existing_autoincrement,
comment=comment
if comment is not False
else existing_comment,
)
)
elif server_default is not False:
self._exec(
MySQLAlterDefault(
table_name, column_name, server_default, schema=schema
)
)
def drop_constraint(
self,
const: Constraint,
) -> None:
if isinstance(const, schema.CheckConstraint) and _is_type_bound(const):
return
super().drop_constraint(const)
def _is_mysql_allowed_functional_default(
self,
type_: Optional[TypeEngine],
server_default: Union[_ServerDefault, Literal[False]],
) -> bool:
return (
type_ is not None
and type_._type_affinity # type:ignore[attr-defined]
is sqltypes.DateTime
and server_default is not None
)
def compare_server_default(
self,
inspector_column,
metadata_column,
rendered_metadata_default,
rendered_inspector_default,
):
# partially a workaround for SQLAlchemy issue #3023; if the
# column were created without "NOT NULL", MySQL may have added
# an implicit default of '0' which we need to skip
# TODO: this is not really covered anymore ?
if (
metadata_column.type._type_affinity is sqltypes.Integer
and inspector_column.primary_key
and not inspector_column.autoincrement
and not rendered_metadata_default
and rendered_inspector_default == "'0'"
):
return False
elif (
rendered_inspector_default
and inspector_column.type._type_affinity is sqltypes.Integer
):
rendered_inspector_default = (
re.sub(r"^'|'$", "", rendered_inspector_default)
if rendered_inspector_default is not None
else None
)
return rendered_inspector_default != rendered_metadata_default
elif (
rendered_metadata_default
and metadata_column.type._type_affinity is sqltypes.String
):
metadata_default = re.sub(r"^'|'$", "", rendered_metadata_default)
return rendered_inspector_default != f"'{metadata_default}'"
elif rendered_inspector_default and rendered_metadata_default:
# adjust for "function()" vs. "FUNCTION" as can occur particularly
# for the CURRENT_TIMESTAMP function on newer MariaDB versions
# SQLAlchemy MySQL dialect bundles ON UPDATE into the server
# default; adjust for this possibly being present.
onupdate_ins = re.match(
r"(.*) (on update.*?)(?:\(\))?$",
rendered_inspector_default.lower(),
)
onupdate_met = re.match(
r"(.*) (on update.*?)(?:\(\))?$",
rendered_metadata_default.lower(),
)
if onupdate_ins:
if not onupdate_met:
return True
elif onupdate_ins.group(2) != onupdate_met.group(2):
return True
rendered_inspector_default = onupdate_ins.group(1)
rendered_metadata_default = onupdate_met.group(1)
return re.sub(
r"(.*?)(?:\(\))?$", r"\1", rendered_inspector_default.lower()
) != re.sub(
r"(.*?)(?:\(\))?$", r"\1", rendered_metadata_default.lower()
)
else:
return rendered_inspector_default != rendered_metadata_default
def correct_for_autogen_constraints(
self,
conn_unique_constraints,
conn_indexes,
metadata_unique_constraints,
metadata_indexes,
):
# TODO: if SQLA 1.0, make use of "duplicates_index"
# metadata
removed = set()
for idx in list(conn_indexes):
if idx.unique:
continue
# MySQL puts implicit indexes on FK columns, even if
# composite and even if MyISAM, so can't check this too easily.
# the name of the index may be the column name or it may
# be the name of the FK constraint.
for col in idx.columns:
if idx.name == col.name:
conn_indexes.remove(idx)
removed.add(idx.name)
break
for fk in col.foreign_keys:
if fk.name == idx.name:
conn_indexes.remove(idx)
removed.add(idx.name)
break
if idx.name in removed:
break
# then remove indexes from the "metadata_indexes"
# that we've removed from reflected, otherwise they come out
# as adds (see #202)
for idx in list(metadata_indexes):
if idx.name in removed:
metadata_indexes.remove(idx)
def correct_for_autogen_foreignkeys(self, conn_fks, metadata_fks):
conn_fk_by_sig = {
compare._fk_constraint_sig(fk).sig: fk for fk in conn_fks
}
metadata_fk_by_sig = {
compare._fk_constraint_sig(fk).sig: fk for fk in metadata_fks
}
for sig in set(conn_fk_by_sig).intersection(metadata_fk_by_sig):
mdfk = metadata_fk_by_sig[sig]
cnfk = conn_fk_by_sig[sig]
# MySQL considers RESTRICT to be the default and doesn't
# report on it. if the model has explicit RESTRICT and
# the conn FK has None, set it to RESTRICT
if (
mdfk.ondelete is not None
and mdfk.ondelete.lower() == "restrict"
and cnfk.ondelete is None
):
cnfk.ondelete = "RESTRICT"
if (
mdfk.onupdate is not None
and mdfk.onupdate.lower() == "restrict"
and cnfk.onupdate is None
):
cnfk.onupdate = "RESTRICT"
class MariaDBImpl(MySQLImpl):
__dialect__ = "mariadb"
class MySQLAlterDefault(AlterColumn):
def __init__(
self,
name: str,
column_name: str,
default: _ServerDefault,
schema: Optional[str] = None,
) -> None:
super(AlterColumn, self).__init__(name, schema=schema)
self.column_name = column_name
self.default = default
class MySQLChangeColumn(AlterColumn):
def __init__(
self,
name: str,
column_name: str,
schema: Optional[str] = None,
newname: Optional[str] = None,
type_: Optional[TypeEngine] = None,
nullable: Optional[bool] = None,
default: Optional[Union[_ServerDefault, Literal[False]]] = False,
autoincrement: Optional[bool] = None,
comment: Optional[Union[str, Literal[False]]] = False,
) -> None:
super(AlterColumn, self).__init__(name, schema=schema)
self.column_name = column_name
self.nullable = nullable
self.newname = newname
self.default = default
self.autoincrement = autoincrement
self.comment = comment
if type_ is None:
raise util.CommandError(
"All MySQL CHANGE/MODIFY COLUMN operations "
"require the existing type."
)
self.type_ = sqltypes.to_instance(type_)
class MySQLModifyColumn(MySQLChangeColumn):
pass
@compiles(ColumnNullable, "mysql", "mariadb")
@compiles(ColumnName, "mysql", "mariadb")
@compiles(ColumnDefault, "mysql", "mariadb")
@compiles(ColumnType, "mysql", "mariadb")
def _mysql_doesnt_support_individual(element, compiler, **kw):
raise NotImplementedError(
"Individual alter column constructs not supported by MySQL"
)
@compiles(MySQLAlterDefault, "mysql", "mariadb")
def _mysql_alter_default(
element: MySQLAlterDefault, compiler: MySQLDDLCompiler, **kw
) -> str:
return "%s ALTER COLUMN %s %s" % (
alter_table(compiler, element.table_name, element.schema),
format_column_name(compiler, element.column_name),
"SET DEFAULT %s" % format_server_default(compiler, element.default)
if element.default is not None
else "DROP DEFAULT",
)
@compiles(MySQLModifyColumn, "mysql", "mariadb")
def _mysql_modify_column(
element: MySQLModifyColumn, compiler: MySQLDDLCompiler, **kw
) -> str:
return "%s MODIFY %s %s" % (
alter_table(compiler, element.table_name, element.schema),
format_column_name(compiler, element.column_name),
_mysql_colspec(
compiler,
nullable=element.nullable,
server_default=element.default,
type_=element.type_,
autoincrement=element.autoincrement,
comment=element.comment,
),
)
@compiles(MySQLChangeColumn, "mysql", "mariadb")
def _mysql_change_column(
element: MySQLChangeColumn, compiler: MySQLDDLCompiler, **kw
) -> str:
return "%s CHANGE %s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
format_column_name(compiler, element.column_name),
format_column_name(compiler, element.newname),
_mysql_colspec(
compiler,
nullable=element.nullable,
server_default=element.default,
type_=element.type_,
autoincrement=element.autoincrement,
comment=element.comment,
),
)
def _mysql_colspec(
compiler: MySQLDDLCompiler,
nullable: Optional[bool],
server_default: Optional[Union[_ServerDefault, Literal[False]]],
type_: TypeEngine,
autoincrement: Optional[bool],
comment: Optional[Union[str, Literal[False]]],
) -> str:
spec = "%s %s" % (
compiler.dialect.type_compiler.process(type_),
"NULL" if nullable else "NOT NULL",
)
if autoincrement:
spec += " AUTO_INCREMENT"
if server_default is not False and server_default is not None:
spec += " DEFAULT %s" % format_server_default(compiler, server_default)
if comment:
spec += " COMMENT %s" % compiler.sql_compiler.render_literal_value(
comment, sqltypes.String()
)
return spec
@compiles(schema.DropConstraint, "mysql", "mariadb")
def _mysql_drop_constraint(
element: DropConstraint, compiler: MySQLDDLCompiler, **kw
) -> str:
"""Redefine SQLAlchemy's drop constraint to
raise errors for invalid constraint type."""
constraint = element.element
if isinstance(
constraint,
(
schema.ForeignKeyConstraint,
schema.PrimaryKeyConstraint,
schema.UniqueConstraint,
),
):
assert not kw
return compiler.visit_drop_constraint(element)
elif isinstance(constraint, schema.CheckConstraint):
# note that SQLAlchemy as of 1.2 does not yet support
# DROP CONSTRAINT for MySQL/MariaDB, so we implement fully
# here.
if _is_mariadb(compiler.dialect):
return "ALTER TABLE %s DROP CONSTRAINT %s" % (
compiler.preparer.format_table(constraint.table),
compiler.preparer.format_constraint(constraint),
)
else:
return "ALTER TABLE %s DROP CHECK %s" % (
compiler.preparer.format_table(constraint.table),
compiler.preparer.format_constraint(constraint),
)
else:
raise NotImplementedError(
"No generic 'DROP CONSTRAINT' in MySQL - "
"please specify constraint type"
)

View File

@@ -0,0 +1,197 @@
from __future__ import annotations
import re
from typing import Any
from typing import Optional
from typing import TYPE_CHECKING
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql import sqltypes
from .base import AddColumn
from .base import alter_table
from .base import ColumnComment
from .base import ColumnDefault
from .base import ColumnName
from .base import ColumnNullable
from .base import ColumnType
from .base import format_column_name
from .base import format_server_default
from .base import format_table_name
from .base import format_type
from .base import IdentityColumnDefault
from .base import RenameTable
from .impl import DefaultImpl
if TYPE_CHECKING:
from sqlalchemy.dialects.oracle.base import OracleDDLCompiler
from sqlalchemy.engine.cursor import CursorResult
from sqlalchemy.sql.schema import Column
class OracleImpl(DefaultImpl):
__dialect__ = "oracle"
transactional_ddl = False
batch_separator = "/"
command_terminator = ""
type_synonyms = DefaultImpl.type_synonyms + (
{"VARCHAR", "VARCHAR2"},
{"BIGINT", "INTEGER", "SMALLINT", "DECIMAL", "NUMERIC", "NUMBER"},
{"DOUBLE", "FLOAT", "DOUBLE_PRECISION"},
)
identity_attrs_ignore = ()
def __init__(self, *arg, **kw) -> None:
super().__init__(*arg, **kw)
self.batch_separator = self.context_opts.get(
"oracle_batch_separator", self.batch_separator
)
def _exec(self, construct: Any, *args, **kw) -> Optional[CursorResult]:
result = super()._exec(construct, *args, **kw)
if self.as_sql and self.batch_separator:
self.static_output(self.batch_separator)
return result
def compare_server_default(
self,
inspector_column,
metadata_column,
rendered_metadata_default,
rendered_inspector_default,
):
if rendered_metadata_default is not None:
rendered_metadata_default = re.sub(
r"^\((.+)\)$", r"\1", rendered_metadata_default
)
rendered_metadata_default = re.sub(
r"^\"?'(.+)'\"?$", r"\1", rendered_metadata_default
)
if rendered_inspector_default is not None:
rendered_inspector_default = re.sub(
r"^\((.+)\)$", r"\1", rendered_inspector_default
)
rendered_inspector_default = re.sub(
r"^\"?'(.+)'\"?$", r"\1", rendered_inspector_default
)
rendered_inspector_default = rendered_inspector_default.strip()
return rendered_inspector_default != rendered_metadata_default
def emit_begin(self) -> None:
self._exec("SET TRANSACTION READ WRITE")
def emit_commit(self) -> None:
self._exec("COMMIT")
@compiles(AddColumn, "oracle")
def visit_add_column(
element: AddColumn, compiler: OracleDDLCompiler, **kw
) -> str:
return "%s %s" % (
alter_table(compiler, element.table_name, element.schema),
add_column(compiler, element.column, **kw),
)
@compiles(ColumnNullable, "oracle")
def visit_column_nullable(
element: ColumnNullable, compiler: OracleDDLCompiler, **kw
) -> str:
return "%s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
"NULL" if element.nullable else "NOT NULL",
)
@compiles(ColumnType, "oracle")
def visit_column_type(
element: ColumnType, compiler: OracleDDLCompiler, **kw
) -> str:
return "%s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
"%s" % format_type(compiler, element.type_),
)
@compiles(ColumnName, "oracle")
def visit_column_name(
element: ColumnName, compiler: OracleDDLCompiler, **kw
) -> str:
return "%s RENAME COLUMN %s TO %s" % (
alter_table(compiler, element.table_name, element.schema),
format_column_name(compiler, element.column_name),
format_column_name(compiler, element.newname),
)
@compiles(ColumnDefault, "oracle")
def visit_column_default(
element: ColumnDefault, compiler: OracleDDLCompiler, **kw
) -> str:
return "%s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
"DEFAULT %s" % format_server_default(compiler, element.default)
if element.default is not None
else "DEFAULT NULL",
)
@compiles(ColumnComment, "oracle")
def visit_column_comment(
element: ColumnComment, compiler: OracleDDLCompiler, **kw
) -> str:
ddl = "COMMENT ON COLUMN {table_name}.{column_name} IS {comment}"
comment = compiler.sql_compiler.render_literal_value(
(element.comment if element.comment is not None else ""),
sqltypes.String(),
)
return ddl.format(
table_name=element.table_name,
column_name=element.column_name,
comment=comment,
)
@compiles(RenameTable, "oracle")
def visit_rename_table(
element: RenameTable, compiler: OracleDDLCompiler, **kw
) -> str:
return "%s RENAME TO %s" % (
alter_table(compiler, element.table_name, element.schema),
format_table_name(compiler, element.new_table_name, None),
)
def alter_column(compiler: OracleDDLCompiler, name: str) -> str:
return "MODIFY %s" % format_column_name(compiler, name)
def add_column(compiler: OracleDDLCompiler, column: Column[Any], **kw) -> str:
return "ADD %s" % compiler.get_column_specification(column, **kw)
@compiles(IdentityColumnDefault, "oracle")
def visit_identity_column(
element: IdentityColumnDefault, compiler: OracleDDLCompiler, **kw
):
text = "%s %s " % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
)
if element.default is None:
# drop identity
text += "DROP IDENTITY"
return text
else:
text += compiler.visit_identity_column(element.default)
return text

View File

@@ -0,0 +1,774 @@
from __future__ import annotations
import logging
import re
from typing import Any
from typing import cast
from typing import Dict
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import TYPE_CHECKING
from typing import Union
from sqlalchemy import Column
from sqlalchemy import literal_column
from sqlalchemy import Numeric
from sqlalchemy import text
from sqlalchemy import types as sqltypes
from sqlalchemy.dialects.postgresql import BIGINT
from sqlalchemy.dialects.postgresql import ExcludeConstraint
from sqlalchemy.dialects.postgresql import INTEGER
from sqlalchemy.schema import CreateIndex
from sqlalchemy.sql import operators
from sqlalchemy.sql.elements import ColumnClause
from sqlalchemy.sql.elements import TextClause
from sqlalchemy.sql.elements import UnaryExpression
from sqlalchemy.sql.functions import FunctionElement
from sqlalchemy.types import NULLTYPE
from .base import alter_column
from .base import alter_table
from .base import AlterColumn
from .base import ColumnComment
from .base import compiles
from .base import format_column_name
from .base import format_table_name
from .base import format_type
from .base import IdentityColumnDefault
from .base import RenameTable
from .impl import DefaultImpl
from .. import util
from ..autogenerate import render
from ..operations import ops
from ..operations import schemaobj
from ..operations.base import BatchOperations
from ..operations.base import Operations
from ..util import sqla_compat
if TYPE_CHECKING:
from typing import Literal
from sqlalchemy import Index
from sqlalchemy import UniqueConstraint
from sqlalchemy.dialects.postgresql.array import ARRAY
from sqlalchemy.dialects.postgresql.base import PGDDLCompiler
from sqlalchemy.dialects.postgresql.hstore import HSTORE
from sqlalchemy.dialects.postgresql.json import JSON
from sqlalchemy.dialects.postgresql.json import JSONB
from sqlalchemy.sql.elements import ClauseElement
from sqlalchemy.sql.elements import ColumnElement
from sqlalchemy.sql.elements import quoted_name
from sqlalchemy.sql.schema import MetaData
from sqlalchemy.sql.schema import Table
from sqlalchemy.sql.type_api import TypeEngine
from .base import _ServerDefault
from ..autogenerate.api import AutogenContext
from ..autogenerate.render import _f_name
from ..runtime.migration import MigrationContext
log = logging.getLogger(__name__)
class PostgresqlImpl(DefaultImpl):
__dialect__ = "postgresql"
transactional_ddl = True
type_synonyms = DefaultImpl.type_synonyms + (
{"FLOAT", "DOUBLE PRECISION"},
)
def create_index(self, index: Index, **kw: Any) -> None:
# this likely defaults to None if not present, so get()
# should normally not return the default value. being
# defensive in any case
postgresql_include = index.kwargs.get("postgresql_include", None) or ()
for col in postgresql_include:
if col not in index.table.c: # type: ignore[union-attr]
index.table.append_column( # type: ignore[union-attr]
Column(col, sqltypes.NullType)
)
self._exec(CreateIndex(index, **kw))
def prep_table_for_batch(self, batch_impl, table):
for constraint in table.constraints:
if (
constraint.name is not None
and constraint.name in batch_impl.named_constraints
):
self.drop_constraint(constraint)
def compare_server_default(
self,
inspector_column,
metadata_column,
rendered_metadata_default,
rendered_inspector_default,
):
# don't do defaults for SERIAL columns
if (
metadata_column.primary_key
and metadata_column is metadata_column.table._autoincrement_column
):
return False
conn_col_default = rendered_inspector_default
defaults_equal = conn_col_default == rendered_metadata_default
if defaults_equal:
return False
if None in (
conn_col_default,
rendered_metadata_default,
metadata_column.server_default,
):
return not defaults_equal
metadata_default = metadata_column.server_default.arg
if isinstance(metadata_default, str):
if not isinstance(inspector_column.type, Numeric):
metadata_default = re.sub(r"^'|'$", "", metadata_default)
metadata_default = f"'{metadata_default}'"
metadata_default = literal_column(metadata_default)
# run a real compare against the server
return not self.connection.scalar(
sqla_compat._select(
literal_column(conn_col_default) == metadata_default
)
)
def alter_column( # type:ignore[override]
self,
table_name: str,
column_name: str,
nullable: Optional[bool] = None,
server_default: Union[_ServerDefault, Literal[False]] = False,
name: Optional[str] = None,
type_: Optional[TypeEngine] = None,
schema: Optional[str] = None,
autoincrement: Optional[bool] = None,
existing_type: Optional[TypeEngine] = None,
existing_server_default: Optional[_ServerDefault] = None,
existing_nullable: Optional[bool] = None,
existing_autoincrement: Optional[bool] = None,
**kw: Any,
) -> None:
using = kw.pop("postgresql_using", None)
if using is not None and type_ is None:
raise util.CommandError(
"postgresql_using must be used with the type_ parameter"
)
if type_ is not None:
self._exec(
PostgresqlColumnType(
table_name,
column_name,
type_,
schema=schema,
using=using,
existing_type=existing_type,
existing_server_default=existing_server_default,
existing_nullable=existing_nullable,
)
)
super().alter_column(
table_name,
column_name,
nullable=nullable,
server_default=server_default,
name=name,
schema=schema,
autoincrement=autoincrement,
existing_type=existing_type,
existing_server_default=existing_server_default,
existing_nullable=existing_nullable,
existing_autoincrement=existing_autoincrement,
**kw,
)
def autogen_column_reflect(self, inspector, table, column_info):
if column_info.get("default") and isinstance(
column_info["type"], (INTEGER, BIGINT)
):
seq_match = re.match(
r"nextval\('(.+?)'::regclass\)", column_info["default"]
)
if seq_match:
info = sqla_compat._exec_on_inspector(
inspector,
text(
"select c.relname, a.attname "
"from pg_class as c join "
"pg_depend d on d.objid=c.oid and "
"d.classid='pg_class'::regclass and "
"d.refclassid='pg_class'::regclass "
"join pg_class t on t.oid=d.refobjid "
"join pg_attribute a on a.attrelid=t.oid and "
"a.attnum=d.refobjsubid "
"where c.relkind='S' and c.relname=:seqname"
),
seqname=seq_match.group(1),
).first()
if info:
seqname, colname = info
if colname == column_info["name"]:
log.info(
"Detected sequence named '%s' as "
"owned by integer column '%s(%s)', "
"assuming SERIAL and omitting",
seqname,
table.name,
colname,
)
# sequence, and the owner is this column,
# its a SERIAL - whack it!
del column_info["default"]
def correct_for_autogen_constraints(
self,
conn_unique_constraints,
conn_indexes,
metadata_unique_constraints,
metadata_indexes,
):
doubled_constraints = {
index
for index in conn_indexes
if index.info.get("duplicates_constraint")
}
for ix in doubled_constraints:
conn_indexes.remove(ix)
if not sqla_compat.sqla_2:
self._skip_functional_indexes(metadata_indexes, conn_indexes)
def _cleanup_index_expr(
self, index: Index, expr: str, remove_suffix: str
) -> str:
# start = expr
expr = expr.lower().replace('"', "").replace("'", "")
if index.table is not None:
# should not be needed, since include_table=False is in compile
expr = expr.replace(f"{index.table.name.lower()}.", "")
while expr and expr[0] == "(" and expr[-1] == ")":
expr = expr[1:-1]
if "::" in expr:
# strip :: cast. types can have spaces in them
expr = re.sub(r"(::[\w ]+\w)", "", expr)
if remove_suffix and expr.endswith(remove_suffix):
expr = expr[: -len(remove_suffix)]
# print(f"START: {start} END: {expr}")
return expr
def _default_modifiers(self, exp: ClauseElement) -> str:
to_remove = ""
while isinstance(exp, UnaryExpression):
if exp.modifier is None:
exp = exp.element
else:
op = exp.modifier
if isinstance(exp.element, UnaryExpression):
inner_op = exp.element.modifier
else:
inner_op = None
if inner_op is None:
if op == operators.asc_op:
# default is asc
to_remove = " asc"
elif op == operators.nullslast_op:
# default is nulls last
to_remove = " nulls last"
else:
if (
inner_op == operators.asc_op
and op == operators.nullslast_op
):
# default is asc nulls last
to_remove = " asc nulls last"
elif (
inner_op == operators.desc_op
and op == operators.nullsfirst_op
):
# default for desc is nulls first
to_remove = " nulls first"
break
return to_remove
def _dialect_sig(
self, item: Union[Index, UniqueConstraint]
) -> Tuple[Any, ...]:
# only the positive case is returned by sqlalchemy reflection so
# None and False are threated the same
if item.dialect_kwargs.get("postgresql_nulls_not_distinct"):
return ("nulls_not_distinct",)
return ()
def create_index_sig(self, index: Index) -> Tuple[Any, ...]:
return tuple(
self._cleanup_index_expr(
index,
*(
(e, "")
if isinstance(e, str)
else (self._compile_element(e), self._default_modifiers(e))
),
)
for e in index.expressions
) + self._dialect_sig(index)
def create_unique_constraint_sig(
self, const: UniqueConstraint
) -> Tuple[Any, ...]:
return tuple(
sorted([col.name for col in const.columns])
) + self._dialect_sig(const)
def adjust_reflected_dialect_options(
self, reflected_options: Dict[str, Any], kind: str
) -> Dict[str, Any]:
options: Dict[str, Any]
options = reflected_options.get("dialect_options", {}).copy()
if not options.get("postgresql_include"):
options.pop("postgresql_include", None)
return options
def _compile_element(self, element: ClauseElement) -> str:
return element.compile(
dialect=self.dialect,
compile_kwargs={"literal_binds": True, "include_table": False},
).string
def render_ddl_sql_expr(
self,
expr: ClauseElement,
is_server_default: bool = False,
is_index: bool = False,
**kw: Any,
) -> str:
"""Render a SQL expression that is typically a server default,
index expression, etc.
"""
# apply self_group to index expressions;
# see https://github.com/sqlalchemy/sqlalchemy/blob/
# 82fa95cfce070fab401d020c6e6e4a6a96cc2578/
# lib/sqlalchemy/dialects/postgresql/base.py#L2261
if is_index and not isinstance(expr, ColumnClause):
expr = expr.self_group()
return super().render_ddl_sql_expr(
expr, is_server_default=is_server_default, is_index=is_index, **kw
)
def render_type(
self, type_: TypeEngine, autogen_context: AutogenContext
) -> Union[str, Literal[False]]:
mod = type(type_).__module__
if not mod.startswith("sqlalchemy.dialects.postgresql"):
return False
if hasattr(self, "_render_%s_type" % type_.__visit_name__):
meth = getattr(self, "_render_%s_type" % type_.__visit_name__)
return meth(type_, autogen_context)
return False
def _render_HSTORE_type(
self, type_: HSTORE, autogen_context: AutogenContext
) -> str:
return cast(
str,
render._render_type_w_subtype(
type_, autogen_context, "text_type", r"(.+?\(.*text_type=)"
),
)
def _render_ARRAY_type(
self, type_: ARRAY, autogen_context: AutogenContext
) -> str:
return cast(
str,
render._render_type_w_subtype(
type_, autogen_context, "item_type", r"(.+?\()"
),
)
def _render_JSON_type(
self, type_: JSON, autogen_context: AutogenContext
) -> str:
return cast(
str,
render._render_type_w_subtype(
type_, autogen_context, "astext_type", r"(.+?\(.*astext_type=)"
),
)
def _render_JSONB_type(
self, type_: JSONB, autogen_context: AutogenContext
) -> str:
return cast(
str,
render._render_type_w_subtype(
type_, autogen_context, "astext_type", r"(.+?\(.*astext_type=)"
),
)
class PostgresqlColumnType(AlterColumn):
def __init__(
self, name: str, column_name: str, type_: TypeEngine, **kw
) -> None:
using = kw.pop("using", None)
super().__init__(name, column_name, **kw)
self.type_ = sqltypes.to_instance(type_)
self.using = using
@compiles(RenameTable, "postgresql")
def visit_rename_table(
element: RenameTable, compiler: PGDDLCompiler, **kw
) -> str:
return "%s RENAME TO %s" % (
alter_table(compiler, element.table_name, element.schema),
format_table_name(compiler, element.new_table_name, None),
)
@compiles(PostgresqlColumnType, "postgresql")
def visit_column_type(
element: PostgresqlColumnType, compiler: PGDDLCompiler, **kw
) -> str:
return "%s %s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
"TYPE %s" % format_type(compiler, element.type_),
"USING %s" % element.using if element.using else "",
)
@compiles(ColumnComment, "postgresql")
def visit_column_comment(
element: ColumnComment, compiler: PGDDLCompiler, **kw
) -> str:
ddl = "COMMENT ON COLUMN {table_name}.{column_name} IS {comment}"
comment = (
compiler.sql_compiler.render_literal_value(
element.comment, sqltypes.String()
)
if element.comment is not None
else "NULL"
)
return ddl.format(
table_name=format_table_name(
compiler, element.table_name, element.schema
),
column_name=format_column_name(compiler, element.column_name),
comment=comment,
)
@compiles(IdentityColumnDefault, "postgresql")
def visit_identity_column(
element: IdentityColumnDefault, compiler: PGDDLCompiler, **kw
):
text = "%s %s " % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
)
if element.default is None:
# drop identity
text += "DROP IDENTITY"
return text
elif element.existing_server_default is None:
# add identity options
text += "ADD "
text += compiler.visit_identity_column(element.default)
return text
else:
# alter identity
diff, _, _ = element.impl._compare_identity_default(
element.default, element.existing_server_default
)
identity = element.default
for attr in sorted(diff):
if attr == "always":
text += "SET GENERATED %s " % (
"ALWAYS" if identity.always else "BY DEFAULT"
)
else:
text += "SET %s " % compiler.get_identity_options(
sqla_compat.Identity(**{attr: getattr(identity, attr)})
)
return text
@Operations.register_operation("create_exclude_constraint")
@BatchOperations.register_operation(
"create_exclude_constraint", "batch_create_exclude_constraint"
)
@ops.AddConstraintOp.register_add_constraint("exclude_constraint")
class CreateExcludeConstraintOp(ops.AddConstraintOp):
"""Represent a create exclude constraint operation."""
constraint_type = "exclude"
def __init__(
self,
constraint_name: sqla_compat._ConstraintName,
table_name: Union[str, quoted_name],
elements: Union[
Sequence[Tuple[str, str]],
Sequence[Tuple[ColumnClause[Any], str]],
],
where: Optional[Union[ColumnElement[bool], str]] = None,
schema: Optional[str] = None,
_orig_constraint: Optional[ExcludeConstraint] = None,
**kw,
) -> None:
self.constraint_name = constraint_name
self.table_name = table_name
self.elements = elements
self.where = where
self.schema = schema
self._orig_constraint = _orig_constraint
self.kw = kw
@classmethod
def from_constraint( # type:ignore[override]
cls, constraint: ExcludeConstraint
) -> CreateExcludeConstraintOp:
constraint_table = sqla_compat._table_for_constraint(constraint)
return cls(
constraint.name,
constraint_table.name,
[
(expr, op)
for expr, name, op in constraint._render_exprs # type:ignore[attr-defined] # noqa
],
where=cast("ColumnElement[bool] | None", constraint.where),
schema=constraint_table.schema,
_orig_constraint=constraint,
deferrable=constraint.deferrable,
initially=constraint.initially,
using=constraint.using,
)
def to_constraint(
self, migration_context: Optional[MigrationContext] = None
) -> ExcludeConstraint:
if self._orig_constraint is not None:
return self._orig_constraint
schema_obj = schemaobj.SchemaObjects(migration_context)
t = schema_obj.table(self.table_name, schema=self.schema)
excl = ExcludeConstraint(
*self.elements,
name=self.constraint_name,
where=self.where,
**self.kw,
)
for (
expr,
name,
oper,
) in excl._render_exprs: # type:ignore[attr-defined]
t.append_column(Column(name, NULLTYPE))
t.append_constraint(excl)
return excl
@classmethod
def create_exclude_constraint(
cls,
operations: Operations,
constraint_name: str,
table_name: str,
*elements: Any,
**kw: Any,
) -> Optional[Table]:
"""Issue an alter to create an EXCLUDE constraint using the
current migration context.
.. note:: This method is Postgresql specific, and additionally
requires at least SQLAlchemy 1.0.
e.g.::
from alembic import op
op.create_exclude_constraint(
"user_excl",
"user",
("period", "&&"),
("group", "="),
where=("group != 'some group'"),
)
Note that the expressions work the same way as that of
the ``ExcludeConstraint`` object itself; if plain strings are
passed, quoting rules must be applied manually.
:param name: Name of the constraint.
:param table_name: String name of the source table.
:param elements: exclude conditions.
:param where: SQL expression or SQL string with optional WHERE
clause.
:param deferrable: optional bool. If set, emit DEFERRABLE or
NOT DEFERRABLE when issuing DDL for this constraint.
:param initially: optional string. If set, emit INITIALLY <value>
when issuing DDL for this constraint.
:param schema: Optional schema name to operate within.
"""
op = cls(constraint_name, table_name, elements, **kw)
return operations.invoke(op)
@classmethod
def batch_create_exclude_constraint(
cls,
operations: BatchOperations,
constraint_name: str,
*elements: Any,
**kw: Any,
):
"""Issue a "create exclude constraint" instruction using the
current batch migration context.
.. note:: This method is Postgresql specific, and additionally
requires at least SQLAlchemy 1.0.
.. seealso::
:meth:`.Operations.create_exclude_constraint`
"""
kw["schema"] = operations.impl.schema
op = cls(constraint_name, operations.impl.table_name, elements, **kw)
return operations.invoke(op)
@render.renderers.dispatch_for(CreateExcludeConstraintOp)
def _add_exclude_constraint(
autogen_context: AutogenContext, op: CreateExcludeConstraintOp
) -> str:
return _exclude_constraint(op.to_constraint(), autogen_context, alter=True)
@render._constraint_renderers.dispatch_for(ExcludeConstraint)
def _render_inline_exclude_constraint(
constraint: ExcludeConstraint,
autogen_context: AutogenContext,
namespace_metadata: MetaData,
) -> str:
rendered = render._user_defined_render(
"exclude", constraint, autogen_context
)
if rendered is not False:
return rendered
return _exclude_constraint(constraint, autogen_context, False)
def _postgresql_autogenerate_prefix(autogen_context: AutogenContext) -> str:
imports = autogen_context.imports
if imports is not None:
imports.add("from sqlalchemy.dialects import postgresql")
return "postgresql."
def _exclude_constraint(
constraint: ExcludeConstraint,
autogen_context: AutogenContext,
alter: bool,
) -> str:
opts: List[Tuple[str, Union[quoted_name, str, _f_name, None]]] = []
has_batch = autogen_context._has_batch
if constraint.deferrable:
opts.append(("deferrable", str(constraint.deferrable)))
if constraint.initially:
opts.append(("initially", str(constraint.initially)))
if constraint.using:
opts.append(("using", str(constraint.using)))
if not has_batch and alter and constraint.table.schema:
opts.append(("schema", render._ident(constraint.table.schema)))
if not alter and constraint.name:
opts.append(
("name", render._render_gen_name(autogen_context, constraint.name))
)
def do_expr_where_opts():
args = [
"(%s, %r)"
% (
_render_potential_column(sqltext, autogen_context),
opstring,
)
for sqltext, name, opstring in constraint._render_exprs # type:ignore[attr-defined] # noqa
]
if constraint.where is not None:
args.append(
"where=%s"
% render._render_potential_expr(
constraint.where, autogen_context
)
)
args.extend(["%s=%r" % (k, v) for k, v in opts])
return args
if alter:
args = [
repr(render._render_gen_name(autogen_context, constraint.name))
]
if not has_batch:
args += [repr(render._ident(constraint.table.name))]
args.extend(do_expr_where_opts())
return "%(prefix)screate_exclude_constraint(%(args)s)" % {
"prefix": render._alembic_autogenerate_prefix(autogen_context),
"args": ", ".join(args),
}
else:
args = do_expr_where_opts()
return "%(prefix)sExcludeConstraint(%(args)s)" % {
"prefix": _postgresql_autogenerate_prefix(autogen_context),
"args": ", ".join(args),
}
def _render_potential_column(
value: Union[
ColumnClause[Any], Column[Any], TextClause, FunctionElement[Any]
],
autogen_context: AutogenContext,
) -> str:
if isinstance(value, ColumnClause):
if value.is_literal:
# like literal_column("int8range(from, to)") in ExcludeConstraint
template = "%(prefix)sliteral_column(%(name)r)"
else:
template = "%(prefix)scolumn(%(name)r)"
return template % {
"prefix": render._sqlalchemy_autogenerate_prefix(autogen_context),
"name": value.name,
}
else:
return render._render_potential_expr(
value,
autogen_context,
wrap_in_text=isinstance(value, (TextClause, FunctionElement)),
)

View File

@@ -0,0 +1,223 @@
from __future__ import annotations
import re
from typing import Any
from typing import Dict
from typing import Optional
from typing import TYPE_CHECKING
from typing import Union
from sqlalchemy import cast
from sqlalchemy import JSON
from sqlalchemy import schema
from sqlalchemy import sql
from sqlalchemy.ext.compiler import compiles
from .base import alter_table
from .base import format_table_name
from .base import RenameTable
from .impl import DefaultImpl
from .. import util
if TYPE_CHECKING:
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.sql.compiler import DDLCompiler
from sqlalchemy.sql.elements import Cast
from sqlalchemy.sql.elements import ClauseElement
from sqlalchemy.sql.schema import Column
from sqlalchemy.sql.schema import Constraint
from sqlalchemy.sql.schema import Table
from sqlalchemy.sql.type_api import TypeEngine
from ..operations.batch import BatchOperationsImpl
class SQLiteImpl(DefaultImpl):
__dialect__ = "sqlite"
transactional_ddl = False
"""SQLite supports transactional DDL, but pysqlite does not:
see: http://bugs.python.org/issue10740
"""
def requires_recreate_in_batch(
self, batch_op: BatchOperationsImpl
) -> bool:
"""Return True if the given :class:`.BatchOperationsImpl`
would need the table to be recreated and copied in order to
proceed.
Normally, only returns True on SQLite when operations other
than add_column are present.
"""
for op in batch_op.batch:
if op[0] == "add_column":
col = op[1][1]
if isinstance(
col.server_default, schema.DefaultClause
) and isinstance(col.server_default.arg, sql.ClauseElement):
return True
elif (
isinstance(col.server_default, util.sqla_compat.Computed)
and col.server_default.persisted
):
return True
elif op[0] not in ("create_index", "drop_index"):
return True
else:
return False
def add_constraint(self, const: Constraint):
# attempt to distinguish between an
# auto-gen constraint and an explicit one
if const._create_rule is None: # type:ignore[attr-defined]
raise NotImplementedError(
"No support for ALTER of constraints in SQLite dialect. "
"Please refer to the batch mode feature which allows for "
"SQLite migrations using a copy-and-move strategy."
)
elif const._create_rule(self): # type:ignore[attr-defined]
util.warn(
"Skipping unsupported ALTER for "
"creation of implicit constraint. "
"Please refer to the batch mode feature which allows for "
"SQLite migrations using a copy-and-move strategy."
)
def drop_constraint(self, const: Constraint):
if const._create_rule is None: # type:ignore[attr-defined]
raise NotImplementedError(
"No support for ALTER of constraints in SQLite dialect. "
"Please refer to the batch mode feature which allows for "
"SQLite migrations using a copy-and-move strategy."
)
def compare_server_default(
self,
inspector_column: Column[Any],
metadata_column: Column[Any],
rendered_metadata_default: Optional[str],
rendered_inspector_default: Optional[str],
) -> bool:
if rendered_metadata_default is not None:
rendered_metadata_default = re.sub(
r"^\((.+)\)$", r"\1", rendered_metadata_default
)
rendered_metadata_default = re.sub(
r"^\"?'(.+)'\"?$", r"\1", rendered_metadata_default
)
if rendered_inspector_default is not None:
rendered_inspector_default = re.sub(
r"^\((.+)\)$", r"\1", rendered_inspector_default
)
rendered_inspector_default = re.sub(
r"^\"?'(.+)'\"?$", r"\1", rendered_inspector_default
)
return rendered_inspector_default != rendered_metadata_default
def _guess_if_default_is_unparenthesized_sql_expr(
self, expr: Optional[str]
) -> bool:
"""Determine if a server default is a SQL expression or a constant.
There are too many assertions that expect server defaults to round-trip
identically without parenthesis added so we will add parens only in
very specific cases.
"""
if not expr:
return False
elif re.match(r"^[0-9\.]$", expr):
return False
elif re.match(r"^'.+'$", expr):
return False
elif re.match(r"^\(.+\)$", expr):
return False
else:
return True
def autogen_column_reflect(
self,
inspector: Inspector,
table: Table,
column_info: Dict[str, Any],
) -> None:
# SQLite expression defaults require parenthesis when sent
# as DDL
if self._guess_if_default_is_unparenthesized_sql_expr(
column_info.get("default", None)
):
column_info["default"] = "(%s)" % (column_info["default"],)
def render_ddl_sql_expr(
self, expr: ClauseElement, is_server_default: bool = False, **kw
) -> str:
# SQLite expression defaults require parenthesis when sent
# as DDL
str_expr = super().render_ddl_sql_expr(
expr, is_server_default=is_server_default, **kw
)
if (
is_server_default
and self._guess_if_default_is_unparenthesized_sql_expr(str_expr)
):
str_expr = "(%s)" % (str_expr,)
return str_expr
def cast_for_batch_migrate(
self,
existing: Column[Any],
existing_transfer: Dict[str, Union[TypeEngine, Cast]],
new_type: TypeEngine,
) -> None:
if (
existing.type._type_affinity # type:ignore[attr-defined]
is not new_type._type_affinity # type:ignore[attr-defined]
and not isinstance(new_type, JSON)
):
existing_transfer["expr"] = cast(
existing_transfer["expr"], new_type
)
def correct_for_autogen_constraints(
self,
conn_unique_constraints,
conn_indexes,
metadata_unique_constraints,
metadata_indexes,
):
self._skip_functional_indexes(metadata_indexes, conn_indexes)
@compiles(RenameTable, "sqlite")
def visit_rename_table(
element: RenameTable, compiler: DDLCompiler, **kw
) -> str:
return "%s RENAME TO %s" % (
alter_table(compiler, element.table_name, element.schema),
format_table_name(compiler, element.new_table_name, None),
)
# @compiles(AddColumn, 'sqlite')
# def visit_add_column(element, compiler, **kw):
# return "%s %s" % (
# alter_table(compiler, element.table_name, element.schema),
# add_column(compiler, element.column, **kw)
# )
# def add_column(compiler, column, **kw):
# text = "ADD COLUMN %s" % compiler.get_column_specification(column, **kw)
# need to modify SQLAlchemy so that the CHECK associated with a Boolean
# or Enum gets placed as part of the column constraints, not the Table
# see ticket 98
# for const in column.constraints:
# text += compiler.process(AddConstraint(const))
# return text