fix(mcp): sanitize read path output for LLM context (#39738)

This commit is contained in:
Richard Fogaca Nienkotter
2026-04-29 19:06:19 -03:00
committed by GitHub
parent 81a08f0a0e
commit c2b9272f4c
22 changed files with 2781 additions and 321 deletions

View File

@@ -46,6 +46,7 @@ from superset.mcp_service.chart.schemas import (
GetChartDataRequest,
PerformanceMetadata,
)
from superset.mcp_service.utils import sanitize_for_llm_context
from superset.mcp_service.utils.cache_utils import get_cache_status_from_result
from superset.mcp_service.utils.oauth2_utils import (
build_oauth2_redirect_message,
@@ -56,6 +57,40 @@ from superset.utils.core import merge_extra_filters
logger = logging.getLogger(__name__)
def _sanitize_chart_data_for_llm_context(chart_data: ChartData) -> ChartData:
"""Wrap chart data read-path descriptive fields before LLM exposure."""
payload = chart_data.model_dump(mode="python")
for field_name in ("chart_name", "summary", "csv_data"):
payload[field_name] = sanitize_for_llm_context(
payload.get(field_name),
field_path=(field_name,),
)
payload["insights"] = sanitize_for_llm_context(
payload.get("insights", []),
field_path=("insights",),
)
payload["data"] = sanitize_for_llm_context(
payload.get("data", []),
field_path=("data",),
excluded_field_names=frozenset(),
)
payload["columns"] = [
{
**column,
"sample_values": sanitize_for_llm_context(
column.get("sample_values", []),
field_path=("columns", str(index), "sample_values"),
excluded_field_names=frozenset(),
),
}
for index, column in enumerate(payload.get("columns", []))
]
return ChartData.model_validate(payload)
def _apply_extra_form_data(
form_data: dict[str, Any], extra_form_data: dict[str, Any] | None
) -> None:
@@ -745,21 +780,23 @@ async def get_chart_data( # noqa: C901
)
# Default JSON format
return ChartData(
chart_id=chart.id,
chart_name=chart.slice_name or f"Chart {chart.id}",
chart_type=chart.viz_type or "unknown",
columns=columns,
data=data[: request.limit] if request.limit else data,
row_count=len(data),
total_rows=query_result.get("rowcount"),
summary=summary,
insights=insights,
data_quality={"completeness": data_completeness},
recommended_visualizations=recommended_visualizations,
data_freshness=None, # Add missing field
performance=performance,
cache_status=cache_status,
return _sanitize_chart_data_for_llm_context(
ChartData(
chart_id=chart.id,
chart_name=chart.slice_name or f"Chart {chart.id}",
chart_type=chart.viz_type or "unknown",
columns=columns,
data=data[: request.limit] if request.limit else data,
row_count=len(data),
total_rows=query_result.get("rowcount"),
summary=summary,
insights=insights,
data_quality={"completeness": data_completeness},
recommended_visualizations=recommended_visualizations,
data_freshness=None, # Add missing field
performance=performance,
cache_status=cache_status,
)
)
except (CommandException, SupersetException, ValueError) as data_error:
@@ -929,30 +966,32 @@ async def _query_from_form_data(
)
await ctx.report_progress(4, 4, "Building response")
return ChartData(
chart_id=0,
chart_name=chart_name,
chart_type=viz_type,
columns=columns,
data=data[: request.limit] if request.limit else data,
row_count=len(data),
total_rows=query_result.get("rowcount"),
summary=summary,
insights=["This is an unsaved chart queried from cached form_data."],
data_quality={
"completeness": 1.0
- (
sum(col.null_count for col in columns)
/ max(len(data) * len(columns), 1)
)
},
recommended_visualizations=[],
data_freshness=None,
performance=PerformanceMetadata(
query_duration_ms=0,
cache_status="fresh_query",
),
cache_status=cache_status,
return _sanitize_chart_data_for_llm_context(
ChartData(
chart_id=0,
chart_name=chart_name,
chart_type=viz_type,
columns=columns,
data=data[: request.limit] if request.limit else data,
row_count=len(data),
total_rows=query_result.get("rowcount"),
summary=summary,
insights=["This is an unsaved chart queried from cached form_data."],
data_quality={
"completeness": 1.0
- (
sum(col.null_count for col in columns)
/ max(len(data) * len(columns), 1)
)
},
recommended_visualizations=[],
data_freshness=None,
performance=PerformanceMetadata(
query_duration_ms=0,
cache_status="fresh_query",
),
cache_status=cache_status,
)
)
except (CommandException, SupersetException, ValueError) as e:
@@ -1001,24 +1040,26 @@ def _export_data_as_csv(
# Return as ChartData with CSV content in a special field
from superset.mcp_service.chart.schemas import ChartData
return ChartData(
chart_id=chart.id,
chart_name=chart.slice_name or f"Chart {chart.id}",
chart_type=chart.viz_type or "unknown",
columns=[], # Column names are embedded in CSV content
data=[], # CSV content is in csv_data field
row_count=len(data),
total_rows=len(data),
summary=f"CSV export of chart '{chart.slice_name}' with {len(data)} rows",
insights=[f"Data exported as CSV format ({len(csv_content)} characters)"],
data_quality={},
recommended_visualizations=[],
data_freshness=None,
performance=performance,
cache_status=cache_status,
# Store CSV content in data field as string for the response
csv_data=csv_content,
format="csv",
return _sanitize_chart_data_for_llm_context(
ChartData(
chart_id=chart.id,
chart_name=chart.slice_name or f"Chart {chart.id}",
chart_type=chart.viz_type or "unknown",
columns=[], # Column names are embedded in CSV content
data=[], # CSV content is in csv_data field
row_count=len(data),
total_rows=len(data),
summary=f"CSV export of chart '{chart.slice_name}' with {len(data)} rows",
insights=[f"Data exported as CSV format ({len(csv_content)} characters)"],
data_quality={},
recommended_visualizations=[],
data_freshness=None,
performance=performance,
cache_status=cache_status,
# Store CSV content in data field as string for the response
csv_data=csv_content,
format="csv",
)
)
@@ -1156,23 +1197,25 @@ def _create_excel_chart_data(
chart_name = chart.slice_name or f"Chart {chart.id}"
summary = f"Excel export of chart '{chart.slice_name}' with {len(data)} rows"
return ChartData(
chart_id=chart.id,
chart_name=chart_name,
chart_type=chart.viz_type or "unknown",
columns=[], # Column names are embedded in the Excel file
data=[],
row_count=len(data),
total_rows=len(data),
summary=summary,
insights=["Data exported as Excel format (base64 encoded)"],
data_quality={},
recommended_visualizations=[],
data_freshness=None,
performance=performance,
cache_status=cache_status,
excel_data=excel_b64,
format="excel",
return _sanitize_chart_data_for_llm_context(
ChartData(
chart_id=chart.id,
chart_name=chart_name,
chart_type=chart.viz_type or "unknown",
columns=[], # Column names are embedded in the Excel file
data=[],
row_count=len(data),
total_rows=len(data),
summary=summary,
insights=["Data exported as Excel format (base64 encoded)"],
data_quality={},
recommended_visualizations=[],
data_freshness=None,
performance=performance,
cache_status=cache_status,
excel_data=excel_b64,
format="excel",
)
)
@@ -1189,21 +1232,23 @@ def _create_excel_chart_data_xlsxwriter(
chart_name = chart.slice_name or f"Chart {chart.id}"
summary = f"Excel export of chart '{chart.slice_name}' with {len(data)} rows"
return ChartData(
chart_id=chart.id,
chart_name=chart_name,
chart_type=chart.viz_type or "unknown",
columns=[], # Column names are embedded in the Excel file
data=[],
row_count=len(data),
total_rows=len(data),
summary=summary,
insights=["Data exported as Excel format (base64 encoded, xlsxwriter)"],
data_quality={},
recommended_visualizations=[],
data_freshness=None,
performance=performance,
cache_status=cache_status,
excel_data=excel_b64,
format="excel",
return _sanitize_chart_data_for_llm_context(
ChartData(
chart_id=chart.id,
chart_name=chart_name,
chart_type=chart.viz_type or "unknown",
columns=[], # Column names are embedded in the Excel file
data=[],
row_count=len(data),
total_rows=len(data),
summary=summary,
insights=["Data exported as Excel format (base64 encoded, xlsxwriter)"],
data_quality={},
recommended_visualizations=[],
data_freshness=None,
performance=performance,
cache_status=cache_status,
excel_data=excel_b64,
format="excel",
)
)