mirror of
https://github.com/apache/superset.git
synced 2026-04-07 18:35:15 +00:00
fix(mcp): detect unknown chart config fields and suggest correct ones (#38848)
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
committed by
GitHub
parent
04e07acf98
commit
16f5a2a41a
@@ -159,8 +159,8 @@ CRITICAL RULES - NEVER VIOLATE:
|
||||
open_sql_lab_with_context, etc.) and use the URL it returns.
|
||||
- To modify an existing chart's filters, metrics, or dimensions, use update_chart.
|
||||
Do NOT use execute_sql for chart modifications.
|
||||
- Parameter name reminders: open_sql_lab_with_context uses "sql" (not "query"),
|
||||
execute_sql uses "sql" (not "query").
|
||||
- Parameter name reminders: ALWAYS use the EXACT parameter names from the tool schema.
|
||||
Do NOT use Superset's internal form_data names.
|
||||
|
||||
IMPORTANT - Tool-Only Interaction:
|
||||
- Do NOT generate code artifacts, HTML pages, JavaScript snippets, or any code intended
|
||||
|
||||
@@ -21,11 +21,13 @@ Pydantic schemas for chart-related responses
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import difflib
|
||||
from datetime import datetime, timezone
|
||||
from typing import Annotated, Any, Dict, List, Literal, Protocol
|
||||
|
||||
from pydantic import (
|
||||
AliasChoices,
|
||||
AliasPath,
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
@@ -395,6 +397,67 @@ def _normalize_group_by_input(v: Any) -> Any:
|
||||
return v
|
||||
|
||||
|
||||
def _top_level_key(alias: str | AliasPath) -> str | None:
|
||||
"""Extract the top-level dict key from a str or AliasPath."""
|
||||
if isinstance(alias, str):
|
||||
return alias
|
||||
if isinstance(alias, AliasPath) and alias.path and isinstance(alias.path[0], str):
|
||||
return alias.path[0]
|
||||
return None
|
||||
|
||||
|
||||
def _get_known_fields(model_class: type[BaseModel]) -> set[str]:
|
||||
"""Collect all valid field names including validation aliases."""
|
||||
known: set[str] = set()
|
||||
for field_name, field_info in model_class.model_fields.items():
|
||||
known.add(field_name)
|
||||
alias = field_info.validation_alias
|
||||
if isinstance(alias, AliasChoices):
|
||||
for choice in alias.choices:
|
||||
key = _top_level_key(choice)
|
||||
if key:
|
||||
known.add(key)
|
||||
elif alias is not None:
|
||||
key = _top_level_key(alias)
|
||||
if key:
|
||||
known.add(key)
|
||||
return known
|
||||
|
||||
|
||||
def _check_unknown_fields(data: Any, model_class: type[BaseModel]) -> Any:
|
||||
"""Raise ValueError for unrecognized fields with 'did you mean?' suggestions.
|
||||
|
||||
Catches fields that would be silently dropped by extra='ignore' and provides
|
||||
actionable error messages to help LLMs self-correct parameter names.
|
||||
"""
|
||||
if not isinstance(data, dict):
|
||||
return data
|
||||
known = _get_known_fields(model_class)
|
||||
unknown = set(data.keys()) - known
|
||||
if not unknown:
|
||||
return data
|
||||
|
||||
messages = []
|
||||
for field in sorted(unknown):
|
||||
matches = difflib.get_close_matches(field, sorted(known), n=1, cutoff=0.6)
|
||||
if matches:
|
||||
messages.append(f"Unknown field '{field}' — did you mean '{matches[0]}'?")
|
||||
else:
|
||||
messages.append(
|
||||
f"Unknown field '{field}'. Valid fields: {', '.join(sorted(known))}"
|
||||
)
|
||||
raise ValueError(" | ".join(messages))
|
||||
|
||||
|
||||
class UnknownFieldCheckMixin(BaseModel):
|
||||
"""Mixin that rejects unknown fields with 'did you mean?' suggestions."""
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_unknown_fields(cls, data: Any) -> Any:
|
||||
return _check_unknown_fields(data, cls)
|
||||
|
||||
|
||||
class ColumnRef(BaseModel):
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
@@ -519,7 +582,7 @@ class FilterConfig(BaseModel):
|
||||
|
||||
|
||||
# Actual chart types
|
||||
class PieChartConfig(BaseModel):
|
||||
class PieChartConfig(UnknownFieldCheckMixin):
|
||||
model_config = ConfigDict(extra="ignore", populate_by_name=True)
|
||||
|
||||
chart_type: Literal["pie"] = "pie"
|
||||
@@ -559,7 +622,7 @@ class PieChartConfig(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class PivotTableChartConfig(BaseModel):
|
||||
class PivotTableChartConfig(UnknownFieldCheckMixin):
|
||||
model_config = ConfigDict(extra="ignore", populate_by_name=True)
|
||||
|
||||
chart_type: Literal["pivot_table"] = "pivot_table"
|
||||
@@ -603,7 +666,7 @@ class PivotTableChartConfig(BaseModel):
|
||||
value_format: str = Field("SMART_NUMBER", max_length=50)
|
||||
|
||||
|
||||
class MixedTimeseriesChartConfig(BaseModel):
|
||||
class MixedTimeseriesChartConfig(UnknownFieldCheckMixin):
|
||||
model_config = ConfigDict(extra="ignore", populate_by_name=True)
|
||||
|
||||
chart_type: Literal["mixed_timeseries"] = "mixed_timeseries"
|
||||
@@ -612,7 +675,11 @@ class MixedTimeseriesChartConfig(BaseModel):
|
||||
description="Shared temporal X-axis column",
|
||||
validation_alias=AliasChoices("x", "x_axis"),
|
||||
)
|
||||
time_grain: TimeGrain | None = Field(None, description="PT1H, P1D, P1W, P1M, P1Y")
|
||||
time_grain: TimeGrain | None = Field(
|
||||
None,
|
||||
description="PT1H, P1D, P1W, P1M, P1Y",
|
||||
validation_alias=AliasChoices("time_grain", "time_grain_sqla"),
|
||||
)
|
||||
# Primary series (Query A)
|
||||
y: List[ColumnRef] = Field(
|
||||
...,
|
||||
@@ -659,8 +726,8 @@ class MixedTimeseriesChartConfig(BaseModel):
|
||||
return _normalize_group_by_input(v)
|
||||
|
||||
|
||||
class HandlebarsChartConfig(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
class HandlebarsChartConfig(UnknownFieldCheckMixin):
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
|
||||
chart_type: Literal["handlebars"] = Field(
|
||||
...,
|
||||
@@ -703,6 +770,7 @@ class HandlebarsChartConfig(BaseModel):
|
||||
"Columns to group by in aggregate mode (query_mode='aggregate'). "
|
||||
"These become the dimensions for aggregation."
|
||||
),
|
||||
validation_alias=AliasChoices("groupby", "group_by"),
|
||||
)
|
||||
metrics: list[ColumnRef] | None = Field(
|
||||
None,
|
||||
@@ -761,7 +829,7 @@ class HandlebarsChartConfig(BaseModel):
|
||||
return self
|
||||
|
||||
|
||||
class TableChartConfig(BaseModel):
|
||||
class TableChartConfig(UnknownFieldCheckMixin):
|
||||
model_config = ConfigDict(extra="ignore", populate_by_name=True)
|
||||
|
||||
chart_type: Literal["table"] = "table"
|
||||
@@ -779,7 +847,10 @@ class TableChartConfig(BaseModel):
|
||||
description="Structured filters (column/op/value). "
|
||||
"Do NOT use adhoc_filters or raw SQL expressions.",
|
||||
)
|
||||
sort_by: List[str] | None = None
|
||||
sort_by: List[str] | None = Field(
|
||||
None,
|
||||
validation_alias=AliasChoices("sort_by", "order_by_cols", "order_by"),
|
||||
)
|
||||
row_limit: int = Field(1000, description="Max rows returned", ge=1, le=50000)
|
||||
|
||||
@model_validator(mode="after")
|
||||
@@ -810,7 +881,7 @@ class TableChartConfig(BaseModel):
|
||||
return self
|
||||
|
||||
|
||||
class XYChartConfig(BaseModel):
|
||||
class XYChartConfig(UnknownFieldCheckMixin):
|
||||
model_config = ConfigDict(extra="ignore", populate_by_name=True)
|
||||
|
||||
chart_type: Literal["xy"] = "xy"
|
||||
@@ -827,12 +898,17 @@ class XYChartConfig(BaseModel):
|
||||
)
|
||||
kind: Literal["line", "bar", "area", "scatter"] = "line"
|
||||
time_grain: TimeGrain | None = Field(
|
||||
None, description="PT1S, PT1M, PT1H, P1D, P1W, P1M, P3M, P1Y"
|
||||
None,
|
||||
description="PT1S, PT1M, PT1H, P1D, P1W, P1M, P3M, P1Y",
|
||||
validation_alias=AliasChoices("time_grain", "time_grain_sqla"),
|
||||
)
|
||||
orientation: Literal["vertical", "horizontal"] | None = Field(
|
||||
None, description="Bar orientation (only for kind='bar')"
|
||||
)
|
||||
stacked: bool = False
|
||||
stacked: bool = Field(
|
||||
False,
|
||||
validation_alias=AliasChoices("stacked", "stack"),
|
||||
)
|
||||
group_by: List[ColumnRef] | None = Field(
|
||||
None,
|
||||
description="Series breakdown columns",
|
||||
|
||||
@@ -33,6 +33,7 @@ class ExecuteSqlRequest(BaseModel):
|
||||
sql: str = Field(
|
||||
...,
|
||||
description="SQL query to execute (supports Jinja2 {{ var }} template syntax)",
|
||||
validation_alias=AliasChoices("sql", "query"),
|
||||
)
|
||||
schema_name: str | None = Field(
|
||||
None, description="Schema to use for query execution", alias="schema"
|
||||
|
||||
@@ -235,17 +235,16 @@ class TestXYChartConfig:
|
||||
)
|
||||
assert config.kind == "area"
|
||||
|
||||
def test_unknown_fields_ignored(self) -> None:
|
||||
"""Test that unknown fields are silently ignored (extra='ignore')."""
|
||||
config = XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="territory"),
|
||||
y=[ColumnRef(name="sales", aggregate="SUM")],
|
||||
kind="bar",
|
||||
unknown_field="bad",
|
||||
)
|
||||
assert config.kind == "bar"
|
||||
assert not hasattr(config, "unknown_field")
|
||||
def test_unknown_fields_raise_error(self) -> None:
|
||||
"""Test that unknown fields raise ValueError with suggestions."""
|
||||
with pytest.raises(ValidationError, match="Unknown field"):
|
||||
XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="territory"),
|
||||
y=[ColumnRef(name="sales", aggregate="SUM")],
|
||||
kind="bar",
|
||||
unknown_field="bad",
|
||||
)
|
||||
|
||||
def test_series_alias_accepted(self) -> None:
|
||||
"""Test that 'series' is accepted as alias for 'group_by'."""
|
||||
@@ -430,12 +429,144 @@ class TestRowLimit:
|
||||
class TestTableChartConfigExtraFields:
|
||||
"""Test TableChartConfig rejects unknown fields."""
|
||||
|
||||
def test_unknown_fields_ignored(self) -> None:
|
||||
"""Test that unknown fields are silently ignored (extra='ignore')."""
|
||||
config = TableChartConfig(
|
||||
chart_type="table",
|
||||
columns=[ColumnRef(name="product")],
|
||||
foo="bar",
|
||||
def test_unknown_fields_raise_error(self) -> None:
|
||||
"""Test that unknown fields raise ValueError with valid field list."""
|
||||
with pytest.raises(ValidationError, match="Unknown field 'foo'"):
|
||||
TableChartConfig(
|
||||
chart_type="table",
|
||||
columns=[ColumnRef(name="product")],
|
||||
foo="bar",
|
||||
)
|
||||
|
||||
|
||||
class TestAliasChoices:
|
||||
"""Test that common Superset form_data aliases are accepted."""
|
||||
|
||||
def test_xy_stack_alias_for_stacked(self) -> None:
|
||||
"""Test that 'stack' is accepted as alias for 'stacked'."""
|
||||
config = XYChartConfig.model_validate(
|
||||
{
|
||||
"chart_type": "xy",
|
||||
"x": {"name": "category"},
|
||||
"y": [{"name": "sales", "aggregate": "SUM"}],
|
||||
"stack": True,
|
||||
}
|
||||
)
|
||||
assert len(config.columns) == 1
|
||||
assert not hasattr(config, "foo")
|
||||
assert config.stacked is True
|
||||
|
||||
def test_xy_stacked_still_works(self) -> None:
|
||||
"""Test that 'stacked' still works as primary field name."""
|
||||
config = XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="category"),
|
||||
y=[ColumnRef(name="sales", aggregate="SUM")],
|
||||
stacked=True,
|
||||
)
|
||||
assert config.stacked is True
|
||||
|
||||
def test_xy_time_grain_sqla_alias(self) -> None:
|
||||
"""Test that 'time_grain_sqla' is accepted as alias for 'time_grain'."""
|
||||
config = XYChartConfig.model_validate(
|
||||
{
|
||||
"chart_type": "xy",
|
||||
"x": {"name": "order_date"},
|
||||
"y": [{"name": "sales", "aggregate": "SUM"}],
|
||||
"time_grain_sqla": "P1D",
|
||||
}
|
||||
)
|
||||
assert config.time_grain is not None
|
||||
|
||||
def test_table_order_by_alias_for_sort_by(self) -> None:
|
||||
"""Test that 'order_by' is accepted as alias for 'sort_by'."""
|
||||
config = TableChartConfig.model_validate(
|
||||
{
|
||||
"chart_type": "table",
|
||||
"columns": [{"name": "product"}],
|
||||
"order_by": ["product"],
|
||||
}
|
||||
)
|
||||
assert config.sort_by == ["product"]
|
||||
|
||||
def test_mixed_timeseries_time_grain_sqla_alias(self) -> None:
|
||||
"""Test that 'time_grain_sqla' works for MixedTimeseriesChartConfig."""
|
||||
from superset.mcp_service.chart.schemas import MixedTimeseriesChartConfig
|
||||
|
||||
config = MixedTimeseriesChartConfig.model_validate(
|
||||
{
|
||||
"chart_type": "mixed_timeseries",
|
||||
"x": {"name": "order_date"},
|
||||
"y": [{"name": "sales", "aggregate": "SUM"}],
|
||||
"y_secondary": [{"name": "profit", "aggregate": "SUM"}],
|
||||
"time_grain_sqla": "P1M",
|
||||
}
|
||||
)
|
||||
assert config.time_grain is not None
|
||||
|
||||
|
||||
class TestUnknownFieldDetection:
|
||||
"""Test that unknown fields produce helpful error messages."""
|
||||
|
||||
def test_near_miss_suggests_correct_field(self) -> None:
|
||||
"""Test that a near-miss field name produces 'did you mean?' suggestion."""
|
||||
with pytest.raises(ValidationError, match="did you mean"):
|
||||
XYChartConfig.model_validate(
|
||||
{
|
||||
"chart_type": "xy",
|
||||
"x": {"name": "category"},
|
||||
"y": [{"name": "sales", "aggregate": "SUM"}],
|
||||
"stacks": True,
|
||||
}
|
||||
)
|
||||
|
||||
def test_completely_unknown_field_lists_valid_fields(self) -> None:
|
||||
"""Test that a completely unknown field lists valid fields."""
|
||||
with pytest.raises(ValidationError, match="Valid fields:"):
|
||||
XYChartConfig.model_validate(
|
||||
{
|
||||
"chart_type": "xy",
|
||||
"x": {"name": "category"},
|
||||
"y": [{"name": "sales", "aggregate": "SUM"}],
|
||||
"zzz_nonexistent": True,
|
||||
}
|
||||
)
|
||||
|
||||
def test_pie_chart_unknown_field(self) -> None:
|
||||
"""Test unknown field detection on PieChartConfig."""
|
||||
from superset.mcp_service.chart.schemas import PieChartConfig
|
||||
|
||||
with pytest.raises(ValidationError, match="Unknown field"):
|
||||
PieChartConfig.model_validate(
|
||||
{
|
||||
"chart_type": "pie",
|
||||
"dimension": {"name": "category"},
|
||||
"metric": {"name": "sales", "aggregate": "SUM"},
|
||||
"bad_field": True,
|
||||
}
|
||||
)
|
||||
|
||||
def test_table_chart_unknown_field(self) -> None:
|
||||
"""Test unknown field detection on TableChartConfig."""
|
||||
with pytest.raises(ValidationError, match="Unknown field"):
|
||||
TableChartConfig.model_validate(
|
||||
{
|
||||
"chart_type": "table",
|
||||
"columns": [{"name": "product"}],
|
||||
"invalid_param": "test",
|
||||
}
|
||||
)
|
||||
|
||||
def test_known_aliases_not_flagged_as_unknown(self) -> None:
|
||||
"""Test that known aliases pass validation without errors."""
|
||||
config = XYChartConfig.model_validate(
|
||||
{
|
||||
"chart_type": "xy",
|
||||
"x_axis": {"name": "category"},
|
||||
"metrics": [{"name": "sales", "aggregate": "SUM"}],
|
||||
"groupby": [{"name": "region"}],
|
||||
"stack": True,
|
||||
"time_grain_sqla": "P1D",
|
||||
}
|
||||
)
|
||||
assert config.stacked is True
|
||||
assert config.row_limit == 10000
|
||||
assert config.group_by is not None
|
||||
|
||||
@@ -119,7 +119,7 @@ class TestHandlebarsChartConfig:
|
||||
)
|
||||
|
||||
def test_extra_fields_forbidden(self) -> None:
|
||||
with pytest.raises(ValueError, match="Extra inputs"):
|
||||
with pytest.raises(ValueError, match="Unknown field 'unknown_field'"):
|
||||
HandlebarsChartConfig(
|
||||
chart_type="handlebars",
|
||||
handlebars_template="<p>test</p>",
|
||||
|
||||
@@ -103,15 +103,14 @@ class TestPieChartConfigSchema:
|
||||
assert config.filters is not None
|
||||
assert len(config.filters) == 1
|
||||
|
||||
def test_pie_config_ignores_extra_fields(self) -> None:
|
||||
config = PieChartConfig(
|
||||
chart_type="pie",
|
||||
dimension=ColumnRef(name="product"),
|
||||
metric=ColumnRef(name="revenue", aggregate="SUM"),
|
||||
unknown_field="bad",
|
||||
)
|
||||
assert config.dimension.name == "product"
|
||||
assert not hasattr(config, "unknown_field")
|
||||
def test_pie_config_rejects_extra_fields(self) -> None:
|
||||
with pytest.raises(ValidationError, match="Unknown field"):
|
||||
PieChartConfig(
|
||||
chart_type="pie",
|
||||
dimension=ColumnRef(name="product"),
|
||||
metric=ColumnRef(name="revenue", aggregate="SUM"),
|
||||
unknown_field="bad",
|
||||
)
|
||||
|
||||
def test_pie_config_missing_dimension(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
@@ -324,15 +323,14 @@ class TestPivotTableChartConfigSchema:
|
||||
metrics=[ColumnRef(name="revenue", aggregate="SUM")],
|
||||
)
|
||||
|
||||
def test_pivot_table_ignores_extra_fields(self) -> None:
|
||||
config = PivotTableChartConfig(
|
||||
chart_type="pivot_table",
|
||||
rows=[ColumnRef(name="product")],
|
||||
metrics=[ColumnRef(name="revenue", aggregate="SUM")],
|
||||
unknown_field="bad",
|
||||
)
|
||||
assert config.rows[0].name == "product"
|
||||
assert not hasattr(config, "unknown_field")
|
||||
def test_pivot_table_rejects_extra_fields(self) -> None:
|
||||
with pytest.raises(ValidationError, match="Unknown field"):
|
||||
PivotTableChartConfig(
|
||||
chart_type="pivot_table",
|
||||
rows=[ColumnRef(name="product")],
|
||||
metrics=[ColumnRef(name="revenue", aggregate="SUM")],
|
||||
unknown_field="bad",
|
||||
)
|
||||
|
||||
def test_pivot_table_valid_aggregate_functions(self) -> None:
|
||||
for agg in ["Sum", "Average", "Median", "Count", "Minimum", "Maximum"]:
|
||||
@@ -504,16 +502,15 @@ class TestMixedTimeseriesChartConfigSchema:
|
||||
y_secondary=[ColumnRef(name="orders", aggregate="COUNT")],
|
||||
)
|
||||
|
||||
def test_mixed_timeseries_ignores_extra_fields(self) -> None:
|
||||
config = MixedTimeseriesChartConfig(
|
||||
chart_type="mixed_timeseries",
|
||||
x=ColumnRef(name="date"),
|
||||
y=[ColumnRef(name="revenue", aggregate="SUM")],
|
||||
y_secondary=[ColumnRef(name="orders", aggregate="COUNT")],
|
||||
unknown_field="bad",
|
||||
)
|
||||
assert config.x.name == "date"
|
||||
assert not hasattr(config, "unknown_field")
|
||||
def test_mixed_timeseries_rejects_extra_fields(self) -> None:
|
||||
with pytest.raises(ValidationError, match="Unknown field"):
|
||||
MixedTimeseriesChartConfig(
|
||||
chart_type="mixed_timeseries",
|
||||
x=ColumnRef(name="date"),
|
||||
y=[ColumnRef(name="revenue", aggregate="SUM")],
|
||||
y_secondary=[ColumnRef(name="orders", aggregate="COUNT")],
|
||||
unknown_field="bad",
|
||||
)
|
||||
|
||||
def test_mixed_timeseries_default_row_limit(self) -> None:
|
||||
config = MixedTimeseriesChartConfig(
|
||||
|
||||
Reference in New Issue
Block a user