fix: benchmark migration script (#15032)

This commit is contained in:
Beto Dealmeida
2021-06-08 13:57:07 -07:00
committed by GitHub
parent b75df937e9
commit 3b97074ecb
4 changed files with 88 additions and 22 deletions

View File

@@ -25,10 +25,12 @@ from types import ModuleType
from typing import Dict, List, Set, Type
import click
from flask import current_app
from flask_appbuilder import Model
from flask_migrate import downgrade, upgrade
from graphlib import TopologicalSorter # pylint: disable=wrong-import-order
from sqlalchemy import inspect
from sqlalchemy import create_engine, inspect, Table
from sqlalchemy.ext.automap import automap_base
from superset import db
from superset.utils.mock_data import add_sample_rows
@@ -83,11 +85,18 @@ def find_models(module: ModuleType) -> List[Type[Model]]:
elif isinstance(obj, dict):
queue.extend(obj.values())
# add implicit models
# pylint: disable=no-member, protected-access
for obj in Model._decl_class_registry.values():
if hasattr(obj, "__table__") and obj.__table__.fullname in tables:
models.append(obj)
# build models by automapping the existing tables, instead of using current
# code; this is needed for migrations that modify schemas (eg, add a column),
# where the current model is out-of-sync with the existing table after a
# downgrade
sqlalchemy_uri = current_app.config["SQLALCHEMY_DATABASE_URI"]
engine = create_engine(sqlalchemy_uri)
Base = automap_base()
Base.prepare(engine, reflect=True)
for table in tables:
model = getattr(Base.classes, table)
model.__tablename__ = table
models.append(model)
# sort topologically so we can create entities in order and
# maintain relationships (eg, create a database before creating
@@ -133,15 +142,6 @@ def main(
).scalar()
print(f"Current version of the DB is {current_revision}")
print("\nIdentifying models used in the migration:")
models = find_models(module)
model_rows: Dict[Type[Model], int] = {}
for model in models:
rows = session.query(model).count()
print(f"- {model.__name__} ({rows} rows in table {model.__tablename__})")
model_rows[model] = rows
session.close()
if current_revision != down_revision:
if not force:
click.confirm(
@@ -152,6 +152,15 @@ def main(
)
downgrade(revision=down_revision)
print("\nIdentifying models used in the migration:")
models = find_models(module)
model_rows: Dict[Type[Model], int] = {}
for model in models:
rows = session.query(model).count()
print(f"- {model.__name__} ({rows} rows in table {model.__tablename__})")
model_rows[model] = rows
session.close()
print("Benchmarking migration")
results: Dict[str, float] = {}
start = time.time()