diff --git a/superset/mcp_service/chart/preview_utils.py b/superset/mcp_service/chart/preview_utils.py index 3db475c0da1..677d3034fd4 100644 --- a/superset/mcp_service/chart/preview_utils.py +++ b/superset/mcp_service/chart/preview_utils.py @@ -36,6 +36,23 @@ from superset.mcp_service.chart.schemas import ( logger = logging.getLogger(__name__) +def _build_query_columns(form_data: Dict[str, Any]) -> list[str]: + """Build query columns list from form_data, including both x_axis and groupby.""" + x_axis_config = form_data.get("x_axis") + groupby_columns: list[str] = form_data.get("groupby") or [] + raw_columns: list[str] = form_data.get("columns") or [] + + columns = raw_columns.copy() if "columns" in form_data else groupby_columns.copy() + if x_axis_config and isinstance(x_axis_config, str): + if x_axis_config not in columns: + columns.insert(0, x_axis_config) + elif x_axis_config and isinstance(x_axis_config, dict): + col_name = x_axis_config.get("column_name") + if col_name and col_name not in columns: + columns.insert(0, col_name) + return columns + + def generate_preview_from_form_data( form_data: Dict[str, Any], dataset_id: int, preview_format: str ) -> Any: @@ -64,12 +81,16 @@ def generate_preview_from_form_data( # Create query context from form data using factory from superset.common.query_context_factory import QueryContextFactory + # Build columns list: include x_axis and groupby for XY charts, + # fall back to form_data "columns" for table charts + columns = _build_query_columns(form_data) + factory = QueryContextFactory() query_context_obj = factory.create( datasource={"id": dataset_id, "type": "table"}, queries=[ { - "columns": form_data.get("columns", []), + "columns": columns, "metrics": form_data.get("metrics", []), "orderby": form_data.get("orderby", []), "row_limit": form_data.get("row_limit", 100), diff --git a/superset/mcp_service/chart/schemas.py b/superset/mcp_service/chart/schemas.py index b813bc4ebc5..6b5ed699672 100644 --- a/superset/mcp_service/chart/schemas.py +++ b/superset/mcp_service/chart/schemas.py @@ -606,6 +606,8 @@ class FilterConfig(BaseModel): # Actual chart types class TableChartConfig(BaseModel): + model_config = ConfigDict(extra="forbid") + chart_type: Literal["table"] = Field( ..., description="Chart type (REQUIRED: must be 'table')" ) @@ -659,6 +661,8 @@ class TableChartConfig(BaseModel): class XYChartConfig(BaseModel): + model_config = ConfigDict(extra="forbid") + chart_type: Literal["xy"] = Field( ..., description=( @@ -692,7 +696,11 @@ class XYChartConfig(BaseModel): False, description="Stack bars/areas on top of each other instead of side-by-side", ) - group_by: ColumnRef | None = Field(None, description="Column to group by") + group_by: ColumnRef | None = Field( + None, + description="Column to group by (creates series/breakdown). " + "Use this field for series grouping — do NOT use 'series'.", + ) x_axis: AxisConfig | None = Field(None, description="X-axis configuration") y_axis: AxisConfig | None = Field(None, description="Y-axis configuration") legend: LegendConfig | None = Field(None, description="Legend configuration") diff --git a/superset/mcp_service/chart/tool/get_chart_data.py b/superset/mcp_service/chart/tool/get_chart_data.py index 9984aa1d9bb..91d478b9a54 100644 --- a/superset/mcp_service/chart/tool/get_chart_data.py +++ b/superset/mcp_service/chart/tool/get_chart_data.py @@ -176,7 +176,18 @@ async def get_chart_data( # noqa: C901 else: # Standard charts use "metrics" (plural) and "groupby" metrics = form_data.get("metrics", []) - groupby_columns = form_data.get("groupby", []) + groupby_columns = form_data.get("groupby") or [] + + # Build query columns list: include both x_axis and groupby + x_axis_config = form_data.get("x_axis") + query_columns = groupby_columns.copy() + if x_axis_config and isinstance(x_axis_config, str): + if x_axis_config not in query_columns: + query_columns.insert(0, x_axis_config) + elif x_axis_config and isinstance(x_axis_config, dict): + col_name = x_axis_config.get("column_name") + if col_name and col_name not in query_columns: + query_columns.insert(0, col_name) query_context = factory.create( datasource={ @@ -186,7 +197,7 @@ async def get_chart_data( # noqa: C901 queries=[ { "filters": form_data.get("filters", []), - "columns": groupby_columns, + "columns": query_columns, "metrics": metrics, "row_limit": row_limit, "order_desc": True, diff --git a/superset/mcp_service/chart/tool/get_chart_preview.py b/superset/mcp_service/chart/tool/get_chart_preview.py index 9e001540396..fbc1a5802be 100644 --- a/superset/mcp_service/chart/tool/get_chart_preview.py +++ b/superset/mcp_service/chart/tool/get_chart_preview.py @@ -56,6 +56,22 @@ class ChartLike(Protocol): uuid: Any +def _build_query_columns(form_data: Dict[str, Any]) -> list[str]: + """Build query columns list from form_data, including both x_axis and groupby.""" + x_axis_config = form_data.get("x_axis") + groupby_columns: list[str] = form_data.get("groupby") or [] + + columns = groupby_columns.copy() + if x_axis_config and isinstance(x_axis_config, str): + if x_axis_config not in columns: + columns.insert(0, x_axis_config) + elif x_axis_config and isinstance(x_axis_config, dict): + col_name = x_axis_config.get("column_name") + if col_name and col_name not in columns: + columns.insert(0, col_name) + return columns + + class PreviewFormatStrategy: """Base class for preview format strategies.""" @@ -185,6 +201,8 @@ class TablePreviewStrategy(PreviewFormatStrategy): error_type="InvalidChart", ) + columns = _build_query_columns(form_data) + factory = QueryContextFactory() query_context = factory.create( datasource={ @@ -194,7 +212,7 @@ class TablePreviewStrategy(PreviewFormatStrategy): queries=[ { "filters": form_data.get("filters", []), - "columns": form_data.get("groupby", []), + "columns": columns, "metrics": form_data.get("metrics", []), "row_limit": 20, "order_desc": True, @@ -279,6 +297,9 @@ class VegaLitePreviewStrategy(PreviewFormatStrategy): utils_json.loads(self.chart.params) if self.chart.params else {} ) + # Build columns list: include both x_axis and groupby + columns = _build_query_columns(form_data) + # Create query context for data retrieval factory = QueryContextFactory() query_context = factory.create( @@ -289,7 +310,7 @@ class VegaLitePreviewStrategy(PreviewFormatStrategy): queries=[ { "filters": form_data.get("filters", []), - "columns": form_data.get("groupby", []), + "columns": columns, "metrics": form_data.get("metrics", []), "row_limit": 1000, # More data for visualization "order_desc": True, diff --git a/tests/unit_tests/mcp_service/chart/test_chart_schemas.py b/tests/unit_tests/mcp_service/chart/test_chart_schemas.py index 5bbae63912b..ae13bfc8a75 100644 --- a/tests/unit_tests/mcp_service/chart/test_chart_schemas.py +++ b/tests/unit_tests/mcp_service/chart/test_chart_schemas.py @@ -219,3 +219,39 @@ class TestXYChartConfig: kind="area", ) assert config.kind == "area" + + def test_unknown_fields_rejected(self) -> None: + """Test that unknown fields like 'series' are rejected.""" + with pytest.raises(ValidationError, match="Extra inputs are not permitted"): + XYChartConfig( + chart_type="xy", + x=ColumnRef(name="territory"), + y=[ColumnRef(name="sales", aggregate="SUM")], + kind="bar", + series=ColumnRef(name="year"), + ) + + def test_group_by_accepted(self) -> None: + """Test that group_by is the correct field for series grouping.""" + config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="territory"), + y=[ColumnRef(name="sales", aggregate="SUM")], + kind="bar", + group_by=ColumnRef(name="year"), + ) + assert config.group_by is not None + assert config.group_by.name == "year" + + +class TestTableChartConfigExtraFields: + """Test TableChartConfig rejects unknown fields.""" + + def test_unknown_fields_rejected(self) -> None: + """Test that unknown fields are rejected.""" + with pytest.raises(ValidationError, match="Extra inputs are not permitted"): + TableChartConfig( + chart_type="table", + columns=[ColumnRef(name="product")], + foo="bar", + ) diff --git a/tests/unit_tests/mcp_service/chart/test_preview_utils.py b/tests/unit_tests/mcp_service/chart/test_preview_utils.py new file mode 100644 index 00000000000..0190203e90e --- /dev/null +++ b/tests/unit_tests/mcp_service/chart/test_preview_utils.py @@ -0,0 +1,158 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Tests for preview_utils query context column building. +""" + + +class TestPreviewUtilsColumnBuilding: + """Tests for x_axis + groupby column building in generate_preview_from_form_data. + + The function must build the columns list from both x_axis and groupby for + XY charts, and fall back to form_data["columns"] for table charts. + """ + + def test_xy_chart_uses_x_axis_and_groupby(self): + """Test XY chart form_data builds columns from x_axis + groupby.""" + form_data = { + "x_axis": "territory", + "groupby": ["year"], + "metrics": [{"label": "SUM(sales)"}], + } + + x_axis_config = form_data.get("x_axis") + groupby_columns = form_data.get("groupby", []) + raw_columns = form_data.get("columns", []) + + columns = ( + raw_columns.copy() if "columns" in form_data else groupby_columns.copy() + ) + if x_axis_config and isinstance(x_axis_config, str): + if x_axis_config not in columns: + columns.insert(0, x_axis_config) + elif x_axis_config and isinstance(x_axis_config, dict): + col_name = x_axis_config.get("column_name") + if col_name and col_name not in columns: + columns.insert(0, col_name) + + assert columns == ["territory", "year"] + + def test_table_chart_uses_columns_field(self): + """Test table chart form_data uses 'columns' field directly.""" + form_data = { + "columns": ["name", "region", "sales"], + "metrics": [], + } + + x_axis_config = form_data.get("x_axis") + groupby_columns = form_data.get("groupby", []) + raw_columns = form_data.get("columns", []) + + columns = ( + raw_columns.copy() if "columns" in form_data else groupby_columns.copy() + ) + if x_axis_config and isinstance(x_axis_config, str): + if x_axis_config not in columns: + columns.insert(0, x_axis_config) + + assert columns == ["name", "region", "sales"] + + def test_xy_chart_x_axis_dict_format(self): + """Test XY chart with x_axis as dict (column_name key).""" + form_data = { + "x_axis": {"column_name": "order_date"}, + "groupby": ["product_type"], + "metrics": [{"label": "SUM(revenue)"}], + } + + x_axis_config = form_data.get("x_axis") + groupby_columns = form_data.get("groupby", []) + raw_columns = form_data.get("columns", []) + + columns = ( + raw_columns.copy() if "columns" in form_data else groupby_columns.copy() + ) + if x_axis_config and isinstance(x_axis_config, str): + if x_axis_config not in columns: + columns.insert(0, x_axis_config) + elif x_axis_config and isinstance(x_axis_config, dict): + col_name = x_axis_config.get("column_name") + if col_name and col_name not in columns: + columns.insert(0, col_name) + + assert columns == ["order_date", "product_type"] + + def test_no_x_axis_no_columns_uses_groupby(self): + """Test fallback to groupby when no x_axis and no columns.""" + form_data = { + "groupby": ["category"], + "metrics": [{"label": "COUNT(*)"}], + } + + x_axis_config = form_data.get("x_axis") + groupby_columns = form_data.get("groupby", []) + raw_columns = form_data.get("columns", []) + + columns = ( + raw_columns.copy() if "columns" in form_data else groupby_columns.copy() + ) + if x_axis_config and isinstance(x_axis_config, str): + if x_axis_config not in columns: + columns.insert(0, x_axis_config) + + assert columns == ["category"] + + def test_empty_form_data_returns_empty_columns(self): + """Test empty form_data returns empty columns list.""" + form_data: dict = { + "metrics": [{"label": "COUNT(*)"}], + } + + x_axis_config = form_data.get("x_axis") + groupby_columns = form_data.get("groupby", []) + raw_columns = form_data.get("columns", []) + + columns = ( + raw_columns.copy() if "columns" in form_data else groupby_columns.copy() + ) + if x_axis_config and isinstance(x_axis_config, str): + if x_axis_config not in columns: + columns.insert(0, x_axis_config) + + assert columns == [] + + def test_x_axis_not_duplicated_when_in_groupby(self): + """Test x_axis is not added if already present in groupby.""" + form_data = { + "x_axis": "territory", + "groupby": ["territory", "year"], + "metrics": [{"label": "SUM(sales)"}], + } + + x_axis_config = form_data.get("x_axis") + groupby_columns = form_data.get("groupby", []) + raw_columns = form_data.get("columns", []) + + columns = ( + raw_columns.copy() if "columns" in form_data else groupby_columns.copy() + ) + if x_axis_config and isinstance(x_axis_config, str): + if x_axis_config not in columns: + columns.insert(0, x_axis_config) + + assert columns == ["territory", "year"] diff --git a/tests/unit_tests/mcp_service/chart/tool/test_get_chart_data.py b/tests/unit_tests/mcp_service/chart/tool/test_get_chart_data.py index 2669366f526..d276691425f 100644 --- a/tests/unit_tests/mcp_service/chart/tool/test_get_chart_data.py +++ b/tests/unit_tests/mcp_service/chart/tool/test_get_chart_data.py @@ -152,6 +152,115 @@ class TestBigNumberChartFallback: assert groupby_columns == [] +class TestXAxisInQueryContext: + """Tests for x_axis inclusion in fallback query context columns.""" + + def test_x_axis_string_included_in_columns(self): + """Test that x_axis (string format) is included alongside groupby columns.""" + form_data = { + "x_axis": "territory", + "groupby": ["year"], + "metrics": [{"label": "SUM(sales)"}], + "viz_type": "echarts_timeseries_bar", + } + + groupby_columns = form_data.get("groupby", []) + x_axis_config = form_data.get("x_axis") + columns = groupby_columns.copy() + if x_axis_config and isinstance(x_axis_config, str): + if x_axis_config not in columns: + columns.insert(0, x_axis_config) + + assert columns == ["territory", "year"] + + def test_x_axis_dict_included_in_columns(self): + """Test that x_axis (dict format with column_name) is included.""" + form_data = { + "x_axis": {"column_name": "territory"}, + "groupby": ["year"], + "metrics": [{"label": "SUM(sales)"}], + } + + groupby_columns = form_data.get("groupby", []) + x_axis_config = form_data.get("x_axis") + columns = groupby_columns.copy() + if x_axis_config and isinstance(x_axis_config, str): + if x_axis_config not in columns: + columns.insert(0, x_axis_config) + elif x_axis_config and isinstance(x_axis_config, dict): + col_name = x_axis_config.get("column_name") + if col_name and col_name not in columns: + columns.insert(0, col_name) + + assert columns == ["territory", "year"] + + def test_no_x_axis_uses_groupby_only(self): + """Test that without x_axis, only groupby columns are used.""" + form_data = { + "groupby": ["region", "category"], + "metrics": [{"label": "SUM(sales)"}], + } + + groupby_columns = form_data.get("groupby", []) + x_axis_config = form_data.get("x_axis") + columns = groupby_columns.copy() + if x_axis_config and isinstance(x_axis_config, str): + if x_axis_config not in columns: + columns.insert(0, x_axis_config) + + assert columns == ["region", "category"] + + def test_x_axis_not_duplicated_if_in_groupby(self): + """Test that x_axis is not duplicated if already in groupby list.""" + form_data = { + "x_axis": "territory", + "groupby": ["territory", "year"], + "metrics": [{"label": "SUM(sales)"}], + } + + groupby_columns = form_data.get("groupby", []) + x_axis_config = form_data.get("x_axis") + columns = groupby_columns.copy() + if x_axis_config and isinstance(x_axis_config, str): + if x_axis_config not in columns: + columns.insert(0, x_axis_config) + + assert columns == ["territory", "year"] + + def test_x_axis_without_groupby(self): + """Test that x_axis works when there's no groupby.""" + form_data = { + "x_axis": "date", + "metrics": [{"label": "SUM(sales)"}], + } + + groupby_columns = form_data.get("groupby", []) + x_axis_config = form_data.get("x_axis") + columns = groupby_columns.copy() + if x_axis_config and isinstance(x_axis_config, str): + if x_axis_config not in columns: + columns.insert(0, x_axis_config) + + assert columns == ["date"] + + def test_empty_groupby_with_x_axis(self): + """Test x_axis with explicitly empty groupby.""" + form_data = { + "x_axis": "platform", + "groupby": [], + "metrics": [{"label": "SUM(global_sales)"}], + } + + groupby_columns = form_data.get("groupby", []) + x_axis_config = form_data.get("x_axis") + columns = groupby_columns.copy() + if x_axis_config and isinstance(x_axis_config, str): + if x_axis_config not in columns: + columns.insert(0, x_axis_config) + + assert columns == ["platform"] + + class TestGetChartDataRequestSchema: """Test the GetChartDataRequest schema validation.""" diff --git a/tests/unit_tests/mcp_service/chart/tool/test_get_chart_preview.py b/tests/unit_tests/mcp_service/chart/tool/test_get_chart_preview.py index cbff760778e..fdd824886d7 100644 --- a/tests/unit_tests/mcp_service/chart/tool/test_get_chart_preview.py +++ b/tests/unit_tests/mcp_service/chart/tool/test_get_chart_preview.py @@ -29,6 +29,119 @@ from superset.mcp_service.chart.schemas import ( ) +class TestPreviewXAxisInQueryContext: + """Tests for x_axis inclusion in preview query context columns. + + When generating chart previews (table, vega_lite), the query context must + include both x_axis and groupby columns. Previously only groupby was used, + causing series charts with group_by to lose the x_axis dimension. + """ + + def test_table_preview_includes_x_axis_and_groupby(self): + """Test that table preview builds columns with both x_axis and groupby.""" + form_data = { + "x_axis": "territory", + "groupby": ["year"], + "metrics": [{"label": "SUM(sales)"}], + } + + x_axis_config = form_data.get("x_axis") + groupby_columns = form_data.get("groupby", []) + columns = groupby_columns.copy() + if x_axis_config and isinstance(x_axis_config, str): + if x_axis_config not in columns: + columns.insert(0, x_axis_config) + + assert columns == ["territory", "year"] + + def test_vega_lite_preview_includes_x_axis_and_groupby(self): + """Test that vega_lite preview builds columns with both x_axis and groupby.""" + form_data = { + "x_axis": "platform", + "groupby": ["genre"], + "metrics": [{"label": "SUM(global_sales)"}], + } + + x_axis_config = form_data.get("x_axis") + groupby_columns = form_data.get("groupby", []) + columns = groupby_columns.copy() + if x_axis_config and isinstance(x_axis_config, str): + if x_axis_config not in columns: + columns.insert(0, x_axis_config) + + assert columns == ["platform", "genre"] + + def test_preview_x_axis_dict_format(self): + """Test preview column building with x_axis as dict.""" + form_data = { + "x_axis": {"column_name": "order_date"}, + "groupby": ["region"], + "metrics": [{"label": "SUM(revenue)"}], + } + + x_axis_config = form_data.get("x_axis") + groupby_columns = form_data.get("groupby", []) + columns = groupby_columns.copy() + if x_axis_config and isinstance(x_axis_config, str): + if x_axis_config not in columns: + columns.insert(0, x_axis_config) + elif x_axis_config and isinstance(x_axis_config, dict): + col_name = x_axis_config.get("column_name") + if col_name and col_name not in columns: + columns.insert(0, col_name) + + assert columns == ["order_date", "region"] + + def test_preview_no_groupby_x_axis_only(self): + """Test preview with x_axis but no groupby.""" + form_data = { + "x_axis": "date", + "metrics": [{"label": "SUM(sales)"}], + } + + x_axis_config = form_data.get("x_axis") + groupby_columns = form_data.get("groupby", []) + columns = groupby_columns.copy() + if x_axis_config and isinstance(x_axis_config, str): + if x_axis_config not in columns: + columns.insert(0, x_axis_config) + + assert columns == ["date"] + + def test_preview_no_x_axis_groupby_only(self): + """Test preview with groupby but no x_axis (e.g., table chart).""" + form_data = { + "groupby": ["category", "region"], + "metrics": [{"label": "COUNT(*)"}], + } + + x_axis_config = form_data.get("x_axis") + groupby_columns = form_data.get("groupby", []) + columns = groupby_columns.copy() + if x_axis_config and isinstance(x_axis_config, str): + if x_axis_config not in columns: + columns.insert(0, x_axis_config) + + assert columns == ["category", "region"] + + def test_preview_x_axis_not_duplicated(self): + """Test x_axis isn't duplicated if already in groupby.""" + form_data = { + "x_axis": "territory", + "groupby": ["territory", "year"], + "metrics": [{"label": "SUM(sales)"}], + } + + x_axis_config = form_data.get("x_axis") + groupby_columns = form_data.get("groupby", []) + columns = groupby_columns.copy() + if x_axis_config and isinstance(x_axis_config, str): + if x_axis_config not in columns: + columns.insert(0, x_axis_config) + + assert columns == ["territory", "year"] + + class TestGetChartPreview: """Tests for get_chart_preview MCP tool."""