Compare commits

...

1 Commits

Author SHA1 Message Date
Beto Dealmeida
7e7b9f84aa WIP 2024-07-03 15:44:01 -04:00
7 changed files with 600 additions and 154 deletions

View File

@@ -172,7 +172,9 @@ class DatasourceKind(StrEnum):
PHYSICAL = "physical"
class BaseDatasource(AuditMixinNullable, ImportExportMixin): # pylint: disable=too-many-public-methods
class BaseDatasource(
AuditMixinNullable, ImportExportMixin
): # pylint: disable=too-many-public-methods
"""A common interface to objects that are queryable
(tables and datasources)"""

View File

@@ -49,7 +49,7 @@ CONNECTION_HOST_DOWN_REGEX = re.compile(
class MssqlEngineSpec(BaseEngineSpec):
engine = "mssql"
engine_name = "Microsoft SQL Server"
limit_method = LimitMethod.WRAP_SQL
limit_method = LimitMethod.FORCE_LIMIT
max_column_name_length = 128
allows_cte_in_subquery = False
allow_limit_clause = False

View File

@@ -23,7 +23,7 @@ class TeradataEngineSpec(BaseEngineSpec):
engine = "teradatasql"
engine_name = "Teradata"
limit_method = LimitMethod.WRAP_SQL
limit_method = LimitMethod.FORCE_LIMIT
max_column_name_length = 30 # since 14.10 this is 128
allow_limit_clause = False
select_keywords = {"SELECT", "SEL"}

View File

@@ -17,6 +17,8 @@
# pylint: disable=too-many-lines
"""a collection of model-related helper classes and functions"""
from __future__ import annotations
import builtins
import dataclasses
import logging
@@ -806,7 +808,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
def get_sqla_row_level_filters(
self,
template_processor: Optional[BaseTemplateProcessor] = None,
template_processor: BaseTemplateProcessor | None = None,
) -> list[TextClause]:
"""
Return the appropriate row level security filters for this table and the

View File

@@ -53,9 +53,8 @@ from superset.models.sql_lab import Query
from superset.result_set import SupersetResultSet
from superset.sql_parse import (
CtasMethod,
insert_rls_as_subquery,
insert_rls_in_predicate,
ParsedQuery,
SQLStatement,
Table,
)
from superset.sqllab.limiting_factor import LimitingFactor
@@ -205,67 +204,49 @@ def execute_sql_statement( # pylint: disable=too-many-statements
database: Database = query.database
db_engine_spec = database.db_engine_spec
parsed_query = ParsedQuery(sql_statement, engine=db_engine_spec.engine)
parsed_statement = SQLStatement(sql_statement, engine=db_engine_spec.engine)
if is_feature_enabled("RLS_IN_SQLLAB"):
# There are two ways to insert RLS: either replacing the table with a subquery
# that has the RLS, or appending the RLS to the ``WHERE`` clause. The former is
# safer, but not supported in all databases.
insert_rls = (
insert_rls_as_subquery
if database.db_engine_spec.allows_subqueries
and database.db_engine_spec.allows_alias_in_select
else insert_rls_in_predicate
)
default_schema = database.get_default_schema_for_query(query)
parsed_statement = parsed_statement.apply_rls(query.catalog, default_schema)
# Insert any applicable RLS predicates
parsed_query = ParsedQuery(
str(
insert_rls(
parsed_query._parsed[0], # pylint: disable=protected-access
database.id,
query.schema,
)
),
engine=db_engine_spec.engine,
)
sql = parsed_query.stripped()
# This is a test to see if the query is being
# limited by either the dropdown or the sql.
# We are testing to see if more rows exist than the limit.
increased_limit = None if query.limit is None else query.limit + 1
if not db_engine_spec.is_readonly_query(parsed_query) and not database.allow_dml:
if parsed_statement.is_dml() and not database.allow_dml:
raise SupersetErrorException(
SupersetError(
message=__("Only SELECT statements are allowed against this database."),
message=__("DML statements are not allowed in this database."),
error_type=SupersetErrorType.DML_NOT_ALLOWED_ERROR,
level=ErrorLevel.ERROR,
)
)
if apply_ctas:
if not query.tmp_table_name:
start_dttm = datetime.fromtimestamp(query.start_time)
query.tmp_table_name = (
f'tmp_{query.user_id}_table_{start_dttm.strftime("%Y_%m_%d_%H_%M_%S")}'
)
sql = parsed_query.as_create_table(
query.tmp_table_name,
schema_name=query.tmp_schema_name,
parsed_statement = parsed_statement.as_create_table(
Table(query.tmp_table_name, query.tmp_schema_name, query.catalog),
method=query.ctas_method,
)
query.select_as_cta_used = True
increased_limit = None if query.limit is None else query.limit + 1
# Do not apply limit to the CTA queries when SQLLAB_CTAS_NO_LIMIT is set to true
if db_engine_spec.is_select_query(parsed_query) and not (
if parsed_statement.is_select() and not (
query.select_as_cta_used and SQLLAB_CTAS_NO_LIMIT
):
if SQL_MAX_ROW and (not query.limit or query.limit > SQL_MAX_ROW):
query.limit = SQL_MAX_ROW
sql = apply_limit_if_exists(database, increased_limit, query, sql)
if query.limit:
# Increase limit by one so we can test if there are more rows when the
# database returns exactly the number of rows requested by the user.
parsed_statement = parsed_statement.apply_limit(increased_limit)
# Hook to allow environment-specific mutation (usually comments) to the SQL
sql = parsed_statement.format(strip=True)
sql = database.mutate_sql_based_on_config(sql)
try:
query.executed_sql = sql
@@ -333,19 +314,6 @@ def execute_sql_statement( # pylint: disable=too-many-statements
return SupersetResultSet(data, cursor_description, db_engine_spec)
def apply_limit_if_exists(
database: Database, increased_limit: Optional[int], query: Query, sql: str
) -> str:
if query.limit and increased_limit:
# We are fetching one more than the requested limit in order
# to test whether there are more rows than the limit. According to the DB
# Engine support it will choose top or limit parse
# Later, the extra row will be dropped before sending
# the results back to the user.
sql = database.apply_limit_to_sql(sql, increased_limit, force=True)
return sql
def _serialize_payload(
payload: dict[Any, Any], use_msgpack: Optional[bool] = False
) -> Union[bytes, str]:

View File

@@ -32,7 +32,7 @@ import sqlparse
from flask_babel import gettext as __
from jinja2 import nodes
from sqlalchemy import and_
from sqlglot import exp, parse, parse_one
from sqlglot import exp
from sqlglot.dialects.dialect import Dialect, Dialects
from sqlglot.errors import ParseError, SqlglotError
from sqlglot.optimizer.scope import Scope, ScopeType, traverse_scope
@@ -308,7 +308,7 @@ def extract_tables_from_statement(
return set()
try:
pseudo_query = parse_one(f"SELECT {literal.this}", dialect=dialect)
pseudo_query = sqlglot.parse_one(f"SELECT {literal.this}", dialect=dialect)
except ParseError:
return set()
sources = pseudo_query.find_all(exp.Table)
@@ -433,7 +433,7 @@ class BaseSQLStatement(Generic[InternalRepresentation]):
"""
raise NotImplementedError()
def format(self, comments: bool = True) -> str:
def format(self, comments: bool = True, strip: bool = False) -> str:
"""
Format the statement, optionally ommitting comments.
"""
@@ -451,10 +451,93 @@ class BaseSQLStatement(Generic[InternalRepresentation]):
"""
raise NotImplementedError()
def apply_rls(
self,
catalog: str | None,
schema: str | None,
) -> InternalRepresentation:
"""
Apply Row Level Security to the SQL.
:param database: The database where the SQL will run
:param catalog: The default catalog for non-qualified table names
:param schema: The default schema for non-qualified table names
:return: The SQL with RLS applied
"""
raise NotImplementedError()
def __str__(self) -> str:
return self.format()
class RLSAsPredicate:
"""
Apply Row Level Security role as a predicate.
This transformer will apply any RLS predicates to the relevant tables. For example,
given the RLS rule:
table: some_table
clause: id = 42
If a user subject to the rule runs the following query:
SELECT foo FROM some_table WHERE bar = 'baz'
The query will be modified to:
SELECT foo FROM some_table WHERE bar = 'baz' AND id = 42
This approach is probably less secure than using subqueries, so it's only used for
databases without support for subqueries.
"""
def __init__(self, rules: dict[Table, str]) -> None:
self.rules = rules
def __call__(self, node: exp.Expression) -> exp.Expression:
if not isinstance(node, exp.Select):
return node
table_node = node.find(exp.Table)
if not table_node:
return node
table = Table(
str(table_node.this),
str(table_node.db) if table_node.db else None,
str(table_node.catalog) if table_node.catalog else None,
)
if predicate := self.rules.get(table):
if where := node.args.get("where"):
predicate = exp.And(this=predicate, expression=where.this)
node.set("where", exp.Where(this=predicate))
return node
class RLSAsSubquery:
def __init__(self, rules: dict[Table, str]) -> None:
self.rules = rules
def __call__(self, node: exp.Expression) -> exp.Expression:
if not isinstance(node, exp.Table):
return node
table = Table(
str(node.this),
str(node.db) if node.db else None,
str(node.catalog) if node.catalog else None,
)
if predicate := self.rules.get(table):
alias = node.alias
node.set("alias", None)
return f"(SELECT * FROM {node} WHERE {predicate}) AS {alias}"
return node
class SQLStatement(BaseSQLStatement[exp.Expression]):
"""
A SQL statement.
@@ -521,12 +604,19 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
dialect = SQLGLOT_DIALECTS.get(engine)
return extract_tables_from_statement(parsed, dialect)
def format(self, comments: bool = True) -> str:
def format(self, comments: bool = True, strip: bool = False) -> str:
"""
Pretty-format the SQL statement.
"""
write = Dialect.get_or_raise(self._dialect)
return write.generate(self._parsed, copy=False, comments=comments, pretty=True)
output = write.generate(
self._parsed,
copy=False,
comments=comments,
pretty=True,
)
return output.strip(" \t\r\n;") if strip else output
def get_settings(self) -> dict[str, str | bool]:
"""
@@ -543,6 +633,186 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
for eq in set_item.find_all(exp.EQ)
}
def apply_rls(
self,
catalog: str | None,
schema: str | None,
) -> SQLStatement:
"""
Apply Row Level Security to the SQL.
:param catalog: The default catalog for non-qualified table names
:param schema: The default schema for non-qualified table names
:return: The SQL with RLS applied
"""
from superset.db_engine_specs import load_engine_specs
statement = self._parsed.copy()
# collect all relevant RLS rules
rules = {}
for table in self.tables:
if rls := self._get_rls_for_table(table, catalog, schema):
rules[table] = rls
if not rules:
return statement
use_subquery = all(
engine_spec.allows_subqueries
for engine_spec in load_engine_specs()
if engine_spec.engine == self.engine
)
transformer = RLSAsSubquery(rules) if use_subquery else RLSAsPredicate(rules)
return SQLStatement(statement.transform(transformer), self.engine)
def _get_rls_for_table(
self,
database: Database,
table: Table,
catalog: str | None,
schema: str | None,
) -> exp.Expression | None:
"""
Get the RLS for a table.
:param table: The table to get the RLS for
:param catalog: The default catalog for non-qualified table names
:param schema: The default schema for non-qualified table names
:return: The RLS for the table
"""
# pylint: disable=import-outside-toplevel
from superset import db
from superset.connectors.sqla.models import SqlaTable
dataset = db.session.query(SqlaTable).filter(
and_(
SqlaTable.database_id == database.id,
SqlaTable.catalog == table.catalog or catalog,
SqlaTable.schema == table.schema or schema,
SqlaTable.table_name == table.table,
).one_or_none()
)
if not dataset:
return None
filters = dataset.get_sqla_row_level_filters()
if not filters:
return None
rls = and_(*filters).compile(
dialect=database.get_dialect(),
compile_kwargs={"literal_binds": True},
)
return sqlglot.parse_one(str(rls), dialect=self._dialect)
def is_dml(self) -> bool:
"""
Check if the statement is DML.
:return: True if the statement is DML
"""
for node in self._parsed.walk():
if isinstance(
node,
(
exp.Insert,
exp.Update,
exp.Delete,
exp.Merge,
exp.Create,
exp.Alter,
exp.Drop,
exp.TruncateTable,
),
):
return True
return False
def as_create_table(self, table: Table, method: CtasMethod) -> SQLStatement:
"""
Convert the statement to a CREATE TABLE statement.
"""
create_table = exp.Create(
this=sqlglot.parse_one(table, into=exp.Table),
kind=method.value,
expression=self._parsed.copy(),
)
return SQLStatement(create_table, self.engine)
def is_select(self) -> bool:
"""
Check if the statement is a SELECT statement.
:return: True if the statement is a SELECT statement
"""
return isinstance(self._parsed, exp.Select)
def apply_limit(self, limit: int, force: bool = False) -> SQLStatement:
"""
Apply a limit to the SQL.
There are 3 strategies to limit queries, defined in the DB engine spec:
1. `FORCE_LIMIT`: a limit is added to the query, or the existing one is
replaced. This is the most efficient, since the database will produce at
most the number of rows that Superset will display.
2. `WRAP_SQL`: the query is wrapped in a subquery, and the limit is applied
to the outer query. This might be inneficient, since the database
optimizer might not be able to push the limit down to the inner query.
3. `FETCH_MANY`: no limit is applied, but only `LIMIT` rows are fetched from
the database. This is the least efficient, unless the database computes
rows as they are read by the cursor, which is unlikely.
:param limit: The limit to apply
:param force: Apply limit even when a lower one is present
:return: The SQL with the limit applied
"""
from superset.db_engine_specs import load_engine_specs
from superset.db_engine_specs.base import LimitMethod
methods = {
engine_spec.limit_method
for engine_spec in load_engine_specs()
if engine_spec.engine == self.engine
}
if not methods:
methods = {LimitMethod.FETCH_MANY}
# When multiple methods are supported, we prefer the more generic one --
# usually less efficient.
preference = [
LimitMethod.FETCH_MANY,
LimitMethod.WRAP_SQL,
LimitMethod.FORCE_LIMIT,
]
method = sorted(methods, key=preference.index)[0]
if not self.is_select() or method == LimitMethod.FETCH_MANY:
return SQLStatement(self._parsed.copy(), self.engine)
if method == LimitMethod.WRAP_SQL:
limited = exp.Select(
expressions=[exp.Star()],
from_=exp.Subquery(subquery=self._parsed.copy(), alias="inner_qry"),
limit=exp.Literal.number(limit),
)
return SQLStatement(limited, self.engine)
current_limit: int | None = None
for node in self._parsed.find_all(exp.Limit):
current_limit = int(node.expression.this)
break
if force or current_limit is None or limit < current_limit:
return SQLStatement(self._parsed.limit(limit), self.engine)
return SQLStatement(self._parsed.copy(), self.engine)
class KQLSplitState(enum.Enum):
"""
@@ -666,11 +936,11 @@ class KustoKQLStatement(BaseSQLStatement[str]):
)
return set()
def format(self, comments: bool = True) -> str:
def format(self, comments: bool = True, strip: bool = False) -> str:
"""
Pretty-format the SQL statement.
"""
return self._parsed
return self._parsed.strip(" \t\r\n;") if strip else self._parsed
def get_settings(self) -> dict[str, str | bool]:
"""
@@ -712,7 +982,10 @@ class SQLScript:
"""
Pretty-format the SQL query.
"""
return ";\n".join(statement.format(comments) for statement in self.statements)
return (
";\n".join(statement.format(comments) for statement in self.statements)
+ ";"
)
def get_settings(self) -> dict[str, str | bool]:
"""
@@ -792,7 +1065,7 @@ class ParsedQuery:
Note: this uses sqlglot, since it's better at catching more edge cases.
"""
try:
statements = parse(self.stripped(), dialect=self._dialect)
statements = sqlglot.parse(self.stripped(), dialect=self._dialect)
except SqlglotError as ex:
logger.warning("Unable to parse SQL (%s): %s", self._dialect, self.sql)
@@ -846,7 +1119,7 @@ class ParsedQuery:
return set()
try:
pseudo_query = parse_one(
pseudo_query = sqlglot.parse_one(
f"SELECT {literal.this}",
dialect=self._dialect,
)

View File

@@ -20,6 +20,7 @@ from typing import Optional
from unittest.mock import Mock
import pytest
import sqlglot
import sqlparse
from pytest_mock import MockerFixture
from sqlalchemy import text
@@ -41,6 +42,8 @@ from superset.sql_parse import (
insert_rls_in_predicate,
KustoKQLStatement,
ParsedQuery,
RLSAsPredicate,
RLSAsSubquery,
sanitize_clause,
split_kql,
SQLScript,
@@ -119,8 +122,9 @@ def test_extract_tables_subselect() -> None:
"""
Test that tables inside subselects are parsed correctly.
"""
assert extract_tables(
"""
assert (
extract_tables(
"""
SELECT sub.*
FROM (
SELECT *
@@ -129,10 +133,13 @@ FROM (
) sub, s2.t2
WHERE sub.resolution = 'NONE'
"""
) == {Table("t1", "s1"), Table("t2", "s2")}
)
== {Table("t1", "s1"), Table("t2", "s2")}
)
assert extract_tables(
"""
assert (
extract_tables(
"""
SELECT sub.*
FROM (
SELECT *
@@ -141,10 +148,13 @@ FROM (
) sub
WHERE sub.resolution = 'NONE'
"""
) == {Table("t1", "s1")}
)
== {Table("t1", "s1")}
)
assert extract_tables(
"""
assert (
extract_tables(
"""
SELECT * FROM t1
WHERE s11 > ANY (
SELECT COUNT(*) /* no hint */ FROM t2
@@ -156,7 +166,9 @@ WHERE s11 > ANY (
)
)
"""
) == {Table("t1"), Table("t2"), Table("t3"), Table("t4")}
)
== {Table("t1"), Table("t2"), Table("t3"), Table("t4")}
)
def test_extract_tables_select_in_expression() -> None:
@@ -227,24 +239,30 @@ def test_extract_tables_select_array() -> None:
"""
Test that queries selecting arrays work as expected.
"""
assert extract_tables(
"""
assert (
extract_tables(
"""
SELECT ARRAY[1, 2, 3] AS my_array
FROM t1 LIMIT 10
"""
) == {Table("t1")}
)
== {Table("t1")}
)
def test_extract_tables_select_if() -> None:
"""
Test that queries with an ``IF`` work as expected.
"""
assert extract_tables(
"""
assert (
extract_tables(
"""
SELECT IF(CARDINALITY(my_array) >= 3, my_array[3], NULL)
FROM t1 LIMIT 10
"""
) == {Table("t1")}
)
== {Table("t1")}
)
def test_extract_tables_with_catalog() -> None:
@@ -312,29 +330,38 @@ def test_extract_tables_where_subquery() -> None:
"""
Test that tables in a ``WHERE`` subquery are parsed correctly.
"""
assert extract_tables(
"""
assert (
extract_tables(
"""
SELECT name
FROM t1
WHERE regionkey = (SELECT max(regionkey) FROM t2)
"""
) == {Table("t1"), Table("t2")}
)
== {Table("t1"), Table("t2")}
)
assert extract_tables(
"""
assert (
extract_tables(
"""
SELECT name
FROM t1
WHERE regionkey IN (SELECT regionkey FROM t2)
"""
) == {Table("t1"), Table("t2")}
)
== {Table("t1"), Table("t2")}
)
assert extract_tables(
"""
assert (
extract_tables(
"""
SELECT name
FROM t1
WHERE EXISTS (SELECT 1 FROM t2 WHERE t1.regionkey = t2.regionkey);
"""
) == {Table("t1"), Table("t2")}
)
== {Table("t1"), Table("t2")}
)
def test_extract_tables_describe() -> None:
@@ -348,12 +375,15 @@ def test_extract_tables_show_partitions() -> None:
"""
Test ``SHOW PARTITIONS``.
"""
assert extract_tables(
"""
assert (
extract_tables(
"""
SHOW PARTITIONS FROM orders
WHERE ds >= '2013-01-01' ORDER BY ds DESC
"""
) == {Table("orders")}
)
== {Table("orders")}
)
def test_extract_tables_join() -> None:
@@ -365,8 +395,9 @@ def test_extract_tables_join() -> None:
Table("t2"),
}
assert extract_tables(
"""
assert (
extract_tables(
"""
SELECT a.date, b.name
FROM left_table a
JOIN (
@@ -377,10 +408,13 @@ JOIN (
) b
ON a.date = b.date
"""
) == {Table("left_table"), Table("right_table")}
)
== {Table("left_table"), Table("right_table")}
)
assert extract_tables(
"""
assert (
extract_tables(
"""
SELECT a.date, b.name
FROM left_table a
LEFT INNER JOIN (
@@ -391,10 +425,13 @@ LEFT INNER JOIN (
) b
ON a.date = b.date
"""
) == {Table("left_table"), Table("right_table")}
)
== {Table("left_table"), Table("right_table")}
)
assert extract_tables(
"""
assert (
extract_tables(
"""
SELECT a.date, b.name
FROM left_table a
RIGHT OUTER JOIN (
@@ -405,10 +442,13 @@ RIGHT OUTER JOIN (
) b
ON a.date = b.date
"""
) == {Table("left_table"), Table("right_table")}
)
== {Table("left_table"), Table("right_table")}
)
assert extract_tables(
"""
assert (
extract_tables(
"""
SELECT a.date, b.name
FROM left_table a
FULL OUTER JOIN (
@@ -419,15 +459,18 @@ FULL OUTER JOIN (
) b
ON a.date = b.date
"""
) == {Table("left_table"), Table("right_table")}
)
== {Table("left_table"), Table("right_table")}
)
def test_extract_tables_semi_join() -> None:
"""
Test ``LEFT SEMI JOIN``.
"""
assert extract_tables(
"""
assert (
extract_tables(
"""
SELECT a.date, b.name
FROM left_table a
LEFT SEMI JOIN (
@@ -438,15 +481,18 @@ LEFT SEMI JOIN (
) b
ON a.data = b.date
"""
) == {Table("left_table"), Table("right_table")}
)
== {Table("left_table"), Table("right_table")}
)
def test_extract_tables_combinations() -> None:
"""
Test a complex case with nested queries.
"""
assert extract_tables(
"""
assert (
extract_tables(
"""
SELECT * FROM t1
WHERE s11 > ANY (
SELECT * FROM t1 UNION ALL SELECT * FROM (
@@ -460,10 +506,13 @@ WHERE s11 > ANY (
)
)
"""
) == {Table("t1"), Table("t3"), Table("t4"), Table("t6")}
)
== {Table("t1"), Table("t3"), Table("t4"), Table("t6")}
)
assert extract_tables(
"""
assert (
extract_tables(
"""
SELECT * FROM (
SELECT * FROM (
SELECT * FROM (
@@ -472,45 +521,56 @@ SELECT * FROM (
) AS S2
) AS S3
"""
) == {Table("EmployeeS")}
)
== {Table("EmployeeS")}
)
def test_extract_tables_with() -> None:
"""
Test ``WITH``.
"""
assert extract_tables(
"""
assert (
extract_tables(
"""
WITH
x AS (SELECT a FROM t1),
y AS (SELECT a AS b FROM t2),
z AS (SELECT b AS c FROM t3)
SELECT c FROM z
"""
) == {Table("t1"), Table("t2"), Table("t3")}
)
== {Table("t1"), Table("t2"), Table("t3")}
)
assert extract_tables(
"""
assert (
extract_tables(
"""
WITH
x AS (SELECT a FROM t1),
y AS (SELECT a AS b FROM x),
z AS (SELECT b AS c FROM y)
SELECT c FROM z
"""
) == {Table("t1")}
)
== {Table("t1")}
)
def test_extract_tables_reusing_aliases() -> None:
"""
Test that the parser follows aliases.
"""
assert extract_tables(
"""
assert (
extract_tables(
"""
with q1 as ( select key from q2 where key = '5'),
q2 as ( select key from src where key = '5')
select * from (select key from q1) a
"""
) == {Table("src")}
)
== {Table("src")}
)
# weird query with circular dependency
assert (
@@ -547,8 +607,9 @@ def test_extract_tables_complex() -> None:
"""
Test a few complex queries.
"""
assert extract_tables(
"""
assert (
extract_tables(
"""
SELECT sum(m_examples) AS "sum__m_example"
FROM (
SELECT
@@ -569,23 +630,29 @@ FROM (
ORDER BY "sum__m_example" DESC
LIMIT 10;
"""
) == {
Table("my_l_table"),
Table("my_b_table"),
Table("my_t_table"),
Table("inner_table"),
}
)
== {
Table("my_l_table"),
Table("my_b_table"),
Table("my_t_table"),
Table("inner_table"),
}
)
assert extract_tables(
"""
assert (
extract_tables(
"""
SELECT *
FROM table_a AS a, table_b AS b, table_c as c
WHERE a.id = b.id and b.id = c.id
"""
) == {Table("table_a"), Table("table_b"), Table("table_c")}
)
== {Table("table_a"), Table("table_b"), Table("table_c")}
)
assert extract_tables(
"""
assert (
extract_tables(
"""
SELECT somecol AS somecol
FROM (
WITH bla AS (
@@ -629,51 +696,63 @@ FROM (
LIMIT 50000
)
"""
) == {Table("a"), Table("b"), Table("c"), Table("d"), Table("e"), Table("f")}
)
== {Table("a"), Table("b"), Table("c"), Table("d"), Table("e"), Table("f")}
)
def test_extract_tables_mixed_from_clause() -> None:
"""
Test that the parser handles a ``FROM`` clause with table and subselect.
"""
assert extract_tables(
"""
assert (
extract_tables(
"""
SELECT *
FROM table_a AS a, (select * from table_b) AS b, table_c as c
WHERE a.id = b.id and b.id = c.id
"""
) == {Table("table_a"), Table("table_b"), Table("table_c")}
)
== {Table("table_a"), Table("table_b"), Table("table_c")}
)
def test_extract_tables_nested_select() -> None:
"""
Test that the parser handles selects inside functions.
"""
assert extract_tables(
"""
assert (
extract_tables(
"""
select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(TABLE_NAME)
from INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_SCHEMA like "%bi%"),0x7e)));
""",
"mysql",
) == {Table("COLUMNS", "INFORMATION_SCHEMA")}
"mysql",
)
== {Table("COLUMNS", "INFORMATION_SCHEMA")}
)
assert extract_tables(
"""
assert (
extract_tables(
"""
select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(COLUMN_NAME)
from INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_NAME="bi_achievement_daily"),0x7e)));
""",
"mysql",
) == {Table("COLUMNS", "INFORMATION_SCHEMA")}
"mysql",
)
== {Table("COLUMNS", "INFORMATION_SCHEMA")}
)
def test_extract_tables_complex_cte_with_prefix() -> None:
"""
Test that the parser handles CTEs with prefixes.
"""
assert extract_tables(
"""
assert (
extract_tables(
"""
WITH CTE__test (SalesPersonID, SalesOrderID, SalesYear)
AS (
SELECT SalesPersonID, SalesOrderID, YEAR(OrderDate) AS SalesYear
@@ -685,21 +764,26 @@ FROM CTE__test
GROUP BY SalesYear, SalesPersonID
ORDER BY SalesPersonID, SalesYear;
"""
) == {Table("SalesOrderHeader")}
)
== {Table("SalesOrderHeader")}
)
def test_extract_tables_identifier_list_with_keyword_as_alias() -> None:
"""
Test that aliases that are keywords are parsed correctly.
"""
assert extract_tables(
"""
assert (
extract_tables(
"""
WITH
f AS (SELECT * FROM foo),
match AS (SELECT * FROM f)
SELECT * FROM match
"""
) == {Table("foo")}
)
== {Table("foo")}
)
def test_update() -> None:
@@ -1841,7 +1925,7 @@ def test_sqlquery() -> None:
script = SQLScript("SELECT 1; SELECT 2;", "sqlite")
assert len(script.statements) == 2
assert script.format() == "SELECT\n 1;\nSELECT\n 2"
assert script.format() == "SELECT\n 1;\nSELECT\n 2;"
assert script.statements[0].format() == "SELECT\n 1"
script = SQLScript("SET a=1; SET a=2; SELECT 3;", "sqlite")
@@ -2058,3 +2142,120 @@ on $left.Day1 == $right.Day
| project Day1, Day2, Percentage = count_*100.0/count_1
""",
]
@pytest.mark.parametrize(
"sql,rules,expected",
[
(
"SELECT t.foo FROM some_table AS t",
{Table("some_table"): "id = 42"},
"SELECT t.foo FROM (SELECT * FROM some_table WHERE id = 42) AS t",
),
(
"SELECT t.foo FROM some_table AS t WHERE bar = 'baz'",
{Table("some_table"): "id = 42"},
(
"SELECT t.foo FROM (SELECT * FROM some_table WHERE id = 42) AS t "
"WHERE bar = 'baz'"
),
),
(
"SELECT t.foo FROM schema1.some_table AS t",
{Table("some_table", "schema1"): "id = 42"},
"SELECT t.foo FROM (SELECT * FROM schema1.some_table WHERE id = 42) AS t",
),
(
"SELECT t.foo FROM schema1.some_table AS t",
{Table("some_table", "schema2"): "id = 42"},
"SELECT t.foo FROM schema1.some_table AS t",
),
(
"SELECT t.foo FROM catalog1.schema1.some_table AS t",
{Table("some_table", "schema1", "catalog1"): "id = 42"},
"SELECT t.foo FROM (SELECT * FROM catalog1.schema1.some_table WHERE id = 42) AS t",
),
(
"SELECT t.foo FROM catalog1.schema1.some_table AS t",
{Table("some_table", "schema1", "catalog2"): "id = 42"},
"SELECT t.foo FROM catalog1.schema1.some_table AS t",
),
],
)
def test_RLSAsSubquery(sql: str, rules: dict[Table, str], expected: str) -> None:
"""
Test the `RLSAsSubquery` transformer.
"""
statement = sqlglot.parse_one(sql)
transformer = RLSAsSubquery(rules)
assert str(statement.transform(transformer)) == expected
@pytest.mark.parametrize(
"sql,rules,expected",
[
(
"SELECT t.foo FROM some_table AS t",
{Table("some_table"): "id = 42"},
"SELECT t.foo FROM some_table AS t WHERE id = 42",
),
(
"SELECT t.foo FROM some_table AS t WHERE bar = 'baz'",
{Table("some_table"): "id = 42"},
"SELECT t.foo FROM some_table AS t WHERE id = 42 AND bar = 'baz'",
),
],
)
def test_RLSAsPredicate(sql: str, rules: dict[Table, str], expected: str) -> None:
"""
Test the `RLSAsPredicate` transformer.
"""
statement = sqlglot.parse_one(sql)
transformer = RLSAsPredicate(rules)
assert str(statement.transform(transformer)) == expected
@pytest.mark.parametrize(
"sql,engine,limit,force,expected",
[
(
"SELECT TOP 10 * FROM Customers",
"teradatasql",
5,
False,
"SELECT\nTOP 5\n *\nFROM Customers",
),
(
"SELECT TOP 10 * FROM Customers",
"teradatasql",
15,
False,
"SELECT\nTOP 10\n *\nFROM Customers",
),
(
"SELECT TOP 10 * FROM Customers",
"teradatasql",
15,
True,
"SELECT\nTOP 15\n *\nFROM Customers",
),
(
"SELECT TOP 10 * FROM Customers",
"mssql",
15,
True,
"SELECT\nTOP 15\n *\nFROM Customers",
),
],
)
def test_apply_limit(
sql: str,
engine: str,
limit: int,
force: bool,
expected: str,
) -> None:
"""
Test the `apply_limit` function.
"""
assert SQLStatement(sql, engine).apply_limit(limit, force).format() == expected