feat: safer insert RLS (#20323)

This commit is contained in:
Beto Dealmeida
2023-11-08 22:52:25 -05:00
committed by GitHub
parent 90e210892b
commit 2bd611916d
5 changed files with 324 additions and 15 deletions

View File

@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, redefined-outer-name, unused-argument, protected-access, too-many-lines
# pylint: disable=invalid-name, redefined-outer-name, too-many-lines
from typing import Optional
@@ -31,7 +31,8 @@ from superset.sql_parse import (
extract_table_references,
get_rls_for_table,
has_table_query,
insert_rls,
insert_rls_as_subquery,
insert_rls_in_predicate,
ParsedQuery,
sanitize_clause,
strip_comments_from_sql,
@@ -1318,6 +1319,184 @@ def test_has_table_query(sql: str, expected: bool) -> None:
assert has_table_query(statement) == expected
@pytest.mark.parametrize(
"sql,table,rls,expected",
[
# Basic test
(
"SELECT * FROM some_table WHERE 1=1",
"some_table",
"id=42",
(
"SELECT * FROM (SELECT * FROM some_table WHERE some_table.id=42) "
"AS some_table WHERE 1=1"
),
),
# Here "table" is a reserved word; since sqlparse is too aggressive when
# characterizing reserved words we need to support them even when not quoted.
(
"SELECT * FROM table WHERE 1=1",
"table",
"id=42",
"SELECT * FROM (SELECT * FROM table WHERE table.id=42) AS table WHERE 1=1",
),
# RLS is only applied to queries reading from the associated table
(
"SELECT * FROM table WHERE 1=1",
"other_table",
"id=42",
"SELECT * FROM table WHERE 1=1",
),
(
"SELECT * FROM other_table WHERE 1=1",
"table",
"id=42",
"SELECT * FROM other_table WHERE 1=1",
),
# JOINs are supported
(
"SELECT * FROM table JOIN other_table ON table.id = other_table.id",
"other_table",
"id=42",
(
"SELECT * FROM table JOIN "
"(SELECT * FROM other_table WHERE other_table.id=42) AS other_table "
"ON table.id = other_table.id"
),
),
# Subqueries
(
"SELECT * FROM (SELECT * FROM other_table)",
"other_table",
"id=42",
(
"SELECT * FROM (SELECT * FROM ("
"SELECT * FROM other_table WHERE other_table.id=42"
") AS other_table)"
),
),
# UNION
(
"SELECT * FROM table UNION ALL SELECT * FROM other_table",
"table",
"id=42",
(
"SELECT * FROM (SELECT * FROM table WHERE table.id=42) AS table "
"UNION ALL SELECT * FROM other_table"
),
),
(
"SELECT * FROM table UNION ALL SELECT * FROM other_table",
"other_table",
"id=42",
(
"SELECT * FROM table UNION ALL SELECT * FROM ("
"SELECT * FROM other_table WHERE other_table.id=42) AS other_table"
),
),
# When comparing fully qualified table names (eg, schema.table) to simple names
# (eg, table) we are also conservative, assuming the schema is the same, since
# we don't have information on the default schema.
(
"SELECT * FROM schema.table_name",
"table_name",
"id=42",
(
"SELECT * FROM (SELECT * FROM schema.table_name "
"WHERE table_name.id=42) AS table_name"
),
),
(
"SELECT * FROM schema.table_name",
"schema.table_name",
"id=42",
(
"SELECT * FROM (SELECT * FROM schema.table_name "
"WHERE schema.table_name.id=42) AS table_name"
),
),
(
"SELECT * FROM table_name",
"schema.table_name",
"id=42",
(
"SELECT * FROM (SELECT * FROM table_name WHERE "
"schema.table_name.id=42) AS table_name"
),
),
# Aliases
(
"SELECT a.*, b.* FROM tbl_a AS a INNER JOIN tbl_b AS b ON a.col = b.col",
"tbl_a",
"id=42",
(
"SELECT a.*, b.* FROM "
"(SELECT * FROM tbl_a WHERE tbl_a.id=42) AS a "
"INNER JOIN tbl_b AS b "
"ON a.col = b.col"
),
),
(
"SELECT a.*, b.* FROM tbl_a a INNER JOIN tbl_b b ON a.col = b.col",
"tbl_a",
"id=42",
(
"SELECT a.*, b.* FROM "
"(SELECT * FROM tbl_a WHERE tbl_a.id=42) AS a "
"INNER JOIN tbl_b b ON a.col = b.col"
),
),
],
)
def test_insert_rls_as_subquery(
mocker: MockerFixture, sql: str, table: str, rls: str, expected: str
) -> None:
"""
Insert into a statement a given RLS condition associated with a table.
"""
condition = sqlparse.parse(rls)[0]
add_table_name(condition, table)
# pylint: disable=unused-argument
def get_rls_for_table(
candidate: Token,
database_id: int,
default_schema: str,
) -> Optional[TokenList]:
"""
Return the RLS ``condition`` if ``candidate`` matches ``table``.
"""
if not isinstance(candidate, Identifier):
candidate = Identifier([Token(Name, candidate.value)])
candidate_table = ParsedQuery.get_table(candidate)
if not candidate_table:
return None
candidate_table_name = (
f"{candidate_table.schema}.{candidate_table.table}"
if candidate_table.schema
else candidate_table.table
)
for left, right in zip(
candidate_table_name.split(".")[::-1], table.split(".")[::-1]
):
if left != right:
return None
return condition
mocker.patch("superset.sql_parse.get_rls_for_table", new=get_rls_for_table)
statement = sqlparse.parse(sql)[0]
assert (
str(
insert_rls_as_subquery(
token_list=statement, database_id=1, default_schema="my_schema"
)
).strip()
== expected.strip()
)
@pytest.mark.parametrize(
"sql,table,rls,expected",
[
@@ -1492,7 +1671,7 @@ def test_has_table_query(sql: str, expected: bool) -> None:
),
],
)
def test_insert_rls(
def test_insert_rls_in_predicate(
mocker: MockerFixture, sql: str, table: str, rls: str, expected: str
) -> None:
"""
@@ -1521,7 +1700,11 @@ def test_insert_rls(
statement = sqlparse.parse(sql)[0]
assert (
str(
insert_rls(token_list=statement, database_id=1, default_schema="my_schema")
insert_rls_in_predicate(
token_list=statement,
database_id=1,
default_schema="my_schema",
)
).strip()
== expected.strip()
)