mirror of
https://github.com/apache/superset.git
synced 2026-05-01 05:54:26 +00:00
4693 lines
129 KiB
Python
4693 lines
129 KiB
Python
# Licensed to the Apache Software Foundation (ASF) under one
|
|
# or more contributor license agreements. See the NOTICE file
|
|
# distributed with this work for additional information
|
|
# regarding copyright ownership. The ASF licenses this file
|
|
# to you under the Apache License, Version 2.0 (the
|
|
# "License"); you may not use this file except in compliance
|
|
# with the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing,
|
|
# software distributed under the License is distributed on an
|
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
# KIND, either express or implied. See the License for the
|
|
# specific language governing permissions and limitations
|
|
# under the License.
|
|
# pylint: disable=invalid-name, redefined-outer-name, too-many-lines
|
|
|
|
|
|
from typing import Any
|
|
|
|
import pytest
|
|
from pytest_mock import MockerFixture
|
|
from sqlglot import Dialects, exp, parse_one
|
|
|
|
from superset.exceptions import QueryClauseValidationException, SupersetParseError
|
|
from superset.jinja_context import JinjaTemplateProcessor
|
|
from superset.sql.parse import (
|
|
apply_cls,
|
|
CLS_ACTION_PRECEDENCE,
|
|
CLS_HASH_FUNCTIONS,
|
|
CLSAction,
|
|
CLSTransformer,
|
|
CTASMethod,
|
|
extract_tables_from_statement,
|
|
JinjaSQLResult,
|
|
KQLTokenType,
|
|
KustoKQLStatement,
|
|
LimitMethod,
|
|
merge_cls_rules,
|
|
process_jinja_sql,
|
|
remove_quotes,
|
|
RLSMethod,
|
|
sanitize_clause,
|
|
split_kql,
|
|
SQLGLOT_DIALECTS,
|
|
SQLScript,
|
|
SQLStatement,
|
|
Table,
|
|
tokenize_kql,
|
|
)
|
|
from tests.integration_tests.conftest import with_feature_flags
|
|
|
|
|
|
def test_table() -> None:
|
|
"""
|
|
Test the `Table` class and its string conversion.
|
|
|
|
Special characters in the table, schema, or catalog name should be escaped correctly.
|
|
""" # noqa: E501
|
|
assert str(Table("tbname")) == "tbname"
|
|
assert str(Table("tbname", "schemaname")) == "schemaname.tbname"
|
|
assert (
|
|
str(Table("tbname", "schemaname", "catalogname"))
|
|
== "catalogname.schemaname.tbname"
|
|
)
|
|
assert (
|
|
str(Table("table.name", "schema/name", "catalog\nname"))
|
|
== "catalog%0Aname.schema%2Fname.table%2Ename"
|
|
)
|
|
|
|
|
|
def test_table_qualify() -> None:
|
|
"""
|
|
Test the `Table.qualify` method.
|
|
|
|
The qualify method should add schema and/or catalog if not already set,
|
|
but should not override existing values.
|
|
"""
|
|
# Table with no schema or catalog
|
|
table = Table("tbname")
|
|
|
|
# Add schema only
|
|
qualified = table.qualify(schema="schemaname")
|
|
assert qualified.table == "tbname"
|
|
assert qualified.schema == "schemaname"
|
|
assert qualified.catalog is None
|
|
assert str(qualified) == "schemaname.tbname"
|
|
|
|
# Add catalog only
|
|
qualified = table.qualify(catalog="catalogname")
|
|
assert qualified.table == "tbname"
|
|
assert qualified.schema is None
|
|
assert qualified.catalog == "catalogname"
|
|
assert str(qualified) == "catalogname.tbname"
|
|
|
|
# Add both schema and catalog
|
|
qualified = table.qualify(schema="schemaname", catalog="catalogname")
|
|
assert qualified.table == "tbname"
|
|
assert qualified.schema == "schemaname"
|
|
assert qualified.catalog == "catalogname"
|
|
assert str(qualified) == "catalogname.schemaname.tbname"
|
|
|
|
# Table with existing schema - should not override
|
|
table_with_schema = Table("tbname", "existingschema")
|
|
qualified = table_with_schema.qualify(schema="newschema")
|
|
assert qualified.schema == "existingschema"
|
|
assert str(qualified) == "existingschema.tbname"
|
|
|
|
# Table with existing catalog - should not override
|
|
table_with_catalog = Table("tbname", catalog="existingcatalog")
|
|
qualified = table_with_catalog.qualify(catalog="newcatalog")
|
|
assert qualified.catalog == "existingcatalog"
|
|
assert str(qualified) == "existingcatalog.tbname"
|
|
|
|
# Table with existing schema and catalog - should not override
|
|
fully_qualified = Table("tbname", "existingschema", "existingcatalog")
|
|
qualified = fully_qualified.qualify(schema="newschema", catalog="newcatalog")
|
|
assert qualified.schema == "existingschema"
|
|
assert qualified.catalog == "existingcatalog"
|
|
assert str(qualified) == "existingcatalog.existingschema.tbname"
|
|
|
|
# Table with schema but no catalog - should add catalog only
|
|
table_with_schema_only = Table("tbname", "existingschema")
|
|
qualified = table_with_schema_only.qualify(
|
|
schema="newschema", catalog="catalogname"
|
|
)
|
|
assert qualified.schema == "existingschema"
|
|
assert qualified.catalog == "catalogname"
|
|
assert str(qualified) == "catalogname.existingschema.tbname"
|
|
|
|
# Table with catalog but no schema - should add schema only
|
|
table_with_catalog_only = Table("tbname", catalog="existingcatalog")
|
|
qualified = table_with_catalog_only.qualify(
|
|
schema="schemaname", catalog="newcatalog"
|
|
)
|
|
assert qualified.schema == "schemaname"
|
|
assert qualified.catalog == "existingcatalog"
|
|
assert str(qualified) == "existingcatalog.schemaname.tbname"
|
|
|
|
# Calling qualify with no arguments should return equivalent table
|
|
qualified = table.qualify()
|
|
assert qualified.table == table.table
|
|
assert qualified.schema == table.schema
|
|
assert qualified.catalog == table.catalog
|
|
|
|
|
|
def extract_tables_from_sql(sql: str, engine: str = "postgresql") -> set[Table]:
|
|
"""
|
|
Helper function to extract tables from SQL.
|
|
"""
|
|
dialect = SQLGLOT_DIALECTS.get(engine)
|
|
return {
|
|
table
|
|
for statement in SQLScript(sql, engine).statements
|
|
for table in extract_tables_from_statement(statement._parsed, dialect)
|
|
}
|
|
|
|
|
|
def test_extract_tables_from_sql() -> None:
|
|
"""
|
|
Test that referenced tables are parsed correctly from the SQL.
|
|
"""
|
|
assert extract_tables_from_sql("SELECT * FROM tbname") == {Table("tbname")}
|
|
assert extract_tables_from_sql("SELECT * FROM tbname foo") == {Table("tbname")}
|
|
assert extract_tables_from_sql("SELECT * FROM tbname AS foo") == {Table("tbname")}
|
|
|
|
# underscore
|
|
assert extract_tables_from_sql("SELECT * FROM tb_name") == {Table("tb_name")}
|
|
|
|
# quotes
|
|
assert extract_tables_from_sql('SELECT * FROM "tbname"') == {Table("tbname")}
|
|
|
|
# unicode
|
|
assert extract_tables_from_sql('SELECT * FROM "tb_name" WHERE city = "Lübeck"') == {
|
|
Table("tb_name")
|
|
}
|
|
|
|
# columns
|
|
assert extract_tables_from_sql("SELECT field1, field2 FROM tb_name") == {
|
|
Table("tb_name")
|
|
}
|
|
assert extract_tables_from_sql("SELECT t1.f1, t2.f2 FROM t1, t2") == {
|
|
Table("t1"),
|
|
Table("t2"),
|
|
}
|
|
|
|
# named table
|
|
assert extract_tables_from_sql(
|
|
"SELECT a.date, a.field FROM left_table a LIMIT 10"
|
|
) == {Table("left_table")}
|
|
|
|
assert extract_tables_from_sql(
|
|
"SELECT FROM (SELECT FROM forbidden_table) AS forbidden_table;"
|
|
) == {Table("forbidden_table")}
|
|
|
|
assert extract_tables_from_sql(
|
|
"select * from (select * from forbidden_table) forbidden_table"
|
|
) == {Table("forbidden_table")}
|
|
|
|
|
|
def test_extract_tables_subselect() -> None:
|
|
"""
|
|
Test that tables inside subselects are parsed correctly.
|
|
"""
|
|
assert extract_tables_from_sql(
|
|
"""
|
|
SELECT sub.*
|
|
FROM (
|
|
SELECT *
|
|
FROM s1.t1
|
|
WHERE day_of_week = 'Friday'
|
|
) sub, s2.t2
|
|
WHERE sub.resolution = 'NONE'
|
|
"""
|
|
) == {Table("t1", "s1"), Table("t2", "s2")}
|
|
|
|
assert extract_tables_from_sql(
|
|
"""
|
|
SELECT sub.*
|
|
FROM (
|
|
SELECT *
|
|
FROM s1.t1
|
|
WHERE day_of_week = 'Friday'
|
|
) sub
|
|
WHERE sub.resolution = 'NONE'
|
|
"""
|
|
) == {Table("t1", "s1")}
|
|
|
|
assert extract_tables_from_sql(
|
|
"""
|
|
SELECT * FROM t1
|
|
WHERE s11 > ANY (
|
|
SELECT COUNT(*) /* no hint */ FROM t2
|
|
WHERE NOT EXISTS (
|
|
SELECT * FROM t3
|
|
WHERE ROW(5*t2.s1,77)=(
|
|
SELECT 50,11*s1 FROM t4
|
|
)
|
|
)
|
|
)
|
|
"""
|
|
) == {Table("t1"), Table("t2"), Table("t3"), Table("t4")}
|
|
|
|
|
|
def test_extract_tables_select_in_expression() -> None:
|
|
"""
|
|
Test that parser works with `SELECT`s used as expressions.
|
|
"""
|
|
assert extract_tables_from_sql("SELECT f1, (SELECT count(1) FROM t2) FROM t1") == {
|
|
Table("t1"),
|
|
Table("t2"),
|
|
}
|
|
assert extract_tables_from_sql(
|
|
"SELECT f1, (SELECT count(1) FROM t2) as f2 FROM t1"
|
|
) == {
|
|
Table("t1"),
|
|
Table("t2"),
|
|
}
|
|
|
|
|
|
def test_extract_tables_parenthesis() -> None:
|
|
"""
|
|
Test that parenthesis are parsed correctly.
|
|
"""
|
|
assert extract_tables_from_sql("SELECT f1, (x + y) AS f2 FROM t1") == {Table("t1")}
|
|
|
|
|
|
def test_extract_tables_with_schema() -> None:
|
|
"""
|
|
Test that schemas are parsed correctly.
|
|
"""
|
|
assert extract_tables_from_sql("SELECT * FROM schemaname.tbname") == {
|
|
Table("tbname", "schemaname")
|
|
}
|
|
assert extract_tables_from_sql('SELECT * FROM "schemaname"."tbname"') == {
|
|
Table("tbname", "schemaname")
|
|
}
|
|
assert extract_tables_from_sql('SELECT * FROM "schemaname"."tbname" foo') == {
|
|
Table("tbname", "schemaname")
|
|
}
|
|
assert extract_tables_from_sql('SELECT * FROM "schemaname"."tbname" AS foo') == {
|
|
Table("tbname", "schemaname")
|
|
}
|
|
|
|
|
|
def test_extract_tables_union() -> None:
|
|
"""
|
|
Test that `UNION` queries work as expected.
|
|
"""
|
|
assert extract_tables_from_sql("SELECT * FROM t1 UNION SELECT * FROM t2") == {
|
|
Table("t1"),
|
|
Table("t2"),
|
|
}
|
|
assert extract_tables_from_sql("SELECT * FROM t1 UNION ALL SELECT * FROM t2") == {
|
|
Table("t1"),
|
|
Table("t2"),
|
|
}
|
|
assert extract_tables_from_sql(
|
|
"SELECT * FROM t1 INTERSECT ALL SELECT * FROM t2"
|
|
) == {
|
|
Table("t1"),
|
|
Table("t2"),
|
|
}
|
|
|
|
|
|
def test_extract_tables_select_from_values() -> None:
|
|
"""
|
|
Test that selecting from values returns no tables.
|
|
"""
|
|
assert extract_tables_from_sql("SELECT * FROM VALUES (13, 42)") == set()
|
|
|
|
|
|
def test_extract_tables_select_array() -> None:
|
|
"""
|
|
Test that queries selecting arrays work as expected.
|
|
"""
|
|
assert extract_tables_from_sql(
|
|
"""
|
|
SELECT ARRAY[1, 2, 3] AS my_array
|
|
FROM t1 LIMIT 10
|
|
"""
|
|
) == {Table("t1")}
|
|
|
|
|
|
def test_extract_tables_select_if() -> None:
|
|
"""
|
|
Test that queries with an `IF` work as expected.
|
|
"""
|
|
assert extract_tables_from_sql(
|
|
"""
|
|
SELECT IF(CARDINALITY(my_array) >= 3, my_array[3], NULL)
|
|
FROM t1 LIMIT 10
|
|
"""
|
|
) == {Table("t1")}
|
|
|
|
|
|
def test_extract_tables_with_catalog() -> None:
|
|
"""
|
|
Test that catalogs are parsed correctly.
|
|
"""
|
|
assert extract_tables_from_sql("SELECT * FROM catalogname.schemaname.tbname") == {
|
|
Table("tbname", "schemaname", "catalogname")
|
|
}
|
|
|
|
|
|
def test_extract_tables_illdefined() -> None:
|
|
"""
|
|
Test that ill-defined tables return an empty set.
|
|
"""
|
|
with pytest.raises(SupersetParseError) as excinfo:
|
|
extract_tables_from_sql("SELECT * FROM schemaname.")
|
|
assert str(excinfo.value) == "Error parsing near '.' at line 1:25"
|
|
|
|
with pytest.raises(SupersetParseError) as excinfo:
|
|
extract_tables_from_sql("SELECT * FROM catalogname.schemaname.")
|
|
assert str(excinfo.value) == "Error parsing near '.' at line 1:37"
|
|
|
|
with pytest.raises(SupersetParseError) as excinfo:
|
|
extract_tables_from_sql("SELECT * FROM catalogname..")
|
|
assert str(excinfo.value) == "Error parsing near '.' at line 1:27"
|
|
|
|
with pytest.raises(SupersetParseError) as excinfo:
|
|
extract_tables_from_sql('SELECT * FROM "tbname')
|
|
assert str(excinfo.value) == "Unable to parse script"
|
|
|
|
# odd edge case that works
|
|
assert extract_tables_from_sql("SELECT * FROM catalogname..tbname") == {
|
|
Table(table="tbname", schema=None, catalog="catalogname")
|
|
}
|
|
|
|
|
|
def test_extract_tables_show_tables_from() -> None:
|
|
"""
|
|
Test `SHOW TABLES FROM`.
|
|
"""
|
|
assert (
|
|
extract_tables_from_sql("SHOW TABLES FROM s1 like '%order%'", "mysql") == set()
|
|
)
|
|
|
|
|
|
def test_format_show_tables() -> None:
|
|
"""
|
|
Test format when `ast.sql()` raises an exception.
|
|
"""
|
|
assert (
|
|
SQLScript("SHOW TABLES FROM s1 like '%order%'", "mysql").format()
|
|
== "SHOW TABLES FROM s1 LIKE '%order%'"
|
|
)
|
|
|
|
|
|
def test_format_no_dialect() -> None:
|
|
"""
|
|
Test format with an engine that has no corresponding dialect.
|
|
"""
|
|
assert (
|
|
SQLScript("SELECT col FROM t WHERE col NOT IN (1, 2)", "dremio").format()
|
|
== """
|
|
SELECT
|
|
col
|
|
FROM t
|
|
WHERE
|
|
NOT col IN (1, 2)
|
|
""".strip()
|
|
)
|
|
|
|
|
|
def test_split_no_dialect() -> None:
|
|
"""
|
|
Test the statement split when the engine has no corresponding dialect.
|
|
"""
|
|
sql = "SELECT col FROM t WHERE col NOT IN (1, 2); SELECT * FROM t; SELECT foo"
|
|
statements = SQLScript(sql, "dremio").statements
|
|
assert len(statements) == 3
|
|
assert statements[0].format() == "SELECT\n col\nFROM t\nWHERE\n NOT col IN (1, 2)"
|
|
assert statements[1].format() == "SELECT\n *\nFROM t"
|
|
assert statements[2].format() == "SELECT\n foo"
|
|
|
|
|
|
def test_extract_tables_show_columns_from() -> None:
|
|
"""
|
|
Test `SHOW COLUMNS FROM`.
|
|
"""
|
|
assert extract_tables_from_sql("SHOW COLUMNS FROM t1") == {Table("t1")}
|
|
|
|
|
|
def test_extract_tables_where_subquery() -> None:
|
|
"""
|
|
Test that tables in a `WHERE` subquery are parsed correctly.
|
|
"""
|
|
assert extract_tables_from_sql(
|
|
"""
|
|
SELECT name
|
|
FROM t1
|
|
WHERE regionkey = (SELECT max(regionkey) FROM t2)
|
|
"""
|
|
) == {Table("t1"), Table("t2")}
|
|
|
|
assert extract_tables_from_sql(
|
|
"""
|
|
SELECT name
|
|
FROM t1
|
|
WHERE regionkey IN (SELECT regionkey FROM t2)
|
|
"""
|
|
) == {Table("t1"), Table("t2")}
|
|
|
|
assert extract_tables_from_sql(
|
|
"""
|
|
SELECT name
|
|
FROM t1
|
|
WHERE EXISTS (SELECT 1 FROM t2 WHERE t1.regionkey = t2.regionkey);
|
|
"""
|
|
) == {Table("t1"), Table("t2")}
|
|
|
|
|
|
def test_extract_tables_describe() -> None:
|
|
"""
|
|
Test `DESCRIBE`.
|
|
"""
|
|
assert extract_tables_from_sql("DESCRIBE t1") == {Table("t1")}
|
|
|
|
|
|
def test_extract_tables_show_partitions() -> None:
|
|
"""
|
|
Test `SHOW PARTITIONS`.
|
|
"""
|
|
assert extract_tables_from_sql(
|
|
"""
|
|
SHOW PARTITIONS FROM orders
|
|
WHERE ds >= '2013-01-01' ORDER BY ds DESC
|
|
"""
|
|
) == {Table("orders")}
|
|
|
|
|
|
def test_extract_tables_join() -> None:
|
|
"""
|
|
Test joins.
|
|
"""
|
|
assert extract_tables_from_sql(
|
|
"SELECT t1.*, t2.* FROM t1 JOIN t2 ON t1.a = t2.a;"
|
|
) == {
|
|
Table("t1"),
|
|
Table("t2"),
|
|
}
|
|
|
|
assert extract_tables_from_sql(
|
|
"""
|
|
SELECT a.date, b.name
|
|
FROM left_table a
|
|
JOIN (
|
|
SELECT
|
|
CAST((b.year) as VARCHAR) date,
|
|
name
|
|
FROM right_table
|
|
) b
|
|
ON a.date = b.date
|
|
"""
|
|
) == {Table("left_table"), Table("right_table")}
|
|
|
|
assert extract_tables_from_sql(
|
|
"""
|
|
SELECT a.date, b.name
|
|
FROM left_table a
|
|
LEFT INNER JOIN (
|
|
SELECT
|
|
CAST((b.year) as VARCHAR) date,
|
|
name
|
|
FROM right_table
|
|
) b
|
|
ON a.date = b.date
|
|
"""
|
|
) == {Table("left_table"), Table("right_table")}
|
|
|
|
assert extract_tables_from_sql(
|
|
"""
|
|
SELECT a.date, b.name
|
|
FROM left_table a
|
|
RIGHT OUTER JOIN (
|
|
SELECT
|
|
CAST((b.year) as VARCHAR) date,
|
|
name
|
|
FROM right_table
|
|
) b
|
|
ON a.date = b.date
|
|
"""
|
|
) == {Table("left_table"), Table("right_table")}
|
|
|
|
assert extract_tables_from_sql(
|
|
"""
|
|
SELECT a.date, b.name
|
|
FROM left_table a
|
|
FULL OUTER JOIN (
|
|
SELECT
|
|
CAST((b.year) as VARCHAR) date,
|
|
name
|
|
FROM right_table
|
|
) b
|
|
ON a.date = b.date
|
|
"""
|
|
) == {Table("left_table"), Table("right_table")}
|
|
|
|
|
|
def test_extract_tables_semi_join() -> None:
|
|
"""
|
|
Test `LEFT SEMI JOIN`.
|
|
"""
|
|
assert extract_tables_from_sql(
|
|
"""
|
|
SELECT a.date, b.name
|
|
FROM left_table a
|
|
LEFT SEMI JOIN (
|
|
SELECT
|
|
CAST((b.year) as VARCHAR) date,
|
|
name
|
|
FROM right_table
|
|
) b
|
|
ON a.data = b.date
|
|
"""
|
|
) == {Table("left_table"), Table("right_table")}
|
|
|
|
|
|
def test_extract_tables_combinations() -> None:
|
|
"""
|
|
Test a complex case with nested queries.
|
|
"""
|
|
assert extract_tables_from_sql(
|
|
"""
|
|
SELECT * FROM t1
|
|
WHERE s11 > ANY (
|
|
SELECT * FROM t1 UNION ALL SELECT * FROM (
|
|
SELECT t6.*, t3.* FROM t6 JOIN t3 ON t6.a = t3.a
|
|
) tmp_join
|
|
WHERE NOT EXISTS (
|
|
SELECT * FROM t3
|
|
WHERE ROW(5*t3.s1,77)=(
|
|
SELECT 50,11*s1 FROM t4
|
|
)
|
|
)
|
|
)
|
|
"""
|
|
) == {Table("t1"), Table("t3"), Table("t4"), Table("t6")}
|
|
|
|
assert extract_tables_from_sql(
|
|
"""
|
|
SELECT * FROM (
|
|
SELECT * FROM (
|
|
SELECT * FROM (
|
|
SELECT * FROM EmployeeS
|
|
) AS S1
|
|
) AS S2
|
|
) AS S3
|
|
"""
|
|
) == {Table("EmployeeS")}
|
|
|
|
|
|
def test_extract_tables_with() -> None:
|
|
"""
|
|
Test `WITH`.
|
|
"""
|
|
assert extract_tables_from_sql(
|
|
"""
|
|
WITH
|
|
x AS (SELECT a FROM t1),
|
|
y AS (SELECT a AS b FROM t2),
|
|
z AS (SELECT b AS c FROM t3)
|
|
SELECT c FROM z
|
|
"""
|
|
) == {Table("t1"), Table("t2"), Table("t3")}
|
|
|
|
assert extract_tables_from_sql(
|
|
"""
|
|
WITH
|
|
x AS (SELECT a FROM t1),
|
|
y AS (SELECT a AS b FROM x),
|
|
z AS (SELECT b AS c FROM y)
|
|
SELECT c FROM z
|
|
"""
|
|
) == {Table("t1")}
|
|
|
|
|
|
def test_extract_tables_reusing_aliases() -> None:
|
|
"""
|
|
Test that the parser follows aliases.
|
|
"""
|
|
assert extract_tables_from_sql(
|
|
"""
|
|
with q1 as ( select key from q2 where key = '5'),
|
|
q2 as ( select key from src where key = '5')
|
|
select * from (select key from q1) a
|
|
"""
|
|
) == {Table("src")}
|
|
|
|
# weird query with circular dependency
|
|
assert (
|
|
extract_tables_from_sql(
|
|
"""
|
|
with src as ( select key from q2 where key = '5'),
|
|
q2 as ( select key from src where key = '5')
|
|
select * from (select key from src) a
|
|
"""
|
|
)
|
|
== set()
|
|
)
|
|
|
|
|
|
def test_extract_tables_multistatement() -> None:
|
|
"""
|
|
Test that the parser works with multiple statements.
|
|
"""
|
|
assert extract_tables_from_sql("SELECT * FROM t1; SELECT * FROM t2") == {
|
|
Table("t1"),
|
|
Table("t2"),
|
|
}
|
|
assert extract_tables_from_sql("SELECT * FROM t1; SELECT * FROM t2;") == {
|
|
Table("t1"),
|
|
Table("t2"),
|
|
}
|
|
assert extract_tables_from_sql(
|
|
"ADD JAR file:///hive.jar; SELECT * FROM t1;",
|
|
engine="hive",
|
|
) == {Table("t1")}
|
|
|
|
|
|
def test_extract_tables_complex() -> None:
|
|
"""
|
|
Test a few complex queries.
|
|
"""
|
|
assert extract_tables_from_sql(
|
|
"""
|
|
SELECT sum(m_examples) AS "sum__m_example"
|
|
FROM (
|
|
SELECT
|
|
COUNT(DISTINCT id_userid) AS m_examples,
|
|
some_more_info
|
|
FROM my_b_table b
|
|
JOIN my_t_table t ON b.ds=t.ds
|
|
JOIN my_l_table l ON b.uid=l.uid
|
|
WHERE
|
|
b.rid IN (
|
|
SELECT other_col
|
|
FROM inner_table
|
|
)
|
|
AND l.bla IN ('x', 'y')
|
|
GROUP BY 2
|
|
ORDER BY 2 ASC
|
|
) AS "meh"
|
|
ORDER BY "sum__m_example" DESC
|
|
LIMIT 10;
|
|
"""
|
|
) == {
|
|
Table("my_l_table"),
|
|
Table("my_b_table"),
|
|
Table("my_t_table"),
|
|
Table("inner_table"),
|
|
}
|
|
|
|
assert extract_tables_from_sql(
|
|
"""
|
|
SELECT *
|
|
FROM table_a AS a, table_b AS b, table_c as c
|
|
WHERE a.id = b.id and b.id = c.id
|
|
"""
|
|
) == {Table("table_a"), Table("table_b"), Table("table_c")}
|
|
|
|
assert extract_tables_from_sql(
|
|
"""
|
|
SELECT somecol AS somecol
|
|
FROM (
|
|
WITH bla AS (
|
|
SELECT col_a
|
|
FROM a
|
|
WHERE
|
|
1=1
|
|
AND column_of_choice NOT IN (
|
|
SELECT interesting_col
|
|
FROM b
|
|
)
|
|
),
|
|
rb AS (
|
|
SELECT yet_another_column
|
|
FROM (
|
|
SELECT a
|
|
FROM c
|
|
GROUP BY the_other_col
|
|
) not_table
|
|
LEFT JOIN bla foo
|
|
ON foo.prop = not_table.bad_col0
|
|
WHERE 1=1
|
|
GROUP BY
|
|
not_table.bad_col1 ,
|
|
not_table.bad_col2 ,
|
|
ORDER BY not_table.bad_col_3 DESC ,
|
|
not_table.bad_col4 ,
|
|
not_table.bad_col5
|
|
)
|
|
SELECT random_col
|
|
FROM d
|
|
WHERE 1=1
|
|
UNION ALL SELECT even_more_cols
|
|
FROM e
|
|
WHERE 1=1
|
|
UNION ALL SELECT lets_go_deeper
|
|
FROM f
|
|
WHERE 1=1
|
|
GROUP BY last_col
|
|
LIMIT 50000
|
|
)
|
|
"""
|
|
) == {Table("a"), Table("b"), Table("c"), Table("d"), Table("e"), Table("f")}
|
|
|
|
|
|
def test_extract_tables_mixed_from_clause() -> None:
|
|
"""
|
|
Test that the parser handles a `FROM` clause with table and subselect.
|
|
"""
|
|
assert extract_tables_from_sql(
|
|
"""
|
|
SELECT *
|
|
FROM table_a AS a, (select * from table_b) AS b, table_c as c
|
|
WHERE a.id = b.id and b.id = c.id
|
|
"""
|
|
) == {Table("table_a"), Table("table_b"), Table("table_c")}
|
|
|
|
|
|
def test_extract_tables_nested_select() -> None:
|
|
"""
|
|
Test that the parser handles selects inside functions.
|
|
"""
|
|
assert extract_tables_from_sql(
|
|
"""
|
|
select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(TABLE_NAME)
|
|
from INFORMATION_SCHEMA.COLUMNS
|
|
WHERE TABLE_SCHEMA like "%bi%"),0x7e)));
|
|
""",
|
|
"mysql",
|
|
) == {Table("COLUMNS", "INFORMATION_SCHEMA")}
|
|
|
|
assert extract_tables_from_sql(
|
|
"""
|
|
select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(COLUMN_NAME)
|
|
from INFORMATION_SCHEMA.COLUMNS
|
|
WHERE TABLE_NAME="bi_achievement_daily"),0x7e)));
|
|
""",
|
|
"mysql",
|
|
) == {Table("COLUMNS", "INFORMATION_SCHEMA")}
|
|
|
|
|
|
def test_extract_tables_complex_cte_with_prefix() -> None:
|
|
"""
|
|
Test that the parser handles CTEs with prefixes.
|
|
"""
|
|
assert extract_tables_from_sql(
|
|
"""
|
|
WITH CTE__test (SalesPersonID, SalesOrderID, SalesYear)
|
|
AS (
|
|
SELECT SalesPersonID, SalesOrderID, YEAR(OrderDate) AS SalesYear
|
|
FROM SalesOrderHeader
|
|
WHERE SalesPersonID IS NOT NULL
|
|
)
|
|
SELECT SalesPersonID, COUNT(SalesOrderID) AS TotalSales, SalesYear
|
|
FROM CTE__test
|
|
GROUP BY SalesYear, SalesPersonID
|
|
ORDER BY SalesPersonID, SalesYear;
|
|
"""
|
|
) == {Table("SalesOrderHeader")}
|
|
|
|
|
|
def test_extract_tables_identifier_list_with_keyword_as_alias() -> None:
|
|
"""
|
|
Test that aliases that are keywords are parsed correctly.
|
|
"""
|
|
assert extract_tables_from_sql(
|
|
"""
|
|
WITH
|
|
f AS (SELECT * FROM foo),
|
|
match AS (SELECT * FROM f)
|
|
SELECT * FROM match
|
|
"""
|
|
) == {Table("foo")}
|
|
|
|
|
|
def test_sqlscript() -> None:
|
|
"""
|
|
Test the `SQLScript` class.
|
|
"""
|
|
script = SQLScript("SELECT 1; SELECT 2;", "sqlite")
|
|
|
|
assert len(script.statements) == 2
|
|
assert script.format() == "SELECT\n 1;\nSELECT\n 2"
|
|
assert script.statements[0].format() == "SELECT\n 1"
|
|
|
|
script = SQLScript("SET a=1; SET a=2; SELECT 3;", "sqlite")
|
|
assert script.get_settings() == {"a": "2"}
|
|
|
|
query = SQLScript(
|
|
"""set querytrace;
|
|
Events | take 100""",
|
|
"kustokql",
|
|
)
|
|
assert query.get_settings() == {"querytrace": True}
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"sql, engine, expected",
|
|
[
|
|
(
|
|
" SELECT foo FROM tbl ; ",
|
|
"postgresql",
|
|
["SELECT\n foo\nFROM tbl"],
|
|
),
|
|
(
|
|
"SELECT foo FROM tbl1; SELECT bar FROM tbl2;",
|
|
"postgresql",
|
|
["SELECT\n foo\nFROM tbl1", "SELECT\n bar\nFROM tbl2"],
|
|
),
|
|
(
|
|
"let foo = 1; tbl | where bar == foo",
|
|
"kustokql",
|
|
["let foo = 1", "tbl | where bar == foo"],
|
|
),
|
|
(
|
|
"SELECT 1; -- extraneous comment",
|
|
"postgresql",
|
|
["SELECT\n 1 /* extraneous comment */"],
|
|
),
|
|
(
|
|
"SHOW TABLES FROM s1 like '%order%';",
|
|
"mysql",
|
|
["SHOW TABLES FROM s1 LIKE '%order%'"],
|
|
),
|
|
(
|
|
"SELECT 1; SELECT 2; SELECT 3;",
|
|
"unknown-engine",
|
|
[
|
|
"SELECT\n 1",
|
|
"SELECT\n 2",
|
|
"SELECT\n 3",
|
|
],
|
|
),
|
|
],
|
|
)
|
|
def test_sqlscript_split(sql: str, engine: str, expected: list[str]) -> None:
|
|
"""
|
|
Test the `SQLScript` class with a script that has a single statement.
|
|
"""
|
|
script = SQLScript(sql, engine)
|
|
assert [statement.format() for statement in script.statements] == expected
|
|
|
|
|
|
def test_sqlstatement() -> None:
|
|
"""
|
|
Test the `SQLStatement` class.
|
|
"""
|
|
statement = SQLStatement(
|
|
"SELECT * FROM table1 UNION ALL SELECT * FROM table2",
|
|
"sqlite",
|
|
)
|
|
|
|
assert (
|
|
statement.format()
|
|
== "SELECT\n *\nFROM table1\nUNION ALL\nSELECT\n *\nFROM table2"
|
|
)
|
|
assert str(statement) == statement.format()
|
|
|
|
assert statement.tables == {
|
|
Table(table="table1", schema=None, catalog=None),
|
|
Table(table="table2", schema=None, catalog=None),
|
|
}
|
|
|
|
assert statement.parse_predicate("a > 1") == exp.GT(
|
|
this=exp.Column(this=exp.Identifier(this="a", quoted=False)),
|
|
expression=exp.Literal(this="1", is_string=False),
|
|
)
|
|
|
|
statement = SQLStatement("SET a=1", "sqlite")
|
|
assert statement.get_settings() == {"a": "1"}
|
|
|
|
with pytest.raises(
|
|
ValueError,
|
|
match="Either statement or ast must be provided",
|
|
):
|
|
SQLStatement()
|
|
|
|
|
|
def test_kustokqlstatement() -> None:
|
|
"""
|
|
Test the `KustoKQLStatement` class.
|
|
"""
|
|
statement = KustoKQLStatement("foo | take 100", "kustokql")
|
|
|
|
assert statement.format() == "foo | take 100"
|
|
assert str(statement) == statement.format()
|
|
|
|
# doesn't support table extraction
|
|
assert statement.tables == set()
|
|
|
|
# optimize is a no-op
|
|
assert statement.optimize().format() == "foo | take 100"
|
|
|
|
# predicate parsing is also no-op
|
|
assert statement.parse_predicate("a > 1") == "a > 1"
|
|
|
|
with pytest.raises(SupersetParseError, match="Invalid engine: invalid-engine"):
|
|
KustoKQLStatement("foo | take 100", "invalid-engine")
|
|
|
|
with pytest.raises(
|
|
SupersetParseError,
|
|
match="KustoKQLStatement should have exactly one statement",
|
|
):
|
|
KustoKQLStatement("foo | take 1; bar | take 2", "kustokql")
|
|
|
|
|
|
def test_kustokqlstatement_split_script() -> None:
|
|
"""
|
|
Test the `KustoKQLStatement` split method.
|
|
"""
|
|
statements = KustoKQLStatement.split_script(
|
|
"""
|
|
let totalPagesPerDay = PageViews
|
|
| summarize by Page, Day = startofday(Timestamp)
|
|
| summarize count() by Day;
|
|
let materializedScope = PageViews
|
|
| summarize by Page, Day = startofday(Timestamp);
|
|
let cachedResult = materialize(materializedScope);
|
|
cachedResult
|
|
| project Page, Day1 = Day
|
|
| join kind = inner
|
|
(
|
|
cachedResult
|
|
| project Page, Day2 = Day
|
|
)
|
|
on Page
|
|
| where Day2 > Day1
|
|
| summarize count() by Day1, Day2
|
|
| join kind = inner
|
|
totalPagesPerDay
|
|
on $left.Day1 == $right.Day
|
|
| project Day1, Day2, Percentage = count_*100.0/count_1
|
|
""",
|
|
"kustokql",
|
|
)
|
|
assert len(statements) == 4
|
|
|
|
|
|
def test_kustokqlstatement_with_program() -> None:
|
|
"""
|
|
Test the `KustoKQLStatement` split method when the KQL has a program.
|
|
"""
|
|
statements = KustoKQLStatement.split_script(
|
|
"""
|
|
print program = ```
|
|
public class Program {
|
|
public static void Main() {
|
|
System.Console.WriteLine("Hello!");
|
|
}
|
|
}```
|
|
""",
|
|
"kustokql",
|
|
)
|
|
assert len(statements) == 1
|
|
|
|
|
|
def test_kustokqlstatement_with_set() -> None:
|
|
"""
|
|
Test the `KustoKQLStatement` split method when the KQL has a set command.
|
|
"""
|
|
statements = KustoKQLStatement.split_script(
|
|
"""
|
|
set querytrace;
|
|
Events | take 100
|
|
""",
|
|
"kustokql",
|
|
)
|
|
assert len(statements) == 2
|
|
assert statements[0].format() == "set querytrace"
|
|
assert statements[1].format() == "Events | take 100"
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"kql,statements",
|
|
[
|
|
('print banner=strcat("Hello", ", ", "World!")', 1),
|
|
(r"print 'O\'Malley\'s'", 1),
|
|
(r"print 'O\'Mal;ley\'s'", 1),
|
|
("print ```foo;\nbar;\nbaz;```\n", 1),
|
|
],
|
|
)
|
|
def test_kustokql_statement_split_special(kql: str, statements: int) -> None:
|
|
assert len(KustoKQLStatement.split_script(kql, "kustokql")) == statements
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"kql, expected",
|
|
[
|
|
(";Table | take 5", ["Table | take 5"]),
|
|
(";Table | take 5;", ["Table | take 5"]),
|
|
(
|
|
"""
|
|
let totalPagesPerDay = PageViews
|
|
| summarize by Page, Day = startofday(Timestamp)
|
|
| summarize count() by Day;
|
|
let materializedScope = PageViews
|
|
| summarize by Page, Day = startofday(Timestamp);
|
|
let cachedResult = materialize(materializedScope);
|
|
cachedResult
|
|
| project Page, Day1 = Day
|
|
| join kind = inner
|
|
(
|
|
cachedResult
|
|
| project Page, Day2 = Day
|
|
)
|
|
on Page
|
|
| where Day2 > Day1
|
|
| summarize count() by Day1, Day2
|
|
| join kind = inner
|
|
totalPagesPerDay
|
|
on $left.Day1 == $right.Day
|
|
| project Day1, Day2, Percentage = count_*100.0/count_1
|
|
""",
|
|
[
|
|
"""
|
|
let totalPagesPerDay = PageViews
|
|
| summarize by Page, Day = startofday(Timestamp)
|
|
| summarize count() by Day""",
|
|
"""
|
|
let materializedScope = PageViews
|
|
| summarize by Page, Day = startofday(Timestamp)""",
|
|
"""
|
|
let cachedResult = materialize(materializedScope)""",
|
|
"""
|
|
cachedResult
|
|
| project Page, Day1 = Day
|
|
| join kind = inner
|
|
(
|
|
cachedResult
|
|
| project Page, Day2 = Day
|
|
)
|
|
on Page
|
|
| where Day2 > Day1
|
|
| summarize count() by Day1, Day2
|
|
| join kind = inner
|
|
totalPagesPerDay
|
|
on $left.Day1 == $right.Day
|
|
| project Day1, Day2, Percentage = count_*100.0/count_1
|
|
""",
|
|
],
|
|
),
|
|
],
|
|
)
|
|
def test_split_kql(kql: str, expected: list[str]) -> None:
|
|
"""
|
|
Test the `split_kql` function.
|
|
"""
|
|
assert split_kql(kql) == expected
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
("engine", "sql", "expected"),
|
|
[
|
|
("sqlite", "SELECT 1", False),
|
|
("sqlite", "INSERT INTO foo VALUES (1)", True),
|
|
("sqlite", "UPDATE foo SET bar = 2 WHERE id = 1", True),
|
|
("sqlite", "DELETE FROM foo WHERE id = 1", True),
|
|
("sqlite", "CREATE TABLE foo (id INT, bar TEXT)", True),
|
|
("sqlite", "DROP TABLE foo", True),
|
|
("sqlite", "EXPLAIN SELECT * FROM foo", False),
|
|
("sqlite", "PRAGMA table_info(foo)", False),
|
|
("postgresql", "SELECT 1", False),
|
|
("postgresql", "INSERT INTO foo (id, bar) VALUES (1, 'test')", True),
|
|
("postgresql", "UPDATE foo SET bar = 'new' WHERE id = 1", True),
|
|
("postgresql", "DELETE FROM foo WHERE id = 1", True),
|
|
("postgresql", "CREATE TABLE foo (id SERIAL PRIMARY KEY, bar TEXT)", True),
|
|
("postgresql", "DROP TABLE foo", True),
|
|
("postgresql", "EXPLAIN ANALYZE SELECT * FROM foo", False),
|
|
("postgresql", "EXPLAIN ANALYZE DELETE FROM foo", True),
|
|
("postgresql", "SHOW search_path", False),
|
|
("postgresql", "SET search_path TO public", False),
|
|
(
|
|
"postgres",
|
|
"""
|
|
with source as (
|
|
select 1 as one
|
|
)
|
|
select * from source
|
|
""",
|
|
False,
|
|
),
|
|
("trino", "SELECT 1", False),
|
|
("trino", "INSERT INTO foo VALUES (1, 'bar')", True),
|
|
("trino", "UPDATE foo SET bar = 'baz' WHERE id = 1", True),
|
|
("trino", "DELETE FROM foo WHERE id = 1", True),
|
|
("trino", "CREATE TABLE foo (id INT, bar VARCHAR)", True),
|
|
("trino", "DROP TABLE foo", True),
|
|
("trino", "EXPLAIN SELECT * FROM foo", False),
|
|
("trino", "SHOW SCHEMAS", False),
|
|
("trino", "SET SESSION optimization_level = '3'", False),
|
|
("kustokql", "tbl | limit 100", False),
|
|
("kustokql", "let foo = 1; tbl | where bar == foo", False),
|
|
("kustokql", ".show tables", False),
|
|
("kustokql", "print 1", False),
|
|
("kustokql", "set querytrace; Events | take 100", False),
|
|
("kustokql", ".drop table foo", True),
|
|
("kustokql", ".set-or-append table foo <| bar", True),
|
|
("base", "SHOW LOCKS test EXTENDED", False),
|
|
("base", "SET hivevar:desc='Legislators'", False),
|
|
("base", "UPDATE t1 SET col1 = NULL", True),
|
|
("base", "EXPLAIN SELECT 1", False),
|
|
("base", "SELECT 1", False),
|
|
("base", "WITH bla AS (SELECT 1) SELECT * FROM bla", False),
|
|
("base", "SHOW CATALOGS", False),
|
|
("base", "SHOW TABLES", False),
|
|
("hive", "UPDATE t1 SET col1 = NULL", True),
|
|
("hive", "INSERT OVERWRITE TABLE tabB SELECT a.Age FROM TableA", True),
|
|
("hive", "SHOW LOCKS test EXTENDED", False),
|
|
("hive", "SET hivevar:desc='Legislators'", False),
|
|
("hive", "EXPLAIN SELECT 1", False),
|
|
("hive", "SELECT 1", False),
|
|
("hive", "WITH bla AS (SELECT 1) SELECT * FROM bla", False),
|
|
("presto", "SET hivevar:desc='Legislators'", False),
|
|
("presto", "UPDATE t1 SET col1 = NULL", True),
|
|
("presto", "INSERT OVERWRITE TABLE tabB SELECT a.Age FROM TableA", True),
|
|
("presto", "SHOW LOCKS test EXTENDED", False),
|
|
("presto", "EXPLAIN SELECT 1", False),
|
|
("presto", "SELECT 1", False),
|
|
("presto", "WITH bla AS (SELECT 1) SELECT * FROM bla", False),
|
|
],
|
|
)
|
|
def test_has_mutation(engine: str, sql: str, expected: bool) -> None:
|
|
"""
|
|
Test the `has_mutation` method.
|
|
"""
|
|
assert SQLScript(sql, engine).has_mutation() == expected
|
|
|
|
|
|
def test_get_settings() -> None:
|
|
"""
|
|
Test `get_settings` in some edge cases.
|
|
"""
|
|
sql = """
|
|
set
|
|
-- this is a tricky comment
|
|
search_path -- another one
|
|
= bar;
|
|
SELECT * FROM some_table;
|
|
"""
|
|
assert SQLScript(sql, "postgresql").get_settings() == {"search_path": "bar"}
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"app",
|
|
[{"SQLGLOT_DIALECTS_EXTENSIONS": {"custom": Dialects.MYSQL}}],
|
|
indirect=True,
|
|
)
|
|
def test_custom_dialect(app: None) -> None:
|
|
"""
|
|
Test that custom dialects are loaded correctly.
|
|
"""
|
|
assert SQLGLOT_DIALECTS.get("custom") == Dialects.MYSQL
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"engine",
|
|
[
|
|
"ascend",
|
|
"awsathena",
|
|
"base",
|
|
"bigquery",
|
|
"clickhouse",
|
|
"clickhousedb",
|
|
"cockroachdb",
|
|
"couchbase",
|
|
"crate",
|
|
"databend",
|
|
"databricks",
|
|
"db2",
|
|
"denodo",
|
|
"dremio",
|
|
"drill",
|
|
"druid",
|
|
"duckdb",
|
|
"dynamodb",
|
|
"elasticsearch",
|
|
"exa",
|
|
"firebird",
|
|
"firebolt",
|
|
"gsheets",
|
|
"hana",
|
|
"hive",
|
|
"ibmi",
|
|
"impala",
|
|
"kustokql",
|
|
"kustosql",
|
|
"kylin",
|
|
"mariadb",
|
|
"motherduck",
|
|
"mssql",
|
|
"mysql",
|
|
"netezza",
|
|
"oceanbase",
|
|
"ocient",
|
|
"odelasticsearch",
|
|
"oracle",
|
|
"pinot",
|
|
"postgresql",
|
|
"presto",
|
|
"pydoris",
|
|
"redshift",
|
|
"risingwave",
|
|
"shillelagh",
|
|
"snowflake",
|
|
"solr",
|
|
"sqlite",
|
|
"starrocks",
|
|
"superset",
|
|
"teradatasql",
|
|
"trino",
|
|
"vertica",
|
|
],
|
|
)
|
|
@pytest.mark.parametrize(
|
|
"sql, expected",
|
|
[
|
|
("SELECT 1", False),
|
|
("with source as ( select 1 as one ) select * from source", False),
|
|
("ALTER TABLE foo ADD COLUMN bar INT", True),
|
|
],
|
|
)
|
|
def test_is_mutating(sql: str, engine: str, expected: bool) -> None:
|
|
"""
|
|
Global tests for `is_mutating`, covering all supported engines.
|
|
"""
|
|
assert SQLStatement(sql, engine).is_mutating() == expected
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"sql, expected",
|
|
[
|
|
(
|
|
"""
|
|
DO $$
|
|
BEGIN
|
|
INSERT INTO public.users (name, real_name)
|
|
VALUES ('SQLLab bypass DML', 'SQLLab bypass DML');
|
|
END;
|
|
$$;
|
|
""",
|
|
True,
|
|
),
|
|
(
|
|
"""
|
|
DO $$
|
|
BEGIN
|
|
IF (SELECT COUNT(*) FROM orders WHERE status = 'pending') > 100 THEN
|
|
RAISE NOTICE 'High pending order volume detected';
|
|
END IF;
|
|
END;
|
|
$$;
|
|
""",
|
|
True,
|
|
),
|
|
],
|
|
)
|
|
def test_is_mutating_anonymous_block(sql: str, expected: bool) -> None:
|
|
"""
|
|
Test for `is_mutating` with a Postgres anonymous block.
|
|
|
|
Since we can't parse the PL/pgSQL inside the block we always assume it is mutating.
|
|
"""
|
|
assert SQLStatement(sql, "postgresql").is_mutating() == expected
|
|
|
|
|
|
def test_optimize() -> None:
|
|
"""
|
|
Test that the `optimize` method works as expected.
|
|
|
|
The SQL optimization only works with engines that have a corresponding dialect.
|
|
"""
|
|
sql = """
|
|
SELECT anon_1.a, anon_1.b
|
|
FROM (SELECT some_table.a AS a, some_table.b AS b, some_table.c AS c
|
|
FROM some_table) AS anon_1
|
|
WHERE anon_1.a > 1 AND anon_1.b = 2
|
|
"""
|
|
|
|
optimized = """
|
|
SELECT
|
|
anon_1.a,
|
|
anon_1.b
|
|
FROM (
|
|
SELECT
|
|
some_table.a AS a,
|
|
some_table.b AS b,
|
|
some_table.c AS c
|
|
FROM some_table
|
|
WHERE
|
|
some_table.a > 1 AND some_table.b = 2
|
|
) AS anon_1
|
|
WHERE
|
|
TRUE AND TRUE
|
|
""".strip()
|
|
|
|
not_optimized = """
|
|
SELECT
|
|
anon_1.a,
|
|
anon_1.b
|
|
FROM (
|
|
SELECT
|
|
some_table.a AS a,
|
|
some_table.b AS b,
|
|
some_table.c AS c
|
|
FROM some_table
|
|
) AS anon_1
|
|
WHERE
|
|
anon_1.a > 1 AND anon_1.b = 2
|
|
""".strip()
|
|
|
|
assert SQLStatement(sql, "sqlite").optimize().format() == optimized
|
|
assert SQLStatement(sql, "crate").optimize().format() == not_optimized
|
|
|
|
# also works for scripts
|
|
assert SQLScript(sql, "sqlite").optimize().format() == optimized
|
|
|
|
|
|
def test_firebolt() -> None:
|
|
"""
|
|
Test that Firebolt 3rd party dialect is registered correctly.
|
|
|
|
We need a custom dialect for Firebolt because it parses `NOT col IN (1, 2)` as
|
|
`(NOT col) IN (1, 2)` instead of `NOT (col IN (1, 2))`, which will fail when `col`
|
|
is not a boolean.
|
|
|
|
Note that `NOT col = 1` works as expected in Firebolt, parsing as `NOT (col = 1)`.
|
|
"""
|
|
sql = "SELECT col NOT IN (1, 2) FROM tbl"
|
|
assert (
|
|
SQLStatement(sql, "firebolt").format()
|
|
== """
|
|
SELECT
|
|
NOT (
|
|
col IN (1, 2)
|
|
)
|
|
FROM tbl
|
|
""".strip()
|
|
)
|
|
|
|
sql = "SELECT NOT col = 1 FROM tbl"
|
|
assert (
|
|
SQLStatement(sql, "firebolt").format()
|
|
== """
|
|
SELECT
|
|
NOT col = 1
|
|
FROM tbl
|
|
""".strip()
|
|
)
|
|
|
|
|
|
def test_firebolt_old() -> None:
|
|
"""
|
|
Test the dialect for the old Firebolt syntax.
|
|
"""
|
|
from superset.sql.dialects import FireboltOld
|
|
from superset.sql.parse import SQLGLOT_DIALECTS
|
|
|
|
SQLGLOT_DIALECTS["firebolt"] = FireboltOld
|
|
|
|
sql = "SELECT * FROM t1 UNNEST(col1 AS foo)"
|
|
assert (
|
|
SQLStatement(sql, "firebolt").format()
|
|
== """
|
|
SELECT
|
|
*
|
|
FROM t1 UNNEST(col1 AS foo)
|
|
""".strip()
|
|
)
|
|
|
|
|
|
def test_firebolt_old_escape_string() -> None:
|
|
"""
|
|
Test the dialect for the old Firebolt syntax.
|
|
"""
|
|
from superset.sql.dialects import FireboltOld
|
|
from superset.sql.parse import SQLGLOT_DIALECTS
|
|
|
|
SQLGLOT_DIALECTS["firebolt"] = FireboltOld
|
|
|
|
# both '' and \' are valid escape sequences
|
|
sql = r"SELECT 'foo''bar', 'foo\'bar'"
|
|
|
|
# but they normalize to ''
|
|
assert (
|
|
SQLStatement(sql, "firebolt").format()
|
|
== """
|
|
SELECT
|
|
'foo''bar',
|
|
'foo''bar'
|
|
""".strip()
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"sql, engine, expected",
|
|
[
|
|
("SELECT * FROM users LIMIT 10", "postgresql", 10),
|
|
(
|
|
"""
|
|
WITH cte_example AS (
|
|
SELECT * FROM my_table
|
|
LIMIT 100
|
|
)
|
|
SELECT * FROM cte_example
|
|
LIMIT 10;
|
|
""",
|
|
"postgresql",
|
|
10,
|
|
),
|
|
("SELECT * FROM users ORDER BY id DESC LIMIT 25", "postgresql", 25),
|
|
("SELECT * FROM users", "postgresql", None),
|
|
("SELECT TOP 5 name FROM employees", "teradatasql", 5),
|
|
("SELECT TOP (42) * FROM table_name", "teradatasql", 42),
|
|
("select * from table", "postgresql", None),
|
|
("select * from mytable limit 10", "postgresql", 10),
|
|
(
|
|
"select * from (select * from my_subquery limit 10) where col=1 limit 20",
|
|
"postgresql",
|
|
20,
|
|
),
|
|
("select * from (select * from my_subquery limit 10);", "postgresql", None),
|
|
(
|
|
"select * from (select * from my_subquery limit 10) where col=1 limit 20;",
|
|
"postgresql",
|
|
20,
|
|
),
|
|
("select * from mytable limit 20, 10", "postgresql", 10),
|
|
("select * from mytable limit 10 offset 20", "postgresql", 10),
|
|
(
|
|
"""
|
|
SELECT id, value, i
|
|
FROM (SELECT * FROM my_table LIMIT 10),
|
|
LATERAL generate_series(1, value) AS i;
|
|
""",
|
|
"postgresql",
|
|
None,
|
|
),
|
|
# not really valid SQL, but let's roll with it
|
|
("SELECT * FROM my_table LIMIT invalid", "postgresql", None),
|
|
],
|
|
)
|
|
def test_get_limit_value(sql: str, engine: str, expected: str) -> None:
|
|
assert SQLStatement(sql, engine).get_limit_value() == expected
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"kql, expected",
|
|
[
|
|
("StormEvents | take 10", 10),
|
|
("StormEvents | limit 20", 20),
|
|
("StormEvents | where State == 'FL' | summarize count()", None),
|
|
("StormEvents | where name has 'limit 10'", None),
|
|
("AnotherTable | take 5", 5),
|
|
("datatable(x:int) [1, 2, 3] | take 100", 100),
|
|
(
|
|
"""
|
|
Table1 | where msg contains 'abc;xyz'
|
|
| limit 5
|
|
""",
|
|
5,
|
|
),
|
|
("table | take five", None),
|
|
],
|
|
)
|
|
def test_get_kql_limit_value(kql: str, expected: str) -> None:
|
|
assert KustoKQLStatement(kql, "kustokql").get_limit_value() == expected
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"sql, engine, limit, method, expected",
|
|
[
|
|
(
|
|
"SELECT * FROM t",
|
|
"postgresql",
|
|
10,
|
|
LimitMethod.FORCE_LIMIT,
|
|
"SELECT\n *\nFROM t\nLIMIT 10",
|
|
),
|
|
(
|
|
"SELECT * FROM t LIMIT 1000",
|
|
"postgresql",
|
|
10,
|
|
LimitMethod.FORCE_LIMIT,
|
|
"SELECT\n *\nFROM t\nLIMIT 10",
|
|
),
|
|
(
|
|
"SELECT * FROM t",
|
|
"mssql",
|
|
10,
|
|
LimitMethod.FORCE_LIMIT,
|
|
"SELECT\nTOP 10\n *\nFROM t",
|
|
),
|
|
(
|
|
"SELECT * FROM t",
|
|
"teradatasql",
|
|
10,
|
|
LimitMethod.FORCE_LIMIT,
|
|
"SELECT\nTOP 10\n *\nFROM t",
|
|
),
|
|
(
|
|
"SELECT * FROM t",
|
|
"oracle",
|
|
10,
|
|
LimitMethod.FORCE_LIMIT,
|
|
"SELECT\n *\nFROM t\nFETCH FIRST 10 ROWS ONLY",
|
|
),
|
|
(
|
|
"SELECT * FROM t",
|
|
"db2",
|
|
10,
|
|
LimitMethod.WRAP_SQL,
|
|
"SELECT\n *\nFROM (\n SELECT\n *\n FROM t\n)\nLIMIT 10",
|
|
),
|
|
(
|
|
"SEL TOP 1000 * FROM My_table",
|
|
"teradatasql",
|
|
100,
|
|
LimitMethod.FORCE_LIMIT,
|
|
"SELECT\nTOP 100\n *\nFROM My_table",
|
|
),
|
|
(
|
|
"SEL TOP 1000 * FROM My_table;",
|
|
"teradatasql",
|
|
100,
|
|
LimitMethod.FORCE_LIMIT,
|
|
"SELECT\nTOP 100\n *\nFROM My_table",
|
|
),
|
|
(
|
|
"SEL TOP 1000 * FROM My_table;",
|
|
"teradatasql",
|
|
1000,
|
|
LimitMethod.FORCE_LIMIT,
|
|
"SELECT\nTOP 1000\n *\nFROM My_table",
|
|
),
|
|
(
|
|
"SELECT TOP 1000 * FROM My_table;",
|
|
"teradatasql",
|
|
100,
|
|
LimitMethod.FORCE_LIMIT,
|
|
"SELECT\nTOP 100\n *\nFROM My_table",
|
|
),
|
|
(
|
|
"SELECT TOP 1000 * FROM My_table;",
|
|
"teradatasql",
|
|
10000,
|
|
LimitMethod.FORCE_LIMIT,
|
|
"SELECT\nTOP 10000\n *\nFROM My_table",
|
|
),
|
|
(
|
|
"SELECT TOP 1000 * FROM My_table",
|
|
"mssql",
|
|
100,
|
|
LimitMethod.FORCE_LIMIT,
|
|
"SELECT\nTOP 100\n *\nFROM My_table",
|
|
),
|
|
(
|
|
"SELECT TOP 1000 * FROM My_table;",
|
|
"mssql",
|
|
100,
|
|
LimitMethod.FORCE_LIMIT,
|
|
"SELECT\nTOP 100\n *\nFROM My_table",
|
|
),
|
|
(
|
|
"SELECT TOP 1000 * FROM My_table;",
|
|
"mssql",
|
|
10000,
|
|
LimitMethod.FORCE_LIMIT,
|
|
"SELECT\nTOP 10000\n *\nFROM My_table",
|
|
),
|
|
(
|
|
"SELECT TOP 1000 * FROM My_table;",
|
|
"mssql",
|
|
1000,
|
|
LimitMethod.FORCE_LIMIT,
|
|
"SELECT\nTOP 1000\n *\nFROM My_table",
|
|
),
|
|
(
|
|
"""
|
|
with abc as (select * from test union select * from test1)
|
|
select TOP 100 * from currency
|
|
""",
|
|
"mssql",
|
|
1000,
|
|
LimitMethod.FORCE_LIMIT,
|
|
"""
|
|
WITH abc AS (
|
|
SELECT
|
|
*
|
|
FROM test
|
|
UNION
|
|
SELECT
|
|
*
|
|
FROM test1
|
|
)
|
|
SELECT
|
|
TOP 1000
|
|
*
|
|
FROM currency
|
|
""".strip(),
|
|
),
|
|
(
|
|
"SELECT DISTINCT x from tbl",
|
|
"mssql",
|
|
100,
|
|
LimitMethod.FORCE_LIMIT,
|
|
"SELECT DISTINCT\nTOP 100\n x\nFROM tbl",
|
|
),
|
|
(
|
|
"SELECT 1 as cnt",
|
|
"mssql",
|
|
10,
|
|
LimitMethod.FORCE_LIMIT,
|
|
"SELECT\nTOP 10\n 1 AS cnt",
|
|
),
|
|
(
|
|
"select TOP 1000 * from abc where id=1",
|
|
"mssql",
|
|
10,
|
|
LimitMethod.FORCE_LIMIT,
|
|
"SELECT\nTOP 10\n *\nFROM abc\nWHERE\n id = 1",
|
|
),
|
|
(
|
|
"SELECT * FROM birth_names -- SOME COMMENT",
|
|
"postgresql",
|
|
1000,
|
|
LimitMethod.FORCE_LIMIT,
|
|
"SELECT\n *\nFROM birth_names /* SOME COMMENT */\nLIMIT 1000",
|
|
),
|
|
(
|
|
"SELECT * FROM birth_names -- SOME COMMENT WITH LIMIT 555",
|
|
"postgresql",
|
|
1000,
|
|
LimitMethod.FORCE_LIMIT,
|
|
"""
|
|
SELECT
|
|
*
|
|
FROM birth_names /* SOME COMMENT WITH LIMIT 555 */
|
|
LIMIT 1000
|
|
""".strip(),
|
|
),
|
|
(
|
|
"SELECT * FROM birth_names LIMIT 555",
|
|
"postgresql",
|
|
1000,
|
|
LimitMethod.FORCE_LIMIT,
|
|
"SELECT\n *\nFROM birth_names\nLIMIT 1000",
|
|
),
|
|
(
|
|
"SELECT * FROM birth_names LIMIT 555",
|
|
"postgresql",
|
|
1000,
|
|
LimitMethod.FETCH_MANY,
|
|
"SELECT\n *\nFROM birth_names\nLIMIT 555",
|
|
),
|
|
],
|
|
)
|
|
def test_set_limit_value(
|
|
sql: str,
|
|
engine: str,
|
|
limit: int,
|
|
method: LimitMethod,
|
|
expected: str,
|
|
) -> None:
|
|
statement = SQLStatement(sql, engine)
|
|
statement.set_limit_value(limit, method)
|
|
assert statement.format() == expected
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"kql, limit, expected",
|
|
[
|
|
("StormEvents | take 10", 100, "StormEvents | take 100"),
|
|
("StormEvents | limit 20", 10, "StormEvents | limit 10"),
|
|
(
|
|
"StormEvents | where State == 'FL' | summarize count()",
|
|
10,
|
|
"StormEvents | where State == 'FL' | summarize count() | take 10",
|
|
),
|
|
(
|
|
"StormEvents | where name has 'limit 10'",
|
|
10,
|
|
"StormEvents | where name has 'limit 10' | take 10",
|
|
),
|
|
("AnotherTable | take 5", 50, "AnotherTable | take 50"),
|
|
(
|
|
"datatable(x:int) [1, 2, 3] | take 100",
|
|
10,
|
|
"datatable(x:int) [1, 2, 3] | take 10",
|
|
),
|
|
(
|
|
"""
|
|
Table1 | where msg contains 'abc;xyz'
|
|
| limit 5
|
|
""",
|
|
10,
|
|
"""Table1 | where msg contains 'abc;xyz'
|
|
| limit 10""",
|
|
),
|
|
],
|
|
)
|
|
def test_set_kql_limit_value(kql: str, limit: int, expected: str) -> None:
|
|
"""
|
|
Test the `set_limit_value` method for KustoKQLStatement.
|
|
"""
|
|
statement = KustoKQLStatement(kql, "kustokql")
|
|
statement.set_limit_value(limit)
|
|
assert statement.format() == expected
|
|
|
|
|
|
@pytest.mark.parametrize("method", [LimitMethod.WRAP_SQL, LimitMethod.FETCH_MANY])
|
|
def test_set_kql_limit_value_invalid_method(method: LimitMethod) -> None:
|
|
"""
|
|
Test that setting a limit value with an invalid method raises an error.
|
|
"""
|
|
statement = KustoKQLStatement("foo", "kustokql")
|
|
|
|
with pytest.raises(
|
|
SupersetParseError,
|
|
match="Kusto KQL only supports the FORCE_LIMIT method.",
|
|
):
|
|
statement.set_limit_value(10, method)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"sql, engine, expected",
|
|
[
|
|
("SELECT 1", "postgresql", False),
|
|
("SELECT 1 AS cnt", "postgresql", False),
|
|
(
|
|
"""
|
|
SELECT 'INR' AS cur
|
|
UNION
|
|
SELECT 'USD' AS cur
|
|
UNION
|
|
SELECT 'EUR' AS cur
|
|
""",
|
|
"postgresql",
|
|
False,
|
|
),
|
|
("WITH cte AS (SELECT 1) SELECT * FROM cte", "postgresql", True),
|
|
(
|
|
"""
|
|
WITH
|
|
x AS (SELECT a FROM t1),
|
|
y AS (SELECT a AS b FROM t2),
|
|
z AS (SELECT b AS c FROM t3)
|
|
SELECT c FROM z
|
|
""",
|
|
"postgresql",
|
|
True,
|
|
),
|
|
(
|
|
"""
|
|
WITH
|
|
x AS (SELECT a FROM t1),
|
|
y AS (SELECT a AS b FROM x),
|
|
z AS (SELECT b AS c FROM y)
|
|
SELECT c FROM z
|
|
""",
|
|
"postgresql",
|
|
True,
|
|
),
|
|
(
|
|
"""
|
|
WITH CTE__test (SalesPersonID, SalesOrderID, SalesYear)
|
|
AS (
|
|
SELECT SalesPersonID, SalesOrderID, YEAR(OrderDate) AS SalesYear
|
|
FROM SalesOrderHeader
|
|
WHERE SalesPersonID IS NOT NULL
|
|
)
|
|
SELECT SalesPersonID, COUNT(SalesOrderID) AS TotalSales, SalesYear
|
|
FROM CTE__test
|
|
GROUP BY SalesYear, SalesPersonID
|
|
ORDER BY SalesPersonID, SalesYear;
|
|
""",
|
|
"postgresql",
|
|
True,
|
|
),
|
|
],
|
|
)
|
|
def test_has_cte(sql: str, engine: str, expected: bool) -> None:
|
|
"""
|
|
Test that the parser detects CTEs correctly.
|
|
"""
|
|
assert SQLStatement(sql, engine).has_cte() == expected
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"sql, engine, expected",
|
|
[
|
|
(
|
|
"SELECT 1",
|
|
"postgresql",
|
|
"WITH __cte AS (\n SELECT\n 1\n)",
|
|
),
|
|
(
|
|
"""
|
|
WITH currency AS (SELECT 'INR' AS cur),
|
|
currency_2 AS (SELECT 'USD' AS cur)
|
|
SELECT * FROM currency
|
|
UNION ALL
|
|
SELECT * FROM currency_2
|
|
""",
|
|
"postgresql",
|
|
"""
|
|
WITH currency AS (
|
|
SELECT
|
|
'INR' AS cur
|
|
), currency_2 AS (
|
|
SELECT
|
|
'USD' AS cur
|
|
), __cte AS (
|
|
SELECT
|
|
*
|
|
FROM currency
|
|
UNION ALL
|
|
SELECT
|
|
*
|
|
FROM currency_2
|
|
)
|
|
""".strip(),
|
|
),
|
|
],
|
|
)
|
|
def test_as_cte(sql: str, engine: str, expected: str) -> None:
|
|
"""
|
|
Test that we can covert select to CTE.
|
|
"""
|
|
assert SQLStatement(sql, engine).as_cte().format() == expected
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"sql, rules, expected",
|
|
[
|
|
(
|
|
"SELECT t.foo FROM some_table AS t",
|
|
{Table("some_table", "schema1", "catalog1"): "id = 42"},
|
|
"""
|
|
SELECT
|
|
t.foo
|
|
FROM (
|
|
SELECT
|
|
*
|
|
FROM some_table
|
|
WHERE
|
|
id = 42
|
|
) AS t
|
|
""".strip(),
|
|
),
|
|
(
|
|
"SELECT t.foo FROM some_table AS t",
|
|
{},
|
|
"""
|
|
SELECT
|
|
t.foo
|
|
FROM some_table AS t
|
|
""".strip(),
|
|
),
|
|
(
|
|
"SELECT t.foo FROM some_table AS t WHERE bar = 'baz'",
|
|
{Table("some_table", "schema1", "catalog1"): "id = 42"},
|
|
"""
|
|
SELECT
|
|
t.foo
|
|
FROM (
|
|
SELECT
|
|
*
|
|
FROM some_table
|
|
WHERE
|
|
id = 42
|
|
) AS t
|
|
WHERE
|
|
bar = 'baz'
|
|
""".strip(),
|
|
),
|
|
(
|
|
"SELECT t.foo FROM schema1.some_table AS t",
|
|
{Table("some_table", "schema1", "catalog1"): "id = 42"},
|
|
"""
|
|
SELECT
|
|
t.foo
|
|
FROM (
|
|
SELECT
|
|
*
|
|
FROM schema1.some_table
|
|
WHERE
|
|
id = 42
|
|
) AS t
|
|
""".strip(),
|
|
),
|
|
(
|
|
"SELECT t.foo FROM schema1.some_table AS t",
|
|
{Table("some_table", "schema2"): "id = 42"},
|
|
"SELECT\n t.foo\nFROM schema1.some_table AS t",
|
|
),
|
|
(
|
|
"SELECT t.foo FROM catalog1.schema1.some_table AS t",
|
|
{Table("some_table", "schema1", "catalog1"): "id = 42"},
|
|
"""
|
|
SELECT
|
|
t.foo
|
|
FROM (
|
|
SELECT
|
|
*
|
|
FROM catalog1.schema1.some_table
|
|
WHERE
|
|
id = 42
|
|
) AS t
|
|
""".strip(),
|
|
),
|
|
(
|
|
"SELECT t.foo FROM catalog1.schema1.some_table AS t",
|
|
{Table("some_table", "schema1", "catalog2"): "id = 42"},
|
|
"SELECT\n t.foo\nFROM catalog1.schema1.some_table AS t",
|
|
),
|
|
(
|
|
"SELECT * FROM some_table WHERE 1=1",
|
|
{Table("some_table", "schema1", "catalog1"): "id = 42"},
|
|
"""
|
|
SELECT
|
|
*
|
|
FROM (
|
|
SELECT
|
|
*
|
|
FROM some_table
|
|
WHERE
|
|
id = 42
|
|
) AS "some_table"
|
|
WHERE
|
|
1 = 1
|
|
""".strip(),
|
|
),
|
|
(
|
|
"SELECT * FROM table WHERE 1=1",
|
|
{Table("table", "schema1", "catalog1"): "id = 42"},
|
|
"""
|
|
SELECT
|
|
*
|
|
FROM (
|
|
SELECT
|
|
*
|
|
FROM table
|
|
WHERE
|
|
id = 42
|
|
) AS "table"
|
|
WHERE
|
|
1 = 1
|
|
""".strip(),
|
|
),
|
|
(
|
|
'SELECT * FROM "table" WHERE 1=1',
|
|
{Table("table", "schema1", "catalog1"): "id = 42"},
|
|
"""
|
|
SELECT
|
|
*
|
|
FROM (
|
|
SELECT
|
|
*
|
|
FROM "table"
|
|
WHERE
|
|
id = 42
|
|
) AS "table"
|
|
WHERE
|
|
1 = 1
|
|
""".strip(),
|
|
),
|
|
(
|
|
"SELECT * FROM table WHERE 1=1",
|
|
{Table("other_table", "schema1", "catalog1"): "id = 42"},
|
|
"""
|
|
SELECT
|
|
*
|
|
FROM table
|
|
WHERE
|
|
1 = 1
|
|
""".strip(),
|
|
),
|
|
(
|
|
"SELECT * FROM other_table WHERE 1=1",
|
|
{Table("table", "schema1", "catalog1"): "id = 42"},
|
|
"""
|
|
SELECT
|
|
*
|
|
FROM other_table
|
|
WHERE
|
|
1 = 1
|
|
""".strip(),
|
|
),
|
|
(
|
|
"SELECT * FROM table JOIN other_table ON table.id = other_table.id",
|
|
{Table("other_table", "schema1", "catalog1"): "id = 42"},
|
|
"""
|
|
SELECT
|
|
*
|
|
FROM table
|
|
JOIN (
|
|
SELECT
|
|
*
|
|
FROM other_table
|
|
WHERE
|
|
id = 42
|
|
) AS "other_table"
|
|
ON table.id = other_table.id
|
|
""".strip(),
|
|
),
|
|
(
|
|
'SELECT * FROM "table" JOIN other_table ON "table".id = other_table.id',
|
|
{Table("table", "schema1", "catalog1"): "id = 42"},
|
|
"""
|
|
SELECT
|
|
*
|
|
FROM (
|
|
SELECT
|
|
*
|
|
FROM "table"
|
|
WHERE
|
|
id = 42
|
|
) AS "table"
|
|
JOIN other_table
|
|
ON "table".id = other_table.id
|
|
""".strip(),
|
|
),
|
|
(
|
|
"SELECT * FROM (SELECT * FROM some_table)",
|
|
{Table("some_table", "schema1", "catalog1"): "id = 42"},
|
|
"""
|
|
SELECT
|
|
*
|
|
FROM (
|
|
SELECT
|
|
*
|
|
FROM (
|
|
SELECT
|
|
*
|
|
FROM some_table
|
|
WHERE
|
|
id = 42
|
|
) AS "some_table"
|
|
)
|
|
""".strip(),
|
|
),
|
|
(
|
|
"SELECT * FROM table UNION ALL SELECT * FROM other_table",
|
|
{Table("table", "schema1", "catalog1"): "id = 42"},
|
|
"""
|
|
SELECT
|
|
*
|
|
FROM (
|
|
SELECT
|
|
*
|
|
FROM table
|
|
WHERE
|
|
id = 42
|
|
) AS "table"
|
|
UNION ALL
|
|
SELECT
|
|
*
|
|
FROM other_table
|
|
""".strip(),
|
|
),
|
|
(
|
|
"SELECT * FROM table UNION ALL SELECT * FROM other_table",
|
|
{Table("other_table", "schema1", "catalog1"): "id = 42"},
|
|
"""
|
|
SELECT
|
|
*
|
|
FROM table
|
|
UNION ALL
|
|
SELECT
|
|
*
|
|
FROM (
|
|
SELECT
|
|
*
|
|
FROM other_table
|
|
WHERE
|
|
id = 42
|
|
) AS "other_table"
|
|
""".strip(),
|
|
),
|
|
(
|
|
"SELECT a.*, b.* FROM tbl_a AS a INNER JOIN tbl_b AS b ON a.col = b.col",
|
|
{Table("tbl_a", "schema1", "catalog1"): "id = 42"},
|
|
"""
|
|
SELECT
|
|
a.*,
|
|
b.*
|
|
FROM (
|
|
SELECT
|
|
*
|
|
FROM tbl_a
|
|
WHERE
|
|
id = 42
|
|
) AS a
|
|
INNER JOIN tbl_b AS b
|
|
ON a.col = b.col
|
|
""".strip(),
|
|
),
|
|
(
|
|
"SELECT a.*, b.* FROM tbl_a a INNER JOIN tbl_b b ON a.col = b.col",
|
|
{Table("tbl_a", "schema1", "catalog1"): "id = 42"},
|
|
"""
|
|
SELECT
|
|
a.*,
|
|
b.*
|
|
FROM (
|
|
SELECT
|
|
*
|
|
FROM tbl_a
|
|
WHERE
|
|
id = 42
|
|
) AS a
|
|
INNER JOIN tbl_b AS b
|
|
ON a.col = b.col
|
|
""".strip(),
|
|
),
|
|
(
|
|
"SELECT * FROM public.flights LIMIT 100",
|
|
{Table("flights", "public", "catalog1"): "\"AIRLINE\" like 'A%'"},
|
|
"""
|
|
SELECT
|
|
*
|
|
FROM (
|
|
SELECT
|
|
*
|
|
FROM public.flights
|
|
WHERE
|
|
"AIRLINE" LIKE 'A%'
|
|
) AS "public.flights"
|
|
LIMIT 100
|
|
""".strip(),
|
|
),
|
|
],
|
|
)
|
|
def test_rls_subquery_transformer(
|
|
sql: str,
|
|
rules: dict[Table, str],
|
|
expected: str,
|
|
) -> None:
|
|
"""
|
|
Test `RLSAsSubqueryTransformer`.
|
|
"""
|
|
statement = SQLStatement(sql)
|
|
statement.apply_rls(
|
|
"catalog1",
|
|
"schema1",
|
|
{k: [parse_one(v)] for k, v in rules.items()},
|
|
RLSMethod.AS_SUBQUERY,
|
|
)
|
|
assert statement.format() == expected
|
|
|
|
|
|
def test_rls_invalid_method(mocker: MockerFixture) -> None:
|
|
"""
|
|
Test that an invalid RLS method raises an error.
|
|
"""
|
|
statement = SQLStatement("SELECT 1", "postgresql")
|
|
predicates = mocker.MagicMock()
|
|
|
|
with pytest.raises(ValueError, match="Invalid RLS method: invalid"):
|
|
statement.apply_rls("catalog1", "schema1", predicates, "invalid") # type: ignore
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"sql, rules, expected",
|
|
[
|
|
(
|
|
"SELECT t.foo FROM some_table AS t",
|
|
{Table("some_table", "schema1", "catalog1"): "id = 42"},
|
|
"""
|
|
SELECT
|
|
t.foo
|
|
FROM some_table AS t
|
|
WHERE
|
|
t.id = 42
|
|
""".strip(),
|
|
),
|
|
(
|
|
"SELECT t.foo FROM schema2.some_table AS t",
|
|
{Table("some_table", "schema1", "catalog1"): "id = 42"},
|
|
"""
|
|
SELECT
|
|
t.foo
|
|
FROM schema2.some_table AS t
|
|
""".strip(),
|
|
),
|
|
(
|
|
"SELECT t.foo FROM catalog2.schema1.some_table AS t",
|
|
{Table("some_table", "schema1", "catalog1"): "id = 42"},
|
|
"""
|
|
SELECT
|
|
t.foo
|
|
FROM catalog2.schema1.some_table AS t
|
|
""".strip(),
|
|
),
|
|
(
|
|
"SELECT t.foo FROM some_table AS t WHERE bar = 'baz'",
|
|
{Table("some_table", "schema1", "catalog1"): "id = 42"},
|
|
"""
|
|
SELECT
|
|
t.foo
|
|
FROM some_table AS t
|
|
WHERE
|
|
t.id = 42 AND (
|
|
bar = 'baz'
|
|
)
|
|
""".strip(),
|
|
),
|
|
(
|
|
"SELECT t.foo FROM some_table AS t WHERE bar = 'baz' OR foo = 'qux'",
|
|
{Table("some_table", "schema1", "catalog1"): "id = 42"},
|
|
"""
|
|
SELECT
|
|
t.foo
|
|
FROM some_table AS t
|
|
WHERE
|
|
t.id = 42 AND (
|
|
bar = 'baz' OR foo = 'qux'
|
|
)
|
|
""".strip(),
|
|
),
|
|
(
|
|
"SELECT * FROM some_table WHERE 1=1",
|
|
{Table("some_table", "schema1", "catalog1"): "id = 42"},
|
|
"""
|
|
SELECT
|
|
*
|
|
FROM some_table
|
|
WHERE
|
|
some_table.id = 42 AND (
|
|
1 = 1
|
|
)
|
|
""".strip(),
|
|
),
|
|
(
|
|
"SELECT * FROM some_table WHERE TRUE OR FALSE",
|
|
{Table("some_table", "schema1", "catalog1"): "id = 42"},
|
|
"""
|
|
SELECT
|
|
*
|
|
FROM some_table
|
|
WHERE
|
|
some_table.id = 42 AND (
|
|
TRUE OR FALSE
|
|
)
|
|
""".strip(),
|
|
),
|
|
(
|
|
"SELECT * FROM table WHERE 1=1",
|
|
{Table("table", "schema1", "catalog1"): "id = 42"},
|
|
"""
|
|
SELECT
|
|
*
|
|
FROM table
|
|
WHERE
|
|
table.id = 42 AND (
|
|
1 = 1
|
|
)
|
|
""".strip(),
|
|
),
|
|
(
|
|
'SELECT * FROM "table" WHERE 1=1',
|
|
{Table("table", "schema1", "catalog1"): "id = 42"},
|
|
"""
|
|
SELECT
|
|
*
|
|
FROM "table"
|
|
WHERE
|
|
"table".id = 42 AND (
|
|
1 = 1
|
|
)
|
|
""".strip(),
|
|
),
|
|
(
|
|
"SELECT * FROM table WHERE 1=1",
|
|
{Table("other_table", "schema1", "catalog1"): "id = 42"},
|
|
"""
|
|
SELECT
|
|
*
|
|
FROM table
|
|
WHERE
|
|
1 = 1
|
|
""".strip(),
|
|
),
|
|
(
|
|
"SELECT * FROM other_table WHERE 1=1",
|
|
{Table("table", "schema1", "catalog1"): "id = 42"},
|
|
"""
|
|
SELECT
|
|
*
|
|
FROM other_table
|
|
WHERE
|
|
1 = 1
|
|
""".strip(),
|
|
),
|
|
(
|
|
"SELECT * FROM table",
|
|
{Table("table", "schema1", "catalog1"): "id = 42"},
|
|
"""
|
|
SELECT
|
|
*
|
|
FROM table
|
|
WHERE
|
|
table.id = 42
|
|
""".strip(),
|
|
),
|
|
(
|
|
"SELECT * FROM some_table",
|
|
{Table("some_table", "schema1", "catalog1"): "id = 42"},
|
|
"""
|
|
SELECT
|
|
*
|
|
FROM some_table
|
|
WHERE
|
|
some_table.id = 42
|
|
""".strip(),
|
|
),
|
|
(
|
|
"SELECT * FROM table ORDER BY id",
|
|
{Table("table", "schema1", "catalog1"): "id = 42"},
|
|
"""
|
|
SELECT
|
|
*
|
|
FROM table
|
|
WHERE
|
|
table.id = 42
|
|
ORDER BY
|
|
id
|
|
""".strip(),
|
|
),
|
|
(
|
|
"SELECT * FROM table WHERE 1=1 AND table.id=42",
|
|
{Table("table", "schema1", "catalog1"): "id = 42"},
|
|
"""
|
|
SELECT
|
|
*
|
|
FROM table
|
|
WHERE
|
|
table.id = 42 AND (
|
|
1 = 1 AND table.id = 42
|
|
)
|
|
""".strip(),
|
|
),
|
|
(
|
|
"""
|
|
SELECT * FROM table
|
|
JOIN other_table
|
|
ON table.id = other_table.id
|
|
AND other_table.id=42
|
|
""",
|
|
{Table("other_table", "schema1", "catalog1"): "id = 42"},
|
|
"""
|
|
SELECT
|
|
*
|
|
FROM table
|
|
JOIN other_table
|
|
ON other_table.id = 42 AND (
|
|
table.id = other_table.id AND other_table.id = 42
|
|
)
|
|
""".strip(),
|
|
),
|
|
(
|
|
"SELECT * FROM table WHERE 1=1 AND id=42",
|
|
{Table("table", "schema1", "catalog1"): "id = 42"},
|
|
"""
|
|
SELECT
|
|
*
|
|
FROM table
|
|
WHERE
|
|
table.id = 42 AND (
|
|
1 = 1 AND id = 42
|
|
)
|
|
""".strip(),
|
|
),
|
|
(
|
|
"SELECT * FROM table JOIN other_table ON table.id = other_table.id",
|
|
{Table("other_table", "schema1", "catalog1"): "id = 42"},
|
|
"""
|
|
SELECT
|
|
*
|
|
FROM table
|
|
JOIN other_table
|
|
ON other_table.id = 42 AND (
|
|
table.id = other_table.id
|
|
)
|
|
""".strip(),
|
|
),
|
|
(
|
|
"SELECT * FROM table JOIN other_table",
|
|
{Table("other_table", "schema1", "catalog1"): "id = 42"},
|
|
"""
|
|
SELECT
|
|
*
|
|
FROM table
|
|
JOIN other_table
|
|
ON other_table.id = 42
|
|
""".strip(),
|
|
),
|
|
(
|
|
"""
|
|
SELECT *
|
|
FROM table
|
|
JOIN other_table
|
|
ON table.id = other_table.id
|
|
WHERE 1=1
|
|
""",
|
|
{Table("other_table", "schema1", "catalog1"): "id = 42"},
|
|
"""
|
|
SELECT
|
|
*
|
|
FROM table
|
|
JOIN other_table
|
|
ON other_table.id = 42 AND (
|
|
table.id = other_table.id
|
|
)
|
|
WHERE
|
|
1 = 1
|
|
""".strip(),
|
|
),
|
|
(
|
|
"SELECT * FROM (SELECT * FROM other_table)",
|
|
{Table("other_table", "schema1", "catalog1"): "id = 42"},
|
|
"""
|
|
SELECT
|
|
*
|
|
FROM (
|
|
SELECT
|
|
*
|
|
FROM other_table
|
|
WHERE
|
|
other_table.id = 42
|
|
)
|
|
""".strip(),
|
|
),
|
|
(
|
|
"SELECT * FROM table UNION ALL SELECT * FROM other_table",
|
|
{Table("table", "schema1", "catalog1"): "id = 42"},
|
|
"""
|
|
SELECT
|
|
*
|
|
FROM table
|
|
WHERE
|
|
table.id = 42
|
|
UNION ALL
|
|
SELECT
|
|
*
|
|
FROM other_table
|
|
""".strip(),
|
|
),
|
|
(
|
|
"SELECT * FROM table UNION ALL SELECT * FROM other_table",
|
|
{Table("other_table", "schema1", "catalog1"): "id = 42"},
|
|
"""
|
|
SELECT
|
|
*
|
|
FROM table
|
|
UNION ALL
|
|
SELECT
|
|
*
|
|
FROM other_table
|
|
WHERE
|
|
other_table.id = 42
|
|
""".strip(),
|
|
),
|
|
(
|
|
"INSERT INTO some_table (col1, col2) VALUES (1, 2)",
|
|
{Table("some_table", "schema1", "catalog1"): "id = 42"},
|
|
"""
|
|
INSERT INTO some_table (
|
|
col1,
|
|
col2
|
|
)
|
|
VALUES
|
|
(1, 2)
|
|
""".strip(),
|
|
),
|
|
],
|
|
)
|
|
def test_rls_predicate_transformer(
|
|
sql: str,
|
|
rules: dict[Table, str],
|
|
expected: str,
|
|
) -> None:
|
|
"""
|
|
Test `RLSPredicateTransformer`.
|
|
"""
|
|
statement = SQLStatement(sql)
|
|
statement.apply_rls(
|
|
"catalog1",
|
|
"schema1",
|
|
{k: [parse_one(v)] for k, v in rules.items()},
|
|
RLSMethod.AS_PREDICATE,
|
|
)
|
|
assert statement.format() == expected
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"sql, table, expected",
|
|
[
|
|
(
|
|
"SELECT * FROM some_table",
|
|
Table("some_table"),
|
|
"""
|
|
CREATE TABLE some_table AS
|
|
SELECT
|
|
*
|
|
FROM some_table
|
|
""".strip(),
|
|
),
|
|
(
|
|
"SELECT * FROM some_table",
|
|
Table("some_table", "schema1", "catalog1"),
|
|
"""
|
|
CREATE TABLE catalog1.schema1.some_table AS
|
|
SELECT
|
|
*
|
|
FROM some_table
|
|
""".strip(),
|
|
),
|
|
],
|
|
)
|
|
def test_as_create_table(sql: str, table: Table, expected: str) -> None:
|
|
"""
|
|
Test the `as_create_table` method.
|
|
"""
|
|
statement = SQLStatement(sql)
|
|
create_table = statement.as_create_table(table, CTASMethod.TABLE)
|
|
assert create_table.format() == expected
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"sql, engine, expected",
|
|
[
|
|
("SELECT * FROM table", "postgresql", True),
|
|
(
|
|
"""
|
|
-- comment
|
|
SELECT * FROM table
|
|
-- comment 2
|
|
""",
|
|
"mysql",
|
|
True,
|
|
),
|
|
(
|
|
"""
|
|
-- comment
|
|
SET @value = 42;
|
|
SELECT @value as foo;
|
|
-- comment 2
|
|
""",
|
|
"mysql",
|
|
True,
|
|
),
|
|
(
|
|
"""
|
|
-- comment
|
|
EXPLAIN SELECT * FROM table
|
|
-- comment 2
|
|
""",
|
|
"mysql",
|
|
False,
|
|
),
|
|
(
|
|
"""
|
|
SELECT * FROM table;
|
|
INSERT INTO TABLE (foo) VALUES (42);
|
|
""",
|
|
"mysql",
|
|
False,
|
|
),
|
|
],
|
|
)
|
|
def test_is_valid_ctas(sql: str, engine: str, expected: bool) -> None:
|
|
"""
|
|
Test the `is_valid_ctas` method.
|
|
"""
|
|
assert SQLScript(sql, engine).is_valid_ctas() == expected
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"sql, engine, expected",
|
|
[
|
|
("SELECT * FROM table", "postgresql", True),
|
|
(
|
|
"""
|
|
-- comment
|
|
SELECT * FROM table
|
|
-- comment 2
|
|
""",
|
|
"mysql",
|
|
True,
|
|
),
|
|
(
|
|
"""
|
|
-- comment
|
|
SET @value = 42;
|
|
SELECT @value as foo;
|
|
-- comment 2
|
|
""",
|
|
"mysql",
|
|
False,
|
|
),
|
|
(
|
|
"""
|
|
-- comment
|
|
SELECT value as foo;
|
|
-- comment 2
|
|
""",
|
|
"mysql",
|
|
True,
|
|
),
|
|
(
|
|
"""
|
|
SELECT * FROM table;
|
|
INSERT INTO TABLE (foo) VALUES (42);
|
|
""",
|
|
"mysql",
|
|
False,
|
|
),
|
|
],
|
|
)
|
|
def test_is_valid_cvas(sql: str, engine: str, expected: bool) -> None:
|
|
"""
|
|
Test the `is_valid_cvas` method.
|
|
"""
|
|
assert SQLScript(sql, engine).is_valid_cvas() == expected
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"sql, expected, engine",
|
|
[
|
|
("col = 1", "col = 1", "base"),
|
|
("1=\t\n1", "1 = 1", "base"),
|
|
("(col = 1)", "(col = 1)", "base"), # Compact format without newlines
|
|
(
|
|
"(col1 = 1) AND (col2 = 2)",
|
|
"(col1 = 1) AND (col2 = 2)",
|
|
"base",
|
|
), # Compact format
|
|
(
|
|
"col = 'abc' -- comment",
|
|
"col = 'abc'",
|
|
"base",
|
|
), # Comments removed for compact format
|
|
("col = 'col1 = 1) AND (col2 = 2'", "col = 'col1 = 1) AND (col2 = 2'", "base"),
|
|
("col = 'select 1; select 2'", "col = 'select 1; select 2'", "base"),
|
|
("col = 'abc -- comment'", "col = 'abc -- comment'", "base"),
|
|
("col1 = 1) AND (col2 = 2)", QueryClauseValidationException, "base"),
|
|
("(col1 = 1) AND (col2 = 2", QueryClauseValidationException, "base"),
|
|
("col1 = 1) AND (col2 = 2", QueryClauseValidationException, "base"),
|
|
("(col1 = 1)) AND ((col2 = 2)", QueryClauseValidationException, "base"),
|
|
("TRUE; SELECT 1", QueryClauseValidationException, "base"),
|
|
],
|
|
)
|
|
def test_sanitize_clause(sql: str, expected: str | Exception, engine: str) -> None:
|
|
"""
|
|
Test the `sanitize_clause` function.
|
|
"""
|
|
if isinstance(expected, str):
|
|
assert sanitize_clause(sql, engine) == expected
|
|
else:
|
|
with pytest.raises(expected):
|
|
sanitize_clause(sql, engine)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"engine",
|
|
[
|
|
"hive",
|
|
"presto",
|
|
"trino",
|
|
],
|
|
)
|
|
@pytest.mark.parametrize(
|
|
"macro, expected",
|
|
[
|
|
(
|
|
"latest_partition('foo.bar')",
|
|
{Table(table="bar", schema="foo")},
|
|
),
|
|
(
|
|
"latest_partition(' foo.bar ')", # Non-atypical user error which works
|
|
{Table(table="bar", schema="foo")},
|
|
),
|
|
(
|
|
"latest_partition('foo.%s'|format('bar'))",
|
|
{Table(table="bar", schema="foo")},
|
|
),
|
|
(
|
|
"latest_sub_partition('foo.bar', baz='qux')",
|
|
{Table(table="bar", schema="foo")},
|
|
),
|
|
(
|
|
"latest_partition('foo.%s'|format(str('bar')))",
|
|
set(),
|
|
),
|
|
(
|
|
"latest_partition('foo.{}'.format('bar'))",
|
|
set(),
|
|
),
|
|
],
|
|
)
|
|
def test_extract_tables_from_jinja_sql(
|
|
mocker: MockerFixture,
|
|
engine: str,
|
|
macro: str,
|
|
expected: set[Table],
|
|
) -> None:
|
|
assert (
|
|
process_jinja_sql(
|
|
sql=f"'{{{{ {engine}.{macro} }}}}'",
|
|
database=mocker.MagicMock(backend=engine),
|
|
).tables
|
|
== expected
|
|
)
|
|
|
|
|
|
@with_feature_flags(ENABLE_TEMPLATE_PROCESSING=False)
|
|
def test_extract_tables_from_jinja_sql_disabled(mocker: MockerFixture) -> None:
|
|
"""
|
|
Test the function when the feature flag is disabled.
|
|
"""
|
|
database = mocker.MagicMock()
|
|
database.db_engine_spec.engine = "mssql"
|
|
|
|
assert process_jinja_sql(
|
|
sql="SELECT 1 FROM t",
|
|
database=database,
|
|
).tables == {Table("t")}
|
|
|
|
|
|
def test_extract_tables_from_jinja_sql_invalid_function(mocker: MockerFixture) -> None:
|
|
"""
|
|
Test the function with an invalid function.
|
|
"""
|
|
database = mocker.MagicMock(backend="postgresql")
|
|
|
|
processor = JinjaTemplateProcessor(database)
|
|
processor.env.globals["my_table"] = lambda: "t"
|
|
mocker.patch(
|
|
"superset.jinja_context.get_template_processor",
|
|
return_value=processor,
|
|
)
|
|
|
|
assert process_jinja_sql(
|
|
sql="SELECT * FROM {{ my_table() }}",
|
|
database=database,
|
|
).tables == {Table("t")}
|
|
|
|
|
|
def test_process_jinja_sql_result_object_structure(mocker: MockerFixture) -> None:
|
|
"""
|
|
Test that process_jinja_sql returns a proper JinjaSQLResult object
|
|
with correct script and tables properties.
|
|
"""
|
|
database = mocker.MagicMock()
|
|
database.db_engine_spec.engine = "postgresql"
|
|
|
|
result = process_jinja_sql(
|
|
sql="SELECT id FROM users WHERE active = true",
|
|
database=database,
|
|
)
|
|
|
|
# Test that result is the correct type
|
|
assert isinstance(result, JinjaSQLResult)
|
|
|
|
# Test that script property returns a SQLScript
|
|
assert hasattr(result, "script")
|
|
assert isinstance(result.script, SQLScript)
|
|
|
|
# Test that tables property returns a set of Tables
|
|
assert hasattr(result, "tables")
|
|
assert isinstance(result.tables, set)
|
|
assert result.tables == {Table("users")}
|
|
|
|
# Test that the script contains the expected SQL
|
|
formatted_sql = result.script.format()
|
|
assert "users" in formatted_sql
|
|
assert "active = TRUE" in formatted_sql
|
|
|
|
|
|
def test_process_jinja_sql_template_params_parameter(mocker: MockerFixture) -> None:
|
|
"""
|
|
Test that the template_params parameter is properly handled.
|
|
"""
|
|
database = mocker.MagicMock()
|
|
database.db_engine_spec.engine = "postgresql"
|
|
|
|
processor = JinjaTemplateProcessor(database)
|
|
mocker.patch(
|
|
"superset.jinja_context.get_template_processor",
|
|
return_value=processor,
|
|
)
|
|
|
|
# Test that template_params parameter is accepted and passed through
|
|
result = process_jinja_sql(
|
|
sql="SELECT * FROM table_name",
|
|
database=database,
|
|
template_params={"param1": "value1"},
|
|
)
|
|
|
|
# Verify the function accepts the parameter without error
|
|
assert isinstance(result, JinjaSQLResult)
|
|
assert result.tables == {Table("table_name")}
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"sql, engine, expected",
|
|
[
|
|
("SELECT * FROM users", "postgresql", True),
|
|
("WITH cte AS (SELECT * FROM users) SELECT * FROM cte", "postgresql", True),
|
|
("CREATE TABLE users AS SELECT * FROM users", "postgresql", False),
|
|
("ALTER TABLE users ADD COLUMN age INT", "postgresql", False),
|
|
("SET @value = 42", "postgresql", False),
|
|
],
|
|
)
|
|
def test_sqlstatement_is_select(sql: str, engine: str, expected: bool) -> None:
|
|
"""
|
|
Test the `SQLStatement.is_select()` method.
|
|
"""
|
|
assert SQLStatement(sql, engine).is_select() == expected
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"kql, expected",
|
|
[
|
|
("StormEvents | take 10", True),
|
|
("StormEvents | limit 20", True),
|
|
("StormEvents | where State == 'FL' | summarize count()", True),
|
|
("StormEvents | where name has 'limit 10'", True),
|
|
("AnotherTable | take 5", True),
|
|
("datatable(x:int) [1, 2, 3] | take 100", True),
|
|
(".create table StormEvents (x:int)", False),
|
|
(".ingest inline into table StormEvents <| StormEvents | take 10", False),
|
|
],
|
|
)
|
|
def test_kqlstatement_is_select(kql: str, expected: bool) -> None:
|
|
"""
|
|
Test the `KustoKQLStatement.is_select()` method.
|
|
"""
|
|
assert KustoKQLStatement(kql, "kustokql").is_select() == expected
|
|
|
|
|
|
def test_singlestore_engine_mapping():
|
|
"""
|
|
Test the `singlestoredb` dialect is properly used.
|
|
"""
|
|
sql = "SELECT COUNT(*) AS `COUNT(*)`"
|
|
statement = SQLStatement(sql, engine="singlestoredb")
|
|
assert statement.is_select()
|
|
|
|
# Should parse without errors
|
|
formatted = statement.format()
|
|
assert "COUNT(*)" in formatted
|
|
|
|
|
|
def test_remove_quotes() -> None:
|
|
"""
|
|
Test the `remove_quotes` helper function.
|
|
"""
|
|
assert remove_quotes(None) is None
|
|
assert remove_quotes('"foo"') == "foo"
|
|
assert remove_quotes("'foo'") == "foo"
|
|
assert remove_quotes("`foo`") == "foo"
|
|
assert remove_quotes("'foo`") == "'foo`"
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"sql, engine, expected",
|
|
[
|
|
("SELECT * FROM table", "postgresql", False),
|
|
("SELECT VERSION()", "postgresql", True),
|
|
("SELECT query_to_xml()", "postgresql", True),
|
|
("WITH cte AS (SELECT * FROM table) SELECT * FROM cte", "postgresql", False),
|
|
(
|
|
"""
|
|
SELECT *
|
|
FROM query_to_xml('SELECT * from some_table WHERE id = 42')
|
|
""",
|
|
"postgresql",
|
|
True,
|
|
),
|
|
("Table | limit 10", "kustokql", False),
|
|
],
|
|
)
|
|
def test_check_functions_present(sql: str, engine: str, expected: bool) -> None:
|
|
"""
|
|
Check the `check_functions_present` method.
|
|
"""
|
|
functions = {"version", "query_to_xml"}
|
|
assert SQLScript(sql, engine).check_functions_present(functions) == expected
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"kql, expected",
|
|
[
|
|
(
|
|
"StormEvents | take 10",
|
|
[
|
|
(KQLTokenType.WORD, "StormEvents"),
|
|
(KQLTokenType.WHITESPACE, " "),
|
|
(KQLTokenType.OTHER, "|"),
|
|
(KQLTokenType.WHITESPACE, " "),
|
|
(KQLTokenType.WORD, "take"),
|
|
(KQLTokenType.WHITESPACE, " "),
|
|
(KQLTokenType.NUMBER, "10"),
|
|
],
|
|
),
|
|
("'test'", [(KQLTokenType.STRING, "'test'")]),
|
|
("```test```", [(KQLTokenType.STRING, "```test```")]),
|
|
],
|
|
)
|
|
def test_tokenize_kql(kql: str, expected: list[tuple[KQLTokenType, str]]) -> None:
|
|
"""
|
|
Test the `tokenize_kql` function.
|
|
"""
|
|
assert tokenize_kql(kql) == expected
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"sql, engine, expected",
|
|
[
|
|
("a = 1", "postgresql", False),
|
|
("(SELECT * FROM table)", "postgresql", True),
|
|
("SELECT * FROM table", "postgresql", False),
|
|
("SELECT * FROM (SELECT 1)", "postgresql", True),
|
|
("SELECT * FROM (SELECT 1) AS subquery", "postgresql", True),
|
|
("WITH cte AS (SELECT 1) SELECT * FROM cte", "postgresql", True),
|
|
("SELECT * FROM table WHERE EXISTS (SELECT 1)", "postgresql", True),
|
|
("SELECT * FROM table WHERE NOT EXISTS (SELECT 1)", "postgresql", True),
|
|
(
|
|
"SELECT * FROM table WHERE id IN (SELECT id FROM other_table)",
|
|
"postgresql",
|
|
True,
|
|
),
|
|
],
|
|
)
|
|
def test_has_subquery(sql: str, engine: str, expected: bool) -> None:
|
|
"""
|
|
Test the `has_subquery` method.
|
|
"""
|
|
assert SQLStatement(sql, engine).has_subquery() == expected
|
|
|
|
|
|
# =============================================================================
|
|
# Column-Level Security (CLS) Tests
|
|
# =============================================================================
|
|
|
|
|
|
def test_cls_action_enum() -> None:
|
|
"""
|
|
Test CLSAction enum values exist.
|
|
"""
|
|
assert CLSAction.HASH is not None
|
|
assert CLSAction.NULLIFY is not None
|
|
assert CLSAction.HIDE is not None
|
|
assert CLSAction.MASK is not None
|
|
|
|
|
|
def test_cls_hash_functions_mapping() -> None:
|
|
"""
|
|
Test that CLS_HASH_FUNCTIONS has entries for common dialects.
|
|
"""
|
|
# Check fallback exists
|
|
assert None in CLS_HASH_FUNCTIONS
|
|
assert CLS_HASH_FUNCTIONS[None] == "'[HASHED]'"
|
|
|
|
# Check common dialects
|
|
assert Dialects.POSTGRES in CLS_HASH_FUNCTIONS
|
|
assert Dialects.MYSQL in CLS_HASH_FUNCTIONS
|
|
assert Dialects.BIGQUERY in CLS_HASH_FUNCTIONS
|
|
assert Dialects.SNOWFLAKE in CLS_HASH_FUNCTIONS
|
|
|
|
# Verify hash patterns contain placeholder
|
|
for dialect, pattern in CLS_HASH_FUNCTIONS.items():
|
|
if dialect is not None and pattern != "'[HASHED]'":
|
|
assert "{}" in pattern, f"Missing placeholder in {dialect} hash pattern"
|
|
|
|
|
|
def test_apply_cls_empty_rules() -> None:
|
|
"""
|
|
Test that apply_cls returns original SQL when rules are empty.
|
|
"""
|
|
sql = "SELECT id, name FROM users"
|
|
result = apply_cls(sql, {}, engine="postgresql")
|
|
assert result == sql
|
|
|
|
|
|
def test_apply_cls_hash_action() -> None:
|
|
"""
|
|
Test CLSAction.HASH transforms column with hash function.
|
|
"""
|
|
rules = {Table("users"): {"ssn": CLSAction.HASH}}
|
|
sql = "SELECT ssn, name FROM users"
|
|
result = apply_cls(sql, rules, engine="postgresql")
|
|
|
|
assert result == (
|
|
"SELECT\n"
|
|
' MD5(CAST("users"."ssn" AS TEXT)) AS ssn,\n'
|
|
' "users"."name" AS "name"\n'
|
|
'FROM "users" AS "users"'
|
|
)
|
|
|
|
|
|
def test_apply_cls_nullify_action() -> None:
|
|
"""
|
|
Test CLSAction.NULLIFY transforms column to NULL.
|
|
"""
|
|
rules = {Table("users"): {"salary": CLSAction.NULLIFY}}
|
|
sql = "SELECT salary, name FROM users"
|
|
result = apply_cls(sql, rules, engine="postgresql")
|
|
|
|
assert result == (
|
|
'SELECT\n NULL AS salary,\n "users"."name" AS "name"\nFROM "users" AS "users"'
|
|
)
|
|
|
|
|
|
def test_apply_cls_hide_action() -> None:
|
|
"""
|
|
Test CLSAction.HIDE removes column from SELECT.
|
|
"""
|
|
rules = {Table("users"): {"password": CLSAction.HIDE}}
|
|
sql = "SELECT password, name FROM users"
|
|
result = apply_cls(sql, rules, engine="postgresql")
|
|
|
|
assert result == ('SELECT\n "users"."name" AS "name"\nFROM "users" AS "users"')
|
|
|
|
|
|
def test_apply_cls_mask_action() -> None:
|
|
"""
|
|
Test CLSAction.MASK transforms column to CASE expression preserving NULLs.
|
|
"""
|
|
rules = {Table("users"): {"phone": CLSAction.MASK}}
|
|
sql = "SELECT phone, name FROM users"
|
|
result = apply_cls(sql, rules, engine="postgresql")
|
|
|
|
assert result == (
|
|
"SELECT\n"
|
|
" CASE WHEN \"users\".\"phone\" IS NULL THEN NULL ELSE '****' END AS phone,\n"
|
|
' "users"."name" AS "name"\n'
|
|
'FROM "users" AS "users"'
|
|
)
|
|
|
|
|
|
def test_apply_cls_mask_preserves_null() -> None:
|
|
"""
|
|
Test CLSAction.MASK preserves NULL values using CASE expression.
|
|
|
|
MASK generates: CASE WHEN column IS NULL THEN NULL ELSE '****' END
|
|
This preserves the semantic meaning of NULL (no value) vs masked (hidden value).
|
|
"""
|
|
rules = {Table("users"): {"email": CLSAction.MASK}}
|
|
sql = "SELECT email FROM users"
|
|
result = apply_cls(sql, rules, engine="postgresql")
|
|
|
|
# The CASE expression should check for NULL and preserve it
|
|
assert "CASE WHEN" in result
|
|
assert "IS NULL THEN NULL" in result
|
|
assert "ELSE '****'" in result
|
|
|
|
|
|
def test_apply_cls_all_actions() -> None:
|
|
"""
|
|
Test all CLS actions in a single query.
|
|
"""
|
|
rules = {
|
|
Table("employees"): {
|
|
"ssn": CLSAction.HASH,
|
|
"salary": CLSAction.NULLIFY,
|
|
"password": CLSAction.HIDE,
|
|
"phone": CLSAction.MASK,
|
|
}
|
|
}
|
|
sql = "SELECT ssn, salary, password, phone, name FROM employees"
|
|
result = apply_cls(sql, rules, engine="postgresql")
|
|
|
|
assert result == (
|
|
"SELECT\n"
|
|
' MD5(CAST("employees"."ssn" AS TEXT)) AS ssn,\n'
|
|
" NULL AS salary,\n"
|
|
" CASE WHEN \"employees\".\"phone\" IS NULL THEN NULL ELSE '****' END AS phone,\n"
|
|
' "employees"."name" AS "name"\n'
|
|
'FROM "employees" AS "employees"'
|
|
)
|
|
|
|
|
|
def test_apply_cls_qualified_columns() -> None:
|
|
"""
|
|
Test CLS with fully qualified column names (table.column).
|
|
"""
|
|
rules = {Table("users"): {"ssn": CLSAction.HASH}}
|
|
sql = "SELECT users.ssn, users.name FROM users"
|
|
result = apply_cls(sql, rules, engine="postgresql")
|
|
|
|
assert result == (
|
|
"SELECT\n"
|
|
' MD5(CAST("users"."ssn" AS TEXT)) AS ssn,\n'
|
|
' "users"."name" AS "name"\n'
|
|
'FROM "users" AS "users"'
|
|
)
|
|
|
|
|
|
def test_apply_cls_table_alias() -> None:
|
|
"""
|
|
Test CLS with table aliases.
|
|
"""
|
|
rules = {Table("users"): {"ssn": CLSAction.HASH}}
|
|
sql = "SELECT u.ssn, u.name FROM users u"
|
|
result = apply_cls(sql, rules, engine="postgresql")
|
|
|
|
assert result == (
|
|
"SELECT\n"
|
|
' MD5(CAST("u"."ssn" AS TEXT)) AS ssn,\n'
|
|
' "u"."name" AS "name"\n'
|
|
'FROM "users" AS "u"'
|
|
)
|
|
|
|
|
|
def test_apply_cls_join() -> None:
|
|
"""
|
|
Test CLS with JOIN queries.
|
|
"""
|
|
rules = {
|
|
Table("employees"): {"ssn": CLSAction.HASH},
|
|
Table("salaries"): {"amount": CLSAction.NULLIFY},
|
|
}
|
|
sql = """
|
|
SELECT e.ssn, e.name, s.amount
|
|
FROM employees e
|
|
JOIN salaries s
|
|
ON e.id = s.employee_id
|
|
"""
|
|
result = apply_cls(sql, rules, engine="postgresql")
|
|
|
|
assert result == (
|
|
"SELECT\n"
|
|
' MD5(CAST("e"."ssn" AS TEXT)) AS ssn,\n'
|
|
' "e"."name" AS "name",\n'
|
|
" NULL AS amount\n"
|
|
'FROM "employees" AS "e"\n'
|
|
'JOIN "salaries" AS "s"\n'
|
|
' ON "e"."id" = "s"."employee_id"'
|
|
)
|
|
|
|
|
|
def test_apply_cls_case_insensitive() -> None:
|
|
"""
|
|
Test CLS rules are case-insensitive for table and column names.
|
|
"""
|
|
rules = {Table("USERS"): {"SSN": CLSAction.HASH}}
|
|
sql = "SELECT ssn, name FROM users"
|
|
result = apply_cls(sql, rules, engine="postgresql")
|
|
|
|
assert result == (
|
|
"SELECT\n"
|
|
' MD5(CAST("users"."ssn" AS TEXT)) AS ssn,\n'
|
|
' "users"."name" AS "name"\n'
|
|
'FROM "users" AS "users"'
|
|
)
|
|
|
|
|
|
def test_apply_cls_with_column_alias() -> None:
|
|
"""
|
|
Test CLS preserves existing column aliases.
|
|
"""
|
|
rules = {Table("users"): {"ssn": CLSAction.HASH}}
|
|
sql = "SELECT ssn AS social_security, name FROM users"
|
|
result = apply_cls(sql, rules, engine="postgresql")
|
|
|
|
assert result == (
|
|
"SELECT\n"
|
|
' MD5(CAST("users"."ssn" AS TEXT)) AS social_security,\n'
|
|
' "users"."name" AS "name"\n'
|
|
'FROM "users" AS "users"'
|
|
)
|
|
|
|
|
|
def test_apply_cls_no_matching_table() -> None:
|
|
"""
|
|
Test CLS leaves columns unchanged when table doesn't match rules.
|
|
"""
|
|
rules = {Table("other_table"): {"ssn": CLSAction.HASH}}
|
|
sql = "SELECT ssn, name FROM users"
|
|
result = apply_cls(sql, rules, engine="postgresql")
|
|
|
|
# Table doesn't match rules, so columns are unchanged (just qualified)
|
|
assert result == (
|
|
"SELECT\n"
|
|
' "users"."ssn" AS "ssn",\n'
|
|
' "users"."name" AS "name"\n'
|
|
'FROM "users" AS "users"'
|
|
)
|
|
|
|
|
|
def test_apply_cls_non_column_expressions() -> None:
|
|
"""
|
|
Test CLS leaves non-column expressions unchanged.
|
|
"""
|
|
rules = {Table("users"): {"name": CLSAction.HASH}}
|
|
sql = "SELECT 1 AS one, 'test' AS str, COUNT(*) AS cnt FROM users"
|
|
result = apply_cls(sql, rules, engine="postgresql")
|
|
|
|
assert result == (
|
|
"SELECT\n"
|
|
' 1 AS "one",\n'
|
|
" 'test' AS \"str\",\n"
|
|
' COUNT(*) AS "cnt"\n'
|
|
'FROM "users" AS "users"'
|
|
)
|
|
|
|
|
|
def test_apply_cls_with_schema() -> None:
|
|
"""
|
|
Test CLS with schema for column qualification.
|
|
"""
|
|
rules = {
|
|
Table("employees"): {"ssn": CLSAction.HASH},
|
|
Table("departments"): {"budget": CLSAction.NULLIFY},
|
|
}
|
|
schema = {
|
|
"employees": {
|
|
"id": "INT",
|
|
"ssn": "VARCHAR",
|
|
"name": "VARCHAR",
|
|
"dept_id": "INT",
|
|
},
|
|
"departments": {"id": "INT", "name": "VARCHAR", "budget": "DECIMAL"},
|
|
}
|
|
sql = """
|
|
SELECT
|
|
ssn, name, budget
|
|
FROM employees e
|
|
JOIN departments d
|
|
ON e.dept_id = d.id
|
|
"""
|
|
result = apply_cls(sql, rules, engine="postgresql", schema=schema)
|
|
|
|
assert "MD5" in result
|
|
assert "NULL" in result
|
|
|
|
|
|
def test_apply_cls_different_dialects() -> None:
|
|
"""
|
|
Test CLS uses correct hash function for different database dialects.
|
|
"""
|
|
rules = {Table("users"): {"ssn": CLSAction.HASH}}
|
|
sql = "SELECT ssn FROM users"
|
|
|
|
# PostgreSQL
|
|
result_pg = apply_cls(sql, rules, engine="postgresql")
|
|
assert result_pg == (
|
|
'SELECT\n MD5(CAST("users"."ssn" AS TEXT)) AS ssn\nFROM "users" AS "users"'
|
|
)
|
|
|
|
# MySQL
|
|
result_mysql = apply_cls(sql, rules, engine="mysql")
|
|
assert result_mysql == (
|
|
"SELECT\n MD5(CAST(`users`.`ssn` AS CHAR)) AS ssn\nFROM `users` AS `users`"
|
|
)
|
|
|
|
# BigQuery
|
|
result_bq = apply_cls(sql, rules, engine="bigquery")
|
|
assert result_bq == (
|
|
"SELECT\n"
|
|
" TO_HEX(MD5(CAST(`users`.`ssn` AS STRING))) AS ssn\n"
|
|
"FROM `users` AS `users`"
|
|
)
|
|
|
|
|
|
def test_apply_cls_unknown_dialect_fallback() -> None:
|
|
"""
|
|
Test CLS uses fallback for unknown database dialects.
|
|
"""
|
|
rules = {Table("users"): {"ssn": CLSAction.HASH}}
|
|
sql = "SELECT users.ssn FROM users"
|
|
result = apply_cls(sql, rules, engine="unknown_database")
|
|
|
|
assert result == ('SELECT\n \'[HASHED]\' AS ssn\nFROM "users" AS "users"')
|
|
|
|
|
|
def test_apply_cls_select_star_warning(caplog: pytest.LogCaptureFixture) -> None:
|
|
"""
|
|
Test CLS logs warning for SELECT * queries.
|
|
"""
|
|
rules = {Table("users"): {"ssn": CLSAction.HASH}}
|
|
sql = "SELECT * FROM users"
|
|
|
|
import logging
|
|
|
|
with caplog.at_level(logging.WARNING):
|
|
result = apply_cls(sql, rules, engine="postgresql")
|
|
|
|
assert (
|
|
"SELECT *" in caplog.text or "CLS cannot fully process SELECT *" in caplog.text
|
|
)
|
|
assert "*" in result # Star should be preserved
|
|
|
|
|
|
def test_sql_statement_apply_cls_method() -> None:
|
|
"""
|
|
Test SQLStatement.apply_cls method.
|
|
"""
|
|
rules = {Table("users"): {"ssn": CLSAction.HASH}}
|
|
statement = SQLStatement("SELECT ssn, name FROM users", engine="postgresql")
|
|
statement.apply_cls(rules)
|
|
result = statement.format()
|
|
|
|
assert result == (
|
|
"SELECT\n"
|
|
' MD5(CAST("users"."ssn" AS TEXT)) AS ssn,\n'
|
|
' "users"."name" AS "name"\n'
|
|
'FROM "users" AS "users"'
|
|
)
|
|
|
|
|
|
def test_sql_statement_apply_cls_empty_rules() -> None:
|
|
"""
|
|
Test SQLStatement.apply_cls with empty rules returns unchanged statement.
|
|
"""
|
|
original_sql = "SELECT ssn, name FROM users"
|
|
statement = SQLStatement(original_sql, engine="postgresql")
|
|
statement.apply_cls({})
|
|
result = statement.format()
|
|
|
|
# Empty rules, so original query is preserved (just formatted)
|
|
assert result == ("SELECT\n ssn,\n name\nFROM users")
|
|
|
|
|
|
def test_sql_statement_apply_cls_with_schema() -> None:
|
|
"""
|
|
Test SQLStatement.apply_cls with schema parameter.
|
|
"""
|
|
rules = {Table("employees"): {"ssn": CLSAction.HASH}}
|
|
schema = {"employees": {"id": "INT", "ssn": "VARCHAR", "name": "VARCHAR"}}
|
|
statement = SQLStatement("SELECT ssn, name FROM employees", engine="postgresql")
|
|
statement.apply_cls(rules, schema=schema)
|
|
result = statement.format()
|
|
|
|
assert result == (
|
|
"SELECT\n"
|
|
' MD5(CAST("employees"."ssn" AS TEXT)) AS ssn,\n'
|
|
' "employees"."name" AS "name"\n'
|
|
'FROM "employees" AS "employees"'
|
|
)
|
|
|
|
|
|
def test_cls_transformer_normalize_rules() -> None:
|
|
"""
|
|
Test CLSTransformer normalizes table and column names to lowercase.
|
|
"""
|
|
rules = {Table("USERS"): {"SSN": CLSAction.HASH, "Name": CLSAction.MASK}}
|
|
transformer = CLSTransformer(rules, Dialects.POSTGRES)
|
|
|
|
# Check that a normalized Table key exists
|
|
normalized_key = Table("users")
|
|
assert normalized_key in transformer.rules
|
|
assert "ssn" in transformer.rules[normalized_key]
|
|
assert "name" in transformer.rules[normalized_key]
|
|
|
|
|
|
def test_cls_transformer_get_action() -> None:
|
|
"""
|
|
Test CLSTransformer._get_action method.
|
|
"""
|
|
rules = {Table("users"): {"ssn": CLSAction.HASH}}
|
|
transformer = CLSTransformer(rules, Dialects.POSTGRES)
|
|
|
|
# Valid table and column
|
|
assert transformer._get_action("users", "ssn") == CLSAction.HASH
|
|
|
|
# Case insensitive
|
|
assert transformer._get_action("USERS", "SSN") == CLSAction.HASH
|
|
|
|
# No matching column
|
|
assert transformer._get_action("users", "name") is None
|
|
|
|
# No matching table
|
|
assert transformer._get_action("other", "ssn") is None
|
|
|
|
# None table
|
|
assert transformer._get_action(None, "ssn") is None
|
|
|
|
|
|
def test_cls_transformer_extract_scope_tables() -> None:
|
|
"""
|
|
Test CLSTransformer._extract_scope_tables method.
|
|
"""
|
|
rules = {Table("users"): {"ssn": CLSAction.HASH}}
|
|
transformer = CLSTransformer(rules, Dialects.POSTGRES)
|
|
|
|
# Single table
|
|
select = parse_one("SELECT * FROM users")
|
|
tables = transformer._extract_scope_tables(select)
|
|
assert "users" in tables
|
|
assert tables["users"] == "users"
|
|
|
|
# Table with alias
|
|
select = parse_one("SELECT * FROM users u")
|
|
tables = transformer._extract_scope_tables(select)
|
|
assert "u" in tables
|
|
assert tables["u"] == "users"
|
|
|
|
# JOIN
|
|
select = parse_one("SELECT * FROM users u JOIN orders o ON u.id = o.user_id")
|
|
tables = transformer._extract_scope_tables(select)
|
|
assert "u" in tables
|
|
assert "o" in tables
|
|
assert tables["u"] == "users"
|
|
assert tables["o"] == "orders"
|
|
|
|
|
|
def test_cls_transformer_get_table_for_column_qualified() -> None:
|
|
"""
|
|
Test CLSTransformer._get_table_for_column with qualified columns.
|
|
"""
|
|
rules = {Table("users"): {"ssn": CLSAction.HASH}}
|
|
transformer = CLSTransformer(rules, Dialects.POSTGRES)
|
|
scope_tables = {"u": "users", "o": "orders"}
|
|
|
|
# Qualified with alias
|
|
column = parse_one("u.ssn").find(exp.Column)
|
|
result = transformer._get_table_for_column(column, scope_tables)
|
|
assert result == "users"
|
|
|
|
# Qualified with unknown alias (returns as-is)
|
|
column = parse_one("x.ssn").find(exp.Column)
|
|
result = transformer._get_table_for_column(column, scope_tables)
|
|
assert result == "x"
|
|
|
|
|
|
def test_cls_transformer_get_table_for_column_single_table() -> None:
|
|
"""
|
|
Test CLSTransformer._get_table_for_column infers single table.
|
|
"""
|
|
rules = {Table("users"): {"ssn": CLSAction.HASH}}
|
|
transformer = CLSTransformer(rules, Dialects.POSTGRES)
|
|
scope_tables = {"users": "users"}
|
|
|
|
# Unqualified column with single table in scope
|
|
column = parse_one("ssn").find(exp.Column)
|
|
result = transformer._get_table_for_column(column, scope_tables)
|
|
assert result == "users"
|
|
|
|
|
|
def test_cls_transformer_get_table_for_column_multi_table_rules_match() -> None:
|
|
"""
|
|
Test CLSTransformer._get_table_for_column matches against rules.
|
|
"""
|
|
rules = {Table("users"): {"ssn": CLSAction.HASH}}
|
|
transformer = CLSTransformer(rules, Dialects.POSTGRES)
|
|
scope_tables = {"users": "users", "orders": "orders"}
|
|
|
|
# Unqualified column that only exists in rules for one table
|
|
column = parse_one("ssn").find(exp.Column)
|
|
result = transformer._get_table_for_column(column, scope_tables)
|
|
assert result == "users"
|
|
|
|
|
|
def test_cls_transformer_get_table_for_column_no_match() -> None:
|
|
"""
|
|
Test CLSTransformer._get_table_for_column returns None when no match.
|
|
"""
|
|
rules = {Table("other"): {"col": CLSAction.HASH}}
|
|
transformer = CLSTransformer(rules, Dialects.POSTGRES)
|
|
scope_tables = {"users": "users", "orders": "orders"}
|
|
|
|
# Unqualified column with no matching rule
|
|
column = parse_one("ssn").find(exp.Column)
|
|
result = transformer._get_table_for_column(column, scope_tables)
|
|
assert result is None
|
|
|
|
|
|
def test_cls_transformer_get_column_alias() -> None:
|
|
"""
|
|
Test CLSTransformer._get_column_alias method.
|
|
"""
|
|
transformer = CLSTransformer({}, Dialects.POSTGRES)
|
|
|
|
# Column expression
|
|
column = parse_one("ssn").find(exp.Column)
|
|
assert transformer._get_column_alias(column) == "ssn"
|
|
|
|
# Alias expression
|
|
alias = parse_one("ssn AS social").find(exp.Alias)
|
|
assert transformer._get_column_alias(alias) == "social"
|
|
|
|
# Other expression (literal)
|
|
literal = parse_one("'test'").find(exp.Literal)
|
|
assert transformer._get_column_alias(literal) == "'test'"
|
|
|
|
|
|
def test_cls_transformer_create_expressions() -> None:
|
|
"""
|
|
Test CLSTransformer expression creation methods.
|
|
"""
|
|
transformer = CLSTransformer({}, Dialects.POSTGRES)
|
|
|
|
# Hash expression
|
|
column = parse_one("ssn").find(exp.Column)
|
|
hash_expr = transformer._create_hash_expression(column, "ssn")
|
|
assert isinstance(hash_expr, exp.Alias)
|
|
assert hash_expr.alias == "ssn"
|
|
|
|
# Null expression
|
|
null_expr = transformer._create_null_expression("salary")
|
|
assert isinstance(null_expr, exp.Alias)
|
|
assert null_expr.alias == "salary"
|
|
assert isinstance(null_expr.this, exp.Null)
|
|
|
|
# Mask expression (CASE expression that preserves NULLs)
|
|
phone_column = parse_one("phone").find(exp.Column)
|
|
mask_expr = transformer._create_mask_expression(phone_column, "phone")
|
|
assert isinstance(mask_expr, exp.Alias)
|
|
assert mask_expr.alias == "phone"
|
|
assert isinstance(mask_expr.this, exp.Case)
|
|
# The CASE should have a default of '****'
|
|
case_default = mask_expr.this.args.get("default")
|
|
assert isinstance(case_default, exp.Literal)
|
|
assert case_default.this == "****"
|
|
|
|
|
|
def test_cls_transformer_call_non_select() -> None:
|
|
"""
|
|
Test CLSTransformer.__call__ returns non-SELECT nodes unchanged.
|
|
"""
|
|
rules = {Table("users"): {"ssn": CLSAction.HASH}}
|
|
transformer = CLSTransformer(rules, Dialects.POSTGRES)
|
|
|
|
# Non-SELECT node should be returned unchanged
|
|
table = parse_one("users").find(exp.Column)
|
|
result = transformer(table)
|
|
assert result == table
|
|
|
|
|
|
def test_cls_transformer_transform_expression_non_column() -> None:
|
|
"""
|
|
Test CLSTransformer._transform_expression returns non-column expressions unchanged.
|
|
"""
|
|
rules = {Table("users"): {"ssn": CLSAction.HASH}}
|
|
transformer = CLSTransformer(rules, Dialects.POSTGRES)
|
|
scope_tables = {"users": "users"}
|
|
|
|
# Literal expression should be unchanged
|
|
literal = parse_one("'test'")
|
|
result = transformer._transform_expression(literal, scope_tables)
|
|
assert result == literal
|
|
|
|
# Function expression should be unchanged
|
|
func = parse_one("COUNT(*)")
|
|
result = transformer._transform_expression(func, scope_tables)
|
|
assert result == func
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"sql,rules,engine,expected",
|
|
[
|
|
# Basic HASH
|
|
(
|
|
"SELECT t.id FROM t",
|
|
{Table("t"): {"id": CLSAction.HASH}},
|
|
"postgresql",
|
|
'SELECT\n MD5(CAST("t"."id" AS TEXT)) AS id\nFROM "t" AS "t"',
|
|
),
|
|
# Basic NULLIFY
|
|
(
|
|
"SELECT t.salary FROM t",
|
|
{Table("t"): {"salary": CLSAction.NULLIFY}},
|
|
"postgresql",
|
|
'SELECT\n NULL AS salary\nFROM "t" AS "t"',
|
|
),
|
|
# Basic HIDE
|
|
(
|
|
"SELECT t.secret, t.public FROM t",
|
|
{Table("t"): {"secret": CLSAction.HIDE}},
|
|
"postgresql",
|
|
'SELECT\n "t"."public" AS "public"\nFROM "t" AS "t"',
|
|
),
|
|
# Basic MASK (preserves NULLs)
|
|
(
|
|
"SELECT t.phone FROM t",
|
|
{Table("t"): {"phone": CLSAction.MASK}},
|
|
"postgresql",
|
|
"SELECT\n CASE WHEN \"t\".\"phone\" IS NULL THEN NULL ELSE '****' END AS phone\nFROM \"t\" AS \"t\"",
|
|
),
|
|
# Multiple tables with different rules
|
|
(
|
|
"SELECT a.ssn, b.amount FROM users a JOIN payments b ON a.id = b.user_id",
|
|
{
|
|
Table("users"): {"ssn": CLSAction.HASH},
|
|
Table("payments"): {"amount": CLSAction.NULLIFY},
|
|
},
|
|
"postgresql",
|
|
(
|
|
"SELECT\n"
|
|
' MD5(CAST("a"."ssn" AS TEXT)) AS ssn,\n'
|
|
" NULL AS amount\n"
|
|
'FROM "users" AS "a"\n'
|
|
'JOIN "payments" AS "b"\n'
|
|
' ON "a"."id" = "b"."user_id"'
|
|
),
|
|
),
|
|
# Snowflake dialect
|
|
(
|
|
"SELECT t.col FROM t",
|
|
{Table("t"): {"col": CLSAction.HASH}},
|
|
"snowflake",
|
|
'SELECT\n MD5(TO_CHAR("T"."COL")) AS COL\nFROM "T" AS "T"',
|
|
),
|
|
# ClickHouse dialect
|
|
(
|
|
"SELECT t.col FROM t",
|
|
{Table("t"): {"col": CLSAction.HASH}},
|
|
"clickhouse",
|
|
'SELECT\n MD5(toString("t"."col")) AS col\nFROM "t" AS "t"',
|
|
),
|
|
],
|
|
)
|
|
def test_apply_cls_parametrized(
|
|
sql: str,
|
|
rules: dict[Table, Any],
|
|
engine: str,
|
|
expected: str,
|
|
) -> None:
|
|
"""
|
|
Parametrized tests for apply_cls function.
|
|
"""
|
|
result = apply_cls(sql, rules, engine=engine)
|
|
assert result == expected
|
|
|
|
|
|
def test_apply_cls_subquery() -> None:
|
|
"""
|
|
Test CLS applies to subqueries.
|
|
"""
|
|
rules = {Table("users"): {"ssn": CLSAction.HASH}}
|
|
sql = "SELECT * FROM (SELECT ssn, name FROM users) AS subq"
|
|
result = apply_cls(sql, rules, engine="postgresql")
|
|
|
|
assert result == (
|
|
"SELECT\n"
|
|
' "subq"."ssn" AS "ssn",\n'
|
|
' "subq"."name" AS "name"\n'
|
|
"FROM (\n"
|
|
" SELECT\n"
|
|
' MD5(CAST("users"."ssn" AS TEXT)) AS ssn,\n'
|
|
' "users"."name" AS "name"\n'
|
|
' FROM "users" AS "users"\n'
|
|
') AS "subq"'
|
|
)
|
|
|
|
|
|
def test_apply_cls_cte() -> None:
|
|
"""
|
|
Test CLS applies to CTEs.
|
|
"""
|
|
rules = {Table("users"): {"ssn": CLSAction.HASH}}
|
|
sql = "WITH cte AS (SELECT ssn, name FROM users) SELECT * FROM cte"
|
|
result = apply_cls(sql, rules, engine="postgresql")
|
|
|
|
assert result == (
|
|
'WITH "cte" AS (\n'
|
|
" SELECT\n"
|
|
' MD5(CAST("users"."ssn" AS TEXT)) AS ssn,\n'
|
|
' "users"."name" AS "name"\n'
|
|
' FROM "users" AS "users"\n'
|
|
")\n"
|
|
"SELECT\n"
|
|
' "cte"."ssn" AS "ssn",\n'
|
|
' "cte"."name" AS "name"\n'
|
|
'FROM "cte" AS "cte"'
|
|
)
|
|
|
|
|
|
def test_apply_cls_union() -> None:
|
|
"""
|
|
Test CLS applies to UNION queries.
|
|
"""
|
|
rules = {Table("users"): {"ssn": CLSAction.HASH}}
|
|
sql = "SELECT ssn FROM users UNION SELECT ssn FROM archived_users"
|
|
result = apply_cls(sql, rules, engine="postgresql")
|
|
|
|
assert result == (
|
|
"SELECT\n"
|
|
' MD5(CAST("users"."ssn" AS TEXT)) AS ssn\n'
|
|
'FROM "users" AS "users"\n'
|
|
"UNION\n"
|
|
"SELECT\n"
|
|
' "archived_users"."ssn" AS "ssn"\n'
|
|
'FROM "archived_users" AS "archived_users"'
|
|
)
|
|
|
|
|
|
def test_cls_hide_all_columns() -> None:
|
|
"""
|
|
Test CLS HIDE action when all columns are hidden.
|
|
"""
|
|
rules = {Table("users"): {"id": CLSAction.HIDE, "name": CLSAction.HIDE}}
|
|
sql = "SELECT id, name FROM users"
|
|
result = apply_cls(sql, rules, engine="postgresql")
|
|
|
|
# Both columns should be hidden, resulting in empty SELECT
|
|
assert result == 'SELECT\nFROM "users" AS "users"'
|
|
|
|
|
|
def test_cls_transformer_extract_scope_tables_no_from() -> None:
|
|
"""
|
|
Test CLSTransformer._extract_scope_tables with no FROM clause.
|
|
"""
|
|
transformer = CLSTransformer({}, Dialects.POSTGRES)
|
|
select = parse_one("SELECT 1")
|
|
tables = transformer._extract_scope_tables(select)
|
|
assert tables == {}
|
|
|
|
|
|
def test_cls_transformer_extract_scope_tables_no_joins() -> None:
|
|
"""
|
|
Test CLSTransformer._extract_scope_tables with FROM but no JOINs.
|
|
"""
|
|
transformer = CLSTransformer({}, Dialects.POSTGRES)
|
|
select = parse_one("SELECT * FROM users")
|
|
tables = transformer._extract_scope_tables(select)
|
|
assert "users" in tables
|
|
assert len(tables) == 1
|
|
|
|
|
|
def test_apply_cls_aliased_column_preserves_alias() -> None:
|
|
"""
|
|
Test that CLS preserves the alias when column has AS clause.
|
|
"""
|
|
rules = {Table("t"): {"col": CLSAction.HASH}}
|
|
sql = "SELECT t.col AS my_alias FROM t"
|
|
result = apply_cls(sql, rules, engine="postgresql")
|
|
|
|
assert result == (
|
|
'SELECT\n MD5(CAST("t"."col" AS TEXT)) AS my_alias\nFROM "t" AS "t"'
|
|
)
|
|
|
|
|
|
def test_cls_transformer_hash_pattern_fallback() -> None:
|
|
"""
|
|
Test CLSTransformer uses fallback hash pattern for unknown dialect.
|
|
"""
|
|
# Use None as dialect to trigger fallback
|
|
transformer = CLSTransformer({Table("t"): {"col": CLSAction.HASH}}, None)
|
|
assert transformer.hash_pattern == "'[HASHED]'"
|
|
|
|
|
|
# Tests for CLS predicate transformation
|
|
def test_apply_cls_where_clause_hash() -> None:
|
|
"""
|
|
Test CLS HASH transforms columns in WHERE clause predicates.
|
|
|
|
This prevents information leakage by ensuring filter conditions also
|
|
get hashed, so queries like "WHERE role='CEO'" won't match unless
|
|
the user knows the hash value.
|
|
"""
|
|
rules = {Table("payroll"): {"role": CLSAction.HASH}}
|
|
sql = "SELECT MAX(salary) FROM payroll WHERE role='CEO'"
|
|
result = apply_cls(sql, rules, engine="postgresql")
|
|
|
|
# The WHERE clause should have the column hashed
|
|
assert "MD5" in result
|
|
assert 'WHERE' in result
|
|
assert "MD5(CAST" in result
|
|
|
|
|
|
def test_apply_cls_where_clause_nullify() -> None:
|
|
"""
|
|
Test CLS NULLIFY in WHERE clause becomes FALSE to prevent filtering.
|
|
"""
|
|
rules = {Table("payroll"): {"salary": CLSAction.NULLIFY}}
|
|
sql = "SELECT name FROM payroll WHERE salary > 100000"
|
|
result = apply_cls(sql, rules, engine="postgresql")
|
|
|
|
# The WHERE clause column becomes FALSE to block filtering
|
|
assert "FALSE" in result
|
|
|
|
|
|
def test_apply_cls_where_clause_mask() -> None:
|
|
"""
|
|
Test CLS MASK in WHERE clause becomes FALSE to prevent filtering.
|
|
"""
|
|
rules = {Table("users"): {"phone": CLSAction.MASK}}
|
|
sql = "SELECT name FROM users WHERE phone = '555-1234'"
|
|
result = apply_cls(sql, rules, engine="postgresql")
|
|
|
|
# The WHERE clause column becomes FALSE to block filtering
|
|
assert "FALSE" in result
|
|
|
|
|
|
def test_apply_cls_where_clause_hide() -> None:
|
|
"""
|
|
Test CLS HIDE in WHERE clause becomes FALSE to prevent filtering.
|
|
"""
|
|
rules = {Table("users"): {"secret_code": CLSAction.HIDE}}
|
|
sql = "SELECT name FROM users WHERE secret_code = 'ADMIN'"
|
|
result = apply_cls(sql, rules, engine="postgresql")
|
|
|
|
# The WHERE clause column becomes FALSE to block filtering
|
|
assert "FALSE" in result
|
|
|
|
|
|
def test_apply_cls_where_clause_multiple_conditions() -> None:
|
|
"""
|
|
Test CLS transforms multiple conditions in WHERE clause.
|
|
"""
|
|
rules = {Table("users"): {"role": CLSAction.HASH, "salary": CLSAction.NULLIFY}}
|
|
sql = "SELECT name FROM users WHERE role = 'admin' AND salary > 50000"
|
|
result = apply_cls(sql, rules, engine="postgresql")
|
|
|
|
# One condition should be hashed, other becomes FALSE
|
|
assert "MD5" in result
|
|
assert "FALSE" in result
|
|
|
|
|
|
def test_apply_cls_join_on_clause_hash() -> None:
|
|
"""
|
|
Test CLS HASH transforms columns in JOIN ON clause.
|
|
"""
|
|
rules = {Table("users"): {"user_id": CLSAction.HASH}}
|
|
sql = """
|
|
SELECT o.order_id
|
|
FROM orders o
|
|
JOIN users u ON o.customer_id = u.user_id
|
|
"""
|
|
result = apply_cls(sql, rules, engine="postgresql")
|
|
|
|
# The ON clause should have the column hashed
|
|
assert "MD5" in result
|
|
|
|
|
|
def test_apply_cls_cross_join_no_on_clause() -> None:
|
|
"""
|
|
Test CLS handles CROSS JOIN (no ON clause) without error.
|
|
"""
|
|
rules = {Table("users"): {"ssn": CLSAction.HASH}}
|
|
sql = "SELECT u.ssn, p.name FROM users u CROSS JOIN products p"
|
|
result = apply_cls(sql, rules, engine="postgresql")
|
|
|
|
# SSN in SELECT should be hashed, CROSS JOIN has no ON to transform
|
|
assert "MD5" in result
|
|
assert "CROSS JOIN" in result
|
|
|
|
|
|
def test_apply_cls_having_clause_hash() -> None:
|
|
"""
|
|
Test CLS HASH transforms columns in HAVING clause.
|
|
"""
|
|
rules = {Table("sales"): {"region": CLSAction.HASH}}
|
|
sql = "SELECT COUNT(*) FROM sales GROUP BY region HAVING region = 'North'"
|
|
result = apply_cls(sql, rules, engine="postgresql")
|
|
|
|
# The HAVING clause should have the column hashed
|
|
assert "HAVING" in result
|
|
assert "MD5" in result
|
|
|
|
|
|
def test_apply_cls_group_by_hash() -> None:
|
|
"""
|
|
Test CLS HASH transforms columns in GROUP BY clause.
|
|
"""
|
|
rules = {Table("users"): {"department": CLSAction.HASH}}
|
|
sql = "SELECT COUNT(*) FROM users GROUP BY department"
|
|
result = apply_cls(sql, rules, engine="postgresql")
|
|
|
|
# GROUP BY should have the column hashed
|
|
assert "GROUP BY" in result
|
|
assert "MD5" in result
|
|
|
|
|
|
def test_apply_cls_group_by_hide() -> None:
|
|
"""
|
|
Test CLS HIDE removes column from GROUP BY clause.
|
|
"""
|
|
rules = {Table("users"): {"ssn": CLSAction.HIDE}}
|
|
sql = "SELECT COUNT(*) FROM users GROUP BY ssn"
|
|
result = apply_cls(sql, rules, engine="postgresql")
|
|
|
|
# GROUP BY should be removed (no columns left)
|
|
assert "GROUP BY" not in result
|
|
|
|
|
|
def test_apply_cls_group_by_nullify() -> None:
|
|
"""
|
|
Test CLS NULLIFY removes column from GROUP BY clause.
|
|
"""
|
|
rules = {Table("users"): {"salary": CLSAction.NULLIFY}}
|
|
sql = "SELECT COUNT(*) FROM users GROUP BY salary"
|
|
result = apply_cls(sql, rules, engine="postgresql")
|
|
|
|
# GROUP BY should be removed (no columns left)
|
|
assert "GROUP BY" not in result
|
|
|
|
|
|
def test_apply_cls_group_by_mask() -> None:
|
|
"""
|
|
Test CLS MASK removes column from GROUP BY clause.
|
|
"""
|
|
rules = {Table("users"): {"phone": CLSAction.MASK}}
|
|
sql = "SELECT COUNT(*) FROM users GROUP BY phone"
|
|
result = apply_cls(sql, rules, engine="postgresql")
|
|
|
|
# GROUP BY should be removed (no columns left)
|
|
assert "GROUP BY" not in result
|
|
|
|
|
|
def test_apply_cls_group_by_multiple_columns() -> None:
|
|
"""
|
|
Test CLS with multiple columns in GROUP BY - partial removal.
|
|
"""
|
|
rules = {Table("users"): {"ssn": CLSAction.HIDE}}
|
|
sql = "SELECT COUNT(*) FROM users GROUP BY department, ssn"
|
|
result = apply_cls(sql, rules, engine="postgresql")
|
|
|
|
# GROUP BY should keep department but remove ssn
|
|
assert "GROUP BY" in result
|
|
assert "department" in result.lower()
|
|
assert "ssn" not in result.lower()
|
|
|
|
|
|
def test_apply_cls_order_by_hash() -> None:
|
|
"""
|
|
Test CLS HASH transforms columns in ORDER BY clause.
|
|
"""
|
|
rules = {Table("users"): {"salary": CLSAction.HASH}}
|
|
sql = "SELECT name FROM users ORDER BY salary DESC"
|
|
result = apply_cls(sql, rules, engine="postgresql")
|
|
|
|
# ORDER BY should have the column hashed
|
|
assert "ORDER BY" in result
|
|
assert "MD5" in result
|
|
assert "DESC" in result
|
|
|
|
|
|
def test_apply_cls_order_by_hide() -> None:
|
|
"""
|
|
Test CLS HIDE removes column from ORDER BY clause.
|
|
"""
|
|
rules = {Table("users"): {"salary": CLSAction.HIDE}}
|
|
sql = "SELECT name FROM users ORDER BY salary"
|
|
result = apply_cls(sql, rules, engine="postgresql")
|
|
|
|
# ORDER BY should be removed (no columns left)
|
|
assert "ORDER BY" not in result
|
|
|
|
|
|
def test_apply_cls_order_by_nullify() -> None:
|
|
"""
|
|
Test CLS NULLIFY removes column from ORDER BY clause.
|
|
"""
|
|
rules = {Table("users"): {"salary": CLSAction.NULLIFY}}
|
|
sql = "SELECT name FROM users ORDER BY salary"
|
|
result = apply_cls(sql, rules, engine="postgresql")
|
|
|
|
# ORDER BY should be removed
|
|
assert "ORDER BY" not in result
|
|
|
|
|
|
def test_apply_cls_order_by_mask() -> None:
|
|
"""
|
|
Test CLS MASK removes column from ORDER BY clause.
|
|
"""
|
|
rules = {Table("users"): {"salary": CLSAction.MASK}}
|
|
sql = "SELECT name FROM users ORDER BY salary"
|
|
result = apply_cls(sql, rules, engine="postgresql")
|
|
|
|
# ORDER BY should be removed
|
|
assert "ORDER BY" not in result
|
|
|
|
|
|
def test_apply_cls_order_by_multiple_columns() -> None:
|
|
"""
|
|
Test CLS with multiple columns in ORDER BY - partial removal.
|
|
"""
|
|
rules = {Table("users"): {"salary": CLSAction.HIDE}}
|
|
sql = "SELECT name FROM users ORDER BY name, salary DESC"
|
|
result = apply_cls(sql, rules, engine="postgresql")
|
|
|
|
# ORDER BY should keep name but remove salary
|
|
assert "ORDER BY" in result
|
|
assert "name" in result.lower()
|
|
# salary should be removed
|
|
assert "salary" not in result.lower()
|
|
|
|
|
|
def test_apply_cls_case_expression_hide() -> None:
|
|
"""
|
|
Test CLS HIDE in CASE expression replaces column with NULL.
|
|
"""
|
|
rules = {Table("users"): {"status": CLSAction.HIDE}}
|
|
sql = "SELECT name, CASE WHEN status = 'active' THEN 'yes' ELSE 'no' END as is_active FROM users"
|
|
result = apply_cls(sql, rules, engine="postgresql")
|
|
|
|
# status should be replaced with NULL in the CASE
|
|
assert "NULL = 'active'" in result
|
|
# Original column name should not appear
|
|
assert "status" not in result.lower() or "null" in result.lower()
|
|
|
|
|
|
def test_apply_cls_case_expression_hash() -> None:
|
|
"""
|
|
Test CLS HASH in CASE expression transforms the column.
|
|
"""
|
|
rules = {Table("users"): {"status": CLSAction.HASH}}
|
|
sql = "SELECT name, CASE WHEN status = 'active' THEN 'yes' ELSE 'no' END as is_active FROM users"
|
|
result = apply_cls(sql, rules, engine="postgresql")
|
|
|
|
# status should be hashed in the CASE
|
|
assert "MD5" in result
|
|
assert "CASE WHEN" in result
|
|
|
|
|
|
def test_apply_cls_function_argument_hide() -> None:
|
|
"""
|
|
Test CLS HIDE in function argument replaces column with NULL.
|
|
"""
|
|
rules = {Table("users"): {"email": CLSAction.HIDE}}
|
|
sql = "SELECT UPPER(email) as upper_email FROM users"
|
|
result = apply_cls(sql, rules, engine="postgresql")
|
|
|
|
# email should be replaced with NULL
|
|
assert "UPPER(NULL)" in result
|
|
|
|
|
|
def test_apply_cls_function_argument_hash() -> None:
|
|
"""
|
|
Test CLS HASH in function argument transforms the column.
|
|
"""
|
|
rules = {Table("users"): {"email": CLSAction.HASH}}
|
|
sql = "SELECT UPPER(email) as upper_email FROM users"
|
|
result = apply_cls(sql, rules, engine="postgresql")
|
|
|
|
# email should be hashed inside UPPER
|
|
assert "UPPER(MD5" in result
|
|
|
|
|
|
def test_apply_cls_window_partition_by_hide() -> None:
|
|
"""
|
|
Test CLS HIDE removes column from window PARTITION BY.
|
|
"""
|
|
rules = {Table("employees"): {"department": CLSAction.HIDE}}
|
|
sql = "SELECT name, ROW_NUMBER() OVER (PARTITION BY department ORDER BY name) as rn FROM employees"
|
|
result = apply_cls(sql, rules, engine="postgresql")
|
|
|
|
# department should be removed from PARTITION BY
|
|
assert "PARTITION BY" not in result or "department" not in result.lower()
|
|
# ORDER BY name should remain
|
|
assert "ORDER BY" in result
|
|
|
|
|
|
def test_apply_cls_window_partition_by_hash() -> None:
|
|
"""
|
|
Test CLS HASH transforms column in window PARTITION BY.
|
|
"""
|
|
rules = {Table("employees"): {"department": CLSAction.HASH}}
|
|
sql = "SELECT name, ROW_NUMBER() OVER (PARTITION BY department ORDER BY name) as rn FROM employees"
|
|
result = apply_cls(sql, rules, engine="postgresql")
|
|
|
|
# department should be hashed in PARTITION BY
|
|
assert "PARTITION BY" in result
|
|
assert "MD5" in result
|
|
|
|
|
|
def test_apply_cls_window_order_by_hide() -> None:
|
|
"""
|
|
Test CLS HIDE removes column from window ORDER BY.
|
|
"""
|
|
rules = {Table("employees"): {"salary": CLSAction.HIDE}}
|
|
sql = "SELECT name, RANK() OVER (ORDER BY salary DESC) as rank FROM employees"
|
|
result = apply_cls(sql, rules, engine="postgresql")
|
|
|
|
# salary should be removed from ORDER BY, leaving empty or no ORDER
|
|
assert "salary" not in result.lower()
|
|
|
|
|
|
def test_apply_cls_window_order_by_hash() -> None:
|
|
"""
|
|
Test CLS HASH transforms column in window ORDER BY.
|
|
"""
|
|
rules = {Table("employees"): {"salary": CLSAction.HASH}}
|
|
sql = "SELECT name, RANK() OVER (ORDER BY salary DESC) as rank FROM employees"
|
|
result = apply_cls(sql, rules, engine="postgresql")
|
|
|
|
# salary should be hashed in window ORDER BY
|
|
assert "MD5" in result
|
|
assert "DESC" in result
|
|
|
|
|
|
def test_apply_cls_window_partition_only_no_order() -> None:
|
|
"""
|
|
Test CLS with window that has PARTITION BY but no ORDER BY.
|
|
Covers branch 847->836 (no window_order).
|
|
"""
|
|
rules = {Table("employees"): {"department": CLSAction.HASH}}
|
|
sql = "SELECT name, COUNT(*) OVER (PARTITION BY department) as cnt FROM employees"
|
|
result = apply_cls(sql, rules, engine="postgresql")
|
|
|
|
# department should be hashed in PARTITION BY, no ORDER BY
|
|
assert "PARTITION BY" in result
|
|
assert "MD5" in result
|
|
assert "ORDER BY" not in result.split("OVER")[1].split(")")[0] # No ORDER BY in window
|
|
|
|
|
|
def test_apply_cls_window_partition_all_blocked() -> None:
|
|
"""
|
|
Test CLS removes PARTITION BY when all columns are blocked.
|
|
Covers branch 842->840 (_is_blocked returns True in partition loop).
|
|
"""
|
|
rules = {Table("employees"): {"department": CLSAction.HIDE}}
|
|
sql = "SELECT name, COUNT(*) OVER (PARTITION BY department) as cnt FROM employees"
|
|
result = apply_cls(sql, rules, engine="postgresql")
|
|
|
|
# PARTITION BY should be removed entirely
|
|
assert "PARTITION BY" not in result
|
|
assert "OVER ()" in result
|
|
|
|
|
|
def test_apply_cls_window_order_all_blocked() -> None:
|
|
"""
|
|
Test CLS removes window ORDER BY when all columns are blocked.
|
|
Covers line 861 (window.set("order", None)) and branch 856->849.
|
|
"""
|
|
rules = {Table("employees"): {"salary": CLSAction.HIDE}}
|
|
sql = "SELECT name, RANK() OVER (ORDER BY salary) as rank FROM employees"
|
|
result = apply_cls(sql, rules, engine="postgresql")
|
|
|
|
# ORDER BY should be removed from window
|
|
assert "ORDER BY" not in result
|
|
assert "OVER ()" in result
|
|
|
|
|
|
def test_apply_cls_arithmetic_expression_hide() -> None:
|
|
"""
|
|
Test CLS HIDE in arithmetic expression replaces column with NULL.
|
|
"""
|
|
rules = {Table("products"): {"price": CLSAction.HIDE}}
|
|
sql = "SELECT name, price * 1.1 as price_with_tax FROM products"
|
|
result = apply_cls(sql, rules, engine="postgresql")
|
|
|
|
# price should be replaced with NULL
|
|
assert "NULL * 1.1" in result or "NULL" in result
|
|
|
|
|
|
def test_apply_cls_concat_function_hide() -> None:
|
|
"""
|
|
Test CLS HIDE in CONCAT function replaces column with NULL.
|
|
"""
|
|
rules = {Table("users"): {"ssn": CLSAction.HIDE}}
|
|
sql = "SELECT CONCAT('SSN: ', ssn) as display FROM users"
|
|
result = apply_cls(sql, rules, engine="postgresql")
|
|
|
|
# ssn should be replaced with NULL
|
|
assert "NULL" in result
|
|
|
|
|
|
def test_apply_cls_where_no_rules_unchanged() -> None:
|
|
"""
|
|
Test that WHERE clause without matching rules remains unchanged.
|
|
"""
|
|
rules = {Table("other_table"): {"col": CLSAction.HASH}}
|
|
sql = "SELECT name FROM users WHERE active = true"
|
|
result = apply_cls(sql, rules, engine="postgresql")
|
|
|
|
# No transformation should occur - just column qualification
|
|
assert "WHERE" in result
|
|
assert 'MD5' not in result
|
|
assert 'FALSE' not in result
|
|
|
|
|
|
def test_apply_cls_table_with_schema() -> None:
|
|
"""
|
|
Test CLS rules with Table containing schema.
|
|
|
|
Rules with schema specified require the table in scope to match.
|
|
Since queries often resolve to just the table name without schema,
|
|
we match by table name when the rule doesn't specify schema.
|
|
"""
|
|
# Rule without schema matches any table with that name
|
|
rules = {Table("users"): {"ssn": CLSAction.HASH}}
|
|
sql = "SELECT ssn FROM public.users"
|
|
result = apply_cls(sql, rules, engine="postgresql")
|
|
|
|
# Should apply the hash since table name matches
|
|
assert "MD5" in result
|
|
|
|
|
|
def test_apply_cls_table_key_matching() -> None:
|
|
"""
|
|
Test CLS rules match by table name when schema is not in query scope.
|
|
|
|
The scope_tables dict maps aliases to table names. When the query
|
|
has a schema-qualified table, the transformer still uses the table name.
|
|
"""
|
|
rules = {Table("users"): {"ssn": CLSAction.HASH}}
|
|
sql = "SELECT ssn FROM users"
|
|
result = apply_cls(sql, rules, engine="postgresql")
|
|
|
|
# Should apply the hash since table name matches
|
|
assert "MD5" in result
|
|
|
|
|
|
def test_apply_cls_table_with_schema_rule() -> None:
|
|
"""
|
|
Test CLS rules with Table(table, schema) pattern.
|
|
"""
|
|
rules = {Table("users", schema="public"): {"ssn": CLSAction.HASH}}
|
|
sql = "SELECT ssn FROM users"
|
|
result = apply_cls(sql, rules, engine="postgresql")
|
|
|
|
# Rule has schema but query doesn't - should match by table name fallback
|
|
assert "MD5" in result
|
|
|
|
|
|
def test_apply_cls_table_with_schema_case_insensitive() -> None:
|
|
"""
|
|
Test CLS rules with Table(table, schema) are case-insensitive.
|
|
"""
|
|
rules = {Table("USERS", schema="PUBLIC"): {"SSN": CLSAction.HASH}}
|
|
sql = "SELECT ssn FROM users"
|
|
result = apply_cls(sql, rules, engine="postgresql")
|
|
|
|
# Should match despite case differences
|
|
assert "MD5" in result
|
|
|
|
|
|
def test_apply_cls_table_with_schema_multiple_tables() -> None:
|
|
"""
|
|
Test CLS rules with Table(table, schema) for multiple tables.
|
|
"""
|
|
rules = {
|
|
Table("users", schema="public"): {"ssn": CLSAction.HASH},
|
|
Table("accounts", schema="finance"): {"balance": CLSAction.NULLIFY},
|
|
}
|
|
sql = """
|
|
SELECT u.ssn, a.balance
|
|
FROM users u
|
|
JOIN accounts a ON u.id = a.user_id
|
|
"""
|
|
result = apply_cls(sql, rules, engine="postgresql")
|
|
|
|
# Both rules should be applied
|
|
assert "MD5" in result
|
|
assert "NULL" in result
|
|
|
|
|
|
def test_apply_cls_table_with_catalog_and_schema_rule() -> None:
|
|
"""
|
|
Test CLS rules with Table(table, schema, catalog) pattern.
|
|
"""
|
|
rules = {Table("users", schema="public", catalog="mydb"): {"ssn": CLSAction.HASH}}
|
|
sql = "SELECT ssn FROM users"
|
|
result = apply_cls(sql, rules, engine="postgresql")
|
|
|
|
# Rule has catalog/schema but query doesn't - should match by table name fallback
|
|
assert "MD5" in result
|
|
|
|
|
|
def test_apply_cls_table_with_catalog_and_schema_case_insensitive() -> None:
|
|
"""
|
|
Test CLS rules with Table(table, schema, catalog) are case-insensitive.
|
|
"""
|
|
rules = {Table("USERS", schema="PUBLIC", catalog="MYDB"): {"SSN": CLSAction.MASK}}
|
|
sql = "SELECT ssn FROM users"
|
|
result = apply_cls(sql, rules, engine="postgresql")
|
|
|
|
# Should match despite case differences
|
|
assert "'****'" in result
|
|
|
|
|
|
def test_apply_cls_table_with_catalog_schema_multiple_actions() -> None:
|
|
"""
|
|
Test CLS rules with Table(table, schema, catalog) and multiple actions.
|
|
"""
|
|
rules = {
|
|
Table("employees", schema="hr", catalog="corp"): {
|
|
"ssn": CLSAction.HASH,
|
|
"salary": CLSAction.NULLIFY,
|
|
"phone": CLSAction.MASK,
|
|
"password": CLSAction.HIDE,
|
|
}
|
|
}
|
|
sql = "SELECT ssn, salary, phone, password, name FROM employees"
|
|
result = apply_cls(sql, rules, engine="postgresql")
|
|
|
|
# Verify all actions applied
|
|
assert "MD5" in result # HASH
|
|
assert "NULL" in result # NULLIFY
|
|
assert "'****'" in result # MASK
|
|
assert "password" not in result.lower() or "password" not in result.split("SELECT")[1].split("FROM")[0] # HIDE
|
|
|
|
|
|
def test_apply_cls_table_with_schema_in_predicate() -> None:
|
|
"""
|
|
Test CLS rules with Table(table, schema) also transform predicates.
|
|
"""
|
|
rules = {Table("payroll", schema="hr"): {"role": CLSAction.HASH}}
|
|
sql = "SELECT MAX(salary) FROM payroll WHERE role = 'CEO'"
|
|
result = apply_cls(sql, rules, engine="postgresql")
|
|
|
|
# Both SELECT and WHERE should have the column hashed
|
|
assert "MD5" in result
|
|
assert "WHERE" in result
|
|
|
|
|
|
def test_apply_cls_table_with_catalog_schema_in_predicate() -> None:
|
|
"""
|
|
Test CLS rules with Table(table, schema, catalog) also transform predicates.
|
|
"""
|
|
rules = {Table("payroll", schema="hr", catalog="corp"): {"salary": CLSAction.NULLIFY}}
|
|
sql = "SELECT salary, name FROM payroll WHERE salary > 100000"
|
|
result = apply_cls(sql, rules, engine="postgresql")
|
|
|
|
# SELECT should have NULL (for salary), WHERE should have FALSE
|
|
assert "NULL" in result
|
|
assert "FALSE" in result
|
|
|
|
|
|
def test_apply_cls_table_schema_no_match_different_table() -> None:
|
|
"""
|
|
Test CLS rules with Table(table, schema) don't match different table names.
|
|
"""
|
|
rules = {Table("employees", schema="hr"): {"ssn": CLSAction.HASH}}
|
|
sql = "SELECT ssn FROM users"
|
|
result = apply_cls(sql, rules, engine="postgresql")
|
|
|
|
# Should NOT apply - table name doesn't match
|
|
assert "MD5" not in result
|
|
|
|
|
|
def test_apply_cls_table_catalog_schema_no_match_different_table() -> None:
|
|
"""
|
|
Test CLS rules with Table(table, schema, catalog) don't match different table names.
|
|
"""
|
|
rules = {Table("employees", schema="hr", catalog="corp"): {"ssn": CLSAction.HASH}}
|
|
sql = "SELECT ssn FROM users"
|
|
result = apply_cls(sql, rules, engine="postgresql")
|
|
|
|
# Should NOT apply - table name doesn't match
|
|
assert "MD5" not in result
|
|
|
|
|
|
def test_apply_cls_mixed_table_rules() -> None:
|
|
"""
|
|
Test CLS with a mix of Table rules: some with schema/catalog, some without.
|
|
"""
|
|
rules = {
|
|
Table("users"): {"email": CLSAction.MASK}, # No schema/catalog
|
|
Table("employees", schema="hr"): {"ssn": CLSAction.HASH}, # With schema
|
|
Table("payroll", schema="finance", catalog="corp"): {"salary": CLSAction.NULLIFY}, # With both
|
|
}
|
|
sql = """
|
|
SELECT u.email, e.ssn, p.salary
|
|
FROM users u
|
|
JOIN employees e ON u.id = e.user_id
|
|
JOIN payroll p ON e.id = p.employee_id
|
|
"""
|
|
result = apply_cls(sql, rules, engine="postgresql")
|
|
|
|
# All three rules should be applied
|
|
assert "'****'" in result # MASK for email
|
|
assert "MD5" in result # HASH for ssn
|
|
assert "NULL" in result # NULLIFY for salary
|
|
|
|
|
|
def test_cls_transformer_normalize_rules_with_schema() -> None:
|
|
"""
|
|
Test CLSTransformer normalizes Table with schema to lowercase.
|
|
"""
|
|
rules = {Table("USERS", schema="PUBLIC"): {"SSN": CLSAction.HASH}}
|
|
transformer = CLSTransformer(rules, Dialects.POSTGRES)
|
|
|
|
# Check that normalized Table key exists
|
|
normalized_key = Table("users", schema="public")
|
|
assert normalized_key in transformer.rules
|
|
assert "ssn" in transformer.rules[normalized_key]
|
|
|
|
|
|
def test_cls_transformer_normalize_rules_with_catalog_and_schema() -> None:
|
|
"""
|
|
Test CLSTransformer normalizes Table with catalog and schema to lowercase.
|
|
"""
|
|
rules = {Table("USERS", schema="PUBLIC", catalog="MYDB"): {"SSN": CLSAction.HASH}}
|
|
transformer = CLSTransformer(rules, Dialects.POSTGRES)
|
|
|
|
# Check that normalized Table key exists
|
|
normalized_key = Table("users", schema="public", catalog="mydb")
|
|
assert normalized_key in transformer.rules
|
|
assert "ssn" in transformer.rules[normalized_key]
|
|
|
|
|
|
def test_cls_transformer_get_action_with_schema() -> None:
|
|
"""
|
|
Test CLSTransformer._get_action with Table(table, schema).
|
|
"""
|
|
rules = {Table("users", schema="public"): {"ssn": CLSAction.HASH}}
|
|
transformer = CLSTransformer(rules, Dialects.POSTGRES)
|
|
|
|
# Should match by table name (fallback behavior)
|
|
assert transformer._get_action("users", "ssn") == CLSAction.HASH
|
|
|
|
# Case insensitive
|
|
assert transformer._get_action("USERS", "SSN") == CLSAction.HASH
|
|
|
|
# Different table should not match
|
|
assert transformer._get_action("employees", "ssn") is None
|
|
|
|
|
|
def test_cls_transformer_get_action_with_catalog_and_schema() -> None:
|
|
"""
|
|
Test CLSTransformer._get_action with Table(table, schema, catalog).
|
|
"""
|
|
rules = {Table("users", schema="public", catalog="mydb"): {"ssn": CLSAction.HASH}}
|
|
transformer = CLSTransformer(rules, Dialects.POSTGRES)
|
|
|
|
# Should match by table name (fallback behavior)
|
|
assert transformer._get_action("users", "ssn") == CLSAction.HASH
|
|
|
|
# Case insensitive
|
|
assert transformer._get_action("USERS", "SSN") == CLSAction.HASH
|
|
|
|
# Different table should not match
|
|
assert transformer._get_action("employees", "ssn") is None
|
|
|
|
|
|
# Tests for merge_cls_rules and CLS_ACTION_PRECEDENCE
|
|
def test_cls_action_precedence() -> None:
|
|
"""
|
|
Test CLS_ACTION_PRECEDENCE has correct ordering: HIDE > NULLIFY > MASK > HASH.
|
|
"""
|
|
assert CLS_ACTION_PRECEDENCE[CLSAction.HIDE] > CLS_ACTION_PRECEDENCE[CLSAction.NULLIFY]
|
|
assert CLS_ACTION_PRECEDENCE[CLSAction.NULLIFY] > CLS_ACTION_PRECEDENCE[CLSAction.MASK]
|
|
assert CLS_ACTION_PRECEDENCE[CLSAction.MASK] > CLS_ACTION_PRECEDENCE[CLSAction.HASH]
|
|
|
|
|
|
def test_merge_cls_rules_empty() -> None:
|
|
"""
|
|
Test merge_cls_rules with no arguments returns empty dict.
|
|
"""
|
|
result = merge_cls_rules()
|
|
assert result == {}
|
|
|
|
|
|
def test_merge_cls_rules_single() -> None:
|
|
"""
|
|
Test merge_cls_rules with single rule set returns it unchanged.
|
|
"""
|
|
rules = {Table("foo"): {"col1": CLSAction.HASH}}
|
|
result = merge_cls_rules(rules)
|
|
assert result == rules
|
|
|
|
|
|
def test_merge_cls_rules_no_conflict() -> None:
|
|
"""
|
|
Test merge_cls_rules with non-conflicting rules.
|
|
"""
|
|
rules1 = {Table("foo"): {"col1": CLSAction.HASH}}
|
|
rules2 = {Table("foo"): {"col2": CLSAction.HIDE}}
|
|
result = merge_cls_rules(rules1, rules2)
|
|
|
|
assert result == {Table("foo"): {"col1": CLSAction.HASH, "col2": CLSAction.HIDE}}
|
|
|
|
|
|
def test_merge_cls_rules_different_tables() -> None:
|
|
"""
|
|
Test merge_cls_rules with different tables.
|
|
"""
|
|
rules1 = {Table("foo"): {"col1": CLSAction.HASH}}
|
|
rules2 = {Table("bar"): {"col1": CLSAction.NULLIFY}}
|
|
result = merge_cls_rules(rules1, rules2)
|
|
|
|
assert result == {
|
|
Table("foo"): {"col1": CLSAction.HASH},
|
|
Table("bar"): {"col1": CLSAction.NULLIFY},
|
|
}
|
|
|
|
|
|
def test_merge_cls_rules_conflict_nullify_over_hash() -> None:
|
|
"""
|
|
Test merge_cls_rules keeps NULLIFY over HASH (stricter).
|
|
"""
|
|
rules1 = {Table("foo"): {"col1": CLSAction.HASH}}
|
|
rules2 = {Table("foo"): {"col1": CLSAction.NULLIFY}}
|
|
result = merge_cls_rules(rules1, rules2)
|
|
|
|
assert result == {Table("foo"): {"col1": CLSAction.NULLIFY}}
|
|
|
|
|
|
def test_merge_cls_rules_conflict_hide_over_nullify() -> None:
|
|
"""
|
|
Test merge_cls_rules keeps HIDE over NULLIFY (stricter).
|
|
"""
|
|
rules1 = {Table("foo"): {"col1": CLSAction.NULLIFY}}
|
|
rules2 = {Table("foo"): {"col1": CLSAction.HIDE}}
|
|
result = merge_cls_rules(rules1, rules2)
|
|
|
|
assert result == {Table("foo"): {"col1": CLSAction.HIDE}}
|
|
|
|
|
|
def test_merge_cls_rules_conflict_mask_over_hash() -> None:
|
|
"""
|
|
Test merge_cls_rules keeps MASK over HASH (stricter).
|
|
"""
|
|
rules1 = {Table("foo"): {"col1": CLSAction.HASH}}
|
|
rules2 = {Table("foo"): {"col1": CLSAction.MASK}}
|
|
result = merge_cls_rules(rules1, rules2)
|
|
|
|
assert result == {Table("foo"): {"col1": CLSAction.MASK}}
|
|
|
|
|
|
def test_merge_cls_rules_conflict_nullify_over_mask() -> None:
|
|
"""
|
|
Test merge_cls_rules keeps NULLIFY over MASK (stricter).
|
|
"""
|
|
rules1 = {Table("foo"): {"col1": CLSAction.MASK}}
|
|
rules2 = {Table("foo"): {"col1": CLSAction.NULLIFY}}
|
|
result = merge_cls_rules(rules1, rules2)
|
|
|
|
assert result == {Table("foo"): {"col1": CLSAction.NULLIFY}}
|
|
|
|
|
|
def test_merge_cls_rules_keeps_stricter_regardless_of_order() -> None:
|
|
"""
|
|
Test merge_cls_rules keeps stricter action regardless of input order.
|
|
"""
|
|
rules1 = {Table("foo"): {"col1": CLSAction.NULLIFY}}
|
|
rules2 = {Table("foo"): {"col1": CLSAction.HASH}}
|
|
|
|
# NULLIFY is stricter than HASH, should be kept
|
|
result = merge_cls_rules(rules1, rules2)
|
|
assert result == {Table("foo"): {"col1": CLSAction.NULLIFY}}
|
|
|
|
# Reverse order should produce same result
|
|
result = merge_cls_rules(rules2, rules1)
|
|
assert result == {Table("foo"): {"col1": CLSAction.NULLIFY}}
|
|
|
|
|
|
def test_merge_cls_rules_user_example() -> None:
|
|
"""
|
|
Test merge_cls_rules with the user's example from requirements.
|
|
|
|
Given:
|
|
{Table("foo"): {"col1": CLSAction.NULLIFY}}
|
|
{Table("foo"): {"col1": CLSAction.HASH, "col2": CLSAction.HIDE}}
|
|
|
|
Should produce:
|
|
{Table("foo"): {"col1": CLSAction.NULLIFY, "col2": CLSAction.HIDE}}
|
|
"""
|
|
rules1 = {Table("foo"): {"col1": CLSAction.NULLIFY}}
|
|
rules2 = {Table("foo"): {"col1": CLSAction.HASH, "col2": CLSAction.HIDE}}
|
|
result = merge_cls_rules(rules1, rules2)
|
|
|
|
assert result == {Table("foo"): {"col1": CLSAction.NULLIFY, "col2": CLSAction.HIDE}}
|
|
|
|
|
|
def test_merge_cls_rules_multiple_rule_sets() -> None:
|
|
"""
|
|
Test merge_cls_rules with more than two rule sets.
|
|
"""
|
|
rules1 = {Table("foo"): {"col1": CLSAction.HASH}}
|
|
rules2 = {Table("foo"): {"col1": CLSAction.MASK, "col2": CLSAction.HASH}}
|
|
rules3 = {Table("foo"): {"col1": CLSAction.NULLIFY, "col3": CLSAction.HIDE}}
|
|
result = merge_cls_rules(rules1, rules2, rules3)
|
|
|
|
assert result == {
|
|
Table("foo"): {
|
|
"col1": CLSAction.NULLIFY, # NULLIFY is strictest
|
|
"col2": CLSAction.HASH,
|
|
"col3": CLSAction.HIDE,
|
|
}
|
|
}
|
|
|
|
|
|
def test_merge_cls_rules_with_schema() -> None:
|
|
"""
|
|
Test merge_cls_rules with Table containing schema.
|
|
"""
|
|
rules1 = {Table("users", schema="public"): {"ssn": CLSAction.HASH}}
|
|
rules2 = {Table("users", schema="public"): {"ssn": CLSAction.HIDE, "email": CLSAction.MASK}}
|
|
result = merge_cls_rules(rules1, rules2)
|
|
|
|
assert result == {
|
|
Table("users", schema="public"): {"ssn": CLSAction.HIDE, "email": CLSAction.MASK}
|
|
}
|
|
|
|
|
|
def test_merge_cls_rules_with_catalog_and_schema() -> None:
|
|
"""
|
|
Test merge_cls_rules with Table containing catalog and schema.
|
|
"""
|
|
rules1 = {Table("users", schema="public", catalog="mydb"): {"ssn": CLSAction.MASK}}
|
|
rules2 = {Table("users", schema="public", catalog="mydb"): {"ssn": CLSAction.NULLIFY}}
|
|
result = merge_cls_rules(rules1, rules2)
|
|
|
|
assert result == {
|
|
Table("users", schema="public", catalog="mydb"): {"ssn": CLSAction.NULLIFY}
|
|
}
|
|
|
|
|
|
def test_merge_cls_rules_different_schemas_same_table() -> None:
|
|
"""
|
|
Test merge_cls_rules treats tables with different schemas as distinct.
|
|
"""
|
|
rules1 = {Table("users", schema="public"): {"ssn": CLSAction.HASH}}
|
|
rules2 = {Table("users", schema="private"): {"ssn": CLSAction.HIDE}}
|
|
result = merge_cls_rules(rules1, rules2)
|
|
|
|
# Should be two separate entries, not merged
|
|
assert result == {
|
|
Table("users", schema="public"): {"ssn": CLSAction.HASH},
|
|
Table("users", schema="private"): {"ssn": CLSAction.HIDE},
|
|
}
|
|
|
|
|
|
def test_merge_cls_rules_complex_scenario() -> None:
|
|
"""
|
|
Test merge_cls_rules with a complex real-world scenario.
|
|
"""
|
|
# Organization-wide rules (less strict)
|
|
org_rules = {
|
|
Table("employees"): {"ssn": CLSAction.HASH, "salary": CLSAction.MASK},
|
|
Table("customers"): {"email": CLSAction.MASK},
|
|
}
|
|
|
|
# Department-specific rules (stricter for certain columns)
|
|
dept_rules = {
|
|
Table("employees"): {"ssn": CLSAction.HIDE, "phone": CLSAction.MASK},
|
|
Table("customers"): {"email": CLSAction.HASH, "credit_card": CLSAction.HIDE},
|
|
}
|
|
|
|
# User-specific rules (even stricter)
|
|
user_rules = {
|
|
Table("employees"): {"salary": CLSAction.HIDE},
|
|
}
|
|
|
|
result = merge_cls_rules(org_rules, dept_rules, user_rules)
|
|
|
|
assert result == {
|
|
Table("employees"): {
|
|
"ssn": CLSAction.HIDE, # HIDE > HASH
|
|
"salary": CLSAction.HIDE, # HIDE > MASK
|
|
"phone": CLSAction.MASK,
|
|
},
|
|
Table("customers"): {
|
|
"email": CLSAction.MASK, # MASK > HASH
|
|
"credit_card": CLSAction.HIDE,
|
|
},
|
|
}
|