feat: helper functions for RLS (#19055)

* feat: helper functions for RLS

* Add function to inject RLS

* Add UNION tests

* Add tests for schema

* Add more tests; cleanup

* has_table_query via tree traversal

* Wrap existing predicate in parenthesis

* Clean up logic

* Improve table matching
This commit is contained in:
Beto Dealmeida
2022-03-11 14:47:11 -08:00
committed by GitHub
parent c337491d0e
commit 8234395466
2 changed files with 447 additions and 2 deletions

View File

@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
# pylint: disable=invalid-name, too-many-lines
import unittest
from typing import Set
@@ -25,6 +25,10 @@ import sqlparse
from superset.exceptions import QueryClauseValidationException
from superset.sql_parse import (
add_table_name,
has_table_query,
insert_rls,
matches_table_name,
ParsedQuery,
strip_comments_from_sql,
Table,
@@ -1111,7 +1115,8 @@ def test_sqlparse_formatting():
"""
assert sqlparse.format(
"SELECT extract(HOUR from from_unixtime(hour_ts) AT TIME ZONE 'America/Los_Angeles') from table",
"SELECT extract(HOUR from from_unixtime(hour_ts) "
"AT TIME ZONE 'America/Los_Angeles') from table",
reindent=True,
) == (
"SELECT extract(HOUR\n from from_unixtime(hour_ts) "
@@ -1189,3 +1194,241 @@ def test_sqlparse_issue_652():
stmt = sqlparse.parse(r"foo = '\' AND bar = 'baz'")[0]
assert len(stmt.tokens) == 5
assert str(stmt.tokens[0]) == "foo = '\\'"
@pytest.mark.parametrize(
"sql,expected",
[
("SELECT * FROM table", True),
("SELECT a FROM (SELECT 1 AS a) JOIN (SELECT * FROM table)", True),
("(SELECT COUNT(DISTINCT name) AS foo FROM birth_names)", True),
("COUNT(*)", False),
("SELECT a FROM (SELECT 1 AS a)", False),
("SELECT a FROM (SELECT 1 AS a) JOIN table", True),
("SELECT * FROM (SELECT 1 AS foo, 2 AS bar) ORDER BY foo ASC, bar", False),
("SELECT * FROM other_table", True),
("extract(HOUR from from_unixtime(hour_ts)", False),
],
)
def test_has_table_query(sql: str, expected: bool) -> None:
"""
Test if a given statement queries a table.
This is used to prevent ad-hoc metrics from querying unauthorized tables, bypassing
row-level security.
"""
statement = sqlparse.parse(sql)[0]
assert has_table_query(statement) == expected
@pytest.mark.parametrize(
"sql,table,rls,expected",
[
# Basic test: append RLS (some_table.id=42) to an existing WHERE clause.
(
"SELECT * FROM some_table WHERE 1=1",
"some_table",
"id=42",
"SELECT * FROM some_table WHERE ( 1=1) AND some_table.id=42",
),
# Any existing predicates MUST to be wrapped in parenthesis because AND has higher
# precedence than OR. If the RLS it `1=0` and we didn't add parenthesis a user
# could bypass it by crafting a query with `WHERE TRUE OR FALSE`, since
# `WHERE TRUE OR FALSE AND 1=0` evaluates to `WHERE TRUE OR (FALSE AND 1=0)`.
(
"SELECT * FROM some_table WHERE TRUE OR FALSE",
"some_table",
"1=0",
"SELECT * FROM some_table WHERE ( TRUE OR FALSE) AND 1=0",
),
# Here "table" is a reserved word; since sqlparse is too aggressive when
# characterizing reserved words we need to support them even when not quoted.
(
"SELECT * FROM table WHERE 1=1",
"table",
"id=42",
"SELECT * FROM table WHERE ( 1=1) AND table.id=42",
),
# RLS is only applied to queries reading from the associated table.
(
"SELECT * FROM table WHERE 1=1",
"other_table",
"id=42",
"SELECT * FROM table WHERE 1=1",
),
(
"SELECT * FROM other_table WHERE 1=1",
"table",
"id=42",
"SELECT * FROM other_table WHERE 1=1",
),
# If there's no pre-existing WHERE clause we create one.
(
"SELECT * FROM table",
"table",
"id=42",
"SELECT * FROM table WHERE table.id=42",
),
(
"SELECT * FROM some_table",
"some_table",
"id=42",
"SELECT * FROM some_table WHERE some_table.id=42",
),
(
"SELECT * FROM table ORDER BY id",
"table",
"id=42",
"SELECT * FROM table WHERE table.id=42 ORDER BY id",
),
(
"SELECT * FROM some_table;",
"some_table",
"id=42",
"SELECT * FROM some_table WHERE some_table.id=42 ;",
),
(
"SELECT * FROM some_table ;",
"some_table",
"id=42",
"SELECT * FROM some_table WHERE some_table.id=42 ;",
),
(
"SELECT * FROM some_table ",
"some_table",
"id=42",
"SELECT * FROM some_table WHERE some_table.id=42",
),
# We add the RLS even if it's already present, to be conservative. It should have
# no impact on the query, and it's easier than testing if the RLS is already
# present (it could be present in an OR clause, eg).
(
"SELECT * FROM table WHERE 1=1 AND table.id=42",
"table",
"id=42",
"SELECT * FROM table WHERE ( 1=1 AND table.id=42) AND table.id=42",
),
(
(
"SELECT * FROM table JOIN other_table ON "
"table.id = other_table.id AND other_table.id=42"
),
"other_table",
"id=42",
(
"SELECT * FROM table JOIN other_table ON other_table.id=42 "
"AND ( table.id = other_table.id AND other_table.id=42 )"
),
),
(
"SELECT * FROM table WHERE 1=1 AND id=42",
"table",
"id=42",
"SELECT * FROM table WHERE ( 1=1 AND id=42) AND table.id=42",
),
# For joins we apply the RLS to the ON clause, since it's easier and prevents
# leaking information about number of rows on OUTER JOINs.
(
"SELECT * FROM table JOIN other_table ON table.id = other_table.id",
"other_table",
"id=42",
(
"SELECT * FROM table JOIN other_table ON other_table.id=42 "
"AND ( table.id = other_table.id )"
),
),
(
(
"SELECT * FROM table JOIN other_table ON table.id = other_table.id "
"WHERE 1=1"
),
"other_table",
"id=42",
(
"SELECT * FROM table JOIN other_table ON other_table.id=42 "
"AND ( table.id = other_table.id ) WHERE 1=1"
),
),
# Subqueries also work, as expected.
(
"SELECT * FROM (SELECT * FROM other_table)",
"other_table",
"id=42",
"SELECT * FROM (SELECT * FROM other_table WHERE other_table.id=42 )",
),
# As well as UNION.
(
"SELECT * FROM table UNION ALL SELECT * FROM other_table",
"table",
"id=42",
"SELECT * FROM table WHERE table.id=42 UNION ALL SELECT * FROM other_table",
),
(
"SELECT * FROM table UNION ALL SELECT * FROM other_table",
"other_table",
"id=42",
(
"SELECT * FROM table UNION ALL "
"SELECT * FROM other_table WHERE other_table.id=42"
),
),
# When comparing fully qualified table names (eg, schema.table) to simple names
# (eg, table) we are also conservative, assuming the schema is the same, since
# we don't have information on the default schema.
(
"SELECT * FROM schema.table_name",
"table_name",
"id=42",
"SELECT * FROM schema.table_name WHERE table_name.id=42",
),
(
"SELECT * FROM schema.table_name",
"schema.table_name",
"id=42",
"SELECT * FROM schema.table_name WHERE schema.table_name.id=42",
),
(
"SELECT * FROM table_name",
"schema.table_name",
"id=42",
"SELECT * FROM table_name WHERE schema.table_name.id=42",
),
],
)
def test_insert_rls(sql: str, table: str, rls: str, expected: str) -> None:
"""
Insert into a statement a given RLS condition associated with a table.
"""
statement = sqlparse.parse(sql)[0]
condition = sqlparse.parse(rls)[0]
assert str(insert_rls(statement, table, condition)).strip() == expected.strip()
@pytest.mark.parametrize(
"rls,table,expected",
[
("id=42", "users", "users.id=42"),
("users.id=42", "users", "users.id=42"),
("schema.users.id=42", "users", "schema.users.id=42"),
("false", "users", "false"),
],
)
def test_add_table_name(rls: str, table: str, expected: str) -> None:
condition = sqlparse.parse(rls)[0]
add_table_name(condition, table)
assert str(condition) == expected
@pytest.mark.parametrize(
"candidate,table,expected",
[
("table", "table", True),
("schema.table", "table", True),
("table", "schema.table", True),
('schema."my table"', '"my table"', True),
('schema."my.table"', '"my.table"', True),
],
)
def test_matches_table_name(candidate: str, table: str, expected: bool) -> None:
token = sqlparse.parse(candidate)[0].tokens[0]
assert matches_table_name(token, table) == expected