Files
superset2/superset/semantic_layers/cache.py
Beto Dealmeida dca18116ae Improvements
2026-05-14 11:22:29 -04:00

710 lines
24 KiB
Python

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Containment-aware cache for semantic view queries.
A broader cached result can satisfy a narrower new query: when the new query's
filters and limit are strictly more restrictive than a cached entry's, the cached
DataFrame is post-filtered and re-limited rather than re-executing the underlying
query.
See ``docs/`` and the plan file for the design rationale; the rules summary:
* Same metrics and dimensions (shape).
* Each cached filter must be implied by a new-query filter on the same column.
* New filters on columns with no cached constraint are applied post-fetch as
"leftovers" — provided the column is in the projection.
* Cached ``limit`` must be at least the new ``limit``; if a cached ``limit`` is
present, the orderings must match (otherwise the cached "top N" is not the
true top of the new query).
* ``ADHOC`` and ``HAVING`` filters require exact-set equality.
* ``offset != 0`` and mismatching ``group_limit`` skip the cache.
"""
from __future__ import annotations
import logging
import re
import time as _time
from dataclasses import dataclass, field
from datetime import date, datetime, time, timedelta
from typing import Any, Iterable
import pandas as pd
import pyarrow as pa
from flask import current_app
from superset_core.semantic_layers.types import (
AdhocExpression,
AggregationType,
Dimension,
Filter,
Metric,
Operator,
OrderDirection,
OrderTuple,
PredicateType,
SemanticQuery,
SemanticRequest,
SemanticResult,
)
from superset.extensions import cache_manager
from superset.utils import json
from superset.utils.hashing import hash_from_str
from superset.utils.pandas_postprocessing.aggregate import aggregate
logger = logging.getLogger(__name__)
INDEX_KEY_PREFIX = "sv:idx:"
VALUE_KEY_PREFIX = "sv:val:"
MAX_ENTRIES_PER_SHAPE = 32
_AGGREGATION_TO_PANDAS: dict[AggregationType, str] = {
AggregationType.SUM: "sum",
AggregationType.COUNT: "sum",
AggregationType.MIN: "min",
AggregationType.MAX: "max",
}
ADDITIVE_AGGREGATIONS = frozenset(_AGGREGATION_TO_PANDAS)
@dataclass(frozen=True)
class ViewMeta:
"""Identity/freshness/TTL info pulled from the SemanticView ORM row."""
uuid: str
changed_on_iso: str
cache_timeout: int | None
@dataclass(frozen=True)
class CachedEntry:
filters: frozenset[Filter]
dimension_keys: frozenset[str]
limit: int | None
offset: int
order_key: str
group_limit_key: str
value_key: str
timestamp: float = field(default_factory=_time.time)
# ---------------------------------------------------------------------------
# Public surface
# ---------------------------------------------------------------------------
def try_serve_from_cache(
view_meta: ViewMeta,
query: SemanticQuery,
) -> SemanticResult | None:
"""Return a cached ``SemanticResult`` that satisfies ``query`` if any."""
try:
cache = cache_manager.data_cache
idx_key = shape_key(view_meta, query)
entries: list[CachedEntry] | None = cache.get(idx_key)
if not entries:
return None
pruned: list[CachedEntry] = []
served: SemanticResult | None = None
for entry in entries:
if served is None:
ok, leftovers, projection_needed = can_satisfy(entry, query)
if ok:
payload = cache.get(entry.value_key)
if payload is None:
# value evicted but index entry survived; drop it
continue
pruned.append(entry)
served = _apply_post_processing(
payload, query, leftovers, projection_needed
)
continue
# keep entry; verify its value is still alive
if cache.get(entry.value_key) is not None:
pruned.append(entry)
if len(pruned) != len(entries):
cache.set(idx_key, pruned, timeout=_timeout(view_meta))
return served
except Exception: # pragma: no cover - defensive
logger.warning("Semantic view cache lookup failed", exc_info=True)
return None
def store_result(
view_meta: ViewMeta,
query: SemanticQuery,
result: SemanticResult,
) -> None:
"""Persist ``result`` under a fresh value key and register a descriptor."""
try:
cache = cache_manager.data_cache
timeout = _timeout(view_meta)
vkey = value_key(view_meta, query)
cache.set(vkey, result, timeout=timeout)
idx_key = shape_key(view_meta, query)
entries: list[CachedEntry] = list(cache.get(idx_key) or [])
entry = CachedEntry(
filters=frozenset(query.filters or set()),
dimension_keys=frozenset(_dimension_key(d) for d in query.dimensions),
limit=query.limit,
offset=query.offset or 0,
order_key=_order_key(query.order),
group_limit_key=_group_limit_key(query.group_limit),
value_key=vkey,
)
entries = [e for e in entries if e.value_key != vkey]
entries.append(entry)
if len(entries) > MAX_ENTRIES_PER_SHAPE:
entries = sorted(entries, key=lambda e: e.timestamp)[
-MAX_ENTRIES_PER_SHAPE:
]
cache.set(idx_key, entries, timeout=timeout)
except Exception: # pragma: no cover - defensive
logger.warning("Semantic view cache store failed", exc_info=True)
# ---------------------------------------------------------------------------
# Keys
# ---------------------------------------------------------------------------
def shape_key(view_meta: ViewMeta, query: SemanticQuery) -> str:
# The shape key buckets entries by metric set only; dimensions live on each
# ``CachedEntry`` so we can find broader (dimension-superset) entries via the
# projection path.
shape = {"m": sorted(m.id for m in query.metrics)}
digest = hash_from_str(json.dumps(shape, sort_keys=True))[:16]
return f"{INDEX_KEY_PREFIX}{view_meta.uuid}:{view_meta.changed_on_iso}:{digest}"
def value_key(view_meta: ViewMeta, query: SemanticQuery) -> str:
digest = hash_from_str(json.dumps(_canonicalize(query), sort_keys=True))[:32]
return f"{VALUE_KEY_PREFIX}{view_meta.uuid}:{view_meta.changed_on_iso}:{digest}"
def _dimension_key(dim: Dimension) -> str:
grain = dim.grain.representation if dim.grain else "_"
return f"{dim.id}@{grain}"
def _canonicalize(query: SemanticQuery) -> dict[str, Any]:
return {
"m": sorted(m.id for m in query.metrics),
"d": sorted(_dimension_key(d) for d in query.dimensions),
"f": sorted(_filter_to_jsonable(f) for f in (query.filters or [])),
"o": _order_key(query.order),
"l": query.limit,
"off": query.offset or 0,
"gl": _group_limit_key(query.group_limit),
}
def _filter_to_jsonable(f: Filter) -> str:
return json.dumps(
{
"t": f.type.value,
"c": f.column.id if f.column is not None else None,
"o": f.operator.value,
"v": _value_to_jsonable(f.value),
},
sort_keys=True,
)
def _value_to_jsonable(value: Any) -> Any:
if isinstance(value, frozenset):
return sorted(_value_to_jsonable(v) for v in value)
if isinstance(value, (datetime, date, time)):
return value.isoformat()
if isinstance(value, timedelta):
return value.total_seconds()
return value
def _order_key(order: list[OrderTuple] | None) -> str:
if not order:
return ""
return json.dumps(
[(_orderable_id(element), direction.value) for element, direction in order]
)
def _orderable_id(element: Metric | Dimension | AdhocExpression) -> str:
return element.id
def _group_limit_key(group_limit: Any) -> str:
if group_limit is None:
return ""
return json.dumps(
{
"dims": sorted(d.id for d in group_limit.dimensions),
"top": group_limit.top,
"metric": group_limit.metric.id if group_limit.metric else None,
"direction": group_limit.direction.value,
"group_others": group_limit.group_others,
"filters": sorted(
_filter_to_jsonable(f) for f in (group_limit.filters or [])
),
},
sort_keys=True,
)
def _timeout(view_meta: ViewMeta) -> int | None:
if view_meta.cache_timeout is not None:
return view_meta.cache_timeout
config = current_app.config.get("DATA_CACHE_CONFIG") or {}
return config.get("CACHE_DEFAULT_TIMEOUT")
# ---------------------------------------------------------------------------
# Containment
# ---------------------------------------------------------------------------
def can_satisfy( # noqa: C901
entry: CachedEntry,
query: SemanticQuery,
) -> tuple[bool, set[Filter], bool]:
"""
Return ``(reusable, leftover_filters, projection_needed)`` for ``entry`` vs
``query``. ``projection_needed`` is True when the cached entry has a strict
superset of the new dimensions and a pandas rollup is required.
"""
new_dim_keys = frozenset(_dimension_key(d) for d in query.dimensions)
cached_dim_keys = entry.dimension_keys
if cached_dim_keys == new_dim_keys:
projection_needed = False
elif cached_dim_keys > new_dim_keys:
projection_needed = True
if not _projection_allowed(entry, query, new_dim_keys, cached_dim_keys):
return False, set(), False
else:
return False, set(), False
new_filters = frozenset(query.filters or set())
c_adhoc, c_having, c_where = _split(entry.filters)
n_adhoc, n_having, n_where = _split(new_filters)
if c_adhoc != n_adhoc:
return False, set(), False
if c_having != n_having:
return False, set(), False
c_by_col = _group_by_column(c_where)
n_by_col = _group_by_column(n_where)
for c_list in c_by_col.values():
for c in c_list:
n_list = n_by_col.get(_filter_col_id(c), [])
if not any(_implies(n, c) for n in n_list):
return False, set(), False
leftovers: set[Filter] = set()
for col_id, n_list in n_by_col.items():
c_list = c_by_col.get(col_id, [])
for n in n_list:
if not any(_implies(c, n) for c in c_list):
if n.column is None or n.operator == Operator.ADHOC:
return False, set(), False
leftovers.add(n)
# Leftover filters are applied to the cached DataFrame BEFORE the optional
# rollup, so their columns must be present in the cached projection.
allowed_ids = _cached_column_ids(entry, query)
for leftover in leftovers:
if leftover.column is None or leftover.column.id not in allowed_ids:
return False, set(), False
if entry.offset != 0 or (query.offset or 0) != 0:
return False, set(), False
if projection_needed:
# Re-aggregation will re-order by ``query.order`` after rollup, so the
# cached order is irrelevant. We do require the new order (if any) to
# reference only surviving columns; otherwise sort would fail post-rollup.
if not _order_uses_only(query.order, _projection_ids(query)):
return False, set(), False
else:
if entry.limit is not None:
if query.limit is None or query.limit > entry.limit:
return False, set(), False
if entry.order_key != _order_key(query.order):
return False, set(), False
if entry.group_limit_key != _group_limit_key(query.group_limit):
return False, set(), False
return True, leftovers, projection_needed
def _projection_allowed(
entry: CachedEntry,
query: SemanticQuery,
new_dim_keys: frozenset[str],
cached_dim_keys: frozenset[str],
) -> bool:
"""
Gates for the projection path (above and beyond filter containment).
"""
if any(m.aggregation not in ADDITIVE_AGGREGATIONS for m in query.metrics):
return False
# Cached truncation makes the rollup unsafe (we're missing rows).
if entry.limit is not None:
return False
if entry.group_limit_key:
return False
if query.group_limit is not None:
return False
# Cached HAVING dropped sub-aggregate rows; the rolled-up totals would be
# off. Conservative: skip the projection path when cached has any HAVING.
if any(f.type == PredicateType.HAVING for f in entry.filters):
return False
return True
def _filter_col_id(f: Filter) -> str | None:
return f.column.id if f.column is not None else None
def _order_uses_only(
order: list[OrderTuple] | None,
allowed_ids: set[str],
) -> bool:
if not order:
return True
return all(_orderable_id(element) in allowed_ids for element, _ in order)
def _split(
filters: Iterable[Filter],
) -> tuple[frozenset[Filter], frozenset[Filter], frozenset[Filter]]:
adhoc: set[Filter] = set()
having: set[Filter] = set()
where: set[Filter] = set()
for f in filters:
if f.operator == Operator.ADHOC:
adhoc.add(f)
elif f.type == PredicateType.HAVING:
having.add(f)
else:
where.add(f)
return frozenset(adhoc), frozenset(having), frozenset(where)
def _group_by_column(filters: Iterable[Filter]) -> dict[str | None, list[Filter]]:
out: dict[str | None, list[Filter]] = {}
for f in filters:
col_id = f.column.id if f.column is not None else None
out.setdefault(col_id, []).append(f)
return out
def _projection_ids(query: SemanticQuery) -> set[str]:
return {d.id for d in query.dimensions} | {m.id for m in query.metrics}
def _cached_column_ids(entry: CachedEntry, query: SemanticQuery) -> set[str]:
"""Column ids available in the cached DataFrame (cached dims + shared metrics)."""
cached_dim_ids = {key.rsplit("@", 1)[0] for key in entry.dimension_keys}
return cached_dim_ids | {m.id for m in query.metrics}
# ---------------------------------------------------------------------------
# Pairwise implication
# ---------------------------------------------------------------------------
# pylint: disable=too-many-return-statements,too-many-branches
def _implies(new: Filter, cached: Filter) -> bool: # noqa: C901
"""True iff every row matching ``new`` also matches ``cached``.
Both filters are assumed to be on the same column (caller groups by column).
"""
if new == cached:
return True
nop, nval = new.operator, new.value
cop, cval = cached.operator, cached.value
if cop == Operator.IS_NULL:
if nop == Operator.IS_NULL:
return True
if nop == Operator.EQUALS and nval is None:
return True
return False
if cop == Operator.IS_NOT_NULL:
if nop == Operator.IS_NOT_NULL:
return True
if nop == Operator.EQUALS:
return nval is not None
if nop in _RANGE_OPS:
return True
if nop == Operator.IN:
return isinstance(nval, frozenset) and all(v is not None for v in nval)
return False
if cop == Operator.EQUALS:
if nop == Operator.EQUALS:
return nval == cval
if nop == Operator.IN and isinstance(nval, frozenset):
return nval == frozenset({cval})
return False
if cop == Operator.NOT_EQUALS:
if nop == Operator.NOT_EQUALS:
return nval == cval
if nop == Operator.EQUALS:
return nval != cval
if nop == Operator.IN and isinstance(nval, frozenset):
return cval not in nval
return False
if cop == Operator.IN and isinstance(cval, frozenset):
if nop == Operator.IN and isinstance(nval, frozenset):
return nval.issubset(cval)
if nop == Operator.EQUALS:
return nval in cval
return False
if cop == Operator.NOT_IN and isinstance(cval, frozenset):
if nop == Operator.NOT_IN and isinstance(nval, frozenset):
return cval.issubset(nval)
if nop == Operator.NOT_EQUALS:
return cval.issubset({nval})
if nop == Operator.EQUALS:
return nval not in cval
if nop == Operator.IN and isinstance(nval, frozenset):
return cval.isdisjoint(nval)
return False
if cop in _RANGE_OPS:
return _implies_range(nop, nval, cop, cval)
# LIKE / NOT_LIKE / ADHOC: only the exact-match path at the top.
return False
_RANGE_OPS = frozenset(
{
Operator.GREATER_THAN,
Operator.GREATER_THAN_OR_EQUAL,
Operator.LESS_THAN,
Operator.LESS_THAN_OR_EQUAL,
}
)
def _implies_range( # noqa: C901
nop: Operator,
nval: Any,
cop: Operator,
cval: Any,
) -> bool:
if isinstance(nval, frozenset):
return nop == Operator.IN and all(_scalar_in_range(v, cop, cval) for v in nval)
if nop == Operator.EQUALS:
return _scalar_in_range(nval, cop, cval)
if nop not in _RANGE_OPS:
return False
if not _comparable(nval, cval):
return False
# Same direction (both upper or both lower bounds) required.
cached_is_lower = cop in (Operator.GREATER_THAN, Operator.GREATER_THAN_OR_EQUAL)
new_is_lower = nop in (Operator.GREATER_THAN, Operator.GREATER_THAN_OR_EQUAL)
if cached_is_lower != new_is_lower:
return False
if cached_is_lower:
# cached: a > cval or a >= cval
# new: a > nval or a >= nval
# need rows(new) ⊆ rows(cached)
if cop == Operator.GREATER_THAN and nop == Operator.GREATER_THAN:
return nval >= cval
if cop == Operator.GREATER_THAN and nop == Operator.GREATER_THAN_OR_EQUAL:
return nval > cval
if cop == Operator.GREATER_THAN_OR_EQUAL and nop == Operator.GREATER_THAN:
return nval >= cval
if (
cop == Operator.GREATER_THAN_OR_EQUAL
and nop == Operator.GREATER_THAN_OR_EQUAL
):
return nval >= cval
return False
else:
if cop == Operator.LESS_THAN and nop == Operator.LESS_THAN:
return nval <= cval
if cop == Operator.LESS_THAN and nop == Operator.LESS_THAN_OR_EQUAL:
return nval < cval
if cop == Operator.LESS_THAN_OR_EQUAL and nop == Operator.LESS_THAN:
return nval <= cval
if cop == Operator.LESS_THAN_OR_EQUAL and nop == Operator.LESS_THAN_OR_EQUAL:
return nval <= cval
return False
def _scalar_in_range(value: Any, cop: Operator, cval: Any) -> bool:
if not _comparable(value, cval):
return False
if cop == Operator.GREATER_THAN:
return value > cval
if cop == Operator.GREATER_THAN_OR_EQUAL:
return value >= cval
if cop == Operator.LESS_THAN:
return value < cval
if cop == Operator.LESS_THAN_OR_EQUAL:
return value <= cval
return False
def _comparable(a: Any, b: Any) -> bool:
if a is None or b is None:
return False
if isinstance(a, bool) or isinstance(b, bool):
return isinstance(a, bool) and isinstance(b, bool)
if isinstance(a, (int, float)) and isinstance(b, (int, float)):
return True
if isinstance(a, str) and isinstance(b, str):
return True
if isinstance(a, (datetime, date, time)) and isinstance(b, type(a)):
return True
if isinstance(a, type(b)) and isinstance(a, (datetime, date, time, timedelta)):
return True
return type(a) == type(b) # noqa: E721
# ---------------------------------------------------------------------------
# Post-processing
# ---------------------------------------------------------------------------
def _apply_post_processing(
cached: SemanticResult,
query: SemanticQuery,
leftovers: set[Filter],
projection_needed: bool,
) -> SemanticResult:
"""Apply leftover filters, projection (re-aggregation), order, and limit."""
if not leftovers and not projection_needed and query.limit is None:
return cached
df = cached.results.to_pandas()
if leftovers:
mask = pd.Series(True, index=df.index)
for f in leftovers:
mask &= _mask_for(df, f)
df = df[mask]
note_def = "Served from semantic view smart cache (post-processed locally)"
if projection_needed:
groupby = [d.name for d in query.dimensions]
aggregates = {
m.name: {
"column": m.name,
"operator": _AGGREGATION_TO_PANDAS[m.aggregation],
}
for m in query.metrics
}
df = aggregate(df, groupby=groupby, aggregates=aggregates)
note_def = "Served from semantic view smart cache (re-aggregated locally)"
df = _apply_order(df, query.order)
if query.limit is not None:
df = df.head(query.limit)
table = pa.Table.from_pandas(df, preserve_index=False)
note = SemanticRequest(type="cache", definition=note_def)
return SemanticResult(requests=list(cached.requests) + [note], results=table)
def _apply_order(
df: pd.DataFrame,
order: list[OrderTuple] | None,
) -> pd.DataFrame:
if not order:
return df
available: list[tuple[str, bool]] = []
for element, direction in order:
col = _orderable_id_name(element)
if col in df.columns:
available.append((col, direction == OrderDirection.ASC))
if not available:
return df
cols = [col for col, _ in available]
asc = [a for _, a in available]
return df.sort_values(by=cols, ascending=asc).reset_index(drop=True)
def _orderable_id_name(element: Metric | Dimension | AdhocExpression) -> str:
return getattr(element, "name", element.id)
def _mask_for(df: pd.DataFrame, f: Filter) -> pd.Series: # noqa: C901
if f.column is None:
return pd.Series(True, index=df.index)
series = df[f.column.name]
op = f.operator
val = f.value
if op == Operator.EQUALS:
return series == val if val is not None else series.isna()
if op == Operator.NOT_EQUALS:
return series != val if val is not None else series.notna()
if op == Operator.GREATER_THAN:
return series > val
if op == Operator.GREATER_THAN_OR_EQUAL:
return series >= val
if op == Operator.LESS_THAN:
return series < val
if op == Operator.LESS_THAN_OR_EQUAL:
return series <= val
if op == Operator.IN:
return series.isin(list(val) if isinstance(val, frozenset) else [val])
if op == Operator.NOT_IN:
return ~series.isin(list(val) if isinstance(val, frozenset) else [val])
if op == Operator.IS_NULL:
return series.isna()
if op == Operator.IS_NOT_NULL:
return series.notna()
if op == Operator.LIKE:
return series.astype(str).str.match(_sql_like_to_regex(str(val)))
if op == Operator.NOT_LIKE:
return ~series.astype(str).str.match(_sql_like_to_regex(str(val)))
return pd.Series(True, index=df.index)
def _sql_like_to_regex(pattern: str) -> str:
out = []
for ch in pattern:
if ch == "%":
out.append(".*")
elif ch == "_":
out.append(".")
else:
out.append(re.escape(ch))
return f"^{''.join(out)}$"