mirror of
https://github.com/apache/superset.git
synced 2026-04-19 08:04:53 +00:00
feat: safer insert RLS (#20323)
This commit is contained in:
@@ -44,6 +44,7 @@ from sqlparse.tokens import (
|
||||
Punctuation,
|
||||
String,
|
||||
Whitespace,
|
||||
Wildcard,
|
||||
)
|
||||
from sqlparse.utils import imt
|
||||
|
||||
@@ -660,18 +661,29 @@ def get_rls_for_table(
|
||||
return None
|
||||
|
||||
rls = sqlparse.parse(predicate)[0]
|
||||
add_table_name(rls, str(dataset))
|
||||
add_table_name(rls, table.table)
|
||||
|
||||
return rls
|
||||
|
||||
|
||||
def insert_rls(
|
||||
def insert_rls_as_subquery(
|
||||
token_list: TokenList,
|
||||
database_id: int,
|
||||
default_schema: Optional[str],
|
||||
) -> TokenList:
|
||||
"""
|
||||
Update a statement inplace applying any associated RLS predicates.
|
||||
|
||||
The RLS predicate is applied as subquery replacing the original table:
|
||||
|
||||
before: SELECT * FROM some_table WHERE 1=1
|
||||
after: SELECT * FROM (
|
||||
SELECT * FROM some_table WHERE some_table.id=42
|
||||
) AS some_table
|
||||
WHERE 1=1
|
||||
|
||||
This method is safer than ``insert_rls_in_predicate``, but doesn't work in all
|
||||
databases.
|
||||
"""
|
||||
rls: Optional[TokenList] = None
|
||||
state = InsertRLSState.SCANNING
|
||||
@@ -679,7 +691,98 @@ def insert_rls(
|
||||
# Recurse into child token list
|
||||
if isinstance(token, TokenList):
|
||||
i = token_list.tokens.index(token)
|
||||
token_list.tokens[i] = insert_rls(token, database_id, default_schema)
|
||||
token_list.tokens[i] = insert_rls_as_subquery(
|
||||
token,
|
||||
database_id,
|
||||
default_schema,
|
||||
)
|
||||
|
||||
# Found a source keyword (FROM/JOIN)
|
||||
if imt(token, m=[(Keyword, "FROM"), (Keyword, "JOIN")]):
|
||||
state = InsertRLSState.SEEN_SOURCE
|
||||
|
||||
# Found identifier/keyword after FROM/JOIN, test for table
|
||||
elif state == InsertRLSState.SEEN_SOURCE and (
|
||||
isinstance(token, Identifier) or token.ttype == Keyword
|
||||
):
|
||||
rls = get_rls_for_table(token, database_id, default_schema)
|
||||
if rls:
|
||||
# replace table with subquery
|
||||
subquery_alias = (
|
||||
token.tokens[-1].value
|
||||
if isinstance(token, Identifier)
|
||||
else token.value
|
||||
)
|
||||
i = token_list.tokens.index(token)
|
||||
|
||||
# strip alias from table name
|
||||
if isinstance(token, Identifier) and token.has_alias():
|
||||
whitespace_index = token.token_next_by(t=Whitespace)[0]
|
||||
token.tokens = token.tokens[:whitespace_index]
|
||||
|
||||
token_list.tokens[i] = Identifier(
|
||||
[
|
||||
Parenthesis(
|
||||
[
|
||||
Token(Punctuation, "("),
|
||||
Token(DML, "SELECT"),
|
||||
Token(Whitespace, " "),
|
||||
Token(Wildcard, "*"),
|
||||
Token(Whitespace, " "),
|
||||
Token(Keyword, "FROM"),
|
||||
Token(Whitespace, " "),
|
||||
token,
|
||||
Token(Whitespace, " "),
|
||||
Where(
|
||||
[
|
||||
Token(Keyword, "WHERE"),
|
||||
Token(Whitespace, " "),
|
||||
rls,
|
||||
]
|
||||
),
|
||||
Token(Punctuation, ")"),
|
||||
]
|
||||
),
|
||||
Token(Whitespace, " "),
|
||||
Token(Keyword, "AS"),
|
||||
Token(Whitespace, " "),
|
||||
Identifier([Token(Name, subquery_alias)]),
|
||||
]
|
||||
)
|
||||
state = InsertRLSState.SCANNING
|
||||
|
||||
# Found nothing, leaving source
|
||||
elif state == InsertRLSState.SEEN_SOURCE and token.ttype != Whitespace:
|
||||
state = InsertRLSState.SCANNING
|
||||
|
||||
return token_list
|
||||
|
||||
|
||||
def insert_rls_in_predicate(
|
||||
token_list: TokenList,
|
||||
database_id: int,
|
||||
default_schema: Optional[str],
|
||||
) -> TokenList:
|
||||
"""
|
||||
Update a statement inplace applying any associated RLS predicates.
|
||||
|
||||
The RLS predicate is ``AND``ed to any existing predicates:
|
||||
|
||||
before: SELECT * FROM some_table WHERE 1=1
|
||||
after: SELECT * FROM some_table WHERE ( 1=1) AND some_table.id=42
|
||||
|
||||
"""
|
||||
rls: Optional[TokenList] = None
|
||||
state = InsertRLSState.SCANNING
|
||||
for token in token_list.tokens:
|
||||
# Recurse into child token list
|
||||
if isinstance(token, TokenList):
|
||||
i = token_list.tokens.index(token)
|
||||
token_list.tokens[i] = insert_rls_in_predicate(
|
||||
token,
|
||||
database_id,
|
||||
default_schema,
|
||||
)
|
||||
|
||||
# Found a source keyword (FROM/JOIN)
|
||||
if imt(token, m=[(Keyword, "FROM"), (Keyword, "JOIN")]):
|
||||
|
||||
Reference in New Issue
Block a user