diff --git a/superset/sql/rls_splice.py b/superset/sql/rls_splice.py index 94fe789b36d..b068305cb26 100644 --- a/superset/sql/rls_splice.py +++ b/superset/sql/rls_splice.py @@ -93,38 +93,19 @@ def _splice_priority(text: str) -> int: return 1 if text != ")" else 0 -def _before_whitespace(sql: str, offset: int) -> int: - """Back up past any whitespace immediately before *offset*.""" - while offset > 0 and sql[offset - 1] in (" ", "\t", "\n", "\r"): - offset -= 1 - return offset - - -def _before_trivia(sql: str, offset: int) -> int: +def _after_previous_token(tokens: list[Token], index: int) -> int: """ - Back up past whitespace and adjacent comments immediately before *offset*. + Return the offset immediately after the token preceding *index*. - This ensures insertion points land before inline/block comments that appear - between `FROM`/`WHERE` and the next clause keyword. + The sqlglot tokenizer strips comments and whitespace from the token stream, + so the previous token's ``end + 1`` is the splice point that lands right + after the last real SQL content — naturally skipping any intervening + comments or whitespace, and never confusing ``--`` or ``/*`` inside string + literals for real comment delimiters. """ - while True: - offset = _before_whitespace(sql, offset) - - # Inline comment ending at offset, eg: "... -- comment\nGROUP BY". - line_start = sql.rfind("\n", 0, offset) + 1 - inline_comment_start = sql.rfind("--", line_start, offset) - if inline_comment_start != -1: - offset = inline_comment_start - continue - - # Block comment ending at offset, eg: "... /* comment */GROUP BY". - if offset >= 2 and sql[offset - 2 : offset] == "*/": - block_comment_start = sql.rfind("/*", 0, offset - 2) - if block_comment_start != -1: - offset = block_comment_start - continue - - return offset + if index <= 0: + return 0 + return tokens[index - 1].end + 1 def _table_from_node( @@ -355,7 +336,6 @@ def _scan_until_scope_boundary( def _find_condition_end( - sql: str, tokens: list[Token], start_index: int, *, @@ -371,14 +351,14 @@ def _find_condition_end( depth += 1 elif tok.token_type == TokenType.R_PAREN: if depth == 0: - return _before_trivia(sql, tok.start) + return prev_end + 1 depth -= 1 elif depth == 0 and ( (stop_at_join and tok.token_type == TokenType.WHERE) or tok.token_type in _CLAUSE_ENDS or (stop_at_join and tok.token_type in _JOIN_STARTS) ): - return _before_trivia(sql, tok.start) + return prev_end + 1 prev_end = tok.end return prev_end + 1 @@ -398,14 +378,14 @@ def _find_where_splice( if idx + 1 >= len(tokens): return [(tokens[idx].end + 1, f" {pred_sql}")] body_start = tokens[idx + 1].start - body_end = _find_condition_end(sql, tokens, idx, stop_at_join=False) + body_end = _find_condition_end(tokens, idx, stop_at_join=False) return [ (body_start, f"{pred_sql} AND ("), (body_end, ")"), ] if kind == "boundary" and idx is not None: - return [(_before_trivia(sql, tokens[idx].start), f" WHERE {pred_sql}")] + return [(_after_previous_token(tokens, idx), f" WHERE {pred_sql}")] return [(len(sql), f" WHERE {pred_sql}")] @@ -426,14 +406,14 @@ def _find_join_splice( if on_index + 1 >= len(tokens): return [(tokens[on_index].end + 1, f" {pred_sql}")] body_start = tokens[on_index + 1].start - body_end = _find_condition_end(sql, tokens, on_index, stop_at_join=True) + body_end = _find_condition_end(tokens, on_index, stop_at_join=True) return [ (body_start, f"{pred_sql} AND ("), (body_end, ")"), ] if boundary_index is not None: - return [(_before_trivia(sql, tokens[boundary_index].start), f" ON {pred_sql}")] + return [(_after_previous_token(tokens, boundary_index), f" ON {pred_sql}")] return [(len(sql), f" ON {pred_sql}")] diff --git a/tests/unit_tests/sql/rls_splice_unit_tests.py b/tests/unit_tests/sql/rls_splice_unit_tests.py index 51f3a23afff..5b1119adeae 100644 --- a/tests/unit_tests/sql/rls_splice_unit_tests.py +++ b/tests/unit_tests/sql/rls_splice_unit_tests.py @@ -21,7 +21,6 @@ from sqlglot import Dialect, exp from superset.sql.parse import SQLStatement, Table from superset.sql.rls_splice import ( - _before_trivia, _classify_source_predicate, _find_condition_end, _find_join_splice, @@ -68,10 +67,24 @@ def test_apply_rls_splice_ignores_empty_predicates() -> None: assert apply_rls_splice(sql, None, None, {Table("foo"): []}) == sql -def test_before_trivia_handles_unmatched_block_comment_suffix() -> None: - sql = "SELECT */GROUP BY x" - offset = sql.index("GROUP") - assert _before_trivia(sql, offset) == offset +def test_apply_rls_splice_ignores_dash_dash_inside_string_literal() -> None: + """ + Regression: the splice point must not be confused by ``--`` appearing + inside a string literal. Earlier ``rfind("--", ...)`` logic mistook this + for an inline comment and inserted the predicate inside the quoted text. + """ + sql = "SELECT * FROM some_table WHERE note = '--x' GROUP BY id" + spliced = apply_rls_splice( + sql, + None, + None, + {Table("some_table"): ["some_table.tenant_id = 42"]}, + dialect="postgres", + ) + assert spliced == ( + "SELECT * FROM some_table WHERE some_table.tenant_id = 42 " + "AND (note = '--x') GROUP BY id" + ) def test_table_end_returns_none_without_metadata() -> None: @@ -122,7 +135,7 @@ def test_find_condition_end_handles_subquery_closing_paren() -> None: sql = "SELECT * FROM (SELECT * FROM t WHERE a = 1)" tokens = _tokenize(sql) where_index = _token_index(tokens, sqlglot.tokens.TokenType.WHERE) - end = _find_condition_end(sql, tokens, where_index, stop_at_join=False) + end = _find_condition_end(tokens, where_index, stop_at_join=False) assert sql[end] == ")" @@ -130,7 +143,7 @@ def test_find_condition_end_handles_parenthesized_expression() -> None: sql = "SELECT * FROM t WHERE (a = 1)" tokens = _tokenize(sql) where_index = _token_index(tokens, sqlglot.tokens.TokenType.WHERE) - end = _find_condition_end(sql, tokens, where_index, stop_at_join=False) + end = _find_condition_end(tokens, where_index, stop_at_join=False) assert end == len(sql)