mirror of
https://github.com/apache/superset.git
synced 2026-04-09 19:35:21 +00:00
234 lines
8.3 KiB
Python
234 lines
8.3 KiB
Python
# Licensed to the Apache Software Foundation (ASF) under one
|
|
# or more contributor license agreements. See the NOTICE file
|
|
# distributed with this work for additional information
|
|
# regarding copyright ownership. The ASF licenses this file
|
|
# to you under the Apache License, Version 2.0 (the
|
|
# "License"); you may not use this file except in compliance
|
|
# with the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing,
|
|
# software distributed under the License is distributed on an
|
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
# KIND, either express or implied. See the License for the
|
|
# specific language governing permissions and limitations
|
|
# under the License.
|
|
import copy
|
|
import logging
|
|
from typing import Any
|
|
|
|
from flask import current_app
|
|
from sqlalchemy import and_, Column, Integer, String, Text
|
|
from sqlalchemy.ext.declarative import declarative_base
|
|
from sqlalchemy.orm import Session
|
|
|
|
from superset.constants import TimeGrain
|
|
from superset.migrations.shared.utils import paginated_update, try_load_json
|
|
from superset.utils import json
|
|
from superset.utils.date_parser import get_since_until
|
|
|
|
logger = logging.getLogger("alembic.env")
|
|
|
|
Base = declarative_base()
|
|
|
|
|
|
class Slice(Base): # type: ignore
|
|
__tablename__ = "slices"
|
|
|
|
id = Column(Integer, primary_key=True)
|
|
slice_name = Column(String(250))
|
|
viz_type = Column(String(250))
|
|
params = Column(Text)
|
|
query_context = Column(Text)
|
|
|
|
|
|
FORM_DATA_BAK_FIELD_NAME = "form_data_bak"
|
|
QUERIES_BAK_FIELD_NAME = "queries_bak"
|
|
|
|
|
|
class MigrateViz:
|
|
remove_keys: set[str] = set()
|
|
rename_keys: dict[str, str] = {}
|
|
source_viz_type: str
|
|
target_viz_type: str
|
|
has_x_axis_control: bool = False
|
|
|
|
def __init__(self, form_data: str) -> None:
|
|
self.data = try_load_json(form_data)
|
|
|
|
def _pre_action(self) -> None:
|
|
"""Some actions before migrate"""
|
|
|
|
def _migrate(self) -> None:
|
|
if self.data.get("viz_type") != self.source_viz_type:
|
|
return
|
|
|
|
if "viz_type" in self.data:
|
|
self.data["viz_type"] = self.target_viz_type
|
|
|
|
# Sometimes visualizations have same keys in the source form_data and rename_keys # noqa: E501
|
|
# We need to remove them from data to allow the migration to work properly with rename_keys # noqa: E501
|
|
for source_key, target_key in self.rename_keys.items():
|
|
if source_key in self.data and target_key in self.data:
|
|
self.data.pop(target_key)
|
|
|
|
rv_data = {}
|
|
for key, value in self.data.items():
|
|
if key in self.rename_keys and self.rename_keys[key] in rv_data:
|
|
raise ValueError("Duplicate key in target viz")
|
|
|
|
if key in self.rename_keys:
|
|
rv_data[self.rename_keys[key]] = value
|
|
continue
|
|
|
|
if key in self.remove_keys:
|
|
continue
|
|
|
|
rv_data[key] = value
|
|
|
|
self._migrate_temporal_filter(rv_data)
|
|
|
|
self.data = rv_data
|
|
|
|
def _post_action(self) -> None:
|
|
"""Some actions after migrate"""
|
|
|
|
def _migrate_temporal_filter(self, rv_data: dict[str, Any]) -> None:
|
|
"""Adds a temporal filter."""
|
|
granularity_sqla = rv_data.pop("granularity_sqla", None)
|
|
time_range = rv_data.pop("time_range", None) or current_app.config.get(
|
|
"DEFAULT_TIME_FILTER"
|
|
)
|
|
|
|
if not granularity_sqla:
|
|
return
|
|
|
|
if self.has_x_axis_control:
|
|
rv_data["x_axis"] = granularity_sqla
|
|
rv_data["time_grain_sqla"] = rv_data.get("time_grain_sqla") or TimeGrain.DAY
|
|
|
|
temporal_filter = {
|
|
"clause": "WHERE",
|
|
"subject": granularity_sqla,
|
|
"operator": "TEMPORAL_RANGE",
|
|
"comparator": time_range,
|
|
"expressionType": "SIMPLE",
|
|
}
|
|
|
|
if isinstance(granularity_sqla, dict):
|
|
since, until = get_since_until(time_range=time_range)
|
|
if not since and not until:
|
|
temporal_filter = {}
|
|
else:
|
|
temporal_filter["comparator"] = None
|
|
temporal_filter["expressionType"] = "SQL"
|
|
temporal_filter["subject"] = granularity_sqla["label"]
|
|
|
|
start_date = since.isoformat() if since else None
|
|
end_date = until.isoformat() if until else None
|
|
if start_date and end_date:
|
|
temporal_filter["sqlExpression"] = (
|
|
f"{granularity_sqla['sqlExpression']} >= '{start_date}' AND "
|
|
f"{granularity_sqla['sqlExpression']} < '{end_date}'"
|
|
)
|
|
elif start_date:
|
|
temporal_filter["sqlExpression"] = (
|
|
f"{granularity_sqla['sqlExpression']} >= '{start_date}'"
|
|
)
|
|
elif end_date:
|
|
temporal_filter["sqlExpression"] = (
|
|
f"{granularity_sqla['sqlExpression']} < '{end_date}'"
|
|
)
|
|
|
|
rv_data["adhoc_filters"] = rv_data.get("adhoc_filters") or []
|
|
if temporal_filter:
|
|
rv_data["adhoc_filters"].append(temporal_filter)
|
|
|
|
@classmethod
|
|
def upgrade_slice(cls, slc: Slice) -> None:
|
|
try:
|
|
clz = cls(slc.params)
|
|
form_data_bak = copy.deepcopy(clz.data)
|
|
|
|
clz._pre_action()
|
|
clz._migrate()
|
|
clz._post_action()
|
|
|
|
# viz_type depends on the migration and should be set after its execution
|
|
# because a source viz can be mapped to different target viz types
|
|
slc.viz_type = clz.target_viz_type
|
|
|
|
backup: Any | dict[str, Any] = {FORM_DATA_BAK_FIELD_NAME: form_data_bak}
|
|
|
|
query_context = try_load_json(slc.query_context)
|
|
queries_bak = None
|
|
|
|
if query_context:
|
|
if "form_data" in query_context:
|
|
query_context["form_data"] = clz.data
|
|
|
|
queries_bak = copy.deepcopy(query_context["queries"])
|
|
|
|
queries = clz._build_query()["queries"]
|
|
query_context["queries"] = queries
|
|
else:
|
|
query_context = clz._build_query()
|
|
|
|
slc.query_context = json.dumps(query_context)
|
|
backup[QUERIES_BAK_FIELD_NAME] = queries_bak
|
|
slc.params = json.dumps({**clz.data, **backup})
|
|
|
|
except Exception as e:
|
|
logger.warning("Failed to migrate slice %s: %s", slc.id, e)
|
|
|
|
@classmethod
|
|
def downgrade_slice(cls, slc: Slice) -> None:
|
|
try:
|
|
form_data = try_load_json(slc.params)
|
|
if "viz_type" in (
|
|
form_data_bak := form_data.get(FORM_DATA_BAK_FIELD_NAME, {})
|
|
):
|
|
slc.params = json.dumps(form_data_bak)
|
|
slc.viz_type = form_data_bak.get("viz_type")
|
|
query_context = try_load_json(slc.query_context)
|
|
queries_bak = form_data.get(QUERIES_BAK_FIELD_NAME, {})
|
|
if queries_bak:
|
|
query_context["queries"] = queries_bak
|
|
if "form_data" in query_context:
|
|
query_context["form_data"] = form_data_bak
|
|
slc.query_context = json.dumps(query_context)
|
|
else:
|
|
slc.query_context = None
|
|
|
|
except Exception as e:
|
|
logger.warning("Failed to downgrade slice %s: %s", slc.id, e)
|
|
|
|
@classmethod
|
|
def upgrade(cls, session: Session) -> None:
|
|
slices = session.query(Slice).filter(Slice.viz_type == cls.source_viz_type)
|
|
for slc in paginated_update(
|
|
slices,
|
|
lambda current, total: logger.info("Upgraded %s/%s charts", current, total),
|
|
):
|
|
cls.upgrade_slice(slc)
|
|
|
|
@classmethod
|
|
def downgrade(cls, session: Session) -> None:
|
|
slices = session.query(Slice).filter(
|
|
and_(
|
|
Slice.viz_type == cls.target_viz_type,
|
|
Slice.params.like(f"%{FORM_DATA_BAK_FIELD_NAME}%"),
|
|
)
|
|
)
|
|
for slc in paginated_update(
|
|
slices,
|
|
lambda current, total: logger.info(
|
|
"Downgraded %s/%s charts", current, total
|
|
),
|
|
):
|
|
cls.downgrade_slice(slc)
|
|
|
|
def _build_query(self) -> Any | dict[str, Any]:
|
|
"""Builds a query based on the form data."""
|