mirror of
https://github.com/apache/superset.git
synced 2026-05-07 17:04:58 +00:00
1286 lines
40 KiB
Python
1286 lines
40 KiB
Python
# Licensed to the Apache Software Foundation (ASF) under one
|
|
# or more contributor license agreements. See the NOTICE file
|
|
# distributed with this work for additional information
|
|
# regarding copyright ownership. The ASF licenses this file
|
|
# to you under the Apache License, Version 2.0 (the
|
|
# "License"); you may not use this file except in compliance
|
|
# with the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing,
|
|
# software distributed under the License is distributed on an
|
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
# KIND, either express or implied. See the License for the
|
|
# specific language governing permissions and limitations
|
|
# under the License.
|
|
|
|
from contextlib import nullcontext
|
|
from typing import Iterator
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
from pandas import DataFrame
|
|
from pytest_mock import MockerFixture
|
|
|
|
from superset.semantic_layers.snowflake_ import (
|
|
get_connection_parameters,
|
|
SnowflakeConfiguration,
|
|
SnowflakeSemanticLayer,
|
|
SnowflakeSemanticView,
|
|
substitute_parameters,
|
|
validate_order_by,
|
|
)
|
|
from superset.semantic_layers.types import (
|
|
AdhocFilter,
|
|
DATE,
|
|
Dimension,
|
|
Filter,
|
|
FilterValues,
|
|
INTEGER,
|
|
Metric,
|
|
NUMBER,
|
|
Operator,
|
|
OrderDirection,
|
|
PredicateType,
|
|
SemanticRequest,
|
|
STRING,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def configuration() -> SnowflakeConfiguration:
|
|
return SnowflakeConfiguration.model_validate(
|
|
{
|
|
"account_identifier": "abcdefg-hij01234",
|
|
"role": "ACCOUNTADMIN",
|
|
"warehouse": "COMPUTE_WH",
|
|
"database": "SAMPLE_DATA",
|
|
"schema": "TPCDS_SF10TCL",
|
|
"auth": {
|
|
"auth_type": "user_password",
|
|
"username": "SNOWFLAKE_USER",
|
|
"password": "SNOWFLAKE_PASSWORD",
|
|
},
|
|
"allow_changing_database": True,
|
|
"allow_changing_schema": True,
|
|
}
|
|
)
|
|
|
|
|
|
# These fixtures reproduce the semantic view from
|
|
# https://quickstarts.snowflake.com/guide/snowflake-semantic-view/index.html
|
|
@pytest.fixture
|
|
def dimension_rows() -> list[dict[str, str]]:
|
|
return [
|
|
dict(
|
|
zip(
|
|
["object_name", "property", "property_value"],
|
|
row,
|
|
strict=False,
|
|
)
|
|
)
|
|
for row in [
|
|
("BIRTHYEAR", "TABLE", "CUSTOMER"),
|
|
("BIRTHYEAR", "EXPRESSION", "C_BIRTH_YEAR"),
|
|
("BIRTHYEAR", "DATA_TYPE", "NUMBER(38,0)"),
|
|
("COUNTRY", "TABLE", "CUSTOMER"),
|
|
("COUNTRY", "EXPRESSION", "C_BIRTH_COUNTRY"),
|
|
("COUNTRY", "DATA_TYPE", "VARCHAR(20)"),
|
|
("C_CUSTOMER_SK", "TABLE", "CUSTOMER"),
|
|
("C_CUSTOMER_SK", "EXPRESSION", "c_customer_sk"),
|
|
("C_CUSTOMER_SK", "DATA_TYPE", "NUMBER(38,0)"),
|
|
("DATE", "TABLE", "DATE"),
|
|
("DATE", "EXPRESSION", "D_DATE"),
|
|
("DATE", "DATA_TYPE", "DATE"),
|
|
("D_DATE_SK", "TABLE", "DATE"),
|
|
("D_DATE_SK", "EXPRESSION", "d_date_sk"),
|
|
("D_DATE_SK", "DATA_TYPE", "NUMBER(38,0)"),
|
|
("MONTH", "TABLE", "DATE"),
|
|
("MONTH", "EXPRESSION", "D_MOY"),
|
|
("MONTH", "DATA_TYPE", "NUMBER(38,0)"),
|
|
("WEEK", "TABLE", "DATE"),
|
|
("WEEK", "EXPRESSION", "D_WEEK_SEQ"),
|
|
("WEEK", "DATA_TYPE", "NUMBER(38,0)"),
|
|
("YEAR", "TABLE", "DATE"),
|
|
("YEAR", "EXPRESSION", "D_YEAR"),
|
|
("YEAR", "DATA_TYPE", "NUMBER(38,0)"),
|
|
("CD_DEMO_SK", "TABLE", "DEMO"),
|
|
("CD_DEMO_SK", "EXPRESSION", "cd_demo_sk"),
|
|
("CD_DEMO_SK", "DATA_TYPE", "NUMBER(38,0)"),
|
|
("CREDIT_RATING", "TABLE", "DEMO"),
|
|
("CREDIT_RATING", "EXPRESSION", "CD_CREDIT_RATING"),
|
|
("CREDIT_RATING", "DATA_TYPE", "VARCHAR(10)"),
|
|
("MARITAL_STATUS", "TABLE", "DEMO"),
|
|
("MARITAL_STATUS", "EXPRESSION", "CD_MARITAL_STATUS"),
|
|
("MARITAL_STATUS", "DATA_TYPE", "VARCHAR(1)"),
|
|
("BRAND", "TABLE", "ITEM"),
|
|
("BRAND", "EXPRESSION", "I_BRAND"),
|
|
("BRAND", "DATA_TYPE", "VARCHAR(50)"),
|
|
("CATEGORY", "TABLE", "ITEM"),
|
|
("CATEGORY", "EXPRESSION", "I_CATEGORY"),
|
|
("CATEGORY", "DATA_TYPE", "VARCHAR(50)"),
|
|
("CLASS", "TABLE", "ITEM"),
|
|
("CLASS", "EXPRESSION", "I_CLASS"),
|
|
("CLASS", "DATA_TYPE", "VARCHAR(50)"),
|
|
("I_ITEM_SK", "TABLE", "ITEM"),
|
|
("I_ITEM_SK", "EXPRESSION", "i_item_sk"),
|
|
("I_ITEM_SK", "DATA_TYPE", "NUMBER(38,0)"),
|
|
("MARKET", "TABLE", "STORE"),
|
|
("MARKET", "EXPRESSION", "S_MARKET_ID"),
|
|
("MARKET", "DATA_TYPE", "NUMBER(38,0)"),
|
|
("SQUAREFOOTAGE", "TABLE", "STORE"),
|
|
("SQUAREFOOTAGE", "EXPRESSION", "S_FLOOR_SPACE"),
|
|
("SQUAREFOOTAGE", "DATA_TYPE", "NUMBER(38,0)"),
|
|
("STATE", "TABLE", "STORE"),
|
|
("STATE", "EXPRESSION", "S_STATE"),
|
|
("STATE", "DATA_TYPE", "VARCHAR(2)"),
|
|
("STORECOUNTRY", "TABLE", "STORE"),
|
|
("STORECOUNTRY", "EXPRESSION", "S_COUNTRY"),
|
|
("STORECOUNTRY", "DATA_TYPE", "VARCHAR(20)"),
|
|
("S_STORE_SK", "TABLE", "STORE"),
|
|
("S_STORE_SK", "EXPRESSION", "s_store_sk"),
|
|
("S_STORE_SK", "DATA_TYPE", "NUMBER(38,0)"),
|
|
("SS_CDEMO_SK", "TABLE", "STORESALES"),
|
|
("SS_CDEMO_SK", "EXPRESSION", "ss_cdemo_sk"),
|
|
("SS_CDEMO_SK", "DATA_TYPE", "NUMBER(38,0)"),
|
|
("SS_CUSTOMER_SK", "TABLE", "STORESALES"),
|
|
("SS_CUSTOMER_SK", "EXPRESSION", "ss_customer_sk"),
|
|
("SS_CUSTOMER_SK", "DATA_TYPE", "NUMBER(38,0)"),
|
|
("SS_ITEM_SK", "TABLE", "STORESALES"),
|
|
("SS_ITEM_SK", "EXPRESSION", "ss_item_sk"),
|
|
("SS_ITEM_SK", "DATA_TYPE", "NUMBER(38,0)"),
|
|
("SS_SOLD_DATE_SK", "TABLE", "STORESALES"),
|
|
("SS_SOLD_DATE_SK", "EXPRESSION", "ss_sold_date_sk"),
|
|
("SS_SOLD_DATE_SK", "DATA_TYPE", "NUMBER(38,0)"),
|
|
("SS_STORE_SK", "TABLE", "STORESALES"),
|
|
("SS_STORE_SK", "EXPRESSION", "ss_store_sk"),
|
|
("SS_STORE_SK", "DATA_TYPE", "NUMBER(38,0)"),
|
|
]
|
|
]
|
|
|
|
|
|
@pytest.fixture
|
|
def metric_rows() -> list[dict[str, str]]:
|
|
return [
|
|
dict(
|
|
zip(
|
|
["object_name", "property", "property_value"],
|
|
row,
|
|
strict=False,
|
|
)
|
|
)
|
|
for row in [
|
|
("TOTALCOST", "TABLE", "STORESALES"),
|
|
("TOTALCOST", "EXPRESSION", "SUM(item.cost)"),
|
|
("TOTALCOST", "DATA_TYPE", "NUMBER(19,2)"),
|
|
("TOTALSALESPRICE", "TABLE", "STORESALES"),
|
|
("TOTALSALESPRICE", "EXPRESSION", "SUM(SS_SALES_PRICE)"),
|
|
("TOTALSALESPRICE", "DATA_TYPE", "NUMBER(19,2)"),
|
|
("TOTALSALESQUANTITY", "TABLE", "STORESALES"),
|
|
("TOTALSALESQUANTITY", "EXPRESSION", "SUM(SS_QUANTITY)"),
|
|
("TOTALSALESQUANTITY", "DATA_TYPE", "NUMBER(38,0)"),
|
|
]
|
|
]
|
|
|
|
|
|
@pytest.fixture
|
|
def connection(mocker: MockerFixture) -> Iterator[MagicMock]:
|
|
"""
|
|
Mock the Snowflake connect function to return a mock connection.
|
|
"""
|
|
connect = mocker.patch("superset.semantic_layers.snowflake_.connect")
|
|
with connect() as connection:
|
|
yield connection
|
|
|
|
|
|
@pytest.fixture
|
|
def semantic_view(
|
|
mocker: MockerFixture,
|
|
connection: MagicMock,
|
|
configuration: SnowflakeConfiguration,
|
|
dimension_rows: list[dict[str, str]],
|
|
metric_rows: list[dict[str, str]],
|
|
) -> SnowflakeSemanticView:
|
|
"""
|
|
Mock the SnowflakeSemanticView to return predefined dimensions and metrics.
|
|
"""
|
|
connection.cursor().execute().fetchall.side_effect = [
|
|
dimension_rows,
|
|
metric_rows,
|
|
]
|
|
|
|
return SnowflakeSemanticView(configuration, "TPCDS_SEMANTIC_VIEW_SM")
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"query, parameters, expected",
|
|
[
|
|
# No parameters
|
|
("SELECT * FROM table", None, "SELECT * FROM table"),
|
|
("SELECT * FROM table", [], "SELECT * FROM table"),
|
|
# NULL values
|
|
(
|
|
"SELECT * FROM table WHERE id = ?",
|
|
[None],
|
|
"SELECT * FROM table WHERE id = NULL",
|
|
),
|
|
# Integer values
|
|
(
|
|
"SELECT * FROM table WHERE id = ?",
|
|
[123],
|
|
"SELECT * FROM table WHERE id = 123",
|
|
),
|
|
(
|
|
"SELECT * FROM table WHERE id = ? AND status = ?",
|
|
[123, 456],
|
|
"SELECT * FROM table WHERE id = 123 AND status = 456",
|
|
),
|
|
# Float values
|
|
(
|
|
"SELECT * FROM table WHERE price = ?",
|
|
[99.99],
|
|
"SELECT * FROM table WHERE price = 99.99",
|
|
),
|
|
(
|
|
"SELECT * FROM table WHERE price BETWEEN ? AND ?",
|
|
[10.5, 99.99],
|
|
"SELECT * FROM table WHERE price BETWEEN 10.5 AND 99.99",
|
|
),
|
|
# Boolean values
|
|
(
|
|
"SELECT * FROM table WHERE active = ?",
|
|
[True],
|
|
"SELECT * FROM table WHERE active = TRUE",
|
|
),
|
|
(
|
|
"SELECT * FROM table WHERE active = ? AND deleted = ?",
|
|
[True, False],
|
|
"SELECT * FROM table WHERE active = TRUE AND deleted = FALSE",
|
|
),
|
|
# String values
|
|
(
|
|
"SELECT * FROM table WHERE name = ?",
|
|
["John"],
|
|
"SELECT * FROM table WHERE name = 'John'",
|
|
),
|
|
(
|
|
"SELECT * FROM table WHERE name = ? OR name = ?",
|
|
["John", "Jane"],
|
|
"SELECT * FROM table WHERE name = 'John' OR name = 'Jane'",
|
|
),
|
|
# String with single quotes (should be escaped)
|
|
(
|
|
"SELECT * FROM table WHERE name = ?",
|
|
["O'Brien"],
|
|
"SELECT * FROM table WHERE name = 'O''Brien'",
|
|
),
|
|
(
|
|
"SELECT * FROM table WHERE text = ?",
|
|
["It's a test"],
|
|
"SELECT * FROM table WHERE text = 'It''s a test'",
|
|
),
|
|
# Mixed types
|
|
(
|
|
(
|
|
"SELECT * FROM table WHERE name = ? "
|
|
"AND age = ? AND active = ? AND salary = ?"
|
|
),
|
|
["John", 30, True, 50000.5],
|
|
(
|
|
"SELECT * FROM table WHERE name = 'John' "
|
|
"AND age = 30 AND active = TRUE AND salary = 50000.5"
|
|
),
|
|
),
|
|
(
|
|
"SELECT * FROM table WHERE col1 = ? AND col2 = ? AND col3 = ?",
|
|
[None, "test", 42],
|
|
"SELECT * FROM table WHERE col1 = NULL AND col2 = 'test' AND col3 = 42",
|
|
),
|
|
],
|
|
)
|
|
def test_substitute_parameters(
|
|
query: str,
|
|
parameters: list | None,
|
|
expected: str,
|
|
) -> None:
|
|
"""
|
|
Test parameter substitution for various types and combinations.
|
|
"""
|
|
assert substitute_parameters(query, parameters) == expected
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"definition, should_raise",
|
|
[
|
|
# Valid simple cases
|
|
("column_name", False),
|
|
("COUNT(*)", False),
|
|
("SUM(amount)", False),
|
|
("table.column", False),
|
|
("schema.table.column", False),
|
|
# Valid with direction
|
|
("column_name ASC", False),
|
|
("column_name DESC", False),
|
|
("COUNT(*) DESC", False),
|
|
("SUM(revenue) ASC", False),
|
|
# Valid with NULLS handling
|
|
("column_name NULLS FIRST", False),
|
|
("column_name NULLS LAST", False),
|
|
("column_name ASC NULLS FIRST", False),
|
|
("column_name DESC NULLS LAST", False),
|
|
("COUNT(*) DESC NULLS FIRST", False),
|
|
# Valid complex expressions
|
|
("gender ASC, COUNT(*)", False),
|
|
("gender ASC, COUNT(*) DESC", False),
|
|
("col1 ASC, col2 DESC, col3", False),
|
|
("CASE WHEN x > 0 THEN 1 ELSE 0 END", False),
|
|
("CAST(column AS INTEGER)", False),
|
|
("UPPER(name)", False),
|
|
("CONCAT(first_name, ' ', last_name)", False),
|
|
# Valid with mixed complexity
|
|
("table.column ASC NULLS FIRST, COUNT(*) DESC", False),
|
|
("schema.table.col1, func(col2) DESC NULLS LAST", False),
|
|
# Invalid - SQL injection attempts with semicolons
|
|
("column_name; DROP TABLE users;", True),
|
|
("column_name; DELETE FROM data; --", True),
|
|
("name; UPDATE users SET admin=1; --", True),
|
|
# Invalid - SQL injection with multiple statements
|
|
("col1; SELECT * FROM passwords", True),
|
|
("col1; INSERT INTO logs VALUES(1)", True),
|
|
# Edge cases - incomplete syntax
|
|
("column/*", True),
|
|
],
|
|
)
|
|
def test_validate_order_by(definition: str, should_raise: bool) -> None:
|
|
"""
|
|
Test ORDER BY validation for valid expressions and SQL injection prevention.
|
|
"""
|
|
context = (
|
|
pytest.raises(ValueError, match="Invalid ORDER BY")
|
|
if should_raise
|
|
else nullcontext()
|
|
)
|
|
with context:
|
|
validate_order_by(definition)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"configuration, expected",
|
|
[
|
|
# Minimal UserPasswordAuth configuration
|
|
(
|
|
{
|
|
"account_identifier": "test_account",
|
|
"auth": {
|
|
"auth_type": "user_password",
|
|
"username": "test_user",
|
|
"password": "test_password",
|
|
},
|
|
"allow_changing_database": True,
|
|
"allow_changing_schema": True,
|
|
},
|
|
{
|
|
"account": "test_account",
|
|
"application": "Apache Superset",
|
|
"paramstyle": "qmark",
|
|
"insecure_mode": True,
|
|
"user": "test_user",
|
|
"password": "test_password",
|
|
},
|
|
),
|
|
# Full UserPasswordAuth configuration
|
|
(
|
|
{
|
|
"account_identifier": "test_account",
|
|
"role": "ACCOUNTADMIN",
|
|
"warehouse": "COMPUTE_WH",
|
|
"database": "TEST_DB",
|
|
"schema": "PUBLIC",
|
|
"auth": {
|
|
"auth_type": "user_password",
|
|
"username": "admin",
|
|
"password": "secret123",
|
|
},
|
|
},
|
|
{
|
|
"account": "test_account",
|
|
"application": "Apache Superset",
|
|
"paramstyle": "qmark",
|
|
"insecure_mode": True,
|
|
"role": "ACCOUNTADMIN",
|
|
"warehouse": "COMPUTE_WH",
|
|
"database": "TEST_DB",
|
|
"schema": "PUBLIC",
|
|
"user": "admin",
|
|
"password": "secret123",
|
|
},
|
|
),
|
|
# UserPasswordAuth with some optional fields
|
|
(
|
|
{
|
|
"account_identifier": "mycompany.us-east-1",
|
|
"warehouse": "ETL_WH",
|
|
"database": "ANALYTICS",
|
|
"auth": {
|
|
"auth_type": "user_password",
|
|
"username": "analyst",
|
|
"password": "p@ssw0rd",
|
|
},
|
|
"allow_changing_schema": True,
|
|
},
|
|
{
|
|
"account": "mycompany.us-east-1",
|
|
"application": "Apache Superset",
|
|
"paramstyle": "qmark",
|
|
"insecure_mode": True,
|
|
"warehouse": "ETL_WH",
|
|
"database": "ANALYTICS",
|
|
"user": "analyst",
|
|
"password": "p@ssw0rd",
|
|
},
|
|
),
|
|
],
|
|
)
|
|
def test_get_connection_parameters(
|
|
configuration: dict,
|
|
expected: dict,
|
|
) -> None:
|
|
"""
|
|
Test connection parameter generation for various configurations.
|
|
"""
|
|
# Create configuration from params
|
|
config = SnowflakeConfiguration(**configuration)
|
|
|
|
# Get connection parameters
|
|
result = get_connection_parameters(config)
|
|
|
|
# Check that all expected keys are present with correct values
|
|
for key, value in expected.items():
|
|
assert key in result, f"Expected key '{key}' not found in result"
|
|
assert result[key] == value, f"Expected {key}={value}, got {result[key]}"
|
|
|
|
# Verify no unexpected keys
|
|
assert set(result.keys()) == set(expected.keys())
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"configuration, databases, schemas, expected_db_enum, expected_schema_enum",
|
|
[
|
|
# No configuration - empty enums
|
|
(
|
|
None,
|
|
None,
|
|
None,
|
|
[],
|
|
[],
|
|
),
|
|
# Configuration with account + auth - populates databases
|
|
(
|
|
{
|
|
"account_identifier": "test_account",
|
|
"auth": {
|
|
"auth_type": "user_password",
|
|
"username": "test_user",
|
|
"password": "test_password",
|
|
},
|
|
"allow_changing_database": True,
|
|
"allow_changing_schema": True,
|
|
},
|
|
["ANALYTICS_DB", "SALES_DB", "MARKETING_DB"],
|
|
None,
|
|
["ANALYTICS_DB", "SALES_DB", "MARKETING_DB"],
|
|
[],
|
|
),
|
|
# Configuration with account + auth + database - populates both
|
|
(
|
|
{
|
|
"account_identifier": "test_account",
|
|
"database": "ANALYTICS_DB",
|
|
"auth": {
|
|
"auth_type": "user_password",
|
|
"username": "test_user",
|
|
"password": "test_password",
|
|
},
|
|
"allow_changing_schema": True,
|
|
},
|
|
["ANALYTICS_DB", "SALES_DB", "MARKETING_DB"],
|
|
["PUBLIC", "STAGING", "DEV"],
|
|
["ANALYTICS_DB", "SALES_DB", "MARKETING_DB"],
|
|
["PUBLIC", "STAGING", "DEV"],
|
|
),
|
|
# Configuration with account + auth, single database
|
|
(
|
|
{
|
|
"account_identifier": "prod_account",
|
|
"auth": {
|
|
"auth_type": "user_password",
|
|
"username": "admin",
|
|
"password": "secret",
|
|
},
|
|
"allow_changing_database": True,
|
|
"allow_changing_schema": True,
|
|
},
|
|
["PRODUCTION"],
|
|
None,
|
|
["PRODUCTION"],
|
|
[],
|
|
),
|
|
],
|
|
)
|
|
def test_get_configuration_schema(
|
|
configuration: dict | None,
|
|
databases: list[str] | None,
|
|
schemas: list[str] | None,
|
|
expected_db_enum: list[str],
|
|
expected_schema_enum: list[str],
|
|
) -> None:
|
|
"""
|
|
Test configuration schema generation with dynamic database/schema enums.
|
|
"""
|
|
if configuration is None:
|
|
# Test without configuration
|
|
schema = SnowflakeSemanticLayer.get_configuration_schema()
|
|
|
|
assert "properties" in schema
|
|
assert "database" in schema["properties"]
|
|
assert "schema" in schema["properties"]
|
|
assert schema["properties"]["database"]["enum"] == expected_db_enum
|
|
assert schema["properties"]["schema"]["enum"] == expected_schema_enum
|
|
else:
|
|
# Create configuration
|
|
config = SnowflakeConfiguration(**configuration)
|
|
|
|
# Mock the connection and cursor
|
|
mock_cursor = MagicMock()
|
|
mock_connection = MagicMock()
|
|
mock_connection.cursor.return_value = mock_cursor
|
|
|
|
# Setup cursor responses
|
|
if databases:
|
|
# SHOW DATABASES returns (name, name, ...)
|
|
mock_cursor.__iter__.return_value = iter(
|
|
[(i, db, "", "", "", "", "") for i, db in enumerate(databases)]
|
|
)
|
|
|
|
if schemas:
|
|
# SELECT SCHEMA_NAME returns (schema_name,)
|
|
mock_cursor.execute.return_value = iter([(schema,) for schema in schemas])
|
|
|
|
# Mock connect to return our mock connection
|
|
with patch("superset.semantic_layers.snowflake_.connect") as mock_connect:
|
|
mock_connect.return_value.__enter__.return_value = mock_connection
|
|
|
|
# Get the schema
|
|
schema = SnowflakeSemanticLayer.get_configuration_schema(config)
|
|
|
|
# Verify connect was called
|
|
mock_connect.assert_called_once()
|
|
|
|
# Verify schema structure
|
|
assert "properties" in schema
|
|
assert "database" in schema["properties"]
|
|
assert "schema" in schema["properties"]
|
|
|
|
# Verify database enum (compare as sets since order isn't guaranteed)
|
|
assert set(schema["properties"]["database"]["enum"]) == set(
|
|
expected_db_enum
|
|
)
|
|
|
|
# Verify schema enum (may not have 'enum' key if database not set)
|
|
if expected_schema_enum:
|
|
assert set(schema["properties"]["schema"]["enum"]) == set(
|
|
expected_schema_enum
|
|
)
|
|
else:
|
|
# When no schemas are expected, enum key may not exist
|
|
# or may be an empty list
|
|
schema_enum = schema["properties"]["schema"].get("enum", [])
|
|
assert set(schema_enum) == set(expected_schema_enum)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"configuration, runtime_data, databases, schemas, expect_database, expect_schema",
|
|
[
|
|
# Database + schema configured, no changing allowed -> empty runtime schema
|
|
(
|
|
{
|
|
"account_identifier": "test_account",
|
|
"database": "ANALYTICS_DB",
|
|
"schema": "PUBLIC",
|
|
"auth": {
|
|
"auth_type": "user_password",
|
|
"username": "test_user",
|
|
"password": "test_password",
|
|
},
|
|
"allow_changing_database": False,
|
|
"allow_changing_schema": False,
|
|
},
|
|
None,
|
|
None,
|
|
None,
|
|
False,
|
|
False,
|
|
),
|
|
# Database configured, schema not configured -> shows schemas
|
|
(
|
|
{
|
|
"account_identifier": "test_account",
|
|
"database": "ANALYTICS_DB",
|
|
"auth": {
|
|
"auth_type": "user_password",
|
|
"username": "test_user",
|
|
"password": "test_password",
|
|
},
|
|
"allow_changing_schema": True,
|
|
},
|
|
None,
|
|
None,
|
|
["PUBLIC", "STAGING", "DEV"],
|
|
False,
|
|
True,
|
|
),
|
|
# Database configured, allow_changing_schema=True -> shows schemas
|
|
(
|
|
{
|
|
"account_identifier": "test_account",
|
|
"database": "ANALYTICS_DB",
|
|
"schema": "PUBLIC",
|
|
"auth": {
|
|
"auth_type": "user_password",
|
|
"username": "test_user",
|
|
"password": "test_password",
|
|
},
|
|
"allow_changing_schema": True,
|
|
},
|
|
None,
|
|
None,
|
|
["PUBLIC", "STAGING", "DEV"],
|
|
False,
|
|
True,
|
|
),
|
|
# Database not configured -> shows databases
|
|
(
|
|
{
|
|
"account_identifier": "test_account",
|
|
"auth": {
|
|
"auth_type": "user_password",
|
|
"username": "test_user",
|
|
"password": "test_password",
|
|
},
|
|
"allow_changing_database": True,
|
|
"allow_changing_schema": True,
|
|
},
|
|
None,
|
|
["ANALYTICS_DB", "SALES_DB"],
|
|
None,
|
|
True,
|
|
True,
|
|
),
|
|
# Database configured, allow_changing_database=True -> shows databases
|
|
(
|
|
{
|
|
"account_identifier": "test_account",
|
|
"database": "ANALYTICS_DB",
|
|
"schema": "PUBLIC",
|
|
"auth": {
|
|
"auth_type": "user_password",
|
|
"username": "test_user",
|
|
"password": "test_password",
|
|
},
|
|
"allow_changing_database": True,
|
|
"allow_changing_schema": False,
|
|
},
|
|
None,
|
|
["ANALYTICS_DB", "SALES_DB"],
|
|
None,
|
|
True,
|
|
False,
|
|
),
|
|
# Runtime data provides database -> shows schemas for that database
|
|
(
|
|
{
|
|
"account_identifier": "test_account",
|
|
"auth": {
|
|
"auth_type": "user_password",
|
|
"username": "test_user",
|
|
"password": "test_password",
|
|
},
|
|
"allow_changing_database": True,
|
|
"allow_changing_schema": True,
|
|
},
|
|
{"database": "SALES_DB"},
|
|
["ANALYTICS_DB", "SALES_DB"],
|
|
["SALES_SCHEMA", "CUSTOMER_SCHEMA"],
|
|
True,
|
|
True,
|
|
),
|
|
],
|
|
)
|
|
def test_get_runtime_schema(
|
|
configuration: dict,
|
|
runtime_data: dict | None,
|
|
databases: list[str] | None,
|
|
schemas: list[str] | None,
|
|
expect_database: bool,
|
|
expect_schema: bool,
|
|
) -> None:
|
|
"""
|
|
Test runtime schema generation with various configuration combinations.
|
|
|
|
The runtime schema should only include fields that the user can change:
|
|
- database field if database is not configured or changing is allowed
|
|
- schema field if schema is not configured or changing is allowed
|
|
"""
|
|
# Create configuration
|
|
config = SnowflakeConfiguration(**configuration)
|
|
|
|
# Mock the connection and cursor
|
|
mock_cursor = MagicMock()
|
|
mock_connection = MagicMock()
|
|
mock_connection.cursor.return_value = mock_cursor
|
|
|
|
# Setup cursor responses
|
|
if databases:
|
|
# SHOW DATABASES returns (name, name, ...)
|
|
mock_cursor.__iter__.return_value = iter(
|
|
[(i, db, "", "", "", "", "") for i, db in enumerate(databases)]
|
|
)
|
|
|
|
if schemas:
|
|
# SELECT SCHEMA_NAME returns (schema_name,)
|
|
mock_cursor.execute.return_value = iter([(schema,) for schema in schemas])
|
|
|
|
# Mock connect to return our mock connection
|
|
with patch("superset.semantic_layers.snowflake_.connect") as mock_connect:
|
|
mock_connect.return_value.__enter__.return_value = mock_connection
|
|
|
|
# Get the runtime schema
|
|
schema = SnowflakeSemanticLayer.get_runtime_schema(config, runtime_data)
|
|
|
|
# Verify connect was called
|
|
mock_connect.assert_called_once()
|
|
|
|
# Verify schema structure
|
|
assert "properties" in schema
|
|
|
|
# Verify database field presence
|
|
if expect_database:
|
|
assert "database" in schema["properties"]
|
|
# Should have enum with available databases
|
|
if databases:
|
|
db_enum = schema["properties"]["database"].get("enum", [])
|
|
assert set(db_enum) == set(databases)
|
|
else:
|
|
assert "database" not in schema["properties"]
|
|
|
|
# Verify schema field presence
|
|
if expect_schema:
|
|
assert "schema" in schema["properties"]
|
|
# Should have enum with available schemas if we have a database
|
|
if schemas and (
|
|
configuration.get("database")
|
|
or (runtime_data and runtime_data.get("database"))
|
|
):
|
|
schema_enum = schema["properties"]["schema"].get("enum", [])
|
|
assert set(schema_enum) == set(schemas)
|
|
else:
|
|
assert "schema" not in schema["properties"]
|
|
|
|
|
|
def test_get_dimensions(
|
|
mocker: MockerFixture,
|
|
connection: MagicMock,
|
|
semantic_view: SnowflakeSemanticView,
|
|
) -> None:
|
|
"""
|
|
Test dimension retrieval and parsing from Snowflake semantic layer.
|
|
"""
|
|
assert semantic_view.dimensions == {
|
|
Dimension(
|
|
id="CUSTOMER.C_CUSTOMER_SK",
|
|
name="C_CUSTOMER_SK",
|
|
type=INTEGER,
|
|
definition=None,
|
|
description="c_customer_sk",
|
|
grain=None,
|
|
),
|
|
Dimension(
|
|
id="STORE.SQUAREFOOTAGE",
|
|
name="SQUAREFOOTAGE",
|
|
type=INTEGER,
|
|
definition=None,
|
|
description="S_FLOOR_SPACE",
|
|
grain=None,
|
|
),
|
|
Dimension(
|
|
id="ITEM.BRAND",
|
|
name="BRAND",
|
|
type=STRING,
|
|
definition=None,
|
|
description="I_BRAND",
|
|
grain=None,
|
|
),
|
|
Dimension(
|
|
id="ITEM.CATEGORY",
|
|
name="CATEGORY",
|
|
type=STRING,
|
|
definition=None,
|
|
description="I_CATEGORY",
|
|
grain=None,
|
|
),
|
|
Dimension(
|
|
id="STORE.S_STORE_SK",
|
|
name="S_STORE_SK",
|
|
type=INTEGER,
|
|
definition=None,
|
|
description="s_store_sk",
|
|
grain=None,
|
|
),
|
|
Dimension(
|
|
id="STORESALES.SS_CUSTOMER_SK",
|
|
name="SS_CUSTOMER_SK",
|
|
type=INTEGER,
|
|
definition=None,
|
|
description="ss_customer_sk",
|
|
grain=None,
|
|
),
|
|
Dimension(
|
|
id="DATE.DATE",
|
|
name="DATE",
|
|
type=DATE,
|
|
definition=None,
|
|
description="D_DATE",
|
|
grain=None,
|
|
),
|
|
Dimension(
|
|
id="DEMO.CD_DEMO_SK",
|
|
name="CD_DEMO_SK",
|
|
type=INTEGER,
|
|
definition=None,
|
|
description="cd_demo_sk",
|
|
grain=None,
|
|
),
|
|
Dimension(
|
|
id="DATE.MONTH",
|
|
name="MONTH",
|
|
type=INTEGER,
|
|
definition=None,
|
|
description="D_MOY",
|
|
grain=None,
|
|
),
|
|
Dimension(
|
|
id="STORE.MARKET",
|
|
name="MARKET",
|
|
type=INTEGER,
|
|
definition=None,
|
|
description="S_MARKET_ID",
|
|
grain=None,
|
|
),
|
|
Dimension(
|
|
id="STORESALES.SS_ITEM_SK",
|
|
name="SS_ITEM_SK",
|
|
type=INTEGER,
|
|
definition=None,
|
|
description="ss_item_sk",
|
|
grain=None,
|
|
),
|
|
Dimension(
|
|
id="STORE.STORECOUNTRY",
|
|
name="STORECOUNTRY",
|
|
type=STRING,
|
|
definition=None,
|
|
description="S_COUNTRY",
|
|
grain=None,
|
|
),
|
|
Dimension(
|
|
id="ITEM.CLASS",
|
|
name="CLASS",
|
|
type=STRING,
|
|
definition=None,
|
|
description="I_CLASS",
|
|
grain=None,
|
|
),
|
|
Dimension(
|
|
id="CUSTOMER.COUNTRY",
|
|
name="COUNTRY",
|
|
type=STRING,
|
|
definition=None,
|
|
description="C_BIRTH_COUNTRY",
|
|
grain=None,
|
|
),
|
|
Dimension(
|
|
id="DEMO.CREDIT_RATING",
|
|
name="CREDIT_RATING",
|
|
type=STRING,
|
|
definition=None,
|
|
description="CD_CREDIT_RATING",
|
|
grain=None,
|
|
),
|
|
Dimension(
|
|
id="DATE.WEEK",
|
|
name="WEEK",
|
|
type=INTEGER,
|
|
definition=None,
|
|
description="D_WEEK_SEQ",
|
|
grain=None,
|
|
),
|
|
Dimension(
|
|
id="DATE.D_DATE_SK",
|
|
name="D_DATE_SK",
|
|
type=INTEGER,
|
|
definition=None,
|
|
description="d_date_sk",
|
|
grain=None,
|
|
),
|
|
Dimension(
|
|
id="STORESALES.SS_SOLD_DATE_SK",
|
|
name="SS_SOLD_DATE_SK",
|
|
type=INTEGER,
|
|
definition=None,
|
|
description="ss_sold_date_sk",
|
|
grain=None,
|
|
),
|
|
Dimension(
|
|
id="CUSTOMER.BIRTHYEAR",
|
|
name="BIRTHYEAR",
|
|
type=INTEGER,
|
|
definition=None,
|
|
description="C_BIRTH_YEAR",
|
|
grain=None,
|
|
),
|
|
Dimension(
|
|
id="DEMO.MARITAL_STATUS",
|
|
name="MARITAL_STATUS",
|
|
type=STRING,
|
|
definition=None,
|
|
description="CD_MARITAL_STATUS",
|
|
grain=None,
|
|
),
|
|
Dimension(
|
|
id="STORESALES.SS_CDEMO_SK",
|
|
name="SS_CDEMO_SK",
|
|
type=INTEGER,
|
|
definition=None,
|
|
description="ss_cdemo_sk",
|
|
grain=None,
|
|
),
|
|
Dimension(
|
|
id="DATE.YEAR",
|
|
name="YEAR",
|
|
type=INTEGER,
|
|
definition=None,
|
|
description="D_YEAR",
|
|
grain=None,
|
|
),
|
|
Dimension(
|
|
id="ITEM.I_ITEM_SK",
|
|
name="I_ITEM_SK",
|
|
type=INTEGER,
|
|
definition=None,
|
|
description="i_item_sk",
|
|
grain=None,
|
|
),
|
|
Dimension(
|
|
id="STORESALES.SS_STORE_SK",
|
|
name="SS_STORE_SK",
|
|
type=INTEGER,
|
|
definition=None,
|
|
description="ss_store_sk",
|
|
grain=None,
|
|
),
|
|
Dimension(
|
|
id="STORE.STATE",
|
|
name="STATE",
|
|
type=STRING,
|
|
definition=None,
|
|
description="S_STATE",
|
|
grain=None,
|
|
),
|
|
}
|
|
|
|
connection.cursor().execute.assert_any_call(
|
|
"""
|
|
DESC SEMANTIC VIEW "SAMPLE_DATA"."TPCDS_SF10TCL"."TPCDS_SEMANTIC_VIEW_SM"
|
|
->> SELECT "object_name", "property", "property_value"
|
|
FROM $1
|
|
WHERE
|
|
"object_kind" = 'DIMENSION' AND
|
|
"property" IN ('COMMENT', 'DATA_TYPE', 'EXPRESSION', 'TABLE');
|
|
""".strip()
|
|
)
|
|
|
|
|
|
def test_get_metrics(
|
|
mocker: MockerFixture,
|
|
connection: MagicMock,
|
|
semantic_view: SnowflakeSemanticView,
|
|
) -> None:
|
|
"""
|
|
Test metric retrieval and parsing from Snowflake semantic layer.
|
|
"""
|
|
assert semantic_view.metrics == {
|
|
Metric(
|
|
id="STORESALES.TOTALCOST",
|
|
name="TOTALCOST",
|
|
type=NUMBER,
|
|
definition="SUM(item.cost)",
|
|
description=None,
|
|
),
|
|
Metric(
|
|
id="STORESALES.TOTALSALESQUANTITY",
|
|
name="TOTALSALESQUANTITY",
|
|
type=INTEGER,
|
|
definition="SUM(SS_QUANTITY)",
|
|
description=None,
|
|
),
|
|
Metric(
|
|
id="STORESALES.TOTALSALESPRICE",
|
|
name="TOTALSALESPRICE",
|
|
type=NUMBER,
|
|
definition="SUM(SS_SALES_PRICE)",
|
|
description=None,
|
|
),
|
|
}
|
|
|
|
connection.cursor().execute.assert_any_call(
|
|
"""
|
|
DESC SEMANTIC VIEW "SAMPLE_DATA"."TPCDS_SF10TCL"."TPCDS_SEMANTIC_VIEW_SM"
|
|
->> SELECT "object_name", "property", "property_value"
|
|
FROM $1
|
|
WHERE
|
|
"object_kind" = 'METRIC' AND
|
|
"property" IN ('COMMENT', 'DATA_TYPE', 'EXPRESSION', 'TABLE');
|
|
""".strip()
|
|
)
|
|
|
|
|
|
def test_get_values(
|
|
mocker: MockerFixture,
|
|
connection: MagicMock,
|
|
semantic_view: SnowflakeSemanticView,
|
|
) -> None:
|
|
connection.cursor().execute().fetch_pandas_all.return_value = DataFrame(
|
|
{
|
|
"CATEGORY": [
|
|
"Music",
|
|
"Women",
|
|
"Home",
|
|
"Children",
|
|
"Men",
|
|
"Electronics",
|
|
"Sports",
|
|
"Shoes",
|
|
"Jewelry",
|
|
"Books",
|
|
None,
|
|
]
|
|
}
|
|
)
|
|
|
|
dimension = Dimension(
|
|
id="ITEM.CATEGORY",
|
|
name="CATEGORY",
|
|
type=STRING,
|
|
description=None,
|
|
definition="I_CATEGORY",
|
|
grain=None,
|
|
)
|
|
|
|
result = semantic_view.get_values(dimension)
|
|
|
|
assert result.requests == [
|
|
SemanticRequest(
|
|
type="snowflake",
|
|
definition="""
|
|
SELECT "CATEGORY"
|
|
FROM SEMANTIC_VIEW(
|
|
"SAMPLE_DATA"."TPCDS_SF10TCL"."TPCDS_SEMANTIC_VIEW_SM"
|
|
DIMENSIONS ITEM.CATEGORY
|
|
|
|
)
|
|
""".strip(),
|
|
)
|
|
]
|
|
assert result.results["CATEGORY"].tolist() == [
|
|
"Music",
|
|
"Women",
|
|
"Home",
|
|
"Children",
|
|
"Men",
|
|
"Electronics",
|
|
"Sports",
|
|
"Shoes",
|
|
"Jewelry",
|
|
"Books",
|
|
None,
|
|
]
|
|
|
|
|
|
def test_get_values_with_filters(
|
|
mocker: MockerFixture,
|
|
connection: MagicMock,
|
|
semantic_view: SnowflakeSemanticView,
|
|
) -> None:
|
|
connection.cursor().execute().fetch_pandas_all.return_value = DataFrame(
|
|
{
|
|
"CATEGORY": [
|
|
"Music",
|
|
"Women",
|
|
"Home",
|
|
"Children",
|
|
"Men",
|
|
"Electronics",
|
|
"Sports",
|
|
"Shoes",
|
|
"Jewelry",
|
|
]
|
|
}
|
|
)
|
|
|
|
dimension = Dimension(
|
|
id="ITEM.CATEGORY",
|
|
name="CATEGORY",
|
|
type=STRING,
|
|
description=None,
|
|
definition="I_CATEGORY",
|
|
grain=None,
|
|
)
|
|
filters = {
|
|
Filter(PredicateType.WHERE, dimension, Operator.NOT_EQUALS, "Books"),
|
|
Filter(PredicateType.WHERE, dimension, Operator.IS_NOT_NULL, None),
|
|
}
|
|
|
|
result = semantic_view.get_values(dimension, filters)
|
|
|
|
assert result.requests == [
|
|
SemanticRequest(
|
|
type="snowflake",
|
|
definition="""
|
|
SELECT "CATEGORY"
|
|
FROM SEMANTIC_VIEW(
|
|
"SAMPLE_DATA"."TPCDS_SF10TCL"."TPCDS_SEMANTIC_VIEW_SM"
|
|
DIMENSIONS ITEM.CATEGORY
|
|
WHERE ("CATEGORY" != 'Books') AND ("CATEGORY" IS NOT NULL)
|
|
)
|
|
""".strip(),
|
|
)
|
|
]
|
|
assert result.results["CATEGORY"].tolist() == [
|
|
"Music",
|
|
"Women",
|
|
"Home",
|
|
"Children",
|
|
"Men",
|
|
"Electronics",
|
|
"Sports",
|
|
"Shoes",
|
|
"Jewelry",
|
|
]
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"metrics, dimensions, filters, order, limit, offset, sql, parameters",
|
|
[
|
|
(
|
|
["TOTALSALESPRICE"],
|
|
[],
|
|
{
|
|
AdhocFilter(PredicateType.WHERE, "Year = '2002'"),
|
|
AdhocFilter(PredicateType.WHERE, "Month = '12'"),
|
|
},
|
|
None,
|
|
10,
|
|
10,
|
|
"""
|
|
SELECT * FROM SEMANTIC_VIEW(
|
|
"SAMPLE_DATA"."TPCDS_SF10TCL"."TPCDS_SEMANTIC_VIEW_SM"
|
|
|
|
METRICS STORESALES.TOTALSALESPRICE AS "STORESALES.TOTALSALESPRICE"
|
|
WHERE (Month = '12') AND (Year = '2002')
|
|
)
|
|
|
|
LIMIT 10
|
|
OFFSET 10
|
|
""",
|
|
(),
|
|
),
|
|
(
|
|
[],
|
|
["CATEGORY"],
|
|
{
|
|
AdhocFilter(PredicateType.WHERE, "Year = '2002'"),
|
|
AdhocFilter(PredicateType.WHERE, "Month = '12'"),
|
|
},
|
|
None,
|
|
20,
|
|
None,
|
|
"""
|
|
SELECT * FROM SEMANTIC_VIEW(
|
|
"SAMPLE_DATA"."TPCDS_SF10TCL"."TPCDS_SEMANTIC_VIEW_SM"
|
|
DIMENSIONS ITEM.CATEGORY AS "ITEM.CATEGORY"
|
|
|
|
WHERE (Month = '12') AND (Year = '2002')
|
|
)
|
|
|
|
LIMIT 20
|
|
""",
|
|
(),
|
|
),
|
|
(
|
|
["TOTALSALESPRICE"],
|
|
["CATEGORY"],
|
|
{
|
|
AdhocFilter(PredicateType.WHERE, "Year = '2002'"),
|
|
AdhocFilter(PredicateType.WHERE, "Month = '12'"),
|
|
},
|
|
[
|
|
("TOTALSALESPRICE", OrderDirection.DESC),
|
|
("CATEGORY", OrderDirection.ASC),
|
|
],
|
|
10,
|
|
10,
|
|
"""
|
|
SELECT * FROM SEMANTIC_VIEW(
|
|
"SAMPLE_DATA"."TPCDS_SF10TCL"."TPCDS_SEMANTIC_VIEW_SM"
|
|
DIMENSIONS ITEM.CATEGORY AS "ITEM.CATEGORY"
|
|
METRICS STORESALES.TOTALSALESPRICE AS "STORESALES.TOTALSALESPRICE"
|
|
WHERE (Month = '12') AND (Year = '2002')
|
|
)
|
|
ORDER BY "STORESALES.TOTALSALESPRICE" DESC, "ITEM.CATEGORY" ASC
|
|
LIMIT 10
|
|
OFFSET 10
|
|
""",
|
|
(),
|
|
),
|
|
],
|
|
)
|
|
def test_get_query(
|
|
semantic_view: SnowflakeSemanticView,
|
|
metrics: list[str],
|
|
dimensions: list[str],
|
|
filters: set[Filter | AdhocFilter] | None,
|
|
order: list[tuple[str, OrderDirection]] | None,
|
|
limit: int | None,
|
|
offset: int | None,
|
|
sql: str,
|
|
parameters: tuple[FilterValues, ...],
|
|
) -> None:
|
|
"""
|
|
Tests for query generation.
|
|
"""
|
|
metric_map = {metric.name: metric for metric in semantic_view.metrics}
|
|
dimension_map = {dim.name: dim for dim in semantic_view.dimensions}
|
|
|
|
assert semantic_view._get_query(
|
|
[metric_map[name] for name in metrics],
|
|
[dimension_map[name] for name in dimensions],
|
|
filters,
|
|
[
|
|
(metric_map.get(name) or dimension_map.get(name), direction)
|
|
for name, direction in (order or [])
|
|
],
|
|
limit,
|
|
offset,
|
|
) == (sql.strip(), parameters)
|