diff --git a/superset/commands/chart/data/streaming_export_command.py b/superset/commands/chart/data/streaming_export_command.py index b6ec3a36698..28764c3ae81 100644 --- a/superset/commands/chart/data/streaming_export_command.py +++ b/superset/commands/chart/data/streaming_export_command.py @@ -68,7 +68,14 @@ class StreamingCSVExportCommand(BaseStreamingCSVExportCommand): query_obj = self._query_context.queries[0] sql_query = datasource.get_query_str(query_obj.to_dict()) - return sql_query, datasource.database + # Chart export is SQL-specific, so we check for BaseDatasource + from superset.connectors.sqla.models import BaseDatasource + + database = ( + datasource.database if isinstance(datasource, BaseDatasource) else None + ) + + return sql_query, database def _get_row_limit(self) -> int | None: """ diff --git a/superset/common/query_actions.py b/superset/common/query_actions.py index ef7485842ef..d59804ff613 100644 --- a/superset/common/query_actions.py +++ b/superset/common/query_actions.py @@ -38,9 +38,7 @@ if TYPE_CHECKING: from superset.common.query_object import QueryObject -def _get_datasource( - query_context: QueryContext, query_obj: QueryObject -) -> Explorable: +def _get_datasource(query_context: QueryContext, query_obj: QueryObject) -> Explorable: return query_obj.datasource or query_context.datasource diff --git a/superset/common/query_context_factory.py b/superset/common/query_context_factory.py index 52884365ffd..7b61cf6ea94 100644 --- a/superset/common/query_context_factory.py +++ b/superset/common/query_context_factory.py @@ -16,7 +16,7 @@ # under the License. from __future__ import annotations -from typing import Any, TYPE_CHECKING +from typing import Any from flask import current_app diff --git a/superset/common/query_context_processor.py b/superset/common/query_context_processor.py index 615984e13dd..eec2275a23b 100644 --- a/superset/common/query_context_processor.py +++ b/superset/common/query_context_processor.py @@ -30,13 +30,13 @@ from superset.common.query_actions import get_query_results from superset.common.utils.query_cache_manager import QueryCacheManager from superset.common.utils.time_range_utils import get_since_until_from_time_range from superset.constants import CACHE_DISABLED_TIMEOUT, CacheRegion -from superset.explorables.base import Explorable from superset.daos.annotation_layer import AnnotationLayerDAO from superset.daos.chart import ChartDAO from superset.exceptions import ( QueryObjectValidationError, SupersetException, ) +from superset.explorables.base import Explorable from superset.extensions import cache_manager, security_manager from superset.models.helpers import QueryResult from superset.superset_typing import AdhocColumn, AdhocMetric diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 1cdce4f78b3..3c0f45feab5 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -234,6 +234,51 @@ class BaseDatasource( """Set the datasource-specific cache timeout.""" self._cache_timeout = value + def has_drill_by_columns(self, column_names: list[str]) -> bool: + """ + Check if the specified columns support drill-by operations. + + For SQL datasources, drill-by is supported on columns that are marked + as groupable in the metadata. This allows users to navigate from + aggregated views to detailed data by grouping on these dimensions. + + :param column_names: List of column names to check + :return: True if all columns support drill-by, False otherwise + """ + if not column_names: + return False + + # Get all groupable column names for this datasource + drillable_columns = { + row[0] + for row in db.session.query(TableColumn.column_name) + .filter(TableColumn.table_id == self.id) + .filter(TableColumn.groupby) + .all() + } + + # Check if all requested columns are drillable + return set(column_names).issubset(drillable_columns) + + def get_time_grains(self) -> list[dict[str, Any]]: + """ + Get available time granularities from the database. + + Implements the Explorable protocol by delegating to the database's + time grain definitions. Each database engine spec defines its own + set of supported time grains. + + :return: List of time grain dictionaries with name, function, and duration + """ + return [ + { + "name": grain.name, + "function": grain.function, + "duration": grain.duration, + } + for grain in (self.database.grains() or []) + ] + @property def kind(self) -> DatasourceKind: return DatasourceKind.VIRTUAL if self.sql else DatasourceKind.PHYSICAL @@ -887,25 +932,6 @@ class TableColumn(AuditMixinNullable, ImportExportMixin, CertificationMixin, Mod def database(self) -> Database: return self.table.database if self.table else self._database # type: ignore - def get_time_grains(self) -> list[dict[str, Any]]: - """ - Get available time granularities from the database. - - Implements the Explorable protocol by delegating to the database's - time grain definitions. Each database engine spec defines its own - set of supported time grains. - - :return: List of time grain dictionaries with name, function, and duration - """ - return [ - { - "name": grain.name, - "function": grain.function, - "duration": grain.duration, - } - for grain in (self.database.grains() or []) - ] - @property def db_engine_spec(self) -> builtins.type[BaseEngineSpec]: return self.database.db_engine_spec diff --git a/superset/explorables/base.py b/superset/explorables/base.py index 716964fbd4f..eb1d678d565 100644 --- a/superset/explorables/base.py +++ b/superset/explorables/base.py @@ -262,10 +262,26 @@ class Explorable(Protocol): Example return value: ```python [ - {"name": "Second", "function": "DATE_TRUNC('second', {col})", "duration": "PT1S"}, - {"name": "Minute", "function": "DATE_TRUNC('minute', {col})", "duration": "PT1M"}, - {"name": "Hour", "function": "DATE_TRUNC('hour', {col})", "duration": "PT1H"}, - {"name": "Day", "function": "DATE_TRUNC('day', {col})", "duration": "P1D"}, + { + "name": "Second", + "function": "DATE_TRUNC('second', {col})", + "duration": "PT1S", + }, + { + "name": "Minute", + "function": "DATE_TRUNC('minute', {col})", + "duration": "PT1M", + }, + { + "name": "Hour", + "function": "DATE_TRUNC('hour', {col})", + "duration": "PT1H", + }, + { + "name": "Day", + "function": "DATE_TRUNC('day', {col})", + "duration": "P1D", + }, ] ``` @@ -276,17 +292,42 @@ class Explorable(Protocol): # Optional Properties # ========================================================================= + # ========================================================================= + # Required Methods + # ========================================================================= + + def has_drill_by_columns(self, column_names: list[str]) -> bool: + """ + Check if the specified columns support drill-by operations. + + Drill-by allows users to navigate from aggregated views to detailed + data by grouping on specific dimensions. This method determines whether + the given columns can be used for drill-by in the current datasource. + + For SQL datasources, this typically checks if columns are marked as + groupable in the metadata. For semantic views, it checks against the + semantic layer's dimension definitions. + + :param column_names: List of column names to check + :return: True if all columns support drill-by, False otherwise + """ + + # ========================================================================= + # Optional Properties + # ========================================================================= + @property def is_rls_supported(self) -> bool: """ Whether this explorable supports Row Level Security. Row Level Security (RLS) allows filtering data based on user identity. - SQL-based datasources typically support this, while semantic layers - may handle security at a different level. + SQL-based datasources typically support this via SQL queries, while + semantic layers may handle security at the semantic layer level. :return: True if RLS is supported, False otherwise """ + return False @property def query_language(self) -> str | None: diff --git a/superset/security/manager.py b/superset/security/manager.py index dc41a5b99bb..6ffef7fcdda 100644 --- a/superset/security/manager.py +++ b/superset/security/manager.py @@ -632,10 +632,10 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods :param form_data: The form_data included in the request. :param dashboard: The dashboard the user is drilling from. - :returns: Whether the user has drill byaccess. + :param datasource: The datasource being queried + :returns: Whether the user has drill by access. """ - from superset.connectors.sqla.models import TableColumn from superset.models.slice import Slice return bool( @@ -650,17 +650,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods and slc in dashboard.slices and slc.datasource == datasource and (dimensions := form_data.get("groupby")) - and (datasource_id := getattr(datasource, "id", None)) - and ( - drillable_columns := { - row[0] - for row in self.session.query(TableColumn.column_name) - .filter(TableColumn.table_id == datasource_id) - .filter(TableColumn.groupby) - .all() - } - ) - and set(dimensions).issubset(drillable_columns) + and datasource.has_drill_by_columns(dimensions) ) def can_access_dashboard(self, dashboard: "Dashboard") -> bool: @@ -726,7 +716,9 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods ) @staticmethod - def get_datasource_access_error_msg(datasource: "BaseDatasource | Explorable") -> str: + def get_datasource_access_error_msg( + datasource: "BaseDatasource | Explorable", + ) -> str: """ Return the error message for the denied Superset datasource. @@ -734,7 +726,11 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods :returns: The error message """ - datasource_id = getattr(datasource, "id", datasource.data.get("id") if hasattr(datasource, "data") else None) + datasource_id = getattr( + datasource, + "id", + datasource.data.get("id") if hasattr(datasource, "data") else None, + ) return ( f"This endpoint requires the datasource {datasource_id}, " "database or `all_datasource_access` permission" @@ -768,8 +764,18 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods level=ErrorLevel.WARNING, extra={ "link": self.get_datasource_access_link(datasource), - "datasource": getattr(datasource, "id", datasource.data.get("id") if hasattr(datasource, "data") else None), - "datasource_name": getattr(datasource, "name", datasource.data.get("name") if hasattr(datasource, "data") else None), + "datasource": getattr( + datasource, + "id", + datasource.data.get("id") if hasattr(datasource, "data") else None, + ), + "datasource_name": getattr( + datasource, + "name", + datasource.data.get("name") + if hasattr(datasource, "data") + else None, + ), }, ) @@ -2367,8 +2373,11 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods # If the DB engine spec doesn't implement the logic the schema is read # from the SQLAlchemy URI if possible; if not, we use the SQLAlchemy # inspector to read it. + from superset.models.sql_lab import Query + default_schema = database.get_default_schema_for_query( - query, template_params # type: ignore[arg-type] + cast(Query, query), + template_params, ) tables = { table_.qualify( @@ -2480,7 +2489,14 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods and dashboard_.json_metadata and (json_metadata := json.loads(dashboard_.json_metadata)) and any( - target.get("datasetId") == getattr(datasource, "id", datasource.data.get("id") if hasattr(datasource, "data") else None) + target.get("datasetId") + == getattr( + datasource, + "id", + datasource.data.get("id") + if hasattr(datasource, "data") + else None, + ) for fltr in json_metadata.get( "native_filter_configuration", [], @@ -2594,7 +2610,11 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods :return: A list of filters """ if guest_user := self.get_current_guest_user_if_guest(): - dataset_id = getattr(dataset, "id", dataset.data.get("id") if hasattr(dataset, "data") else None) + dataset_id = getattr( + dataset, + "id", + dataset.data.get("id") if hasattr(dataset, "data") else None, + ) return [ rule for rule in guest_user.rls @@ -2639,7 +2659,9 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods ) .filter(RLSFilterRoles.c.role_id.in_(user_roles)) ) - table_id = getattr(table, "id", table.data.get("id") if hasattr(table, "data") else None) + table_id = getattr( + table, "id", table.data.get("id") if hasattr(table, "data") else None + ) filter_tables = self.session.query(RLSFilterTables.c.rls_filter_id).filter( RLSFilterTables.c.table_id == table_id ) @@ -2667,7 +2689,9 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods ) return query.all() - def get_rls_sorted(self, table: "BaseDatasource | Explorable") -> list["RowLevelSecurityFilter"]: + def get_rls_sorted( + self, table: "BaseDatasource | Explorable" + ) -> list["RowLevelSecurityFilter"]: """ Retrieves a list RLS filters sorted by ID for the current user and the passed table. @@ -2679,12 +2703,14 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods filters.sort(key=lambda f: f.id) return filters - def get_guest_rls_filters_str(self, table: "BaseDatasource | Explorable") -> list[str]: + def get_guest_rls_filters_str( + self, table: "BaseDatasource | Explorable" + ) -> list[str]: return [f.get("clause", "") for f in self.get_guest_rls_filters(table)] def get_rls_cache_key(self, datasource: "Explorable | BaseDatasource") -> list[str]: rls_clauses_with_group_key = [] - if hasattr(datasource, "is_rls_supported") and datasource.is_rls_supported: + if datasource.is_rls_supported: rls_clauses_with_group_key = [ f"{f.clause}-{f.group_key or ''}" for f in self.get_rls_sorted(datasource) diff --git a/superset/thumbnails/digest.py b/superset/thumbnails/digest.py index 31a179fd93c..41389e15944 100644 --- a/superset/thumbnails/digest.py +++ b/superset/thumbnails/digest.py @@ -61,6 +61,8 @@ def _adjust_string_with_rls( """ Add the RLS filters to the unique string based on current executor. """ + from superset.connectors.sqla.models import BaseDatasource + user = ( security_manager.find_user(executor) or security_manager.get_current_guest_user_if_guest() @@ -72,7 +74,7 @@ def _adjust_string_with_rls( for datasource in datasources: if ( datasource - and hasattr(datasource, "is_rls_supported") + and isinstance(datasource, BaseDatasource) and datasource.is_rls_supported ): rls_filters = datasource.get_sqla_row_level_filters() diff --git a/superset/utils/core.py b/superset/utils/core.py index 7ff6021f060..bf0b97bbb47 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -1657,7 +1657,9 @@ def map_sql_type_to_inferred_type(sql_type: Optional[str]) -> str: return "string" # If no match is found, return "string" as default -def get_metric_type_from_column(column: Any, datasource: BaseDatasource | Explorable | Query) -> str: +def get_metric_type_from_column( + column: Any, datasource: BaseDatasource | Explorable | Query +) -> str: """ Determine the metric type from a given column in a datasource. @@ -1778,7 +1780,9 @@ def get_time_filter_status( applied_time_extras: dict[str, str], ) -> tuple[list[dict[str, str]], list[dict[str, str]]]: temporal_columns: set[Any] = { - (col.column_name if hasattr(col, "column_name") else col.get("column_name")) for col in datasource.columns if (col.is_dttm if hasattr(col, "is_dttm") else col.get("is_dttm")) + (col.column_name if hasattr(col, "column_name") else col.get("column_name")) + for col in datasource.columns + if (col.is_dttm if hasattr(col, "is_dttm") else col.get("is_dttm")) } applied: list[dict[str, str]] = [] rejected: list[dict[str, str]] = [] diff --git a/tests/unit_tests/common/test_time_shifts.py b/tests/unit_tests/common/test_time_shifts.py index f65b9d93eeb..08b6478c482 100644 --- a/tests/unit_tests/common/test_time_shifts.py +++ b/tests/unit_tests/common/test_time_shifts.py @@ -39,32 +39,26 @@ processor = QueryContextProcessor( ) # Bind ExploreMixin methods to datasource for testing -processor._qc_datasource.add_offset_join_column = ( - ExploreMixin.add_offset_join_column.__get__(processor._qc_datasource) +# Type annotation needed because _qc_datasource is typed as Explorable in protocol +_datasource: BaseDatasource = processor._qc_datasource # type: ignore +_datasource.add_offset_join_column = ExploreMixin.add_offset_join_column.__get__( + _datasource ) -processor._qc_datasource.join_offset_dfs = ExploreMixin.join_offset_dfs.__get__( - processor._qc_datasource +_datasource.join_offset_dfs = ExploreMixin.join_offset_dfs.__get__(_datasource) +_datasource.is_valid_date_range = ExploreMixin.is_valid_date_range.__get__(_datasource) +_datasource._determine_join_keys = ExploreMixin._determine_join_keys.__get__( + _datasource ) -processor._qc_datasource.is_valid_date_range = ExploreMixin.is_valid_date_range.__get__( - processor._qc_datasource -) -processor._qc_datasource._determine_join_keys = ( - ExploreMixin._determine_join_keys.__get__(processor._qc_datasource) -) -processor._qc_datasource._perform_join = ExploreMixin._perform_join.__get__( - processor._qc_datasource -) -processor._qc_datasource._apply_cleanup_logic = ( - ExploreMixin._apply_cleanup_logic.__get__(processor._qc_datasource) +_datasource._perform_join = ExploreMixin._perform_join.__get__(_datasource) +_datasource._apply_cleanup_logic = ExploreMixin._apply_cleanup_logic.__get__( + _datasource ) # Static methods don't need binding - assign directly -processor._qc_datasource.generate_join_column = ExploreMixin.generate_join_column -processor._qc_datasource.is_valid_date_range_static = ( - ExploreMixin.is_valid_date_range_static -) +_datasource.generate_join_column = ExploreMixin.generate_join_column +_datasource.is_valid_date_range_static = ExploreMixin.is_valid_date_range_static # Convenience reference for backward compatibility in tests -query_context_processor = processor._qc_datasource +query_context_processor = _datasource @fixture