diff --git a/superset/mcp_service/chart/resources/chart_configs.py b/superset/mcp_service/chart/resources/chart_configs.py index ea40a89fdc6..8af128ae259 100644 --- a/superset/mcp_service/chart/resources/chart_configs.py +++ b/superset/mcp_service/chart/resources/chart_configs.py @@ -80,7 +80,7 @@ def get_chart_configs_resource() -> str: "y": [ {"name": "revenue", "aggregate": "SUM", "label": "Revenue"}, ], - "group_by": {"name": "region", "label": "Region"}, + "group_by": [{"name": "region", "label": "Region"}], "stacked": True, "legend": {"show": True, "position": "right"}, }, @@ -123,7 +123,7 @@ def get_chart_configs_resource() -> str: "label": "Avg Conversion Rate", } ], - "group_by": {"name": "campaign_type", "label": "Campaign"}, + "group_by": [{"name": "campaign_type", "label": "Campaign"}], "x_axis": {"format": "$,.0f"}, "y_axis": {"format": ".2%"}, }, @@ -136,7 +136,7 @@ def get_chart_configs_resource() -> str: "kind": "area", "x": {"name": "order_date", "label": "Date"}, "y": [{"name": "signups", "aggregate": "SUM", "label": "Signups"}], - "group_by": {"name": "channel", "label": "Channel"}, + "group_by": [{"name": "channel", "label": "Channel"}], "stacked": True, "time_grain": "P1W", }, diff --git a/superset/mcp_service/chart/schemas.py b/superset/mcp_service/chart/schemas.py index 42e19f0073a..b84bf9bc38b 100644 --- a/superset/mcp_service/chart/schemas.py +++ b/superset/mcp_service/chart/schemas.py @@ -382,6 +382,19 @@ class ChartList(BaseModel): # Common pieces + + +def _normalize_group_by_input(v: Any) -> Any: + """Accept a single ColumnRef/dict/str and normalize to list of dicts.""" + if isinstance(v, str): + return [{"name": v}] + if isinstance(v, (dict, ColumnRef)): + return [v] + if isinstance(v, list): + return [{"name": item} if isinstance(item, str) else item for item in v] + return v + + class ColumnRef(BaseModel): model_config = ConfigDict(populate_by_name=True) @@ -624,7 +637,9 @@ class MixedTimeseriesChartConfig(BaseModel): group_by_secondary: List[ColumnRef] | None = Field( None, description="Secondary series group by", - validation_alias=AliasChoices("group_by_secondary", "groupby_b"), + validation_alias=AliasChoices( + "group_by_secondary", "groupby_b", "groupby_secondary" + ), ) # Display options show_legend: bool = True @@ -641,14 +656,7 @@ class MixedTimeseriesChartConfig(BaseModel): @field_validator("group_by", "group_by_secondary", mode="before") @classmethod def wrap_single_group_by(cls, v: Any) -> Any: - """Accept a single ColumnRef/dict/str and normalize to list of dicts.""" - if isinstance(v, str): - return [{"name": v}] - if isinstance(v, (dict, ColumnRef)): - return [v] - if isinstance(v, list): - return [{"name": item} if isinstance(item, str) else item for item in v] - return v + return _normalize_group_by_input(v) class TableChartConfig(BaseModel): @@ -743,14 +751,7 @@ class XYChartConfig(BaseModel): @field_validator("group_by", mode="before") @classmethod def wrap_single_group_by(cls, v: Any) -> Any: - """Accept a single ColumnRef/dict/str and normalize to list of dicts.""" - if isinstance(v, str): - return [{"name": v}] - if isinstance(v, (dict, ColumnRef)): - return [v] - if isinstance(v, list): - return [{"name": item} if isinstance(item, str) else item for item in v] - return v + return _normalize_group_by_input(v) @model_validator(mode="after") def validate_unique_column_labels(self) -> "XYChartConfig": diff --git a/tests/unit_tests/mcp_service/chart/test_chart_schemas.py b/tests/unit_tests/mcp_service/chart/test_chart_schemas.py index 312ebf1d2d6..6ab08f96b94 100644 --- a/tests/unit_tests/mcp_service/chart/test_chart_schemas.py +++ b/tests/unit_tests/mcp_service/chart/test_chart_schemas.py @@ -286,6 +286,74 @@ class TestXYChartConfig: assert config.group_by is not None assert len(config.group_by) == 2 + def test_group_by_bare_string(self) -> None: + """Test that group_by accepts a bare string (auto-wrapped).""" + config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="territory"), + y=[ColumnRef(name="sales", aggregate="SUM")], + kind="bar", + group_by="region", + ) + assert config.group_by is not None + assert len(config.group_by) == 1 + assert config.group_by[0].name == "region" + + def test_orientation_horizontal_accepted(self) -> None: + """Test that orientation='horizontal' is accepted for bar charts.""" + config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="department"), + y=[ColumnRef(name="headcount", aggregate="SUM")], + kind="bar", + orientation="horizontal", + ) + assert config.orientation == "horizontal" + + def test_orientation_vertical_accepted(self) -> None: + """Test that orientation='vertical' is accepted for bar charts.""" + config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="category"), + y=[ColumnRef(name="sales", aggregate="SUM")], + kind="bar", + orientation="vertical", + ) + assert config.orientation == "vertical" + + def test_orientation_none_by_default(self) -> None: + """Test that orientation defaults to None when not specified.""" + config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="category"), + y=[ColumnRef(name="sales", aggregate="SUM")], + kind="bar", + ) + assert config.orientation is None + + def test_orientation_invalid_value_rejected(self) -> None: + """Test that invalid orientation values are rejected.""" + with pytest.raises(ValidationError): + XYChartConfig( + chart_type="xy", + x=ColumnRef(name="category"), + y=[ColumnRef(name="sales", aggregate="SUM")], + kind="bar", + orientation="diagonal", + ) + + def test_orientation_with_non_bar_chart(self) -> None: + """Test that orientation field is accepted on non-bar charts at schema level.""" + config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="date"), + y=[ColumnRef(name="revenue", aggregate="SUM")], + kind="line", + orientation="horizontal", + ) + # Schema allows it; the chart_utils layer decides whether to apply it + assert config.orientation == "horizontal" + class TestRowLimit: """Test row_limit field on chart configs."""