fix(sqlglot): adhoc expressions (#35482)

This commit is contained in:
Beto Dealmeida
2025-10-03 12:10:10 -04:00
committed by GitHub
parent 891f826143
commit 139b5ae20c
3 changed files with 378 additions and 17 deletions

View File

@@ -1467,7 +1467,7 @@ class SqlaTable(
if not processed:
try:
expression = self._process_sql_expression(
expression = self._process_select_expression(
expression=expression,
database_id=self.database_id,
engine=self.database.backend,
@@ -1502,7 +1502,7 @@ class SqlaTable(
"""
label = utils.get_column_name(col)
try:
expression = self._process_sql_expression(
expression = self._process_select_expression(
expression=col["sqlExpression"],
database_id=self.database_id,
engine=self.database.backend,

View File

@@ -871,6 +871,40 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
raise QueryObjectValidationError(ex.message) from ex
return expression
def _process_select_expression(
self,
expression: Optional[str],
database_id: int,
engine: str,
schema: str,
template_processor: Optional[BaseTemplateProcessor],
) -> Optional[str]:
"""
Validate and process an adhoc expression used as a column or metric.
This requires prefixing the expression with a dummy SELECT statement, so it can
be properly parsed and validated.
"""
if expression:
expression = f"SELECT {expression}"
if processed := self._process_sql_expression(
expression=expression,
database_id=database_id,
engine=engine,
schema=schema,
template_processor=template_processor,
):
prefix, expression = re.split(
r"SELECT\s+",
processed,
maxsplit=1,
flags=re.IGNORECASE,
)
return expression.strip()
return None
def _process_orderby_expression(
self,
expression: Optional[str],
@@ -1200,7 +1234,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
expression = metric.get("sqlExpression")
if not processed:
expression = self._process_sql_expression(
expression = self._process_select_expression(
expression=metric["sqlExpression"],
database_id=self.database_id,
engine=self.database.backend,
@@ -1888,12 +1922,12 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
template_processor=template_processor,
)
else:
selected = validate_adhoc_subquery(
selected,
self.database,
self.catalog,
self.schema,
self.database.db_engine_spec.engine,
selected = self._process_select_expression(
expression=selected,
database_id=self.database_id,
engine=self.database.backend,
schema=self.schema,
template_processor=template_processor,
)
outer = literal_column(f"({selected})")
outer = self.make_sqla_column_compatible(outer, selected)
@@ -1917,12 +1951,12 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
_sql = quote(selected)
_column_label = selected
selected = validate_adhoc_subquery(
_sql,
self.database,
self.catalog,
self.schema,
self.database.db_engine_spec.engine,
selected = self._process_select_expression(
expression=_sql,
database_id=self.database_id,
engine=self.database.backend,
schema=self.schema,
template_processor=template_processor,
)
select_exprs.append(
@@ -2196,7 +2230,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
if extras:
where = extras.get("where")
if where:
where = self._process_sql_expression(
where = self._process_select_expression(
expression=where,
database_id=self.database_id,
engine=self.database.backend,
@@ -2206,7 +2240,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
where_clause_and += [self.text(where)]
having = extras.get("having")
if having:
having = self._process_sql_expression(
having = self._process_select_expression(
expression=having,
database_id=self.database_id,
engine=self.database.backend,

View File

@@ -798,3 +798,330 @@ def test_process_orderby_expression_with_template_processor(
assert call_args["template_processor"] is template_processor
assert result == "processed_column DESC"
def test_process_select_expression_basic(
mocker: MockerFixture,
database: Database,
) -> None:
"""
Test basic SELECT expression processing.
"""
from superset.connectors.sqla.models import SqlaTable
table = SqlaTable(
database=database,
schema=None,
table_name="t",
)
# Mock _process_sql_expression to return a processed SELECT statement
mocker.patch.object(
table,
"_process_sql_expression",
return_value="SELECT COUNT(*)",
)
result = table._process_select_expression(
expression="COUNT(*)",
database_id=database.id,
engine="sqlite",
schema="",
template_processor=None,
)
assert result == "COUNT(*)"
def test_process_select_expression_with_case_insensitive_select(
mocker: MockerFixture,
database: Database,
) -> None:
"""
Test SELECT expression processing with case-insensitive matching.
"""
from superset.connectors.sqla.models import SqlaTable
table = SqlaTable(
database=database,
schema=None,
table_name="t",
)
# Mock with lowercase "select"
mocker.patch.object(
table,
"_process_sql_expression",
return_value="select column_name",
)
result = table._process_select_expression(
expression="column_name",
database_id=database.id,
engine="sqlite",
schema="",
template_processor=None,
)
assert result == "column_name"
def test_process_select_expression_complex(
mocker: MockerFixture,
database: Database,
) -> None:
"""
Test SELECT expression with complex expressions.
"""
from superset.connectors.sqla.models import SqlaTable
table = SqlaTable(
database=database,
schema=None,
table_name="t",
)
complex_select = "CASE WHEN status = 'active' THEN 1 ELSE 0 END"
mocker.patch.object(
table,
"_process_sql_expression",
return_value=f"SELECT {complex_select}",
)
result = table._process_select_expression(
expression=complex_select,
database_id=database.id,
engine="sqlite",
schema="",
template_processor=None,
)
assert result == complex_select
def test_process_select_expression_none(
mocker: MockerFixture,
database: Database,
) -> None:
"""
Test SELECT expression processing with None expression.
"""
from superset.connectors.sqla.models import SqlaTable
table = SqlaTable(
database=database,
schema=None,
table_name="t",
)
# Mock should return None when input is None
mocker.patch.object(
table,
"_process_sql_expression",
return_value=None,
)
result = table._process_select_expression(
expression=None,
database_id=database.id,
engine="sqlite",
schema="",
template_processor=None,
)
assert result is None
def test_process_select_expression_empty_string(
mocker: MockerFixture,
database: Database,
) -> None:
"""
Test SELECT expression processing with empty string.
"""
from superset.connectors.sqla.models import SqlaTable
table = SqlaTable(
database=database,
schema=None,
table_name="t",
)
# Mock should return None for empty string
mocker.patch.object(
table,
"_process_sql_expression",
return_value=None,
)
result = table._process_select_expression(
expression="",
database_id=database.id,
engine="sqlite",
schema="",
template_processor=None,
)
assert result is None
def test_process_select_expression_strips_whitespace(
mocker: MockerFixture,
database: Database,
) -> None:
"""
Test that SELECT expression processing strips leading/trailing whitespace.
"""
from superset.connectors.sqla.models import SqlaTable
table = SqlaTable(
database=database,
schema=None,
table_name="t",
)
# Mock with extra whitespace after SELECT
mocker.patch.object(
table,
"_process_sql_expression",
return_value="SELECT column_name ",
)
result = table._process_select_expression(
expression="column_name",
database_id=database.id,
engine="sqlite",
schema="",
template_processor=None,
)
assert result == "column_name"
def test_process_select_expression_with_template_processor(
mocker: MockerFixture,
database: Database,
) -> None:
"""
Test SELECT expression with template processor.
"""
from unittest.mock import Mock
from superset.connectors.sqla.models import SqlaTable
table = SqlaTable(
database=database,
schema=None,
table_name="t",
)
# Create a mock template processor
template_processor = Mock()
# Mock the _process_sql_expression to verify it receives the prefixed expression
mock_process = mocker.patch.object(
table,
"_process_sql_expression",
return_value="SELECT processed_expression",
)
result = table._process_select_expression(
expression="some_expression",
database_id=database.id,
engine="sqlite",
schema="",
template_processor=template_processor,
)
# Verify _process_sql_expression was called with SELECT prefix
mock_process.assert_called_once()
call_args = mock_process.call_args[1]
assert call_args["expression"] == "SELECT some_expression"
assert call_args["template_processor"] is template_processor
assert result == "processed_expression"
def test_process_select_expression_distinct_column(
mocker: MockerFixture,
database: Database,
) -> None:
"""
Test SELECT expression with DISTINCT keyword (e.g., "distinct owners").
This test ensures that expressions like "distinct owners" used in adhoc
metrics or columns are properly parsed and validated.
"""
from superset.connectors.sqla.models import SqlaTable
table = SqlaTable(
database=database,
schema=None,
table_name="t",
)
# Mock _process_sql_expression to return a processed SELECT with DISTINCT
mocker.patch.object(
table,
"_process_sql_expression",
return_value="SELECT DISTINCT owners",
)
result = table._process_select_expression(
expression="distinct owners",
database_id=database.id,
engine="sqlite",
schema="",
template_processor=None,
)
assert result == "DISTINCT owners"
def test_process_select_expression_end_to_end(database: Database) -> None:
"""
End-to-end test that verifies the regex split works with real sqlglot processing.
This test does NOT mock _process_sql_expression, allowing the full flow
through sqlglot parsing and validation to ensure the regex extraction works.
"""
from superset.connectors.sqla.models import SqlaTable
table = SqlaTable(
database=database,
schema=None,
table_name="t",
)
# Test various real-world expressions
test_cases = [
# (input, expected_output)
("COUNT(*)", "COUNT(*)"),
("DISTINCT owners", "DISTINCT owners"),
("column_name", "column_name"),
(
"CASE WHEN status = 'active' THEN 1 ELSE 0 END",
"CASE WHEN status = 'active' THEN 1 ELSE 0 END",
),
("SUM(amount) / COUNT(*)", "SUM(amount) / COUNT(*)"),
("UPPER(name)", "UPPER(name)"),
]
for expression, expected in test_cases:
result = table._process_select_expression(
expression=expression,
database_id=database.id,
engine="sqlite",
schema="",
template_processor=None,
)
# sqlglot may normalize the SQL slightly, so we check the result exists
# and doesn't contain the SELECT prefix
assert result is not None, f"Failed to process: {expression}"
assert not result.upper().startswith("SELECT"), (
f"Result still has SELECT prefix: {result}"
)
# The result should contain the core expression (case-insensitive check)
assert expected.replace(" ", "").lower() in result.replace(" ", "").lower(), (
f"Expected '{expected}' to be in result '{result}' for input '{expression}'"
)