mirror of
https://github.com/apache/superset.git
synced 2026-04-07 18:35:15 +00:00
feat: Add BaseDAO improvements and test reorganization (#35018)
Co-authored-by: bito-code-review[bot] <188872107+bito-code-review[bot]@users.noreply.github.com> Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -84,6 +84,7 @@ dependencies = [
|
||||
"pgsanity",
|
||||
"Pillow>=11.0.0, <12",
|
||||
"polyline>=2.0.0, <3.0",
|
||||
"pydantic>=2.8.0",
|
||||
"pyparsing>=3.0.6, <4",
|
||||
"python-dateutil",
|
||||
"python-dotenv", # optional dependencies for Flask but required for Superset, see https://flask.palletsprojects.com/en/stable/installation/#optional-dependencies
|
||||
|
||||
@@ -6,6 +6,8 @@ alembic==1.15.2
|
||||
# via flask-migrate
|
||||
amqp==5.3.1
|
||||
# via kombu
|
||||
annotated-types==0.7.0
|
||||
# via pydantic
|
||||
apispec==6.6.1
|
||||
# via
|
||||
# -r requirements/base.in
|
||||
@@ -115,7 +117,9 @@ flask==2.3.3
|
||||
# flask-sqlalchemy
|
||||
# flask-wtf
|
||||
flask-appbuilder==5.0.0
|
||||
# via apache-superset (pyproject.toml)
|
||||
# via
|
||||
# apache-superset (pyproject.toml)
|
||||
# apache-superset-core
|
||||
flask-babel==3.1.0
|
||||
# via flask-appbuilder
|
||||
flask-caching==2.3.1
|
||||
@@ -294,6 +298,10 @@ pyasn1-modules==0.4.2
|
||||
# via google-auth
|
||||
pycparser==2.22
|
||||
# via cffi
|
||||
pydantic==2.11.7
|
||||
# via apache-superset (pyproject.toml)
|
||||
pydantic-core==2.33.2
|
||||
# via pydantic
|
||||
pygments==2.19.1
|
||||
# via rich
|
||||
pyjwt==2.10.1
|
||||
@@ -404,10 +412,15 @@ typing-extensions==4.14.0
|
||||
# alembic
|
||||
# cattrs
|
||||
# limits
|
||||
# pydantic
|
||||
# pydantic-core
|
||||
# pyopenssl
|
||||
# referencing
|
||||
# selenium
|
||||
# shillelagh
|
||||
# typing-inspection
|
||||
typing-inspection==0.4.1
|
||||
# via pydantic
|
||||
tzdata==2025.2
|
||||
# via
|
||||
# kombu
|
||||
|
||||
@@ -18,6 +18,10 @@ amqp==5.3.1
|
||||
# via
|
||||
# -c requirements/base-constraint.txt
|
||||
# kombu
|
||||
annotated-types==0.7.0
|
||||
# via
|
||||
# -c requirements/base-constraint.txt
|
||||
# pydantic
|
||||
apispec==6.6.1
|
||||
# via
|
||||
# -c requirements/base-constraint.txt
|
||||
@@ -212,6 +216,7 @@ flask-appbuilder==5.0.0
|
||||
# via
|
||||
# -c requirements/base-constraint.txt
|
||||
# apache-superset
|
||||
# apache-superset-core
|
||||
flask-babel==3.1.0
|
||||
# via
|
||||
# -c requirements/base-constraint.txt
|
||||
@@ -631,6 +636,14 @@ pycparser==2.22
|
||||
# via
|
||||
# -c requirements/base-constraint.txt
|
||||
# cffi
|
||||
pydantic==2.11.7
|
||||
# via
|
||||
# -c requirements/base-constraint.txt
|
||||
# apache-superset
|
||||
pydantic-core==2.33.2
|
||||
# via
|
||||
# -c requirements/base-constraint.txt
|
||||
# pydantic
|
||||
pydata-google-auth==1.9.0
|
||||
# via pandas-gbq
|
||||
pydruid==0.6.9
|
||||
@@ -874,10 +887,17 @@ typing-extensions==4.14.0
|
||||
# apache-superset
|
||||
# cattrs
|
||||
# limits
|
||||
# pydantic
|
||||
# pydantic-core
|
||||
# pyopenssl
|
||||
# referencing
|
||||
# selenium
|
||||
# shillelagh
|
||||
# typing-inspection
|
||||
typing-inspection==0.4.1
|
||||
# via
|
||||
# -c requirements/base-constraint.txt
|
||||
# pydantic
|
||||
tzdata==2025.2
|
||||
# via
|
||||
# -c requirements/base-constraint.txt
|
||||
|
||||
@@ -16,22 +16,141 @@
|
||||
# under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Generic, get_args, TypeVar
|
||||
import logging
|
||||
import uuid as uuid_lib
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
Generic,
|
||||
get_args,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask_appbuilder.models.filters import BaseFilter
|
||||
from flask_appbuilder.models.sqla import Model
|
||||
from flask_appbuilder.models.sqla.interface import SQLAInterface
|
||||
from flask_sqlalchemy import BaseQuery
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import asc, cast, desc, or_, Text
|
||||
from sqlalchemy.exc import SQLAlchemyError, StatementError
|
||||
from sqlalchemy.ext.hybrid import hybrid_property
|
||||
from sqlalchemy.inspection import inspect
|
||||
from sqlalchemy.orm import ColumnProperty, joinedload, RelationshipProperty
|
||||
|
||||
from superset.daos.exceptions import (
|
||||
DAOFindFailedError,
|
||||
)
|
||||
from superset.extensions import db
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar("T", bound=Model)
|
||||
|
||||
|
||||
class ColumnOperatorEnum(str, Enum):
|
||||
eq = "eq"
|
||||
ne = "ne"
|
||||
sw = "sw"
|
||||
ew = "ew"
|
||||
in_ = "in"
|
||||
nin = "nin"
|
||||
gt = "gt"
|
||||
gte = "gte"
|
||||
lt = "lt"
|
||||
lte = "lte"
|
||||
like = "like"
|
||||
ilike = "ilike"
|
||||
is_null = "is_null"
|
||||
is_not_null = "is_not_null"
|
||||
|
||||
def apply(self, column: Any, value: Any) -> Any:
|
||||
op_func = operator_map.get(self)
|
||||
if not op_func:
|
||||
raise ValueError("Unsupported operator: %s" % self)
|
||||
return op_func(column, value)
|
||||
|
||||
|
||||
# Define operator_map as a module-level dict after the enum is defined
|
||||
operator_map: Dict[ColumnOperatorEnum, Any] = {
|
||||
ColumnOperatorEnum.eq: lambda col, val: col == val,
|
||||
ColumnOperatorEnum.ne: lambda col, val: col != val,
|
||||
ColumnOperatorEnum.sw: lambda col, val: col.like(f"{val}%"),
|
||||
ColumnOperatorEnum.ew: lambda col, val: col.like(f"%{val}"),
|
||||
ColumnOperatorEnum.in_: lambda col, val: col.in_(
|
||||
val if isinstance(val, (list, tuple)) else [val]
|
||||
),
|
||||
ColumnOperatorEnum.nin: lambda col, val: ~col.in_(
|
||||
val if isinstance(val, (list, tuple)) else [val]
|
||||
),
|
||||
ColumnOperatorEnum.gt: lambda col, val: col > val,
|
||||
ColumnOperatorEnum.gte: lambda col, val: col >= val,
|
||||
ColumnOperatorEnum.lt: lambda col, val: col < val,
|
||||
ColumnOperatorEnum.lte: lambda col, val: col <= val,
|
||||
ColumnOperatorEnum.like: lambda col, val: col.like(f"%{val}%"),
|
||||
ColumnOperatorEnum.ilike: lambda col, val: col.ilike(f"%{val}%"),
|
||||
ColumnOperatorEnum.is_null: lambda col, _: col.is_(None),
|
||||
ColumnOperatorEnum.is_not_null: lambda col, _: col.isnot(None),
|
||||
}
|
||||
|
||||
# Map SQLAlchemy types to supported operators
|
||||
TYPE_OPERATOR_MAP = {
|
||||
"string": [
|
||||
ColumnOperatorEnum.eq,
|
||||
ColumnOperatorEnum.ne,
|
||||
ColumnOperatorEnum.sw,
|
||||
ColumnOperatorEnum.ew,
|
||||
ColumnOperatorEnum.in_,
|
||||
ColumnOperatorEnum.nin,
|
||||
ColumnOperatorEnum.like,
|
||||
ColumnOperatorEnum.ilike,
|
||||
ColumnOperatorEnum.is_null,
|
||||
ColumnOperatorEnum.is_not_null,
|
||||
],
|
||||
"boolean": [
|
||||
ColumnOperatorEnum.eq,
|
||||
ColumnOperatorEnum.ne,
|
||||
ColumnOperatorEnum.is_null,
|
||||
ColumnOperatorEnum.is_not_null,
|
||||
],
|
||||
"number": [
|
||||
ColumnOperatorEnum.eq,
|
||||
ColumnOperatorEnum.ne,
|
||||
ColumnOperatorEnum.gt,
|
||||
ColumnOperatorEnum.gte,
|
||||
ColumnOperatorEnum.lt,
|
||||
ColumnOperatorEnum.lte,
|
||||
ColumnOperatorEnum.in_,
|
||||
ColumnOperatorEnum.nin,
|
||||
ColumnOperatorEnum.is_null,
|
||||
ColumnOperatorEnum.is_not_null,
|
||||
],
|
||||
"datetime": [
|
||||
ColumnOperatorEnum.eq,
|
||||
ColumnOperatorEnum.ne,
|
||||
ColumnOperatorEnum.gt,
|
||||
ColumnOperatorEnum.gte,
|
||||
ColumnOperatorEnum.lt,
|
||||
ColumnOperatorEnum.lte,
|
||||
ColumnOperatorEnum.in_,
|
||||
ColumnOperatorEnum.nin,
|
||||
ColumnOperatorEnum.is_null,
|
||||
ColumnOperatorEnum.is_not_null,
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
class ColumnOperator(BaseModel):
|
||||
col: str = Field(..., description="Column name to filter on")
|
||||
opr: ColumnOperatorEnum = Field(..., description="Operator")
|
||||
value: Any = Field(None, description="Value for the filter")
|
||||
|
||||
|
||||
class BaseDAO(Generic[T]):
|
||||
"""
|
||||
Base DAO, implement base CRUD sqlalchemy operations
|
||||
@@ -83,80 +202,170 @@ class BaseDAO(Generic[T]):
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def find_by_id(
|
||||
cls,
|
||||
model_id: str | int,
|
||||
skip_base_filter: bool = False,
|
||||
) -> T | None:
|
||||
def _apply_base_filter(
|
||||
cls, query: Any, skip_base_filter: bool = False, data_model: Any = None
|
||||
) -> Any:
|
||||
"""
|
||||
Find a model by id, if defined applies `base_filter`
|
||||
Apply the base_filter to the query if it exists and skip_base_filter is False.
|
||||
"""
|
||||
query = db.session.query(cls.model_cls)
|
||||
if cls.base_filter and not skip_base_filter:
|
||||
data_model = SQLAInterface(cls.model_cls, db.session)
|
||||
if data_model is None:
|
||||
data_model = SQLAInterface(cls.model_cls, db.session)
|
||||
query = cls.base_filter( # pylint: disable=not-callable
|
||||
cls.id_column_name, data_model
|
||||
).apply(query, None)
|
||||
id_column = getattr(cls.model_cls, cls.id_column_name)
|
||||
return query
|
||||
|
||||
@classmethod
|
||||
def _convert_value_for_column(cls, column: Any, value: Any) -> Any:
|
||||
"""
|
||||
Convert a value to the appropriate type for a given SQLAlchemy column.
|
||||
|
||||
Args:
|
||||
column: SQLAlchemy column object
|
||||
value: Value to convert
|
||||
|
||||
Returns:
|
||||
Converted value or None if conversion fails
|
||||
"""
|
||||
if (
|
||||
hasattr(column.type, "python_type")
|
||||
and column.type.python_type == uuid_lib.UUID
|
||||
):
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
return uuid_lib.UUID(value)
|
||||
except (ValueError, AttributeError):
|
||||
return None
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
def _find_by_column(
|
||||
cls,
|
||||
column_name: str,
|
||||
value: str | int,
|
||||
skip_base_filter: bool = False,
|
||||
) -> T | None:
|
||||
"""
|
||||
Private method to find a model by any column value.
|
||||
|
||||
Args:
|
||||
column_name: Name of the column to search by
|
||||
value: Value to search for
|
||||
skip_base_filter: Whether to skip base filtering
|
||||
|
||||
Returns:
|
||||
Model instance or None if not found
|
||||
"""
|
||||
query = db.session.query(cls.model_cls)
|
||||
query = cls._apply_base_filter(query, skip_base_filter)
|
||||
|
||||
if not hasattr(cls.model_cls, column_name):
|
||||
return None
|
||||
|
||||
column = getattr(cls.model_cls, column_name)
|
||||
converted_value = cls._convert_value_for_column(column, value)
|
||||
if converted_value is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
return query.filter(id_column == model_id).one_or_none()
|
||||
return query.filter(column == converted_value).one_or_none()
|
||||
except StatementError:
|
||||
# can happen if int is passed instead of a string or similar
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def find_by_id(
|
||||
cls,
|
||||
model_id: str | int,
|
||||
skip_base_filter: bool = False,
|
||||
id_column: str | None = None,
|
||||
) -> T | None:
|
||||
"""
|
||||
Find a model by ID using specified or default ID column.
|
||||
|
||||
Args:
|
||||
model_id: ID value to search for
|
||||
skip_base_filter: Whether to skip base filtering
|
||||
id_column: Column name to use (defaults to cls.id_column_name)
|
||||
|
||||
Returns:
|
||||
Model instance or None if not found
|
||||
"""
|
||||
column = id_column or cls.id_column_name
|
||||
return cls._find_by_column(column, model_id, skip_base_filter)
|
||||
|
||||
@classmethod
|
||||
def find_by_ids(
|
||||
cls,
|
||||
model_ids: list[str] | list[int],
|
||||
model_ids: Sequence[str | int],
|
||||
skip_base_filter: bool = False,
|
||||
id_column: str | None = None,
|
||||
) -> list[T]:
|
||||
"""
|
||||
Find a List of models by a list of ids, if defined applies `base_filter`
|
||||
|
||||
:param model_ids: List of IDs to find
|
||||
:param skip_base_filter: If true, skip applying the base filter
|
||||
:param id_column: Optional column name to use for ID lookup
|
||||
(defaults to id_column_name)
|
||||
"""
|
||||
id_col = getattr(cls.model_cls, cls.id_column_name, None)
|
||||
column = id_column or cls.id_column_name
|
||||
id_col = getattr(cls.model_cls, column, None)
|
||||
if id_col is None or not model_ids:
|
||||
return []
|
||||
query = db.session.query(cls.model_cls).filter(id_col.in_(model_ids))
|
||||
if cls.base_filter and not skip_base_filter:
|
||||
data_model = SQLAInterface(cls.model_cls, db.session)
|
||||
query = cls.base_filter( # pylint: disable=not-callable
|
||||
cls.id_column_name, data_model
|
||||
).apply(query, None)
|
||||
|
||||
# Convert IDs to appropriate types based on column type
|
||||
converted_ids: list[str | int | uuid_lib.UUID] = []
|
||||
for id_val in model_ids:
|
||||
converted_value = cls._convert_value_for_column(id_col, id_val)
|
||||
if converted_value is not None:
|
||||
# Only add successfully converted values
|
||||
converted_ids.append(converted_value)
|
||||
else:
|
||||
# Log warning for failed conversions
|
||||
logger.warning(
|
||||
"Failed to convert ID '%s' for column %s.%s",
|
||||
id_val,
|
||||
cls.model_cls.__name__ if cls.model_cls else "Unknown",
|
||||
column,
|
||||
)
|
||||
|
||||
# If no valid IDs after conversion, return empty list
|
||||
if not converted_ids:
|
||||
return []
|
||||
|
||||
query = db.session.query(cls.model_cls).filter(id_col.in_(converted_ids))
|
||||
query = cls._apply_base_filter(query, skip_base_filter)
|
||||
|
||||
try:
|
||||
results = query.all()
|
||||
except SQLAlchemyError as ex:
|
||||
model_name = cls.model_cls.__name__ if cls.model_cls else "Unknown"
|
||||
raise DAOFindFailedError(
|
||||
f"Failed to find {model_name} with ids: {model_ids}"
|
||||
"Failed to find %s with ids: %s" % (model_name, model_ids)
|
||||
) from ex
|
||||
|
||||
return results
|
||||
|
||||
@classmethod
|
||||
def find_all(cls) -> list[T]:
|
||||
def find_all(cls, skip_base_filter: bool = False) -> list[T]:
|
||||
"""
|
||||
Get all that fit the `base_filter`
|
||||
"""
|
||||
query = db.session.query(cls.model_cls)
|
||||
if cls.base_filter:
|
||||
data_model = SQLAInterface(cls.model_cls, db.session)
|
||||
query = cls.base_filter( # pylint: disable=not-callable
|
||||
cls.id_column_name, data_model
|
||||
).apply(query, None)
|
||||
query = cls._apply_base_filter(query, skip_base_filter)
|
||||
return query.all()
|
||||
|
||||
@classmethod
|
||||
def find_one_or_none(cls, **filter_by: Any) -> T | None:
|
||||
def find_one_or_none(
|
||||
cls, skip_base_filter: bool = False, **filter_by: Any
|
||||
) -> T | None:
|
||||
"""
|
||||
Get the first that fit the `base_filter`
|
||||
"""
|
||||
query = db.session.query(cls.model_cls)
|
||||
if cls.base_filter:
|
||||
data_model = SQLAInterface(cls.model_cls, db.session)
|
||||
query = cls.base_filter( # pylint: disable=not-callable
|
||||
cls.id_column_name, data_model
|
||||
).apply(query, None)
|
||||
query = cls._apply_base_filter(query, skip_base_filter)
|
||||
return query.filter_by(**filter_by).one_or_none()
|
||||
|
||||
@classmethod
|
||||
@@ -251,3 +460,229 @@ class BaseDAO(Generic[T]):
|
||||
cls.id_column_name, data_model
|
||||
).apply(query, None)
|
||||
return query.filter_by(**filter_by).all()
|
||||
|
||||
@classmethod
|
||||
def apply_column_operators(
|
||||
cls, query: Any, column_operators: Optional[List[ColumnOperator]] = None
|
||||
) -> Any:
|
||||
"""
|
||||
Apply column operators (list of ColumnOperator) to the query using
|
||||
ColumnOperatorEnum logic. Raises ValueError if a filter references a
|
||||
non-existent column.
|
||||
"""
|
||||
if not column_operators:
|
||||
return query
|
||||
for c in column_operators:
|
||||
if not isinstance(c, ColumnOperator):
|
||||
continue
|
||||
col, opr, value = c.col, c.opr, c.value
|
||||
if not col or not hasattr(cls.model_cls, col):
|
||||
model_name = cls.model_cls.__name__ if cls.model_cls else "Unknown"
|
||||
logging.error(
|
||||
"Invalid filter: column '%s' does not exist on %s", col, model_name
|
||||
)
|
||||
raise ValueError(
|
||||
"Invalid filter: column '%s' does not exist on %s"
|
||||
% (col, model_name)
|
||||
)
|
||||
column = getattr(cls.model_cls, col)
|
||||
try:
|
||||
# Always use ColumnOperatorEnum's apply method
|
||||
operator_enum = ColumnOperatorEnum(opr)
|
||||
query = query.filter(operator_enum.apply(column, value))
|
||||
except Exception as e:
|
||||
logging.error("Error applying filter on column '%s': %s", col, e)
|
||||
raise
|
||||
return query
|
||||
|
||||
@classmethod
|
||||
def get_filterable_columns_and_operators(cls) -> Dict[str, List[str]]:
|
||||
"""
|
||||
Returns a dict mapping filterable columns (including hybrid/computed fields if
|
||||
present) to their supported operators. Used by MCP tools to dynamically expose
|
||||
filter options. Custom fields supported by the DAO but not present on the model
|
||||
should be documented here.
|
||||
"""
|
||||
|
||||
mapper = inspect(cls.model_cls)
|
||||
columns = {c.key: c for c in mapper.columns}
|
||||
# Add hybrid properties
|
||||
hybrids = {
|
||||
name: attr
|
||||
for name, attr in vars(cls.model_cls).items()
|
||||
if isinstance(attr, hybrid_property)
|
||||
}
|
||||
# You may add custom fields here, e.g.:
|
||||
# custom_fields = {"tags": ["eq", "in_", "like"], ...}
|
||||
custom_fields: Dict[str, List[str]] = {}
|
||||
|
||||
filterable: Dict[str, Any] = {}
|
||||
for name, col in columns.items():
|
||||
if isinstance(col.type, (sa.String, sa.Text)):
|
||||
filterable[name] = TYPE_OPERATOR_MAP["string"]
|
||||
elif isinstance(col.type, (sa.Boolean,)):
|
||||
filterable[name] = TYPE_OPERATOR_MAP["boolean"]
|
||||
elif isinstance(col.type, (sa.Integer, sa.Float, sa.Numeric)):
|
||||
filterable[name] = TYPE_OPERATOR_MAP["number"]
|
||||
elif isinstance(col.type, (sa.DateTime, sa.Date, sa.Time)):
|
||||
filterable[name] = TYPE_OPERATOR_MAP["datetime"]
|
||||
else:
|
||||
# Fallback to eq/ne/null
|
||||
filterable[name] = [
|
||||
ColumnOperatorEnum.eq,
|
||||
ColumnOperatorEnum.ne,
|
||||
ColumnOperatorEnum.is_null,
|
||||
ColumnOperatorEnum.is_not_null,
|
||||
]
|
||||
# Add hybrid properties as string fields by default
|
||||
for name in hybrids:
|
||||
filterable[name] = TYPE_OPERATOR_MAP["string"]
|
||||
# Add custom fields
|
||||
filterable.update(custom_fields)
|
||||
|
||||
# Convert enum values to strings for the return type
|
||||
result: Dict[str, List[str]] = {}
|
||||
for key, operators in filterable.items():
|
||||
if isinstance(operators, list):
|
||||
# Convert enums to strings
|
||||
result[key] = [
|
||||
op.value if isinstance(op, ColumnOperatorEnum) else op
|
||||
for op in operators
|
||||
]
|
||||
else:
|
||||
result[key] = operators
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def _build_query(
|
||||
cls,
|
||||
column_operators: Optional[List[ColumnOperator]] = None,
|
||||
search: Optional[str] = None,
|
||||
search_columns: Optional[List[str]] = None,
|
||||
custom_filters: Optional[Dict[str, BaseFilter]] = None,
|
||||
skip_base_filter: bool = False,
|
||||
data_model: Optional[SQLAInterface] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Build a SQLAlchemy query with base filter, column operators, search, and
|
||||
custom filters.
|
||||
"""
|
||||
if data_model is None:
|
||||
data_model = SQLAInterface(cls.model_cls, db.session)
|
||||
query = data_model.session.query(cls.model_cls)
|
||||
query = cls._apply_base_filter(
|
||||
query, skip_base_filter=skip_base_filter, data_model=data_model
|
||||
)
|
||||
if search and search_columns:
|
||||
search_filters = []
|
||||
for column_name in search_columns:
|
||||
if hasattr(cls.model_cls, column_name):
|
||||
column = getattr(cls.model_cls, column_name)
|
||||
search_filters.append(cast(column, Text).ilike(f"%{search}%"))
|
||||
if search_filters:
|
||||
query = query.filter(or_(*search_filters))
|
||||
if custom_filters:
|
||||
for filter_class in custom_filters.values():
|
||||
query = filter_class.apply(query, None)
|
||||
if column_operators:
|
||||
query = cls.apply_column_operators(query, column_operators)
|
||||
return query
|
||||
|
||||
@classmethod
|
||||
def list( # noqa: C901
|
||||
cls,
|
||||
column_operators: Optional[List[ColumnOperator]] = None,
|
||||
order_column: str = "changed_on",
|
||||
order_direction: str = "desc",
|
||||
page: int = 0,
|
||||
page_size: int = 100,
|
||||
search: Optional[str] = None,
|
||||
search_columns: Optional[List[str]] = None,
|
||||
custom_filters: Optional[Dict[str, BaseFilter]] = None,
|
||||
columns: Optional[List[str]] = None,
|
||||
) -> Tuple[List[Any], int]:
|
||||
"""
|
||||
Generic list method for filtered, sorted, and paginated results.
|
||||
If columns is specified, returns a list of tuples (one per row),
|
||||
otherwise returns model instances.
|
||||
"""
|
||||
data_model = SQLAInterface(cls.model_cls, db.session)
|
||||
|
||||
column_attrs = []
|
||||
relationship_loads = []
|
||||
if columns is None:
|
||||
columns = []
|
||||
for name in columns:
|
||||
attr = getattr(cls.model_cls, name, None)
|
||||
if attr is None:
|
||||
continue
|
||||
prop = getattr(attr, "property", None)
|
||||
if isinstance(prop, ColumnProperty):
|
||||
column_attrs.append(attr)
|
||||
elif isinstance(prop, RelationshipProperty):
|
||||
relationship_loads.append(joinedload(attr))
|
||||
# Ignore properties and other non-queryable attributes
|
||||
|
||||
if relationship_loads:
|
||||
# If any relationships are requested, query the full model
|
||||
# but don't add the joins yet - we'll add them after counting
|
||||
query = data_model.session.query(cls.model_cls)
|
||||
elif column_attrs:
|
||||
# Only columns requested
|
||||
query = data_model.session.query(*column_attrs)
|
||||
else:
|
||||
# Fallback: query the full model
|
||||
query = data_model.session.query(cls.model_cls)
|
||||
query = cls._apply_base_filter(query, data_model=data_model)
|
||||
if search and search_columns:
|
||||
search_filters = []
|
||||
for column_name in search_columns:
|
||||
if hasattr(cls.model_cls, column_name):
|
||||
column = getattr(cls.model_cls, column_name)
|
||||
search_filters.append(cast(column, Text).ilike(f"%{search}%"))
|
||||
if search_filters:
|
||||
query = query.filter(or_(*search_filters))
|
||||
if custom_filters:
|
||||
for filter_class in custom_filters.values():
|
||||
query = filter_class.apply(query, None)
|
||||
if column_operators:
|
||||
query = cls.apply_column_operators(query, column_operators)
|
||||
|
||||
# Count before adding relationship joins to avoid inflated counts
|
||||
# with one-to-many or many-to-many relationships
|
||||
total_count = query.count()
|
||||
|
||||
# Add relationship joins after counting
|
||||
if relationship_loads:
|
||||
for loader in relationship_loads:
|
||||
query = query.options(loader)
|
||||
|
||||
if hasattr(cls.model_cls, order_column):
|
||||
column = getattr(cls.model_cls, order_column)
|
||||
if order_direction.lower() == "desc":
|
||||
query = query.order_by(desc(column))
|
||||
else:
|
||||
query = query.order_by(asc(column))
|
||||
page = page
|
||||
page_size = max(page_size, 1)
|
||||
query = query.offset(page * page_size).limit(page_size)
|
||||
items = query.all()
|
||||
# If columns are specified, SQLAlchemy returns Row objects (not tuples or
|
||||
# model instances)
|
||||
return items, total_count
|
||||
|
||||
@classmethod
|
||||
def count(
|
||||
cls,
|
||||
column_operators: Optional[List[ColumnOperator]] = None,
|
||||
skip_base_filter: bool = False,
|
||||
) -> int:
|
||||
"""
|
||||
Count the number of records for the model, optionally filtered by column
|
||||
operators.
|
||||
"""
|
||||
query = cls._build_query(
|
||||
column_operators=column_operators, skip_base_filter=skip_base_filter
|
||||
)
|
||||
return query.count()
|
||||
|
||||
@@ -18,7 +18,7 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Dict, List, TYPE_CHECKING
|
||||
|
||||
from flask_appbuilder.models.sqla.interface import SQLAInterface
|
||||
|
||||
@@ -35,10 +35,23 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Custom filterable fields for charts
|
||||
CHART_CUSTOM_FIELDS = {
|
||||
"viz_type": ["eq", "in_", "like"],
|
||||
"datasource_name": ["eq", "in_", "like"],
|
||||
}
|
||||
|
||||
|
||||
class ChartDAO(BaseDAO[Slice]):
|
||||
base_filter = ChartFilter
|
||||
|
||||
@classmethod
|
||||
def get_filterable_columns_and_operators(cls) -> Dict[str, List[str]]:
|
||||
filterable = super().get_filterable_columns_and_operators()
|
||||
# Add custom fields for charts
|
||||
filterable.update(CHART_CUSTOM_FIELDS)
|
||||
return filterable
|
||||
|
||||
@staticmethod
|
||||
def get_by_id_or_uuid(id_or_uuid: str) -> Slice:
|
||||
query = db.session.query(Slice).filter(id_or_uuid_filter(id_or_uuid))
|
||||
|
||||
@@ -19,7 +19,7 @@ from __future__ import annotations
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from flask import g
|
||||
from flask_appbuilder.models.sqla.interface import SQLAInterface
|
||||
@@ -45,10 +45,24 @@ from superset.utils.dashboard_filter_scopes_converter import copy_filter_scopes
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Custom filterable fields for dashboards
|
||||
DASHBOARD_CUSTOM_FIELDS = {
|
||||
"tags": ["eq", "in_", "like"],
|
||||
"owners": ["eq", "in_"],
|
||||
"published": ["eq"],
|
||||
}
|
||||
|
||||
|
||||
class DashboardDAO(BaseDAO[Dashboard]):
|
||||
base_filter = DashboardAccessFilter
|
||||
|
||||
@classmethod
|
||||
def get_filterable_columns_and_operators(cls) -> Dict[str, List[str]]:
|
||||
filterable = super().get_filterable_columns_and_operators()
|
||||
# Add custom fields for dashboards
|
||||
filterable.update(DASHBOARD_CUSTOM_FIELDS)
|
||||
return filterable
|
||||
|
||||
@classmethod
|
||||
def get_by_id_or_slug(cls, id_or_slug: int | str) -> Dashboard:
|
||||
if is_uuid(id_or_slug):
|
||||
|
||||
@@ -18,7 +18,7 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import dateutil.parser
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
@@ -35,8 +35,18 @@ from superset.views.base import DatasourceFilter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Custom filterable fields for datasets
|
||||
DATASET_CUSTOM_FIELDS: dict[str, list[str]] = {}
|
||||
|
||||
|
||||
class DatasetDAO(BaseDAO[SqlaTable]):
|
||||
"""
|
||||
DAO for datasets. Supports filtering on model fields, hybrid properties, and custom
|
||||
fields:
|
||||
- tags: list of tags (eq, in_, like)
|
||||
- owner: user id (eq, in_)
|
||||
"""
|
||||
|
||||
base_filter = DatasourceFilter
|
||||
|
||||
@staticmethod
|
||||
@@ -351,6 +361,13 @@ class DatasetDAO(BaseDAO[SqlaTable]):
|
||||
.one_or_none()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_filterable_columns_and_operators(cls) -> Dict[str, List[str]]:
|
||||
filterable = super().get_filterable_columns_and_operators()
|
||||
# Add custom fields
|
||||
filterable.update(DATASET_CUSTOM_FIELDS)
|
||||
return filterable
|
||||
|
||||
|
||||
class DatasetColumnDAO(BaseDAO[TableColumn]):
|
||||
pass
|
||||
|
||||
1613
tests/integration_tests/dao/base_dao_test.py
Normal file
1613
tests/integration_tests/dao/base_dao_test.py
Normal file
File diff suppressed because it is too large
Load Diff
143
tests/integration_tests/dao/conftest.py
Normal file
143
tests/integration_tests/dao/conftest.py
Normal file
@@ -0,0 +1,143 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Fixtures for DAO integration tests.
|
||||
|
||||
This module provides fixtures that replicate the unit test behavior by using
|
||||
an in-memory SQLite database for each test to ensure data isolation and avoid
|
||||
conflicts between test runs.
|
||||
|
||||
Key features:
|
||||
- In-memory SQLite database created per test
|
||||
- Proper Flask-SQLAlchemy session patching
|
||||
- Security manager session handling
|
||||
- Automatic cleanup after each test
|
||||
"""
|
||||
|
||||
from typing import Generator
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask_appbuilder.security.sqla.models import User
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
from superset.extensions import db
|
||||
from tests.integration_tests.test_app import app as superset_app
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def setup_sample_data() -> None:
|
||||
"""
|
||||
Override parent conftest setup_sample_data to prevent loading sample data.
|
||||
|
||||
This prevents the parent conftest from loading CSS templates and other
|
||||
sample data that could interfere with DAO integration tests.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app() -> Flask:
|
||||
"""Get the Superset Flask application instance."""
|
||||
return superset_app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app_context(app: Flask) -> Generator[Session, None, None]:
|
||||
"""
|
||||
Create an in-memory SQLite database for each test.
|
||||
|
||||
This fixture replicates the unit test behavior by providing a fresh
|
||||
in-memory database for each test, ensuring complete data isolation
|
||||
and avoiding conflicts between test runs.
|
||||
|
||||
Args:
|
||||
app: Flask application instance
|
||||
|
||||
Yields:
|
||||
Session: SQLAlchemy session connected to in-memory database
|
||||
"""
|
||||
# Create in-memory SQLite engine with StaticPool to avoid connection issues
|
||||
engine = create_engine(
|
||||
"sqlite:///:memory:",
|
||||
poolclass=StaticPool,
|
||||
connect_args={"check_same_thread": False},
|
||||
)
|
||||
|
||||
# Create session bound to in-memory database
|
||||
session_factory = sessionmaker(bind=engine)
|
||||
session = session_factory()
|
||||
|
||||
# Make session compatible with Flask-SQLAlchemy expectations
|
||||
session.remove = lambda: None
|
||||
session.get_bind = lambda *args, **kwargs: engine
|
||||
|
||||
with app.app_context():
|
||||
# Patch db.session to use our in-memory session
|
||||
with patch.object(db, "session", session):
|
||||
# Import models to ensure they're registered
|
||||
from flask_appbuilder.security.sqla.models import User as FABUser
|
||||
|
||||
# Create all tables in the in-memory database
|
||||
# Flask-AppBuilder models use a different metadata object
|
||||
# We need to create tables from both metadata objects
|
||||
|
||||
# First create Flask-AppBuilder tables (User, Role, etc.)
|
||||
FABUser.metadata.create_all(engine)
|
||||
|
||||
# Then create Superset-specific tables
|
||||
db.metadata.create_all(engine)
|
||||
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
# Clean up: rollback any pending transactions
|
||||
session.rollback()
|
||||
session.close()
|
||||
engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def user_with_data(app_context: Session) -> Session:
|
||||
"""
|
||||
Create a test user in the database.
|
||||
|
||||
Some DAO tests expect a user with specific attributes to exist.
|
||||
This fixture creates that user and returns the database session.
|
||||
|
||||
Args:
|
||||
app_context: Database session from app_context fixture
|
||||
|
||||
Returns:
|
||||
Session: The same database session with test user created
|
||||
"""
|
||||
# Create test user with expected attributes
|
||||
user = User(
|
||||
username="testuser",
|
||||
first_name="Test",
|
||||
last_name="User",
|
||||
email="testuser@example.com",
|
||||
active=True,
|
||||
)
|
||||
db.session.add(user)
|
||||
db.session.commit()
|
||||
|
||||
return app_context
|
||||
@@ -14,12 +14,19 @@
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Unit tests for BaseDAO functionality using mocks and no database operations.
|
||||
"""
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import Boolean, Column, Integer, String
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
from superset.daos.base import BaseDAO
|
||||
from superset.daos.base import BaseDAO, ColumnOperatorEnum
|
||||
from superset.daos.exceptions import DAOFindFailedError
|
||||
|
||||
|
||||
@@ -37,6 +44,107 @@ class TestDAOWithNoneModel(BaseDAO[MockModel]):
|
||||
model_cls = None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Unit Tests - These tests use mocks and don't touch the database
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_column_operator_enum_apply_method() -> None: # noqa: C901
|
||||
"""
|
||||
Test that the apply method works correctly for each operator.
|
||||
This verifies the actual SQL generation for each operator.
|
||||
"""
|
||||
Base_test = declarative_base() # noqa: N806
|
||||
|
||||
class TestModel(Base_test): # type: ignore
|
||||
__tablename__ = "test_model"
|
||||
id = Column(Integer, primary_key=True)
|
||||
name = Column(String(50))
|
||||
age = Column(Integer)
|
||||
active = Column(Boolean)
|
||||
|
||||
# Test each operator's apply method
|
||||
test_cases = [
|
||||
# (operator, column, value, expected_sql_fragment)
|
||||
(ColumnOperatorEnum.eq, TestModel.name, "test", "test_model.name = 'test'"),
|
||||
(ColumnOperatorEnum.ne, TestModel.name, "test", "test_model.name != 'test'"),
|
||||
(ColumnOperatorEnum.sw, TestModel.name, "test", "test_model.name LIKE 'test%'"),
|
||||
(ColumnOperatorEnum.ew, TestModel.name, "test", "test_model.name LIKE '%test'"),
|
||||
(ColumnOperatorEnum.in_, TestModel.id, [1, 2, 3], "test_model.id IN (1, 2, 3)"),
|
||||
(
|
||||
ColumnOperatorEnum.nin,
|
||||
TestModel.id,
|
||||
[1, 2, 3],
|
||||
"test_model.id NOT IN (1, 2, 3)",
|
||||
),
|
||||
(ColumnOperatorEnum.gt, TestModel.age, 25, "test_model.age > 25"),
|
||||
(ColumnOperatorEnum.gte, TestModel.age, 25, "test_model.age >= 25"),
|
||||
(ColumnOperatorEnum.lt, TestModel.age, 25, "test_model.age < 25"),
|
||||
(ColumnOperatorEnum.lte, TestModel.age, 25, "test_model.age <= 25"),
|
||||
(
|
||||
ColumnOperatorEnum.like,
|
||||
TestModel.name,
|
||||
"test",
|
||||
"test_model.name LIKE '%test%'",
|
||||
),
|
||||
(
|
||||
ColumnOperatorEnum.ilike,
|
||||
TestModel.name,
|
||||
"test",
|
||||
"lower(test_model.name) LIKE lower('%test%')",
|
||||
),
|
||||
(ColumnOperatorEnum.is_null, TestModel.name, None, "test_model.name IS NULL"),
|
||||
(
|
||||
ColumnOperatorEnum.is_not_null,
|
||||
TestModel.name,
|
||||
None,
|
||||
"test_model.name IS NOT NULL",
|
||||
),
|
||||
]
|
||||
|
||||
for operator, column, value, expected_sql_fragment in test_cases:
|
||||
# Apply the operator
|
||||
result = operator.apply(column, value)
|
||||
|
||||
# Convert to string to check SQL generation
|
||||
sql_str = str(result.compile(compile_kwargs={"literal_binds": True}))
|
||||
|
||||
# Normalize whitespace for comparison
|
||||
normalized_sql = " ".join(sql_str.split())
|
||||
normalized_expected = " ".join(expected_sql_fragment.split())
|
||||
|
||||
# Assert the exact SQL fragment is present
|
||||
assert normalized_expected in normalized_sql, (
|
||||
f"Expected SQL fragment '{expected_sql_fragment}' not found in generated "
|
||||
f"SQL: '{sql_str}' "
|
||||
f"for operator {operator.name}"
|
||||
)
|
||||
|
||||
# Test that all operators are covered
|
||||
all_operators = set(ColumnOperatorEnum)
|
||||
tested_operators = {
|
||||
ColumnOperatorEnum.eq,
|
||||
ColumnOperatorEnum.ne,
|
||||
ColumnOperatorEnum.sw,
|
||||
ColumnOperatorEnum.ew,
|
||||
ColumnOperatorEnum.in_,
|
||||
ColumnOperatorEnum.nin,
|
||||
ColumnOperatorEnum.gt,
|
||||
ColumnOperatorEnum.gte,
|
||||
ColumnOperatorEnum.lt,
|
||||
ColumnOperatorEnum.lte,
|
||||
ColumnOperatorEnum.like,
|
||||
ColumnOperatorEnum.ilike,
|
||||
ColumnOperatorEnum.is_null,
|
||||
ColumnOperatorEnum.is_not_null,
|
||||
}
|
||||
|
||||
# Ensure we've tested all operators
|
||||
assert tested_operators == all_operators, (
|
||||
f"Missing operators: {all_operators - tested_operators}"
|
||||
)
|
||||
|
||||
|
||||
def test_find_by_ids_sqlalchemy_error_with_model_cls():
|
||||
"""Test SQLAlchemyError in find_by_ids shows proper model name
|
||||
when model_cls is set"""
|
||||
@@ -137,6 +245,6 @@ def test_find_by_ids_none_id_column():
|
||||
with patch("superset.daos.base.getattr") as mock_getattr:
|
||||
mock_getattr.return_value = None
|
||||
|
||||
results = TestDAO.find_by_ids([1, 2])
|
||||
results = TestDAO.find_by_ids([1, 2, 3])
|
||||
|
||||
assert results == []
|
||||
|
||||
Reference in New Issue
Block a user