From 15b3c96f8e97148f63e036038727c9a3ea9e15d3 Mon Sep 17 00:00:00 2001 From: Amin Ghadersohi Date: Mon, 9 Feb 2026 08:12:06 -0500 Subject: [PATCH] fix(security): Add table blocklist and fix MCP SQL validation bypass (#37411) --- superset/config.py | 28 +++++++ superset/exceptions.py | 15 ++++ superset/sql/execution/executor.py | 41 ++++++++++- superset/sql/parse.py | 40 ++++++++++ superset/sql_lab.py | 15 ++++ .../unit_tests/sql/execution/test_executor.py | 73 +++++++++++++++++-- tests/unit_tests/sql/parse_tests.py | 33 +++++++++ 7 files changed, 238 insertions(+), 7 deletions(-) diff --git a/superset/config.py b/superset/config.py index 9731d80bd58..970c965d051 100644 --- a/superset/config.py +++ b/superset/config.py @@ -1811,6 +1811,34 @@ DISALLOWED_SQL_FUNCTIONS: dict[str, set[str]] = { }, } +# Per-engine blocklist of system catalog tables/views that should not be queried. +# Prevents information disclosure through system catalog access. +DISALLOWED_SQL_TABLES: dict[str, set[str]] = { + "postgresql": { + "pg_stat_activity", + "pg_roles", + "pg_shadow", + "pg_authid", + "pg_settings", + "pg_config", + "pg_hba_file_rules", + "pg_stat_ssl", + "pg_stat_replication", + "pg_stat_wal_receiver", + "pg_user", + }, + "mysql": { + "mysql.user", + "performance_schema.threads", + "performance_schema.processlist", + }, + "mssql": { + "sys.server_principals", + "sys.sql_logins", + "sys.configurations", + }, +} + # A function that intercepts the SQL to be executed and can alter it. # A common use case for this is around adding some sort of comment header to the SQL diff --git a/superset/exceptions.py b/superset/exceptions.py index 4967ffbfa81..fabbbe13347 100644 --- a/superset/exceptions.py +++ b/superset/exceptions.py @@ -399,6 +399,21 @@ class SupersetDisallowedSQLFunctionException(SupersetErrorException): ) +class SupersetDisallowedSQLTableException(SupersetErrorException): + """ + Disallowed table/view found in SQL statement + """ + + def __init__(self, tables: set[str]): + super().__init__( + SupersetError( + message=f"SQL statement references disallowed table(s): {tables}", + error_type=SupersetErrorType.SYNTAX_ERROR, + level=ErrorLevel.ERROR, + ) + ) + + class CreateKeyValueDistributedLockFailedException(Exception): # noqa: N818 """ Exception to signalize failure to acquire lock. diff --git a/superset/sql/execution/executor.py b/superset/sql/execution/executor.py index b1f5e794049..e021920023d 100644 --- a/superset/sql/execution/executor.py +++ b/superset/sql/execution/executor.py @@ -451,10 +451,22 @@ class SQLExecutor: :raises SupersetSecurityException: If security checks fail """ # Check disallowed functions - if disallowed := self._check_disallowed_functions(script): + if disallowed_functions := self._check_disallowed_functions(script): raise SupersetSecurityException( SupersetError( - message=f"Disallowed SQL functions: {', '.join(disallowed)}", + message=( + f"Disallowed SQL functions: {', '.join(disallowed_functions)}" + ), + error_type=SupersetErrorType.INVALID_SQL_ERROR, + level=ErrorLevel.ERROR, + ) + ) + + # Check disallowed tables + if disallowed_tables := self._check_disallowed_tables(script): + raise SupersetSecurityException( + SupersetError( + message=f"Disallowed SQL tables: {', '.join(disallowed_tables)}", error_type=SupersetErrorType.INVALID_SQL_ERROR, level=ErrorLevel.ERROR, ) @@ -684,6 +696,31 @@ class SQLExecutor: return found if found else None + def _check_disallowed_tables(self, script: SQLScript) -> set[str] | None: + """ + Check for disallowed SQL tables/views. + + :param script: Parsed SQL script + :returns: Set of disallowed tables found, or None if none found + """ + disallowed_config = app.config.get("DISALLOWED_SQL_TABLES", {}) + engine_name = self.database.db_engine_spec.engine + + # Get disallowed tables for this engine + engine_disallowed = disallowed_config.get(engine_name, set()) + if not engine_disallowed: + return None + + # Single-pass AST-based table detection + found: set[str] = set() + for statement in script.statements: + present = {table.table.lower() for table in statement.tables} + for table in engine_disallowed: + if table.lower() in present: + found.add(table) + + return found or None + def _apply_rls_to_script( self, script: SQLScript, catalog: str | None, schema: str | None ) -> None: diff --git a/superset/sql/parse.py b/superset/sql/parse.py index af9a740ec75..843962eb091 100644 --- a/superset/sql/parse.py +++ b/superset/sql/parse.py @@ -452,6 +452,15 @@ class BaseSQLStatement(Generic[InternalRepresentation]): """ raise NotImplementedError() + def check_tables_present(self, tables: set[str]) -> bool: + """ + Check if any of the given tables are present in the statement. + + :param tables: Set of table names to check for (case-insensitive) + :return: True if any of the tables are present + """ + raise NotImplementedError() + def get_limit_value(self) -> int | None: """ Get the limit value of the statement. @@ -766,6 +775,16 @@ class SQLStatement(BaseSQLStatement[exp.Expression]): } return any(function.upper() in present for function in functions) + def check_tables_present(self, tables: set[str]) -> bool: + """ + Check if any of the given tables are present in the statement. + + :param tables: Set of table names to check for (case-insensitive) + :return: True if any of the tables are present + """ + present = {table.table.lower() for table in self.tables} + return any(table.lower() in present for table in tables) + def get_limit_value(self) -> int | None: """ Parse a SQL query and return the `LIMIT` or `TOP` value, if present. @@ -1172,6 +1191,16 @@ class KustoKQLStatement(BaseSQLStatement[str]): logger.warning("Kusto KQL doesn't support checking for functions present.") return False + def check_tables_present(self, tables: set[str]) -> bool: + """ + Check if any of the given tables are present in the statement. + + :param tables: Set of table names to check for (case-insensitive) + :return: True if any of the tables are present + """ + logger.warning("Kusto KQL doesn't support checking for tables present.") + return False + def get_limit_value(self) -> int | None: """ Get the limit value of the statement. @@ -1313,6 +1342,17 @@ class SQLScript: for statement in self.statements ) + def check_tables_present(self, tables: set[str]) -> bool: + """ + Check if any of the given tables are present in the script. + + :param tables: Set of table names to check for (case-insensitive) + :return: True if any of the tables are present + """ + return any( + statement.check_tables_present(tables) for statement in self.statements + ) + def is_valid_ctas(self) -> bool: """ Check if the script contains a valid CTAS statement. diff --git a/superset/sql_lab.py b/superset/sql_lab.py index 8ce20da5bb6..5e170630812 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -45,6 +45,7 @@ from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.exceptions import ( OAuth2RedirectError, SupersetDisallowedSQLFunctionException, + SupersetDisallowedSQLTableException, SupersetDMLNotAllowedException, SupersetErrorException, SupersetErrorsException, @@ -411,6 +412,20 @@ def execute_sql_statements( # noqa: C901 ): raise SupersetDisallowedSQLFunctionException(disallowed_functions) + disallowed_tables = app.config["DISALLOWED_SQL_TABLES"].get( + db_engine_spec.engine, + set(), + ) + if disallowed_tables and parsed_script.check_tables_present(disallowed_tables): + # Report only the tables actually found in the query + found_tables = set() + for statement in parsed_script.statements: + present = {table.table.lower() for table in statement.tables} + for table in disallowed_tables: + if table.lower() in present: + found_tables.add(table) + raise SupersetDisallowedSQLTableException(found_tables or disallowed_tables) + if parsed_script.has_mutation() and not database.allow_dml: raise SupersetDMLNotAllowedException() diff --git a/tests/unit_tests/sql/execution/test_executor.py b/tests/unit_tests/sql/execution/test_executor.py index 5168959e9ea..8cfd87e2ead 100644 --- a/tests/unit_tests/sql/execution/test_executor.py +++ b/tests/unit_tests/sql/execution/test_executor.py @@ -350,6 +350,64 @@ def test_execute_allowed_functions( assert result.status == QueryStatus.SUCCESS +def test_execute_disallowed_tables( + mocker: MockerFixture, database: Database, app_context: None +) -> None: + """Test that disallowed SQL tables are blocked.""" + mocker.patch.dict( + current_app.config, + { + "SQL_QUERY_MUTATOR": None, + "SQLLAB_TIMEOUT": 30, + "DISALLOWED_SQL_FUNCTIONS": {}, + "DISALLOWED_SQL_TABLES": {"sqlite": {"pg_stat_activity", "pg_roles"}}, + }, + ) + + result = database.execute("SELECT * FROM pg_stat_activity") + + assert result.status == QueryStatus.FAILED + assert result.error_message is not None + assert "Disallowed SQL tables: pg_stat_activity" == result.error_message + + +def test_execute_allowed_tables( + mocker: MockerFixture, database: Database, app_context: None +) -> None: + """Test that allowed SQL tables work normally.""" + mock_query_execution(mocker, database, return_data=[(1,)], column_names=["id"]) + mocker.patch.dict( + current_app.config, + { + "SQL_QUERY_MUTATOR": None, + "SQLLAB_TIMEOUT": 30, + "SQL_MAX_ROW": None, + "DISALLOWED_SQL_FUNCTIONS": {}, + "DISALLOWED_SQL_TABLES": {"sqlite": {"pg_stat_activity", "pg_roles"}}, + "QUERY_LOGGER": None, + }, + ) + + result = database.execute("SELECT * FROM users") + + assert result.status == QueryStatus.SUCCESS + + +def test_check_disallowed_tables_no_config( + mocker: MockerFixture, database: Database, app_context: None +) -> None: + """Test disallowed tables check when no config exists.""" + from superset.sql.execution.executor import SQLExecutor + + mocker.patch.dict(current_app.config, {"DISALLOWED_SQL_TABLES": {}}) + + executor = SQLExecutor(database) + script = MagicMock() + result = executor._check_disallowed_tables(script) + + assert result is None + + # ============================================================================= # Row-Level Security Tests # ============================================================================= @@ -1293,7 +1351,8 @@ def test_async_handle_get_result_query_not_found( query_result = result.get_result() assert query_result.status == QueryStatus.FAILED - assert "not found" in query_result.error_message.lower() # type: ignore[union-attr] + assert query_result.error_message is not None + assert "not found" in query_result.error_message.lower() def test_async_handle_get_result_pending( @@ -1434,7 +1493,8 @@ def test_async_handle_get_result_backend_load_error( query_result = result.get_result() assert query_result.status == QueryStatus.FAILED - assert "Error loading results" in query_result.error_message # type: ignore[operator] + assert query_result.error_message is not None + assert "Error loading results" in query_result.error_message def test_async_handle_get_result_no_results_key( @@ -1465,7 +1525,8 @@ def test_async_handle_get_result_no_results_key( query_result = result.get_result() assert query_result.status == QueryStatus.FAILED - assert "Results not available" in query_result.error_message # type: ignore[operator] + assert query_result.error_message is not None + assert "Results not available" in query_result.error_message def test_async_handle_get_status_query_not_found( @@ -1970,7 +2031,8 @@ def test_async_handle_get_result_with_empty_blob( # Should return failure when blob not found assert query_result.status == QueryStatus.FAILED - assert "Results not available" in query_result.error_message # type: ignore[operator] + assert query_result.error_message is not None + assert "Results not available" in query_result.error_message def test_async_handle_get_result_no_results_backend( @@ -2010,7 +2072,8 @@ def test_async_handle_get_result_no_results_backend( # Should return failure when no results backend assert query_result.status == QueryStatus.FAILED - assert "Results not available" in query_result.error_message # type: ignore[operator] + assert query_result.error_message is not None + assert "Results not available" in query_result.error_message def test_create_query_record_with_user( diff --git a/tests/unit_tests/sql/parse_tests.py b/tests/unit_tests/sql/parse_tests.py index f8e7251d808..b50718173cb 100644 --- a/tests/unit_tests/sql/parse_tests.py +++ b/tests/unit_tests/sql/parse_tests.py @@ -2942,6 +2942,39 @@ def test_check_functions_present(sql: str, engine: str, expected: bool) -> None: assert SQLScript(sql, engine).check_functions_present(functions) == expected +@pytest.mark.parametrize( + "sql, engine, expected", + [ + ("SELECT * FROM my_table", "postgresql", False), + ("SELECT * FROM pg_stat_activity", "postgresql", True), + ("SELECT * FROM PG_STAT_ACTIVITY", "postgresql", True), + ("SELECT * FROM pg_roles", "postgresql", True), + ( + "WITH cte AS (SELECT 1) SELECT * FROM cte", + "postgresql", + False, + ), + ( + "SELECT * FROM my_table; SELECT * FROM pg_settings", + "postgresql", + True, + ), + ( + "SELECT * FROM schema.pg_stat_activity", + "postgresql", + True, + ), + ("Table | limit 10", "kustokql", False), + ], +) +def test_check_tables_present(sql: str, engine: str, expected: bool) -> None: + """ + Check the `check_tables_present` method. + """ + tables = {"pg_stat_activity", "pg_roles", "pg_settings"} + assert SQLScript(sql, engine).check_tables_present(tables) == expected + + @pytest.mark.parametrize( "kql, expected", [