mirror of
https://github.com/apache/superset.git
synced 2026-05-12 19:35:17 +00:00
feat: WHAT IF - backend
This commit is contained in:
@@ -18,7 +18,7 @@
|
||||
import pandas as pd
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy import create_engine, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm.session import Session
|
||||
|
||||
@@ -187,10 +187,13 @@ def test_query_datasources_by_permissions_with_catalog_schema(
|
||||
["[my_db].[db1].[schema1]", "[my_other_db].[schema]"], # type: ignore
|
||||
)
|
||||
clause = db.session.query().filter_by().filter.mock_calls[0].args[0]
|
||||
assert str(clause.compile(engine, compile_kwargs={"literal_binds": True})) == (
|
||||
"tables.perm IN ('[my_db].[table1](id:1)') OR "
|
||||
"tables.schema_perm IN ('[my_db].[db1].[schema1]', '[my_other_db].[schema]') OR " # noqa: E501
|
||||
"tables.catalog_perm IN ('[my_db].[db1]')"
|
||||
assert (
|
||||
str(clause.compile(engine, compile_kwargs={"literal_binds": True}))
|
||||
== (
|
||||
"tables.perm IN ('[my_db].[table1](id:1)') OR "
|
||||
"tables.schema_perm IN ('[my_db].[db1].[schema1]', '[my_other_db].[schema]') OR " # noqa: E501
|
||||
"tables.catalog_perm IN ('[my_db].[db1]')"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -763,9 +766,9 @@ def test_get_sqla_table_quoting_for_cross_catalog(
|
||||
# The compiled SQL should contain each part quoted separately
|
||||
assert expected_in_sql in compiled, f"Expected {expected_in_sql} in SQL: {compiled}"
|
||||
# Should NOT have the entire identifier quoted as one string
|
||||
assert not_expected_in_sql not in compiled, (
|
||||
f"Should not have {not_expected_in_sql} in SQL: {compiled}"
|
||||
)
|
||||
assert (
|
||||
not_expected_in_sql not in compiled
|
||||
), f"Should not have {not_expected_in_sql} in SQL: {compiled}"
|
||||
|
||||
|
||||
def test_get_sqla_table_without_cross_catalog_ignores_catalog(
|
||||
@@ -842,3 +845,398 @@ def test_quoted_name_prevents_double_quoting(mocker: MockerFixture) -> None:
|
||||
# Should have each part quoted separately:
|
||||
# GOOD: "MY_DB"."MY_SCHEMA"."MY_TABLE"
|
||||
assert '"MY_DB"."MY_SCHEMA"."MY_TABLE"' in compiled
|
||||
|
||||
|
||||
def test_apply_what_if_transform_single_modification(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test that _apply_what_if_transform correctly transforms a single column.
|
||||
"""
|
||||
engine = create_engine("sqlite://")
|
||||
database = mocker.MagicMock()
|
||||
database.db_engine_spec.engine = "sqlite"
|
||||
|
||||
# Create table with columns
|
||||
table = SqlaTable(
|
||||
table_name="sales",
|
||||
database=database,
|
||||
columns=[
|
||||
TableColumn(column_name="date"),
|
||||
TableColumn(column_name="ad_spend"),
|
||||
TableColumn(column_name="revenue"),
|
||||
],
|
||||
)
|
||||
|
||||
# Get the base table
|
||||
source = table.get_sqla_table()
|
||||
|
||||
# Apply what-if transformation
|
||||
what_if = {
|
||||
"modifications": [{"column": "ad_spend", "multiplier": 1.1}],
|
||||
"needed_columns": {"date", "ad_spend", "revenue"},
|
||||
}
|
||||
result = table._apply_what_if_transform(source, what_if)
|
||||
|
||||
# Compile to SQL and verify
|
||||
query = select(result)
|
||||
compiled = str(query.compile(engine, compile_kwargs={"literal_binds": True}))
|
||||
|
||||
# Should have the subquery alias
|
||||
assert "__what_if" in compiled
|
||||
# Should have the multiplication
|
||||
assert "ad_spend * 1.1" in compiled
|
||||
|
||||
|
||||
def test_apply_what_if_transform_multiple_modifications(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test that _apply_what_if_transform correctly transforms multiple columns.
|
||||
"""
|
||||
engine = create_engine("sqlite://")
|
||||
database = mocker.MagicMock()
|
||||
database.db_engine_spec.engine = "sqlite"
|
||||
|
||||
table = SqlaTable(
|
||||
table_name="sales",
|
||||
database=database,
|
||||
columns=[
|
||||
TableColumn(column_name="date"),
|
||||
TableColumn(column_name="ad_spend"),
|
||||
TableColumn(column_name="revenue"),
|
||||
TableColumn(column_name="conversions"),
|
||||
],
|
||||
)
|
||||
|
||||
source = table.get_sqla_table()
|
||||
|
||||
what_if = {
|
||||
"modifications": [
|
||||
{"column": "ad_spend", "multiplier": 1.1},
|
||||
{"column": "revenue", "multiplier": 0.95},
|
||||
],
|
||||
"needed_columns": {"date", "ad_spend", "revenue"},
|
||||
}
|
||||
result = table._apply_what_if_transform(source, what_if)
|
||||
|
||||
query = select(result)
|
||||
compiled = str(query.compile(engine, compile_kwargs={"literal_binds": True}))
|
||||
|
||||
assert "ad_spend * 1.1" in compiled
|
||||
assert "revenue * 0.95" in compiled
|
||||
# conversions should not be in the query since it's not in needed_columns
|
||||
assert "conversions" not in compiled
|
||||
|
||||
|
||||
def test_apply_what_if_transform_no_modifications(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test that _apply_what_if_transform returns source unchanged when no modifications.
|
||||
"""
|
||||
database = mocker.MagicMock()
|
||||
|
||||
table = SqlaTable(
|
||||
table_name="sales",
|
||||
database=database,
|
||||
columns=[
|
||||
TableColumn(column_name="date"),
|
||||
TableColumn(column_name="ad_spend"),
|
||||
],
|
||||
)
|
||||
|
||||
source = table.get_sqla_table()
|
||||
|
||||
what_if = {
|
||||
"modifications": [],
|
||||
"needed_columns": {"date", "ad_spend"},
|
||||
}
|
||||
result = table._apply_what_if_transform(source, what_if)
|
||||
|
||||
# Should return source unchanged
|
||||
assert result is source
|
||||
|
||||
|
||||
def test_apply_what_if_transform_only_needed_columns(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test that _apply_what_if_transform only includes needed
|
||||
columns plus modified columns.
|
||||
"""
|
||||
engine = create_engine("sqlite://")
|
||||
database = mocker.MagicMock()
|
||||
database.db_engine_spec.engine = "sqlite"
|
||||
|
||||
# Create table with many columns
|
||||
table = SqlaTable(
|
||||
table_name="sales",
|
||||
database=database,
|
||||
columns=[
|
||||
TableColumn(column_name="col1"),
|
||||
TableColumn(column_name="col2"),
|
||||
TableColumn(column_name="col3"),
|
||||
TableColumn(column_name="ad_spend"),
|
||||
TableColumn(column_name="col5"),
|
||||
],
|
||||
)
|
||||
|
||||
source = table.get_sqla_table()
|
||||
|
||||
# Only need col1, but modifying ad_spend
|
||||
what_if = {
|
||||
"modifications": [{"column": "ad_spend", "multiplier": 1.1}],
|
||||
"needed_columns": {"col1"},
|
||||
}
|
||||
result = table._apply_what_if_transform(source, what_if)
|
||||
|
||||
query = select(result)
|
||||
compiled = str(query.compile(engine, compile_kwargs={"literal_binds": True}))
|
||||
|
||||
# Should have col1 (needed) and ad_spend (modified)
|
||||
assert "col1" in compiled
|
||||
assert "ad_spend" in compiled
|
||||
# Should NOT have other columns
|
||||
assert "col2" not in compiled
|
||||
assert "col3" not in compiled
|
||||
assert "col5" not in compiled
|
||||
|
||||
|
||||
def test_apply_what_if_transform_nonexistent_column(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test that _apply_what_if_transform handles modifications for non-existent columns.
|
||||
"""
|
||||
engine = create_engine("sqlite://")
|
||||
database = mocker.MagicMock()
|
||||
database.db_engine_spec.engine = "sqlite"
|
||||
|
||||
table = SqlaTable(
|
||||
table_name="sales",
|
||||
database=database,
|
||||
columns=[
|
||||
TableColumn(column_name="date"),
|
||||
TableColumn(column_name="revenue"),
|
||||
],
|
||||
)
|
||||
|
||||
source = table.get_sqla_table()
|
||||
|
||||
# Try to modify a column that doesn't exist
|
||||
what_if = {
|
||||
"modifications": [{"column": "nonexistent_column", "multiplier": 1.1}],
|
||||
"needed_columns": {"date", "revenue"},
|
||||
}
|
||||
result = table._apply_what_if_transform(source, what_if)
|
||||
|
||||
query = select(result)
|
||||
compiled = str(query.compile(engine, compile_kwargs={"literal_binds": True}))
|
||||
|
||||
# Should still work, just without the nonexistent column
|
||||
assert "date" in compiled
|
||||
assert "revenue" in compiled
|
||||
assert "nonexistent_column" not in compiled
|
||||
|
||||
|
||||
def test_collect_needed_columns(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test that _collect_needed_columns extracts columns from query parameters.
|
||||
"""
|
||||
database = mocker.MagicMock()
|
||||
|
||||
table = SqlaTable(
|
||||
table_name="sales",
|
||||
database=database,
|
||||
columns=[
|
||||
TableColumn(column_name="date"),
|
||||
TableColumn(column_name="region"),
|
||||
TableColumn(column_name="ad_spend"),
|
||||
TableColumn(column_name="revenue"),
|
||||
],
|
||||
)
|
||||
|
||||
# Test with various query parameters
|
||||
needed = table._collect_needed_columns(
|
||||
columns=["date", "region"],
|
||||
groupby=["date"],
|
||||
metrics=[
|
||||
{
|
||||
"expressionType": "SIMPLE",
|
||||
"column": {"column_name": "ad_spend"},
|
||||
"aggregate": "SUM",
|
||||
}
|
||||
],
|
||||
filter=[{"col": "revenue", "op": ">", "val": 100}],
|
||||
orderby=[("date", True)],
|
||||
granularity="date",
|
||||
)
|
||||
|
||||
# Should include all referenced columns
|
||||
assert "date" in needed # from columns, groupby, orderby, granularity
|
||||
assert "region" in needed # from columns
|
||||
assert "ad_spend" in needed # from metrics
|
||||
assert "revenue" in needed # from filter
|
||||
|
||||
|
||||
def test_collect_needed_columns_empty(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test that _collect_needed_columns handles empty/None parameters.
|
||||
"""
|
||||
database = mocker.MagicMock()
|
||||
|
||||
table = SqlaTable(
|
||||
table_name="sales",
|
||||
database=database,
|
||||
columns=[TableColumn(column_name="col1")],
|
||||
)
|
||||
|
||||
needed = table._collect_needed_columns(
|
||||
columns=None,
|
||||
groupby=None,
|
||||
metrics=None,
|
||||
filter=None,
|
||||
orderby=None,
|
||||
granularity=None,
|
||||
)
|
||||
|
||||
assert needed == set()
|
||||
|
||||
|
||||
def test_collect_needed_columns_returns_none_for_sql_metrics(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""
|
||||
Test that _collect_needed_columns returns None for SQL-type adhoc metrics,
|
||||
indicating all columns should be included.
|
||||
"""
|
||||
database = mocker.MagicMock()
|
||||
|
||||
table = SqlaTable(
|
||||
table_name="sales",
|
||||
database=database,
|
||||
columns=[TableColumn(column_name="col1")],
|
||||
)
|
||||
|
||||
# SQL-type adhoc metric - can't determine columns
|
||||
needed = table._collect_needed_columns(
|
||||
columns=["date"],
|
||||
metrics=[
|
||||
{
|
||||
"expressionType": "SQL",
|
||||
"sqlExpression": "SUM(hidden_column)",
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
# Should return None to signal all columns needed
|
||||
assert needed is None
|
||||
|
||||
|
||||
def test_collect_needed_columns_returns_none_for_saved_metrics(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""
|
||||
Test that _collect_needed_columns returns None for saved metrics (strings),
|
||||
indicating all columns should be included.
|
||||
"""
|
||||
database = mocker.MagicMock()
|
||||
|
||||
table = SqlaTable(
|
||||
table_name="sales",
|
||||
database=database,
|
||||
columns=[TableColumn(column_name="col1")],
|
||||
)
|
||||
|
||||
# Saved metric (string) - can't determine columns
|
||||
needed = table._collect_needed_columns(
|
||||
columns=["date"],
|
||||
metrics=["saved_metric_name"],
|
||||
)
|
||||
|
||||
# Should return None to signal all columns needed
|
||||
assert needed is None
|
||||
|
||||
|
||||
def test_apply_what_if_transform_all_columns_when_needed_none(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""
|
||||
Test that _apply_what_if_transform includes all columns when needed_columns is None.
|
||||
"""
|
||||
engine = create_engine("sqlite://")
|
||||
database = mocker.MagicMock()
|
||||
database.db_engine_spec.engine = "sqlite"
|
||||
|
||||
table = SqlaTable(
|
||||
table_name="sales",
|
||||
database=database,
|
||||
columns=[
|
||||
TableColumn(column_name="col1"),
|
||||
TableColumn(column_name="col2"),
|
||||
TableColumn(column_name="ad_spend"),
|
||||
TableColumn(column_name="col4"),
|
||||
],
|
||||
)
|
||||
|
||||
source = table.get_sqla_table()
|
||||
|
||||
# needed_columns is None - should include all columns
|
||||
what_if = {
|
||||
"modifications": [{"column": "ad_spend", "multiplier": 1.1}],
|
||||
"needed_columns": None,
|
||||
}
|
||||
result = table._apply_what_if_transform(source, what_if)
|
||||
|
||||
query = select(result)
|
||||
compiled = str(query.compile(engine, compile_kwargs={"literal_binds": True}))
|
||||
|
||||
# Should have all columns
|
||||
assert "col1" in compiled
|
||||
assert "col2" in compiled
|
||||
assert "ad_spend" in compiled
|
||||
assert "col4" in compiled
|
||||
|
||||
|
||||
def test_collect_needed_columns_returns_none_for_adhoc_columns(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""
|
||||
Test that _collect_needed_columns returns None for adhoc columns
|
||||
with SQL expressions, indicating all columns should be included.
|
||||
"""
|
||||
database = mocker.MagicMock()
|
||||
|
||||
table = SqlaTable(
|
||||
table_name="sales",
|
||||
database=database,
|
||||
columns=[TableColumn(column_name="col1")],
|
||||
)
|
||||
|
||||
# Adhoc column with SQL expression in columns list
|
||||
needed = table._collect_needed_columns(
|
||||
columns=[
|
||||
"date",
|
||||
{"label": "custom_col", "sqlExpression": "CONCAT(first_name, last_name)"},
|
||||
],
|
||||
metrics=[],
|
||||
)
|
||||
|
||||
# Should return None to signal all columns needed
|
||||
assert needed is None
|
||||
|
||||
|
||||
def test_collect_needed_columns_returns_none_for_adhoc_groupby(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""
|
||||
Test that _collect_needed_columns returns None for adhoc columns in groupby.
|
||||
"""
|
||||
database = mocker.MagicMock()
|
||||
|
||||
table = SqlaTable(
|
||||
table_name="sales",
|
||||
database=database,
|
||||
columns=[TableColumn(column_name="col1")],
|
||||
)
|
||||
|
||||
# Adhoc column in groupby
|
||||
needed = table._collect_needed_columns(
|
||||
groupby=[
|
||||
{"label": "year", "sqlExpression": "EXTRACT(YEAR FROM date)"},
|
||||
],
|
||||
metrics=[],
|
||||
)
|
||||
|
||||
assert needed is None
|
||||
|
||||
Reference in New Issue
Block a user