mirror of
https://github.com/apache/superset.git
synced 2026-05-29 20:29:34 +00:00
Address comments
This commit is contained in:
@@ -93,38 +93,19 @@ def _splice_priority(text: str) -> int:
|
||||
return 1 if text != ")" else 0
|
||||
|
||||
|
||||
def _before_whitespace(sql: str, offset: int) -> int:
|
||||
"""Back up past any whitespace immediately before *offset*."""
|
||||
while offset > 0 and sql[offset - 1] in (" ", "\t", "\n", "\r"):
|
||||
offset -= 1
|
||||
return offset
|
||||
|
||||
|
||||
def _before_trivia(sql: str, offset: int) -> int:
|
||||
def _after_previous_token(tokens: list[Token], index: int) -> int:
|
||||
"""
|
||||
Back up past whitespace and adjacent comments immediately before *offset*.
|
||||
Return the offset immediately after the token preceding *index*.
|
||||
|
||||
This ensures insertion points land before inline/block comments that appear
|
||||
between `FROM`/`WHERE` and the next clause keyword.
|
||||
The sqlglot tokenizer strips comments and whitespace from the token stream,
|
||||
so the previous token's ``end + 1`` is the splice point that lands right
|
||||
after the last real SQL content — naturally skipping any intervening
|
||||
comments or whitespace, and never confusing ``--`` or ``/*`` inside string
|
||||
literals for real comment delimiters.
|
||||
"""
|
||||
while True:
|
||||
offset = _before_whitespace(sql, offset)
|
||||
|
||||
# Inline comment ending at offset, eg: "... -- comment\nGROUP BY".
|
||||
line_start = sql.rfind("\n", 0, offset) + 1
|
||||
inline_comment_start = sql.rfind("--", line_start, offset)
|
||||
if inline_comment_start != -1:
|
||||
offset = inline_comment_start
|
||||
continue
|
||||
|
||||
# Block comment ending at offset, eg: "... /* comment */GROUP BY".
|
||||
if offset >= 2 and sql[offset - 2 : offset] == "*/":
|
||||
block_comment_start = sql.rfind("/*", 0, offset - 2)
|
||||
if block_comment_start != -1:
|
||||
offset = block_comment_start
|
||||
continue
|
||||
|
||||
return offset
|
||||
if index <= 0:
|
||||
return 0
|
||||
return tokens[index - 1].end + 1
|
||||
|
||||
|
||||
def _table_from_node(
|
||||
@@ -355,7 +336,6 @@ def _scan_until_scope_boundary(
|
||||
|
||||
|
||||
def _find_condition_end(
|
||||
sql: str,
|
||||
tokens: list[Token],
|
||||
start_index: int,
|
||||
*,
|
||||
@@ -371,14 +351,14 @@ def _find_condition_end(
|
||||
depth += 1
|
||||
elif tok.token_type == TokenType.R_PAREN:
|
||||
if depth == 0:
|
||||
return _before_trivia(sql, tok.start)
|
||||
return prev_end + 1
|
||||
depth -= 1
|
||||
elif depth == 0 and (
|
||||
(stop_at_join and tok.token_type == TokenType.WHERE)
|
||||
or tok.token_type in _CLAUSE_ENDS
|
||||
or (stop_at_join and tok.token_type in _JOIN_STARTS)
|
||||
):
|
||||
return _before_trivia(sql, tok.start)
|
||||
return prev_end + 1
|
||||
prev_end = tok.end
|
||||
return prev_end + 1
|
||||
|
||||
@@ -398,14 +378,14 @@ def _find_where_splice(
|
||||
if idx + 1 >= len(tokens):
|
||||
return [(tokens[idx].end + 1, f" {pred_sql}")]
|
||||
body_start = tokens[idx + 1].start
|
||||
body_end = _find_condition_end(sql, tokens, idx, stop_at_join=False)
|
||||
body_end = _find_condition_end(tokens, idx, stop_at_join=False)
|
||||
return [
|
||||
(body_start, f"{pred_sql} AND ("),
|
||||
(body_end, ")"),
|
||||
]
|
||||
|
||||
if kind == "boundary" and idx is not None:
|
||||
return [(_before_trivia(sql, tokens[idx].start), f" WHERE {pred_sql}")]
|
||||
return [(_after_previous_token(tokens, idx), f" WHERE {pred_sql}")]
|
||||
|
||||
return [(len(sql), f" WHERE {pred_sql}")]
|
||||
|
||||
@@ -426,14 +406,14 @@ def _find_join_splice(
|
||||
if on_index + 1 >= len(tokens):
|
||||
return [(tokens[on_index].end + 1, f" {pred_sql}")]
|
||||
body_start = tokens[on_index + 1].start
|
||||
body_end = _find_condition_end(sql, tokens, on_index, stop_at_join=True)
|
||||
body_end = _find_condition_end(tokens, on_index, stop_at_join=True)
|
||||
return [
|
||||
(body_start, f"{pred_sql} AND ("),
|
||||
(body_end, ")"),
|
||||
]
|
||||
|
||||
if boundary_index is not None:
|
||||
return [(_before_trivia(sql, tokens[boundary_index].start), f" ON {pred_sql}")]
|
||||
return [(_after_previous_token(tokens, boundary_index), f" ON {pred_sql}")]
|
||||
|
||||
return [(len(sql), f" ON {pred_sql}")]
|
||||
|
||||
|
||||
@@ -21,7 +21,6 @@ from sqlglot import Dialect, exp
|
||||
|
||||
from superset.sql.parse import SQLStatement, Table
|
||||
from superset.sql.rls_splice import (
|
||||
_before_trivia,
|
||||
_classify_source_predicate,
|
||||
_find_condition_end,
|
||||
_find_join_splice,
|
||||
@@ -68,10 +67,24 @@ def test_apply_rls_splice_ignores_empty_predicates() -> None:
|
||||
assert apply_rls_splice(sql, None, None, {Table("foo"): []}) == sql
|
||||
|
||||
|
||||
def test_before_trivia_handles_unmatched_block_comment_suffix() -> None:
|
||||
sql = "SELECT */GROUP BY x"
|
||||
offset = sql.index("GROUP")
|
||||
assert _before_trivia(sql, offset) == offset
|
||||
def test_apply_rls_splice_ignores_dash_dash_inside_string_literal() -> None:
|
||||
"""
|
||||
Regression: the splice point must not be confused by ``--`` appearing
|
||||
inside a string literal. Earlier ``rfind("--", ...)`` logic mistook this
|
||||
for an inline comment and inserted the predicate inside the quoted text.
|
||||
"""
|
||||
sql = "SELECT * FROM some_table WHERE note = '--x' GROUP BY id"
|
||||
spliced = apply_rls_splice(
|
||||
sql,
|
||||
None,
|
||||
None,
|
||||
{Table("some_table"): ["some_table.tenant_id = 42"]},
|
||||
dialect="postgres",
|
||||
)
|
||||
assert spliced == (
|
||||
"SELECT * FROM some_table WHERE some_table.tenant_id = 42 "
|
||||
"AND (note = '--x') GROUP BY id"
|
||||
)
|
||||
|
||||
|
||||
def test_table_end_returns_none_without_metadata() -> None:
|
||||
@@ -122,7 +135,7 @@ def test_find_condition_end_handles_subquery_closing_paren() -> None:
|
||||
sql = "SELECT * FROM (SELECT * FROM t WHERE a = 1)"
|
||||
tokens = _tokenize(sql)
|
||||
where_index = _token_index(tokens, sqlglot.tokens.TokenType.WHERE)
|
||||
end = _find_condition_end(sql, tokens, where_index, stop_at_join=False)
|
||||
end = _find_condition_end(tokens, where_index, stop_at_join=False)
|
||||
assert sql[end] == ")"
|
||||
|
||||
|
||||
@@ -130,7 +143,7 @@ def test_find_condition_end_handles_parenthesized_expression() -> None:
|
||||
sql = "SELECT * FROM t WHERE (a = 1)"
|
||||
tokens = _tokenize(sql)
|
||||
where_index = _token_index(tokens, sqlglot.tokens.TokenType.WHERE)
|
||||
end = _find_condition_end(sql, tokens, where_index, stop_at_join=False)
|
||||
end = _find_condition_end(tokens, where_index, stop_at_join=False)
|
||||
assert end == len(sql)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user