mirror of
https://github.com/apache/superset.git
synced 2026-04-07 18:35:15 +00:00
chore: catch sqlalchemy error (#34768)
This commit is contained in:
committed by
GitHub
parent
b45141b2a1
commit
009b99bfbb
@@ -21,8 +21,11 @@ from typing import Any, Generic, get_args, TypeVar
|
||||
from flask_appbuilder.models.filters import BaseFilter
|
||||
from flask_appbuilder.models.sqla import Model
|
||||
from flask_appbuilder.models.sqla.interface import SQLAInterface
|
||||
from sqlalchemy.exc import StatementError
|
||||
from sqlalchemy.exc import SQLAlchemyError, StatementError
|
||||
|
||||
from superset.daos.exceptions import (
|
||||
DAOFindFailedError,
|
||||
)
|
||||
from superset.extensions import db
|
||||
|
||||
T = TypeVar("T", bound=Model)
|
||||
@@ -81,7 +84,7 @@ class BaseDAO(Generic[T]):
|
||||
Find a List of models by a list of ids, if defined applies `base_filter`
|
||||
"""
|
||||
id_col = getattr(cls.model_cls, cls.id_column_name, None)
|
||||
if id_col is None:
|
||||
if id_col is None or not model_ids:
|
||||
return []
|
||||
query = db.session.query(cls.model_cls).filter(id_col.in_(model_ids))
|
||||
if cls.base_filter and not skip_base_filter:
|
||||
@@ -89,7 +92,16 @@ class BaseDAO(Generic[T]):
|
||||
query = cls.base_filter( # pylint: disable=not-callable
|
||||
cls.id_column_name, data_model
|
||||
).apply(query, None)
|
||||
return query.all()
|
||||
|
||||
try:
|
||||
results = query.all()
|
||||
except SQLAlchemyError as ex:
|
||||
model_name = cls.model_cls.__name__ if cls.model_cls else "Unknown"
|
||||
raise DAOFindFailedError(
|
||||
f"Failed to find {model_name} with ids: {model_ids}"
|
||||
) from ex
|
||||
|
||||
return results
|
||||
|
||||
@classmethod
|
||||
def find_all(cls) -> list[T]:
|
||||
|
||||
@@ -23,6 +23,39 @@ class DAOException(SupersetException):
|
||||
"""
|
||||
|
||||
|
||||
class DAOFindFailedError(DAOException):
|
||||
"""
|
||||
DAO Find failed
|
||||
"""
|
||||
|
||||
status = 400
|
||||
message = "Find failed"
|
||||
|
||||
|
||||
class DAOCreateFailedError(DAOException):
|
||||
"""
|
||||
DAO Create failed
|
||||
"""
|
||||
|
||||
message = "Create failed"
|
||||
|
||||
|
||||
class DAOUpdateFailedError(DAOException):
|
||||
"""
|
||||
DAO Update failed
|
||||
"""
|
||||
|
||||
message = "Update failed"
|
||||
|
||||
|
||||
class DAODeleteFailedError(DAOException):
|
||||
"""
|
||||
DAO Delete failed
|
||||
"""
|
||||
|
||||
message = "Delete failed"
|
||||
|
||||
|
||||
class DatasourceTypeNotSupportedError(DAOException):
|
||||
"""
|
||||
DAO datasource query source type is not supported
|
||||
|
||||
142
tests/unit_tests/dao/base_dao_test.py
Normal file
142
tests/unit_tests/dao/base_dao_test.py
Normal file
@@ -0,0 +1,142 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from superset.daos.base import BaseDAO
|
||||
from superset.daos.exceptions import DAOFindFailedError
|
||||
|
||||
|
||||
class MockModel:
|
||||
def __init__(self, id=1, name="test"):
|
||||
self.id = id
|
||||
self.name = name
|
||||
|
||||
|
||||
class TestDAO(BaseDAO[MockModel]):
|
||||
model_cls = MockModel
|
||||
|
||||
|
||||
class TestDAOWithNoneModel(BaseDAO[MockModel]):
|
||||
model_cls = None
|
||||
|
||||
|
||||
def test_find_by_ids_sqlalchemy_error_with_model_cls():
|
||||
"""Test SQLAlchemyError in find_by_ids shows proper model name
|
||||
when model_cls is set"""
|
||||
|
||||
with (
|
||||
patch("superset.daos.base.db") as mock_db,
|
||||
patch("superset.daos.base.getattr") as mock_getattr,
|
||||
):
|
||||
mock_session = Mock()
|
||||
mock_db.session = mock_session
|
||||
|
||||
# Mock the id column to have an in_ method
|
||||
mock_id_col = Mock()
|
||||
mock_id_col.in_.return_value = Mock()
|
||||
mock_getattr.return_value = mock_id_col
|
||||
|
||||
mock_query = Mock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.all.side_effect = SQLAlchemyError("Database error")
|
||||
|
||||
with pytest.raises(DAOFindFailedError) as exc_info:
|
||||
TestDAO.find_by_ids([1, 2])
|
||||
|
||||
assert "Failed to find MockModel with ids: [1, 2]" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_find_by_ids_sqlalchemy_error_with_none_model_cls():
|
||||
"""Test SQLAlchemyError in find_by_ids shows 'Unknown' when model_cls is None"""
|
||||
|
||||
with (
|
||||
patch("superset.daos.base.db") as mock_db,
|
||||
patch("superset.daos.base.getattr") as mock_getattr,
|
||||
):
|
||||
mock_session = Mock()
|
||||
mock_db.session = mock_session
|
||||
|
||||
# Mock the id column to have an in_ method but return from a None model_cls
|
||||
mock_id_col = Mock()
|
||||
mock_id_col.in_.return_value = Mock()
|
||||
mock_getattr.return_value = mock_id_col
|
||||
|
||||
mock_query = Mock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.all.side_effect = SQLAlchemyError("Database error")
|
||||
|
||||
# Set model_cls to None but allow method to proceed past guard clause
|
||||
with patch.object(TestDAOWithNoneModel, "model_cls", None):
|
||||
with pytest.raises(DAOFindFailedError) as exc_info:
|
||||
TestDAOWithNoneModel.find_by_ids([1, 2])
|
||||
|
||||
assert "Failed to find Unknown with ids: [1, 2]" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_find_by_ids_successful_execution():
|
||||
"""Test that find_by_ids works normally when no SQLAlchemyError occurs"""
|
||||
|
||||
with (
|
||||
patch("superset.daos.base.db") as mock_db,
|
||||
patch("superset.daos.base.getattr") as mock_getattr,
|
||||
):
|
||||
mock_session = Mock()
|
||||
mock_db.session = mock_session
|
||||
|
||||
# Mock the id column to have an in_ method
|
||||
mock_id_col = Mock()
|
||||
mock_id_col.in_.return_value = Mock()
|
||||
mock_getattr.return_value = mock_id_col
|
||||
|
||||
mock_query = Mock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
|
||||
expected_results = [MockModel(1, "test1"), MockModel(2, "test2")]
|
||||
mock_query.all.return_value = expected_results
|
||||
|
||||
results = TestDAO.find_by_ids([1, 2])
|
||||
|
||||
assert results == expected_results
|
||||
mock_query.all.assert_called_once()
|
||||
|
||||
|
||||
def test_find_by_ids_empty_list():
|
||||
"""Test that find_by_ids returns empty list when model_ids is empty"""
|
||||
|
||||
with patch("superset.daos.base.getattr") as mock_getattr:
|
||||
mock_getattr.return_value = None
|
||||
|
||||
results = TestDAO.find_by_ids([])
|
||||
|
||||
assert results == []
|
||||
|
||||
|
||||
def test_find_by_ids_none_id_column():
|
||||
"""Test that find_by_ids returns empty list when id column doesn't exist"""
|
||||
|
||||
with patch("superset.daos.base.getattr") as mock_getattr:
|
||||
mock_getattr.return_value = None
|
||||
|
||||
results = TestDAO.find_by_ids([1, 2])
|
||||
|
||||
assert results == []
|
||||
Reference in New Issue
Block a user