From 1a1adcfabde0847f61b7721b4e15e85c0f4be380 Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Mon, 30 Mar 2026 16:23:55 -0400 Subject: [PATCH] Move logic to commands --- superset/commands/datasource/__init__.py | 0 superset/commands/datasource/list.py | 156 +++++++++++++ superset/daos/datasource.py | 117 +++++++++- superset/datasource/api.py | 273 +---------------------- superset/datasource/schemas.py | 139 ++++++++++++ 5 files changed, 419 insertions(+), 266 deletions(-) create mode 100644 superset/commands/datasource/__init__.py create mode 100644 superset/commands/datasource/list.py create mode 100644 superset/datasource/schemas.py diff --git a/superset/commands/datasource/__init__.py b/superset/commands/datasource/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/superset/commands/datasource/list.py b/superset/commands/datasource/list.py new file mode 100644 index 00000000000..50ea765b4c3 --- /dev/null +++ b/superset/commands/datasource/list.py @@ -0,0 +1,156 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Command for the combined dataset + semantic view list endpoint.""" + +from __future__ import annotations + +import logging +from typing import Any, cast + +from sqlalchemy import union_all + +from superset.commands.base import BaseCommand +from superset.connectors.sqla.models import SqlaTable +from superset.daos.datasource import DatasourceDAO +from superset.datasource.schemas import DatasetListSchema, SemanticViewListSchema +from superset.semantic_layers.models import SemanticView + +logger = logging.getLogger(__name__) + +_dataset_schema = DatasetListSchema() +_semantic_view_schema = SemanticViewListSchema() + + +class GetCombinedDatasourceListCommand(BaseCommand): + """ + Fetch and serialize a paginated, combined list of datasets and semantic views. + + Callers are responsible for checking access permissions before constructing + this command and for passing the appropriate ``can_read_*`` flags. + """ + + def __init__( + self, + args: dict[str, Any], + can_read_datasets: bool, + can_read_semantic_views: bool, + ) -> None: + self._args = args + self._can_read_datasets = can_read_datasets + self._can_read_semantic_views = can_read_semantic_views + + def run(self) -> dict[str, Any]: + self.validate() + + page = self._args.get("page", 0) + page_size = self._args.get("page_size", 25) + order_column = self._args.get("order_column", "changed_on") + order_direction = self._args.get("order_direction", "desc") + filters = self._args.get("filters", []) + + source_type, name_filter, sql_filter, type_filter = self._parse_filters(filters) + source_type = self._resolve_source_type(source_type, sql_filter, type_filter) + + ds_q = DatasourceDAO.build_dataset_query(name_filter, sql_filter) + sv_q = DatasourceDAO.build_semantic_view_query(name_filter) + + if source_type == "database": + combined = ds_q.subquery() + elif source_type == "semantic_layer": + combined = sv_q.subquery() + else: + combined = union_all(ds_q, sv_q).subquery() + + total_count, rows = DatasourceDAO.paginate_combined_query( + combined, order_column, order_direction, page, page_size + ) + + datasets_map = DatasourceDAO.fetch_datasets_by_ids( + [r.item_id for r in rows if r.source_type == "database"] + ) + sv_map = DatasourceDAO.fetch_semantic_views_by_ids( + [r.item_id for r in rows if r.source_type == "semantic_layer"] + ) + + result: list[dict[str, Any]] = [] + for row in rows: + if row.source_type == "database": + ds_obj = cast(SqlaTable | None, datasets_map.get(row.item_id)) + if ds_obj: + result.append(_dataset_schema.dump(ds_obj)) + else: + sv_obj = cast(SemanticView | None, sv_map.get(row.item_id)) + if sv_obj: + result.append(_semantic_view_schema.dump(sv_obj)) + + return {"count": total_count, "result": result} + + def validate(self) -> None: + pass # access checks are performed by the caller (API layer) + + def _resolve_source_type( + self, + source_type: str, + sql_filter: bool | None, + type_filter: str | None, + ) -> str: + """Narrow source_type based on access flags, sql filter, and type filter.""" + if not self._can_read_semantic_views: + return "database" + if not self._can_read_datasets: + return "semantic_layer" + # sql_filter (physical/virtual toggle) only applies to datasets + if sql_filter is not None: + return "database" + # Explicit semantic-view type filter + if type_filter == "semantic_view": + return "semantic_layer" + return source_type + + @staticmethod + def _parse_filters( + filters: list[dict[str, Any]], + ) -> tuple[str, str | None, bool | None, str | None]: + """ + Translate raw rison filter dicts into typed query parameters. + + Returns: + source_type: "all" | "database" | "semantic_layer" + name_filter: substring to match against name/table_name + sql_filter: True → physical only, False → virtual only, None → both + type_filter: "semantic_view" when the caller wants only semantic views + """ + source_type = "all" + name_filter: str | None = None + sql_filter: bool | None = None + type_filter: str | None = None + + for f in filters: + col = f.get("col") + value = f.get("value") + + if col == "source_type": + source_type = value or "all" + elif col == "table_name" and f.get("opr") == "ct": + name_filter = value + elif col == "sql": + if value == "semantic_view": + type_filter = "semantic_view" + else: + sql_filter = value + + return source_type, name_filter, sql_filter, type_filter diff --git a/superset/daos/datasource.py b/superset/daos/datasource.py index 6a49d350e35..3d347332ce7 100644 --- a/superset/daos/datasource.py +++ b/superset/daos/datasource.py @@ -17,9 +17,14 @@ import logging import uuid -from typing import Union +from typing import Any, Union -from superset import db +from sqlalchemy import and_, func, literal, or_, select +from sqlalchemy.orm import joinedload +from sqlalchemy.sql import Select + +from superset import db, security_manager +from superset.connectors.sqla import models as sqla_models from superset.connectors.sqla.models import SqlaTable from superset.daos.base import BaseDAO from superset.daos.exceptions import ( @@ -30,6 +35,7 @@ from superset.daos.exceptions import ( from superset.models.sql_lab import Query, SavedQuery from superset.semantic_layers.models import SemanticView from superset.utils.core import DatasourceType +from superset.utils.filters import get_dataset_access_filters logger = logging.getLogger(__name__) @@ -80,3 +86,110 @@ class DatasourceDAO(BaseDAO[Datasource]): raise DatasourceNotFound() return datasource + + @staticmethod + def build_dataset_query( + name_filter: str | None, + sql_filter: bool | None, + ) -> Select: + """Build a SELECT for datasets, applying access and content filters.""" + ds_q = select( + SqlaTable.id.label("item_id"), + literal("database").label("source_type"), + SqlaTable.changed_on, + SqlaTable.table_name, + ).select_from(SqlaTable.__table__) + + if not security_manager.can_access_all_datasources(): + ds_q = ds_q.join( + sqla_models.Database, + sqla_models.Database.id == SqlaTable.database_id, + ) + ds_q = ds_q.where(get_dataset_access_filters(SqlaTable)) + + if name_filter: + ds_q = ds_q.where(SqlaTable.table_name.ilike(f"%{name_filter}%")) + + if sql_filter is not None: + if sql_filter: + ds_q = ds_q.where(or_(SqlaTable.sql.is_(None), SqlaTable.sql == "")) + else: + ds_q = ds_q.where(and_(SqlaTable.sql.isnot(None), SqlaTable.sql != "")) + + return ds_q + + @staticmethod + def build_semantic_view_query(name_filter: str | None) -> Select: + """Build a SELECT for semantic views, applying name filter.""" + sv_q = select( + SemanticView.id.label("item_id"), + literal("semantic_layer").label("source_type"), + SemanticView.changed_on, + SemanticView.name.label("table_name"), + ).select_from(SemanticView.__table__) + + if name_filter: + sv_q = sv_q.where(SemanticView.name.ilike(f"%{name_filter}%")) + + return sv_q + + @staticmethod + def paginate_combined_query( + combined: Any, + order_column: str, + order_direction: str, + page: int, + page_size: int, + ) -> tuple[int, list[Any]]: + """Count, sort, and paginate the combined dataset/semantic-view query.""" + sort_col_map = { + "changed_on_delta_humanized": "changed_on", + "table_name": "table_name", + } + sort_col_name = sort_col_map.get(order_column, "changed_on") + + total_count = ( + db.session.execute(select(func.count()).select_from(combined)).scalar() or 0 + ) + + sort_col = combined.c[sort_col_name] + ordered_col = sort_col.desc() if order_direction == "desc" else sort_col.asc() + + rows = db.session.execute( + select(combined.c.item_id, combined.c.source_type) + .order_by(ordered_col) + .offset(page * page_size) + .limit(page_size) + ).fetchall() + + return total_count, rows + + @staticmethod + def fetch_datasets_by_ids(ids: list[int]) -> dict[int, SqlaTable]: + """Fetch SqlaTable objects by id with relationships eager-loaded.""" + if not ids: + return {} + objs = ( + db.session.query(SqlaTable) + .options( + joinedload(SqlaTable.database), + joinedload(SqlaTable.owners), + joinedload(SqlaTable.changed_by), + ) + .filter(SqlaTable.id.in_(ids)) + .all() + ) + return {obj.id: obj for obj in objs} + + @staticmethod + def fetch_semantic_views_by_ids(ids: list[int]) -> dict[int, SemanticView]: + """Fetch SemanticView objects by id with relationships eager-loaded.""" + if not ids: + return {} + objs = ( + db.session.query(SemanticView) + .options(joinedload(SemanticView.changed_by)) + .filter(SemanticView.id.in_(ids)) + .all() + ) + return {obj.id: obj for obj in objs} diff --git a/superset/datasource/api.py b/superset/datasource/api.py index 0dcf9ac869b..6690d9d34f2 100644 --- a/superset/datasource/api.py +++ b/superset/datasource/api.py @@ -20,20 +20,15 @@ from typing import Any from flask import current_app as app, request from flask_appbuilder.api import expose, protect, rison, safe from flask_appbuilder.api.schemas import get_list_schema -from sqlalchemy import and_, func, literal, or_, select, union_all -from sqlalchemy.orm import joinedload -from sqlalchemy.sql import Select -from superset import db, event_logger, is_feature_enabled, security_manager -from superset.connectors.sqla import models as sqla_models -from superset.connectors.sqla.models import BaseDatasource, SqlaTable +from superset import event_logger, is_feature_enabled, security_manager +from superset.commands.datasource.list import GetCombinedDatasourceListCommand +from superset.connectors.sqla.models import BaseDatasource from superset.daos.datasource import DatasourceDAO from superset.daos.exceptions import DatasourceNotFound, DatasourceTypeNotSupportedError from superset.exceptions import SupersetSecurityException -from superset.semantic_layers.models import SemanticView from superset.superset_typing import FlaskResponse from superset.utils.core import apply_max_row_limit, DatasourceType, SqlExpressionType -from superset.utils.filters import get_dataset_access_filters from superset.views.base_api import BaseSupersetApi, statsd_metrics logger = logging.getLogger(__name__) @@ -351,260 +346,10 @@ class DatasourceRestApi(BaseSupersetApi): if not can_read_datasets and not can_read_sv: return self.response(403, message="Access denied") - args = kwargs.get("rison", {}) - page = args.get("page", 0) - page_size = args.get("page_size", 25) - order_column = args.get("order_column", "changed_on") - order_direction = args.get("order_direction", "desc") - filters = args.get("filters", []) + result = GetCombinedDatasourceListCommand( + args=kwargs.get("rison", {}), + can_read_datasets=can_read_datasets, + can_read_semantic_views=can_read_sv, + ).run() - source_type, name_filter, sql_filter, type_filter = ( - self._parse_combined_list_filters(filters) - ) - - # If semantic layers feature flag is off or no SV access, only show datasets - if not can_read_sv: - source_type = "database" - - # If no dataset access, only show semantic views - if not can_read_datasets: - source_type = "semantic_layer" - - ds_q = self._build_dataset_query(name_filter, sql_filter) - - # Selecting Physical/Virtual implicitly means "database only" - if sql_filter is not None and source_type == "all": - source_type = "database" - - # Handle type_filter = "semantic_view" - if type_filter == "semantic_view": - source_type = "semantic_layer" - - sv_q = self._build_semantic_view_query(name_filter) - - # Build combined query based on source_type - if source_type == "database": - combined = ds_q.subquery() - elif source_type == "semantic_layer": - combined = sv_q.subquery() - else: - combined = union_all(ds_q, sv_q).subquery() - - total_count, rows = self._paginate_combined_query( - combined, order_column, order_direction, page, page_size - ) - - result = self._fetch_and_serialize_rows(rows) - - return self.response(200, count=total_count, result=result) - - @staticmethod - def _parse_combined_list_filters( - filters: list[dict[str, Any]], - ) -> tuple[str, str | None, bool | None, str | None]: - """Parse filters into source_type, name_filter, sql_filter, type_filter.""" - source_type = "all" - name_filter = None - sql_filter: bool | None = None - type_filter: str | None = None - for f in filters: - if f.get("col") == "source_type": - source_type = f.get("value", "all") - elif f.get("col") == "table_name" and f.get("opr") == "ct": - name_filter = f.get("value") - elif f.get("col") == "sql": - val = f.get("value") - if val == "semantic_view": - type_filter = "semantic_view" - else: - sql_filter = val - return source_type, name_filter, sql_filter, type_filter - - @staticmethod - def _build_dataset_query( - name_filter: str | None, - sql_filter: bool | None, - ) -> Select: - """Build the dataset subquery with filters.""" - ds_q = select( - SqlaTable.id.label("item_id"), - literal("database").label("source_type"), - SqlaTable.changed_on, - SqlaTable.table_name, - ).select_from(SqlaTable.__table__) - - if not security_manager.can_access_all_datasources(): - ds_q = ds_q.join( - sqla_models.Database, - sqla_models.Database.id == SqlaTable.database_id, - ) - ds_q = ds_q.where(get_dataset_access_filters(SqlaTable)) - - if name_filter: - ds_q = ds_q.where(SqlaTable.table_name.ilike(f"%{name_filter}%")) - - if sql_filter is not None: - if sql_filter: - ds_q = ds_q.where(or_(SqlaTable.sql.is_(None), SqlaTable.sql == "")) - else: - ds_q = ds_q.where(and_(SqlaTable.sql.isnot(None), SqlaTable.sql != "")) - - return ds_q - - @staticmethod - def _build_semantic_view_query(name_filter: str | None) -> Select: - """Build the semantic view subquery with filters.""" - sv_q = select( - SemanticView.id.label("item_id"), - literal("semantic_layer").label("source_type"), - SemanticView.changed_on, - SemanticView.name.label("table_name"), - ).select_from(SemanticView.__table__) - - if name_filter: - sv_q = sv_q.where(SemanticView.name.ilike(f"%{name_filter}%")) - - return sv_q - - @staticmethod - def _paginate_combined_query( - combined: Any, - order_column: str, - order_direction: str, - page: int, - page_size: int, - ) -> tuple[int, list[Any]]: - """Count, sort, and paginate the combined query.""" - sort_col_map = { - "changed_on_delta_humanized": "changed_on", - "table_name": "table_name", - } - sort_col_name = sort_col_map.get(order_column, "changed_on") - - total_count = ( - db.session.execute(select(func.count()).select_from(combined)).scalar() or 0 - ) - - sort_col = combined.c[sort_col_name] - if order_direction == "desc": - sort_col = sort_col.desc() - else: - sort_col = sort_col.asc() - - paginated_q = ( - select(combined.c.item_id, combined.c.source_type) - .order_by(sort_col) - .offset(page * page_size) - .limit(page_size) - ) - rows = db.session.execute(paginated_q).fetchall() - return total_count, rows - - def _fetch_and_serialize_rows(self, rows: list[Any]) -> list[dict[str, Any]]: - """Fetch ORM objects and serialize rows in order.""" - dataset_ids = [r.item_id for r in rows if r.source_type == "database"] - sv_ids = [r.item_id for r in rows if r.source_type == "semantic_layer"] - - datasets_map: dict[int, SqlaTable] = {} - if dataset_ids: - ds_objs = ( - db.session.query(SqlaTable) - .options( - joinedload(SqlaTable.database), - joinedload(SqlaTable.owners), - joinedload(SqlaTable.changed_by), - ) - .filter(SqlaTable.id.in_(dataset_ids)) - .all() - ) - datasets_map = {obj.id: obj for obj in ds_objs} - - sv_map: dict[int, SemanticView] = {} - if sv_ids: - sv_objs = ( - db.session.query(SemanticView) - .options(joinedload(SemanticView.changed_by)) - .filter(SemanticView.id.in_(sv_ids)) - .all() - ) - sv_map = {obj.id: obj for obj in sv_objs} - - result = [] - for row in rows: - if row.source_type == "database": - obj = datasets_map.get(row.item_id) - if obj: - result.append(self._serialize_dataset(obj)) - else: - obj = sv_map.get(row.item_id) - if obj: - result.append(self._serialize_semantic_view(obj)) - - return result - - @staticmethod - def _serialize_dataset(obj: SqlaTable) -> dict[str, Any]: - changed_by = obj.changed_by - return { - "id": obj.id, - "uuid": str(obj.uuid), - "table_name": obj.table_name, - "kind": obj.kind, - "source_type": "database", - "description": obj.description, - "explore_url": obj.explore_url, - "database": { - "id": obj.database_id, - "database_name": obj.database.database_name, - } - if obj.database - else None, - "schema": obj.schema, - "sql": obj.sql, - "extra": obj.extra, - "owners": [ - { - "id": o.id, - "first_name": o.first_name, - "last_name": o.last_name, - } - for o in obj.owners - ], - "changed_by_name": obj.changed_by_name, - "changed_by": { - "first_name": changed_by.first_name, - "last_name": changed_by.last_name, - } - if changed_by - else None, - "changed_on_delta_humanized": obj.changed_on_delta_humanized(), - "changed_on_utc": obj.changed_on_utc(), - } - - @staticmethod - def _serialize_semantic_view(obj: SemanticView) -> dict[str, Any]: - changed_by = obj.changed_by - return { - "id": obj.id, - "uuid": str(obj.uuid), - "table_name": obj.name, - "kind": "semantic_view", - "source_type": "semantic_layer", - "description": obj.description, - "cache_timeout": obj.cache_timeout, - "explore_url": obj.explore_url, - "database": None, - "schema": None, - "sql": None, - "extra": None, - "owners": [], - "changed_by_name": obj.changed_by_name, - "changed_by": { - "first_name": changed_by.first_name, - "last_name": changed_by.last_name, - } - if changed_by - else None, - "changed_on_delta_humanized": obj.changed_on_delta_humanized(), - "changed_on_utc": obj.changed_on_utc(), - } + return self.response(200, **result) diff --git a/superset/datasource/schemas.py b/superset/datasource/schemas.py new file mode 100644 index 00000000000..bd9bf5326f9 --- /dev/null +++ b/superset/datasource/schemas.py @@ -0,0 +1,139 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Marshmallow schemas for the combined datasource list endpoint.""" + +from __future__ import annotations + +from marshmallow import fields, Schema + +from superset.connectors.sqla.models import SqlaTable +from superset.semantic_layers.models import SemanticView + + +class _ChangedBySchema(Schema): + first_name = fields.String() + last_name = fields.String() + + +class _OwnerSchema(Schema): + id = fields.Integer() + first_name = fields.String() + last_name = fields.String() + + +class _DatabaseSchema(Schema): + id = fields.Integer() + database_name = fields.String() + + +class DatasetListSchema(Schema): + """Serializes a SqlaTable ORM object for the combined list response.""" + + id = fields.Integer() + uuid = fields.Method("get_uuid") + table_name = fields.String() + kind = fields.String() + source_type = fields.Constant("database") + description = fields.String(allow_none=True) + explore_url = fields.String() + database = fields.Method("get_database") + schema = fields.String(allow_none=True) + sql = fields.String(allow_none=True) + extra = fields.String(allow_none=True) + owners = fields.Method("get_owners") + changed_by_name = fields.String() + changed_by = fields.Method("get_changed_by") + changed_on_delta_humanized = fields.Method("get_changed_on_delta_humanized") + changed_on_utc = fields.Method("get_changed_on_utc") + + def get_uuid(self, obj: SqlaTable) -> str: + return str(obj.uuid) + + def get_database(self, obj: SqlaTable) -> dict[str, object] | None: + if not obj.database: + return None + return _DatabaseSchema().dump( + {"id": obj.database_id, "database_name": obj.database.database_name} + ) + + def get_owners(self, obj: SqlaTable) -> list[dict[str, object]]: + return _OwnerSchema(many=True).dump( + [ + {"id": o.id, "first_name": o.first_name, "last_name": o.last_name} + for o in obj.owners + ] + ) + + def get_changed_by(self, obj: SqlaTable) -> dict[str, object] | None: + if not obj.changed_by: + return None + return _ChangedBySchema().dump( + { + "first_name": obj.changed_by.first_name, + "last_name": obj.changed_by.last_name, + } + ) + + def get_changed_on_delta_humanized(self, obj: SqlaTable) -> str: + return obj.changed_on_delta_humanized() + + def get_changed_on_utc(self, obj: SqlaTable) -> str: + return obj.changed_on_utc() + + +class SemanticViewListSchema(Schema): + """Serializes a SemanticView ORM object for the combined list response.""" + + id = fields.Integer() + uuid = fields.Method("get_uuid") + table_name = fields.Method("get_table_name") + kind = fields.Constant("semantic_view") + source_type = fields.Constant("semantic_layer") + description = fields.String(allow_none=True) + explore_url = fields.String() + database = fields.Constant(None) + schema = fields.Constant(None) + sql = fields.Constant(None) + extra = fields.Constant(None) + owners = fields.Constant([]) + changed_by_name = fields.String() + changed_by = fields.Method("get_changed_by") + changed_on_delta_humanized = fields.Method("get_changed_on_delta_humanized") + changed_on_utc = fields.Method("get_changed_on_utc") + cache_timeout = fields.Integer(allow_none=True) + + def get_uuid(self, obj: SemanticView) -> str: + return str(obj.uuid) + + def get_table_name(self, obj: SemanticView) -> str: + return obj.name + + def get_changed_by(self, obj: SemanticView) -> dict[str, object] | None: + if not obj.changed_by: + return None + return _ChangedBySchema().dump( + { + "first_name": obj.changed_by.first_name, + "last_name": obj.changed_by.last_name, + } + ) + + def get_changed_on_delta_humanized(self, obj: SemanticView) -> str: + return obj.changed_on_delta_humanized() + + def get_changed_on_utc(self, obj: SemanticView) -> str: + return obj.changed_on_utc()