diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index ba07a5d148b..b0e71a7fa80 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -1430,6 +1430,7 @@ class SqlaTable( metric: AdhocMetric, columns_by_name: dict[str, TableColumn], template_processor: BaseTemplateProcessor | None = None, + processed: bool = False, ) -> ColumnElement: """ Turn an adhoc metric into a sqlalchemy column. @@ -1437,6 +1438,7 @@ class SqlaTable( :param dict metric: Adhoc metric definition :param dict columns_by_name: Columns for the current table :param template_processor: template_processor instance + :param bool processed: Whether the sqlExpression has already been processed :returns: The metric defined as a sqlalchemy column :rtype: sqlalchemy.sql.column """ @@ -1455,16 +1457,20 @@ class SqlaTable( sqla_column = column(column_name) sqla_metric = self.sqla_aggregations[metric["aggregate"]](sqla_column) elif expression_type == utils.AdhocMetricExpressionType.SQL: - try: - expression = self._process_sql_expression( - expression=metric["sqlExpression"], - database_id=self.database_id, - engine=self.database.backend, - schema=self.schema, - template_processor=template_processor, - ) - except SupersetSecurityException as ex: - raise QueryObjectValidationError(ex.message) from ex + expression = metric.get("sqlExpression") + + if not processed: + try: + expression = self._process_sql_expression( + expression=expression, + database_id=self.database_id, + engine=self.database.backend, + schema=self.schema, + template_processor=template_processor, + ) + except SupersetSecurityException as ex: + raise QueryObjectValidationError(ex.message) from ex + sqla_metric = literal_column(expression) else: raise QueryObjectValidationError("Adhoc metric expressionType is invalid") diff --git a/superset/models/helpers.py b/superset/models/helpers.py index 6e6ef22bde9..2bb0c85c538 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -871,6 +871,40 @@ class ExploreMixin: # pylint: disable=too-many-public-methods raise QueryObjectValidationError(ex.message) from ex return expression + def _process_orderby_expression( + self, + expression: Optional[str], + database_id: int, + engine: str, + schema: str, + template_processor: Optional[BaseTemplateProcessor], + ) -> Optional[str]: + """ + Validate and process an ORDER BY clause expression. + + This requires prefixing the expression with a dummy SELECT statement, so it can + be properly parsed and validated. + """ + if expression: + expression = f"SELECT 1 ORDER BY {expression}" + + if processed := self._process_sql_expression( + expression=expression, + database_id=database_id, + engine=engine, + schema=schema, + template_processor=template_processor, + ): + prefix, expression = re.split( + r"ORDER\s+BY", + processed, + maxsplit=1, + flags=re.IGNORECASE, + ) + return expression.strip() + + return None + def make_sqla_column_compatible( self, sqla_col: ColumnElement, label: Optional[str] = None ) -> ColumnElement: @@ -1139,6 +1173,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods metric: AdhocMetric, columns_by_name: dict[str, "TableColumn"], # pylint: disable=unused-argument template_processor: Optional[BaseTemplateProcessor] = None, + processed: bool = False, ) -> ColumnElement: """ Turn an adhoc metric into a sqlalchemy column. @@ -1146,6 +1181,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods :param dict metric: Adhoc metric definition :param dict columns_by_name: Columns for the current table :param template_processor: template_processor instance + :param bool processed: Whether the sqlExpression has already been processed :returns: The metric defined as a sqlalchemy column :rtype: sqlalchemy.sql.column """ @@ -1158,13 +1194,17 @@ class ExploreMixin: # pylint: disable=too-many-public-methods sqla_column = sa.column(column_name) sqla_metric = self.sqla_aggregations[metric["aggregate"]](sqla_column) elif expression_type == utils.AdhocMetricExpressionType.SQL: - expression = self._process_sql_expression( - expression=metric["sqlExpression"], - database_id=self.database_id, - engine=self.database.backend, - schema=self.schema, - template_processor=template_processor, - ) + expression = metric.get("sqlExpression") + + if not processed: + expression = self._process_sql_expression( + expression=metric["sqlExpression"], + database_id=self.database_id, + engine=self.database.backend, + schema=self.schema, + template_processor=template_processor, + ) + sqla_metric = literal_column(expression) else: raise QueryObjectValidationError("Adhoc metric expressionType is invalid") @@ -1779,7 +1819,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods if isinstance(col, dict): col = cast(AdhocMetric, col) if col.get("sqlExpression"): - col["sqlExpression"] = self._process_sql_expression( + col["sqlExpression"] = self._process_orderby_expression( expression=col["sqlExpression"], database_id=self.database_id, engine=self.database.backend, @@ -1788,9 +1828,12 @@ class ExploreMixin: # pylint: disable=too-many-public-methods ) if utils.is_adhoc_metric(col): # add adhoc sort by column to columns_by_name if not exists - col = self.adhoc_metric_to_sqla(col, columns_by_name) - # if the adhoc metric has been defined before - # use the existing instance. + col = self.adhoc_metric_to_sqla( + col, + columns_by_name, + processed=True, + ) + # use the existing instance, if possible col = metrics_exprs_by_expr.get(str(col), col) need_groupby = True elif col in metrics_exprs_by_label: diff --git a/tests/unit_tests/models/helpers_test.py b/tests/unit_tests/models/helpers_test.py index 80430ab0881..95862c29544 100644 --- a/tests/unit_tests/models/helpers_test.py +++ b/tests/unit_tests/models/helpers_test.py @@ -556,3 +556,245 @@ def test_apply_series_others_grouping_no_label_in_groupby(database: Database) -> assert "category" in result_groupby_columns # The GROUP BY expression should be different from the SELECT expression # because only SELECT gets make_sqla_column_compatible applied + + +def test_process_orderby_expression_basic( + mocker: MockerFixture, + database: Database, +) -> None: + """ + Test basic ORDER BY expression processing. + """ + from superset.connectors.sqla.models import SqlaTable + + table = SqlaTable( + database=database, + schema=None, + table_name="t", + ) + + # Mock _process_sql_expression to return a processed SELECT statement + mocker.patch.object( + table, + "_process_sql_expression", + return_value="SELECT 1 ORDER BY column_name DESC", + ) + + result = table._process_orderby_expression( + expression="column_name DESC", + database_id=database.id, + engine="sqlite", + schema="", + template_processor=None, + ) + + assert result == "column_name DESC" + + +def test_process_orderby_expression_with_case_insensitive_order_by( + mocker: MockerFixture, + database: Database, +) -> None: + """ + Test ORDER BY expression processing with case-insensitive matching. + """ + from superset.connectors.sqla.models import SqlaTable + + table = SqlaTable( + database=database, + schema=None, + table_name="t", + ) + + # Mock with lowercase "order by" + mocker.patch.object( + table, + "_process_sql_expression", + return_value="SELECT 1 order by column_name ASC", + ) + + result = table._process_orderby_expression( + expression="column_name ASC", + database_id=database.id, + engine="sqlite", + schema="", + template_processor=None, + ) + + assert result == "column_name ASC" + + +def test_process_orderby_expression_complex( + mocker: MockerFixture, + database: Database, +) -> None: + """ + Test ORDER BY expression with complex expressions. + """ + from superset.connectors.sqla.models import SqlaTable + + table = SqlaTable( + database=database, + schema=None, + table_name="t", + ) + + complex_orderby = "CASE WHEN status = 'active' THEN 1 ELSE 2 END, name DESC" + mocker.patch.object( + table, + "_process_sql_expression", + return_value=f"SELECT 1 ORDER BY {complex_orderby}", + ) + + result = table._process_orderby_expression( + expression=complex_orderby, + database_id=database.id, + engine="sqlite", + schema="", + template_processor=None, + ) + + assert result == complex_orderby + + +def test_process_orderby_expression_none( + mocker: MockerFixture, + database: Database, +) -> None: + """ + Test ORDER BY expression processing with None expression. + """ + from superset.connectors.sqla.models import SqlaTable + + table = SqlaTable( + database=database, + schema=None, + table_name="t", + ) + + # Mock should return None when input is None + mocker.patch.object( + table, + "_process_sql_expression", + return_value=None, + ) + + result = table._process_orderby_expression( + expression=None, + database_id=database.id, + engine="sqlite", + schema="", + template_processor=None, + ) + + assert result is None + + +def test_process_orderby_expression_empty_string( + mocker: MockerFixture, + database: Database, +) -> None: + """ + Test ORDER BY expression processing with empty string. + """ + from superset.connectors.sqla.models import SqlaTable + + table = SqlaTable( + database=database, + schema=None, + table_name="t", + ) + + # Mock should return None for empty string + mocker.patch.object( + table, + "_process_sql_expression", + return_value=None, + ) + + result = table._process_orderby_expression( + expression="", + database_id=database.id, + engine="sqlite", + schema="", + template_processor=None, + ) + + assert result is None + + +def test_process_orderby_expression_strips_whitespace( + mocker: MockerFixture, + database: Database, +) -> None: + """ + Test that ORDER BY expression processing strips leading/trailing whitespace. + """ + from superset.connectors.sqla.models import SqlaTable + + table = SqlaTable( + database=database, + schema=None, + table_name="t", + ) + + # Mock with extra whitespace after ORDER BY + mocker.patch.object( + table, + "_process_sql_expression", + return_value="SELECT 1 ORDER BY column_name DESC ", + ) + + result = table._process_orderby_expression( + expression="column_name DESC", + database_id=database.id, + engine="sqlite", + schema="", + template_processor=None, + ) + + assert result == "column_name DESC" + + +def test_process_orderby_expression_with_template_processor( + mocker: MockerFixture, + database: Database, +) -> None: + """ + Test ORDER BY expression with template processor. + """ + from unittest.mock import Mock + + from superset.connectors.sqla.models import SqlaTable + + table = SqlaTable( + database=database, + schema=None, + table_name="t", + ) + + # Create a mock template processor + template_processor = Mock() + + # Mock the _process_sql_expression to verify it receives the prefixed expression + mock_process = mocker.patch.object( + table, + "_process_sql_expression", + return_value="SELECT 1 ORDER BY processed_column DESC", + ) + + result = table._process_orderby_expression( + expression="column_name DESC", + database_id=database.id, + engine="sqlite", + schema="", + template_processor=template_processor, + ) + + # Verify _process_sql_expression was called with SELECT prefix + mock_process.assert_called_once() + call_args = mock_process.call_args[1] + assert call_args["expression"] == "SELECT 1 ORDER BY column_name DESC" + assert call_args["template_processor"] is template_processor + + assert result == "processed_column DESC"