Compare commits

...

2 Commits

Author SHA1 Message Date
alexandrusoare
794ed48e1e fix(type): improve type casting 2026-05-12 16:09:13 +03:00
alexandrusoare
752434ce9a fix(preview): fix chart preview bugs 2026-05-12 16:01:33 +03:00
2 changed files with 181 additions and 24 deletions

View File

@@ -56,6 +56,7 @@ from superset.mcp_service.utils.oauth2_utils import (
OAUTH2_CONFIG_ERROR_MESSAGE,
)
from superset.mcp_service.utils.url_utils import get_superset_base_url
from superset.superset_typing import Column, Metric
logger = logging.getLogger(__name__)
@@ -145,22 +146,89 @@ class ChartLike(Protocol):
uuid: Any
def _build_query_columns(form_data: Dict[str, Any]) -> list[str]:
"""Build query columns list from form_data, including both x_axis and groupby."""
x_axis_config = form_data.get("x_axis")
groupby_columns: list[str] = form_data.get("groupby") or []
def _build_query_columns(form_data: Dict[str, Any]) -> list[Column]:
"""Build query columns list from form_data, including both x_axis and groupby.
Handles chart-type-specific keys:
- Standard charts: ``groupby`` + ``x_axis``
- Pivot tables: ``groupbyColumns`` + ``groupbyRows`` (when ``groupby`` is absent)
- Mixed timeseries: ``groupby_b`` (secondary groupby)
"""
x_axis_config: Column | None = form_data.get("x_axis")
groupby_columns: list[Column] = form_data.get("groupby") or []
# Pivot tables store dimensions under groupbyColumns / groupbyRows
if not groupby_columns:
pivot_rows: list[Column] = form_data.get("groupbyRows") or []
pivot_cols: list[Column] = form_data.get("groupbyColumns") or []
groupby_columns = list(pivot_rows) + list(pivot_cols)
# Mixed timeseries stores secondary groupby under groupby_b
groupby_b: list[Column] = form_data.get("groupby_b") or []
for col in groupby_b:
if col not in groupby_columns:
groupby_columns.append(col)
# Deduplicate while preserving order
seen: set[str] = set()
columns: list[Column] = []
def _add_unique(col: Column) -> None:
key = col if isinstance(col, str) else col.get("label", str(col))
if key not in seen:
columns.append(col)
seen.add(key)
columns = groupby_columns.copy()
if x_axis_config and isinstance(x_axis_config, str):
if x_axis_config not in columns:
columns.insert(0, x_axis_config)
_add_unique(x_axis_config)
elif x_axis_config and isinstance(x_axis_config, dict):
col_name = x_axis_config.get("column_name")
if col_name and col_name not in columns:
columns.insert(0, col_name)
if col_name and isinstance(col_name, str):
_add_unique(col_name)
for col in groupby_columns:
_add_unique(col)
return columns
def _build_query_metrics(form_data: Dict[str, Any]) -> list[Metric]:
"""Extract metrics from form_data, handling chart-type variations.
Handles:
- ``metrics`` (plural) — most chart types
- ``metric`` (singular) — Pie charts
- ``metrics_b`` — secondary y-axis in Mixed Timeseries charts
"""
metrics: list[Metric] = list(form_data.get("metrics") or [])
if not metrics:
singular: Metric | None = form_data.get("metric")
if singular:
metrics = [singular]
# Mixed timeseries stores the second y-axis metrics under metrics_b
metrics_b: list[Metric] = form_data.get("metrics_b") or []
for m in metrics_b:
if m not in metrics:
metrics.append(m)
return metrics
def _build_chart_description(chart: ChartLike) -> str:
"""Build a human-readable chart description, with hints for special chart types."""
base = (
f"Preview of {chart.viz_type or 'chart'}: "
f"{chart.slice_name or f'Chart {chart.id}'}"
)
if chart.viz_type == "handlebars":
base += (
". Note: Handlebars charts use browser-side template rendering; "
"this preview shows the raw underlying data, not the rendered template"
)
return base
class PreviewFormatStrategy:
"""Base class for preview format strategies."""
@@ -215,9 +283,7 @@ class ASCIIPreviewStrategy(PreviewFormatStrategy):
)
# Build query for chart data
x_axis_config = form_data.get("x_axis")
groupby_columns = form_data.get("groupby", [])
metrics = form_data.get("metrics", [])
metrics = _build_query_metrics(form_data)
# Table charts in raw mode use all_columns or columns
all_columns = form_data.get("all_columns", [])
@@ -225,12 +291,7 @@ class ASCIIPreviewStrategy(PreviewFormatStrategy):
if form_data.get("query_mode") == "raw" and (all_columns or raw_columns):
columns = list(all_columns or raw_columns)
else:
columns = groupby_columns.copy()
if x_axis_config and isinstance(x_axis_config, str):
columns.append(x_axis_config)
elif x_axis_config and isinstance(x_axis_config, dict):
if "column_name" in x_axis_config:
columns.append(x_axis_config["column_name"])
columns = _build_query_columns(form_data)
if not columns and not metrics:
return ChartError(
@@ -327,7 +388,7 @@ class TablePreviewStrategy(PreviewFormatStrategy):
{
"filters": form_data.get("filters", []),
"columns": columns,
"metrics": form_data.get("metrics", []),
"metrics": _build_query_metrics(form_data),
"row_limit": 20,
"order_desc": True,
}
@@ -433,7 +494,7 @@ class VegaLitePreviewStrategy(PreviewFormatStrategy):
{
"filters": form_data.get("filters", []),
"columns": columns,
"metrics": form_data.get("metrics", []),
"metrics": _build_query_metrics(form_data),
"row_limit": 1000, # More data for visualization
"order_desc": True,
}
@@ -1371,10 +1432,7 @@ async def _get_chart_preview_internal( # noqa: C901
chart_type=chart.viz_type or "unknown",
explore_url=f"{get_superset_base_url()}/explore/?slice_id={chart.id}",
content=content,
chart_description=(
f"Preview of {chart.viz_type or 'chart'}: "
f"{chart.slice_name or f'Chart {chart.id}'}"
),
chart_description=_build_chart_description(chart),
accessibility=accessibility,
performance=performance,
)

View File

@@ -19,6 +19,8 @@
Unit tests for get_chart_preview MCP tool
"""
from unittest.mock import MagicMock
import pytest
from superset.mcp_service.chart.schemas import (
@@ -33,6 +35,9 @@ from superset.mcp_service.chart.schemas import (
VegaLitePreview,
)
from superset.mcp_service.chart.tool.get_chart_preview import (
_build_chart_description,
_build_query_columns,
_build_query_metrics,
_sanitize_chart_preview_for_llm_context,
)
from superset.mcp_service.utils import sanitize_for_llm_context
@@ -597,6 +602,100 @@ Market Share
# These demonstrate the expected ASCII formats for different chart types
def test_build_query_columns_standard_groupby():
form_data = {"x_axis": "date", "groupby": ["region"]}
assert _build_query_columns(form_data) == ["date", "region"]
def test_build_query_columns_pivot_table():
"""Pivot tables use groupbyColumns/groupbyRows instead of groupby."""
form_data = {
"groupbyRows": ["product"],
"groupbyColumns": ["region"],
"metrics": [{"label": "SUM(sales)"}],
}
columns = _build_query_columns(form_data)
assert "product" in columns
assert "region" in columns
def test_build_query_columns_mixed_timeseries_groupby_b():
"""Mixed timeseries stores secondary groupby under groupby_b."""
form_data = {
"x_axis": "date",
"groupby": ["series_a"],
"groupby_b": ["series_b"],
}
columns = _build_query_columns(form_data)
assert "date" in columns
assert "series_a" in columns
assert "series_b" in columns
def test_build_query_columns_no_duplicates():
form_data = {
"x_axis": "date",
"groupby": ["date", "region"],
}
columns = _build_query_columns(form_data)
assert columns.count("date") == 1
def test_build_query_metrics_plural():
form_data = {"metrics": [{"label": "SUM(sales)"}, {"label": "COUNT(*)"}]}
assert _build_query_metrics(form_data) == [
{"label": "SUM(sales)"},
{"label": "COUNT(*)"},
]
def test_build_query_metrics_singular_for_pie():
"""Pie charts use metric (singular) instead of metrics."""
form_data = {"metric": "SUM(amount)"}
assert _build_query_metrics(form_data) == ["SUM(amount)"]
def test_build_query_metrics_mixed_timeseries():
"""Mixed timeseries stores secondary metrics under metrics_b."""
form_data = {
"metrics": [{"label": "SUM(revenue)"}],
"metrics_b": [{"label": "AVG(cost)"}],
}
result = _build_query_metrics(form_data)
assert {"label": "SUM(revenue)"} in result
assert {"label": "AVG(cost)"} in result
def test_build_query_metrics_empty():
assert _build_query_metrics({}) == []
def test_build_query_columns_pivot_overlapping_rows_and_columns():
"""Overlapping values in groupbyRows and groupbyColumns are deduplicated."""
form_data = {
"groupbyRows": ["country", "region"],
"groupbyColumns": ["region", "city"],
}
columns = _build_query_columns(form_data)
assert columns.count("region") == 1
assert "country" in columns
assert "city" in columns
def test_build_chart_description_standard():
chart = MagicMock(viz_type="line", slice_name="Sales Trend", id=1)
desc = _build_chart_description(chart)
assert desc == "Preview of line: Sales Trend"
def test_build_chart_description_handlebars():
chart = MagicMock(viz_type="handlebars", slice_name="My Template", id=2)
desc = _build_chart_description(chart)
assert "Handlebars" in desc
assert "raw underlying data" in desc
assert "template rendering" in desc
class TestDetachedInstanceError:
"""Tests that DetachedInstanceError is handled gracefully.