fix(mcp): Handle big_number charts and make semantic warnings non-blocking (#37142)

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Amin Ghadersohi
2026-01-16 12:19:06 -05:00
committed by GitHub
parent 4532ccf638
commit d0783da3e5
5 changed files with 394 additions and 17 deletions

View File

@@ -16,7 +16,7 @@
# under the License.
"""
Tests for the get_chart_data request schema
Tests for the get_chart_data request schema and big_number chart handling.
"""
import pytest
@@ -24,6 +24,134 @@ import pytest
from superset.mcp_service.chart.schemas import GetChartDataRequest
class TestBigNumberChartFallback:
"""Tests for big_number chart fallback query construction."""
def test_big_number_uses_singular_metric(self):
"""Test that big_number charts use 'metric' (singular) from form_data."""
# Mock form_data for big_number chart
form_data = {
"metric": {"label": "Count", "expressionType": "SIMPLE", "column": None},
"viz_type": "big_number",
}
# Verify the metric extraction logic
metric = form_data.get("metric")
metrics = [metric] if metric else []
assert len(metrics) == 1
assert metrics[0]["label"] == "Count"
def test_big_number_total_uses_singular_metric(self):
"""Test that big_number_total charts use 'metric' (singular)."""
form_data = {
"metric": {"label": "Total Sales", "expressionType": "SQL"},
"viz_type": "big_number_total",
}
metric = form_data.get("metric")
metrics = [metric] if metric else []
assert len(metrics) == 1
assert metrics[0]["label"] == "Total Sales"
def test_big_number_empty_metric_returns_empty_list(self):
"""Test handling of big_number chart with no metric configured."""
form_data = {
"metric": None,
"viz_type": "big_number",
}
metric = form_data.get("metric")
metrics = [metric] if metric else []
assert len(metrics) == 0
def test_big_number_no_groupby_columns(self):
"""Test that big_number charts don't have groupby columns."""
form_data = {
"metric": {"label": "Count"},
"viz_type": "big_number",
"groupby": ["should_be_ignored"], # This should be ignored
}
viz_type = form_data.get("viz_type", "")
if viz_type.startswith("big_number"):
groupby_columns: list[str] = [] # big_number charts don't group by
else:
groupby_columns = form_data.get("groupby", [])
assert groupby_columns == []
def test_standard_chart_uses_plural_metrics(self):
"""Test that non-big_number charts use 'metrics' (plural)."""
form_data = {
"metrics": [
{"label": "Sum of Sales"},
{"label": "Avg of Quantity"},
],
"groupby": ["region", "category"],
"viz_type": "table",
}
viz_type = form_data.get("viz_type", "")
if viz_type.startswith("big_number"):
metric = form_data.get("metric")
metrics = [metric] if metric else []
groupby_columns: list[str] = []
else:
metrics = form_data.get("metrics", [])
groupby_columns = form_data.get("groupby", [])
assert len(metrics) == 2
assert len(groupby_columns) == 2
def test_viz_type_detection_for_single_metric_charts(self):
"""Test viz_type detection handles all single-metric chart types."""
# Chart types that use "metric" (singular) instead of "metrics" (plural)
single_metric_types = ("big_number", "pop_kpi")
# big_number variants match via startswith
big_number_types = ["big_number", "big_number_total"]
for viz_type in big_number_types:
is_single_metric = (
viz_type.startswith("big_number") or viz_type in single_metric_types
)
assert is_single_metric is True
# pop_kpi (BigNumberPeriodOverPeriod) matches via exact match
assert "pop_kpi" in single_metric_types
# Verify standard chart types don't match
other_types = ["table", "line", "bar", "pie", "echarts_timeseries"]
for viz_type in other_types:
is_single_metric = (
viz_type.startswith("big_number") or viz_type in single_metric_types
)
assert is_single_metric is False
def test_pop_kpi_uses_singular_metric(self):
"""Test that pop_kpi (BigNumberPeriodOverPeriod) uses singular metric."""
form_data = {
"metric": {"label": "Period Comparison", "expressionType": "SQL"},
"viz_type": "pop_kpi",
}
viz_type = form_data.get("viz_type", "")
single_metric_types = ("big_number", "pop_kpi")
if viz_type.startswith("big_number") or viz_type in single_metric_types:
metric = form_data.get("metric")
metrics = [metric] if metric else []
groupby_columns: list[str] = []
else:
metrics = form_data.get("metrics", [])
groupby_columns = form_data.get("groupby", [])
assert len(metrics) == 1
assert metrics[0]["label"] == "Period Comparison"
assert groupby_columns == []
class TestGetChartDataRequestSchema:
"""Test the GetChartDataRequest schema validation."""

View File

@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@@ -0,0 +1,221 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Tests for RuntimeValidator.
These tests verify that semantic warnings are non-blocking and
chart generation succeeds even when warnings are present.
"""
from unittest.mock import patch
from superset.mcp_service.chart.schemas import (
AxisConfig,
ColumnRef,
TableChartConfig,
XYChartConfig,
)
from superset.mcp_service.chart.validation.runtime import RuntimeValidator
class TestRuntimeValidatorNonBlocking:
"""Test that RuntimeValidator treats semantic warnings as non-blocking."""
def test_validate_runtime_issues_returns_valid_without_warnings(self):
"""Test that validation returns (True, None) when no issues found."""
config = TableChartConfig(
chart_type="table",
columns=[ColumnRef(name="col1")],
)
is_valid, error = RuntimeValidator.validate_runtime_issues(config, 1)
assert is_valid is True
assert error is None
def test_validate_runtime_issues_non_blocking_with_format_warnings(self):
"""Test that format compatibility warnings do NOT block chart generation."""
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="date"),
y=[ColumnRef(name="value", aggregate="SUM")],
kind="line",
x_axis=AxisConfig(format="$,.2f"), # Currency format for date - mismatch
)
# Mock the format validator to return warnings
with patch(
"superset.mcp_service.chart.validation.runtime.RuntimeValidator."
"_validate_format_compatibility"
) as mock_format:
mock_format.return_value = [
"Currency format '$,.2f' may not display dates correctly"
]
is_valid, error = RuntimeValidator.validate_runtime_issues(config, 1)
# Should still return valid despite warnings
assert is_valid is True
assert error is None
def test_validate_runtime_issues_non_blocking_with_cardinality_warnings(self):
"""Test that cardinality warnings do NOT block chart generation."""
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="user_id"), # High cardinality column
y=[ColumnRef(name="count", aggregate="COUNT")],
kind="bar",
)
# Mock the cardinality validator to return warnings
with patch(
"superset.mcp_service.chart.validation.runtime.RuntimeValidator."
"_validate_cardinality"
) as mock_cardinality:
mock_cardinality.return_value = (
["High cardinality detected: 10000+ unique values"],
["Consider using aggregation or filtering"],
)
is_valid, error = RuntimeValidator.validate_runtime_issues(config, 1)
# Should still return valid despite high cardinality warning
assert is_valid is True
assert error is None
def test_validate_runtime_issues_non_blocking_with_chart_type_suggestions(self):
"""Test that chart type suggestions do NOT block chart generation."""
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="category"), # Categorical data
y=[ColumnRef(name="sales", aggregate="SUM")],
kind="line", # Line chart for categorical data
)
# Mock the chart type suggester to return suggestions
with patch(
"superset.mcp_service.chart.validation.runtime.RuntimeValidator."
"_validate_chart_type"
) as mock_suggester:
mock_suggester.return_value = (
["Line chart may not be ideal for categorical X axis"],
["Consider using bar chart for better visualization"],
)
is_valid, error = RuntimeValidator.validate_runtime_issues(config, 1)
# Should still return valid despite suggestion
assert is_valid is True
assert error is None
def test_validate_runtime_issues_non_blocking_with_multiple_warnings(self):
"""Test that multiple warnings combined do NOT block chart generation."""
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="user_id"),
y=[ColumnRef(name="amount", aggregate="SUM")],
kind="scatter",
x_axis=AxisConfig(format="smart_date"), # Wrong format for user_id
)
# Mock all validators to return warnings
with (
patch(
"superset.mcp_service.chart.validation.runtime.RuntimeValidator."
"_validate_format_compatibility"
) as mock_format,
patch(
"superset.mcp_service.chart.validation.runtime.RuntimeValidator."
"_validate_cardinality"
) as mock_cardinality,
patch(
"superset.mcp_service.chart.validation.runtime.RuntimeValidator."
"_validate_chart_type"
) as mock_type,
):
mock_format.return_value = ["Format mismatch warning"]
mock_cardinality.return_value = (
["High cardinality warning"],
["Cardinality suggestion"],
)
mock_type.return_value = (
["Chart type warning"],
["Chart type suggestion"],
)
is_valid, error = RuntimeValidator.validate_runtime_issues(config, 1)
# Should still return valid despite multiple warnings
assert is_valid is True
assert error is None
def test_validate_runtime_issues_logs_warnings(self):
"""Test that warnings are logged for debugging purposes."""
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="date"),
y=[ColumnRef(name="value", aggregate="SUM")],
kind="line",
)
with (
patch(
"superset.mcp_service.chart.validation.runtime.RuntimeValidator."
"_validate_format_compatibility"
) as mock_format,
patch(
"superset.mcp_service.chart.validation.runtime.logger"
) as mock_logger,
):
mock_format.return_value = ["Test warning message"]
is_valid, error = RuntimeValidator.validate_runtime_issues(config, 1)
# Verify warnings were logged
assert mock_logger.info.called
assert is_valid is True
assert error is None
def test_validate_table_chart_skips_xy_validations(self):
"""Test that table charts skip XY-specific validations."""
config = TableChartConfig(
chart_type="table",
columns=[
ColumnRef(name="region"),
ColumnRef(name="sales", aggregate="SUM"),
],
)
# These should not be called for table charts
with (
patch(
"superset.mcp_service.chart.validation.runtime.RuntimeValidator."
"_validate_format_compatibility"
) as mock_format,
patch(
"superset.mcp_service.chart.validation.runtime.RuntimeValidator."
"_validate_cardinality"
) as mock_cardinality,
):
is_valid, error = RuntimeValidator.validate_runtime_issues(config, 1)
# Format and cardinality validation should not be called for table charts
mock_format.assert_not_called()
mock_cardinality.assert_not_called()
assert is_valid is True
assert error is None