refactor: Migrates the MCP execute_sql tool to use the SQL execution API (#36739)

Co-authored-by: codeant-ai-for-open-source[bot] <244253245+codeant-ai-for-open-source[bot]@users.noreply.github.com>
This commit is contained in:
Michael S. Molina
2025-12-22 09:48:28 -03:00
committed by GitHub
parent c0bcf28947
commit 6b25d0663e
8 changed files with 492 additions and 838 deletions

View File

@@ -1,137 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Unit tests for MCP SQL Lab utility functions."""
from unittest.mock import MagicMock
import pytest
from superset.mcp_service.sql_lab.sql_lab_utils import _apply_limit
from superset.sql.parse import SQLScript
class TestSQLScriptMutationDetection:
"""Tests for SQLScript.has_mutation() used in query type detection."""
def test_simple_select_no_mutation(self):
"""Test simple SELECT query is not a mutation."""
script = SQLScript("SELECT * FROM table1", "sqlite")
assert script.has_mutation() is False
def test_cte_query_no_mutation(self):
"""Test CTE (WITH clause) query is not a mutation."""
cte_sql = """
WITH cte_name AS (
SELECT * FROM table1
)
SELECT * FROM cte_name
"""
script = SQLScript(cte_sql, "sqlite")
assert script.has_mutation() is False
def test_recursive_cte_no_mutation(self):
"""Test recursive CTE is not a mutation."""
recursive_cte = """
WITH RECURSIVE cte AS (
SELECT 1 AS n
UNION ALL
SELECT n + 1 FROM cte WHERE n < 10
)
SELECT n FROM cte
"""
script = SQLScript(recursive_cte, "sqlite")
assert script.has_mutation() is False
def test_multiple_ctes_no_mutation(self):
"""Test query with multiple CTEs is not a mutation."""
multiple_ctes = """
WITH
cte1 AS (SELECT 1 as a),
cte2 AS (SELECT 2 as b)
SELECT * FROM cte1, cte2
"""
script = SQLScript(multiple_ctes, "sqlite")
assert script.has_mutation() is False
def test_insert_is_mutation(self):
"""Test INSERT query is a mutation."""
script = SQLScript("INSERT INTO table1 VALUES (1)", "sqlite")
assert script.has_mutation() is True
def test_update_is_mutation(self):
"""Test UPDATE query is a mutation."""
script = SQLScript("UPDATE table1 SET col = 1", "sqlite")
assert script.has_mutation() is True
def test_delete_is_mutation(self):
"""Test DELETE query is a mutation."""
script = SQLScript("DELETE FROM table1", "sqlite")
assert script.has_mutation() is True
def test_create_is_mutation(self):
"""Test CREATE query is a mutation."""
script = SQLScript("CREATE TABLE table1 (id INT)", "sqlite")
assert script.has_mutation() is True
class TestApplyLimit:
"""Tests for _apply_limit function using SQLScript."""
@pytest.fixture
def mock_database(self):
"""Create a mock database with sqlite engine spec."""
db = MagicMock()
db.db_engine_spec.engine = "sqlite"
return db
def test_adds_limit_to_select(self, mock_database):
"""Test LIMIT is added to SELECT query."""
result = _apply_limit("SELECT * FROM table1", 100, mock_database)
assert "LIMIT 100" in result
def test_adds_limit_to_cte(self, mock_database):
"""Test LIMIT is added to CTE query."""
cte_sql = "WITH cte AS (SELECT 1) SELECT * FROM cte"
result = _apply_limit(cte_sql, 50, mock_database)
assert "LIMIT 50" in result
def test_preserves_existing_limit(self, mock_database):
"""Test existing LIMIT is not modified."""
sql = "SELECT * FROM table1 LIMIT 10"
result = _apply_limit(sql, 100, mock_database)
assert "LIMIT 10" in result
assert "LIMIT 100" not in result
def test_preserves_existing_limit_in_cte(self, mock_database):
"""Test existing LIMIT in CTE query is not modified."""
cte_sql = "WITH cte AS (SELECT 1) SELECT * FROM cte LIMIT 5"
result = _apply_limit(cte_sql, 100, mock_database)
assert "LIMIT 5" in result
assert "LIMIT 100" not in result
def test_no_limit_on_insert(self, mock_database):
"""Test LIMIT is not added to INSERT query."""
sql = "INSERT INTO table1 VALUES (1)"
result = _apply_limit(sql, 100, mock_database)
assert result == sql
def test_no_limit_on_update(self, mock_database):
"""Test LIMIT is not added to UPDATE query."""
sql = "UPDATE table1 SET col = 1"
result = _apply_limit(sql, 100, mock_database)
assert result == sql

View File

@@ -17,14 +17,20 @@
"""
Unit tests for execute_sql MCP tool
These tests mock Database.execute() to test the MCP tool's parameter mapping
and response conversion logic.
"""
import logging
from typing import Any
from unittest.mock import MagicMock, Mock, patch
import pandas as pd
import pytest
from fastmcp import Client
from fastmcp.exceptions import ToolError
from superset_core.api.types import QueryResult, QueryStatus, StatementResult
from superset.mcp_service.app import mcp
@@ -48,6 +54,71 @@ def mock_auth():
yield mock_get_user
def _create_select_result(
rows: list[dict[str, Any]],
columns: list[str],
original_sql: str = "SELECT * FROM users",
executed_sql: str | None = None,
execution_time_ms: float = 10.0,
) -> QueryResult:
"""Create a mock QueryResult for SELECT queries."""
df = pd.DataFrame(rows) if rows else pd.DataFrame(columns=columns)
return QueryResult(
status=QueryStatus.SUCCESS,
statements=[
StatementResult(
original_sql=original_sql,
executed_sql=executed_sql or original_sql,
data=df,
row_count=len(df),
execution_time_ms=execution_time_ms,
)
],
query_id=None,
total_execution_time_ms=execution_time_ms,
is_cached=False,
)
def _create_dml_result(
affected_rows: int,
original_sql: str = "UPDATE users SET active = true",
executed_sql: str | None = None,
execution_time_ms: float = 5.0,
) -> QueryResult:
"""Create a mock QueryResult for DML queries."""
return QueryResult(
status=QueryStatus.SUCCESS,
statements=[
StatementResult(
original_sql=original_sql,
executed_sql=executed_sql or original_sql,
data=None,
row_count=affected_rows,
execution_time_ms=execution_time_ms,
)
],
query_id=None,
total_execution_time_ms=execution_time_ms,
is_cached=False,
)
def _create_error_result(
error_message: str,
status: QueryStatus = QueryStatus.FAILED,
) -> QueryResult:
"""Create a mock QueryResult for failed queries."""
return QueryResult(
status=status,
statements=[],
query_id=None,
total_execution_time_ms=0,
is_cached=False,
error_message=error_message,
)
def _mock_database(
id: int = 1,
database_name: str = "test_db",
@@ -58,26 +129,6 @@ def _mock_database(
database.id = id
database.database_name = database_name
database.allow_dml = allow_dml
# Mock raw connection context manager
mock_cursor = Mock()
mock_cursor.description = [
("id", "INTEGER", None, None, None, None, False),
("name", "VARCHAR", None, None, None, None, True),
]
mock_cursor.fetchmany.return_value = [(1, "test_name")]
mock_cursor.rowcount = 1
mock_conn = Mock()
mock_conn.cursor.return_value = mock_cursor
mock_conn.commit = Mock()
mock_context = MagicMock()
mock_context.__enter__.return_value = mock_conn
mock_context.__exit__.return_value = None
database.get_raw_connection.return_value = mock_context
return database
@@ -91,8 +142,11 @@ class TestExecuteSql:
self, mock_db, mock_security_manager, mcp_server
):
"""Test basic SELECT query execution."""
# Setup mocks
mock_database = _mock_database()
mock_database.execute.return_value = _create_select_result(
rows=[{"id": 1, "name": "test_name"}],
columns=["id", "name"],
)
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
mock_database
)
@@ -107,25 +161,41 @@ class TestExecuteSql:
async with Client(mcp_server) as client:
result = await client.call_tool("execute_sql", {"request": request})
assert result.data.success is True
assert result.data.error is None
assert result.data.row_count == 1
assert len(result.data.rows) == 1
assert result.data.rows[0]["id"] == 1
assert result.data.rows[0]["name"] == "test_name"
assert len(result.data.columns) == 2
assert result.data.columns[0].name == "id"
assert result.data.columns[0].type == "INTEGER"
assert result.data.execution_time > 0
# Use structured_content for dictionary access (Pydantic model responses)
data = result.structured_content
assert data["success"] is True
assert data["error"] is None
assert data["row_count"] == 1
assert len(data["rows"]) == 1
assert data["rows"][0]["id"] == 1
assert data["rows"][0]["name"] == "test_name"
assert len(data["columns"]) == 2
assert data["columns"][0]["name"] == "id"
assert data["execution_time"] > 0
# Verify Database.execute() was called with correct QueryOptions
mock_database.execute.assert_called_once()
call_args = mock_database.execute.call_args
assert call_args[0][0] == request["sql"]
options = call_args[0][1]
assert options.limit == 10
# Caching is enabled by default (force_refresh=False means cache=None)
assert options.cache is None
@patch("superset.security_manager")
@patch("superset.db")
@pytest.mark.asyncio
async def test_execute_sql_with_parameters(
async def test_execute_sql_with_template_params(
self, mock_db, mock_security_manager, mcp_server
):
"""Test SQL execution with parameter substitution."""
"""Test SQL execution with Jinja2 template parameters."""
mock_database = _mock_database()
mock_database.execute.return_value = _create_select_result(
rows=[{"order_id": 1, "status": "active"}],
columns=["order_id", "status"],
original_sql="SELECT * FROM {{ table }} WHERE status = '{{ status }}'",
executed_sql="SELECT * FROM orders WHERE status = 'active'",
)
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
mock_database
)
@@ -133,33 +203,42 @@ class TestExecuteSql:
request = {
"database_id": 1,
"sql": "SELECT * FROM {table} WHERE status = '{status}' LIMIT {limit}",
"parameters": {"table": "orders", "status": "active", "limit": "5"},
"sql": "SELECT * FROM {{ table }} WHERE status = '{{ status }}'",
"template_params": {"table": "orders", "status": "active"},
"limit": 10,
}
async with Client(mcp_server) as client:
result = await client.call_tool("execute_sql", {"request": request})
assert result.data.success is True
assert result.data.error is None
# Verify parameter substitution happened
mock_database.get_raw_connection.assert_called_once()
cursor = ( # fmt: skip
mock_database.get_raw_connection.return_value.__enter__.return_value.cursor.return_value
)
# Check that the SQL was formatted with parameters
executed_sql = cursor.execute.call_args[0][0]
assert "orders" in executed_sql
assert "active" in executed_sql
# Use structured_content for dictionary access (Pydantic model responses)
data = result.structured_content
assert data["success"] is True
assert data["error"] is None
# Verify template_params were passed to QueryOptions
call_args = mock_database.execute.call_args
options = call_args[0][1]
assert options.template_params == {"table": "orders", "status": "active"}
# Verify statements contain both original and executed SQL
assert data["statements"] is not None
assert len(data["statements"]) == 1
assert "{{ table }}" in data["statements"][0]["original_sql"]
assert "orders" in data["statements"][0]["executed_sql"]
@patch("superset.security_manager")
@patch("superset.db")
@pytest.mark.asyncio
async def test_execute_sql_database_not_found(
self, mock_db, mock_security_manager, mcp_server
self,
mock_db,
mock_security_manager, # noqa: PT019
mcp_server,
):
"""Test error when database is not found."""
# mock_security_manager is patched but not used (error happens first)
del mock_security_manager # Silence unused variable warning
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
None
)
@@ -171,15 +250,10 @@ class TestExecuteSql:
}
async with Client(mcp_server) as client:
result = await client.call_tool("execute_sql", {"request": request})
with pytest.raises(ToolError, match="Database with ID 999 not found"):
await client.call_tool("execute_sql", {"request": request})
assert result.data.success is False
assert result.data.error is not None
assert "Database with ID 999 not found" in result.data.error
assert result.data.error_type == "DATABASE_NOT_FOUND_ERROR"
assert result.data.rows is None
@patch("superset.security_manager")
@patch("superset.security_manager", new_callable=MagicMock)
@patch("superset.db")
@pytest.mark.asyncio
async def test_execute_sql_access_denied(
@@ -190,10 +264,7 @@ class TestExecuteSql:
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
mock_database
)
# Use Mock instead of AsyncMock for synchronous call
from unittest.mock import Mock
mock_security_manager.can_access_database = Mock(return_value=False)
mock_security_manager.can_access_database.return_value = False
request = {
"database_id": 1,
@@ -202,58 +273,27 @@ class TestExecuteSql:
}
async with Client(mcp_server) as client:
result = await client.call_tool("execute_sql", {"request": request})
assert result.data.success is False
assert result.data.error is not None
assert "Access denied to database" in result.data.error
assert result.data.error_type == "SECURITY_ERROR"
with pytest.raises(ToolError, match="Access denied to database"):
await client.call_tool("execute_sql", {"request": request})
@patch("superset.security_manager")
@patch("superset.db")
@pytest.mark.asyncio
async def test_execute_sql_dml_not_allowed(
async def test_execute_sql_dml_success(
self, mock_db, mock_security_manager, mcp_server
):
"""Test error when DML operations are not allowed."""
mock_database = _mock_database(allow_dml=False)
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
mock_database
)
mock_security_manager.can_access_database.return_value = True
request = {
"database_id": 1,
"sql": "UPDATE users SET name = 'test' WHERE id = 1",
"limit": 1,
}
async with Client(mcp_server) as client:
result = await client.call_tool("execute_sql", {"request": request})
assert result.data.success is False
assert result.data.error is not None
assert result.data.error_type == "DML_NOT_ALLOWED"
@patch("superset.security_manager")
@patch("superset.db")
@pytest.mark.asyncio
async def test_execute_sql_dml_allowed(
self, mock_db, mock_security_manager, mcp_server
):
"""Test successful DML execution when allowed."""
"""Test successful DML execution."""
mock_database = _mock_database(allow_dml=True)
dml_sql = "UPDATE users SET active = true WHERE last_login > '2024-01-01'"
mock_database.execute.return_value = _create_dml_result(
affected_rows=3,
original_sql=dml_sql,
)
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
mock_database
)
mock_security_manager.can_access_database.return_value = True
# Mock cursor for DML operation
cursor = ( # fmt: skip
mock_database.get_raw_connection.return_value.__enter__.return_value.cursor.return_value
)
cursor.rowcount = 3 # 3 rows affected
request = {
"database_id": 1,
"sql": "UPDATE users SET active = true WHERE last_login > '2024-01-01'",
@@ -263,15 +303,13 @@ class TestExecuteSql:
async with Client(mcp_server) as client:
result = await client.call_tool("execute_sql", {"request": request})
assert result.data.success is True
assert result.data.error is None
assert result.data.affected_rows == 3
assert result.data.rows == [] # Empty rows for DML
assert result.data.row_count == 0
# Verify commit was called
(
mock_database.get_raw_connection.return_value.__enter__.return_value.commit.assert_called_once()
)
# Use structured_content for dictionary access (Pydantic model responses)
data = result.structured_content
assert data["success"] is True
assert data["error"] is None
assert data["affected_rows"] == 3
assert data["rows"] is None # None for DML
assert data["row_count"] is None
@patch("superset.security_manager")
@patch("superset.db")
@@ -281,17 +319,15 @@ class TestExecuteSql:
):
"""Test query that returns no results."""
mock_database = _mock_database()
mock_database.execute.return_value = _create_select_result(
rows=[],
columns=["id", "name"],
)
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
mock_database
)
mock_security_manager.can_access_database.return_value = True
# Mock empty results
cursor = ( # fmt: skip
mock_database.get_raw_connection.return_value.__enter__.return_value.cursor.return_value
)
cursor.fetchmany.return_value = []
request = {
"database_id": 1,
"sql": "SELECT * FROM users WHERE id = 999999",
@@ -301,77 +337,26 @@ class TestExecuteSql:
async with Client(mcp_server) as client:
result = await client.call_tool("execute_sql", {"request": request})
assert result.data.success is True
assert result.data.error is None
assert result.data.row_count == 0
assert len(result.data.rows) == 0
assert len(result.data.columns) == 2 # Column metadata still returned
# Use structured_content for dictionary access (Pydantic model responses)
data = result.structured_content
assert data["success"] is True
assert data["error"] is None
assert data["row_count"] == 0
assert len(data["rows"]) == 0
assert len(data["columns"]) == 2 # Column metadata still returned
@patch("superset.security_manager")
@patch("superset.db")
@pytest.mark.asyncio
async def test_execute_sql_missing_parameter(
async def test_execute_sql_with_schema_and_catalog(
self, mock_db, mock_security_manager, mcp_server
):
"""Test error when required parameter is missing."""
"""Test SQL execution with schema and catalog specification."""
mock_database = _mock_database()
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
mock_database
mock_database.execute.return_value = _create_select_result(
rows=[{"total": 100}],
columns=["total"],
)
mock_security_manager.can_access_database.return_value = True
request = {
"database_id": 1,
"sql": "SELECT * FROM {table_name} WHERE id = {user_id}",
"parameters": {"table_name": "users"}, # Missing user_id
"limit": 1,
}
async with Client(mcp_server) as client:
result = await client.call_tool("execute_sql", {"request": request})
assert result.data.success is False
assert result.data.error is not None
assert "user_id" in result.data.error # Error contains parameter name
assert result.data.error_type == "INVALID_PAYLOAD_FORMAT_ERROR"
@patch("superset.security_manager")
@patch("superset.db")
@pytest.mark.asyncio
async def test_execute_sql_empty_parameters_with_placeholders(
self, mock_db, mock_security_manager, mcp_server
):
"""Test error when empty parameters dict is provided but SQL has
placeholders."""
mock_database = _mock_database()
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
mock_database
)
mock_security_manager.can_access_database.return_value = True
request = {
"database_id": 1,
"sql": "SELECT * FROM {table_name} LIMIT 5",
"parameters": {}, # Empty dict but SQL has {table_name}
"limit": 5,
}
async with Client(mcp_server) as client:
result = await client.call_tool("execute_sql", {"request": request})
assert result.data.success is False
assert result.data.error is not None
assert "Missing parameter: table_name" in result.data.error
assert result.data.error_type == "INVALID_PAYLOAD_FORMAT_ERROR"
@patch("superset.security_manager")
@patch("superset.db")
@pytest.mark.asyncio
async def test_execute_sql_with_schema(
self, mock_db, mock_security_manager, mcp_server
):
"""Test SQL execution with schema specification."""
mock_database = _mock_database()
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
mock_database
)
@@ -381,28 +366,47 @@ class TestExecuteSql:
"database_id": 1,
"sql": "SELECT COUNT(*) as total FROM orders",
"schema": "sales",
"catalog": "prod_catalog",
"limit": 1,
}
async with Client(mcp_server) as client:
result = await client.call_tool("execute_sql", {"request": request})
assert result.data.success is True
assert result.data.error is None
# Verify schema was passed to get_raw_connection
# Verify schema was passed
call_args = mock_database.get_raw_connection.call_args
assert call_args[1]["schema"] == "sales"
assert call_args[1]["catalog"] is None
# Use structured_content for dictionary access (Pydantic model responses)
data = result.structured_content
assert data["success"] is True
# Verify schema and catalog were passed to QueryOptions
call_args = mock_database.execute.call_args
options = call_args[0][1]
assert options.schema == "sales"
assert options.catalog == "prod_catalog"
@patch("superset.security_manager")
@patch("superset.db")
@pytest.mark.asyncio
async def test_execute_sql_limit_enforcement(
async def test_execute_sql_dry_run(
self, mock_db, mock_security_manager, mcp_server
):
"""Test that LIMIT is added to SELECT queries without one."""
"""Test dry_run mode returns transformed SQL without executing."""
mock_database = _mock_database()
executed_sql = "SELECT * FROM users WHERE user_id IN (SELECT ...) LIMIT 100"
mock_database.execute.return_value = QueryResult(
status=QueryStatus.SUCCESS,
statements=[
StatementResult(
original_sql="SELECT * FROM {{ table }}",
executed_sql=executed_sql,
data=None,
row_count=0,
execution_time_ms=0,
)
],
query_id=None,
total_execution_time_ms=0,
is_cached=False,
)
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
mock_database
)
@@ -410,25 +414,32 @@ class TestExecuteSql:
request = {
"database_id": 1,
"sql": "SELECT * FROM users", # No LIMIT
"limit": 50,
"sql": "SELECT * FROM {{ table }}",
"template_params": {"table": "users"},
"dry_run": True,
"limit": 100,
}
async with Client(mcp_server) as client:
result = await client.call_tool("execute_sql", {"request": request})
assert result.data.success is True
# Verify LIMIT was added
cursor = ( # fmt: skip
mock_database.get_raw_connection.return_value.__enter__.return_value.cursor.return_value
)
executed_sql = cursor.execute.call_args[0][0]
assert "LIMIT 50" in executed_sql
# Use structured_content for dictionary access (Pydantic model responses)
data = result.structured_content
assert data["success"] is True
# Verify dry_run was passed
call_args = mock_database.execute.call_args
options = call_args[0][1]
assert options.dry_run is True
# Verify statements show transformed SQL
assert data["statements"] is not None
assert "{{ table }}" in data["statements"][0]["original_sql"]
assert "users" in data["statements"][0]["executed_sql"]
@patch("superset.security_manager")
@patch("superset.db")
@pytest.mark.asyncio
async def test_execute_sql_sql_injection_prevention(
async def test_execute_sql_timeout_error(
self, mock_db, mock_security_manager, mcp_server
):
"""Test that SQL injection attempts are handled safely.
@@ -437,6 +448,10 @@ class TestExecuteSql:
before execution when DML is not allowed on the database.
"""
mock_database = _mock_database()
mock_database.execute.return_value = _create_error_result(
error_message="Query exceeded the timeout limit",
status=QueryStatus.TIMED_OUT,
)
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
mock_database
)
@@ -444,19 +459,76 @@ class TestExecuteSql:
request = {
"database_id": 1,
"sql": "SELECT * FROM users WHERE id = 1; DROP TABLE users;--",
"sql": "SELECT * FROM large_table",
"timeout": 5,
"limit": 100,
}
async with Client(mcp_server) as client:
result = await client.call_tool("execute_sql", {"request": request})
# Use structured_content for dictionary access (Pydantic model responses)
data = result.structured_content
assert data["success"] is False
assert data["error"] == "Query exceeded the timeout limit"
assert data["error_type"] == "timed_out"
@patch("superset.security_manager")
@patch("superset.db")
@pytest.mark.asyncio
async def test_execute_sql_multi_statement(
self, mock_db, mock_security_manager, mcp_server
):
"""Test multi-statement SQL execution."""
mock_database = _mock_database()
mock_database.execute.return_value = QueryResult(
status=QueryStatus.SUCCESS,
statements=[
StatementResult(
original_sql="SELECT 1 as a",
executed_sql="SELECT 1 as a",
data=pd.DataFrame([{"a": 1}]),
row_count=1,
execution_time_ms=5.0,
),
StatementResult(
original_sql="SELECT 2 as b",
executed_sql="SELECT 2 as b",
data=pd.DataFrame([{"b": 2}]),
row_count=1,
execution_time_ms=3.0,
),
],
query_id=None,
total_execution_time_ms=8.0,
is_cached=False,
)
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
mock_database
)
mock_security_manager.can_access_database.return_value = True
request = {
"database_id": 1,
"sql": "SELECT 1 as a; SELECT 2 as b;",
"limit": 10,
}
async with Client(mcp_server) as client:
result = await client.call_tool("execute_sql", {"request": request})
# SQLScript correctly detects DROP TABLE as a mutation
# and blocks it before execution (improved security)
assert result.data.success is False
assert result.data.error is not None
assert "DML" in result.data.error or "mutates" in result.data.error
assert result.data.error_type == "DML_NOT_ALLOWED"
# Use structured_content for dictionary access (Pydantic model responses)
data = result.structured_content
assert data["success"] is True
# Statements should contain both
assert data["statements"] is not None
assert len(data["statements"]) == 2
assert data["statements"][0]["original_sql"] == "SELECT 1 as a"
assert data["statements"][1]["original_sql"] == "SELECT 2 as b"
# rows/columns should be from first statement for backward compat
assert data["rows"] == [{"a": 1}]
assert data["row_count"] == 1
@pytest.mark.asyncio
async def test_execute_sql_empty_query_validation(self, mcp_server):
@@ -495,3 +567,39 @@ class TestExecuteSql:
async with Client(mcp_server) as client:
with pytest.raises(ToolError, match="less than or equal to 10000"):
await client.call_tool("execute_sql", {"request": request})
@patch("superset.security_manager")
@patch("superset.db")
@pytest.mark.asyncio
async def test_execute_sql_force_refresh(
self, mock_db, mock_security_manager, mcp_server
):
"""Test force_refresh bypasses cache."""
mock_database = _mock_database()
mock_database.execute.return_value = _create_select_result(
rows=[{"id": 1}],
columns=["id"],
)
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
mock_database
)
mock_security_manager.can_access_database.return_value = True
request = {
"database_id": 1,
"sql": "SELECT id FROM users",
"limit": 10,
"force_refresh": True,
}
async with Client(mcp_server) as client:
result = await client.call_tool("execute_sql", {"request": request})
data = result.structured_content
assert data["success"] is True
# Verify force_refresh was passed to CacheOptions
call_args = mock_database.execute.call_args
options = call_args[0][1]
assert options.cache is not None
assert options.cache.force_refresh is True