Address comments

This commit is contained in:
Beto Dealmeida
2026-05-12 10:17:50 -04:00
parent 439130db54
commit 0ce43abe5b
2 changed files with 36 additions and 43 deletions

View File

@@ -21,7 +21,6 @@ from sqlglot import Dialect, exp
from superset.sql.parse import SQLStatement, Table
from superset.sql.rls_splice import (
_before_trivia,
_classify_source_predicate,
_find_condition_end,
_find_join_splice,
@@ -68,10 +67,24 @@ def test_apply_rls_splice_ignores_empty_predicates() -> None:
assert apply_rls_splice(sql, None, None, {Table("foo"): []}) == sql
def test_before_trivia_handles_unmatched_block_comment_suffix() -> None:
sql = "SELECT */GROUP BY x"
offset = sql.index("GROUP")
assert _before_trivia(sql, offset) == offset
def test_apply_rls_splice_ignores_dash_dash_inside_string_literal() -> None:
"""
Regression: the splice point must not be confused by ``--`` appearing
inside a string literal. Earlier ``rfind("--", ...)`` logic mistook this
for an inline comment and inserted the predicate inside the quoted text.
"""
sql = "SELECT * FROM some_table WHERE note = '--x' GROUP BY id"
spliced = apply_rls_splice(
sql,
None,
None,
{Table("some_table"): ["some_table.tenant_id = 42"]},
dialect="postgres",
)
assert spliced == (
"SELECT * FROM some_table WHERE some_table.tenant_id = 42 "
"AND (note = '--x') GROUP BY id"
)
def test_table_end_returns_none_without_metadata() -> None:
@@ -122,7 +135,7 @@ def test_find_condition_end_handles_subquery_closing_paren() -> None:
sql = "SELECT * FROM (SELECT * FROM t WHERE a = 1)"
tokens = _tokenize(sql)
where_index = _token_index(tokens, sqlglot.tokens.TokenType.WHERE)
end = _find_condition_end(sql, tokens, where_index, stop_at_join=False)
end = _find_condition_end(tokens, where_index, stop_at_join=False)
assert sql[end] == ")"
@@ -130,7 +143,7 @@ def test_find_condition_end_handles_parenthesized_expression() -> None:
sql = "SELECT * FROM t WHERE (a = 1)"
tokens = _tokenize(sql)
where_index = _token_index(tokens, sqlglot.tokens.TokenType.WHERE)
end = _find_condition_end(sql, tokens, where_index, stop_at_join=False)
end = _find_condition_end(tokens, where_index, stop_at_join=False)
assert end == len(sql)