diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index c38a0085a53..05be7d352a3 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -172,7 +172,9 @@ class DatasourceKind(StrEnum): PHYSICAL = "physical" -class BaseDatasource(AuditMixinNullable, ImportExportMixin): # pylint: disable=too-many-public-methods +class BaseDatasource( + AuditMixinNullable, ImportExportMixin +): # pylint: disable=too-many-public-methods """A common interface to objects that are queryable (tables and datasources)""" diff --git a/superset/db_engine_specs/mssql.py b/superset/db_engine_specs/mssql.py index d5cc86c859a..3521deeb51c 100644 --- a/superset/db_engine_specs/mssql.py +++ b/superset/db_engine_specs/mssql.py @@ -49,7 +49,7 @@ CONNECTION_HOST_DOWN_REGEX = re.compile( class MssqlEngineSpec(BaseEngineSpec): engine = "mssql" engine_name = "Microsoft SQL Server" - limit_method = LimitMethod.WRAP_SQL + limit_method = LimitMethod.FORCE_LIMIT max_column_name_length = 128 allows_cte_in_subquery = False allow_limit_clause = False diff --git a/superset/db_engine_specs/teradata.py b/superset/db_engine_specs/teradata.py index 887add24e90..910ac9461d6 100644 --- a/superset/db_engine_specs/teradata.py +++ b/superset/db_engine_specs/teradata.py @@ -23,7 +23,7 @@ class TeradataEngineSpec(BaseEngineSpec): engine = "teradatasql" engine_name = "Teradata" - limit_method = LimitMethod.WRAP_SQL + limit_method = LimitMethod.FORCE_LIMIT max_column_name_length = 30 # since 14.10 this is 128 allow_limit_clause = False select_keywords = {"SELECT", "SEL"} diff --git a/superset/models/helpers.py b/superset/models/helpers.py index b841426ff71..cd8b8d80825 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -17,6 +17,8 @@ # pylint: disable=too-many-lines """a collection of model-related helper classes and functions""" +from __future__ import annotations + import builtins import dataclasses import logging @@ -806,7 +808,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods def get_sqla_row_level_filters( self, - template_processor: Optional[BaseTemplateProcessor] = None, + template_processor: BaseTemplateProcessor | None = None, ) -> list[TextClause]: """ Return the appropriate row level security filters for this table and the diff --git a/superset/sql_lab.py b/superset/sql_lab.py index f20bff35c30..f032ce18824 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -53,9 +53,8 @@ from superset.models.sql_lab import Query from superset.result_set import SupersetResultSet from superset.sql_parse import ( CtasMethod, - insert_rls_as_subquery, - insert_rls_in_predicate, ParsedQuery, + SQLStatement, Table, ) from superset.sqllab.limiting_factor import LimitingFactor @@ -205,67 +204,49 @@ def execute_sql_statement( # pylint: disable=too-many-statements database: Database = query.database db_engine_spec = database.db_engine_spec - parsed_query = ParsedQuery(sql_statement, engine=db_engine_spec.engine) + parsed_statement = SQLStatement(sql_statement, engine=db_engine_spec.engine) + if is_feature_enabled("RLS_IN_SQLLAB"): - # There are two ways to insert RLS: either replacing the table with a subquery - # that has the RLS, or appending the RLS to the ``WHERE`` clause. The former is - # safer, but not supported in all databases. - insert_rls = ( - insert_rls_as_subquery - if database.db_engine_spec.allows_subqueries - and database.db_engine_spec.allows_alias_in_select - else insert_rls_in_predicate - ) + default_schema = database.get_default_schema_for_query(query) + parsed_statement = parsed_statement.apply_rls(query.catalog, default_schema) - # Insert any applicable RLS predicates - parsed_query = ParsedQuery( - str( - insert_rls( - parsed_query._parsed[0], # pylint: disable=protected-access - database.id, - query.schema, - ) - ), - engine=db_engine_spec.engine, - ) - - sql = parsed_query.stripped() - - # This is a test to see if the query is being - # limited by either the dropdown or the sql. - # We are testing to see if more rows exist than the limit. - increased_limit = None if query.limit is None else query.limit + 1 - - if not db_engine_spec.is_readonly_query(parsed_query) and not database.allow_dml: + if parsed_statement.is_dml() and not database.allow_dml: raise SupersetErrorException( SupersetError( - message=__("Only SELECT statements are allowed against this database."), + message=__("DML statements are not allowed in this database."), error_type=SupersetErrorType.DML_NOT_ALLOWED_ERROR, level=ErrorLevel.ERROR, ) ) + if apply_ctas: if not query.tmp_table_name: start_dttm = datetime.fromtimestamp(query.start_time) query.tmp_table_name = ( f'tmp_{query.user_id}_table_{start_dttm.strftime("%Y_%m_%d_%H_%M_%S")}' ) - sql = parsed_query.as_create_table( - query.tmp_table_name, - schema_name=query.tmp_schema_name, + parsed_statement = parsed_statement.as_create_table( + Table(query.tmp_table_name, query.tmp_schema_name, query.catalog), method=query.ctas_method, ) query.select_as_cta_used = True + increased_limit = None if query.limit is None else query.limit + 1 + # Do not apply limit to the CTA queries when SQLLAB_CTAS_NO_LIMIT is set to true - if db_engine_spec.is_select_query(parsed_query) and not ( + if parsed_statement.is_select() and not ( query.select_as_cta_used and SQLLAB_CTAS_NO_LIMIT ): if SQL_MAX_ROW and (not query.limit or query.limit > SQL_MAX_ROW): query.limit = SQL_MAX_ROW - sql = apply_limit_if_exists(database, increased_limit, query, sql) + + if query.limit: + # Increase limit by one so we can test if there are more rows when the + # database returns exactly the number of rows requested by the user. + parsed_statement = parsed_statement.apply_limit(increased_limit) # Hook to allow environment-specific mutation (usually comments) to the SQL + sql = parsed_statement.format(strip=True) sql = database.mutate_sql_based_on_config(sql) try: query.executed_sql = sql @@ -333,19 +314,6 @@ def execute_sql_statement( # pylint: disable=too-many-statements return SupersetResultSet(data, cursor_description, db_engine_spec) -def apply_limit_if_exists( - database: Database, increased_limit: Optional[int], query: Query, sql: str -) -> str: - if query.limit and increased_limit: - # We are fetching one more than the requested limit in order - # to test whether there are more rows than the limit. According to the DB - # Engine support it will choose top or limit parse - # Later, the extra row will be dropped before sending - # the results back to the user. - sql = database.apply_limit_to_sql(sql, increased_limit, force=True) - return sql - - def _serialize_payload( payload: dict[Any, Any], use_msgpack: Optional[bool] = False ) -> Union[bytes, str]: diff --git a/superset/sql_parse.py b/superset/sql_parse.py index 8046e8f74f6..4a041f5d3a9 100644 --- a/superset/sql_parse.py +++ b/superset/sql_parse.py @@ -32,7 +32,7 @@ import sqlparse from flask_babel import gettext as __ from jinja2 import nodes from sqlalchemy import and_ -from sqlglot import exp, parse, parse_one +from sqlglot import exp from sqlglot.dialects.dialect import Dialect, Dialects from sqlglot.errors import ParseError, SqlglotError from sqlglot.optimizer.scope import Scope, ScopeType, traverse_scope @@ -308,7 +308,7 @@ def extract_tables_from_statement( return set() try: - pseudo_query = parse_one(f"SELECT {literal.this}", dialect=dialect) + pseudo_query = sqlglot.parse_one(f"SELECT {literal.this}", dialect=dialect) except ParseError: return set() sources = pseudo_query.find_all(exp.Table) @@ -433,7 +433,7 @@ class BaseSQLStatement(Generic[InternalRepresentation]): """ raise NotImplementedError() - def format(self, comments: bool = True) -> str: + def format(self, comments: bool = True, strip: bool = False) -> str: """ Format the statement, optionally ommitting comments. """ @@ -451,10 +451,93 @@ class BaseSQLStatement(Generic[InternalRepresentation]): """ raise NotImplementedError() + def apply_rls( + self, + catalog: str | None, + schema: str | None, + ) -> InternalRepresentation: + """ + Apply Row Level Security to the SQL. + + :param database: The database where the SQL will run + :param catalog: The default catalog for non-qualified table names + :param schema: The default schema for non-qualified table names + :return: The SQL with RLS applied + """ + raise NotImplementedError() + def __str__(self) -> str: return self.format() +class RLSAsPredicate: + """ + Apply Row Level Security role as a predicate. + + This transformer will apply any RLS predicates to the relevant tables. For example, + given the RLS rule: + + table: some_table + clause: id = 42 + + If a user subject to the rule runs the following query: + + SELECT foo FROM some_table WHERE bar = 'baz' + + The query will be modified to: + + SELECT foo FROM some_table WHERE bar = 'baz' AND id = 42 + + This approach is probably less secure than using subqueries, so it's only used for + databases without support for subqueries. + """ + + def __init__(self, rules: dict[Table, str]) -> None: + self.rules = rules + + def __call__(self, node: exp.Expression) -> exp.Expression: + if not isinstance(node, exp.Select): + return node + + table_node = node.find(exp.Table) + if not table_node: + return node + + table = Table( + str(table_node.this), + str(table_node.db) if table_node.db else None, + str(table_node.catalog) if table_node.catalog else None, + ) + if predicate := self.rules.get(table): + if where := node.args.get("where"): + predicate = exp.And(this=predicate, expression=where.this) + + node.set("where", exp.Where(this=predicate)) + + return node + + +class RLSAsSubquery: + def __init__(self, rules: dict[Table, str]) -> None: + self.rules = rules + + def __call__(self, node: exp.Expression) -> exp.Expression: + if not isinstance(node, exp.Table): + return node + + table = Table( + str(node.this), + str(node.db) if node.db else None, + str(node.catalog) if node.catalog else None, + ) + if predicate := self.rules.get(table): + alias = node.alias + node.set("alias", None) + return f"(SELECT * FROM {node} WHERE {predicate}) AS {alias}" + + return node + + class SQLStatement(BaseSQLStatement[exp.Expression]): """ A SQL statement. @@ -521,12 +604,19 @@ class SQLStatement(BaseSQLStatement[exp.Expression]): dialect = SQLGLOT_DIALECTS.get(engine) return extract_tables_from_statement(parsed, dialect) - def format(self, comments: bool = True) -> str: + def format(self, comments: bool = True, strip: bool = False) -> str: """ Pretty-format the SQL statement. """ write = Dialect.get_or_raise(self._dialect) - return write.generate(self._parsed, copy=False, comments=comments, pretty=True) + output = write.generate( + self._parsed, + copy=False, + comments=comments, + pretty=True, + ) + + return output.strip(" \t\r\n;") if strip else output def get_settings(self) -> dict[str, str | bool]: """ @@ -543,6 +633,186 @@ class SQLStatement(BaseSQLStatement[exp.Expression]): for eq in set_item.find_all(exp.EQ) } + def apply_rls( + self, + catalog: str | None, + schema: str | None, + ) -> SQLStatement: + """ + Apply Row Level Security to the SQL. + + :param catalog: The default catalog for non-qualified table names + :param schema: The default schema for non-qualified table names + :return: The SQL with RLS applied + """ + from superset.db_engine_specs import load_engine_specs + + statement = self._parsed.copy() + + # collect all relevant RLS rules + rules = {} + for table in self.tables: + if rls := self._get_rls_for_table(table, catalog, schema): + rules[table] = rls + + if not rules: + return statement + + use_subquery = all( + engine_spec.allows_subqueries + for engine_spec in load_engine_specs() + if engine_spec.engine == self.engine + ) + transformer = RLSAsSubquery(rules) if use_subquery else RLSAsPredicate(rules) + + return SQLStatement(statement.transform(transformer), self.engine) + + def _get_rls_for_table( + self, + database: Database, + table: Table, + catalog: str | None, + schema: str | None, + ) -> exp.Expression | None: + """ + Get the RLS for a table. + + :param table: The table to get the RLS for + :param catalog: The default catalog for non-qualified table names + :param schema: The default schema for non-qualified table names + :return: The RLS for the table + """ + # pylint: disable=import-outside-toplevel + from superset import db + from superset.connectors.sqla.models import SqlaTable + + dataset = db.session.query(SqlaTable).filter( + and_( + SqlaTable.database_id == database.id, + SqlaTable.catalog == table.catalog or catalog, + SqlaTable.schema == table.schema or schema, + SqlaTable.table_name == table.table, + ).one_or_none() + ) + if not dataset: + return None + + filters = dataset.get_sqla_row_level_filters() + if not filters: + return None + + rls = and_(*filters).compile( + dialect=database.get_dialect(), + compile_kwargs={"literal_binds": True}, + ) + + return sqlglot.parse_one(str(rls), dialect=self._dialect) + + def is_dml(self) -> bool: + """ + Check if the statement is DML. + + :return: True if the statement is DML + """ + for node in self._parsed.walk(): + if isinstance( + node, + ( + exp.Insert, + exp.Update, + exp.Delete, + exp.Merge, + exp.Create, + exp.Alter, + exp.Drop, + exp.TruncateTable, + ), + ): + return True + + return False + + def as_create_table(self, table: Table, method: CtasMethod) -> SQLStatement: + """ + Convert the statement to a CREATE TABLE statement. + """ + create_table = exp.Create( + this=sqlglot.parse_one(table, into=exp.Table), + kind=method.value, + expression=self._parsed.copy(), + ) + + return SQLStatement(create_table, self.engine) + + def is_select(self) -> bool: + """ + Check if the statement is a SELECT statement. + + :return: True if the statement is a SELECT statement + """ + return isinstance(self._parsed, exp.Select) + + def apply_limit(self, limit: int, force: bool = False) -> SQLStatement: + """ + Apply a limit to the SQL. + + There are 3 strategies to limit queries, defined in the DB engine spec: + + 1. `FORCE_LIMIT`: a limit is added to the query, or the existing one is + replaced. This is the most efficient, since the database will produce at + most the number of rows that Superset will display. + 2. `WRAP_SQL`: the query is wrapped in a subquery, and the limit is applied + to the outer query. This might be inneficient, since the database + optimizer might not be able to push the limit down to the inner query. + 3. `FETCH_MANY`: no limit is applied, but only `LIMIT` rows are fetched from + the database. This is the least efficient, unless the database computes + rows as they are read by the cursor, which is unlikely. + + :param limit: The limit to apply + :param force: Apply limit even when a lower one is present + :return: The SQL with the limit applied + """ + from superset.db_engine_specs import load_engine_specs + from superset.db_engine_specs.base import LimitMethod + + methods = { + engine_spec.limit_method + for engine_spec in load_engine_specs() + if engine_spec.engine == self.engine + } + if not methods: + methods = {LimitMethod.FETCH_MANY} + + # When multiple methods are supported, we prefer the more generic one -- + # usually less efficient. + preference = [ + LimitMethod.FETCH_MANY, + LimitMethod.WRAP_SQL, + LimitMethod.FORCE_LIMIT, + ] + method = sorted(methods, key=preference.index)[0] + + if not self.is_select() or method == LimitMethod.FETCH_MANY: + return SQLStatement(self._parsed.copy(), self.engine) + + if method == LimitMethod.WRAP_SQL: + limited = exp.Select( + expressions=[exp.Star()], + from_=exp.Subquery(subquery=self._parsed.copy(), alias="inner_qry"), + limit=exp.Literal.number(limit), + ) + return SQLStatement(limited, self.engine) + + current_limit: int | None = None + for node in self._parsed.find_all(exp.Limit): + current_limit = int(node.expression.this) + break + + if force or current_limit is None or limit < current_limit: + return SQLStatement(self._parsed.limit(limit), self.engine) + + return SQLStatement(self._parsed.copy(), self.engine) + class KQLSplitState(enum.Enum): """ @@ -666,11 +936,11 @@ class KustoKQLStatement(BaseSQLStatement[str]): ) return set() - def format(self, comments: bool = True) -> str: + def format(self, comments: bool = True, strip: bool = False) -> str: """ Pretty-format the SQL statement. """ - return self._parsed + return self._parsed.strip(" \t\r\n;") if strip else self._parsed def get_settings(self) -> dict[str, str | bool]: """ @@ -712,7 +982,10 @@ class SQLScript: """ Pretty-format the SQL query. """ - return ";\n".join(statement.format(comments) for statement in self.statements) + return ( + ";\n".join(statement.format(comments) for statement in self.statements) + + ";" + ) def get_settings(self) -> dict[str, str | bool]: """ @@ -792,7 +1065,7 @@ class ParsedQuery: Note: this uses sqlglot, since it's better at catching more edge cases. """ try: - statements = parse(self.stripped(), dialect=self._dialect) + statements = sqlglot.parse(self.stripped(), dialect=self._dialect) except SqlglotError as ex: logger.warning("Unable to parse SQL (%s): %s", self._dialect, self.sql) @@ -846,7 +1119,7 @@ class ParsedQuery: return set() try: - pseudo_query = parse_one( + pseudo_query = sqlglot.parse_one( f"SELECT {literal.this}", dialect=self._dialect, ) diff --git a/tests/unit_tests/sql_parse_tests.py b/tests/unit_tests/sql_parse_tests.py index 6259d6272db..204455c22d5 100644 --- a/tests/unit_tests/sql_parse_tests.py +++ b/tests/unit_tests/sql_parse_tests.py @@ -20,6 +20,7 @@ from typing import Optional from unittest.mock import Mock import pytest +import sqlglot import sqlparse from pytest_mock import MockerFixture from sqlalchemy import text @@ -41,6 +42,8 @@ from superset.sql_parse import ( insert_rls_in_predicate, KustoKQLStatement, ParsedQuery, + RLSAsPredicate, + RLSAsSubquery, sanitize_clause, split_kql, SQLScript, @@ -119,8 +122,9 @@ def test_extract_tables_subselect() -> None: """ Test that tables inside subselects are parsed correctly. """ - assert extract_tables( - """ + assert ( + extract_tables( + """ SELECT sub.* FROM ( SELECT * @@ -129,10 +133,13 @@ FROM ( ) sub, s2.t2 WHERE sub.resolution = 'NONE' """ - ) == {Table("t1", "s1"), Table("t2", "s2")} + ) + == {Table("t1", "s1"), Table("t2", "s2")} + ) - assert extract_tables( - """ + assert ( + extract_tables( + """ SELECT sub.* FROM ( SELECT * @@ -141,10 +148,13 @@ FROM ( ) sub WHERE sub.resolution = 'NONE' """ - ) == {Table("t1", "s1")} + ) + == {Table("t1", "s1")} + ) - assert extract_tables( - """ + assert ( + extract_tables( + """ SELECT * FROM t1 WHERE s11 > ANY ( SELECT COUNT(*) /* no hint */ FROM t2 @@ -156,7 +166,9 @@ WHERE s11 > ANY ( ) ) """ - ) == {Table("t1"), Table("t2"), Table("t3"), Table("t4")} + ) + == {Table("t1"), Table("t2"), Table("t3"), Table("t4")} + ) def test_extract_tables_select_in_expression() -> None: @@ -227,24 +239,30 @@ def test_extract_tables_select_array() -> None: """ Test that queries selecting arrays work as expected. """ - assert extract_tables( - """ + assert ( + extract_tables( + """ SELECT ARRAY[1, 2, 3] AS my_array FROM t1 LIMIT 10 """ - ) == {Table("t1")} + ) + == {Table("t1")} + ) def test_extract_tables_select_if() -> None: """ Test that queries with an ``IF`` work as expected. """ - assert extract_tables( - """ + assert ( + extract_tables( + """ SELECT IF(CARDINALITY(my_array) >= 3, my_array[3], NULL) FROM t1 LIMIT 10 """ - ) == {Table("t1")} + ) + == {Table("t1")} + ) def test_extract_tables_with_catalog() -> None: @@ -312,29 +330,38 @@ def test_extract_tables_where_subquery() -> None: """ Test that tables in a ``WHERE`` subquery are parsed correctly. """ - assert extract_tables( - """ + assert ( + extract_tables( + """ SELECT name FROM t1 WHERE regionkey = (SELECT max(regionkey) FROM t2) """ - ) == {Table("t1"), Table("t2")} + ) + == {Table("t1"), Table("t2")} + ) - assert extract_tables( - """ + assert ( + extract_tables( + """ SELECT name FROM t1 WHERE regionkey IN (SELECT regionkey FROM t2) """ - ) == {Table("t1"), Table("t2")} + ) + == {Table("t1"), Table("t2")} + ) - assert extract_tables( - """ + assert ( + extract_tables( + """ SELECT name FROM t1 WHERE EXISTS (SELECT 1 FROM t2 WHERE t1.regionkey = t2.regionkey); """ - ) == {Table("t1"), Table("t2")} + ) + == {Table("t1"), Table("t2")} + ) def test_extract_tables_describe() -> None: @@ -348,12 +375,15 @@ def test_extract_tables_show_partitions() -> None: """ Test ``SHOW PARTITIONS``. """ - assert extract_tables( - """ + assert ( + extract_tables( + """ SHOW PARTITIONS FROM orders WHERE ds >= '2013-01-01' ORDER BY ds DESC """ - ) == {Table("orders")} + ) + == {Table("orders")} + ) def test_extract_tables_join() -> None: @@ -365,8 +395,9 @@ def test_extract_tables_join() -> None: Table("t2"), } - assert extract_tables( - """ + assert ( + extract_tables( + """ SELECT a.date, b.name FROM left_table a JOIN ( @@ -377,10 +408,13 @@ JOIN ( ) b ON a.date = b.date """ - ) == {Table("left_table"), Table("right_table")} + ) + == {Table("left_table"), Table("right_table")} + ) - assert extract_tables( - """ + assert ( + extract_tables( + """ SELECT a.date, b.name FROM left_table a LEFT INNER JOIN ( @@ -391,10 +425,13 @@ LEFT INNER JOIN ( ) b ON a.date = b.date """ - ) == {Table("left_table"), Table("right_table")} + ) + == {Table("left_table"), Table("right_table")} + ) - assert extract_tables( - """ + assert ( + extract_tables( + """ SELECT a.date, b.name FROM left_table a RIGHT OUTER JOIN ( @@ -405,10 +442,13 @@ RIGHT OUTER JOIN ( ) b ON a.date = b.date """ - ) == {Table("left_table"), Table("right_table")} + ) + == {Table("left_table"), Table("right_table")} + ) - assert extract_tables( - """ + assert ( + extract_tables( + """ SELECT a.date, b.name FROM left_table a FULL OUTER JOIN ( @@ -419,15 +459,18 @@ FULL OUTER JOIN ( ) b ON a.date = b.date """ - ) == {Table("left_table"), Table("right_table")} + ) + == {Table("left_table"), Table("right_table")} + ) def test_extract_tables_semi_join() -> None: """ Test ``LEFT SEMI JOIN``. """ - assert extract_tables( - """ + assert ( + extract_tables( + """ SELECT a.date, b.name FROM left_table a LEFT SEMI JOIN ( @@ -438,15 +481,18 @@ LEFT SEMI JOIN ( ) b ON a.data = b.date """ - ) == {Table("left_table"), Table("right_table")} + ) + == {Table("left_table"), Table("right_table")} + ) def test_extract_tables_combinations() -> None: """ Test a complex case with nested queries. """ - assert extract_tables( - """ + assert ( + extract_tables( + """ SELECT * FROM t1 WHERE s11 > ANY ( SELECT * FROM t1 UNION ALL SELECT * FROM ( @@ -460,10 +506,13 @@ WHERE s11 > ANY ( ) ) """ - ) == {Table("t1"), Table("t3"), Table("t4"), Table("t6")} + ) + == {Table("t1"), Table("t3"), Table("t4"), Table("t6")} + ) - assert extract_tables( - """ + assert ( + extract_tables( + """ SELECT * FROM ( SELECT * FROM ( SELECT * FROM ( @@ -472,45 +521,56 @@ SELECT * FROM ( ) AS S2 ) AS S3 """ - ) == {Table("EmployeeS")} + ) + == {Table("EmployeeS")} + ) def test_extract_tables_with() -> None: """ Test ``WITH``. """ - assert extract_tables( - """ + assert ( + extract_tables( + """ WITH x AS (SELECT a FROM t1), y AS (SELECT a AS b FROM t2), z AS (SELECT b AS c FROM t3) SELECT c FROM z """ - ) == {Table("t1"), Table("t2"), Table("t3")} + ) + == {Table("t1"), Table("t2"), Table("t3")} + ) - assert extract_tables( - """ + assert ( + extract_tables( + """ WITH x AS (SELECT a FROM t1), y AS (SELECT a AS b FROM x), z AS (SELECT b AS c FROM y) SELECT c FROM z """ - ) == {Table("t1")} + ) + == {Table("t1")} + ) def test_extract_tables_reusing_aliases() -> None: """ Test that the parser follows aliases. """ - assert extract_tables( - """ + assert ( + extract_tables( + """ with q1 as ( select key from q2 where key = '5'), q2 as ( select key from src where key = '5') select * from (select key from q1) a """ - ) == {Table("src")} + ) + == {Table("src")} + ) # weird query with circular dependency assert ( @@ -547,8 +607,9 @@ def test_extract_tables_complex() -> None: """ Test a few complex queries. """ - assert extract_tables( - """ + assert ( + extract_tables( + """ SELECT sum(m_examples) AS "sum__m_example" FROM ( SELECT @@ -569,23 +630,29 @@ FROM ( ORDER BY "sum__m_example" DESC LIMIT 10; """ - ) == { - Table("my_l_table"), - Table("my_b_table"), - Table("my_t_table"), - Table("inner_table"), - } + ) + == { + Table("my_l_table"), + Table("my_b_table"), + Table("my_t_table"), + Table("inner_table"), + } + ) - assert extract_tables( - """ + assert ( + extract_tables( + """ SELECT * FROM table_a AS a, table_b AS b, table_c as c WHERE a.id = b.id and b.id = c.id """ - ) == {Table("table_a"), Table("table_b"), Table("table_c")} + ) + == {Table("table_a"), Table("table_b"), Table("table_c")} + ) - assert extract_tables( - """ + assert ( + extract_tables( + """ SELECT somecol AS somecol FROM ( WITH bla AS ( @@ -629,51 +696,63 @@ FROM ( LIMIT 50000 ) """ - ) == {Table("a"), Table("b"), Table("c"), Table("d"), Table("e"), Table("f")} + ) + == {Table("a"), Table("b"), Table("c"), Table("d"), Table("e"), Table("f")} + ) def test_extract_tables_mixed_from_clause() -> None: """ Test that the parser handles a ``FROM`` clause with table and subselect. """ - assert extract_tables( - """ + assert ( + extract_tables( + """ SELECT * FROM table_a AS a, (select * from table_b) AS b, table_c as c WHERE a.id = b.id and b.id = c.id """ - ) == {Table("table_a"), Table("table_b"), Table("table_c")} + ) + == {Table("table_a"), Table("table_b"), Table("table_c")} + ) def test_extract_tables_nested_select() -> None: """ Test that the parser handles selects inside functions. """ - assert extract_tables( - """ + assert ( + extract_tables( + """ select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(TABLE_NAME) from INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA like "%bi%"),0x7e))); """, - "mysql", - ) == {Table("COLUMNS", "INFORMATION_SCHEMA")} + "mysql", + ) + == {Table("COLUMNS", "INFORMATION_SCHEMA")} + ) - assert extract_tables( - """ + assert ( + extract_tables( + """ select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(COLUMN_NAME) from INFORMATION_SCHEMA.COLUMNS WHERE TABLE_NAME="bi_achievement_daily"),0x7e))); """, - "mysql", - ) == {Table("COLUMNS", "INFORMATION_SCHEMA")} + "mysql", + ) + == {Table("COLUMNS", "INFORMATION_SCHEMA")} + ) def test_extract_tables_complex_cte_with_prefix() -> None: """ Test that the parser handles CTEs with prefixes. """ - assert extract_tables( - """ + assert ( + extract_tables( + """ WITH CTE__test (SalesPersonID, SalesOrderID, SalesYear) AS ( SELECT SalesPersonID, SalesOrderID, YEAR(OrderDate) AS SalesYear @@ -685,21 +764,26 @@ FROM CTE__test GROUP BY SalesYear, SalesPersonID ORDER BY SalesPersonID, SalesYear; """ - ) == {Table("SalesOrderHeader")} + ) + == {Table("SalesOrderHeader")} + ) def test_extract_tables_identifier_list_with_keyword_as_alias() -> None: """ Test that aliases that are keywords are parsed correctly. """ - assert extract_tables( - """ + assert ( + extract_tables( + """ WITH f AS (SELECT * FROM foo), match AS (SELECT * FROM f) SELECT * FROM match """ - ) == {Table("foo")} + ) + == {Table("foo")} + ) def test_update() -> None: @@ -1841,7 +1925,7 @@ def test_sqlquery() -> None: script = SQLScript("SELECT 1; SELECT 2;", "sqlite") assert len(script.statements) == 2 - assert script.format() == "SELECT\n 1;\nSELECT\n 2" + assert script.format() == "SELECT\n 1;\nSELECT\n 2;" assert script.statements[0].format() == "SELECT\n 1" script = SQLScript("SET a=1; SET a=2; SELECT 3;", "sqlite") @@ -2058,3 +2142,120 @@ on $left.Day1 == $right.Day | project Day1, Day2, Percentage = count_*100.0/count_1 """, ] + + +@pytest.mark.parametrize( + "sql,rules,expected", + [ + ( + "SELECT t.foo FROM some_table AS t", + {Table("some_table"): "id = 42"}, + "SELECT t.foo FROM (SELECT * FROM some_table WHERE id = 42) AS t", + ), + ( + "SELECT t.foo FROM some_table AS t WHERE bar = 'baz'", + {Table("some_table"): "id = 42"}, + ( + "SELECT t.foo FROM (SELECT * FROM some_table WHERE id = 42) AS t " + "WHERE bar = 'baz'" + ), + ), + ( + "SELECT t.foo FROM schema1.some_table AS t", + {Table("some_table", "schema1"): "id = 42"}, + "SELECT t.foo FROM (SELECT * FROM schema1.some_table WHERE id = 42) AS t", + ), + ( + "SELECT t.foo FROM schema1.some_table AS t", + {Table("some_table", "schema2"): "id = 42"}, + "SELECT t.foo FROM schema1.some_table AS t", + ), + ( + "SELECT t.foo FROM catalog1.schema1.some_table AS t", + {Table("some_table", "schema1", "catalog1"): "id = 42"}, + "SELECT t.foo FROM (SELECT * FROM catalog1.schema1.some_table WHERE id = 42) AS t", + ), + ( + "SELECT t.foo FROM catalog1.schema1.some_table AS t", + {Table("some_table", "schema1", "catalog2"): "id = 42"}, + "SELECT t.foo FROM catalog1.schema1.some_table AS t", + ), + ], +) +def test_RLSAsSubquery(sql: str, rules: dict[Table, str], expected: str) -> None: + """ + Test the `RLSAsSubquery` transformer. + """ + statement = sqlglot.parse_one(sql) + transformer = RLSAsSubquery(rules) + assert str(statement.transform(transformer)) == expected + + +@pytest.mark.parametrize( + "sql,rules,expected", + [ + ( + "SELECT t.foo FROM some_table AS t", + {Table("some_table"): "id = 42"}, + "SELECT t.foo FROM some_table AS t WHERE id = 42", + ), + ( + "SELECT t.foo FROM some_table AS t WHERE bar = 'baz'", + {Table("some_table"): "id = 42"}, + "SELECT t.foo FROM some_table AS t WHERE id = 42 AND bar = 'baz'", + ), + ], +) +def test_RLSAsPredicate(sql: str, rules: dict[Table, str], expected: str) -> None: + """ + Test the `RLSAsPredicate` transformer. + """ + statement = sqlglot.parse_one(sql) + transformer = RLSAsPredicate(rules) + assert str(statement.transform(transformer)) == expected + + +@pytest.mark.parametrize( + "sql,engine,limit,force,expected", + [ + ( + "SELECT TOP 10 * FROM Customers", + "teradatasql", + 5, + False, + "SELECT\nTOP 5\n *\nFROM Customers", + ), + ( + "SELECT TOP 10 * FROM Customers", + "teradatasql", + 15, + False, + "SELECT\nTOP 10\n *\nFROM Customers", + ), + ( + "SELECT TOP 10 * FROM Customers", + "teradatasql", + 15, + True, + "SELECT\nTOP 15\n *\nFROM Customers", + ), + ( + "SELECT TOP 10 * FROM Customers", + "mssql", + 15, + True, + "SELECT\nTOP 15\n *\nFROM Customers", + ), + ], +) +def test_apply_limit( + sql: str, + engine: str, + limit: int, + force: bool, + expected: str, +) -> None: + """ + Test the `apply_limit` function. + """ + assert SQLStatement(sql, engine).apply_limit(limit, force).format() == expected