From 009b99bfbbbe693d498879c0c3e22efbdc5ed42c Mon Sep 17 00:00:00 2001 From: Elizabeth Thompson Date: Wed, 20 Aug 2025 18:00:06 -0700 Subject: [PATCH] chore: catch sqlalchemy error (#34768) --- superset/daos/base.py | 18 +++- superset/daos/exceptions.py | 33 ++++++ tests/unit_tests/dao/base_dao_test.py | 142 ++++++++++++++++++++++++++ 3 files changed, 190 insertions(+), 3 deletions(-) create mode 100644 tests/unit_tests/dao/base_dao_test.py diff --git a/superset/daos/base.py b/superset/daos/base.py index e393034062b..189d53d4e26 100644 --- a/superset/daos/base.py +++ b/superset/daos/base.py @@ -21,8 +21,11 @@ from typing import Any, Generic, get_args, TypeVar from flask_appbuilder.models.filters import BaseFilter from flask_appbuilder.models.sqla import Model from flask_appbuilder.models.sqla.interface import SQLAInterface -from sqlalchemy.exc import StatementError +from sqlalchemy.exc import SQLAlchemyError, StatementError +from superset.daos.exceptions import ( + DAOFindFailedError, +) from superset.extensions import db T = TypeVar("T", bound=Model) @@ -81,7 +84,7 @@ class BaseDAO(Generic[T]): Find a List of models by a list of ids, if defined applies `base_filter` """ id_col = getattr(cls.model_cls, cls.id_column_name, None) - if id_col is None: + if id_col is None or not model_ids: return [] query = db.session.query(cls.model_cls).filter(id_col.in_(model_ids)) if cls.base_filter and not skip_base_filter: @@ -89,7 +92,16 @@ class BaseDAO(Generic[T]): query = cls.base_filter( # pylint: disable=not-callable cls.id_column_name, data_model ).apply(query, None) - return query.all() + + try: + results = query.all() + except SQLAlchemyError as ex: + model_name = cls.model_cls.__name__ if cls.model_cls else "Unknown" + raise DAOFindFailedError( + f"Failed to find {model_name} with ids: {model_ids}" + ) from ex + + return results @classmethod def find_all(cls) -> list[T]: diff --git a/superset/daos/exceptions.py b/superset/daos/exceptions.py index ebd20fee631..1b9fdf606d9 100644 --- a/superset/daos/exceptions.py +++ b/superset/daos/exceptions.py @@ -23,6 +23,39 @@ class DAOException(SupersetException): """ +class DAOFindFailedError(DAOException): + """ + DAO Find failed + """ + + status = 400 + message = "Find failed" + + +class DAOCreateFailedError(DAOException): + """ + DAO Create failed + """ + + message = "Create failed" + + +class DAOUpdateFailedError(DAOException): + """ + DAO Update failed + """ + + message = "Update failed" + + +class DAODeleteFailedError(DAOException): + """ + DAO Delete failed + """ + + message = "Delete failed" + + class DatasourceTypeNotSupportedError(DAOException): """ DAO datasource query source type is not supported diff --git a/tests/unit_tests/dao/base_dao_test.py b/tests/unit_tests/dao/base_dao_test.py new file mode 100644 index 00000000000..fdf1417ec48 --- /dev/null +++ b/tests/unit_tests/dao/base_dao_test.py @@ -0,0 +1,142 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from unittest.mock import Mock, patch + +import pytest +from sqlalchemy.exc import SQLAlchemyError + +from superset.daos.base import BaseDAO +from superset.daos.exceptions import DAOFindFailedError + + +class MockModel: + def __init__(self, id=1, name="test"): + self.id = id + self.name = name + + +class TestDAO(BaseDAO[MockModel]): + model_cls = MockModel + + +class TestDAOWithNoneModel(BaseDAO[MockModel]): + model_cls = None + + +def test_find_by_ids_sqlalchemy_error_with_model_cls(): + """Test SQLAlchemyError in find_by_ids shows proper model name + when model_cls is set""" + + with ( + patch("superset.daos.base.db") as mock_db, + patch("superset.daos.base.getattr") as mock_getattr, + ): + mock_session = Mock() + mock_db.session = mock_session + + # Mock the id column to have an in_ method + mock_id_col = Mock() + mock_id_col.in_.return_value = Mock() + mock_getattr.return_value = mock_id_col + + mock_query = Mock() + mock_session.query.return_value = mock_query + mock_query.filter.return_value = mock_query + mock_query.all.side_effect = SQLAlchemyError("Database error") + + with pytest.raises(DAOFindFailedError) as exc_info: + TestDAO.find_by_ids([1, 2]) + + assert "Failed to find MockModel with ids: [1, 2]" in str(exc_info.value) + + +def test_find_by_ids_sqlalchemy_error_with_none_model_cls(): + """Test SQLAlchemyError in find_by_ids shows 'Unknown' when model_cls is None""" + + with ( + patch("superset.daos.base.db") as mock_db, + patch("superset.daos.base.getattr") as mock_getattr, + ): + mock_session = Mock() + mock_db.session = mock_session + + # Mock the id column to have an in_ method but return from a None model_cls + mock_id_col = Mock() + mock_id_col.in_.return_value = Mock() + mock_getattr.return_value = mock_id_col + + mock_query = Mock() + mock_session.query.return_value = mock_query + mock_query.filter.return_value = mock_query + mock_query.all.side_effect = SQLAlchemyError("Database error") + + # Set model_cls to None but allow method to proceed past guard clause + with patch.object(TestDAOWithNoneModel, "model_cls", None): + with pytest.raises(DAOFindFailedError) as exc_info: + TestDAOWithNoneModel.find_by_ids([1, 2]) + + assert "Failed to find Unknown with ids: [1, 2]" in str(exc_info.value) + + +def test_find_by_ids_successful_execution(): + """Test that find_by_ids works normally when no SQLAlchemyError occurs""" + + with ( + patch("superset.daos.base.db") as mock_db, + patch("superset.daos.base.getattr") as mock_getattr, + ): + mock_session = Mock() + mock_db.session = mock_session + + # Mock the id column to have an in_ method + mock_id_col = Mock() + mock_id_col.in_.return_value = Mock() + mock_getattr.return_value = mock_id_col + + mock_query = Mock() + mock_session.query.return_value = mock_query + mock_query.filter.return_value = mock_query + + expected_results = [MockModel(1, "test1"), MockModel(2, "test2")] + mock_query.all.return_value = expected_results + + results = TestDAO.find_by_ids([1, 2]) + + assert results == expected_results + mock_query.all.assert_called_once() + + +def test_find_by_ids_empty_list(): + """Test that find_by_ids returns empty list when model_ids is empty""" + + with patch("superset.daos.base.getattr") as mock_getattr: + mock_getattr.return_value = None + + results = TestDAO.find_by_ids([]) + + assert results == [] + + +def test_find_by_ids_none_id_column(): + """Test that find_by_ids returns empty list when id column doesn't exist""" + + with patch("superset.daos.base.getattr") as mock_getattr: + mock_getattr.return_value = None + + results = TestDAO.find_by_ids([1, 2]) + + assert results == []