Address comments

This commit is contained in:
Beto Dealmeida
2026-05-12 10:17:50 -04:00
parent 439130db54
commit 0ce43abe5b
2 changed files with 36 additions and 43 deletions

View File

@@ -93,38 +93,19 @@ def _splice_priority(text: str) -> int:
return 1 if text != ")" else 0
def _before_whitespace(sql: str, offset: int) -> int:
"""Back up past any whitespace immediately before *offset*."""
while offset > 0 and sql[offset - 1] in (" ", "\t", "\n", "\r"):
offset -= 1
return offset
def _before_trivia(sql: str, offset: int) -> int:
def _after_previous_token(tokens: list[Token], index: int) -> int:
"""
Back up past whitespace and adjacent comments immediately before *offset*.
Return the offset immediately after the token preceding *index*.
This ensures insertion points land before inline/block comments that appear
between `FROM`/`WHERE` and the next clause keyword.
The sqlglot tokenizer strips comments and whitespace from the token stream,
so the previous token's ``end + 1`` is the splice point that lands right
after the last real SQL content — naturally skipping any intervening
comments or whitespace, and never confusing ``--`` or ``/*`` inside string
literals for real comment delimiters.
"""
while True:
offset = _before_whitespace(sql, offset)
# Inline comment ending at offset, eg: "... -- comment\nGROUP BY".
line_start = sql.rfind("\n", 0, offset) + 1
inline_comment_start = sql.rfind("--", line_start, offset)
if inline_comment_start != -1:
offset = inline_comment_start
continue
# Block comment ending at offset, eg: "... /* comment */GROUP BY".
if offset >= 2 and sql[offset - 2 : offset] == "*/":
block_comment_start = sql.rfind("/*", 0, offset - 2)
if block_comment_start != -1:
offset = block_comment_start
continue
return offset
if index <= 0:
return 0
return tokens[index - 1].end + 1
def _table_from_node(
@@ -355,7 +336,6 @@ def _scan_until_scope_boundary(
def _find_condition_end(
sql: str,
tokens: list[Token],
start_index: int,
*,
@@ -371,14 +351,14 @@ def _find_condition_end(
depth += 1
elif tok.token_type == TokenType.R_PAREN:
if depth == 0:
return _before_trivia(sql, tok.start)
return prev_end + 1
depth -= 1
elif depth == 0 and (
(stop_at_join and tok.token_type == TokenType.WHERE)
or tok.token_type in _CLAUSE_ENDS
or (stop_at_join and tok.token_type in _JOIN_STARTS)
):
return _before_trivia(sql, tok.start)
return prev_end + 1
prev_end = tok.end
return prev_end + 1
@@ -398,14 +378,14 @@ def _find_where_splice(
if idx + 1 >= len(tokens):
return [(tokens[idx].end + 1, f" {pred_sql}")]
body_start = tokens[idx + 1].start
body_end = _find_condition_end(sql, tokens, idx, stop_at_join=False)
body_end = _find_condition_end(tokens, idx, stop_at_join=False)
return [
(body_start, f"{pred_sql} AND ("),
(body_end, ")"),
]
if kind == "boundary" and idx is not None:
return [(_before_trivia(sql, tokens[idx].start), f" WHERE {pred_sql}")]
return [(_after_previous_token(tokens, idx), f" WHERE {pred_sql}")]
return [(len(sql), f" WHERE {pred_sql}")]
@@ -426,14 +406,14 @@ def _find_join_splice(
if on_index + 1 >= len(tokens):
return [(tokens[on_index].end + 1, f" {pred_sql}")]
body_start = tokens[on_index + 1].start
body_end = _find_condition_end(sql, tokens, on_index, stop_at_join=True)
body_end = _find_condition_end(tokens, on_index, stop_at_join=True)
return [
(body_start, f"{pred_sql} AND ("),
(body_end, ")"),
]
if boundary_index is not None:
return [(_before_trivia(sql, tokens[boundary_index].start), f" ON {pred_sql}")]
return [(_after_previous_token(tokens, boundary_index), f" ON {pred_sql}")]
return [(len(sql), f" ON {pred_sql}")]

View File

@@ -21,7 +21,6 @@ from sqlglot import Dialect, exp
from superset.sql.parse import SQLStatement, Table
from superset.sql.rls_splice import (
_before_trivia,
_classify_source_predicate,
_find_condition_end,
_find_join_splice,
@@ -68,10 +67,24 @@ def test_apply_rls_splice_ignores_empty_predicates() -> None:
assert apply_rls_splice(sql, None, None, {Table("foo"): []}) == sql
def test_before_trivia_handles_unmatched_block_comment_suffix() -> None:
sql = "SELECT */GROUP BY x"
offset = sql.index("GROUP")
assert _before_trivia(sql, offset) == offset
def test_apply_rls_splice_ignores_dash_dash_inside_string_literal() -> None:
"""
Regression: the splice point must not be confused by ``--`` appearing
inside a string literal. Earlier ``rfind("--", ...)`` logic mistook this
for an inline comment and inserted the predicate inside the quoted text.
"""
sql = "SELECT * FROM some_table WHERE note = '--x' GROUP BY id"
spliced = apply_rls_splice(
sql,
None,
None,
{Table("some_table"): ["some_table.tenant_id = 42"]},
dialect="postgres",
)
assert spliced == (
"SELECT * FROM some_table WHERE some_table.tenant_id = 42 "
"AND (note = '--x') GROUP BY id"
)
def test_table_end_returns_none_without_metadata() -> None:
@@ -122,7 +135,7 @@ def test_find_condition_end_handles_subquery_closing_paren() -> None:
sql = "SELECT * FROM (SELECT * FROM t WHERE a = 1)"
tokens = _tokenize(sql)
where_index = _token_index(tokens, sqlglot.tokens.TokenType.WHERE)
end = _find_condition_end(sql, tokens, where_index, stop_at_join=False)
end = _find_condition_end(tokens, where_index, stop_at_join=False)
assert sql[end] == ")"
@@ -130,7 +143,7 @@ def test_find_condition_end_handles_parenthesized_expression() -> None:
sql = "SELECT * FROM t WHERE (a = 1)"
tokens = _tokenize(sql)
where_index = _token_index(tokens, sqlglot.tokens.TokenType.WHERE)
end = _find_condition_end(sql, tokens, where_index, stop_at_join=False)
end = _find_condition_end(tokens, where_index, stop_at_join=False)
assert end == len(sql)