fix(mcp): use SQLScript for all SQL parsing in execute_sql (#36599)

This commit is contained in:
Amin Ghadersohi
2025-12-20 23:52:56 -05:00
committed by GitHub
parent c026ae2ce7
commit e3e6b0e18b
3 changed files with 180 additions and 39 deletions

View File

@@ -0,0 +1,137 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Unit tests for MCP SQL Lab utility functions."""
from unittest.mock import MagicMock
import pytest
from superset.mcp_service.sql_lab.sql_lab_utils import _apply_limit
from superset.sql.parse import SQLScript
class TestSQLScriptMutationDetection:
"""Tests for SQLScript.has_mutation() used in query type detection."""
def test_simple_select_no_mutation(self):
"""Test simple SELECT query is not a mutation."""
script = SQLScript("SELECT * FROM table1", "sqlite")
assert script.has_mutation() is False
def test_cte_query_no_mutation(self):
"""Test CTE (WITH clause) query is not a mutation."""
cte_sql = """
WITH cte_name AS (
SELECT * FROM table1
)
SELECT * FROM cte_name
"""
script = SQLScript(cte_sql, "sqlite")
assert script.has_mutation() is False
def test_recursive_cte_no_mutation(self):
"""Test recursive CTE is not a mutation."""
recursive_cte = """
WITH RECURSIVE cte AS (
SELECT 1 AS n
UNION ALL
SELECT n + 1 FROM cte WHERE n < 10
)
SELECT n FROM cte
"""
script = SQLScript(recursive_cte, "sqlite")
assert script.has_mutation() is False
def test_multiple_ctes_no_mutation(self):
"""Test query with multiple CTEs is not a mutation."""
multiple_ctes = """
WITH
cte1 AS (SELECT 1 as a),
cte2 AS (SELECT 2 as b)
SELECT * FROM cte1, cte2
"""
script = SQLScript(multiple_ctes, "sqlite")
assert script.has_mutation() is False
def test_insert_is_mutation(self):
"""Test INSERT query is a mutation."""
script = SQLScript("INSERT INTO table1 VALUES (1)", "sqlite")
assert script.has_mutation() is True
def test_update_is_mutation(self):
"""Test UPDATE query is a mutation."""
script = SQLScript("UPDATE table1 SET col = 1", "sqlite")
assert script.has_mutation() is True
def test_delete_is_mutation(self):
"""Test DELETE query is a mutation."""
script = SQLScript("DELETE FROM table1", "sqlite")
assert script.has_mutation() is True
def test_create_is_mutation(self):
"""Test CREATE query is a mutation."""
script = SQLScript("CREATE TABLE table1 (id INT)", "sqlite")
assert script.has_mutation() is True
class TestApplyLimit:
"""Tests for _apply_limit function using SQLScript."""
@pytest.fixture
def mock_database(self):
"""Create a mock database with sqlite engine spec."""
db = MagicMock()
db.db_engine_spec.engine = "sqlite"
return db
def test_adds_limit_to_select(self, mock_database):
"""Test LIMIT is added to SELECT query."""
result = _apply_limit("SELECT * FROM table1", 100, mock_database)
assert "LIMIT 100" in result
def test_adds_limit_to_cte(self, mock_database):
"""Test LIMIT is added to CTE query."""
cte_sql = "WITH cte AS (SELECT 1) SELECT * FROM cte"
result = _apply_limit(cte_sql, 50, mock_database)
assert "LIMIT 50" in result
def test_preserves_existing_limit(self, mock_database):
"""Test existing LIMIT is not modified."""
sql = "SELECT * FROM table1 LIMIT 10"
result = _apply_limit(sql, 100, mock_database)
assert "LIMIT 10" in result
assert "LIMIT 100" not in result
def test_preserves_existing_limit_in_cte(self, mock_database):
"""Test existing LIMIT in CTE query is not modified."""
cte_sql = "WITH cte AS (SELECT 1) SELECT * FROM cte LIMIT 5"
result = _apply_limit(cte_sql, 100, mock_database)
assert "LIMIT 5" in result
assert "LIMIT 100" not in result
def test_no_limit_on_insert(self, mock_database):
"""Test LIMIT is not added to INSERT query."""
sql = "INSERT INTO table1 VALUES (1)"
result = _apply_limit(sql, 100, mock_database)
assert result == sql
def test_no_limit_on_update(self, mock_database):
"""Test LIMIT is not added to UPDATE query."""
sql = "UPDATE table1 SET col = 1"
result = _apply_limit(sql, 100, mock_database)
assert result == sql

View File

@@ -431,19 +431,17 @@ class TestExecuteSql:
async def test_execute_sql_sql_injection_prevention(
self, mock_db, mock_security_manager, mcp_server
):
"""Test that SQL injection attempts are handled safely."""
"""Test that SQL injection attempts are handled safely.
SQLScript detects the DROP TABLE as a mutation and blocks it
before execution when DML is not allowed on the database.
"""
mock_database = _mock_database()
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
mock_database
)
mock_security_manager.can_access_database.return_value = True
# Mock execute to raise an exception
cursor = ( # fmt: skip
mock_database.get_raw_connection.return_value.__enter__.return_value.cursor.return_value
)
cursor.execute.side_effect = Exception("Syntax error")
request = {
"database_id": 1,
"sql": "SELECT * FROM users WHERE id = 1; DROP TABLE users;--",
@@ -453,10 +451,12 @@ class TestExecuteSql:
async with Client(mcp_server) as client:
result = await client.call_tool("execute_sql", {"request": request})
# SQLScript correctly detects DROP TABLE as a mutation
# and blocks it before execution (improved security)
assert result.data.success is False
assert result.data.error is not None
assert "Syntax error" in result.data.error # Contains actual error
assert result.data.error_type == "EXECUTION_ERROR"
assert "DML" in result.data.error or "mutates" in result.data.error
assert result.data.error_type == "DML_NOT_ALLOWED"
@pytest.mark.asyncio
async def test_execute_sql_empty_query_validation(self, mcp_server):