mirror of
https://github.com/apache/superset.git
synced 2026-04-09 11:25:23 +00:00
559 lines
18 KiB
Python
559 lines
18 KiB
Python
# Licensed to the Apache Software Foundation (ASF) under one
|
|
# or more contributor license agreements. See the NOTICE file
|
|
# distributed with this work for additional information
|
|
# regarding copyright ownership. The ASF licenses this file
|
|
# to you under the Apache License, Version 2.0 (the
|
|
# "License"); you may not use this file except in compliance
|
|
# with the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing,
|
|
# software distributed under the License is distributed on an
|
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
# KIND, either express or implied. See the License for the
|
|
# specific language governing permissions and limitations
|
|
# under the License.
|
|
|
|
# pylint: disable=import-outside-toplevel
|
|
|
|
from __future__ import annotations
|
|
|
|
from contextlib import contextmanager
|
|
from typing import TYPE_CHECKING
|
|
from unittest.mock import patch
|
|
|
|
import pytest
|
|
from pytest_mock import MockerFixture
|
|
from sqlalchemy import create_engine
|
|
from sqlalchemy.orm.session import Session
|
|
from sqlalchemy.pool import StaticPool
|
|
|
|
if TYPE_CHECKING:
|
|
from superset.models.core import Database
|
|
|
|
|
|
@pytest.fixture
|
|
def database(mocker: MockerFixture, session: Session) -> Database:
|
|
from superset.connectors.sqla.models import SqlaTable
|
|
from superset.models.core import Database
|
|
|
|
SqlaTable.metadata.create_all(session.get_bind())
|
|
|
|
engine = create_engine(
|
|
"sqlite://",
|
|
connect_args={"check_same_thread": False},
|
|
poolclass=StaticPool,
|
|
)
|
|
database = Database(database_name="db", sqlalchemy_uri="sqlite://")
|
|
|
|
connection = engine.raw_connection()
|
|
connection.execute("CREATE TABLE t (a INTEGER, b TEXT)")
|
|
connection.execute("INSERT INTO t VALUES (1, 'Alice')")
|
|
connection.execute("INSERT INTO t VALUES (NULL, 'Bob')")
|
|
connection.commit()
|
|
|
|
# since we're using an in-memory SQLite database, make sure we always
|
|
# return the same engine where the table was created
|
|
@contextmanager
|
|
def mock_get_sqla_engine(catalog=None, schema=None, **kwargs):
|
|
yield engine
|
|
|
|
mocker.patch.object(
|
|
database,
|
|
"get_sqla_engine",
|
|
new=mock_get_sqla_engine,
|
|
)
|
|
|
|
return database
|
|
|
|
|
|
def test_values_for_column(database: Database) -> None:
|
|
"""
|
|
Test the `values_for_column` method.
|
|
|
|
NULL values should be returned as `None`, not `np.nan`, since NaN cannot be
|
|
serialized to JSON.
|
|
"""
|
|
import numpy as np
|
|
import pandas as pd
|
|
|
|
from superset.connectors.sqla.models import SqlaTable, TableColumn
|
|
|
|
table = SqlaTable(
|
|
database=database,
|
|
schema=None,
|
|
table_name="t",
|
|
columns=[TableColumn(column_name="a")],
|
|
)
|
|
|
|
# Mock pd.read_sql_query to return a dataframe with the expected values
|
|
with patch(
|
|
"pandas.read_sql_query",
|
|
return_value=pd.DataFrame({"column_values": [1, np.nan]}),
|
|
):
|
|
assert table.values_for_column("a") == [1, None]
|
|
|
|
|
|
def test_values_for_column_with_rls(database: Database) -> None:
|
|
"""
|
|
Test the `values_for_column` method with RLS enabled.
|
|
"""
|
|
import pandas as pd
|
|
from sqlalchemy.sql.elements import TextClause
|
|
|
|
from superset.connectors.sqla.models import SqlaTable, TableColumn
|
|
|
|
table = SqlaTable(
|
|
database=database,
|
|
schema=None,
|
|
table_name="t",
|
|
columns=[
|
|
TableColumn(column_name="a"),
|
|
],
|
|
)
|
|
|
|
# Mock RLS filters and pd.read_sql_query
|
|
with (
|
|
patch.object(
|
|
table,
|
|
"get_sqla_row_level_filters",
|
|
return_value=[
|
|
TextClause("a = 1"),
|
|
],
|
|
),
|
|
patch(
|
|
"pandas.read_sql_query",
|
|
return_value=pd.DataFrame({"column_values": [1]}),
|
|
),
|
|
):
|
|
assert table.values_for_column("a") == [1]
|
|
|
|
|
|
def test_values_for_column_with_rls_no_values(database: Database) -> None:
|
|
"""
|
|
Test the `values_for_column` method with RLS enabled and no values.
|
|
"""
|
|
import pandas as pd
|
|
from sqlalchemy.sql.elements import TextClause
|
|
|
|
from superset.connectors.sqla.models import SqlaTable, TableColumn
|
|
|
|
table = SqlaTable(
|
|
database=database,
|
|
schema=None,
|
|
table_name="t",
|
|
columns=[
|
|
TableColumn(column_name="a"),
|
|
],
|
|
)
|
|
|
|
# Mock RLS filters and pd.read_sql_query to return empty dataframe
|
|
with (
|
|
patch.object(
|
|
table,
|
|
"get_sqla_row_level_filters",
|
|
return_value=[
|
|
TextClause("a = 2"),
|
|
],
|
|
),
|
|
patch(
|
|
"pandas.read_sql_query",
|
|
return_value=pd.DataFrame({"column_values": []}),
|
|
),
|
|
):
|
|
assert table.values_for_column("a") == []
|
|
|
|
|
|
def test_values_for_column_calculated(
|
|
mocker: MockerFixture,
|
|
database: Database,
|
|
) -> None:
|
|
"""
|
|
Test that calculated columns work.
|
|
"""
|
|
import pandas as pd
|
|
|
|
from superset.connectors.sqla.models import SqlaTable, TableColumn
|
|
|
|
table = SqlaTable(
|
|
database=database,
|
|
schema=None,
|
|
table_name="t",
|
|
columns=[
|
|
TableColumn(
|
|
column_name="starts_with_A",
|
|
expression="CASE WHEN b LIKE 'A%' THEN 'yes' ELSE 'nope' END",
|
|
)
|
|
],
|
|
)
|
|
|
|
# Mock pd.read_sql_query to return expected values for calculated column
|
|
with patch(
|
|
"pandas.read_sql_query",
|
|
return_value=pd.DataFrame({"column_values": ["yes", "nope"]}),
|
|
):
|
|
assert table.values_for_column("starts_with_A") == ["yes", "nope"]
|
|
|
|
|
|
def test_values_for_column_double_percents(
|
|
mocker: MockerFixture,
|
|
database: Database,
|
|
) -> None:
|
|
"""
|
|
Test the behavior of `double_percents`.
|
|
"""
|
|
import pandas as pd
|
|
|
|
from superset.connectors.sqla.models import SqlaTable, TableColumn
|
|
|
|
with database.get_sqla_engine() as engine:
|
|
engine.dialect.identifier_preparer._double_percents = "pyformat"
|
|
|
|
table = SqlaTable(
|
|
database=database,
|
|
schema=None,
|
|
table_name="t",
|
|
columns=[
|
|
TableColumn(
|
|
column_name="starts_with_A",
|
|
expression="CASE WHEN b LIKE 'A%' THEN 'yes' ELSE 'nope' END",
|
|
)
|
|
],
|
|
)
|
|
|
|
# Mock pd.read_sql_query to capture the SQL and return expected values
|
|
read_sql_mock = mocker.patch(
|
|
"pandas.read_sql_query",
|
|
return_value=pd.DataFrame({"column_values": ["yes", "nope"]}),
|
|
)
|
|
|
|
result = table.values_for_column("starts_with_A")
|
|
|
|
# Verify the result
|
|
assert result == ["yes", "nope"]
|
|
|
|
# Verify read_sql_query was called
|
|
read_sql_mock.assert_called_once()
|
|
|
|
# Get the SQL that was passed to read_sql_query
|
|
called_sql = str(read_sql_mock.call_args[1]["sql"])
|
|
|
|
# The SQL should have single percents (after replacement)
|
|
assert "LIKE 'A%'" in called_sql
|
|
assert "LIKE 'A%%'" not in called_sql
|
|
|
|
|
|
def test_apply_series_others_grouping(database: Database) -> None:
|
|
"""
|
|
Test the `_apply_series_others_grouping` method.
|
|
|
|
This method should replace series columns with CASE expressions that
|
|
group remaining series into an "Others" category based on a condition.
|
|
"""
|
|
from unittest.mock import Mock
|
|
|
|
from superset.connectors.sqla.models import SqlaTable, TableColumn
|
|
|
|
# Create a mock table for testing
|
|
table = SqlaTable(
|
|
database=database,
|
|
schema=None,
|
|
table_name="test_table",
|
|
columns=[
|
|
TableColumn(column_name="category", type="TEXT"),
|
|
TableColumn(column_name="metric_col", type="INTEGER"),
|
|
TableColumn(column_name="other_col", type="TEXT"),
|
|
],
|
|
)
|
|
|
|
# Mock SELECT expressions
|
|
category_expr = Mock()
|
|
category_expr.name = "category"
|
|
metric_expr = Mock()
|
|
metric_expr.name = "metric_col"
|
|
other_expr = Mock()
|
|
other_expr.name = "other_col"
|
|
|
|
select_exprs = [category_expr, metric_expr, other_expr]
|
|
|
|
# Mock GROUP BY columns
|
|
groupby_all_columns = {
|
|
"category": category_expr,
|
|
"other_col": other_expr,
|
|
}
|
|
|
|
# Define series columns (only category should be modified)
|
|
groupby_series_columns = {"category": category_expr}
|
|
|
|
# Create a condition factory that always returns True
|
|
def always_true_condition(col_name: str, expr) -> bool:
|
|
return True
|
|
|
|
# Mock the make_sqla_column_compatible method
|
|
def mock_make_compatible(expr, name=None):
|
|
mock_result = Mock()
|
|
mock_result.name = name
|
|
return mock_result
|
|
|
|
with patch.object(
|
|
table, "make_sqla_column_compatible", side_effect=mock_make_compatible
|
|
):
|
|
# Call the method
|
|
result_select_exprs, result_groupby_columns = (
|
|
table._apply_series_others_grouping(
|
|
select_exprs,
|
|
groupby_all_columns,
|
|
groupby_series_columns,
|
|
always_true_condition,
|
|
)
|
|
)
|
|
|
|
# Verify SELECT expressions
|
|
assert len(result_select_exprs) == 3
|
|
|
|
# Category (series column) should be replaced with CASE expression
|
|
category_result = result_select_exprs[0]
|
|
assert category_result.name == "category" # Should be made compatible
|
|
|
|
# Metric (non-series column) should remain unchanged
|
|
assert result_select_exprs[1] == metric_expr
|
|
|
|
# Other (non-series column) should remain unchanged
|
|
assert result_select_exprs[2] == other_expr
|
|
|
|
# Verify GROUP BY columns
|
|
assert len(result_groupby_columns) == 2
|
|
|
|
# Category (series column) should be replaced with CASE expression
|
|
assert "category" in result_groupby_columns
|
|
category_groupby_result = result_groupby_columns["category"]
|
|
# After our fix, GROUP BY expressions are NOT wrapped with
|
|
# make_sqla_column_compatible, so it should be a raw CASE expression,
|
|
# not a Mock with .name attribute. Verify it's different from the original
|
|
assert category_groupby_result != category_expr
|
|
|
|
# Other (non-series column) should remain unchanged
|
|
assert result_groupby_columns["other_col"] == other_expr
|
|
|
|
|
|
def test_apply_series_others_grouping_with_false_condition(database: Database) -> None:
|
|
"""
|
|
Test the `_apply_series_others_grouping` method with a condition that returns False.
|
|
|
|
This should result in CASE expressions that always use "Others".
|
|
"""
|
|
from unittest.mock import Mock
|
|
|
|
from superset.connectors.sqla.models import SqlaTable, TableColumn
|
|
|
|
# Create a mock table for testing
|
|
table = SqlaTable(
|
|
database=database,
|
|
schema=None,
|
|
table_name="test_table",
|
|
columns=[TableColumn(column_name="category", type="TEXT")],
|
|
)
|
|
|
|
# Mock SELECT expressions
|
|
category_expr = Mock()
|
|
category_expr.name = "category"
|
|
select_exprs = [category_expr]
|
|
|
|
# Mock GROUP BY columns
|
|
groupby_all_columns = {"category": category_expr}
|
|
groupby_series_columns = {"category": category_expr}
|
|
|
|
# Create a condition factory that always returns False
|
|
def always_false_condition(col_name: str, expr) -> bool:
|
|
return False
|
|
|
|
# Mock the make_sqla_column_compatible method
|
|
def mock_make_compatible(expr, name=None):
|
|
mock_result = Mock()
|
|
mock_result.name = name
|
|
return mock_result
|
|
|
|
with patch.object(
|
|
table, "make_sqla_column_compatible", side_effect=mock_make_compatible
|
|
):
|
|
# Call the method
|
|
result_select_exprs, result_groupby_columns = (
|
|
table._apply_series_others_grouping(
|
|
select_exprs,
|
|
groupby_all_columns,
|
|
groupby_series_columns,
|
|
always_false_condition,
|
|
)
|
|
)
|
|
|
|
# Verify that the expressions were replaced (we can't test SQL generation
|
|
# in a unit test, but we can verify the structure changed)
|
|
assert len(result_select_exprs) == 1
|
|
assert result_select_exprs[0].name == "category"
|
|
|
|
assert len(result_groupby_columns) == 1
|
|
assert "category" in result_groupby_columns
|
|
# GROUP BY expression should be a CASE expression, not the original
|
|
assert result_groupby_columns["category"] != category_expr
|
|
|
|
|
|
def test_apply_series_others_grouping_sql_compilation(database: Database) -> None:
|
|
"""
|
|
Test that the `_apply_series_others_grouping` method properly quotes
|
|
the 'Others' literal in both SELECT and GROUP BY clauses.
|
|
|
|
This test verifies the fix for the bug where 'Others' was not quoted
|
|
in the GROUP BY clause, causing SQL syntax errors.
|
|
"""
|
|
import sqlalchemy as sa
|
|
|
|
from superset.connectors.sqla.models import SqlaTable, TableColumn
|
|
|
|
# Create a real table instance
|
|
table = SqlaTable(
|
|
database=database,
|
|
schema=None,
|
|
table_name="test_table",
|
|
columns=[
|
|
TableColumn(column_name="name", type="TEXT"),
|
|
TableColumn(column_name="value", type="INTEGER"),
|
|
],
|
|
)
|
|
|
|
# Create real SQLAlchemy expressions
|
|
name_col = sa.column("name")
|
|
value_col = sa.column("value")
|
|
|
|
select_exprs = [name_col, value_col]
|
|
groupby_all_columns = {"name": name_col}
|
|
groupby_series_columns = {"name": name_col}
|
|
|
|
# Condition factory that checks if a subquery column is not null
|
|
def condition_factory(col_name: str, expr):
|
|
return sa.column("series_limit.name__").is_not(None)
|
|
|
|
# Call the method
|
|
result_select_exprs, result_groupby_columns = table._apply_series_others_grouping(
|
|
select_exprs,
|
|
groupby_all_columns,
|
|
groupby_series_columns,
|
|
condition_factory,
|
|
)
|
|
|
|
# Get the database dialect from the actual database
|
|
with database.get_sqla_engine() as engine:
|
|
dialect = engine.dialect
|
|
|
|
# Test SELECT expression compilation
|
|
select_case_expr = result_select_exprs[0]
|
|
select_sql = str(
|
|
select_case_expr.compile(
|
|
dialect=dialect, compile_kwargs={"literal_binds": True}
|
|
)
|
|
)
|
|
|
|
# Test GROUP BY expression compilation
|
|
groupby_case_expr = result_groupby_columns["name"]
|
|
groupby_sql = str(
|
|
groupby_case_expr.compile(
|
|
dialect=dialect, compile_kwargs={"literal_binds": True}
|
|
)
|
|
)
|
|
|
|
# Different databases may use different quote characters
|
|
# PostgreSQL/MySQL use single quotes, some might use double quotes
|
|
# The key is that Others should be quoted, not bare
|
|
|
|
# Check that 'Others' appears with some form of quotes
|
|
# and not as a bare identifier
|
|
assert " Others " not in select_sql, "Found unquoted 'Others' in SELECT"
|
|
assert " Others " not in groupby_sql, "Found unquoted 'Others' in GROUP BY"
|
|
|
|
# Check for common quoting patterns
|
|
has_single_quotes = "'Others'" in select_sql and "'Others'" in groupby_sql
|
|
has_double_quotes = '"Others"' in select_sql and '"Others"' in groupby_sql
|
|
|
|
assert has_single_quotes or has_double_quotes, (
|
|
"Others literal should be quoted with either single or double quotes"
|
|
)
|
|
|
|
# Verify the structure of the generated SQL
|
|
assert "CASE WHEN" in select_sql
|
|
assert "CASE WHEN" in groupby_sql
|
|
|
|
# Check that ELSE is followed by a quoted value
|
|
assert "ELSE " in select_sql
|
|
assert "ELSE " in groupby_sql
|
|
|
|
# The key test is that GROUP BY expression doesn't have a label
|
|
# while SELECT might or might not have one depending on the database
|
|
# What matters is that GROUP BY should NOT have label
|
|
assert " AS " not in groupby_sql # GROUP BY should NOT have label
|
|
|
|
# Also verify that if SELECT has a label, it's different from GROUP BY
|
|
if " AS " in select_sql:
|
|
# If labeled, SELECT and GROUP BY should be different
|
|
assert select_sql != groupby_sql
|
|
|
|
|
|
def test_apply_series_others_grouping_no_label_in_groupby(database: Database) -> None:
|
|
"""
|
|
Test that GROUP BY expressions don't get wrapped with make_sqla_column_compatible.
|
|
|
|
This is a specific test for the bug fix where make_sqla_column_compatible
|
|
was causing issues with literal quoting in GROUP BY clauses.
|
|
"""
|
|
from unittest.mock import ANY, call, Mock, patch
|
|
|
|
from superset.connectors.sqla.models import SqlaTable, TableColumn
|
|
|
|
# Create a table instance
|
|
table = SqlaTable(
|
|
database=database,
|
|
schema=None,
|
|
table_name="test_table",
|
|
columns=[TableColumn(column_name="category", type="TEXT")],
|
|
)
|
|
|
|
# Mock expressions
|
|
category_expr = Mock()
|
|
category_expr.name = "category"
|
|
|
|
select_exprs = [category_expr]
|
|
groupby_all_columns = {"category": category_expr}
|
|
groupby_series_columns = {"category": category_expr}
|
|
|
|
def condition_factory(col_name: str, expr):
|
|
return True
|
|
|
|
# Track calls to make_sqla_column_compatible
|
|
with patch.object(
|
|
table, "make_sqla_column_compatible", side_effect=lambda expr, name: expr
|
|
) as mock_make_compatible:
|
|
result_select_exprs, result_groupby_columns = (
|
|
table._apply_series_others_grouping(
|
|
select_exprs,
|
|
groupby_all_columns,
|
|
groupby_series_columns,
|
|
condition_factory,
|
|
)
|
|
)
|
|
|
|
# Verify make_sqla_column_compatible was called for SELECT expressions
|
|
# but NOT for GROUP BY expressions
|
|
calls = mock_make_compatible.call_args_list
|
|
|
|
# Should have exactly one call (for the SELECT expression)
|
|
assert len(calls) == 1
|
|
|
|
# The call should be for the SELECT expression with the column name
|
|
# Using unittest.mock.ANY to match any CASE expression
|
|
assert calls[0] == call(ANY, "category")
|
|
|
|
# Verify the GROUP BY expression was NOT passed through
|
|
# make_sqla_column_compatible - it should be the raw CASE expression
|
|
assert "category" in result_groupby_columns
|
|
# The GROUP BY expression should be different from the SELECT expression
|
|
# because only SELECT gets make_sqla_column_compatible applied
|