From af2d3babecc150dfebe39ae72213675bf99ddb1f Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Fri, 8 May 2026 13:49:37 -0400 Subject: [PATCH] Improvements --- superset/sql/parse.py | 7 ++- superset/sql/rls_splice.py | 42 +++++++++++-- tests/unit_tests/sql/parse_tests.py | 95 +++++++++++++++++++++++++++++ 3 files changed, 139 insertions(+), 5 deletions(-) diff --git a/superset/sql/parse.py b/superset/sql/parse.py index a45943ea463..d944dce5a2f 100644 --- a/superset/sql/parse.py +++ b/superset/sql/parse.py @@ -873,7 +873,12 @@ class SQLStatement(BaseSQLStatement[exp.Expression]): Modify the `LIMIT` or `TOP` value of the SQL statement inplace. """ # AST mutation invalidates any cached verbatim SQL (e.g. from splice). - self._raw_sql = None + # If we already have a rewritten SQL string, re-parse it first so further + # AST mutations (like LIMIT injection) preserve prior text-based rewrites. + if self._raw_sql is not None: + self._parsed = self._parse_statement(self._raw_sql, self.engine) + self._source_sql = self._raw_sql + self._raw_sql = None if method == LimitMethod.FORCE_LIMIT: self._parsed.args["limit"] = exp.Limit( expression=exp.Literal(this=str(limit), is_string=False) diff --git a/superset/sql/rls_splice.py b/superset/sql/rls_splice.py index 0166a13fe49..48576f7ac71 100644 --- a/superset/sql/rls_splice.py +++ b/superset/sql/rls_splice.py @@ -62,8 +62,15 @@ _CLAUSE_ENDS = { TokenType.GROUP_BY, TokenType.HAVING, TokenType.ORDER_BY, + TokenType.WINDOW, + TokenType.QUALIFY, TokenType.LIMIT, TokenType.FETCH, + TokenType.CLUSTER_BY, + TokenType.DISTRIBUTE_BY, + TokenType.SORT_BY, + TokenType.CONNECT_BY, + TokenType.START_WITH, TokenType.UNION, TokenType.INTERSECT, TokenType.EXCEPT, @@ -77,6 +84,33 @@ def _before_whitespace(sql: str, offset: int) -> int: return offset +def _before_trivia(sql: str, offset: int) -> int: + """ + Back up past whitespace and adjacent comments immediately before *offset*. + + This ensures insertion points land before inline/block comments that appear + between `FROM`/`WHERE` and the next clause keyword. + """ + while True: + offset = _before_whitespace(sql, offset) + + # Inline comment ending at offset, eg: "... -- comment\nGROUP BY". + line_start = sql.rfind("\n", 0, offset) + 1 + inline_comment_start = sql.rfind("--", line_start, offset) + if inline_comment_start != -1: + offset = inline_comment_start + continue + + # Block comment ending at offset, eg: "... /* comment */GROUP BY". + if offset >= 2 and sql[offset - 2 : offset] == "*/": + block_comment_start = sql.rfind("/*", 0, offset - 2) + if block_comment_start != -1: + offset = block_comment_start + continue + + return offset + + def _table_from_node( node: exp.Table, catalog: str | None, @@ -217,7 +251,7 @@ def _find_splice_point( if tok.token_type == TokenType.R_PAREN: if depth == 0: # Closing paren of our subquery — insert just before it. - offset = _before_whitespace(sql, tok.start) + offset = _before_trivia(sql, tok.start) text = f" AND {pred_sql}" if has_where else f" WHERE {pred_sql}" return (offset, text) depth -= 1 @@ -231,7 +265,7 @@ def _find_splice_point( if not has_where and tok.token_type in _CLAUSE_ENDS: # Insert WHERE before this clause keyword. - return (_before_whitespace(sql, tok.start), f" WHERE {pred_sql}") + return (_before_trivia(sql, tok.start), f" WHERE {pred_sql}") # No clause boundary found — append at end of SQL. text = f" AND {pred_sql}" if has_where else f" WHERE {pred_sql}" @@ -255,9 +289,9 @@ def _find_after_where( depth += 1 elif tok.token_type == TokenType.R_PAREN: if depth == 0: - return (_before_whitespace(sql, tok.start), f" AND {pred_sql}") + return (_before_trivia(sql, tok.start), f" AND {pred_sql}") depth -= 1 elif depth == 0 and tok.token_type in _CLAUSE_ENDS: - return (_before_whitespace(sql, tok.start), f" AND {pred_sql}") + return (_before_trivia(sql, tok.start), f" AND {pred_sql}") prev_end = tok.end return (prev_end + 1, f" AND {pred_sql}") diff --git a/tests/unit_tests/sql/parse_tests.py b/tests/unit_tests/sql/parse_tests.py index 7835840b04a..79a9cec6953 100644 --- a/tests/unit_tests/sql/parse_tests.py +++ b/tests/unit_tests/sql/parse_tests.py @@ -1704,6 +1704,21 @@ def test_set_limit_value( assert statement.format() == expected +def test_set_limit_value_after_splice_reparses_from_raw_sql() -> None: + """ + When a statement has cached verbatim SQL from splice-mode rewrites, setting + limit should reparse that SQL before mutating the AST. + """ + statement = SQLStatement("SELECT * FROM some_table", "postgresql") + statement._raw_sql = "SELECT * FROM some_table WHERE tenant_id = 42" + + statement.set_limit_value(10, LimitMethod.FORCE_LIMIT) + formatted = statement.format() + + assert "tenant_id = 42" in formatted + assert "LIMIT 10" in formatted + + @pytest.mark.parametrize( "kql, limit, expected", [ @@ -2707,6 +2722,86 @@ def test_rls_predicate_splice_string_predicates_skip_parse() -> None: ) +@pytest.mark.parametrize( + "sql, expected", + [ + ( + "SELECT * FROM some_table -- hi\nGROUP BY id", + "SELECT * FROM some_table WHERE tenant_id = 42 -- hi\nGROUP BY id", + ), + ( + "SELECT * FROM some_table /* inline */ GROUP BY id", + "SELECT * FROM some_table WHERE tenant_id = 42 /* inline */ GROUP BY id", + ), + ], +) +def test_rls_predicate_splice_inserts_before_comments(sql: str, expected: str) -> None: + """ + Splice mode should insert predicates before comments that precede the next + clause boundary, so comments do not swallow the injected SQL. + """ + statement = SQLStatement(sql, engine="postgresql") + statement.apply_rls( + None, + None, + {Table("some_table"): ["tenant_id = 42"]}, + RLSMethod.AS_PREDICATE_SPLICE, + ) + assert statement.format() == expected + + +@pytest.mark.parametrize( + "sql, engine, expected", + [ + ( + "SELECT * FROM some_table QUALIFY row_number() OVER (PARTITION BY id ORDER BY ts DESC) = 1", + "snowflake", + "SELECT * FROM some_table WHERE tenant_id = 42 QUALIFY row_number() OVER (PARTITION BY id ORDER BY ts DESC) = 1", + ), + ( + "SELECT sum(v) OVER () FROM some_table WINDOW w AS (PARTITION BY id)", + "postgresql", + "SELECT sum(v) OVER () FROM some_table WHERE tenant_id = 42 WINDOW w AS (PARTITION BY id)", + ), + ], +) +def test_rls_predicate_splice_handles_additional_clause_boundaries( + sql: str, + engine: str, + expected: str, +) -> None: + """ + Splice mode should insert WHERE before clause types that can legally follow + FROM/WHERE (for example QUALIFY and WINDOW). + """ + statement = SQLStatement(sql, engine=engine) + statement.apply_rls( + None, + None, + {Table("some_table"): ["tenant_id = 42"]}, + RLSMethod.AS_PREDICATE_SPLICE, + ) + assert statement.format() == expected + + +def test_rls_predicate_splice_then_limit_keeps_rls() -> None: + """ + LIMIT rewrites after splice-mode RLS should retain injected predicates. + """ + statement = SQLStatement("SELECT * FROM some_table", engine="postgresql") + statement.apply_rls( + None, + None, + {Table("some_table"): ["tenant_id = 42"]}, + RLSMethod.AS_PREDICATE_SPLICE, + ) + statement.set_limit_value(101, LimitMethod.FORCE_LIMIT) + + formatted = statement.format() + assert "tenant_id = 42" in formatted + assert "LIMIT 101" in formatted + + @pytest.mark.parametrize( "sql, table, expected", [