mirror of
https://github.com/apache/superset.git
synced 2026-04-18 23:55:00 +00:00
fix(mcp): use SQLScript for all SQL parsing in execute_sql (#36599)
This commit is contained in:
137
tests/unit_tests/mcp_service/sql_lab/test_sql_lab_utils.py
Normal file
137
tests/unit_tests/mcp_service/sql_lab/test_sql_lab_utils.py
Normal file
@@ -0,0 +1,137 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""Unit tests for MCP SQL Lab utility functions."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from superset.mcp_service.sql_lab.sql_lab_utils import _apply_limit
|
||||
from superset.sql.parse import SQLScript
|
||||
|
||||
|
||||
class TestSQLScriptMutationDetection:
|
||||
"""Tests for SQLScript.has_mutation() used in query type detection."""
|
||||
|
||||
def test_simple_select_no_mutation(self):
|
||||
"""Test simple SELECT query is not a mutation."""
|
||||
script = SQLScript("SELECT * FROM table1", "sqlite")
|
||||
assert script.has_mutation() is False
|
||||
|
||||
def test_cte_query_no_mutation(self):
|
||||
"""Test CTE (WITH clause) query is not a mutation."""
|
||||
cte_sql = """
|
||||
WITH cte_name AS (
|
||||
SELECT * FROM table1
|
||||
)
|
||||
SELECT * FROM cte_name
|
||||
"""
|
||||
script = SQLScript(cte_sql, "sqlite")
|
||||
assert script.has_mutation() is False
|
||||
|
||||
def test_recursive_cte_no_mutation(self):
|
||||
"""Test recursive CTE is not a mutation."""
|
||||
recursive_cte = """
|
||||
WITH RECURSIVE cte AS (
|
||||
SELECT 1 AS n
|
||||
UNION ALL
|
||||
SELECT n + 1 FROM cte WHERE n < 10
|
||||
)
|
||||
SELECT n FROM cte
|
||||
"""
|
||||
script = SQLScript(recursive_cte, "sqlite")
|
||||
assert script.has_mutation() is False
|
||||
|
||||
def test_multiple_ctes_no_mutation(self):
|
||||
"""Test query with multiple CTEs is not a mutation."""
|
||||
multiple_ctes = """
|
||||
WITH
|
||||
cte1 AS (SELECT 1 as a),
|
||||
cte2 AS (SELECT 2 as b)
|
||||
SELECT * FROM cte1, cte2
|
||||
"""
|
||||
script = SQLScript(multiple_ctes, "sqlite")
|
||||
assert script.has_mutation() is False
|
||||
|
||||
def test_insert_is_mutation(self):
|
||||
"""Test INSERT query is a mutation."""
|
||||
script = SQLScript("INSERT INTO table1 VALUES (1)", "sqlite")
|
||||
assert script.has_mutation() is True
|
||||
|
||||
def test_update_is_mutation(self):
|
||||
"""Test UPDATE query is a mutation."""
|
||||
script = SQLScript("UPDATE table1 SET col = 1", "sqlite")
|
||||
assert script.has_mutation() is True
|
||||
|
||||
def test_delete_is_mutation(self):
|
||||
"""Test DELETE query is a mutation."""
|
||||
script = SQLScript("DELETE FROM table1", "sqlite")
|
||||
assert script.has_mutation() is True
|
||||
|
||||
def test_create_is_mutation(self):
|
||||
"""Test CREATE query is a mutation."""
|
||||
script = SQLScript("CREATE TABLE table1 (id INT)", "sqlite")
|
||||
assert script.has_mutation() is True
|
||||
|
||||
|
||||
class TestApplyLimit:
|
||||
"""Tests for _apply_limit function using SQLScript."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_database(self):
|
||||
"""Create a mock database with sqlite engine spec."""
|
||||
db = MagicMock()
|
||||
db.db_engine_spec.engine = "sqlite"
|
||||
return db
|
||||
|
||||
def test_adds_limit_to_select(self, mock_database):
|
||||
"""Test LIMIT is added to SELECT query."""
|
||||
result = _apply_limit("SELECT * FROM table1", 100, mock_database)
|
||||
assert "LIMIT 100" in result
|
||||
|
||||
def test_adds_limit_to_cte(self, mock_database):
|
||||
"""Test LIMIT is added to CTE query."""
|
||||
cte_sql = "WITH cte AS (SELECT 1) SELECT * FROM cte"
|
||||
result = _apply_limit(cte_sql, 50, mock_database)
|
||||
assert "LIMIT 50" in result
|
||||
|
||||
def test_preserves_existing_limit(self, mock_database):
|
||||
"""Test existing LIMIT is not modified."""
|
||||
sql = "SELECT * FROM table1 LIMIT 10"
|
||||
result = _apply_limit(sql, 100, mock_database)
|
||||
assert "LIMIT 10" in result
|
||||
assert "LIMIT 100" not in result
|
||||
|
||||
def test_preserves_existing_limit_in_cte(self, mock_database):
|
||||
"""Test existing LIMIT in CTE query is not modified."""
|
||||
cte_sql = "WITH cte AS (SELECT 1) SELECT * FROM cte LIMIT 5"
|
||||
result = _apply_limit(cte_sql, 100, mock_database)
|
||||
assert "LIMIT 5" in result
|
||||
assert "LIMIT 100" not in result
|
||||
|
||||
def test_no_limit_on_insert(self, mock_database):
|
||||
"""Test LIMIT is not added to INSERT query."""
|
||||
sql = "INSERT INTO table1 VALUES (1)"
|
||||
result = _apply_limit(sql, 100, mock_database)
|
||||
assert result == sql
|
||||
|
||||
def test_no_limit_on_update(self, mock_database):
|
||||
"""Test LIMIT is not added to UPDATE query."""
|
||||
sql = "UPDATE table1 SET col = 1"
|
||||
result = _apply_limit(sql, 100, mock_database)
|
||||
assert result == sql
|
||||
@@ -431,19 +431,17 @@ class TestExecuteSql:
|
||||
async def test_execute_sql_sql_injection_prevention(
|
||||
self, mock_db, mock_security_manager, mcp_server
|
||||
):
|
||||
"""Test that SQL injection attempts are handled safely."""
|
||||
"""Test that SQL injection attempts are handled safely.
|
||||
|
||||
SQLScript detects the DROP TABLE as a mutation and blocks it
|
||||
before execution when DML is not allowed on the database.
|
||||
"""
|
||||
mock_database = _mock_database()
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
|
||||
mock_database
|
||||
)
|
||||
mock_security_manager.can_access_database.return_value = True
|
||||
|
||||
# Mock execute to raise an exception
|
||||
cursor = ( # fmt: skip
|
||||
mock_database.get_raw_connection.return_value.__enter__.return_value.cursor.return_value
|
||||
)
|
||||
cursor.execute.side_effect = Exception("Syntax error")
|
||||
|
||||
request = {
|
||||
"database_id": 1,
|
||||
"sql": "SELECT * FROM users WHERE id = 1; DROP TABLE users;--",
|
||||
@@ -453,10 +451,12 @@ class TestExecuteSql:
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("execute_sql", {"request": request})
|
||||
|
||||
# SQLScript correctly detects DROP TABLE as a mutation
|
||||
# and blocks it before execution (improved security)
|
||||
assert result.data.success is False
|
||||
assert result.data.error is not None
|
||||
assert "Syntax error" in result.data.error # Contains actual error
|
||||
assert result.data.error_type == "EXECUTION_ERROR"
|
||||
assert "DML" in result.data.error or "mutates" in result.data.error
|
||||
assert result.data.error_type == "DML_NOT_ALLOWED"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_sql_empty_query_validation(self, mcp_server):
|
||||
|
||||
Reference in New Issue
Block a user