From c853d4df63243412fc41cee2ba8b303ecbb45c4f Mon Sep 17 00:00:00 2001 From: Kamil Gabryjelski Date: Mon, 30 Mar 2026 16:38:31 +0200 Subject: [PATCH] feat(mcp): support saved metrics from datasets in chart generation (#38955) Co-authored-by: Claude Opus 4.6 (1M context) (cherry picked from commit 15bab227bb00f34ab4b4bb87844f0916e839953a) --- superset/mcp_service/chart/chart_utils.py | 24 +++-- superset/mcp_service/chart/schemas.py | 35 +++++- .../chart/validation/dataset_validator.py | 101 ++++++++++++++---- .../mcp_service/chart/test_chart_schemas.py | 55 ++++++++++ .../mcp_service/chart/test_chart_utils.py | 88 +++++++++++++++ .../test_column_name_normalization.py | 50 +++++++++ 6 files changed, 323 insertions(+), 30 deletions(-) diff --git a/superset/mcp_service/chart/chart_utils.py b/superset/mcp_service/chart/chart_utils.py index e1292037d3f..0a4b19b0cee 100644 --- a/superset/mcp_service/chart/chart_utils.py +++ b/superset/mcp_service/chart/chart_utils.py @@ -381,8 +381,8 @@ def map_table_config(config: TableChartConfig) -> Dict[str, Any]: aggregated_metrics = [] for col in config.columns: - if col.aggregate: - # Column has aggregation - treat as metric + if col.is_metric: + # Saved metric or column with aggregation - treat as metric aggregated_metrics.append(create_metric_object(col)) else: # No aggregation - treat as raw column @@ -441,8 +441,16 @@ def map_table_config(config: TableChartConfig) -> Dict[str, Any]: return form_data -def create_metric_object(col: ColumnRef) -> Dict[str, Any]: - """Create a metric object for a column with enhanced validation.""" +def create_metric_object(col: ColumnRef) -> Dict[str, Any] | str: + """Create a metric object for a column with enhanced validation. + + For saved metrics, returns the metric name as a plain string which + Superset's query engine resolves via its metrics_by_name lookup. + For ad-hoc metrics, returns a SIMPLE expression dict. + """ + if col.saved_metric: + return col.name + # Ensure aggregate is valid - default to SUM if not specified or invalid valid_aggregates = { "SUM", @@ -840,6 +848,8 @@ def _humanize_column(col: ColumnRef) -> str: if col.label: return col.label name = col.name.replace("_", " ").title() + if col.saved_metric: + return name if col.aggregate: return f"{col.aggregate.capitalize()}({name})" return name @@ -874,9 +884,9 @@ def _truncate(name: str, max_length: int = 60) -> str: def _table_chart_what(config: TableChartConfig, dataset_name: str | None) -> str: """Build the descriptive fragment for a table chart.""" - has_agg = any(col.aggregate for col in config.columns) + has_agg = any(col.is_metric for col in config.columns) if has_agg: - metrics = [col for col in config.columns if col.aggregate] + metrics = [col for col in config.columns if col.is_metric] what = ", ".join(_humanize_column(m) for m in metrics[:2]) return f"{what} Summary" if dataset_name: @@ -1073,7 +1083,7 @@ def analyze_chart_capabilities(chart: Any | None, config: Any) -> ChartCapabilit # Classify data types data_types = [] if hasattr(config, "x") and config.x: - data_types.append("categorical" if not config.x.aggregate else "metric") + data_types.append("categorical" if not config.x.is_metric else "metric") if hasattr(config, "y") and config.y: data_types.extend(["metric"] * len(config.y)) if "time" in viz_type or "timeseries" in viz_type: diff --git a/superset/mcp_service/chart/schemas.py b/superset/mcp_service/chart/schemas.py index edf952dffb9..8914660c4fa 100644 --- a/superset/mcp_service/chart/schemas.py +++ b/superset/mcp_service/chart/schemas.py @@ -485,6 +485,24 @@ class ColumnRef(BaseModel): ] | None ) = Field(None, description="SQL aggregate function") + saved_metric: bool = Field( + False, + description="If true, 'name' refers to a saved metric from the dataset " + "(use get_dataset_info to see available metrics). " + "When set, 'aggregate' is ignored.", + ) + + @property + def is_metric(self) -> bool: + """Whether this ref acts as a metric (has aggregate or is a saved metric).""" + return bool(self.aggregate) or self.saved_metric + + @model_validator(mode="after") + def clear_aggregate_for_saved_metric(self) -> "ColumnRef": + """Clear aggregate when saved_metric is True since it's ignored.""" + if self.saved_metric and self.aggregate is not None: + self.aggregate = None + return self @field_validator("name") @classmethod @@ -592,7 +610,9 @@ class PieChartConfig(UnknownFieldCheckMixin): validation_alias=AliasChoices("dimension", "groupby"), ) metric: ColumnRef = Field( - ..., description="Value metric (needs aggregate e.g. SUM, COUNT)" + ..., + description="Value metric (use aggregate e.g. SUM, COUNT for ad-hoc, " + "or set saved_metric=True for a saved dataset metric)", ) donut: bool = False show_labels: bool = True @@ -638,7 +658,8 @@ class PivotTableChartConfig(UnknownFieldCheckMixin): metrics: List[ColumnRef] = Field( ..., min_length=1, - description="Metrics (need aggregate e.g. SUM, COUNT, AVG)", + description="Metrics (use aggregate e.g. SUM, COUNT, AVG for ad-hoc, " + "or set saved_metric=True for saved dataset metrics)", ) aggregate_function: Literal[ "Sum", @@ -818,7 +839,7 @@ class HandlebarsChartConfig(UnknownFieldCheckMixin): "Handlebars chart in 'aggregate' query mode requires 'metrics' " "field. Specify at least one metric with an aggregate function." ) - missing_agg = [m.name for m in self.metrics if not m.aggregate] + missing_agg = [m.name for m in self.metrics if not m.is_metric] if missing_agg: raise ValueError( f"Handlebars chart in 'aggregate' query mode requires an " @@ -861,7 +882,9 @@ class TableChartConfig(UnknownFieldCheckMixin): for i, col in enumerate(self.columns): # Generate the label that will be used (same logic as create_metric_object) - if col.aggregate: + if col.saved_metric: + label = col.label or col.name + elif col.aggregate: label = col.label or f"{col.aggregate}({col.name})" else: label = col.label or col.name @@ -943,7 +966,9 @@ class XYChartConfig(UnknownFieldCheckMixin): # Check Y-axis labels for i, col in enumerate(self.y): - if col.aggregate: + if col.saved_metric: + label = col.label or col.name + elif col.aggregate: label = col.label or f"{col.aggregate}({col.name})" else: label = col.label or col.name diff --git a/superset/mcp_service/chart/validation/dataset_validator.py b/superset/mcp_service/chart/validation/dataset_validator.py index 4f7e78cc4f0..202d7cc5afc 100644 --- a/superset/mcp_service/chart/validation/dataset_validator.py +++ b/superset/mcp_service/chart/validation/dataset_validator.py @@ -82,25 +82,19 @@ class DatasetValidator: # Collect all column references column_refs = DatasetValidator._extract_column_references(config) - # Validate each column exists - invalid_columns = [] - for col_ref in column_refs: - if not DatasetValidator._column_exists(col_ref.name, dataset_context): - invalid_columns.append(col_ref) + # Validate saved metrics exist in dataset metrics specifically + invalid_saved = DatasetValidator._validate_saved_metrics( + column_refs, dataset_context + ) + if invalid_saved: + return False, invalid_saved - if invalid_columns: - # Generate suggestions for invalid columns - suggestions_map = {} - for col_ref in invalid_columns: - suggestions = DatasetValidator._get_column_suggestions( - col_ref.name, dataset_context - ) - suggestions_map[col_ref.name] = suggestions - - # Build error with suggestions - return False, DatasetValidator._build_column_error( - invalid_columns, suggestions_map, dataset_context - ) + # Validate columns exist (skip saved metrics — already validated above) + column_error = DatasetValidator._validate_columns_exist( + column_refs, dataset_context + ) + if column_error: + return False, column_error # Validate aggregation compatibility if isinstance(config, (TableChartConfig, XYChartConfig)): @@ -112,6 +106,32 @@ class DatasetValidator: return True, None + @staticmethod + def _validate_columns_exist( + column_refs: List[ColumnRef], dataset_context: DatasetContext + ) -> ChartGenerationError | None: + """Validate that non-saved-metric column refs exist in the dataset.""" + invalid_columns = [] + for col_ref in column_refs: + if col_ref.saved_metric: + continue + if not DatasetValidator._column_exists(col_ref.name, dataset_context): + invalid_columns.append(col_ref) + + if not invalid_columns: + return None + + suggestions_map = {} + for col_ref in invalid_columns: + suggestions = DatasetValidator._get_column_suggestions( + col_ref.name, dataset_context + ) + suggestions_map[col_ref.name] = suggestions + + return DatasetValidator._build_column_error( + invalid_columns, suggestions_map, dataset_context + ) + @staticmethod def _get_dataset_context(dataset_id: int | str) -> DatasetContext | None: """Get dataset context with column information.""" @@ -418,6 +438,49 @@ class DatasetValidator: error_code="MULTIPLE_INVALID_COLUMNS", ) + @staticmethod + def _validate_saved_metrics( + column_refs: List[ColumnRef], dataset_context: DatasetContext + ) -> ChartGenerationError | None: + """Validate that saved_metric refs exist in dataset metrics. + + A ColumnRef with saved_metric=True must match an entry in + available_metrics, not just available_columns. Without this check + a regular column name marked as saved_metric would pass + _column_exists (which checks both lists) but fail at query time. + """ + metric_names = {m["name"].lower() for m in dataset_context.available_metrics} + invalid = [ + col_ref.name + for col_ref in column_refs + if col_ref.saved_metric and col_ref.name.lower() not in metric_names + ] + if not invalid: + return None + + from superset.mcp_service.utils.error_builder import ChartErrorBuilder + + available = [m["name"] for m in dataset_context.available_metrics] + return ChartErrorBuilder.build_error( + error_type="invalid_saved_metric", + template_key="column_not_found", + template_vars={ + "column": ", ".join(invalid), + "suggestions": ( + f"Available saved metrics: {', '.join(available[:10])}" + if available + else "This dataset has no saved metrics" + ), + }, + custom_suggestions=[ + f"'{name}' is not a saved metric in this dataset. " + "Remove saved_metric=True to use it as a column with an aggregate, " + "or use get_dataset_info to see available saved metrics." + for name in invalid + ], + error_code="INVALID_SAVED_METRIC", + ) + @staticmethod def _validate_aggregations( column_refs: List[ColumnRef], dataset_context: DatasetContext @@ -426,6 +489,8 @@ class DatasetValidator: errors = [] for col_ref in column_refs: + if col_ref.saved_metric: + continue # Saved metrics have built-in aggregation if not col_ref.aggregate: continue diff --git a/tests/unit_tests/mcp_service/chart/test_chart_schemas.py b/tests/unit_tests/mcp_service/chart/test_chart_schemas.py index 73bf5ae9f14..840fc01d3ad 100644 --- a/tests/unit_tests/mcp_service/chart/test_chart_schemas.py +++ b/tests/unit_tests/mcp_service/chart/test_chart_schemas.py @@ -625,3 +625,58 @@ class TestUnknownFieldDetection: assert config.stacked is True assert config.row_limit == 10000 assert config.group_by is not None + + +class TestColumnRefSavedMetric: + """Test ColumnRef saved_metric support.""" + + def test_saved_metric_defaults_to_false(self) -> None: + col = ColumnRef(name="revenue", aggregate="SUM") + assert col.saved_metric is False + + def test_saved_metric_flag_accepted(self) -> None: + col = ColumnRef(name="total_revenue", saved_metric=True) + assert col.saved_metric is True + assert col.name == "total_revenue" + + def test_saved_metric_clears_aggregate(self) -> None: + col = ColumnRef(name="total_revenue", saved_metric=True, aggregate="SUM") + assert col.saved_metric is True + assert col.aggregate is None + + def test_saved_metric_preserves_label(self) -> None: + col = ColumnRef(name="total_revenue", saved_metric=True, label="Revenue") + assert col.label == "Revenue" + + def test_saved_metric_in_table_config_unique_labels(self) -> None: + config = TableChartConfig( + chart_type="table", + columns=[ + ColumnRef(name="product_line"), + ColumnRef(name="total_revenue", saved_metric=True), + ], + ) + assert len(config.columns) == 2 + + def test_saved_metric_in_xy_config_unique_labels(self) -> None: + config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="order_date"), + y=[ColumnRef(name="total_revenue", saved_metric=True)], + ) + assert len(config.y) == 1 + + def test_saved_metric_duplicate_label_rejected(self) -> None: + with pytest.raises(ValidationError, match="Duplicate column/metric labels"): + XYChartConfig( + chart_type="xy", + x=ColumnRef(name="order_date"), + y=[ + ColumnRef(name="total_revenue", saved_metric=True), + ColumnRef( + name="other_metric", + saved_metric=True, + label="total_revenue", + ), + ], + ) diff --git a/tests/unit_tests/mcp_service/chart/test_chart_utils.py b/tests/unit_tests/mcp_service/chart/test_chart_utils.py index 1a247bf1e89..53ff6847983 100644 --- a/tests/unit_tests/mcp_service/chart/test_chart_utils.py +++ b/tests/unit_tests/mcp_service/chart/test_chart_utils.py @@ -55,6 +55,7 @@ class TestCreateMetricObject: col = ColumnRef(name="revenue", aggregate="SUM", label="Total Revenue") result = create_metric_object(col) + assert isinstance(result, dict) assert result["aggregate"] == "SUM" assert result["column"]["column_name"] == "revenue" assert result["label"] == "Total Revenue" @@ -66,11 +67,28 @@ class TestCreateMetricObject: col = ColumnRef(name="orders") result = create_metric_object(col) + assert isinstance(result, dict) assert result["aggregate"] == "SUM" assert result["column"]["column_name"] == "orders" assert result["label"] == "SUM(orders)" assert result["optionName"] == "metric_orders" + def test_create_metric_object_saved_metric_returns_string(self) -> None: + """Test that saved metrics return a plain string metric name""" + col = ColumnRef(name="total_revenue", saved_metric=True) + result = create_metric_object(col) + + assert result == "total_revenue" + assert isinstance(result, str) + + def test_create_metric_object_saved_metric_ignores_aggregate(self) -> None: + """Test that saved metrics ignore aggregate even if somehow set""" + col = ColumnRef(name="total_revenue", saved_metric=True, aggregate="SUM") + result = create_metric_object(col) + + # saved_metric validator clears aggregate, result is plain string + assert result == "total_revenue" + class TestMapFilterOperator: """Test map_filter_operator function""" @@ -338,6 +356,38 @@ class TestMapTableConfig: assert result["row_limit"] == 1000 + def test_map_table_config_saved_metric_as_metric(self) -> None: + """Test that saved metrics are routed to metrics, not raw columns.""" + config = TableChartConfig( + chart_type="table", + columns=[ + ColumnRef(name="product_line"), + ColumnRef(name="total_revenue", saved_metric=True), + ], + ) + + result = map_table_config(config) + + assert result["query_mode"] == "aggregate" + assert result["metrics"] == ["total_revenue"] + assert "product_line" in result["groupby"] + + def test_map_table_config_saved_metric_only(self) -> None: + """Test table with only saved metrics (no raw columns).""" + config = TableChartConfig( + chart_type="table", + columns=[ + ColumnRef(name="total_revenue", saved_metric=True), + ColumnRef(name="avg_order_value", saved_metric=True), + ], + ) + + result = map_table_config(config) + + assert result["query_mode"] == "aggregate" + assert result["metrics"] == ["total_revenue", "avg_order_value"] + assert "all_columns" not in result + class TestAddAdhocFilters: """Test _add_adhoc_filters helper function""" @@ -651,6 +701,44 @@ class TestMapXYConfig: assert result["row_limit"] == 10000 + @patch("superset.mcp_service.chart.chart_utils.is_column_truly_temporal") + def test_map_xy_config_saved_metric(self, mock_is_temporal: Any) -> None: + """Test XY config with saved metric emits string in metrics list""" + mock_is_temporal.return_value = True + config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="order_date"), + y=[ColumnRef(name="total_revenue", saved_metric=True)], + kind="line", + ) + + result = map_xy_config(config, dataset_id=1) + + assert result["metrics"] == ["total_revenue"] + + @patch("superset.mcp_service.chart.chart_utils.is_column_truly_temporal") + def test_map_xy_config_mixed_saved_and_adhoc_metrics( + self, mock_is_temporal: Any + ) -> None: + """Test XY config with both saved and ad-hoc metrics""" + mock_is_temporal.return_value = True + config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="order_date"), + y=[ + ColumnRef(name="total_revenue", saved_metric=True), + ColumnRef(name="quantity", aggregate="SUM"), + ], + kind="line", + ) + + result = map_xy_config(config, dataset_id=1) + + assert len(result["metrics"]) == 2 + assert result["metrics"][0] == "total_revenue" + assert isinstance(result["metrics"][1], dict) + assert result["metrics"][1]["aggregate"] == "SUM" + class TestMapConfigToFormData: """Test map_config_to_form_data function""" diff --git a/tests/unit_tests/mcp_service/chart/validation/test_column_name_normalization.py b/tests/unit_tests/mcp_service/chart/validation/test_column_name_normalization.py index 8c68f738ed1..ab663ebc811 100644 --- a/tests/unit_tests/mcp_service/chart/validation/test_column_name_normalization.py +++ b/tests/unit_tests/mcp_service/chart/validation/test_column_name_normalization.py @@ -679,3 +679,53 @@ class TestNormalizeXAxisFilterConsistency: assert normalized.group_by is not None assert normalized.filters is not None assert normalized.group_by[0].name == normalized.filters[0].column == "AIRLINE" + + +class TestValidateSavedMetrics: + """Test that saved_metric refs are validated against dataset metrics.""" + + def test_valid_saved_metric_passes( + self, mock_dataset_context: DatasetContext + ) -> None: + config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="OrderDate"), + y=[ColumnRef(name="TotalRevenue", saved_metric=True)], + ) + is_valid, error = DatasetValidator.validate_against_dataset( + config, dataset_id=18, dataset_context=mock_dataset_context + ) + assert is_valid + assert error is None + + def test_column_name_as_saved_metric_fails( + self, mock_dataset_context: DatasetContext + ) -> None: + """A regular column marked as saved_metric should be rejected.""" + config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="OrderDate"), + y=[ColumnRef(name="Sales", saved_metric=True)], + ) + is_valid, error = DatasetValidator.validate_against_dataset( + config, dataset_id=18, dataset_context=mock_dataset_context + ) + assert not is_valid + assert error is not None + assert error.error_code == "INVALID_SAVED_METRIC" + + def test_nonexistent_saved_metric_fails( + self, mock_dataset_context: DatasetContext + ) -> None: + """A nonexistent saved metric should produce a specific error.""" + config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="OrderDate"), + y=[ColumnRef(name="nonexistent_metric", saved_metric=True)], + ) + is_valid, error = DatasetValidator.validate_against_dataset( + config, dataset_id=18, dataset_context=mock_dataset_context + ) + assert not is_valid + assert error is not None + assert error.error_code == "INVALID_SAVED_METRIC"