mirror of
https://github.com/apache/superset.git
synced 2026-05-12 19:35:17 +00:00
fix: Adjust viz migrations to also migrate the queries object (#33285)
Co-authored-by: Michael S. Molina <michael.s.molina@gmail.com>
Co-authored-by: Michael S. Molina <70410625+michael-s-molina@users.noreply.github.com>
(cherry picked from commit 57183da315)
This commit is contained in:
committed by
Michael Molina
parent
685b259f6f
commit
cfba5cdc57
@@ -44,6 +44,7 @@ class Slice(Base): # type: ignore
|
||||
|
||||
|
||||
FORM_DATA_BAK_FIELD_NAME = "form_data_bak"
|
||||
QUERIES_BAK_FIELD_NAME = "queries_bak"
|
||||
|
||||
|
||||
class MigrateViz:
|
||||
@@ -156,14 +157,24 @@ class MigrateViz:
|
||||
# because a source viz can be mapped to different target viz types
|
||||
slc.viz_type = clz.target_viz_type
|
||||
|
||||
# only backup params
|
||||
slc.params = json.dumps(
|
||||
{**clz.data, FORM_DATA_BAK_FIELD_NAME: form_data_bak}
|
||||
)
|
||||
backup = {FORM_DATA_BAK_FIELD_NAME: form_data_bak}
|
||||
|
||||
query_context = try_load_json(slc.query_context)
|
||||
|
||||
if query_context:
|
||||
if "form_data" in query_context:
|
||||
query_context["form_data"] = clz.data
|
||||
|
||||
queries_bak = copy.deepcopy(query_context["queries"])
|
||||
|
||||
queries = clz._build_query()["queries"]
|
||||
query_context["queries"] = queries
|
||||
|
||||
if "form_data" in (query_context := try_load_json(slc.query_context)):
|
||||
query_context["form_data"] = clz.data
|
||||
slc.query_context = json.dumps(query_context)
|
||||
backup[QUERIES_BAK_FIELD_NAME] = queries_bak
|
||||
|
||||
slc.params = json.dumps({**clz.data, **backup})
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to migrate slice {slc.id}: {e}")
|
||||
|
||||
@@ -177,9 +188,12 @@ class MigrateViz:
|
||||
slc.params = json.dumps(form_data_bak)
|
||||
slc.viz_type = form_data_bak.get("viz_type")
|
||||
query_context = try_load_json(slc.query_context)
|
||||
queries_bak = form_data.get(QUERIES_BAK_FIELD_NAME, {})
|
||||
query_context["queries"] = queries_bak
|
||||
if "form_data" in query_context:
|
||||
query_context["form_data"] = form_data_bak
|
||||
slc.query_context = json.dumps(query_context)
|
||||
|
||||
slc.query_context = json.dumps(query_context)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to downgrade slice {slc.id}: {e}")
|
||||
|
||||
@@ -205,3 +219,6 @@ class MigrateViz:
|
||||
lambda current, total: logger.info(f"Downgraded {current}/{total} charts"),
|
||||
):
|
||||
cls.downgrade_slice(slc)
|
||||
|
||||
def _build_query(self) -> Any | dict[str, Any]:
|
||||
"""Builds a query based on the form data."""
|
||||
|
||||
@@ -16,6 +16,32 @@
|
||||
# under the License.
|
||||
from typing import Any
|
||||
|
||||
from superset.migrations.shared.migrate_viz.query_functions import (
|
||||
build_query_context,
|
||||
contribution_operator,
|
||||
ensure_is_array,
|
||||
extract_extra_metrics,
|
||||
flatten_operator,
|
||||
get_column_label,
|
||||
get_metric_label,
|
||||
get_x_axis_column,
|
||||
histogram_operator,
|
||||
is_physical_column,
|
||||
is_time_comparison,
|
||||
is_x_axis_set,
|
||||
normalize_order_by,
|
||||
pivot_operator,
|
||||
prophet_operator,
|
||||
rank_operator,
|
||||
remove_form_data_suffix,
|
||||
rename_operator,
|
||||
resample_operator,
|
||||
retain_form_data_suffix,
|
||||
rolling_window_operator,
|
||||
sort_operator,
|
||||
time_compare_operator,
|
||||
time_compare_pivot_operator,
|
||||
)
|
||||
from superset.utils.core import as_list
|
||||
|
||||
from .base import MigrateViz
|
||||
@@ -35,6 +61,19 @@ class MigrateTreeMap(MigrateViz):
|
||||
):
|
||||
self.data["metric"] = self.data["metrics"][0]
|
||||
|
||||
def _build_query(self) -> dict[str, Any]:
|
||||
metric = self.data.get("metric")
|
||||
sort_by_metric = self.data.get("sort_by_metric")
|
||||
|
||||
def process(base_query_object: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
new_query_object = base_query_object.copy()
|
||||
|
||||
if sort_by_metric:
|
||||
new_query_object["orderby"] = [[metric, False]]
|
||||
return [new_query_object]
|
||||
|
||||
return build_query_context(self.data, process)
|
||||
|
||||
|
||||
class MigratePivotTable(MigrateViz):
|
||||
source_viz_type = "pivot_table"
|
||||
@@ -70,6 +109,58 @@ class MigratePivotTable(MigrateViz):
|
||||
|
||||
self.data["rowOrder"] = "value_z_to_a"
|
||||
|
||||
def _build_query(self) -> dict[str, Any]:
|
||||
groupby_columns = self.data.get("groupbyColumns", [])
|
||||
groupby_rows = self.data.get("groupbyRows", [])
|
||||
extra_form_data = self.data.get("extra_form_data", {})
|
||||
time_grain_sqla = extra_form_data.get("time_grain_sqla") or self.data.get(
|
||||
"time_grain_sqla"
|
||||
)
|
||||
|
||||
unique_columns = ensure_is_array(groupby_columns) + ensure_is_array(
|
||||
groupby_rows
|
||||
)
|
||||
|
||||
columns = []
|
||||
for col in unique_columns:
|
||||
if (
|
||||
is_physical_column(col)
|
||||
and time_grain_sqla
|
||||
and (
|
||||
self.data.get("temporal_columns_lookup", {}).get(col)
|
||||
or self.data.get("granularity_sqla") == col
|
||||
)
|
||||
):
|
||||
col_dict = {
|
||||
"timeGrain": time_grain_sqla,
|
||||
"columnType": "BASE_AXIS",
|
||||
"sqlExpression": col,
|
||||
"label": col,
|
||||
"expressionType": "SQL",
|
||||
}
|
||||
if col_dict not in columns:
|
||||
columns.append(col_dict)
|
||||
else:
|
||||
if col not in columns:
|
||||
columns.append(col)
|
||||
|
||||
def process(base_query_object: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
series_limit_metric = base_query_object.get("series_limit_metric")
|
||||
metrics = base_query_object.get("metrics")
|
||||
order_desc = base_query_object.get("order_desc")
|
||||
orderby = None
|
||||
if series_limit_metric:
|
||||
orderby = [[series_limit_metric, not order_desc]]
|
||||
elif isinstance(metrics, list) and metrics and metrics[0]:
|
||||
orderby = [[metrics[0], not order_desc]]
|
||||
new_query_object = base_query_object.copy()
|
||||
if orderby is not None:
|
||||
new_query_object["orderby"] = orderby
|
||||
new_query_object["columns"] = columns
|
||||
return [new_query_object]
|
||||
|
||||
return build_query_context(self.data, process)
|
||||
|
||||
|
||||
class MigrateDualLine(MigrateViz):
|
||||
has_x_axis_control = True
|
||||
@@ -94,12 +185,73 @@ class MigrateDualLine(MigrateViz):
|
||||
super()._migrate_temporal_filter(rv_data)
|
||||
rv_data["adhoc_filters_b"] = rv_data.get("adhoc_filters") or []
|
||||
|
||||
def _build_query(self) -> dict[str, Any]:
|
||||
base_form_data = self.data.copy()
|
||||
form_data1 = remove_form_data_suffix(base_form_data, "_b")
|
||||
form_data2 = retain_form_data_suffix(base_form_data, "_b")
|
||||
|
||||
def process_fn(fd: dict[str, Any]) -> dict[str, Any]:
|
||||
def process(base_query_object: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
query_object = base_query_object.copy()
|
||||
query_object["columns"] = (
|
||||
ensure_is_array(get_x_axis_column(self.data))
|
||||
if is_x_axis_set(self.data)
|
||||
else []
|
||||
) + ensure_is_array(fd.get("groupby"))
|
||||
query_object["series_columns"] = fd.get("groupby")
|
||||
if not is_x_axis_set(self.data):
|
||||
query_object["is_timeseries"] = True
|
||||
pivot_operator_runtime = (
|
||||
time_compare_pivot_operator(fd, query_object)
|
||||
if is_time_comparison(fd, query_object)
|
||||
else pivot_operator(fd, query_object)
|
||||
)
|
||||
tmp_query_object = query_object.copy()
|
||||
tmp_query_object["time_offsets"] = (
|
||||
fd.get("time_compare")
|
||||
if is_time_comparison(fd, query_object)
|
||||
else []
|
||||
)
|
||||
tmp_query_object["post_processing"] = [
|
||||
pivot_operator_runtime,
|
||||
rolling_window_operator(fd, query_object),
|
||||
time_compare_operator(fd, query_object),
|
||||
resample_operator(fd, query_object),
|
||||
rename_operator(fd, query_object),
|
||||
flatten_operator(fd, query_object),
|
||||
]
|
||||
|
||||
if tmp_query_object["series_columns"] is None:
|
||||
tmp_query_object.pop("series_columns")
|
||||
return [normalize_order_by(tmp_query_object)]
|
||||
|
||||
return build_query_context(fd, process)
|
||||
|
||||
query_contexts = [process_fn(form_data1), process_fn(form_data2)]
|
||||
qc0 = query_contexts[0]
|
||||
qc1 = query_contexts[1]
|
||||
merged = qc0.copy()
|
||||
merged["queries"] = qc0.get("queries", []) + qc1.get("queries", [])
|
||||
return merged
|
||||
|
||||
|
||||
class MigrateSunburst(MigrateViz):
|
||||
source_viz_type = "sunburst"
|
||||
target_viz_type = "sunburst_v2"
|
||||
rename_keys = {"groupby": "columns"}
|
||||
|
||||
def _build_query(self) -> dict[str, Any]:
|
||||
metric = self.data.get("metric")
|
||||
sort_by_metric = self.data.get("sort_by_metric")
|
||||
|
||||
def process(base_query_object: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
result = base_query_object.copy()
|
||||
if sort_by_metric:
|
||||
result["orderby"] = [[metric, False]]
|
||||
return [result]
|
||||
|
||||
return build_query_context(self.data, process)
|
||||
|
||||
|
||||
class TimeseriesChart(MigrateViz):
|
||||
has_x_axis_control = True
|
||||
@@ -155,6 +307,63 @@ class TimeseriesChart(MigrateViz):
|
||||
if x_ticks_layout := self.data.get("x_ticks_layout"):
|
||||
self.data["x_ticks_layout"] = 45 if x_ticks_layout == "45°" else 0
|
||||
|
||||
def _build_query(self) -> dict[str, Any]:
|
||||
groupby = self.data.get("groupby")
|
||||
|
||||
def query_builder(base_query_object: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
"""
|
||||
The `pivot_operator_in_runtime` determines how to pivot the dataframe
|
||||
returned from the raw query.
|
||||
1. If it's a time compared query, there will return a pivoted
|
||||
dataframe that append time compared metrics.
|
||||
"""
|
||||
extra_metrics = extract_extra_metrics(self.data)
|
||||
|
||||
pivot_operator_in_runtime = (
|
||||
time_compare_pivot_operator(self.data, base_query_object)
|
||||
if is_time_comparison(self.data, base_query_object)
|
||||
else pivot_operator(self.data, base_query_object)
|
||||
)
|
||||
|
||||
columns = (
|
||||
ensure_is_array(get_x_axis_column(self.data))
|
||||
if is_x_axis_set(self.data)
|
||||
else []
|
||||
) + ensure_is_array(groupby)
|
||||
|
||||
time_offsets = (
|
||||
self.data.get("time_compare")
|
||||
if is_time_comparison(self.data, base_query_object)
|
||||
else []
|
||||
)
|
||||
|
||||
result = {
|
||||
**base_query_object,
|
||||
"metrics": (base_query_object.get("metrics") or []) + extra_metrics,
|
||||
"columns": columns,
|
||||
"series_columns": groupby,
|
||||
**({"is_timeseries": True} if not is_x_axis_set(self.data) else {}),
|
||||
# todo: move `normalize_order_by to extract_query_fields`
|
||||
"orderby": normalize_order_by(base_query_object).get("orderby"),
|
||||
"time_offsets": time_offsets,
|
||||
"post_processing": [
|
||||
pivot_operator_in_runtime,
|
||||
rolling_window_operator(self.data, base_query_object),
|
||||
time_compare_operator(self.data, base_query_object),
|
||||
resample_operator(self.data, base_query_object),
|
||||
rename_operator(self.data, base_query_object),
|
||||
contribution_operator(self.data, base_query_object, time_offsets),
|
||||
sort_operator(self.data, base_query_object),
|
||||
flatten_operator(self.data, base_query_object),
|
||||
# todo: move prophet before flatten
|
||||
prophet_operator(self.data, base_query_object),
|
||||
],
|
||||
}
|
||||
|
||||
return [result]
|
||||
|
||||
return build_query_context(self.data, query_builder)
|
||||
|
||||
|
||||
class MigrateLineChart(TimeseriesChart):
|
||||
source_viz_type = "line"
|
||||
@@ -173,6 +382,9 @@ class MigrateLineChart(TimeseriesChart):
|
||||
self.target_viz_type = "echarts_timeseries_step"
|
||||
self.data["seriesType"] = "end"
|
||||
|
||||
def _build_query(self) -> dict[str, Any]:
|
||||
return super()._build_query()
|
||||
|
||||
|
||||
class MigrateAreaChart(TimeseriesChart):
|
||||
source_viz_type = "area"
|
||||
@@ -194,6 +406,9 @@ class MigrateAreaChart(TimeseriesChart):
|
||||
|
||||
self.data["opacity"] = 0.7
|
||||
|
||||
def _build_query(self) -> dict[str, Any]:
|
||||
return super()._build_query()
|
||||
|
||||
|
||||
class MigrateBarChart(TimeseriesChart):
|
||||
source_viz_type = "bar"
|
||||
@@ -208,6 +423,9 @@ class MigrateBarChart(TimeseriesChart):
|
||||
|
||||
self.data["stack"] = "Stack" if self.data.get("bar_stacked") else None
|
||||
|
||||
def _build_query(self) -> dict[str, Any]:
|
||||
return super()._build_query()
|
||||
|
||||
|
||||
class MigrateDistBarChart(TimeseriesChart):
|
||||
source_viz_type = "dist_bar"
|
||||
@@ -238,6 +456,9 @@ class MigrateDistBarChart(TimeseriesChart):
|
||||
self.data["stack"] = "Stack" if self.data.get("bar_stacked") else None
|
||||
self.data["x_ticks_layout"] = 45
|
||||
|
||||
def _build_query(self) -> dict[str, Any]:
|
||||
return super()._build_query()
|
||||
|
||||
|
||||
class MigrateBubbleChart(MigrateViz):
|
||||
source_viz_type = "bubble"
|
||||
@@ -267,6 +488,30 @@ class MigrateBubbleChart(MigrateViz):
|
||||
# Truncate y-axis by default to preserve layout
|
||||
self.data["y_axis_showminmax"] = True
|
||||
|
||||
def _build_query(self) -> dict[str, Any]:
|
||||
columns = ensure_is_array(self.data.get("entity")) + ensure_is_array(
|
||||
self.data.get("series")
|
||||
)
|
||||
|
||||
def process(base_query_object: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
if base_query_object.get("orderby"):
|
||||
orderby = [
|
||||
[
|
||||
base_query_object["orderby"][0],
|
||||
not base_query_object.get("order_desc", False),
|
||||
]
|
||||
]
|
||||
else:
|
||||
orderby = None
|
||||
|
||||
new_query_object = {**base_query_object, "columns": columns}
|
||||
if orderby is not None:
|
||||
new_query_object["orderby"] = orderby
|
||||
|
||||
return [new_query_object]
|
||||
|
||||
return build_query_context(self.data, process)
|
||||
|
||||
|
||||
class MigrateHeatmapChart(MigrateViz):
|
||||
source_viz_type = "heatmap"
|
||||
@@ -282,6 +527,53 @@ class MigrateHeatmapChart(MigrateViz):
|
||||
def _pre_action(self) -> None:
|
||||
self.data["legend_type"] = "continuous"
|
||||
|
||||
def _build_query(self) -> dict[str, Any]:
|
||||
groupby = self.data.get("groupby")
|
||||
normalize_across = self.data.get("normalize_across")
|
||||
sort_x_axis = self.data.get("sort_x_axis")
|
||||
sort_y_axis = self.data.get("sort_y_axis")
|
||||
x_axis = self.data.get("x_axis")
|
||||
|
||||
metric = get_metric_label(self.data.get("metric"))
|
||||
|
||||
columns = ensure_is_array(get_x_axis_column(self.data)) + ensure_is_array(
|
||||
groupby
|
||||
)
|
||||
|
||||
orderby = []
|
||||
if sort_x_axis:
|
||||
chosen = metric if "value" in sort_x_axis else columns[0]
|
||||
ascending = "asc" in sort_x_axis
|
||||
orderby.append([chosen, ascending])
|
||||
if sort_y_axis:
|
||||
chosen = metric if "value" in sort_y_axis else columns[1]
|
||||
ascending = "asc" in sort_y_axis
|
||||
orderby.append([chosen, ascending])
|
||||
|
||||
if normalize_across == "x":
|
||||
group_by = get_column_label(x_axis)
|
||||
elif normalize_across == "y":
|
||||
group_by = get_column_label(groupby)
|
||||
else:
|
||||
group_by = None
|
||||
|
||||
def process(base_query_object: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
new_query_object = base_query_object.copy()
|
||||
new_query_object["columns"] = columns
|
||||
if orderby:
|
||||
new_query_object["orderby"] = orderby
|
||||
new_query_object["post_processing"] = [
|
||||
rank_operator(
|
||||
self.data,
|
||||
base_query_object,
|
||||
{"metric": metric, "group_by": group_by},
|
||||
)
|
||||
]
|
||||
|
||||
return [new_query_object]
|
||||
|
||||
return build_query_context(self.data, process)
|
||||
|
||||
|
||||
class MigrateHistogramChart(MigrateViz):
|
||||
source_viz_type = "histogram"
|
||||
@@ -305,6 +597,22 @@ class MigrateHistogramChart(MigrateViz):
|
||||
if not groupby:
|
||||
self.data["groupby"] = []
|
||||
|
||||
def _build_query(self) -> dict[str, Any]:
|
||||
column = self.data.get("column")
|
||||
groupby = self.data.get("groupby", [])
|
||||
|
||||
def process(base_query_object: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
result = base_query_object.copy()
|
||||
result["columns"] = groupby + [column]
|
||||
result["post_processing"] = [
|
||||
histogram_operator(self.data, base_query_object)
|
||||
]
|
||||
if "metrics" in result.keys():
|
||||
result.pop("metrics", None)
|
||||
return [result]
|
||||
|
||||
return build_query_context(self.data, process)
|
||||
|
||||
|
||||
class MigrateSankey(MigrateViz):
|
||||
source_viz_type = "sankey"
|
||||
@@ -316,3 +624,19 @@ class MigrateSankey(MigrateViz):
|
||||
if groupby and len(groupby) > 1:
|
||||
self.data["source"] = groupby[0]
|
||||
self.data["target"] = groupby[1]
|
||||
|
||||
def _build_query(self) -> dict[str, Any]:
|
||||
metric = self.data.get("metric")
|
||||
sort_by_metric = self.data.get("sort_by_metric")
|
||||
source = self.data.get("source")
|
||||
target = self.data.get("target")
|
||||
groupby = [source, target]
|
||||
|
||||
def process(base_query_object: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
result = base_query_object.copy()
|
||||
result["groupby"] = groupby
|
||||
if sort_by_metric:
|
||||
result["orderby"] = [[metric, False]]
|
||||
return [result]
|
||||
|
||||
return build_query_context(self.data, process)
|
||||
|
||||
1507
superset/migrations/shared/migrate_viz/query_functions.py
Normal file
1507
superset/migrations/shared/migrate_viz/query_functions.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user