Files
superset2/tests/unit_tests/models/helpers_test.py
Luiz Otavio 5c3fe6f49b fix(drill-to-detail): drill to detail by correctly filtering by metric (#39766)
Co-authored-by: Michael S. Molina <michael.s.molina@gmail.com>
(cherry picked from commit df396aa6e9)
2026-04-30 08:51:47 -03:00

259 lines
7.6 KiB
Python

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=import-outside-toplevel
from __future__ import annotations
from contextlib import contextmanager
from typing import TYPE_CHECKING
from unittest.mock import patch
import pytest
from pytest_mock import MockerFixture
from sqlalchemy import create_engine, text
from sqlalchemy.orm.session import Session
from sqlalchemy.pool import StaticPool
if TYPE_CHECKING:
from superset.models.core import Database
@pytest.fixture
def database(mocker: MockerFixture, session: Session) -> Database:
from superset.connectors.sqla.models import SqlaTable
from superset.models.core import Database
SqlaTable.metadata.create_all(session.get_bind())
engine = create_engine(
"sqlite://",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
database = Database(database_name="db", sqlalchemy_uri="sqlite://")
connection = engine.raw_connection()
connection.execute("CREATE TABLE t (a INTEGER, b TEXT)")
connection.execute("INSERT INTO t VALUES (1, 'Alice')")
connection.execute("INSERT INTO t VALUES (NULL, 'Bob')")
connection.commit()
# since we're using an in-memory SQLite database, make sure we always
# return the same engine where the table was created
@contextmanager
def mock_get_sqla_engine():
yield engine
mocker.patch.object(
database,
"get_sqla_engine",
new=mock_get_sqla_engine,
)
return database
def test_values_for_column(database: Database) -> None:
"""
Test the `values_for_column` method.
NULL values should be returned as `None`, not `np.nan`, since NaN cannot be
serialized to JSON.
"""
from superset.connectors.sqla.models import SqlaTable, TableColumn
table = SqlaTable(
database=database,
schema=None,
table_name="t",
columns=[TableColumn(column_name="a")],
)
assert table.values_for_column("a") == [1, None]
def test_values_for_column_with_rls(database: Database) -> None:
"""
Test the `values_for_column` method with RLS enabled.
"""
from sqlalchemy.sql.elements import TextClause
from superset.connectors.sqla.models import SqlaTable, TableColumn
table = SqlaTable(
database=database,
schema=None,
table_name="t",
columns=[
TableColumn(column_name="a"),
],
)
with patch.object(
table,
"get_sqla_row_level_filters",
return_value=[
TextClause("a = 1"),
],
):
assert table.values_for_column("a") == [1]
def test_values_for_column_with_rls_no_values(database: Database) -> None:
"""
Test the `values_for_column` method with RLS enabled and no values.
"""
from sqlalchemy.sql.elements import TextClause
from superset.connectors.sqla.models import SqlaTable, TableColumn
table = SqlaTable(
database=database,
schema=None,
table_name="t",
columns=[
TableColumn(column_name="a"),
],
)
with patch.object(
table,
"get_sqla_row_level_filters",
return_value=[
TextClause("a = 2"),
],
):
assert table.values_for_column("a") == []
def test_values_for_column_calculated(
mocker: MockerFixture,
database: Database,
) -> None:
"""
Test that calculated columns work.
"""
from superset.connectors.sqla.models import SqlaTable, TableColumn
table = SqlaTable(
database=database,
schema=None,
table_name="t",
columns=[
TableColumn(
column_name="starts_with_A",
expression="CASE WHEN b LIKE 'A%' THEN 'yes' ELSE 'nope' END",
)
],
)
assert table.values_for_column("starts_with_A") == ["yes", "nope"]
def test_values_for_column_double_percents(
mocker: MockerFixture,
database: Database,
) -> None:
"""
Test the behavior of `double_percents`.
"""
from superset.connectors.sqla.models import SqlaTable, TableColumn
with database.get_sqla_engine() as engine:
engine.dialect.identifier_preparer._double_percents = "pyformat"
table = SqlaTable(
database=database,
schema=None,
table_name="t",
columns=[
TableColumn(
column_name="starts_with_A",
expression="CASE WHEN b LIKE 'A%' THEN 'yes' ELSE 'nope' END",
)
],
)
mutate_sql_based_on_config = mocker.patch.object(
database,
"mutate_sql_based_on_config",
side_effect=lambda sql: sql,
)
pd = mocker.patch("superset.models.helpers.pd")
table.values_for_column("starts_with_A")
# make sure the SQL originally had double percents
mutate_sql_based_on_config.assert_called_with(
"SELECT DISTINCT CASE WHEN b LIKE 'A%%' THEN 'yes' ELSE 'nope' END "
"AS column_values \nFROM t\n LIMIT 10000 OFFSET 0"
)
# make sure final query has single percents
with database.get_sqla_engine() as engine:
expected_sql = text(
"SELECT DISTINCT CASE WHEN b LIKE 'A%' THEN 'yes' ELSE 'nope' END "
"AS column_values \nFROM t\n LIMIT 10000 OFFSET 0"
)
called_sql = pd.read_sql_query.call_args.kwargs["sql"]
called_conn = pd.read_sql_query.call_args.kwargs["con"]
assert called_sql.compare(expected_sql) is True
assert called_conn == engine
def test_filter_by_verbose_name_resolves_to_column(
database: Database,
) -> None:
"""
A filter whose "col" value matches a column's verbose_name
(e.g. the label emitted by "Drill to detail by") must resolve to that
column and produce a WHERE clause on the underlying column_name.
"""
from superset.connectors.sqla.models import SqlaTable, TableColumn
table = SqlaTable(
database=database,
schema=None,
table_name="t",
columns=[
TableColumn(
column_name="country_code",
verbose_name="Country",
type="TEXT",
),
TableColumn(column_name="b", type="TEXT"),
],
)
sqla_query = table.get_sqla_query(
columns=["b"],
# Filter uses the verbose label, as "Drill to detail by" does.
filter=[{"col": "Country", "op": "==", "val": "US"}],
is_timeseries=False,
row_limit=10,
)
with database.get_sqla_engine() as engine:
sql = str(
sqla_query.sqla_query.compile(
dialect=engine.dialect,
compile_kwargs={"literal_binds": True},
)
)
# The filter should be translated to a WHERE clause on the real column.
assert "WHERE" in sql, f"Expected WHERE clause, got SQL: {sql}"
assert "country_code" in sql, f"Expected filter on 'country_code', got SQL: {sql}"
assert "'US'" in sql, f"Expected filter value 'US', got SQL: {sql}"