From 558ff4452b8d2a5e7d5f4cc6f21964f72b21c1c7 Mon Sep 17 00:00:00 2001 From: Alexandru Soare <37236580+alexandrusoare@users.noreply.github.com> Date: Fri, 22 May 2026 13:42:59 +0300 Subject: [PATCH] fix(preview): fix chart preview bugs (#40063) --- .../chart/tool/get_chart_preview.py | 91 ++++++++++++++--- .../chart/tool/test_get_chart_preview.py | 98 +++++++++++++++++++ 2 files changed, 176 insertions(+), 13 deletions(-) diff --git a/superset/mcp_service/chart/tool/get_chart_preview.py b/superset/mcp_service/chart/tool/get_chart_preview.py index 798fdd4fda0..f8da956b66c 100644 --- a/superset/mcp_service/chart/tool/get_chart_preview.py +++ b/superset/mcp_service/chart/tool/get_chart_preview.py @@ -59,6 +59,7 @@ from superset.mcp_service.utils.oauth2_utils import ( OAUTH2_CONFIG_ERROR_MESSAGE, ) from superset.mcp_service.utils.url_utils import get_superset_base_url +from superset.superset_typing import Column, Metric logger = logging.getLogger(__name__) @@ -148,22 +149,89 @@ class ChartLike(Protocol): uuid: Any -def _build_query_columns(form_data: Dict[str, Any]) -> list[str]: - """Build query columns list from form_data, including both x_axis and groupby.""" - x_axis_config = form_data.get("x_axis") - groupby_columns: list[str] = form_data.get("groupby") or [] +def _build_query_columns(form_data: Dict[str, Any]) -> list[Column]: + """Build query columns list from form_data, including both x_axis and groupby. + + Handles chart-type-specific keys: + - Standard charts: ``groupby`` + ``x_axis`` + - Pivot tables: ``groupbyColumns`` + ``groupbyRows`` (when ``groupby`` is absent) + - Mixed timeseries: ``groupby_b`` (secondary groupby) + """ + x_axis_config: Column | None = form_data.get("x_axis") + groupby_columns: list[Column] = form_data.get("groupby") or [] + + # Pivot tables store dimensions under groupbyColumns / groupbyRows + if not groupby_columns: + pivot_rows: list[Column] = form_data.get("groupbyRows") or [] + pivot_cols: list[Column] = form_data.get("groupbyColumns") or [] + groupby_columns = list(pivot_rows) + list(pivot_cols) + + # Mixed timeseries stores secondary groupby under groupby_b + groupby_b: list[Column] = form_data.get("groupby_b") or [] + for col in groupby_b: + if col not in groupby_columns: + groupby_columns.append(col) + + # Deduplicate while preserving order + seen: set[str] = set() + columns: list[Column] = [] + + def _add_unique(col: Column) -> None: + key = col if isinstance(col, str) else col.get("label", str(col)) + if key not in seen: + columns.append(col) + seen.add(key) - columns = groupby_columns.copy() if x_axis_config and isinstance(x_axis_config, str): - if x_axis_config not in columns: - columns.insert(0, x_axis_config) + _add_unique(x_axis_config) elif x_axis_config and isinstance(x_axis_config, dict): col_name = x_axis_config.get("column_name") - if col_name and col_name not in columns: - columns.insert(0, col_name) + if col_name and isinstance(col_name, str): + _add_unique(col_name) + + for col in groupby_columns: + _add_unique(col) + return columns +def _build_query_metrics(form_data: Dict[str, Any]) -> list[Metric]: + """Extract metrics from form_data, handling chart-type variations. + + Handles: + - ``metrics`` (plural) — most chart types + - ``metric`` (singular) — Pie charts + - ``metrics_b`` — secondary y-axis in Mixed Timeseries charts + """ + metrics: list[Metric] = list(form_data.get("metrics") or []) + if not metrics: + singular: Metric | None = form_data.get("metric") + if singular: + metrics = [singular] + + # Mixed timeseries stores the second y-axis metrics under metrics_b + metrics_b: list[Metric] = form_data.get("metrics_b") or [] + for m in metrics_b: + if m not in metrics: + metrics.append(m) + + return metrics + + +def _build_chart_description(chart: ChartLike) -> str: + """Build a human-readable chart description, with hints for special chart types.""" + base = ( + f"Preview of {chart.viz_type or 'chart'}: " + f"{chart.slice_name or f'Chart {chart.id}'}" + ) + if chart.viz_type == "handlebars": + base += ( + ". Note: Handlebars charts use browser-side template rendering; " + "this preview shows the raw underlying data, not the rendered template" + ) + return base + + class PreviewFormatStrategy: """Base class for preview format strategies.""" @@ -1304,10 +1372,7 @@ async def _get_chart_preview_internal( # noqa: C901 chart_type=chart.viz_type or "unknown", explore_url=f"{get_superset_base_url()}/explore/?slice_id={chart.id}", content=content, - chart_description=( - f"Preview of {chart.viz_type or 'chart'}: " - f"{chart.slice_name or f'Chart {chart.id}'}" - ), + chart_description=_build_chart_description(chart), accessibility=accessibility, performance=performance, ) diff --git a/tests/unit_tests/mcp_service/chart/tool/test_get_chart_preview.py b/tests/unit_tests/mcp_service/chart/tool/test_get_chart_preview.py index 98b5e5fff7b..1a296e8674c 100644 --- a/tests/unit_tests/mcp_service/chart/tool/test_get_chart_preview.py +++ b/tests/unit_tests/mcp_service/chart/tool/test_get_chart_preview.py @@ -22,6 +22,7 @@ Unit tests for get_chart_preview MCP tool import importlib from types import SimpleNamespace from typing import Any +from unittest.mock import MagicMock import pytest @@ -37,6 +38,9 @@ from superset.mcp_service.chart.schemas import ( VegaLitePreview, ) from superset.mcp_service.chart.tool.get_chart_preview import ( + _build_chart_description, + _build_query_columns, + _build_query_metrics, _sanitize_chart_preview_for_llm_context, ASCIIPreviewStrategy, TablePreviewStrategy, @@ -983,6 +987,100 @@ Market Share # These demonstrate the expected ASCII formats for different chart types +def test_build_query_columns_standard_groupby(): + form_data = {"x_axis": "date", "groupby": ["region"]} + assert _build_query_columns(form_data) == ["date", "region"] + + +def test_build_query_columns_pivot_table(): + """Pivot tables use groupbyColumns/groupbyRows instead of groupby.""" + form_data = { + "groupbyRows": ["product"], + "groupbyColumns": ["region"], + "metrics": [{"label": "SUM(sales)"}], + } + columns = _build_query_columns(form_data) + assert "product" in columns + assert "region" in columns + + +def test_build_query_columns_mixed_timeseries_groupby_b(): + """Mixed timeseries stores secondary groupby under groupby_b.""" + form_data = { + "x_axis": "date", + "groupby": ["series_a"], + "groupby_b": ["series_b"], + } + columns = _build_query_columns(form_data) + assert "date" in columns + assert "series_a" in columns + assert "series_b" in columns + + +def test_build_query_columns_no_duplicates(): + form_data = { + "x_axis": "date", + "groupby": ["date", "region"], + } + columns = _build_query_columns(form_data) + assert columns.count("date") == 1 + + +def test_build_query_metrics_plural(): + form_data = {"metrics": [{"label": "SUM(sales)"}, {"label": "COUNT(*)"}]} + assert _build_query_metrics(form_data) == [ + {"label": "SUM(sales)"}, + {"label": "COUNT(*)"}, + ] + + +def test_build_query_metrics_singular_for_pie(): + """Pie charts use metric (singular) instead of metrics.""" + form_data = {"metric": "SUM(amount)"} + assert _build_query_metrics(form_data) == ["SUM(amount)"] + + +def test_build_query_metrics_mixed_timeseries(): + """Mixed timeseries stores secondary metrics under metrics_b.""" + form_data = { + "metrics": [{"label": "SUM(revenue)"}], + "metrics_b": [{"label": "AVG(cost)"}], + } + result = _build_query_metrics(form_data) + assert {"label": "SUM(revenue)"} in result + assert {"label": "AVG(cost)"} in result + + +def test_build_query_metrics_empty(): + assert _build_query_metrics({}) == [] + + +def test_build_query_columns_pivot_overlapping_rows_and_columns(): + """Overlapping values in groupbyRows and groupbyColumns are deduplicated.""" + form_data = { + "groupbyRows": ["country", "region"], + "groupbyColumns": ["region", "city"], + } + columns = _build_query_columns(form_data) + assert columns.count("region") == 1 + assert "country" in columns + assert "city" in columns + + +def test_build_chart_description_standard(): + chart = MagicMock(viz_type="line", slice_name="Sales Trend", id=1) + desc = _build_chart_description(chart) + assert desc == "Preview of line: Sales Trend" + + +def test_build_chart_description_handlebars(): + chart = MagicMock(viz_type="handlebars", slice_name="My Template", id=2) + desc = _build_chart_description(chart) + assert "Handlebars" in desc + assert "raw underlying data" in desc + assert "template rendering" in desc + + class TestDetachedInstanceError: """Tests that DetachedInstanceError is handled gracefully.