diff --git a/superset/mcp_service/chart/chart_utils.py b/superset/mcp_service/chart/chart_utils.py index 655c95a5f4d..31fc9537051 100644 --- a/superset/mcp_service/chart/chart_utils.py +++ b/superset/mcp_service/chart/chart_utils.py @@ -1132,13 +1132,7 @@ def _big_number_chart_what(config: BigNumberChartConfig) -> str: def generate_chart_name( - config: TableChartConfig - | XYChartConfig - | PieChartConfig - | PivotTableChartConfig - | MixedTimeseriesChartConfig - | HandlebarsChartConfig - | BigNumberChartConfig, + config: Any, dataset_name: str | None = None, ) -> str: """Generate a descriptive chart name following a standard format. @@ -1154,65 +1148,22 @@ def generate_chart_name( An en-dash followed by context (filters / time grain) is appended when such information is available. """ - if isinstance(config, TableChartConfig): - what = _table_chart_what(config, dataset_name) - context = _summarize_filters(config.filters) - elif isinstance(config, XYChartConfig): - what = _xy_chart_what(config) - context = _xy_chart_context(config) - elif isinstance(config, PieChartConfig): - what = _pie_chart_what(config) - context = _summarize_filters(config.filters) - elif isinstance(config, PivotTableChartConfig): - what = _pivot_table_what(config) - context = _summarize_filters(config.filters) - elif isinstance(config, MixedTimeseriesChartConfig): - what = _mixed_timeseries_what(config) - context = _summarize_filters(config.filters) - elif isinstance(config, HandlebarsChartConfig): - what = _handlebars_chart_what(config) - context = _summarize_filters(getattr(config, "filters", None)) - elif isinstance(config, BigNumberChartConfig): - what = _big_number_chart_what(config) - context = _summarize_filters(getattr(config, "filters", None)) - else: - return "Chart" + from superset.mcp_service.chart.registry import get_registry - name = what - if context: - name = f"{what} \u2013 {context}" - return _truncate(name) + plugin = get_registry().get(getattr(config, "chart_type", "")) + if plugin is None: + return "Chart" + return _truncate(plugin.generate_name(config, dataset_name)) def _resolve_viz_type(config: Any) -> str: """Resolve the Superset viz_type from a chart config object.""" - chart_type = getattr(config, "chart_type", "unknown") - if chart_type == "xy": - kind = getattr(config, "kind", "line") - viz_type_map = { - "line": "echarts_timeseries_line", - "bar": "echarts_timeseries_bar", - "area": "echarts_area", - "scatter": "echarts_timeseries_scatter", - } - return viz_type_map.get(kind, "echarts_timeseries_line") - elif chart_type == "table": - return getattr(config, "viz_type", "table") - elif chart_type == "pie": - return "pie" - elif chart_type == "pivot_table": - return "pivot_table_v2" - elif chart_type == "mixed_timeseries": - return "mixed_timeseries" - elif chart_type == "handlebars": - return "handlebars" - elif chart_type == "big_number": - show_trendline = getattr(config, "show_trendline", False) - temporal_column = getattr(config, "temporal_column", None) - return ( - "big_number" if show_trendline and temporal_column else "big_number_total" - ) - return "unknown" + from superset.mcp_service.chart.registry import get_registry + + plugin = get_registry().get(getattr(config, "chart_type", "")) + if plugin is None: + return "unknown" + return plugin.resolve_viz_type(config) def analyze_chart_capabilities(chart: Any | None, config: Any) -> ChartCapabilities: diff --git a/superset/mcp_service/chart/plugin.py b/superset/mcp_service/chart/plugin.py index 0221c079032..b310bc1da07 100644 --- a/superset/mcp_service/chart/plugin.py +++ b/superset/mcp_service/chart/plugin.py @@ -148,6 +148,30 @@ class ChartTypePlugin(Protocol): """ ... + def generate_name( + self, + config: Any, + dataset_name: str | None = None, + ) -> str: + """ + Return a descriptive chart name for the given config. + + Called by chart_utils.generate_chart_name(). The name should follow + the standard format conventions documented in that function. Plugins + that do not override this return the generic fallback "Chart". + """ + ... + + def resolve_viz_type(self, config: Any) -> str: + """ + Return the Superset-internal viz_type string for this config. + + Called by chart_utils._resolve_viz_type(). The returned string must + match a registered Superset viz plugin (e.g. "echarts_timeseries_line"). + Plugins that do not override this return "unknown". + """ + ... + class BaseChartPlugin: """ @@ -202,3 +226,13 @@ class BaseChartPlugin: dataset_id: int | str, ) -> list[str]: return [] + + def generate_name( + self, + config: Any, + dataset_name: str | None = None, + ) -> str: + return "Chart" + + def resolve_viz_type(self, config: Any) -> str: + return "unknown" diff --git a/superset/mcp_service/chart/plugins/big_number.py b/superset/mcp_service/chart/plugins/big_number.py index 6c7fc199f28..d924cd7735f 100644 --- a/superset/mcp_service/chart/plugins/big_number.py +++ b/superset/mcp_service/chart/plugins/big_number.py @@ -176,6 +176,23 @@ class BigNumberChartPlugin(BaseChartPlugin): return None + def generate_name(self, config: Any, dataset_name: str | None = None) -> str: + from superset.mcp_service.chart.chart_utils import ( + _big_number_chart_what, + _summarize_filters, + ) + + what = _big_number_chart_what(config) + context = _summarize_filters(getattr(config, "filters", None)) + return f"{what} \u2013 {context}" if context else what + + def resolve_viz_type(self, config: Any) -> str: + show_trendline = getattr(config, "show_trendline", False) + temporal_column = getattr(config, "temporal_column", None) + if show_trendline and temporal_column: + return "big_number" + return "big_number_total" + def normalize_column_refs(self, config: Any, dataset_context: Any) -> Any: from superset.mcp_service.chart.schemas import BigNumberChartConfig from superset.mcp_service.chart.validation.dataset_validator import ( diff --git a/superset/mcp_service/chart/plugins/handlebars.py b/superset/mcp_service/chart/plugins/handlebars.py index 918110a0c37..d177ede9457 100644 --- a/superset/mcp_service/chart/plugins/handlebars.py +++ b/superset/mcp_service/chart/plugins/handlebars.py @@ -142,6 +142,19 @@ class HandlebarsChartPlugin(BaseChartPlugin): return map_handlebars_config(config) + def generate_name(self, config: Any, dataset_name: str | None = None) -> str: + from superset.mcp_service.chart.chart_utils import ( + _handlebars_chart_what, + _summarize_filters, + ) + + what = _handlebars_chart_what(config) + context = _summarize_filters(getattr(config, "filters", None)) + return f"{what} \u2013 {context}" if context else what + + def resolve_viz_type(self, config: Any) -> str: + return "handlebars" + def normalize_column_refs(self, config: Any, dataset_context: Any) -> Any: from superset.mcp_service.chart.schemas import HandlebarsChartConfig from superset.mcp_service.chart.validation.dataset_validator import ( diff --git a/superset/mcp_service/chart/plugins/mixed_timeseries.py b/superset/mcp_service/chart/plugins/mixed_timeseries.py index 1560948701e..706df74499d 100644 --- a/superset/mcp_service/chart/plugins/mixed_timeseries.py +++ b/superset/mcp_service/chart/plugins/mixed_timeseries.py @@ -110,6 +110,19 @@ class MixedTimeseriesChartPlugin(BaseChartPlugin): return map_mixed_timeseries_config(config, dataset_id=dataset_id) + def generate_name(self, config: Any, dataset_name: str | None = None) -> str: + from superset.mcp_service.chart.chart_utils import ( + _mixed_timeseries_what, + _summarize_filters, + ) + + what = _mixed_timeseries_what(config) + context = _summarize_filters(config.filters) + return f"{what} \u2013 {context}" if context else what + + def resolve_viz_type(self, config: Any) -> str: + return "mixed_timeseries" + def normalize_column_refs(self, config: Any, dataset_context: Any) -> Any: from superset.mcp_service.chart.schemas import MixedTimeseriesChartConfig from superset.mcp_service.chart.validation.dataset_validator import ( diff --git a/superset/mcp_service/chart/plugins/pie.py b/superset/mcp_service/chart/plugins/pie.py index 132e169b66d..0160e134a3d 100644 --- a/superset/mcp_service/chart/plugins/pie.py +++ b/superset/mcp_service/chart/plugins/pie.py @@ -84,6 +84,19 @@ class PieChartPlugin(BaseChartPlugin): return map_pie_config(config) + def generate_name(self, config: Any, dataset_name: str | None = None) -> str: + from superset.mcp_service.chart.chart_utils import ( + _pie_chart_what, + _summarize_filters, + ) + + what = _pie_chart_what(config) + context = _summarize_filters(config.filters) + return f"{what} \u2013 {context}" if context else what + + def resolve_viz_type(self, config: Any) -> str: + return "pie" + def normalize_column_refs(self, config: Any, dataset_context: Any) -> Any: from superset.mcp_service.chart.schemas import PieChartConfig from superset.mcp_service.chart.validation.dataset_validator import ( diff --git a/superset/mcp_service/chart/plugins/pivot_table.py b/superset/mcp_service/chart/plugins/pivot_table.py index c9b55c59115..e514a3307a0 100644 --- a/superset/mcp_service/chart/plugins/pivot_table.py +++ b/superset/mcp_service/chart/plugins/pivot_table.py @@ -107,6 +107,19 @@ class PivotTableChartPlugin(BaseChartPlugin): return map_pivot_table_config(config) + def generate_name(self, config: Any, dataset_name: str | None = None) -> str: + from superset.mcp_service.chart.chart_utils import ( + _pivot_table_what, + _summarize_filters, + ) + + what = _pivot_table_what(config) + context = _summarize_filters(config.filters) + return f"{what} \u2013 {context}" if context else what + + def resolve_viz_type(self, config: Any) -> str: + return "pivot_table_v2" + def normalize_column_refs(self, config: Any, dataset_context: Any) -> Any: from superset.mcp_service.chart.schemas import PivotTableChartConfig from superset.mcp_service.chart.validation.dataset_validator import ( diff --git a/superset/mcp_service/chart/plugins/table.py b/superset/mcp_service/chart/plugins/table.py index 16152ef9baf..ee7d9f27d4b 100644 --- a/superset/mcp_service/chart/plugins/table.py +++ b/superset/mcp_service/chart/plugins/table.py @@ -89,6 +89,19 @@ class TableChartPlugin(BaseChartPlugin): return map_table_config(config) + def generate_name(self, config: Any, dataset_name: str | None = None) -> str: + from superset.mcp_service.chart.chart_utils import ( + _summarize_filters, + _table_chart_what, + ) + + what = _table_chart_what(config, dataset_name) + context = _summarize_filters(config.filters) + return f"{what} \u2013 {context}" if context else what + + def resolve_viz_type(self, config: Any) -> str: + return getattr(config, "viz_type", "table") + def normalize_column_refs(self, config: Any, dataset_context: Any) -> Any: from superset.mcp_service.chart.schemas import TableChartConfig from superset.mcp_service.chart.validation.dataset_validator import ( @@ -96,6 +109,10 @@ class TableChartPlugin(BaseChartPlugin): ) config_dict = config.model_dump() - DatasetValidator._normalize_table_config(config_dict, dataset_context) + get_canonical = DatasetValidator._get_canonical_column_name + + for col in config_dict.get("columns") or []: + col["name"] = get_canonical(col["name"], dataset_context) + DatasetValidator._normalize_filters(config_dict, dataset_context) return TableChartConfig.model_validate(config_dict) diff --git a/superset/mcp_service/chart/plugins/xy.py b/superset/mcp_service/chart/plugins/xy.py index 75eea3f124d..92088ed66ab 100644 --- a/superset/mcp_service/chart/plugins/xy.py +++ b/superset/mcp_service/chart/plugins/xy.py @@ -19,12 +19,15 @@ from __future__ import annotations +import logging from typing import Any from superset.mcp_service.chart.plugin import BaseChartPlugin from superset.mcp_service.chart.schemas import ColumnRef from superset.mcp_service.common.error_schemas import ChartGenerationError +logger = logging.getLogger(__name__) + class XYChartPlugin(BaseChartPlugin): """Plugin for xy chart type (line, bar, area, scatter).""" @@ -105,17 +108,43 @@ class XYChartPlugin(BaseChartPlugin): ) config_dict = config.model_dump() - DatasetValidator._normalize_xy_config(config_dict, dataset_context) + get_canonical = DatasetValidator._get_canonical_column_name + + if config_dict.get("x"): + config_dict["x"]["name"] = get_canonical( + config_dict["x"]["name"], dataset_context + ) + for y_col in config_dict.get("y") or []: + y_col["name"] = get_canonical(y_col["name"], dataset_context) + for gb_col in config_dict.get("group_by") or []: + gb_col["name"] = get_canonical(gb_col["name"], dataset_context) + DatasetValidator._normalize_filters(config_dict, dataset_context) return XYChartConfig.model_validate(config_dict) + def generate_name(self, config: Any, dataset_name: str | None = None) -> str: + from superset.mcp_service.chart.chart_utils import ( + _xy_chart_context, + _xy_chart_what, + ) + + what = _xy_chart_what(config) + context = _xy_chart_context(config) + return f"{what} \u2013 {context}" if context else what + + def resolve_viz_type(self, config: Any) -> str: + kind = getattr(config, "kind", "line") + return { + "line": "echarts_timeseries_line", + "bar": "echarts_timeseries_bar", + "area": "echarts_area", + "scatter": "echarts_timeseries_scatter", + }.get(kind, "echarts_timeseries_line") + def get_runtime_warnings(self, config: Any, dataset_id: int | str) -> list[str]: """Return format-compatibility and cardinality warnings for XY charts.""" - import logging - from superset.mcp_service.chart.schemas import XYChartConfig - logger = logging.getLogger(__name__) if not isinstance(config, XYChartConfig): return [] diff --git a/superset/mcp_service/chart/validation/dataset_validator.py b/superset/mcp_service/chart/validation/dataset_validator.py index b4dbbc489fc..f1ef0eacabf 100644 --- a/superset/mcp_service/chart/validation/dataset_validator.py +++ b/superset/mcp_service/chart/validation/dataset_validator.py @@ -330,42 +330,6 @@ class DatasetValidator: # Return original if not found (validation should catch this case) return column_name - @staticmethod - def _normalize_xy_config( - config_dict: Dict[str, Any], dataset_context: DatasetContext - ) -> None: - """Normalize column names in an XY chart config dict in place.""" - # Normalize x-axis column - if "x" in config_dict and config_dict["x"]: - config_dict["x"]["name"] = DatasetValidator._get_canonical_column_name( - config_dict["x"]["name"], dataset_context - ) - - # Normalize y-axis columns - if "y" in config_dict and config_dict["y"]: - for y_col in config_dict["y"]: - y_col["name"] = DatasetValidator._get_canonical_column_name( - y_col["name"], dataset_context - ) - - # Normalize group_by columns - if "group_by" in config_dict and config_dict["group_by"]: - for gb_col in config_dict["group_by"]: - gb_col["name"] = DatasetValidator._get_canonical_column_name( - gb_col["name"], dataset_context - ) - - @staticmethod - def _normalize_table_config( - config_dict: Dict[str, Any], dataset_context: DatasetContext - ) -> None: - """Normalize column names in a table chart config dict in place.""" - if "columns" in config_dict and config_dict["columns"]: - for col in config_dict["columns"]: - col["name"] = DatasetValidator._get_canonical_column_name( - col["name"], dataset_context - ) - @staticmethod def _normalize_filters( config_dict: Dict[str, Any], dataset_context: DatasetContext diff --git a/superset/mcp_service/chart/validation/schema_validator.py b/superset/mcp_service/chart/validation/schema_validator.py index d5ee9651a11..969779c26a8 100644 --- a/superset/mcp_service/chart/validation/schema_validator.py +++ b/superset/mcp_service/chart/validation/schema_validator.py @@ -186,353 +186,12 @@ class SchemaValidator: return False, error return True, None - @staticmethod - def _pre_validate_xy_config( - config: Dict[str, Any], - ) -> Tuple[bool, ChartGenerationError | None]: - """Pre-validate XY chart configuration.""" - # x is optional — defaults to dataset's main_dttm_col in map_xy_config - if "y" not in config: - return False, ChartGenerationError( - error_type="missing_xy_fields", - message="XY chart missing required field: 'y' (Y-axis metrics)", - details="XY charts require Y-axis (metrics) specifications. " - "X-axis is optional and defaults to the dataset's primary " - "datetime column when omitted.", - suggestions=[ - "Add 'y' field: [{'name': 'metric_column', 'aggregate': 'SUM'}] " - "for Y-axis", - "Example: {'chart_type': 'xy', 'x': {'name': 'date'}, " - "'y': [{'name': 'sales', 'aggregate': 'SUM'}]}", - ], - error_code="MISSING_XY_FIELDS", - ) - - # Validate Y is a list - if not isinstance(config.get("y", []), list): - return False, ChartGenerationError( - error_type="invalid_y_format", - message="Y-axis must be a list of metrics", - details="The 'y' field must be an array of metric specifications", - suggestions=[ - "Wrap Y-axis metric in array: 'y': [{'name': 'column', " - "'aggregate': 'SUM'}]", - "Multiple metrics supported: 'y': [metric1, metric2, ...]", - ], - error_code="INVALID_Y_FORMAT", - ) - - return True, None - - @staticmethod - def _pre_validate_table_config( - config: Dict[str, Any], - ) -> Tuple[bool, ChartGenerationError | None]: - """Pre-validate table chart configuration.""" - if "columns" not in config: - return False, ChartGenerationError( - error_type="missing_columns", - message="Table chart missing required field: columns", - details="Table charts require a 'columns' array to specify which " - "columns to display", - suggestions=[ - "Add 'columns' field with array of column specifications", - "Example: 'columns': [{'name': 'product'}, {'name': 'sales', " - "'aggregate': 'SUM'}]", - "Each column can have optional 'aggregate' for metrics", - ], - error_code="MISSING_COLUMNS", - ) - - if not isinstance(config.get("columns", []), list): - return False, ChartGenerationError( - error_type="invalid_columns_format", - message="Columns must be a list", - details="The 'columns' field must be an array of column specifications", - suggestions=[ - "Ensure columns is an array: 'columns': [...]", - "Each column should be an object with 'name' field", - ], - error_code="INVALID_COLUMNS_FORMAT", - ) - - return True, None - - @staticmethod - def _pre_validate_pie_config( - config: Dict[str, Any], - ) -> Tuple[bool, ChartGenerationError | None]: - """Pre-validate pie chart configuration.""" - missing_fields = [] - - if "dimension" not in config: - missing_fields.append("'dimension' (category column for slices)") - if "metric" not in config: - missing_fields.append("'metric' (value metric for slice sizes)") - - if missing_fields: - return False, ChartGenerationError( - error_type="missing_pie_fields", - message=f"Pie chart missing required " - f"fields: {', '.join(missing_fields)}", - details="Pie charts require a dimension (categories) and a metric " - "(values)", - suggestions=[ - "Add 'dimension' field: {'name': 'category_column'}", - "Add 'metric' field: {'name': 'value_column', 'aggregate': 'SUM'}", - "Example: {'chart_type': 'pie', 'dimension': {'name': " - "'product'}, 'metric': {'name': 'revenue', 'aggregate': 'SUM'}}", - ], - error_code="MISSING_PIE_FIELDS", - ) - - return True, None - - @staticmethod - def _pre_validate_handlebars_config( - config: Dict[str, Any], - ) -> Tuple[bool, ChartGenerationError | None]: - """Pre-validate handlebars chart configuration.""" - if "handlebars_template" not in config: - return False, ChartGenerationError( - error_type="missing_handlebars_template", - message="Handlebars chart missing required field: handlebars_template", - details="Handlebars charts require a 'handlebars_template' string " - "containing Handlebars HTML template markup", - suggestions=[ - "Add 'handlebars_template' with a Handlebars HTML template", - "Data is available as {{data}} array in the template", - "Example: '