fix(mcp): detect unknown chart config fields and suggest correct ones (#38848)

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Kamil Gabryjelski
2026-03-25 18:38:23 +01:00
committed by GitHub
parent 04e07acf98
commit 16f5a2a41a
6 changed files with 266 additions and 61 deletions

View File

@@ -159,8 +159,8 @@ CRITICAL RULES - NEVER VIOLATE:
open_sql_lab_with_context, etc.) and use the URL it returns. open_sql_lab_with_context, etc.) and use the URL it returns.
- To modify an existing chart's filters, metrics, or dimensions, use update_chart. - To modify an existing chart's filters, metrics, or dimensions, use update_chart.
Do NOT use execute_sql for chart modifications. Do NOT use execute_sql for chart modifications.
- Parameter name reminders: open_sql_lab_with_context uses "sql" (not "query"), - Parameter name reminders: ALWAYS use the EXACT parameter names from the tool schema.
execute_sql uses "sql" (not "query"). Do NOT use Superset's internal form_data names.
IMPORTANT - Tool-Only Interaction: IMPORTANT - Tool-Only Interaction:
- Do NOT generate code artifacts, HTML pages, JavaScript snippets, or any code intended - Do NOT generate code artifacts, HTML pages, JavaScript snippets, or any code intended

View File

@@ -21,11 +21,13 @@ Pydantic schemas for chart-related responses
from __future__ import annotations from __future__ import annotations
import difflib
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Annotated, Any, Dict, List, Literal, Protocol from typing import Annotated, Any, Dict, List, Literal, Protocol
from pydantic import ( from pydantic import (
AliasChoices, AliasChoices,
AliasPath,
BaseModel, BaseModel,
ConfigDict, ConfigDict,
Field, Field,
@@ -395,6 +397,67 @@ def _normalize_group_by_input(v: Any) -> Any:
return v return v
def _top_level_key(alias: str | AliasPath) -> str | None:
"""Extract the top-level dict key from a str or AliasPath."""
if isinstance(alias, str):
return alias
if isinstance(alias, AliasPath) and alias.path and isinstance(alias.path[0], str):
return alias.path[0]
return None
def _get_known_fields(model_class: type[BaseModel]) -> set[str]:
"""Collect all valid field names including validation aliases."""
known: set[str] = set()
for field_name, field_info in model_class.model_fields.items():
known.add(field_name)
alias = field_info.validation_alias
if isinstance(alias, AliasChoices):
for choice in alias.choices:
key = _top_level_key(choice)
if key:
known.add(key)
elif alias is not None:
key = _top_level_key(alias)
if key:
known.add(key)
return known
def _check_unknown_fields(data: Any, model_class: type[BaseModel]) -> Any:
"""Raise ValueError for unrecognized fields with 'did you mean?' suggestions.
Catches fields that would be silently dropped by extra='ignore' and provides
actionable error messages to help LLMs self-correct parameter names.
"""
if not isinstance(data, dict):
return data
known = _get_known_fields(model_class)
unknown = set(data.keys()) - known
if not unknown:
return data
messages = []
for field in sorted(unknown):
matches = difflib.get_close_matches(field, sorted(known), n=1, cutoff=0.6)
if matches:
messages.append(f"Unknown field '{field}' — did you mean '{matches[0]}'?")
else:
messages.append(
f"Unknown field '{field}'. Valid fields: {', '.join(sorted(known))}"
)
raise ValueError(" | ".join(messages))
class UnknownFieldCheckMixin(BaseModel):
"""Mixin that rejects unknown fields with 'did you mean?' suggestions."""
@model_validator(mode="before")
@classmethod
def check_unknown_fields(cls, data: Any) -> Any:
return _check_unknown_fields(data, cls)
class ColumnRef(BaseModel): class ColumnRef(BaseModel):
model_config = ConfigDict(populate_by_name=True) model_config = ConfigDict(populate_by_name=True)
@@ -519,7 +582,7 @@ class FilterConfig(BaseModel):
# Actual chart types # Actual chart types
class PieChartConfig(BaseModel): class PieChartConfig(UnknownFieldCheckMixin):
model_config = ConfigDict(extra="ignore", populate_by_name=True) model_config = ConfigDict(extra="ignore", populate_by_name=True)
chart_type: Literal["pie"] = "pie" chart_type: Literal["pie"] = "pie"
@@ -559,7 +622,7 @@ class PieChartConfig(BaseModel):
) )
class PivotTableChartConfig(BaseModel): class PivotTableChartConfig(UnknownFieldCheckMixin):
model_config = ConfigDict(extra="ignore", populate_by_name=True) model_config = ConfigDict(extra="ignore", populate_by_name=True)
chart_type: Literal["pivot_table"] = "pivot_table" chart_type: Literal["pivot_table"] = "pivot_table"
@@ -603,7 +666,7 @@ class PivotTableChartConfig(BaseModel):
value_format: str = Field("SMART_NUMBER", max_length=50) value_format: str = Field("SMART_NUMBER", max_length=50)
class MixedTimeseriesChartConfig(BaseModel): class MixedTimeseriesChartConfig(UnknownFieldCheckMixin):
model_config = ConfigDict(extra="ignore", populate_by_name=True) model_config = ConfigDict(extra="ignore", populate_by_name=True)
chart_type: Literal["mixed_timeseries"] = "mixed_timeseries" chart_type: Literal["mixed_timeseries"] = "mixed_timeseries"
@@ -612,7 +675,11 @@ class MixedTimeseriesChartConfig(BaseModel):
description="Shared temporal X-axis column", description="Shared temporal X-axis column",
validation_alias=AliasChoices("x", "x_axis"), validation_alias=AliasChoices("x", "x_axis"),
) )
time_grain: TimeGrain | None = Field(None, description="PT1H, P1D, P1W, P1M, P1Y") time_grain: TimeGrain | None = Field(
None,
description="PT1H, P1D, P1W, P1M, P1Y",
validation_alias=AliasChoices("time_grain", "time_grain_sqla"),
)
# Primary series (Query A) # Primary series (Query A)
y: List[ColumnRef] = Field( y: List[ColumnRef] = Field(
..., ...,
@@ -659,8 +726,8 @@ class MixedTimeseriesChartConfig(BaseModel):
return _normalize_group_by_input(v) return _normalize_group_by_input(v)
class HandlebarsChartConfig(BaseModel): class HandlebarsChartConfig(UnknownFieldCheckMixin):
model_config = ConfigDict(extra="forbid") model_config = ConfigDict(extra="ignore")
chart_type: Literal["handlebars"] = Field( chart_type: Literal["handlebars"] = Field(
..., ...,
@@ -703,6 +770,7 @@ class HandlebarsChartConfig(BaseModel):
"Columns to group by in aggregate mode (query_mode='aggregate'). " "Columns to group by in aggregate mode (query_mode='aggregate'). "
"These become the dimensions for aggregation." "These become the dimensions for aggregation."
), ),
validation_alias=AliasChoices("groupby", "group_by"),
) )
metrics: list[ColumnRef] | None = Field( metrics: list[ColumnRef] | None = Field(
None, None,
@@ -761,7 +829,7 @@ class HandlebarsChartConfig(BaseModel):
return self return self
class TableChartConfig(BaseModel): class TableChartConfig(UnknownFieldCheckMixin):
model_config = ConfigDict(extra="ignore", populate_by_name=True) model_config = ConfigDict(extra="ignore", populate_by_name=True)
chart_type: Literal["table"] = "table" chart_type: Literal["table"] = "table"
@@ -779,7 +847,10 @@ class TableChartConfig(BaseModel):
description="Structured filters (column/op/value). " description="Structured filters (column/op/value). "
"Do NOT use adhoc_filters or raw SQL expressions.", "Do NOT use adhoc_filters or raw SQL expressions.",
) )
sort_by: List[str] | None = None sort_by: List[str] | None = Field(
None,
validation_alias=AliasChoices("sort_by", "order_by_cols", "order_by"),
)
row_limit: int = Field(1000, description="Max rows returned", ge=1, le=50000) row_limit: int = Field(1000, description="Max rows returned", ge=1, le=50000)
@model_validator(mode="after") @model_validator(mode="after")
@@ -810,7 +881,7 @@ class TableChartConfig(BaseModel):
return self return self
class XYChartConfig(BaseModel): class XYChartConfig(UnknownFieldCheckMixin):
model_config = ConfigDict(extra="ignore", populate_by_name=True) model_config = ConfigDict(extra="ignore", populate_by_name=True)
chart_type: Literal["xy"] = "xy" chart_type: Literal["xy"] = "xy"
@@ -827,12 +898,17 @@ class XYChartConfig(BaseModel):
) )
kind: Literal["line", "bar", "area", "scatter"] = "line" kind: Literal["line", "bar", "area", "scatter"] = "line"
time_grain: TimeGrain | None = Field( time_grain: TimeGrain | None = Field(
None, description="PT1S, PT1M, PT1H, P1D, P1W, P1M, P3M, P1Y" None,
description="PT1S, PT1M, PT1H, P1D, P1W, P1M, P3M, P1Y",
validation_alias=AliasChoices("time_grain", "time_grain_sqla"),
) )
orientation: Literal["vertical", "horizontal"] | None = Field( orientation: Literal["vertical", "horizontal"] | None = Field(
None, description="Bar orientation (only for kind='bar')" None, description="Bar orientation (only for kind='bar')"
) )
stacked: bool = False stacked: bool = Field(
False,
validation_alias=AliasChoices("stacked", "stack"),
)
group_by: List[ColumnRef] | None = Field( group_by: List[ColumnRef] | None = Field(
None, None,
description="Series breakdown columns", description="Series breakdown columns",

View File

@@ -33,6 +33,7 @@ class ExecuteSqlRequest(BaseModel):
sql: str = Field( sql: str = Field(
..., ...,
description="SQL query to execute (supports Jinja2 {{ var }} template syntax)", description="SQL query to execute (supports Jinja2 {{ var }} template syntax)",
validation_alias=AliasChoices("sql", "query"),
) )
schema_name: str | None = Field( schema_name: str | None = Field(
None, description="Schema to use for query execution", alias="schema" None, description="Schema to use for query execution", alias="schema"

View File

@@ -235,17 +235,16 @@ class TestXYChartConfig:
) )
assert config.kind == "area" assert config.kind == "area"
def test_unknown_fields_ignored(self) -> None: def test_unknown_fields_raise_error(self) -> None:
"""Test that unknown fields are silently ignored (extra='ignore').""" """Test that unknown fields raise ValueError with suggestions."""
config = XYChartConfig( with pytest.raises(ValidationError, match="Unknown field"):
chart_type="xy", XYChartConfig(
x=ColumnRef(name="territory"), chart_type="xy",
y=[ColumnRef(name="sales", aggregate="SUM")], x=ColumnRef(name="territory"),
kind="bar", y=[ColumnRef(name="sales", aggregate="SUM")],
unknown_field="bad", kind="bar",
) unknown_field="bad",
assert config.kind == "bar" )
assert not hasattr(config, "unknown_field")
def test_series_alias_accepted(self) -> None: def test_series_alias_accepted(self) -> None:
"""Test that 'series' is accepted as alias for 'group_by'.""" """Test that 'series' is accepted as alias for 'group_by'."""
@@ -430,12 +429,144 @@ class TestRowLimit:
class TestTableChartConfigExtraFields: class TestTableChartConfigExtraFields:
"""Test TableChartConfig rejects unknown fields.""" """Test TableChartConfig rejects unknown fields."""
def test_unknown_fields_ignored(self) -> None: def test_unknown_fields_raise_error(self) -> None:
"""Test that unknown fields are silently ignored (extra='ignore').""" """Test that unknown fields raise ValueError with valid field list."""
config = TableChartConfig( with pytest.raises(ValidationError, match="Unknown field 'foo'"):
chart_type="table", TableChartConfig(
columns=[ColumnRef(name="product")], chart_type="table",
foo="bar", columns=[ColumnRef(name="product")],
foo="bar",
)
class TestAliasChoices:
"""Test that common Superset form_data aliases are accepted."""
def test_xy_stack_alias_for_stacked(self) -> None:
"""Test that 'stack' is accepted as alias for 'stacked'."""
config = XYChartConfig.model_validate(
{
"chart_type": "xy",
"x": {"name": "category"},
"y": [{"name": "sales", "aggregate": "SUM"}],
"stack": True,
}
) )
assert len(config.columns) == 1 assert config.stacked is True
assert not hasattr(config, "foo")
def test_xy_stacked_still_works(self) -> None:
"""Test that 'stacked' still works as primary field name."""
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="category"),
y=[ColumnRef(name="sales", aggregate="SUM")],
stacked=True,
)
assert config.stacked is True
def test_xy_time_grain_sqla_alias(self) -> None:
"""Test that 'time_grain_sqla' is accepted as alias for 'time_grain'."""
config = XYChartConfig.model_validate(
{
"chart_type": "xy",
"x": {"name": "order_date"},
"y": [{"name": "sales", "aggregate": "SUM"}],
"time_grain_sqla": "P1D",
}
)
assert config.time_grain is not None
def test_table_order_by_alias_for_sort_by(self) -> None:
"""Test that 'order_by' is accepted as alias for 'sort_by'."""
config = TableChartConfig.model_validate(
{
"chart_type": "table",
"columns": [{"name": "product"}],
"order_by": ["product"],
}
)
assert config.sort_by == ["product"]
def test_mixed_timeseries_time_grain_sqla_alias(self) -> None:
"""Test that 'time_grain_sqla' works for MixedTimeseriesChartConfig."""
from superset.mcp_service.chart.schemas import MixedTimeseriesChartConfig
config = MixedTimeseriesChartConfig.model_validate(
{
"chart_type": "mixed_timeseries",
"x": {"name": "order_date"},
"y": [{"name": "sales", "aggregate": "SUM"}],
"y_secondary": [{"name": "profit", "aggregate": "SUM"}],
"time_grain_sqla": "P1M",
}
)
assert config.time_grain is not None
class TestUnknownFieldDetection:
"""Test that unknown fields produce helpful error messages."""
def test_near_miss_suggests_correct_field(self) -> None:
"""Test that a near-miss field name produces 'did you mean?' suggestion."""
with pytest.raises(ValidationError, match="did you mean"):
XYChartConfig.model_validate(
{
"chart_type": "xy",
"x": {"name": "category"},
"y": [{"name": "sales", "aggregate": "SUM"}],
"stacks": True,
}
)
def test_completely_unknown_field_lists_valid_fields(self) -> None:
"""Test that a completely unknown field lists valid fields."""
with pytest.raises(ValidationError, match="Valid fields:"):
XYChartConfig.model_validate(
{
"chart_type": "xy",
"x": {"name": "category"},
"y": [{"name": "sales", "aggregate": "SUM"}],
"zzz_nonexistent": True,
}
)
def test_pie_chart_unknown_field(self) -> None:
"""Test unknown field detection on PieChartConfig."""
from superset.mcp_service.chart.schemas import PieChartConfig
with pytest.raises(ValidationError, match="Unknown field"):
PieChartConfig.model_validate(
{
"chart_type": "pie",
"dimension": {"name": "category"},
"metric": {"name": "sales", "aggregate": "SUM"},
"bad_field": True,
}
)
def test_table_chart_unknown_field(self) -> None:
"""Test unknown field detection on TableChartConfig."""
with pytest.raises(ValidationError, match="Unknown field"):
TableChartConfig.model_validate(
{
"chart_type": "table",
"columns": [{"name": "product"}],
"invalid_param": "test",
}
)
def test_known_aliases_not_flagged_as_unknown(self) -> None:
"""Test that known aliases pass validation without errors."""
config = XYChartConfig.model_validate(
{
"chart_type": "xy",
"x_axis": {"name": "category"},
"metrics": [{"name": "sales", "aggregate": "SUM"}],
"groupby": [{"name": "region"}],
"stack": True,
"time_grain_sqla": "P1D",
}
)
assert config.stacked is True
assert config.row_limit == 10000
assert config.group_by is not None

View File

@@ -119,7 +119,7 @@ class TestHandlebarsChartConfig:
) )
def test_extra_fields_forbidden(self) -> None: def test_extra_fields_forbidden(self) -> None:
with pytest.raises(ValueError, match="Extra inputs"): with pytest.raises(ValueError, match="Unknown field 'unknown_field'"):
HandlebarsChartConfig( HandlebarsChartConfig(
chart_type="handlebars", chart_type="handlebars",
handlebars_template="<p>test</p>", handlebars_template="<p>test</p>",

View File

@@ -103,15 +103,14 @@ class TestPieChartConfigSchema:
assert config.filters is not None assert config.filters is not None
assert len(config.filters) == 1 assert len(config.filters) == 1
def test_pie_config_ignores_extra_fields(self) -> None: def test_pie_config_rejects_extra_fields(self) -> None:
config = PieChartConfig( with pytest.raises(ValidationError, match="Unknown field"):
chart_type="pie", PieChartConfig(
dimension=ColumnRef(name="product"), chart_type="pie",
metric=ColumnRef(name="revenue", aggregate="SUM"), dimension=ColumnRef(name="product"),
unknown_field="bad", metric=ColumnRef(name="revenue", aggregate="SUM"),
) unknown_field="bad",
assert config.dimension.name == "product" )
assert not hasattr(config, "unknown_field")
def test_pie_config_missing_dimension(self) -> None: def test_pie_config_missing_dimension(self) -> None:
with pytest.raises(ValidationError): with pytest.raises(ValidationError):
@@ -324,15 +323,14 @@ class TestPivotTableChartConfigSchema:
metrics=[ColumnRef(name="revenue", aggregate="SUM")], metrics=[ColumnRef(name="revenue", aggregate="SUM")],
) )
def test_pivot_table_ignores_extra_fields(self) -> None: def test_pivot_table_rejects_extra_fields(self) -> None:
config = PivotTableChartConfig( with pytest.raises(ValidationError, match="Unknown field"):
chart_type="pivot_table", PivotTableChartConfig(
rows=[ColumnRef(name="product")], chart_type="pivot_table",
metrics=[ColumnRef(name="revenue", aggregate="SUM")], rows=[ColumnRef(name="product")],
unknown_field="bad", metrics=[ColumnRef(name="revenue", aggregate="SUM")],
) unknown_field="bad",
assert config.rows[0].name == "product" )
assert not hasattr(config, "unknown_field")
def test_pivot_table_valid_aggregate_functions(self) -> None: def test_pivot_table_valid_aggregate_functions(self) -> None:
for agg in ["Sum", "Average", "Median", "Count", "Minimum", "Maximum"]: for agg in ["Sum", "Average", "Median", "Count", "Minimum", "Maximum"]:
@@ -504,16 +502,15 @@ class TestMixedTimeseriesChartConfigSchema:
y_secondary=[ColumnRef(name="orders", aggregate="COUNT")], y_secondary=[ColumnRef(name="orders", aggregate="COUNT")],
) )
def test_mixed_timeseries_ignores_extra_fields(self) -> None: def test_mixed_timeseries_rejects_extra_fields(self) -> None:
config = MixedTimeseriesChartConfig( with pytest.raises(ValidationError, match="Unknown field"):
chart_type="mixed_timeseries", MixedTimeseriesChartConfig(
x=ColumnRef(name="date"), chart_type="mixed_timeseries",
y=[ColumnRef(name="revenue", aggregate="SUM")], x=ColumnRef(name="date"),
y_secondary=[ColumnRef(name="orders", aggregate="COUNT")], y=[ColumnRef(name="revenue", aggregate="SUM")],
unknown_field="bad", y_secondary=[ColumnRef(name="orders", aggregate="COUNT")],
) unknown_field="bad",
assert config.x.name == "date" )
assert not hasattr(config, "unknown_field")
def test_mixed_timeseries_default_row_limit(self) -> None: def test_mixed_timeseries_default_row_limit(self) -> None:
config = MixedTimeseriesChartConfig( config = MixedTimeseriesChartConfig(