mirror of
https://github.com/apache/superset.git
synced 2026-05-30 04:39:20 +00:00
refactor(mcp): complete plugin protocol — registry bootstrap, mypy fixes, test repairs
On top of the dead-code elimination in the previous commit: - Add lazy _ensure_plugins_loaded() bootstrap to ChartTypeRegistry so the registry is populated even without importing app.py (fixes isolated test runs) - Delegate _RegistryProxy methods to module-level functions so bootstrap runs - Guard register() against empty chart_type strings - Add generate_name + resolve_viz_type to ChartTypePlugin Protocol and BaseChartPlugin; delegate generate_chart_name/_resolve_viz_type in chart_utils to the plugin registry - Add _with_context static helper to BaseChartPlugin (shared by all plugins) - Fix stale 'five methods' → 'eight methods' docstring in plugin.py - Add TypeVar _C to normalize_column_names so mypy infers correct return type - Fix broken tests: update _pre_validate_big_number_config → _pre_validate_chart_type, remove deleted TestNormalizeXYConfig/TestNormalizeTableConfig classes, update runtime validator tests for removed _validate_format_compatibility / _validate_cardinality methods, add x is not None narrowing guards Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -1145,15 +1145,9 @@ def generate_chart_name(
|
||||
) -> str:
|
||||
"""Generate a descriptive chart name following a standard format.
|
||||
|
||||
Format conventions (by chart type):
|
||||
Aggregated (bar/scatter with group_by): [Metric] by [Dimension]
|
||||
Time-series (line/area, no group_by): [Metric] Over Time
|
||||
Table (no aggregates): [Dataset] Records
|
||||
Table (with aggregates): [Metric] Summary
|
||||
Pie: [Dimension] by [Metric]
|
||||
Pivot Table: Pivot Table – [Row1, Row2]
|
||||
Mixed Timeseries: [Primary] + [Secondary]
|
||||
An en-dash followed by context (filters / time grain) is appended
|
||||
Delegates to each plugin's ``generate_name()`` method.
|
||||
See each plugin's ``generate_name`` for chart-type-specific format conventions.
|
||||
An en-dash followed by context (filters / time grain) is appended by the plugin
|
||||
when such information is available.
|
||||
"""
|
||||
from superset.mcp_service.chart.registry import get_registry
|
||||
|
||||
@@ -38,9 +38,9 @@ class ChartTypePlugin(Protocol):
|
||||
"""
|
||||
Protocol that every chart-type plugin must satisfy.
|
||||
|
||||
Implementing all five methods in a single class guarantees that adding a
|
||||
Implementing all eight methods in a single class guarantees that adding a
|
||||
new chart type requires only one new file — the plugin — rather than edits
|
||||
across four separate files.
|
||||
across multiple separate files.
|
||||
"""
|
||||
|
||||
#: Discriminator value matching ChartConfig's chart_type field.
|
||||
@@ -236,3 +236,8 @@ class BaseChartPlugin:
|
||||
|
||||
def resolve_viz_type(self, config: Any) -> str:
|
||||
return "unknown"
|
||||
|
||||
@staticmethod
|
||||
def _with_context(what: str, context: str | None) -> str:
|
||||
"""Combine a 'what' label and optional context with an en-dash."""
|
||||
return f"{what} – {context}" if context else what
|
||||
|
||||
@@ -184,7 +184,7 @@ class BigNumberChartPlugin(BaseChartPlugin):
|
||||
|
||||
what = _big_number_chart_what(config)
|
||||
context = _summarize_filters(getattr(config, "filters", None))
|
||||
return f"{what} \u2013 {context}" if context else what
|
||||
return self._with_context(what, context)
|
||||
|
||||
def resolve_viz_type(self, config: Any) -> str:
|
||||
show_trendline = getattr(config, "show_trendline", False)
|
||||
|
||||
@@ -150,7 +150,7 @@ class HandlebarsChartPlugin(BaseChartPlugin):
|
||||
|
||||
what = _handlebars_chart_what(config)
|
||||
context = _summarize_filters(getattr(config, "filters", None))
|
||||
return f"{what} \u2013 {context}" if context else what
|
||||
return self._with_context(what, context)
|
||||
|
||||
def resolve_viz_type(self, config: Any) -> str:
|
||||
return "handlebars"
|
||||
|
||||
@@ -118,7 +118,7 @@ class MixedTimeseriesChartPlugin(BaseChartPlugin):
|
||||
|
||||
what = _mixed_timeseries_what(config)
|
||||
context = _summarize_filters(config.filters)
|
||||
return f"{what} \u2013 {context}" if context else what
|
||||
return self._with_context(what, context)
|
||||
|
||||
def resolve_viz_type(self, config: Any) -> str:
|
||||
return "mixed_timeseries"
|
||||
|
||||
@@ -92,7 +92,7 @@ class PieChartPlugin(BaseChartPlugin):
|
||||
|
||||
what = _pie_chart_what(config)
|
||||
context = _summarize_filters(config.filters)
|
||||
return f"{what} \u2013 {context}" if context else what
|
||||
return self._with_context(what, context)
|
||||
|
||||
def resolve_viz_type(self, config: Any) -> str:
|
||||
return "pie"
|
||||
|
||||
@@ -115,7 +115,7 @@ class PivotTableChartPlugin(BaseChartPlugin):
|
||||
|
||||
what = _pivot_table_what(config)
|
||||
context = _summarize_filters(config.filters)
|
||||
return f"{what} \u2013 {context}" if context else what
|
||||
return self._with_context(what, context)
|
||||
|
||||
def resolve_viz_type(self, config: Any) -> str:
|
||||
return "pivot_table_v2"
|
||||
|
||||
@@ -97,7 +97,7 @@ class TableChartPlugin(BaseChartPlugin):
|
||||
|
||||
what = _table_chart_what(config, dataset_name)
|
||||
context = _summarize_filters(config.filters)
|
||||
return f"{what} \u2013 {context}" if context else what
|
||||
return self._with_context(what, context)
|
||||
|
||||
def resolve_viz_type(self, config: Any) -> str:
|
||||
return getattr(config, "viz_type", "table")
|
||||
|
||||
@@ -130,7 +130,7 @@ class XYChartPlugin(BaseChartPlugin):
|
||||
|
||||
what = _xy_chart_what(config)
|
||||
context = _xy_chart_context(config)
|
||||
return f"{what} \u2013 {context}" if context else what
|
||||
return self._with_context(what, context)
|
||||
|
||||
def resolve_viz_type(self, config: Any) -> str:
|
||||
kind = getattr(config, "kind", "line")
|
||||
@@ -168,7 +168,7 @@ class XYChartPlugin(BaseChartPlugin):
|
||||
CardinalityValidator,
|
||||
)
|
||||
|
||||
chart_kind = config.kind if hasattr(config, "kind") else "default"
|
||||
chart_kind = config.kind
|
||||
group_by_col = config.group_by[0].name if config.group_by else None
|
||||
if config.x is not None:
|
||||
_ok, card_info = CardinalityValidator.check_cardinality(
|
||||
|
||||
@@ -45,10 +45,26 @@ if TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_REGISTRY: dict[str, "ChartTypePlugin"] = {}
|
||||
_plugins_loaded = False
|
||||
|
||||
|
||||
def _ensure_plugins_loaded() -> None:
|
||||
"""Lazily import the plugins package to populate _REGISTRY.
|
||||
|
||||
Called before every registry lookup so the registry is always populated,
|
||||
even when callers (tests, chart_utils, validators) import this module
|
||||
directly without first importing app.py.
|
||||
"""
|
||||
global _plugins_loaded
|
||||
if not _plugins_loaded:
|
||||
_plugins_loaded = True
|
||||
import superset.mcp_service.chart.plugins # noqa: F401
|
||||
|
||||
|
||||
def register(plugin: "ChartTypePlugin") -> None:
|
||||
"""Register a chart type plugin in the global registry."""
|
||||
if not plugin.chart_type:
|
||||
raise ValueError(f"{type(plugin).__name__} must define a non-empty chart_type")
|
||||
if plugin.chart_type in _REGISTRY:
|
||||
logger.warning(
|
||||
"Overwriting existing plugin for chart_type=%r", plugin.chart_type
|
||||
@@ -59,16 +75,19 @@ def register(plugin: "ChartTypePlugin") -> None:
|
||||
|
||||
def get(chart_type: str) -> "ChartTypePlugin | None":
|
||||
"""Return the plugin for a given chart_type, or None if not registered."""
|
||||
_ensure_plugins_loaded()
|
||||
return _REGISTRY.get(chart_type)
|
||||
|
||||
|
||||
def all_types() -> list[str]:
|
||||
"""Return all registered chart type strings in insertion order."""
|
||||
_ensure_plugins_loaded()
|
||||
return list(_REGISTRY.keys())
|
||||
|
||||
|
||||
def is_registered(chart_type: str) -> bool:
|
||||
"""Return True if chart_type has a registered plugin."""
|
||||
_ensure_plugins_loaded()
|
||||
return chart_type in _REGISTRY
|
||||
|
||||
|
||||
@@ -84,6 +103,7 @@ def display_name_for_viz_type(viz_type: str) -> str | None:
|
||||
display_name_for_viz_type("pivot_table_v2") # "Pivot Table"
|
||||
display_name_for_viz_type("unknown_type") # None
|
||||
"""
|
||||
_ensure_plugins_loaded()
|
||||
for plugin in _REGISTRY.values():
|
||||
name = plugin.native_viz_types.get(viz_type)
|
||||
if name is not None:
|
||||
@@ -100,13 +120,13 @@ class _RegistryProxy:
|
||||
"""Thin proxy exposing registry functions as instance methods."""
|
||||
|
||||
def get(self, chart_type: str) -> "ChartTypePlugin | None":
|
||||
return _REGISTRY.get(chart_type)
|
||||
return get(chart_type)
|
||||
|
||||
def all_types(self) -> list[str]:
|
||||
return list(_REGISTRY.keys())
|
||||
return all_types()
|
||||
|
||||
def is_registered(self, chart_type: str) -> bool:
|
||||
return chart_type in _REGISTRY
|
||||
return is_registered(chart_type)
|
||||
|
||||
def display_name_for_viz_type(self, viz_type: str) -> str | None:
|
||||
return display_name_for_viz_type(viz_type)
|
||||
|
||||
@@ -22,12 +22,14 @@ Validates that referenced columns exist in the dataset schema.
|
||||
|
||||
import difflib
|
||||
import logging
|
||||
from typing import Any, Dict, List, Tuple
|
||||
from typing import Any, Dict, List, Tuple, TypeVar
|
||||
|
||||
from superset.mcp_service.chart.schemas import (
|
||||
ChartConfig,
|
||||
ColumnRef,
|
||||
)
|
||||
|
||||
_C = TypeVar("_C", bound=ChartConfig)
|
||||
from superset.mcp_service.common.error_schemas import (
|
||||
ChartGenerationError,
|
||||
ColumnSuggestion,
|
||||
@@ -346,10 +348,10 @@ class DatasetValidator:
|
||||
|
||||
@staticmethod
|
||||
def normalize_column_names(
|
||||
config: ChartConfig,
|
||||
config: _C,
|
||||
dataset_id: int | str,
|
||||
dataset_context: DatasetContext | None = None,
|
||||
) -> ChartConfig:
|
||||
) -> _C:
|
||||
"""
|
||||
Normalize column names in config to match the canonical dataset column names.
|
||||
|
||||
|
||||
@@ -90,7 +90,7 @@ class TestBigNumberChartConfig:
|
||||
"chart_type": "big_number",
|
||||
"metric": {"name": "total_sales", "saved_metric": True},
|
||||
}
|
||||
is_valid, error = SchemaValidator._pre_validate_big_number_config(data)
|
||||
is_valid, error = SchemaValidator._pre_validate_chart_type("big_number", data)
|
||||
assert is_valid is True
|
||||
assert error is None
|
||||
|
||||
|
||||
@@ -222,9 +222,9 @@ class TestBigNumberChartFallback:
|
||||
"viz_type": viz_type,
|
||||
}
|
||||
metrics, _ = _extract_metrics_and_groupby(form_data)
|
||||
assert metrics == [{"label": "plural_metric"}], (
|
||||
f"{viz_type} should use plural metrics"
|
||||
)
|
||||
assert metrics == [
|
||||
{"label": "plural_metric"}
|
||||
], f"{viz_type} should use plural metrics"
|
||||
|
||||
def test_pop_kpi_uses_singular_metric(self):
|
||||
"""Test that pop_kpi (BigNumberPeriodOverPeriod) uses singular metric."""
|
||||
|
||||
@@ -117,83 +117,6 @@ class TestGetCanonicalColumnName:
|
||||
assert result == "unknown_column"
|
||||
|
||||
|
||||
class TestNormalizeXYConfig:
|
||||
"""Test _normalize_xy_config static method."""
|
||||
|
||||
def test_normalize_x_axis_column(
|
||||
self, mock_dataset_context: DatasetContext
|
||||
) -> None:
|
||||
"""Test that x-axis column name is normalized."""
|
||||
config_dict: Dict[str, Any] = {
|
||||
"chart_type": "xy",
|
||||
"x": {"name": "orderdate"},
|
||||
"y": [{"name": "Sales", "aggregate": "SUM"}],
|
||||
"kind": "line",
|
||||
}
|
||||
|
||||
DatasetValidator._normalize_xy_config(config_dict, mock_dataset_context)
|
||||
|
||||
assert config_dict["x"]["name"] == "OrderDate"
|
||||
|
||||
def test_normalize_y_axis_columns(
|
||||
self, mock_dataset_context: DatasetContext
|
||||
) -> None:
|
||||
"""Test that y-axis column names are normalized."""
|
||||
config_dict: Dict[str, Any] = {
|
||||
"chart_type": "xy",
|
||||
"x": {"name": "OrderDate"},
|
||||
"y": [
|
||||
{"name": "sales", "aggregate": "SUM"},
|
||||
{"name": "QUANTITY_ORDERED", "aggregate": "COUNT"},
|
||||
],
|
||||
"kind": "bar",
|
||||
}
|
||||
|
||||
DatasetValidator._normalize_xy_config(config_dict, mock_dataset_context)
|
||||
|
||||
assert config_dict["y"][0]["name"] == "Sales"
|
||||
assert config_dict["y"][1]["name"] == "quantity_ordered"
|
||||
|
||||
def test_normalize_group_by_column(
|
||||
self, mock_dataset_context: DatasetContext
|
||||
) -> None:
|
||||
"""Test that group_by column name is normalized."""
|
||||
config_dict: Dict[str, Any] = {
|
||||
"chart_type": "xy",
|
||||
"x": {"name": "OrderDate"},
|
||||
"y": [{"name": "Sales", "aggregate": "SUM"}],
|
||||
"kind": "line",
|
||||
"group_by": [{"name": "productline"}],
|
||||
}
|
||||
|
||||
DatasetValidator._normalize_xy_config(config_dict, mock_dataset_context)
|
||||
|
||||
assert config_dict["group_by"][0]["name"] == "ProductLine"
|
||||
|
||||
|
||||
class TestNormalizeTableConfig:
|
||||
"""Test _normalize_table_config static method."""
|
||||
|
||||
def test_normalize_table_columns(
|
||||
self, mock_dataset_context: DatasetContext
|
||||
) -> None:
|
||||
"""Test that table column names are normalized."""
|
||||
config_dict: Dict[str, Any] = {
|
||||
"chart_type": "table",
|
||||
"columns": [
|
||||
{"name": "orderdate"},
|
||||
{"name": "PRODUCTLINE"},
|
||||
{"name": "sales", "aggregate": "SUM"},
|
||||
],
|
||||
}
|
||||
|
||||
DatasetValidator._normalize_table_config(config_dict, mock_dataset_context)
|
||||
|
||||
assert config_dict["columns"][0]["name"] == "OrderDate"
|
||||
assert config_dict["columns"][1]["name"] == "ProductLine"
|
||||
assert config_dict["columns"][2]["name"] == "Sales"
|
||||
|
||||
|
||||
class TestNormalizeFilters:
|
||||
"""Test _normalize_filters static method."""
|
||||
|
||||
|
||||
@@ -58,12 +58,12 @@ class TestRuntimeValidatorNonBlocking:
|
||||
x_axis=AxisConfig(format="$,.2f"), # Currency format for date - mismatch
|
||||
)
|
||||
|
||||
# Mock the format validator to return warnings
|
||||
# Mock the plugin runtime dispatcher to return format warnings
|
||||
with patch(
|
||||
"superset.mcp_service.chart.validation.runtime.RuntimeValidator."
|
||||
"_validate_format_compatibility"
|
||||
) as mock_format:
|
||||
mock_format.return_value = [
|
||||
"_validate_plugin_runtime"
|
||||
) as mock_plugin:
|
||||
mock_plugin.return_value = [
|
||||
"Currency format '$,.2f' may not display dates correctly"
|
||||
]
|
||||
|
||||
@@ -87,15 +87,14 @@ class TestRuntimeValidatorNonBlocking:
|
||||
kind="bar",
|
||||
)
|
||||
|
||||
# Mock the cardinality validator to return warnings
|
||||
# Mock the plugin runtime dispatcher to return cardinality warnings
|
||||
with patch(
|
||||
"superset.mcp_service.chart.validation.runtime.RuntimeValidator."
|
||||
"_validate_cardinality"
|
||||
) as mock_cardinality:
|
||||
mock_cardinality.return_value = (
|
||||
["High cardinality detected: 10000+ unique values"],
|
||||
["Consider using aggregation or filtering"],
|
||||
)
|
||||
"_validate_plugin_runtime"
|
||||
) as mock_plugin:
|
||||
mock_plugin.return_value = [
|
||||
"High cardinality detected: 10000+ unique values"
|
||||
]
|
||||
|
||||
is_valid, warnings_metadata = RuntimeValidator.validate_runtime_issues(
|
||||
config, 1
|
||||
@@ -148,26 +147,21 @@ class TestRuntimeValidatorNonBlocking:
|
||||
x_axis=AxisConfig(format="smart_date"), # Wrong format for user_id
|
||||
)
|
||||
|
||||
# Mock all validators to return warnings
|
||||
# Mock plugin runtime and chart type validators to return warnings
|
||||
with (
|
||||
patch(
|
||||
"superset.mcp_service.chart.validation.runtime.RuntimeValidator."
|
||||
"_validate_format_compatibility"
|
||||
) as mock_format,
|
||||
patch(
|
||||
"superset.mcp_service.chart.validation.runtime.RuntimeValidator."
|
||||
"_validate_cardinality"
|
||||
) as mock_cardinality,
|
||||
"_validate_plugin_runtime"
|
||||
) as mock_plugin,
|
||||
patch(
|
||||
"superset.mcp_service.chart.validation.runtime.RuntimeValidator."
|
||||
"_validate_chart_type"
|
||||
) as mock_type,
|
||||
):
|
||||
mock_format.return_value = ["Format mismatch warning"]
|
||||
mock_cardinality.return_value = (
|
||||
["High cardinality warning"],
|
||||
["Cardinality suggestion"],
|
||||
)
|
||||
mock_plugin.return_value = [
|
||||
"Format mismatch warning",
|
||||
"High cardinality warning",
|
||||
]
|
||||
mock_type.return_value = (
|
||||
["Chart type warning"],
|
||||
["Chart type suggestion"],
|
||||
@@ -197,13 +191,13 @@ class TestRuntimeValidatorNonBlocking:
|
||||
with (
|
||||
patch(
|
||||
"superset.mcp_service.chart.validation.runtime.RuntimeValidator."
|
||||
"_validate_format_compatibility"
|
||||
) as mock_format,
|
||||
"_validate_plugin_runtime"
|
||||
) as mock_plugin,
|
||||
patch(
|
||||
"superset.mcp_service.chart.validation.runtime.logger"
|
||||
) as mock_logger,
|
||||
):
|
||||
mock_format.return_value = ["Test warning message"]
|
||||
mock_plugin.return_value = ["Test warning message"]
|
||||
|
||||
is_valid, warnings_metadata = RuntimeValidator.validate_runtime_issues(
|
||||
config, 1
|
||||
@@ -217,7 +211,7 @@ class TestRuntimeValidatorNonBlocking:
|
||||
assert "warnings" in warnings_metadata
|
||||
|
||||
def test_validate_table_chart_skips_xy_validations(self):
|
||||
"""Test that table charts skip XY-specific validations."""
|
||||
"""Test that table charts produce no XY-specific runtime warnings."""
|
||||
config = TableChartConfig(
|
||||
chart_type="table",
|
||||
columns=[
|
||||
@@ -226,28 +220,15 @@ class TestRuntimeValidatorNonBlocking:
|
||||
],
|
||||
)
|
||||
|
||||
# These should not be called for table charts
|
||||
with (
|
||||
patch(
|
||||
"superset.mcp_service.chart.validation.runtime.RuntimeValidator."
|
||||
"_validate_format_compatibility"
|
||||
) as mock_format,
|
||||
patch(
|
||||
"superset.mcp_service.chart.validation.runtime.RuntimeValidator."
|
||||
"_validate_cardinality"
|
||||
) as mock_cardinality,
|
||||
patch(
|
||||
"superset.mcp_service.chart.validation.runtime.RuntimeValidator."
|
||||
"_validate_chart_type"
|
||||
) as mock_chart_type,
|
||||
):
|
||||
# Mock chart type validator to return no warnings
|
||||
# Plugin runtime dispatches to TableChartPlugin which returns no warnings.
|
||||
# Chart type suggester is also stubbed to return no warnings.
|
||||
with patch(
|
||||
"superset.mcp_service.chart.validation.runtime.RuntimeValidator."
|
||||
"_validate_chart_type"
|
||||
) as mock_chart_type:
|
||||
mock_chart_type.return_value = ([], [])
|
||||
|
||||
is_valid, error = RuntimeValidator.validate_runtime_issues(config, 1)
|
||||
|
||||
# Format and cardinality validation should not be called for table charts
|
||||
mock_format.assert_not_called()
|
||||
mock_cardinality.assert_not_called()
|
||||
assert is_valid is True
|
||||
assert error is None
|
||||
|
||||
Reference in New Issue
Block a user