mirror of
https://github.com/apache/superset.git
synced 2026-04-07 10:31:50 +00:00
fix(sqlglot): adhoc expressions (#35482)
This commit is contained in:
@@ -1467,7 +1467,7 @@ class SqlaTable(
|
||||
|
||||
if not processed:
|
||||
try:
|
||||
expression = self._process_sql_expression(
|
||||
expression = self._process_select_expression(
|
||||
expression=expression,
|
||||
database_id=self.database_id,
|
||||
engine=self.database.backend,
|
||||
@@ -1502,7 +1502,7 @@ class SqlaTable(
|
||||
"""
|
||||
label = utils.get_column_name(col)
|
||||
try:
|
||||
expression = self._process_sql_expression(
|
||||
expression = self._process_select_expression(
|
||||
expression=col["sqlExpression"],
|
||||
database_id=self.database_id,
|
||||
engine=self.database.backend,
|
||||
|
||||
@@ -871,6 +871,40 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
|
||||
raise QueryObjectValidationError(ex.message) from ex
|
||||
return expression
|
||||
|
||||
def _process_select_expression(
|
||||
self,
|
||||
expression: Optional[str],
|
||||
database_id: int,
|
||||
engine: str,
|
||||
schema: str,
|
||||
template_processor: Optional[BaseTemplateProcessor],
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Validate and process an adhoc expression used as a column or metric.
|
||||
|
||||
This requires prefixing the expression with a dummy SELECT statement, so it can
|
||||
be properly parsed and validated.
|
||||
"""
|
||||
if expression:
|
||||
expression = f"SELECT {expression}"
|
||||
|
||||
if processed := self._process_sql_expression(
|
||||
expression=expression,
|
||||
database_id=database_id,
|
||||
engine=engine,
|
||||
schema=schema,
|
||||
template_processor=template_processor,
|
||||
):
|
||||
prefix, expression = re.split(
|
||||
r"SELECT\s+",
|
||||
processed,
|
||||
maxsplit=1,
|
||||
flags=re.IGNORECASE,
|
||||
)
|
||||
return expression.strip()
|
||||
|
||||
return None
|
||||
|
||||
def _process_orderby_expression(
|
||||
self,
|
||||
expression: Optional[str],
|
||||
@@ -1200,7 +1234,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
|
||||
expression = metric.get("sqlExpression")
|
||||
|
||||
if not processed:
|
||||
expression = self._process_sql_expression(
|
||||
expression = self._process_select_expression(
|
||||
expression=metric["sqlExpression"],
|
||||
database_id=self.database_id,
|
||||
engine=self.database.backend,
|
||||
@@ -1888,12 +1922,12 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
|
||||
template_processor=template_processor,
|
||||
)
|
||||
else:
|
||||
selected = validate_adhoc_subquery(
|
||||
selected,
|
||||
self.database,
|
||||
self.catalog,
|
||||
self.schema,
|
||||
self.database.db_engine_spec.engine,
|
||||
selected = self._process_select_expression(
|
||||
expression=selected,
|
||||
database_id=self.database_id,
|
||||
engine=self.database.backend,
|
||||
schema=self.schema,
|
||||
template_processor=template_processor,
|
||||
)
|
||||
outer = literal_column(f"({selected})")
|
||||
outer = self.make_sqla_column_compatible(outer, selected)
|
||||
@@ -1917,12 +1951,12 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
|
||||
_sql = quote(selected)
|
||||
_column_label = selected
|
||||
|
||||
selected = validate_adhoc_subquery(
|
||||
_sql,
|
||||
self.database,
|
||||
self.catalog,
|
||||
self.schema,
|
||||
self.database.db_engine_spec.engine,
|
||||
selected = self._process_select_expression(
|
||||
expression=_sql,
|
||||
database_id=self.database_id,
|
||||
engine=self.database.backend,
|
||||
schema=self.schema,
|
||||
template_processor=template_processor,
|
||||
)
|
||||
|
||||
select_exprs.append(
|
||||
@@ -2196,7 +2230,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
|
||||
if extras:
|
||||
where = extras.get("where")
|
||||
if where:
|
||||
where = self._process_sql_expression(
|
||||
where = self._process_select_expression(
|
||||
expression=where,
|
||||
database_id=self.database_id,
|
||||
engine=self.database.backend,
|
||||
@@ -2206,7 +2240,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
|
||||
where_clause_and += [self.text(where)]
|
||||
having = extras.get("having")
|
||||
if having:
|
||||
having = self._process_sql_expression(
|
||||
having = self._process_select_expression(
|
||||
expression=having,
|
||||
database_id=self.database_id,
|
||||
engine=self.database.backend,
|
||||
|
||||
@@ -798,3 +798,330 @@ def test_process_orderby_expression_with_template_processor(
|
||||
assert call_args["template_processor"] is template_processor
|
||||
|
||||
assert result == "processed_column DESC"
|
||||
|
||||
|
||||
def test_process_select_expression_basic(
|
||||
mocker: MockerFixture,
|
||||
database: Database,
|
||||
) -> None:
|
||||
"""
|
||||
Test basic SELECT expression processing.
|
||||
"""
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
|
||||
table = SqlaTable(
|
||||
database=database,
|
||||
schema=None,
|
||||
table_name="t",
|
||||
)
|
||||
|
||||
# Mock _process_sql_expression to return a processed SELECT statement
|
||||
mocker.patch.object(
|
||||
table,
|
||||
"_process_sql_expression",
|
||||
return_value="SELECT COUNT(*)",
|
||||
)
|
||||
|
||||
result = table._process_select_expression(
|
||||
expression="COUNT(*)",
|
||||
database_id=database.id,
|
||||
engine="sqlite",
|
||||
schema="",
|
||||
template_processor=None,
|
||||
)
|
||||
|
||||
assert result == "COUNT(*)"
|
||||
|
||||
|
||||
def test_process_select_expression_with_case_insensitive_select(
|
||||
mocker: MockerFixture,
|
||||
database: Database,
|
||||
) -> None:
|
||||
"""
|
||||
Test SELECT expression processing with case-insensitive matching.
|
||||
"""
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
|
||||
table = SqlaTable(
|
||||
database=database,
|
||||
schema=None,
|
||||
table_name="t",
|
||||
)
|
||||
|
||||
# Mock with lowercase "select"
|
||||
mocker.patch.object(
|
||||
table,
|
||||
"_process_sql_expression",
|
||||
return_value="select column_name",
|
||||
)
|
||||
|
||||
result = table._process_select_expression(
|
||||
expression="column_name",
|
||||
database_id=database.id,
|
||||
engine="sqlite",
|
||||
schema="",
|
||||
template_processor=None,
|
||||
)
|
||||
|
||||
assert result == "column_name"
|
||||
|
||||
|
||||
def test_process_select_expression_complex(
|
||||
mocker: MockerFixture,
|
||||
database: Database,
|
||||
) -> None:
|
||||
"""
|
||||
Test SELECT expression with complex expressions.
|
||||
"""
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
|
||||
table = SqlaTable(
|
||||
database=database,
|
||||
schema=None,
|
||||
table_name="t",
|
||||
)
|
||||
|
||||
complex_select = "CASE WHEN status = 'active' THEN 1 ELSE 0 END"
|
||||
mocker.patch.object(
|
||||
table,
|
||||
"_process_sql_expression",
|
||||
return_value=f"SELECT {complex_select}",
|
||||
)
|
||||
|
||||
result = table._process_select_expression(
|
||||
expression=complex_select,
|
||||
database_id=database.id,
|
||||
engine="sqlite",
|
||||
schema="",
|
||||
template_processor=None,
|
||||
)
|
||||
|
||||
assert result == complex_select
|
||||
|
||||
|
||||
def test_process_select_expression_none(
|
||||
mocker: MockerFixture,
|
||||
database: Database,
|
||||
) -> None:
|
||||
"""
|
||||
Test SELECT expression processing with None expression.
|
||||
"""
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
|
||||
table = SqlaTable(
|
||||
database=database,
|
||||
schema=None,
|
||||
table_name="t",
|
||||
)
|
||||
|
||||
# Mock should return None when input is None
|
||||
mocker.patch.object(
|
||||
table,
|
||||
"_process_sql_expression",
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
result = table._process_select_expression(
|
||||
expression=None,
|
||||
database_id=database.id,
|
||||
engine="sqlite",
|
||||
schema="",
|
||||
template_processor=None,
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_process_select_expression_empty_string(
|
||||
mocker: MockerFixture,
|
||||
database: Database,
|
||||
) -> None:
|
||||
"""
|
||||
Test SELECT expression processing with empty string.
|
||||
"""
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
|
||||
table = SqlaTable(
|
||||
database=database,
|
||||
schema=None,
|
||||
table_name="t",
|
||||
)
|
||||
|
||||
# Mock should return None for empty string
|
||||
mocker.patch.object(
|
||||
table,
|
||||
"_process_sql_expression",
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
result = table._process_select_expression(
|
||||
expression="",
|
||||
database_id=database.id,
|
||||
engine="sqlite",
|
||||
schema="",
|
||||
template_processor=None,
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_process_select_expression_strips_whitespace(
|
||||
mocker: MockerFixture,
|
||||
database: Database,
|
||||
) -> None:
|
||||
"""
|
||||
Test that SELECT expression processing strips leading/trailing whitespace.
|
||||
"""
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
|
||||
table = SqlaTable(
|
||||
database=database,
|
||||
schema=None,
|
||||
table_name="t",
|
||||
)
|
||||
|
||||
# Mock with extra whitespace after SELECT
|
||||
mocker.patch.object(
|
||||
table,
|
||||
"_process_sql_expression",
|
||||
return_value="SELECT column_name ",
|
||||
)
|
||||
|
||||
result = table._process_select_expression(
|
||||
expression="column_name",
|
||||
database_id=database.id,
|
||||
engine="sqlite",
|
||||
schema="",
|
||||
template_processor=None,
|
||||
)
|
||||
|
||||
assert result == "column_name"
|
||||
|
||||
|
||||
def test_process_select_expression_with_template_processor(
|
||||
mocker: MockerFixture,
|
||||
database: Database,
|
||||
) -> None:
|
||||
"""
|
||||
Test SELECT expression with template processor.
|
||||
"""
|
||||
from unittest.mock import Mock
|
||||
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
|
||||
table = SqlaTable(
|
||||
database=database,
|
||||
schema=None,
|
||||
table_name="t",
|
||||
)
|
||||
|
||||
# Create a mock template processor
|
||||
template_processor = Mock()
|
||||
|
||||
# Mock the _process_sql_expression to verify it receives the prefixed expression
|
||||
mock_process = mocker.patch.object(
|
||||
table,
|
||||
"_process_sql_expression",
|
||||
return_value="SELECT processed_expression",
|
||||
)
|
||||
|
||||
result = table._process_select_expression(
|
||||
expression="some_expression",
|
||||
database_id=database.id,
|
||||
engine="sqlite",
|
||||
schema="",
|
||||
template_processor=template_processor,
|
||||
)
|
||||
|
||||
# Verify _process_sql_expression was called with SELECT prefix
|
||||
mock_process.assert_called_once()
|
||||
call_args = mock_process.call_args[1]
|
||||
assert call_args["expression"] == "SELECT some_expression"
|
||||
assert call_args["template_processor"] is template_processor
|
||||
|
||||
assert result == "processed_expression"
|
||||
|
||||
|
||||
def test_process_select_expression_distinct_column(
|
||||
mocker: MockerFixture,
|
||||
database: Database,
|
||||
) -> None:
|
||||
"""
|
||||
Test SELECT expression with DISTINCT keyword (e.g., "distinct owners").
|
||||
|
||||
This test ensures that expressions like "distinct owners" used in adhoc
|
||||
metrics or columns are properly parsed and validated.
|
||||
"""
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
|
||||
table = SqlaTable(
|
||||
database=database,
|
||||
schema=None,
|
||||
table_name="t",
|
||||
)
|
||||
|
||||
# Mock _process_sql_expression to return a processed SELECT with DISTINCT
|
||||
mocker.patch.object(
|
||||
table,
|
||||
"_process_sql_expression",
|
||||
return_value="SELECT DISTINCT owners",
|
||||
)
|
||||
|
||||
result = table._process_select_expression(
|
||||
expression="distinct owners",
|
||||
database_id=database.id,
|
||||
engine="sqlite",
|
||||
schema="",
|
||||
template_processor=None,
|
||||
)
|
||||
|
||||
assert result == "DISTINCT owners"
|
||||
|
||||
|
||||
def test_process_select_expression_end_to_end(database: Database) -> None:
|
||||
"""
|
||||
End-to-end test that verifies the regex split works with real sqlglot processing.
|
||||
|
||||
This test does NOT mock _process_sql_expression, allowing the full flow
|
||||
through sqlglot parsing and validation to ensure the regex extraction works.
|
||||
"""
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
|
||||
table = SqlaTable(
|
||||
database=database,
|
||||
schema=None,
|
||||
table_name="t",
|
||||
)
|
||||
|
||||
# Test various real-world expressions
|
||||
test_cases = [
|
||||
# (input, expected_output)
|
||||
("COUNT(*)", "COUNT(*)"),
|
||||
("DISTINCT owners", "DISTINCT owners"),
|
||||
("column_name", "column_name"),
|
||||
(
|
||||
"CASE WHEN status = 'active' THEN 1 ELSE 0 END",
|
||||
"CASE WHEN status = 'active' THEN 1 ELSE 0 END",
|
||||
),
|
||||
("SUM(amount) / COUNT(*)", "SUM(amount) / COUNT(*)"),
|
||||
("UPPER(name)", "UPPER(name)"),
|
||||
]
|
||||
|
||||
for expression, expected in test_cases:
|
||||
result = table._process_select_expression(
|
||||
expression=expression,
|
||||
database_id=database.id,
|
||||
engine="sqlite",
|
||||
schema="",
|
||||
template_processor=None,
|
||||
)
|
||||
# sqlglot may normalize the SQL slightly, so we check the result exists
|
||||
# and doesn't contain the SELECT prefix
|
||||
assert result is not None, f"Failed to process: {expression}"
|
||||
assert not result.upper().startswith("SELECT"), (
|
||||
f"Result still has SELECT prefix: {result}"
|
||||
)
|
||||
# The result should contain the core expression (case-insensitive check)
|
||||
assert expected.replace(" ", "").lower() in result.replace(" ", "").lower(), (
|
||||
f"Expected '{expected}' to be in result '{result}' for input '{expression}'"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user