diff --git a/superset/mcp_service/chart/tool/get_chart_data.py b/superset/mcp_service/chart/tool/get_chart_data.py index a99df31dd94..327052b423f 100644 --- a/superset/mcp_service/chart/tool/get_chart_data.py +++ b/superset/mcp_service/chart/tool/get_chart_data.py @@ -165,21 +165,90 @@ async def get_chart_data( # noqa: C901 or current_app.config["ROW_LIMIT"] ) - # Handle different chart types that have different form_data structures - # Some charts use "metric" (singular), not "metrics" (plural): - # - big_number, big_number_total - # - pop_kpi (BigNumberPeriodOverPeriod) - # These charts also don't have groupby columns + # Handle different chart types that have different form_data + # structures. Chart types that exclusively use "metric" + # (singular) with no groupby: + # big_number, big_number_total, pop_kpi + # Chart types that use "metric" (singular) but may have + # groupby-like fields (entity, series, columns): + # world_map, treemap_v2, sunburst_v2, gauge_chart + # Bubble charts use x/y/size as separate metric fields. viz_type = chart.viz_type or "" - if viz_type in ("big_number", "big_number_total", "pop_kpi"): + + singular_metric_no_groupby = ( + "big_number", + "big_number_total", + "pop_kpi", + ) + singular_metric_types = ( + *singular_metric_no_groupby, + "world_map", + "treemap_v2", + "sunburst_v2", + "gauge_chart", + ) + + if viz_type == "bubble": + # Bubble charts store metrics in x, y, size fields + bubble_metrics = [] + for field in ("x", "y", "size"): + m = form_data.get(field) + if m: + bubble_metrics.append(m) + metrics = bubble_metrics + groupby_columns: list[str] = list( + form_data.get("entity", None) and [form_data["entity"]] or [] + ) + series_field = form_data.get("series") + if series_field and series_field not in groupby_columns: + groupby_columns.append(series_field) + elif viz_type in singular_metric_types: # These chart types use "metric" (singular) metric = form_data.get("metric") metrics = [metric] if metric else [] - groupby_columns: list[str] = [] # These charts don't group by + if viz_type in singular_metric_no_groupby: + groupby_columns = [] + else: + # Some singular-metric charts use groupby, entity, + # series, or columns for dimensional breakdown + groupby_columns = list(form_data.get("groupby") or []) + entity = form_data.get("entity") + if entity and entity not in groupby_columns: + groupby_columns.append(entity) + series = form_data.get("series") + if series and series not in groupby_columns: + groupby_columns.append(series) + form_columns = form_data.get("columns") + if form_columns and isinstance(form_columns, list): + for col in form_columns: + if isinstance(col, str) and col not in groupby_columns: + groupby_columns.append(col) else: # Standard charts use "metrics" (plural) and "groupby" metrics = form_data.get("metrics", []) - groupby_columns = form_data.get("groupby") or [] + groupby_columns = list(form_data.get("groupby") or []) + # Some chart types use "columns" instead of "groupby" + if not groupby_columns: + form_columns = form_data.get("columns") + if form_columns and isinstance(form_columns, list): + for col in form_columns: + if isinstance(col, str): + groupby_columns.append(col) + + # Fallback: if metrics is still empty, try singular "metric" + if not metrics: + fallback_metric = form_data.get("metric") + if fallback_metric: + metrics = [fallback_metric] + + # Fallback: try entity/series if groupby is still empty + if not groupby_columns: + entity = form_data.get("entity") + if entity: + groupby_columns.append(entity) + series = form_data.get("series") + if series and series not in groupby_columns: + groupby_columns.append(series) # Build query columns list: include both x_axis and groupby x_axis_config = form_data.get("x_axis") @@ -192,6 +261,28 @@ async def get_chart_data( # noqa: C901 if col_name and col_name not in query_columns: query_columns.insert(0, col_name) + # Safety net: if we could not extract any metrics or + # columns, return a clear error instead of the cryptic + # "Empty query?" that comes from deeper in the stack. + if not metrics and not query_columns: + await ctx.error( + "Cannot construct fallback query for chart %s " + "(viz_type=%s): no metrics, columns, or groupby " + "could be extracted from form_data. " + "Re-save the chart to populate query_context." + % (chart.id, viz_type) + ) + return ChartError( + error=( + f"Chart {chart.id} (type: {viz_type}) has no " + f"saved query_context and its form_data does " + f"not contain recognizable metrics or columns. " + f"Please open this chart in Superset and " + f"re-save it to generate a query_context." + ), + error_type="MissingQueryContext", + ) + query_context = factory.create( datasource={ "id": chart.datasource_id, diff --git a/tests/unit_tests/mcp_service/chart/tool/test_get_chart_data.py b/tests/unit_tests/mcp_service/chart/tool/test_get_chart_data.py index d276691425f..7850ef82a9a 100644 --- a/tests/unit_tests/mcp_service/chart/tool/test_get_chart_data.py +++ b/tests/unit_tests/mcp_service/chart/tool/test_get_chart_data.py @@ -16,28 +16,114 @@ # under the License. """ -Tests for the get_chart_data request schema and big_number chart handling. +Tests for the get_chart_data request schema and chart type fallback handling. """ +from typing import Any + import pytest from superset.mcp_service.chart.schemas import GetChartDataRequest +def _collect_groupby_extras( + form_data: dict[str, Any], + groupby_columns: list[str], +) -> None: + """Append entity/series/columns from form_data into groupby_columns.""" + entity = form_data.get("entity") + if entity and entity not in groupby_columns: + groupby_columns.append(entity) + series = form_data.get("series") + if series and series not in groupby_columns: + groupby_columns.append(series) + form_columns = form_data.get("columns") + if form_columns and isinstance(form_columns, list): + for col in form_columns: + if isinstance(col, str) and col not in groupby_columns: + groupby_columns.append(col) + + +def _extract_bubble( + form_data: dict[str, Any], +) -> tuple[list[Any], list[str]]: + """Extract metrics and groupby for bubble charts.""" + metrics: list[Any] = [] + for field in ("x", "y", "size"): + m = form_data.get(field) + if m: + metrics.append(m) + entity = form_data.get("entity") + groupby: list[str] = [entity] if entity else [] + series_field = form_data.get("series") + if series_field and series_field not in groupby: + groupby.append(series_field) + return metrics, groupby + + +_SINGULAR_METRIC_NO_GROUPBY = ( + "big_number", + "big_number_total", + "pop_kpi", +) +_SINGULAR_METRIC_TYPES = ( + *_SINGULAR_METRIC_NO_GROUPBY, + "world_map", + "treemap_v2", + "sunburst_v2", + "gauge_chart", +) + + +def _extract_metrics_and_groupby( + form_data: dict[str, Any], +) -> tuple[list[Any], list[str]]: + """Mirror the fallback metric/groupby extraction logic from get_chart_data.py.""" + viz_type = form_data.get("viz_type", "") + + groupby_columns: list[str] + if viz_type == "bubble": + metrics, groupby_columns = _extract_bubble(form_data) + elif viz_type in _SINGULAR_METRIC_TYPES: + metric = form_data.get("metric") + metrics = [metric] if metric else [] + if viz_type in _SINGULAR_METRIC_NO_GROUPBY: + groupby_columns = [] + else: + groupby_columns = list(form_data.get("groupby") or []) + _collect_groupby_extras(form_data, groupby_columns) + else: + metrics = form_data.get("metrics", []) + groupby_columns = list(form_data.get("groupby") or []) + if not groupby_columns: + form_columns = form_data.get("columns") + if form_columns and isinstance(form_columns, list): + groupby_columns = [c for c in form_columns if isinstance(c, str)] + + # Fallback: try singular metric if metrics still empty + if not metrics: + fallback_metric = form_data.get("metric") + if fallback_metric: + metrics = [fallback_metric] + + # Fallback: try entity/series if groupby still empty + if not groupby_columns: + _collect_groupby_extras(form_data, groupby_columns) + + return metrics, groupby_columns + + class TestBigNumberChartFallback: """Tests for big_number chart fallback query construction.""" def test_big_number_uses_singular_metric(self): """Test that big_number charts use 'metric' (singular) from form_data.""" - # Mock form_data for big_number chart form_data = { "metric": {"label": "Count", "expressionType": "SIMPLE", "column": None}, "viz_type": "big_number", } - # Verify the metric extraction logic - metric = form_data.get("metric") - metrics = [metric] if metric else [] + metrics, groupby = _extract_metrics_and_groupby(form_data) assert len(metrics) == 1 assert metrics[0]["label"] == "Count" @@ -49,8 +135,7 @@ class TestBigNumberChartFallback: "viz_type": "big_number_total", } - metric = form_data.get("metric") - metrics = [metric] if metric else [] + metrics, groupby = _extract_metrics_and_groupby(form_data) assert len(metrics) == 1 assert metrics[0]["label"] == "Total Sales" @@ -62,8 +147,7 @@ class TestBigNumberChartFallback: "viz_type": "big_number", } - metric = form_data.get("metric") - metrics = [metric] if metric else [] + metrics, groupby = _extract_metrics_and_groupby(form_data) assert len(metrics) == 0 @@ -75,13 +159,9 @@ class TestBigNumberChartFallback: "groupby": ["should_be_ignored"], # This should be ignored } - viz_type = form_data.get("viz_type", "") - if viz_type.startswith("big_number"): - groupby_columns: list[str] = [] # big_number charts don't group by - else: - groupby_columns = form_data.get("groupby", []) + metrics, groupby = _extract_metrics_and_groupby(form_data) - assert groupby_columns == [] + assert groupby == [] def test_standard_chart_uses_plural_metrics(self): """Test that non-big_number charts use 'metrics' (plural).""" @@ -94,41 +174,43 @@ class TestBigNumberChartFallback: "viz_type": "table", } - viz_type = form_data.get("viz_type", "") - if viz_type.startswith("big_number"): - metric = form_data.get("metric") - metrics = [metric] if metric else [] - groupby_columns: list[str] = [] - else: - metrics = form_data.get("metrics", []) - groupby_columns = form_data.get("groupby", []) + metrics, groupby = _extract_metrics_and_groupby(form_data) assert len(metrics) == 2 - assert len(groupby_columns) == 2 + assert len(groupby) == 2 def test_viz_type_detection_for_single_metric_charts(self): """Test viz_type detection handles all single-metric chart types.""" - # Chart types that use "metric" (singular) instead of "metrics" (plural) - single_metric_types = ("big_number", "pop_kpi") + singular_metric_types = ( + "big_number", + "big_number_total", + "pop_kpi", + "world_map", + "treemap_v2", + "sunburst_v2", + "gauge_chart", + ) - # big_number variants match via startswith - big_number_types = ["big_number", "big_number_total"] - for viz_type in big_number_types: - is_single_metric = ( - viz_type.startswith("big_number") or viz_type in single_metric_types - ) - assert is_single_metric is True + for viz_type in singular_metric_types: + form_data = { + "metric": {"label": "test_metric"}, + "viz_type": viz_type, + } + metrics, _ = _extract_metrics_and_groupby(form_data) + assert len(metrics) == 1, f"{viz_type} should extract singular metric" - # pop_kpi (BigNumberPeriodOverPeriod) matches via exact match - assert "pop_kpi" in single_metric_types - - # Verify standard chart types don't match + # Verify standard chart types don't use singular metric path other_types = ["table", "line", "bar", "pie", "echarts_timeseries"] for viz_type in other_types: - is_single_metric = ( - viz_type.startswith("big_number") or viz_type in single_metric_types + form_data = { + "metric": {"label": "should_be_ignored"}, + "metrics": [{"label": "plural_metric"}], + "viz_type": viz_type, + } + metrics, _ = _extract_metrics_and_groupby(form_data) + assert metrics == [{"label": "plural_metric"}], ( + f"{viz_type} should use plural metrics" ) - assert is_single_metric is False def test_pop_kpi_uses_singular_metric(self): """Test that pop_kpi (BigNumberPeriodOverPeriod) uses singular metric.""" @@ -137,19 +219,274 @@ class TestBigNumberChartFallback: "viz_type": "pop_kpi", } - viz_type = form_data.get("viz_type", "") - single_metric_types = ("big_number", "pop_kpi") - if viz_type.startswith("big_number") or viz_type in single_metric_types: - metric = form_data.get("metric") - metrics = [metric] if metric else [] - groupby_columns: list[str] = [] - else: - metrics = form_data.get("metrics", []) - groupby_columns = form_data.get("groupby", []) + metrics, groupby = _extract_metrics_and_groupby(form_data) assert len(metrics) == 1 assert metrics[0]["label"] == "Period Comparison" - assert groupby_columns == [] + assert groupby == [] + + +class TestWorldMapChartFallback: + """Tests for world_map chart fallback query construction.""" + + def test_world_map_uses_singular_metric(self): + """Test that world_map charts use 'metric' (singular).""" + form_data = { + "metric": {"label": "Population"}, + "entity": "country_code", + "viz_type": "world_map", + } + + metrics, groupby = _extract_metrics_and_groupby(form_data) + + assert len(metrics) == 1 + assert metrics[0]["label"] == "Population" + + def test_world_map_extracts_entity_as_groupby(self): + """Test that world_map entity field becomes groupby.""" + form_data = { + "metric": {"label": "Population"}, + "entity": "country_code", + "viz_type": "world_map", + } + + metrics, groupby = _extract_metrics_and_groupby(form_data) + + assert "country_code" in groupby + + def test_world_map_extracts_series(self): + """Test that world_map series field is added to groupby.""" + form_data = { + "metric": {"label": "Population"}, + "entity": "country_code", + "series": "region", + "viz_type": "world_map", + } + + metrics, groupby = _extract_metrics_and_groupby(form_data) + + assert "country_code" in groupby + assert "region" in groupby + + +class TestTreemapAndSunburstFallback: + """Tests for treemap_v2 and sunburst_v2 chart fallback query construction.""" + + def test_treemap_v2_uses_singular_metric(self): + """Test that treemap_v2 charts use 'metric' (singular).""" + form_data = { + "metric": {"label": "Revenue"}, + "groupby": ["category", "sub_category"], + "viz_type": "treemap_v2", + } + + metrics, groupby = _extract_metrics_and_groupby(form_data) + + assert len(metrics) == 1 + assert metrics[0]["label"] == "Revenue" + assert groupby == ["category", "sub_category"] + + def test_sunburst_v2_uses_singular_metric(self): + """Test that sunburst_v2 charts use 'metric' (singular).""" + form_data = { + "metric": {"label": "Count"}, + "columns": ["level1", "level2", "level3"], + "viz_type": "sunburst_v2", + } + + metrics, groupby = _extract_metrics_and_groupby(form_data) + + assert len(metrics) == 1 + assert metrics[0]["label"] == "Count" + # columns should be picked up as groupby alternatives + assert "level1" in groupby + assert "level2" in groupby + assert "level3" in groupby + + def test_treemap_with_columns_field(self): + """Test that treemap_v2 uses columns field when groupby is missing.""" + form_data = { + "metric": {"label": "Revenue"}, + "columns": ["region", "product"], + "viz_type": "treemap_v2", + } + + metrics, groupby = _extract_metrics_and_groupby(form_data) + + assert len(metrics) == 1 + assert "region" in groupby + assert "product" in groupby + + +class TestGaugeChartFallback: + """Tests for gauge_chart fallback query construction.""" + + def test_gauge_chart_uses_singular_metric(self): + """Test that gauge_chart uses 'metric' (singular).""" + form_data = { + "metric": {"label": "Completion %"}, + "viz_type": "gauge_chart", + } + + metrics, groupby = _extract_metrics_and_groupby(form_data) + + assert len(metrics) == 1 + assert metrics[0]["label"] == "Completion %" + + def test_gauge_chart_with_groupby(self): + """Test that gauge_chart respects groupby if present.""" + form_data = { + "metric": {"label": "Completion %"}, + "groupby": ["department"], + "viz_type": "gauge_chart", + } + + metrics, groupby = _extract_metrics_and_groupby(form_data) + + assert len(metrics) == 1 + assert groupby == ["department"] + + +class TestBubbleChartFallback: + """Tests for bubble chart fallback query construction.""" + + def test_bubble_extracts_x_y_size_as_metrics(self): + """Test that bubble charts extract x, y, size as separate metrics.""" + form_data = { + "x": {"label": "GDP"}, + "y": {"label": "Life Expectancy"}, + "size": {"label": "Population"}, + "entity": "country", + "viz_type": "bubble", + } + + metrics, groupby = _extract_metrics_and_groupby(form_data) + + assert len(metrics) == 3 + assert metrics[0]["label"] == "GDP" + assert metrics[1]["label"] == "Life Expectancy" + assert metrics[2]["label"] == "Population" + + def test_bubble_extracts_entity_as_groupby(self): + """Test that bubble charts use entity as groupby.""" + form_data = { + "x": {"label": "GDP"}, + "y": {"label": "Life Expectancy"}, + "size": {"label": "Population"}, + "entity": "country", + "viz_type": "bubble", + } + + metrics, groupby = _extract_metrics_and_groupby(form_data) + + assert "country" in groupby + + def test_bubble_extracts_series(self): + """Test that bubble charts include series in groupby.""" + form_data = { + "x": {"label": "GDP"}, + "y": {"label": "Life Expectancy"}, + "size": {"label": "Population"}, + "entity": "country", + "series": "continent", + "viz_type": "bubble", + } + + metrics, groupby = _extract_metrics_and_groupby(form_data) + + assert "country" in groupby + assert "continent" in groupby + + def test_bubble_partial_metrics(self): + """Test bubble chart with only some metric fields set.""" + form_data = { + "x": {"label": "GDP"}, + "y": None, + "size": {"label": "Population"}, + "entity": "country", + "viz_type": "bubble", + } + + metrics, groupby = _extract_metrics_and_groupby(form_data) + + assert len(metrics) == 2 + labels = [m["label"] for m in metrics] + assert "GDP" in labels + assert "Population" in labels + + +class TestFallbackMetricExtraction: + """Tests for the fallback singular metric extraction.""" + + def test_standard_chart_falls_back_to_singular_metric(self): + """Test that standard charts try singular metric if plural is empty.""" + form_data = { + "metric": {"label": "Fallback Metric"}, + "metrics": [], + "groupby": ["region"], + "viz_type": "bar", + } + + metrics, groupby = _extract_metrics_and_groupby(form_data) + + assert len(metrics) == 1 + assert metrics[0]["label"] == "Fallback Metric" + + def test_standard_chart_no_metrics_at_all(self): + """Test standard chart with neither metrics nor metric.""" + form_data = { + "groupby": ["region"], + "viz_type": "bar", + } + + metrics, groupby = _extract_metrics_and_groupby(form_data) + + assert len(metrics) == 0 + assert groupby == ["region"] + + def test_standard_chart_uses_columns_as_groupby_fallback(self): + """Test that standard charts use columns field when groupby is empty.""" + form_data = { + "metrics": [{"label": "Count"}], + "columns": ["col_a", "col_b"], + "viz_type": "table", + } + + metrics, groupby = _extract_metrics_and_groupby(form_data) + + assert len(metrics) == 1 + assert "col_a" in groupby + assert "col_b" in groupby + + def test_entity_series_fallback_for_unknown_chart(self): + """Test that entity/series are used as groupby fallback.""" + form_data = { + "metric": {"label": "Some Metric"}, + "entity": "name_col", + "series": "type_col", + "viz_type": "some_unknown_type", + } + + metrics, groupby = _extract_metrics_and_groupby(form_data) + + assert len(metrics) == 1 + assert "name_col" in groupby + assert "type_col" in groupby + + +class TestSafetyNetEmptyQuery: + """Tests for the safety net when no metrics/columns can be extracted.""" + + def test_completely_empty_form_data_yields_empty(self): + """Test that form_data with nothing extractable returns empty.""" + form_data = { + "viz_type": "mystery_chart", + } + + metrics, groupby = _extract_metrics_and_groupby(form_data) + + assert metrics == [] + assert groupby == [] class TestXAxisInQueryContext: