diff --git a/scripts/benchmark_migration.py b/scripts/benchmark_migration.py index 0faa92a8855..d226efbfd30 100644 --- a/scripts/benchmark_migration.py +++ b/scripts/benchmark_migration.py @@ -25,10 +25,12 @@ from types import ModuleType from typing import Dict, List, Set, Type import click +from flask import current_app from flask_appbuilder import Model from flask_migrate import downgrade, upgrade from graphlib import TopologicalSorter # pylint: disable=wrong-import-order -from sqlalchemy import inspect +from sqlalchemy import create_engine, inspect, Table +from sqlalchemy.ext.automap import automap_base from superset import db from superset.utils.mock_data import add_sample_rows @@ -83,11 +85,18 @@ def find_models(module: ModuleType) -> List[Type[Model]]: elif isinstance(obj, dict): queue.extend(obj.values()) - # add implicit models - # pylint: disable=no-member, protected-access - for obj in Model._decl_class_registry.values(): - if hasattr(obj, "__table__") and obj.__table__.fullname in tables: - models.append(obj) + # build models by automapping the existing tables, instead of using current + # code; this is needed for migrations that modify schemas (eg, add a column), + # where the current model is out-of-sync with the existing table after a + # downgrade + sqlalchemy_uri = current_app.config["SQLALCHEMY_DATABASE_URI"] + engine = create_engine(sqlalchemy_uri) + Base = automap_base() + Base.prepare(engine, reflect=True) + for table in tables: + model = getattr(Base.classes, table) + model.__tablename__ = table + models.append(model) # sort topologically so we can create entities in order and # maintain relationships (eg, create a database before creating @@ -133,15 +142,6 @@ def main( ).scalar() print(f"Current version of the DB is {current_revision}") - print("\nIdentifying models used in the migration:") - models = find_models(module) - model_rows: Dict[Type[Model], int] = {} - for model in models: - rows = session.query(model).count() - print(f"- {model.__name__} ({rows} rows in table {model.__tablename__})") - model_rows[model] = rows - session.close() - if current_revision != down_revision: if not force: click.confirm( @@ -152,6 +152,15 @@ def main( ) downgrade(revision=down_revision) + print("\nIdentifying models used in the migration:") + models = find_models(module) + model_rows: Dict[Type[Model], int] = {} + for model in models: + rows = session.query(model).count() + print(f"- {model.__name__} ({rows} rows in table {model.__tablename__})") + model_rows[model] = rows + session.close() + print("Benchmarking migration") results: Dict[str, float] = {} start = time.time() diff --git a/superset/migrations/versions/27ae655e4247_make_creator_owners.py b/superset/migrations/versions/27ae655e4247_make_creator_owners.py index 561a8ca9a54..c373c0f7e90 100644 --- a/superset/migrations/versions/27ae655e4247_make_creator_owners.py +++ b/superset/migrations/versions/27ae655e4247_make_creator_owners.py @@ -27,10 +27,10 @@ revision = "27ae655e4247" down_revision = "d8bc074f7aad" from alembic import op +from flask import g from flask_appbuilder import Model -from flask_appbuilder.models.mixins import AuditMixin from sqlalchemy import Column, ForeignKey, Integer, Table -from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.ext.declarative import declarative_base, declared_attr from sqlalchemy.orm import relationship from superset import db @@ -62,6 +62,29 @@ dashboard_user = Table( ) +class AuditMixin: + @classmethod + def get_user_id(cls): + try: + return g.user.id + except Exception: + return None + + @declared_attr + def created_by_fk(cls): + return Column( + Integer, ForeignKey("ab_user.id"), default=cls.get_user_id, nullable=False + ) + + @declared_attr + def created_by(cls): + return relationship( + "User", + primaryjoin="%s.created_by_fk == User.id" % cls.__name__, + enable_typechecks=False, + ) + + class Slice(Base, AuditMixin): """Declarative class to do query in upgrade""" diff --git a/superset/migrations/versions/c82ee8a39623_add_implicit_tags.py b/superset/migrations/versions/c82ee8a39623_add_implicit_tags.py index 42fdb7e4668..3bab3f6ec3a 100644 --- a/superset/migrations/versions/c82ee8a39623_add_implicit_tags.py +++ b/superset/migrations/versions/c82ee8a39623_add_implicit_tags.py @@ -26,16 +26,46 @@ Create Date: 2018-07-26 11:10:23.653524 revision = "c82ee8a39623" down_revision = "c617da68de7d" -from alembic import op -from sqlalchemy import Column, Enum, ForeignKey, Integer, String -from sqlalchemy.ext.declarative import declarative_base +from datetime import datetime + +from alembic import op +from flask_appbuilder.models.mixins import AuditMixin +from sqlalchemy import Column, DateTime, Enum, ForeignKey, Integer, String +from sqlalchemy.ext.declarative import declarative_base, declared_attr -from superset.models.helpers import AuditMixinNullable from superset.models.tags import ObjectTypes, TagTypes Base = declarative_base() +class AuditMixinNullable(AuditMixin): + """Altering the AuditMixin to use nullable fields + + Allows creating objects programmatically outside of CRUD + """ + + created_on = Column(DateTime, default=datetime.now, nullable=True) + changed_on = Column( + DateTime, default=datetime.now, onupdate=datetime.now, nullable=True + ) + + @declared_attr + def created_by_fk(self) -> Column: + return Column( + Integer, ForeignKey("ab_user.id"), default=self.get_user_id, nullable=True, + ) + + @declared_attr + def changed_by_fk(self) -> Column: + return Column( + Integer, + ForeignKey("ab_user.id"), + default=self.get_user_id, + onupdate=self.get_user_id, + nullable=True, + ) + + class Tag(Base, AuditMixinNullable): """A tag attached to an object (query, chart or dashboard).""" diff --git a/superset/utils/mock_data.py b/superset/utils/mock_data.py index 06327ef8926..84981ca59b8 100644 --- a/superset/utils/mock_data.py +++ b/superset/utils/mock_data.py @@ -29,6 +29,7 @@ import sqlalchemy.sql.sqltypes import sqlalchemy_utils from flask_appbuilder import Model from sqlalchemy import Column, inspect, MetaData, Table +from sqlalchemy.dialects import postgresql from sqlalchemy.orm import Session from sqlalchemy.sql import func from sqlalchemy.sql.visitors import VisitableType @@ -146,6 +147,9 @@ def get_type_generator(sqltype: sqlalchemy.sql.sqltypes) -> Callable[[], Any]: if isinstance(sqltype, sqlalchemy_utils.types.uuid.UUIDType): return uuid4 + if isinstance(sqltype, postgresql.base.UUID): + return lambda: str(uuid4()) + if isinstance(sqltype, sqlalchemy.sql.sqltypes.BLOB): length = random.randrange(sqltype.length or 255) return lambda: os.urandom(length) @@ -153,7 +157,7 @@ def get_type_generator(sqltype: sqlalchemy.sql.sqltypes) -> Callable[[], Any]: logger.warning( "Unknown type %s. Please add it to `get_type_generator`.", type(sqltype) ) - return lambda: "UNKNOWN TYPE" + return lambda: b"UNKNOWN TYPE" def add_data(