mirror of
https://github.com/apache/superset.git
synced 2026-04-19 08:04:53 +00:00
fix(mcp): return all statement results for multi-statement SQL queries (#38388)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -84,6 +84,15 @@ class ColumnInfo(BaseModel):
|
||||
is_nullable: bool | None = Field(None, description="Whether column allows NULL")
|
||||
|
||||
|
||||
class StatementData(BaseModel):
|
||||
"""Row data and column metadata for a single SQL statement."""
|
||||
|
||||
rows: list[dict[str, Any]] = Field(
|
||||
..., description="Result rows as list of dictionaries"
|
||||
)
|
||||
columns: list[ColumnInfo] = Field(..., description="Column metadata information")
|
||||
|
||||
|
||||
class StatementInfo(BaseModel):
|
||||
"""Information about a single SQL statement execution."""
|
||||
|
||||
@@ -95,6 +104,14 @@ class StatementInfo(BaseModel):
|
||||
execution_time_ms: float | None = Field(
|
||||
None, description="Statement execution time in milliseconds"
|
||||
)
|
||||
data: StatementData | None = Field(
|
||||
None,
|
||||
description=(
|
||||
"Row data and column metadata for this statement. "
|
||||
"Present for data-bearing statements (e.g., SELECT), "
|
||||
"absent for DML/DDL statements (e.g., SET, UPDATE)."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class ExecuteSqlResponse(BaseModel):
|
||||
@@ -119,6 +136,15 @@ class ExecuteSqlResponse(BaseModel):
|
||||
statements: list[StatementInfo] | None = Field(
|
||||
None, description="Per-statement execution info (for multi-statement queries)"
|
||||
)
|
||||
multi_statement_warning: str | None = Field(
|
||||
None,
|
||||
description=(
|
||||
"Warning when multiple data-bearing statements were executed. "
|
||||
"The top-level rows/columns contain only the last "
|
||||
"data-bearing statement's results. "
|
||||
"Check each entry in the statements array for per-statement data."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class OpenSqlLabRequest(BaseModel):
|
||||
|
||||
@@ -43,6 +43,7 @@ from superset.mcp_service.sql_lab.schemas import (
|
||||
ColumnInfo,
|
||||
ExecuteSqlRequest,
|
||||
ExecuteSqlResponse,
|
||||
StatementData,
|
||||
StatementInfo,
|
||||
)
|
||||
from superset.mcp_service.utils.schema_utils import parse_request
|
||||
@@ -162,56 +163,69 @@ def _convert_to_response(result: QueryResult) -> ExecuteSqlResponse:
|
||||
error_type=result.status.value,
|
||||
)
|
||||
|
||||
# Build statement info list
|
||||
statements = [
|
||||
StatementInfo(
|
||||
original_sql=stmt.original_sql,
|
||||
executed_sql=stmt.executed_sql,
|
||||
row_count=stmt.row_count,
|
||||
execution_time_ms=stmt.execution_time_ms,
|
||||
)
|
||||
for stmt in result.statements
|
||||
]
|
||||
# Build statement info list, including per-statement row data
|
||||
# for data-bearing statements (e.g., SELECT).
|
||||
statements: list[StatementInfo] = []
|
||||
data_bearing_count = 0
|
||||
|
||||
# Find the last statement with data (SELECT results).
|
||||
# For single statements this is the same as first.
|
||||
# For multi-statement queries (e.g., SET ...; SELECT ...) this skips
|
||||
# non-data statements and returns the actual query results.
|
||||
for stmt in result.statements:
|
||||
stmt_data: StatementData | None = None
|
||||
if stmt.data is not None:
|
||||
df = stmt.data
|
||||
stmt_data = StatementData(
|
||||
rows=df.to_dict(orient="records"),
|
||||
columns=[
|
||||
ColumnInfo(name=col, type=str(df[col].dtype)) for col in df.columns
|
||||
],
|
||||
)
|
||||
data_bearing_count += 1
|
||||
|
||||
statements.append(
|
||||
StatementInfo(
|
||||
original_sql=stmt.original_sql,
|
||||
executed_sql=stmt.executed_sql,
|
||||
row_count=stmt.row_count,
|
||||
execution_time_ms=stmt.execution_time_ms,
|
||||
data=stmt_data,
|
||||
)
|
||||
)
|
||||
|
||||
# Top-level rows/columns come from the last data-bearing statement
|
||||
# for backward compatibility.
|
||||
rows: list[dict[str, Any]] | None = None
|
||||
columns: list[ColumnInfo] | None = None
|
||||
row_count: int | None = None
|
||||
affected_rows: int | None = None
|
||||
|
||||
data_stmt = None
|
||||
for stmt in reversed(result.statements):
|
||||
last_data_stmt = None
|
||||
for stmt in reversed(statements):
|
||||
if stmt.data is not None:
|
||||
data_stmt = stmt
|
||||
last_data_stmt = stmt
|
||||
break
|
||||
|
||||
if data_stmt is not None and data_stmt.data is not None:
|
||||
# SELECT query - convert DataFrame
|
||||
import pandas as pd
|
||||
|
||||
df = data_stmt.data
|
||||
if not isinstance(df, pd.DataFrame):
|
||||
logger.error(
|
||||
"Expected DataFrame but got %s for statement data",
|
||||
type(df).__name__,
|
||||
)
|
||||
return ExecuteSqlResponse(
|
||||
success=False,
|
||||
error=f"Internal error: unexpected data type ({type(df).__name__})",
|
||||
error_type="data_conversion_error",
|
||||
statements=statements,
|
||||
)
|
||||
rows = df.to_dict(orient="records")
|
||||
columns = [ColumnInfo(name=col, type=str(df[col].dtype)) for col in df.columns]
|
||||
row_count = len(df)
|
||||
if last_data_stmt is not None and last_data_stmt.data is not None:
|
||||
rows = last_data_stmt.data.rows
|
||||
columns = last_data_stmt.data.columns
|
||||
row_count = len(last_data_stmt.data.rows)
|
||||
elif result.statements:
|
||||
# DML-only query
|
||||
last_stmt = result.statements[-1]
|
||||
affected_rows = last_stmt.row_count
|
||||
|
||||
# Warn when multiple data-bearing statements exist so the LLM
|
||||
# knows to inspect the statements array for all results.
|
||||
multi_statement_warning: str | None = None
|
||||
if data_bearing_count > 1:
|
||||
multi_statement_warning = (
|
||||
f"This query contained {data_bearing_count} "
|
||||
"data-bearing statements. "
|
||||
"The top-level rows/columns contain only the "
|
||||
"last data-bearing statement's results. "
|
||||
"Check the 'data' field in each entry of the "
|
||||
"'statements' array to see results from ALL "
|
||||
"statements."
|
||||
)
|
||||
|
||||
return ExecuteSqlResponse(
|
||||
success=True,
|
||||
rows=rows,
|
||||
@@ -224,4 +238,5 @@ def _convert_to_response(result: QueryResult) -> ExecuteSqlResponse:
|
||||
else None
|
||||
),
|
||||
statements=statements,
|
||||
multi_statement_warning=multi_statement_warning,
|
||||
)
|
||||
|
||||
@@ -530,6 +530,21 @@ class TestExecuteSql:
|
||||
assert data["rows"] == [{"b": 2}]
|
||||
assert data["row_count"] == 1
|
||||
|
||||
# Per-statement data should be present for both statements
|
||||
assert data["statements"][0]["data"] is not None
|
||||
assert data["statements"][0]["data"]["rows"] == [{"a": 1}]
|
||||
assert len(data["statements"][0]["data"]["columns"]) == 1
|
||||
assert data["statements"][0]["data"]["columns"][0]["name"] == "a"
|
||||
|
||||
assert data["statements"][1]["data"] is not None
|
||||
assert data["statements"][1]["data"]["rows"] == [{"b": 2}]
|
||||
assert len(data["statements"][1]["data"]["columns"]) == 1
|
||||
assert data["statements"][1]["data"]["columns"][0]["name"] == "b"
|
||||
|
||||
# Warning should be present for multi-data-bearing queries
|
||||
assert data["multi_statement_warning"] is not None
|
||||
assert "2 data-bearing statements" in data["multi_statement_warning"]
|
||||
|
||||
@patch("superset.security_manager")
|
||||
@patch("superset.db")
|
||||
@pytest.mark.asyncio
|
||||
@@ -609,6 +624,14 @@ class TestExecuteSql:
|
||||
assert "id" in column_names
|
||||
assert "amount" in column_names
|
||||
|
||||
# SET statement should have no data, SELECT should have data
|
||||
assert data["statements"][0]["data"] is None
|
||||
assert data["statements"][1]["data"] is not None
|
||||
assert len(data["statements"][1]["data"]["rows"]) == 2
|
||||
|
||||
# No warning since only one data-bearing statement
|
||||
assert data["multi_statement_warning"] is None
|
||||
|
||||
@patch("superset.security_manager")
|
||||
@patch("superset.db")
|
||||
@pytest.mark.asyncio
|
||||
@@ -664,6 +687,94 @@ class TestExecuteSql:
|
||||
# affected_rows should come from the last statement
|
||||
assert data["affected_rows"] == 5
|
||||
|
||||
# DML statements should have no per-statement data
|
||||
assert data["statements"][0]["data"] is None
|
||||
assert data["statements"][1]["data"] is None
|
||||
|
||||
# No warning for DML-only queries
|
||||
assert data["multi_statement_warning"] is None
|
||||
|
||||
@patch("superset.security_manager")
|
||||
@patch("superset.db")
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_sql_multi_statement_preserves_all_data(
|
||||
self, mock_db, mock_security_manager, mcp_server
|
||||
) -> None:
|
||||
"""Test that multi-statement SQL returns per-statement data for ALL results.
|
||||
|
||||
Regression test: previously, running two SELECT statements would only
|
||||
return the last statement's rows in the top-level response and
|
||||
completely lose the first statement's row data.
|
||||
"""
|
||||
mock_database = _mock_database()
|
||||
mock_database.execute.return_value = QueryResult(
|
||||
status=QueryStatus.SUCCESS,
|
||||
statements=[
|
||||
StatementResult(
|
||||
original_sql="SELECT COUNT(*) AS order_count FROM orders",
|
||||
executed_sql="SELECT COUNT(*) AS order_count FROM orders",
|
||||
data=pd.DataFrame([{"order_count": 42}]),
|
||||
row_count=1,
|
||||
execution_time_ms=5.0,
|
||||
),
|
||||
StatementResult(
|
||||
original_sql="SELECT SUM(revenue) AS total_revenue FROM orders",
|
||||
executed_sql="SELECT SUM(revenue) AS total_revenue FROM orders",
|
||||
data=pd.DataFrame([{"total_revenue": 12345.67}]),
|
||||
row_count=1,
|
||||
execution_time_ms=7.0,
|
||||
),
|
||||
],
|
||||
query_id=None,
|
||||
total_execution_time_ms=12.0,
|
||||
is_cached=False,
|
||||
)
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
|
||||
mock_database
|
||||
)
|
||||
mock_security_manager.can_access_database.return_value = True
|
||||
|
||||
request = {
|
||||
"database_id": 1,
|
||||
"sql": (
|
||||
"SELECT COUNT(*) AS order_count FROM orders;"
|
||||
" SELECT SUM(revenue) AS total_revenue FROM orders"
|
||||
),
|
||||
"limit": 100,
|
||||
}
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("execute_sql", {"request": request})
|
||||
|
||||
data = result.structured_content
|
||||
assert data["success"] is True
|
||||
|
||||
# Top-level rows/columns should be from the LAST data-bearing stmt
|
||||
assert data["rows"] == [{"total_revenue": 12345.67}]
|
||||
assert data["row_count"] == 1
|
||||
|
||||
# Both statements should have per-statement data
|
||||
assert len(data["statements"]) == 2
|
||||
|
||||
# First statement's data is accessible
|
||||
first_stmt = data["statements"][0]
|
||||
assert first_stmt["data"] is not None
|
||||
assert first_stmt["data"]["rows"] == [{"order_count": 42}]
|
||||
assert len(first_stmt["data"]["columns"]) == 1
|
||||
assert first_stmt["data"]["columns"][0]["name"] == "order_count"
|
||||
|
||||
# Second statement's data is accessible
|
||||
second_stmt = data["statements"][1]
|
||||
assert second_stmt["data"] is not None
|
||||
assert second_stmt["data"]["rows"] == [{"total_revenue": 12345.67}]
|
||||
assert len(second_stmt["data"]["columns"]) == 1
|
||||
assert second_stmt["data"]["columns"][0]["name"] == "total_revenue"
|
||||
|
||||
# Warning should tell LLM to check statements array
|
||||
assert data["multi_statement_warning"] is not None
|
||||
assert "2 data-bearing statements" in data["multi_statement_warning"]
|
||||
assert "statements" in data["multi_statement_warning"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_sql_empty_query_validation(self, mcp_server):
|
||||
"""Test validation of empty SQL query."""
|
||||
|
||||
Reference in New Issue
Block a user