diff --git a/superset/mcp_service/chart/tool/get_chart_data.py b/superset/mcp_service/chart/tool/get_chart_data.py index d4d1804af4a..f72da49660f 100644 --- a/superset/mcp_service/chart/tool/get_chart_data.py +++ b/superset/mcp_service/chart/tool/get_chart_data.py @@ -161,6 +161,24 @@ async def get_chart_data( # noqa: C901 or form_data.get("row_limit") or current_app.config["ROW_LIMIT"] ) + + # Handle different chart types that have different form_data structures + # Some charts use "metric" (singular), not "metrics" (plural): + # - big_number, big_number_total + # - pop_kpi (BigNumberPeriodOverPeriod) + # These charts also don't have groupby columns + viz_type = chart.viz_type or "" + single_metric_types = ("big_number", "pop_kpi") + if viz_type.startswith("big_number") or viz_type in single_metric_types: + # These chart types use "metric" (singular) + metric = form_data.get("metric") + metrics = [metric] if metric else [] + groupby_columns: list[str] = [] # These charts don't group by + else: + # Standard charts use "metrics" (plural) and "groupby" + metrics = form_data.get("metrics", []) + groupby_columns = form_data.get("groupby", []) + query_context = factory.create( datasource={ "id": chart.datasource_id, @@ -169,8 +187,8 @@ async def get_chart_data( # noqa: C901 queries=[ { "filters": form_data.get("filters", []), - "columns": form_data.get("groupby", []), - "metrics": form_data.get("metrics", []), + "columns": groupby_columns, + "metrics": metrics, "row_limit": row_limit, "order_desc": True, } diff --git a/superset/mcp_service/chart/validation/runtime/__init__.py b/superset/mcp_service/chart/validation/runtime/__init__.py index e48585fe19c..659ada61de2 100644 --- a/superset/mcp_service/chart/validation/runtime/__init__.py +++ b/superset/mcp_service/chart/validation/runtime/__init__.py @@ -75,22 +75,16 @@ class RuntimeValidator: warnings.extend(type_warnings) suggestions.extend(type_suggestions) - # If we have warnings, return them as a validation error + # Semantic warnings are informational, not blocking errors. + # Log them for debugging but allow chart generation to proceed. if warnings: - from superset.mcp_service.utils.error_builder import ( - ChartErrorBuilder, - ) - - return False, ChartErrorBuilder.build_error( - error_type="runtime_semantic_warning", - template_key="performance_warning", - template_vars={ - "reason": "; ".join(warnings[:3]) - + ("..." if len(warnings) > 3 else "") - }, - custom_suggestions=suggestions[:5], # Limit suggestions - error_code="RUNTIME_SEMANTIC_WARNING", + logger.info( + "Runtime semantic warnings for dataset %s: %s", + dataset_id, + "; ".join(warnings[:3]) + ("..." if len(warnings) > 3 else ""), ) + if suggestions: + logger.info("Suggestions: %s", "; ".join(suggestions[:3])) return True, None diff --git a/tests/unit_tests/mcp_service/chart/tool/test_get_chart_data.py b/tests/unit_tests/mcp_service/chart/tool/test_get_chart_data.py index 30b61980a90..2669366f526 100644 --- a/tests/unit_tests/mcp_service/chart/tool/test_get_chart_data.py +++ b/tests/unit_tests/mcp_service/chart/tool/test_get_chart_data.py @@ -16,7 +16,7 @@ # under the License. """ -Tests for the get_chart_data request schema +Tests for the get_chart_data request schema and big_number chart handling. """ import pytest @@ -24,6 +24,134 @@ import pytest from superset.mcp_service.chart.schemas import GetChartDataRequest +class TestBigNumberChartFallback: + """Tests for big_number chart fallback query construction.""" + + def test_big_number_uses_singular_metric(self): + """Test that big_number charts use 'metric' (singular) from form_data.""" + # Mock form_data for big_number chart + form_data = { + "metric": {"label": "Count", "expressionType": "SIMPLE", "column": None}, + "viz_type": "big_number", + } + + # Verify the metric extraction logic + metric = form_data.get("metric") + metrics = [metric] if metric else [] + + assert len(metrics) == 1 + assert metrics[0]["label"] == "Count" + + def test_big_number_total_uses_singular_metric(self): + """Test that big_number_total charts use 'metric' (singular).""" + form_data = { + "metric": {"label": "Total Sales", "expressionType": "SQL"}, + "viz_type": "big_number_total", + } + + metric = form_data.get("metric") + metrics = [metric] if metric else [] + + assert len(metrics) == 1 + assert metrics[0]["label"] == "Total Sales" + + def test_big_number_empty_metric_returns_empty_list(self): + """Test handling of big_number chart with no metric configured.""" + form_data = { + "metric": None, + "viz_type": "big_number", + } + + metric = form_data.get("metric") + metrics = [metric] if metric else [] + + assert len(metrics) == 0 + + def test_big_number_no_groupby_columns(self): + """Test that big_number charts don't have groupby columns.""" + form_data = { + "metric": {"label": "Count"}, + "viz_type": "big_number", + "groupby": ["should_be_ignored"], # This should be ignored + } + + viz_type = form_data.get("viz_type", "") + if viz_type.startswith("big_number"): + groupby_columns: list[str] = [] # big_number charts don't group by + else: + groupby_columns = form_data.get("groupby", []) + + assert groupby_columns == [] + + def test_standard_chart_uses_plural_metrics(self): + """Test that non-big_number charts use 'metrics' (plural).""" + form_data = { + "metrics": [ + {"label": "Sum of Sales"}, + {"label": "Avg of Quantity"}, + ], + "groupby": ["region", "category"], + "viz_type": "table", + } + + viz_type = form_data.get("viz_type", "") + if viz_type.startswith("big_number"): + metric = form_data.get("metric") + metrics = [metric] if metric else [] + groupby_columns: list[str] = [] + else: + metrics = form_data.get("metrics", []) + groupby_columns = form_data.get("groupby", []) + + assert len(metrics) == 2 + assert len(groupby_columns) == 2 + + def test_viz_type_detection_for_single_metric_charts(self): + """Test viz_type detection handles all single-metric chart types.""" + # Chart types that use "metric" (singular) instead of "metrics" (plural) + single_metric_types = ("big_number", "pop_kpi") + + # big_number variants match via startswith + big_number_types = ["big_number", "big_number_total"] + for viz_type in big_number_types: + is_single_metric = ( + viz_type.startswith("big_number") or viz_type in single_metric_types + ) + assert is_single_metric is True + + # pop_kpi (BigNumberPeriodOverPeriod) matches via exact match + assert "pop_kpi" in single_metric_types + + # Verify standard chart types don't match + other_types = ["table", "line", "bar", "pie", "echarts_timeseries"] + for viz_type in other_types: + is_single_metric = ( + viz_type.startswith("big_number") or viz_type in single_metric_types + ) + assert is_single_metric is False + + def test_pop_kpi_uses_singular_metric(self): + """Test that pop_kpi (BigNumberPeriodOverPeriod) uses singular metric.""" + form_data = { + "metric": {"label": "Period Comparison", "expressionType": "SQL"}, + "viz_type": "pop_kpi", + } + + viz_type = form_data.get("viz_type", "") + single_metric_types = ("big_number", "pop_kpi") + if viz_type.startswith("big_number") or viz_type in single_metric_types: + metric = form_data.get("metric") + metrics = [metric] if metric else [] + groupby_columns: list[str] = [] + else: + metrics = form_data.get("metrics", []) + groupby_columns = form_data.get("groupby", []) + + assert len(metrics) == 1 + assert metrics[0]["label"] == "Period Comparison" + assert groupby_columns == [] + + class TestGetChartDataRequestSchema: """Test the GetChartDataRequest schema validation.""" diff --git a/tests/unit_tests/mcp_service/chart/validation/__init__.py b/tests/unit_tests/mcp_service/chart/validation/__init__.py new file mode 100644 index 00000000000..13a83393a91 --- /dev/null +++ b/tests/unit_tests/mcp_service/chart/validation/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/unit_tests/mcp_service/chart/validation/test_runtime_validator.py b/tests/unit_tests/mcp_service/chart/validation/test_runtime_validator.py new file mode 100644 index 00000000000..ba8cc3fb9f9 --- /dev/null +++ b/tests/unit_tests/mcp_service/chart/validation/test_runtime_validator.py @@ -0,0 +1,221 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Tests for RuntimeValidator. + +These tests verify that semantic warnings are non-blocking and +chart generation succeeds even when warnings are present. +""" + +from unittest.mock import patch + +from superset.mcp_service.chart.schemas import ( + AxisConfig, + ColumnRef, + TableChartConfig, + XYChartConfig, +) +from superset.mcp_service.chart.validation.runtime import RuntimeValidator + + +class TestRuntimeValidatorNonBlocking: + """Test that RuntimeValidator treats semantic warnings as non-blocking.""" + + def test_validate_runtime_issues_returns_valid_without_warnings(self): + """Test that validation returns (True, None) when no issues found.""" + config = TableChartConfig( + chart_type="table", + columns=[ColumnRef(name="col1")], + ) + + is_valid, error = RuntimeValidator.validate_runtime_issues(config, 1) + + assert is_valid is True + assert error is None + + def test_validate_runtime_issues_non_blocking_with_format_warnings(self): + """Test that format compatibility warnings do NOT block chart generation.""" + config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="date"), + y=[ColumnRef(name="value", aggregate="SUM")], + kind="line", + x_axis=AxisConfig(format="$,.2f"), # Currency format for date - mismatch + ) + + # Mock the format validator to return warnings + with patch( + "superset.mcp_service.chart.validation.runtime.RuntimeValidator." + "_validate_format_compatibility" + ) as mock_format: + mock_format.return_value = [ + "Currency format '$,.2f' may not display dates correctly" + ] + + is_valid, error = RuntimeValidator.validate_runtime_issues(config, 1) + + # Should still return valid despite warnings + assert is_valid is True + assert error is None + + def test_validate_runtime_issues_non_blocking_with_cardinality_warnings(self): + """Test that cardinality warnings do NOT block chart generation.""" + config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="user_id"), # High cardinality column + y=[ColumnRef(name="count", aggregate="COUNT")], + kind="bar", + ) + + # Mock the cardinality validator to return warnings + with patch( + "superset.mcp_service.chart.validation.runtime.RuntimeValidator." + "_validate_cardinality" + ) as mock_cardinality: + mock_cardinality.return_value = ( + ["High cardinality detected: 10000+ unique values"], + ["Consider using aggregation or filtering"], + ) + + is_valid, error = RuntimeValidator.validate_runtime_issues(config, 1) + + # Should still return valid despite high cardinality warning + assert is_valid is True + assert error is None + + def test_validate_runtime_issues_non_blocking_with_chart_type_suggestions(self): + """Test that chart type suggestions do NOT block chart generation.""" + config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="category"), # Categorical data + y=[ColumnRef(name="sales", aggregate="SUM")], + kind="line", # Line chart for categorical data + ) + + # Mock the chart type suggester to return suggestions + with patch( + "superset.mcp_service.chart.validation.runtime.RuntimeValidator." + "_validate_chart_type" + ) as mock_suggester: + mock_suggester.return_value = ( + ["Line chart may not be ideal for categorical X axis"], + ["Consider using bar chart for better visualization"], + ) + + is_valid, error = RuntimeValidator.validate_runtime_issues(config, 1) + + # Should still return valid despite suggestion + assert is_valid is True + assert error is None + + def test_validate_runtime_issues_non_blocking_with_multiple_warnings(self): + """Test that multiple warnings combined do NOT block chart generation.""" + config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="user_id"), + y=[ColumnRef(name="amount", aggregate="SUM")], + kind="scatter", + x_axis=AxisConfig(format="smart_date"), # Wrong format for user_id + ) + + # Mock all validators to return warnings + with ( + patch( + "superset.mcp_service.chart.validation.runtime.RuntimeValidator." + "_validate_format_compatibility" + ) as mock_format, + patch( + "superset.mcp_service.chart.validation.runtime.RuntimeValidator." + "_validate_cardinality" + ) as mock_cardinality, + patch( + "superset.mcp_service.chart.validation.runtime.RuntimeValidator." + "_validate_chart_type" + ) as mock_type, + ): + mock_format.return_value = ["Format mismatch warning"] + mock_cardinality.return_value = ( + ["High cardinality warning"], + ["Cardinality suggestion"], + ) + mock_type.return_value = ( + ["Chart type warning"], + ["Chart type suggestion"], + ) + + is_valid, error = RuntimeValidator.validate_runtime_issues(config, 1) + + # Should still return valid despite multiple warnings + assert is_valid is True + assert error is None + + def test_validate_runtime_issues_logs_warnings(self): + """Test that warnings are logged for debugging purposes.""" + config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="date"), + y=[ColumnRef(name="value", aggregate="SUM")], + kind="line", + ) + + with ( + patch( + "superset.mcp_service.chart.validation.runtime.RuntimeValidator." + "_validate_format_compatibility" + ) as mock_format, + patch( + "superset.mcp_service.chart.validation.runtime.logger" + ) as mock_logger, + ): + mock_format.return_value = ["Test warning message"] + + is_valid, error = RuntimeValidator.validate_runtime_issues(config, 1) + + # Verify warnings were logged + assert mock_logger.info.called + assert is_valid is True + assert error is None + + def test_validate_table_chart_skips_xy_validations(self): + """Test that table charts skip XY-specific validations.""" + config = TableChartConfig( + chart_type="table", + columns=[ + ColumnRef(name="region"), + ColumnRef(name="sales", aggregate="SUM"), + ], + ) + + # These should not be called for table charts + with ( + patch( + "superset.mcp_service.chart.validation.runtime.RuntimeValidator." + "_validate_format_compatibility" + ) as mock_format, + patch( + "superset.mcp_service.chart.validation.runtime.RuntimeValidator." + "_validate_cardinality" + ) as mock_cardinality, + ): + is_valid, error = RuntimeValidator.validate_runtime_issues(config, 1) + + # Format and cardinality validation should not be called for table charts + mock_format.assert_not_called() + mock_cardinality.assert_not_called() + assert is_valid is True + assert error is None