mirror of
https://github.com/apache/superset.git
synced 2026-06-19 14:39:20 +00:00
Compare commits
4 Commits
fix-report
...
fix/saniti
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ca2cb0f289 | ||
|
|
bbfa73e277 | ||
|
|
b03b723928 | ||
|
|
390488c75f |
@@ -1686,15 +1686,34 @@ def process_jinja_sql(
|
||||
|
||||
def sanitize_clause(clause: str, engine: str) -> str:
|
||||
"""
|
||||
Make sure the SQL clause is valid.
|
||||
Validate a SQL clause and return it unchanged.
|
||||
|
||||
The clause is parsed to ensure it is a single, well-formed statement. We
|
||||
intentionally return the *original* text rather than a re-rendered version:
|
||||
round-tripping user SQL through SQLGlot's dialect generator can silently
|
||||
alter semantics. For example, the Postgres dialect (borrowed by several
|
||||
engines) rewrites ``ROUND(AVG(x), n)`` to ``ROUND(CAST(AVG(x) AS DECIMAL),
|
||||
n)``, which rounds the value to an integer before the explicit ``ROUND`` on
|
||||
engines whose unqualified ``DECIMAL`` defaults to scale 0 (see #36113).
|
||||
|
||||
Comments are the one exception: a trailing line comment can comment out
|
||||
surrounding SQL once the clause is embedded into a larger query (e.g.
|
||||
wrapped in parentheses), so any clause that contains comments is re-rendered
|
||||
to normalize them into a safe form. A trailing statement terminator is
|
||||
likewise stripped, since callers embed the clause inside a larger fragment
|
||||
(``WHERE (...)``) where a stray ``;`` would produce invalid SQL.
|
||||
"""
|
||||
try:
|
||||
statement = SQLStatement(clause, engine)
|
||||
dialect = SQLGLOT_DIALECTS.get(engine)
|
||||
parsed = statement._parsed # pylint: disable=protected-access
|
||||
if not any(node.comments for node in parsed.walk()):
|
||||
return clause.rstrip().rstrip(";").rstrip()
|
||||
|
||||
from sqlglot.dialects.dialect import Dialect
|
||||
|
||||
dialect = SQLGLOT_DIALECTS.get(engine)
|
||||
return Dialect.get_or_raise(dialect).generate(
|
||||
statement._parsed, # pylint: disable=protected-access
|
||||
parsed,
|
||||
copy=True,
|
||||
comments=True,
|
||||
pretty=False,
|
||||
|
||||
@@ -792,7 +792,7 @@ def test_get_samples_with_multiple_filters(
|
||||
assert "2000-01-02" in rv.json["result"]["query"]
|
||||
assert "2000-01-04" in rv.json["result"]["query"]
|
||||
assert "col3 = 1.2" in rv.json["result"]["query"]
|
||||
assert "col4 IS NULL" in rv.json["result"]["query"]
|
||||
assert "col4 is null" in rv.json["result"]["query"]
|
||||
assert "col2 = 'c'" in rv.json["result"]["query"]
|
||||
|
||||
|
||||
|
||||
@@ -199,8 +199,8 @@ class TestDatabaseModel(SupersetTestCase):
|
||||
assert "'foo_P1D'" in query
|
||||
# assert dataset saved metric
|
||||
assert "count('bar_P1D')" in query
|
||||
# assert adhoc metric
|
||||
assert "SUM(CASE WHEN user = 'user_abc' THEN 1 ELSE 0 END)" in query
|
||||
# assert adhoc metric (sanitize_clause preserves the user's SQL verbatim)
|
||||
assert "SUM(case when user = 'user_abc' then 1 else 0 end)" in query
|
||||
# Cleanup
|
||||
db.session.delete(table)
|
||||
db.session.commit()
|
||||
|
||||
@@ -2897,7 +2897,8 @@ def test_is_valid_cvas(sql: str, engine: str, expected: bool) -> None:
|
||||
"sql, expected, engine",
|
||||
[
|
||||
("col = 1", "col = 1", "base"),
|
||||
("1=\t\n1", "1 = 1", "base"),
|
||||
# Comment-free clauses are returned verbatim (no semantic round-trip).
|
||||
("1=\t\n1", "1=\t\n1", "base"),
|
||||
("(col = 1)", "(col = 1)", "base"), # Compact format without newlines
|
||||
(
|
||||
"(col1 = 1) AND (col2 = 2)",
|
||||
@@ -2921,6 +2922,10 @@ def test_is_valid_cvas(sql: str, engine: str, expected: bool) -> None:
|
||||
), # Block comments preserved
|
||||
("col = 'col1 = 1) AND (col2 = 2'", "col = 'col1 = 1) AND (col2 = 2'", "base"),
|
||||
("col = 'select 1; select 2'", "col = 'select 1; select 2'", "base"),
|
||||
# Trailing statement terminators are stripped so the clause stays valid
|
||||
# once embedded inside a larger fragment (e.g. ``WHERE (...)``).
|
||||
("col = 1;", "col = 1", "base"),
|
||||
("col = 1 ; ", "col = 1", "base"),
|
||||
("col = 'abc -- comment'", "col = 'abc -- comment'", "base"),
|
||||
("col1 = 1) AND (col2 = 2)", QueryClauseValidationException, "base"),
|
||||
("(col1 = 1) AND (col2 = 2", QueryClauseValidationException, "base"),
|
||||
@@ -2940,6 +2945,31 @@ def test_sanitize_clause(sql: str, expected: str | Exception, engine: str) -> No
|
||||
sanitize_clause(sql, engine)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"engine",
|
||||
["postgresql", "redshift", "cockroachdb", "netezza", "hana", "base", "mysql"],
|
||||
)
|
||||
def test_sanitize_clause_preserves_aggregation_semantics(engine: str) -> None:
|
||||
"""
|
||||
Regression test for https://github.com/apache/superset/issues/36113.
|
||||
|
||||
`sanitize_clause` must not silently rewrite a user-authored expression. The
|
||||
Postgres SQLGlot dialect (which several engines borrow) rewrites
|
||||
``ROUND(AVG(x), n)`` to ``ROUND(CAST(AVG(x) AS DECIMAL), n)`` at generation
|
||||
time. On engines whose unqualified ``DECIMAL`` defaults to scale 0 (e.g.
|
||||
Redshift, Netezza) the injected cast rounds the aggregate to an integer
|
||||
*before* the explicit ``ROUND``, producing wrong results.
|
||||
|
||||
The clause must be returned unchanged regardless of the engine dialect.
|
||||
"""
|
||||
clause = "ROUND(AVG(col), 4)"
|
||||
sanitized = sanitize_clause(clause, engine)
|
||||
assert "CAST" not in sanitized.upper(), (
|
||||
f"sanitize_clause injected a cast for engine {engine!r}: {sanitized!r}"
|
||||
)
|
||||
assert sanitized == clause
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"engine",
|
||||
[
|
||||
|
||||
Reference in New Issue
Block a user