feat(where_in): Support returning None if filter_values return None (#32731)

This commit is contained in:
Vitor Avila
2025-03-18 16:18:51 -03:00
committed by GitHub
parent 710af87faf
commit 850801f510
2 changed files with 27 additions and 3 deletions

View File

@@ -519,7 +519,12 @@ class WhereInMacro: # pylint: disable=too-few-public-methods
def __init__(self, dialect: Dialect):
self.dialect = dialect
def __call__(self, values: list[Any], mark: Optional[str] = None) -> str:
def __call__(
self,
values: list[Any],
mark: Optional[str] = None,
default_to_none: bool = False,
) -> str | None:
"""
Given a list of values, build a parenthesis list suitable for an IN expression.
@@ -528,6 +533,10 @@ class WhereInMacro: # pylint: disable=too-few-public-methods
>>> where_in([1, "Joe's", 3])
(1, 'Joe''s', 3)
The `default_to_none` parameter is used to determine the return value when the
list of values is empty:
- If `default_to_none` is `False` (default), the return value is ().
- If `default_to_none` is `True`, the return value is `None`.
"""
binds = [bindparam(f"value_{i}", value) for i, value in enumerate(values)]
string_representations = [
@@ -539,9 +548,11 @@ class WhereInMacro: # pylint: disable=too-few-public-methods
for bind in binds
]
joined_values = ", ".join(string_representations)
result = f"({joined_values})"
result = (
f"({joined_values})" if (joined_values or not default_to_none) else None
)
if mark:
if mark and result:
result += (
"\n-- WARNING: the `mark` parameter was removed from the `where_in` "
"macro for security reasons\n"

View File

@@ -416,6 +416,19 @@ def test_where_in() -> None:
assert where_in(["O'Malley's"]) == "('O''Malley''s')"
def test_where_in_empty_list() -> None:
"""
Test the ``where_in`` Jinja2 filter when it receives an
empty list.
"""
where_in = WhereInMacro(mysql.dialect())
# By default, the filter should return empty parenthesis (as a string)
assert where_in([]) == "()"
# With the default_to_none parameter set to True, it should return None
assert where_in([], default_to_none=True) is None
def test_dataset_macro(mocker: MockerFixture) -> None:
"""
Test the ``dataset_macro`` macro.