# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. # pylint: disable=import-outside-toplevel from __future__ import annotations from contextlib import contextmanager from typing import TYPE_CHECKING from unittest.mock import patch import pytest from pytest_mock import MockerFixture from sqlalchemy import create_engine, text from sqlalchemy.orm.session import Session from sqlalchemy.pool import StaticPool if TYPE_CHECKING: from superset.models.core import Database @pytest.fixture def database(mocker: MockerFixture, session: Session) -> Database: from superset.connectors.sqla.models import SqlaTable from superset.models.core import Database SqlaTable.metadata.create_all(session.get_bind()) engine = create_engine( "sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool, ) database = Database(database_name="db", sqlalchemy_uri="sqlite://") connection = engine.raw_connection() connection.execute("CREATE TABLE t (a INTEGER, b TEXT)") connection.execute("INSERT INTO t VALUES (1, 'Alice')") connection.execute("INSERT INTO t VALUES (NULL, 'Bob')") connection.commit() # since we're using an in-memory SQLite database, make sure we always # return the same engine where the table was created @contextmanager def mock_get_sqla_engine(): yield engine mocker.patch.object( database, "get_sqla_engine", new=mock_get_sqla_engine, ) return database def test_values_for_column(database: Database) -> None: """ Test the `values_for_column` method. NULL values should be returned as `None`, not `np.nan`, since NaN cannot be serialized to JSON. """ from superset.connectors.sqla.models import SqlaTable, TableColumn table = SqlaTable( database=database, schema=None, table_name="t", columns=[TableColumn(column_name="a")], ) assert table.values_for_column("a") == [1, None] def test_values_for_column_with_rls(database: Database) -> None: """ Test the `values_for_column` method with RLS enabled. """ from sqlalchemy.sql.elements import TextClause from superset.connectors.sqla.models import SqlaTable, TableColumn table = SqlaTable( database=database, schema=None, table_name="t", columns=[ TableColumn(column_name="a"), ], ) with patch.object( table, "get_sqla_row_level_filters", return_value=[ TextClause("a = 1"), ], ): assert table.values_for_column("a") == [1] def test_values_for_column_with_rls_no_values(database: Database) -> None: """ Test the `values_for_column` method with RLS enabled and no values. """ from sqlalchemy.sql.elements import TextClause from superset.connectors.sqla.models import SqlaTable, TableColumn table = SqlaTable( database=database, schema=None, table_name="t", columns=[ TableColumn(column_name="a"), ], ) with patch.object( table, "get_sqla_row_level_filters", return_value=[ TextClause("a = 2"), ], ): assert table.values_for_column("a") == [] def test_values_for_column_calculated( mocker: MockerFixture, database: Database, ) -> None: """ Test that calculated columns work. """ from superset.connectors.sqla.models import SqlaTable, TableColumn table = SqlaTable( database=database, schema=None, table_name="t", columns=[ TableColumn( column_name="starts_with_A", expression="CASE WHEN b LIKE 'A%' THEN 'yes' ELSE 'nope' END", ) ], ) assert table.values_for_column("starts_with_A") == ["yes", "nope"] def test_values_for_column_double_percents( mocker: MockerFixture, database: Database, ) -> None: """ Test the behavior of `double_percents`. """ from superset.connectors.sqla.models import SqlaTable, TableColumn with database.get_sqla_engine() as engine: engine.dialect.identifier_preparer._double_percents = "pyformat" table = SqlaTable( database=database, schema=None, table_name="t", columns=[ TableColumn( column_name="starts_with_A", expression="CASE WHEN b LIKE 'A%' THEN 'yes' ELSE 'nope' END", ) ], ) mutate_sql_based_on_config = mocker.patch.object( database, "mutate_sql_based_on_config", side_effect=lambda sql: sql, ) pd = mocker.patch("superset.models.helpers.pd") table.values_for_column("starts_with_A") # make sure the SQL originally had double percents mutate_sql_based_on_config.assert_called_with( "SELECT DISTINCT CASE WHEN b LIKE 'A%%' THEN 'yes' ELSE 'nope' END " "AS column_values \nFROM t\n LIMIT 10000 OFFSET 0" ) # make sure final query has single percents with database.get_sqla_engine() as engine: expected_sql = text( "SELECT DISTINCT CASE WHEN b LIKE 'A%' THEN 'yes' ELSE 'nope' END " "AS column_values \nFROM t\n LIMIT 10000 OFFSET 0" ) called_sql = pd.read_sql_query.call_args.kwargs["sql"] called_conn = pd.read_sql_query.call_args.kwargs["con"] assert called_sql.compare(expected_sql) is True assert called_conn == engine def test_filter_by_verbose_name_resolves_to_column( database: Database, ) -> None: """ A filter whose "col" value matches a column's verbose_name (e.g. the label emitted by "Drill to detail by") must resolve to that column and produce a WHERE clause on the underlying column_name. """ from superset.connectors.sqla.models import SqlaTable, TableColumn table = SqlaTable( database=database, schema=None, table_name="t", columns=[ TableColumn( column_name="country_code", verbose_name="Country", type="TEXT", ), TableColumn(column_name="b", type="TEXT"), ], ) sqla_query = table.get_sqla_query( columns=["b"], # Filter uses the verbose label, as "Drill to detail by" does. filter=[{"col": "Country", "op": "==", "val": "US"}], is_timeseries=False, row_limit=10, ) with database.get_sqla_engine() as engine: sql = str( sqla_query.sqla_query.compile( dialect=engine.dialect, compile_kwargs={"literal_binds": True}, ) ) # The filter should be translated to a WHERE clause on the real column. assert "WHERE" in sql, f"Expected WHERE clause, got SQL: {sql}" assert "country_code" in sql, f"Expected filter on 'country_code', got SQL: {sql}" assert "'US'" in sql, f"Expected filter value 'US', got SQL: {sql}"