diff --git a/superset/views/datasource/utils.py b/superset/views/datasource/utils.py index a44642e437d..096b8482c59 100644 --- a/superset/views/datasource/utils.py +++ b/superset/views/datasource/utils.py @@ -153,6 +153,12 @@ def get_samples( # pylint: disable=too-many-arguments ) try: + # Enforce access control before fetching data. + # This prevents users with "can samples on Datasource" permission from + # reading samples from datasets they don't have access to. + samples_instance.raise_for_access() + count_star_instance.raise_for_access() + count_star_data = count_star_instance.get_payload()["queries"][0] if count_star_data.get("status") == QueryStatus.FAILED: diff --git a/tests/unit_tests/views/datasource/__init__.py b/tests/unit_tests/views/datasource/__init__.py new file mode 100644 index 00000000000..13a83393a91 --- /dev/null +++ b/tests/unit_tests/views/datasource/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/unit_tests/views/datasource/utils_test.py b/tests/unit_tests/views/datasource/utils_test.py new file mode 100644 index 00000000000..4788fea92fa --- /dev/null +++ b/tests/unit_tests/views/datasource/utils_test.py @@ -0,0 +1,217 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Tests for superset.views.datasource.utils module.""" + +from unittest.mock import MagicMock, patch + +import pytest + +from superset.errors import ErrorLevel, SupersetError, SupersetErrorType +from superset.exceptions import SupersetSecurityException + + +@patch("superset.views.datasource.utils.get_limit_clause") +def test_get_samples_raises_security_exception_when_access_denied( + mock_get_limit_clause: MagicMock, +): + """ + Test that get_samples() enforces access control by calling raise_for_access(). + This verifies the fix for issue #31944 where users with "can samples on Datasource" + permission could read samples from datasets they don't have access to. + """ + mock_get_limit_clause.return_value = {"row_offset": 0, "row_limit": 100} + + mock_datasource = MagicMock() + mock_datasource.type = "table" + mock_datasource.id = 1 + mock_datasource.columns = [] + + mock_samples_context = MagicMock() + mock_count_context = MagicMock() + + # Simulate security exception when raise_for_access is called + mock_samples_context.raise_for_access.side_effect = SupersetSecurityException( + SupersetError( + message="Access denied", + error_type=SupersetErrorType.DATASOURCE_SECURITY_ACCESS_ERROR, + level=ErrorLevel.WARNING, + ) + ) + + with ( + patch( + "superset.views.datasource.utils.DatasourceDAO.get_datasource", + return_value=mock_datasource, + ), + patch( + "superset.views.datasource.utils.QueryContextFactory" + ) as mock_factory_class, + ): + mock_factory = MagicMock() + mock_factory_class.return_value = mock_factory + + # Return different mock contexts for samples vs count queries + mock_factory.create.side_effect = [mock_samples_context, mock_count_context] + + from superset.views.datasource.utils import get_samples + + with pytest.raises(SupersetSecurityException) as exc_info: + get_samples( + datasource_type="table", + datasource_id=1, + force=False, + page=1, + per_page=100, + ) + + assert exc_info.value.error.error_type == ( + SupersetErrorType.DATASOURCE_SECURITY_ACCESS_ERROR + ) + # Verify raise_for_access was called on the samples context + mock_samples_context.raise_for_access.assert_called_once() + + +@patch("superset.views.datasource.utils.get_limit_clause") +def test_get_samples_calls_raise_for_access_on_both_contexts( + mock_get_limit_clause: MagicMock, +): + """ + Test that get_samples() calls raise_for_access() on both the samples + and count_star query contexts before fetching data. + """ + mock_get_limit_clause.return_value = {"row_offset": 0, "row_limit": 100} + + mock_datasource = MagicMock() + mock_datasource.type = "table" + mock_datasource.id = 1 + mock_datasource.columns = [] + + mock_samples_context = MagicMock() + mock_count_context = MagicMock() + + # Set up successful access check + mock_samples_context.raise_for_access.return_value = None + mock_count_context.raise_for_access.return_value = None + + # Set up successful payload responses + mock_count_context.get_payload.return_value = { + "queries": [{"data": [{"COUNT(*)": 100}], "status": "success"}] + } + mock_samples_context.get_payload.return_value = { + "queries": [ + { + "data": [{"col1": "val1"}], + "status": "success", + "cache_key": "test_key", + } + ] + } + + with ( + patch( + "superset.views.datasource.utils.DatasourceDAO.get_datasource", + return_value=mock_datasource, + ), + patch( + "superset.views.datasource.utils.QueryContextFactory" + ) as mock_factory_class, + ): + mock_factory = MagicMock() + mock_factory_class.return_value = mock_factory + + # Return different mock contexts for samples vs count queries + mock_factory.create.side_effect = [mock_samples_context, mock_count_context] + + from superset.views.datasource.utils import get_samples + + result = get_samples( + datasource_type="table", + datasource_id=1, + force=False, + page=1, + per_page=100, + ) + + # Verify both contexts had raise_for_access called + mock_samples_context.raise_for_access.assert_called_once() + mock_count_context.raise_for_access.assert_called_once() + + # Verify the result contains expected data + assert result["data"] == [{"col1": "val1"}] + assert result["total_count"] == 100 + + +@patch("superset.views.datasource.utils.get_limit_clause") +def test_get_samples_count_star_access_denied(mock_get_limit_clause: MagicMock): + """ + Test that get_samples() raises security exception when access to count_star + query context is denied. + """ + mock_get_limit_clause.return_value = {"row_offset": 0, "row_limit": 100} + + mock_datasource = MagicMock() + mock_datasource.type = "table" + mock_datasource.id = 1 + mock_datasource.columns = [] + + mock_samples_context = MagicMock() + mock_count_context = MagicMock() + + # Samples context allows access + mock_samples_context.raise_for_access.return_value = None + + # Count context denies access + mock_count_context.raise_for_access.side_effect = SupersetSecurityException( + SupersetError( + message="Access denied to count query", + error_type=SupersetErrorType.DATASOURCE_SECURITY_ACCESS_ERROR, + level=ErrorLevel.WARNING, + ) + ) + + with ( + patch( + "superset.views.datasource.utils.DatasourceDAO.get_datasource", + return_value=mock_datasource, + ), + patch( + "superset.views.datasource.utils.QueryContextFactory" + ) as mock_factory_class, + ): + mock_factory = MagicMock() + mock_factory_class.return_value = mock_factory + + mock_factory.create.side_effect = [mock_samples_context, mock_count_context] + + from superset.views.datasource.utils import get_samples + + with pytest.raises(SupersetSecurityException) as exc_info: + get_samples( + datasource_type="table", + datasource_id=1, + force=False, + page=1, + per_page=100, + ) + + assert exc_info.value.error.error_type == ( + SupersetErrorType.DATASOURCE_SECURITY_ACCESS_ERROR + ) + # Verify samples context was checked first + mock_samples_context.raise_for_access.assert_called_once() + # Verify count context was also checked + mock_count_context.raise_for_access.assert_called_once()