diff --git a/superset/mcp_service/sql_lab/sql_lab_utils.py b/superset/mcp_service/sql_lab/sql_lab_utils.py index 6844e26a497..10c543c7686 100644 --- a/superset/mcp_service/sql_lab/sql_lab_utils.py +++ b/superset/mcp_service/sql_lab/sql_lab_utils.py @@ -70,26 +70,22 @@ def validate_sql_query(sql: str, database: Any) -> None: SupersetDisallowedSQLFunctionException, SupersetDMLNotAllowedException, ) + from superset.sql.parse import SQLScript - # Simplified validation without complex parsing - sql_upper = sql.upper().strip() + # Use SQLScript for proper SQL parsing + script = SQLScript(sql, database.db_engine_spec.engine) # Check for DML operations if not allowed - dml_keywords = ["INSERT", "UPDATE", "DELETE", "DROP", "CREATE", "ALTER", "TRUNCATE"] - if any(sql_upper.startswith(keyword) for keyword in dml_keywords): - if not database.allow_dml: - raise SupersetDMLNotAllowedException() + if script.has_mutation() and not database.allow_dml: + raise SupersetDMLNotAllowedException() # Check for disallowed functions from config disallowed_functions = app.config.get("DISALLOWED_SQL_FUNCTIONS", {}).get( - "sqlite", - set(), # Default to sqlite for now + database.db_engine_spec.engine, + set(), ) - if disallowed_functions: - sql_lower = sql.lower() - for func in disallowed_functions: - if f"{func.lower()}(" in sql_lower: - raise SupersetDisallowedSQLFunctionException(disallowed_functions) + if disallowed_functions and script.check_functions_present(disallowed_functions): + raise SupersetDisallowedSQLFunctionException(disallowed_functions) def execute_sql_query( @@ -110,8 +106,8 @@ def execute_sql_query( sql = _apply_parameters(sql, parameters) validate_sql_query(sql, database) - # Apply limit for SELECT queries - rendered_sql = _apply_limit(sql, limit) + # Apply limit for SELECT queries using SQLScript + rendered_sql = _apply_limit(sql, limit, database) # Execute and get results results = _execute_query(database, rendered_sql, schema, limit) @@ -156,12 +152,23 @@ def _apply_parameters(sql: str, parameters: dict[str, Any] | None) -> str: return sql -def _apply_limit(sql: str, limit: int) -> str: - """Apply limit to SELECT queries if not already present.""" - sql_lower = sql.lower().strip() - if sql_lower.startswith("select") and "limit" not in sql_lower: - return f"{sql.rstrip().rstrip(';')} LIMIT {limit}" - return sql +def _apply_limit(sql: str, limit: int, database: Any) -> str: + """Apply limit to SELECT queries using SQLScript for proper parsing.""" + from superset.sql.parse import LimitMethod, SQLScript + + script = SQLScript(sql, database.db_engine_spec.engine) + + # Only apply limit to non-mutating (SELECT-like) queries + if script.has_mutation(): + return sql + + # Apply limit to each statement in the script + for statement in script.statements: + # Only set limit if not already present + if statement.get_limit_value() is None: + statement.set_limit_value(limit, LimitMethod.FORCE_LIMIT) + + return script.format() def _execute_query( @@ -172,6 +179,7 @@ def _execute_query( ) -> dict[str, Any]: """Execute the query and process results.""" # Import inside function to avoid initialization issues + from superset.sql.parse import SQLScript from superset.utils.core import QuerySource results = { @@ -192,11 +200,12 @@ def _execute_query( cursor = conn.cursor() cursor.execute(sql) - # Process results based on query type - if _is_select_query(sql): - _process_select_results(cursor, results, limit) - else: + # Use SQLScript for proper SQL parsing to determine query type + script = SQLScript(sql, database.db_engine_spec.engine) + if script.has_mutation(): _process_dml_results(cursor, conn, results) + else: + _process_select_results(cursor, results, limit) except Exception as e: logger.error("Error executing SQL: %s", e) @@ -205,11 +214,6 @@ def _execute_query( return results -def _is_select_query(sql: str) -> bool: - """Check if SQL is a SELECT query.""" - return sql.lower().strip().startswith("select") - - def _process_select_results(cursor: Any, results: dict[str, Any], limit: int) -> None: """Process SELECT query results.""" # Fetch results diff --git a/tests/unit_tests/mcp_service/sql_lab/test_sql_lab_utils.py b/tests/unit_tests/mcp_service/sql_lab/test_sql_lab_utils.py new file mode 100644 index 00000000000..9a08b72a91d --- /dev/null +++ b/tests/unit_tests/mcp_service/sql_lab/test_sql_lab_utils.py @@ -0,0 +1,137 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Unit tests for MCP SQL Lab utility functions.""" + +from unittest.mock import MagicMock + +import pytest + +from superset.mcp_service.sql_lab.sql_lab_utils import _apply_limit +from superset.sql.parse import SQLScript + + +class TestSQLScriptMutationDetection: + """Tests for SQLScript.has_mutation() used in query type detection.""" + + def test_simple_select_no_mutation(self): + """Test simple SELECT query is not a mutation.""" + script = SQLScript("SELECT * FROM table1", "sqlite") + assert script.has_mutation() is False + + def test_cte_query_no_mutation(self): + """Test CTE (WITH clause) query is not a mutation.""" + cte_sql = """ + WITH cte_name AS ( + SELECT * FROM table1 + ) + SELECT * FROM cte_name + """ + script = SQLScript(cte_sql, "sqlite") + assert script.has_mutation() is False + + def test_recursive_cte_no_mutation(self): + """Test recursive CTE is not a mutation.""" + recursive_cte = """ + WITH RECURSIVE cte AS ( + SELECT 1 AS n + UNION ALL + SELECT n + 1 FROM cte WHERE n < 10 + ) + SELECT n FROM cte + """ + script = SQLScript(recursive_cte, "sqlite") + assert script.has_mutation() is False + + def test_multiple_ctes_no_mutation(self): + """Test query with multiple CTEs is not a mutation.""" + multiple_ctes = """ + WITH + cte1 AS (SELECT 1 as a), + cte2 AS (SELECT 2 as b) + SELECT * FROM cte1, cte2 + """ + script = SQLScript(multiple_ctes, "sqlite") + assert script.has_mutation() is False + + def test_insert_is_mutation(self): + """Test INSERT query is a mutation.""" + script = SQLScript("INSERT INTO table1 VALUES (1)", "sqlite") + assert script.has_mutation() is True + + def test_update_is_mutation(self): + """Test UPDATE query is a mutation.""" + script = SQLScript("UPDATE table1 SET col = 1", "sqlite") + assert script.has_mutation() is True + + def test_delete_is_mutation(self): + """Test DELETE query is a mutation.""" + script = SQLScript("DELETE FROM table1", "sqlite") + assert script.has_mutation() is True + + def test_create_is_mutation(self): + """Test CREATE query is a mutation.""" + script = SQLScript("CREATE TABLE table1 (id INT)", "sqlite") + assert script.has_mutation() is True + + +class TestApplyLimit: + """Tests for _apply_limit function using SQLScript.""" + + @pytest.fixture + def mock_database(self): + """Create a mock database with sqlite engine spec.""" + db = MagicMock() + db.db_engine_spec.engine = "sqlite" + return db + + def test_adds_limit_to_select(self, mock_database): + """Test LIMIT is added to SELECT query.""" + result = _apply_limit("SELECT * FROM table1", 100, mock_database) + assert "LIMIT 100" in result + + def test_adds_limit_to_cte(self, mock_database): + """Test LIMIT is added to CTE query.""" + cte_sql = "WITH cte AS (SELECT 1) SELECT * FROM cte" + result = _apply_limit(cte_sql, 50, mock_database) + assert "LIMIT 50" in result + + def test_preserves_existing_limit(self, mock_database): + """Test existing LIMIT is not modified.""" + sql = "SELECT * FROM table1 LIMIT 10" + result = _apply_limit(sql, 100, mock_database) + assert "LIMIT 10" in result + assert "LIMIT 100" not in result + + def test_preserves_existing_limit_in_cte(self, mock_database): + """Test existing LIMIT in CTE query is not modified.""" + cte_sql = "WITH cte AS (SELECT 1) SELECT * FROM cte LIMIT 5" + result = _apply_limit(cte_sql, 100, mock_database) + assert "LIMIT 5" in result + assert "LIMIT 100" not in result + + def test_no_limit_on_insert(self, mock_database): + """Test LIMIT is not added to INSERT query.""" + sql = "INSERT INTO table1 VALUES (1)" + result = _apply_limit(sql, 100, mock_database) + assert result == sql + + def test_no_limit_on_update(self, mock_database): + """Test LIMIT is not added to UPDATE query.""" + sql = "UPDATE table1 SET col = 1" + result = _apply_limit(sql, 100, mock_database) + assert result == sql diff --git a/tests/unit_tests/mcp_service/sql_lab/tool/test_execute_sql.py b/tests/unit_tests/mcp_service/sql_lab/tool/test_execute_sql.py index bbf9e410b04..5b3012462db 100644 --- a/tests/unit_tests/mcp_service/sql_lab/tool/test_execute_sql.py +++ b/tests/unit_tests/mcp_service/sql_lab/tool/test_execute_sql.py @@ -431,19 +431,17 @@ class TestExecuteSql: async def test_execute_sql_sql_injection_prevention( self, mock_db, mock_security_manager, mcp_server ): - """Test that SQL injection attempts are handled safely.""" + """Test that SQL injection attempts are handled safely. + + SQLScript detects the DROP TABLE as a mutation and blocks it + before execution when DML is not allowed on the database. + """ mock_database = _mock_database() mock_db.session.query.return_value.filter_by.return_value.first.return_value = ( mock_database ) mock_security_manager.can_access_database.return_value = True - # Mock execute to raise an exception - cursor = ( # fmt: skip - mock_database.get_raw_connection.return_value.__enter__.return_value.cursor.return_value - ) - cursor.execute.side_effect = Exception("Syntax error") - request = { "database_id": 1, "sql": "SELECT * FROM users WHERE id = 1; DROP TABLE users;--", @@ -453,10 +451,12 @@ class TestExecuteSql: async with Client(mcp_server) as client: result = await client.call_tool("execute_sql", {"request": request}) + # SQLScript correctly detects DROP TABLE as a mutation + # and blocks it before execution (improved security) assert result.data.success is False assert result.data.error is not None - assert "Syntax error" in result.data.error # Contains actual error - assert result.data.error_type == "EXECUTION_ERROR" + assert "DML" in result.data.error or "mutates" in result.data.error + assert result.data.error_type == "DML_NOT_ALLOWED" @pytest.mark.asyncio async def test_execute_sql_empty_query_validation(self, mcp_server):