From 4a153a0ec3254c5759981c9a038ad85da1b0153e Mon Sep 17 00:00:00 2001 From: Amin Ghadersohi Date: Fri, 13 Mar 2026 16:53:52 +0100 Subject: [PATCH] fix(mcp): return all statement results for multi-statement SQL queries (#38388) Co-authored-by: Claude Opus 4.6 (cherry picked from commit b6c3b3ef46d9386f8e318458a243695da95ed3e4) --- superset/mcp_service/sql_lab/schemas.py | 26 ++++ .../mcp_service/sql_lab/tool/execute_sql.py | 87 ++++++++------ .../sql_lab/tool/test_execute_sql.py | 111 ++++++++++++++++++ 3 files changed, 188 insertions(+), 36 deletions(-) diff --git a/superset/mcp_service/sql_lab/schemas.py b/superset/mcp_service/sql_lab/schemas.py index 436e9330733..e55ca53130f 100644 --- a/superset/mcp_service/sql_lab/schemas.py +++ b/superset/mcp_service/sql_lab/schemas.py @@ -84,6 +84,15 @@ class ColumnInfo(BaseModel): is_nullable: bool | None = Field(None, description="Whether column allows NULL") +class StatementData(BaseModel): + """Row data and column metadata for a single SQL statement.""" + + rows: list[dict[str, Any]] = Field( + ..., description="Result rows as list of dictionaries" + ) + columns: list[ColumnInfo] = Field(..., description="Column metadata information") + + class StatementInfo(BaseModel): """Information about a single SQL statement execution.""" @@ -95,6 +104,14 @@ class StatementInfo(BaseModel): execution_time_ms: float | None = Field( None, description="Statement execution time in milliseconds" ) + data: StatementData | None = Field( + None, + description=( + "Row data and column metadata for this statement. " + "Present for data-bearing statements (e.g., SELECT), " + "absent for DML/DDL statements (e.g., SET, UPDATE)." + ), + ) class ExecuteSqlResponse(BaseModel): @@ -119,6 +136,15 @@ class ExecuteSqlResponse(BaseModel): statements: list[StatementInfo] | None = Field( None, description="Per-statement execution info (for multi-statement queries)" ) + multi_statement_warning: str | None = Field( + None, + description=( + "Warning when multiple data-bearing statements were executed. " + "The top-level rows/columns contain only the last " + "data-bearing statement's results. " + "Check each entry in the statements array for per-statement data." + ), + ) class OpenSqlLabRequest(BaseModel): diff --git a/superset/mcp_service/sql_lab/tool/execute_sql.py b/superset/mcp_service/sql_lab/tool/execute_sql.py index 35bfd2cc672..4093765938b 100644 --- a/superset/mcp_service/sql_lab/tool/execute_sql.py +++ b/superset/mcp_service/sql_lab/tool/execute_sql.py @@ -43,6 +43,7 @@ from superset.mcp_service.sql_lab.schemas import ( ColumnInfo, ExecuteSqlRequest, ExecuteSqlResponse, + StatementData, StatementInfo, ) from superset.mcp_service.utils.schema_utils import parse_request @@ -158,56 +159,69 @@ def _convert_to_response(result: QueryResult) -> ExecuteSqlResponse: error_type=result.status.value, ) - # Build statement info list - statements = [ - StatementInfo( - original_sql=stmt.original_sql, - executed_sql=stmt.executed_sql, - row_count=stmt.row_count, - execution_time_ms=stmt.execution_time_ms, - ) - for stmt in result.statements - ] + # Build statement info list, including per-statement row data + # for data-bearing statements (e.g., SELECT). + statements: list[StatementInfo] = [] + data_bearing_count = 0 - # Find the last statement with data (SELECT results). - # For single statements this is the same as first. - # For multi-statement queries (e.g., SET ...; SELECT ...) this skips - # non-data statements and returns the actual query results. + for stmt in result.statements: + stmt_data: StatementData | None = None + if stmt.data is not None: + df = stmt.data + stmt_data = StatementData( + rows=df.to_dict(orient="records"), + columns=[ + ColumnInfo(name=col, type=str(df[col].dtype)) for col in df.columns + ], + ) + data_bearing_count += 1 + + statements.append( + StatementInfo( + original_sql=stmt.original_sql, + executed_sql=stmt.executed_sql, + row_count=stmt.row_count, + execution_time_ms=stmt.execution_time_ms, + data=stmt_data, + ) + ) + + # Top-level rows/columns come from the last data-bearing statement + # for backward compatibility. rows: list[dict[str, Any]] | None = None columns: list[ColumnInfo] | None = None row_count: int | None = None affected_rows: int | None = None - data_stmt = None - for stmt in reversed(result.statements): + last_data_stmt = None + for stmt in reversed(statements): if stmt.data is not None: - data_stmt = stmt + last_data_stmt = stmt break - if data_stmt is not None and data_stmt.data is not None: - # SELECT query - convert DataFrame - import pandas as pd - - df = data_stmt.data - if not isinstance(df, pd.DataFrame): - logger.error( - "Expected DataFrame but got %s for statement data", - type(df).__name__, - ) - return ExecuteSqlResponse( - success=False, - error=f"Internal error: unexpected data type ({type(df).__name__})", - error_type="data_conversion_error", - statements=statements, - ) - rows = df.to_dict(orient="records") - columns = [ColumnInfo(name=col, type=str(df[col].dtype)) for col in df.columns] - row_count = len(df) + if last_data_stmt is not None and last_data_stmt.data is not None: + rows = last_data_stmt.data.rows + columns = last_data_stmt.data.columns + row_count = len(last_data_stmt.data.rows) elif result.statements: # DML-only query last_stmt = result.statements[-1] affected_rows = last_stmt.row_count + # Warn when multiple data-bearing statements exist so the LLM + # knows to inspect the statements array for all results. + multi_statement_warning: str | None = None + if data_bearing_count > 1: + multi_statement_warning = ( + f"This query contained {data_bearing_count} " + "data-bearing statements. " + "The top-level rows/columns contain only the " + "last data-bearing statement's results. " + "Check the 'data' field in each entry of the " + "'statements' array to see results from ALL " + "statements." + ) + return ExecuteSqlResponse( success=True, rows=rows, @@ -220,4 +234,5 @@ def _convert_to_response(result: QueryResult) -> ExecuteSqlResponse: else None ), statements=statements, + multi_statement_warning=multi_statement_warning, ) diff --git a/tests/unit_tests/mcp_service/sql_lab/tool/test_execute_sql.py b/tests/unit_tests/mcp_service/sql_lab/tool/test_execute_sql.py index 3576acd3d37..764e1bf2cee 100644 --- a/tests/unit_tests/mcp_service/sql_lab/tool/test_execute_sql.py +++ b/tests/unit_tests/mcp_service/sql_lab/tool/test_execute_sql.py @@ -530,6 +530,21 @@ class TestExecuteSql: assert data["rows"] == [{"b": 2}] assert data["row_count"] == 1 + # Per-statement data should be present for both statements + assert data["statements"][0]["data"] is not None + assert data["statements"][0]["data"]["rows"] == [{"a": 1}] + assert len(data["statements"][0]["data"]["columns"]) == 1 + assert data["statements"][0]["data"]["columns"][0]["name"] == "a" + + assert data["statements"][1]["data"] is not None + assert data["statements"][1]["data"]["rows"] == [{"b": 2}] + assert len(data["statements"][1]["data"]["columns"]) == 1 + assert data["statements"][1]["data"]["columns"][0]["name"] == "b" + + # Warning should be present for multi-data-bearing queries + assert data["multi_statement_warning"] is not None + assert "2 data-bearing statements" in data["multi_statement_warning"] + @patch("superset.security_manager") @patch("superset.db") @pytest.mark.asyncio @@ -609,6 +624,14 @@ class TestExecuteSql: assert "id" in column_names assert "amount" in column_names + # SET statement should have no data, SELECT should have data + assert data["statements"][0]["data"] is None + assert data["statements"][1]["data"] is not None + assert len(data["statements"][1]["data"]["rows"]) == 2 + + # No warning since only one data-bearing statement + assert data["multi_statement_warning"] is None + @patch("superset.security_manager") @patch("superset.db") @pytest.mark.asyncio @@ -664,6 +687,94 @@ class TestExecuteSql: # affected_rows should come from the last statement assert data["affected_rows"] == 5 + # DML statements should have no per-statement data + assert data["statements"][0]["data"] is None + assert data["statements"][1]["data"] is None + + # No warning for DML-only queries + assert data["multi_statement_warning"] is None + + @patch("superset.security_manager") + @patch("superset.db") + @pytest.mark.asyncio + async def test_execute_sql_multi_statement_preserves_all_data( + self, mock_db, mock_security_manager, mcp_server + ) -> None: + """Test that multi-statement SQL returns per-statement data for ALL results. + + Regression test: previously, running two SELECT statements would only + return the last statement's rows in the top-level response and + completely lose the first statement's row data. + """ + mock_database = _mock_database() + mock_database.execute.return_value = QueryResult( + status=QueryStatus.SUCCESS, + statements=[ + StatementResult( + original_sql="SELECT COUNT(*) AS order_count FROM orders", + executed_sql="SELECT COUNT(*) AS order_count FROM orders", + data=pd.DataFrame([{"order_count": 42}]), + row_count=1, + execution_time_ms=5.0, + ), + StatementResult( + original_sql="SELECT SUM(revenue) AS total_revenue FROM orders", + executed_sql="SELECT SUM(revenue) AS total_revenue FROM orders", + data=pd.DataFrame([{"total_revenue": 12345.67}]), + row_count=1, + execution_time_ms=7.0, + ), + ], + query_id=None, + total_execution_time_ms=12.0, + is_cached=False, + ) + mock_db.session.query.return_value.filter_by.return_value.first.return_value = ( + mock_database + ) + mock_security_manager.can_access_database.return_value = True + + request = { + "database_id": 1, + "sql": ( + "SELECT COUNT(*) AS order_count FROM orders;" + " SELECT SUM(revenue) AS total_revenue FROM orders" + ), + "limit": 100, + } + + async with Client(mcp_server) as client: + result = await client.call_tool("execute_sql", {"request": request}) + + data = result.structured_content + assert data["success"] is True + + # Top-level rows/columns should be from the LAST data-bearing stmt + assert data["rows"] == [{"total_revenue": 12345.67}] + assert data["row_count"] == 1 + + # Both statements should have per-statement data + assert len(data["statements"]) == 2 + + # First statement's data is accessible + first_stmt = data["statements"][0] + assert first_stmt["data"] is not None + assert first_stmt["data"]["rows"] == [{"order_count": 42}] + assert len(first_stmt["data"]["columns"]) == 1 + assert first_stmt["data"]["columns"][0]["name"] == "order_count" + + # Second statement's data is accessible + second_stmt = data["statements"][1] + assert second_stmt["data"] is not None + assert second_stmt["data"]["rows"] == [{"total_revenue": 12345.67}] + assert len(second_stmt["data"]["columns"]) == 1 + assert second_stmt["data"]["columns"][0]["name"] == "total_revenue" + + # Warning should tell LLM to check statements array + assert data["multi_statement_warning"] is not None + assert "2 data-bearing statements" in data["multi_statement_warning"] + assert "statements" in data["multi_statement_warning"] + @pytest.mark.asyncio async def test_execute_sql_empty_query_validation(self, mcp_server): """Test validation of empty SQL query."""