feat: WHAT IF - backend

This commit is contained in:
Kamil Gabryjelski
2025-12-16 14:31:09 +01:00
parent 6f8052b828
commit 4dab58f8c0
3 changed files with 582 additions and 12 deletions

View File

@@ -1516,11 +1516,79 @@ class SqlaTable(
def get_from_clause(
self,
template_processor: BaseTemplateProcessor | None = None,
what_if: dict[str, Any] | None = None,
) -> tuple[TableClause | Alias, str | None]:
if not self.is_virtual:
return self.get_sqla_table(), None
tbl = self.get_sqla_table()
if what_if:
tbl = self._apply_what_if_transform(tbl, what_if)
return tbl, None
return super().get_from_clause(template_processor)
from_clause, cte = super().get_from_clause(template_processor, what_if=None)
if what_if:
from_clause = self._apply_what_if_transform(from_clause, what_if)
return from_clause, cte
def _apply_what_if_transform(
self,
source: TableClause | Alias,
what_if: dict[str, Any],
) -> Alias:
"""
Wrap the source table/subquery with a subquery that applies
column transformations for what-if analysis.
:param source: Original table or subquery to transform
:param what_if: Dict containing 'modifications' list with column/multiplier
pairs and 'needed_columns' set with columns required by the query
:returns: Aliased subquery with transformations applied
"""
modifications = what_if.get("modifications", [])
if not modifications:
return source # type: ignore
# Build a dict of column -> multiplier
mod_map = {m["column"]: m["multiplier"] for m in modifications}
# Get columns needed by the query + modified columns
# None means we need all columns (e.g., for complex SQL metrics)
needed_columns: set[str] | None = what_if.get("needed_columns")
modified_column_names = set(mod_map.keys())
# Determine which columns to select
available_columns = {col.column_name for col in self.columns}
if needed_columns is None:
# Use all available columns
columns_to_select = available_columns
else:
# Use only needed columns + modified columns
columns_to_select = needed_columns | modified_column_names
# Build select list with only needed columns
select_columns = []
for col_name in columns_to_select:
# Skip columns that don't exist in the datasource
if col_name not in available_columns:
continue
if col_name in mod_map:
# Apply transformation: column * multiplier AS column
multiplier = mod_map[col_name]
transformed = (sa.column(col_name) * sa.literal(multiplier)).label(
col_name
)
select_columns.append(transformed)
else:
select_columns.append(sa.column(col_name))
if not select_columns:
# Fallback: if no columns to select, return source unchanged
return source # type: ignore
# Create subquery with transformations
subq = sa.select(*select_columns).select_from(source)
return subq.alias("__what_if")
def adhoc_metric_to_sqla(
self,

View File

@@ -2016,7 +2016,9 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
return self.db_engine_spec.get_text_clause(clause)
def get_from_clause(
self, template_processor: Optional[BaseTemplateProcessor] = None
self,
template_processor: Optional[BaseTemplateProcessor] = None,
what_if: Optional[dict[str, Any]] = None,
) -> tuple[Union[TableClause, Alias], Optional[str]]:
"""
Return where to select the columns and metrics from. Either a physical table
@@ -2060,6 +2062,96 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
return from_clause, cte
def _collect_needed_columns( # noqa: C901
self,
columns: Optional[list[Column]] = None,
groupby: Optional[list[Column]] = None,
metrics: Optional[list[Metric]] = None,
filter: Optional[list[utils.QueryObjectFilterClause]] = None,
orderby: Optional[list[OrderBy]] = None,
granularity: Optional[str] = None,
) -> Optional[set[str]]:
"""
Collect all column names needed by the query for what-if transformation.
This allows us to only select necessary columns instead of SELECT *.
:returns: Set of column names that are referenced by the query,
or None if all columns should be included (e.g., for complex metrics)
"""
needed: set[str] = set()
# Add granularity column (time column)
if granularity:
needed.add(granularity)
# Add columns from dimensions/columns list
for col in columns or []:
if isinstance(col, str):
needed.add(col)
elif isinstance(col, dict):
if col.get("sqlExpression"):
# Adhoc column with SQL expression - can't determine columns
return None
if col.get("column_name"):
needed.add(col["column_name"])
# Add columns from groupby
for col in groupby or []:
if isinstance(col, str):
needed.add(col)
elif isinstance(col, dict):
if col.get("sqlExpression"):
# Adhoc column with SQL expression - can't determine columns
return None
if col.get("column_name"):
needed.add(col["column_name"])
# Add columns from metrics (try to extract column references)
# For complex metrics (SQL expressions or saved metrics), we need all columns
# because we can't easily parse what columns they reference
for metric in metrics or []:
if isinstance(metric, str):
# Saved metric - can't determine columns, need all
return None # Signal to use all columns
elif isinstance(metric, dict):
expression_type = metric.get("expressionType")
if expression_type == "SQL":
# SQL expression - can't determine columns, need all
return None # Signal to use all columns
# SIMPLE adhoc metric - check for column reference
metric_column = metric.get("column")
if isinstance(metric_column, dict):
col_name = metric_column.get("column_name")
if isinstance(col_name, str):
needed.add(col_name)
# Add columns from filters
for flt in filter or []:
col = flt.get("col")
if isinstance(col, str):
needed.add(col)
elif isinstance(col, dict):
if col.get("sqlExpression"):
# Adhoc column filter - can't determine columns
return None
if col.get("column_name"):
needed.add(col["column_name"])
# Add columns from orderby
for order_item in orderby or []:
if isinstance(order_item, (list, tuple)) and len(order_item) >= 1:
col = order_item[0]
if isinstance(col, str):
needed.add(col)
elif isinstance(col, dict):
if col.get("sqlExpression"):
# Adhoc column orderby - can't determine columns
return None
if col.get("column_name"):
needed.add(col["column_name"])
return needed
def adhoc_metric_to_sqla(
self,
metric: AdhocMetric,
@@ -2879,7 +2971,19 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
# Process FROM clause early to populate removed_filters from virtual dataset
# templates before we decide whether to add time filters
tbl, cte = self.get_from_clause(template_processor)
what_if = extras.get("what_if") if extras else None
if what_if:
# Collect columns needed by the query for efficient what-if transformation
what_if = dict(what_if) # Copy to avoid mutating original
what_if["needed_columns"] = self._collect_needed_columns(
columns=columns,
groupby=groupby,
metrics=metrics,
filter=filter,
orderby=orderby,
granularity=granularity,
)
tbl, cte = self.get_from_clause(template_processor, what_if=what_if)
if granularity:
if granularity not in columns_by_name or not dttm_col:

View File

@@ -18,7 +18,7 @@
import pandas as pd
import pytest
from pytest_mock import MockerFixture
from sqlalchemy import create_engine
from sqlalchemy import create_engine, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm.session import Session
@@ -187,10 +187,13 @@ def test_query_datasources_by_permissions_with_catalog_schema(
["[my_db].[db1].[schema1]", "[my_other_db].[schema]"], # type: ignore
)
clause = db.session.query().filter_by().filter.mock_calls[0].args[0]
assert str(clause.compile(engine, compile_kwargs={"literal_binds": True})) == (
"tables.perm IN ('[my_db].[table1](id:1)') OR "
"tables.schema_perm IN ('[my_db].[db1].[schema1]', '[my_other_db].[schema]') OR " # noqa: E501
"tables.catalog_perm IN ('[my_db].[db1]')"
assert (
str(clause.compile(engine, compile_kwargs={"literal_binds": True}))
== (
"tables.perm IN ('[my_db].[table1](id:1)') OR "
"tables.schema_perm IN ('[my_db].[db1].[schema1]', '[my_other_db].[schema]') OR " # noqa: E501
"tables.catalog_perm IN ('[my_db].[db1]')"
)
)
@@ -763,9 +766,9 @@ def test_get_sqla_table_quoting_for_cross_catalog(
# The compiled SQL should contain each part quoted separately
assert expected_in_sql in compiled, f"Expected {expected_in_sql} in SQL: {compiled}"
# Should NOT have the entire identifier quoted as one string
assert not_expected_in_sql not in compiled, (
f"Should not have {not_expected_in_sql} in SQL: {compiled}"
)
assert (
not_expected_in_sql not in compiled
), f"Should not have {not_expected_in_sql} in SQL: {compiled}"
def test_get_sqla_table_without_cross_catalog_ignores_catalog(
@@ -842,3 +845,398 @@ def test_quoted_name_prevents_double_quoting(mocker: MockerFixture) -> None:
# Should have each part quoted separately:
# GOOD: "MY_DB"."MY_SCHEMA"."MY_TABLE"
assert '"MY_DB"."MY_SCHEMA"."MY_TABLE"' in compiled
def test_apply_what_if_transform_single_modification(mocker: MockerFixture) -> None:
"""
Test that _apply_what_if_transform correctly transforms a single column.
"""
engine = create_engine("sqlite://")
database = mocker.MagicMock()
database.db_engine_spec.engine = "sqlite"
# Create table with columns
table = SqlaTable(
table_name="sales",
database=database,
columns=[
TableColumn(column_name="date"),
TableColumn(column_name="ad_spend"),
TableColumn(column_name="revenue"),
],
)
# Get the base table
source = table.get_sqla_table()
# Apply what-if transformation
what_if = {
"modifications": [{"column": "ad_spend", "multiplier": 1.1}],
"needed_columns": {"date", "ad_spend", "revenue"},
}
result = table._apply_what_if_transform(source, what_if)
# Compile to SQL and verify
query = select(result)
compiled = str(query.compile(engine, compile_kwargs={"literal_binds": True}))
# Should have the subquery alias
assert "__what_if" in compiled
# Should have the multiplication
assert "ad_spend * 1.1" in compiled
def test_apply_what_if_transform_multiple_modifications(mocker: MockerFixture) -> None:
"""
Test that _apply_what_if_transform correctly transforms multiple columns.
"""
engine = create_engine("sqlite://")
database = mocker.MagicMock()
database.db_engine_spec.engine = "sqlite"
table = SqlaTable(
table_name="sales",
database=database,
columns=[
TableColumn(column_name="date"),
TableColumn(column_name="ad_spend"),
TableColumn(column_name="revenue"),
TableColumn(column_name="conversions"),
],
)
source = table.get_sqla_table()
what_if = {
"modifications": [
{"column": "ad_spend", "multiplier": 1.1},
{"column": "revenue", "multiplier": 0.95},
],
"needed_columns": {"date", "ad_spend", "revenue"},
}
result = table._apply_what_if_transform(source, what_if)
query = select(result)
compiled = str(query.compile(engine, compile_kwargs={"literal_binds": True}))
assert "ad_spend * 1.1" in compiled
assert "revenue * 0.95" in compiled
# conversions should not be in the query since it's not in needed_columns
assert "conversions" not in compiled
def test_apply_what_if_transform_no_modifications(mocker: MockerFixture) -> None:
"""
Test that _apply_what_if_transform returns source unchanged when no modifications.
"""
database = mocker.MagicMock()
table = SqlaTable(
table_name="sales",
database=database,
columns=[
TableColumn(column_name="date"),
TableColumn(column_name="ad_spend"),
],
)
source = table.get_sqla_table()
what_if = {
"modifications": [],
"needed_columns": {"date", "ad_spend"},
}
result = table._apply_what_if_transform(source, what_if)
# Should return source unchanged
assert result is source
def test_apply_what_if_transform_only_needed_columns(mocker: MockerFixture) -> None:
"""
Test that _apply_what_if_transform only includes needed
columns plus modified columns.
"""
engine = create_engine("sqlite://")
database = mocker.MagicMock()
database.db_engine_spec.engine = "sqlite"
# Create table with many columns
table = SqlaTable(
table_name="sales",
database=database,
columns=[
TableColumn(column_name="col1"),
TableColumn(column_name="col2"),
TableColumn(column_name="col3"),
TableColumn(column_name="ad_spend"),
TableColumn(column_name="col5"),
],
)
source = table.get_sqla_table()
# Only need col1, but modifying ad_spend
what_if = {
"modifications": [{"column": "ad_spend", "multiplier": 1.1}],
"needed_columns": {"col1"},
}
result = table._apply_what_if_transform(source, what_if)
query = select(result)
compiled = str(query.compile(engine, compile_kwargs={"literal_binds": True}))
# Should have col1 (needed) and ad_spend (modified)
assert "col1" in compiled
assert "ad_spend" in compiled
# Should NOT have other columns
assert "col2" not in compiled
assert "col3" not in compiled
assert "col5" not in compiled
def test_apply_what_if_transform_nonexistent_column(mocker: MockerFixture) -> None:
"""
Test that _apply_what_if_transform handles modifications for non-existent columns.
"""
engine = create_engine("sqlite://")
database = mocker.MagicMock()
database.db_engine_spec.engine = "sqlite"
table = SqlaTable(
table_name="sales",
database=database,
columns=[
TableColumn(column_name="date"),
TableColumn(column_name="revenue"),
],
)
source = table.get_sqla_table()
# Try to modify a column that doesn't exist
what_if = {
"modifications": [{"column": "nonexistent_column", "multiplier": 1.1}],
"needed_columns": {"date", "revenue"},
}
result = table._apply_what_if_transform(source, what_if)
query = select(result)
compiled = str(query.compile(engine, compile_kwargs={"literal_binds": True}))
# Should still work, just without the nonexistent column
assert "date" in compiled
assert "revenue" in compiled
assert "nonexistent_column" not in compiled
def test_collect_needed_columns(mocker: MockerFixture) -> None:
"""
Test that _collect_needed_columns extracts columns from query parameters.
"""
database = mocker.MagicMock()
table = SqlaTable(
table_name="sales",
database=database,
columns=[
TableColumn(column_name="date"),
TableColumn(column_name="region"),
TableColumn(column_name="ad_spend"),
TableColumn(column_name="revenue"),
],
)
# Test with various query parameters
needed = table._collect_needed_columns(
columns=["date", "region"],
groupby=["date"],
metrics=[
{
"expressionType": "SIMPLE",
"column": {"column_name": "ad_spend"},
"aggregate": "SUM",
}
],
filter=[{"col": "revenue", "op": ">", "val": 100}],
orderby=[("date", True)],
granularity="date",
)
# Should include all referenced columns
assert "date" in needed # from columns, groupby, orderby, granularity
assert "region" in needed # from columns
assert "ad_spend" in needed # from metrics
assert "revenue" in needed # from filter
def test_collect_needed_columns_empty(mocker: MockerFixture) -> None:
"""
Test that _collect_needed_columns handles empty/None parameters.
"""
database = mocker.MagicMock()
table = SqlaTable(
table_name="sales",
database=database,
columns=[TableColumn(column_name="col1")],
)
needed = table._collect_needed_columns(
columns=None,
groupby=None,
metrics=None,
filter=None,
orderby=None,
granularity=None,
)
assert needed == set()
def test_collect_needed_columns_returns_none_for_sql_metrics(
mocker: MockerFixture,
) -> None:
"""
Test that _collect_needed_columns returns None for SQL-type adhoc metrics,
indicating all columns should be included.
"""
database = mocker.MagicMock()
table = SqlaTable(
table_name="sales",
database=database,
columns=[TableColumn(column_name="col1")],
)
# SQL-type adhoc metric - can't determine columns
needed = table._collect_needed_columns(
columns=["date"],
metrics=[
{
"expressionType": "SQL",
"sqlExpression": "SUM(hidden_column)",
}
],
)
# Should return None to signal all columns needed
assert needed is None
def test_collect_needed_columns_returns_none_for_saved_metrics(
mocker: MockerFixture,
) -> None:
"""
Test that _collect_needed_columns returns None for saved metrics (strings),
indicating all columns should be included.
"""
database = mocker.MagicMock()
table = SqlaTable(
table_name="sales",
database=database,
columns=[TableColumn(column_name="col1")],
)
# Saved metric (string) - can't determine columns
needed = table._collect_needed_columns(
columns=["date"],
metrics=["saved_metric_name"],
)
# Should return None to signal all columns needed
assert needed is None
def test_apply_what_if_transform_all_columns_when_needed_none(
mocker: MockerFixture,
) -> None:
"""
Test that _apply_what_if_transform includes all columns when needed_columns is None.
"""
engine = create_engine("sqlite://")
database = mocker.MagicMock()
database.db_engine_spec.engine = "sqlite"
table = SqlaTable(
table_name="sales",
database=database,
columns=[
TableColumn(column_name="col1"),
TableColumn(column_name="col2"),
TableColumn(column_name="ad_spend"),
TableColumn(column_name="col4"),
],
)
source = table.get_sqla_table()
# needed_columns is None - should include all columns
what_if = {
"modifications": [{"column": "ad_spend", "multiplier": 1.1}],
"needed_columns": None,
}
result = table._apply_what_if_transform(source, what_if)
query = select(result)
compiled = str(query.compile(engine, compile_kwargs={"literal_binds": True}))
# Should have all columns
assert "col1" in compiled
assert "col2" in compiled
assert "ad_spend" in compiled
assert "col4" in compiled
def test_collect_needed_columns_returns_none_for_adhoc_columns(
mocker: MockerFixture,
) -> None:
"""
Test that _collect_needed_columns returns None for adhoc columns
with SQL expressions, indicating all columns should be included.
"""
database = mocker.MagicMock()
table = SqlaTable(
table_name="sales",
database=database,
columns=[TableColumn(column_name="col1")],
)
# Adhoc column with SQL expression in columns list
needed = table._collect_needed_columns(
columns=[
"date",
{"label": "custom_col", "sqlExpression": "CONCAT(first_name, last_name)"},
],
metrics=[],
)
# Should return None to signal all columns needed
assert needed is None
def test_collect_needed_columns_returns_none_for_adhoc_groupby(
mocker: MockerFixture,
) -> None:
"""
Test that _collect_needed_columns returns None for adhoc columns in groupby.
"""
database = mocker.MagicMock()
table = SqlaTable(
table_name="sales",
database=database,
columns=[TableColumn(column_name="col1")],
)
# Adhoc column in groupby
needed = table._collect_needed_columns(
groupby=[
{"label": "year", "sqlExpression": "EXTRACT(YEAR FROM date)"},
],
metrics=[],
)
assert needed is None