diff --git a/superset/mcp_service/chart/chart_utils.py b/superset/mcp_service/chart/chart_utils.py index ece8e1310a3..1d507909d91 100644 --- a/superset/mcp_service/chart/chart_utils.py +++ b/superset/mcp_service/chart/chart_utils.py @@ -566,22 +566,12 @@ def map_xy_config( # Configure temporal handling based on whether column is truly temporal configure_temporal_handling(form_data, x_is_temporal, config.time_grain) - # CRITICAL FIX: For time series charts, handle groupby carefully to avoid duplicates - # The x_axis field already tells Superset which column to use for time grouping - groupby_columns = [] - - # Only add groupby columns if there's an explicit group_by specified - # The x_axis column should NOT be duplicated in groupby as it causes - # "Duplicate column/metric labels" errors in Superset - # Only add group_by column if it's specified AND different from x_axis - # NEVER add the x_axis column to groupby as it creates duplicate labels - if config.group_by and config.group_by.name != config.x.name: - groupby_columns.append(config.group_by.name) - - # Set the groupby in form_data only if we have valid columns - # Don't set empty groupby - let Superset handle x_axis grouping automatically - if groupby_columns: - form_data["groupby"] = groupby_columns + # Only add groupby columns that differ from x_axis to avoid + # "Duplicate column/metric labels" errors in Superset. + if config.group_by: + groupby_columns = [c.name for c in config.group_by if c.name != config.x.name] + if groupby_columns: + form_data["groupby"] = groupby_columns _add_adhoc_filters(form_data, config.filters) @@ -742,12 +732,18 @@ def map_mixed_timeseries_config( configure_temporal_handling(form_data, x_is_temporal, config.time_grain) # Primary groupby (Query A) - if config.group_by and config.group_by.name != config.x.name: - form_data["groupby"] = [config.group_by.name] + if config.group_by: + groupby = [c.name for c in config.group_by if c.name != config.x.name] + if groupby: + form_data["groupby"] = groupby # Secondary groupby (Query B) - if config.group_by_secondary and config.group_by_secondary.name != config.x.name: - form_data["groupby_b"] = [config.group_by_secondary.name] + if config.group_by_secondary: + groupby_b = [ + c.name for c in config.group_by_secondary if c.name != config.x.name + ] + if groupby_b: + form_data["groupby_b"] = groupby_b form_data["row_limit"] = config.row_limit @@ -831,10 +827,10 @@ def _xy_chart_what(config: XYChartConfig) -> str: primary_metric = _humanize_column(config.y[0]) if config.y else "Value" dimension = _humanize_column(config.x) - if config.kind in ("line", "area") and config.group_by is None: + if config.kind in ("line", "area") and not config.group_by: return f"{primary_metric} Over Time" - if config.group_by is not None: - group_label = _humanize_column(config.group_by) + if config.group_by: + group_label = _humanize_column(config.group_by[0]) return f"{primary_metric} by {group_label}" if config.kind == "scatter": return f"{primary_metric} vs {dimension}" diff --git a/superset/mcp_service/chart/schemas.py b/superset/mcp_service/chart/schemas.py index 91eeac5fc6b..42e19f0073a 100644 --- a/superset/mcp_service/chart/schemas.py +++ b/superset/mcp_service/chart/schemas.py @@ -25,6 +25,7 @@ from datetime import datetime, timezone from typing import Annotated, Any, Dict, List, Literal, Protocol from pydantic import ( + AliasChoices, BaseModel, ConfigDict, Field, @@ -382,11 +383,14 @@ class ChartList(BaseModel): # Common pieces class ColumnRef(BaseModel): + model_config = ConfigDict(populate_by_name=True) + name: str = Field( ..., min_length=1, max_length=255, pattern=r"^[a-zA-Z0-9_][a-zA-Z0-9_\s\-\.]*$", + validation_alias=AliasChoices("name", "column_name"), ) label: str | None = Field(None, max_length=500) dtype: str | None = None @@ -435,11 +439,14 @@ class LegendConfig(BaseModel): class FilterConfig(BaseModel): + model_config = ConfigDict(populate_by_name=True) + column: str = Field( ..., min_length=1, max_length=255, pattern=r"^[a-zA-Z0-9_][a-zA-Z0-9_\s\-\.]*$", + validation_alias=AliasChoices("column", "col"), ) op: Literal[ "=", @@ -456,10 +463,12 @@ class FilterConfig(BaseModel): ] = Field( ..., description="LIKE/ILIKE use % wildcards. IN/NOT IN take a list.", + validation_alias=AliasChoices("op", "operator", "opr"), ) value: str | int | float | bool | list[str | int | float | bool] = Field( ..., description="For IN/NOT IN, provide a list.", + validation_alias=AliasChoices("value", "val"), ) @field_validator("column") @@ -498,10 +507,14 @@ class FilterConfig(BaseModel): # Actual chart types class PieChartConfig(BaseModel): - model_config = ConfigDict(extra="forbid") + model_config = ConfigDict(extra="ignore", populate_by_name=True) chart_type: Literal["pie"] = "pie" - dimension: ColumnRef = Field(..., description="Category column for slices") + dimension: ColumnRef = Field( + ..., + description="Category column for slices", + validation_alias=AliasChoices("dimension", "groupby"), + ) metric: ColumnRef = Field( ..., description="Value metric (needs aggregate e.g. SUM, COUNT)" ) @@ -518,7 +531,11 @@ class PieChartConfig(BaseModel): ] = "key_value_percent" sort_by_metric: bool = True show_legend: bool = True - filters: List[FilterConfig] | None = None + filters: List[FilterConfig] | None = Field( + None, + description="Structured filters (column/op/value). " + "Do NOT use adhoc_filters or raw SQL expressions.", + ) row_limit: int = Field(100, description="Max slices", ge=1, le=10000) number_format: str = Field("SMART_NUMBER", max_length=50) show_total: bool = Field(False, description="Show total in center") @@ -530,10 +547,15 @@ class PieChartConfig(BaseModel): class PivotTableChartConfig(BaseModel): - model_config = ConfigDict(extra="forbid") + model_config = ConfigDict(extra="ignore", populate_by_name=True) chart_type: Literal["pivot_table"] = "pivot_table" - rows: List[ColumnRef] = Field(..., min_length=1, description="Row grouping columns") + rows: List[ColumnRef] = Field( + ..., + min_length=1, + description="Row grouping columns", + validation_alias=AliasChoices("rows", "groupby", "dimension"), + ) columns: List[ColumnRef] | None = Field( None, description="Column groups for cross-tabulation" ) @@ -559,40 +581,78 @@ class PivotTableChartConfig(BaseModel): show_column_totals: bool = True transpose: bool = False combine_metric: bool = Field(False, description="Metrics side by side in columns") - filters: List[FilterConfig] | None = None + filters: List[FilterConfig] | None = Field( + None, + description="Structured filters (column/op/value). " + "Do NOT use adhoc_filters or raw SQL expressions.", + ) row_limit: int = Field(10000, description="Max cells", ge=1, le=50000) value_format: str = Field("SMART_NUMBER", max_length=50) class MixedTimeseriesChartConfig(BaseModel): - model_config = ConfigDict(extra="forbid") + model_config = ConfigDict(extra="ignore", populate_by_name=True) chart_type: Literal["mixed_timeseries"] = "mixed_timeseries" - x: ColumnRef = Field(..., description="Shared temporal X-axis column") + x: ColumnRef = Field( + ..., + description="Shared temporal X-axis column", + validation_alias=AliasChoices("x", "x_axis"), + ) time_grain: TimeGrain | None = Field(None, description="PT1H, P1D, P1W, P1M, P1Y") # Primary series (Query A) - y: List[ColumnRef] = Field(..., min_length=1, description="Primary Y-axis metrics") + y: List[ColumnRef] = Field( + ..., + min_length=1, + description="Primary Y-axis metrics", + validation_alias=AliasChoices("y", "metrics"), + ) primary_kind: Literal["line", "bar", "area", "scatter"] = "line" - group_by: ColumnRef | None = Field(None, description="Primary series group by") + group_by: List[ColumnRef] | None = Field( + None, + description="Primary series group by", + validation_alias=AliasChoices("group_by", "groupby", "series", "dimension"), + ) # Secondary series (Query B) y_secondary: List[ColumnRef] = Field( - ..., min_length=1, description="Secondary Y-axis metrics" + ..., + min_length=1, + description="Secondary Y-axis metrics", + validation_alias=AliasChoices("y_secondary", "metrics_b"), ) secondary_kind: Literal["line", "bar", "area", "scatter"] = "bar" - group_by_secondary: ColumnRef | None = Field( - None, description="Secondary series group by" + group_by_secondary: List[ColumnRef] | None = Field( + None, + description="Secondary series group by", + validation_alias=AliasChoices("group_by_secondary", "groupby_b"), ) # Display options show_legend: bool = True x_axis: AxisConfig | None = None y_axis: AxisConfig | None = None y_axis_secondary: AxisConfig | None = None - filters: List[FilterConfig] | None = None + filters: List[FilterConfig] | None = Field( + None, + description="Structured filters (column/op/value). " + "Do NOT use adhoc_filters or raw SQL expressions.", + ) row_limit: int = Field(10000, description="Max data points", ge=1, le=50000) + @field_validator("group_by", "group_by_secondary", mode="before") + @classmethod + def wrap_single_group_by(cls, v: Any) -> Any: + """Accept a single ColumnRef/dict/str and normalize to list of dicts.""" + if isinstance(v, str): + return [{"name": v}] + if isinstance(v, (dict, ColumnRef)): + return [v] + if isinstance(v, list): + return [{"name": item} if isinstance(item, str) else item for item in v] + return v + class TableChartConfig(BaseModel): - model_config = ConfigDict(extra="forbid") + model_config = ConfigDict(extra="ignore", populate_by_name=True) chart_type: Literal["table"] = "table" viz_type: Literal["table", "ag-grid-table"] = Field( @@ -602,8 +662,13 @@ class TableChartConfig(BaseModel): ..., min_length=1, description="Columns with unique labels", + validation_alias=AliasChoices("columns", "all_columns", "groupby"), + ) + filters: List[FilterConfig] | None = Field( + None, + description="Structured filters (column/op/value). " + "Do NOT use adhoc_filters or raw SQL expressions.", ) - filters: List[FilterConfig] | None = None sort_by: List[str] | None = None row_limit: int = Field(1000, description="Max rows returned", ge=1, le=50000) @@ -636,12 +701,19 @@ class TableChartConfig(BaseModel): class XYChartConfig(BaseModel): - model_config = ConfigDict(extra="forbid") + model_config = ConfigDict(extra="ignore", populate_by_name=True) chart_type: Literal["xy"] = "xy" - x: ColumnRef = Field(..., description="X-axis column") + x: ColumnRef = Field( + ..., + description="X-axis column", + validation_alias=AliasChoices("x", "x_axis", "x_column"), + ) y: List[ColumnRef] = Field( - ..., min_length=1, description="Y-axis metrics (unique labels)" + ..., + min_length=1, + description="Y-axis metrics (unique labels)", + validation_alias=AliasChoices("y", "metrics"), ) kind: Literal["line", "bar", "area", "scatter"] = "line" time_grain: TimeGrain | None = Field( @@ -651,15 +723,35 @@ class XYChartConfig(BaseModel): None, description="Bar orientation (only for kind='bar')" ) stacked: bool = False - group_by: ColumnRef | None = Field( - None, description="Series breakdown column (not 'series')" + group_by: List[ColumnRef] | None = Field( + None, + description="Series breakdown columns", + validation_alias=AliasChoices( + "group_by", "groupby", "series", "breakdown", "dimension" + ), ) x_axis: AxisConfig | None = None y_axis: AxisConfig | None = None legend: LegendConfig | None = None - filters: List[FilterConfig] | None = None + filters: List[FilterConfig] | None = Field( + None, + description="Structured filters (column/op/value). " + "Do NOT use adhoc_filters or raw SQL expressions.", + ) row_limit: int = Field(10000, description="Max data points", ge=1, le=50000) + @field_validator("group_by", mode="before") + @classmethod + def wrap_single_group_by(cls, v: Any) -> Any: + """Accept a single ColumnRef/dict/str and normalize to list of dicts.""" + if isinstance(v, str): + return [{"name": v}] + if isinstance(v, (dict, ColumnRef)): + return [v] + if isinstance(v, list): + return [{"name": item} if isinstance(item, str) else item for item in v] + return v + @model_validator(mode="after") def validate_unique_column_labels(self) -> "XYChartConfig": """Ensure all column labels are unique across x, y, and group_by.""" @@ -684,14 +776,22 @@ class XYChartConfig(BaseModel): else: labels_seen[label] = f"y[{i}]" - # Check group_by label if present + # Check group_by labels if present if self.group_by: - group_label = self.group_by.label or self.group_by.name - if group_label in labels_seen: - duplicates.append( - f"group_by: '{group_label}' " - f"(conflicts with {labels_seen[group_label]})" - ) + for i, col in enumerate(self.group_by): + if col.name == self.x.name: + # map_xy_config() strips group_by entries that match x + # to prevent Superset "duplicate label" errors, so + # we allow them through validation. + continue + group_label = col.label or col.name + if group_label in labels_seen: + duplicates.append( + f"group_by[{i}]: '{group_label}' " + f"(conflicts with {labels_seen[group_label]})" + ) + else: + labels_seen[group_label] = f"group_by[{i}]" if duplicates: raise ValueError( diff --git a/superset/mcp_service/chart/validation/dataset_validator.py b/superset/mcp_service/chart/validation/dataset_validator.py index c50fccc6938..4f7e78cc4f0 100644 --- a/superset/mcp_service/chart/validation/dataset_validator.py +++ b/superset/mcp_service/chart/validation/dataset_validator.py @@ -187,7 +187,7 @@ class DatasetValidator: refs.append(config.x) refs.extend(config.y) if config.group_by: - refs.append(config.group_by) + refs.extend(config.group_by) # Add filter columns if hasattr(config, "filters") and config.filters: @@ -265,13 +265,12 @@ class DatasetValidator: y_col["name"], dataset_context ) - # Normalize group_by column + # Normalize group_by columns if "group_by" in config_dict and config_dict["group_by"]: - config_dict["group_by"]["name"] = ( - DatasetValidator._get_canonical_column_name( - config_dict["group_by"]["name"], dataset_context + for gb_col in config_dict["group_by"]: + gb_col["name"] = DatasetValidator._get_canonical_column_name( + gb_col["name"], dataset_context ) - ) @staticmethod def _normalize_table_config( diff --git a/superset/mcp_service/chart/validation/runtime/__init__.py b/superset/mcp_service/chart/validation/runtime/__init__.py index 7228d73481b..b9414551737 100644 --- a/superset/mcp_service/chart/validation/runtime/__init__.py +++ b/superset/mcp_service/chart/validation/runtime/__init__.py @@ -138,7 +138,7 @@ class RuntimeValidator: dataset_id=dataset_id, x_column=config.x.name, chart_type=chart_type, - group_by_column=config.group_by.name if config.group_by else None, + group_by_column=config.group_by[0].name if config.group_by else None, ) if not is_ok and cardinality_info: diff --git a/tests/unit_tests/mcp_service/chart/test_chart_schemas.py b/tests/unit_tests/mcp_service/chart/test_chart_schemas.py index f0b35f74376..55b26c3826a 100644 --- a/tests/unit_tests/mcp_service/chart/test_chart_schemas.py +++ b/tests/unit_tests/mcp_service/chart/test_chart_schemas.py @@ -158,16 +158,31 @@ class TestXYChartConfig: ) assert len(config.y) == 2 - def test_group_by_duplicate_with_x_rejected(self) -> None: - """Test that group_by conflicts with x are rejected.""" + def test_group_by_duplicate_label_with_x_rejected(self) -> None: + """Test that group_by with a custom label conflicting with x is rejected.""" with pytest.raises(ValidationError, match="Duplicate column/metric labels"): XYChartConfig( chart_type="xy", x=ColumnRef(name="region"), y=[ColumnRef(name="sales", aggregate="SUM")], - group_by=ColumnRef(name="category", label="region"), + group_by=[ColumnRef(name="category", label="region")], ) + def test_group_by_same_column_as_x_allowed(self) -> None: + """Test that group_by with the same column name as x is allowed. + + The mapping layer filters these out, so validation should not reject them. + """ + config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="date"), + y=[ColumnRef(name="sales", aggregate="SUM")], + kind="line", + group_by=[ColumnRef(name="date")], + ) + assert config.group_by is not None + assert config.group_by[0].name == "date" + def test_realistic_chart_configurations(self) -> None: """Test realistic chart configurations.""" # This should work - COUNT(product_line) != product_line @@ -220,19 +235,34 @@ class TestXYChartConfig: ) assert config.kind == "area" - def test_unknown_fields_rejected(self) -> None: - """Test that unknown fields like 'series' are rejected.""" - with pytest.raises(ValidationError, match="Extra inputs are not permitted"): - XYChartConfig( - chart_type="xy", - x=ColumnRef(name="territory"), - y=[ColumnRef(name="sales", aggregate="SUM")], - kind="bar", - series=ColumnRef(name="year"), - ) + def test_unknown_fields_ignored(self) -> None: + """Test that unknown fields are silently ignored (extra='ignore').""" + config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="territory"), + y=[ColumnRef(name="sales", aggregate="SUM")], + kind="bar", + unknown_field="bad", + ) + assert config.kind == "bar" + assert not hasattr(config, "unknown_field") + + def test_series_alias_accepted(self) -> None: + """Test that 'series' is accepted as alias for 'group_by'.""" + config = XYChartConfig.model_validate( + { + "chart_type": "xy", + "x": {"name": "territory"}, + "y": [{"name": "sales", "aggregate": "SUM"}], + "kind": "bar", + "series": {"name": "year"}, + } + ) + assert config.group_by is not None + assert config.group_by[0].name == "year" def test_group_by_accepted(self) -> None: - """Test that group_by is the correct field for series grouping.""" + """Test that group_by accepts a single ColumnRef (auto-wrapped in list).""" config = XYChartConfig( chart_type="xy", x=ColumnRef(name="territory"), @@ -241,7 +271,20 @@ class TestXYChartConfig: group_by=ColumnRef(name="year"), ) assert config.group_by is not None - assert config.group_by.name == "year" + assert len(config.group_by) == 1 + assert config.group_by[0].name == "year" + + def test_group_by_multiple(self) -> None: + """Test that group_by accepts a list of ColumnRefs.""" + config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="territory"), + y=[ColumnRef(name="sales", aggregate="SUM")], + kind="bar", + group_by=[ColumnRef(name="year"), ColumnRef(name="region")], + ) + assert config.group_by is not None + assert len(config.group_by) == 2 def test_orientation_horizontal_accepted(self) -> None: """Test that orientation='horizontal' is accepted for bar charts.""" @@ -374,11 +417,12 @@ class TestRowLimit: class TestTableChartConfigExtraFields: """Test TableChartConfig rejects unknown fields.""" - def test_unknown_fields_rejected(self) -> None: - """Test that unknown fields are rejected.""" - with pytest.raises(ValidationError, match="Extra inputs are not permitted"): - TableChartConfig( - chart_type="table", - columns=[ColumnRef(name="product")], - foo="bar", - ) + def test_unknown_fields_ignored(self) -> None: + """Test that unknown fields are silently ignored (extra='ignore').""" + config = TableChartConfig( + chart_type="table", + columns=[ColumnRef(name="product")], + foo="bar", + ) + assert len(config.columns) == 1 + assert not hasattr(config, "foo") diff --git a/tests/unit_tests/mcp_service/chart/test_new_chart_types.py b/tests/unit_tests/mcp_service/chart/test_new_chart_types.py index 9221d2d2bf1..9390530843b 100644 --- a/tests/unit_tests/mcp_service/chart/test_new_chart_types.py +++ b/tests/unit_tests/mcp_service/chart/test_new_chart_types.py @@ -103,14 +103,15 @@ class TestPieChartConfigSchema: assert config.filters is not None assert len(config.filters) == 1 - def test_pie_config_rejects_extra_fields(self) -> None: - with pytest.raises(ValidationError): - PieChartConfig( - chart_type="pie", - dimension=ColumnRef(name="product"), - metric=ColumnRef(name="revenue", aggregate="SUM"), - unknown_field="bad", - ) + def test_pie_config_ignores_extra_fields(self) -> None: + config = PieChartConfig( + chart_type="pie", + dimension=ColumnRef(name="product"), + metric=ColumnRef(name="revenue", aggregate="SUM"), + unknown_field="bad", + ) + assert config.dimension.name == "product" + assert not hasattr(config, "unknown_field") def test_pie_config_missing_dimension(self) -> None: with pytest.raises(ValidationError): @@ -323,14 +324,15 @@ class TestPivotTableChartConfigSchema: metrics=[ColumnRef(name="revenue", aggregate="SUM")], ) - def test_pivot_table_rejects_extra_fields(self) -> None: - with pytest.raises(ValidationError): - PivotTableChartConfig( - chart_type="pivot_table", - rows=[ColumnRef(name="product")], - metrics=[ColumnRef(name="revenue", aggregate="SUM")], - unknown_field="bad", - ) + def test_pivot_table_ignores_extra_fields(self) -> None: + config = PivotTableChartConfig( + chart_type="pivot_table", + rows=[ColumnRef(name="product")], + metrics=[ColumnRef(name="revenue", aggregate="SUM")], + unknown_field="bad", + ) + assert config.rows[0].name == "product" + assert not hasattr(config, "unknown_field") def test_pivot_table_valid_aggregate_functions(self) -> None: for agg in ["Sum", "Average", "Median", "Count", "Minimum", "Maximum"]: @@ -459,10 +461,10 @@ class TestMixedTimeseriesChartConfigSchema: time_grain="P1M", y=[ColumnRef(name="revenue", aggregate="SUM")], primary_kind="area", - group_by=ColumnRef(name="region"), + group_by=[ColumnRef(name="region")], y_secondary=[ColumnRef(name="orders", aggregate="COUNT")], secondary_kind="scatter", - group_by_secondary=ColumnRef(name="channel"), + group_by_secondary=[ColumnRef(name="channel")], show_legend=False, x_axis=AxisConfig(title="Date"), y_axis=AxisConfig(title="Revenue", format="$,.2f"), @@ -473,9 +475,9 @@ class TestMixedTimeseriesChartConfigSchema: assert config.secondary_kind == "scatter" assert config.time_grain == "P1M" assert config.group_by is not None - assert config.group_by.name == "region" + assert config.group_by[0].name == "region" assert config.group_by_secondary is not None - assert config.group_by_secondary.name == "channel" + assert config.group_by_secondary[0].name == "channel" def test_mixed_timeseries_missing_y(self) -> None: with pytest.raises(ValidationError): @@ -502,15 +504,16 @@ class TestMixedTimeseriesChartConfigSchema: y_secondary=[ColumnRef(name="orders", aggregate="COUNT")], ) - def test_mixed_timeseries_rejects_extra_fields(self) -> None: - with pytest.raises(ValidationError): - MixedTimeseriesChartConfig( - chart_type="mixed_timeseries", - x=ColumnRef(name="date"), - y=[ColumnRef(name="revenue", aggregate="SUM")], - y_secondary=[ColumnRef(name="orders", aggregate="COUNT")], - unknown_field="bad", - ) + def test_mixed_timeseries_ignores_extra_fields(self) -> None: + config = MixedTimeseriesChartConfig( + chart_type="mixed_timeseries", + x=ColumnRef(name="date"), + y=[ColumnRef(name="revenue", aggregate="SUM")], + y_secondary=[ColumnRef(name="orders", aggregate="COUNT")], + unknown_field="bad", + ) + assert config.x.name == "date" + assert not hasattr(config, "unknown_field") def test_mixed_timeseries_default_row_limit(self) -> None: config = MixedTimeseriesChartConfig( @@ -606,9 +609,9 @@ class TestMapMixedTimeseriesConfig: chart_type="mixed_timeseries", x=ColumnRef(name="date"), y=[ColumnRef(name="revenue", aggregate="SUM")], - group_by=ColumnRef(name="region"), + group_by=[ColumnRef(name="region")], y_secondary=[ColumnRef(name="orders", aggregate="COUNT")], - group_by_secondary=ColumnRef(name="channel"), + group_by_secondary=[ColumnRef(name="channel")], ) result = map_mixed_timeseries_config(config, dataset_id=1) @@ -623,9 +626,9 @@ class TestMapMixedTimeseriesConfig: chart_type="mixed_timeseries", x=ColumnRef(name="date"), y=[ColumnRef(name="revenue", aggregate="SUM")], - group_by=ColumnRef(name="date"), # same as x + group_by=[ColumnRef(name="date")], # same as x y_secondary=[ColumnRef(name="orders", aggregate="COUNT")], - group_by_secondary=ColumnRef(name="date"), # same as x + group_by_secondary=[ColumnRef(name="date")], # same as x ) result = map_mixed_timeseries_config(config, dataset_id=1) diff --git a/tests/unit_tests/mcp_service/chart/validation/test_column_name_normalization.py b/tests/unit_tests/mcp_service/chart/validation/test_column_name_normalization.py index 77fdf64143f..8c68f738ed1 100644 --- a/tests/unit_tests/mcp_service/chart/validation/test_column_name_normalization.py +++ b/tests/unit_tests/mcp_service/chart/validation/test_column_name_normalization.py @@ -163,12 +163,12 @@ class TestNormalizeXYConfig: "x": {"name": "OrderDate"}, "y": [{"name": "Sales", "aggregate": "SUM"}], "kind": "line", - "group_by": {"name": "productline"}, + "group_by": [{"name": "productline"}], } DatasetValidator._normalize_xy_config(config_dict, mock_dataset_context) - assert config_dict["group_by"]["name"] == "ProductLine" + assert config_dict["group_by"][0]["name"] == "ProductLine" class TestNormalizeTableConfig: @@ -397,7 +397,7 @@ class TestNormalizeUppercaseDataset: assert normalized.x.name == "ds" assert normalized.y[0].name == "DISTANCE" assert normalized.group_by is not None - assert normalized.group_by.name == "AIRLINE" + assert normalized.group_by[0].name == "AIRLINE" assert normalized.filters is not None assert normalized.filters[0].column == "AIRLINE" @@ -531,7 +531,7 @@ class TestNormalizeEdgeCases: assert normalized.y[0].name == "Sales" assert normalized.y[1].name == "quantity_ordered" assert normalized.group_by is not None - assert normalized.group_by.name == "ProductLine" + assert normalized.group_by[0].name == "ProductLine" assert normalized.filters is not None assert normalized.filters[0].column == "ProductLine" assert normalized.filters[1].column == "OrderDate" @@ -678,4 +678,4 @@ class TestNormalizeXAxisFilterConsistency: assert normalized.group_by is not None assert normalized.filters is not None - assert normalized.group_by.name == normalized.filters[0].column == "AIRLINE" + assert normalized.group_by[0].name == normalized.filters[0].column == "AIRLINE"