diff --git a/superset/models/helpers.py b/superset/models/helpers.py index 46ae374fd57..f6f28f63850 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -58,7 +58,7 @@ from sqlalchemy import and_, Column, or_, UniqueConstraint from sqlalchemy.exc import MultipleResultsFound from sqlalchemy.ext.declarative import declared_attr from sqlalchemy.orm import Mapper, validates -from sqlalchemy.sql.elements import ColumnElement, literal_column, TextClause +from sqlalchemy.sql.elements import ColumnElement, Grouping, literal_column, TextClause from sqlalchemy.sql.expression import Label, Select, TextAsFrom from sqlalchemy.sql.selectable import Alias, TableClause from sqlalchemy_utils import UUIDType @@ -2980,7 +2980,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods where_clause_and: list[ColumnElement] = [] having_clause_and: list[ColumnElement] = [] - for flt in filter: # type: ignore + for flt in filter or []: if not all(flt.get(s) for s in ["col", "op"]): continue flt_col = flt["col"] @@ -3221,7 +3221,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods schema=self.schema, template_processor=template_processor, ) - where_clause_and += [self.text(where)] + where_clause_and += [Grouping(self.text(where))] having = extras.get("having") if having: having = self._process_select_expression( @@ -3231,7 +3231,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods schema=self.schema, template_processor=template_processor, ) - having_clause_and += [self.text(having)] + having_clause_and += [Grouping(self.text(having))] if apply_fetch_values_predicate and self.fetch_values_predicate: qry = qry.where( diff --git a/tests/unit_tests/models/helpers_test.py b/tests/unit_tests/models/helpers_test.py index e5a177285e8..c0354844737 100644 --- a/tests/unit_tests/models/helpers_test.py +++ b/tests/unit_tests/models/helpers_test.py @@ -1743,3 +1743,116 @@ def test_orderby_adhoc_column(database: Database) -> None: # Verify the SQL contains the expression from the adhoc column sql = str(result.sqla_query) assert "ORDER BY" in sql.upper() + + +def test_extras_where_is_parenthesized( + database: Database, +) -> None: + """ + Test that extras.where is wrapped in parentheses when composed with other + filters. + + Without parentheses, an extras.where containing OR operators combined + with other filters via AND could produce unexpected evaluation order due + to SQL operator precedence (AND binds tighter than OR). Wrapping in + parentheses ensures the expression is treated as a single logical unit. + """ + from unittest.mock import patch + + from sqlalchemy import text as sa_text + + from superset.connectors.sqla.models import SqlaTable, TableColumn + + table = SqlaTable( + database=database, + schema=None, + table_name="t", + columns=[ + TableColumn(column_name="a", type="INTEGER"), + TableColumn(column_name="b", type="TEXT"), + ], + ) + + with ( + patch.object( + table, + "get_sqla_row_level_filters", + return_value=[sa_text("(b = 'restricted')")], + ), + patch.object( + table, + "_process_select_expression", + return_value="1 = 1 OR 1 = 1", + ), + ): + sqla_query = table.get_sqla_query( + columns=["a"], + extras={"where": "1=1 OR 1=1"}, + is_timeseries=False, + metrics=[], + ) + + with database.get_sqla_engine() as engine: + sql = str( + sqla_query.sqla_query.compile( + dialect=engine.dialect, + compile_kwargs={"literal_binds": True}, + ) + ) + + assert "(1 = 1 OR 1 = 1)" in sql, ( + f"extras.where should be wrapped in parentheses. Generated SQL: {sql}" + ) + + assert "b = 'restricted'" in sql, ( + f"Additional filters should be present in query. Generated SQL: {sql}" + ) + + +def test_extras_having_is_parenthesized( + database: Database, +) -> None: + """ + Test that extras.having is wrapped in parentheses when composed with + other HAVING filters, to ensure correct evaluation order. + """ + from unittest.mock import patch + + from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn + + table = SqlaTable( + database=database, + schema=None, + table_name="t", + columns=[ + TableColumn(column_name="a", type="INTEGER"), + TableColumn(column_name="b", type="TEXT"), + ], + metrics=[ + SqlMetric(metric_name="cnt", expression="COUNT(*)"), + ], + ) + + with patch.object( + table, + "_process_select_expression", + return_value="COUNT(*) > 0 OR 1 = 1", + ): + sqla_query = table.get_sqla_query( + groupby=["b"], + metrics=["cnt"], + extras={"having": "COUNT(*) > 0 OR 1=1"}, + is_timeseries=False, + ) + + with database.get_sqla_engine() as engine: + sql = str( + sqla_query.sqla_query.compile( + dialect=engine.dialect, + compile_kwargs={"literal_binds": True}, + ) + ) + + assert "(COUNT(*) > 0 OR 1 = 1)" in sql, ( + f"extras.having should be wrapped in parentheses. Generated SQL: {sql}" + )