feat: Add BaseDAO improvements and test reorganization (#35018)

Co-authored-by: bito-code-review[bot] <188872107+bito-code-review[bot]@users.noreply.github.com>
Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
Amin Ghadersohi
2025-09-16 21:15:16 -04:00
committed by GitHub
parent 05c6a1bf20
commit dced2f8564
10 changed files with 2415 additions and 38 deletions

View File

@@ -84,6 +84,7 @@ dependencies = [
"pgsanity",
"Pillow>=11.0.0, <12",
"polyline>=2.0.0, <3.0",
"pydantic>=2.8.0",
"pyparsing>=3.0.6, <4",
"python-dateutil",
"python-dotenv", # optional dependencies for Flask but required for Superset, see https://flask.palletsprojects.com/en/stable/installation/#optional-dependencies

View File

@@ -6,6 +6,8 @@ alembic==1.15.2
# via flask-migrate
amqp==5.3.1
# via kombu
annotated-types==0.7.0
# via pydantic
apispec==6.6.1
# via
# -r requirements/base.in
@@ -115,7 +117,9 @@ flask==2.3.3
# flask-sqlalchemy
# flask-wtf
flask-appbuilder==5.0.0
# via apache-superset (pyproject.toml)
# via
# apache-superset (pyproject.toml)
# apache-superset-core
flask-babel==3.1.0
# via flask-appbuilder
flask-caching==2.3.1
@@ -294,6 +298,10 @@ pyasn1-modules==0.4.2
# via google-auth
pycparser==2.22
# via cffi
pydantic==2.11.7
# via apache-superset (pyproject.toml)
pydantic-core==2.33.2
# via pydantic
pygments==2.19.1
# via rich
pyjwt==2.10.1
@@ -404,10 +412,15 @@ typing-extensions==4.14.0
# alembic
# cattrs
# limits
# pydantic
# pydantic-core
# pyopenssl
# referencing
# selenium
# shillelagh
# typing-inspection
typing-inspection==0.4.1
# via pydantic
tzdata==2025.2
# via
# kombu

View File

@@ -18,6 +18,10 @@ amqp==5.3.1
# via
# -c requirements/base-constraint.txt
# kombu
annotated-types==0.7.0
# via
# -c requirements/base-constraint.txt
# pydantic
apispec==6.6.1
# via
# -c requirements/base-constraint.txt
@@ -212,6 +216,7 @@ flask-appbuilder==5.0.0
# via
# -c requirements/base-constraint.txt
# apache-superset
# apache-superset-core
flask-babel==3.1.0
# via
# -c requirements/base-constraint.txt
@@ -631,6 +636,14 @@ pycparser==2.22
# via
# -c requirements/base-constraint.txt
# cffi
pydantic==2.11.7
# via
# -c requirements/base-constraint.txt
# apache-superset
pydantic-core==2.33.2
# via
# -c requirements/base-constraint.txt
# pydantic
pydata-google-auth==1.9.0
# via pandas-gbq
pydruid==0.6.9
@@ -874,10 +887,17 @@ typing-extensions==4.14.0
# apache-superset
# cattrs
# limits
# pydantic
# pydantic-core
# pyopenssl
# referencing
# selenium
# shillelagh
# typing-inspection
typing-inspection==0.4.1
# via
# -c requirements/base-constraint.txt
# pydantic
tzdata==2025.2
# via
# -c requirements/base-constraint.txt

View File

@@ -16,22 +16,141 @@
# under the License.
from __future__ import annotations
from typing import Any, Generic, get_args, TypeVar
import logging
import uuid as uuid_lib
from enum import Enum
from typing import (
Any,
Dict,
Generic,
get_args,
List,
Optional,
Sequence,
Tuple,
TypeVar,
)
import sqlalchemy as sa
from flask_appbuilder.models.filters import BaseFilter
from flask_appbuilder.models.sqla import Model
from flask_appbuilder.models.sqla.interface import SQLAInterface
from flask_sqlalchemy import BaseQuery
from pydantic import BaseModel, Field
from sqlalchemy import asc, cast, desc, or_, Text
from sqlalchemy.exc import SQLAlchemyError, StatementError
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.inspection import inspect
from sqlalchemy.orm import ColumnProperty, joinedload, RelationshipProperty
from superset.daos.exceptions import (
DAOFindFailedError,
)
from superset.extensions import db
logger = logging.getLogger(__name__)
T = TypeVar("T", bound=Model)
class ColumnOperatorEnum(str, Enum):
eq = "eq"
ne = "ne"
sw = "sw"
ew = "ew"
in_ = "in"
nin = "nin"
gt = "gt"
gte = "gte"
lt = "lt"
lte = "lte"
like = "like"
ilike = "ilike"
is_null = "is_null"
is_not_null = "is_not_null"
def apply(self, column: Any, value: Any) -> Any:
op_func = operator_map.get(self)
if not op_func:
raise ValueError("Unsupported operator: %s" % self)
return op_func(column, value)
# Define operator_map as a module-level dict after the enum is defined
operator_map: Dict[ColumnOperatorEnum, Any] = {
ColumnOperatorEnum.eq: lambda col, val: col == val,
ColumnOperatorEnum.ne: lambda col, val: col != val,
ColumnOperatorEnum.sw: lambda col, val: col.like(f"{val}%"),
ColumnOperatorEnum.ew: lambda col, val: col.like(f"%{val}"),
ColumnOperatorEnum.in_: lambda col, val: col.in_(
val if isinstance(val, (list, tuple)) else [val]
),
ColumnOperatorEnum.nin: lambda col, val: ~col.in_(
val if isinstance(val, (list, tuple)) else [val]
),
ColumnOperatorEnum.gt: lambda col, val: col > val,
ColumnOperatorEnum.gte: lambda col, val: col >= val,
ColumnOperatorEnum.lt: lambda col, val: col < val,
ColumnOperatorEnum.lte: lambda col, val: col <= val,
ColumnOperatorEnum.like: lambda col, val: col.like(f"%{val}%"),
ColumnOperatorEnum.ilike: lambda col, val: col.ilike(f"%{val}%"),
ColumnOperatorEnum.is_null: lambda col, _: col.is_(None),
ColumnOperatorEnum.is_not_null: lambda col, _: col.isnot(None),
}
# Map SQLAlchemy types to supported operators
TYPE_OPERATOR_MAP = {
"string": [
ColumnOperatorEnum.eq,
ColumnOperatorEnum.ne,
ColumnOperatorEnum.sw,
ColumnOperatorEnum.ew,
ColumnOperatorEnum.in_,
ColumnOperatorEnum.nin,
ColumnOperatorEnum.like,
ColumnOperatorEnum.ilike,
ColumnOperatorEnum.is_null,
ColumnOperatorEnum.is_not_null,
],
"boolean": [
ColumnOperatorEnum.eq,
ColumnOperatorEnum.ne,
ColumnOperatorEnum.is_null,
ColumnOperatorEnum.is_not_null,
],
"number": [
ColumnOperatorEnum.eq,
ColumnOperatorEnum.ne,
ColumnOperatorEnum.gt,
ColumnOperatorEnum.gte,
ColumnOperatorEnum.lt,
ColumnOperatorEnum.lte,
ColumnOperatorEnum.in_,
ColumnOperatorEnum.nin,
ColumnOperatorEnum.is_null,
ColumnOperatorEnum.is_not_null,
],
"datetime": [
ColumnOperatorEnum.eq,
ColumnOperatorEnum.ne,
ColumnOperatorEnum.gt,
ColumnOperatorEnum.gte,
ColumnOperatorEnum.lt,
ColumnOperatorEnum.lte,
ColumnOperatorEnum.in_,
ColumnOperatorEnum.nin,
ColumnOperatorEnum.is_null,
ColumnOperatorEnum.is_not_null,
],
}
class ColumnOperator(BaseModel):
col: str = Field(..., description="Column name to filter on")
opr: ColumnOperatorEnum = Field(..., description="Operator")
value: Any = Field(None, description="Value for the filter")
class BaseDAO(Generic[T]):
"""
Base DAO, implement base CRUD sqlalchemy operations
@@ -83,80 +202,170 @@ class BaseDAO(Generic[T]):
return None
@classmethod
def find_by_id(
cls,
model_id: str | int,
skip_base_filter: bool = False,
) -> T | None:
def _apply_base_filter(
cls, query: Any, skip_base_filter: bool = False, data_model: Any = None
) -> Any:
"""
Find a model by id, if defined applies `base_filter`
Apply the base_filter to the query if it exists and skip_base_filter is False.
"""
query = db.session.query(cls.model_cls)
if cls.base_filter and not skip_base_filter:
data_model = SQLAInterface(cls.model_cls, db.session)
if data_model is None:
data_model = SQLAInterface(cls.model_cls, db.session)
query = cls.base_filter( # pylint: disable=not-callable
cls.id_column_name, data_model
).apply(query, None)
id_column = getattr(cls.model_cls, cls.id_column_name)
return query
@classmethod
def _convert_value_for_column(cls, column: Any, value: Any) -> Any:
"""
Convert a value to the appropriate type for a given SQLAlchemy column.
Args:
column: SQLAlchemy column object
value: Value to convert
Returns:
Converted value or None if conversion fails
"""
if (
hasattr(column.type, "python_type")
and column.type.python_type == uuid_lib.UUID
):
if isinstance(value, str):
try:
return uuid_lib.UUID(value)
except (ValueError, AttributeError):
return None
return value
@classmethod
def _find_by_column(
cls,
column_name: str,
value: str | int,
skip_base_filter: bool = False,
) -> T | None:
"""
Private method to find a model by any column value.
Args:
column_name: Name of the column to search by
value: Value to search for
skip_base_filter: Whether to skip base filtering
Returns:
Model instance or None if not found
"""
query = db.session.query(cls.model_cls)
query = cls._apply_base_filter(query, skip_base_filter)
if not hasattr(cls.model_cls, column_name):
return None
column = getattr(cls.model_cls, column_name)
converted_value = cls._convert_value_for_column(column, value)
if converted_value is None:
return None
try:
return query.filter(id_column == model_id).one_or_none()
return query.filter(column == converted_value).one_or_none()
except StatementError:
# can happen if int is passed instead of a string or similar
return None
@classmethod
def find_by_id(
cls,
model_id: str | int,
skip_base_filter: bool = False,
id_column: str | None = None,
) -> T | None:
"""
Find a model by ID using specified or default ID column.
Args:
model_id: ID value to search for
skip_base_filter: Whether to skip base filtering
id_column: Column name to use (defaults to cls.id_column_name)
Returns:
Model instance or None if not found
"""
column = id_column or cls.id_column_name
return cls._find_by_column(column, model_id, skip_base_filter)
@classmethod
def find_by_ids(
cls,
model_ids: list[str] | list[int],
model_ids: Sequence[str | int],
skip_base_filter: bool = False,
id_column: str | None = None,
) -> list[T]:
"""
Find a List of models by a list of ids, if defined applies `base_filter`
:param model_ids: List of IDs to find
:param skip_base_filter: If true, skip applying the base filter
:param id_column: Optional column name to use for ID lookup
(defaults to id_column_name)
"""
id_col = getattr(cls.model_cls, cls.id_column_name, None)
column = id_column or cls.id_column_name
id_col = getattr(cls.model_cls, column, None)
if id_col is None or not model_ids:
return []
query = db.session.query(cls.model_cls).filter(id_col.in_(model_ids))
if cls.base_filter and not skip_base_filter:
data_model = SQLAInterface(cls.model_cls, db.session)
query = cls.base_filter( # pylint: disable=not-callable
cls.id_column_name, data_model
).apply(query, None)
# Convert IDs to appropriate types based on column type
converted_ids: list[str | int | uuid_lib.UUID] = []
for id_val in model_ids:
converted_value = cls._convert_value_for_column(id_col, id_val)
if converted_value is not None:
# Only add successfully converted values
converted_ids.append(converted_value)
else:
# Log warning for failed conversions
logger.warning(
"Failed to convert ID '%s' for column %s.%s",
id_val,
cls.model_cls.__name__ if cls.model_cls else "Unknown",
column,
)
# If no valid IDs after conversion, return empty list
if not converted_ids:
return []
query = db.session.query(cls.model_cls).filter(id_col.in_(converted_ids))
query = cls._apply_base_filter(query, skip_base_filter)
try:
results = query.all()
except SQLAlchemyError as ex:
model_name = cls.model_cls.__name__ if cls.model_cls else "Unknown"
raise DAOFindFailedError(
f"Failed to find {model_name} with ids: {model_ids}"
"Failed to find %s with ids: %s" % (model_name, model_ids)
) from ex
return results
@classmethod
def find_all(cls) -> list[T]:
def find_all(cls, skip_base_filter: bool = False) -> list[T]:
"""
Get all that fit the `base_filter`
"""
query = db.session.query(cls.model_cls)
if cls.base_filter:
data_model = SQLAInterface(cls.model_cls, db.session)
query = cls.base_filter( # pylint: disable=not-callable
cls.id_column_name, data_model
).apply(query, None)
query = cls._apply_base_filter(query, skip_base_filter)
return query.all()
@classmethod
def find_one_or_none(cls, **filter_by: Any) -> T | None:
def find_one_or_none(
cls, skip_base_filter: bool = False, **filter_by: Any
) -> T | None:
"""
Get the first that fit the `base_filter`
"""
query = db.session.query(cls.model_cls)
if cls.base_filter:
data_model = SQLAInterface(cls.model_cls, db.session)
query = cls.base_filter( # pylint: disable=not-callable
cls.id_column_name, data_model
).apply(query, None)
query = cls._apply_base_filter(query, skip_base_filter)
return query.filter_by(**filter_by).one_or_none()
@classmethod
@@ -251,3 +460,229 @@ class BaseDAO(Generic[T]):
cls.id_column_name, data_model
).apply(query, None)
return query.filter_by(**filter_by).all()
@classmethod
def apply_column_operators(
cls, query: Any, column_operators: Optional[List[ColumnOperator]] = None
) -> Any:
"""
Apply column operators (list of ColumnOperator) to the query using
ColumnOperatorEnum logic. Raises ValueError if a filter references a
non-existent column.
"""
if not column_operators:
return query
for c in column_operators:
if not isinstance(c, ColumnOperator):
continue
col, opr, value = c.col, c.opr, c.value
if not col or not hasattr(cls.model_cls, col):
model_name = cls.model_cls.__name__ if cls.model_cls else "Unknown"
logging.error(
"Invalid filter: column '%s' does not exist on %s", col, model_name
)
raise ValueError(
"Invalid filter: column '%s' does not exist on %s"
% (col, model_name)
)
column = getattr(cls.model_cls, col)
try:
# Always use ColumnOperatorEnum's apply method
operator_enum = ColumnOperatorEnum(opr)
query = query.filter(operator_enum.apply(column, value))
except Exception as e:
logging.error("Error applying filter on column '%s': %s", col, e)
raise
return query
@classmethod
def get_filterable_columns_and_operators(cls) -> Dict[str, List[str]]:
"""
Returns a dict mapping filterable columns (including hybrid/computed fields if
present) to their supported operators. Used by MCP tools to dynamically expose
filter options. Custom fields supported by the DAO but not present on the model
should be documented here.
"""
mapper = inspect(cls.model_cls)
columns = {c.key: c for c in mapper.columns}
# Add hybrid properties
hybrids = {
name: attr
for name, attr in vars(cls.model_cls).items()
if isinstance(attr, hybrid_property)
}
# You may add custom fields here, e.g.:
# custom_fields = {"tags": ["eq", "in_", "like"], ...}
custom_fields: Dict[str, List[str]] = {}
filterable: Dict[str, Any] = {}
for name, col in columns.items():
if isinstance(col.type, (sa.String, sa.Text)):
filterable[name] = TYPE_OPERATOR_MAP["string"]
elif isinstance(col.type, (sa.Boolean,)):
filterable[name] = TYPE_OPERATOR_MAP["boolean"]
elif isinstance(col.type, (sa.Integer, sa.Float, sa.Numeric)):
filterable[name] = TYPE_OPERATOR_MAP["number"]
elif isinstance(col.type, (sa.DateTime, sa.Date, sa.Time)):
filterable[name] = TYPE_OPERATOR_MAP["datetime"]
else:
# Fallback to eq/ne/null
filterable[name] = [
ColumnOperatorEnum.eq,
ColumnOperatorEnum.ne,
ColumnOperatorEnum.is_null,
ColumnOperatorEnum.is_not_null,
]
# Add hybrid properties as string fields by default
for name in hybrids:
filterable[name] = TYPE_OPERATOR_MAP["string"]
# Add custom fields
filterable.update(custom_fields)
# Convert enum values to strings for the return type
result: Dict[str, List[str]] = {}
for key, operators in filterable.items():
if isinstance(operators, list):
# Convert enums to strings
result[key] = [
op.value if isinstance(op, ColumnOperatorEnum) else op
for op in operators
]
else:
result[key] = operators
return result
@classmethod
def _build_query(
cls,
column_operators: Optional[List[ColumnOperator]] = None,
search: Optional[str] = None,
search_columns: Optional[List[str]] = None,
custom_filters: Optional[Dict[str, BaseFilter]] = None,
skip_base_filter: bool = False,
data_model: Optional[SQLAInterface] = None,
) -> Any:
"""
Build a SQLAlchemy query with base filter, column operators, search, and
custom filters.
"""
if data_model is None:
data_model = SQLAInterface(cls.model_cls, db.session)
query = data_model.session.query(cls.model_cls)
query = cls._apply_base_filter(
query, skip_base_filter=skip_base_filter, data_model=data_model
)
if search and search_columns:
search_filters = []
for column_name in search_columns:
if hasattr(cls.model_cls, column_name):
column = getattr(cls.model_cls, column_name)
search_filters.append(cast(column, Text).ilike(f"%{search}%"))
if search_filters:
query = query.filter(or_(*search_filters))
if custom_filters:
for filter_class in custom_filters.values():
query = filter_class.apply(query, None)
if column_operators:
query = cls.apply_column_operators(query, column_operators)
return query
@classmethod
def list( # noqa: C901
cls,
column_operators: Optional[List[ColumnOperator]] = None,
order_column: str = "changed_on",
order_direction: str = "desc",
page: int = 0,
page_size: int = 100,
search: Optional[str] = None,
search_columns: Optional[List[str]] = None,
custom_filters: Optional[Dict[str, BaseFilter]] = None,
columns: Optional[List[str]] = None,
) -> Tuple[List[Any], int]:
"""
Generic list method for filtered, sorted, and paginated results.
If columns is specified, returns a list of tuples (one per row),
otherwise returns model instances.
"""
data_model = SQLAInterface(cls.model_cls, db.session)
column_attrs = []
relationship_loads = []
if columns is None:
columns = []
for name in columns:
attr = getattr(cls.model_cls, name, None)
if attr is None:
continue
prop = getattr(attr, "property", None)
if isinstance(prop, ColumnProperty):
column_attrs.append(attr)
elif isinstance(prop, RelationshipProperty):
relationship_loads.append(joinedload(attr))
# Ignore properties and other non-queryable attributes
if relationship_loads:
# If any relationships are requested, query the full model
# but don't add the joins yet - we'll add them after counting
query = data_model.session.query(cls.model_cls)
elif column_attrs:
# Only columns requested
query = data_model.session.query(*column_attrs)
else:
# Fallback: query the full model
query = data_model.session.query(cls.model_cls)
query = cls._apply_base_filter(query, data_model=data_model)
if search and search_columns:
search_filters = []
for column_name in search_columns:
if hasattr(cls.model_cls, column_name):
column = getattr(cls.model_cls, column_name)
search_filters.append(cast(column, Text).ilike(f"%{search}%"))
if search_filters:
query = query.filter(or_(*search_filters))
if custom_filters:
for filter_class in custom_filters.values():
query = filter_class.apply(query, None)
if column_operators:
query = cls.apply_column_operators(query, column_operators)
# Count before adding relationship joins to avoid inflated counts
# with one-to-many or many-to-many relationships
total_count = query.count()
# Add relationship joins after counting
if relationship_loads:
for loader in relationship_loads:
query = query.options(loader)
if hasattr(cls.model_cls, order_column):
column = getattr(cls.model_cls, order_column)
if order_direction.lower() == "desc":
query = query.order_by(desc(column))
else:
query = query.order_by(asc(column))
page = page
page_size = max(page_size, 1)
query = query.offset(page * page_size).limit(page_size)
items = query.all()
# If columns are specified, SQLAlchemy returns Row objects (not tuples or
# model instances)
return items, total_count
@classmethod
def count(
cls,
column_operators: Optional[List[ColumnOperator]] = None,
skip_base_filter: bool = False,
) -> int:
"""
Count the number of records for the model, optionally filtered by column
operators.
"""
query = cls._build_query(
column_operators=column_operators, skip_base_filter=skip_base_filter
)
return query.count()

View File

@@ -18,7 +18,7 @@ from __future__ import annotations
import logging
from datetime import datetime
from typing import TYPE_CHECKING
from typing import Dict, List, TYPE_CHECKING
from flask_appbuilder.models.sqla.interface import SQLAInterface
@@ -35,10 +35,23 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
# Custom filterable fields for charts
CHART_CUSTOM_FIELDS = {
"viz_type": ["eq", "in_", "like"],
"datasource_name": ["eq", "in_", "like"],
}
class ChartDAO(BaseDAO[Slice]):
base_filter = ChartFilter
@classmethod
def get_filterable_columns_and_operators(cls) -> Dict[str, List[str]]:
filterable = super().get_filterable_columns_and_operators()
# Add custom fields for charts
filterable.update(CHART_CUSTOM_FIELDS)
return filterable
@staticmethod
def get_by_id_or_uuid(id_or_uuid: str) -> Slice:
query = db.session.query(Slice).filter(id_or_uuid_filter(id_or_uuid))

View File

@@ -19,7 +19,7 @@ from __future__ import annotations
import logging
from collections import defaultdict
from datetime import datetime
from typing import Any
from typing import Any, Dict, List
from flask import g
from flask_appbuilder.models.sqla.interface import SQLAInterface
@@ -45,10 +45,24 @@ from superset.utils.dashboard_filter_scopes_converter import copy_filter_scopes
logger = logging.getLogger(__name__)
# Custom filterable fields for dashboards
DASHBOARD_CUSTOM_FIELDS = {
"tags": ["eq", "in_", "like"],
"owners": ["eq", "in_"],
"published": ["eq"],
}
class DashboardDAO(BaseDAO[Dashboard]):
base_filter = DashboardAccessFilter
@classmethod
def get_filterable_columns_and_operators(cls) -> Dict[str, List[str]]:
filterable = super().get_filterable_columns_and_operators()
# Add custom fields for dashboards
filterable.update(DASHBOARD_CUSTOM_FIELDS)
return filterable
@classmethod
def get_by_id_or_slug(cls, id_or_slug: int | str) -> Dashboard:
if is_uuid(id_or_slug):

View File

@@ -18,7 +18,7 @@ from __future__ import annotations
import logging
from datetime import datetime
from typing import Any
from typing import Any, Dict, List
import dateutil.parser
from sqlalchemy.exc import SQLAlchemyError
@@ -35,8 +35,18 @@ from superset.views.base import DatasourceFilter
logger = logging.getLogger(__name__)
# Custom filterable fields for datasets
DATASET_CUSTOM_FIELDS: dict[str, list[str]] = {}
class DatasetDAO(BaseDAO[SqlaTable]):
"""
DAO for datasets. Supports filtering on model fields, hybrid properties, and custom
fields:
- tags: list of tags (eq, in_, like)
- owner: user id (eq, in_)
"""
base_filter = DatasourceFilter
@staticmethod
@@ -351,6 +361,13 @@ class DatasetDAO(BaseDAO[SqlaTable]):
.one_or_none()
)
@classmethod
def get_filterable_columns_and_operators(cls) -> Dict[str, List[str]]:
filterable = super().get_filterable_columns_and_operators()
# Add custom fields
filterable.update(DATASET_CUSTOM_FIELDS)
return filterable
class DatasetColumnDAO(BaseDAO[TableColumn]):
pass

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,143 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Fixtures for DAO integration tests.
This module provides fixtures that replicate the unit test behavior by using
an in-memory SQLite database for each test to ensure data isolation and avoid
conflicts between test runs.
Key features:
- In-memory SQLite database created per test
- Proper Flask-SQLAlchemy session patching
- Security manager session handling
- Automatic cleanup after each test
"""
from typing import Generator
from unittest.mock import patch
import pytest
from flask import Flask
from flask_appbuilder.security.sqla.models import User
from sqlalchemy import create_engine
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.pool import StaticPool
from superset.extensions import db
from tests.integration_tests.test_app import app as superset_app
@pytest.fixture(scope="module", autouse=True)
def setup_sample_data() -> None:
"""
Override parent conftest setup_sample_data to prevent loading sample data.
This prevents the parent conftest from loading CSS templates and other
sample data that could interfere with DAO integration tests.
"""
pass
@pytest.fixture
def app() -> Flask:
"""Get the Superset Flask application instance."""
return superset_app
@pytest.fixture
def app_context(app: Flask) -> Generator[Session, None, None]:
"""
Create an in-memory SQLite database for each test.
This fixture replicates the unit test behavior by providing a fresh
in-memory database for each test, ensuring complete data isolation
and avoiding conflicts between test runs.
Args:
app: Flask application instance
Yields:
Session: SQLAlchemy session connected to in-memory database
"""
# Create in-memory SQLite engine with StaticPool to avoid connection issues
engine = create_engine(
"sqlite:///:memory:",
poolclass=StaticPool,
connect_args={"check_same_thread": False},
)
# Create session bound to in-memory database
session_factory = sessionmaker(bind=engine)
session = session_factory()
# Make session compatible with Flask-SQLAlchemy expectations
session.remove = lambda: None
session.get_bind = lambda *args, **kwargs: engine
with app.app_context():
# Patch db.session to use our in-memory session
with patch.object(db, "session", session):
# Import models to ensure they're registered
from flask_appbuilder.security.sqla.models import User as FABUser
# Create all tables in the in-memory database
# Flask-AppBuilder models use a different metadata object
# We need to create tables from both metadata objects
# First create Flask-AppBuilder tables (User, Role, etc.)
FABUser.metadata.create_all(engine)
# Then create Superset-specific tables
db.metadata.create_all(engine)
try:
yield session
finally:
# Clean up: rollback any pending transactions
session.rollback()
session.close()
engine.dispose()
@pytest.fixture
def user_with_data(app_context: Session) -> Session:
"""
Create a test user in the database.
Some DAO tests expect a user with specific attributes to exist.
This fixture creates that user and returns the database session.
Args:
app_context: Database session from app_context fixture
Returns:
Session: The same database session with test user created
"""
# Create test user with expected attributes
user = User(
username="testuser",
first_name="Test",
last_name="User",
email="testuser@example.com",
active=True,
)
db.session.add(user)
db.session.commit()
return app_context

View File

@@ -14,12 +14,19 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Unit tests for BaseDAO functionality using mocks and no database operations.
"""
from unittest.mock import Mock, patch
import pytest
from sqlalchemy import Boolean, Column, Integer, String
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.declarative import declarative_base
from superset.daos.base import BaseDAO
from superset.daos.base import BaseDAO, ColumnOperatorEnum
from superset.daos.exceptions import DAOFindFailedError
@@ -37,6 +44,107 @@ class TestDAOWithNoneModel(BaseDAO[MockModel]):
model_cls = None
# =============================================================================
# Unit Tests - These tests use mocks and don't touch the database
# =============================================================================
def test_column_operator_enum_apply_method() -> None: # noqa: C901
"""
Test that the apply method works correctly for each operator.
This verifies the actual SQL generation for each operator.
"""
Base_test = declarative_base() # noqa: N806
class TestModel(Base_test): # type: ignore
__tablename__ = "test_model"
id = Column(Integer, primary_key=True)
name = Column(String(50))
age = Column(Integer)
active = Column(Boolean)
# Test each operator's apply method
test_cases = [
# (operator, column, value, expected_sql_fragment)
(ColumnOperatorEnum.eq, TestModel.name, "test", "test_model.name = 'test'"),
(ColumnOperatorEnum.ne, TestModel.name, "test", "test_model.name != 'test'"),
(ColumnOperatorEnum.sw, TestModel.name, "test", "test_model.name LIKE 'test%'"),
(ColumnOperatorEnum.ew, TestModel.name, "test", "test_model.name LIKE '%test'"),
(ColumnOperatorEnum.in_, TestModel.id, [1, 2, 3], "test_model.id IN (1, 2, 3)"),
(
ColumnOperatorEnum.nin,
TestModel.id,
[1, 2, 3],
"test_model.id NOT IN (1, 2, 3)",
),
(ColumnOperatorEnum.gt, TestModel.age, 25, "test_model.age > 25"),
(ColumnOperatorEnum.gte, TestModel.age, 25, "test_model.age >= 25"),
(ColumnOperatorEnum.lt, TestModel.age, 25, "test_model.age < 25"),
(ColumnOperatorEnum.lte, TestModel.age, 25, "test_model.age <= 25"),
(
ColumnOperatorEnum.like,
TestModel.name,
"test",
"test_model.name LIKE '%test%'",
),
(
ColumnOperatorEnum.ilike,
TestModel.name,
"test",
"lower(test_model.name) LIKE lower('%test%')",
),
(ColumnOperatorEnum.is_null, TestModel.name, None, "test_model.name IS NULL"),
(
ColumnOperatorEnum.is_not_null,
TestModel.name,
None,
"test_model.name IS NOT NULL",
),
]
for operator, column, value, expected_sql_fragment in test_cases:
# Apply the operator
result = operator.apply(column, value)
# Convert to string to check SQL generation
sql_str = str(result.compile(compile_kwargs={"literal_binds": True}))
# Normalize whitespace for comparison
normalized_sql = " ".join(sql_str.split())
normalized_expected = " ".join(expected_sql_fragment.split())
# Assert the exact SQL fragment is present
assert normalized_expected in normalized_sql, (
f"Expected SQL fragment '{expected_sql_fragment}' not found in generated "
f"SQL: '{sql_str}' "
f"for operator {operator.name}"
)
# Test that all operators are covered
all_operators = set(ColumnOperatorEnum)
tested_operators = {
ColumnOperatorEnum.eq,
ColumnOperatorEnum.ne,
ColumnOperatorEnum.sw,
ColumnOperatorEnum.ew,
ColumnOperatorEnum.in_,
ColumnOperatorEnum.nin,
ColumnOperatorEnum.gt,
ColumnOperatorEnum.gte,
ColumnOperatorEnum.lt,
ColumnOperatorEnum.lte,
ColumnOperatorEnum.like,
ColumnOperatorEnum.ilike,
ColumnOperatorEnum.is_null,
ColumnOperatorEnum.is_not_null,
}
# Ensure we've tested all operators
assert tested_operators == all_operators, (
f"Missing operators: {all_operators - tested_operators}"
)
def test_find_by_ids_sqlalchemy_error_with_model_cls():
"""Test SQLAlchemyError in find_by_ids shows proper model name
when model_cls is set"""
@@ -137,6 +245,6 @@ def test_find_by_ids_none_id_column():
with patch("superset.daos.base.getattr") as mock_getattr:
mock_getattr.return_value = None
results = TestDAO.find_by_ids([1, 2])
results = TestDAO.find_by_ids([1, 2, 3])
assert results == []