Files
superset2/superset/mcp_service/sql_lab/tool/execute_sql.py

297 lines
10 KiB
Python

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Execute SQL MCP Tool
Tool for executing SQL queries against databases using the unified
Database.execute() API with RLS, template rendering, and security validation.
"""
import logging
from decimal import Decimal
from typing import Any
import pandas as pd
from fastmcp import Context
from superset_core.mcp.decorators import tool, ToolAnnotations
from superset_core.queries.types import (
CacheOptions,
QueryOptions,
QueryResult,
QueryStatus,
)
from superset.errors import SupersetErrorType
from superset.extensions import event_logger
from superset.mcp_service.sql_lab.schemas import (
ColumnInfo,
ExecuteSqlRequest,
ExecuteSqlResponse,
StatementData,
StatementInfo,
)
logger = logging.getLogger(__name__)
@tool(
tags=["mutate"],
class_permission_name="SQLLab",
method_permission_name="execute_sql_query",
annotations=ToolAnnotations(
title="Execute SQL query",
readOnlyHint=False,
destructiveHint=True,
),
)
async def execute_sql(request: ExecuteSqlRequest, ctx: Context) -> ExecuteSqlResponse:
"""Execute SQL query against database using the unified Database.execute() API."""
await ctx.info(
"Starting SQL execution: database_id=%s, timeout=%s, limit=%s, schema=%s"
% (request.database_id, request.timeout, request.limit, request.schema_name)
)
# Log SQL query details (truncated for security)
sql_preview = request.sql[:100] + "..." if len(request.sql) > 100 else request.sql
await ctx.debug(
"SQL query details: sql_preview=%r, sql_length=%s, has_template_params=%s"
% (
sql_preview,
len(request.sql),
bool(request.template_params),
)
)
logger.info("Executing SQL query on database ID: %s", request.database_id)
try:
# Import inside function to avoid initialization issues
from superset import db, security_manager
from superset.models.core import Database
# 1. Get database and check access
with event_logger.log_context(action="mcp.execute_sql.db_validation"):
database = (
db.session.query(Database).filter_by(id=request.database_id).first()
)
if not database:
await ctx.error(
"Database not found: database_id=%s" % request.database_id
)
return ExecuteSqlResponse(
success=False,
error=f"Database with ID {request.database_id} not found",
error_type=SupersetErrorType.DATABASE_NOT_FOUND_ERROR.value,
)
if not security_manager.can_access_database(database):
await ctx.error(
"Access denied to database: %s" % database.database_name
)
return ExecuteSqlResponse(
success=False,
error=f"Access denied to database {database.database_name}",
error_type=SupersetErrorType.DATABASE_SECURITY_ACCESS_ERROR.value,
)
# 2. Build QueryOptions and execute query
cache_opts = CacheOptions(force_refresh=True) if request.force_refresh else None
options = QueryOptions(
catalog=request.catalog,
schema=request.schema_name,
limit=request.limit,
timeout_seconds=request.timeout,
template_params=request.template_params,
dry_run=request.dry_run,
cache=cache_opts,
)
# 3. Execute query
with event_logger.log_context(action="mcp.execute_sql.query_execution"):
result = database.execute(request.sql, options)
# 4. Convert to MCP response format
with event_logger.log_context(action="mcp.execute_sql.response_conversion"):
response = _convert_to_response(result)
# Log successful execution
if response.success:
await ctx.info(
"SQL execution completed successfully: rows_returned=%s, "
"execution_time=%s"
% (
response.row_count,
response.execution_time,
)
)
else:
await ctx.info(
"SQL execution failed: error=%s, error_type=%s"
% (response.error, response.error_type)
)
return response
except Exception as e:
await ctx.error(
"SQL execution failed: error=%s, database_id=%s"
% (
str(e),
request.database_id,
)
)
raise
def _sanitize_row_values(rows: list[dict[str, Any]]) -> None:
"""Sanitize non-serializable values in rows for JSON serialization."""
for row in rows:
for key, value in row.items():
if isinstance(value, (bytes, memoryview)):
raw = bytes(value) if isinstance(value, memoryview) else value
try:
row[key] = raw.decode("utf-8")
except (UnicodeDecodeError, AttributeError):
row[key] = raw.hex()
elif isinstance(value, Decimal):
row[key] = float(value)
elif not isinstance(value, (str, int, float, bool, type(None), list, dict)):
row[key] = str(value)
def _data_to_statement_data(data: Any) -> StatementData:
"""Convert statement data (DataFrame, list, dict, bytes) to StatementData.
When results come from cache, data may be a dict/list/bytes instead of
a pandas DataFrame. This function handles all cases defensively.
"""
from superset.utils import json as json_utils
if isinstance(data, list):
rows_data = data
elif isinstance(data, dict):
rows_data = data.get("data", [data])
if not isinstance(rows_data, list):
rows_data = [rows_data]
elif isinstance(data, pd.DataFrame):
rows_data = data.to_dict(orient="records")
_sanitize_row_values(rows_data)
return StatementData(
rows=rows_data,
columns=[
ColumnInfo(name=col, type=str(data[col].dtype)) for col in data.columns
],
)
elif isinstance(data, bytes):
try:
decoded = json_utils.loads(data)
rows_data = decoded if isinstance(decoded, list) else [decoded]
except (ValueError, UnicodeDecodeError):
rows_data = []
else:
rows_data = [{"value": str(data)}]
_sanitize_row_values(rows_data)
col_names = list(rows_data[0].keys()) if rows_data else []
return StatementData(
rows=rows_data,
columns=[ColumnInfo(name=col, type="object") for col in col_names],
)
def _convert_to_response(result: QueryResult) -> ExecuteSqlResponse:
"""Convert QueryResult to ExecuteSqlResponse."""
if result.status != QueryStatus.SUCCESS:
return ExecuteSqlResponse(
success=False,
error=result.error_message,
error_type=result.status.value,
)
# Build statement info list, including per-statement row data
# for data-bearing statements (e.g., SELECT).
statements: list[StatementInfo] = []
data_bearing_count = 0
for stmt in result.statements:
stmt_data: StatementData | None = None
if stmt.data is not None:
stmt_data = _data_to_statement_data(stmt.data)
data_bearing_count += 1
statements.append(
StatementInfo(
original_sql=stmt.original_sql,
executed_sql=stmt.executed_sql,
row_count=stmt.row_count,
execution_time_ms=stmt.execution_time_ms,
data=stmt_data,
)
)
# Top-level rows/columns come from the last data-bearing statement
# for backward compatibility.
rows: list[dict[str, Any]] | None = None
columns: list[ColumnInfo] | None = None
row_count: int | None = None
affected_rows: int | None = None
last_data_stmt = None
for stmt in reversed(statements):
if stmt.data is not None:
last_data_stmt = stmt
break
if last_data_stmt is not None and last_data_stmt.data is not None:
rows = last_data_stmt.data.rows
columns = last_data_stmt.data.columns
row_count = len(last_data_stmt.data.rows)
elif result.statements:
# DML-only query
last_stmt = result.statements[-1]
affected_rows = last_stmt.row_count
# Warn when multiple data-bearing statements exist so the LLM
# knows to inspect the statements array for all results.
multi_statement_warning: str | None = None
if data_bearing_count > 1:
multi_statement_warning = (
f"This query contained {data_bearing_count} "
"data-bearing statements. "
"The top-level rows/columns contain only the "
"last data-bearing statement's results. "
"Check the 'data' field in each entry of the "
"'statements' array to see results from ALL "
"statements."
)
return ExecuteSqlResponse(
success=True,
rows=rows,
columns=columns,
row_count=row_count,
affected_rows=affected_rows,
execution_time=(
result.total_execution_time_ms / 1000
if result.total_execution_time_ms is not None
else None
),
statements=statements,
multi_statement_warning=multi_statement_warning,
)