diff --git a/superset/semantic_layers/snowflake_.py b/superset/semantic_layers/snowflake_.py index 176cc36e3df..01832c34e21 100644 --- a/superset/semantic_layers/snowflake_.py +++ b/superset/semantic_layers/snowflake_.py @@ -414,7 +414,7 @@ class SnowflakeSemanticLayer: FROM INFORMATION_SCHEMA.SCHEMATA WHERE CATALOG_NAME = ? """ - ) + ).strip() return {row[0] for row in cursor.execute(query, (database,))} def __init__(self, configuration: SnowflakeConfiguration): @@ -440,7 +440,7 @@ class SnowflakeSemanticLayer: SHOW SEMANTIC VIEWS ->> SELECT "name" FROM $1; """ - ) + ).strip() return { SnowflakeSemanticView(configuration, row[0]) for row in cursor.execute(query) @@ -495,7 +495,7 @@ class SnowflakeSemanticView: "object_kind" = 'DIMENSION' AND "property" IN ('COMMENT', 'DATA_TYPE', 'EXPRESSION', 'TABLE'); """ - ) + ).strip() connection_parameters = get_connection_parameters(self.configuration) with connect(**connection_parameters) as connection: @@ -532,7 +532,7 @@ class SnowflakeSemanticView: "object_kind" = 'METRIC' AND "property" IN ('COMMENT', 'DATA_TYPE', 'EXPRESSION', 'TABLE'); """ - ) + ).strip() connection_parameters = get_connection_parameters(self.configuration) with connect(**connection_parameters) as connection: @@ -584,7 +584,7 @@ class SnowflakeSemanticView: def _build_predicates( self, filters: set[Filter | AdhocFilter], - ) -> tuple[str, tuple[FilterValues]]: + ) -> tuple[str, tuple[FilterValues, ...]]: """ Convert a set of filters to a single `AND`ed predicate. @@ -637,7 +637,7 @@ class SnowflakeSemanticView: {"WHERE " + where_clause if where_clause else ""} ) """ - ) + ).strip() connection_parameters = get_connection_parameters(self.configuration) with connect(**connection_parameters) as connection: df = connection.cursor().execute(query, parameters).fetch_pandas_all() @@ -765,7 +765,7 @@ class SnowflakeSemanticView: limit: int | None = None, offset: int | None = None, group_limit: GroupLimit | None = None, - ) -> tuple[str, tuple[FilterValues]]: + ) -> tuple[str, tuple[FilterValues, ...]]: """ Build a query to fetch data from the semantic view. @@ -883,7 +883,7 @@ class SnowflakeSemanticView: {"LIMIT " + str(limit) if limit is not None else ""} {"OFFSET " + str(offset) if offset is not None else ""} """ - ) + ).strip() def _build_top_groups_cte( self, @@ -948,7 +948,7 @@ class SnowflakeSemanticView: LIMIT {group_limit.top} ) """ - ) + ).strip() return cte_sql, cte_params @@ -1076,7 +1076,7 @@ class SnowflakeSemanticView: {"HAVING " + having_clause if having_clause else ""} ) """ - ) + ).strip() # Build GROUP BY clause (full CASE expressions + non-limited dimensions) # We need to repeat the full CASE expressions, not use aliases, because @@ -1105,7 +1105,7 @@ class SnowflakeSemanticView: {"LIMIT " + str(limit) if limit is not None else ""} {"OFFSET " + str(offset) if offset is not None else ""} """ - ) + ).strip() return query, cte_params @@ -1173,7 +1173,7 @@ class SnowflakeSemanticView: {"LIMIT " + str(limit) if limit is not None else ""} {"OFFSET " + str(offset) if offset is not None else ""} """ - ) + ).strip() return query, cte_params diff --git a/tests/unit_tests/semantic_layers/test_snowflake.py b/tests/unit_tests/semantic_layers/test_snowflake.py index 95c521b04ba..3d260ae63da 100644 --- a/tests/unit_tests/semantic_layers/test_snowflake.py +++ b/tests/unit_tests/semantic_layers/test_snowflake.py @@ -16,33 +16,208 @@ # under the License. from contextlib import nullcontext +from typing import Iterator from unittest.mock import MagicMock, patch import pytest +from pandas import DataFrame +from pytest_mock import MockerFixture from superset.semantic_layers.snowflake_ import ( get_connection_parameters, SnowflakeConfiguration, SnowflakeSemanticLayer, + SnowflakeSemanticView, substitute_parameters, validate_order_by, ) +from superset.semantic_layers.types import ( + AdhocFilter, + DATE, + Dimension, + Filter, + FilterValues, + INTEGER, + Metric, + NUMBER, + Operator, + PredicateType, + SemanticRequest, + STRING, +) + + +@pytest.fixture +def configuration() -> SnowflakeConfiguration: + return SnowflakeConfiguration.model_validate( + { + "account_identifier": "abcdefg-hij01234", + "role": "ACCOUNTADMIN", + "warehouse": "COMPUTE_WH", + "database": "SAMPLE_DATA", + "schema": "TPCDS_SF10TCL", + "auth": { + "auth_type": "user_password", + "username": "SNOWFLAKE_USER", + "password": "SNOWFLAKE_PASSWORD", + }, + "allow_changing_database": True, + "allow_changing_schema": True, + } + ) + + +# These fixtures reproduce the semantic view from +# https://quickstarts.snowflake.com/guide/snowflake-semantic-view/index.html +@pytest.fixture +def dimension_rows() -> list[dict[str, str]]: + return [ + dict( + zip( + ["object_name", "property", "property_value"], + row, + strict=False, + ) + ) + for row in [ + ("BIRTHYEAR", "TABLE", "CUSTOMER"), + ("BIRTHYEAR", "EXPRESSION", "C_BIRTH_YEAR"), + ("BIRTHYEAR", "DATA_TYPE", "NUMBER(38,0)"), + ("COUNTRY", "TABLE", "CUSTOMER"), + ("COUNTRY", "EXPRESSION", "C_BIRTH_COUNTRY"), + ("COUNTRY", "DATA_TYPE", "VARCHAR(20)"), + ("C_CUSTOMER_SK", "TABLE", "CUSTOMER"), + ("C_CUSTOMER_SK", "EXPRESSION", "c_customer_sk"), + ("C_CUSTOMER_SK", "DATA_TYPE", "NUMBER(38,0)"), + ("DATE", "TABLE", "DATE"), + ("DATE", "EXPRESSION", "D_DATE"), + ("DATE", "DATA_TYPE", "DATE"), + ("D_DATE_SK", "TABLE", "DATE"), + ("D_DATE_SK", "EXPRESSION", "d_date_sk"), + ("D_DATE_SK", "DATA_TYPE", "NUMBER(38,0)"), + ("MONTH", "TABLE", "DATE"), + ("MONTH", "EXPRESSION", "D_MOY"), + ("MONTH", "DATA_TYPE", "NUMBER(38,0)"), + ("WEEK", "TABLE", "DATE"), + ("WEEK", "EXPRESSION", "D_WEEK_SEQ"), + ("WEEK", "DATA_TYPE", "NUMBER(38,0)"), + ("YEAR", "TABLE", "DATE"), + ("YEAR", "EXPRESSION", "D_YEAR"), + ("YEAR", "DATA_TYPE", "NUMBER(38,0)"), + ("CD_DEMO_SK", "TABLE", "DEMO"), + ("CD_DEMO_SK", "EXPRESSION", "cd_demo_sk"), + ("CD_DEMO_SK", "DATA_TYPE", "NUMBER(38,0)"), + ("CREDIT_RATING", "TABLE", "DEMO"), + ("CREDIT_RATING", "EXPRESSION", "CD_CREDIT_RATING"), + ("CREDIT_RATING", "DATA_TYPE", "VARCHAR(10)"), + ("MARITAL_STATUS", "TABLE", "DEMO"), + ("MARITAL_STATUS", "EXPRESSION", "CD_MARITAL_STATUS"), + ("MARITAL_STATUS", "DATA_TYPE", "VARCHAR(1)"), + ("BRAND", "TABLE", "ITEM"), + ("BRAND", "EXPRESSION", "I_BRAND"), + ("BRAND", "DATA_TYPE", "VARCHAR(50)"), + ("CATEGORY", "TABLE", "ITEM"), + ("CATEGORY", "EXPRESSION", "I_CATEGORY"), + ("CATEGORY", "DATA_TYPE", "VARCHAR(50)"), + ("CLASS", "TABLE", "ITEM"), + ("CLASS", "EXPRESSION", "I_CLASS"), + ("CLASS", "DATA_TYPE", "VARCHAR(50)"), + ("I_ITEM_SK", "TABLE", "ITEM"), + ("I_ITEM_SK", "EXPRESSION", "i_item_sk"), + ("I_ITEM_SK", "DATA_TYPE", "NUMBER(38,0)"), + ("MARKET", "TABLE", "STORE"), + ("MARKET", "EXPRESSION", "S_MARKET_ID"), + ("MARKET", "DATA_TYPE", "NUMBER(38,0)"), + ("SQUAREFOOTAGE", "TABLE", "STORE"), + ("SQUAREFOOTAGE", "EXPRESSION", "S_FLOOR_SPACE"), + ("SQUAREFOOTAGE", "DATA_TYPE", "NUMBER(38,0)"), + ("STATE", "TABLE", "STORE"), + ("STATE", "EXPRESSION", "S_STATE"), + ("STATE", "DATA_TYPE", "VARCHAR(2)"), + ("STORECOUNTRY", "TABLE", "STORE"), + ("STORECOUNTRY", "EXPRESSION", "S_COUNTRY"), + ("STORECOUNTRY", "DATA_TYPE", "VARCHAR(20)"), + ("S_STORE_SK", "TABLE", "STORE"), + ("S_STORE_SK", "EXPRESSION", "s_store_sk"), + ("S_STORE_SK", "DATA_TYPE", "NUMBER(38,0)"), + ("SS_CDEMO_SK", "TABLE", "STORESALES"), + ("SS_CDEMO_SK", "EXPRESSION", "ss_cdemo_sk"), + ("SS_CDEMO_SK", "DATA_TYPE", "NUMBER(38,0)"), + ("SS_CUSTOMER_SK", "TABLE", "STORESALES"), + ("SS_CUSTOMER_SK", "EXPRESSION", "ss_customer_sk"), + ("SS_CUSTOMER_SK", "DATA_TYPE", "NUMBER(38,0)"), + ("SS_ITEM_SK", "TABLE", "STORESALES"), + ("SS_ITEM_SK", "EXPRESSION", "ss_item_sk"), + ("SS_ITEM_SK", "DATA_TYPE", "NUMBER(38,0)"), + ("SS_SOLD_DATE_SK", "TABLE", "STORESALES"), + ("SS_SOLD_DATE_SK", "EXPRESSION", "ss_sold_date_sk"), + ("SS_SOLD_DATE_SK", "DATA_TYPE", "NUMBER(38,0)"), + ("SS_STORE_SK", "TABLE", "STORESALES"), + ("SS_STORE_SK", "EXPRESSION", "ss_store_sk"), + ("SS_STORE_SK", "DATA_TYPE", "NUMBER(38,0)"), + ] + ] + + +@pytest.fixture +def metric_rows() -> list[dict[str, str]]: + return [ + dict( + zip( + ["object_name", "property", "property_value"], + row, + strict=False, + ) + ) + for row in [ + ("TOTALCOST", "TABLE", "STORESALES"), + ("TOTALCOST", "EXPRESSION", "SUM(item.cost)"), + ("TOTALCOST", "DATA_TYPE", "NUMBER(19,2)"), + ("TOTALSALESPRICE", "TABLE", "STORESALES"), + ("TOTALSALESPRICE", "EXPRESSION", "SUM(SS_SALES_PRICE)"), + ("TOTALSALESPRICE", "DATA_TYPE", "NUMBER(19,2)"), + ("TOTALSALESQUANTITY", "TABLE", "STORESALES"), + ("TOTALSALESQUANTITY", "EXPRESSION", "SUM(SS_QUANTITY)"), + ("TOTALSALESQUANTITY", "DATA_TYPE", "NUMBER(38,0)"), + ] + ] + + +@pytest.fixture +def connection(mocker: MockerFixture) -> Iterator[MagicMock]: + """ + Mock the Snowflake connect function to return a mock connection. + """ + connect = mocker.patch("superset.semantic_layers.snowflake_.connect") + with connect() as connection: + yield connection + + +@pytest.fixture +def semantic_view( + mocker: MockerFixture, + connection: MagicMock, + configuration: SnowflakeConfiguration, + dimension_rows: list[dict[str, str]], + metric_rows: list[dict[str, str]], +) -> SnowflakeSemanticView: + """ + Mock the SnowflakeSemanticView to return predefined dimensions and metrics. + """ + connection.cursor().execute().fetchall.side_effect = [ + dimension_rows, + metric_rows, + ] + + return SnowflakeSemanticView(configuration, "TPCDS_SEMANTIC_VIEW_SM") @pytest.mark.parametrize( "query, parameters, expected", [ # No parameters - ( - "SELECT * FROM table", - None, - "SELECT * FROM table", - ), - ( - "SELECT * FROM table", - [], - "SELECT * FROM table", - ), + ("SELECT * FROM table", None, "SELECT * FROM table"), + ("SELECT * FROM table", [], "SELECT * FROM table"), # NULL values ( "SELECT * FROM table WHERE id = ?", @@ -610,3 +785,471 @@ def test_get_runtime_schema( assert set(schema_enum) == set(schemas) else: assert "schema" not in schema["properties"] + + +def test_get_dimensions( + mocker: MockerFixture, + connection: MagicMock, + semantic_view: SnowflakeSemanticView, +) -> None: + """ + Test dimension retrieval and parsing from Snowflake semantic layer. + """ + assert semantic_view.dimensions == { + Dimension( + id="CUSTOMER.C_CUSTOMER_SK", + name="C_CUSTOMER_SK", + type=INTEGER, + definition=None, + description="c_customer_sk", + grain=None, + ), + Dimension( + id="STORE.SQUAREFOOTAGE", + name="SQUAREFOOTAGE", + type=INTEGER, + definition=None, + description="S_FLOOR_SPACE", + grain=None, + ), + Dimension( + id="ITEM.BRAND", + name="BRAND", + type=STRING, + definition=None, + description="I_BRAND", + grain=None, + ), + Dimension( + id="ITEM.CATEGORY", + name="CATEGORY", + type=STRING, + definition=None, + description="I_CATEGORY", + grain=None, + ), + Dimension( + id="STORE.S_STORE_SK", + name="S_STORE_SK", + type=INTEGER, + definition=None, + description="s_store_sk", + grain=None, + ), + Dimension( + id="STORESALES.SS_CUSTOMER_SK", + name="SS_CUSTOMER_SK", + type=INTEGER, + definition=None, + description="ss_customer_sk", + grain=None, + ), + Dimension( + id="DATE.DATE", + name="DATE", + type=DATE, + definition=None, + description="D_DATE", + grain=None, + ), + Dimension( + id="DEMO.CD_DEMO_SK", + name="CD_DEMO_SK", + type=INTEGER, + definition=None, + description="cd_demo_sk", + grain=None, + ), + Dimension( + id="DATE.MONTH", + name="MONTH", + type=INTEGER, + definition=None, + description="D_MOY", + grain=None, + ), + Dimension( + id="STORE.MARKET", + name="MARKET", + type=INTEGER, + definition=None, + description="S_MARKET_ID", + grain=None, + ), + Dimension( + id="STORESALES.SS_ITEM_SK", + name="SS_ITEM_SK", + type=INTEGER, + definition=None, + description="ss_item_sk", + grain=None, + ), + Dimension( + id="STORE.STORECOUNTRY", + name="STORECOUNTRY", + type=STRING, + definition=None, + description="S_COUNTRY", + grain=None, + ), + Dimension( + id="ITEM.CLASS", + name="CLASS", + type=STRING, + definition=None, + description="I_CLASS", + grain=None, + ), + Dimension( + id="CUSTOMER.COUNTRY", + name="COUNTRY", + type=STRING, + definition=None, + description="C_BIRTH_COUNTRY", + grain=None, + ), + Dimension( + id="DEMO.CREDIT_RATING", + name="CREDIT_RATING", + type=STRING, + definition=None, + description="CD_CREDIT_RATING", + grain=None, + ), + Dimension( + id="DATE.WEEK", + name="WEEK", + type=INTEGER, + definition=None, + description="D_WEEK_SEQ", + grain=None, + ), + Dimension( + id="DATE.D_DATE_SK", + name="D_DATE_SK", + type=INTEGER, + definition=None, + description="d_date_sk", + grain=None, + ), + Dimension( + id="STORESALES.SS_SOLD_DATE_SK", + name="SS_SOLD_DATE_SK", + type=INTEGER, + definition=None, + description="ss_sold_date_sk", + grain=None, + ), + Dimension( + id="CUSTOMER.BIRTHYEAR", + name="BIRTHYEAR", + type=INTEGER, + definition=None, + description="C_BIRTH_YEAR", + grain=None, + ), + Dimension( + id="DEMO.MARITAL_STATUS", + name="MARITAL_STATUS", + type=STRING, + definition=None, + description="CD_MARITAL_STATUS", + grain=None, + ), + Dimension( + id="STORESALES.SS_CDEMO_SK", + name="SS_CDEMO_SK", + type=INTEGER, + definition=None, + description="ss_cdemo_sk", + grain=None, + ), + Dimension( + id="DATE.YEAR", + name="YEAR", + type=INTEGER, + definition=None, + description="D_YEAR", + grain=None, + ), + Dimension( + id="ITEM.I_ITEM_SK", + name="I_ITEM_SK", + type=INTEGER, + definition=None, + description="i_item_sk", + grain=None, + ), + Dimension( + id="STORESALES.SS_STORE_SK", + name="SS_STORE_SK", + type=INTEGER, + definition=None, + description="ss_store_sk", + grain=None, + ), + Dimension( + id="STORE.STATE", + name="STATE", + type=STRING, + definition=None, + description="S_STATE", + grain=None, + ), + } + + connection.cursor().execute.assert_any_call( + """ +DESC SEMANTIC VIEW "SAMPLE_DATA"."TPCDS_SF10TCL"."TPCDS_SEMANTIC_VIEW_SM" + ->> SELECT "object_name", "property", "property_value" + FROM $1 + WHERE + "object_kind" = 'DIMENSION' AND + "property" IN ('COMMENT', 'DATA_TYPE', 'EXPRESSION', 'TABLE'); + """.strip() + ) + + +def test_get_metrics( + mocker: MockerFixture, + connection: MagicMock, + semantic_view: SnowflakeSemanticView, +) -> None: + """ + Test metric retrieval and parsing from Snowflake semantic layer. + """ + assert semantic_view.metrics == { + Metric( + id="STORESALES.TOTALCOST", + name="TOTALCOST", + type=NUMBER, + definition="SUM(item.cost)", + description=None, + ), + Metric( + id="STORESALES.TOTALSALESQUANTITY", + name="TOTALSALESQUANTITY", + type=INTEGER, + definition="SUM(SS_QUANTITY)", + description=None, + ), + Metric( + id="STORESALES.TOTALSALESPRICE", + name="TOTALSALESPRICE", + type=NUMBER, + definition="SUM(SS_SALES_PRICE)", + description=None, + ), + } + + connection.cursor().execute.assert_any_call( + """ +DESC SEMANTIC VIEW "SAMPLE_DATA"."TPCDS_SF10TCL"."TPCDS_SEMANTIC_VIEW_SM" + ->> SELECT "object_name", "property", "property_value" + FROM $1 + WHERE + "object_kind" = 'METRIC' AND + "property" IN ('COMMENT', 'DATA_TYPE', 'EXPRESSION', 'TABLE'); + """.strip() + ) + + +def test_get_values( + mocker: MockerFixture, + connection: MagicMock, + semantic_view: SnowflakeSemanticView, +) -> None: + connection.cursor().execute().fetch_pandas_all.return_value = DataFrame( + { + "CATEGORY": [ + "Music", + "Women", + "Home", + "Children", + "Men", + "Electronics", + "Sports", + "Shoes", + "Jewelry", + "Books", + None, + ] + } + ) + + dimension = Dimension( + id="ITEM.CATEGORY", + name="CATEGORY", + type=STRING, + description=None, + definition="I_CATEGORY", + grain=None, + ) + + result = semantic_view.get_values(dimension) + + assert result.requests == [ + SemanticRequest( + type="snowflake", + definition=""" +SELECT "CATEGORY" +FROM SEMANTIC_VIEW( + "SAMPLE_DATA"."TPCDS_SF10TCL"."TPCDS_SEMANTIC_VIEW_SM" + DIMENSIONS ITEM.CATEGORY + +) + """.strip(), + ) + ] + assert result.results["CATEGORY"].tolist() == [ + "Music", + "Women", + "Home", + "Children", + "Men", + "Electronics", + "Sports", + "Shoes", + "Jewelry", + "Books", + None, + ] + + +def test_get_values_with_filters( + mocker: MockerFixture, + connection: MagicMock, + semantic_view: SnowflakeSemanticView, +) -> None: + connection.cursor().execute().fetch_pandas_all.return_value = DataFrame( + { + "CATEGORY": [ + "Music", + "Women", + "Home", + "Children", + "Men", + "Electronics", + "Sports", + "Shoes", + "Jewelry", + ] + } + ) + + dimension = Dimension( + id="ITEM.CATEGORY", + name="CATEGORY", + type=STRING, + description=None, + definition="I_CATEGORY", + grain=None, + ) + filters = { + Filter(PredicateType.WHERE, dimension, Operator.NOT_EQUALS, "Books"), + Filter(PredicateType.WHERE, dimension, Operator.IS_NOT_NULL, None), + } + + result = semantic_view.get_values(dimension, filters) + + assert result.requests == [ + SemanticRequest( + type="snowflake", + definition=""" +SELECT "CATEGORY" +FROM SEMANTIC_VIEW( + "SAMPLE_DATA"."TPCDS_SF10TCL"."TPCDS_SEMANTIC_VIEW_SM" + DIMENSIONS ITEM.CATEGORY + WHERE ("CATEGORY" != 'Books') AND ("CATEGORY" IS NOT NULL) +) + """.strip(), + ) + ] + assert result.results["CATEGORY"].tolist() == [ + "Music", + "Women", + "Home", + "Children", + "Men", + "Electronics", + "Sports", + "Shoes", + "Jewelry", + ] + + +@pytest.mark.parametrize( + "metrics, dimensions, filters, expected_sql, expected_parameters", + [ + ( + ["TOTALSALESPRICE"], + [], + { + AdhocFilter(PredicateType.WHERE, "Year = '2002'"), + AdhocFilter(PredicateType.WHERE, "Month = '12'"), + }, + """ +SELECT * FROM SEMANTIC_VIEW( + "SAMPLE_DATA"."TPCDS_SF10TCL"."TPCDS_SEMANTIC_VIEW_SM" + + METRICS STORESALES.TOTALSALESPRICE AS "STORESALES.TOTALSALESPRICE" + WHERE (Month = '12') AND (Year = '2002') +) + """, + (), + ), + ( + [], + ["CATEGORY"], + { + AdhocFilter(PredicateType.WHERE, "Year = '2002'"), + AdhocFilter(PredicateType.WHERE, "Month = '12'"), + }, + """ +SELECT * FROM SEMANTIC_VIEW( + "SAMPLE_DATA"."TPCDS_SF10TCL"."TPCDS_SEMANTIC_VIEW_SM" + DIMENSIONS ITEM.CATEGORY AS "ITEM.CATEGORY" + + WHERE (Month = '12') AND (Year = '2002') +) + """, + (), + ), + ( + ["TOTALSALESPRICE"], + ["CATEGORY"], + { + AdhocFilter(PredicateType.WHERE, "Year = '2002'"), + AdhocFilter(PredicateType.WHERE, "Month = '12'"), + }, + """ +SELECT * FROM SEMANTIC_VIEW( + "SAMPLE_DATA"."TPCDS_SF10TCL"."TPCDS_SEMANTIC_VIEW_SM" + DIMENSIONS ITEM.CATEGORY AS "ITEM.CATEGORY" + METRICS STORESALES.TOTALSALESPRICE AS "STORESALES.TOTALSALESPRICE" + WHERE (Month = '12') AND (Year = '2002') +) + """, + (), + ), + ], +) +def test_get_query( + semantic_view: SnowflakeSemanticView, + metrics: list[str], + dimensions: list[str], + filters: set[Filter | AdhocFilter] | None, + expected_sql: str, + expected_parameters: tuple[FilterValues, ...], +) -> None: + """ + Tests for query generation. + """ + metric_map = {metric.name: metric for metric in semantic_view.metrics} + dimension_map = {dim.name: dim for dim in semantic_view.dimensions} + + assert semantic_view._get_query( + [metric_map[name] for name in metrics], + [dimension_map[name] for name in dimensions], + filters, + ) == (expected_sql.strip(), expected_parameters)