Compare commits

...

3 Commits

Author SHA1 Message Date
alexandrusoare
879b1c2255 taking care of PR comments 2026-05-06 11:28:03 +03:00
alexandrusoare
81d6fc26e2 improvements 2026-05-05 13:30:02 +03:00
alexandrusoare
8b2a8d21c1 fix(recommandation): fix chart recommandation 2026-05-04 13:19:38 +03:00
2 changed files with 317 additions and 14 deletions

View File

@@ -52,10 +52,177 @@ from superset.mcp_service.utils.oauth2_utils import (
build_oauth2_redirect_message,
OAUTH2_CONFIG_ERROR_MESSAGE,
)
from superset.utils.core import merge_extra_filters
from superset.utils.core import GenericDataType, merge_extra_filters
logger = logging.getLogger(__name__)
_GENERIC_TYPE_MAP: dict[int, str] = {
GenericDataType.NUMERIC: "numeric",
GenericDataType.STRING: "string",
GenericDataType.TEMPORAL: "temporal",
GenericDataType.BOOLEAN: "boolean",
}
# Maps Superset viz_type strings to canonical categories so we can
# avoid recommending a chart type the user already has.
_VIZ_CATEGORY: dict[str, str] = {
"echarts_timeseries_line": "line",
"echarts_timeseries_smooth": "line",
"echarts_timeseries_step": "line",
"echarts_timeseries": "line",
"echarts_timeseries_bar": "bar",
"echarts_area": "area",
"echarts_timeseries_scatter": "scatter",
"mixed_timeseries": "line",
"table": "table",
"pie": "pie",
"big_number": "kpi",
"big_number_total": "kpi",
"pop_kpi": "kpi",
"dist_bar": "bar",
"line": "line",
"area": "area",
"scatter": "scatter",
"bubble": "bubble",
"treemap_v2": "treemap",
"sunburst_v2": "treemap",
"heatmap_v2": "heatmap",
"gauge_chart": "gauge",
"funnel": "funnel",
"histogram": "histogram",
"histogram_v2": "histogram",
"box_plot": "box_plot",
"world_map": "map",
"pivot_table_v2": "table",
}
_MAX_RECOMMENDATIONS = 4
def _recommend_visualizations(
viz_type: str,
columns: list[DataColumn],
row_count: int,
) -> list[str]:
"""Suggest visualization types based on column types,
cardinality, and the chart's current viz_type.
"""
if not columns:
return ["table"]
current_category = _VIZ_CATEGORY.get(viz_type, viz_type)
candidates = _build_candidates(columns, row_count)
if not candidates:
candidates = ["table", "bar chart"]
return _filter_candidates(candidates, current_category)
def _build_candidates(
columns: list[DataColumn],
row_count: int,
) -> list[str]:
"""Build candidate visualization list from column metadata."""
temporal = [c for c in columns if c.data_type == "temporal"]
numeric = [c for c in columns if c.data_type == "numeric"]
categorical = [c for c in columns if c.data_type in ("string", "boolean")]
if temporal and numeric:
return _candidates_temporal_numeric(numeric, row_count)
if categorical and numeric:
return _candidates_categorical_numeric(numeric, categorical)
if len(numeric) >= 2:
return _candidates_multi_numeric(numeric, categorical)
if len(numeric) == 1 and not temporal and not categorical:
return _candidates_single_numeric(numeric[0], row_count)
return []
def _candidates_temporal_numeric(
numeric: list[DataColumn], row_count: int
) -> list[str]:
# Few data points are better as a bar chart than a line
if row_count < 5:
candidates = ["bar chart", "table"]
else:
candidates = ["line chart", "area chart", "bar chart"]
if len(numeric) > 1:
candidates.append("multi-line chart")
return candidates
def _candidates_categorical_numeric(
numeric: list[DataColumn],
categorical: list[DataColumn],
) -> list[str]:
candidates = ["bar chart"]
if len(numeric) == 1 and categorical[0].unique_count <= 10:
candidates.append("pie chart")
if len(numeric) >= 2:
candidates.append("scatter plot")
candidates.append("heatmap")
if any(c.unique_count > 5 for c in categorical):
candidates.append("treemap")
return candidates
def _candidates_single_numeric(col: DataColumn, row_count: int) -> list[str]:
candidates = ["big number / KPI", "gauge chart"]
if row_count > 20 and col.unique_count > 10:
candidates.insert(0, "histogram")
return candidates
def _candidates_multi_numeric(
numeric: list[DataColumn],
categorical: list[DataColumn],
) -> list[str]:
candidates = ["scatter plot"]
if len(numeric) >= 3:
candidates.append("bubble chart")
if categorical:
candidates.append("heatmap")
return candidates
# Maps each candidate string to a canonical category for dedup
# against the current viz_type.
_CANDIDATE_CATEGORY: dict[str, str] = {
"line chart": "line",
"multi-line chart": "line",
"area chart": "area",
"bar chart": "bar",
"scatter plot": "scatter",
"bubble chart": "bubble",
"pie chart": "pie",
"treemap": "treemap",
"heatmap": "heatmap",
"big number / KPI": "kpi",
"gauge chart": "gauge",
"histogram": "histogram",
"table": "table",
}
def _filter_candidates(
candidates: list[str],
current_category: str,
) -> list[str]:
"""Deduplicate, exclude the current viz category, and cap."""
seen: set[str] = set()
result: list[str] = []
for c in candidates:
if c in seen:
continue
if _CANDIDATE_CATEGORY.get(c) == current_category:
continue
seen.add(c)
result.append(c)
if len(result) >= _MAX_RECOMMENDATIONS:
break
return result
def _sanitize_chart_data_for_llm_context(chart_data: ChartData) -> ChartData:
"""Wrap chart data read-path descriptive fields before LLM exposure."""
@@ -620,8 +787,9 @@ async def get_chart_data( # noqa: C901
)
# Create rich column metadata
coltypes = query_result.get("coltypes", [])
columns = []
for col_name in raw_columns:
for idx, col_name in enumerate(raw_columns):
# Sample some values for metadata
sample_values = [
row.get(col_name)
@@ -629,13 +797,16 @@ async def get_chart_data( # noqa: C901
if row.get(col_name) is not None
]
# Infer data type
# Use SQL-derived GenericDataType when available,
# fall back to Python isinstance heuristic
data_type = "string"
if sample_values:
if all(isinstance(v, (int, float)) for v in sample_values):
data_type = "numeric"
elif all(isinstance(v, bool) for v in sample_values):
if coltypes:
data_type = _GENERIC_TYPE_MAP.get(coltypes[idx], "string")
elif sample_values:
if all(isinstance(v, bool) for v in sample_values):
data_type = "boolean"
elif all(isinstance(v, (int, float)) for v in sample_values):
data_type = "numeric"
columns.append(
DataColumn(
@@ -678,13 +849,11 @@ async def get_chart_data( # noqa: C901
else:
insights.append("Fresh data retrieved from database")
recommended_visualizations = []
if any(
"time" in col.lower() or "date" in col.lower() for col in raw_columns
):
recommended_visualizations.extend(["line chart", "time series"])
if len(raw_columns) <= 3:
recommended_visualizations.extend(["bar chart", "scatter plot"])
recommended_visualizations = _recommend_visualizations(
viz_type=chart.viz_type or "unknown",
columns=columns,
row_count=len(data),
)
# Performance metadata with cache awareness
execution_time = int((time.time() - start_time) * 1000)

View File

@@ -30,10 +30,14 @@ from superset.mcp_service.chart.schemas import (
PerformanceMetadata,
)
from superset.mcp_service.chart.tool.get_chart_data import (
_GENERIC_TYPE_MAP,
_MAX_RECOMMENDATIONS,
_recommend_visualizations,
_sanitize_chart_data_for_llm_context,
)
from superset.mcp_service.utils import sanitize_for_llm_context
from superset.mcp_service.utils.sanitization import LLM_CONTEXT_ESCAPED_CLOSE_DELIMITER
from superset.utils.core import GenericDataType
def _collect_groupby_extras(
@@ -988,3 +992,133 @@ class TestChartDataCommandValidation:
)
mock_command.run.assert_not_called()
# ---------------------------------------------------------------------------
# Tests for _recommend_visualizations
# ---------------------------------------------------------------------------
def _col(
name: str,
data_type: str = "string",
unique_count: int = 5,
null_count: int = 0,
) -> DataColumn:
"""Shortcut to build a DataColumn for tests."""
return DataColumn(
name=name,
display_name=name,
data_type=data_type,
sample_values=[],
null_count=null_count,
unique_count=unique_count,
)
def test_recommend_temporal_and_numeric_suggests_line_chart():
cols = [_col("created_at", "temporal"), _col("revenue", "numeric")]
result = _recommend_visualizations("table", cols, row_count=50)
assert "line chart" in result
assert "area chart" in result
def test_recommend_categorical_and_numeric_suggests_bar_chart():
cols = [_col("region", "string", unique_count=5), _col("sales", "numeric")]
result = _recommend_visualizations("echarts_timeseries_line", cols, row_count=50)
assert "bar chart" in result
def test_recommend_excludes_current_viz_type():
cols = [_col("created_at", "temporal"), _col("revenue", "numeric")]
result = _recommend_visualizations("echarts_timeseries_line", cols, row_count=50)
assert "line chart" not in result
def test_recommend_multiple_numeric_suggests_scatter():
cols = [
_col("height", "numeric"),
_col("weight", "numeric"),
_col("age", "numeric"),
]
result = _recommend_visualizations("table", cols, row_count=100)
assert "scatter plot" in result
def test_recommend_single_numeric_suggests_kpi():
cols = [_col("total_revenue", "numeric")]
result = _recommend_visualizations("table", cols, row_count=1)
assert "big number / KPI" in result
def test_recommend_all_strings_falls_back():
cols = [_col("name", "string"), _col("address", "string")]
result = _recommend_visualizations("pie", cols, row_count=100)
assert "table" in result or "bar chart" in result
def test_recommend_high_cardinality_no_pie():
cols = [
_col("user_id", "string", unique_count=900),
_col("score", "numeric"),
]
result = _recommend_visualizations("table", cols, row_count=1000)
assert "pie chart" not in result
def test_recommend_caps_at_max():
cols = [_col("ts", "temporal"), _col("a", "numeric"), _col("b", "numeric")]
result = _recommend_visualizations("table", cols, row_count=100)
assert len(result) <= _MAX_RECOMMENDATIONS
def test_recommend_empty_columns_returns_table():
result = _recommend_visualizations("table", [], row_count=0)
assert result == ["table"]
def test_recommend_pie_only_for_low_cardinality():
cols = [
_col("department", "string", unique_count=25),
_col("headcount", "numeric"),
]
result = _recommend_visualizations("table", cols, row_count=100)
assert "pie chart" not in result
def test_recommend_temporal_few_rows_prefers_bar():
cols = [_col("date", "temporal"), _col("revenue", "numeric")]
result = _recommend_visualizations("table", cols, row_count=3)
assert "bar chart" in result
assert "line chart" not in result
def test_recommend_single_numeric_high_cardinality_suggests_histogram():
cols = [_col("salary", "numeric", unique_count=500)]
result = _recommend_visualizations("table", cols, row_count=1000)
assert "histogram" in result
def test_coltypes_populates_data_type():
"""Verify that GenericDataType values from coltypes are mapped correctly."""
assert _GENERIC_TYPE_MAP[GenericDataType.NUMERIC] == "numeric"
assert _GENERIC_TYPE_MAP[GenericDataType.STRING] == "string"
assert _GENERIC_TYPE_MAP[GenericDataType.TEMPORAL] == "temporal"
assert _GENERIC_TYPE_MAP[GenericDataType.BOOLEAN] == "boolean"
def test_bool_isinstance_check_before_int():
"""bool is a subclass of int; verify bool check takes priority in fallback."""
# When coltypes is unavailable, the fallback isinstance heuristic
# must check bool before int/float since isinstance(True, int) is True.
# We verify this indirectly: if _GENERIC_TYPE_MAP handles bool correctly,
# and the fallback code checks bool first, booleans won't be "numeric".
# Direct test: simulate what the fallback does
sample_values = [True, False, True]
data_type = "string"
if all(isinstance(v, bool) for v in sample_values):
data_type = "boolean"
elif all(isinstance(v, (int, float)) for v in sample_values):
data_type = "numeric"
assert data_type == "boolean"