From cd52193869157fd0a035eadf4438428bb0e4dea7 Mon Sep 17 00:00:00 2001 From: Amin Ghadersohi Date: Mon, 7 Jul 2025 20:36:48 +1000 Subject: [PATCH] update: add helper for applying base filter --- superset/daos/base.py | 50 +++++++++++++------------------ tests/unit_tests/dao/base_test.py | 11 ++----- 2 files changed, 24 insertions(+), 37 deletions(-) diff --git a/superset/daos/base.py b/superset/daos/base.py index 11b67ef49b1..664d023d887 100644 --- a/superset/daos/base.py +++ b/superset/daos/base.py @@ -50,6 +50,19 @@ class BaseDAO(Generic[T]): cls.__orig_bases__[0] # type: ignore # pylint: disable=no-member )[0] + @classmethod + def _apply_base_filter(cls, query, skip_base_filter: bool = False, data_model=None): + """ + Apply the base_filter to the query if it exists and skip_base_filter is False. + """ + if cls.base_filter and not skip_base_filter: + if data_model is None: + data_model = SQLAInterface(cls.model_cls, db.session) + query = cls.base_filter( # pylint: disable=not-callable + cls.id_column_name, data_model + ).apply(query, None) + return query + @classmethod def find_by_id( cls, @@ -60,11 +73,7 @@ class BaseDAO(Generic[T]): Find a model by id, if defined applies `base_filter` """ query = db.session.query(cls.model_cls) - if cls.base_filter and not skip_base_filter: - data_model = SQLAInterface(cls.model_cls, db.session) - query = cls.base_filter( # pylint: disable=not-callable - cls.id_column_name, data_model - ).apply(query, None) + query = cls._apply_base_filter(query, skip_base_filter) id_column = getattr(cls.model_cls, cls.id_column_name) try: return query.filter(id_column == model_id).one_or_none() @@ -85,11 +94,7 @@ class BaseDAO(Generic[T]): if id_col is None: return [] query = db.session.query(cls.model_cls).filter(id_col.in_(model_ids)) - if cls.base_filter and not skip_base_filter: - data_model = SQLAInterface(cls.model_cls, db.session) - query = cls.base_filter( # pylint: disable=not-callable - cls.id_column_name, data_model - ).apply(query, None) + query = cls._apply_base_filter(query, skip_base_filter) return query.all() @classmethod @@ -98,11 +103,7 @@ class BaseDAO(Generic[T]): Get all that fit the `base_filter` """ query = db.session.query(cls.model_cls) - if cls.base_filter: - data_model = SQLAInterface(cls.model_cls, db.session) - query = cls.base_filter( # pylint: disable=not-callable - cls.id_column_name, data_model - ).apply(query, None) + query = cls._apply_base_filter(query) return query.all() @classmethod @@ -111,11 +112,7 @@ class BaseDAO(Generic[T]): Get the first that fit the `base_filter` """ query = db.session.query(cls.model_cls) - if cls.base_filter: - data_model = SQLAInterface(cls.model_cls, db.session) - query = cls.base_filter( # pylint: disable=not-callable - cls.id_column_name, data_model - ).apply(query, None) + query = cls._apply_base_filter(query) return query.filter_by(**filter_by).one_or_none() @classmethod @@ -215,6 +212,7 @@ class BaseDAO(Generic[T]): :param custom_filters: Dictionary of custom FAB filter classes to apply :return: Tuple of (items, total_count) """ + # Create SQLAInterface instance for FAB-compatible query generation data_model = SQLAInterface(cls.model_cls, db.session) @@ -222,10 +220,7 @@ class BaseDAO(Generic[T]): query = data_model.session.query(cls.model_cls) # Apply base filter if defined - if cls.base_filter: - query = cls.base_filter( # pylint: disable=not-callable - cls.id_column_name, data_model - ).apply(query, None) + query = cls._apply_base_filter(query, data_model=data_model) # Apply search filter if search and search_columns: @@ -292,12 +287,9 @@ class BaseDAO(Generic[T]): :param filters: Dictionary of column_name: value to filter by :return: Number of records matching the filter """ + data_model = SQLAInterface(cls.model_cls, db.session) query = db.session.query(cls.model_cls) - if cls.base_filter and not skip_base_filter: - data_model = SQLAInterface(cls.model_cls, db.session) - query = cls.base_filter( # pylint: disable=not-callable - cls.id_column_name, data_model - ).apply(query, None) + query = cls._apply_base_filter(query, skip_base_filter=skip_base_filter, data_model=data_model) if filters: for column_name, value in filters.items(): diff --git a/tests/unit_tests/dao/base_test.py b/tests/unit_tests/dao/base_test.py index 15f27a8871b..348d1d9dbe3 100644 --- a/tests/unit_tests/dao/base_test.py +++ b/tests/unit_tests/dao/base_test.py @@ -16,11 +16,9 @@ # under the License. from sqlalchemy.orm.session import Session -from flask_appbuilder.models.filters import BaseFilter -import pytest + def test_base_dao_list_returns_results(user_with_data: Session) -> None: - """Test that BaseDAO.list returns results for the model.""" from superset.daos.user import UserDAO results, total = UserDAO.list() assert total >= 1 @@ -28,7 +26,6 @@ def test_base_dao_list_returns_results(user_with_data: Session) -> None: def test_base_dao_list_with_filters(user_with_data: Session) -> None: - """Test that BaseDAO.list applies filters correctly.""" from superset.daos.user import UserDAO results, total = UserDAO.list(filters={"username": "testuser"}) assert total >= 1 @@ -36,7 +33,6 @@ def test_base_dao_list_with_filters(user_with_data: Session) -> None: def test_base_dao_list_with_non_matching_filter(user_with_data: Session) -> None: - """Test that BaseDAO.list returns empty for non-matching filters.""" from superset.daos.user import UserDAO results, total = UserDAO.list(filters={"username": "doesnotexist"}) assert total == 0 @@ -44,14 +40,12 @@ def test_base_dao_list_with_non_matching_filter(user_with_data: Session) -> None def test_base_dao_count_returns_value(user_with_data: Session) -> None: - """Test that BaseDAO.count returns a count for the model.""" from superset.daos.user import UserDAO count = UserDAO.count() assert count >= 1 def test_base_dao_count_with_filters(user_with_data: Session) -> None: - """Test that BaseDAO.count applies filters correctly.""" from superset.daos.user import UserDAO count = UserDAO.count(filters={"username": "testuser"}) assert count >= 1 @@ -60,7 +54,6 @@ def test_base_dao_count_with_filters(user_with_data: Session) -> None: def test_base_dao_list_and_count_skip_base_filter(user_with_data: Session) -> None: - """Test that skip_base_filter argument works for list and count.""" from superset.daos.user import UserDAO results, total = UserDAO.list() results_skip, total_skip = UserDAO.list() @@ -138,6 +131,7 @@ def test_base_dao_list_custom_filter(user_with_data: Session) -> None: from superset.daos.user import UserDAO from flask_appbuilder.security.sqla.models import User from flask_appbuilder.models.sqla.interface import SQLAInterface + from flask_appbuilder.models.filters import BaseFilter datamodel = SQLAInterface(User, user_with_data) class EmailDomainFilter(BaseFilter): def apply(self, query, value): @@ -157,6 +151,7 @@ def test_base_dao_list_custom_filter(user_with_data: Session) -> None: def test_base_dao_list_base_filter(user_with_data: Session) -> None: from superset.daos.user import UserDAO from flask_appbuilder.security.sqla.models import User + from flask_appbuilder.models.filters import BaseFilter class OnlyActiveFilter(BaseFilter): def apply(self, query, value): return query.filter(User.active == True)