# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. # pylint: disable=import-outside-toplevel from __future__ import annotations from contextlib import contextmanager from typing import TYPE_CHECKING from unittest.mock import patch import pytest from pytest_mock import MockerFixture from sqlalchemy import create_engine from sqlalchemy.orm.session import Session from sqlalchemy.pool import StaticPool if TYPE_CHECKING: from superset.models.core import Database @pytest.fixture def database(mocker: MockerFixture, session: Session) -> Database: from superset.connectors.sqla.models import SqlaTable from superset.models.core import Database SqlaTable.metadata.create_all(session.get_bind()) engine = create_engine( "sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool, ) database = Database(database_name="db", sqlalchemy_uri="sqlite://") connection = engine.raw_connection() connection.execute("CREATE TABLE t (a INTEGER, b TEXT)") connection.execute("INSERT INTO t VALUES (1, 'Alice')") connection.execute("INSERT INTO t VALUES (NULL, 'Bob')") connection.commit() # since we're using an in-memory SQLite database, make sure we always # return the same engine where the table was created @contextmanager def mock_get_sqla_engine(catalog=None, schema=None, **kwargs): yield engine mocker.patch.object( database, "get_sqla_engine", new=mock_get_sqla_engine, ) return database def test_values_for_column(database: Database) -> None: """ Test the `values_for_column` method. NULL values should be returned as `None`, not `np.nan`, since NaN cannot be serialized to JSON. """ import numpy as np import pandas as pd from superset.connectors.sqla.models import SqlaTable, TableColumn table = SqlaTable( database=database, schema=None, table_name="t", columns=[TableColumn(column_name="a")], ) # Mock pd.read_sql_query to return a dataframe with the expected values with patch( "pandas.read_sql_query", return_value=pd.DataFrame({"column_values": [1, np.nan]}), ): assert table.values_for_column("a") == [1, None] def test_values_for_column_with_rls(database: Database) -> None: """ Test the `values_for_column` method with RLS enabled. """ import pandas as pd from sqlalchemy.sql.elements import TextClause from superset.connectors.sqla.models import SqlaTable, TableColumn table = SqlaTable( database=database, schema=None, table_name="t", columns=[ TableColumn(column_name="a"), ], ) # Mock RLS filters and pd.read_sql_query with ( patch.object( table, "get_sqla_row_level_filters", return_value=[ TextClause("a = 1"), ], ), patch( "pandas.read_sql_query", return_value=pd.DataFrame({"column_values": [1]}), ), ): assert table.values_for_column("a") == [1] def test_values_for_column_with_rls_no_values(database: Database) -> None: """ Test the `values_for_column` method with RLS enabled and no values. """ import pandas as pd from sqlalchemy.sql.elements import TextClause from superset.connectors.sqla.models import SqlaTable, TableColumn table = SqlaTable( database=database, schema=None, table_name="t", columns=[ TableColumn(column_name="a"), ], ) # Mock RLS filters and pd.read_sql_query to return empty dataframe with ( patch.object( table, "get_sqla_row_level_filters", return_value=[ TextClause("a = 2"), ], ), patch( "pandas.read_sql_query", return_value=pd.DataFrame({"column_values": []}), ), ): assert table.values_for_column("a") == [] def test_values_for_column_calculated( mocker: MockerFixture, database: Database, ) -> None: """ Test that calculated columns work. """ import pandas as pd from superset.connectors.sqla.models import SqlaTable, TableColumn table = SqlaTable( database=database, schema=None, table_name="t", columns=[ TableColumn( column_name="starts_with_A", expression="CASE WHEN b LIKE 'A%' THEN 'yes' ELSE 'nope' END", ) ], ) # Mock pd.read_sql_query to return expected values for calculated column with patch( "pandas.read_sql_query", return_value=pd.DataFrame({"column_values": ["yes", "nope"]}), ): assert table.values_for_column("starts_with_A") == ["yes", "nope"] def test_values_for_column_double_percents( mocker: MockerFixture, database: Database, ) -> None: """ Test the behavior of `double_percents`. """ import pandas as pd from superset.connectors.sqla.models import SqlaTable, TableColumn with database.get_sqla_engine() as engine: engine.dialect.identifier_preparer._double_percents = "pyformat" table = SqlaTable( database=database, schema=None, table_name="t", columns=[ TableColumn( column_name="starts_with_A", expression="CASE WHEN b LIKE 'A%' THEN 'yes' ELSE 'nope' END", ) ], ) # Mock pd.read_sql_query to capture the SQL and return expected values read_sql_mock = mocker.patch( "pandas.read_sql_query", return_value=pd.DataFrame({"column_values": ["yes", "nope"]}), ) result = table.values_for_column("starts_with_A") # Verify the result assert result == ["yes", "nope"] # Verify read_sql_query was called read_sql_mock.assert_called_once() # Get the SQL that was passed to read_sql_query called_sql = str(read_sql_mock.call_args[1]["sql"]) # The SQL should have single percents (after replacement) assert "LIKE 'A%'" in called_sql assert "LIKE 'A%%'" not in called_sql def test_apply_series_others_grouping(database: Database) -> None: """ Test the `_apply_series_others_grouping` method. This method should replace series columns with CASE expressions that group remaining series into an "Others" category based on a condition. """ from unittest.mock import Mock from superset.connectors.sqla.models import SqlaTable, TableColumn # Create a mock table for testing table = SqlaTable( database=database, schema=None, table_name="test_table", columns=[ TableColumn(column_name="category", type="TEXT"), TableColumn(column_name="metric_col", type="INTEGER"), TableColumn(column_name="other_col", type="TEXT"), ], ) # Mock SELECT expressions category_expr = Mock() category_expr.name = "category" metric_expr = Mock() metric_expr.name = "metric_col" other_expr = Mock() other_expr.name = "other_col" select_exprs = [category_expr, metric_expr, other_expr] # Mock GROUP BY columns groupby_all_columns = { "category": category_expr, "other_col": other_expr, } # Define series columns (only category should be modified) groupby_series_columns = {"category": category_expr} # Create a condition factory that always returns True def always_true_condition(col_name: str, expr) -> bool: return True # Mock the make_sqla_column_compatible method def mock_make_compatible(expr, name=None): mock_result = Mock() mock_result.name = name return mock_result with patch.object( table, "make_sqla_column_compatible", side_effect=mock_make_compatible ): # Call the method result_select_exprs, result_groupby_columns = ( table._apply_series_others_grouping( select_exprs, groupby_all_columns, groupby_series_columns, always_true_condition, ) ) # Verify SELECT expressions assert len(result_select_exprs) == 3 # Category (series column) should be replaced with CASE expression category_result = result_select_exprs[0] assert category_result.name == "category" # Should be made compatible # Metric (non-series column) should remain unchanged assert result_select_exprs[1] == metric_expr # Other (non-series column) should remain unchanged assert result_select_exprs[2] == other_expr # Verify GROUP BY columns assert len(result_groupby_columns) == 2 # Category (series column) should be replaced with CASE expression assert "category" in result_groupby_columns category_groupby_result = result_groupby_columns["category"] # After our fix, GROUP BY expressions are NOT wrapped with # make_sqla_column_compatible, so it should be a raw CASE expression, # not a Mock with .name attribute. Verify it's different from the original assert category_groupby_result != category_expr # Other (non-series column) should remain unchanged assert result_groupby_columns["other_col"] == other_expr def test_apply_series_others_grouping_with_false_condition(database: Database) -> None: """ Test the `_apply_series_others_grouping` method with a condition that returns False. This should result in CASE expressions that always use "Others". """ from unittest.mock import Mock from superset.connectors.sqla.models import SqlaTable, TableColumn # Create a mock table for testing table = SqlaTable( database=database, schema=None, table_name="test_table", columns=[TableColumn(column_name="category", type="TEXT")], ) # Mock SELECT expressions category_expr = Mock() category_expr.name = "category" select_exprs = [category_expr] # Mock GROUP BY columns groupby_all_columns = {"category": category_expr} groupby_series_columns = {"category": category_expr} # Create a condition factory that always returns False def always_false_condition(col_name: str, expr) -> bool: return False # Mock the make_sqla_column_compatible method def mock_make_compatible(expr, name=None): mock_result = Mock() mock_result.name = name return mock_result with patch.object( table, "make_sqla_column_compatible", side_effect=mock_make_compatible ): # Call the method result_select_exprs, result_groupby_columns = ( table._apply_series_others_grouping( select_exprs, groupby_all_columns, groupby_series_columns, always_false_condition, ) ) # Verify that the expressions were replaced (we can't test SQL generation # in a unit test, but we can verify the structure changed) assert len(result_select_exprs) == 1 assert result_select_exprs[0].name == "category" assert len(result_groupby_columns) == 1 assert "category" in result_groupby_columns # GROUP BY expression should be a CASE expression, not the original assert result_groupby_columns["category"] != category_expr def test_apply_series_others_grouping_sql_compilation(database: Database) -> None: """ Test that the `_apply_series_others_grouping` method properly quotes the 'Others' literal in both SELECT and GROUP BY clauses. This test verifies the fix for the bug where 'Others' was not quoted in the GROUP BY clause, causing SQL syntax errors. """ import sqlalchemy as sa from superset.connectors.sqla.models import SqlaTable, TableColumn # Create a real table instance table = SqlaTable( database=database, schema=None, table_name="test_table", columns=[ TableColumn(column_name="name", type="TEXT"), TableColumn(column_name="value", type="INTEGER"), ], ) # Create real SQLAlchemy expressions name_col = sa.column("name") value_col = sa.column("value") select_exprs = [name_col, value_col] groupby_all_columns = {"name": name_col} groupby_series_columns = {"name": name_col} # Condition factory that checks if a subquery column is not null def condition_factory(col_name: str, expr): return sa.column("series_limit.name__").is_not(None) # Call the method result_select_exprs, result_groupby_columns = table._apply_series_others_grouping( select_exprs, groupby_all_columns, groupby_series_columns, condition_factory, ) # Get the database dialect from the actual database with database.get_sqla_engine() as engine: dialect = engine.dialect # Test SELECT expression compilation select_case_expr = result_select_exprs[0] select_sql = str( select_case_expr.compile( dialect=dialect, compile_kwargs={"literal_binds": True} ) ) # Test GROUP BY expression compilation groupby_case_expr = result_groupby_columns["name"] groupby_sql = str( groupby_case_expr.compile( dialect=dialect, compile_kwargs={"literal_binds": True} ) ) # Different databases may use different quote characters # PostgreSQL/MySQL use single quotes, some might use double quotes # The key is that Others should be quoted, not bare # Check that 'Others' appears with some form of quotes # and not as a bare identifier assert " Others " not in select_sql, "Found unquoted 'Others' in SELECT" assert " Others " not in groupby_sql, "Found unquoted 'Others' in GROUP BY" # Check for common quoting patterns has_single_quotes = "'Others'" in select_sql and "'Others'" in groupby_sql has_double_quotes = '"Others"' in select_sql and '"Others"' in groupby_sql assert has_single_quotes or has_double_quotes, ( "Others literal should be quoted with either single or double quotes" ) # Verify the structure of the generated SQL assert "CASE WHEN" in select_sql assert "CASE WHEN" in groupby_sql # Check that ELSE is followed by a quoted value assert "ELSE " in select_sql assert "ELSE " in groupby_sql # The key test is that GROUP BY expression doesn't have a label # while SELECT might or might not have one depending on the database # What matters is that GROUP BY should NOT have label assert " AS " not in groupby_sql # GROUP BY should NOT have label # Also verify that if SELECT has a label, it's different from GROUP BY if " AS " in select_sql: # If labeled, SELECT and GROUP BY should be different assert select_sql != groupby_sql def test_apply_series_others_grouping_no_label_in_groupby(database: Database) -> None: """ Test that GROUP BY expressions don't get wrapped with make_sqla_column_compatible. This is a specific test for the bug fix where make_sqla_column_compatible was causing issues with literal quoting in GROUP BY clauses. """ from unittest.mock import ANY, call, Mock, patch from superset.connectors.sqla.models import SqlaTable, TableColumn # Create a table instance table = SqlaTable( database=database, schema=None, table_name="test_table", columns=[TableColumn(column_name="category", type="TEXT")], ) # Mock expressions category_expr = Mock() category_expr.name = "category" select_exprs = [category_expr] groupby_all_columns = {"category": category_expr} groupby_series_columns = {"category": category_expr} def condition_factory(col_name: str, expr): return True # Track calls to make_sqla_column_compatible with patch.object( table, "make_sqla_column_compatible", side_effect=lambda expr, name: expr ) as mock_make_compatible: result_select_exprs, result_groupby_columns = ( table._apply_series_others_grouping( select_exprs, groupby_all_columns, groupby_series_columns, condition_factory, ) ) # Verify make_sqla_column_compatible was called for SELECT expressions # but NOT for GROUP BY expressions calls = mock_make_compatible.call_args_list # Should have exactly one call (for the SELECT expression) assert len(calls) == 1 # The call should be for the SELECT expression with the column name # Using unittest.mock.ANY to match any CASE expression assert calls[0] == call(ANY, "category") # Verify the GROUP BY expression was NOT passed through # make_sqla_column_compatible - it should be the raw CASE expression assert "category" in result_groupby_columns # The GROUP BY expression should be different from the SELECT expression # because only SELECT gets make_sqla_column_compatible applied