mirror of
https://github.com/apache/superset.git
synced 2026-04-28 12:34:23 +00:00
Compare commits
20 Commits
fdf19db5e6
...
semantic-l
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b9ab0ced77 | ||
|
|
bfbb68c3c8 | ||
|
|
b437421a8e | ||
|
|
e253bd2fb3 | ||
|
|
bfb7048e42 | ||
|
|
2833b69ca0 | ||
|
|
6e17714a19 | ||
|
|
8a0aaa42ec | ||
|
|
af479a9d99 | ||
|
|
77f60f42e6 | ||
|
|
f0121a166e | ||
|
|
0c4b0cb9b9 | ||
|
|
a36bbf8ffd | ||
|
|
99525c1ce9 | ||
|
|
889e9bbade | ||
|
|
b809a990ee | ||
|
|
9c7fcbf548 | ||
|
|
046aabee73 | ||
|
|
b672c7b853 | ||
|
|
ea33d797a7 |
@@ -53,7 +53,7 @@ extension-pkg-whitelist=pyarrow
|
||||
|
||||
[MESSAGES CONTROL]
|
||||
disable=all
|
||||
enable=disallowed-json-import,disallowed-sql-import,consider-using-transaction
|
||||
enable=json-import,disallowed-sql-import,consider-using-transaction
|
||||
|
||||
|
||||
[REPORTS]
|
||||
|
||||
7
setup.py
7
setup.py
@@ -30,7 +30,9 @@ with open(PACKAGE_JSON) as package_file:
|
||||
|
||||
def get_git_sha() -> str:
|
||||
try:
|
||||
output = subprocess.check_output(["git", "rev-parse", "HEAD"]) # noqa: S603, S607
|
||||
output = subprocess.check_output(
|
||||
["git", "rev-parse", "HEAD"]
|
||||
) # noqa: S603, S607
|
||||
return output.decode().strip()
|
||||
except Exception: # pylint: disable=broad-except
|
||||
return ""
|
||||
@@ -58,6 +60,9 @@ setup(
|
||||
include_package_data=True,
|
||||
zip_safe=False,
|
||||
entry_points={
|
||||
"superset.semantic_layers": [
|
||||
"snowflake = superset.semantic_layers.snowflake:SnowflakeSemanticLayer"
|
||||
],
|
||||
"console_scripts": ["superset=superset.cli.main:superset"],
|
||||
# the `postgres` and `postgres+psycopg2://` schemes were removed in SQLAlchemy 1.4 # noqa: E501
|
||||
# add an alias here to prevent breaking existing databases
|
||||
|
||||
@@ -19,16 +19,27 @@
|
||||
|
||||
import { DatasourceType } from './types/Datasource';
|
||||
|
||||
const DATASOURCE_TYPE_MAP: Record<string, DatasourceType> = {
|
||||
table: DatasourceType.Table,
|
||||
query: DatasourceType.Query,
|
||||
dataset: DatasourceType.Dataset,
|
||||
sl_table: DatasourceType.SlTable,
|
||||
saved_query: DatasourceType.SavedQuery,
|
||||
semantic_view: DatasourceType.SemanticView,
|
||||
};
|
||||
|
||||
export default class DatasourceKey {
|
||||
readonly id: number;
|
||||
readonly id: number | string;
|
||||
|
||||
readonly type: DatasourceType;
|
||||
|
||||
constructor(key: string) {
|
||||
const [idStr, typeStr] = key.split('__');
|
||||
this.id = parseInt(idStr, 10);
|
||||
this.type = DatasourceType.Table; // default to SqlaTable model
|
||||
this.type = typeStr === 'query' ? DatasourceType.Query : this.type;
|
||||
// Only parse as integer if the entire string is numeric
|
||||
// (parseInt would incorrectly parse "85d3139f..." as 85)
|
||||
const isNumeric = /^\d+$/.test(idStr);
|
||||
this.id = isNumeric ? parseInt(idStr, 10) : idStr;
|
||||
this.type = DATASOURCE_TYPE_MAP[typeStr] ?? DatasourceType.Table;
|
||||
}
|
||||
|
||||
public toString() {
|
||||
|
||||
@@ -26,6 +26,7 @@ export enum DatasourceType {
|
||||
Dataset = 'dataset',
|
||||
SlTable = 'sl_table',
|
||||
SavedQuery = 'saved_query',
|
||||
SemanticView = 'semantic_view',
|
||||
}
|
||||
|
||||
export interface Currency {
|
||||
@@ -37,7 +38,7 @@ export interface Currency {
|
||||
* Datasource metadata.
|
||||
*/
|
||||
export interface Datasource {
|
||||
id: number;
|
||||
id: number | string;
|
||||
name: string;
|
||||
type: DatasourceType;
|
||||
columns: Column[];
|
||||
|
||||
@@ -156,7 +156,7 @@ export interface QueryObject
|
||||
|
||||
export interface QueryContext {
|
||||
datasource: {
|
||||
id: number;
|
||||
id: number | string;
|
||||
type: DatasourceType;
|
||||
};
|
||||
/** Force refresh of all queries */
|
||||
|
||||
@@ -90,11 +90,16 @@ const ModalFooter = ({ formData, closeModal }: ModalFooterProps) => {
|
||||
findPermission('can_explore', 'Superset', state.user?.roles),
|
||||
);
|
||||
|
||||
const [datasource_id, datasource_type] = formData.datasource.split('__');
|
||||
const [datasourceIdStr, datasource_type] = formData.datasource.split('__');
|
||||
// Try to parse as integer, fall back to string (UUID) if NaN
|
||||
const parsedDatasourceId = parseInt(datasourceIdStr, 10);
|
||||
const datasource_id = Number.isNaN(parsedDatasourceId)
|
||||
? datasourceIdStr
|
||||
: parsedDatasourceId;
|
||||
useEffect(() => {
|
||||
// short circuit if the user is embedded as explore is not available
|
||||
if (isEmbedded()) return;
|
||||
postFormData(Number(datasource_id), datasource_type, formData, 0)
|
||||
postFormData(datasource_id, datasource_type, formData, 0)
|
||||
.then(key => {
|
||||
setUrl(
|
||||
`/explore/?form_data_key=${key}&dashboard_page_id=${dashboardPageId}`,
|
||||
|
||||
@@ -272,7 +272,7 @@ export type Slice = {
|
||||
changed_on: number;
|
||||
changed_on_humanized: string;
|
||||
modified: string;
|
||||
datasource_id: number;
|
||||
datasource_id: number | string;
|
||||
datasource_type: DatasourceType;
|
||||
datasource_url: string;
|
||||
datasource_name: string;
|
||||
|
||||
@@ -144,12 +144,14 @@ export const getSlicePayload = async (
|
||||
...adhocFilters,
|
||||
dashboards,
|
||||
};
|
||||
let datasourceId = 0;
|
||||
let datasourceId: number | string = 0;
|
||||
let datasourceType: DatasourceType = DatasourceType.Table;
|
||||
|
||||
if (formData.datasource) {
|
||||
const [id, typeString] = formData.datasource.split('__');
|
||||
datasourceId = parseInt(id, 10);
|
||||
// Try to parse as integer, fall back to string (UUID) if NaN
|
||||
const parsedId = parseInt(id, 10);
|
||||
datasourceId = Number.isNaN(parsedId) ? id : parsedId;
|
||||
|
||||
const formattedTypeString =
|
||||
typeString.charAt(0).toUpperCase() + typeString.slice(1);
|
||||
|
||||
@@ -20,7 +20,7 @@ import { SupersetClient, JsonObject, JsonResponse } from '@superset-ui/core';
|
||||
import { sanitizeFormData } from 'src/utils/sanitizeFormData';
|
||||
|
||||
type Payload = {
|
||||
datasource_id: number;
|
||||
datasource_id: number | string;
|
||||
datasource_type: string;
|
||||
form_data: string;
|
||||
chart_id?: number;
|
||||
@@ -36,7 +36,7 @@ const assembleEndpoint = (key?: string, tabId?: string) => {
|
||||
};
|
||||
|
||||
const assemblePayload = (
|
||||
datasourceId: number,
|
||||
datasourceId: number | string,
|
||||
datasourceType: string,
|
||||
formData: JsonObject,
|
||||
chartId?: number,
|
||||
@@ -53,7 +53,7 @@ const assemblePayload = (
|
||||
};
|
||||
|
||||
export const postFormData = (
|
||||
datasourceId: number,
|
||||
datasourceId: number | string,
|
||||
datasourceType: string,
|
||||
formData: JsonObject,
|
||||
chartId?: number,
|
||||
@@ -70,7 +70,7 @@ export const postFormData = (
|
||||
}).then((r: JsonResponse) => r.json.key);
|
||||
|
||||
export const putFormData = (
|
||||
datasourceId: number,
|
||||
datasourceId: number | string,
|
||||
datasourceType: string,
|
||||
key: string,
|
||||
formData: JsonObject,
|
||||
|
||||
@@ -68,7 +68,7 @@ class StreamingCSVExportCommand(BaseStreamingCSVExportCommand):
|
||||
query_obj = self._query_context.queries[0]
|
||||
sql_query = datasource.get_query_str(query_obj.to_dict())
|
||||
|
||||
return sql_query, datasource.database
|
||||
return sql_query, getattr(datasource, "database", None)
|
||||
|
||||
def _get_row_limit(self) -> int | None:
|
||||
"""
|
||||
|
||||
@@ -78,6 +78,13 @@ class ExportDatasetsCommand(ExportModelsCommand):
|
||||
payload["version"] = EXPORT_VERSION
|
||||
payload["database_uuid"] = str(model.database.uuid)
|
||||
|
||||
# Always set cache_timeout from the property to ensure correct value
|
||||
payload["cache_timeout"] = model.cache_timeout
|
||||
|
||||
# SQLAlchemy returns column names as quoted_name objects which PyYAML cannot
|
||||
# serialize. Convert all keys to regular strings to fix YAML serialization.
|
||||
payload = {str(key): value for key, value in payload.items()}
|
||||
|
||||
file_content = yaml.safe_dump(payload, sort_keys=False)
|
||||
return file_content
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ from superset.exceptions import SupersetException
|
||||
from superset.explore.exceptions import WrongEndpointError
|
||||
from superset.explore.permalink.exceptions import ExplorePermalinkGetFailedError
|
||||
from superset.extensions import security_manager
|
||||
from superset.superset_typing import BaseDatasourceData, QueryData
|
||||
from superset.superset_typing import ExplorableData
|
||||
from superset.utils import core as utils, json
|
||||
from superset.views.utils import (
|
||||
get_datasource_info,
|
||||
@@ -124,7 +124,11 @@ class GetExploreCommand(BaseCommand, ABC):
|
||||
security_manager.raise_for_access(datasource=datasource)
|
||||
|
||||
viz_type = form_data.get("viz_type")
|
||||
if not viz_type and datasource and datasource.default_endpoint:
|
||||
if (
|
||||
not viz_type
|
||||
and datasource
|
||||
and getattr(datasource, "default_endpoint", None)
|
||||
):
|
||||
raise WrongEndpointError(redirect=datasource.default_endpoint)
|
||||
|
||||
form_data["datasource"] = (
|
||||
@@ -136,7 +140,7 @@ class GetExploreCommand(BaseCommand, ABC):
|
||||
utils.merge_extra_filters(form_data)
|
||||
utils.merge_request_params(form_data, request.args)
|
||||
|
||||
datasource_data: BaseDatasourceData | QueryData = {
|
||||
datasource_data: ExplorableData = {
|
||||
"type": self._datasource_type or "unknown",
|
||||
"name": datasource_name,
|
||||
"columns": [],
|
||||
|
||||
@@ -15,13 +15,13 @@
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
|
||||
@dataclass
|
||||
class CommandParameters:
|
||||
permalink_key: Optional[str]
|
||||
form_data_key: Optional[str]
|
||||
datasource_id: Optional[int]
|
||||
datasource_id: Optional[Union[int, str]]
|
||||
datasource_type: Optional[str]
|
||||
slice_id: Optional[int]
|
||||
|
||||
@@ -23,8 +23,8 @@ from flask_babel import _
|
||||
|
||||
from superset.common.chart_data import ChartDataResultType
|
||||
from superset.common.db_query_status import QueryStatus
|
||||
from superset.connectors.sqla.models import BaseDatasource
|
||||
from superset.exceptions import QueryObjectValidationError, SupersetParseError
|
||||
from superset.explorables.base import Explorable
|
||||
from superset.utils.core import (
|
||||
extract_column_dtype,
|
||||
extract_dataframe_dtypes,
|
||||
@@ -38,9 +38,7 @@ if TYPE_CHECKING:
|
||||
from superset.common.query_object import QueryObject
|
||||
|
||||
|
||||
def _get_datasource(
|
||||
query_context: QueryContext, query_obj: QueryObject
|
||||
) -> BaseDatasource:
|
||||
def _get_datasource(query_context: QueryContext, query_obj: QueryObject) -> Explorable:
|
||||
return query_obj.datasource or query_context.datasource
|
||||
|
||||
|
||||
@@ -64,16 +62,9 @@ def _get_timegrains(
|
||||
query_context: QueryContext, query_obj: QueryObject, _: bool
|
||||
) -> dict[str, Any]:
|
||||
datasource = _get_datasource(query_context, query_obj)
|
||||
return {
|
||||
"data": [
|
||||
{
|
||||
"name": grain.name,
|
||||
"function": grain.function,
|
||||
"duration": grain.duration,
|
||||
}
|
||||
for grain in datasource.database.grains()
|
||||
]
|
||||
}
|
||||
# Use the new get_time_grains() method from Explorable protocol
|
||||
grains = datasource.get_time_grains()
|
||||
return {"data": grains}
|
||||
|
||||
|
||||
def _get_query(
|
||||
@@ -158,7 +149,8 @@ def _get_samples(
|
||||
qry_obj_cols = []
|
||||
for o in datasource.columns:
|
||||
if isinstance(o, dict):
|
||||
qry_obj_cols.append(o.get("column_name"))
|
||||
if column_name := o.get("column_name"):
|
||||
qry_obj_cols.append(column_name)
|
||||
else:
|
||||
qry_obj_cols.append(o.column_name)
|
||||
query_obj.columns = qry_obj_cols
|
||||
@@ -180,7 +172,8 @@ def _get_drill_detail(
|
||||
qry_obj_cols = []
|
||||
for o in datasource.columns:
|
||||
if isinstance(o, dict):
|
||||
qry_obj_cols.append(o.get("column_name"))
|
||||
if column_name := o.get("column_name"):
|
||||
qry_obj_cols.append(column_name)
|
||||
else:
|
||||
qry_obj_cols.append(o.column_name)
|
||||
query_obj.columns = qry_obj_cols
|
||||
|
||||
@@ -24,11 +24,11 @@ import pandas as pd
|
||||
from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType
|
||||
from superset.common.query_context_processor import QueryContextProcessor
|
||||
from superset.common.query_object import QueryObject
|
||||
from superset.explorables.base import Explorable
|
||||
from superset.models.slice import Slice
|
||||
from superset.utils.core import GenericDataType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.connectors.sqla.models import BaseDatasource
|
||||
from superset.models.helpers import QueryResult
|
||||
|
||||
|
||||
@@ -44,7 +44,7 @@ class QueryContext:
|
||||
cache_type: ClassVar[str] = "df"
|
||||
enforce_numerical_metrics: ClassVar[bool] = True
|
||||
|
||||
datasource: BaseDatasource
|
||||
datasource: Explorable
|
||||
slice_: Slice | None = None
|
||||
queries: list[QueryObject]
|
||||
form_data: dict[str, Any] | None
|
||||
@@ -62,7 +62,7 @@ class QueryContext:
|
||||
def __init__( # pylint: disable=too-many-arguments
|
||||
self,
|
||||
*,
|
||||
datasource: BaseDatasource,
|
||||
datasource: Explorable,
|
||||
queries: list[QueryObject],
|
||||
slice_: Slice | None,
|
||||
form_data: dict[str, Any] | None,
|
||||
@@ -99,15 +99,24 @@ class QueryContext:
|
||||
return self._processor.get_payload(cache_query_context, force_cached)
|
||||
|
||||
def get_cache_timeout(self) -> int | None:
|
||||
"""
|
||||
Get the cache timeout for this query context.
|
||||
|
||||
Priority order:
|
||||
1. Custom timeout set for this specific query
|
||||
2. Chart-level timeout (if querying from a saved chart)
|
||||
3. Datasource-level timeout (explorable handles its own fallback logic)
|
||||
4. System default (None)
|
||||
|
||||
Note: Each explorable is responsible for its own internal fallback chain.
|
||||
For example, BaseDatasource falls back to database.cache_timeout,
|
||||
while semantic layers might fall back to their layer's default.
|
||||
"""
|
||||
if self.custom_cache_timeout is not None:
|
||||
return self.custom_cache_timeout
|
||||
if self.slice_ and self.slice_.cache_timeout is not None:
|
||||
return self.slice_.cache_timeout
|
||||
if self.datasource.cache_timeout is not None:
|
||||
return self.datasource.cache_timeout
|
||||
if hasattr(self.datasource, "database"):
|
||||
return self.datasource.database.cache_timeout
|
||||
return None
|
||||
return self.datasource.cache_timeout
|
||||
|
||||
def query_cache_key(self, query_obj: QueryObject, **kwargs: Any) -> str | None:
|
||||
return self._processor.query_cache_key(query_obj, **kwargs)
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
# under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, TYPE_CHECKING
|
||||
from typing import Any
|
||||
|
||||
from flask import current_app
|
||||
|
||||
@@ -26,13 +26,11 @@ from superset.common.query_object import QueryObject
|
||||
from superset.common.query_object_factory import QueryObjectFactory
|
||||
from superset.daos.chart import ChartDAO
|
||||
from superset.daos.datasource import DatasourceDAO
|
||||
from superset.explorables.base import Explorable
|
||||
from superset.models.slice import Slice
|
||||
from superset.superset_typing import Column
|
||||
from superset.utils.core import DatasourceDict, DatasourceType, is_adhoc_column
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.connectors.sqla.models import BaseDatasource
|
||||
|
||||
|
||||
def create_query_object_factory() -> QueryObjectFactory:
|
||||
return QueryObjectFactory(current_app.config, DatasourceDAO())
|
||||
@@ -104,7 +102,7 @@ class QueryContextFactory: # pylint: disable=too-few-public-methods
|
||||
cache_values=cache_values,
|
||||
)
|
||||
|
||||
def _convert_to_model(self, datasource: DatasourceDict) -> BaseDatasource:
|
||||
def _convert_to_model(self, datasource: DatasourceDict) -> Explorable:
|
||||
return DatasourceDAO.get_datasource(
|
||||
datasource_type=DatasourceType(datasource["type"]),
|
||||
database_id_or_uuid=datasource["id"],
|
||||
@@ -115,7 +113,7 @@ class QueryContextFactory: # pylint: disable=too-few-public-methods
|
||||
|
||||
def _process_query_object(
|
||||
self,
|
||||
datasource: BaseDatasource,
|
||||
datasource: Explorable,
|
||||
form_data: dict[str, Any] | None,
|
||||
query_object: QueryObject,
|
||||
) -> QueryObject:
|
||||
@@ -201,7 +199,7 @@ class QueryContextFactory: # pylint: disable=too-few-public-methods
|
||||
self,
|
||||
query_object: QueryObject,
|
||||
form_data: dict[str, Any] | None,
|
||||
datasource: BaseDatasource,
|
||||
datasource: Explorable,
|
||||
) -> None:
|
||||
temporal_columns = {
|
||||
column["column_name"] if isinstance(column, dict) else column.column_name
|
||||
|
||||
@@ -29,7 +29,6 @@ from superset.common.db_query_status import QueryStatus
|
||||
from superset.common.query_actions import get_query_results
|
||||
from superset.common.utils.query_cache_manager import QueryCacheManager
|
||||
from superset.common.utils.time_range_utils import get_since_until_from_time_range
|
||||
from superset.connectors.sqla.models import BaseDatasource
|
||||
from superset.constants import CACHE_DISABLED_TIMEOUT, CacheRegion
|
||||
from superset.daos.annotation_layer import AnnotationLayerDAO
|
||||
from superset.daos.chart import ChartDAO
|
||||
@@ -37,6 +36,7 @@ from superset.exceptions import (
|
||||
QueryObjectValidationError,
|
||||
SupersetException,
|
||||
)
|
||||
from superset.explorables.base import Explorable
|
||||
from superset.extensions import cache_manager, security_manager
|
||||
from superset.models.helpers import QueryResult
|
||||
from superset.superset_typing import AdhocColumn, AdhocMetric
|
||||
@@ -70,7 +70,7 @@ class QueryContextProcessor:
|
||||
"""
|
||||
|
||||
_query_context: QueryContext
|
||||
_qc_datasource: BaseDatasource
|
||||
_qc_datasource: Explorable
|
||||
|
||||
def __init__(self, query_context: QueryContext):
|
||||
self._query_context = query_context
|
||||
|
||||
@@ -85,6 +85,7 @@ from superset.exceptions import (
|
||||
SupersetSecurityException,
|
||||
SupersetSyntaxErrorException,
|
||||
)
|
||||
from superset.explorables.base import TimeGrainDict
|
||||
from superset.jinja_context import (
|
||||
BaseTemplateProcessor,
|
||||
ExtraCache,
|
||||
@@ -105,7 +106,7 @@ from superset.sql.parse import Table
|
||||
from superset.superset_typing import (
|
||||
AdhocColumn,
|
||||
AdhocMetric,
|
||||
BaseDatasourceData,
|
||||
ExplorableData,
|
||||
Metric,
|
||||
QueryObjectDict,
|
||||
ResultSetColumnType,
|
||||
@@ -198,7 +199,7 @@ class BaseDatasource(
|
||||
is_featured = Column(Boolean, default=False) # TODO deprecating
|
||||
filter_select_enabled = Column(Boolean, default=True)
|
||||
offset = Column(Integer, default=0)
|
||||
cache_timeout = Column(Integer)
|
||||
_cache_timeout = Column("cache_timeout", Integer)
|
||||
params = Column(String(1000))
|
||||
perm = Column(String(1000))
|
||||
schema_perm = Column(String(1000))
|
||||
@@ -212,6 +213,78 @@ class BaseDatasource(
|
||||
|
||||
extra_import_fields = ["is_managed_externally", "external_url"]
|
||||
|
||||
@property
|
||||
def cache_timeout(self) -> int | None:
|
||||
"""
|
||||
Get the cache timeout for this datasource.
|
||||
|
||||
Implements the Explorable protocol by handling the fallback chain:
|
||||
1. Datasource-specific timeout (if set)
|
||||
2. Database default timeout (if no datasource timeout)
|
||||
3. None (use system default)
|
||||
|
||||
This allows each datasource to override caching, while falling back
|
||||
to database-level defaults when appropriate.
|
||||
"""
|
||||
if self._cache_timeout is not None:
|
||||
return self._cache_timeout
|
||||
|
||||
# database should always be set, but that's not true for v0 import
|
||||
if self.database:
|
||||
return self.database.cache_timeout
|
||||
|
||||
return None
|
||||
|
||||
@cache_timeout.setter
|
||||
def cache_timeout(self, value: int | None) -> None:
|
||||
"""Set the datasource-specific cache timeout."""
|
||||
self._cache_timeout = value
|
||||
|
||||
def has_drill_by_columns(self, column_names: list[str]) -> bool:
|
||||
"""
|
||||
Check if the specified columns support drill-by operations.
|
||||
|
||||
For SQL datasources, drill-by is supported on columns that are marked
|
||||
as groupable in the metadata. This allows users to navigate from
|
||||
aggregated views to detailed data by grouping on these dimensions.
|
||||
|
||||
:param column_names: List of column names to check
|
||||
:return: True if all columns support drill-by, False otherwise
|
||||
"""
|
||||
if not column_names:
|
||||
return False
|
||||
|
||||
# Get all groupable column names for this datasource
|
||||
drillable_columns = {
|
||||
row[0]
|
||||
for row in db.session.query(TableColumn.column_name)
|
||||
.filter(TableColumn.table_id == self.id)
|
||||
.filter(TableColumn.groupby)
|
||||
.all()
|
||||
}
|
||||
|
||||
# Check if all requested columns are drillable
|
||||
return set(column_names).issubset(drillable_columns)
|
||||
|
||||
def get_time_grains(self) -> list[TimeGrainDict]:
|
||||
"""
|
||||
Get available time granularities from the database.
|
||||
|
||||
Implements the Explorable protocol by delegating to the database's
|
||||
time grain definitions. Each database engine spec defines its own
|
||||
set of supported time grains.
|
||||
|
||||
:return: List of time grain dictionaries with name, function, and duration
|
||||
"""
|
||||
return [
|
||||
{
|
||||
"name": grain.name,
|
||||
"function": grain.function,
|
||||
"duration": grain.duration,
|
||||
}
|
||||
for grain in (self.database.grains() or [])
|
||||
]
|
||||
|
||||
@property
|
||||
def kind(self) -> DatasourceKind:
|
||||
return DatasourceKind.VIRTUAL if self.sql else DatasourceKind.PHYSICAL
|
||||
@@ -363,7 +436,7 @@ class BaseDatasource(
|
||||
return verb_map
|
||||
|
||||
@property
|
||||
def data(self) -> BaseDatasourceData:
|
||||
def data(self) -> ExplorableData:
|
||||
"""Data representation of the datasource sent to the frontend"""
|
||||
return {
|
||||
# simple fields
|
||||
@@ -1356,7 +1429,7 @@ class SqlaTable(
|
||||
return [(g.duration, g.name) for g in self.database.grains() or []]
|
||||
|
||||
@property
|
||||
def data(self) -> BaseDatasourceData:
|
||||
def data(self) -> ExplorableData:
|
||||
data_ = super().data
|
||||
if self.type == "table":
|
||||
data_["granularity_sqla"] = self.granularity_sqla
|
||||
|
||||
@@ -28,6 +28,7 @@ from superset.daos.exceptions import (
|
||||
DatasourceValueIsIncorrect,
|
||||
)
|
||||
from superset.models.sql_lab import Query, SavedQuery
|
||||
from superset.semantic_layers.models import SemanticView
|
||||
from superset.utils.core import DatasourceType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -40,6 +41,7 @@ class DatasourceDAO(BaseDAO[Datasource]):
|
||||
DatasourceType.TABLE: SqlaTable,
|
||||
DatasourceType.QUERY: Query,
|
||||
DatasourceType.SAVEDQUERY: SavedQuery,
|
||||
DatasourceType.SEMANTIC_VIEW: SemanticView,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
||||
152
superset/daos/semantic_layer.py
Normal file
152
superset/daos/semantic_layer.py
Normal file
@@ -0,0 +1,152 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""DAOs for semantic layer models."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from superset.daos.base import BaseDAO
|
||||
from superset.extensions import db
|
||||
from superset.semantic_layers.models import SemanticLayer, SemanticView
|
||||
|
||||
|
||||
class SemanticLayerDAO(BaseDAO[SemanticLayer]):
|
||||
"""
|
||||
Data Access Object for SemanticLayer model.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def validate_uniqueness(name: str) -> bool:
|
||||
"""
|
||||
Validate that semantic layer name is unique.
|
||||
|
||||
:param name: Semantic layer name
|
||||
:return: True if name is unique, False otherwise
|
||||
"""
|
||||
query = db.session.query(SemanticLayer).filter(SemanticLayer.name == name)
|
||||
return not db.session.query(query.exists()).scalar()
|
||||
|
||||
@staticmethod
|
||||
def validate_update_uniqueness(layer_uuid: str, name: str) -> bool:
|
||||
"""
|
||||
Validate that semantic layer name is unique for updates.
|
||||
|
||||
:param layer_uuid: UUID of the semantic layer being updated
|
||||
:param name: New name to validate
|
||||
:return: True if name is unique, False otherwise
|
||||
"""
|
||||
query = db.session.query(SemanticLayer).filter(
|
||||
SemanticLayer.name == name,
|
||||
SemanticLayer.uuid != layer_uuid,
|
||||
)
|
||||
return not db.session.query(query.exists()).scalar()
|
||||
|
||||
@staticmethod
|
||||
def find_by_name(name: str) -> SemanticLayer | None:
|
||||
"""
|
||||
Find semantic layer by name.
|
||||
|
||||
:param name: Semantic layer name
|
||||
:return: SemanticLayer instance or None
|
||||
"""
|
||||
return (
|
||||
db.session.query(SemanticLayer)
|
||||
.filter(SemanticLayer.name == name)
|
||||
.one_or_none()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_semantic_views(cls, layer_uuid: str) -> list[SemanticView]:
|
||||
"""
|
||||
Get all semantic views for a semantic layer.
|
||||
|
||||
:param layer_uuid: UUID of the semantic layer
|
||||
:return: List of SemanticView instances
|
||||
"""
|
||||
return (
|
||||
db.session.query(SemanticView)
|
||||
.filter(SemanticView.semantic_layer_uuid == layer_uuid)
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
class SemanticViewDAO(BaseDAO[SemanticView]):
|
||||
"""Data Access Object for SemanticView model."""
|
||||
|
||||
@staticmethod
|
||||
def find_by_semantic_layer(layer_uuid: str) -> list[SemanticView]:
|
||||
"""
|
||||
Find all views for a semantic layer.
|
||||
|
||||
:param layer_uuid: UUID of the semantic layer
|
||||
:return: List of SemanticView instances
|
||||
"""
|
||||
return (
|
||||
db.session.query(SemanticView)
|
||||
.filter(SemanticView.semantic_layer_uuid == layer_uuid)
|
||||
.all()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def validate_uniqueness(name: str, layer_uuid: str) -> bool:
|
||||
"""
|
||||
Validate that view name is unique within semantic layer.
|
||||
|
||||
:param name: View name
|
||||
:param layer_uuid: UUID of the semantic layer
|
||||
:return: True if name is unique within layer, False otherwise
|
||||
"""
|
||||
query = db.session.query(SemanticView).filter(
|
||||
SemanticView.name == name,
|
||||
SemanticView.semantic_layer_uuid == layer_uuid,
|
||||
)
|
||||
return not db.session.query(query.exists()).scalar()
|
||||
|
||||
@staticmethod
|
||||
def validate_update_uniqueness(view_uuid: str, name: str, layer_uuid: str) -> bool:
|
||||
"""
|
||||
Validate that view name is unique within semantic layer for updates.
|
||||
|
||||
:param view_uuid: UUID of the view being updated
|
||||
:param name: New name to validate
|
||||
:param layer_uuid: UUID of the semantic layer
|
||||
:return: True if name is unique within layer, False otherwise
|
||||
"""
|
||||
query = db.session.query(SemanticView).filter(
|
||||
SemanticView.name == name,
|
||||
SemanticView.semantic_layer_uuid == layer_uuid,
|
||||
SemanticView.uuid != view_uuid,
|
||||
)
|
||||
return not db.session.query(query.exists()).scalar()
|
||||
|
||||
@staticmethod
|
||||
def find_by_name(name: str, layer_uuid: str) -> SemanticView | None:
|
||||
"""
|
||||
Find semantic view by name within a semantic layer.
|
||||
|
||||
:param name: View name
|
||||
:param layer_uuid: UUID of the semantic layer
|
||||
:return: SemanticView instance or None
|
||||
"""
|
||||
return (
|
||||
db.session.query(SemanticView)
|
||||
.filter(
|
||||
SemanticView.name == name,
|
||||
SemanticView.semantic_layer_uuid == layer_uuid,
|
||||
)
|
||||
.one_or_none()
|
||||
)
|
||||
497
superset/explorables/base.py
Normal file
497
superset/explorables/base.py
Normal file
@@ -0,0 +1,497 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
Base protocol for explorable data sources in Superset.
|
||||
|
||||
An "explorable" is any data source that can be explored to create charts,
|
||||
including SQL datasets, saved queries, and semantic layer views.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Hashable
|
||||
from datetime import datetime
|
||||
from typing import Any, Protocol, runtime_checkable, TYPE_CHECKING, TypedDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.common.query_object import QueryObject
|
||||
from superset.models.helpers import QueryResult
|
||||
from superset.superset_typing import ExplorableData, QueryObjectDict
|
||||
|
||||
|
||||
class TimeGrainDict(TypedDict):
|
||||
"""
|
||||
TypedDict for time grain options returned by get_time_grains.
|
||||
|
||||
Represents a time granularity option that can be used for grouping
|
||||
temporal data. Each time grain specifies how to bucket timestamps.
|
||||
|
||||
Attributes:
|
||||
name: Display name for the time grain (e.g., "Hour", "Day", "Week")
|
||||
function: Implementation-specific expression for applying the grain.
|
||||
For SQL datasources, this is typically a SQL expression template
|
||||
like "DATE_TRUNC('hour', {col})".
|
||||
duration: ISO 8601 duration string (e.g., "PT1H", "P1D", "P1W")
|
||||
"""
|
||||
|
||||
name: str
|
||||
function: str
|
||||
duration: str | None
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class MetricMetadata(Protocol):
|
||||
"""
|
||||
Protocol for metric metadata objects.
|
||||
|
||||
Represents a metric that's available on an explorable data source.
|
||||
Metrics contain SQL expressions or references to semantic layer measures.
|
||||
|
||||
Attributes:
|
||||
metric_name: Unique identifier for the metric
|
||||
expression: SQL expression or reference for calculating the metric
|
||||
verbose_name: Human-readable name for display in the UI
|
||||
description: Description of what the metric represents
|
||||
d3format: D3 format string for formatting numeric values
|
||||
currency: Currency configuration for the metric (JSON object)
|
||||
warning_text: Warning message to display when using this metric
|
||||
certified_by: Person or entity that certified this metric
|
||||
certification_details: Details about the certification
|
||||
"""
|
||||
|
||||
@property
|
||||
def metric_name(self) -> str:
|
||||
"""Unique identifier for the metric."""
|
||||
|
||||
@property
|
||||
def expression(self) -> str:
|
||||
"""SQL expression or reference for calculating the metric."""
|
||||
|
||||
@property
|
||||
def verbose_name(self) -> str | None:
|
||||
"""Human-readable name for display in the UI."""
|
||||
|
||||
@property
|
||||
def description(self) -> str | None:
|
||||
"""Description of what the metric represents."""
|
||||
|
||||
@property
|
||||
def d3format(self) -> str | None:
|
||||
"""D3 format string for formatting numeric values."""
|
||||
|
||||
@property
|
||||
def currency(self) -> dict[str, Any] | None:
|
||||
"""Currency configuration for the metric (JSON object)."""
|
||||
|
||||
@property
|
||||
def warning_text(self) -> str | None:
|
||||
"""Warning message to display when using this metric."""
|
||||
|
||||
@property
|
||||
def certified_by(self) -> str | None:
|
||||
"""Person or entity that certified this metric."""
|
||||
|
||||
@property
|
||||
def certification_details(self) -> str | None:
|
||||
"""Details about the certification."""
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ColumnMetadata(Protocol):
|
||||
"""
|
||||
Protocol for column metadata objects.
|
||||
|
||||
Represents a column/dimension that's available on an explorable data source.
|
||||
Used for grouping, filtering, and dimension-based analysis.
|
||||
|
||||
Attributes:
|
||||
column_name: Unique identifier for the column
|
||||
type: SQL data type of the column (e.g., 'VARCHAR', 'INTEGER', 'DATETIME')
|
||||
is_dttm: Whether this column represents a date or time value
|
||||
verbose_name: Human-readable name for display in the UI
|
||||
description: Description of what the column represents
|
||||
groupby: Whether this column is allowed for grouping/aggregation
|
||||
filterable: Whether this column can be used in filters
|
||||
expression: SQL expression if this is a calculated column
|
||||
python_date_format: Python datetime format string for temporal columns
|
||||
advanced_data_type: Advanced data type classification
|
||||
extra: Additional metadata stored as JSON
|
||||
"""
|
||||
|
||||
@property
|
||||
def column_name(self) -> str:
|
||||
"""Unique identifier for the column."""
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
"""SQL data type of the column."""
|
||||
|
||||
@property
|
||||
def is_dttm(self) -> bool:
|
||||
"""Whether this column represents a date or time value."""
|
||||
|
||||
@property
|
||||
def verbose_name(self) -> str | None:
|
||||
"""Human-readable name for display in the UI."""
|
||||
|
||||
@property
|
||||
def description(self) -> str | None:
|
||||
"""Description of what the column represents."""
|
||||
|
||||
@property
|
||||
def groupby(self) -> bool:
|
||||
"""Whether this column is allowed for grouping/aggregation."""
|
||||
|
||||
@property
|
||||
def filterable(self) -> bool:
|
||||
"""Whether this column can be used in filters."""
|
||||
|
||||
@property
|
||||
def expression(self) -> str | None:
|
||||
"""SQL expression if this is a calculated column."""
|
||||
|
||||
@property
|
||||
def python_date_format(self) -> str | None:
|
||||
"""Python datetime format string for temporal columns."""
|
||||
|
||||
@property
|
||||
def advanced_data_type(self) -> str | None:
|
||||
"""Advanced data type classification."""
|
||||
|
||||
@property
|
||||
def extra(self) -> str | None:
|
||||
"""Additional metadata stored as JSON."""
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Explorable(Protocol):
|
||||
"""
|
||||
Protocol for objects that can be explored to create charts.
|
||||
|
||||
This protocol defines the minimal interface required for a data source
|
||||
to be visualizable in Superset. It is implemented by:
|
||||
- BaseDatasource (SQL datasets and queries)
|
||||
- SemanticView (semantic layer views)
|
||||
- Future: Other data source types
|
||||
|
||||
The protocol focuses on the essential methods and properties needed
|
||||
for query execution, caching, and security.
|
||||
"""
|
||||
|
||||
# =========================================================================
|
||||
# Core Query Interface
|
||||
# =========================================================================
|
||||
|
||||
def get_query_result(self, query_object: QueryObject) -> QueryResult:
|
||||
"""
|
||||
Execute a query and return results.
|
||||
|
||||
This is the primary method for data retrieval. It takes a query
|
||||
object describing what data to fetch (columns, metrics, filters, time range,
|
||||
etc.) and returns a QueryResult containing a pandas DataFrame with the results.
|
||||
|
||||
:param query_obj: QueryObject describing the query
|
||||
|
||||
:return: QueryResult containing:
|
||||
- df: pandas DataFrame with query results
|
||||
- query: string representation of the executed query
|
||||
- duration: query execution time
|
||||
- status: QueryStatus (SUCCESS/FAILED)
|
||||
- error_message: error details if query failed
|
||||
"""
|
||||
|
||||
def get_query_str(self, query_obj: QueryObjectDict) -> str:
|
||||
"""
|
||||
Get the query string without executing.
|
||||
|
||||
Returns a string representation of the query that would be executed
|
||||
for the given query object. This is used for display in the UI
|
||||
and debugging.
|
||||
|
||||
:param query_obj: Dictionary describing the query
|
||||
:return: String representation of the query (SQL, GraphQL, etc.)
|
||||
"""
|
||||
|
||||
# =========================================================================
|
||||
# Identity & Metadata
|
||||
# =========================================================================
|
||||
|
||||
@property
|
||||
def uid(self) -> str:
|
||||
"""
|
||||
Unique identifier for this explorable.
|
||||
|
||||
Used as part of cache keys and for tracking. Should be stable
|
||||
across application restarts but change when the explorable's
|
||||
data or structure changes.
|
||||
|
||||
Format convention: "{type}_{id}" (e.g., "table_123", "semantic_view_abc")
|
||||
|
||||
:return: Unique identifier string
|
||||
"""
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
"""
|
||||
Type discriminator for this explorable.
|
||||
|
||||
Identifies the kind of data source (e.g., 'table', 'query', 'semantic_view').
|
||||
Used for routing and type-specific behavior.
|
||||
|
||||
:return: Type identifier string
|
||||
"""
|
||||
|
||||
@property
|
||||
def metrics(self) -> list[MetricMetadata]:
|
||||
"""
|
||||
List of metric metadata objects.
|
||||
|
||||
Each object should provide at minimum:
|
||||
- metric_name: str - the metric's name
|
||||
- expression: str - the metric's calculation expression
|
||||
|
||||
Used for validation, autocomplete, and query building.
|
||||
|
||||
:return: List of metric metadata objects
|
||||
"""
|
||||
|
||||
# TODO: rename to dimensions
|
||||
@property
|
||||
def columns(self) -> list[ColumnMetadata]:
|
||||
"""
|
||||
List of column metadata objects.
|
||||
|
||||
Each object should provide at minimum:
|
||||
- column_name: str - the column's name
|
||||
- type: str - the column's data type
|
||||
- is_dttm: bool - whether it's a datetime column
|
||||
|
||||
Used for validation, autocomplete, and query building.
|
||||
|
||||
:return: List of column metadata objects
|
||||
"""
|
||||
|
||||
# TODO: remove and use columns instead
|
||||
@property
|
||||
def column_names(self) -> list[str]:
|
||||
"""
|
||||
List of available column names.
|
||||
|
||||
A simple list of all column names in the explorable.
|
||||
Used for quick validation and filtering.
|
||||
|
||||
:return: List of column name strings
|
||||
"""
|
||||
|
||||
@property
|
||||
def data(self) -> ExplorableData:
|
||||
"""
|
||||
Full metadata representation sent to the frontend.
|
||||
|
||||
This property returns a dictionary containing all the metadata
|
||||
needed by the Explore UI, including columns, metrics, and
|
||||
other configuration.
|
||||
|
||||
Required keys in the returned dictionary:
|
||||
- id: unique identifier (int or str)
|
||||
- uid: unique string identifier
|
||||
- name: display name
|
||||
- type: explorable type ('table', 'query', 'semantic_view', etc.)
|
||||
- columns: list of column metadata dicts (with column_name, type, etc.)
|
||||
- metrics: list of metric metadata dicts (with metric_name, expression, etc.)
|
||||
- database: database metadata dict (with id, backend, etc.)
|
||||
|
||||
Optional keys:
|
||||
- description: human-readable description
|
||||
- schema: schema name (if applicable)
|
||||
- catalog: catalog name (if applicable)
|
||||
- cache_timeout: default cache timeout
|
||||
- offset: timezone offset
|
||||
- owners: list of owner IDs
|
||||
- verbose_map: dict mapping column/metric names to display names
|
||||
|
||||
:return: Dictionary with complete explorable metadata
|
||||
"""
|
||||
|
||||
# =========================================================================
|
||||
# Caching
|
||||
# =========================================================================
|
||||
|
||||
@property
|
||||
def cache_timeout(self) -> int | None:
|
||||
"""
|
||||
Default cache timeout in seconds.
|
||||
|
||||
Determines how long query results should be cached.
|
||||
Returns None to use the system default cache timeout.
|
||||
|
||||
:return: Cache timeout in seconds, or None for system default
|
||||
"""
|
||||
|
||||
@property
|
||||
def changed_on(self) -> datetime | None:
|
||||
"""
|
||||
Last modification timestamp.
|
||||
|
||||
Used for cache invalidation - when this changes, cached
|
||||
results for this explorable become invalid.
|
||||
|
||||
:return: Datetime of last modification, or None
|
||||
"""
|
||||
|
||||
def get_extra_cache_keys(self, query_obj: QueryObjectDict) -> list[Hashable]:
|
||||
"""
|
||||
Additional cache key components specific to this explorable.
|
||||
|
||||
Provides explorable-specific values to include in cache keys.
|
||||
Used to ensure cache invalidation when the explorable's
|
||||
underlying data or configuration changes in ways not captured
|
||||
by uid or changed_on.
|
||||
|
||||
:param query_obj: The query being executed
|
||||
:return: List of additional hashable values for cache key
|
||||
"""
|
||||
|
||||
# =========================================================================
|
||||
# Security
|
||||
# =========================================================================
|
||||
|
||||
@property
|
||||
def perm(self) -> str:
|
||||
"""
|
||||
Permission string for this explorable.
|
||||
|
||||
Used by the security manager to check if a user has access
|
||||
to this data source. Format depends on the explorable type
|
||||
(e.g., "[database].[schema].[table]" for SQL tables).
|
||||
|
||||
:return: Permission identifier string
|
||||
"""
|
||||
|
||||
# =========================================================================
|
||||
# Time/Date Handling
|
||||
# =========================================================================
|
||||
|
||||
@property
|
||||
def offset(self) -> int:
|
||||
"""
|
||||
Timezone offset for datetime columns.
|
||||
|
||||
Used to normalize datetime values to the user's timezone.
|
||||
Returns 0 for UTC, or an offset in seconds.
|
||||
|
||||
:return: Timezone offset in seconds (0 for UTC)
|
||||
"""
|
||||
|
||||
# =========================================================================
|
||||
# Time Granularity
|
||||
# =========================================================================
|
||||
|
||||
def get_time_grains(self) -> list[TimeGrainDict]:
|
||||
"""
|
||||
Get available time granularities for temporal grouping.
|
||||
|
||||
Returns a list of time grain options that can be used for grouping
|
||||
temporal data. Each time grain specifies how to bucket timestamps
|
||||
(e.g., by hour, day, week, month).
|
||||
|
||||
Each dictionary in the returned list should contain:
|
||||
- name: str - Display name (e.g., "Hour", "Day", "Week")
|
||||
- function: str - How to apply the grain (implementation-specific)
|
||||
- duration: str - ISO 8601 duration string (e.g., "PT1H", "P1D", "P1W")
|
||||
|
||||
For SQL datasources, the function is typically a SQL expression template
|
||||
like "DATE_TRUNC('hour', {col})". For semantic layers, it might be a
|
||||
semantic layer-specific identifier like "hour" or "day".
|
||||
|
||||
Return an empty list if time grains are not supported or applicable.
|
||||
|
||||
Example return value:
|
||||
```python
|
||||
[
|
||||
{
|
||||
"name": "Second",
|
||||
"function": "DATE_TRUNC('second', {col})",
|
||||
"duration": "PT1S",
|
||||
},
|
||||
{
|
||||
"name": "Minute",
|
||||
"function": "DATE_TRUNC('minute', {col})",
|
||||
"duration": "PT1M",
|
||||
},
|
||||
{
|
||||
"name": "Hour",
|
||||
"function": "DATE_TRUNC('hour', {col})",
|
||||
"duration": "PT1H",
|
||||
},
|
||||
{
|
||||
"name": "Day",
|
||||
"function": "DATE_TRUNC('day', {col})",
|
||||
"duration": "P1D",
|
||||
},
|
||||
]
|
||||
```
|
||||
|
||||
:return: List of time grain dictionaries (empty list if not supported)
|
||||
"""
|
||||
|
||||
# =========================================================================
|
||||
# Drilling
|
||||
# =========================================================================
|
||||
|
||||
def has_drill_by_columns(self, column_names: list[str]) -> bool:
|
||||
"""
|
||||
Check if the specified columns support drill-by operations.
|
||||
|
||||
Drill-by allows users to navigate from aggregated views to detailed
|
||||
data by grouping on specific dimensions. This method determines whether
|
||||
the given columns can be used for drill-by in the current datasource.
|
||||
|
||||
For SQL datasources, this typically checks if columns are marked as
|
||||
groupable in the metadata. For semantic views, it checks against the
|
||||
semantic layer's dimension definitions.
|
||||
|
||||
:param column_names: List of column names to check
|
||||
:return: True if all columns support drill-by, False otherwise
|
||||
"""
|
||||
|
||||
# =========================================================================
|
||||
# Optional Properties
|
||||
# =========================================================================
|
||||
|
||||
@property
|
||||
def is_rls_supported(self) -> bool:
|
||||
"""
|
||||
Whether this explorable supports Row Level Security.
|
||||
|
||||
Row Level Security (RLS) allows filtering data based on user identity.
|
||||
SQL-based datasources typically support this via SQL queries, while
|
||||
semantic layers may handle security at the semantic layer level.
|
||||
|
||||
:return: True if RLS is supported, False otherwise
|
||||
"""
|
||||
|
||||
@property
|
||||
def query_language(self) -> str | None:
|
||||
"""
|
||||
Query language identifier for syntax highlighting.
|
||||
|
||||
Specifies the language used in queries for proper syntax highlighting
|
||||
in the UI (e.g., 'sql', 'graphql', 'jsoniq').
|
||||
|
||||
:return: Language identifier string, or None if not applicable
|
||||
"""
|
||||
@@ -109,7 +109,7 @@ class ExploreRestApi(BaseSupersetApi):
|
||||
params = CommandParameters(
|
||||
permalink_key=request.args.get("permalink_key", type=str),
|
||||
form_data_key=request.args.get("form_data_key", type=str),
|
||||
datasource_id=request.args.get("datasource_id", type=int),
|
||||
datasource_id=request.args.get("datasource_id"),
|
||||
datasource_type=request.args.get("datasource_type", type=str),
|
||||
slice_id=request.args.get("slice_id", type=int),
|
||||
)
|
||||
|
||||
@@ -0,0 +1,126 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""add_semantic_layers_and_views
|
||||
|
||||
Revision ID: 33d7e0e21daa
|
||||
Revises: x2s8ocx6rto6
|
||||
Create Date: 2025-11-04 11:26:00.000000
|
||||
|
||||
"""
|
||||
|
||||
import uuid
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy_utils import UUIDType
|
||||
from sqlalchemy_utils.types.json import JSONType
|
||||
|
||||
from superset.extensions import encrypted_field_factory
|
||||
from superset.migrations.shared.utils import (
|
||||
create_fks_for_table,
|
||||
create_table,
|
||||
drop_table,
|
||||
)
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "33d7e0e21daa"
|
||||
down_revision = "x2s8ocx6rto6"
|
||||
|
||||
|
||||
def upgrade():
|
||||
# Create semantic_layers table
|
||||
create_table(
|
||||
"semantic_layers",
|
||||
sa.Column("uuid", UUIDType(binary=True), default=uuid.uuid4, nullable=False),
|
||||
sa.Column("created_on", sa.DateTime(), nullable=True),
|
||||
sa.Column("changed_on", sa.DateTime(), nullable=True),
|
||||
sa.Column("name", sa.String(length=250), nullable=False),
|
||||
sa.Column("description", sa.Text(), nullable=True),
|
||||
sa.Column("type", sa.String(length=250), nullable=False),
|
||||
sa.Column(
|
||||
"configuration",
|
||||
encrypted_field_factory.create(JSONType),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("cache_timeout", sa.Integer(), nullable=True),
|
||||
sa.Column("created_by_fk", sa.Integer(), nullable=True),
|
||||
sa.Column("changed_by_fk", sa.Integer(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("uuid"),
|
||||
)
|
||||
|
||||
# Create foreign key constraints for semantic_layers
|
||||
create_fks_for_table(
|
||||
"fk_semantic_layers_created_by_fk_ab_user",
|
||||
"semantic_layers",
|
||||
"ab_user",
|
||||
["created_by_fk"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
create_fks_for_table(
|
||||
"fk_semantic_layers_changed_by_fk_ab_user",
|
||||
"semantic_layers",
|
||||
"ab_user",
|
||||
["changed_by_fk"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
# Create semantic_views table
|
||||
create_table(
|
||||
"semantic_views",
|
||||
sa.Column("uuid", UUIDType(binary=True), default=uuid.uuid4, nullable=False),
|
||||
sa.Column("created_on", sa.DateTime(), nullable=True),
|
||||
sa.Column("changed_on", sa.DateTime(), nullable=True),
|
||||
sa.Column("name", sa.String(length=250), nullable=False),
|
||||
sa.Column("description", sa.Text(), nullable=True),
|
||||
sa.Column(
|
||||
"configuration",
|
||||
encrypted_field_factory.create(JSONType),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("cache_timeout", sa.Integer(), nullable=True),
|
||||
sa.Column(
|
||||
"semantic_layer_uuid",
|
||||
UUIDType(binary=True),
|
||||
sa.ForeignKey("semantic_layers.uuid", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("created_by_fk", sa.Integer(), nullable=True),
|
||||
sa.Column("changed_by_fk", sa.Integer(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("uuid"),
|
||||
)
|
||||
|
||||
# Create foreign key constraints for semantic_views
|
||||
create_fks_for_table(
|
||||
"fk_semantic_views_created_by_fk_ab_user",
|
||||
"semantic_views",
|
||||
"ab_user",
|
||||
["created_by_fk"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
create_fks_for_table(
|
||||
"fk_semantic_views_changed_by_fk_ab_user",
|
||||
"semantic_views",
|
||||
"ab_user",
|
||||
["changed_by_fk"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
drop_table("semantic_views")
|
||||
drop_table("semantic_layers")
|
||||
@@ -467,7 +467,9 @@ class ImportExportMixin(UUIDMixin):
|
||||
if parent_ref:
|
||||
parent_excludes = {c.name for c in parent_ref.local_columns}
|
||||
dict_rep = {
|
||||
c.name: getattr(self, c.name)
|
||||
# Convert c.name to str to handle SQLAlchemy's quoted_name type
|
||||
# which is not YAML-serializable
|
||||
str(c.name): getattr(self, c.name)
|
||||
for c in cls.__table__.columns # type: ignore
|
||||
if (
|
||||
c.name in export_fields
|
||||
@@ -837,7 +839,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def cache_timeout(self) -> int:
|
||||
def cache_timeout(self) -> int | None:
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
|
||||
@@ -50,6 +50,7 @@ from superset_core.api.models import Query as CoreQuery, SavedQuery as CoreSaved
|
||||
|
||||
from superset import security_manager
|
||||
from superset.exceptions import SupersetParseError, SupersetSecurityException
|
||||
from superset.explorables.base import TimeGrainDict
|
||||
from superset.jinja_context import BaseTemplateProcessor, get_template_processor
|
||||
from superset.models.helpers import (
|
||||
AuditMixinNullable,
|
||||
@@ -63,7 +64,7 @@ from superset.sql.parse import (
|
||||
Table,
|
||||
)
|
||||
from superset.sqllab.limiting_factor import LimitingFactor
|
||||
from superset.superset_typing import QueryData, QueryObjectDict
|
||||
from superset.superset_typing import ExplorableData, QueryObjectDict
|
||||
from superset.utils import json
|
||||
from superset.utils.core import (
|
||||
get_column_name,
|
||||
@@ -239,7 +240,7 @@ class Query(
|
||||
return None
|
||||
|
||||
@property
|
||||
def data(self) -> QueryData:
|
||||
def data(self) -> ExplorableData:
|
||||
"""Returns query data for the frontend"""
|
||||
order_by_choices = []
|
||||
for col in self.columns:
|
||||
@@ -335,6 +336,32 @@ class Query(
|
||||
def get_extra_cache_keys(self, query_obj: QueryObjectDict) -> list[Hashable]:
|
||||
return []
|
||||
|
||||
def get_time_grains(self) -> list[TimeGrainDict]:
|
||||
"""
|
||||
Get available time granularities from the database.
|
||||
|
||||
Delegates to the database's time grain definitions.
|
||||
"""
|
||||
return [
|
||||
{
|
||||
"name": grain.name,
|
||||
"function": grain.function,
|
||||
"duration": grain.duration,
|
||||
}
|
||||
for grain in (self.database.grains() or [])
|
||||
]
|
||||
|
||||
def has_drill_by_columns(self, column_names: list[str]) -> bool:
|
||||
"""
|
||||
Check if the specified columns support drill-by operations.
|
||||
|
||||
For Query objects, all columns are considered drillable since they
|
||||
come from ad-hoc SQL queries without predefined metadata.
|
||||
"""
|
||||
if not column_names:
|
||||
return False
|
||||
return set(column_names).issubset(set(self.column_names))
|
||||
|
||||
@property
|
||||
def tracking_url(self) -> Optional[str]:
|
||||
"""
|
||||
|
||||
@@ -87,6 +87,7 @@ if TYPE_CHECKING:
|
||||
RowLevelSecurityFilter,
|
||||
SqlaTable,
|
||||
)
|
||||
from superset.explorables.base import Explorable
|
||||
from superset.models.core import Database
|
||||
from superset.models.dashboard import Dashboard
|
||||
from superset.models.slice import Slice
|
||||
@@ -540,24 +541,43 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
||||
or (catalog_perm and self.can_access("catalog_access", catalog_perm))
|
||||
)
|
||||
|
||||
def can_access_schema(self, datasource: "BaseDatasource") -> bool:
|
||||
def can_access_schema(self, datasource: "BaseDatasource | Explorable") -> bool:
|
||||
"""
|
||||
Return True if the user can access the schema associated with specified
|
||||
datasource, False otherwise.
|
||||
|
||||
For SQL datasources: Checks database → catalog → schema hierarchy
|
||||
For other explorables: Only checks all_datasources permission
|
||||
|
||||
:param datasource: The datasource
|
||||
:returns: Whether the user can access the datasource's schema
|
||||
"""
|
||||
from superset.connectors.sqla.models import BaseDatasource
|
||||
|
||||
return (
|
||||
self.can_access_all_datasources()
|
||||
or self.can_access_database(datasource.database)
|
||||
or (
|
||||
datasource.catalog
|
||||
# Admin/superuser override
|
||||
if self.can_access_all_datasources():
|
||||
return True
|
||||
|
||||
# SQL-specific hierarchy checks
|
||||
if isinstance(datasource, BaseDatasource):
|
||||
# Database-level access grants all schemas
|
||||
if self.can_access_database(datasource.database):
|
||||
return True
|
||||
|
||||
# Catalog-level access grants all schemas in catalog
|
||||
if (
|
||||
hasattr(datasource, "catalog")
|
||||
and datasource.catalog
|
||||
and self.can_access_catalog(datasource.database, datasource.catalog)
|
||||
)
|
||||
or self.can_access("schema_access", datasource.schema_perm or "")
|
||||
)
|
||||
):
|
||||
return True
|
||||
|
||||
# Schema-level permission (SQL only)
|
||||
if self.can_access("schema_access", datasource.schema_perm or ""):
|
||||
return True
|
||||
|
||||
# Non-SQL explorables don't have schema hierarchy
|
||||
return False
|
||||
|
||||
def can_access_datasource(self, datasource: "BaseDatasource") -> bool:
|
||||
"""
|
||||
@@ -604,7 +624,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
||||
self,
|
||||
form_data: dict[str, Any],
|
||||
dashboard: "Dashboard",
|
||||
datasource: "BaseDatasource",
|
||||
datasource: "BaseDatasource | Explorable",
|
||||
) -> bool:
|
||||
"""
|
||||
Return True if the form_data is performing a supported drill by operation,
|
||||
@@ -612,10 +632,10 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
||||
|
||||
:param form_data: The form_data included in the request.
|
||||
:param dashboard: The dashboard the user is drilling from.
|
||||
:returns: Whether the user has drill byaccess.
|
||||
:param datasource: The datasource being queried
|
||||
:returns: Whether the user has drill by access.
|
||||
"""
|
||||
|
||||
from superset.connectors.sqla.models import TableColumn
|
||||
from superset.models.slice import Slice
|
||||
|
||||
return bool(
|
||||
@@ -630,16 +650,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
||||
and slc in dashboard.slices
|
||||
and slc.datasource == datasource
|
||||
and (dimensions := form_data.get("groupby"))
|
||||
and (
|
||||
drillable_columns := {
|
||||
row[0]
|
||||
for row in self.session.query(TableColumn.column_name)
|
||||
.filter(TableColumn.table_id == datasource.id)
|
||||
.filter(TableColumn.groupby)
|
||||
.all()
|
||||
}
|
||||
)
|
||||
and set(dimensions).issubset(drillable_columns)
|
||||
and datasource.has_drill_by_columns(dimensions)
|
||||
)
|
||||
|
||||
def can_access_dashboard(self, dashboard: "Dashboard") -> bool:
|
||||
@@ -705,7 +716,9 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_datasource_access_error_msg(datasource: "BaseDatasource") -> str:
|
||||
def get_datasource_access_error_msg(
|
||||
datasource: "BaseDatasource | Explorable",
|
||||
) -> str:
|
||||
"""
|
||||
Return the error message for the denied Superset datasource.
|
||||
|
||||
@@ -714,13 +727,13 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
||||
"""
|
||||
|
||||
return (
|
||||
f"This endpoint requires the datasource {datasource.id}, "
|
||||
f"This endpoint requires the datasource {datasource.data['id']}, "
|
||||
"database or `all_datasource_access` permission"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_datasource_access_link( # pylint: disable=unused-argument
|
||||
datasource: "BaseDatasource",
|
||||
datasource: "BaseDatasource | Explorable",
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Return the link for the denied Superset datasource.
|
||||
@@ -732,7 +745,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
||||
return get_conf().get("PERMISSION_INSTRUCTIONS_LINK")
|
||||
|
||||
def get_datasource_access_error_object( # pylint: disable=invalid-name
|
||||
self, datasource: "BaseDatasource"
|
||||
self, datasource: "BaseDatasource | Explorable"
|
||||
) -> SupersetError:
|
||||
"""
|
||||
Return the error object for the denied Superset datasource.
|
||||
@@ -746,8 +759,8 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
||||
level=ErrorLevel.WARNING,
|
||||
extra={
|
||||
"link": self.get_datasource_access_link(datasource),
|
||||
"datasource": datasource.id,
|
||||
"datasource_name": datasource.name,
|
||||
"datasource": datasource.data["id"],
|
||||
"datasource_name": datasource.data["name"],
|
||||
},
|
||||
)
|
||||
|
||||
@@ -2280,8 +2293,8 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
||||
dashboard: Optional["Dashboard"] = None,
|
||||
chart: Optional["Slice"] = None,
|
||||
database: Optional["Database"] = None,
|
||||
datasource: Optional["BaseDatasource"] = None,
|
||||
query: Optional["Query"] = None,
|
||||
datasource: Optional["BaseDatasource | Explorable"] = None,
|
||||
query: Optional["Query | Explorable"] = None,
|
||||
query_context: Optional["QueryContext"] = None,
|
||||
table: Optional["Table"] = None,
|
||||
viz: Optional["BaseViz"] = None,
|
||||
@@ -2326,7 +2339,9 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
||||
|
||||
if database and table or query:
|
||||
if query:
|
||||
database = query.database
|
||||
# Type narrow: only SQL Lab Query objects have .database attribute
|
||||
if hasattr(query, "database"):
|
||||
database = query.database
|
||||
|
||||
database = cast("Database", database)
|
||||
default_catalog = database.get_default_catalog()
|
||||
@@ -2334,7 +2349,8 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
||||
if self.can_access_database(database):
|
||||
return
|
||||
|
||||
if query:
|
||||
# Type narrow: this path only applies to SQL Lab Query objects
|
||||
if query and hasattr(query, "sql") and hasattr(query, "catalog"):
|
||||
# Getting the default schema for a query is hard. Users can select the
|
||||
# schema in SQL Lab, but there's no guarantee that the query actually
|
||||
# will run in that schema. Each DB engine spec needs to implement the
|
||||
@@ -2342,8 +2358,11 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
||||
# If the DB engine spec doesn't implement the logic the schema is read
|
||||
# from the SQLAlchemy URI if possible; if not, we use the SQLAlchemy
|
||||
# inspector to read it.
|
||||
from superset.models.sql_lab import Query
|
||||
|
||||
default_schema = database.get_default_schema_for_query(
|
||||
query, template_params
|
||||
cast(Query, query),
|
||||
template_params,
|
||||
)
|
||||
tables = {
|
||||
table_.qualify(
|
||||
@@ -2455,7 +2474,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
||||
and dashboard_.json_metadata
|
||||
and (json_metadata := json.loads(dashboard_.json_metadata))
|
||||
and any(
|
||||
target.get("datasetId") == datasource.id
|
||||
target.get("datasetId") == datasource.data["id"]
|
||||
for fltr in json_metadata.get(
|
||||
"native_filter_configuration",
|
||||
[],
|
||||
@@ -2560,7 +2579,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
||||
return super().get_user_roles(user)
|
||||
|
||||
def get_guest_rls_filters(
|
||||
self, dataset: "BaseDatasource"
|
||||
self, dataset: "BaseDatasource | Explorable"
|
||||
) -> list[GuestTokenRlsRule]:
|
||||
"""
|
||||
Retrieves the row level security filters for the current user and the dataset,
|
||||
@@ -2573,11 +2592,11 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
||||
rule
|
||||
for rule in guest_user.rls
|
||||
if not rule.get("dataset")
|
||||
or str(rule.get("dataset")) == str(dataset.id)
|
||||
or str(rule.get("dataset")) == str(dataset.data["id"])
|
||||
]
|
||||
return []
|
||||
|
||||
def get_rls_filters(self, table: "BaseDatasource") -> list[SqlaQuery]:
|
||||
def get_rls_filters(self, table: "BaseDatasource | Explorable") -> list[SqlaQuery]:
|
||||
"""
|
||||
Retrieves the appropriate row level security filters for the current user and
|
||||
the passed table.
|
||||
@@ -2614,7 +2633,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
||||
.filter(RLSFilterRoles.c.role_id.in_(user_roles))
|
||||
)
|
||||
filter_tables = self.session.query(RLSFilterTables.c.rls_filter_id).filter(
|
||||
RLSFilterTables.c.table_id == table.id
|
||||
RLSFilterTables.c.table_id == table.data["id"]
|
||||
)
|
||||
query = (
|
||||
self.session.query(
|
||||
@@ -2640,7 +2659,9 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
||||
)
|
||||
return query.all()
|
||||
|
||||
def get_rls_sorted(self, table: "BaseDatasource") -> list["RowLevelSecurityFilter"]:
|
||||
def get_rls_sorted(
|
||||
self, table: "BaseDatasource | Explorable"
|
||||
) -> list["RowLevelSecurityFilter"]:
|
||||
"""
|
||||
Retrieves a list RLS filters sorted by ID for
|
||||
the current user and the passed table.
|
||||
@@ -2652,10 +2673,12 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
||||
filters.sort(key=lambda f: f.id)
|
||||
return filters
|
||||
|
||||
def get_guest_rls_filters_str(self, table: "BaseDatasource") -> list[str]:
|
||||
def get_guest_rls_filters_str(
|
||||
self, table: "BaseDatasource | Explorable"
|
||||
) -> list[str]:
|
||||
return [f.get("clause", "") for f in self.get_guest_rls_filters(table)]
|
||||
|
||||
def get_rls_cache_key(self, datasource: "BaseDatasource") -> list[str]:
|
||||
def get_rls_cache_key(self, datasource: "Explorable | BaseDatasource") -> list[str]:
|
||||
rls_clauses_with_group_key = []
|
||||
if datasource.is_rls_supported:
|
||||
rls_clauses_with_group_key = [
|
||||
|
||||
16
superset/semantic_layers/__init__.py
Normal file
16
superset/semantic_layers/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
938
superset/semantic_layers/mapper.py
Normal file
938
superset/semantic_layers/mapper.py
Normal file
@@ -0,0 +1,938 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Functions for mapping `QueryObject` to semantic layers.
|
||||
|
||||
These functions validate and convert a `QueryObject` into one or more `SemanticQuery`,
|
||||
which are then passed to semantic layer implementations for execution, returning a
|
||||
single dataframe.
|
||||
|
||||
"""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from time import time
|
||||
from typing import Any, cast, Sequence, TypeGuard
|
||||
|
||||
import numpy as np
|
||||
|
||||
from superset.common.db_query_status import QueryStatus
|
||||
from superset.common.query_object import QueryObject
|
||||
from superset.common.utils.time_range_utils import get_since_until_from_query_object
|
||||
from superset.connectors.sqla.models import BaseDatasource
|
||||
from superset.models.helpers import QueryResult
|
||||
from superset.semantic_layers.types import (
|
||||
AdhocExpression,
|
||||
AdhocFilter,
|
||||
Day,
|
||||
Dimension,
|
||||
Filter,
|
||||
FilterValues,
|
||||
Grain,
|
||||
GroupLimit,
|
||||
Hour,
|
||||
Metric,
|
||||
Minute,
|
||||
Month,
|
||||
Operator,
|
||||
OrderDirection,
|
||||
OrderTuple,
|
||||
PredicateType,
|
||||
Quarter,
|
||||
Second,
|
||||
SemanticQuery,
|
||||
SemanticResult,
|
||||
SemanticViewFeature,
|
||||
Week,
|
||||
Year,
|
||||
)
|
||||
from superset.utils.core import (
|
||||
FilterOperator,
|
||||
QueryObjectFilterClause,
|
||||
TIME_COMPARISON,
|
||||
)
|
||||
from superset.utils.date_parser import get_past_or_future
|
||||
|
||||
|
||||
class ValidatedQueryObjectFilterClause(QueryObjectFilterClause):
|
||||
"""
|
||||
A validated QueryObject filter clause with a string column name.
|
||||
|
||||
The `col` in a `QueryObjectFilterClause` can be either a string (column name) or an
|
||||
adhoc column, but we only support the former in semantic layers.
|
||||
"""
|
||||
|
||||
# overwrite to narrow type; mypy complains about more restrictive typed dicts,
|
||||
# but the alternative would be to redefine the object
|
||||
col: str # type: ignore[misc]
|
||||
op: str # type: ignore[misc]
|
||||
|
||||
|
||||
class ValidatedQueryObject(QueryObject):
|
||||
"""
|
||||
A query object that has a datasource defined.
|
||||
"""
|
||||
|
||||
datasource: BaseDatasource
|
||||
|
||||
# overwrite to narrow type; mypy complains about the assignment since the base type
|
||||
# allows adhoc filters, but we only support validated filters here
|
||||
filter: list[ValidatedQueryObjectFilterClause] # type: ignore[assignment]
|
||||
series_columns: Sequence[str] # type: ignore[assignment]
|
||||
series_limit_metric: str | None
|
||||
|
||||
|
||||
def get_results(query_object: QueryObject) -> QueryResult:
|
||||
"""
|
||||
Run 1+ queries based on `QueryObject` and return the results.
|
||||
|
||||
:param query_object: The QueryObject containing query specifications
|
||||
:return: QueryResult compatible with Superset's query interface
|
||||
"""
|
||||
if not validate_query_object(query_object):
|
||||
raise ValueError("QueryObject must have a datasource defined.")
|
||||
|
||||
# Track execution time
|
||||
start_time = time()
|
||||
|
||||
semantic_view = query_object.datasource.implementation
|
||||
dispatcher = (
|
||||
semantic_view.get_row_count
|
||||
if query_object.is_rowcount
|
||||
else semantic_view.get_dataframe
|
||||
)
|
||||
|
||||
# Step 1: Convert QueryObject to list of SemanticQuery objects
|
||||
# The first query is the main query, subsequent queries are for time offsets
|
||||
queries = map_query_object(query_object)
|
||||
|
||||
# Step 2: Execute the main query (first in the list)
|
||||
main_query = queries[0]
|
||||
main_result = dispatcher(
|
||||
metrics=main_query.metrics,
|
||||
dimensions=main_query.dimensions,
|
||||
filters=main_query.filters,
|
||||
order=main_query.order,
|
||||
limit=main_query.limit,
|
||||
offset=main_query.offset,
|
||||
group_limit=main_query.group_limit,
|
||||
)
|
||||
|
||||
main_df = main_result.results
|
||||
|
||||
# Collect all requests (SQL queries, HTTP requests, etc.) for troubleshooting
|
||||
all_requests = list(main_result.requests)
|
||||
|
||||
# If no time offsets, return the main result as-is
|
||||
if not query_object.time_offsets or len(queries) <= 1:
|
||||
semantic_result = SemanticResult(
|
||||
requests=all_requests,
|
||||
results=main_df,
|
||||
)
|
||||
duration = timedelta(seconds=time() - start_time)
|
||||
return map_semantic_result_to_query_result(
|
||||
semantic_result,
|
||||
query_object,
|
||||
duration,
|
||||
)
|
||||
|
||||
# Get metric names from the main query
|
||||
# These are the columns that will be renamed with offset suffixes
|
||||
metric_names = [metric.name for metric in main_query.metrics]
|
||||
|
||||
# Join keys are all columns except metrics
|
||||
# These will be used to match rows between main and offset DataFrames
|
||||
join_keys = [col for col in main_df.columns if col not in metric_names]
|
||||
|
||||
# Step 3 & 4: Execute each time offset query and join results
|
||||
for offset_query, time_offset in zip(
|
||||
queries[1:],
|
||||
query_object.time_offsets,
|
||||
strict=False,
|
||||
):
|
||||
# Execute the offset query
|
||||
result = dispatcher(
|
||||
metrics=offset_query.metrics,
|
||||
dimensions=offset_query.dimensions,
|
||||
filters=offset_query.filters,
|
||||
order=offset_query.order,
|
||||
limit=offset_query.limit,
|
||||
offset=offset_query.offset,
|
||||
group_limit=offset_query.group_limit,
|
||||
)
|
||||
|
||||
# Add this query's requests to the collection
|
||||
all_requests.extend(result.requests)
|
||||
|
||||
offset_df = result.results
|
||||
|
||||
# Handle empty results - add NaN columns directly instead of merging
|
||||
# This avoids dtype mismatch issues with empty DataFrames
|
||||
if offset_df.empty:
|
||||
# Add offset metric columns with NaN values directly to main_df
|
||||
for metric in metric_names:
|
||||
offset_col_name = TIME_COMPARISON.join([metric, time_offset])
|
||||
main_df[offset_col_name] = np.nan
|
||||
else:
|
||||
# Rename metric columns with time offset suffix
|
||||
# Format: "{metric_name}__{time_offset}"
|
||||
# Example: "revenue" -> "revenue__1 week ago"
|
||||
offset_df = offset_df.rename(
|
||||
columns={
|
||||
metric: TIME_COMPARISON.join([metric, time_offset])
|
||||
for metric in metric_names
|
||||
}
|
||||
)
|
||||
|
||||
# Step 5: Perform left join on dimension columns
|
||||
# This preserves all rows from main_df and adds offset metrics
|
||||
# where they match
|
||||
main_df = main_df.merge(
|
||||
offset_df,
|
||||
on=join_keys,
|
||||
how="left",
|
||||
suffixes=("", "__duplicate"),
|
||||
)
|
||||
|
||||
# Clean up any duplicate columns that might have been created
|
||||
# (shouldn't happen with proper join keys, but defensive programming)
|
||||
duplicate_cols = [
|
||||
col for col in main_df.columns if col.endswith("__duplicate")
|
||||
]
|
||||
if duplicate_cols:
|
||||
main_df = main_df.drop(columns=duplicate_cols)
|
||||
|
||||
# Convert final result to QueryResult
|
||||
semantic_result = SemanticResult(requests=all_requests, results=main_df)
|
||||
duration = timedelta(seconds=time() - start_time)
|
||||
return map_semantic_result_to_query_result(
|
||||
semantic_result,
|
||||
query_object,
|
||||
duration,
|
||||
)
|
||||
|
||||
|
||||
def map_semantic_result_to_query_result(
|
||||
semantic_result: SemanticResult,
|
||||
query_object: ValidatedQueryObject,
|
||||
duration: timedelta,
|
||||
) -> QueryResult:
|
||||
"""
|
||||
Convert a SemanticResult to a QueryResult.
|
||||
|
||||
:param semantic_result: Result from the semantic layer
|
||||
:param query_object: Original QueryObject (for passthrough attributes)
|
||||
:param duration: Time taken to execute the query
|
||||
:return: QueryResult compatible with Superset's query interface
|
||||
"""
|
||||
# Get the query string from requests (typically one or more SQL queries)
|
||||
query_str = ""
|
||||
if semantic_result.requests:
|
||||
# Join all requests for display (could be multiple for time comparisons)
|
||||
query_str = "\n\n".join(
|
||||
f"-- {req.type}\n{req.definition}" for req in semantic_result.requests
|
||||
)
|
||||
|
||||
return QueryResult(
|
||||
# Core data
|
||||
df=semantic_result.results,
|
||||
query=query_str,
|
||||
duration=duration,
|
||||
# Template filters - not applicable to semantic layers
|
||||
# (semantic layers don't use Jinja templates)
|
||||
applied_template_filters=None,
|
||||
# Filter columns - not applicable to semantic layers
|
||||
# (semantic layers handle filter validation internally)
|
||||
applied_filter_columns=None,
|
||||
rejected_filter_columns=None,
|
||||
# Status - always success if we got here
|
||||
# (errors would raise exceptions before reaching this point)
|
||||
status=QueryStatus.SUCCESS,
|
||||
error_message=None,
|
||||
errors=None,
|
||||
# Time range - pass through from original query_object
|
||||
from_dttm=query_object.from_dttm,
|
||||
to_dttm=query_object.to_dttm,
|
||||
)
|
||||
|
||||
|
||||
def _normalize_column(column: str | dict, dimension_names: set[str]) -> str:
|
||||
"""
|
||||
Normalize a column to its dimension name.
|
||||
|
||||
Columns can be either:
|
||||
- A string (dimension name directly)
|
||||
- A dict with isColumnReference=True and sqlExpression containing the dimension name
|
||||
"""
|
||||
if isinstance(column, str):
|
||||
return column
|
||||
|
||||
if isinstance(column, dict):
|
||||
# Handle column references (e.g., from time-series charts)
|
||||
if column.get("isColumnReference") and column.get("sqlExpression"):
|
||||
sql_expr = column["sqlExpression"]
|
||||
if sql_expr in dimension_names:
|
||||
return sql_expr
|
||||
|
||||
raise ValueError("Adhoc dimensions are not supported in Semantic Views.")
|
||||
|
||||
|
||||
def map_query_object(query_object: ValidatedQueryObject) -> list[SemanticQuery]:
|
||||
"""
|
||||
Convert a `QueryObject` into a list of `SemanticQuery`.
|
||||
|
||||
This function maps the `QueryObject` into query objects that focus less on
|
||||
visualization and more on semantics.
|
||||
"""
|
||||
semantic_view = query_object.datasource.implementation
|
||||
|
||||
all_metrics = {metric.name: metric for metric in semantic_view.metrics}
|
||||
all_dimensions = {
|
||||
dimension.name: dimension for dimension in semantic_view.dimensions
|
||||
}
|
||||
|
||||
# Normalize columns (may be dicts with isColumnReference=True for time-series)
|
||||
dimension_names = set(all_dimensions.keys())
|
||||
normalized_columns = {
|
||||
_normalize_column(column, dimension_names) for column in query_object.columns
|
||||
}
|
||||
|
||||
metrics = [all_metrics[metric] for metric in (query_object.metrics or [])]
|
||||
|
||||
grain = (
|
||||
_convert_time_grain(query_object.extras["time_grain_sqla"])
|
||||
if "time_grain_sqla" in query_object.extras
|
||||
else None
|
||||
)
|
||||
dimensions = [
|
||||
dimension
|
||||
for dimension in semantic_view.dimensions
|
||||
if dimension.name in normalized_columns
|
||||
and (
|
||||
# if a grain is specified, only include the time dimension if its grain
|
||||
# matches the requested grain
|
||||
grain is None
|
||||
or dimension.name != query_object.granularity
|
||||
or dimension.grain == grain
|
||||
)
|
||||
]
|
||||
|
||||
order = _get_order_from_query_object(query_object, all_metrics, all_dimensions)
|
||||
limit = query_object.row_limit
|
||||
offset = query_object.row_offset
|
||||
|
||||
group_limit = _get_group_limit_from_query_object(
|
||||
query_object,
|
||||
all_metrics,
|
||||
all_dimensions,
|
||||
)
|
||||
|
||||
queries = []
|
||||
for time_offset in [None] + query_object.time_offsets:
|
||||
filters = _get_filters_from_query_object(
|
||||
query_object,
|
||||
time_offset,
|
||||
all_dimensions,
|
||||
)
|
||||
print(">>", filters)
|
||||
|
||||
queries.append(
|
||||
SemanticQuery(
|
||||
metrics=metrics,
|
||||
dimensions=dimensions,
|
||||
filters=filters,
|
||||
order=order,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
group_limit=group_limit,
|
||||
)
|
||||
)
|
||||
|
||||
return queries
|
||||
|
||||
|
||||
def _get_filters_from_query_object(
|
||||
query_object: ValidatedQueryObject,
|
||||
time_offset: str | None,
|
||||
all_dimensions: dict[str, Dimension],
|
||||
) -> set[Filter | AdhocFilter]:
|
||||
"""
|
||||
Extract all filters from the query object, including time range filters.
|
||||
|
||||
This simplifies the complexity of from_dttm/to_dttm/inner_from_dttm/inner_to_dttm
|
||||
by converting all time constraints into filters.
|
||||
"""
|
||||
filters: set[Filter | AdhocFilter] = set()
|
||||
|
||||
# 1. Add fetch values predicate if present
|
||||
if (
|
||||
query_object.apply_fetch_values_predicate
|
||||
and query_object.datasource.fetch_values_predicate
|
||||
):
|
||||
filters.add(
|
||||
AdhocFilter(
|
||||
type=PredicateType.WHERE,
|
||||
definition=query_object.datasource.fetch_values_predicate,
|
||||
)
|
||||
)
|
||||
|
||||
# 2. Add time range filter based on from_dttm/to_dttm
|
||||
# For time offsets, this automatically calculates the shifted bounds
|
||||
time_filters = _get_time_filter(query_object, time_offset, all_dimensions)
|
||||
filters.update(time_filters)
|
||||
|
||||
# 3. Add filters from query_object.extras (WHERE and HAVING clauses)
|
||||
extras_filters = _get_filters_from_extras(query_object.extras)
|
||||
filters.update(extras_filters)
|
||||
|
||||
# 4. Add all other filters from query_object.filter
|
||||
for filter_ in query_object.filter:
|
||||
# Skip temporal range filters - we're using inner bounds instead
|
||||
if (
|
||||
filter_.get("op") == FilterOperator.TEMPORAL_RANGE.value
|
||||
and query_object.granularity
|
||||
):
|
||||
continue
|
||||
|
||||
if converted_filters := _convert_query_object_filter(filter_, all_dimensions):
|
||||
filters.update(converted_filters)
|
||||
|
||||
return filters
|
||||
|
||||
|
||||
def _get_filters_from_extras(extras: dict[str, Any]) -> set[AdhocFilter]:
|
||||
"""
|
||||
Extract filters from the extras dict.
|
||||
|
||||
The extras dict can contain various keys that affect query behavior:
|
||||
|
||||
Supported keys (converted to filters):
|
||||
- "where": SQL WHERE clause expression (e.g., "customer_id > 100")
|
||||
- "having": SQL HAVING clause expression (e.g., "SUM(sales) > 1000")
|
||||
|
||||
Other keys in extras (handled elsewhere in the mapper):
|
||||
- "time_grain_sqla": Time granularity (e.g., "P1D", "PT1H")
|
||||
Handled in _convert_time_grain() and used for dimension grain matching
|
||||
|
||||
Note: The WHERE and HAVING clauses from extras are SQL expressions that
|
||||
are passed through as-is to the semantic layer as AdhocFilter objects.
|
||||
"""
|
||||
filters: set[AdhocFilter] = set()
|
||||
|
||||
# Add WHERE clause from extras
|
||||
if where_clause := extras.get("where"):
|
||||
filters.add(
|
||||
AdhocFilter(
|
||||
type=PredicateType.WHERE,
|
||||
definition=where_clause,
|
||||
)
|
||||
)
|
||||
|
||||
# Add HAVING clause from extras
|
||||
if having_clause := extras.get("having"):
|
||||
filters.add(
|
||||
AdhocFilter(
|
||||
type=PredicateType.HAVING,
|
||||
definition=having_clause,
|
||||
)
|
||||
)
|
||||
|
||||
return filters
|
||||
|
||||
|
||||
def _get_time_filter(
|
||||
query_object: ValidatedQueryObject,
|
||||
time_offset: str | None,
|
||||
all_dimensions: dict[str, Dimension],
|
||||
) -> set[Filter]:
|
||||
"""
|
||||
Create a time range filter from the query object.
|
||||
|
||||
This handles both regular queries and time offset queries, simplifying the
|
||||
complexity of from_dttm/to_dttm/inner_from_dttm/inner_to_dttm by using the
|
||||
same time bounds for both the main query and series limit subqueries.
|
||||
"""
|
||||
filters: set[Filter] = set()
|
||||
|
||||
if not query_object.granularity:
|
||||
return filters
|
||||
|
||||
time_dimension = all_dimensions.get(query_object.granularity)
|
||||
if not time_dimension:
|
||||
return filters
|
||||
|
||||
# Get the appropriate time bounds based on whether this is a time offset query
|
||||
from_dttm, to_dttm = _get_time_bounds(query_object, time_offset)
|
||||
|
||||
if not from_dttm or not to_dttm:
|
||||
return filters
|
||||
|
||||
# Create a filter with >= and < operators
|
||||
return {
|
||||
Filter(
|
||||
type=PredicateType.WHERE,
|
||||
column=time_dimension,
|
||||
operator=Operator.GREATER_THAN_OR_EQUAL,
|
||||
value=from_dttm,
|
||||
),
|
||||
Filter(
|
||||
type=PredicateType.WHERE,
|
||||
column=time_dimension,
|
||||
operator=Operator.LESS_THAN,
|
||||
value=to_dttm,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def _get_time_bounds(
|
||||
query_object: ValidatedQueryObject,
|
||||
time_offset: str | None,
|
||||
) -> tuple[datetime | None, datetime | None]:
|
||||
"""
|
||||
Get the appropriate time bounds for the query.
|
||||
|
||||
For regular queries (time_offset is None), returns from_dttm/to_dttm.
|
||||
For time offset queries, calculates the shifted bounds.
|
||||
|
||||
This simplifies the inner_from_dttm/inner_to_dttm complexity by using
|
||||
the same bounds for both main queries and series limit subqueries (Option 1).
|
||||
"""
|
||||
if time_offset is None:
|
||||
# Main query: use from_dttm/to_dttm directly
|
||||
return query_object.from_dttm, query_object.to_dttm
|
||||
|
||||
# Time offset query: calculate shifted bounds
|
||||
# Use from_dttm/to_dttm if available, otherwise try to get from time_range
|
||||
outer_from = query_object.from_dttm
|
||||
outer_to = query_object.to_dttm
|
||||
|
||||
if not outer_from or not outer_to:
|
||||
# Fall back to parsing time_range if from_dttm/to_dttm not set
|
||||
outer_from, outer_to = get_since_until_from_query_object(query_object)
|
||||
|
||||
if not outer_from or not outer_to:
|
||||
return None, None
|
||||
|
||||
# Apply the offset to both bounds
|
||||
offset_from = get_past_or_future(time_offset, outer_from)
|
||||
offset_to = get_past_or_future(time_offset, outer_to)
|
||||
|
||||
return offset_from, offset_to
|
||||
|
||||
|
||||
def _convert_query_object_filter(
|
||||
filter_: ValidatedQueryObjectFilterClause,
|
||||
all_dimensions: dict[str, Dimension],
|
||||
) -> set[Filter] | None:
|
||||
"""
|
||||
Convert a QueryObject filter dict to a semantic layer Filter or AdhocFilter.
|
||||
"""
|
||||
operator_str = filter_["op"]
|
||||
|
||||
# Handle simple column filters
|
||||
col = filter_.get("col")
|
||||
if col not in all_dimensions:
|
||||
return None
|
||||
|
||||
dimension = all_dimensions[col]
|
||||
|
||||
val_str = filter_["val"]
|
||||
value: FilterValues | set[FilterValues]
|
||||
if val_str is None:
|
||||
value = None
|
||||
elif isinstance(val_str, (list, tuple)):
|
||||
value = set(val_str)
|
||||
else:
|
||||
value = val_str
|
||||
|
||||
# Special case for temporal range
|
||||
if operator_str == FilterOperator.TEMPORAL_RANGE.value:
|
||||
# XXX
|
||||
start, end = value.split(" : ")
|
||||
return {
|
||||
Filter(
|
||||
type=PredicateType.WHERE,
|
||||
column=dimension,
|
||||
operator=Operator.GREATER_THAN_OR_EQUAL,
|
||||
value=start,
|
||||
),
|
||||
Filter(
|
||||
type=PredicateType.WHERE,
|
||||
column=dimension,
|
||||
operator=Operator.LESS_THAN,
|
||||
value=end,
|
||||
),
|
||||
}
|
||||
|
||||
# Map QueryObject operators to semantic layer operators
|
||||
operator_mapping = {
|
||||
FilterOperator.EQUALS.value: Operator.EQUALS,
|
||||
FilterOperator.NOT_EQUALS.value: Operator.NOT_EQUALS,
|
||||
FilterOperator.GREATER_THAN.value: Operator.GREATER_THAN,
|
||||
FilterOperator.LESS_THAN.value: Operator.LESS_THAN,
|
||||
FilterOperator.GREATER_THAN_OR_EQUALS.value: Operator.GREATER_THAN_OR_EQUAL,
|
||||
FilterOperator.LESS_THAN_OR_EQUALS.value: Operator.LESS_THAN_OR_EQUAL,
|
||||
FilterOperator.IN.value: Operator.IN,
|
||||
FilterOperator.NOT_IN.value: Operator.NOT_IN,
|
||||
FilterOperator.LIKE.value: Operator.LIKE,
|
||||
FilterOperator.NOT_LIKE.value: Operator.NOT_LIKE,
|
||||
FilterOperator.IS_NULL.value: Operator.IS_NULL,
|
||||
FilterOperator.IS_NOT_NULL.value: Operator.IS_NOT_NULL,
|
||||
}
|
||||
|
||||
operator = operator_mapping.get(operator_str)
|
||||
if not operator:
|
||||
# Unknown operator - create adhoc filter
|
||||
return None
|
||||
|
||||
return {
|
||||
Filter(
|
||||
type=PredicateType.WHERE,
|
||||
column=dimension,
|
||||
operator=operator,
|
||||
value=value,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
def _get_order_from_query_object(
|
||||
query_object: ValidatedQueryObject,
|
||||
all_metrics: dict[str, Metric],
|
||||
all_dimensions: dict[str, Dimension],
|
||||
) -> list[OrderTuple]:
|
||||
order: list[OrderTuple] = []
|
||||
for element, ascending in query_object.orderby:
|
||||
direction = OrderDirection.ASC if ascending else OrderDirection.DESC
|
||||
|
||||
# adhoc
|
||||
if isinstance(element, dict):
|
||||
if element["sqlExpression"] is not None:
|
||||
order.append(
|
||||
(
|
||||
AdhocExpression(
|
||||
id=element["label"] or element["sqlExpression"],
|
||||
definition=element["sqlExpression"],
|
||||
),
|
||||
direction,
|
||||
)
|
||||
)
|
||||
elif element in all_dimensions:
|
||||
order.append((all_dimensions[element], direction))
|
||||
elif element in all_metrics:
|
||||
order.append((all_metrics[element], direction))
|
||||
|
||||
return order
|
||||
|
||||
|
||||
def _get_group_limit_from_query_object(
|
||||
query_object: ValidatedQueryObject,
|
||||
all_metrics: dict[str, Metric],
|
||||
all_dimensions: dict[str, Dimension],
|
||||
) -> GroupLimit | None:
|
||||
# no limit
|
||||
if query_object.series_limit == 0 or not query_object.columns:
|
||||
return None
|
||||
|
||||
dimensions = [all_dimensions[dim_id] for dim_id in query_object.series_columns]
|
||||
top = query_object.series_limit
|
||||
metric = (
|
||||
all_metrics[query_object.series_limit_metric]
|
||||
if query_object.series_limit_metric
|
||||
else None
|
||||
)
|
||||
direction = OrderDirection.DESC if query_object.order_desc else OrderDirection.ASC
|
||||
group_others = query_object.group_others_when_limit_reached
|
||||
|
||||
# Check if we need separate filters for the group limit subquery
|
||||
# This happens when inner_from_dttm/inner_to_dttm differ from from_dttm/to_dttm
|
||||
group_limit_filters = _get_group_limit_filters(query_object, all_dimensions)
|
||||
|
||||
return GroupLimit(
|
||||
dimensions=dimensions,
|
||||
top=top,
|
||||
metric=metric,
|
||||
direction=direction,
|
||||
group_others=group_others,
|
||||
filters=group_limit_filters,
|
||||
)
|
||||
|
||||
|
||||
def _get_group_limit_filters(
|
||||
query_object: ValidatedQueryObject,
|
||||
all_dimensions: dict[str, Dimension],
|
||||
) -> set[Filter | AdhocFilter] | None:
|
||||
"""
|
||||
Get separate filters for the group limit subquery if needed.
|
||||
|
||||
This is used when inner_from_dttm/inner_to_dttm differ from from_dttm/to_dttm,
|
||||
which happens during time comparison queries. The group limit subquery may need
|
||||
different time bounds to determine the top N groups.
|
||||
|
||||
Returns None if the group limit should use the same filters as the main query.
|
||||
"""
|
||||
# Check if inner time bounds are explicitly set and differ from outer bounds
|
||||
if (
|
||||
query_object.inner_from_dttm is None
|
||||
or query_object.inner_to_dttm is None
|
||||
or (
|
||||
query_object.inner_from_dttm == query_object.from_dttm
|
||||
and query_object.inner_to_dttm == query_object.to_dttm
|
||||
)
|
||||
):
|
||||
# No separate bounds needed - use the same filters as the main query
|
||||
return None
|
||||
|
||||
# Create separate filters for the group limit subquery
|
||||
filters: set[Filter | AdhocFilter] = set()
|
||||
|
||||
# Add time range filter using inner bounds
|
||||
if query_object.granularity:
|
||||
time_dimension = all_dimensions.get(query_object.granularity)
|
||||
if (
|
||||
time_dimension
|
||||
and query_object.inner_from_dttm
|
||||
and query_object.inner_to_dttm
|
||||
):
|
||||
filters.update(
|
||||
{
|
||||
Filter(
|
||||
type=PredicateType.WHERE,
|
||||
column=time_dimension,
|
||||
operator=Operator.GREATER_THAN_OR_EQUAL,
|
||||
value=query_object.inner_from_dttm,
|
||||
),
|
||||
Filter(
|
||||
type=PredicateType.WHERE,
|
||||
column=time_dimension,
|
||||
operator=Operator.LESS_THAN,
|
||||
value=query_object.inner_to_dttm,
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
# Add fetch values predicate if present
|
||||
if (
|
||||
query_object.apply_fetch_values_predicate
|
||||
and query_object.datasource.fetch_values_predicate
|
||||
):
|
||||
filters.add(
|
||||
AdhocFilter(
|
||||
type=PredicateType.WHERE,
|
||||
definition=query_object.datasource.fetch_values_predicate,
|
||||
)
|
||||
)
|
||||
|
||||
# Add filters from query_object.extras (WHERE and HAVING clauses)
|
||||
extras_filters = _get_filters_from_extras(query_object.extras)
|
||||
filters.update(extras_filters)
|
||||
|
||||
# Add all other non-temporal filters from query_object.filter
|
||||
for filter_ in query_object.filter:
|
||||
# Skip temporal range filters - we're using inner bounds instead
|
||||
if (
|
||||
filter_.get("op") == FilterOperator.TEMPORAL_RANGE.value
|
||||
and query_object.granularity
|
||||
):
|
||||
continue
|
||||
|
||||
if converted_filters := _convert_query_object_filter(filter_, all_dimensions):
|
||||
filters.update(converted_filters)
|
||||
|
||||
return filters if filters else None
|
||||
|
||||
|
||||
def _convert_time_grain(time_grain: str) -> Grain | None:
|
||||
"""
|
||||
Convert a time grain string from the query object to a Grain enum.
|
||||
"""
|
||||
mapping = {
|
||||
grain.representation: grain
|
||||
for grain in [
|
||||
Second,
|
||||
Minute,
|
||||
Hour,
|
||||
Day,
|
||||
Week,
|
||||
Month,
|
||||
Quarter,
|
||||
Year,
|
||||
]
|
||||
}
|
||||
|
||||
return mapping.get(time_grain)
|
||||
|
||||
|
||||
def validate_query_object(
|
||||
query_object: QueryObject,
|
||||
) -> TypeGuard[ValidatedQueryObject]:
|
||||
"""
|
||||
Validate that the `QueryObject` is compatible with the `SemanticView`.
|
||||
|
||||
If some semantic view implementation supports these features we should add an
|
||||
attribute to the `SemanticViewImplementation` to indicate support for them.
|
||||
"""
|
||||
if not query_object.datasource:
|
||||
return False
|
||||
|
||||
query_object = cast(ValidatedQueryObject, query_object)
|
||||
|
||||
_validate_metrics(query_object)
|
||||
_validate_dimensions(query_object)
|
||||
_validate_filters(query_object)
|
||||
_validate_granularity(query_object)
|
||||
_validate_group_limit(query_object)
|
||||
_validate_orderby(query_object)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _validate_metrics(query_object: ValidatedQueryObject) -> None:
|
||||
"""
|
||||
Make sure metrics are defined in the semantic view.
|
||||
"""
|
||||
semantic_view = query_object.datasource.implementation
|
||||
|
||||
if any(not isinstance(metric, str) for metric in (query_object.metrics or [])):
|
||||
raise ValueError("Adhoc metrics are not supported in Semantic Views.")
|
||||
|
||||
metric_names = {metric.name for metric in semantic_view.metrics}
|
||||
if not set(query_object.metrics or []) <= metric_names:
|
||||
raise ValueError("All metrics must be defined in the Semantic View.")
|
||||
|
||||
|
||||
def _validate_dimensions(query_object: ValidatedQueryObject) -> None:
|
||||
"""
|
||||
Make sure all dimensions are defined in the semantic view.
|
||||
"""
|
||||
semantic_view = query_object.datasource.implementation
|
||||
dimension_names = {dimension.name for dimension in semantic_view.dimensions}
|
||||
|
||||
# Normalize all columns to dimension names
|
||||
normalized_columns = [
|
||||
_normalize_column(column, dimension_names) for column in query_object.columns
|
||||
]
|
||||
|
||||
if not set(normalized_columns) <= dimension_names:
|
||||
raise ValueError("All dimensions must be defined in the Semantic View.")
|
||||
|
||||
|
||||
def _validate_filters(query_object: ValidatedQueryObject) -> None:
|
||||
"""
|
||||
Make sure all filters are valid.
|
||||
"""
|
||||
for filter_ in query_object.filter:
|
||||
if isinstance(filter_["col"], dict):
|
||||
raise ValueError(
|
||||
"Adhoc columns are not supported in Semantic View filters."
|
||||
)
|
||||
if not filter_.get("op"):
|
||||
raise ValueError("All filters must have an operator defined.")
|
||||
|
||||
|
||||
def _validate_granularity(query_object: ValidatedQueryObject) -> None:
|
||||
"""
|
||||
Make sure time column and time grain are valid.
|
||||
"""
|
||||
semantic_view = query_object.datasource.implementation
|
||||
dimension_names = {dimension.name for dimension in semantic_view.dimensions}
|
||||
|
||||
if time_column := query_object.granularity:
|
||||
if time_column not in dimension_names:
|
||||
raise ValueError(
|
||||
"The time column must be defined in the Semantic View dimensions."
|
||||
)
|
||||
|
||||
if time_grain := query_object.extras.get("time_grain_sqla"):
|
||||
if not time_column:
|
||||
raise ValueError(
|
||||
"A time column must be specified when a time grain is provided."
|
||||
)
|
||||
|
||||
supported_time_grains = {
|
||||
dimension.grain
|
||||
for dimension in semantic_view.dimensions
|
||||
if dimension.name == time_column and dimension.grain
|
||||
}
|
||||
if _convert_time_grain(time_grain) not in supported_time_grains:
|
||||
raise ValueError(
|
||||
"The time grain is not supported for the time column in the "
|
||||
"Semantic View."
|
||||
)
|
||||
|
||||
|
||||
def _validate_group_limit(query_object: ValidatedQueryObject) -> None:
|
||||
"""
|
||||
Validate group limit related features in the query object.
|
||||
"""
|
||||
semantic_view = query_object.datasource.implementation
|
||||
|
||||
# no limit
|
||||
if query_object.series_limit == 0:
|
||||
return
|
||||
|
||||
if (
|
||||
query_object.series_columns
|
||||
and SemanticViewFeature.GROUP_LIMIT not in semantic_view.features
|
||||
):
|
||||
raise ValueError("Group limit is not supported in this Semantic View.")
|
||||
|
||||
if any(not isinstance(col, str) for col in query_object.series_columns):
|
||||
raise ValueError("Adhoc dimensions are not supported in series columns.")
|
||||
|
||||
metric_names = {metric.name for metric in semantic_view.metrics}
|
||||
if query_object.series_limit_metric and (
|
||||
not isinstance(query_object.series_limit_metric, str)
|
||||
or query_object.series_limit_metric not in metric_names
|
||||
):
|
||||
raise ValueError(
|
||||
"The series limit metric must be defined in the Semantic View."
|
||||
)
|
||||
|
||||
dimension_names = {dimension.name for dimension in semantic_view.dimensions}
|
||||
if not set(query_object.series_columns) <= dimension_names:
|
||||
raise ValueError("All series columns must be defined in the Semantic View.")
|
||||
|
||||
if (
|
||||
query_object.group_others_when_limit_reached
|
||||
and SemanticViewFeature.GROUP_OTHERS not in semantic_view.features
|
||||
):
|
||||
raise ValueError(
|
||||
"Grouping others when limit is reached is not supported in this Semantic "
|
||||
"View."
|
||||
)
|
||||
|
||||
|
||||
def _validate_orderby(query_object: ValidatedQueryObject) -> None:
|
||||
"""
|
||||
Validate order by elements in the query object.
|
||||
"""
|
||||
semantic_view = query_object.datasource.implementation
|
||||
|
||||
if (
|
||||
any(not isinstance(element, str) for element, _ in query_object.orderby)
|
||||
and SemanticViewFeature.ADHOC_EXPRESSIONS_IN_ORDERBY
|
||||
not in semantic_view.features
|
||||
):
|
||||
raise ValueError(
|
||||
"Adhoc expressions in order by are not supported in this Semantic View."
|
||||
)
|
||||
|
||||
elements = {orderby[0] for orderby in query_object.orderby}
|
||||
metric_names = {metric.name for metric in semantic_view.metrics}
|
||||
dimension_names = {dimension.name for dimension in semantic_view.dimensions}
|
||||
if not elements <= metric_names | dimension_names:
|
||||
raise ValueError("All order by elements must be defined in the Semantic View.")
|
||||
381
superset/semantic_layers/models.py
Normal file
381
superset/semantic_layers/models.py
Normal file
@@ -0,0 +1,381 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""Semantic layer models."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from collections.abc import Hashable
|
||||
from dataclasses import dataclass
|
||||
from functools import cached_property
|
||||
from importlib.metadata import entry_points
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
from flask_appbuilder import Model
|
||||
from sqlalchemy import Column, ForeignKey, Integer, String, Text
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy_utils import UUIDType
|
||||
from sqlalchemy_utils.types.json import JSONType
|
||||
|
||||
from superset.common.query_object import QueryObject
|
||||
from superset.explorables.base import TimeGrainDict
|
||||
from superset.extensions import encrypted_field_factory
|
||||
from superset.models.helpers import AuditMixinNullable, QueryResult
|
||||
from superset.semantic_layers.mapper import get_results
|
||||
from superset.semantic_layers.types import (
|
||||
BINARY,
|
||||
BOOLEAN,
|
||||
DATE,
|
||||
DATETIME,
|
||||
DECIMAL,
|
||||
INTEGER,
|
||||
INTERVAL,
|
||||
NUMBER,
|
||||
OBJECT,
|
||||
SemanticLayerImplementation,
|
||||
SemanticViewImplementation,
|
||||
STRING,
|
||||
TIME,
|
||||
Type,
|
||||
)
|
||||
from superset.utils import json
|
||||
from superset.utils.core import GenericDataType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.superset_typing import ExplorableData, QueryObjectDict
|
||||
|
||||
|
||||
def get_column_type(semantic_type: Type) -> GenericDataType:
|
||||
"""
|
||||
Map semantic layer types to generic data types.
|
||||
"""
|
||||
if semantic_type in {DATE, DATETIME, TIME}:
|
||||
return GenericDataType.TEMPORAL
|
||||
if semantic_type in {INTEGER, NUMBER, DECIMAL, INTERVAL}:
|
||||
return GenericDataType.NUMERIC
|
||||
if semantic_type is BOOLEAN:
|
||||
return GenericDataType.BOOLEAN
|
||||
if semantic_type in {STRING, OBJECT, BINARY}:
|
||||
return GenericDataType.STRING
|
||||
return GenericDataType.STRING
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MetricMetadata:
|
||||
metric_name: str
|
||||
expression: str
|
||||
verbose_name: str | None = None
|
||||
description: str | None = None
|
||||
d3format: str | None = None
|
||||
currency: dict[str, Any] | None = None
|
||||
warning_text: str | None = None
|
||||
certified_by: str | None = None
|
||||
certification_details: str | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ColumnMetadata:
|
||||
column_name: str
|
||||
type: str
|
||||
is_dttm: bool
|
||||
verbose_name: str | None = None
|
||||
description: str | None = None
|
||||
groupby: bool = True
|
||||
filterable: bool = True
|
||||
expression: str | None = None
|
||||
python_date_format: str | None = None
|
||||
advanced_data_type: str | None = None
|
||||
extra: str | None = None
|
||||
|
||||
|
||||
class SemanticLayer(AuditMixinNullable, Model):
|
||||
"""
|
||||
Semantic layer model.
|
||||
|
||||
A semantic layer provides an abstraction over data sources,
|
||||
allowing users to query data through a semantic interface.
|
||||
"""
|
||||
|
||||
__tablename__ = "semantic_layers"
|
||||
|
||||
uuid = Column(UUIDType(binary=True), primary_key=True, default=uuid.uuid4)
|
||||
|
||||
# Core fields
|
||||
name = Column(String(250), nullable=False)
|
||||
description = Column(Text, nullable=True)
|
||||
type = Column(String(250), nullable=False) # snowflake, etc
|
||||
|
||||
configuration = Column(encrypted_field_factory.create(JSONType), default=dict)
|
||||
cache_timeout = Column(Integer, nullable=True)
|
||||
|
||||
# Semantic views relationship
|
||||
semantic_views: list[SemanticView] = relationship(
|
||||
"SemanticView",
|
||||
back_populates="semantic_layer",
|
||||
cascade="all, delete-orphan",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.name or str(self.uuid)
|
||||
|
||||
@cached_property
|
||||
def implementation(
|
||||
self,
|
||||
) -> SemanticLayerImplementation[Any, SemanticViewImplementation]:
|
||||
"""
|
||||
Return semantic layer implementation.
|
||||
"""
|
||||
entry_point = next(
|
||||
iter(
|
||||
entry_points(
|
||||
group="superset.semantic_layers",
|
||||
name=self.type,
|
||||
)
|
||||
)
|
||||
)
|
||||
implementation_class = entry_point.load()
|
||||
|
||||
if not issubclass(implementation_class, SemanticLayerImplementation):
|
||||
raise TypeError(
|
||||
f"Entry point for semantic layer type '{self.type}' "
|
||||
"must be a subclass of SemanticLayerImplementation"
|
||||
)
|
||||
|
||||
return implementation_class.from_configuration(json.loads(self.configuration))
|
||||
|
||||
|
||||
class SemanticView(AuditMixinNullable, Model):
|
||||
"""
|
||||
Semantic view model.
|
||||
|
||||
A semantic view represents a queryable view within a semantic layer.
|
||||
"""
|
||||
|
||||
__tablename__ = "semantic_views"
|
||||
|
||||
uuid = Column(UUIDType(binary=True), primary_key=True, default=uuid.uuid4)
|
||||
|
||||
# Core fields
|
||||
name = Column(String(250), nullable=False)
|
||||
description = Column(Text, nullable=True)
|
||||
|
||||
configuration = Column(encrypted_field_factory.create(JSONType), default=dict)
|
||||
cache_timeout = Column(Integer, nullable=True)
|
||||
|
||||
# Semantic layer relationship
|
||||
semantic_layer_uuid = Column(
|
||||
UUIDType(binary=True),
|
||||
ForeignKey("semantic_layers.uuid", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
semantic_layer: SemanticLayer = relationship(
|
||||
"SemanticLayer",
|
||||
back_populates="semantic_views",
|
||||
foreign_keys=[semantic_layer_uuid],
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.name or str(self.uuid)
|
||||
|
||||
@cached_property
|
||||
def implementation(self) -> SemanticViewImplementation:
|
||||
"""
|
||||
Return semantic view implementation.
|
||||
"""
|
||||
return self.semantic_layer.implementation.get_semantic_view(
|
||||
self.name,
|
||||
json.loads(self.configuration),
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# Explorable protocol implementation
|
||||
# =========================================================================
|
||||
|
||||
def get_query_result(self, query_object: QueryObject) -> QueryResult:
|
||||
return get_results(query_object)
|
||||
|
||||
def get_query_str(self, query_obj: QueryObjectDict) -> str:
|
||||
return "Not implemented for semantic layers"
|
||||
|
||||
@property
|
||||
def uid(self) -> str:
|
||||
return self.implementation.uid()
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return "semantic_view"
|
||||
|
||||
@property
|
||||
def metrics(self) -> list[MetricMetadata]:
|
||||
return [
|
||||
MetricMetadata(
|
||||
metric_name=metric.name,
|
||||
expression=metric.definition,
|
||||
description=metric.description,
|
||||
)
|
||||
for metric in self.implementation.metrics
|
||||
]
|
||||
|
||||
@property
|
||||
def columns(self) -> list[ColumnMetadata]:
|
||||
return [
|
||||
ColumnMetadata(
|
||||
column_name=dimension.name,
|
||||
type=dimension.type.__name__,
|
||||
is_dttm=dimension.type in {DATE, TIME, DATETIME},
|
||||
description=dimension.description,
|
||||
expression=dimension.definition,
|
||||
extra=json.dumps({"grain": dimension.grain}),
|
||||
)
|
||||
for dimension in self.implementation.dimensions
|
||||
]
|
||||
|
||||
@property
|
||||
def column_names(self) -> list[str]:
|
||||
return [dimension.name for dimension in self.implementation.dimensions]
|
||||
|
||||
@property
|
||||
def data(self) -> ExplorableData:
|
||||
return {
|
||||
# core
|
||||
"id": self.uuid.hex,
|
||||
"uid": self.uid,
|
||||
"type": "semantic_view",
|
||||
"name": self.name,
|
||||
"columns": [
|
||||
{
|
||||
"advanced_data_type": None,
|
||||
"certification_details": None,
|
||||
"certified_by": None,
|
||||
"column_name": dimension.name,
|
||||
"description": dimension.description,
|
||||
"expression": dimension.definition,
|
||||
"filterable": True,
|
||||
"groupby": True,
|
||||
"id": None,
|
||||
"uuid": None,
|
||||
"is_certified": False,
|
||||
"is_dttm": dimension.type in {DATE, TIME, DATETIME},
|
||||
"python_date_format": None,
|
||||
"type": dimension.type.__name__,
|
||||
"type_generic": get_column_type(dimension.type),
|
||||
"verbose_name": None,
|
||||
"warning_markdown": None,
|
||||
}
|
||||
for dimension in self.implementation.dimensions
|
||||
],
|
||||
"metrics": [
|
||||
{
|
||||
"certification_details": None,
|
||||
"certified_by": None,
|
||||
"d3format": None,
|
||||
"description": metric.description,
|
||||
"expression": metric.definition,
|
||||
"id": None,
|
||||
"uuid": None,
|
||||
"is_certified": False,
|
||||
"metric_name": metric.name,
|
||||
"warning_markdown": None,
|
||||
"warning_text": None,
|
||||
"verbose_name": None,
|
||||
}
|
||||
for metric in self.implementation.metrics
|
||||
],
|
||||
"database": {},
|
||||
# UI features
|
||||
"verbose_map": {},
|
||||
"order_by_choices": [],
|
||||
"filter_select": True,
|
||||
"filter_select_enabled": True,
|
||||
"sql": None,
|
||||
"select_star": None,
|
||||
"owners": [],
|
||||
"description": self.description,
|
||||
"table_name": self.name,
|
||||
"column_types": [
|
||||
get_column_type(dimension.type)
|
||||
for dimension in self.implementation.dimensions
|
||||
],
|
||||
"column_names": [
|
||||
dimension.name for dimension in self.implementation.dimensions
|
||||
],
|
||||
# rare
|
||||
"column_formats": {},
|
||||
"datasource_name": self.name,
|
||||
"perm": self.perm,
|
||||
"offset": None,
|
||||
"cache_timeout": self.cache_timeout,
|
||||
"params": None,
|
||||
# sql-specific
|
||||
"schema": None,
|
||||
"catalog": None,
|
||||
"main_dttm_col": None,
|
||||
"time_grain_sqla": [],
|
||||
"granularity_sqla": [],
|
||||
"fetch_values_predicate": None,
|
||||
"template_params": None,
|
||||
"is_sqllab_view": False,
|
||||
"extra": None,
|
||||
"always_filter_main_dttm": False,
|
||||
"normalize_columns": False,
|
||||
# TODO XXX
|
||||
# "owners": [owner.id for owner in self.owners],
|
||||
"edit_url": "",
|
||||
"default_endpoint": None,
|
||||
"folders": [],
|
||||
"health_check_message": None,
|
||||
}
|
||||
|
||||
def get_extra_cache_keys(self, query_obj: QueryObjectDict) -> list[Hashable]:
|
||||
return []
|
||||
|
||||
@property
|
||||
def perm(self) -> str:
|
||||
return self.semantic_layer_uuid.hex + "::" + self.uuid.hex
|
||||
|
||||
@property
|
||||
def offset(self) -> int:
|
||||
# always return datetime as UTC
|
||||
return 0
|
||||
|
||||
@property
|
||||
def get_time_grains(self) -> list[TimeGrainDict]:
|
||||
return [
|
||||
{
|
||||
"name": dimension.grain.name,
|
||||
"function": "",
|
||||
"duration": dimension.grain.representation,
|
||||
}
|
||||
for dimension in self.implementation.dimensions
|
||||
if dimension.grain
|
||||
]
|
||||
|
||||
def has_drill_by_columns(self, column_names: list[str]) -> bool:
|
||||
dimension_names = {
|
||||
dimension.name for dimension in self.implementation.dimensions
|
||||
}
|
||||
return all(column_name in dimension_names for column_name in column_names)
|
||||
|
||||
@property
|
||||
def is_rls_supported(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def query_language(self) -> str | None:
|
||||
return None
|
||||
26
superset/semantic_layers/snowflake/__init__.py
Normal file
26
superset/semantic_layers/snowflake/__init__.py
Normal file
@@ -0,0 +1,26 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from superset.semantic_layers.snowflake.schemas import SnowflakeConfiguration
|
||||
from superset.semantic_layers.snowflake.semantic_layer import SnowflakeSemanticLayer
|
||||
from superset.semantic_layers.snowflake.semantic_view import SnowflakeSemanticView
|
||||
|
||||
__all__ = [
|
||||
"SnowflakeConfiguration",
|
||||
"SnowflakeSemanticLayer",
|
||||
"SnowflakeSemanticView",
|
||||
]
|
||||
130
superset/semantic_layers/snowflake/schemas.py
Normal file
130
superset/semantic_layers/snowflake/schemas.py
Normal file
@@ -0,0 +1,130 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator, SecretStr
|
||||
|
||||
|
||||
class UserPasswordAuth(BaseModel):
|
||||
"""
|
||||
Username and password authentication.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(title="Username and password")
|
||||
|
||||
auth_type: Literal["user_password"] = "user_password"
|
||||
username: str = Field(description="The username to authenticate as.")
|
||||
password: SecretStr = Field(
|
||||
description="The password to authenticate with.",
|
||||
repr=False,
|
||||
)
|
||||
|
||||
|
||||
class PrivateKeyAuth(BaseModel):
|
||||
"""
|
||||
Private key authentication.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(title="Private key")
|
||||
|
||||
auth_type: Literal["private_key"] = "private_key"
|
||||
private_key: SecretStr = Field(
|
||||
description="The private key to authenticate with, in PEM format.",
|
||||
repr=False,
|
||||
)
|
||||
private_key_password: SecretStr = Field(
|
||||
description="The password to decrypt the private key with.",
|
||||
repr=False,
|
||||
)
|
||||
|
||||
|
||||
class SnowflakeConfiguration(BaseModel):
|
||||
"""
|
||||
Parameters needed to connect to Snowflake.
|
||||
"""
|
||||
|
||||
# account is the only required parameter
|
||||
account_identifier: str = Field(
|
||||
description="The Snowflake account identifier.",
|
||||
json_schema_extra={"examples": ["abc12345"]},
|
||||
)
|
||||
|
||||
role: str | None = Field(
|
||||
default=None,
|
||||
description="The default role to use.",
|
||||
json_schema_extra={"examples": ["myrole"]},
|
||||
)
|
||||
warehouse: str | None = Field(
|
||||
default=None,
|
||||
description="The default warehouse to use.",
|
||||
json_schema_extra={"examples": ["testwh"]},
|
||||
)
|
||||
|
||||
auth: Union[UserPasswordAuth, PrivateKeyAuth] = Field(
|
||||
discriminator="auth_type",
|
||||
description="Authentication method",
|
||||
)
|
||||
|
||||
# database and schema can be optionally provided; if not provided the user
|
||||
# will be able to browse databases/schemas
|
||||
database: str | None = Field(
|
||||
default=None,
|
||||
description="The default database to use.",
|
||||
json_schema_extra={
|
||||
"examples": ["testdb"],
|
||||
"x-dynamic": True,
|
||||
"x-dependsOn": ["account_identifier", "auth"],
|
||||
},
|
||||
)
|
||||
allow_changing_database: bool = Field(
|
||||
default=False,
|
||||
description="Allow changing the default database.",
|
||||
)
|
||||
schema_: str | None = Field(
|
||||
default=None,
|
||||
description="The default schema to use.",
|
||||
json_schema_extra={
|
||||
"examples": ["public"],
|
||||
"x-dynamic": True,
|
||||
"x-dependsOn": ["account_identifier", "auth", "database"],
|
||||
},
|
||||
# `schema` is an attribute of `BaseModel` so it needs to be aliased
|
||||
alias="schema",
|
||||
)
|
||||
allow_changing_schema: bool = Field(
|
||||
default=False,
|
||||
description="Allow changing the default schema.",
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_database_schema_settings(self) -> SnowflakeConfiguration:
|
||||
"""
|
||||
Validate that if database or schema is not specified, the corresponding
|
||||
allow_changing flag must be true.
|
||||
"""
|
||||
if not self.database and not self.allow_changing_database:
|
||||
raise ValueError(
|
||||
"If no database is specified, allow_changing_database must be true"
|
||||
)
|
||||
if not self.schema_ and not self.allow_changing_schema:
|
||||
raise ValueError(
|
||||
"If no schema is specified, allow_changing_schema must be true"
|
||||
)
|
||||
return self
|
||||
269
superset/semantic_layers/snowflake/semantic_layer.py
Normal file
269
superset/semantic_layers/snowflake/semantic_layer.py
Normal file
@@ -0,0 +1,269 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from textwrap import dedent
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import create_model, Field
|
||||
from snowflake.connector import connect
|
||||
from snowflake.connector.connection import SnowflakeConnection
|
||||
|
||||
from superset.semantic_layers.snowflake.schemas import SnowflakeConfiguration
|
||||
from superset.semantic_layers.snowflake.semantic_view import SnowflakeSemanticView
|
||||
from superset.semantic_layers.snowflake.utils import get_connection_parameters
|
||||
from superset.semantic_layers.types import (
|
||||
SemanticLayerImplementation,
|
||||
)
|
||||
|
||||
|
||||
class SnowflakeSemanticLayer(
|
||||
SemanticLayerImplementation[SnowflakeConfiguration, SnowflakeSemanticView]
|
||||
):
|
||||
id = "snowflake"
|
||||
name = "Snowflake Semantic Layer"
|
||||
description = "Connect to semantic views stored in Snowflake."
|
||||
|
||||
@classmethod
|
||||
def from_configuration(
|
||||
cls,
|
||||
configuration: dict[str, Any],
|
||||
) -> SnowflakeSemanticLayer:
|
||||
"""
|
||||
Create a SnowflakeSemanticLayer from a configuration dictionary.
|
||||
"""
|
||||
config = SnowflakeConfiguration.model_validate(configuration)
|
||||
return cls(config)
|
||||
|
||||
@classmethod
|
||||
def get_configuration_schema(
|
||||
cls,
|
||||
configuration: SnowflakeConfiguration | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Get the JSON schema for the configuration needed to add the semantic layer.
|
||||
|
||||
A partial configuration can be sent to improve the schema. For example,
|
||||
providing account and auth will allow the schema to provide a list of
|
||||
databases; providing a database will allow the schema to provide a list of
|
||||
schemas.
|
||||
|
||||
Note that database and schema can both be left empty when the semantic layer is
|
||||
added to Superset; the user will then have to provide them when loading
|
||||
semantic views.
|
||||
"""
|
||||
schema = SnowflakeConfiguration.model_json_schema()
|
||||
properties = schema["properties"]
|
||||
|
||||
if configuration is None:
|
||||
# set these to empty; they will be populated when a partial configuration is
|
||||
# passed
|
||||
properties["database"]["enum"] = []
|
||||
properties["schema"]["enum"] = []
|
||||
|
||||
return schema
|
||||
|
||||
connection_parameters = get_connection_parameters(configuration)
|
||||
with connect(**connection_parameters) as connection:
|
||||
if all(
|
||||
getattr(configuration, dependency)
|
||||
for dependency in properties["database"].get("x-dependsOn", [])
|
||||
):
|
||||
options = cls._fetch_databases(connection)
|
||||
properties["database"]["enum"] = list(options)
|
||||
|
||||
if (
|
||||
all(
|
||||
getattr(configuration, dependency)
|
||||
for dependency in properties["schema"].get("x-dependsOn", [])
|
||||
)
|
||||
and configuration.database
|
||||
):
|
||||
options = cls._fetch_schemas(connection, configuration.database)
|
||||
properties["schema"]["enum"] = list(options)
|
||||
|
||||
return schema
|
||||
|
||||
@classmethod
|
||||
def get_runtime_schema(
|
||||
cls,
|
||||
configuration: SnowflakeConfiguration,
|
||||
runtime_data: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Get the JSON schema for the runtime parameters needed to load semantic views.
|
||||
|
||||
The schema can be enriched with actual values when `runtime_data` is provided,
|
||||
enabling dynamic schema updates (e.g., populating schema dropdown after
|
||||
database is selected).
|
||||
"""
|
||||
fields: dict[str, tuple[Any, Field]] = {}
|
||||
|
||||
# update configuration with runtime data, for example, to select a schema after
|
||||
# the database has been selected
|
||||
configuration = configuration.model_copy(update=runtime_data)
|
||||
|
||||
connection_parameters = get_connection_parameters(configuration)
|
||||
with connect(**connection_parameters) as connection:
|
||||
if not configuration.database or configuration.allow_changing_database:
|
||||
options = cls._fetch_databases(connection)
|
||||
fields["database"] = (
|
||||
Literal[*options],
|
||||
Field(description="The default database to use."),
|
||||
)
|
||||
|
||||
if not configuration.schema_ or configuration.allow_changing_schema:
|
||||
if configuration.database:
|
||||
options = cls._fetch_schemas(connection, configuration.database)
|
||||
fields["schema_"] = (
|
||||
Literal[*options],
|
||||
Field(
|
||||
description="The default schema to use.",
|
||||
alias="schema",
|
||||
json_schema_extra=(
|
||||
{
|
||||
"x-dynamic": True,
|
||||
"x-dependsOn": ["database"],
|
||||
}
|
||||
if "database" in fields
|
||||
else {}
|
||||
),
|
||||
),
|
||||
)
|
||||
else:
|
||||
# Database not provided yet, add schema as empty
|
||||
# (will be populated dynamically)
|
||||
fields["schema_"] = (
|
||||
str | None,
|
||||
Field(
|
||||
default=None,
|
||||
description="The default schema to use.",
|
||||
alias="schema",
|
||||
json_schema_extra={
|
||||
"x-dynamic": True,
|
||||
"x-dependsOn": ["database"],
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
return create_model("RuntimeParameters", **fields).model_json_schema()
|
||||
|
||||
@classmethod
|
||||
def _fetch_databases(cls, connection: SnowflakeConnection) -> set[str]:
|
||||
"""
|
||||
Fetch the list of databases available in the Snowflake account.
|
||||
|
||||
We use `SHOW DATABASES` instead of querying the information schema since it
|
||||
allows to retrieve the list of databases without having to specify a database
|
||||
when connecting.
|
||||
"""
|
||||
cursor = connection.cursor()
|
||||
cursor.execute("SHOW DATABASES")
|
||||
return {row[1] for row in cursor}
|
||||
|
||||
@classmethod
|
||||
def _fetch_schemas(
|
||||
cls,
|
||||
connection: SnowflakeConnection,
|
||||
database: str | None,
|
||||
) -> set[str]:
|
||||
"""
|
||||
Fetch the list of schemas available in a given database.
|
||||
|
||||
The connection should already have the database set in its context.
|
||||
"""
|
||||
if not database:
|
||||
return set()
|
||||
|
||||
cursor = connection.cursor()
|
||||
query = dedent(
|
||||
"""
|
||||
SELECT SCHEMA_NAME
|
||||
FROM INFORMATION_SCHEMA.SCHEMATA
|
||||
WHERE CATALOG_NAME = ?
|
||||
"""
|
||||
).strip()
|
||||
return {row[0] for row in cursor.execute(query, (database,))}
|
||||
|
||||
def __init__(self, configuration: SnowflakeConfiguration):
|
||||
self.configuration = configuration
|
||||
|
||||
def get_semantic_views(
|
||||
self,
|
||||
runtime_configuration: dict[str, Any],
|
||||
) -> set[SnowflakeSemanticView]:
|
||||
"""
|
||||
Get the semantic views available in the semantic layer.
|
||||
"""
|
||||
# Avoid circular import
|
||||
from superset.semantic_layers.snowflake.semantic_view import (
|
||||
SnowflakeSemanticView,
|
||||
)
|
||||
|
||||
# create a new configuration with the runtime parameters
|
||||
configuration = self.configuration.model_copy(update=runtime_configuration)
|
||||
|
||||
connection_parameters = get_connection_parameters(configuration)
|
||||
with connect(**connection_parameters) as connection:
|
||||
cursor = connection.cursor()
|
||||
query = dedent(
|
||||
"""
|
||||
SHOW SEMANTIC VIEWS
|
||||
->> SELECT "name" FROM $1;
|
||||
"""
|
||||
).strip()
|
||||
views = {
|
||||
SnowflakeSemanticView(row[0], configuration)
|
||||
for row in cursor.execute(query)
|
||||
}
|
||||
|
||||
return views
|
||||
|
||||
def get_semantic_view(
|
||||
self,
|
||||
name: str,
|
||||
additional_configuration: dict[str, Any],
|
||||
) -> SnowflakeSemanticView:
|
||||
"""
|
||||
Get a specific semantic view by name.
|
||||
"""
|
||||
# Avoid circular import
|
||||
from superset.semantic_layers.snowflake.semantic_view import (
|
||||
SnowflakeSemanticView,
|
||||
)
|
||||
|
||||
# create a new configuration with the additional parameters
|
||||
configuration = self.configuration.model_copy(update=additional_configuration)
|
||||
return SnowflakeSemanticView(name, configuration)
|
||||
|
||||
# check that the semantic view exists
|
||||
connection_parameters = get_connection_parameters(configuration)
|
||||
with connect(**connection_parameters) as connection:
|
||||
cursor = connection.cursor()
|
||||
query = dedent(
|
||||
"""
|
||||
SHOW SEMANTIC VIEWS
|
||||
->> SELECT "name" FROM $1 WHERE "name" = ?;
|
||||
"""
|
||||
).strip()
|
||||
cursor.execute(query, (name,))
|
||||
rows = cursor.fetchall()
|
||||
if not rows:
|
||||
raise ValueError(f'Semantic view "{name}" does not exist.')
|
||||
|
||||
return SnowflakeSemanticView(name, configuration)
|
||||
873
superset/semantic_layers/snowflake/semantic_view.py
Normal file
873
superset/semantic_layers/snowflake/semantic_view.py
Normal file
@@ -0,0 +1,873 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
# ruff: noqa: S608
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from textwrap import dedent
|
||||
|
||||
from pandas import DataFrame
|
||||
from snowflake.connector import connect, DictCursor
|
||||
from snowflake.sqlalchemy.snowdialect import SnowflakeDialect
|
||||
|
||||
from superset.semantic_layers.snowflake.schemas import SnowflakeConfiguration
|
||||
from superset.semantic_layers.snowflake.utils import (
|
||||
get_connection_parameters,
|
||||
substitute_parameters,
|
||||
validate_order_by,
|
||||
)
|
||||
from superset.semantic_layers.types import (
|
||||
AdhocExpression,
|
||||
AdhocFilter,
|
||||
BINARY,
|
||||
BOOLEAN,
|
||||
DATE,
|
||||
DATETIME,
|
||||
DECIMAL,
|
||||
Dimension,
|
||||
Filter,
|
||||
FilterValues,
|
||||
GroupLimit,
|
||||
INTEGER,
|
||||
Metric,
|
||||
NUMBER,
|
||||
OBJECT,
|
||||
Operator,
|
||||
OrderDirection,
|
||||
OrderTuple,
|
||||
PredicateType,
|
||||
SemanticRequest,
|
||||
SemanticResult,
|
||||
SemanticViewFeature,
|
||||
SemanticViewImplementation,
|
||||
STRING,
|
||||
TIME,
|
||||
Type,
|
||||
)
|
||||
|
||||
REQUEST_TYPE = "snowflake"
|
||||
|
||||
|
||||
class SnowflakeSemanticView(SemanticViewImplementation):
|
||||
features = frozenset(
|
||||
{
|
||||
SemanticViewFeature.ADHOC_EXPRESSIONS_IN_ORDERBY,
|
||||
SemanticViewFeature.GROUP_LIMIT,
|
||||
SemanticViewFeature.GROUP_OTHERS,
|
||||
}
|
||||
)
|
||||
|
||||
def __init__(self, name: str, configuration: SnowflakeConfiguration):
|
||||
self.configuration = configuration
|
||||
self.name = name
|
||||
|
||||
self._quote = SnowflakeDialect().identifier_preparer.quote
|
||||
|
||||
self.dimensions = self.get_dimensions()
|
||||
self.metrics = self.get_metrics()
|
||||
|
||||
def uid(self) -> str:
|
||||
return ".".join(
|
||||
self._quote(part)
|
||||
for part in (
|
||||
self.configuration.database,
|
||||
self.configuration.schema_,
|
||||
self.name,
|
||||
)
|
||||
)
|
||||
|
||||
def get_dimensions(self) -> set[Dimension]:
|
||||
"""
|
||||
Get the dimensions defined in the semantic view.
|
||||
|
||||
Even though Snowflake supports `SHOW SEMANTIC DIMENSIONS IN my_semantic_view`,
|
||||
it doesn't return the expression of dimensions, so we use a slightly more
|
||||
complicated query to get all the information we need in one go.
|
||||
"""
|
||||
dimensions: set[Dimension] = set()
|
||||
|
||||
query = dedent(
|
||||
f"""
|
||||
DESC SEMANTIC VIEW {self.uid()}
|
||||
->> SELECT "object_name", "property", "property_value"
|
||||
FROM $1
|
||||
WHERE
|
||||
"object_kind" = 'DIMENSION' AND
|
||||
"property" IN ('COMMENT', 'DATA_TYPE', 'EXPRESSION', 'TABLE');
|
||||
"""
|
||||
).strip()
|
||||
|
||||
connection_parameters = get_connection_parameters(self.configuration)
|
||||
with connect(**connection_parameters) as connection:
|
||||
cursor = connection.cursor(DictCursor)
|
||||
rows = cursor.execute(query).fetchall()
|
||||
|
||||
for name, group in itertools.groupby(rows, key=lambda x: x["object_name"]):
|
||||
attributes = defaultdict(set)
|
||||
for row in group:
|
||||
attributes[row["property"]].add(row["property_value"])
|
||||
|
||||
table = next(iter(attributes["TABLE"]))
|
||||
id_ = table + "." + name
|
||||
type_ = self._get_type(next(iter(attributes["DATA_TYPE"])))
|
||||
description = next(iter(attributes["COMMENT"]), None)
|
||||
definition = next(iter(attributes["EXPRESSION"]), None)
|
||||
|
||||
dimensions.add(Dimension(id_, name, type_, description, definition))
|
||||
|
||||
return dimensions
|
||||
|
||||
def get_metrics(self) -> set[Metric]:
|
||||
"""
|
||||
Get the metrics defined in the semantic view.
|
||||
"""
|
||||
metrics: set[Metric] = set()
|
||||
|
||||
query = dedent(
|
||||
f"""
|
||||
DESC SEMANTIC VIEW {self.uid()}
|
||||
->> SELECT "object_name", "property", "property_value"
|
||||
FROM $1
|
||||
WHERE
|
||||
"object_kind" = 'METRIC' AND
|
||||
"property" IN ('COMMENT', 'DATA_TYPE', 'EXPRESSION', 'TABLE');
|
||||
"""
|
||||
).strip()
|
||||
|
||||
connection_parameters = get_connection_parameters(self.configuration)
|
||||
with connect(**connection_parameters) as connection:
|
||||
cursor = connection.cursor(DictCursor)
|
||||
rows = cursor.execute(query).fetchall()
|
||||
|
||||
for name, group in itertools.groupby(rows, key=lambda x: x["object_name"]):
|
||||
attributes = defaultdict(set)
|
||||
for row in group:
|
||||
attributes[row["property"]].add(row["property_value"])
|
||||
|
||||
table = next(iter(attributes["TABLE"]))
|
||||
id_ = table + "." + name
|
||||
type_ = self._get_type(next(iter(attributes["DATA_TYPE"])))
|
||||
description = next(iter(attributes["COMMENT"]), None)
|
||||
definition = next(iter(attributes["EXPRESSION"]), None)
|
||||
|
||||
metrics.add(Metric(id_, name, type_, definition, description))
|
||||
|
||||
return metrics
|
||||
|
||||
def _get_type(self, snowflake_type: str | None) -> type[Type]:
|
||||
"""
|
||||
Return the semantic type corresponding to a Snowflake type.
|
||||
"""
|
||||
if snowflake_type is None:
|
||||
return STRING
|
||||
|
||||
type_map = {
|
||||
STRING: {r"VARCHAR\(\d+\)$", "STRING$", "TEXT$", r"CHAR\(\d+\)$"},
|
||||
INTEGER: {r"NUMBER\(38,\s?0\)$", "INT$", "INTEGER$", "BIGINT$"},
|
||||
DECIMAL: {r"NUMBER\(10,\s?2\)$"},
|
||||
NUMBER: {r"NUMBER\(\d+,\s?\d+\)$", "FLOAT$", "DOUBLE$"},
|
||||
BOOLEAN: {"BOOLEAN$"},
|
||||
DATE: {"DATE$"},
|
||||
DATETIME: {"TIMESTAMP_TZ$", "TIMESTAMP__NTZ$"},
|
||||
TIME: {"TIME$"},
|
||||
OBJECT: {"OBJECT$"},
|
||||
BINARY: {r"BINARY\(\d+\)$", r"VARBINARY\(\d+\)$"},
|
||||
}
|
||||
for semantic_type, patterns in type_map.items():
|
||||
if any(
|
||||
re.match(pattern, snowflake_type, re.IGNORECASE) for pattern in patterns
|
||||
):
|
||||
return semantic_type
|
||||
|
||||
return STRING
|
||||
|
||||
def _build_predicates(
|
||||
self,
|
||||
filters: list[Filter | AdhocFilter],
|
||||
) -> tuple[str, tuple[FilterValues, ...]]:
|
||||
"""
|
||||
Convert a set of filters to a single `AND`ed predicate.
|
||||
|
||||
Caller should check the types of filters beforehand, as this method does not
|
||||
differentiate between `WHERE` and `HAVING` predicates.
|
||||
"""
|
||||
if not filters:
|
||||
return "", ()
|
||||
|
||||
# convert filters predicate with associated parameters; native filters are
|
||||
# already strings, so we keep them as-is
|
||||
unary_operators = {Operator.IS_NULL, Operator.IS_NOT_NULL}
|
||||
predicates: list[str] = []
|
||||
parameters: list[FilterValues] = []
|
||||
for filter_ in filters or set():
|
||||
if isinstance(filter_, AdhocFilter):
|
||||
predicates.append(f"({filter_.definition})")
|
||||
else:
|
||||
predicates.append(f"({self._build_native_filter(filter_)})")
|
||||
if filter_.operator not in unary_operators:
|
||||
parameters.extend(
|
||||
[filter_.value]
|
||||
if not isinstance(filter_.value, (set, frozenset))
|
||||
else filter_.value
|
||||
)
|
||||
|
||||
return " AND ".join(predicates), tuple(parameters)
|
||||
|
||||
def get_values(
|
||||
self,
|
||||
dimension: Dimension,
|
||||
filters: set[Filter | AdhocFilter] | None = None,
|
||||
) -> SemanticResult:
|
||||
"""
|
||||
Return distinct values for a dimension.
|
||||
"""
|
||||
where_clause, parameters = self._build_predicates(
|
||||
sorted(
|
||||
filter_
|
||||
for filter_ in (filters or [])
|
||||
if filter_.type == PredicateType.WHERE
|
||||
)
|
||||
)
|
||||
query = dedent(
|
||||
f"""
|
||||
SELECT {self._quote(dimension.name)}
|
||||
FROM SEMANTIC_VIEW(
|
||||
{self.uid()}
|
||||
DIMENSIONS {dimension.id}
|
||||
{"WHERE " + where_clause if where_clause else ""}
|
||||
)
|
||||
"""
|
||||
).strip()
|
||||
connection_parameters = get_connection_parameters(self.configuration)
|
||||
with connect(**connection_parameters) as connection:
|
||||
df = connection.cursor().execute(query, parameters).fetch_pandas_all()
|
||||
|
||||
return SemanticResult(
|
||||
requests=[
|
||||
SemanticRequest(
|
||||
REQUEST_TYPE,
|
||||
substitute_parameters(query, parameters),
|
||||
)
|
||||
],
|
||||
results=df,
|
||||
)
|
||||
|
||||
def _build_native_filter(self, filter_: Filter) -> str:
|
||||
"""
|
||||
Convert a Filter to a AdhocFilter.
|
||||
"""
|
||||
column = filter_.column
|
||||
operator = filter_.operator
|
||||
value = filter_.value
|
||||
|
||||
column_name = self._quote(column.name)
|
||||
|
||||
# Handle IS NULL and IS NOT NULL operators (no value needed)
|
||||
if operator in {Operator.IS_NULL, Operator.IS_NOT_NULL}:
|
||||
return f"{column_name} {operator.value}"
|
||||
|
||||
# Handle IN and NOT IN operators (set values)
|
||||
if operator in {Operator.IN, Operator.NOT_IN}:
|
||||
parameter_count = len(value) if isinstance(value, (set, frozenset)) else 1
|
||||
formatted_values = ", ".join("?" for _ in range(parameter_count))
|
||||
return f"{column_name} {operator.value} ({formatted_values})"
|
||||
|
||||
return f"{column_name} {operator.value} ?"
|
||||
|
||||
def get_dataframe(
|
||||
self,
|
||||
metrics: list[Metric],
|
||||
dimensions: list[Dimension],
|
||||
filters: set[Filter | AdhocFilter] | None = None,
|
||||
order: list[OrderTuple] | None = None,
|
||||
limit: int | None = None,
|
||||
offset: int | None = None,
|
||||
*,
|
||||
group_limit: GroupLimit | None = None,
|
||||
) -> SemanticResult:
|
||||
"""
|
||||
Execute a query and return the results as a (wrapped) Pandas DataFrame.
|
||||
"""
|
||||
if not metrics and not dimensions:
|
||||
return DataFrame()
|
||||
|
||||
query, parameters = self._get_query(
|
||||
metrics,
|
||||
dimensions,
|
||||
filters,
|
||||
order,
|
||||
limit,
|
||||
offset,
|
||||
group_limit,
|
||||
)
|
||||
connection_parameters = get_connection_parameters(self.configuration)
|
||||
with connect(**connection_parameters) as connection:
|
||||
df = connection.cursor().execute(query, parameters).fetch_pandas_all()
|
||||
|
||||
# map column names to dimension/metric names instead of IDs
|
||||
mapping = {
|
||||
**{dimension.id: dimension.name for dimension in dimensions},
|
||||
**{metric.id: metric.name for metric in metrics},
|
||||
}
|
||||
df.rename(columns=mapping, inplace=True)
|
||||
|
||||
return SemanticResult(
|
||||
requests=[
|
||||
SemanticRequest(
|
||||
REQUEST_TYPE,
|
||||
substitute_parameters(query, parameters),
|
||||
)
|
||||
],
|
||||
results=df,
|
||||
)
|
||||
|
||||
def get_row_count(
|
||||
self,
|
||||
metrics: list[Metric],
|
||||
dimensions: list[Dimension],
|
||||
filters: set[Filter | AdhocFilter] | None = None,
|
||||
order: list[OrderTuple] | None = None,
|
||||
limit: int | None = None,
|
||||
offset: int | None = None,
|
||||
*,
|
||||
group_limit: GroupLimit | None = None,
|
||||
) -> SemanticResult:
|
||||
"""
|
||||
Execute a query and return the number of rows the result would have.
|
||||
"""
|
||||
if not metrics and not dimensions:
|
||||
return SemanticResult(
|
||||
requests=[],
|
||||
results=DataFrame([[0]], columns=["COUNT"]),
|
||||
)
|
||||
|
||||
query, parameters = self._get_query(
|
||||
metrics,
|
||||
dimensions,
|
||||
filters,
|
||||
order,
|
||||
limit,
|
||||
offset,
|
||||
group_limit,
|
||||
)
|
||||
query = f"SELECT COUNT(*) FROM ({query}) AS subquery"
|
||||
connection_parameters = get_connection_parameters(self.configuration)
|
||||
with connect(**connection_parameters) as connection:
|
||||
df = connection.cursor().execute(query, parameters).fechone()[0]
|
||||
|
||||
return SemanticResult(
|
||||
requests=[
|
||||
SemanticRequest(
|
||||
REQUEST_TYPE,
|
||||
substitute_parameters(query, parameters),
|
||||
)
|
||||
],
|
||||
results=df,
|
||||
)
|
||||
|
||||
def _get_query(
|
||||
self,
|
||||
metrics: list[Metric],
|
||||
dimensions: list[Dimension],
|
||||
filters: set[Filter | AdhocFilter] | None = None,
|
||||
order: list[OrderTuple] | None = None,
|
||||
limit: int | None = None,
|
||||
offset: int | None = None,
|
||||
group_limit: GroupLimit | None = None,
|
||||
) -> tuple[str, tuple[FilterValues, ...]]:
|
||||
"""
|
||||
Build a query to fetch data from the semantic view.
|
||||
|
||||
This also returns the parameters need to run `cursor.execute()`, passed
|
||||
separately to prevent SQL injection.
|
||||
"""
|
||||
if limit is None and offset is not None:
|
||||
raise ValueError("Offset cannot be set without limit")
|
||||
|
||||
filters = filters or set()
|
||||
where_clause, where_parameters = self._build_predicates(
|
||||
# XXX sort to ensure deterministic order for parameters
|
||||
[filter_ for filter_ in filters if filter_.type == PredicateType.WHERE]
|
||||
)
|
||||
# having clauses are not supported, since there's no GROUP BY
|
||||
if any(filter_.type == PredicateType.HAVING for filter_ in filters):
|
||||
raise ValueError("HAVING filters are not supported")
|
||||
|
||||
if group_limit:
|
||||
query, cte_parameters = self._build_query_with_group_limit(
|
||||
metrics,
|
||||
dimensions,
|
||||
where_clause,
|
||||
order,
|
||||
limit,
|
||||
offset,
|
||||
group_limit,
|
||||
)
|
||||
# Combine parameters: CTE params first, then main query params
|
||||
all_parameters = cte_parameters + where_parameters
|
||||
else:
|
||||
query = self._build_simple_query(
|
||||
metrics,
|
||||
dimensions,
|
||||
where_clause,
|
||||
order,
|
||||
limit,
|
||||
offset,
|
||||
)
|
||||
all_parameters = where_parameters
|
||||
|
||||
return query, all_parameters
|
||||
|
||||
def _alias_element(self, element: Metric | Dimension) -> str:
|
||||
"""
|
||||
Generate an aliased column expression for a metric or dimension.
|
||||
"""
|
||||
return f"{element.id} AS {self._quote(element.id)}"
|
||||
|
||||
def _build_order_clause(
|
||||
self,
|
||||
order: list[OrderTuple] | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Build the ORDER BY clause from a list of (element, direction) tuples.
|
||||
|
||||
Note that for adhoc expressions, Superset will still add `ASC` or `DESC` to the
|
||||
end, which means adhoc expressions can contain multiple columns as long as the
|
||||
last one has no direction specified.
|
||||
|
||||
This is fine:
|
||||
|
||||
gender ASC, COUNT(*)
|
||||
|
||||
But this is not
|
||||
|
||||
gender ASC, COUNT(*) DESC
|
||||
|
||||
The latter will produce a query that looks like this:
|
||||
|
||||
... ORDER BY gender ASC, COUNT(*) DESC DESC
|
||||
|
||||
"""
|
||||
if not order:
|
||||
return ""
|
||||
|
||||
def build_element(element: Metric | Dimension | AdhocExpression) -> str:
|
||||
if isinstance(element, AdhocExpression):
|
||||
validate_order_by(element.definition)
|
||||
return element.definition
|
||||
return self._quote(element.id)
|
||||
|
||||
return ", ".join(
|
||||
f"{build_element(element)} {direction.value}"
|
||||
for element, direction in order
|
||||
)
|
||||
|
||||
def _get_temporal_dimension(
|
||||
self,
|
||||
dimensions: list[Dimension],
|
||||
) -> Dimension | None:
|
||||
"""
|
||||
Find the first temporal dimension in the list.
|
||||
|
||||
Returns the first dimension with a temporal type (DATE, DATETIME, TIME),
|
||||
or None if no temporal dimension is found.
|
||||
"""
|
||||
temporal_types = {DATE, DATETIME, TIME}
|
||||
for dimension in dimensions:
|
||||
if dimension.type in temporal_types:
|
||||
return dimension
|
||||
return None
|
||||
|
||||
def _get_default_order(
|
||||
self,
|
||||
dimensions: list[Dimension],
|
||||
order: list[OrderTuple] | None,
|
||||
) -> list[OrderTuple] | None:
|
||||
"""
|
||||
Get the order to use, prepending temporal sort if needed.
|
||||
|
||||
If there's a temporal dimension in the query and it's not already
|
||||
in the order, prepends an ascending sort by that dimension.
|
||||
This ensures time-series data is always sorted chronologically first.
|
||||
"""
|
||||
temporal_dimension = self._get_temporal_dimension(dimensions)
|
||||
if not temporal_dimension:
|
||||
return order
|
||||
|
||||
# Check if temporal dimension is already in the order
|
||||
if order:
|
||||
for element, _ in order:
|
||||
if isinstance(element, Dimension) and element.id == temporal_dimension.id:
|
||||
return order
|
||||
# Prepend temporal dimension to existing order
|
||||
return [(temporal_dimension, OrderDirection.ASC)] + list(order)
|
||||
|
||||
# No order specified, use temporal dimension
|
||||
return [(temporal_dimension, OrderDirection.ASC)]
|
||||
|
||||
def _build_simple_query(
|
||||
self,
|
||||
metrics: list[Metric],
|
||||
dimensions: list[Dimension],
|
||||
where_clause: str,
|
||||
order: list[OrderTuple] | None,
|
||||
limit: int | None,
|
||||
offset: int | None,
|
||||
) -> str:
|
||||
"""
|
||||
Build a query without group limiting.
|
||||
"""
|
||||
dimension_arguments = ", ".join(
|
||||
self._alias_element(dimension) for dimension in dimensions
|
||||
)
|
||||
metric_arguments = ", ".join(self._alias_element(metric) for metric in metrics)
|
||||
# Use default temporal ordering if no explicit order is provided
|
||||
effective_order = self._get_default_order(dimensions, order)
|
||||
order_clause = self._build_order_clause(effective_order)
|
||||
|
||||
return dedent(
|
||||
f"""
|
||||
SELECT * FROM SEMANTIC_VIEW(
|
||||
{self.uid()}
|
||||
{"DIMENSIONS " + dimension_arguments if dimension_arguments else ""}
|
||||
{"METRICS " + metric_arguments if metric_arguments else ""}
|
||||
{"WHERE " + where_clause if where_clause else ""}
|
||||
)
|
||||
{"ORDER BY " + order_clause if order_clause else ""}
|
||||
{"LIMIT " + str(limit) if limit is not None else ""}
|
||||
{"OFFSET " + str(offset) if offset is not None else ""}
|
||||
"""
|
||||
).strip()
|
||||
|
||||
def _build_top_groups_cte(
|
||||
self,
|
||||
group_limit: GroupLimit,
|
||||
where_clause: str,
|
||||
) -> tuple[str, tuple[FilterValues, ...]]:
|
||||
"""
|
||||
Build a CTE that finds the top N combinations of limited dimensions.
|
||||
|
||||
If group_limit.filters is set, it uses those filters instead of the main
|
||||
query's where clause. This allows using different time bounds for finding top
|
||||
groups vs showing data.
|
||||
|
||||
Returns:
|
||||
Tuple of (CTE SQL, parameters for the CTE)
|
||||
"""
|
||||
limited_dimension_arguments = ", ".join(
|
||||
self._alias_element(dimension) for dimension in group_limit.dimensions
|
||||
)
|
||||
limited_dimension_names = ", ".join(
|
||||
self._quote(dimension.id) for dimension in group_limit.dimensions
|
||||
)
|
||||
|
||||
# Use separate filters for group limit if provided (Option 2)
|
||||
# Otherwise use the same filters as the main query (Option 1)
|
||||
if group_limit.filters is not None:
|
||||
group_where_clause, group_where_params = self._build_predicates(
|
||||
sorted(
|
||||
filter_
|
||||
for filter_ in group_limit.filters
|
||||
if filter_.type == PredicateType.WHERE
|
||||
)
|
||||
)
|
||||
if any(
|
||||
filter_.type == PredicateType.HAVING for filter_ in group_limit.filters
|
||||
):
|
||||
raise ValueError(
|
||||
"HAVING filters are not supported in group limit filters"
|
||||
)
|
||||
cte_params = group_where_params
|
||||
else:
|
||||
group_where_clause = where_clause
|
||||
cte_params = () # No additional params - using main query params
|
||||
|
||||
# Build METRICS clause and ORDER BY based on whether metric is provided
|
||||
if group_limit.metric is not None:
|
||||
metrics_clause = (
|
||||
f"METRICS {group_limit.metric.id}"
|
||||
f" AS {self._quote(group_limit.metric.id)}"
|
||||
)
|
||||
order_by_clause = (
|
||||
f"{self._quote(group_limit.metric.id)} {group_limit.direction.value}"
|
||||
)
|
||||
else:
|
||||
# No metric provided - order by first dimension
|
||||
metrics_clause = ""
|
||||
order_by_clause = (
|
||||
f"{self._quote(group_limit.dimensions[0].id)} "
|
||||
f"{group_limit.direction.value}"
|
||||
)
|
||||
|
||||
# Build SEMANTIC_VIEW arguments
|
||||
semantic_view_args = [
|
||||
f"DIMENSIONS {limited_dimension_arguments}",
|
||||
]
|
||||
if metrics_clause:
|
||||
semantic_view_args.append(metrics_clause)
|
||||
if group_where_clause:
|
||||
semantic_view_args.append(f"WHERE {group_where_clause}")
|
||||
|
||||
semantic_view_args_str = "\n ".join(semantic_view_args)
|
||||
|
||||
# Add trailing blank line if there's no WHERE clause
|
||||
# This matches the original template behavior
|
||||
if not group_where_clause:
|
||||
semantic_view_args_str += "\n"
|
||||
|
||||
cte_sql = dedent(
|
||||
f"""
|
||||
WITH top_groups AS (
|
||||
SELECT {limited_dimension_names}
|
||||
FROM SEMANTIC_VIEW(
|
||||
{self.uid()}
|
||||
{semantic_view_args_str}
|
||||
)
|
||||
ORDER BY
|
||||
{order_by_clause}
|
||||
LIMIT {group_limit.top}
|
||||
)
|
||||
"""
|
||||
).strip()
|
||||
|
||||
return cte_sql, cte_params
|
||||
|
||||
def _build_group_filter(self, group_limit: GroupLimit) -> str:
|
||||
"""
|
||||
Build a WHERE filter that restricts results to top N groups.
|
||||
"""
|
||||
if len(group_limit.dimensions) == 1:
|
||||
dimension_id = self._quote(group_limit.dimensions[0].id)
|
||||
return f"{dimension_id} IN (SELECT {dimension_id} FROM top_groups)"
|
||||
|
||||
# Multi-column IN clause
|
||||
dimension_tuple = ", ".join(
|
||||
self._quote(dim.id) for dim in group_limit.dimensions
|
||||
)
|
||||
return f"({dimension_tuple}) IN (SELECT {dimension_tuple} FROM top_groups)"
|
||||
|
||||
def _build_case_expression(
|
||||
self,
|
||||
dimension: Dimension,
|
||||
group_condition: str,
|
||||
) -> str:
|
||||
"""
|
||||
Build a CASE expression that replaces non-top values with 'Other'.
|
||||
|
||||
Args:
|
||||
dimension: The dimension to build the CASE for
|
||||
group_condition: The condition to check if value is in top groups
|
||||
(e.g., "staff_id IN (SELECT staff_id FROM top_groups)")
|
||||
|
||||
Returns:
|
||||
SQL CASE expression
|
||||
"""
|
||||
dimension_id = self._quote(dimension.id)
|
||||
return f"""CASE
|
||||
WHEN {group_condition} THEN {dimension_id}
|
||||
ELSE CAST('Other' AS VARCHAR)
|
||||
END"""
|
||||
|
||||
def _build_query_with_others(
|
||||
self,
|
||||
metrics: list[Metric],
|
||||
dimensions: list[Dimension],
|
||||
where_clause: str,
|
||||
order: list[OrderTuple] | None,
|
||||
limit: int | None,
|
||||
offset: int | None,
|
||||
group_limit: GroupLimit,
|
||||
) -> tuple[str, tuple[FilterValues, ...]]:
|
||||
"""
|
||||
Build a query that groups non-top N values as 'Other'.
|
||||
|
||||
This uses a two-stage approach:
|
||||
1. CTE to find top N groups
|
||||
2. Subquery with CASE expressions to replace non-top values with 'Other'
|
||||
3. Outer query to re-aggregate with the new grouping
|
||||
|
||||
Returns:
|
||||
Tuple of (SQL query, CTE parameters)
|
||||
"""
|
||||
top_groups_cte, cte_params = self._build_top_groups_cte(
|
||||
group_limit,
|
||||
where_clause,
|
||||
)
|
||||
|
||||
# Determine which dimensions are limited vs non-limited
|
||||
limited_dimension_ids = {dim.id for dim in group_limit.dimensions}
|
||||
non_limited_dimensions = [
|
||||
dim for dim in dimensions if dim.id not in limited_dimension_ids
|
||||
]
|
||||
|
||||
# Build the group condition for CASE expressions
|
||||
if len(group_limit.dimensions) == 1:
|
||||
dimension_id = self._quote(group_limit.dimensions[0].id)
|
||||
group_condition = (
|
||||
f"{dimension_id} IN (SELECT {dimension_id} FROM top_groups)"
|
||||
)
|
||||
else:
|
||||
dimension_tuple = ", ".join(
|
||||
self._quote(dim.id) for dim in group_limit.dimensions
|
||||
)
|
||||
group_condition = (
|
||||
f"({dimension_tuple}) IN (SELECT {dimension_tuple} FROM top_groups)"
|
||||
)
|
||||
|
||||
# Build CASE expressions for limited dimensions
|
||||
case_expressions = []
|
||||
case_expressions_for_groupby = []
|
||||
for dim in group_limit.dimensions:
|
||||
case_expr = self._build_case_expression(dim, group_condition)
|
||||
alias = self._quote(dim.id)
|
||||
case_expressions.append(f"{case_expr} AS {alias}")
|
||||
# Store the full CASE expression for GROUP BY (not just alias)
|
||||
case_expressions_for_groupby.append(case_expr)
|
||||
|
||||
# Build SELECT for non-limited dimensions (pass through)
|
||||
non_limited_selects = [
|
||||
f"{self._quote(dim.id)} AS {self._quote(dim.id)}"
|
||||
for dim in non_limited_dimensions
|
||||
]
|
||||
|
||||
# Build metric aggregations
|
||||
metric_aggregations = [
|
||||
f"SUM({self._quote(metric.id)}) AS {self._quote(metric.id)}"
|
||||
for metric in metrics
|
||||
]
|
||||
|
||||
# Build the subquery that gets raw data from SEMANTIC_VIEW
|
||||
dimension_arguments = ", ".join(
|
||||
self._alias_element(dimension) for dimension in dimensions
|
||||
)
|
||||
metric_arguments = ", ".join(self._alias_element(metric) for metric in metrics)
|
||||
|
||||
subquery = dedent(
|
||||
f"""
|
||||
raw_data AS (
|
||||
SELECT * FROM SEMANTIC_VIEW(
|
||||
{self.uid()}
|
||||
DIMENSIONS {dimension_arguments}
|
||||
METRICS {metric_arguments}
|
||||
{"WHERE " + where_clause if where_clause else ""}
|
||||
)
|
||||
)
|
||||
"""
|
||||
).strip()
|
||||
|
||||
# Build GROUP BY clause (full CASE expressions + non-limited dimensions)
|
||||
# We need to repeat the full CASE expressions, not use aliases, because
|
||||
# Snowflake may interpret the alias as the original column reference
|
||||
group_by_columns = case_expressions_for_groupby + [
|
||||
self._quote(dim.id) for dim in non_limited_dimensions
|
||||
]
|
||||
group_by_clause = ", ".join(group_by_columns)
|
||||
|
||||
# Build final SELECT columns
|
||||
select_columns = case_expressions + non_limited_selects + metric_aggregations
|
||||
select_clause = ",\n ".join(select_columns)
|
||||
|
||||
# Build ORDER BY clause (need to reference the aliased columns)
|
||||
# Use default temporal ordering if no explicit order is provided
|
||||
effective_order = self._get_default_order(dimensions, order)
|
||||
order_clause = self._build_order_clause(effective_order)
|
||||
|
||||
query = dedent(
|
||||
f"""
|
||||
{top_groups_cte},
|
||||
{subquery}
|
||||
SELECT
|
||||
{select_clause}
|
||||
FROM raw_data
|
||||
GROUP BY {group_by_clause}
|
||||
{"ORDER BY " + order_clause if order_clause else ""}
|
||||
{"LIMIT " + str(limit) if limit is not None else ""}
|
||||
{"OFFSET " + str(offset) if offset is not None else ""}
|
||||
"""
|
||||
).strip()
|
||||
|
||||
return query, cte_params
|
||||
|
||||
def _build_query_with_group_limit(
|
||||
self,
|
||||
metrics: list[Metric],
|
||||
dimensions: list[Dimension],
|
||||
where_clause: str,
|
||||
order: list[OrderTuple] | None,
|
||||
limit: int | None,
|
||||
offset: int | None,
|
||||
group_limit: GroupLimit,
|
||||
) -> tuple[str, tuple[FilterValues, ...]]:
|
||||
"""
|
||||
Build a query with group limiting (top N groups).
|
||||
|
||||
If group_others is True, groups non-top values as 'Other'.
|
||||
Otherwise, filters to show only top N groups.
|
||||
|
||||
Returns:
|
||||
Tuple of (SQL query, CTE parameters)
|
||||
"""
|
||||
if group_limit.group_others:
|
||||
return self._build_query_with_others(
|
||||
metrics,
|
||||
dimensions,
|
||||
where_clause,
|
||||
order,
|
||||
limit,
|
||||
offset,
|
||||
group_limit,
|
||||
)
|
||||
|
||||
# Standard group limiting: just filter to top N groups
|
||||
# We can't use CTE references inside SEMANTIC_VIEW(), so we wrap it
|
||||
dimension_arguments = ", ".join(
|
||||
self._alias_element(dimension) for dimension in dimensions
|
||||
)
|
||||
metric_arguments = ", ".join(self._alias_element(metric) for metric in metrics)
|
||||
# Use default temporal ordering if no explicit order is provided
|
||||
effective_order = self._get_default_order(dimensions, order)
|
||||
order_clause = self._build_order_clause(effective_order)
|
||||
|
||||
top_groups_cte, cte_params = self._build_top_groups_cte(
|
||||
group_limit,
|
||||
where_clause,
|
||||
)
|
||||
group_filter = self._build_group_filter(group_limit)
|
||||
|
||||
query = dedent(
|
||||
f"""
|
||||
{top_groups_cte}
|
||||
SELECT * FROM SEMANTIC_VIEW(
|
||||
{self.uid()}
|
||||
{"DIMENSIONS " + dimension_arguments if dimension_arguments else ""}
|
||||
{"METRICS " + metric_arguments if metric_arguments else ""}
|
||||
{"WHERE " + where_clause if where_clause else ""}
|
||||
) AS subquery
|
||||
WHERE {group_filter}
|
||||
{"ORDER BY " + order_clause if order_clause else ""}
|
||||
{"LIMIT " + str(limit) if limit is not None else ""}
|
||||
{"OFFSET " + str(offset) if offset is not None else ""}
|
||||
"""
|
||||
).strip()
|
||||
|
||||
return query, cte_params
|
||||
|
||||
__repr__ = uid
|
||||
123
superset/semantic_layers/snowflake/utils.py
Normal file
123
superset/semantic_layers/snowflake/utils.py
Normal file
@@ -0,0 +1,123 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
# ruff: noqa: S608
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Sequence
|
||||
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
|
||||
from superset.exceptions import SupersetParseError
|
||||
from superset.semantic_layers.snowflake.schemas import (
|
||||
PrivateKeyAuth,
|
||||
SnowflakeConfiguration,
|
||||
UserPasswordAuth,
|
||||
)
|
||||
from superset.sql.parse import SQLStatement
|
||||
|
||||
|
||||
def substitute_parameters(query: str, parameters: Sequence[Any] | None) -> str:
|
||||
"""
|
||||
Substitute parametereters in templated query.
|
||||
|
||||
This is used to convert bind query parameters so that we can return the executed
|
||||
query for logging/auditing purposes. With Snowflake the binding happens on the
|
||||
server, so the only way to get the true executed query would be to query the
|
||||
database, which is innefficient.
|
||||
"""
|
||||
if not parameters:
|
||||
return query
|
||||
|
||||
result = query
|
||||
for parameter in parameters:
|
||||
if parameter is None:
|
||||
replacement = "NULL"
|
||||
elif isinstance(parameter, bool):
|
||||
# Check bool before int/float since bool is a subclass of int
|
||||
replacement = str(parameter).upper()
|
||||
elif isinstance(parameter, (int, float)):
|
||||
replacement = str(parameter)
|
||||
else:
|
||||
# String - escape single quotes
|
||||
quoted = str(parameter).replace("'", "''")
|
||||
replacement = f"'{quoted}'"
|
||||
|
||||
result = result.replace("?", replacement, 1)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def validate_order_by(definition: str) -> None:
|
||||
"""
|
||||
Validate that an ORDER BY expression is safe to use.
|
||||
|
||||
Note that `definition` could contain multiple expressions separated by commas.
|
||||
"""
|
||||
try:
|
||||
# this ensures that we have a single statement, preventing SQL injection via a
|
||||
# semicolon in the order by clause
|
||||
SQLStatement(f"SELECT 1 ORDER BY {definition}", "snowflake")
|
||||
except SupersetParseError as ex:
|
||||
raise ValueError("Invalid ORDER BY expression") from ex
|
||||
|
||||
|
||||
def get_connection_parameters(configuration: SnowflakeConfiguration) -> dict[str, Any]:
|
||||
"""
|
||||
Convert the configuration to connection parameters for the Snowflake connector.
|
||||
"""
|
||||
params = {
|
||||
"account": configuration.account_identifier,
|
||||
"application": "Apache Superset",
|
||||
"paramstyle": "qmark",
|
||||
"insecure_mode": True,
|
||||
}
|
||||
|
||||
if configuration.role:
|
||||
params["role"] = configuration.role
|
||||
if configuration.warehouse:
|
||||
params["warehouse"] = configuration.warehouse
|
||||
if configuration.database:
|
||||
params["database"] = configuration.database
|
||||
if configuration.schema_:
|
||||
params["schema"] = configuration.schema_
|
||||
|
||||
auth = configuration.auth
|
||||
if isinstance(auth, UserPasswordAuth):
|
||||
params["user"] = auth.username
|
||||
params["password"] = auth.password.get_secret_value()
|
||||
elif isinstance(auth, PrivateKeyAuth):
|
||||
pem_private_key = serialization.load_pem_private_key(
|
||||
auth.private_key.get_secret_value().encode(),
|
||||
password=(
|
||||
auth.private_key_password.get_secret_value().encode()
|
||||
if auth.private_key_password
|
||||
else None
|
||||
),
|
||||
backend=default_backend(),
|
||||
)
|
||||
params["private_key"] = pem_private_key.private_bytes(
|
||||
encoding=serialization.Encoding.DER,
|
||||
format=serialization.PrivateFormat.PKCS8,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
)
|
||||
else:
|
||||
raise ValueError("Unsupported authentication method")
|
||||
|
||||
return params
|
||||
497
superset/semantic_layers/types.py
Normal file
497
superset/semantic_layers/types.py
Normal file
@@ -0,0 +1,497 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
from dataclasses import dataclass
|
||||
from datetime import date, datetime, time, timedelta
|
||||
from functools import total_ordering
|
||||
from typing import Any, Protocol, runtime_checkable, TypeVar
|
||||
|
||||
from pandas import DataFrame
|
||||
from pydantic import BaseModel
|
||||
|
||||
__all__ = [
|
||||
"BINARY",
|
||||
"BOOLEAN",
|
||||
"DATE",
|
||||
"DATETIME",
|
||||
"DECIMAL",
|
||||
"Day",
|
||||
"Dimension",
|
||||
"Hour",
|
||||
"INTEGER",
|
||||
"INTERVAL",
|
||||
"Minute",
|
||||
"Month",
|
||||
"NUMBER",
|
||||
"OBJECT",
|
||||
"Quarter",
|
||||
"Second",
|
||||
"STRING",
|
||||
"TIME",
|
||||
"Week",
|
||||
"Year",
|
||||
]
|
||||
|
||||
|
||||
class Type:
|
||||
"""
|
||||
Base class for types.
|
||||
"""
|
||||
|
||||
|
||||
class INTEGER(Type):
|
||||
"""
|
||||
Represents an integer type.
|
||||
"""
|
||||
|
||||
|
||||
class NUMBER(Type):
|
||||
"""
|
||||
Represents a number type.
|
||||
"""
|
||||
|
||||
|
||||
class DECIMAL(Type):
|
||||
"""
|
||||
Represents a decimal type.
|
||||
"""
|
||||
|
||||
|
||||
class STRING(Type):
|
||||
"""
|
||||
Represents a string type.
|
||||
"""
|
||||
|
||||
|
||||
class BOOLEAN(Type):
|
||||
"""
|
||||
Represents a boolean type.
|
||||
"""
|
||||
|
||||
|
||||
class DATE(Type):
|
||||
"""
|
||||
Represents a date type.
|
||||
"""
|
||||
|
||||
|
||||
class TIME(Type):
|
||||
"""
|
||||
Represents a time type.
|
||||
"""
|
||||
|
||||
|
||||
class DATETIME(DATE, TIME):
|
||||
"""
|
||||
Represents a datetime type.
|
||||
"""
|
||||
|
||||
|
||||
class INTERVAL(Type):
|
||||
"""
|
||||
Represents an interval type.
|
||||
"""
|
||||
|
||||
|
||||
class OBJECT(Type):
|
||||
"""
|
||||
Represents an object type.
|
||||
"""
|
||||
|
||||
|
||||
class BINARY(Type):
|
||||
"""
|
||||
Represents a binary type.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@total_ordering
|
||||
class Grain:
|
||||
"""
|
||||
Base class for time and date grains with comparison support.
|
||||
|
||||
Attributes:
|
||||
name: Human-readable name of the grain (e.g., "Second")
|
||||
representation: ISO 8601 representation (e.g., "PT1S")
|
||||
value: Time period as a timedelta
|
||||
"""
|
||||
|
||||
name: str
|
||||
representation: str
|
||||
value: timedelta
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if isinstance(other, Grain):
|
||||
return self.value == other.value
|
||||
return NotImplemented
|
||||
|
||||
def __lt__(self, other: object) -> bool:
|
||||
if isinstance(other, Grain):
|
||||
return self.value < other.value
|
||||
return NotImplemented
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash((self.name, self.representation, self.value))
|
||||
|
||||
|
||||
class Second(Grain):
|
||||
name = "Second"
|
||||
representation = "PT1S"
|
||||
value = timedelta(seconds=1)
|
||||
|
||||
|
||||
class Minute(Grain):
|
||||
name = "Minute"
|
||||
representation = "PT1M"
|
||||
value = timedelta(minutes=1)
|
||||
|
||||
|
||||
class Hour(Grain):
|
||||
name = "Hour"
|
||||
representation = "PT1H"
|
||||
value = timedelta(hours=1)
|
||||
|
||||
|
||||
class Day(Grain):
|
||||
name = "Day"
|
||||
representation = "P1D"
|
||||
value = timedelta(days=1)
|
||||
|
||||
|
||||
class Week(Grain):
|
||||
name = "Week"
|
||||
representation = "P1W"
|
||||
value = timedelta(weeks=1)
|
||||
|
||||
|
||||
class Month(Grain):
|
||||
name = "Month"
|
||||
representation = "P1M"
|
||||
value = timedelta(days=30)
|
||||
|
||||
|
||||
class Quarter(Grain):
|
||||
name = "Quarter"
|
||||
representation = "P3M"
|
||||
value = timedelta(days=90)
|
||||
|
||||
|
||||
class Year(Grain):
|
||||
name = "Year"
|
||||
representation = "P1Y"
|
||||
value = timedelta(days=365)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Dimension:
|
||||
id: str
|
||||
name: str
|
||||
type: type[Type]
|
||||
|
||||
definition: str | None = None
|
||||
description: str | None = None
|
||||
grain: Grain | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Metric:
|
||||
id: str
|
||||
name: str
|
||||
type: type[Type]
|
||||
|
||||
definition: str | None
|
||||
description: str | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AdhocExpression:
|
||||
id: str
|
||||
definition: str
|
||||
|
||||
|
||||
class Operator(str, enum.Enum):
|
||||
EQUALS = "="
|
||||
NOT_EQUALS = "!="
|
||||
GREATER_THAN = ">"
|
||||
LESS_THAN = "<"
|
||||
GREATER_THAN_OR_EQUAL = ">="
|
||||
LESS_THAN_OR_EQUAL = "<="
|
||||
IN = "IN"
|
||||
NOT_IN = "NOT IN"
|
||||
LIKE = "LIKE"
|
||||
NOT_LIKE = "NOT LIKE"
|
||||
IS_NULL = "IS NULL"
|
||||
IS_NOT_NULL = "IS NOT NULL"
|
||||
|
||||
|
||||
FilterValues = str | int | float | bool | datetime | date | time | timedelta | None
|
||||
|
||||
|
||||
class PredicateType(enum.Enum):
|
||||
WHERE = "WHERE"
|
||||
HAVING = "HAVING"
|
||||
|
||||
|
||||
@dataclass(frozen=True, order=True)
|
||||
class Filter:
|
||||
type: PredicateType
|
||||
column: Dimension | Metric
|
||||
operator: Operator
|
||||
value: FilterValues | set[FilterValues]
|
||||
|
||||
|
||||
@dataclass(frozen=True, order=True)
|
||||
class AdhocFilter:
|
||||
type: PredicateType
|
||||
definition: str
|
||||
|
||||
|
||||
class OrderDirection(enum.Enum):
|
||||
ASC = "ASC"
|
||||
DESC = "DESC"
|
||||
|
||||
|
||||
OrderTuple = tuple[Metric | Dimension | AdhocExpression, OrderDirection]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GroupLimit:
|
||||
"""
|
||||
Limit query to top/bottom N combinations of specified dimensions.
|
||||
|
||||
The `filters` parameter allows specifying separate filter constraints for the
|
||||
group limit subquery. This is useful when you want to determine the top N groups
|
||||
using different criteria (e.g., a different time range) than the main query.
|
||||
|
||||
For example, you might want to find the top 10 products by sales over the last
|
||||
30 days, but then show daily sales for those products over the last 7 days.
|
||||
"""
|
||||
|
||||
dimensions: list[Dimension]
|
||||
top: int
|
||||
metric: Metric | None
|
||||
direction: OrderDirection = OrderDirection.DESC
|
||||
group_others: bool = False
|
||||
filters: set[Filter | AdhocFilter] | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SemanticRequest:
|
||||
"""
|
||||
Represents a request made to obtain semantic results.
|
||||
|
||||
This could be a SQL query, an HTTP request, etc.
|
||||
"""
|
||||
|
||||
type: str
|
||||
definition: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SemanticResult:
|
||||
"""
|
||||
Represents the results of a semantic query.
|
||||
|
||||
This includes any requests (SQL queries, HTTP requests) that were performed in order
|
||||
to obtain the results, in order to help troubleshooting.
|
||||
"""
|
||||
|
||||
requests: list[SemanticRequest]
|
||||
results: DataFrame
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SemanticQuery:
|
||||
"""
|
||||
Represents a semantic query.
|
||||
"""
|
||||
|
||||
metrics: list[Metric]
|
||||
dimensions: list[Dimension]
|
||||
filters: set[Filter | AdhocFilter] | None = None
|
||||
order: list[OrderTuple] | None = None
|
||||
limit: int | None = None
|
||||
offset: int | None = None
|
||||
group_limit: GroupLimit | None = None
|
||||
|
||||
|
||||
class SemanticViewFeature(enum.Enum):
|
||||
"""
|
||||
Custom features supported by semantic layers.
|
||||
"""
|
||||
|
||||
ADHOC_EXPRESSIONS_IN_ORDERBY = "ADHOC_EXPRESSIONS_IN_ORDERBY"
|
||||
GROUP_LIMIT = "GROUP_LIMIT"
|
||||
GROUP_OTHERS = "GROUP_OTHERS"
|
||||
|
||||
|
||||
ConfigT = TypeVar("ConfigT", bound=BaseModel, contravariant=True)
|
||||
SemanticViewT = TypeVar("SemanticViewT", bound="SemanticViewImplementation")
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class SemanticLayerImplementation(Protocol[ConfigT, SemanticViewT]):
|
||||
"""
|
||||
A protocol for semantic layers.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_configuration(
|
||||
cls,
|
||||
configuration: dict[str, Any],
|
||||
) -> SemanticLayerImplementation[ConfigT, SemanticViewT]:
|
||||
"""
|
||||
Create a semantic layer from its configuration.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_configuration_schema(
|
||||
cls,
|
||||
configuration: ConfigT | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Get the JSON schema for the configuration needed to add the semantic layer.
|
||||
|
||||
A partial configuration `configuration` can be sent to improve the schema,
|
||||
allowing for progressive validation and better UX. For example, a semantic
|
||||
layer might require:
|
||||
|
||||
- auth information
|
||||
- a database
|
||||
|
||||
If the user provides the auth information, a client can send the partial
|
||||
configuration to this method, and the resulting JSON schema would include
|
||||
the list of databases the user has access to, allowing a dropdown to be
|
||||
populated.
|
||||
|
||||
The Snowflake semantic layer has an example implementation of this method, where
|
||||
database and schema names are populated based on the provided connection info.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_runtime_schema(
|
||||
cls,
|
||||
configuration: ConfigT,
|
||||
runtime_data: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Get the JSON schema for the runtime parameters needed to load semantic views.
|
||||
|
||||
This returns the schema needed to connect to a semantic view given the
|
||||
configuration for the semantic layer. For example, a semantic layer might
|
||||
be configured by:
|
||||
|
||||
- auth information
|
||||
- an optional database
|
||||
|
||||
If the user does not provide a database when creating the semantic layer, the
|
||||
runtime schema would require the database name to be provided before loading any
|
||||
semantic views. This allows users to create semantic layers that connect to a
|
||||
specific database (or project, account, etc.), or that allow users to select it
|
||||
at query time.
|
||||
|
||||
The Snowflake semantic layer has an example implementation of this method, where
|
||||
database and schema names are required if they were not provided in the initial
|
||||
configuration.
|
||||
"""
|
||||
|
||||
def get_semantic_views(
|
||||
self,
|
||||
runtime_configuration: dict[str, Any],
|
||||
) -> set[SemanticViewT]:
|
||||
"""
|
||||
Get the semantic views available in the semantic layer.
|
||||
|
||||
The runtime configuration can provide information like a given project or
|
||||
schema, used to restrict the semantic views returned.
|
||||
"""
|
||||
|
||||
def get_semantic_view(
|
||||
self,
|
||||
name: str,
|
||||
additional_configuration: dict[str, Any],
|
||||
) -> SemanticViewT:
|
||||
"""
|
||||
Get a specific semantic view by its name and additional configuration.
|
||||
"""
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class SemanticViewImplementation(Protocol):
|
||||
"""
|
||||
A protocol for semantic views.
|
||||
"""
|
||||
|
||||
features: frozenset[SemanticViewFeature]
|
||||
|
||||
def uid(self) -> str:
|
||||
"""
|
||||
Returns a unique identifier for the semantic view.
|
||||
"""
|
||||
|
||||
def get_dimensions(self) -> set[Dimension]:
|
||||
"""
|
||||
Get the dimensions defined in the semantic view.
|
||||
"""
|
||||
|
||||
def get_metrics(self) -> set[Metric]:
|
||||
"""
|
||||
Get the metrics defined in the semantic view.
|
||||
"""
|
||||
|
||||
def get_values(
|
||||
self,
|
||||
dimension: Dimension,
|
||||
filters: set[Filter | AdhocFilter] | None = None,
|
||||
) -> SemanticResult:
|
||||
"""
|
||||
Return distinct values for a dimension.
|
||||
"""
|
||||
|
||||
def get_dataframe(
|
||||
self,
|
||||
metrics: list[Metric],
|
||||
dimensions: list[Dimension],
|
||||
filters: set[Filter | AdhocFilter] | None = None,
|
||||
order: list[OrderTuple] | None = None,
|
||||
limit: int | None = None,
|
||||
offset: int | None = None,
|
||||
*,
|
||||
group_limit: GroupLimit | None = None,
|
||||
) -> SemanticResult:
|
||||
"""
|
||||
Execute a semantic query and return the results as a DataFrame.
|
||||
"""
|
||||
|
||||
def get_row_count(
|
||||
self,
|
||||
metrics: list[Metric],
|
||||
dimensions: list[Dimension],
|
||||
filters: set[Filter | AdhocFilter] | None = None,
|
||||
order: list[OrderTuple] | None = None,
|
||||
limit: int | None = None,
|
||||
offset: int | None = None,
|
||||
*,
|
||||
group_limit: GroupLimit | None = None,
|
||||
) -> SemanticResult:
|
||||
"""
|
||||
Execute a query and return the number of rows the result would have.
|
||||
"""
|
||||
@@ -57,6 +57,46 @@ class AdhocMetric(TypedDict, total=False):
|
||||
sqlExpression: str | None
|
||||
|
||||
|
||||
class DatasetColumnData(TypedDict, total=False):
|
||||
"""Type for column metadata in ExplorableData datasets."""
|
||||
|
||||
advanced_data_type: str | None
|
||||
certification_details: str | None
|
||||
certified_by: str | None
|
||||
column_name: str
|
||||
description: str | None
|
||||
expression: str | None
|
||||
filterable: bool
|
||||
groupby: bool
|
||||
id: int
|
||||
uuid: str | None
|
||||
is_certified: bool
|
||||
is_dttm: bool
|
||||
python_date_format: str | None
|
||||
type: str
|
||||
type_generic: NotRequired["GenericDataType" | None]
|
||||
verbose_name: str | None
|
||||
warning_markdown: str | None
|
||||
|
||||
|
||||
class DatasetMetricData(TypedDict, total=False):
|
||||
"""Type for metric metadata in ExplorableData datasets."""
|
||||
|
||||
certification_details: str | None
|
||||
certified_by: str | None
|
||||
currency: NotRequired[dict[str, Any]]
|
||||
d3format: str | None
|
||||
description: str | None
|
||||
expression: str
|
||||
id: int
|
||||
uuid: str | None
|
||||
is_certified: bool
|
||||
metric_name: str
|
||||
warning_markdown: str | None
|
||||
warning_text: str | None
|
||||
verbose_name: str | None
|
||||
|
||||
|
||||
class AdhocColumn(TypedDict, total=False):
|
||||
hasCustomLabel: bool | None
|
||||
label: str
|
||||
@@ -195,15 +235,19 @@ class QueryObjectDict(TypedDict, total=False):
|
||||
timeseries_limit_metric: Metric | None
|
||||
|
||||
|
||||
class BaseDatasourceData(TypedDict, total=False):
|
||||
class ExplorableData(TypedDict, total=False):
|
||||
"""
|
||||
TypedDict for datasource data returned to the frontend.
|
||||
TypedDict for explorable data returned to the frontend.
|
||||
|
||||
This represents the structure of the dictionary returned from BaseDatasource.data
|
||||
property. It provides datasource information to the frontend for visualization
|
||||
and querying.
|
||||
This represents the structure of the dictionary returned from the `data` property
|
||||
of any Explorable (BaseDatasource, Query, etc.). It provides datasource/query
|
||||
information to the frontend for visualization and querying.
|
||||
|
||||
Core fields from BaseDatasource.data:
|
||||
All fields are optional (total=False) since different explorable types provide
|
||||
different subsets of these fields. Query objects provide a minimal subset while
|
||||
SqlaTable provides the full set.
|
||||
|
||||
Core fields:
|
||||
id: Unique identifier for the datasource
|
||||
uid: Unique identifier including type (e.g., "1__table")
|
||||
column_formats: D3 format strings for columns
|
||||
@@ -268,8 +312,8 @@ class BaseDatasourceData(TypedDict, total=False):
|
||||
perm: str | None
|
||||
edit_url: str
|
||||
sql: str | None
|
||||
columns: list[dict[str, Any]]
|
||||
metrics: list[dict[str, Any]]
|
||||
columns: list["DatasetColumnData"]
|
||||
metrics: list["DatasetMetricData"]
|
||||
folders: Any # JSON field, can be list or dict
|
||||
order_by_choices: list[tuple[str, str]]
|
||||
owners: list[int] | list[dict[str, Any]] # Can be either format
|
||||
@@ -277,8 +321,8 @@ class BaseDatasourceData(TypedDict, total=False):
|
||||
select_star: str | None
|
||||
|
||||
# Additional fields from SqlaTable and data_for_slices
|
||||
column_types: list[Any]
|
||||
column_names: set[str] | set[Any]
|
||||
column_types: list["GenericDataType"]
|
||||
column_names: set[str]
|
||||
granularity_sqla: list[tuple[Any, Any]]
|
||||
time_grain_sqla: list[tuple[Any, Any]]
|
||||
main_dttm_col: str | None
|
||||
@@ -291,46 +335,6 @@ class BaseDatasourceData(TypedDict, total=False):
|
||||
normalize_columns: bool
|
||||
|
||||
|
||||
class QueryData(TypedDict, total=False):
|
||||
"""
|
||||
TypedDict for SQL Lab query data returned to the frontend.
|
||||
|
||||
This represents the structure of the dictionary returned from Query.data property
|
||||
in SQL Lab. It provides query information to the frontend for execution and display.
|
||||
|
||||
Fields:
|
||||
time_grain_sqla: Available time grains for this database
|
||||
filter_select: Whether filter select is enabled
|
||||
name: Query tab name
|
||||
columns: List of column definitions
|
||||
metrics: List of metrics (always empty for queries)
|
||||
id: Query ID
|
||||
type: Object type (always "query")
|
||||
sql: SQL query text
|
||||
owners: List of owner information
|
||||
database: Database connection details
|
||||
order_by_choices: Available ordering options
|
||||
catalog: Catalog name if applicable
|
||||
schema: Schema name if applicable
|
||||
verbose_map: Mapping of column names to verbose names (empty for queries)
|
||||
"""
|
||||
|
||||
time_grain_sqla: list[tuple[Any, Any]]
|
||||
filter_select: bool
|
||||
name: str | None
|
||||
columns: list[dict[str, Any]]
|
||||
metrics: list[Any]
|
||||
id: int
|
||||
type: str
|
||||
sql: str | None
|
||||
owners: list[dict[str, Any]]
|
||||
database: dict[str, Any]
|
||||
order_by_choices: list[tuple[str, str]]
|
||||
catalog: str | None
|
||||
schema: str | None
|
||||
verbose_map: dict[str, str]
|
||||
|
||||
|
||||
VizData: TypeAlias = list[Any] | dict[Any, Any] | None
|
||||
VizPayload: TypeAlias = dict[str, Any]
|
||||
|
||||
|
||||
@@ -61,6 +61,7 @@ def _adjust_string_with_rls(
|
||||
"""
|
||||
Add the RLS filters to the unique string based on current executor.
|
||||
"""
|
||||
|
||||
user = (
|
||||
security_manager.find_user(executor)
|
||||
or security_manager.get_current_guest_user_if_guest()
|
||||
@@ -70,11 +71,7 @@ def _adjust_string_with_rls(
|
||||
stringified_rls = ""
|
||||
with override_user(user):
|
||||
for datasource in datasources:
|
||||
if (
|
||||
datasource
|
||||
and hasattr(datasource, "is_rls_supported")
|
||||
and datasource.is_rls_supported
|
||||
):
|
||||
if datasource and getattr(datasource, "is_rls_supported", False):
|
||||
rls_filters = datasource.get_sqla_row_level_filters()
|
||||
|
||||
if len(rls_filters) > 0:
|
||||
|
||||
@@ -96,6 +96,7 @@ from superset.exceptions import (
|
||||
SupersetException,
|
||||
SupersetTimeoutException,
|
||||
)
|
||||
from superset.explorables.base import Explorable
|
||||
from superset.sql.parse import sanitize_clause
|
||||
from superset.superset_typing import (
|
||||
AdhocColumn,
|
||||
@@ -114,9 +115,8 @@ from superset.utils.hashing import md5_sha_from_dict, md5_sha_from_str
|
||||
from superset.utils.pandas import detect_datetime_format
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.connectors.sqla.models import BaseDatasource, TableColumn
|
||||
from superset.connectors.sqla.models import TableColumn
|
||||
from superset.models.core import Database
|
||||
from superset.models.sql_lab import Query
|
||||
|
||||
logging.getLogger("MARKDOWN").setLevel(logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -200,6 +200,7 @@ class DatasourceType(StrEnum):
|
||||
QUERY = "query"
|
||||
SAVEDQUERY = "saved_query"
|
||||
VIEW = "view"
|
||||
SEMANTIC_VIEW = "semantic_view"
|
||||
|
||||
|
||||
class LoggerLevel(StrEnum):
|
||||
@@ -1656,7 +1657,7 @@ def map_sql_type_to_inferred_type(sql_type: Optional[str]) -> str:
|
||||
return "string" # If no match is found, return "string" as default
|
||||
|
||||
|
||||
def get_metric_type_from_column(column: Any, datasource: BaseDatasource | Query) -> str:
|
||||
def get_metric_type_from_column(column: Any, datasource: Explorable) -> str:
|
||||
"""
|
||||
Determine the metric type from a given column in a datasource.
|
||||
|
||||
@@ -1698,7 +1699,7 @@ def get_metric_type_from_column(column: Any, datasource: BaseDatasource | Query)
|
||||
|
||||
def extract_dataframe_dtypes(
|
||||
df: pd.DataFrame,
|
||||
datasource: BaseDatasource | Query | None = None,
|
||||
datasource: Explorable | None = None,
|
||||
) -> list[GenericDataType]:
|
||||
"""Serialize pandas/numpy dtypes to generic types"""
|
||||
|
||||
@@ -1718,7 +1719,8 @@ def extract_dataframe_dtypes(
|
||||
if datasource:
|
||||
for column in datasource.columns:
|
||||
if isinstance(column, dict):
|
||||
columns_by_name[column.get("column_name")] = column
|
||||
if column_name := column.get("column_name"):
|
||||
columns_by_name[column_name] = column
|
||||
else:
|
||||
columns_by_name[column.column_name] = column
|
||||
|
||||
@@ -1768,11 +1770,13 @@ def is_test() -> bool:
|
||||
|
||||
|
||||
def get_time_filter_status(
|
||||
datasource: BaseDatasource,
|
||||
datasource: Explorable,
|
||||
applied_time_extras: dict[str, str],
|
||||
) -> tuple[list[dict[str, str]], list[dict[str, str]]]:
|
||||
temporal_columns: set[Any] = {
|
||||
col.column_name for col in datasource.columns if col.is_dttm
|
||||
(col.column_name if hasattr(col, "column_name") else col.get("column_name"))
|
||||
for col in datasource.columns
|
||||
if (col.is_dttm if hasattr(col, "is_dttm") else col.get("is_dttm"))
|
||||
}
|
||||
applied: list[dict[str, str]] = []
|
||||
rejected: list[dict[str, str]] = []
|
||||
|
||||
@@ -78,9 +78,8 @@ from superset.models.dashboard import Dashboard
|
||||
from superset.models.slice import Slice
|
||||
from superset.models.user_attributes import UserAttribute
|
||||
from superset.superset_typing import (
|
||||
BaseDatasourceData,
|
||||
ExplorableData,
|
||||
FlaskResponse,
|
||||
QueryData,
|
||||
)
|
||||
from superset.tasks.utils import get_current_user
|
||||
from superset.utils import core as utils, json
|
||||
@@ -531,14 +530,14 @@ class Superset(BaseSupersetView):
|
||||
)
|
||||
standalone_mode = ReservedUrlParameters.is_standalone_mode()
|
||||
force = request.args.get("force") in {"force", "1", "true"}
|
||||
dummy_datasource_data: BaseDatasourceData = {
|
||||
dummy_datasource_data: ExplorableData = {
|
||||
"type": datasource_type or "unknown",
|
||||
"name": datasource_name,
|
||||
"columns": [],
|
||||
"metrics": [],
|
||||
"database": {"id": 0, "backend": ""},
|
||||
}
|
||||
datasource_data: BaseDatasourceData | QueryData
|
||||
datasource_data: ExplorableData
|
||||
try:
|
||||
datasource_data = datasource.data if datasource else dummy_datasource_data
|
||||
except (SupersetException, SQLAlchemyError):
|
||||
|
||||
@@ -20,6 +20,7 @@ from collections import defaultdict
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, DefaultDict, Optional, Union
|
||||
from urllib import parse
|
||||
from uuid import UUID
|
||||
|
||||
import msgpack
|
||||
import pyarrow as pa
|
||||
@@ -46,10 +47,9 @@ from superset.models.dashboard import Dashboard
|
||||
from superset.models.slice import Slice
|
||||
from superset.models.sql_lab import Query
|
||||
from superset.superset_typing import (
|
||||
BaseDatasourceData,
|
||||
ExplorableData,
|
||||
FlaskResponse,
|
||||
FormData,
|
||||
QueryData,
|
||||
)
|
||||
from superset.utils import json
|
||||
from superset.utils.core import DatasourceType
|
||||
@@ -92,12 +92,10 @@ def redirect_to_login(next_target: str | None = None) -> FlaskResponse:
|
||||
|
||||
|
||||
def sanitize_datasource_data(
|
||||
datasource_data: BaseDatasourceData | QueryData,
|
||||
datasource_data: ExplorableData,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Sanitize datasource data by removing sensitive database parameters.
|
||||
|
||||
Accepts TypedDict types (BaseDatasourceData, QueryData).
|
||||
"""
|
||||
if datasource_data:
|
||||
datasource_database = datasource_data.get("database")
|
||||
@@ -275,8 +273,10 @@ def add_sqllab_custom_filters(form_data: dict[Any, Any]) -> Any:
|
||||
|
||||
|
||||
def get_datasource_info(
|
||||
datasource_id: Optional[int], datasource_type: Optional[str], form_data: FormData
|
||||
) -> tuple[int, Optional[str]]:
|
||||
datasource_id: Optional[int],
|
||||
datasource_type: Optional[str],
|
||||
form_data: FormData,
|
||||
) -> tuple[int | UUID, Optional[str]]:
|
||||
"""
|
||||
Compatibility layer for handling of datasource info
|
||||
|
||||
@@ -303,7 +303,11 @@ def get_datasource_info(
|
||||
_("The dataset associated with this chart no longer exists")
|
||||
)
|
||||
|
||||
datasource_id = int(datasource_id)
|
||||
if datasource_id.isdigit():
|
||||
datasource_id = int(datasource_id)
|
||||
else:
|
||||
datasource_id = UUID(datasource_id)
|
||||
|
||||
return datasource_id, datasource_type
|
||||
|
||||
|
||||
|
||||
@@ -39,32 +39,26 @@ processor = QueryContextProcessor(
|
||||
)
|
||||
|
||||
# Bind ExploreMixin methods to datasource for testing
|
||||
processor._qc_datasource.add_offset_join_column = (
|
||||
ExploreMixin.add_offset_join_column.__get__(processor._qc_datasource)
|
||||
# Type annotation needed because _qc_datasource is typed as Explorable in protocol
|
||||
_datasource: BaseDatasource = processor._qc_datasource # type: ignore
|
||||
_datasource.add_offset_join_column = ExploreMixin.add_offset_join_column.__get__(
|
||||
_datasource
|
||||
)
|
||||
processor._qc_datasource.join_offset_dfs = ExploreMixin.join_offset_dfs.__get__(
|
||||
processor._qc_datasource
|
||||
_datasource.join_offset_dfs = ExploreMixin.join_offset_dfs.__get__(_datasource)
|
||||
_datasource.is_valid_date_range = ExploreMixin.is_valid_date_range.__get__(_datasource)
|
||||
_datasource._determine_join_keys = ExploreMixin._determine_join_keys.__get__(
|
||||
_datasource
|
||||
)
|
||||
processor._qc_datasource.is_valid_date_range = ExploreMixin.is_valid_date_range.__get__(
|
||||
processor._qc_datasource
|
||||
)
|
||||
processor._qc_datasource._determine_join_keys = (
|
||||
ExploreMixin._determine_join_keys.__get__(processor._qc_datasource)
|
||||
)
|
||||
processor._qc_datasource._perform_join = ExploreMixin._perform_join.__get__(
|
||||
processor._qc_datasource
|
||||
)
|
||||
processor._qc_datasource._apply_cleanup_logic = (
|
||||
ExploreMixin._apply_cleanup_logic.__get__(processor._qc_datasource)
|
||||
_datasource._perform_join = ExploreMixin._perform_join.__get__(_datasource)
|
||||
_datasource._apply_cleanup_logic = ExploreMixin._apply_cleanup_logic.__get__(
|
||||
_datasource
|
||||
)
|
||||
# Static methods don't need binding - assign directly
|
||||
processor._qc_datasource.generate_join_column = ExploreMixin.generate_join_column
|
||||
processor._qc_datasource.is_valid_date_range_static = (
|
||||
ExploreMixin.is_valid_date_range_static
|
||||
)
|
||||
_datasource.generate_join_column = ExploreMixin.generate_join_column
|
||||
_datasource.is_valid_date_range_static = ExploreMixin.is_valid_date_range_static
|
||||
|
||||
# Convenience reference for backward compatibility in tests
|
||||
query_context_processor = processor._qc_datasource
|
||||
query_context_processor = _datasource
|
||||
|
||||
|
||||
@fixture
|
||||
|
||||
Reference in New Issue
Block a user