mirror of
https://github.com/apache/superset.git
synced 2026-04-07 10:31:50 +00:00
fix(mcp): use last data-bearing statement in execute_sql response (#37968)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -164,22 +164,31 @@ def _convert_to_response(result: QueryResult) -> ExecuteSqlResponse:
|
||||
for stmt in result.statements
|
||||
]
|
||||
|
||||
# Get first statement's data for backward compatibility
|
||||
first_stmt = result.statements[0] if result.statements else None
|
||||
# Find the last statement with data (SELECT results).
|
||||
# For single statements this is the same as first.
|
||||
# For multi-statement queries (e.g., SET ...; SELECT ...) this skips
|
||||
# non-data statements and returns the actual query results.
|
||||
rows: list[dict[str, Any]] | None = None
|
||||
columns: list[ColumnInfo] | None = None
|
||||
row_count: int | None = None
|
||||
affected_rows: int | None = None
|
||||
|
||||
if first_stmt and first_stmt.data is not None:
|
||||
data_stmt = None
|
||||
for stmt in reversed(result.statements):
|
||||
if stmt.data is not None:
|
||||
data_stmt = stmt
|
||||
break
|
||||
|
||||
if data_stmt is not None and data_stmt.data is not None:
|
||||
# SELECT query - convert DataFrame
|
||||
df = first_stmt.data
|
||||
df = data_stmt.data
|
||||
rows = df.to_dict(orient="records")
|
||||
columns = [ColumnInfo(name=col, type=str(df[col].dtype)) for col in df.columns]
|
||||
row_count = len(df)
|
||||
elif first_stmt:
|
||||
# DML query
|
||||
affected_rows = first_stmt.row_count
|
||||
elif result.statements:
|
||||
# DML-only query
|
||||
last_stmt = result.statements[-1]
|
||||
affected_rows = last_stmt.row_count
|
||||
|
||||
return ExecuteSqlResponse(
|
||||
success=True,
|
||||
|
||||
@@ -526,10 +526,144 @@ class TestExecuteSql:
|
||||
assert data["statements"][0]["original_sql"] == "SELECT 1 as a"
|
||||
assert data["statements"][1]["original_sql"] == "SELECT 2 as b"
|
||||
|
||||
# rows/columns should be from first statement for backward compat
|
||||
assert data["rows"] == [{"a": 1}]
|
||||
# rows/columns should be from last data-bearing statement
|
||||
assert data["rows"] == [{"b": 2}]
|
||||
assert data["row_count"] == 1
|
||||
|
||||
@patch("superset.security_manager")
|
||||
@patch("superset.db")
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_sql_multi_statement_set_then_select(
|
||||
self, mock_db, mock_security_manager, mcp_server
|
||||
):
|
||||
"""Test multi-statement where first stmt is SET (no data) and second is SELECT.
|
||||
|
||||
This covers the edge case where a SET command (e.g., SET search_path)
|
||||
precedes the actual query. The response should contain the SELECT
|
||||
results, not the SET's affected_rows.
|
||||
"""
|
||||
mock_database = _mock_database()
|
||||
mock_database.execute.return_value = QueryResult(
|
||||
status=QueryStatus.SUCCESS,
|
||||
statements=[
|
||||
StatementResult(
|
||||
original_sql="SET search_path TO sales",
|
||||
executed_sql="SET search_path TO sales",
|
||||
data=None,
|
||||
row_count=0,
|
||||
execution_time_ms=1.0,
|
||||
),
|
||||
StatementResult(
|
||||
original_sql=(
|
||||
"WITH cte AS (SELECT id, amount FROM orders) SELECT * FROM cte"
|
||||
),
|
||||
executed_sql=(
|
||||
"WITH cte AS (SELECT id, amount FROM orders) SELECT * FROM cte"
|
||||
),
|
||||
data=pd.DataFrame(
|
||||
[{"id": 1, "amount": 99.99}, {"id": 2, "amount": 150.00}]
|
||||
),
|
||||
row_count=2,
|
||||
execution_time_ms=12.0,
|
||||
),
|
||||
],
|
||||
query_id=None,
|
||||
total_execution_time_ms=13.0,
|
||||
is_cached=False,
|
||||
)
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
|
||||
mock_database
|
||||
)
|
||||
mock_security_manager.can_access_database.return_value = True
|
||||
|
||||
request = {
|
||||
"database_id": 1,
|
||||
"sql": (
|
||||
"SET search_path TO sales;"
|
||||
" WITH cte AS (SELECT id, amount FROM orders)"
|
||||
" SELECT * FROM cte"
|
||||
),
|
||||
"limit": 100,
|
||||
}
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("execute_sql", {"request": request})
|
||||
|
||||
data = result.structured_content
|
||||
assert data["success"] is True
|
||||
assert data["statements"] is not None
|
||||
assert len(data["statements"]) == 2
|
||||
|
||||
# The response should contain the SELECT results, not affected_rows
|
||||
assert data["rows"] is not None
|
||||
assert len(data["rows"]) == 2
|
||||
assert data["rows"][0]["id"] == 1
|
||||
assert data["rows"][0]["amount"] == 99.99
|
||||
assert data["row_count"] == 2
|
||||
assert data["affected_rows"] is None
|
||||
|
||||
# Verify columns come from the SELECT statement
|
||||
assert data["columns"] is not None
|
||||
assert len(data["columns"]) == 2
|
||||
column_names = [c["name"] for c in data["columns"]]
|
||||
assert "id" in column_names
|
||||
assert "amount" in column_names
|
||||
|
||||
@patch("superset.security_manager")
|
||||
@patch("superset.db")
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_sql_multi_statement_all_dml(
|
||||
self, mock_db, mock_security_manager, mcp_server
|
||||
):
|
||||
"""Test multi-statement where all statements are DML (no data).
|
||||
|
||||
When no statement has data, the response should use affected_rows
|
||||
from the last statement.
|
||||
"""
|
||||
mock_database = _mock_database(allow_dml=True)
|
||||
mock_database.execute.return_value = QueryResult(
|
||||
status=QueryStatus.SUCCESS,
|
||||
statements=[
|
||||
StatementResult(
|
||||
original_sql="SET search_path TO sales",
|
||||
executed_sql="SET search_path TO sales",
|
||||
data=None,
|
||||
row_count=0,
|
||||
execution_time_ms=1.0,
|
||||
),
|
||||
StatementResult(
|
||||
original_sql="UPDATE orders SET status = 'shipped'",
|
||||
executed_sql="UPDATE orders SET status = 'shipped'",
|
||||
data=None,
|
||||
row_count=5,
|
||||
execution_time_ms=8.0,
|
||||
),
|
||||
],
|
||||
query_id=None,
|
||||
total_execution_time_ms=9.0,
|
||||
is_cached=False,
|
||||
)
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
|
||||
mock_database
|
||||
)
|
||||
mock_security_manager.can_access_database.return_value = True
|
||||
|
||||
request = {
|
||||
"database_id": 1,
|
||||
"sql": "SET search_path TO sales; UPDATE orders SET status = 'shipped'",
|
||||
"limit": 100,
|
||||
}
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("execute_sql", {"request": request})
|
||||
|
||||
data = result.structured_content
|
||||
assert data["success"] is True
|
||||
assert data["rows"] is None
|
||||
assert data["row_count"] is None
|
||||
# affected_rows should come from the last statement
|
||||
assert data["affected_rows"] == 5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_sql_empty_query_validation(self, mcp_server):
|
||||
"""Test validation of empty SQL query."""
|
||||
|
||||
Reference in New Issue
Block a user