mirror of
https://github.com/apache/superset.git
synced 2026-04-20 16:44:46 +00:00
perf: refactor SIP-68 db migrations with INSERT SELECT FROM (#19421)
This commit is contained in:
@@ -15,42 +15,22 @@
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
from typing import Any, Iterator, Optional, Set
|
||||
import os
|
||||
import time
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from alembic import op
|
||||
from sqlalchemy import engine_from_config
|
||||
from sqlalchemy.dialects.mysql.base import MySQLDialect
|
||||
from sqlalchemy.dialects.postgresql.base import PGDialect
|
||||
from sqlalchemy.engine import reflection
|
||||
from sqlalchemy.exc import NoSuchTableError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
try:
|
||||
from sqloxide import parse_sql
|
||||
except ImportError:
|
||||
parse_sql = None
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from superset.sql_parse import ParsedQuery, Table
|
||||
|
||||
logger = logging.getLogger("alembic")
|
||||
|
||||
|
||||
# mapping between sqloxide and SQLAlchemy dialects
|
||||
sqloxide_dialects = {
|
||||
"ansi": {"trino", "trinonative", "presto"},
|
||||
"hive": {"hive", "databricks"},
|
||||
"ms": {"mssql"},
|
||||
"mysql": {"mysql"},
|
||||
"postgres": {
|
||||
"cockroachdb",
|
||||
"hana",
|
||||
"netezza",
|
||||
"postgres",
|
||||
"postgresql",
|
||||
"redshift",
|
||||
"vertica",
|
||||
},
|
||||
"snowflake": {"snowflake"},
|
||||
"sqlite": {"sqlite", "gsheets", "shillelagh"},
|
||||
"clickhouse": {"clickhouse"},
|
||||
}
|
||||
DEFAULT_BATCH_SIZE = int(os.environ.get("BATCH_SIZE", 1000))
|
||||
|
||||
|
||||
def table_has_column(table: str, column: str) -> bool:
|
||||
@@ -61,7 +41,6 @@ def table_has_column(table: str, column: str) -> bool:
|
||||
:param column: A column name
|
||||
:returns: True iff the column exists in the table
|
||||
"""
|
||||
|
||||
config = op.get_context().config
|
||||
engine = engine_from_config(
|
||||
config.get_section(config.config_ini_section), prefix="sqlalchemy."
|
||||
@@ -73,42 +52,44 @@ def table_has_column(table: str, column: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def find_nodes_by_key(element: Any, target: str) -> Iterator[Any]:
|
||||
"""
|
||||
Find all nodes in a SQL tree matching a given key.
|
||||
"""
|
||||
if isinstance(element, list):
|
||||
for child in element:
|
||||
yield from find_nodes_by_key(child, target)
|
||||
elif isinstance(element, dict):
|
||||
for key, value in element.items():
|
||||
if key == target:
|
||||
yield value
|
||||
else:
|
||||
yield from find_nodes_by_key(value, target)
|
||||
uuid_by_dialect = {
|
||||
MySQLDialect: "UNHEX(REPLACE(CONVERT(UUID() using utf8mb4), '-', ''))",
|
||||
PGDialect: "uuid_in(md5(random()::text || clock_timestamp()::text)::cstring)",
|
||||
}
|
||||
|
||||
|
||||
def extract_table_references(sql_text: str, sqla_dialect: str) -> Set[Table]:
|
||||
"""
|
||||
Return all the dependencies from a SQL sql_text.
|
||||
"""
|
||||
if not parse_sql:
|
||||
parsed = ParsedQuery(sql_text)
|
||||
return parsed.tables
|
||||
def assign_uuids(
|
||||
model: Any, session: Session, batch_size: int = DEFAULT_BATCH_SIZE
|
||||
) -> None:
|
||||
"""Generate new UUIDs for all rows in a table"""
|
||||
bind = op.get_bind()
|
||||
table_name = model.__tablename__
|
||||
count = session.query(model).count()
|
||||
# silently skip if the table is empty (suitable for db initialization)
|
||||
if count == 0:
|
||||
return
|
||||
|
||||
dialect = "generic"
|
||||
for dialect, sqla_dialects in sqloxide_dialects.items():
|
||||
if sqla_dialect in sqla_dialects:
|
||||
break
|
||||
try:
|
||||
tree = parse_sql(sql_text, dialect=dialect)
|
||||
except Exception: # pylint: disable=broad-except
|
||||
logger.warning("Unable to parse query with sqloxide: %s", sql_text)
|
||||
# fallback to sqlparse
|
||||
parsed = ParsedQuery(sql_text)
|
||||
return parsed.tables
|
||||
start_time = time.time()
|
||||
print(f"\nAdding uuids for `{table_name}`...")
|
||||
# Use dialect specific native SQL queries if possible
|
||||
for dialect, sql in uuid_by_dialect.items():
|
||||
if isinstance(bind.dialect, dialect):
|
||||
op.execute(
|
||||
f"UPDATE {dialect().identifier_preparer.quote(table_name)} SET uuid = {sql}"
|
||||
)
|
||||
print(f"Done. Assigned {count} uuids in {time.time() - start_time:.3f}s.\n")
|
||||
return
|
||||
|
||||
return {
|
||||
Table(*[part["value"] for part in table["name"][::-1]])
|
||||
for table in find_nodes_by_key(tree, "Table")
|
||||
}
|
||||
# Othwewise Use Python uuid function
|
||||
start = 0
|
||||
while start < count:
|
||||
end = min(start + batch_size, count)
|
||||
for obj in session.query(model)[start:end]:
|
||||
obj.uuid = uuid4()
|
||||
session.merge(obj)
|
||||
session.commit()
|
||||
if start + batch_size < count:
|
||||
print(f" uuid assigned to {end} out of {count}\r", end="")
|
||||
start += batch_size
|
||||
|
||||
print(f"Done. Assigned {count} uuids in {time.time() - start_time:.3f}s.\n")
|
||||
|
||||
Reference in New Issue
Block a user