chore: catch sqlalchemy error (#34768)

This commit is contained in:
Elizabeth Thompson
2025-08-20 18:00:06 -07:00
committed by GitHub
parent b45141b2a1
commit 009b99bfbb
3 changed files with 190 additions and 3 deletions

View File

@@ -21,8 +21,11 @@ from typing import Any, Generic, get_args, TypeVar
from flask_appbuilder.models.filters import BaseFilter
from flask_appbuilder.models.sqla import Model
from flask_appbuilder.models.sqla.interface import SQLAInterface
from sqlalchemy.exc import StatementError
from sqlalchemy.exc import SQLAlchemyError, StatementError
from superset.daos.exceptions import (
DAOFindFailedError,
)
from superset.extensions import db
T = TypeVar("T", bound=Model)
@@ -81,7 +84,7 @@ class BaseDAO(Generic[T]):
Find a List of models by a list of ids, if defined applies `base_filter`
"""
id_col = getattr(cls.model_cls, cls.id_column_name, None)
if id_col is None:
if id_col is None or not model_ids:
return []
query = db.session.query(cls.model_cls).filter(id_col.in_(model_ids))
if cls.base_filter and not skip_base_filter:
@@ -89,7 +92,16 @@ class BaseDAO(Generic[T]):
query = cls.base_filter( # pylint: disable=not-callable
cls.id_column_name, data_model
).apply(query, None)
return query.all()
try:
results = query.all()
except SQLAlchemyError as ex:
model_name = cls.model_cls.__name__ if cls.model_cls else "Unknown"
raise DAOFindFailedError(
f"Failed to find {model_name} with ids: {model_ids}"
) from ex
return results
@classmethod
def find_all(cls) -> list[T]:

View File

@@ -23,6 +23,39 @@ class DAOException(SupersetException):
"""
class DAOFindFailedError(DAOException):
"""
DAO Find failed
"""
status = 400
message = "Find failed"
class DAOCreateFailedError(DAOException):
"""
DAO Create failed
"""
message = "Create failed"
class DAOUpdateFailedError(DAOException):
"""
DAO Update failed
"""
message = "Update failed"
class DAODeleteFailedError(DAOException):
"""
DAO Delete failed
"""
message = "Delete failed"
class DatasourceTypeNotSupportedError(DAOException):
"""
DAO datasource query source type is not supported

View File

@@ -0,0 +1,142 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from unittest.mock import Mock, patch
import pytest
from sqlalchemy.exc import SQLAlchemyError
from superset.daos.base import BaseDAO
from superset.daos.exceptions import DAOFindFailedError
class MockModel:
def __init__(self, id=1, name="test"):
self.id = id
self.name = name
class TestDAO(BaseDAO[MockModel]):
model_cls = MockModel
class TestDAOWithNoneModel(BaseDAO[MockModel]):
model_cls = None
def test_find_by_ids_sqlalchemy_error_with_model_cls():
"""Test SQLAlchemyError in find_by_ids shows proper model name
when model_cls is set"""
with (
patch("superset.daos.base.db") as mock_db,
patch("superset.daos.base.getattr") as mock_getattr,
):
mock_session = Mock()
mock_db.session = mock_session
# Mock the id column to have an in_ method
mock_id_col = Mock()
mock_id_col.in_.return_value = Mock()
mock_getattr.return_value = mock_id_col
mock_query = Mock()
mock_session.query.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.all.side_effect = SQLAlchemyError("Database error")
with pytest.raises(DAOFindFailedError) as exc_info:
TestDAO.find_by_ids([1, 2])
assert "Failed to find MockModel with ids: [1, 2]" in str(exc_info.value)
def test_find_by_ids_sqlalchemy_error_with_none_model_cls():
"""Test SQLAlchemyError in find_by_ids shows 'Unknown' when model_cls is None"""
with (
patch("superset.daos.base.db") as mock_db,
patch("superset.daos.base.getattr") as mock_getattr,
):
mock_session = Mock()
mock_db.session = mock_session
# Mock the id column to have an in_ method but return from a None model_cls
mock_id_col = Mock()
mock_id_col.in_.return_value = Mock()
mock_getattr.return_value = mock_id_col
mock_query = Mock()
mock_session.query.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.all.side_effect = SQLAlchemyError("Database error")
# Set model_cls to None but allow method to proceed past guard clause
with patch.object(TestDAOWithNoneModel, "model_cls", None):
with pytest.raises(DAOFindFailedError) as exc_info:
TestDAOWithNoneModel.find_by_ids([1, 2])
assert "Failed to find Unknown with ids: [1, 2]" in str(exc_info.value)
def test_find_by_ids_successful_execution():
"""Test that find_by_ids works normally when no SQLAlchemyError occurs"""
with (
patch("superset.daos.base.db") as mock_db,
patch("superset.daos.base.getattr") as mock_getattr,
):
mock_session = Mock()
mock_db.session = mock_session
# Mock the id column to have an in_ method
mock_id_col = Mock()
mock_id_col.in_.return_value = Mock()
mock_getattr.return_value = mock_id_col
mock_query = Mock()
mock_session.query.return_value = mock_query
mock_query.filter.return_value = mock_query
expected_results = [MockModel(1, "test1"), MockModel(2, "test2")]
mock_query.all.return_value = expected_results
results = TestDAO.find_by_ids([1, 2])
assert results == expected_results
mock_query.all.assert_called_once()
def test_find_by_ids_empty_list():
"""Test that find_by_ids returns empty list when model_ids is empty"""
with patch("superset.daos.base.getattr") as mock_getattr:
mock_getattr.return_value = None
results = TestDAO.find_by_ids([])
assert results == []
def test_find_by_ids_none_id_column():
"""Test that find_by_ids returns empty list when id column doesn't exist"""
with patch("superset.daos.base.getattr") as mock_getattr:
mock_getattr.return_value = None
results = TestDAO.find_by_ids([1, 2])
assert results == []