feat: Explorable protocol (#36245)

This commit is contained in:
Beto Dealmeida
2025-12-04 13:18:34 -05:00
committed by GitHub
parent c36ac53445
commit 16e6452b8c
14 changed files with 563 additions and 109 deletions

View File

@@ -53,7 +53,7 @@ extension-pkg-whitelist=pyarrow
[MESSAGES CONTROL]
disable=all
enable=disallowed-json-import,disallowed-sql-import,consider-using-transaction
enable=json-import,disallowed-sql-import,consider-using-transaction
[REPORTS]

View File

@@ -68,7 +68,7 @@ class StreamingCSVExportCommand(BaseStreamingCSVExportCommand):
query_obj = self._query_context.queries[0]
sql_query = datasource.get_query_str(query_obj.to_dict())
return sql_query, datasource.database
return sql_query, getattr(datasource, "database", None)
def _get_row_limit(self) -> int | None:
"""

View File

@@ -77,6 +77,13 @@ class ExportDatasetsCommand(ExportModelsCommand):
payload["version"] = EXPORT_VERSION
payload["database_uuid"] = str(model.database.uuid)
# Always set cache_timeout from the property to ensure correct value
payload["cache_timeout"] = model.cache_timeout
# SQLAlchemy returns column names as quoted_name objects which PyYAML cannot
# serialize. Convert all keys to regular strings to fix YAML serialization.
payload = {str(key): value for key, value in payload.items()}
file_content = yaml.safe_dump(payload, sort_keys=False)
return file_content

View File

@@ -23,8 +23,8 @@ from flask_babel import _
from superset.common.chart_data import ChartDataResultType
from superset.common.db_query_status import QueryStatus
from superset.connectors.sqla.models import BaseDatasource
from superset.exceptions import QueryObjectValidationError, SupersetParseError
from superset.explorables.base import Explorable
from superset.utils.core import (
extract_column_dtype,
extract_dataframe_dtypes,
@@ -38,9 +38,7 @@ if TYPE_CHECKING:
from superset.common.query_object import QueryObject
def _get_datasource(
query_context: QueryContext, query_obj: QueryObject
) -> BaseDatasource:
def _get_datasource(query_context: QueryContext, query_obj: QueryObject) -> Explorable:
return query_obj.datasource or query_context.datasource
@@ -64,16 +62,9 @@ def _get_timegrains(
query_context: QueryContext, query_obj: QueryObject, _: bool
) -> dict[str, Any]:
datasource = _get_datasource(query_context, query_obj)
return {
"data": [
{
"name": grain.name,
"function": grain.function,
"duration": grain.duration,
}
for grain in datasource.database.grains()
]
}
# Use the new get_time_grains() method from Explorable protocol
grains = datasource.get_time_grains()
return {"data": grains}
def _get_query(
@@ -158,7 +149,8 @@ def _get_samples(
qry_obj_cols = []
for o in datasource.columns:
if isinstance(o, dict):
qry_obj_cols.append(o.get("column_name"))
if column_name := o.get("column_name"):
qry_obj_cols.append(column_name)
else:
qry_obj_cols.append(o.column_name)
query_obj.columns = qry_obj_cols
@@ -180,7 +172,8 @@ def _get_drill_detail(
qry_obj_cols = []
for o in datasource.columns:
if isinstance(o, dict):
qry_obj_cols.append(o.get("column_name"))
if column_name := o.get("column_name"):
qry_obj_cols.append(column_name)
else:
qry_obj_cols.append(o.column_name)
query_obj.columns = qry_obj_cols

View File

@@ -24,11 +24,11 @@ import pandas as pd
from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType
from superset.common.query_context_processor import QueryContextProcessor
from superset.common.query_object import QueryObject
from superset.explorables.base import Explorable
from superset.models.slice import Slice
from superset.utils.core import GenericDataType
if TYPE_CHECKING:
from superset.connectors.sqla.models import BaseDatasource
from superset.models.helpers import QueryResult
@@ -44,7 +44,7 @@ class QueryContext:
cache_type: ClassVar[str] = "df"
enforce_numerical_metrics: ClassVar[bool] = True
datasource: BaseDatasource
datasource: Explorable
slice_: Slice | None = None
queries: list[QueryObject]
form_data: dict[str, Any] | None
@@ -62,7 +62,7 @@ class QueryContext:
def __init__( # pylint: disable=too-many-arguments
self,
*,
datasource: BaseDatasource,
datasource: Explorable,
queries: list[QueryObject],
slice_: Slice | None,
form_data: dict[str, Any] | None,
@@ -99,15 +99,24 @@ class QueryContext:
return self._processor.get_payload(cache_query_context, force_cached)
def get_cache_timeout(self) -> int | None:
"""
Get the cache timeout for this query context.
Priority order:
1. Custom timeout set for this specific query
2. Chart-level timeout (if querying from a saved chart)
3. Datasource-level timeout (explorable handles its own fallback logic)
4. System default (None)
Note: Each explorable is responsible for its own internal fallback chain.
For example, BaseDatasource falls back to database.cache_timeout,
while semantic layers might fall back to their layer's default.
"""
if self.custom_cache_timeout is not None:
return self.custom_cache_timeout
if self.slice_ and self.slice_.cache_timeout is not None:
return self.slice_.cache_timeout
if self.datasource.cache_timeout is not None:
return self.datasource.cache_timeout
if hasattr(self.datasource, "database"):
return self.datasource.database.cache_timeout
return None
def query_cache_key(self, query_obj: QueryObject, **kwargs: Any) -> str | None:
return self._processor.query_cache_key(query_obj, **kwargs)

View File

@@ -16,7 +16,7 @@
# under the License.
from __future__ import annotations
from typing import Any, TYPE_CHECKING
from typing import Any
from flask import current_app
@@ -26,13 +26,11 @@ from superset.common.query_object import QueryObject
from superset.common.query_object_factory import QueryObjectFactory
from superset.daos.chart import ChartDAO
from superset.daos.datasource import DatasourceDAO
from superset.explorables.base import Explorable
from superset.models.slice import Slice
from superset.superset_typing import Column
from superset.utils.core import DatasourceDict, DatasourceType, is_adhoc_column
if TYPE_CHECKING:
from superset.connectors.sqla.models import BaseDatasource
def create_query_object_factory() -> QueryObjectFactory:
return QueryObjectFactory(current_app.config, DatasourceDAO())
@@ -104,7 +102,7 @@ class QueryContextFactory: # pylint: disable=too-few-public-methods
cache_values=cache_values,
)
def _convert_to_model(self, datasource: DatasourceDict) -> BaseDatasource:
def _convert_to_model(self, datasource: DatasourceDict) -> Explorable:
return DatasourceDAO.get_datasource(
datasource_type=DatasourceType(datasource["type"]),
database_id_or_uuid=datasource["id"],
@@ -115,7 +113,7 @@ class QueryContextFactory: # pylint: disable=too-few-public-methods
def _process_query_object(
self,
datasource: BaseDatasource,
datasource: Explorable,
form_data: dict[str, Any] | None,
query_object: QueryObject,
) -> QueryObject:
@@ -201,7 +199,7 @@ class QueryContextFactory: # pylint: disable=too-few-public-methods
self,
query_object: QueryObject,
form_data: dict[str, Any] | None,
datasource: BaseDatasource,
datasource: Explorable,
) -> None:
temporal_columns = {
column["column_name"] if isinstance(column, dict) else column.column_name

View File

@@ -29,7 +29,6 @@ from superset.common.db_query_status import QueryStatus
from superset.common.query_actions import get_query_results
from superset.common.utils.query_cache_manager import QueryCacheManager
from superset.common.utils.time_range_utils import get_since_until_from_time_range
from superset.connectors.sqla.models import BaseDatasource
from superset.constants import CACHE_DISABLED_TIMEOUT, CacheRegion
from superset.daos.annotation_layer import AnnotationLayerDAO
from superset.daos.chart import ChartDAO
@@ -37,6 +36,7 @@ from superset.exceptions import (
QueryObjectValidationError,
SupersetException,
)
from superset.explorables.base import Explorable
from superset.extensions import cache_manager, security_manager
from superset.models.helpers import QueryResult
from superset.superset_typing import AdhocColumn, AdhocMetric
@@ -70,7 +70,7 @@ class QueryContextProcessor:
"""
_query_context: QueryContext
_qc_datasource: BaseDatasource
_qc_datasource: Explorable
def __init__(self, query_context: QueryContext):
self._query_context = query_context

View File

@@ -198,7 +198,7 @@ class BaseDatasource(
is_featured = Column(Boolean, default=False) # TODO deprecating
filter_select_enabled = Column(Boolean, default=True)
offset = Column(Integer, default=0)
cache_timeout = Column(Integer)
_cache_timeout = Column("cache_timeout", Integer)
params = Column(String(1000))
perm = Column(String(1000))
schema_perm = Column(String(1000))
@@ -212,6 +212,78 @@ class BaseDatasource(
extra_import_fields = ["is_managed_externally", "external_url"]
@property
def cache_timeout(self) -> int | None:
"""
Get the cache timeout for this datasource.
Implements the Explorable protocol by handling the fallback chain:
1. Datasource-specific timeout (if set)
2. Database default timeout (if no datasource timeout)
3. None (use system default)
This allows each datasource to override caching, while falling back
to database-level defaults when appropriate.
"""
if self._cache_timeout is not None:
return self._cache_timeout
# database should always be set, but that's not true for v0 import
if self.database:
return self.database.cache_timeout
return None
@cache_timeout.setter
def cache_timeout(self, value: int | None) -> None:
"""Set the datasource-specific cache timeout."""
self._cache_timeout = value
def has_drill_by_columns(self, column_names: list[str]) -> bool:
"""
Check if the specified columns support drill-by operations.
For SQL datasources, drill-by is supported on columns that are marked
as groupable in the metadata. This allows users to navigate from
aggregated views to detailed data by grouping on these dimensions.
:param column_names: List of column names to check
:return: True if all columns support drill-by, False otherwise
"""
if not column_names:
return False
# Get all groupable column names for this datasource
drillable_columns = {
row[0]
for row in db.session.query(TableColumn.column_name)
.filter(TableColumn.table_id == self.id)
.filter(TableColumn.groupby)
.all()
}
# Check if all requested columns are drillable
return set(column_names).issubset(drillable_columns)
def get_time_grains(self) -> list[dict[str, Any]]:
"""
Get available time granularities from the database.
Implements the Explorable protocol by delegating to the database's
time grain definitions. Each database engine spec defines its own
set of supported time grains.
:return: List of time grain dictionaries with name, function, and duration
"""
return [
{
"name": grain.name,
"function": grain.function,
"duration": grain.duration,
}
for grain in (self.database.grains() or [])
]
@property
def kind(self) -> DatasourceKind:
return DatasourceKind.VIRTUAL if self.sql else DatasourceKind.PHYSICAL

View File

@@ -0,0 +1,353 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Base protocol for explorable data sources in Superset.
An "explorable" is any data source that can be explored to create charts,
including SQL datasets, saved queries, and semantic layer views.
"""
from __future__ import annotations
from collections.abc import Hashable
from datetime import datetime
from typing import Any, Protocol, runtime_checkable
from superset.common.query_object import QueryObject
from superset.models.helpers import QueryResult
from superset.superset_typing import QueryObjectDict
@runtime_checkable
class Explorable(Protocol):
"""
Protocol for objects that can be explored to create charts.
This protocol defines the minimal interface required for a data source
to be visualizable in Superset. It is implemented by:
- BaseDatasource (SQL datasets and queries)
- SemanticView (semantic layer views)
- Future: Other data source types
The protocol focuses on the essential methods and properties needed
for query execution, caching, and security.
"""
# =========================================================================
# Core Query Interface
# =========================================================================
def get_query_result(self, query_object: QueryObject) -> QueryResult:
"""
Execute a query and return results.
This is the primary method for data retrieval. It takes a query
object describing what data to fetch (columns, metrics, filters, time range,
etc.) and returns a QueryResult containing a pandas DataFrame with the results.
:param query_obj: QueryObject describing the query
:return: QueryResult containing:
- df: pandas DataFrame with query results
- query: string representation of the executed query
- duration: query execution time
- status: QueryStatus (SUCCESS/FAILED)
- error_message: error details if query failed
"""
def get_query_str(self, query_obj: QueryObjectDict) -> str:
"""
Get the query string without executing.
Returns a string representation of the query that would be executed
for the given query object. This is used for display in the UI
and debugging.
:param query_obj: Dictionary describing the query
:return: String representation of the query (SQL, GraphQL, etc.)
"""
# =========================================================================
# Identity & Metadata
# =========================================================================
@property
def uid(self) -> str:
"""
Unique identifier for this explorable.
Used as part of cache keys and for tracking. Should be stable
across application restarts but change when the explorable's
data or structure changes.
Format convention: "{type}_{id}" (e.g., "table_123", "semantic_view_abc")
:return: Unique identifier string
"""
@property
def type(self) -> str:
"""
Type discriminator for this explorable.
Identifies the kind of data source (e.g., 'table', 'query', 'semantic_view').
Used for routing and type-specific behavior.
:return: Type identifier string
"""
@property
def metrics(self) -> list[Any]:
"""
List of metric metadata objects.
Each object should provide at minimum:
- metric_name: str - the metric's name
- expression: str - the metric's calculation expression
Used for validation, autocomplete, and query building.
:return: List of metric metadata objects
"""
# TODO: rename to dimensions
@property
def columns(self) -> list[Any]:
"""
List of column metadata objects.
Each object should provide at minimum:
- column_name: str - the column's name
- type: str - the column's data type
- is_dttm: bool - whether it's a datetime column
Used for validation, autocomplete, and query building.
:return: List of column metadata objects
"""
# TODO: remove and use columns instead
@property
def column_names(self) -> list[str]:
"""
List of available column names.
A simple list of all column names in the explorable.
Used for quick validation and filtering.
:return: List of column name strings
"""
# TODO: use TypedDict for return type
@property
def data(self) -> dict[str, Any]:
"""
Full metadata representation sent to the frontend.
This property returns a dictionary containing all the metadata
needed by the Explore UI, including columns, metrics, and
other configuration.
Required keys in the returned dictionary:
- id: unique identifier (int or str)
- uid: unique string identifier
- name: display name
- type: explorable type ('table', 'query', 'semantic_view', etc.)
- columns: list of column metadata dicts (with column_name, type, etc.)
- metrics: list of metric metadata dicts (with metric_name, expression, etc.)
- database: database metadata dict (with id, backend, etc.)
Optional keys:
- description: human-readable description
- schema: schema name (if applicable)
- catalog: catalog name (if applicable)
- cache_timeout: default cache timeout
- offset: timezone offset
- owners: list of owner IDs
- verbose_map: dict mapping column/metric names to display names
:return: Dictionary with complete explorable metadata
"""
# =========================================================================
# Caching
# =========================================================================
@property
def cache_timeout(self) -> int | None:
"""
Default cache timeout in seconds.
Determines how long query results should be cached.
Returns None to use the system default cache timeout.
:return: Cache timeout in seconds, or None for system default
"""
@property
def changed_on(self) -> datetime | None:
"""
Last modification timestamp.
Used for cache invalidation - when this changes, cached
results for this explorable become invalid.
:return: Datetime of last modification, or None
"""
def get_extra_cache_keys(self, query_obj: QueryObjectDict) -> list[Hashable]:
"""
Additional cache key components specific to this explorable.
Provides explorable-specific values to include in cache keys.
Used to ensure cache invalidation when the explorable's
underlying data or configuration changes in ways not captured
by uid or changed_on.
:param query_obj: The query being executed
:return: List of additional hashable values for cache key
"""
# =========================================================================
# Security
# =========================================================================
@property
def perm(self) -> str:
"""
Permission string for this explorable.
Used by the security manager to check if a user has access
to this data source. Format depends on the explorable type
(e.g., "[database].[schema].[table]" for SQL tables).
:return: Permission identifier string
"""
# =========================================================================
# Time/Date Handling
# =========================================================================
@property
def offset(self) -> int:
"""
Timezone offset for datetime columns.
Used to normalize datetime values to the user's timezone.
Returns 0 for UTC, or an offset in seconds.
:return: Timezone offset in seconds (0 for UTC)
"""
# =========================================================================
# Time Granularity
# =========================================================================
def get_time_grains(self) -> list[dict[str, Any]]:
"""
Get available time granularities for temporal grouping.
Returns a list of time grain options that can be used for grouping
temporal data. Each time grain specifies how to bucket timestamps
(e.g., by hour, day, week, month).
Each dictionary in the returned list should contain:
- name: str - Display name (e.g., "Hour", "Day", "Week")
- function: str - How to apply the grain (implementation-specific)
- duration: str - ISO 8601 duration string (e.g., "PT1H", "P1D", "P1W")
For SQL datasources, the function is typically a SQL expression template
like "DATE_TRUNC('hour', {col})". For semantic layers, it might be a
semantic layer-specific identifier like "hour" or "day".
Return an empty list if time grains are not supported or applicable.
Example return value:
```python
[
{
"name": "Second",
"function": "DATE_TRUNC('second', {col})",
"duration": "PT1S",
},
{
"name": "Minute",
"function": "DATE_TRUNC('minute', {col})",
"duration": "PT1M",
},
{
"name": "Hour",
"function": "DATE_TRUNC('hour', {col})",
"duration": "PT1H",
},
{
"name": "Day",
"function": "DATE_TRUNC('day', {col})",
"duration": "P1D",
},
]
```
:return: List of time grain dictionaries (empty list if not supported)
"""
# =========================================================================
# Drilling
# =========================================================================
def has_drill_by_columns(self, column_names: list[str]) -> bool:
"""
Check if the specified columns support drill-by operations.
Drill-by allows users to navigate from aggregated views to detailed
data by grouping on specific dimensions. This method determines whether
the given columns can be used for drill-by in the current datasource.
For SQL datasources, this typically checks if columns are marked as
groupable in the metadata. For semantic views, it checks against the
semantic layer's dimension definitions.
:param column_names: List of column names to check
:return: True if all columns support drill-by, False otherwise
"""
# =========================================================================
# Optional Properties
# =========================================================================
@property
def is_rls_supported(self) -> bool:
"""
Whether this explorable supports Row Level Security.
Row Level Security (RLS) allows filtering data based on user identity.
SQL-based datasources typically support this via SQL queries, while
semantic layers may handle security at the semantic layer level.
:return: True if RLS is supported, False otherwise
"""
@property
def query_language(self) -> str | None:
"""
Query language identifier for syntax highlighting.
Specifies the language used in queries for proper syntax highlighting
in the UI (e.g., 'sql', 'graphql', 'jsoniq').
:return: Language identifier string, or None if not applicable
"""

View File

@@ -467,7 +467,9 @@ class ImportExportMixin(UUIDMixin):
if parent_ref:
parent_excludes = {c.name for c in parent_ref.local_columns}
dict_rep = {
c.name: getattr(self, c.name)
# Convert c.name to str to handle SQLAlchemy's quoted_name type
# which is not YAML-serializable
str(c.name): getattr(self, c.name)
for c in cls.__table__.columns # type: ignore
if (
c.name in export_fields
@@ -837,7 +839,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
raise NotImplementedError()
@property
def cache_timeout(self) -> int:
def cache_timeout(self) -> int | None:
raise NotImplementedError()
@property

View File

@@ -87,6 +87,7 @@ if TYPE_CHECKING:
RowLevelSecurityFilter,
SqlaTable,
)
from superset.explorables.base import Explorable
from superset.models.core import Database
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
@@ -540,24 +541,43 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
or (catalog_perm and self.can_access("catalog_access", catalog_perm))
)
def can_access_schema(self, datasource: "BaseDatasource") -> bool:
def can_access_schema(self, datasource: "BaseDatasource | Explorable") -> bool:
"""
Return True if the user can access the schema associated with specified
datasource, False otherwise.
For SQL datasources: Checks database → catalog → schema hierarchy
For other explorables: Only checks all_datasources permission
:param datasource: The datasource
:returns: Whether the user can access the datasource's schema
"""
from superset.connectors.sqla.models import BaseDatasource
return (
self.can_access_all_datasources()
or self.can_access_database(datasource.database)
or (
datasource.catalog
# Admin/superuser override
if self.can_access_all_datasources():
return True
# SQL-specific hierarchy checks
if isinstance(datasource, BaseDatasource):
# Database-level access grants all schemas
if self.can_access_database(datasource.database):
return True
# Catalog-level access grants all schemas in catalog
if (
hasattr(datasource, "catalog")
and datasource.catalog
and self.can_access_catalog(datasource.database, datasource.catalog)
)
or self.can_access("schema_access", datasource.schema_perm or "")
)
):
return True
# Schema-level permission (SQL only)
if self.can_access("schema_access", datasource.schema_perm or ""):
return True
# Non-SQL explorables don't have schema hierarchy
return False
def can_access_datasource(self, datasource: "BaseDatasource") -> bool:
"""
@@ -604,7 +624,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
self,
form_data: dict[str, Any],
dashboard: "Dashboard",
datasource: "BaseDatasource",
datasource: "BaseDatasource | Explorable",
) -> bool:
"""
Return True if the form_data is performing a supported drill by operation,
@@ -612,10 +632,10 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
:param form_data: The form_data included in the request.
:param dashboard: The dashboard the user is drilling from.
:param datasource: The datasource being queried
:returns: Whether the user has drill by access.
"""
from superset.connectors.sqla.models import TableColumn
from superset.models.slice import Slice
return bool(
@@ -630,16 +650,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
and slc in dashboard.slices
and slc.datasource == datasource
and (dimensions := form_data.get("groupby"))
and (
drillable_columns := {
row[0]
for row in self.session.query(TableColumn.column_name)
.filter(TableColumn.table_id == datasource.id)
.filter(TableColumn.groupby)
.all()
}
)
and set(dimensions).issubset(drillable_columns)
and datasource.has_drill_by_columns(dimensions)
)
def can_access_dashboard(self, dashboard: "Dashboard") -> bool:
@@ -705,7 +716,9 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
)
@staticmethod
def get_datasource_access_error_msg(datasource: "BaseDatasource") -> str:
def get_datasource_access_error_msg(
datasource: "BaseDatasource | Explorable",
) -> str:
"""
Return the error message for the denied Superset datasource.
@@ -714,13 +727,13 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
"""
return (
f"This endpoint requires the datasource {datasource.id}, "
f"This endpoint requires the datasource {datasource.data['id']}, "
"database or `all_datasource_access` permission"
)
@staticmethod
def get_datasource_access_link( # pylint: disable=unused-argument
datasource: "BaseDatasource",
datasource: "BaseDatasource | Explorable",
) -> Optional[str]:
"""
Return the link for the denied Superset datasource.
@@ -732,7 +745,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
return get_conf().get("PERMISSION_INSTRUCTIONS_LINK")
def get_datasource_access_error_object( # pylint: disable=invalid-name
self, datasource: "BaseDatasource"
self, datasource: "BaseDatasource | Explorable"
) -> SupersetError:
"""
Return the error object for the denied Superset datasource.
@@ -746,8 +759,8 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
level=ErrorLevel.WARNING,
extra={
"link": self.get_datasource_access_link(datasource),
"datasource": datasource.id,
"datasource_name": datasource.name,
"datasource": datasource.data["id"],
"datasource_name": datasource.data["name"],
},
)
@@ -2280,8 +2293,8 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
dashboard: Optional["Dashboard"] = None,
chart: Optional["Slice"] = None,
database: Optional["Database"] = None,
datasource: Optional["BaseDatasource"] = None,
query: Optional["Query"] = None,
datasource: Optional["BaseDatasource | Explorable"] = None,
query: Optional["Query | Explorable"] = None,
query_context: Optional["QueryContext"] = None,
table: Optional["Table"] = None,
viz: Optional["BaseViz"] = None,
@@ -2326,6 +2339,8 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
if database and table or query:
if query:
# Type narrow: only SQL Lab Query objects have .database attribute
if hasattr(query, "database"):
database = query.database
database = cast("Database", database)
@@ -2334,7 +2349,8 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
if self.can_access_database(database):
return
if query:
# Type narrow: this path only applies to SQL Lab Query objects
if query and hasattr(query, "sql") and hasattr(query, "catalog"):
# Getting the default schema for a query is hard. Users can select the
# schema in SQL Lab, but there's no guarantee that the query actually
# will run in that schema. Each DB engine spec needs to implement the
@@ -2342,8 +2358,11 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
# If the DB engine spec doesn't implement the logic the schema is read
# from the SQLAlchemy URI if possible; if not, we use the SQLAlchemy
# inspector to read it.
from superset.models.sql_lab import Query
default_schema = database.get_default_schema_for_query(
query, template_params
cast(Query, query),
template_params,
)
tables = {
table_.qualify(
@@ -2455,7 +2474,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
and dashboard_.json_metadata
and (json_metadata := json.loads(dashboard_.json_metadata))
and any(
target.get("datasetId") == datasource.id
target.get("datasetId") == datasource.data["id"]
for fltr in json_metadata.get(
"native_filter_configuration",
[],
@@ -2560,7 +2579,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
return super().get_user_roles(user)
def get_guest_rls_filters(
self, dataset: "BaseDatasource"
self, dataset: "BaseDatasource | Explorable"
) -> list[GuestTokenRlsRule]:
"""
Retrieves the row level security filters for the current user and the dataset,
@@ -2573,11 +2592,11 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
rule
for rule in guest_user.rls
if not rule.get("dataset")
or str(rule.get("dataset")) == str(dataset.id)
or str(rule.get("dataset")) == str(dataset.data["id"])
]
return []
def get_rls_filters(self, table: "BaseDatasource") -> list[SqlaQuery]:
def get_rls_filters(self, table: "BaseDatasource | Explorable") -> list[SqlaQuery]:
"""
Retrieves the appropriate row level security filters for the current user and
the passed table.
@@ -2614,7 +2633,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
.filter(RLSFilterRoles.c.role_id.in_(user_roles))
)
filter_tables = self.session.query(RLSFilterTables.c.rls_filter_id).filter(
RLSFilterTables.c.table_id == table.id
RLSFilterTables.c.table_id == table.data["id"]
)
query = (
self.session.query(
@@ -2640,7 +2659,9 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
)
return query.all()
def get_rls_sorted(self, table: "BaseDatasource") -> list["RowLevelSecurityFilter"]:
def get_rls_sorted(
self, table: "BaseDatasource | Explorable"
) -> list["RowLevelSecurityFilter"]:
"""
Retrieves a list RLS filters sorted by ID for
the current user and the passed table.
@@ -2652,10 +2673,12 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
filters.sort(key=lambda f: f.id)
return filters
def get_guest_rls_filters_str(self, table: "BaseDatasource") -> list[str]:
def get_guest_rls_filters_str(
self, table: "BaseDatasource | Explorable"
) -> list[str]:
return [f.get("clause", "") for f in self.get_guest_rls_filters(table)]
def get_rls_cache_key(self, datasource: "BaseDatasource") -> list[str]:
def get_rls_cache_key(self, datasource: "Explorable | BaseDatasource") -> list[str]:
rls_clauses_with_group_key = []
if datasource.is_rls_supported:
rls_clauses_with_group_key = [

View File

@@ -61,6 +61,7 @@ def _adjust_string_with_rls(
"""
Add the RLS filters to the unique string based on current executor.
"""
user = (
security_manager.find_user(executor)
or security_manager.get_current_guest_user_if_guest()
@@ -70,11 +71,7 @@ def _adjust_string_with_rls(
stringified_rls = ""
with override_user(user):
for datasource in datasources:
if (
datasource
and hasattr(datasource, "is_rls_supported")
and datasource.is_rls_supported
):
if datasource and getattr(datasource, "is_rls_supported", False):
rls_filters = datasource.get_sqla_row_level_filters()
if len(rls_filters) > 0:

View File

@@ -115,6 +115,7 @@ from superset.utils.pandas import detect_datetime_format
if TYPE_CHECKING:
from superset.connectors.sqla.models import BaseDatasource, TableColumn
from superset.explorables.base import Explorable
from superset.models.core import Database
from superset.models.sql_lab import Query
@@ -1656,7 +1657,9 @@ def map_sql_type_to_inferred_type(sql_type: Optional[str]) -> str:
return "string" # If no match is found, return "string" as default
def get_metric_type_from_column(column: Any, datasource: BaseDatasource | Query) -> str:
def get_metric_type_from_column(
column: Any, datasource: BaseDatasource | Explorable | Query
) -> str:
"""
Determine the metric type from a given column in a datasource.
@@ -1698,7 +1701,7 @@ def get_metric_type_from_column(column: Any, datasource: BaseDatasource | Query)
def extract_dataframe_dtypes(
df: pd.DataFrame,
datasource: BaseDatasource | Query | None = None,
datasource: BaseDatasource | Explorable | Query | None = None,
) -> list[GenericDataType]:
"""Serialize pandas/numpy dtypes to generic types"""
@@ -1718,7 +1721,8 @@ def extract_dataframe_dtypes(
if datasource:
for column in datasource.columns:
if isinstance(column, dict):
columns_by_name[column.get("column_name")] = column
if column_name := column.get("column_name"):
columns_by_name[column_name] = column
else:
columns_by_name[column.column_name] = column
@@ -1768,11 +1772,13 @@ def is_test() -> bool:
def get_time_filter_status(
datasource: BaseDatasource,
datasource: BaseDatasource | Explorable,
applied_time_extras: dict[str, str],
) -> tuple[list[dict[str, str]], list[dict[str, str]]]:
temporal_columns: set[Any] = {
col.column_name for col in datasource.columns if col.is_dttm
(col.column_name if hasattr(col, "column_name") else col.get("column_name"))
for col in datasource.columns
if (col.is_dttm if hasattr(col, "is_dttm") else col.get("is_dttm"))
}
applied: list[dict[str, str]] = []
rejected: list[dict[str, str]] = []

View File

@@ -39,32 +39,26 @@ processor = QueryContextProcessor(
)
# Bind ExploreMixin methods to datasource for testing
processor._qc_datasource.add_offset_join_column = (
ExploreMixin.add_offset_join_column.__get__(processor._qc_datasource)
# Type annotation needed because _qc_datasource is typed as Explorable in protocol
_datasource: BaseDatasource = processor._qc_datasource # type: ignore
_datasource.add_offset_join_column = ExploreMixin.add_offset_join_column.__get__(
_datasource
)
processor._qc_datasource.join_offset_dfs = ExploreMixin.join_offset_dfs.__get__(
processor._qc_datasource
_datasource.join_offset_dfs = ExploreMixin.join_offset_dfs.__get__(_datasource)
_datasource.is_valid_date_range = ExploreMixin.is_valid_date_range.__get__(_datasource)
_datasource._determine_join_keys = ExploreMixin._determine_join_keys.__get__(
_datasource
)
processor._qc_datasource.is_valid_date_range = ExploreMixin.is_valid_date_range.__get__(
processor._qc_datasource
)
processor._qc_datasource._determine_join_keys = (
ExploreMixin._determine_join_keys.__get__(processor._qc_datasource)
)
processor._qc_datasource._perform_join = ExploreMixin._perform_join.__get__(
processor._qc_datasource
)
processor._qc_datasource._apply_cleanup_logic = (
ExploreMixin._apply_cleanup_logic.__get__(processor._qc_datasource)
_datasource._perform_join = ExploreMixin._perform_join.__get__(_datasource)
_datasource._apply_cleanup_logic = ExploreMixin._apply_cleanup_logic.__get__(
_datasource
)
# Static methods don't need binding - assign directly
processor._qc_datasource.generate_join_column = ExploreMixin.generate_join_column
processor._qc_datasource.is_valid_date_range_static = (
ExploreMixin.is_valid_date_range_static
)
_datasource.generate_join_column = ExploreMixin.generate_join_column
_datasource.is_valid_date_range_static = ExploreMixin.is_valid_date_range_static
# Convenience reference for backward compatibility in tests
query_context_processor = processor._qc_datasource
query_context_processor = _datasource
@fixture