mirror of
https://github.com/apache/superset.git
synced 2026-04-20 16:44:46 +00:00
feat: helper functions for RLS (#19055)
* feat: helper functions for RLS * Add function to inject RLS * Add UNION tests * Add tests for schema * Add more tests; cleanup * has_table_query via tree traversal * Wrap existing predicate in parenthesis * Clean up logic * Improve table matching
This commit is contained in:
@@ -15,7 +15,7 @@
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
# pylint: disable=invalid-name, too-many-lines
|
||||
|
||||
import unittest
|
||||
from typing import Set
|
||||
@@ -25,6 +25,10 @@ import sqlparse
|
||||
|
||||
from superset.exceptions import QueryClauseValidationException
|
||||
from superset.sql_parse import (
|
||||
add_table_name,
|
||||
has_table_query,
|
||||
insert_rls,
|
||||
matches_table_name,
|
||||
ParsedQuery,
|
||||
strip_comments_from_sql,
|
||||
Table,
|
||||
@@ -1111,7 +1115,8 @@ def test_sqlparse_formatting():
|
||||
|
||||
"""
|
||||
assert sqlparse.format(
|
||||
"SELECT extract(HOUR from from_unixtime(hour_ts) AT TIME ZONE 'America/Los_Angeles') from table",
|
||||
"SELECT extract(HOUR from from_unixtime(hour_ts) "
|
||||
"AT TIME ZONE 'America/Los_Angeles') from table",
|
||||
reindent=True,
|
||||
) == (
|
||||
"SELECT extract(HOUR\n from from_unixtime(hour_ts) "
|
||||
@@ -1189,3 +1194,241 @@ def test_sqlparse_issue_652():
|
||||
stmt = sqlparse.parse(r"foo = '\' AND bar = 'baz'")[0]
|
||||
assert len(stmt.tokens) == 5
|
||||
assert str(stmt.tokens[0]) == "foo = '\\'"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sql,expected",
|
||||
[
|
||||
("SELECT * FROM table", True),
|
||||
("SELECT a FROM (SELECT 1 AS a) JOIN (SELECT * FROM table)", True),
|
||||
("(SELECT COUNT(DISTINCT name) AS foo FROM birth_names)", True),
|
||||
("COUNT(*)", False),
|
||||
("SELECT a FROM (SELECT 1 AS a)", False),
|
||||
("SELECT a FROM (SELECT 1 AS a) JOIN table", True),
|
||||
("SELECT * FROM (SELECT 1 AS foo, 2 AS bar) ORDER BY foo ASC, bar", False),
|
||||
("SELECT * FROM other_table", True),
|
||||
("extract(HOUR from from_unixtime(hour_ts)", False),
|
||||
],
|
||||
)
|
||||
def test_has_table_query(sql: str, expected: bool) -> None:
|
||||
"""
|
||||
Test if a given statement queries a table.
|
||||
|
||||
This is used to prevent ad-hoc metrics from querying unauthorized tables, bypassing
|
||||
row-level security.
|
||||
"""
|
||||
statement = sqlparse.parse(sql)[0]
|
||||
assert has_table_query(statement) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sql,table,rls,expected",
|
||||
[
|
||||
# Basic test: append RLS (some_table.id=42) to an existing WHERE clause.
|
||||
(
|
||||
"SELECT * FROM some_table WHERE 1=1",
|
||||
"some_table",
|
||||
"id=42",
|
||||
"SELECT * FROM some_table WHERE ( 1=1) AND some_table.id=42",
|
||||
),
|
||||
# Any existing predicates MUST to be wrapped in parenthesis because AND has higher
|
||||
# precedence than OR. If the RLS it `1=0` and we didn't add parenthesis a user
|
||||
# could bypass it by crafting a query with `WHERE TRUE OR FALSE`, since
|
||||
# `WHERE TRUE OR FALSE AND 1=0` evaluates to `WHERE TRUE OR (FALSE AND 1=0)`.
|
||||
(
|
||||
"SELECT * FROM some_table WHERE TRUE OR FALSE",
|
||||
"some_table",
|
||||
"1=0",
|
||||
"SELECT * FROM some_table WHERE ( TRUE OR FALSE) AND 1=0",
|
||||
),
|
||||
# Here "table" is a reserved word; since sqlparse is too aggressive when
|
||||
# characterizing reserved words we need to support them even when not quoted.
|
||||
(
|
||||
"SELECT * FROM table WHERE 1=1",
|
||||
"table",
|
||||
"id=42",
|
||||
"SELECT * FROM table WHERE ( 1=1) AND table.id=42",
|
||||
),
|
||||
# RLS is only applied to queries reading from the associated table.
|
||||
(
|
||||
"SELECT * FROM table WHERE 1=1",
|
||||
"other_table",
|
||||
"id=42",
|
||||
"SELECT * FROM table WHERE 1=1",
|
||||
),
|
||||
(
|
||||
"SELECT * FROM other_table WHERE 1=1",
|
||||
"table",
|
||||
"id=42",
|
||||
"SELECT * FROM other_table WHERE 1=1",
|
||||
),
|
||||
# If there's no pre-existing WHERE clause we create one.
|
||||
(
|
||||
"SELECT * FROM table",
|
||||
"table",
|
||||
"id=42",
|
||||
"SELECT * FROM table WHERE table.id=42",
|
||||
),
|
||||
(
|
||||
"SELECT * FROM some_table",
|
||||
"some_table",
|
||||
"id=42",
|
||||
"SELECT * FROM some_table WHERE some_table.id=42",
|
||||
),
|
||||
(
|
||||
"SELECT * FROM table ORDER BY id",
|
||||
"table",
|
||||
"id=42",
|
||||
"SELECT * FROM table WHERE table.id=42 ORDER BY id",
|
||||
),
|
||||
(
|
||||
"SELECT * FROM some_table;",
|
||||
"some_table",
|
||||
"id=42",
|
||||
"SELECT * FROM some_table WHERE some_table.id=42 ;",
|
||||
),
|
||||
(
|
||||
"SELECT * FROM some_table ;",
|
||||
"some_table",
|
||||
"id=42",
|
||||
"SELECT * FROM some_table WHERE some_table.id=42 ;",
|
||||
),
|
||||
(
|
||||
"SELECT * FROM some_table ",
|
||||
"some_table",
|
||||
"id=42",
|
||||
"SELECT * FROM some_table WHERE some_table.id=42",
|
||||
),
|
||||
# We add the RLS even if it's already present, to be conservative. It should have
|
||||
# no impact on the query, and it's easier than testing if the RLS is already
|
||||
# present (it could be present in an OR clause, eg).
|
||||
(
|
||||
"SELECT * FROM table WHERE 1=1 AND table.id=42",
|
||||
"table",
|
||||
"id=42",
|
||||
"SELECT * FROM table WHERE ( 1=1 AND table.id=42) AND table.id=42",
|
||||
),
|
||||
(
|
||||
(
|
||||
"SELECT * FROM table JOIN other_table ON "
|
||||
"table.id = other_table.id AND other_table.id=42"
|
||||
),
|
||||
"other_table",
|
||||
"id=42",
|
||||
(
|
||||
"SELECT * FROM table JOIN other_table ON other_table.id=42 "
|
||||
"AND ( table.id = other_table.id AND other_table.id=42 )"
|
||||
),
|
||||
),
|
||||
(
|
||||
"SELECT * FROM table WHERE 1=1 AND id=42",
|
||||
"table",
|
||||
"id=42",
|
||||
"SELECT * FROM table WHERE ( 1=1 AND id=42) AND table.id=42",
|
||||
),
|
||||
# For joins we apply the RLS to the ON clause, since it's easier and prevents
|
||||
# leaking information about number of rows on OUTER JOINs.
|
||||
(
|
||||
"SELECT * FROM table JOIN other_table ON table.id = other_table.id",
|
||||
"other_table",
|
||||
"id=42",
|
||||
(
|
||||
"SELECT * FROM table JOIN other_table ON other_table.id=42 "
|
||||
"AND ( table.id = other_table.id )"
|
||||
),
|
||||
),
|
||||
(
|
||||
(
|
||||
"SELECT * FROM table JOIN other_table ON table.id = other_table.id "
|
||||
"WHERE 1=1"
|
||||
),
|
||||
"other_table",
|
||||
"id=42",
|
||||
(
|
||||
"SELECT * FROM table JOIN other_table ON other_table.id=42 "
|
||||
"AND ( table.id = other_table.id ) WHERE 1=1"
|
||||
),
|
||||
),
|
||||
# Subqueries also work, as expected.
|
||||
(
|
||||
"SELECT * FROM (SELECT * FROM other_table)",
|
||||
"other_table",
|
||||
"id=42",
|
||||
"SELECT * FROM (SELECT * FROM other_table WHERE other_table.id=42 )",
|
||||
),
|
||||
# As well as UNION.
|
||||
(
|
||||
"SELECT * FROM table UNION ALL SELECT * FROM other_table",
|
||||
"table",
|
||||
"id=42",
|
||||
"SELECT * FROM table WHERE table.id=42 UNION ALL SELECT * FROM other_table",
|
||||
),
|
||||
(
|
||||
"SELECT * FROM table UNION ALL SELECT * FROM other_table",
|
||||
"other_table",
|
||||
"id=42",
|
||||
(
|
||||
"SELECT * FROM table UNION ALL "
|
||||
"SELECT * FROM other_table WHERE other_table.id=42"
|
||||
),
|
||||
),
|
||||
# When comparing fully qualified table names (eg, schema.table) to simple names
|
||||
# (eg, table) we are also conservative, assuming the schema is the same, since
|
||||
# we don't have information on the default schema.
|
||||
(
|
||||
"SELECT * FROM schema.table_name",
|
||||
"table_name",
|
||||
"id=42",
|
||||
"SELECT * FROM schema.table_name WHERE table_name.id=42",
|
||||
),
|
||||
(
|
||||
"SELECT * FROM schema.table_name",
|
||||
"schema.table_name",
|
||||
"id=42",
|
||||
"SELECT * FROM schema.table_name WHERE schema.table_name.id=42",
|
||||
),
|
||||
(
|
||||
"SELECT * FROM table_name",
|
||||
"schema.table_name",
|
||||
"id=42",
|
||||
"SELECT * FROM table_name WHERE schema.table_name.id=42",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_insert_rls(sql: str, table: str, rls: str, expected: str) -> None:
|
||||
"""
|
||||
Insert into a statement a given RLS condition associated with a table.
|
||||
"""
|
||||
statement = sqlparse.parse(sql)[0]
|
||||
condition = sqlparse.parse(rls)[0]
|
||||
assert str(insert_rls(statement, table, condition)).strip() == expected.strip()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"rls,table,expected",
|
||||
[
|
||||
("id=42", "users", "users.id=42"),
|
||||
("users.id=42", "users", "users.id=42"),
|
||||
("schema.users.id=42", "users", "schema.users.id=42"),
|
||||
("false", "users", "false"),
|
||||
],
|
||||
)
|
||||
def test_add_table_name(rls: str, table: str, expected: str) -> None:
|
||||
condition = sqlparse.parse(rls)[0]
|
||||
add_table_name(condition, table)
|
||||
assert str(condition) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"candidate,table,expected",
|
||||
[
|
||||
("table", "table", True),
|
||||
("schema.table", "table", True),
|
||||
("table", "schema.table", True),
|
||||
('schema."my table"', '"my table"', True),
|
||||
('schema."my.table"', '"my.table"', True),
|
||||
],
|
||||
)
|
||||
def test_matches_table_name(candidate: str, table: str, expected: bool) -> None:
|
||||
token = sqlparse.parse(candidate)[0].tokens[0]
|
||||
assert matches_table_name(token, table) == expected
|
||||
|
||||
Reference in New Issue
Block a user