fix(mcp): get_chart_sql drops x_axis on echarts_timeseries_* and only renders one query for mixed_timeseries (#39865)

This commit is contained in:
Amin Ghadersohi
2026-05-08 15:29:28 -04:00
committed by GitHub
parent dce3317bc9
commit ff7dc53853
2 changed files with 494 additions and 14 deletions

View File

@@ -141,6 +141,21 @@ def _resolve_metrics_and_groupby(
return _resolve_metrics(form_data, viz_type), _resolve_groupby(form_data)
def _extract_x_axis_col(form_data: dict[str, Any]) -> str | None:
"""Return the x_axis column name from form_data, or None if not set.
``x_axis`` may be stored as a plain column-name string or as an adhoc
column dict (``{"column_name": "...", ...}``).
"""
x_axis = form_data.get("x_axis")
if isinstance(x_axis, str) and x_axis:
return x_axis
if isinstance(x_axis, dict):
col_name = x_axis.get("column_name")
return col_name if isinstance(col_name, str) and col_name else None
return None
def _resolve_engine(
datasource_id: Any,
datasource_type: str,
@@ -162,6 +177,58 @@ def _resolve_engine(
return "base"
def _build_single_query_dict(
form_data: dict[str, Any],
columns: list[Any],
metrics: list[Any],
) -> dict[str, Any]:
"""Build one query entry for QueryContextFactory from form_data fields."""
qd: dict[str, Any] = {"columns": columns, "metrics": metrics}
if time_range := form_data.get("time_range"):
qd["time_range"] = time_range
if filters := form_data.get("filters"):
qd["filters"] = filters
if (row_limit := form_data.get("row_limit")) is not None:
qd["row_limit"] = row_limit
return qd
def _build_mixed_timeseries_secondary(
form_data: dict[str, Any],
x_axis_col: str | None,
engine: str = "base",
) -> dict[str, Any]:
"""Build the secondary query dict for the ``mixed_timeseries`` viz type.
``mixed_timeseries`` has two independent series layers; the secondary
layer uses ``metrics_b`` / ``groupby_b`` instead of the primary fields.
Secondary-specific overrides (``time_range_b``, ``row_limit_b``,
``adhoc_filters_b``) replace the corresponding primary values so the
generated SQL accurately reflects each series' independent configuration.
"""
metrics_b: list[Any] = list(form_data.get("metrics_b") or [])
raw_b = form_data.get("groupby_b") or []
groupby_b: list[Any] = [raw_b] if isinstance(raw_b, str) else list(raw_b)
if x_axis_col and x_axis_col not in groupby_b:
groupby_b = [x_axis_col] + groupby_b
qd = _build_single_query_dict(form_data, groupby_b, metrics_b)
if time_range_b := form_data.get("time_range_b"):
qd["time_range"] = time_range_b
if (row_limit_b := form_data.get("row_limit_b")) is not None:
qd["row_limit"] = row_limit_b
# Process adhoc_filters_b into concrete filter clauses for the secondary
# query, mirroring how split_adhoc_filters_into_base_filters handles the
# primary adhoc_filters in _build_query_context_from_form_data.
if adhoc_filters_b := form_data.get("adhoc_filters_b"):
from superset.utils.core import split_adhoc_filters_into_base_filters
secondary_fd: dict[str, Any] = {"adhoc_filters": adhoc_filters_b}
split_adhoc_filters_into_base_filters(secondary_fd, engine)
if secondary_filters := secondary_fd.get("filters"):
qd["filters"] = secondary_filters
return qd
def _build_query_context_from_form_data(
form_data: dict[str, Any],
chart: "Slice | None" = None,
@@ -209,22 +276,33 @@ def _build_query_context_from_form_data(
merge_extra_filters(form_data)
split_adhoc_filters_into_base_filters(form_data, engine)
# Build query dict with temporal and filter fields.
# QueryObjectFactory.create() accepts time_range as a top-level kwarg
# and converts it to from_dttm/to_dttm for the QueryObject.
query_dict: dict[str, Any] = {
"columns": groupby,
"metrics": metrics,
}
viz_type: str = (
form_data.get("viz_type")
or (getattr(chart, "viz_type", "") if chart else "")
or ""
)
is_timeseries = (
viz_type.startswith("echarts_timeseries") or viz_type == "mixed_timeseries"
)
if time_range := form_data.get("time_range"):
query_dict["time_range"] = time_range
# For echarts_timeseries_* and mixed_timeseries charts the temporal
# column is stored in x_axis rather than groupby. Prepend it so the
# generated SQL includes the time axis.
x_axis_col: str | None = None
if is_timeseries:
x_axis_col = _extract_x_axis_col(form_data)
if x_axis_col and x_axis_col not in groupby:
groupby = [x_axis_col] + groupby
if filters := form_data.get("filters"):
query_dict["filters"] = filters
queries: list[dict[str, Any]] = [
_build_single_query_dict(form_data, groupby, metrics)
]
if (row_limit := form_data.get("row_limit")) is not None:
query_dict["row_limit"] = row_limit
# mixed_timeseries exposes two independent query layers (primary and
# secondary). Build the second query from metrics_b / groupby_b so
# that get_chart_sql returns SQL for both and neither is silently lost.
if viz_type == "mixed_timeseries":
queries.append(_build_mixed_timeseries_secondary(form_data, x_axis_col, engine))
# Ensure datasource fields satisfy DatasourceDict typing requirements.
# datasource_id must be int | str; datasource_type must be str.
@@ -238,7 +316,7 @@ def _build_query_context_from_form_data(
return factory.create(
datasource={"id": resolved_id, "type": resolved_type_str},
queries=[query_dict],
queries=queries,
form_data=form_data,
result_type=ChartDataResultType.QUERY,
force=False,

View File

@@ -33,6 +33,7 @@ from superset.mcp_service.chart.schemas import (
from superset.mcp_service.chart.tool.get_chart_sql import (
_build_query_context_from_form_data,
_extract_sql_from_result,
_extract_x_axis_col,
_find_chart_by_identifier,
_resolve_datasource_name,
_resolve_effective_form_data,
@@ -468,6 +469,407 @@ class TestBuildQueryContextFromFormData:
assert queries[0]["columns"] == ["product"]
class TestExtractXAxisCol:
"""Tests for the _extract_x_axis_col helper."""
def test_string_x_axis(self):
"""Plain string x_axis returns the string directly."""
assert _extract_x_axis_col({"x_axis": "order_date"}) == "order_date"
def test_dict_x_axis(self):
"""Adhoc column dict x_axis returns column_name."""
assert (
_extract_x_axis_col(
{
"x_axis": {
"column_name": "ds",
"label": "ds",
"expressionType": "SIMPLE",
}
}
)
== "ds"
)
def test_missing_x_axis_returns_none(self):
"""Missing x_axis key returns None."""
assert _extract_x_axis_col({}) is None
def test_none_x_axis_returns_none(self):
"""Explicit None x_axis returns None."""
assert _extract_x_axis_col({"x_axis": None}) is None
def test_empty_string_x_axis_returns_none(self):
"""Empty string x_axis returns None."""
assert _extract_x_axis_col({"x_axis": ""}) is None
def test_dict_missing_column_name_returns_none(self):
"""Adhoc column dict without column_name returns None."""
assert _extract_x_axis_col({"x_axis": {"label": "ds"}}) is None
def test_sql_expression_x_axis_returns_none(self):
"""SQL expression adhoc columns have no column_name; returns None."""
assert (
_extract_x_axis_col(
{
"x_axis": {
"expressionType": "SQL",
"sqlExpression": "DATE_TRUNC('day', created_at)",
"label": "day",
}
}
)
is None
)
class TestBuildQueryContextTimeseriesAndMixed:
"""Regression tests for x_axis and mixed_timeseries query-context fixes.
Guards against two bugs: x_axis column dropped for echarts_timeseries_*
charts, and only one query rendered for mixed_timeseries charts.
"""
@patch("superset.common.query_context_factory.QueryContextFactory")
@patch("superset.daos.datasource.DatasourceDAO.get_datasource")
def test_echarts_timeseries_x_axis_included_in_columns(
self, mock_get_ds, mock_factory_cls
):
"""x_axis column is prepended to query columns for echarts_timeseries charts."""
mock_ds = Mock()
mock_ds.database.db_engine_spec.engine = "postgresql"
mock_get_ds.return_value = mock_ds
mock_factory = Mock()
mock_factory.create.return_value = Mock()
mock_factory_cls.return_value = mock_factory
form_data = {
"datasource_id": 1,
"datasource_type": "table",
"viz_type": "echarts_timeseries_line",
"x_axis": "ds",
"metrics": ["sum__sales"],
"groupby": ["region"],
}
with patch("superset.common.chart_data.ChartDataResultType") as mock_rt:
mock_rt.QUERY = "QUERY"
_build_query_context_from_form_data(form_data, chart=None)
queries = mock_factory.create.call_args[1]["queries"]
assert len(queries) == 1
assert queries[0]["columns"][0] == "ds"
assert "region" in queries[0]["columns"]
@patch("superset.common.query_context_factory.QueryContextFactory")
@patch("superset.daos.datasource.DatasourceDAO.get_datasource")
def test_echarts_timeseries_dict_x_axis_included_in_columns(
self, mock_get_ds, mock_factory_cls
):
"""Adhoc-column x_axis dict is resolved and prepended for echarts_timeseries."""
mock_ds = Mock()
mock_ds.database.db_engine_spec.engine = "postgresql"
mock_get_ds.return_value = mock_ds
mock_factory = Mock()
mock_factory.create.return_value = Mock()
mock_factory_cls.return_value = mock_factory
form_data = {
"datasource_id": 1,
"datasource_type": "table",
"viz_type": "echarts_timeseries_bar",
"x_axis": {"column_name": "order_date", "expressionType": "SIMPLE"},
"metrics": ["count"],
"groupby": [],
}
with patch("superset.common.chart_data.ChartDataResultType") as mock_rt:
mock_rt.QUERY = "QUERY"
_build_query_context_from_form_data(form_data, chart=None)
queries = mock_factory.create.call_args[1]["queries"]
assert queries[0]["columns"][0] == "order_date"
@patch("superset.common.query_context_factory.QueryContextFactory")
@patch("superset.daos.datasource.DatasourceDAO.get_datasource")
def test_echarts_timeseries_x_axis_not_duplicated_if_already_in_groupby(
self, mock_get_ds, mock_factory_cls
):
"""x_axis is not duplicated if it is already in groupby."""
mock_ds = Mock()
mock_ds.database.db_engine_spec.engine = "postgresql"
mock_get_ds.return_value = mock_ds
mock_factory = Mock()
mock_factory.create.return_value = Mock()
mock_factory_cls.return_value = mock_factory
form_data = {
"datasource_id": 1,
"datasource_type": "table",
"viz_type": "echarts_timeseries_line",
"x_axis": "ds",
"metrics": ["count"],
"groupby": ["ds"], # already present
}
with patch("superset.common.chart_data.ChartDataResultType") as mock_rt:
mock_rt.QUERY = "QUERY"
_build_query_context_from_form_data(form_data, chart=None)
queries = mock_factory.create.call_args[1]["queries"]
assert queries[0]["columns"].count("ds") == 1
@patch("superset.common.query_context_factory.QueryContextFactory")
@patch("superset.daos.datasource.DatasourceDAO.get_datasource")
def test_non_timeseries_x_axis_not_added(self, mock_get_ds, mock_factory_cls):
"""x_axis is not added for non-timeseries chart types (e.g. table)."""
mock_ds = Mock()
mock_ds.database.db_engine_spec.engine = "postgresql"
mock_get_ds.return_value = mock_ds
mock_factory = Mock()
mock_factory.create.return_value = Mock()
mock_factory_cls.return_value = mock_factory
form_data = {
"datasource_id": 1,
"datasource_type": "table",
"viz_type": "table",
"x_axis": "ds",
"metrics": ["count"],
"groupby": ["region"],
}
with patch("superset.common.chart_data.ChartDataResultType") as mock_rt:
mock_rt.QUERY = "QUERY"
_build_query_context_from_form_data(form_data, chart=None)
queries = mock_factory.create.call_args[1]["queries"]
assert "ds" not in queries[0]["columns"]
@patch("superset.common.query_context_factory.QueryContextFactory")
@patch("superset.daos.datasource.DatasourceDAO.get_datasource")
def test_mixed_timeseries_produces_two_queries(self, mock_get_ds, mock_factory_cls):
"""mixed_timeseries builds two query dicts — one per series layer."""
mock_ds = Mock()
mock_ds.database.db_engine_spec.engine = "postgresql"
mock_get_ds.return_value = mock_ds
mock_factory = Mock()
mock_factory.create.return_value = Mock()
mock_factory_cls.return_value = mock_factory
form_data = {
"datasource_id": 1,
"datasource_type": "table",
"viz_type": "mixed_timeseries",
"x_axis": "ds",
"metrics": ["sum__revenue"],
"groupby": ["country"],
"metrics_b": ["count"],
"groupby_b": ["channel"],
"time_range": "Last 30 days",
}
with patch("superset.common.chart_data.ChartDataResultType") as mock_rt:
mock_rt.QUERY = "QUERY"
_build_query_context_from_form_data(form_data, chart=None)
queries = mock_factory.create.call_args[1]["queries"]
assert len(queries) == 2
# Primary query
assert "ds" in queries[0]["columns"]
assert "country" in queries[0]["columns"]
assert queries[0]["metrics"] == ["sum__revenue"]
assert queries[0]["time_range"] == "Last 30 days"
# Secondary query
assert "ds" in queries[1]["columns"]
assert "channel" in queries[1]["columns"]
assert queries[1]["metrics"] == ["count"]
assert queries[1]["time_range"] == "Last 30 days"
@patch("superset.common.query_context_factory.QueryContextFactory")
@patch("superset.daos.datasource.DatasourceDAO.get_datasource")
def test_mixed_timeseries_x_axis_not_duplicated_in_secondary(
self, mock_get_ds, mock_factory_cls
):
"""x_axis is not duplicated in the secondary query if already in groupby_b."""
mock_ds = Mock()
mock_ds.database.db_engine_spec.engine = "postgresql"
mock_get_ds.return_value = mock_ds
mock_factory = Mock()
mock_factory.create.return_value = Mock()
mock_factory_cls.return_value = mock_factory
form_data = {
"datasource_id": 1,
"datasource_type": "table",
"viz_type": "mixed_timeseries",
"x_axis": "ds",
"metrics": ["count"],
"groupby": [],
"metrics_b": ["sum__sales"],
"groupby_b": ["ds"], # x_axis already present
}
with patch("superset.common.chart_data.ChartDataResultType") as mock_rt:
mock_rt.QUERY = "QUERY"
_build_query_context_from_form_data(form_data, chart=None)
queries = mock_factory.create.call_args[1]["queries"]
assert queries[1]["columns"].count("ds") == 1
@patch("superset.common.query_context_factory.QueryContextFactory")
@patch("superset.daos.datasource.DatasourceDAO.get_datasource")
def test_mixed_timeseries_empty_secondary(self, mock_get_ds, mock_factory_cls):
"""mixed_timeseries with no metrics_b/groupby_b still produces two queries."""
mock_ds = Mock()
mock_ds.database.db_engine_spec.engine = "postgresql"
mock_get_ds.return_value = mock_ds
mock_factory = Mock()
mock_factory.create.return_value = Mock()
mock_factory_cls.return_value = mock_factory
form_data = {
"datasource_id": 1,
"datasource_type": "table",
"viz_type": "mixed_timeseries",
"x_axis": "ds",
"metrics": ["count"],
"groupby": [],
}
with patch("superset.common.chart_data.ChartDataResultType") as mock_rt:
mock_rt.QUERY = "QUERY"
_build_query_context_from_form_data(form_data, chart=None)
queries = mock_factory.create.call_args[1]["queries"]
assert len(queries) == 2
assert queries[1]["metrics"] == []
@patch("superset.common.query_context_factory.QueryContextFactory")
@patch("superset.daos.datasource.DatasourceDAO.get_datasource")
def test_mixed_timeseries_time_range_b_overrides_secondary(
self, mock_get_ds, mock_factory_cls
):
"""time_range_b overrides the primary time_range for the secondary query."""
mock_ds = Mock()
mock_ds.database.db_engine_spec.engine = "postgresql"
mock_get_ds.return_value = mock_ds
mock_factory = Mock()
mock_factory.create.return_value = Mock()
mock_factory_cls.return_value = mock_factory
form_data = {
"datasource_id": 1,
"datasource_type": "table",
"viz_type": "mixed_timeseries",
"x_axis": "ds",
"metrics": ["sum__revenue"],
"groupby": [],
"metrics_b": ["count"],
"groupby_b": [],
"time_range": "Last 30 days",
"time_range_b": "Last 7 days",
}
with patch("superset.common.chart_data.ChartDataResultType") as mock_rt:
mock_rt.QUERY = "QUERY"
_build_query_context_from_form_data(form_data, chart=None)
queries = mock_factory.create.call_args[1]["queries"]
assert len(queries) == 2
assert queries[0]["time_range"] == "Last 30 days"
assert queries[1]["time_range"] == "Last 7 days"
@patch("superset.common.query_context_factory.QueryContextFactory")
@patch("superset.daos.datasource.DatasourceDAO.get_datasource")
def test_mixed_timeseries_row_limit_b_overrides_secondary(
self, mock_get_ds, mock_factory_cls
):
"""row_limit_b overrides the primary row_limit for the secondary query."""
mock_ds = Mock()
mock_ds.database.db_engine_spec.engine = "postgresql"
mock_get_ds.return_value = mock_ds
mock_factory = Mock()
mock_factory.create.return_value = Mock()
mock_factory_cls.return_value = mock_factory
form_data = {
"datasource_id": 1,
"datasource_type": "table",
"viz_type": "mixed_timeseries",
"x_axis": "ds",
"metrics": ["sum__revenue"],
"groupby": [],
"metrics_b": ["count"],
"groupby_b": [],
"row_limit": 100,
"row_limit_b": 50,
}
with patch("superset.common.chart_data.ChartDataResultType") as mock_rt:
mock_rt.QUERY = "QUERY"
_build_query_context_from_form_data(form_data, chart=None)
queries = mock_factory.create.call_args[1]["queries"]
assert len(queries) == 2
assert queries[0]["row_limit"] == 100
assert queries[1]["row_limit"] == 50
@patch("superset.common.query_context_factory.QueryContextFactory")
@patch("superset.daos.datasource.DatasourceDAO.get_datasource")
def test_mixed_timeseries_adhoc_filters_b_applied_to_secondary(
self, mock_get_ds, mock_factory_cls
):
"""adhoc_filters_b is processed and applied to the secondary query filters."""
mock_ds = Mock()
mock_ds.database.db_engine_spec.engine = "postgresql"
mock_get_ds.return_value = mock_ds
mock_factory = Mock()
mock_factory.create.return_value = Mock()
mock_factory_cls.return_value = mock_factory
form_data = {
"datasource_id": 1,
"datasource_type": "table",
"viz_type": "mixed_timeseries",
"x_axis": "ds",
"metrics": ["sum__revenue"],
"groupby": [],
"metrics_b": ["count"],
"groupby_b": [],
"adhoc_filters_b": [
{
"clause": "WHERE",
"expressionType": "SIMPLE",
"subject": "channel",
"operator": "==",
"comparator": "organic",
}
],
}
with patch("superset.common.chart_data.ChartDataResultType") as mock_rt:
mock_rt.QUERY = "QUERY"
_build_query_context_from_form_data(form_data, chart=None)
queries = mock_factory.create.call_args[1]["queries"]
assert len(queries) == 2
secondary_filters = queries[1].get("filters", [])
assert {"col": "channel", "op": "==", "val": "organic"} in secondary_filters
class TestResolveDatasourceName:
"""Tests for _resolve_datasource_name helper."""