Compare commits

...

4 Commits

Author SHA1 Message Date
Beto Dealmeida
4a749e3f31 Better type inference 2026-06-07 12:59:44 -04:00
Beto Dealmeida
439bcf8d79 parity 2026-06-04 20:59:15 -04:00
Beto Dealmeida
ce7753b631 Add tests 2026-06-04 19:46:54 -04:00
Beto Dealmeida
b9492f477b WIP 2026-06-04 19:31:57 -04:00
28 changed files with 12777 additions and 47 deletions

View File

@@ -2529,6 +2529,13 @@ except ImportError:
LOCAL_EXTENSIONS: list[str] = []
EXTENSIONS_PATH: str | None = None
# When True, dataset queries are routed through the dataset semantic-layer
# extension (``superset/semantic_layers/extension``) instead of the legacy
# ``get_sqla_query`` path. The semantic view builds the SQL via sqlglot and
# the mapper handles the QueryObject → SemanticQuery translation. Falls back
# to the legacy path on any error.
USE_DATASET_SEMANTIC_VIEW: bool = False
# Default polling interval for tasks (seconds)
TASK_ABORT_POLLING_DEFAULT_INTERVAL = 10

View File

@@ -19,10 +19,12 @@ from __future__ import annotations
import builtins
import logging
import sys
from collections import defaultdict
from collections.abc import Hashable
from dataclasses import dataclass, field
from datetime import timedelta
from pathlib import Path
from typing import Any, Callable, cast, Optional, Union
import pandas as pd
@@ -71,6 +73,7 @@ from superset_core.common.models import Dataset as CoreDataset
from superset import db, is_feature_enabled, security_manager
from superset.commands.dataset.exceptions import DatasetNotFoundError
from superset.common.db_query_status import QueryStatus
from superset.common.query_object import QueryObject
from superset.connectors.sqla.utils import (
get_columns_description,
get_physical_table_metadata,
@@ -2139,6 +2142,67 @@ class SqlaTable(
"""Returns a text clause using ExploreMixin implementation"""
return ExploreMixin.text(self, clause)
@property
def implementation(self) -> Any:
"""
Expose the dataset as a ``SemanticView`` instance.
The dataset semantic-layer extension (``superset/semantic_layers/
extension``) wraps a ``SqlaTable`` in a ``DatasetSemanticView``,
translating the dataset's columns and metrics into semantic dimensions
and metrics. This property exists so the same mapper-based execution
path used by stored semantic views (``mapper.get_results``) can drive
plain dataset queries when ``USE_DATASET_SEMANTIC_VIEW`` is enabled.
"""
# The extension's backend lives outside the regular Python path; add it
# the first time we need it. The ``@semantic_layer`` decorator has
# already been monkey-patched by ``inject_semantic_layer_implementations``
# at app init, so importing here is safe at runtime.
extension_src = (
Path(__file__).resolve().parents[1]
/ "semantic_layers"
/ "extension"
/ "backend"
/ "src"
)
extension_src_str = str(extension_src)
if extension_src_str not in sys.path:
sys.path.insert(0, extension_src_str)
from preset_io.dataset_semantic_layer import DatasetSemanticView
return DatasetSemanticView(self)
def get_query_result(self, query_object: QueryObject) -> QueryResult:
"""
Route dataset queries through the dataset semantic-layer extension
when ``USE_DATASET_SEMANTIC_VIEW`` is enabled, otherwise fall back to
the legacy ``ExploreMixin`` path.
The mapper in ``superset.semantic_layers.mapper`` handles the
``QueryObject`` → ``SemanticQuery`` translation (filters, time range,
group limit, ordering, time offsets), so all we add here is the
feature-flag gate and a safety fallback for cases the semantic view
does not yet support.
"""
if current_app.config.get("USE_DATASET_SEMANTIC_VIEW"):
from superset.semantic_layers.mapper import get_results
try:
# ``mapper.get_results`` reads ``query_object.datasource.implementation``;
# ensure we are the datasource on the in-memory QueryObject.
if query_object.datasource is None:
query_object.datasource = self
return get_results(query_object)
except Exception: # pylint: disable=broad-except
logger.warning(
"Semantic-view execution failed for dataset %s; falling "
"back to the legacy query path.",
self.table_name,
exc_info=True,
)
return ExploreMixin.get_query_result(self, query_object)
sa.event.listen(SqlaTable, "before_update", SqlaTable.before_update)
sa.event.listen(SqlaTable, "after_insert", SqlaTable.after_insert)

View File

@@ -0,0 +1,197 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Convert frontend "adhoc" metric / column dictionaries into semantic-layer
``Metric`` / ``Dimension`` objects.
Adhoc metrics and columns are user-defined SQL expressions that don't
correspond to a saved metric or physical column on the dataset. The legacy
``get_sqla_query`` path supports them through ``adhoc_metric_to_sqla`` /
``adhoc_column_to_sqla``; the semantic-layer path needs an equivalent so
that flipping ``USE_DATASET_SEMANTIC_VIEW`` doesn't regress charts that
use them.
The strategy is light: synthesize a ``Metric(definition=<sql>)`` /
``Dimension(definition=<sql>)`` and let the view's existing
``definition``-parsing path handle the rest. Jinja templating + safety
checks are delegated to the dataset's own
``_process_select_expression`` (already battle-tested).
"""
from __future__ import annotations
from typing import Any, TYPE_CHECKING
import pyarrow as pa
from superset_core.semantic_layers.types import (
AggregationType,
Dimension,
Metric,
)
from superset.exceptions import QueryObjectValidationError
from superset.semantic_layers.arrow_inference import infer_arrow_type
from superset.utils.core import AdhocMetricExpressionType
if TYPE_CHECKING:
from superset.connectors.sqla.models import SqlaTable
from superset.jinja_context import BaseTemplateProcessor
# Map Superset's aggregate strings to the semantic-layer ``AggregationType``.
# Anything unmapped (median, percentile, …) falls back to OTHER, which the
# view treats as a non-rollup-safe metric.
_AGGREGATE_MAP: dict[str, AggregationType] = {
"COUNT": AggregationType.COUNT,
"SUM": AggregationType.SUM,
"AVG": AggregationType.AVG,
"MIN": AggregationType.MIN,
"MAX": AggregationType.MAX,
"COUNT_DISTINCT": AggregationType.COUNT_DISTINCT,
}
def _process_adhoc_sql(
dataset: "SqlaTable",
sql_expression: str,
template_processor: "BaseTemplateProcessor | None",
) -> str:
"""
Run Jinja + safety validation on a free-form SQL expression.
Delegates to the dataset's ``_process_select_expression`` so adhoc
columns and metrics go through the same Jinja rendering, subquery
rules, and clause sanitization as the legacy path.
"""
processed = dataset._process_select_expression( # noqa: SLF001
expression=sql_expression,
database_id=dataset.database_id,
engine=dataset.database.db_engine_spec.engine,
schema=dataset.schema or "",
template_processor=template_processor,
)
if not processed:
raise QueryObjectValidationError(
"Adhoc SQL expression resolved to an empty string."
)
return processed
def _simple_metric_definition(adhoc: dict[str, Any], dataset: "SqlaTable") -> str:
"""
Build the SQL for a SIMPLE adhoc metric: ``AGGREGATE(column)``.
Uses the database's identifier quoter so the column name is rendered
in the right dialect.
"""
aggregate = (adhoc.get("aggregate") or "").upper()
column = adhoc.get("column") or {}
column_name = column.get("column_name")
if not aggregate or not column_name:
raise QueryObjectValidationError(
"SIMPLE adhoc metric requires both an aggregate and a column."
)
if aggregate == "COUNT_DISTINCT":
return f"COUNT(DISTINCT {dataset.quote_identifier(column_name)})"
return f"{aggregate}({dataset.quote_identifier(column_name)})"
def adhoc_metric_to_semantic_metric(
adhoc: dict[str, Any],
dataset: "SqlaTable",
template_processor: "BaseTemplateProcessor | None" = None,
) -> Metric:
"""
Convert an ``AdhocMetric`` dict to a semantic-layer ``Metric``.
SIMPLE metrics build ``f"{aggregate}({column})"`` and tag the
aggregation type so the view can preserve rollup semantics. SQL
metrics use the user-provided ``sqlExpression`` (after Jinja +
safety validation) and report ``AggregationType.OTHER`` because the
expression is opaque.
Adhoc dtype is set to ``pa.null()`` — the view aliases the metric
by id and downstream chart code coerces at render time.
"""
label = adhoc.get("label")
if not label:
raise QueryObjectValidationError("Adhoc metric is missing a ``label``.")
expression_type = adhoc.get("expressionType")
if expression_type == AdhocMetricExpressionType.SIMPLE.value:
definition = _simple_metric_definition(adhoc, dataset)
# Templating is irrelevant for SIMPLE metrics (no user-supplied SQL),
# so the safety pass is skipped here.
aggregate = (adhoc.get("aggregate") or "").upper()
aggregation = _AGGREGATE_MAP.get(aggregate, AggregationType.OTHER)
elif expression_type == AdhocMetricExpressionType.SQL.value:
sql_expression = adhoc.get("sqlExpression")
if not sql_expression:
raise QueryObjectValidationError(
"SQL adhoc metric is missing ``sqlExpression``."
)
definition = _process_adhoc_sql(dataset, sql_expression, template_processor)
aggregation = AggregationType.OTHER
else:
raise QueryObjectValidationError(
f"Unknown adhoc metric expressionType: {expression_type!r}"
)
return Metric(
id=label,
name=label,
type=infer_arrow_type(definition, dataset.database.db_engine_spec.engine),
definition=definition,
aggregation=aggregation,
)
def adhoc_column_to_semantic_dimension(
adhoc: dict[str, Any],
dataset: "SqlaTable",
dimensions_by_name: dict[str, Dimension],
template_processor: "BaseTemplateProcessor | None" = None,
) -> Dimension:
"""
Convert an ``AdhocColumn`` dict to a semantic-layer ``Dimension``.
When the adhoc is a thin wrapper around a real column
(``isColumnReference=True`` and ``sqlExpression`` matches a known
dimension name), the existing ``Dimension`` is returned so that
metadata such as type and grain are preserved. Otherwise a fresh
``Dimension`` with ``definition=<rendered SQL>`` is synthesized.
"""
label = adhoc.get("label")
if not label:
raise QueryObjectValidationError("Adhoc column is missing a ``label``.")
sql_expression = adhoc.get("sqlExpression")
if not sql_expression:
raise QueryObjectValidationError(
"Adhoc column is missing ``sqlExpression``."
)
if adhoc.get("isColumnReference") and sql_expression in dimensions_by_name:
return dimensions_by_name[sql_expression]
definition = _process_adhoc_sql(dataset, sql_expression, template_processor)
return Dimension(
id=label,
name=label,
type=infer_arrow_type(definition, dataset.database.db_engine_spec.engine),
definition=definition,
)

View File

@@ -0,0 +1,37 @@
# Dependencies
node_modules/
# Build outputs from the frontend bundler. The top-level dist/ is checked in
# because this extension ships as an in-built bundle loaded via LOCAL_EXTENSIONS.
frontend/dist/
*.supx
# Python
__pycache__/
*.py[cod]
*$py.class
*.egg-info/
.eggs/
*.egg
.venv/
venv/
env/
ENV/
# IDE
.idea/
.vscode/
*.swp
*.swo
# OS
.DS_Store
Thumbs.db
# Logs
*.log
npm-debug.log*
# Environment
.env
.env.local

View File

@@ -0,0 +1,4 @@
[report]
exclude_lines =
pragma: no cover
if TYPE_CHECKING:

View File

@@ -0,0 +1,21 @@
"""
Root conftest that patches the semantic_layer decorator before test collection.
The decorator in superset_core.semantic_layers.decorators is a placeholder that
raises NotImplementedError outside of a running Superset instance. We replace it
with a no-op passthrough so the decorated classes can be imported during tests.
"""
import superset_core.semantic_layers.decorators as _dec
def _noop_semantic_layer(**kwargs):
"""Return the class unchanged."""
def wrapper(cls):
return cls
return wrapper
_dec.semantic_layer = _noop_semantic_layer

View File

@@ -0,0 +1,13 @@
[project]
name = "preset_io-dataset_semantic_layer"
version = "0.1.0"
license = "All Rights Reserved"
dependencies = [
"sqlglot",
]
[tool.apache_superset_extensions.build]
include = [
"src/preset_io/dataset_semantic_layer/**/*.py",
]
exclude = []

View File

@@ -0,0 +1,9 @@
from .layer import DatasetSemanticLayer
from .schemas import DatasetConfiguration
from .view import DatasetSemanticView
__all__ = [
"DatasetConfiguration",
"DatasetSemanticLayer",
"DatasetSemanticView",
]

View File

@@ -0,0 +1,3 @@
# The @semantic_layer decorator on DatasetSemanticLayer handles registration
# automatically. This import triggers the decorator at extension load time.
from .layer import DatasetSemanticLayer # noqa: F401

View File

@@ -0,0 +1,95 @@
from __future__ import annotations
from typing import Any
from superset_core.semantic_layers.config import build_configuration_schema
from superset_core.semantic_layers.decorators import semantic_layer
from superset_core.semantic_layers.layer import SemanticLayer
from .schemas import DatasetConfiguration
from .utils import dataset_label, get_dataset_by_id, list_datasets
from .view import DatasetSemanticView
@semantic_layer(
id="dataset",
name="Dataset Semantic Layer",
description=(
"Expose any Superset dataset as a semantic view. Metrics and columns "
"defined on the dataset become the semantic metrics and dimensions."
),
)
class DatasetSemanticLayer(SemanticLayer[DatasetConfiguration, DatasetSemanticView]):
configuration_class = DatasetConfiguration
@classmethod
def from_configuration(
cls,
configuration: dict[str, Any],
) -> "DatasetSemanticLayer":
config = DatasetConfiguration.model_validate(configuration or {})
return cls(config)
@classmethod
def get_configuration_schema(
cls,
configuration: DatasetConfiguration | None = None,
) -> dict[str, Any]:
"""
No configuration is needed to add this semantic layer — the layer wraps
whatever datasets are already registered in Superset.
"""
return build_configuration_schema(DatasetConfiguration, configuration)
@classmethod
def get_runtime_schema(
cls,
configuration: DatasetConfiguration,
runtime_data: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""
When adding a semantic view, the user picks a dataset from a dropdown.
Each existing dataset can be wrapped into a semantic view.
"""
datasets = list_datasets()
ids = [str(dataset.id) for dataset in datasets]
labels = [dataset_label(dataset) for dataset in datasets]
return {
"type": "object",
"required": ["dataset_id"],
"x-singleView": True,
"properties": {
"dataset_id": {
"type": "string",
"title": "Dataset",
"description": "The Superset dataset to expose as a semantic view.",
"enum": ids,
"x-enumNames": labels,
},
},
}
def __init__(self, configuration: DatasetConfiguration) -> None:
self.configuration = configuration
def get_semantic_views(
self,
runtime_configuration: dict[str, Any],
) -> set[DatasetSemanticView]:
dataset_id = runtime_configuration.get("dataset_id")
if not dataset_id:
return set()
dataset = get_dataset_by_id(int(dataset_id))
return {DatasetSemanticView(dataset)}
def get_semantic_view(
self,
name: str,
additional_configuration: dict[str, Any],
) -> DatasetSemanticView:
dataset_id = additional_configuration.get("dataset_id")
if not dataset_id:
raise ValueError("dataset_id is required to load a semantic view")
dataset = get_dataset_by_id(int(dataset_id))
return DatasetSemanticView(dataset)

View File

@@ -0,0 +1,14 @@
from __future__ import annotations
from pydantic import BaseModel, ConfigDict
class DatasetConfiguration(BaseModel):
"""
Configuration for the dataset semantic layer.
The layer wraps Superset datasets directly, so it does not need any
configuration beyond what Superset already knows about each dataset.
"""
model_config = ConfigDict(title="Dataset semantic layer", extra="ignore")

View File

@@ -0,0 +1,89 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any
import pyarrow as pa
if TYPE_CHECKING:
import pandas as pd
from superset.connectors.sqla.models import SqlaTable
def list_datasets() -> list["SqlaTable"]:
"""
Return every ``SqlaTable`` the current user can access, sorted by name.
Read inside the function so the module can be imported in unit tests that
do not have a Flask app context — these helpers are only called at runtime.
"""
from superset import db
from superset.connectors.sqla.models import SqlaTable
# ``no_autoflush`` prevents SQLAlchemy from flushing pending session state
# while we iterate over results — flushing during iteration of the
# session's pending deque has been observed to raise "deque mutated
# during iteration" in this codebase.
with db.session.no_autoflush:
datasets = db.session.query(SqlaTable).all()
return sorted(datasets, key=lambda d: d.table_name)
def get_dataset_by_id(dataset_id: int) -> "SqlaTable":
"""
Look up a dataset by primary key, eagerly materialising its columns and
metrics so downstream code can iterate them without triggering further
lazy loads (and without risking autoflush mid-iteration).
"""
from sqlalchemy.orm import selectinload
from superset import db
from superset.connectors.sqla.models import SqlaTable
with db.session.no_autoflush:
dataset = (
db.session.query(SqlaTable)
.options(
selectinload(SqlaTable.columns),
selectinload(SqlaTable.metrics),
)
.filter_by(id=dataset_id)
.one_or_none()
)
if dataset is None:
raise ValueError(f"Dataset with id {dataset_id} does not exist.")
# Force materialisation inside the no_autoflush block so later access
# never lazy-loads while another flush is in progress.
list(dataset.columns)
list(dataset.metrics)
return dataset
def dataset_label(dataset: "SqlaTable") -> str:
"""
Human-readable label for a dataset combining schema and table name.
"""
parts = [
part
for part in (dataset.catalog, dataset.schema, dataset.table_name)
if part
]
return ".".join(parts) if parts else dataset.table_name
def df_to_arrow(df: "pd.DataFrame") -> pa.Table:
"""
Convert a pandas DataFrame to a pyarrow Table, falling back gracefully when
the DataFrame is empty.
"""
if df is None or df.empty:
return pa.table({col: [] for col in (df.columns if df is not None else [])})
return pa.Table.from_pandas(df, preserve_index=False)
def coerce_literal(value: Any) -> Any:
"""
Coerce filter literal values to native Python types sqlglot can serialise.
"""
if isinstance(value, (set, frozenset)):
return [coerce_literal(v) for v in value]
return value

View File

@@ -0,0 +1,840 @@
from __future__ import annotations
from typing import TYPE_CHECKING
import pyarrow as pa
import sqlglot
from sqlglot import expressions as sqlglot_exp
from superset_core.semantic_layers.types import (
AdhocExpression,
AggregationType,
Dimension,
Filter,
GroupLimit,
Metric,
Operator,
OrderDirection,
PredicateType,
SemanticQuery,
SemanticRequest,
SemanticResult,
)
from superset_core.semantic_layers.view import SemanticView, SemanticViewFeature
from .utils import coerce_literal, df_to_arrow
if TYPE_CHECKING:
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
REQUEST_TYPE = "sql"
# Map Superset's GenericDataType (an IntEnum value) to a pyarrow type. Kept
# narrow on purpose — the dataset model only tracks a coarse generic type, so
# anything more specific belongs to the underlying database and is left as
# UTF-8 at the semantic layer boundary.
_GENERIC_TYPE_TO_ARROW: dict[int, pa.DataType] = {
0: pa.float64(), # NUMERIC
1: pa.utf8(), # STRING
2: pa.timestamp("us"), # TEMPORAL
3: pa.bool_(), # BOOLEAN
}
# Top-level aggregate AST node → AggregationType. Compound expressions like
# ``SUM(a) + SUM(b)`` fall through to OTHER because they cannot be safely
# rolled up.
_AGG_NODE_MAP: dict[type[sqlglot_exp.Expression], AggregationType] = {
sqlglot_exp.Sum: AggregationType.SUM,
sqlglot_exp.Min: AggregationType.MIN,
sqlglot_exp.Max: AggregationType.MAX,
sqlglot_exp.Avg: AggregationType.AVG,
}
_OPERATOR_TO_SQLGLOT: dict[Operator, type[sqlglot_exp.Expression]] = {
Operator.EQUALS: sqlglot_exp.EQ,
Operator.NOT_EQUALS: sqlglot_exp.NEQ,
Operator.GREATER_THAN: sqlglot_exp.GT,
Operator.LESS_THAN: sqlglot_exp.LT,
Operator.GREATER_THAN_OR_EQUAL: sqlglot_exp.GTE,
Operator.LESS_THAN_OR_EQUAL: sqlglot_exp.LTE,
Operator.LIKE: sqlglot_exp.Like,
}
class DatasetSemanticView(SemanticView):
features = frozenset(
{
SemanticViewFeature.ADHOC_EXPRESSIONS_IN_ORDERBY,
SemanticViewFeature.GROUP_LIMIT,
SemanticViewFeature.GROUP_OTHERS,
}
)
def __init__(self, dataset: "SqlaTable") -> None:
self.dataset = dataset
self.name = dataset.table_name
self.dimensions = self.get_dimensions()
self.metrics = self.get_metrics()
# ------------------------------------------------------------------
# Identity & dialect
# ------------------------------------------------------------------
def uid(self) -> str:
return f"dataset:{self.dataset.id}"
@property
def _sqlglot_dialect(self) -> str | None:
# Imported lazily — ``superset.sql.parse`` pulls in a heavy dependency
# graph that is not needed for unit tests of pure AST building.
from superset.sql.parse import SQLGLOT_DIALECTS
engine = self.dataset.database.db_engine_spec.engine
dialect = SQLGLOT_DIALECTS.get(engine)
return dialect.value if dialect is not None else None
# ------------------------------------------------------------------
# Metadata
# ------------------------------------------------------------------
def get_dimensions(self) -> set[Dimension]:
dimensions: set[Dimension] = set()
for column in self.dataset.columns:
if not column.groupby:
continue
dimensions.add(
Dimension(
id=column.column_name,
name=column.verbose_name or column.column_name,
type=self._column_arrow_type(column),
definition=column.expression or None,
description=column.description or None,
)
)
return dimensions
def get_metrics(self) -> set[Metric]:
from superset.semantic_layers.arrow_inference import infer_arrow_type
engine = self.dataset.database.db_engine_spec.engine
metrics: set[Metric] = set()
for metric in self.dataset.metrics:
metrics.add(
Metric(
id=metric.metric_name,
name=metric.verbose_name or metric.metric_name,
type=infer_arrow_type(metric.expression, engine),
definition=metric.expression,
description=metric.description or None,
aggregation=self._aggregation_from_expression(metric.expression),
)
)
return metrics
def get_compatible_metrics(
self,
selected_metrics: set[Metric],
selected_dimensions: set[Dimension],
) -> set[Metric]:
return self.metrics
def get_compatible_dimensions(
self,
selected_metrics: set[Metric],
selected_dimensions: set[Dimension],
) -> set[Dimension]:
return self.dimensions
def _column_arrow_type(self, column: "TableColumn") -> pa.DataType:
generic = column.type_generic
if generic is None:
return pa.utf8()
return _GENERIC_TYPE_TO_ARROW.get(int(generic), pa.utf8())
def _aggregation_from_expression(
self,
expression: str | None,
) -> AggregationType:
if not expression:
return AggregationType.OTHER
try:
parsed = sqlglot.parse_one(expression, dialect=self._sqlglot_dialect)
except sqlglot.errors.ParseError:
return AggregationType.OTHER
if isinstance(parsed, sqlglot_exp.Count):
if isinstance(parsed.this, sqlglot_exp.Distinct):
return AggregationType.COUNT_DISTINCT
return AggregationType.COUNT
for node_type, aggregation in _AGG_NODE_MAP.items():
if isinstance(parsed, node_type):
return aggregation
return AggregationType.OTHER
# ------------------------------------------------------------------
# AST building
# ------------------------------------------------------------------
def _dimension_column(self, dimension: Dimension) -> "TableColumn":
# Callers (currently only ``_dimension_expression``) check membership
# before invoking this helper, so the match is guaranteed.
for column in self.dataset.columns:
if column.column_name == dimension.id:
return column
raise ValueError( # pragma: no cover - guarded by caller
f'Dimension "{dimension.id}" is not part of this dataset.'
)
def _metric_column(self, metric: Metric) -> "SqlMetric":
for sql_metric in self.dataset.metrics:
if sql_metric.metric_name == metric.id:
return sql_metric
raise ValueError(f'Metric "{metric.id}" is not part of this dataset.')
def _dimension_expression(self, dimension: Dimension) -> sqlglot_exp.Expression:
# Synthesised adhoc dimensions don't correspond to a dataset column;
# their ``definition`` carries the SQL we should emit.
dataset_column_names = {col.column_name for col in self.dataset.columns}
if dimension.id not in dataset_column_names:
if not dimension.definition:
raise ValueError(
f'Dimension "{dimension.id}" is not part of this dataset.'
)
base: sqlglot_exp.Expression = sqlglot.parse_one(
dimension.definition,
dialect=self._sqlglot_dialect,
)
return self._apply_grain(dimension, base)
column = self._dimension_column(dimension)
if column.expression:
base = sqlglot.parse_one(
column.expression,
dialect=self._sqlglot_dialect,
)
else:
base = sqlglot_exp.column(column.column_name, quoted=True)
return self._apply_grain(dimension, base)
def _apply_grain(
self,
dimension: Dimension,
base_expr: sqlglot_exp.Expression,
) -> sqlglot_exp.Expression:
"""
Wrap ``base_expr`` in the engine spec's time-grain template when the
dimension carries a grain. Engines that don't model the requested
grain return the unwrapped expression — same fallback the legacy path
uses.
"""
if dimension.grain is None:
return base_expr
engine_spec = self.dataset.database.db_engine_spec
template = engine_spec.get_time_grain_expressions().get(
dimension.grain.representation
)
if not template:
return base_expr
# ``{func}`` / ``{type}`` placeholders (BigQuery, MS SQL) require the
# engine spec's full ``get_timestamp_expr`` machinery to resolve.
if "{func}" in template or "{type}" in template:
import sqlalchemy as sa
sa_col = sa.column(
base_expr.sql(dialect=self._sqlglot_dialect),
is_literal=True,
)
tse = engine_spec.get_timestamp_expr(
sa_col,
None,
dimension.grain.representation,
)
sql_str = str(
tse.compile(
dialect=self.dataset.database.get_dialect(),
compile_kwargs={"literal_binds": True},
)
)
return sqlglot.parse_one(sql_str, dialect=self._sqlglot_dialect)
col_sql = base_expr.sql(dialect=self._sqlglot_dialect)
grain_sql = template.replace("{col}", col_sql)
return sqlglot.parse_one(grain_sql, dialect=self._sqlglot_dialect)
def _metric_expression(self, metric: Metric) -> sqlglot_exp.Expression:
sql_metric = self._metric_column(metric)
return sqlglot.parse_one(sql_metric.expression, dialect=self._sqlglot_dialect)
def _source_table(self) -> sqlglot_exp.Expression:
"""
Build the FROM clause. Physical datasets reference the table directly;
virtual datasets wrap their SQL as a subquery so the rest of the AST
can treat the source like a table.
For virtual datasets, RLS rules from the underlying tables referenced
in ``dataset.sql`` are injected into the inner SQL. When the engine
spec asks for ``AS_SUBQUERY``-style RLS, the dataset's own RLS
predicates wrap the source in an extra ``SELECT * FROM … WHERE`` so
the rest of the AST is unaffected.
"""
dataset = self.dataset
if dataset.sql:
inner_sql = self._rls_apply_to_virtual_sql() or dataset.sql
inner = sqlglot.parse_one(inner_sql, dialect=self._sqlglot_dialect)
source: sqlglot_exp.Expression = sqlglot_exp.Subquery(
this=inner, alias="virtual_dataset"
)
else:
parts = [
sqlglot_exp.to_identifier(part, quoted=True)
for part in (dataset.catalog, dataset.schema, dataset.table_name)
if part
]
if len(parts) == 1:
source = sqlglot_exp.Table(this=parts[0])
elif len(parts) == 2:
source = sqlglot_exp.Table(this=parts[1], db=parts[0])
else:
source = sqlglot_exp.Table(this=parts[2], db=parts[1], catalog=parts[0])
if self._rls_should_wrap_subquery():
rls_clauses = self._rls_predicates()
if rls_clauses:
source = self._wrap_with_rls(source, rls_clauses)
return source
# ------------------------------------------------------------------
# RLS
# ------------------------------------------------------------------
def _rls_predicates(self) -> list[str]:
"""Outer RLS predicates for the dataset, rendered as SQL strings."""
from superset.semantic_layers.rls import render_rls_predicates
return render_rls_predicates(self.dataset)
def _rls_apply_to_virtual_sql(self) -> str | None:
from superset.semantic_layers.rls import apply_rls_to_virtual_sql
return apply_rls_to_virtual_sql(self.dataset)
def _rls_should_wrap_subquery(self) -> bool:
from superset.semantic_layers.rls import get_rls_method
from superset.sql.parse import RLSMethod
return get_rls_method(self.dataset) == RLSMethod.AS_SUBQUERY
def _wrap_with_rls(
self,
source: sqlglot_exp.Expression,
rls_clauses: list[str],
) -> sqlglot_exp.Subquery:
"""Wrap ``source`` in ``SELECT * FROM source WHERE <AND-ed RLS>``."""
parsed = [
sqlglot.parse_one(clause, dialect=self._sqlglot_dialect)
for clause in rls_clauses
]
combined = parsed[0]
for predicate in parsed[1:]:
combined = sqlglot_exp.And(this=combined, expression=predicate)
inner = (
sqlglot_exp.Select()
.select(sqlglot_exp.Star())
.from_(source)
.where(combined)
)
return sqlglot_exp.Subquery(this=inner, alias="rls_wrapped")
def _filter_predicate(self, filter_: Filter) -> sqlglot_exp.Expression:
if filter_.operator == Operator.ADHOC:
if not isinstance(filter_.value, str):
raise ValueError("Adhoc filter value must be a SQL string")
return sqlglot.parse_one(filter_.value, dialect=self._sqlglot_dialect)
if filter_.column is None:
raise ValueError("Native filters require a column")
if isinstance(filter_.column, Dimension):
column_expr = self._dimension_expression(filter_.column)
else:
column_expr = self._metric_expression(filter_.column)
operator = filter_.operator
if operator == Operator.IS_NULL:
return sqlglot_exp.Is(this=column_expr, expression=sqlglot_exp.Null())
if operator == Operator.IS_NOT_NULL:
return sqlglot_exp.Not(
this=sqlglot_exp.Is(this=column_expr, expression=sqlglot_exp.Null())
)
if operator in (Operator.IN, Operator.NOT_IN):
values = coerce_literal(filter_.value)
if not isinstance(values, list):
values = [values]
expressions = [sqlglot_exp.convert(v) for v in values]
in_expr = sqlglot_exp.In(this=column_expr, expressions=expressions)
return (
sqlglot_exp.Not(this=in_expr)
if operator == Operator.NOT_IN
else in_expr
)
if operator == Operator.NOT_LIKE:
return sqlglot_exp.Not(
this=sqlglot_exp.Like(
this=column_expr,
expression=sqlglot_exp.convert(filter_.value),
)
)
sqlglot_op_cls = _OPERATOR_TO_SQLGLOT.get(operator)
if sqlglot_op_cls is None:
raise ValueError(f"Unsupported operator: {operator}")
return sqlglot_op_cls(
this=column_expr,
expression=sqlglot_exp.convert(filter_.value),
)
def _combine_predicates(
self,
filters: set[Filter],
) -> sqlglot_exp.Expression | None:
if not filters:
return None
# Sort to make the resulting SQL deterministic. Filter's default
# tuple ordering eventually compares ``Dimension`` instances, which
# are not orderable, so use a stable string key instead.
ordered = sorted(
filters,
key=lambda f: (
f.type.value,
f.column.id if f.column is not None else "",
f.operator.value,
repr(f.value),
),
)
combined = self._filter_predicate(ordered[0])
for filter_ in ordered[1:]:
combined = sqlglot_exp.And(
this=combined,
expression=self._filter_predicate(filter_),
)
return combined
def _order_expression(
self,
element: Metric | Dimension | AdhocExpression,
direction: OrderDirection,
) -> sqlglot_exp.Ordered:
if isinstance(element, AdhocExpression):
inner = sqlglot.parse_one(
element.definition,
dialect=self._sqlglot_dialect,
)
else:
inner = sqlglot_exp.column(element.id, quoted=True)
return sqlglot_exp.Ordered(this=inner, desc=direction == OrderDirection.DESC)
def _build_ast(self, query: SemanticQuery) -> sqlglot_exp.Select:
if query.limit is None and query.offset is not None:
raise ValueError("Offset cannot be set without limit")
filters = query.filters or set()
where_filters = {f for f in filters if f.type == PredicateType.WHERE}
having_filters = {f for f in filters if f.type == PredicateType.HAVING}
# Add the dataset's outer RLS predicates as ADHOC where-filters when
# the engine wants AS_PREDICATE-style RLS. AS_SUBQUERY-style RLS is
# handled in ``_source_table`` instead.
if not self._rls_should_wrap_subquery():
for clause in self._rls_predicates():
where_filters = where_filters | {
Filter(
type=PredicateType.WHERE,
column=None,
operator=Operator.ADHOC,
value=clause,
)
}
if query.group_limit is not None and query.group_limit.group_others:
return self._build_with_others(query, where_filters, having_filters)
if query.group_limit is not None:
return self._build_with_group_limit(query, where_filters, having_filters)
return self._build_plain(query, where_filters, having_filters)
def _build_plain(
self,
query: SemanticQuery,
where_filters: set[Filter],
having_filters: set[Filter],
) -> sqlglot_exp.Select:
projections: list[sqlglot_exp.Expression] = []
for dimension in query.dimensions:
projections.append(
sqlglot_exp.alias_(
self._dimension_expression(dimension),
dimension.id,
quoted=True,
)
)
for metric in query.metrics:
projections.append(
sqlglot_exp.alias_(
self._metric_expression(metric),
metric.id,
quoted=True,
)
)
select = (
sqlglot_exp.Select()
.select(*projections, append=False)
.from_(self._source_table())
)
if where_predicate := self._combine_predicates(where_filters):
select = select.where(where_predicate)
# GROUP BY when aggregating: any metric is an aggregate, so all
# non-aggregate dimensions need to be in GROUP BY. For grain-bucketed
# dimensions we GROUP BY the full bucketed expression because dialects
# disagree on whether SELECT aliases are visible in GROUP BY.
if query.metrics and query.dimensions:
group_by_columns = [
(
self._dimension_expression(dimension)
if dimension.grain is not None
else sqlglot_exp.column(dimension.id, quoted=True)
)
for dimension in query.dimensions
]
select = select.group_by(*group_by_columns)
if having_predicate := self._combine_predicates(having_filters):
select = select.having(having_predicate)
if query.order:
ordered = [
self._order_expression(element, direction)
for element, direction in query.order
]
select = select.order_by(*ordered)
if query.limit is not None:
select = select.limit(query.limit)
if query.offset is not None:
select = select.offset(query.offset)
return select
# ------------------------------------------------------------------
# Group limit (top N) helpers
# ------------------------------------------------------------------
def _top_groups_cte_select(
self,
group_limit: GroupLimit,
main_where_filters: set[Filter],
) -> sqlglot_exp.Select:
"""
Build the SELECT that powers the ``top_groups`` CTE.
The CTE projects only the limited dimensions. The ordering metric, when
present, is evaluated inline in ORDER BY so it does not leak into the
CTE's column list.
"""
if group_limit.filters is not None:
if any(f.type == PredicateType.HAVING for f in group_limit.filters):
raise ValueError(
"HAVING filters are not supported in group_limit.filters"
)
cte_where_filters = {
f for f in group_limit.filters if f.type == PredicateType.WHERE
}
else:
cte_where_filters = main_where_filters
dim_projections = [
sqlglot_exp.alias_(
self._dimension_expression(dim),
dim.id,
quoted=True,
)
for dim in group_limit.dimensions
]
select = (
sqlglot_exp.Select()
.select(*dim_projections, append=False)
.from_(self._source_table())
)
if where_predicate := self._combine_predicates(cte_where_filters):
select = select.where(where_predicate)
# When ordering by a metric we need an aggregation; GROUP BY the
# limited dimensions so each combination collapses to a single row.
# Use full expressions for grain-bucketed dims (alias resolution in
# GROUP BY isn't portable).
if group_limit.metric is not None:
select = select.group_by(
*[
(
self._dimension_expression(dim)
if dim.grain is not None
else sqlglot_exp.column(dim.id, quoted=True)
)
for dim in group_limit.dimensions
]
)
order_expr: sqlglot_exp.Expression = self._metric_expression(
group_limit.metric
)
else:
order_expr = sqlglot_exp.column(
group_limit.dimensions[0].id,
quoted=True,
)
select = select.order_by(
sqlglot_exp.Ordered(
this=order_expr,
desc=group_limit.direction == OrderDirection.DESC,
)
)
return select.limit(group_limit.top)
def _top_groups_in_predicate(
self,
group_limit: GroupLimit,
) -> sqlglot_exp.Expression:
"""
Build a predicate restricting the limited dimensions to the rows of the
``top_groups`` CTE. Single dimension uses scalar IN; multiple use a
row-tuple IN. The subquery is wrapped in ``Subquery`` so sqlglot emits
the surrounding parentheses.
"""
cte_table = sqlglot_exp.to_identifier("top_groups")
dim_columns = [
sqlglot_exp.column(dim.id, quoted=True)
for dim in group_limit.dimensions
]
subquery = sqlglot_exp.Subquery(
this=(
sqlglot_exp.Select()
.select(*dim_columns)
.from_(cte_table)
)
)
if len(dim_columns) == 1:
return sqlglot_exp.In(this=dim_columns[0], query=subquery)
return sqlglot_exp.In(
this=sqlglot_exp.Tuple(expressions=dim_columns),
query=subquery,
)
def _build_with_group_limit(
self,
query: SemanticQuery,
where_filters: set[Filter],
having_filters: set[Filter],
) -> sqlglot_exp.Select:
"""
Restrict the main query to rows whose limited-dimension values appear
in the ``top_groups`` CTE.
"""
select = self._build_plain(query, where_filters, having_filters)
select = select.where(self._top_groups_in_predicate(query.group_limit))
cte_select = self._top_groups_cte_select(query.group_limit, where_filters)
return select.with_("top_groups", as_=cte_select)
def _build_with_others(
self,
query: SemanticQuery,
where_filters: set[Filter],
having_filters: set[Filter],
) -> sqlglot_exp.Select:
"""
Bucket non-top values into ``'Other'`` for the limited dimensions and
re-aggregate against the bucketed groups. Non-additive metrics stay
correct because we re-evaluate the original metric expression instead
of summing previously aggregated rows.
"""
group_limit = query.group_limit
limited_dim_ids = {dim.id for dim in group_limit.dimensions}
in_predicate = self._top_groups_in_predicate(group_limit)
def dim_expr(dim: Dimension) -> sqlglot_exp.Expression:
"""Underlying expression for a dimension, with CASE for limited ones."""
base = self._dimension_expression(dim)
if dim.id not in limited_dim_ids:
return base
return sqlglot_exp.Case(
ifs=[sqlglot_exp.If(this=in_predicate.copy(), true=base)],
default=sqlglot_exp.Literal.string("Other"),
)
projections: list[sqlglot_exp.Expression] = []
for dim in query.dimensions:
projections.append(sqlglot_exp.alias_(dim_expr(dim), dim.id, quoted=True))
for metric in query.metrics:
projections.append(
sqlglot_exp.alias_(
self._metric_expression(metric),
metric.id,
quoted=True,
)
)
select = (
sqlglot_exp.Select()
.select(*projections, append=False)
.from_(self._source_table())
)
if where_predicate := self._combine_predicates(where_filters):
select = select.where(where_predicate)
# GROUP BY the full expressions (not the aliases). Dialects disagree on
# whether GROUP BY can reference a SELECT alias, and using the full
# expression here side-steps that ambiguity for the CASE columns.
if query.metrics and query.dimensions:
select = select.group_by(
*[dim_expr(dim) for dim in query.dimensions]
)
if having_predicate := self._combine_predicates(having_filters):
select = select.having(having_predicate)
if query.order:
ordered = [
self._order_expression(element, direction)
for element, direction in query.order
]
select = select.order_by(*ordered)
if query.limit is not None:
select = select.limit(query.limit)
if query.offset is not None:
select = select.offset(query.offset)
cte_select = self._top_groups_cte_select(group_limit, where_filters)
return select.with_("top_groups", as_=cte_select)
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def compile_query(self, query: SemanticQuery) -> str:
"""Return the SQL for ``query`` in the dataset's database dialect."""
return self._build_ast(query).sql(dialect=self._sqlglot_dialect)
def get_values(
self,
dimension: Dimension,
filters: set[Filter] | None = None,
) -> SemanticResult:
where_filters = {
f for f in (filters or set()) if f.type == PredicateType.WHERE
}
# Add outer RLS predicates as ADHOC filters; AS_SUBQUERY-mode RLS is
# wrapped into ``_source_table`` already.
if not self._rls_should_wrap_subquery():
for clause in self._rls_predicates():
where_filters.add(
Filter(
type=PredicateType.WHERE,
column=None,
operator=Operator.ADHOC,
value=clause,
)
)
where = self._combine_predicates(where_filters)
select = (
sqlglot_exp.Select()
.distinct()
.select(
sqlglot_exp.alias_(
self._dimension_expression(dimension),
dimension.id,
quoted=True,
)
)
.from_(self._source_table())
)
if where is not None:
select = select.where(where)
sql = select.sql(dialect=self._sqlglot_dialect)
df = self.dataset.database.get_df(
sql,
catalog=self.dataset.catalog,
schema=self.dataset.schema,
)
return SemanticResult(
requests=[SemanticRequest(REQUEST_TYPE, sql)],
results=df_to_arrow(df),
)
def get_table(self, query: SemanticQuery) -> SemanticResult:
if not query.metrics and not query.dimensions:
return SemanticResult(requests=[], results=pa.table({}))
sql = self.compile_query(query)
df = self.dataset.database.get_df(
sql,
catalog=self.dataset.catalog,
schema=self.dataset.schema,
)
# Map alias columns back to dimension/metric *names* so the result uses
# human-friendly labels instead of internal IDs.
rename: dict[str, str] = {}
for dimension in query.dimensions:
rename[dimension.id] = dimension.name
for metric in query.metrics:
rename[metric.id] = metric.name
if df is not None and not df.empty and rename:
df = df.rename(columns=rename)
return SemanticResult(
requests=[SemanticRequest(REQUEST_TYPE, sql)],
results=df_to_arrow(df),
)
def get_row_count(self, query: SemanticQuery) -> SemanticResult:
if not query.metrics and not query.dimensions:
return SemanticResult(requests=[], results=pa.table({"COUNT": [0]}))
inner = self._build_ast(query)
count_select = (
sqlglot_exp.Select()
.select(
sqlglot_exp.alias_(
sqlglot_exp.Count(this=sqlglot_exp.Star()),
"COUNT",
quoted=True,
)
)
.from_(sqlglot_exp.Subquery(this=inner, alias="subquery"))
)
sql = count_select.sql(dialect=self._sqlglot_dialect)
df = self.dataset.database.get_df(
sql,
catalog=self.dataset.catalog,
schema=self.dataset.schema,
)
return SemanticResult(
requests=[SemanticRequest(REQUEST_TYPE, sql)],
results=df_to_arrow(df),
)
__repr__ = uid

View File

@@ -0,0 +1,80 @@
# flake8: noqa: E501
from unittest.mock import MagicMock
import pytest
from preset_io.dataset_semantic_layer import (
DatasetConfiguration,
DatasetSemanticLayer,
DatasetSemanticView,
)
@pytest.fixture
def dataset() -> MagicMock:
db = MagicMock()
db.db_engine_spec.engine = "postgresql"
ds = MagicMock(spec=[])
ds.id = 1
ds.table_name = "orders"
ds.schema = "public"
ds.catalog = None
ds.sql = None
ds.database = db
ds.columns = []
ds.metrics = []
return ds
def test_empty_configuration_schema() -> None:
schema = DatasetSemanticLayer.get_configuration_schema()
# No required fields — adding the layer takes zero input.
assert schema.get("required", []) == []
assert schema["properties"] == {}
def test_runtime_schema_exposes_dataset_dropdown(mocker, dataset: MagicMock) -> None:
mocker.patch(
"preset_io.dataset_semantic_layer.layer.list_datasets",
return_value=[dataset],
)
schema = DatasetSemanticLayer.get_runtime_schema(DatasetConfiguration())
assert schema["required"] == ["dataset_id"]
field = schema["properties"]["dataset_id"]
assert field["enum"] == ["1"]
assert field["x-enumNames"] == ["public.orders"]
def test_get_semantic_view_requires_dataset_id() -> None:
layer = DatasetSemanticLayer.from_configuration({})
with pytest.raises(ValueError, match="dataset_id"):
layer.get_semantic_view("orders", {})
def test_get_semantic_view_returns_dataset_view(mocker, dataset: MagicMock) -> None:
mocker.patch(
"preset_io.dataset_semantic_layer.layer.get_dataset_by_id",
return_value=dataset,
)
layer = DatasetSemanticLayer.from_configuration({})
view = layer.get_semantic_view("orders", {"dataset_id": "1"})
assert isinstance(view, DatasetSemanticView)
assert view.uid() == "dataset:1"
def test_get_semantic_views_returns_empty_when_no_id() -> None:
layer = DatasetSemanticLayer.from_configuration({})
assert layer.get_semantic_views({}) == set()
def test_get_semantic_views_resolves_dataset(mocker, dataset: MagicMock) -> None:
mocker.patch(
"preset_io.dataset_semantic_layer.layer.get_dataset_by_id",
return_value=dataset,
)
layer = DatasetSemanticLayer.from_configuration({})
views = layer.get_semantic_views({"dataset_id": "1"})
assert len(views) == 1
assert next(iter(views)).uid() == "dataset:1"

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,157 @@
# flake8: noqa: E501
import sys
import types
from unittest.mock import MagicMock
import pandas as pd
import pyarrow as pa
import pytest
from preset_io.dataset_semantic_layer.utils import (
coerce_literal,
dataset_label,
df_to_arrow,
)
# ---------------------------------------------------------------------------
# Pure helpers
# ---------------------------------------------------------------------------
def test_coerce_literal_passes_through_scalars() -> None:
assert coerce_literal("x") == "x"
assert coerce_literal(42) == 42
assert coerce_literal(None) is None
def test_coerce_literal_unwraps_sets() -> None:
result = coerce_literal({"a", "b"})
assert isinstance(result, list)
assert set(result) == {"a", "b"}
def test_dataset_label_combines_parts() -> None:
ds = MagicMock(spec=[])
ds.catalog = "warehouse"
ds.schema = "public"
ds.table_name = "orders"
assert dataset_label(ds) == "warehouse.public.orders"
def test_dataset_label_skips_blank_parts() -> None:
ds = MagicMock(spec=[])
ds.catalog = None
ds.schema = None
ds.table_name = "orders"
assert dataset_label(ds) == "orders"
def test_df_to_arrow_with_data() -> None:
df = pd.DataFrame({"a": [1, 2], "b": ["x", "y"]})
table = df_to_arrow(df)
assert table.num_rows == 2
assert table.column_names == ["a", "b"]
def test_df_to_arrow_with_empty_df() -> None:
df = pd.DataFrame(columns=["a", "b"])
table = df_to_arrow(df)
assert table.num_rows == 0
assert table.column_names == ["a", "b"]
def test_df_to_arrow_with_none() -> None:
table = df_to_arrow(None)
assert table.num_rows == 0
# ---------------------------------------------------------------------------
# Session-bound helpers — superset is stubbed so we don't require a Flask app
# ---------------------------------------------------------------------------
@pytest.fixture
def fake_superset(monkeypatch: pytest.MonkeyPatch) -> tuple[MagicMock, type]:
"""
Provide a minimal stand-in for ``superset.db`` and
``superset.connectors.sqla.models.SqlaTable`` so list/get helpers run
without a real Flask app or database.
"""
class FakeSqlaTable: # noqa: D401 — marker class only
"""Sentinel used as the queried entity."""
columns: list = []
metrics: list = []
session = MagicMock()
db_module = types.ModuleType("superset")
db_module.db = MagicMock()
db_module.db.session = session
sqla_models = types.ModuleType("superset.connectors.sqla.models")
sqla_models.SqlaTable = FakeSqlaTable
connectors_module = types.ModuleType("superset.connectors")
connectors_sqla = types.ModuleType("superset.connectors.sqla")
monkeypatch.setitem(sys.modules, "superset", db_module)
monkeypatch.setitem(sys.modules, "superset.connectors", connectors_module)
monkeypatch.setitem(sys.modules, "superset.connectors.sqla", connectors_sqla)
monkeypatch.setitem(sys.modules, "superset.connectors.sqla.models", sqla_models)
# Also stub sqlalchemy.orm.selectinload for get_dataset_by_id.
return session, FakeSqlaTable
def test_list_datasets_returns_sorted(fake_superset) -> None:
from preset_io.dataset_semantic_layer.utils import list_datasets
session, _ = fake_superset
a, b, c = MagicMock(), MagicMock(), MagicMock()
a.table_name = "z_users"
b.table_name = "a_orders"
c.table_name = "m_payments"
session.no_autoflush.__enter__.return_value = None
session.no_autoflush.__exit__.return_value = False
session.query.return_value.all.return_value = [a, b, c]
result = list_datasets()
assert [d.table_name for d in result] == ["a_orders", "m_payments", "z_users"]
def test_get_dataset_by_id_raises_when_missing(fake_superset) -> None:
from preset_io.dataset_semantic_layer.utils import get_dataset_by_id
session, _ = fake_superset
session.no_autoflush.__enter__.return_value = None
session.no_autoflush.__exit__.return_value = False
(
session.query.return_value.options.return_value.filter_by.return_value.one_or_none.return_value
) = None
with pytest.raises(ValueError, match="Dataset with id 42 does not exist."):
get_dataset_by_id(42)
def test_get_dataset_by_id_materialises_relationships(fake_superset) -> None:
from preset_io.dataset_semantic_layer.utils import get_dataset_by_id
session, _ = fake_superset
dataset = MagicMock()
# Use real lists so iteration in get_dataset_by_id doesn't error.
dataset.columns = ["col1", "col2"]
dataset.metrics = ["m1"]
session.no_autoflush.__enter__.return_value = None
session.no_autoflush.__exit__.return_value = False
(
session.query.return_value.options.return_value.filter_by.return_value.one_or_none.return_value
) = dataset
result = get_dataset_by_id(1)
assert result is dataset

View File

@@ -0,0 +1,8 @@
{
"publisher": "preset-io",
"name": "dataset-semantic-layer",
"displayName": "Dataset Semantic Layer",
"version": "0.1.0",
"license": "All Rights Reserved",
"permissions": []
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,34 @@
{
"name": "dataset_semantic_layer",
"version": "0.1.0",
"main": "dist/main.js",
"types": "dist/publicAPI.d.ts",
"scripts": {
"test": "echo \"Error: no test specified\" && exit 1",
"start": "webpack serve --mode development",
"build": "webpack --stats-error-details --mode production"
},
"keywords": [],
"private": true,
"author": "",
"license": "All Rights Reserved",
"description": "",
"peerDependencies": {
"@apache-superset/core": "0.1.0-rc2",
"react": "^17.0.2",
"react-dom": "^17.0.2"
},
"devDependencies": {
"@babel/preset-react": "^7.26.3",
"@babel/preset-typescript": "^7.26.0",
"@types/react": "^19.0.10",
"copy-webpack-plugin": "^13.0.0",
"install": "^0.13.0",
"npm": "^11.1.0",
"ts-loader": "^9.5.2",
"typescript": "^5.8.2",
"webpack": "^5.98.0",
"webpack-cli": "^6.0.1",
"webpack-dev-server": "^5.2.0"
}
}

View File

@@ -0,0 +1,17 @@
import React from "react";
import { views } from "@apache-superset/core";
const viewDisposable = views.registerView(
{ id: "dataset.semantic-layer", name: "Dataset Semantic Layer" },
"sqllab.panels",
() => <p>Dataset Semantic Layer</p>
);
export const activate = () => {
console.log("Dataset Semantic Layer extension activated");
};
export const deactivate = () => {
viewDisposable.dispose();
console.log("Dataset Semantic Layer extension deactivated");
};

View File

@@ -0,0 +1,13 @@
{
"compilerOptions": {
"target": "es5",
"module": "esnext",
"moduleResolution": "node10",
"jsx": "react",
"strict": true,
"esModuleInterop": true,
"skipLibCheck": true,
"forceConsistentCasingInFileNames": true
},
"include": ["src"]
}

View File

@@ -0,0 +1,67 @@
const path = require("path");
const { ModuleFederationPlugin } = require("webpack").container;
const packageConfig = require("./package");
module.exports = (env, argv) => {
const isProd = argv.mode === "production";
return {
entry: isProd ? {} : "./src/index.tsx",
mode: isProd ? "production" : "development",
devServer: {
port: 3000,
headers: {
"Access-Control-Allow-Origin": "*",
},
},
output: {
clean: true,
filename: isProd ? undefined : "[name].[contenthash].js",
chunkFilename: "[name].[contenthash].js",
path: path.resolve(__dirname, "dist"),
publicPath: `/api/v1/extensions/${packageConfig.name}/`,
},
resolve: {
extensions: [".ts", ".tsx", ".js", ".jsx"],
},
externalsType: "window",
externals: {
"@apache-superset/core": "superset",
},
module: {
rules: [
{
test: /\.tsx?$/,
use: "ts-loader",
exclude: /node_modules/,
},
],
},
plugins: [
new ModuleFederationPlugin({
name: "dataset_semantic_layer",
filename: "remoteEntry.[contenthash].js",
exposes: {
"./index": "./src/index.tsx",
},
shared: {
react: {
singleton: true,
requiredVersion: packageConfig.peerDependencies.react,
import: false,
},
"react-dom": {
singleton: true,
requiredVersion: packageConfig.peerDependencies["react-dom"],
import: false,
},
antd: {
singleton: true,
requiredVersion: packageConfig.peerDependencies["antd"],
import: false,
},
},
}),
],
};
};

View File

@@ -24,6 +24,7 @@ single dataframe.
"""
import dataclasses
from datetime import date, datetime, time, timedelta, tzinfo
from time import time as current_time
from typing import Any, cast, Sequence, TypeGuard
@@ -56,7 +57,11 @@ from superset.common.utils.time_range_utils import get_since_until_from_query_ob
from superset.connectors.sqla.models import BaseDatasource
from superset.constants import NO_TIME_RANGE
from superset.models.helpers import QueryResult
from superset.superset_typing import AdhocColumn
from superset.semantic_layers.adhoc import (
adhoc_column_to_semantic_dimension,
adhoc_metric_to_semantic_metric,
)
from superset.superset_typing import AdhocColumn, AdhocMetric
from superset.utils.core import (
FilterOperator,
QueryObjectFilterClause,
@@ -258,16 +263,92 @@ def _normalize_column(column: str | AdhocColumn, dimension_names: set[str]) -> s
- A string (dimension name directly)
- An AdhocColumn with isColumnReference=True and sqlExpression containing the
dimension name
Used by callers that only care about the resolved *name* (e.g. validating
that a granularity column matches a real dimension). Synthetic adhocs
(non-column-reference) carry no dimension name and should be routed
through :func:`_resolve_dimension` instead.
"""
if isinstance(column, str):
return column
# Handle column references (e.g., from time-series charts)
# Column reference adhocs unwrap to their underlying column name.
if column.get("isColumnReference") and (sql_expr := column.get("sqlExpression")):
if sql_expr in dimension_names:
return sql_expr
raise ValueError("Adhoc dimensions are not supported in Semantic Views.")
# Synthetic adhoc: the label *is* the resolved name in semantic-layer
# terms (we use it as the Dimension id/name when materialising). Falling
# back here keeps validators happy with adhoc dicts and defers actual
# resolution to ``_resolve_dimension``.
if label := column.get("label"):
return label
raise ValueError("Adhoc column is missing a ``label``.")
def _resolve_dimension(
column: str | AdhocColumn,
dataset: BaseDatasource,
dimensions_by_name: dict[str, Dimension],
template_processor: Any | None,
) -> Dimension:
"""
Resolve a column entry (string name or adhoc dict) to a ``Dimension``.
Strings look up a saved dimension by name. Adhoc dicts marked
``isColumnReference=True`` re-use the matching saved dimension so its
metadata (type, grain) is preserved. Anything else is synthesised via
:func:`adhoc_column_to_semantic_dimension`.
"""
if isinstance(column, str):
if column not in dimensions_by_name:
raise ValueError(
f"Dimension {column!r} is not defined in the Semantic View."
)
return dimensions_by_name[column]
return adhoc_column_to_semantic_dimension(
column,
dataset,
dimensions_by_name,
template_processor,
)
def _resolve_metric(
metric: str | AdhocMetric,
dataset: BaseDatasource,
metrics_by_name: dict[str, Metric],
template_processor: Any | None,
) -> Metric:
"""
Resolve a metric entry (string name or adhoc dict) to a ``Metric``.
"""
if isinstance(metric, str):
if metric not in metrics_by_name:
raise ValueError(
f"Metric {metric!r} is not defined in the Semantic View."
)
return metrics_by_name[metric]
return adhoc_metric_to_semantic_metric(metric, dataset, template_processor)
def _get_template_processor(dataset: BaseDatasource) -> Any | None:
"""
Build a template processor for adhoc Jinja rendering, or ``None`` when
the datasource doesn't expose one (e.g. a cube-mode semantic view, where
adhocs aren't supported anyway).
"""
get_template_processor = getattr(dataset, "get_template_processor", None)
if get_template_processor is None:
return None
return get_template_processor()
def _stamp_grain(dimension: Dimension, grain: Grain) -> Dimension:
"""Return a copy of ``dimension`` with the supplied grain set."""
return dataclasses.replace(dimension, grain=grain)
def map_query_object(query_object: ValidatedQueryObject) -> list[SemanticQuery]:
@@ -278,33 +359,54 @@ def map_query_object(query_object: ValidatedQueryObject) -> list[SemanticQuery]:
visualization and more on semantics.
"""
semantic_view = query_object.datasource.implementation
dataset = query_object.datasource
all_metrics = {metric.name: metric for metric in semantic_view.metrics}
all_dimensions = {
dimension.name: dimension for dimension in semantic_view.dimensions
}
# Normalize columns (may be dicts with isColumnReference=True for time-series)
dimension_names = set(all_dimensions.keys())
normalized_columns = {
_normalize_column(column, dimension_names) for column in query_object.columns
}
template_processor = _get_template_processor(dataset)
metrics = [all_metrics[metric] for metric in (query_object.metrics or [])]
metrics = [
_resolve_metric(metric, dataset, all_metrics, template_processor)
for metric in (query_object.metrics or [])
]
# Resolve each requested column, preserving order and deduplicating by id.
seen_dim_ids: set[str] = set()
dimensions: list[Dimension] = []
for column in query_object.columns:
dim = _resolve_dimension(
column,
dataset,
all_dimensions,
template_processor,
)
if dim.id in seen_dim_ids:
continue
seen_dim_ids.add(dim.id)
dimensions.append(dim)
grain = _convert_time_grain(query_object.extras.get("time_grain_sqla"))
dimensions = [
dimension
for dimension in semantic_view.dimensions
if dimension.name in normalized_columns
and (
# if a grain is specified, only include the time dimension if its grain
# matches the requested grain
grain is None
or dimension.name != query_object.granularity
or dimension.grain == grain
)
]
if grain is not None and query_object.granularity:
# Apply the requested grain to the granularity dimension. For cube-mode
# views that pre-declare grain on their dimensions, keep the existing
# "only include the matching-grain version" semantic. For dataset-mode
# views (dimensions have grain=None), the view applies the grain via
# the engine spec at SQL build time, so we just stamp the grain on
# the dimension here.
dimensions = [
(
_stamp_grain(dim, grain)
if dim.name == query_object.granularity and dim.grain is None
else dim
)
for dim in dimensions
if dim.name != query_object.granularity
or dim.grain is None
or dim.grain == grain
]
order = _get_order_from_query_object(query_object, all_metrics, all_dimensions)
limit = query_object.row_limit
@@ -928,32 +1030,53 @@ def validate_query_object(
def _validate_metrics(query_object: ValidatedQueryObject) -> None:
"""
Make sure metrics are defined in the semantic view.
Make sure metrics are defined in the semantic view or are valid adhocs.
Validation of adhoc metrics is delegated to ``_resolve_metric`` — it
raises if the SIMPLE shape is malformed or the SQL fails Jinja /
safety checks.
"""
semantic_view = query_object.datasource.implementation
dataset = query_object.datasource
metrics_by_name = {metric.name: metric for metric in semantic_view.metrics}
template_processor = _get_template_processor(dataset)
if any(not isinstance(metric, str) for metric in (query_object.metrics or [])):
raise ValueError("Adhoc metrics are not supported in Semantic Views.")
labels: list[str] = []
for metric in query_object.metrics or []:
resolved = _resolve_metric(
metric, dataset, metrics_by_name, template_processor
)
labels.append(resolved.id)
metric_names = {metric.name for metric in semantic_view.metrics}
if not set(query_object.metrics or []) <= metric_names:
raise ValueError("All metrics must be defined in the Semantic View.")
if len(labels) != len(set(labels)):
raise ValueError(
"Duplicate metric labels are not supported in Semantic Views."
)
def _validate_dimensions(query_object: ValidatedQueryObject) -> None:
"""
Make sure all dimensions are defined in the semantic view.
Make sure all dimensions are defined in the semantic view or are valid
adhocs. Synthesized adhoc columns are validated via ``_resolve_dimension``.
"""
semantic_view = query_object.datasource.implementation
dimension_names = {dimension.name for dimension in semantic_view.dimensions}
dataset = query_object.datasource
dimensions_by_name = {
dimension.name: dimension for dimension in semantic_view.dimensions
}
template_processor = _get_template_processor(dataset)
# Normalize all columns to dimension names
normalized_columns = [
_normalize_column(column, dimension_names) for column in query_object.columns
]
labels: list[str] = []
for column in query_object.columns:
resolved = _resolve_dimension(
column, dataset, dimensions_by_name, template_processor
)
labels.append(resolved.id)
if not set(normalized_columns) <= dimension_names:
raise ValueError("All dimensions must be defined in the Semantic View.")
if len(labels) != len(set(labels)):
raise ValueError(
"Duplicate column labels are not supported in Semantic Views."
)
def _validate_filters(query_object: ValidatedQueryObject) -> None:

View File

@@ -0,0 +1,120 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
RLS helpers used by the dataset semantic-view path.
The legacy ``get_sqla_query`` applies row-level security in two places:
* The dataset's own RLS rules are AND-ed into the outer ``WHERE`` (or
the source table is wrapped in a subquery when the engine spec
returns ``RLSMethod.AS_SUBQUERY``).
* For virtual datasets, RLS rules from the *underlying* tables
referenced by ``dataset.sql`` are injected via
:func:`superset.utils.rls.apply_rls`.
This module exposes both flows as pure functions returning SQL strings,
which the ``DatasetSemanticView`` plugs into its sqlglot AST.
"""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING
from superset.sql.parse import RLSMethod, SQLScript
from superset.utils.rls import apply_rls
if TYPE_CHECKING:
from superset.connectors.sqla.models import SqlaTable
logger = logging.getLogger(__name__)
def get_rls_method(dataset: "SqlaTable") -> RLSMethod:
"""Return the engine's preferred RLS injection method."""
return dataset.database.db_engine_spec.get_rls_method()
def render_rls_predicates(dataset: "SqlaTable") -> list[str]:
"""
Render the dataset's outer RLS predicates as SQL strings.
The clauses are already Jinja-rendered by
:meth:`SqlaTable.get_sqla_row_level_filters`; we then compile them
against the database's dialect with ``literal_binds=True`` so the
resulting strings are safe to splice into a sqlglot AST.
Returns an empty list when the dataset has no applicable RLS rules
for the current user.
"""
text_clauses = dataset.get_sqla_row_level_filters()
if not text_clauses:
return []
dialect = dataset.database.get_dialect()
return [
str(clause.compile(dialect=dialect, compile_kwargs={"literal_binds": True}))
for clause in text_clauses
]
def apply_rls_to_virtual_sql(dataset: "SqlaTable") -> str | None:
"""
Rewrite a virtual dataset's inner SQL to enforce RLS on the
underlying tables it references.
Returns the rewritten SQL string when RLS predicates were injected,
or ``None`` when no rules applied (caller should use the original
``dataset.sql`` in that case). Returns ``None`` for physical
datasets too.
"""
if not dataset.sql:
return None
engine = dataset.database.db_engine_spec.engine
try:
parsed_script = SQLScript(dataset.sql, engine=engine)
except Exception: # pylint: disable=broad-except
# Mirror the legacy behavior — failure to parse should not block
# the query; the caller will fall back to the original SQL and
# RLS will simply not be applied to inner tables.
logger.warning(
"Failed to parse virtual dataset SQL for RLS application",
exc_info=True,
)
return None
default_schema = dataset.database.get_default_schema(dataset.catalog)
rls_applied = False
try:
for statement in parsed_script.statements:
if apply_rls(
dataset.database,
dataset.catalog,
dataset.schema or default_schema or "",
statement,
exclude_dataset_id=dataset.id,
):
rls_applied = True
except Exception: # pylint: disable=broad-except
logger.warning(
"Failed to apply RLS to virtual dataset inner SQL",
exc_info=True,
)
return None
return parsed_script.format() if rls_applied else None

View File

@@ -0,0 +1,270 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
from __future__ import annotations
from unittest.mock import MagicMock
import pyarrow as pa
import pytest
from superset_core.semantic_layers.types import AggregationType, Dimension
from superset.exceptions import QueryObjectValidationError
from superset.semantic_layers.adhoc import (
adhoc_column_to_semantic_dimension,
adhoc_metric_to_semantic_metric,
)
@pytest.fixture
def dataset() -> MagicMock:
ds = MagicMock()
ds.database_id = 1
ds.schema = "public"
ds.database.db_engine_spec.engine = "postgresql"
# Identifier quoter — return double-quoted name.
ds.quote_identifier = lambda name: f'"{name}"'
# ``_process_select_expression`` echoes back the SQL after "rendering".
ds._process_select_expression = (
lambda expression, **kwargs: expression # noqa: U100
)
return ds
# ---------------------------------------------------------------------------
# Adhoc metric
# ---------------------------------------------------------------------------
def test_simple_metric_sum(dataset: MagicMock) -> None:
metric = adhoc_metric_to_semantic_metric(
{
"expressionType": "SIMPLE",
"label": "total_amount",
"aggregate": "SUM",
"column": {"column_name": "amount"},
},
dataset,
)
assert metric.id == "total_amount"
assert metric.name == "total_amount"
assert metric.definition == 'SUM("amount")'
assert metric.aggregation == AggregationType.SUM
# SUM(...) infers to float64.
assert metric.type == pa.float64()
def test_simple_metric_count_distinct(dataset: MagicMock) -> None:
metric = adhoc_metric_to_semantic_metric(
{
"expressionType": "SIMPLE",
"label": "unique_customers",
"aggregate": "COUNT_DISTINCT",
"column": {"column_name": "customer_id"},
},
dataset,
)
assert metric.definition == 'COUNT(DISTINCT "customer_id")'
assert metric.aggregation == AggregationType.COUNT_DISTINCT
def test_simple_metric_unknown_aggregate_falls_back_to_other(
dataset: MagicMock,
) -> None:
metric = adhoc_metric_to_semantic_metric(
{
"expressionType": "SIMPLE",
"label": "median_x",
"aggregate": "MEDIAN",
"column": {"column_name": "x"},
},
dataset,
)
# MEDIAN isn't in the rollup-safe map.
assert metric.aggregation == AggregationType.OTHER
def test_simple_metric_missing_aggregate_raises(dataset: MagicMock) -> None:
with pytest.raises(QueryObjectValidationError, match="aggregate and a column"):
adhoc_metric_to_semantic_metric(
{
"expressionType": "SIMPLE",
"label": "bad",
"column": {"column_name": "amount"},
},
dataset,
)
def test_simple_metric_missing_column_raises(dataset: MagicMock) -> None:
with pytest.raises(QueryObjectValidationError, match="aggregate and a column"):
adhoc_metric_to_semantic_metric(
{
"expressionType": "SIMPLE",
"label": "bad",
"aggregate": "SUM",
"column": {},
},
dataset,
)
def test_sql_metric(dataset: MagicMock) -> None:
metric = adhoc_metric_to_semantic_metric(
{
"expressionType": "SQL",
"label": "profit_margin",
"sqlExpression": "SUM(revenue - cost) / SUM(revenue)",
},
dataset,
)
assert metric.definition == "SUM(revenue - cost) / SUM(revenue)"
assert metric.aggregation == AggregationType.OTHER
def test_sql_metric_missing_expression_raises(dataset: MagicMock) -> None:
with pytest.raises(QueryObjectValidationError, match="sqlExpression"):
adhoc_metric_to_semantic_metric(
{"expressionType": "SQL", "label": "bad"},
dataset,
)
def test_metric_missing_label_raises(dataset: MagicMock) -> None:
with pytest.raises(QueryObjectValidationError, match="missing a ``label``"):
adhoc_metric_to_semantic_metric(
{"expressionType": "SIMPLE", "aggregate": "SUM"},
dataset,
)
def test_metric_unknown_expression_type_raises(dataset: MagicMock) -> None:
with pytest.raises(QueryObjectValidationError, match="Unknown adhoc metric"):
adhoc_metric_to_semantic_metric(
{"label": "weird", "expressionType": "MYSTERY"},
dataset,
)
def test_sql_metric_jinja_applied(dataset: MagicMock) -> None:
# ``_process_select_expression`` is where Jinja and safety validation
# live in the dataset model. We verify the helper is invoked.
dataset._process_select_expression = MagicMock(return_value="user_id = 42")
metric = adhoc_metric_to_semantic_metric(
{
"expressionType": "SQL",
"label": "rendered",
"sqlExpression": "user_id = {{ current_user_id() }}",
},
dataset,
template_processor=MagicMock(),
)
assert metric.definition == "user_id = 42"
dataset._process_select_expression.assert_called_once()
def test_sql_metric_empty_processed_raises(dataset: MagicMock) -> None:
dataset._process_select_expression = MagicMock(return_value=None)
with pytest.raises(QueryObjectValidationError, match="empty string"):
adhoc_metric_to_semantic_metric(
{
"expressionType": "SQL",
"label": "bad",
"sqlExpression": "{{ '' }}",
},
dataset,
template_processor=MagicMock(),
)
# ---------------------------------------------------------------------------
# Adhoc column
# ---------------------------------------------------------------------------
def test_adhoc_column_reference_uses_existing_dimension(dataset: MagicMock) -> None:
existing = Dimension(id="country", name="country", type=pa.utf8())
result = adhoc_column_to_semantic_dimension(
{
"label": "country",
"sqlExpression": "country",
"isColumnReference": True,
},
dataset,
{"country": existing},
)
assert result is existing
def test_adhoc_column_synthesises_dimension(dataset: MagicMock) -> None:
dataset._process_select_expression = MagicMock(return_value="UPPER(country)")
result = adhoc_column_to_semantic_dimension(
{
"label": "upper_country",
"sqlExpression": "UPPER(country)",
},
dataset,
{},
template_processor=MagicMock(),
)
assert result.id == "upper_country"
assert result.name == "upper_country"
assert result.definition == "UPPER(country)"
# UPPER(...) infers to utf8.
assert result.type == pa.utf8()
def test_adhoc_column_missing_label_raises(dataset: MagicMock) -> None:
with pytest.raises(QueryObjectValidationError, match="``label``"):
adhoc_column_to_semantic_dimension(
{"sqlExpression": "x"},
dataset,
{},
)
def test_adhoc_column_missing_sql_raises(dataset: MagicMock) -> None:
with pytest.raises(QueryObjectValidationError, match="``sqlExpression``"):
adhoc_column_to_semantic_dimension(
{"label": "x"},
dataset,
{},
)
def test_adhoc_column_reference_falls_back_when_not_matching(
dataset: MagicMock,
) -> None:
"""
A column-reference adhoc whose sqlExpression doesn't match an existing
dimension is treated as a synthesized adhoc.
"""
dataset._process_select_expression = MagicMock(return_value="ghost")
result = adhoc_column_to_semantic_dimension(
{
"label": "spooky",
"sqlExpression": "ghost",
"isColumnReference": True,
},
dataset,
{"country": Dimension(id="country", name="country", type=pa.utf8())},
template_processor=MagicMock(),
)
assert result.id == "spooky"
assert result.definition == "ghost"

View File

@@ -50,9 +50,14 @@ from superset.semantic_layers.mapper import (
_get_group_limit_filters,
_get_group_limit_from_query_object,
_get_order_from_query_object,
_get_template_processor,
_get_time_bounds,
_get_time_filter,
_normalize_column,
_resolve_dimension,
_resolve_metric,
_stamp_grain,
_validate_dimensions,
_validate_filters,
_validate_granularity,
_validate_group_limit,
@@ -1035,7 +1040,10 @@ def test_validate_query_object_undefined_metric_error(
columns=["order_date"],
)
with pytest.raises(ValueError, match="All metrics must be defined"):
with pytest.raises(
ValueError,
match="Metric 'undefined_metric' is not defined in the Semantic View",
):
validate_query_object(query_object)
@@ -1051,7 +1059,10 @@ def test_validate_query_object_undefined_dimension_error(
columns=["undefined_dimension"],
)
with pytest.raises(ValueError, match="All dimensions must be defined"):
with pytest.raises(
ValueError,
match="Dimension 'undefined_dimension' is not defined in the Semantic View",
):
validate_query_object(query_object)
@@ -1800,28 +1811,31 @@ def test_get_results_empty_requests(
def test_normalize_column_adhoc_not_in_dimensions() -> None:
"""
Test _normalize_column raises error for AdhocColumn with sqlExpression not in dims.
Adhoc columns whose sqlExpression doesn't match an existing dimension fall
back to using the label as the resolved name. Actual SQL synthesis is the
job of _resolve_dimension; _normalize_column only surfaces a name.
"""
dimension_names = {"category", "region"}
adhoc_column: AdhocColumn = {
"label": "custom_dim",
"isColumnReference": True,
"sqlExpression": "unknown_dimension",
}
with pytest.raises(ValueError, match="Adhoc dimensions are not supported"):
_normalize_column(adhoc_column, dimension_names)
assert _normalize_column(adhoc_column, dimension_names) == "custom_dim"
def test_normalize_column_adhoc_missing_sql_expression() -> None:
def test_normalize_column_adhoc_missing_label_raises() -> None:
"""
Test _normalize_column raises error for AdhocColumn without sqlExpression.
When neither a matching column reference nor a label is provided there's
no resolvable name and _normalize_column raises.
"""
dimension_names = {"category", "region"}
adhoc_column: AdhocColumn = {
"isColumnReference": True,
}
with pytest.raises(ValueError, match="Adhoc dimensions are not supported"):
with pytest.raises(ValueError, match="Adhoc column is missing"):
_normalize_column(adhoc_column, dimension_names)
@@ -2265,11 +2279,12 @@ def test_validate_query_object_no_datasource() -> None:
assert result is False
def test_validate_metrics_adhoc_error(
def test_validate_metrics_adhoc_with_bad_shape_raises(
mocker: MockerFixture,
) -> None:
"""
Test validation error for adhoc metrics.
Adhoc metrics are now supported, but invalid shapes (missing
expressionType, etc.) still raise via the adhoc resolver.
"""
mock_datasource = mocker.Mock()
category_dim = Dimension("category", "category", pa.utf8(), "category", "Category")
@@ -2279,13 +2294,15 @@ def test_validate_metrics_adhoc_error(
mock_datasource.implementation.dimensions = {category_dim}
mock_datasource.implementation.metrics = {sales_metric}
# Strip the template processor; we just want to verify the shape check.
mock_datasource.get_template_processor.return_value = None
# Manually create a query object with an adhoc metric
query_object = mocker.Mock()
query_object.datasource = mock_datasource
# Missing expressionType — the resolver doesn't know how to interpret this.
query_object.metrics = [{"label": "adhoc", "sqlExpression": "SUM(x)"}]
with pytest.raises(ValueError, match="Adhoc metrics are not supported"):
with pytest.raises(Exception, match="expressionType"):
_validate_metrics(query_object)
@@ -3111,3 +3128,171 @@ def test_coerce_time_invalid_string_raises() -> None:
def test_coerce_time_rejects_other_types() -> None:
with pytest.raises(ValueError, match="Invalid time value"):
_coerce_scalar_filter_value(123, _dim(pa.time64("us")))
# ---------------------------------------------------------------------------
# Adhoc + grain resolver coverage
# ---------------------------------------------------------------------------
def test_get_template_processor_returns_none_when_unsupported() -> None:
"""A bare object without ``get_template_processor`` returns None."""
class Plain:
pass
assert _get_template_processor(Plain()) is None
def test_stamp_grain_returns_new_dimension_with_grain() -> None:
dim = Dimension(id="dt", name="dt", type=pa.timestamp("us"))
stamped = _stamp_grain(dim, Grains.DAY)
assert stamped.grain == Grains.DAY
# Original is unchanged (frozen dataclass).
assert dim.grain is None
def _adhoc_dataset() -> MagicMock:
ds = MagicMock()
ds.database_id = 1
ds.schema = "public"
ds.database.db_engine_spec.engine = "postgresql"
ds.quote_identifier = lambda name: f'"{name}"'
ds._process_select_expression = lambda expression, **kwargs: expression # noqa: U100
ds.get_template_processor.return_value = None
return ds
def test_resolve_metric_for_adhoc_simple_dict() -> None:
ds = _adhoc_dataset()
metric = _resolve_metric(
{
"expressionType": "SIMPLE",
"label": "total",
"aggregate": "SUM",
"column": {"column_name": "x"},
},
ds,
{},
None,
)
assert metric.id == "total"
assert metric.definition == 'SUM("x")'
def test_resolve_dimension_for_adhoc_dict() -> None:
ds = _adhoc_dataset()
dim = _resolve_dimension(
{"label": "calc", "sqlExpression": "UPPER(country)"},
ds,
{},
None,
)
assert dim.id == "calc"
assert dim.definition == "UPPER(country)"
def test_validate_metrics_duplicate_label_raises() -> None:
ds = _adhoc_dataset()
view = MagicMock()
view.metrics = []
view.dimensions = []
ds.implementation = view
query = MagicMock()
query.datasource = ds
query.metrics = [
{
"expressionType": "SIMPLE",
"label": "dup",
"aggregate": "SUM",
"column": {"column_name": "x"},
},
{
"expressionType": "SIMPLE",
"label": "dup",
"aggregate": "SUM",
"column": {"column_name": "y"},
},
]
with pytest.raises(ValueError, match="Duplicate metric labels"):
_validate_metrics(query)
def test_validate_dimensions_duplicate_label_raises() -> None:
ds = _adhoc_dataset()
view = MagicMock()
view.metrics = []
view.dimensions = []
ds.implementation = view
query = MagicMock()
query.datasource = ds
query.columns = [
{"label": "dup", "sqlExpression": "x"},
{"label": "dup", "sqlExpression": "y"},
]
with pytest.raises(ValueError, match="Duplicate column labels"):
_validate_dimensions(query)
def test_map_query_object_dedups_dimensions_by_id(mocker: MockerFixture) -> None:
"""Two requested columns resolving to the same dim id collapse to one."""
ds = _adhoc_dataset()
existing = Dimension(id="country", name="country", type=pa.utf8())
view = MagicMock()
view.metrics = []
view.dimensions = [existing]
ds.implementation = view
ds.fetch_values_predicate = None
query = ValidatedQueryObject(
datasource=ds,
metrics=[],
# Same dim referenced twice — once as a string, once as a column ref.
columns=[
"country",
{"label": "country", "sqlExpression": "country", "isColumnReference": True},
],
)
sem_queries = map_query_object(query)
assert len(sem_queries[0].dimensions) == 1
def test_map_query_object_stamps_grain_on_granularity_dim(
mocker: MockerFixture,
) -> None:
"""A grain request stamps the grain onto the matching dimension."""
ds = _adhoc_dataset()
dt_dim = Dimension(id="dt", name="dt", type=pa.timestamp("us"))
view = MagicMock()
view.metrics = []
view.dimensions = [dt_dim]
ds.implementation = view
ds.fetch_values_predicate = None
query = ValidatedQueryObject(
datasource=ds,
metrics=[],
columns=["dt"],
granularity="dt",
extras={"time_grain_sqla": "P1D"},
)
sem_queries = map_query_object(query)
dim = sem_queries[0].dimensions[0]
assert dim.grain == Grains.DAY
def test_normalize_column_adhoc_label_only() -> None:
"""Adhoc with no isColumnReference falls back to its label."""
assert (
_normalize_column({"label": "calc", "sqlExpression": "x"}, set()) == "calc"
)
def test_normalize_column_string_passthrough() -> None:
assert _normalize_column("category", {"category", "region"}) == "category"

View File

@@ -0,0 +1,166 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
from __future__ import annotations
from unittest.mock import MagicMock
import pytest
from pytest_mock import MockerFixture
from superset.semantic_layers.rls import (
apply_rls_to_virtual_sql,
get_rls_method,
render_rls_predicates,
)
from superset.sql.parse import RLSMethod
def _make_dataset(
*,
sql: str | None = None,
rls_clauses: list | None = None,
) -> MagicMock:
ds = MagicMock()
ds.id = 1
ds.catalog = None
ds.schema = "public"
ds.sql = sql
ds.database.db_engine_spec.engine = "postgresql"
ds.database.get_default_schema.return_value = "public"
ds.get_sqla_row_level_filters.return_value = rls_clauses or []
return ds
def test_get_rls_method_delegates_to_engine_spec() -> None:
ds = _make_dataset()
ds.database.db_engine_spec.get_rls_method.return_value = RLSMethod.AS_SUBQUERY
assert get_rls_method(ds) == RLSMethod.AS_SUBQUERY
def test_render_rls_predicates_empty_when_no_rules() -> None:
ds = _make_dataset()
assert render_rls_predicates(ds) == []
def test_render_rls_predicates_compiles_each_clause() -> None:
clause1 = MagicMock()
clause1.compile.return_value = "tenant_id = 7"
clause2 = MagicMock()
clause2.compile.return_value = "deleted = false"
ds = _make_dataset(rls_clauses=[clause1, clause2])
result = render_rls_predicates(ds)
assert result == ["tenant_id = 7", "deleted = false"]
# Each clause is compiled with literal_binds so values are inlined.
for clause in (clause1, clause2):
kwargs = clause.compile.call_args.kwargs
assert kwargs["compile_kwargs"] == {"literal_binds": True}
def test_apply_rls_to_virtual_sql_returns_none_for_physical(mocker: MockerFixture) -> None:
ds = _make_dataset()
assert apply_rls_to_virtual_sql(ds) is None
def test_apply_rls_to_virtual_sql_returns_none_when_no_rls_applied(
mocker: MockerFixture,
) -> None:
ds = _make_dataset(sql="SELECT * FROM raw_orders")
# apply_rls returns False → no predicates were injected.
mocker.patch("superset.semantic_layers.rls.apply_rls", return_value=False)
assert apply_rls_to_virtual_sql(ds) is None
def test_apply_rls_to_virtual_sql_returns_rewritten_when_applied(
mocker: MockerFixture,
) -> None:
ds = _make_dataset(sql="SELECT * FROM raw_orders")
fake_script = MagicMock()
fake_statement = MagicMock()
fake_script.statements = [fake_statement]
fake_script.format.return_value = "SELECT * FROM raw_orders WHERE x = 1"
mocker.patch(
"superset.semantic_layers.rls.SQLScript",
return_value=fake_script,
)
mocker.patch("superset.semantic_layers.rls.apply_rls", return_value=True)
result = apply_rls_to_virtual_sql(ds)
assert result == "SELECT * FROM raw_orders WHERE x = 1"
def test_apply_rls_to_virtual_sql_swallows_parse_error(mocker: MockerFixture) -> None:
ds = _make_dataset(sql="totally invalid sql")
mocker.patch(
"superset.semantic_layers.rls.SQLScript",
side_effect=Exception("parse error"),
)
assert apply_rls_to_virtual_sql(ds) is None
def test_apply_rls_to_virtual_sql_swallows_apply_error(mocker: MockerFixture) -> None:
ds = _make_dataset(sql="SELECT * FROM raw_orders")
fake_script = MagicMock()
fake_script.statements = [MagicMock()]
mocker.patch(
"superset.semantic_layers.rls.SQLScript",
return_value=fake_script,
)
mocker.patch(
"superset.semantic_layers.rls.apply_rls",
side_effect=Exception("boom"),
)
assert apply_rls_to_virtual_sql(ds) is None
def test_apply_rls_to_virtual_sql_uses_default_schema_when_dataset_schema_missing(
mocker: MockerFixture,
) -> None:
ds = _make_dataset(sql="SELECT * FROM raw_orders")
ds.schema = None
fake_script = MagicMock()
fake_statement = MagicMock()
fake_script.statements = [fake_statement]
mocker.patch(
"superset.semantic_layers.rls.SQLScript",
return_value=fake_script,
)
apply_rls_mock = mocker.patch(
"superset.semantic_layers.rls.apply_rls",
return_value=False,
)
apply_rls_to_virtual_sql(ds)
# default_schema was "public" from get_default_schema; that should be the
# schema arg passed to apply_rls.
args, kwargs = apply_rls_mock.call_args
assert args[2] == "public" or kwargs.get("schema") == "public"
def test_render_rls_predicates_uses_dialect_for_compile() -> None:
clause = MagicMock()
clause.compile.return_value = "x = 1"
ds = _make_dataset(rls_clauses=[clause])
render_rls_predicates(ds)
kwargs = clause.compile.call_args.kwargs
# Verify the dialect from the database is passed (not asserting identity,
# just that ``dialect=`` got a value).
assert "dialect" in kwargs