diff --git a/superset/migrations/shared/migrate_viz/base.py b/superset/migrations/shared/migrate_viz/base.py index 024a58463e2..e73cddd82d3 100644 --- a/superset/migrations/shared/migrate_viz/base.py +++ b/superset/migrations/shared/migrate_viz/base.py @@ -16,14 +16,15 @@ # under the License. from __future__ import annotations +import copy import json -from typing import Dict, Set +from typing import Any, Dict, Set from alembic import op from sqlalchemy import and_, Column, Integer, String, Text from sqlalchemy.ext.declarative import declarative_base -from superset import db +from superset import conf, db, is_feature_enabled from superset.migrations.shared.utils import paginated_update, try_load_json Base = declarative_base() @@ -52,7 +53,7 @@ class MigrateViz: self.data = try_load_json(form_data) def _pre_action(self) -> None: - """some actions before migrate""" + """Some actions before migrate""" def _migrate(self) -> None: if self.data.get("viz_type") != self.source_viz_type: @@ -68,22 +69,50 @@ class MigrateViz: if key in self.rename_keys: rv_data[self.rename_keys[key]] = value + continue if key in self.remove_keys: continue rv_data[key] = value + if is_feature_enabled("GENERIC_CHART_AXES"): + self._migrate_temporal_filter(rv_data) + self.data = rv_data def _post_action(self) -> None: - """some actions after migrate""" + """Some actions after migrate""" + + def _migrate_temporal_filter(self, rv_data: Dict[str, Any]) -> None: + """Adds a temporal filter.""" + granularity_sqla = rv_data.pop("granularity_sqla", None) + time_range = rv_data.pop("time_range", None) or conf.get("DEFAULT_TIME_FILTER") + + if not granularity_sqla: + return + + temporal_filter = { + "clause": "WHERE", + "subject": granularity_sqla, + "operator": "TEMPORAL_RANGE", + "comparator": time_range, + "expressionType": "SIMPLE", + } + + if isinstance(granularity_sqla, dict): + temporal_filter["comparator"] = None + temporal_filter["expressionType"] = "SQL" + temporal_filter["subject"] = granularity_sqla["label"] + temporal_filter["sqlExpression"] = granularity_sqla["sqlExpression"] + + rv_data["adhoc_filters"] = rv_data.get("adhoc_filters", []) + [temporal_filter] @classmethod def upgrade_slice(cls, slc: Slice) -> Slice: clz = cls(slc.params) slc.viz_type = cls.target_viz_type - form_data_bak = clz.data.copy() + form_data_bak = copy.deepcopy(clz.data) clz._pre_action() clz._migrate()