fix(mcp): include x_axis column in query context for series charts with group_by (#37639)

This commit is contained in:
Amin Ghadersohi
2026-02-05 11:59:44 -07:00
committed by GitHub
parent 2e463078a2
commit 47db185e3b
8 changed files with 483 additions and 6 deletions

View File

@@ -152,6 +152,115 @@ class TestBigNumberChartFallback:
assert groupby_columns == []
class TestXAxisInQueryContext:
"""Tests for x_axis inclusion in fallback query context columns."""
def test_x_axis_string_included_in_columns(self):
"""Test that x_axis (string format) is included alongside groupby columns."""
form_data = {
"x_axis": "territory",
"groupby": ["year"],
"metrics": [{"label": "SUM(sales)"}],
"viz_type": "echarts_timeseries_bar",
}
groupby_columns = form_data.get("groupby", [])
x_axis_config = form_data.get("x_axis")
columns = groupby_columns.copy()
if x_axis_config and isinstance(x_axis_config, str):
if x_axis_config not in columns:
columns.insert(0, x_axis_config)
assert columns == ["territory", "year"]
def test_x_axis_dict_included_in_columns(self):
"""Test that x_axis (dict format with column_name) is included."""
form_data = {
"x_axis": {"column_name": "territory"},
"groupby": ["year"],
"metrics": [{"label": "SUM(sales)"}],
}
groupby_columns = form_data.get("groupby", [])
x_axis_config = form_data.get("x_axis")
columns = groupby_columns.copy()
if x_axis_config and isinstance(x_axis_config, str):
if x_axis_config not in columns:
columns.insert(0, x_axis_config)
elif x_axis_config and isinstance(x_axis_config, dict):
col_name = x_axis_config.get("column_name")
if col_name and col_name not in columns:
columns.insert(0, col_name)
assert columns == ["territory", "year"]
def test_no_x_axis_uses_groupby_only(self):
"""Test that without x_axis, only groupby columns are used."""
form_data = {
"groupby": ["region", "category"],
"metrics": [{"label": "SUM(sales)"}],
}
groupby_columns = form_data.get("groupby", [])
x_axis_config = form_data.get("x_axis")
columns = groupby_columns.copy()
if x_axis_config and isinstance(x_axis_config, str):
if x_axis_config not in columns:
columns.insert(0, x_axis_config)
assert columns == ["region", "category"]
def test_x_axis_not_duplicated_if_in_groupby(self):
"""Test that x_axis is not duplicated if already in groupby list."""
form_data = {
"x_axis": "territory",
"groupby": ["territory", "year"],
"metrics": [{"label": "SUM(sales)"}],
}
groupby_columns = form_data.get("groupby", [])
x_axis_config = form_data.get("x_axis")
columns = groupby_columns.copy()
if x_axis_config and isinstance(x_axis_config, str):
if x_axis_config not in columns:
columns.insert(0, x_axis_config)
assert columns == ["territory", "year"]
def test_x_axis_without_groupby(self):
"""Test that x_axis works when there's no groupby."""
form_data = {
"x_axis": "date",
"metrics": [{"label": "SUM(sales)"}],
}
groupby_columns = form_data.get("groupby", [])
x_axis_config = form_data.get("x_axis")
columns = groupby_columns.copy()
if x_axis_config and isinstance(x_axis_config, str):
if x_axis_config not in columns:
columns.insert(0, x_axis_config)
assert columns == ["date"]
def test_empty_groupby_with_x_axis(self):
"""Test x_axis with explicitly empty groupby."""
form_data = {
"x_axis": "platform",
"groupby": [],
"metrics": [{"label": "SUM(global_sales)"}],
}
groupby_columns = form_data.get("groupby", [])
x_axis_config = form_data.get("x_axis")
columns = groupby_columns.copy()
if x_axis_config and isinstance(x_axis_config, str):
if x_axis_config not in columns:
columns.insert(0, x_axis_config)
assert columns == ["platform"]
class TestGetChartDataRequestSchema:
"""Test the GetChartDataRequest schema validation."""

View File

@@ -29,6 +29,119 @@ from superset.mcp_service.chart.schemas import (
)
class TestPreviewXAxisInQueryContext:
"""Tests for x_axis inclusion in preview query context columns.
When generating chart previews (table, vega_lite), the query context must
include both x_axis and groupby columns. Previously only groupby was used,
causing series charts with group_by to lose the x_axis dimension.
"""
def test_table_preview_includes_x_axis_and_groupby(self):
"""Test that table preview builds columns with both x_axis and groupby."""
form_data = {
"x_axis": "territory",
"groupby": ["year"],
"metrics": [{"label": "SUM(sales)"}],
}
x_axis_config = form_data.get("x_axis")
groupby_columns = form_data.get("groupby", [])
columns = groupby_columns.copy()
if x_axis_config and isinstance(x_axis_config, str):
if x_axis_config not in columns:
columns.insert(0, x_axis_config)
assert columns == ["territory", "year"]
def test_vega_lite_preview_includes_x_axis_and_groupby(self):
"""Test that vega_lite preview builds columns with both x_axis and groupby."""
form_data = {
"x_axis": "platform",
"groupby": ["genre"],
"metrics": [{"label": "SUM(global_sales)"}],
}
x_axis_config = form_data.get("x_axis")
groupby_columns = form_data.get("groupby", [])
columns = groupby_columns.copy()
if x_axis_config and isinstance(x_axis_config, str):
if x_axis_config not in columns:
columns.insert(0, x_axis_config)
assert columns == ["platform", "genre"]
def test_preview_x_axis_dict_format(self):
"""Test preview column building with x_axis as dict."""
form_data = {
"x_axis": {"column_name": "order_date"},
"groupby": ["region"],
"metrics": [{"label": "SUM(revenue)"}],
}
x_axis_config = form_data.get("x_axis")
groupby_columns = form_data.get("groupby", [])
columns = groupby_columns.copy()
if x_axis_config and isinstance(x_axis_config, str):
if x_axis_config not in columns:
columns.insert(0, x_axis_config)
elif x_axis_config and isinstance(x_axis_config, dict):
col_name = x_axis_config.get("column_name")
if col_name and col_name not in columns:
columns.insert(0, col_name)
assert columns == ["order_date", "region"]
def test_preview_no_groupby_x_axis_only(self):
"""Test preview with x_axis but no groupby."""
form_data = {
"x_axis": "date",
"metrics": [{"label": "SUM(sales)"}],
}
x_axis_config = form_data.get("x_axis")
groupby_columns = form_data.get("groupby", [])
columns = groupby_columns.copy()
if x_axis_config and isinstance(x_axis_config, str):
if x_axis_config not in columns:
columns.insert(0, x_axis_config)
assert columns == ["date"]
def test_preview_no_x_axis_groupby_only(self):
"""Test preview with groupby but no x_axis (e.g., table chart)."""
form_data = {
"groupby": ["category", "region"],
"metrics": [{"label": "COUNT(*)"}],
}
x_axis_config = form_data.get("x_axis")
groupby_columns = form_data.get("groupby", [])
columns = groupby_columns.copy()
if x_axis_config and isinstance(x_axis_config, str):
if x_axis_config not in columns:
columns.insert(0, x_axis_config)
assert columns == ["category", "region"]
def test_preview_x_axis_not_duplicated(self):
"""Test x_axis isn't duplicated if already in groupby."""
form_data = {
"x_axis": "territory",
"groupby": ["territory", "year"],
"metrics": [{"label": "SUM(sales)"}],
}
x_axis_config = form_data.get("x_axis")
groupby_columns = form_data.get("groupby", [])
columns = groupby_columns.copy()
if x_axis_config and isinstance(x_axis_config, str):
if x_axis_config not in columns:
columns.insert(0, x_axis_config)
assert columns == ["territory", "year"]
class TestGetChartPreview:
"""Tests for get_chart_preview MCP tool."""