feat: implement RLS in sqlglot (#33524)

This commit is contained in:
Beto Dealmeida
2025-05-28 09:10:45 -04:00
committed by GitHub
parent e205846845
commit 0abe6eed89
2 changed files with 826 additions and 20 deletions

View File

@@ -18,13 +18,14 @@
import pytest
from sqlglot import Dialects
from sqlglot import Dialects, parse_one
from superset.exceptions import SupersetParseError
from superset.sql.parse import (
extract_tables_from_statement,
KustoKQLStatement,
LimitMethod,
RLSMethod,
split_kql,
SQLGLOT_DIALECTS,
SQLScript,
@@ -303,11 +304,13 @@ def test_format_no_dialect() -> None:
"""
assert (
SQLScript("SELECT col FROM t WHERE col NOT IN (1, 2)", "dremio").format()
== """SELECT
== """
SELECT
col
FROM t
WHERE
NOT col IN (1, 2)"""
NOT col IN (1, 2)
""".strip()
)
@@ -1118,7 +1121,8 @@ FROM some_table) AS anon_1
WHERE anon_1.a > 1 AND anon_1.b = 2
"""
optimized = """SELECT
optimized = """
SELECT
anon_1.a,
anon_1.b
FROM (
@@ -1131,9 +1135,11 @@ FROM (
some_table.a > 1 AND some_table.b = 2
) AS anon_1
WHERE
TRUE AND TRUE"""
TRUE AND TRUE
""".strip()
not_optimized = """SELECT
not_optimized = """
SELECT
anon_1.a,
anon_1.b
FROM (
@@ -1144,7 +1150,8 @@ FROM (
FROM some_table
) AS anon_1
WHERE
anon_1.a > 1 AND anon_1.b = 2"""
anon_1.a > 1 AND anon_1.b = 2
""".strip()
assert SQLStatement(sql, "sqlite").optimize().format() == optimized
assert SQLStatement(sql, "dremio").optimize().format() == not_optimized
@@ -1195,9 +1202,11 @@ def test_firebolt_old() -> None:
sql = "SELECT * FROM t1 UNNEST(col1 AS foo)"
assert (
SQLStatement(sql, "firebolt").format()
== """SELECT
== """
SELECT
*
FROM t1 UNNEST(col1 AS foo)"""
FROM t1 UNNEST(col1 AS foo)
""".strip()
)
@@ -1216,9 +1225,11 @@ def test_firebolt_old_escape_string() -> None:
# but they normalize to ''
assert (
SQLStatement(sql, "firebolt").format()
== """SELECT
== """
SELECT
'foo''bar',
'foo''bar'"""
'foo''bar'
""".strip()
)
@@ -1410,7 +1421,8 @@ select TOP 100 * from currency
"mssql",
1000,
LimitMethod.FORCE_LIMIT,
"""WITH abc AS (
"""
WITH abc AS (
SELECT
*
FROM test
@@ -1422,7 +1434,8 @@ select TOP 100 * from currency
SELECT
TOP 1000
*
FROM currency""",
FROM currency
""".strip(),
),
(
"SELECT DISTINCT x from tbl",
@@ -1457,10 +1470,12 @@ FROM currency""",
"postgresql",
1000,
LimitMethod.FORCE_LIMIT,
"""SELECT
"""
SELECT
*
FROM birth_names /* SOME COMMENT WITH LIMIT 555 */
LIMIT 1000""",
LIMIT 1000
""".strip(),
),
(
"SELECT * FROM birth_names LIMIT 555",
@@ -1602,7 +1617,8 @@ UNION ALL
SELECT * FROM currency_2
""",
"postgresql",
"""WITH currency AS (
"""
WITH currency AS (
SELECT
'INR' AS cur
), currency_2 AS (
@@ -1616,7 +1632,8 @@ SELECT * FROM currency_2
SELECT
*
FROM currency_2
)""",
)
""".strip(),
),
],
)
@@ -1625,3 +1642,608 @@ def test_as_cte(sql: str, engine: str, expected: str) -> None:
Test that we can covert select to CTE.
"""
assert SQLStatement(sql, engine).as_cte().format() == expected
@pytest.mark.parametrize(
"sql, rules, expected",
[
(
"SELECT t.foo FROM some_table AS t",
{Table("some_table", "schema1", "catalog1"): "id = 42"},
"""
SELECT
t.foo
FROM (
SELECT
*
FROM some_table
WHERE
id = 42
) AS t
""".strip(),
),
(
"SELECT t.foo FROM some_table AS t WHERE bar = 'baz'",
{Table("some_table", "schema1", "catalog1"): "id = 42"},
"""
SELECT
t.foo
FROM (
SELECT
*
FROM some_table
WHERE
id = 42
) AS t
WHERE
bar = 'baz'
""".strip(),
),
(
"SELECT t.foo FROM schema1.some_table AS t",
{Table("some_table", "schema1", "catalog1"): "id = 42"},
"""
SELECT
t.foo
FROM (
SELECT
*
FROM schema1.some_table
WHERE
id = 42
) AS t
""".strip(),
),
(
"SELECT t.foo FROM schema1.some_table AS t",
{Table("some_table", "schema2"): "id = 42"},
"SELECT\n t.foo\nFROM schema1.some_table AS t",
),
(
"SELECT t.foo FROM catalog1.schema1.some_table AS t",
{Table("some_table", "schema1", "catalog1"): "id = 42"},
"""
SELECT
t.foo
FROM (
SELECT
*
FROM catalog1.schema1.some_table
WHERE
id = 42
) AS t
""".strip(),
),
(
"SELECT t.foo FROM catalog1.schema1.some_table AS t",
{Table("some_table", "schema1", "catalog2"): "id = 42"},
"SELECT\n t.foo\nFROM catalog1.schema1.some_table AS t",
),
(
"SELECT * FROM some_table WHERE 1=1",
{Table("some_table", "schema1", "catalog1"): "id = 42"},
"""
SELECT
*
FROM (
SELECT
*
FROM some_table
WHERE
id = 42
) AS some_table
WHERE
1 = 1
""".strip(),
),
(
"SELECT * FROM table WHERE 1=1",
{Table("table", "schema1", "catalog1"): "id = 42"},
"""
SELECT
*
FROM (
SELECT
*
FROM table
WHERE
id = 42
) AS table
WHERE
1 = 1
""".strip(),
),
(
'SELECT * FROM "table" WHERE 1=1',
{Table("table", "schema1", "catalog1"): "id = 42"},
"""
SELECT
*
FROM (
SELECT
*
FROM "table"
WHERE
id = 42
) AS "table"
WHERE
1 = 1
""".strip(),
),
(
"SELECT * FROM table WHERE 1=1",
{Table("other_table", "schema1", "catalog1"): "id = 42"},
"""
SELECT
*
FROM table
WHERE
1 = 1
""".strip(),
),
(
"SELECT * FROM other_table WHERE 1=1",
{Table("table", "schema1", "catalog1"): "id = 42"},
"""
SELECT
*
FROM other_table
WHERE
1 = 1
""".strip(),
),
(
"SELECT * FROM table JOIN other_table ON table.id = other_table.id",
{Table("other_table", "schema1", "catalog1"): "id = 42"},
"""
SELECT
*
FROM table
JOIN (
SELECT
*
FROM other_table
WHERE
id = 42
) AS other_table
ON table.id = other_table.id
""".strip(),
),
(
'SELECT * FROM "table" JOIN other_table ON "table".id = other_table.id',
{Table("table", "schema1", "catalog1"): "id = 42"},
"""
SELECT
*
FROM (
SELECT
*
FROM "table"
WHERE
id = 42
) AS "table"
JOIN other_table
ON "table".id = other_table.id
""".strip(),
),
(
"SELECT * FROM (SELECT * FROM some_table)",
{Table("some_table", "schema1", "catalog1"): "id = 42"},
"""
SELECT
*
FROM (
SELECT
*
FROM (
SELECT
*
FROM some_table
WHERE
id = 42
) AS some_table
)
""".strip(),
),
(
"SELECT * FROM table UNION ALL SELECT * FROM other_table",
{Table("table", "schema1", "catalog1"): "id = 42"},
"""
SELECT
*
FROM (
SELECT
*
FROM table
WHERE
id = 42
) AS table
UNION ALL
SELECT
*
FROM other_table
""".strip(),
),
(
"SELECT * FROM table UNION ALL SELECT * FROM other_table",
{Table("other_table", "schema1", "catalog1"): "id = 42"},
"""
SELECT
*
FROM table
UNION ALL
SELECT
*
FROM (
SELECT
*
FROM other_table
WHERE
id = 42
) AS other_table
""".strip(),
),
(
"SELECT a.*, b.* FROM tbl_a AS a INNER JOIN tbl_b AS b ON a.col = b.col",
{Table("tbl_a", "schema1", "catalog1"): "id = 42"},
"""
SELECT
a.*,
b.*
FROM (
SELECT
*
FROM tbl_a
WHERE
id = 42
) AS a
INNER JOIN tbl_b AS b
ON a.col = b.col
""".strip(),
),
(
"SELECT a.*, b.* FROM tbl_a a INNER JOIN tbl_b b ON a.col = b.col",
{Table("tbl_a", "schema1", "catalog1"): "id = 42"},
"""
SELECT
a.*,
b.*
FROM (
SELECT
*
FROM tbl_a
WHERE
id = 42
) AS a
INNER JOIN tbl_b AS b
ON a.col = b.col
""".strip(),
),
],
)
def test_rls_subquery_transformer(
sql: str,
rules: dict[Table, str],
expected: str,
) -> None:
"""
Test `RLSAsSubqueryTransformer`.
"""
statement = SQLStatement(sql)
statement.apply_rls(
"catalog1",
"schema1",
{k: [parse_one(v)] for k, v in rules.items()},
RLSMethod.AS_SUBQUERY,
)
assert statement.format() == expected
@pytest.mark.parametrize(
"sql, rules, expected",
[
(
"SELECT t.foo FROM some_table AS t",
{Table("some_table", "schema1", "catalog1"): "id = 42"},
"""
SELECT
t.foo
FROM some_table AS t
WHERE
t.id = 42
""".strip(),
),
(
"SELECT t.foo FROM schema2.some_table AS t",
{Table("some_table", "schema1", "catalog1"): "id = 42"},
"""
SELECT
t.foo
FROM schema2.some_table AS t
""".strip(),
),
(
"SELECT t.foo FROM catalog2.schema1.some_table AS t",
{Table("some_table", "schema1", "catalog1"): "id = 42"},
"""
SELECT
t.foo
FROM catalog2.schema1.some_table AS t
""".strip(),
),
(
"SELECT t.foo FROM some_table AS t WHERE bar = 'baz'",
{Table("some_table", "schema1", "catalog1"): "id = 42"},
"""
SELECT
t.foo
FROM some_table AS t
WHERE
t.id = 42 AND (
bar = 'baz'
)
""".strip(),
),
(
"SELECT t.foo FROM some_table AS t WHERE bar = 'baz' OR foo = 'qux'",
{Table("some_table", "schema1", "catalog1"): "id = 42"},
"""
SELECT
t.foo
FROM some_table AS t
WHERE
t.id = 42 AND (
bar = 'baz' OR foo = 'qux'
)
""".strip(),
),
(
"SELECT * FROM some_table WHERE 1=1",
{Table("some_table", "schema1", "catalog1"): "id = 42"},
"""
SELECT
*
FROM some_table
WHERE
some_table.id = 42 AND (
1 = 1
)
""".strip(),
),
(
"SELECT * FROM some_table WHERE TRUE OR FALSE",
{Table("some_table", "schema1", "catalog1"): "id = 42"},
"""
SELECT
*
FROM some_table
WHERE
some_table.id = 42 AND (
TRUE OR FALSE
)
""".strip(),
),
(
"SELECT * FROM table WHERE 1=1",
{Table("table", "schema1", "catalog1"): "id = 42"},
"""
SELECT
*
FROM table
WHERE
table.id = 42 AND (
1 = 1
)
""".strip(),
),
(
'SELECT * FROM "table" WHERE 1=1',
{Table("table", "schema1", "catalog1"): "id = 42"},
"""
SELECT
*
FROM "table"
WHERE
"table".id = 42 AND (
1 = 1
)
""".strip(),
),
(
"SELECT * FROM table WHERE 1=1",
{Table("other_table", "schema1", "catalog1"): "id = 42"},
"""
SELECT
*
FROM table
WHERE
1 = 1
""".strip(),
),
(
"SELECT * FROM other_table WHERE 1=1",
{Table("table", "schema1", "catalog1"): "id = 42"},
"""
SELECT
*
FROM other_table
WHERE
1 = 1
""".strip(),
),
(
"SELECT * FROM table",
{Table("table", "schema1", "catalog1"): "id = 42"},
"""
SELECT
*
FROM table
WHERE
table.id = 42
""".strip(),
),
(
"SELECT * FROM some_table",
{Table("some_table", "schema1", "catalog1"): "id = 42"},
"""
SELECT
*
FROM some_table
WHERE
some_table.id = 42
""".strip(),
),
(
"SELECT * FROM table ORDER BY id",
{Table("table", "schema1", "catalog1"): "id = 42"},
"""
SELECT
*
FROM table
WHERE
table.id = 42
ORDER BY
id
""".strip(),
),
(
"SELECT * FROM table WHERE 1=1 AND table.id=42",
{Table("table", "schema1", "catalog1"): "id = 42"},
"""
SELECT
*
FROM table
WHERE
table.id = 42 AND (
1 = 1 AND table.id = 42
)
""".strip(),
),
(
"""
SELECT * FROM table
JOIN other_table
ON table.id = other_table.id
AND other_table.id=42
""",
{Table("other_table", "schema1", "catalog1"): "id = 42"},
"""
SELECT
*
FROM table
JOIN other_table
ON other_table.id = 42 AND (
table.id = other_table.id AND other_table.id = 42
)
""".strip(),
),
(
"SELECT * FROM table WHERE 1=1 AND id=42",
{Table("table", "schema1", "catalog1"): "id = 42"},
"""
SELECT
*
FROM table
WHERE
table.id = 42 AND (
1 = 1 AND id = 42
)
""".strip(),
),
(
"SELECT * FROM table JOIN other_table ON table.id = other_table.id",
{Table("other_table", "schema1", "catalog1"): "id = 42"},
"""
SELECT
*
FROM table
JOIN other_table
ON other_table.id = 42 AND (
table.id = other_table.id
)
""".strip(),
),
(
"""
SELECT *
FROM table
JOIN other_table
ON table.id = other_table.id
WHERE 1=1
""",
{Table("other_table", "schema1", "catalog1"): "id = 42"},
"""
SELECT
*
FROM table
JOIN other_table
ON other_table.id = 42 AND (
table.id = other_table.id
)
WHERE
1 = 1
""".strip(),
),
(
"SELECT * FROM (SELECT * FROM other_table)",
{Table("other_table", "schema1", "catalog1"): "id = 42"},
"""
SELECT
*
FROM (
SELECT
*
FROM other_table
WHERE
other_table.id = 42
)
""".strip(),
),
(
"SELECT * FROM table UNION ALL SELECT * FROM other_table",
{Table("table", "schema1", "catalog1"): "id = 42"},
"""
SELECT
*
FROM table
WHERE
table.id = 42
UNION ALL
SELECT
*
FROM other_table
""".strip(),
),
(
"SELECT * FROM table UNION ALL SELECT * FROM other_table",
{Table("other_table", "schema1", "catalog1"): "id = 42"},
"""
SELECT
*
FROM table
UNION ALL
SELECT
*
FROM other_table
WHERE
other_table.id = 42
""".strip(),
),
],
)
def test_rls_predicate_transformer(
sql: str,
rules: dict[Table, str],
expected: str,
) -> None:
"""
Test `RLSPredicateTransformer`.
"""
statement = SQLStatement(sql)
statement.apply_rls(
"catalog1",
"schema1",
{k: [parse_one(v)] for k, v in rules.items()},
RLSMethod.AS_PREDICATE,
)
assert statement.format() == expected