refactor: Migrates the MCP execute_sql tool to use the SQL execution API (#36739)

Co-authored-by: codeant-ai-for-open-source[bot] <244253245+codeant-ai-for-open-source[bot]@users.noreply.github.com>
This commit is contained in:
Michael S. Molina
2025-12-22 09:48:28 -03:00
committed by GitHub
parent c0bcf28947
commit 6b25d0663e
8 changed files with 492 additions and 838 deletions

View File

@@ -17,14 +17,20 @@
"""
Unit tests for execute_sql MCP tool
These tests mock Database.execute() to test the MCP tool's parameter mapping
and response conversion logic.
"""
import logging
from typing import Any
from unittest.mock import MagicMock, Mock, patch
import pandas as pd
import pytest
from fastmcp import Client
from fastmcp.exceptions import ToolError
from superset_core.api.types import QueryResult, QueryStatus, StatementResult
from superset.mcp_service.app import mcp
@@ -48,6 +54,71 @@ def mock_auth():
yield mock_get_user
def _create_select_result(
rows: list[dict[str, Any]],
columns: list[str],
original_sql: str = "SELECT * FROM users",
executed_sql: str | None = None,
execution_time_ms: float = 10.0,
) -> QueryResult:
"""Create a mock QueryResult for SELECT queries."""
df = pd.DataFrame(rows) if rows else pd.DataFrame(columns=columns)
return QueryResult(
status=QueryStatus.SUCCESS,
statements=[
StatementResult(
original_sql=original_sql,
executed_sql=executed_sql or original_sql,
data=df,
row_count=len(df),
execution_time_ms=execution_time_ms,
)
],
query_id=None,
total_execution_time_ms=execution_time_ms,
is_cached=False,
)
def _create_dml_result(
affected_rows: int,
original_sql: str = "UPDATE users SET active = true",
executed_sql: str | None = None,
execution_time_ms: float = 5.0,
) -> QueryResult:
"""Create a mock QueryResult for DML queries."""
return QueryResult(
status=QueryStatus.SUCCESS,
statements=[
StatementResult(
original_sql=original_sql,
executed_sql=executed_sql or original_sql,
data=None,
row_count=affected_rows,
execution_time_ms=execution_time_ms,
)
],
query_id=None,
total_execution_time_ms=execution_time_ms,
is_cached=False,
)
def _create_error_result(
error_message: str,
status: QueryStatus = QueryStatus.FAILED,
) -> QueryResult:
"""Create a mock QueryResult for failed queries."""
return QueryResult(
status=status,
statements=[],
query_id=None,
total_execution_time_ms=0,
is_cached=False,
error_message=error_message,
)
def _mock_database(
id: int = 1,
database_name: str = "test_db",
@@ -58,26 +129,6 @@ def _mock_database(
database.id = id
database.database_name = database_name
database.allow_dml = allow_dml
# Mock raw connection context manager
mock_cursor = Mock()
mock_cursor.description = [
("id", "INTEGER", None, None, None, None, False),
("name", "VARCHAR", None, None, None, None, True),
]
mock_cursor.fetchmany.return_value = [(1, "test_name")]
mock_cursor.rowcount = 1
mock_conn = Mock()
mock_conn.cursor.return_value = mock_cursor
mock_conn.commit = Mock()
mock_context = MagicMock()
mock_context.__enter__.return_value = mock_conn
mock_context.__exit__.return_value = None
database.get_raw_connection.return_value = mock_context
return database
@@ -91,8 +142,11 @@ class TestExecuteSql:
self, mock_db, mock_security_manager, mcp_server
):
"""Test basic SELECT query execution."""
# Setup mocks
mock_database = _mock_database()
mock_database.execute.return_value = _create_select_result(
rows=[{"id": 1, "name": "test_name"}],
columns=["id", "name"],
)
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
mock_database
)
@@ -107,25 +161,41 @@ class TestExecuteSql:
async with Client(mcp_server) as client:
result = await client.call_tool("execute_sql", {"request": request})
assert result.data.success is True
assert result.data.error is None
assert result.data.row_count == 1
assert len(result.data.rows) == 1
assert result.data.rows[0]["id"] == 1
assert result.data.rows[0]["name"] == "test_name"
assert len(result.data.columns) == 2
assert result.data.columns[0].name == "id"
assert result.data.columns[0].type == "INTEGER"
assert result.data.execution_time > 0
# Use structured_content for dictionary access (Pydantic model responses)
data = result.structured_content
assert data["success"] is True
assert data["error"] is None
assert data["row_count"] == 1
assert len(data["rows"]) == 1
assert data["rows"][0]["id"] == 1
assert data["rows"][0]["name"] == "test_name"
assert len(data["columns"]) == 2
assert data["columns"][0]["name"] == "id"
assert data["execution_time"] > 0
# Verify Database.execute() was called with correct QueryOptions
mock_database.execute.assert_called_once()
call_args = mock_database.execute.call_args
assert call_args[0][0] == request["sql"]
options = call_args[0][1]
assert options.limit == 10
# Caching is enabled by default (force_refresh=False means cache=None)
assert options.cache is None
@patch("superset.security_manager")
@patch("superset.db")
@pytest.mark.asyncio
async def test_execute_sql_with_parameters(
async def test_execute_sql_with_template_params(
self, mock_db, mock_security_manager, mcp_server
):
"""Test SQL execution with parameter substitution."""
"""Test SQL execution with Jinja2 template parameters."""
mock_database = _mock_database()
mock_database.execute.return_value = _create_select_result(
rows=[{"order_id": 1, "status": "active"}],
columns=["order_id", "status"],
original_sql="SELECT * FROM {{ table }} WHERE status = '{{ status }}'",
executed_sql="SELECT * FROM orders WHERE status = 'active'",
)
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
mock_database
)
@@ -133,33 +203,42 @@ class TestExecuteSql:
request = {
"database_id": 1,
"sql": "SELECT * FROM {table} WHERE status = '{status}' LIMIT {limit}",
"parameters": {"table": "orders", "status": "active", "limit": "5"},
"sql": "SELECT * FROM {{ table }} WHERE status = '{{ status }}'",
"template_params": {"table": "orders", "status": "active"},
"limit": 10,
}
async with Client(mcp_server) as client:
result = await client.call_tool("execute_sql", {"request": request})
assert result.data.success is True
assert result.data.error is None
# Verify parameter substitution happened
mock_database.get_raw_connection.assert_called_once()
cursor = ( # fmt: skip
mock_database.get_raw_connection.return_value.__enter__.return_value.cursor.return_value
)
# Check that the SQL was formatted with parameters
executed_sql = cursor.execute.call_args[0][0]
assert "orders" in executed_sql
assert "active" in executed_sql
# Use structured_content for dictionary access (Pydantic model responses)
data = result.structured_content
assert data["success"] is True
assert data["error"] is None
# Verify template_params were passed to QueryOptions
call_args = mock_database.execute.call_args
options = call_args[0][1]
assert options.template_params == {"table": "orders", "status": "active"}
# Verify statements contain both original and executed SQL
assert data["statements"] is not None
assert len(data["statements"]) == 1
assert "{{ table }}" in data["statements"][0]["original_sql"]
assert "orders" in data["statements"][0]["executed_sql"]
@patch("superset.security_manager")
@patch("superset.db")
@pytest.mark.asyncio
async def test_execute_sql_database_not_found(
self, mock_db, mock_security_manager, mcp_server
self,
mock_db,
mock_security_manager, # noqa: PT019
mcp_server,
):
"""Test error when database is not found."""
# mock_security_manager is patched but not used (error happens first)
del mock_security_manager # Silence unused variable warning
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
None
)
@@ -171,15 +250,10 @@ class TestExecuteSql:
}
async with Client(mcp_server) as client:
result = await client.call_tool("execute_sql", {"request": request})
with pytest.raises(ToolError, match="Database with ID 999 not found"):
await client.call_tool("execute_sql", {"request": request})
assert result.data.success is False
assert result.data.error is not None
assert "Database with ID 999 not found" in result.data.error
assert result.data.error_type == "DATABASE_NOT_FOUND_ERROR"
assert result.data.rows is None
@patch("superset.security_manager")
@patch("superset.security_manager", new_callable=MagicMock)
@patch("superset.db")
@pytest.mark.asyncio
async def test_execute_sql_access_denied(
@@ -190,10 +264,7 @@ class TestExecuteSql:
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
mock_database
)
# Use Mock instead of AsyncMock for synchronous call
from unittest.mock import Mock
mock_security_manager.can_access_database = Mock(return_value=False)
mock_security_manager.can_access_database.return_value = False
request = {
"database_id": 1,
@@ -202,58 +273,27 @@ class TestExecuteSql:
}
async with Client(mcp_server) as client:
result = await client.call_tool("execute_sql", {"request": request})
assert result.data.success is False
assert result.data.error is not None
assert "Access denied to database" in result.data.error
assert result.data.error_type == "SECURITY_ERROR"
with pytest.raises(ToolError, match="Access denied to database"):
await client.call_tool("execute_sql", {"request": request})
@patch("superset.security_manager")
@patch("superset.db")
@pytest.mark.asyncio
async def test_execute_sql_dml_not_allowed(
async def test_execute_sql_dml_success(
self, mock_db, mock_security_manager, mcp_server
):
"""Test error when DML operations are not allowed."""
mock_database = _mock_database(allow_dml=False)
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
mock_database
)
mock_security_manager.can_access_database.return_value = True
request = {
"database_id": 1,
"sql": "UPDATE users SET name = 'test' WHERE id = 1",
"limit": 1,
}
async with Client(mcp_server) as client:
result = await client.call_tool("execute_sql", {"request": request})
assert result.data.success is False
assert result.data.error is not None
assert result.data.error_type == "DML_NOT_ALLOWED"
@patch("superset.security_manager")
@patch("superset.db")
@pytest.mark.asyncio
async def test_execute_sql_dml_allowed(
self, mock_db, mock_security_manager, mcp_server
):
"""Test successful DML execution when allowed."""
"""Test successful DML execution."""
mock_database = _mock_database(allow_dml=True)
dml_sql = "UPDATE users SET active = true WHERE last_login > '2024-01-01'"
mock_database.execute.return_value = _create_dml_result(
affected_rows=3,
original_sql=dml_sql,
)
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
mock_database
)
mock_security_manager.can_access_database.return_value = True
# Mock cursor for DML operation
cursor = ( # fmt: skip
mock_database.get_raw_connection.return_value.__enter__.return_value.cursor.return_value
)
cursor.rowcount = 3 # 3 rows affected
request = {
"database_id": 1,
"sql": "UPDATE users SET active = true WHERE last_login > '2024-01-01'",
@@ -263,15 +303,13 @@ class TestExecuteSql:
async with Client(mcp_server) as client:
result = await client.call_tool("execute_sql", {"request": request})
assert result.data.success is True
assert result.data.error is None
assert result.data.affected_rows == 3
assert result.data.rows == [] # Empty rows for DML
assert result.data.row_count == 0
# Verify commit was called
(
mock_database.get_raw_connection.return_value.__enter__.return_value.commit.assert_called_once()
)
# Use structured_content for dictionary access (Pydantic model responses)
data = result.structured_content
assert data["success"] is True
assert data["error"] is None
assert data["affected_rows"] == 3
assert data["rows"] is None # None for DML
assert data["row_count"] is None
@patch("superset.security_manager")
@patch("superset.db")
@@ -281,17 +319,15 @@ class TestExecuteSql:
):
"""Test query that returns no results."""
mock_database = _mock_database()
mock_database.execute.return_value = _create_select_result(
rows=[],
columns=["id", "name"],
)
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
mock_database
)
mock_security_manager.can_access_database.return_value = True
# Mock empty results
cursor = ( # fmt: skip
mock_database.get_raw_connection.return_value.__enter__.return_value.cursor.return_value
)
cursor.fetchmany.return_value = []
request = {
"database_id": 1,
"sql": "SELECT * FROM users WHERE id = 999999",
@@ -301,77 +337,26 @@ class TestExecuteSql:
async with Client(mcp_server) as client:
result = await client.call_tool("execute_sql", {"request": request})
assert result.data.success is True
assert result.data.error is None
assert result.data.row_count == 0
assert len(result.data.rows) == 0
assert len(result.data.columns) == 2 # Column metadata still returned
# Use structured_content for dictionary access (Pydantic model responses)
data = result.structured_content
assert data["success"] is True
assert data["error"] is None
assert data["row_count"] == 0
assert len(data["rows"]) == 0
assert len(data["columns"]) == 2 # Column metadata still returned
@patch("superset.security_manager")
@patch("superset.db")
@pytest.mark.asyncio
async def test_execute_sql_missing_parameter(
async def test_execute_sql_with_schema_and_catalog(
self, mock_db, mock_security_manager, mcp_server
):
"""Test error when required parameter is missing."""
"""Test SQL execution with schema and catalog specification."""
mock_database = _mock_database()
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
mock_database
mock_database.execute.return_value = _create_select_result(
rows=[{"total": 100}],
columns=["total"],
)
mock_security_manager.can_access_database.return_value = True
request = {
"database_id": 1,
"sql": "SELECT * FROM {table_name} WHERE id = {user_id}",
"parameters": {"table_name": "users"}, # Missing user_id
"limit": 1,
}
async with Client(mcp_server) as client:
result = await client.call_tool("execute_sql", {"request": request})
assert result.data.success is False
assert result.data.error is not None
assert "user_id" in result.data.error # Error contains parameter name
assert result.data.error_type == "INVALID_PAYLOAD_FORMAT_ERROR"
@patch("superset.security_manager")
@patch("superset.db")
@pytest.mark.asyncio
async def test_execute_sql_empty_parameters_with_placeholders(
self, mock_db, mock_security_manager, mcp_server
):
"""Test error when empty parameters dict is provided but SQL has
placeholders."""
mock_database = _mock_database()
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
mock_database
)
mock_security_manager.can_access_database.return_value = True
request = {
"database_id": 1,
"sql": "SELECT * FROM {table_name} LIMIT 5",
"parameters": {}, # Empty dict but SQL has {table_name}
"limit": 5,
}
async with Client(mcp_server) as client:
result = await client.call_tool("execute_sql", {"request": request})
assert result.data.success is False
assert result.data.error is not None
assert "Missing parameter: table_name" in result.data.error
assert result.data.error_type == "INVALID_PAYLOAD_FORMAT_ERROR"
@patch("superset.security_manager")
@patch("superset.db")
@pytest.mark.asyncio
async def test_execute_sql_with_schema(
self, mock_db, mock_security_manager, mcp_server
):
"""Test SQL execution with schema specification."""
mock_database = _mock_database()
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
mock_database
)
@@ -381,28 +366,47 @@ class TestExecuteSql:
"database_id": 1,
"sql": "SELECT COUNT(*) as total FROM orders",
"schema": "sales",
"catalog": "prod_catalog",
"limit": 1,
}
async with Client(mcp_server) as client:
result = await client.call_tool("execute_sql", {"request": request})
assert result.data.success is True
assert result.data.error is None
# Verify schema was passed to get_raw_connection
# Verify schema was passed
call_args = mock_database.get_raw_connection.call_args
assert call_args[1]["schema"] == "sales"
assert call_args[1]["catalog"] is None
# Use structured_content for dictionary access (Pydantic model responses)
data = result.structured_content
assert data["success"] is True
# Verify schema and catalog were passed to QueryOptions
call_args = mock_database.execute.call_args
options = call_args[0][1]
assert options.schema == "sales"
assert options.catalog == "prod_catalog"
@patch("superset.security_manager")
@patch("superset.db")
@pytest.mark.asyncio
async def test_execute_sql_limit_enforcement(
async def test_execute_sql_dry_run(
self, mock_db, mock_security_manager, mcp_server
):
"""Test that LIMIT is added to SELECT queries without one."""
"""Test dry_run mode returns transformed SQL without executing."""
mock_database = _mock_database()
executed_sql = "SELECT * FROM users WHERE user_id IN (SELECT ...) LIMIT 100"
mock_database.execute.return_value = QueryResult(
status=QueryStatus.SUCCESS,
statements=[
StatementResult(
original_sql="SELECT * FROM {{ table }}",
executed_sql=executed_sql,
data=None,
row_count=0,
execution_time_ms=0,
)
],
query_id=None,
total_execution_time_ms=0,
is_cached=False,
)
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
mock_database
)
@@ -410,25 +414,32 @@ class TestExecuteSql:
request = {
"database_id": 1,
"sql": "SELECT * FROM users", # No LIMIT
"limit": 50,
"sql": "SELECT * FROM {{ table }}",
"template_params": {"table": "users"},
"dry_run": True,
"limit": 100,
}
async with Client(mcp_server) as client:
result = await client.call_tool("execute_sql", {"request": request})
assert result.data.success is True
# Verify LIMIT was added
cursor = ( # fmt: skip
mock_database.get_raw_connection.return_value.__enter__.return_value.cursor.return_value
)
executed_sql = cursor.execute.call_args[0][0]
assert "LIMIT 50" in executed_sql
# Use structured_content for dictionary access (Pydantic model responses)
data = result.structured_content
assert data["success"] is True
# Verify dry_run was passed
call_args = mock_database.execute.call_args
options = call_args[0][1]
assert options.dry_run is True
# Verify statements show transformed SQL
assert data["statements"] is not None
assert "{{ table }}" in data["statements"][0]["original_sql"]
assert "users" in data["statements"][0]["executed_sql"]
@patch("superset.security_manager")
@patch("superset.db")
@pytest.mark.asyncio
async def test_execute_sql_sql_injection_prevention(
async def test_execute_sql_timeout_error(
self, mock_db, mock_security_manager, mcp_server
):
"""Test that SQL injection attempts are handled safely.
@@ -437,6 +448,10 @@ class TestExecuteSql:
before execution when DML is not allowed on the database.
"""
mock_database = _mock_database()
mock_database.execute.return_value = _create_error_result(
error_message="Query exceeded the timeout limit",
status=QueryStatus.TIMED_OUT,
)
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
mock_database
)
@@ -444,19 +459,76 @@ class TestExecuteSql:
request = {
"database_id": 1,
"sql": "SELECT * FROM users WHERE id = 1; DROP TABLE users;--",
"sql": "SELECT * FROM large_table",
"timeout": 5,
"limit": 100,
}
async with Client(mcp_server) as client:
result = await client.call_tool("execute_sql", {"request": request})
# Use structured_content for dictionary access (Pydantic model responses)
data = result.structured_content
assert data["success"] is False
assert data["error"] == "Query exceeded the timeout limit"
assert data["error_type"] == "timed_out"
@patch("superset.security_manager")
@patch("superset.db")
@pytest.mark.asyncio
async def test_execute_sql_multi_statement(
self, mock_db, mock_security_manager, mcp_server
):
"""Test multi-statement SQL execution."""
mock_database = _mock_database()
mock_database.execute.return_value = QueryResult(
status=QueryStatus.SUCCESS,
statements=[
StatementResult(
original_sql="SELECT 1 as a",
executed_sql="SELECT 1 as a",
data=pd.DataFrame([{"a": 1}]),
row_count=1,
execution_time_ms=5.0,
),
StatementResult(
original_sql="SELECT 2 as b",
executed_sql="SELECT 2 as b",
data=pd.DataFrame([{"b": 2}]),
row_count=1,
execution_time_ms=3.0,
),
],
query_id=None,
total_execution_time_ms=8.0,
is_cached=False,
)
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
mock_database
)
mock_security_manager.can_access_database.return_value = True
request = {
"database_id": 1,
"sql": "SELECT 1 as a; SELECT 2 as b;",
"limit": 10,
}
async with Client(mcp_server) as client:
result = await client.call_tool("execute_sql", {"request": request})
# SQLScript correctly detects DROP TABLE as a mutation
# and blocks it before execution (improved security)
assert result.data.success is False
assert result.data.error is not None
assert "DML" in result.data.error or "mutates" in result.data.error
assert result.data.error_type == "DML_NOT_ALLOWED"
# Use structured_content for dictionary access (Pydantic model responses)
data = result.structured_content
assert data["success"] is True
# Statements should contain both
assert data["statements"] is not None
assert len(data["statements"]) == 2
assert data["statements"][0]["original_sql"] == "SELECT 1 as a"
assert data["statements"][1]["original_sql"] == "SELECT 2 as b"
# rows/columns should be from first statement for backward compat
assert data["rows"] == [{"a": 1}]
assert data["row_count"] == 1
@pytest.mark.asyncio
async def test_execute_sql_empty_query_validation(self, mcp_server):
@@ -495,3 +567,39 @@ class TestExecuteSql:
async with Client(mcp_server) as client:
with pytest.raises(ToolError, match="less than or equal to 10000"):
await client.call_tool("execute_sql", {"request": request})
@patch("superset.security_manager")
@patch("superset.db")
@pytest.mark.asyncio
async def test_execute_sql_force_refresh(
self, mock_db, mock_security_manager, mcp_server
):
"""Test force_refresh bypasses cache."""
mock_database = _mock_database()
mock_database.execute.return_value = _create_select_result(
rows=[{"id": 1}],
columns=["id"],
)
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
mock_database
)
mock_security_manager.can_access_database.return_value = True
request = {
"database_id": 1,
"sql": "SELECT id FROM users",
"limit": 10,
"force_refresh": True,
}
async with Client(mcp_server) as client:
result = await client.call_tool("execute_sql", {"request": request})
data = result.structured_content
assert data["success"] is True
# Verify force_refresh was passed to CacheOptions
call_args = mock_database.execute.call_args
options = call_args[0][1]
assert options.cache is not None
assert options.cache.force_refresh is True