Move logic to commands

This commit is contained in:
Beto Dealmeida
2026-03-30 16:23:55 -04:00
parent bc9abd31c7
commit 1a1adcfabd
5 changed files with 419 additions and 266 deletions

View File

@@ -20,20 +20,15 @@ from typing import Any
from flask import current_app as app, request
from flask_appbuilder.api import expose, protect, rison, safe
from flask_appbuilder.api.schemas import get_list_schema
from sqlalchemy import and_, func, literal, or_, select, union_all
from sqlalchemy.orm import joinedload
from sqlalchemy.sql import Select
from superset import db, event_logger, is_feature_enabled, security_manager
from superset.connectors.sqla import models as sqla_models
from superset.connectors.sqla.models import BaseDatasource, SqlaTable
from superset import event_logger, is_feature_enabled, security_manager
from superset.commands.datasource.list import GetCombinedDatasourceListCommand
from superset.connectors.sqla.models import BaseDatasource
from superset.daos.datasource import DatasourceDAO
from superset.daos.exceptions import DatasourceNotFound, DatasourceTypeNotSupportedError
from superset.exceptions import SupersetSecurityException
from superset.semantic_layers.models import SemanticView
from superset.superset_typing import FlaskResponse
from superset.utils.core import apply_max_row_limit, DatasourceType, SqlExpressionType
from superset.utils.filters import get_dataset_access_filters
from superset.views.base_api import BaseSupersetApi, statsd_metrics
logger = logging.getLogger(__name__)
@@ -351,260 +346,10 @@ class DatasourceRestApi(BaseSupersetApi):
if not can_read_datasets and not can_read_sv:
return self.response(403, message="Access denied")
args = kwargs.get("rison", {})
page = args.get("page", 0)
page_size = args.get("page_size", 25)
order_column = args.get("order_column", "changed_on")
order_direction = args.get("order_direction", "desc")
filters = args.get("filters", [])
result = GetCombinedDatasourceListCommand(
args=kwargs.get("rison", {}),
can_read_datasets=can_read_datasets,
can_read_semantic_views=can_read_sv,
).run()
source_type, name_filter, sql_filter, type_filter = (
self._parse_combined_list_filters(filters)
)
# If semantic layers feature flag is off or no SV access, only show datasets
if not can_read_sv:
source_type = "database"
# If no dataset access, only show semantic views
if not can_read_datasets:
source_type = "semantic_layer"
ds_q = self._build_dataset_query(name_filter, sql_filter)
# Selecting Physical/Virtual implicitly means "database only"
if sql_filter is not None and source_type == "all":
source_type = "database"
# Handle type_filter = "semantic_view"
if type_filter == "semantic_view":
source_type = "semantic_layer"
sv_q = self._build_semantic_view_query(name_filter)
# Build combined query based on source_type
if source_type == "database":
combined = ds_q.subquery()
elif source_type == "semantic_layer":
combined = sv_q.subquery()
else:
combined = union_all(ds_q, sv_q).subquery()
total_count, rows = self._paginate_combined_query(
combined, order_column, order_direction, page, page_size
)
result = self._fetch_and_serialize_rows(rows)
return self.response(200, count=total_count, result=result)
@staticmethod
def _parse_combined_list_filters(
filters: list[dict[str, Any]],
) -> tuple[str, str | None, bool | None, str | None]:
"""Parse filters into source_type, name_filter, sql_filter, type_filter."""
source_type = "all"
name_filter = None
sql_filter: bool | None = None
type_filter: str | None = None
for f in filters:
if f.get("col") == "source_type":
source_type = f.get("value", "all")
elif f.get("col") == "table_name" and f.get("opr") == "ct":
name_filter = f.get("value")
elif f.get("col") == "sql":
val = f.get("value")
if val == "semantic_view":
type_filter = "semantic_view"
else:
sql_filter = val
return source_type, name_filter, sql_filter, type_filter
@staticmethod
def _build_dataset_query(
name_filter: str | None,
sql_filter: bool | None,
) -> Select:
"""Build the dataset subquery with filters."""
ds_q = select(
SqlaTable.id.label("item_id"),
literal("database").label("source_type"),
SqlaTable.changed_on,
SqlaTable.table_name,
).select_from(SqlaTable.__table__)
if not security_manager.can_access_all_datasources():
ds_q = ds_q.join(
sqla_models.Database,
sqla_models.Database.id == SqlaTable.database_id,
)
ds_q = ds_q.where(get_dataset_access_filters(SqlaTable))
if name_filter:
ds_q = ds_q.where(SqlaTable.table_name.ilike(f"%{name_filter}%"))
if sql_filter is not None:
if sql_filter:
ds_q = ds_q.where(or_(SqlaTable.sql.is_(None), SqlaTable.sql == ""))
else:
ds_q = ds_q.where(and_(SqlaTable.sql.isnot(None), SqlaTable.sql != ""))
return ds_q
@staticmethod
def _build_semantic_view_query(name_filter: str | None) -> Select:
"""Build the semantic view subquery with filters."""
sv_q = select(
SemanticView.id.label("item_id"),
literal("semantic_layer").label("source_type"),
SemanticView.changed_on,
SemanticView.name.label("table_name"),
).select_from(SemanticView.__table__)
if name_filter:
sv_q = sv_q.where(SemanticView.name.ilike(f"%{name_filter}%"))
return sv_q
@staticmethod
def _paginate_combined_query(
combined: Any,
order_column: str,
order_direction: str,
page: int,
page_size: int,
) -> tuple[int, list[Any]]:
"""Count, sort, and paginate the combined query."""
sort_col_map = {
"changed_on_delta_humanized": "changed_on",
"table_name": "table_name",
}
sort_col_name = sort_col_map.get(order_column, "changed_on")
total_count = (
db.session.execute(select(func.count()).select_from(combined)).scalar() or 0
)
sort_col = combined.c[sort_col_name]
if order_direction == "desc":
sort_col = sort_col.desc()
else:
sort_col = sort_col.asc()
paginated_q = (
select(combined.c.item_id, combined.c.source_type)
.order_by(sort_col)
.offset(page * page_size)
.limit(page_size)
)
rows = db.session.execute(paginated_q).fetchall()
return total_count, rows
def _fetch_and_serialize_rows(self, rows: list[Any]) -> list[dict[str, Any]]:
"""Fetch ORM objects and serialize rows in order."""
dataset_ids = [r.item_id for r in rows if r.source_type == "database"]
sv_ids = [r.item_id for r in rows if r.source_type == "semantic_layer"]
datasets_map: dict[int, SqlaTable] = {}
if dataset_ids:
ds_objs = (
db.session.query(SqlaTable)
.options(
joinedload(SqlaTable.database),
joinedload(SqlaTable.owners),
joinedload(SqlaTable.changed_by),
)
.filter(SqlaTable.id.in_(dataset_ids))
.all()
)
datasets_map = {obj.id: obj for obj in ds_objs}
sv_map: dict[int, SemanticView] = {}
if sv_ids:
sv_objs = (
db.session.query(SemanticView)
.options(joinedload(SemanticView.changed_by))
.filter(SemanticView.id.in_(sv_ids))
.all()
)
sv_map = {obj.id: obj for obj in sv_objs}
result = []
for row in rows:
if row.source_type == "database":
obj = datasets_map.get(row.item_id)
if obj:
result.append(self._serialize_dataset(obj))
else:
obj = sv_map.get(row.item_id)
if obj:
result.append(self._serialize_semantic_view(obj))
return result
@staticmethod
def _serialize_dataset(obj: SqlaTable) -> dict[str, Any]:
changed_by = obj.changed_by
return {
"id": obj.id,
"uuid": str(obj.uuid),
"table_name": obj.table_name,
"kind": obj.kind,
"source_type": "database",
"description": obj.description,
"explore_url": obj.explore_url,
"database": {
"id": obj.database_id,
"database_name": obj.database.database_name,
}
if obj.database
else None,
"schema": obj.schema,
"sql": obj.sql,
"extra": obj.extra,
"owners": [
{
"id": o.id,
"first_name": o.first_name,
"last_name": o.last_name,
}
for o in obj.owners
],
"changed_by_name": obj.changed_by_name,
"changed_by": {
"first_name": changed_by.first_name,
"last_name": changed_by.last_name,
}
if changed_by
else None,
"changed_on_delta_humanized": obj.changed_on_delta_humanized(),
"changed_on_utc": obj.changed_on_utc(),
}
@staticmethod
def _serialize_semantic_view(obj: SemanticView) -> dict[str, Any]:
changed_by = obj.changed_by
return {
"id": obj.id,
"uuid": str(obj.uuid),
"table_name": obj.name,
"kind": "semantic_view",
"source_type": "semantic_layer",
"description": obj.description,
"cache_timeout": obj.cache_timeout,
"explore_url": obj.explore_url,
"database": None,
"schema": None,
"sql": None,
"extra": None,
"owners": [],
"changed_by_name": obj.changed_by_name,
"changed_by": {
"first_name": changed_by.first_name,
"last_name": changed_by.last_name,
}
if changed_by
else None,
"changed_on_delta_humanized": obj.changed_on_delta_humanized(),
"changed_on_utc": obj.changed_on_utc(),
}
return self.response(200, **result)