refactor: Migrates the MCP execute_sql tool to use the SQL execution API (#36739)

Co-authored-by: codeant-ai-for-open-source[bot] <244253245+codeant-ai-for-open-source[bot]@users.noreply.github.com>
This commit is contained in:
Michael S. Molina
2025-12-22 09:48:28 -03:00
committed by GitHub
parent c0bcf28947
commit 6b25d0663e
8 changed files with 492 additions and 838 deletions

View File

@@ -18,19 +18,26 @@
"""
Execute SQL MCP Tool
Tool for executing SQL queries against databases with security validation
and timeout protection.
Tool for executing SQL queries against databases using the unified
Database.execute() API with RLS, template rendering, and security validation.
"""
from __future__ import annotations
import logging
from typing import Any
from fastmcp import Context
from superset_core.api.types import CacheOptions, QueryOptions, QueryResult, QueryStatus
from superset_core.mcp import tool
from superset.mcp_service.sql_lab.execute_sql_core import ExecuteSqlCore
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import SupersetErrorException, SupersetSecurityException
from superset.mcp_service.sql_lab.schemas import (
ColumnInfo,
ExecuteSqlRequest,
ExecuteSqlResponse,
StatementInfo,
)
from superset.mcp_service.utils.schema_utils import parse_request
@@ -40,10 +47,7 @@ logger = logging.getLogger(__name__)
@tool(tags=["mutate"])
@parse_request(ExecuteSqlRequest)
async def execute_sql(request: ExecuteSqlRequest, ctx: Context) -> ExecuteSqlResponse:
"""Execute SQL query against database.
Returns query results with security validation and timeout protection.
"""
"""Execute SQL query against database using the unified Database.execute() API."""
await ctx.info(
"Starting SQL execution: database_id=%s, timeout=%s, limit=%s, schema=%s"
% (request.database_id, request.timeout, request.limit, request.schema_name)
@@ -52,36 +56,78 @@ async def execute_sql(request: ExecuteSqlRequest, ctx: Context) -> ExecuteSqlRes
# Log SQL query details (truncated for security)
sql_preview = request.sql[:100] + "..." if len(request.sql) > 100 else request.sql
await ctx.debug(
"SQL query details: sql_preview=%r, sql_length=%s, has_parameters=%s"
"SQL query details: sql_preview=%r, sql_length=%s, has_template_params=%s"
% (
sql_preview,
len(request.sql),
bool(request.parameters),
bool(request.template_params),
)
)
logger.info("Executing SQL query on database ID: %s", request.database_id)
try:
# Use the ExecuteSqlCore to handle all the logic
sql_tool = ExecuteSqlCore(use_command_mode=False, logger=logger)
result = sql_tool.run_tool(request)
# Import inside function to avoid initialization issues
from superset import db, security_manager
from superset.models.core import Database
# 1. Get database and check access
database = db.session.query(Database).filter_by(id=request.database_id).first()
if not database:
raise SupersetErrorException(
SupersetError(
message=f"Database with ID {request.database_id} not found",
error_type=SupersetErrorType.DATABASE_NOT_FOUND_ERROR,
level=ErrorLevel.ERROR,
)
)
if not security_manager.can_access_database(database):
raise SupersetSecurityException(
SupersetError(
message=f"Access denied to database {database.database_name}",
error_type=SupersetErrorType.DATABASE_SECURITY_ACCESS_ERROR,
level=ErrorLevel.ERROR,
)
)
# 2. Build QueryOptions
# Caching is enabled by default to reduce database load.
# force_refresh bypasses cache when user explicitly requests fresh data.
cache_opts = CacheOptions(force_refresh=True) if request.force_refresh else None
options = QueryOptions(
catalog=request.catalog,
schema=request.schema_name,
limit=request.limit,
timeout_seconds=request.timeout,
template_params=request.template_params,
dry_run=request.dry_run,
cache=cache_opts,
)
# 3. Execute query
result = database.execute(request.sql, options)
# 4. Convert to MCP response format
response = _convert_to_response(result)
# Log successful execution
if hasattr(result, "data") and result.data:
row_count = len(result.data) if isinstance(result.data, list) else 1
if response.success:
await ctx.info(
"SQL execution completed successfully: rows_returned=%s, "
"query_duration_ms=%s"
"execution_time=%s"
% (
row_count,
getattr(result, "query_duration_ms", None),
response.row_count,
response.execution_time,
)
)
else:
await ctx.info("SQL execution completed: status=no_data_returned")
await ctx.info(
"SQL execution failed: error=%s, error_type=%s"
% (response.error, response.error_type)
)
return result
return response
except Exception as e:
await ctx.error(
@@ -92,3 +138,55 @@ async def execute_sql(request: ExecuteSqlRequest, ctx: Context) -> ExecuteSqlRes
)
)
raise
def _convert_to_response(result: QueryResult) -> ExecuteSqlResponse:
"""Convert QueryResult to ExecuteSqlResponse."""
if result.status != QueryStatus.SUCCESS:
return ExecuteSqlResponse(
success=False,
error=result.error_message,
error_type=result.status.value,
)
# Build statement info list
statements = [
StatementInfo(
original_sql=stmt.original_sql,
executed_sql=stmt.executed_sql,
row_count=stmt.row_count,
execution_time_ms=stmt.execution_time_ms,
)
for stmt in result.statements
]
# Get first statement's data for backward compatibility
first_stmt = result.statements[0] if result.statements else None
rows: list[dict[str, Any]] | None = None
columns: list[ColumnInfo] | None = None
row_count: int | None = None
affected_rows: int | None = None
if first_stmt and first_stmt.data is not None:
# SELECT query - convert DataFrame
df = first_stmt.data
rows = df.to_dict(orient="records")
columns = [ColumnInfo(name=col, type=str(df[col].dtype)) for col in df.columns]
row_count = len(df)
elif first_stmt:
# DML query
affected_rows = first_stmt.row_count
return ExecuteSqlResponse(
success=True,
rows=rows,
columns=columns,
row_count=row_count,
affected_rows=affected_rows,
execution_time=(
result.total_execution_time_ms / 1000
if result.total_execution_time_ms is not None
else None
),
statements=statements,
)