mirror of
https://github.com/apache/superset.git
synced 2026-05-21 15:55:10 +00:00
Improvements
This commit is contained in:
@@ -873,7 +873,12 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
|
||||
Modify the `LIMIT` or `TOP` value of the SQL statement inplace.
|
||||
"""
|
||||
# AST mutation invalidates any cached verbatim SQL (e.g. from splice).
|
||||
self._raw_sql = None
|
||||
# If we already have a rewritten SQL string, re-parse it first so further
|
||||
# AST mutations (like LIMIT injection) preserve prior text-based rewrites.
|
||||
if self._raw_sql is not None:
|
||||
self._parsed = self._parse_statement(self._raw_sql, self.engine)
|
||||
self._source_sql = self._raw_sql
|
||||
self._raw_sql = None
|
||||
if method == LimitMethod.FORCE_LIMIT:
|
||||
self._parsed.args["limit"] = exp.Limit(
|
||||
expression=exp.Literal(this=str(limit), is_string=False)
|
||||
|
||||
@@ -62,8 +62,15 @@ _CLAUSE_ENDS = {
|
||||
TokenType.GROUP_BY,
|
||||
TokenType.HAVING,
|
||||
TokenType.ORDER_BY,
|
||||
TokenType.WINDOW,
|
||||
TokenType.QUALIFY,
|
||||
TokenType.LIMIT,
|
||||
TokenType.FETCH,
|
||||
TokenType.CLUSTER_BY,
|
||||
TokenType.DISTRIBUTE_BY,
|
||||
TokenType.SORT_BY,
|
||||
TokenType.CONNECT_BY,
|
||||
TokenType.START_WITH,
|
||||
TokenType.UNION,
|
||||
TokenType.INTERSECT,
|
||||
TokenType.EXCEPT,
|
||||
@@ -77,6 +84,33 @@ def _before_whitespace(sql: str, offset: int) -> int:
|
||||
return offset
|
||||
|
||||
|
||||
def _before_trivia(sql: str, offset: int) -> int:
|
||||
"""
|
||||
Back up past whitespace and adjacent comments immediately before *offset*.
|
||||
|
||||
This ensures insertion points land before inline/block comments that appear
|
||||
between `FROM`/`WHERE` and the next clause keyword.
|
||||
"""
|
||||
while True:
|
||||
offset = _before_whitespace(sql, offset)
|
||||
|
||||
# Inline comment ending at offset, eg: "... -- comment\nGROUP BY".
|
||||
line_start = sql.rfind("\n", 0, offset) + 1
|
||||
inline_comment_start = sql.rfind("--", line_start, offset)
|
||||
if inline_comment_start != -1:
|
||||
offset = inline_comment_start
|
||||
continue
|
||||
|
||||
# Block comment ending at offset, eg: "... /* comment */GROUP BY".
|
||||
if offset >= 2 and sql[offset - 2 : offset] == "*/":
|
||||
block_comment_start = sql.rfind("/*", 0, offset - 2)
|
||||
if block_comment_start != -1:
|
||||
offset = block_comment_start
|
||||
continue
|
||||
|
||||
return offset
|
||||
|
||||
|
||||
def _table_from_node(
|
||||
node: exp.Table,
|
||||
catalog: str | None,
|
||||
@@ -217,7 +251,7 @@ def _find_splice_point(
|
||||
if tok.token_type == TokenType.R_PAREN:
|
||||
if depth == 0:
|
||||
# Closing paren of our subquery — insert just before it.
|
||||
offset = _before_whitespace(sql, tok.start)
|
||||
offset = _before_trivia(sql, tok.start)
|
||||
text = f" AND {pred_sql}" if has_where else f" WHERE {pred_sql}"
|
||||
return (offset, text)
|
||||
depth -= 1
|
||||
@@ -231,7 +265,7 @@ def _find_splice_point(
|
||||
|
||||
if not has_where and tok.token_type in _CLAUSE_ENDS:
|
||||
# Insert WHERE before this clause keyword.
|
||||
return (_before_whitespace(sql, tok.start), f" WHERE {pred_sql}")
|
||||
return (_before_trivia(sql, tok.start), f" WHERE {pred_sql}")
|
||||
|
||||
# No clause boundary found — append at end of SQL.
|
||||
text = f" AND {pred_sql}" if has_where else f" WHERE {pred_sql}"
|
||||
@@ -255,9 +289,9 @@ def _find_after_where(
|
||||
depth += 1
|
||||
elif tok.token_type == TokenType.R_PAREN:
|
||||
if depth == 0:
|
||||
return (_before_whitespace(sql, tok.start), f" AND {pred_sql}")
|
||||
return (_before_trivia(sql, tok.start), f" AND {pred_sql}")
|
||||
depth -= 1
|
||||
elif depth == 0 and tok.token_type in _CLAUSE_ENDS:
|
||||
return (_before_whitespace(sql, tok.start), f" AND {pred_sql}")
|
||||
return (_before_trivia(sql, tok.start), f" AND {pred_sql}")
|
||||
prev_end = tok.end
|
||||
return (prev_end + 1, f" AND {pred_sql}")
|
||||
|
||||
@@ -1704,6 +1704,21 @@ def test_set_limit_value(
|
||||
assert statement.format() == expected
|
||||
|
||||
|
||||
def test_set_limit_value_after_splice_reparses_from_raw_sql() -> None:
|
||||
"""
|
||||
When a statement has cached verbatim SQL from splice-mode rewrites, setting
|
||||
limit should reparse that SQL before mutating the AST.
|
||||
"""
|
||||
statement = SQLStatement("SELECT * FROM some_table", "postgresql")
|
||||
statement._raw_sql = "SELECT * FROM some_table WHERE tenant_id = 42"
|
||||
|
||||
statement.set_limit_value(10, LimitMethod.FORCE_LIMIT)
|
||||
formatted = statement.format()
|
||||
|
||||
assert "tenant_id = 42" in formatted
|
||||
assert "LIMIT 10" in formatted
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"kql, limit, expected",
|
||||
[
|
||||
@@ -2707,6 +2722,86 @@ def test_rls_predicate_splice_string_predicates_skip_parse() -> None:
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sql, expected",
|
||||
[
|
||||
(
|
||||
"SELECT * FROM some_table -- hi\nGROUP BY id",
|
||||
"SELECT * FROM some_table WHERE tenant_id = 42 -- hi\nGROUP BY id",
|
||||
),
|
||||
(
|
||||
"SELECT * FROM some_table /* inline */ GROUP BY id",
|
||||
"SELECT * FROM some_table WHERE tenant_id = 42 /* inline */ GROUP BY id",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_rls_predicate_splice_inserts_before_comments(sql: str, expected: str) -> None:
|
||||
"""
|
||||
Splice mode should insert predicates before comments that precede the next
|
||||
clause boundary, so comments do not swallow the injected SQL.
|
||||
"""
|
||||
statement = SQLStatement(sql, engine="postgresql")
|
||||
statement.apply_rls(
|
||||
None,
|
||||
None,
|
||||
{Table("some_table"): ["tenant_id = 42"]},
|
||||
RLSMethod.AS_PREDICATE_SPLICE,
|
||||
)
|
||||
assert statement.format() == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sql, engine, expected",
|
||||
[
|
||||
(
|
||||
"SELECT * FROM some_table QUALIFY row_number() OVER (PARTITION BY id ORDER BY ts DESC) = 1",
|
||||
"snowflake",
|
||||
"SELECT * FROM some_table WHERE tenant_id = 42 QUALIFY row_number() OVER (PARTITION BY id ORDER BY ts DESC) = 1",
|
||||
),
|
||||
(
|
||||
"SELECT sum(v) OVER () FROM some_table WINDOW w AS (PARTITION BY id)",
|
||||
"postgresql",
|
||||
"SELECT sum(v) OVER () FROM some_table WHERE tenant_id = 42 WINDOW w AS (PARTITION BY id)",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_rls_predicate_splice_handles_additional_clause_boundaries(
|
||||
sql: str,
|
||||
engine: str,
|
||||
expected: str,
|
||||
) -> None:
|
||||
"""
|
||||
Splice mode should insert WHERE before clause types that can legally follow
|
||||
FROM/WHERE (for example QUALIFY and WINDOW).
|
||||
"""
|
||||
statement = SQLStatement(sql, engine=engine)
|
||||
statement.apply_rls(
|
||||
None,
|
||||
None,
|
||||
{Table("some_table"): ["tenant_id = 42"]},
|
||||
RLSMethod.AS_PREDICATE_SPLICE,
|
||||
)
|
||||
assert statement.format() == expected
|
||||
|
||||
|
||||
def test_rls_predicate_splice_then_limit_keeps_rls() -> None:
|
||||
"""
|
||||
LIMIT rewrites after splice-mode RLS should retain injected predicates.
|
||||
"""
|
||||
statement = SQLStatement("SELECT * FROM some_table", engine="postgresql")
|
||||
statement.apply_rls(
|
||||
None,
|
||||
None,
|
||||
{Table("some_table"): ["tenant_id = 42"]},
|
||||
RLSMethod.AS_PREDICATE_SPLICE,
|
||||
)
|
||||
statement.set_limit_value(101, LimitMethod.FORCE_LIMIT)
|
||||
|
||||
formatted = statement.format()
|
||||
assert "tenant_id = 42" in formatted
|
||||
assert "LIMIT 101" in formatted
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sql, table, expected",
|
||||
[
|
||||
|
||||
Reference in New Issue
Block a user