Compare commits

...

20 Commits

Author SHA1 Message Date
Beto Dealmeida
b9ab0ced77 Fix order 2026-01-26 18:47:44 -05:00
Beto Dealmeida
bfbb68c3c8 WIP 2025-12-15 10:26:52 -05:00
Beto Dealmeida
b437421a8e Fix filter 2025-12-11 15:06:13 -05:00
Beto Dealmeida
e253bd2fb3 Fix mapping 2025-12-11 10:58:17 -05:00
Beto Dealmeida
bfb7048e42 Frontend 2025-12-11 10:32:06 -05:00
Beto Dealmeida
2833b69ca0 WIP 2025-12-10 16:26:45 -05:00
Beto Dealmeida
6e17714a19 WIP 2025-12-10 13:54:19 -05:00
Beto Dealmeida
8a0aaa42ec feat: semantic layer implementation (Snowflake) 2025-12-05 15:32:01 -05:00
Beto Dealmeida
af479a9d99 More cleanup 2025-12-03 10:30:38 -05:00
Beto Dealmeida
77f60f42e6 More cleanup 2025-12-02 12:13:49 -05:00
Beto Dealmeida
f0121a166e chore: improve types 2025-12-01 17:33:11 -05:00
Beto Dealmeida
0c4b0cb9b9 Cleanup code 2025-12-01 11:16:06 -05:00
Beto Dealmeida
a36bbf8ffd Small fixes 2025-11-26 16:15:08 -05:00
Beto Dealmeida
99525c1ce9 Fix errors 2025-11-26 13:21:48 -05:00
Beto Dealmeida
889e9bbade Fix lint/tests 2025-11-25 16:59:17 -05:00
Beto Dealmeida
b809a990ee Fix pylint 2025-11-25 11:50:04 -05:00
Beto Dealmeida
9c7fcbf548 Fix tests 2025-11-25 11:48:46 -05:00
Beto Dealmeida
046aabee73 Fix lint 2025-11-24 16:51:54 -05:00
Beto Dealmeida
b672c7b853 Remove AI artifacts 2025-11-24 14:41:30 -05:00
Beto Dealmeida
ea33d797a7 feat: explorable protocol 2025-11-24 14:39:25 -05:00
41 changed files with 4392 additions and 200 deletions

View File

@@ -53,7 +53,7 @@ extension-pkg-whitelist=pyarrow
[MESSAGES CONTROL]
disable=all
enable=disallowed-json-import,disallowed-sql-import,consider-using-transaction
enable=json-import,disallowed-sql-import,consider-using-transaction
[REPORTS]

View File

@@ -30,7 +30,9 @@ with open(PACKAGE_JSON) as package_file:
def get_git_sha() -> str:
try:
output = subprocess.check_output(["git", "rev-parse", "HEAD"]) # noqa: S603, S607
output = subprocess.check_output(
["git", "rev-parse", "HEAD"]
) # noqa: S603, S607
return output.decode().strip()
except Exception: # pylint: disable=broad-except
return ""
@@ -58,6 +60,9 @@ setup(
include_package_data=True,
zip_safe=False,
entry_points={
"superset.semantic_layers": [
"snowflake = superset.semantic_layers.snowflake:SnowflakeSemanticLayer"
],
"console_scripts": ["superset=superset.cli.main:superset"],
# the `postgres` and `postgres+psycopg2://` schemes were removed in SQLAlchemy 1.4 # noqa: E501
# add an alias here to prevent breaking existing databases

View File

@@ -19,16 +19,27 @@
import { DatasourceType } from './types/Datasource';
const DATASOURCE_TYPE_MAP: Record<string, DatasourceType> = {
table: DatasourceType.Table,
query: DatasourceType.Query,
dataset: DatasourceType.Dataset,
sl_table: DatasourceType.SlTable,
saved_query: DatasourceType.SavedQuery,
semantic_view: DatasourceType.SemanticView,
};
export default class DatasourceKey {
readonly id: number;
readonly id: number | string;
readonly type: DatasourceType;
constructor(key: string) {
const [idStr, typeStr] = key.split('__');
this.id = parseInt(idStr, 10);
this.type = DatasourceType.Table; // default to SqlaTable model
this.type = typeStr === 'query' ? DatasourceType.Query : this.type;
// Only parse as integer if the entire string is numeric
// (parseInt would incorrectly parse "85d3139f..." as 85)
const isNumeric = /^\d+$/.test(idStr);
this.id = isNumeric ? parseInt(idStr, 10) : idStr;
this.type = DATASOURCE_TYPE_MAP[typeStr] ?? DatasourceType.Table;
}
public toString() {

View File

@@ -26,6 +26,7 @@ export enum DatasourceType {
Dataset = 'dataset',
SlTable = 'sl_table',
SavedQuery = 'saved_query',
SemanticView = 'semantic_view',
}
export interface Currency {
@@ -37,7 +38,7 @@ export interface Currency {
* Datasource metadata.
*/
export interface Datasource {
id: number;
id: number | string;
name: string;
type: DatasourceType;
columns: Column[];

View File

@@ -156,7 +156,7 @@ export interface QueryObject
export interface QueryContext {
datasource: {
id: number;
id: number | string;
type: DatasourceType;
};
/** Force refresh of all queries */

View File

@@ -90,11 +90,16 @@ const ModalFooter = ({ formData, closeModal }: ModalFooterProps) => {
findPermission('can_explore', 'Superset', state.user?.roles),
);
const [datasource_id, datasource_type] = formData.datasource.split('__');
const [datasourceIdStr, datasource_type] = formData.datasource.split('__');
// Try to parse as integer, fall back to string (UUID) if NaN
const parsedDatasourceId = parseInt(datasourceIdStr, 10);
const datasource_id = Number.isNaN(parsedDatasourceId)
? datasourceIdStr
: parsedDatasourceId;
useEffect(() => {
// short circuit if the user is embedded as explore is not available
if (isEmbedded()) return;
postFormData(Number(datasource_id), datasource_type, formData, 0)
postFormData(datasource_id, datasource_type, formData, 0)
.then(key => {
setUrl(
`/explore/?form_data_key=${key}&dashboard_page_id=${dashboardPageId}`,

View File

@@ -272,7 +272,7 @@ export type Slice = {
changed_on: number;
changed_on_humanized: string;
modified: string;
datasource_id: number;
datasource_id: number | string;
datasource_type: DatasourceType;
datasource_url: string;
datasource_name: string;

View File

@@ -144,12 +144,14 @@ export const getSlicePayload = async (
...adhocFilters,
dashboards,
};
let datasourceId = 0;
let datasourceId: number | string = 0;
let datasourceType: DatasourceType = DatasourceType.Table;
if (formData.datasource) {
const [id, typeString] = formData.datasource.split('__');
datasourceId = parseInt(id, 10);
// Try to parse as integer, fall back to string (UUID) if NaN
const parsedId = parseInt(id, 10);
datasourceId = Number.isNaN(parsedId) ? id : parsedId;
const formattedTypeString =
typeString.charAt(0).toUpperCase() + typeString.slice(1);

View File

@@ -20,7 +20,7 @@ import { SupersetClient, JsonObject, JsonResponse } from '@superset-ui/core';
import { sanitizeFormData } from 'src/utils/sanitizeFormData';
type Payload = {
datasource_id: number;
datasource_id: number | string;
datasource_type: string;
form_data: string;
chart_id?: number;
@@ -36,7 +36,7 @@ const assembleEndpoint = (key?: string, tabId?: string) => {
};
const assemblePayload = (
datasourceId: number,
datasourceId: number | string,
datasourceType: string,
formData: JsonObject,
chartId?: number,
@@ -53,7 +53,7 @@ const assemblePayload = (
};
export const postFormData = (
datasourceId: number,
datasourceId: number | string,
datasourceType: string,
formData: JsonObject,
chartId?: number,
@@ -70,7 +70,7 @@ export const postFormData = (
}).then((r: JsonResponse) => r.json.key);
export const putFormData = (
datasourceId: number,
datasourceId: number | string,
datasourceType: string,
key: string,
formData: JsonObject,

View File

@@ -68,7 +68,7 @@ class StreamingCSVExportCommand(BaseStreamingCSVExportCommand):
query_obj = self._query_context.queries[0]
sql_query = datasource.get_query_str(query_obj.to_dict())
return sql_query, datasource.database
return sql_query, getattr(datasource, "database", None)
def _get_row_limit(self) -> int | None:
"""

View File

@@ -78,6 +78,13 @@ class ExportDatasetsCommand(ExportModelsCommand):
payload["version"] = EXPORT_VERSION
payload["database_uuid"] = str(model.database.uuid)
# Always set cache_timeout from the property to ensure correct value
payload["cache_timeout"] = model.cache_timeout
# SQLAlchemy returns column names as quoted_name objects which PyYAML cannot
# serialize. Convert all keys to regular strings to fix YAML serialization.
payload = {str(key): value for key, value in payload.items()}
file_content = yaml.safe_dump(payload, sort_keys=False)
return file_content

View File

@@ -37,7 +37,7 @@ from superset.exceptions import SupersetException
from superset.explore.exceptions import WrongEndpointError
from superset.explore.permalink.exceptions import ExplorePermalinkGetFailedError
from superset.extensions import security_manager
from superset.superset_typing import BaseDatasourceData, QueryData
from superset.superset_typing import ExplorableData
from superset.utils import core as utils, json
from superset.views.utils import (
get_datasource_info,
@@ -124,7 +124,11 @@ class GetExploreCommand(BaseCommand, ABC):
security_manager.raise_for_access(datasource=datasource)
viz_type = form_data.get("viz_type")
if not viz_type and datasource and datasource.default_endpoint:
if (
not viz_type
and datasource
and getattr(datasource, "default_endpoint", None)
):
raise WrongEndpointError(redirect=datasource.default_endpoint)
form_data["datasource"] = (
@@ -136,7 +140,7 @@ class GetExploreCommand(BaseCommand, ABC):
utils.merge_extra_filters(form_data)
utils.merge_request_params(form_data, request.args)
datasource_data: BaseDatasourceData | QueryData = {
datasource_data: ExplorableData = {
"type": self._datasource_type or "unknown",
"name": datasource_name,
"columns": [],

View File

@@ -15,13 +15,13 @@
# specific language governing permissions and limitations
# under the License.
from dataclasses import dataclass
from typing import Optional
from typing import Optional, Union
@dataclass
class CommandParameters:
permalink_key: Optional[str]
form_data_key: Optional[str]
datasource_id: Optional[int]
datasource_id: Optional[Union[int, str]]
datasource_type: Optional[str]
slice_id: Optional[int]

View File

@@ -23,8 +23,8 @@ from flask_babel import _
from superset.common.chart_data import ChartDataResultType
from superset.common.db_query_status import QueryStatus
from superset.connectors.sqla.models import BaseDatasource
from superset.exceptions import QueryObjectValidationError, SupersetParseError
from superset.explorables.base import Explorable
from superset.utils.core import (
extract_column_dtype,
extract_dataframe_dtypes,
@@ -38,9 +38,7 @@ if TYPE_CHECKING:
from superset.common.query_object import QueryObject
def _get_datasource(
query_context: QueryContext, query_obj: QueryObject
) -> BaseDatasource:
def _get_datasource(query_context: QueryContext, query_obj: QueryObject) -> Explorable:
return query_obj.datasource or query_context.datasource
@@ -64,16 +62,9 @@ def _get_timegrains(
query_context: QueryContext, query_obj: QueryObject, _: bool
) -> dict[str, Any]:
datasource = _get_datasource(query_context, query_obj)
return {
"data": [
{
"name": grain.name,
"function": grain.function,
"duration": grain.duration,
}
for grain in datasource.database.grains()
]
}
# Use the new get_time_grains() method from Explorable protocol
grains = datasource.get_time_grains()
return {"data": grains}
def _get_query(
@@ -158,7 +149,8 @@ def _get_samples(
qry_obj_cols = []
for o in datasource.columns:
if isinstance(o, dict):
qry_obj_cols.append(o.get("column_name"))
if column_name := o.get("column_name"):
qry_obj_cols.append(column_name)
else:
qry_obj_cols.append(o.column_name)
query_obj.columns = qry_obj_cols
@@ -180,7 +172,8 @@ def _get_drill_detail(
qry_obj_cols = []
for o in datasource.columns:
if isinstance(o, dict):
qry_obj_cols.append(o.get("column_name"))
if column_name := o.get("column_name"):
qry_obj_cols.append(column_name)
else:
qry_obj_cols.append(o.column_name)
query_obj.columns = qry_obj_cols

View File

@@ -24,11 +24,11 @@ import pandas as pd
from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType
from superset.common.query_context_processor import QueryContextProcessor
from superset.common.query_object import QueryObject
from superset.explorables.base import Explorable
from superset.models.slice import Slice
from superset.utils.core import GenericDataType
if TYPE_CHECKING:
from superset.connectors.sqla.models import BaseDatasource
from superset.models.helpers import QueryResult
@@ -44,7 +44,7 @@ class QueryContext:
cache_type: ClassVar[str] = "df"
enforce_numerical_metrics: ClassVar[bool] = True
datasource: BaseDatasource
datasource: Explorable
slice_: Slice | None = None
queries: list[QueryObject]
form_data: dict[str, Any] | None
@@ -62,7 +62,7 @@ class QueryContext:
def __init__( # pylint: disable=too-many-arguments
self,
*,
datasource: BaseDatasource,
datasource: Explorable,
queries: list[QueryObject],
slice_: Slice | None,
form_data: dict[str, Any] | None,
@@ -99,15 +99,24 @@ class QueryContext:
return self._processor.get_payload(cache_query_context, force_cached)
def get_cache_timeout(self) -> int | None:
"""
Get the cache timeout for this query context.
Priority order:
1. Custom timeout set for this specific query
2. Chart-level timeout (if querying from a saved chart)
3. Datasource-level timeout (explorable handles its own fallback logic)
4. System default (None)
Note: Each explorable is responsible for its own internal fallback chain.
For example, BaseDatasource falls back to database.cache_timeout,
while semantic layers might fall back to their layer's default.
"""
if self.custom_cache_timeout is not None:
return self.custom_cache_timeout
if self.slice_ and self.slice_.cache_timeout is not None:
return self.slice_.cache_timeout
if self.datasource.cache_timeout is not None:
return self.datasource.cache_timeout
if hasattr(self.datasource, "database"):
return self.datasource.database.cache_timeout
return None
return self.datasource.cache_timeout
def query_cache_key(self, query_obj: QueryObject, **kwargs: Any) -> str | None:
return self._processor.query_cache_key(query_obj, **kwargs)

View File

@@ -16,7 +16,7 @@
# under the License.
from __future__ import annotations
from typing import Any, TYPE_CHECKING
from typing import Any
from flask import current_app
@@ -26,13 +26,11 @@ from superset.common.query_object import QueryObject
from superset.common.query_object_factory import QueryObjectFactory
from superset.daos.chart import ChartDAO
from superset.daos.datasource import DatasourceDAO
from superset.explorables.base import Explorable
from superset.models.slice import Slice
from superset.superset_typing import Column
from superset.utils.core import DatasourceDict, DatasourceType, is_adhoc_column
if TYPE_CHECKING:
from superset.connectors.sqla.models import BaseDatasource
def create_query_object_factory() -> QueryObjectFactory:
return QueryObjectFactory(current_app.config, DatasourceDAO())
@@ -104,7 +102,7 @@ class QueryContextFactory: # pylint: disable=too-few-public-methods
cache_values=cache_values,
)
def _convert_to_model(self, datasource: DatasourceDict) -> BaseDatasource:
def _convert_to_model(self, datasource: DatasourceDict) -> Explorable:
return DatasourceDAO.get_datasource(
datasource_type=DatasourceType(datasource["type"]),
database_id_or_uuid=datasource["id"],
@@ -115,7 +113,7 @@ class QueryContextFactory: # pylint: disable=too-few-public-methods
def _process_query_object(
self,
datasource: BaseDatasource,
datasource: Explorable,
form_data: dict[str, Any] | None,
query_object: QueryObject,
) -> QueryObject:
@@ -201,7 +199,7 @@ class QueryContextFactory: # pylint: disable=too-few-public-methods
self,
query_object: QueryObject,
form_data: dict[str, Any] | None,
datasource: BaseDatasource,
datasource: Explorable,
) -> None:
temporal_columns = {
column["column_name"] if isinstance(column, dict) else column.column_name

View File

@@ -29,7 +29,6 @@ from superset.common.db_query_status import QueryStatus
from superset.common.query_actions import get_query_results
from superset.common.utils.query_cache_manager import QueryCacheManager
from superset.common.utils.time_range_utils import get_since_until_from_time_range
from superset.connectors.sqla.models import BaseDatasource
from superset.constants import CACHE_DISABLED_TIMEOUT, CacheRegion
from superset.daos.annotation_layer import AnnotationLayerDAO
from superset.daos.chart import ChartDAO
@@ -37,6 +36,7 @@ from superset.exceptions import (
QueryObjectValidationError,
SupersetException,
)
from superset.explorables.base import Explorable
from superset.extensions import cache_manager, security_manager
from superset.models.helpers import QueryResult
from superset.superset_typing import AdhocColumn, AdhocMetric
@@ -70,7 +70,7 @@ class QueryContextProcessor:
"""
_query_context: QueryContext
_qc_datasource: BaseDatasource
_qc_datasource: Explorable
def __init__(self, query_context: QueryContext):
self._query_context = query_context

View File

@@ -85,6 +85,7 @@ from superset.exceptions import (
SupersetSecurityException,
SupersetSyntaxErrorException,
)
from superset.explorables.base import TimeGrainDict
from superset.jinja_context import (
BaseTemplateProcessor,
ExtraCache,
@@ -105,7 +106,7 @@ from superset.sql.parse import Table
from superset.superset_typing import (
AdhocColumn,
AdhocMetric,
BaseDatasourceData,
ExplorableData,
Metric,
QueryObjectDict,
ResultSetColumnType,
@@ -198,7 +199,7 @@ class BaseDatasource(
is_featured = Column(Boolean, default=False) # TODO deprecating
filter_select_enabled = Column(Boolean, default=True)
offset = Column(Integer, default=0)
cache_timeout = Column(Integer)
_cache_timeout = Column("cache_timeout", Integer)
params = Column(String(1000))
perm = Column(String(1000))
schema_perm = Column(String(1000))
@@ -212,6 +213,78 @@ class BaseDatasource(
extra_import_fields = ["is_managed_externally", "external_url"]
@property
def cache_timeout(self) -> int | None:
"""
Get the cache timeout for this datasource.
Implements the Explorable protocol by handling the fallback chain:
1. Datasource-specific timeout (if set)
2. Database default timeout (if no datasource timeout)
3. None (use system default)
This allows each datasource to override caching, while falling back
to database-level defaults when appropriate.
"""
if self._cache_timeout is not None:
return self._cache_timeout
# database should always be set, but that's not true for v0 import
if self.database:
return self.database.cache_timeout
return None
@cache_timeout.setter
def cache_timeout(self, value: int | None) -> None:
"""Set the datasource-specific cache timeout."""
self._cache_timeout = value
def has_drill_by_columns(self, column_names: list[str]) -> bool:
"""
Check if the specified columns support drill-by operations.
For SQL datasources, drill-by is supported on columns that are marked
as groupable in the metadata. This allows users to navigate from
aggregated views to detailed data by grouping on these dimensions.
:param column_names: List of column names to check
:return: True if all columns support drill-by, False otherwise
"""
if not column_names:
return False
# Get all groupable column names for this datasource
drillable_columns = {
row[0]
for row in db.session.query(TableColumn.column_name)
.filter(TableColumn.table_id == self.id)
.filter(TableColumn.groupby)
.all()
}
# Check if all requested columns are drillable
return set(column_names).issubset(drillable_columns)
def get_time_grains(self) -> list[TimeGrainDict]:
"""
Get available time granularities from the database.
Implements the Explorable protocol by delegating to the database's
time grain definitions. Each database engine spec defines its own
set of supported time grains.
:return: List of time grain dictionaries with name, function, and duration
"""
return [
{
"name": grain.name,
"function": grain.function,
"duration": grain.duration,
}
for grain in (self.database.grains() or [])
]
@property
def kind(self) -> DatasourceKind:
return DatasourceKind.VIRTUAL if self.sql else DatasourceKind.PHYSICAL
@@ -363,7 +436,7 @@ class BaseDatasource(
return verb_map
@property
def data(self) -> BaseDatasourceData:
def data(self) -> ExplorableData:
"""Data representation of the datasource sent to the frontend"""
return {
# simple fields
@@ -1356,7 +1429,7 @@ class SqlaTable(
return [(g.duration, g.name) for g in self.database.grains() or []]
@property
def data(self) -> BaseDatasourceData:
def data(self) -> ExplorableData:
data_ = super().data
if self.type == "table":
data_["granularity_sqla"] = self.granularity_sqla

View File

@@ -28,6 +28,7 @@ from superset.daos.exceptions import (
DatasourceValueIsIncorrect,
)
from superset.models.sql_lab import Query, SavedQuery
from superset.semantic_layers.models import SemanticView
from superset.utils.core import DatasourceType
logger = logging.getLogger(__name__)
@@ -40,6 +41,7 @@ class DatasourceDAO(BaseDAO[Datasource]):
DatasourceType.TABLE: SqlaTable,
DatasourceType.QUERY: Query,
DatasourceType.SAVEDQUERY: SavedQuery,
DatasourceType.SEMANTIC_VIEW: SemanticView,
}
@classmethod

View File

@@ -0,0 +1,152 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""DAOs for semantic layer models."""
from __future__ import annotations
from superset.daos.base import BaseDAO
from superset.extensions import db
from superset.semantic_layers.models import SemanticLayer, SemanticView
class SemanticLayerDAO(BaseDAO[SemanticLayer]):
"""
Data Access Object for SemanticLayer model.
"""
@staticmethod
def validate_uniqueness(name: str) -> bool:
"""
Validate that semantic layer name is unique.
:param name: Semantic layer name
:return: True if name is unique, False otherwise
"""
query = db.session.query(SemanticLayer).filter(SemanticLayer.name == name)
return not db.session.query(query.exists()).scalar()
@staticmethod
def validate_update_uniqueness(layer_uuid: str, name: str) -> bool:
"""
Validate that semantic layer name is unique for updates.
:param layer_uuid: UUID of the semantic layer being updated
:param name: New name to validate
:return: True if name is unique, False otherwise
"""
query = db.session.query(SemanticLayer).filter(
SemanticLayer.name == name,
SemanticLayer.uuid != layer_uuid,
)
return not db.session.query(query.exists()).scalar()
@staticmethod
def find_by_name(name: str) -> SemanticLayer | None:
"""
Find semantic layer by name.
:param name: Semantic layer name
:return: SemanticLayer instance or None
"""
return (
db.session.query(SemanticLayer)
.filter(SemanticLayer.name == name)
.one_or_none()
)
@classmethod
def get_semantic_views(cls, layer_uuid: str) -> list[SemanticView]:
"""
Get all semantic views for a semantic layer.
:param layer_uuid: UUID of the semantic layer
:return: List of SemanticView instances
"""
return (
db.session.query(SemanticView)
.filter(SemanticView.semantic_layer_uuid == layer_uuid)
.all()
)
class SemanticViewDAO(BaseDAO[SemanticView]):
"""Data Access Object for SemanticView model."""
@staticmethod
def find_by_semantic_layer(layer_uuid: str) -> list[SemanticView]:
"""
Find all views for a semantic layer.
:param layer_uuid: UUID of the semantic layer
:return: List of SemanticView instances
"""
return (
db.session.query(SemanticView)
.filter(SemanticView.semantic_layer_uuid == layer_uuid)
.all()
)
@staticmethod
def validate_uniqueness(name: str, layer_uuid: str) -> bool:
"""
Validate that view name is unique within semantic layer.
:param name: View name
:param layer_uuid: UUID of the semantic layer
:return: True if name is unique within layer, False otherwise
"""
query = db.session.query(SemanticView).filter(
SemanticView.name == name,
SemanticView.semantic_layer_uuid == layer_uuid,
)
return not db.session.query(query.exists()).scalar()
@staticmethod
def validate_update_uniqueness(view_uuid: str, name: str, layer_uuid: str) -> bool:
"""
Validate that view name is unique within semantic layer for updates.
:param view_uuid: UUID of the view being updated
:param name: New name to validate
:param layer_uuid: UUID of the semantic layer
:return: True if name is unique within layer, False otherwise
"""
query = db.session.query(SemanticView).filter(
SemanticView.name == name,
SemanticView.semantic_layer_uuid == layer_uuid,
SemanticView.uuid != view_uuid,
)
return not db.session.query(query.exists()).scalar()
@staticmethod
def find_by_name(name: str, layer_uuid: str) -> SemanticView | None:
"""
Find semantic view by name within a semantic layer.
:param name: View name
:param layer_uuid: UUID of the semantic layer
:return: SemanticView instance or None
"""
return (
db.session.query(SemanticView)
.filter(
SemanticView.name == name,
SemanticView.semantic_layer_uuid == layer_uuid,
)
.one_or_none()
)

View File

@@ -0,0 +1,497 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Base protocol for explorable data sources in Superset.
An "explorable" is any data source that can be explored to create charts,
including SQL datasets, saved queries, and semantic layer views.
"""
from __future__ import annotations
from collections.abc import Hashable
from datetime import datetime
from typing import Any, Protocol, runtime_checkable, TYPE_CHECKING, TypedDict
if TYPE_CHECKING:
from superset.common.query_object import QueryObject
from superset.models.helpers import QueryResult
from superset.superset_typing import ExplorableData, QueryObjectDict
class TimeGrainDict(TypedDict):
"""
TypedDict for time grain options returned by get_time_grains.
Represents a time granularity option that can be used for grouping
temporal data. Each time grain specifies how to bucket timestamps.
Attributes:
name: Display name for the time grain (e.g., "Hour", "Day", "Week")
function: Implementation-specific expression for applying the grain.
For SQL datasources, this is typically a SQL expression template
like "DATE_TRUNC('hour', {col})".
duration: ISO 8601 duration string (e.g., "PT1H", "P1D", "P1W")
"""
name: str
function: str
duration: str | None
@runtime_checkable
class MetricMetadata(Protocol):
"""
Protocol for metric metadata objects.
Represents a metric that's available on an explorable data source.
Metrics contain SQL expressions or references to semantic layer measures.
Attributes:
metric_name: Unique identifier for the metric
expression: SQL expression or reference for calculating the metric
verbose_name: Human-readable name for display in the UI
description: Description of what the metric represents
d3format: D3 format string for formatting numeric values
currency: Currency configuration for the metric (JSON object)
warning_text: Warning message to display when using this metric
certified_by: Person or entity that certified this metric
certification_details: Details about the certification
"""
@property
def metric_name(self) -> str:
"""Unique identifier for the metric."""
@property
def expression(self) -> str:
"""SQL expression or reference for calculating the metric."""
@property
def verbose_name(self) -> str | None:
"""Human-readable name for display in the UI."""
@property
def description(self) -> str | None:
"""Description of what the metric represents."""
@property
def d3format(self) -> str | None:
"""D3 format string for formatting numeric values."""
@property
def currency(self) -> dict[str, Any] | None:
"""Currency configuration for the metric (JSON object)."""
@property
def warning_text(self) -> str | None:
"""Warning message to display when using this metric."""
@property
def certified_by(self) -> str | None:
"""Person or entity that certified this metric."""
@property
def certification_details(self) -> str | None:
"""Details about the certification."""
@runtime_checkable
class ColumnMetadata(Protocol):
"""
Protocol for column metadata objects.
Represents a column/dimension that's available on an explorable data source.
Used for grouping, filtering, and dimension-based analysis.
Attributes:
column_name: Unique identifier for the column
type: SQL data type of the column (e.g., 'VARCHAR', 'INTEGER', 'DATETIME')
is_dttm: Whether this column represents a date or time value
verbose_name: Human-readable name for display in the UI
description: Description of what the column represents
groupby: Whether this column is allowed for grouping/aggregation
filterable: Whether this column can be used in filters
expression: SQL expression if this is a calculated column
python_date_format: Python datetime format string for temporal columns
advanced_data_type: Advanced data type classification
extra: Additional metadata stored as JSON
"""
@property
def column_name(self) -> str:
"""Unique identifier for the column."""
@property
def type(self) -> str:
"""SQL data type of the column."""
@property
def is_dttm(self) -> bool:
"""Whether this column represents a date or time value."""
@property
def verbose_name(self) -> str | None:
"""Human-readable name for display in the UI."""
@property
def description(self) -> str | None:
"""Description of what the column represents."""
@property
def groupby(self) -> bool:
"""Whether this column is allowed for grouping/aggregation."""
@property
def filterable(self) -> bool:
"""Whether this column can be used in filters."""
@property
def expression(self) -> str | None:
"""SQL expression if this is a calculated column."""
@property
def python_date_format(self) -> str | None:
"""Python datetime format string for temporal columns."""
@property
def advanced_data_type(self) -> str | None:
"""Advanced data type classification."""
@property
def extra(self) -> str | None:
"""Additional metadata stored as JSON."""
@runtime_checkable
class Explorable(Protocol):
"""
Protocol for objects that can be explored to create charts.
This protocol defines the minimal interface required for a data source
to be visualizable in Superset. It is implemented by:
- BaseDatasource (SQL datasets and queries)
- SemanticView (semantic layer views)
- Future: Other data source types
The protocol focuses on the essential methods and properties needed
for query execution, caching, and security.
"""
# =========================================================================
# Core Query Interface
# =========================================================================
def get_query_result(self, query_object: QueryObject) -> QueryResult:
"""
Execute a query and return results.
This is the primary method for data retrieval. It takes a query
object describing what data to fetch (columns, metrics, filters, time range,
etc.) and returns a QueryResult containing a pandas DataFrame with the results.
:param query_obj: QueryObject describing the query
:return: QueryResult containing:
- df: pandas DataFrame with query results
- query: string representation of the executed query
- duration: query execution time
- status: QueryStatus (SUCCESS/FAILED)
- error_message: error details if query failed
"""
def get_query_str(self, query_obj: QueryObjectDict) -> str:
"""
Get the query string without executing.
Returns a string representation of the query that would be executed
for the given query object. This is used for display in the UI
and debugging.
:param query_obj: Dictionary describing the query
:return: String representation of the query (SQL, GraphQL, etc.)
"""
# =========================================================================
# Identity & Metadata
# =========================================================================
@property
def uid(self) -> str:
"""
Unique identifier for this explorable.
Used as part of cache keys and for tracking. Should be stable
across application restarts but change when the explorable's
data or structure changes.
Format convention: "{type}_{id}" (e.g., "table_123", "semantic_view_abc")
:return: Unique identifier string
"""
@property
def type(self) -> str:
"""
Type discriminator for this explorable.
Identifies the kind of data source (e.g., 'table', 'query', 'semantic_view').
Used for routing and type-specific behavior.
:return: Type identifier string
"""
@property
def metrics(self) -> list[MetricMetadata]:
"""
List of metric metadata objects.
Each object should provide at minimum:
- metric_name: str - the metric's name
- expression: str - the metric's calculation expression
Used for validation, autocomplete, and query building.
:return: List of metric metadata objects
"""
# TODO: rename to dimensions
@property
def columns(self) -> list[ColumnMetadata]:
"""
List of column metadata objects.
Each object should provide at minimum:
- column_name: str - the column's name
- type: str - the column's data type
- is_dttm: bool - whether it's a datetime column
Used for validation, autocomplete, and query building.
:return: List of column metadata objects
"""
# TODO: remove and use columns instead
@property
def column_names(self) -> list[str]:
"""
List of available column names.
A simple list of all column names in the explorable.
Used for quick validation and filtering.
:return: List of column name strings
"""
@property
def data(self) -> ExplorableData:
"""
Full metadata representation sent to the frontend.
This property returns a dictionary containing all the metadata
needed by the Explore UI, including columns, metrics, and
other configuration.
Required keys in the returned dictionary:
- id: unique identifier (int or str)
- uid: unique string identifier
- name: display name
- type: explorable type ('table', 'query', 'semantic_view', etc.)
- columns: list of column metadata dicts (with column_name, type, etc.)
- metrics: list of metric metadata dicts (with metric_name, expression, etc.)
- database: database metadata dict (with id, backend, etc.)
Optional keys:
- description: human-readable description
- schema: schema name (if applicable)
- catalog: catalog name (if applicable)
- cache_timeout: default cache timeout
- offset: timezone offset
- owners: list of owner IDs
- verbose_map: dict mapping column/metric names to display names
:return: Dictionary with complete explorable metadata
"""
# =========================================================================
# Caching
# =========================================================================
@property
def cache_timeout(self) -> int | None:
"""
Default cache timeout in seconds.
Determines how long query results should be cached.
Returns None to use the system default cache timeout.
:return: Cache timeout in seconds, or None for system default
"""
@property
def changed_on(self) -> datetime | None:
"""
Last modification timestamp.
Used for cache invalidation - when this changes, cached
results for this explorable become invalid.
:return: Datetime of last modification, or None
"""
def get_extra_cache_keys(self, query_obj: QueryObjectDict) -> list[Hashable]:
"""
Additional cache key components specific to this explorable.
Provides explorable-specific values to include in cache keys.
Used to ensure cache invalidation when the explorable's
underlying data or configuration changes in ways not captured
by uid or changed_on.
:param query_obj: The query being executed
:return: List of additional hashable values for cache key
"""
# =========================================================================
# Security
# =========================================================================
@property
def perm(self) -> str:
"""
Permission string for this explorable.
Used by the security manager to check if a user has access
to this data source. Format depends on the explorable type
(e.g., "[database].[schema].[table]" for SQL tables).
:return: Permission identifier string
"""
# =========================================================================
# Time/Date Handling
# =========================================================================
@property
def offset(self) -> int:
"""
Timezone offset for datetime columns.
Used to normalize datetime values to the user's timezone.
Returns 0 for UTC, or an offset in seconds.
:return: Timezone offset in seconds (0 for UTC)
"""
# =========================================================================
# Time Granularity
# =========================================================================
def get_time_grains(self) -> list[TimeGrainDict]:
"""
Get available time granularities for temporal grouping.
Returns a list of time grain options that can be used for grouping
temporal data. Each time grain specifies how to bucket timestamps
(e.g., by hour, day, week, month).
Each dictionary in the returned list should contain:
- name: str - Display name (e.g., "Hour", "Day", "Week")
- function: str - How to apply the grain (implementation-specific)
- duration: str - ISO 8601 duration string (e.g., "PT1H", "P1D", "P1W")
For SQL datasources, the function is typically a SQL expression template
like "DATE_TRUNC('hour', {col})". For semantic layers, it might be a
semantic layer-specific identifier like "hour" or "day".
Return an empty list if time grains are not supported or applicable.
Example return value:
```python
[
{
"name": "Second",
"function": "DATE_TRUNC('second', {col})",
"duration": "PT1S",
},
{
"name": "Minute",
"function": "DATE_TRUNC('minute', {col})",
"duration": "PT1M",
},
{
"name": "Hour",
"function": "DATE_TRUNC('hour', {col})",
"duration": "PT1H",
},
{
"name": "Day",
"function": "DATE_TRUNC('day', {col})",
"duration": "P1D",
},
]
```
:return: List of time grain dictionaries (empty list if not supported)
"""
# =========================================================================
# Drilling
# =========================================================================
def has_drill_by_columns(self, column_names: list[str]) -> bool:
"""
Check if the specified columns support drill-by operations.
Drill-by allows users to navigate from aggregated views to detailed
data by grouping on specific dimensions. This method determines whether
the given columns can be used for drill-by in the current datasource.
For SQL datasources, this typically checks if columns are marked as
groupable in the metadata. For semantic views, it checks against the
semantic layer's dimension definitions.
:param column_names: List of column names to check
:return: True if all columns support drill-by, False otherwise
"""
# =========================================================================
# Optional Properties
# =========================================================================
@property
def is_rls_supported(self) -> bool:
"""
Whether this explorable supports Row Level Security.
Row Level Security (RLS) allows filtering data based on user identity.
SQL-based datasources typically support this via SQL queries, while
semantic layers may handle security at the semantic layer level.
:return: True if RLS is supported, False otherwise
"""
@property
def query_language(self) -> str | None:
"""
Query language identifier for syntax highlighting.
Specifies the language used in queries for proper syntax highlighting
in the UI (e.g., 'sql', 'graphql', 'jsoniq').
:return: Language identifier string, or None if not applicable
"""

View File

@@ -109,7 +109,7 @@ class ExploreRestApi(BaseSupersetApi):
params = CommandParameters(
permalink_key=request.args.get("permalink_key", type=str),
form_data_key=request.args.get("form_data_key", type=str),
datasource_id=request.args.get("datasource_id", type=int),
datasource_id=request.args.get("datasource_id"),
datasource_type=request.args.get("datasource_type", type=str),
slice_id=request.args.get("slice_id", type=int),
)

View File

@@ -0,0 +1,126 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""add_semantic_layers_and_views
Revision ID: 33d7e0e21daa
Revises: x2s8ocx6rto6
Create Date: 2025-11-04 11:26:00.000000
"""
import uuid
import sqlalchemy as sa
from sqlalchemy_utils import UUIDType
from sqlalchemy_utils.types.json import JSONType
from superset.extensions import encrypted_field_factory
from superset.migrations.shared.utils import (
create_fks_for_table,
create_table,
drop_table,
)
# revision identifiers, used by Alembic.
revision = "33d7e0e21daa"
down_revision = "x2s8ocx6rto6"
def upgrade():
# Create semantic_layers table
create_table(
"semantic_layers",
sa.Column("uuid", UUIDType(binary=True), default=uuid.uuid4, nullable=False),
sa.Column("created_on", sa.DateTime(), nullable=True),
sa.Column("changed_on", sa.DateTime(), nullable=True),
sa.Column("name", sa.String(length=250), nullable=False),
sa.Column("description", sa.Text(), nullable=True),
sa.Column("type", sa.String(length=250), nullable=False),
sa.Column(
"configuration",
encrypted_field_factory.create(JSONType),
nullable=True,
),
sa.Column("cache_timeout", sa.Integer(), nullable=True),
sa.Column("created_by_fk", sa.Integer(), nullable=True),
sa.Column("changed_by_fk", sa.Integer(), nullable=True),
sa.PrimaryKeyConstraint("uuid"),
)
# Create foreign key constraints for semantic_layers
create_fks_for_table(
"fk_semantic_layers_created_by_fk_ab_user",
"semantic_layers",
"ab_user",
["created_by_fk"],
["id"],
)
create_fks_for_table(
"fk_semantic_layers_changed_by_fk_ab_user",
"semantic_layers",
"ab_user",
["changed_by_fk"],
["id"],
)
# Create semantic_views table
create_table(
"semantic_views",
sa.Column("uuid", UUIDType(binary=True), default=uuid.uuid4, nullable=False),
sa.Column("created_on", sa.DateTime(), nullable=True),
sa.Column("changed_on", sa.DateTime(), nullable=True),
sa.Column("name", sa.String(length=250), nullable=False),
sa.Column("description", sa.Text(), nullable=True),
sa.Column(
"configuration",
encrypted_field_factory.create(JSONType),
nullable=True,
),
sa.Column("cache_timeout", sa.Integer(), nullable=True),
sa.Column(
"semantic_layer_uuid",
UUIDType(binary=True),
sa.ForeignKey("semantic_layers.uuid", ondelete="CASCADE"),
nullable=False,
),
sa.Column("created_by_fk", sa.Integer(), nullable=True),
sa.Column("changed_by_fk", sa.Integer(), nullable=True),
sa.PrimaryKeyConstraint("uuid"),
)
# Create foreign key constraints for semantic_views
create_fks_for_table(
"fk_semantic_views_created_by_fk_ab_user",
"semantic_views",
"ab_user",
["created_by_fk"],
["id"],
)
create_fks_for_table(
"fk_semantic_views_changed_by_fk_ab_user",
"semantic_views",
"ab_user",
["changed_by_fk"],
["id"],
)
def downgrade():
drop_table("semantic_views")
drop_table("semantic_layers")

View File

@@ -467,7 +467,9 @@ class ImportExportMixin(UUIDMixin):
if parent_ref:
parent_excludes = {c.name for c in parent_ref.local_columns}
dict_rep = {
c.name: getattr(self, c.name)
# Convert c.name to str to handle SQLAlchemy's quoted_name type
# which is not YAML-serializable
str(c.name): getattr(self, c.name)
for c in cls.__table__.columns # type: ignore
if (
c.name in export_fields
@@ -837,7 +839,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
raise NotImplementedError()
@property
def cache_timeout(self) -> int:
def cache_timeout(self) -> int | None:
raise NotImplementedError()
@property

View File

@@ -50,6 +50,7 @@ from superset_core.api.models import Query as CoreQuery, SavedQuery as CoreSaved
from superset import security_manager
from superset.exceptions import SupersetParseError, SupersetSecurityException
from superset.explorables.base import TimeGrainDict
from superset.jinja_context import BaseTemplateProcessor, get_template_processor
from superset.models.helpers import (
AuditMixinNullable,
@@ -63,7 +64,7 @@ from superset.sql.parse import (
Table,
)
from superset.sqllab.limiting_factor import LimitingFactor
from superset.superset_typing import QueryData, QueryObjectDict
from superset.superset_typing import ExplorableData, QueryObjectDict
from superset.utils import json
from superset.utils.core import (
get_column_name,
@@ -239,7 +240,7 @@ class Query(
return None
@property
def data(self) -> QueryData:
def data(self) -> ExplorableData:
"""Returns query data for the frontend"""
order_by_choices = []
for col in self.columns:
@@ -335,6 +336,32 @@ class Query(
def get_extra_cache_keys(self, query_obj: QueryObjectDict) -> list[Hashable]:
return []
def get_time_grains(self) -> list[TimeGrainDict]:
"""
Get available time granularities from the database.
Delegates to the database's time grain definitions.
"""
return [
{
"name": grain.name,
"function": grain.function,
"duration": grain.duration,
}
for grain in (self.database.grains() or [])
]
def has_drill_by_columns(self, column_names: list[str]) -> bool:
"""
Check if the specified columns support drill-by operations.
For Query objects, all columns are considered drillable since they
come from ad-hoc SQL queries without predefined metadata.
"""
if not column_names:
return False
return set(column_names).issubset(set(self.column_names))
@property
def tracking_url(self) -> Optional[str]:
"""

View File

@@ -87,6 +87,7 @@ if TYPE_CHECKING:
RowLevelSecurityFilter,
SqlaTable,
)
from superset.explorables.base import Explorable
from superset.models.core import Database
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
@@ -540,24 +541,43 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
or (catalog_perm and self.can_access("catalog_access", catalog_perm))
)
def can_access_schema(self, datasource: "BaseDatasource") -> bool:
def can_access_schema(self, datasource: "BaseDatasource | Explorable") -> bool:
"""
Return True if the user can access the schema associated with specified
datasource, False otherwise.
For SQL datasources: Checks database → catalog → schema hierarchy
For other explorables: Only checks all_datasources permission
:param datasource: The datasource
:returns: Whether the user can access the datasource's schema
"""
from superset.connectors.sqla.models import BaseDatasource
return (
self.can_access_all_datasources()
or self.can_access_database(datasource.database)
or (
datasource.catalog
# Admin/superuser override
if self.can_access_all_datasources():
return True
# SQL-specific hierarchy checks
if isinstance(datasource, BaseDatasource):
# Database-level access grants all schemas
if self.can_access_database(datasource.database):
return True
# Catalog-level access grants all schemas in catalog
if (
hasattr(datasource, "catalog")
and datasource.catalog
and self.can_access_catalog(datasource.database, datasource.catalog)
)
or self.can_access("schema_access", datasource.schema_perm or "")
)
):
return True
# Schema-level permission (SQL only)
if self.can_access("schema_access", datasource.schema_perm or ""):
return True
# Non-SQL explorables don't have schema hierarchy
return False
def can_access_datasource(self, datasource: "BaseDatasource") -> bool:
"""
@@ -604,7 +624,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
self,
form_data: dict[str, Any],
dashboard: "Dashboard",
datasource: "BaseDatasource",
datasource: "BaseDatasource | Explorable",
) -> bool:
"""
Return True if the form_data is performing a supported drill by operation,
@@ -612,10 +632,10 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
:param form_data: The form_data included in the request.
:param dashboard: The dashboard the user is drilling from.
:returns: Whether the user has drill byaccess.
:param datasource: The datasource being queried
:returns: Whether the user has drill by access.
"""
from superset.connectors.sqla.models import TableColumn
from superset.models.slice import Slice
return bool(
@@ -630,16 +650,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
and slc in dashboard.slices
and slc.datasource == datasource
and (dimensions := form_data.get("groupby"))
and (
drillable_columns := {
row[0]
for row in self.session.query(TableColumn.column_name)
.filter(TableColumn.table_id == datasource.id)
.filter(TableColumn.groupby)
.all()
}
)
and set(dimensions).issubset(drillable_columns)
and datasource.has_drill_by_columns(dimensions)
)
def can_access_dashboard(self, dashboard: "Dashboard") -> bool:
@@ -705,7 +716,9 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
)
@staticmethod
def get_datasource_access_error_msg(datasource: "BaseDatasource") -> str:
def get_datasource_access_error_msg(
datasource: "BaseDatasource | Explorable",
) -> str:
"""
Return the error message for the denied Superset datasource.
@@ -714,13 +727,13 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
"""
return (
f"This endpoint requires the datasource {datasource.id}, "
f"This endpoint requires the datasource {datasource.data['id']}, "
"database or `all_datasource_access` permission"
)
@staticmethod
def get_datasource_access_link( # pylint: disable=unused-argument
datasource: "BaseDatasource",
datasource: "BaseDatasource | Explorable",
) -> Optional[str]:
"""
Return the link for the denied Superset datasource.
@@ -732,7 +745,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
return get_conf().get("PERMISSION_INSTRUCTIONS_LINK")
def get_datasource_access_error_object( # pylint: disable=invalid-name
self, datasource: "BaseDatasource"
self, datasource: "BaseDatasource | Explorable"
) -> SupersetError:
"""
Return the error object for the denied Superset datasource.
@@ -746,8 +759,8 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
level=ErrorLevel.WARNING,
extra={
"link": self.get_datasource_access_link(datasource),
"datasource": datasource.id,
"datasource_name": datasource.name,
"datasource": datasource.data["id"],
"datasource_name": datasource.data["name"],
},
)
@@ -2280,8 +2293,8 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
dashboard: Optional["Dashboard"] = None,
chart: Optional["Slice"] = None,
database: Optional["Database"] = None,
datasource: Optional["BaseDatasource"] = None,
query: Optional["Query"] = None,
datasource: Optional["BaseDatasource | Explorable"] = None,
query: Optional["Query | Explorable"] = None,
query_context: Optional["QueryContext"] = None,
table: Optional["Table"] = None,
viz: Optional["BaseViz"] = None,
@@ -2326,7 +2339,9 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
if database and table or query:
if query:
database = query.database
# Type narrow: only SQL Lab Query objects have .database attribute
if hasattr(query, "database"):
database = query.database
database = cast("Database", database)
default_catalog = database.get_default_catalog()
@@ -2334,7 +2349,8 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
if self.can_access_database(database):
return
if query:
# Type narrow: this path only applies to SQL Lab Query objects
if query and hasattr(query, "sql") and hasattr(query, "catalog"):
# Getting the default schema for a query is hard. Users can select the
# schema in SQL Lab, but there's no guarantee that the query actually
# will run in that schema. Each DB engine spec needs to implement the
@@ -2342,8 +2358,11 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
# If the DB engine spec doesn't implement the logic the schema is read
# from the SQLAlchemy URI if possible; if not, we use the SQLAlchemy
# inspector to read it.
from superset.models.sql_lab import Query
default_schema = database.get_default_schema_for_query(
query, template_params
cast(Query, query),
template_params,
)
tables = {
table_.qualify(
@@ -2455,7 +2474,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
and dashboard_.json_metadata
and (json_metadata := json.loads(dashboard_.json_metadata))
and any(
target.get("datasetId") == datasource.id
target.get("datasetId") == datasource.data["id"]
for fltr in json_metadata.get(
"native_filter_configuration",
[],
@@ -2560,7 +2579,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
return super().get_user_roles(user)
def get_guest_rls_filters(
self, dataset: "BaseDatasource"
self, dataset: "BaseDatasource | Explorable"
) -> list[GuestTokenRlsRule]:
"""
Retrieves the row level security filters for the current user and the dataset,
@@ -2573,11 +2592,11 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
rule
for rule in guest_user.rls
if not rule.get("dataset")
or str(rule.get("dataset")) == str(dataset.id)
or str(rule.get("dataset")) == str(dataset.data["id"])
]
return []
def get_rls_filters(self, table: "BaseDatasource") -> list[SqlaQuery]:
def get_rls_filters(self, table: "BaseDatasource | Explorable") -> list[SqlaQuery]:
"""
Retrieves the appropriate row level security filters for the current user and
the passed table.
@@ -2614,7 +2633,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
.filter(RLSFilterRoles.c.role_id.in_(user_roles))
)
filter_tables = self.session.query(RLSFilterTables.c.rls_filter_id).filter(
RLSFilterTables.c.table_id == table.id
RLSFilterTables.c.table_id == table.data["id"]
)
query = (
self.session.query(
@@ -2640,7 +2659,9 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
)
return query.all()
def get_rls_sorted(self, table: "BaseDatasource") -> list["RowLevelSecurityFilter"]:
def get_rls_sorted(
self, table: "BaseDatasource | Explorable"
) -> list["RowLevelSecurityFilter"]:
"""
Retrieves a list RLS filters sorted by ID for
the current user and the passed table.
@@ -2652,10 +2673,12 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
filters.sort(key=lambda f: f.id)
return filters
def get_guest_rls_filters_str(self, table: "BaseDatasource") -> list[str]:
def get_guest_rls_filters_str(
self, table: "BaseDatasource | Explorable"
) -> list[str]:
return [f.get("clause", "") for f in self.get_guest_rls_filters(table)]
def get_rls_cache_key(self, datasource: "BaseDatasource") -> list[str]:
def get_rls_cache_key(self, datasource: "Explorable | BaseDatasource") -> list[str]:
rls_clauses_with_group_key = []
if datasource.is_rls_supported:
rls_clauses_with_group_key = [

View File

@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@@ -0,0 +1,938 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Functions for mapping `QueryObject` to semantic layers.
These functions validate and convert a `QueryObject` into one or more `SemanticQuery`,
which are then passed to semantic layer implementations for execution, returning a
single dataframe.
"""
from datetime import datetime, timedelta
from time import time
from typing import Any, cast, Sequence, TypeGuard
import numpy as np
from superset.common.db_query_status import QueryStatus
from superset.common.query_object import QueryObject
from superset.common.utils.time_range_utils import get_since_until_from_query_object
from superset.connectors.sqla.models import BaseDatasource
from superset.models.helpers import QueryResult
from superset.semantic_layers.types import (
AdhocExpression,
AdhocFilter,
Day,
Dimension,
Filter,
FilterValues,
Grain,
GroupLimit,
Hour,
Metric,
Minute,
Month,
Operator,
OrderDirection,
OrderTuple,
PredicateType,
Quarter,
Second,
SemanticQuery,
SemanticResult,
SemanticViewFeature,
Week,
Year,
)
from superset.utils.core import (
FilterOperator,
QueryObjectFilterClause,
TIME_COMPARISON,
)
from superset.utils.date_parser import get_past_or_future
class ValidatedQueryObjectFilterClause(QueryObjectFilterClause):
"""
A validated QueryObject filter clause with a string column name.
The `col` in a `QueryObjectFilterClause` can be either a string (column name) or an
adhoc column, but we only support the former in semantic layers.
"""
# overwrite to narrow type; mypy complains about more restrictive typed dicts,
# but the alternative would be to redefine the object
col: str # type: ignore[misc]
op: str # type: ignore[misc]
class ValidatedQueryObject(QueryObject):
"""
A query object that has a datasource defined.
"""
datasource: BaseDatasource
# overwrite to narrow type; mypy complains about the assignment since the base type
# allows adhoc filters, but we only support validated filters here
filter: list[ValidatedQueryObjectFilterClause] # type: ignore[assignment]
series_columns: Sequence[str] # type: ignore[assignment]
series_limit_metric: str | None
def get_results(query_object: QueryObject) -> QueryResult:
"""
Run 1+ queries based on `QueryObject` and return the results.
:param query_object: The QueryObject containing query specifications
:return: QueryResult compatible with Superset's query interface
"""
if not validate_query_object(query_object):
raise ValueError("QueryObject must have a datasource defined.")
# Track execution time
start_time = time()
semantic_view = query_object.datasource.implementation
dispatcher = (
semantic_view.get_row_count
if query_object.is_rowcount
else semantic_view.get_dataframe
)
# Step 1: Convert QueryObject to list of SemanticQuery objects
# The first query is the main query, subsequent queries are for time offsets
queries = map_query_object(query_object)
# Step 2: Execute the main query (first in the list)
main_query = queries[0]
main_result = dispatcher(
metrics=main_query.metrics,
dimensions=main_query.dimensions,
filters=main_query.filters,
order=main_query.order,
limit=main_query.limit,
offset=main_query.offset,
group_limit=main_query.group_limit,
)
main_df = main_result.results
# Collect all requests (SQL queries, HTTP requests, etc.) for troubleshooting
all_requests = list(main_result.requests)
# If no time offsets, return the main result as-is
if not query_object.time_offsets or len(queries) <= 1:
semantic_result = SemanticResult(
requests=all_requests,
results=main_df,
)
duration = timedelta(seconds=time() - start_time)
return map_semantic_result_to_query_result(
semantic_result,
query_object,
duration,
)
# Get metric names from the main query
# These are the columns that will be renamed with offset suffixes
metric_names = [metric.name for metric in main_query.metrics]
# Join keys are all columns except metrics
# These will be used to match rows between main and offset DataFrames
join_keys = [col for col in main_df.columns if col not in metric_names]
# Step 3 & 4: Execute each time offset query and join results
for offset_query, time_offset in zip(
queries[1:],
query_object.time_offsets,
strict=False,
):
# Execute the offset query
result = dispatcher(
metrics=offset_query.metrics,
dimensions=offset_query.dimensions,
filters=offset_query.filters,
order=offset_query.order,
limit=offset_query.limit,
offset=offset_query.offset,
group_limit=offset_query.group_limit,
)
# Add this query's requests to the collection
all_requests.extend(result.requests)
offset_df = result.results
# Handle empty results - add NaN columns directly instead of merging
# This avoids dtype mismatch issues with empty DataFrames
if offset_df.empty:
# Add offset metric columns with NaN values directly to main_df
for metric in metric_names:
offset_col_name = TIME_COMPARISON.join([metric, time_offset])
main_df[offset_col_name] = np.nan
else:
# Rename metric columns with time offset suffix
# Format: "{metric_name}__{time_offset}"
# Example: "revenue" -> "revenue__1 week ago"
offset_df = offset_df.rename(
columns={
metric: TIME_COMPARISON.join([metric, time_offset])
for metric in metric_names
}
)
# Step 5: Perform left join on dimension columns
# This preserves all rows from main_df and adds offset metrics
# where they match
main_df = main_df.merge(
offset_df,
on=join_keys,
how="left",
suffixes=("", "__duplicate"),
)
# Clean up any duplicate columns that might have been created
# (shouldn't happen with proper join keys, but defensive programming)
duplicate_cols = [
col for col in main_df.columns if col.endswith("__duplicate")
]
if duplicate_cols:
main_df = main_df.drop(columns=duplicate_cols)
# Convert final result to QueryResult
semantic_result = SemanticResult(requests=all_requests, results=main_df)
duration = timedelta(seconds=time() - start_time)
return map_semantic_result_to_query_result(
semantic_result,
query_object,
duration,
)
def map_semantic_result_to_query_result(
semantic_result: SemanticResult,
query_object: ValidatedQueryObject,
duration: timedelta,
) -> QueryResult:
"""
Convert a SemanticResult to a QueryResult.
:param semantic_result: Result from the semantic layer
:param query_object: Original QueryObject (for passthrough attributes)
:param duration: Time taken to execute the query
:return: QueryResult compatible with Superset's query interface
"""
# Get the query string from requests (typically one or more SQL queries)
query_str = ""
if semantic_result.requests:
# Join all requests for display (could be multiple for time comparisons)
query_str = "\n\n".join(
f"-- {req.type}\n{req.definition}" for req in semantic_result.requests
)
return QueryResult(
# Core data
df=semantic_result.results,
query=query_str,
duration=duration,
# Template filters - not applicable to semantic layers
# (semantic layers don't use Jinja templates)
applied_template_filters=None,
# Filter columns - not applicable to semantic layers
# (semantic layers handle filter validation internally)
applied_filter_columns=None,
rejected_filter_columns=None,
# Status - always success if we got here
# (errors would raise exceptions before reaching this point)
status=QueryStatus.SUCCESS,
error_message=None,
errors=None,
# Time range - pass through from original query_object
from_dttm=query_object.from_dttm,
to_dttm=query_object.to_dttm,
)
def _normalize_column(column: str | dict, dimension_names: set[str]) -> str:
"""
Normalize a column to its dimension name.
Columns can be either:
- A string (dimension name directly)
- A dict with isColumnReference=True and sqlExpression containing the dimension name
"""
if isinstance(column, str):
return column
if isinstance(column, dict):
# Handle column references (e.g., from time-series charts)
if column.get("isColumnReference") and column.get("sqlExpression"):
sql_expr = column["sqlExpression"]
if sql_expr in dimension_names:
return sql_expr
raise ValueError("Adhoc dimensions are not supported in Semantic Views.")
def map_query_object(query_object: ValidatedQueryObject) -> list[SemanticQuery]:
"""
Convert a `QueryObject` into a list of `SemanticQuery`.
This function maps the `QueryObject` into query objects that focus less on
visualization and more on semantics.
"""
semantic_view = query_object.datasource.implementation
all_metrics = {metric.name: metric for metric in semantic_view.metrics}
all_dimensions = {
dimension.name: dimension for dimension in semantic_view.dimensions
}
# Normalize columns (may be dicts with isColumnReference=True for time-series)
dimension_names = set(all_dimensions.keys())
normalized_columns = {
_normalize_column(column, dimension_names) for column in query_object.columns
}
metrics = [all_metrics[metric] for metric in (query_object.metrics or [])]
grain = (
_convert_time_grain(query_object.extras["time_grain_sqla"])
if "time_grain_sqla" in query_object.extras
else None
)
dimensions = [
dimension
for dimension in semantic_view.dimensions
if dimension.name in normalized_columns
and (
# if a grain is specified, only include the time dimension if its grain
# matches the requested grain
grain is None
or dimension.name != query_object.granularity
or dimension.grain == grain
)
]
order = _get_order_from_query_object(query_object, all_metrics, all_dimensions)
limit = query_object.row_limit
offset = query_object.row_offset
group_limit = _get_group_limit_from_query_object(
query_object,
all_metrics,
all_dimensions,
)
queries = []
for time_offset in [None] + query_object.time_offsets:
filters = _get_filters_from_query_object(
query_object,
time_offset,
all_dimensions,
)
print(">>", filters)
queries.append(
SemanticQuery(
metrics=metrics,
dimensions=dimensions,
filters=filters,
order=order,
limit=limit,
offset=offset,
group_limit=group_limit,
)
)
return queries
def _get_filters_from_query_object(
query_object: ValidatedQueryObject,
time_offset: str | None,
all_dimensions: dict[str, Dimension],
) -> set[Filter | AdhocFilter]:
"""
Extract all filters from the query object, including time range filters.
This simplifies the complexity of from_dttm/to_dttm/inner_from_dttm/inner_to_dttm
by converting all time constraints into filters.
"""
filters: set[Filter | AdhocFilter] = set()
# 1. Add fetch values predicate if present
if (
query_object.apply_fetch_values_predicate
and query_object.datasource.fetch_values_predicate
):
filters.add(
AdhocFilter(
type=PredicateType.WHERE,
definition=query_object.datasource.fetch_values_predicate,
)
)
# 2. Add time range filter based on from_dttm/to_dttm
# For time offsets, this automatically calculates the shifted bounds
time_filters = _get_time_filter(query_object, time_offset, all_dimensions)
filters.update(time_filters)
# 3. Add filters from query_object.extras (WHERE and HAVING clauses)
extras_filters = _get_filters_from_extras(query_object.extras)
filters.update(extras_filters)
# 4. Add all other filters from query_object.filter
for filter_ in query_object.filter:
# Skip temporal range filters - we're using inner bounds instead
if (
filter_.get("op") == FilterOperator.TEMPORAL_RANGE.value
and query_object.granularity
):
continue
if converted_filters := _convert_query_object_filter(filter_, all_dimensions):
filters.update(converted_filters)
return filters
def _get_filters_from_extras(extras: dict[str, Any]) -> set[AdhocFilter]:
"""
Extract filters from the extras dict.
The extras dict can contain various keys that affect query behavior:
Supported keys (converted to filters):
- "where": SQL WHERE clause expression (e.g., "customer_id > 100")
- "having": SQL HAVING clause expression (e.g., "SUM(sales) > 1000")
Other keys in extras (handled elsewhere in the mapper):
- "time_grain_sqla": Time granularity (e.g., "P1D", "PT1H")
Handled in _convert_time_grain() and used for dimension grain matching
Note: The WHERE and HAVING clauses from extras are SQL expressions that
are passed through as-is to the semantic layer as AdhocFilter objects.
"""
filters: set[AdhocFilter] = set()
# Add WHERE clause from extras
if where_clause := extras.get("where"):
filters.add(
AdhocFilter(
type=PredicateType.WHERE,
definition=where_clause,
)
)
# Add HAVING clause from extras
if having_clause := extras.get("having"):
filters.add(
AdhocFilter(
type=PredicateType.HAVING,
definition=having_clause,
)
)
return filters
def _get_time_filter(
query_object: ValidatedQueryObject,
time_offset: str | None,
all_dimensions: dict[str, Dimension],
) -> set[Filter]:
"""
Create a time range filter from the query object.
This handles both regular queries and time offset queries, simplifying the
complexity of from_dttm/to_dttm/inner_from_dttm/inner_to_dttm by using the
same time bounds for both the main query and series limit subqueries.
"""
filters: set[Filter] = set()
if not query_object.granularity:
return filters
time_dimension = all_dimensions.get(query_object.granularity)
if not time_dimension:
return filters
# Get the appropriate time bounds based on whether this is a time offset query
from_dttm, to_dttm = _get_time_bounds(query_object, time_offset)
if not from_dttm or not to_dttm:
return filters
# Create a filter with >= and < operators
return {
Filter(
type=PredicateType.WHERE,
column=time_dimension,
operator=Operator.GREATER_THAN_OR_EQUAL,
value=from_dttm,
),
Filter(
type=PredicateType.WHERE,
column=time_dimension,
operator=Operator.LESS_THAN,
value=to_dttm,
),
}
def _get_time_bounds(
query_object: ValidatedQueryObject,
time_offset: str | None,
) -> tuple[datetime | None, datetime | None]:
"""
Get the appropriate time bounds for the query.
For regular queries (time_offset is None), returns from_dttm/to_dttm.
For time offset queries, calculates the shifted bounds.
This simplifies the inner_from_dttm/inner_to_dttm complexity by using
the same bounds for both main queries and series limit subqueries (Option 1).
"""
if time_offset is None:
# Main query: use from_dttm/to_dttm directly
return query_object.from_dttm, query_object.to_dttm
# Time offset query: calculate shifted bounds
# Use from_dttm/to_dttm if available, otherwise try to get from time_range
outer_from = query_object.from_dttm
outer_to = query_object.to_dttm
if not outer_from or not outer_to:
# Fall back to parsing time_range if from_dttm/to_dttm not set
outer_from, outer_to = get_since_until_from_query_object(query_object)
if not outer_from or not outer_to:
return None, None
# Apply the offset to both bounds
offset_from = get_past_or_future(time_offset, outer_from)
offset_to = get_past_or_future(time_offset, outer_to)
return offset_from, offset_to
def _convert_query_object_filter(
filter_: ValidatedQueryObjectFilterClause,
all_dimensions: dict[str, Dimension],
) -> set[Filter] | None:
"""
Convert a QueryObject filter dict to a semantic layer Filter or AdhocFilter.
"""
operator_str = filter_["op"]
# Handle simple column filters
col = filter_.get("col")
if col not in all_dimensions:
return None
dimension = all_dimensions[col]
val_str = filter_["val"]
value: FilterValues | set[FilterValues]
if val_str is None:
value = None
elif isinstance(val_str, (list, tuple)):
value = set(val_str)
else:
value = val_str
# Special case for temporal range
if operator_str == FilterOperator.TEMPORAL_RANGE.value:
# XXX
start, end = value.split(" : ")
return {
Filter(
type=PredicateType.WHERE,
column=dimension,
operator=Operator.GREATER_THAN_OR_EQUAL,
value=start,
),
Filter(
type=PredicateType.WHERE,
column=dimension,
operator=Operator.LESS_THAN,
value=end,
),
}
# Map QueryObject operators to semantic layer operators
operator_mapping = {
FilterOperator.EQUALS.value: Operator.EQUALS,
FilterOperator.NOT_EQUALS.value: Operator.NOT_EQUALS,
FilterOperator.GREATER_THAN.value: Operator.GREATER_THAN,
FilterOperator.LESS_THAN.value: Operator.LESS_THAN,
FilterOperator.GREATER_THAN_OR_EQUALS.value: Operator.GREATER_THAN_OR_EQUAL,
FilterOperator.LESS_THAN_OR_EQUALS.value: Operator.LESS_THAN_OR_EQUAL,
FilterOperator.IN.value: Operator.IN,
FilterOperator.NOT_IN.value: Operator.NOT_IN,
FilterOperator.LIKE.value: Operator.LIKE,
FilterOperator.NOT_LIKE.value: Operator.NOT_LIKE,
FilterOperator.IS_NULL.value: Operator.IS_NULL,
FilterOperator.IS_NOT_NULL.value: Operator.IS_NOT_NULL,
}
operator = operator_mapping.get(operator_str)
if not operator:
# Unknown operator - create adhoc filter
return None
return {
Filter(
type=PredicateType.WHERE,
column=dimension,
operator=operator,
value=value,
)
}
def _get_order_from_query_object(
query_object: ValidatedQueryObject,
all_metrics: dict[str, Metric],
all_dimensions: dict[str, Dimension],
) -> list[OrderTuple]:
order: list[OrderTuple] = []
for element, ascending in query_object.orderby:
direction = OrderDirection.ASC if ascending else OrderDirection.DESC
# adhoc
if isinstance(element, dict):
if element["sqlExpression"] is not None:
order.append(
(
AdhocExpression(
id=element["label"] or element["sqlExpression"],
definition=element["sqlExpression"],
),
direction,
)
)
elif element in all_dimensions:
order.append((all_dimensions[element], direction))
elif element in all_metrics:
order.append((all_metrics[element], direction))
return order
def _get_group_limit_from_query_object(
query_object: ValidatedQueryObject,
all_metrics: dict[str, Metric],
all_dimensions: dict[str, Dimension],
) -> GroupLimit | None:
# no limit
if query_object.series_limit == 0 or not query_object.columns:
return None
dimensions = [all_dimensions[dim_id] for dim_id in query_object.series_columns]
top = query_object.series_limit
metric = (
all_metrics[query_object.series_limit_metric]
if query_object.series_limit_metric
else None
)
direction = OrderDirection.DESC if query_object.order_desc else OrderDirection.ASC
group_others = query_object.group_others_when_limit_reached
# Check if we need separate filters for the group limit subquery
# This happens when inner_from_dttm/inner_to_dttm differ from from_dttm/to_dttm
group_limit_filters = _get_group_limit_filters(query_object, all_dimensions)
return GroupLimit(
dimensions=dimensions,
top=top,
metric=metric,
direction=direction,
group_others=group_others,
filters=group_limit_filters,
)
def _get_group_limit_filters(
query_object: ValidatedQueryObject,
all_dimensions: dict[str, Dimension],
) -> set[Filter | AdhocFilter] | None:
"""
Get separate filters for the group limit subquery if needed.
This is used when inner_from_dttm/inner_to_dttm differ from from_dttm/to_dttm,
which happens during time comparison queries. The group limit subquery may need
different time bounds to determine the top N groups.
Returns None if the group limit should use the same filters as the main query.
"""
# Check if inner time bounds are explicitly set and differ from outer bounds
if (
query_object.inner_from_dttm is None
or query_object.inner_to_dttm is None
or (
query_object.inner_from_dttm == query_object.from_dttm
and query_object.inner_to_dttm == query_object.to_dttm
)
):
# No separate bounds needed - use the same filters as the main query
return None
# Create separate filters for the group limit subquery
filters: set[Filter | AdhocFilter] = set()
# Add time range filter using inner bounds
if query_object.granularity:
time_dimension = all_dimensions.get(query_object.granularity)
if (
time_dimension
and query_object.inner_from_dttm
and query_object.inner_to_dttm
):
filters.update(
{
Filter(
type=PredicateType.WHERE,
column=time_dimension,
operator=Operator.GREATER_THAN_OR_EQUAL,
value=query_object.inner_from_dttm,
),
Filter(
type=PredicateType.WHERE,
column=time_dimension,
operator=Operator.LESS_THAN,
value=query_object.inner_to_dttm,
),
}
)
# Add fetch values predicate if present
if (
query_object.apply_fetch_values_predicate
and query_object.datasource.fetch_values_predicate
):
filters.add(
AdhocFilter(
type=PredicateType.WHERE,
definition=query_object.datasource.fetch_values_predicate,
)
)
# Add filters from query_object.extras (WHERE and HAVING clauses)
extras_filters = _get_filters_from_extras(query_object.extras)
filters.update(extras_filters)
# Add all other non-temporal filters from query_object.filter
for filter_ in query_object.filter:
# Skip temporal range filters - we're using inner bounds instead
if (
filter_.get("op") == FilterOperator.TEMPORAL_RANGE.value
and query_object.granularity
):
continue
if converted_filters := _convert_query_object_filter(filter_, all_dimensions):
filters.update(converted_filters)
return filters if filters else None
def _convert_time_grain(time_grain: str) -> Grain | None:
"""
Convert a time grain string from the query object to a Grain enum.
"""
mapping = {
grain.representation: grain
for grain in [
Second,
Minute,
Hour,
Day,
Week,
Month,
Quarter,
Year,
]
}
return mapping.get(time_grain)
def validate_query_object(
query_object: QueryObject,
) -> TypeGuard[ValidatedQueryObject]:
"""
Validate that the `QueryObject` is compatible with the `SemanticView`.
If some semantic view implementation supports these features we should add an
attribute to the `SemanticViewImplementation` to indicate support for them.
"""
if not query_object.datasource:
return False
query_object = cast(ValidatedQueryObject, query_object)
_validate_metrics(query_object)
_validate_dimensions(query_object)
_validate_filters(query_object)
_validate_granularity(query_object)
_validate_group_limit(query_object)
_validate_orderby(query_object)
return True
def _validate_metrics(query_object: ValidatedQueryObject) -> None:
"""
Make sure metrics are defined in the semantic view.
"""
semantic_view = query_object.datasource.implementation
if any(not isinstance(metric, str) for metric in (query_object.metrics or [])):
raise ValueError("Adhoc metrics are not supported in Semantic Views.")
metric_names = {metric.name for metric in semantic_view.metrics}
if not set(query_object.metrics or []) <= metric_names:
raise ValueError("All metrics must be defined in the Semantic View.")
def _validate_dimensions(query_object: ValidatedQueryObject) -> None:
"""
Make sure all dimensions are defined in the semantic view.
"""
semantic_view = query_object.datasource.implementation
dimension_names = {dimension.name for dimension in semantic_view.dimensions}
# Normalize all columns to dimension names
normalized_columns = [
_normalize_column(column, dimension_names) for column in query_object.columns
]
if not set(normalized_columns) <= dimension_names:
raise ValueError("All dimensions must be defined in the Semantic View.")
def _validate_filters(query_object: ValidatedQueryObject) -> None:
"""
Make sure all filters are valid.
"""
for filter_ in query_object.filter:
if isinstance(filter_["col"], dict):
raise ValueError(
"Adhoc columns are not supported in Semantic View filters."
)
if not filter_.get("op"):
raise ValueError("All filters must have an operator defined.")
def _validate_granularity(query_object: ValidatedQueryObject) -> None:
"""
Make sure time column and time grain are valid.
"""
semantic_view = query_object.datasource.implementation
dimension_names = {dimension.name for dimension in semantic_view.dimensions}
if time_column := query_object.granularity:
if time_column not in dimension_names:
raise ValueError(
"The time column must be defined in the Semantic View dimensions."
)
if time_grain := query_object.extras.get("time_grain_sqla"):
if not time_column:
raise ValueError(
"A time column must be specified when a time grain is provided."
)
supported_time_grains = {
dimension.grain
for dimension in semantic_view.dimensions
if dimension.name == time_column and dimension.grain
}
if _convert_time_grain(time_grain) not in supported_time_grains:
raise ValueError(
"The time grain is not supported for the time column in the "
"Semantic View."
)
def _validate_group_limit(query_object: ValidatedQueryObject) -> None:
"""
Validate group limit related features in the query object.
"""
semantic_view = query_object.datasource.implementation
# no limit
if query_object.series_limit == 0:
return
if (
query_object.series_columns
and SemanticViewFeature.GROUP_LIMIT not in semantic_view.features
):
raise ValueError("Group limit is not supported in this Semantic View.")
if any(not isinstance(col, str) for col in query_object.series_columns):
raise ValueError("Adhoc dimensions are not supported in series columns.")
metric_names = {metric.name for metric in semantic_view.metrics}
if query_object.series_limit_metric and (
not isinstance(query_object.series_limit_metric, str)
or query_object.series_limit_metric not in metric_names
):
raise ValueError(
"The series limit metric must be defined in the Semantic View."
)
dimension_names = {dimension.name for dimension in semantic_view.dimensions}
if not set(query_object.series_columns) <= dimension_names:
raise ValueError("All series columns must be defined in the Semantic View.")
if (
query_object.group_others_when_limit_reached
and SemanticViewFeature.GROUP_OTHERS not in semantic_view.features
):
raise ValueError(
"Grouping others when limit is reached is not supported in this Semantic "
"View."
)
def _validate_orderby(query_object: ValidatedQueryObject) -> None:
"""
Validate order by elements in the query object.
"""
semantic_view = query_object.datasource.implementation
if (
any(not isinstance(element, str) for element, _ in query_object.orderby)
and SemanticViewFeature.ADHOC_EXPRESSIONS_IN_ORDERBY
not in semantic_view.features
):
raise ValueError(
"Adhoc expressions in order by are not supported in this Semantic View."
)
elements = {orderby[0] for orderby in query_object.orderby}
metric_names = {metric.name for metric in semantic_view.metrics}
dimension_names = {dimension.name for dimension in semantic_view.dimensions}
if not elements <= metric_names | dimension_names:
raise ValueError("All order by elements must be defined in the Semantic View.")

View File

@@ -0,0 +1,381 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Semantic layer models."""
from __future__ import annotations
import uuid
from collections.abc import Hashable
from dataclasses import dataclass
from functools import cached_property
from importlib.metadata import entry_points
from typing import Any, TYPE_CHECKING
from flask_appbuilder import Model
from sqlalchemy import Column, ForeignKey, Integer, String, Text
from sqlalchemy.orm import relationship
from sqlalchemy_utils import UUIDType
from sqlalchemy_utils.types.json import JSONType
from superset.common.query_object import QueryObject
from superset.explorables.base import TimeGrainDict
from superset.extensions import encrypted_field_factory
from superset.models.helpers import AuditMixinNullable, QueryResult
from superset.semantic_layers.mapper import get_results
from superset.semantic_layers.types import (
BINARY,
BOOLEAN,
DATE,
DATETIME,
DECIMAL,
INTEGER,
INTERVAL,
NUMBER,
OBJECT,
SemanticLayerImplementation,
SemanticViewImplementation,
STRING,
TIME,
Type,
)
from superset.utils import json
from superset.utils.core import GenericDataType
if TYPE_CHECKING:
from superset.superset_typing import ExplorableData, QueryObjectDict
def get_column_type(semantic_type: Type) -> GenericDataType:
"""
Map semantic layer types to generic data types.
"""
if semantic_type in {DATE, DATETIME, TIME}:
return GenericDataType.TEMPORAL
if semantic_type in {INTEGER, NUMBER, DECIMAL, INTERVAL}:
return GenericDataType.NUMERIC
if semantic_type is BOOLEAN:
return GenericDataType.BOOLEAN
if semantic_type in {STRING, OBJECT, BINARY}:
return GenericDataType.STRING
return GenericDataType.STRING
@dataclass(frozen=True)
class MetricMetadata:
metric_name: str
expression: str
verbose_name: str | None = None
description: str | None = None
d3format: str | None = None
currency: dict[str, Any] | None = None
warning_text: str | None = None
certified_by: str | None = None
certification_details: str | None = None
@dataclass(frozen=True)
class ColumnMetadata:
column_name: str
type: str
is_dttm: bool
verbose_name: str | None = None
description: str | None = None
groupby: bool = True
filterable: bool = True
expression: str | None = None
python_date_format: str | None = None
advanced_data_type: str | None = None
extra: str | None = None
class SemanticLayer(AuditMixinNullable, Model):
"""
Semantic layer model.
A semantic layer provides an abstraction over data sources,
allowing users to query data through a semantic interface.
"""
__tablename__ = "semantic_layers"
uuid = Column(UUIDType(binary=True), primary_key=True, default=uuid.uuid4)
# Core fields
name = Column(String(250), nullable=False)
description = Column(Text, nullable=True)
type = Column(String(250), nullable=False) # snowflake, etc
configuration = Column(encrypted_field_factory.create(JSONType), default=dict)
cache_timeout = Column(Integer, nullable=True)
# Semantic views relationship
semantic_views: list[SemanticView] = relationship(
"SemanticView",
back_populates="semantic_layer",
cascade="all, delete-orphan",
passive_deletes=True,
)
def __repr__(self) -> str:
return self.name or str(self.uuid)
@cached_property
def implementation(
self,
) -> SemanticLayerImplementation[Any, SemanticViewImplementation]:
"""
Return semantic layer implementation.
"""
entry_point = next(
iter(
entry_points(
group="superset.semantic_layers",
name=self.type,
)
)
)
implementation_class = entry_point.load()
if not issubclass(implementation_class, SemanticLayerImplementation):
raise TypeError(
f"Entry point for semantic layer type '{self.type}' "
"must be a subclass of SemanticLayerImplementation"
)
return implementation_class.from_configuration(json.loads(self.configuration))
class SemanticView(AuditMixinNullable, Model):
"""
Semantic view model.
A semantic view represents a queryable view within a semantic layer.
"""
__tablename__ = "semantic_views"
uuid = Column(UUIDType(binary=True), primary_key=True, default=uuid.uuid4)
# Core fields
name = Column(String(250), nullable=False)
description = Column(Text, nullable=True)
configuration = Column(encrypted_field_factory.create(JSONType), default=dict)
cache_timeout = Column(Integer, nullable=True)
# Semantic layer relationship
semantic_layer_uuid = Column(
UUIDType(binary=True),
ForeignKey("semantic_layers.uuid", ondelete="CASCADE"),
nullable=False,
)
semantic_layer: SemanticLayer = relationship(
"SemanticLayer",
back_populates="semantic_views",
foreign_keys=[semantic_layer_uuid],
)
def __repr__(self) -> str:
return self.name or str(self.uuid)
@cached_property
def implementation(self) -> SemanticViewImplementation:
"""
Return semantic view implementation.
"""
return self.semantic_layer.implementation.get_semantic_view(
self.name,
json.loads(self.configuration),
)
# =========================================================================
# Explorable protocol implementation
# =========================================================================
def get_query_result(self, query_object: QueryObject) -> QueryResult:
return get_results(query_object)
def get_query_str(self, query_obj: QueryObjectDict) -> str:
return "Not implemented for semantic layers"
@property
def uid(self) -> str:
return self.implementation.uid()
@property
def type(self) -> str:
return "semantic_view"
@property
def metrics(self) -> list[MetricMetadata]:
return [
MetricMetadata(
metric_name=metric.name,
expression=metric.definition,
description=metric.description,
)
for metric in self.implementation.metrics
]
@property
def columns(self) -> list[ColumnMetadata]:
return [
ColumnMetadata(
column_name=dimension.name,
type=dimension.type.__name__,
is_dttm=dimension.type in {DATE, TIME, DATETIME},
description=dimension.description,
expression=dimension.definition,
extra=json.dumps({"grain": dimension.grain}),
)
for dimension in self.implementation.dimensions
]
@property
def column_names(self) -> list[str]:
return [dimension.name for dimension in self.implementation.dimensions]
@property
def data(self) -> ExplorableData:
return {
# core
"id": self.uuid.hex,
"uid": self.uid,
"type": "semantic_view",
"name": self.name,
"columns": [
{
"advanced_data_type": None,
"certification_details": None,
"certified_by": None,
"column_name": dimension.name,
"description": dimension.description,
"expression": dimension.definition,
"filterable": True,
"groupby": True,
"id": None,
"uuid": None,
"is_certified": False,
"is_dttm": dimension.type in {DATE, TIME, DATETIME},
"python_date_format": None,
"type": dimension.type.__name__,
"type_generic": get_column_type(dimension.type),
"verbose_name": None,
"warning_markdown": None,
}
for dimension in self.implementation.dimensions
],
"metrics": [
{
"certification_details": None,
"certified_by": None,
"d3format": None,
"description": metric.description,
"expression": metric.definition,
"id": None,
"uuid": None,
"is_certified": False,
"metric_name": metric.name,
"warning_markdown": None,
"warning_text": None,
"verbose_name": None,
}
for metric in self.implementation.metrics
],
"database": {},
# UI features
"verbose_map": {},
"order_by_choices": [],
"filter_select": True,
"filter_select_enabled": True,
"sql": None,
"select_star": None,
"owners": [],
"description": self.description,
"table_name": self.name,
"column_types": [
get_column_type(dimension.type)
for dimension in self.implementation.dimensions
],
"column_names": [
dimension.name for dimension in self.implementation.dimensions
],
# rare
"column_formats": {},
"datasource_name": self.name,
"perm": self.perm,
"offset": None,
"cache_timeout": self.cache_timeout,
"params": None,
# sql-specific
"schema": None,
"catalog": None,
"main_dttm_col": None,
"time_grain_sqla": [],
"granularity_sqla": [],
"fetch_values_predicate": None,
"template_params": None,
"is_sqllab_view": False,
"extra": None,
"always_filter_main_dttm": False,
"normalize_columns": False,
# TODO XXX
# "owners": [owner.id for owner in self.owners],
"edit_url": "",
"default_endpoint": None,
"folders": [],
"health_check_message": None,
}
def get_extra_cache_keys(self, query_obj: QueryObjectDict) -> list[Hashable]:
return []
@property
def perm(self) -> str:
return self.semantic_layer_uuid.hex + "::" + self.uuid.hex
@property
def offset(self) -> int:
# always return datetime as UTC
return 0
@property
def get_time_grains(self) -> list[TimeGrainDict]:
return [
{
"name": dimension.grain.name,
"function": "",
"duration": dimension.grain.representation,
}
for dimension in self.implementation.dimensions
if dimension.grain
]
def has_drill_by_columns(self, column_names: list[str]) -> bool:
dimension_names = {
dimension.name for dimension in self.implementation.dimensions
}
return all(column_name in dimension_names for column_name in column_names)
@property
def is_rls_supported(self) -> bool:
return False
@property
def query_language(self) -> str | None:
return None

View File

@@ -0,0 +1,26 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from superset.semantic_layers.snowflake.schemas import SnowflakeConfiguration
from superset.semantic_layers.snowflake.semantic_layer import SnowflakeSemanticLayer
from superset.semantic_layers.snowflake.semantic_view import SnowflakeSemanticView
__all__ = [
"SnowflakeConfiguration",
"SnowflakeSemanticLayer",
"SnowflakeSemanticView",
]

View File

@@ -0,0 +1,130 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
from typing import Literal, Union
from pydantic import BaseModel, ConfigDict, Field, model_validator, SecretStr
class UserPasswordAuth(BaseModel):
"""
Username and password authentication.
"""
model_config = ConfigDict(title="Username and password")
auth_type: Literal["user_password"] = "user_password"
username: str = Field(description="The username to authenticate as.")
password: SecretStr = Field(
description="The password to authenticate with.",
repr=False,
)
class PrivateKeyAuth(BaseModel):
"""
Private key authentication.
"""
model_config = ConfigDict(title="Private key")
auth_type: Literal["private_key"] = "private_key"
private_key: SecretStr = Field(
description="The private key to authenticate with, in PEM format.",
repr=False,
)
private_key_password: SecretStr = Field(
description="The password to decrypt the private key with.",
repr=False,
)
class SnowflakeConfiguration(BaseModel):
"""
Parameters needed to connect to Snowflake.
"""
# account is the only required parameter
account_identifier: str = Field(
description="The Snowflake account identifier.",
json_schema_extra={"examples": ["abc12345"]},
)
role: str | None = Field(
default=None,
description="The default role to use.",
json_schema_extra={"examples": ["myrole"]},
)
warehouse: str | None = Field(
default=None,
description="The default warehouse to use.",
json_schema_extra={"examples": ["testwh"]},
)
auth: Union[UserPasswordAuth, PrivateKeyAuth] = Field(
discriminator="auth_type",
description="Authentication method",
)
# database and schema can be optionally provided; if not provided the user
# will be able to browse databases/schemas
database: str | None = Field(
default=None,
description="The default database to use.",
json_schema_extra={
"examples": ["testdb"],
"x-dynamic": True,
"x-dependsOn": ["account_identifier", "auth"],
},
)
allow_changing_database: bool = Field(
default=False,
description="Allow changing the default database.",
)
schema_: str | None = Field(
default=None,
description="The default schema to use.",
json_schema_extra={
"examples": ["public"],
"x-dynamic": True,
"x-dependsOn": ["account_identifier", "auth", "database"],
},
# `schema` is an attribute of `BaseModel` so it needs to be aliased
alias="schema",
)
allow_changing_schema: bool = Field(
default=False,
description="Allow changing the default schema.",
)
@model_validator(mode="after")
def validate_database_schema_settings(self) -> SnowflakeConfiguration:
"""
Validate that if database or schema is not specified, the corresponding
allow_changing flag must be true.
"""
if not self.database and not self.allow_changing_database:
raise ValueError(
"If no database is specified, allow_changing_database must be true"
)
if not self.schema_ and not self.allow_changing_schema:
raise ValueError(
"If no schema is specified, allow_changing_schema must be true"
)
return self

View File

@@ -0,0 +1,269 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
from textwrap import dedent
from typing import Any, Literal
from pydantic import create_model, Field
from snowflake.connector import connect
from snowflake.connector.connection import SnowflakeConnection
from superset.semantic_layers.snowflake.schemas import SnowflakeConfiguration
from superset.semantic_layers.snowflake.semantic_view import SnowflakeSemanticView
from superset.semantic_layers.snowflake.utils import get_connection_parameters
from superset.semantic_layers.types import (
SemanticLayerImplementation,
)
class SnowflakeSemanticLayer(
SemanticLayerImplementation[SnowflakeConfiguration, SnowflakeSemanticView]
):
id = "snowflake"
name = "Snowflake Semantic Layer"
description = "Connect to semantic views stored in Snowflake."
@classmethod
def from_configuration(
cls,
configuration: dict[str, Any],
) -> SnowflakeSemanticLayer:
"""
Create a SnowflakeSemanticLayer from a configuration dictionary.
"""
config = SnowflakeConfiguration.model_validate(configuration)
return cls(config)
@classmethod
def get_configuration_schema(
cls,
configuration: SnowflakeConfiguration | None = None,
) -> dict[str, Any]:
"""
Get the JSON schema for the configuration needed to add the semantic layer.
A partial configuration can be sent to improve the schema. For example,
providing account and auth will allow the schema to provide a list of
databases; providing a database will allow the schema to provide a list of
schemas.
Note that database and schema can both be left empty when the semantic layer is
added to Superset; the user will then have to provide them when loading
semantic views.
"""
schema = SnowflakeConfiguration.model_json_schema()
properties = schema["properties"]
if configuration is None:
# set these to empty; they will be populated when a partial configuration is
# passed
properties["database"]["enum"] = []
properties["schema"]["enum"] = []
return schema
connection_parameters = get_connection_parameters(configuration)
with connect(**connection_parameters) as connection:
if all(
getattr(configuration, dependency)
for dependency in properties["database"].get("x-dependsOn", [])
):
options = cls._fetch_databases(connection)
properties["database"]["enum"] = list(options)
if (
all(
getattr(configuration, dependency)
for dependency in properties["schema"].get("x-dependsOn", [])
)
and configuration.database
):
options = cls._fetch_schemas(connection, configuration.database)
properties["schema"]["enum"] = list(options)
return schema
@classmethod
def get_runtime_schema(
cls,
configuration: SnowflakeConfiguration,
runtime_data: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""
Get the JSON schema for the runtime parameters needed to load semantic views.
The schema can be enriched with actual values when `runtime_data` is provided,
enabling dynamic schema updates (e.g., populating schema dropdown after
database is selected).
"""
fields: dict[str, tuple[Any, Field]] = {}
# update configuration with runtime data, for example, to select a schema after
# the database has been selected
configuration = configuration.model_copy(update=runtime_data)
connection_parameters = get_connection_parameters(configuration)
with connect(**connection_parameters) as connection:
if not configuration.database or configuration.allow_changing_database:
options = cls._fetch_databases(connection)
fields["database"] = (
Literal[*options],
Field(description="The default database to use."),
)
if not configuration.schema_ or configuration.allow_changing_schema:
if configuration.database:
options = cls._fetch_schemas(connection, configuration.database)
fields["schema_"] = (
Literal[*options],
Field(
description="The default schema to use.",
alias="schema",
json_schema_extra=(
{
"x-dynamic": True,
"x-dependsOn": ["database"],
}
if "database" in fields
else {}
),
),
)
else:
# Database not provided yet, add schema as empty
# (will be populated dynamically)
fields["schema_"] = (
str | None,
Field(
default=None,
description="The default schema to use.",
alias="schema",
json_schema_extra={
"x-dynamic": True,
"x-dependsOn": ["database"],
},
),
)
return create_model("RuntimeParameters", **fields).model_json_schema()
@classmethod
def _fetch_databases(cls, connection: SnowflakeConnection) -> set[str]:
"""
Fetch the list of databases available in the Snowflake account.
We use `SHOW DATABASES` instead of querying the information schema since it
allows to retrieve the list of databases without having to specify a database
when connecting.
"""
cursor = connection.cursor()
cursor.execute("SHOW DATABASES")
return {row[1] for row in cursor}
@classmethod
def _fetch_schemas(
cls,
connection: SnowflakeConnection,
database: str | None,
) -> set[str]:
"""
Fetch the list of schemas available in a given database.
The connection should already have the database set in its context.
"""
if not database:
return set()
cursor = connection.cursor()
query = dedent(
"""
SELECT SCHEMA_NAME
FROM INFORMATION_SCHEMA.SCHEMATA
WHERE CATALOG_NAME = ?
"""
).strip()
return {row[0] for row in cursor.execute(query, (database,))}
def __init__(self, configuration: SnowflakeConfiguration):
self.configuration = configuration
def get_semantic_views(
self,
runtime_configuration: dict[str, Any],
) -> set[SnowflakeSemanticView]:
"""
Get the semantic views available in the semantic layer.
"""
# Avoid circular import
from superset.semantic_layers.snowflake.semantic_view import (
SnowflakeSemanticView,
)
# create a new configuration with the runtime parameters
configuration = self.configuration.model_copy(update=runtime_configuration)
connection_parameters = get_connection_parameters(configuration)
with connect(**connection_parameters) as connection:
cursor = connection.cursor()
query = dedent(
"""
SHOW SEMANTIC VIEWS
->> SELECT "name" FROM $1;
"""
).strip()
views = {
SnowflakeSemanticView(row[0], configuration)
for row in cursor.execute(query)
}
return views
def get_semantic_view(
self,
name: str,
additional_configuration: dict[str, Any],
) -> SnowflakeSemanticView:
"""
Get a specific semantic view by name.
"""
# Avoid circular import
from superset.semantic_layers.snowflake.semantic_view import (
SnowflakeSemanticView,
)
# create a new configuration with the additional parameters
configuration = self.configuration.model_copy(update=additional_configuration)
return SnowflakeSemanticView(name, configuration)
# check that the semantic view exists
connection_parameters = get_connection_parameters(configuration)
with connect(**connection_parameters) as connection:
cursor = connection.cursor()
query = dedent(
"""
SHOW SEMANTIC VIEWS
->> SELECT "name" FROM $1 WHERE "name" = ?;
"""
).strip()
cursor.execute(query, (name,))
rows = cursor.fetchall()
if not rows:
raise ValueError(f'Semantic view "{name}" does not exist.')
return SnowflakeSemanticView(name, configuration)

View File

@@ -0,0 +1,873 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# ruff: noqa: S608
from __future__ import annotations
import itertools
import re
from collections import defaultdict
from textwrap import dedent
from pandas import DataFrame
from snowflake.connector import connect, DictCursor
from snowflake.sqlalchemy.snowdialect import SnowflakeDialect
from superset.semantic_layers.snowflake.schemas import SnowflakeConfiguration
from superset.semantic_layers.snowflake.utils import (
get_connection_parameters,
substitute_parameters,
validate_order_by,
)
from superset.semantic_layers.types import (
AdhocExpression,
AdhocFilter,
BINARY,
BOOLEAN,
DATE,
DATETIME,
DECIMAL,
Dimension,
Filter,
FilterValues,
GroupLimit,
INTEGER,
Metric,
NUMBER,
OBJECT,
Operator,
OrderDirection,
OrderTuple,
PredicateType,
SemanticRequest,
SemanticResult,
SemanticViewFeature,
SemanticViewImplementation,
STRING,
TIME,
Type,
)
REQUEST_TYPE = "snowflake"
class SnowflakeSemanticView(SemanticViewImplementation):
features = frozenset(
{
SemanticViewFeature.ADHOC_EXPRESSIONS_IN_ORDERBY,
SemanticViewFeature.GROUP_LIMIT,
SemanticViewFeature.GROUP_OTHERS,
}
)
def __init__(self, name: str, configuration: SnowflakeConfiguration):
self.configuration = configuration
self.name = name
self._quote = SnowflakeDialect().identifier_preparer.quote
self.dimensions = self.get_dimensions()
self.metrics = self.get_metrics()
def uid(self) -> str:
return ".".join(
self._quote(part)
for part in (
self.configuration.database,
self.configuration.schema_,
self.name,
)
)
def get_dimensions(self) -> set[Dimension]:
"""
Get the dimensions defined in the semantic view.
Even though Snowflake supports `SHOW SEMANTIC DIMENSIONS IN my_semantic_view`,
it doesn't return the expression of dimensions, so we use a slightly more
complicated query to get all the information we need in one go.
"""
dimensions: set[Dimension] = set()
query = dedent(
f"""
DESC SEMANTIC VIEW {self.uid()}
->> SELECT "object_name", "property", "property_value"
FROM $1
WHERE
"object_kind" = 'DIMENSION' AND
"property" IN ('COMMENT', 'DATA_TYPE', 'EXPRESSION', 'TABLE');
"""
).strip()
connection_parameters = get_connection_parameters(self.configuration)
with connect(**connection_parameters) as connection:
cursor = connection.cursor(DictCursor)
rows = cursor.execute(query).fetchall()
for name, group in itertools.groupby(rows, key=lambda x: x["object_name"]):
attributes = defaultdict(set)
for row in group:
attributes[row["property"]].add(row["property_value"])
table = next(iter(attributes["TABLE"]))
id_ = table + "." + name
type_ = self._get_type(next(iter(attributes["DATA_TYPE"])))
description = next(iter(attributes["COMMENT"]), None)
definition = next(iter(attributes["EXPRESSION"]), None)
dimensions.add(Dimension(id_, name, type_, description, definition))
return dimensions
def get_metrics(self) -> set[Metric]:
"""
Get the metrics defined in the semantic view.
"""
metrics: set[Metric] = set()
query = dedent(
f"""
DESC SEMANTIC VIEW {self.uid()}
->> SELECT "object_name", "property", "property_value"
FROM $1
WHERE
"object_kind" = 'METRIC' AND
"property" IN ('COMMENT', 'DATA_TYPE', 'EXPRESSION', 'TABLE');
"""
).strip()
connection_parameters = get_connection_parameters(self.configuration)
with connect(**connection_parameters) as connection:
cursor = connection.cursor(DictCursor)
rows = cursor.execute(query).fetchall()
for name, group in itertools.groupby(rows, key=lambda x: x["object_name"]):
attributes = defaultdict(set)
for row in group:
attributes[row["property"]].add(row["property_value"])
table = next(iter(attributes["TABLE"]))
id_ = table + "." + name
type_ = self._get_type(next(iter(attributes["DATA_TYPE"])))
description = next(iter(attributes["COMMENT"]), None)
definition = next(iter(attributes["EXPRESSION"]), None)
metrics.add(Metric(id_, name, type_, definition, description))
return metrics
def _get_type(self, snowflake_type: str | None) -> type[Type]:
"""
Return the semantic type corresponding to a Snowflake type.
"""
if snowflake_type is None:
return STRING
type_map = {
STRING: {r"VARCHAR\(\d+\)$", "STRING$", "TEXT$", r"CHAR\(\d+\)$"},
INTEGER: {r"NUMBER\(38,\s?0\)$", "INT$", "INTEGER$", "BIGINT$"},
DECIMAL: {r"NUMBER\(10,\s?2\)$"},
NUMBER: {r"NUMBER\(\d+,\s?\d+\)$", "FLOAT$", "DOUBLE$"},
BOOLEAN: {"BOOLEAN$"},
DATE: {"DATE$"},
DATETIME: {"TIMESTAMP_TZ$", "TIMESTAMP__NTZ$"},
TIME: {"TIME$"},
OBJECT: {"OBJECT$"},
BINARY: {r"BINARY\(\d+\)$", r"VARBINARY\(\d+\)$"},
}
for semantic_type, patterns in type_map.items():
if any(
re.match(pattern, snowflake_type, re.IGNORECASE) for pattern in patterns
):
return semantic_type
return STRING
def _build_predicates(
self,
filters: list[Filter | AdhocFilter],
) -> tuple[str, tuple[FilterValues, ...]]:
"""
Convert a set of filters to a single `AND`ed predicate.
Caller should check the types of filters beforehand, as this method does not
differentiate between `WHERE` and `HAVING` predicates.
"""
if not filters:
return "", ()
# convert filters predicate with associated parameters; native filters are
# already strings, so we keep them as-is
unary_operators = {Operator.IS_NULL, Operator.IS_NOT_NULL}
predicates: list[str] = []
parameters: list[FilterValues] = []
for filter_ in filters or set():
if isinstance(filter_, AdhocFilter):
predicates.append(f"({filter_.definition})")
else:
predicates.append(f"({self._build_native_filter(filter_)})")
if filter_.operator not in unary_operators:
parameters.extend(
[filter_.value]
if not isinstance(filter_.value, (set, frozenset))
else filter_.value
)
return " AND ".join(predicates), tuple(parameters)
def get_values(
self,
dimension: Dimension,
filters: set[Filter | AdhocFilter] | None = None,
) -> SemanticResult:
"""
Return distinct values for a dimension.
"""
where_clause, parameters = self._build_predicates(
sorted(
filter_
for filter_ in (filters or [])
if filter_.type == PredicateType.WHERE
)
)
query = dedent(
f"""
SELECT {self._quote(dimension.name)}
FROM SEMANTIC_VIEW(
{self.uid()}
DIMENSIONS {dimension.id}
{"WHERE " + where_clause if where_clause else ""}
)
"""
).strip()
connection_parameters = get_connection_parameters(self.configuration)
with connect(**connection_parameters) as connection:
df = connection.cursor().execute(query, parameters).fetch_pandas_all()
return SemanticResult(
requests=[
SemanticRequest(
REQUEST_TYPE,
substitute_parameters(query, parameters),
)
],
results=df,
)
def _build_native_filter(self, filter_: Filter) -> str:
"""
Convert a Filter to a AdhocFilter.
"""
column = filter_.column
operator = filter_.operator
value = filter_.value
column_name = self._quote(column.name)
# Handle IS NULL and IS NOT NULL operators (no value needed)
if operator in {Operator.IS_NULL, Operator.IS_NOT_NULL}:
return f"{column_name} {operator.value}"
# Handle IN and NOT IN operators (set values)
if operator in {Operator.IN, Operator.NOT_IN}:
parameter_count = len(value) if isinstance(value, (set, frozenset)) else 1
formatted_values = ", ".join("?" for _ in range(parameter_count))
return f"{column_name} {operator.value} ({formatted_values})"
return f"{column_name} {operator.value} ?"
def get_dataframe(
self,
metrics: list[Metric],
dimensions: list[Dimension],
filters: set[Filter | AdhocFilter] | None = None,
order: list[OrderTuple] | None = None,
limit: int | None = None,
offset: int | None = None,
*,
group_limit: GroupLimit | None = None,
) -> SemanticResult:
"""
Execute a query and return the results as a (wrapped) Pandas DataFrame.
"""
if not metrics and not dimensions:
return DataFrame()
query, parameters = self._get_query(
metrics,
dimensions,
filters,
order,
limit,
offset,
group_limit,
)
connection_parameters = get_connection_parameters(self.configuration)
with connect(**connection_parameters) as connection:
df = connection.cursor().execute(query, parameters).fetch_pandas_all()
# map column names to dimension/metric names instead of IDs
mapping = {
**{dimension.id: dimension.name for dimension in dimensions},
**{metric.id: metric.name for metric in metrics},
}
df.rename(columns=mapping, inplace=True)
return SemanticResult(
requests=[
SemanticRequest(
REQUEST_TYPE,
substitute_parameters(query, parameters),
)
],
results=df,
)
def get_row_count(
self,
metrics: list[Metric],
dimensions: list[Dimension],
filters: set[Filter | AdhocFilter] | None = None,
order: list[OrderTuple] | None = None,
limit: int | None = None,
offset: int | None = None,
*,
group_limit: GroupLimit | None = None,
) -> SemanticResult:
"""
Execute a query and return the number of rows the result would have.
"""
if not metrics and not dimensions:
return SemanticResult(
requests=[],
results=DataFrame([[0]], columns=["COUNT"]),
)
query, parameters = self._get_query(
metrics,
dimensions,
filters,
order,
limit,
offset,
group_limit,
)
query = f"SELECT COUNT(*) FROM ({query}) AS subquery"
connection_parameters = get_connection_parameters(self.configuration)
with connect(**connection_parameters) as connection:
df = connection.cursor().execute(query, parameters).fechone()[0]
return SemanticResult(
requests=[
SemanticRequest(
REQUEST_TYPE,
substitute_parameters(query, parameters),
)
],
results=df,
)
def _get_query(
self,
metrics: list[Metric],
dimensions: list[Dimension],
filters: set[Filter | AdhocFilter] | None = None,
order: list[OrderTuple] | None = None,
limit: int | None = None,
offset: int | None = None,
group_limit: GroupLimit | None = None,
) -> tuple[str, tuple[FilterValues, ...]]:
"""
Build a query to fetch data from the semantic view.
This also returns the parameters need to run `cursor.execute()`, passed
separately to prevent SQL injection.
"""
if limit is None and offset is not None:
raise ValueError("Offset cannot be set without limit")
filters = filters or set()
where_clause, where_parameters = self._build_predicates(
# XXX sort to ensure deterministic order for parameters
[filter_ for filter_ in filters if filter_.type == PredicateType.WHERE]
)
# having clauses are not supported, since there's no GROUP BY
if any(filter_.type == PredicateType.HAVING for filter_ in filters):
raise ValueError("HAVING filters are not supported")
if group_limit:
query, cte_parameters = self._build_query_with_group_limit(
metrics,
dimensions,
where_clause,
order,
limit,
offset,
group_limit,
)
# Combine parameters: CTE params first, then main query params
all_parameters = cte_parameters + where_parameters
else:
query = self._build_simple_query(
metrics,
dimensions,
where_clause,
order,
limit,
offset,
)
all_parameters = where_parameters
return query, all_parameters
def _alias_element(self, element: Metric | Dimension) -> str:
"""
Generate an aliased column expression for a metric or dimension.
"""
return f"{element.id} AS {self._quote(element.id)}"
def _build_order_clause(
self,
order: list[OrderTuple] | None = None,
) -> str:
"""
Build the ORDER BY clause from a list of (element, direction) tuples.
Note that for adhoc expressions, Superset will still add `ASC` or `DESC` to the
end, which means adhoc expressions can contain multiple columns as long as the
last one has no direction specified.
This is fine:
gender ASC, COUNT(*)
But this is not
gender ASC, COUNT(*) DESC
The latter will produce a query that looks like this:
... ORDER BY gender ASC, COUNT(*) DESC DESC
"""
if not order:
return ""
def build_element(element: Metric | Dimension | AdhocExpression) -> str:
if isinstance(element, AdhocExpression):
validate_order_by(element.definition)
return element.definition
return self._quote(element.id)
return ", ".join(
f"{build_element(element)} {direction.value}"
for element, direction in order
)
def _get_temporal_dimension(
self,
dimensions: list[Dimension],
) -> Dimension | None:
"""
Find the first temporal dimension in the list.
Returns the first dimension with a temporal type (DATE, DATETIME, TIME),
or None if no temporal dimension is found.
"""
temporal_types = {DATE, DATETIME, TIME}
for dimension in dimensions:
if dimension.type in temporal_types:
return dimension
return None
def _get_default_order(
self,
dimensions: list[Dimension],
order: list[OrderTuple] | None,
) -> list[OrderTuple] | None:
"""
Get the order to use, prepending temporal sort if needed.
If there's a temporal dimension in the query and it's not already
in the order, prepends an ascending sort by that dimension.
This ensures time-series data is always sorted chronologically first.
"""
temporal_dimension = self._get_temporal_dimension(dimensions)
if not temporal_dimension:
return order
# Check if temporal dimension is already in the order
if order:
for element, _ in order:
if isinstance(element, Dimension) and element.id == temporal_dimension.id:
return order
# Prepend temporal dimension to existing order
return [(temporal_dimension, OrderDirection.ASC)] + list(order)
# No order specified, use temporal dimension
return [(temporal_dimension, OrderDirection.ASC)]
def _build_simple_query(
self,
metrics: list[Metric],
dimensions: list[Dimension],
where_clause: str,
order: list[OrderTuple] | None,
limit: int | None,
offset: int | None,
) -> str:
"""
Build a query without group limiting.
"""
dimension_arguments = ", ".join(
self._alias_element(dimension) for dimension in dimensions
)
metric_arguments = ", ".join(self._alias_element(metric) for metric in metrics)
# Use default temporal ordering if no explicit order is provided
effective_order = self._get_default_order(dimensions, order)
order_clause = self._build_order_clause(effective_order)
return dedent(
f"""
SELECT * FROM SEMANTIC_VIEW(
{self.uid()}
{"DIMENSIONS " + dimension_arguments if dimension_arguments else ""}
{"METRICS " + metric_arguments if metric_arguments else ""}
{"WHERE " + where_clause if where_clause else ""}
)
{"ORDER BY " + order_clause if order_clause else ""}
{"LIMIT " + str(limit) if limit is not None else ""}
{"OFFSET " + str(offset) if offset is not None else ""}
"""
).strip()
def _build_top_groups_cte(
self,
group_limit: GroupLimit,
where_clause: str,
) -> tuple[str, tuple[FilterValues, ...]]:
"""
Build a CTE that finds the top N combinations of limited dimensions.
If group_limit.filters is set, it uses those filters instead of the main
query's where clause. This allows using different time bounds for finding top
groups vs showing data.
Returns:
Tuple of (CTE SQL, parameters for the CTE)
"""
limited_dimension_arguments = ", ".join(
self._alias_element(dimension) for dimension in group_limit.dimensions
)
limited_dimension_names = ", ".join(
self._quote(dimension.id) for dimension in group_limit.dimensions
)
# Use separate filters for group limit if provided (Option 2)
# Otherwise use the same filters as the main query (Option 1)
if group_limit.filters is not None:
group_where_clause, group_where_params = self._build_predicates(
sorted(
filter_
for filter_ in group_limit.filters
if filter_.type == PredicateType.WHERE
)
)
if any(
filter_.type == PredicateType.HAVING for filter_ in group_limit.filters
):
raise ValueError(
"HAVING filters are not supported in group limit filters"
)
cte_params = group_where_params
else:
group_where_clause = where_clause
cte_params = () # No additional params - using main query params
# Build METRICS clause and ORDER BY based on whether metric is provided
if group_limit.metric is not None:
metrics_clause = (
f"METRICS {group_limit.metric.id}"
f" AS {self._quote(group_limit.metric.id)}"
)
order_by_clause = (
f"{self._quote(group_limit.metric.id)} {group_limit.direction.value}"
)
else:
# No metric provided - order by first dimension
metrics_clause = ""
order_by_clause = (
f"{self._quote(group_limit.dimensions[0].id)} "
f"{group_limit.direction.value}"
)
# Build SEMANTIC_VIEW arguments
semantic_view_args = [
f"DIMENSIONS {limited_dimension_arguments}",
]
if metrics_clause:
semantic_view_args.append(metrics_clause)
if group_where_clause:
semantic_view_args.append(f"WHERE {group_where_clause}")
semantic_view_args_str = "\n ".join(semantic_view_args)
# Add trailing blank line if there's no WHERE clause
# This matches the original template behavior
if not group_where_clause:
semantic_view_args_str += "\n"
cte_sql = dedent(
f"""
WITH top_groups AS (
SELECT {limited_dimension_names}
FROM SEMANTIC_VIEW(
{self.uid()}
{semantic_view_args_str}
)
ORDER BY
{order_by_clause}
LIMIT {group_limit.top}
)
"""
).strip()
return cte_sql, cte_params
def _build_group_filter(self, group_limit: GroupLimit) -> str:
"""
Build a WHERE filter that restricts results to top N groups.
"""
if len(group_limit.dimensions) == 1:
dimension_id = self._quote(group_limit.dimensions[0].id)
return f"{dimension_id} IN (SELECT {dimension_id} FROM top_groups)"
# Multi-column IN clause
dimension_tuple = ", ".join(
self._quote(dim.id) for dim in group_limit.dimensions
)
return f"({dimension_tuple}) IN (SELECT {dimension_tuple} FROM top_groups)"
def _build_case_expression(
self,
dimension: Dimension,
group_condition: str,
) -> str:
"""
Build a CASE expression that replaces non-top values with 'Other'.
Args:
dimension: The dimension to build the CASE for
group_condition: The condition to check if value is in top groups
(e.g., "staff_id IN (SELECT staff_id FROM top_groups)")
Returns:
SQL CASE expression
"""
dimension_id = self._quote(dimension.id)
return f"""CASE
WHEN {group_condition} THEN {dimension_id}
ELSE CAST('Other' AS VARCHAR)
END"""
def _build_query_with_others(
self,
metrics: list[Metric],
dimensions: list[Dimension],
where_clause: str,
order: list[OrderTuple] | None,
limit: int | None,
offset: int | None,
group_limit: GroupLimit,
) -> tuple[str, tuple[FilterValues, ...]]:
"""
Build a query that groups non-top N values as 'Other'.
This uses a two-stage approach:
1. CTE to find top N groups
2. Subquery with CASE expressions to replace non-top values with 'Other'
3. Outer query to re-aggregate with the new grouping
Returns:
Tuple of (SQL query, CTE parameters)
"""
top_groups_cte, cte_params = self._build_top_groups_cte(
group_limit,
where_clause,
)
# Determine which dimensions are limited vs non-limited
limited_dimension_ids = {dim.id for dim in group_limit.dimensions}
non_limited_dimensions = [
dim for dim in dimensions if dim.id not in limited_dimension_ids
]
# Build the group condition for CASE expressions
if len(group_limit.dimensions) == 1:
dimension_id = self._quote(group_limit.dimensions[0].id)
group_condition = (
f"{dimension_id} IN (SELECT {dimension_id} FROM top_groups)"
)
else:
dimension_tuple = ", ".join(
self._quote(dim.id) for dim in group_limit.dimensions
)
group_condition = (
f"({dimension_tuple}) IN (SELECT {dimension_tuple} FROM top_groups)"
)
# Build CASE expressions for limited dimensions
case_expressions = []
case_expressions_for_groupby = []
for dim in group_limit.dimensions:
case_expr = self._build_case_expression(dim, group_condition)
alias = self._quote(dim.id)
case_expressions.append(f"{case_expr} AS {alias}")
# Store the full CASE expression for GROUP BY (not just alias)
case_expressions_for_groupby.append(case_expr)
# Build SELECT for non-limited dimensions (pass through)
non_limited_selects = [
f"{self._quote(dim.id)} AS {self._quote(dim.id)}"
for dim in non_limited_dimensions
]
# Build metric aggregations
metric_aggregations = [
f"SUM({self._quote(metric.id)}) AS {self._quote(metric.id)}"
for metric in metrics
]
# Build the subquery that gets raw data from SEMANTIC_VIEW
dimension_arguments = ", ".join(
self._alias_element(dimension) for dimension in dimensions
)
metric_arguments = ", ".join(self._alias_element(metric) for metric in metrics)
subquery = dedent(
f"""
raw_data AS (
SELECT * FROM SEMANTIC_VIEW(
{self.uid()}
DIMENSIONS {dimension_arguments}
METRICS {metric_arguments}
{"WHERE " + where_clause if where_clause else ""}
)
)
"""
).strip()
# Build GROUP BY clause (full CASE expressions + non-limited dimensions)
# We need to repeat the full CASE expressions, not use aliases, because
# Snowflake may interpret the alias as the original column reference
group_by_columns = case_expressions_for_groupby + [
self._quote(dim.id) for dim in non_limited_dimensions
]
group_by_clause = ", ".join(group_by_columns)
# Build final SELECT columns
select_columns = case_expressions + non_limited_selects + metric_aggregations
select_clause = ",\n ".join(select_columns)
# Build ORDER BY clause (need to reference the aliased columns)
# Use default temporal ordering if no explicit order is provided
effective_order = self._get_default_order(dimensions, order)
order_clause = self._build_order_clause(effective_order)
query = dedent(
f"""
{top_groups_cte},
{subquery}
SELECT
{select_clause}
FROM raw_data
GROUP BY {group_by_clause}
{"ORDER BY " + order_clause if order_clause else ""}
{"LIMIT " + str(limit) if limit is not None else ""}
{"OFFSET " + str(offset) if offset is not None else ""}
"""
).strip()
return query, cte_params
def _build_query_with_group_limit(
self,
metrics: list[Metric],
dimensions: list[Dimension],
where_clause: str,
order: list[OrderTuple] | None,
limit: int | None,
offset: int | None,
group_limit: GroupLimit,
) -> tuple[str, tuple[FilterValues, ...]]:
"""
Build a query with group limiting (top N groups).
If group_others is True, groups non-top values as 'Other'.
Otherwise, filters to show only top N groups.
Returns:
Tuple of (SQL query, CTE parameters)
"""
if group_limit.group_others:
return self._build_query_with_others(
metrics,
dimensions,
where_clause,
order,
limit,
offset,
group_limit,
)
# Standard group limiting: just filter to top N groups
# We can't use CTE references inside SEMANTIC_VIEW(), so we wrap it
dimension_arguments = ", ".join(
self._alias_element(dimension) for dimension in dimensions
)
metric_arguments = ", ".join(self._alias_element(metric) for metric in metrics)
# Use default temporal ordering if no explicit order is provided
effective_order = self._get_default_order(dimensions, order)
order_clause = self._build_order_clause(effective_order)
top_groups_cte, cte_params = self._build_top_groups_cte(
group_limit,
where_clause,
)
group_filter = self._build_group_filter(group_limit)
query = dedent(
f"""
{top_groups_cte}
SELECT * FROM SEMANTIC_VIEW(
{self.uid()}
{"DIMENSIONS " + dimension_arguments if dimension_arguments else ""}
{"METRICS " + metric_arguments if metric_arguments else ""}
{"WHERE " + where_clause if where_clause else ""}
) AS subquery
WHERE {group_filter}
{"ORDER BY " + order_clause if order_clause else ""}
{"LIMIT " + str(limit) if limit is not None else ""}
{"OFFSET " + str(offset) if offset is not None else ""}
"""
).strip()
return query, cte_params
__repr__ = uid

View File

@@ -0,0 +1,123 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# ruff: noqa: S608
from __future__ import annotations
from typing import Any, Sequence
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from superset.exceptions import SupersetParseError
from superset.semantic_layers.snowflake.schemas import (
PrivateKeyAuth,
SnowflakeConfiguration,
UserPasswordAuth,
)
from superset.sql.parse import SQLStatement
def substitute_parameters(query: str, parameters: Sequence[Any] | None) -> str:
"""
Substitute parametereters in templated query.
This is used to convert bind query parameters so that we can return the executed
query for logging/auditing purposes. With Snowflake the binding happens on the
server, so the only way to get the true executed query would be to query the
database, which is innefficient.
"""
if not parameters:
return query
result = query
for parameter in parameters:
if parameter is None:
replacement = "NULL"
elif isinstance(parameter, bool):
# Check bool before int/float since bool is a subclass of int
replacement = str(parameter).upper()
elif isinstance(parameter, (int, float)):
replacement = str(parameter)
else:
# String - escape single quotes
quoted = str(parameter).replace("'", "''")
replacement = f"'{quoted}'"
result = result.replace("?", replacement, 1)
return result
def validate_order_by(definition: str) -> None:
"""
Validate that an ORDER BY expression is safe to use.
Note that `definition` could contain multiple expressions separated by commas.
"""
try:
# this ensures that we have a single statement, preventing SQL injection via a
# semicolon in the order by clause
SQLStatement(f"SELECT 1 ORDER BY {definition}", "snowflake")
except SupersetParseError as ex:
raise ValueError("Invalid ORDER BY expression") from ex
def get_connection_parameters(configuration: SnowflakeConfiguration) -> dict[str, Any]:
"""
Convert the configuration to connection parameters for the Snowflake connector.
"""
params = {
"account": configuration.account_identifier,
"application": "Apache Superset",
"paramstyle": "qmark",
"insecure_mode": True,
}
if configuration.role:
params["role"] = configuration.role
if configuration.warehouse:
params["warehouse"] = configuration.warehouse
if configuration.database:
params["database"] = configuration.database
if configuration.schema_:
params["schema"] = configuration.schema_
auth = configuration.auth
if isinstance(auth, UserPasswordAuth):
params["user"] = auth.username
params["password"] = auth.password.get_secret_value()
elif isinstance(auth, PrivateKeyAuth):
pem_private_key = serialization.load_pem_private_key(
auth.private_key.get_secret_value().encode(),
password=(
auth.private_key_password.get_secret_value().encode()
if auth.private_key_password
else None
),
backend=default_backend(),
)
params["private_key"] = pem_private_key.private_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)
else:
raise ValueError("Unsupported authentication method")
return params

View File

@@ -0,0 +1,497 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import enum
from dataclasses import dataclass
from datetime import date, datetime, time, timedelta
from functools import total_ordering
from typing import Any, Protocol, runtime_checkable, TypeVar
from pandas import DataFrame
from pydantic import BaseModel
__all__ = [
"BINARY",
"BOOLEAN",
"DATE",
"DATETIME",
"DECIMAL",
"Day",
"Dimension",
"Hour",
"INTEGER",
"INTERVAL",
"Minute",
"Month",
"NUMBER",
"OBJECT",
"Quarter",
"Second",
"STRING",
"TIME",
"Week",
"Year",
]
class Type:
"""
Base class for types.
"""
class INTEGER(Type):
"""
Represents an integer type.
"""
class NUMBER(Type):
"""
Represents a number type.
"""
class DECIMAL(Type):
"""
Represents a decimal type.
"""
class STRING(Type):
"""
Represents a string type.
"""
class BOOLEAN(Type):
"""
Represents a boolean type.
"""
class DATE(Type):
"""
Represents a date type.
"""
class TIME(Type):
"""
Represents a time type.
"""
class DATETIME(DATE, TIME):
"""
Represents a datetime type.
"""
class INTERVAL(Type):
"""
Represents an interval type.
"""
class OBJECT(Type):
"""
Represents an object type.
"""
class BINARY(Type):
"""
Represents a binary type.
"""
@dataclass(frozen=True)
@total_ordering
class Grain:
"""
Base class for time and date grains with comparison support.
Attributes:
name: Human-readable name of the grain (e.g., "Second")
representation: ISO 8601 representation (e.g., "PT1S")
value: Time period as a timedelta
"""
name: str
representation: str
value: timedelta
def __eq__(self, other: object) -> bool:
if isinstance(other, Grain):
return self.value == other.value
return NotImplemented
def __lt__(self, other: object) -> bool:
if isinstance(other, Grain):
return self.value < other.value
return NotImplemented
def __hash__(self) -> int:
return hash((self.name, self.representation, self.value))
class Second(Grain):
name = "Second"
representation = "PT1S"
value = timedelta(seconds=1)
class Minute(Grain):
name = "Minute"
representation = "PT1M"
value = timedelta(minutes=1)
class Hour(Grain):
name = "Hour"
representation = "PT1H"
value = timedelta(hours=1)
class Day(Grain):
name = "Day"
representation = "P1D"
value = timedelta(days=1)
class Week(Grain):
name = "Week"
representation = "P1W"
value = timedelta(weeks=1)
class Month(Grain):
name = "Month"
representation = "P1M"
value = timedelta(days=30)
class Quarter(Grain):
name = "Quarter"
representation = "P3M"
value = timedelta(days=90)
class Year(Grain):
name = "Year"
representation = "P1Y"
value = timedelta(days=365)
@dataclass(frozen=True)
class Dimension:
id: str
name: str
type: type[Type]
definition: str | None = None
description: str | None = None
grain: Grain | None = None
@dataclass(frozen=True)
class Metric:
id: str
name: str
type: type[Type]
definition: str | None
description: str | None = None
@dataclass(frozen=True)
class AdhocExpression:
id: str
definition: str
class Operator(str, enum.Enum):
EQUALS = "="
NOT_EQUALS = "!="
GREATER_THAN = ">"
LESS_THAN = "<"
GREATER_THAN_OR_EQUAL = ">="
LESS_THAN_OR_EQUAL = "<="
IN = "IN"
NOT_IN = "NOT IN"
LIKE = "LIKE"
NOT_LIKE = "NOT LIKE"
IS_NULL = "IS NULL"
IS_NOT_NULL = "IS NOT NULL"
FilterValues = str | int | float | bool | datetime | date | time | timedelta | None
class PredicateType(enum.Enum):
WHERE = "WHERE"
HAVING = "HAVING"
@dataclass(frozen=True, order=True)
class Filter:
type: PredicateType
column: Dimension | Metric
operator: Operator
value: FilterValues | set[FilterValues]
@dataclass(frozen=True, order=True)
class AdhocFilter:
type: PredicateType
definition: str
class OrderDirection(enum.Enum):
ASC = "ASC"
DESC = "DESC"
OrderTuple = tuple[Metric | Dimension | AdhocExpression, OrderDirection]
@dataclass(frozen=True)
class GroupLimit:
"""
Limit query to top/bottom N combinations of specified dimensions.
The `filters` parameter allows specifying separate filter constraints for the
group limit subquery. This is useful when you want to determine the top N groups
using different criteria (e.g., a different time range) than the main query.
For example, you might want to find the top 10 products by sales over the last
30 days, but then show daily sales for those products over the last 7 days.
"""
dimensions: list[Dimension]
top: int
metric: Metric | None
direction: OrderDirection = OrderDirection.DESC
group_others: bool = False
filters: set[Filter | AdhocFilter] | None = None
@dataclass(frozen=True)
class SemanticRequest:
"""
Represents a request made to obtain semantic results.
This could be a SQL query, an HTTP request, etc.
"""
type: str
definition: str
@dataclass(frozen=True)
class SemanticResult:
"""
Represents the results of a semantic query.
This includes any requests (SQL queries, HTTP requests) that were performed in order
to obtain the results, in order to help troubleshooting.
"""
requests: list[SemanticRequest]
results: DataFrame
@dataclass(frozen=True)
class SemanticQuery:
"""
Represents a semantic query.
"""
metrics: list[Metric]
dimensions: list[Dimension]
filters: set[Filter | AdhocFilter] | None = None
order: list[OrderTuple] | None = None
limit: int | None = None
offset: int | None = None
group_limit: GroupLimit | None = None
class SemanticViewFeature(enum.Enum):
"""
Custom features supported by semantic layers.
"""
ADHOC_EXPRESSIONS_IN_ORDERBY = "ADHOC_EXPRESSIONS_IN_ORDERBY"
GROUP_LIMIT = "GROUP_LIMIT"
GROUP_OTHERS = "GROUP_OTHERS"
ConfigT = TypeVar("ConfigT", bound=BaseModel, contravariant=True)
SemanticViewT = TypeVar("SemanticViewT", bound="SemanticViewImplementation")
@runtime_checkable
class SemanticLayerImplementation(Protocol[ConfigT, SemanticViewT]):
"""
A protocol for semantic layers.
"""
@classmethod
def from_configuration(
cls,
configuration: dict[str, Any],
) -> SemanticLayerImplementation[ConfigT, SemanticViewT]:
"""
Create a semantic layer from its configuration.
"""
@classmethod
def get_configuration_schema(
cls,
configuration: ConfigT | None = None,
) -> dict[str, Any]:
"""
Get the JSON schema for the configuration needed to add the semantic layer.
A partial configuration `configuration` can be sent to improve the schema,
allowing for progressive validation and better UX. For example, a semantic
layer might require:
- auth information
- a database
If the user provides the auth information, a client can send the partial
configuration to this method, and the resulting JSON schema would include
the list of databases the user has access to, allowing a dropdown to be
populated.
The Snowflake semantic layer has an example implementation of this method, where
database and schema names are populated based on the provided connection info.
"""
@classmethod
def get_runtime_schema(
cls,
configuration: ConfigT,
runtime_data: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""
Get the JSON schema for the runtime parameters needed to load semantic views.
This returns the schema needed to connect to a semantic view given the
configuration for the semantic layer. For example, a semantic layer might
be configured by:
- auth information
- an optional database
If the user does not provide a database when creating the semantic layer, the
runtime schema would require the database name to be provided before loading any
semantic views. This allows users to create semantic layers that connect to a
specific database (or project, account, etc.), or that allow users to select it
at query time.
The Snowflake semantic layer has an example implementation of this method, where
database and schema names are required if they were not provided in the initial
configuration.
"""
def get_semantic_views(
self,
runtime_configuration: dict[str, Any],
) -> set[SemanticViewT]:
"""
Get the semantic views available in the semantic layer.
The runtime configuration can provide information like a given project or
schema, used to restrict the semantic views returned.
"""
def get_semantic_view(
self,
name: str,
additional_configuration: dict[str, Any],
) -> SemanticViewT:
"""
Get a specific semantic view by its name and additional configuration.
"""
@runtime_checkable
class SemanticViewImplementation(Protocol):
"""
A protocol for semantic views.
"""
features: frozenset[SemanticViewFeature]
def uid(self) -> str:
"""
Returns a unique identifier for the semantic view.
"""
def get_dimensions(self) -> set[Dimension]:
"""
Get the dimensions defined in the semantic view.
"""
def get_metrics(self) -> set[Metric]:
"""
Get the metrics defined in the semantic view.
"""
def get_values(
self,
dimension: Dimension,
filters: set[Filter | AdhocFilter] | None = None,
) -> SemanticResult:
"""
Return distinct values for a dimension.
"""
def get_dataframe(
self,
metrics: list[Metric],
dimensions: list[Dimension],
filters: set[Filter | AdhocFilter] | None = None,
order: list[OrderTuple] | None = None,
limit: int | None = None,
offset: int | None = None,
*,
group_limit: GroupLimit | None = None,
) -> SemanticResult:
"""
Execute a semantic query and return the results as a DataFrame.
"""
def get_row_count(
self,
metrics: list[Metric],
dimensions: list[Dimension],
filters: set[Filter | AdhocFilter] | None = None,
order: list[OrderTuple] | None = None,
limit: int | None = None,
offset: int | None = None,
*,
group_limit: GroupLimit | None = None,
) -> SemanticResult:
"""
Execute a query and return the number of rows the result would have.
"""

View File

@@ -57,6 +57,46 @@ class AdhocMetric(TypedDict, total=False):
sqlExpression: str | None
class DatasetColumnData(TypedDict, total=False):
"""Type for column metadata in ExplorableData datasets."""
advanced_data_type: str | None
certification_details: str | None
certified_by: str | None
column_name: str
description: str | None
expression: str | None
filterable: bool
groupby: bool
id: int
uuid: str | None
is_certified: bool
is_dttm: bool
python_date_format: str | None
type: str
type_generic: NotRequired["GenericDataType" | None]
verbose_name: str | None
warning_markdown: str | None
class DatasetMetricData(TypedDict, total=False):
"""Type for metric metadata in ExplorableData datasets."""
certification_details: str | None
certified_by: str | None
currency: NotRequired[dict[str, Any]]
d3format: str | None
description: str | None
expression: str
id: int
uuid: str | None
is_certified: bool
metric_name: str
warning_markdown: str | None
warning_text: str | None
verbose_name: str | None
class AdhocColumn(TypedDict, total=False):
hasCustomLabel: bool | None
label: str
@@ -195,15 +235,19 @@ class QueryObjectDict(TypedDict, total=False):
timeseries_limit_metric: Metric | None
class BaseDatasourceData(TypedDict, total=False):
class ExplorableData(TypedDict, total=False):
"""
TypedDict for datasource data returned to the frontend.
TypedDict for explorable data returned to the frontend.
This represents the structure of the dictionary returned from BaseDatasource.data
property. It provides datasource information to the frontend for visualization
and querying.
This represents the structure of the dictionary returned from the `data` property
of any Explorable (BaseDatasource, Query, etc.). It provides datasource/query
information to the frontend for visualization and querying.
Core fields from BaseDatasource.data:
All fields are optional (total=False) since different explorable types provide
different subsets of these fields. Query objects provide a minimal subset while
SqlaTable provides the full set.
Core fields:
id: Unique identifier for the datasource
uid: Unique identifier including type (e.g., "1__table")
column_formats: D3 format strings for columns
@@ -268,8 +312,8 @@ class BaseDatasourceData(TypedDict, total=False):
perm: str | None
edit_url: str
sql: str | None
columns: list[dict[str, Any]]
metrics: list[dict[str, Any]]
columns: list["DatasetColumnData"]
metrics: list["DatasetMetricData"]
folders: Any # JSON field, can be list or dict
order_by_choices: list[tuple[str, str]]
owners: list[int] | list[dict[str, Any]] # Can be either format
@@ -277,8 +321,8 @@ class BaseDatasourceData(TypedDict, total=False):
select_star: str | None
# Additional fields from SqlaTable and data_for_slices
column_types: list[Any]
column_names: set[str] | set[Any]
column_types: list["GenericDataType"]
column_names: set[str]
granularity_sqla: list[tuple[Any, Any]]
time_grain_sqla: list[tuple[Any, Any]]
main_dttm_col: str | None
@@ -291,46 +335,6 @@ class BaseDatasourceData(TypedDict, total=False):
normalize_columns: bool
class QueryData(TypedDict, total=False):
"""
TypedDict for SQL Lab query data returned to the frontend.
This represents the structure of the dictionary returned from Query.data property
in SQL Lab. It provides query information to the frontend for execution and display.
Fields:
time_grain_sqla: Available time grains for this database
filter_select: Whether filter select is enabled
name: Query tab name
columns: List of column definitions
metrics: List of metrics (always empty for queries)
id: Query ID
type: Object type (always "query")
sql: SQL query text
owners: List of owner information
database: Database connection details
order_by_choices: Available ordering options
catalog: Catalog name if applicable
schema: Schema name if applicable
verbose_map: Mapping of column names to verbose names (empty for queries)
"""
time_grain_sqla: list[tuple[Any, Any]]
filter_select: bool
name: str | None
columns: list[dict[str, Any]]
metrics: list[Any]
id: int
type: str
sql: str | None
owners: list[dict[str, Any]]
database: dict[str, Any]
order_by_choices: list[tuple[str, str]]
catalog: str | None
schema: str | None
verbose_map: dict[str, str]
VizData: TypeAlias = list[Any] | dict[Any, Any] | None
VizPayload: TypeAlias = dict[str, Any]

View File

@@ -61,6 +61,7 @@ def _adjust_string_with_rls(
"""
Add the RLS filters to the unique string based on current executor.
"""
user = (
security_manager.find_user(executor)
or security_manager.get_current_guest_user_if_guest()
@@ -70,11 +71,7 @@ def _adjust_string_with_rls(
stringified_rls = ""
with override_user(user):
for datasource in datasources:
if (
datasource
and hasattr(datasource, "is_rls_supported")
and datasource.is_rls_supported
):
if datasource and getattr(datasource, "is_rls_supported", False):
rls_filters = datasource.get_sqla_row_level_filters()
if len(rls_filters) > 0:

View File

@@ -96,6 +96,7 @@ from superset.exceptions import (
SupersetException,
SupersetTimeoutException,
)
from superset.explorables.base import Explorable
from superset.sql.parse import sanitize_clause
from superset.superset_typing import (
AdhocColumn,
@@ -114,9 +115,8 @@ from superset.utils.hashing import md5_sha_from_dict, md5_sha_from_str
from superset.utils.pandas import detect_datetime_format
if TYPE_CHECKING:
from superset.connectors.sqla.models import BaseDatasource, TableColumn
from superset.connectors.sqla.models import TableColumn
from superset.models.core import Database
from superset.models.sql_lab import Query
logging.getLogger("MARKDOWN").setLevel(logging.INFO)
logger = logging.getLogger(__name__)
@@ -200,6 +200,7 @@ class DatasourceType(StrEnum):
QUERY = "query"
SAVEDQUERY = "saved_query"
VIEW = "view"
SEMANTIC_VIEW = "semantic_view"
class LoggerLevel(StrEnum):
@@ -1656,7 +1657,7 @@ def map_sql_type_to_inferred_type(sql_type: Optional[str]) -> str:
return "string" # If no match is found, return "string" as default
def get_metric_type_from_column(column: Any, datasource: BaseDatasource | Query) -> str:
def get_metric_type_from_column(column: Any, datasource: Explorable) -> str:
"""
Determine the metric type from a given column in a datasource.
@@ -1698,7 +1699,7 @@ def get_metric_type_from_column(column: Any, datasource: BaseDatasource | Query)
def extract_dataframe_dtypes(
df: pd.DataFrame,
datasource: BaseDatasource | Query | None = None,
datasource: Explorable | None = None,
) -> list[GenericDataType]:
"""Serialize pandas/numpy dtypes to generic types"""
@@ -1718,7 +1719,8 @@ def extract_dataframe_dtypes(
if datasource:
for column in datasource.columns:
if isinstance(column, dict):
columns_by_name[column.get("column_name")] = column
if column_name := column.get("column_name"):
columns_by_name[column_name] = column
else:
columns_by_name[column.column_name] = column
@@ -1768,11 +1770,13 @@ def is_test() -> bool:
def get_time_filter_status(
datasource: BaseDatasource,
datasource: Explorable,
applied_time_extras: dict[str, str],
) -> tuple[list[dict[str, str]], list[dict[str, str]]]:
temporal_columns: set[Any] = {
col.column_name for col in datasource.columns if col.is_dttm
(col.column_name if hasattr(col, "column_name") else col.get("column_name"))
for col in datasource.columns
if (col.is_dttm if hasattr(col, "is_dttm") else col.get("is_dttm"))
}
applied: list[dict[str, str]] = []
rejected: list[dict[str, str]] = []

View File

@@ -78,9 +78,8 @@ from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from superset.models.user_attributes import UserAttribute
from superset.superset_typing import (
BaseDatasourceData,
ExplorableData,
FlaskResponse,
QueryData,
)
from superset.tasks.utils import get_current_user
from superset.utils import core as utils, json
@@ -531,14 +530,14 @@ class Superset(BaseSupersetView):
)
standalone_mode = ReservedUrlParameters.is_standalone_mode()
force = request.args.get("force") in {"force", "1", "true"}
dummy_datasource_data: BaseDatasourceData = {
dummy_datasource_data: ExplorableData = {
"type": datasource_type or "unknown",
"name": datasource_name,
"columns": [],
"metrics": [],
"database": {"id": 0, "backend": ""},
}
datasource_data: BaseDatasourceData | QueryData
datasource_data: ExplorableData
try:
datasource_data = datasource.data if datasource else dummy_datasource_data
except (SupersetException, SQLAlchemyError):

View File

@@ -20,6 +20,7 @@ from collections import defaultdict
from functools import wraps
from typing import Any, Callable, DefaultDict, Optional, Union
from urllib import parse
from uuid import UUID
import msgpack
import pyarrow as pa
@@ -46,10 +47,9 @@ from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from superset.models.sql_lab import Query
from superset.superset_typing import (
BaseDatasourceData,
ExplorableData,
FlaskResponse,
FormData,
QueryData,
)
from superset.utils import json
from superset.utils.core import DatasourceType
@@ -92,12 +92,10 @@ def redirect_to_login(next_target: str | None = None) -> FlaskResponse:
def sanitize_datasource_data(
datasource_data: BaseDatasourceData | QueryData,
datasource_data: ExplorableData,
) -> dict[str, Any]:
"""
Sanitize datasource data by removing sensitive database parameters.
Accepts TypedDict types (BaseDatasourceData, QueryData).
"""
if datasource_data:
datasource_database = datasource_data.get("database")
@@ -275,8 +273,10 @@ def add_sqllab_custom_filters(form_data: dict[Any, Any]) -> Any:
def get_datasource_info(
datasource_id: Optional[int], datasource_type: Optional[str], form_data: FormData
) -> tuple[int, Optional[str]]:
datasource_id: Optional[int],
datasource_type: Optional[str],
form_data: FormData,
) -> tuple[int | UUID, Optional[str]]:
"""
Compatibility layer for handling of datasource info
@@ -303,7 +303,11 @@ def get_datasource_info(
_("The dataset associated with this chart no longer exists")
)
datasource_id = int(datasource_id)
if datasource_id.isdigit():
datasource_id = int(datasource_id)
else:
datasource_id = UUID(datasource_id)
return datasource_id, datasource_type

View File

@@ -39,32 +39,26 @@ processor = QueryContextProcessor(
)
# Bind ExploreMixin methods to datasource for testing
processor._qc_datasource.add_offset_join_column = (
ExploreMixin.add_offset_join_column.__get__(processor._qc_datasource)
# Type annotation needed because _qc_datasource is typed as Explorable in protocol
_datasource: BaseDatasource = processor._qc_datasource # type: ignore
_datasource.add_offset_join_column = ExploreMixin.add_offset_join_column.__get__(
_datasource
)
processor._qc_datasource.join_offset_dfs = ExploreMixin.join_offset_dfs.__get__(
processor._qc_datasource
_datasource.join_offset_dfs = ExploreMixin.join_offset_dfs.__get__(_datasource)
_datasource.is_valid_date_range = ExploreMixin.is_valid_date_range.__get__(_datasource)
_datasource._determine_join_keys = ExploreMixin._determine_join_keys.__get__(
_datasource
)
processor._qc_datasource.is_valid_date_range = ExploreMixin.is_valid_date_range.__get__(
processor._qc_datasource
)
processor._qc_datasource._determine_join_keys = (
ExploreMixin._determine_join_keys.__get__(processor._qc_datasource)
)
processor._qc_datasource._perform_join = ExploreMixin._perform_join.__get__(
processor._qc_datasource
)
processor._qc_datasource._apply_cleanup_logic = (
ExploreMixin._apply_cleanup_logic.__get__(processor._qc_datasource)
_datasource._perform_join = ExploreMixin._perform_join.__get__(_datasource)
_datasource._apply_cleanup_logic = ExploreMixin._apply_cleanup_logic.__get__(
_datasource
)
# Static methods don't need binding - assign directly
processor._qc_datasource.generate_join_column = ExploreMixin.generate_join_column
processor._qc_datasource.is_valid_date_range_static = (
ExploreMixin.is_valid_date_range_static
)
_datasource.generate_join_column = ExploreMixin.generate_join_column
_datasource.is_valid_date_range_static = ExploreMixin.is_valid_date_range_static
# Convenience reference for backward compatibility in tests
query_context_processor = processor._qc_datasource
query_context_processor = _datasource
@fixture