mirror of
https://github.com/apache/superset.git
synced 2026-04-14 05:34:38 +00:00
297 lines
10 KiB
Python
297 lines
10 KiB
Python
# Licensed to the Apache Software Foundation (ASF) under one
|
|
# or more contributor license agreements. See the NOTICE file
|
|
# distributed with this work for additional information
|
|
# regarding copyright ownership. The ASF licenses this file
|
|
# to you under the Apache License, Version 2.0 (the
|
|
# "License"); you may not use this file except in compliance
|
|
# with the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing,
|
|
# software distributed under the License is distributed on an
|
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
# KIND, either express or implied. See the License for the
|
|
# specific language governing permissions and limitations
|
|
# under the License.
|
|
|
|
"""
|
|
Execute SQL MCP Tool
|
|
|
|
Tool for executing SQL queries against databases using the unified
|
|
Database.execute() API with RLS, template rendering, and security validation.
|
|
"""
|
|
|
|
import logging
|
|
from decimal import Decimal
|
|
from typing import Any
|
|
|
|
import pandas as pd
|
|
from fastmcp import Context
|
|
from superset_core.mcp.decorators import tool, ToolAnnotations
|
|
from superset_core.queries.types import (
|
|
CacheOptions,
|
|
QueryOptions,
|
|
QueryResult,
|
|
QueryStatus,
|
|
)
|
|
|
|
from superset.errors import SupersetErrorType
|
|
from superset.extensions import event_logger
|
|
from superset.mcp_service.sql_lab.schemas import (
|
|
ColumnInfo,
|
|
ExecuteSqlRequest,
|
|
ExecuteSqlResponse,
|
|
StatementData,
|
|
StatementInfo,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@tool(
|
|
tags=["mutate"],
|
|
class_permission_name="SQLLab",
|
|
method_permission_name="execute_sql_query",
|
|
annotations=ToolAnnotations(
|
|
title="Execute SQL query",
|
|
readOnlyHint=False,
|
|
destructiveHint=True,
|
|
),
|
|
)
|
|
async def execute_sql(request: ExecuteSqlRequest, ctx: Context) -> ExecuteSqlResponse:
|
|
"""Execute SQL query against database using the unified Database.execute() API."""
|
|
await ctx.info(
|
|
"Starting SQL execution: database_id=%s, timeout=%s, limit=%s, schema=%s"
|
|
% (request.database_id, request.timeout, request.limit, request.schema_name)
|
|
)
|
|
|
|
# Log SQL query details (truncated for security)
|
|
sql_preview = request.sql[:100] + "..." if len(request.sql) > 100 else request.sql
|
|
await ctx.debug(
|
|
"SQL query details: sql_preview=%r, sql_length=%s, has_template_params=%s"
|
|
% (
|
|
sql_preview,
|
|
len(request.sql),
|
|
bool(request.template_params),
|
|
)
|
|
)
|
|
|
|
logger.info("Executing SQL query on database ID: %s", request.database_id)
|
|
|
|
try:
|
|
# Import inside function to avoid initialization issues
|
|
from superset import db, security_manager
|
|
from superset.models.core import Database
|
|
|
|
# 1. Get database and check access
|
|
with event_logger.log_context(action="mcp.execute_sql.db_validation"):
|
|
database = (
|
|
db.session.query(Database).filter_by(id=request.database_id).first()
|
|
)
|
|
if not database:
|
|
await ctx.error(
|
|
"Database not found: database_id=%s" % request.database_id
|
|
)
|
|
return ExecuteSqlResponse(
|
|
success=False,
|
|
error=f"Database with ID {request.database_id} not found",
|
|
error_type=SupersetErrorType.DATABASE_NOT_FOUND_ERROR.value,
|
|
)
|
|
|
|
if not security_manager.can_access_database(database):
|
|
await ctx.error(
|
|
"Access denied to database: %s" % database.database_name
|
|
)
|
|
return ExecuteSqlResponse(
|
|
success=False,
|
|
error=f"Access denied to database {database.database_name}",
|
|
error_type=SupersetErrorType.DATABASE_SECURITY_ACCESS_ERROR.value,
|
|
)
|
|
|
|
# 2. Build QueryOptions and execute query
|
|
cache_opts = CacheOptions(force_refresh=True) if request.force_refresh else None
|
|
options = QueryOptions(
|
|
catalog=request.catalog,
|
|
schema=request.schema_name,
|
|
limit=request.limit,
|
|
timeout_seconds=request.timeout,
|
|
template_params=request.template_params,
|
|
dry_run=request.dry_run,
|
|
cache=cache_opts,
|
|
)
|
|
|
|
# 3. Execute query
|
|
with event_logger.log_context(action="mcp.execute_sql.query_execution"):
|
|
result = database.execute(request.sql, options)
|
|
|
|
# 4. Convert to MCP response format
|
|
with event_logger.log_context(action="mcp.execute_sql.response_conversion"):
|
|
response = _convert_to_response(result)
|
|
|
|
# Log successful execution
|
|
if response.success:
|
|
await ctx.info(
|
|
"SQL execution completed successfully: rows_returned=%s, "
|
|
"execution_time=%s"
|
|
% (
|
|
response.row_count,
|
|
response.execution_time,
|
|
)
|
|
)
|
|
else:
|
|
await ctx.info(
|
|
"SQL execution failed: error=%s, error_type=%s"
|
|
% (response.error, response.error_type)
|
|
)
|
|
|
|
return response
|
|
|
|
except Exception as e:
|
|
await ctx.error(
|
|
"SQL execution failed: error=%s, database_id=%s"
|
|
% (
|
|
str(e),
|
|
request.database_id,
|
|
)
|
|
)
|
|
raise
|
|
|
|
|
|
def _sanitize_row_values(rows: list[dict[str, Any]]) -> None:
|
|
"""Sanitize non-serializable values in rows for JSON serialization."""
|
|
for row in rows:
|
|
for key, value in row.items():
|
|
if isinstance(value, (bytes, memoryview)):
|
|
raw = bytes(value) if isinstance(value, memoryview) else value
|
|
try:
|
|
row[key] = raw.decode("utf-8")
|
|
except (UnicodeDecodeError, AttributeError):
|
|
row[key] = raw.hex()
|
|
elif isinstance(value, Decimal):
|
|
row[key] = float(value)
|
|
elif not isinstance(value, (str, int, float, bool, type(None), list, dict)):
|
|
row[key] = str(value)
|
|
|
|
|
|
def _data_to_statement_data(data: Any) -> StatementData:
|
|
"""Convert statement data (DataFrame, list, dict, bytes) to StatementData.
|
|
|
|
When results come from cache, data may be a dict/list/bytes instead of
|
|
a pandas DataFrame. This function handles all cases defensively.
|
|
"""
|
|
from superset.utils import json as json_utils
|
|
|
|
if isinstance(data, list):
|
|
rows_data = data
|
|
elif isinstance(data, dict):
|
|
rows_data = data.get("data", [data])
|
|
if not isinstance(rows_data, list):
|
|
rows_data = [rows_data]
|
|
elif isinstance(data, pd.DataFrame):
|
|
rows_data = data.to_dict(orient="records")
|
|
_sanitize_row_values(rows_data)
|
|
return StatementData(
|
|
rows=rows_data,
|
|
columns=[
|
|
ColumnInfo(name=col, type=str(data[col].dtype)) for col in data.columns
|
|
],
|
|
)
|
|
elif isinstance(data, bytes):
|
|
try:
|
|
decoded = json_utils.loads(data)
|
|
rows_data = decoded if isinstance(decoded, list) else [decoded]
|
|
except (ValueError, UnicodeDecodeError):
|
|
rows_data = []
|
|
else:
|
|
rows_data = [{"value": str(data)}]
|
|
|
|
_sanitize_row_values(rows_data)
|
|
col_names = list(rows_data[0].keys()) if rows_data else []
|
|
return StatementData(
|
|
rows=rows_data,
|
|
columns=[ColumnInfo(name=col, type="object") for col in col_names],
|
|
)
|
|
|
|
|
|
def _convert_to_response(result: QueryResult) -> ExecuteSqlResponse:
|
|
"""Convert QueryResult to ExecuteSqlResponse."""
|
|
if result.status != QueryStatus.SUCCESS:
|
|
return ExecuteSqlResponse(
|
|
success=False,
|
|
error=result.error_message,
|
|
error_type=result.status.value,
|
|
)
|
|
|
|
# Build statement info list, including per-statement row data
|
|
# for data-bearing statements (e.g., SELECT).
|
|
statements: list[StatementInfo] = []
|
|
data_bearing_count = 0
|
|
|
|
for stmt in result.statements:
|
|
stmt_data: StatementData | None = None
|
|
if stmt.data is not None:
|
|
stmt_data = _data_to_statement_data(stmt.data)
|
|
data_bearing_count += 1
|
|
|
|
statements.append(
|
|
StatementInfo(
|
|
original_sql=stmt.original_sql,
|
|
executed_sql=stmt.executed_sql,
|
|
row_count=stmt.row_count,
|
|
execution_time_ms=stmt.execution_time_ms,
|
|
data=stmt_data,
|
|
)
|
|
)
|
|
|
|
# Top-level rows/columns come from the last data-bearing statement
|
|
# for backward compatibility.
|
|
rows: list[dict[str, Any]] | None = None
|
|
columns: list[ColumnInfo] | None = None
|
|
row_count: int | None = None
|
|
affected_rows: int | None = None
|
|
|
|
last_data_stmt = None
|
|
for stmt in reversed(statements):
|
|
if stmt.data is not None:
|
|
last_data_stmt = stmt
|
|
break
|
|
|
|
if last_data_stmt is not None and last_data_stmt.data is not None:
|
|
rows = last_data_stmt.data.rows
|
|
columns = last_data_stmt.data.columns
|
|
row_count = len(last_data_stmt.data.rows)
|
|
elif result.statements:
|
|
# DML-only query
|
|
last_stmt = result.statements[-1]
|
|
affected_rows = last_stmt.row_count
|
|
|
|
# Warn when multiple data-bearing statements exist so the LLM
|
|
# knows to inspect the statements array for all results.
|
|
multi_statement_warning: str | None = None
|
|
if data_bearing_count > 1:
|
|
multi_statement_warning = (
|
|
f"This query contained {data_bearing_count} "
|
|
"data-bearing statements. "
|
|
"The top-level rows/columns contain only the "
|
|
"last data-bearing statement's results. "
|
|
"Check the 'data' field in each entry of the "
|
|
"'statements' array to see results from ALL "
|
|
"statements."
|
|
)
|
|
|
|
return ExecuteSqlResponse(
|
|
success=True,
|
|
rows=rows,
|
|
columns=columns,
|
|
row_count=row_count,
|
|
affected_rows=affected_rows,
|
|
execution_time=(
|
|
result.total_execution_time_ms / 1000
|
|
if result.total_execution_time_ms is not None
|
|
else None
|
|
),
|
|
statements=statements,
|
|
multi_statement_warning=multi_statement_warning,
|
|
)
|