fix(mcp): return all statement results for multi-statement SQL queries (#38388)

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Amin Ghadersohi
2026-03-13 16:53:52 +01:00
committed by GitHub
parent f4a57a13bc
commit b6c3b3ef46
3 changed files with 188 additions and 36 deletions

View File

@@ -84,6 +84,15 @@ class ColumnInfo(BaseModel):
is_nullable: bool | None = Field(None, description="Whether column allows NULL")
class StatementData(BaseModel):
"""Row data and column metadata for a single SQL statement."""
rows: list[dict[str, Any]] = Field(
..., description="Result rows as list of dictionaries"
)
columns: list[ColumnInfo] = Field(..., description="Column metadata information")
class StatementInfo(BaseModel):
"""Information about a single SQL statement execution."""
@@ -95,6 +104,14 @@ class StatementInfo(BaseModel):
execution_time_ms: float | None = Field(
None, description="Statement execution time in milliseconds"
)
data: StatementData | None = Field(
None,
description=(
"Row data and column metadata for this statement. "
"Present for data-bearing statements (e.g., SELECT), "
"absent for DML/DDL statements (e.g., SET, UPDATE)."
),
)
class ExecuteSqlResponse(BaseModel):
@@ -119,6 +136,15 @@ class ExecuteSqlResponse(BaseModel):
statements: list[StatementInfo] | None = Field(
None, description="Per-statement execution info (for multi-statement queries)"
)
multi_statement_warning: str | None = Field(
None,
description=(
"Warning when multiple data-bearing statements were executed. "
"The top-level rows/columns contain only the last "
"data-bearing statement's results. "
"Check each entry in the statements array for per-statement data."
),
)
class OpenSqlLabRequest(BaseModel):

View File

@@ -43,6 +43,7 @@ from superset.mcp_service.sql_lab.schemas import (
ColumnInfo,
ExecuteSqlRequest,
ExecuteSqlResponse,
StatementData,
StatementInfo,
)
from superset.mcp_service.utils.schema_utils import parse_request
@@ -162,56 +163,69 @@ def _convert_to_response(result: QueryResult) -> ExecuteSqlResponse:
error_type=result.status.value,
)
# Build statement info list
statements = [
StatementInfo(
original_sql=stmt.original_sql,
executed_sql=stmt.executed_sql,
row_count=stmt.row_count,
execution_time_ms=stmt.execution_time_ms,
)
for stmt in result.statements
]
# Build statement info list, including per-statement row data
# for data-bearing statements (e.g., SELECT).
statements: list[StatementInfo] = []
data_bearing_count = 0
# Find the last statement with data (SELECT results).
# For single statements this is the same as first.
# For multi-statement queries (e.g., SET ...; SELECT ...) this skips
# non-data statements and returns the actual query results.
for stmt in result.statements:
stmt_data: StatementData | None = None
if stmt.data is not None:
df = stmt.data
stmt_data = StatementData(
rows=df.to_dict(orient="records"),
columns=[
ColumnInfo(name=col, type=str(df[col].dtype)) for col in df.columns
],
)
data_bearing_count += 1
statements.append(
StatementInfo(
original_sql=stmt.original_sql,
executed_sql=stmt.executed_sql,
row_count=stmt.row_count,
execution_time_ms=stmt.execution_time_ms,
data=stmt_data,
)
)
# Top-level rows/columns come from the last data-bearing statement
# for backward compatibility.
rows: list[dict[str, Any]] | None = None
columns: list[ColumnInfo] | None = None
row_count: int | None = None
affected_rows: int | None = None
data_stmt = None
for stmt in reversed(result.statements):
last_data_stmt = None
for stmt in reversed(statements):
if stmt.data is not None:
data_stmt = stmt
last_data_stmt = stmt
break
if data_stmt is not None and data_stmt.data is not None:
# SELECT query - convert DataFrame
import pandas as pd
df = data_stmt.data
if not isinstance(df, pd.DataFrame):
logger.error(
"Expected DataFrame but got %s for statement data",
type(df).__name__,
)
return ExecuteSqlResponse(
success=False,
error=f"Internal error: unexpected data type ({type(df).__name__})",
error_type="data_conversion_error",
statements=statements,
)
rows = df.to_dict(orient="records")
columns = [ColumnInfo(name=col, type=str(df[col].dtype)) for col in df.columns]
row_count = len(df)
if last_data_stmt is not None and last_data_stmt.data is not None:
rows = last_data_stmt.data.rows
columns = last_data_stmt.data.columns
row_count = len(last_data_stmt.data.rows)
elif result.statements:
# DML-only query
last_stmt = result.statements[-1]
affected_rows = last_stmt.row_count
# Warn when multiple data-bearing statements exist so the LLM
# knows to inspect the statements array for all results.
multi_statement_warning: str | None = None
if data_bearing_count > 1:
multi_statement_warning = (
f"This query contained {data_bearing_count} "
"data-bearing statements. "
"The top-level rows/columns contain only the "
"last data-bearing statement's results. "
"Check the 'data' field in each entry of the "
"'statements' array to see results from ALL "
"statements."
)
return ExecuteSqlResponse(
success=True,
rows=rows,
@@ -224,4 +238,5 @@ def _convert_to_response(result: QueryResult) -> ExecuteSqlResponse:
else None
),
statements=statements,
multi_statement_warning=multi_statement_warning,
)

View File

@@ -530,6 +530,21 @@ class TestExecuteSql:
assert data["rows"] == [{"b": 2}]
assert data["row_count"] == 1
# Per-statement data should be present for both statements
assert data["statements"][0]["data"] is not None
assert data["statements"][0]["data"]["rows"] == [{"a": 1}]
assert len(data["statements"][0]["data"]["columns"]) == 1
assert data["statements"][0]["data"]["columns"][0]["name"] == "a"
assert data["statements"][1]["data"] is not None
assert data["statements"][1]["data"]["rows"] == [{"b": 2}]
assert len(data["statements"][1]["data"]["columns"]) == 1
assert data["statements"][1]["data"]["columns"][0]["name"] == "b"
# Warning should be present for multi-data-bearing queries
assert data["multi_statement_warning"] is not None
assert "2 data-bearing statements" in data["multi_statement_warning"]
@patch("superset.security_manager")
@patch("superset.db")
@pytest.mark.asyncio
@@ -609,6 +624,14 @@ class TestExecuteSql:
assert "id" in column_names
assert "amount" in column_names
# SET statement should have no data, SELECT should have data
assert data["statements"][0]["data"] is None
assert data["statements"][1]["data"] is not None
assert len(data["statements"][1]["data"]["rows"]) == 2
# No warning since only one data-bearing statement
assert data["multi_statement_warning"] is None
@patch("superset.security_manager")
@patch("superset.db")
@pytest.mark.asyncio
@@ -664,6 +687,94 @@ class TestExecuteSql:
# affected_rows should come from the last statement
assert data["affected_rows"] == 5
# DML statements should have no per-statement data
assert data["statements"][0]["data"] is None
assert data["statements"][1]["data"] is None
# No warning for DML-only queries
assert data["multi_statement_warning"] is None
@patch("superset.security_manager")
@patch("superset.db")
@pytest.mark.asyncio
async def test_execute_sql_multi_statement_preserves_all_data(
self, mock_db, mock_security_manager, mcp_server
) -> None:
"""Test that multi-statement SQL returns per-statement data for ALL results.
Regression test: previously, running two SELECT statements would only
return the last statement's rows in the top-level response and
completely lose the first statement's row data.
"""
mock_database = _mock_database()
mock_database.execute.return_value = QueryResult(
status=QueryStatus.SUCCESS,
statements=[
StatementResult(
original_sql="SELECT COUNT(*) AS order_count FROM orders",
executed_sql="SELECT COUNT(*) AS order_count FROM orders",
data=pd.DataFrame([{"order_count": 42}]),
row_count=1,
execution_time_ms=5.0,
),
StatementResult(
original_sql="SELECT SUM(revenue) AS total_revenue FROM orders",
executed_sql="SELECT SUM(revenue) AS total_revenue FROM orders",
data=pd.DataFrame([{"total_revenue": 12345.67}]),
row_count=1,
execution_time_ms=7.0,
),
],
query_id=None,
total_execution_time_ms=12.0,
is_cached=False,
)
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
mock_database
)
mock_security_manager.can_access_database.return_value = True
request = {
"database_id": 1,
"sql": (
"SELECT COUNT(*) AS order_count FROM orders;"
" SELECT SUM(revenue) AS total_revenue FROM orders"
),
"limit": 100,
}
async with Client(mcp_server) as client:
result = await client.call_tool("execute_sql", {"request": request})
data = result.structured_content
assert data["success"] is True
# Top-level rows/columns should be from the LAST data-bearing stmt
assert data["rows"] == [{"total_revenue": 12345.67}]
assert data["row_count"] == 1
# Both statements should have per-statement data
assert len(data["statements"]) == 2
# First statement's data is accessible
first_stmt = data["statements"][0]
assert first_stmt["data"] is not None
assert first_stmt["data"]["rows"] == [{"order_count": 42}]
assert len(first_stmt["data"]["columns"]) == 1
assert first_stmt["data"]["columns"][0]["name"] == "order_count"
# Second statement's data is accessible
second_stmt = data["statements"][1]
assert second_stmt["data"] is not None
assert second_stmt["data"]["rows"] == [{"total_revenue": 12345.67}]
assert len(second_stmt["data"]["columns"]) == 1
assert second_stmt["data"]["columns"][0]["name"] == "total_revenue"
# Warning should tell LLM to check statements array
assert data["multi_statement_warning"] is not None
assert "2 data-bearing statements" in data["multi_statement_warning"]
assert "statements" in data["multi_statement_warning"]
@pytest.mark.asyncio
async def test_execute_sql_empty_query_validation(self, mcp_server):
"""Test validation of empty SQL query."""