mirror of
https://github.com/apache/superset.git
synced 2026-05-11 02:45:46 +00:00
434 lines
16 KiB
Python
434 lines
16 KiB
Python
# Licensed to the Apache Software Foundation (ASF) under one
|
|
# or more contributor license agreements. See the NOTICE file
|
|
# distributed with this work for additional information
|
|
# regarding copyright ownership. The ASF licenses this file
|
|
# to you under the Apache License, Version 2.0 (the
|
|
# "License"); you may not use this file except in compliance
|
|
# with the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing,
|
|
# software distributed under the License is distributed on an
|
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
# KIND, either express or implied. See the License for the
|
|
# specific language governing permissions and limitations
|
|
# under the License.
|
|
from __future__ import annotations
|
|
|
|
from typing import Any, Dict, Generic, get_args, List, Optional, Tuple, TypeVar
|
|
|
|
from flask_appbuilder.models.filters import BaseFilter
|
|
from flask_appbuilder.models.sqla import Model
|
|
from flask_appbuilder.models.sqla.interface import SQLAInterface
|
|
from sqlalchemy import asc, desc, or_, cast, Text
|
|
from sqlalchemy.exc import StatementError
|
|
from sqlalchemy.orm import joinedload, RelationshipProperty, ColumnProperty
|
|
from sqlalchemy.inspection import inspect
|
|
|
|
from superset.extensions import db
|
|
|
|
from enum import Enum
|
|
from pydantic import BaseModel, Field
|
|
import logging
|
|
|
|
T = TypeVar("T", bound=Model)
|
|
|
|
class ColumnOperatorEnum(str, Enum):
|
|
eq = "eq"
|
|
ne = "ne"
|
|
sw = "sw"
|
|
ew = "ew"
|
|
in_ = "in"
|
|
nin = "nin"
|
|
gt = "gt"
|
|
gte = "gte"
|
|
lt = "lt"
|
|
lte = "lte"
|
|
like = "like"
|
|
ilike = "ilike"
|
|
is_null = "is_null"
|
|
is_not_null = "is_not_null"
|
|
|
|
@classmethod
|
|
def operator_map(cls):
|
|
return {
|
|
cls.eq: lambda col, val: col == val,
|
|
cls.ne: lambda col, val: col != val,
|
|
cls.sw: lambda col, val: col.like(f"{val}%"),
|
|
cls.ew: lambda col, val: col.like(f"%{val}"),
|
|
cls.in_: lambda col, val: col.in_(val if isinstance(val, (list, tuple)) else [val]),
|
|
cls.nin: lambda col, val: ~col.in_(val if isinstance(val, (list, tuple)) else [val]),
|
|
cls.gt: lambda col, val: col > val,
|
|
cls.gte: lambda col, val: col >= val,
|
|
cls.lt: lambda col, val: col < val,
|
|
cls.lte: lambda col, val: col <= val,
|
|
cls.like: lambda col, val: col.like(val),
|
|
cls.ilike: lambda col, val: col.ilike(val),
|
|
cls.is_null: lambda col, _: col.is_(None),
|
|
cls.is_not_null: lambda col, _: col.isnot(None),
|
|
}
|
|
|
|
def apply(self, column, value):
|
|
op_func = self.operator_map().get(self)
|
|
if not op_func:
|
|
raise ValueError(f"Unsupported operator: {self}")
|
|
return op_func(column, value)
|
|
|
|
class ColumnOperator(BaseModel):
|
|
col: str = Field(..., description="Column name to filter on")
|
|
opr: ColumnOperatorEnum = Field(..., description="Operator")
|
|
value: Any = Field(None, description="Value for the filter")
|
|
|
|
class BaseDAO(Generic[T]):
|
|
"""
|
|
Base DAO, implement base CRUD sqlalchemy operations
|
|
"""
|
|
|
|
model_cls: type[Model] | None = None
|
|
"""
|
|
Child classes need to state the Model class so they don't need to implement basic
|
|
create, update and delete methods
|
|
"""
|
|
base_filter: BaseFilter | None = None
|
|
"""
|
|
Child classes can register base filtering to be applied to all filter methods
|
|
"""
|
|
id_column_name = "id"
|
|
|
|
def __init_subclass__(cls) -> None:
|
|
cls.model_cls = get_args(
|
|
cls.__orig_bases__[0] # type: ignore # pylint: disable=no-member
|
|
)[0]
|
|
|
|
@classmethod
|
|
def _apply_base_filter(cls, query, skip_base_filter: bool = False, data_model=None):
|
|
"""
|
|
Apply the base_filter to the query if it exists and skip_base_filter is False.
|
|
"""
|
|
if cls.base_filter and not skip_base_filter:
|
|
if data_model is None:
|
|
data_model = SQLAInterface(cls.model_cls, db.session)
|
|
query = cls.base_filter( # pylint: disable=not-callable
|
|
cls.id_column_name, data_model
|
|
).apply(query, None)
|
|
return query
|
|
|
|
@classmethod
|
|
def find_by_id(
|
|
cls,
|
|
model_id: str | int,
|
|
skip_base_filter: bool = False,
|
|
) -> T | None:
|
|
"""
|
|
Find a model by id, if defined applies `base_filter`
|
|
"""
|
|
query = db.session.query(cls.model_cls)
|
|
query = cls._apply_base_filter(query, skip_base_filter)
|
|
id_column = getattr(cls.model_cls, cls.id_column_name)
|
|
try:
|
|
return query.filter(id_column == model_id).one_or_none()
|
|
except StatementError:
|
|
# can happen if int is passed instead of a string or similar
|
|
return None
|
|
|
|
@classmethod
|
|
def find_by_ids(
|
|
cls,
|
|
model_ids: list[str] | list[int],
|
|
skip_base_filter: bool = False,
|
|
) -> list[T]:
|
|
"""
|
|
Find a List of models by a list of ids, if defined applies `base_filter`
|
|
"""
|
|
id_col = getattr(cls.model_cls, cls.id_column_name, None)
|
|
if id_col is None:
|
|
return []
|
|
query = db.session.query(cls.model_cls).filter(id_col.in_(model_ids))
|
|
query = cls._apply_base_filter(query, skip_base_filter)
|
|
return query.all()
|
|
|
|
@classmethod
|
|
def find_all(cls) -> list[T]:
|
|
"""
|
|
Get all that fit the `base_filter`
|
|
"""
|
|
query = db.session.query(cls.model_cls)
|
|
query = cls._apply_base_filter(query)
|
|
return query.all()
|
|
|
|
@classmethod
|
|
def find_one_or_none(cls, **filter_by: Any) -> T | None:
|
|
"""
|
|
Get the first that fit the `base_filter`
|
|
"""
|
|
query = db.session.query(cls.model_cls)
|
|
query = cls._apply_base_filter(query)
|
|
return query.filter_by(**filter_by).one_or_none()
|
|
|
|
@classmethod
|
|
def create(
|
|
cls,
|
|
item: T | None = None,
|
|
attributes: dict[str, Any] | None = None,
|
|
) -> T:
|
|
"""
|
|
Create an object from the specified item and/or attributes.
|
|
|
|
:param item: The object to create
|
|
:param attributes: The attributes associated with the object to create
|
|
"""
|
|
|
|
if not item:
|
|
item = cls.model_cls() # type: ignore # pylint: disable=not-callable
|
|
|
|
if attributes:
|
|
for key, value in attributes.items():
|
|
setattr(item, key, value)
|
|
|
|
db.session.add(item)
|
|
return item # type: ignore
|
|
|
|
@classmethod
|
|
def update(
|
|
cls,
|
|
item: T | None = None,
|
|
attributes: dict[str, Any] | None = None,
|
|
) -> T:
|
|
"""
|
|
Update an object from the specified item and/or attributes.
|
|
|
|
:param item: The object to update
|
|
:param attributes: The attributes associated with the object to update
|
|
"""
|
|
|
|
if not item:
|
|
item = cls.model_cls() # type: ignore # pylint: disable=not-callable
|
|
|
|
if attributes:
|
|
for key, value in attributes.items():
|
|
setattr(item, key, value)
|
|
|
|
if item not in db.session:
|
|
return db.session.merge(item)
|
|
|
|
return item # type: ignore
|
|
|
|
@classmethod
|
|
def delete(cls, items: list[T]) -> None:
|
|
"""
|
|
Delete the specified items including their associated relationships.
|
|
|
|
Note that bulk deletion via `delete` is not invoked in the base class as this
|
|
does not dispatch the ORM `after_delete` event which may be required to augment
|
|
additional records loosely defined via implicit relationships. Instead ORM
|
|
objects are deleted one-by-one via `Session.delete`.
|
|
|
|
Subclasses may invoke bulk deletion but are responsible for instrumenting any
|
|
post-deletion logic.
|
|
|
|
:param items: The items to delete
|
|
:see: https://docs.sqlalchemy.org/en/latest/orm/queryguide/dml.html
|
|
"""
|
|
|
|
for item in items:
|
|
db.session.delete(item)
|
|
|
|
@classmethod
|
|
def apply_column_operators(cls, query, column_operators: Optional[List[ColumnOperator]] = None):
|
|
"""
|
|
Apply column operators (list of ColumnOperator) to the query using ColumnOperatorEnum logic.
|
|
Raises ValueError if a filter references a non-existent column.
|
|
"""
|
|
if not column_operators:
|
|
return query
|
|
for c in column_operators:
|
|
if not isinstance(c, ColumnOperator):
|
|
continue
|
|
col = c.col
|
|
opr = c.opr
|
|
value = c.value
|
|
if not col or not hasattr(cls.model_cls, col):
|
|
logging.error(f"Invalid filter: column '{col}' does not exist on {cls.model_cls.__name__}")
|
|
raise ValueError(f"Invalid filter: column '{col}' does not exist on {cls.model_cls.__name__}")
|
|
column = getattr(cls.model_cls, col)
|
|
try:
|
|
# Always use ColumnOperatorEnum's apply method
|
|
query = query.filter(ColumnOperatorEnum(opr).apply(column, value))
|
|
except Exception as e:
|
|
logging.error(f"Error applying filter on column '{col}': {e}")
|
|
raise
|
|
return query
|
|
|
|
@classmethod
|
|
def get_filterable_columns_and_operators(cls) -> dict:
|
|
"""
|
|
Returns a dict mapping filterable columns (including hybrid/computed fields if present)
|
|
to their supported operators. Used by MCP tools to dynamically expose filter options.
|
|
Custom fields supported by the DAO but not present on the model should be documented here.
|
|
"""
|
|
from sqlalchemy.ext.hybrid import hybrid_property
|
|
mapper = inspect(cls.model_cls)
|
|
columns = {c.key: c for c in mapper.columns}
|
|
# Add hybrid properties
|
|
hybrids = {
|
|
name: attr for name, attr in vars(cls.model_cls).items()
|
|
if isinstance(attr, hybrid_property)
|
|
}
|
|
# You may add custom fields here, e.g.:
|
|
# custom_fields = {"tags": ["eq", "in_", "like"], ...}
|
|
custom_fields = {}
|
|
# Map SQLAlchemy types to supported operators
|
|
type_operator_map = {
|
|
"string": [
|
|
"eq", "ne", "sw", "ew", "in_", "nin", "like", "ilike", "is_null", "is_not_null"
|
|
],
|
|
"boolean": ["eq", "ne", "is_null", "is_not_null"],
|
|
"number": [
|
|
"eq", "ne", "gt", "gte", "lt", "lte", "in_", "nin", "is_null", "is_not_null"
|
|
],
|
|
"datetime": [
|
|
"eq", "ne", "gt", "gte", "lt", "lte", "in_", "nin", "is_null", "is_not_null"
|
|
],
|
|
}
|
|
import sqlalchemy as sa
|
|
filterable = {}
|
|
for name, col in columns.items():
|
|
if isinstance(col.type, (sa.String, sa.Text)):
|
|
filterable[name] = type_operator_map["string"]
|
|
elif isinstance(col.type, (sa.Boolean,)):
|
|
filterable[name] = type_operator_map["boolean"]
|
|
elif isinstance(col.type, (sa.Integer, sa.Float, sa.Numeric)):
|
|
filterable[name] = type_operator_map["number"]
|
|
elif isinstance(col.type, (sa.DateTime, sa.Date, sa.Time)):
|
|
filterable[name] = type_operator_map["datetime"]
|
|
else:
|
|
# Fallback to eq/ne/null
|
|
filterable[name] = ["eq", "ne", "is_null", "is_not_null"]
|
|
# Add hybrid properties as string fields by default
|
|
for name in hybrids:
|
|
filterable[name] = type_operator_map["string"]
|
|
# Add custom fields
|
|
filterable.update(custom_fields)
|
|
return filterable
|
|
|
|
@classmethod
|
|
def _build_query(
|
|
cls,
|
|
column_operators: Optional[List[ColumnOperator]] = None,
|
|
search: Optional[str] = None,
|
|
search_columns: Optional[List[str]] = None,
|
|
custom_filters: Optional[Dict[str, BaseFilter]] = None,
|
|
skip_base_filter: bool = False,
|
|
data_model: SQLAInterface = None,
|
|
):
|
|
"""
|
|
Build a SQLAlchemy query with base filter, column operators, search, and custom filters.
|
|
"""
|
|
if data_model is None:
|
|
data_model = SQLAInterface(cls.model_cls, db.session)
|
|
query = data_model.session.query(cls.model_cls)
|
|
query = cls._apply_base_filter(query, skip_base_filter=skip_base_filter, data_model=data_model)
|
|
if search and search_columns:
|
|
search_filters = []
|
|
for column_name in search_columns:
|
|
if hasattr(cls.model_cls, column_name):
|
|
column = getattr(cls.model_cls, column_name)
|
|
search_filters.append(cast(column, Text).ilike(f"%{search}%"))
|
|
if search_filters:
|
|
query = query.filter(or_(*search_filters))
|
|
if custom_filters:
|
|
for filter_class in custom_filters.values():
|
|
query = filter_class.apply(query, None)
|
|
if column_operators:
|
|
query = cls.apply_column_operators(query, column_operators)
|
|
return query
|
|
|
|
@classmethod
|
|
def list(
|
|
cls,
|
|
column_operators: Optional[List[ColumnOperator]] = None,
|
|
order_column: str = "changed_on",
|
|
order_direction: str = "desc",
|
|
page: int = 0,
|
|
page_size: int = 100,
|
|
search: Optional[str] = None,
|
|
search_columns: Optional[List[str]] = None,
|
|
custom_filters: Optional[Dict[str, BaseFilter]] = None,
|
|
columns: Optional[List[str]] = None,
|
|
) -> Tuple[List[Any], int]:
|
|
"""
|
|
Generic list method for filtered, sorted, and paginated results.
|
|
If columns is specified, returns a list of tuples (one per row),
|
|
otherwise returns model instances.
|
|
"""
|
|
data_model = SQLAInterface(cls.model_cls, db.session)
|
|
# Separate columns and relationships
|
|
mapper = inspect(cls.model_cls)
|
|
column_names = set(c.key for c in mapper.columns)
|
|
relationship_names = set(r.key for r in mapper.relationships)
|
|
|
|
column_attrs = []
|
|
relationship_loads = []
|
|
if columns is None:
|
|
columns = []
|
|
for name in columns:
|
|
attr = getattr(cls.model_cls, name, None)
|
|
if attr is None:
|
|
continue
|
|
prop = getattr(attr, 'property', None)
|
|
if isinstance(prop, ColumnProperty):
|
|
column_attrs.append(attr)
|
|
elif isinstance(prop, RelationshipProperty):
|
|
relationship_loads.append(joinedload(attr))
|
|
# Ignore properties and other non-queryable attributes
|
|
|
|
if relationship_loads:
|
|
# If any relationships are requested, query the full model and joinedload relationships
|
|
query = data_model.session.query(cls.model_cls)
|
|
for loader in relationship_loads:
|
|
query = query.options(loader)
|
|
elif column_attrs:
|
|
# Only columns requested
|
|
query = data_model.session.query(*column_attrs)
|
|
else:
|
|
# Fallback: query the full model
|
|
query = data_model.session.query(cls.model_cls)
|
|
query = cls._apply_base_filter(query, data_model=data_model)
|
|
if search and search_columns:
|
|
search_filters = []
|
|
for column_name in search_columns:
|
|
if hasattr(cls.model_cls, column_name):
|
|
column = getattr(cls.model_cls, column_name)
|
|
search_filters.append(cast(column, Text).ilike(f"%{search}%"))
|
|
if search_filters:
|
|
query = query.filter(or_(*search_filters))
|
|
if custom_filters:
|
|
for filter_class in custom_filters.values():
|
|
query = filter_class.apply(query, None)
|
|
if column_operators:
|
|
query = cls.apply_column_operators(query, column_operators)
|
|
total_count = query.count()
|
|
if hasattr(cls.model_cls, order_column):
|
|
column = getattr(cls.model_cls, order_column)
|
|
if order_direction.lower() == "desc":
|
|
query = query.order_by(desc(column))
|
|
else:
|
|
query = query.order_by(asc(column))
|
|
page = page
|
|
page_size = max(page_size, 1)
|
|
query = query.offset(page * page_size).limit(page_size)
|
|
items = query.all()
|
|
# If columns are specified, SQLAlchemy returns Row objects (not tuples or model instances)
|
|
return items, total_count
|
|
|
|
@classmethod
|
|
def count(
|
|
cls, column_operators: Optional[List[ColumnOperator]] = None, skip_base_filter: bool = False
|
|
) -> int:
|
|
"""
|
|
Count the number of records for the model, optionally filtered by column operators.
|
|
"""
|
|
query = cls._build_query(column_operators=column_operators, skip_base_filter=skip_base_filter)
|
|
return query.count()
|