diff --git a/superset/sql/rls_splice.py b/superset/sql/rls_splice.py index 2a26a9920d9..94fe789b36d 100644 --- a/superset/sql/rls_splice.py +++ b/superset/sql/rls_splice.py @@ -234,10 +234,12 @@ def _splices_for_scope( if source_type == "from": from_predicates.append(pred_sql) from_table_ends.append(table_end) - elif source_type == "join": - join_splice = _find_join_splice(sql, tokens, table_end, pred_sql) - if join_splice: - join_splices.extend(join_splice) + continue + + join_splice = _find_join_splice(sql, tokens, table_end, pred_sql) + if join_splice: + join_splices.extend(join_splice) + continue if not from_predicates: return join_splices diff --git a/tests/unit_tests/sql/rls_splice_unit_tests.py b/tests/unit_tests/sql/rls_splice_unit_tests.py new file mode 100644 index 00000000000..16c3ced0076 --- /dev/null +++ b/tests/unit_tests/sql/rls_splice_unit_tests.py @@ -0,0 +1,261 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import sqlglot +from sqlglot import Dialect, exp + +from superset.sql.parse import SQLStatement, Table +from superset.sql.rls_splice import ( + _before_trivia, + _classify_source_predicate, + _find_condition_end, + _find_join_splice, + _find_where_splice, + _scan_join_clause, + _scan_until_scope_boundary, + _splices_for_scope, + _table_end, + apply_rls_splice, +) + + +def _tokenize(sql: str) -> list[sqlglot.tokens.Token]: + return list(Dialect.get_or_raise(None).tokenize(sql)) + + +def _token_index(tokens: list[sqlglot.tokens.Token], token_type: object) -> int: + return next(i for i, token in enumerate(tokens) if token.token_type == token_type) + + +def _token_by_text(tokens: list[sqlglot.tokens.Token], text: str) -> sqlglot.tokens.Token: + return next(token for token in tokens if token.text == text) + + +def test_split_source_returns_none_result_when_tokenize_fails(monkeypatch: object) -> None: + class _BrokenDialect: + @staticmethod + def tokenize(_: str) -> list[sqlglot.tokens.Token]: + raise sqlglot.errors.SqlglotError("boom") + + monkeypatch.setattr( + "superset.sql.parse.Dialect.get_or_raise", + lambda _: _BrokenDialect(), + ) + assert SQLStatement._split_source("SELECT 1", "postgresql", 2) == [None, None] + + +def test_apply_rls_splice_ignores_empty_predicates() -> None: + sql = "SELECT 1" + assert apply_rls_splice(sql, None, None, {Table("foo"): []}) == sql + + +def test_before_trivia_handles_unmatched_block_comment_suffix() -> None: + sql = "SELECT */GROUP BY x" + offset = sql.index("GROUP") + assert _before_trivia(sql, offset) == offset + + +def test_table_end_returns_none_without_metadata() -> None: + source = exp.Table(this=exp.Identifier(this="foo")) + assert _table_end(source) is None + + +def test_classify_source_predicate_returns_none_without_table_metadata() -> None: + source = exp.Table(this=exp.Identifier(this="foo")) + exp.From(this=source) + result = _classify_source_predicate( + source, + {Table("foo"): ["id = 1"]}, + None, + None, + None, + ) + assert result == ("none", None, None) + + +def test_classify_source_predicate_returns_none_for_unsupported_parent() -> None: + source = exp.Table(this=exp.Identifier(this="foo")) + source.this.meta["end"] = 3 + exp.Alias(this=source, alias=exp.Identifier(this="alias")) + result = _classify_source_predicate( + source, + {Table("foo"): ["id = 1"]}, + None, + None, + None, + ) + assert result == ("none", None, None) + + +def test_scan_until_scope_boundary_tracks_parenthesis_depth() -> None: + sql = "SELECT * FROM t WHERE (a = 1)" + tokens = _tokenize(sql) + where_token = _token_by_text(tokens, "WHERE") + assert _scan_until_scope_boundary(tokens, where_token.start, stop_at_join=False) == ( + "eof", + None, + ) + + +def test_find_condition_end_handles_subquery_closing_paren() -> None: + sql = "SELECT * FROM (SELECT * FROM t WHERE a = 1)" + tokens = _tokenize(sql) + where_index = _token_index(tokens, sqlglot.tokens.TokenType.WHERE) + end = _find_condition_end(sql, tokens, where_index, stop_at_join=False) + assert sql[end] == ")" + + +def test_find_condition_end_handles_parenthesized_expression() -> None: + sql = "SELECT * FROM t WHERE (a = 1)" + tokens = _tokenize(sql) + where_index = _token_index(tokens, sqlglot.tokens.TokenType.WHERE) + end = _find_condition_end(sql, tokens, where_index, stop_at_join=False) + assert end == len(sql) + + +def test_find_where_splice_handles_trailing_where_keyword() -> None: + sql = "SELECT * FROM t WHERE" + tokens = _tokenize(sql) + splices = _find_where_splice(sql, tokens, anchor=0, pred_sql="t.id = 1") + assert splices == [(len(sql), " t.id = 1")] + + +def test_find_join_splice_handles_trailing_on_keyword() -> None: + sql = "SELECT * FROM a JOIN b ON" + tokens = _tokenize(sql) + b_token = _token_by_text(tokens, "b") + splices = _find_join_splice(sql, tokens, b_token.end, "b.id = 1") + assert splices == [(len(sql), " b.id = 1")] + + +def test_find_join_splice_inserts_on_before_where_boundary() -> None: + sql = "SELECT * FROM a JOIN b WHERE x = 1" + tokens = _tokenize(sql) + b_token = _token_by_text(tokens, "b") + splices = _find_join_splice(sql, tokens, b_token.end, "b.id = 1") + assert splices == [(sql.index("WHERE") - 1, " ON b.id = 1")] + + +def test_scan_join_clause_covers_nested_parentheses_and_join_boundary() -> None: + sql = "SELECT * FROM a JOIN b ON (a.id = b.id) JOIN c ON 1 = 1" + tokens = _tokenize(sql) + b_token = _token_by_text(tokens, "b") + on_index, boundary_index = _scan_join_clause(tokens, b_token.end) + assert on_index is not None + assert boundary_index is not None + assert tokens[boundary_index].token_type == sqlglot.tokens.TokenType.JOIN + + +def test_scan_join_clause_stops_at_outer_closing_paren() -> None: + sql = "SELECT * FROM (SELECT * FROM a JOIN b) sub" + tokens = _tokenize(sql) + b_token = _token_by_text(tokens, "b") + _, boundary_index = _scan_join_clause(tokens, b_token.end) + assert boundary_index is not None + assert tokens[boundary_index].token_type == sqlglot.tokens.TokenType.R_PAREN + + +def test_splices_for_scope_handles_empty_join_splice_result( + monkeypatch: object, +) -> None: + class _Scope: + sources = {"x": object()} + + sql = "SELECT 1" + tokens = _tokenize(sql) + monkeypatch.setattr( + "superset.sql.rls_splice._classify_source_predicate", + lambda *args, **kwargs: ("join", 0, "x.id = 1"), + ) + monkeypatch.setattr( + "superset.sql.rls_splice._find_join_splice", + lambda *args, **kwargs: [], + ) + assert ( + _splices_for_scope( + sql, + tokens, + _Scope(), + {Table("x"): ["x.id = 1"]}, + None, + None, + None, + ) + == [] + ) + + +def test_splices_for_scope_combines_join_and_from_splices(monkeypatch: object) -> None: + class _Scope: + sources = {"f": object(), "j": object()} + + sql = "SELECT 1" + tokens = _tokenize(sql) + calls = [("from", 3, "f.id = 1"), ("join", 6, "j.id = 2")] + + def _fake_classify(*args: object, **kwargs: object) -> tuple[str, int, str]: + return calls.pop(0) + + monkeypatch.setattr("superset.sql.rls_splice._classify_source_predicate", _fake_classify) + monkeypatch.setattr( + "superset.sql.rls_splice._find_join_splice", + lambda *args, **kwargs: [(50, " ON j.id = 2")], + ) + monkeypatch.setattr( + "superset.sql.rls_splice._find_where_splice", + lambda *args, **kwargs: [(20, " WHERE f.id = 1")], + ) + + assert _splices_for_scope( + sql, + tokens, + _Scope(), + {Table("f"): ["id = 1"], Table("j"): ["id = 2"]}, + None, + None, + None, + ) == [(50, " ON j.id = 2"), (20, " WHERE f.id = 1")] + + +def test_splices_for_scope_join_then_next_source(monkeypatch: object) -> None: + class _Scope: + sources = {"j": object(), "f": object()} + + sql = "SELECT 1" + tokens = _tokenize(sql) + calls = [("join", 3, "j.id = 2"), ("none", None, None)] + + def _fake_classify( + *args: object, **kwargs: object + ) -> tuple[str, int | None, str | None]: + return calls.pop(0) + + monkeypatch.setattr("superset.sql.rls_splice._classify_source_predicate", _fake_classify) + monkeypatch.setattr( + "superset.sql.rls_splice._find_join_splice", + lambda *args, **kwargs: [], + ) + + assert _splices_for_scope( + sql, + tokens, + _Scope(), + {Table("j"): ["id = 2"]}, + None, + None, + None, + ) == []