mirror of
https://github.com/apache/superset.git
synced 2026-05-19 06:45:15 +00:00
Compare commits
2 Commits
docs/dashb
...
alexandrus
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
794ed48e1e | ||
|
|
752434ce9a |
@@ -56,6 +56,7 @@ from superset.mcp_service.utils.oauth2_utils import (
|
||||
OAUTH2_CONFIG_ERROR_MESSAGE,
|
||||
)
|
||||
from superset.mcp_service.utils.url_utils import get_superset_base_url
|
||||
from superset.superset_typing import Column, Metric
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -145,22 +146,89 @@ class ChartLike(Protocol):
|
||||
uuid: Any
|
||||
|
||||
|
||||
def _build_query_columns(form_data: Dict[str, Any]) -> list[str]:
|
||||
"""Build query columns list from form_data, including both x_axis and groupby."""
|
||||
x_axis_config = form_data.get("x_axis")
|
||||
groupby_columns: list[str] = form_data.get("groupby") or []
|
||||
def _build_query_columns(form_data: Dict[str, Any]) -> list[Column]:
|
||||
"""Build query columns list from form_data, including both x_axis and groupby.
|
||||
|
||||
Handles chart-type-specific keys:
|
||||
- Standard charts: ``groupby`` + ``x_axis``
|
||||
- Pivot tables: ``groupbyColumns`` + ``groupbyRows`` (when ``groupby`` is absent)
|
||||
- Mixed timeseries: ``groupby_b`` (secondary groupby)
|
||||
"""
|
||||
x_axis_config: Column | None = form_data.get("x_axis")
|
||||
groupby_columns: list[Column] = form_data.get("groupby") or []
|
||||
|
||||
# Pivot tables store dimensions under groupbyColumns / groupbyRows
|
||||
if not groupby_columns:
|
||||
pivot_rows: list[Column] = form_data.get("groupbyRows") or []
|
||||
pivot_cols: list[Column] = form_data.get("groupbyColumns") or []
|
||||
groupby_columns = list(pivot_rows) + list(pivot_cols)
|
||||
|
||||
# Mixed timeseries stores secondary groupby under groupby_b
|
||||
groupby_b: list[Column] = form_data.get("groupby_b") or []
|
||||
for col in groupby_b:
|
||||
if col not in groupby_columns:
|
||||
groupby_columns.append(col)
|
||||
|
||||
# Deduplicate while preserving order
|
||||
seen: set[str] = set()
|
||||
columns: list[Column] = []
|
||||
|
||||
def _add_unique(col: Column) -> None:
|
||||
key = col if isinstance(col, str) else col.get("label", str(col))
|
||||
if key not in seen:
|
||||
columns.append(col)
|
||||
seen.add(key)
|
||||
|
||||
columns = groupby_columns.copy()
|
||||
if x_axis_config and isinstance(x_axis_config, str):
|
||||
if x_axis_config not in columns:
|
||||
columns.insert(0, x_axis_config)
|
||||
_add_unique(x_axis_config)
|
||||
elif x_axis_config and isinstance(x_axis_config, dict):
|
||||
col_name = x_axis_config.get("column_name")
|
||||
if col_name and col_name not in columns:
|
||||
columns.insert(0, col_name)
|
||||
if col_name and isinstance(col_name, str):
|
||||
_add_unique(col_name)
|
||||
|
||||
for col in groupby_columns:
|
||||
_add_unique(col)
|
||||
|
||||
return columns
|
||||
|
||||
|
||||
def _build_query_metrics(form_data: Dict[str, Any]) -> list[Metric]:
|
||||
"""Extract metrics from form_data, handling chart-type variations.
|
||||
|
||||
Handles:
|
||||
- ``metrics`` (plural) — most chart types
|
||||
- ``metric`` (singular) — Pie charts
|
||||
- ``metrics_b`` — secondary y-axis in Mixed Timeseries charts
|
||||
"""
|
||||
metrics: list[Metric] = list(form_data.get("metrics") or [])
|
||||
if not metrics:
|
||||
singular: Metric | None = form_data.get("metric")
|
||||
if singular:
|
||||
metrics = [singular]
|
||||
|
||||
# Mixed timeseries stores the second y-axis metrics under metrics_b
|
||||
metrics_b: list[Metric] = form_data.get("metrics_b") or []
|
||||
for m in metrics_b:
|
||||
if m not in metrics:
|
||||
metrics.append(m)
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
def _build_chart_description(chart: ChartLike) -> str:
|
||||
"""Build a human-readable chart description, with hints for special chart types."""
|
||||
base = (
|
||||
f"Preview of {chart.viz_type or 'chart'}: "
|
||||
f"{chart.slice_name or f'Chart {chart.id}'}"
|
||||
)
|
||||
if chart.viz_type == "handlebars":
|
||||
base += (
|
||||
". Note: Handlebars charts use browser-side template rendering; "
|
||||
"this preview shows the raw underlying data, not the rendered template"
|
||||
)
|
||||
return base
|
||||
|
||||
|
||||
class PreviewFormatStrategy:
|
||||
"""Base class for preview format strategies."""
|
||||
|
||||
@@ -215,9 +283,7 @@ class ASCIIPreviewStrategy(PreviewFormatStrategy):
|
||||
)
|
||||
|
||||
# Build query for chart data
|
||||
x_axis_config = form_data.get("x_axis")
|
||||
groupby_columns = form_data.get("groupby", [])
|
||||
metrics = form_data.get("metrics", [])
|
||||
metrics = _build_query_metrics(form_data)
|
||||
|
||||
# Table charts in raw mode use all_columns or columns
|
||||
all_columns = form_data.get("all_columns", [])
|
||||
@@ -225,12 +291,7 @@ class ASCIIPreviewStrategy(PreviewFormatStrategy):
|
||||
if form_data.get("query_mode") == "raw" and (all_columns or raw_columns):
|
||||
columns = list(all_columns or raw_columns)
|
||||
else:
|
||||
columns = groupby_columns.copy()
|
||||
if x_axis_config and isinstance(x_axis_config, str):
|
||||
columns.append(x_axis_config)
|
||||
elif x_axis_config and isinstance(x_axis_config, dict):
|
||||
if "column_name" in x_axis_config:
|
||||
columns.append(x_axis_config["column_name"])
|
||||
columns = _build_query_columns(form_data)
|
||||
|
||||
if not columns and not metrics:
|
||||
return ChartError(
|
||||
@@ -327,7 +388,7 @@ class TablePreviewStrategy(PreviewFormatStrategy):
|
||||
{
|
||||
"filters": form_data.get("filters", []),
|
||||
"columns": columns,
|
||||
"metrics": form_data.get("metrics", []),
|
||||
"metrics": _build_query_metrics(form_data),
|
||||
"row_limit": 20,
|
||||
"order_desc": True,
|
||||
}
|
||||
@@ -433,7 +494,7 @@ class VegaLitePreviewStrategy(PreviewFormatStrategy):
|
||||
{
|
||||
"filters": form_data.get("filters", []),
|
||||
"columns": columns,
|
||||
"metrics": form_data.get("metrics", []),
|
||||
"metrics": _build_query_metrics(form_data),
|
||||
"row_limit": 1000, # More data for visualization
|
||||
"order_desc": True,
|
||||
}
|
||||
@@ -1371,10 +1432,7 @@ async def _get_chart_preview_internal( # noqa: C901
|
||||
chart_type=chart.viz_type or "unknown",
|
||||
explore_url=f"{get_superset_base_url()}/explore/?slice_id={chart.id}",
|
||||
content=content,
|
||||
chart_description=(
|
||||
f"Preview of {chart.viz_type or 'chart'}: "
|
||||
f"{chart.slice_name or f'Chart {chart.id}'}"
|
||||
),
|
||||
chart_description=_build_chart_description(chart),
|
||||
accessibility=accessibility,
|
||||
performance=performance,
|
||||
)
|
||||
|
||||
@@ -19,6 +19,8 @@
|
||||
Unit tests for get_chart_preview MCP tool
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from superset.mcp_service.chart.schemas import (
|
||||
@@ -33,6 +35,9 @@ from superset.mcp_service.chart.schemas import (
|
||||
VegaLitePreview,
|
||||
)
|
||||
from superset.mcp_service.chart.tool.get_chart_preview import (
|
||||
_build_chart_description,
|
||||
_build_query_columns,
|
||||
_build_query_metrics,
|
||||
_sanitize_chart_preview_for_llm_context,
|
||||
)
|
||||
from superset.mcp_service.utils import sanitize_for_llm_context
|
||||
@@ -597,6 +602,100 @@ Market Share
|
||||
# These demonstrate the expected ASCII formats for different chart types
|
||||
|
||||
|
||||
def test_build_query_columns_standard_groupby():
|
||||
form_data = {"x_axis": "date", "groupby": ["region"]}
|
||||
assert _build_query_columns(form_data) == ["date", "region"]
|
||||
|
||||
|
||||
def test_build_query_columns_pivot_table():
|
||||
"""Pivot tables use groupbyColumns/groupbyRows instead of groupby."""
|
||||
form_data = {
|
||||
"groupbyRows": ["product"],
|
||||
"groupbyColumns": ["region"],
|
||||
"metrics": [{"label": "SUM(sales)"}],
|
||||
}
|
||||
columns = _build_query_columns(form_data)
|
||||
assert "product" in columns
|
||||
assert "region" in columns
|
||||
|
||||
|
||||
def test_build_query_columns_mixed_timeseries_groupby_b():
|
||||
"""Mixed timeseries stores secondary groupby under groupby_b."""
|
||||
form_data = {
|
||||
"x_axis": "date",
|
||||
"groupby": ["series_a"],
|
||||
"groupby_b": ["series_b"],
|
||||
}
|
||||
columns = _build_query_columns(form_data)
|
||||
assert "date" in columns
|
||||
assert "series_a" in columns
|
||||
assert "series_b" in columns
|
||||
|
||||
|
||||
def test_build_query_columns_no_duplicates():
|
||||
form_data = {
|
||||
"x_axis": "date",
|
||||
"groupby": ["date", "region"],
|
||||
}
|
||||
columns = _build_query_columns(form_data)
|
||||
assert columns.count("date") == 1
|
||||
|
||||
|
||||
def test_build_query_metrics_plural():
|
||||
form_data = {"metrics": [{"label": "SUM(sales)"}, {"label": "COUNT(*)"}]}
|
||||
assert _build_query_metrics(form_data) == [
|
||||
{"label": "SUM(sales)"},
|
||||
{"label": "COUNT(*)"},
|
||||
]
|
||||
|
||||
|
||||
def test_build_query_metrics_singular_for_pie():
|
||||
"""Pie charts use metric (singular) instead of metrics."""
|
||||
form_data = {"metric": "SUM(amount)"}
|
||||
assert _build_query_metrics(form_data) == ["SUM(amount)"]
|
||||
|
||||
|
||||
def test_build_query_metrics_mixed_timeseries():
|
||||
"""Mixed timeseries stores secondary metrics under metrics_b."""
|
||||
form_data = {
|
||||
"metrics": [{"label": "SUM(revenue)"}],
|
||||
"metrics_b": [{"label": "AVG(cost)"}],
|
||||
}
|
||||
result = _build_query_metrics(form_data)
|
||||
assert {"label": "SUM(revenue)"} in result
|
||||
assert {"label": "AVG(cost)"} in result
|
||||
|
||||
|
||||
def test_build_query_metrics_empty():
|
||||
assert _build_query_metrics({}) == []
|
||||
|
||||
|
||||
def test_build_query_columns_pivot_overlapping_rows_and_columns():
|
||||
"""Overlapping values in groupbyRows and groupbyColumns are deduplicated."""
|
||||
form_data = {
|
||||
"groupbyRows": ["country", "region"],
|
||||
"groupbyColumns": ["region", "city"],
|
||||
}
|
||||
columns = _build_query_columns(form_data)
|
||||
assert columns.count("region") == 1
|
||||
assert "country" in columns
|
||||
assert "city" in columns
|
||||
|
||||
|
||||
def test_build_chart_description_standard():
|
||||
chart = MagicMock(viz_type="line", slice_name="Sales Trend", id=1)
|
||||
desc = _build_chart_description(chart)
|
||||
assert desc == "Preview of line: Sales Trend"
|
||||
|
||||
|
||||
def test_build_chart_description_handlebars():
|
||||
chart = MagicMock(viz_type="handlebars", slice_name="My Template", id=2)
|
||||
desc = _build_chart_description(chart)
|
||||
assert "Handlebars" in desc
|
||||
assert "raw underlying data" in desc
|
||||
assert "template rendering" in desc
|
||||
|
||||
|
||||
class TestDetachedInstanceError:
|
||||
"""Tests that DetachedInstanceError is handled gracefully.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user