feat: safer insert RLS (#20323)

This commit is contained in:
Beto Dealmeida
2023-11-08 22:52:25 -05:00
committed by GitHub
parent 90e210892b
commit 2bd611916d
5 changed files with 324 additions and 15 deletions

View File

@@ -44,6 +44,7 @@ from sqlparse.tokens import (
Punctuation,
String,
Whitespace,
Wildcard,
)
from sqlparse.utils import imt
@@ -660,18 +661,29 @@ def get_rls_for_table(
return None
rls = sqlparse.parse(predicate)[0]
add_table_name(rls, str(dataset))
add_table_name(rls, table.table)
return rls
def insert_rls(
def insert_rls_as_subquery(
token_list: TokenList,
database_id: int,
default_schema: Optional[str],
) -> TokenList:
"""
Update a statement inplace applying any associated RLS predicates.
The RLS predicate is applied as subquery replacing the original table:
before: SELECT * FROM some_table WHERE 1=1
after: SELECT * FROM (
SELECT * FROM some_table WHERE some_table.id=42
) AS some_table
WHERE 1=1
This method is safer than ``insert_rls_in_predicate``, but doesn't work in all
databases.
"""
rls: Optional[TokenList] = None
state = InsertRLSState.SCANNING
@@ -679,7 +691,98 @@ def insert_rls(
# Recurse into child token list
if isinstance(token, TokenList):
i = token_list.tokens.index(token)
token_list.tokens[i] = insert_rls(token, database_id, default_schema)
token_list.tokens[i] = insert_rls_as_subquery(
token,
database_id,
default_schema,
)
# Found a source keyword (FROM/JOIN)
if imt(token, m=[(Keyword, "FROM"), (Keyword, "JOIN")]):
state = InsertRLSState.SEEN_SOURCE
# Found identifier/keyword after FROM/JOIN, test for table
elif state == InsertRLSState.SEEN_SOURCE and (
isinstance(token, Identifier) or token.ttype == Keyword
):
rls = get_rls_for_table(token, database_id, default_schema)
if rls:
# replace table with subquery
subquery_alias = (
token.tokens[-1].value
if isinstance(token, Identifier)
else token.value
)
i = token_list.tokens.index(token)
# strip alias from table name
if isinstance(token, Identifier) and token.has_alias():
whitespace_index = token.token_next_by(t=Whitespace)[0]
token.tokens = token.tokens[:whitespace_index]
token_list.tokens[i] = Identifier(
[
Parenthesis(
[
Token(Punctuation, "("),
Token(DML, "SELECT"),
Token(Whitespace, " "),
Token(Wildcard, "*"),
Token(Whitespace, " "),
Token(Keyword, "FROM"),
Token(Whitespace, " "),
token,
Token(Whitespace, " "),
Where(
[
Token(Keyword, "WHERE"),
Token(Whitespace, " "),
rls,
]
),
Token(Punctuation, ")"),
]
),
Token(Whitespace, " "),
Token(Keyword, "AS"),
Token(Whitespace, " "),
Identifier([Token(Name, subquery_alias)]),
]
)
state = InsertRLSState.SCANNING
# Found nothing, leaving source
elif state == InsertRLSState.SEEN_SOURCE and token.ttype != Whitespace:
state = InsertRLSState.SCANNING
return token_list
def insert_rls_in_predicate(
token_list: TokenList,
database_id: int,
default_schema: Optional[str],
) -> TokenList:
"""
Update a statement inplace applying any associated RLS predicates.
The RLS predicate is ``AND``ed to any existing predicates:
before: SELECT * FROM some_table WHERE 1=1
after: SELECT * FROM some_table WHERE ( 1=1) AND some_table.id=42
"""
rls: Optional[TokenList] = None
state = InsertRLSState.SCANNING
for token in token_list.tokens:
# Recurse into child token list
if isinstance(token, TokenList):
i = token_list.tokens.index(token)
token_list.tokens[i] = insert_rls_in_predicate(
token,
database_id,
default_schema,
)
# Found a source keyword (FROM/JOIN)
if imt(token, m=[(Keyword, "FROM"), (Keyword, "JOIN")]):