From c4be22e801b25f8a00428fb24bc0bdb204fbbfbf Mon Sep 17 00:00:00 2001 From: Amin Ghadersohi Date: Wed, 13 May 2026 21:27:39 +0000 Subject: [PATCH] =?UTF-8?q?refactor(mcp):=20address=20Codex=20review=20?= =?UTF-8?q?=E2=80=94=20fix=20registry=20bug,=20DRY=20schema=20hints,=20rem?= =?UTF-8?q?ove=20column=20regex?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit P1.1 registry.py: move _plugins_loaded=True to after successful import so a failed load doesn't permanently poison the registry. P1.3 schemas.py: remove overly restrictive ColumnRef.name / FilterClause.column / BigNumberChartConfig.temporal_column regex that blocked valid column names containing parentheses, slashes, and other SQL-common characters. P2.3 (DRY): eliminate _CHART_TYPE_ERROR_HINTS second-registry in schema_validator.py by adding schema_error_hint() to ChartTypePlugin protocol, BaseChartPlugin default, and all 7 plugin classes. SchemaValidator now delegates to the plugin registry instead of maintaining a parallel dict. P3.3 test_registry.py: add full registry unit-test coverage (register, get, all_types, is_registered, display_name_for_viz_type, proxy methods, duplicate warning, empty chart_type validation, insertion-order guarantee). Co-Authored-By: Claude Sonnet 4.6 --- superset/mcp_service/chart/plugin.py | 12 ++ .../mcp_service/chart/plugins/big_number.py | 17 +++ .../mcp_service/chart/plugins/handlebars.py | 20 +++ .../chart/plugins/mixed_timeseries.py | 20 +++ superset/mcp_service/chart/plugins/pie.py | 17 +++ .../mcp_service/chart/plugins/pivot_table.py | 19 +++ superset/mcp_service/chart/plugins/table.py | 18 +++ superset/mcp_service/chart/plugins/xy.py | 18 +++ superset/mcp_service/chart/registry.py | 8 +- superset/mcp_service/chart/schemas.py | 3 - .../chart/validation/schema_validator.py | 119 +-------------- .../mcp_service/chart/test_registry.py | 143 ++++++++++++++++++ 12 files changed, 297 insertions(+), 117 deletions(-) mode change 100644 => 100755 superset/mcp_service/chart/plugin.py mode change 100644 => 100755 superset/mcp_service/chart/plugins/big_number.py mode change 100644 => 100755 superset/mcp_service/chart/plugins/handlebars.py mode change 100644 => 100755 superset/mcp_service/chart/plugins/mixed_timeseries.py mode change 100644 => 100755 superset/mcp_service/chart/plugins/pie.py mode change 100644 => 100755 superset/mcp_service/chart/plugins/pivot_table.py mode change 100644 => 100755 superset/mcp_service/chart/plugins/table.py mode change 100644 => 100755 superset/mcp_service/chart/plugins/xy.py mode change 100644 => 100755 superset/mcp_service/chart/schemas.py mode change 100644 => 100755 superset/mcp_service/chart/validation/schema_validator.py create mode 100644 tests/unit_tests/mcp_service/chart/test_registry.py diff --git a/superset/mcp_service/chart/plugin.py b/superset/mcp_service/chart/plugin.py old mode 100644 new mode 100755 index d0e68208664..6d88c7194d9 --- a/superset/mcp_service/chart/plugin.py +++ b/superset/mcp_service/chart/plugin.py @@ -172,6 +172,15 @@ class ChartTypePlugin(Protocol): """ ... + def schema_error_hint(self) -> "ChartGenerationError | None": + """ + Return a user-friendly error for Pydantic discriminated-union parse failures. + + Called by SchemaValidator when Pydantic cannot parse the config union and + the chart_type is known. Returning None falls back to the generic error. + """ + ... + class BaseChartPlugin: """ @@ -237,6 +246,9 @@ class BaseChartPlugin: def resolve_viz_type(self, config: Any) -> str: return "unknown" + def schema_error_hint(self) -> ChartGenerationError | None: + return None + @staticmethod def _with_context(what: str, context: str | None) -> str: """Combine a 'what' label and optional context with an en-dash.""" diff --git a/superset/mcp_service/chart/plugins/big_number.py b/superset/mcp_service/chart/plugins/big_number.py old mode 100644 new mode 100755 index 4be98e1c58b..e542f8e75f0 --- a/superset/mcp_service/chart/plugins/big_number.py +++ b/superset/mcp_service/chart/plugins/big_number.py @@ -201,3 +201,20 @@ class BigNumberChartPlugin(BaseChartPlugin): ) DatasetValidator._normalize_filters(config_dict, dataset_context) return BigNumberChartConfig.model_validate(config_dict) + + def schema_error_hint(self) -> ChartGenerationError | None: + return ChartGenerationError( + error_type="big_number_validation_error", + message="Big Number chart configuration validation failed", + details=( + "The Big Number chart configuration is missing required " + "fields or has invalid structure" + ), + suggestions=[ + "Ensure 'metric' field has 'name' and 'aggregate'", + "Example: 'metric': {'name': 'revenue', 'aggregate': 'SUM'}", + "For trendline: add show_trendline=true and temporal_column='col'", + "Without trendline: just provide the metric", + ], + error_code="BIG_NUMBER_VALIDATION_ERROR", + ) diff --git a/superset/mcp_service/chart/plugins/handlebars.py b/superset/mcp_service/chart/plugins/handlebars.py old mode 100644 new mode 100755 index 85c7cc1511b..53d78cd8b82 --- a/superset/mcp_service/chart/plugins/handlebars.py +++ b/superset/mcp_service/chart/plugins/handlebars.py @@ -167,3 +167,23 @@ class HandlebarsChartPlugin(BaseChartPlugin): _norm_list("groupby") DatasetValidator._normalize_filters(config_dict, dataset_context) return HandlebarsChartConfig.model_validate(config_dict) + + def schema_error_hint(self) -> ChartGenerationError | None: + return ChartGenerationError( + error_type="handlebars_validation_error", + message="Handlebars chart configuration validation failed", + details=( + "The handlebars chart configuration is missing " + "required fields or has invalid structure" + ), + suggestions=[ + "Ensure 'handlebars_template' is a non-empty string", + "For aggregate mode: add 'metrics' with aggregate functions", + "For raw mode: set 'query_mode': 'raw' and add 'columns'", + "Example: {'chart_type': 'handlebars', " + "'handlebars_template': " + "'', " + "'metrics': [{'name': 'sales', 'aggregate': 'SUM'}]}", + ], + error_code="HANDLEBARS_VALIDATION_ERROR", + ) diff --git a/superset/mcp_service/chart/plugins/mixed_timeseries.py b/superset/mcp_service/chart/plugins/mixed_timeseries.py old mode 100644 new mode 100755 index e2bc062b32c..0cf7b82e80e --- a/superset/mcp_service/chart/plugins/mixed_timeseries.py +++ b/superset/mcp_service/chart/plugins/mixed_timeseries.py @@ -143,3 +143,23 @@ class MixedTimeseriesChartPlugin(BaseChartPlugin): _norm_list("group_by_secondary") DatasetValidator._normalize_filters(config_dict, dataset_context) return MixedTimeseriesChartConfig.model_validate(config_dict) + + def schema_error_hint(self) -> ChartGenerationError | None: + return ChartGenerationError( + error_type="mixed_timeseries_validation_error", + message="Mixed timeseries chart configuration validation failed", + details=( + "The mixed timeseries configuration is missing " + "required fields or has invalid structure" + ), + suggestions=[ + "Ensure 'x' field has 'name' for the time axis column", + "Ensure 'y' is an array of primary-axis metrics", + "Ensure 'y_secondary' is an array of secondary-axis metrics", + "Example: {'chart_type': 'mixed_timeseries', " + "'x': {'name': 'order_date'}, " + "'y': [{'name': 'revenue', 'aggregate': 'SUM'}], " + "'y_secondary': [{'name': 'orders', 'aggregate': 'COUNT'}]}", + ], + error_code="MIXED_TIMESERIES_VALIDATION_ERROR", + ) diff --git a/superset/mcp_service/chart/plugins/pie.py b/superset/mcp_service/chart/plugins/pie.py old mode 100644 new mode 100755 index 01d7921cf62..3d87fe7f05f --- a/superset/mcp_service/chart/plugins/pie.py +++ b/superset/mcp_service/chart/plugins/pie.py @@ -109,3 +109,20 @@ class PieChartPlugin(BaseChartPlugin): ) DatasetValidator._normalize_filters(config_dict, dataset_context) return PieChartConfig.model_validate(config_dict) + + def schema_error_hint(self) -> ChartGenerationError | None: + return ChartGenerationError( + error_type="pie_validation_error", + message="Pie chart configuration validation failed", + details=( + "The pie chart configuration is missing required " + "fields or has invalid structure" + ), + suggestions=[ + "Ensure 'dimension' field has 'name' for the slice label", + "Ensure 'metric' field has 'name' and 'aggregate'", + "Example: {'chart_type': 'pie', 'dimension': {'name': 'category'}, " + "'metric': {'name': 'revenue', 'aggregate': 'SUM'}}", + ], + error_code="PIE_VALIDATION_ERROR", + ) diff --git a/superset/mcp_service/chart/plugins/pivot_table.py b/superset/mcp_service/chart/plugins/pivot_table.py old mode 100644 new mode 100755 index e00f5ec0577..038f8c79416 --- a/superset/mcp_service/chart/plugins/pivot_table.py +++ b/superset/mcp_service/chart/plugins/pivot_table.py @@ -132,3 +132,22 @@ class PivotTableChartPlugin(BaseChartPlugin): _norm_col_list("columns") DatasetValidator._normalize_filters(config_dict, dataset_context) return PivotTableChartConfig.model_validate(config_dict) + + def schema_error_hint(self) -> ChartGenerationError | None: + return ChartGenerationError( + error_type="pivot_table_validation_error", + message="Pivot table configuration validation failed", + details=( + "The pivot table configuration is missing required " + "fields or has invalid structure" + ), + suggestions=[ + "Ensure 'rows' field is an array of column specs", + "Ensure 'metrics' field is an array with aggregate funcs", + "Optional: add 'columns' for column grouping", + "Example: {'chart_type': 'pivot_table', " + "'rows': [{'name': 'region'}], " + "'metrics': [{'name': 'revenue', 'aggregate': 'SUM'}]}", + ], + error_code="PIVOT_TABLE_VALIDATION_ERROR", + ) diff --git a/superset/mcp_service/chart/plugins/table.py b/superset/mcp_service/chart/plugins/table.py old mode 100644 new mode 100755 index e8d2516db71..86f5dcaead2 --- a/superset/mcp_service/chart/plugins/table.py +++ b/superset/mcp_service/chart/plugins/table.py @@ -108,3 +108,21 @@ class TableChartPlugin(BaseChartPlugin): DatasetValidator._normalize_filters(config_dict, dataset_context) return TableChartConfig.model_validate(config_dict) + + def schema_error_hint(self) -> ChartGenerationError | None: + return ChartGenerationError( + error_type="table_validation_error", + message="Table chart configuration validation failed", + details=( + "The table chart configuration is missing required " + "fields or has invalid structure" + ), + suggestions=[ + "Ensure 'columns' field is an array of column specifications", + "Each column needs {'name': 'column_name'}", + "Optional: add 'aggregate' for metrics", + "Example: 'columns': [{'name': 'product'}, " + "{'name': 'sales', 'aggregate': 'SUM'}]", + ], + error_code="TABLE_VALIDATION_ERROR", + ) diff --git a/superset/mcp_service/chart/plugins/xy.py b/superset/mcp_service/chart/plugins/xy.py old mode 100644 new mode 100755 index 477a30a6e4d..076826f3f08 --- a/superset/mcp_service/chart/plugins/xy.py +++ b/superset/mcp_service/chart/plugins/xy.py @@ -172,3 +172,21 @@ class XYChartPlugin(BaseChartPlugin): logger.warning("XY cardinality validation failed: %s", exc) return warnings + + def schema_error_hint(self) -> ChartGenerationError | None: + return ChartGenerationError( + error_type="xy_validation_error", + message="XY chart configuration validation failed", + details=( + "The XY chart configuration is missing required " + "fields or has invalid structure" + ), + suggestions=[ + "Note: 'x' is optional and defaults to the dataset's primary " + "datetime column", + "Ensure 'y' is an array: [{'name': 'metric', 'aggregate': 'SUM'}]", + "Check that all column names are strings", + "Verify aggregate functions are valid: SUM, COUNT, AVG, MIN, MAX", + ], + error_code="XY_VALIDATION_ERROR", + ) diff --git a/superset/mcp_service/chart/registry.py b/superset/mcp_service/chart/registry.py index 005b17ce327..920cfcc5592 100755 --- a/superset/mcp_service/chart/registry.py +++ b/superset/mcp_service/chart/registry.py @@ -62,8 +62,12 @@ def _ensure_plugins_loaded() -> None: return with _plugins_lock: if not _plugins_loaded: - _plugins_loaded = True - import superset.mcp_service.chart.plugins # noqa: F401 + try: + import superset.mcp_service.chart.plugins # noqa: F401 + + _plugins_loaded = True + except Exception: + logger.exception("Failed to load built-in chart type plugins") def register(plugin: "ChartTypePlugin") -> None: diff --git a/superset/mcp_service/chart/schemas.py b/superset/mcp_service/chart/schemas.py old mode 100644 new mode 100755 index 9d8507a9254..40a2ef1f3c2 --- a/superset/mcp_service/chart/schemas.py +++ b/superset/mcp_service/chart/schemas.py @@ -680,7 +680,6 @@ class ColumnRef(BaseModel): ..., min_length=1, max_length=255, - pattern=r"^[a-zA-Z0-9_][a-zA-Z0-9_\s\-\.]*$", validation_alias=AliasChoices("name", "column_name"), ) label: str | None = Field(None, max_length=500) @@ -754,7 +753,6 @@ class FilterConfig(BaseModel): ..., min_length=1, max_length=255, - pattern=r"^[a-zA-Z0-9_][a-zA-Z0-9_\s\-\.]*$", validation_alias=AliasChoices("column", "col"), ) op: Literal[ @@ -1117,7 +1115,6 @@ class BigNumberChartConfig(UnknownFieldCheckMixin): ), min_length=1, max_length=255, - pattern=r"^[a-zA-Z0-9_][a-zA-Z0-9_\s\-\.]*$", ) time_grain: TimeGrain | None = Field( None, diff --git a/superset/mcp_service/chart/validation/schema_validator.py b/superset/mcp_service/chart/validation/schema_validator.py old mode 100644 new mode 100755 index eb404d11f19..11f0f6ada8a --- a/superset/mcp_service/chart/validation/schema_validator.py +++ b/superset/mcp_service/chart/validation/schema_validator.py @@ -186,115 +186,6 @@ class SchemaValidator: return False, error return True, None - # Per-chart-type error details used by _enhance_validation_error. - # Keyed by chart_type discriminator value. - # NOTE: Keep this dict in sync with the plugin registry in - # superset/mcp_service/chart/plugins/ — each registered chart_type must - # have a corresponding entry here so Pydantic parse errors produce - # helpful, type-specific messages. - _CHART_TYPE_ERROR_HINTS: Dict[str, Dict[str, Any]] = { - "xy": { - "error_type": "xy_validation_error", - "message": "XY chart configuration validation failed", - "details": "The XY chart configuration is missing required " - "fields or has invalid structure", - "suggestions": [ - "Note: 'x' is optional and defaults to the dataset's primary " - "datetime column", - "Ensure 'y' is an array: [{'name': 'metric', 'aggregate': 'SUM'}]", - "Check that all column names are strings", - "Verify aggregate functions are valid: SUM, COUNT, AVG, MIN, MAX", - ], - "error_code": "XY_VALIDATION_ERROR", - }, - "table": { - "error_type": "table_validation_error", - "message": "Table chart configuration validation failed", - "details": "The table chart configuration is missing required " - "fields or has invalid structure", - "suggestions": [ - "Ensure 'columns' field is an array of column specifications", - "Each column needs {'name': 'column_name'}", - "Optional: add 'aggregate' for metrics", - "Example: 'columns': [{'name': 'product'}, " - "{'name': 'sales', 'aggregate': 'SUM'}]", - ], - "error_code": "TABLE_VALIDATION_ERROR", - }, - "pie": { - "error_type": "pie_validation_error", - "message": "Pie chart configuration validation failed", - "details": "The pie chart configuration is missing required " - "fields or has invalid structure", - "suggestions": [ - "Ensure 'dimension' field has 'name' for the slice label", - "Ensure 'metric' field has 'name' and 'aggregate'", - "Example: {'chart_type': 'pie', 'dimension': {'name': 'category'}, " - "'metric': {'name': 'revenue', 'aggregate': 'SUM'}}", - ], - "error_code": "PIE_VALIDATION_ERROR", - }, - "pivot_table": { - "error_type": "pivot_table_validation_error", - "message": "Pivot table configuration validation failed", - "details": "The pivot table configuration is missing required " - "fields or has invalid structure", - "suggestions": [ - "Ensure 'rows' field is an array of column specs", - "Ensure 'metrics' field is an array with aggregate funcs", - "Optional: add 'columns' for column grouping", - "Example: {'chart_type': 'pivot_table', 'rows': [{'name': 'region'}], " - "'metrics': [{'name': 'revenue', 'aggregate': 'SUM'}]}", - ], - "error_code": "PIVOT_TABLE_VALIDATION_ERROR", - }, - "mixed_timeseries": { - "error_type": "mixed_timeseries_validation_error", - "message": "Mixed timeseries chart configuration validation failed", - "details": "The mixed timeseries configuration is missing " - "required fields or has invalid structure", - "suggestions": [ - "Ensure 'x' field has 'name' for the time axis column", - "Ensure 'y' is an array of primary-axis metrics", - "Ensure 'y_secondary' is an array of secondary-axis metrics", - "Example: {'chart_type': 'mixed_timeseries', " - "'x': {'name': 'order_date'}, " - "'y': [{'name': 'revenue', 'aggregate': 'SUM'}], " - "'y_secondary': [{'name': 'orders', 'aggregate': 'COUNT'}]}", - ], - "error_code": "MIXED_TIMESERIES_VALIDATION_ERROR", - }, - "handlebars": { - "error_type": "handlebars_validation_error", - "message": "Handlebars chart configuration validation failed", - "details": "The handlebars chart configuration is missing " - "required fields or has invalid structure", - "suggestions": [ - "Ensure 'handlebars_template' is a non-empty string", - "For aggregate mode: add 'metrics' with aggregate functions", - "For raw mode: set 'query_mode': 'raw' and add 'columns'", - "Example: {'chart_type': 'handlebars', " - "'handlebars_template': '', " - "'metrics': [{'name': 'sales', 'aggregate': 'SUM'}]}", - ], - "error_code": "HANDLEBARS_VALIDATION_ERROR", - }, - "big_number": { - "error_type": "big_number_validation_error", - "message": "Big Number chart configuration validation failed", - "details": "The Big Number chart configuration is missing required " - "fields or has invalid structure", - "suggestions": [ - "Ensure 'metric' field has 'name' and 'aggregate'", - "Example: 'metric': {'name': 'revenue', 'aggregate': 'SUM'}", - "For trendline: add show_trendline=true and temporal_column='col'", - "Without trendline: just provide the metric", - ], - "error_code": "BIG_NUMBER_VALIDATION_ERROR", - }, - } - @staticmethod def _enhance_validation_error( error: PydanticValidationError, request_data: Dict[str, Any] @@ -307,10 +198,14 @@ class SchemaValidator: if err.get("type") == "union_tag_invalid" or "discriminator" in str( err.get("ctx", {}) ): + from superset.mcp_service.chart.registry import get_registry + chart_type = request_data.get("config", {}).get("chart_type", "") - hint = SchemaValidator._CHART_TYPE_ERROR_HINTS.get(chart_type) - if hint: - return ChartGenerationError(**hint) + plugin = get_registry().get(chart_type) + if plugin is not None: + hint = plugin.schema_error_hint() + if hint is not None: + return hint # Default enhanced error error_details = [] diff --git a/tests/unit_tests/mcp_service/chart/test_registry.py b/tests/unit_tests/mcp_service/chart/test_registry.py new file mode 100644 index 00000000000..0351b2d2bd3 --- /dev/null +++ b/tests/unit_tests/mcp_service/chart/test_registry.py @@ -0,0 +1,143 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Tests for the chart type plugin registry.""" + +import pytest + +import superset.mcp_service.chart.registry as registry_module +from superset.mcp_service.chart.plugin import BaseChartPlugin +from superset.mcp_service.chart.registry import ( + _RegistryProxy, + all_types, + display_name_for_viz_type, + get, + get_registry, + is_registered, + register, +) + + +@pytest.fixture(autouse=True) +def _isolated_registry(monkeypatch): + """Run each test against a clean registry without touching the real one.""" + monkeypatch.setattr(registry_module, "_REGISTRY", {}) + monkeypatch.setattr(registry_module, "_plugins_loaded", True) + + +class _FakePlugin(BaseChartPlugin): + chart_type = "fake" + display_name = "Fake Chart" + native_viz_types = {"fake_viz": "Fake Viz"} + + +class _AnotherPlugin(BaseChartPlugin): + chart_type = "another" + display_name = "Another Chart" + native_viz_types = {"another_viz": "Another Viz"} + + +def test_register_adds_plugin(): + plugin = _FakePlugin() + register(plugin) + assert get("fake") is plugin + + +def test_get_returns_none_for_unknown(): + assert get("nonexistent") is None + + +def test_all_types_returns_registered_keys(): + register(_FakePlugin()) + register(_AnotherPlugin()) + types = all_types() + assert "fake" in types + assert "another" in types + + +def test_all_types_insertion_order(): + register(_FakePlugin()) + register(_AnotherPlugin()) + types = all_types() + assert types.index("fake") < types.index("another") + + +def test_is_registered_true_for_known(): + register(_FakePlugin()) + assert is_registered("fake") is True + + +def test_is_registered_false_for_unknown(): + assert is_registered("nonexistent") is False + + +def test_register_warns_on_duplicate(caplog): + register(_FakePlugin()) + with caplog.at_level("WARNING"): + register(_FakePlugin()) + assert "Overwriting" in caplog.text + + +def test_register_raises_for_empty_chart_type(): + class _BadPlugin(BaseChartPlugin): + chart_type = "" + + with pytest.raises(ValueError, match="non-empty chart_type"): + register(_BadPlugin()) + + +def test_display_name_for_viz_type_found(): + register(_FakePlugin()) + assert display_name_for_viz_type("fake_viz") == "Fake Viz" + + +def test_display_name_for_viz_type_not_found(): + register(_FakePlugin()) + assert display_name_for_viz_type("unknown_viz") is None + + +def test_display_name_searches_all_plugins(): + register(_FakePlugin()) + register(_AnotherPlugin()) + assert display_name_for_viz_type("another_viz") == "Another Viz" + + +def test_get_registry_returns_proxy(): + assert isinstance(get_registry(), _RegistryProxy) + + +def test_registry_proxy_get(): + plugin = _FakePlugin() + register(plugin) + assert get_registry().get("fake") is plugin + + +def test_registry_proxy_all_types(): + register(_FakePlugin()) + assert "fake" in get_registry().all_types() + + +def test_registry_proxy_is_registered(): + register(_FakePlugin()) + assert get_registry().is_registered("fake") is True + assert get_registry().is_registered("missing") is False + + +def test_registry_proxy_display_name_for_viz_type(): + register(_FakePlugin()) + assert get_registry().display_name_for_viz_type("fake_viz") == "Fake Viz" + assert get_registry().display_name_for_viz_type("unknown") is None