diff --git a/superset/semantic_layers/cache.py b/superset/semantic_layers/cache.py new file mode 100644 index 00000000000..b69a8803cbf --- /dev/null +++ b/superset/semantic_layers/cache.py @@ -0,0 +1,589 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Containment-aware cache for semantic view queries. + +A broader cached result can satisfy a narrower new query: when the new query's +filters and limit are strictly more restrictive than a cached entry's, the cached +DataFrame is post-filtered and re-limited rather than re-executing the underlying +query. + +See ``docs/`` and the plan file for the design rationale; the rules summary: + +* Same metrics and dimensions (shape). +* Each cached filter must be implied by a new-query filter on the same column. +* New filters on columns with no cached constraint are applied post-fetch as + "leftovers" — provided the column is in the projection. +* Cached ``limit`` must be at least the new ``limit``; if a cached ``limit`` is + present, the orderings must match (otherwise the cached "top N" is not the + true top of the new query). +* ``ADHOC`` and ``HAVING`` filters require exact-set equality. +* ``offset != 0`` and mismatching ``group_limit`` skip the cache. +""" + +from __future__ import annotations + +import logging +import re +import time as _time +from dataclasses import dataclass, field +from datetime import date, datetime, time, timedelta +from typing import Any, Iterable + +import pandas as pd +import pyarrow as pa +from flask import current_app +from superset_core.semantic_layers.types import ( + AdhocExpression, + Dimension, + Filter, + Metric, + Operator, + OrderTuple, + PredicateType, + SemanticQuery, + SemanticRequest, + SemanticResult, +) + +from superset.extensions import cache_manager +from superset.utils import json +from superset.utils.hashing import hash_from_str + +logger = logging.getLogger(__name__) + +INDEX_KEY_PREFIX = "sv:idx:" +VALUE_KEY_PREFIX = "sv:val:" +MAX_ENTRIES_PER_SHAPE = 32 + + +@dataclass(frozen=True) +class ViewMeta: + """Identity/freshness/TTL info pulled from the SemanticView ORM row.""" + + uuid: str + changed_on_iso: str + cache_timeout: int | None + + +@dataclass(frozen=True) +class CachedEntry: + filters: frozenset[Filter] + limit: int | None + offset: int + order_key: str + group_limit_key: str + value_key: str + timestamp: float = field(default_factory=_time.time) + + +# --------------------------------------------------------------------------- +# Public surface +# --------------------------------------------------------------------------- + + +def try_serve_from_cache( + view_meta: ViewMeta, + query: SemanticQuery, +) -> SemanticResult | None: + """Return a cached ``SemanticResult`` that satisfies ``query`` if any.""" + try: + cache = cache_manager.data_cache + idx_key = shape_key(view_meta, query) + entries: list[CachedEntry] | None = cache.get(idx_key) + if not entries: + return None + + pruned: list[CachedEntry] = [] + served: SemanticResult | None = None + for entry in entries: + if served is None: + ok, leftovers = can_satisfy(entry, query) + if ok: + payload = cache.get(entry.value_key) + if payload is None: + # value evicted but index entry survived; drop it + continue + pruned.append(entry) + served = _apply_post_processing(payload, query, leftovers) + continue + # keep entry; verify its value is still alive + if cache.get(entry.value_key) is not None: + pruned.append(entry) + + if len(pruned) != len(entries): + cache.set(idx_key, pruned, timeout=_timeout(view_meta)) + return served + except Exception: # pragma: no cover - defensive + logger.warning("Semantic view cache lookup failed", exc_info=True) + return None + + +def store_result( + view_meta: ViewMeta, + query: SemanticQuery, + result: SemanticResult, +) -> None: + """Persist ``result`` under a fresh value key and register a descriptor.""" + try: + cache = cache_manager.data_cache + timeout = _timeout(view_meta) + vkey = value_key(view_meta, query) + cache.set(vkey, result, timeout=timeout) + + idx_key = shape_key(view_meta, query) + entries: list[CachedEntry] = list(cache.get(idx_key) or []) + entry = CachedEntry( + filters=frozenset(query.filters or set()), + limit=query.limit, + offset=query.offset or 0, + order_key=_order_key(query.order), + group_limit_key=_group_limit_key(query.group_limit), + value_key=vkey, + ) + entries = [e for e in entries if e.value_key != vkey] + entries.append(entry) + if len(entries) > MAX_ENTRIES_PER_SHAPE: + entries = sorted(entries, key=lambda e: e.timestamp)[ + -MAX_ENTRIES_PER_SHAPE: + ] + cache.set(idx_key, entries, timeout=timeout) + except Exception: # pragma: no cover - defensive + logger.warning("Semantic view cache store failed", exc_info=True) + + +# --------------------------------------------------------------------------- +# Keys +# --------------------------------------------------------------------------- + + +def shape_key(view_meta: ViewMeta, query: SemanticQuery) -> str: + shape = { + "m": sorted(m.id for m in query.metrics), + "d": sorted(_dimension_key(d) for d in query.dimensions), + } + digest = hash_from_str(json.dumps(shape, sort_keys=True))[:16] + return f"{INDEX_KEY_PREFIX}{view_meta.uuid}:{view_meta.changed_on_iso}:{digest}" + + +def value_key(view_meta: ViewMeta, query: SemanticQuery) -> str: + digest = hash_from_str(json.dumps(_canonicalize(query), sort_keys=True))[:32] + return f"{VALUE_KEY_PREFIX}{view_meta.uuid}:{view_meta.changed_on_iso}:{digest}" + + +def _dimension_key(dim: Dimension) -> str: + grain = dim.grain.representation if dim.grain else "_" + return f"{dim.id}@{grain}" + + +def _canonicalize(query: SemanticQuery) -> dict[str, Any]: + return { + "m": sorted(m.id for m in query.metrics), + "d": sorted(_dimension_key(d) for d in query.dimensions), + "f": sorted(_filter_to_jsonable(f) for f in (query.filters or [])), + "o": _order_key(query.order), + "l": query.limit, + "off": query.offset or 0, + "gl": _group_limit_key(query.group_limit), + } + + +def _filter_to_jsonable(f: Filter) -> str: + return json.dumps( + { + "t": f.type.value, + "c": f.column.id if f.column is not None else None, + "o": f.operator.value, + "v": _value_to_jsonable(f.value), + }, + sort_keys=True, + ) + + +def _value_to_jsonable(value: Any) -> Any: + if isinstance(value, frozenset): + return sorted(_value_to_jsonable(v) for v in value) + if isinstance(value, (datetime, date, time)): + return value.isoformat() + if isinstance(value, timedelta): + return value.total_seconds() + return value + + +def _order_key(order: list[OrderTuple] | None) -> str: + if not order: + return "" + return json.dumps( + [(_orderable_id(element), direction.value) for element, direction in order] + ) + + +def _orderable_id(element: Metric | Dimension | AdhocExpression) -> str: + return element.id + + +def _group_limit_key(group_limit: Any) -> str: + if group_limit is None: + return "" + return json.dumps( + { + "dims": sorted(d.id for d in group_limit.dimensions), + "top": group_limit.top, + "metric": group_limit.metric.id if group_limit.metric else None, + "direction": group_limit.direction.value, + "group_others": group_limit.group_others, + "filters": sorted( + _filter_to_jsonable(f) for f in (group_limit.filters or []) + ), + }, + sort_keys=True, + ) + + +def _timeout(view_meta: ViewMeta) -> int | None: + if view_meta.cache_timeout is not None: + return view_meta.cache_timeout + config = current_app.config.get("DATA_CACHE_CONFIG") or {} + return config.get("CACHE_DEFAULT_TIMEOUT") + + +# --------------------------------------------------------------------------- +# Containment +# --------------------------------------------------------------------------- + + +def can_satisfy( # noqa: C901 + entry: CachedEntry, + query: SemanticQuery, +) -> tuple[bool, set[Filter]]: + """Return ``(reusable, leftover_filters_to_apply)`` for ``entry`` vs ``query``.""" + new_filters = frozenset(query.filters or set()) + + c_adhoc, c_having, c_where = _split(entry.filters) + n_adhoc, n_having, n_where = _split(new_filters) + + if c_adhoc != n_adhoc: + return False, set() + if c_having != n_having: + return False, set() + + c_by_col = _group_by_column(c_where) + n_by_col = _group_by_column(n_where) + + for col_id, c_list in c_by_col.items(): + n_list = n_by_col.get(col_id, []) + for c in c_list: + if not any(_implies(n, c) for n in n_list): + return False, set() + + leftovers: set[Filter] = set() + for col_id, n_list in n_by_col.items(): + c_list = c_by_col.get(col_id, []) + for n in n_list: + if not any(_implies(c, n) for c in c_list): + if n.column is None or n.operator == Operator.ADHOC: + return False, set() + leftovers.add(n) + + projection_ids = _projection_ids(query) + for leftover in leftovers: + if leftover.column is None or leftover.column.id not in projection_ids: + return False, set() + + if entry.offset != 0 or (query.offset or 0) != 0: + return False, set() + + if entry.limit is not None: + if query.limit is None or query.limit > entry.limit: + return False, set() + if entry.order_key != _order_key(query.order): + return False, set() + + if entry.group_limit_key != _group_limit_key(query.group_limit): + return False, set() + + return True, leftovers + + +def _split( + filters: Iterable[Filter], +) -> tuple[frozenset[Filter], frozenset[Filter], frozenset[Filter]]: + adhoc: set[Filter] = set() + having: set[Filter] = set() + where: set[Filter] = set() + for f in filters: + if f.operator == Operator.ADHOC: + adhoc.add(f) + elif f.type == PredicateType.HAVING: + having.add(f) + else: + where.add(f) + return frozenset(adhoc), frozenset(having), frozenset(where) + + +def _group_by_column(filters: Iterable[Filter]) -> dict[str | None, list[Filter]]: + out: dict[str | None, list[Filter]] = {} + for f in filters: + col_id = f.column.id if f.column is not None else None + out.setdefault(col_id, []).append(f) + return out + + +def _projection_ids(query: SemanticQuery) -> set[str]: + return {d.id for d in query.dimensions} | {m.id for m in query.metrics} + + +# --------------------------------------------------------------------------- +# Pairwise implication +# --------------------------------------------------------------------------- + + +# pylint: disable=too-many-return-statements,too-many-branches +def _implies(new: Filter, cached: Filter) -> bool: # noqa: C901 + """True iff every row matching ``new`` also matches ``cached``. + + Both filters are assumed to be on the same column (caller groups by column). + """ + if new == cached: + return True + + nop, nval = new.operator, new.value + cop, cval = cached.operator, cached.value + + if cop == Operator.IS_NULL: + if nop == Operator.IS_NULL: + return True + if nop == Operator.EQUALS and nval is None: + return True + return False + + if cop == Operator.IS_NOT_NULL: + if nop == Operator.IS_NOT_NULL: + return True + if nop == Operator.EQUALS: + return nval is not None + if nop in _RANGE_OPS: + return True + if nop == Operator.IN: + return isinstance(nval, frozenset) and all(v is not None for v in nval) + return False + + if cop == Operator.EQUALS: + if nop == Operator.EQUALS: + return nval == cval + if nop == Operator.IN and isinstance(nval, frozenset): + return nval == frozenset({cval}) + return False + + if cop == Operator.NOT_EQUALS: + if nop == Operator.NOT_EQUALS: + return nval == cval + if nop == Operator.EQUALS: + return nval != cval + if nop == Operator.IN and isinstance(nval, frozenset): + return cval not in nval + return False + + if cop == Operator.IN and isinstance(cval, frozenset): + if nop == Operator.IN and isinstance(nval, frozenset): + return nval.issubset(cval) + if nop == Operator.EQUALS: + return nval in cval + return False + + if cop == Operator.NOT_IN and isinstance(cval, frozenset): + if nop == Operator.NOT_IN and isinstance(nval, frozenset): + return cval.issubset(nval) + if nop == Operator.NOT_EQUALS: + return cval.issubset({nval}) + if nop == Operator.EQUALS: + return nval not in cval + if nop == Operator.IN and isinstance(nval, frozenset): + return cval.isdisjoint(nval) + return False + + if cop in _RANGE_OPS: + return _implies_range(nop, nval, cop, cval) + + # LIKE / NOT_LIKE / ADHOC: only the exact-match path at the top. + return False + + +_RANGE_OPS = frozenset( + { + Operator.GREATER_THAN, + Operator.GREATER_THAN_OR_EQUAL, + Operator.LESS_THAN, + Operator.LESS_THAN_OR_EQUAL, + } +) + + +def _implies_range( # noqa: C901 + nop: Operator, + nval: Any, + cop: Operator, + cval: Any, +) -> bool: + if isinstance(nval, frozenset): + return nop == Operator.IN and all(_scalar_in_range(v, cop, cval) for v in nval) + if nop == Operator.EQUALS: + return _scalar_in_range(nval, cop, cval) + if nop not in _RANGE_OPS: + return False + if not _comparable(nval, cval): + return False + + # Same direction (both upper or both lower bounds) required. + cached_is_lower = cop in (Operator.GREATER_THAN, Operator.GREATER_THAN_OR_EQUAL) + new_is_lower = nop in (Operator.GREATER_THAN, Operator.GREATER_THAN_OR_EQUAL) + if cached_is_lower != new_is_lower: + return False + + if cached_is_lower: + # cached: a > cval or a >= cval + # new: a > nval or a >= nval + # need rows(new) ⊆ rows(cached) + if cop == Operator.GREATER_THAN and nop == Operator.GREATER_THAN: + return nval >= cval + if cop == Operator.GREATER_THAN and nop == Operator.GREATER_THAN_OR_EQUAL: + return nval > cval + if cop == Operator.GREATER_THAN_OR_EQUAL and nop == Operator.GREATER_THAN: + return nval >= cval + if ( + cop == Operator.GREATER_THAN_OR_EQUAL + and nop == Operator.GREATER_THAN_OR_EQUAL + ): + return nval >= cval + return False + else: + if cop == Operator.LESS_THAN and nop == Operator.LESS_THAN: + return nval <= cval + if cop == Operator.LESS_THAN and nop == Operator.LESS_THAN_OR_EQUAL: + return nval < cval + if cop == Operator.LESS_THAN_OR_EQUAL and nop == Operator.LESS_THAN: + return nval <= cval + if cop == Operator.LESS_THAN_OR_EQUAL and nop == Operator.LESS_THAN_OR_EQUAL: + return nval <= cval + return False + + +def _scalar_in_range(value: Any, cop: Operator, cval: Any) -> bool: + if not _comparable(value, cval): + return False + if cop == Operator.GREATER_THAN: + return value > cval + if cop == Operator.GREATER_THAN_OR_EQUAL: + return value >= cval + if cop == Operator.LESS_THAN: + return value < cval + if cop == Operator.LESS_THAN_OR_EQUAL: + return value <= cval + return False + + +def _comparable(a: Any, b: Any) -> bool: + if a is None or b is None: + return False + if isinstance(a, bool) or isinstance(b, bool): + return isinstance(a, bool) and isinstance(b, bool) + if isinstance(a, (int, float)) and isinstance(b, (int, float)): + return True + if isinstance(a, str) and isinstance(b, str): + return True + if isinstance(a, (datetime, date, time)) and isinstance(b, type(a)): + return True + if isinstance(a, type(b)) and isinstance(a, (datetime, date, time, timedelta)): + return True + return type(a) == type(b) # noqa: E721 + + +# --------------------------------------------------------------------------- +# Post-processing +# --------------------------------------------------------------------------- + + +def _apply_post_processing( + cached: SemanticResult, + query: SemanticQuery, + leftovers: set[Filter], +) -> SemanticResult: + """Apply leftover filters and the new limit to a cached result.""" + if not leftovers and query.limit is None: + return cached + + df = cached.results.to_pandas() + if leftovers: + mask = pd.Series(True, index=df.index) + for f in leftovers: + mask &= _mask_for(df, f) + df = df[mask] + if query.limit is not None: + df = df.head(query.limit) + + table = pa.Table.from_pandas(df, preserve_index=False) + note = SemanticRequest( + type="cache", + definition="Served from semantic view smart cache (post-processed locally)", + ) + return SemanticResult(requests=list(cached.requests) + [note], results=table) + + +def _mask_for(df: pd.DataFrame, f: Filter) -> pd.Series: # noqa: C901 + if f.column is None: + return pd.Series(True, index=df.index) + series = df[f.column.name] + op = f.operator + val = f.value + if op == Operator.EQUALS: + return series == val if val is not None else series.isna() + if op == Operator.NOT_EQUALS: + return series != val if val is not None else series.notna() + if op == Operator.GREATER_THAN: + return series > val + if op == Operator.GREATER_THAN_OR_EQUAL: + return series >= val + if op == Operator.LESS_THAN: + return series < val + if op == Operator.LESS_THAN_OR_EQUAL: + return series <= val + if op == Operator.IN: + return series.isin(list(val) if isinstance(val, frozenset) else [val]) + if op == Operator.NOT_IN: + return ~series.isin(list(val) if isinstance(val, frozenset) else [val]) + if op == Operator.IS_NULL: + return series.isna() + if op == Operator.IS_NOT_NULL: + return series.notna() + if op == Operator.LIKE: + return series.astype(str).str.match(_sql_like_to_regex(str(val))) + if op == Operator.NOT_LIKE: + return ~series.astype(str).str.match(_sql_like_to_regex(str(val))) + return pd.Series(True, index=df.index) + + +def _sql_like_to_regex(pattern: str) -> str: + out = [] + for ch in pattern: + if ch == "%": + out.append(".*") + elif ch == "_": + out.append(".") + else: + out.append(re.escape(ch)) + return f"^{''.join(out)}$" diff --git a/superset/semantic_layers/mapper.py b/superset/semantic_layers/mapper.py index 23dea06124d..515a2d38fc9 100644 --- a/superset/semantic_layers/mapper.py +++ b/superset/semantic_layers/mapper.py @@ -26,7 +26,7 @@ single dataframe. from datetime import datetime, timedelta from time import time -from typing import Any, cast, Sequence, TypeGuard +from typing import Any, Callable, cast, Sequence, TypeGuard import isodate import numpy as np @@ -55,6 +55,11 @@ from superset.common.utils.time_range_utils import get_since_until_from_query_ob from superset.connectors.sqla.models import BaseDatasource from superset.constants import NO_TIME_RANGE from superset.models.helpers import QueryResult +from superset.semantic_layers.cache import ( + store_result, + try_serve_from_cache, + ViewMeta, +) from superset.superset_typing import AdhocColumn from superset.utils.core import ( FilterOperator, @@ -112,13 +117,15 @@ def get_results(query_object: QueryObject) -> QueryResult: else semantic_view.get_table ) + cached_dispatch = _make_cached_dispatch(query_object, dispatcher) + # Step 1: Convert QueryObject to list of SemanticQuery objects # The first query is the main query, subsequent queries are for time offsets queries = map_query_object(query_object) # Step 2: Execute the main query (first in the list) main_query = queries[0] - main_result = dispatcher(main_query) + main_result = cached_dispatch(main_query) main_df = main_result.results.to_pandas() @@ -149,7 +156,7 @@ def get_results(query_object: QueryObject) -> QueryResult: strict=False, ): # Execute the offset query - result = dispatcher(offset_query) + result = cached_dispatch(offset_query) # Add this query's requests to the collection all_requests.extend(result.requests) @@ -205,6 +212,37 @@ def get_results(query_object: QueryObject) -> QueryResult: ) +def _make_cached_dispatch( + query_object: ValidatedQueryObject, + dispatcher: Callable[[SemanticQuery], SemanticResult], +) -> Callable[[SemanticQuery], SemanticResult]: + """ + Wrap the semantic view dispatcher with a containment-aware cache. + + Row-count queries bypass the cache. Cache failures are logged and the + dispatcher is called as if the cache were absent. + """ + if query_object.is_rowcount: + return dispatcher + + view = query_object.datasource + changed_on = getattr(view, "changed_on", None) + view_meta = ViewMeta( + uuid=str(view.uuid), + changed_on_iso=changed_on.isoformat() if changed_on else "", + cache_timeout=getattr(view, "cache_timeout", None), + ) + + def cached_dispatch(query: SemanticQuery) -> SemanticResult: + if (hit := try_serve_from_cache(view_meta, query)) is not None: + return hit + result = dispatcher(query) + store_result(view_meta, query, result) + return result + + return cached_dispatch + + def map_semantic_result_to_query_result( semantic_result: SemanticResult, query_object: ValidatedQueryObject, diff --git a/tests/unit_tests/semantic_layers/cache_integration_test.py b/tests/unit_tests/semantic_layers/cache_integration_test.py new file mode 100644 index 00000000000..87b5e659042 --- /dev/null +++ b/tests/unit_tests/semantic_layers/cache_integration_test.py @@ -0,0 +1,191 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""End-to-end test that exercises ``mapper.get_results`` with a live cache.""" + +from __future__ import annotations + +from datetime import datetime +from typing import Any +from unittest.mock import MagicMock + +import pandas as pd +import pyarrow as pa +import pytest +from pytest_mock import MockerFixture +from superset_core.semantic_layers.types import ( + Dimension, + Metric, + SemanticRequest, + SemanticResult, +) + +from superset.semantic_layers import cache as cache_module +from superset.semantic_layers.mapper import get_results, ValidatedQueryObject + + +class _InMemoryCache: + """Minimal flask-caching compatible cache used to isolate tests.""" + + def __init__(self) -> None: + self._store: dict[str, Any] = {} + + def get(self, key: str) -> Any: + return self._store.get(key) + + def set(self, key: str, value: Any, timeout: int | None = None) -> bool: + self._store[key] = value + return True + + def delete(self, key: str) -> bool: + return self._store.pop(key, None) is not None + + +@pytest.fixture +def fake_cache(mocker: MockerFixture) -> _InMemoryCache: + fake = _InMemoryCache() + mocker.patch.object( + type(cache_module.cache_manager), + "data_cache", + property(lambda self: fake), + ) + return fake + + +@pytest.fixture +def view_implementation() -> Any: + """SemanticView implementation stub with one metric and one dimension.""" + dim_a = Dimension(id="t.a", name="a", type=pa.int64()) + metric_x = Metric(id="t.x", name="x", type=pa.float64(), definition="sum(x)") + + impl = MagicMock() + impl.metrics = {metric_x} + impl.dimensions = {dim_a} + impl.features = frozenset() + impl.get_metrics = MagicMock(return_value={metric_x}) + impl.get_dimensions = MagicMock(return_value={dim_a}) + return impl + + +@pytest.fixture +def datasource(view_implementation: Any) -> MagicMock: + ds = MagicMock() + ds.implementation = view_implementation + ds.uuid = "view-uuid-stable" + ds.changed_on = datetime(2026, 1, 1, 12, 0, 0) + ds.cache_timeout = 60 + ds.fetch_values_predicate = None + return ds + + +def _result(rows: list[tuple[int, float]]) -> SemanticResult: + df = pd.DataFrame(rows, columns=["a", "x"]) + return SemanticResult( + requests=[SemanticRequest(type="SQL", definition="select a, x")], + results=pa.Table.from_pandas(df, preserve_index=False), + ) + + +def _qo( + datasource: MagicMock, + filter_op: str | None = None, + filter_val: Any = None, + limit: int | None = None, +) -> ValidatedQueryObject: + qo_filters: list[dict[str, Any]] = ( + [{"col": "a", "op": filter_op, "val": filter_val}] if filter_op else [] + ) + return ValidatedQueryObject( + datasource=datasource, + metrics=["x"], + columns=["a"], + filters=qo_filters, # type: ignore[arg-type] + row_limit=limit, + ) + + +def test_narrower_filter_reuses_cache( + fake_cache: _InMemoryCache, + view_implementation: Any, + datasource: MagicMock, +) -> None: + # The dispatcher returns rows already filtered by `a > 1` (in production it + # would; here we hand-feed the result). The second query (a > 2) is a subset + # and must be served from the cached DataFrame. + cached = _result([(2, 2.0), (3, 3.0), (5, 5.0)]) + view_implementation.get_table = MagicMock(return_value=cached) + + first = get_results(_qo(datasource, ">", 1)) + assert view_implementation.get_table.call_count == 1 + assert sorted(first.df["a"].tolist()) == [2, 3, 5] + + second = get_results(_qo(datasource, ">", 2)) + assert view_implementation.get_table.call_count == 1 # cache hit + assert sorted(second.df["a"].tolist()) == [3, 5] + + +def test_smaller_limit_reuses_cache( + fake_cache: _InMemoryCache, + view_implementation: Any, + datasource: MagicMock, +) -> None: + # First call has no limit; second asks for 2 rows — should be served from cache. + full = _result([(0, 1.0), (1, 2.0), (2, 3.0), (3, 4.0)]) + view_implementation.get_table = MagicMock(return_value=full) + + get_results(_qo(datasource, limit=None)) + assert view_implementation.get_table.call_count == 1 + + result = get_results(_qo(datasource, limit=2)) + assert view_implementation.get_table.call_count == 1 # cache hit + assert len(result.df) == 2 + + +def test_broader_filter_misses_cache( + fake_cache: _InMemoryCache, + view_implementation: Any, + datasource: MagicMock, +) -> None: + view_implementation.get_table = MagicMock( + side_effect=[ + _result([(2, 1.0), (3, 2.0)]), + _result([(0, 1.0), (2, 2.0), (3, 3.0)]), + ] + ) + + get_results(_qo(datasource, ">", 1)) + assert view_implementation.get_table.call_count == 1 + + # Broader filter — must re-execute. + get_results(_qo(datasource, ">", 0)) + assert view_implementation.get_table.call_count == 2 + + +def test_changed_on_invalidates_cache( + fake_cache: _InMemoryCache, + view_implementation: Any, + datasource: MagicMock, +) -> None: + view_implementation.get_table = MagicMock(return_value=_result([(2, 1.0)])) + + get_results(_qo(datasource, ">", 1)) + assert view_implementation.get_table.call_count == 1 + + # Bumping changed_on yields a different shape key — cache misses. + datasource.changed_on = datetime(2026, 2, 1, 0, 0, 0) + get_results(_qo(datasource, ">", 1)) + assert view_implementation.get_table.call_count == 2 diff --git a/tests/unit_tests/semantic_layers/cache_test.py b/tests/unit_tests/semantic_layers/cache_test.py new file mode 100644 index 00000000000..bfe5cf948cc --- /dev/null +++ b/tests/unit_tests/semantic_layers/cache_test.py @@ -0,0 +1,396 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from datetime import datetime +from typing import Any + +import pandas as pd +import pyarrow as pa +import pytest +from superset_core.semantic_layers.types import ( + Dimension, + Filter, + Metric, + Operator, + OrderDirection, + PredicateType, + SemanticQuery, + SemanticRequest, + SemanticResult, +) + +from superset.semantic_layers.cache import ( + _apply_post_processing, + _implies, + CachedEntry, + can_satisfy, + shape_key, + value_key, + ViewMeta, +) + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +def dim(id_: str, name: str | None = None) -> Dimension: + return Dimension(id=id_, name=name or id_, type=pa.utf8()) + + +def met(id_: str, name: str | None = None) -> Metric: + return Metric(id=id_, name=name or id_, type=pa.float64(), definition="x") + + +COL_A = dim("col.a", "a") +COL_B = dim("col.b", "b") +M_X = met("met.x", "x") +M_Y = met("met.y", "y") + +VIEW = ViewMeta(uuid="view-1", changed_on_iso="2026-05-01T00:00:00", cache_timeout=None) + + +def where(column: Dimension | Metric | None, op: Operator, value: Any) -> Filter: + return Filter(type=PredicateType.WHERE, column=column, operator=op, value=value) + + +def having(column: Metric, op: Operator, value: Any) -> Filter: + return Filter(type=PredicateType.HAVING, column=column, operator=op, value=value) + + +def adhoc(definition: str, type_: PredicateType = PredicateType.WHERE) -> Filter: + return Filter(type=type_, column=None, operator=Operator.ADHOC, value=definition) + + +def query( + filters: set[Filter] | None = None, + limit: int | None = None, + order: Any = None, + dimensions: list[Dimension] | None = None, + metrics: list[Metric] | None = None, +) -> SemanticQuery: + return SemanticQuery( + metrics=metrics if metrics is not None else [M_X], + dimensions=dimensions if dimensions is not None else [COL_A, COL_B], + filters=filters, + order=order, + limit=limit, + ) + + +def entry_from(q: SemanticQuery, value_key_: str = "vk") -> CachedEntry: + from superset.semantic_layers.cache import _group_limit_key, _order_key + + return CachedEntry( + filters=frozenset(q.filters or set()), + limit=q.limit, + offset=q.offset or 0, + order_key=_order_key(q.order), + group_limit_key=_group_limit_key(q.group_limit), + value_key=value_key_, + ) + + +# --------------------------------------------------------------------------- +# _implies: scalar range pairs +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "new_op,new_val,cached_op,cached_val,expected", + [ + # narrower lower bound + (Operator.GREATER_THAN, 20, Operator.GREATER_THAN, 10, True), + (Operator.GREATER_THAN, 10, Operator.GREATER_THAN, 20, False), + (Operator.GREATER_THAN_OR_EQUAL, 11, Operator.GREATER_THAN, 10, True), + (Operator.GREATER_THAN_OR_EQUAL, 10, Operator.GREATER_THAN, 10, False), + (Operator.GREATER_THAN, 10, Operator.GREATER_THAN_OR_EQUAL, 10, True), + (Operator.GREATER_THAN, 9, Operator.GREATER_THAN_OR_EQUAL, 10, False), + # narrower upper bound + (Operator.LESS_THAN, 5, Operator.LESS_THAN, 10, True), + (Operator.LESS_THAN_OR_EQUAL, 9, Operator.LESS_THAN, 10, True), + (Operator.LESS_THAN_OR_EQUAL, 10, Operator.LESS_THAN, 10, False), + # cross-direction — never implies + (Operator.LESS_THAN, 5, Operator.GREATER_THAN, 10, False), + (Operator.GREATER_THAN, 5, Operator.LESS_THAN, 10, False), + # equals fits in range + (Operator.EQUALS, 15, Operator.GREATER_THAN, 10, True), + (Operator.EQUALS, 10, Operator.GREATER_THAN, 10, False), + (Operator.EQUALS, 10, Operator.GREATER_THAN_OR_EQUAL, 10, True), + ], +) +def test_implies_range( + new_op: Operator, + new_val: Any, + cached_op: Operator, + cached_val: Any, + expected: bool, +) -> None: + assert ( + _implies(where(COL_A, new_op, new_val), where(COL_A, cached_op, cached_val)) + is expected + ) + + +def test_implies_in_subset() -> None: + cached = where(COL_A, Operator.IN, frozenset({"a", "b", "c"})) + assert _implies(where(COL_A, Operator.IN, frozenset({"a", "b"})), cached) is True + assert _implies(where(COL_A, Operator.IN, frozenset({"a", "d"})), cached) is False + # equals to a value in the cached IN set + assert _implies(where(COL_A, Operator.EQUALS, "b"), cached) is True + assert _implies(where(COL_A, Operator.EQUALS, "z"), cached) is False + + +def test_implies_in_all_in_range() -> None: + cached = where(COL_A, Operator.GREATER_THAN, 10) + assert _implies(where(COL_A, Operator.IN, frozenset({11, 12})), cached) is True + assert _implies(where(COL_A, Operator.IN, frozenset({10, 12})), cached) is False + + +def test_implies_equals_exact() -> None: + cached = where(COL_A, Operator.EQUALS, 5) + assert _implies(where(COL_A, Operator.EQUALS, 5), cached) is True + assert _implies(where(COL_A, Operator.EQUALS, 6), cached) is False + + +def test_implies_is_not_null() -> None: + cached = where(COL_A, Operator.IS_NOT_NULL, None) + assert _implies(where(COL_A, Operator.GREATER_THAN, 0), cached) is True + assert _implies(where(COL_A, Operator.IS_NOT_NULL, None), cached) is True + assert _implies(where(COL_A, Operator.IS_NULL, None), cached) is False + + +def test_implies_like_exact_match_only() -> None: + a = where(COL_A, Operator.LIKE, "foo%") + b = where(COL_A, Operator.LIKE, "foo%") + c = where(COL_A, Operator.LIKE, "bar%") + assert _implies(a, b) is True + assert _implies(c, b) is False + assert _implies(where(COL_A, Operator.EQUALS, "fooz"), b) is False + + +# --------------------------------------------------------------------------- +# can_satisfy +# --------------------------------------------------------------------------- + + +def test_can_satisfy_empty_cached_returns_all_as_leftovers() -> None: + cached_q = query(filters=None) + new_q = query(filters={where(COL_A, Operator.GREATER_THAN, 5)}) + ok, leftovers = can_satisfy(entry_from(cached_q), new_q) + assert ok is True + assert leftovers == {where(COL_A, Operator.GREATER_THAN, 5)} + + +def test_can_satisfy_narrower_filter() -> None: + cached_q = query(filters={where(COL_A, Operator.GREATER_THAN, 1)}) + new_q = query(filters={where(COL_A, Operator.GREATER_THAN, 2)}) + ok, leftovers = can_satisfy(entry_from(cached_q), new_q) + assert ok is True + assert leftovers == {where(COL_A, Operator.GREATER_THAN, 2)} + + +def test_can_satisfy_broader_filter_fails() -> None: + cached_q = query(filters={where(COL_A, Operator.GREATER_THAN, 2)}) + new_q = query(filters={where(COL_A, Operator.GREATER_THAN, 1)}) + ok, leftovers = can_satisfy(entry_from(cached_q), new_q) + assert ok is False + assert leftovers == set() + + +def test_can_satisfy_missing_constraint_fails() -> None: + cached_q = query(filters={where(COL_A, Operator.GREATER_THAN, 1)}) + new_q = query(filters=None) + ok, _ = can_satisfy(entry_from(cached_q), new_q) + assert ok is False + + +def test_can_satisfy_new_filter_on_extra_column() -> None: + cached_q = query(filters={where(COL_A, Operator.GREATER_THAN, 1)}) + new_q = query( + filters={ + where(COL_A, Operator.GREATER_THAN, 2), + where(COL_B, Operator.EQUALS, "x"), + } + ) + ok, leftovers = can_satisfy(entry_from(cached_q), new_q) + assert ok is True + assert leftovers == { + where(COL_A, Operator.GREATER_THAN, 2), + where(COL_B, Operator.EQUALS, "x"), + } + + +def test_can_satisfy_leftover_on_non_projected_column_fails() -> None: + other = dim("col.other", "other") + cached_q = query(filters=None) + new_q = query( + filters={where(other, Operator.EQUALS, "x")}, + dimensions=[COL_A, COL_B], + ) + ok, _ = can_satisfy(entry_from(cached_q), new_q) + assert ok is False + + +def test_can_satisfy_having_requires_exact_set() -> None: + cached_q = query(filters={having(M_X, Operator.GREATER_THAN, 100)}) + same = query(filters={having(M_X, Operator.GREATER_THAN, 100)}) + tighter = query(filters={having(M_X, Operator.GREATER_THAN, 200)}) + ok_same, _ = can_satisfy(entry_from(cached_q), same) + ok_tight, _ = can_satisfy(entry_from(cached_q), tighter) + assert ok_same is True + assert ok_tight is False + + +def test_can_satisfy_adhoc_requires_exact_set() -> None: + cached_q = query(filters={adhoc("col_a > 1")}) + same = query(filters={adhoc("col_a > 1")}) + different = query(filters={adhoc("col_a > 2")}) + ok_same, _ = can_satisfy(entry_from(cached_q), same) + ok_diff, _ = can_satisfy(entry_from(cached_q), different) + assert ok_same is True + assert ok_diff is False + + +# --------------------------------------------------------------------------- +# Limit / order / offset +# --------------------------------------------------------------------------- + + +def test_can_satisfy_unlimited_cached_satisfies_any_limit() -> None: + cached_q = query(filters=None, limit=None) + new_q = query(filters=None, limit=10) + ok, leftovers = can_satisfy(entry_from(cached_q), new_q) + assert ok is True + assert leftovers == set() + + +def test_can_satisfy_smaller_limit_with_matching_order() -> None: + order = [(M_X, OrderDirection.DESC)] + cached_q = query(filters=None, limit=100, order=order) + new_q = query(filters=None, limit=10, order=order) + ok, _ = can_satisfy(entry_from(cached_q), new_q) + assert ok is True + + +def test_can_satisfy_smaller_limit_different_order_fails() -> None: + cached_q = query(filters=None, limit=100, order=[(M_X, OrderDirection.DESC)]) + new_q = query(filters=None, limit=10, order=[(M_X, OrderDirection.ASC)]) + ok, _ = can_satisfy(entry_from(cached_q), new_q) + assert ok is False + + +def test_can_satisfy_larger_limit_fails() -> None: + cached_q = query(filters=None, limit=10) + new_q = query(filters=None, limit=100) + ok, _ = can_satisfy(entry_from(cached_q), new_q) + assert ok is False + + +def test_can_satisfy_no_new_limit_when_cached_has_one_fails() -> None: + cached_q = query(filters=None, limit=100) + new_q = query(filters=None, limit=None) + ok, _ = can_satisfy(entry_from(cached_q), new_q) + assert ok is False + + +def test_can_satisfy_offset_never_reused() -> None: + cached_q = SemanticQuery(metrics=[M_X], dimensions=[COL_A], offset=5) + new_q = SemanticQuery(metrics=[M_X], dimensions=[COL_A], offset=5) + ok, _ = can_satisfy(entry_from(cached_q), new_q) + assert ok is False + + +# --------------------------------------------------------------------------- +# Post-processing +# --------------------------------------------------------------------------- + + +def test_apply_post_processing_filters_and_limits() -> None: + df = pd.DataFrame({"a": [1, 3, 5, 7, 9], "x": [10, 20, 30, 40, 50]}) + cached = SemanticResult( + requests=[SemanticRequest(type="SQL", definition="select ...")], + results=pa.Table.from_pandas(df, preserve_index=False), + ) + new_q = query( + filters={where(COL_A, Operator.GREATER_THAN, 2)}, + limit=2, + ) + result = _apply_post_processing( + cached, new_q, {where(COL_A, Operator.GREATER_THAN, 2)} + ) + result_df = result.results.to_pandas() + assert list(result_df["a"]) == [3, 5] + # the cache annotates the requests with a marker + assert any(req.type == "cache" for req in result.requests) + + +def test_apply_post_processing_no_leftovers_no_limit_returns_original() -> None: + df = pd.DataFrame({"a": [1, 2]}) + cached = SemanticResult( + requests=[], results=pa.Table.from_pandas(df, preserve_index=False) + ) + new_q = query(filters=None, limit=None) + out = _apply_post_processing(cached, new_q, set()) + # same object reference is OK; we explicitly return the input + assert out is cached + + +# --------------------------------------------------------------------------- +# Hash stability +# --------------------------------------------------------------------------- + + +def test_value_key_stable_across_metric_order() -> None: + q1 = SemanticQuery(metrics=[M_X, M_Y], dimensions=[COL_A]) + q2 = SemanticQuery(metrics=[M_Y, M_X], dimensions=[COL_A]) + assert value_key(VIEW, q1) == value_key(VIEW, q2) + + +def test_shape_key_stable_across_dimension_order() -> None: + q1 = SemanticQuery(metrics=[M_X], dimensions=[COL_A, COL_B]) + q2 = SemanticQuery(metrics=[M_X], dimensions=[COL_B, COL_A]) + assert shape_key(VIEW, q1) == shape_key(VIEW, q2) + + +def test_shape_key_changes_with_changed_on() -> None: + q = SemanticQuery(metrics=[M_X], dimensions=[COL_A]) + other = ViewMeta(uuid=VIEW.uuid, changed_on_iso="2099-01-01", cache_timeout=None) + assert shape_key(VIEW, q) != shape_key(other, q) + + +def test_value_key_changes_with_filter_value() -> None: + q1 = SemanticQuery( + metrics=[M_X], + dimensions=[COL_A], + filters={where(COL_A, Operator.GREATER_THAN, 1)}, + ) + q2 = SemanticQuery( + metrics=[M_X], + dimensions=[COL_A], + filters={where(COL_A, Operator.GREATER_THAN, 2)}, + ) + assert value_key(VIEW, q1) != value_key(VIEW, q2) + + +def test_value_key_with_datetime_filter() -> None: + f = where(COL_A, Operator.GREATER_THAN_OR_EQUAL, datetime(2025, 1, 1)) + q = SemanticQuery(metrics=[M_X], dimensions=[COL_A], filters={f}) + # should not raise + assert value_key(VIEW, q).startswith("sv:val:")