fix(mcp): detect unknown chart config fields and suggest correct ones (#38848)

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Kamil Gabryjelski
2026-03-25 18:38:23 +01:00
committed by GitHub
parent 04e07acf98
commit 16f5a2a41a
6 changed files with 266 additions and 61 deletions

View File

@@ -159,8 +159,8 @@ CRITICAL RULES - NEVER VIOLATE:
open_sql_lab_with_context, etc.) and use the URL it returns.
- To modify an existing chart's filters, metrics, or dimensions, use update_chart.
Do NOT use execute_sql for chart modifications.
- Parameter name reminders: open_sql_lab_with_context uses "sql" (not "query"),
execute_sql uses "sql" (not "query").
- Parameter name reminders: ALWAYS use the EXACT parameter names from the tool schema.
Do NOT use Superset's internal form_data names.
IMPORTANT - Tool-Only Interaction:
- Do NOT generate code artifacts, HTML pages, JavaScript snippets, or any code intended

View File

@@ -21,11 +21,13 @@ Pydantic schemas for chart-related responses
from __future__ import annotations
import difflib
from datetime import datetime, timezone
from typing import Annotated, Any, Dict, List, Literal, Protocol
from pydantic import (
AliasChoices,
AliasPath,
BaseModel,
ConfigDict,
Field,
@@ -395,6 +397,67 @@ def _normalize_group_by_input(v: Any) -> Any:
return v
def _top_level_key(alias: str | AliasPath) -> str | None:
"""Extract the top-level dict key from a str or AliasPath."""
if isinstance(alias, str):
return alias
if isinstance(alias, AliasPath) and alias.path and isinstance(alias.path[0], str):
return alias.path[0]
return None
def _get_known_fields(model_class: type[BaseModel]) -> set[str]:
"""Collect all valid field names including validation aliases."""
known: set[str] = set()
for field_name, field_info in model_class.model_fields.items():
known.add(field_name)
alias = field_info.validation_alias
if isinstance(alias, AliasChoices):
for choice in alias.choices:
key = _top_level_key(choice)
if key:
known.add(key)
elif alias is not None:
key = _top_level_key(alias)
if key:
known.add(key)
return known
def _check_unknown_fields(data: Any, model_class: type[BaseModel]) -> Any:
"""Raise ValueError for unrecognized fields with 'did you mean?' suggestions.
Catches fields that would be silently dropped by extra='ignore' and provides
actionable error messages to help LLMs self-correct parameter names.
"""
if not isinstance(data, dict):
return data
known = _get_known_fields(model_class)
unknown = set(data.keys()) - known
if not unknown:
return data
messages = []
for field in sorted(unknown):
matches = difflib.get_close_matches(field, sorted(known), n=1, cutoff=0.6)
if matches:
messages.append(f"Unknown field '{field}' — did you mean '{matches[0]}'?")
else:
messages.append(
f"Unknown field '{field}'. Valid fields: {', '.join(sorted(known))}"
)
raise ValueError(" | ".join(messages))
class UnknownFieldCheckMixin(BaseModel):
"""Mixin that rejects unknown fields with 'did you mean?' suggestions."""
@model_validator(mode="before")
@classmethod
def check_unknown_fields(cls, data: Any) -> Any:
return _check_unknown_fields(data, cls)
class ColumnRef(BaseModel):
model_config = ConfigDict(populate_by_name=True)
@@ -519,7 +582,7 @@ class FilterConfig(BaseModel):
# Actual chart types
class PieChartConfig(BaseModel):
class PieChartConfig(UnknownFieldCheckMixin):
model_config = ConfigDict(extra="ignore", populate_by_name=True)
chart_type: Literal["pie"] = "pie"
@@ -559,7 +622,7 @@ class PieChartConfig(BaseModel):
)
class PivotTableChartConfig(BaseModel):
class PivotTableChartConfig(UnknownFieldCheckMixin):
model_config = ConfigDict(extra="ignore", populate_by_name=True)
chart_type: Literal["pivot_table"] = "pivot_table"
@@ -603,7 +666,7 @@ class PivotTableChartConfig(BaseModel):
value_format: str = Field("SMART_NUMBER", max_length=50)
class MixedTimeseriesChartConfig(BaseModel):
class MixedTimeseriesChartConfig(UnknownFieldCheckMixin):
model_config = ConfigDict(extra="ignore", populate_by_name=True)
chart_type: Literal["mixed_timeseries"] = "mixed_timeseries"
@@ -612,7 +675,11 @@ class MixedTimeseriesChartConfig(BaseModel):
description="Shared temporal X-axis column",
validation_alias=AliasChoices("x", "x_axis"),
)
time_grain: TimeGrain | None = Field(None, description="PT1H, P1D, P1W, P1M, P1Y")
time_grain: TimeGrain | None = Field(
None,
description="PT1H, P1D, P1W, P1M, P1Y",
validation_alias=AliasChoices("time_grain", "time_grain_sqla"),
)
# Primary series (Query A)
y: List[ColumnRef] = Field(
...,
@@ -659,8 +726,8 @@ class MixedTimeseriesChartConfig(BaseModel):
return _normalize_group_by_input(v)
class HandlebarsChartConfig(BaseModel):
model_config = ConfigDict(extra="forbid")
class HandlebarsChartConfig(UnknownFieldCheckMixin):
model_config = ConfigDict(extra="ignore")
chart_type: Literal["handlebars"] = Field(
...,
@@ -703,6 +770,7 @@ class HandlebarsChartConfig(BaseModel):
"Columns to group by in aggregate mode (query_mode='aggregate'). "
"These become the dimensions for aggregation."
),
validation_alias=AliasChoices("groupby", "group_by"),
)
metrics: list[ColumnRef] | None = Field(
None,
@@ -761,7 +829,7 @@ class HandlebarsChartConfig(BaseModel):
return self
class TableChartConfig(BaseModel):
class TableChartConfig(UnknownFieldCheckMixin):
model_config = ConfigDict(extra="ignore", populate_by_name=True)
chart_type: Literal["table"] = "table"
@@ -779,7 +847,10 @@ class TableChartConfig(BaseModel):
description="Structured filters (column/op/value). "
"Do NOT use adhoc_filters or raw SQL expressions.",
)
sort_by: List[str] | None = None
sort_by: List[str] | None = Field(
None,
validation_alias=AliasChoices("sort_by", "order_by_cols", "order_by"),
)
row_limit: int = Field(1000, description="Max rows returned", ge=1, le=50000)
@model_validator(mode="after")
@@ -810,7 +881,7 @@ class TableChartConfig(BaseModel):
return self
class XYChartConfig(BaseModel):
class XYChartConfig(UnknownFieldCheckMixin):
model_config = ConfigDict(extra="ignore", populate_by_name=True)
chart_type: Literal["xy"] = "xy"
@@ -827,12 +898,17 @@ class XYChartConfig(BaseModel):
)
kind: Literal["line", "bar", "area", "scatter"] = "line"
time_grain: TimeGrain | None = Field(
None, description="PT1S, PT1M, PT1H, P1D, P1W, P1M, P3M, P1Y"
None,
description="PT1S, PT1M, PT1H, P1D, P1W, P1M, P3M, P1Y",
validation_alias=AliasChoices("time_grain", "time_grain_sqla"),
)
orientation: Literal["vertical", "horizontal"] | None = Field(
None, description="Bar orientation (only for kind='bar')"
)
stacked: bool = False
stacked: bool = Field(
False,
validation_alias=AliasChoices("stacked", "stack"),
)
group_by: List[ColumnRef] | None = Field(
None,
description="Series breakdown columns",

View File

@@ -33,6 +33,7 @@ class ExecuteSqlRequest(BaseModel):
sql: str = Field(
...,
description="SQL query to execute (supports Jinja2 {{ var }} template syntax)",
validation_alias=AliasChoices("sql", "query"),
)
schema_name: str | None = Field(
None, description="Schema to use for query execution", alias="schema"

View File

@@ -235,17 +235,16 @@ class TestXYChartConfig:
)
assert config.kind == "area"
def test_unknown_fields_ignored(self) -> None:
"""Test that unknown fields are silently ignored (extra='ignore')."""
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="territory"),
y=[ColumnRef(name="sales", aggregate="SUM")],
kind="bar",
unknown_field="bad",
)
assert config.kind == "bar"
assert not hasattr(config, "unknown_field")
def test_unknown_fields_raise_error(self) -> None:
"""Test that unknown fields raise ValueError with suggestions."""
with pytest.raises(ValidationError, match="Unknown field"):
XYChartConfig(
chart_type="xy",
x=ColumnRef(name="territory"),
y=[ColumnRef(name="sales", aggregate="SUM")],
kind="bar",
unknown_field="bad",
)
def test_series_alias_accepted(self) -> None:
"""Test that 'series' is accepted as alias for 'group_by'."""
@@ -430,12 +429,144 @@ class TestRowLimit:
class TestTableChartConfigExtraFields:
"""Test TableChartConfig rejects unknown fields."""
def test_unknown_fields_ignored(self) -> None:
"""Test that unknown fields are silently ignored (extra='ignore')."""
config = TableChartConfig(
chart_type="table",
columns=[ColumnRef(name="product")],
foo="bar",
def test_unknown_fields_raise_error(self) -> None:
"""Test that unknown fields raise ValueError with valid field list."""
with pytest.raises(ValidationError, match="Unknown field 'foo'"):
TableChartConfig(
chart_type="table",
columns=[ColumnRef(name="product")],
foo="bar",
)
class TestAliasChoices:
"""Test that common Superset form_data aliases are accepted."""
def test_xy_stack_alias_for_stacked(self) -> None:
"""Test that 'stack' is accepted as alias for 'stacked'."""
config = XYChartConfig.model_validate(
{
"chart_type": "xy",
"x": {"name": "category"},
"y": [{"name": "sales", "aggregate": "SUM"}],
"stack": True,
}
)
assert len(config.columns) == 1
assert not hasattr(config, "foo")
assert config.stacked is True
def test_xy_stacked_still_works(self) -> None:
"""Test that 'stacked' still works as primary field name."""
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="category"),
y=[ColumnRef(name="sales", aggregate="SUM")],
stacked=True,
)
assert config.stacked is True
def test_xy_time_grain_sqla_alias(self) -> None:
"""Test that 'time_grain_sqla' is accepted as alias for 'time_grain'."""
config = XYChartConfig.model_validate(
{
"chart_type": "xy",
"x": {"name": "order_date"},
"y": [{"name": "sales", "aggregate": "SUM"}],
"time_grain_sqla": "P1D",
}
)
assert config.time_grain is not None
def test_table_order_by_alias_for_sort_by(self) -> None:
"""Test that 'order_by' is accepted as alias for 'sort_by'."""
config = TableChartConfig.model_validate(
{
"chart_type": "table",
"columns": [{"name": "product"}],
"order_by": ["product"],
}
)
assert config.sort_by == ["product"]
def test_mixed_timeseries_time_grain_sqla_alias(self) -> None:
"""Test that 'time_grain_sqla' works for MixedTimeseriesChartConfig."""
from superset.mcp_service.chart.schemas import MixedTimeseriesChartConfig
config = MixedTimeseriesChartConfig.model_validate(
{
"chart_type": "mixed_timeseries",
"x": {"name": "order_date"},
"y": [{"name": "sales", "aggregate": "SUM"}],
"y_secondary": [{"name": "profit", "aggregate": "SUM"}],
"time_grain_sqla": "P1M",
}
)
assert config.time_grain is not None
class TestUnknownFieldDetection:
"""Test that unknown fields produce helpful error messages."""
def test_near_miss_suggests_correct_field(self) -> None:
"""Test that a near-miss field name produces 'did you mean?' suggestion."""
with pytest.raises(ValidationError, match="did you mean"):
XYChartConfig.model_validate(
{
"chart_type": "xy",
"x": {"name": "category"},
"y": [{"name": "sales", "aggregate": "SUM"}],
"stacks": True,
}
)
def test_completely_unknown_field_lists_valid_fields(self) -> None:
"""Test that a completely unknown field lists valid fields."""
with pytest.raises(ValidationError, match="Valid fields:"):
XYChartConfig.model_validate(
{
"chart_type": "xy",
"x": {"name": "category"},
"y": [{"name": "sales", "aggregate": "SUM"}],
"zzz_nonexistent": True,
}
)
def test_pie_chart_unknown_field(self) -> None:
"""Test unknown field detection on PieChartConfig."""
from superset.mcp_service.chart.schemas import PieChartConfig
with pytest.raises(ValidationError, match="Unknown field"):
PieChartConfig.model_validate(
{
"chart_type": "pie",
"dimension": {"name": "category"},
"metric": {"name": "sales", "aggregate": "SUM"},
"bad_field": True,
}
)
def test_table_chart_unknown_field(self) -> None:
"""Test unknown field detection on TableChartConfig."""
with pytest.raises(ValidationError, match="Unknown field"):
TableChartConfig.model_validate(
{
"chart_type": "table",
"columns": [{"name": "product"}],
"invalid_param": "test",
}
)
def test_known_aliases_not_flagged_as_unknown(self) -> None:
"""Test that known aliases pass validation without errors."""
config = XYChartConfig.model_validate(
{
"chart_type": "xy",
"x_axis": {"name": "category"},
"metrics": [{"name": "sales", "aggregate": "SUM"}],
"groupby": [{"name": "region"}],
"stack": True,
"time_grain_sqla": "P1D",
}
)
assert config.stacked is True
assert config.row_limit == 10000
assert config.group_by is not None

View File

@@ -119,7 +119,7 @@ class TestHandlebarsChartConfig:
)
def test_extra_fields_forbidden(self) -> None:
with pytest.raises(ValueError, match="Extra inputs"):
with pytest.raises(ValueError, match="Unknown field 'unknown_field'"):
HandlebarsChartConfig(
chart_type="handlebars",
handlebars_template="<p>test</p>",

View File

@@ -103,15 +103,14 @@ class TestPieChartConfigSchema:
assert config.filters is not None
assert len(config.filters) == 1
def test_pie_config_ignores_extra_fields(self) -> None:
config = PieChartConfig(
chart_type="pie",
dimension=ColumnRef(name="product"),
metric=ColumnRef(name="revenue", aggregate="SUM"),
unknown_field="bad",
)
assert config.dimension.name == "product"
assert not hasattr(config, "unknown_field")
def test_pie_config_rejects_extra_fields(self) -> None:
with pytest.raises(ValidationError, match="Unknown field"):
PieChartConfig(
chart_type="pie",
dimension=ColumnRef(name="product"),
metric=ColumnRef(name="revenue", aggregate="SUM"),
unknown_field="bad",
)
def test_pie_config_missing_dimension(self) -> None:
with pytest.raises(ValidationError):
@@ -324,15 +323,14 @@ class TestPivotTableChartConfigSchema:
metrics=[ColumnRef(name="revenue", aggregate="SUM")],
)
def test_pivot_table_ignores_extra_fields(self) -> None:
config = PivotTableChartConfig(
chart_type="pivot_table",
rows=[ColumnRef(name="product")],
metrics=[ColumnRef(name="revenue", aggregate="SUM")],
unknown_field="bad",
)
assert config.rows[0].name == "product"
assert not hasattr(config, "unknown_field")
def test_pivot_table_rejects_extra_fields(self) -> None:
with pytest.raises(ValidationError, match="Unknown field"):
PivotTableChartConfig(
chart_type="pivot_table",
rows=[ColumnRef(name="product")],
metrics=[ColumnRef(name="revenue", aggregate="SUM")],
unknown_field="bad",
)
def test_pivot_table_valid_aggregate_functions(self) -> None:
for agg in ["Sum", "Average", "Median", "Count", "Minimum", "Maximum"]:
@@ -504,16 +502,15 @@ class TestMixedTimeseriesChartConfigSchema:
y_secondary=[ColumnRef(name="orders", aggregate="COUNT")],
)
def test_mixed_timeseries_ignores_extra_fields(self) -> None:
config = MixedTimeseriesChartConfig(
chart_type="mixed_timeseries",
x=ColumnRef(name="date"),
y=[ColumnRef(name="revenue", aggregate="SUM")],
y_secondary=[ColumnRef(name="orders", aggregate="COUNT")],
unknown_field="bad",
)
assert config.x.name == "date"
assert not hasattr(config, "unknown_field")
def test_mixed_timeseries_rejects_extra_fields(self) -> None:
with pytest.raises(ValidationError, match="Unknown field"):
MixedTimeseriesChartConfig(
chart_type="mixed_timeseries",
x=ColumnRef(name="date"),
y=[ColumnRef(name="revenue", aggregate="SUM")],
y_secondary=[ColumnRef(name="orders", aggregate="COUNT")],
unknown_field="bad",
)
def test_mixed_timeseries_default_row_limit(self) -> None:
config = MixedTimeseriesChartConfig(