From b794b192d146c4cd00712ba152dffb3e6002fca9 Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Thu, 24 Jul 2025 16:40:56 -0400 Subject: [PATCH] fix: return 422 on invalid SQL (#34303) --- superset/commands/dataset/create.py | 11 +- superset/connectors/sqla/utils.py | 8 +- superset/sql/parse.py | 2 +- tests/integration_tests/datasets/api_tests.py | 30 ++++ .../commands/dataset/test_create.py | 170 ++++++++++++++++++ .../unit_tests/connectors/sqla/test_utils.py | 123 +++++++++++++ 6 files changed, 340 insertions(+), 4 deletions(-) create mode 100644 tests/unit_tests/commands/dataset/test_create.py create mode 100644 tests/unit_tests/connectors/sqla/test_utils.py diff --git a/superset/commands/dataset/create.py b/superset/commands/dataset/create.py index cbfe1bfc9e3..d7eab2b7bb3 100644 --- a/superset/commands/dataset/create.py +++ b/superset/commands/dataset/create.py @@ -31,7 +31,7 @@ from superset.commands.dataset.exceptions import ( TableNotFoundValidationError, ) from superset.daos.dataset import DatasetDAO -from superset.exceptions import SupersetSecurityException +from superset.exceptions import SupersetParseError, SupersetSecurityException from superset.extensions import security_manager from superset.sql.parse import Table from superset.utils.decorators import on_error, transaction @@ -51,7 +51,7 @@ class CreateDatasetCommand(CreateMixin, BaseCommand): dataset.fetch_metadata() return dataset - def validate(self) -> None: + def validate(self) -> None: # noqa: C901 exceptions: list[ValidationError] = [] database_id = self._properties["database"] catalog = self._properties.get("catalog") @@ -95,6 +95,13 @@ class CreateDatasetCommand(CreateMixin, BaseCommand): ) except SupersetSecurityException as ex: exceptions.append(DatasetDataAccessIsNotAllowed(ex.error.message)) + except SupersetParseError as ex: + exceptions.append( + ValidationError( + f"Invalid SQL: {ex.error.message}", + field_name="sql", + ) + ) try: owners = self.populate_owners(owner_ids) self._properties["owners"] = owners diff --git a/superset/connectors/sqla/utils.py b/superset/connectors/sqla/utils.py index 172103a9af4..6c0d2a82606 100644 --- a/superset/connectors/sqla/utils.py +++ b/superset/connectors/sqla/utils.py @@ -34,6 +34,7 @@ from superset.constants import LRU_CACHE_MAX_SIZE from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.exceptions import ( SupersetGenericDBErrorException, + SupersetParseError, SupersetSecurityException, ) from superset.models.core import Database @@ -105,7 +106,12 @@ def get_virtual_table_metadata(dataset: SqlaTable) -> list[ResultSetColumnType]: sql = dataset.get_template_processor().process_template( dataset.sql, **dataset.template_params_dict ) - parsed_script = SQLScript(sql, engine=db_engine_spec.engine) + try: + parsed_script = SQLScript(sql, engine=db_engine_spec.engine) + except SupersetParseError as ex: + raise SupersetGenericDBErrorException( + message=_("Invalid SQL: %(error)s", error=ex.error.message), + ) from ex if parsed_script.has_mutation(): raise SupersetSecurityException( SupersetError( diff --git a/superset/sql/parse.py b/superset/sql/parse.py index ed8f6e4c270..6af483aed6e 100644 --- a/superset/sql/parse.py +++ b/superset/sql/parse.py @@ -644,7 +644,7 @@ class SQLStatement(BaseSQLStatement[exp.Expression]): # depending on the dialect (Oracle, MS SQL) the `ALTER` is parsed as a # command, not an expression - check at root level if isinstance(self._parsed, exp.Command) and self._parsed.name == "ALTER": - return True + return True # pragma: no cover # Postgres runs DMLs prefixed by `EXPLAIN ANALYZE`, see # https://www.postgresql.org/docs/current/sql-explain.html diff --git a/tests/integration_tests/datasets/api_tests.py b/tests/integration_tests/datasets/api_tests.py index dcb4b808b22..2906a6612ec 100644 --- a/tests/integration_tests/datasets/api_tests.py +++ b/tests/integration_tests/datasets/api_tests.py @@ -974,6 +974,36 @@ class TestDatasetApi(SupersetTestCase): assert rv.status_code == 422 assert data == {"message": "Dataset could not be created."} + @patch("superset.commands.dataset.create.security_manager.raise_for_access") + def test_create_dataset_with_invalid_sql_validation(self, mock_raise_for_access): + """ + Dataset API: Test create dataset with invalid SQL during validation returns 422 + """ + from superset.exceptions import SupersetParseError + + # Mock raise_for_access to throw SupersetParseError during validation + mock_raise_for_access.side_effect = SupersetParseError( + sql="SELECT FROM WHERE AND", + engine="postgresql", + message="Invalid SQL syntax", + ) + + self.login(ADMIN_USERNAME) + examples_db = get_example_database() + dataset_data = { + "database": examples_db.id, + "schema": "", + "table_name": "invalid_sql_table", + "sql": "SELECT FROM WHERE AND", + } + uri = "api/v1/dataset/" + rv = self.client.post(uri, json=dataset_data) + data = json.loads(rv.data.decode("utf-8")) + # The error is caught during validation and returns 422 + assert rv.status_code == 422 + assert "sql" in data["message"] + assert "Invalid SQL:" in data["message"]["sql"][0] + def test_update_dataset_preserve_ownership(self): """ Dataset API: Test update dataset preserves owner list (if un-changed) diff --git a/tests/unit_tests/commands/dataset/test_create.py b/tests/unit_tests/commands/dataset/test_create.py new file mode 100644 index 00000000000..4202cb3ff2f --- /dev/null +++ b/tests/unit_tests/commands/dataset/test_create.py @@ -0,0 +1,170 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from unittest.mock import Mock, patch + +import pytest +from marshmallow import ValidationError + +from superset.commands.dataset.create import CreateDatasetCommand +from superset.commands.dataset.exceptions import DatasetInvalidError +from superset.errors import ErrorLevel, SupersetError, SupersetErrorType +from superset.exceptions import SupersetParseError +from superset.models.core import Database + + +def test_create_dataset_invalid_sql_parse_error(): + """Test that invalid SQL returns a 4xx error when caught as SupersetParseError.""" + mock_database = Mock(spec=Database) + mock_database.id = 1 + mock_database.db_engine_spec.engine = "postgresql" + mock_database.get_default_catalog.return_value = None + + with patch( + "superset.commands.dataset.create.DatasetDAO.get_database_by_id", + return_value=mock_database, + ): + with patch( + "superset.commands.dataset.create.DatasetDAO.validate_uniqueness", + return_value=True, + ): + with patch( + "superset.commands.dataset.create.security_manager.raise_for_access", + side_effect=SupersetParseError( + sql="SELECT INVALID SQL SYNTAX", + engine="postgresql", + message="Invalid SQL syntax: unexpected token 'INVALID'", + ), + ): + with patch( + "superset.commands.dataset.create.CreateDatasetCommand.populate_owners", + return_value=[], + ): + command = CreateDatasetCommand( + { + "database": 1, + "table_name": "test_virtual_dataset", + "sql": "SELECT INVALID SQL SYNTAX", + } + ) + + with pytest.raises(DatasetInvalidError) as exc_info: + command.validate() + + # Verify the exception contains the correct validation error + validation_errors = exc_info.value._exceptions + assert len(validation_errors) == 1 + assert isinstance(validation_errors[0], ValidationError) + assert validation_errors[0].field_name == "sql" + assert "Invalid SQL:" in str(validation_errors[0].messages[0]) + assert "unexpected token 'INVALID'" in str( + validation_errors[0].messages[0] + ) + + +def test_create_dataset_valid_sql_with_access_error(): + """ + Test that security exceptions work correctly + """ + mock_database = Mock(spec=Database) + mock_database.id = 1 + mock_database.db_engine_spec.engine = "postgresql" + mock_database.get_default_catalog.return_value = None + + from superset.exceptions import SupersetSecurityException + + with patch( + "superset.commands.dataset.create.DatasetDAO.get_database_by_id", + return_value=mock_database, + ): + with patch( + "superset.commands.dataset.create.DatasetDAO.validate_uniqueness", + return_value=True, + ): + with patch( + "superset.commands.dataset.create.security_manager.raise_for_access", + side_effect=SupersetSecurityException( + SupersetError( + error_type=SupersetErrorType.DATASOURCE_SECURITY_ACCESS_ERROR, + message="User does not have access to table 'secret_table'", + level=ErrorLevel.ERROR, + ) + ), + ): + with patch( + "superset.commands.dataset.create.CreateDatasetCommand.populate_owners", + return_value=[], + ): + command = CreateDatasetCommand( + { + "database": 1, + "table_name": "test_virtual_dataset", + "sql": "SELECT * FROM secret_table", + } + ) + + with pytest.raises(DatasetInvalidError) as exc_info: + command.validate() + + # Verify the security error is handled correctly (existing behavior) + validation_errors = exc_info.value._exceptions + assert len(validation_errors) == 1 + # This should be a DatasetDataAccessIsNotAllowed error + from superset.commands.dataset.exceptions import ( + DatasetDataAccessIsNotAllowed, + ) + + assert isinstance( + validation_errors[0], DatasetDataAccessIsNotAllowed + ) + assert validation_errors[0].field_name == "sql" + assert "User does not have access to table 'secret_table'" in str( + validation_errors[0].messages[0] + ) + + +def test_create_dataset_physical_table_no_parse_error(): + """Test that physical tables (no SQL) don't trigger parsing.""" + mock_database = Mock(spec=Database) + mock_database.id = 1 + mock_database.get_default_catalog.return_value = None + + with patch( + "superset.commands.dataset.create.DatasetDAO.get_database_by_id", + return_value=mock_database, + ): + with patch( + "superset.commands.dataset.create.DatasetDAO.validate_uniqueness", + return_value=True, + ): + with patch( + "superset.commands.dataset.create.DatasetDAO.validate_table_exists", + return_value=True, + ): + with patch( + "superset.commands.dataset.create.CreateDatasetCommand.populate_owners", + return_value=[], + ): + command = CreateDatasetCommand( + { + "database": 1, + "table_name": "physical_table", + # No SQL provided - this is a physical table + } + ) + + # Should not raise any parsing errors + command.validate() diff --git a/tests/unit_tests/connectors/sqla/test_utils.py b/tests/unit_tests/connectors/sqla/test_utils.py new file mode 100644 index 00000000000..be3e3662e19 --- /dev/null +++ b/tests/unit_tests/connectors/sqla/test_utils.py @@ -0,0 +1,123 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from unittest.mock import Mock, patch + +import pytest + +from superset.connectors.sqla.models import SqlaTable +from superset.connectors.sqla.utils import get_virtual_table_metadata +from superset.errors import SupersetErrorType +from superset.exceptions import ( + SupersetGenericDBErrorException, + SupersetSecurityException, +) +from superset.models.core import Database + + +def test_get_virtual_table_metadata_invalid_sql(): + """Test that invalid SQL in virtual table raises proper exception.""" + mock_dataset = Mock(spec=SqlaTable) + mock_database = Mock(spec=Database) + mock_dataset.database = mock_database + mock_dataset.sql = "SELECT INVALID SYNTAX FROM" + mock_database.db_engine_spec.engine = "postgresql" + + # Mock template processor + mock_template_processor = Mock() + mock_template_processor.process_template.return_value = "SELECT INVALID SYNTAX FROM" + mock_dataset.get_template_processor.return_value = mock_template_processor + mock_dataset.template_params_dict = {} + + with pytest.raises(SupersetGenericDBErrorException) as exc_info: + get_virtual_table_metadata(mock_dataset) + + # Check that the error message includes the parsing error + assert "Invalid SQL:" in str(exc_info.value.message) + + +def test_get_virtual_table_metadata_empty_sql(): + """Test that empty SQL raises appropriate error.""" + mock_dataset = Mock(spec=SqlaTable) + mock_dataset.sql = None + + with pytest.raises(SupersetGenericDBErrorException) as exc_info: + get_virtual_table_metadata(mock_dataset) + + assert "Virtual dataset query cannot be empty" in str(exc_info.value.message) + + +def test_get_virtual_table_metadata_mutation_not_allowed(): + """Test that SQL with mutations raises security error.""" + mock_dataset = Mock(spec=SqlaTable) + mock_database = Mock(spec=Database) + mock_dataset.database = mock_database + mock_dataset.sql = "DELETE FROM users" + mock_database.db_engine_spec.engine = "postgresql" + + # Mock template processor + mock_template_processor = Mock() + mock_template_processor.process_template.return_value = "DELETE FROM users" + mock_dataset.get_template_processor.return_value = mock_template_processor + mock_dataset.template_params_dict = {} + + # Mock SQLScript to simulate mutation detection + with patch("superset.connectors.sqla.utils.SQLScript") as mock_sqlscript_class: + mock_script = Mock() + mock_script.has_mutation.return_value = True + mock_sqlscript_class.return_value = mock_script + + with pytest.raises(SupersetSecurityException) as exc_info: + get_virtual_table_metadata(mock_dataset) + + assert ( + exc_info.value.error.error_type + == SupersetErrorType.DATASOURCE_SECURITY_ACCESS_ERROR + ) + assert "Only `SELECT` statements are allowed" in exc_info.value.error.message + + +def test_get_virtual_table_metadata_multiple_statements_not_allowed(): + """Test that multiple SQL statements raise security error.""" + mock_dataset = Mock(spec=SqlaTable) + mock_database = Mock(spec=Database) + mock_dataset.database = mock_database + mock_dataset.sql = "SELECT * FROM table1; SELECT * FROM table2" + mock_database.db_engine_spec.engine = "postgresql" + + # Mock template processor + mock_template_processor = Mock() + mock_template_processor.process_template.return_value = ( + "SELECT * FROM table1; SELECT * FROM table2" + ) + mock_dataset.get_template_processor.return_value = mock_template_processor + mock_dataset.template_params_dict = {} + + # Mock SQLScript to simulate multiple statements + with patch("superset.connectors.sqla.utils.SQLScript") as mock_sqlscript_class: + mock_script = Mock() + mock_script.has_mutation.return_value = False + mock_script.statements = [Mock(), Mock()] # Two statements + mock_sqlscript_class.return_value = mock_script + + with pytest.raises(SupersetSecurityException) as exc_info: + get_virtual_table_metadata(mock_dataset) + + assert ( + exc_info.value.error.error_type + == SupersetErrorType.DATASOURCE_SECURITY_ACCESS_ERROR + ) + assert "Only single queries supported" in exc_info.value.error.message