diff --git a/superset/mcp_service/app.py b/superset/mcp_service/app.py index 6d726abed78..3fb388f8aaf 100644 --- a/superset/mcp_service/app.py +++ b/superset/mcp_service/app.py @@ -159,8 +159,8 @@ CRITICAL RULES - NEVER VIOLATE: open_sql_lab_with_context, etc.) and use the URL it returns. - To modify an existing chart's filters, metrics, or dimensions, use update_chart. Do NOT use execute_sql for chart modifications. -- Parameter name reminders: open_sql_lab_with_context uses "sql" (not "query"), - execute_sql uses "sql" (not "query"). +- Parameter name reminders: ALWAYS use the EXACT parameter names from the tool schema. + Do NOT use Superset's internal form_data names. IMPORTANT - Tool-Only Interaction: - Do NOT generate code artifacts, HTML pages, JavaScript snippets, or any code intended diff --git a/superset/mcp_service/chart/schemas.py b/superset/mcp_service/chart/schemas.py index f78c15f510d..edf952dffb9 100644 --- a/superset/mcp_service/chart/schemas.py +++ b/superset/mcp_service/chart/schemas.py @@ -21,11 +21,13 @@ Pydantic schemas for chart-related responses from __future__ import annotations +import difflib from datetime import datetime, timezone from typing import Annotated, Any, Dict, List, Literal, Protocol from pydantic import ( AliasChoices, + AliasPath, BaseModel, ConfigDict, Field, @@ -395,6 +397,67 @@ def _normalize_group_by_input(v: Any) -> Any: return v +def _top_level_key(alias: str | AliasPath) -> str | None: + """Extract the top-level dict key from a str or AliasPath.""" + if isinstance(alias, str): + return alias + if isinstance(alias, AliasPath) and alias.path and isinstance(alias.path[0], str): + return alias.path[0] + return None + + +def _get_known_fields(model_class: type[BaseModel]) -> set[str]: + """Collect all valid field names including validation aliases.""" + known: set[str] = set() + for field_name, field_info in model_class.model_fields.items(): + known.add(field_name) + alias = field_info.validation_alias + if isinstance(alias, AliasChoices): + for choice in alias.choices: + key = _top_level_key(choice) + if key: + known.add(key) + elif alias is not None: + key = _top_level_key(alias) + if key: + known.add(key) + return known + + +def _check_unknown_fields(data: Any, model_class: type[BaseModel]) -> Any: + """Raise ValueError for unrecognized fields with 'did you mean?' suggestions. + + Catches fields that would be silently dropped by extra='ignore' and provides + actionable error messages to help LLMs self-correct parameter names. + """ + if not isinstance(data, dict): + return data + known = _get_known_fields(model_class) + unknown = set(data.keys()) - known + if not unknown: + return data + + messages = [] + for field in sorted(unknown): + matches = difflib.get_close_matches(field, sorted(known), n=1, cutoff=0.6) + if matches: + messages.append(f"Unknown field '{field}' — did you mean '{matches[0]}'?") + else: + messages.append( + f"Unknown field '{field}'. Valid fields: {', '.join(sorted(known))}" + ) + raise ValueError(" | ".join(messages)) + + +class UnknownFieldCheckMixin(BaseModel): + """Mixin that rejects unknown fields with 'did you mean?' suggestions.""" + + @model_validator(mode="before") + @classmethod + def check_unknown_fields(cls, data: Any) -> Any: + return _check_unknown_fields(data, cls) + + class ColumnRef(BaseModel): model_config = ConfigDict(populate_by_name=True) @@ -519,7 +582,7 @@ class FilterConfig(BaseModel): # Actual chart types -class PieChartConfig(BaseModel): +class PieChartConfig(UnknownFieldCheckMixin): model_config = ConfigDict(extra="ignore", populate_by_name=True) chart_type: Literal["pie"] = "pie" @@ -559,7 +622,7 @@ class PieChartConfig(BaseModel): ) -class PivotTableChartConfig(BaseModel): +class PivotTableChartConfig(UnknownFieldCheckMixin): model_config = ConfigDict(extra="ignore", populate_by_name=True) chart_type: Literal["pivot_table"] = "pivot_table" @@ -603,7 +666,7 @@ class PivotTableChartConfig(BaseModel): value_format: str = Field("SMART_NUMBER", max_length=50) -class MixedTimeseriesChartConfig(BaseModel): +class MixedTimeseriesChartConfig(UnknownFieldCheckMixin): model_config = ConfigDict(extra="ignore", populate_by_name=True) chart_type: Literal["mixed_timeseries"] = "mixed_timeseries" @@ -612,7 +675,11 @@ class MixedTimeseriesChartConfig(BaseModel): description="Shared temporal X-axis column", validation_alias=AliasChoices("x", "x_axis"), ) - time_grain: TimeGrain | None = Field(None, description="PT1H, P1D, P1W, P1M, P1Y") + time_grain: TimeGrain | None = Field( + None, + description="PT1H, P1D, P1W, P1M, P1Y", + validation_alias=AliasChoices("time_grain", "time_grain_sqla"), + ) # Primary series (Query A) y: List[ColumnRef] = Field( ..., @@ -659,8 +726,8 @@ class MixedTimeseriesChartConfig(BaseModel): return _normalize_group_by_input(v) -class HandlebarsChartConfig(BaseModel): - model_config = ConfigDict(extra="forbid") +class HandlebarsChartConfig(UnknownFieldCheckMixin): + model_config = ConfigDict(extra="ignore") chart_type: Literal["handlebars"] = Field( ..., @@ -703,6 +770,7 @@ class HandlebarsChartConfig(BaseModel): "Columns to group by in aggregate mode (query_mode='aggregate'). " "These become the dimensions for aggregation." ), + validation_alias=AliasChoices("groupby", "group_by"), ) metrics: list[ColumnRef] | None = Field( None, @@ -761,7 +829,7 @@ class HandlebarsChartConfig(BaseModel): return self -class TableChartConfig(BaseModel): +class TableChartConfig(UnknownFieldCheckMixin): model_config = ConfigDict(extra="ignore", populate_by_name=True) chart_type: Literal["table"] = "table" @@ -779,7 +847,10 @@ class TableChartConfig(BaseModel): description="Structured filters (column/op/value). " "Do NOT use adhoc_filters or raw SQL expressions.", ) - sort_by: List[str] | None = None + sort_by: List[str] | None = Field( + None, + validation_alias=AliasChoices("sort_by", "order_by_cols", "order_by"), + ) row_limit: int = Field(1000, description="Max rows returned", ge=1, le=50000) @model_validator(mode="after") @@ -810,7 +881,7 @@ class TableChartConfig(BaseModel): return self -class XYChartConfig(BaseModel): +class XYChartConfig(UnknownFieldCheckMixin): model_config = ConfigDict(extra="ignore", populate_by_name=True) chart_type: Literal["xy"] = "xy" @@ -827,12 +898,17 @@ class XYChartConfig(BaseModel): ) kind: Literal["line", "bar", "area", "scatter"] = "line" time_grain: TimeGrain | None = Field( - None, description="PT1S, PT1M, PT1H, P1D, P1W, P1M, P3M, P1Y" + None, + description="PT1S, PT1M, PT1H, P1D, P1W, P1M, P3M, P1Y", + validation_alias=AliasChoices("time_grain", "time_grain_sqla"), ) orientation: Literal["vertical", "horizontal"] | None = Field( None, description="Bar orientation (only for kind='bar')" ) - stacked: bool = False + stacked: bool = Field( + False, + validation_alias=AliasChoices("stacked", "stack"), + ) group_by: List[ColumnRef] | None = Field( None, description="Series breakdown columns", diff --git a/superset/mcp_service/sql_lab/schemas.py b/superset/mcp_service/sql_lab/schemas.py index 2e0268774cc..b20bdc419e4 100644 --- a/superset/mcp_service/sql_lab/schemas.py +++ b/superset/mcp_service/sql_lab/schemas.py @@ -33,6 +33,7 @@ class ExecuteSqlRequest(BaseModel): sql: str = Field( ..., description="SQL query to execute (supports Jinja2 {{ var }} template syntax)", + validation_alias=AliasChoices("sql", "query"), ) schema_name: str | None = Field( None, description="Schema to use for query execution", alias="schema" diff --git a/tests/unit_tests/mcp_service/chart/test_chart_schemas.py b/tests/unit_tests/mcp_service/chart/test_chart_schemas.py index 6ab08f96b94..3ca4793e89a 100644 --- a/tests/unit_tests/mcp_service/chart/test_chart_schemas.py +++ b/tests/unit_tests/mcp_service/chart/test_chart_schemas.py @@ -235,17 +235,16 @@ class TestXYChartConfig: ) assert config.kind == "area" - def test_unknown_fields_ignored(self) -> None: - """Test that unknown fields are silently ignored (extra='ignore').""" - config = XYChartConfig( - chart_type="xy", - x=ColumnRef(name="territory"), - y=[ColumnRef(name="sales", aggregate="SUM")], - kind="bar", - unknown_field="bad", - ) - assert config.kind == "bar" - assert not hasattr(config, "unknown_field") + def test_unknown_fields_raise_error(self) -> None: + """Test that unknown fields raise ValueError with suggestions.""" + with pytest.raises(ValidationError, match="Unknown field"): + XYChartConfig( + chart_type="xy", + x=ColumnRef(name="territory"), + y=[ColumnRef(name="sales", aggregate="SUM")], + kind="bar", + unknown_field="bad", + ) def test_series_alias_accepted(self) -> None: """Test that 'series' is accepted as alias for 'group_by'.""" @@ -430,12 +429,144 @@ class TestRowLimit: class TestTableChartConfigExtraFields: """Test TableChartConfig rejects unknown fields.""" - def test_unknown_fields_ignored(self) -> None: - """Test that unknown fields are silently ignored (extra='ignore').""" - config = TableChartConfig( - chart_type="table", - columns=[ColumnRef(name="product")], - foo="bar", + def test_unknown_fields_raise_error(self) -> None: + """Test that unknown fields raise ValueError with valid field list.""" + with pytest.raises(ValidationError, match="Unknown field 'foo'"): + TableChartConfig( + chart_type="table", + columns=[ColumnRef(name="product")], + foo="bar", + ) + + +class TestAliasChoices: + """Test that common Superset form_data aliases are accepted.""" + + def test_xy_stack_alias_for_stacked(self) -> None: + """Test that 'stack' is accepted as alias for 'stacked'.""" + config = XYChartConfig.model_validate( + { + "chart_type": "xy", + "x": {"name": "category"}, + "y": [{"name": "sales", "aggregate": "SUM"}], + "stack": True, + } ) - assert len(config.columns) == 1 - assert not hasattr(config, "foo") + assert config.stacked is True + + def test_xy_stacked_still_works(self) -> None: + """Test that 'stacked' still works as primary field name.""" + config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="category"), + y=[ColumnRef(name="sales", aggregate="SUM")], + stacked=True, + ) + assert config.stacked is True + + def test_xy_time_grain_sqla_alias(self) -> None: + """Test that 'time_grain_sqla' is accepted as alias for 'time_grain'.""" + config = XYChartConfig.model_validate( + { + "chart_type": "xy", + "x": {"name": "order_date"}, + "y": [{"name": "sales", "aggregate": "SUM"}], + "time_grain_sqla": "P1D", + } + ) + assert config.time_grain is not None + + def test_table_order_by_alias_for_sort_by(self) -> None: + """Test that 'order_by' is accepted as alias for 'sort_by'.""" + config = TableChartConfig.model_validate( + { + "chart_type": "table", + "columns": [{"name": "product"}], + "order_by": ["product"], + } + ) + assert config.sort_by == ["product"] + + def test_mixed_timeseries_time_grain_sqla_alias(self) -> None: + """Test that 'time_grain_sqla' works for MixedTimeseriesChartConfig.""" + from superset.mcp_service.chart.schemas import MixedTimeseriesChartConfig + + config = MixedTimeseriesChartConfig.model_validate( + { + "chart_type": "mixed_timeseries", + "x": {"name": "order_date"}, + "y": [{"name": "sales", "aggregate": "SUM"}], + "y_secondary": [{"name": "profit", "aggregate": "SUM"}], + "time_grain_sqla": "P1M", + } + ) + assert config.time_grain is not None + + +class TestUnknownFieldDetection: + """Test that unknown fields produce helpful error messages.""" + + def test_near_miss_suggests_correct_field(self) -> None: + """Test that a near-miss field name produces 'did you mean?' suggestion.""" + with pytest.raises(ValidationError, match="did you mean"): + XYChartConfig.model_validate( + { + "chart_type": "xy", + "x": {"name": "category"}, + "y": [{"name": "sales", "aggregate": "SUM"}], + "stacks": True, + } + ) + + def test_completely_unknown_field_lists_valid_fields(self) -> None: + """Test that a completely unknown field lists valid fields.""" + with pytest.raises(ValidationError, match="Valid fields:"): + XYChartConfig.model_validate( + { + "chart_type": "xy", + "x": {"name": "category"}, + "y": [{"name": "sales", "aggregate": "SUM"}], + "zzz_nonexistent": True, + } + ) + + def test_pie_chart_unknown_field(self) -> None: + """Test unknown field detection on PieChartConfig.""" + from superset.mcp_service.chart.schemas import PieChartConfig + + with pytest.raises(ValidationError, match="Unknown field"): + PieChartConfig.model_validate( + { + "chart_type": "pie", + "dimension": {"name": "category"}, + "metric": {"name": "sales", "aggregate": "SUM"}, + "bad_field": True, + } + ) + + def test_table_chart_unknown_field(self) -> None: + """Test unknown field detection on TableChartConfig.""" + with pytest.raises(ValidationError, match="Unknown field"): + TableChartConfig.model_validate( + { + "chart_type": "table", + "columns": [{"name": "product"}], + "invalid_param": "test", + } + ) + + def test_known_aliases_not_flagged_as_unknown(self) -> None: + """Test that known aliases pass validation without errors.""" + config = XYChartConfig.model_validate( + { + "chart_type": "xy", + "x_axis": {"name": "category"}, + "metrics": [{"name": "sales", "aggregate": "SUM"}], + "groupby": [{"name": "region"}], + "stack": True, + "time_grain_sqla": "P1D", + } + ) + assert config.stacked is True + assert config.row_limit == 10000 + assert config.group_by is not None diff --git a/tests/unit_tests/mcp_service/chart/test_handlebars_chart.py b/tests/unit_tests/mcp_service/chart/test_handlebars_chart.py index ed0317265fc..377a9cb9006 100644 --- a/tests/unit_tests/mcp_service/chart/test_handlebars_chart.py +++ b/tests/unit_tests/mcp_service/chart/test_handlebars_chart.py @@ -119,7 +119,7 @@ class TestHandlebarsChartConfig: ) def test_extra_fields_forbidden(self) -> None: - with pytest.raises(ValueError, match="Extra inputs"): + with pytest.raises(ValueError, match="Unknown field 'unknown_field'"): HandlebarsChartConfig( chart_type="handlebars", handlebars_template="

test

", diff --git a/tests/unit_tests/mcp_service/chart/test_new_chart_types.py b/tests/unit_tests/mcp_service/chart/test_new_chart_types.py index 9390530843b..9469f63b39a 100644 --- a/tests/unit_tests/mcp_service/chart/test_new_chart_types.py +++ b/tests/unit_tests/mcp_service/chart/test_new_chart_types.py @@ -103,15 +103,14 @@ class TestPieChartConfigSchema: assert config.filters is not None assert len(config.filters) == 1 - def test_pie_config_ignores_extra_fields(self) -> None: - config = PieChartConfig( - chart_type="pie", - dimension=ColumnRef(name="product"), - metric=ColumnRef(name="revenue", aggregate="SUM"), - unknown_field="bad", - ) - assert config.dimension.name == "product" - assert not hasattr(config, "unknown_field") + def test_pie_config_rejects_extra_fields(self) -> None: + with pytest.raises(ValidationError, match="Unknown field"): + PieChartConfig( + chart_type="pie", + dimension=ColumnRef(name="product"), + metric=ColumnRef(name="revenue", aggregate="SUM"), + unknown_field="bad", + ) def test_pie_config_missing_dimension(self) -> None: with pytest.raises(ValidationError): @@ -324,15 +323,14 @@ class TestPivotTableChartConfigSchema: metrics=[ColumnRef(name="revenue", aggregate="SUM")], ) - def test_pivot_table_ignores_extra_fields(self) -> None: - config = PivotTableChartConfig( - chart_type="pivot_table", - rows=[ColumnRef(name="product")], - metrics=[ColumnRef(name="revenue", aggregate="SUM")], - unknown_field="bad", - ) - assert config.rows[0].name == "product" - assert not hasattr(config, "unknown_field") + def test_pivot_table_rejects_extra_fields(self) -> None: + with pytest.raises(ValidationError, match="Unknown field"): + PivotTableChartConfig( + chart_type="pivot_table", + rows=[ColumnRef(name="product")], + metrics=[ColumnRef(name="revenue", aggregate="SUM")], + unknown_field="bad", + ) def test_pivot_table_valid_aggregate_functions(self) -> None: for agg in ["Sum", "Average", "Median", "Count", "Minimum", "Maximum"]: @@ -504,16 +502,15 @@ class TestMixedTimeseriesChartConfigSchema: y_secondary=[ColumnRef(name="orders", aggregate="COUNT")], ) - def test_mixed_timeseries_ignores_extra_fields(self) -> None: - config = MixedTimeseriesChartConfig( - chart_type="mixed_timeseries", - x=ColumnRef(name="date"), - y=[ColumnRef(name="revenue", aggregate="SUM")], - y_secondary=[ColumnRef(name="orders", aggregate="COUNT")], - unknown_field="bad", - ) - assert config.x.name == "date" - assert not hasattr(config, "unknown_field") + def test_mixed_timeseries_rejects_extra_fields(self) -> None: + with pytest.raises(ValidationError, match="Unknown field"): + MixedTimeseriesChartConfig( + chart_type="mixed_timeseries", + x=ColumnRef(name="date"), + y=[ColumnRef(name="revenue", aggregate="SUM")], + y_secondary=[ColumnRef(name="orders", aggregate="COUNT")], + unknown_field="bad", + ) def test_mixed_timeseries_default_row_limit(self) -> None: config = MixedTimeseriesChartConfig(