From 8a0aaa42ec67aa5237f423c6b65f790df75c45c4 Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Thu, 4 Dec 2025 10:37:57 -0500 Subject: [PATCH] feat: semantic layer implementation (Snowflake) --- superset/daos/semantic_layer.py | 152 +++ superset/explorables/base.py | 128 ++- ...7e0e21daa_add_semantic_layers_and_views.py | 124 +++ superset/semantic_layers/__init__.py | 16 + superset/semantic_layers/mapper.py | 869 ++++++++++++++++++ superset/semantic_layers/models.py | 378 ++++++++ .../semantic_layers/snowflake/__init__.py | 26 + superset/semantic_layers/snowflake/schemas.py | 130 +++ .../snowflake/semantic_layer.py | 236 +++++ .../snowflake/semantic_view.py | 817 ++++++++++++++++ superset/semantic_layers/snowflake/utils.py | 123 +++ superset/semantic_layers/types.py | 497 ++++++++++ superset/superset_typing.py | 48 +- 13 files changed, 3538 insertions(+), 6 deletions(-) create mode 100644 superset/daos/semantic_layer.py create mode 100644 superset/migrations/versions/2025-11-04_11-26_33d7e0e21daa_add_semantic_layers_and_views.py create mode 100644 superset/semantic_layers/__init__.py create mode 100644 superset/semantic_layers/mapper.py create mode 100644 superset/semantic_layers/models.py create mode 100644 superset/semantic_layers/snowflake/__init__.py create mode 100644 superset/semantic_layers/snowflake/schemas.py create mode 100644 superset/semantic_layers/snowflake/semantic_layer.py create mode 100644 superset/semantic_layers/snowflake/semantic_view.py create mode 100644 superset/semantic_layers/snowflake/utils.py create mode 100644 superset/semantic_layers/types.py diff --git a/superset/daos/semantic_layer.py b/superset/daos/semantic_layer.py new file mode 100644 index 00000000000..9c591e4a7a4 --- /dev/null +++ b/superset/daos/semantic_layer.py @@ -0,0 +1,152 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""DAOs for semantic layer models.""" + +from __future__ import annotations + +from superset.daos.base import BaseDAO +from superset.extensions import db +from superset.semantic_layers.models import SemanticLayer, SemanticView + + +class SemanticLayerDAO(BaseDAO[SemanticLayer]): + """ + Data Access Object for SemanticLayer model. + """ + + @staticmethod + def validate_uniqueness(name: str) -> bool: + """ + Validate that semantic layer name is unique. + + :param name: Semantic layer name + :return: True if name is unique, False otherwise + """ + query = db.session.query(SemanticLayer).filter(SemanticLayer.name == name) + return not db.session.query(query.exists()).scalar() + + @staticmethod + def validate_update_uniqueness(layer_uuid: str, name: str) -> bool: + """ + Validate that semantic layer name is unique for updates. + + :param layer_uuid: UUID of the semantic layer being updated + :param name: New name to validate + :return: True if name is unique, False otherwise + """ + query = db.session.query(SemanticLayer).filter( + SemanticLayer.name == name, + SemanticLayer.uuid != layer_uuid, + ) + return not db.session.query(query.exists()).scalar() + + @staticmethod + def find_by_name(name: str) -> SemanticLayer | None: + """ + Find semantic layer by name. + + :param name: Semantic layer name + :return: SemanticLayer instance or None + """ + return ( + db.session.query(SemanticLayer) + .filter(SemanticLayer.name == name) + .one_or_none() + ) + + @classmethod + def get_semantic_views(cls, layer_uuid: str) -> list[SemanticView]: + """ + Get all semantic views for a semantic layer. + + :param layer_uuid: UUID of the semantic layer + :return: List of SemanticView instances + """ + return ( + db.session.query(SemanticView) + .filter(SemanticView.semantic_layer_uuid == layer_uuid) + .all() + ) + + +class SemanticViewDAO(BaseDAO[SemanticView]): + """Data Access Object for SemanticView model.""" + + @staticmethod + def find_by_semantic_layer(layer_uuid: str) -> list[SemanticView]: + """ + Find all views for a semantic layer. + + :param layer_uuid: UUID of the semantic layer + :return: List of SemanticView instances + """ + return ( + db.session.query(SemanticView) + .filter(SemanticView.semantic_layer_uuid == layer_uuid) + .all() + ) + + @staticmethod + def validate_uniqueness(name: str, layer_uuid: str) -> bool: + """ + Validate that view name is unique within semantic layer. + + :param name: View name + :param layer_uuid: UUID of the semantic layer + :return: True if name is unique within layer, False otherwise + """ + query = db.session.query(SemanticView).filter( + SemanticView.name == name, + SemanticView.semantic_layer_uuid == layer_uuid, + ) + return not db.session.query(query.exists()).scalar() + + @staticmethod + def validate_update_uniqueness(view_uuid: str, name: str, layer_uuid: str) -> bool: + """ + Validate that view name is unique within semantic layer for updates. + + :param view_uuid: UUID of the view being updated + :param name: New name to validate + :param layer_uuid: UUID of the semantic layer + :return: True if name is unique within layer, False otherwise + """ + query = db.session.query(SemanticView).filter( + SemanticView.name == name, + SemanticView.semantic_layer_uuid == layer_uuid, + SemanticView.uuid != view_uuid, + ) + return not db.session.query(query.exists()).scalar() + + @staticmethod + def find_by_name(name: str, layer_uuid: str) -> SemanticView | None: + """ + Find semantic view by name within a semantic layer. + + :param name: View name + :param layer_uuid: UUID of the semantic layer + :return: SemanticView instance or None + """ + return ( + db.session.query(SemanticView) + .filter( + SemanticView.name == name, + SemanticView.semantic_layer_uuid == layer_uuid, + ) + .one_or_none() + ) diff --git a/superset/explorables/base.py b/superset/explorables/base.py index 2d534b72099..de69257a317 100644 --- a/superset/explorables/base.py +++ b/superset/explorables/base.py @@ -53,6 +53,130 @@ class TimeGrainDict(TypedDict): duration: str | None +@runtime_checkable +class MetricMetadata(Protocol): + """ + Protocol for metric metadata objects. + + Represents a metric that's available on an explorable data source. + Metrics contain SQL expressions or references to semantic layer measures. + + Attributes: + metric_name: Unique identifier for the metric + expression: SQL expression or reference for calculating the metric + verbose_name: Human-readable name for display in the UI + description: Description of what the metric represents + d3format: D3 format string for formatting numeric values + currency: Currency configuration for the metric (JSON object) + warning_text: Warning message to display when using this metric + certified_by: Person or entity that certified this metric + certification_details: Details about the certification + """ + + @property + def metric_name(self) -> str: + """Unique identifier for the metric.""" + + @property + def expression(self) -> str: + """SQL expression or reference for calculating the metric.""" + + @property + def verbose_name(self) -> str | None: + """Human-readable name for display in the UI.""" + + @property + def description(self) -> str | None: + """Description of what the metric represents.""" + + @property + def d3format(self) -> str | None: + """D3 format string for formatting numeric values.""" + + @property + def currency(self) -> dict[str, Any] | None: + """Currency configuration for the metric (JSON object).""" + + @property + def warning_text(self) -> str | None: + """Warning message to display when using this metric.""" + + @property + def certified_by(self) -> str | None: + """Person or entity that certified this metric.""" + + @property + def certification_details(self) -> str | None: + """Details about the certification.""" + + +@runtime_checkable +class ColumnMetadata(Protocol): + """ + Protocol for column metadata objects. + + Represents a column/dimension that's available on an explorable data source. + Used for grouping, filtering, and dimension-based analysis. + + Attributes: + column_name: Unique identifier for the column + type: SQL data type of the column (e.g., 'VARCHAR', 'INTEGER', 'DATETIME') + is_dttm: Whether this column represents a date or time value + verbose_name: Human-readable name for display in the UI + description: Description of what the column represents + groupby: Whether this column is allowed for grouping/aggregation + filterable: Whether this column can be used in filters + expression: SQL expression if this is a calculated column + python_date_format: Python datetime format string for temporal columns + advanced_data_type: Advanced data type classification + extra: Additional metadata stored as JSON + """ + + @property + def column_name(self) -> str: + """Unique identifier for the column.""" + + @property + def type(self) -> str: + """SQL data type of the column.""" + + @property + def is_dttm(self) -> bool: + """Whether this column represents a date or time value.""" + + @property + def verbose_name(self) -> str | None: + """Human-readable name for display in the UI.""" + + @property + def description(self) -> str | None: + """Description of what the column represents.""" + + @property + def groupby(self) -> bool: + """Whether this column is allowed for grouping/aggregation.""" + + @property + def filterable(self) -> bool: + """Whether this column can be used in filters.""" + + @property + def expression(self) -> str | None: + """SQL expression if this is a calculated column.""" + + @property + def python_date_format(self) -> str | None: + """Python datetime format string for temporal columns.""" + + @property + def advanced_data_type(self) -> str | None: + """Advanced data type classification.""" + + @property + def extra(self) -> str | None: + """Additional metadata stored as JSON.""" + + @runtime_checkable class Explorable(Protocol): """ @@ -132,7 +256,7 @@ class Explorable(Protocol): """ @property - def metrics(self) -> list[Any]: + def metrics(self) -> list[MetricMetadata]: """ List of metric metadata objects. @@ -147,7 +271,7 @@ class Explorable(Protocol): # TODO: rename to dimensions @property - def columns(self) -> list[Any]: + def columns(self) -> list[ColumnMetadata]: """ List of column metadata objects. diff --git a/superset/migrations/versions/2025-11-04_11-26_33d7e0e21daa_add_semantic_layers_and_views.py b/superset/migrations/versions/2025-11-04_11-26_33d7e0e21daa_add_semantic_layers_and_views.py new file mode 100644 index 00000000000..6cac30f3fdc --- /dev/null +++ b/superset/migrations/versions/2025-11-04_11-26_33d7e0e21daa_add_semantic_layers_and_views.py @@ -0,0 +1,124 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""add_semantic_layers_and_views + +Revision ID: 33d7e0e21daa +Revises: x2s8ocx6rto6 +Create Date: 2025-11-04 11:26:00.000000 + +""" + +import uuid + +import sqlalchemy as sa +from sqlalchemy_utils import UUIDType + +from superset.extensions import encrypted_field_factory +from superset.migrations.shared.utils import ( + create_fks_for_table, + create_table, + drop_table, +) + +# revision identifiers, used by Alembic. +revision = "33d7e0e21daa" +down_revision = "x2s8ocx6rto6" + + +def upgrade(): + # Create semantic_layers table + create_table( + "semantic_layers", + sa.Column("uuid", UUIDType(binary=True), default=uuid.uuid4, nullable=False), + sa.Column("created_on", sa.DateTime(), nullable=True), + sa.Column("changed_on", sa.DateTime(), nullable=True), + sa.Column("name", sa.String(length=250), nullable=False), + sa.Column("description", sa.Text(), nullable=True), + sa.Column("type", sa.String(length=250), nullable=False), + sa.Column( + "configuration", + encrypted_field_factory.create(sa.Text), + nullable=True, + ), + sa.Column("cache_timeout", sa.Integer(), nullable=True), + sa.Column("created_by_fk", sa.Integer(), nullable=True), + sa.Column("changed_by_fk", sa.Integer(), nullable=True), + sa.PrimaryKeyConstraint("uuid"), + ) + + # Create foreign key constraints for semantic_layers + create_fks_for_table( + "fk_semantic_layers_created_by_fk_ab_user", + "semantic_layers", + "ab_user", + ["created_by_fk"], + ["id"], + ) + + create_fks_for_table( + "fk_semantic_layers_changed_by_fk_ab_user", + "semantic_layers", + "ab_user", + ["changed_by_fk"], + ["id"], + ) + + # Create semantic_views table + create_table( + "semantic_views", + sa.Column("uuid", UUIDType(binary=True), default=uuid.uuid4, nullable=False), + sa.Column("created_on", sa.DateTime(), nullable=True), + sa.Column("changed_on", sa.DateTime(), nullable=True), + sa.Column("name", sa.String(length=250), nullable=False), + sa.Column( + "configuration", + encrypted_field_factory.create(sa.Text), + nullable=True, + ), + sa.Column("cache_timeout", sa.Integer(), nullable=True), + sa.Column( + "semantic_layer_uuid", + UUIDType(binary=True), + sa.ForeignKey("semantic_layers.uuid", ondelete="CASCADE"), + nullable=False, + ), + sa.Column("created_by_fk", sa.Integer(), nullable=True), + sa.Column("changed_by_fk", sa.Integer(), nullable=True), + sa.PrimaryKeyConstraint("uuid"), + ) + + # Create foreign key constraints for semantic_views + create_fks_for_table( + "fk_semantic_views_created_by_fk_ab_user", + "semantic_views", + "ab_user", + ["created_by_fk"], + ["id"], + ) + + create_fks_for_table( + "fk_semantic_views_changed_by_fk_ab_user", + "semantic_views", + "ab_user", + ["changed_by_fk"], + ["id"], + ) + + +def downgrade(): + drop_table("semantic_views") + drop_table("semantic_layers") diff --git a/superset/semantic_layers/__init__.py b/superset/semantic_layers/__init__.py new file mode 100644 index 00000000000..13a83393a91 --- /dev/null +++ b/superset/semantic_layers/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/superset/semantic_layers/mapper.py b/superset/semantic_layers/mapper.py new file mode 100644 index 00000000000..a9ee68f8f81 --- /dev/null +++ b/superset/semantic_layers/mapper.py @@ -0,0 +1,869 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Functions for mapping `QueryObject` to semantic layers. + +These functions validate and convert a `QueryObject` into one or more `SemanticQuery`, +which are then passed to semantic layer implementations for execution, returning a +single dataframe. + +""" + +from datetime import datetime, timedelta +from time import time +from typing import Any, cast, Sequence, TypeGuard + +import numpy as np + +from superset.common.db_query_status import QueryStatus +from superset.common.query_object import QueryObject +from superset.common.utils.time_range_utils import get_since_until_from_query_object +from superset.connectors.sqla.models import BaseDatasource +from superset.models.helpers import QueryResult +from superset.semantic_layers.types import ( + AdhocExpression, + AdhocFilter, + DateGrain, + Dimension, + Filter, + FilterValues, + GroupLimit, + Metric, + Operator, + OrderDirection, + OrderTuple, + PredicateType, + SemanticQuery, + SemanticResult, + SemanticViewFeature, + TimeGrain, +) +from superset.utils.core import ( + FilterOperator, + QueryObjectFilterClause, + TIME_COMPARISON, +) +from superset.utils.date_parser import get_past_or_future + + +class ValidatedQueryObjectFilterClause(QueryObjectFilterClause): + """ + A validated QueryObject filter clause with a string column name. + + The `col` in a `QueryObjectFilterClause` can be either a string (column name) or an + adhoc column, but we only support the former in semantic layers. + """ + + # overwrite to narrow type; mypy complains about more restrictive typed dicts, + # but the alternative would be to redefine the object + col: str # type: ignore[misc] + op: str # type: ignore[misc] + + +class ValidatedQueryObject(QueryObject): + """ + A query object that has a datasource defined. + """ + + datasource: BaseDatasource + + # overwrite to narrow type; mypy complains about the assignment since the base type + # allows adhoc filters, but we only support validated filters here + filter: list[ValidatedQueryObjectFilterClause] # type: ignore[assignment] + series_columns: Sequence[str] # type: ignore[assignment] + series_limit_metric: str | None + + +def get_results(query_object: QueryObject) -> QueryResult: + """ + Run 1+ queries based on `QueryObject` and return the results. + + :param query_object: The QueryObject containing query specifications + :return: QueryResult compatible with Superset's query interface + """ + if not validate_query_object(query_object): + raise ValueError("QueryObject must have a datasource defined.") + + # Track execution time + start_time = time() + + semantic_view = query_object.datasource.implementation + dispatcher = ( + semantic_view.get_row_count + if query_object.is_rowcount + else semantic_view.get_dataframe + ) + + # Step 1: Convert QueryObject to list of SemanticQuery objects + # The first query is the main query, subsequent queries are for time offsets + queries = map_query_object(query_object) + + # Step 2: Execute the main query (first in the list) + main_query = queries[0] + main_result = dispatcher( + metrics=main_query.metrics, + dimensions=main_query.dimensions, + filters=main_query.filters, + order=main_query.order, + limit=main_query.limit, + offset=main_query.offset, + group_limit=main_query.group_limit, + ) + + main_df = main_result.results + + # Collect all requests (SQL queries, HTTP requests, etc.) for troubleshooting + all_requests = list(main_result.requests) + + # If no time offsets, return the main result as-is + if not query_object.time_offsets or len(queries) <= 1: + semantic_result = SemanticResult( + requests=all_requests, + results=main_df, + ) + duration = timedelta(seconds=time() - start_time) + return map_semantic_result_to_query_result( + semantic_result, + query_object, + duration, + ) + + # Get metric names from the main query + # These are the columns that will be renamed with offset suffixes + metric_names = [metric.name for metric in main_query.metrics] + + # Join keys are all columns except metrics + # These will be used to match rows between main and offset DataFrames + join_keys = [col for col in main_df.columns if col not in metric_names] + + # Step 3 & 4: Execute each time offset query and join results + for offset_query, time_offset in zip( + queries[1:], + query_object.time_offsets, + strict=False, + ): + # Execute the offset query + result = dispatcher( + metrics=offset_query.metrics, + dimensions=offset_query.dimensions, + filters=offset_query.filters, + order=offset_query.order, + limit=offset_query.limit, + offset=offset_query.offset, + group_limit=offset_query.group_limit, + ) + + # Add this query's requests to the collection + all_requests.extend(result.requests) + + offset_df = result.results + + # Handle empty results - add NaN columns directly instead of merging + # This avoids dtype mismatch issues with empty DataFrames + if offset_df.empty: + # Add offset metric columns with NaN values directly to main_df + for metric in metric_names: + offset_col_name = TIME_COMPARISON.join([metric, time_offset]) + main_df[offset_col_name] = np.nan + else: + # Rename metric columns with time offset suffix + # Format: "{metric_name}__{time_offset}" + # Example: "revenue" -> "revenue__1 week ago" + offset_df = offset_df.rename( + columns={ + metric: TIME_COMPARISON.join([metric, time_offset]) + for metric in metric_names + } + ) + + # Step 5: Perform left join on dimension columns + # This preserves all rows from main_df and adds offset metrics + # where they match + main_df = main_df.merge( + offset_df, + on=join_keys, + how="left", + suffixes=("", "__duplicate"), + ) + + # Clean up any duplicate columns that might have been created + # (shouldn't happen with proper join keys, but defensive programming) + duplicate_cols = [ + col for col in main_df.columns if col.endswith("__duplicate") + ] + if duplicate_cols: + main_df = main_df.drop(columns=duplicate_cols) + + # Convert final result to QueryResult + semantic_result = SemanticResult(requests=all_requests, results=main_df) + duration = timedelta(seconds=time() - start_time) + return map_semantic_result_to_query_result( + semantic_result, + query_object, + duration, + ) + + +def map_semantic_result_to_query_result( + semantic_result: SemanticResult, + query_object: ValidatedQueryObject, + duration: timedelta, +) -> QueryResult: + """ + Convert a SemanticResult to a QueryResult. + + :param semantic_result: Result from the semantic layer + :param query_object: Original QueryObject (for passthrough attributes) + :param duration: Time taken to execute the query + :return: QueryResult compatible with Superset's query interface + """ + # Get the query string from requests (typically one or more SQL queries) + query_str = "" + if semantic_result.requests: + # Join all requests for display (could be multiple for time comparisons) + query_str = "\n\n".join( + f"-- {req.type}\n{req.definition}" for req in semantic_result.requests + ) + + return QueryResult( + # Core data + df=semantic_result.results, + query=query_str, + duration=duration, + # Template filters - not applicable to semantic layers + # (semantic layers don't use Jinja templates) + applied_template_filters=None, + # Filter columns - not applicable to semantic layers + # (semantic layers handle filter validation internally) + applied_filter_columns=None, + rejected_filter_columns=None, + # Status - always success if we got here + # (errors would raise exceptions before reaching this point) + status=QueryStatus.SUCCESS, + error_message=None, + errors=None, + # Time range - pass through from original query_object + from_dttm=query_object.from_dttm, + to_dttm=query_object.to_dttm, + ) + + +def map_query_object(query_object: ValidatedQueryObject) -> list[SemanticQuery]: + """ + Convert a `QueryObject` into a list of `SemanticQuery`. + + This function maps the `QueryObject` into query objects that focus less on + visualization and more on semantics. + """ + semantic_view = query_object.datasource.implementation + + all_metrics = {metric.name: metric for metric in semantic_view.metrics} + all_dimensions = { + dimension.name: dimension for dimension in semantic_view.dimensions + } + + metrics = [all_metrics[metric] for metric in (query_object.metrics or [])] + + grain = ( + _convert_time_grain(query_object.extras["time_grain_sqla"]) + if "time_grain_sqla" in query_object.extras + else None + ) + dimensions = [ + dimension + for dimension in semantic_view.dimensions + if dimension.name in query_object.columns + and ( + # if a grain is specified, only include the time dimension if its grain + # matches the requested grain + grain is None + or dimension.name != query_object.granularity + or dimension.grain == grain + ) + ] + + order = _get_order_from_query_object(query_object, all_metrics, all_dimensions) + limit = query_object.row_limit + offset = query_object.row_offset + + group_limit = _get_group_limit_from_query_object( + query_object, + all_metrics, + all_dimensions, + ) + + queries = [] + for time_offset in [None] + query_object.time_offsets: + filters = _get_filters_from_query_object( + query_object, + time_offset, + all_dimensions, + ) + + queries.append( + SemanticQuery( + metrics=metrics, + dimensions=dimensions, + filters=filters, + order=order, + limit=limit, + offset=offset, + group_limit=group_limit, + ) + ) + + return queries + + +def _get_filters_from_query_object( + query_object: ValidatedQueryObject, + time_offset: str | None, + all_dimensions: dict[str, Dimension], +) -> set[Filter | AdhocFilter]: + """ + Extract all filters from the query object, including time range filters. + + This simplifies the complexity of from_dttm/to_dttm/inner_from_dttm/inner_to_dttm + by converting all time constraints into filters. + """ + filters: set[Filter | AdhocFilter] = set() + + # 1. Add fetch values predicate if present + if ( + query_object.apply_fetch_values_predicate + and query_object.datasource.fetch_values_predicate + ): + filters.add( + AdhocFilter( + type=PredicateType.WHERE, + definition=query_object.datasource.fetch_values_predicate, + ) + ) + + # 2. Add time range filter based on from_dttm/to_dttm + # For time offsets, this automatically calculates the shifted bounds + time_filters = _get_time_filter(query_object, time_offset, all_dimensions) + filters.update(time_filters) + + # 3. Add filters from query_object.extras (WHERE and HAVING clauses) + extras_filters = _get_filters_from_extras(query_object.extras) + filters.update(extras_filters) + + # 4. Add all other filters from query_object.filter + for filter_ in query_object.filter: + converted_filter = _convert_query_object_filter(filter_, all_dimensions) + if converted_filter: + filters.add(converted_filter) + + return filters + + +def _get_filters_from_extras(extras: dict[str, Any]) -> set[AdhocFilter]: + """ + Extract filters from the extras dict. + + The extras dict can contain various keys that affect query behavior: + + Supported keys (converted to filters): + - "where": SQL WHERE clause expression (e.g., "customer_id > 100") + - "having": SQL HAVING clause expression (e.g., "SUM(sales) > 1000") + + Other keys in extras (handled elsewhere in the mapper): + - "time_grain_sqla": Time granularity (e.g., "P1D", "PT1H") + Handled in _convert_time_grain() and used for dimension grain matching + + Note: The WHERE and HAVING clauses from extras are SQL expressions that + are passed through as-is to the semantic layer as AdhocFilter objects. + """ + filters: set[AdhocFilter] = set() + + # Add WHERE clause from extras + if where_clause := extras.get("where"): + filters.add( + AdhocFilter( + type=PredicateType.WHERE, + definition=where_clause, + ) + ) + + # Add HAVING clause from extras + if having_clause := extras.get("having"): + filters.add( + AdhocFilter( + type=PredicateType.HAVING, + definition=having_clause, + ) + ) + + return filters + + +def _get_time_filter( + query_object: ValidatedQueryObject, + time_offset: str | None, + all_dimensions: dict[str, Dimension], +) -> set[Filter]: + """ + Create a time range filter from the query object. + + This handles both regular queries and time offset queries, simplifying the + complexity of from_dttm/to_dttm/inner_from_dttm/inner_to_dttm by using the + same time bounds for both the main query and series limit subqueries. + """ + filters: set[Filter] = set() + + if not query_object.granularity: + return filters + + time_dimension = all_dimensions.get(query_object.granularity) + if not time_dimension: + return filters + + # Get the appropriate time bounds based on whether this is a time offset query + from_dttm, to_dttm = _get_time_bounds(query_object, time_offset) + + if not from_dttm or not to_dttm: + return filters + + # Create a filter with >= and < operators + return { + Filter( + type=PredicateType.WHERE, + column=time_dimension, + operator=Operator.GREATER_THAN_OR_EQUAL, + value=from_dttm, + ), + Filter( + type=PredicateType.WHERE, + column=time_dimension, + operator=Operator.LESS_THAN, + value=to_dttm, + ), + } + + +def _get_time_bounds( + query_object: ValidatedQueryObject, + time_offset: str | None, +) -> tuple[datetime | None, datetime | None]: + """ + Get the appropriate time bounds for the query. + + For regular queries (time_offset is None), returns from_dttm/to_dttm. + For time offset queries, calculates the shifted bounds. + + This simplifies the inner_from_dttm/inner_to_dttm complexity by using + the same bounds for both main queries and series limit subqueries (Option 1). + """ + if time_offset is None: + # Main query: use from_dttm/to_dttm directly + return query_object.from_dttm, query_object.to_dttm + + # Time offset query: calculate shifted bounds + # Use from_dttm/to_dttm if available, otherwise try to get from time_range + outer_from = query_object.from_dttm + outer_to = query_object.to_dttm + + if not outer_from or not outer_to: + # Fall back to parsing time_range if from_dttm/to_dttm not set + outer_from, outer_to = get_since_until_from_query_object(query_object) + + if not outer_from or not outer_to: + return None, None + + # Apply the offset to both bounds + offset_from = get_past_or_future(time_offset, outer_from) + offset_to = get_past_or_future(time_offset, outer_to) + + return offset_from, offset_to + + +def _convert_query_object_filter( + filter_: ValidatedQueryObjectFilterClause, + all_dimensions: dict[str, Dimension], +) -> Filter | AdhocFilter | None: + """ + Convert a QueryObject filter dict to a semantic layer Filter or AdhocFilter. + """ + operator_str = filter_["op"] + + # Handle TEMPORAL_RANGE filters (these are already handled by _get_time_filter) + if operator_str == FilterOperator.TEMPORAL_RANGE.value: + # Skip - already handled in _get_time_filter + return None + + # Handle simple column filters + col = filter_.get("col") + if col not in all_dimensions: + return None + + dimension = all_dimensions[col] + + val_str = filter_["val"] + value: FilterValues | set[FilterValues] + if val_str is None: + value = None + elif isinstance(val_str, (list, tuple)): + value = set(val_str) + else: + value = val_str + + # Map QueryObject operators to semantic layer operators + operator_mapping = { + FilterOperator.EQUALS.value: Operator.EQUALS, + FilterOperator.NOT_EQUALS.value: Operator.NOT_EQUALS, + FilterOperator.GREATER_THAN.value: Operator.GREATER_THAN, + FilterOperator.LESS_THAN.value: Operator.LESS_THAN, + FilterOperator.GREATER_THAN_OR_EQUALS.value: Operator.GREATER_THAN_OR_EQUAL, + FilterOperator.LESS_THAN_OR_EQUALS.value: Operator.LESS_THAN_OR_EQUAL, + FilterOperator.IN.value: Operator.IN, + FilterOperator.NOT_IN.value: Operator.NOT_IN, + FilterOperator.LIKE.value: Operator.LIKE, + FilterOperator.NOT_LIKE.value: Operator.NOT_LIKE, + FilterOperator.IS_NULL.value: Operator.IS_NULL, + FilterOperator.IS_NOT_NULL.value: Operator.IS_NOT_NULL, + } + + operator = operator_mapping.get(operator_str) + if not operator: + # Unknown operator - create adhoc filter + return None + + return Filter( + type=PredicateType.WHERE, + column=dimension, + operator=operator, + value=value, + ) + + +def _get_order_from_query_object( + query_object: ValidatedQueryObject, + all_metrics: dict[str, Metric], + all_dimensions: dict[str, Dimension], +) -> list[OrderTuple]: + order: list[OrderTuple] = [] + for element, ascending in query_object.orderby: + direction = OrderDirection.ASC if ascending else OrderDirection.DESC + + # adhoc + if isinstance(element, dict): + if element["sqlExpression"] is not None: + order.append( + ( + AdhocExpression( + id=element["label"] or element["sqlExpression"], + definition=element["sqlExpression"], + ), + direction, + ) + ) + elif element in all_dimensions: + order.append((all_dimensions[element], direction)) + elif element in all_metrics: + order.append((all_metrics[element], direction)) + + return order + + +def _get_group_limit_from_query_object( + query_object: ValidatedQueryObject, + all_metrics: dict[str, Metric], + all_dimensions: dict[str, Dimension], +) -> GroupLimit | None: + # no limit + if query_object.series_limit == 0 or not query_object.columns: + return None + + dimensions = [all_dimensions[dim_id] for dim_id in query_object.series_columns] + top = query_object.series_limit + metric = ( + all_metrics[query_object.series_limit_metric] + if query_object.series_limit_metric + else None + ) + direction = OrderDirection.DESC if query_object.order_desc else OrderDirection.ASC + group_others = query_object.group_others_when_limit_reached + + # Check if we need separate filters for the group limit subquery + # This happens when inner_from_dttm/inner_to_dttm differ from from_dttm/to_dttm + group_limit_filters = _get_group_limit_filters(query_object, all_dimensions) + + return GroupLimit( + dimensions=dimensions, + top=top, + metric=metric, + direction=direction, + group_others=group_others, + filters=group_limit_filters, + ) + + +def _get_group_limit_filters( + query_object: ValidatedQueryObject, + all_dimensions: dict[str, Dimension], +) -> set[Filter | AdhocFilter] | None: + """ + Get separate filters for the group limit subquery if needed. + + This is used when inner_from_dttm/inner_to_dttm differ from from_dttm/to_dttm, + which happens during time comparison queries. The group limit subquery may need + different time bounds to determine the top N groups. + + Returns None if the group limit should use the same filters as the main query. + """ + # Check if inner time bounds are explicitly set and differ from outer bounds + if ( + query_object.inner_from_dttm is None + or query_object.inner_to_dttm is None + or ( + query_object.inner_from_dttm == query_object.from_dttm + and query_object.inner_to_dttm == query_object.to_dttm + ) + ): + # No separate bounds needed - use the same filters as the main query + return None + + # Create separate filters for the group limit subquery + filters: set[Filter | AdhocFilter] = set() + + # Add time range filter using inner bounds + if query_object.granularity: + time_dimension = all_dimensions.get(query_object.granularity) + if ( + time_dimension + and query_object.inner_from_dttm + and query_object.inner_to_dttm + ): + filters.update( + { + Filter( + type=PredicateType.WHERE, + column=time_dimension, + operator=Operator.GREATER_THAN_OR_EQUAL, + value=query_object.inner_from_dttm, + ), + Filter( + type=PredicateType.WHERE, + column=time_dimension, + operator=Operator.LESS_THAN, + value=query_object.inner_to_dttm, + ), + } + ) + + # Add fetch values predicate if present + if ( + query_object.apply_fetch_values_predicate + and query_object.datasource.fetch_values_predicate + ): + filters.add( + AdhocFilter( + type=PredicateType.WHERE, + definition=query_object.datasource.fetch_values_predicate, + ) + ) + + # Add filters from query_object.extras (WHERE and HAVING clauses) + extras_filters = _get_filters_from_extras(query_object.extras) + filters.update(extras_filters) + + # Add all other non-temporal filters from query_object.filter + for filter_ in query_object.filter: + # Skip temporal range filters - we're using inner bounds instead + if filter_.get("op") == FilterOperator.TEMPORAL_RANGE.value: + continue + + converted_filter = _convert_query_object_filter(filter_, all_dimensions) + if converted_filter: + filters.add(converted_filter) + + return filters if filters else None + + +def _convert_time_grain(time_grain: str) -> TimeGrain | DateGrain | None: + """ + Convert a time grain string from the query object to a TimeGrain or DateGrain enum. + """ + if time_grain in TimeGrain.__members__: + return TimeGrain[time_grain] + + if time_grain in DateGrain.__members__: + return DateGrain[time_grain] + + return None + + +def validate_query_object( + query_object: QueryObject, +) -> TypeGuard[ValidatedQueryObject]: + """ + Validate that the `QueryObject` is compatible with the `SemanticView`. + + If some semantic view implementation supports these features we should add an + attribute to the `SemanticViewImplementation` to indicate support for them. + """ + if not query_object.datasource: + return False + + query_object = cast(ValidatedQueryObject, query_object) + + _validate_metrics(query_object) + _validate_dimensions(query_object) + _validate_filters(query_object) + _validate_granularity(query_object) + _validate_group_limit(query_object) + _validate_orderby(query_object) + + return True + + +def _validate_metrics(query_object: ValidatedQueryObject) -> None: + """ + Make sure metrics are defined in the semantic view. + """ + semantic_view = query_object.datasource.implementation + + if any(not isinstance(metric, str) for metric in (query_object.metrics or [])): + raise ValueError("Adhoc metrics are not supported in Semantic Views.") + + metric_names = {metric.name for metric in semantic_view.metrics} + if not set(query_object.metrics or []) <= metric_names: + raise ValueError("All metrics must be defined in the Semantic View.") + + +def _validate_dimensions(query_object: ValidatedQueryObject) -> None: + """ + Make sure all dimensions are defined in the semantic view. + """ + semantic_view = query_object.datasource.implementation + + if any(not isinstance(column, str) for column in query_object.columns): + raise ValueError("Adhoc dimensions are not supported in Semantic Views.") + + dimension_names = {dimension.name for dimension in semantic_view.dimensions} + if not set(query_object.columns) <= dimension_names: + raise ValueError("All dimensions must be defined in the Semantic View.") + + +def _validate_filters(query_object: ValidatedQueryObject) -> None: + """ + Make sure all filters are valid. + """ + for filter_ in query_object.filter: + if isinstance(filter_["col"], dict): + raise ValueError( + "Adhoc columns are not supported in Semantic View filters." + ) + if not filter_.get("op"): + raise ValueError("All filters must have an operator defined.") + + +def _validate_granularity(query_object: ValidatedQueryObject) -> None: + """ + Make sure time column and time grain are valid. + """ + semantic_view = query_object.datasource.implementation + dimension_names = {dimension.name for dimension in semantic_view.dimensions} + + if time_column := query_object.granularity: + if time_column not in dimension_names: + raise ValueError( + "The time column must be defined in the Semantic View dimensions." + ) + + if time_grain := query_object.extras.get("time_grain_sqla"): + if not time_column: + raise ValueError( + "A time column must be specified when a time grain is provided." + ) + + supported_time_grains = { + dimension.grain + for dimension in semantic_view.dimensions + if dimension.name == time_column and dimension.grain + } + if _convert_time_grain(time_grain) not in supported_time_grains: + raise ValueError( + "The time grain is not supported for the time column in the " + "Semantic View." + ) + + +def _validate_group_limit(query_object: ValidatedQueryObject) -> None: + """ + Validate group limit related features in the query object. + """ + semantic_view = query_object.datasource.implementation + + # no limit + if query_object.series_limit == 0: + return + + if ( + query_object.series_columns + and SemanticViewFeature.GROUP_LIMIT not in semantic_view.features + ): + raise ValueError("Group limit is not supported in this Semantic View.") + + if any(not isinstance(col, str) for col in query_object.series_columns): + raise ValueError("Adhoc dimensions are not supported in series columns.") + + metric_names = {metric.name for metric in semantic_view.metrics} + if query_object.series_limit_metric and ( + not isinstance(query_object.series_limit_metric, str) + or query_object.series_limit_metric not in metric_names + ): + raise ValueError( + "The series limit metric must be defined in the Semantic View." + ) + + dimension_names = {dimension.name for dimension in semantic_view.dimensions} + if not set(query_object.series_columns) <= dimension_names: + raise ValueError("All series columns must be defined in the Semantic View.") + + if ( + query_object.group_others_when_limit_reached + and SemanticViewFeature.GROUP_OTHERS not in semantic_view.features + ): + raise ValueError( + "Grouping others when limit is reached is not supported in this Semantic " + "View." + ) + + +def _validate_orderby(query_object: ValidatedQueryObject) -> None: + """ + Validate order by elements in the query object. + """ + semantic_view = query_object.datasource.implementation + + if ( + any(not isinstance(element, str) for element, _ in query_object.orderby) + and SemanticViewFeature.ADHOC_EXPRESSIONS_IN_ORDERBY + not in semantic_view.features + ): + raise ValueError( + "Adhoc expressions in order by are not supported in this Semantic View." + ) + + elements = set(query_object.orderby) + metric_names = {metric.name for metric in semantic_view.metrics} + dimension_names = {dimension.name for dimension in semantic_view.dimensions} + if not elements <= metric_names | dimension_names: + raise ValueError("All order by elements must be defined in the Semantic View.") diff --git a/superset/semantic_layers/models.py b/superset/semantic_layers/models.py new file mode 100644 index 00000000000..0beba9c0264 --- /dev/null +++ b/superset/semantic_layers/models.py @@ -0,0 +1,378 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Semantic layer models.""" + +from __future__ import annotations + +import uuid +from collections.abc import Hashable +from dataclasses import dataclass +from functools import cached_property +from importlib.metadata import entry_points +from typing import Any, TYPE_CHECKING + +from flask_appbuilder import Model +from sqlalchemy import Column, ForeignKey, Integer, String, Text +from sqlalchemy.orm import relationship +from sqlalchemy_utils import UUIDType + +from superset.common.query_object import QueryObject +from superset.explorables.base import TimeGrainDict +from superset.extensions import encrypted_field_factory +from superset.models.helpers import AuditMixinNullable, QueryResult +from superset.semantic_layers.mapper import get_results +from superset.semantic_layers.types import ( + BINARY, + BOOLEAN, + DATE, + DATETIME, + DECIMAL, + INTEGER, + INTERVAL, + NUMBER, + OBJECT, + SemanticLayerImplementation, + SemanticViewImplementation, + STRING, + TIME, + Type, +) +from superset.utils import json +from superset.utils.core import GenericDataType + +if TYPE_CHECKING: + from superset.superset_typing import ExplorableData, QueryObjectDict + + +def get_column_type(semantic_type: Type) -> GenericDataType: + """ + Map semantic layer types to generic data types. + """ + if semantic_type in {DATE, DATETIME, TIME}: + return GenericDataType.TEMPORAL + if semantic_type in {INTEGER, NUMBER, DECIMAL, INTERVAL}: + return GenericDataType.NUMERIC + if semantic_type is BOOLEAN: + return GenericDataType.BOOLEAN + if semantic_type in {STRING, OBJECT, BINARY}: + return GenericDataType.STRING + return GenericDataType.STRING + + +@dataclass(frozen=True) +class MetricMetadata: + metric_name: str + expression: str + verbose_name: str | None = None + description: str | None = None + d3format: str | None = None + currency: dict[str, Any] | None = None + warning_text: str | None = None + certified_by: str | None = None + certification_details: str | None = None + + +@dataclass(frozen=True) +class ColumnMetadata: + column_name: str + type: str + is_dttm: bool + verbose_name: str | None = None + description: str | None = None + groupby: bool = True + filterable: bool = True + expression: str | None = None + python_date_format: str | None = None + advanced_data_type: str | None = None + extra: str | None = None + + +class SemanticLayer(AuditMixinNullable, Model): + """ + Semantic layer model. + + A semantic layer provides an abstraction over data sources, + allowing users to query data through a semantic interface. + """ + + __tablename__ = "semantic_layers" + + uuid = Column(UUIDType(binary=True), primary_key=True, default=uuid.uuid4) + + # Core fields + name = Column(String(250), nullable=False) + description = Column(Text, nullable=True) + type = Column(String(250), nullable=False) + + configuration = Column(encrypted_field_factory.create(Text), default="{}") + cache_timeout = Column(Integer, nullable=True) + + # Semantic views relationship + semantic_views: list[SemanticView] = relationship( + "SemanticView", + back_populates="semantic_layer", + cascade="all, delete-orphan", + passive_deletes=True, + ) + + def __repr__(self) -> str: + return self.name or str(self.uuid) + + @cached_property + def implementation( + self, + ) -> SemanticLayerImplementation[Any, SemanticViewImplementation]: + """ + Return semantic layer implementation. + """ + entry_point = next( + iter( + entry_points( + group="superset.semantic_layers", + name=self.type, + ) + ) + ) + implementation_class = entry_point.load() + + if not issubclass(implementation_class, SemanticLayerImplementation): + raise TypeError( + f"Entry point for semantic layer type '{self.type}' " + "must be a subclass of SemanticLayerImplementation" + ) + + return implementation_class.from_configuration(self.configuration) + + +class SemanticView(AuditMixinNullable, Model): + """ + Semantic view model. + + A semantic view represents a queryable view within a semantic layer. + """ + + __tablename__ = "semantic_views" + + uuid = Column(UUIDType(binary=True), primary_key=True, default=uuid.uuid4) + + # Core fields + name = Column(String(250), nullable=False) + + configuration = Column(encrypted_field_factory.create(Text), default="{}") + cache_timeout = Column(Integer, nullable=True) + + # Semantic layer relationship + semantic_layer_uuid = Column( + UUIDType(binary=True), + ForeignKey("semantic_layers.uuid", ondelete="CASCADE"), + nullable=False, + ) + semantic_layer: SemanticLayer = relationship( + "SemanticLayer", + back_populates="semantic_views", + foreign_keys=[semantic_layer_uuid], + ) + + def __repr__(self) -> str: + return self.name or str(self.uuid) + + @cached_property + def implementation(self) -> SemanticViewImplementation: + """ + Return semantic view implementation. + """ + return self.semantic_layer.implementation.get_semantic_view( + self.name, + self.configuration, + ) + + # ========================================================================= + # Explorable protocol implementation + # ========================================================================= + + def get_query_result(self, query_object: QueryObject) -> QueryResult: + return get_results(query_object) + + def get_query_str(self, query_obj: QueryObjectDict) -> str: + return "Not implemented for semantic layers" + + @property + def uid(self) -> str: + return self.implementation.uid() + + @property + def type(self) -> str: + return "semantic_view" + + @property + def metrics(self) -> list[MetricMetadata]: + return [ + MetricMetadata( + metric_name=metric.name, + expression=metric.definition, + description=metric.description, + ) + for metric in self.implementation.metrics + ] + + @property + def columns(self) -> list[ColumnMetadata]: + return [ + ColumnMetadata( + column_name=dimension.name, + type=dimension.type.__name__, + is_dttm=dimension.type in {DATE, TIME, DATETIME}, + description=dimension.description, + expression=dimension.definition, + extra=json.dumps({"grain": dimension.grain}), + ) + for dimension in self.implementation.dimensions + ] + + @property + def column_names(self) -> list[str]: + return [dimension.name for dimension in self.implementation.dimensions] + + @property + def data(self) -> ExplorableData: + return { + # core + "id": self.uuid.hex, + "uid": self.uid, + "type": "semantic_view", + "name": self.name, + "columns": [ + { + "advanced_data_type": None, + "certification_details": None, + "certified_by": None, + "column_name": dimension.name, + "description": dimension.description, + "expression": dimension.definition, + "filterable": True, + "groupby": True, + "id": None, + "uuid": dimension.uuid.hex, + "is_certified": False, + "is_dttm": dimension.type in {DATE, TIME, DATETIME}, + "python_date_format": None, + "type": dimension.type.__name__, + "type_generic": get_column_type(dimension.type), + "verbose_name": None, + "warning_markdown": None, + } + for dimension in self.implementation.dimensions + ], + "metrics": [ + { + "certification_details": None, + "certified_by": None, + "d3format": None, + "description": metric.description, + "expression": metric.definition, + "id": None, + "uuid": metric.uuid.hex, + "is_certified": False, + "metric_name": metric.name, + "warning_markdown": None, + "warning_text": None, + "verbose_name": None, + } + for metric in self.implementation.metrics + ], + "database": {}, + # UI features + "verbose_map": {}, + "order_by_choices": [], + "filter_select": True, + "filter_select_enabled": True, + "sql": None, + "select_star": None, + "owners": [owner.id for owner in self.owners], + "description": None, + "table_name": self.name, + "column_types": [ + get_column_type(dimension.type) + for dimension in self.implementation.dimensions + ], + "column_names": { + dimension.name for dimension in self.implementation.dimensions + }, + # rare + "column_formats": {}, + "datasource_name": self.name, + "perm": self.perm, + "offset": None, + "cache_timeout": self.cache_timeout, + "params": None, + # sql-specific + "schema": None, + "catalog": None, + "main_dttm_col": None, + "time_grain_sqla": [], + "granularity_sqla": [], + "fetch_values_predicate": None, + "template_params": None, + "is_sqllab_view": False, + "extra": None, + "always_filter_main_dttm": False, + "normalize_columns": False, + # TODO XXX + "edit_url": "", + "default_endpoint": None, + "folders": [], + "health_check_message": None, + } + + def get_extra_cache_keys(self, query_obj: QueryObjectDict) -> list[Hashable]: + return [] + + @property + def perm(self) -> str: + return self.semantic_layer_uuid.hex + "::" + self.uuid.hex + + @property + def offset(self) -> int: + # always return datetime as UTC + return 0 + + @property + def get_time_grains(self) -> list[TimeGrainDict]: + return [ + { + "name": dimension.grain.name, + "function": "", + "duration": dimension.grain.representation, + } + for dimension in self.implementation.dimensions + if dimension.grain + ] + + def has_drill_by_columns(self, column_names: list[str]) -> bool: + dimension_names = { + dimension.name for dimension in self.implementation.dimensions + } + return all(column_name in dimension_names for column_name in column_names) + + @property + def is_rls_supported(self) -> bool: + return False + + @property + def query_language(self) -> str | None: + return None diff --git a/superset/semantic_layers/snowflake/__init__.py b/superset/semantic_layers/snowflake/__init__.py new file mode 100644 index 00000000000..40b1a53fed8 --- /dev/null +++ b/superset/semantic_layers/snowflake/__init__.py @@ -0,0 +1,26 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from superset.semantic_layers.snowflake.schemas import SnowflakeConfiguration +from superset.semantic_layers.snowflake.semantic_layer import SnowflakeSemanticLayer +from superset.semantic_layers.snowflake.semantic_view import SnowflakeSemanticView + +__all__ = [ + "SnowflakeConfiguration", + "SnowflakeSemanticLayer", + "SnowflakeSemanticView", +] diff --git a/superset/semantic_layers/snowflake/schemas.py b/superset/semantic_layers/snowflake/schemas.py new file mode 100644 index 00000000000..012ff3392b7 --- /dev/null +++ b/superset/semantic_layers/snowflake/schemas.py @@ -0,0 +1,130 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from typing import Literal, Union + +from pydantic import BaseModel, ConfigDict, Field, model_validator, SecretStr + + +class UserPasswordAuth(BaseModel): + """ + Username and password authentication. + """ + + model_config = ConfigDict(title="Username and password") + + auth_type: Literal["user_password"] = "user_password" + username: str = Field(description="The username to authenticate as.") + password: SecretStr = Field( + description="The password to authenticate with.", + repr=False, + ) + + +class PrivateKeyAuth(BaseModel): + """ + Private key authentication. + """ + + model_config = ConfigDict(title="Private key") + + auth_type: Literal["private_key"] = "private_key" + private_key: SecretStr = Field( + description="The private key to authenticate with, in PEM format.", + repr=False, + ) + private_key_password: SecretStr = Field( + description="The password to decrypt the private key with.", + repr=False, + ) + + +class SnowflakeConfiguration(BaseModel): + """ + Parameters needed to connect to Snowflake. + """ + + # account is the only required parameter + account_identifier: str = Field( + description="The Snowflake account identifier.", + json_schema_extra={"examples": ["abc12345"]}, + ) + + role: str | None = Field( + default=None, + description="The default role to use.", + json_schema_extra={"examples": ["myrole"]}, + ) + warehouse: str | None = Field( + default=None, + description="The default warehouse to use.", + json_schema_extra={"examples": ["testwh"]}, + ) + + auth: Union[UserPasswordAuth, PrivateKeyAuth] = Field( + discriminator="auth_type", + description="Authentication method", + ) + + # database and schema can be optionally provided; if not provided the user + # will be able to browse databases/schemas + database: str | None = Field( + default=None, + description="The default database to use.", + json_schema_extra={ + "examples": ["testdb"], + "x-dynamic": True, + "x-dependsOn": ["account_identifier", "auth"], + }, + ) + allow_changing_database: bool = Field( + default=False, + description="Allow changing the default database.", + ) + schema_: str | None = Field( + default=None, + description="The default schema to use.", + json_schema_extra={ + "examples": ["public"], + "x-dynamic": True, + "x-dependsOn": ["account_identifier", "auth", "database"], + }, + # `schema` is an attribute of `BaseModel` so it needs to be aliased + alias="schema", + ) + allow_changing_schema: bool = Field( + default=False, + description="Allow changing the default schema.", + ) + + @model_validator(mode="after") + def validate_database_schema_settings(self) -> SnowflakeConfiguration: + """ + Validate that if database or schema is not specified, the corresponding + allow_changing flag must be true. + """ + if not self.database and not self.allow_changing_database: + raise ValueError( + "If no database is specified, allow_changing_database must be true" + ) + if not self.schema_ and not self.allow_changing_schema: + raise ValueError( + "If no schema is specified, allow_changing_schema must be true" + ) + return self diff --git a/superset/semantic_layers/snowflake/semantic_layer.py b/superset/semantic_layers/snowflake/semantic_layer.py new file mode 100644 index 00000000000..8f99f6706f1 --- /dev/null +++ b/superset/semantic_layers/snowflake/semantic_layer.py @@ -0,0 +1,236 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from textwrap import dedent +from typing import Any, Literal, TYPE_CHECKING + +from pydantic import create_model, Field +from snowflake.connector import connect +from snowflake.connector.connection import SnowflakeConnection + +from superset.semantic_layers.snowflake.schemas import SnowflakeConfiguration +from superset.semantic_layers.snowflake.utils import get_connection_parameters +from superset.semantic_layers.types import ( + SemanticLayerImplementation, +) + +if TYPE_CHECKING: + from superset.semantic_layers.snowflake.semantic_view import SnowflakeSemanticView + + +class SnowflakeSemanticLayer( + SemanticLayerImplementation[SnowflakeConfiguration, SnowflakeSemanticView] +): + id = "snowflake" + name = "Snowflake Semantic Layer" + description = "Connect to semantic views stored in Snowflake." + + @classmethod + def from_configuration( + cls, + configuration: dict[str, Any], + ) -> SnowflakeSemanticLayer: + """ + Create a SnowflakeSemanticLayer from a configuration dictionary. + """ + config = SnowflakeConfiguration.model_validate(configuration) + return cls(config) + + @classmethod + def get_configuration_schema( + cls, + configuration: SnowflakeConfiguration | None = None, + ) -> dict[str, Any]: + """ + Get the JSON schema for the configuration needed to add the semantic layer. + + A partial configuration can be sent to improve the schema. For example, + providing account and auth will allow the schema to provide a list of + databases; providing a database will allow the schema to provide a list of + schemas. + + Note that database and schema can both be left empty when the semantic layer is + added to Superset; the user will then have to provide them when loading + semantic views. + """ + schema = SnowflakeConfiguration.model_json_schema() + properties = schema["properties"] + + if configuration is None: + # set these to empty; they will be populated when a partial configuration is + # passed + properties["database"]["enum"] = [] + properties["schema"]["enum"] = [] + + return schema + + connection_parameters = get_connection_parameters(configuration) + with connect(**connection_parameters) as connection: + if all( + getattr(configuration, dependency) + for dependency in properties["database"].get("x-dependsOn", []) + ): + options = cls._fetch_databases(connection) + properties["database"]["enum"] = list(options) + + if ( + all( + getattr(configuration, dependency) + for dependency in properties["schema"].get("x-dependsOn", []) + ) + and configuration.database + ): + options = cls._fetch_schemas(connection, configuration.database) + properties["schema"]["enum"] = list(options) + + return schema + + @classmethod + def get_runtime_schema( + cls, + configuration: SnowflakeConfiguration, + runtime_data: dict[str, Any] | None = None, + ) -> dict[str, Any]: + """ + Get the JSON schema for the runtime parameters needed to load semantic views. + + The schema can be enriched with actual values when `runtime_data` is provided, + enabling dynamic schema updates (e.g., populating schema dropdown after + database is selected). + """ + fields: dict[str, tuple[Any, Field]] = {} + + # update configuration with runtime data, for example, to select a schema after + # the database has been selected + configuration = configuration.model_copy(update=runtime_data) + + connection_parameters = get_connection_parameters(configuration) + with connect(**connection_parameters) as connection: + if not configuration.database or configuration.allow_changing_database: + options = cls._fetch_databases(connection) + fields["database"] = ( + Literal[*options], + Field(description="The default database to use."), + ) + + if not configuration.schema_ or configuration.allow_changing_schema: + if configuration.database: + options = cls._fetch_schemas(connection, configuration.database) + fields["schema_"] = ( + Literal[*options], + Field( + description="The default schema to use.", + alias="schema", + json_schema_extra=( + { + "x-dynamic": True, + "x-dependsOn": ["database"], + } + if "database" in fields + else {} + ), + ), + ) + else: + # Database not provided yet, add schema as empty + # (will be populated dynamically) + fields["schema_"] = ( + str | None, + Field( + default=None, + description="The default schema to use.", + alias="schema", + json_schema_extra={ + "x-dynamic": True, + "x-dependsOn": ["database"], + }, + ), + ) + + return create_model("RuntimeParameters", **fields).model_json_schema() + + @classmethod + def _fetch_databases(cls, connection: SnowflakeConnection) -> set[str]: + """ + Fetch the list of databases available in the Snowflake account. + + We use `SHOW DATABASES` instead of querying the information schema since it + allows to retrieve the list of databases without having to specify a database + when connecting. + """ + cursor = connection.cursor() + cursor.execute("SHOW DATABASES") + return {row[1] for row in cursor} + + @classmethod + def _fetch_schemas( + cls, + connection: SnowflakeConnection, + database: str | None, + ) -> set[str]: + """ + Fetch the list of schemas available in a given database. + + The connection should already have the database set in its context. + """ + if not database: + return set() + + cursor = connection.cursor() + query = dedent( + """ + SELECT SCHEMA_NAME + FROM INFORMATION_SCHEMA.SCHEMATA + WHERE CATALOG_NAME = ? + """ + ).strip() + return {row[0] for row in cursor.execute(query, (database,))} + + def __init__(self, configuration: SnowflakeConfiguration): + self.configuration = configuration + + def get_semantic_views( + self, + runtime_configuration: dict[str, Any], + ) -> set[SnowflakeSemanticView]: + """ + Get the semantic views available in the semantic layer. + """ + # Avoid circular import + from superset.semantic_layers.snowflake.semantic_view import ( + SnowflakeSemanticView, + ) + + # create a new configuration with the runtime parameters + configuration = self.configuration.model_copy(update=runtime_configuration) + + connection_parameters = get_connection_parameters(configuration) + with connect(**connection_parameters) as connection: + cursor = connection.cursor() + query = dedent( + """ + SHOW SEMANTIC VIEWS + ->> SELECT "name" FROM $1; + """ + ).strip() + views = { + SnowflakeSemanticView(row[0], configuration) + for row in cursor.execute(query) + } + return views diff --git a/superset/semantic_layers/snowflake/semantic_view.py b/superset/semantic_layers/snowflake/semantic_view.py new file mode 100644 index 00000000000..e40635ab5d7 --- /dev/null +++ b/superset/semantic_layers/snowflake/semantic_view.py @@ -0,0 +1,817 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# ruff: noqa: S608 + +from __future__ import annotations + +import itertools +import re +from collections import defaultdict +from textwrap import dedent + +from pandas import DataFrame +from snowflake.connector import connect, DictCursor +from snowflake.sqlalchemy.snowdialect import SnowflakeDialect + +from superset.semantic_layers.snowflake.schemas import SnowflakeConfiguration +from superset.semantic_layers.snowflake.utils import ( + get_connection_parameters, + substitute_parameters, + validate_order_by, +) +from superset.semantic_layers.types import ( + AdhocExpression, + AdhocFilter, + BINARY, + BOOLEAN, + DATE, + DATETIME, + DECIMAL, + Dimension, + Filter, + FilterValues, + GroupLimit, + INTEGER, + Metric, + NUMBER, + OBJECT, + Operator, + OrderTuple, + PredicateType, + SemanticRequest, + SemanticResult, + SemanticViewFeature, + SemanticViewImplementation, + STRING, + TIME, + Type, +) + +REQUEST_TYPE = "snowflake" + + +class SnowflakeSemanticView(SemanticViewImplementation): + features = frozenset( + { + SemanticViewFeature.ADHOC_EXPRESSIONS_IN_ORDERBY, + SemanticViewFeature.GROUP_LIMIT, + SemanticViewFeature.GROUP_OTHERS, + } + ) + + def __init__(self, name: str, configuration: SnowflakeConfiguration): + self.configuration = configuration + self.name = name + + self._quote = SnowflakeDialect().identifier_preparer.quote + + self.dimensions = self.get_dimensions() + self.metrics = self.get_metrics() + + def uid(self) -> str: + return ".".join( + self._quote(part) + for part in ( + self.configuration.database, + self.configuration.schema_, + self.name, + ) + ) + + def get_dimensions(self) -> set[Dimension]: + """ + Get the dimensions defined in the semantic view. + + Even though Snowflake supports `SHOW SEMANTIC DIMENSIONS IN my_semantic_view`, + it doesn't return the expression of dimensions, so we use a slightly more + complicated query to get all the information we need in one go. + """ + dimensions: set[Dimension] = set() + + query = dedent( + f""" + DESC SEMANTIC VIEW {self.uid()} + ->> SELECT "object_name", "property", "property_value" + FROM $1 + WHERE + "object_kind" = 'DIMENSION' AND + "property" IN ('COMMENT', 'DATA_TYPE', 'EXPRESSION', 'TABLE'); + """ + ).strip() + + connection_parameters = get_connection_parameters(self.configuration) + with connect(**connection_parameters) as connection: + cursor = connection.cursor(DictCursor) + rows = cursor.execute(query).fetchall() + + for name, group in itertools.groupby(rows, key=lambda x: x["object_name"]): + attributes = defaultdict(set) + for row in group: + attributes[row["property"]].add(row["property_value"]) + + table = next(iter(attributes["TABLE"])) + id_ = table + "." + name + type_ = self._get_type(next(iter(attributes["DATA_TYPE"]))) + description = next(iter(attributes["COMMENT"]), None) + definition = next(iter(attributes["EXPRESSION"]), None) + + dimensions.add(Dimension(id_, name, type_, description, definition)) + + return dimensions + + def get_metrics(self) -> set[Metric]: + """ + Get the metrics defined in the semantic view. + """ + metrics: set[Metric] = set() + + query = dedent( + f""" + DESC SEMANTIC VIEW {self.uid()} + ->> SELECT "object_name", "property", "property_value" + FROM $1 + WHERE + "object_kind" = 'METRIC' AND + "property" IN ('COMMENT', 'DATA_TYPE', 'EXPRESSION', 'TABLE'); + """ + ).strip() + + connection_parameters = get_connection_parameters(self.configuration) + with connect(**connection_parameters) as connection: + cursor = connection.cursor(DictCursor) + rows = cursor.execute(query).fetchall() + + for name, group in itertools.groupby(rows, key=lambda x: x["object_name"]): + attributes = defaultdict(set) + for row in group: + attributes[row["property"]].add(row["property_value"]) + + table = next(iter(attributes["TABLE"])) + id_ = table + "." + name + type_ = self._get_type(next(iter(attributes["DATA_TYPE"]))) + description = next(iter(attributes["COMMENT"]), None) + definition = next(iter(attributes["EXPRESSION"]), None) + + metrics.add(Metric(id_, name, type_, definition, description)) + + return metrics + + def _get_type(self, snowflake_type: str | None) -> type[Type]: + """ + Return the semantic type corresponding to a Snowflake type. + """ + if snowflake_type is None: + return STRING + + type_map = { + STRING: {r"VARCHAR\(\d+\)$", "STRING$", "TEXT$", r"CHAR\(\d+\)$"}, + INTEGER: {r"NUMBER\(38,\s?0\)$", "INT$", "INTEGER$", "BIGINT$"}, + DECIMAL: {r"NUMBER\(10,\s?2\)$"}, + NUMBER: {r"NUMBER\(\d+,\s?\d+\)$", "FLOAT$", "DOUBLE$"}, + BOOLEAN: {"BOOLEAN$"}, + DATE: {"DATE$"}, + DATETIME: {"TIMESTAMP_TZ$", "TIMESTAMP__NTZ$"}, + TIME: {"TIME$"}, + OBJECT: {"OBJECT$"}, + BINARY: {r"BINARY\(\d+\)$", r"VARBINARY\(\d+\)$"}, + } + for semantic_type, patterns in type_map.items(): + if any( + re.match(pattern, snowflake_type, re.IGNORECASE) for pattern in patterns + ): + return semantic_type + + return STRING + + def _build_predicates( + self, + filters: list[Filter | AdhocFilter], + ) -> tuple[str, tuple[FilterValues, ...]]: + """ + Convert a set of filters to a single `AND`ed predicate. + + Caller should check the types of filters beforehand, as this method does not + differentiate between `WHERE` and `HAVING` predicates. + """ + if not filters: + return "", () + + # convert filters predicate with associated parameters; native filters are + # already strings, so we keep them as-is + unary_operators = {Operator.IS_NULL, Operator.IS_NOT_NULL} + predicates: list[str] = [] + parameters: list[FilterValues] = [] + for filter_ in filters or set(): + if isinstance(filter_, AdhocFilter): + predicates.append(f"({filter_.definition})") + else: + predicates.append(f"({self._build_native_filter(filter_)})") + if filter_.operator not in unary_operators: + parameters.extend( + [filter_.value] + if not isinstance(filter_.value, (set, frozenset)) + else filter_.value + ) + + return " AND ".join(predicates), tuple(parameters) + + def get_values( + self, + dimension: Dimension, + filters: set[Filter | AdhocFilter] | None = None, + ) -> SemanticResult: + """ + Return distinct values for a dimension. + """ + where_clause, parameters = self._build_predicates( + sorted( + filter_ + for filter_ in (filters or []) + if filter_.type == PredicateType.WHERE + ) + ) + query = dedent( + f""" + SELECT {self._quote(dimension.name)} + FROM SEMANTIC_VIEW( + {self.uid()} + DIMENSIONS {dimension.id} + {"WHERE " + where_clause if where_clause else ""} + ) + """ + ).strip() + connection_parameters = get_connection_parameters(self.configuration) + with connect(**connection_parameters) as connection: + df = connection.cursor().execute(query, parameters).fetch_pandas_all() + + return SemanticResult( + requests=[ + SemanticRequest( + REQUEST_TYPE, + substitute_parameters(query, parameters), + ) + ], + results=df, + ) + + def _build_native_filter(self, filter_: Filter) -> str: + """ + Convert a Filter to a AdhocFilter. + """ + column = filter_.column + operator = filter_.operator + value = filter_.value + + column_name = self._quote(column.name) + + # Handle IS NULL and IS NOT NULL operators (no value needed) + if operator in {Operator.IS_NULL, Operator.IS_NOT_NULL}: + return f"{column_name} {operator.value}" + + # Handle IN and NOT IN operators (set values) + if operator in {Operator.IN, Operator.NOT_IN}: + parameter_count = len(value) if isinstance(value, (set, frozenset)) else 1 + formatted_values = ", ".join("?" for _ in range(parameter_count)) + return f"{column_name} {operator.value} ({formatted_values})" + + return f"{column_name} {operator.value} ?" + + def get_dataframe( + self, + metrics: list[Metric], + dimensions: list[Dimension], + filters: set[Filter | AdhocFilter] | None = None, + order: list[OrderTuple] | None = None, + limit: int | None = None, + offset: int | None = None, + *, + group_limit: GroupLimit | None = None, + ) -> SemanticResult: + """ + Execute a query and return the results as a Pandas DataFrame. + """ + if not metrics and not dimensions: + return DataFrame() + + query, parameters = self._get_query( + metrics, + dimensions, + filters, + order, + limit, + offset, + group_limit, + ) + connection_parameters = get_connection_parameters(self.configuration) + with connect(**connection_parameters) as connection: + df = connection.cursor().execute(query, parameters).fetch_pandas_all() + + return SemanticResult( + requests=[ + SemanticRequest( + REQUEST_TYPE, + substitute_parameters(query, parameters), + ) + ], + results=df, + ) + + def get_row_count( + self, + metrics: list[Metric], + dimensions: list[Dimension], + filters: set[Filter | AdhocFilter] | None = None, + order: list[OrderTuple] | None = None, + limit: int | None = None, + offset: int | None = None, + *, + group_limit: GroupLimit | None = None, + ) -> SemanticResult: + """ + Execute a query and return the number of rows the result would have. + """ + if not metrics and not dimensions: + return SemanticResult( + requests=[], + results=DataFrame([[0]], columns=["COUNT"]), + ) + + query, parameters = self._get_query( + metrics, + dimensions, + filters, + order, + limit, + offset, + group_limit, + ) + query = f"SELECT COUNT(*) FROM ({query}) AS subquery" + connection_parameters = get_connection_parameters(self.configuration) + with connect(**connection_parameters) as connection: + df = connection.cursor().execute(query, parameters).fechone()[0] + + return SemanticResult( + requests=[ + SemanticRequest( + REQUEST_TYPE, + substitute_parameters(query, parameters), + ) + ], + results=df, + ) + + def _get_query( + self, + metrics: list[Metric], + dimensions: list[Dimension], + filters: set[Filter | AdhocFilter] | None = None, + order: list[OrderTuple] | None = None, + limit: int | None = None, + offset: int | None = None, + group_limit: GroupLimit | None = None, + ) -> tuple[str, tuple[FilterValues, ...]]: + """ + Build a query to fetch data from the semantic view. + + This also returns the parameters need to run `cursor.execute()`, passed + separately to prevent SQL injection. + """ + if limit is None and offset is not None: + raise ValueError("Offset cannot be set without limit") + + filters = filters or set() + where_clause, where_parameters = self._build_predicates( + sorted( + filter_ for filter_ in filters if filter_.type == PredicateType.WHERE + ) + ) + # having clauses are not supported, since there's no GROUP BY + if any(filter_.type == PredicateType.HAVING for filter_ in filters): + raise ValueError("HAVING filters are not supported") + + if group_limit: + query, cte_parameters = self._build_query_with_group_limit( + metrics, + dimensions, + where_clause, + order, + limit, + offset, + group_limit, + ) + # Combine parameters: CTE params first, then main query params + all_parameters = cte_parameters + where_parameters + else: + query = self._build_simple_query( + metrics, + dimensions, + where_clause, + order, + limit, + offset, + ) + all_parameters = where_parameters + + return query, all_parameters + + def _alias_element(self, element: Metric | Dimension) -> str: + """ + Generate an aliased column expression for a metric or dimension. + """ + return f"{element.id} AS {self._quote(element.id)}" + + def _build_order_clause( + self, + order: list[OrderTuple] | None = None, + ) -> str: + """ + Build the ORDER BY clause from a list of (element, direction) tuples. + + Note that for adhoc expressions, Superset will still add `ASC` or `DESC` to the + end, which means adhoc expressions can contain multiple columns as long as the + last one has no direction specified. + + This is fine: + + gender ASC, COUNT(*) + + But this is not + + gender ASC, COUNT(*) DESC + + The latter will produce a query that looks like this: + + ... ORDER BY gender ASC, COUNT(*) DESC DESC + + """ + if not order: + return "" + + def build_element(element: Metric | Dimension | AdhocExpression) -> str: + if isinstance(element, AdhocExpression): + validate_order_by(element.definition) + return element.definition + return self._quote(element.id) + + return ", ".join( + f"{build_element(element)} {direction.value}" + for element, direction in order + ) + + def _build_simple_query( + self, + metrics: list[Metric], + dimensions: list[Dimension], + where_clause: str, + order: list[OrderTuple] | None, + limit: int | None, + offset: int | None, + ) -> str: + """ + Build a query without group limiting. + """ + dimension_arguments = ", ".join( + self._alias_element(dimension) for dimension in dimensions + ) + metric_arguments = ", ".join(self._alias_element(metric) for metric in metrics) + order_clause = self._build_order_clause(order) + + return dedent( + f""" + SELECT * FROM SEMANTIC_VIEW( + {self.uid()} + {"DIMENSIONS " + dimension_arguments if dimension_arguments else ""} + {"METRICS " + metric_arguments if metric_arguments else ""} + {"WHERE " + where_clause if where_clause else ""} + ) + {"ORDER BY " + order_clause if order_clause else ""} + {"LIMIT " + str(limit) if limit is not None else ""} + {"OFFSET " + str(offset) if offset is not None else ""} + """ + ).strip() + + def _build_top_groups_cte( + self, + group_limit: GroupLimit, + where_clause: str, + ) -> tuple[str, tuple[FilterValues, ...]]: + """ + Build a CTE that finds the top N combinations of limited dimensions. + + If group_limit.filters is set, it uses those filters instead of the main + query's where clause. This allows using different time bounds for finding top + groups vs showing data. + + Returns: + Tuple of (CTE SQL, parameters for the CTE) + """ + limited_dimension_arguments = ", ".join( + self._alias_element(dimension) for dimension in group_limit.dimensions + ) + limited_dimension_names = ", ".join( + self._quote(dimension.id) for dimension in group_limit.dimensions + ) + + # Use separate filters for group limit if provided (Option 2) + # Otherwise use the same filters as the main query (Option 1) + if group_limit.filters is not None: + group_where_clause, group_where_params = self._build_predicates( + sorted( + filter_ + for filter_ in group_limit.filters + if filter_.type == PredicateType.WHERE + ) + ) + if any( + filter_.type == PredicateType.HAVING for filter_ in group_limit.filters + ): + raise ValueError( + "HAVING filters are not supported in group limit filters" + ) + cte_params = group_where_params + else: + group_where_clause = where_clause + cte_params = () # No additional params - using main query params + + # Build METRICS clause and ORDER BY based on whether metric is provided + if group_limit.metric is not None: + metrics_clause = ( + f"METRICS {group_limit.metric.id}" + f" AS {self._quote(group_limit.metric.id)}" + ) + order_by_clause = ( + f"{self._quote(group_limit.metric.id)} {group_limit.direction.value}" + ) + else: + # No metric provided - order by first dimension + metrics_clause = "" + order_by_clause = ( + f"{self._quote(group_limit.dimensions[0].id)} " + f"{group_limit.direction.value}" + ) + + # Build SEMANTIC_VIEW arguments + semantic_view_args = [ + f"DIMENSIONS {limited_dimension_arguments}", + ] + if metrics_clause: + semantic_view_args.append(metrics_clause) + if group_where_clause: + semantic_view_args.append(f"WHERE {group_where_clause}") + + semantic_view_args_str = "\n ".join(semantic_view_args) + + # Add trailing blank line if there's no WHERE clause + # This matches the original template behavior + if not group_where_clause: + semantic_view_args_str += "\n" + + cte_sql = dedent( + f""" + WITH top_groups AS ( + SELECT {limited_dimension_names} + FROM SEMANTIC_VIEW( + {self.uid()} + {semantic_view_args_str} + ) + ORDER BY + {order_by_clause} + LIMIT {group_limit.top} + ) + """ + ).strip() + + return cte_sql, cte_params + + def _build_group_filter(self, group_limit: GroupLimit) -> str: + """ + Build a WHERE filter that restricts results to top N groups. + """ + if len(group_limit.dimensions) == 1: + dimension_id = self._quote(group_limit.dimensions[0].id) + return f"{dimension_id} IN (SELECT {dimension_id} FROM top_groups)" + + # Multi-column IN clause + dimension_tuple = ", ".join( + self._quote(dim.id) for dim in group_limit.dimensions + ) + return f"({dimension_tuple}) IN (SELECT {dimension_tuple} FROM top_groups)" + + def _build_case_expression( + self, + dimension: Dimension, + group_condition: str, + ) -> str: + """ + Build a CASE expression that replaces non-top values with 'Other'. + + Args: + dimension: The dimension to build the CASE for + group_condition: The condition to check if value is in top groups + (e.g., "staff_id IN (SELECT staff_id FROM top_groups)") + + Returns: + SQL CASE expression + """ + dimension_id = self._quote(dimension.id) + return f"""CASE + WHEN {group_condition} THEN {dimension_id} + ELSE CAST('Other' AS VARCHAR) + END""" + + def _build_query_with_others( + self, + metrics: list[Metric], + dimensions: list[Dimension], + where_clause: str, + order: list[OrderTuple] | None, + limit: int | None, + offset: int | None, + group_limit: GroupLimit, + ) -> tuple[str, tuple[FilterValues, ...]]: + """ + Build a query that groups non-top N values as 'Other'. + + This uses a two-stage approach: + 1. CTE to find top N groups + 2. Subquery with CASE expressions to replace non-top values with 'Other' + 3. Outer query to re-aggregate with the new grouping + + Returns: + Tuple of (SQL query, CTE parameters) + """ + top_groups_cte, cte_params = self._build_top_groups_cte( + group_limit, + where_clause, + ) + + # Determine which dimensions are limited vs non-limited + limited_dimension_ids = {dim.id for dim in group_limit.dimensions} + non_limited_dimensions = [ + dim for dim in dimensions if dim.id not in limited_dimension_ids + ] + + # Build the group condition for CASE expressions + if len(group_limit.dimensions) == 1: + dimension_id = self._quote(group_limit.dimensions[0].id) + group_condition = ( + f"{dimension_id} IN (SELECT {dimension_id} FROM top_groups)" + ) + else: + dimension_tuple = ", ".join( + self._quote(dim.id) for dim in group_limit.dimensions + ) + group_condition = ( + f"({dimension_tuple}) IN (SELECT {dimension_tuple} FROM top_groups)" + ) + + # Build CASE expressions for limited dimensions + case_expressions = [] + case_expressions_for_groupby = [] + for dim in group_limit.dimensions: + case_expr = self._build_case_expression(dim, group_condition) + alias = self._quote(dim.id) + case_expressions.append(f"{case_expr} AS {alias}") + # Store the full CASE expression for GROUP BY (not just alias) + case_expressions_for_groupby.append(case_expr) + + # Build SELECT for non-limited dimensions (pass through) + non_limited_selects = [ + f"{self._quote(dim.id)} AS {self._quote(dim.id)}" + for dim in non_limited_dimensions + ] + + # Build metric aggregations + metric_aggregations = [ + f"SUM({self._quote(metric.id)}) AS {self._quote(metric.id)}" + for metric in metrics + ] + + # Build the subquery that gets raw data from SEMANTIC_VIEW + dimension_arguments = ", ".join( + self._alias_element(dimension) for dimension in dimensions + ) + metric_arguments = ", ".join(self._alias_element(metric) for metric in metrics) + + subquery = dedent( + f""" + raw_data AS ( + SELECT * FROM SEMANTIC_VIEW( + {self.uid()} + DIMENSIONS {dimension_arguments} + METRICS {metric_arguments} + {"WHERE " + where_clause if where_clause else ""} + ) + ) + """ + ).strip() + + # Build GROUP BY clause (full CASE expressions + non-limited dimensions) + # We need to repeat the full CASE expressions, not use aliases, because + # Snowflake may interpret the alias as the original column reference + group_by_columns = case_expressions_for_groupby + [ + self._quote(dim.id) for dim in non_limited_dimensions + ] + group_by_clause = ", ".join(group_by_columns) + + # Build final SELECT columns + select_columns = case_expressions + non_limited_selects + metric_aggregations + select_clause = ",\n ".join(select_columns) + + # Build ORDER BY clause (need to reference the aliased columns) + order_clause = self._build_order_clause(order) + + query = dedent( + f""" + {top_groups_cte}, + {subquery} + SELECT + {select_clause} + FROM raw_data + GROUP BY {group_by_clause} + {"ORDER BY " + order_clause if order_clause else ""} + {"LIMIT " + str(limit) if limit is not None else ""} + {"OFFSET " + str(offset) if offset is not None else ""} + """ + ).strip() + + return query, cte_params + + def _build_query_with_group_limit( + self, + metrics: list[Metric], + dimensions: list[Dimension], + where_clause: str, + order: list[OrderTuple] | None, + limit: int | None, + offset: int | None, + group_limit: GroupLimit, + ) -> tuple[str, tuple[FilterValues, ...]]: + """ + Build a query with group limiting (top N groups). + + If group_others is True, groups non-top values as 'Other'. + Otherwise, filters to show only top N groups. + + Returns: + Tuple of (SQL query, CTE parameters) + """ + if group_limit.group_others: + return self._build_query_with_others( + metrics, + dimensions, + where_clause, + order, + limit, + offset, + group_limit, + ) + + # Standard group limiting: just filter to top N groups + # We can't use CTE references inside SEMANTIC_VIEW(), so we wrap it + dimension_arguments = ", ".join( + self._alias_element(dimension) for dimension in dimensions + ) + metric_arguments = ", ".join(self._alias_element(metric) for metric in metrics) + order_clause = self._build_order_clause(order) + + top_groups_cte, cte_params = self._build_top_groups_cte( + group_limit, + where_clause, + ) + group_filter = self._build_group_filter(group_limit) + + query = dedent( + f""" + {top_groups_cte} + SELECT * FROM SEMANTIC_VIEW( + {self.uid()} + {"DIMENSIONS " + dimension_arguments if dimension_arguments else ""} + {"METRICS " + metric_arguments if metric_arguments else ""} + {"WHERE " + where_clause if where_clause else ""} + ) AS subquery + WHERE {group_filter} + {"ORDER BY " + order_clause if order_clause else ""} + {"LIMIT " + str(limit) if limit is not None else ""} + {"OFFSET " + str(offset) if offset is not None else ""} + """ + ).strip() + + return query, cte_params + + __repr__ = uid diff --git a/superset/semantic_layers/snowflake/utils.py b/superset/semantic_layers/snowflake/utils.py new file mode 100644 index 00000000000..76251c0288a --- /dev/null +++ b/superset/semantic_layers/snowflake/utils.py @@ -0,0 +1,123 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# ruff: noqa: S608 + +from __future__ import annotations + +from typing import Any, Sequence + +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import serialization + +from superset.exceptions import SupersetParseError +from superset.semantic_layers.snowflake.schemas import ( + PrivateKeyAuth, + SnowflakeConfiguration, + UserPasswordAuth, +) +from superset.sql.parse import SQLStatement + + +def substitute_parameters(query: str, parameters: Sequence[Any] | None) -> str: + """ + Substitute parametereters in templated query. + + This is used to convert bind query parameters so that we can return the executed + query for logging/auditing purposes. With Snowflake the binding happens on the + server, so the only way to get the true executed query would be to query the + database, which is innefficient. + """ + if not parameters: + return query + + result = query + for parameter in parameters: + if parameter is None: + replacement = "NULL" + elif isinstance(parameter, bool): + # Check bool before int/float since bool is a subclass of int + replacement = str(parameter).upper() + elif isinstance(parameter, (int, float)): + replacement = str(parameter) + else: + # String - escape single quotes + quoted = str(parameter).replace("'", "''") + replacement = f"'{quoted}'" + + result = result.replace("?", replacement, 1) + + return result + + +def validate_order_by(definition: str) -> None: + """ + Validate that an ORDER BY expression is safe to use. + + Note that `definition` could contain multiple expressions separated by commas. + """ + try: + # this ensures that we have a single statement, preventing SQL injection via a + # semicolon in the order by clause + SQLStatement(f"SELECT 1 ORDER BY {definition}", "snowflake") + except SupersetParseError as ex: + raise ValueError("Invalid ORDER BY expression") from ex + + +def get_connection_parameters(configuration: SnowflakeConfiguration) -> dict[str, Any]: + """ + Convert the configuration to connection parameters for the Snowflake connector. + """ + params = { + "account": configuration.account_identifier, + "application": "Apache Superset", + "paramstyle": "qmark", + "insecure_mode": True, + } + + if configuration.role: + params["role"] = configuration.role + if configuration.warehouse: + params["warehouse"] = configuration.warehouse + if configuration.database: + params["database"] = configuration.database + if configuration.schema_: + params["schema"] = configuration.schema_ + + auth = configuration.auth + if isinstance(auth, UserPasswordAuth): + params["user"] = auth.username + params["password"] = auth.password.get_secret_value() + elif isinstance(auth, PrivateKeyAuth): + pem_private_key = serialization.load_pem_private_key( + auth.private_key.get_secret_value().encode(), + password=( + auth.private_key_password.get_secret_value().encode() + if auth.private_key_password + else None + ), + backend=default_backend(), + ) + params["private_key"] = pem_private_key.private_bytes( + encoding=serialization.Encoding.DER, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + else: + raise ValueError("Unsupported authentication method") + + return params diff --git a/superset/semantic_layers/types.py b/superset/semantic_layers/types.py new file mode 100644 index 00000000000..826ee50666c --- /dev/null +++ b/superset/semantic_layers/types.py @@ -0,0 +1,497 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import enum +from dataclasses import dataclass +from datetime import date, datetime, time, timedelta +from functools import total_ordering +from typing import Any, Protocol, runtime_checkable, TypeVar + +from pandas import DataFrame +from pydantic import BaseModel + +__all__ = [ + "BINARY", + "BOOLEAN", + "DATE", + "DATETIME", + "DECIMAL", + "Day", + "Dimension", + "Hour", + "INTEGER", + "INTERVAL", + "Minute", + "Month", + "NUMBER", + "OBJECT", + "Quarter", + "Second", + "STRING", + "TIME", + "Week", + "Year", +] + + +class Type: + """ + Base class for types. + """ + + +class INTEGER(Type): + """ + Represents an integer type. + """ + + +class NUMBER(Type): + """ + Represents a number type. + """ + + +class DECIMAL(Type): + """ + Represents a decimal type. + """ + + +class STRING(Type): + """ + Represents a string type. + """ + + +class BOOLEAN(Type): + """ + Represents a boolean type. + """ + + +class DATE(Type): + """ + Represents a date type. + """ + + +class TIME(Type): + """ + Represents a time type. + """ + + +class DATETIME(DATE, TIME): + """ + Represents a datetime type. + """ + + +class INTERVAL(Type): + """ + Represents an interval type. + """ + + +class OBJECT(Type): + """ + Represents an object type. + """ + + +class BINARY(Type): + """ + Represents a binary type. + """ + + +@dataclass(frozen=True) +@total_ordering +class Grain: + """ + Base class for time and date grains with comparison support. + + Attributes: + name: Human-readable name of the grain (e.g., "Second") + representation: ISO 8601 representation (e.g., "PT1S") + value: Time period as a timedelta + """ + + name: str + representation: str + value: timedelta + + def __eq__(self, other: object) -> bool: + if isinstance(other, Grain): + return self.value == other.value + return NotImplemented + + def __lt__(self, other: object) -> bool: + if isinstance(other, Grain): + return self.value < other.value + return NotImplemented + + def __hash__(self) -> int: + return hash((self.name, self.representation, self.value)) + + +class Second(Grain): + name = "Second" + representation = "PT1S" + value = timedelta(seconds=1) + + +class Minute(Grain): + name = "Minute" + representation = "PT1M" + value = timedelta(minutes=1) + + +class Hour(Grain): + name = "Hour" + representation = "PT1H" + value = timedelta(hours=1) + + +class Day(Grain): + name = "Day" + representation = "P1D" + value = timedelta(days=1) + + +class Week(Grain): + name = "Week" + representation = "P1W" + value = timedelta(weeks=1) + + +class Month(Grain): + name = "Month" + representation = "P1M" + value = timedelta(days=30) + + +class Quarter(Grain): + name = "Quarter" + representation = "P3M" + value = timedelta(days=90) + + +class Year(Grain): + name = "Year" + representation = "P1Y" + value = timedelta(days=365) + + +@dataclass(frozen=True) +class Dimension: + id: str + name: str + type: type[Type] + + definition: str | None = None + description: str | None = None + grain: Grain | None = None + + +@dataclass(frozen=True) +class Metric: + id: str + name: str + type: type[Type] + + definition: str | None + description: str | None = None + + +@dataclass(frozen=True) +class AdhocExpression: + id: str + definition: str + + +class Operator(str, enum.Enum): + EQUALS = "=" + NOT_EQUALS = "!=" + GREATER_THAN = ">" + LESS_THAN = "<" + GREATER_THAN_OR_EQUAL = ">=" + LESS_THAN_OR_EQUAL = "<=" + IN = "IN" + NOT_IN = "NOT IN" + LIKE = "LIKE" + NOT_LIKE = "NOT LIKE" + IS_NULL = "IS NULL" + IS_NOT_NULL = "IS NOT NULL" + + +FilterValues = str | int | float | bool | datetime | date | time | timedelta | None + + +class PredicateType(enum.Enum): + WHERE = "WHERE" + HAVING = "HAVING" + + +@dataclass(frozen=True, order=True) +class Filter: + type: PredicateType + column: Dimension | Metric + operator: Operator + value: FilterValues | set[FilterValues] + + +@dataclass(frozen=True, order=True) +class AdhocFilter: + type: PredicateType + definition: str + + +class OrderDirection(enum.Enum): + ASC = "ASC" + DESC = "DESC" + + +OrderTuple = tuple[Metric | Dimension | AdhocExpression, OrderDirection] + + +@dataclass(frozen=True) +class GroupLimit: + """ + Limit query to top/bottom N combinations of specified dimensions. + + The `filters` parameter allows specifying separate filter constraints for the + group limit subquery. This is useful when you want to determine the top N groups + using different criteria (e.g., a different time range) than the main query. + + For example, you might want to find the top 10 products by sales over the last + 30 days, but then show daily sales for those products over the last 7 days. + """ + + dimensions: list[Dimension] + top: int + metric: Metric | None + direction: OrderDirection = OrderDirection.DESC + group_others: bool = False + filters: set[Filter | AdhocFilter] | None = None + + +@dataclass(frozen=True) +class SemanticRequest: + """ + Represents a request made to obtain semantic results. + + This could be a SQL query, an HTTP request, etc. + """ + + type: str + definition: str + + +@dataclass(frozen=True) +class SemanticResult: + """ + Represents the results of a semantic query. + + This includes any requests (SQL queries, HTTP requests) that were performed in order + to obtain the results, in order to help troubleshooting. + """ + + requests: list[SemanticRequest] + results: DataFrame + + +@dataclass(frozen=True) +class SemanticQuery: + """ + Represents a semantic query. + """ + + metrics: list[Metric] + dimensions: list[Dimension] + filters: set[Filter | AdhocFilter] | None = None + order: list[OrderTuple] | None = None + limit: int | None = None + offset: int | None = None + group_limit: GroupLimit | None = None + + +class SemanticViewFeature(enum.Enum): + """ + Custom features supported by semantic layers. + """ + + ADHOC_EXPRESSIONS_IN_ORDERBY = "ADHOC_EXPRESSIONS_IN_ORDERBY" + GROUP_LIMIT = "GROUP_LIMIT" + GROUP_OTHERS = "GROUP_OTHERS" + + +ConfigT = TypeVar("ConfigT", bound=BaseModel, contravariant=True) +SemanticViewT = TypeVar("SemanticViewT", bound="SemanticViewImplementation") + + +@runtime_checkable +class SemanticLayerImplementation(Protocol[ConfigT, SemanticViewT]): + """ + A protocol for semantic layers. + """ + + @classmethod + def from_configuration( + cls, + configuration: dict[str, Any], + ) -> SemanticLayerImplementation[ConfigT, SemanticViewT]: + """ + Create a semantic layer from its configuration. + """ + + @classmethod + def get_configuration_schema( + cls, + configuration: ConfigT | None = None, + ) -> dict[str, Any]: + """ + Get the JSON schema for the configuration needed to add the semantic layer. + + A partial configuration `configuration` can be sent to improve the schema, + allowing for progressive validation and better UX. For example, a semantic + layer might require: + + - auth information + - a database + + If the user provides the auth information, a client can send the partial + configuration to this method, and the resulting JSON schema would include + the list of databases the user has access to, allowing a dropdown to be + populated. + + The Snowflake semantic layer has an example implementation of this method, where + database and schema names are populated based on the provided connection info. + """ + + @classmethod + def get_runtime_schema( + cls, + configuration: ConfigT, + runtime_data: dict[str, Any] | None = None, + ) -> dict[str, Any]: + """ + Get the JSON schema for the runtime parameters needed to load semantic views. + + This returns the schema needed to connect to a semantic view given the + configuration for the semantic layer. For example, a semantic layer might + be configured by: + + - auth information + - an optional database + + If the user does not provide a database when creating the semantic layer, the + runtime schema would require the database name to be provided before loading any + semantic views. This allows users to create semantic layers that connect to a + specific database (or project, account, etc.), or that allow users to select it + at query time. + + The Snowflake semantic layer has an example implementation of this method, where + database and schema names are required if they were not provided in the initial + configuration. + """ + + def get_semantic_views( + self, + runtime_configuration: dict[str, Any], + ) -> set[SemanticViewT]: + """ + Get the semantic views available in the semantic layer. + + The runtime configuration can provide information like a given project or + schema, used to restrict the semantic views returned. + """ + + def get_semantic_view( + self, + name: str, + additional_configuration: dict[str, Any], + ) -> SemanticViewT: + """ + Get a specific semantic view by its name and additional configuration. + """ + + +@runtime_checkable +class SemanticViewImplementation(Protocol): + """ + A protocol for semantic views. + """ + + features: frozenset[SemanticViewFeature] + + def uid(self) -> str: + """ + Returns a unique identifier for the semantic view. + """ + + def get_dimensions(self) -> set[Dimension]: + """ + Get the dimensions defined in the semantic view. + """ + + def get_metrics(self) -> set[Metric]: + """ + Get the metrics defined in the semantic view. + """ + + def get_values( + self, + dimension: Dimension, + filters: set[Filter | AdhocFilter] | None = None, + ) -> SemanticResult: + """ + Return distinct values for a dimension. + """ + + def get_dataframe( + self, + metrics: list[Metric], + dimensions: list[Dimension], + filters: set[Filter | AdhocFilter] | None = None, + order: list[OrderTuple] | None = None, + limit: int | None = None, + offset: int | None = None, + *, + group_limit: GroupLimit | None = None, + ) -> SemanticResult: + """ + Execute a semantic query and return the results as a DataFrame. + """ + + def get_row_count( + self, + metrics: list[Metric], + dimensions: list[Dimension], + filters: set[Filter | AdhocFilter] | None = None, + order: list[OrderTuple] | None = None, + limit: int | None = None, + offset: int | None = None, + *, + group_limit: GroupLimit | None = None, + ) -> SemanticResult: + """ + Execute a query and return the number of rows the result would have. + """ diff --git a/superset/superset_typing.py b/superset/superset_typing.py index 5841dc245ee..45e122b2ae0 100644 --- a/superset/superset_typing.py +++ b/superset/superset_typing.py @@ -57,6 +57,46 @@ class AdhocMetric(TypedDict, total=False): sqlExpression: str | None +class DatasetColumnData(TypedDict, total=False): + """Type for column metadata in ExplorableData datasets.""" + + advanced_data_type: str | None + certification_details: str | None + certified_by: str | None + column_name: str + description: str | None + expression: str | None + filterable: bool + groupby: bool + id: int + uuid: str | None + is_certified: bool + is_dttm: bool + python_date_format: str | None + type: str + type_generic: NotRequired["GenericDataType" | None] + verbose_name: str | None + warning_markdown: str | None + + +class DatasetMetricData(TypedDict, total=False): + """Type for metric metadata in ExplorableData datasets.""" + + certification_details: str | None + certified_by: str | None + currency: NotRequired[dict[str, Any]] + d3format: str | None + description: str | None + expression: str + id: int + uuid: str | None + is_certified: bool + metric_name: str + warning_markdown: str | None + warning_text: str | None + verbose_name: str | None + + class AdhocColumn(TypedDict, total=False): hasCustomLabel: bool | None label: str @@ -272,8 +312,8 @@ class ExplorableData(TypedDict, total=False): perm: str | None edit_url: str sql: str | None - columns: list[dict[str, Any]] - metrics: list[dict[str, Any]] + columns: list["DatasetColumnData"] + metrics: list["DatasetMetricData"] folders: Any # JSON field, can be list or dict order_by_choices: list[tuple[str, str]] owners: list[int] | list[dict[str, Any]] # Can be either format @@ -281,8 +321,8 @@ class ExplorableData(TypedDict, total=False): select_star: str | None # Additional fields from SqlaTable and data_for_slices - column_types: list[Any] - column_names: set[str] | set[Any] + column_types: list["GenericDataType"] + column_names: set[str] granularity_sqla: list[tuple[Any, Any]] time_grain_sqla: list[tuple[Any, Any]] main_dttm_col: str | None