mirror of
https://github.com/apache/superset.git
synced 2026-05-12 19:35:17 +00:00
fix(recommandation): fix chart recommandation
This commit is contained in:
@@ -52,10 +52,143 @@ from superset.mcp_service.utils.oauth2_utils import (
|
||||
build_oauth2_redirect_message,
|
||||
OAUTH2_CONFIG_ERROR_MESSAGE,
|
||||
)
|
||||
from superset.utils.core import merge_extra_filters
|
||||
from superset.utils.core import GenericDataType, merge_extra_filters
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_GENERIC_TYPE_MAP: dict[int, str] = {
|
||||
GenericDataType.NUMERIC: "numeric",
|
||||
GenericDataType.STRING: "string",
|
||||
GenericDataType.TEMPORAL: "temporal",
|
||||
GenericDataType.BOOLEAN: "boolean",
|
||||
}
|
||||
|
||||
# Maps Superset viz_type strings to canonical categories so we can
|
||||
# avoid recommending a chart type the user already has.
|
||||
_VIZ_CATEGORY: dict[str, str] = {
|
||||
"echarts_timeseries_line": "line",
|
||||
"echarts_timeseries_bar": "bar",
|
||||
"echarts_area": "area",
|
||||
"echarts_timeseries_scatter": "scatter",
|
||||
"table": "table",
|
||||
"pie": "pie",
|
||||
"big_number": "kpi",
|
||||
"big_number_total": "kpi",
|
||||
"dist_bar": "bar",
|
||||
"line": "line",
|
||||
"area": "area",
|
||||
"scatter": "scatter",
|
||||
"bubble": "bubble",
|
||||
"treemap_v2": "treemap",
|
||||
"heatmap_v2": "heatmap",
|
||||
"gauge_chart": "gauge",
|
||||
"histogram": "histogram",
|
||||
"box_plot": "box_plot",
|
||||
"world_map": "map",
|
||||
}
|
||||
|
||||
_MAX_RECOMMENDATIONS = 4
|
||||
|
||||
|
||||
def _recommend_visualizations(
|
||||
viz_type: str,
|
||||
columns: list[DataColumn],
|
||||
row_count: int,
|
||||
) -> list[str]:
|
||||
"""Suggest visualization types based on column types,
|
||||
cardinality, and the chart's current viz_type.
|
||||
"""
|
||||
if not columns:
|
||||
return ["table"]
|
||||
|
||||
current_category = _VIZ_CATEGORY.get(viz_type, viz_type)
|
||||
candidates = _build_candidates(columns, row_count)
|
||||
|
||||
if not candidates:
|
||||
candidates = ["table", "bar chart"]
|
||||
|
||||
return _filter_candidates(candidates, current_category)
|
||||
|
||||
|
||||
def _build_candidates(
|
||||
columns: list[DataColumn],
|
||||
row_count: int,
|
||||
) -> list[str]:
|
||||
"""Build candidate visualization list from column metadata."""
|
||||
temporal = [c for c in columns if c.data_type == "temporal"]
|
||||
numeric = [c for c in columns if c.data_type == "numeric"]
|
||||
categorical = [
|
||||
c
|
||||
for c in columns
|
||||
if c.data_type in ("string", "boolean")
|
||||
and (
|
||||
c.unique_count <= 20 or (row_count > 0 and c.unique_count / row_count < 0.5)
|
||||
)
|
||||
]
|
||||
|
||||
candidates: list[str] = []
|
||||
|
||||
if temporal and numeric:
|
||||
candidates.append("line chart")
|
||||
candidates.append("area chart")
|
||||
candidates.append("bar chart")
|
||||
if len(numeric) > 1:
|
||||
candidates.append("multi-line chart")
|
||||
elif categorical and numeric:
|
||||
candidates.append("bar chart")
|
||||
if len(numeric) == 1 and categorical and categorical[0].unique_count <= 10:
|
||||
candidates.append("pie chart")
|
||||
if any(c.unique_count > 5 for c in categorical):
|
||||
candidates.append("treemap")
|
||||
elif len(numeric) >= 2:
|
||||
candidates.append("scatter plot")
|
||||
if len(numeric) >= 3:
|
||||
candidates.append("bubble chart")
|
||||
if categorical:
|
||||
candidates.append("heatmap")
|
||||
elif len(numeric) == 1 and not temporal and not categorical:
|
||||
candidates.append("big number / KPI")
|
||||
candidates.append("gauge chart")
|
||||
|
||||
return candidates
|
||||
|
||||
|
||||
# Maps each candidate string to a canonical category for dedup
|
||||
# against the current viz_type.
|
||||
_CANDIDATE_CATEGORY: dict[str, str] = {
|
||||
"line chart": "line",
|
||||
"multi-line chart": "line",
|
||||
"area chart": "area",
|
||||
"bar chart": "bar",
|
||||
"scatter plot": "scatter",
|
||||
"bubble chart": "bubble",
|
||||
"pie chart": "pie",
|
||||
"treemap": "treemap",
|
||||
"heatmap": "heatmap",
|
||||
"big number / KPI": "kpi",
|
||||
"gauge chart": "gauge",
|
||||
"table": "table",
|
||||
}
|
||||
|
||||
|
||||
def _filter_candidates(
|
||||
candidates: list[str],
|
||||
current_category: str,
|
||||
) -> list[str]:
|
||||
"""Deduplicate, exclude the current viz category, and cap."""
|
||||
seen: set[str] = set()
|
||||
result: list[str] = []
|
||||
for c in candidates:
|
||||
if c in seen:
|
||||
continue
|
||||
if _CANDIDATE_CATEGORY.get(c) == current_category:
|
||||
continue
|
||||
seen.add(c)
|
||||
result.append(c)
|
||||
if len(result) >= _MAX_RECOMMENDATIONS:
|
||||
break
|
||||
return result
|
||||
|
||||
|
||||
def _sanitize_chart_data_for_llm_context(chart_data: ChartData) -> ChartData:
|
||||
"""Wrap chart data read-path descriptive fields before LLM exposure."""
|
||||
@@ -620,8 +753,9 @@ async def get_chart_data( # noqa: C901
|
||||
)
|
||||
|
||||
# Create rich column metadata
|
||||
coltypes = query_result.get("coltypes", [])
|
||||
columns = []
|
||||
for col_name in raw_columns:
|
||||
for idx, col_name in enumerate(raw_columns):
|
||||
# Sample some values for metadata
|
||||
sample_values = [
|
||||
row.get(col_name)
|
||||
@@ -629,9 +763,12 @@ async def get_chart_data( # noqa: C901
|
||||
if row.get(col_name) is not None
|
||||
]
|
||||
|
||||
# Infer data type
|
||||
# Use SQL-derived GenericDataType when available,
|
||||
# fall back to Python isinstance heuristic
|
||||
data_type = "string"
|
||||
if sample_values:
|
||||
if idx < len(coltypes):
|
||||
data_type = _GENERIC_TYPE_MAP.get(coltypes[idx], "string")
|
||||
elif sample_values:
|
||||
if all(isinstance(v, (int, float)) for v in sample_values):
|
||||
data_type = "numeric"
|
||||
elif all(isinstance(v, bool) for v in sample_values):
|
||||
@@ -678,13 +815,11 @@ async def get_chart_data( # noqa: C901
|
||||
else:
|
||||
insights.append("Fresh data retrieved from database")
|
||||
|
||||
recommended_visualizations = []
|
||||
if any(
|
||||
"time" in col.lower() or "date" in col.lower() for col in raw_columns
|
||||
):
|
||||
recommended_visualizations.extend(["line chart", "time series"])
|
||||
if len(raw_columns) <= 3:
|
||||
recommended_visualizations.extend(["bar chart", "scatter plot"])
|
||||
recommended_visualizations = _recommend_visualizations(
|
||||
viz_type=chart.viz_type or "unknown",
|
||||
columns=columns,
|
||||
row_count=len(data),
|
||||
)
|
||||
|
||||
# Performance metadata with cache awareness
|
||||
execution_time = int((time.time() - start_time) * 1000)
|
||||
|
||||
@@ -30,6 +30,8 @@ from superset.mcp_service.chart.schemas import (
|
||||
PerformanceMetadata,
|
||||
)
|
||||
from superset.mcp_service.chart.tool.get_chart_data import (
|
||||
_MAX_RECOMMENDATIONS,
|
||||
_recommend_visualizations,
|
||||
_sanitize_chart_data_for_llm_context,
|
||||
)
|
||||
from superset.mcp_service.utils import sanitize_for_llm_context
|
||||
@@ -988,3 +990,95 @@ class TestChartDataCommandValidation:
|
||||
)
|
||||
|
||||
mock_command.run.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests for _recommend_visualizations
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _col(
|
||||
name: str,
|
||||
data_type: str = "string",
|
||||
unique_count: int = 5,
|
||||
null_count: int = 0,
|
||||
) -> DataColumn:
|
||||
"""Shortcut to build a DataColumn for tests."""
|
||||
return DataColumn(
|
||||
name=name,
|
||||
display_name=name,
|
||||
data_type=data_type,
|
||||
sample_values=[],
|
||||
null_count=null_count,
|
||||
unique_count=unique_count,
|
||||
)
|
||||
|
||||
|
||||
def test_recommend_temporal_and_numeric_suggests_line_chart():
|
||||
cols = [_col("created_at", "temporal"), _col("revenue", "numeric")]
|
||||
result = _recommend_visualizations("table", cols, row_count=50)
|
||||
assert "line chart" in result
|
||||
assert "area chart" in result
|
||||
|
||||
|
||||
def test_recommend_categorical_and_numeric_suggests_bar_chart():
|
||||
cols = [_col("region", "string", unique_count=5), _col("sales", "numeric")]
|
||||
result = _recommend_visualizations("echarts_timeseries_line", cols, row_count=50)
|
||||
assert "bar chart" in result
|
||||
|
||||
|
||||
def test_recommend_excludes_current_viz_type():
|
||||
cols = [_col("created_at", "temporal"), _col("revenue", "numeric")]
|
||||
result = _recommend_visualizations("echarts_timeseries_line", cols, row_count=50)
|
||||
assert "line chart" not in result
|
||||
|
||||
|
||||
def test_recommend_multiple_numeric_suggests_scatter():
|
||||
cols = [
|
||||
_col("height", "numeric"),
|
||||
_col("weight", "numeric"),
|
||||
_col("age", "numeric"),
|
||||
]
|
||||
result = _recommend_visualizations("table", cols, row_count=100)
|
||||
assert "scatter plot" in result
|
||||
|
||||
|
||||
def test_recommend_single_numeric_suggests_kpi():
|
||||
cols = [_col("total_revenue", "numeric")]
|
||||
result = _recommend_visualizations("table", cols, row_count=1)
|
||||
assert "big number / KPI" in result
|
||||
|
||||
|
||||
def test_recommend_all_strings_falls_back():
|
||||
cols = [_col("name", "string"), _col("address", "string")]
|
||||
result = _recommend_visualizations("pie", cols, row_count=100)
|
||||
assert "table" in result or "bar chart" in result
|
||||
|
||||
|
||||
def test_recommend_high_cardinality_no_pie():
|
||||
cols = [
|
||||
_col("user_id", "string", unique_count=900),
|
||||
_col("score", "numeric"),
|
||||
]
|
||||
result = _recommend_visualizations("table", cols, row_count=1000)
|
||||
assert "pie chart" not in result
|
||||
|
||||
|
||||
def test_recommend_caps_at_max():
|
||||
cols = [_col("ts", "temporal"), _col("a", "numeric"), _col("b", "numeric")]
|
||||
result = _recommend_visualizations("table", cols, row_count=100)
|
||||
assert len(result) <= _MAX_RECOMMENDATIONS
|
||||
|
||||
|
||||
def test_recommend_empty_columns_returns_table():
|
||||
result = _recommend_visualizations("table", [], row_count=0)
|
||||
assert result == ["table"]
|
||||
|
||||
|
||||
def test_recommend_pie_only_for_low_cardinality():
|
||||
cols = [
|
||||
_col("department", "string", unique_count=25),
|
||||
_col("headcount", "numeric"),
|
||||
]
|
||||
result = _recommend_visualizations("table", cols, row_count=100)
|
||||
assert "pie chart" not in result
|
||||
|
||||
Reference in New Issue
Block a user