diff --git a/superset/mcp_service/chart/chart_utils.py b/superset/mcp_service/chart/chart_utils.py index ca4f6c54fde..d2313ed3b60 100644 --- a/superset/mcp_service/chart/chart_utils.py +++ b/superset/mcp_service/chart/chart_utils.py @@ -1145,15 +1145,9 @@ def generate_chart_name( ) -> str: """Generate a descriptive chart name following a standard format. - Format conventions (by chart type): - Aggregated (bar/scatter with group_by): [Metric] by [Dimension] - Time-series (line/area, no group_by): [Metric] Over Time - Table (no aggregates): [Dataset] Records - Table (with aggregates): [Metric] Summary - Pie: [Dimension] by [Metric] - Pivot Table: Pivot Table – [Row1, Row2] - Mixed Timeseries: [Primary] + [Secondary] - An en-dash followed by context (filters / time grain) is appended + Delegates to each plugin's ``generate_name()`` method. + See each plugin's ``generate_name`` for chart-type-specific format conventions. + An en-dash followed by context (filters / time grain) is appended by the plugin when such information is available. """ from superset.mcp_service.chart.registry import get_registry diff --git a/superset/mcp_service/chart/plugin.py b/superset/mcp_service/chart/plugin.py index b310bc1da07..d0e68208664 100644 --- a/superset/mcp_service/chart/plugin.py +++ b/superset/mcp_service/chart/plugin.py @@ -38,9 +38,9 @@ class ChartTypePlugin(Protocol): """ Protocol that every chart-type plugin must satisfy. - Implementing all five methods in a single class guarantees that adding a + Implementing all eight methods in a single class guarantees that adding a new chart type requires only one new file — the plugin — rather than edits - across four separate files. + across multiple separate files. """ #: Discriminator value matching ChartConfig's chart_type field. @@ -236,3 +236,8 @@ class BaseChartPlugin: def resolve_viz_type(self, config: Any) -> str: return "unknown" + + @staticmethod + def _with_context(what: str, context: str | None) -> str: + """Combine a 'what' label and optional context with an en-dash.""" + return f"{what} – {context}" if context else what diff --git a/superset/mcp_service/chart/plugins/big_number.py b/superset/mcp_service/chart/plugins/big_number.py index d924cd7735f..17ee2b3d896 100644 --- a/superset/mcp_service/chart/plugins/big_number.py +++ b/superset/mcp_service/chart/plugins/big_number.py @@ -184,7 +184,7 @@ class BigNumberChartPlugin(BaseChartPlugin): what = _big_number_chart_what(config) context = _summarize_filters(getattr(config, "filters", None)) - return f"{what} \u2013 {context}" if context else what + return self._with_context(what, context) def resolve_viz_type(self, config: Any) -> str: show_trendline = getattr(config, "show_trendline", False) diff --git a/superset/mcp_service/chart/plugins/handlebars.py b/superset/mcp_service/chart/plugins/handlebars.py index d177ede9457..c242030572a 100644 --- a/superset/mcp_service/chart/plugins/handlebars.py +++ b/superset/mcp_service/chart/plugins/handlebars.py @@ -150,7 +150,7 @@ class HandlebarsChartPlugin(BaseChartPlugin): what = _handlebars_chart_what(config) context = _summarize_filters(getattr(config, "filters", None)) - return f"{what} \u2013 {context}" if context else what + return self._with_context(what, context) def resolve_viz_type(self, config: Any) -> str: return "handlebars" diff --git a/superset/mcp_service/chart/plugins/mixed_timeseries.py b/superset/mcp_service/chart/plugins/mixed_timeseries.py index 706df74499d..079151eed3f 100644 --- a/superset/mcp_service/chart/plugins/mixed_timeseries.py +++ b/superset/mcp_service/chart/plugins/mixed_timeseries.py @@ -118,7 +118,7 @@ class MixedTimeseriesChartPlugin(BaseChartPlugin): what = _mixed_timeseries_what(config) context = _summarize_filters(config.filters) - return f"{what} \u2013 {context}" if context else what + return self._with_context(what, context) def resolve_viz_type(self, config: Any) -> str: return "mixed_timeseries" diff --git a/superset/mcp_service/chart/plugins/pie.py b/superset/mcp_service/chart/plugins/pie.py index 0160e134a3d..2be4cdd98a4 100644 --- a/superset/mcp_service/chart/plugins/pie.py +++ b/superset/mcp_service/chart/plugins/pie.py @@ -92,7 +92,7 @@ class PieChartPlugin(BaseChartPlugin): what = _pie_chart_what(config) context = _summarize_filters(config.filters) - return f"{what} \u2013 {context}" if context else what + return self._with_context(what, context) def resolve_viz_type(self, config: Any) -> str: return "pie" diff --git a/superset/mcp_service/chart/plugins/pivot_table.py b/superset/mcp_service/chart/plugins/pivot_table.py index e514a3307a0..ae85cd1d816 100644 --- a/superset/mcp_service/chart/plugins/pivot_table.py +++ b/superset/mcp_service/chart/plugins/pivot_table.py @@ -115,7 +115,7 @@ class PivotTableChartPlugin(BaseChartPlugin): what = _pivot_table_what(config) context = _summarize_filters(config.filters) - return f"{what} \u2013 {context}" if context else what + return self._with_context(what, context) def resolve_viz_type(self, config: Any) -> str: return "pivot_table_v2" diff --git a/superset/mcp_service/chart/plugins/table.py b/superset/mcp_service/chart/plugins/table.py index ee7d9f27d4b..13bd1640dc9 100644 --- a/superset/mcp_service/chart/plugins/table.py +++ b/superset/mcp_service/chart/plugins/table.py @@ -97,7 +97,7 @@ class TableChartPlugin(BaseChartPlugin): what = _table_chart_what(config, dataset_name) context = _summarize_filters(config.filters) - return f"{what} \u2013 {context}" if context else what + return self._with_context(what, context) def resolve_viz_type(self, config: Any) -> str: return getattr(config, "viz_type", "table") diff --git a/superset/mcp_service/chart/plugins/xy.py b/superset/mcp_service/chart/plugins/xy.py index 92088ed66ab..9176d083f29 100644 --- a/superset/mcp_service/chart/plugins/xy.py +++ b/superset/mcp_service/chart/plugins/xy.py @@ -130,7 +130,7 @@ class XYChartPlugin(BaseChartPlugin): what = _xy_chart_what(config) context = _xy_chart_context(config) - return f"{what} \u2013 {context}" if context else what + return self._with_context(what, context) def resolve_viz_type(self, config: Any) -> str: kind = getattr(config, "kind", "line") @@ -168,7 +168,7 @@ class XYChartPlugin(BaseChartPlugin): CardinalityValidator, ) - chart_kind = config.kind if hasattr(config, "kind") else "default" + chart_kind = config.kind group_by_col = config.group_by[0].name if config.group_by else None if config.x is not None: _ok, card_info = CardinalityValidator.check_cardinality( diff --git a/superset/mcp_service/chart/registry.py b/superset/mcp_service/chart/registry.py index 872a2578e1c..5b8f4c9515f 100644 --- a/superset/mcp_service/chart/registry.py +++ b/superset/mcp_service/chart/registry.py @@ -45,10 +45,26 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) _REGISTRY: dict[str, "ChartTypePlugin"] = {} +_plugins_loaded = False + + +def _ensure_plugins_loaded() -> None: + """Lazily import the plugins package to populate _REGISTRY. + + Called before every registry lookup so the registry is always populated, + even when callers (tests, chart_utils, validators) import this module + directly without first importing app.py. + """ + global _plugins_loaded + if not _plugins_loaded: + _plugins_loaded = True + import superset.mcp_service.chart.plugins # noqa: F401 def register(plugin: "ChartTypePlugin") -> None: """Register a chart type plugin in the global registry.""" + if not plugin.chart_type: + raise ValueError(f"{type(plugin).__name__} must define a non-empty chart_type") if plugin.chart_type in _REGISTRY: logger.warning( "Overwriting existing plugin for chart_type=%r", plugin.chart_type @@ -59,16 +75,19 @@ def register(plugin: "ChartTypePlugin") -> None: def get(chart_type: str) -> "ChartTypePlugin | None": """Return the plugin for a given chart_type, or None if not registered.""" + _ensure_plugins_loaded() return _REGISTRY.get(chart_type) def all_types() -> list[str]: """Return all registered chart type strings in insertion order.""" + _ensure_plugins_loaded() return list(_REGISTRY.keys()) def is_registered(chart_type: str) -> bool: """Return True if chart_type has a registered plugin.""" + _ensure_plugins_loaded() return chart_type in _REGISTRY @@ -84,6 +103,7 @@ def display_name_for_viz_type(viz_type: str) -> str | None: display_name_for_viz_type("pivot_table_v2") # "Pivot Table" display_name_for_viz_type("unknown_type") # None """ + _ensure_plugins_loaded() for plugin in _REGISTRY.values(): name = plugin.native_viz_types.get(viz_type) if name is not None: @@ -100,13 +120,13 @@ class _RegistryProxy: """Thin proxy exposing registry functions as instance methods.""" def get(self, chart_type: str) -> "ChartTypePlugin | None": - return _REGISTRY.get(chart_type) + return get(chart_type) def all_types(self) -> list[str]: - return list(_REGISTRY.keys()) + return all_types() def is_registered(self, chart_type: str) -> bool: - return chart_type in _REGISTRY + return is_registered(chart_type) def display_name_for_viz_type(self, viz_type: str) -> str | None: return display_name_for_viz_type(viz_type) diff --git a/superset/mcp_service/chart/validation/dataset_validator.py b/superset/mcp_service/chart/validation/dataset_validator.py index f1ef0eacabf..7d07cc89055 100644 --- a/superset/mcp_service/chart/validation/dataset_validator.py +++ b/superset/mcp_service/chart/validation/dataset_validator.py @@ -22,12 +22,14 @@ Validates that referenced columns exist in the dataset schema. import difflib import logging -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Tuple, TypeVar from superset.mcp_service.chart.schemas import ( ChartConfig, ColumnRef, ) + +_C = TypeVar("_C", bound=ChartConfig) from superset.mcp_service.common.error_schemas import ( ChartGenerationError, ColumnSuggestion, @@ -346,10 +348,10 @@ class DatasetValidator: @staticmethod def normalize_column_names( - config: ChartConfig, + config: _C, dataset_id: int | str, dataset_context: DatasetContext | None = None, - ) -> ChartConfig: + ) -> _C: """ Normalize column names in config to match the canonical dataset column names. diff --git a/tests/unit_tests/mcp_service/chart/test_big_number_chart.py b/tests/unit_tests/mcp_service/chart/test_big_number_chart.py index 59e142333bd..c832d7793d1 100644 --- a/tests/unit_tests/mcp_service/chart/test_big_number_chart.py +++ b/tests/unit_tests/mcp_service/chart/test_big_number_chart.py @@ -90,7 +90,7 @@ class TestBigNumberChartConfig: "chart_type": "big_number", "metric": {"name": "total_sales", "saved_metric": True}, } - is_valid, error = SchemaValidator._pre_validate_big_number_config(data) + is_valid, error = SchemaValidator._pre_validate_chart_type("big_number", data) assert is_valid is True assert error is None diff --git a/tests/unit_tests/mcp_service/chart/tool/test_get_chart_data.py b/tests/unit_tests/mcp_service/chart/tool/test_get_chart_data.py index 5404f8985b6..c1e347af55c 100644 --- a/tests/unit_tests/mcp_service/chart/tool/test_get_chart_data.py +++ b/tests/unit_tests/mcp_service/chart/tool/test_get_chart_data.py @@ -222,9 +222,9 @@ class TestBigNumberChartFallback: "viz_type": viz_type, } metrics, _ = _extract_metrics_and_groupby(form_data) - assert metrics == [{"label": "plural_metric"}], ( - f"{viz_type} should use plural metrics" - ) + assert metrics == [ + {"label": "plural_metric"} + ], f"{viz_type} should use plural metrics" def test_pop_kpi_uses_singular_metric(self): """Test that pop_kpi (BigNumberPeriodOverPeriod) uses singular metric.""" diff --git a/tests/unit_tests/mcp_service/chart/validation/test_column_name_normalization.py b/tests/unit_tests/mcp_service/chart/validation/test_column_name_normalization.py index dbebe268b4b..a81f0864f26 100644 --- a/tests/unit_tests/mcp_service/chart/validation/test_column_name_normalization.py +++ b/tests/unit_tests/mcp_service/chart/validation/test_column_name_normalization.py @@ -117,83 +117,6 @@ class TestGetCanonicalColumnName: assert result == "unknown_column" -class TestNormalizeXYConfig: - """Test _normalize_xy_config static method.""" - - def test_normalize_x_axis_column( - self, mock_dataset_context: DatasetContext - ) -> None: - """Test that x-axis column name is normalized.""" - config_dict: Dict[str, Any] = { - "chart_type": "xy", - "x": {"name": "orderdate"}, - "y": [{"name": "Sales", "aggregate": "SUM"}], - "kind": "line", - } - - DatasetValidator._normalize_xy_config(config_dict, mock_dataset_context) - - assert config_dict["x"]["name"] == "OrderDate" - - def test_normalize_y_axis_columns( - self, mock_dataset_context: DatasetContext - ) -> None: - """Test that y-axis column names are normalized.""" - config_dict: Dict[str, Any] = { - "chart_type": "xy", - "x": {"name": "OrderDate"}, - "y": [ - {"name": "sales", "aggregate": "SUM"}, - {"name": "QUANTITY_ORDERED", "aggregate": "COUNT"}, - ], - "kind": "bar", - } - - DatasetValidator._normalize_xy_config(config_dict, mock_dataset_context) - - assert config_dict["y"][0]["name"] == "Sales" - assert config_dict["y"][1]["name"] == "quantity_ordered" - - def test_normalize_group_by_column( - self, mock_dataset_context: DatasetContext - ) -> None: - """Test that group_by column name is normalized.""" - config_dict: Dict[str, Any] = { - "chart_type": "xy", - "x": {"name": "OrderDate"}, - "y": [{"name": "Sales", "aggregate": "SUM"}], - "kind": "line", - "group_by": [{"name": "productline"}], - } - - DatasetValidator._normalize_xy_config(config_dict, mock_dataset_context) - - assert config_dict["group_by"][0]["name"] == "ProductLine" - - -class TestNormalizeTableConfig: - """Test _normalize_table_config static method.""" - - def test_normalize_table_columns( - self, mock_dataset_context: DatasetContext - ) -> None: - """Test that table column names are normalized.""" - config_dict: Dict[str, Any] = { - "chart_type": "table", - "columns": [ - {"name": "orderdate"}, - {"name": "PRODUCTLINE"}, - {"name": "sales", "aggregate": "SUM"}, - ], - } - - DatasetValidator._normalize_table_config(config_dict, mock_dataset_context) - - assert config_dict["columns"][0]["name"] == "OrderDate" - assert config_dict["columns"][1]["name"] == "ProductLine" - assert config_dict["columns"][2]["name"] == "Sales" - - class TestNormalizeFilters: """Test _normalize_filters static method.""" diff --git a/tests/unit_tests/mcp_service/chart/validation/test_runtime_validator.py b/tests/unit_tests/mcp_service/chart/validation/test_runtime_validator.py index c49677cb99f..6aed0b11269 100644 --- a/tests/unit_tests/mcp_service/chart/validation/test_runtime_validator.py +++ b/tests/unit_tests/mcp_service/chart/validation/test_runtime_validator.py @@ -58,12 +58,12 @@ class TestRuntimeValidatorNonBlocking: x_axis=AxisConfig(format="$,.2f"), # Currency format for date - mismatch ) - # Mock the format validator to return warnings + # Mock the plugin runtime dispatcher to return format warnings with patch( "superset.mcp_service.chart.validation.runtime.RuntimeValidator." - "_validate_format_compatibility" - ) as mock_format: - mock_format.return_value = [ + "_validate_plugin_runtime" + ) as mock_plugin: + mock_plugin.return_value = [ "Currency format '$,.2f' may not display dates correctly" ] @@ -87,15 +87,14 @@ class TestRuntimeValidatorNonBlocking: kind="bar", ) - # Mock the cardinality validator to return warnings + # Mock the plugin runtime dispatcher to return cardinality warnings with patch( "superset.mcp_service.chart.validation.runtime.RuntimeValidator." - "_validate_cardinality" - ) as mock_cardinality: - mock_cardinality.return_value = ( - ["High cardinality detected: 10000+ unique values"], - ["Consider using aggregation or filtering"], - ) + "_validate_plugin_runtime" + ) as mock_plugin: + mock_plugin.return_value = [ + "High cardinality detected: 10000+ unique values" + ] is_valid, warnings_metadata = RuntimeValidator.validate_runtime_issues( config, 1 @@ -148,26 +147,21 @@ class TestRuntimeValidatorNonBlocking: x_axis=AxisConfig(format="smart_date"), # Wrong format for user_id ) - # Mock all validators to return warnings + # Mock plugin runtime and chart type validators to return warnings with ( patch( "superset.mcp_service.chart.validation.runtime.RuntimeValidator." - "_validate_format_compatibility" - ) as mock_format, - patch( - "superset.mcp_service.chart.validation.runtime.RuntimeValidator." - "_validate_cardinality" - ) as mock_cardinality, + "_validate_plugin_runtime" + ) as mock_plugin, patch( "superset.mcp_service.chart.validation.runtime.RuntimeValidator." "_validate_chart_type" ) as mock_type, ): - mock_format.return_value = ["Format mismatch warning"] - mock_cardinality.return_value = ( - ["High cardinality warning"], - ["Cardinality suggestion"], - ) + mock_plugin.return_value = [ + "Format mismatch warning", + "High cardinality warning", + ] mock_type.return_value = ( ["Chart type warning"], ["Chart type suggestion"], @@ -197,13 +191,13 @@ class TestRuntimeValidatorNonBlocking: with ( patch( "superset.mcp_service.chart.validation.runtime.RuntimeValidator." - "_validate_format_compatibility" - ) as mock_format, + "_validate_plugin_runtime" + ) as mock_plugin, patch( "superset.mcp_service.chart.validation.runtime.logger" ) as mock_logger, ): - mock_format.return_value = ["Test warning message"] + mock_plugin.return_value = ["Test warning message"] is_valid, warnings_metadata = RuntimeValidator.validate_runtime_issues( config, 1 @@ -217,7 +211,7 @@ class TestRuntimeValidatorNonBlocking: assert "warnings" in warnings_metadata def test_validate_table_chart_skips_xy_validations(self): - """Test that table charts skip XY-specific validations.""" + """Test that table charts produce no XY-specific runtime warnings.""" config = TableChartConfig( chart_type="table", columns=[ @@ -226,28 +220,15 @@ class TestRuntimeValidatorNonBlocking: ], ) - # These should not be called for table charts - with ( - patch( - "superset.mcp_service.chart.validation.runtime.RuntimeValidator." - "_validate_format_compatibility" - ) as mock_format, - patch( - "superset.mcp_service.chart.validation.runtime.RuntimeValidator." - "_validate_cardinality" - ) as mock_cardinality, - patch( - "superset.mcp_service.chart.validation.runtime.RuntimeValidator." - "_validate_chart_type" - ) as mock_chart_type, - ): - # Mock chart type validator to return no warnings + # Plugin runtime dispatches to TableChartPlugin which returns no warnings. + # Chart type suggester is also stubbed to return no warnings. + with patch( + "superset.mcp_service.chart.validation.runtime.RuntimeValidator." + "_validate_chart_type" + ) as mock_chart_type: mock_chart_type.return_value = ([], []) is_valid, error = RuntimeValidator.validate_runtime_issues(config, 1) - # Format and cardinality validation should not be called for table charts - mock_format.assert_not_called() - mock_cardinality.assert_not_called() assert is_valid is True assert error is None