Compare commits

...

8 Commits

Author SHA1 Message Date
Beto Dealmeida
23889a8d27 Fix lint 2026-05-14 18:19:38 -04:00
Beto Dealmeida
aaa87d79c2 Add indicator in Explore 2026-05-14 17:58:27 -04:00
Beto Dealmeida
520401e23d Fixes 2026-05-14 14:42:48 -04:00
Beto Dealmeida
dca18116ae Improvements 2026-05-14 11:22:29 -04:00
Beto Dealmeida
209b44522d Leverage additive metrics 2026-05-14 11:22:29 -04:00
Beto Dealmeida
861ed4aa45 feat(semantic layers): cache 2026-05-14 11:22:29 -04:00
Beto Dealmeida
137c9fca6d Improvements 2026-05-14 11:18:33 -04:00
Beto Dealmeida
671eed7863 feat(semantic layers): form for semantic layer with single semantic view 2026-05-14 11:02:33 -04:00
29 changed files with 2163 additions and 29 deletions

View File

@@ -92,6 +92,26 @@ class Dimension:
grain: Grain | None = None
class AggregationType(str, enum.Enum):
"""
Aggregation function applied by a metric.
Additivity (across an arbitrary set of grouping dimensions):
* ``SUM``, ``COUNT``: fully additive — sub-group sums roll up via ``sum``.
* ``MIN``, ``MAX``: roll up via ``min`` / ``max`` of sub-group values.
* ``AVG``, ``COUNT_DISTINCT``, ``OTHER``: not safely roll-uppable from
sub-aggregates without auxiliary data.
"""
SUM = "SUM"
COUNT = "COUNT"
MIN = "MIN"
MAX = "MAX"
AVG = "AVG"
COUNT_DISTINCT = "COUNT_DISTINCT"
OTHER = "OTHER"
@dataclass(frozen=True)
class Metric:
id: str
@@ -100,6 +120,7 @@ class Metric:
definition: str
description: str | None = None
aggregation: AggregationType | None = None
@dataclass(frozen=True)

View File

@@ -80,7 +80,7 @@ const restrictedImportsRules = {
'no-jest-mock-console': {
name: 'jest-mock-console',
message: 'Please use native Jest spies, i.e. jest.spyOn(console, "warn")',
}
},
};
module.exports = {

View File

@@ -36,3 +36,14 @@ test('Rendering TooltipContent correctly - with timestep', () => {
.fromNow()}. Click to force-refresh`,
);
});
test('Rendering TooltipContent correctly - semantic cache', () => {
render(
<TooltipContent cacheSource="semantic" cachedTimestamp="01-01-2000" />,
);
expect(screen.getByTestId('tooltip-content')?.textContent).toBe(
`Loaded from semantic smart cache ${extendedDayjs
.utc('01-01-2000')
.fromNow()}. Click to force-refresh`,
);
});

View File

@@ -23,15 +23,28 @@ import { extendedDayjs } from '../../utils/dates';
interface Props {
cachedTimestamp?: string;
cacheSource?: 'query' | 'semantic';
}
export const TooltipContent: FC<Props> = ({ cachedTimestamp }) => {
export const TooltipContent: FC<Props> = ({
cachedTimestamp,
cacheSource = 'query',
}) => {
const loadedFromText =
cacheSource === 'semantic'
? t('Loaded from semantic smart cache')
: t('Loaded data cached');
const loadedFallbackText =
cacheSource === 'semantic'
? t('Loaded from semantic smart cache')
: t('Loaded from cache');
const cachedText = cachedTimestamp ? (
<span>
{t('Loaded data cached')}
{loadedFromText}
<b> {extendedDayjs.utc(cachedTimestamp).fromNow()}</b>
</span>
) : (
t('Loaded from cache')
loadedFallbackText
);
return (

View File

@@ -29,13 +29,19 @@ export const CachedLabel: FC<CacheLabelProps> = ({
className,
onClick,
cachedTimestamp,
cacheSource = 'query',
}) => {
const [hovered, setHovered] = useState(false);
const labelType = hovered ? 'info' : 'default';
return (
<Tooltip
title={<TooltipContent cachedTimestamp={cachedTimestamp} />}
title={
<TooltipContent
cachedTimestamp={cachedTimestamp}
cacheSource={cacheSource}
/>
}
id="cache-desc-tooltip"
>
<Label

View File

@@ -22,5 +22,6 @@ import type { MouseEventHandler } from 'react';
export interface CacheLabelProps {
onClick?: MouseEventHandler<HTMLElement>;
cachedTimestamp?: string;
cacheSource?: 'query' | 'semantic';
className?: string;
}

View File

@@ -519,7 +519,8 @@ const Select = forwardRef(
handleSelectAll();
}}
>
{t('Select all')} {`(${formatNumber('SMART_NUMBER', bulkSelectCounts.selectable)})`}
{t('Select all')}{' '}
{`(${formatNumber('SMART_NUMBER', bulkSelectCounts.selectable)})`}
</Button>
<Button
type="link"
@@ -536,7 +537,8 @@ const Select = forwardRef(
handleDeselectAll();
}}
>
{t('Clear')} {`(${formatNumber('SMART_NUMBER', bulkSelectCounts.deselectable)})`}
{t('Clear')}{' '}
{`(${formatNumber('SMART_NUMBER', bulkSelectCounts.deselectable)})`}
</Button>
</StyledBulkActionsContainer>
),

View File

@@ -60,6 +60,7 @@ export interface ChartDataResponseResult {
coltypes: GenericDataType[];
error: string | null;
is_cached: boolean;
semantic_cache_hit?: boolean | null;
query: string;
rowcount: number;
sql_rowcount: number;

View File

@@ -182,10 +182,7 @@ testWithAssets(
// Now track POST /api/v1/chart/data requests around Clear All
const postsAfterClearAll: string[] = [];
const handler = (req: any) => {
if (
req.url().includes('/api/v1/chart/data') &&
req.method() === 'POST'
) {
if (req.url().includes('/api/v1/chart/data') && req.method() === 'POST') {
postsAfterClearAll.push(req.url());
}
};

View File

@@ -288,9 +288,7 @@ describe('BigNumberWithTrendline transformProps', () => {
height: 300,
queriesData: [
{
data: [
{ __timestamp: 1, value: 100 },
] as unknown as BigNumberDatum[],
data: [{ __timestamp: 1, value: 100 }] as unknown as BigNumberDatum[],
colnames: ['__timestamp', 'value'],
coltypes: ['TEMPORAL', 'NUMERIC'],
},

View File

@@ -122,7 +122,9 @@ describe('getChartIdsFromLayout', () => {
hash: '',
standalone: DashboardStandaloneMode.HideNav,
});
expect(url).toBe(`/dashboard/1/?standalone=${DashboardStandaloneMode.HideNav}`);
expect(url).toBe(
`/dashboard/1/?standalone=${DashboardStandaloneMode.HideNav}`,
);
});
test('should process native filters key', () => {

View File

@@ -66,6 +66,9 @@ export const ChartPills = forwardRef(
) => {
const isLoading = chartStatus === 'loading';
const firstQueryResponse = queriesResponse?.[0];
const isQueryCached = Boolean(firstQueryResponse?.is_cached);
const isSemanticCached = Boolean(firstQueryResponse?.semantic_cache_hit);
const isAnyCacheHit = isQueryCached || isSemanticCached;
// For table charts with server pagination, check second query for total count
const isTableChart =
@@ -100,10 +103,15 @@ export const ChartPills = forwardRef(
limit={Number(rowLimit ?? 0)}
/>
)}
{!isLoading && firstQueryResponse?.is_cached && (
{!isLoading && isAnyCacheHit && (
<CachedLabel
onClick={refreshCachedQuery}
cachedTimestamp={firstQueryResponse.cached_dttm}
cachedTimestamp={
isQueryCached
? firstQueryResponse?.cached_dttm
: firstQueryResponse?.queried_dttm
}
cacheSource={isSemanticCached ? 'semantic' : 'query'}
/>
)}
<Timer

View File

@@ -168,6 +168,23 @@ describe('ChartContainer', () => {
expect(screen.queryByText(/cached/i)).not.toBeInTheDocument();
});
test('should show cached button for semantic smart cache hit', async () => {
const props = createProps({
chart: {
chartStatus: 'rendered',
queriesResponse: [
{
is_cached: false,
semantic_cache_hit: true,
queried_dttm: '2026-01-01',
},
],
},
});
render(<ChartContainer {...props} />, { useRedux: true });
expect(await screen.findByText(/cached/i)).toBeInTheDocument();
});
test('hides gutter when collapsing data panel', async () => {
const props = createProps();
setItem(LocalStorageKeys.IsDatapanelOpen, true);

View File

@@ -188,7 +188,9 @@ function CollectionControl({
// Two items can collide when keyAccessor returns falsy and the index
// fallback is used — breaking dnd-kit reordering and React reconciliation.
// Assign a stable nanoid per item ref when no key is available.
const generatedIdsRef = useRef<WeakMap<CollectionItem, string>>(new WeakMap());
const generatedIdsRef = useRef<WeakMap<CollectionItem, string>>(
new WeakMap(),
);
const itemIds = useMemo(
() =>
value.map(item => {

View File

@@ -255,22 +255,105 @@ const EnumNamesRenderer = withJsonFormsControlProps(EnumNamesControl);
const enumNamesEntry = {
// Rank 5: higher than the default string renderer (23) so this fires
// whenever x-enumNames is present, regardless of the underlying type.
// Array-of-enum schemas are handled by ``multiEnumEntry`` below — this
// renderer only targets scalar string/number controls.
tester: rankWith(
5,
schemaMatches(s => {
const names = (s as Record<string, unknown>)['x-enumNames'];
return Array.isArray(names) && (names as unknown[]).length > 0;
}),
and(
schemaMatches(s => {
const names = (s as Record<string, unknown>)['x-enumNames'];
return Array.isArray(names) && (names as unknown[]).length > 0;
}),
schemaMatches(s => (s as Record<string, unknown>)?.type !== 'array'),
),
),
renderer: EnumNamesRenderer,
};
/**
* Renderer for ``{type: 'array', items: {enum: [...]}}`` schemas. Renders
* a single Antd Select with ``mode="multiple"`` (tag-style multi-select),
* matching the natural expectation of a "pick several from a list" control.
*
* Without this, the default ``PrimitiveArrayControl`` from the upstream
* library renders an "Add …" button that creates one single-select per
* element — visually wrong for an enum multi-select and unable to display
* ``items.x-enumNames`` labels.
*
* The renderer is dynamic-aware: when the host form is refreshing the
* schema (e.g. compatible options narrowing as the user picks), the Select
* shows a loading indicator without becoming disabled, so the user can
* continue editing while options refresh.
*/
function MultiEnumControl(props: ControlProps) {
const { refreshingSchema } = props.config ?? {};
const arraySchema = props.schema as Record<string, unknown>;
const itemsSchema =
(arraySchema.items as Record<string, unknown>) ??
({} as Record<string, unknown>);
const enumValues = (itemsSchema.enum as unknown[]) ?? [];
const enumNames =
(itemsSchema['x-enumNames'] as string[]) ?? enumValues.map(String);
const options = enumValues.map((value, index) => ({
value: value as string | number,
label: enumNames[index] ?? String(value),
}));
const value = Array.isArray(props.data) ? (props.data as unknown[]) : [];
const tooltip = (props.uischema?.options as Record<string, unknown>)
?.tooltip as string | undefined;
return (
<Form.Item label={props.label} tooltip={tooltip}>
<Select
mode="multiple"
value={value as (string | number)[]}
onChange={next => props.handleChange(props.path, next)}
options={options}
style={{ width: '100%' }}
disabled={!props.enabled}
loading={!!refreshingSchema}
allowClear
optionFilterProp="label"
placeholder={
(props.uischema?.options as Record<string, unknown>)
?.placeholderText as string | undefined
}
/>
</Form.Item>
);
}
const MultiEnumRenderer = withJsonFormsControlProps(MultiEnumControl);
const multiEnumEntry = {
// Rank 35: must beat upstream ``PrimitiveArrayRenderer`` (rank 30) so an
// ``array``/``items.enum`` schema renders as one Antd multi-select tag
// box instead of the "Add" repeater pattern that PrimitiveArray uses.
tester: rankWith(
35,
schemaMatches(s => {
const schema = s as Record<string, unknown>;
if (schema?.type !== 'array') return false;
const items = schema.items as Record<string, unknown> | undefined;
return (
!!items &&
Array.isArray(items.enum) &&
(items.enum as unknown[]).length > 0
);
}),
),
renderer: MultiEnumRenderer,
};
export const renderers = [
...rendererRegistryEntries,
passwordEntry,
constEntry,
readOnlyEntry,
enumNamesEntry,
multiEnumEntry,
dynamicFieldEntry,
];

View File

@@ -254,7 +254,9 @@ export default function AddSemanticViewModal({
!schema?.properties ||
Object.keys(schema.properties).length === 0
) {
// No runtime config needed — fetch views right away
// Preserve top-level runtime metadata (e.g. x-singleView) even when
// there are no form fields, then fetch views right away.
applyRuntimeSchema(schema);
fetchViews(uuid, {}, gen);
} else {
applyRuntimeSchema(schema);
@@ -456,6 +458,32 @@ export default function AddSemanticViewModal({
const viewsDisabled =
loadingViews || (!loadingViews && availableViews.length === 0);
// When ``x-singleView: true`` the runtime form fully describes a single
// semantic view (e.g. a MetricFlow cube). Hide the picker and auto-select
// whatever ``get_semantic_views`` returned so the Add button can fire
// without an extra user click.
const singleViewMode =
(runtimeSchema as Record<string, unknown> | null)?.['x-singleView'] ===
true;
useEffect(() => {
if (!singleViewMode) return;
const namesToAdd = availableViews
.filter(v => !v.already_added)
.map(v => v.name)
.sort((a, b) => a.localeCompare(b))
.slice(0, 1);
setSelectedViewNames(prev => {
if (
prev.length === namesToAdd.length &&
prev.every((n, i) => n === namesToAdd[i])
) {
return prev;
}
return namesToAdd;
});
}, [singleViewMode, availableViews]);
return (
<StandardModal
show={show}
@@ -511,8 +539,12 @@ export default function AddSemanticViewModal({
</>
)}
{/* Semantic Views — always visible once a layer is selected */}
{selectedLayerUuid && !loadingRuntime && (
{/* Semantic Views — always visible once a layer is selected, unless
the runtime schema declares ``x-singleView: true``: extensions
(e.g. MetricFlow cubes) whose runtime form fully describes a
single view set that flag so the picker disappears and the
view is auto-selected when ``get_semantic_views`` returns it. */}
{selectedLayerUuid && !loadingRuntime && !singleViewMode && (
<ModalFormField label={t('Semantic Views')}>
<Select
ariaLabel={t('Semantic views')}

View File

@@ -33,7 +33,12 @@ import { ensureAppRoot } from '../utils/pathUtils';
import type { DashboardInfo, DashboardLayoutState } from '../dashboard/types';
import type { QueryEditor } from '../SqlLab/types';
type LogEventSource = 'dashboard' | 'embedded_dashboard' | 'explore' | 'sqlLab' | 'slice';
type LogEventSource =
| 'dashboard'
| 'embedded_dashboard'
| 'explore'
| 'sqlLab'
| 'slice';
interface LogEventData {
source?: LogEventSource;

View File

@@ -1497,6 +1497,11 @@ class ChartDataResponseResult(Schema):
required=True,
allow_none=None,
)
semantic_cache_hit = fields.Boolean(
metadata={"description": "Whether the semantic layer smart cache was used"},
required=False,
allow_none=True,
)
query = fields.String(
metadata={
"description": "The executed query statement. May be absent when "

View File

@@ -82,6 +82,7 @@ class QueryContextFactory: # pylint: disable=too-few-public-methods
result_type,
datasource=datasource,
server_pagination=server_pagination,
force_query=force,
**query_obj,
),
)

View File

@@ -202,6 +202,7 @@ class QueryContextProcessor:
"annotation_data": cache.annotation_data,
"error": cache.error_message,
"is_cached": cache.is_cached,
"semantic_cache_hit": getattr(cache, "semantic_cache_hit", None),
"query": cache.query,
"status": cache.status,
"stacktrace": cache.stacktrace,

View File

@@ -101,6 +101,7 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
result_type: ChartDataResultType | None
row_limit: int | None
row_offset: int
force_query: bool
series_columns: list[Column]
series_limit: int
series_limit_metric: Metric | None
@@ -128,6 +129,7 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
post_processing: list[dict[str, Any] | None] | None = None,
row_limit: int | None = None,
row_offset: int | None = None,
force_query: bool = False,
series_columns: list[Column] | None = None,
series_limit: int = 0,
series_limit_metric: Metric | None = None,
@@ -152,6 +154,7 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
self._set_post_processing(post_processing)
self.row_limit = row_limit
self.row_offset = row_offset or 0
self.force_query = force_query
self._init_series_columns(series_columns, metrics, is_timeseries)
self.series_limit = series_limit
self.series_limit_metric = series_limit_metric
@@ -404,6 +407,7 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
"post_processing": self.post_processing,
"row_limit": self.row_limit,
"row_offset": self.row_offset,
"force_query": self.force_query,
"series_columns": self.series_columns,
"series_limit": self.series_limit,
"series_limit_metric": self.series_limit_metric,

View File

@@ -69,6 +69,7 @@ class QueryCacheManager:
cache_value: dict[str, Any] | None = None,
sql_rowcount: int | None = None,
queried_dttm: str | None = None,
semantic_cache_hit: bool | None = None,
) -> None:
self.df = df
self.query = query
@@ -86,6 +87,7 @@ class QueryCacheManager:
self.cache_value = cache_value
self.sql_rowcount = sql_rowcount
self.queried_dttm = queried_dttm
self.semantic_cache_hit = semantic_cache_hit
# pylint: disable=too-many-arguments
def set_query_result(
@@ -110,6 +112,7 @@ class QueryCacheManager:
self.error_message = query_result.error_message
self.df = query_result.df
self.sql_rowcount = query_result.sql_rowcount
self.semantic_cache_hit = query_result.semantic_cache_hit
self.annotation_data = {} if annotation_data is None else annotation_data
self.queried_dttm = (
datetime.now(tz=timezone.utc).replace(microsecond=0).isoformat()
@@ -131,6 +134,7 @@ class QueryCacheManager:
"rejected_filter_columns": self.rejected_filter_columns,
"annotation_data": self.annotation_data,
"sql_rowcount": self.sql_rowcount,
"semantic_cache_hit": self.semantic_cache_hit,
"queried_dttm": self.queried_dttm,
"dttm": self.queried_dttm, # Backwards compatibility
}
@@ -186,6 +190,9 @@ class QueryCacheManager:
query_cache.is_loaded = True
query_cache.is_cached = cache_value is not None
query_cache.sql_rowcount = cache_value.get("sql_rowcount", None)
query_cache.semantic_cache_hit = cache_value.get(
"semantic_cache_hit", None
)
query_cache.cache_dttm = (
cache_value["dttm"] if cache_value is not None else None
)

View File

@@ -673,6 +673,7 @@ class QueryResult: # pylint: disable=too-few-public-methods
errors: Optional[list[dict[str, Any]]] = None,
from_dttm: Optional[datetime] = None,
to_dttm: Optional[datetime] = None,
semantic_cache_hit: Optional[bool] = None,
) -> None:
self.df = df
self.query = query
@@ -685,6 +686,7 @@ class QueryResult: # pylint: disable=too-few-public-methods
self.errors = errors or []
self.from_dttm = from_dttm
self.to_dttm = to_dttm
self.semantic_cache_hit = semantic_cache_hit
self.sql_rowcount = len(self.df.index) if not self.df.empty else 0

View File

@@ -0,0 +1,725 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Containment-aware cache for semantic view queries.
A broader cached result can satisfy a narrower new query: when the new query's
filters and limit are strictly more restrictive than a cached entry's, the cached
DataFrame is post-filtered and re-limited rather than re-executing the underlying
query.
See ``docs/`` and the plan file for the design rationale; the rules summary:
* Same metrics and dimensions (shape).
* Each cached filter must be implied by a new-query filter on the same column.
* New filters on columns with no cached constraint are applied post-fetch as
"leftovers" — provided the column is in the projection.
* Cached ``limit`` must be at least the new ``limit``; if a cached ``limit`` is
present, the orderings must match (otherwise the cached "top N" is not the
true top of the new query).
* ``ADHOC`` and ``HAVING`` filters require exact-set equality.
* ``offset != 0`` and mismatching ``group_limit`` skip the cache.
"""
from __future__ import annotations
import logging
import re
import time as _time
from dataclasses import dataclass, field
from datetime import date, datetime, time, timedelta
from typing import Any, Iterable
import pandas as pd
import pyarrow as pa
from flask import current_app
from superset_core.semantic_layers.types import (
AdhocExpression,
AggregationType,
Dimension,
Filter,
Metric,
Operator,
OrderDirection,
OrderTuple,
PredicateType,
SemanticQuery,
SemanticRequest,
SemanticResult,
)
from superset.extensions import cache_manager
from superset.utils import json
from superset.utils.hashing import hash_from_str
from superset.utils.pandas_postprocessing.aggregate import aggregate
logger = logging.getLogger(__name__)
INDEX_KEY_PREFIX = "sv:idx:"
VALUE_KEY_PREFIX = "sv:val:"
MAX_ENTRIES_PER_SHAPE = 32
_AGGREGATION_TO_PANDAS: dict[AggregationType, str] = {
AggregationType.SUM: "sum",
AggregationType.COUNT: "sum",
AggregationType.MIN: "min",
AggregationType.MAX: "max",
}
ADDITIVE_AGGREGATIONS = frozenset(_AGGREGATION_TO_PANDAS)
@dataclass(frozen=True)
class ViewMeta:
"""Identity/freshness/TTL info pulled from the SemanticView ORM row."""
uuid: str
changed_on_iso: str
cache_timeout: int | None
@dataclass(frozen=True)
class CachedEntry:
filters: frozenset[Filter]
dimension_keys: frozenset[str]
limit: int | None
offset: int
order_key: str
group_limit_key: str
value_key: str
timestamp: float = field(default_factory=_time.time)
# ---------------------------------------------------------------------------
# Public surface
# ---------------------------------------------------------------------------
def try_serve_from_cache(
view_meta: ViewMeta,
query: SemanticQuery,
) -> SemanticResult | None:
"""Return a cached ``SemanticResult`` that satisfies ``query`` if any."""
try:
cache = cache_manager.data_cache
idx_key = shape_key(view_meta, query)
entries: list[CachedEntry] | None = cache.get(idx_key)
if not entries:
return None
pruned: list[CachedEntry] = []
served: SemanticResult | None = None
for entry in entries:
if served is None:
ok, leftovers, projection_needed = can_satisfy(entry, query)
if ok:
payload = cache.get(entry.value_key)
if payload is None:
# value evicted but index entry survived; drop it
continue
if projection_needed and not _projection_input_complete(
entry, payload
):
# Cached result may be truncated (top-N). Keep the index
# entry alive but skip reuse for projection.
pruned.append(entry)
continue
pruned.append(entry)
served = _apply_post_processing(
payload, query, leftovers, projection_needed
)
continue
# keep entry; verify its value is still alive
if cache.get(entry.value_key) is not None:
pruned.append(entry)
if len(pruned) != len(entries):
cache.set(idx_key, pruned, timeout=_timeout(view_meta))
return served
except Exception: # pragma: no cover - defensive
logger.warning("Semantic view cache lookup failed", exc_info=True)
return None
def store_result(
view_meta: ViewMeta,
query: SemanticQuery,
result: SemanticResult,
) -> None:
"""Persist ``result`` under a fresh value key and register a descriptor."""
try:
cache = cache_manager.data_cache
timeout = _timeout(view_meta)
vkey = value_key(view_meta, query)
cache.set(vkey, result, timeout=timeout)
idx_key = shape_key(view_meta, query)
entries: list[CachedEntry] = list(cache.get(idx_key) or [])
entry = CachedEntry(
filters=frozenset(query.filters or set()),
dimension_keys=frozenset(_dimension_key(d) for d in query.dimensions),
limit=query.limit,
offset=query.offset or 0,
order_key=_order_key(query.order),
group_limit_key=_group_limit_key(query.group_limit),
value_key=vkey,
)
entries = [e for e in entries if e.value_key != vkey]
entries.append(entry)
if len(entries) > MAX_ENTRIES_PER_SHAPE:
entries = sorted(entries, key=lambda e: e.timestamp)[
-MAX_ENTRIES_PER_SHAPE:
]
cache.set(idx_key, entries, timeout=timeout)
except Exception: # pragma: no cover - defensive
logger.warning("Semantic view cache store failed", exc_info=True)
# ---------------------------------------------------------------------------
# Keys
# ---------------------------------------------------------------------------
def shape_key(view_meta: ViewMeta, query: SemanticQuery) -> str:
# The shape key buckets entries by metric set only; dimensions live on each
# ``CachedEntry`` so we can find broader (dimension-superset) entries via the
# projection path.
shape = {"m": sorted(m.id for m in query.metrics)}
digest = hash_from_str(json.dumps(shape, sort_keys=True))[:16]
return f"{INDEX_KEY_PREFIX}{view_meta.uuid}:{view_meta.changed_on_iso}:{digest}"
def value_key(view_meta: ViewMeta, query: SemanticQuery) -> str:
digest = hash_from_str(json.dumps(_canonicalize(query), sort_keys=True))[:32]
return f"{VALUE_KEY_PREFIX}{view_meta.uuid}:{view_meta.changed_on_iso}:{digest}"
def _dimension_key(dim: Dimension) -> str:
grain = dim.grain.representation if dim.grain else "_"
return f"{dim.id}@{grain}"
def _canonicalize(query: SemanticQuery) -> dict[str, Any]:
return {
"m": sorted(m.id for m in query.metrics),
"d": sorted(_dimension_key(d) for d in query.dimensions),
"f": sorted(_filter_to_jsonable(f) for f in (query.filters or [])),
"o": _order_key(query.order),
"l": query.limit,
"off": query.offset or 0,
"gl": _group_limit_key(query.group_limit),
}
def _filter_to_jsonable(f: Filter) -> str:
return json.dumps(
{
"t": f.type.value,
"c": f.column.id if f.column is not None else None,
"o": f.operator.value,
"v": _value_to_jsonable(f.value),
},
sort_keys=True,
)
def _value_to_jsonable(value: Any) -> Any:
if isinstance(value, frozenset):
return sorted(_value_to_jsonable(v) for v in value)
if isinstance(value, (datetime, date, time)):
return value.isoformat()
if isinstance(value, timedelta):
return value.total_seconds()
return value
def _order_key(order: list[OrderTuple] | None) -> str:
if not order:
return ""
return json.dumps(
[(_orderable_id(element), direction.value) for element, direction in order]
)
def _orderable_id(element: Metric | Dimension | AdhocExpression) -> str:
return element.id
def _group_limit_key(group_limit: Any) -> str:
if group_limit is None:
return ""
return json.dumps(
{
"dims": sorted(d.id for d in group_limit.dimensions),
"top": group_limit.top,
"metric": group_limit.metric.id if group_limit.metric else None,
"direction": group_limit.direction.value,
"group_others": group_limit.group_others,
"filters": sorted(
_filter_to_jsonable(f) for f in (group_limit.filters or [])
),
},
sort_keys=True,
)
def _timeout(view_meta: ViewMeta) -> int | None:
if view_meta.cache_timeout is not None:
return view_meta.cache_timeout
config = current_app.config.get("DATA_CACHE_CONFIG") or {}
return config.get("CACHE_DEFAULT_TIMEOUT")
# ---------------------------------------------------------------------------
# Containment
# ---------------------------------------------------------------------------
def can_satisfy( # noqa: C901
entry: CachedEntry,
query: SemanticQuery,
) -> tuple[bool, set[Filter], bool]:
"""
Return ``(reusable, leftover_filters, projection_needed)`` for ``entry`` vs
``query``. ``projection_needed`` is True when the cached entry has a strict
superset of the new dimensions and a pandas rollup is required.
"""
new_dim_keys = frozenset(_dimension_key(d) for d in query.dimensions)
cached_dim_keys = entry.dimension_keys
if cached_dim_keys == new_dim_keys:
projection_needed = False
elif cached_dim_keys > new_dim_keys:
projection_needed = True
if not _projection_allowed(entry, query):
return False, set(), False
else:
return False, set(), False
new_filters = frozenset(query.filters or set())
c_adhoc, c_having, c_where = _split(entry.filters)
n_adhoc, n_having, n_where = _split(new_filters)
if c_adhoc != n_adhoc:
return False, set(), False
if c_having != n_having:
return False, set(), False
c_by_col = _group_by_column(c_where)
n_by_col = _group_by_column(n_where)
for c_list in c_by_col.values():
for c in c_list:
n_list = n_by_col.get(_filter_col_id(c), [])
if not any(_implies(n, c) for n in n_list):
return False, set(), False
leftovers: set[Filter] = set()
for col_id, n_list in n_by_col.items():
c_list = c_by_col.get(col_id, [])
for n in n_list:
if not any(_implies(c, n) for c in c_list):
if n.column is None or n.operator == Operator.ADHOC:
return False, set(), False
leftovers.add(n)
# Leftover filters are applied to the cached DataFrame BEFORE the optional
# rollup, so their columns must be present in the cached projection.
allowed_ids = _cached_column_ids(entry, query)
for leftover in leftovers:
if leftover.column is None or leftover.column.id not in allowed_ids:
return False, set(), False
if entry.offset != 0 or (query.offset or 0) != 0:
return False, set(), False
if projection_needed:
# Re-aggregation will re-order by ``query.order`` after rollup, so the
# cached order is irrelevant. We do require the new order (if any) to
# reference only surviving columns; otherwise sort would fail post-rollup.
if not _order_uses_only(query.order, _projection_ids(query)):
return False, set(), False
else:
if entry.limit is not None:
if query.limit is None or query.limit > entry.limit:
return False, set(), False
if entry.order_key != _order_key(query.order):
return False, set(), False
if entry.group_limit_key != _group_limit_key(query.group_limit):
return False, set(), False
return True, leftovers, projection_needed
def _projection_allowed(
entry: CachedEntry,
query: SemanticQuery,
) -> bool:
"""
Gates for the projection path (above and beyond filter containment).
"""
if any(m.aggregation not in ADDITIVE_AGGREGATIONS for m in query.metrics):
return False
if entry.group_limit_key:
return False
if query.group_limit is not None:
return False
# Cached HAVING dropped sub-aggregate rows; the rolled-up totals would be
# off. Conservative: skip the projection path when cached has any HAVING.
if any(f.type == PredicateType.HAVING for f in entry.filters):
return False
return True
def _projection_input_complete(entry: CachedEntry, payload: SemanticResult) -> bool:
"""
True when a projection source is guaranteed not to be limit-truncated.
If a cached query had ``limit=N`` and returned exactly ``N`` rows, there might
be additional source rows that were cut off. We only reuse it for projection
when the payload row count is strictly less than ``N``.
"""
if entry.limit is None:
return True
return payload.results.num_rows < entry.limit
def _filter_col_id(f: Filter) -> str | None:
return f.column.id if f.column is not None else None
def _order_uses_only(
order: list[OrderTuple] | None,
allowed_ids: set[str],
) -> bool:
if not order:
return True
return all(_orderable_id(element) in allowed_ids for element, _ in order)
def _split(
filters: Iterable[Filter],
) -> tuple[frozenset[Filter], frozenset[Filter], frozenset[Filter]]:
adhoc: set[Filter] = set()
having: set[Filter] = set()
where: set[Filter] = set()
for f in filters:
if f.operator == Operator.ADHOC:
adhoc.add(f)
elif f.type == PredicateType.HAVING:
having.add(f)
else:
where.add(f)
return frozenset(adhoc), frozenset(having), frozenset(where)
def _group_by_column(filters: Iterable[Filter]) -> dict[str | None, list[Filter]]:
out: dict[str | None, list[Filter]] = {}
for f in filters:
col_id = f.column.id if f.column is not None else None
out.setdefault(col_id, []).append(f)
return out
def _projection_ids(query: SemanticQuery) -> set[str]:
return {d.id for d in query.dimensions} | {m.id for m in query.metrics}
def _cached_column_ids(entry: CachedEntry, query: SemanticQuery) -> set[str]:
"""Column ids available in the cached DataFrame (cached dims + shared metrics)."""
cached_dim_ids = {key.rsplit("@", 1)[0] for key in entry.dimension_keys}
return cached_dim_ids | {m.id for m in query.metrics}
# ---------------------------------------------------------------------------
# Pairwise implication
# ---------------------------------------------------------------------------
# pylint: disable=too-many-return-statements,too-many-branches
def _implies(new: Filter, cached: Filter) -> bool: # noqa: C901
"""True iff every row matching ``new`` also matches ``cached``.
Both filters are assumed to be on the same column (caller groups by column).
"""
if new == cached:
return True
nop, nval = new.operator, new.value
cop, cval = cached.operator, cached.value
if cop == Operator.IS_NULL:
if nop == Operator.IS_NULL:
return True
if nop == Operator.EQUALS and nval is None:
return True
return False
if cop == Operator.IS_NOT_NULL:
if nop == Operator.IS_NOT_NULL:
return True
if nop == Operator.EQUALS:
return nval is not None
if nop in _RANGE_OPS:
return True
if nop == Operator.IN:
return isinstance(nval, frozenset) and all(v is not None for v in nval)
return False
if cop == Operator.EQUALS:
if nop == Operator.EQUALS:
return nval == cval
if nop == Operator.IN and isinstance(nval, frozenset):
return nval == frozenset({cval})
return False
if cop == Operator.NOT_EQUALS:
if nop == Operator.NOT_EQUALS:
return nval == cval
if nop == Operator.EQUALS:
return nval != cval
if nop == Operator.IN and isinstance(nval, frozenset):
return cval not in nval
return False
if cop == Operator.IN and isinstance(cval, frozenset):
if nop == Operator.IN and isinstance(nval, frozenset):
return nval.issubset(cval)
if nop == Operator.EQUALS:
return nval in cval
return False
if cop == Operator.NOT_IN and isinstance(cval, frozenset):
if nop == Operator.NOT_IN and isinstance(nval, frozenset):
return cval.issubset(nval)
if nop == Operator.NOT_EQUALS:
return cval.issubset({nval})
if nop == Operator.EQUALS:
return nval not in cval
if nop == Operator.IN and isinstance(nval, frozenset):
return cval.isdisjoint(nval)
return False
if cop in _RANGE_OPS:
return _implies_range(nop, nval, cop, cval)
# LIKE / NOT_LIKE / ADHOC: only the exact-match path at the top.
return False
_RANGE_OPS = frozenset(
{
Operator.GREATER_THAN,
Operator.GREATER_THAN_OR_EQUAL,
Operator.LESS_THAN,
Operator.LESS_THAN_OR_EQUAL,
}
)
def _implies_range( # noqa: C901
nop: Operator,
nval: Any,
cop: Operator,
cval: Any,
) -> bool:
if isinstance(nval, frozenset):
return nop == Operator.IN and all(_scalar_in_range(v, cop, cval) for v in nval)
if nop == Operator.EQUALS:
return _scalar_in_range(nval, cop, cval)
if nop not in _RANGE_OPS:
return False
if not _comparable(nval, cval):
return False
# Same direction (both upper or both lower bounds) required.
cached_is_lower = cop in (Operator.GREATER_THAN, Operator.GREATER_THAN_OR_EQUAL)
new_is_lower = nop in (Operator.GREATER_THAN, Operator.GREATER_THAN_OR_EQUAL)
if cached_is_lower != new_is_lower:
return False
if cached_is_lower:
# cached: a > cval or a >= cval
# new: a > nval or a >= nval
# need rows(new) ⊆ rows(cached)
if cop == Operator.GREATER_THAN and nop == Operator.GREATER_THAN:
return nval >= cval
if cop == Operator.GREATER_THAN and nop == Operator.GREATER_THAN_OR_EQUAL:
return nval > cval
if cop == Operator.GREATER_THAN_OR_EQUAL and nop == Operator.GREATER_THAN:
return nval >= cval
if (
cop == Operator.GREATER_THAN_OR_EQUAL
and nop == Operator.GREATER_THAN_OR_EQUAL
):
return nval >= cval
return False
else:
if cop == Operator.LESS_THAN and nop == Operator.LESS_THAN:
return nval <= cval
if cop == Operator.LESS_THAN and nop == Operator.LESS_THAN_OR_EQUAL:
return nval < cval
if cop == Operator.LESS_THAN_OR_EQUAL and nop == Operator.LESS_THAN:
return nval <= cval
if cop == Operator.LESS_THAN_OR_EQUAL and nop == Operator.LESS_THAN_OR_EQUAL:
return nval <= cval
return False
def _scalar_in_range(value: Any, cop: Operator, cval: Any) -> bool:
if not _comparable(value, cval):
return False
if cop == Operator.GREATER_THAN:
return value > cval
if cop == Operator.GREATER_THAN_OR_EQUAL:
return value >= cval
if cop == Operator.LESS_THAN:
return value < cval
if cop == Operator.LESS_THAN_OR_EQUAL:
return value <= cval
return False
def _comparable(a: Any, b: Any) -> bool:
if a is None or b is None:
return False
if isinstance(a, bool) or isinstance(b, bool):
return isinstance(a, bool) and isinstance(b, bool)
if isinstance(a, (int, float)) and isinstance(b, (int, float)):
return True
if isinstance(a, str) and isinstance(b, str):
return True
if isinstance(a, (datetime, date, time)) and isinstance(b, type(a)):
return True
if isinstance(a, type(b)) and isinstance(a, (datetime, date, time, timedelta)):
return True
return type(a) == type(b) # noqa: E721
# ---------------------------------------------------------------------------
# Post-processing
# ---------------------------------------------------------------------------
def _apply_post_processing(
cached: SemanticResult,
query: SemanticQuery,
leftovers: set[Filter],
projection_needed: bool,
) -> SemanticResult:
"""Apply leftover filters, projection (re-aggregation), order, and limit."""
if not leftovers and not projection_needed and query.limit is None:
return cached
df = cached.results.to_pandas()
if leftovers:
mask = pd.Series(True, index=df.index)
for f in leftovers:
mask &= _mask_for(df, f)
df = df[mask]
note_def = "Served from semantic view smart cache (post-processed locally)"
if projection_needed:
groupby = [d.name for d in query.dimensions]
aggregates: dict[str, dict[str, str]] = {}
for m in query.metrics:
if m.aggregation is None:
continue
aggregates[m.name] = {
"column": m.name,
"operator": _AGGREGATION_TO_PANDAS[m.aggregation],
}
df = aggregate(df, groupby=groupby, aggregates=aggregates)
note_def = "Served from semantic view smart cache (re-aggregated locally)"
df = _apply_order(df, query.order)
if query.limit is not None:
df = df.head(query.limit)
table = pa.Table.from_pandas(df, preserve_index=False)
note = SemanticRequest(type="cache", definition=note_def)
return SemanticResult(requests=list(cached.requests) + [note], results=table)
def _apply_order(
df: pd.DataFrame,
order: list[OrderTuple] | None,
) -> pd.DataFrame:
if not order:
return df
available: list[tuple[str, bool]] = []
for element, direction in order:
col = _orderable_id_name(element)
if col in df.columns:
available.append((col, direction == OrderDirection.ASC))
if not available:
return df
cols = [col for col, _ in available]
asc = [a for _, a in available]
return df.sort_values(by=cols, ascending=asc).reset_index(drop=True)
def _orderable_id_name(element: Metric | Dimension | AdhocExpression) -> str:
return getattr(element, "name", element.id)
def _mask_for(df: pd.DataFrame, f: Filter) -> pd.Series: # noqa: C901
if f.column is None:
return pd.Series(True, index=df.index)
series = df[f.column.name]
op = f.operator
val = f.value
if op == Operator.EQUALS:
return series == val if val is not None else series.isna()
if op == Operator.NOT_EQUALS:
return series != val if val is not None else series.notna()
if op == Operator.GREATER_THAN:
return series > val
if op == Operator.GREATER_THAN_OR_EQUAL:
return series >= val
if op == Operator.LESS_THAN:
return series < val
if op == Operator.LESS_THAN_OR_EQUAL:
return series <= val
if op == Operator.IN:
return series.isin(list(val) if isinstance(val, frozenset) else [val])
if op == Operator.NOT_IN:
return ~series.isin(list(val) if isinstance(val, frozenset) else [val])
if op == Operator.IS_NULL:
return series.isna()
if op == Operator.IS_NOT_NULL:
return series.notna()
if op == Operator.LIKE:
return series.astype(str).str.match(_sql_like_to_regex(str(val)))
if op == Operator.NOT_LIKE:
return ~series.astype(str).str.match(_sql_like_to_regex(str(val)))
return pd.Series(True, index=df.index)
def _sql_like_to_regex(pattern: str) -> str:
out = []
for ch in pattern:
if ch == "%":
out.append(".*")
elif ch == "_":
out.append(".")
else:
out.append(re.escape(ch))
return f"^{''.join(out)}$"

View File

@@ -26,7 +26,7 @@ single dataframe.
from datetime import datetime, timedelta
from time import time
from typing import Any, cast, Sequence, TypeGuard
from typing import Any, Callable, cast, Sequence, TypeGuard
import isodate
import numpy as np
@@ -55,6 +55,11 @@ from superset.common.utils.time_range_utils import get_since_until_from_query_ob
from superset.connectors.sqla.models import BaseDatasource
from superset.constants import NO_TIME_RANGE
from superset.models.helpers import QueryResult
from superset.semantic_layers.cache import (
store_result,
try_serve_from_cache,
ViewMeta,
)
from superset.superset_typing import AdhocColumn
from superset.utils.core import (
FilterOperator,
@@ -112,13 +117,15 @@ def get_results(query_object: QueryObject) -> QueryResult:
else semantic_view.get_table
)
cached_dispatch = _make_cached_dispatch(query_object, dispatcher)
# Step 1: Convert QueryObject to list of SemanticQuery objects
# The first query is the main query, subsequent queries are for time offsets
queries = map_query_object(query_object)
# Step 2: Execute the main query (first in the list)
main_query = queries[0]
main_result = dispatcher(main_query)
main_result = cached_dispatch(main_query)
main_df = main_result.results.to_pandas()
@@ -149,7 +156,7 @@ def get_results(query_object: QueryObject) -> QueryResult:
strict=False,
):
# Execute the offset query
result = dispatcher(offset_query)
result = cached_dispatch(offset_query)
# Add this query's requests to the collection
all_requests.extend(result.requests)
@@ -205,6 +212,37 @@ def get_results(query_object: QueryObject) -> QueryResult:
)
def _make_cached_dispatch(
query_object: ValidatedQueryObject,
dispatcher: Callable[[SemanticQuery], SemanticResult],
) -> Callable[[SemanticQuery], SemanticResult]:
"""
Wrap the semantic view dispatcher with a containment-aware cache.
Row-count queries bypass the cache. Cache failures are logged and the
dispatcher is called as if the cache were absent.
"""
if query_object.is_rowcount or query_object.force_query:
return dispatcher
view = query_object.datasource
changed_on = getattr(view, "changed_on", None)
view_meta = ViewMeta(
uuid=str(view.uuid),
changed_on_iso=changed_on.isoformat() if changed_on else "",
cache_timeout=getattr(view, "cache_timeout", None),
)
def cached_dispatch(query: SemanticQuery) -> SemanticResult:
if (hit := try_serve_from_cache(view_meta, query)) is not None:
return hit
result = dispatcher(query)
store_result(view_meta, query, result)
return result
return cached_dispatch
def map_semantic_result_to_query_result(
semantic_result: SemanticResult,
query_object: ValidatedQueryObject,
@@ -226,6 +264,8 @@ def map_semantic_result_to_query_result(
f"-- {req.type}\n{req.definition}" for req in semantic_result.requests
)
semantic_cache_hit = any(req.type == "cache" for req in semantic_result.requests)
return QueryResult(
# Core data
df=semantic_result.results.to_pandas(),
@@ -246,6 +286,7 @@ def map_semantic_result_to_query_result(
# Time range - pass through from original query_object
from_dttm=query_object.from_dttm,
to_dttm=query_object.to_dttm,
semantic_cache_hit=semantic_cache_hit,
)

View File

@@ -179,6 +179,7 @@ class QueryObjectDict(TypedDict, total=False):
orderby: List of order by clauses
row_limit: Maximum number of rows
row_offset: Number of rows to skip
force_query: Whether to bypass cache when executing
series_columns: Columns to use for series
series_limit: Maximum number of series
series_limit_metric: Metric to use for series limiting
@@ -215,6 +216,7 @@ class QueryObjectDict(TypedDict, total=False):
orderby: list[OrderBy]
row_limit: int | None
row_offset: int
force_query: bool
series_columns: list[Column]
series_limit: int
series_limit_metric: Metric | None

View File

@@ -0,0 +1,355 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""End-to-end test that exercises ``mapper.get_results`` with a live cache."""
from __future__ import annotations
from datetime import datetime
from typing import Any
from unittest.mock import MagicMock
import pandas as pd
import pyarrow as pa
import pytest
from pytest_mock import MockerFixture
from superset_core.semantic_layers.types import (
AggregationType,
Dimension,
Metric,
SemanticRequest,
SemanticResult,
)
from superset.semantic_layers import cache as cache_module
from superset.semantic_layers.mapper import get_results, ValidatedQueryObject
class _InMemoryCache:
"""Minimal flask-caching compatible cache used to isolate tests."""
def __init__(self) -> None:
self._store: dict[str, Any] = {}
def get(self, key: str) -> Any:
return self._store.get(key)
def set(self, key: str, value: Any, timeout: int | None = None) -> bool:
self._store[key] = value
return True
def delete(self, key: str) -> bool:
return self._store.pop(key, None) is not None
@pytest.fixture
def fake_cache(mocker: MockerFixture) -> _InMemoryCache:
fake = _InMemoryCache()
mocker.patch.object(
type(cache_module.cache_manager),
"data_cache",
property(lambda self: fake),
)
return fake
@pytest.fixture
def view_implementation() -> Any:
"""SemanticView implementation stub with one metric and one dimension."""
dim_a = Dimension(id="t.a", name="a", type=pa.int64())
metric_x = Metric(id="t.x", name="x", type=pa.float64(), definition="sum(x)")
impl = MagicMock()
impl.metrics = {metric_x}
impl.dimensions = {dim_a}
impl.features = frozenset()
impl.get_metrics = MagicMock(return_value={metric_x})
impl.get_dimensions = MagicMock(return_value={dim_a})
return impl
@pytest.fixture
def datasource(view_implementation: Any) -> MagicMock:
ds = MagicMock()
ds.implementation = view_implementation
ds.uuid = "view-uuid-stable"
ds.changed_on = datetime(2026, 1, 1, 12, 0, 0)
ds.cache_timeout = 60
ds.fetch_values_predicate = None
return ds
def _result(rows: list[tuple[int, float]]) -> SemanticResult:
df = pd.DataFrame(rows, columns=["a", "x"])
return SemanticResult(
requests=[SemanticRequest(type="SQL", definition="select a, x")],
results=pa.Table.from_pandas(df, preserve_index=False),
)
def _qo(
datasource: MagicMock,
filter_op: str | None = None,
filter_val: Any = None,
limit: int | None = None,
force_query: bool = False,
) -> ValidatedQueryObject:
qo_filters: list[dict[str, Any]] = (
[{"col": "a", "op": filter_op, "val": filter_val}] if filter_op else []
)
return ValidatedQueryObject(
datasource=datasource,
metrics=["x"],
columns=["a"],
filters=qo_filters, # type: ignore[arg-type]
row_limit=limit,
force_query=force_query,
)
def test_narrower_filter_reuses_cache(
fake_cache: _InMemoryCache,
view_implementation: Any,
datasource: MagicMock,
) -> None:
# The dispatcher returns rows already filtered by `a > 1` (in production it
# would; here we hand-feed the result). The second query (a > 2) is a subset
# and must be served from the cached DataFrame.
cached = _result([(2, 2.0), (3, 3.0), (5, 5.0)])
view_implementation.get_table = MagicMock(return_value=cached)
first = get_results(_qo(datasource, ">", 1))
assert view_implementation.get_table.call_count == 1
assert sorted(first.df["a"].tolist()) == [2, 3, 5]
second = get_results(_qo(datasource, ">", 2))
assert view_implementation.get_table.call_count == 1 # cache hit
assert sorted(second.df["a"].tolist()) == [3, 5]
def test_smaller_limit_reuses_cache(
fake_cache: _InMemoryCache,
view_implementation: Any,
datasource: MagicMock,
) -> None:
# First call has no limit; second asks for 2 rows — should be served from cache.
full = _result([(0, 1.0), (1, 2.0), (2, 3.0), (3, 4.0)])
view_implementation.get_table = MagicMock(return_value=full)
get_results(_qo(datasource, limit=None))
assert view_implementation.get_table.call_count == 1
result = get_results(_qo(datasource, limit=2))
assert view_implementation.get_table.call_count == 1 # cache hit
assert len(result.df) == 2
def test_broader_filter_misses_cache(
fake_cache: _InMemoryCache,
view_implementation: Any,
datasource: MagicMock,
) -> None:
view_implementation.get_table = MagicMock(
side_effect=[
_result([(2, 1.0), (3, 2.0)]),
_result([(0, 1.0), (2, 2.0), (3, 3.0)]),
]
)
get_results(_qo(datasource, ">", 1))
assert view_implementation.get_table.call_count == 1
# Broader filter — must re-execute.
get_results(_qo(datasource, ">", 0))
assert view_implementation.get_table.call_count == 2
def test_changed_on_invalidates_cache(
fake_cache: _InMemoryCache,
view_implementation: Any,
datasource: MagicMock,
) -> None:
view_implementation.get_table = MagicMock(return_value=_result([(2, 1.0)]))
get_results(_qo(datasource, ">", 1))
assert view_implementation.get_table.call_count == 1
# Bumping changed_on yields a different shape key — cache misses.
datasource.changed_on = datetime(2026, 2, 1, 0, 0, 0)
get_results(_qo(datasource, ">", 1))
assert view_implementation.get_table.call_count == 2
def test_force_query_bypasses_semantic_cache(
fake_cache: _InMemoryCache,
view_implementation: Any,
datasource: MagicMock,
) -> None:
view_implementation.get_table = MagicMock(return_value=_result([(2, 1.0)]))
get_results(_qo(datasource, ">", 1))
assert view_implementation.get_table.call_count == 1
get_results(_qo(datasource, ">", 1, force_query=True))
assert view_implementation.get_table.call_count == 2
# ---------------------------------------------------------------------------
# Projection (v2) — dropping a dimension and re-aggregating
# ---------------------------------------------------------------------------
def _make_view(metric_aggregation: AggregationType | None) -> tuple[Any, MagicMock]:
dim_b = Dimension(id="t.b", name="b", type=pa.utf8())
dim_c = Dimension(id="t.c", name="c", type=pa.utf8())
metric_x = Metric(
id="t.x",
name="x",
type=pa.float64(),
definition="sum(x)",
aggregation=metric_aggregation,
)
impl = MagicMock()
impl.metrics = {metric_x}
impl.dimensions = {dim_b, dim_c}
impl.features = frozenset()
impl.get_metrics = MagicMock(return_value={metric_x})
impl.get_dimensions = MagicMock(return_value={dim_b, dim_c})
ds = MagicMock()
ds.implementation = impl
ds.uuid = "proj-view"
ds.changed_on = datetime(2026, 3, 1, 0, 0, 0)
ds.cache_timeout = 60
ds.fetch_values_predicate = None
return impl, ds
def _qo_dims(ds: MagicMock, columns: list[str]) -> ValidatedQueryObject:
return ValidatedQueryObject(
datasource=ds,
metrics=["x"],
columns=columns, # type: ignore[arg-type]
filters=[],
)
def _result_bc(rows: list[tuple[str, str, float]]) -> SemanticResult:
df = pd.DataFrame(rows, columns=["b", "c", "x"])
return SemanticResult(
requests=[SemanticRequest(type="SQL", definition="select b,c,sum(x)")],
results=pa.Table.from_pandas(df, preserve_index=False),
)
def test_projection_reuses_cached_for_dropped_dim(
fake_cache: _InMemoryCache,
) -> None:
impl, ds = _make_view(AggregationType.SUM)
impl.get_table = MagicMock(
return_value=_result_bc(
[("b1", "c1", 5.0), ("b1", "c2", 3.0), ("b2", "c1", 4.0)]
)
)
first = get_results(_qo_dims(ds, ["b", "c"]))
assert impl.get_table.call_count == 1
assert len(first.df) == 3
second = get_results(_qo_dims(ds, ["b"]))
assert impl.get_table.call_count == 1 # served via projection
df = second.df.sort_values("b").reset_index(drop=True)
assert df["b"].tolist() == ["b1", "b2"]
assert df["x"].tolist() == [8.0, 4.0]
def test_projection_skipped_when_aggregation_unknown(
fake_cache: _InMemoryCache,
) -> None:
impl, ds = _make_view(None) # metric has no aggregation declared
impl.get_table = MagicMock(
side_effect=[
_result_bc([("b1", "c1", 5.0), ("b1", "c2", 3.0)]),
_result_bc([("b1", "c1", 5.0)]), # what the SV would compute for [b]
]
)
get_results(_qo_dims(ds, ["b", "c"]))
assert impl.get_table.call_count == 1
get_results(_qo_dims(ds, ["b"]))
assert impl.get_table.call_count == 2 # cannot project, re-executed
def test_projection_skipped_for_avg(
fake_cache: _InMemoryCache,
) -> None:
impl, ds = _make_view(AggregationType.AVG)
impl.get_table = MagicMock(
side_effect=[
_result_bc([("b1", "c1", 5.0), ("b1", "c2", 3.0)]),
_result_bc([("b1", "c1", 4.0)]),
]
)
get_results(_qo_dims(ds, ["b", "c"]))
get_results(_qo_dims(ds, ["b"]))
assert impl.get_table.call_count == 2
def test_projection_reuses_when_cached_limit_not_reached(
fake_cache: _InMemoryCache,
) -> None:
impl, ds = _make_view(AggregationType.SUM)
impl.get_table = MagicMock(
return_value=_result_bc(
[("b1", "c1", 5.0), ("b1", "c2", 3.0), ("b2", "c1", 4.0)]
)
)
first = get_results(_qo_dims(ds, ["b", "c"]))
assert impl.get_table.call_count == 1
assert len(first.df) == 3
second = get_results(_qo_dims(ds, ["b"]))
assert impl.get_table.call_count == 1 # served via projection
df = second.df.sort_values("b").reset_index(drop=True)
assert df["b"].tolist() == ["b1", "b2"]
assert df["x"].tolist() == [8.0, 4.0]
def test_projection_skips_when_cached_limit_reached(
fake_cache: _InMemoryCache,
) -> None:
impl, ds = _make_view(AggregationType.SUM)
first_q = _qo_dims(ds, ["b", "c"])
first_q.row_limit = 3
second_q = _qo_dims(ds, ["b"])
impl.get_table = MagicMock(
side_effect=[
_result_bc([("b1", "c1", 5.0), ("b1", "c2", 3.0), ("b2", "c1", 4.0)]),
_result_bc([("b1", "c1", 8.0), ("b2", "c1", 4.0)]),
]
)
get_results(first_q)
assert impl.get_table.call_count == 1
get_results(second_q)
assert impl.get_table.call_count == 2 # projection skipped; re-executed

View File

@@ -0,0 +1,757 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
from datetime import datetime
from typing import Any
import pandas as pd
import pyarrow as pa
import pytest
from superset_core.semantic_layers.types import (
AggregationType,
Dimension,
Filter,
GroupLimit,
Metric,
Operator,
OrderDirection,
PredicateType,
SemanticQuery,
SemanticRequest,
SemanticResult,
)
from superset.semantic_layers.cache import (
_apply_post_processing,
_implies,
_projection_input_complete,
CachedEntry,
can_satisfy,
shape_key,
value_key,
ViewMeta,
)
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
def dim(id_: str, name: str | None = None) -> Dimension:
return Dimension(id=id_, name=name or id_, type=pa.utf8())
def met(
id_: str,
name: str | None = None,
aggregation: AggregationType | None = None,
) -> Metric:
return Metric(
id=id_,
name=name or id_,
type=pa.float64(),
definition="x",
aggregation=aggregation,
)
COL_A = dim("col.a", "a")
COL_B = dim("col.b", "b")
M_X = met("met.x", "x")
M_Y = met("met.y", "y")
VIEW = ViewMeta(uuid="view-1", changed_on_iso="2026-05-01T00:00:00", cache_timeout=None)
def where(column: Dimension | Metric | None, op: Operator, value: Any) -> Filter:
return Filter(type=PredicateType.WHERE, column=column, operator=op, value=value)
def having(column: Metric, op: Operator, value: Any) -> Filter:
return Filter(type=PredicateType.HAVING, column=column, operator=op, value=value)
def adhoc(definition: str, type_: PredicateType = PredicateType.WHERE) -> Filter:
return Filter(type=type_, column=None, operator=Operator.ADHOC, value=definition)
def query(
filters: set[Filter] | None = None,
limit: int | None = None,
order: Any = None,
dimensions: list[Dimension] | None = None,
metrics: list[Metric] | None = None,
) -> SemanticQuery:
return SemanticQuery(
metrics=metrics if metrics is not None else [M_X],
dimensions=dimensions if dimensions is not None else [COL_A, COL_B],
filters=filters,
order=order,
limit=limit,
)
def entry_from(q: SemanticQuery, value_key_: str = "vk") -> CachedEntry:
from superset.semantic_layers.cache import (
_dimension_key,
_group_limit_key,
_order_key,
)
return CachedEntry(
filters=frozenset(q.filters or set()),
dimension_keys=frozenset(_dimension_key(d) for d in q.dimensions),
limit=q.limit,
offset=q.offset or 0,
order_key=_order_key(q.order),
group_limit_key=_group_limit_key(q.group_limit),
value_key=value_key_,
)
# ---------------------------------------------------------------------------
# _implies: scalar range pairs
# ---------------------------------------------------------------------------
@pytest.mark.parametrize(
"new_op,new_val,cached_op,cached_val,expected",
[
# narrower lower bound
(Operator.GREATER_THAN, 20, Operator.GREATER_THAN, 10, True),
(Operator.GREATER_THAN, 10, Operator.GREATER_THAN, 20, False),
(Operator.GREATER_THAN_OR_EQUAL, 11, Operator.GREATER_THAN, 10, True),
(Operator.GREATER_THAN_OR_EQUAL, 10, Operator.GREATER_THAN, 10, False),
(Operator.GREATER_THAN, 10, Operator.GREATER_THAN_OR_EQUAL, 10, True),
(Operator.GREATER_THAN, 9, Operator.GREATER_THAN_OR_EQUAL, 10, False),
# narrower upper bound
(Operator.LESS_THAN, 5, Operator.LESS_THAN, 10, True),
(Operator.LESS_THAN_OR_EQUAL, 9, Operator.LESS_THAN, 10, True),
(Operator.LESS_THAN_OR_EQUAL, 10, Operator.LESS_THAN, 10, False),
# cross-direction — never implies
(Operator.LESS_THAN, 5, Operator.GREATER_THAN, 10, False),
(Operator.GREATER_THAN, 5, Operator.LESS_THAN, 10, False),
# equals fits in range
(Operator.EQUALS, 15, Operator.GREATER_THAN, 10, True),
(Operator.EQUALS, 10, Operator.GREATER_THAN, 10, False),
(Operator.EQUALS, 10, Operator.GREATER_THAN_OR_EQUAL, 10, True),
],
)
def test_implies_range(
new_op: Operator,
new_val: Any,
cached_op: Operator,
cached_val: Any,
expected: bool,
) -> None:
assert (
_implies(where(COL_A, new_op, new_val), where(COL_A, cached_op, cached_val))
is expected
)
def test_implies_in_subset() -> None:
cached = where(COL_A, Operator.IN, frozenset({"a", "b", "c"}))
assert _implies(where(COL_A, Operator.IN, frozenset({"a", "b"})), cached) is True
assert _implies(where(COL_A, Operator.IN, frozenset({"a", "d"})), cached) is False
# equals to a value in the cached IN set
assert _implies(where(COL_A, Operator.EQUALS, "b"), cached) is True
assert _implies(where(COL_A, Operator.EQUALS, "z"), cached) is False
def test_implies_in_all_in_range() -> None:
cached = where(COL_A, Operator.GREATER_THAN, 10)
assert _implies(where(COL_A, Operator.IN, frozenset({11, 12})), cached) is True
assert _implies(where(COL_A, Operator.IN, frozenset({10, 12})), cached) is False
def test_implies_equals_exact() -> None:
cached = where(COL_A, Operator.EQUALS, 5)
assert _implies(where(COL_A, Operator.EQUALS, 5), cached) is True
assert _implies(where(COL_A, Operator.EQUALS, 6), cached) is False
def test_implies_is_not_null() -> None:
cached = where(COL_A, Operator.IS_NOT_NULL, None)
assert _implies(where(COL_A, Operator.GREATER_THAN, 0), cached) is True
assert _implies(where(COL_A, Operator.IS_NOT_NULL, None), cached) is True
assert _implies(where(COL_A, Operator.IS_NULL, None), cached) is False
def test_implies_like_exact_match_only() -> None:
a = where(COL_A, Operator.LIKE, "foo%")
b = where(COL_A, Operator.LIKE, "foo%")
c = where(COL_A, Operator.LIKE, "bar%")
assert _implies(a, b) is True
assert _implies(c, b) is False
assert _implies(where(COL_A, Operator.EQUALS, "fooz"), b) is False
# ---------------------------------------------------------------------------
# can_satisfy
# ---------------------------------------------------------------------------
def test_can_satisfy_empty_cached_returns_all_as_leftovers() -> None:
cached_q = query(filters=None)
new_q = query(filters={where(COL_A, Operator.GREATER_THAN, 5)})
ok, leftovers, projection = can_satisfy(entry_from(cached_q), new_q)
assert ok is True
assert projection is False
assert leftovers == {where(COL_A, Operator.GREATER_THAN, 5)}
def test_can_satisfy_narrower_filter() -> None:
cached_q = query(filters={where(COL_A, Operator.GREATER_THAN, 1)})
new_q = query(filters={where(COL_A, Operator.GREATER_THAN, 2)})
ok, leftovers, _ = can_satisfy(entry_from(cached_q), new_q)
assert ok is True
assert leftovers == {where(COL_A, Operator.GREATER_THAN, 2)}
def test_can_satisfy_broader_filter_fails() -> None:
cached_q = query(filters={where(COL_A, Operator.GREATER_THAN, 2)})
new_q = query(filters={where(COL_A, Operator.GREATER_THAN, 1)})
ok, leftovers, _ = can_satisfy(entry_from(cached_q), new_q)
assert ok is False
assert leftovers == set()
def test_can_satisfy_missing_constraint_fails() -> None:
cached_q = query(filters={where(COL_A, Operator.GREATER_THAN, 1)})
new_q = query(filters=None)
ok, _, _ = can_satisfy(entry_from(cached_q), new_q)
assert ok is False
def test_can_satisfy_new_filter_on_extra_column() -> None:
cached_q = query(filters={where(COL_A, Operator.GREATER_THAN, 1)})
new_q = query(
filters={
where(COL_A, Operator.GREATER_THAN, 2),
where(COL_B, Operator.EQUALS, "x"),
}
)
ok, leftovers, _ = can_satisfy(entry_from(cached_q), new_q)
assert ok is True
assert leftovers == {
where(COL_A, Operator.GREATER_THAN, 2),
where(COL_B, Operator.EQUALS, "x"),
}
def test_can_satisfy_leftover_on_non_projected_column_fails() -> None:
other = dim("col.other", "other")
cached_q = query(filters=None)
new_q = query(
filters={where(other, Operator.EQUALS, "x")},
dimensions=[COL_A, COL_B],
)
ok, _, _ = can_satisfy(entry_from(cached_q), new_q)
assert ok is False
def test_can_satisfy_having_requires_exact_set() -> None:
cached_q = query(filters={having(M_X, Operator.GREATER_THAN, 100)})
same = query(filters={having(M_X, Operator.GREATER_THAN, 100)})
tighter = query(filters={having(M_X, Operator.GREATER_THAN, 200)})
ok_same, _, _ = can_satisfy(entry_from(cached_q), same)
ok_tight, _, _ = can_satisfy(entry_from(cached_q), tighter)
assert ok_same is True
assert ok_tight is False
def test_can_satisfy_adhoc_requires_exact_set() -> None:
cached_q = query(filters={adhoc("col_a > 1")})
same = query(filters={adhoc("col_a > 1")})
different = query(filters={adhoc("col_a > 2")})
ok_same, _, _ = can_satisfy(entry_from(cached_q), same)
ok_diff, _, _ = can_satisfy(entry_from(cached_q), different)
assert ok_same is True
assert ok_diff is False
# ---------------------------------------------------------------------------
# Limit / order / offset
# ---------------------------------------------------------------------------
def test_can_satisfy_unlimited_cached_satisfies_any_limit() -> None:
cached_q = query(filters=None, limit=None)
new_q = query(filters=None, limit=10)
ok, leftovers, _ = can_satisfy(entry_from(cached_q), new_q)
assert ok is True
assert leftovers == set()
def test_can_satisfy_smaller_limit_with_matching_order() -> None:
order = [(M_X, OrderDirection.DESC)]
cached_q = query(filters=None, limit=100, order=order)
new_q = query(filters=None, limit=10, order=order)
ok, _, _ = can_satisfy(entry_from(cached_q), new_q)
assert ok is True
def test_can_satisfy_smaller_limit_different_order_fails() -> None:
cached_q = query(filters=None, limit=100, order=[(M_X, OrderDirection.DESC)])
new_q = query(filters=None, limit=10, order=[(M_X, OrderDirection.ASC)])
ok, _, _ = can_satisfy(entry_from(cached_q), new_q)
assert ok is False
def test_can_satisfy_larger_limit_fails() -> None:
cached_q = query(filters=None, limit=10)
new_q = query(filters=None, limit=100)
ok, _, _ = can_satisfy(entry_from(cached_q), new_q)
assert ok is False
def test_can_satisfy_no_new_limit_when_cached_has_one_fails() -> None:
cached_q = query(filters=None, limit=100)
new_q = query(filters=None, limit=None)
ok, _, _ = can_satisfy(entry_from(cached_q), new_q)
assert ok is False
def test_can_satisfy_offset_never_reused() -> None:
cached_q = SemanticQuery(metrics=[M_X], dimensions=[COL_A], offset=5)
new_q = SemanticQuery(metrics=[M_X], dimensions=[COL_A], offset=5)
ok, _, _ = can_satisfy(entry_from(cached_q), new_q)
assert ok is False
# ---------------------------------------------------------------------------
# Post-processing
# ---------------------------------------------------------------------------
def test_apply_post_processing_filters_and_limits() -> None:
df = pd.DataFrame({"a": [1, 3, 5, 7, 9], "x": [10, 20, 30, 40, 50]})
cached = SemanticResult(
requests=[SemanticRequest(type="SQL", definition="select ...")],
results=pa.Table.from_pandas(df, preserve_index=False),
)
new_q = query(
filters={where(COL_A, Operator.GREATER_THAN, 2)},
limit=2,
)
result = _apply_post_processing(
cached, new_q, {where(COL_A, Operator.GREATER_THAN, 2)}, False
)
result_df = result.results.to_pandas()
assert list(result_df["a"]) == [3, 5]
# the cache annotates the requests with a marker
assert any(req.type == "cache" for req in result.requests)
def test_apply_post_processing_no_leftovers_no_limit_returns_original() -> None:
df = pd.DataFrame({"a": [1, 2]})
cached = SemanticResult(
requests=[], results=pa.Table.from_pandas(df, preserve_index=False)
)
new_q = query(filters=None, limit=None)
out = _apply_post_processing(cached, new_q, set(), False)
# same object reference is OK; we explicitly return the input
assert out is cached
# ---------------------------------------------------------------------------
# Hash stability
# ---------------------------------------------------------------------------
def test_value_key_stable_across_metric_order() -> None:
q1 = SemanticQuery(metrics=[M_X, M_Y], dimensions=[COL_A])
q2 = SemanticQuery(metrics=[M_Y, M_X], dimensions=[COL_A])
assert value_key(VIEW, q1) == value_key(VIEW, q2)
def test_shape_key_stable_across_dimension_order() -> None:
q1 = SemanticQuery(metrics=[M_X], dimensions=[COL_A, COL_B])
q2 = SemanticQuery(metrics=[M_X], dimensions=[COL_B, COL_A])
assert shape_key(VIEW, q1) == shape_key(VIEW, q2)
def test_shape_key_changes_with_changed_on() -> None:
q = SemanticQuery(metrics=[M_X], dimensions=[COL_A])
other = ViewMeta(uuid=VIEW.uuid, changed_on_iso="2099-01-01", cache_timeout=None)
assert shape_key(VIEW, q) != shape_key(other, q)
def test_value_key_changes_with_filter_value() -> None:
q1 = SemanticQuery(
metrics=[M_X],
dimensions=[COL_A],
filters={where(COL_A, Operator.GREATER_THAN, 1)},
)
q2 = SemanticQuery(
metrics=[M_X],
dimensions=[COL_A],
filters={where(COL_A, Operator.GREATER_THAN, 2)},
)
assert value_key(VIEW, q1) != value_key(VIEW, q2)
def test_value_key_with_datetime_filter() -> None:
f = where(COL_A, Operator.GREATER_THAN_OR_EQUAL, datetime(2025, 1, 1))
q = SemanticQuery(metrics=[M_X], dimensions=[COL_A], filters={f})
# should not raise
assert value_key(VIEW, q).startswith("sv:val:")
def test_shape_key_independent_of_dimensions() -> None:
# The v2 shape key buckets entries by metric set only; different dimension
# sets share the same shape so the projection path can find broader entries.
q1 = SemanticQuery(metrics=[M_X], dimensions=[COL_A, COL_B])
q2 = SemanticQuery(metrics=[M_X], dimensions=[COL_A])
assert shape_key(VIEW, q1) == shape_key(VIEW, q2)
# Value keys still differ.
assert value_key(VIEW, q1) != value_key(VIEW, q2)
# ---------------------------------------------------------------------------
# Projection (v2)
# ---------------------------------------------------------------------------
M_SUM = met("met.sum", "sum_x", aggregation=AggregationType.SUM)
M_COUNT = met("met.count", "count_x", aggregation=AggregationType.COUNT)
M_MIN = met("met.min", "min_x", aggregation=AggregationType.MIN)
M_MAX = met("met.max", "max_x", aggregation=AggregationType.MAX)
M_AVG = met("met.avg", "avg_x", aggregation=AggregationType.AVG)
M_UNKNOWN = met("met.unknown", "unknown_x", aggregation=None)
def _projection_query(
metrics: list[Metric],
new_dimensions: list[Dimension],
cached_dimensions: list[Dimension],
cached_filters: set[Filter] | None = None,
cached_limit: int | None = None,
new_filters: set[Filter] | None = None,
new_limit: int | None = None,
new_order: Any = None,
new_group_limit: GroupLimit | None = None,
) -> tuple[CachedEntry, SemanticQuery]:
cached_q = SemanticQuery(
metrics=metrics,
dimensions=cached_dimensions,
filters=cached_filters,
limit=cached_limit,
)
new_q = SemanticQuery(
metrics=metrics,
dimensions=new_dimensions,
filters=new_filters,
limit=new_limit,
order=new_order,
group_limit=new_group_limit,
)
return entry_from(cached_q), new_q
@pytest.mark.parametrize(
"metric,operator",
[
(M_SUM, "sum"),
(M_COUNT, "sum"),
(M_MIN, "min"),
(M_MAX, "max"),
],
)
def test_can_satisfy_projection_each_additive_op(metric: Metric, operator: str) -> None:
entry, new_q = _projection_query(
metrics=[metric],
new_dimensions=[COL_A],
cached_dimensions=[COL_A, COL_B],
)
ok, leftovers, projection = can_satisfy(entry, new_q)
assert ok is True
assert projection is True
assert leftovers == set()
def test_projection_rolls_up_sum() -> None:
entry, new_q = _projection_query(
metrics=[M_SUM],
new_dimensions=[COL_A],
cached_dimensions=[COL_A, COL_B],
)
cached_df = pd.DataFrame(
{"a": ["x", "x", "y", "y"], "b": [1, 2, 1, 2], "sum_x": [10, 20, 30, 40]}
)
cached = SemanticResult(
requests=[SemanticRequest(type="SQL", definition="select ...")],
results=pa.Table.from_pandas(cached_df, preserve_index=False),
)
out = _apply_post_processing(cached, new_q, set(), True)
out_df = out.results.to_pandas().sort_values("a").reset_index(drop=True)
assert list(out_df["a"]) == ["x", "y"]
assert list(out_df["sum_x"]) == [30, 70]
def test_projection_rolls_up_min_max_count() -> None:
entry, new_q = _projection_query(
metrics=[M_MIN, M_MAX, M_COUNT],
new_dimensions=[COL_A],
cached_dimensions=[COL_A, COL_B],
)
cached_df = pd.DataFrame(
{
"a": ["x", "x", "y", "y"],
"b": [1, 2, 1, 2],
"min_x": [5, 2, 9, 8],
"max_x": [50, 60, 70, 80],
"count_x": [1, 1, 2, 3],
}
)
cached = SemanticResult(
requests=[],
results=pa.Table.from_pandas(cached_df, preserve_index=False),
)
out = _apply_post_processing(cached, new_q, set(), True)
df = out.results.to_pandas().sort_values("a").reset_index(drop=True)
assert list(df["min_x"]) == [2, 8]
assert list(df["max_x"]) == [60, 80]
assert list(df["count_x"]) == [2, 5]
def test_projection_drops_multiple_dims() -> None:
col_c = dim("col.c", "c")
entry, new_q = _projection_query(
metrics=[M_SUM],
new_dimensions=[COL_A],
cached_dimensions=[COL_A, COL_B, col_c],
)
cached_df = pd.DataFrame(
{
"a": ["x", "x", "x", "y"],
"b": [1, 1, 2, 1],
"c": [10, 20, 10, 10],
"sum_x": [1, 2, 3, 4],
}
)
cached = SemanticResult(
requests=[], results=pa.Table.from_pandas(cached_df, preserve_index=False)
)
out = _apply_post_processing(cached, new_q, set(), True)
df = out.results.to_pandas().sort_values("a").reset_index(drop=True)
assert list(df["sum_x"]) == [6, 4]
def test_projection_with_leftover_filter_then_rollup() -> None:
entry, new_q = _projection_query(
metrics=[M_SUM],
new_dimensions=[COL_A],
cached_dimensions=[COL_A, COL_B],
new_filters={where(COL_B, Operator.GREATER_THAN, 1)},
)
cached_df = pd.DataFrame(
{"a": ["x", "x", "y"], "b": [1, 2, 2], "sum_x": [10, 20, 30]}
)
cached = SemanticResult(
requests=[], results=pa.Table.from_pandas(cached_df, preserve_index=False)
)
ok, leftovers, projection = can_satisfy(entry, new_q)
assert ok is True
assert projection is True
out = _apply_post_processing(cached, new_q, leftovers, projection)
df = out.results.to_pandas().sort_values("a").reset_index(drop=True)
# b > 1 removes the (x,1) row; x sums to 20, y to 30
assert list(df["sum_x"]) == [20, 30]
def test_projection_with_order_and_limit() -> None:
entry, new_q = _projection_query(
metrics=[M_SUM],
new_dimensions=[COL_A],
cached_dimensions=[COL_A, COL_B],
new_order=[(M_SUM, OrderDirection.DESC)],
new_limit=1,
)
cached_df = pd.DataFrame(
{"a": ["x", "x", "y"], "b": [1, 2, 1], "sum_x": [1, 2, 100]}
)
cached = SemanticResult(
requests=[], results=pa.Table.from_pandas(cached_df, preserve_index=False)
)
out = _apply_post_processing(cached, new_q, set(), True)
df = out.results.to_pandas()
assert len(df) == 1
assert df["a"].tolist() == ["y"]
assert df["sum_x"].tolist() == [100]
def test_apply_post_processing_sorts_before_limit_for_non_projection() -> None:
cached_df = pd.DataFrame({"a": ["x", "y", "z"], "x": [1.0, 100.0, 50.0]})
cached = SemanticResult(
requests=[],
results=pa.Table.from_pandas(cached_df, preserve_index=False),
)
new_q = SemanticQuery(
metrics=[M_X],
dimensions=[COL_A],
order=[(M_X, OrderDirection.DESC)],
limit=2,
)
out = _apply_post_processing(cached, new_q, set(), False)
df = out.results.to_pandas()
assert df["x"].tolist() == [100.0, 50.0]
def test_projection_rejected_when_metric_aggregation_unknown() -> None:
entry, new_q = _projection_query(
metrics=[M_UNKNOWN],
new_dimensions=[COL_A],
cached_dimensions=[COL_A, COL_B],
)
ok, _, _ = can_satisfy(entry, new_q)
assert ok is False
def test_projection_rejected_for_avg() -> None:
entry, new_q = _projection_query(
metrics=[M_AVG],
new_dimensions=[COL_A],
cached_dimensions=[COL_A, COL_B],
)
ok, _, _ = can_satisfy(entry, new_q)
assert ok is False
def test_projection_with_cached_limit_defers_to_runtime_rowcount_check() -> None:
entry, new_q = _projection_query(
metrics=[M_SUM],
new_dimensions=[COL_A],
cached_dimensions=[COL_A, COL_B],
cached_limit=10,
)
ok, leftovers, projection = can_satisfy(entry, new_q)
assert ok is True
assert leftovers == set()
assert projection is True
def test_projection_input_complete_unlimited_cached() -> None:
entry = entry_from(
SemanticQuery(metrics=[M_SUM], dimensions=[COL_A, COL_B], limit=None)
)
payload = SemanticResult(
requests=[],
results=pa.Table.from_pydict({"a": ["x"], "b": [1], "sum_x": [1.0]}),
)
assert _projection_input_complete(entry, payload) is True
def test_projection_input_complete_limited_cached_short_page() -> None:
entry = entry_from(
SemanticQuery(metrics=[M_SUM], dimensions=[COL_A, COL_B], limit=10)
)
payload = SemanticResult(
requests=[],
results=pa.Table.from_pydict(
{
"a": ["x", "y", "z"],
"b": [1, 1, 1],
"sum_x": [1.0, 2.0, 3.0],
}
),
)
assert _projection_input_complete(entry, payload) is True
def test_projection_input_complete_limited_cached_full_page() -> None:
entry = entry_from(
SemanticQuery(metrics=[M_SUM], dimensions=[COL_A, COL_B], limit=3)
)
payload = SemanticResult(
requests=[],
results=pa.Table.from_pydict(
{
"a": ["x", "y", "z"],
"b": [1, 1, 1],
"sum_x": [1.0, 2.0, 3.0],
}
),
)
assert _projection_input_complete(entry, payload) is False
def test_projection_rejected_when_cached_has_having() -> None:
entry, new_q = _projection_query(
metrics=[M_SUM],
new_dimensions=[COL_A],
cached_dimensions=[COL_A, COL_B],
cached_filters={having(M_SUM, Operator.GREATER_THAN, 10)},
new_filters={having(M_SUM, Operator.GREATER_THAN, 10)},
)
ok, _, _ = can_satisfy(entry, new_q)
assert ok is False
def test_projection_rejected_when_new_query_has_group_limit() -> None:
group_limit = GroupLimit(
dimensions=[COL_A],
top=2,
metric=M_SUM,
direction=OrderDirection.DESC,
)
entry, new_q = _projection_query(
metrics=[M_SUM],
new_dimensions=[COL_A],
cached_dimensions=[COL_A, COL_B],
new_group_limit=group_limit,
)
ok, _, _ = can_satisfy(entry, new_q)
assert ok is False
def test_projection_rejected_when_order_references_dropped_dim() -> None:
entry, new_q = _projection_query(
metrics=[M_SUM],
new_dimensions=[COL_A],
cached_dimensions=[COL_A, COL_B],
new_order=[(COL_B, OrderDirection.ASC)],
)
ok, _, _ = can_satisfy(entry, new_q)
assert ok is False
def test_projection_rejected_when_cached_has_filter_on_dropped_dim() -> None:
# cached restricts c; rolling up to [a] would miss rows we'd need
entry, new_q = _projection_query(
metrics=[M_SUM],
new_dimensions=[COL_A],
cached_dimensions=[COL_A, COL_B],
cached_filters={where(COL_B, Operator.GREATER_THAN, 5)},
)
ok, _, _ = can_satisfy(entry, new_q)
assert ok is False
def test_projection_rejected_when_cached_dims_subset_not_superset() -> None:
# cached has just [a]; new wants [a, b] — finer-grained data unavailable
entry, new_q = _projection_query(
metrics=[M_SUM],
new_dimensions=[COL_A, COL_B],
cached_dimensions=[COL_A],
)
ok, _, _ = can_satisfy(entry, new_q)
assert ok is False

View File

@@ -1251,6 +1251,41 @@ def test_get_results_without_time_offsets(
# Verify DataFrame matches main query result
pd.testing.assert_frame_equal(result.df, main_df)
assert result.semantic_cache_hit is False
def test_get_results_marks_semantic_cache_hit_from_requests(
mock_datasource: MagicMock,
mocker: MockerFixture,
) -> None:
main_df = pd.DataFrame({"category": ["A"], "total_sales": [1.0]})
cached_result = SemanticResult(
requests=[
SemanticRequest(type="SQL", definition="SELECT ..."),
SemanticRequest(
type="cache",
definition=(
"Served from semantic view smart cache (re-aggregated locally)"
),
),
],
results=pa.Table.from_pandas(main_df),
)
mock_datasource.implementation.get_table = mocker.Mock(return_value=cached_result)
query_object = ValidatedQueryObject(
datasource=mock_datasource,
from_dttm=datetime(2025, 10, 15),
to_dttm=datetime(2025, 10, 22),
metrics=["total_sales"],
columns=["category"],
granularity="order_date",
)
result = get_results(query_object)
assert result.semantic_cache_hit is True
def test_get_results_with_single_time_offset(