mirror of
https://github.com/apache/superset.git
synced 2026-05-09 09:55:19 +00:00
backend for filters
This commit is contained in:
@@ -1529,6 +1529,72 @@ class SqlaTable(
|
||||
from_clause = self._apply_what_if_transform(from_clause, what_if)
|
||||
return from_clause, cte
|
||||
|
||||
def _build_what_if_filter_condition(
|
||||
self,
|
||||
filters: list[dict[str, Any]],
|
||||
) -> ColumnElement | None:
|
||||
"""
|
||||
Build a SQLAlchemy condition from a list of what-if filters.
|
||||
|
||||
Supports operators: ==, !=, >, <, >=, <=, IN, NOT IN, TEMPORAL_RANGE
|
||||
|
||||
:param filters: List of filter dicts with 'col', 'op', and 'val' keys
|
||||
:returns: Combined SQLAlchemy condition (ANDed together), or None if no valid filters
|
||||
"""
|
||||
from superset.common.utils.time_range_utils import (
|
||||
get_since_until_from_time_range,
|
||||
)
|
||||
from superset.utils.core import FilterOperator
|
||||
|
||||
conditions: list[ColumnElement] = []
|
||||
available_columns = {col.column_name for col in self.columns}
|
||||
|
||||
for flt in filters:
|
||||
col_name = flt.get("col")
|
||||
op = flt.get("op")
|
||||
val = flt.get("val")
|
||||
|
||||
# Skip if column doesn't exist in datasource
|
||||
if col_name not in available_columns:
|
||||
continue
|
||||
|
||||
sqla_col = sa.column(col_name)
|
||||
|
||||
if op == FilterOperator.EQUALS:
|
||||
conditions.append(sqla_col == val)
|
||||
elif op == FilterOperator.NOT_EQUALS:
|
||||
conditions.append(sqla_col != val)
|
||||
elif op == FilterOperator.GREATER_THAN:
|
||||
conditions.append(sqla_col > val)
|
||||
elif op == FilterOperator.LESS_THAN:
|
||||
conditions.append(sqla_col < val)
|
||||
elif op == FilterOperator.GREATER_THAN_OR_EQUALS:
|
||||
conditions.append(sqla_col >= val)
|
||||
elif op == FilterOperator.LESS_THAN_OR_EQUALS:
|
||||
conditions.append(sqla_col <= val)
|
||||
elif op == FilterOperator.IN:
|
||||
if isinstance(val, list):
|
||||
conditions.append(sqla_col.in_(val))
|
||||
elif op == FilterOperator.NOT_IN:
|
||||
if isinstance(val, list):
|
||||
conditions.append(~sqla_col.in_(val))
|
||||
elif op == FilterOperator.TEMPORAL_RANGE:
|
||||
# Parse time range string like "2024-01-01 : 2024-03-31" or "Last week"
|
||||
if isinstance(val, str):
|
||||
since, until = get_since_until_from_time_range(time_range=val)
|
||||
time_conditions = []
|
||||
if since:
|
||||
time_conditions.append(sqla_col >= sa.literal(since))
|
||||
if until:
|
||||
time_conditions.append(sqla_col < sa.literal(until))
|
||||
if time_conditions:
|
||||
conditions.append(and_(*time_conditions))
|
||||
|
||||
if not conditions:
|
||||
return None
|
||||
|
||||
return and_(*conditions)
|
||||
|
||||
def _apply_what_if_transform(
|
||||
self,
|
||||
source: TableClause | Alias,
|
||||
@@ -1547,22 +1613,29 @@ class SqlaTable(
|
||||
if not modifications:
|
||||
return source # type: ignore
|
||||
|
||||
# Build a dict of column -> multiplier
|
||||
mod_map = {m["column"]: m["multiplier"] for m in modifications}
|
||||
# Build a dict of column -> modification config (including filters)
|
||||
mod_map = {m["column"]: m for m in modifications}
|
||||
|
||||
# Get columns needed by the query + modified columns
|
||||
# None means we need all columns (e.g., for complex SQL metrics)
|
||||
needed_columns: set[str] | None = what_if.get("needed_columns")
|
||||
modified_column_names = set(mod_map.keys())
|
||||
|
||||
# Collect columns used in filters
|
||||
filter_columns: set[str] = set()
|
||||
for mod in modifications:
|
||||
for flt in mod.get("filters", []):
|
||||
if col_name := flt.get("col"):
|
||||
filter_columns.add(col_name)
|
||||
|
||||
# Determine which columns to select
|
||||
available_columns = {col.column_name for col in self.columns}
|
||||
if needed_columns is None:
|
||||
# Use all available columns
|
||||
columns_to_select = available_columns
|
||||
else:
|
||||
# Use only needed columns + modified columns
|
||||
columns_to_select = needed_columns | modified_column_names
|
||||
# Use only needed columns + modified columns + filter columns
|
||||
columns_to_select = needed_columns | modified_column_names | filter_columns
|
||||
|
||||
# Build select list with only needed columns
|
||||
select_columns = []
|
||||
@@ -1573,11 +1646,26 @@ class SqlaTable(
|
||||
continue
|
||||
|
||||
if col_name in mod_map:
|
||||
# Apply transformation: column * multiplier AS column
|
||||
multiplier = mod_map[col_name]
|
||||
transformed = (sa.column(col_name) * sa.literal(multiplier)).label(
|
||||
col_name
|
||||
)
|
||||
mod = mod_map[col_name]
|
||||
multiplier = mod["multiplier"]
|
||||
filters = mod.get("filters", [])
|
||||
col_ref = sa.column(col_name)
|
||||
|
||||
if filters:
|
||||
# Build conditional transformation with CASE WHEN
|
||||
condition = self._build_what_if_filter_condition(filters)
|
||||
if condition is not None:
|
||||
transformed = sa.case(
|
||||
(condition, col_ref * sa.literal(multiplier)),
|
||||
else_=col_ref,
|
||||
).label(col_name)
|
||||
else:
|
||||
# No valid filter conditions, apply unconditionally
|
||||
transformed = (col_ref * sa.literal(multiplier)).label(col_name)
|
||||
else:
|
||||
# No filters, apply transformation to all rows
|
||||
transformed = (col_ref * sa.literal(multiplier)).label(col_name)
|
||||
|
||||
select_columns.append(transformed)
|
||||
else:
|
||||
select_columns.append(sa.column(col_name))
|
||||
|
||||
@@ -1240,3 +1240,402 @@ def test_collect_needed_columns_returns_none_for_adhoc_groupby(
|
||||
)
|
||||
|
||||
assert needed is None
|
||||
|
||||
|
||||
def test_apply_what_if_transform_with_single_filter(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test that _apply_what_if_transform generates CASE WHEN for filtered modifications.
|
||||
"""
|
||||
engine = create_engine("sqlite://")
|
||||
database = mocker.MagicMock()
|
||||
database.db_engine_spec.engine = "sqlite"
|
||||
|
||||
table = SqlaTable(
|
||||
table_name="sales",
|
||||
database=database,
|
||||
columns=[
|
||||
TableColumn(column_name="date"),
|
||||
TableColumn(column_name="product"),
|
||||
TableColumn(column_name="ad_spend"),
|
||||
TableColumn(column_name="revenue"),
|
||||
],
|
||||
)
|
||||
|
||||
source = table.get_sqla_table()
|
||||
|
||||
# Apply what-if with filter: only modify ad_spend where product = 'Widget'
|
||||
what_if = {
|
||||
"modifications": [
|
||||
{
|
||||
"column": "ad_spend",
|
||||
"multiplier": 1.2,
|
||||
"filters": [{"col": "product", "op": "==", "val": "Widget"}],
|
||||
}
|
||||
],
|
||||
"needed_columns": {"date", "ad_spend", "revenue"},
|
||||
}
|
||||
result = table._apply_what_if_transform(source, what_if)
|
||||
|
||||
query = select(result)
|
||||
compiled = str(query.compile(engine, compile_kwargs={"literal_binds": True}))
|
||||
|
||||
# Should have CASE WHEN with the filter condition
|
||||
assert "CASE WHEN" in compiled
|
||||
assert "product" in compiled
|
||||
assert "'Widget'" in compiled
|
||||
assert "ad_spend * 1.2" in compiled
|
||||
assert "__what_if" in compiled
|
||||
|
||||
|
||||
def test_apply_what_if_transform_with_multiple_filters(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test that _apply_what_if_transform ANDs multiple filter conditions together.
|
||||
"""
|
||||
engine = create_engine("sqlite://")
|
||||
database = mocker.MagicMock()
|
||||
database.db_engine_spec.engine = "sqlite"
|
||||
|
||||
table = SqlaTable(
|
||||
table_name="sales",
|
||||
database=database,
|
||||
columns=[
|
||||
TableColumn(column_name="date"),
|
||||
TableColumn(column_name="product"),
|
||||
TableColumn(column_name="region"),
|
||||
TableColumn(column_name="ad_spend"),
|
||||
],
|
||||
)
|
||||
|
||||
source = table.get_sqla_table()
|
||||
|
||||
# Multiple filters: product = 'Widget' AND region = 'US'
|
||||
what_if = {
|
||||
"modifications": [
|
||||
{
|
||||
"column": "ad_spend",
|
||||
"multiplier": 1.5,
|
||||
"filters": [
|
||||
{"col": "product", "op": "==", "val": "Widget"},
|
||||
{"col": "region", "op": "==", "val": "US"},
|
||||
],
|
||||
}
|
||||
],
|
||||
"needed_columns": {"date", "ad_spend"},
|
||||
}
|
||||
result = table._apply_what_if_transform(source, what_if)
|
||||
|
||||
query = select(result)
|
||||
compiled = str(query.compile(engine, compile_kwargs={"literal_binds": True}))
|
||||
|
||||
# Should have both conditions ANDed
|
||||
assert "CASE WHEN" in compiled
|
||||
assert "product" in compiled
|
||||
assert "'Widget'" in compiled
|
||||
assert "region" in compiled
|
||||
assert "'US'" in compiled
|
||||
assert "AND" in compiled
|
||||
assert "ad_spend * 1.5" in compiled
|
||||
|
||||
|
||||
def test_apply_what_if_transform_with_in_operator(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test that _apply_what_if_transform handles IN operator correctly.
|
||||
"""
|
||||
engine = create_engine("sqlite://")
|
||||
database = mocker.MagicMock()
|
||||
database.db_engine_spec.engine = "sqlite"
|
||||
|
||||
table = SqlaTable(
|
||||
table_name="sales",
|
||||
database=database,
|
||||
columns=[
|
||||
TableColumn(column_name="product"),
|
||||
TableColumn(column_name="ad_spend"),
|
||||
],
|
||||
)
|
||||
|
||||
source = table.get_sqla_table()
|
||||
|
||||
# Filter: product IN ['Widget', 'Gadget']
|
||||
what_if = {
|
||||
"modifications": [
|
||||
{
|
||||
"column": "ad_spend",
|
||||
"multiplier": 1.1,
|
||||
"filters": [
|
||||
{"col": "product", "op": "IN", "val": ["Widget", "Gadget"]}
|
||||
],
|
||||
}
|
||||
],
|
||||
"needed_columns": {"ad_spend"},
|
||||
}
|
||||
result = table._apply_what_if_transform(source, what_if)
|
||||
|
||||
query = select(result)
|
||||
compiled = str(query.compile(engine, compile_kwargs={"literal_binds": True}))
|
||||
|
||||
# Should have IN clause
|
||||
assert "CASE WHEN" in compiled
|
||||
assert "product IN" in compiled
|
||||
assert "'Widget'" in compiled
|
||||
assert "'Gadget'" in compiled
|
||||
|
||||
|
||||
def test_apply_what_if_transform_filter_columns_included(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test that filter columns are included in the subquery even if not in needed_columns.
|
||||
"""
|
||||
engine = create_engine("sqlite://")
|
||||
database = mocker.MagicMock()
|
||||
database.db_engine_spec.engine = "sqlite"
|
||||
|
||||
table = SqlaTable(
|
||||
table_name="sales",
|
||||
database=database,
|
||||
columns=[
|
||||
TableColumn(column_name="date"),
|
||||
TableColumn(column_name="product"),
|
||||
TableColumn(column_name="ad_spend"),
|
||||
],
|
||||
)
|
||||
|
||||
source = table.get_sqla_table()
|
||||
|
||||
# needed_columns doesn't include 'product', but it's used in filter
|
||||
what_if = {
|
||||
"modifications": [
|
||||
{
|
||||
"column": "ad_spend",
|
||||
"multiplier": 1.2,
|
||||
"filters": [{"col": "product", "op": "==", "val": "Widget"}],
|
||||
}
|
||||
],
|
||||
"needed_columns": {"date", "ad_spend"}, # product NOT included
|
||||
}
|
||||
result = table._apply_what_if_transform(source, what_if)
|
||||
|
||||
query = select(result)
|
||||
compiled = str(query.compile(engine, compile_kwargs={"literal_binds": True}))
|
||||
|
||||
# product should still be in the SELECT list because it's used in filter
|
||||
assert "product" in compiled
|
||||
assert "CASE WHEN" in compiled
|
||||
|
||||
|
||||
def test_apply_what_if_transform_comparison_operators(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test that _apply_what_if_transform handles comparison operators (>, <, >=, <=).
|
||||
"""
|
||||
engine = create_engine("sqlite://")
|
||||
database = mocker.MagicMock()
|
||||
database.db_engine_spec.engine = "sqlite"
|
||||
|
||||
table = SqlaTable(
|
||||
table_name="sales",
|
||||
database=database,
|
||||
columns=[
|
||||
TableColumn(column_name="quantity"),
|
||||
TableColumn(column_name="ad_spend"),
|
||||
],
|
||||
)
|
||||
|
||||
source = table.get_sqla_table()
|
||||
|
||||
# Filter: quantity >= 100
|
||||
what_if = {
|
||||
"modifications": [
|
||||
{
|
||||
"column": "ad_spend",
|
||||
"multiplier": 1.3,
|
||||
"filters": [{"col": "quantity", "op": ">=", "val": 100}],
|
||||
}
|
||||
],
|
||||
"needed_columns": {"ad_spend"},
|
||||
}
|
||||
result = table._apply_what_if_transform(source, what_if)
|
||||
|
||||
query = select(result)
|
||||
compiled = str(query.compile(engine, compile_kwargs={"literal_binds": True}))
|
||||
|
||||
assert "CASE WHEN" in compiled
|
||||
assert "quantity >= 100" in compiled
|
||||
assert "ad_spend * 1.3" in compiled
|
||||
|
||||
|
||||
def test_build_what_if_filter_condition_skips_nonexistent_columns(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""
|
||||
Test that _build_what_if_filter_condition skips filters for non-existent columns.
|
||||
"""
|
||||
database = mocker.MagicMock()
|
||||
|
||||
table = SqlaTable(
|
||||
table_name="sales",
|
||||
database=database,
|
||||
columns=[
|
||||
TableColumn(column_name="product"),
|
||||
TableColumn(column_name="ad_spend"),
|
||||
],
|
||||
)
|
||||
|
||||
# Filter references non-existent column
|
||||
filters = [
|
||||
{"col": "nonexistent", "op": "==", "val": "test"},
|
||||
{"col": "product", "op": "==", "val": "Widget"},
|
||||
]
|
||||
condition = table._build_what_if_filter_condition(filters)
|
||||
|
||||
# Should still return a condition (just for product)
|
||||
assert condition is not None
|
||||
compiled = str(condition)
|
||||
assert "product" in compiled
|
||||
assert "nonexistent" not in compiled
|
||||
|
||||
|
||||
def test_build_what_if_filter_condition_returns_none_for_all_invalid(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""
|
||||
Test that _build_what_if_filter_condition returns None if all filters are invalid.
|
||||
"""
|
||||
database = mocker.MagicMock()
|
||||
|
||||
table = SqlaTable(
|
||||
table_name="sales",
|
||||
database=database,
|
||||
columns=[
|
||||
TableColumn(column_name="product"),
|
||||
],
|
||||
)
|
||||
|
||||
# All filters reference non-existent columns
|
||||
filters = [
|
||||
{"col": "nonexistent1", "op": "==", "val": "test"},
|
||||
{"col": "nonexistent2", "op": "==", "val": "test"},
|
||||
]
|
||||
condition = table._build_what_if_filter_condition(filters)
|
||||
|
||||
assert condition is None
|
||||
|
||||
|
||||
def test_apply_what_if_transform_with_temporal_range_filter(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""
|
||||
Test that _apply_what_if_transform handles TEMPORAL_RANGE filter correctly.
|
||||
"""
|
||||
from datetime import datetime
|
||||
|
||||
engine = create_engine("sqlite://")
|
||||
database = mocker.MagicMock()
|
||||
database.db_engine_spec.engine = "sqlite"
|
||||
|
||||
table = SqlaTable(
|
||||
table_name="sales",
|
||||
database=database,
|
||||
columns=[
|
||||
TableColumn(column_name="order_date"),
|
||||
TableColumn(column_name="ad_spend"),
|
||||
],
|
||||
)
|
||||
|
||||
# Mock get_since_until_from_time_range to avoid Flask app context requirement
|
||||
mocker.patch(
|
||||
"superset.common.utils.time_range_utils.get_since_until_from_time_range",
|
||||
return_value=(datetime(2024, 1, 1), datetime(2024, 3, 31)),
|
||||
)
|
||||
|
||||
source = table.get_sqla_table()
|
||||
|
||||
what_if = {
|
||||
"modifications": [
|
||||
{
|
||||
"column": "ad_spend",
|
||||
"multiplier": 1.2,
|
||||
"filters": [
|
||||
{
|
||||
"col": "order_date",
|
||||
"op": "TEMPORAL_RANGE",
|
||||
"val": "2024-01-01 : 2024-03-31",
|
||||
}
|
||||
],
|
||||
}
|
||||
],
|
||||
"needed_columns": {"ad_spend"},
|
||||
}
|
||||
result = table._apply_what_if_transform(source, what_if)
|
||||
|
||||
query = select(result)
|
||||
compiled = str(query.compile(engine, compile_kwargs={"literal_binds": True}))
|
||||
|
||||
# Should have CASE WHEN with time range conditions
|
||||
assert "CASE WHEN" in compiled
|
||||
assert "order_date >=" in compiled
|
||||
assert "order_date <" in compiled
|
||||
assert "ad_spend * 1.2" in compiled
|
||||
|
||||
|
||||
def test_apply_what_if_transform_with_combined_filters(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""
|
||||
Test that _apply_what_if_transform handles combined product + time range filters.
|
||||
"""
|
||||
from datetime import datetime
|
||||
|
||||
engine = create_engine("sqlite://")
|
||||
database = mocker.MagicMock()
|
||||
database.db_engine_spec.engine = "sqlite"
|
||||
|
||||
table = SqlaTable(
|
||||
table_name="sales",
|
||||
database=database,
|
||||
columns=[
|
||||
TableColumn(column_name="order_date"),
|
||||
TableColumn(column_name="product"),
|
||||
TableColumn(column_name="ad_spend"),
|
||||
],
|
||||
)
|
||||
|
||||
# Mock get_since_until_from_time_range to avoid Flask app context requirement
|
||||
mocker.patch(
|
||||
"superset.common.utils.time_range_utils.get_since_until_from_time_range",
|
||||
return_value=(datetime(2024, 1, 1), datetime(2024, 4, 1)),
|
||||
)
|
||||
|
||||
source = table.get_sqla_table()
|
||||
|
||||
# Combined filter: product = 'Widget' AND order_date in Q1 2024
|
||||
what_if = {
|
||||
"modifications": [
|
||||
{
|
||||
"column": "ad_spend",
|
||||
"multiplier": 1.5,
|
||||
"filters": [
|
||||
{"col": "product", "op": "==", "val": "Widget"},
|
||||
{
|
||||
"col": "order_date",
|
||||
"op": "TEMPORAL_RANGE",
|
||||
"val": "2024-01-01 : 2024-04-01",
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
"needed_columns": {"ad_spend"},
|
||||
}
|
||||
result = table._apply_what_if_transform(source, what_if)
|
||||
|
||||
query = select(result)
|
||||
compiled = str(query.compile(engine, compile_kwargs={"literal_binds": True}))
|
||||
|
||||
# Should have CASE WHEN with all conditions ANDed together
|
||||
assert "CASE WHEN" in compiled
|
||||
assert "product" in compiled
|
||||
assert "'Widget'" in compiled
|
||||
assert "order_date >=" in compiled
|
||||
assert "order_date <" in compiled
|
||||
assert "AND" in compiled
|
||||
assert "ad_spend * 1.5" in compiled
|
||||
# Both filter columns should be in the SELECT list
|
||||
assert "order_date" in compiled
|
||||
|
||||
Reference in New Issue
Block a user