fix(recommandation): fix chart recommandation

This commit is contained in:
alexandrusoare
2026-05-04 13:19:38 +03:00
parent dc1c0f6ba1
commit 8b2a8d21c1
2 changed files with 240 additions and 11 deletions

View File

@@ -52,10 +52,143 @@ from superset.mcp_service.utils.oauth2_utils import (
build_oauth2_redirect_message,
OAUTH2_CONFIG_ERROR_MESSAGE,
)
from superset.utils.core import merge_extra_filters
from superset.utils.core import GenericDataType, merge_extra_filters
logger = logging.getLogger(__name__)
_GENERIC_TYPE_MAP: dict[int, str] = {
GenericDataType.NUMERIC: "numeric",
GenericDataType.STRING: "string",
GenericDataType.TEMPORAL: "temporal",
GenericDataType.BOOLEAN: "boolean",
}
# Maps Superset viz_type strings to canonical categories so we can
# avoid recommending a chart type the user already has.
_VIZ_CATEGORY: dict[str, str] = {
"echarts_timeseries_line": "line",
"echarts_timeseries_bar": "bar",
"echarts_area": "area",
"echarts_timeseries_scatter": "scatter",
"table": "table",
"pie": "pie",
"big_number": "kpi",
"big_number_total": "kpi",
"dist_bar": "bar",
"line": "line",
"area": "area",
"scatter": "scatter",
"bubble": "bubble",
"treemap_v2": "treemap",
"heatmap_v2": "heatmap",
"gauge_chart": "gauge",
"histogram": "histogram",
"box_plot": "box_plot",
"world_map": "map",
}
_MAX_RECOMMENDATIONS = 4
def _recommend_visualizations(
viz_type: str,
columns: list[DataColumn],
row_count: int,
) -> list[str]:
"""Suggest visualization types based on column types,
cardinality, and the chart's current viz_type.
"""
if not columns:
return ["table"]
current_category = _VIZ_CATEGORY.get(viz_type, viz_type)
candidates = _build_candidates(columns, row_count)
if not candidates:
candidates = ["table", "bar chart"]
return _filter_candidates(candidates, current_category)
def _build_candidates(
columns: list[DataColumn],
row_count: int,
) -> list[str]:
"""Build candidate visualization list from column metadata."""
temporal = [c for c in columns if c.data_type == "temporal"]
numeric = [c for c in columns if c.data_type == "numeric"]
categorical = [
c
for c in columns
if c.data_type in ("string", "boolean")
and (
c.unique_count <= 20 or (row_count > 0 and c.unique_count / row_count < 0.5)
)
]
candidates: list[str] = []
if temporal and numeric:
candidates.append("line chart")
candidates.append("area chart")
candidates.append("bar chart")
if len(numeric) > 1:
candidates.append("multi-line chart")
elif categorical and numeric:
candidates.append("bar chart")
if len(numeric) == 1 and categorical and categorical[0].unique_count <= 10:
candidates.append("pie chart")
if any(c.unique_count > 5 for c in categorical):
candidates.append("treemap")
elif len(numeric) >= 2:
candidates.append("scatter plot")
if len(numeric) >= 3:
candidates.append("bubble chart")
if categorical:
candidates.append("heatmap")
elif len(numeric) == 1 and not temporal and not categorical:
candidates.append("big number / KPI")
candidates.append("gauge chart")
return candidates
# Maps each candidate string to a canonical category for dedup
# against the current viz_type.
_CANDIDATE_CATEGORY: dict[str, str] = {
"line chart": "line",
"multi-line chart": "line",
"area chart": "area",
"bar chart": "bar",
"scatter plot": "scatter",
"bubble chart": "bubble",
"pie chart": "pie",
"treemap": "treemap",
"heatmap": "heatmap",
"big number / KPI": "kpi",
"gauge chart": "gauge",
"table": "table",
}
def _filter_candidates(
candidates: list[str],
current_category: str,
) -> list[str]:
"""Deduplicate, exclude the current viz category, and cap."""
seen: set[str] = set()
result: list[str] = []
for c in candidates:
if c in seen:
continue
if _CANDIDATE_CATEGORY.get(c) == current_category:
continue
seen.add(c)
result.append(c)
if len(result) >= _MAX_RECOMMENDATIONS:
break
return result
def _sanitize_chart_data_for_llm_context(chart_data: ChartData) -> ChartData:
"""Wrap chart data read-path descriptive fields before LLM exposure."""
@@ -620,8 +753,9 @@ async def get_chart_data( # noqa: C901
)
# Create rich column metadata
coltypes = query_result.get("coltypes", [])
columns = []
for col_name in raw_columns:
for idx, col_name in enumerate(raw_columns):
# Sample some values for metadata
sample_values = [
row.get(col_name)
@@ -629,9 +763,12 @@ async def get_chart_data( # noqa: C901
if row.get(col_name) is not None
]
# Infer data type
# Use SQL-derived GenericDataType when available,
# fall back to Python isinstance heuristic
data_type = "string"
if sample_values:
if idx < len(coltypes):
data_type = _GENERIC_TYPE_MAP.get(coltypes[idx], "string")
elif sample_values:
if all(isinstance(v, (int, float)) for v in sample_values):
data_type = "numeric"
elif all(isinstance(v, bool) for v in sample_values):
@@ -678,13 +815,11 @@ async def get_chart_data( # noqa: C901
else:
insights.append("Fresh data retrieved from database")
recommended_visualizations = []
if any(
"time" in col.lower() or "date" in col.lower() for col in raw_columns
):
recommended_visualizations.extend(["line chart", "time series"])
if len(raw_columns) <= 3:
recommended_visualizations.extend(["bar chart", "scatter plot"])
recommended_visualizations = _recommend_visualizations(
viz_type=chart.viz_type or "unknown",
columns=columns,
row_count=len(data),
)
# Performance metadata with cache awareness
execution_time = int((time.time() - start_time) * 1000)

View File

@@ -30,6 +30,8 @@ from superset.mcp_service.chart.schemas import (
PerformanceMetadata,
)
from superset.mcp_service.chart.tool.get_chart_data import (
_MAX_RECOMMENDATIONS,
_recommend_visualizations,
_sanitize_chart_data_for_llm_context,
)
from superset.mcp_service.utils import sanitize_for_llm_context
@@ -988,3 +990,95 @@ class TestChartDataCommandValidation:
)
mock_command.run.assert_not_called()
# ---------------------------------------------------------------------------
# Tests for _recommend_visualizations
# ---------------------------------------------------------------------------
def _col(
name: str,
data_type: str = "string",
unique_count: int = 5,
null_count: int = 0,
) -> DataColumn:
"""Shortcut to build a DataColumn for tests."""
return DataColumn(
name=name,
display_name=name,
data_type=data_type,
sample_values=[],
null_count=null_count,
unique_count=unique_count,
)
def test_recommend_temporal_and_numeric_suggests_line_chart():
cols = [_col("created_at", "temporal"), _col("revenue", "numeric")]
result = _recommend_visualizations("table", cols, row_count=50)
assert "line chart" in result
assert "area chart" in result
def test_recommend_categorical_and_numeric_suggests_bar_chart():
cols = [_col("region", "string", unique_count=5), _col("sales", "numeric")]
result = _recommend_visualizations("echarts_timeseries_line", cols, row_count=50)
assert "bar chart" in result
def test_recommend_excludes_current_viz_type():
cols = [_col("created_at", "temporal"), _col("revenue", "numeric")]
result = _recommend_visualizations("echarts_timeseries_line", cols, row_count=50)
assert "line chart" not in result
def test_recommend_multiple_numeric_suggests_scatter():
cols = [
_col("height", "numeric"),
_col("weight", "numeric"),
_col("age", "numeric"),
]
result = _recommend_visualizations("table", cols, row_count=100)
assert "scatter plot" in result
def test_recommend_single_numeric_suggests_kpi():
cols = [_col("total_revenue", "numeric")]
result = _recommend_visualizations("table", cols, row_count=1)
assert "big number / KPI" in result
def test_recommend_all_strings_falls_back():
cols = [_col("name", "string"), _col("address", "string")]
result = _recommend_visualizations("pie", cols, row_count=100)
assert "table" in result or "bar chart" in result
def test_recommend_high_cardinality_no_pie():
cols = [
_col("user_id", "string", unique_count=900),
_col("score", "numeric"),
]
result = _recommend_visualizations("table", cols, row_count=1000)
assert "pie chart" not in result
def test_recommend_caps_at_max():
cols = [_col("ts", "temporal"), _col("a", "numeric"), _col("b", "numeric")]
result = _recommend_visualizations("table", cols, row_count=100)
assert len(result) <= _MAX_RECOMMENDATIONS
def test_recommend_empty_columns_returns_table():
result = _recommend_visualizations("table", [], row_count=0)
assert result == ["table"]
def test_recommend_pie_only_for_low_cardinality():
cols = [
_col("department", "string", unique_count=25),
_col("headcount", "numeric"),
]
result = _recommend_visualizations("table", cols, row_count=100)
assert "pie chart" not in result