mirror of
https://github.com/apache/superset.git
synced 2026-04-09 19:35:21 +00:00
700 lines
21 KiB
Python
700 lines
21 KiB
Python
# Licensed to the Apache Software Foundation (ASF) under one
|
|
# or more contributor license agreements. See the NOTICE file
|
|
# distributed with this work for additional information
|
|
# regarding copyright ownership. The ASF licenses this file
|
|
# to you under the Apache License, Version 2.0 (the
|
|
# "License"); you may not use this file except in compliance
|
|
# with the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing,
|
|
# software distributed under the License is distributed on an
|
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
# KIND, either express or implied. See the License for the
|
|
# specific language governing permissions and limitations
|
|
# under the License.
|
|
import logging
|
|
import os
|
|
import time
|
|
from collections.abc import Iterator
|
|
from typing import Any, Callable, Optional, Union
|
|
from uuid import uuid4
|
|
|
|
from alembic import op
|
|
from sqlalchemy import (
|
|
Column,
|
|
inspect,
|
|
JSON,
|
|
MetaData,
|
|
select,
|
|
String,
|
|
Table,
|
|
text,
|
|
update,
|
|
)
|
|
from sqlalchemy.dialects.mysql.base import MySQLDialect
|
|
from sqlalchemy.dialects.postgresql.base import PGDialect
|
|
from sqlalchemy.dialects.sqlite.base import SQLiteDialect # noqa: E402
|
|
from sqlalchemy.engine.reflection import Inspector
|
|
from sqlalchemy.exc import NoSuchTableError
|
|
from sqlalchemy.orm import Query, Session
|
|
from sqlalchemy.sql.schema import SchemaItem
|
|
|
|
from superset.utils import json
|
|
|
|
GREEN = "\033[32m"
|
|
RESET = "\033[0m"
|
|
YELLOW = "\033[33m"
|
|
RED = "\033[31m"
|
|
LRED = "\033[91m"
|
|
|
|
logger = logging.getLogger("alembic.env")
|
|
|
|
DEFAULT_BATCH_SIZE = int(os.environ.get("BATCH_SIZE", 1000))
|
|
|
|
|
|
def get_table_column(
|
|
table_name: str,
|
|
column_name: str,
|
|
) -> Optional[dict[str, Any]]:
|
|
"""
|
|
Get the specified column.
|
|
|
|
:param table_name: The Table name
|
|
:param column_name: The column name
|
|
:returns: The column dictionary or None if not found
|
|
"""
|
|
|
|
insp = inspect(op.get_context().bind)
|
|
|
|
try:
|
|
for column in insp.get_columns(table_name):
|
|
if column["name"] == column_name:
|
|
return column
|
|
except NoSuchTableError:
|
|
pass
|
|
|
|
return None
|
|
|
|
|
|
def table_has_column(table_name: str, column_name: str) -> bool:
|
|
"""
|
|
Checks if a column exists in a given table.
|
|
|
|
:param table_name: A table name
|
|
:param column_name: A column name
|
|
:returns: True iff the column exists in the table
|
|
"""
|
|
|
|
return bool(get_table_column(table_name, column_name))
|
|
|
|
|
|
def table_has_index(table: str, index: str) -> bool:
|
|
"""
|
|
Checks if an index exists in a given table.
|
|
|
|
:param table: A table name
|
|
:param index: A index name
|
|
:returns: True if the index exists in the table
|
|
"""
|
|
|
|
insp = inspect(op.get_context().bind)
|
|
|
|
try:
|
|
return any(ind["name"] == index for ind in insp.get_indexes(table))
|
|
except NoSuchTableError:
|
|
return False
|
|
|
|
|
|
uuid_by_dialect = {
|
|
MySQLDialect: "UNHEX(REPLACE(CONVERT(UUID() using utf8mb4), '-', ''))",
|
|
PGDialect: "uuid_in(md5(random()::text || clock_timestamp()::text)::cstring)",
|
|
}
|
|
|
|
|
|
def assign_uuids(
|
|
model: Any, session: Session, batch_size: int = DEFAULT_BATCH_SIZE
|
|
) -> None:
|
|
"""Generate new UUIDs for all rows in a table"""
|
|
bind = op.get_bind()
|
|
table_name = model.__tablename__
|
|
count = session.query(model).count()
|
|
# silently skip if the table is empty (suitable for db initialization)
|
|
if count == 0:
|
|
return
|
|
|
|
start_time = time.time()
|
|
print(f"\nAdding uuids for `{table_name}`...")
|
|
# Use dialect specific native SQL queries if possible
|
|
for dialect, sql in uuid_by_dialect.items():
|
|
if isinstance(bind.dialect, dialect):
|
|
op.execute(
|
|
f"UPDATE {dialect().identifier_preparer.quote(table_name)} SET uuid = {sql}" # noqa: S608, E501
|
|
)
|
|
print(f"Done. Assigned {count} uuids in {time.time() - start_time:.3f}s.\n")
|
|
return
|
|
|
|
for obj in paginated_update(
|
|
session.query(model),
|
|
lambda current, total: print(
|
|
f" uuid assigned to {current} out of {total}", end="\r"
|
|
),
|
|
batch_size=batch_size,
|
|
):
|
|
obj.uuid = uuid4
|
|
print(f"Done. Assigned {count} uuids in {time.time() - start_time:.3f}s.\n")
|
|
|
|
|
|
def paginated_update(
|
|
query: Query,
|
|
print_page_progress: Optional[Union[Callable[[int, int], None], bool]] = None,
|
|
batch_size: int = DEFAULT_BATCH_SIZE,
|
|
) -> Iterator[Any]:
|
|
"""
|
|
Update models in small batches so we don't have to load everything in memory.
|
|
"""
|
|
|
|
total = query.count()
|
|
processed = 0
|
|
session: Session = inspect(query).session
|
|
result = session.execute(query)
|
|
|
|
if print_page_progress is None or print_page_progress is True:
|
|
print_page_progress = lambda processed, total: print( # noqa: E731
|
|
f" {processed}/{total}", end="\r"
|
|
)
|
|
|
|
while True:
|
|
rows = result.fetchmany(batch_size)
|
|
|
|
if not rows:
|
|
break
|
|
|
|
for row in rows:
|
|
yield row[0]
|
|
|
|
session.commit()
|
|
processed += len(rows)
|
|
|
|
if print_page_progress:
|
|
print_page_progress(processed, total)
|
|
|
|
|
|
def try_load_json(data: Optional[str]) -> dict[str, Any]:
|
|
return data and json.loads(data) or {}
|
|
|
|
|
|
def has_table(table_name: str) -> bool:
|
|
"""
|
|
Check if a table exists in the database.
|
|
|
|
:param table_name: The table name
|
|
:returns: True if the table exists
|
|
"""
|
|
|
|
insp = inspect(op.get_context().bind)
|
|
table_exists = insp.has_table(table_name)
|
|
|
|
return table_exists
|
|
|
|
|
|
def drop_fks_for_table(
|
|
table_name: str, foreign_key_names: list[str] | None = None
|
|
) -> None:
|
|
"""
|
|
Drop specific or all foreign key constraints for a table
|
|
if they exist and the database is not sqlite.
|
|
|
|
:param table_name: The table name to drop foreign key constraints from
|
|
:param foreign_key_names: Optional list of specific foreign key names to drop.
|
|
If None is provided, all will be dropped.
|
|
"""
|
|
connection = op.get_bind()
|
|
inspector = Inspector.from_engine(connection)
|
|
|
|
if isinstance(connection.dialect, SQLiteDialect):
|
|
return # sqlite doesn't like constraints
|
|
|
|
if has_table(table_name):
|
|
existing_fks = {fk["name"] for fk in inspector.get_foreign_keys(table_name)}
|
|
|
|
# What to delete based on whether the list was passed
|
|
if foreign_key_names is not None:
|
|
foreign_key_names = list(set(foreign_key_names) & existing_fks)
|
|
else:
|
|
foreign_key_names = list(existing_fks)
|
|
|
|
for fk_name in foreign_key_names:
|
|
logger.info(
|
|
"Dropping foreign key %s%s%s from table %s%s%s...", # noqa: E501
|
|
GREEN,
|
|
fk_name,
|
|
RESET,
|
|
GREEN,
|
|
table_name,
|
|
RESET,
|
|
)
|
|
op.drop_constraint(fk_name, table_name, type_="foreignkey")
|
|
|
|
|
|
def create_table(table_name: str, *columns: SchemaItem, **kwargs: Any) -> None:
|
|
"""
|
|
Creates a database table with the specified name and columns.
|
|
|
|
This function checks if a table with the given name already exists in the database.
|
|
If the table already exists, it logs an informational.
|
|
Otherwise, it proceeds to create a new table using the provided name
|
|
and schema columns.
|
|
|
|
:param table_name: The name of the table to be created.
|
|
:param columns: A variable number of arguments representing the schema
|
|
just like when calling alembic's method create_table()
|
|
"""
|
|
if has_table(table_name=table_name):
|
|
logger.info("Table %s%s%s already exists. Skipping...", LRED, table_name, RESET)
|
|
return
|
|
|
|
logger.info("Creating table %s%s%s...", GREEN, table_name, RESET)
|
|
op.create_table(table_name, *columns, **kwargs)
|
|
logger.info("Table %s%s%s created.", GREEN, table_name, RESET)
|
|
|
|
|
|
def drop_table(table_name: str) -> None:
|
|
"""
|
|
Drops a database table with the specified name.
|
|
|
|
This function checks if a table with the given name exists in the database.
|
|
If the table does not exist, it logs an informational message and skips the dropping process.
|
|
If the table exists, it first attempts to drop all foreign key constraints associated with the table
|
|
(handled by `drop_fks_for_table`) and then proceeds to drop the table.
|
|
|
|
:param table_name: The name of the table to be dropped.
|
|
""" # noqa: E501
|
|
|
|
if not has_table(table_name=table_name):
|
|
logger.info("Table %s%s%s doesn't exist. Skipping...", GREEN, table_name, RESET)
|
|
return
|
|
|
|
logger.info("Dropping table %s%s%s...", GREEN, table_name, RESET)
|
|
drop_fks_for_table(table_name)
|
|
op.drop_table(table_name=table_name)
|
|
logger.info("Table %s%s%s dropped.", GREEN, table_name, RESET)
|
|
|
|
|
|
def batch_operation(
|
|
callable: Callable[[int, int], None], count: int, batch_size: int
|
|
) -> None:
|
|
"""
|
|
Executes an operation by dividing a task into smaller batches and tracking progress.
|
|
|
|
This function is designed to process a large number of items in smaller batches. It takes a callable
|
|
that performs the operation on each batch. The function logs the progress of the operation as it processes
|
|
through the batches.
|
|
|
|
If count is set to 0 or lower, it logs an informational message and skips the batch process.
|
|
|
|
:param callable: A callable function that takes two integer arguments:
|
|
the start index and the end index of the current batch.
|
|
:param count: The total number of items to process.
|
|
:param batch_size: The number of items to process in each batch.
|
|
""" # noqa: E501
|
|
if count <= 0:
|
|
logger.info(
|
|
"No records to process in batch %s(count <= 0)%s for callable %sother_callable_example%s. Skipping...", # noqa: E501
|
|
LRED,
|
|
RESET,
|
|
LRED,
|
|
RESET,
|
|
)
|
|
return
|
|
for offset in range(0, count, batch_size):
|
|
percentage = (offset / count) * 100 if count else 0
|
|
logger.info(
|
|
"Progress: %s/%s (%.2f%%)",
|
|
"{:,}".format(offset),
|
|
"{:,}".format(count),
|
|
percentage,
|
|
)
|
|
callable(offset, min(offset + batch_size, count))
|
|
|
|
logger.info("Progress: %s/%s (100%%)", "{:,}".format(count), "{:,}".format(count))
|
|
logger.info(
|
|
"End: %s%s%s batch operation %ssuccessfully%s executed.", # noqa: E501
|
|
GREEN,
|
|
callable.__name__,
|
|
RESET,
|
|
GREEN,
|
|
RESET,
|
|
)
|
|
|
|
|
|
def add_columns(table_name: str, *columns: Column) -> None:
|
|
"""
|
|
Adds new columns to an existing database table.
|
|
|
|
If a column already exist, or the table doesn't exist, it logs an informational message and skips the adding process.
|
|
Otherwise, it proceeds to add the new column to the table.
|
|
|
|
The operation is performed using Alembic's batch_alter_table.
|
|
|
|
:param table_name: The name of the table to which the columns will be added.
|
|
:param columns: A list of SQLAlchemy Column objects that define the name, type, and other attributes of the columns to be added.
|
|
""" # noqa: E501
|
|
|
|
cols_to_add = []
|
|
for col in columns:
|
|
if table_has_column(table_name=table_name, column_name=col.name):
|
|
logger.info(
|
|
"Column %s%s%s already present on table %s%s%s. Skipping...", # noqa: E501
|
|
LRED,
|
|
col.name,
|
|
RESET,
|
|
LRED,
|
|
table_name,
|
|
RESET,
|
|
)
|
|
else:
|
|
cols_to_add.append(col)
|
|
|
|
with op.batch_alter_table(table_name) as batch_op:
|
|
for col in cols_to_add:
|
|
logger.info(
|
|
"Adding column %s%s%s to table %s%s%s...", # noqa: E501
|
|
GREEN,
|
|
col.name,
|
|
RESET,
|
|
GREEN,
|
|
table_name,
|
|
RESET,
|
|
)
|
|
batch_op.add_column(col)
|
|
|
|
|
|
def drop_columns(table_name: str, *columns: str) -> None:
|
|
"""
|
|
Drops specified columns from an existing database table.
|
|
|
|
If a column or table does not exist, it logs an informational message and skips the dropping process.
|
|
Otherwise, it proceeds to remove the column from the table.
|
|
|
|
The operation is performed using Alembic's batch_alter_table.
|
|
|
|
:param table_name: The name of the table from which the columns will be removed.
|
|
:param columns: A list of column names to be dropped.
|
|
""" # noqa: E501
|
|
|
|
cols_to_drop = []
|
|
for col in columns:
|
|
if not table_has_column(table_name=table_name, column_name=col):
|
|
logger.info(
|
|
"Column %s%s%s is not present on table %s%s%s. Skipping...", # noqa: E501
|
|
LRED,
|
|
col,
|
|
RESET,
|
|
LRED,
|
|
table_name,
|
|
RESET,
|
|
)
|
|
else:
|
|
cols_to_drop.append(col)
|
|
|
|
with op.batch_alter_table(table_name) as batch_op:
|
|
for col in cols_to_drop:
|
|
logger.info(
|
|
"Dropping column %s%s%s from table %s%s%s...", # noqa: E501
|
|
GREEN,
|
|
col,
|
|
RESET,
|
|
GREEN,
|
|
table_name,
|
|
RESET,
|
|
)
|
|
batch_op.drop_column(col)
|
|
|
|
|
|
def create_index(
|
|
table_name: str, index_name: str, columns: list[str], *, unique: bool = False
|
|
) -> None:
|
|
"""
|
|
Creates an index on specified columns of an existing database table.
|
|
|
|
If the index already exists, it logs an informational message and skips the creation process.
|
|
Otherwise, it proceeds to create a new index with the specified name on the given columns of the table.
|
|
|
|
:param table_name: The name of the table on which the index will be created.
|
|
:param index_name: The name of the index to be created.
|
|
:param columns: A list of column names for which the index will be created
|
|
:param unique: If True, create a unique index.
|
|
""" # noqa: E501
|
|
|
|
if table_has_index(table=table_name, index=index_name):
|
|
logger.info(
|
|
"Table %s%s%s already has index %s%s%s. Skipping...", # noqa: E501
|
|
LRED,
|
|
table_name,
|
|
RESET,
|
|
LRED,
|
|
index_name,
|
|
RESET,
|
|
)
|
|
return
|
|
|
|
logger.info(
|
|
"Creating index %s%s%s on table %s%s%s",
|
|
GREEN,
|
|
index_name,
|
|
RESET,
|
|
GREEN,
|
|
table_name,
|
|
RESET,
|
|
)
|
|
|
|
op.create_index(
|
|
table_name=table_name,
|
|
index_name=index_name,
|
|
unique=unique,
|
|
columns=columns,
|
|
)
|
|
|
|
|
|
def drop_index(table_name: str, index_name: str) -> None:
|
|
"""
|
|
Drops an index from an existing database table.
|
|
|
|
If the index does not exists, it logs an informational message and skips the dropping process.
|
|
Otherwise, it proceeds with the removal operation.
|
|
|
|
:param table_name: The name of the table from which the index will be dropped.
|
|
:param index_name: The name of the index to be dropped.
|
|
""" # noqa: E501
|
|
|
|
if not table_has_index(table=table_name, index=index_name):
|
|
logger.info(
|
|
"Table %s%s%s doesn't have index %s%s%s. Skipping...", # noqa: E501
|
|
LRED,
|
|
table_name,
|
|
RESET,
|
|
LRED,
|
|
index_name,
|
|
RESET,
|
|
)
|
|
return
|
|
|
|
logger.info(
|
|
"Dropping index %s%s%s from table %s%s%s...", # noqa: E501
|
|
GREEN,
|
|
index_name,
|
|
RESET,
|
|
GREEN,
|
|
table_name,
|
|
RESET,
|
|
)
|
|
|
|
op.drop_index(table_name=table_name, index_name=index_name)
|
|
|
|
|
|
def create_fks_for_table(
|
|
foreign_key_name: str,
|
|
table_name: str,
|
|
referenced_table: str,
|
|
local_cols: list[str],
|
|
remote_cols: list[str],
|
|
ondelete: Optional[str] = None,
|
|
) -> None:
|
|
"""
|
|
Create a foreign key constraint for a table, ensuring compatibility with sqlite.
|
|
|
|
:param foreign_key_name: Foreign key constraint name.
|
|
:param table_name: The name of the table where the foreign key will be created.
|
|
:param referenced_table: The table the FK references.
|
|
:param local_cols: Column names in the current table.
|
|
:param remote_cols: Column names in the referenced table.
|
|
:param ondelete: (Optional) The ON DELETE action (e.g., "CASCADE", "SET NULL").
|
|
"""
|
|
connection = op.get_bind()
|
|
|
|
if not has_table(table_name):
|
|
logger.warning(
|
|
"Table %s%s%s does not exist. Skipping foreign key creation.", # noqa: E501
|
|
LRED,
|
|
table_name,
|
|
RESET,
|
|
)
|
|
return
|
|
|
|
if isinstance(connection.dialect, SQLiteDialect):
|
|
# SQLite requires batch mode since ALTER TABLE is limited
|
|
with op.batch_alter_table(table_name) as batch_op:
|
|
logger.info(
|
|
"Creating foreign key %s%s%s on table %s%s%s (SQLite mode)...", # noqa: E501
|
|
GREEN,
|
|
foreign_key_name,
|
|
RESET,
|
|
GREEN,
|
|
table_name,
|
|
RESET,
|
|
)
|
|
batch_op.create_foreign_key(
|
|
foreign_key_name,
|
|
referenced_table,
|
|
local_cols,
|
|
remote_cols,
|
|
ondelete=ondelete,
|
|
)
|
|
else:
|
|
# Standard FK creation for other databases
|
|
logger.info(
|
|
"Creating foreign key %s%s%s on table %s%s%s...", # noqa: E501
|
|
GREEN,
|
|
foreign_key_name,
|
|
RESET,
|
|
GREEN,
|
|
table_name,
|
|
RESET,
|
|
)
|
|
op.create_foreign_key(
|
|
foreign_key_name,
|
|
table_name,
|
|
referenced_table,
|
|
local_cols,
|
|
remote_cols,
|
|
ondelete=ondelete,
|
|
)
|
|
|
|
|
|
def cast_text_column_to_json(
|
|
table: str,
|
|
column: str,
|
|
pk: str = "id",
|
|
nullable: bool = True,
|
|
suffix: str = "_tmp",
|
|
) -> None:
|
|
"""
|
|
Cast a text column to JSON.
|
|
|
|
SQLAlchemy now has a nice abstraction for JSON columns, even if the underlying
|
|
database doesn't support the type natively. We should always use it when storing
|
|
JSON payloads.
|
|
|
|
:param table: The name of the table.
|
|
:param column: The name of the column to be cast.
|
|
:param pk: The name of the primary key column.
|
|
:param nullable: Whether the new column should be nullable.
|
|
:param suffix: The suffix to be added to the temporary column name.
|
|
"""
|
|
conn = op.get_bind()
|
|
|
|
if isinstance(conn.dialect, PGDialect):
|
|
conn.execute(
|
|
text(
|
|
f"""
|
|
CREATE OR REPLACE FUNCTION safe_to_jsonb(input text)
|
|
RETURNS jsonb
|
|
LANGUAGE plpgsql
|
|
IMMUTABLE
|
|
AS $$
|
|
BEGIN
|
|
RETURN input::jsonb;
|
|
EXCEPTION WHEN invalid_text_representation THEN
|
|
RETURN NULL;
|
|
END;
|
|
$$;
|
|
|
|
ALTER TABLE {table}
|
|
ALTER COLUMN {column} TYPE jsonb
|
|
USING safe_to_jsonb({column});
|
|
"""
|
|
)
|
|
)
|
|
return
|
|
|
|
tmp_column = column + suffix
|
|
op.add_column(
|
|
table,
|
|
Column(tmp_column, JSON(), nullable=nullable),
|
|
)
|
|
|
|
meta = MetaData()
|
|
t = Table(table, meta, autoload_with=conn)
|
|
stmt_select = select(t.c[pk], t.c[column]).where(t.c[column].is_not(None))
|
|
|
|
for row_pk, value in conn.execute(stmt_select):
|
|
try:
|
|
json.loads(value)
|
|
except json.JSONDecodeError:
|
|
logger.warning(
|
|
"Invalid JSON value in column %s for %s=%s: %s",
|
|
column,
|
|
pk,
|
|
row_pk,
|
|
value,
|
|
)
|
|
continue
|
|
stmt_update = update(t).where(t.c[pk] == row_pk).values({tmp_column: value})
|
|
conn.execute(stmt_update)
|
|
|
|
op.drop_column(table, column)
|
|
op.alter_column(table, tmp_column, existing_type=JSON(), new_column_name=column)
|
|
|
|
return
|
|
|
|
|
|
def cast_json_column_to_text(
|
|
table: str,
|
|
column: str,
|
|
pk: str = "id",
|
|
nullable: bool = True,
|
|
suffix: str = "_tmp",
|
|
length: int = 128,
|
|
) -> None:
|
|
"""
|
|
Cast a JSON column back to text.
|
|
|
|
:param table: The name of the table.
|
|
:param column: The name of the column to be cast.
|
|
:param pk: The name of the primary key column.
|
|
:param nullable: Whether the new column should be nullable.
|
|
:param suffix: The suffix to be added to the temporary column name.
|
|
:param length: The length of the text column.
|
|
"""
|
|
conn = op.get_bind()
|
|
|
|
if isinstance(conn.dialect, PGDialect):
|
|
conn.execute(
|
|
text(
|
|
f"""
|
|
ALTER TABLE {table}
|
|
ALTER COLUMN {column} TYPE text
|
|
USING {column}::text
|
|
"""
|
|
)
|
|
)
|
|
return
|
|
|
|
tmp_column = column + suffix
|
|
op.add_column(
|
|
table,
|
|
Column(tmp_column, String(length=length), nullable=nullable),
|
|
)
|
|
|
|
meta = MetaData()
|
|
t = Table(table, meta, autoload_with=conn)
|
|
stmt_select = select(t.c[pk], t.c[column]).where(t.c[column].is_not(None))
|
|
|
|
for row_pk, value in conn.execute(stmt_select):
|
|
stmt_update = (
|
|
update(t).where(t.c[pk] == row_pk).values({tmp_column: json.dumps(value)})
|
|
)
|
|
conn.execute(stmt_update)
|
|
|
|
op.drop_column(table, column)
|
|
op.alter_column(
|
|
table,
|
|
tmp_column,
|
|
existing_type=String(length=length),
|
|
new_column_name=column,
|
|
)
|
|
|
|
return
|