diff --git a/superset/mcp_service/sql_lab/tool/execute_sql.py b/superset/mcp_service/sql_lab/tool/execute_sql.py index fee996c97c4..d10f700f52e 100644 --- a/superset/mcp_service/sql_lab/tool/execute_sql.py +++ b/superset/mcp_service/sql_lab/tool/execute_sql.py @@ -164,22 +164,31 @@ def _convert_to_response(result: QueryResult) -> ExecuteSqlResponse: for stmt in result.statements ] - # Get first statement's data for backward compatibility - first_stmt = result.statements[0] if result.statements else None + # Find the last statement with data (SELECT results). + # For single statements this is the same as first. + # For multi-statement queries (e.g., SET ...; SELECT ...) this skips + # non-data statements and returns the actual query results. rows: list[dict[str, Any]] | None = None columns: list[ColumnInfo] | None = None row_count: int | None = None affected_rows: int | None = None - if first_stmt and first_stmt.data is not None: + data_stmt = None + for stmt in reversed(result.statements): + if stmt.data is not None: + data_stmt = stmt + break + + if data_stmt is not None and data_stmt.data is not None: # SELECT query - convert DataFrame - df = first_stmt.data + df = data_stmt.data rows = df.to_dict(orient="records") columns = [ColumnInfo(name=col, type=str(df[col].dtype)) for col in df.columns] row_count = len(df) - elif first_stmt: - # DML query - affected_rows = first_stmt.row_count + elif result.statements: + # DML-only query + last_stmt = result.statements[-1] + affected_rows = last_stmt.row_count return ExecuteSqlResponse( success=True, diff --git a/tests/unit_tests/mcp_service/sql_lab/tool/test_execute_sql.py b/tests/unit_tests/mcp_service/sql_lab/tool/test_execute_sql.py index 86b9af9188c..77f7f18d82a 100644 --- a/tests/unit_tests/mcp_service/sql_lab/tool/test_execute_sql.py +++ b/tests/unit_tests/mcp_service/sql_lab/tool/test_execute_sql.py @@ -526,10 +526,144 @@ class TestExecuteSql: assert data["statements"][0]["original_sql"] == "SELECT 1 as a" assert data["statements"][1]["original_sql"] == "SELECT 2 as b" - # rows/columns should be from first statement for backward compat - assert data["rows"] == [{"a": 1}] + # rows/columns should be from last data-bearing statement + assert data["rows"] == [{"b": 2}] assert data["row_count"] == 1 + @patch("superset.security_manager") + @patch("superset.db") + @pytest.mark.asyncio + async def test_execute_sql_multi_statement_set_then_select( + self, mock_db, mock_security_manager, mcp_server + ): + """Test multi-statement where first stmt is SET (no data) and second is SELECT. + + This covers the edge case where a SET command (e.g., SET search_path) + precedes the actual query. The response should contain the SELECT + results, not the SET's affected_rows. + """ + mock_database = _mock_database() + mock_database.execute.return_value = QueryResult( + status=QueryStatus.SUCCESS, + statements=[ + StatementResult( + original_sql="SET search_path TO sales", + executed_sql="SET search_path TO sales", + data=None, + row_count=0, + execution_time_ms=1.0, + ), + StatementResult( + original_sql=( + "WITH cte AS (SELECT id, amount FROM orders) SELECT * FROM cte" + ), + executed_sql=( + "WITH cte AS (SELECT id, amount FROM orders) SELECT * FROM cte" + ), + data=pd.DataFrame( + [{"id": 1, "amount": 99.99}, {"id": 2, "amount": 150.00}] + ), + row_count=2, + execution_time_ms=12.0, + ), + ], + query_id=None, + total_execution_time_ms=13.0, + is_cached=False, + ) + mock_db.session.query.return_value.filter_by.return_value.first.return_value = ( + mock_database + ) + mock_security_manager.can_access_database.return_value = True + + request = { + "database_id": 1, + "sql": ( + "SET search_path TO sales;" + " WITH cte AS (SELECT id, amount FROM orders)" + " SELECT * FROM cte" + ), + "limit": 100, + } + + async with Client(mcp_server) as client: + result = await client.call_tool("execute_sql", {"request": request}) + + data = result.structured_content + assert data["success"] is True + assert data["statements"] is not None + assert len(data["statements"]) == 2 + + # The response should contain the SELECT results, not affected_rows + assert data["rows"] is not None + assert len(data["rows"]) == 2 + assert data["rows"][0]["id"] == 1 + assert data["rows"][0]["amount"] == 99.99 + assert data["row_count"] == 2 + assert data["affected_rows"] is None + + # Verify columns come from the SELECT statement + assert data["columns"] is not None + assert len(data["columns"]) == 2 + column_names = [c["name"] for c in data["columns"]] + assert "id" in column_names + assert "amount" in column_names + + @patch("superset.security_manager") + @patch("superset.db") + @pytest.mark.asyncio + async def test_execute_sql_multi_statement_all_dml( + self, mock_db, mock_security_manager, mcp_server + ): + """Test multi-statement where all statements are DML (no data). + + When no statement has data, the response should use affected_rows + from the last statement. + """ + mock_database = _mock_database(allow_dml=True) + mock_database.execute.return_value = QueryResult( + status=QueryStatus.SUCCESS, + statements=[ + StatementResult( + original_sql="SET search_path TO sales", + executed_sql="SET search_path TO sales", + data=None, + row_count=0, + execution_time_ms=1.0, + ), + StatementResult( + original_sql="UPDATE orders SET status = 'shipped'", + executed_sql="UPDATE orders SET status = 'shipped'", + data=None, + row_count=5, + execution_time_ms=8.0, + ), + ], + query_id=None, + total_execution_time_ms=9.0, + is_cached=False, + ) + mock_db.session.query.return_value.filter_by.return_value.first.return_value = ( + mock_database + ) + mock_security_manager.can_access_database.return_value = True + + request = { + "database_id": 1, + "sql": "SET search_path TO sales; UPDATE orders SET status = 'shipped'", + "limit": 100, + } + + async with Client(mcp_server) as client: + result = await client.call_tool("execute_sql", {"request": request}) + + data = result.structured_content + assert data["success"] is True + assert data["rows"] is None + assert data["row_count"] is None + # affected_rows should come from the last statement + assert data["affected_rows"] == 5 + @pytest.mark.asyncio async def test_execute_sql_empty_query_validation(self, mcp_server): """Test validation of empty SQL query."""