mirror of
https://github.com/apache/superset.git
synced 2026-04-19 08:04:53 +00:00
refactor: Migrates the MCP execute_sql tool to use the SQL execution API (#36739)
Co-authored-by: codeant-ai-for-open-source[bot] <244253245+codeant-ai-for-open-source[bot]@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
c0bcf28947
commit
6b25d0663e
@@ -1,137 +0,0 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""Unit tests for MCP SQL Lab utility functions."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from superset.mcp_service.sql_lab.sql_lab_utils import _apply_limit
|
||||
from superset.sql.parse import SQLScript
|
||||
|
||||
|
||||
class TestSQLScriptMutationDetection:
|
||||
"""Tests for SQLScript.has_mutation() used in query type detection."""
|
||||
|
||||
def test_simple_select_no_mutation(self):
|
||||
"""Test simple SELECT query is not a mutation."""
|
||||
script = SQLScript("SELECT * FROM table1", "sqlite")
|
||||
assert script.has_mutation() is False
|
||||
|
||||
def test_cte_query_no_mutation(self):
|
||||
"""Test CTE (WITH clause) query is not a mutation."""
|
||||
cte_sql = """
|
||||
WITH cte_name AS (
|
||||
SELECT * FROM table1
|
||||
)
|
||||
SELECT * FROM cte_name
|
||||
"""
|
||||
script = SQLScript(cte_sql, "sqlite")
|
||||
assert script.has_mutation() is False
|
||||
|
||||
def test_recursive_cte_no_mutation(self):
|
||||
"""Test recursive CTE is not a mutation."""
|
||||
recursive_cte = """
|
||||
WITH RECURSIVE cte AS (
|
||||
SELECT 1 AS n
|
||||
UNION ALL
|
||||
SELECT n + 1 FROM cte WHERE n < 10
|
||||
)
|
||||
SELECT n FROM cte
|
||||
"""
|
||||
script = SQLScript(recursive_cte, "sqlite")
|
||||
assert script.has_mutation() is False
|
||||
|
||||
def test_multiple_ctes_no_mutation(self):
|
||||
"""Test query with multiple CTEs is not a mutation."""
|
||||
multiple_ctes = """
|
||||
WITH
|
||||
cte1 AS (SELECT 1 as a),
|
||||
cte2 AS (SELECT 2 as b)
|
||||
SELECT * FROM cte1, cte2
|
||||
"""
|
||||
script = SQLScript(multiple_ctes, "sqlite")
|
||||
assert script.has_mutation() is False
|
||||
|
||||
def test_insert_is_mutation(self):
|
||||
"""Test INSERT query is a mutation."""
|
||||
script = SQLScript("INSERT INTO table1 VALUES (1)", "sqlite")
|
||||
assert script.has_mutation() is True
|
||||
|
||||
def test_update_is_mutation(self):
|
||||
"""Test UPDATE query is a mutation."""
|
||||
script = SQLScript("UPDATE table1 SET col = 1", "sqlite")
|
||||
assert script.has_mutation() is True
|
||||
|
||||
def test_delete_is_mutation(self):
|
||||
"""Test DELETE query is a mutation."""
|
||||
script = SQLScript("DELETE FROM table1", "sqlite")
|
||||
assert script.has_mutation() is True
|
||||
|
||||
def test_create_is_mutation(self):
|
||||
"""Test CREATE query is a mutation."""
|
||||
script = SQLScript("CREATE TABLE table1 (id INT)", "sqlite")
|
||||
assert script.has_mutation() is True
|
||||
|
||||
|
||||
class TestApplyLimit:
|
||||
"""Tests for _apply_limit function using SQLScript."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_database(self):
|
||||
"""Create a mock database with sqlite engine spec."""
|
||||
db = MagicMock()
|
||||
db.db_engine_spec.engine = "sqlite"
|
||||
return db
|
||||
|
||||
def test_adds_limit_to_select(self, mock_database):
|
||||
"""Test LIMIT is added to SELECT query."""
|
||||
result = _apply_limit("SELECT * FROM table1", 100, mock_database)
|
||||
assert "LIMIT 100" in result
|
||||
|
||||
def test_adds_limit_to_cte(self, mock_database):
|
||||
"""Test LIMIT is added to CTE query."""
|
||||
cte_sql = "WITH cte AS (SELECT 1) SELECT * FROM cte"
|
||||
result = _apply_limit(cte_sql, 50, mock_database)
|
||||
assert "LIMIT 50" in result
|
||||
|
||||
def test_preserves_existing_limit(self, mock_database):
|
||||
"""Test existing LIMIT is not modified."""
|
||||
sql = "SELECT * FROM table1 LIMIT 10"
|
||||
result = _apply_limit(sql, 100, mock_database)
|
||||
assert "LIMIT 10" in result
|
||||
assert "LIMIT 100" not in result
|
||||
|
||||
def test_preserves_existing_limit_in_cte(self, mock_database):
|
||||
"""Test existing LIMIT in CTE query is not modified."""
|
||||
cte_sql = "WITH cte AS (SELECT 1) SELECT * FROM cte LIMIT 5"
|
||||
result = _apply_limit(cte_sql, 100, mock_database)
|
||||
assert "LIMIT 5" in result
|
||||
assert "LIMIT 100" not in result
|
||||
|
||||
def test_no_limit_on_insert(self, mock_database):
|
||||
"""Test LIMIT is not added to INSERT query."""
|
||||
sql = "INSERT INTO table1 VALUES (1)"
|
||||
result = _apply_limit(sql, 100, mock_database)
|
||||
assert result == sql
|
||||
|
||||
def test_no_limit_on_update(self, mock_database):
|
||||
"""Test LIMIT is not added to UPDATE query."""
|
||||
sql = "UPDATE table1 SET col = 1"
|
||||
result = _apply_limit(sql, 100, mock_database)
|
||||
assert result == sql
|
||||
@@ -17,14 +17,20 @@
|
||||
|
||||
"""
|
||||
Unit tests for execute_sql MCP tool
|
||||
|
||||
These tests mock Database.execute() to test the MCP tool's parameter mapping
|
||||
and response conversion logic.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pandas as pd
|
||||
import pytest
|
||||
from fastmcp import Client
|
||||
from fastmcp.exceptions import ToolError
|
||||
from superset_core.api.types import QueryResult, QueryStatus, StatementResult
|
||||
|
||||
from superset.mcp_service.app import mcp
|
||||
|
||||
@@ -48,6 +54,71 @@ def mock_auth():
|
||||
yield mock_get_user
|
||||
|
||||
|
||||
def _create_select_result(
|
||||
rows: list[dict[str, Any]],
|
||||
columns: list[str],
|
||||
original_sql: str = "SELECT * FROM users",
|
||||
executed_sql: str | None = None,
|
||||
execution_time_ms: float = 10.0,
|
||||
) -> QueryResult:
|
||||
"""Create a mock QueryResult for SELECT queries."""
|
||||
df = pd.DataFrame(rows) if rows else pd.DataFrame(columns=columns)
|
||||
return QueryResult(
|
||||
status=QueryStatus.SUCCESS,
|
||||
statements=[
|
||||
StatementResult(
|
||||
original_sql=original_sql,
|
||||
executed_sql=executed_sql or original_sql,
|
||||
data=df,
|
||||
row_count=len(df),
|
||||
execution_time_ms=execution_time_ms,
|
||||
)
|
||||
],
|
||||
query_id=None,
|
||||
total_execution_time_ms=execution_time_ms,
|
||||
is_cached=False,
|
||||
)
|
||||
|
||||
|
||||
def _create_dml_result(
|
||||
affected_rows: int,
|
||||
original_sql: str = "UPDATE users SET active = true",
|
||||
executed_sql: str | None = None,
|
||||
execution_time_ms: float = 5.0,
|
||||
) -> QueryResult:
|
||||
"""Create a mock QueryResult for DML queries."""
|
||||
return QueryResult(
|
||||
status=QueryStatus.SUCCESS,
|
||||
statements=[
|
||||
StatementResult(
|
||||
original_sql=original_sql,
|
||||
executed_sql=executed_sql or original_sql,
|
||||
data=None,
|
||||
row_count=affected_rows,
|
||||
execution_time_ms=execution_time_ms,
|
||||
)
|
||||
],
|
||||
query_id=None,
|
||||
total_execution_time_ms=execution_time_ms,
|
||||
is_cached=False,
|
||||
)
|
||||
|
||||
|
||||
def _create_error_result(
|
||||
error_message: str,
|
||||
status: QueryStatus = QueryStatus.FAILED,
|
||||
) -> QueryResult:
|
||||
"""Create a mock QueryResult for failed queries."""
|
||||
return QueryResult(
|
||||
status=status,
|
||||
statements=[],
|
||||
query_id=None,
|
||||
total_execution_time_ms=0,
|
||||
is_cached=False,
|
||||
error_message=error_message,
|
||||
)
|
||||
|
||||
|
||||
def _mock_database(
|
||||
id: int = 1,
|
||||
database_name: str = "test_db",
|
||||
@@ -58,26 +129,6 @@ def _mock_database(
|
||||
database.id = id
|
||||
database.database_name = database_name
|
||||
database.allow_dml = allow_dml
|
||||
|
||||
# Mock raw connection context manager
|
||||
mock_cursor = Mock()
|
||||
mock_cursor.description = [
|
||||
("id", "INTEGER", None, None, None, None, False),
|
||||
("name", "VARCHAR", None, None, None, None, True),
|
||||
]
|
||||
mock_cursor.fetchmany.return_value = [(1, "test_name")]
|
||||
mock_cursor.rowcount = 1
|
||||
|
||||
mock_conn = Mock()
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_conn.commit = Mock()
|
||||
|
||||
mock_context = MagicMock()
|
||||
mock_context.__enter__.return_value = mock_conn
|
||||
mock_context.__exit__.return_value = None
|
||||
|
||||
database.get_raw_connection.return_value = mock_context
|
||||
|
||||
return database
|
||||
|
||||
|
||||
@@ -91,8 +142,11 @@ class TestExecuteSql:
|
||||
self, mock_db, mock_security_manager, mcp_server
|
||||
):
|
||||
"""Test basic SELECT query execution."""
|
||||
# Setup mocks
|
||||
mock_database = _mock_database()
|
||||
mock_database.execute.return_value = _create_select_result(
|
||||
rows=[{"id": 1, "name": "test_name"}],
|
||||
columns=["id", "name"],
|
||||
)
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
|
||||
mock_database
|
||||
)
|
||||
@@ -107,25 +161,41 @@ class TestExecuteSql:
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("execute_sql", {"request": request})
|
||||
|
||||
assert result.data.success is True
|
||||
assert result.data.error is None
|
||||
assert result.data.row_count == 1
|
||||
assert len(result.data.rows) == 1
|
||||
assert result.data.rows[0]["id"] == 1
|
||||
assert result.data.rows[0]["name"] == "test_name"
|
||||
assert len(result.data.columns) == 2
|
||||
assert result.data.columns[0].name == "id"
|
||||
assert result.data.columns[0].type == "INTEGER"
|
||||
assert result.data.execution_time > 0
|
||||
# Use structured_content for dictionary access (Pydantic model responses)
|
||||
data = result.structured_content
|
||||
assert data["success"] is True
|
||||
assert data["error"] is None
|
||||
assert data["row_count"] == 1
|
||||
assert len(data["rows"]) == 1
|
||||
assert data["rows"][0]["id"] == 1
|
||||
assert data["rows"][0]["name"] == "test_name"
|
||||
assert len(data["columns"]) == 2
|
||||
assert data["columns"][0]["name"] == "id"
|
||||
assert data["execution_time"] > 0
|
||||
|
||||
# Verify Database.execute() was called with correct QueryOptions
|
||||
mock_database.execute.assert_called_once()
|
||||
call_args = mock_database.execute.call_args
|
||||
assert call_args[0][0] == request["sql"]
|
||||
options = call_args[0][1]
|
||||
assert options.limit == 10
|
||||
# Caching is enabled by default (force_refresh=False means cache=None)
|
||||
assert options.cache is None
|
||||
|
||||
@patch("superset.security_manager")
|
||||
@patch("superset.db")
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_sql_with_parameters(
|
||||
async def test_execute_sql_with_template_params(
|
||||
self, mock_db, mock_security_manager, mcp_server
|
||||
):
|
||||
"""Test SQL execution with parameter substitution."""
|
||||
"""Test SQL execution with Jinja2 template parameters."""
|
||||
mock_database = _mock_database()
|
||||
mock_database.execute.return_value = _create_select_result(
|
||||
rows=[{"order_id": 1, "status": "active"}],
|
||||
columns=["order_id", "status"],
|
||||
original_sql="SELECT * FROM {{ table }} WHERE status = '{{ status }}'",
|
||||
executed_sql="SELECT * FROM orders WHERE status = 'active'",
|
||||
)
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
|
||||
mock_database
|
||||
)
|
||||
@@ -133,33 +203,42 @@ class TestExecuteSql:
|
||||
|
||||
request = {
|
||||
"database_id": 1,
|
||||
"sql": "SELECT * FROM {table} WHERE status = '{status}' LIMIT {limit}",
|
||||
"parameters": {"table": "orders", "status": "active", "limit": "5"},
|
||||
"sql": "SELECT * FROM {{ table }} WHERE status = '{{ status }}'",
|
||||
"template_params": {"table": "orders", "status": "active"},
|
||||
"limit": 10,
|
||||
}
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("execute_sql", {"request": request})
|
||||
|
||||
assert result.data.success is True
|
||||
assert result.data.error is None
|
||||
# Verify parameter substitution happened
|
||||
mock_database.get_raw_connection.assert_called_once()
|
||||
cursor = ( # fmt: skip
|
||||
mock_database.get_raw_connection.return_value.__enter__.return_value.cursor.return_value
|
||||
)
|
||||
# Check that the SQL was formatted with parameters
|
||||
executed_sql = cursor.execute.call_args[0][0]
|
||||
assert "orders" in executed_sql
|
||||
assert "active" in executed_sql
|
||||
# Use structured_content for dictionary access (Pydantic model responses)
|
||||
data = result.structured_content
|
||||
assert data["success"] is True
|
||||
assert data["error"] is None
|
||||
|
||||
# Verify template_params were passed to QueryOptions
|
||||
call_args = mock_database.execute.call_args
|
||||
options = call_args[0][1]
|
||||
assert options.template_params == {"table": "orders", "status": "active"}
|
||||
|
||||
# Verify statements contain both original and executed SQL
|
||||
assert data["statements"] is not None
|
||||
assert len(data["statements"]) == 1
|
||||
assert "{{ table }}" in data["statements"][0]["original_sql"]
|
||||
assert "orders" in data["statements"][0]["executed_sql"]
|
||||
|
||||
@patch("superset.security_manager")
|
||||
@patch("superset.db")
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_sql_database_not_found(
|
||||
self, mock_db, mock_security_manager, mcp_server
|
||||
self,
|
||||
mock_db,
|
||||
mock_security_manager, # noqa: PT019
|
||||
mcp_server,
|
||||
):
|
||||
"""Test error when database is not found."""
|
||||
# mock_security_manager is patched but not used (error happens first)
|
||||
del mock_security_manager # Silence unused variable warning
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
|
||||
None
|
||||
)
|
||||
@@ -171,15 +250,10 @@ class TestExecuteSql:
|
||||
}
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("execute_sql", {"request": request})
|
||||
with pytest.raises(ToolError, match="Database with ID 999 not found"):
|
||||
await client.call_tool("execute_sql", {"request": request})
|
||||
|
||||
assert result.data.success is False
|
||||
assert result.data.error is not None
|
||||
assert "Database with ID 999 not found" in result.data.error
|
||||
assert result.data.error_type == "DATABASE_NOT_FOUND_ERROR"
|
||||
assert result.data.rows is None
|
||||
|
||||
@patch("superset.security_manager")
|
||||
@patch("superset.security_manager", new_callable=MagicMock)
|
||||
@patch("superset.db")
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_sql_access_denied(
|
||||
@@ -190,10 +264,7 @@ class TestExecuteSql:
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
|
||||
mock_database
|
||||
)
|
||||
# Use Mock instead of AsyncMock for synchronous call
|
||||
from unittest.mock import Mock
|
||||
|
||||
mock_security_manager.can_access_database = Mock(return_value=False)
|
||||
mock_security_manager.can_access_database.return_value = False
|
||||
|
||||
request = {
|
||||
"database_id": 1,
|
||||
@@ -202,58 +273,27 @@ class TestExecuteSql:
|
||||
}
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("execute_sql", {"request": request})
|
||||
|
||||
assert result.data.success is False
|
||||
assert result.data.error is not None
|
||||
assert "Access denied to database" in result.data.error
|
||||
assert result.data.error_type == "SECURITY_ERROR"
|
||||
with pytest.raises(ToolError, match="Access denied to database"):
|
||||
await client.call_tool("execute_sql", {"request": request})
|
||||
|
||||
@patch("superset.security_manager")
|
||||
@patch("superset.db")
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_sql_dml_not_allowed(
|
||||
async def test_execute_sql_dml_success(
|
||||
self, mock_db, mock_security_manager, mcp_server
|
||||
):
|
||||
"""Test error when DML operations are not allowed."""
|
||||
mock_database = _mock_database(allow_dml=False)
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
|
||||
mock_database
|
||||
)
|
||||
mock_security_manager.can_access_database.return_value = True
|
||||
|
||||
request = {
|
||||
"database_id": 1,
|
||||
"sql": "UPDATE users SET name = 'test' WHERE id = 1",
|
||||
"limit": 1,
|
||||
}
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("execute_sql", {"request": request})
|
||||
|
||||
assert result.data.success is False
|
||||
assert result.data.error is not None
|
||||
assert result.data.error_type == "DML_NOT_ALLOWED"
|
||||
|
||||
@patch("superset.security_manager")
|
||||
@patch("superset.db")
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_sql_dml_allowed(
|
||||
self, mock_db, mock_security_manager, mcp_server
|
||||
):
|
||||
"""Test successful DML execution when allowed."""
|
||||
"""Test successful DML execution."""
|
||||
mock_database = _mock_database(allow_dml=True)
|
||||
dml_sql = "UPDATE users SET active = true WHERE last_login > '2024-01-01'"
|
||||
mock_database.execute.return_value = _create_dml_result(
|
||||
affected_rows=3,
|
||||
original_sql=dml_sql,
|
||||
)
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
|
||||
mock_database
|
||||
)
|
||||
mock_security_manager.can_access_database.return_value = True
|
||||
|
||||
# Mock cursor for DML operation
|
||||
cursor = ( # fmt: skip
|
||||
mock_database.get_raw_connection.return_value.__enter__.return_value.cursor.return_value
|
||||
)
|
||||
cursor.rowcount = 3 # 3 rows affected
|
||||
|
||||
request = {
|
||||
"database_id": 1,
|
||||
"sql": "UPDATE users SET active = true WHERE last_login > '2024-01-01'",
|
||||
@@ -263,15 +303,13 @@ class TestExecuteSql:
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("execute_sql", {"request": request})
|
||||
|
||||
assert result.data.success is True
|
||||
assert result.data.error is None
|
||||
assert result.data.affected_rows == 3
|
||||
assert result.data.rows == [] # Empty rows for DML
|
||||
assert result.data.row_count == 0
|
||||
# Verify commit was called
|
||||
(
|
||||
mock_database.get_raw_connection.return_value.__enter__.return_value.commit.assert_called_once()
|
||||
)
|
||||
# Use structured_content for dictionary access (Pydantic model responses)
|
||||
data = result.structured_content
|
||||
assert data["success"] is True
|
||||
assert data["error"] is None
|
||||
assert data["affected_rows"] == 3
|
||||
assert data["rows"] is None # None for DML
|
||||
assert data["row_count"] is None
|
||||
|
||||
@patch("superset.security_manager")
|
||||
@patch("superset.db")
|
||||
@@ -281,17 +319,15 @@ class TestExecuteSql:
|
||||
):
|
||||
"""Test query that returns no results."""
|
||||
mock_database = _mock_database()
|
||||
mock_database.execute.return_value = _create_select_result(
|
||||
rows=[],
|
||||
columns=["id", "name"],
|
||||
)
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
|
||||
mock_database
|
||||
)
|
||||
mock_security_manager.can_access_database.return_value = True
|
||||
|
||||
# Mock empty results
|
||||
cursor = ( # fmt: skip
|
||||
mock_database.get_raw_connection.return_value.__enter__.return_value.cursor.return_value
|
||||
)
|
||||
cursor.fetchmany.return_value = []
|
||||
|
||||
request = {
|
||||
"database_id": 1,
|
||||
"sql": "SELECT * FROM users WHERE id = 999999",
|
||||
@@ -301,77 +337,26 @@ class TestExecuteSql:
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("execute_sql", {"request": request})
|
||||
|
||||
assert result.data.success is True
|
||||
assert result.data.error is None
|
||||
assert result.data.row_count == 0
|
||||
assert len(result.data.rows) == 0
|
||||
assert len(result.data.columns) == 2 # Column metadata still returned
|
||||
# Use structured_content for dictionary access (Pydantic model responses)
|
||||
data = result.structured_content
|
||||
assert data["success"] is True
|
||||
assert data["error"] is None
|
||||
assert data["row_count"] == 0
|
||||
assert len(data["rows"]) == 0
|
||||
assert len(data["columns"]) == 2 # Column metadata still returned
|
||||
|
||||
@patch("superset.security_manager")
|
||||
@patch("superset.db")
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_sql_missing_parameter(
|
||||
async def test_execute_sql_with_schema_and_catalog(
|
||||
self, mock_db, mock_security_manager, mcp_server
|
||||
):
|
||||
"""Test error when required parameter is missing."""
|
||||
"""Test SQL execution with schema and catalog specification."""
|
||||
mock_database = _mock_database()
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
|
||||
mock_database
|
||||
mock_database.execute.return_value = _create_select_result(
|
||||
rows=[{"total": 100}],
|
||||
columns=["total"],
|
||||
)
|
||||
mock_security_manager.can_access_database.return_value = True
|
||||
|
||||
request = {
|
||||
"database_id": 1,
|
||||
"sql": "SELECT * FROM {table_name} WHERE id = {user_id}",
|
||||
"parameters": {"table_name": "users"}, # Missing user_id
|
||||
"limit": 1,
|
||||
}
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("execute_sql", {"request": request})
|
||||
|
||||
assert result.data.success is False
|
||||
assert result.data.error is not None
|
||||
assert "user_id" in result.data.error # Error contains parameter name
|
||||
assert result.data.error_type == "INVALID_PAYLOAD_FORMAT_ERROR"
|
||||
|
||||
@patch("superset.security_manager")
|
||||
@patch("superset.db")
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_sql_empty_parameters_with_placeholders(
|
||||
self, mock_db, mock_security_manager, mcp_server
|
||||
):
|
||||
"""Test error when empty parameters dict is provided but SQL has
|
||||
placeholders."""
|
||||
mock_database = _mock_database()
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
|
||||
mock_database
|
||||
)
|
||||
mock_security_manager.can_access_database.return_value = True
|
||||
|
||||
request = {
|
||||
"database_id": 1,
|
||||
"sql": "SELECT * FROM {table_name} LIMIT 5",
|
||||
"parameters": {}, # Empty dict but SQL has {table_name}
|
||||
"limit": 5,
|
||||
}
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("execute_sql", {"request": request})
|
||||
|
||||
assert result.data.success is False
|
||||
assert result.data.error is not None
|
||||
assert "Missing parameter: table_name" in result.data.error
|
||||
assert result.data.error_type == "INVALID_PAYLOAD_FORMAT_ERROR"
|
||||
|
||||
@patch("superset.security_manager")
|
||||
@patch("superset.db")
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_sql_with_schema(
|
||||
self, mock_db, mock_security_manager, mcp_server
|
||||
):
|
||||
"""Test SQL execution with schema specification."""
|
||||
mock_database = _mock_database()
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
|
||||
mock_database
|
||||
)
|
||||
@@ -381,28 +366,47 @@ class TestExecuteSql:
|
||||
"database_id": 1,
|
||||
"sql": "SELECT COUNT(*) as total FROM orders",
|
||||
"schema": "sales",
|
||||
"catalog": "prod_catalog",
|
||||
"limit": 1,
|
||||
}
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("execute_sql", {"request": request})
|
||||
|
||||
assert result.data.success is True
|
||||
assert result.data.error is None
|
||||
# Verify schema was passed to get_raw_connection
|
||||
# Verify schema was passed
|
||||
call_args = mock_database.get_raw_connection.call_args
|
||||
assert call_args[1]["schema"] == "sales"
|
||||
assert call_args[1]["catalog"] is None
|
||||
# Use structured_content for dictionary access (Pydantic model responses)
|
||||
data = result.structured_content
|
||||
assert data["success"] is True
|
||||
|
||||
# Verify schema and catalog were passed to QueryOptions
|
||||
call_args = mock_database.execute.call_args
|
||||
options = call_args[0][1]
|
||||
assert options.schema == "sales"
|
||||
assert options.catalog == "prod_catalog"
|
||||
|
||||
@patch("superset.security_manager")
|
||||
@patch("superset.db")
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_sql_limit_enforcement(
|
||||
async def test_execute_sql_dry_run(
|
||||
self, mock_db, mock_security_manager, mcp_server
|
||||
):
|
||||
"""Test that LIMIT is added to SELECT queries without one."""
|
||||
"""Test dry_run mode returns transformed SQL without executing."""
|
||||
mock_database = _mock_database()
|
||||
executed_sql = "SELECT * FROM users WHERE user_id IN (SELECT ...) LIMIT 100"
|
||||
mock_database.execute.return_value = QueryResult(
|
||||
status=QueryStatus.SUCCESS,
|
||||
statements=[
|
||||
StatementResult(
|
||||
original_sql="SELECT * FROM {{ table }}",
|
||||
executed_sql=executed_sql,
|
||||
data=None,
|
||||
row_count=0,
|
||||
execution_time_ms=0,
|
||||
)
|
||||
],
|
||||
query_id=None,
|
||||
total_execution_time_ms=0,
|
||||
is_cached=False,
|
||||
)
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
|
||||
mock_database
|
||||
)
|
||||
@@ -410,25 +414,32 @@ class TestExecuteSql:
|
||||
|
||||
request = {
|
||||
"database_id": 1,
|
||||
"sql": "SELECT * FROM users", # No LIMIT
|
||||
"limit": 50,
|
||||
"sql": "SELECT * FROM {{ table }}",
|
||||
"template_params": {"table": "users"},
|
||||
"dry_run": True,
|
||||
"limit": 100,
|
||||
}
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("execute_sql", {"request": request})
|
||||
|
||||
assert result.data.success is True
|
||||
# Verify LIMIT was added
|
||||
cursor = ( # fmt: skip
|
||||
mock_database.get_raw_connection.return_value.__enter__.return_value.cursor.return_value
|
||||
)
|
||||
executed_sql = cursor.execute.call_args[0][0]
|
||||
assert "LIMIT 50" in executed_sql
|
||||
# Use structured_content for dictionary access (Pydantic model responses)
|
||||
data = result.structured_content
|
||||
assert data["success"] is True
|
||||
# Verify dry_run was passed
|
||||
call_args = mock_database.execute.call_args
|
||||
options = call_args[0][1]
|
||||
assert options.dry_run is True
|
||||
|
||||
# Verify statements show transformed SQL
|
||||
assert data["statements"] is not None
|
||||
assert "{{ table }}" in data["statements"][0]["original_sql"]
|
||||
assert "users" in data["statements"][0]["executed_sql"]
|
||||
|
||||
@patch("superset.security_manager")
|
||||
@patch("superset.db")
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_sql_sql_injection_prevention(
|
||||
async def test_execute_sql_timeout_error(
|
||||
self, mock_db, mock_security_manager, mcp_server
|
||||
):
|
||||
"""Test that SQL injection attempts are handled safely.
|
||||
@@ -437,6 +448,10 @@ class TestExecuteSql:
|
||||
before execution when DML is not allowed on the database.
|
||||
"""
|
||||
mock_database = _mock_database()
|
||||
mock_database.execute.return_value = _create_error_result(
|
||||
error_message="Query exceeded the timeout limit",
|
||||
status=QueryStatus.TIMED_OUT,
|
||||
)
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
|
||||
mock_database
|
||||
)
|
||||
@@ -444,19 +459,76 @@ class TestExecuteSql:
|
||||
|
||||
request = {
|
||||
"database_id": 1,
|
||||
"sql": "SELECT * FROM users WHERE id = 1; DROP TABLE users;--",
|
||||
"sql": "SELECT * FROM large_table",
|
||||
"timeout": 5,
|
||||
"limit": 100,
|
||||
}
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("execute_sql", {"request": request})
|
||||
|
||||
# Use structured_content for dictionary access (Pydantic model responses)
|
||||
data = result.structured_content
|
||||
assert data["success"] is False
|
||||
assert data["error"] == "Query exceeded the timeout limit"
|
||||
assert data["error_type"] == "timed_out"
|
||||
|
||||
@patch("superset.security_manager")
|
||||
@patch("superset.db")
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_sql_multi_statement(
|
||||
self, mock_db, mock_security_manager, mcp_server
|
||||
):
|
||||
"""Test multi-statement SQL execution."""
|
||||
mock_database = _mock_database()
|
||||
mock_database.execute.return_value = QueryResult(
|
||||
status=QueryStatus.SUCCESS,
|
||||
statements=[
|
||||
StatementResult(
|
||||
original_sql="SELECT 1 as a",
|
||||
executed_sql="SELECT 1 as a",
|
||||
data=pd.DataFrame([{"a": 1}]),
|
||||
row_count=1,
|
||||
execution_time_ms=5.0,
|
||||
),
|
||||
StatementResult(
|
||||
original_sql="SELECT 2 as b",
|
||||
executed_sql="SELECT 2 as b",
|
||||
data=pd.DataFrame([{"b": 2}]),
|
||||
row_count=1,
|
||||
execution_time_ms=3.0,
|
||||
),
|
||||
],
|
||||
query_id=None,
|
||||
total_execution_time_ms=8.0,
|
||||
is_cached=False,
|
||||
)
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
|
||||
mock_database
|
||||
)
|
||||
mock_security_manager.can_access_database.return_value = True
|
||||
|
||||
request = {
|
||||
"database_id": 1,
|
||||
"sql": "SELECT 1 as a; SELECT 2 as b;",
|
||||
"limit": 10,
|
||||
}
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("execute_sql", {"request": request})
|
||||
|
||||
# SQLScript correctly detects DROP TABLE as a mutation
|
||||
# and blocks it before execution (improved security)
|
||||
assert result.data.success is False
|
||||
assert result.data.error is not None
|
||||
assert "DML" in result.data.error or "mutates" in result.data.error
|
||||
assert result.data.error_type == "DML_NOT_ALLOWED"
|
||||
# Use structured_content for dictionary access (Pydantic model responses)
|
||||
data = result.structured_content
|
||||
assert data["success"] is True
|
||||
# Statements should contain both
|
||||
assert data["statements"] is not None
|
||||
assert len(data["statements"]) == 2
|
||||
assert data["statements"][0]["original_sql"] == "SELECT 1 as a"
|
||||
assert data["statements"][1]["original_sql"] == "SELECT 2 as b"
|
||||
|
||||
# rows/columns should be from first statement for backward compat
|
||||
assert data["rows"] == [{"a": 1}]
|
||||
assert data["row_count"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_sql_empty_query_validation(self, mcp_server):
|
||||
@@ -495,3 +567,39 @@ class TestExecuteSql:
|
||||
async with Client(mcp_server) as client:
|
||||
with pytest.raises(ToolError, match="less than or equal to 10000"):
|
||||
await client.call_tool("execute_sql", {"request": request})
|
||||
|
||||
@patch("superset.security_manager")
|
||||
@patch("superset.db")
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_sql_force_refresh(
|
||||
self, mock_db, mock_security_manager, mcp_server
|
||||
):
|
||||
"""Test force_refresh bypasses cache."""
|
||||
mock_database = _mock_database()
|
||||
mock_database.execute.return_value = _create_select_result(
|
||||
rows=[{"id": 1}],
|
||||
columns=["id"],
|
||||
)
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
|
||||
mock_database
|
||||
)
|
||||
mock_security_manager.can_access_database.return_value = True
|
||||
|
||||
request = {
|
||||
"database_id": 1,
|
||||
"sql": "SELECT id FROM users",
|
||||
"limit": 10,
|
||||
"force_refresh": True,
|
||||
}
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("execute_sql", {"request": request})
|
||||
|
||||
data = result.structured_content
|
||||
assert data["success"] is True
|
||||
|
||||
# Verify force_refresh was passed to CacheOptions
|
||||
call_args = mock_database.execute.call_args
|
||||
options = call_args[0][1]
|
||||
assert options.cache is not None
|
||||
assert options.cache.force_refresh is True
|
||||
|
||||
Reference in New Issue
Block a user