fix(mcp): use last data-bearing statement in execute_sql response (#37968)

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Amin Ghadersohi
2026-02-17 07:13:55 -05:00
committed by GitHub
parent f7218e7a19
commit aec1f6edce
2 changed files with 152 additions and 9 deletions

View File

@@ -164,22 +164,31 @@ def _convert_to_response(result: QueryResult) -> ExecuteSqlResponse:
for stmt in result.statements
]
# Get first statement's data for backward compatibility
first_stmt = result.statements[0] if result.statements else None
# Find the last statement with data (SELECT results).
# For single statements this is the same as first.
# For multi-statement queries (e.g., SET ...; SELECT ...) this skips
# non-data statements and returns the actual query results.
rows: list[dict[str, Any]] | None = None
columns: list[ColumnInfo] | None = None
row_count: int | None = None
affected_rows: int | None = None
if first_stmt and first_stmt.data is not None:
data_stmt = None
for stmt in reversed(result.statements):
if stmt.data is not None:
data_stmt = stmt
break
if data_stmt is not None and data_stmt.data is not None:
# SELECT query - convert DataFrame
df = first_stmt.data
df = data_stmt.data
rows = df.to_dict(orient="records")
columns = [ColumnInfo(name=col, type=str(df[col].dtype)) for col in df.columns]
row_count = len(df)
elif first_stmt:
# DML query
affected_rows = first_stmt.row_count
elif result.statements:
# DML-only query
last_stmt = result.statements[-1]
affected_rows = last_stmt.row_count
return ExecuteSqlResponse(
success=True,

View File

@@ -526,10 +526,144 @@ class TestExecuteSql:
assert data["statements"][0]["original_sql"] == "SELECT 1 as a"
assert data["statements"][1]["original_sql"] == "SELECT 2 as b"
# rows/columns should be from first statement for backward compat
assert data["rows"] == [{"a": 1}]
# rows/columns should be from last data-bearing statement
assert data["rows"] == [{"b": 2}]
assert data["row_count"] == 1
@patch("superset.security_manager")
@patch("superset.db")
@pytest.mark.asyncio
async def test_execute_sql_multi_statement_set_then_select(
self, mock_db, mock_security_manager, mcp_server
):
"""Test multi-statement where first stmt is SET (no data) and second is SELECT.
This covers the edge case where a SET command (e.g., SET search_path)
precedes the actual query. The response should contain the SELECT
results, not the SET's affected_rows.
"""
mock_database = _mock_database()
mock_database.execute.return_value = QueryResult(
status=QueryStatus.SUCCESS,
statements=[
StatementResult(
original_sql="SET search_path TO sales",
executed_sql="SET search_path TO sales",
data=None,
row_count=0,
execution_time_ms=1.0,
),
StatementResult(
original_sql=(
"WITH cte AS (SELECT id, amount FROM orders) SELECT * FROM cte"
),
executed_sql=(
"WITH cte AS (SELECT id, amount FROM orders) SELECT * FROM cte"
),
data=pd.DataFrame(
[{"id": 1, "amount": 99.99}, {"id": 2, "amount": 150.00}]
),
row_count=2,
execution_time_ms=12.0,
),
],
query_id=None,
total_execution_time_ms=13.0,
is_cached=False,
)
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
mock_database
)
mock_security_manager.can_access_database.return_value = True
request = {
"database_id": 1,
"sql": (
"SET search_path TO sales;"
" WITH cte AS (SELECT id, amount FROM orders)"
" SELECT * FROM cte"
),
"limit": 100,
}
async with Client(mcp_server) as client:
result = await client.call_tool("execute_sql", {"request": request})
data = result.structured_content
assert data["success"] is True
assert data["statements"] is not None
assert len(data["statements"]) == 2
# The response should contain the SELECT results, not affected_rows
assert data["rows"] is not None
assert len(data["rows"]) == 2
assert data["rows"][0]["id"] == 1
assert data["rows"][0]["amount"] == 99.99
assert data["row_count"] == 2
assert data["affected_rows"] is None
# Verify columns come from the SELECT statement
assert data["columns"] is not None
assert len(data["columns"]) == 2
column_names = [c["name"] for c in data["columns"]]
assert "id" in column_names
assert "amount" in column_names
@patch("superset.security_manager")
@patch("superset.db")
@pytest.mark.asyncio
async def test_execute_sql_multi_statement_all_dml(
self, mock_db, mock_security_manager, mcp_server
):
"""Test multi-statement where all statements are DML (no data).
When no statement has data, the response should use affected_rows
from the last statement.
"""
mock_database = _mock_database(allow_dml=True)
mock_database.execute.return_value = QueryResult(
status=QueryStatus.SUCCESS,
statements=[
StatementResult(
original_sql="SET search_path TO sales",
executed_sql="SET search_path TO sales",
data=None,
row_count=0,
execution_time_ms=1.0,
),
StatementResult(
original_sql="UPDATE orders SET status = 'shipped'",
executed_sql="UPDATE orders SET status = 'shipped'",
data=None,
row_count=5,
execution_time_ms=8.0,
),
],
query_id=None,
total_execution_time_ms=9.0,
is_cached=False,
)
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
mock_database
)
mock_security_manager.can_access_database.return_value = True
request = {
"database_id": 1,
"sql": "SET search_path TO sales; UPDATE orders SET status = 'shipped'",
"limit": 100,
}
async with Client(mcp_server) as client:
result = await client.call_tool("execute_sql", {"request": request})
data = result.structured_content
assert data["success"] is True
assert data["rows"] is None
assert data["row_count"] is None
# affected_rows should come from the last statement
assert data["affected_rows"] == 5
@pytest.mark.asyncio
async def test_execute_sql_empty_query_validation(self, mcp_server):
"""Test validation of empty SQL query."""