mirror of
https://github.com/apache/superset.git
synced 2026-06-10 10:09:14 +00:00
Compare commits
4 Commits
ci/cypress
...
builtin-sl
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4a749e3f31 | ||
|
|
439bcf8d79 | ||
|
|
ce7753b631 | ||
|
|
b9492f477b |
@@ -2529,6 +2529,13 @@ except ImportError:
|
||||
LOCAL_EXTENSIONS: list[str] = []
|
||||
EXTENSIONS_PATH: str | None = None
|
||||
|
||||
# When True, dataset queries are routed through the dataset semantic-layer
|
||||
# extension (``superset/semantic_layers/extension``) instead of the legacy
|
||||
# ``get_sqla_query`` path. The semantic view builds the SQL via sqlglot and
|
||||
# the mapper handles the QueryObject → SemanticQuery translation. Falls back
|
||||
# to the legacy path on any error.
|
||||
USE_DATASET_SEMANTIC_VIEW: bool = False
|
||||
|
||||
# Default polling interval for tasks (seconds)
|
||||
TASK_ABORT_POLLING_DEFAULT_INTERVAL = 10
|
||||
|
||||
|
||||
@@ -19,10 +19,12 @@ from __future__ import annotations
|
||||
|
||||
import builtins
|
||||
import logging
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from collections.abc import Hashable
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, cast, Optional, Union
|
||||
|
||||
import pandas as pd
|
||||
@@ -71,6 +73,7 @@ from superset_core.common.models import Dataset as CoreDataset
|
||||
from superset import db, is_feature_enabled, security_manager
|
||||
from superset.commands.dataset.exceptions import DatasetNotFoundError
|
||||
from superset.common.db_query_status import QueryStatus
|
||||
from superset.common.query_object import QueryObject
|
||||
from superset.connectors.sqla.utils import (
|
||||
get_columns_description,
|
||||
get_physical_table_metadata,
|
||||
@@ -2139,6 +2142,67 @@ class SqlaTable(
|
||||
"""Returns a text clause using ExploreMixin implementation"""
|
||||
return ExploreMixin.text(self, clause)
|
||||
|
||||
@property
|
||||
def implementation(self) -> Any:
|
||||
"""
|
||||
Expose the dataset as a ``SemanticView`` instance.
|
||||
|
||||
The dataset semantic-layer extension (``superset/semantic_layers/
|
||||
extension``) wraps a ``SqlaTable`` in a ``DatasetSemanticView``,
|
||||
translating the dataset's columns and metrics into semantic dimensions
|
||||
and metrics. This property exists so the same mapper-based execution
|
||||
path used by stored semantic views (``mapper.get_results``) can drive
|
||||
plain dataset queries when ``USE_DATASET_SEMANTIC_VIEW`` is enabled.
|
||||
"""
|
||||
# The extension's backend lives outside the regular Python path; add it
|
||||
# the first time we need it. The ``@semantic_layer`` decorator has
|
||||
# already been monkey-patched by ``inject_semantic_layer_implementations``
|
||||
# at app init, so importing here is safe at runtime.
|
||||
extension_src = (
|
||||
Path(__file__).resolve().parents[1]
|
||||
/ "semantic_layers"
|
||||
/ "extension"
|
||||
/ "backend"
|
||||
/ "src"
|
||||
)
|
||||
extension_src_str = str(extension_src)
|
||||
if extension_src_str not in sys.path:
|
||||
sys.path.insert(0, extension_src_str)
|
||||
|
||||
from preset_io.dataset_semantic_layer import DatasetSemanticView
|
||||
|
||||
return DatasetSemanticView(self)
|
||||
|
||||
def get_query_result(self, query_object: QueryObject) -> QueryResult:
|
||||
"""
|
||||
Route dataset queries through the dataset semantic-layer extension
|
||||
when ``USE_DATASET_SEMANTIC_VIEW`` is enabled, otherwise fall back to
|
||||
the legacy ``ExploreMixin`` path.
|
||||
|
||||
The mapper in ``superset.semantic_layers.mapper`` handles the
|
||||
``QueryObject`` → ``SemanticQuery`` translation (filters, time range,
|
||||
group limit, ordering, time offsets), so all we add here is the
|
||||
feature-flag gate and a safety fallback for cases the semantic view
|
||||
does not yet support.
|
||||
"""
|
||||
if current_app.config.get("USE_DATASET_SEMANTIC_VIEW"):
|
||||
from superset.semantic_layers.mapper import get_results
|
||||
|
||||
try:
|
||||
# ``mapper.get_results`` reads ``query_object.datasource.implementation``;
|
||||
# ensure we are the datasource on the in-memory QueryObject.
|
||||
if query_object.datasource is None:
|
||||
query_object.datasource = self
|
||||
return get_results(query_object)
|
||||
except Exception: # pylint: disable=broad-except
|
||||
logger.warning(
|
||||
"Semantic-view execution failed for dataset %s; falling "
|
||||
"back to the legacy query path.",
|
||||
self.table_name,
|
||||
exc_info=True,
|
||||
)
|
||||
return ExploreMixin.get_query_result(self, query_object)
|
||||
|
||||
|
||||
sa.event.listen(SqlaTable, "before_update", SqlaTable.before_update)
|
||||
sa.event.listen(SqlaTable, "after_insert", SqlaTable.after_insert)
|
||||
|
||||
197
superset/semantic_layers/adhoc.py
Normal file
197
superset/semantic_layers/adhoc.py
Normal file
@@ -0,0 +1,197 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Convert frontend "adhoc" metric / column dictionaries into semantic-layer
|
||||
``Metric`` / ``Dimension`` objects.
|
||||
|
||||
Adhoc metrics and columns are user-defined SQL expressions that don't
|
||||
correspond to a saved metric or physical column on the dataset. The legacy
|
||||
``get_sqla_query`` path supports them through ``adhoc_metric_to_sqla`` /
|
||||
``adhoc_column_to_sqla``; the semantic-layer path needs an equivalent so
|
||||
that flipping ``USE_DATASET_SEMANTIC_VIEW`` doesn't regress charts that
|
||||
use them.
|
||||
|
||||
The strategy is light: synthesize a ``Metric(definition=<sql>)`` /
|
||||
``Dimension(definition=<sql>)`` and let the view's existing
|
||||
``definition``-parsing path handle the rest. Jinja templating + safety
|
||||
checks are delegated to the dataset's own
|
||||
``_process_select_expression`` (already battle-tested).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
import pyarrow as pa
|
||||
from superset_core.semantic_layers.types import (
|
||||
AggregationType,
|
||||
Dimension,
|
||||
Metric,
|
||||
)
|
||||
|
||||
from superset.exceptions import QueryObjectValidationError
|
||||
from superset.semantic_layers.arrow_inference import infer_arrow_type
|
||||
from superset.utils.core import AdhocMetricExpressionType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.jinja_context import BaseTemplateProcessor
|
||||
|
||||
|
||||
# Map Superset's aggregate strings to the semantic-layer ``AggregationType``.
|
||||
# Anything unmapped (median, percentile, …) falls back to OTHER, which the
|
||||
# view treats as a non-rollup-safe metric.
|
||||
_AGGREGATE_MAP: dict[str, AggregationType] = {
|
||||
"COUNT": AggregationType.COUNT,
|
||||
"SUM": AggregationType.SUM,
|
||||
"AVG": AggregationType.AVG,
|
||||
"MIN": AggregationType.MIN,
|
||||
"MAX": AggregationType.MAX,
|
||||
"COUNT_DISTINCT": AggregationType.COUNT_DISTINCT,
|
||||
}
|
||||
|
||||
|
||||
def _process_adhoc_sql(
|
||||
dataset: "SqlaTable",
|
||||
sql_expression: str,
|
||||
template_processor: "BaseTemplateProcessor | None",
|
||||
) -> str:
|
||||
"""
|
||||
Run Jinja + safety validation on a free-form SQL expression.
|
||||
|
||||
Delegates to the dataset's ``_process_select_expression`` so adhoc
|
||||
columns and metrics go through the same Jinja rendering, subquery
|
||||
rules, and clause sanitization as the legacy path.
|
||||
"""
|
||||
processed = dataset._process_select_expression( # noqa: SLF001
|
||||
expression=sql_expression,
|
||||
database_id=dataset.database_id,
|
||||
engine=dataset.database.db_engine_spec.engine,
|
||||
schema=dataset.schema or "",
|
||||
template_processor=template_processor,
|
||||
)
|
||||
if not processed:
|
||||
raise QueryObjectValidationError(
|
||||
"Adhoc SQL expression resolved to an empty string."
|
||||
)
|
||||
return processed
|
||||
|
||||
|
||||
def _simple_metric_definition(adhoc: dict[str, Any], dataset: "SqlaTable") -> str:
|
||||
"""
|
||||
Build the SQL for a SIMPLE adhoc metric: ``AGGREGATE(column)``.
|
||||
|
||||
Uses the database's identifier quoter so the column name is rendered
|
||||
in the right dialect.
|
||||
"""
|
||||
aggregate = (adhoc.get("aggregate") or "").upper()
|
||||
column = adhoc.get("column") or {}
|
||||
column_name = column.get("column_name")
|
||||
if not aggregate or not column_name:
|
||||
raise QueryObjectValidationError(
|
||||
"SIMPLE adhoc metric requires both an aggregate and a column."
|
||||
)
|
||||
if aggregate == "COUNT_DISTINCT":
|
||||
return f"COUNT(DISTINCT {dataset.quote_identifier(column_name)})"
|
||||
return f"{aggregate}({dataset.quote_identifier(column_name)})"
|
||||
|
||||
|
||||
def adhoc_metric_to_semantic_metric(
|
||||
adhoc: dict[str, Any],
|
||||
dataset: "SqlaTable",
|
||||
template_processor: "BaseTemplateProcessor | None" = None,
|
||||
) -> Metric:
|
||||
"""
|
||||
Convert an ``AdhocMetric`` dict to a semantic-layer ``Metric``.
|
||||
|
||||
SIMPLE metrics build ``f"{aggregate}({column})"`` and tag the
|
||||
aggregation type so the view can preserve rollup semantics. SQL
|
||||
metrics use the user-provided ``sqlExpression`` (after Jinja +
|
||||
safety validation) and report ``AggregationType.OTHER`` because the
|
||||
expression is opaque.
|
||||
|
||||
Adhoc dtype is set to ``pa.null()`` — the view aliases the metric
|
||||
by id and downstream chart code coerces at render time.
|
||||
"""
|
||||
label = adhoc.get("label")
|
||||
if not label:
|
||||
raise QueryObjectValidationError("Adhoc metric is missing a ``label``.")
|
||||
|
||||
expression_type = adhoc.get("expressionType")
|
||||
if expression_type == AdhocMetricExpressionType.SIMPLE.value:
|
||||
definition = _simple_metric_definition(adhoc, dataset)
|
||||
# Templating is irrelevant for SIMPLE metrics (no user-supplied SQL),
|
||||
# so the safety pass is skipped here.
|
||||
aggregate = (adhoc.get("aggregate") or "").upper()
|
||||
aggregation = _AGGREGATE_MAP.get(aggregate, AggregationType.OTHER)
|
||||
elif expression_type == AdhocMetricExpressionType.SQL.value:
|
||||
sql_expression = adhoc.get("sqlExpression")
|
||||
if not sql_expression:
|
||||
raise QueryObjectValidationError(
|
||||
"SQL adhoc metric is missing ``sqlExpression``."
|
||||
)
|
||||
definition = _process_adhoc_sql(dataset, sql_expression, template_processor)
|
||||
aggregation = AggregationType.OTHER
|
||||
else:
|
||||
raise QueryObjectValidationError(
|
||||
f"Unknown adhoc metric expressionType: {expression_type!r}"
|
||||
)
|
||||
|
||||
return Metric(
|
||||
id=label,
|
||||
name=label,
|
||||
type=infer_arrow_type(definition, dataset.database.db_engine_spec.engine),
|
||||
definition=definition,
|
||||
aggregation=aggregation,
|
||||
)
|
||||
|
||||
|
||||
def adhoc_column_to_semantic_dimension(
|
||||
adhoc: dict[str, Any],
|
||||
dataset: "SqlaTable",
|
||||
dimensions_by_name: dict[str, Dimension],
|
||||
template_processor: "BaseTemplateProcessor | None" = None,
|
||||
) -> Dimension:
|
||||
"""
|
||||
Convert an ``AdhocColumn`` dict to a semantic-layer ``Dimension``.
|
||||
|
||||
When the adhoc is a thin wrapper around a real column
|
||||
(``isColumnReference=True`` and ``sqlExpression`` matches a known
|
||||
dimension name), the existing ``Dimension`` is returned so that
|
||||
metadata such as type and grain are preserved. Otherwise a fresh
|
||||
``Dimension`` with ``definition=<rendered SQL>`` is synthesized.
|
||||
"""
|
||||
label = adhoc.get("label")
|
||||
if not label:
|
||||
raise QueryObjectValidationError("Adhoc column is missing a ``label``.")
|
||||
sql_expression = adhoc.get("sqlExpression")
|
||||
if not sql_expression:
|
||||
raise QueryObjectValidationError(
|
||||
"Adhoc column is missing ``sqlExpression``."
|
||||
)
|
||||
|
||||
if adhoc.get("isColumnReference") and sql_expression in dimensions_by_name:
|
||||
return dimensions_by_name[sql_expression]
|
||||
|
||||
definition = _process_adhoc_sql(dataset, sql_expression, template_processor)
|
||||
return Dimension(
|
||||
id=label,
|
||||
name=label,
|
||||
type=infer_arrow_type(definition, dataset.database.db_engine_spec.engine),
|
||||
definition=definition,
|
||||
)
|
||||
37
superset/semantic_layers/extension/.gitignore
vendored
Normal file
37
superset/semantic_layers/extension/.gitignore
vendored
Normal file
@@ -0,0 +1,37 @@
|
||||
# Dependencies
|
||||
node_modules/
|
||||
|
||||
# Build outputs from the frontend bundler. The top-level dist/ is checked in
|
||||
# because this extension ships as an in-built bundle loaded via LOCAL_EXTENSIONS.
|
||||
frontend/dist/
|
||||
*.supx
|
||||
|
||||
# Python
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
*.egg-info/
|
||||
.eggs/
|
||||
*.egg
|
||||
.venv/
|
||||
venv/
|
||||
env/
|
||||
ENV/
|
||||
|
||||
# IDE
|
||||
.idea/
|
||||
.vscode/
|
||||
*.swp
|
||||
*.swo
|
||||
|
||||
# OS
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
|
||||
# Logs
|
||||
*.log
|
||||
npm-debug.log*
|
||||
|
||||
# Environment
|
||||
.env
|
||||
.env.local
|
||||
4
superset/semantic_layers/extension/backend/.coveragerc
Normal file
4
superset/semantic_layers/extension/backend/.coveragerc
Normal file
@@ -0,0 +1,4 @@
|
||||
[report]
|
||||
exclude_lines =
|
||||
pragma: no cover
|
||||
if TYPE_CHECKING:
|
||||
21
superset/semantic_layers/extension/backend/conftest.py
Normal file
21
superset/semantic_layers/extension/backend/conftest.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""
|
||||
Root conftest that patches the semantic_layer decorator before test collection.
|
||||
|
||||
The decorator in superset_core.semantic_layers.decorators is a placeholder that
|
||||
raises NotImplementedError outside of a running Superset instance. We replace it
|
||||
with a no-op passthrough so the decorated classes can be imported during tests.
|
||||
"""
|
||||
|
||||
import superset_core.semantic_layers.decorators as _dec
|
||||
|
||||
|
||||
def _noop_semantic_layer(**kwargs):
|
||||
"""Return the class unchanged."""
|
||||
|
||||
def wrapper(cls):
|
||||
return cls
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
_dec.semantic_layer = _noop_semantic_layer
|
||||
13
superset/semantic_layers/extension/backend/pyproject.toml
Normal file
13
superset/semantic_layers/extension/backend/pyproject.toml
Normal file
@@ -0,0 +1,13 @@
|
||||
[project]
|
||||
name = "preset_io-dataset_semantic_layer"
|
||||
version = "0.1.0"
|
||||
license = "All Rights Reserved"
|
||||
dependencies = [
|
||||
"sqlglot",
|
||||
]
|
||||
|
||||
[tool.apache_superset_extensions.build]
|
||||
include = [
|
||||
"src/preset_io/dataset_semantic_layer/**/*.py",
|
||||
]
|
||||
exclude = []
|
||||
@@ -0,0 +1,9 @@
|
||||
from .layer import DatasetSemanticLayer
|
||||
from .schemas import DatasetConfiguration
|
||||
from .view import DatasetSemanticView
|
||||
|
||||
__all__ = [
|
||||
"DatasetConfiguration",
|
||||
"DatasetSemanticLayer",
|
||||
"DatasetSemanticView",
|
||||
]
|
||||
@@ -0,0 +1,3 @@
|
||||
# The @semantic_layer decorator on DatasetSemanticLayer handles registration
|
||||
# automatically. This import triggers the decorator at extension load time.
|
||||
from .layer import DatasetSemanticLayer # noqa: F401
|
||||
@@ -0,0 +1,95 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from superset_core.semantic_layers.config import build_configuration_schema
|
||||
from superset_core.semantic_layers.decorators import semantic_layer
|
||||
from superset_core.semantic_layers.layer import SemanticLayer
|
||||
|
||||
from .schemas import DatasetConfiguration
|
||||
from .utils import dataset_label, get_dataset_by_id, list_datasets
|
||||
from .view import DatasetSemanticView
|
||||
|
||||
|
||||
@semantic_layer(
|
||||
id="dataset",
|
||||
name="Dataset Semantic Layer",
|
||||
description=(
|
||||
"Expose any Superset dataset as a semantic view. Metrics and columns "
|
||||
"defined on the dataset become the semantic metrics and dimensions."
|
||||
),
|
||||
)
|
||||
class DatasetSemanticLayer(SemanticLayer[DatasetConfiguration, DatasetSemanticView]):
|
||||
configuration_class = DatasetConfiguration
|
||||
|
||||
@classmethod
|
||||
def from_configuration(
|
||||
cls,
|
||||
configuration: dict[str, Any],
|
||||
) -> "DatasetSemanticLayer":
|
||||
config = DatasetConfiguration.model_validate(configuration or {})
|
||||
return cls(config)
|
||||
|
||||
@classmethod
|
||||
def get_configuration_schema(
|
||||
cls,
|
||||
configuration: DatasetConfiguration | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
No configuration is needed to add this semantic layer — the layer wraps
|
||||
whatever datasets are already registered in Superset.
|
||||
"""
|
||||
return build_configuration_schema(DatasetConfiguration, configuration)
|
||||
|
||||
@classmethod
|
||||
def get_runtime_schema(
|
||||
cls,
|
||||
configuration: DatasetConfiguration,
|
||||
runtime_data: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
When adding a semantic view, the user picks a dataset from a dropdown.
|
||||
Each existing dataset can be wrapped into a semantic view.
|
||||
"""
|
||||
datasets = list_datasets()
|
||||
ids = [str(dataset.id) for dataset in datasets]
|
||||
labels = [dataset_label(dataset) for dataset in datasets]
|
||||
|
||||
return {
|
||||
"type": "object",
|
||||
"required": ["dataset_id"],
|
||||
"x-singleView": True,
|
||||
"properties": {
|
||||
"dataset_id": {
|
||||
"type": "string",
|
||||
"title": "Dataset",
|
||||
"description": "The Superset dataset to expose as a semantic view.",
|
||||
"enum": ids,
|
||||
"x-enumNames": labels,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
def __init__(self, configuration: DatasetConfiguration) -> None:
|
||||
self.configuration = configuration
|
||||
|
||||
def get_semantic_views(
|
||||
self,
|
||||
runtime_configuration: dict[str, Any],
|
||||
) -> set[DatasetSemanticView]:
|
||||
dataset_id = runtime_configuration.get("dataset_id")
|
||||
if not dataset_id:
|
||||
return set()
|
||||
dataset = get_dataset_by_id(int(dataset_id))
|
||||
return {DatasetSemanticView(dataset)}
|
||||
|
||||
def get_semantic_view(
|
||||
self,
|
||||
name: str,
|
||||
additional_configuration: dict[str, Any],
|
||||
) -> DatasetSemanticView:
|
||||
dataset_id = additional_configuration.get("dataset_id")
|
||||
if not dataset_id:
|
||||
raise ValueError("dataset_id is required to load a semantic view")
|
||||
dataset = get_dataset_by_id(int(dataset_id))
|
||||
return DatasetSemanticView(dataset)
|
||||
@@ -0,0 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
class DatasetConfiguration(BaseModel):
|
||||
"""
|
||||
Configuration for the dataset semantic layer.
|
||||
|
||||
The layer wraps Superset datasets directly, so it does not need any
|
||||
configuration beyond what Superset already knows about each dataset.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(title="Dataset semantic layer", extra="ignore")
|
||||
@@ -0,0 +1,89 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import pyarrow as pa
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import pandas as pd
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
|
||||
|
||||
def list_datasets() -> list["SqlaTable"]:
|
||||
"""
|
||||
Return every ``SqlaTable`` the current user can access, sorted by name.
|
||||
|
||||
Read inside the function so the module can be imported in unit tests that
|
||||
do not have a Flask app context — these helpers are only called at runtime.
|
||||
"""
|
||||
from superset import db
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
|
||||
# ``no_autoflush`` prevents SQLAlchemy from flushing pending session state
|
||||
# while we iterate over results — flushing during iteration of the
|
||||
# session's pending deque has been observed to raise "deque mutated
|
||||
# during iteration" in this codebase.
|
||||
with db.session.no_autoflush:
|
||||
datasets = db.session.query(SqlaTable).all()
|
||||
return sorted(datasets, key=lambda d: d.table_name)
|
||||
|
||||
|
||||
def get_dataset_by_id(dataset_id: int) -> "SqlaTable":
|
||||
"""
|
||||
Look up a dataset by primary key, eagerly materialising its columns and
|
||||
metrics so downstream code can iterate them without triggering further
|
||||
lazy loads (and without risking autoflush mid-iteration).
|
||||
"""
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from superset import db
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
|
||||
with db.session.no_autoflush:
|
||||
dataset = (
|
||||
db.session.query(SqlaTable)
|
||||
.options(
|
||||
selectinload(SqlaTable.columns),
|
||||
selectinload(SqlaTable.metrics),
|
||||
)
|
||||
.filter_by(id=dataset_id)
|
||||
.one_or_none()
|
||||
)
|
||||
if dataset is None:
|
||||
raise ValueError(f"Dataset with id {dataset_id} does not exist.")
|
||||
# Force materialisation inside the no_autoflush block so later access
|
||||
# never lazy-loads while another flush is in progress.
|
||||
list(dataset.columns)
|
||||
list(dataset.metrics)
|
||||
return dataset
|
||||
|
||||
|
||||
def dataset_label(dataset: "SqlaTable") -> str:
|
||||
"""
|
||||
Human-readable label for a dataset combining schema and table name.
|
||||
"""
|
||||
parts = [
|
||||
part
|
||||
for part in (dataset.catalog, dataset.schema, dataset.table_name)
|
||||
if part
|
||||
]
|
||||
return ".".join(parts) if parts else dataset.table_name
|
||||
|
||||
|
||||
def df_to_arrow(df: "pd.DataFrame") -> pa.Table:
|
||||
"""
|
||||
Convert a pandas DataFrame to a pyarrow Table, falling back gracefully when
|
||||
the DataFrame is empty.
|
||||
"""
|
||||
if df is None or df.empty:
|
||||
return pa.table({col: [] for col in (df.columns if df is not None else [])})
|
||||
return pa.Table.from_pandas(df, preserve_index=False)
|
||||
|
||||
|
||||
def coerce_literal(value: Any) -> Any:
|
||||
"""
|
||||
Coerce filter literal values to native Python types sqlglot can serialise.
|
||||
"""
|
||||
if isinstance(value, (set, frozenset)):
|
||||
return [coerce_literal(v) for v in value]
|
||||
return value
|
||||
@@ -0,0 +1,840 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pyarrow as pa
|
||||
import sqlglot
|
||||
from sqlglot import expressions as sqlglot_exp
|
||||
from superset_core.semantic_layers.types import (
|
||||
AdhocExpression,
|
||||
AggregationType,
|
||||
Dimension,
|
||||
Filter,
|
||||
GroupLimit,
|
||||
Metric,
|
||||
Operator,
|
||||
OrderDirection,
|
||||
PredicateType,
|
||||
SemanticQuery,
|
||||
SemanticRequest,
|
||||
SemanticResult,
|
||||
)
|
||||
from superset_core.semantic_layers.view import SemanticView, SemanticViewFeature
|
||||
|
||||
from .utils import coerce_literal, df_to_arrow
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
|
||||
|
||||
REQUEST_TYPE = "sql"
|
||||
|
||||
# Map Superset's GenericDataType (an IntEnum value) to a pyarrow type. Kept
|
||||
# narrow on purpose — the dataset model only tracks a coarse generic type, so
|
||||
# anything more specific belongs to the underlying database and is left as
|
||||
# UTF-8 at the semantic layer boundary.
|
||||
_GENERIC_TYPE_TO_ARROW: dict[int, pa.DataType] = {
|
||||
0: pa.float64(), # NUMERIC
|
||||
1: pa.utf8(), # STRING
|
||||
2: pa.timestamp("us"), # TEMPORAL
|
||||
3: pa.bool_(), # BOOLEAN
|
||||
}
|
||||
|
||||
# Top-level aggregate AST node → AggregationType. Compound expressions like
|
||||
# ``SUM(a) + SUM(b)`` fall through to OTHER because they cannot be safely
|
||||
# rolled up.
|
||||
_AGG_NODE_MAP: dict[type[sqlglot_exp.Expression], AggregationType] = {
|
||||
sqlglot_exp.Sum: AggregationType.SUM,
|
||||
sqlglot_exp.Min: AggregationType.MIN,
|
||||
sqlglot_exp.Max: AggregationType.MAX,
|
||||
sqlglot_exp.Avg: AggregationType.AVG,
|
||||
}
|
||||
|
||||
_OPERATOR_TO_SQLGLOT: dict[Operator, type[sqlglot_exp.Expression]] = {
|
||||
Operator.EQUALS: sqlglot_exp.EQ,
|
||||
Operator.NOT_EQUALS: sqlglot_exp.NEQ,
|
||||
Operator.GREATER_THAN: sqlglot_exp.GT,
|
||||
Operator.LESS_THAN: sqlglot_exp.LT,
|
||||
Operator.GREATER_THAN_OR_EQUAL: sqlglot_exp.GTE,
|
||||
Operator.LESS_THAN_OR_EQUAL: sqlglot_exp.LTE,
|
||||
Operator.LIKE: sqlglot_exp.Like,
|
||||
}
|
||||
|
||||
|
||||
class DatasetSemanticView(SemanticView):
|
||||
features = frozenset(
|
||||
{
|
||||
SemanticViewFeature.ADHOC_EXPRESSIONS_IN_ORDERBY,
|
||||
SemanticViewFeature.GROUP_LIMIT,
|
||||
SemanticViewFeature.GROUP_OTHERS,
|
||||
}
|
||||
)
|
||||
|
||||
def __init__(self, dataset: "SqlaTable") -> None:
|
||||
self.dataset = dataset
|
||||
self.name = dataset.table_name
|
||||
|
||||
self.dimensions = self.get_dimensions()
|
||||
self.metrics = self.get_metrics()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Identity & dialect
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def uid(self) -> str:
|
||||
return f"dataset:{self.dataset.id}"
|
||||
|
||||
@property
|
||||
def _sqlglot_dialect(self) -> str | None:
|
||||
# Imported lazily — ``superset.sql.parse`` pulls in a heavy dependency
|
||||
# graph that is not needed for unit tests of pure AST building.
|
||||
from superset.sql.parse import SQLGLOT_DIALECTS
|
||||
|
||||
engine = self.dataset.database.db_engine_spec.engine
|
||||
dialect = SQLGLOT_DIALECTS.get(engine)
|
||||
return dialect.value if dialect is not None else None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Metadata
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get_dimensions(self) -> set[Dimension]:
|
||||
dimensions: set[Dimension] = set()
|
||||
for column in self.dataset.columns:
|
||||
if not column.groupby:
|
||||
continue
|
||||
dimensions.add(
|
||||
Dimension(
|
||||
id=column.column_name,
|
||||
name=column.verbose_name or column.column_name,
|
||||
type=self._column_arrow_type(column),
|
||||
definition=column.expression or None,
|
||||
description=column.description or None,
|
||||
)
|
||||
)
|
||||
return dimensions
|
||||
|
||||
def get_metrics(self) -> set[Metric]:
|
||||
from superset.semantic_layers.arrow_inference import infer_arrow_type
|
||||
|
||||
engine = self.dataset.database.db_engine_spec.engine
|
||||
metrics: set[Metric] = set()
|
||||
for metric in self.dataset.metrics:
|
||||
metrics.add(
|
||||
Metric(
|
||||
id=metric.metric_name,
|
||||
name=metric.verbose_name or metric.metric_name,
|
||||
type=infer_arrow_type(metric.expression, engine),
|
||||
definition=metric.expression,
|
||||
description=metric.description or None,
|
||||
aggregation=self._aggregation_from_expression(metric.expression),
|
||||
)
|
||||
)
|
||||
return metrics
|
||||
|
||||
def get_compatible_metrics(
|
||||
self,
|
||||
selected_metrics: set[Metric],
|
||||
selected_dimensions: set[Dimension],
|
||||
) -> set[Metric]:
|
||||
return self.metrics
|
||||
|
||||
def get_compatible_dimensions(
|
||||
self,
|
||||
selected_metrics: set[Metric],
|
||||
selected_dimensions: set[Dimension],
|
||||
) -> set[Dimension]:
|
||||
return self.dimensions
|
||||
|
||||
def _column_arrow_type(self, column: "TableColumn") -> pa.DataType:
|
||||
generic = column.type_generic
|
||||
if generic is None:
|
||||
return pa.utf8()
|
||||
return _GENERIC_TYPE_TO_ARROW.get(int(generic), pa.utf8())
|
||||
|
||||
def _aggregation_from_expression(
|
||||
self,
|
||||
expression: str | None,
|
||||
) -> AggregationType:
|
||||
if not expression:
|
||||
return AggregationType.OTHER
|
||||
|
||||
try:
|
||||
parsed = sqlglot.parse_one(expression, dialect=self._sqlglot_dialect)
|
||||
except sqlglot.errors.ParseError:
|
||||
return AggregationType.OTHER
|
||||
|
||||
if isinstance(parsed, sqlglot_exp.Count):
|
||||
if isinstance(parsed.this, sqlglot_exp.Distinct):
|
||||
return AggregationType.COUNT_DISTINCT
|
||||
return AggregationType.COUNT
|
||||
|
||||
for node_type, aggregation in _AGG_NODE_MAP.items():
|
||||
if isinstance(parsed, node_type):
|
||||
return aggregation
|
||||
|
||||
return AggregationType.OTHER
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# AST building
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _dimension_column(self, dimension: Dimension) -> "TableColumn":
|
||||
# Callers (currently only ``_dimension_expression``) check membership
|
||||
# before invoking this helper, so the match is guaranteed.
|
||||
for column in self.dataset.columns:
|
||||
if column.column_name == dimension.id:
|
||||
return column
|
||||
raise ValueError( # pragma: no cover - guarded by caller
|
||||
f'Dimension "{dimension.id}" is not part of this dataset.'
|
||||
)
|
||||
|
||||
def _metric_column(self, metric: Metric) -> "SqlMetric":
|
||||
for sql_metric in self.dataset.metrics:
|
||||
if sql_metric.metric_name == metric.id:
|
||||
return sql_metric
|
||||
raise ValueError(f'Metric "{metric.id}" is not part of this dataset.')
|
||||
|
||||
def _dimension_expression(self, dimension: Dimension) -> sqlglot_exp.Expression:
|
||||
# Synthesised adhoc dimensions don't correspond to a dataset column;
|
||||
# their ``definition`` carries the SQL we should emit.
|
||||
dataset_column_names = {col.column_name for col in self.dataset.columns}
|
||||
if dimension.id not in dataset_column_names:
|
||||
if not dimension.definition:
|
||||
raise ValueError(
|
||||
f'Dimension "{dimension.id}" is not part of this dataset.'
|
||||
)
|
||||
base: sqlglot_exp.Expression = sqlglot.parse_one(
|
||||
dimension.definition,
|
||||
dialect=self._sqlglot_dialect,
|
||||
)
|
||||
return self._apply_grain(dimension, base)
|
||||
|
||||
column = self._dimension_column(dimension)
|
||||
if column.expression:
|
||||
base = sqlglot.parse_one(
|
||||
column.expression,
|
||||
dialect=self._sqlglot_dialect,
|
||||
)
|
||||
else:
|
||||
base = sqlglot_exp.column(column.column_name, quoted=True)
|
||||
return self._apply_grain(dimension, base)
|
||||
|
||||
def _apply_grain(
|
||||
self,
|
||||
dimension: Dimension,
|
||||
base_expr: sqlglot_exp.Expression,
|
||||
) -> sqlglot_exp.Expression:
|
||||
"""
|
||||
Wrap ``base_expr`` in the engine spec's time-grain template when the
|
||||
dimension carries a grain. Engines that don't model the requested
|
||||
grain return the unwrapped expression — same fallback the legacy path
|
||||
uses.
|
||||
"""
|
||||
if dimension.grain is None:
|
||||
return base_expr
|
||||
|
||||
engine_spec = self.dataset.database.db_engine_spec
|
||||
template = engine_spec.get_time_grain_expressions().get(
|
||||
dimension.grain.representation
|
||||
)
|
||||
if not template:
|
||||
return base_expr
|
||||
|
||||
# ``{func}`` / ``{type}`` placeholders (BigQuery, MS SQL) require the
|
||||
# engine spec's full ``get_timestamp_expr`` machinery to resolve.
|
||||
if "{func}" in template or "{type}" in template:
|
||||
import sqlalchemy as sa
|
||||
|
||||
sa_col = sa.column(
|
||||
base_expr.sql(dialect=self._sqlglot_dialect),
|
||||
is_literal=True,
|
||||
)
|
||||
tse = engine_spec.get_timestamp_expr(
|
||||
sa_col,
|
||||
None,
|
||||
dimension.grain.representation,
|
||||
)
|
||||
sql_str = str(
|
||||
tse.compile(
|
||||
dialect=self.dataset.database.get_dialect(),
|
||||
compile_kwargs={"literal_binds": True},
|
||||
)
|
||||
)
|
||||
return sqlglot.parse_one(sql_str, dialect=self._sqlglot_dialect)
|
||||
|
||||
col_sql = base_expr.sql(dialect=self._sqlglot_dialect)
|
||||
grain_sql = template.replace("{col}", col_sql)
|
||||
return sqlglot.parse_one(grain_sql, dialect=self._sqlglot_dialect)
|
||||
|
||||
def _metric_expression(self, metric: Metric) -> sqlglot_exp.Expression:
|
||||
sql_metric = self._metric_column(metric)
|
||||
return sqlglot.parse_one(sql_metric.expression, dialect=self._sqlglot_dialect)
|
||||
|
||||
def _source_table(self) -> sqlglot_exp.Expression:
|
||||
"""
|
||||
Build the FROM clause. Physical datasets reference the table directly;
|
||||
virtual datasets wrap their SQL as a subquery so the rest of the AST
|
||||
can treat the source like a table.
|
||||
|
||||
For virtual datasets, RLS rules from the underlying tables referenced
|
||||
in ``dataset.sql`` are injected into the inner SQL. When the engine
|
||||
spec asks for ``AS_SUBQUERY``-style RLS, the dataset's own RLS
|
||||
predicates wrap the source in an extra ``SELECT * FROM … WHERE`` so
|
||||
the rest of the AST is unaffected.
|
||||
"""
|
||||
dataset = self.dataset
|
||||
|
||||
if dataset.sql:
|
||||
inner_sql = self._rls_apply_to_virtual_sql() or dataset.sql
|
||||
inner = sqlglot.parse_one(inner_sql, dialect=self._sqlglot_dialect)
|
||||
source: sqlglot_exp.Expression = sqlglot_exp.Subquery(
|
||||
this=inner, alias="virtual_dataset"
|
||||
)
|
||||
else:
|
||||
parts = [
|
||||
sqlglot_exp.to_identifier(part, quoted=True)
|
||||
for part in (dataset.catalog, dataset.schema, dataset.table_name)
|
||||
if part
|
||||
]
|
||||
if len(parts) == 1:
|
||||
source = sqlglot_exp.Table(this=parts[0])
|
||||
elif len(parts) == 2:
|
||||
source = sqlglot_exp.Table(this=parts[1], db=parts[0])
|
||||
else:
|
||||
source = sqlglot_exp.Table(this=parts[2], db=parts[1], catalog=parts[0])
|
||||
|
||||
if self._rls_should_wrap_subquery():
|
||||
rls_clauses = self._rls_predicates()
|
||||
if rls_clauses:
|
||||
source = self._wrap_with_rls(source, rls_clauses)
|
||||
|
||||
return source
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# RLS
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _rls_predicates(self) -> list[str]:
|
||||
"""Outer RLS predicates for the dataset, rendered as SQL strings."""
|
||||
from superset.semantic_layers.rls import render_rls_predicates
|
||||
|
||||
return render_rls_predicates(self.dataset)
|
||||
|
||||
def _rls_apply_to_virtual_sql(self) -> str | None:
|
||||
from superset.semantic_layers.rls import apply_rls_to_virtual_sql
|
||||
|
||||
return apply_rls_to_virtual_sql(self.dataset)
|
||||
|
||||
def _rls_should_wrap_subquery(self) -> bool:
|
||||
from superset.semantic_layers.rls import get_rls_method
|
||||
from superset.sql.parse import RLSMethod
|
||||
|
||||
return get_rls_method(self.dataset) == RLSMethod.AS_SUBQUERY
|
||||
|
||||
def _wrap_with_rls(
|
||||
self,
|
||||
source: sqlglot_exp.Expression,
|
||||
rls_clauses: list[str],
|
||||
) -> sqlglot_exp.Subquery:
|
||||
"""Wrap ``source`` in ``SELECT * FROM source WHERE <AND-ed RLS>``."""
|
||||
parsed = [
|
||||
sqlglot.parse_one(clause, dialect=self._sqlglot_dialect)
|
||||
for clause in rls_clauses
|
||||
]
|
||||
combined = parsed[0]
|
||||
for predicate in parsed[1:]:
|
||||
combined = sqlglot_exp.And(this=combined, expression=predicate)
|
||||
inner = (
|
||||
sqlglot_exp.Select()
|
||||
.select(sqlglot_exp.Star())
|
||||
.from_(source)
|
||||
.where(combined)
|
||||
)
|
||||
return sqlglot_exp.Subquery(this=inner, alias="rls_wrapped")
|
||||
|
||||
def _filter_predicate(self, filter_: Filter) -> sqlglot_exp.Expression:
|
||||
if filter_.operator == Operator.ADHOC:
|
||||
if not isinstance(filter_.value, str):
|
||||
raise ValueError("Adhoc filter value must be a SQL string")
|
||||
return sqlglot.parse_one(filter_.value, dialect=self._sqlglot_dialect)
|
||||
|
||||
if filter_.column is None:
|
||||
raise ValueError("Native filters require a column")
|
||||
|
||||
if isinstance(filter_.column, Dimension):
|
||||
column_expr = self._dimension_expression(filter_.column)
|
||||
else:
|
||||
column_expr = self._metric_expression(filter_.column)
|
||||
|
||||
operator = filter_.operator
|
||||
|
||||
if operator == Operator.IS_NULL:
|
||||
return sqlglot_exp.Is(this=column_expr, expression=sqlglot_exp.Null())
|
||||
if operator == Operator.IS_NOT_NULL:
|
||||
return sqlglot_exp.Not(
|
||||
this=sqlglot_exp.Is(this=column_expr, expression=sqlglot_exp.Null())
|
||||
)
|
||||
if operator in (Operator.IN, Operator.NOT_IN):
|
||||
values = coerce_literal(filter_.value)
|
||||
if not isinstance(values, list):
|
||||
values = [values]
|
||||
expressions = [sqlglot_exp.convert(v) for v in values]
|
||||
in_expr = sqlglot_exp.In(this=column_expr, expressions=expressions)
|
||||
return (
|
||||
sqlglot_exp.Not(this=in_expr)
|
||||
if operator == Operator.NOT_IN
|
||||
else in_expr
|
||||
)
|
||||
if operator == Operator.NOT_LIKE:
|
||||
return sqlglot_exp.Not(
|
||||
this=sqlglot_exp.Like(
|
||||
this=column_expr,
|
||||
expression=sqlglot_exp.convert(filter_.value),
|
||||
)
|
||||
)
|
||||
|
||||
sqlglot_op_cls = _OPERATOR_TO_SQLGLOT.get(operator)
|
||||
if sqlglot_op_cls is None:
|
||||
raise ValueError(f"Unsupported operator: {operator}")
|
||||
return sqlglot_op_cls(
|
||||
this=column_expr,
|
||||
expression=sqlglot_exp.convert(filter_.value),
|
||||
)
|
||||
|
||||
def _combine_predicates(
|
||||
self,
|
||||
filters: set[Filter],
|
||||
) -> sqlglot_exp.Expression | None:
|
||||
if not filters:
|
||||
return None
|
||||
# Sort to make the resulting SQL deterministic. Filter's default
|
||||
# tuple ordering eventually compares ``Dimension`` instances, which
|
||||
# are not orderable, so use a stable string key instead.
|
||||
ordered = sorted(
|
||||
filters,
|
||||
key=lambda f: (
|
||||
f.type.value,
|
||||
f.column.id if f.column is not None else "",
|
||||
f.operator.value,
|
||||
repr(f.value),
|
||||
),
|
||||
)
|
||||
combined = self._filter_predicate(ordered[0])
|
||||
for filter_ in ordered[1:]:
|
||||
combined = sqlglot_exp.And(
|
||||
this=combined,
|
||||
expression=self._filter_predicate(filter_),
|
||||
)
|
||||
return combined
|
||||
|
||||
def _order_expression(
|
||||
self,
|
||||
element: Metric | Dimension | AdhocExpression,
|
||||
direction: OrderDirection,
|
||||
) -> sqlglot_exp.Ordered:
|
||||
if isinstance(element, AdhocExpression):
|
||||
inner = sqlglot.parse_one(
|
||||
element.definition,
|
||||
dialect=self._sqlglot_dialect,
|
||||
)
|
||||
else:
|
||||
inner = sqlglot_exp.column(element.id, quoted=True)
|
||||
return sqlglot_exp.Ordered(this=inner, desc=direction == OrderDirection.DESC)
|
||||
|
||||
def _build_ast(self, query: SemanticQuery) -> sqlglot_exp.Select:
|
||||
if query.limit is None and query.offset is not None:
|
||||
raise ValueError("Offset cannot be set without limit")
|
||||
|
||||
filters = query.filters or set()
|
||||
where_filters = {f for f in filters if f.type == PredicateType.WHERE}
|
||||
having_filters = {f for f in filters if f.type == PredicateType.HAVING}
|
||||
|
||||
# Add the dataset's outer RLS predicates as ADHOC where-filters when
|
||||
# the engine wants AS_PREDICATE-style RLS. AS_SUBQUERY-style RLS is
|
||||
# handled in ``_source_table`` instead.
|
||||
if not self._rls_should_wrap_subquery():
|
||||
for clause in self._rls_predicates():
|
||||
where_filters = where_filters | {
|
||||
Filter(
|
||||
type=PredicateType.WHERE,
|
||||
column=None,
|
||||
operator=Operator.ADHOC,
|
||||
value=clause,
|
||||
)
|
||||
}
|
||||
|
||||
if query.group_limit is not None and query.group_limit.group_others:
|
||||
return self._build_with_others(query, where_filters, having_filters)
|
||||
if query.group_limit is not None:
|
||||
return self._build_with_group_limit(query, where_filters, having_filters)
|
||||
return self._build_plain(query, where_filters, having_filters)
|
||||
|
||||
def _build_plain(
|
||||
self,
|
||||
query: SemanticQuery,
|
||||
where_filters: set[Filter],
|
||||
having_filters: set[Filter],
|
||||
) -> sqlglot_exp.Select:
|
||||
projections: list[sqlglot_exp.Expression] = []
|
||||
for dimension in query.dimensions:
|
||||
projections.append(
|
||||
sqlglot_exp.alias_(
|
||||
self._dimension_expression(dimension),
|
||||
dimension.id,
|
||||
quoted=True,
|
||||
)
|
||||
)
|
||||
for metric in query.metrics:
|
||||
projections.append(
|
||||
sqlglot_exp.alias_(
|
||||
self._metric_expression(metric),
|
||||
metric.id,
|
||||
quoted=True,
|
||||
)
|
||||
)
|
||||
|
||||
select = (
|
||||
sqlglot_exp.Select()
|
||||
.select(*projections, append=False)
|
||||
.from_(self._source_table())
|
||||
)
|
||||
|
||||
if where_predicate := self._combine_predicates(where_filters):
|
||||
select = select.where(where_predicate)
|
||||
|
||||
# GROUP BY when aggregating: any metric is an aggregate, so all
|
||||
# non-aggregate dimensions need to be in GROUP BY. For grain-bucketed
|
||||
# dimensions we GROUP BY the full bucketed expression because dialects
|
||||
# disagree on whether SELECT aliases are visible in GROUP BY.
|
||||
if query.metrics and query.dimensions:
|
||||
group_by_columns = [
|
||||
(
|
||||
self._dimension_expression(dimension)
|
||||
if dimension.grain is not None
|
||||
else sqlglot_exp.column(dimension.id, quoted=True)
|
||||
)
|
||||
for dimension in query.dimensions
|
||||
]
|
||||
select = select.group_by(*group_by_columns)
|
||||
|
||||
if having_predicate := self._combine_predicates(having_filters):
|
||||
select = select.having(having_predicate)
|
||||
|
||||
if query.order:
|
||||
ordered = [
|
||||
self._order_expression(element, direction)
|
||||
for element, direction in query.order
|
||||
]
|
||||
select = select.order_by(*ordered)
|
||||
|
||||
if query.limit is not None:
|
||||
select = select.limit(query.limit)
|
||||
if query.offset is not None:
|
||||
select = select.offset(query.offset)
|
||||
|
||||
return select
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Group limit (top N) helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _top_groups_cte_select(
|
||||
self,
|
||||
group_limit: GroupLimit,
|
||||
main_where_filters: set[Filter],
|
||||
) -> sqlglot_exp.Select:
|
||||
"""
|
||||
Build the SELECT that powers the ``top_groups`` CTE.
|
||||
|
||||
The CTE projects only the limited dimensions. The ordering metric, when
|
||||
present, is evaluated inline in ORDER BY so it does not leak into the
|
||||
CTE's column list.
|
||||
"""
|
||||
if group_limit.filters is not None:
|
||||
if any(f.type == PredicateType.HAVING for f in group_limit.filters):
|
||||
raise ValueError(
|
||||
"HAVING filters are not supported in group_limit.filters"
|
||||
)
|
||||
cte_where_filters = {
|
||||
f for f in group_limit.filters if f.type == PredicateType.WHERE
|
||||
}
|
||||
else:
|
||||
cte_where_filters = main_where_filters
|
||||
|
||||
dim_projections = [
|
||||
sqlglot_exp.alias_(
|
||||
self._dimension_expression(dim),
|
||||
dim.id,
|
||||
quoted=True,
|
||||
)
|
||||
for dim in group_limit.dimensions
|
||||
]
|
||||
select = (
|
||||
sqlglot_exp.Select()
|
||||
.select(*dim_projections, append=False)
|
||||
.from_(self._source_table())
|
||||
)
|
||||
|
||||
if where_predicate := self._combine_predicates(cte_where_filters):
|
||||
select = select.where(where_predicate)
|
||||
|
||||
# When ordering by a metric we need an aggregation; GROUP BY the
|
||||
# limited dimensions so each combination collapses to a single row.
|
||||
# Use full expressions for grain-bucketed dims (alias resolution in
|
||||
# GROUP BY isn't portable).
|
||||
if group_limit.metric is not None:
|
||||
select = select.group_by(
|
||||
*[
|
||||
(
|
||||
self._dimension_expression(dim)
|
||||
if dim.grain is not None
|
||||
else sqlglot_exp.column(dim.id, quoted=True)
|
||||
)
|
||||
for dim in group_limit.dimensions
|
||||
]
|
||||
)
|
||||
order_expr: sqlglot_exp.Expression = self._metric_expression(
|
||||
group_limit.metric
|
||||
)
|
||||
else:
|
||||
order_expr = sqlglot_exp.column(
|
||||
group_limit.dimensions[0].id,
|
||||
quoted=True,
|
||||
)
|
||||
|
||||
select = select.order_by(
|
||||
sqlglot_exp.Ordered(
|
||||
this=order_expr,
|
||||
desc=group_limit.direction == OrderDirection.DESC,
|
||||
)
|
||||
)
|
||||
return select.limit(group_limit.top)
|
||||
|
||||
def _top_groups_in_predicate(
|
||||
self,
|
||||
group_limit: GroupLimit,
|
||||
) -> sqlglot_exp.Expression:
|
||||
"""
|
||||
Build a predicate restricting the limited dimensions to the rows of the
|
||||
``top_groups`` CTE. Single dimension uses scalar IN; multiple use a
|
||||
row-tuple IN. The subquery is wrapped in ``Subquery`` so sqlglot emits
|
||||
the surrounding parentheses.
|
||||
"""
|
||||
cte_table = sqlglot_exp.to_identifier("top_groups")
|
||||
dim_columns = [
|
||||
sqlglot_exp.column(dim.id, quoted=True)
|
||||
for dim in group_limit.dimensions
|
||||
]
|
||||
subquery = sqlglot_exp.Subquery(
|
||||
this=(
|
||||
sqlglot_exp.Select()
|
||||
.select(*dim_columns)
|
||||
.from_(cte_table)
|
||||
)
|
||||
)
|
||||
if len(dim_columns) == 1:
|
||||
return sqlglot_exp.In(this=dim_columns[0], query=subquery)
|
||||
return sqlglot_exp.In(
|
||||
this=sqlglot_exp.Tuple(expressions=dim_columns),
|
||||
query=subquery,
|
||||
)
|
||||
|
||||
def _build_with_group_limit(
|
||||
self,
|
||||
query: SemanticQuery,
|
||||
where_filters: set[Filter],
|
||||
having_filters: set[Filter],
|
||||
) -> sqlglot_exp.Select:
|
||||
"""
|
||||
Restrict the main query to rows whose limited-dimension values appear
|
||||
in the ``top_groups`` CTE.
|
||||
"""
|
||||
select = self._build_plain(query, where_filters, having_filters)
|
||||
select = select.where(self._top_groups_in_predicate(query.group_limit))
|
||||
|
||||
cte_select = self._top_groups_cte_select(query.group_limit, where_filters)
|
||||
return select.with_("top_groups", as_=cte_select)
|
||||
|
||||
def _build_with_others(
|
||||
self,
|
||||
query: SemanticQuery,
|
||||
where_filters: set[Filter],
|
||||
having_filters: set[Filter],
|
||||
) -> sqlglot_exp.Select:
|
||||
"""
|
||||
Bucket non-top values into ``'Other'`` for the limited dimensions and
|
||||
re-aggregate against the bucketed groups. Non-additive metrics stay
|
||||
correct because we re-evaluate the original metric expression instead
|
||||
of summing previously aggregated rows.
|
||||
"""
|
||||
group_limit = query.group_limit
|
||||
limited_dim_ids = {dim.id for dim in group_limit.dimensions}
|
||||
in_predicate = self._top_groups_in_predicate(group_limit)
|
||||
|
||||
def dim_expr(dim: Dimension) -> sqlglot_exp.Expression:
|
||||
"""Underlying expression for a dimension, with CASE for limited ones."""
|
||||
base = self._dimension_expression(dim)
|
||||
if dim.id not in limited_dim_ids:
|
||||
return base
|
||||
return sqlglot_exp.Case(
|
||||
ifs=[sqlglot_exp.If(this=in_predicate.copy(), true=base)],
|
||||
default=sqlglot_exp.Literal.string("Other"),
|
||||
)
|
||||
|
||||
projections: list[sqlglot_exp.Expression] = []
|
||||
for dim in query.dimensions:
|
||||
projections.append(sqlglot_exp.alias_(dim_expr(dim), dim.id, quoted=True))
|
||||
for metric in query.metrics:
|
||||
projections.append(
|
||||
sqlglot_exp.alias_(
|
||||
self._metric_expression(metric),
|
||||
metric.id,
|
||||
quoted=True,
|
||||
)
|
||||
)
|
||||
|
||||
select = (
|
||||
sqlglot_exp.Select()
|
||||
.select(*projections, append=False)
|
||||
.from_(self._source_table())
|
||||
)
|
||||
|
||||
if where_predicate := self._combine_predicates(where_filters):
|
||||
select = select.where(where_predicate)
|
||||
|
||||
# GROUP BY the full expressions (not the aliases). Dialects disagree on
|
||||
# whether GROUP BY can reference a SELECT alias, and using the full
|
||||
# expression here side-steps that ambiguity for the CASE columns.
|
||||
if query.metrics and query.dimensions:
|
||||
select = select.group_by(
|
||||
*[dim_expr(dim) for dim in query.dimensions]
|
||||
)
|
||||
|
||||
if having_predicate := self._combine_predicates(having_filters):
|
||||
select = select.having(having_predicate)
|
||||
|
||||
if query.order:
|
||||
ordered = [
|
||||
self._order_expression(element, direction)
|
||||
for element, direction in query.order
|
||||
]
|
||||
select = select.order_by(*ordered)
|
||||
|
||||
if query.limit is not None:
|
||||
select = select.limit(query.limit)
|
||||
if query.offset is not None:
|
||||
select = select.offset(query.offset)
|
||||
|
||||
cte_select = self._top_groups_cte_select(group_limit, where_filters)
|
||||
return select.with_("top_groups", as_=cte_select)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def compile_query(self, query: SemanticQuery) -> str:
|
||||
"""Return the SQL for ``query`` in the dataset's database dialect."""
|
||||
return self._build_ast(query).sql(dialect=self._sqlglot_dialect)
|
||||
|
||||
def get_values(
|
||||
self,
|
||||
dimension: Dimension,
|
||||
filters: set[Filter] | None = None,
|
||||
) -> SemanticResult:
|
||||
where_filters = {
|
||||
f for f in (filters or set()) if f.type == PredicateType.WHERE
|
||||
}
|
||||
# Add outer RLS predicates as ADHOC filters; AS_SUBQUERY-mode RLS is
|
||||
# wrapped into ``_source_table`` already.
|
||||
if not self._rls_should_wrap_subquery():
|
||||
for clause in self._rls_predicates():
|
||||
where_filters.add(
|
||||
Filter(
|
||||
type=PredicateType.WHERE,
|
||||
column=None,
|
||||
operator=Operator.ADHOC,
|
||||
value=clause,
|
||||
)
|
||||
)
|
||||
where = self._combine_predicates(where_filters)
|
||||
select = (
|
||||
sqlglot_exp.Select()
|
||||
.distinct()
|
||||
.select(
|
||||
sqlglot_exp.alias_(
|
||||
self._dimension_expression(dimension),
|
||||
dimension.id,
|
||||
quoted=True,
|
||||
)
|
||||
)
|
||||
.from_(self._source_table())
|
||||
)
|
||||
if where is not None:
|
||||
select = select.where(where)
|
||||
|
||||
sql = select.sql(dialect=self._sqlglot_dialect)
|
||||
df = self.dataset.database.get_df(
|
||||
sql,
|
||||
catalog=self.dataset.catalog,
|
||||
schema=self.dataset.schema,
|
||||
)
|
||||
return SemanticResult(
|
||||
requests=[SemanticRequest(REQUEST_TYPE, sql)],
|
||||
results=df_to_arrow(df),
|
||||
)
|
||||
|
||||
def get_table(self, query: SemanticQuery) -> SemanticResult:
|
||||
if not query.metrics and not query.dimensions:
|
||||
return SemanticResult(requests=[], results=pa.table({}))
|
||||
|
||||
sql = self.compile_query(query)
|
||||
df = self.dataset.database.get_df(
|
||||
sql,
|
||||
catalog=self.dataset.catalog,
|
||||
schema=self.dataset.schema,
|
||||
)
|
||||
|
||||
# Map alias columns back to dimension/metric *names* so the result uses
|
||||
# human-friendly labels instead of internal IDs.
|
||||
rename: dict[str, str] = {}
|
||||
for dimension in query.dimensions:
|
||||
rename[dimension.id] = dimension.name
|
||||
for metric in query.metrics:
|
||||
rename[metric.id] = metric.name
|
||||
if df is not None and not df.empty and rename:
|
||||
df = df.rename(columns=rename)
|
||||
|
||||
return SemanticResult(
|
||||
requests=[SemanticRequest(REQUEST_TYPE, sql)],
|
||||
results=df_to_arrow(df),
|
||||
)
|
||||
|
||||
def get_row_count(self, query: SemanticQuery) -> SemanticResult:
|
||||
if not query.metrics and not query.dimensions:
|
||||
return SemanticResult(requests=[], results=pa.table({"COUNT": [0]}))
|
||||
|
||||
inner = self._build_ast(query)
|
||||
count_select = (
|
||||
sqlglot_exp.Select()
|
||||
.select(
|
||||
sqlglot_exp.alias_(
|
||||
sqlglot_exp.Count(this=sqlglot_exp.Star()),
|
||||
"COUNT",
|
||||
quoted=True,
|
||||
)
|
||||
)
|
||||
.from_(sqlglot_exp.Subquery(this=inner, alias="subquery"))
|
||||
)
|
||||
sql = count_select.sql(dialect=self._sqlglot_dialect)
|
||||
df = self.dataset.database.get_df(
|
||||
sql,
|
||||
catalog=self.dataset.catalog,
|
||||
schema=self.dataset.schema,
|
||||
)
|
||||
|
||||
return SemanticResult(
|
||||
requests=[SemanticRequest(REQUEST_TYPE, sql)],
|
||||
results=df_to_arrow(df),
|
||||
)
|
||||
|
||||
__repr__ = uid
|
||||
@@ -0,0 +1,80 @@
|
||||
# flake8: noqa: E501
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from preset_io.dataset_semantic_layer import (
|
||||
DatasetConfiguration,
|
||||
DatasetSemanticLayer,
|
||||
DatasetSemanticView,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dataset() -> MagicMock:
|
||||
db = MagicMock()
|
||||
db.db_engine_spec.engine = "postgresql"
|
||||
ds = MagicMock(spec=[])
|
||||
ds.id = 1
|
||||
ds.table_name = "orders"
|
||||
ds.schema = "public"
|
||||
ds.catalog = None
|
||||
ds.sql = None
|
||||
ds.database = db
|
||||
ds.columns = []
|
||||
ds.metrics = []
|
||||
return ds
|
||||
|
||||
|
||||
def test_empty_configuration_schema() -> None:
|
||||
schema = DatasetSemanticLayer.get_configuration_schema()
|
||||
# No required fields — adding the layer takes zero input.
|
||||
assert schema.get("required", []) == []
|
||||
assert schema["properties"] == {}
|
||||
|
||||
|
||||
def test_runtime_schema_exposes_dataset_dropdown(mocker, dataset: MagicMock) -> None:
|
||||
mocker.patch(
|
||||
"preset_io.dataset_semantic_layer.layer.list_datasets",
|
||||
return_value=[dataset],
|
||||
)
|
||||
schema = DatasetSemanticLayer.get_runtime_schema(DatasetConfiguration())
|
||||
|
||||
assert schema["required"] == ["dataset_id"]
|
||||
field = schema["properties"]["dataset_id"]
|
||||
assert field["enum"] == ["1"]
|
||||
assert field["x-enumNames"] == ["public.orders"]
|
||||
|
||||
|
||||
def test_get_semantic_view_requires_dataset_id() -> None:
|
||||
layer = DatasetSemanticLayer.from_configuration({})
|
||||
with pytest.raises(ValueError, match="dataset_id"):
|
||||
layer.get_semantic_view("orders", {})
|
||||
|
||||
|
||||
def test_get_semantic_view_returns_dataset_view(mocker, dataset: MagicMock) -> None:
|
||||
mocker.patch(
|
||||
"preset_io.dataset_semantic_layer.layer.get_dataset_by_id",
|
||||
return_value=dataset,
|
||||
)
|
||||
layer = DatasetSemanticLayer.from_configuration({})
|
||||
view = layer.get_semantic_view("orders", {"dataset_id": "1"})
|
||||
|
||||
assert isinstance(view, DatasetSemanticView)
|
||||
assert view.uid() == "dataset:1"
|
||||
|
||||
|
||||
def test_get_semantic_views_returns_empty_when_no_id() -> None:
|
||||
layer = DatasetSemanticLayer.from_configuration({})
|
||||
assert layer.get_semantic_views({}) == set()
|
||||
|
||||
|
||||
def test_get_semantic_views_resolves_dataset(mocker, dataset: MagicMock) -> None:
|
||||
mocker.patch(
|
||||
"preset_io.dataset_semantic_layer.layer.get_dataset_by_id",
|
||||
return_value=dataset,
|
||||
)
|
||||
layer = DatasetSemanticLayer.from_configuration({})
|
||||
views = layer.get_semantic_views({"dataset_id": "1"})
|
||||
assert len(views) == 1
|
||||
assert next(iter(views)).uid() == "dataset:1"
|
||||
File diff suppressed because it is too large
Load Diff
157
superset/semantic_layers/extension/backend/tests/test_utils.py
Normal file
157
superset/semantic_layers/extension/backend/tests/test_utils.py
Normal file
@@ -0,0 +1,157 @@
|
||||
# flake8: noqa: E501
|
||||
|
||||
import sys
|
||||
import types
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pandas as pd
|
||||
import pyarrow as pa
|
||||
import pytest
|
||||
|
||||
from preset_io.dataset_semantic_layer.utils import (
|
||||
coerce_literal,
|
||||
dataset_label,
|
||||
df_to_arrow,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pure helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_coerce_literal_passes_through_scalars() -> None:
|
||||
assert coerce_literal("x") == "x"
|
||||
assert coerce_literal(42) == 42
|
||||
assert coerce_literal(None) is None
|
||||
|
||||
|
||||
def test_coerce_literal_unwraps_sets() -> None:
|
||||
result = coerce_literal({"a", "b"})
|
||||
assert isinstance(result, list)
|
||||
assert set(result) == {"a", "b"}
|
||||
|
||||
|
||||
def test_dataset_label_combines_parts() -> None:
|
||||
ds = MagicMock(spec=[])
|
||||
ds.catalog = "warehouse"
|
||||
ds.schema = "public"
|
||||
ds.table_name = "orders"
|
||||
assert dataset_label(ds) == "warehouse.public.orders"
|
||||
|
||||
|
||||
def test_dataset_label_skips_blank_parts() -> None:
|
||||
ds = MagicMock(spec=[])
|
||||
ds.catalog = None
|
||||
ds.schema = None
|
||||
ds.table_name = "orders"
|
||||
assert dataset_label(ds) == "orders"
|
||||
|
||||
|
||||
def test_df_to_arrow_with_data() -> None:
|
||||
df = pd.DataFrame({"a": [1, 2], "b": ["x", "y"]})
|
||||
table = df_to_arrow(df)
|
||||
assert table.num_rows == 2
|
||||
assert table.column_names == ["a", "b"]
|
||||
|
||||
|
||||
def test_df_to_arrow_with_empty_df() -> None:
|
||||
df = pd.DataFrame(columns=["a", "b"])
|
||||
table = df_to_arrow(df)
|
||||
assert table.num_rows == 0
|
||||
assert table.column_names == ["a", "b"]
|
||||
|
||||
|
||||
def test_df_to_arrow_with_none() -> None:
|
||||
table = df_to_arrow(None)
|
||||
assert table.num_rows == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Session-bound helpers — superset is stubbed so we don't require a Flask app
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_superset(monkeypatch: pytest.MonkeyPatch) -> tuple[MagicMock, type]:
|
||||
"""
|
||||
Provide a minimal stand-in for ``superset.db`` and
|
||||
``superset.connectors.sqla.models.SqlaTable`` so list/get helpers run
|
||||
without a real Flask app or database.
|
||||
"""
|
||||
|
||||
class FakeSqlaTable: # noqa: D401 — marker class only
|
||||
"""Sentinel used as the queried entity."""
|
||||
|
||||
columns: list = []
|
||||
metrics: list = []
|
||||
|
||||
session = MagicMock()
|
||||
db_module = types.ModuleType("superset")
|
||||
db_module.db = MagicMock()
|
||||
db_module.db.session = session
|
||||
|
||||
sqla_models = types.ModuleType("superset.connectors.sqla.models")
|
||||
sqla_models.SqlaTable = FakeSqlaTable
|
||||
connectors_module = types.ModuleType("superset.connectors")
|
||||
connectors_sqla = types.ModuleType("superset.connectors.sqla")
|
||||
|
||||
monkeypatch.setitem(sys.modules, "superset", db_module)
|
||||
monkeypatch.setitem(sys.modules, "superset.connectors", connectors_module)
|
||||
monkeypatch.setitem(sys.modules, "superset.connectors.sqla", connectors_sqla)
|
||||
monkeypatch.setitem(sys.modules, "superset.connectors.sqla.models", sqla_models)
|
||||
|
||||
# Also stub sqlalchemy.orm.selectinload for get_dataset_by_id.
|
||||
return session, FakeSqlaTable
|
||||
|
||||
|
||||
def test_list_datasets_returns_sorted(fake_superset) -> None:
|
||||
from preset_io.dataset_semantic_layer.utils import list_datasets
|
||||
|
||||
session, _ = fake_superset
|
||||
|
||||
a, b, c = MagicMock(), MagicMock(), MagicMock()
|
||||
a.table_name = "z_users"
|
||||
b.table_name = "a_orders"
|
||||
c.table_name = "m_payments"
|
||||
session.no_autoflush.__enter__.return_value = None
|
||||
session.no_autoflush.__exit__.return_value = False
|
||||
session.query.return_value.all.return_value = [a, b, c]
|
||||
|
||||
result = list_datasets()
|
||||
assert [d.table_name for d in result] == ["a_orders", "m_payments", "z_users"]
|
||||
|
||||
|
||||
def test_get_dataset_by_id_raises_when_missing(fake_superset) -> None:
|
||||
from preset_io.dataset_semantic_layer.utils import get_dataset_by_id
|
||||
|
||||
session, _ = fake_superset
|
||||
|
||||
session.no_autoflush.__enter__.return_value = None
|
||||
session.no_autoflush.__exit__.return_value = False
|
||||
(
|
||||
session.query.return_value.options.return_value.filter_by.return_value.one_or_none.return_value
|
||||
) = None
|
||||
|
||||
with pytest.raises(ValueError, match="Dataset with id 42 does not exist."):
|
||||
get_dataset_by_id(42)
|
||||
|
||||
|
||||
def test_get_dataset_by_id_materialises_relationships(fake_superset) -> None:
|
||||
from preset_io.dataset_semantic_layer.utils import get_dataset_by_id
|
||||
|
||||
session, _ = fake_superset
|
||||
|
||||
dataset = MagicMock()
|
||||
# Use real lists so iteration in get_dataset_by_id doesn't error.
|
||||
dataset.columns = ["col1", "col2"]
|
||||
dataset.metrics = ["m1"]
|
||||
|
||||
session.no_autoflush.__enter__.return_value = None
|
||||
session.no_autoflush.__exit__.return_value = False
|
||||
(
|
||||
session.query.return_value.options.return_value.filter_by.return_value.one_or_none.return_value
|
||||
) = dataset
|
||||
|
||||
result = get_dataset_by_id(1)
|
||||
assert result is dataset
|
||||
8
superset/semantic_layers/extension/extension.json
Normal file
8
superset/semantic_layers/extension/extension.json
Normal file
@@ -0,0 +1,8 @@
|
||||
{
|
||||
"publisher": "preset-io",
|
||||
"name": "dataset-semantic-layer",
|
||||
"displayName": "Dataset Semantic Layer",
|
||||
"version": "0.1.0",
|
||||
"license": "All Rights Reserved",
|
||||
"permissions": []
|
||||
}
|
||||
8958
superset/semantic_layers/extension/frontend/package-lock.json
generated
Normal file
8958
superset/semantic_layers/extension/frontend/package-lock.json
generated
Normal file
File diff suppressed because it is too large
Load Diff
34
superset/semantic_layers/extension/frontend/package.json
Normal file
34
superset/semantic_layers/extension/frontend/package.json
Normal file
@@ -0,0 +1,34 @@
|
||||
{
|
||||
"name": "dataset_semantic_layer",
|
||||
"version": "0.1.0",
|
||||
"main": "dist/main.js",
|
||||
"types": "dist/publicAPI.d.ts",
|
||||
"scripts": {
|
||||
"test": "echo \"Error: no test specified\" && exit 1",
|
||||
"start": "webpack serve --mode development",
|
||||
"build": "webpack --stats-error-details --mode production"
|
||||
},
|
||||
"keywords": [],
|
||||
"private": true,
|
||||
"author": "",
|
||||
"license": "All Rights Reserved",
|
||||
"description": "",
|
||||
"peerDependencies": {
|
||||
"@apache-superset/core": "0.1.0-rc2",
|
||||
"react": "^17.0.2",
|
||||
"react-dom": "^17.0.2"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@babel/preset-react": "^7.26.3",
|
||||
"@babel/preset-typescript": "^7.26.0",
|
||||
"@types/react": "^19.0.10",
|
||||
"copy-webpack-plugin": "^13.0.0",
|
||||
"install": "^0.13.0",
|
||||
"npm": "^11.1.0",
|
||||
"ts-loader": "^9.5.2",
|
||||
"typescript": "^5.8.2",
|
||||
"webpack": "^5.98.0",
|
||||
"webpack-cli": "^6.0.1",
|
||||
"webpack-dev-server": "^5.2.0"
|
||||
}
|
||||
}
|
||||
17
superset/semantic_layers/extension/frontend/src/index.tsx
Normal file
17
superset/semantic_layers/extension/frontend/src/index.tsx
Normal file
@@ -0,0 +1,17 @@
|
||||
import React from "react";
|
||||
import { views } from "@apache-superset/core";
|
||||
|
||||
const viewDisposable = views.registerView(
|
||||
{ id: "dataset.semantic-layer", name: "Dataset Semantic Layer" },
|
||||
"sqllab.panels",
|
||||
() => <p>Dataset Semantic Layer</p>
|
||||
);
|
||||
|
||||
export const activate = () => {
|
||||
console.log("Dataset Semantic Layer extension activated");
|
||||
};
|
||||
|
||||
export const deactivate = () => {
|
||||
viewDisposable.dispose();
|
||||
console.log("Dataset Semantic Layer extension deactivated");
|
||||
};
|
||||
13
superset/semantic_layers/extension/frontend/tsconfig.json
Normal file
13
superset/semantic_layers/extension/frontend/tsconfig.json
Normal file
@@ -0,0 +1,13 @@
|
||||
{
|
||||
"compilerOptions": {
|
||||
"target": "es5",
|
||||
"module": "esnext",
|
||||
"moduleResolution": "node10",
|
||||
"jsx": "react",
|
||||
"strict": true,
|
||||
"esModuleInterop": true,
|
||||
"skipLibCheck": true,
|
||||
"forceConsistentCasingInFileNames": true
|
||||
},
|
||||
"include": ["src"]
|
||||
}
|
||||
@@ -0,0 +1,67 @@
|
||||
const path = require("path");
|
||||
const { ModuleFederationPlugin } = require("webpack").container;
|
||||
const packageConfig = require("./package");
|
||||
|
||||
module.exports = (env, argv) => {
|
||||
const isProd = argv.mode === "production";
|
||||
|
||||
return {
|
||||
entry: isProd ? {} : "./src/index.tsx",
|
||||
mode: isProd ? "production" : "development",
|
||||
devServer: {
|
||||
port: 3000,
|
||||
headers: {
|
||||
"Access-Control-Allow-Origin": "*",
|
||||
},
|
||||
},
|
||||
output: {
|
||||
clean: true,
|
||||
filename: isProd ? undefined : "[name].[contenthash].js",
|
||||
chunkFilename: "[name].[contenthash].js",
|
||||
path: path.resolve(__dirname, "dist"),
|
||||
publicPath: `/api/v1/extensions/${packageConfig.name}/`,
|
||||
},
|
||||
resolve: {
|
||||
extensions: [".ts", ".tsx", ".js", ".jsx"],
|
||||
},
|
||||
externalsType: "window",
|
||||
externals: {
|
||||
"@apache-superset/core": "superset",
|
||||
},
|
||||
module: {
|
||||
rules: [
|
||||
{
|
||||
test: /\.tsx?$/,
|
||||
use: "ts-loader",
|
||||
exclude: /node_modules/,
|
||||
},
|
||||
],
|
||||
},
|
||||
plugins: [
|
||||
new ModuleFederationPlugin({
|
||||
name: "dataset_semantic_layer",
|
||||
filename: "remoteEntry.[contenthash].js",
|
||||
exposes: {
|
||||
"./index": "./src/index.tsx",
|
||||
},
|
||||
shared: {
|
||||
react: {
|
||||
singleton: true,
|
||||
requiredVersion: packageConfig.peerDependencies.react,
|
||||
import: false,
|
||||
},
|
||||
"react-dom": {
|
||||
singleton: true,
|
||||
requiredVersion: packageConfig.peerDependencies["react-dom"],
|
||||
import: false,
|
||||
},
|
||||
antd: {
|
||||
singleton: true,
|
||||
requiredVersion: packageConfig.peerDependencies["antd"],
|
||||
import: false,
|
||||
},
|
||||
},
|
||||
}),
|
||||
],
|
||||
};
|
||||
};
|
||||
@@ -24,6 +24,7 @@ single dataframe.
|
||||
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
from datetime import date, datetime, time, timedelta, tzinfo
|
||||
from time import time as current_time
|
||||
from typing import Any, cast, Sequence, TypeGuard
|
||||
@@ -56,7 +57,11 @@ from superset.common.utils.time_range_utils import get_since_until_from_query_ob
|
||||
from superset.connectors.sqla.models import BaseDatasource
|
||||
from superset.constants import NO_TIME_RANGE
|
||||
from superset.models.helpers import QueryResult
|
||||
from superset.superset_typing import AdhocColumn
|
||||
from superset.semantic_layers.adhoc import (
|
||||
adhoc_column_to_semantic_dimension,
|
||||
adhoc_metric_to_semantic_metric,
|
||||
)
|
||||
from superset.superset_typing import AdhocColumn, AdhocMetric
|
||||
from superset.utils.core import (
|
||||
FilterOperator,
|
||||
QueryObjectFilterClause,
|
||||
@@ -258,16 +263,92 @@ def _normalize_column(column: str | AdhocColumn, dimension_names: set[str]) -> s
|
||||
- A string (dimension name directly)
|
||||
- An AdhocColumn with isColumnReference=True and sqlExpression containing the
|
||||
dimension name
|
||||
|
||||
Used by callers that only care about the resolved *name* (e.g. validating
|
||||
that a granularity column matches a real dimension). Synthetic adhocs
|
||||
(non-column-reference) carry no dimension name and should be routed
|
||||
through :func:`_resolve_dimension` instead.
|
||||
"""
|
||||
if isinstance(column, str):
|
||||
return column
|
||||
|
||||
# Handle column references (e.g., from time-series charts)
|
||||
# Column reference adhocs unwrap to their underlying column name.
|
||||
if column.get("isColumnReference") and (sql_expr := column.get("sqlExpression")):
|
||||
if sql_expr in dimension_names:
|
||||
return sql_expr
|
||||
|
||||
raise ValueError("Adhoc dimensions are not supported in Semantic Views.")
|
||||
# Synthetic adhoc: the label *is* the resolved name in semantic-layer
|
||||
# terms (we use it as the Dimension id/name when materialising). Falling
|
||||
# back here keeps validators happy with adhoc dicts and defers actual
|
||||
# resolution to ``_resolve_dimension``.
|
||||
if label := column.get("label"):
|
||||
return label
|
||||
|
||||
raise ValueError("Adhoc column is missing a ``label``.")
|
||||
|
||||
|
||||
def _resolve_dimension(
|
||||
column: str | AdhocColumn,
|
||||
dataset: BaseDatasource,
|
||||
dimensions_by_name: dict[str, Dimension],
|
||||
template_processor: Any | None,
|
||||
) -> Dimension:
|
||||
"""
|
||||
Resolve a column entry (string name or adhoc dict) to a ``Dimension``.
|
||||
|
||||
Strings look up a saved dimension by name. Adhoc dicts marked
|
||||
``isColumnReference=True`` re-use the matching saved dimension so its
|
||||
metadata (type, grain) is preserved. Anything else is synthesised via
|
||||
:func:`adhoc_column_to_semantic_dimension`.
|
||||
"""
|
||||
if isinstance(column, str):
|
||||
if column not in dimensions_by_name:
|
||||
raise ValueError(
|
||||
f"Dimension {column!r} is not defined in the Semantic View."
|
||||
)
|
||||
return dimensions_by_name[column]
|
||||
|
||||
return adhoc_column_to_semantic_dimension(
|
||||
column,
|
||||
dataset,
|
||||
dimensions_by_name,
|
||||
template_processor,
|
||||
)
|
||||
|
||||
|
||||
def _resolve_metric(
|
||||
metric: str | AdhocMetric,
|
||||
dataset: BaseDatasource,
|
||||
metrics_by_name: dict[str, Metric],
|
||||
template_processor: Any | None,
|
||||
) -> Metric:
|
||||
"""
|
||||
Resolve a metric entry (string name or adhoc dict) to a ``Metric``.
|
||||
"""
|
||||
if isinstance(metric, str):
|
||||
if metric not in metrics_by_name:
|
||||
raise ValueError(
|
||||
f"Metric {metric!r} is not defined in the Semantic View."
|
||||
)
|
||||
return metrics_by_name[metric]
|
||||
return adhoc_metric_to_semantic_metric(metric, dataset, template_processor)
|
||||
|
||||
|
||||
def _get_template_processor(dataset: BaseDatasource) -> Any | None:
|
||||
"""
|
||||
Build a template processor for adhoc Jinja rendering, or ``None`` when
|
||||
the datasource doesn't expose one (e.g. a cube-mode semantic view, where
|
||||
adhocs aren't supported anyway).
|
||||
"""
|
||||
get_template_processor = getattr(dataset, "get_template_processor", None)
|
||||
if get_template_processor is None:
|
||||
return None
|
||||
return get_template_processor()
|
||||
|
||||
|
||||
def _stamp_grain(dimension: Dimension, grain: Grain) -> Dimension:
|
||||
"""Return a copy of ``dimension`` with the supplied grain set."""
|
||||
return dataclasses.replace(dimension, grain=grain)
|
||||
|
||||
|
||||
def map_query_object(query_object: ValidatedQueryObject) -> list[SemanticQuery]:
|
||||
@@ -278,33 +359,54 @@ def map_query_object(query_object: ValidatedQueryObject) -> list[SemanticQuery]:
|
||||
visualization and more on semantics.
|
||||
"""
|
||||
semantic_view = query_object.datasource.implementation
|
||||
dataset = query_object.datasource
|
||||
|
||||
all_metrics = {metric.name: metric for metric in semantic_view.metrics}
|
||||
all_dimensions = {
|
||||
dimension.name: dimension for dimension in semantic_view.dimensions
|
||||
}
|
||||
|
||||
# Normalize columns (may be dicts with isColumnReference=True for time-series)
|
||||
dimension_names = set(all_dimensions.keys())
|
||||
normalized_columns = {
|
||||
_normalize_column(column, dimension_names) for column in query_object.columns
|
||||
}
|
||||
template_processor = _get_template_processor(dataset)
|
||||
|
||||
metrics = [all_metrics[metric] for metric in (query_object.metrics or [])]
|
||||
metrics = [
|
||||
_resolve_metric(metric, dataset, all_metrics, template_processor)
|
||||
for metric in (query_object.metrics or [])
|
||||
]
|
||||
|
||||
# Resolve each requested column, preserving order and deduplicating by id.
|
||||
seen_dim_ids: set[str] = set()
|
||||
dimensions: list[Dimension] = []
|
||||
for column in query_object.columns:
|
||||
dim = _resolve_dimension(
|
||||
column,
|
||||
dataset,
|
||||
all_dimensions,
|
||||
template_processor,
|
||||
)
|
||||
if dim.id in seen_dim_ids:
|
||||
continue
|
||||
seen_dim_ids.add(dim.id)
|
||||
dimensions.append(dim)
|
||||
|
||||
grain = _convert_time_grain(query_object.extras.get("time_grain_sqla"))
|
||||
dimensions = [
|
||||
dimension
|
||||
for dimension in semantic_view.dimensions
|
||||
if dimension.name in normalized_columns
|
||||
and (
|
||||
# if a grain is specified, only include the time dimension if its grain
|
||||
# matches the requested grain
|
||||
grain is None
|
||||
or dimension.name != query_object.granularity
|
||||
or dimension.grain == grain
|
||||
)
|
||||
]
|
||||
if grain is not None and query_object.granularity:
|
||||
# Apply the requested grain to the granularity dimension. For cube-mode
|
||||
# views that pre-declare grain on their dimensions, keep the existing
|
||||
# "only include the matching-grain version" semantic. For dataset-mode
|
||||
# views (dimensions have grain=None), the view applies the grain via
|
||||
# the engine spec at SQL build time, so we just stamp the grain on
|
||||
# the dimension here.
|
||||
dimensions = [
|
||||
(
|
||||
_stamp_grain(dim, grain)
|
||||
if dim.name == query_object.granularity and dim.grain is None
|
||||
else dim
|
||||
)
|
||||
for dim in dimensions
|
||||
if dim.name != query_object.granularity
|
||||
or dim.grain is None
|
||||
or dim.grain == grain
|
||||
]
|
||||
|
||||
order = _get_order_from_query_object(query_object, all_metrics, all_dimensions)
|
||||
limit = query_object.row_limit
|
||||
@@ -928,32 +1030,53 @@ def validate_query_object(
|
||||
|
||||
def _validate_metrics(query_object: ValidatedQueryObject) -> None:
|
||||
"""
|
||||
Make sure metrics are defined in the semantic view.
|
||||
Make sure metrics are defined in the semantic view or are valid adhocs.
|
||||
|
||||
Validation of adhoc metrics is delegated to ``_resolve_metric`` — it
|
||||
raises if the SIMPLE shape is malformed or the SQL fails Jinja /
|
||||
safety checks.
|
||||
"""
|
||||
semantic_view = query_object.datasource.implementation
|
||||
dataset = query_object.datasource
|
||||
metrics_by_name = {metric.name: metric for metric in semantic_view.metrics}
|
||||
template_processor = _get_template_processor(dataset)
|
||||
|
||||
if any(not isinstance(metric, str) for metric in (query_object.metrics or [])):
|
||||
raise ValueError("Adhoc metrics are not supported in Semantic Views.")
|
||||
labels: list[str] = []
|
||||
for metric in query_object.metrics or []:
|
||||
resolved = _resolve_metric(
|
||||
metric, dataset, metrics_by_name, template_processor
|
||||
)
|
||||
labels.append(resolved.id)
|
||||
|
||||
metric_names = {metric.name for metric in semantic_view.metrics}
|
||||
if not set(query_object.metrics or []) <= metric_names:
|
||||
raise ValueError("All metrics must be defined in the Semantic View.")
|
||||
if len(labels) != len(set(labels)):
|
||||
raise ValueError(
|
||||
"Duplicate metric labels are not supported in Semantic Views."
|
||||
)
|
||||
|
||||
|
||||
def _validate_dimensions(query_object: ValidatedQueryObject) -> None:
|
||||
"""
|
||||
Make sure all dimensions are defined in the semantic view.
|
||||
Make sure all dimensions are defined in the semantic view or are valid
|
||||
adhocs. Synthesized adhoc columns are validated via ``_resolve_dimension``.
|
||||
"""
|
||||
semantic_view = query_object.datasource.implementation
|
||||
dimension_names = {dimension.name for dimension in semantic_view.dimensions}
|
||||
dataset = query_object.datasource
|
||||
dimensions_by_name = {
|
||||
dimension.name: dimension for dimension in semantic_view.dimensions
|
||||
}
|
||||
template_processor = _get_template_processor(dataset)
|
||||
|
||||
# Normalize all columns to dimension names
|
||||
normalized_columns = [
|
||||
_normalize_column(column, dimension_names) for column in query_object.columns
|
||||
]
|
||||
labels: list[str] = []
|
||||
for column in query_object.columns:
|
||||
resolved = _resolve_dimension(
|
||||
column, dataset, dimensions_by_name, template_processor
|
||||
)
|
||||
labels.append(resolved.id)
|
||||
|
||||
if not set(normalized_columns) <= dimension_names:
|
||||
raise ValueError("All dimensions must be defined in the Semantic View.")
|
||||
if len(labels) != len(set(labels)):
|
||||
raise ValueError(
|
||||
"Duplicate column labels are not supported in Semantic Views."
|
||||
)
|
||||
|
||||
|
||||
def _validate_filters(query_object: ValidatedQueryObject) -> None:
|
||||
|
||||
120
superset/semantic_layers/rls.py
Normal file
120
superset/semantic_layers/rls.py
Normal file
@@ -0,0 +1,120 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
RLS helpers used by the dataset semantic-view path.
|
||||
|
||||
The legacy ``get_sqla_query`` applies row-level security in two places:
|
||||
|
||||
* The dataset's own RLS rules are AND-ed into the outer ``WHERE`` (or
|
||||
the source table is wrapped in a subquery when the engine spec
|
||||
returns ``RLSMethod.AS_SUBQUERY``).
|
||||
* For virtual datasets, RLS rules from the *underlying* tables
|
||||
referenced by ``dataset.sql`` are injected via
|
||||
:func:`superset.utils.rls.apply_rls`.
|
||||
|
||||
This module exposes both flows as pure functions returning SQL strings,
|
||||
which the ``DatasetSemanticView`` plugs into its sqlglot AST.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from superset.sql.parse import RLSMethod, SQLScript
|
||||
from superset.utils.rls import apply_rls
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_rls_method(dataset: "SqlaTable") -> RLSMethod:
|
||||
"""Return the engine's preferred RLS injection method."""
|
||||
return dataset.database.db_engine_spec.get_rls_method()
|
||||
|
||||
|
||||
def render_rls_predicates(dataset: "SqlaTable") -> list[str]:
|
||||
"""
|
||||
Render the dataset's outer RLS predicates as SQL strings.
|
||||
|
||||
The clauses are already Jinja-rendered by
|
||||
:meth:`SqlaTable.get_sqla_row_level_filters`; we then compile them
|
||||
against the database's dialect with ``literal_binds=True`` so the
|
||||
resulting strings are safe to splice into a sqlglot AST.
|
||||
|
||||
Returns an empty list when the dataset has no applicable RLS rules
|
||||
for the current user.
|
||||
"""
|
||||
text_clauses = dataset.get_sqla_row_level_filters()
|
||||
if not text_clauses:
|
||||
return []
|
||||
dialect = dataset.database.get_dialect()
|
||||
return [
|
||||
str(clause.compile(dialect=dialect, compile_kwargs={"literal_binds": True}))
|
||||
for clause in text_clauses
|
||||
]
|
||||
|
||||
|
||||
def apply_rls_to_virtual_sql(dataset: "SqlaTable") -> str | None:
|
||||
"""
|
||||
Rewrite a virtual dataset's inner SQL to enforce RLS on the
|
||||
underlying tables it references.
|
||||
|
||||
Returns the rewritten SQL string when RLS predicates were injected,
|
||||
or ``None`` when no rules applied (caller should use the original
|
||||
``dataset.sql`` in that case). Returns ``None`` for physical
|
||||
datasets too.
|
||||
"""
|
||||
if not dataset.sql:
|
||||
return None
|
||||
|
||||
engine = dataset.database.db_engine_spec.engine
|
||||
try:
|
||||
parsed_script = SQLScript(dataset.sql, engine=engine)
|
||||
except Exception: # pylint: disable=broad-except
|
||||
# Mirror the legacy behavior — failure to parse should not block
|
||||
# the query; the caller will fall back to the original SQL and
|
||||
# RLS will simply not be applied to inner tables.
|
||||
logger.warning(
|
||||
"Failed to parse virtual dataset SQL for RLS application",
|
||||
exc_info=True,
|
||||
)
|
||||
return None
|
||||
|
||||
default_schema = dataset.database.get_default_schema(dataset.catalog)
|
||||
rls_applied = False
|
||||
try:
|
||||
for statement in parsed_script.statements:
|
||||
if apply_rls(
|
||||
dataset.database,
|
||||
dataset.catalog,
|
||||
dataset.schema or default_schema or "",
|
||||
statement,
|
||||
exclude_dataset_id=dataset.id,
|
||||
):
|
||||
rls_applied = True
|
||||
except Exception: # pylint: disable=broad-except
|
||||
logger.warning(
|
||||
"Failed to apply RLS to virtual dataset inner SQL",
|
||||
exc_info=True,
|
||||
)
|
||||
return None
|
||||
|
||||
return parsed_script.format() if rls_applied else None
|
||||
270
tests/unit_tests/semantic_layers/adhoc_test.py
Normal file
270
tests/unit_tests/semantic_layers/adhoc_test.py
Normal file
@@ -0,0 +1,270 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pyarrow as pa
|
||||
import pytest
|
||||
from superset_core.semantic_layers.types import AggregationType, Dimension
|
||||
|
||||
from superset.exceptions import QueryObjectValidationError
|
||||
from superset.semantic_layers.adhoc import (
|
||||
adhoc_column_to_semantic_dimension,
|
||||
adhoc_metric_to_semantic_metric,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dataset() -> MagicMock:
|
||||
ds = MagicMock()
|
||||
ds.database_id = 1
|
||||
ds.schema = "public"
|
||||
ds.database.db_engine_spec.engine = "postgresql"
|
||||
# Identifier quoter — return double-quoted name.
|
||||
ds.quote_identifier = lambda name: f'"{name}"'
|
||||
# ``_process_select_expression`` echoes back the SQL after "rendering".
|
||||
ds._process_select_expression = (
|
||||
lambda expression, **kwargs: expression # noqa: U100
|
||||
)
|
||||
return ds
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Adhoc metric
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_simple_metric_sum(dataset: MagicMock) -> None:
|
||||
metric = adhoc_metric_to_semantic_metric(
|
||||
{
|
||||
"expressionType": "SIMPLE",
|
||||
"label": "total_amount",
|
||||
"aggregate": "SUM",
|
||||
"column": {"column_name": "amount"},
|
||||
},
|
||||
dataset,
|
||||
)
|
||||
assert metric.id == "total_amount"
|
||||
assert metric.name == "total_amount"
|
||||
assert metric.definition == 'SUM("amount")'
|
||||
assert metric.aggregation == AggregationType.SUM
|
||||
# SUM(...) infers to float64.
|
||||
assert metric.type == pa.float64()
|
||||
|
||||
|
||||
def test_simple_metric_count_distinct(dataset: MagicMock) -> None:
|
||||
metric = adhoc_metric_to_semantic_metric(
|
||||
{
|
||||
"expressionType": "SIMPLE",
|
||||
"label": "unique_customers",
|
||||
"aggregate": "COUNT_DISTINCT",
|
||||
"column": {"column_name": "customer_id"},
|
||||
},
|
||||
dataset,
|
||||
)
|
||||
assert metric.definition == 'COUNT(DISTINCT "customer_id")'
|
||||
assert metric.aggregation == AggregationType.COUNT_DISTINCT
|
||||
|
||||
|
||||
def test_simple_metric_unknown_aggregate_falls_back_to_other(
|
||||
dataset: MagicMock,
|
||||
) -> None:
|
||||
metric = adhoc_metric_to_semantic_metric(
|
||||
{
|
||||
"expressionType": "SIMPLE",
|
||||
"label": "median_x",
|
||||
"aggregate": "MEDIAN",
|
||||
"column": {"column_name": "x"},
|
||||
},
|
||||
dataset,
|
||||
)
|
||||
# MEDIAN isn't in the rollup-safe map.
|
||||
assert metric.aggregation == AggregationType.OTHER
|
||||
|
||||
|
||||
def test_simple_metric_missing_aggregate_raises(dataset: MagicMock) -> None:
|
||||
with pytest.raises(QueryObjectValidationError, match="aggregate and a column"):
|
||||
adhoc_metric_to_semantic_metric(
|
||||
{
|
||||
"expressionType": "SIMPLE",
|
||||
"label": "bad",
|
||||
"column": {"column_name": "amount"},
|
||||
},
|
||||
dataset,
|
||||
)
|
||||
|
||||
|
||||
def test_simple_metric_missing_column_raises(dataset: MagicMock) -> None:
|
||||
with pytest.raises(QueryObjectValidationError, match="aggregate and a column"):
|
||||
adhoc_metric_to_semantic_metric(
|
||||
{
|
||||
"expressionType": "SIMPLE",
|
||||
"label": "bad",
|
||||
"aggregate": "SUM",
|
||||
"column": {},
|
||||
},
|
||||
dataset,
|
||||
)
|
||||
|
||||
|
||||
def test_sql_metric(dataset: MagicMock) -> None:
|
||||
metric = adhoc_metric_to_semantic_metric(
|
||||
{
|
||||
"expressionType": "SQL",
|
||||
"label": "profit_margin",
|
||||
"sqlExpression": "SUM(revenue - cost) / SUM(revenue)",
|
||||
},
|
||||
dataset,
|
||||
)
|
||||
assert metric.definition == "SUM(revenue - cost) / SUM(revenue)"
|
||||
assert metric.aggregation == AggregationType.OTHER
|
||||
|
||||
|
||||
def test_sql_metric_missing_expression_raises(dataset: MagicMock) -> None:
|
||||
with pytest.raises(QueryObjectValidationError, match="sqlExpression"):
|
||||
adhoc_metric_to_semantic_metric(
|
||||
{"expressionType": "SQL", "label": "bad"},
|
||||
dataset,
|
||||
)
|
||||
|
||||
|
||||
def test_metric_missing_label_raises(dataset: MagicMock) -> None:
|
||||
with pytest.raises(QueryObjectValidationError, match="missing a ``label``"):
|
||||
adhoc_metric_to_semantic_metric(
|
||||
{"expressionType": "SIMPLE", "aggregate": "SUM"},
|
||||
dataset,
|
||||
)
|
||||
|
||||
|
||||
def test_metric_unknown_expression_type_raises(dataset: MagicMock) -> None:
|
||||
with pytest.raises(QueryObjectValidationError, match="Unknown adhoc metric"):
|
||||
adhoc_metric_to_semantic_metric(
|
||||
{"label": "weird", "expressionType": "MYSTERY"},
|
||||
dataset,
|
||||
)
|
||||
|
||||
|
||||
def test_sql_metric_jinja_applied(dataset: MagicMock) -> None:
|
||||
# ``_process_select_expression`` is where Jinja and safety validation
|
||||
# live in the dataset model. We verify the helper is invoked.
|
||||
dataset._process_select_expression = MagicMock(return_value="user_id = 42")
|
||||
metric = adhoc_metric_to_semantic_metric(
|
||||
{
|
||||
"expressionType": "SQL",
|
||||
"label": "rendered",
|
||||
"sqlExpression": "user_id = {{ current_user_id() }}",
|
||||
},
|
||||
dataset,
|
||||
template_processor=MagicMock(),
|
||||
)
|
||||
assert metric.definition == "user_id = 42"
|
||||
dataset._process_select_expression.assert_called_once()
|
||||
|
||||
|
||||
def test_sql_metric_empty_processed_raises(dataset: MagicMock) -> None:
|
||||
dataset._process_select_expression = MagicMock(return_value=None)
|
||||
with pytest.raises(QueryObjectValidationError, match="empty string"):
|
||||
adhoc_metric_to_semantic_metric(
|
||||
{
|
||||
"expressionType": "SQL",
|
||||
"label": "bad",
|
||||
"sqlExpression": "{{ '' }}",
|
||||
},
|
||||
dataset,
|
||||
template_processor=MagicMock(),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Adhoc column
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_adhoc_column_reference_uses_existing_dimension(dataset: MagicMock) -> None:
|
||||
existing = Dimension(id="country", name="country", type=pa.utf8())
|
||||
result = adhoc_column_to_semantic_dimension(
|
||||
{
|
||||
"label": "country",
|
||||
"sqlExpression": "country",
|
||||
"isColumnReference": True,
|
||||
},
|
||||
dataset,
|
||||
{"country": existing},
|
||||
)
|
||||
assert result is existing
|
||||
|
||||
|
||||
def test_adhoc_column_synthesises_dimension(dataset: MagicMock) -> None:
|
||||
dataset._process_select_expression = MagicMock(return_value="UPPER(country)")
|
||||
result = adhoc_column_to_semantic_dimension(
|
||||
{
|
||||
"label": "upper_country",
|
||||
"sqlExpression": "UPPER(country)",
|
||||
},
|
||||
dataset,
|
||||
{},
|
||||
template_processor=MagicMock(),
|
||||
)
|
||||
assert result.id == "upper_country"
|
||||
assert result.name == "upper_country"
|
||||
assert result.definition == "UPPER(country)"
|
||||
# UPPER(...) infers to utf8.
|
||||
assert result.type == pa.utf8()
|
||||
|
||||
|
||||
def test_adhoc_column_missing_label_raises(dataset: MagicMock) -> None:
|
||||
with pytest.raises(QueryObjectValidationError, match="``label``"):
|
||||
adhoc_column_to_semantic_dimension(
|
||||
{"sqlExpression": "x"},
|
||||
dataset,
|
||||
{},
|
||||
)
|
||||
|
||||
|
||||
def test_adhoc_column_missing_sql_raises(dataset: MagicMock) -> None:
|
||||
with pytest.raises(QueryObjectValidationError, match="``sqlExpression``"):
|
||||
adhoc_column_to_semantic_dimension(
|
||||
{"label": "x"},
|
||||
dataset,
|
||||
{},
|
||||
)
|
||||
|
||||
|
||||
def test_adhoc_column_reference_falls_back_when_not_matching(
|
||||
dataset: MagicMock,
|
||||
) -> None:
|
||||
"""
|
||||
A column-reference adhoc whose sqlExpression doesn't match an existing
|
||||
dimension is treated as a synthesized adhoc.
|
||||
"""
|
||||
dataset._process_select_expression = MagicMock(return_value="ghost")
|
||||
result = adhoc_column_to_semantic_dimension(
|
||||
{
|
||||
"label": "spooky",
|
||||
"sqlExpression": "ghost",
|
||||
"isColumnReference": True,
|
||||
},
|
||||
dataset,
|
||||
{"country": Dimension(id="country", name="country", type=pa.utf8())},
|
||||
template_processor=MagicMock(),
|
||||
)
|
||||
assert result.id == "spooky"
|
||||
assert result.definition == "ghost"
|
||||
@@ -50,9 +50,14 @@ from superset.semantic_layers.mapper import (
|
||||
_get_group_limit_filters,
|
||||
_get_group_limit_from_query_object,
|
||||
_get_order_from_query_object,
|
||||
_get_template_processor,
|
||||
_get_time_bounds,
|
||||
_get_time_filter,
|
||||
_normalize_column,
|
||||
_resolve_dimension,
|
||||
_resolve_metric,
|
||||
_stamp_grain,
|
||||
_validate_dimensions,
|
||||
_validate_filters,
|
||||
_validate_granularity,
|
||||
_validate_group_limit,
|
||||
@@ -1035,7 +1040,10 @@ def test_validate_query_object_undefined_metric_error(
|
||||
columns=["order_date"],
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="All metrics must be defined"):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="Metric 'undefined_metric' is not defined in the Semantic View",
|
||||
):
|
||||
validate_query_object(query_object)
|
||||
|
||||
|
||||
@@ -1051,7 +1059,10 @@ def test_validate_query_object_undefined_dimension_error(
|
||||
columns=["undefined_dimension"],
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="All dimensions must be defined"):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="Dimension 'undefined_dimension' is not defined in the Semantic View",
|
||||
):
|
||||
validate_query_object(query_object)
|
||||
|
||||
|
||||
@@ -1800,28 +1811,31 @@ def test_get_results_empty_requests(
|
||||
|
||||
def test_normalize_column_adhoc_not_in_dimensions() -> None:
|
||||
"""
|
||||
Test _normalize_column raises error for AdhocColumn with sqlExpression not in dims.
|
||||
Adhoc columns whose sqlExpression doesn't match an existing dimension fall
|
||||
back to using the label as the resolved name. Actual SQL synthesis is the
|
||||
job of _resolve_dimension; _normalize_column only surfaces a name.
|
||||
"""
|
||||
dimension_names = {"category", "region"}
|
||||
adhoc_column: AdhocColumn = {
|
||||
"label": "custom_dim",
|
||||
"isColumnReference": True,
|
||||
"sqlExpression": "unknown_dimension",
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="Adhoc dimensions are not supported"):
|
||||
_normalize_column(adhoc_column, dimension_names)
|
||||
assert _normalize_column(adhoc_column, dimension_names) == "custom_dim"
|
||||
|
||||
|
||||
def test_normalize_column_adhoc_missing_sql_expression() -> None:
|
||||
def test_normalize_column_adhoc_missing_label_raises() -> None:
|
||||
"""
|
||||
Test _normalize_column raises error for AdhocColumn without sqlExpression.
|
||||
When neither a matching column reference nor a label is provided there's
|
||||
no resolvable name and _normalize_column raises.
|
||||
"""
|
||||
dimension_names = {"category", "region"}
|
||||
adhoc_column: AdhocColumn = {
|
||||
"isColumnReference": True,
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="Adhoc dimensions are not supported"):
|
||||
with pytest.raises(ValueError, match="Adhoc column is missing"):
|
||||
_normalize_column(adhoc_column, dimension_names)
|
||||
|
||||
|
||||
@@ -2265,11 +2279,12 @@ def test_validate_query_object_no_datasource() -> None:
|
||||
assert result is False
|
||||
|
||||
|
||||
def test_validate_metrics_adhoc_error(
|
||||
def test_validate_metrics_adhoc_with_bad_shape_raises(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""
|
||||
Test validation error for adhoc metrics.
|
||||
Adhoc metrics are now supported, but invalid shapes (missing
|
||||
expressionType, etc.) still raise via the adhoc resolver.
|
||||
"""
|
||||
mock_datasource = mocker.Mock()
|
||||
category_dim = Dimension("category", "category", pa.utf8(), "category", "Category")
|
||||
@@ -2279,13 +2294,15 @@ def test_validate_metrics_adhoc_error(
|
||||
|
||||
mock_datasource.implementation.dimensions = {category_dim}
|
||||
mock_datasource.implementation.metrics = {sales_metric}
|
||||
# Strip the template processor; we just want to verify the shape check.
|
||||
mock_datasource.get_template_processor.return_value = None
|
||||
|
||||
# Manually create a query object with an adhoc metric
|
||||
query_object = mocker.Mock()
|
||||
query_object.datasource = mock_datasource
|
||||
# Missing expressionType — the resolver doesn't know how to interpret this.
|
||||
query_object.metrics = [{"label": "adhoc", "sqlExpression": "SUM(x)"}]
|
||||
|
||||
with pytest.raises(ValueError, match="Adhoc metrics are not supported"):
|
||||
with pytest.raises(Exception, match="expressionType"):
|
||||
_validate_metrics(query_object)
|
||||
|
||||
|
||||
@@ -3111,3 +3128,171 @@ def test_coerce_time_invalid_string_raises() -> None:
|
||||
def test_coerce_time_rejects_other_types() -> None:
|
||||
with pytest.raises(ValueError, match="Invalid time value"):
|
||||
_coerce_scalar_filter_value(123, _dim(pa.time64("us")))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Adhoc + grain resolver coverage
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_get_template_processor_returns_none_when_unsupported() -> None:
|
||||
"""A bare object without ``get_template_processor`` returns None."""
|
||||
|
||||
class Plain:
|
||||
pass
|
||||
|
||||
assert _get_template_processor(Plain()) is None
|
||||
|
||||
|
||||
def test_stamp_grain_returns_new_dimension_with_grain() -> None:
|
||||
dim = Dimension(id="dt", name="dt", type=pa.timestamp("us"))
|
||||
stamped = _stamp_grain(dim, Grains.DAY)
|
||||
assert stamped.grain == Grains.DAY
|
||||
# Original is unchanged (frozen dataclass).
|
||||
assert dim.grain is None
|
||||
|
||||
|
||||
def _adhoc_dataset() -> MagicMock:
|
||||
ds = MagicMock()
|
||||
ds.database_id = 1
|
||||
ds.schema = "public"
|
||||
ds.database.db_engine_spec.engine = "postgresql"
|
||||
ds.quote_identifier = lambda name: f'"{name}"'
|
||||
ds._process_select_expression = lambda expression, **kwargs: expression # noqa: U100
|
||||
ds.get_template_processor.return_value = None
|
||||
return ds
|
||||
|
||||
|
||||
def test_resolve_metric_for_adhoc_simple_dict() -> None:
|
||||
ds = _adhoc_dataset()
|
||||
metric = _resolve_metric(
|
||||
{
|
||||
"expressionType": "SIMPLE",
|
||||
"label": "total",
|
||||
"aggregate": "SUM",
|
||||
"column": {"column_name": "x"},
|
||||
},
|
||||
ds,
|
||||
{},
|
||||
None,
|
||||
)
|
||||
assert metric.id == "total"
|
||||
assert metric.definition == 'SUM("x")'
|
||||
|
||||
|
||||
def test_resolve_dimension_for_adhoc_dict() -> None:
|
||||
ds = _adhoc_dataset()
|
||||
dim = _resolve_dimension(
|
||||
{"label": "calc", "sqlExpression": "UPPER(country)"},
|
||||
ds,
|
||||
{},
|
||||
None,
|
||||
)
|
||||
assert dim.id == "calc"
|
||||
assert dim.definition == "UPPER(country)"
|
||||
|
||||
|
||||
def test_validate_metrics_duplicate_label_raises() -> None:
|
||||
ds = _adhoc_dataset()
|
||||
view = MagicMock()
|
||||
view.metrics = []
|
||||
view.dimensions = []
|
||||
ds.implementation = view
|
||||
|
||||
query = MagicMock()
|
||||
query.datasource = ds
|
||||
query.metrics = [
|
||||
{
|
||||
"expressionType": "SIMPLE",
|
||||
"label": "dup",
|
||||
"aggregate": "SUM",
|
||||
"column": {"column_name": "x"},
|
||||
},
|
||||
{
|
||||
"expressionType": "SIMPLE",
|
||||
"label": "dup",
|
||||
"aggregate": "SUM",
|
||||
"column": {"column_name": "y"},
|
||||
},
|
||||
]
|
||||
|
||||
with pytest.raises(ValueError, match="Duplicate metric labels"):
|
||||
_validate_metrics(query)
|
||||
|
||||
|
||||
def test_validate_dimensions_duplicate_label_raises() -> None:
|
||||
ds = _adhoc_dataset()
|
||||
view = MagicMock()
|
||||
view.metrics = []
|
||||
view.dimensions = []
|
||||
ds.implementation = view
|
||||
|
||||
query = MagicMock()
|
||||
query.datasource = ds
|
||||
query.columns = [
|
||||
{"label": "dup", "sqlExpression": "x"},
|
||||
{"label": "dup", "sqlExpression": "y"},
|
||||
]
|
||||
|
||||
with pytest.raises(ValueError, match="Duplicate column labels"):
|
||||
_validate_dimensions(query)
|
||||
|
||||
|
||||
def test_map_query_object_dedups_dimensions_by_id(mocker: MockerFixture) -> None:
|
||||
"""Two requested columns resolving to the same dim id collapse to one."""
|
||||
ds = _adhoc_dataset()
|
||||
existing = Dimension(id="country", name="country", type=pa.utf8())
|
||||
view = MagicMock()
|
||||
view.metrics = []
|
||||
view.dimensions = [existing]
|
||||
ds.implementation = view
|
||||
ds.fetch_values_predicate = None
|
||||
|
||||
query = ValidatedQueryObject(
|
||||
datasource=ds,
|
||||
metrics=[],
|
||||
# Same dim referenced twice — once as a string, once as a column ref.
|
||||
columns=[
|
||||
"country",
|
||||
{"label": "country", "sqlExpression": "country", "isColumnReference": True},
|
||||
],
|
||||
)
|
||||
|
||||
sem_queries = map_query_object(query)
|
||||
assert len(sem_queries[0].dimensions) == 1
|
||||
|
||||
|
||||
def test_map_query_object_stamps_grain_on_granularity_dim(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""A grain request stamps the grain onto the matching dimension."""
|
||||
ds = _adhoc_dataset()
|
||||
dt_dim = Dimension(id="dt", name="dt", type=pa.timestamp("us"))
|
||||
view = MagicMock()
|
||||
view.metrics = []
|
||||
view.dimensions = [dt_dim]
|
||||
ds.implementation = view
|
||||
ds.fetch_values_predicate = None
|
||||
|
||||
query = ValidatedQueryObject(
|
||||
datasource=ds,
|
||||
metrics=[],
|
||||
columns=["dt"],
|
||||
granularity="dt",
|
||||
extras={"time_grain_sqla": "P1D"},
|
||||
)
|
||||
|
||||
sem_queries = map_query_object(query)
|
||||
dim = sem_queries[0].dimensions[0]
|
||||
assert dim.grain == Grains.DAY
|
||||
|
||||
|
||||
def test_normalize_column_adhoc_label_only() -> None:
|
||||
"""Adhoc with no isColumnReference falls back to its label."""
|
||||
assert (
|
||||
_normalize_column({"label": "calc", "sqlExpression": "x"}, set()) == "calc"
|
||||
)
|
||||
|
||||
|
||||
def test_normalize_column_string_passthrough() -> None:
|
||||
assert _normalize_column("category", {"category", "region"}) == "category"
|
||||
|
||||
166
tests/unit_tests/semantic_layers/rls_test.py
Normal file
166
tests/unit_tests/semantic_layers/rls_test.py
Normal file
@@ -0,0 +1,166 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from superset.semantic_layers.rls import (
|
||||
apply_rls_to_virtual_sql,
|
||||
get_rls_method,
|
||||
render_rls_predicates,
|
||||
)
|
||||
from superset.sql.parse import RLSMethod
|
||||
|
||||
|
||||
def _make_dataset(
|
||||
*,
|
||||
sql: str | None = None,
|
||||
rls_clauses: list | None = None,
|
||||
) -> MagicMock:
|
||||
ds = MagicMock()
|
||||
ds.id = 1
|
||||
ds.catalog = None
|
||||
ds.schema = "public"
|
||||
ds.sql = sql
|
||||
ds.database.db_engine_spec.engine = "postgresql"
|
||||
ds.database.get_default_schema.return_value = "public"
|
||||
ds.get_sqla_row_level_filters.return_value = rls_clauses or []
|
||||
return ds
|
||||
|
||||
|
||||
def test_get_rls_method_delegates_to_engine_spec() -> None:
|
||||
ds = _make_dataset()
|
||||
ds.database.db_engine_spec.get_rls_method.return_value = RLSMethod.AS_SUBQUERY
|
||||
assert get_rls_method(ds) == RLSMethod.AS_SUBQUERY
|
||||
|
||||
|
||||
def test_render_rls_predicates_empty_when_no_rules() -> None:
|
||||
ds = _make_dataset()
|
||||
assert render_rls_predicates(ds) == []
|
||||
|
||||
|
||||
def test_render_rls_predicates_compiles_each_clause() -> None:
|
||||
clause1 = MagicMock()
|
||||
clause1.compile.return_value = "tenant_id = 7"
|
||||
clause2 = MagicMock()
|
||||
clause2.compile.return_value = "deleted = false"
|
||||
ds = _make_dataset(rls_clauses=[clause1, clause2])
|
||||
|
||||
result = render_rls_predicates(ds)
|
||||
assert result == ["tenant_id = 7", "deleted = false"]
|
||||
# Each clause is compiled with literal_binds so values are inlined.
|
||||
for clause in (clause1, clause2):
|
||||
kwargs = clause.compile.call_args.kwargs
|
||||
assert kwargs["compile_kwargs"] == {"literal_binds": True}
|
||||
|
||||
|
||||
def test_apply_rls_to_virtual_sql_returns_none_for_physical(mocker: MockerFixture) -> None:
|
||||
ds = _make_dataset()
|
||||
assert apply_rls_to_virtual_sql(ds) is None
|
||||
|
||||
|
||||
def test_apply_rls_to_virtual_sql_returns_none_when_no_rls_applied(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
ds = _make_dataset(sql="SELECT * FROM raw_orders")
|
||||
# apply_rls returns False → no predicates were injected.
|
||||
mocker.patch("superset.semantic_layers.rls.apply_rls", return_value=False)
|
||||
assert apply_rls_to_virtual_sql(ds) is None
|
||||
|
||||
|
||||
def test_apply_rls_to_virtual_sql_returns_rewritten_when_applied(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
ds = _make_dataset(sql="SELECT * FROM raw_orders")
|
||||
fake_script = MagicMock()
|
||||
fake_statement = MagicMock()
|
||||
fake_script.statements = [fake_statement]
|
||||
fake_script.format.return_value = "SELECT * FROM raw_orders WHERE x = 1"
|
||||
mocker.patch(
|
||||
"superset.semantic_layers.rls.SQLScript",
|
||||
return_value=fake_script,
|
||||
)
|
||||
mocker.patch("superset.semantic_layers.rls.apply_rls", return_value=True)
|
||||
|
||||
result = apply_rls_to_virtual_sql(ds)
|
||||
assert result == "SELECT * FROM raw_orders WHERE x = 1"
|
||||
|
||||
|
||||
def test_apply_rls_to_virtual_sql_swallows_parse_error(mocker: MockerFixture) -> None:
|
||||
ds = _make_dataset(sql="totally invalid sql")
|
||||
mocker.patch(
|
||||
"superset.semantic_layers.rls.SQLScript",
|
||||
side_effect=Exception("parse error"),
|
||||
)
|
||||
assert apply_rls_to_virtual_sql(ds) is None
|
||||
|
||||
|
||||
def test_apply_rls_to_virtual_sql_swallows_apply_error(mocker: MockerFixture) -> None:
|
||||
ds = _make_dataset(sql="SELECT * FROM raw_orders")
|
||||
fake_script = MagicMock()
|
||||
fake_script.statements = [MagicMock()]
|
||||
mocker.patch(
|
||||
"superset.semantic_layers.rls.SQLScript",
|
||||
return_value=fake_script,
|
||||
)
|
||||
mocker.patch(
|
||||
"superset.semantic_layers.rls.apply_rls",
|
||||
side_effect=Exception("boom"),
|
||||
)
|
||||
assert apply_rls_to_virtual_sql(ds) is None
|
||||
|
||||
|
||||
def test_apply_rls_to_virtual_sql_uses_default_schema_when_dataset_schema_missing(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
ds = _make_dataset(sql="SELECT * FROM raw_orders")
|
||||
ds.schema = None
|
||||
fake_script = MagicMock()
|
||||
fake_statement = MagicMock()
|
||||
fake_script.statements = [fake_statement]
|
||||
mocker.patch(
|
||||
"superset.semantic_layers.rls.SQLScript",
|
||||
return_value=fake_script,
|
||||
)
|
||||
apply_rls_mock = mocker.patch(
|
||||
"superset.semantic_layers.rls.apply_rls",
|
||||
return_value=False,
|
||||
)
|
||||
|
||||
apply_rls_to_virtual_sql(ds)
|
||||
# default_schema was "public" from get_default_schema; that should be the
|
||||
# schema arg passed to apply_rls.
|
||||
args, kwargs = apply_rls_mock.call_args
|
||||
assert args[2] == "public" or kwargs.get("schema") == "public"
|
||||
|
||||
|
||||
def test_render_rls_predicates_uses_dialect_for_compile() -> None:
|
||||
clause = MagicMock()
|
||||
clause.compile.return_value = "x = 1"
|
||||
ds = _make_dataset(rls_clauses=[clause])
|
||||
|
||||
render_rls_predicates(ds)
|
||||
kwargs = clause.compile.call_args.kwargs
|
||||
# Verify the dialect from the database is passed (not asserting identity,
|
||||
# just that ``dialect=`` got a value).
|
||||
assert "dialect" in kwargs
|
||||
Reference in New Issue
Block a user