mirror of
https://github.com/apache/superset.git
synced 2026-04-20 00:24:38 +00:00
feat: helper functions for RLS (#19055)
* feat: helper functions for RLS * Add function to inject RLS * Add UNION tests * Add tests for schema * Add more tests; cleanup * has_table_query via tree traversal * Wrap existing predicate in parenthesis * Clean up logic * Improve table matching
This commit is contained in:
@@ -29,6 +29,7 @@ from sqlparse.sql import (
|
||||
remove_quotes,
|
||||
Token,
|
||||
TokenList,
|
||||
Where,
|
||||
)
|
||||
from sqlparse.tokens import (
|
||||
CTE,
|
||||
@@ -458,3 +459,204 @@ def validate_filter_clause(clause: str) -> None:
|
||||
)
|
||||
if open_parens > 0:
|
||||
raise QueryClauseValidationException("Unclosed parenthesis in filter clause")
|
||||
|
||||
|
||||
class InsertRLSState(str, Enum):
|
||||
"""
|
||||
State machine that scans for WHERE and ON clauses referencing tables.
|
||||
"""
|
||||
|
||||
SCANNING = "SCANNING"
|
||||
SEEN_SOURCE = "SEEN_SOURCE"
|
||||
FOUND_TABLE = "FOUND_TABLE"
|
||||
|
||||
|
||||
def has_table_query(token_list: TokenList) -> bool:
|
||||
"""
|
||||
Return if a stament has a query reading from a table.
|
||||
|
||||
>>> has_table_query(sqlparse.parse("COUNT(*)")[0])
|
||||
False
|
||||
>>> has_table_query(sqlparse.parse("SELECT * FROM table")[0])
|
||||
True
|
||||
|
||||
Note that queries reading from constant values return false:
|
||||
|
||||
>>> has_table_query(sqlparse.parse("SELECT * FROM (SELECT 1)")[0])
|
||||
False
|
||||
|
||||
"""
|
||||
state = InsertRLSState.SCANNING
|
||||
for token in token_list.tokens:
|
||||
|
||||
# # Recurse into child token list
|
||||
if isinstance(token, TokenList) and has_table_query(token):
|
||||
return True
|
||||
|
||||
# Found a source keyword (FROM/JOIN)
|
||||
if imt(token, m=[(Keyword, "FROM"), (Keyword, "JOIN")]):
|
||||
state = InsertRLSState.SEEN_SOURCE
|
||||
|
||||
# Found identifier/keyword after FROM/JOIN
|
||||
elif state == InsertRLSState.SEEN_SOURCE and (
|
||||
isinstance(token, sqlparse.sql.Identifier) or token.ttype == Keyword
|
||||
):
|
||||
return True
|
||||
|
||||
# Found nothing, leaving source
|
||||
elif state == InsertRLSState.SEEN_SOURCE and token.ttype != Whitespace:
|
||||
state = InsertRLSState.SCANNING
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def add_table_name(rls: TokenList, table: str) -> None:
|
||||
"""
|
||||
Modify a RLS expression ensuring columns are fully qualified.
|
||||
"""
|
||||
tokens = rls.tokens[:]
|
||||
while tokens:
|
||||
token = tokens.pop(0)
|
||||
|
||||
if isinstance(token, Identifier) and token.get_parent_name() is None:
|
||||
token.tokens = [
|
||||
Token(Name, table),
|
||||
Token(Punctuation, "."),
|
||||
Token(Name, token.get_name()),
|
||||
]
|
||||
elif isinstance(token, TokenList):
|
||||
tokens.extend(token.tokens)
|
||||
|
||||
|
||||
def matches_table_name(candidate: Token, table: str) -> bool:
|
||||
"""
|
||||
Returns if the token represents a reference to the table.
|
||||
|
||||
Tables can be fully qualified with periods.
|
||||
|
||||
Note that in theory a table should be represented as an identifier, but due to
|
||||
sqlparse's aggressive list of keywords (spanning multiple dialects) often it gets
|
||||
classified as a keyword.
|
||||
"""
|
||||
if not isinstance(candidate, Identifier):
|
||||
candidate = Identifier([Token(Name, candidate.value)])
|
||||
|
||||
target = sqlparse.parse(table)[0].tokens[0]
|
||||
if not isinstance(target, Identifier):
|
||||
target = Identifier([Token(Name, target.value)])
|
||||
|
||||
# match from right to left, splitting on the period, eg, schema.table == table
|
||||
for left, right in zip(candidate.tokens[::-1], target.tokens[::-1]):
|
||||
if left.value != right.value:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def insert_rls(token_list: TokenList, table: str, rls: TokenList) -> TokenList:
|
||||
"""
|
||||
Update a statement inplace applying an RLS associated with a given table.
|
||||
"""
|
||||
# make sure the identifier has the table name
|
||||
add_table_name(rls, table)
|
||||
|
||||
state = InsertRLSState.SCANNING
|
||||
for token in token_list.tokens:
|
||||
|
||||
# Recurse into child token list
|
||||
if isinstance(token, TokenList):
|
||||
i = token_list.tokens.index(token)
|
||||
token_list.tokens[i] = insert_rls(token, table, rls)
|
||||
|
||||
# Found a source keyword (FROM/JOIN)
|
||||
if imt(token, m=[(Keyword, "FROM"), (Keyword, "JOIN")]):
|
||||
state = InsertRLSState.SEEN_SOURCE
|
||||
|
||||
# Found identifier/keyword after FROM/JOIN, test for table
|
||||
elif state == InsertRLSState.SEEN_SOURCE and (
|
||||
isinstance(token, Identifier) or token.ttype == Keyword
|
||||
):
|
||||
if matches_table_name(token, table):
|
||||
state = InsertRLSState.FOUND_TABLE
|
||||
|
||||
# Found WHERE clause, insert RLS. Note that we insert it even it already exists,
|
||||
# to be on the safe side: it could be present in a clause like `1=1 OR RLS`.
|
||||
elif state == InsertRLSState.FOUND_TABLE and isinstance(token, Where):
|
||||
token.tokens[1:1] = [Token(Whitespace, " "), Token(Punctuation, "(")]
|
||||
token.tokens.extend(
|
||||
[
|
||||
Token(Punctuation, ")"),
|
||||
Token(Whitespace, " "),
|
||||
Token(Keyword, "AND"),
|
||||
Token(Whitespace, " "),
|
||||
]
|
||||
+ rls.tokens
|
||||
)
|
||||
state = InsertRLSState.SCANNING
|
||||
|
||||
# Found ON clause, insert RLS. The logic for ON is more complicated than the logic
|
||||
# for WHERE because in the former the comparisons are siblings, while on the
|
||||
# latter they are children.
|
||||
elif (
|
||||
state == InsertRLSState.FOUND_TABLE
|
||||
and token.ttype == Keyword
|
||||
and token.value.upper() == "ON"
|
||||
):
|
||||
tokens = [
|
||||
Token(Whitespace, " "),
|
||||
rls,
|
||||
Token(Whitespace, " "),
|
||||
Token(Keyword, "AND"),
|
||||
Token(Whitespace, " "),
|
||||
Token(Punctuation, "("),
|
||||
]
|
||||
i = token_list.tokens.index(token)
|
||||
token.parent.tokens[i + 1 : i + 1] = tokens
|
||||
i += len(tokens) + 2
|
||||
|
||||
# close parenthesis after last existing comparison
|
||||
j = 0
|
||||
for j, sibling in enumerate(token_list.tokens[i:]):
|
||||
# scan until we hit a non-comparison keyword (like ORDER BY) or a WHERE
|
||||
if (
|
||||
sibling.ttype == Keyword
|
||||
and not imt(
|
||||
sibling, m=[(Keyword, "AND"), (Keyword, "OR"), (Keyword, "NOT")]
|
||||
)
|
||||
or isinstance(sibling, Where)
|
||||
):
|
||||
j -= 1
|
||||
break
|
||||
token.parent.tokens[i + j + 1 : i + j + 1] = [
|
||||
Token(Whitespace, " "),
|
||||
Token(Punctuation, ")"),
|
||||
Token(Whitespace, " "),
|
||||
]
|
||||
|
||||
state = InsertRLSState.SCANNING
|
||||
|
||||
# Found table but no WHERE clause found, insert one
|
||||
elif state == InsertRLSState.FOUND_TABLE and token.ttype != Whitespace:
|
||||
i = token_list.tokens.index(token)
|
||||
token_list.tokens[i:i] = [
|
||||
Token(Whitespace, " "),
|
||||
Where([Token(Keyword, "WHERE"), Token(Whitespace, " "), rls]),
|
||||
Token(Whitespace, " "),
|
||||
]
|
||||
|
||||
state = InsertRLSState.SCANNING
|
||||
|
||||
# Found nothing, leaving source
|
||||
elif state == InsertRLSState.SEEN_SOURCE and token.ttype != Whitespace:
|
||||
state = InsertRLSState.SCANNING
|
||||
|
||||
# found table at the end of the statement; append a WHERE clause
|
||||
if state == InsertRLSState.FOUND_TABLE:
|
||||
token_list.tokens.extend(
|
||||
[
|
||||
Token(Whitespace, " "),
|
||||
Where([Token(Keyword, "WHERE"), Token(Whitespace, " "), rls]),
|
||||
]
|
||||
)
|
||||
|
||||
return token_list
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
# pylint: disable=invalid-name, too-many-lines
|
||||
|
||||
import unittest
|
||||
from typing import Set
|
||||
@@ -25,6 +25,10 @@ import sqlparse
|
||||
|
||||
from superset.exceptions import QueryClauseValidationException
|
||||
from superset.sql_parse import (
|
||||
add_table_name,
|
||||
has_table_query,
|
||||
insert_rls,
|
||||
matches_table_name,
|
||||
ParsedQuery,
|
||||
strip_comments_from_sql,
|
||||
Table,
|
||||
@@ -1111,7 +1115,8 @@ def test_sqlparse_formatting():
|
||||
|
||||
"""
|
||||
assert sqlparse.format(
|
||||
"SELECT extract(HOUR from from_unixtime(hour_ts) AT TIME ZONE 'America/Los_Angeles') from table",
|
||||
"SELECT extract(HOUR from from_unixtime(hour_ts) "
|
||||
"AT TIME ZONE 'America/Los_Angeles') from table",
|
||||
reindent=True,
|
||||
) == (
|
||||
"SELECT extract(HOUR\n from from_unixtime(hour_ts) "
|
||||
@@ -1189,3 +1194,241 @@ def test_sqlparse_issue_652():
|
||||
stmt = sqlparse.parse(r"foo = '\' AND bar = 'baz'")[0]
|
||||
assert len(stmt.tokens) == 5
|
||||
assert str(stmt.tokens[0]) == "foo = '\\'"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sql,expected",
|
||||
[
|
||||
("SELECT * FROM table", True),
|
||||
("SELECT a FROM (SELECT 1 AS a) JOIN (SELECT * FROM table)", True),
|
||||
("(SELECT COUNT(DISTINCT name) AS foo FROM birth_names)", True),
|
||||
("COUNT(*)", False),
|
||||
("SELECT a FROM (SELECT 1 AS a)", False),
|
||||
("SELECT a FROM (SELECT 1 AS a) JOIN table", True),
|
||||
("SELECT * FROM (SELECT 1 AS foo, 2 AS bar) ORDER BY foo ASC, bar", False),
|
||||
("SELECT * FROM other_table", True),
|
||||
("extract(HOUR from from_unixtime(hour_ts)", False),
|
||||
],
|
||||
)
|
||||
def test_has_table_query(sql: str, expected: bool) -> None:
|
||||
"""
|
||||
Test if a given statement queries a table.
|
||||
|
||||
This is used to prevent ad-hoc metrics from querying unauthorized tables, bypassing
|
||||
row-level security.
|
||||
"""
|
||||
statement = sqlparse.parse(sql)[0]
|
||||
assert has_table_query(statement) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sql,table,rls,expected",
|
||||
[
|
||||
# Basic test: append RLS (some_table.id=42) to an existing WHERE clause.
|
||||
(
|
||||
"SELECT * FROM some_table WHERE 1=1",
|
||||
"some_table",
|
||||
"id=42",
|
||||
"SELECT * FROM some_table WHERE ( 1=1) AND some_table.id=42",
|
||||
),
|
||||
# Any existing predicates MUST to be wrapped in parenthesis because AND has higher
|
||||
# precedence than OR. If the RLS it `1=0` and we didn't add parenthesis a user
|
||||
# could bypass it by crafting a query with `WHERE TRUE OR FALSE`, since
|
||||
# `WHERE TRUE OR FALSE AND 1=0` evaluates to `WHERE TRUE OR (FALSE AND 1=0)`.
|
||||
(
|
||||
"SELECT * FROM some_table WHERE TRUE OR FALSE",
|
||||
"some_table",
|
||||
"1=0",
|
||||
"SELECT * FROM some_table WHERE ( TRUE OR FALSE) AND 1=0",
|
||||
),
|
||||
# Here "table" is a reserved word; since sqlparse is too aggressive when
|
||||
# characterizing reserved words we need to support them even when not quoted.
|
||||
(
|
||||
"SELECT * FROM table WHERE 1=1",
|
||||
"table",
|
||||
"id=42",
|
||||
"SELECT * FROM table WHERE ( 1=1) AND table.id=42",
|
||||
),
|
||||
# RLS is only applied to queries reading from the associated table.
|
||||
(
|
||||
"SELECT * FROM table WHERE 1=1",
|
||||
"other_table",
|
||||
"id=42",
|
||||
"SELECT * FROM table WHERE 1=1",
|
||||
),
|
||||
(
|
||||
"SELECT * FROM other_table WHERE 1=1",
|
||||
"table",
|
||||
"id=42",
|
||||
"SELECT * FROM other_table WHERE 1=1",
|
||||
),
|
||||
# If there's no pre-existing WHERE clause we create one.
|
||||
(
|
||||
"SELECT * FROM table",
|
||||
"table",
|
||||
"id=42",
|
||||
"SELECT * FROM table WHERE table.id=42",
|
||||
),
|
||||
(
|
||||
"SELECT * FROM some_table",
|
||||
"some_table",
|
||||
"id=42",
|
||||
"SELECT * FROM some_table WHERE some_table.id=42",
|
||||
),
|
||||
(
|
||||
"SELECT * FROM table ORDER BY id",
|
||||
"table",
|
||||
"id=42",
|
||||
"SELECT * FROM table WHERE table.id=42 ORDER BY id",
|
||||
),
|
||||
(
|
||||
"SELECT * FROM some_table;",
|
||||
"some_table",
|
||||
"id=42",
|
||||
"SELECT * FROM some_table WHERE some_table.id=42 ;",
|
||||
),
|
||||
(
|
||||
"SELECT * FROM some_table ;",
|
||||
"some_table",
|
||||
"id=42",
|
||||
"SELECT * FROM some_table WHERE some_table.id=42 ;",
|
||||
),
|
||||
(
|
||||
"SELECT * FROM some_table ",
|
||||
"some_table",
|
||||
"id=42",
|
||||
"SELECT * FROM some_table WHERE some_table.id=42",
|
||||
),
|
||||
# We add the RLS even if it's already present, to be conservative. It should have
|
||||
# no impact on the query, and it's easier than testing if the RLS is already
|
||||
# present (it could be present in an OR clause, eg).
|
||||
(
|
||||
"SELECT * FROM table WHERE 1=1 AND table.id=42",
|
||||
"table",
|
||||
"id=42",
|
||||
"SELECT * FROM table WHERE ( 1=1 AND table.id=42) AND table.id=42",
|
||||
),
|
||||
(
|
||||
(
|
||||
"SELECT * FROM table JOIN other_table ON "
|
||||
"table.id = other_table.id AND other_table.id=42"
|
||||
),
|
||||
"other_table",
|
||||
"id=42",
|
||||
(
|
||||
"SELECT * FROM table JOIN other_table ON other_table.id=42 "
|
||||
"AND ( table.id = other_table.id AND other_table.id=42 )"
|
||||
),
|
||||
),
|
||||
(
|
||||
"SELECT * FROM table WHERE 1=1 AND id=42",
|
||||
"table",
|
||||
"id=42",
|
||||
"SELECT * FROM table WHERE ( 1=1 AND id=42) AND table.id=42",
|
||||
),
|
||||
# For joins we apply the RLS to the ON clause, since it's easier and prevents
|
||||
# leaking information about number of rows on OUTER JOINs.
|
||||
(
|
||||
"SELECT * FROM table JOIN other_table ON table.id = other_table.id",
|
||||
"other_table",
|
||||
"id=42",
|
||||
(
|
||||
"SELECT * FROM table JOIN other_table ON other_table.id=42 "
|
||||
"AND ( table.id = other_table.id )"
|
||||
),
|
||||
),
|
||||
(
|
||||
(
|
||||
"SELECT * FROM table JOIN other_table ON table.id = other_table.id "
|
||||
"WHERE 1=1"
|
||||
),
|
||||
"other_table",
|
||||
"id=42",
|
||||
(
|
||||
"SELECT * FROM table JOIN other_table ON other_table.id=42 "
|
||||
"AND ( table.id = other_table.id ) WHERE 1=1"
|
||||
),
|
||||
),
|
||||
# Subqueries also work, as expected.
|
||||
(
|
||||
"SELECT * FROM (SELECT * FROM other_table)",
|
||||
"other_table",
|
||||
"id=42",
|
||||
"SELECT * FROM (SELECT * FROM other_table WHERE other_table.id=42 )",
|
||||
),
|
||||
# As well as UNION.
|
||||
(
|
||||
"SELECT * FROM table UNION ALL SELECT * FROM other_table",
|
||||
"table",
|
||||
"id=42",
|
||||
"SELECT * FROM table WHERE table.id=42 UNION ALL SELECT * FROM other_table",
|
||||
),
|
||||
(
|
||||
"SELECT * FROM table UNION ALL SELECT * FROM other_table",
|
||||
"other_table",
|
||||
"id=42",
|
||||
(
|
||||
"SELECT * FROM table UNION ALL "
|
||||
"SELECT * FROM other_table WHERE other_table.id=42"
|
||||
),
|
||||
),
|
||||
# When comparing fully qualified table names (eg, schema.table) to simple names
|
||||
# (eg, table) we are also conservative, assuming the schema is the same, since
|
||||
# we don't have information on the default schema.
|
||||
(
|
||||
"SELECT * FROM schema.table_name",
|
||||
"table_name",
|
||||
"id=42",
|
||||
"SELECT * FROM schema.table_name WHERE table_name.id=42",
|
||||
),
|
||||
(
|
||||
"SELECT * FROM schema.table_name",
|
||||
"schema.table_name",
|
||||
"id=42",
|
||||
"SELECT * FROM schema.table_name WHERE schema.table_name.id=42",
|
||||
),
|
||||
(
|
||||
"SELECT * FROM table_name",
|
||||
"schema.table_name",
|
||||
"id=42",
|
||||
"SELECT * FROM table_name WHERE schema.table_name.id=42",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_insert_rls(sql: str, table: str, rls: str, expected: str) -> None:
|
||||
"""
|
||||
Insert into a statement a given RLS condition associated with a table.
|
||||
"""
|
||||
statement = sqlparse.parse(sql)[0]
|
||||
condition = sqlparse.parse(rls)[0]
|
||||
assert str(insert_rls(statement, table, condition)).strip() == expected.strip()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"rls,table,expected",
|
||||
[
|
||||
("id=42", "users", "users.id=42"),
|
||||
("users.id=42", "users", "users.id=42"),
|
||||
("schema.users.id=42", "users", "schema.users.id=42"),
|
||||
("false", "users", "false"),
|
||||
],
|
||||
)
|
||||
def test_add_table_name(rls: str, table: str, expected: str) -> None:
|
||||
condition = sqlparse.parse(rls)[0]
|
||||
add_table_name(condition, table)
|
||||
assert str(condition) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"candidate,table,expected",
|
||||
[
|
||||
("table", "table", True),
|
||||
("schema.table", "table", True),
|
||||
("table", "schema.table", True),
|
||||
('schema."my table"', '"my table"', True),
|
||||
('schema."my.table"', '"my.table"', True),
|
||||
],
|
||||
)
|
||||
def test_matches_table_name(candidate: str, table: str, expected: bool) -> None:
|
||||
token = sqlparse.parse(candidate)[0].tokens[0]
|
||||
assert matches_table_name(token, table) == expected
|
||||
|
||||
Reference in New Issue
Block a user