fix(mcp): handle more chart types in get_chart_data fallback query construction (#37969)

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Amin Ghadersohi
2026-02-17 07:02:42 -05:00
committed by GitHub
parent 9566e8a9c6
commit 5cd829f13c
2 changed files with 486 additions and 58 deletions

View File

@@ -165,21 +165,90 @@ async def get_chart_data( # noqa: C901
or current_app.config["ROW_LIMIT"]
)
# Handle different chart types that have different form_data structures
# Some charts use "metric" (singular), not "metrics" (plural):
# - big_number, big_number_total
# - pop_kpi (BigNumberPeriodOverPeriod)
# These charts also don't have groupby columns
# Handle different chart types that have different form_data
# structures. Chart types that exclusively use "metric"
# (singular) with no groupby:
# big_number, big_number_total, pop_kpi
# Chart types that use "metric" (singular) but may have
# groupby-like fields (entity, series, columns):
# world_map, treemap_v2, sunburst_v2, gauge_chart
# Bubble charts use x/y/size as separate metric fields.
viz_type = chart.viz_type or ""
if viz_type in ("big_number", "big_number_total", "pop_kpi"):
singular_metric_no_groupby = (
"big_number",
"big_number_total",
"pop_kpi",
)
singular_metric_types = (
*singular_metric_no_groupby,
"world_map",
"treemap_v2",
"sunburst_v2",
"gauge_chart",
)
if viz_type == "bubble":
# Bubble charts store metrics in x, y, size fields
bubble_metrics = []
for field in ("x", "y", "size"):
m = form_data.get(field)
if m:
bubble_metrics.append(m)
metrics = bubble_metrics
groupby_columns: list[str] = list(
form_data.get("entity", None) and [form_data["entity"]] or []
)
series_field = form_data.get("series")
if series_field and series_field not in groupby_columns:
groupby_columns.append(series_field)
elif viz_type in singular_metric_types:
# These chart types use "metric" (singular)
metric = form_data.get("metric")
metrics = [metric] if metric else []
groupby_columns: list[str] = [] # These charts don't group by
if viz_type in singular_metric_no_groupby:
groupby_columns = []
else:
# Some singular-metric charts use groupby, entity,
# series, or columns for dimensional breakdown
groupby_columns = list(form_data.get("groupby") or [])
entity = form_data.get("entity")
if entity and entity not in groupby_columns:
groupby_columns.append(entity)
series = form_data.get("series")
if series and series not in groupby_columns:
groupby_columns.append(series)
form_columns = form_data.get("columns")
if form_columns and isinstance(form_columns, list):
for col in form_columns:
if isinstance(col, str) and col not in groupby_columns:
groupby_columns.append(col)
else:
# Standard charts use "metrics" (plural) and "groupby"
metrics = form_data.get("metrics", [])
groupby_columns = form_data.get("groupby") or []
groupby_columns = list(form_data.get("groupby") or [])
# Some chart types use "columns" instead of "groupby"
if not groupby_columns:
form_columns = form_data.get("columns")
if form_columns and isinstance(form_columns, list):
for col in form_columns:
if isinstance(col, str):
groupby_columns.append(col)
# Fallback: if metrics is still empty, try singular "metric"
if not metrics:
fallback_metric = form_data.get("metric")
if fallback_metric:
metrics = [fallback_metric]
# Fallback: try entity/series if groupby is still empty
if not groupby_columns:
entity = form_data.get("entity")
if entity:
groupby_columns.append(entity)
series = form_data.get("series")
if series and series not in groupby_columns:
groupby_columns.append(series)
# Build query columns list: include both x_axis and groupby
x_axis_config = form_data.get("x_axis")
@@ -192,6 +261,28 @@ async def get_chart_data( # noqa: C901
if col_name and col_name not in query_columns:
query_columns.insert(0, col_name)
# Safety net: if we could not extract any metrics or
# columns, return a clear error instead of the cryptic
# "Empty query?" that comes from deeper in the stack.
if not metrics and not query_columns:
await ctx.error(
"Cannot construct fallback query for chart %s "
"(viz_type=%s): no metrics, columns, or groupby "
"could be extracted from form_data. "
"Re-save the chart to populate query_context."
% (chart.id, viz_type)
)
return ChartError(
error=(
f"Chart {chart.id} (type: {viz_type}) has no "
f"saved query_context and its form_data does "
f"not contain recognizable metrics or columns. "
f"Please open this chart in Superset and "
f"re-save it to generate a query_context."
),
error_type="MissingQueryContext",
)
query_context = factory.create(
datasource={
"id": chart.datasource_id,

View File

@@ -16,28 +16,114 @@
# under the License.
"""
Tests for the get_chart_data request schema and big_number chart handling.
Tests for the get_chart_data request schema and chart type fallback handling.
"""
from typing import Any
import pytest
from superset.mcp_service.chart.schemas import GetChartDataRequest
def _collect_groupby_extras(
form_data: dict[str, Any],
groupby_columns: list[str],
) -> None:
"""Append entity/series/columns from form_data into groupby_columns."""
entity = form_data.get("entity")
if entity and entity not in groupby_columns:
groupby_columns.append(entity)
series = form_data.get("series")
if series and series not in groupby_columns:
groupby_columns.append(series)
form_columns = form_data.get("columns")
if form_columns and isinstance(form_columns, list):
for col in form_columns:
if isinstance(col, str) and col not in groupby_columns:
groupby_columns.append(col)
def _extract_bubble(
form_data: dict[str, Any],
) -> tuple[list[Any], list[str]]:
"""Extract metrics and groupby for bubble charts."""
metrics: list[Any] = []
for field in ("x", "y", "size"):
m = form_data.get(field)
if m:
metrics.append(m)
entity = form_data.get("entity")
groupby: list[str] = [entity] if entity else []
series_field = form_data.get("series")
if series_field and series_field not in groupby:
groupby.append(series_field)
return metrics, groupby
_SINGULAR_METRIC_NO_GROUPBY = (
"big_number",
"big_number_total",
"pop_kpi",
)
_SINGULAR_METRIC_TYPES = (
*_SINGULAR_METRIC_NO_GROUPBY,
"world_map",
"treemap_v2",
"sunburst_v2",
"gauge_chart",
)
def _extract_metrics_and_groupby(
form_data: dict[str, Any],
) -> tuple[list[Any], list[str]]:
"""Mirror the fallback metric/groupby extraction logic from get_chart_data.py."""
viz_type = form_data.get("viz_type", "")
groupby_columns: list[str]
if viz_type == "bubble":
metrics, groupby_columns = _extract_bubble(form_data)
elif viz_type in _SINGULAR_METRIC_TYPES:
metric = form_data.get("metric")
metrics = [metric] if metric else []
if viz_type in _SINGULAR_METRIC_NO_GROUPBY:
groupby_columns = []
else:
groupby_columns = list(form_data.get("groupby") or [])
_collect_groupby_extras(form_data, groupby_columns)
else:
metrics = form_data.get("metrics", [])
groupby_columns = list(form_data.get("groupby") or [])
if not groupby_columns:
form_columns = form_data.get("columns")
if form_columns and isinstance(form_columns, list):
groupby_columns = [c for c in form_columns if isinstance(c, str)]
# Fallback: try singular metric if metrics still empty
if not metrics:
fallback_metric = form_data.get("metric")
if fallback_metric:
metrics = [fallback_metric]
# Fallback: try entity/series if groupby still empty
if not groupby_columns:
_collect_groupby_extras(form_data, groupby_columns)
return metrics, groupby_columns
class TestBigNumberChartFallback:
"""Tests for big_number chart fallback query construction."""
def test_big_number_uses_singular_metric(self):
"""Test that big_number charts use 'metric' (singular) from form_data."""
# Mock form_data for big_number chart
form_data = {
"metric": {"label": "Count", "expressionType": "SIMPLE", "column": None},
"viz_type": "big_number",
}
# Verify the metric extraction logic
metric = form_data.get("metric")
metrics = [metric] if metric else []
metrics, groupby = _extract_metrics_and_groupby(form_data)
assert len(metrics) == 1
assert metrics[0]["label"] == "Count"
@@ -49,8 +135,7 @@ class TestBigNumberChartFallback:
"viz_type": "big_number_total",
}
metric = form_data.get("metric")
metrics = [metric] if metric else []
metrics, groupby = _extract_metrics_and_groupby(form_data)
assert len(metrics) == 1
assert metrics[0]["label"] == "Total Sales"
@@ -62,8 +147,7 @@ class TestBigNumberChartFallback:
"viz_type": "big_number",
}
metric = form_data.get("metric")
metrics = [metric] if metric else []
metrics, groupby = _extract_metrics_and_groupby(form_data)
assert len(metrics) == 0
@@ -75,13 +159,9 @@ class TestBigNumberChartFallback:
"groupby": ["should_be_ignored"], # This should be ignored
}
viz_type = form_data.get("viz_type", "")
if viz_type.startswith("big_number"):
groupby_columns: list[str] = [] # big_number charts don't group by
else:
groupby_columns = form_data.get("groupby", [])
metrics, groupby = _extract_metrics_and_groupby(form_data)
assert groupby_columns == []
assert groupby == []
def test_standard_chart_uses_plural_metrics(self):
"""Test that non-big_number charts use 'metrics' (plural)."""
@@ -94,41 +174,43 @@ class TestBigNumberChartFallback:
"viz_type": "table",
}
viz_type = form_data.get("viz_type", "")
if viz_type.startswith("big_number"):
metric = form_data.get("metric")
metrics = [metric] if metric else []
groupby_columns: list[str] = []
else:
metrics = form_data.get("metrics", [])
groupby_columns = form_data.get("groupby", [])
metrics, groupby = _extract_metrics_and_groupby(form_data)
assert len(metrics) == 2
assert len(groupby_columns) == 2
assert len(groupby) == 2
def test_viz_type_detection_for_single_metric_charts(self):
"""Test viz_type detection handles all single-metric chart types."""
# Chart types that use "metric" (singular) instead of "metrics" (plural)
single_metric_types = ("big_number", "pop_kpi")
singular_metric_types = (
"big_number",
"big_number_total",
"pop_kpi",
"world_map",
"treemap_v2",
"sunburst_v2",
"gauge_chart",
)
# big_number variants match via startswith
big_number_types = ["big_number", "big_number_total"]
for viz_type in big_number_types:
is_single_metric = (
viz_type.startswith("big_number") or viz_type in single_metric_types
)
assert is_single_metric is True
for viz_type in singular_metric_types:
form_data = {
"metric": {"label": "test_metric"},
"viz_type": viz_type,
}
metrics, _ = _extract_metrics_and_groupby(form_data)
assert len(metrics) == 1, f"{viz_type} should extract singular metric"
# pop_kpi (BigNumberPeriodOverPeriod) matches via exact match
assert "pop_kpi" in single_metric_types
# Verify standard chart types don't match
# Verify standard chart types don't use singular metric path
other_types = ["table", "line", "bar", "pie", "echarts_timeseries"]
for viz_type in other_types:
is_single_metric = (
viz_type.startswith("big_number") or viz_type in single_metric_types
form_data = {
"metric": {"label": "should_be_ignored"},
"metrics": [{"label": "plural_metric"}],
"viz_type": viz_type,
}
metrics, _ = _extract_metrics_and_groupby(form_data)
assert metrics == [{"label": "plural_metric"}], (
f"{viz_type} should use plural metrics"
)
assert is_single_metric is False
def test_pop_kpi_uses_singular_metric(self):
"""Test that pop_kpi (BigNumberPeriodOverPeriod) uses singular metric."""
@@ -137,19 +219,274 @@ class TestBigNumberChartFallback:
"viz_type": "pop_kpi",
}
viz_type = form_data.get("viz_type", "")
single_metric_types = ("big_number", "pop_kpi")
if viz_type.startswith("big_number") or viz_type in single_metric_types:
metric = form_data.get("metric")
metrics = [metric] if metric else []
groupby_columns: list[str] = []
else:
metrics = form_data.get("metrics", [])
groupby_columns = form_data.get("groupby", [])
metrics, groupby = _extract_metrics_and_groupby(form_data)
assert len(metrics) == 1
assert metrics[0]["label"] == "Period Comparison"
assert groupby_columns == []
assert groupby == []
class TestWorldMapChartFallback:
"""Tests for world_map chart fallback query construction."""
def test_world_map_uses_singular_metric(self):
"""Test that world_map charts use 'metric' (singular)."""
form_data = {
"metric": {"label": "Population"},
"entity": "country_code",
"viz_type": "world_map",
}
metrics, groupby = _extract_metrics_and_groupby(form_data)
assert len(metrics) == 1
assert metrics[0]["label"] == "Population"
def test_world_map_extracts_entity_as_groupby(self):
"""Test that world_map entity field becomes groupby."""
form_data = {
"metric": {"label": "Population"},
"entity": "country_code",
"viz_type": "world_map",
}
metrics, groupby = _extract_metrics_and_groupby(form_data)
assert "country_code" in groupby
def test_world_map_extracts_series(self):
"""Test that world_map series field is added to groupby."""
form_data = {
"metric": {"label": "Population"},
"entity": "country_code",
"series": "region",
"viz_type": "world_map",
}
metrics, groupby = _extract_metrics_and_groupby(form_data)
assert "country_code" in groupby
assert "region" in groupby
class TestTreemapAndSunburstFallback:
"""Tests for treemap_v2 and sunburst_v2 chart fallback query construction."""
def test_treemap_v2_uses_singular_metric(self):
"""Test that treemap_v2 charts use 'metric' (singular)."""
form_data = {
"metric": {"label": "Revenue"},
"groupby": ["category", "sub_category"],
"viz_type": "treemap_v2",
}
metrics, groupby = _extract_metrics_and_groupby(form_data)
assert len(metrics) == 1
assert metrics[0]["label"] == "Revenue"
assert groupby == ["category", "sub_category"]
def test_sunburst_v2_uses_singular_metric(self):
"""Test that sunburst_v2 charts use 'metric' (singular)."""
form_data = {
"metric": {"label": "Count"},
"columns": ["level1", "level2", "level3"],
"viz_type": "sunburst_v2",
}
metrics, groupby = _extract_metrics_and_groupby(form_data)
assert len(metrics) == 1
assert metrics[0]["label"] == "Count"
# columns should be picked up as groupby alternatives
assert "level1" in groupby
assert "level2" in groupby
assert "level3" in groupby
def test_treemap_with_columns_field(self):
"""Test that treemap_v2 uses columns field when groupby is missing."""
form_data = {
"metric": {"label": "Revenue"},
"columns": ["region", "product"],
"viz_type": "treemap_v2",
}
metrics, groupby = _extract_metrics_and_groupby(form_data)
assert len(metrics) == 1
assert "region" in groupby
assert "product" in groupby
class TestGaugeChartFallback:
"""Tests for gauge_chart fallback query construction."""
def test_gauge_chart_uses_singular_metric(self):
"""Test that gauge_chart uses 'metric' (singular)."""
form_data = {
"metric": {"label": "Completion %"},
"viz_type": "gauge_chart",
}
metrics, groupby = _extract_metrics_and_groupby(form_data)
assert len(metrics) == 1
assert metrics[0]["label"] == "Completion %"
def test_gauge_chart_with_groupby(self):
"""Test that gauge_chart respects groupby if present."""
form_data = {
"metric": {"label": "Completion %"},
"groupby": ["department"],
"viz_type": "gauge_chart",
}
metrics, groupby = _extract_metrics_and_groupby(form_data)
assert len(metrics) == 1
assert groupby == ["department"]
class TestBubbleChartFallback:
"""Tests for bubble chart fallback query construction."""
def test_bubble_extracts_x_y_size_as_metrics(self):
"""Test that bubble charts extract x, y, size as separate metrics."""
form_data = {
"x": {"label": "GDP"},
"y": {"label": "Life Expectancy"},
"size": {"label": "Population"},
"entity": "country",
"viz_type": "bubble",
}
metrics, groupby = _extract_metrics_and_groupby(form_data)
assert len(metrics) == 3
assert metrics[0]["label"] == "GDP"
assert metrics[1]["label"] == "Life Expectancy"
assert metrics[2]["label"] == "Population"
def test_bubble_extracts_entity_as_groupby(self):
"""Test that bubble charts use entity as groupby."""
form_data = {
"x": {"label": "GDP"},
"y": {"label": "Life Expectancy"},
"size": {"label": "Population"},
"entity": "country",
"viz_type": "bubble",
}
metrics, groupby = _extract_metrics_and_groupby(form_data)
assert "country" in groupby
def test_bubble_extracts_series(self):
"""Test that bubble charts include series in groupby."""
form_data = {
"x": {"label": "GDP"},
"y": {"label": "Life Expectancy"},
"size": {"label": "Population"},
"entity": "country",
"series": "continent",
"viz_type": "bubble",
}
metrics, groupby = _extract_metrics_and_groupby(form_data)
assert "country" in groupby
assert "continent" in groupby
def test_bubble_partial_metrics(self):
"""Test bubble chart with only some metric fields set."""
form_data = {
"x": {"label": "GDP"},
"y": None,
"size": {"label": "Population"},
"entity": "country",
"viz_type": "bubble",
}
metrics, groupby = _extract_metrics_and_groupby(form_data)
assert len(metrics) == 2
labels = [m["label"] for m in metrics]
assert "GDP" in labels
assert "Population" in labels
class TestFallbackMetricExtraction:
"""Tests for the fallback singular metric extraction."""
def test_standard_chart_falls_back_to_singular_metric(self):
"""Test that standard charts try singular metric if plural is empty."""
form_data = {
"metric": {"label": "Fallback Metric"},
"metrics": [],
"groupby": ["region"],
"viz_type": "bar",
}
metrics, groupby = _extract_metrics_and_groupby(form_data)
assert len(metrics) == 1
assert metrics[0]["label"] == "Fallback Metric"
def test_standard_chart_no_metrics_at_all(self):
"""Test standard chart with neither metrics nor metric."""
form_data = {
"groupby": ["region"],
"viz_type": "bar",
}
metrics, groupby = _extract_metrics_and_groupby(form_data)
assert len(metrics) == 0
assert groupby == ["region"]
def test_standard_chart_uses_columns_as_groupby_fallback(self):
"""Test that standard charts use columns field when groupby is empty."""
form_data = {
"metrics": [{"label": "Count"}],
"columns": ["col_a", "col_b"],
"viz_type": "table",
}
metrics, groupby = _extract_metrics_and_groupby(form_data)
assert len(metrics) == 1
assert "col_a" in groupby
assert "col_b" in groupby
def test_entity_series_fallback_for_unknown_chart(self):
"""Test that entity/series are used as groupby fallback."""
form_data = {
"metric": {"label": "Some Metric"},
"entity": "name_col",
"series": "type_col",
"viz_type": "some_unknown_type",
}
metrics, groupby = _extract_metrics_and_groupby(form_data)
assert len(metrics) == 1
assert "name_col" in groupby
assert "type_col" in groupby
class TestSafetyNetEmptyQuery:
"""Tests for the safety net when no metrics/columns can be extracted."""
def test_completely_empty_form_data_yields_empty(self):
"""Test that form_data with nothing extractable returns empty."""
form_data = {
"viz_type": "mystery_chart",
}
metrics, groupby = _extract_metrics_and_groupby(form_data)
assert metrics == []
assert groupby == []
class TestXAxisInQueryContext: