Compare commits

...

1 Commits

Author SHA1 Message Date
Beto Dealmeida
7e7b9f84aa WIP 2024-07-03 15:44:01 -04:00
7 changed files with 600 additions and 154 deletions

View File

@@ -172,7 +172,9 @@ class DatasourceKind(StrEnum):
PHYSICAL = "physical" PHYSICAL = "physical"
class BaseDatasource(AuditMixinNullable, ImportExportMixin): # pylint: disable=too-many-public-methods class BaseDatasource(
AuditMixinNullable, ImportExportMixin
): # pylint: disable=too-many-public-methods
"""A common interface to objects that are queryable """A common interface to objects that are queryable
(tables and datasources)""" (tables and datasources)"""

View File

@@ -49,7 +49,7 @@ CONNECTION_HOST_DOWN_REGEX = re.compile(
class MssqlEngineSpec(BaseEngineSpec): class MssqlEngineSpec(BaseEngineSpec):
engine = "mssql" engine = "mssql"
engine_name = "Microsoft SQL Server" engine_name = "Microsoft SQL Server"
limit_method = LimitMethod.WRAP_SQL limit_method = LimitMethod.FORCE_LIMIT
max_column_name_length = 128 max_column_name_length = 128
allows_cte_in_subquery = False allows_cte_in_subquery = False
allow_limit_clause = False allow_limit_clause = False

View File

@@ -23,7 +23,7 @@ class TeradataEngineSpec(BaseEngineSpec):
engine = "teradatasql" engine = "teradatasql"
engine_name = "Teradata" engine_name = "Teradata"
limit_method = LimitMethod.WRAP_SQL limit_method = LimitMethod.FORCE_LIMIT
max_column_name_length = 30 # since 14.10 this is 128 max_column_name_length = 30 # since 14.10 this is 128
allow_limit_clause = False allow_limit_clause = False
select_keywords = {"SELECT", "SEL"} select_keywords = {"SELECT", "SEL"}

View File

@@ -17,6 +17,8 @@
# pylint: disable=too-many-lines # pylint: disable=too-many-lines
"""a collection of model-related helper classes and functions""" """a collection of model-related helper classes and functions"""
from __future__ import annotations
import builtins import builtins
import dataclasses import dataclasses
import logging import logging
@@ -806,7 +808,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
def get_sqla_row_level_filters( def get_sqla_row_level_filters(
self, self,
template_processor: Optional[BaseTemplateProcessor] = None, template_processor: BaseTemplateProcessor | None = None,
) -> list[TextClause]: ) -> list[TextClause]:
""" """
Return the appropriate row level security filters for this table and the Return the appropriate row level security filters for this table and the

View File

@@ -53,9 +53,8 @@ from superset.models.sql_lab import Query
from superset.result_set import SupersetResultSet from superset.result_set import SupersetResultSet
from superset.sql_parse import ( from superset.sql_parse import (
CtasMethod, CtasMethod,
insert_rls_as_subquery,
insert_rls_in_predicate,
ParsedQuery, ParsedQuery,
SQLStatement,
Table, Table,
) )
from superset.sqllab.limiting_factor import LimitingFactor from superset.sqllab.limiting_factor import LimitingFactor
@@ -205,67 +204,49 @@ def execute_sql_statement( # pylint: disable=too-many-statements
database: Database = query.database database: Database = query.database
db_engine_spec = database.db_engine_spec db_engine_spec = database.db_engine_spec
parsed_query = ParsedQuery(sql_statement, engine=db_engine_spec.engine) parsed_statement = SQLStatement(sql_statement, engine=db_engine_spec.engine)
if is_feature_enabled("RLS_IN_SQLLAB"): if is_feature_enabled("RLS_IN_SQLLAB"):
# There are two ways to insert RLS: either replacing the table with a subquery default_schema = database.get_default_schema_for_query(query)
# that has the RLS, or appending the RLS to the ``WHERE`` clause. The former is parsed_statement = parsed_statement.apply_rls(query.catalog, default_schema)
# safer, but not supported in all databases.
insert_rls = (
insert_rls_as_subquery
if database.db_engine_spec.allows_subqueries
and database.db_engine_spec.allows_alias_in_select
else insert_rls_in_predicate
)
# Insert any applicable RLS predicates if parsed_statement.is_dml() and not database.allow_dml:
parsed_query = ParsedQuery(
str(
insert_rls(
parsed_query._parsed[0], # pylint: disable=protected-access
database.id,
query.schema,
)
),
engine=db_engine_spec.engine,
)
sql = parsed_query.stripped()
# This is a test to see if the query is being
# limited by either the dropdown or the sql.
# We are testing to see if more rows exist than the limit.
increased_limit = None if query.limit is None else query.limit + 1
if not db_engine_spec.is_readonly_query(parsed_query) and not database.allow_dml:
raise SupersetErrorException( raise SupersetErrorException(
SupersetError( SupersetError(
message=__("Only SELECT statements are allowed against this database."), message=__("DML statements are not allowed in this database."),
error_type=SupersetErrorType.DML_NOT_ALLOWED_ERROR, error_type=SupersetErrorType.DML_NOT_ALLOWED_ERROR,
level=ErrorLevel.ERROR, level=ErrorLevel.ERROR,
) )
) )
if apply_ctas: if apply_ctas:
if not query.tmp_table_name: if not query.tmp_table_name:
start_dttm = datetime.fromtimestamp(query.start_time) start_dttm = datetime.fromtimestamp(query.start_time)
query.tmp_table_name = ( query.tmp_table_name = (
f'tmp_{query.user_id}_table_{start_dttm.strftime("%Y_%m_%d_%H_%M_%S")}' f'tmp_{query.user_id}_table_{start_dttm.strftime("%Y_%m_%d_%H_%M_%S")}'
) )
sql = parsed_query.as_create_table( parsed_statement = parsed_statement.as_create_table(
query.tmp_table_name, Table(query.tmp_table_name, query.tmp_schema_name, query.catalog),
schema_name=query.tmp_schema_name,
method=query.ctas_method, method=query.ctas_method,
) )
query.select_as_cta_used = True query.select_as_cta_used = True
increased_limit = None if query.limit is None else query.limit + 1
# Do not apply limit to the CTA queries when SQLLAB_CTAS_NO_LIMIT is set to true # Do not apply limit to the CTA queries when SQLLAB_CTAS_NO_LIMIT is set to true
if db_engine_spec.is_select_query(parsed_query) and not ( if parsed_statement.is_select() and not (
query.select_as_cta_used and SQLLAB_CTAS_NO_LIMIT query.select_as_cta_used and SQLLAB_CTAS_NO_LIMIT
): ):
if SQL_MAX_ROW and (not query.limit or query.limit > SQL_MAX_ROW): if SQL_MAX_ROW and (not query.limit or query.limit > SQL_MAX_ROW):
query.limit = SQL_MAX_ROW query.limit = SQL_MAX_ROW
sql = apply_limit_if_exists(database, increased_limit, query, sql)
if query.limit:
# Increase limit by one so we can test if there are more rows when the
# database returns exactly the number of rows requested by the user.
parsed_statement = parsed_statement.apply_limit(increased_limit)
# Hook to allow environment-specific mutation (usually comments) to the SQL # Hook to allow environment-specific mutation (usually comments) to the SQL
sql = parsed_statement.format(strip=True)
sql = database.mutate_sql_based_on_config(sql) sql = database.mutate_sql_based_on_config(sql)
try: try:
query.executed_sql = sql query.executed_sql = sql
@@ -333,19 +314,6 @@ def execute_sql_statement( # pylint: disable=too-many-statements
return SupersetResultSet(data, cursor_description, db_engine_spec) return SupersetResultSet(data, cursor_description, db_engine_spec)
def apply_limit_if_exists(
database: Database, increased_limit: Optional[int], query: Query, sql: str
) -> str:
if query.limit and increased_limit:
# We are fetching one more than the requested limit in order
# to test whether there are more rows than the limit. According to the DB
# Engine support it will choose top or limit parse
# Later, the extra row will be dropped before sending
# the results back to the user.
sql = database.apply_limit_to_sql(sql, increased_limit, force=True)
return sql
def _serialize_payload( def _serialize_payload(
payload: dict[Any, Any], use_msgpack: Optional[bool] = False payload: dict[Any, Any], use_msgpack: Optional[bool] = False
) -> Union[bytes, str]: ) -> Union[bytes, str]:

View File

@@ -32,7 +32,7 @@ import sqlparse
from flask_babel import gettext as __ from flask_babel import gettext as __
from jinja2 import nodes from jinja2 import nodes
from sqlalchemy import and_ from sqlalchemy import and_
from sqlglot import exp, parse, parse_one from sqlglot import exp
from sqlglot.dialects.dialect import Dialect, Dialects from sqlglot.dialects.dialect import Dialect, Dialects
from sqlglot.errors import ParseError, SqlglotError from sqlglot.errors import ParseError, SqlglotError
from sqlglot.optimizer.scope import Scope, ScopeType, traverse_scope from sqlglot.optimizer.scope import Scope, ScopeType, traverse_scope
@@ -308,7 +308,7 @@ def extract_tables_from_statement(
return set() return set()
try: try:
pseudo_query = parse_one(f"SELECT {literal.this}", dialect=dialect) pseudo_query = sqlglot.parse_one(f"SELECT {literal.this}", dialect=dialect)
except ParseError: except ParseError:
return set() return set()
sources = pseudo_query.find_all(exp.Table) sources = pseudo_query.find_all(exp.Table)
@@ -433,7 +433,7 @@ class BaseSQLStatement(Generic[InternalRepresentation]):
""" """
raise NotImplementedError() raise NotImplementedError()
def format(self, comments: bool = True) -> str: def format(self, comments: bool = True, strip: bool = False) -> str:
""" """
Format the statement, optionally ommitting comments. Format the statement, optionally ommitting comments.
""" """
@@ -451,10 +451,93 @@ class BaseSQLStatement(Generic[InternalRepresentation]):
""" """
raise NotImplementedError() raise NotImplementedError()
def apply_rls(
self,
catalog: str | None,
schema: str | None,
) -> InternalRepresentation:
"""
Apply Row Level Security to the SQL.
:param database: The database where the SQL will run
:param catalog: The default catalog for non-qualified table names
:param schema: The default schema for non-qualified table names
:return: The SQL with RLS applied
"""
raise NotImplementedError()
def __str__(self) -> str: def __str__(self) -> str:
return self.format() return self.format()
class RLSAsPredicate:
"""
Apply Row Level Security role as a predicate.
This transformer will apply any RLS predicates to the relevant tables. For example,
given the RLS rule:
table: some_table
clause: id = 42
If a user subject to the rule runs the following query:
SELECT foo FROM some_table WHERE bar = 'baz'
The query will be modified to:
SELECT foo FROM some_table WHERE bar = 'baz' AND id = 42
This approach is probably less secure than using subqueries, so it's only used for
databases without support for subqueries.
"""
def __init__(self, rules: dict[Table, str]) -> None:
self.rules = rules
def __call__(self, node: exp.Expression) -> exp.Expression:
if not isinstance(node, exp.Select):
return node
table_node = node.find(exp.Table)
if not table_node:
return node
table = Table(
str(table_node.this),
str(table_node.db) if table_node.db else None,
str(table_node.catalog) if table_node.catalog else None,
)
if predicate := self.rules.get(table):
if where := node.args.get("where"):
predicate = exp.And(this=predicate, expression=where.this)
node.set("where", exp.Where(this=predicate))
return node
class RLSAsSubquery:
def __init__(self, rules: dict[Table, str]) -> None:
self.rules = rules
def __call__(self, node: exp.Expression) -> exp.Expression:
if not isinstance(node, exp.Table):
return node
table = Table(
str(node.this),
str(node.db) if node.db else None,
str(node.catalog) if node.catalog else None,
)
if predicate := self.rules.get(table):
alias = node.alias
node.set("alias", None)
return f"(SELECT * FROM {node} WHERE {predicate}) AS {alias}"
return node
class SQLStatement(BaseSQLStatement[exp.Expression]): class SQLStatement(BaseSQLStatement[exp.Expression]):
""" """
A SQL statement. A SQL statement.
@@ -521,12 +604,19 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
dialect = SQLGLOT_DIALECTS.get(engine) dialect = SQLGLOT_DIALECTS.get(engine)
return extract_tables_from_statement(parsed, dialect) return extract_tables_from_statement(parsed, dialect)
def format(self, comments: bool = True) -> str: def format(self, comments: bool = True, strip: bool = False) -> str:
""" """
Pretty-format the SQL statement. Pretty-format the SQL statement.
""" """
write = Dialect.get_or_raise(self._dialect) write = Dialect.get_or_raise(self._dialect)
return write.generate(self._parsed, copy=False, comments=comments, pretty=True) output = write.generate(
self._parsed,
copy=False,
comments=comments,
pretty=True,
)
return output.strip(" \t\r\n;") if strip else output
def get_settings(self) -> dict[str, str | bool]: def get_settings(self) -> dict[str, str | bool]:
""" """
@@ -543,6 +633,186 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
for eq in set_item.find_all(exp.EQ) for eq in set_item.find_all(exp.EQ)
} }
def apply_rls(
self,
catalog: str | None,
schema: str | None,
) -> SQLStatement:
"""
Apply Row Level Security to the SQL.
:param catalog: The default catalog for non-qualified table names
:param schema: The default schema for non-qualified table names
:return: The SQL with RLS applied
"""
from superset.db_engine_specs import load_engine_specs
statement = self._parsed.copy()
# collect all relevant RLS rules
rules = {}
for table in self.tables:
if rls := self._get_rls_for_table(table, catalog, schema):
rules[table] = rls
if not rules:
return statement
use_subquery = all(
engine_spec.allows_subqueries
for engine_spec in load_engine_specs()
if engine_spec.engine == self.engine
)
transformer = RLSAsSubquery(rules) if use_subquery else RLSAsPredicate(rules)
return SQLStatement(statement.transform(transformer), self.engine)
def _get_rls_for_table(
self,
database: Database,
table: Table,
catalog: str | None,
schema: str | None,
) -> exp.Expression | None:
"""
Get the RLS for a table.
:param table: The table to get the RLS for
:param catalog: The default catalog for non-qualified table names
:param schema: The default schema for non-qualified table names
:return: The RLS for the table
"""
# pylint: disable=import-outside-toplevel
from superset import db
from superset.connectors.sqla.models import SqlaTable
dataset = db.session.query(SqlaTable).filter(
and_(
SqlaTable.database_id == database.id,
SqlaTable.catalog == table.catalog or catalog,
SqlaTable.schema == table.schema or schema,
SqlaTable.table_name == table.table,
).one_or_none()
)
if not dataset:
return None
filters = dataset.get_sqla_row_level_filters()
if not filters:
return None
rls = and_(*filters).compile(
dialect=database.get_dialect(),
compile_kwargs={"literal_binds": True},
)
return sqlglot.parse_one(str(rls), dialect=self._dialect)
def is_dml(self) -> bool:
"""
Check if the statement is DML.
:return: True if the statement is DML
"""
for node in self._parsed.walk():
if isinstance(
node,
(
exp.Insert,
exp.Update,
exp.Delete,
exp.Merge,
exp.Create,
exp.Alter,
exp.Drop,
exp.TruncateTable,
),
):
return True
return False
def as_create_table(self, table: Table, method: CtasMethod) -> SQLStatement:
"""
Convert the statement to a CREATE TABLE statement.
"""
create_table = exp.Create(
this=sqlglot.parse_one(table, into=exp.Table),
kind=method.value,
expression=self._parsed.copy(),
)
return SQLStatement(create_table, self.engine)
def is_select(self) -> bool:
"""
Check if the statement is a SELECT statement.
:return: True if the statement is a SELECT statement
"""
return isinstance(self._parsed, exp.Select)
def apply_limit(self, limit: int, force: bool = False) -> SQLStatement:
"""
Apply a limit to the SQL.
There are 3 strategies to limit queries, defined in the DB engine spec:
1. `FORCE_LIMIT`: a limit is added to the query, or the existing one is
replaced. This is the most efficient, since the database will produce at
most the number of rows that Superset will display.
2. `WRAP_SQL`: the query is wrapped in a subquery, and the limit is applied
to the outer query. This might be inneficient, since the database
optimizer might not be able to push the limit down to the inner query.
3. `FETCH_MANY`: no limit is applied, but only `LIMIT` rows are fetched from
the database. This is the least efficient, unless the database computes
rows as they are read by the cursor, which is unlikely.
:param limit: The limit to apply
:param force: Apply limit even when a lower one is present
:return: The SQL with the limit applied
"""
from superset.db_engine_specs import load_engine_specs
from superset.db_engine_specs.base import LimitMethod
methods = {
engine_spec.limit_method
for engine_spec in load_engine_specs()
if engine_spec.engine == self.engine
}
if not methods:
methods = {LimitMethod.FETCH_MANY}
# When multiple methods are supported, we prefer the more generic one --
# usually less efficient.
preference = [
LimitMethod.FETCH_MANY,
LimitMethod.WRAP_SQL,
LimitMethod.FORCE_LIMIT,
]
method = sorted(methods, key=preference.index)[0]
if not self.is_select() or method == LimitMethod.FETCH_MANY:
return SQLStatement(self._parsed.copy(), self.engine)
if method == LimitMethod.WRAP_SQL:
limited = exp.Select(
expressions=[exp.Star()],
from_=exp.Subquery(subquery=self._parsed.copy(), alias="inner_qry"),
limit=exp.Literal.number(limit),
)
return SQLStatement(limited, self.engine)
current_limit: int | None = None
for node in self._parsed.find_all(exp.Limit):
current_limit = int(node.expression.this)
break
if force or current_limit is None or limit < current_limit:
return SQLStatement(self._parsed.limit(limit), self.engine)
return SQLStatement(self._parsed.copy(), self.engine)
class KQLSplitState(enum.Enum): class KQLSplitState(enum.Enum):
""" """
@@ -666,11 +936,11 @@ class KustoKQLStatement(BaseSQLStatement[str]):
) )
return set() return set()
def format(self, comments: bool = True) -> str: def format(self, comments: bool = True, strip: bool = False) -> str:
""" """
Pretty-format the SQL statement. Pretty-format the SQL statement.
""" """
return self._parsed return self._parsed.strip(" \t\r\n;") if strip else self._parsed
def get_settings(self) -> dict[str, str | bool]: def get_settings(self) -> dict[str, str | bool]:
""" """
@@ -712,7 +982,10 @@ class SQLScript:
""" """
Pretty-format the SQL query. Pretty-format the SQL query.
""" """
return ";\n".join(statement.format(comments) for statement in self.statements) return (
";\n".join(statement.format(comments) for statement in self.statements)
+ ";"
)
def get_settings(self) -> dict[str, str | bool]: def get_settings(self) -> dict[str, str | bool]:
""" """
@@ -792,7 +1065,7 @@ class ParsedQuery:
Note: this uses sqlglot, since it's better at catching more edge cases. Note: this uses sqlglot, since it's better at catching more edge cases.
""" """
try: try:
statements = parse(self.stripped(), dialect=self._dialect) statements = sqlglot.parse(self.stripped(), dialect=self._dialect)
except SqlglotError as ex: except SqlglotError as ex:
logger.warning("Unable to parse SQL (%s): %s", self._dialect, self.sql) logger.warning("Unable to parse SQL (%s): %s", self._dialect, self.sql)
@@ -846,7 +1119,7 @@ class ParsedQuery:
return set() return set()
try: try:
pseudo_query = parse_one( pseudo_query = sqlglot.parse_one(
f"SELECT {literal.this}", f"SELECT {literal.this}",
dialect=self._dialect, dialect=self._dialect,
) )

View File

@@ -20,6 +20,7 @@ from typing import Optional
from unittest.mock import Mock from unittest.mock import Mock
import pytest import pytest
import sqlglot
import sqlparse import sqlparse
from pytest_mock import MockerFixture from pytest_mock import MockerFixture
from sqlalchemy import text from sqlalchemy import text
@@ -41,6 +42,8 @@ from superset.sql_parse import (
insert_rls_in_predicate, insert_rls_in_predicate,
KustoKQLStatement, KustoKQLStatement,
ParsedQuery, ParsedQuery,
RLSAsPredicate,
RLSAsSubquery,
sanitize_clause, sanitize_clause,
split_kql, split_kql,
SQLScript, SQLScript,
@@ -119,7 +122,8 @@ def test_extract_tables_subselect() -> None:
""" """
Test that tables inside subselects are parsed correctly. Test that tables inside subselects are parsed correctly.
""" """
assert extract_tables( assert (
extract_tables(
""" """
SELECT sub.* SELECT sub.*
FROM ( FROM (
@@ -129,9 +133,12 @@ FROM (
) sub, s2.t2 ) sub, s2.t2
WHERE sub.resolution = 'NONE' WHERE sub.resolution = 'NONE'
""" """
) == {Table("t1", "s1"), Table("t2", "s2")} )
== {Table("t1", "s1"), Table("t2", "s2")}
)
assert extract_tables( assert (
extract_tables(
""" """
SELECT sub.* SELECT sub.*
FROM ( FROM (
@@ -141,9 +148,12 @@ FROM (
) sub ) sub
WHERE sub.resolution = 'NONE' WHERE sub.resolution = 'NONE'
""" """
) == {Table("t1", "s1")} )
== {Table("t1", "s1")}
)
assert extract_tables( assert (
extract_tables(
""" """
SELECT * FROM t1 SELECT * FROM t1
WHERE s11 > ANY ( WHERE s11 > ANY (
@@ -156,7 +166,9 @@ WHERE s11 > ANY (
) )
) )
""" """
) == {Table("t1"), Table("t2"), Table("t3"), Table("t4")} )
== {Table("t1"), Table("t2"), Table("t3"), Table("t4")}
)
def test_extract_tables_select_in_expression() -> None: def test_extract_tables_select_in_expression() -> None:
@@ -227,24 +239,30 @@ def test_extract_tables_select_array() -> None:
""" """
Test that queries selecting arrays work as expected. Test that queries selecting arrays work as expected.
""" """
assert extract_tables( assert (
extract_tables(
""" """
SELECT ARRAY[1, 2, 3] AS my_array SELECT ARRAY[1, 2, 3] AS my_array
FROM t1 LIMIT 10 FROM t1 LIMIT 10
""" """
) == {Table("t1")} )
== {Table("t1")}
)
def test_extract_tables_select_if() -> None: def test_extract_tables_select_if() -> None:
""" """
Test that queries with an ``IF`` work as expected. Test that queries with an ``IF`` work as expected.
""" """
assert extract_tables( assert (
extract_tables(
""" """
SELECT IF(CARDINALITY(my_array) >= 3, my_array[3], NULL) SELECT IF(CARDINALITY(my_array) >= 3, my_array[3], NULL)
FROM t1 LIMIT 10 FROM t1 LIMIT 10
""" """
) == {Table("t1")} )
== {Table("t1")}
)
def test_extract_tables_with_catalog() -> None: def test_extract_tables_with_catalog() -> None:
@@ -312,29 +330,38 @@ def test_extract_tables_where_subquery() -> None:
""" """
Test that tables in a ``WHERE`` subquery are parsed correctly. Test that tables in a ``WHERE`` subquery are parsed correctly.
""" """
assert extract_tables( assert (
extract_tables(
""" """
SELECT name SELECT name
FROM t1 FROM t1
WHERE regionkey = (SELECT max(regionkey) FROM t2) WHERE regionkey = (SELECT max(regionkey) FROM t2)
""" """
) == {Table("t1"), Table("t2")} )
== {Table("t1"), Table("t2")}
)
assert extract_tables( assert (
extract_tables(
""" """
SELECT name SELECT name
FROM t1 FROM t1
WHERE regionkey IN (SELECT regionkey FROM t2) WHERE regionkey IN (SELECT regionkey FROM t2)
""" """
) == {Table("t1"), Table("t2")} )
== {Table("t1"), Table("t2")}
)
assert extract_tables( assert (
extract_tables(
""" """
SELECT name SELECT name
FROM t1 FROM t1
WHERE EXISTS (SELECT 1 FROM t2 WHERE t1.regionkey = t2.regionkey); WHERE EXISTS (SELECT 1 FROM t2 WHERE t1.regionkey = t2.regionkey);
""" """
) == {Table("t1"), Table("t2")} )
== {Table("t1"), Table("t2")}
)
def test_extract_tables_describe() -> None: def test_extract_tables_describe() -> None:
@@ -348,12 +375,15 @@ def test_extract_tables_show_partitions() -> None:
""" """
Test ``SHOW PARTITIONS``. Test ``SHOW PARTITIONS``.
""" """
assert extract_tables( assert (
extract_tables(
""" """
SHOW PARTITIONS FROM orders SHOW PARTITIONS FROM orders
WHERE ds >= '2013-01-01' ORDER BY ds DESC WHERE ds >= '2013-01-01' ORDER BY ds DESC
""" """
) == {Table("orders")} )
== {Table("orders")}
)
def test_extract_tables_join() -> None: def test_extract_tables_join() -> None:
@@ -365,7 +395,8 @@ def test_extract_tables_join() -> None:
Table("t2"), Table("t2"),
} }
assert extract_tables( assert (
extract_tables(
""" """
SELECT a.date, b.name SELECT a.date, b.name
FROM left_table a FROM left_table a
@@ -377,9 +408,12 @@ JOIN (
) b ) b
ON a.date = b.date ON a.date = b.date
""" """
) == {Table("left_table"), Table("right_table")} )
== {Table("left_table"), Table("right_table")}
)
assert extract_tables( assert (
extract_tables(
""" """
SELECT a.date, b.name SELECT a.date, b.name
FROM left_table a FROM left_table a
@@ -391,9 +425,12 @@ LEFT INNER JOIN (
) b ) b
ON a.date = b.date ON a.date = b.date
""" """
) == {Table("left_table"), Table("right_table")} )
== {Table("left_table"), Table("right_table")}
)
assert extract_tables( assert (
extract_tables(
""" """
SELECT a.date, b.name SELECT a.date, b.name
FROM left_table a FROM left_table a
@@ -405,9 +442,12 @@ RIGHT OUTER JOIN (
) b ) b
ON a.date = b.date ON a.date = b.date
""" """
) == {Table("left_table"), Table("right_table")} )
== {Table("left_table"), Table("right_table")}
)
assert extract_tables( assert (
extract_tables(
""" """
SELECT a.date, b.name SELECT a.date, b.name
FROM left_table a FROM left_table a
@@ -419,14 +459,17 @@ FULL OUTER JOIN (
) b ) b
ON a.date = b.date ON a.date = b.date
""" """
) == {Table("left_table"), Table("right_table")} )
== {Table("left_table"), Table("right_table")}
)
def test_extract_tables_semi_join() -> None: def test_extract_tables_semi_join() -> None:
""" """
Test ``LEFT SEMI JOIN``. Test ``LEFT SEMI JOIN``.
""" """
assert extract_tables( assert (
extract_tables(
""" """
SELECT a.date, b.name SELECT a.date, b.name
FROM left_table a FROM left_table a
@@ -438,14 +481,17 @@ LEFT SEMI JOIN (
) b ) b
ON a.data = b.date ON a.data = b.date
""" """
) == {Table("left_table"), Table("right_table")} )
== {Table("left_table"), Table("right_table")}
)
def test_extract_tables_combinations() -> None: def test_extract_tables_combinations() -> None:
""" """
Test a complex case with nested queries. Test a complex case with nested queries.
""" """
assert extract_tables( assert (
extract_tables(
""" """
SELECT * FROM t1 SELECT * FROM t1
WHERE s11 > ANY ( WHERE s11 > ANY (
@@ -460,9 +506,12 @@ WHERE s11 > ANY (
) )
) )
""" """
) == {Table("t1"), Table("t3"), Table("t4"), Table("t6")} )
== {Table("t1"), Table("t3"), Table("t4"), Table("t6")}
)
assert extract_tables( assert (
extract_tables(
""" """
SELECT * FROM ( SELECT * FROM (
SELECT * FROM ( SELECT * FROM (
@@ -472,14 +521,17 @@ SELECT * FROM (
) AS S2 ) AS S2
) AS S3 ) AS S3
""" """
) == {Table("EmployeeS")} )
== {Table("EmployeeS")}
)
def test_extract_tables_with() -> None: def test_extract_tables_with() -> None:
""" """
Test ``WITH``. Test ``WITH``.
""" """
assert extract_tables( assert (
extract_tables(
""" """
WITH WITH
x AS (SELECT a FROM t1), x AS (SELECT a FROM t1),
@@ -487,9 +539,12 @@ WITH
z AS (SELECT b AS c FROM t3) z AS (SELECT b AS c FROM t3)
SELECT c FROM z SELECT c FROM z
""" """
) == {Table("t1"), Table("t2"), Table("t3")} )
== {Table("t1"), Table("t2"), Table("t3")}
)
assert extract_tables( assert (
extract_tables(
""" """
WITH WITH
x AS (SELECT a FROM t1), x AS (SELECT a FROM t1),
@@ -497,20 +552,25 @@ WITH
z AS (SELECT b AS c FROM y) z AS (SELECT b AS c FROM y)
SELECT c FROM z SELECT c FROM z
""" """
) == {Table("t1")} )
== {Table("t1")}
)
def test_extract_tables_reusing_aliases() -> None: def test_extract_tables_reusing_aliases() -> None:
""" """
Test that the parser follows aliases. Test that the parser follows aliases.
""" """
assert extract_tables( assert (
extract_tables(
""" """
with q1 as ( select key from q2 where key = '5'), with q1 as ( select key from q2 where key = '5'),
q2 as ( select key from src where key = '5') q2 as ( select key from src where key = '5')
select * from (select key from q1) a select * from (select key from q1) a
""" """
) == {Table("src")} )
== {Table("src")}
)
# weird query with circular dependency # weird query with circular dependency
assert ( assert (
@@ -547,7 +607,8 @@ def test_extract_tables_complex() -> None:
""" """
Test a few complex queries. Test a few complex queries.
""" """
assert extract_tables( assert (
extract_tables(
""" """
SELECT sum(m_examples) AS "sum__m_example" SELECT sum(m_examples) AS "sum__m_example"
FROM ( FROM (
@@ -569,22 +630,28 @@ FROM (
ORDER BY "sum__m_example" DESC ORDER BY "sum__m_example" DESC
LIMIT 10; LIMIT 10;
""" """
) == { )
== {
Table("my_l_table"), Table("my_l_table"),
Table("my_b_table"), Table("my_b_table"),
Table("my_t_table"), Table("my_t_table"),
Table("inner_table"), Table("inner_table"),
} }
)
assert extract_tables( assert (
extract_tables(
""" """
SELECT * SELECT *
FROM table_a AS a, table_b AS b, table_c as c FROM table_a AS a, table_b AS b, table_c as c
WHERE a.id = b.id and b.id = c.id WHERE a.id = b.id and b.id = c.id
""" """
) == {Table("table_a"), Table("table_b"), Table("table_c")} )
== {Table("table_a"), Table("table_b"), Table("table_c")}
)
assert extract_tables( assert (
extract_tables(
""" """
SELECT somecol AS somecol SELECT somecol AS somecol
FROM ( FROM (
@@ -629,50 +696,62 @@ FROM (
LIMIT 50000 LIMIT 50000
) )
""" """
) == {Table("a"), Table("b"), Table("c"), Table("d"), Table("e"), Table("f")} )
== {Table("a"), Table("b"), Table("c"), Table("d"), Table("e"), Table("f")}
)
def test_extract_tables_mixed_from_clause() -> None: def test_extract_tables_mixed_from_clause() -> None:
""" """
Test that the parser handles a ``FROM`` clause with table and subselect. Test that the parser handles a ``FROM`` clause with table and subselect.
""" """
assert extract_tables( assert (
extract_tables(
""" """
SELECT * SELECT *
FROM table_a AS a, (select * from table_b) AS b, table_c as c FROM table_a AS a, (select * from table_b) AS b, table_c as c
WHERE a.id = b.id and b.id = c.id WHERE a.id = b.id and b.id = c.id
""" """
) == {Table("table_a"), Table("table_b"), Table("table_c")} )
== {Table("table_a"), Table("table_b"), Table("table_c")}
)
def test_extract_tables_nested_select() -> None: def test_extract_tables_nested_select() -> None:
""" """
Test that the parser handles selects inside functions. Test that the parser handles selects inside functions.
""" """
assert extract_tables( assert (
extract_tables(
""" """
select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(TABLE_NAME) select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(TABLE_NAME)
from INFORMATION_SCHEMA.COLUMNS from INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_SCHEMA like "%bi%"),0x7e))); WHERE TABLE_SCHEMA like "%bi%"),0x7e)));
""", """,
"mysql", "mysql",
) == {Table("COLUMNS", "INFORMATION_SCHEMA")} )
== {Table("COLUMNS", "INFORMATION_SCHEMA")}
)
assert extract_tables( assert (
extract_tables(
""" """
select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(COLUMN_NAME) select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(COLUMN_NAME)
from INFORMATION_SCHEMA.COLUMNS from INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_NAME="bi_achievement_daily"),0x7e))); WHERE TABLE_NAME="bi_achievement_daily"),0x7e)));
""", """,
"mysql", "mysql",
) == {Table("COLUMNS", "INFORMATION_SCHEMA")} )
== {Table("COLUMNS", "INFORMATION_SCHEMA")}
)
def test_extract_tables_complex_cte_with_prefix() -> None: def test_extract_tables_complex_cte_with_prefix() -> None:
""" """
Test that the parser handles CTEs with prefixes. Test that the parser handles CTEs with prefixes.
""" """
assert extract_tables( assert (
extract_tables(
""" """
WITH CTE__test (SalesPersonID, SalesOrderID, SalesYear) WITH CTE__test (SalesPersonID, SalesOrderID, SalesYear)
AS ( AS (
@@ -685,21 +764,26 @@ FROM CTE__test
GROUP BY SalesYear, SalesPersonID GROUP BY SalesYear, SalesPersonID
ORDER BY SalesPersonID, SalesYear; ORDER BY SalesPersonID, SalesYear;
""" """
) == {Table("SalesOrderHeader")} )
== {Table("SalesOrderHeader")}
)
def test_extract_tables_identifier_list_with_keyword_as_alias() -> None: def test_extract_tables_identifier_list_with_keyword_as_alias() -> None:
""" """
Test that aliases that are keywords are parsed correctly. Test that aliases that are keywords are parsed correctly.
""" """
assert extract_tables( assert (
extract_tables(
""" """
WITH WITH
f AS (SELECT * FROM foo), f AS (SELECT * FROM foo),
match AS (SELECT * FROM f) match AS (SELECT * FROM f)
SELECT * FROM match SELECT * FROM match
""" """
) == {Table("foo")} )
== {Table("foo")}
)
def test_update() -> None: def test_update() -> None:
@@ -1841,7 +1925,7 @@ def test_sqlquery() -> None:
script = SQLScript("SELECT 1; SELECT 2;", "sqlite") script = SQLScript("SELECT 1; SELECT 2;", "sqlite")
assert len(script.statements) == 2 assert len(script.statements) == 2
assert script.format() == "SELECT\n 1;\nSELECT\n 2" assert script.format() == "SELECT\n 1;\nSELECT\n 2;"
assert script.statements[0].format() == "SELECT\n 1" assert script.statements[0].format() == "SELECT\n 1"
script = SQLScript("SET a=1; SET a=2; SELECT 3;", "sqlite") script = SQLScript("SET a=1; SET a=2; SELECT 3;", "sqlite")
@@ -2058,3 +2142,120 @@ on $left.Day1 == $right.Day
| project Day1, Day2, Percentage = count_*100.0/count_1 | project Day1, Day2, Percentage = count_*100.0/count_1
""", """,
] ]
@pytest.mark.parametrize(
"sql,rules,expected",
[
(
"SELECT t.foo FROM some_table AS t",
{Table("some_table"): "id = 42"},
"SELECT t.foo FROM (SELECT * FROM some_table WHERE id = 42) AS t",
),
(
"SELECT t.foo FROM some_table AS t WHERE bar = 'baz'",
{Table("some_table"): "id = 42"},
(
"SELECT t.foo FROM (SELECT * FROM some_table WHERE id = 42) AS t "
"WHERE bar = 'baz'"
),
),
(
"SELECT t.foo FROM schema1.some_table AS t",
{Table("some_table", "schema1"): "id = 42"},
"SELECT t.foo FROM (SELECT * FROM schema1.some_table WHERE id = 42) AS t",
),
(
"SELECT t.foo FROM schema1.some_table AS t",
{Table("some_table", "schema2"): "id = 42"},
"SELECT t.foo FROM schema1.some_table AS t",
),
(
"SELECT t.foo FROM catalog1.schema1.some_table AS t",
{Table("some_table", "schema1", "catalog1"): "id = 42"},
"SELECT t.foo FROM (SELECT * FROM catalog1.schema1.some_table WHERE id = 42) AS t",
),
(
"SELECT t.foo FROM catalog1.schema1.some_table AS t",
{Table("some_table", "schema1", "catalog2"): "id = 42"},
"SELECT t.foo FROM catalog1.schema1.some_table AS t",
),
],
)
def test_RLSAsSubquery(sql: str, rules: dict[Table, str], expected: str) -> None:
"""
Test the `RLSAsSubquery` transformer.
"""
statement = sqlglot.parse_one(sql)
transformer = RLSAsSubquery(rules)
assert str(statement.transform(transformer)) == expected
@pytest.mark.parametrize(
"sql,rules,expected",
[
(
"SELECT t.foo FROM some_table AS t",
{Table("some_table"): "id = 42"},
"SELECT t.foo FROM some_table AS t WHERE id = 42",
),
(
"SELECT t.foo FROM some_table AS t WHERE bar = 'baz'",
{Table("some_table"): "id = 42"},
"SELECT t.foo FROM some_table AS t WHERE id = 42 AND bar = 'baz'",
),
],
)
def test_RLSAsPredicate(sql: str, rules: dict[Table, str], expected: str) -> None:
"""
Test the `RLSAsPredicate` transformer.
"""
statement = sqlglot.parse_one(sql)
transformer = RLSAsPredicate(rules)
assert str(statement.transform(transformer)) == expected
@pytest.mark.parametrize(
"sql,engine,limit,force,expected",
[
(
"SELECT TOP 10 * FROM Customers",
"teradatasql",
5,
False,
"SELECT\nTOP 5\n *\nFROM Customers",
),
(
"SELECT TOP 10 * FROM Customers",
"teradatasql",
15,
False,
"SELECT\nTOP 10\n *\nFROM Customers",
),
(
"SELECT TOP 10 * FROM Customers",
"teradatasql",
15,
True,
"SELECT\nTOP 15\n *\nFROM Customers",
),
(
"SELECT TOP 10 * FROM Customers",
"mssql",
15,
True,
"SELECT\nTOP 15\n *\nFROM Customers",
),
],
)
def test_apply_limit(
sql: str,
engine: str,
limit: int,
force: bool,
expected: str,
) -> None:
"""
Test the `apply_limit` function.
"""
assert SQLStatement(sql, engine).apply_limit(limit, force).format() == expected