mirror of
https://github.com/apache/superset.git
synced 2026-07-05 14:25:32 +00:00
Compare commits
4 Commits
chore/ci/s
...
builtin-sl
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4a749e3f31 | ||
|
|
439bcf8d79 | ||
|
|
ce7753b631 | ||
|
|
b9492f477b |
@@ -2529,6 +2529,13 @@ except ImportError:
|
|||||||
LOCAL_EXTENSIONS: list[str] = []
|
LOCAL_EXTENSIONS: list[str] = []
|
||||||
EXTENSIONS_PATH: str | None = None
|
EXTENSIONS_PATH: str | None = None
|
||||||
|
|
||||||
|
# When True, dataset queries are routed through the dataset semantic-layer
|
||||||
|
# extension (``superset/semantic_layers/extension``) instead of the legacy
|
||||||
|
# ``get_sqla_query`` path. The semantic view builds the SQL via sqlglot and
|
||||||
|
# the mapper handles the QueryObject → SemanticQuery translation. Falls back
|
||||||
|
# to the legacy path on any error.
|
||||||
|
USE_DATASET_SEMANTIC_VIEW: bool = False
|
||||||
|
|
||||||
# Default polling interval for tasks (seconds)
|
# Default polling interval for tasks (seconds)
|
||||||
TASK_ABORT_POLLING_DEFAULT_INTERVAL = 10
|
TASK_ABORT_POLLING_DEFAULT_INTERVAL = 10
|
||||||
|
|
||||||
|
|||||||
@@ -19,10 +19,12 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import builtins
|
import builtins
|
||||||
import logging
|
import logging
|
||||||
|
import sys
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from collections.abc import Hashable
|
from collections.abc import Hashable
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any, Callable, cast, Optional, Union
|
from typing import Any, Callable, cast, Optional, Union
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
@@ -71,6 +73,7 @@ from superset_core.common.models import Dataset as CoreDataset
|
|||||||
from superset import db, is_feature_enabled, security_manager
|
from superset import db, is_feature_enabled, security_manager
|
||||||
from superset.commands.dataset.exceptions import DatasetNotFoundError
|
from superset.commands.dataset.exceptions import DatasetNotFoundError
|
||||||
from superset.common.db_query_status import QueryStatus
|
from superset.common.db_query_status import QueryStatus
|
||||||
|
from superset.common.query_object import QueryObject
|
||||||
from superset.connectors.sqla.utils import (
|
from superset.connectors.sqla.utils import (
|
||||||
get_columns_description,
|
get_columns_description,
|
||||||
get_physical_table_metadata,
|
get_physical_table_metadata,
|
||||||
@@ -2139,6 +2142,67 @@ class SqlaTable(
|
|||||||
"""Returns a text clause using ExploreMixin implementation"""
|
"""Returns a text clause using ExploreMixin implementation"""
|
||||||
return ExploreMixin.text(self, clause)
|
return ExploreMixin.text(self, clause)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def implementation(self) -> Any:
|
||||||
|
"""
|
||||||
|
Expose the dataset as a ``SemanticView`` instance.
|
||||||
|
|
||||||
|
The dataset semantic-layer extension (``superset/semantic_layers/
|
||||||
|
extension``) wraps a ``SqlaTable`` in a ``DatasetSemanticView``,
|
||||||
|
translating the dataset's columns and metrics into semantic dimensions
|
||||||
|
and metrics. This property exists so the same mapper-based execution
|
||||||
|
path used by stored semantic views (``mapper.get_results``) can drive
|
||||||
|
plain dataset queries when ``USE_DATASET_SEMANTIC_VIEW`` is enabled.
|
||||||
|
"""
|
||||||
|
# The extension's backend lives outside the regular Python path; add it
|
||||||
|
# the first time we need it. The ``@semantic_layer`` decorator has
|
||||||
|
# already been monkey-patched by ``inject_semantic_layer_implementations``
|
||||||
|
# at app init, so importing here is safe at runtime.
|
||||||
|
extension_src = (
|
||||||
|
Path(__file__).resolve().parents[1]
|
||||||
|
/ "semantic_layers"
|
||||||
|
/ "extension"
|
||||||
|
/ "backend"
|
||||||
|
/ "src"
|
||||||
|
)
|
||||||
|
extension_src_str = str(extension_src)
|
||||||
|
if extension_src_str not in sys.path:
|
||||||
|
sys.path.insert(0, extension_src_str)
|
||||||
|
|
||||||
|
from preset_io.dataset_semantic_layer import DatasetSemanticView
|
||||||
|
|
||||||
|
return DatasetSemanticView(self)
|
||||||
|
|
||||||
|
def get_query_result(self, query_object: QueryObject) -> QueryResult:
|
||||||
|
"""
|
||||||
|
Route dataset queries through the dataset semantic-layer extension
|
||||||
|
when ``USE_DATASET_SEMANTIC_VIEW`` is enabled, otherwise fall back to
|
||||||
|
the legacy ``ExploreMixin`` path.
|
||||||
|
|
||||||
|
The mapper in ``superset.semantic_layers.mapper`` handles the
|
||||||
|
``QueryObject`` → ``SemanticQuery`` translation (filters, time range,
|
||||||
|
group limit, ordering, time offsets), so all we add here is the
|
||||||
|
feature-flag gate and a safety fallback for cases the semantic view
|
||||||
|
does not yet support.
|
||||||
|
"""
|
||||||
|
if current_app.config.get("USE_DATASET_SEMANTIC_VIEW"):
|
||||||
|
from superset.semantic_layers.mapper import get_results
|
||||||
|
|
||||||
|
try:
|
||||||
|
# ``mapper.get_results`` reads ``query_object.datasource.implementation``;
|
||||||
|
# ensure we are the datasource on the in-memory QueryObject.
|
||||||
|
if query_object.datasource is None:
|
||||||
|
query_object.datasource = self
|
||||||
|
return get_results(query_object)
|
||||||
|
except Exception: # pylint: disable=broad-except
|
||||||
|
logger.warning(
|
||||||
|
"Semantic-view execution failed for dataset %s; falling "
|
||||||
|
"back to the legacy query path.",
|
||||||
|
self.table_name,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
return ExploreMixin.get_query_result(self, query_object)
|
||||||
|
|
||||||
|
|
||||||
sa.event.listen(SqlaTable, "before_update", SqlaTable.before_update)
|
sa.event.listen(SqlaTable, "before_update", SqlaTable.before_update)
|
||||||
sa.event.listen(SqlaTable, "after_insert", SqlaTable.after_insert)
|
sa.event.listen(SqlaTable, "after_insert", SqlaTable.after_insert)
|
||||||
|
|||||||
197
superset/semantic_layers/adhoc.py
Normal file
197
superset/semantic_layers/adhoc.py
Normal file
@@ -0,0 +1,197 @@
|
|||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
Convert frontend "adhoc" metric / column dictionaries into semantic-layer
|
||||||
|
``Metric`` / ``Dimension`` objects.
|
||||||
|
|
||||||
|
Adhoc metrics and columns are user-defined SQL expressions that don't
|
||||||
|
correspond to a saved metric or physical column on the dataset. The legacy
|
||||||
|
``get_sqla_query`` path supports them through ``adhoc_metric_to_sqla`` /
|
||||||
|
``adhoc_column_to_sqla``; the semantic-layer path needs an equivalent so
|
||||||
|
that flipping ``USE_DATASET_SEMANTIC_VIEW`` doesn't regress charts that
|
||||||
|
use them.
|
||||||
|
|
||||||
|
The strategy is light: synthesize a ``Metric(definition=<sql>)`` /
|
||||||
|
``Dimension(definition=<sql>)`` and let the view's existing
|
||||||
|
``definition``-parsing path handle the rest. Jinja templating + safety
|
||||||
|
checks are delegated to the dataset's own
|
||||||
|
``_process_select_expression`` (already battle-tested).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, TYPE_CHECKING
|
||||||
|
|
||||||
|
import pyarrow as pa
|
||||||
|
from superset_core.semantic_layers.types import (
|
||||||
|
AggregationType,
|
||||||
|
Dimension,
|
||||||
|
Metric,
|
||||||
|
)
|
||||||
|
|
||||||
|
from superset.exceptions import QueryObjectValidationError
|
||||||
|
from superset.semantic_layers.arrow_inference import infer_arrow_type
|
||||||
|
from superset.utils.core import AdhocMetricExpressionType
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from superset.connectors.sqla.models import SqlaTable
|
||||||
|
from superset.jinja_context import BaseTemplateProcessor
|
||||||
|
|
||||||
|
|
||||||
|
# Map Superset's aggregate strings to the semantic-layer ``AggregationType``.
|
||||||
|
# Anything unmapped (median, percentile, …) falls back to OTHER, which the
|
||||||
|
# view treats as a non-rollup-safe metric.
|
||||||
|
_AGGREGATE_MAP: dict[str, AggregationType] = {
|
||||||
|
"COUNT": AggregationType.COUNT,
|
||||||
|
"SUM": AggregationType.SUM,
|
||||||
|
"AVG": AggregationType.AVG,
|
||||||
|
"MIN": AggregationType.MIN,
|
||||||
|
"MAX": AggregationType.MAX,
|
||||||
|
"COUNT_DISTINCT": AggregationType.COUNT_DISTINCT,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _process_adhoc_sql(
|
||||||
|
dataset: "SqlaTable",
|
||||||
|
sql_expression: str,
|
||||||
|
template_processor: "BaseTemplateProcessor | None",
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Run Jinja + safety validation on a free-form SQL expression.
|
||||||
|
|
||||||
|
Delegates to the dataset's ``_process_select_expression`` so adhoc
|
||||||
|
columns and metrics go through the same Jinja rendering, subquery
|
||||||
|
rules, and clause sanitization as the legacy path.
|
||||||
|
"""
|
||||||
|
processed = dataset._process_select_expression( # noqa: SLF001
|
||||||
|
expression=sql_expression,
|
||||||
|
database_id=dataset.database_id,
|
||||||
|
engine=dataset.database.db_engine_spec.engine,
|
||||||
|
schema=dataset.schema or "",
|
||||||
|
template_processor=template_processor,
|
||||||
|
)
|
||||||
|
if not processed:
|
||||||
|
raise QueryObjectValidationError(
|
||||||
|
"Adhoc SQL expression resolved to an empty string."
|
||||||
|
)
|
||||||
|
return processed
|
||||||
|
|
||||||
|
|
||||||
|
def _simple_metric_definition(adhoc: dict[str, Any], dataset: "SqlaTable") -> str:
|
||||||
|
"""
|
||||||
|
Build the SQL for a SIMPLE adhoc metric: ``AGGREGATE(column)``.
|
||||||
|
|
||||||
|
Uses the database's identifier quoter so the column name is rendered
|
||||||
|
in the right dialect.
|
||||||
|
"""
|
||||||
|
aggregate = (adhoc.get("aggregate") or "").upper()
|
||||||
|
column = adhoc.get("column") or {}
|
||||||
|
column_name = column.get("column_name")
|
||||||
|
if not aggregate or not column_name:
|
||||||
|
raise QueryObjectValidationError(
|
||||||
|
"SIMPLE adhoc metric requires both an aggregate and a column."
|
||||||
|
)
|
||||||
|
if aggregate == "COUNT_DISTINCT":
|
||||||
|
return f"COUNT(DISTINCT {dataset.quote_identifier(column_name)})"
|
||||||
|
return f"{aggregate}({dataset.quote_identifier(column_name)})"
|
||||||
|
|
||||||
|
|
||||||
|
def adhoc_metric_to_semantic_metric(
|
||||||
|
adhoc: dict[str, Any],
|
||||||
|
dataset: "SqlaTable",
|
||||||
|
template_processor: "BaseTemplateProcessor | None" = None,
|
||||||
|
) -> Metric:
|
||||||
|
"""
|
||||||
|
Convert an ``AdhocMetric`` dict to a semantic-layer ``Metric``.
|
||||||
|
|
||||||
|
SIMPLE metrics build ``f"{aggregate}({column})"`` and tag the
|
||||||
|
aggregation type so the view can preserve rollup semantics. SQL
|
||||||
|
metrics use the user-provided ``sqlExpression`` (after Jinja +
|
||||||
|
safety validation) and report ``AggregationType.OTHER`` because the
|
||||||
|
expression is opaque.
|
||||||
|
|
||||||
|
Adhoc dtype is set to ``pa.null()`` — the view aliases the metric
|
||||||
|
by id and downstream chart code coerces at render time.
|
||||||
|
"""
|
||||||
|
label = adhoc.get("label")
|
||||||
|
if not label:
|
||||||
|
raise QueryObjectValidationError("Adhoc metric is missing a ``label``.")
|
||||||
|
|
||||||
|
expression_type = adhoc.get("expressionType")
|
||||||
|
if expression_type == AdhocMetricExpressionType.SIMPLE.value:
|
||||||
|
definition = _simple_metric_definition(adhoc, dataset)
|
||||||
|
# Templating is irrelevant for SIMPLE metrics (no user-supplied SQL),
|
||||||
|
# so the safety pass is skipped here.
|
||||||
|
aggregate = (adhoc.get("aggregate") or "").upper()
|
||||||
|
aggregation = _AGGREGATE_MAP.get(aggregate, AggregationType.OTHER)
|
||||||
|
elif expression_type == AdhocMetricExpressionType.SQL.value:
|
||||||
|
sql_expression = adhoc.get("sqlExpression")
|
||||||
|
if not sql_expression:
|
||||||
|
raise QueryObjectValidationError(
|
||||||
|
"SQL adhoc metric is missing ``sqlExpression``."
|
||||||
|
)
|
||||||
|
definition = _process_adhoc_sql(dataset, sql_expression, template_processor)
|
||||||
|
aggregation = AggregationType.OTHER
|
||||||
|
else:
|
||||||
|
raise QueryObjectValidationError(
|
||||||
|
f"Unknown adhoc metric expressionType: {expression_type!r}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return Metric(
|
||||||
|
id=label,
|
||||||
|
name=label,
|
||||||
|
type=infer_arrow_type(definition, dataset.database.db_engine_spec.engine),
|
||||||
|
definition=definition,
|
||||||
|
aggregation=aggregation,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def adhoc_column_to_semantic_dimension(
|
||||||
|
adhoc: dict[str, Any],
|
||||||
|
dataset: "SqlaTable",
|
||||||
|
dimensions_by_name: dict[str, Dimension],
|
||||||
|
template_processor: "BaseTemplateProcessor | None" = None,
|
||||||
|
) -> Dimension:
|
||||||
|
"""
|
||||||
|
Convert an ``AdhocColumn`` dict to a semantic-layer ``Dimension``.
|
||||||
|
|
||||||
|
When the adhoc is a thin wrapper around a real column
|
||||||
|
(``isColumnReference=True`` and ``sqlExpression`` matches a known
|
||||||
|
dimension name), the existing ``Dimension`` is returned so that
|
||||||
|
metadata such as type and grain are preserved. Otherwise a fresh
|
||||||
|
``Dimension`` with ``definition=<rendered SQL>`` is synthesized.
|
||||||
|
"""
|
||||||
|
label = adhoc.get("label")
|
||||||
|
if not label:
|
||||||
|
raise QueryObjectValidationError("Adhoc column is missing a ``label``.")
|
||||||
|
sql_expression = adhoc.get("sqlExpression")
|
||||||
|
if not sql_expression:
|
||||||
|
raise QueryObjectValidationError(
|
||||||
|
"Adhoc column is missing ``sqlExpression``."
|
||||||
|
)
|
||||||
|
|
||||||
|
if adhoc.get("isColumnReference") and sql_expression in dimensions_by_name:
|
||||||
|
return dimensions_by_name[sql_expression]
|
||||||
|
|
||||||
|
definition = _process_adhoc_sql(dataset, sql_expression, template_processor)
|
||||||
|
return Dimension(
|
||||||
|
id=label,
|
||||||
|
name=label,
|
||||||
|
type=infer_arrow_type(definition, dataset.database.db_engine_spec.engine),
|
||||||
|
definition=definition,
|
||||||
|
)
|
||||||
37
superset/semantic_layers/extension/.gitignore
vendored
Normal file
37
superset/semantic_layers/extension/.gitignore
vendored
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
# Dependencies
|
||||||
|
node_modules/
|
||||||
|
|
||||||
|
# Build outputs from the frontend bundler. The top-level dist/ is checked in
|
||||||
|
# because this extension ships as an in-built bundle loaded via LOCAL_EXTENSIONS.
|
||||||
|
frontend/dist/
|
||||||
|
*.supx
|
||||||
|
|
||||||
|
# Python
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
*.egg-info/
|
||||||
|
.eggs/
|
||||||
|
*.egg
|
||||||
|
.venv/
|
||||||
|
venv/
|
||||||
|
env/
|
||||||
|
ENV/
|
||||||
|
|
||||||
|
# IDE
|
||||||
|
.idea/
|
||||||
|
.vscode/
|
||||||
|
*.swp
|
||||||
|
*.swo
|
||||||
|
|
||||||
|
# OS
|
||||||
|
.DS_Store
|
||||||
|
Thumbs.db
|
||||||
|
|
||||||
|
# Logs
|
||||||
|
*.log
|
||||||
|
npm-debug.log*
|
||||||
|
|
||||||
|
# Environment
|
||||||
|
.env
|
||||||
|
.env.local
|
||||||
4
superset/semantic_layers/extension/backend/.coveragerc
Normal file
4
superset/semantic_layers/extension/backend/.coveragerc
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
[report]
|
||||||
|
exclude_lines =
|
||||||
|
pragma: no cover
|
||||||
|
if TYPE_CHECKING:
|
||||||
21
superset/semantic_layers/extension/backend/conftest.py
Normal file
21
superset/semantic_layers/extension/backend/conftest.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
"""
|
||||||
|
Root conftest that patches the semantic_layer decorator before test collection.
|
||||||
|
|
||||||
|
The decorator in superset_core.semantic_layers.decorators is a placeholder that
|
||||||
|
raises NotImplementedError outside of a running Superset instance. We replace it
|
||||||
|
with a no-op passthrough so the decorated classes can be imported during tests.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import superset_core.semantic_layers.decorators as _dec
|
||||||
|
|
||||||
|
|
||||||
|
def _noop_semantic_layer(**kwargs):
|
||||||
|
"""Return the class unchanged."""
|
||||||
|
|
||||||
|
def wrapper(cls):
|
||||||
|
return cls
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
_dec.semantic_layer = _noop_semantic_layer
|
||||||
13
superset/semantic_layers/extension/backend/pyproject.toml
Normal file
13
superset/semantic_layers/extension/backend/pyproject.toml
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
[project]
|
||||||
|
name = "preset_io-dataset_semantic_layer"
|
||||||
|
version = "0.1.0"
|
||||||
|
license = "All Rights Reserved"
|
||||||
|
dependencies = [
|
||||||
|
"sqlglot",
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.apache_superset_extensions.build]
|
||||||
|
include = [
|
||||||
|
"src/preset_io/dataset_semantic_layer/**/*.py",
|
||||||
|
]
|
||||||
|
exclude = []
|
||||||
@@ -0,0 +1,9 @@
|
|||||||
|
from .layer import DatasetSemanticLayer
|
||||||
|
from .schemas import DatasetConfiguration
|
||||||
|
from .view import DatasetSemanticView
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"DatasetConfiguration",
|
||||||
|
"DatasetSemanticLayer",
|
||||||
|
"DatasetSemanticView",
|
||||||
|
]
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
# The @semantic_layer decorator on DatasetSemanticLayer handles registration
|
||||||
|
# automatically. This import triggers the decorator at extension load time.
|
||||||
|
from .layer import DatasetSemanticLayer # noqa: F401
|
||||||
@@ -0,0 +1,95 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from superset_core.semantic_layers.config import build_configuration_schema
|
||||||
|
from superset_core.semantic_layers.decorators import semantic_layer
|
||||||
|
from superset_core.semantic_layers.layer import SemanticLayer
|
||||||
|
|
||||||
|
from .schemas import DatasetConfiguration
|
||||||
|
from .utils import dataset_label, get_dataset_by_id, list_datasets
|
||||||
|
from .view import DatasetSemanticView
|
||||||
|
|
||||||
|
|
||||||
|
@semantic_layer(
|
||||||
|
id="dataset",
|
||||||
|
name="Dataset Semantic Layer",
|
||||||
|
description=(
|
||||||
|
"Expose any Superset dataset as a semantic view. Metrics and columns "
|
||||||
|
"defined on the dataset become the semantic metrics and dimensions."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
class DatasetSemanticLayer(SemanticLayer[DatasetConfiguration, DatasetSemanticView]):
|
||||||
|
configuration_class = DatasetConfiguration
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_configuration(
|
||||||
|
cls,
|
||||||
|
configuration: dict[str, Any],
|
||||||
|
) -> "DatasetSemanticLayer":
|
||||||
|
config = DatasetConfiguration.model_validate(configuration or {})
|
||||||
|
return cls(config)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_configuration_schema(
|
||||||
|
cls,
|
||||||
|
configuration: DatasetConfiguration | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
No configuration is needed to add this semantic layer — the layer wraps
|
||||||
|
whatever datasets are already registered in Superset.
|
||||||
|
"""
|
||||||
|
return build_configuration_schema(DatasetConfiguration, configuration)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_runtime_schema(
|
||||||
|
cls,
|
||||||
|
configuration: DatasetConfiguration,
|
||||||
|
runtime_data: dict[str, Any] | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
When adding a semantic view, the user picks a dataset from a dropdown.
|
||||||
|
Each existing dataset can be wrapped into a semantic view.
|
||||||
|
"""
|
||||||
|
datasets = list_datasets()
|
||||||
|
ids = [str(dataset.id) for dataset in datasets]
|
||||||
|
labels = [dataset_label(dataset) for dataset in datasets]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"required": ["dataset_id"],
|
||||||
|
"x-singleView": True,
|
||||||
|
"properties": {
|
||||||
|
"dataset_id": {
|
||||||
|
"type": "string",
|
||||||
|
"title": "Dataset",
|
||||||
|
"description": "The Superset dataset to expose as a semantic view.",
|
||||||
|
"enum": ids,
|
||||||
|
"x-enumNames": labels,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, configuration: DatasetConfiguration) -> None:
|
||||||
|
self.configuration = configuration
|
||||||
|
|
||||||
|
def get_semantic_views(
|
||||||
|
self,
|
||||||
|
runtime_configuration: dict[str, Any],
|
||||||
|
) -> set[DatasetSemanticView]:
|
||||||
|
dataset_id = runtime_configuration.get("dataset_id")
|
||||||
|
if not dataset_id:
|
||||||
|
return set()
|
||||||
|
dataset = get_dataset_by_id(int(dataset_id))
|
||||||
|
return {DatasetSemanticView(dataset)}
|
||||||
|
|
||||||
|
def get_semantic_view(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
additional_configuration: dict[str, Any],
|
||||||
|
) -> DatasetSemanticView:
|
||||||
|
dataset_id = additional_configuration.get("dataset_id")
|
||||||
|
if not dataset_id:
|
||||||
|
raise ValueError("dataset_id is required to load a semantic view")
|
||||||
|
dataset = get_dataset_by_id(int(dataset_id))
|
||||||
|
return DatasetSemanticView(dataset)
|
||||||
@@ -0,0 +1,14 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetConfiguration(BaseModel):
|
||||||
|
"""
|
||||||
|
Configuration for the dataset semantic layer.
|
||||||
|
|
||||||
|
The layer wraps Superset datasets directly, so it does not need any
|
||||||
|
configuration beyond what Superset already knows about each dataset.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_config = ConfigDict(title="Dataset semantic layer", extra="ignore")
|
||||||
@@ -0,0 +1,89 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
import pyarrow as pa
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
import pandas as pd
|
||||||
|
from superset.connectors.sqla.models import SqlaTable
|
||||||
|
|
||||||
|
|
||||||
|
def list_datasets() -> list["SqlaTable"]:
|
||||||
|
"""
|
||||||
|
Return every ``SqlaTable`` the current user can access, sorted by name.
|
||||||
|
|
||||||
|
Read inside the function so the module can be imported in unit tests that
|
||||||
|
do not have a Flask app context — these helpers are only called at runtime.
|
||||||
|
"""
|
||||||
|
from superset import db
|
||||||
|
from superset.connectors.sqla.models import SqlaTable
|
||||||
|
|
||||||
|
# ``no_autoflush`` prevents SQLAlchemy from flushing pending session state
|
||||||
|
# while we iterate over results — flushing during iteration of the
|
||||||
|
# session's pending deque has been observed to raise "deque mutated
|
||||||
|
# during iteration" in this codebase.
|
||||||
|
with db.session.no_autoflush:
|
||||||
|
datasets = db.session.query(SqlaTable).all()
|
||||||
|
return sorted(datasets, key=lambda d: d.table_name)
|
||||||
|
|
||||||
|
|
||||||
|
def get_dataset_by_id(dataset_id: int) -> "SqlaTable":
|
||||||
|
"""
|
||||||
|
Look up a dataset by primary key, eagerly materialising its columns and
|
||||||
|
metrics so downstream code can iterate them without triggering further
|
||||||
|
lazy loads (and without risking autoflush mid-iteration).
|
||||||
|
"""
|
||||||
|
from sqlalchemy.orm import selectinload
|
||||||
|
|
||||||
|
from superset import db
|
||||||
|
from superset.connectors.sqla.models import SqlaTable
|
||||||
|
|
||||||
|
with db.session.no_autoflush:
|
||||||
|
dataset = (
|
||||||
|
db.session.query(SqlaTable)
|
||||||
|
.options(
|
||||||
|
selectinload(SqlaTable.columns),
|
||||||
|
selectinload(SqlaTable.metrics),
|
||||||
|
)
|
||||||
|
.filter_by(id=dataset_id)
|
||||||
|
.one_or_none()
|
||||||
|
)
|
||||||
|
if dataset is None:
|
||||||
|
raise ValueError(f"Dataset with id {dataset_id} does not exist.")
|
||||||
|
# Force materialisation inside the no_autoflush block so later access
|
||||||
|
# never lazy-loads while another flush is in progress.
|
||||||
|
list(dataset.columns)
|
||||||
|
list(dataset.metrics)
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
def dataset_label(dataset: "SqlaTable") -> str:
|
||||||
|
"""
|
||||||
|
Human-readable label for a dataset combining schema and table name.
|
||||||
|
"""
|
||||||
|
parts = [
|
||||||
|
part
|
||||||
|
for part in (dataset.catalog, dataset.schema, dataset.table_name)
|
||||||
|
if part
|
||||||
|
]
|
||||||
|
return ".".join(parts) if parts else dataset.table_name
|
||||||
|
|
||||||
|
|
||||||
|
def df_to_arrow(df: "pd.DataFrame") -> pa.Table:
|
||||||
|
"""
|
||||||
|
Convert a pandas DataFrame to a pyarrow Table, falling back gracefully when
|
||||||
|
the DataFrame is empty.
|
||||||
|
"""
|
||||||
|
if df is None or df.empty:
|
||||||
|
return pa.table({col: [] for col in (df.columns if df is not None else [])})
|
||||||
|
return pa.Table.from_pandas(df, preserve_index=False)
|
||||||
|
|
||||||
|
|
||||||
|
def coerce_literal(value: Any) -> Any:
|
||||||
|
"""
|
||||||
|
Coerce filter literal values to native Python types sqlglot can serialise.
|
||||||
|
"""
|
||||||
|
if isinstance(value, (set, frozenset)):
|
||||||
|
return [coerce_literal(v) for v in value]
|
||||||
|
return value
|
||||||
@@ -0,0 +1,840 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import pyarrow as pa
|
||||||
|
import sqlglot
|
||||||
|
from sqlglot import expressions as sqlglot_exp
|
||||||
|
from superset_core.semantic_layers.types import (
|
||||||
|
AdhocExpression,
|
||||||
|
AggregationType,
|
||||||
|
Dimension,
|
||||||
|
Filter,
|
||||||
|
GroupLimit,
|
||||||
|
Metric,
|
||||||
|
Operator,
|
||||||
|
OrderDirection,
|
||||||
|
PredicateType,
|
||||||
|
SemanticQuery,
|
||||||
|
SemanticRequest,
|
||||||
|
SemanticResult,
|
||||||
|
)
|
||||||
|
from superset_core.semantic_layers.view import SemanticView, SemanticViewFeature
|
||||||
|
|
||||||
|
from .utils import coerce_literal, df_to_arrow
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
|
||||||
|
|
||||||
|
REQUEST_TYPE = "sql"
|
||||||
|
|
||||||
|
# Map Superset's GenericDataType (an IntEnum value) to a pyarrow type. Kept
|
||||||
|
# narrow on purpose — the dataset model only tracks a coarse generic type, so
|
||||||
|
# anything more specific belongs to the underlying database and is left as
|
||||||
|
# UTF-8 at the semantic layer boundary.
|
||||||
|
_GENERIC_TYPE_TO_ARROW: dict[int, pa.DataType] = {
|
||||||
|
0: pa.float64(), # NUMERIC
|
||||||
|
1: pa.utf8(), # STRING
|
||||||
|
2: pa.timestamp("us"), # TEMPORAL
|
||||||
|
3: pa.bool_(), # BOOLEAN
|
||||||
|
}
|
||||||
|
|
||||||
|
# Top-level aggregate AST node → AggregationType. Compound expressions like
|
||||||
|
# ``SUM(a) + SUM(b)`` fall through to OTHER because they cannot be safely
|
||||||
|
# rolled up.
|
||||||
|
_AGG_NODE_MAP: dict[type[sqlglot_exp.Expression], AggregationType] = {
|
||||||
|
sqlglot_exp.Sum: AggregationType.SUM,
|
||||||
|
sqlglot_exp.Min: AggregationType.MIN,
|
||||||
|
sqlglot_exp.Max: AggregationType.MAX,
|
||||||
|
sqlglot_exp.Avg: AggregationType.AVG,
|
||||||
|
}
|
||||||
|
|
||||||
|
_OPERATOR_TO_SQLGLOT: dict[Operator, type[sqlglot_exp.Expression]] = {
|
||||||
|
Operator.EQUALS: sqlglot_exp.EQ,
|
||||||
|
Operator.NOT_EQUALS: sqlglot_exp.NEQ,
|
||||||
|
Operator.GREATER_THAN: sqlglot_exp.GT,
|
||||||
|
Operator.LESS_THAN: sqlglot_exp.LT,
|
||||||
|
Operator.GREATER_THAN_OR_EQUAL: sqlglot_exp.GTE,
|
||||||
|
Operator.LESS_THAN_OR_EQUAL: sqlglot_exp.LTE,
|
||||||
|
Operator.LIKE: sqlglot_exp.Like,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetSemanticView(SemanticView):
|
||||||
|
features = frozenset(
|
||||||
|
{
|
||||||
|
SemanticViewFeature.ADHOC_EXPRESSIONS_IN_ORDERBY,
|
||||||
|
SemanticViewFeature.GROUP_LIMIT,
|
||||||
|
SemanticViewFeature.GROUP_OTHERS,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, dataset: "SqlaTable") -> None:
|
||||||
|
self.dataset = dataset
|
||||||
|
self.name = dataset.table_name
|
||||||
|
|
||||||
|
self.dimensions = self.get_dimensions()
|
||||||
|
self.metrics = self.get_metrics()
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Identity & dialect
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def uid(self) -> str:
|
||||||
|
return f"dataset:{self.dataset.id}"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _sqlglot_dialect(self) -> str | None:
|
||||||
|
# Imported lazily — ``superset.sql.parse`` pulls in a heavy dependency
|
||||||
|
# graph that is not needed for unit tests of pure AST building.
|
||||||
|
from superset.sql.parse import SQLGLOT_DIALECTS
|
||||||
|
|
||||||
|
engine = self.dataset.database.db_engine_spec.engine
|
||||||
|
dialect = SQLGLOT_DIALECTS.get(engine)
|
||||||
|
return dialect.value if dialect is not None else None
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Metadata
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def get_dimensions(self) -> set[Dimension]:
|
||||||
|
dimensions: set[Dimension] = set()
|
||||||
|
for column in self.dataset.columns:
|
||||||
|
if not column.groupby:
|
||||||
|
continue
|
||||||
|
dimensions.add(
|
||||||
|
Dimension(
|
||||||
|
id=column.column_name,
|
||||||
|
name=column.verbose_name or column.column_name,
|
||||||
|
type=self._column_arrow_type(column),
|
||||||
|
definition=column.expression or None,
|
||||||
|
description=column.description or None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return dimensions
|
||||||
|
|
||||||
|
def get_metrics(self) -> set[Metric]:
|
||||||
|
from superset.semantic_layers.arrow_inference import infer_arrow_type
|
||||||
|
|
||||||
|
engine = self.dataset.database.db_engine_spec.engine
|
||||||
|
metrics: set[Metric] = set()
|
||||||
|
for metric in self.dataset.metrics:
|
||||||
|
metrics.add(
|
||||||
|
Metric(
|
||||||
|
id=metric.metric_name,
|
||||||
|
name=metric.verbose_name or metric.metric_name,
|
||||||
|
type=infer_arrow_type(metric.expression, engine),
|
||||||
|
definition=metric.expression,
|
||||||
|
description=metric.description or None,
|
||||||
|
aggregation=self._aggregation_from_expression(metric.expression),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
def get_compatible_metrics(
|
||||||
|
self,
|
||||||
|
selected_metrics: set[Metric],
|
||||||
|
selected_dimensions: set[Dimension],
|
||||||
|
) -> set[Metric]:
|
||||||
|
return self.metrics
|
||||||
|
|
||||||
|
def get_compatible_dimensions(
|
||||||
|
self,
|
||||||
|
selected_metrics: set[Metric],
|
||||||
|
selected_dimensions: set[Dimension],
|
||||||
|
) -> set[Dimension]:
|
||||||
|
return self.dimensions
|
||||||
|
|
||||||
|
def _column_arrow_type(self, column: "TableColumn") -> pa.DataType:
|
||||||
|
generic = column.type_generic
|
||||||
|
if generic is None:
|
||||||
|
return pa.utf8()
|
||||||
|
return _GENERIC_TYPE_TO_ARROW.get(int(generic), pa.utf8())
|
||||||
|
|
||||||
|
def _aggregation_from_expression(
|
||||||
|
self,
|
||||||
|
expression: str | None,
|
||||||
|
) -> AggregationType:
|
||||||
|
if not expression:
|
||||||
|
return AggregationType.OTHER
|
||||||
|
|
||||||
|
try:
|
||||||
|
parsed = sqlglot.parse_one(expression, dialect=self._sqlglot_dialect)
|
||||||
|
except sqlglot.errors.ParseError:
|
||||||
|
return AggregationType.OTHER
|
||||||
|
|
||||||
|
if isinstance(parsed, sqlglot_exp.Count):
|
||||||
|
if isinstance(parsed.this, sqlglot_exp.Distinct):
|
||||||
|
return AggregationType.COUNT_DISTINCT
|
||||||
|
return AggregationType.COUNT
|
||||||
|
|
||||||
|
for node_type, aggregation in _AGG_NODE_MAP.items():
|
||||||
|
if isinstance(parsed, node_type):
|
||||||
|
return aggregation
|
||||||
|
|
||||||
|
return AggregationType.OTHER
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# AST building
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _dimension_column(self, dimension: Dimension) -> "TableColumn":
|
||||||
|
# Callers (currently only ``_dimension_expression``) check membership
|
||||||
|
# before invoking this helper, so the match is guaranteed.
|
||||||
|
for column in self.dataset.columns:
|
||||||
|
if column.column_name == dimension.id:
|
||||||
|
return column
|
||||||
|
raise ValueError( # pragma: no cover - guarded by caller
|
||||||
|
f'Dimension "{dimension.id}" is not part of this dataset.'
|
||||||
|
)
|
||||||
|
|
||||||
|
def _metric_column(self, metric: Metric) -> "SqlMetric":
|
||||||
|
for sql_metric in self.dataset.metrics:
|
||||||
|
if sql_metric.metric_name == metric.id:
|
||||||
|
return sql_metric
|
||||||
|
raise ValueError(f'Metric "{metric.id}" is not part of this dataset.')
|
||||||
|
|
||||||
|
def _dimension_expression(self, dimension: Dimension) -> sqlglot_exp.Expression:
|
||||||
|
# Synthesised adhoc dimensions don't correspond to a dataset column;
|
||||||
|
# their ``definition`` carries the SQL we should emit.
|
||||||
|
dataset_column_names = {col.column_name for col in self.dataset.columns}
|
||||||
|
if dimension.id not in dataset_column_names:
|
||||||
|
if not dimension.definition:
|
||||||
|
raise ValueError(
|
||||||
|
f'Dimension "{dimension.id}" is not part of this dataset.'
|
||||||
|
)
|
||||||
|
base: sqlglot_exp.Expression = sqlglot.parse_one(
|
||||||
|
dimension.definition,
|
||||||
|
dialect=self._sqlglot_dialect,
|
||||||
|
)
|
||||||
|
return self._apply_grain(dimension, base)
|
||||||
|
|
||||||
|
column = self._dimension_column(dimension)
|
||||||
|
if column.expression:
|
||||||
|
base = sqlglot.parse_one(
|
||||||
|
column.expression,
|
||||||
|
dialect=self._sqlglot_dialect,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
base = sqlglot_exp.column(column.column_name, quoted=True)
|
||||||
|
return self._apply_grain(dimension, base)
|
||||||
|
|
||||||
|
def _apply_grain(
|
||||||
|
self,
|
||||||
|
dimension: Dimension,
|
||||||
|
base_expr: sqlglot_exp.Expression,
|
||||||
|
) -> sqlglot_exp.Expression:
|
||||||
|
"""
|
||||||
|
Wrap ``base_expr`` in the engine spec's time-grain template when the
|
||||||
|
dimension carries a grain. Engines that don't model the requested
|
||||||
|
grain return the unwrapped expression — same fallback the legacy path
|
||||||
|
uses.
|
||||||
|
"""
|
||||||
|
if dimension.grain is None:
|
||||||
|
return base_expr
|
||||||
|
|
||||||
|
engine_spec = self.dataset.database.db_engine_spec
|
||||||
|
template = engine_spec.get_time_grain_expressions().get(
|
||||||
|
dimension.grain.representation
|
||||||
|
)
|
||||||
|
if not template:
|
||||||
|
return base_expr
|
||||||
|
|
||||||
|
# ``{func}`` / ``{type}`` placeholders (BigQuery, MS SQL) require the
|
||||||
|
# engine spec's full ``get_timestamp_expr`` machinery to resolve.
|
||||||
|
if "{func}" in template or "{type}" in template:
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
sa_col = sa.column(
|
||||||
|
base_expr.sql(dialect=self._sqlglot_dialect),
|
||||||
|
is_literal=True,
|
||||||
|
)
|
||||||
|
tse = engine_spec.get_timestamp_expr(
|
||||||
|
sa_col,
|
||||||
|
None,
|
||||||
|
dimension.grain.representation,
|
||||||
|
)
|
||||||
|
sql_str = str(
|
||||||
|
tse.compile(
|
||||||
|
dialect=self.dataset.database.get_dialect(),
|
||||||
|
compile_kwargs={"literal_binds": True},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return sqlglot.parse_one(sql_str, dialect=self._sqlglot_dialect)
|
||||||
|
|
||||||
|
col_sql = base_expr.sql(dialect=self._sqlglot_dialect)
|
||||||
|
grain_sql = template.replace("{col}", col_sql)
|
||||||
|
return sqlglot.parse_one(grain_sql, dialect=self._sqlglot_dialect)
|
||||||
|
|
||||||
|
def _metric_expression(self, metric: Metric) -> sqlglot_exp.Expression:
|
||||||
|
sql_metric = self._metric_column(metric)
|
||||||
|
return sqlglot.parse_one(sql_metric.expression, dialect=self._sqlglot_dialect)
|
||||||
|
|
||||||
|
def _source_table(self) -> sqlglot_exp.Expression:
|
||||||
|
"""
|
||||||
|
Build the FROM clause. Physical datasets reference the table directly;
|
||||||
|
virtual datasets wrap their SQL as a subquery so the rest of the AST
|
||||||
|
can treat the source like a table.
|
||||||
|
|
||||||
|
For virtual datasets, RLS rules from the underlying tables referenced
|
||||||
|
in ``dataset.sql`` are injected into the inner SQL. When the engine
|
||||||
|
spec asks for ``AS_SUBQUERY``-style RLS, the dataset's own RLS
|
||||||
|
predicates wrap the source in an extra ``SELECT * FROM … WHERE`` so
|
||||||
|
the rest of the AST is unaffected.
|
||||||
|
"""
|
||||||
|
dataset = self.dataset
|
||||||
|
|
||||||
|
if dataset.sql:
|
||||||
|
inner_sql = self._rls_apply_to_virtual_sql() or dataset.sql
|
||||||
|
inner = sqlglot.parse_one(inner_sql, dialect=self._sqlglot_dialect)
|
||||||
|
source: sqlglot_exp.Expression = sqlglot_exp.Subquery(
|
||||||
|
this=inner, alias="virtual_dataset"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
parts = [
|
||||||
|
sqlglot_exp.to_identifier(part, quoted=True)
|
||||||
|
for part in (dataset.catalog, dataset.schema, dataset.table_name)
|
||||||
|
if part
|
||||||
|
]
|
||||||
|
if len(parts) == 1:
|
||||||
|
source = sqlglot_exp.Table(this=parts[0])
|
||||||
|
elif len(parts) == 2:
|
||||||
|
source = sqlglot_exp.Table(this=parts[1], db=parts[0])
|
||||||
|
else:
|
||||||
|
source = sqlglot_exp.Table(this=parts[2], db=parts[1], catalog=parts[0])
|
||||||
|
|
||||||
|
if self._rls_should_wrap_subquery():
|
||||||
|
rls_clauses = self._rls_predicates()
|
||||||
|
if rls_clauses:
|
||||||
|
source = self._wrap_with_rls(source, rls_clauses)
|
||||||
|
|
||||||
|
return source
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# RLS
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _rls_predicates(self) -> list[str]:
|
||||||
|
"""Outer RLS predicates for the dataset, rendered as SQL strings."""
|
||||||
|
from superset.semantic_layers.rls import render_rls_predicates
|
||||||
|
|
||||||
|
return render_rls_predicates(self.dataset)
|
||||||
|
|
||||||
|
def _rls_apply_to_virtual_sql(self) -> str | None:
|
||||||
|
from superset.semantic_layers.rls import apply_rls_to_virtual_sql
|
||||||
|
|
||||||
|
return apply_rls_to_virtual_sql(self.dataset)
|
||||||
|
|
||||||
|
def _rls_should_wrap_subquery(self) -> bool:
|
||||||
|
from superset.semantic_layers.rls import get_rls_method
|
||||||
|
from superset.sql.parse import RLSMethod
|
||||||
|
|
||||||
|
return get_rls_method(self.dataset) == RLSMethod.AS_SUBQUERY
|
||||||
|
|
||||||
|
def _wrap_with_rls(
|
||||||
|
self,
|
||||||
|
source: sqlglot_exp.Expression,
|
||||||
|
rls_clauses: list[str],
|
||||||
|
) -> sqlglot_exp.Subquery:
|
||||||
|
"""Wrap ``source`` in ``SELECT * FROM source WHERE <AND-ed RLS>``."""
|
||||||
|
parsed = [
|
||||||
|
sqlglot.parse_one(clause, dialect=self._sqlglot_dialect)
|
||||||
|
for clause in rls_clauses
|
||||||
|
]
|
||||||
|
combined = parsed[0]
|
||||||
|
for predicate in parsed[1:]:
|
||||||
|
combined = sqlglot_exp.And(this=combined, expression=predicate)
|
||||||
|
inner = (
|
||||||
|
sqlglot_exp.Select()
|
||||||
|
.select(sqlglot_exp.Star())
|
||||||
|
.from_(source)
|
||||||
|
.where(combined)
|
||||||
|
)
|
||||||
|
return sqlglot_exp.Subquery(this=inner, alias="rls_wrapped")
|
||||||
|
|
||||||
|
def _filter_predicate(self, filter_: Filter) -> sqlglot_exp.Expression:
|
||||||
|
if filter_.operator == Operator.ADHOC:
|
||||||
|
if not isinstance(filter_.value, str):
|
||||||
|
raise ValueError("Adhoc filter value must be a SQL string")
|
||||||
|
return sqlglot.parse_one(filter_.value, dialect=self._sqlglot_dialect)
|
||||||
|
|
||||||
|
if filter_.column is None:
|
||||||
|
raise ValueError("Native filters require a column")
|
||||||
|
|
||||||
|
if isinstance(filter_.column, Dimension):
|
||||||
|
column_expr = self._dimension_expression(filter_.column)
|
||||||
|
else:
|
||||||
|
column_expr = self._metric_expression(filter_.column)
|
||||||
|
|
||||||
|
operator = filter_.operator
|
||||||
|
|
||||||
|
if operator == Operator.IS_NULL:
|
||||||
|
return sqlglot_exp.Is(this=column_expr, expression=sqlglot_exp.Null())
|
||||||
|
if operator == Operator.IS_NOT_NULL:
|
||||||
|
return sqlglot_exp.Not(
|
||||||
|
this=sqlglot_exp.Is(this=column_expr, expression=sqlglot_exp.Null())
|
||||||
|
)
|
||||||
|
if operator in (Operator.IN, Operator.NOT_IN):
|
||||||
|
values = coerce_literal(filter_.value)
|
||||||
|
if not isinstance(values, list):
|
||||||
|
values = [values]
|
||||||
|
expressions = [sqlglot_exp.convert(v) for v in values]
|
||||||
|
in_expr = sqlglot_exp.In(this=column_expr, expressions=expressions)
|
||||||
|
return (
|
||||||
|
sqlglot_exp.Not(this=in_expr)
|
||||||
|
if operator == Operator.NOT_IN
|
||||||
|
else in_expr
|
||||||
|
)
|
||||||
|
if operator == Operator.NOT_LIKE:
|
||||||
|
return sqlglot_exp.Not(
|
||||||
|
this=sqlglot_exp.Like(
|
||||||
|
this=column_expr,
|
||||||
|
expression=sqlglot_exp.convert(filter_.value),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
sqlglot_op_cls = _OPERATOR_TO_SQLGLOT.get(operator)
|
||||||
|
if sqlglot_op_cls is None:
|
||||||
|
raise ValueError(f"Unsupported operator: {operator}")
|
||||||
|
return sqlglot_op_cls(
|
||||||
|
this=column_expr,
|
||||||
|
expression=sqlglot_exp.convert(filter_.value),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _combine_predicates(
|
||||||
|
self,
|
||||||
|
filters: set[Filter],
|
||||||
|
) -> sqlglot_exp.Expression | None:
|
||||||
|
if not filters:
|
||||||
|
return None
|
||||||
|
# Sort to make the resulting SQL deterministic. Filter's default
|
||||||
|
# tuple ordering eventually compares ``Dimension`` instances, which
|
||||||
|
# are not orderable, so use a stable string key instead.
|
||||||
|
ordered = sorted(
|
||||||
|
filters,
|
||||||
|
key=lambda f: (
|
||||||
|
f.type.value,
|
||||||
|
f.column.id if f.column is not None else "",
|
||||||
|
f.operator.value,
|
||||||
|
repr(f.value),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
combined = self._filter_predicate(ordered[0])
|
||||||
|
for filter_ in ordered[1:]:
|
||||||
|
combined = sqlglot_exp.And(
|
||||||
|
this=combined,
|
||||||
|
expression=self._filter_predicate(filter_),
|
||||||
|
)
|
||||||
|
return combined
|
||||||
|
|
||||||
|
def _order_expression(
|
||||||
|
self,
|
||||||
|
element: Metric | Dimension | AdhocExpression,
|
||||||
|
direction: OrderDirection,
|
||||||
|
) -> sqlglot_exp.Ordered:
|
||||||
|
if isinstance(element, AdhocExpression):
|
||||||
|
inner = sqlglot.parse_one(
|
||||||
|
element.definition,
|
||||||
|
dialect=self._sqlglot_dialect,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
inner = sqlglot_exp.column(element.id, quoted=True)
|
||||||
|
return sqlglot_exp.Ordered(this=inner, desc=direction == OrderDirection.DESC)
|
||||||
|
|
||||||
|
def _build_ast(self, query: SemanticQuery) -> sqlglot_exp.Select:
|
||||||
|
if query.limit is None and query.offset is not None:
|
||||||
|
raise ValueError("Offset cannot be set without limit")
|
||||||
|
|
||||||
|
filters = query.filters or set()
|
||||||
|
where_filters = {f for f in filters if f.type == PredicateType.WHERE}
|
||||||
|
having_filters = {f for f in filters if f.type == PredicateType.HAVING}
|
||||||
|
|
||||||
|
# Add the dataset's outer RLS predicates as ADHOC where-filters when
|
||||||
|
# the engine wants AS_PREDICATE-style RLS. AS_SUBQUERY-style RLS is
|
||||||
|
# handled in ``_source_table`` instead.
|
||||||
|
if not self._rls_should_wrap_subquery():
|
||||||
|
for clause in self._rls_predicates():
|
||||||
|
where_filters = where_filters | {
|
||||||
|
Filter(
|
||||||
|
type=PredicateType.WHERE,
|
||||||
|
column=None,
|
||||||
|
operator=Operator.ADHOC,
|
||||||
|
value=clause,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
if query.group_limit is not None and query.group_limit.group_others:
|
||||||
|
return self._build_with_others(query, where_filters, having_filters)
|
||||||
|
if query.group_limit is not None:
|
||||||
|
return self._build_with_group_limit(query, where_filters, having_filters)
|
||||||
|
return self._build_plain(query, where_filters, having_filters)
|
||||||
|
|
||||||
|
def _build_plain(
|
||||||
|
self,
|
||||||
|
query: SemanticQuery,
|
||||||
|
where_filters: set[Filter],
|
||||||
|
having_filters: set[Filter],
|
||||||
|
) -> sqlglot_exp.Select:
|
||||||
|
projections: list[sqlglot_exp.Expression] = []
|
||||||
|
for dimension in query.dimensions:
|
||||||
|
projections.append(
|
||||||
|
sqlglot_exp.alias_(
|
||||||
|
self._dimension_expression(dimension),
|
||||||
|
dimension.id,
|
||||||
|
quoted=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for metric in query.metrics:
|
||||||
|
projections.append(
|
||||||
|
sqlglot_exp.alias_(
|
||||||
|
self._metric_expression(metric),
|
||||||
|
metric.id,
|
||||||
|
quoted=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
select = (
|
||||||
|
sqlglot_exp.Select()
|
||||||
|
.select(*projections, append=False)
|
||||||
|
.from_(self._source_table())
|
||||||
|
)
|
||||||
|
|
||||||
|
if where_predicate := self._combine_predicates(where_filters):
|
||||||
|
select = select.where(where_predicate)
|
||||||
|
|
||||||
|
# GROUP BY when aggregating: any metric is an aggregate, so all
|
||||||
|
# non-aggregate dimensions need to be in GROUP BY. For grain-bucketed
|
||||||
|
# dimensions we GROUP BY the full bucketed expression because dialects
|
||||||
|
# disagree on whether SELECT aliases are visible in GROUP BY.
|
||||||
|
if query.metrics and query.dimensions:
|
||||||
|
group_by_columns = [
|
||||||
|
(
|
||||||
|
self._dimension_expression(dimension)
|
||||||
|
if dimension.grain is not None
|
||||||
|
else sqlglot_exp.column(dimension.id, quoted=True)
|
||||||
|
)
|
||||||
|
for dimension in query.dimensions
|
||||||
|
]
|
||||||
|
select = select.group_by(*group_by_columns)
|
||||||
|
|
||||||
|
if having_predicate := self._combine_predicates(having_filters):
|
||||||
|
select = select.having(having_predicate)
|
||||||
|
|
||||||
|
if query.order:
|
||||||
|
ordered = [
|
||||||
|
self._order_expression(element, direction)
|
||||||
|
for element, direction in query.order
|
||||||
|
]
|
||||||
|
select = select.order_by(*ordered)
|
||||||
|
|
||||||
|
if query.limit is not None:
|
||||||
|
select = select.limit(query.limit)
|
||||||
|
if query.offset is not None:
|
||||||
|
select = select.offset(query.offset)
|
||||||
|
|
||||||
|
return select
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Group limit (top N) helpers
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _top_groups_cte_select(
|
||||||
|
self,
|
||||||
|
group_limit: GroupLimit,
|
||||||
|
main_where_filters: set[Filter],
|
||||||
|
) -> sqlglot_exp.Select:
|
||||||
|
"""
|
||||||
|
Build the SELECT that powers the ``top_groups`` CTE.
|
||||||
|
|
||||||
|
The CTE projects only the limited dimensions. The ordering metric, when
|
||||||
|
present, is evaluated inline in ORDER BY so it does not leak into the
|
||||||
|
CTE's column list.
|
||||||
|
"""
|
||||||
|
if group_limit.filters is not None:
|
||||||
|
if any(f.type == PredicateType.HAVING for f in group_limit.filters):
|
||||||
|
raise ValueError(
|
||||||
|
"HAVING filters are not supported in group_limit.filters"
|
||||||
|
)
|
||||||
|
cte_where_filters = {
|
||||||
|
f for f in group_limit.filters if f.type == PredicateType.WHERE
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
cte_where_filters = main_where_filters
|
||||||
|
|
||||||
|
dim_projections = [
|
||||||
|
sqlglot_exp.alias_(
|
||||||
|
self._dimension_expression(dim),
|
||||||
|
dim.id,
|
||||||
|
quoted=True,
|
||||||
|
)
|
||||||
|
for dim in group_limit.dimensions
|
||||||
|
]
|
||||||
|
select = (
|
||||||
|
sqlglot_exp.Select()
|
||||||
|
.select(*dim_projections, append=False)
|
||||||
|
.from_(self._source_table())
|
||||||
|
)
|
||||||
|
|
||||||
|
if where_predicate := self._combine_predicates(cte_where_filters):
|
||||||
|
select = select.where(where_predicate)
|
||||||
|
|
||||||
|
# When ordering by a metric we need an aggregation; GROUP BY the
|
||||||
|
# limited dimensions so each combination collapses to a single row.
|
||||||
|
# Use full expressions for grain-bucketed dims (alias resolution in
|
||||||
|
# GROUP BY isn't portable).
|
||||||
|
if group_limit.metric is not None:
|
||||||
|
select = select.group_by(
|
||||||
|
*[
|
||||||
|
(
|
||||||
|
self._dimension_expression(dim)
|
||||||
|
if dim.grain is not None
|
||||||
|
else sqlglot_exp.column(dim.id, quoted=True)
|
||||||
|
)
|
||||||
|
for dim in group_limit.dimensions
|
||||||
|
]
|
||||||
|
)
|
||||||
|
order_expr: sqlglot_exp.Expression = self._metric_expression(
|
||||||
|
group_limit.metric
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
order_expr = sqlglot_exp.column(
|
||||||
|
group_limit.dimensions[0].id,
|
||||||
|
quoted=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
select = select.order_by(
|
||||||
|
sqlglot_exp.Ordered(
|
||||||
|
this=order_expr,
|
||||||
|
desc=group_limit.direction == OrderDirection.DESC,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return select.limit(group_limit.top)
|
||||||
|
|
||||||
|
def _top_groups_in_predicate(
|
||||||
|
self,
|
||||||
|
group_limit: GroupLimit,
|
||||||
|
) -> sqlglot_exp.Expression:
|
||||||
|
"""
|
||||||
|
Build a predicate restricting the limited dimensions to the rows of the
|
||||||
|
``top_groups`` CTE. Single dimension uses scalar IN; multiple use a
|
||||||
|
row-tuple IN. The subquery is wrapped in ``Subquery`` so sqlglot emits
|
||||||
|
the surrounding parentheses.
|
||||||
|
"""
|
||||||
|
cte_table = sqlglot_exp.to_identifier("top_groups")
|
||||||
|
dim_columns = [
|
||||||
|
sqlglot_exp.column(dim.id, quoted=True)
|
||||||
|
for dim in group_limit.dimensions
|
||||||
|
]
|
||||||
|
subquery = sqlglot_exp.Subquery(
|
||||||
|
this=(
|
||||||
|
sqlglot_exp.Select()
|
||||||
|
.select(*dim_columns)
|
||||||
|
.from_(cte_table)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if len(dim_columns) == 1:
|
||||||
|
return sqlglot_exp.In(this=dim_columns[0], query=subquery)
|
||||||
|
return sqlglot_exp.In(
|
||||||
|
this=sqlglot_exp.Tuple(expressions=dim_columns),
|
||||||
|
query=subquery,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _build_with_group_limit(
|
||||||
|
self,
|
||||||
|
query: SemanticQuery,
|
||||||
|
where_filters: set[Filter],
|
||||||
|
having_filters: set[Filter],
|
||||||
|
) -> sqlglot_exp.Select:
|
||||||
|
"""
|
||||||
|
Restrict the main query to rows whose limited-dimension values appear
|
||||||
|
in the ``top_groups`` CTE.
|
||||||
|
"""
|
||||||
|
select = self._build_plain(query, where_filters, having_filters)
|
||||||
|
select = select.where(self._top_groups_in_predicate(query.group_limit))
|
||||||
|
|
||||||
|
cte_select = self._top_groups_cte_select(query.group_limit, where_filters)
|
||||||
|
return select.with_("top_groups", as_=cte_select)
|
||||||
|
|
||||||
|
def _build_with_others(
|
||||||
|
self,
|
||||||
|
query: SemanticQuery,
|
||||||
|
where_filters: set[Filter],
|
||||||
|
having_filters: set[Filter],
|
||||||
|
) -> sqlglot_exp.Select:
|
||||||
|
"""
|
||||||
|
Bucket non-top values into ``'Other'`` for the limited dimensions and
|
||||||
|
re-aggregate against the bucketed groups. Non-additive metrics stay
|
||||||
|
correct because we re-evaluate the original metric expression instead
|
||||||
|
of summing previously aggregated rows.
|
||||||
|
"""
|
||||||
|
group_limit = query.group_limit
|
||||||
|
limited_dim_ids = {dim.id for dim in group_limit.dimensions}
|
||||||
|
in_predicate = self._top_groups_in_predicate(group_limit)
|
||||||
|
|
||||||
|
def dim_expr(dim: Dimension) -> sqlglot_exp.Expression:
|
||||||
|
"""Underlying expression for a dimension, with CASE for limited ones."""
|
||||||
|
base = self._dimension_expression(dim)
|
||||||
|
if dim.id not in limited_dim_ids:
|
||||||
|
return base
|
||||||
|
return sqlglot_exp.Case(
|
||||||
|
ifs=[sqlglot_exp.If(this=in_predicate.copy(), true=base)],
|
||||||
|
default=sqlglot_exp.Literal.string("Other"),
|
||||||
|
)
|
||||||
|
|
||||||
|
projections: list[sqlglot_exp.Expression] = []
|
||||||
|
for dim in query.dimensions:
|
||||||
|
projections.append(sqlglot_exp.alias_(dim_expr(dim), dim.id, quoted=True))
|
||||||
|
for metric in query.metrics:
|
||||||
|
projections.append(
|
||||||
|
sqlglot_exp.alias_(
|
||||||
|
self._metric_expression(metric),
|
||||||
|
metric.id,
|
||||||
|
quoted=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
select = (
|
||||||
|
sqlglot_exp.Select()
|
||||||
|
.select(*projections, append=False)
|
||||||
|
.from_(self._source_table())
|
||||||
|
)
|
||||||
|
|
||||||
|
if where_predicate := self._combine_predicates(where_filters):
|
||||||
|
select = select.where(where_predicate)
|
||||||
|
|
||||||
|
# GROUP BY the full expressions (not the aliases). Dialects disagree on
|
||||||
|
# whether GROUP BY can reference a SELECT alias, and using the full
|
||||||
|
# expression here side-steps that ambiguity for the CASE columns.
|
||||||
|
if query.metrics and query.dimensions:
|
||||||
|
select = select.group_by(
|
||||||
|
*[dim_expr(dim) for dim in query.dimensions]
|
||||||
|
)
|
||||||
|
|
||||||
|
if having_predicate := self._combine_predicates(having_filters):
|
||||||
|
select = select.having(having_predicate)
|
||||||
|
|
||||||
|
if query.order:
|
||||||
|
ordered = [
|
||||||
|
self._order_expression(element, direction)
|
||||||
|
for element, direction in query.order
|
||||||
|
]
|
||||||
|
select = select.order_by(*ordered)
|
||||||
|
|
||||||
|
if query.limit is not None:
|
||||||
|
select = select.limit(query.limit)
|
||||||
|
if query.offset is not None:
|
||||||
|
select = select.offset(query.offset)
|
||||||
|
|
||||||
|
cte_select = self._top_groups_cte_select(group_limit, where_filters)
|
||||||
|
return select.with_("top_groups", as_=cte_select)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Public API
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def compile_query(self, query: SemanticQuery) -> str:
|
||||||
|
"""Return the SQL for ``query`` in the dataset's database dialect."""
|
||||||
|
return self._build_ast(query).sql(dialect=self._sqlglot_dialect)
|
||||||
|
|
||||||
|
def get_values(
|
||||||
|
self,
|
||||||
|
dimension: Dimension,
|
||||||
|
filters: set[Filter] | None = None,
|
||||||
|
) -> SemanticResult:
|
||||||
|
where_filters = {
|
||||||
|
f for f in (filters or set()) if f.type == PredicateType.WHERE
|
||||||
|
}
|
||||||
|
# Add outer RLS predicates as ADHOC filters; AS_SUBQUERY-mode RLS is
|
||||||
|
# wrapped into ``_source_table`` already.
|
||||||
|
if not self._rls_should_wrap_subquery():
|
||||||
|
for clause in self._rls_predicates():
|
||||||
|
where_filters.add(
|
||||||
|
Filter(
|
||||||
|
type=PredicateType.WHERE,
|
||||||
|
column=None,
|
||||||
|
operator=Operator.ADHOC,
|
||||||
|
value=clause,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
where = self._combine_predicates(where_filters)
|
||||||
|
select = (
|
||||||
|
sqlglot_exp.Select()
|
||||||
|
.distinct()
|
||||||
|
.select(
|
||||||
|
sqlglot_exp.alias_(
|
||||||
|
self._dimension_expression(dimension),
|
||||||
|
dimension.id,
|
||||||
|
quoted=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.from_(self._source_table())
|
||||||
|
)
|
||||||
|
if where is not None:
|
||||||
|
select = select.where(where)
|
||||||
|
|
||||||
|
sql = select.sql(dialect=self._sqlglot_dialect)
|
||||||
|
df = self.dataset.database.get_df(
|
||||||
|
sql,
|
||||||
|
catalog=self.dataset.catalog,
|
||||||
|
schema=self.dataset.schema,
|
||||||
|
)
|
||||||
|
return SemanticResult(
|
||||||
|
requests=[SemanticRequest(REQUEST_TYPE, sql)],
|
||||||
|
results=df_to_arrow(df),
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_table(self, query: SemanticQuery) -> SemanticResult:
|
||||||
|
if not query.metrics and not query.dimensions:
|
||||||
|
return SemanticResult(requests=[], results=pa.table({}))
|
||||||
|
|
||||||
|
sql = self.compile_query(query)
|
||||||
|
df = self.dataset.database.get_df(
|
||||||
|
sql,
|
||||||
|
catalog=self.dataset.catalog,
|
||||||
|
schema=self.dataset.schema,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Map alias columns back to dimension/metric *names* so the result uses
|
||||||
|
# human-friendly labels instead of internal IDs.
|
||||||
|
rename: dict[str, str] = {}
|
||||||
|
for dimension in query.dimensions:
|
||||||
|
rename[dimension.id] = dimension.name
|
||||||
|
for metric in query.metrics:
|
||||||
|
rename[metric.id] = metric.name
|
||||||
|
if df is not None and not df.empty and rename:
|
||||||
|
df = df.rename(columns=rename)
|
||||||
|
|
||||||
|
return SemanticResult(
|
||||||
|
requests=[SemanticRequest(REQUEST_TYPE, sql)],
|
||||||
|
results=df_to_arrow(df),
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_row_count(self, query: SemanticQuery) -> SemanticResult:
|
||||||
|
if not query.metrics and not query.dimensions:
|
||||||
|
return SemanticResult(requests=[], results=pa.table({"COUNT": [0]}))
|
||||||
|
|
||||||
|
inner = self._build_ast(query)
|
||||||
|
count_select = (
|
||||||
|
sqlglot_exp.Select()
|
||||||
|
.select(
|
||||||
|
sqlglot_exp.alias_(
|
||||||
|
sqlglot_exp.Count(this=sqlglot_exp.Star()),
|
||||||
|
"COUNT",
|
||||||
|
quoted=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.from_(sqlglot_exp.Subquery(this=inner, alias="subquery"))
|
||||||
|
)
|
||||||
|
sql = count_select.sql(dialect=self._sqlglot_dialect)
|
||||||
|
df = self.dataset.database.get_df(
|
||||||
|
sql,
|
||||||
|
catalog=self.dataset.catalog,
|
||||||
|
schema=self.dataset.schema,
|
||||||
|
)
|
||||||
|
|
||||||
|
return SemanticResult(
|
||||||
|
requests=[SemanticRequest(REQUEST_TYPE, sql)],
|
||||||
|
results=df_to_arrow(df),
|
||||||
|
)
|
||||||
|
|
||||||
|
__repr__ = uid
|
||||||
@@ -0,0 +1,80 @@
|
|||||||
|
# flake8: noqa: E501
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from preset_io.dataset_semantic_layer import (
|
||||||
|
DatasetConfiguration,
|
||||||
|
DatasetSemanticLayer,
|
||||||
|
DatasetSemanticView,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def dataset() -> MagicMock:
|
||||||
|
db = MagicMock()
|
||||||
|
db.db_engine_spec.engine = "postgresql"
|
||||||
|
ds = MagicMock(spec=[])
|
||||||
|
ds.id = 1
|
||||||
|
ds.table_name = "orders"
|
||||||
|
ds.schema = "public"
|
||||||
|
ds.catalog = None
|
||||||
|
ds.sql = None
|
||||||
|
ds.database = db
|
||||||
|
ds.columns = []
|
||||||
|
ds.metrics = []
|
||||||
|
return ds
|
||||||
|
|
||||||
|
|
||||||
|
def test_empty_configuration_schema() -> None:
|
||||||
|
schema = DatasetSemanticLayer.get_configuration_schema()
|
||||||
|
# No required fields — adding the layer takes zero input.
|
||||||
|
assert schema.get("required", []) == []
|
||||||
|
assert schema["properties"] == {}
|
||||||
|
|
||||||
|
|
||||||
|
def test_runtime_schema_exposes_dataset_dropdown(mocker, dataset: MagicMock) -> None:
|
||||||
|
mocker.patch(
|
||||||
|
"preset_io.dataset_semantic_layer.layer.list_datasets",
|
||||||
|
return_value=[dataset],
|
||||||
|
)
|
||||||
|
schema = DatasetSemanticLayer.get_runtime_schema(DatasetConfiguration())
|
||||||
|
|
||||||
|
assert schema["required"] == ["dataset_id"]
|
||||||
|
field = schema["properties"]["dataset_id"]
|
||||||
|
assert field["enum"] == ["1"]
|
||||||
|
assert field["x-enumNames"] == ["public.orders"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_semantic_view_requires_dataset_id() -> None:
|
||||||
|
layer = DatasetSemanticLayer.from_configuration({})
|
||||||
|
with pytest.raises(ValueError, match="dataset_id"):
|
||||||
|
layer.get_semantic_view("orders", {})
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_semantic_view_returns_dataset_view(mocker, dataset: MagicMock) -> None:
|
||||||
|
mocker.patch(
|
||||||
|
"preset_io.dataset_semantic_layer.layer.get_dataset_by_id",
|
||||||
|
return_value=dataset,
|
||||||
|
)
|
||||||
|
layer = DatasetSemanticLayer.from_configuration({})
|
||||||
|
view = layer.get_semantic_view("orders", {"dataset_id": "1"})
|
||||||
|
|
||||||
|
assert isinstance(view, DatasetSemanticView)
|
||||||
|
assert view.uid() == "dataset:1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_semantic_views_returns_empty_when_no_id() -> None:
|
||||||
|
layer = DatasetSemanticLayer.from_configuration({})
|
||||||
|
assert layer.get_semantic_views({}) == set()
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_semantic_views_resolves_dataset(mocker, dataset: MagicMock) -> None:
|
||||||
|
mocker.patch(
|
||||||
|
"preset_io.dataset_semantic_layer.layer.get_dataset_by_id",
|
||||||
|
return_value=dataset,
|
||||||
|
)
|
||||||
|
layer = DatasetSemanticLayer.from_configuration({})
|
||||||
|
views = layer.get_semantic_views({"dataset_id": "1"})
|
||||||
|
assert len(views) == 1
|
||||||
|
assert next(iter(views)).uid() == "dataset:1"
|
||||||
File diff suppressed because it is too large
Load Diff
157
superset/semantic_layers/extension/backend/tests/test_utils.py
Normal file
157
superset/semantic_layers/extension/backend/tests/test_utils.py
Normal file
@@ -0,0 +1,157 @@
|
|||||||
|
# flake8: noqa: E501
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import types
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import pyarrow as pa
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from preset_io.dataset_semantic_layer.utils import (
|
||||||
|
coerce_literal,
|
||||||
|
dataset_label,
|
||||||
|
df_to_arrow,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Pure helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_coerce_literal_passes_through_scalars() -> None:
|
||||||
|
assert coerce_literal("x") == "x"
|
||||||
|
assert coerce_literal(42) == 42
|
||||||
|
assert coerce_literal(None) is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_coerce_literal_unwraps_sets() -> None:
|
||||||
|
result = coerce_literal({"a", "b"})
|
||||||
|
assert isinstance(result, list)
|
||||||
|
assert set(result) == {"a", "b"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_dataset_label_combines_parts() -> None:
|
||||||
|
ds = MagicMock(spec=[])
|
||||||
|
ds.catalog = "warehouse"
|
||||||
|
ds.schema = "public"
|
||||||
|
ds.table_name = "orders"
|
||||||
|
assert dataset_label(ds) == "warehouse.public.orders"
|
||||||
|
|
||||||
|
|
||||||
|
def test_dataset_label_skips_blank_parts() -> None:
|
||||||
|
ds = MagicMock(spec=[])
|
||||||
|
ds.catalog = None
|
||||||
|
ds.schema = None
|
||||||
|
ds.table_name = "orders"
|
||||||
|
assert dataset_label(ds) == "orders"
|
||||||
|
|
||||||
|
|
||||||
|
def test_df_to_arrow_with_data() -> None:
|
||||||
|
df = pd.DataFrame({"a": [1, 2], "b": ["x", "y"]})
|
||||||
|
table = df_to_arrow(df)
|
||||||
|
assert table.num_rows == 2
|
||||||
|
assert table.column_names == ["a", "b"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_df_to_arrow_with_empty_df() -> None:
|
||||||
|
df = pd.DataFrame(columns=["a", "b"])
|
||||||
|
table = df_to_arrow(df)
|
||||||
|
assert table.num_rows == 0
|
||||||
|
assert table.column_names == ["a", "b"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_df_to_arrow_with_none() -> None:
|
||||||
|
table = df_to_arrow(None)
|
||||||
|
assert table.num_rows == 0
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Session-bound helpers — superset is stubbed so we don't require a Flask app
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def fake_superset(monkeypatch: pytest.MonkeyPatch) -> tuple[MagicMock, type]:
|
||||||
|
"""
|
||||||
|
Provide a minimal stand-in for ``superset.db`` and
|
||||||
|
``superset.connectors.sqla.models.SqlaTable`` so list/get helpers run
|
||||||
|
without a real Flask app or database.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class FakeSqlaTable: # noqa: D401 — marker class only
|
||||||
|
"""Sentinel used as the queried entity."""
|
||||||
|
|
||||||
|
columns: list = []
|
||||||
|
metrics: list = []
|
||||||
|
|
||||||
|
session = MagicMock()
|
||||||
|
db_module = types.ModuleType("superset")
|
||||||
|
db_module.db = MagicMock()
|
||||||
|
db_module.db.session = session
|
||||||
|
|
||||||
|
sqla_models = types.ModuleType("superset.connectors.sqla.models")
|
||||||
|
sqla_models.SqlaTable = FakeSqlaTable
|
||||||
|
connectors_module = types.ModuleType("superset.connectors")
|
||||||
|
connectors_sqla = types.ModuleType("superset.connectors.sqla")
|
||||||
|
|
||||||
|
monkeypatch.setitem(sys.modules, "superset", db_module)
|
||||||
|
monkeypatch.setitem(sys.modules, "superset.connectors", connectors_module)
|
||||||
|
monkeypatch.setitem(sys.modules, "superset.connectors.sqla", connectors_sqla)
|
||||||
|
monkeypatch.setitem(sys.modules, "superset.connectors.sqla.models", sqla_models)
|
||||||
|
|
||||||
|
# Also stub sqlalchemy.orm.selectinload for get_dataset_by_id.
|
||||||
|
return session, FakeSqlaTable
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_datasets_returns_sorted(fake_superset) -> None:
|
||||||
|
from preset_io.dataset_semantic_layer.utils import list_datasets
|
||||||
|
|
||||||
|
session, _ = fake_superset
|
||||||
|
|
||||||
|
a, b, c = MagicMock(), MagicMock(), MagicMock()
|
||||||
|
a.table_name = "z_users"
|
||||||
|
b.table_name = "a_orders"
|
||||||
|
c.table_name = "m_payments"
|
||||||
|
session.no_autoflush.__enter__.return_value = None
|
||||||
|
session.no_autoflush.__exit__.return_value = False
|
||||||
|
session.query.return_value.all.return_value = [a, b, c]
|
||||||
|
|
||||||
|
result = list_datasets()
|
||||||
|
assert [d.table_name for d in result] == ["a_orders", "m_payments", "z_users"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_dataset_by_id_raises_when_missing(fake_superset) -> None:
|
||||||
|
from preset_io.dataset_semantic_layer.utils import get_dataset_by_id
|
||||||
|
|
||||||
|
session, _ = fake_superset
|
||||||
|
|
||||||
|
session.no_autoflush.__enter__.return_value = None
|
||||||
|
session.no_autoflush.__exit__.return_value = False
|
||||||
|
(
|
||||||
|
session.query.return_value.options.return_value.filter_by.return_value.one_or_none.return_value
|
||||||
|
) = None
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Dataset with id 42 does not exist."):
|
||||||
|
get_dataset_by_id(42)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_dataset_by_id_materialises_relationships(fake_superset) -> None:
|
||||||
|
from preset_io.dataset_semantic_layer.utils import get_dataset_by_id
|
||||||
|
|
||||||
|
session, _ = fake_superset
|
||||||
|
|
||||||
|
dataset = MagicMock()
|
||||||
|
# Use real lists so iteration in get_dataset_by_id doesn't error.
|
||||||
|
dataset.columns = ["col1", "col2"]
|
||||||
|
dataset.metrics = ["m1"]
|
||||||
|
|
||||||
|
session.no_autoflush.__enter__.return_value = None
|
||||||
|
session.no_autoflush.__exit__.return_value = False
|
||||||
|
(
|
||||||
|
session.query.return_value.options.return_value.filter_by.return_value.one_or_none.return_value
|
||||||
|
) = dataset
|
||||||
|
|
||||||
|
result = get_dataset_by_id(1)
|
||||||
|
assert result is dataset
|
||||||
8
superset/semantic_layers/extension/extension.json
Normal file
8
superset/semantic_layers/extension/extension.json
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
{
|
||||||
|
"publisher": "preset-io",
|
||||||
|
"name": "dataset-semantic-layer",
|
||||||
|
"displayName": "Dataset Semantic Layer",
|
||||||
|
"version": "0.1.0",
|
||||||
|
"license": "All Rights Reserved",
|
||||||
|
"permissions": []
|
||||||
|
}
|
||||||
8958
superset/semantic_layers/extension/frontend/package-lock.json
generated
Normal file
8958
superset/semantic_layers/extension/frontend/package-lock.json
generated
Normal file
File diff suppressed because it is too large
Load Diff
34
superset/semantic_layers/extension/frontend/package.json
Normal file
34
superset/semantic_layers/extension/frontend/package.json
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
{
|
||||||
|
"name": "dataset_semantic_layer",
|
||||||
|
"version": "0.1.0",
|
||||||
|
"main": "dist/main.js",
|
||||||
|
"types": "dist/publicAPI.d.ts",
|
||||||
|
"scripts": {
|
||||||
|
"test": "echo \"Error: no test specified\" && exit 1",
|
||||||
|
"start": "webpack serve --mode development",
|
||||||
|
"build": "webpack --stats-error-details --mode production"
|
||||||
|
},
|
||||||
|
"keywords": [],
|
||||||
|
"private": true,
|
||||||
|
"author": "",
|
||||||
|
"license": "All Rights Reserved",
|
||||||
|
"description": "",
|
||||||
|
"peerDependencies": {
|
||||||
|
"@apache-superset/core": "0.1.0-rc2",
|
||||||
|
"react": "^17.0.2",
|
||||||
|
"react-dom": "^17.0.2"
|
||||||
|
},
|
||||||
|
"devDependencies": {
|
||||||
|
"@babel/preset-react": "^7.26.3",
|
||||||
|
"@babel/preset-typescript": "^7.26.0",
|
||||||
|
"@types/react": "^19.0.10",
|
||||||
|
"copy-webpack-plugin": "^13.0.0",
|
||||||
|
"install": "^0.13.0",
|
||||||
|
"npm": "^11.1.0",
|
||||||
|
"ts-loader": "^9.5.2",
|
||||||
|
"typescript": "^5.8.2",
|
||||||
|
"webpack": "^5.98.0",
|
||||||
|
"webpack-cli": "^6.0.1",
|
||||||
|
"webpack-dev-server": "^5.2.0"
|
||||||
|
}
|
||||||
|
}
|
||||||
17
superset/semantic_layers/extension/frontend/src/index.tsx
Normal file
17
superset/semantic_layers/extension/frontend/src/index.tsx
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
import React from "react";
|
||||||
|
import { views } from "@apache-superset/core";
|
||||||
|
|
||||||
|
const viewDisposable = views.registerView(
|
||||||
|
{ id: "dataset.semantic-layer", name: "Dataset Semantic Layer" },
|
||||||
|
"sqllab.panels",
|
||||||
|
() => <p>Dataset Semantic Layer</p>
|
||||||
|
);
|
||||||
|
|
||||||
|
export const activate = () => {
|
||||||
|
console.log("Dataset Semantic Layer extension activated");
|
||||||
|
};
|
||||||
|
|
||||||
|
export const deactivate = () => {
|
||||||
|
viewDisposable.dispose();
|
||||||
|
console.log("Dataset Semantic Layer extension deactivated");
|
||||||
|
};
|
||||||
13
superset/semantic_layers/extension/frontend/tsconfig.json
Normal file
13
superset/semantic_layers/extension/frontend/tsconfig.json
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
{
|
||||||
|
"compilerOptions": {
|
||||||
|
"target": "es5",
|
||||||
|
"module": "esnext",
|
||||||
|
"moduleResolution": "node10",
|
||||||
|
"jsx": "react",
|
||||||
|
"strict": true,
|
||||||
|
"esModuleInterop": true,
|
||||||
|
"skipLibCheck": true,
|
||||||
|
"forceConsistentCasingInFileNames": true
|
||||||
|
},
|
||||||
|
"include": ["src"]
|
||||||
|
}
|
||||||
@@ -0,0 +1,67 @@
|
|||||||
|
const path = require("path");
|
||||||
|
const { ModuleFederationPlugin } = require("webpack").container;
|
||||||
|
const packageConfig = require("./package");
|
||||||
|
|
||||||
|
module.exports = (env, argv) => {
|
||||||
|
const isProd = argv.mode === "production";
|
||||||
|
|
||||||
|
return {
|
||||||
|
entry: isProd ? {} : "./src/index.tsx",
|
||||||
|
mode: isProd ? "production" : "development",
|
||||||
|
devServer: {
|
||||||
|
port: 3000,
|
||||||
|
headers: {
|
||||||
|
"Access-Control-Allow-Origin": "*",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
output: {
|
||||||
|
clean: true,
|
||||||
|
filename: isProd ? undefined : "[name].[contenthash].js",
|
||||||
|
chunkFilename: "[name].[contenthash].js",
|
||||||
|
path: path.resolve(__dirname, "dist"),
|
||||||
|
publicPath: `/api/v1/extensions/${packageConfig.name}/`,
|
||||||
|
},
|
||||||
|
resolve: {
|
||||||
|
extensions: [".ts", ".tsx", ".js", ".jsx"],
|
||||||
|
},
|
||||||
|
externalsType: "window",
|
||||||
|
externals: {
|
||||||
|
"@apache-superset/core": "superset",
|
||||||
|
},
|
||||||
|
module: {
|
||||||
|
rules: [
|
||||||
|
{
|
||||||
|
test: /\.tsx?$/,
|
||||||
|
use: "ts-loader",
|
||||||
|
exclude: /node_modules/,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
plugins: [
|
||||||
|
new ModuleFederationPlugin({
|
||||||
|
name: "dataset_semantic_layer",
|
||||||
|
filename: "remoteEntry.[contenthash].js",
|
||||||
|
exposes: {
|
||||||
|
"./index": "./src/index.tsx",
|
||||||
|
},
|
||||||
|
shared: {
|
||||||
|
react: {
|
||||||
|
singleton: true,
|
||||||
|
requiredVersion: packageConfig.peerDependencies.react,
|
||||||
|
import: false,
|
||||||
|
},
|
||||||
|
"react-dom": {
|
||||||
|
singleton: true,
|
||||||
|
requiredVersion: packageConfig.peerDependencies["react-dom"],
|
||||||
|
import: false,
|
||||||
|
},
|
||||||
|
antd: {
|
||||||
|
singleton: true,
|
||||||
|
requiredVersion: packageConfig.peerDependencies["antd"],
|
||||||
|
import: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
],
|
||||||
|
};
|
||||||
|
};
|
||||||
@@ -24,6 +24,7 @@ single dataframe.
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
from datetime import date, datetime, time, timedelta, tzinfo
|
from datetime import date, datetime, time, timedelta, tzinfo
|
||||||
from time import time as current_time
|
from time import time as current_time
|
||||||
from typing import Any, cast, Sequence, TypeGuard
|
from typing import Any, cast, Sequence, TypeGuard
|
||||||
@@ -56,7 +57,11 @@ from superset.common.utils.time_range_utils import get_since_until_from_query_ob
|
|||||||
from superset.connectors.sqla.models import BaseDatasource
|
from superset.connectors.sqla.models import BaseDatasource
|
||||||
from superset.constants import NO_TIME_RANGE
|
from superset.constants import NO_TIME_RANGE
|
||||||
from superset.models.helpers import QueryResult
|
from superset.models.helpers import QueryResult
|
||||||
from superset.superset_typing import AdhocColumn
|
from superset.semantic_layers.adhoc import (
|
||||||
|
adhoc_column_to_semantic_dimension,
|
||||||
|
adhoc_metric_to_semantic_metric,
|
||||||
|
)
|
||||||
|
from superset.superset_typing import AdhocColumn, AdhocMetric
|
||||||
from superset.utils.core import (
|
from superset.utils.core import (
|
||||||
FilterOperator,
|
FilterOperator,
|
||||||
QueryObjectFilterClause,
|
QueryObjectFilterClause,
|
||||||
@@ -258,16 +263,92 @@ def _normalize_column(column: str | AdhocColumn, dimension_names: set[str]) -> s
|
|||||||
- A string (dimension name directly)
|
- A string (dimension name directly)
|
||||||
- An AdhocColumn with isColumnReference=True and sqlExpression containing the
|
- An AdhocColumn with isColumnReference=True and sqlExpression containing the
|
||||||
dimension name
|
dimension name
|
||||||
|
|
||||||
|
Used by callers that only care about the resolved *name* (e.g. validating
|
||||||
|
that a granularity column matches a real dimension). Synthetic adhocs
|
||||||
|
(non-column-reference) carry no dimension name and should be routed
|
||||||
|
through :func:`_resolve_dimension` instead.
|
||||||
"""
|
"""
|
||||||
if isinstance(column, str):
|
if isinstance(column, str):
|
||||||
return column
|
return column
|
||||||
|
|
||||||
# Handle column references (e.g., from time-series charts)
|
# Column reference adhocs unwrap to their underlying column name.
|
||||||
if column.get("isColumnReference") and (sql_expr := column.get("sqlExpression")):
|
if column.get("isColumnReference") and (sql_expr := column.get("sqlExpression")):
|
||||||
if sql_expr in dimension_names:
|
if sql_expr in dimension_names:
|
||||||
return sql_expr
|
return sql_expr
|
||||||
|
|
||||||
raise ValueError("Adhoc dimensions are not supported in Semantic Views.")
|
# Synthetic adhoc: the label *is* the resolved name in semantic-layer
|
||||||
|
# terms (we use it as the Dimension id/name when materialising). Falling
|
||||||
|
# back here keeps validators happy with adhoc dicts and defers actual
|
||||||
|
# resolution to ``_resolve_dimension``.
|
||||||
|
if label := column.get("label"):
|
||||||
|
return label
|
||||||
|
|
||||||
|
raise ValueError("Adhoc column is missing a ``label``.")
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_dimension(
|
||||||
|
column: str | AdhocColumn,
|
||||||
|
dataset: BaseDatasource,
|
||||||
|
dimensions_by_name: dict[str, Dimension],
|
||||||
|
template_processor: Any | None,
|
||||||
|
) -> Dimension:
|
||||||
|
"""
|
||||||
|
Resolve a column entry (string name or adhoc dict) to a ``Dimension``.
|
||||||
|
|
||||||
|
Strings look up a saved dimension by name. Adhoc dicts marked
|
||||||
|
``isColumnReference=True`` re-use the matching saved dimension so its
|
||||||
|
metadata (type, grain) is preserved. Anything else is synthesised via
|
||||||
|
:func:`adhoc_column_to_semantic_dimension`.
|
||||||
|
"""
|
||||||
|
if isinstance(column, str):
|
||||||
|
if column not in dimensions_by_name:
|
||||||
|
raise ValueError(
|
||||||
|
f"Dimension {column!r} is not defined in the Semantic View."
|
||||||
|
)
|
||||||
|
return dimensions_by_name[column]
|
||||||
|
|
||||||
|
return adhoc_column_to_semantic_dimension(
|
||||||
|
column,
|
||||||
|
dataset,
|
||||||
|
dimensions_by_name,
|
||||||
|
template_processor,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_metric(
|
||||||
|
metric: str | AdhocMetric,
|
||||||
|
dataset: BaseDatasource,
|
||||||
|
metrics_by_name: dict[str, Metric],
|
||||||
|
template_processor: Any | None,
|
||||||
|
) -> Metric:
|
||||||
|
"""
|
||||||
|
Resolve a metric entry (string name or adhoc dict) to a ``Metric``.
|
||||||
|
"""
|
||||||
|
if isinstance(metric, str):
|
||||||
|
if metric not in metrics_by_name:
|
||||||
|
raise ValueError(
|
||||||
|
f"Metric {metric!r} is not defined in the Semantic View."
|
||||||
|
)
|
||||||
|
return metrics_by_name[metric]
|
||||||
|
return adhoc_metric_to_semantic_metric(metric, dataset, template_processor)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_template_processor(dataset: BaseDatasource) -> Any | None:
|
||||||
|
"""
|
||||||
|
Build a template processor for adhoc Jinja rendering, or ``None`` when
|
||||||
|
the datasource doesn't expose one (e.g. a cube-mode semantic view, where
|
||||||
|
adhocs aren't supported anyway).
|
||||||
|
"""
|
||||||
|
get_template_processor = getattr(dataset, "get_template_processor", None)
|
||||||
|
if get_template_processor is None:
|
||||||
|
return None
|
||||||
|
return get_template_processor()
|
||||||
|
|
||||||
|
|
||||||
|
def _stamp_grain(dimension: Dimension, grain: Grain) -> Dimension:
|
||||||
|
"""Return a copy of ``dimension`` with the supplied grain set."""
|
||||||
|
return dataclasses.replace(dimension, grain=grain)
|
||||||
|
|
||||||
|
|
||||||
def map_query_object(query_object: ValidatedQueryObject) -> list[SemanticQuery]:
|
def map_query_object(query_object: ValidatedQueryObject) -> list[SemanticQuery]:
|
||||||
@@ -278,33 +359,54 @@ def map_query_object(query_object: ValidatedQueryObject) -> list[SemanticQuery]:
|
|||||||
visualization and more on semantics.
|
visualization and more on semantics.
|
||||||
"""
|
"""
|
||||||
semantic_view = query_object.datasource.implementation
|
semantic_view = query_object.datasource.implementation
|
||||||
|
dataset = query_object.datasource
|
||||||
|
|
||||||
all_metrics = {metric.name: metric for metric in semantic_view.metrics}
|
all_metrics = {metric.name: metric for metric in semantic_view.metrics}
|
||||||
all_dimensions = {
|
all_dimensions = {
|
||||||
dimension.name: dimension for dimension in semantic_view.dimensions
|
dimension.name: dimension for dimension in semantic_view.dimensions
|
||||||
}
|
}
|
||||||
|
|
||||||
# Normalize columns (may be dicts with isColumnReference=True for time-series)
|
template_processor = _get_template_processor(dataset)
|
||||||
dimension_names = set(all_dimensions.keys())
|
|
||||||
normalized_columns = {
|
|
||||||
_normalize_column(column, dimension_names) for column in query_object.columns
|
|
||||||
}
|
|
||||||
|
|
||||||
metrics = [all_metrics[metric] for metric in (query_object.metrics or [])]
|
metrics = [
|
||||||
|
_resolve_metric(metric, dataset, all_metrics, template_processor)
|
||||||
|
for metric in (query_object.metrics or [])
|
||||||
|
]
|
||||||
|
|
||||||
|
# Resolve each requested column, preserving order and deduplicating by id.
|
||||||
|
seen_dim_ids: set[str] = set()
|
||||||
|
dimensions: list[Dimension] = []
|
||||||
|
for column in query_object.columns:
|
||||||
|
dim = _resolve_dimension(
|
||||||
|
column,
|
||||||
|
dataset,
|
||||||
|
all_dimensions,
|
||||||
|
template_processor,
|
||||||
|
)
|
||||||
|
if dim.id in seen_dim_ids:
|
||||||
|
continue
|
||||||
|
seen_dim_ids.add(dim.id)
|
||||||
|
dimensions.append(dim)
|
||||||
|
|
||||||
grain = _convert_time_grain(query_object.extras.get("time_grain_sqla"))
|
grain = _convert_time_grain(query_object.extras.get("time_grain_sqla"))
|
||||||
dimensions = [
|
if grain is not None and query_object.granularity:
|
||||||
dimension
|
# Apply the requested grain to the granularity dimension. For cube-mode
|
||||||
for dimension in semantic_view.dimensions
|
# views that pre-declare grain on their dimensions, keep the existing
|
||||||
if dimension.name in normalized_columns
|
# "only include the matching-grain version" semantic. For dataset-mode
|
||||||
and (
|
# views (dimensions have grain=None), the view applies the grain via
|
||||||
# if a grain is specified, only include the time dimension if its grain
|
# the engine spec at SQL build time, so we just stamp the grain on
|
||||||
# matches the requested grain
|
# the dimension here.
|
||||||
grain is None
|
dimensions = [
|
||||||
or dimension.name != query_object.granularity
|
(
|
||||||
or dimension.grain == grain
|
_stamp_grain(dim, grain)
|
||||||
)
|
if dim.name == query_object.granularity and dim.grain is None
|
||||||
]
|
else dim
|
||||||
|
)
|
||||||
|
for dim in dimensions
|
||||||
|
if dim.name != query_object.granularity
|
||||||
|
or dim.grain is None
|
||||||
|
or dim.grain == grain
|
||||||
|
]
|
||||||
|
|
||||||
order = _get_order_from_query_object(query_object, all_metrics, all_dimensions)
|
order = _get_order_from_query_object(query_object, all_metrics, all_dimensions)
|
||||||
limit = query_object.row_limit
|
limit = query_object.row_limit
|
||||||
@@ -928,32 +1030,53 @@ def validate_query_object(
|
|||||||
|
|
||||||
def _validate_metrics(query_object: ValidatedQueryObject) -> None:
|
def _validate_metrics(query_object: ValidatedQueryObject) -> None:
|
||||||
"""
|
"""
|
||||||
Make sure metrics are defined in the semantic view.
|
Make sure metrics are defined in the semantic view or are valid adhocs.
|
||||||
|
|
||||||
|
Validation of adhoc metrics is delegated to ``_resolve_metric`` — it
|
||||||
|
raises if the SIMPLE shape is malformed or the SQL fails Jinja /
|
||||||
|
safety checks.
|
||||||
"""
|
"""
|
||||||
semantic_view = query_object.datasource.implementation
|
semantic_view = query_object.datasource.implementation
|
||||||
|
dataset = query_object.datasource
|
||||||
|
metrics_by_name = {metric.name: metric for metric in semantic_view.metrics}
|
||||||
|
template_processor = _get_template_processor(dataset)
|
||||||
|
|
||||||
if any(not isinstance(metric, str) for metric in (query_object.metrics or [])):
|
labels: list[str] = []
|
||||||
raise ValueError("Adhoc metrics are not supported in Semantic Views.")
|
for metric in query_object.metrics or []:
|
||||||
|
resolved = _resolve_metric(
|
||||||
|
metric, dataset, metrics_by_name, template_processor
|
||||||
|
)
|
||||||
|
labels.append(resolved.id)
|
||||||
|
|
||||||
metric_names = {metric.name for metric in semantic_view.metrics}
|
if len(labels) != len(set(labels)):
|
||||||
if not set(query_object.metrics or []) <= metric_names:
|
raise ValueError(
|
||||||
raise ValueError("All metrics must be defined in the Semantic View.")
|
"Duplicate metric labels are not supported in Semantic Views."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _validate_dimensions(query_object: ValidatedQueryObject) -> None:
|
def _validate_dimensions(query_object: ValidatedQueryObject) -> None:
|
||||||
"""
|
"""
|
||||||
Make sure all dimensions are defined in the semantic view.
|
Make sure all dimensions are defined in the semantic view or are valid
|
||||||
|
adhocs. Synthesized adhoc columns are validated via ``_resolve_dimension``.
|
||||||
"""
|
"""
|
||||||
semantic_view = query_object.datasource.implementation
|
semantic_view = query_object.datasource.implementation
|
||||||
dimension_names = {dimension.name for dimension in semantic_view.dimensions}
|
dataset = query_object.datasource
|
||||||
|
dimensions_by_name = {
|
||||||
|
dimension.name: dimension for dimension in semantic_view.dimensions
|
||||||
|
}
|
||||||
|
template_processor = _get_template_processor(dataset)
|
||||||
|
|
||||||
# Normalize all columns to dimension names
|
labels: list[str] = []
|
||||||
normalized_columns = [
|
for column in query_object.columns:
|
||||||
_normalize_column(column, dimension_names) for column in query_object.columns
|
resolved = _resolve_dimension(
|
||||||
]
|
column, dataset, dimensions_by_name, template_processor
|
||||||
|
)
|
||||||
|
labels.append(resolved.id)
|
||||||
|
|
||||||
if not set(normalized_columns) <= dimension_names:
|
if len(labels) != len(set(labels)):
|
||||||
raise ValueError("All dimensions must be defined in the Semantic View.")
|
raise ValueError(
|
||||||
|
"Duplicate column labels are not supported in Semantic Views."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _validate_filters(query_object: ValidatedQueryObject) -> None:
|
def _validate_filters(query_object: ValidatedQueryObject) -> None:
|
||||||
|
|||||||
120
superset/semantic_layers/rls.py
Normal file
120
superset/semantic_layers/rls.py
Normal file
@@ -0,0 +1,120 @@
|
|||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
RLS helpers used by the dataset semantic-view path.
|
||||||
|
|
||||||
|
The legacy ``get_sqla_query`` applies row-level security in two places:
|
||||||
|
|
||||||
|
* The dataset's own RLS rules are AND-ed into the outer ``WHERE`` (or
|
||||||
|
the source table is wrapped in a subquery when the engine spec
|
||||||
|
returns ``RLSMethod.AS_SUBQUERY``).
|
||||||
|
* For virtual datasets, RLS rules from the *underlying* tables
|
||||||
|
referenced by ``dataset.sql`` are injected via
|
||||||
|
:func:`superset.utils.rls.apply_rls`.
|
||||||
|
|
||||||
|
This module exposes both flows as pure functions returning SQL strings,
|
||||||
|
which the ``DatasetSemanticView`` plugs into its sqlglot AST.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from superset.sql.parse import RLSMethod, SQLScript
|
||||||
|
from superset.utils.rls import apply_rls
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from superset.connectors.sqla.models import SqlaTable
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def get_rls_method(dataset: "SqlaTable") -> RLSMethod:
|
||||||
|
"""Return the engine's preferred RLS injection method."""
|
||||||
|
return dataset.database.db_engine_spec.get_rls_method()
|
||||||
|
|
||||||
|
|
||||||
|
def render_rls_predicates(dataset: "SqlaTable") -> list[str]:
|
||||||
|
"""
|
||||||
|
Render the dataset's outer RLS predicates as SQL strings.
|
||||||
|
|
||||||
|
The clauses are already Jinja-rendered by
|
||||||
|
:meth:`SqlaTable.get_sqla_row_level_filters`; we then compile them
|
||||||
|
against the database's dialect with ``literal_binds=True`` so the
|
||||||
|
resulting strings are safe to splice into a sqlglot AST.
|
||||||
|
|
||||||
|
Returns an empty list when the dataset has no applicable RLS rules
|
||||||
|
for the current user.
|
||||||
|
"""
|
||||||
|
text_clauses = dataset.get_sqla_row_level_filters()
|
||||||
|
if not text_clauses:
|
||||||
|
return []
|
||||||
|
dialect = dataset.database.get_dialect()
|
||||||
|
return [
|
||||||
|
str(clause.compile(dialect=dialect, compile_kwargs={"literal_binds": True}))
|
||||||
|
for clause in text_clauses
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rls_to_virtual_sql(dataset: "SqlaTable") -> str | None:
|
||||||
|
"""
|
||||||
|
Rewrite a virtual dataset's inner SQL to enforce RLS on the
|
||||||
|
underlying tables it references.
|
||||||
|
|
||||||
|
Returns the rewritten SQL string when RLS predicates were injected,
|
||||||
|
or ``None`` when no rules applied (caller should use the original
|
||||||
|
``dataset.sql`` in that case). Returns ``None`` for physical
|
||||||
|
datasets too.
|
||||||
|
"""
|
||||||
|
if not dataset.sql:
|
||||||
|
return None
|
||||||
|
|
||||||
|
engine = dataset.database.db_engine_spec.engine
|
||||||
|
try:
|
||||||
|
parsed_script = SQLScript(dataset.sql, engine=engine)
|
||||||
|
except Exception: # pylint: disable=broad-except
|
||||||
|
# Mirror the legacy behavior — failure to parse should not block
|
||||||
|
# the query; the caller will fall back to the original SQL and
|
||||||
|
# RLS will simply not be applied to inner tables.
|
||||||
|
logger.warning(
|
||||||
|
"Failed to parse virtual dataset SQL for RLS application",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
default_schema = dataset.database.get_default_schema(dataset.catalog)
|
||||||
|
rls_applied = False
|
||||||
|
try:
|
||||||
|
for statement in parsed_script.statements:
|
||||||
|
if apply_rls(
|
||||||
|
dataset.database,
|
||||||
|
dataset.catalog,
|
||||||
|
dataset.schema or default_schema or "",
|
||||||
|
statement,
|
||||||
|
exclude_dataset_id=dataset.id,
|
||||||
|
):
|
||||||
|
rls_applied = True
|
||||||
|
except Exception: # pylint: disable=broad-except
|
||||||
|
logger.warning(
|
||||||
|
"Failed to apply RLS to virtual dataset inner SQL",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
return parsed_script.format() if rls_applied else None
|
||||||
270
tests/unit_tests/semantic_layers/adhoc_test.py
Normal file
270
tests/unit_tests/semantic_layers/adhoc_test.py
Normal file
@@ -0,0 +1,270 @@
|
|||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
|
||||||
|
# pylint: disable=invalid-name
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pyarrow as pa
|
||||||
|
import pytest
|
||||||
|
from superset_core.semantic_layers.types import AggregationType, Dimension
|
||||||
|
|
||||||
|
from superset.exceptions import QueryObjectValidationError
|
||||||
|
from superset.semantic_layers.adhoc import (
|
||||||
|
adhoc_column_to_semantic_dimension,
|
||||||
|
adhoc_metric_to_semantic_metric,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def dataset() -> MagicMock:
|
||||||
|
ds = MagicMock()
|
||||||
|
ds.database_id = 1
|
||||||
|
ds.schema = "public"
|
||||||
|
ds.database.db_engine_spec.engine = "postgresql"
|
||||||
|
# Identifier quoter — return double-quoted name.
|
||||||
|
ds.quote_identifier = lambda name: f'"{name}"'
|
||||||
|
# ``_process_select_expression`` echoes back the SQL after "rendering".
|
||||||
|
ds._process_select_expression = (
|
||||||
|
lambda expression, **kwargs: expression # noqa: U100
|
||||||
|
)
|
||||||
|
return ds
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Adhoc metric
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_simple_metric_sum(dataset: MagicMock) -> None:
|
||||||
|
metric = adhoc_metric_to_semantic_metric(
|
||||||
|
{
|
||||||
|
"expressionType": "SIMPLE",
|
||||||
|
"label": "total_amount",
|
||||||
|
"aggregate": "SUM",
|
||||||
|
"column": {"column_name": "amount"},
|
||||||
|
},
|
||||||
|
dataset,
|
||||||
|
)
|
||||||
|
assert metric.id == "total_amount"
|
||||||
|
assert metric.name == "total_amount"
|
||||||
|
assert metric.definition == 'SUM("amount")'
|
||||||
|
assert metric.aggregation == AggregationType.SUM
|
||||||
|
# SUM(...) infers to float64.
|
||||||
|
assert metric.type == pa.float64()
|
||||||
|
|
||||||
|
|
||||||
|
def test_simple_metric_count_distinct(dataset: MagicMock) -> None:
|
||||||
|
metric = adhoc_metric_to_semantic_metric(
|
||||||
|
{
|
||||||
|
"expressionType": "SIMPLE",
|
||||||
|
"label": "unique_customers",
|
||||||
|
"aggregate": "COUNT_DISTINCT",
|
||||||
|
"column": {"column_name": "customer_id"},
|
||||||
|
},
|
||||||
|
dataset,
|
||||||
|
)
|
||||||
|
assert metric.definition == 'COUNT(DISTINCT "customer_id")'
|
||||||
|
assert metric.aggregation == AggregationType.COUNT_DISTINCT
|
||||||
|
|
||||||
|
|
||||||
|
def test_simple_metric_unknown_aggregate_falls_back_to_other(
|
||||||
|
dataset: MagicMock,
|
||||||
|
) -> None:
|
||||||
|
metric = adhoc_metric_to_semantic_metric(
|
||||||
|
{
|
||||||
|
"expressionType": "SIMPLE",
|
||||||
|
"label": "median_x",
|
||||||
|
"aggregate": "MEDIAN",
|
||||||
|
"column": {"column_name": "x"},
|
||||||
|
},
|
||||||
|
dataset,
|
||||||
|
)
|
||||||
|
# MEDIAN isn't in the rollup-safe map.
|
||||||
|
assert metric.aggregation == AggregationType.OTHER
|
||||||
|
|
||||||
|
|
||||||
|
def test_simple_metric_missing_aggregate_raises(dataset: MagicMock) -> None:
|
||||||
|
with pytest.raises(QueryObjectValidationError, match="aggregate and a column"):
|
||||||
|
adhoc_metric_to_semantic_metric(
|
||||||
|
{
|
||||||
|
"expressionType": "SIMPLE",
|
||||||
|
"label": "bad",
|
||||||
|
"column": {"column_name": "amount"},
|
||||||
|
},
|
||||||
|
dataset,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_simple_metric_missing_column_raises(dataset: MagicMock) -> None:
|
||||||
|
with pytest.raises(QueryObjectValidationError, match="aggregate and a column"):
|
||||||
|
adhoc_metric_to_semantic_metric(
|
||||||
|
{
|
||||||
|
"expressionType": "SIMPLE",
|
||||||
|
"label": "bad",
|
||||||
|
"aggregate": "SUM",
|
||||||
|
"column": {},
|
||||||
|
},
|
||||||
|
dataset,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_sql_metric(dataset: MagicMock) -> None:
|
||||||
|
metric = adhoc_metric_to_semantic_metric(
|
||||||
|
{
|
||||||
|
"expressionType": "SQL",
|
||||||
|
"label": "profit_margin",
|
||||||
|
"sqlExpression": "SUM(revenue - cost) / SUM(revenue)",
|
||||||
|
},
|
||||||
|
dataset,
|
||||||
|
)
|
||||||
|
assert metric.definition == "SUM(revenue - cost) / SUM(revenue)"
|
||||||
|
assert metric.aggregation == AggregationType.OTHER
|
||||||
|
|
||||||
|
|
||||||
|
def test_sql_metric_missing_expression_raises(dataset: MagicMock) -> None:
|
||||||
|
with pytest.raises(QueryObjectValidationError, match="sqlExpression"):
|
||||||
|
adhoc_metric_to_semantic_metric(
|
||||||
|
{"expressionType": "SQL", "label": "bad"},
|
||||||
|
dataset,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_metric_missing_label_raises(dataset: MagicMock) -> None:
|
||||||
|
with pytest.raises(QueryObjectValidationError, match="missing a ``label``"):
|
||||||
|
adhoc_metric_to_semantic_metric(
|
||||||
|
{"expressionType": "SIMPLE", "aggregate": "SUM"},
|
||||||
|
dataset,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_metric_unknown_expression_type_raises(dataset: MagicMock) -> None:
|
||||||
|
with pytest.raises(QueryObjectValidationError, match="Unknown adhoc metric"):
|
||||||
|
adhoc_metric_to_semantic_metric(
|
||||||
|
{"label": "weird", "expressionType": "MYSTERY"},
|
||||||
|
dataset,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_sql_metric_jinja_applied(dataset: MagicMock) -> None:
|
||||||
|
# ``_process_select_expression`` is where Jinja and safety validation
|
||||||
|
# live in the dataset model. We verify the helper is invoked.
|
||||||
|
dataset._process_select_expression = MagicMock(return_value="user_id = 42")
|
||||||
|
metric = adhoc_metric_to_semantic_metric(
|
||||||
|
{
|
||||||
|
"expressionType": "SQL",
|
||||||
|
"label": "rendered",
|
||||||
|
"sqlExpression": "user_id = {{ current_user_id() }}",
|
||||||
|
},
|
||||||
|
dataset,
|
||||||
|
template_processor=MagicMock(),
|
||||||
|
)
|
||||||
|
assert metric.definition == "user_id = 42"
|
||||||
|
dataset._process_select_expression.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_sql_metric_empty_processed_raises(dataset: MagicMock) -> None:
|
||||||
|
dataset._process_select_expression = MagicMock(return_value=None)
|
||||||
|
with pytest.raises(QueryObjectValidationError, match="empty string"):
|
||||||
|
adhoc_metric_to_semantic_metric(
|
||||||
|
{
|
||||||
|
"expressionType": "SQL",
|
||||||
|
"label": "bad",
|
||||||
|
"sqlExpression": "{{ '' }}",
|
||||||
|
},
|
||||||
|
dataset,
|
||||||
|
template_processor=MagicMock(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Adhoc column
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_adhoc_column_reference_uses_existing_dimension(dataset: MagicMock) -> None:
|
||||||
|
existing = Dimension(id="country", name="country", type=pa.utf8())
|
||||||
|
result = adhoc_column_to_semantic_dimension(
|
||||||
|
{
|
||||||
|
"label": "country",
|
||||||
|
"sqlExpression": "country",
|
||||||
|
"isColumnReference": True,
|
||||||
|
},
|
||||||
|
dataset,
|
||||||
|
{"country": existing},
|
||||||
|
)
|
||||||
|
assert result is existing
|
||||||
|
|
||||||
|
|
||||||
|
def test_adhoc_column_synthesises_dimension(dataset: MagicMock) -> None:
|
||||||
|
dataset._process_select_expression = MagicMock(return_value="UPPER(country)")
|
||||||
|
result = adhoc_column_to_semantic_dimension(
|
||||||
|
{
|
||||||
|
"label": "upper_country",
|
||||||
|
"sqlExpression": "UPPER(country)",
|
||||||
|
},
|
||||||
|
dataset,
|
||||||
|
{},
|
||||||
|
template_processor=MagicMock(),
|
||||||
|
)
|
||||||
|
assert result.id == "upper_country"
|
||||||
|
assert result.name == "upper_country"
|
||||||
|
assert result.definition == "UPPER(country)"
|
||||||
|
# UPPER(...) infers to utf8.
|
||||||
|
assert result.type == pa.utf8()
|
||||||
|
|
||||||
|
|
||||||
|
def test_adhoc_column_missing_label_raises(dataset: MagicMock) -> None:
|
||||||
|
with pytest.raises(QueryObjectValidationError, match="``label``"):
|
||||||
|
adhoc_column_to_semantic_dimension(
|
||||||
|
{"sqlExpression": "x"},
|
||||||
|
dataset,
|
||||||
|
{},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_adhoc_column_missing_sql_raises(dataset: MagicMock) -> None:
|
||||||
|
with pytest.raises(QueryObjectValidationError, match="``sqlExpression``"):
|
||||||
|
adhoc_column_to_semantic_dimension(
|
||||||
|
{"label": "x"},
|
||||||
|
dataset,
|
||||||
|
{},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_adhoc_column_reference_falls_back_when_not_matching(
|
||||||
|
dataset: MagicMock,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
A column-reference adhoc whose sqlExpression doesn't match an existing
|
||||||
|
dimension is treated as a synthesized adhoc.
|
||||||
|
"""
|
||||||
|
dataset._process_select_expression = MagicMock(return_value="ghost")
|
||||||
|
result = adhoc_column_to_semantic_dimension(
|
||||||
|
{
|
||||||
|
"label": "spooky",
|
||||||
|
"sqlExpression": "ghost",
|
||||||
|
"isColumnReference": True,
|
||||||
|
},
|
||||||
|
dataset,
|
||||||
|
{"country": Dimension(id="country", name="country", type=pa.utf8())},
|
||||||
|
template_processor=MagicMock(),
|
||||||
|
)
|
||||||
|
assert result.id == "spooky"
|
||||||
|
assert result.definition == "ghost"
|
||||||
@@ -50,9 +50,14 @@ from superset.semantic_layers.mapper import (
|
|||||||
_get_group_limit_filters,
|
_get_group_limit_filters,
|
||||||
_get_group_limit_from_query_object,
|
_get_group_limit_from_query_object,
|
||||||
_get_order_from_query_object,
|
_get_order_from_query_object,
|
||||||
|
_get_template_processor,
|
||||||
_get_time_bounds,
|
_get_time_bounds,
|
||||||
_get_time_filter,
|
_get_time_filter,
|
||||||
_normalize_column,
|
_normalize_column,
|
||||||
|
_resolve_dimension,
|
||||||
|
_resolve_metric,
|
||||||
|
_stamp_grain,
|
||||||
|
_validate_dimensions,
|
||||||
_validate_filters,
|
_validate_filters,
|
||||||
_validate_granularity,
|
_validate_granularity,
|
||||||
_validate_group_limit,
|
_validate_group_limit,
|
||||||
@@ -1035,7 +1040,10 @@ def test_validate_query_object_undefined_metric_error(
|
|||||||
columns=["order_date"],
|
columns=["order_date"],
|
||||||
)
|
)
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="All metrics must be defined"):
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match="Metric 'undefined_metric' is not defined in the Semantic View",
|
||||||
|
):
|
||||||
validate_query_object(query_object)
|
validate_query_object(query_object)
|
||||||
|
|
||||||
|
|
||||||
@@ -1051,7 +1059,10 @@ def test_validate_query_object_undefined_dimension_error(
|
|||||||
columns=["undefined_dimension"],
|
columns=["undefined_dimension"],
|
||||||
)
|
)
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="All dimensions must be defined"):
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match="Dimension 'undefined_dimension' is not defined in the Semantic View",
|
||||||
|
):
|
||||||
validate_query_object(query_object)
|
validate_query_object(query_object)
|
||||||
|
|
||||||
|
|
||||||
@@ -1800,28 +1811,31 @@ def test_get_results_empty_requests(
|
|||||||
|
|
||||||
def test_normalize_column_adhoc_not_in_dimensions() -> None:
|
def test_normalize_column_adhoc_not_in_dimensions() -> None:
|
||||||
"""
|
"""
|
||||||
Test _normalize_column raises error for AdhocColumn with sqlExpression not in dims.
|
Adhoc columns whose sqlExpression doesn't match an existing dimension fall
|
||||||
|
back to using the label as the resolved name. Actual SQL synthesis is the
|
||||||
|
job of _resolve_dimension; _normalize_column only surfaces a name.
|
||||||
"""
|
"""
|
||||||
dimension_names = {"category", "region"}
|
dimension_names = {"category", "region"}
|
||||||
adhoc_column: AdhocColumn = {
|
adhoc_column: AdhocColumn = {
|
||||||
|
"label": "custom_dim",
|
||||||
"isColumnReference": True,
|
"isColumnReference": True,
|
||||||
"sqlExpression": "unknown_dimension",
|
"sqlExpression": "unknown_dimension",
|
||||||
}
|
}
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="Adhoc dimensions are not supported"):
|
assert _normalize_column(adhoc_column, dimension_names) == "custom_dim"
|
||||||
_normalize_column(adhoc_column, dimension_names)
|
|
||||||
|
|
||||||
|
|
||||||
def test_normalize_column_adhoc_missing_sql_expression() -> None:
|
def test_normalize_column_adhoc_missing_label_raises() -> None:
|
||||||
"""
|
"""
|
||||||
Test _normalize_column raises error for AdhocColumn without sqlExpression.
|
When neither a matching column reference nor a label is provided there's
|
||||||
|
no resolvable name and _normalize_column raises.
|
||||||
"""
|
"""
|
||||||
dimension_names = {"category", "region"}
|
dimension_names = {"category", "region"}
|
||||||
adhoc_column: AdhocColumn = {
|
adhoc_column: AdhocColumn = {
|
||||||
"isColumnReference": True,
|
"isColumnReference": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="Adhoc dimensions are not supported"):
|
with pytest.raises(ValueError, match="Adhoc column is missing"):
|
||||||
_normalize_column(adhoc_column, dimension_names)
|
_normalize_column(adhoc_column, dimension_names)
|
||||||
|
|
||||||
|
|
||||||
@@ -2265,11 +2279,12 @@ def test_validate_query_object_no_datasource() -> None:
|
|||||||
assert result is False
|
assert result is False
|
||||||
|
|
||||||
|
|
||||||
def test_validate_metrics_adhoc_error(
|
def test_validate_metrics_adhoc_with_bad_shape_raises(
|
||||||
mocker: MockerFixture,
|
mocker: MockerFixture,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Test validation error for adhoc metrics.
|
Adhoc metrics are now supported, but invalid shapes (missing
|
||||||
|
expressionType, etc.) still raise via the adhoc resolver.
|
||||||
"""
|
"""
|
||||||
mock_datasource = mocker.Mock()
|
mock_datasource = mocker.Mock()
|
||||||
category_dim = Dimension("category", "category", pa.utf8(), "category", "Category")
|
category_dim = Dimension("category", "category", pa.utf8(), "category", "Category")
|
||||||
@@ -2279,13 +2294,15 @@ def test_validate_metrics_adhoc_error(
|
|||||||
|
|
||||||
mock_datasource.implementation.dimensions = {category_dim}
|
mock_datasource.implementation.dimensions = {category_dim}
|
||||||
mock_datasource.implementation.metrics = {sales_metric}
|
mock_datasource.implementation.metrics = {sales_metric}
|
||||||
|
# Strip the template processor; we just want to verify the shape check.
|
||||||
|
mock_datasource.get_template_processor.return_value = None
|
||||||
|
|
||||||
# Manually create a query object with an adhoc metric
|
|
||||||
query_object = mocker.Mock()
|
query_object = mocker.Mock()
|
||||||
query_object.datasource = mock_datasource
|
query_object.datasource = mock_datasource
|
||||||
|
# Missing expressionType — the resolver doesn't know how to interpret this.
|
||||||
query_object.metrics = [{"label": "adhoc", "sqlExpression": "SUM(x)"}]
|
query_object.metrics = [{"label": "adhoc", "sqlExpression": "SUM(x)"}]
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="Adhoc metrics are not supported"):
|
with pytest.raises(Exception, match="expressionType"):
|
||||||
_validate_metrics(query_object)
|
_validate_metrics(query_object)
|
||||||
|
|
||||||
|
|
||||||
@@ -3111,3 +3128,171 @@ def test_coerce_time_invalid_string_raises() -> None:
|
|||||||
def test_coerce_time_rejects_other_types() -> None:
|
def test_coerce_time_rejects_other_types() -> None:
|
||||||
with pytest.raises(ValueError, match="Invalid time value"):
|
with pytest.raises(ValueError, match="Invalid time value"):
|
||||||
_coerce_scalar_filter_value(123, _dim(pa.time64("us")))
|
_coerce_scalar_filter_value(123, _dim(pa.time64("us")))
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Adhoc + grain resolver coverage
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_template_processor_returns_none_when_unsupported() -> None:
|
||||||
|
"""A bare object without ``get_template_processor`` returns None."""
|
||||||
|
|
||||||
|
class Plain:
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert _get_template_processor(Plain()) is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_stamp_grain_returns_new_dimension_with_grain() -> None:
|
||||||
|
dim = Dimension(id="dt", name="dt", type=pa.timestamp("us"))
|
||||||
|
stamped = _stamp_grain(dim, Grains.DAY)
|
||||||
|
assert stamped.grain == Grains.DAY
|
||||||
|
# Original is unchanged (frozen dataclass).
|
||||||
|
assert dim.grain is None
|
||||||
|
|
||||||
|
|
||||||
|
def _adhoc_dataset() -> MagicMock:
|
||||||
|
ds = MagicMock()
|
||||||
|
ds.database_id = 1
|
||||||
|
ds.schema = "public"
|
||||||
|
ds.database.db_engine_spec.engine = "postgresql"
|
||||||
|
ds.quote_identifier = lambda name: f'"{name}"'
|
||||||
|
ds._process_select_expression = lambda expression, **kwargs: expression # noqa: U100
|
||||||
|
ds.get_template_processor.return_value = None
|
||||||
|
return ds
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_metric_for_adhoc_simple_dict() -> None:
|
||||||
|
ds = _adhoc_dataset()
|
||||||
|
metric = _resolve_metric(
|
||||||
|
{
|
||||||
|
"expressionType": "SIMPLE",
|
||||||
|
"label": "total",
|
||||||
|
"aggregate": "SUM",
|
||||||
|
"column": {"column_name": "x"},
|
||||||
|
},
|
||||||
|
ds,
|
||||||
|
{},
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
assert metric.id == "total"
|
||||||
|
assert metric.definition == 'SUM("x")'
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_dimension_for_adhoc_dict() -> None:
|
||||||
|
ds = _adhoc_dataset()
|
||||||
|
dim = _resolve_dimension(
|
||||||
|
{"label": "calc", "sqlExpression": "UPPER(country)"},
|
||||||
|
ds,
|
||||||
|
{},
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
assert dim.id == "calc"
|
||||||
|
assert dim.definition == "UPPER(country)"
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_metrics_duplicate_label_raises() -> None:
|
||||||
|
ds = _adhoc_dataset()
|
||||||
|
view = MagicMock()
|
||||||
|
view.metrics = []
|
||||||
|
view.dimensions = []
|
||||||
|
ds.implementation = view
|
||||||
|
|
||||||
|
query = MagicMock()
|
||||||
|
query.datasource = ds
|
||||||
|
query.metrics = [
|
||||||
|
{
|
||||||
|
"expressionType": "SIMPLE",
|
||||||
|
"label": "dup",
|
||||||
|
"aggregate": "SUM",
|
||||||
|
"column": {"column_name": "x"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"expressionType": "SIMPLE",
|
||||||
|
"label": "dup",
|
||||||
|
"aggregate": "SUM",
|
||||||
|
"column": {"column_name": "y"},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Duplicate metric labels"):
|
||||||
|
_validate_metrics(query)
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_dimensions_duplicate_label_raises() -> None:
|
||||||
|
ds = _adhoc_dataset()
|
||||||
|
view = MagicMock()
|
||||||
|
view.metrics = []
|
||||||
|
view.dimensions = []
|
||||||
|
ds.implementation = view
|
||||||
|
|
||||||
|
query = MagicMock()
|
||||||
|
query.datasource = ds
|
||||||
|
query.columns = [
|
||||||
|
{"label": "dup", "sqlExpression": "x"},
|
||||||
|
{"label": "dup", "sqlExpression": "y"},
|
||||||
|
]
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Duplicate column labels"):
|
||||||
|
_validate_dimensions(query)
|
||||||
|
|
||||||
|
|
||||||
|
def test_map_query_object_dedups_dimensions_by_id(mocker: MockerFixture) -> None:
|
||||||
|
"""Two requested columns resolving to the same dim id collapse to one."""
|
||||||
|
ds = _adhoc_dataset()
|
||||||
|
existing = Dimension(id="country", name="country", type=pa.utf8())
|
||||||
|
view = MagicMock()
|
||||||
|
view.metrics = []
|
||||||
|
view.dimensions = [existing]
|
||||||
|
ds.implementation = view
|
||||||
|
ds.fetch_values_predicate = None
|
||||||
|
|
||||||
|
query = ValidatedQueryObject(
|
||||||
|
datasource=ds,
|
||||||
|
metrics=[],
|
||||||
|
# Same dim referenced twice — once as a string, once as a column ref.
|
||||||
|
columns=[
|
||||||
|
"country",
|
||||||
|
{"label": "country", "sqlExpression": "country", "isColumnReference": True},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
sem_queries = map_query_object(query)
|
||||||
|
assert len(sem_queries[0].dimensions) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_map_query_object_stamps_grain_on_granularity_dim(
|
||||||
|
mocker: MockerFixture,
|
||||||
|
) -> None:
|
||||||
|
"""A grain request stamps the grain onto the matching dimension."""
|
||||||
|
ds = _adhoc_dataset()
|
||||||
|
dt_dim = Dimension(id="dt", name="dt", type=pa.timestamp("us"))
|
||||||
|
view = MagicMock()
|
||||||
|
view.metrics = []
|
||||||
|
view.dimensions = [dt_dim]
|
||||||
|
ds.implementation = view
|
||||||
|
ds.fetch_values_predicate = None
|
||||||
|
|
||||||
|
query = ValidatedQueryObject(
|
||||||
|
datasource=ds,
|
||||||
|
metrics=[],
|
||||||
|
columns=["dt"],
|
||||||
|
granularity="dt",
|
||||||
|
extras={"time_grain_sqla": "P1D"},
|
||||||
|
)
|
||||||
|
|
||||||
|
sem_queries = map_query_object(query)
|
||||||
|
dim = sem_queries[0].dimensions[0]
|
||||||
|
assert dim.grain == Grains.DAY
|
||||||
|
|
||||||
|
|
||||||
|
def test_normalize_column_adhoc_label_only() -> None:
|
||||||
|
"""Adhoc with no isColumnReference falls back to its label."""
|
||||||
|
assert (
|
||||||
|
_normalize_column({"label": "calc", "sqlExpression": "x"}, set()) == "calc"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_normalize_column_string_passthrough() -> None:
|
||||||
|
assert _normalize_column("category", {"category", "region"}) == "category"
|
||||||
|
|||||||
166
tests/unit_tests/semantic_layers/rls_test.py
Normal file
166
tests/unit_tests/semantic_layers/rls_test.py
Normal file
@@ -0,0 +1,166 @@
|
|||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
|
||||||
|
# pylint: disable=invalid-name
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pytest_mock import MockerFixture
|
||||||
|
|
||||||
|
from superset.semantic_layers.rls import (
|
||||||
|
apply_rls_to_virtual_sql,
|
||||||
|
get_rls_method,
|
||||||
|
render_rls_predicates,
|
||||||
|
)
|
||||||
|
from superset.sql.parse import RLSMethod
|
||||||
|
|
||||||
|
|
||||||
|
def _make_dataset(
|
||||||
|
*,
|
||||||
|
sql: str | None = None,
|
||||||
|
rls_clauses: list | None = None,
|
||||||
|
) -> MagicMock:
|
||||||
|
ds = MagicMock()
|
||||||
|
ds.id = 1
|
||||||
|
ds.catalog = None
|
||||||
|
ds.schema = "public"
|
||||||
|
ds.sql = sql
|
||||||
|
ds.database.db_engine_spec.engine = "postgresql"
|
||||||
|
ds.database.get_default_schema.return_value = "public"
|
||||||
|
ds.get_sqla_row_level_filters.return_value = rls_clauses or []
|
||||||
|
return ds
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_rls_method_delegates_to_engine_spec() -> None:
|
||||||
|
ds = _make_dataset()
|
||||||
|
ds.database.db_engine_spec.get_rls_method.return_value = RLSMethod.AS_SUBQUERY
|
||||||
|
assert get_rls_method(ds) == RLSMethod.AS_SUBQUERY
|
||||||
|
|
||||||
|
|
||||||
|
def test_render_rls_predicates_empty_when_no_rules() -> None:
|
||||||
|
ds = _make_dataset()
|
||||||
|
assert render_rls_predicates(ds) == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_render_rls_predicates_compiles_each_clause() -> None:
|
||||||
|
clause1 = MagicMock()
|
||||||
|
clause1.compile.return_value = "tenant_id = 7"
|
||||||
|
clause2 = MagicMock()
|
||||||
|
clause2.compile.return_value = "deleted = false"
|
||||||
|
ds = _make_dataset(rls_clauses=[clause1, clause2])
|
||||||
|
|
||||||
|
result = render_rls_predicates(ds)
|
||||||
|
assert result == ["tenant_id = 7", "deleted = false"]
|
||||||
|
# Each clause is compiled with literal_binds so values are inlined.
|
||||||
|
for clause in (clause1, clause2):
|
||||||
|
kwargs = clause.compile.call_args.kwargs
|
||||||
|
assert kwargs["compile_kwargs"] == {"literal_binds": True}
|
||||||
|
|
||||||
|
|
||||||
|
def test_apply_rls_to_virtual_sql_returns_none_for_physical(mocker: MockerFixture) -> None:
|
||||||
|
ds = _make_dataset()
|
||||||
|
assert apply_rls_to_virtual_sql(ds) is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_apply_rls_to_virtual_sql_returns_none_when_no_rls_applied(
|
||||||
|
mocker: MockerFixture,
|
||||||
|
) -> None:
|
||||||
|
ds = _make_dataset(sql="SELECT * FROM raw_orders")
|
||||||
|
# apply_rls returns False → no predicates were injected.
|
||||||
|
mocker.patch("superset.semantic_layers.rls.apply_rls", return_value=False)
|
||||||
|
assert apply_rls_to_virtual_sql(ds) is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_apply_rls_to_virtual_sql_returns_rewritten_when_applied(
|
||||||
|
mocker: MockerFixture,
|
||||||
|
) -> None:
|
||||||
|
ds = _make_dataset(sql="SELECT * FROM raw_orders")
|
||||||
|
fake_script = MagicMock()
|
||||||
|
fake_statement = MagicMock()
|
||||||
|
fake_script.statements = [fake_statement]
|
||||||
|
fake_script.format.return_value = "SELECT * FROM raw_orders WHERE x = 1"
|
||||||
|
mocker.patch(
|
||||||
|
"superset.semantic_layers.rls.SQLScript",
|
||||||
|
return_value=fake_script,
|
||||||
|
)
|
||||||
|
mocker.patch("superset.semantic_layers.rls.apply_rls", return_value=True)
|
||||||
|
|
||||||
|
result = apply_rls_to_virtual_sql(ds)
|
||||||
|
assert result == "SELECT * FROM raw_orders WHERE x = 1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_apply_rls_to_virtual_sql_swallows_parse_error(mocker: MockerFixture) -> None:
|
||||||
|
ds = _make_dataset(sql="totally invalid sql")
|
||||||
|
mocker.patch(
|
||||||
|
"superset.semantic_layers.rls.SQLScript",
|
||||||
|
side_effect=Exception("parse error"),
|
||||||
|
)
|
||||||
|
assert apply_rls_to_virtual_sql(ds) is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_apply_rls_to_virtual_sql_swallows_apply_error(mocker: MockerFixture) -> None:
|
||||||
|
ds = _make_dataset(sql="SELECT * FROM raw_orders")
|
||||||
|
fake_script = MagicMock()
|
||||||
|
fake_script.statements = [MagicMock()]
|
||||||
|
mocker.patch(
|
||||||
|
"superset.semantic_layers.rls.SQLScript",
|
||||||
|
return_value=fake_script,
|
||||||
|
)
|
||||||
|
mocker.patch(
|
||||||
|
"superset.semantic_layers.rls.apply_rls",
|
||||||
|
side_effect=Exception("boom"),
|
||||||
|
)
|
||||||
|
assert apply_rls_to_virtual_sql(ds) is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_apply_rls_to_virtual_sql_uses_default_schema_when_dataset_schema_missing(
|
||||||
|
mocker: MockerFixture,
|
||||||
|
) -> None:
|
||||||
|
ds = _make_dataset(sql="SELECT * FROM raw_orders")
|
||||||
|
ds.schema = None
|
||||||
|
fake_script = MagicMock()
|
||||||
|
fake_statement = MagicMock()
|
||||||
|
fake_script.statements = [fake_statement]
|
||||||
|
mocker.patch(
|
||||||
|
"superset.semantic_layers.rls.SQLScript",
|
||||||
|
return_value=fake_script,
|
||||||
|
)
|
||||||
|
apply_rls_mock = mocker.patch(
|
||||||
|
"superset.semantic_layers.rls.apply_rls",
|
||||||
|
return_value=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
apply_rls_to_virtual_sql(ds)
|
||||||
|
# default_schema was "public" from get_default_schema; that should be the
|
||||||
|
# schema arg passed to apply_rls.
|
||||||
|
args, kwargs = apply_rls_mock.call_args
|
||||||
|
assert args[2] == "public" or kwargs.get("schema") == "public"
|
||||||
|
|
||||||
|
|
||||||
|
def test_render_rls_predicates_uses_dialect_for_compile() -> None:
|
||||||
|
clause = MagicMock()
|
||||||
|
clause.compile.return_value = "x = 1"
|
||||||
|
ds = _make_dataset(rls_clauses=[clause])
|
||||||
|
|
||||||
|
render_rls_predicates(ds)
|
||||||
|
kwargs = clause.compile.call_args.kwargs
|
||||||
|
# Verify the dialect from the database is passed (not asserting identity,
|
||||||
|
# just that ``dialect=`` got a value).
|
||||||
|
assert "dialect" in kwargs
|
||||||
Reference in New Issue
Block a user