perf: refactor SIP-68 db migrations with INSERT SELECT FROM (#19421)

This commit is contained in:
Jesse Yang
2022-04-19 18:58:18 -07:00
committed by GitHub
parent 1c5d3b73df
commit 231716cb50
30 changed files with 2356 additions and 1988 deletions

View File

@@ -15,42 +15,22 @@
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Any, Iterator, Optional, Set
import os
import time
from typing import Any
from uuid import uuid4
from alembic import op
from sqlalchemy import engine_from_config
from sqlalchemy.dialects.mysql.base import MySQLDialect
from sqlalchemy.dialects.postgresql.base import PGDialect
from sqlalchemy.engine import reflection
from sqlalchemy.exc import NoSuchTableError
from sqlalchemy.orm import Session
try:
from sqloxide import parse_sql
except ImportError:
parse_sql = None
logger = logging.getLogger(__name__)
from superset.sql_parse import ParsedQuery, Table
logger = logging.getLogger("alembic")
# mapping between sqloxide and SQLAlchemy dialects
sqloxide_dialects = {
"ansi": {"trino", "trinonative", "presto"},
"hive": {"hive", "databricks"},
"ms": {"mssql"},
"mysql": {"mysql"},
"postgres": {
"cockroachdb",
"hana",
"netezza",
"postgres",
"postgresql",
"redshift",
"vertica",
},
"snowflake": {"snowflake"},
"sqlite": {"sqlite", "gsheets", "shillelagh"},
"clickhouse": {"clickhouse"},
}
DEFAULT_BATCH_SIZE = int(os.environ.get("BATCH_SIZE", 1000))
def table_has_column(table: str, column: str) -> bool:
@@ -61,7 +41,6 @@ def table_has_column(table: str, column: str) -> bool:
:param column: A column name
:returns: True iff the column exists in the table
"""
config = op.get_context().config
engine = engine_from_config(
config.get_section(config.config_ini_section), prefix="sqlalchemy."
@@ -73,42 +52,44 @@ def table_has_column(table: str, column: str) -> bool:
return False
def find_nodes_by_key(element: Any, target: str) -> Iterator[Any]:
"""
Find all nodes in a SQL tree matching a given key.
"""
if isinstance(element, list):
for child in element:
yield from find_nodes_by_key(child, target)
elif isinstance(element, dict):
for key, value in element.items():
if key == target:
yield value
else:
yield from find_nodes_by_key(value, target)
uuid_by_dialect = {
MySQLDialect: "UNHEX(REPLACE(CONVERT(UUID() using utf8mb4), '-', ''))",
PGDialect: "uuid_in(md5(random()::text || clock_timestamp()::text)::cstring)",
}
def extract_table_references(sql_text: str, sqla_dialect: str) -> Set[Table]:
"""
Return all the dependencies from a SQL sql_text.
"""
if not parse_sql:
parsed = ParsedQuery(sql_text)
return parsed.tables
def assign_uuids(
model: Any, session: Session, batch_size: int = DEFAULT_BATCH_SIZE
) -> None:
"""Generate new UUIDs for all rows in a table"""
bind = op.get_bind()
table_name = model.__tablename__
count = session.query(model).count()
# silently skip if the table is empty (suitable for db initialization)
if count == 0:
return
dialect = "generic"
for dialect, sqla_dialects in sqloxide_dialects.items():
if sqla_dialect in sqla_dialects:
break
try:
tree = parse_sql(sql_text, dialect=dialect)
except Exception: # pylint: disable=broad-except
logger.warning("Unable to parse query with sqloxide: %s", sql_text)
# fallback to sqlparse
parsed = ParsedQuery(sql_text)
return parsed.tables
start_time = time.time()
print(f"\nAdding uuids for `{table_name}`...")
# Use dialect specific native SQL queries if possible
for dialect, sql in uuid_by_dialect.items():
if isinstance(bind.dialect, dialect):
op.execute(
f"UPDATE {dialect().identifier_preparer.quote(table_name)} SET uuid = {sql}"
)
print(f"Done. Assigned {count} uuids in {time.time() - start_time:.3f}s.\n")
return
return {
Table(*[part["value"] for part in table["name"][::-1]])
for table in find_nodes_by_key(tree, "Table")
}
# Othwewise Use Python uuid function
start = 0
while start < count:
end = min(start + batch_size, count)
for obj in session.query(model)[start:end]:
obj.uuid = uuid4()
session.merge(obj)
session.commit()
if start + batch_size < count:
print(f" uuid assigned to {end} out of {count}\r", end="")
start += batch_size
print(f"Done. Assigned {count} uuids in {time.time() - start_time:.3f}s.\n")