refactor(mcp): address Codex review — fix registry bug, DRY schema hints, remove column regex

P1.1 registry.py: move _plugins_loaded=True to after successful import so a
failed load doesn't permanently poison the registry.

P1.3 schemas.py: remove overly restrictive ColumnRef.name / FilterClause.column
/ BigNumberChartConfig.temporal_column regex that blocked valid column names
containing parentheses, slashes, and other SQL-common characters.

P2.3 (DRY): eliminate _CHART_TYPE_ERROR_HINTS second-registry in
schema_validator.py by adding schema_error_hint() to ChartTypePlugin protocol,
BaseChartPlugin default, and all 7 plugin classes. SchemaValidator now delegates
to the plugin registry instead of maintaining a parallel dict.

P3.3 test_registry.py: add full registry unit-test coverage (register, get,
all_types, is_registered, display_name_for_viz_type, proxy methods, duplicate
warning, empty chart_type validation, insertion-order guarantee).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Amin Ghadersohi
2026-05-13 21:27:39 +00:00
parent f3a30af324
commit c4be22e801
12 changed files with 297 additions and 117 deletions

12
superset/mcp_service/chart/plugin.py Normal file → Executable file
View File

@@ -172,6 +172,15 @@ class ChartTypePlugin(Protocol):
"""
...
def schema_error_hint(self) -> "ChartGenerationError | None":
"""
Return a user-friendly error for Pydantic discriminated-union parse failures.
Called by SchemaValidator when Pydantic cannot parse the config union and
the chart_type is known. Returning None falls back to the generic error.
"""
...
class BaseChartPlugin:
"""
@@ -237,6 +246,9 @@ class BaseChartPlugin:
def resolve_viz_type(self, config: Any) -> str:
return "unknown"
def schema_error_hint(self) -> ChartGenerationError | None:
return None
@staticmethod
def _with_context(what: str, context: str | None) -> str:
"""Combine a 'what' label and optional context with an en-dash."""

17
superset/mcp_service/chart/plugins/big_number.py Normal file → Executable file
View File

@@ -201,3 +201,20 @@ class BigNumberChartPlugin(BaseChartPlugin):
)
DatasetValidator._normalize_filters(config_dict, dataset_context)
return BigNumberChartConfig.model_validate(config_dict)
def schema_error_hint(self) -> ChartGenerationError | None:
return ChartGenerationError(
error_type="big_number_validation_error",
message="Big Number chart configuration validation failed",
details=(
"The Big Number chart configuration is missing required "
"fields or has invalid structure"
),
suggestions=[
"Ensure 'metric' field has 'name' and 'aggregate'",
"Example: 'metric': {'name': 'revenue', 'aggregate': 'SUM'}",
"For trendline: add show_trendline=true and temporal_column='col'",
"Without trendline: just provide the metric",
],
error_code="BIG_NUMBER_VALIDATION_ERROR",
)

20
superset/mcp_service/chart/plugins/handlebars.py Normal file → Executable file
View File

@@ -167,3 +167,23 @@ class HandlebarsChartPlugin(BaseChartPlugin):
_norm_list("groupby")
DatasetValidator._normalize_filters(config_dict, dataset_context)
return HandlebarsChartConfig.model_validate(config_dict)
def schema_error_hint(self) -> ChartGenerationError | None:
return ChartGenerationError(
error_type="handlebars_validation_error",
message="Handlebars chart configuration validation failed",
details=(
"The handlebars chart configuration is missing "
"required fields or has invalid structure"
),
suggestions=[
"Ensure 'handlebars_template' is a non-empty string",
"For aggregate mode: add 'metrics' with aggregate functions",
"For raw mode: set 'query_mode': 'raw' and add 'columns'",
"Example: {'chart_type': 'handlebars', "
"'handlebars_template': "
"'<ul>{{#each data}}<li>{{this.name}}</li>{{/each}}</ul>', "
"'metrics': [{'name': 'sales', 'aggregate': 'SUM'}]}",
],
error_code="HANDLEBARS_VALIDATION_ERROR",
)

20
superset/mcp_service/chart/plugins/mixed_timeseries.py Normal file → Executable file
View File

@@ -143,3 +143,23 @@ class MixedTimeseriesChartPlugin(BaseChartPlugin):
_norm_list("group_by_secondary")
DatasetValidator._normalize_filters(config_dict, dataset_context)
return MixedTimeseriesChartConfig.model_validate(config_dict)
def schema_error_hint(self) -> ChartGenerationError | None:
return ChartGenerationError(
error_type="mixed_timeseries_validation_error",
message="Mixed timeseries chart configuration validation failed",
details=(
"The mixed timeseries configuration is missing "
"required fields or has invalid structure"
),
suggestions=[
"Ensure 'x' field has 'name' for the time axis column",
"Ensure 'y' is an array of primary-axis metrics",
"Ensure 'y_secondary' is an array of secondary-axis metrics",
"Example: {'chart_type': 'mixed_timeseries', "
"'x': {'name': 'order_date'}, "
"'y': [{'name': 'revenue', 'aggregate': 'SUM'}], "
"'y_secondary': [{'name': 'orders', 'aggregate': 'COUNT'}]}",
],
error_code="MIXED_TIMESERIES_VALIDATION_ERROR",
)

17
superset/mcp_service/chart/plugins/pie.py Normal file → Executable file
View File

@@ -109,3 +109,20 @@ class PieChartPlugin(BaseChartPlugin):
)
DatasetValidator._normalize_filters(config_dict, dataset_context)
return PieChartConfig.model_validate(config_dict)
def schema_error_hint(self) -> ChartGenerationError | None:
return ChartGenerationError(
error_type="pie_validation_error",
message="Pie chart configuration validation failed",
details=(
"The pie chart configuration is missing required "
"fields or has invalid structure"
),
suggestions=[
"Ensure 'dimension' field has 'name' for the slice label",
"Ensure 'metric' field has 'name' and 'aggregate'",
"Example: {'chart_type': 'pie', 'dimension': {'name': 'category'}, "
"'metric': {'name': 'revenue', 'aggregate': 'SUM'}}",
],
error_code="PIE_VALIDATION_ERROR",
)

19
superset/mcp_service/chart/plugins/pivot_table.py Normal file → Executable file
View File

@@ -132,3 +132,22 @@ class PivotTableChartPlugin(BaseChartPlugin):
_norm_col_list("columns")
DatasetValidator._normalize_filters(config_dict, dataset_context)
return PivotTableChartConfig.model_validate(config_dict)
def schema_error_hint(self) -> ChartGenerationError | None:
return ChartGenerationError(
error_type="pivot_table_validation_error",
message="Pivot table configuration validation failed",
details=(
"The pivot table configuration is missing required "
"fields or has invalid structure"
),
suggestions=[
"Ensure 'rows' field is an array of column specs",
"Ensure 'metrics' field is an array with aggregate funcs",
"Optional: add 'columns' for column grouping",
"Example: {'chart_type': 'pivot_table', "
"'rows': [{'name': 'region'}], "
"'metrics': [{'name': 'revenue', 'aggregate': 'SUM'}]}",
],
error_code="PIVOT_TABLE_VALIDATION_ERROR",
)

18
superset/mcp_service/chart/plugins/table.py Normal file → Executable file
View File

@@ -108,3 +108,21 @@ class TableChartPlugin(BaseChartPlugin):
DatasetValidator._normalize_filters(config_dict, dataset_context)
return TableChartConfig.model_validate(config_dict)
def schema_error_hint(self) -> ChartGenerationError | None:
return ChartGenerationError(
error_type="table_validation_error",
message="Table chart configuration validation failed",
details=(
"The table chart configuration is missing required "
"fields or has invalid structure"
),
suggestions=[
"Ensure 'columns' field is an array of column specifications",
"Each column needs {'name': 'column_name'}",
"Optional: add 'aggregate' for metrics",
"Example: 'columns': [{'name': 'product'}, "
"{'name': 'sales', 'aggregate': 'SUM'}]",
],
error_code="TABLE_VALIDATION_ERROR",
)

18
superset/mcp_service/chart/plugins/xy.py Normal file → Executable file
View File

@@ -172,3 +172,21 @@ class XYChartPlugin(BaseChartPlugin):
logger.warning("XY cardinality validation failed: %s", exc)
return warnings
def schema_error_hint(self) -> ChartGenerationError | None:
return ChartGenerationError(
error_type="xy_validation_error",
message="XY chart configuration validation failed",
details=(
"The XY chart configuration is missing required "
"fields or has invalid structure"
),
suggestions=[
"Note: 'x' is optional and defaults to the dataset's primary "
"datetime column",
"Ensure 'y' is an array: [{'name': 'metric', 'aggregate': 'SUM'}]",
"Check that all column names are strings",
"Verify aggregate functions are valid: SUM, COUNT, AVG, MIN, MAX",
],
error_code="XY_VALIDATION_ERROR",
)

View File

@@ -62,8 +62,12 @@ def _ensure_plugins_loaded() -> None:
return
with _plugins_lock:
if not _plugins_loaded:
_plugins_loaded = True
import superset.mcp_service.chart.plugins # noqa: F401
try:
import superset.mcp_service.chart.plugins # noqa: F401
_plugins_loaded = True
except Exception:
logger.exception("Failed to load built-in chart type plugins")
def register(plugin: "ChartTypePlugin") -> None:

3
superset/mcp_service/chart/schemas.py Normal file → Executable file
View File

@@ -680,7 +680,6 @@ class ColumnRef(BaseModel):
...,
min_length=1,
max_length=255,
pattern=r"^[a-zA-Z0-9_][a-zA-Z0-9_\s\-\.]*$",
validation_alias=AliasChoices("name", "column_name"),
)
label: str | None = Field(None, max_length=500)
@@ -754,7 +753,6 @@ class FilterConfig(BaseModel):
...,
min_length=1,
max_length=255,
pattern=r"^[a-zA-Z0-9_][a-zA-Z0-9_\s\-\.]*$",
validation_alias=AliasChoices("column", "col"),
)
op: Literal[
@@ -1117,7 +1115,6 @@ class BigNumberChartConfig(UnknownFieldCheckMixin):
),
min_length=1,
max_length=255,
pattern=r"^[a-zA-Z0-9_][a-zA-Z0-9_\s\-\.]*$",
)
time_grain: TimeGrain | None = Field(
None,

119
superset/mcp_service/chart/validation/schema_validator.py Normal file → Executable file
View File

@@ -186,115 +186,6 @@ class SchemaValidator:
return False, error
return True, None
# Per-chart-type error details used by _enhance_validation_error.
# Keyed by chart_type discriminator value.
# NOTE: Keep this dict in sync with the plugin registry in
# superset/mcp_service/chart/plugins/ — each registered chart_type must
# have a corresponding entry here so Pydantic parse errors produce
# helpful, type-specific messages.
_CHART_TYPE_ERROR_HINTS: Dict[str, Dict[str, Any]] = {
"xy": {
"error_type": "xy_validation_error",
"message": "XY chart configuration validation failed",
"details": "The XY chart configuration is missing required "
"fields or has invalid structure",
"suggestions": [
"Note: 'x' is optional and defaults to the dataset's primary "
"datetime column",
"Ensure 'y' is an array: [{'name': 'metric', 'aggregate': 'SUM'}]",
"Check that all column names are strings",
"Verify aggregate functions are valid: SUM, COUNT, AVG, MIN, MAX",
],
"error_code": "XY_VALIDATION_ERROR",
},
"table": {
"error_type": "table_validation_error",
"message": "Table chart configuration validation failed",
"details": "The table chart configuration is missing required "
"fields or has invalid structure",
"suggestions": [
"Ensure 'columns' field is an array of column specifications",
"Each column needs {'name': 'column_name'}",
"Optional: add 'aggregate' for metrics",
"Example: 'columns': [{'name': 'product'}, "
"{'name': 'sales', 'aggregate': 'SUM'}]",
],
"error_code": "TABLE_VALIDATION_ERROR",
},
"pie": {
"error_type": "pie_validation_error",
"message": "Pie chart configuration validation failed",
"details": "The pie chart configuration is missing required "
"fields or has invalid structure",
"suggestions": [
"Ensure 'dimension' field has 'name' for the slice label",
"Ensure 'metric' field has 'name' and 'aggregate'",
"Example: {'chart_type': 'pie', 'dimension': {'name': 'category'}, "
"'metric': {'name': 'revenue', 'aggregate': 'SUM'}}",
],
"error_code": "PIE_VALIDATION_ERROR",
},
"pivot_table": {
"error_type": "pivot_table_validation_error",
"message": "Pivot table configuration validation failed",
"details": "The pivot table configuration is missing required "
"fields or has invalid structure",
"suggestions": [
"Ensure 'rows' field is an array of column specs",
"Ensure 'metrics' field is an array with aggregate funcs",
"Optional: add 'columns' for column grouping",
"Example: {'chart_type': 'pivot_table', 'rows': [{'name': 'region'}], "
"'metrics': [{'name': 'revenue', 'aggregate': 'SUM'}]}",
],
"error_code": "PIVOT_TABLE_VALIDATION_ERROR",
},
"mixed_timeseries": {
"error_type": "mixed_timeseries_validation_error",
"message": "Mixed timeseries chart configuration validation failed",
"details": "The mixed timeseries configuration is missing "
"required fields or has invalid structure",
"suggestions": [
"Ensure 'x' field has 'name' for the time axis column",
"Ensure 'y' is an array of primary-axis metrics",
"Ensure 'y_secondary' is an array of secondary-axis metrics",
"Example: {'chart_type': 'mixed_timeseries', "
"'x': {'name': 'order_date'}, "
"'y': [{'name': 'revenue', 'aggregate': 'SUM'}], "
"'y_secondary': [{'name': 'orders', 'aggregate': 'COUNT'}]}",
],
"error_code": "MIXED_TIMESERIES_VALIDATION_ERROR",
},
"handlebars": {
"error_type": "handlebars_validation_error",
"message": "Handlebars chart configuration validation failed",
"details": "The handlebars chart configuration is missing "
"required fields or has invalid structure",
"suggestions": [
"Ensure 'handlebars_template' is a non-empty string",
"For aggregate mode: add 'metrics' with aggregate functions",
"For raw mode: set 'query_mode': 'raw' and add 'columns'",
"Example: {'chart_type': 'handlebars', "
"'handlebars_template': '<ul>{{#each data}}<li>"
"{{this.name}}</li>{{/each}}</ul>', "
"'metrics': [{'name': 'sales', 'aggregate': 'SUM'}]}",
],
"error_code": "HANDLEBARS_VALIDATION_ERROR",
},
"big_number": {
"error_type": "big_number_validation_error",
"message": "Big Number chart configuration validation failed",
"details": "The Big Number chart configuration is missing required "
"fields or has invalid structure",
"suggestions": [
"Ensure 'metric' field has 'name' and 'aggregate'",
"Example: 'metric': {'name': 'revenue', 'aggregate': 'SUM'}",
"For trendline: add show_trendline=true and temporal_column='col'",
"Without trendline: just provide the metric",
],
"error_code": "BIG_NUMBER_VALIDATION_ERROR",
},
}
@staticmethod
def _enhance_validation_error(
error: PydanticValidationError, request_data: Dict[str, Any]
@@ -307,10 +198,14 @@ class SchemaValidator:
if err.get("type") == "union_tag_invalid" or "discriminator" in str(
err.get("ctx", {})
):
from superset.mcp_service.chart.registry import get_registry
chart_type = request_data.get("config", {}).get("chart_type", "")
hint = SchemaValidator._CHART_TYPE_ERROR_HINTS.get(chart_type)
if hint:
return ChartGenerationError(**hint)
plugin = get_registry().get(chart_type)
if plugin is not None:
hint = plugin.schema_error_hint()
if hint is not None:
return hint
# Default enhanced error
error_details = []

View File

@@ -0,0 +1,143 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Tests for the chart type plugin registry."""
import pytest
import superset.mcp_service.chart.registry as registry_module
from superset.mcp_service.chart.plugin import BaseChartPlugin
from superset.mcp_service.chart.registry import (
_RegistryProxy,
all_types,
display_name_for_viz_type,
get,
get_registry,
is_registered,
register,
)
@pytest.fixture(autouse=True)
def _isolated_registry(monkeypatch):
"""Run each test against a clean registry without touching the real one."""
monkeypatch.setattr(registry_module, "_REGISTRY", {})
monkeypatch.setattr(registry_module, "_plugins_loaded", True)
class _FakePlugin(BaseChartPlugin):
chart_type = "fake"
display_name = "Fake Chart"
native_viz_types = {"fake_viz": "Fake Viz"}
class _AnotherPlugin(BaseChartPlugin):
chart_type = "another"
display_name = "Another Chart"
native_viz_types = {"another_viz": "Another Viz"}
def test_register_adds_plugin():
plugin = _FakePlugin()
register(plugin)
assert get("fake") is plugin
def test_get_returns_none_for_unknown():
assert get("nonexistent") is None
def test_all_types_returns_registered_keys():
register(_FakePlugin())
register(_AnotherPlugin())
types = all_types()
assert "fake" in types
assert "another" in types
def test_all_types_insertion_order():
register(_FakePlugin())
register(_AnotherPlugin())
types = all_types()
assert types.index("fake") < types.index("another")
def test_is_registered_true_for_known():
register(_FakePlugin())
assert is_registered("fake") is True
def test_is_registered_false_for_unknown():
assert is_registered("nonexistent") is False
def test_register_warns_on_duplicate(caplog):
register(_FakePlugin())
with caplog.at_level("WARNING"):
register(_FakePlugin())
assert "Overwriting" in caplog.text
def test_register_raises_for_empty_chart_type():
class _BadPlugin(BaseChartPlugin):
chart_type = ""
with pytest.raises(ValueError, match="non-empty chart_type"):
register(_BadPlugin())
def test_display_name_for_viz_type_found():
register(_FakePlugin())
assert display_name_for_viz_type("fake_viz") == "Fake Viz"
def test_display_name_for_viz_type_not_found():
register(_FakePlugin())
assert display_name_for_viz_type("unknown_viz") is None
def test_display_name_searches_all_plugins():
register(_FakePlugin())
register(_AnotherPlugin())
assert display_name_for_viz_type("another_viz") == "Another Viz"
def test_get_registry_returns_proxy():
assert isinstance(get_registry(), _RegistryProxy)
def test_registry_proxy_get():
plugin = _FakePlugin()
register(plugin)
assert get_registry().get("fake") is plugin
def test_registry_proxy_all_types():
register(_FakePlugin())
assert "fake" in get_registry().all_types()
def test_registry_proxy_is_registered():
register(_FakePlugin())
assert get_registry().is_registered("fake") is True
assert get_registry().is_registered("missing") is False
def test_registry_proxy_display_name_for_viz_type():
register(_FakePlugin())
assert get_registry().display_name_for_viz_type("fake_viz") == "Fake Viz"
assert get_registry().display_name_for_viz_type("unknown") is None