From 16e6452b8c3d2989d424b0f245c6d723d34673d2 Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Thu, 4 Dec 2025 13:18:34 -0500 Subject: [PATCH] feat: Explorable protocol (#36245) --- .pylintrc | 2 +- .../chart/data/streaming_export_command.py | 2 +- superset/commands/dataset/export.py | 7 + superset/common/query_actions.py | 25 +- superset/common/query_context.py | 25 +- superset/common/query_context_factory.py | 12 +- superset/common/query_context_processor.py | 4 +- superset/connectors/sqla/models.py | 74 +++- superset/explorables/base.py | 353 ++++++++++++++++++ superset/models/helpers.py | 6 +- superset/security/manager.py | 105 ++++-- superset/thumbnails/digest.py | 7 +- superset/utils/core.py | 16 +- tests/unit_tests/common/test_time_shifts.py | 34 +- 14 files changed, 563 insertions(+), 109 deletions(-) create mode 100644 superset/explorables/base.py diff --git a/.pylintrc b/.pylintrc index 80b5ef5f668..010f0d16b47 100644 --- a/.pylintrc +++ b/.pylintrc @@ -53,7 +53,7 @@ extension-pkg-whitelist=pyarrow [MESSAGES CONTROL] disable=all -enable=disallowed-json-import,disallowed-sql-import,consider-using-transaction +enable=json-import,disallowed-sql-import,consider-using-transaction [REPORTS] diff --git a/superset/commands/chart/data/streaming_export_command.py b/superset/commands/chart/data/streaming_export_command.py index b6ec3a36698..7dec6bc1d41 100644 --- a/superset/commands/chart/data/streaming_export_command.py +++ b/superset/commands/chart/data/streaming_export_command.py @@ -68,7 +68,7 @@ class StreamingCSVExportCommand(BaseStreamingCSVExportCommand): query_obj = self._query_context.queries[0] sql_query = datasource.get_query_str(query_obj.to_dict()) - return sql_query, datasource.database + return sql_query, getattr(datasource, "database", None) def _get_row_limit(self) -> int | None: """ diff --git a/superset/commands/dataset/export.py b/superset/commands/dataset/export.py index 6c866b310cd..348f8152e69 100644 --- a/superset/commands/dataset/export.py +++ b/superset/commands/dataset/export.py @@ -77,6 +77,13 @@ class ExportDatasetsCommand(ExportModelsCommand): payload["version"] = EXPORT_VERSION payload["database_uuid"] = str(model.database.uuid) + # Always set cache_timeout from the property to ensure correct value + payload["cache_timeout"] = model.cache_timeout + + # SQLAlchemy returns column names as quoted_name objects which PyYAML cannot + # serialize. Convert all keys to regular strings to fix YAML serialization. + payload = {str(key): value for key, value in payload.items()} + file_content = yaml.safe_dump(payload, sort_keys=False) return file_content diff --git a/superset/common/query_actions.py b/superset/common/query_actions.py index c99bd2a431d..c7760ebd4ea 100644 --- a/superset/common/query_actions.py +++ b/superset/common/query_actions.py @@ -23,8 +23,8 @@ from flask_babel import _ from superset.common.chart_data import ChartDataResultType from superset.common.db_query_status import QueryStatus -from superset.connectors.sqla.models import BaseDatasource from superset.exceptions import QueryObjectValidationError, SupersetParseError +from superset.explorables.base import Explorable from superset.utils.core import ( extract_column_dtype, extract_dataframe_dtypes, @@ -38,9 +38,7 @@ if TYPE_CHECKING: from superset.common.query_object import QueryObject -def _get_datasource( - query_context: QueryContext, query_obj: QueryObject -) -> BaseDatasource: +def _get_datasource(query_context: QueryContext, query_obj: QueryObject) -> Explorable: return query_obj.datasource or query_context.datasource @@ -64,16 +62,9 @@ def _get_timegrains( query_context: QueryContext, query_obj: QueryObject, _: bool ) -> dict[str, Any]: datasource = _get_datasource(query_context, query_obj) - return { - "data": [ - { - "name": grain.name, - "function": grain.function, - "duration": grain.duration, - } - for grain in datasource.database.grains() - ] - } + # Use the new get_time_grains() method from Explorable protocol + grains = datasource.get_time_grains() + return {"data": grains} def _get_query( @@ -158,7 +149,8 @@ def _get_samples( qry_obj_cols = [] for o in datasource.columns: if isinstance(o, dict): - qry_obj_cols.append(o.get("column_name")) + if column_name := o.get("column_name"): + qry_obj_cols.append(column_name) else: qry_obj_cols.append(o.column_name) query_obj.columns = qry_obj_cols @@ -180,7 +172,8 @@ def _get_drill_detail( qry_obj_cols = [] for o in datasource.columns: if isinstance(o, dict): - qry_obj_cols.append(o.get("column_name")) + if column_name := o.get("column_name"): + qry_obj_cols.append(column_name) else: qry_obj_cols.append(o.column_name) query_obj.columns = qry_obj_cols diff --git a/superset/common/query_context.py b/superset/common/query_context.py index 400c4a95038..ed39aff078f 100644 --- a/superset/common/query_context.py +++ b/superset/common/query_context.py @@ -24,11 +24,11 @@ import pandas as pd from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType from superset.common.query_context_processor import QueryContextProcessor from superset.common.query_object import QueryObject +from superset.explorables.base import Explorable from superset.models.slice import Slice from superset.utils.core import GenericDataType if TYPE_CHECKING: - from superset.connectors.sqla.models import BaseDatasource from superset.models.helpers import QueryResult @@ -44,7 +44,7 @@ class QueryContext: cache_type: ClassVar[str] = "df" enforce_numerical_metrics: ClassVar[bool] = True - datasource: BaseDatasource + datasource: Explorable slice_: Slice | None = None queries: list[QueryObject] form_data: dict[str, Any] | None @@ -62,7 +62,7 @@ class QueryContext: def __init__( # pylint: disable=too-many-arguments self, *, - datasource: BaseDatasource, + datasource: Explorable, queries: list[QueryObject], slice_: Slice | None, form_data: dict[str, Any] | None, @@ -99,15 +99,24 @@ class QueryContext: return self._processor.get_payload(cache_query_context, force_cached) def get_cache_timeout(self) -> int | None: + """ + Get the cache timeout for this query context. + + Priority order: + 1. Custom timeout set for this specific query + 2. Chart-level timeout (if querying from a saved chart) + 3. Datasource-level timeout (explorable handles its own fallback logic) + 4. System default (None) + + Note: Each explorable is responsible for its own internal fallback chain. + For example, BaseDatasource falls back to database.cache_timeout, + while semantic layers might fall back to their layer's default. + """ if self.custom_cache_timeout is not None: return self.custom_cache_timeout if self.slice_ and self.slice_.cache_timeout is not None: return self.slice_.cache_timeout - if self.datasource.cache_timeout is not None: - return self.datasource.cache_timeout - if hasattr(self.datasource, "database"): - return self.datasource.database.cache_timeout - return None + return self.datasource.cache_timeout def query_cache_key(self, query_obj: QueryObject, **kwargs: Any) -> str | None: return self._processor.query_cache_key(query_obj, **kwargs) diff --git a/superset/common/query_context_factory.py b/superset/common/query_context_factory.py index 0113886edaa..7b61cf6ea94 100644 --- a/superset/common/query_context_factory.py +++ b/superset/common/query_context_factory.py @@ -16,7 +16,7 @@ # under the License. from __future__ import annotations -from typing import Any, TYPE_CHECKING +from typing import Any from flask import current_app @@ -26,13 +26,11 @@ from superset.common.query_object import QueryObject from superset.common.query_object_factory import QueryObjectFactory from superset.daos.chart import ChartDAO from superset.daos.datasource import DatasourceDAO +from superset.explorables.base import Explorable from superset.models.slice import Slice from superset.superset_typing import Column from superset.utils.core import DatasourceDict, DatasourceType, is_adhoc_column -if TYPE_CHECKING: - from superset.connectors.sqla.models import BaseDatasource - def create_query_object_factory() -> QueryObjectFactory: return QueryObjectFactory(current_app.config, DatasourceDAO()) @@ -104,7 +102,7 @@ class QueryContextFactory: # pylint: disable=too-few-public-methods cache_values=cache_values, ) - def _convert_to_model(self, datasource: DatasourceDict) -> BaseDatasource: + def _convert_to_model(self, datasource: DatasourceDict) -> Explorable: return DatasourceDAO.get_datasource( datasource_type=DatasourceType(datasource["type"]), database_id_or_uuid=datasource["id"], @@ -115,7 +113,7 @@ class QueryContextFactory: # pylint: disable=too-few-public-methods def _process_query_object( self, - datasource: BaseDatasource, + datasource: Explorable, form_data: dict[str, Any] | None, query_object: QueryObject, ) -> QueryObject: @@ -201,7 +199,7 @@ class QueryContextFactory: # pylint: disable=too-few-public-methods self, query_object: QueryObject, form_data: dict[str, Any] | None, - datasource: BaseDatasource, + datasource: Explorable, ) -> None: temporal_columns = { column["column_name"] if isinstance(column, dict) else column.column_name diff --git a/superset/common/query_context_processor.py b/superset/common/query_context_processor.py index 4ccda926d0b..be448873fdd 100644 --- a/superset/common/query_context_processor.py +++ b/superset/common/query_context_processor.py @@ -29,7 +29,6 @@ from superset.common.db_query_status import QueryStatus from superset.common.query_actions import get_query_results from superset.common.utils.query_cache_manager import QueryCacheManager from superset.common.utils.time_range_utils import get_since_until_from_time_range -from superset.connectors.sqla.models import BaseDatasource from superset.constants import CACHE_DISABLED_TIMEOUT, CacheRegion from superset.daos.annotation_layer import AnnotationLayerDAO from superset.daos.chart import ChartDAO @@ -37,6 +36,7 @@ from superset.exceptions import ( QueryObjectValidationError, SupersetException, ) +from superset.explorables.base import Explorable from superset.extensions import cache_manager, security_manager from superset.models.helpers import QueryResult from superset.superset_typing import AdhocColumn, AdhocMetric @@ -70,7 +70,7 @@ class QueryContextProcessor: """ _query_context: QueryContext - _qc_datasource: BaseDatasource + _qc_datasource: Explorable def __init__(self, query_context: QueryContext): self._query_context = query_context diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 40d993a2f76..55d8c5fb709 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -198,7 +198,7 @@ class BaseDatasource( is_featured = Column(Boolean, default=False) # TODO deprecating filter_select_enabled = Column(Boolean, default=True) offset = Column(Integer, default=0) - cache_timeout = Column(Integer) + _cache_timeout = Column("cache_timeout", Integer) params = Column(String(1000)) perm = Column(String(1000)) schema_perm = Column(String(1000)) @@ -212,6 +212,78 @@ class BaseDatasource( extra_import_fields = ["is_managed_externally", "external_url"] + @property + def cache_timeout(self) -> int | None: + """ + Get the cache timeout for this datasource. + + Implements the Explorable protocol by handling the fallback chain: + 1. Datasource-specific timeout (if set) + 2. Database default timeout (if no datasource timeout) + 3. None (use system default) + + This allows each datasource to override caching, while falling back + to database-level defaults when appropriate. + """ + if self._cache_timeout is not None: + return self._cache_timeout + + # database should always be set, but that's not true for v0 import + if self.database: + return self.database.cache_timeout + + return None + + @cache_timeout.setter + def cache_timeout(self, value: int | None) -> None: + """Set the datasource-specific cache timeout.""" + self._cache_timeout = value + + def has_drill_by_columns(self, column_names: list[str]) -> bool: + """ + Check if the specified columns support drill-by operations. + + For SQL datasources, drill-by is supported on columns that are marked + as groupable in the metadata. This allows users to navigate from + aggregated views to detailed data by grouping on these dimensions. + + :param column_names: List of column names to check + :return: True if all columns support drill-by, False otherwise + """ + if not column_names: + return False + + # Get all groupable column names for this datasource + drillable_columns = { + row[0] + for row in db.session.query(TableColumn.column_name) + .filter(TableColumn.table_id == self.id) + .filter(TableColumn.groupby) + .all() + } + + # Check if all requested columns are drillable + return set(column_names).issubset(drillable_columns) + + def get_time_grains(self) -> list[dict[str, Any]]: + """ + Get available time granularities from the database. + + Implements the Explorable protocol by delegating to the database's + time grain definitions. Each database engine spec defines its own + set of supported time grains. + + :return: List of time grain dictionaries with name, function, and duration + """ + return [ + { + "name": grain.name, + "function": grain.function, + "duration": grain.duration, + } + for grain in (self.database.grains() or []) + ] + @property def kind(self) -> DatasourceKind: return DatasourceKind.VIRTUAL if self.sql else DatasourceKind.PHYSICAL diff --git a/superset/explorables/base.py b/superset/explorables/base.py new file mode 100644 index 00000000000..d88de1a316b --- /dev/null +++ b/superset/explorables/base.py @@ -0,0 +1,353 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Base protocol for explorable data sources in Superset. + +An "explorable" is any data source that can be explored to create charts, +including SQL datasets, saved queries, and semantic layer views. +""" + +from __future__ import annotations + +from collections.abc import Hashable +from datetime import datetime +from typing import Any, Protocol, runtime_checkable + +from superset.common.query_object import QueryObject +from superset.models.helpers import QueryResult +from superset.superset_typing import QueryObjectDict + + +@runtime_checkable +class Explorable(Protocol): + """ + Protocol for objects that can be explored to create charts. + + This protocol defines the minimal interface required for a data source + to be visualizable in Superset. It is implemented by: + - BaseDatasource (SQL datasets and queries) + - SemanticView (semantic layer views) + - Future: Other data source types + + The protocol focuses on the essential methods and properties needed + for query execution, caching, and security. + """ + + # ========================================================================= + # Core Query Interface + # ========================================================================= + + def get_query_result(self, query_object: QueryObject) -> QueryResult: + """ + Execute a query and return results. + + This is the primary method for data retrieval. It takes a query + object describing what data to fetch (columns, metrics, filters, time range, + etc.) and returns a QueryResult containing a pandas DataFrame with the results. + + :param query_obj: QueryObject describing the query + + :return: QueryResult containing: + - df: pandas DataFrame with query results + - query: string representation of the executed query + - duration: query execution time + - status: QueryStatus (SUCCESS/FAILED) + - error_message: error details if query failed + """ + + def get_query_str(self, query_obj: QueryObjectDict) -> str: + """ + Get the query string without executing. + + Returns a string representation of the query that would be executed + for the given query object. This is used for display in the UI + and debugging. + + :param query_obj: Dictionary describing the query + :return: String representation of the query (SQL, GraphQL, etc.) + """ + + # ========================================================================= + # Identity & Metadata + # ========================================================================= + + @property + def uid(self) -> str: + """ + Unique identifier for this explorable. + + Used as part of cache keys and for tracking. Should be stable + across application restarts but change when the explorable's + data or structure changes. + + Format convention: "{type}_{id}" (e.g., "table_123", "semantic_view_abc") + + :return: Unique identifier string + """ + + @property + def type(self) -> str: + """ + Type discriminator for this explorable. + + Identifies the kind of data source (e.g., 'table', 'query', 'semantic_view'). + Used for routing and type-specific behavior. + + :return: Type identifier string + """ + + @property + def metrics(self) -> list[Any]: + """ + List of metric metadata objects. + + Each object should provide at minimum: + - metric_name: str - the metric's name + - expression: str - the metric's calculation expression + + Used for validation, autocomplete, and query building. + + :return: List of metric metadata objects + """ + + # TODO: rename to dimensions + @property + def columns(self) -> list[Any]: + """ + List of column metadata objects. + + Each object should provide at minimum: + - column_name: str - the column's name + - type: str - the column's data type + - is_dttm: bool - whether it's a datetime column + + Used for validation, autocomplete, and query building. + + :return: List of column metadata objects + """ + + # TODO: remove and use columns instead + @property + def column_names(self) -> list[str]: + """ + List of available column names. + + A simple list of all column names in the explorable. + Used for quick validation and filtering. + + :return: List of column name strings + """ + + # TODO: use TypedDict for return type + @property + def data(self) -> dict[str, Any]: + """ + Full metadata representation sent to the frontend. + + This property returns a dictionary containing all the metadata + needed by the Explore UI, including columns, metrics, and + other configuration. + + Required keys in the returned dictionary: + - id: unique identifier (int or str) + - uid: unique string identifier + - name: display name + - type: explorable type ('table', 'query', 'semantic_view', etc.) + - columns: list of column metadata dicts (with column_name, type, etc.) + - metrics: list of metric metadata dicts (with metric_name, expression, etc.) + - database: database metadata dict (with id, backend, etc.) + + Optional keys: + - description: human-readable description + - schema: schema name (if applicable) + - catalog: catalog name (if applicable) + - cache_timeout: default cache timeout + - offset: timezone offset + - owners: list of owner IDs + - verbose_map: dict mapping column/metric names to display names + + :return: Dictionary with complete explorable metadata + """ + + # ========================================================================= + # Caching + # ========================================================================= + + @property + def cache_timeout(self) -> int | None: + """ + Default cache timeout in seconds. + + Determines how long query results should be cached. + Returns None to use the system default cache timeout. + + :return: Cache timeout in seconds, or None for system default + """ + + @property + def changed_on(self) -> datetime | None: + """ + Last modification timestamp. + + Used for cache invalidation - when this changes, cached + results for this explorable become invalid. + + :return: Datetime of last modification, or None + """ + + def get_extra_cache_keys(self, query_obj: QueryObjectDict) -> list[Hashable]: + """ + Additional cache key components specific to this explorable. + + Provides explorable-specific values to include in cache keys. + Used to ensure cache invalidation when the explorable's + underlying data or configuration changes in ways not captured + by uid or changed_on. + + :param query_obj: The query being executed + :return: List of additional hashable values for cache key + """ + + # ========================================================================= + # Security + # ========================================================================= + + @property + def perm(self) -> str: + """ + Permission string for this explorable. + + Used by the security manager to check if a user has access + to this data source. Format depends on the explorable type + (e.g., "[database].[schema].[table]" for SQL tables). + + :return: Permission identifier string + """ + + # ========================================================================= + # Time/Date Handling + # ========================================================================= + + @property + def offset(self) -> int: + """ + Timezone offset for datetime columns. + + Used to normalize datetime values to the user's timezone. + Returns 0 for UTC, or an offset in seconds. + + :return: Timezone offset in seconds (0 for UTC) + """ + + # ========================================================================= + # Time Granularity + # ========================================================================= + + def get_time_grains(self) -> list[dict[str, Any]]: + """ + Get available time granularities for temporal grouping. + + Returns a list of time grain options that can be used for grouping + temporal data. Each time grain specifies how to bucket timestamps + (e.g., by hour, day, week, month). + + Each dictionary in the returned list should contain: + - name: str - Display name (e.g., "Hour", "Day", "Week") + - function: str - How to apply the grain (implementation-specific) + - duration: str - ISO 8601 duration string (e.g., "PT1H", "P1D", "P1W") + + For SQL datasources, the function is typically a SQL expression template + like "DATE_TRUNC('hour', {col})". For semantic layers, it might be a + semantic layer-specific identifier like "hour" or "day". + + Return an empty list if time grains are not supported or applicable. + + Example return value: + ```python + [ + { + "name": "Second", + "function": "DATE_TRUNC('second', {col})", + "duration": "PT1S", + }, + { + "name": "Minute", + "function": "DATE_TRUNC('minute', {col})", + "duration": "PT1M", + }, + { + "name": "Hour", + "function": "DATE_TRUNC('hour', {col})", + "duration": "PT1H", + }, + { + "name": "Day", + "function": "DATE_TRUNC('day', {col})", + "duration": "P1D", + }, + ] + ``` + + :return: List of time grain dictionaries (empty list if not supported) + """ + + # ========================================================================= + # Drilling + # ========================================================================= + + def has_drill_by_columns(self, column_names: list[str]) -> bool: + """ + Check if the specified columns support drill-by operations. + + Drill-by allows users to navigate from aggregated views to detailed + data by grouping on specific dimensions. This method determines whether + the given columns can be used for drill-by in the current datasource. + + For SQL datasources, this typically checks if columns are marked as + groupable in the metadata. For semantic views, it checks against the + semantic layer's dimension definitions. + + :param column_names: List of column names to check + :return: True if all columns support drill-by, False otherwise + """ + + # ========================================================================= + # Optional Properties + # ========================================================================= + + @property + def is_rls_supported(self) -> bool: + """ + Whether this explorable supports Row Level Security. + + Row Level Security (RLS) allows filtering data based on user identity. + SQL-based datasources typically support this via SQL queries, while + semantic layers may handle security at the semantic layer level. + + :return: True if RLS is supported, False otherwise + """ + + @property + def query_language(self) -> str | None: + """ + Query language identifier for syntax highlighting. + + Specifies the language used in queries for proper syntax highlighting + in the UI (e.g., 'sql', 'graphql', 'jsoniq'). + + :return: Language identifier string, or None if not applicable + """ diff --git a/superset/models/helpers.py b/superset/models/helpers.py index a336dfaf56c..a4fb9e3fea1 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -467,7 +467,9 @@ class ImportExportMixin(UUIDMixin): if parent_ref: parent_excludes = {c.name for c in parent_ref.local_columns} dict_rep = { - c.name: getattr(self, c.name) + # Convert c.name to str to handle SQLAlchemy's quoted_name type + # which is not YAML-serializable + str(c.name): getattr(self, c.name) for c in cls.__table__.columns # type: ignore if ( c.name in export_fields @@ -837,7 +839,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods raise NotImplementedError() @property - def cache_timeout(self) -> int: + def cache_timeout(self) -> int | None: raise NotImplementedError() @property diff --git a/superset/security/manager.py b/superset/security/manager.py index f173da512e2..ea739376eef 100644 --- a/superset/security/manager.py +++ b/superset/security/manager.py @@ -87,6 +87,7 @@ if TYPE_CHECKING: RowLevelSecurityFilter, SqlaTable, ) + from superset.explorables.base import Explorable from superset.models.core import Database from superset.models.dashboard import Dashboard from superset.models.slice import Slice @@ -540,24 +541,43 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods or (catalog_perm and self.can_access("catalog_access", catalog_perm)) ) - def can_access_schema(self, datasource: "BaseDatasource") -> bool: + def can_access_schema(self, datasource: "BaseDatasource | Explorable") -> bool: """ Return True if the user can access the schema associated with specified datasource, False otherwise. + For SQL datasources: Checks database → catalog → schema hierarchy + For other explorables: Only checks all_datasources permission + :param datasource: The datasource :returns: Whether the user can access the datasource's schema """ + from superset.connectors.sqla.models import BaseDatasource - return ( - self.can_access_all_datasources() - or self.can_access_database(datasource.database) - or ( - datasource.catalog + # Admin/superuser override + if self.can_access_all_datasources(): + return True + + # SQL-specific hierarchy checks + if isinstance(datasource, BaseDatasource): + # Database-level access grants all schemas + if self.can_access_database(datasource.database): + return True + + # Catalog-level access grants all schemas in catalog + if ( + hasattr(datasource, "catalog") + and datasource.catalog and self.can_access_catalog(datasource.database, datasource.catalog) - ) - or self.can_access("schema_access", datasource.schema_perm or "") - ) + ): + return True + + # Schema-level permission (SQL only) + if self.can_access("schema_access", datasource.schema_perm or ""): + return True + + # Non-SQL explorables don't have schema hierarchy + return False def can_access_datasource(self, datasource: "BaseDatasource") -> bool: """ @@ -604,7 +624,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods self, form_data: dict[str, Any], dashboard: "Dashboard", - datasource: "BaseDatasource", + datasource: "BaseDatasource | Explorable", ) -> bool: """ Return True if the form_data is performing a supported drill by operation, @@ -612,10 +632,10 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods :param form_data: The form_data included in the request. :param dashboard: The dashboard the user is drilling from. - :returns: Whether the user has drill byaccess. + :param datasource: The datasource being queried + :returns: Whether the user has drill by access. """ - from superset.connectors.sqla.models import TableColumn from superset.models.slice import Slice return bool( @@ -630,16 +650,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods and slc in dashboard.slices and slc.datasource == datasource and (dimensions := form_data.get("groupby")) - and ( - drillable_columns := { - row[0] - for row in self.session.query(TableColumn.column_name) - .filter(TableColumn.table_id == datasource.id) - .filter(TableColumn.groupby) - .all() - } - ) - and set(dimensions).issubset(drillable_columns) + and datasource.has_drill_by_columns(dimensions) ) def can_access_dashboard(self, dashboard: "Dashboard") -> bool: @@ -705,7 +716,9 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods ) @staticmethod - def get_datasource_access_error_msg(datasource: "BaseDatasource") -> str: + def get_datasource_access_error_msg( + datasource: "BaseDatasource | Explorable", + ) -> str: """ Return the error message for the denied Superset datasource. @@ -714,13 +727,13 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods """ return ( - f"This endpoint requires the datasource {datasource.id}, " + f"This endpoint requires the datasource {datasource.data['id']}, " "database or `all_datasource_access` permission" ) @staticmethod def get_datasource_access_link( # pylint: disable=unused-argument - datasource: "BaseDatasource", + datasource: "BaseDatasource | Explorable", ) -> Optional[str]: """ Return the link for the denied Superset datasource. @@ -732,7 +745,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods return get_conf().get("PERMISSION_INSTRUCTIONS_LINK") def get_datasource_access_error_object( # pylint: disable=invalid-name - self, datasource: "BaseDatasource" + self, datasource: "BaseDatasource | Explorable" ) -> SupersetError: """ Return the error object for the denied Superset datasource. @@ -746,8 +759,8 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods level=ErrorLevel.WARNING, extra={ "link": self.get_datasource_access_link(datasource), - "datasource": datasource.id, - "datasource_name": datasource.name, + "datasource": datasource.data["id"], + "datasource_name": datasource.data["name"], }, ) @@ -2280,8 +2293,8 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods dashboard: Optional["Dashboard"] = None, chart: Optional["Slice"] = None, database: Optional["Database"] = None, - datasource: Optional["BaseDatasource"] = None, - query: Optional["Query"] = None, + datasource: Optional["BaseDatasource | Explorable"] = None, + query: Optional["Query | Explorable"] = None, query_context: Optional["QueryContext"] = None, table: Optional["Table"] = None, viz: Optional["BaseViz"] = None, @@ -2326,7 +2339,9 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods if database and table or query: if query: - database = query.database + # Type narrow: only SQL Lab Query objects have .database attribute + if hasattr(query, "database"): + database = query.database database = cast("Database", database) default_catalog = database.get_default_catalog() @@ -2334,7 +2349,8 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods if self.can_access_database(database): return - if query: + # Type narrow: this path only applies to SQL Lab Query objects + if query and hasattr(query, "sql") and hasattr(query, "catalog"): # Getting the default schema for a query is hard. Users can select the # schema in SQL Lab, but there's no guarantee that the query actually # will run in that schema. Each DB engine spec needs to implement the @@ -2342,8 +2358,11 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods # If the DB engine spec doesn't implement the logic the schema is read # from the SQLAlchemy URI if possible; if not, we use the SQLAlchemy # inspector to read it. + from superset.models.sql_lab import Query + default_schema = database.get_default_schema_for_query( - query, template_params + cast(Query, query), + template_params, ) tables = { table_.qualify( @@ -2455,7 +2474,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods and dashboard_.json_metadata and (json_metadata := json.loads(dashboard_.json_metadata)) and any( - target.get("datasetId") == datasource.id + target.get("datasetId") == datasource.data["id"] for fltr in json_metadata.get( "native_filter_configuration", [], @@ -2560,7 +2579,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods return super().get_user_roles(user) def get_guest_rls_filters( - self, dataset: "BaseDatasource" + self, dataset: "BaseDatasource | Explorable" ) -> list[GuestTokenRlsRule]: """ Retrieves the row level security filters for the current user and the dataset, @@ -2573,11 +2592,11 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods rule for rule in guest_user.rls if not rule.get("dataset") - or str(rule.get("dataset")) == str(dataset.id) + or str(rule.get("dataset")) == str(dataset.data["id"]) ] return [] - def get_rls_filters(self, table: "BaseDatasource") -> list[SqlaQuery]: + def get_rls_filters(self, table: "BaseDatasource | Explorable") -> list[SqlaQuery]: """ Retrieves the appropriate row level security filters for the current user and the passed table. @@ -2614,7 +2633,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods .filter(RLSFilterRoles.c.role_id.in_(user_roles)) ) filter_tables = self.session.query(RLSFilterTables.c.rls_filter_id).filter( - RLSFilterTables.c.table_id == table.id + RLSFilterTables.c.table_id == table.data["id"] ) query = ( self.session.query( @@ -2640,7 +2659,9 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods ) return query.all() - def get_rls_sorted(self, table: "BaseDatasource") -> list["RowLevelSecurityFilter"]: + def get_rls_sorted( + self, table: "BaseDatasource | Explorable" + ) -> list["RowLevelSecurityFilter"]: """ Retrieves a list RLS filters sorted by ID for the current user and the passed table. @@ -2652,10 +2673,12 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods filters.sort(key=lambda f: f.id) return filters - def get_guest_rls_filters_str(self, table: "BaseDatasource") -> list[str]: + def get_guest_rls_filters_str( + self, table: "BaseDatasource | Explorable" + ) -> list[str]: return [f.get("clause", "") for f in self.get_guest_rls_filters(table)] - def get_rls_cache_key(self, datasource: "BaseDatasource") -> list[str]: + def get_rls_cache_key(self, datasource: "Explorable | BaseDatasource") -> list[str]: rls_clauses_with_group_key = [] if datasource.is_rls_supported: rls_clauses_with_group_key = [ diff --git a/superset/thumbnails/digest.py b/superset/thumbnails/digest.py index 31a179fd93c..1cdf8d3a642 100644 --- a/superset/thumbnails/digest.py +++ b/superset/thumbnails/digest.py @@ -61,6 +61,7 @@ def _adjust_string_with_rls( """ Add the RLS filters to the unique string based on current executor. """ + user = ( security_manager.find_user(executor) or security_manager.get_current_guest_user_if_guest() @@ -70,11 +71,7 @@ def _adjust_string_with_rls( stringified_rls = "" with override_user(user): for datasource in datasources: - if ( - datasource - and hasattr(datasource, "is_rls_supported") - and datasource.is_rls_supported - ): + if datasource and getattr(datasource, "is_rls_supported", False): rls_filters = datasource.get_sqla_row_level_filters() if len(rls_filters) > 0: diff --git a/superset/utils/core.py b/superset/utils/core.py index 8fd7e4750ed..60ace76fa23 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -115,6 +115,7 @@ from superset.utils.pandas import detect_datetime_format if TYPE_CHECKING: from superset.connectors.sqla.models import BaseDatasource, TableColumn + from superset.explorables.base import Explorable from superset.models.core import Database from superset.models.sql_lab import Query @@ -1656,7 +1657,9 @@ def map_sql_type_to_inferred_type(sql_type: Optional[str]) -> str: return "string" # If no match is found, return "string" as default -def get_metric_type_from_column(column: Any, datasource: BaseDatasource | Query) -> str: +def get_metric_type_from_column( + column: Any, datasource: BaseDatasource | Explorable | Query +) -> str: """ Determine the metric type from a given column in a datasource. @@ -1698,7 +1701,7 @@ def get_metric_type_from_column(column: Any, datasource: BaseDatasource | Query) def extract_dataframe_dtypes( df: pd.DataFrame, - datasource: BaseDatasource | Query | None = None, + datasource: BaseDatasource | Explorable | Query | None = None, ) -> list[GenericDataType]: """Serialize pandas/numpy dtypes to generic types""" @@ -1718,7 +1721,8 @@ def extract_dataframe_dtypes( if datasource: for column in datasource.columns: if isinstance(column, dict): - columns_by_name[column.get("column_name")] = column + if column_name := column.get("column_name"): + columns_by_name[column_name] = column else: columns_by_name[column.column_name] = column @@ -1768,11 +1772,13 @@ def is_test() -> bool: def get_time_filter_status( - datasource: BaseDatasource, + datasource: BaseDatasource | Explorable, applied_time_extras: dict[str, str], ) -> tuple[list[dict[str, str]], list[dict[str, str]]]: temporal_columns: set[Any] = { - col.column_name for col in datasource.columns if col.is_dttm + (col.column_name if hasattr(col, "column_name") else col.get("column_name")) + for col in datasource.columns + if (col.is_dttm if hasattr(col, "is_dttm") else col.get("is_dttm")) } applied: list[dict[str, str]] = [] rejected: list[dict[str, str]] = [] diff --git a/tests/unit_tests/common/test_time_shifts.py b/tests/unit_tests/common/test_time_shifts.py index f65b9d93eeb..08b6478c482 100644 --- a/tests/unit_tests/common/test_time_shifts.py +++ b/tests/unit_tests/common/test_time_shifts.py @@ -39,32 +39,26 @@ processor = QueryContextProcessor( ) # Bind ExploreMixin methods to datasource for testing -processor._qc_datasource.add_offset_join_column = ( - ExploreMixin.add_offset_join_column.__get__(processor._qc_datasource) +# Type annotation needed because _qc_datasource is typed as Explorable in protocol +_datasource: BaseDatasource = processor._qc_datasource # type: ignore +_datasource.add_offset_join_column = ExploreMixin.add_offset_join_column.__get__( + _datasource ) -processor._qc_datasource.join_offset_dfs = ExploreMixin.join_offset_dfs.__get__( - processor._qc_datasource +_datasource.join_offset_dfs = ExploreMixin.join_offset_dfs.__get__(_datasource) +_datasource.is_valid_date_range = ExploreMixin.is_valid_date_range.__get__(_datasource) +_datasource._determine_join_keys = ExploreMixin._determine_join_keys.__get__( + _datasource ) -processor._qc_datasource.is_valid_date_range = ExploreMixin.is_valid_date_range.__get__( - processor._qc_datasource -) -processor._qc_datasource._determine_join_keys = ( - ExploreMixin._determine_join_keys.__get__(processor._qc_datasource) -) -processor._qc_datasource._perform_join = ExploreMixin._perform_join.__get__( - processor._qc_datasource -) -processor._qc_datasource._apply_cleanup_logic = ( - ExploreMixin._apply_cleanup_logic.__get__(processor._qc_datasource) +_datasource._perform_join = ExploreMixin._perform_join.__get__(_datasource) +_datasource._apply_cleanup_logic = ExploreMixin._apply_cleanup_logic.__get__( + _datasource ) # Static methods don't need binding - assign directly -processor._qc_datasource.generate_join_column = ExploreMixin.generate_join_column -processor._qc_datasource.is_valid_date_range_static = ( - ExploreMixin.is_valid_date_range_static -) +_datasource.generate_join_column = ExploreMixin.generate_join_column +_datasource.is_valid_date_range_static = ExploreMixin.is_valid_date_range_static # Convenience reference for backward compatibility in tests -query_context_processor = processor._qc_datasource +query_context_processor = _datasource @fixture