diff --git a/superset-core/src/superset_core/api/models.py b/superset-core/src/superset_core/api/models.py index 59cb07dc382..346e8392f16 100644 --- a/superset-core/src/superset_core/api/models.py +++ b/superset-core/src/superset_core/api/models.py @@ -92,7 +92,11 @@ class Database(CoreModel): """ Execute SQL synchronously. - :param sql: SQL query to execute + The SQL must be written in the dialect of the target database (e.g., + PostgreSQL syntax for PostgreSQL databases, Snowflake syntax for + Snowflake, etc.). No automatic cross-dialect translation is performed. + + :param sql: SQL query to execute (in the target database's dialect) :param options: Query execution options (see `QueryOptions`). If not provided, defaults are used. :returns: QueryResult with status, data (DataFrame), and metadata @@ -139,7 +143,11 @@ class Database(CoreModel): Returns immediately with a handle for tracking progress and retrieving results from the background worker. - :param sql: SQL query to execute + The SQL must be written in the dialect of the target database (e.g., + PostgreSQL syntax for PostgreSQL databases, Snowflake syntax for + Snowflake, etc.). No automatic cross-dialect translation is performed. + + :param sql: SQL query to execute (in the target database's dialect) :param options: Query execution options (see `QueryOptions`). If not provided, defaults are used. :returns: AsyncQueryHandle for tracking the query diff --git a/superset/mcp_service/sql_lab/execute_sql_core.py b/superset/mcp_service/sql_lab/execute_sql_core.py deleted file mode 100644 index 263e0123385..00000000000 --- a/superset/mcp_service/sql_lab/execute_sql_core.py +++ /dev/null @@ -1,221 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -""" -Generic SQL execution core for MCP service. -""" - -import logging -from typing import Any - -from superset.mcp_service.mcp_core import BaseCore -from superset.mcp_service.sql_lab.schemas import ( - ExecuteSqlRequest, - ExecuteSqlResponse, -) - - -class ExecuteSqlCore(BaseCore): - """ - Generic tool for executing SQL queries with security validation. - - This tool provides a high-level interface for SQL execution that can be used - by different MCP tools or other components. It handles: - - Database access validation - - SQL query validation (DML permissions, disallowed functions) - - Parameter substitution - - Query execution with timeout - - Result formatting - - The tool can work in two modes: - 1. Simple mode: Direct SQL execution using sql_lab_utils (default) - 2. Command mode: Using ExecuteSqlCommand for full SQL Lab features - """ - - def __init__( - self, - use_command_mode: bool = False, - logger: logging.Logger | None = None, - ) -> None: - super().__init__(logger) - self.use_command_mode = use_command_mode - - def run_tool(self, request: ExecuteSqlRequest) -> ExecuteSqlResponse: - """ - Execute SQL query and return results. - - Args: - request: ExecuteSqlRequest with database_id, sql, and optional parameters - - Returns: - ExecuteSqlResponse with success status, results, or error information - """ - try: - # Import inside method to avoid initialization issues - from superset.mcp_service.sql_lab.sql_lab_utils import check_database_access - - # Check database access - database = check_database_access(request.database_id) - - if self.use_command_mode: - # Use full SQL Lab command for complex queries - return self._execute_with_command(request, database) - else: - # Use simplified execution for basic queries - return self._execute_simple(request, database) - - except Exception as e: - # Handle errors and return error response with proper error types - self._log_error(e, "executing SQL") - return self._handle_execution_error(e) - - def _execute_simple( - self, - request: ExecuteSqlRequest, - database: Any, - ) -> ExecuteSqlResponse: - """Execute SQL using simplified sql_lab_utils.""" - # Import inside method to avoid initialization issues - from superset.mcp_service.sql_lab.sql_lab_utils import execute_sql_query - - results = execute_sql_query( - database=database, - sql=request.sql, - schema=request.schema_name, - limit=request.limit, - timeout=request.timeout, - parameters=request.parameters, - ) - - return ExecuteSqlResponse( - success=True, - rows=results.get("rows"), - columns=results.get("columns"), - row_count=results.get("row_count"), - affected_rows=results.get("affected_rows"), - query_id=None, # Not available in simple mode - execution_time=results.get("execution_time"), - error=None, - error_type=None, - ) - - def _execute_with_command( - self, - request: ExecuteSqlRequest, - database: Any, - ) -> ExecuteSqlResponse: - """Execute SQL using full SQL Lab command (not implemented yet).""" - # This would use ExecuteSqlCommand for full SQL Lab features - # Including query caching, async execution, complex parsing, etc. - # For now, we'll fall back to simple execution - self._log_info("Command mode not fully implemented, using simple mode") - return self._execute_simple(request, database) - - # Future implementation would look like: - # context = SqlJsonExecutionContext( - # database_id=request.database_id, - # sql=request.sql, - # schema=request.schema_name, - # limit=request.limit, - # # ... other context fields - # ) - # - # command = ExecuteSqlCommand( - # execution_context=context, - # query_dao=QueryDAO(), - # database_dao=DatabaseDAO(), - # # ... other dependencies - # ) - # - # result = command.run() - # return self._format_command_result(result) - - def _handle_execution_error(self, e: Exception) -> ExecuteSqlResponse: - """Map exceptions to error responses.""" - error_type = self._get_error_type(e) - return ExecuteSqlResponse( - success=False, - error=str(e), - error_type=error_type, - rows=None, - columns=None, - row_count=None, - affected_rows=None, - query_id=None, - execution_time=None, - ) - - def _get_error_type(self, e: Exception) -> str: - """Determine error type from exception.""" - # Import inside method to avoid initialization issues - from superset.exceptions import ( - SupersetDisallowedSQLFunctionException, - SupersetDMLNotAllowedException, - SupersetErrorException, - SupersetSecurityException, - SupersetTimeoutException, - ) - - if isinstance(e, SupersetSecurityException): - return "SECURITY_ERROR" - elif isinstance(e, SupersetTimeoutException): - return "TIMEOUT" - elif isinstance(e, SupersetDMLNotAllowedException): - return "DML_NOT_ALLOWED" - elif isinstance(e, SupersetDisallowedSQLFunctionException): - return "DISALLOWED_FUNCTION" - elif isinstance(e, SupersetErrorException): - return self._extract_superset_error_type(e) - else: - return "EXECUTION_ERROR" - - def _extract_superset_error_type(self, e: Exception) -> str: - """Extract error type from SupersetErrorException.""" - if hasattr(e, "error") and hasattr(e.error, "error_type"): - error_type_name = e.error.error_type.name - # Map common error type patterns - if "INVALID_PAYLOAD" in error_type_name: - return "INVALID_PAYLOAD_FORMAT_ERROR" - elif "DATABASE_NOT_FOUND" in error_type_name: - return "DATABASE_NOT_FOUND_ERROR" - elif "SECURITY" in error_type_name: - return "SECURITY_ERROR" - elif "TIMEOUT" in error_type_name: - return "TIMEOUT" - elif "DML_NOT_ALLOWED" in error_type_name: - return "DML_NOT_ALLOWED" - else: - return error_type_name - return "EXECUTION_ERROR" - - def _format_command_result( - self, command_result: dict[str, Any] - ) -> ExecuteSqlResponse: - """Format ExecuteSqlCommand result into ExecuteSqlResponse.""" - # This would extract relevant fields from command result - # Placeholder implementation for future use - return ExecuteSqlResponse( - success=command_result.get("success", False), - rows=command_result.get("data"), - columns=command_result.get("columns"), - row_count=command_result.get("row_count"), - affected_rows=command_result.get("affected_rows"), - query_id=command_result.get("query_id"), - execution_time=command_result.get("execution_time"), - error=command_result.get("error"), - error_type=command_result.get("error_type"), - ) diff --git a/superset/mcp_service/sql_lab/schemas.py b/superset/mcp_service/sql_lab/schemas.py index fcfe7cb62ba..571abd2a238 100644 --- a/superset/mcp_service/sql_lab/schemas.py +++ b/superset/mcp_service/sql_lab/schemas.py @@ -28,10 +28,14 @@ class ExecuteSqlRequest(BaseModel): database_id: int = Field( ..., description="Database connection ID to execute query against" ) - sql: str = Field(..., description="SQL query to execute") + sql: str = Field( + ..., + description="SQL query to execute (supports Jinja2 {{ var }} template syntax)", + ) schema_name: str | None = Field( None, description="Schema to use for query execution", alias="schema" ) + catalog: str | None = Field(None, description="Catalog name for query execution") limit: int = Field( default=1000, description="Maximum number of rows to return", @@ -41,8 +45,21 @@ class ExecuteSqlRequest(BaseModel): timeout: int = Field( default=30, description="Query timeout in seconds", ge=1, le=300 ) - parameters: dict[str, Any] | None = Field( - None, description="Parameters for query substitution" + template_params: dict[str, Any] | None = Field( + None, description="Jinja2 template parameters for SQL rendering" + ) + dry_run: bool = Field( + default=False, + description="Return transformed SQL without executing (for debugging)", + ) + force_refresh: bool = Field( + default=False, + description=( + "Bypass cache and re-execute query. " + "IMPORTANT: Only set to true when the user EXPLICITLY requests " + "fresh/updated data (e.g., 'refresh', 'get latest', 're-run'). " + "Default to false to reduce database load." + ), ) @field_validator("sql") @@ -61,11 +78,24 @@ class ColumnInfo(BaseModel): is_nullable: bool | None = Field(None, description="Whether column allows NULL") +class StatementInfo(BaseModel): + """Information about a single SQL statement execution.""" + + original_sql: str = Field(..., description="Original SQL as submitted") + executed_sql: str = Field( + ..., description="SQL after transformations (RLS, mutations, limits)" + ) + row_count: int = Field(..., description="Number of rows returned/affected") + execution_time_ms: float | None = Field( + None, description="Statement execution time in milliseconds" + ) + + class ExecuteSqlResponse(BaseModel): """Response schema for SQL execution results.""" success: bool = Field(..., description="Whether query executed successfully") - rows: Any | None = Field( + rows: list[dict[str, Any]] | None = Field( None, description="Query result rows as list of dictionaries" ) columns: list[ColumnInfo] | None = Field( @@ -75,12 +105,14 @@ class ExecuteSqlResponse(BaseModel): affected_rows: int | None = Field( None, description="Number of rows affected (for DML queries)" ) - query_id: str | None = Field(None, description="Query tracking ID") execution_time: float | None = Field( None, description="Query execution time in seconds" ) error: str | None = Field(None, description="Error message if query failed") error_type: str | None = Field(None, description="Type of error if failed") + statements: list[StatementInfo] | None = Field( + None, description="Per-statement execution info (for multi-statement queries)" + ) class OpenSqlLabRequest(BaseModel): diff --git a/superset/mcp_service/sql_lab/sql_lab_utils.py b/superset/mcp_service/sql_lab/sql_lab_utils.py deleted file mode 100644 index 10c543c7686..00000000000 --- a/superset/mcp_service/sql_lab/sql_lab_utils.py +++ /dev/null @@ -1,247 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -""" -Utility functions for SQL Lab MCP tools. - -This module contains helper functions for SQL execution, validation, -and database access that are shared across SQL Lab tools. -""" - -import logging -from typing import Any - -logger = logging.getLogger(__name__) - - -def check_database_access(database_id: int) -> Any: - """Check if user has access to the database.""" - # Import inside function to avoid initialization issues - from superset import db, security_manager - from superset.errors import ErrorLevel, SupersetError, SupersetErrorType - from superset.exceptions import SupersetErrorException, SupersetSecurityException - from superset.models.core import Database - - # Use session query to ensure relationships are loaded - database = db.session.query(Database).filter_by(id=database_id).first() - - if not database: - raise SupersetErrorException( - SupersetError( - message=f"Database with ID {database_id} not found", - error_type=SupersetErrorType.DATABASE_NOT_FOUND_ERROR, - level=ErrorLevel.ERROR, - ) - ) - - # Check database access permissions - if not security_manager.can_access_database(database): - raise SupersetSecurityException( - SupersetError( - message=f"Access denied to database {database.database_name}", - error_type=SupersetErrorType.DATABASE_SECURITY_ACCESS_ERROR, - level=ErrorLevel.ERROR, - ) - ) - - return database - - -def validate_sql_query(sql: str, database: Any) -> None: - """Validate SQL query for security and syntax.""" - # Import inside function to avoid initialization issues - from flask import current_app as app - - from superset.exceptions import ( - SupersetDisallowedSQLFunctionException, - SupersetDMLNotAllowedException, - ) - from superset.sql.parse import SQLScript - - # Use SQLScript for proper SQL parsing - script = SQLScript(sql, database.db_engine_spec.engine) - - # Check for DML operations if not allowed - if script.has_mutation() and not database.allow_dml: - raise SupersetDMLNotAllowedException() - - # Check for disallowed functions from config - disallowed_functions = app.config.get("DISALLOWED_SQL_FUNCTIONS", {}).get( - database.db_engine_spec.engine, - set(), - ) - if disallowed_functions and script.check_functions_present(disallowed_functions): - raise SupersetDisallowedSQLFunctionException(disallowed_functions) - - -def execute_sql_query( - database: Any, - sql: str, - schema: str | None, - limit: int, - timeout: int, - parameters: dict[str, Any] | None, -) -> dict[str, Any]: - """Execute SQL query and return results.""" - # Import inside function to avoid initialization issues - from superset.utils.dates import now_as_float - - start_time = now_as_float() - - # Apply parameters and validate - sql = _apply_parameters(sql, parameters) - validate_sql_query(sql, database) - - # Apply limit for SELECT queries using SQLScript - rendered_sql = _apply_limit(sql, limit, database) - - # Execute and get results - results = _execute_query(database, rendered_sql, schema, limit) - - # Calculate execution time - end_time = now_as_float() - results["execution_time"] = end_time - start_time - - return results - - -def _apply_parameters(sql: str, parameters: dict[str, Any] | None) -> str: - """Apply parameters to SQL query.""" - # Import inside function to avoid initialization issues - from superset.errors import ErrorLevel, SupersetError, SupersetErrorType - from superset.exceptions import SupersetErrorException - - if parameters: - try: - return sql.format(**parameters) - except KeyError as e: - raise SupersetErrorException( - SupersetError( - message=f"Missing parameter: {e}", - error_type=SupersetErrorType.INVALID_PAYLOAD_FORMAT_ERROR, - level=ErrorLevel.ERROR, - ) - ) from e - else: - # Check if SQL contains placeholders when no parameters provided - import re - - placeholders = re.findall(r"{(\w+)}", sql) - if placeholders: - raise SupersetErrorException( - SupersetError( - message=f"Missing parameter: {placeholders[0]}", - error_type=SupersetErrorType.INVALID_PAYLOAD_FORMAT_ERROR, - level=ErrorLevel.ERROR, - ) - ) - return sql - - -def _apply_limit(sql: str, limit: int, database: Any) -> str: - """Apply limit to SELECT queries using SQLScript for proper parsing.""" - from superset.sql.parse import LimitMethod, SQLScript - - script = SQLScript(sql, database.db_engine_spec.engine) - - # Only apply limit to non-mutating (SELECT-like) queries - if script.has_mutation(): - return sql - - # Apply limit to each statement in the script - for statement in script.statements: - # Only set limit if not already present - if statement.get_limit_value() is None: - statement.set_limit_value(limit, LimitMethod.FORCE_LIMIT) - - return script.format() - - -def _execute_query( - database: Any, - sql: str, - schema: str | None, - limit: int, -) -> dict[str, Any]: - """Execute the query and process results.""" - # Import inside function to avoid initialization issues - from superset.sql.parse import SQLScript - from superset.utils.core import QuerySource - - results = { - "rows": [], - "columns": [], - "row_count": 0, - "affected_rows": None, - "execution_time": 0.0, - } - - try: - # Execute query with timeout - with database.get_raw_connection( - catalog=None, - schema=schema, - source=QuerySource.SQL_LAB, - ) as conn: - cursor = conn.cursor() - cursor.execute(sql) - - # Use SQLScript for proper SQL parsing to determine query type - script = SQLScript(sql, database.db_engine_spec.engine) - if script.has_mutation(): - _process_dml_results(cursor, conn, results) - else: - _process_select_results(cursor, results, limit) - - except Exception as e: - logger.error("Error executing SQL: %s", e) - raise - - return results - - -def _process_select_results(cursor: Any, results: dict[str, Any], limit: int) -> None: - """Process SELECT query results.""" - # Fetch results - data = cursor.fetchmany(limit) - - # Get column metadata - column_info = [] - if cursor.description: - for col in cursor.description: - column_info.append( - { - "name": col[0], - "type": str(col[1]) if col[1] else "unknown", - "is_nullable": col[6] if len(col) > 6 else None, - } - ) - - # Set column info regardless of whether there's data - if column_info: - results["columns"] = column_info - - # Convert rows to dictionaries - column_names = [col["name"] for col in column_info] - results["rows"] = [dict(zip(column_names, row, strict=False)) for row in data] - results["row_count"] = len(data) - - -def _process_dml_results(cursor: Any, conn: Any, results: dict[str, Any]) -> None: - """Process DML query results.""" - results["affected_rows"] = cursor.rowcount - conn.commit() # pylint: disable=consider-using-transaction diff --git a/superset/mcp_service/sql_lab/tool/execute_sql.py b/superset/mcp_service/sql_lab/tool/execute_sql.py index cb934540843..64fe24a395a 100644 --- a/superset/mcp_service/sql_lab/tool/execute_sql.py +++ b/superset/mcp_service/sql_lab/tool/execute_sql.py @@ -18,19 +18,26 @@ """ Execute SQL MCP Tool -Tool for executing SQL queries against databases with security validation -and timeout protection. +Tool for executing SQL queries against databases using the unified +Database.execute() API with RLS, template rendering, and security validation. """ +from __future__ import annotations + import logging +from typing import Any from fastmcp import Context +from superset_core.api.types import CacheOptions, QueryOptions, QueryResult, QueryStatus from superset_core.mcp import tool -from superset.mcp_service.sql_lab.execute_sql_core import ExecuteSqlCore +from superset.errors import ErrorLevel, SupersetError, SupersetErrorType +from superset.exceptions import SupersetErrorException, SupersetSecurityException from superset.mcp_service.sql_lab.schemas import ( + ColumnInfo, ExecuteSqlRequest, ExecuteSqlResponse, + StatementInfo, ) from superset.mcp_service.utils.schema_utils import parse_request @@ -40,10 +47,7 @@ logger = logging.getLogger(__name__) @tool(tags=["mutate"]) @parse_request(ExecuteSqlRequest) async def execute_sql(request: ExecuteSqlRequest, ctx: Context) -> ExecuteSqlResponse: - """Execute SQL query against database. - - Returns query results with security validation and timeout protection. - """ + """Execute SQL query against database using the unified Database.execute() API.""" await ctx.info( "Starting SQL execution: database_id=%s, timeout=%s, limit=%s, schema=%s" % (request.database_id, request.timeout, request.limit, request.schema_name) @@ -52,36 +56,78 @@ async def execute_sql(request: ExecuteSqlRequest, ctx: Context) -> ExecuteSqlRes # Log SQL query details (truncated for security) sql_preview = request.sql[:100] + "..." if len(request.sql) > 100 else request.sql await ctx.debug( - "SQL query details: sql_preview=%r, sql_length=%s, has_parameters=%s" + "SQL query details: sql_preview=%r, sql_length=%s, has_template_params=%s" % ( sql_preview, len(request.sql), - bool(request.parameters), + bool(request.template_params), ) ) logger.info("Executing SQL query on database ID: %s", request.database_id) try: - # Use the ExecuteSqlCore to handle all the logic - sql_tool = ExecuteSqlCore(use_command_mode=False, logger=logger) - result = sql_tool.run_tool(request) + # Import inside function to avoid initialization issues + from superset import db, security_manager + from superset.models.core import Database + + # 1. Get database and check access + database = db.session.query(Database).filter_by(id=request.database_id).first() + if not database: + raise SupersetErrorException( + SupersetError( + message=f"Database with ID {request.database_id} not found", + error_type=SupersetErrorType.DATABASE_NOT_FOUND_ERROR, + level=ErrorLevel.ERROR, + ) + ) + + if not security_manager.can_access_database(database): + raise SupersetSecurityException( + SupersetError( + message=f"Access denied to database {database.database_name}", + error_type=SupersetErrorType.DATABASE_SECURITY_ACCESS_ERROR, + level=ErrorLevel.ERROR, + ) + ) + + # 2. Build QueryOptions + # Caching is enabled by default to reduce database load. + # force_refresh bypasses cache when user explicitly requests fresh data. + cache_opts = CacheOptions(force_refresh=True) if request.force_refresh else None + options = QueryOptions( + catalog=request.catalog, + schema=request.schema_name, + limit=request.limit, + timeout_seconds=request.timeout, + template_params=request.template_params, + dry_run=request.dry_run, + cache=cache_opts, + ) + + # 3. Execute query + result = database.execute(request.sql, options) + + # 4. Convert to MCP response format + response = _convert_to_response(result) # Log successful execution - if hasattr(result, "data") and result.data: - row_count = len(result.data) if isinstance(result.data, list) else 1 + if response.success: await ctx.info( "SQL execution completed successfully: rows_returned=%s, " - "query_duration_ms=%s" + "execution_time=%s" % ( - row_count, - getattr(result, "query_duration_ms", None), + response.row_count, + response.execution_time, ) ) else: - await ctx.info("SQL execution completed: status=no_data_returned") + await ctx.info( + "SQL execution failed: error=%s, error_type=%s" + % (response.error, response.error_type) + ) - return result + return response except Exception as e: await ctx.error( @@ -92,3 +138,55 @@ async def execute_sql(request: ExecuteSqlRequest, ctx: Context) -> ExecuteSqlRes ) ) raise + + +def _convert_to_response(result: QueryResult) -> ExecuteSqlResponse: + """Convert QueryResult to ExecuteSqlResponse.""" + if result.status != QueryStatus.SUCCESS: + return ExecuteSqlResponse( + success=False, + error=result.error_message, + error_type=result.status.value, + ) + + # Build statement info list + statements = [ + StatementInfo( + original_sql=stmt.original_sql, + executed_sql=stmt.executed_sql, + row_count=stmt.row_count, + execution_time_ms=stmt.execution_time_ms, + ) + for stmt in result.statements + ] + + # Get first statement's data for backward compatibility + first_stmt = result.statements[0] if result.statements else None + rows: list[dict[str, Any]] | None = None + columns: list[ColumnInfo] | None = None + row_count: int | None = None + affected_rows: int | None = None + + if first_stmt and first_stmt.data is not None: + # SELECT query - convert DataFrame + df = first_stmt.data + rows = df.to_dict(orient="records") + columns = [ColumnInfo(name=col, type=str(df[col].dtype)) for col in df.columns] + row_count = len(df) + elif first_stmt: + # DML query + affected_rows = first_stmt.row_count + + return ExecuteSqlResponse( + success=True, + rows=rows, + columns=columns, + row_count=row_count, + affected_rows=affected_rows, + execution_time=( + result.total_execution_time_ms / 1000 + if result.total_execution_time_ms is not None + else None + ), + statements=statements, + ) diff --git a/superset/mcp_service/utils/schema_utils.py b/superset/mcp_service/utils/schema_utils.py index 4e97abf4055..cc1aa2392d7 100644 --- a/superset/mcp_service/utils/schema_utils.py +++ b/superset/mcp_service/utils/schema_utils.py @@ -508,10 +508,23 @@ def parse_request( new_params = [] for name, param in orig_sig.parameters.items(): # Skip ctx parameter - FastMCP tools don't expose it to clients - if param.annotation is FMContext or ( - hasattr(param.annotation, "__name__") - and param.annotation.__name__ == "Context" - ): + # Check for Context type, forward reference string, or parameter named 'ctx' + is_context = ( + param.annotation is FMContext + or ( + hasattr(param.annotation, "__name__") + and param.annotation.__name__ == "Context" + ) + or ( + isinstance(param.annotation, str) + and ( + param.annotation == "Context" + or param.annotation.endswith(".Context") + ) + ) + or name == "ctx" # Fallback: skip any param named 'ctx' + ) + if is_context: continue if name == "request": new_params.append(param.replace(annotation=str | request_class)) diff --git a/tests/unit_tests/mcp_service/sql_lab/test_sql_lab_utils.py b/tests/unit_tests/mcp_service/sql_lab/test_sql_lab_utils.py deleted file mode 100644 index 9a08b72a91d..00000000000 --- a/tests/unit_tests/mcp_service/sql_lab/test_sql_lab_utils.py +++ /dev/null @@ -1,137 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""Unit tests for MCP SQL Lab utility functions.""" - -from unittest.mock import MagicMock - -import pytest - -from superset.mcp_service.sql_lab.sql_lab_utils import _apply_limit -from superset.sql.parse import SQLScript - - -class TestSQLScriptMutationDetection: - """Tests for SQLScript.has_mutation() used in query type detection.""" - - def test_simple_select_no_mutation(self): - """Test simple SELECT query is not a mutation.""" - script = SQLScript("SELECT * FROM table1", "sqlite") - assert script.has_mutation() is False - - def test_cte_query_no_mutation(self): - """Test CTE (WITH clause) query is not a mutation.""" - cte_sql = """ - WITH cte_name AS ( - SELECT * FROM table1 - ) - SELECT * FROM cte_name - """ - script = SQLScript(cte_sql, "sqlite") - assert script.has_mutation() is False - - def test_recursive_cte_no_mutation(self): - """Test recursive CTE is not a mutation.""" - recursive_cte = """ - WITH RECURSIVE cte AS ( - SELECT 1 AS n - UNION ALL - SELECT n + 1 FROM cte WHERE n < 10 - ) - SELECT n FROM cte - """ - script = SQLScript(recursive_cte, "sqlite") - assert script.has_mutation() is False - - def test_multiple_ctes_no_mutation(self): - """Test query with multiple CTEs is not a mutation.""" - multiple_ctes = """ - WITH - cte1 AS (SELECT 1 as a), - cte2 AS (SELECT 2 as b) - SELECT * FROM cte1, cte2 - """ - script = SQLScript(multiple_ctes, "sqlite") - assert script.has_mutation() is False - - def test_insert_is_mutation(self): - """Test INSERT query is a mutation.""" - script = SQLScript("INSERT INTO table1 VALUES (1)", "sqlite") - assert script.has_mutation() is True - - def test_update_is_mutation(self): - """Test UPDATE query is a mutation.""" - script = SQLScript("UPDATE table1 SET col = 1", "sqlite") - assert script.has_mutation() is True - - def test_delete_is_mutation(self): - """Test DELETE query is a mutation.""" - script = SQLScript("DELETE FROM table1", "sqlite") - assert script.has_mutation() is True - - def test_create_is_mutation(self): - """Test CREATE query is a mutation.""" - script = SQLScript("CREATE TABLE table1 (id INT)", "sqlite") - assert script.has_mutation() is True - - -class TestApplyLimit: - """Tests for _apply_limit function using SQLScript.""" - - @pytest.fixture - def mock_database(self): - """Create a mock database with sqlite engine spec.""" - db = MagicMock() - db.db_engine_spec.engine = "sqlite" - return db - - def test_adds_limit_to_select(self, mock_database): - """Test LIMIT is added to SELECT query.""" - result = _apply_limit("SELECT * FROM table1", 100, mock_database) - assert "LIMIT 100" in result - - def test_adds_limit_to_cte(self, mock_database): - """Test LIMIT is added to CTE query.""" - cte_sql = "WITH cte AS (SELECT 1) SELECT * FROM cte" - result = _apply_limit(cte_sql, 50, mock_database) - assert "LIMIT 50" in result - - def test_preserves_existing_limit(self, mock_database): - """Test existing LIMIT is not modified.""" - sql = "SELECT * FROM table1 LIMIT 10" - result = _apply_limit(sql, 100, mock_database) - assert "LIMIT 10" in result - assert "LIMIT 100" not in result - - def test_preserves_existing_limit_in_cte(self, mock_database): - """Test existing LIMIT in CTE query is not modified.""" - cte_sql = "WITH cte AS (SELECT 1) SELECT * FROM cte LIMIT 5" - result = _apply_limit(cte_sql, 100, mock_database) - assert "LIMIT 5" in result - assert "LIMIT 100" not in result - - def test_no_limit_on_insert(self, mock_database): - """Test LIMIT is not added to INSERT query.""" - sql = "INSERT INTO table1 VALUES (1)" - result = _apply_limit(sql, 100, mock_database) - assert result == sql - - def test_no_limit_on_update(self, mock_database): - """Test LIMIT is not added to UPDATE query.""" - sql = "UPDATE table1 SET col = 1" - result = _apply_limit(sql, 100, mock_database) - assert result == sql diff --git a/tests/unit_tests/mcp_service/sql_lab/tool/test_execute_sql.py b/tests/unit_tests/mcp_service/sql_lab/tool/test_execute_sql.py index 5b3012462db..86b9af9188c 100644 --- a/tests/unit_tests/mcp_service/sql_lab/tool/test_execute_sql.py +++ b/tests/unit_tests/mcp_service/sql_lab/tool/test_execute_sql.py @@ -17,14 +17,20 @@ """ Unit tests for execute_sql MCP tool + +These tests mock Database.execute() to test the MCP tool's parameter mapping +and response conversion logic. """ import logging +from typing import Any from unittest.mock import MagicMock, Mock, patch +import pandas as pd import pytest from fastmcp import Client from fastmcp.exceptions import ToolError +from superset_core.api.types import QueryResult, QueryStatus, StatementResult from superset.mcp_service.app import mcp @@ -48,6 +54,71 @@ def mock_auth(): yield mock_get_user +def _create_select_result( + rows: list[dict[str, Any]], + columns: list[str], + original_sql: str = "SELECT * FROM users", + executed_sql: str | None = None, + execution_time_ms: float = 10.0, +) -> QueryResult: + """Create a mock QueryResult for SELECT queries.""" + df = pd.DataFrame(rows) if rows else pd.DataFrame(columns=columns) + return QueryResult( + status=QueryStatus.SUCCESS, + statements=[ + StatementResult( + original_sql=original_sql, + executed_sql=executed_sql or original_sql, + data=df, + row_count=len(df), + execution_time_ms=execution_time_ms, + ) + ], + query_id=None, + total_execution_time_ms=execution_time_ms, + is_cached=False, + ) + + +def _create_dml_result( + affected_rows: int, + original_sql: str = "UPDATE users SET active = true", + executed_sql: str | None = None, + execution_time_ms: float = 5.0, +) -> QueryResult: + """Create a mock QueryResult for DML queries.""" + return QueryResult( + status=QueryStatus.SUCCESS, + statements=[ + StatementResult( + original_sql=original_sql, + executed_sql=executed_sql or original_sql, + data=None, + row_count=affected_rows, + execution_time_ms=execution_time_ms, + ) + ], + query_id=None, + total_execution_time_ms=execution_time_ms, + is_cached=False, + ) + + +def _create_error_result( + error_message: str, + status: QueryStatus = QueryStatus.FAILED, +) -> QueryResult: + """Create a mock QueryResult for failed queries.""" + return QueryResult( + status=status, + statements=[], + query_id=None, + total_execution_time_ms=0, + is_cached=False, + error_message=error_message, + ) + + def _mock_database( id: int = 1, database_name: str = "test_db", @@ -58,26 +129,6 @@ def _mock_database( database.id = id database.database_name = database_name database.allow_dml = allow_dml - - # Mock raw connection context manager - mock_cursor = Mock() - mock_cursor.description = [ - ("id", "INTEGER", None, None, None, None, False), - ("name", "VARCHAR", None, None, None, None, True), - ] - mock_cursor.fetchmany.return_value = [(1, "test_name")] - mock_cursor.rowcount = 1 - - mock_conn = Mock() - mock_conn.cursor.return_value = mock_cursor - mock_conn.commit = Mock() - - mock_context = MagicMock() - mock_context.__enter__.return_value = mock_conn - mock_context.__exit__.return_value = None - - database.get_raw_connection.return_value = mock_context - return database @@ -91,8 +142,11 @@ class TestExecuteSql: self, mock_db, mock_security_manager, mcp_server ): """Test basic SELECT query execution.""" - # Setup mocks mock_database = _mock_database() + mock_database.execute.return_value = _create_select_result( + rows=[{"id": 1, "name": "test_name"}], + columns=["id", "name"], + ) mock_db.session.query.return_value.filter_by.return_value.first.return_value = ( mock_database ) @@ -107,25 +161,41 @@ class TestExecuteSql: async with Client(mcp_server) as client: result = await client.call_tool("execute_sql", {"request": request}) - assert result.data.success is True - assert result.data.error is None - assert result.data.row_count == 1 - assert len(result.data.rows) == 1 - assert result.data.rows[0]["id"] == 1 - assert result.data.rows[0]["name"] == "test_name" - assert len(result.data.columns) == 2 - assert result.data.columns[0].name == "id" - assert result.data.columns[0].type == "INTEGER" - assert result.data.execution_time > 0 + # Use structured_content for dictionary access (Pydantic model responses) + data = result.structured_content + assert data["success"] is True + assert data["error"] is None + assert data["row_count"] == 1 + assert len(data["rows"]) == 1 + assert data["rows"][0]["id"] == 1 + assert data["rows"][0]["name"] == "test_name" + assert len(data["columns"]) == 2 + assert data["columns"][0]["name"] == "id" + assert data["execution_time"] > 0 + + # Verify Database.execute() was called with correct QueryOptions + mock_database.execute.assert_called_once() + call_args = mock_database.execute.call_args + assert call_args[0][0] == request["sql"] + options = call_args[0][1] + assert options.limit == 10 + # Caching is enabled by default (force_refresh=False means cache=None) + assert options.cache is None @patch("superset.security_manager") @patch("superset.db") @pytest.mark.asyncio - async def test_execute_sql_with_parameters( + async def test_execute_sql_with_template_params( self, mock_db, mock_security_manager, mcp_server ): - """Test SQL execution with parameter substitution.""" + """Test SQL execution with Jinja2 template parameters.""" mock_database = _mock_database() + mock_database.execute.return_value = _create_select_result( + rows=[{"order_id": 1, "status": "active"}], + columns=["order_id", "status"], + original_sql="SELECT * FROM {{ table }} WHERE status = '{{ status }}'", + executed_sql="SELECT * FROM orders WHERE status = 'active'", + ) mock_db.session.query.return_value.filter_by.return_value.first.return_value = ( mock_database ) @@ -133,33 +203,42 @@ class TestExecuteSql: request = { "database_id": 1, - "sql": "SELECT * FROM {table} WHERE status = '{status}' LIMIT {limit}", - "parameters": {"table": "orders", "status": "active", "limit": "5"}, + "sql": "SELECT * FROM {{ table }} WHERE status = '{{ status }}'", + "template_params": {"table": "orders", "status": "active"}, "limit": 10, } async with Client(mcp_server) as client: result = await client.call_tool("execute_sql", {"request": request}) - assert result.data.success is True - assert result.data.error is None - # Verify parameter substitution happened - mock_database.get_raw_connection.assert_called_once() - cursor = ( # fmt: skip - mock_database.get_raw_connection.return_value.__enter__.return_value.cursor.return_value - ) - # Check that the SQL was formatted with parameters - executed_sql = cursor.execute.call_args[0][0] - assert "orders" in executed_sql - assert "active" in executed_sql + # Use structured_content for dictionary access (Pydantic model responses) + data = result.structured_content + assert data["success"] is True + assert data["error"] is None + + # Verify template_params were passed to QueryOptions + call_args = mock_database.execute.call_args + options = call_args[0][1] + assert options.template_params == {"table": "orders", "status": "active"} + + # Verify statements contain both original and executed SQL + assert data["statements"] is not None + assert len(data["statements"]) == 1 + assert "{{ table }}" in data["statements"][0]["original_sql"] + assert "orders" in data["statements"][0]["executed_sql"] @patch("superset.security_manager") @patch("superset.db") @pytest.mark.asyncio async def test_execute_sql_database_not_found( - self, mock_db, mock_security_manager, mcp_server + self, + mock_db, + mock_security_manager, # noqa: PT019 + mcp_server, ): """Test error when database is not found.""" + # mock_security_manager is patched but not used (error happens first) + del mock_security_manager # Silence unused variable warning mock_db.session.query.return_value.filter_by.return_value.first.return_value = ( None ) @@ -171,15 +250,10 @@ class TestExecuteSql: } async with Client(mcp_server) as client: - result = await client.call_tool("execute_sql", {"request": request}) + with pytest.raises(ToolError, match="Database with ID 999 not found"): + await client.call_tool("execute_sql", {"request": request}) - assert result.data.success is False - assert result.data.error is not None - assert "Database with ID 999 not found" in result.data.error - assert result.data.error_type == "DATABASE_NOT_FOUND_ERROR" - assert result.data.rows is None - - @patch("superset.security_manager") + @patch("superset.security_manager", new_callable=MagicMock) @patch("superset.db") @pytest.mark.asyncio async def test_execute_sql_access_denied( @@ -190,10 +264,7 @@ class TestExecuteSql: mock_db.session.query.return_value.filter_by.return_value.first.return_value = ( mock_database ) - # Use Mock instead of AsyncMock for synchronous call - from unittest.mock import Mock - - mock_security_manager.can_access_database = Mock(return_value=False) + mock_security_manager.can_access_database.return_value = False request = { "database_id": 1, @@ -202,58 +273,27 @@ class TestExecuteSql: } async with Client(mcp_server) as client: - result = await client.call_tool("execute_sql", {"request": request}) - - assert result.data.success is False - assert result.data.error is not None - assert "Access denied to database" in result.data.error - assert result.data.error_type == "SECURITY_ERROR" + with pytest.raises(ToolError, match="Access denied to database"): + await client.call_tool("execute_sql", {"request": request}) @patch("superset.security_manager") @patch("superset.db") @pytest.mark.asyncio - async def test_execute_sql_dml_not_allowed( + async def test_execute_sql_dml_success( self, mock_db, mock_security_manager, mcp_server ): - """Test error when DML operations are not allowed.""" - mock_database = _mock_database(allow_dml=False) - mock_db.session.query.return_value.filter_by.return_value.first.return_value = ( - mock_database - ) - mock_security_manager.can_access_database.return_value = True - - request = { - "database_id": 1, - "sql": "UPDATE users SET name = 'test' WHERE id = 1", - "limit": 1, - } - - async with Client(mcp_server) as client: - result = await client.call_tool("execute_sql", {"request": request}) - - assert result.data.success is False - assert result.data.error is not None - assert result.data.error_type == "DML_NOT_ALLOWED" - - @patch("superset.security_manager") - @patch("superset.db") - @pytest.mark.asyncio - async def test_execute_sql_dml_allowed( - self, mock_db, mock_security_manager, mcp_server - ): - """Test successful DML execution when allowed.""" + """Test successful DML execution.""" mock_database = _mock_database(allow_dml=True) + dml_sql = "UPDATE users SET active = true WHERE last_login > '2024-01-01'" + mock_database.execute.return_value = _create_dml_result( + affected_rows=3, + original_sql=dml_sql, + ) mock_db.session.query.return_value.filter_by.return_value.first.return_value = ( mock_database ) mock_security_manager.can_access_database.return_value = True - # Mock cursor for DML operation - cursor = ( # fmt: skip - mock_database.get_raw_connection.return_value.__enter__.return_value.cursor.return_value - ) - cursor.rowcount = 3 # 3 rows affected - request = { "database_id": 1, "sql": "UPDATE users SET active = true WHERE last_login > '2024-01-01'", @@ -263,15 +303,13 @@ class TestExecuteSql: async with Client(mcp_server) as client: result = await client.call_tool("execute_sql", {"request": request}) - assert result.data.success is True - assert result.data.error is None - assert result.data.affected_rows == 3 - assert result.data.rows == [] # Empty rows for DML - assert result.data.row_count == 0 - # Verify commit was called - ( - mock_database.get_raw_connection.return_value.__enter__.return_value.commit.assert_called_once() - ) + # Use structured_content for dictionary access (Pydantic model responses) + data = result.structured_content + assert data["success"] is True + assert data["error"] is None + assert data["affected_rows"] == 3 + assert data["rows"] is None # None for DML + assert data["row_count"] is None @patch("superset.security_manager") @patch("superset.db") @@ -281,17 +319,15 @@ class TestExecuteSql: ): """Test query that returns no results.""" mock_database = _mock_database() + mock_database.execute.return_value = _create_select_result( + rows=[], + columns=["id", "name"], + ) mock_db.session.query.return_value.filter_by.return_value.first.return_value = ( mock_database ) mock_security_manager.can_access_database.return_value = True - # Mock empty results - cursor = ( # fmt: skip - mock_database.get_raw_connection.return_value.__enter__.return_value.cursor.return_value - ) - cursor.fetchmany.return_value = [] - request = { "database_id": 1, "sql": "SELECT * FROM users WHERE id = 999999", @@ -301,77 +337,26 @@ class TestExecuteSql: async with Client(mcp_server) as client: result = await client.call_tool("execute_sql", {"request": request}) - assert result.data.success is True - assert result.data.error is None - assert result.data.row_count == 0 - assert len(result.data.rows) == 0 - assert len(result.data.columns) == 2 # Column metadata still returned + # Use structured_content for dictionary access (Pydantic model responses) + data = result.structured_content + assert data["success"] is True + assert data["error"] is None + assert data["row_count"] == 0 + assert len(data["rows"]) == 0 + assert len(data["columns"]) == 2 # Column metadata still returned @patch("superset.security_manager") @patch("superset.db") @pytest.mark.asyncio - async def test_execute_sql_missing_parameter( + async def test_execute_sql_with_schema_and_catalog( self, mock_db, mock_security_manager, mcp_server ): - """Test error when required parameter is missing.""" + """Test SQL execution with schema and catalog specification.""" mock_database = _mock_database() - mock_db.session.query.return_value.filter_by.return_value.first.return_value = ( - mock_database + mock_database.execute.return_value = _create_select_result( + rows=[{"total": 100}], + columns=["total"], ) - mock_security_manager.can_access_database.return_value = True - - request = { - "database_id": 1, - "sql": "SELECT * FROM {table_name} WHERE id = {user_id}", - "parameters": {"table_name": "users"}, # Missing user_id - "limit": 1, - } - - async with Client(mcp_server) as client: - result = await client.call_tool("execute_sql", {"request": request}) - - assert result.data.success is False - assert result.data.error is not None - assert "user_id" in result.data.error # Error contains parameter name - assert result.data.error_type == "INVALID_PAYLOAD_FORMAT_ERROR" - - @patch("superset.security_manager") - @patch("superset.db") - @pytest.mark.asyncio - async def test_execute_sql_empty_parameters_with_placeholders( - self, mock_db, mock_security_manager, mcp_server - ): - """Test error when empty parameters dict is provided but SQL has - placeholders.""" - mock_database = _mock_database() - mock_db.session.query.return_value.filter_by.return_value.first.return_value = ( - mock_database - ) - mock_security_manager.can_access_database.return_value = True - - request = { - "database_id": 1, - "sql": "SELECT * FROM {table_name} LIMIT 5", - "parameters": {}, # Empty dict but SQL has {table_name} - "limit": 5, - } - - async with Client(mcp_server) as client: - result = await client.call_tool("execute_sql", {"request": request}) - - assert result.data.success is False - assert result.data.error is not None - assert "Missing parameter: table_name" in result.data.error - assert result.data.error_type == "INVALID_PAYLOAD_FORMAT_ERROR" - - @patch("superset.security_manager") - @patch("superset.db") - @pytest.mark.asyncio - async def test_execute_sql_with_schema( - self, mock_db, mock_security_manager, mcp_server - ): - """Test SQL execution with schema specification.""" - mock_database = _mock_database() mock_db.session.query.return_value.filter_by.return_value.first.return_value = ( mock_database ) @@ -381,28 +366,47 @@ class TestExecuteSql: "database_id": 1, "sql": "SELECT COUNT(*) as total FROM orders", "schema": "sales", + "catalog": "prod_catalog", "limit": 1, } async with Client(mcp_server) as client: result = await client.call_tool("execute_sql", {"request": request}) - assert result.data.success is True - assert result.data.error is None - # Verify schema was passed to get_raw_connection - # Verify schema was passed - call_args = mock_database.get_raw_connection.call_args - assert call_args[1]["schema"] == "sales" - assert call_args[1]["catalog"] is None + # Use structured_content for dictionary access (Pydantic model responses) + data = result.structured_content + assert data["success"] is True + + # Verify schema and catalog were passed to QueryOptions + call_args = mock_database.execute.call_args + options = call_args[0][1] + assert options.schema == "sales" + assert options.catalog == "prod_catalog" @patch("superset.security_manager") @patch("superset.db") @pytest.mark.asyncio - async def test_execute_sql_limit_enforcement( + async def test_execute_sql_dry_run( self, mock_db, mock_security_manager, mcp_server ): - """Test that LIMIT is added to SELECT queries without one.""" + """Test dry_run mode returns transformed SQL without executing.""" mock_database = _mock_database() + executed_sql = "SELECT * FROM users WHERE user_id IN (SELECT ...) LIMIT 100" + mock_database.execute.return_value = QueryResult( + status=QueryStatus.SUCCESS, + statements=[ + StatementResult( + original_sql="SELECT * FROM {{ table }}", + executed_sql=executed_sql, + data=None, + row_count=0, + execution_time_ms=0, + ) + ], + query_id=None, + total_execution_time_ms=0, + is_cached=False, + ) mock_db.session.query.return_value.filter_by.return_value.first.return_value = ( mock_database ) @@ -410,25 +414,32 @@ class TestExecuteSql: request = { "database_id": 1, - "sql": "SELECT * FROM users", # No LIMIT - "limit": 50, + "sql": "SELECT * FROM {{ table }}", + "template_params": {"table": "users"}, + "dry_run": True, + "limit": 100, } async with Client(mcp_server) as client: result = await client.call_tool("execute_sql", {"request": request}) - assert result.data.success is True - # Verify LIMIT was added - cursor = ( # fmt: skip - mock_database.get_raw_connection.return_value.__enter__.return_value.cursor.return_value - ) - executed_sql = cursor.execute.call_args[0][0] - assert "LIMIT 50" in executed_sql + # Use structured_content for dictionary access (Pydantic model responses) + data = result.structured_content + assert data["success"] is True + # Verify dry_run was passed + call_args = mock_database.execute.call_args + options = call_args[0][1] + assert options.dry_run is True + + # Verify statements show transformed SQL + assert data["statements"] is not None + assert "{{ table }}" in data["statements"][0]["original_sql"] + assert "users" in data["statements"][0]["executed_sql"] @patch("superset.security_manager") @patch("superset.db") @pytest.mark.asyncio - async def test_execute_sql_sql_injection_prevention( + async def test_execute_sql_timeout_error( self, mock_db, mock_security_manager, mcp_server ): """Test that SQL injection attempts are handled safely. @@ -437,6 +448,10 @@ class TestExecuteSql: before execution when DML is not allowed on the database. """ mock_database = _mock_database() + mock_database.execute.return_value = _create_error_result( + error_message="Query exceeded the timeout limit", + status=QueryStatus.TIMED_OUT, + ) mock_db.session.query.return_value.filter_by.return_value.first.return_value = ( mock_database ) @@ -444,19 +459,76 @@ class TestExecuteSql: request = { "database_id": 1, - "sql": "SELECT * FROM users WHERE id = 1; DROP TABLE users;--", + "sql": "SELECT * FROM large_table", + "timeout": 5, + "limit": 100, + } + + async with Client(mcp_server) as client: + result = await client.call_tool("execute_sql", {"request": request}) + + # Use structured_content for dictionary access (Pydantic model responses) + data = result.structured_content + assert data["success"] is False + assert data["error"] == "Query exceeded the timeout limit" + assert data["error_type"] == "timed_out" + + @patch("superset.security_manager") + @patch("superset.db") + @pytest.mark.asyncio + async def test_execute_sql_multi_statement( + self, mock_db, mock_security_manager, mcp_server + ): + """Test multi-statement SQL execution.""" + mock_database = _mock_database() + mock_database.execute.return_value = QueryResult( + status=QueryStatus.SUCCESS, + statements=[ + StatementResult( + original_sql="SELECT 1 as a", + executed_sql="SELECT 1 as a", + data=pd.DataFrame([{"a": 1}]), + row_count=1, + execution_time_ms=5.0, + ), + StatementResult( + original_sql="SELECT 2 as b", + executed_sql="SELECT 2 as b", + data=pd.DataFrame([{"b": 2}]), + row_count=1, + execution_time_ms=3.0, + ), + ], + query_id=None, + total_execution_time_ms=8.0, + is_cached=False, + ) + mock_db.session.query.return_value.filter_by.return_value.first.return_value = ( + mock_database + ) + mock_security_manager.can_access_database.return_value = True + + request = { + "database_id": 1, + "sql": "SELECT 1 as a; SELECT 2 as b;", "limit": 10, } async with Client(mcp_server) as client: result = await client.call_tool("execute_sql", {"request": request}) - # SQLScript correctly detects DROP TABLE as a mutation - # and blocks it before execution (improved security) - assert result.data.success is False - assert result.data.error is not None - assert "DML" in result.data.error or "mutates" in result.data.error - assert result.data.error_type == "DML_NOT_ALLOWED" + # Use structured_content for dictionary access (Pydantic model responses) + data = result.structured_content + assert data["success"] is True + # Statements should contain both + assert data["statements"] is not None + assert len(data["statements"]) == 2 + assert data["statements"][0]["original_sql"] == "SELECT 1 as a" + assert data["statements"][1]["original_sql"] == "SELECT 2 as b" + + # rows/columns should be from first statement for backward compat + assert data["rows"] == [{"a": 1}] + assert data["row_count"] == 1 @pytest.mark.asyncio async def test_execute_sql_empty_query_validation(self, mcp_server): @@ -495,3 +567,39 @@ class TestExecuteSql: async with Client(mcp_server) as client: with pytest.raises(ToolError, match="less than or equal to 10000"): await client.call_tool("execute_sql", {"request": request}) + + @patch("superset.security_manager") + @patch("superset.db") + @pytest.mark.asyncio + async def test_execute_sql_force_refresh( + self, mock_db, mock_security_manager, mcp_server + ): + """Test force_refresh bypasses cache.""" + mock_database = _mock_database() + mock_database.execute.return_value = _create_select_result( + rows=[{"id": 1}], + columns=["id"], + ) + mock_db.session.query.return_value.filter_by.return_value.first.return_value = ( + mock_database + ) + mock_security_manager.can_access_database.return_value = True + + request = { + "database_id": 1, + "sql": "SELECT id FROM users", + "limit": 10, + "force_refresh": True, + } + + async with Client(mcp_server) as client: + result = await client.call_tool("execute_sql", {"request": request}) + + data = result.structured_content + assert data["success"] is True + + # Verify force_refresh was passed to CacheOptions + call_args = mock_database.execute.call_args + options = call_args[0][1] + assert options.cache is not None + assert options.cache.force_refresh is True