mirror of
https://github.com/apache/superset.git
synced 2026-04-17 23:25:05 +00:00
refactor: Migrates the MCP execute_sql tool to use the SQL execution API (#36739)
Co-authored-by: codeant-ai-for-open-source[bot] <244253245+codeant-ai-for-open-source[bot]@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
c0bcf28947
commit
6b25d0663e
@@ -18,19 +18,26 @@
|
||||
"""
|
||||
Execute SQL MCP Tool
|
||||
|
||||
Tool for executing SQL queries against databases with security validation
|
||||
and timeout protection.
|
||||
Tool for executing SQL queries against databases using the unified
|
||||
Database.execute() API with RLS, template rendering, and security validation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from fastmcp import Context
|
||||
from superset_core.api.types import CacheOptions, QueryOptions, QueryResult, QueryStatus
|
||||
from superset_core.mcp import tool
|
||||
|
||||
from superset.mcp_service.sql_lab.execute_sql_core import ExecuteSqlCore
|
||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||
from superset.exceptions import SupersetErrorException, SupersetSecurityException
|
||||
from superset.mcp_service.sql_lab.schemas import (
|
||||
ColumnInfo,
|
||||
ExecuteSqlRequest,
|
||||
ExecuteSqlResponse,
|
||||
StatementInfo,
|
||||
)
|
||||
from superset.mcp_service.utils.schema_utils import parse_request
|
||||
|
||||
@@ -40,10 +47,7 @@ logger = logging.getLogger(__name__)
|
||||
@tool(tags=["mutate"])
|
||||
@parse_request(ExecuteSqlRequest)
|
||||
async def execute_sql(request: ExecuteSqlRequest, ctx: Context) -> ExecuteSqlResponse:
|
||||
"""Execute SQL query against database.
|
||||
|
||||
Returns query results with security validation and timeout protection.
|
||||
"""
|
||||
"""Execute SQL query against database using the unified Database.execute() API."""
|
||||
await ctx.info(
|
||||
"Starting SQL execution: database_id=%s, timeout=%s, limit=%s, schema=%s"
|
||||
% (request.database_id, request.timeout, request.limit, request.schema_name)
|
||||
@@ -52,36 +56,78 @@ async def execute_sql(request: ExecuteSqlRequest, ctx: Context) -> ExecuteSqlRes
|
||||
# Log SQL query details (truncated for security)
|
||||
sql_preview = request.sql[:100] + "..." if len(request.sql) > 100 else request.sql
|
||||
await ctx.debug(
|
||||
"SQL query details: sql_preview=%r, sql_length=%s, has_parameters=%s"
|
||||
"SQL query details: sql_preview=%r, sql_length=%s, has_template_params=%s"
|
||||
% (
|
||||
sql_preview,
|
||||
len(request.sql),
|
||||
bool(request.parameters),
|
||||
bool(request.template_params),
|
||||
)
|
||||
)
|
||||
|
||||
logger.info("Executing SQL query on database ID: %s", request.database_id)
|
||||
|
||||
try:
|
||||
# Use the ExecuteSqlCore to handle all the logic
|
||||
sql_tool = ExecuteSqlCore(use_command_mode=False, logger=logger)
|
||||
result = sql_tool.run_tool(request)
|
||||
# Import inside function to avoid initialization issues
|
||||
from superset import db, security_manager
|
||||
from superset.models.core import Database
|
||||
|
||||
# 1. Get database and check access
|
||||
database = db.session.query(Database).filter_by(id=request.database_id).first()
|
||||
if not database:
|
||||
raise SupersetErrorException(
|
||||
SupersetError(
|
||||
message=f"Database with ID {request.database_id} not found",
|
||||
error_type=SupersetErrorType.DATABASE_NOT_FOUND_ERROR,
|
||||
level=ErrorLevel.ERROR,
|
||||
)
|
||||
)
|
||||
|
||||
if not security_manager.can_access_database(database):
|
||||
raise SupersetSecurityException(
|
||||
SupersetError(
|
||||
message=f"Access denied to database {database.database_name}",
|
||||
error_type=SupersetErrorType.DATABASE_SECURITY_ACCESS_ERROR,
|
||||
level=ErrorLevel.ERROR,
|
||||
)
|
||||
)
|
||||
|
||||
# 2. Build QueryOptions
|
||||
# Caching is enabled by default to reduce database load.
|
||||
# force_refresh bypasses cache when user explicitly requests fresh data.
|
||||
cache_opts = CacheOptions(force_refresh=True) if request.force_refresh else None
|
||||
options = QueryOptions(
|
||||
catalog=request.catalog,
|
||||
schema=request.schema_name,
|
||||
limit=request.limit,
|
||||
timeout_seconds=request.timeout,
|
||||
template_params=request.template_params,
|
||||
dry_run=request.dry_run,
|
||||
cache=cache_opts,
|
||||
)
|
||||
|
||||
# 3. Execute query
|
||||
result = database.execute(request.sql, options)
|
||||
|
||||
# 4. Convert to MCP response format
|
||||
response = _convert_to_response(result)
|
||||
|
||||
# Log successful execution
|
||||
if hasattr(result, "data") and result.data:
|
||||
row_count = len(result.data) if isinstance(result.data, list) else 1
|
||||
if response.success:
|
||||
await ctx.info(
|
||||
"SQL execution completed successfully: rows_returned=%s, "
|
||||
"query_duration_ms=%s"
|
||||
"execution_time=%s"
|
||||
% (
|
||||
row_count,
|
||||
getattr(result, "query_duration_ms", None),
|
||||
response.row_count,
|
||||
response.execution_time,
|
||||
)
|
||||
)
|
||||
else:
|
||||
await ctx.info("SQL execution completed: status=no_data_returned")
|
||||
await ctx.info(
|
||||
"SQL execution failed: error=%s, error_type=%s"
|
||||
% (response.error, response.error_type)
|
||||
)
|
||||
|
||||
return result
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
await ctx.error(
|
||||
@@ -92,3 +138,55 @@ async def execute_sql(request: ExecuteSqlRequest, ctx: Context) -> ExecuteSqlRes
|
||||
)
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
def _convert_to_response(result: QueryResult) -> ExecuteSqlResponse:
|
||||
"""Convert QueryResult to ExecuteSqlResponse."""
|
||||
if result.status != QueryStatus.SUCCESS:
|
||||
return ExecuteSqlResponse(
|
||||
success=False,
|
||||
error=result.error_message,
|
||||
error_type=result.status.value,
|
||||
)
|
||||
|
||||
# Build statement info list
|
||||
statements = [
|
||||
StatementInfo(
|
||||
original_sql=stmt.original_sql,
|
||||
executed_sql=stmt.executed_sql,
|
||||
row_count=stmt.row_count,
|
||||
execution_time_ms=stmt.execution_time_ms,
|
||||
)
|
||||
for stmt in result.statements
|
||||
]
|
||||
|
||||
# Get first statement's data for backward compatibility
|
||||
first_stmt = result.statements[0] if result.statements else None
|
||||
rows: list[dict[str, Any]] | None = None
|
||||
columns: list[ColumnInfo] | None = None
|
||||
row_count: int | None = None
|
||||
affected_rows: int | None = None
|
||||
|
||||
if first_stmt and first_stmt.data is not None:
|
||||
# SELECT query - convert DataFrame
|
||||
df = first_stmt.data
|
||||
rows = df.to_dict(orient="records")
|
||||
columns = [ColumnInfo(name=col, type=str(df[col].dtype)) for col in df.columns]
|
||||
row_count = len(df)
|
||||
elif first_stmt:
|
||||
# DML query
|
||||
affected_rows = first_stmt.row_count
|
||||
|
||||
return ExecuteSqlResponse(
|
||||
success=True,
|
||||
rows=rows,
|
||||
columns=columns,
|
||||
row_count=row_count,
|
||||
affected_rows=affected_rows,
|
||||
execution_time=(
|
||||
result.total_execution_time_ms / 1000
|
||||
if result.total_execution_time_ms is not None
|
||||
else None
|
||||
),
|
||||
statements=statements,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user