Files
superset2/superset/mcp_service/chart/tool/get_chart_sql.py
2026-04-29 19:06:19 -03:00

617 lines
21 KiB
Python

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
MCP tool: get_chart_sql
"""
import logging
from typing import Any, TYPE_CHECKING
from fastmcp import Context
from superset_core.mcp.decorators import tool, ToolAnnotations
if TYPE_CHECKING:
from superset.models.slice import Slice
from superset.commands.exceptions import CommandException
from superset.commands.explore.form_data.parameters import CommandParameters
from superset.exceptions import SupersetException, SupersetSecurityException
from superset.extensions import event_logger
from superset.mcp_service.chart.chart_utils import validate_chart_dataset
from superset.mcp_service.chart.schemas import (
ChartError,
ChartSql,
GetChartSqlRequest,
)
from superset.mcp_service.utils import sanitize_for_llm_context
logger = logging.getLogger(__name__)
def _sanitize_chart_sql_for_llm_context(chart_sql: ChartSql) -> ChartSql:
"""Wrap chart SQL read-path descriptive fields before LLM exposure."""
payload = chart_sql.model_dump(mode="python")
for field_name in ("chart_name", "datasource_name", "sql", "error"):
payload[field_name] = sanitize_for_llm_context(
payload.get(field_name),
field_path=(field_name,),
)
return ChartSql.model_validate(payload)
def _get_cached_form_data(form_data_key: str) -> str | None:
"""Retrieve form_data from cache using form_data_key.
Returns the JSON string of form_data if found, None otherwise.
"""
from superset.commands.explore.form_data.get import GetFormDataCommand
try:
cmd_params = CommandParameters(key=form_data_key)
return GetFormDataCommand(cmd_params).run()
except (KeyError, ValueError, CommandException) as e:
logger.warning("Failed to retrieve form_data from cache: %s", e)
return None
def _resolve_metrics(form_data: dict[str, Any], viz_type: str) -> list[Any]:
"""Extract metrics from form_data, handling chart-type-specific fields."""
# Bubble charts store measures in x, y, size fields
if viz_type == "bubble":
return [m for field in ("x", "y", "size") if (m := form_data.get(field))]
metrics = form_data.get("metrics", [])
# Fallback: some chart types store the measure as singular "metric"
if not metrics and (metric := form_data.get("metric")):
metrics = [metric]
return metrics
def _resolve_groupby(form_data: dict[str, Any]) -> list[str]:
"""Extract groupby columns from form_data with fallback aliases.
Normalises scalar strings (e.g. heatmap_v2 migrated from legacy
``all_columns_y``) so that ``list("country")`` does not split into
individual characters.
"""
raw_groupby = form_data.get("groupby") or []
if isinstance(raw_groupby, str):
groupby: list[str] = [raw_groupby]
else:
groupby = list(raw_groupby)
if groupby:
return groupby
# Fallback: some chart types store dimensions in entity/series/columns
for field in ("entity", "series"):
value = form_data.get(field)
if isinstance(value, str) and value not in groupby:
groupby.append(value)
form_columns = form_data.get("columns")
if isinstance(form_columns, list):
for col in form_columns:
if isinstance(col, str) and col not in groupby:
groupby.append(col)
return groupby
def _resolve_metrics_and_groupby(
form_data: dict[str, Any],
chart: "Slice | None",
) -> tuple[list[Any], list[str]]:
"""Resolve metrics and groupby columns from form_data.
Handles chart-type-specific field names: singular ``metric`` for
big-number variants, bubble ``x``/``y``/``size``, and fallback
fields ``entity``, ``series``, and ``columns`` for dimensions.
"""
viz_type = form_data.get(
"viz_type", getattr(chart, "viz_type", "") if chart else ""
)
singular_metric_no_groupby = (
"big_number",
"big_number_total",
"pop_kpi",
)
if viz_type in singular_metric_no_groupby:
metrics: list[Any] = [metric] if (metric := form_data.get("metric")) else []
return metrics, []
return _resolve_metrics(form_data, viz_type), _resolve_groupby(form_data)
def _resolve_engine(
datasource_id: Any,
datasource_type: str,
) -> str:
"""Return the DB engine name for *datasource_id*, or ``"base"`` on any error."""
if not isinstance(datasource_id, (int, str)):
return "base"
try:
from superset.daos.datasource import DatasourceDAO
from superset.utils.core import DatasourceType
ds = DatasourceDAO.get_datasource(
datasource_type=DatasourceType(datasource_type),
database_id_or_uuid=datasource_id,
)
return ds.database.db_engine_spec.engine
except Exception: # noqa: BLE001
logger.debug("Could not resolve engine for datasource %s", datasource_id)
return "base"
def _build_query_context_from_form_data(
form_data: dict[str, Any],
chart: "Slice | None" = None,
) -> Any:
"""Build a QueryContext from form_data with result_type=QUERY.
This constructs a query context that will return the rendered SQL
instead of executing the query.
"""
from superset.common.chart_data import ChartDataResultType
from superset.common.query_context_factory import QueryContextFactory
factory = QueryContextFactory()
datasource_id = form_data.get("datasource_id")
datasource_type = form_data.get("datasource_type")
# Unsaved Explore state often stores datasource as a combined field
# like "123__table" instead of separate datasource_id/datasource_type.
if not datasource_id and (combined := form_data.get("datasource")):
if isinstance(combined, str) and "__" in combined:
parts = combined.split("__", 1)
datasource_id = int(parts[0]) if parts[0].isdigit() else parts[0]
datasource_type = parts[1] if len(parts) > 1 else None
if not datasource_id and chart:
datasource_id = getattr(chart, "datasource_id", None)
if not datasource_type and chart:
datasource_type = getattr(chart, "datasource_type", None)
metrics, groupby = _resolve_metrics_and_groupby(form_data, chart)
# Preprocess adhoc_filters into where/having/filters on form_data so
# that the QueryObject receives concrete filter clauses. This mirrors
# the view-layer call in viz.py:process_query_filters.
from superset.utils.core import (
merge_extra_filters,
split_adhoc_filters_into_base_filters,
)
resolved_type_str: str = (
datasource_type if isinstance(datasource_type, str) else "table"
)
engine = _resolve_engine(datasource_id, resolved_type_str)
merge_extra_filters(form_data)
split_adhoc_filters_into_base_filters(form_data, engine)
# Build query dict with temporal and filter fields.
# QueryObjectFactory.create() accepts time_range as a top-level kwarg
# and converts it to from_dttm/to_dttm for the QueryObject.
query_dict: dict[str, Any] = {
"columns": groupby,
"metrics": metrics,
}
if time_range := form_data.get("time_range"):
query_dict["time_range"] = time_range
if filters := form_data.get("filters"):
query_dict["filters"] = filters
if (row_limit := form_data.get("row_limit")) is not None:
query_dict["row_limit"] = row_limit
# Ensure datasource fields satisfy DatasourceDict typing requirements.
# datasource_id must be int | str; datasource_type must be str.
if not isinstance(datasource_id, (int, str)):
raise ValueError(
"Cannot determine datasource ID from form_data. "
"Provide a chart identifier or ensure form_data contains "
"'datasource_id' or 'datasource'."
)
resolved_id: int | str = datasource_id
return factory.create(
datasource={"id": resolved_id, "type": resolved_type_str},
queries=[query_dict],
form_data=form_data,
result_type=ChartDataResultType.QUERY,
force=False,
)
def _find_chart_by_identifier(
identifier: int | str,
) -> "Slice | None":
"""Look up a chart by numeric ID or UUID string."""
from superset.daos.chart import ChartDAO
if isinstance(identifier, int) or (
isinstance(identifier, str) and identifier.isdigit()
):
chart_id = int(identifier) if isinstance(identifier, str) else identifier
return ChartDAO.find_by_id(chart_id)
if isinstance(identifier, str):
return ChartDAO.find_by_id(identifier, id_column="uuid")
return None
def _resolve_effective_form_data(
chart: "Slice",
form_data_key: str | None,
) -> tuple[dict[str, Any], bool]:
"""Resolve the effective form_data for a chart.
Returns (form_data_dict, using_unsaved_state).
"""
from superset.utils import json as utils_json
if form_data_key:
if cached_raw := _get_cached_form_data(form_data_key):
try:
parsed = utils_json.loads(cached_raw)
if isinstance(parsed, dict):
return parsed, True
except (TypeError, ValueError):
pass
try:
saved = utils_json.loads(chart.params) if chart.params else {}
except (TypeError, ValueError):
saved = {}
return saved if isinstance(saved, dict) else {}, False
def _sql_from_saved_query_context(
chart: "Slice",
) -> ChartSql | ChartError | None:
"""Try to extract SQL from a chart's saved query_context.
Returns None if the chart has no query_context or parsing fails.
"""
from superset.charts.schemas import ChartDataQueryContextSchema
from superset.commands.chart.data.get_data_command import ChartDataCommand
from superset.common.chart_data import ChartDataResultType
from superset.utils import json as utils_json
if not chart.query_context:
return None
try:
qc_json = utils_json.loads(chart.query_context)
qc_json["result_type"] = ChartDataResultType.QUERY
qc_json["force"] = False
query_context = ChartDataQueryContextSchema().load(qc_json)
query_context.result_type = ChartDataResultType.QUERY
command = ChartDataCommand(query_context)
command.validate()
result = command.run()
return _extract_sql_from_result(
result, chart.id, chart.slice_name, chart.datasource_name
)
except SupersetSecurityException:
raise # Let access denials propagate for consistent error handling
except (SupersetException, CommandException, TypeError, ValueError):
return None
def _resolve_datasource_name(
form_data: dict[str, Any],
chart: "Slice | None",
) -> str | None:
"""Resolve datasource name from form_data or chart.
For unsaved charts (chart=None), looks up the datasource by ID
from form_data so that the response includes a meaningful name.
"""
if chart:
return getattr(chart, "datasource_name", None)
# Unsaved chart — resolve from form_data
datasource_id = form_data.get("datasource_id")
datasource_type = form_data.get("datasource_type", "table")
if not datasource_id and (combined := form_data.get("datasource")):
if isinstance(combined, str) and "__" in combined:
parts = combined.split("__", 1)
datasource_id = int(parts[0]) if parts[0].isdigit() else parts[0]
datasource_type = parts[1] if len(parts) > 1 else "table"
if not datasource_id:
return None
try:
from superset.daos.datasource import DatasourceDAO
from superset.daos.exceptions import (
DatasourceNotFound,
DatasourceTypeNotSupportedError,
DatasourceValueIsIncorrect,
)
from superset.utils.core import DatasourceType
datasource = DatasourceDAO.get_datasource(
datasource_type=DatasourceType(datasource_type),
database_id_or_uuid=datasource_id,
)
return getattr(datasource, "name", None)
except (
ValueError,
DatasourceNotFound,
DatasourceTypeNotSupportedError,
DatasourceValueIsIncorrect,
):
return None
def _sql_from_form_data(
form_data: dict[str, Any],
chart: "Slice | None",
) -> ChartSql | ChartError:
"""Build SQL from form_data (fallback path)."""
from superset.commands.chart.data.get_data_command import ChartDataCommand
query_context = _build_query_context_from_form_data(form_data, chart)
command = ChartDataCommand(query_context)
command.validate()
result = command.run()
return _extract_sql_from_result(
result,
chart_id=getattr(chart, "id", None),
chart_name=getattr(chart, "slice_name", None),
datasource_name=_resolve_datasource_name(form_data, chart),
)
def _extract_sql_from_result(
result: dict[str, Any],
chart_id: int | None,
chart_name: str | None,
datasource_name: str | None,
) -> ChartSql | ChartError:
"""Extract SQL query string(s) from the ChartDataCommand result.
Charts with multiple queries (e.g. mixed timeseries) produce multiple
query entries. This collects SQL from all of them so that composite
visualisations do not silently lose part of their SQL.
"""
queries = result.get("queries", [])
if not queries:
return ChartError(
error=(
"No query results returned. The chart may have an empty configuration."
),
error_type="EmptyQuery",
)
sql_parts: list[str] = []
errors: list[str] = []
language = "sql"
for idx, query_result in enumerate(queries, start=1):
if not isinstance(query_result, dict):
continue
if "language" in query_result and isinstance(query_result["language"], str):
language = query_result["language"]
query_sql = query_result.get("query", "")
query_error = query_result.get("error")
if query_sql:
sql_parts.append(
query_sql if len(queries) == 1 else f"-- Query {idx}\n{query_sql}"
)
if query_error:
errors.append(f"Query {idx}: {query_error}")
if not sql_parts and errors:
return ChartError(
error="SQL generation failed: " + "; ".join(errors),
error_type="QueryGenerationFailed",
)
return _sanitize_chart_sql_for_llm_context(
ChartSql(
chart_id=chart_id,
chart_name=chart_name,
sql="\n\n".join(sql_parts),
language=language,
datasource_name=datasource_name,
error="; ".join(errors) if errors else None,
)
)
@tool(
tags=["data"],
class_permission_name="SQLLab",
method_permission_name="execute_sql_query",
annotations=ToolAnnotations(
title="Get chart SQL",
readOnlyHint=True,
destructiveHint=False,
),
)
async def get_chart_sql(
request: GetChartSqlRequest, ctx: Context
) -> ChartSql | ChartError:
"""Get the rendered SQL query for a chart without executing it.
Returns the SQL that a chart would execute, similar to the "View Query"
feature in the Superset UI. Useful for understanding, debugging, or
auditing the SQL generated by a chart's configuration.
Note: The returned SQL is RLS-scoped to the caller and contains
table/schema names. Same disclosure surface as the Superset UI
"View Query" modal. Consumers should not log or share the output
without considering data-access implications.
Supports:
- Numeric ID or UUID lookup
- form_data_key: get SQL for unsaved chart state from Explore view
Example usage:
```json
{
"identifier": 123
}
```
Returns the SQL query string and metadata about the chart.
"""
await ctx.info(
"Starting chart SQL retrieval: identifier=%s, form_data_key=%s"
% (request.identifier, request.form_data_key)
)
try:
return await _handle_chart_sql_request(request, ctx)
except SupersetException as e:
logger.exception("Superset error in get_chart_sql")
return ChartError(
error=f"Superset error: {e}",
error_type="SupersetError",
)
except (ValueError, TypeError) as e:
logger.exception("Unexpected error in get_chart_sql")
return ChartError(
error=f"Failed to generate chart SQL: {e}",
error_type="QueryGenerationFailed",
)
async def _handle_chart_sql_request(
request: GetChartSqlRequest, ctx: Context
) -> ChartSql | ChartError:
"""Core logic for get_chart_sql, extracted for complexity."""
# Handle unsaved chart (form_data_key only, no identifier)
if not request.identifier and request.form_data_key:
return await _handle_unsaved_chart_sql(request.form_data_key, ctx)
# Find the chart by identifier
if request.identifier is None:
return ChartError(
error="Chart identifier is required.",
error_type="ValidationError",
)
with event_logger.log_context(action="mcp.get_chart_sql.chart_lookup"):
chart = _find_chart_by_identifier(request.identifier)
if not chart:
return ChartError(
error=f"No chart found with identifier: {request.identifier}",
error_type="NotFound",
)
await ctx.info(
"Chart found: chart_id=%s, chart_name=%s, viz_type=%s"
% (chart.id, chart.slice_name, chart.viz_type)
)
# Validate the chart's dataset is accessible
validation_result = validate_chart_dataset(chart, check_access=True)
if not validation_result.is_valid:
await ctx.warning(
"Chart found but dataset is not accessible: %s" % (validation_result.error,)
)
return ChartError(
error=validation_result.error or "Chart's dataset is not accessible.",
error_type="DatasetNotAccessible",
)
for warning in validation_result.warnings:
await ctx.warning("Dataset warning: %s" % (warning,))
# Resolve effective form_data
effective_form_data, using_unsaved_state = _resolve_effective_form_data(
chart, request.form_data_key
)
# Try saved query_context first (faster, more accurate)
with event_logger.log_context(action="mcp.get_chart_sql.build_query"):
if not using_unsaved_state:
saved_result = _sql_from_saved_query_context(chart)
if saved_result is not None:
return saved_result
await ctx.warning(
"Could not extract SQL from saved query_context. "
"Falling back to form_data."
)
# Fallback: build query context from form_data
try:
return _sql_from_form_data(effective_form_data, chart)
except (SupersetException, CommandException, ValueError) as e:
await ctx.warning("Failed to build SQL from form_data: %s" % str(e))
return ChartError(
error="Failed to generate SQL for chart %s: %s" % (chart.id, e),
error_type="QueryGenerationFailed",
)
async def _handle_unsaved_chart_sql(
form_data_key: str, ctx: Context
) -> ChartSql | ChartError:
"""Handle SQL retrieval for unsaved charts (form_data_key only)."""
from superset.utils import json as utils_json
with event_logger.log_context(action="mcp.get_chart_sql.unsaved_chart_from_cache"):
await ctx.info(
"No chart identifier - getting SQL from unsaved chart cache: "
"form_data_key=%s" % (form_data_key,)
)
cached_form_data = _get_cached_form_data(form_data_key)
if not cached_form_data:
return ChartError(
error="No cached chart data found for form_data_key. "
"The cache may have expired.",
error_type="NotFound",
)
try:
form_data = utils_json.loads(cached_form_data)
except (TypeError, ValueError) as e:
return ChartError(
error=f"Failed to parse cached form_data: {e}",
error_type="ParseError",
)
if not isinstance(form_data, dict):
return ChartError(
error="Cached form_data is not a valid JSON object.",
error_type="ParseError",
)
try:
return _sql_from_form_data(form_data, chart=None)
except (SupersetException, CommandException, ValueError) as e:
await ctx.warning("Failed to generate SQL from form_data: %s" % str(e))
return ChartError(
error="Failed to generate SQL from cached form_data: %s" % str(e),
error_type="QueryGenerationFailed",
)