chore(pre-commit): Add pyupgrade and pycln hooks (#24197)

This commit is contained in:
John Bodley
2023-06-01 12:01:10 -07:00
committed by GitHub
parent 7d7ce63970
commit a4d5d7c6b9
448 changed files with 3084 additions and 3305 deletions

View File

@@ -23,7 +23,7 @@ from graphlib import TopologicalSorter
from inspect import getsource
from pathlib import Path
from types import ModuleType
from typing import Any, Dict, List, Set, Type
from typing import Any
import click
from flask import current_app
@@ -48,12 +48,10 @@ def import_migration_script(filepath: Path) -> ModuleType:
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module) # type: ignore
return module
raise Exception(
"No module spec found in location: `{path}`".format(path=str(filepath))
)
raise Exception(f"No module spec found in location: `{str(filepath)}`")
def extract_modified_tables(module: ModuleType) -> Set[str]:
def extract_modified_tables(module: ModuleType) -> set[str]:
"""
Extract the tables being modified by a migration script.
@@ -62,7 +60,7 @@ def extract_modified_tables(module: ModuleType) -> Set[str]:
actually traversing the AST.
"""
tables: Set[str] = set()
tables: set[str] = set()
for function in {"upgrade", "downgrade"}:
source = getsource(getattr(module, function))
tables.update(re.findall(r'alter_table\(\s*"(\w+?)"\s*\)', source, re.DOTALL))
@@ -72,11 +70,11 @@ def extract_modified_tables(module: ModuleType) -> Set[str]:
return tables
def find_models(module: ModuleType) -> List[Type[Model]]:
def find_models(module: ModuleType) -> list[type[Model]]:
"""
Find all models in a migration script.
"""
models: List[Type[Model]] = []
models: list[type[Model]] = []
tables = extract_modified_tables(module)
# add models defined explicitly in the migration script
@@ -123,7 +121,7 @@ def find_models(module: ModuleType) -> List[Type[Model]]:
sorter: TopologicalSorter[Any] = TopologicalSorter()
for model in models:
inspector = inspect(model)
dependent_tables: List[str] = []
dependent_tables: list[str] = []
for column in inspector.columns.values():
for foreign_key in column.foreign_keys:
if foreign_key.column.table.name != model.__tablename__:
@@ -174,7 +172,7 @@ def main(
print("\nIdentifying models used in the migration:")
models = find_models(module)
model_rows: Dict[Type[Model], int] = {}
model_rows: dict[type[Model], int] = {}
for model in models:
rows = session.query(model).count()
print(f"- {model.__name__} ({rows} rows in table {model.__tablename__})")
@@ -182,7 +180,7 @@ def main(
session.close()
print("Benchmarking migration")
results: Dict[str, float] = {}
results: dict[str, float] = {}
start = time.time()
upgrade(revision=revision)
duration = time.time() - start
@@ -190,14 +188,14 @@ def main(
print(f"Migration on current DB took: {duration:.2f} seconds")
min_entities = 10
new_models: Dict[Type[Model], List[Model]] = defaultdict(list)
new_models: dict[type[Model], list[Model]] = defaultdict(list)
while min_entities <= limit:
downgrade(revision=down_revision)
print(f"Running with at least {min_entities} entities of each model")
for model in models:
missing = min_entities - model_rows[model]
if missing > 0:
entities: List[Model] = []
entities: list[Model] = []
print(f"- Adding {missing} entities to the {model.__name__} model")
bar = ChargingBar("Processing", max=missing)
try: