This commit is contained in:
Beto Dealmeida
2025-10-28 15:08:35 -04:00
parent e3dec47a5e
commit 29e335aa3e
2 changed files with 665 additions and 22 deletions

View File

@@ -16,33 +16,208 @@
# under the License.
from contextlib import nullcontext
from typing import Iterator
from unittest.mock import MagicMock, patch
import pytest
from pandas import DataFrame
from pytest_mock import MockerFixture
from superset.semantic_layers.snowflake_ import (
get_connection_parameters,
SnowflakeConfiguration,
SnowflakeSemanticLayer,
SnowflakeSemanticView,
substitute_parameters,
validate_order_by,
)
from superset.semantic_layers.types import (
AdhocFilter,
DATE,
Dimension,
Filter,
FilterValues,
INTEGER,
Metric,
NUMBER,
Operator,
PredicateType,
SemanticRequest,
STRING,
)
@pytest.fixture
def configuration() -> SnowflakeConfiguration:
return SnowflakeConfiguration.model_validate(
{
"account_identifier": "abcdefg-hij01234",
"role": "ACCOUNTADMIN",
"warehouse": "COMPUTE_WH",
"database": "SAMPLE_DATA",
"schema": "TPCDS_SF10TCL",
"auth": {
"auth_type": "user_password",
"username": "SNOWFLAKE_USER",
"password": "SNOWFLAKE_PASSWORD",
},
"allow_changing_database": True,
"allow_changing_schema": True,
}
)
# These fixtures reproduce the semantic view from
# https://quickstarts.snowflake.com/guide/snowflake-semantic-view/index.html
@pytest.fixture
def dimension_rows() -> list[dict[str, str]]:
return [
dict(
zip(
["object_name", "property", "property_value"],
row,
strict=False,
)
)
for row in [
("BIRTHYEAR", "TABLE", "CUSTOMER"),
("BIRTHYEAR", "EXPRESSION", "C_BIRTH_YEAR"),
("BIRTHYEAR", "DATA_TYPE", "NUMBER(38,0)"),
("COUNTRY", "TABLE", "CUSTOMER"),
("COUNTRY", "EXPRESSION", "C_BIRTH_COUNTRY"),
("COUNTRY", "DATA_TYPE", "VARCHAR(20)"),
("C_CUSTOMER_SK", "TABLE", "CUSTOMER"),
("C_CUSTOMER_SK", "EXPRESSION", "c_customer_sk"),
("C_CUSTOMER_SK", "DATA_TYPE", "NUMBER(38,0)"),
("DATE", "TABLE", "DATE"),
("DATE", "EXPRESSION", "D_DATE"),
("DATE", "DATA_TYPE", "DATE"),
("D_DATE_SK", "TABLE", "DATE"),
("D_DATE_SK", "EXPRESSION", "d_date_sk"),
("D_DATE_SK", "DATA_TYPE", "NUMBER(38,0)"),
("MONTH", "TABLE", "DATE"),
("MONTH", "EXPRESSION", "D_MOY"),
("MONTH", "DATA_TYPE", "NUMBER(38,0)"),
("WEEK", "TABLE", "DATE"),
("WEEK", "EXPRESSION", "D_WEEK_SEQ"),
("WEEK", "DATA_TYPE", "NUMBER(38,0)"),
("YEAR", "TABLE", "DATE"),
("YEAR", "EXPRESSION", "D_YEAR"),
("YEAR", "DATA_TYPE", "NUMBER(38,0)"),
("CD_DEMO_SK", "TABLE", "DEMO"),
("CD_DEMO_SK", "EXPRESSION", "cd_demo_sk"),
("CD_DEMO_SK", "DATA_TYPE", "NUMBER(38,0)"),
("CREDIT_RATING", "TABLE", "DEMO"),
("CREDIT_RATING", "EXPRESSION", "CD_CREDIT_RATING"),
("CREDIT_RATING", "DATA_TYPE", "VARCHAR(10)"),
("MARITAL_STATUS", "TABLE", "DEMO"),
("MARITAL_STATUS", "EXPRESSION", "CD_MARITAL_STATUS"),
("MARITAL_STATUS", "DATA_TYPE", "VARCHAR(1)"),
("BRAND", "TABLE", "ITEM"),
("BRAND", "EXPRESSION", "I_BRAND"),
("BRAND", "DATA_TYPE", "VARCHAR(50)"),
("CATEGORY", "TABLE", "ITEM"),
("CATEGORY", "EXPRESSION", "I_CATEGORY"),
("CATEGORY", "DATA_TYPE", "VARCHAR(50)"),
("CLASS", "TABLE", "ITEM"),
("CLASS", "EXPRESSION", "I_CLASS"),
("CLASS", "DATA_TYPE", "VARCHAR(50)"),
("I_ITEM_SK", "TABLE", "ITEM"),
("I_ITEM_SK", "EXPRESSION", "i_item_sk"),
("I_ITEM_SK", "DATA_TYPE", "NUMBER(38,0)"),
("MARKET", "TABLE", "STORE"),
("MARKET", "EXPRESSION", "S_MARKET_ID"),
("MARKET", "DATA_TYPE", "NUMBER(38,0)"),
("SQUAREFOOTAGE", "TABLE", "STORE"),
("SQUAREFOOTAGE", "EXPRESSION", "S_FLOOR_SPACE"),
("SQUAREFOOTAGE", "DATA_TYPE", "NUMBER(38,0)"),
("STATE", "TABLE", "STORE"),
("STATE", "EXPRESSION", "S_STATE"),
("STATE", "DATA_TYPE", "VARCHAR(2)"),
("STORECOUNTRY", "TABLE", "STORE"),
("STORECOUNTRY", "EXPRESSION", "S_COUNTRY"),
("STORECOUNTRY", "DATA_TYPE", "VARCHAR(20)"),
("S_STORE_SK", "TABLE", "STORE"),
("S_STORE_SK", "EXPRESSION", "s_store_sk"),
("S_STORE_SK", "DATA_TYPE", "NUMBER(38,0)"),
("SS_CDEMO_SK", "TABLE", "STORESALES"),
("SS_CDEMO_SK", "EXPRESSION", "ss_cdemo_sk"),
("SS_CDEMO_SK", "DATA_TYPE", "NUMBER(38,0)"),
("SS_CUSTOMER_SK", "TABLE", "STORESALES"),
("SS_CUSTOMER_SK", "EXPRESSION", "ss_customer_sk"),
("SS_CUSTOMER_SK", "DATA_TYPE", "NUMBER(38,0)"),
("SS_ITEM_SK", "TABLE", "STORESALES"),
("SS_ITEM_SK", "EXPRESSION", "ss_item_sk"),
("SS_ITEM_SK", "DATA_TYPE", "NUMBER(38,0)"),
("SS_SOLD_DATE_SK", "TABLE", "STORESALES"),
("SS_SOLD_DATE_SK", "EXPRESSION", "ss_sold_date_sk"),
("SS_SOLD_DATE_SK", "DATA_TYPE", "NUMBER(38,0)"),
("SS_STORE_SK", "TABLE", "STORESALES"),
("SS_STORE_SK", "EXPRESSION", "ss_store_sk"),
("SS_STORE_SK", "DATA_TYPE", "NUMBER(38,0)"),
]
]
@pytest.fixture
def metric_rows() -> list[dict[str, str]]:
return [
dict(
zip(
["object_name", "property", "property_value"],
row,
strict=False,
)
)
for row in [
("TOTALCOST", "TABLE", "STORESALES"),
("TOTALCOST", "EXPRESSION", "SUM(item.cost)"),
("TOTALCOST", "DATA_TYPE", "NUMBER(19,2)"),
("TOTALSALESPRICE", "TABLE", "STORESALES"),
("TOTALSALESPRICE", "EXPRESSION", "SUM(SS_SALES_PRICE)"),
("TOTALSALESPRICE", "DATA_TYPE", "NUMBER(19,2)"),
("TOTALSALESQUANTITY", "TABLE", "STORESALES"),
("TOTALSALESQUANTITY", "EXPRESSION", "SUM(SS_QUANTITY)"),
("TOTALSALESQUANTITY", "DATA_TYPE", "NUMBER(38,0)"),
]
]
@pytest.fixture
def connection(mocker: MockerFixture) -> Iterator[MagicMock]:
"""
Mock the Snowflake connect function to return a mock connection.
"""
connect = mocker.patch("superset.semantic_layers.snowflake_.connect")
with connect() as connection:
yield connection
@pytest.fixture
def semantic_view(
mocker: MockerFixture,
connection: MagicMock,
configuration: SnowflakeConfiguration,
dimension_rows: list[dict[str, str]],
metric_rows: list[dict[str, str]],
) -> SnowflakeSemanticView:
"""
Mock the SnowflakeSemanticView to return predefined dimensions and metrics.
"""
connection.cursor().execute().fetchall.side_effect = [
dimension_rows,
metric_rows,
]
return SnowflakeSemanticView(configuration, "TPCDS_SEMANTIC_VIEW_SM")
@pytest.mark.parametrize(
"query, parameters, expected",
[
# No parameters
(
"SELECT * FROM table",
None,
"SELECT * FROM table",
),
(
"SELECT * FROM table",
[],
"SELECT * FROM table",
),
("SELECT * FROM table", None, "SELECT * FROM table"),
("SELECT * FROM table", [], "SELECT * FROM table"),
# NULL values
(
"SELECT * FROM table WHERE id = ?",
@@ -610,3 +785,471 @@ def test_get_runtime_schema(
assert set(schema_enum) == set(schemas)
else:
assert "schema" not in schema["properties"]
def test_get_dimensions(
mocker: MockerFixture,
connection: MagicMock,
semantic_view: SnowflakeSemanticView,
) -> None:
"""
Test dimension retrieval and parsing from Snowflake semantic layer.
"""
assert semantic_view.dimensions == {
Dimension(
id="CUSTOMER.C_CUSTOMER_SK",
name="C_CUSTOMER_SK",
type=INTEGER,
definition=None,
description="c_customer_sk",
grain=None,
),
Dimension(
id="STORE.SQUAREFOOTAGE",
name="SQUAREFOOTAGE",
type=INTEGER,
definition=None,
description="S_FLOOR_SPACE",
grain=None,
),
Dimension(
id="ITEM.BRAND",
name="BRAND",
type=STRING,
definition=None,
description="I_BRAND",
grain=None,
),
Dimension(
id="ITEM.CATEGORY",
name="CATEGORY",
type=STRING,
definition=None,
description="I_CATEGORY",
grain=None,
),
Dimension(
id="STORE.S_STORE_SK",
name="S_STORE_SK",
type=INTEGER,
definition=None,
description="s_store_sk",
grain=None,
),
Dimension(
id="STORESALES.SS_CUSTOMER_SK",
name="SS_CUSTOMER_SK",
type=INTEGER,
definition=None,
description="ss_customer_sk",
grain=None,
),
Dimension(
id="DATE.DATE",
name="DATE",
type=DATE,
definition=None,
description="D_DATE",
grain=None,
),
Dimension(
id="DEMO.CD_DEMO_SK",
name="CD_DEMO_SK",
type=INTEGER,
definition=None,
description="cd_demo_sk",
grain=None,
),
Dimension(
id="DATE.MONTH",
name="MONTH",
type=INTEGER,
definition=None,
description="D_MOY",
grain=None,
),
Dimension(
id="STORE.MARKET",
name="MARKET",
type=INTEGER,
definition=None,
description="S_MARKET_ID",
grain=None,
),
Dimension(
id="STORESALES.SS_ITEM_SK",
name="SS_ITEM_SK",
type=INTEGER,
definition=None,
description="ss_item_sk",
grain=None,
),
Dimension(
id="STORE.STORECOUNTRY",
name="STORECOUNTRY",
type=STRING,
definition=None,
description="S_COUNTRY",
grain=None,
),
Dimension(
id="ITEM.CLASS",
name="CLASS",
type=STRING,
definition=None,
description="I_CLASS",
grain=None,
),
Dimension(
id="CUSTOMER.COUNTRY",
name="COUNTRY",
type=STRING,
definition=None,
description="C_BIRTH_COUNTRY",
grain=None,
),
Dimension(
id="DEMO.CREDIT_RATING",
name="CREDIT_RATING",
type=STRING,
definition=None,
description="CD_CREDIT_RATING",
grain=None,
),
Dimension(
id="DATE.WEEK",
name="WEEK",
type=INTEGER,
definition=None,
description="D_WEEK_SEQ",
grain=None,
),
Dimension(
id="DATE.D_DATE_SK",
name="D_DATE_SK",
type=INTEGER,
definition=None,
description="d_date_sk",
grain=None,
),
Dimension(
id="STORESALES.SS_SOLD_DATE_SK",
name="SS_SOLD_DATE_SK",
type=INTEGER,
definition=None,
description="ss_sold_date_sk",
grain=None,
),
Dimension(
id="CUSTOMER.BIRTHYEAR",
name="BIRTHYEAR",
type=INTEGER,
definition=None,
description="C_BIRTH_YEAR",
grain=None,
),
Dimension(
id="DEMO.MARITAL_STATUS",
name="MARITAL_STATUS",
type=STRING,
definition=None,
description="CD_MARITAL_STATUS",
grain=None,
),
Dimension(
id="STORESALES.SS_CDEMO_SK",
name="SS_CDEMO_SK",
type=INTEGER,
definition=None,
description="ss_cdemo_sk",
grain=None,
),
Dimension(
id="DATE.YEAR",
name="YEAR",
type=INTEGER,
definition=None,
description="D_YEAR",
grain=None,
),
Dimension(
id="ITEM.I_ITEM_SK",
name="I_ITEM_SK",
type=INTEGER,
definition=None,
description="i_item_sk",
grain=None,
),
Dimension(
id="STORESALES.SS_STORE_SK",
name="SS_STORE_SK",
type=INTEGER,
definition=None,
description="ss_store_sk",
grain=None,
),
Dimension(
id="STORE.STATE",
name="STATE",
type=STRING,
definition=None,
description="S_STATE",
grain=None,
),
}
connection.cursor().execute.assert_any_call(
"""
DESC SEMANTIC VIEW "SAMPLE_DATA"."TPCDS_SF10TCL"."TPCDS_SEMANTIC_VIEW_SM"
->> SELECT "object_name", "property", "property_value"
FROM $1
WHERE
"object_kind" = 'DIMENSION' AND
"property" IN ('COMMENT', 'DATA_TYPE', 'EXPRESSION', 'TABLE');
""".strip()
)
def test_get_metrics(
mocker: MockerFixture,
connection: MagicMock,
semantic_view: SnowflakeSemanticView,
) -> None:
"""
Test metric retrieval and parsing from Snowflake semantic layer.
"""
assert semantic_view.metrics == {
Metric(
id="STORESALES.TOTALCOST",
name="TOTALCOST",
type=NUMBER,
definition="SUM(item.cost)",
description=None,
),
Metric(
id="STORESALES.TOTALSALESQUANTITY",
name="TOTALSALESQUANTITY",
type=INTEGER,
definition="SUM(SS_QUANTITY)",
description=None,
),
Metric(
id="STORESALES.TOTALSALESPRICE",
name="TOTALSALESPRICE",
type=NUMBER,
definition="SUM(SS_SALES_PRICE)",
description=None,
),
}
connection.cursor().execute.assert_any_call(
"""
DESC SEMANTIC VIEW "SAMPLE_DATA"."TPCDS_SF10TCL"."TPCDS_SEMANTIC_VIEW_SM"
->> SELECT "object_name", "property", "property_value"
FROM $1
WHERE
"object_kind" = 'METRIC' AND
"property" IN ('COMMENT', 'DATA_TYPE', 'EXPRESSION', 'TABLE');
""".strip()
)
def test_get_values(
mocker: MockerFixture,
connection: MagicMock,
semantic_view: SnowflakeSemanticView,
) -> None:
connection.cursor().execute().fetch_pandas_all.return_value = DataFrame(
{
"CATEGORY": [
"Music",
"Women",
"Home",
"Children",
"Men",
"Electronics",
"Sports",
"Shoes",
"Jewelry",
"Books",
None,
]
}
)
dimension = Dimension(
id="ITEM.CATEGORY",
name="CATEGORY",
type=STRING,
description=None,
definition="I_CATEGORY",
grain=None,
)
result = semantic_view.get_values(dimension)
assert result.requests == [
SemanticRequest(
type="snowflake",
definition="""
SELECT "CATEGORY"
FROM SEMANTIC_VIEW(
"SAMPLE_DATA"."TPCDS_SF10TCL"."TPCDS_SEMANTIC_VIEW_SM"
DIMENSIONS ITEM.CATEGORY
)
""".strip(),
)
]
assert result.results["CATEGORY"].tolist() == [
"Music",
"Women",
"Home",
"Children",
"Men",
"Electronics",
"Sports",
"Shoes",
"Jewelry",
"Books",
None,
]
def test_get_values_with_filters(
mocker: MockerFixture,
connection: MagicMock,
semantic_view: SnowflakeSemanticView,
) -> None:
connection.cursor().execute().fetch_pandas_all.return_value = DataFrame(
{
"CATEGORY": [
"Music",
"Women",
"Home",
"Children",
"Men",
"Electronics",
"Sports",
"Shoes",
"Jewelry",
]
}
)
dimension = Dimension(
id="ITEM.CATEGORY",
name="CATEGORY",
type=STRING,
description=None,
definition="I_CATEGORY",
grain=None,
)
filters = {
Filter(PredicateType.WHERE, dimension, Operator.NOT_EQUALS, "Books"),
Filter(PredicateType.WHERE, dimension, Operator.IS_NOT_NULL, None),
}
result = semantic_view.get_values(dimension, filters)
assert result.requests == [
SemanticRequest(
type="snowflake",
definition="""
SELECT "CATEGORY"
FROM SEMANTIC_VIEW(
"SAMPLE_DATA"."TPCDS_SF10TCL"."TPCDS_SEMANTIC_VIEW_SM"
DIMENSIONS ITEM.CATEGORY
WHERE ("CATEGORY" != 'Books') AND ("CATEGORY" IS NOT NULL)
)
""".strip(),
)
]
assert result.results["CATEGORY"].tolist() == [
"Music",
"Women",
"Home",
"Children",
"Men",
"Electronics",
"Sports",
"Shoes",
"Jewelry",
]
@pytest.mark.parametrize(
"metrics, dimensions, filters, expected_sql, expected_parameters",
[
(
["TOTALSALESPRICE"],
[],
{
AdhocFilter(PredicateType.WHERE, "Year = '2002'"),
AdhocFilter(PredicateType.WHERE, "Month = '12'"),
},
"""
SELECT * FROM SEMANTIC_VIEW(
"SAMPLE_DATA"."TPCDS_SF10TCL"."TPCDS_SEMANTIC_VIEW_SM"
METRICS STORESALES.TOTALSALESPRICE AS "STORESALES.TOTALSALESPRICE"
WHERE (Month = '12') AND (Year = '2002')
)
""",
(),
),
(
[],
["CATEGORY"],
{
AdhocFilter(PredicateType.WHERE, "Year = '2002'"),
AdhocFilter(PredicateType.WHERE, "Month = '12'"),
},
"""
SELECT * FROM SEMANTIC_VIEW(
"SAMPLE_DATA"."TPCDS_SF10TCL"."TPCDS_SEMANTIC_VIEW_SM"
DIMENSIONS ITEM.CATEGORY AS "ITEM.CATEGORY"
WHERE (Month = '12') AND (Year = '2002')
)
""",
(),
),
(
["TOTALSALESPRICE"],
["CATEGORY"],
{
AdhocFilter(PredicateType.WHERE, "Year = '2002'"),
AdhocFilter(PredicateType.WHERE, "Month = '12'"),
},
"""
SELECT * FROM SEMANTIC_VIEW(
"SAMPLE_DATA"."TPCDS_SF10TCL"."TPCDS_SEMANTIC_VIEW_SM"
DIMENSIONS ITEM.CATEGORY AS "ITEM.CATEGORY"
METRICS STORESALES.TOTALSALESPRICE AS "STORESALES.TOTALSALESPRICE"
WHERE (Month = '12') AND (Year = '2002')
)
""",
(),
),
],
)
def test_get_query(
semantic_view: SnowflakeSemanticView,
metrics: list[str],
dimensions: list[str],
filters: set[Filter | AdhocFilter] | None,
expected_sql: str,
expected_parameters: tuple[FilterValues, ...],
) -> None:
"""
Tests for query generation.
"""
metric_map = {metric.name: metric for metric in semantic_view.metrics}
dimension_map = {dim.name: dim for dim in semantic_view.dimensions}
assert semantic_view._get_query(
[metric_map[name] for name in metrics],
[dimension_map[name] for name in dimensions],
filters,
) == (expected_sql.strip(), expected_parameters)