fix(mcp): return all statement results for multi-statement SQL queries (#38388)

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Amin Ghadersohi
2026-03-13 16:53:52 +01:00
committed by GitHub
parent f4a57a13bc
commit b6c3b3ef46
3 changed files with 188 additions and 36 deletions

View File

@@ -43,6 +43,7 @@ from superset.mcp_service.sql_lab.schemas import (
ColumnInfo,
ExecuteSqlRequest,
ExecuteSqlResponse,
StatementData,
StatementInfo,
)
from superset.mcp_service.utils.schema_utils import parse_request
@@ -162,56 +163,69 @@ def _convert_to_response(result: QueryResult) -> ExecuteSqlResponse:
error_type=result.status.value,
)
# Build statement info list
statements = [
StatementInfo(
original_sql=stmt.original_sql,
executed_sql=stmt.executed_sql,
row_count=stmt.row_count,
execution_time_ms=stmt.execution_time_ms,
)
for stmt in result.statements
]
# Build statement info list, including per-statement row data
# for data-bearing statements (e.g., SELECT).
statements: list[StatementInfo] = []
data_bearing_count = 0
# Find the last statement with data (SELECT results).
# For single statements this is the same as first.
# For multi-statement queries (e.g., SET ...; SELECT ...) this skips
# non-data statements and returns the actual query results.
for stmt in result.statements:
stmt_data: StatementData | None = None
if stmt.data is not None:
df = stmt.data
stmt_data = StatementData(
rows=df.to_dict(orient="records"),
columns=[
ColumnInfo(name=col, type=str(df[col].dtype)) for col in df.columns
],
)
data_bearing_count += 1
statements.append(
StatementInfo(
original_sql=stmt.original_sql,
executed_sql=stmt.executed_sql,
row_count=stmt.row_count,
execution_time_ms=stmt.execution_time_ms,
data=stmt_data,
)
)
# Top-level rows/columns come from the last data-bearing statement
# for backward compatibility.
rows: list[dict[str, Any]] | None = None
columns: list[ColumnInfo] | None = None
row_count: int | None = None
affected_rows: int | None = None
data_stmt = None
for stmt in reversed(result.statements):
last_data_stmt = None
for stmt in reversed(statements):
if stmt.data is not None:
data_stmt = stmt
last_data_stmt = stmt
break
if data_stmt is not None and data_stmt.data is not None:
# SELECT query - convert DataFrame
import pandas as pd
df = data_stmt.data
if not isinstance(df, pd.DataFrame):
logger.error(
"Expected DataFrame but got %s for statement data",
type(df).__name__,
)
return ExecuteSqlResponse(
success=False,
error=f"Internal error: unexpected data type ({type(df).__name__})",
error_type="data_conversion_error",
statements=statements,
)
rows = df.to_dict(orient="records")
columns = [ColumnInfo(name=col, type=str(df[col].dtype)) for col in df.columns]
row_count = len(df)
if last_data_stmt is not None and last_data_stmt.data is not None:
rows = last_data_stmt.data.rows
columns = last_data_stmt.data.columns
row_count = len(last_data_stmt.data.rows)
elif result.statements:
# DML-only query
last_stmt = result.statements[-1]
affected_rows = last_stmt.row_count
# Warn when multiple data-bearing statements exist so the LLM
# knows to inspect the statements array for all results.
multi_statement_warning: str | None = None
if data_bearing_count > 1:
multi_statement_warning = (
f"This query contained {data_bearing_count} "
"data-bearing statements. "
"The top-level rows/columns contain only the "
"last data-bearing statement's results. "
"Check the 'data' field in each entry of the "
"'statements' array to see results from ALL "
"statements."
)
return ExecuteSqlResponse(
success=True,
rows=rows,
@@ -224,4 +238,5 @@ def _convert_to_response(result: QueryResult) -> ExecuteSqlResponse:
else None
),
statements=statements,
multi_statement_warning=multi_statement_warning,
)