feat: WHAT IF - backend

This commit is contained in:
Kamil Gabryjelski
2025-12-16 14:31:09 +01:00
parent 6f8052b828
commit 4dab58f8c0
3 changed files with 582 additions and 12 deletions

View File

@@ -18,7 +18,7 @@
import pandas as pd
import pytest
from pytest_mock import MockerFixture
from sqlalchemy import create_engine
from sqlalchemy import create_engine, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm.session import Session
@@ -187,10 +187,13 @@ def test_query_datasources_by_permissions_with_catalog_schema(
["[my_db].[db1].[schema1]", "[my_other_db].[schema]"], # type: ignore
)
clause = db.session.query().filter_by().filter.mock_calls[0].args[0]
assert str(clause.compile(engine, compile_kwargs={"literal_binds": True})) == (
"tables.perm IN ('[my_db].[table1](id:1)') OR "
"tables.schema_perm IN ('[my_db].[db1].[schema1]', '[my_other_db].[schema]') OR " # noqa: E501
"tables.catalog_perm IN ('[my_db].[db1]')"
assert (
str(clause.compile(engine, compile_kwargs={"literal_binds": True}))
== (
"tables.perm IN ('[my_db].[table1](id:1)') OR "
"tables.schema_perm IN ('[my_db].[db1].[schema1]', '[my_other_db].[schema]') OR " # noqa: E501
"tables.catalog_perm IN ('[my_db].[db1]')"
)
)
@@ -763,9 +766,9 @@ def test_get_sqla_table_quoting_for_cross_catalog(
# The compiled SQL should contain each part quoted separately
assert expected_in_sql in compiled, f"Expected {expected_in_sql} in SQL: {compiled}"
# Should NOT have the entire identifier quoted as one string
assert not_expected_in_sql not in compiled, (
f"Should not have {not_expected_in_sql} in SQL: {compiled}"
)
assert (
not_expected_in_sql not in compiled
), f"Should not have {not_expected_in_sql} in SQL: {compiled}"
def test_get_sqla_table_without_cross_catalog_ignores_catalog(
@@ -842,3 +845,398 @@ def test_quoted_name_prevents_double_quoting(mocker: MockerFixture) -> None:
# Should have each part quoted separately:
# GOOD: "MY_DB"."MY_SCHEMA"."MY_TABLE"
assert '"MY_DB"."MY_SCHEMA"."MY_TABLE"' in compiled
def test_apply_what_if_transform_single_modification(mocker: MockerFixture) -> None:
"""
Test that _apply_what_if_transform correctly transforms a single column.
"""
engine = create_engine("sqlite://")
database = mocker.MagicMock()
database.db_engine_spec.engine = "sqlite"
# Create table with columns
table = SqlaTable(
table_name="sales",
database=database,
columns=[
TableColumn(column_name="date"),
TableColumn(column_name="ad_spend"),
TableColumn(column_name="revenue"),
],
)
# Get the base table
source = table.get_sqla_table()
# Apply what-if transformation
what_if = {
"modifications": [{"column": "ad_spend", "multiplier": 1.1}],
"needed_columns": {"date", "ad_spend", "revenue"},
}
result = table._apply_what_if_transform(source, what_if)
# Compile to SQL and verify
query = select(result)
compiled = str(query.compile(engine, compile_kwargs={"literal_binds": True}))
# Should have the subquery alias
assert "__what_if" in compiled
# Should have the multiplication
assert "ad_spend * 1.1" in compiled
def test_apply_what_if_transform_multiple_modifications(mocker: MockerFixture) -> None:
"""
Test that _apply_what_if_transform correctly transforms multiple columns.
"""
engine = create_engine("sqlite://")
database = mocker.MagicMock()
database.db_engine_spec.engine = "sqlite"
table = SqlaTable(
table_name="sales",
database=database,
columns=[
TableColumn(column_name="date"),
TableColumn(column_name="ad_spend"),
TableColumn(column_name="revenue"),
TableColumn(column_name="conversions"),
],
)
source = table.get_sqla_table()
what_if = {
"modifications": [
{"column": "ad_spend", "multiplier": 1.1},
{"column": "revenue", "multiplier": 0.95},
],
"needed_columns": {"date", "ad_spend", "revenue"},
}
result = table._apply_what_if_transform(source, what_if)
query = select(result)
compiled = str(query.compile(engine, compile_kwargs={"literal_binds": True}))
assert "ad_spend * 1.1" in compiled
assert "revenue * 0.95" in compiled
# conversions should not be in the query since it's not in needed_columns
assert "conversions" not in compiled
def test_apply_what_if_transform_no_modifications(mocker: MockerFixture) -> None:
"""
Test that _apply_what_if_transform returns source unchanged when no modifications.
"""
database = mocker.MagicMock()
table = SqlaTable(
table_name="sales",
database=database,
columns=[
TableColumn(column_name="date"),
TableColumn(column_name="ad_spend"),
],
)
source = table.get_sqla_table()
what_if = {
"modifications": [],
"needed_columns": {"date", "ad_spend"},
}
result = table._apply_what_if_transform(source, what_if)
# Should return source unchanged
assert result is source
def test_apply_what_if_transform_only_needed_columns(mocker: MockerFixture) -> None:
"""
Test that _apply_what_if_transform only includes needed
columns plus modified columns.
"""
engine = create_engine("sqlite://")
database = mocker.MagicMock()
database.db_engine_spec.engine = "sqlite"
# Create table with many columns
table = SqlaTable(
table_name="sales",
database=database,
columns=[
TableColumn(column_name="col1"),
TableColumn(column_name="col2"),
TableColumn(column_name="col3"),
TableColumn(column_name="ad_spend"),
TableColumn(column_name="col5"),
],
)
source = table.get_sqla_table()
# Only need col1, but modifying ad_spend
what_if = {
"modifications": [{"column": "ad_spend", "multiplier": 1.1}],
"needed_columns": {"col1"},
}
result = table._apply_what_if_transform(source, what_if)
query = select(result)
compiled = str(query.compile(engine, compile_kwargs={"literal_binds": True}))
# Should have col1 (needed) and ad_spend (modified)
assert "col1" in compiled
assert "ad_spend" in compiled
# Should NOT have other columns
assert "col2" not in compiled
assert "col3" not in compiled
assert "col5" not in compiled
def test_apply_what_if_transform_nonexistent_column(mocker: MockerFixture) -> None:
"""
Test that _apply_what_if_transform handles modifications for non-existent columns.
"""
engine = create_engine("sqlite://")
database = mocker.MagicMock()
database.db_engine_spec.engine = "sqlite"
table = SqlaTable(
table_name="sales",
database=database,
columns=[
TableColumn(column_name="date"),
TableColumn(column_name="revenue"),
],
)
source = table.get_sqla_table()
# Try to modify a column that doesn't exist
what_if = {
"modifications": [{"column": "nonexistent_column", "multiplier": 1.1}],
"needed_columns": {"date", "revenue"},
}
result = table._apply_what_if_transform(source, what_if)
query = select(result)
compiled = str(query.compile(engine, compile_kwargs={"literal_binds": True}))
# Should still work, just without the nonexistent column
assert "date" in compiled
assert "revenue" in compiled
assert "nonexistent_column" not in compiled
def test_collect_needed_columns(mocker: MockerFixture) -> None:
"""
Test that _collect_needed_columns extracts columns from query parameters.
"""
database = mocker.MagicMock()
table = SqlaTable(
table_name="sales",
database=database,
columns=[
TableColumn(column_name="date"),
TableColumn(column_name="region"),
TableColumn(column_name="ad_spend"),
TableColumn(column_name="revenue"),
],
)
# Test with various query parameters
needed = table._collect_needed_columns(
columns=["date", "region"],
groupby=["date"],
metrics=[
{
"expressionType": "SIMPLE",
"column": {"column_name": "ad_spend"},
"aggregate": "SUM",
}
],
filter=[{"col": "revenue", "op": ">", "val": 100}],
orderby=[("date", True)],
granularity="date",
)
# Should include all referenced columns
assert "date" in needed # from columns, groupby, orderby, granularity
assert "region" in needed # from columns
assert "ad_spend" in needed # from metrics
assert "revenue" in needed # from filter
def test_collect_needed_columns_empty(mocker: MockerFixture) -> None:
"""
Test that _collect_needed_columns handles empty/None parameters.
"""
database = mocker.MagicMock()
table = SqlaTable(
table_name="sales",
database=database,
columns=[TableColumn(column_name="col1")],
)
needed = table._collect_needed_columns(
columns=None,
groupby=None,
metrics=None,
filter=None,
orderby=None,
granularity=None,
)
assert needed == set()
def test_collect_needed_columns_returns_none_for_sql_metrics(
mocker: MockerFixture,
) -> None:
"""
Test that _collect_needed_columns returns None for SQL-type adhoc metrics,
indicating all columns should be included.
"""
database = mocker.MagicMock()
table = SqlaTable(
table_name="sales",
database=database,
columns=[TableColumn(column_name="col1")],
)
# SQL-type adhoc metric - can't determine columns
needed = table._collect_needed_columns(
columns=["date"],
metrics=[
{
"expressionType": "SQL",
"sqlExpression": "SUM(hidden_column)",
}
],
)
# Should return None to signal all columns needed
assert needed is None
def test_collect_needed_columns_returns_none_for_saved_metrics(
mocker: MockerFixture,
) -> None:
"""
Test that _collect_needed_columns returns None for saved metrics (strings),
indicating all columns should be included.
"""
database = mocker.MagicMock()
table = SqlaTable(
table_name="sales",
database=database,
columns=[TableColumn(column_name="col1")],
)
# Saved metric (string) - can't determine columns
needed = table._collect_needed_columns(
columns=["date"],
metrics=["saved_metric_name"],
)
# Should return None to signal all columns needed
assert needed is None
def test_apply_what_if_transform_all_columns_when_needed_none(
mocker: MockerFixture,
) -> None:
"""
Test that _apply_what_if_transform includes all columns when needed_columns is None.
"""
engine = create_engine("sqlite://")
database = mocker.MagicMock()
database.db_engine_spec.engine = "sqlite"
table = SqlaTable(
table_name="sales",
database=database,
columns=[
TableColumn(column_name="col1"),
TableColumn(column_name="col2"),
TableColumn(column_name="ad_spend"),
TableColumn(column_name="col4"),
],
)
source = table.get_sqla_table()
# needed_columns is None - should include all columns
what_if = {
"modifications": [{"column": "ad_spend", "multiplier": 1.1}],
"needed_columns": None,
}
result = table._apply_what_if_transform(source, what_if)
query = select(result)
compiled = str(query.compile(engine, compile_kwargs={"literal_binds": True}))
# Should have all columns
assert "col1" in compiled
assert "col2" in compiled
assert "ad_spend" in compiled
assert "col4" in compiled
def test_collect_needed_columns_returns_none_for_adhoc_columns(
mocker: MockerFixture,
) -> None:
"""
Test that _collect_needed_columns returns None for adhoc columns
with SQL expressions, indicating all columns should be included.
"""
database = mocker.MagicMock()
table = SqlaTable(
table_name="sales",
database=database,
columns=[TableColumn(column_name="col1")],
)
# Adhoc column with SQL expression in columns list
needed = table._collect_needed_columns(
columns=[
"date",
{"label": "custom_col", "sqlExpression": "CONCAT(first_name, last_name)"},
],
metrics=[],
)
# Should return None to signal all columns needed
assert needed is None
def test_collect_needed_columns_returns_none_for_adhoc_groupby(
mocker: MockerFixture,
) -> None:
"""
Test that _collect_needed_columns returns None for adhoc columns in groupby.
"""
database = mocker.MagicMock()
table = SqlaTable(
table_name="sales",
database=database,
columns=[TableColumn(column_name="col1")],
)
# Adhoc column in groupby
needed = table._collect_needed_columns(
groupby=[
{"label": "year", "sqlExpression": "EXTRACT(YEAR FROM date)"},
],
metrics=[],
)
assert needed is None