fix(sqla): parenthesize extras where/having clauses in query generation (#38183)

Co-authored-by: Diego Pucci <diegopucci.me@gmail.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Enzo Martellucci
2026-03-09 10:05:55 +01:00
committed by GitHub
parent 9983e255f8
commit c7a1f57487
2 changed files with 117 additions and 4 deletions

View File

@@ -58,7 +58,7 @@ from sqlalchemy import and_, Column, or_, UniqueConstraint
from sqlalchemy.exc import MultipleResultsFound
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.orm import Mapper, validates
from sqlalchemy.sql.elements import ColumnElement, literal_column, TextClause
from sqlalchemy.sql.elements import ColumnElement, Grouping, literal_column, TextClause
from sqlalchemy.sql.expression import Label, Select, TextAsFrom
from sqlalchemy.sql.selectable import Alias, TableClause
from sqlalchemy_utils import UUIDType
@@ -2980,7 +2980,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
where_clause_and: list[ColumnElement] = []
having_clause_and: list[ColumnElement] = []
for flt in filter: # type: ignore
for flt in filter or []:
if not all(flt.get(s) for s in ["col", "op"]):
continue
flt_col = flt["col"]
@@ -3221,7 +3221,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
schema=self.schema,
template_processor=template_processor,
)
where_clause_and += [self.text(where)]
where_clause_and += [Grouping(self.text(where))]
having = extras.get("having")
if having:
having = self._process_select_expression(
@@ -3231,7 +3231,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
schema=self.schema,
template_processor=template_processor,
)
having_clause_and += [self.text(having)]
having_clause_and += [Grouping(self.text(having))]
if apply_fetch_values_predicate and self.fetch_values_predicate:
qry = qry.where(

View File

@@ -1743,3 +1743,116 @@ def test_orderby_adhoc_column(database: Database) -> None:
# Verify the SQL contains the expression from the adhoc column
sql = str(result.sqla_query)
assert "ORDER BY" in sql.upper()
def test_extras_where_is_parenthesized(
database: Database,
) -> None:
"""
Test that extras.where is wrapped in parentheses when composed with other
filters.
Without parentheses, an extras.where containing OR operators combined
with other filters via AND could produce unexpected evaluation order due
to SQL operator precedence (AND binds tighter than OR). Wrapping in
parentheses ensures the expression is treated as a single logical unit.
"""
from unittest.mock import patch
from sqlalchemy import text as sa_text
from superset.connectors.sqla.models import SqlaTable, TableColumn
table = SqlaTable(
database=database,
schema=None,
table_name="t",
columns=[
TableColumn(column_name="a", type="INTEGER"),
TableColumn(column_name="b", type="TEXT"),
],
)
with (
patch.object(
table,
"get_sqla_row_level_filters",
return_value=[sa_text("(b = 'restricted')")],
),
patch.object(
table,
"_process_select_expression",
return_value="1 = 1 OR 1 = 1",
),
):
sqla_query = table.get_sqla_query(
columns=["a"],
extras={"where": "1=1 OR 1=1"},
is_timeseries=False,
metrics=[],
)
with database.get_sqla_engine() as engine:
sql = str(
sqla_query.sqla_query.compile(
dialect=engine.dialect,
compile_kwargs={"literal_binds": True},
)
)
assert "(1 = 1 OR 1 = 1)" in sql, (
f"extras.where should be wrapped in parentheses. Generated SQL: {sql}"
)
assert "b = 'restricted'" in sql, (
f"Additional filters should be present in query. Generated SQL: {sql}"
)
def test_extras_having_is_parenthesized(
database: Database,
) -> None:
"""
Test that extras.having is wrapped in parentheses when composed with
other HAVING filters, to ensure correct evaluation order.
"""
from unittest.mock import patch
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
table = SqlaTable(
database=database,
schema=None,
table_name="t",
columns=[
TableColumn(column_name="a", type="INTEGER"),
TableColumn(column_name="b", type="TEXT"),
],
metrics=[
SqlMetric(metric_name="cnt", expression="COUNT(*)"),
],
)
with patch.object(
table,
"_process_select_expression",
return_value="COUNT(*) > 0 OR 1 = 1",
):
sqla_query = table.get_sqla_query(
groupby=["b"],
metrics=["cnt"],
extras={"having": "COUNT(*) > 0 OR 1=1"},
is_timeseries=False,
)
with database.get_sqla_engine() as engine:
sql = str(
sqla_query.sqla_query.compile(
dialect=engine.dialect,
compile_kwargs={"literal_binds": True},
)
)
assert "(COUNT(*) > 0 OR 1 = 1)" in sql, (
f"extras.having should be wrapped in parentheses. Generated SQL: {sql}"
)