backend for filters

This commit is contained in:
Kamil Gabryjelski
2025-12-17 13:28:54 +01:00
parent 5d6a697e32
commit 0a6bba1a14
2 changed files with 496 additions and 9 deletions

View File

@@ -1529,6 +1529,72 @@ class SqlaTable(
from_clause = self._apply_what_if_transform(from_clause, what_if)
return from_clause, cte
def _build_what_if_filter_condition(
self,
filters: list[dict[str, Any]],
) -> ColumnElement | None:
"""
Build a SQLAlchemy condition from a list of what-if filters.
Supports operators: ==, !=, >, <, >=, <=, IN, NOT IN, TEMPORAL_RANGE
:param filters: List of filter dicts with 'col', 'op', and 'val' keys
:returns: Combined SQLAlchemy condition (ANDed together), or None if no valid filters
"""
from superset.common.utils.time_range_utils import (
get_since_until_from_time_range,
)
from superset.utils.core import FilterOperator
conditions: list[ColumnElement] = []
available_columns = {col.column_name for col in self.columns}
for flt in filters:
col_name = flt.get("col")
op = flt.get("op")
val = flt.get("val")
# Skip if column doesn't exist in datasource
if col_name not in available_columns:
continue
sqla_col = sa.column(col_name)
if op == FilterOperator.EQUALS:
conditions.append(sqla_col == val)
elif op == FilterOperator.NOT_EQUALS:
conditions.append(sqla_col != val)
elif op == FilterOperator.GREATER_THAN:
conditions.append(sqla_col > val)
elif op == FilterOperator.LESS_THAN:
conditions.append(sqla_col < val)
elif op == FilterOperator.GREATER_THAN_OR_EQUALS:
conditions.append(sqla_col >= val)
elif op == FilterOperator.LESS_THAN_OR_EQUALS:
conditions.append(sqla_col <= val)
elif op == FilterOperator.IN:
if isinstance(val, list):
conditions.append(sqla_col.in_(val))
elif op == FilterOperator.NOT_IN:
if isinstance(val, list):
conditions.append(~sqla_col.in_(val))
elif op == FilterOperator.TEMPORAL_RANGE:
# Parse time range string like "2024-01-01 : 2024-03-31" or "Last week"
if isinstance(val, str):
since, until = get_since_until_from_time_range(time_range=val)
time_conditions = []
if since:
time_conditions.append(sqla_col >= sa.literal(since))
if until:
time_conditions.append(sqla_col < sa.literal(until))
if time_conditions:
conditions.append(and_(*time_conditions))
if not conditions:
return None
return and_(*conditions)
def _apply_what_if_transform(
self,
source: TableClause | Alias,
@@ -1547,22 +1613,29 @@ class SqlaTable(
if not modifications:
return source # type: ignore
# Build a dict of column -> multiplier
mod_map = {m["column"]: m["multiplier"] for m in modifications}
# Build a dict of column -> modification config (including filters)
mod_map = {m["column"]: m for m in modifications}
# Get columns needed by the query + modified columns
# None means we need all columns (e.g., for complex SQL metrics)
needed_columns: set[str] | None = what_if.get("needed_columns")
modified_column_names = set(mod_map.keys())
# Collect columns used in filters
filter_columns: set[str] = set()
for mod in modifications:
for flt in mod.get("filters", []):
if col_name := flt.get("col"):
filter_columns.add(col_name)
# Determine which columns to select
available_columns = {col.column_name for col in self.columns}
if needed_columns is None:
# Use all available columns
columns_to_select = available_columns
else:
# Use only needed columns + modified columns
columns_to_select = needed_columns | modified_column_names
# Use only needed columns + modified columns + filter columns
columns_to_select = needed_columns | modified_column_names | filter_columns
# Build select list with only needed columns
select_columns = []
@@ -1573,11 +1646,26 @@ class SqlaTable(
continue
if col_name in mod_map:
# Apply transformation: column * multiplier AS column
multiplier = mod_map[col_name]
transformed = (sa.column(col_name) * sa.literal(multiplier)).label(
col_name
)
mod = mod_map[col_name]
multiplier = mod["multiplier"]
filters = mod.get("filters", [])
col_ref = sa.column(col_name)
if filters:
# Build conditional transformation with CASE WHEN
condition = self._build_what_if_filter_condition(filters)
if condition is not None:
transformed = sa.case(
(condition, col_ref * sa.literal(multiplier)),
else_=col_ref,
).label(col_name)
else:
# No valid filter conditions, apply unconditionally
transformed = (col_ref * sa.literal(multiplier)).label(col_name)
else:
# No filters, apply transformation to all rows
transformed = (col_ref * sa.literal(multiplier)).label(col_name)
select_columns.append(transformed)
else:
select_columns.append(sa.column(col_name))

View File

@@ -1240,3 +1240,402 @@ def test_collect_needed_columns_returns_none_for_adhoc_groupby(
)
assert needed is None
def test_apply_what_if_transform_with_single_filter(mocker: MockerFixture) -> None:
"""
Test that _apply_what_if_transform generates CASE WHEN for filtered modifications.
"""
engine = create_engine("sqlite://")
database = mocker.MagicMock()
database.db_engine_spec.engine = "sqlite"
table = SqlaTable(
table_name="sales",
database=database,
columns=[
TableColumn(column_name="date"),
TableColumn(column_name="product"),
TableColumn(column_name="ad_spend"),
TableColumn(column_name="revenue"),
],
)
source = table.get_sqla_table()
# Apply what-if with filter: only modify ad_spend where product = 'Widget'
what_if = {
"modifications": [
{
"column": "ad_spend",
"multiplier": 1.2,
"filters": [{"col": "product", "op": "==", "val": "Widget"}],
}
],
"needed_columns": {"date", "ad_spend", "revenue"},
}
result = table._apply_what_if_transform(source, what_if)
query = select(result)
compiled = str(query.compile(engine, compile_kwargs={"literal_binds": True}))
# Should have CASE WHEN with the filter condition
assert "CASE WHEN" in compiled
assert "product" in compiled
assert "'Widget'" in compiled
assert "ad_spend * 1.2" in compiled
assert "__what_if" in compiled
def test_apply_what_if_transform_with_multiple_filters(mocker: MockerFixture) -> None:
"""
Test that _apply_what_if_transform ANDs multiple filter conditions together.
"""
engine = create_engine("sqlite://")
database = mocker.MagicMock()
database.db_engine_spec.engine = "sqlite"
table = SqlaTable(
table_name="sales",
database=database,
columns=[
TableColumn(column_name="date"),
TableColumn(column_name="product"),
TableColumn(column_name="region"),
TableColumn(column_name="ad_spend"),
],
)
source = table.get_sqla_table()
# Multiple filters: product = 'Widget' AND region = 'US'
what_if = {
"modifications": [
{
"column": "ad_spend",
"multiplier": 1.5,
"filters": [
{"col": "product", "op": "==", "val": "Widget"},
{"col": "region", "op": "==", "val": "US"},
],
}
],
"needed_columns": {"date", "ad_spend"},
}
result = table._apply_what_if_transform(source, what_if)
query = select(result)
compiled = str(query.compile(engine, compile_kwargs={"literal_binds": True}))
# Should have both conditions ANDed
assert "CASE WHEN" in compiled
assert "product" in compiled
assert "'Widget'" in compiled
assert "region" in compiled
assert "'US'" in compiled
assert "AND" in compiled
assert "ad_spend * 1.5" in compiled
def test_apply_what_if_transform_with_in_operator(mocker: MockerFixture) -> None:
"""
Test that _apply_what_if_transform handles IN operator correctly.
"""
engine = create_engine("sqlite://")
database = mocker.MagicMock()
database.db_engine_spec.engine = "sqlite"
table = SqlaTable(
table_name="sales",
database=database,
columns=[
TableColumn(column_name="product"),
TableColumn(column_name="ad_spend"),
],
)
source = table.get_sqla_table()
# Filter: product IN ['Widget', 'Gadget']
what_if = {
"modifications": [
{
"column": "ad_spend",
"multiplier": 1.1,
"filters": [
{"col": "product", "op": "IN", "val": ["Widget", "Gadget"]}
],
}
],
"needed_columns": {"ad_spend"},
}
result = table._apply_what_if_transform(source, what_if)
query = select(result)
compiled = str(query.compile(engine, compile_kwargs={"literal_binds": True}))
# Should have IN clause
assert "CASE WHEN" in compiled
assert "product IN" in compiled
assert "'Widget'" in compiled
assert "'Gadget'" in compiled
def test_apply_what_if_transform_filter_columns_included(mocker: MockerFixture) -> None:
"""
Test that filter columns are included in the subquery even if not in needed_columns.
"""
engine = create_engine("sqlite://")
database = mocker.MagicMock()
database.db_engine_spec.engine = "sqlite"
table = SqlaTable(
table_name="sales",
database=database,
columns=[
TableColumn(column_name="date"),
TableColumn(column_name="product"),
TableColumn(column_name="ad_spend"),
],
)
source = table.get_sqla_table()
# needed_columns doesn't include 'product', but it's used in filter
what_if = {
"modifications": [
{
"column": "ad_spend",
"multiplier": 1.2,
"filters": [{"col": "product", "op": "==", "val": "Widget"}],
}
],
"needed_columns": {"date", "ad_spend"}, # product NOT included
}
result = table._apply_what_if_transform(source, what_if)
query = select(result)
compiled = str(query.compile(engine, compile_kwargs={"literal_binds": True}))
# product should still be in the SELECT list because it's used in filter
assert "product" in compiled
assert "CASE WHEN" in compiled
def test_apply_what_if_transform_comparison_operators(mocker: MockerFixture) -> None:
"""
Test that _apply_what_if_transform handles comparison operators (>, <, >=, <=).
"""
engine = create_engine("sqlite://")
database = mocker.MagicMock()
database.db_engine_spec.engine = "sqlite"
table = SqlaTable(
table_name="sales",
database=database,
columns=[
TableColumn(column_name="quantity"),
TableColumn(column_name="ad_spend"),
],
)
source = table.get_sqla_table()
# Filter: quantity >= 100
what_if = {
"modifications": [
{
"column": "ad_spend",
"multiplier": 1.3,
"filters": [{"col": "quantity", "op": ">=", "val": 100}],
}
],
"needed_columns": {"ad_spend"},
}
result = table._apply_what_if_transform(source, what_if)
query = select(result)
compiled = str(query.compile(engine, compile_kwargs={"literal_binds": True}))
assert "CASE WHEN" in compiled
assert "quantity >= 100" in compiled
assert "ad_spend * 1.3" in compiled
def test_build_what_if_filter_condition_skips_nonexistent_columns(
mocker: MockerFixture,
) -> None:
"""
Test that _build_what_if_filter_condition skips filters for non-existent columns.
"""
database = mocker.MagicMock()
table = SqlaTable(
table_name="sales",
database=database,
columns=[
TableColumn(column_name="product"),
TableColumn(column_name="ad_spend"),
],
)
# Filter references non-existent column
filters = [
{"col": "nonexistent", "op": "==", "val": "test"},
{"col": "product", "op": "==", "val": "Widget"},
]
condition = table._build_what_if_filter_condition(filters)
# Should still return a condition (just for product)
assert condition is not None
compiled = str(condition)
assert "product" in compiled
assert "nonexistent" not in compiled
def test_build_what_if_filter_condition_returns_none_for_all_invalid(
mocker: MockerFixture,
) -> None:
"""
Test that _build_what_if_filter_condition returns None if all filters are invalid.
"""
database = mocker.MagicMock()
table = SqlaTable(
table_name="sales",
database=database,
columns=[
TableColumn(column_name="product"),
],
)
# All filters reference non-existent columns
filters = [
{"col": "nonexistent1", "op": "==", "val": "test"},
{"col": "nonexistent2", "op": "==", "val": "test"},
]
condition = table._build_what_if_filter_condition(filters)
assert condition is None
def test_apply_what_if_transform_with_temporal_range_filter(
mocker: MockerFixture,
) -> None:
"""
Test that _apply_what_if_transform handles TEMPORAL_RANGE filter correctly.
"""
from datetime import datetime
engine = create_engine("sqlite://")
database = mocker.MagicMock()
database.db_engine_spec.engine = "sqlite"
table = SqlaTable(
table_name="sales",
database=database,
columns=[
TableColumn(column_name="order_date"),
TableColumn(column_name="ad_spend"),
],
)
# Mock get_since_until_from_time_range to avoid Flask app context requirement
mocker.patch(
"superset.common.utils.time_range_utils.get_since_until_from_time_range",
return_value=(datetime(2024, 1, 1), datetime(2024, 3, 31)),
)
source = table.get_sqla_table()
what_if = {
"modifications": [
{
"column": "ad_spend",
"multiplier": 1.2,
"filters": [
{
"col": "order_date",
"op": "TEMPORAL_RANGE",
"val": "2024-01-01 : 2024-03-31",
}
],
}
],
"needed_columns": {"ad_spend"},
}
result = table._apply_what_if_transform(source, what_if)
query = select(result)
compiled = str(query.compile(engine, compile_kwargs={"literal_binds": True}))
# Should have CASE WHEN with time range conditions
assert "CASE WHEN" in compiled
assert "order_date >=" in compiled
assert "order_date <" in compiled
assert "ad_spend * 1.2" in compiled
def test_apply_what_if_transform_with_combined_filters(
mocker: MockerFixture,
) -> None:
"""
Test that _apply_what_if_transform handles combined product + time range filters.
"""
from datetime import datetime
engine = create_engine("sqlite://")
database = mocker.MagicMock()
database.db_engine_spec.engine = "sqlite"
table = SqlaTable(
table_name="sales",
database=database,
columns=[
TableColumn(column_name="order_date"),
TableColumn(column_name="product"),
TableColumn(column_name="ad_spend"),
],
)
# Mock get_since_until_from_time_range to avoid Flask app context requirement
mocker.patch(
"superset.common.utils.time_range_utils.get_since_until_from_time_range",
return_value=(datetime(2024, 1, 1), datetime(2024, 4, 1)),
)
source = table.get_sqla_table()
# Combined filter: product = 'Widget' AND order_date in Q1 2024
what_if = {
"modifications": [
{
"column": "ad_spend",
"multiplier": 1.5,
"filters": [
{"col": "product", "op": "==", "val": "Widget"},
{
"col": "order_date",
"op": "TEMPORAL_RANGE",
"val": "2024-01-01 : 2024-04-01",
},
],
}
],
"needed_columns": {"ad_spend"},
}
result = table._apply_what_if_transform(source, what_if)
query = select(result)
compiled = str(query.compile(engine, compile_kwargs={"literal_binds": True}))
# Should have CASE WHEN with all conditions ANDed together
assert "CASE WHEN" in compiled
assert "product" in compiled
assert "'Widget'" in compiled
assert "order_date >=" in compiled
assert "order_date <" in compiled
assert "AND" in compiled
assert "ad_spend * 1.5" in compiled
# Both filter columns should be in the SELECT list
assert "order_date" in compiled