From 0f970025206d048b8874acd8485ae381d0cd75f2 Mon Sep 17 00:00:00 2001 From: Amin Ghadersohi Date: Wed, 16 Jul 2025 02:31:22 +1000 Subject: [PATCH] update: Refactor BaseDAO: enhance column operator logic and expand test coverage - Improved the BaseDAO class to robustly handle column operator logic, ensuring all supported operators (eq, ne, sw, ew, in, nin, gt, gte, lt, lte, like, ilike, is_null, is_not_null) are consistently applied via ColumnOperatorEnum. - Refactored the apply_column_operators and list methods for clarity and reliability, including better handling of columns, relationships, and search. - Removed 1 index base page handing from list --- superset/daos/base.py | 70 +++++++++++++++++++++++++------ tests/unit_tests/dao/base_test.py | 39 +++++++++++++++++ 2 files changed, 96 insertions(+), 13 deletions(-) diff --git a/superset/daos/base.py b/superset/daos/base.py index f85959128c0..41fb4275eed 100644 --- a/superset/daos/base.py +++ b/superset/daos/base.py @@ -21,8 +21,10 @@ from typing import Any, Dict, Generic, get_args, List, Optional, Tuple, TypeVar from flask_appbuilder.models.filters import BaseFilter from flask_appbuilder.models.sqla import Model from flask_appbuilder.models.sqla.interface import SQLAInterface -from sqlalchemy import asc, desc, or_ +from sqlalchemy import asc, desc, or_, cast, Text from sqlalchemy.exc import StatementError +from sqlalchemy.orm import joinedload, RelationshipProperty, ColumnProperty +from sqlalchemy.inspection import inspect from superset.extensions import db @@ -234,7 +236,7 @@ class BaseDAO(Generic[T]): @classmethod def apply_column_operators(cls, query, column_operators: Optional[List[ColumnOperator]] = None): """ - Apply column operators (list of ColumnOperator) to the query. + Apply column operators (list of ColumnOperator) to the query using ColumnOperatorEnum logic. """ if not column_operators: return query @@ -248,7 +250,8 @@ class BaseDAO(Generic[T]): continue column = getattr(cls.model_cls, col) try: - query = query.filter(opr.apply(column, value)) + # Always use ColumnOperatorEnum's apply method + query = query.filter(ColumnOperatorEnum(opr).apply(column, value)) except Exception: continue # Optionally log or raise return query @@ -275,7 +278,7 @@ class BaseDAO(Generic[T]): for column_name in search_columns: if hasattr(cls.model_cls, column_name): column = getattr(cls.model_cls, column_name) - search_filters.append(column.ilike(f"%{search}%")) + search_filters.append(cast(column, Text).ilike(f"%{search}%")) if search_filters: query = query.filter(or_(*search_filters)) if custom_filters: @@ -296,18 +299,59 @@ class BaseDAO(Generic[T]): search: Optional[str] = None, search_columns: Optional[List[str]] = None, custom_filters: Optional[Dict[str, BaseFilter]] = None, - ) -> Tuple[List[T], int]: + columns: Optional[List[str]] = None, + ) -> Tuple[List[Any], int]: """ Generic list method for filtered, sorted, and paginated results. + If columns is specified, returns a list of tuples (one per row), + otherwise returns model instances. """ data_model = SQLAInterface(cls.model_cls, db.session) - query = cls._build_query( - column_operators=column_operators, - search=search, - search_columns=search_columns, - custom_filters=custom_filters, - data_model=data_model, - ) + # Separate columns and relationships + mapper = inspect(cls.model_cls) + column_names = set(c.key for c in mapper.columns) + relationship_names = set(r.key for r in mapper.relationships) + + column_attrs = [] + relationship_loads = [] + if columns is None: + columns = [] + for name in columns: + attr = getattr(cls.model_cls, name, None) + if attr is None: + continue + prop = getattr(attr, 'property', None) + if isinstance(prop, ColumnProperty): + column_attrs.append(attr) + elif isinstance(prop, RelationshipProperty): + relationship_loads.append(joinedload(attr)) + # Ignore properties and other non-queryable attributes + + if relationship_loads: + # If any relationships are requested, query the full model and joinedload relationships + query = data_model.session.query(cls.model_cls) + for loader in relationship_loads: + query = query.options(loader) + elif column_attrs: + # Only columns requested + query = data_model.session.query(*column_attrs) + else: + # Fallback: query the full model + query = data_model.session.query(cls.model_cls) + query = cls._apply_base_filter(query, data_model=data_model) + if search and search_columns: + search_filters = [] + for column_name in search_columns: + if hasattr(cls.model_cls, column_name): + column = getattr(cls.model_cls, column_name) + search_filters.append(cast(column, Text).ilike(f"%{search}%")) + if search_filters: + query = query.filter(or_(*search_filters)) + if custom_filters: + for filter_class in custom_filters.values(): + query = filter_class.apply(query, None) + if column_operators: + query = cls.apply_column_operators(query, column_operators) total_count = query.count() if hasattr(cls.model_cls, order_column): column = getattr(cls.model_cls, order_column) @@ -315,7 +359,7 @@ class BaseDAO(Generic[T]): query = query.order_by(desc(column)) else: query = query.order_by(asc(column)) - page = max(page, 0) + page = page page_size = max(page_size, 1) query = query.offset(page * page_size).limit(page_size) items = query.all() diff --git a/tests/unit_tests/dao/base_test.py b/tests/unit_tests/dao/base_test.py index f8c3546536e..410aeecb66c 100644 --- a/tests/unit_tests/dao/base_test.py +++ b/tests/unit_tests/dao/base_test.py @@ -286,3 +286,42 @@ def test_base_dao_column_operator_is_not_null(user_with_data: Session) -> None: results, _ = UserDAO.list(column_operators=[ColumnOperator(col="last_login", opr=ColumnOperatorEnum.is_not_null)]) assert any(u.username == "notnulluser" for u in results) assert all(u.last_login is not None for u in results) + +def test_base_dao_list_with_select_columns(user_with_data: Session) -> None: + from superset.daos.user import UserDAO + # Add a user to ensure at least one exists + from flask_appbuilder.security.sqla.models import User + user_with_data.add(User(id=900, username="coluser", first_name="Col", last_name="User", email="coluser@example.com", active=True)) + user_with_data.commit() + # Request only username and email columns + results, total = UserDAO.list(columns=["username", "email"]) + assert total >= 1 + # Should return Row objects with correct values + found = False + for row in results: + # SQLAlchemy Row supports both index and key access + if row[0] == "coluser" and row[1] == "coluser@example.com": + found = True + assert found + # Request only id column + results, total = UserDAO.list(columns=["id"]) + found = False + for row in results: + if row[0] == 900: + found = True + assert found + +def test_base_dao_list_with_default_columns(user_with_data: Session) -> None: + from superset.daos.user import UserDAO + from flask_appbuilder.security.sqla.models import User + user_with_data.add(User(id=901, username="defaultuser", first_name="Default", last_name="User", email="defaultuser@example.com", active=True)) + user_with_data.commit() + results, total = UserDAO.list() + assert total >= 1 + # Should return model instances + found = False + for user in results: + if hasattr(user, "id") and hasattr(user, "username") and hasattr(user, "email"): + if user.id == 901 and user.username == "defaultuser" and user.email == "defaultuser@example.com": + found = True + assert found