diff --git a/superset/mcp_service/chart/schemas.py b/superset/mcp_service/chart/schemas.py
index decd3c46921..1fdb4f43ab6 100644
--- a/superset/mcp_service/chart/schemas.py
+++ b/superset/mcp_service/chart/schemas.py
@@ -57,6 +57,10 @@ from superset.mcp_service.system.schemas import (
PaginationInfo,
TagInfo,
)
+from superset.mcp_service.utils import (
+ escape_llm_context_delimiters,
+ sanitize_for_llm_context,
+)
from superset.mcp_service.utils.sanitization import (
sanitize_filter_value,
sanitize_user_input,
@@ -188,6 +192,12 @@ class ChartError(BaseModel):
)
model_config = ConfigDict(ser_json_timedelta="iso8601")
+ @field_validator("error")
+ @classmethod
+ def sanitize_error_for_llm_context(cls, value: str) -> str:
+ """Wrap error text before it is exposed to LLM context."""
+ return sanitize_for_llm_context(value, field_path=("error",))
+
class ChartCapabilities(BaseModel):
"""Describes what the chart can do for LLM understanding."""
@@ -375,6 +385,83 @@ def extract_filters_from_form_data(
)
+CHART_FORM_DATA_EXCLUDED_FIELD_NAMES = frozenset(
+ {
+ "all_columns",
+ "columns",
+ "datasource",
+ "datasource_id",
+ "datasource_name",
+ "datasource_type",
+ "entity",
+ "form_data_key",
+ "groupby",
+ "metric",
+ "metrics",
+ "series",
+ "slice_id",
+ "viz_type",
+ "x",
+ "y",
+ "size",
+ }
+)
+
+
+def sanitize_chart_info_for_llm_context(chart_info: ChartInfo) -> ChartInfo:
+ """Wrap chart read-path descriptive fields before LLM exposure."""
+ payload = chart_info.model_dump(mode="python")
+
+ for field_name in (
+ "slice_name",
+ "description",
+ "certified_by",
+ "certification_details",
+ ):
+ payload[field_name] = sanitize_for_llm_context(
+ payload.get(field_name),
+ field_path=(field_name,),
+ )
+
+ payload["datasource_name"] = escape_llm_context_delimiters(
+ payload.get("datasource_name")
+ )
+
+ if payload.get("filters") is not None:
+ payload["filters"] = sanitize_for_llm_context(
+ payload["filters"],
+ field_path=("filters",),
+ excluded_field_names=frozenset(),
+ )
+
+ if payload.get("form_data") is not None:
+ payload["form_data"] = sanitize_for_llm_context(
+ payload["form_data"],
+ field_path=("form_data",),
+ excluded_field_names=(
+ CHART_FORM_DATA_EXCLUDED_FIELD_NAMES
+ | frozenset({"cache_key", "database", "database_name", "schema"})
+ ),
+ )
+
+ payload["tags"] = [
+ {
+ **tag,
+ "name": sanitize_for_llm_context(
+ tag.get("name"),
+ field_path=("tags", str(index), "name"),
+ ),
+ "description": sanitize_for_llm_context(
+ tag.get("description"),
+ field_path=("tags", str(index), "description"),
+ ),
+ }
+ for index, tag in enumerate(payload.get("tags", []))
+ ]
+
+ return ChartInfo.model_validate(payload)
+
+
def serialize_chart_object(chart: ChartLike | None) -> ChartInfo | None:
if not chart:
return None
@@ -401,30 +488,38 @@ def serialize_chart_object(chart: ChartLike | None) -> ChartInfo | None:
# Extract structured filter information
filters_info = extract_filters_from_form_data(chart_form_data)
- return ChartInfo(
- id=chart_id,
- slice_name=getattr(chart, "slice_name", None),
- viz_type=getattr(chart, "viz_type", None),
- datasource_name=getattr(chart, "datasource_name", None),
- datasource_type=getattr(chart, "datasource_type", None),
- url=chart_url,
- description=getattr(chart, "description", None),
- certified_by=getattr(chart, "certified_by", None),
- certification_details=getattr(chart, "certification_details", None),
- cache_timeout=getattr(chart, "cache_timeout", None),
- form_data=chart_form_data,
- filters=filters_info,
- changed_on=getattr(chart, "changed_on", None),
- changed_on_humanized=_humanize_timestamp(getattr(chart, "changed_on", None)),
- created_on=getattr(chart, "created_on", None),
- created_on_humanized=_humanize_timestamp(getattr(chart, "created_on", None)),
- uuid=str(getattr(chart, "uuid", "")) if getattr(chart, "uuid", None) else None,
- tags=[
- TagInfo.model_validate(tag, from_attributes=True)
- for tag in getattr(chart, "tags", [])
- ]
- if getattr(chart, "tags", None)
- else [],
+ return sanitize_chart_info_for_llm_context(
+ ChartInfo(
+ id=chart_id,
+ slice_name=getattr(chart, "slice_name", None),
+ viz_type=getattr(chart, "viz_type", None),
+ datasource_name=getattr(chart, "datasource_name", None),
+ datasource_type=getattr(chart, "datasource_type", None),
+ url=chart_url,
+ description=getattr(chart, "description", None),
+ certified_by=getattr(chart, "certified_by", None),
+ certification_details=getattr(chart, "certification_details", None),
+ cache_timeout=getattr(chart, "cache_timeout", None),
+ form_data=chart_form_data,
+ filters=filters_info,
+ changed_on=getattr(chart, "changed_on", None),
+ changed_on_humanized=_humanize_timestamp(
+ getattr(chart, "changed_on", None)
+ ),
+ created_on=getattr(chart, "created_on", None),
+ created_on_humanized=_humanize_timestamp(
+ getattr(chart, "created_on", None)
+ ),
+ uuid=str(getattr(chart, "uuid", ""))
+ if getattr(chart, "uuid", None)
+ else None,
+ tags=[
+ TagInfo.model_validate(tag, from_attributes=True)
+ for tag in getattr(chart, "tags", [])
+ ]
+ if getattr(chart, "tags", None)
+ else [],
+ )
)
diff --git a/superset/mcp_service/chart/tool/generate_chart.py b/superset/mcp_service/chart/tool/generate_chart.py
index f3f7efa5a83..646ac4d4c2a 100644
--- a/superset/mcp_service/chart/tool/generate_chart.py
+++ b/superset/mcp_service/chart/tool/generate_chart.py
@@ -41,11 +41,13 @@ from superset.mcp_service.chart.chart_utils import (
)
from superset.mcp_service.chart.schemas import (
AccessibilityMetadata,
+ CHART_FORM_DATA_EXCLUDED_FIELD_NAMES,
ChartError,
GenerateChartRequest,
GenerateChartResponse,
PerformanceMetadata,
)
+from superset.mcp_service.utils import sanitize_for_llm_context
from superset.mcp_service.utils.oauth2_utils import (
build_oauth2_redirect_message,
OAUTH2_CONFIG_ERROR_MESSAGE,
@@ -55,6 +57,22 @@ from superset.utils import json
logger = logging.getLogger(__name__)
+GENERATE_CHART_FORM_DATA_EXCLUDED_FIELD_NAMES = (
+ CHART_FORM_DATA_EXCLUDED_FIELD_NAMES
+ | frozenset({"cache_key", "database", "database_name", "schema"})
+)
+
+
+def _sanitize_generate_chart_form_data_for_llm_context(
+ form_data: dict[str, Any],
+) -> dict[str, Any]:
+ """Wrap generated-chart form_data before returning it to LLM clients."""
+ return sanitize_for_llm_context(
+ form_data,
+ field_path=("form_data",),
+ excluded_field_names=GENERATE_CHART_FORM_DATA_EXCLUDED_FIELD_NAMES,
+ )
+
@dataclass
class CompileResult:
@@ -476,7 +494,11 @@ async def generate_chart( # noqa: C901
{
"chart": None,
"error": error.model_dump(),
- "form_data": form_data,
+ "form_data": (
+ _sanitize_generate_chart_form_data_for_llm_context(
+ form_data
+ )
+ ),
"performance": {
"query_duration_ms": execution_time,
"cache_status": "error",
@@ -603,7 +625,11 @@ async def generate_chart( # noqa: C901
{
"chart": None,
"error": error.model_dump(),
- "form_data": form_data,
+ "form_data": (
+ _sanitize_generate_chart_form_data_for_llm_context(
+ form_data
+ )
+ ),
"performance": {
"query_duration_ms": execution_time,
"cache_status": "error",
@@ -799,7 +825,7 @@ async def generate_chart( # noqa: C901
"semantics": semantics.model_dump() if semantics else None,
"explore_url": explore_url,
# Form data fields - REQUIRED for chatbot/external client rendering
- "form_data": form_data,
+ "form_data": _sanitize_generate_chart_form_data_for_llm_context(form_data),
"form_data_key": form_data_key,
"api_endpoints": {
"data": f"{get_superset_base_url()}/api/v1/chart/{chart.id}/data/"
diff --git a/superset/mcp_service/chart/tool/get_chart_data.py b/superset/mcp_service/chart/tool/get_chart_data.py
index 9a564395000..7421ab0cdc0 100644
--- a/superset/mcp_service/chart/tool/get_chart_data.py
+++ b/superset/mcp_service/chart/tool/get_chart_data.py
@@ -46,6 +46,7 @@ from superset.mcp_service.chart.schemas import (
GetChartDataRequest,
PerformanceMetadata,
)
+from superset.mcp_service.utils import sanitize_for_llm_context
from superset.mcp_service.utils.cache_utils import get_cache_status_from_result
from superset.mcp_service.utils.oauth2_utils import (
build_oauth2_redirect_message,
@@ -56,6 +57,40 @@ from superset.utils.core import merge_extra_filters
logger = logging.getLogger(__name__)
+def _sanitize_chart_data_for_llm_context(chart_data: ChartData) -> ChartData:
+ """Wrap chart data read-path descriptive fields before LLM exposure."""
+ payload = chart_data.model_dump(mode="python")
+
+ for field_name in ("chart_name", "summary", "csv_data"):
+ payload[field_name] = sanitize_for_llm_context(
+ payload.get(field_name),
+ field_path=(field_name,),
+ )
+
+ payload["insights"] = sanitize_for_llm_context(
+ payload.get("insights", []),
+ field_path=("insights",),
+ )
+ payload["data"] = sanitize_for_llm_context(
+ payload.get("data", []),
+ field_path=("data",),
+ excluded_field_names=frozenset(),
+ )
+ payload["columns"] = [
+ {
+ **column,
+ "sample_values": sanitize_for_llm_context(
+ column.get("sample_values", []),
+ field_path=("columns", str(index), "sample_values"),
+ excluded_field_names=frozenset(),
+ ),
+ }
+ for index, column in enumerate(payload.get("columns", []))
+ ]
+
+ return ChartData.model_validate(payload)
+
+
def _apply_extra_form_data(
form_data: dict[str, Any], extra_form_data: dict[str, Any] | None
) -> None:
@@ -745,21 +780,23 @@ async def get_chart_data( # noqa: C901
)
# Default JSON format
- return ChartData(
- chart_id=chart.id,
- chart_name=chart.slice_name or f"Chart {chart.id}",
- chart_type=chart.viz_type or "unknown",
- columns=columns,
- data=data[: request.limit] if request.limit else data,
- row_count=len(data),
- total_rows=query_result.get("rowcount"),
- summary=summary,
- insights=insights,
- data_quality={"completeness": data_completeness},
- recommended_visualizations=recommended_visualizations,
- data_freshness=None, # Add missing field
- performance=performance,
- cache_status=cache_status,
+ return _sanitize_chart_data_for_llm_context(
+ ChartData(
+ chart_id=chart.id,
+ chart_name=chart.slice_name or f"Chart {chart.id}",
+ chart_type=chart.viz_type or "unknown",
+ columns=columns,
+ data=data[: request.limit] if request.limit else data,
+ row_count=len(data),
+ total_rows=query_result.get("rowcount"),
+ summary=summary,
+ insights=insights,
+ data_quality={"completeness": data_completeness},
+ recommended_visualizations=recommended_visualizations,
+ data_freshness=None, # Add missing field
+ performance=performance,
+ cache_status=cache_status,
+ )
)
except (CommandException, SupersetException, ValueError) as data_error:
@@ -929,30 +966,32 @@ async def _query_from_form_data(
)
await ctx.report_progress(4, 4, "Building response")
- return ChartData(
- chart_id=0,
- chart_name=chart_name,
- chart_type=viz_type,
- columns=columns,
- data=data[: request.limit] if request.limit else data,
- row_count=len(data),
- total_rows=query_result.get("rowcount"),
- summary=summary,
- insights=["This is an unsaved chart queried from cached form_data."],
- data_quality={
- "completeness": 1.0
- - (
- sum(col.null_count for col in columns)
- / max(len(data) * len(columns), 1)
- )
- },
- recommended_visualizations=[],
- data_freshness=None,
- performance=PerformanceMetadata(
- query_duration_ms=0,
- cache_status="fresh_query",
- ),
- cache_status=cache_status,
+ return _sanitize_chart_data_for_llm_context(
+ ChartData(
+ chart_id=0,
+ chart_name=chart_name,
+ chart_type=viz_type,
+ columns=columns,
+ data=data[: request.limit] if request.limit else data,
+ row_count=len(data),
+ total_rows=query_result.get("rowcount"),
+ summary=summary,
+ insights=["This is an unsaved chart queried from cached form_data."],
+ data_quality={
+ "completeness": 1.0
+ - (
+ sum(col.null_count for col in columns)
+ / max(len(data) * len(columns), 1)
+ )
+ },
+ recommended_visualizations=[],
+ data_freshness=None,
+ performance=PerformanceMetadata(
+ query_duration_ms=0,
+ cache_status="fresh_query",
+ ),
+ cache_status=cache_status,
+ )
)
except (CommandException, SupersetException, ValueError) as e:
@@ -1001,24 +1040,26 @@ def _export_data_as_csv(
# Return as ChartData with CSV content in a special field
from superset.mcp_service.chart.schemas import ChartData
- return ChartData(
- chart_id=chart.id,
- chart_name=chart.slice_name or f"Chart {chart.id}",
- chart_type=chart.viz_type or "unknown",
- columns=[], # Column names are embedded in CSV content
- data=[], # CSV content is in csv_data field
- row_count=len(data),
- total_rows=len(data),
- summary=f"CSV export of chart '{chart.slice_name}' with {len(data)} rows",
- insights=[f"Data exported as CSV format ({len(csv_content)} characters)"],
- data_quality={},
- recommended_visualizations=[],
- data_freshness=None,
- performance=performance,
- cache_status=cache_status,
- # Store CSV content in data field as string for the response
- csv_data=csv_content,
- format="csv",
+ return _sanitize_chart_data_for_llm_context(
+ ChartData(
+ chart_id=chart.id,
+ chart_name=chart.slice_name or f"Chart {chart.id}",
+ chart_type=chart.viz_type or "unknown",
+ columns=[], # Column names are embedded in CSV content
+ data=[], # CSV content is in csv_data field
+ row_count=len(data),
+ total_rows=len(data),
+ summary=f"CSV export of chart '{chart.slice_name}' with {len(data)} rows",
+ insights=[f"Data exported as CSV format ({len(csv_content)} characters)"],
+ data_quality={},
+ recommended_visualizations=[],
+ data_freshness=None,
+ performance=performance,
+ cache_status=cache_status,
+ # Store CSV content in data field as string for the response
+ csv_data=csv_content,
+ format="csv",
+ )
)
@@ -1156,23 +1197,25 @@ def _create_excel_chart_data(
chart_name = chart.slice_name or f"Chart {chart.id}"
summary = f"Excel export of chart '{chart.slice_name}' with {len(data)} rows"
- return ChartData(
- chart_id=chart.id,
- chart_name=chart_name,
- chart_type=chart.viz_type or "unknown",
- columns=[], # Column names are embedded in the Excel file
- data=[],
- row_count=len(data),
- total_rows=len(data),
- summary=summary,
- insights=["Data exported as Excel format (base64 encoded)"],
- data_quality={},
- recommended_visualizations=[],
- data_freshness=None,
- performance=performance,
- cache_status=cache_status,
- excel_data=excel_b64,
- format="excel",
+ return _sanitize_chart_data_for_llm_context(
+ ChartData(
+ chart_id=chart.id,
+ chart_name=chart_name,
+ chart_type=chart.viz_type or "unknown",
+ columns=[], # Column names are embedded in the Excel file
+ data=[],
+ row_count=len(data),
+ total_rows=len(data),
+ summary=summary,
+ insights=["Data exported as Excel format (base64 encoded)"],
+ data_quality={},
+ recommended_visualizations=[],
+ data_freshness=None,
+ performance=performance,
+ cache_status=cache_status,
+ excel_data=excel_b64,
+ format="excel",
+ )
)
@@ -1189,21 +1232,23 @@ def _create_excel_chart_data_xlsxwriter(
chart_name = chart.slice_name or f"Chart {chart.id}"
summary = f"Excel export of chart '{chart.slice_name}' with {len(data)} rows"
- return ChartData(
- chart_id=chart.id,
- chart_name=chart_name,
- chart_type=chart.viz_type or "unknown",
- columns=[], # Column names are embedded in the Excel file
- data=[],
- row_count=len(data),
- total_rows=len(data),
- summary=summary,
- insights=["Data exported as Excel format (base64 encoded, xlsxwriter)"],
- data_quality={},
- recommended_visualizations=[],
- data_freshness=None,
- performance=performance,
- cache_status=cache_status,
- excel_data=excel_b64,
- format="excel",
+ return _sanitize_chart_data_for_llm_context(
+ ChartData(
+ chart_id=chart.id,
+ chart_name=chart_name,
+ chart_type=chart.viz_type or "unknown",
+ columns=[], # Column names are embedded in the Excel file
+ data=[],
+ row_count=len(data),
+ total_rows=len(data),
+ summary=summary,
+ insights=["Data exported as Excel format (base64 encoded, xlsxwriter)"],
+ data_quality={},
+ recommended_visualizations=[],
+ data_freshness=None,
+ performance=performance,
+ cache_status=cache_status,
+ excel_data=excel_b64,
+ format="excel",
+ )
)
diff --git a/superset/mcp_service/chart/tool/get_chart_info.py b/superset/mcp_service/chart/tool/get_chart_info.py
index 68c13e4bafb..79ae03da41c 100644
--- a/superset/mcp_service/chart/tool/get_chart_info.py
+++ b/superset/mcp_service/chart/tool/get_chart_info.py
@@ -29,10 +29,12 @@ from superset.extensions import event_logger
from superset.mcp_service.chart.chart_helpers import get_cached_form_data
from superset.mcp_service.chart.chart_utils import validate_chart_dataset
from superset.mcp_service.chart.schemas import (
+ CHART_FORM_DATA_EXCLUDED_FIELD_NAMES,
ChartError,
ChartInfo,
extract_filters_from_form_data,
GetChartInfoRequest,
+ sanitize_chart_info_for_llm_context,
serialize_chart_object,
)
from superset.mcp_service.mcp_core import ModelGetInfoCore
@@ -40,6 +42,7 @@ from superset.mcp_service.privacy import (
redact_chart_data_model_fields,
user_can_view_data_model_metadata,
)
+from superset.mcp_service.utils import sanitize_for_llm_context
logger = logging.getLogger(__name__)
@@ -67,17 +70,25 @@ def _build_unsaved_chart_info(form_data_key: str) -> ChartInfo | ChartError:
error="Cached form_data is not a valid JSON object.",
error_type="ParseError",
)
- return ChartInfo(
- viz_type=form_data.get("viz_type"),
- datasource_name=form_data.get("datasource_name"),
- datasource_type=form_data.get("datasource_type"),
- filters=extract_filters_from_form_data(form_data),
- form_data=form_data,
- form_data_key=form_data_key,
- is_unsaved_state=True,
+ return sanitize_chart_info_for_llm_context(
+ ChartInfo(
+ viz_type=form_data.get("viz_type"),
+ datasource_name=form_data.get("datasource_name"),
+ datasource_type=form_data.get("datasource_type"),
+ filters=extract_filters_from_form_data(form_data),
+ form_data=form_data,
+ form_data_key=form_data_key,
+ is_unsaved_state=True,
+ )
)
+FORM_DATA_OVERRIDE_EXCLUDED_FIELD_NAMES = (
+ CHART_FORM_DATA_EXCLUDED_FIELD_NAMES
+ | frozenset({"cache_key", "database", "database_name", "schema"})
+)
+
+
def _apply_unsaved_state_override(result: ChartInfo, form_data_key: str) -> None:
"""Override a ChartInfo's form_data with cached unsaved state."""
from superset.utils import json as utils_json
@@ -106,6 +117,23 @@ def _apply_unsaved_state_override(result: ChartInfo, form_data_key: str) -> None
"The cache may have expired. Using saved chart configuration."
)
+ payload = result.model_dump(mode="python")
+ if payload.get("filters") is not None:
+ payload["filters"] = sanitize_for_llm_context(
+ payload["filters"],
+ field_path=("filters",),
+ excluded_field_names=frozenset(),
+ )
+ if payload.get("form_data") is not None:
+ payload["form_data"] = sanitize_for_llm_context(
+ payload["form_data"],
+ field_path=("form_data",),
+ excluded_field_names=FORM_DATA_OVERRIDE_EXCLUDED_FIELD_NAMES,
+ )
+ sanitized = ChartInfo.model_validate(payload)
+ result.filters = sanitized.filters
+ result.form_data = sanitized.form_data
+
@tool(
tags=["discovery"],
diff --git a/superset/mcp_service/chart/tool/get_chart_preview.py b/superset/mcp_service/chart/tool/get_chart_preview.py
index 7215170f8a6..1fb3740f116 100644
--- a/superset/mcp_service/chart/tool/get_chart_preview.py
+++ b/superset/mcp_service/chart/tool/get_chart_preview.py
@@ -47,6 +47,7 @@ from superset.mcp_service.chart.schemas import (
URLPreview,
VegaLitePreview,
)
+from superset.mcp_service.utils import sanitize_for_llm_context
from superset.mcp_service.utils.oauth2_utils import (
build_oauth2_redirect_message,
OAUTH2_CONFIG_ERROR_MESSAGE,
@@ -56,6 +57,78 @@ from superset.mcp_service.utils.url_utils import get_superset_base_url
logger = logging.getLogger(__name__)
+def _sanitize_preview_content_for_llm_context(content: dict[str, Any]) -> None:
+ """Wrap string-bearing preview content while preserving routing fields."""
+ content_type = content.get("type")
+
+ if content_type == "ascii":
+ content["ascii_content"] = sanitize_for_llm_context(
+ content.get("ascii_content"),
+ field_path=("content", "ascii_content"),
+ )
+ return
+
+ if content_type == "table":
+ content["table_data"] = sanitize_for_llm_context(
+ content.get("table_data"),
+ field_path=("content", "table_data"),
+ )
+ return
+
+ if content_type == "interactive":
+ content["html_content"] = sanitize_for_llm_context(
+ content.get("html_content"),
+ field_path=("content", "html_content"),
+ )
+ return
+
+ if content_type != "vega_lite":
+ return
+
+ specification = content.get("specification")
+ if not isinstance(specification, dict):
+ return
+
+ if "description" in specification:
+ specification["description"] = sanitize_for_llm_context(
+ specification.get("description"),
+ field_path=("content", "specification", "description"),
+ )
+
+ data = specification.get("data")
+ if isinstance(data, dict) and (values := data.get("values")) is not None:
+ data["values"] = sanitize_for_llm_context(
+ values,
+ field_path=("content", "specification", "data", "values"),
+ excluded_field_names=frozenset(),
+ )
+
+
+def _sanitize_chart_preview_for_llm_context(
+ chart_preview: ChartPreview,
+) -> ChartPreview:
+ """Wrap chart preview read-path descriptive fields before LLM exposure."""
+ payload = chart_preview.model_dump(mode="python")
+
+ for field_name in ("chart_name", "chart_description", "ascii_chart", "table_data"):
+ payload[field_name] = sanitize_for_llm_context(
+ payload.get(field_name),
+ field_path=(field_name,),
+ )
+
+ if accessibility := payload.get("accessibility"):
+ accessibility["alt_text"] = sanitize_for_llm_context(
+ accessibility.get("alt_text"),
+ field_path=("accessibility", "alt_text"),
+ )
+
+ content = payload.get("content")
+ if isinstance(content, dict):
+ _sanitize_preview_content_for_llm_context(content)
+
+ return ChartPreview.model_validate(payload)
+
+
class ChartLike(Protocol):
"""Protocol for chart-like objects with required attributes for preview."""
@@ -1296,7 +1369,7 @@ async def _get_chart_preview_internal( # noqa: C901
result.width = content.width
result.height = content.height
- return result
+ return _sanitize_chart_preview_for_llm_context(result)
except (
CommandException,
diff --git a/superset/mcp_service/chart/tool/get_chart_sql.py b/superset/mcp_service/chart/tool/get_chart_sql.py
index cb07a9c3637..d586817d261 100644
--- a/superset/mcp_service/chart/tool/get_chart_sql.py
+++ b/superset/mcp_service/chart/tool/get_chart_sql.py
@@ -38,10 +38,24 @@ from superset.mcp_service.chart.schemas import (
ChartSql,
GetChartSqlRequest,
)
+from superset.mcp_service.utils import sanitize_for_llm_context
logger = logging.getLogger(__name__)
+def _sanitize_chart_sql_for_llm_context(chart_sql: ChartSql) -> ChartSql:
+ """Wrap chart SQL read-path descriptive fields before LLM exposure."""
+ payload = chart_sql.model_dump(mode="python")
+
+ for field_name in ("chart_name", "datasource_name", "sql", "error"):
+ payload[field_name] = sanitize_for_llm_context(
+ payload.get(field_name),
+ field_path=(field_name,),
+ )
+
+ return ChartSql.model_validate(payload)
+
+
def _get_cached_form_data(form_data_key: str) -> str | None:
"""Retrieve form_data from cache using form_data_key.
@@ -423,13 +437,15 @@ def _extract_sql_from_result(
error_type="QueryGenerationFailed",
)
- return ChartSql(
- chart_id=chart_id,
- chart_name=chart_name,
- sql="\n\n".join(sql_parts),
- language=language,
- datasource_name=datasource_name,
- error="; ".join(errors) if errors else None,
+ return _sanitize_chart_sql_for_llm_context(
+ ChartSql(
+ chart_id=chart_id,
+ chart_name=chart_name,
+ sql="\n\n".join(sql_parts),
+ language=language,
+ datasource_name=datasource_name,
+ error="; ".join(errors) if errors else None,
+ )
)
diff --git a/superset/mcp_service/dashboard/schemas.py b/superset/mcp_service/dashboard/schemas.py
index 268df90f87f..8a92d585896 100644
--- a/superset/mcp_service/dashboard/schemas.py
+++ b/superset/mcp_service/dashboard/schemas.py
@@ -100,6 +100,10 @@ from superset.mcp_service.system.schemas import (
RoleInfo,
TagInfo,
)
+from superset.mcp_service.utils import (
+ escape_llm_context_delimiters,
+ sanitize_for_llm_context,
+)
from superset.mcp_service.utils.sanitization import (
sanitize_user_input,
sanitize_user_input_with_changes,
@@ -116,6 +120,12 @@ class DashboardError(BaseModel):
model_config = ConfigDict(ser_json_timedelta="iso8601")
+ @field_validator("error")
+ @classmethod
+ def sanitize_error_for_llm_context(cls, value: str) -> str:
+ """Wrap error text before it is exposed to LLM context."""
+ return sanitize_for_llm_context(value, field_path=("error",))
+
@classmethod
def create(cls, error: str, error_type: str) -> "DashboardError":
"""Create a standardized DashboardError with timestamp."""
@@ -748,6 +758,83 @@ def redact_filter_state_data_model_metadata(
}
+def _sanitize_dashboard_info_for_llm_context(
+ dashboard_info: DashboardInfo,
+) -> DashboardInfo:
+ """Wrap dashboard read-path descriptive fields before LLM exposure."""
+ payload = dashboard_info.model_dump(mode="python")
+
+ for field_name in (
+ "dashboard_title",
+ "description",
+ "css",
+ "certified_by",
+ "certification_details",
+ ):
+ payload[field_name] = sanitize_for_llm_context(
+ payload.get(field_name),
+ field_path=(field_name,),
+ )
+
+ payload["native_filters"] = [
+ {
+ **native_filter,
+ "name": sanitize_for_llm_context(
+ native_filter.get("name"),
+ field_path=("native_filters", str(index), "name"),
+ ),
+ "targets": sanitize_for_llm_context(
+ native_filter.get("targets", []),
+ field_path=("native_filters", str(index), "targets"),
+ excluded_field_names=frozenset(),
+ ),
+ }
+ for index, native_filter in enumerate(payload.get("native_filters", []))
+ ]
+
+ payload["charts"] = [
+ {
+ **chart,
+ "slice_name": sanitize_for_llm_context(
+ chart.get("slice_name"),
+ field_path=("charts", str(index), "slice_name"),
+ ),
+ "description": sanitize_for_llm_context(
+ chart.get("description"),
+ field_path=("charts", str(index), "description"),
+ ),
+ "datasource_name": escape_llm_context_delimiters(
+ chart.get("datasource_name"),
+ ),
+ }
+ for index, chart in enumerate(payload.get("charts", []))
+ ]
+
+ if payload.get("filter_state") is not None:
+ payload["filter_state"] = sanitize_for_llm_context(
+ payload["filter_state"],
+ field_path=("filter_state",),
+ excluded_field_names=frozenset(),
+ )
+
+ payload["tags"] = [
+ {
+ **tag,
+ "name": sanitize_for_llm_context(
+ tag.get("name"),
+ field_path=("tags", str(index), "name"),
+ ),
+ "description": sanitize_for_llm_context(
+ tag.get("description"),
+ field_path=("tags", str(index), "description"),
+ ),
+ }
+ for index, tag in enumerate(payload.get("tags", []))
+ ]
+
+ return DashboardInfo.model_validate(payload)
+
+
def dashboard_serializer(dashboard: "Dashboard") -> DashboardInfo:
from superset.mcp_service.utils.url_utils import get_superset_base_url
@@ -758,51 +845,54 @@ def dashboard_serializer(dashboard: "Dashboard") -> DashboardInfo:
json_metadata_str = getattr(dashboard, "json_metadata", None)
position_json_str = getattr(dashboard, "position_json", None)
- return DashboardInfo(
- id=dashboard.id,
- dashboard_title=dashboard.dashboard_title or "Untitled",
- slug=dashboard.slug or "",
- description=dashboard.description,
- css=dashboard.css,
- certified_by=dashboard.certified_by,
- certification_details=dashboard.certification_details,
- published=dashboard.published,
- is_managed_externally=dashboard.is_managed_externally,
- external_url=dashboard.external_url,
- created_on=dashboard.created_on,
- changed_on=dashboard.changed_on,
- uuid=str(dashboard.uuid) if dashboard.uuid else None,
- url=absolute_url,
- created_on_humanized=dashboard.created_on_humanized,
- changed_on_humanized=dashboard.changed_on_humanized,
- chart_count=len(dashboard.slices) if dashboard.slices else 0,
- native_filters=_extract_native_filters(
- json_metadata_str,
- include_data_model_metadata=include_data_model_metadata,
- ),
- cross_filters_enabled=_extract_cross_filters_enabled(json_metadata_str),
- omitted_fields=_build_omitted_fields(
- json_metadata_str,
- position_json_str,
- ),
- tags=[
- TagInfo.model_validate(tag, from_attributes=True) for tag in dashboard.tags
- ]
- if dashboard.tags
- else [],
- charts=[
- summary
- for chart in dashboard.slices
- if (
- summary := serialize_chart_summary(
- chart,
- include_data_model_metadata=include_data_model_metadata,
+ return _sanitize_dashboard_info_for_llm_context(
+ DashboardInfo(
+ id=dashboard.id,
+ dashboard_title=dashboard.dashboard_title or "Untitled",
+ slug=dashboard.slug or "",
+ description=dashboard.description,
+ css=dashboard.css,
+ certified_by=dashboard.certified_by,
+ certification_details=dashboard.certification_details,
+ published=dashboard.published,
+ is_managed_externally=dashboard.is_managed_externally,
+ external_url=dashboard.external_url,
+ created_on=dashboard.created_on,
+ changed_on=dashboard.changed_on,
+ uuid=str(dashboard.uuid) if dashboard.uuid else None,
+ url=absolute_url,
+ created_on_humanized=dashboard.created_on_humanized,
+ changed_on_humanized=dashboard.changed_on_humanized,
+ chart_count=len(dashboard.slices) if dashboard.slices else 0,
+ native_filters=_extract_native_filters(
+ json_metadata_str,
+ include_data_model_metadata=include_data_model_metadata,
+ ),
+ cross_filters_enabled=_extract_cross_filters_enabled(json_metadata_str),
+ omitted_fields=_build_omitted_fields(
+ json_metadata_str,
+ position_json_str,
+ ),
+ tags=[
+ TagInfo.model_validate(tag, from_attributes=True)
+ for tag in dashboard.tags
+ ]
+ if dashboard.tags
+ else [],
+ charts=[
+ summary
+ for chart in dashboard.slices
+ if (
+ summary := serialize_chart_summary(
+ chart,
+ include_data_model_metadata=include_data_model_metadata,
+ )
)
- )
- is not None
- ]
- if dashboard.slices
- else [],
+ is not None
+ ]
+ if dashboard.slices
+ else [],
+ )
)
@@ -831,53 +921,55 @@ def serialize_dashboard_object(dashboard: Any) -> DashboardInfo:
position_json_str = getattr(dashboard, "position_json", None)
include_data_model_metadata = user_can_view_data_model_metadata()
- return DashboardInfo(
- id=dashboard_id,
- dashboard_title=getattr(dashboard, "dashboard_title", None),
- slug=slug or "",
- url=dashboard_url,
- published=getattr(dashboard, "published", None),
- changed_on=getattr(dashboard, "changed_on", None),
- changed_on_humanized=_humanize_timestamp(
- getattr(dashboard, "changed_on", None)
- ),
- created_on=getattr(dashboard, "created_on", None),
- created_on_humanized=_humanize_timestamp(
- getattr(dashboard, "created_on", None)
- ),
- description=getattr(dashboard, "description", None),
- css=getattr(dashboard, "css", None),
- certified_by=getattr(dashboard, "certified_by", None),
- certification_details=getattr(dashboard, "certification_details", None),
- native_filters=_extract_native_filters(
- json_metadata_str,
- include_data_model_metadata=include_data_model_metadata,
- ),
- cross_filters_enabled=_extract_cross_filters_enabled(json_metadata_str),
- omitted_fields=_build_omitted_fields(json_metadata_str, position_json_str),
- is_managed_externally=getattr(dashboard, "is_managed_externally", None),
- external_url=getattr(dashboard, "external_url", None),
- uuid=str(getattr(dashboard, "uuid", ""))
- if getattr(dashboard, "uuid", None)
- else None,
- chart_count=len(getattr(dashboard, "slices", [])),
- tags=[
- TagInfo.model_validate(tag, from_attributes=True)
- for tag in getattr(dashboard, "tags", [])
- ]
- if getattr(dashboard, "tags", None)
- else [],
- charts=[
- summary
- for chart in getattr(dashboard, "slices", [])
- if (
- summary := serialize_chart_summary(
- chart,
- include_data_model_metadata=include_data_model_metadata,
+ return _sanitize_dashboard_info_for_llm_context(
+ DashboardInfo(
+ id=dashboard_id,
+ dashboard_title=getattr(dashboard, "dashboard_title", None),
+ slug=slug or "",
+ url=dashboard_url,
+ published=getattr(dashboard, "published", None),
+ changed_on=getattr(dashboard, "changed_on", None),
+ changed_on_humanized=_humanize_timestamp(
+ getattr(dashboard, "changed_on", None)
+ ),
+ created_on=getattr(dashboard, "created_on", None),
+ created_on_humanized=_humanize_timestamp(
+ getattr(dashboard, "created_on", None)
+ ),
+ description=getattr(dashboard, "description", None),
+ css=getattr(dashboard, "css", None),
+ certified_by=getattr(dashboard, "certified_by", None),
+ certification_details=getattr(dashboard, "certification_details", None),
+ native_filters=_extract_native_filters(
+ json_metadata_str,
+ include_data_model_metadata=include_data_model_metadata,
+ ),
+ cross_filters_enabled=_extract_cross_filters_enabled(json_metadata_str),
+ omitted_fields=_build_omitted_fields(json_metadata_str, position_json_str),
+ is_managed_externally=getattr(dashboard, "is_managed_externally", None),
+ external_url=getattr(dashboard, "external_url", None),
+ uuid=str(getattr(dashboard, "uuid", ""))
+ if getattr(dashboard, "uuid", None)
+ else None,
+ chart_count=len(getattr(dashboard, "slices", [])),
+ tags=[
+ TagInfo.model_validate(tag, from_attributes=True)
+ for tag in getattr(dashboard, "tags", [])
+ ]
+ if getattr(dashboard, "tags", None)
+ else [],
+ charts=[
+ summary
+ for chart in getattr(dashboard, "slices", [])
+ if (
+ summary := serialize_chart_summary(
+ chart,
+ include_data_model_metadata=include_data_model_metadata,
+ )
)
- )
- is not None
- ]
- if getattr(dashboard, "slices", None)
- else [],
+ is not None
+ ]
+ if getattr(dashboard, "slices", None)
+ else [],
+ )
)
diff --git a/superset/mcp_service/dashboard/tool/get_dashboard_info.py b/superset/mcp_service/dashboard/tool/get_dashboard_info.py
index 9c52bbc7f2d..6acc51b6bc0 100644
--- a/superset/mcp_service/dashboard/tool/get_dashboard_info.py
+++ b/superset/mcp_service/dashboard/tool/get_dashboard_info.py
@@ -26,12 +26,14 @@ import logging
from datetime import datetime, timezone
from fastmcp import Context
+from flask import g, has_request_context
from sqlalchemy.orm import subqueryload
from superset_core.mcp.decorators import tool, ToolAnnotations
from superset.dashboards.permalink.exceptions import DashboardPermalinkGetFailedError
from superset.dashboards.permalink.types import DashboardPermalinkValue
from superset.extensions import event_logger
+from superset.mcp_service.auth import load_user_with_relationships
from superset.mcp_service.dashboard.schemas import (
dashboard_serializer,
DashboardError,
@@ -41,10 +43,51 @@ from superset.mcp_service.dashboard.schemas import (
)
from superset.mcp_service.mcp_core import ModelGetInfoCore
from superset.mcp_service.privacy import user_can_view_data_model_metadata
+from superset.mcp_service.utils import sanitize_for_llm_context
logger = logging.getLogger(__name__)
+def _refresh_request_user_for_permalink_access() -> None:
+ """Reload the request user before permalink access checks."""
+ if not has_request_context() or not getattr(g, "user", None):
+ return
+
+ current_user = g.user
+ if getattr(current_user, "is_anonymous", False):
+ return
+
+ username = getattr(current_user, "username", None)
+ email = getattr(current_user, "email", None)
+ if not username and not email:
+ return
+
+ refreshed_user = (
+ load_user_with_relationships(username=username)
+ if username
+ else load_user_with_relationships(email=email)
+ )
+ if refreshed_user is not None:
+ g.user = refreshed_user
+
+
+def _apply_permalink_state(
+ result: DashboardInfo,
+ permalink_key: str,
+ permalink_state: dict[str, object],
+) -> DashboardInfo:
+ """Sanitize only the raw permalink fields added after serialization."""
+ payload = result.model_dump(mode="python")
+ payload["permalink_key"] = permalink_key
+ payload["filter_state"] = sanitize_for_llm_context(
+ permalink_state,
+ field_path=("filter_state",),
+ excluded_field_names=frozenset(),
+ )
+ payload["is_permalink_state"] = True
+ return DashboardInfo.model_validate(payload)
+
+
def _get_permalink_state(permalink_key: str) -> DashboardPermalinkValue | None:
"""Retrieve dashboard filter state from permalink.
@@ -136,6 +179,7 @@ async def get_dashboard_info(
"Retrieving filter state from permalink: permalink_key=%s"
% (request.permalink_key,)
)
+ _refresh_request_user_for_permalink_access()
permalink_value = _get_permalink_state(request.permalink_key)
if permalink_value:
@@ -171,9 +215,11 @@ async def get_dashboard_info(
permalink_state = redact_filter_state_data_model_metadata(
permalink_state
)
- result.permalink_key = request.permalink_key
- result.filter_state = permalink_state
- result.is_permalink_state = True
+ result = _apply_permalink_state(
+ result,
+ request.permalink_key,
+ permalink_state,
+ )
await ctx.info(
"Filter state retrieved from permalink: "
diff --git a/superset/mcp_service/dataset/schemas.py b/superset/mcp_service/dataset/schemas.py
index bbfe018dbfc..dfb7f8f9faa 100644
--- a/superset/mcp_service/dataset/schemas.py
+++ b/superset/mcp_service/dataset/schemas.py
@@ -47,6 +47,10 @@ from superset.mcp_service.system.schemas import (
PaginationInfo,
TagInfo,
)
+from superset.mcp_service.utils import (
+ escape_llm_context_delimiters,
+ sanitize_for_llm_context,
+)
from superset.utils import json
@@ -286,6 +290,12 @@ class DatasetError(BaseModel):
timestamp: str | datetime | None = Field(None, description="Error timestamp")
model_config = ConfigDict(ser_json_timedelta="iso8601")
+ @field_validator("error")
+ @classmethod
+ def sanitize_error_for_llm_context(cls, value: str) -> str:
+ """Wrap error text before it is exposed to LLM context."""
+ return sanitize_for_llm_context(value, field_path=("error",))
+
@classmethod
def create(cls, error: str, error_type: str) -> "DatasetError":
"""Create a standardized DatasetError with timestamp."""
@@ -404,6 +414,90 @@ def _humanize_timestamp(dt: datetime | None) -> str | None:
return humanize.naturaltime(datetime.now() - dt)
+def _sanitize_dataset_info_for_llm_context(dataset_info: DatasetInfo) -> DatasetInfo:
+ """Wrap dataset read-path descriptive fields before LLM exposure."""
+ payload = dataset_info.model_dump(mode="python")
+
+ for field_name in ("description", "certified_by", "certification_details", "sql"):
+ payload[field_name] = sanitize_for_llm_context(
+ payload.get(field_name),
+ field_path=(field_name,),
+ )
+
+ for field_name in ("table_name", "schema_name", "database_name", "schema_perm"):
+ payload[field_name] = escape_llm_context_delimiters(payload.get(field_name))
+
+ payload["extra"] = sanitize_for_llm_context(
+ payload.get("extra"),
+ field_path=("extra",),
+ excluded_field_names=frozenset(),
+ )
+
+ for field_name in ("params", "template_params"):
+ payload[field_name] = sanitize_for_llm_context(
+ payload.get(field_name),
+ field_path=(field_name,),
+ excluded_field_names=frozenset(),
+ )
+
+ payload["columns"] = [
+ {
+ **column,
+ "column_name": escape_llm_context_delimiters(
+ column.get("column_name"),
+ ),
+ "description": sanitize_for_llm_context(
+ column.get("description"),
+ field_path=("columns", str(index), "description"),
+ ),
+ "verbose_name": sanitize_for_llm_context(
+ column.get("verbose_name"),
+ field_path=("columns", str(index), "verbose_name"),
+ ),
+ }
+ for index, column in enumerate(payload.get("columns", []))
+ ]
+
+ payload["metrics"] = [
+ {
+ **metric,
+ "metric_name": escape_llm_context_delimiters(
+ metric.get("metric_name"),
+ ),
+ "expression": sanitize_for_llm_context(
+ metric.get("expression"),
+ field_path=("metrics", str(index), "expression"),
+ ),
+ "description": sanitize_for_llm_context(
+ metric.get("description"),
+ field_path=("metrics", str(index), "description"),
+ ),
+ "verbose_name": sanitize_for_llm_context(
+ metric.get("verbose_name"),
+ field_path=("metrics", str(index), "verbose_name"),
+ ),
+ }
+ for index, metric in enumerate(payload.get("metrics", []))
+ ]
+
+ payload["tags"] = [
+ {
+ **tag,
+ "name": sanitize_for_llm_context(
+ tag.get("name"),
+ field_path=("tags", str(index), "name"),
+ ),
+ "description": sanitize_for_llm_context(
+ tag.get("description"),
+ field_path=("tags", str(index), "description"),
+ ),
+ }
+ for index, tag in enumerate(payload.get("tags", []))
+ ]
+
+ return DatasetInfo.model_validate(payload)
+
+
def serialize_dataset_object(dataset: Any) -> DatasetInfo | None:
if not dataset:
return None
@@ -438,46 +532,52 @@ def serialize_dataset_object(dataset: Any) -> DatasetInfo | None:
)
for metric in getattr(dataset, "metrics", [])
]
- return DatasetInfo(
- id=getattr(dataset, "id", None),
- table_name=getattr(dataset, "table_name", None),
- schema_name=getattr(dataset, "schema", None),
- database_name=getattr(dataset.database, "database_name", None)
- if getattr(dataset, "database", None)
- else None,
- description=getattr(dataset, "description", None),
- certified_by=getattr(dataset, "certified_by", None),
- certification_details=getattr(dataset, "certification_details", None),
- changed_on=getattr(dataset, "changed_on", None),
- changed_on_humanized=_humanize_timestamp(getattr(dataset, "changed_on", None)),
- created_on=getattr(dataset, "created_on", None),
- created_on_humanized=_humanize_timestamp(getattr(dataset, "created_on", None)),
- tags=[
- TagInfo.model_validate(tag, from_attributes=True)
- for tag in getattr(dataset, "tags", [])
- ]
- if getattr(dataset, "tags", None)
- else [],
- is_virtual=getattr(dataset, "is_virtual", None),
- database_id=getattr(dataset, "database_id", None),
- uuid=str(getattr(dataset, "uuid", ""))
- if getattr(dataset, "uuid", None)
- else None,
- schema_perm=getattr(dataset, "schema_perm", None),
- url=(
- f"{get_superset_base_url()}/tablemodelview/edit/"
- f"{getattr(dataset, 'id', None)}"
- if getattr(dataset, "id", None)
- else None
- ),
- sql=getattr(dataset, "sql", None),
- main_dttm_col=getattr(dataset, "main_dttm_col", None),
- offset=getattr(dataset, "offset", None),
- cache_timeout=getattr(dataset, "cache_timeout", None),
- params=params,
- template_params=_parse_json_field(dataset, "template_params"),
- extra=_parse_json_field(dataset, "extra"),
- columns=columns,
- metrics=metrics,
- is_favorite=getattr(dataset, "is_favorite", None),
+ return _sanitize_dataset_info_for_llm_context(
+ DatasetInfo(
+ id=getattr(dataset, "id", None),
+ table_name=getattr(dataset, "table_name", None),
+ schema_name=getattr(dataset, "schema", None),
+ database_name=getattr(dataset.database, "database_name", None)
+ if getattr(dataset, "database", None)
+ else None,
+ description=getattr(dataset, "description", None),
+ certified_by=getattr(dataset, "certified_by", None),
+ certification_details=getattr(dataset, "certification_details", None),
+ changed_on=getattr(dataset, "changed_on", None),
+ changed_on_humanized=_humanize_timestamp(
+ getattr(dataset, "changed_on", None)
+ ),
+ created_on=getattr(dataset, "created_on", None),
+ created_on_humanized=_humanize_timestamp(
+ getattr(dataset, "created_on", None)
+ ),
+ tags=[
+ TagInfo.model_validate(tag, from_attributes=True)
+ for tag in getattr(dataset, "tags", [])
+ ]
+ if getattr(dataset, "tags", None)
+ else [],
+ is_virtual=getattr(dataset, "is_virtual", None),
+ database_id=getattr(dataset, "database_id", None),
+ uuid=str(getattr(dataset, "uuid", ""))
+ if getattr(dataset, "uuid", None)
+ else None,
+ schema_perm=getattr(dataset, "schema_perm", None),
+ url=(
+ f"{get_superset_base_url()}/tablemodelview/edit/"
+ f"{getattr(dataset, 'id', None)}"
+ if getattr(dataset, "id", None)
+ else None
+ ),
+ sql=getattr(dataset, "sql", None),
+ main_dttm_col=getattr(dataset, "main_dttm_col", None),
+ offset=getattr(dataset, "offset", None),
+ cache_timeout=getattr(dataset, "cache_timeout", None),
+ params=params,
+ template_params=_parse_json_field(dataset, "template_params"),
+ extra=_parse_json_field(dataset, "extra"),
+ columns=columns,
+ metrics=metrics,
+ is_favorite=getattr(dataset, "is_favorite", None),
+ )
)
diff --git a/superset/mcp_service/sql_lab/tool/open_sql_lab_with_context.py b/superset/mcp_service/sql_lab/tool/open_sql_lab_with_context.py
index 04fb93d7d80..1d7af0ba6f2 100644
--- a/superset/mcp_service/sql_lab/tool/open_sql_lab_with_context.py
+++ b/superset/mcp_service/sql_lab/tool/open_sql_lab_with_context.py
@@ -22,7 +22,7 @@ Tool for generating SQL Lab URLs with pre-populated sql and context.
"""
import logging
-from urllib.parse import urlencode
+from urllib.parse import parse_qsl, urlencode, urlsplit, urlunsplit
from fastmcp import Context
from superset_core.mcp.decorators import tool, ToolAnnotations
@@ -32,10 +32,51 @@ from superset.mcp_service.sql_lab.schemas import (
OpenSqlLabRequest,
SqlLabResponse,
)
+from superset.mcp_service.utils import sanitize_for_llm_context
from superset.mcp_service.utils.url_utils import get_superset_base_url
logger = logging.getLogger(__name__)
+SQL_LAB_QUERY_PARAMS_TO_SANITIZE = frozenset({"sql", "title"})
+
+
+def _sanitize_sql_lab_url_for_llm_context(url: str) -> str:
+ """Wrap user-controlled SQL Lab query values while preserving navigation."""
+ if not url:
+ return url
+
+ parsed = urlsplit(url)
+ query_params = parse_qsl(parsed.query, keep_blank_values=True)
+ if not query_params:
+ return url
+
+ sanitized_params = [
+ (
+ name,
+ sanitize_for_llm_context(value, field_path=(name,))
+ if name in SQL_LAB_QUERY_PARAMS_TO_SANITIZE
+ else value,
+ )
+ for name, value in query_params
+ ]
+ return urlunsplit(parsed._replace(query=urlencode(sanitized_params)))
+
+
+def _sanitize_sql_lab_response_for_llm_context(
+ response: SqlLabResponse,
+) -> SqlLabResponse:
+ """Wrap user-controlled SQL Lab response content before LLM exposure."""
+ payload = response.model_dump(mode="python")
+ payload["url"] = _sanitize_sql_lab_url_for_llm_context(payload.get("url", ""))
+
+ for field_name in ("title", "error"):
+ payload[field_name] = sanitize_for_llm_context(
+ payload.get(field_name),
+ field_path=(field_name,),
+ )
+
+ return SqlLabResponse.model_validate(payload)
+
@tool(
tags=["explore"],
@@ -61,12 +102,17 @@ def open_sql_lab_with_context(
# Validate database exists and is accessible
database = DatabaseDAO.find_by_id(request.database_connection_id)
if not database:
- return SqlLabResponse(
- url="",
- database_id=request.database_connection_id,
- schema_name=request.schema_name,
- title=request.title,
- error=f"Database with ID {request.database_connection_id} not found",
+ error_message = (
+ f"Database with ID {request.database_connection_id} not found"
+ )
+ return _sanitize_sql_lab_response_for_llm_context(
+ SqlLabResponse(
+ url="",
+ database_id=request.database_connection_id,
+ schema_name=request.schema_name,
+ title=request.title,
+ error=error_message,
+ )
)
# Build query parameters for SQL Lab URL
@@ -109,12 +155,14 @@ def open_sql_lab_with_context(
"Generated SQL Lab URL for database %s", request.database_connection_id
)
- return SqlLabResponse(
- url=url,
- database_id=request.database_connection_id,
- schema_name=request.schema_name,
- title=request.title,
- error=None,
+ return _sanitize_sql_lab_response_for_llm_context(
+ SqlLabResponse(
+ url=url,
+ database_id=request.database_connection_id,
+ schema_name=request.schema_name,
+ title=request.title,
+ error=None,
+ )
)
except Exception as e:
@@ -128,10 +176,12 @@ def open_sql_lab_with_context(
"Database rollback failed during error handling", exc_info=True
)
logger.error("Error generating SQL Lab URL: %s", e)
- return SqlLabResponse(
- url="",
- database_id=request.database_connection_id,
- schema_name=request.schema_name,
- title=request.title,
- error=f"Failed to generate SQL Lab URL: {str(e)}",
+ return _sanitize_sql_lab_response_for_llm_context(
+ SqlLabResponse(
+ url="",
+ database_id=request.database_connection_id,
+ schema_name=request.schema_name,
+ title=request.title,
+ error=f"Failed to generate SQL Lab URL: {str(e)}",
+ )
)
diff --git a/superset/mcp_service/utils/__init__.py b/superset/mcp_service/utils/__init__.py
index b962f652be8..b405537c2d5 100644
--- a/superset/mcp_service/utils/__init__.py
+++ b/superset/mcp_service/utils/__init__.py
@@ -17,6 +17,11 @@
from __future__ import annotations
+from superset.mcp_service.utils.sanitization import (
+ escape_llm_context_delimiters as escape_llm_context_delimiters,
+ sanitize_for_llm_context as sanitize_for_llm_context,
+)
+
def _is_uuid(value: str) -> bool:
"""Check if a string is a valid UUID."""
diff --git a/superset/mcp_service/utils/sanitization.py b/superset/mcp_service/utils/sanitization.py
index 48d32b43882..ababfb69337 100644
--- a/superset/mcp_service/utils/sanitization.py
+++ b/superset/mcp_service/utils/sanitization.py
@@ -31,9 +31,145 @@ Key features:
import html
import re
+from typing import Any
import nh3
+LLM_CONTEXT_OPEN_DELIMITER = "