diff --git a/superset/mcp_service/sql_lab/tool/execute_sql.py b/superset/mcp_service/sql_lab/tool/execute_sql.py index b5f01e8afa7..1e3dfddf866 100644 --- a/superset/mcp_service/sql_lab/tool/execute_sql.py +++ b/superset/mcp_service/sql_lab/tool/execute_sql.py @@ -50,6 +50,7 @@ from superset.mcp_service.utils.oauth2_utils import ( build_oauth2_redirect_message, OAUTH2_CONFIG_ERROR_MESSAGE, ) +from superset.sql.parse import SQLScript logger = logging.getLogger(__name__) @@ -117,7 +118,50 @@ async def execute_sql(request: ExecuteSqlRequest, ctx: Context) -> ExecuteSqlRes error_type=SupersetErrorType.DATABASE_SECURITY_ACCESS_ERROR.value, ) - # 2. Build QueryOptions and execute query + # 2. Block destructive DDL (DROP, TRUNCATE, ALTER) + # Fail-closed: if parsing fails, block the query rather than + # allowing potentially destructive SQL to bypass the check. + # Render Jinja2 templates first so templated SQL can be parsed. + with event_logger.log_context(action="mcp.execute_sql.ddl_check"): + try: + sql_to_check = request.sql + if request.template_params: + from superset.jinja_context import get_template_processor + + tp = get_template_processor(database=database) + sql_to_check = tp.process_template( + request.sql, **request.template_params + ) + + script = SQLScript(sql_to_check, database.db_engine_spec.engine) + if script.has_destructive(): + await ctx.error( + "Destructive DDL blocked: sql_preview=%r" % sql_preview + ) + return ExecuteSqlResponse( + success=False, + error=( + "Destructive DDL statements (DROP, TRUNCATE, ALTER) " + "are not allowed through MCP. Use the Superset SQL " + "Lab UI for administrative database operations." + ), + error_type=SupersetErrorType.DML_NOT_ALLOWED_ERROR.value, + ) + except Exception as parse_err: + await ctx.error( + "DDL pre-check failed to parse SQL, blocking query: %s" + % str(parse_err) + ) + return ExecuteSqlResponse( + success=False, + error=( + "SQL could not be parsed for security validation. " + "Please check your SQL syntax and try again." + ), + error_type=SupersetErrorType.INVALID_SQL_ERROR.value, + ) + + # 3. Build QueryOptions and execute query cache_opts = CacheOptions(force_refresh=True) if request.force_refresh else None options = QueryOptions( catalog=request.catalog, @@ -129,11 +173,11 @@ async def execute_sql(request: ExecuteSqlRequest, ctx: Context) -> ExecuteSqlRes cache=cache_opts, ) - # 3. Execute query + # 4. Execute query with event_logger.log_context(action="mcp.execute_sql.query_execution"): result = database.execute(request.sql, options) - # 4. Convert to MCP response format + # 5. Convert to MCP response format with event_logger.log_context(action="mcp.execute_sql.response_conversion"): response = _convert_to_response(result) diff --git a/superset/sql/parse.py b/superset/sql/parse.py index aab3a61c50e..77475322d43 100644 --- a/superset/sql/parse.py +++ b/superset/sql/parse.py @@ -439,6 +439,14 @@ class BaseSQLStatement(Generic[InternalRepresentation]): """ raise NotImplementedError() + def is_destructive(self) -> bool: + """ + Check if the statement is destructive DDL (DROP, TRUNCATE, ALTER). + + :return: True if the statement is destructive DDL. + """ + raise NotImplementedError() + def optimize(self) -> BaseSQLStatement[InternalRepresentation]: """ Return optimized statement. @@ -719,6 +727,31 @@ class SQLStatement(BaseSQLStatement[exp.Expression]): return False + def is_destructive(self) -> bool: + """ + Check if the statement is destructive DDL (DROP, TRUNCATE, ALTER). + + Unlike ``is_mutating()``, this excludes non-destructive DML + (INSERT, UPDATE, DELETE, MERGE) and CREATE. + + :return: True if the statement is destructive DDL. + """ + destructive_nodes = ( + exp.Drop, + exp.TruncateTable, + exp.Alter, + ) + + for node_type in destructive_nodes: + if self._parsed.find(node_type): + return True + + # Handle ALTER parsed as Command (Oracle, MS SQL dialects) + if isinstance(self._parsed, exp.Command) and self._parsed.name == "ALTER": + return True # pragma: no cover + + return False + def format(self, comments: bool = True) -> str: """ Pretty-format the SQL statement. @@ -1175,6 +1208,18 @@ class KustoKQLStatement(BaseSQLStatement[str]): """ return self._parsed.startswith(".") and not self._parsed.startswith(".show") + def is_destructive(self) -> bool: + """ + Check if the statement is destructive DDL. + + Kusto KQL uses dot-commands for management operations. Destructive + operations start with ``.drop`` or ``.alter``. + + :return: True if the statement is destructive DDL. + """ + lower = self._parsed.lower() + return lower.startswith(".drop") or lower.startswith(".alter") + def optimize(self) -> KustoKQLStatement: """ Return optimized statement. @@ -1321,6 +1366,14 @@ class SQLScript: """ return any(statement.is_mutating() for statement in self.statements) + def has_destructive(self) -> bool: + """ + Check if the script contains destructive DDL (DROP, TRUNCATE, ALTER). + + :return: True if any statement is destructive DDL. + """ + return any(statement.is_destructive() for statement in self.statements) + def optimize(self) -> SQLScript: """ Return optimized script. diff --git a/tests/unit_tests/mcp_service/sql_lab/tool/test_execute_sql.py b/tests/unit_tests/mcp_service/sql_lab/tool/test_execute_sql.py index de67736ac55..e495183281b 100644 --- a/tests/unit_tests/mcp_service/sql_lab/tool/test_execute_sql.py +++ b/tests/unit_tests/mcp_service/sql_lab/tool/test_execute_sql.py @@ -1317,3 +1317,153 @@ class TestColumnInfoIsNullable: {"name": "c", "type": "int", "is_nullable": "UNKNOWN"} ) assert col.is_nullable is None + + +class TestDestructiveDDLBlocking: + """Tests for destructive DDL blocking in execute_sql.""" + + @pytest.fixture + def ddl_mocks(self): + """Common mock wiring for DDL blocking tests.""" + with ( + patch("superset.db") as mock_db, + patch("superset.security_manager") as mock_sm, + ): + mock_database = _mock_database() + mock_database.db_engine_spec.engine = "postgresql" + query_chain = mock_db.session.query.return_value + query_chain.filter_by.return_value.first.return_value = mock_database + mock_sm.can_access_database.return_value = True + yield mock_database + + @pytest.mark.asyncio + async def test_drop_table_blocked(self, ddl_mocks, mcp_server): + """DROP TABLE is blocked before reaching the executor.""" + async with Client(mcp_server) as client: + result = await client.call_tool( + "execute_sql", + {"request": {"database_id": 1, "sql": "DROP TABLE birth_names"}}, + ) + data = result.structured_content + assert data["success"] is False + assert "Destructive DDL" in data["error"] + assert "DROP" in data["error"] + ddl_mocks.execute.assert_not_called() + + @pytest.mark.asyncio + async def test_truncate_blocked(self, ddl_mocks, mcp_server): + """TRUNCATE TABLE is blocked before reaching the executor.""" + async with Client(mcp_server) as client: + result = await client.call_tool( + "execute_sql", + {"request": {"database_id": 1, "sql": "TRUNCATE TABLE birth_names"}}, + ) + data = result.structured_content + assert data["success"] is False + assert "Destructive DDL" in data["error"] + ddl_mocks.execute.assert_not_called() + + @pytest.mark.asyncio + async def test_alter_table_blocked(self, ddl_mocks, mcp_server): + """ALTER TABLE is blocked before reaching the executor.""" + async with Client(mcp_server) as client: + result = await client.call_tool( + "execute_sql", + { + "request": { + "database_id": 1, + "sql": "ALTER TABLE birth_names ADD COLUMN x INT", + } + }, + ) + data = result.structured_content + assert data["success"] is False + assert "Destructive DDL" in data["error"] + ddl_mocks.execute.assert_not_called() + + @pytest.mark.asyncio + async def test_drop_in_multi_statement_blocked(self, ddl_mocks, mcp_server): + """DROP TABLE hidden in a multi-statement query is blocked.""" + async with Client(mcp_server) as client: + result = await client.call_tool( + "execute_sql", + { + "request": { + "database_id": 1, + "sql": "DROP TABLE birth_names; SELECT 1", + } + }, + ) + data = result.structured_content + assert data["success"] is False + assert "Destructive DDL" in data["error"] + ddl_mocks.execute.assert_not_called() + + @pytest.mark.asyncio + async def test_select_allowed(self, ddl_mocks, mcp_server): + """SELECT queries pass through the DDL check.""" + ddl_mocks.execute.return_value = _create_select_result( + rows=[{"x": 1}], columns=["x"], original_sql="SELECT 1" + ) + + async with Client(mcp_server) as client: + result = await client.call_tool( + "execute_sql", + {"request": {"database_id": 1, "sql": "SELECT 1"}}, + ) + data = result.structured_content + assert data["success"] is True + ddl_mocks.execute.assert_called_once() + + @pytest.mark.asyncio + async def test_insert_allowed(self, ddl_mocks, mcp_server): + """INSERT queries pass through the DDL check (DML is allowed).""" + ddl_mocks.execute.return_value = _create_dml_result( + affected_rows=1, + original_sql="INSERT INTO t VALUES (1)", + ) + + async with Client(mcp_server) as client: + result = await client.call_tool( + "execute_sql", + {"request": {"database_id": 1, "sql": "INSERT INTO t VALUES (1)"}}, + ) + data = result.structured_content + assert data["success"] is True + ddl_mocks.execute.assert_called_once() + + @pytest.mark.asyncio + async def test_parse_failure_blocks_query(self, ddl_mocks, mcp_server): + """When SQL parsing fails, the query is blocked (fail-closed).""" + import sys + + execute_sql_mod = sys.modules["superset.mcp_service.sql_lab.tool.execute_sql"] + with patch.object( + execute_sql_mod, + "SQLScript", + side_effect=Exception("parse error"), + ): + async with Client(mcp_server) as client: + result = await client.call_tool( + "execute_sql", + {"request": {"database_id": 1, "sql": "DROP TABLE birth_names"}}, + ) + data = result.structured_content + assert data["success"] is False + assert "could not be parsed" in data["error"] + ddl_mocks.execute.assert_not_called() + + @pytest.mark.asyncio + async def test_drop_table_blocked_mysql(self, ddl_mocks, mcp_server): + """DROP TABLE is blocked for non-PostgreSQL dialects too.""" + ddl_mocks.db_engine_spec.engine = "mysql" + + async with Client(mcp_server) as client: + result = await client.call_tool( + "execute_sql", + {"request": {"database_id": 1, "sql": "DROP TABLE users"}}, + ) + data = result.structured_content + assert data["success"] is False + assert "Destructive DDL" in data["error"] + ddl_mocks.execute.assert_not_called() diff --git a/tests/unit_tests/sql/parse_tests.py b/tests/unit_tests/sql/parse_tests.py index a2361ae5abc..bbf4ec2e527 100644 --- a/tests/unit_tests/sql/parse_tests.py +++ b/tests/unit_tests/sql/parse_tests.py @@ -1327,6 +1327,66 @@ def test_is_mutating_anonymous_block(sql: str, expected: bool) -> None: assert SQLStatement(sql, "postgresql").is_mutating() == expected +@pytest.mark.parametrize( + "sql, expected", + [ + ("SELECT 1", False), + ("INSERT INTO t VALUES (1)", False), + ("UPDATE t SET x = 1", False), + ("DELETE FROM t", False), + ("MERGE INTO t USING s ON t.id = s.id WHEN MATCHED THEN DELETE", False), + ("CREATE TABLE t (id INT)", False), + ("DROP TABLE t", True), + ("DROP TABLE IF EXISTS t", True), + ("DROP VIEW v", True), + ("TRUNCATE TABLE t", True), + ("ALTER TABLE t ADD COLUMN x INT", True), + ("ALTER TABLE t DROP COLUMN x", True), + ], +) +def test_is_destructive(sql: str, expected: bool) -> None: + """ + Test that ``is_destructive`` detects DROP, TRUNCATE, and ALTER + but not SELECT, INSERT, UPDATE, DELETE, MERGE, or CREATE. + """ + assert SQLStatement(sql, "postgresql").is_destructive() == expected + + +@pytest.mark.parametrize( + "sql, expected", + [ + ("SELECT 1; INSERT INTO t VALUES (1)", False), + ("SELECT 1; DROP TABLE t", True), + ("SELECT 1; TRUNCATE TABLE t", True), + ("CREATE TABLE t (id INT); ALTER TABLE t ADD COLUMN x INT", True), + ], +) +def test_has_destructive(sql: str, expected: bool) -> None: + """ + Test that ``has_destructive`` on SQLScript detects destructive DDL + across multiple statements. + """ + assert SQLScript(sql, "postgresql").has_destructive() == expected + + +@pytest.mark.parametrize( + "kql, expected", + [ + (".drop table T", True), + (".alter table T (col:string)", True), + (".show tables", False), + ("T | count", False), + ], +) +def test_kusto_is_destructive(kql: str, expected: bool) -> None: + """ + Test ``is_destructive`` on KustoKQLStatement. + """ + from superset.sql.parse import KustoKQLStatement + + assert KustoKQLStatement(kql, "kustokql").is_destructive() == expected + + def test_optimize() -> None: """ Test that the `optimize` method works as expected.