mirror of
https://github.com/apache/superset.git
synced 2026-04-18 07:35:09 +00:00
fix(mcp): return all statement results for multi-statement SQL queries (#38388)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -43,6 +43,7 @@ from superset.mcp_service.sql_lab.schemas import (
|
||||
ColumnInfo,
|
||||
ExecuteSqlRequest,
|
||||
ExecuteSqlResponse,
|
||||
StatementData,
|
||||
StatementInfo,
|
||||
)
|
||||
from superset.mcp_service.utils.schema_utils import parse_request
|
||||
@@ -162,56 +163,69 @@ def _convert_to_response(result: QueryResult) -> ExecuteSqlResponse:
|
||||
error_type=result.status.value,
|
||||
)
|
||||
|
||||
# Build statement info list
|
||||
statements = [
|
||||
StatementInfo(
|
||||
original_sql=stmt.original_sql,
|
||||
executed_sql=stmt.executed_sql,
|
||||
row_count=stmt.row_count,
|
||||
execution_time_ms=stmt.execution_time_ms,
|
||||
)
|
||||
for stmt in result.statements
|
||||
]
|
||||
# Build statement info list, including per-statement row data
|
||||
# for data-bearing statements (e.g., SELECT).
|
||||
statements: list[StatementInfo] = []
|
||||
data_bearing_count = 0
|
||||
|
||||
# Find the last statement with data (SELECT results).
|
||||
# For single statements this is the same as first.
|
||||
# For multi-statement queries (e.g., SET ...; SELECT ...) this skips
|
||||
# non-data statements and returns the actual query results.
|
||||
for stmt in result.statements:
|
||||
stmt_data: StatementData | None = None
|
||||
if stmt.data is not None:
|
||||
df = stmt.data
|
||||
stmt_data = StatementData(
|
||||
rows=df.to_dict(orient="records"),
|
||||
columns=[
|
||||
ColumnInfo(name=col, type=str(df[col].dtype)) for col in df.columns
|
||||
],
|
||||
)
|
||||
data_bearing_count += 1
|
||||
|
||||
statements.append(
|
||||
StatementInfo(
|
||||
original_sql=stmt.original_sql,
|
||||
executed_sql=stmt.executed_sql,
|
||||
row_count=stmt.row_count,
|
||||
execution_time_ms=stmt.execution_time_ms,
|
||||
data=stmt_data,
|
||||
)
|
||||
)
|
||||
|
||||
# Top-level rows/columns come from the last data-bearing statement
|
||||
# for backward compatibility.
|
||||
rows: list[dict[str, Any]] | None = None
|
||||
columns: list[ColumnInfo] | None = None
|
||||
row_count: int | None = None
|
||||
affected_rows: int | None = None
|
||||
|
||||
data_stmt = None
|
||||
for stmt in reversed(result.statements):
|
||||
last_data_stmt = None
|
||||
for stmt in reversed(statements):
|
||||
if stmt.data is not None:
|
||||
data_stmt = stmt
|
||||
last_data_stmt = stmt
|
||||
break
|
||||
|
||||
if data_stmt is not None and data_stmt.data is not None:
|
||||
# SELECT query - convert DataFrame
|
||||
import pandas as pd
|
||||
|
||||
df = data_stmt.data
|
||||
if not isinstance(df, pd.DataFrame):
|
||||
logger.error(
|
||||
"Expected DataFrame but got %s for statement data",
|
||||
type(df).__name__,
|
||||
)
|
||||
return ExecuteSqlResponse(
|
||||
success=False,
|
||||
error=f"Internal error: unexpected data type ({type(df).__name__})",
|
||||
error_type="data_conversion_error",
|
||||
statements=statements,
|
||||
)
|
||||
rows = df.to_dict(orient="records")
|
||||
columns = [ColumnInfo(name=col, type=str(df[col].dtype)) for col in df.columns]
|
||||
row_count = len(df)
|
||||
if last_data_stmt is not None and last_data_stmt.data is not None:
|
||||
rows = last_data_stmt.data.rows
|
||||
columns = last_data_stmt.data.columns
|
||||
row_count = len(last_data_stmt.data.rows)
|
||||
elif result.statements:
|
||||
# DML-only query
|
||||
last_stmt = result.statements[-1]
|
||||
affected_rows = last_stmt.row_count
|
||||
|
||||
# Warn when multiple data-bearing statements exist so the LLM
|
||||
# knows to inspect the statements array for all results.
|
||||
multi_statement_warning: str | None = None
|
||||
if data_bearing_count > 1:
|
||||
multi_statement_warning = (
|
||||
f"This query contained {data_bearing_count} "
|
||||
"data-bearing statements. "
|
||||
"The top-level rows/columns contain only the "
|
||||
"last data-bearing statement's results. "
|
||||
"Check the 'data' field in each entry of the "
|
||||
"'statements' array to see results from ALL "
|
||||
"statements."
|
||||
)
|
||||
|
||||
return ExecuteSqlResponse(
|
||||
success=True,
|
||||
rows=rows,
|
||||
@@ -224,4 +238,5 @@ def _convert_to_response(result: QueryResult) -> ExecuteSqlResponse:
|
||||
else None
|
||||
),
|
||||
statements=statements,
|
||||
multi_statement_warning=multi_statement_warning,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user