mirror of
https://github.com/apache/superset.git
synced 2026-04-10 20:06:13 +00:00
645 lines
23 KiB
Python
645 lines
23 KiB
Python
# Licensed to the Apache Software Foundation (ASF) under one
|
|
# or more contributor license agreements. See the NOTICE file
|
|
# distributed with this work for additional information
|
|
# regarding copyright ownership. The ASF licenses this file
|
|
# to you under the Apache License, Version 2.0 (the
|
|
# "License"); you may not use this file except in compliance
|
|
# with the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing,
|
|
# software distributed under the License is distributed on an
|
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
# KIND, either express or implied. See the License for the
|
|
# specific language governing permissions and limitations
|
|
# under the License.
|
|
from typing import Any
|
|
|
|
from superset.migrations.shared.migrate_viz.query_functions import (
|
|
build_query_context,
|
|
contribution_operator,
|
|
ensure_is_array,
|
|
extract_extra_metrics,
|
|
flatten_operator,
|
|
get_column_label,
|
|
get_metric_label,
|
|
get_x_axis_column,
|
|
histogram_operator,
|
|
is_physical_column,
|
|
is_time_comparison,
|
|
is_x_axis_set,
|
|
normalize_order_by,
|
|
pivot_operator,
|
|
prophet_operator,
|
|
rank_operator,
|
|
remove_form_data_suffix,
|
|
rename_operator,
|
|
resample_operator,
|
|
retain_form_data_suffix,
|
|
rolling_window_operator,
|
|
sort_operator,
|
|
time_compare_operator,
|
|
time_compare_pivot_operator,
|
|
)
|
|
from superset.utils.core import as_list
|
|
|
|
from .base import MigrateViz
|
|
|
|
|
|
class MigrateTreeMap(MigrateViz):
|
|
source_viz_type = "treemap"
|
|
target_viz_type = "treemap_v2"
|
|
remove_keys = {"metrics"}
|
|
rename_keys = {"order_desc": "sort_by_metric"}
|
|
|
|
def _pre_action(self) -> None:
|
|
if (
|
|
"metrics" in self.data
|
|
and isinstance(self.data["metrics"], list)
|
|
and len(self.data["metrics"]) > 0
|
|
):
|
|
self.data["metric"] = self.data["metrics"][0]
|
|
|
|
def _build_query(self) -> dict[str, Any]:
|
|
metric = self.data.get("metric")
|
|
sort_by_metric = self.data.get("sort_by_metric")
|
|
|
|
def process(base_query_object: dict[str, Any]) -> list[dict[str, Any]]:
|
|
new_query_object = base_query_object.copy()
|
|
|
|
if sort_by_metric:
|
|
new_query_object["orderby"] = [[metric, False]]
|
|
return [new_query_object]
|
|
|
|
return build_query_context(self.data, process)
|
|
|
|
|
|
class MigratePivotTable(MigrateViz):
|
|
source_viz_type = "pivot_table"
|
|
target_viz_type = "pivot_table_v2"
|
|
remove_keys = {"pivot_margins"}
|
|
rename_keys = {
|
|
"columns": "groupbyColumns",
|
|
"combine_metric": "combineMetric",
|
|
"groupby": "groupbyRows",
|
|
"number_format": "valueFormat",
|
|
"pandas_aggfunc": "aggregateFunction",
|
|
"row_limit": "series_limit",
|
|
"timeseries_limit_metric": "series_limit_metric",
|
|
"transpose_pivot": "transposePivot",
|
|
}
|
|
aggregation_mapping = {
|
|
"sum": "Sum",
|
|
"mean": "Average",
|
|
"median": "Median",
|
|
"min": "Minimum",
|
|
"max": "Maximum",
|
|
"std": "Sample Standard Deviation",
|
|
"var": "Sample Variance",
|
|
}
|
|
|
|
def _pre_action(self) -> None:
|
|
if pivot_margins := self.data.get("pivot_margins"):
|
|
self.data["colTotals"] = pivot_margins
|
|
self.data["colSubTotals"] = pivot_margins
|
|
|
|
if pandas_aggfunc := self.data.get("pandas_aggfunc"):
|
|
self.data["pandas_aggfunc"] = self.aggregation_mapping[pandas_aggfunc]
|
|
|
|
self.data["rowOrder"] = "value_z_to_a"
|
|
|
|
def _build_query(self) -> dict[str, Any]:
|
|
groupby_columns = self.data.get("groupbyColumns", [])
|
|
groupby_rows = self.data.get("groupbyRows", [])
|
|
extra_form_data = self.data.get("extra_form_data", {})
|
|
time_grain_sqla = extra_form_data.get("time_grain_sqla") or self.data.get(
|
|
"time_grain_sqla"
|
|
)
|
|
|
|
unique_columns = ensure_is_array(groupby_columns) + ensure_is_array(
|
|
groupby_rows
|
|
)
|
|
|
|
columns = []
|
|
for col in unique_columns:
|
|
if (
|
|
is_physical_column(col)
|
|
and time_grain_sqla
|
|
and (
|
|
self.data.get("temporal_columns_lookup", {}).get(col)
|
|
or self.data.get("granularity_sqla") == col
|
|
)
|
|
):
|
|
col_dict = {
|
|
"timeGrain": time_grain_sqla,
|
|
"columnType": "BASE_AXIS",
|
|
"sqlExpression": col,
|
|
"label": col,
|
|
"expressionType": "SQL",
|
|
}
|
|
if col_dict not in columns:
|
|
columns.append(col_dict)
|
|
else:
|
|
if col not in columns:
|
|
columns.append(col)
|
|
|
|
def process(base_query_object: dict[str, Any]) -> list[dict[str, Any]]:
|
|
series_limit_metric = base_query_object.get("series_limit_metric")
|
|
metrics = base_query_object.get("metrics")
|
|
order_desc = base_query_object.get("order_desc")
|
|
orderby = None
|
|
if series_limit_metric:
|
|
orderby = [[series_limit_metric, not order_desc]]
|
|
elif isinstance(metrics, list) and metrics and metrics[0]:
|
|
orderby = [[metrics[0], not order_desc]]
|
|
new_query_object = base_query_object.copy()
|
|
if orderby is not None:
|
|
new_query_object["orderby"] = orderby
|
|
new_query_object["columns"] = columns
|
|
return [new_query_object]
|
|
|
|
return build_query_context(self.data, process)
|
|
|
|
|
|
class MigrateDualLine(MigrateViz):
|
|
has_x_axis_control = True
|
|
source_viz_type = "dual_line"
|
|
target_viz_type = "mixed_timeseries"
|
|
rename_keys = {
|
|
"x_axis_format": "x_axis_time_format",
|
|
"y_axis_2_format": "y_axis_format_secondary",
|
|
"y_axis_2_bounds": "y_axis_bounds_secondary",
|
|
}
|
|
remove_keys = {"metric", "metric_2"}
|
|
|
|
def _pre_action(self) -> None:
|
|
self.data["yAxisIndex"] = 0
|
|
self.data["yAxisIndexB"] = 1
|
|
self.data["adhoc_filters_b"] = self.data.get("adhoc_filters")
|
|
self.data["truncateYAxis"] = True
|
|
self.data["metrics"] = [self.data.get("metric")]
|
|
self.data["metrics_b"] = [self.data.get("metric_2")]
|
|
|
|
def _migrate_temporal_filter(self, rv_data: dict[str, Any]) -> None:
|
|
super()._migrate_temporal_filter(rv_data)
|
|
rv_data["adhoc_filters_b"] = rv_data.get("adhoc_filters") or []
|
|
|
|
def _build_query(self) -> dict[str, Any]:
|
|
base_form_data = self.data.copy()
|
|
form_data1 = remove_form_data_suffix(base_form_data, "_b")
|
|
form_data2 = retain_form_data_suffix(base_form_data, "_b")
|
|
|
|
def process_fn(fd: dict[str, Any]) -> dict[str, Any]:
|
|
def process(base_query_object: dict[str, Any]) -> list[dict[str, Any]]:
|
|
query_object = base_query_object.copy()
|
|
query_object["columns"] = (
|
|
ensure_is_array(get_x_axis_column(self.data))
|
|
if is_x_axis_set(self.data)
|
|
else []
|
|
) + ensure_is_array(fd.get("groupby"))
|
|
query_object["series_columns"] = fd.get("groupby")
|
|
if not is_x_axis_set(self.data):
|
|
query_object["is_timeseries"] = True
|
|
pivot_operator_runtime = (
|
|
time_compare_pivot_operator(fd, query_object)
|
|
if is_time_comparison(fd, query_object)
|
|
else pivot_operator(fd, query_object)
|
|
)
|
|
tmp_query_object = query_object.copy()
|
|
tmp_query_object["time_offsets"] = (
|
|
fd.get("time_compare")
|
|
if is_time_comparison(fd, query_object)
|
|
else []
|
|
)
|
|
tmp_query_object["post_processing"] = [
|
|
pivot_operator_runtime,
|
|
rolling_window_operator(fd, query_object),
|
|
time_compare_operator(fd, query_object),
|
|
resample_operator(fd, query_object),
|
|
rename_operator(fd, query_object),
|
|
flatten_operator(fd, query_object),
|
|
]
|
|
|
|
if tmp_query_object["series_columns"] is None:
|
|
tmp_query_object.pop("series_columns")
|
|
return [normalize_order_by(tmp_query_object)]
|
|
|
|
return build_query_context(fd, process)
|
|
|
|
query_contexts = [process_fn(form_data1), process_fn(form_data2)]
|
|
qc0 = query_contexts[0]
|
|
qc1 = query_contexts[1]
|
|
merged = qc0.copy()
|
|
merged["queries"] = qc0.get("queries", []) + qc1.get("queries", [])
|
|
return merged
|
|
|
|
|
|
class MigrateSunburst(MigrateViz):
|
|
source_viz_type = "sunburst"
|
|
target_viz_type = "sunburst_v2"
|
|
rename_keys = {"groupby": "columns"}
|
|
|
|
def _build_query(self) -> dict[str, Any]:
|
|
metric = self.data.get("metric")
|
|
sort_by_metric = self.data.get("sort_by_metric")
|
|
|
|
def process(base_query_object: dict[str, Any]) -> list[dict[str, Any]]:
|
|
result = base_query_object.copy()
|
|
if sort_by_metric:
|
|
result["orderby"] = [[metric, False]]
|
|
return [result]
|
|
|
|
return build_query_context(self.data, process)
|
|
|
|
|
|
class TimeseriesChart(MigrateViz):
|
|
has_x_axis_control = True
|
|
rename_keys = {
|
|
"bottom_margin": "x_axis_title_margin",
|
|
"left_margin": "y_axis_title_margin",
|
|
"show_controls": "show_extra_controls",
|
|
"x_axis_label": "x_axis_title",
|
|
"x_axis_format": "x_axis_time_format",
|
|
"x_axis_showminmax": "truncateXAxis",
|
|
"x_ticks_layout": "xAxisLabelRotation",
|
|
"y_axis_label": "y_axis_title",
|
|
"y_axis_showminmax": "truncateYAxis",
|
|
"y_log_scale": "logAxis",
|
|
}
|
|
remove_keys = {
|
|
"contribution",
|
|
"line_interpolation",
|
|
"reduce_x_ticks",
|
|
"show_brush",
|
|
"show_markers",
|
|
}
|
|
|
|
def _pre_action(self) -> None:
|
|
self.data["contributionMode"] = "row" if self.data.get("contribution") else None
|
|
self.data["zoomable"] = self.data.get("show_brush") == "yes"
|
|
self.data["markerEnabled"] = self.data.get("show_markers") or False
|
|
self.data["y_axis_showminmax"] = True
|
|
|
|
bottom_margin = self.data.get("bottom_margin")
|
|
if self.data.get("x_axis_label") and (
|
|
not bottom_margin or bottom_margin == "auto"
|
|
):
|
|
self.data["bottom_margin"] = 30
|
|
|
|
left_margin = self.data.get("left_margin")
|
|
if self.data.get("y_axis_label") and (not left_margin or left_margin == "auto"):
|
|
self.data["left_margin"] = 30
|
|
|
|
if (rolling_type := self.data.get("rolling_type")) and rolling_type != "None":
|
|
self.data["rolling_type"] = rolling_type
|
|
|
|
if (time_compare := self.data.get("time_compare")) is not None:
|
|
self.data["time_compare"] = [
|
|
v if v.endswith(" ago") else v + " ago"
|
|
for value in as_list(time_compare)
|
|
if (v := value.strip())
|
|
]
|
|
|
|
comparison_type = self.data.get("comparison_type") or "values"
|
|
self.data["comparison_type"] = (
|
|
"difference" if comparison_type == "absolute" else comparison_type
|
|
)
|
|
|
|
if x_ticks_layout := self.data.get("x_ticks_layout"):
|
|
self.data["x_ticks_layout"] = 45 if x_ticks_layout == "45°" else 0
|
|
|
|
def _build_query(self) -> dict[str, Any]:
|
|
groupby = self.data.get("groupby")
|
|
|
|
def query_builder(base_query_object: dict[str, Any]) -> list[dict[str, Any]]:
|
|
"""
|
|
The `pivot_operator_in_runtime` determines how to pivot the dataframe
|
|
returned from the raw query.
|
|
1. If it's a time compared query, there will return a pivoted
|
|
dataframe that append time compared metrics.
|
|
"""
|
|
extra_metrics = extract_extra_metrics(self.data)
|
|
|
|
pivot_operator_in_runtime = (
|
|
time_compare_pivot_operator(self.data, base_query_object)
|
|
if is_time_comparison(self.data, base_query_object)
|
|
else pivot_operator(self.data, base_query_object)
|
|
)
|
|
|
|
columns = (
|
|
ensure_is_array(get_x_axis_column(self.data))
|
|
if is_x_axis_set(self.data)
|
|
else []
|
|
) + ensure_is_array(groupby)
|
|
|
|
time_offsets = (
|
|
self.data.get("time_compare")
|
|
if is_time_comparison(self.data, base_query_object)
|
|
else []
|
|
)
|
|
|
|
result = {
|
|
**base_query_object,
|
|
"metrics": (base_query_object.get("metrics") or []) + extra_metrics,
|
|
"columns": columns,
|
|
"series_columns": groupby,
|
|
**({"is_timeseries": True} if not is_x_axis_set(self.data) else {}),
|
|
# todo: move `normalize_order_by to extract_query_fields`
|
|
"orderby": normalize_order_by(base_query_object).get("orderby"),
|
|
"time_offsets": time_offsets,
|
|
"post_processing": [
|
|
pivot_operator_in_runtime,
|
|
rolling_window_operator(self.data, base_query_object),
|
|
time_compare_operator(self.data, base_query_object),
|
|
resample_operator(self.data, base_query_object),
|
|
rename_operator(self.data, base_query_object),
|
|
contribution_operator(self.data, base_query_object, time_offsets),
|
|
sort_operator(self.data, base_query_object),
|
|
flatten_operator(self.data, base_query_object),
|
|
# todo: move prophet before flatten
|
|
prophet_operator(self.data, base_query_object),
|
|
],
|
|
}
|
|
|
|
return [result]
|
|
|
|
return build_query_context(self.data, query_builder)
|
|
|
|
|
|
class MigrateLineChart(TimeseriesChart):
|
|
source_viz_type = "line"
|
|
target_viz_type = "echarts_timeseries_line"
|
|
|
|
def _pre_action(self) -> None:
|
|
super()._pre_action()
|
|
|
|
line_interpolation = self.data.get("line_interpolation")
|
|
if line_interpolation == "cardinal":
|
|
self.target_viz_type = "echarts_timeseries_smooth"
|
|
elif line_interpolation == "step-before":
|
|
self.target_viz_type = "echarts_timeseries_step"
|
|
self.data["seriesType"] = "start"
|
|
elif line_interpolation == "step-after":
|
|
self.target_viz_type = "echarts_timeseries_step"
|
|
self.data["seriesType"] = "end"
|
|
|
|
def _build_query(self) -> dict[str, Any]:
|
|
return super()._build_query()
|
|
|
|
|
|
class MigrateAreaChart(TimeseriesChart):
|
|
source_viz_type = "area"
|
|
target_viz_type = "echarts_area"
|
|
stacked_map = {
|
|
"expand": "Expand",
|
|
"stack": "Stack",
|
|
"stream": "Stream",
|
|
}
|
|
|
|
def _pre_action(self) -> None:
|
|
super()._pre_action()
|
|
|
|
self.remove_keys.add("stacked_style")
|
|
|
|
self.data["stack"] = self.stacked_map.get(
|
|
self.data.get("stacked_style") or "stack"
|
|
)
|
|
|
|
self.data["opacity"] = 0.7
|
|
|
|
def _build_query(self) -> dict[str, Any]:
|
|
return super()._build_query()
|
|
|
|
|
|
class MigrateBarChart(TimeseriesChart):
|
|
source_viz_type = "bar"
|
|
target_viz_type = "echarts_timeseries_bar"
|
|
|
|
def _pre_action(self) -> None:
|
|
super()._pre_action()
|
|
|
|
self.rename_keys["show_bar_value"] = "show_value"
|
|
|
|
self.remove_keys.add("bar_stacked")
|
|
|
|
self.data["stack"] = "Stack" if self.data.get("bar_stacked") else None
|
|
|
|
def _build_query(self) -> dict[str, Any]:
|
|
return super()._build_query()
|
|
|
|
|
|
class MigrateDistBarChart(TimeseriesChart):
|
|
source_viz_type = "dist_bar"
|
|
target_viz_type = "echarts_timeseries_bar"
|
|
has_x_axis_control = False
|
|
|
|
def _pre_action(self) -> None:
|
|
super()._pre_action()
|
|
|
|
groupby = self.data.get("groupby") or []
|
|
columns = self.data.get("columns") or []
|
|
if len(groupby) > 0:
|
|
# x-axis supports only one value
|
|
self.data["x_axis"] = groupby[0]
|
|
|
|
self.data["groupby"] = []
|
|
if len(groupby) > 1:
|
|
# rest of groupby will go into dimensions
|
|
self.data["groupby"] += groupby[1:]
|
|
if len(columns) > 0:
|
|
self.data["groupby"] += columns
|
|
|
|
self.rename_keys["show_bar_value"] = "show_value"
|
|
|
|
self.remove_keys.add("columns")
|
|
self.remove_keys.add("bar_stacked")
|
|
|
|
self.data["stack"] = "Stack" if self.data.get("bar_stacked") else None
|
|
self.data["x_ticks_layout"] = 45
|
|
|
|
def _build_query(self) -> dict[str, Any]:
|
|
return super()._build_query()
|
|
|
|
|
|
class MigrateBubbleChart(MigrateViz):
|
|
source_viz_type = "bubble"
|
|
target_viz_type = "bubble_v2"
|
|
rename_keys = {
|
|
"bottom_margin": "x_axis_title_margin",
|
|
"left_margin": "y_axis_title_margin",
|
|
"limit": "row_limit",
|
|
"x_axis_format": "xAxisFormat",
|
|
"x_log_scale": "logXAxis",
|
|
"x_ticks_layout": "xAxisLabelRotation",
|
|
"y_axis_showminmax": "truncateYAxis",
|
|
"y_log_scale": "logYAxis",
|
|
}
|
|
remove_keys = {"x_axis_showminmax"}
|
|
|
|
def _pre_action(self) -> None:
|
|
bottom_margin = self.data.get("bottom_margin")
|
|
if self.data.get("x_axis_label") and (
|
|
not bottom_margin or bottom_margin == "auto"
|
|
):
|
|
self.data["bottom_margin"] = 30
|
|
|
|
if x_ticks_layout := self.data.get("x_ticks_layout"):
|
|
self.data["x_ticks_layout"] = 45 if x_ticks_layout == "45°" else 0
|
|
|
|
# Truncate y-axis by default to preserve layout
|
|
self.data["y_axis_showminmax"] = True
|
|
|
|
def _build_query(self) -> dict[str, Any]:
|
|
columns = ensure_is_array(self.data.get("entity")) + ensure_is_array(
|
|
self.data.get("series")
|
|
)
|
|
|
|
def process(base_query_object: dict[str, Any]) -> list[dict[str, Any]]:
|
|
if base_query_object.get("orderby"):
|
|
orderby = [
|
|
[
|
|
base_query_object["orderby"][0],
|
|
not base_query_object.get("order_desc", False),
|
|
]
|
|
]
|
|
else:
|
|
orderby = None
|
|
|
|
new_query_object = {**base_query_object, "columns": columns}
|
|
if orderby is not None:
|
|
new_query_object["orderby"] = orderby
|
|
|
|
return [new_query_object]
|
|
|
|
return build_query_context(self.data, process)
|
|
|
|
|
|
class MigrateHeatmapChart(MigrateViz):
|
|
source_viz_type = "heatmap"
|
|
target_viz_type = "heatmap_v2"
|
|
rename_keys = {
|
|
"all_columns_x": "x_axis",
|
|
"all_columns_y": "groupby",
|
|
"y_axis_bounds": "value_bounds",
|
|
"show_perc": "show_percentage",
|
|
}
|
|
remove_keys = {"sort_by_metric", "canvas_image_rendering"}
|
|
|
|
def _pre_action(self) -> None:
|
|
self.data["legend_type"] = "continuous"
|
|
|
|
def _build_query(self) -> dict[str, Any]:
|
|
groupby = self.data.get("groupby")
|
|
normalize_across = self.data.get("normalize_across")
|
|
sort_x_axis = self.data.get("sort_x_axis")
|
|
sort_y_axis = self.data.get("sort_y_axis")
|
|
x_axis = self.data.get("x_axis")
|
|
|
|
metric = get_metric_label(self.data.get("metric"))
|
|
|
|
columns = ensure_is_array(get_x_axis_column(self.data)) + ensure_is_array(
|
|
groupby
|
|
)
|
|
|
|
orderby = []
|
|
if sort_x_axis:
|
|
chosen = metric if "value" in sort_x_axis else columns[0]
|
|
ascending = "asc" in sort_x_axis
|
|
orderby.append([chosen, ascending])
|
|
if sort_y_axis:
|
|
chosen = metric if "value" in sort_y_axis else columns[1]
|
|
ascending = "asc" in sort_y_axis
|
|
orderby.append([chosen, ascending])
|
|
|
|
if normalize_across == "x":
|
|
group_by = get_column_label(x_axis)
|
|
elif normalize_across == "y":
|
|
group_by = get_column_label(groupby)
|
|
else:
|
|
group_by = None
|
|
|
|
def process(base_query_object: dict[str, Any]) -> list[dict[str, Any]]:
|
|
new_query_object = base_query_object.copy()
|
|
new_query_object["columns"] = columns
|
|
if orderby:
|
|
new_query_object["orderby"] = orderby
|
|
new_query_object["post_processing"] = [
|
|
rank_operator(
|
|
self.data,
|
|
base_query_object,
|
|
{"metric": metric, "group_by": group_by},
|
|
)
|
|
]
|
|
|
|
return [new_query_object]
|
|
|
|
return build_query_context(self.data, process)
|
|
|
|
|
|
class MigrateHistogramChart(MigrateViz):
|
|
source_viz_type = "histogram"
|
|
target_viz_type = "histogram_v2"
|
|
rename_keys = {
|
|
"x_axis_label": "x_axis_title",
|
|
"y_axis_label": "y_axis_title",
|
|
"normalized": "normalize",
|
|
}
|
|
remove_keys = {"all_columns_x", "link_length", "queryFields"}
|
|
|
|
def _pre_action(self) -> None:
|
|
all_columns_x = self.data.get("all_columns_x")
|
|
if all_columns_x and len(all_columns_x) > 0:
|
|
self.data["column"] = all_columns_x[0]
|
|
|
|
link_length = self.data.get("link_length")
|
|
self.data["bins"] = int(link_length) if link_length else 5
|
|
|
|
groupby = self.data.get("groupby")
|
|
if not groupby:
|
|
self.data["groupby"] = []
|
|
|
|
def _build_query(self) -> dict[str, Any]:
|
|
column = self.data.get("column")
|
|
groupby = self.data.get("groupby", [])
|
|
|
|
def process(base_query_object: dict[str, Any]) -> list[dict[str, Any]]:
|
|
result = base_query_object.copy()
|
|
result["columns"] = groupby + [column]
|
|
result["post_processing"] = [
|
|
histogram_operator(self.data, base_query_object)
|
|
]
|
|
if "metrics" in result.keys():
|
|
result.pop("metrics", None)
|
|
return [result]
|
|
|
|
return build_query_context(self.data, process)
|
|
|
|
|
|
class MigrateSankey(MigrateViz):
|
|
source_viz_type = "sankey"
|
|
target_viz_type = "sankey_v2"
|
|
remove_keys = {"groupby"}
|
|
|
|
def _pre_action(self) -> None:
|
|
groupby = self.data.get("groupby")
|
|
if groupby and len(groupby) > 1:
|
|
self.data["source"] = groupby[0]
|
|
self.data["target"] = groupby[1]
|
|
|
|
def _build_query(self) -> dict[str, Any]:
|
|
metric = self.data.get("metric")
|
|
sort_by_metric = self.data.get("sort_by_metric")
|
|
source = self.data.get("source")
|
|
target = self.data.get("target")
|
|
groupby = [source, target]
|
|
|
|
def process(base_query_object: dict[str, Any]) -> list[dict[str, Any]]:
|
|
result = base_query_object.copy()
|
|
result["groupby"] = groupby
|
|
if sort_by_metric:
|
|
result["orderby"] = [[metric, False]]
|
|
return [result]
|
|
|
|
return build_query_context(self.data, process)
|