mirror of
https://github.com/apache/superset.git
synced 2026-04-20 08:34:37 +00:00
fix(mcp): include x_axis column in query context for series charts with group_by (#37639)
This commit is contained in:
@@ -219,3 +219,39 @@ class TestXYChartConfig:
|
||||
kind="area",
|
||||
)
|
||||
assert config.kind == "area"
|
||||
|
||||
def test_unknown_fields_rejected(self) -> None:
|
||||
"""Test that unknown fields like 'series' are rejected."""
|
||||
with pytest.raises(ValidationError, match="Extra inputs are not permitted"):
|
||||
XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="territory"),
|
||||
y=[ColumnRef(name="sales", aggregate="SUM")],
|
||||
kind="bar",
|
||||
series=ColumnRef(name="year"),
|
||||
)
|
||||
|
||||
def test_group_by_accepted(self) -> None:
|
||||
"""Test that group_by is the correct field for series grouping."""
|
||||
config = XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="territory"),
|
||||
y=[ColumnRef(name="sales", aggregate="SUM")],
|
||||
kind="bar",
|
||||
group_by=ColumnRef(name="year"),
|
||||
)
|
||||
assert config.group_by is not None
|
||||
assert config.group_by.name == "year"
|
||||
|
||||
|
||||
class TestTableChartConfigExtraFields:
|
||||
"""Test TableChartConfig rejects unknown fields."""
|
||||
|
||||
def test_unknown_fields_rejected(self) -> None:
|
||||
"""Test that unknown fields are rejected."""
|
||||
with pytest.raises(ValidationError, match="Extra inputs are not permitted"):
|
||||
TableChartConfig(
|
||||
chart_type="table",
|
||||
columns=[ColumnRef(name="product")],
|
||||
foo="bar",
|
||||
)
|
||||
|
||||
158
tests/unit_tests/mcp_service/chart/test_preview_utils.py
Normal file
158
tests/unit_tests/mcp_service/chart/test_preview_utils.py
Normal file
@@ -0,0 +1,158 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Tests for preview_utils query context column building.
|
||||
"""
|
||||
|
||||
|
||||
class TestPreviewUtilsColumnBuilding:
|
||||
"""Tests for x_axis + groupby column building in generate_preview_from_form_data.
|
||||
|
||||
The function must build the columns list from both x_axis and groupby for
|
||||
XY charts, and fall back to form_data["columns"] for table charts.
|
||||
"""
|
||||
|
||||
def test_xy_chart_uses_x_axis_and_groupby(self):
|
||||
"""Test XY chart form_data builds columns from x_axis + groupby."""
|
||||
form_data = {
|
||||
"x_axis": "territory",
|
||||
"groupby": ["year"],
|
||||
"metrics": [{"label": "SUM(sales)"}],
|
||||
}
|
||||
|
||||
x_axis_config = form_data.get("x_axis")
|
||||
groupby_columns = form_data.get("groupby", [])
|
||||
raw_columns = form_data.get("columns", [])
|
||||
|
||||
columns = (
|
||||
raw_columns.copy() if "columns" in form_data else groupby_columns.copy()
|
||||
)
|
||||
if x_axis_config and isinstance(x_axis_config, str):
|
||||
if x_axis_config not in columns:
|
||||
columns.insert(0, x_axis_config)
|
||||
elif x_axis_config and isinstance(x_axis_config, dict):
|
||||
col_name = x_axis_config.get("column_name")
|
||||
if col_name and col_name not in columns:
|
||||
columns.insert(0, col_name)
|
||||
|
||||
assert columns == ["territory", "year"]
|
||||
|
||||
def test_table_chart_uses_columns_field(self):
|
||||
"""Test table chart form_data uses 'columns' field directly."""
|
||||
form_data = {
|
||||
"columns": ["name", "region", "sales"],
|
||||
"metrics": [],
|
||||
}
|
||||
|
||||
x_axis_config = form_data.get("x_axis")
|
||||
groupby_columns = form_data.get("groupby", [])
|
||||
raw_columns = form_data.get("columns", [])
|
||||
|
||||
columns = (
|
||||
raw_columns.copy() if "columns" in form_data else groupby_columns.copy()
|
||||
)
|
||||
if x_axis_config and isinstance(x_axis_config, str):
|
||||
if x_axis_config not in columns:
|
||||
columns.insert(0, x_axis_config)
|
||||
|
||||
assert columns == ["name", "region", "sales"]
|
||||
|
||||
def test_xy_chart_x_axis_dict_format(self):
|
||||
"""Test XY chart with x_axis as dict (column_name key)."""
|
||||
form_data = {
|
||||
"x_axis": {"column_name": "order_date"},
|
||||
"groupby": ["product_type"],
|
||||
"metrics": [{"label": "SUM(revenue)"}],
|
||||
}
|
||||
|
||||
x_axis_config = form_data.get("x_axis")
|
||||
groupby_columns = form_data.get("groupby", [])
|
||||
raw_columns = form_data.get("columns", [])
|
||||
|
||||
columns = (
|
||||
raw_columns.copy() if "columns" in form_data else groupby_columns.copy()
|
||||
)
|
||||
if x_axis_config and isinstance(x_axis_config, str):
|
||||
if x_axis_config not in columns:
|
||||
columns.insert(0, x_axis_config)
|
||||
elif x_axis_config and isinstance(x_axis_config, dict):
|
||||
col_name = x_axis_config.get("column_name")
|
||||
if col_name and col_name not in columns:
|
||||
columns.insert(0, col_name)
|
||||
|
||||
assert columns == ["order_date", "product_type"]
|
||||
|
||||
def test_no_x_axis_no_columns_uses_groupby(self):
|
||||
"""Test fallback to groupby when no x_axis and no columns."""
|
||||
form_data = {
|
||||
"groupby": ["category"],
|
||||
"metrics": [{"label": "COUNT(*)"}],
|
||||
}
|
||||
|
||||
x_axis_config = form_data.get("x_axis")
|
||||
groupby_columns = form_data.get("groupby", [])
|
||||
raw_columns = form_data.get("columns", [])
|
||||
|
||||
columns = (
|
||||
raw_columns.copy() if "columns" in form_data else groupby_columns.copy()
|
||||
)
|
||||
if x_axis_config and isinstance(x_axis_config, str):
|
||||
if x_axis_config not in columns:
|
||||
columns.insert(0, x_axis_config)
|
||||
|
||||
assert columns == ["category"]
|
||||
|
||||
def test_empty_form_data_returns_empty_columns(self):
|
||||
"""Test empty form_data returns empty columns list."""
|
||||
form_data: dict = {
|
||||
"metrics": [{"label": "COUNT(*)"}],
|
||||
}
|
||||
|
||||
x_axis_config = form_data.get("x_axis")
|
||||
groupby_columns = form_data.get("groupby", [])
|
||||
raw_columns = form_data.get("columns", [])
|
||||
|
||||
columns = (
|
||||
raw_columns.copy() if "columns" in form_data else groupby_columns.copy()
|
||||
)
|
||||
if x_axis_config and isinstance(x_axis_config, str):
|
||||
if x_axis_config not in columns:
|
||||
columns.insert(0, x_axis_config)
|
||||
|
||||
assert columns == []
|
||||
|
||||
def test_x_axis_not_duplicated_when_in_groupby(self):
|
||||
"""Test x_axis is not added if already present in groupby."""
|
||||
form_data = {
|
||||
"x_axis": "territory",
|
||||
"groupby": ["territory", "year"],
|
||||
"metrics": [{"label": "SUM(sales)"}],
|
||||
}
|
||||
|
||||
x_axis_config = form_data.get("x_axis")
|
||||
groupby_columns = form_data.get("groupby", [])
|
||||
raw_columns = form_data.get("columns", [])
|
||||
|
||||
columns = (
|
||||
raw_columns.copy() if "columns" in form_data else groupby_columns.copy()
|
||||
)
|
||||
if x_axis_config and isinstance(x_axis_config, str):
|
||||
if x_axis_config not in columns:
|
||||
columns.insert(0, x_axis_config)
|
||||
|
||||
assert columns == ["territory", "year"]
|
||||
@@ -152,6 +152,115 @@ class TestBigNumberChartFallback:
|
||||
assert groupby_columns == []
|
||||
|
||||
|
||||
class TestXAxisInQueryContext:
|
||||
"""Tests for x_axis inclusion in fallback query context columns."""
|
||||
|
||||
def test_x_axis_string_included_in_columns(self):
|
||||
"""Test that x_axis (string format) is included alongside groupby columns."""
|
||||
form_data = {
|
||||
"x_axis": "territory",
|
||||
"groupby": ["year"],
|
||||
"metrics": [{"label": "SUM(sales)"}],
|
||||
"viz_type": "echarts_timeseries_bar",
|
||||
}
|
||||
|
||||
groupby_columns = form_data.get("groupby", [])
|
||||
x_axis_config = form_data.get("x_axis")
|
||||
columns = groupby_columns.copy()
|
||||
if x_axis_config and isinstance(x_axis_config, str):
|
||||
if x_axis_config not in columns:
|
||||
columns.insert(0, x_axis_config)
|
||||
|
||||
assert columns == ["territory", "year"]
|
||||
|
||||
def test_x_axis_dict_included_in_columns(self):
|
||||
"""Test that x_axis (dict format with column_name) is included."""
|
||||
form_data = {
|
||||
"x_axis": {"column_name": "territory"},
|
||||
"groupby": ["year"],
|
||||
"metrics": [{"label": "SUM(sales)"}],
|
||||
}
|
||||
|
||||
groupby_columns = form_data.get("groupby", [])
|
||||
x_axis_config = form_data.get("x_axis")
|
||||
columns = groupby_columns.copy()
|
||||
if x_axis_config and isinstance(x_axis_config, str):
|
||||
if x_axis_config not in columns:
|
||||
columns.insert(0, x_axis_config)
|
||||
elif x_axis_config and isinstance(x_axis_config, dict):
|
||||
col_name = x_axis_config.get("column_name")
|
||||
if col_name and col_name not in columns:
|
||||
columns.insert(0, col_name)
|
||||
|
||||
assert columns == ["territory", "year"]
|
||||
|
||||
def test_no_x_axis_uses_groupby_only(self):
|
||||
"""Test that without x_axis, only groupby columns are used."""
|
||||
form_data = {
|
||||
"groupby": ["region", "category"],
|
||||
"metrics": [{"label": "SUM(sales)"}],
|
||||
}
|
||||
|
||||
groupby_columns = form_data.get("groupby", [])
|
||||
x_axis_config = form_data.get("x_axis")
|
||||
columns = groupby_columns.copy()
|
||||
if x_axis_config and isinstance(x_axis_config, str):
|
||||
if x_axis_config not in columns:
|
||||
columns.insert(0, x_axis_config)
|
||||
|
||||
assert columns == ["region", "category"]
|
||||
|
||||
def test_x_axis_not_duplicated_if_in_groupby(self):
|
||||
"""Test that x_axis is not duplicated if already in groupby list."""
|
||||
form_data = {
|
||||
"x_axis": "territory",
|
||||
"groupby": ["territory", "year"],
|
||||
"metrics": [{"label": "SUM(sales)"}],
|
||||
}
|
||||
|
||||
groupby_columns = form_data.get("groupby", [])
|
||||
x_axis_config = form_data.get("x_axis")
|
||||
columns = groupby_columns.copy()
|
||||
if x_axis_config and isinstance(x_axis_config, str):
|
||||
if x_axis_config not in columns:
|
||||
columns.insert(0, x_axis_config)
|
||||
|
||||
assert columns == ["territory", "year"]
|
||||
|
||||
def test_x_axis_without_groupby(self):
|
||||
"""Test that x_axis works when there's no groupby."""
|
||||
form_data = {
|
||||
"x_axis": "date",
|
||||
"metrics": [{"label": "SUM(sales)"}],
|
||||
}
|
||||
|
||||
groupby_columns = form_data.get("groupby", [])
|
||||
x_axis_config = form_data.get("x_axis")
|
||||
columns = groupby_columns.copy()
|
||||
if x_axis_config and isinstance(x_axis_config, str):
|
||||
if x_axis_config not in columns:
|
||||
columns.insert(0, x_axis_config)
|
||||
|
||||
assert columns == ["date"]
|
||||
|
||||
def test_empty_groupby_with_x_axis(self):
|
||||
"""Test x_axis with explicitly empty groupby."""
|
||||
form_data = {
|
||||
"x_axis": "platform",
|
||||
"groupby": [],
|
||||
"metrics": [{"label": "SUM(global_sales)"}],
|
||||
}
|
||||
|
||||
groupby_columns = form_data.get("groupby", [])
|
||||
x_axis_config = form_data.get("x_axis")
|
||||
columns = groupby_columns.copy()
|
||||
if x_axis_config and isinstance(x_axis_config, str):
|
||||
if x_axis_config not in columns:
|
||||
columns.insert(0, x_axis_config)
|
||||
|
||||
assert columns == ["platform"]
|
||||
|
||||
|
||||
class TestGetChartDataRequestSchema:
|
||||
"""Test the GetChartDataRequest schema validation."""
|
||||
|
||||
|
||||
@@ -29,6 +29,119 @@ from superset.mcp_service.chart.schemas import (
|
||||
)
|
||||
|
||||
|
||||
class TestPreviewXAxisInQueryContext:
|
||||
"""Tests for x_axis inclusion in preview query context columns.
|
||||
|
||||
When generating chart previews (table, vega_lite), the query context must
|
||||
include both x_axis and groupby columns. Previously only groupby was used,
|
||||
causing series charts with group_by to lose the x_axis dimension.
|
||||
"""
|
||||
|
||||
def test_table_preview_includes_x_axis_and_groupby(self):
|
||||
"""Test that table preview builds columns with both x_axis and groupby."""
|
||||
form_data = {
|
||||
"x_axis": "territory",
|
||||
"groupby": ["year"],
|
||||
"metrics": [{"label": "SUM(sales)"}],
|
||||
}
|
||||
|
||||
x_axis_config = form_data.get("x_axis")
|
||||
groupby_columns = form_data.get("groupby", [])
|
||||
columns = groupby_columns.copy()
|
||||
if x_axis_config and isinstance(x_axis_config, str):
|
||||
if x_axis_config not in columns:
|
||||
columns.insert(0, x_axis_config)
|
||||
|
||||
assert columns == ["territory", "year"]
|
||||
|
||||
def test_vega_lite_preview_includes_x_axis_and_groupby(self):
|
||||
"""Test that vega_lite preview builds columns with both x_axis and groupby."""
|
||||
form_data = {
|
||||
"x_axis": "platform",
|
||||
"groupby": ["genre"],
|
||||
"metrics": [{"label": "SUM(global_sales)"}],
|
||||
}
|
||||
|
||||
x_axis_config = form_data.get("x_axis")
|
||||
groupby_columns = form_data.get("groupby", [])
|
||||
columns = groupby_columns.copy()
|
||||
if x_axis_config and isinstance(x_axis_config, str):
|
||||
if x_axis_config not in columns:
|
||||
columns.insert(0, x_axis_config)
|
||||
|
||||
assert columns == ["platform", "genre"]
|
||||
|
||||
def test_preview_x_axis_dict_format(self):
|
||||
"""Test preview column building with x_axis as dict."""
|
||||
form_data = {
|
||||
"x_axis": {"column_name": "order_date"},
|
||||
"groupby": ["region"],
|
||||
"metrics": [{"label": "SUM(revenue)"}],
|
||||
}
|
||||
|
||||
x_axis_config = form_data.get("x_axis")
|
||||
groupby_columns = form_data.get("groupby", [])
|
||||
columns = groupby_columns.copy()
|
||||
if x_axis_config and isinstance(x_axis_config, str):
|
||||
if x_axis_config not in columns:
|
||||
columns.insert(0, x_axis_config)
|
||||
elif x_axis_config and isinstance(x_axis_config, dict):
|
||||
col_name = x_axis_config.get("column_name")
|
||||
if col_name and col_name not in columns:
|
||||
columns.insert(0, col_name)
|
||||
|
||||
assert columns == ["order_date", "region"]
|
||||
|
||||
def test_preview_no_groupby_x_axis_only(self):
|
||||
"""Test preview with x_axis but no groupby."""
|
||||
form_data = {
|
||||
"x_axis": "date",
|
||||
"metrics": [{"label": "SUM(sales)"}],
|
||||
}
|
||||
|
||||
x_axis_config = form_data.get("x_axis")
|
||||
groupby_columns = form_data.get("groupby", [])
|
||||
columns = groupby_columns.copy()
|
||||
if x_axis_config and isinstance(x_axis_config, str):
|
||||
if x_axis_config not in columns:
|
||||
columns.insert(0, x_axis_config)
|
||||
|
||||
assert columns == ["date"]
|
||||
|
||||
def test_preview_no_x_axis_groupby_only(self):
|
||||
"""Test preview with groupby but no x_axis (e.g., table chart)."""
|
||||
form_data = {
|
||||
"groupby": ["category", "region"],
|
||||
"metrics": [{"label": "COUNT(*)"}],
|
||||
}
|
||||
|
||||
x_axis_config = form_data.get("x_axis")
|
||||
groupby_columns = form_data.get("groupby", [])
|
||||
columns = groupby_columns.copy()
|
||||
if x_axis_config and isinstance(x_axis_config, str):
|
||||
if x_axis_config not in columns:
|
||||
columns.insert(0, x_axis_config)
|
||||
|
||||
assert columns == ["category", "region"]
|
||||
|
||||
def test_preview_x_axis_not_duplicated(self):
|
||||
"""Test x_axis isn't duplicated if already in groupby."""
|
||||
form_data = {
|
||||
"x_axis": "territory",
|
||||
"groupby": ["territory", "year"],
|
||||
"metrics": [{"label": "SUM(sales)"}],
|
||||
}
|
||||
|
||||
x_axis_config = form_data.get("x_axis")
|
||||
groupby_columns = form_data.get("groupby", [])
|
||||
columns = groupby_columns.copy()
|
||||
if x_axis_config and isinstance(x_axis_config, str):
|
||||
if x_axis_config not in columns:
|
||||
columns.insert(0, x_axis_config)
|
||||
|
||||
assert columns == ["territory", "year"]
|
||||
|
||||
|
||||
class TestGetChartPreview:
|
||||
"""Tests for get_chart_preview MCP tool."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user