feat(mcp): support saved metrics from datasets in chart generation (#38955)

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Kamil Gabryjelski
2026-03-30 16:38:31 +02:00
committed by GitHub
parent d331a043a3
commit 15bab227bb
6 changed files with 323 additions and 30 deletions

View File

@@ -570,3 +570,58 @@ class TestUnknownFieldDetection:
assert config.stacked is True
assert config.row_limit == 10000
assert config.group_by is not None
class TestColumnRefSavedMetric:
"""Test ColumnRef saved_metric support."""
def test_saved_metric_defaults_to_false(self) -> None:
col = ColumnRef(name="revenue", aggregate="SUM")
assert col.saved_metric is False
def test_saved_metric_flag_accepted(self) -> None:
col = ColumnRef(name="total_revenue", saved_metric=True)
assert col.saved_metric is True
assert col.name == "total_revenue"
def test_saved_metric_clears_aggregate(self) -> None:
col = ColumnRef(name="total_revenue", saved_metric=True, aggregate="SUM")
assert col.saved_metric is True
assert col.aggregate is None
def test_saved_metric_preserves_label(self) -> None:
col = ColumnRef(name="total_revenue", saved_metric=True, label="Revenue")
assert col.label == "Revenue"
def test_saved_metric_in_table_config_unique_labels(self) -> None:
config = TableChartConfig(
chart_type="table",
columns=[
ColumnRef(name="product_line"),
ColumnRef(name="total_revenue", saved_metric=True),
],
)
assert len(config.columns) == 2
def test_saved_metric_in_xy_config_unique_labels(self) -> None:
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="order_date"),
y=[ColumnRef(name="total_revenue", saved_metric=True)],
)
assert len(config.y) == 1
def test_saved_metric_duplicate_label_rejected(self) -> None:
with pytest.raises(ValidationError, match="Duplicate column/metric labels"):
XYChartConfig(
chart_type="xy",
x=ColumnRef(name="order_date"),
y=[
ColumnRef(name="total_revenue", saved_metric=True),
ColumnRef(
name="other_metric",
saved_metric=True,
label="total_revenue",
),
],
)

View File

@@ -55,6 +55,7 @@ class TestCreateMetricObject:
col = ColumnRef(name="revenue", aggregate="SUM", label="Total Revenue")
result = create_metric_object(col)
assert isinstance(result, dict)
assert result["aggregate"] == "SUM"
assert result["column"]["column_name"] == "revenue"
assert result["label"] == "Total Revenue"
@@ -66,11 +67,28 @@ class TestCreateMetricObject:
col = ColumnRef(name="orders")
result = create_metric_object(col)
assert isinstance(result, dict)
assert result["aggregate"] == "SUM"
assert result["column"]["column_name"] == "orders"
assert result["label"] == "SUM(orders)"
assert result["optionName"] == "metric_orders"
def test_create_metric_object_saved_metric_returns_string(self) -> None:
"""Test that saved metrics return a plain string metric name"""
col = ColumnRef(name="total_revenue", saved_metric=True)
result = create_metric_object(col)
assert result == "total_revenue"
assert isinstance(result, str)
def test_create_metric_object_saved_metric_ignores_aggregate(self) -> None:
"""Test that saved metrics ignore aggregate even if somehow set"""
col = ColumnRef(name="total_revenue", saved_metric=True, aggregate="SUM")
result = create_metric_object(col)
# saved_metric validator clears aggregate, result is plain string
assert result == "total_revenue"
class TestMapFilterOperator:
"""Test map_filter_operator function"""
@@ -338,6 +356,38 @@ class TestMapTableConfig:
assert result["row_limit"] == 1000
def test_map_table_config_saved_metric_as_metric(self) -> None:
"""Test that saved metrics are routed to metrics, not raw columns."""
config = TableChartConfig(
chart_type="table",
columns=[
ColumnRef(name="product_line"),
ColumnRef(name="total_revenue", saved_metric=True),
],
)
result = map_table_config(config)
assert result["query_mode"] == "aggregate"
assert result["metrics"] == ["total_revenue"]
assert "product_line" in result["groupby"]
def test_map_table_config_saved_metric_only(self) -> None:
"""Test table with only saved metrics (no raw columns)."""
config = TableChartConfig(
chart_type="table",
columns=[
ColumnRef(name="total_revenue", saved_metric=True),
ColumnRef(name="avg_order_value", saved_metric=True),
],
)
result = map_table_config(config)
assert result["query_mode"] == "aggregate"
assert result["metrics"] == ["total_revenue", "avg_order_value"]
assert "all_columns" not in result
class TestAddAdhocFilters:
"""Test _add_adhoc_filters helper function"""
@@ -651,6 +701,44 @@ class TestMapXYConfig:
assert result["row_limit"] == 10000
@patch("superset.mcp_service.chart.chart_utils.is_column_truly_temporal")
def test_map_xy_config_saved_metric(self, mock_is_temporal: Any) -> None:
"""Test XY config with saved metric emits string in metrics list"""
mock_is_temporal.return_value = True
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="order_date"),
y=[ColumnRef(name="total_revenue", saved_metric=True)],
kind="line",
)
result = map_xy_config(config, dataset_id=1)
assert result["metrics"] == ["total_revenue"]
@patch("superset.mcp_service.chart.chart_utils.is_column_truly_temporal")
def test_map_xy_config_mixed_saved_and_adhoc_metrics(
self, mock_is_temporal: Any
) -> None:
"""Test XY config with both saved and ad-hoc metrics"""
mock_is_temporal.return_value = True
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="order_date"),
y=[
ColumnRef(name="total_revenue", saved_metric=True),
ColumnRef(name="quantity", aggregate="SUM"),
],
kind="line",
)
result = map_xy_config(config, dataset_id=1)
assert len(result["metrics"]) == 2
assert result["metrics"][0] == "total_revenue"
assert isinstance(result["metrics"][1], dict)
assert result["metrics"][1]["aggregate"] == "SUM"
class TestMapConfigToFormData:
"""Test map_config_to_form_data function"""

View File

@@ -679,3 +679,53 @@ class TestNormalizeXAxisFilterConsistency:
assert normalized.group_by is not None
assert normalized.filters is not None
assert normalized.group_by[0].name == normalized.filters[0].column == "AIRLINE"
class TestValidateSavedMetrics:
"""Test that saved_metric refs are validated against dataset metrics."""
def test_valid_saved_metric_passes(
self, mock_dataset_context: DatasetContext
) -> None:
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="OrderDate"),
y=[ColumnRef(name="TotalRevenue", saved_metric=True)],
)
is_valid, error = DatasetValidator.validate_against_dataset(
config, dataset_id=18, dataset_context=mock_dataset_context
)
assert is_valid
assert error is None
def test_column_name_as_saved_metric_fails(
self, mock_dataset_context: DatasetContext
) -> None:
"""A regular column marked as saved_metric should be rejected."""
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="OrderDate"),
y=[ColumnRef(name="Sales", saved_metric=True)],
)
is_valid, error = DatasetValidator.validate_against_dataset(
config, dataset_id=18, dataset_context=mock_dataset_context
)
assert not is_valid
assert error is not None
assert error.error_code == "INVALID_SAVED_METRIC"
def test_nonexistent_saved_metric_fails(
self, mock_dataset_context: DatasetContext
) -> None:
"""A nonexistent saved metric should produce a specific error."""
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="OrderDate"),
y=[ColumnRef(name="nonexistent_metric", saved_metric=True)],
)
is_valid, error = DatasetValidator.validate_against_dataset(
config, dataset_id=18, dataset_context=mock_dataset_context
)
assert not is_valid
assert error is not None
assert error.error_code == "INVALID_SAVED_METRIC"