diff --git a/superset/db_engine_specs/hive.py b/superset/db_engine_specs/hive.py index 0d2bdd3a5d9..1f15b2834f9 100644 --- a/superset/db_engine_specs/hive.py +++ b/superset/db_engine_specs/hive.py @@ -95,6 +95,7 @@ class HiveEngineSpec(PrestoEngineSpec): allows_hidden_orderby_agg = False supports_dynamic_schema = True + supports_cross_catalog_queries = False # When running `SHOW FUNCTIONS`, what is the name of the column with the # function names? diff --git a/superset/exceptions.py b/superset/exceptions.py index c6105b44654..0ca6f0c70be 100644 --- a/superset/exceptions.py +++ b/superset/exceptions.py @@ -432,3 +432,54 @@ class TableNotFoundException(SupersetErrorException): level=ErrorLevel.ERROR, ) ) + + +class SupersetDMLNotAllowedException(SupersetErrorException): + def __init__(self) -> None: + error = SupersetError( + message=_( + "This database does not allow for DDL/DML, but the query mutates " + "data. Please contact your administrator for more assistance." + ), + error_type=SupersetErrorType.DML_NOT_ALLOWED_ERROR, + level=ErrorLevel.ERROR, + ) + super().__init__(error) + + +class SupersetInvalidCTASException(SupersetErrorException): + def __init__(self) -> None: + error = SupersetError( + message=_( + "CTAS (create table as select) can only be run with a query where " + "the last statement is a SELECT. Please make sure your query has " + "a SELECT as its last statement. Then, try running your query again." + ), + error_type=SupersetErrorType.INVALID_CTAS_QUERY_ERROR, + level=ErrorLevel.ERROR, + ) + super().__init__(error) + + +class SupersetInvalidCVASException(SupersetErrorException): + def __init__(self) -> None: + error = SupersetError( + message=_( + "CVAS (create view as select) can only be run with a query with " + "a single SELECT statement. Please make sure your query has only " + "a SELECT statement. Then, try running your query again." + ), + error_type=SupersetErrorType.INVALID_CVAS_QUERY_ERROR, + level=ErrorLevel.ERROR, + ) + super().__init__(error) + + +class SupersetResultsBackendNotConfigureException(SupersetErrorException): + def __init__(self) -> None: + error = SupersetError( + message=_("Results backend is not configured."), + error_type=SupersetErrorType.RESULTS_BACKEND_NOT_CONFIGURED_ERROR, + level=ErrorLevel.ERROR, + ) + super().__init__(error) diff --git a/superset/models/sql_lab.py b/superset/models/sql_lab.py index 4d443423d18..f73334a1e99 100644 --- a/superset/models/sql_lab.py +++ b/superset/models/sql_lab.py @@ -56,7 +56,8 @@ from superset.models.helpers import ( ExtraJSONMixin, ImportExportMixin, ) -from superset.sql_parse import CtasMethod, extract_tables_from_jinja_sql, Table +from superset.sql.parse import CTASMethod +from superset.sql_parse import extract_tables_from_jinja_sql, Table from superset.sqllab.limiting_factor import LimitingFactor from superset.utils import json from superset.utils.core import ( @@ -128,7 +129,7 @@ class Query( ) select_as_cta = Column(Boolean) select_as_cta_used = Column(Boolean, default=False) - ctas_method = Column(String(16), default=CtasMethod.TABLE) + ctas_method = Column(String(16), default=CTASMethod.TABLE.name) progress = Column(Integer, default=0) # 1..100 # # of rows in the result set or rows modified. diff --git a/superset/sql/parse.py b/superset/sql/parse.py index fbda84570e3..0a0f4b3e5c1 100644 --- a/superset/sql/parse.py +++ b/superset/sql/parse.py @@ -439,22 +439,35 @@ class BaseSQLStatement(Generic[InternalRepresentation]): """ raise NotImplementedError() - def as_cte(self, alias: str = "__cte") -> SQLStatement: + def as_cte(self, alias: str = "__cte") -> BaseSQLStatement[InternalRepresentation]: """ Rewrite the statement as a CTE. :param alias: The alias to use for the CTE. - :return: A new SQLStatement with the CTE. + :return: A new BaseSQLStatement[InternalRepresentation] with the CTE. """ raise NotImplementedError() - def as_create_table(self, table: Table, method: CTASMethod) -> SQLStatement: + def as_create_table( + self, + table: Table, + method: CTASMethod, + ) -> BaseSQLStatement[InternalRepresentation]: """ Rewrite the statement as a `CREATE TABLE AS` statement. :param table: The table to create. :param method: The method to use for creating the table. - :return: A new SQLStatement with the CTE. + :return: A new BaseSQLStatement[InternalRepresentation] with the CTE. + """ + raise NotImplementedError() + + def parse_predicate(self, predicate: str) -> InternalRepresentation: + """ + Parse a predicate string into an AST. + + :param predicate: The predicate to parse. + :return: The parsed predicate. """ raise NotImplementedError() @@ -790,6 +803,15 @@ class SQLStatement(BaseSQLStatement[exp.Expression]): return SQLStatement(ast=create_table, engine=self.engine) + def parse_predicate(self, predicate: str) -> exp.Expression: + """ + Parse a predicate string into an AST. + + :param predicate: The predicate to parse. + :return: The parsed predicate. + """ + return sqlglot.parse_one(predicate, dialect=self._dialect) + def apply_rls( self, catalog: str | None, @@ -804,6 +826,9 @@ class SQLStatement(BaseSQLStatement[exp.Expression]): :param schema: The default schema for non-qualified table names :param method: The method to use for applying the rules. """ + if not predicates: + return + transformers = { RLSMethod.AS_PREDICATE: RLSAsPredicateTransformer, RLSMethod.AS_SUBQUERY: RLSAsSubqueryTransformer, @@ -1128,6 +1153,15 @@ class KustoKQLStatement(BaseSQLStatement[str]): self._parsed = "".join(val for _, val in tokens) + def parse_predicate(self, predicate: str) -> str: + """ + Parse a predicate string into an AST. + + :param predicate: The predicate to parse. + :return: The parsed predicate. + """ + return predicate + class SQLScript: """ diff --git a/superset/sql_lab.py b/superset/sql_lab.py index 34e648f8fd2..08a694c9489 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -19,16 +19,18 @@ import dataclasses import logging import sys import uuid +from collections import defaultdict from contextlib import closing from datetime import datetime from sys import getsizeof -from typing import Any, cast, Optional, Union +from typing import Any, cast, Optional, TypeVar, Union import backoff import msgpack from celery.exceptions import SoftTimeLimitExceeded from flask import current_app from flask_babel import gettext as __ +from sqlalchemy import and_, or_ from superset import ( app, @@ -39,27 +41,25 @@ from superset import ( security_manager, ) from superset.common.db_query_status import QueryStatus +from superset.connectors.sqla.models import SqlaTable from superset.constants import QUERY_CANCEL_KEY, QUERY_EARLY_CANCEL_KEY from superset.dataframe import df_to_records from superset.db_engine_specs import BaseEngineSpec from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.exceptions import ( OAuth2RedirectError, + SupersetDMLNotAllowedException, SupersetErrorException, SupersetErrorsException, - SupersetParseError, + SupersetInvalidCTASException, + SupersetInvalidCVASException, + SupersetResultsBackendNotConfigureException, ) from superset.extensions import celery_app, event_logger from superset.models.core import Database from superset.models.sql_lab import Query from superset.result_set import SupersetResultSet -from superset.sql.parse import SQLScript, SQLStatement, Table -from superset.sql_parse import ( - CtasMethod, - insert_rls_as_subquery, - insert_rls_in_predicate, - ParsedQuery, -) +from superset.sql.parse import BaseSQLStatement, CTASMethod, RLSMethod, SQLScript, Table from superset.sqllab.limiting_factor import LimitingFactor from superset.sqllab.utils import write_ipc_buffer from superset.utils import json @@ -197,101 +197,149 @@ def get_sql_results( # pylint: disable=too-many-arguments return handle_query_error(ex, query) -def execute_sql_statement( # pylint: disable=too-many-statements, too-many-locals # noqa: C901 - sql_statement: str, +def apply_rls(query: Query, parsed_statement: BaseSQLStatement[Any]) -> None: + """ + Modify statement inplace to ensure RLS rules are applied. + """ + database = query.database + + # we need the default schema to fully qualify the table names + default_schema = database.get_default_schema_for_query(query) + + # There are two ways to insert RLS: either replacing the table with a subquery + # that has the RLS, or appending the RLS to the ``WHERE`` clause. The former is + # safer, but not supported in all databases. + method = ( + RLSMethod.AS_SUBQUERY + if database.db_engine_spec.allows_subqueries + and database.db_engine_spec.allows_alias_in_select + else RLSMethod.AS_PREDICATE + ) + + # collect all RLS predicates for all tables in the query + predicates: dict[Table, list[Any]] = defaultdict(list) + for table in parsed_statement.tables: + # fully qualify table + table = Table( + table.table, + table.schema or default_schema, + table.catalog or query.catalog, + ) + + if table_predicates := get_predicates_for_table( + table, + database, + query.catalog == database.get_default_catalog(), + ): + predicates[table].extend( + parsed_statement.parse_predicate(predicate) + for predicate in table_predicates + ) + + parsed_statement.apply_rls(query.catalog, default_schema, predicates, method) + + +def get_predicates_for_table( + table: Table, + database: Database, + is_default_catalog: bool, +) -> list[str]: + """ + Get the RLS predicates for a table. + + This is used to inject RLS rules into SQL statements run in SQL Lab. Note that the + table must be fully qualified, with catalog (null if the DB doesn't support) and + schema. + """ + # if the dataset in the RLS has null catalog, match it when using the default + # catalog + catalog_predicate = SqlaTable.catalog == table.catalog + if table.catalog and is_default_catalog: + catalog_predicate = or_( + catalog_predicate, + SqlaTable.catalog.is_(None), + ) + + dataset = ( + db.session.query(SqlaTable) + .filter( + and_( + SqlaTable.database_id == database.id, + catalog_predicate, + SqlaTable.schema == table.schema, + SqlaTable.table_name == table.table, + ) + ) + .one_or_none() + ) + if not dataset: + return [] + + return [ + str( + predicate.compile( + dialect=database.get_dialect(), + compile_kwargs={"literal_binds": True}, + ) + ) + for predicate in dataset.get_sqla_row_level_filters() + ] + + +S = TypeVar("S", bound=BaseSQLStatement[Any]) + + +def apply_ctas(query: Query, parsed_statement: S) -> S: + """ + Apply CTAS/CVAS. + """ + if not query.tmp_table_name: + start_dttm = datetime.fromtimestamp(query.start_time) + prefix = f"tmp_{query.user_id}_table" + query.tmp_table_name = start_dttm.strftime(f"{prefix}_%Y_%m_%d_%H_%M_%S") + + catalog = ( + query.catalog + if query.database.db_engine_spec.supports_cross_catalog_queries + else None + ) + table = Table(query.tmp_table_name, query.tmp_schema_name, catalog) + method = CTASMethod[query.ctas_method.upper()] + + return parsed_statement.as_create_table(table, method) # type: ignore[return-value] + + +def apply_limit(query: Query, parsed_statement: BaseSQLStatement[Any]) -> None: + """ + Apply limit to the SQL statement. + """ + # Do not apply limit to the CTA queries when SQLLAB_CTAS_NO_LIMIT is set to true + if parsed_statement.is_mutating() or ( + query.select_as_cta_used and SQLLAB_CTAS_NO_LIMIT + ): + return + + if SQL_MAX_ROW and (not query.limit or query.limit > SQL_MAX_ROW): + query.limit = SQL_MAX_ROW + + if query.limit: + parsed_statement.set_limit_value( + # fetch an extra row to inform user if there are more rows + query.limit + 1, + query.database.db_engine_spec.limit_method, + ) + + +def execute_query( # pylint: disable=too-many-statements, too-many-locals # noqa: C901 query: Query, cursor: Any, - log_params: Optional[dict[str, Any]], - apply_ctas: bool = False, + log_params: Optional[dict[str, Any]] = None, ) -> SupersetResultSet: """Executes a single SQL statement""" database: Database = query.database db_engine_spec = database.db_engine_spec - parsed_query = ParsedQuery(sql_statement, engine=db_engine_spec.engine) - if is_feature_enabled("RLS_IN_SQLLAB"): - # There are two ways to insert RLS: either replacing the table with a subquery - # that has the RLS, or appending the RLS to the ``WHERE`` clause. The former is - # safer, but not supported in all databases. - insert_rls = ( - insert_rls_as_subquery - if database.db_engine_spec.allows_subqueries - and database.db_engine_spec.allows_alias_in_select - else insert_rls_in_predicate - ) - - # Insert any applicable RLS predicates - parsed_query = ParsedQuery( - str( - insert_rls( - parsed_query._parsed[0], # pylint: disable=protected-access - database.id, - query.schema, - ) - ), - engine=db_engine_spec.engine, - ) - - sql = parsed_query.stripped() - - # This is a test to see if the query is being - # limited by either the dropdown or the sql. - # We are testing to see if more rows exist than the limit. - increased_limit = None if query.limit is None else query.limit + 1 - - if not database.allow_dml: - errors = [] - try: - parsed_statement = SQLStatement( - statement=sql_statement, - engine=db_engine_spec.engine, - ) - disallowed = parsed_statement.is_mutating() - except SupersetParseError as ex: - # if we fail to parse the query, disallow by default - disallowed = True - errors.append(ex.error) - - if disallowed: - errors.append( - SupersetError( - message=__( - "This database does not allow for DDL/DML, and the query " - "could not be parsed to confirm it is a read-only query. Please " # noqa: E501 - "contact your administrator for more assistance." - ), - error_type=SupersetErrorType.DML_NOT_ALLOWED_ERROR, - level=ErrorLevel.ERROR, - ) - ) - raise SupersetErrorsException(errors) - - original_sql = sql - if apply_ctas: - if not query.tmp_table_name: - start_dttm = datetime.fromtimestamp(query.start_time) - query.tmp_table_name = ( - f"tmp_{query.user_id}_table_{start_dttm.strftime('%Y_%m_%d_%H_%M_%S')}" - ) - sql = parsed_query.as_create_table( - query.tmp_table_name, - schema_name=query.tmp_schema_name, - method=query.ctas_method, - ) - query.select_as_cta_used = True - - # Do not apply limit to the CTA queries when SQLLAB_CTAS_NO_LIMIT is set to true - if not SQLScript(original_sql, db_engine_spec.engine).has_mutation() and not ( - query.select_as_cta_used and SQLLAB_CTAS_NO_LIMIT - ): - if SQL_MAX_ROW and (not query.limit or query.limit > SQL_MAX_ROW): - query.limit = SQL_MAX_ROW - sql = apply_limit_if_exists(database, increased_limit, query, sql) - - # Hook to allow environment-specific mutation (usually comments) to the SQL - sql = database.mutate_sql_based_on_config(sql) try: - query.executed_sql = sql if log_query: log_query( query.database.sqlalchemy_uri, @@ -308,7 +356,7 @@ def execute_sql_statement( # pylint: disable=too-many-statements, too-many-loca object_ref=__name__, ): with stats_timing("sqllab.query.time_executing_query", stats_logger): - db_engine_spec.execute_with_cursor(cursor, sql, query) + db_engine_spec.execute_with_cursor(cursor, query.executed_sql, query) with stats_timing("sqllab.query.time_fetching_results", stats_logger): logger.debug( @@ -316,6 +364,7 @@ def execute_sql_statement( # pylint: disable=too-many-statements, too-many-loca query.id, str(query.to_dict()), ) + increased_limit = None if query.limit is None else query.limit + 1 data = db_engine_spec.fetch_data(cursor, increased_limit) if query.limit is None or len(data) <= query.limit: query.limiting_factor = LimitingFactor.NOT_LIMITED @@ -356,19 +405,6 @@ def execute_sql_statement( # pylint: disable=too-many-statements, too-many-loca return SupersetResultSet(data, cursor_description, db_engine_spec) -def apply_limit_if_exists( - database: Database, increased_limit: Optional[int], query: Query, sql: str -) -> str: - if query.limit and increased_limit: - # We are fetching one more than the requested limit in order - # to test whether there are more rows than the limit. According to the DB - # Engine support it will choose top or limit parse - # Later, the extra row will be dropped before sending - # the results back to the user. - sql = database.apply_limit_to_sql(sql, increased_limit, force=True) - return sql - - def _serialize_payload( payload: dict[Any, Any], use_msgpack: Optional[bool] = False ) -> Union[bytes, str]: @@ -434,67 +470,52 @@ def execute_sql_statements( # noqa: C901 db_engine_spec.patch() if database.allow_run_async and not results_backend: - raise SupersetErrorException( - SupersetError( - message=__("Results backend is not configured."), - error_type=SupersetErrorType.RESULTS_BACKEND_NOT_CONFIGURED_ERROR, - level=ErrorLevel.ERROR, - ) - ) - - # Breaking down into multiple statements - parsed_query = ParsedQuery( - rendered_query, - engine=db_engine_spec.engine, - ) - if not db_engine_spec.run_multiple_statements_as_one: - statements = parsed_query.get_statements() - logger.info( - "Query %s: Executing %i statement(s)", str(query_id), len(statements) - ) - else: - statements = [rendered_query] - logger.info("Query %s: Executing query as a single statement", str(query_id)) + raise SupersetResultsBackendNotConfigureException() logger.info("Query %s: Set query to 'running'", str(query_id)) query.status = QueryStatus.RUNNING query.start_running_time = now_as_float() db.session.commit() - # Should we create a table or view from the select? - if ( - query.select_as_cta - and query.ctas_method == CtasMethod.TABLE - and not parsed_query.is_valid_ctas() - ): - raise SupersetErrorException( - SupersetError( - message=__( - "CTAS (create table as select) can only be run with a query where " - "the last statement is a SELECT. Please make sure your query has " - "a SELECT as its last statement. Then, try running your query " - "again." - ), - error_type=SupersetErrorType.INVALID_CTAS_QUERY_ERROR, - level=ErrorLevel.ERROR, - ) - ) - if ( - query.select_as_cta - and query.ctas_method == CtasMethod.VIEW - and not parsed_query.is_valid_cvas() - ): - raise SupersetErrorException( - SupersetError( - message=__( - "CVAS (create view as select) can only be run with a query with " - "a single SELECT statement. Please make sure your query has only " - "a SELECT statement. Then, try running your query again." - ), - error_type=SupersetErrorType.INVALID_CVAS_QUERY_ERROR, - level=ErrorLevel.ERROR, - ) + parsed_script = SQLScript(rendered_query, engine=db_engine_spec.engine) + + if parsed_script.has_mutation() and not database.allow_dml: + raise SupersetDMLNotAllowedException() + + if is_feature_enabled("RLS_IN_SQLLAB"): + for statement in parsed_script.statements: + apply_rls(query, statement) + + if query.select_as_cta: + # CTAS is valid when the last statement is a SELECT, while CVAS is valid when + # there is only a single statement which must be a SELECT. + if ( + query.ctas_method == CTASMethod.TABLE.name + and not parsed_script.is_valid_ctas() + ): + raise SupersetInvalidCTASException() + if ( + query.ctas_method == CTASMethod.VIEW.name + and not parsed_script.is_valid_cvas() + ): + raise SupersetInvalidCVASException() + + parsed_script.statements[-1] = apply_ctas( # type: ignore + query, + parsed_script.statements[-1], ) + query.select_as_cta_used = True + + for statement in parsed_script.statements: + apply_limit(query, statement) + + # some databases (like BigQuery and Kusto) do not persist state across mmultiple + # statements if they're run separately (especially when using `NullPool`), so we run + # the query as a single block. + if db_engine_spec.run_multiple_statements_as_one: + blocks = [parsed_script.format()] + else: + blocks = [statement.format() for statement in parsed_script.statements] with database.get_raw_connection( catalog=query.catalog, @@ -504,40 +525,35 @@ def execute_sql_statements( # noqa: C901 # Sharing a single connection and cursor across the # execution of all statements (if many) cursor = conn.cursor() + cancel_query_id = db_engine_spec.get_cancel_query_id(cursor, query) if cancel_query_id is not None: query.set_extra_json_key(QUERY_CANCEL_KEY, cancel_query_id) db.session.commit() - statement_count = len(statements) - for i, statement in enumerate(statements): + + block_count = len(blocks) + for i, block in enumerate(blocks): # Check if stopped db.session.refresh(query) if query.status == QueryStatus.STOPPED: payload.update({"status": query.status}) return payload - # For CTAS we create the table only on the last statement - apply_ctas = query.select_as_cta and ( - query.ctas_method == CtasMethod.VIEW - or (query.ctas_method == CtasMethod.TABLE and i == len(statements) - 1) - ) + # Run statement msg = __( - "Running statement %(statement_num)s out of %(statement_count)s", - statement_num=i + 1, - statement_count=statement_count, + "Running block %(block_num)s out of %(block_count)s", + block_num=i + 1, + block_count=block_count, ) logger.info("Query %s: %s", str(query_id), msg) query.set_extra_json_key("progress", msg) db.session.commit() - try: - result_set = execute_sql_statement( - statement, - query, - cursor, - log_params, - apply_ctas, - ) + # Hook to allow environment-specific mutation (usually comments) to the SQL + query.executed_sql = database.mutate_sql_based_on_config(block) + + try: + result_set = execute_query(query, cursor, log_params) except SqlLabQueryStoppedException: payload.update({"status": QueryStatus.STOPPED}) return payload @@ -545,22 +561,18 @@ def execute_sql_statements( # noqa: C901 msg = str(ex) prefix_message = ( __( - "Statement %(statement_num)s out of %(statement_count)s", - statement_num=i + 1, - statement_count=statement_count, + "Block %(block_num)s out of %(block_count)s", + block_num=i + 1, + block_count=block_count, ) - if statement_count > 1 + if block_count > 1 else "" ) payload = handle_query_error(ex, query, payload, prefix_message) return payload # Commit the connection so CTA queries will create the table and any DML. - should_commit = ( - SQLScript(rendered_query, db_engine_spec.engine).has_mutation() - or apply_ctas - ) - if should_commit: + if parsed_script.has_mutation() or query.select_as_cta: conn.commit() # Success, updating the query entry in database diff --git a/superset/sqllab/sqllab_execution_context.py b/superset/sqllab/sqllab_execution_context.py index ab0f91bbf30..0e579ede9b6 100644 --- a/superset/sqllab/sqllab_execution_context.py +++ b/superset/sqllab/sqllab_execution_context.py @@ -26,7 +26,7 @@ from sqlalchemy.orm.exc import DetachedInstanceError from superset import is_feature_enabled from superset.models.sql_lab import Query -from superset.sql_parse import CtasMethod +from superset.sql.parse import CTASMethod from superset.utils import core as utils, json from superset.utils.core import apply_max_row_limit, get_user_id from superset.utils.dates import now_as_float @@ -148,6 +148,7 @@ class SqlJsonExecutionContext: # pylint: disable=too-many-instance-attributes def create_query(self) -> Query: start_time = now_as_float() + ctas = cast(CreateTableAsSelect, self.create_table_as_select) if self.select_as_cta: return Query( database_id=self.database_id, @@ -155,14 +156,14 @@ class SqlJsonExecutionContext: # pylint: disable=too-many-instance-attributes catalog=self.catalog, schema=self.schema, select_as_cta=True, - ctas_method=self.create_table_as_select.ctas_method, # type: ignore + ctas_method=ctas.ctas_method.name, start_time=start_time, tab_name=self.tab_name, status=self.status, limit=self.limit, sql_editor_id=self.sql_editor_id, - tmp_table_name=self.create_table_as_select.target_table_name, # type: ignore - tmp_schema_name=self.create_table_as_select.target_schema_name, # type: ignore + tmp_table_name=ctas.target_table_name, + tmp_schema_name=ctas.target_schema_name, user_id=self.user_id, client_id=self.client_id_or_short_id, ) @@ -190,12 +191,12 @@ class SqlJsonExecutionContext: # pylint: disable=too-many-instance-attributes class CreateTableAsSelect: # pylint: disable=too-few-public-methods - ctas_method: CtasMethod + ctas_method: CTASMethod target_schema_name: str | None target_table_name: str def __init__( - self, ctas_method: CtasMethod, target_schema_name: str, target_table_name: str + self, ctas_method: CTASMethod, target_schema_name: str, target_table_name: str ): self.ctas_method = ctas_method self.target_schema_name = target_schema_name @@ -203,7 +204,7 @@ class CreateTableAsSelect: # pylint: disable=too-few-public-methods @staticmethod def create_from(query_params: dict[str, Any]) -> CreateTableAsSelect: - ctas_method = query_params.get("ctas_method", CtasMethod.TABLE) + ctas_method = CTASMethod[query_params.get("ctas_method", "table").upper()] schema = cast(str, query_params.get("schema")) tmp_table_name = cast(str, query_params.get("tmp_table_name")) return CreateTableAsSelect(ctas_method, schema, tmp_table_name) diff --git a/tests/integration_tests/base_tests.py b/tests/integration_tests/base_tests.py index 5ddda0c5912..1cf91ca2bbd 100644 --- a/tests/integration_tests/base_tests.py +++ b/tests/integration_tests/base_tests.py @@ -36,7 +36,7 @@ from sqlalchemy.dialects.mysql import dialect from tests.integration_tests.constants import ADMIN_USERNAME from tests.integration_tests.test_app import app, login -from superset.sql_parse import CtasMethod +from superset.sql.parse import CTASMethod from superset import db, security_manager from superset.connectors.sqla.models import BaseDatasource, SqlaTable from superset.models import core as models @@ -387,7 +387,7 @@ class SupersetTestCase(TestCase): select_as_cta=False, tmp_table_name=None, schema=None, - ctas_method=CtasMethod.TABLE, + ctas_method=CTASMethod.TABLE, template_params="{}", ): if username: @@ -400,7 +400,7 @@ class SupersetTestCase(TestCase): "client_id": client_id, "queryLimit": query_limit, "sql_editor_id": sql_editor_id, - "ctas_method": ctas_method, + "ctas_method": ctas_method.name, "templateParams": template_params, } if tmp_table_name: diff --git a/tests/integration_tests/celery_tests.py b/tests/integration_tests/celery_tests.py index 050f1e1dc26..b517c20d2de 100644 --- a/tests/integration_tests/celery_tests.py +++ b/tests/integration_tests/celery_tests.py @@ -39,7 +39,7 @@ from superset.db_engine_specs.base import BaseEngineSpec from superset.errors import ErrorLevel, SupersetErrorType from superset.extensions import celery_app from superset.models.sql_lab import Query -from superset.sql_parse import ParsedQuery, CtasMethod +from superset.sql.parse import CTASMethod from superset.utils.core import backend from superset.utils.database import get_example_database from tests.integration_tests.conftest import CTAS_SCHEMA_NAME @@ -76,13 +76,19 @@ def setup_sqllab(): db.session.query(Query).delete() db.session.commit() for tbl in TMP_TABLES: - drop_table_if_exists(f"{tbl}_{CtasMethod.TABLE.lower()}", CtasMethod.TABLE) - drop_table_if_exists(f"{tbl}_{CtasMethod.VIEW.lower()}", CtasMethod.VIEW) drop_table_if_exists( - f"{CTAS_SCHEMA_NAME}.{tbl}_{CtasMethod.TABLE.lower()}", CtasMethod.TABLE + f"{tbl}_{CTASMethod.TABLE.name.lower()}", CTASMethod.TABLE ) drop_table_if_exists( - f"{CTAS_SCHEMA_NAME}.{tbl}_{CtasMethod.VIEW.lower()}", CtasMethod.VIEW + f"{tbl}_{CTASMethod.VIEW.name.lower()}", CTASMethod.VIEW + ) + drop_table_if_exists( + f"{CTAS_SCHEMA_NAME}.{tbl}_{CTASMethod.TABLE.name.lower()}", + CTASMethod.TABLE, + ) + drop_table_if_exists( + f"{CTAS_SCHEMA_NAME}.{tbl}_{CTASMethod.VIEW.name.lower()}", + CTASMethod.VIEW, ) @@ -90,7 +96,7 @@ def run_sql( test_client, sql, cta=False, - ctas_method=CtasMethod.TABLE, + ctas_method=CTASMethod.TABLE, tmp_table="tmp", async_=False, ): @@ -104,14 +110,14 @@ def run_sql( select_as_cta=cta, tmp_table_name=tmp_table, client_id="".join(random.choice(string.ascii_lowercase) for i in range(5)), # noqa: S311 - ctas_method=ctas_method, + ctas_method=ctas_method.name, ), ).json -def drop_table_if_exists(table_name: str, table_type: CtasMethod) -> None: +def drop_table_if_exists(table_name: str, table_type: CTASMethod) -> None: """Drop table if it exists, works on any DB""" - sql = f"DROP {table_type} IF EXISTS {table_name}" + sql = f"DROP {table_type.name} IF EXISTS {table_name}" database = get_example_database() with database.get_sqla_engine() as engine: engine.execute(sql) @@ -124,10 +130,10 @@ def quote_f(value: Optional[str]): return inspector.engine.dialect.identifier_preparer.quote_identifier(value) -def cta_result(ctas_method: CtasMethod): +def cta_result(ctas_method: CTASMethod): if backend() != "presto": return [], [] - if ctas_method == CtasMethod.TABLE: + if ctas_method == CTASMethod.TABLE: return [{"rows": 1}], [{"name": "rows", "type": "BIGINT", "is_dttm": False}] return [{"result": True}], [{"name": "result", "type": "BOOLEAN", "is_dttm": False}] @@ -143,13 +149,13 @@ def get_select_star(table: str, limit: int, schema: Optional[str] = None): @pytest.mark.usefixtures("login_as_admin") -@pytest.mark.parametrize("ctas_method", [CtasMethod.TABLE, CtasMethod.VIEW]) +@pytest.mark.parametrize("ctas_method", [CTASMethod.TABLE, CTASMethod.VIEW]) def test_run_sync_query_dont_exist(test_client, ctas_method): examples_db = get_example_database() engine_name = examples_db.db_engine_spec.engine_name sql_dont_exist = "SELECT name FROM table_dont_exist" result = run_sql(test_client, sql_dont_exist, cta=True, ctas_method=ctas_method) - if backend() == "sqlite" and ctas_method == CtasMethod.VIEW: + if backend() == "sqlite" and ctas_method == CTASMethod.VIEW: assert QueryStatus.SUCCESS == result["status"], result elif backend() == "presto": assert ( @@ -188,9 +194,9 @@ def test_run_sync_query_dont_exist(test_client, ctas_method): @pytest.mark.usefixtures("load_birth_names_data", "login_as_admin") -@pytest.mark.parametrize("ctas_method", [CtasMethod.TABLE, CtasMethod.VIEW]) -def test_run_sync_query_cta(test_client, ctas_method): - tmp_table_name = f"{TEST_SYNC}_{ctas_method.lower()}" +@pytest.mark.parametrize("ctas_method", [CTASMethod.TABLE, CTASMethod.VIEW]) +def test_run_sync_query_cta(test_client, ctas_method: CTASMethod) -> None: + tmp_table_name = f"{TEST_SYNC}_{ctas_method.name.lower()}" result = run_sql( test_client, QUERY, tmp_table=tmp_table_name, cta=True, ctas_method=ctas_method ) @@ -218,16 +224,44 @@ def test_run_sync_query_cta_no_data(test_client): @pytest.mark.usefixtures("load_birth_names_data", "login_as_admin") -@pytest.mark.parametrize("ctas_method", [CtasMethod.TABLE, CtasMethod.VIEW]) +@pytest.mark.parametrize( + "ctas_method, expected", + [ + ( + CTASMethod.TABLE, + """ +CREATE TABLE sqllab_test_db.test_sync_cta_table AS +SELECT + name +FROM birth_names +LIMIT 1 + """.strip(), + ), + ( + CTASMethod.VIEW, + """ +CREATE VIEW sqllab_test_db.test_sync_cta_view AS +SELECT + name +FROM birth_names +LIMIT 1 + """.strip(), + ), + ], +) @mock.patch( # noqa: PT008 "superset.sqllab.sqllab_execution_context.get_cta_schema_name", lambda d, u, s, sql: CTAS_SCHEMA_NAME, ) -def test_run_sync_query_cta_config(test_client, ctas_method): +def test_run_sync_query_cta_config( + test_client, + ctas_method: CTASMethod, + expected: str, +) -> None: if backend() == "sqlite": # sqlite doesn't support schemas return - tmp_table_name = f"{TEST_SYNC_CTA}_{ctas_method.lower()}" + tmp_table_name = f"{TEST_SYNC_CTA}_{ctas_method.name.lower()}" result = run_sql( test_client, QUERY, cta=True, ctas_method=ctas_method, tmp_table=tmp_table_name ) @@ -235,10 +269,7 @@ def test_run_sync_query_cta_config(test_client, ctas_method): assert cta_result(ctas_method) == (result["data"], result["columns"]) query = get_query_by_id(result["query"]["serverId"]) - assert ( - f"CREATE {ctas_method} {CTAS_SCHEMA_NAME}.{tmp_table_name} AS \n{QUERY}" - == query.executed_sql - ) + assert query.executed_sql == expected assert query.select_sql == get_select_star( tmp_table_name, limit=query.limit, schema=CTAS_SCHEMA_NAME ) @@ -249,16 +280,44 @@ def test_run_sync_query_cta_config(test_client, ctas_method): @pytest.mark.usefixtures("load_birth_names_data", "login_as_admin") -@pytest.mark.parametrize("ctas_method", [CtasMethod.TABLE, CtasMethod.VIEW]) +@pytest.mark.parametrize( + "ctas_method, expected", + [ + ( + CTASMethod.TABLE, + """ +CREATE TABLE sqllab_test_db.test_async_cta_config_table AS +SELECT + name +FROM birth_names +LIMIT 1 + """.strip(), + ), + ( + CTASMethod.VIEW, + """ +CREATE VIEW sqllab_test_db.test_async_cta_config_view AS +SELECT + name +FROM birth_names +LIMIT 1 + """.strip(), + ), + ], +) @mock.patch( # noqa: PT008 "superset.sqllab.sqllab_execution_context.get_cta_schema_name", lambda d, u, s, sql: CTAS_SCHEMA_NAME, ) -def test_run_async_query_cta_config(test_client, ctas_method): +def test_run_async_query_cta_config( + test_client, + ctas_method: CTASMethod, + expected: str, +) -> None: if backend() == "sqlite": # sqlite doesn't support schemas return - tmp_table_name = f"{TEST_ASYNC_CTA_CONFIG}_{ctas_method.lower()}" + tmp_table_name = f"{TEST_ASYNC_CTA_CONFIG}_{ctas_method.name.lower()}" result = run_sql( test_client, QUERY, @@ -275,18 +334,43 @@ def test_run_async_query_cta_config(test_client, ctas_method): get_select_star(tmp_table_name, limit=query.limit, schema=CTAS_SCHEMA_NAME) == query.select_sql ) - assert ( - f"CREATE {ctas_method} {CTAS_SCHEMA_NAME}.{tmp_table_name} AS \n{QUERY}" - == query.executed_sql - ) + assert query.executed_sql == expected delete_tmp_view_or_table(f"{CTAS_SCHEMA_NAME}.{tmp_table_name}", ctas_method) @pytest.mark.usefixtures("load_birth_names_data", "login_as_admin") -@pytest.mark.parametrize("ctas_method", [CtasMethod.TABLE, CtasMethod.VIEW]) -def test_run_async_cta_query(test_client, ctas_method): - table_name = f"{TEST_ASYNC_CTA}_{ctas_method.lower()}" +@pytest.mark.parametrize( + "ctas_method, expected", + [ + ( + CTASMethod.TABLE, + """ +CREATE TABLE test_async_cta_table AS +SELECT + name +FROM birth_names +LIMIT 1 + """.strip(), + ), + ( + CTASMethod.VIEW, + """ +CREATE VIEW test_async_cta_view AS +SELECT + name +FROM birth_names +LIMIT 1 + """.strip(), + ), + ], +) +def test_run_async_cta_query( + test_client, + ctas_method: CTASMethod, + expected: str, +) -> None: + table_name = f"{TEST_ASYNC_CTA}_{ctas_method.name.lower()}" result = run_sql( test_client, QUERY, @@ -301,7 +385,7 @@ def test_run_async_cta_query(test_client, ctas_method): assert QueryStatus.SUCCESS == query.status assert get_select_star(table_name, query.limit) in query.select_sql - assert f"CREATE {ctas_method} {table_name} AS \n{QUERY}" == query.executed_sql + assert query.executed_sql == expected assert QUERY == query.sql assert query.rows == (1 if backend() == "presto" else 0) assert query.select_as_cta @@ -311,9 +395,37 @@ def test_run_async_cta_query(test_client, ctas_method): @pytest.mark.usefixtures("load_birth_names_data", "login_as_admin") -@pytest.mark.parametrize("ctas_method", [CtasMethod.TABLE, CtasMethod.VIEW]) -def test_run_async_cta_query_with_lower_limit(test_client, ctas_method): - tmp_table = f"{TEST_ASYNC_LOWER_LIMIT}_{ctas_method.lower()}" +@pytest.mark.parametrize( + "ctas_method, expected", + [ + ( + CTASMethod.TABLE, + """ +CREATE TABLE test_async_lower_limit_table AS +SELECT + name +FROM birth_names +LIMIT 1 + """.strip(), + ), + ( + CTASMethod.VIEW, + """ +CREATE VIEW test_async_lower_limit_view AS +SELECT + name +FROM birth_names +LIMIT 1 + """.strip(), + ), + ], +) +def test_run_async_cta_query_with_lower_limit( + test_client, + ctas_method: CTASMethod, + expected: str, +) -> None: + tmp_table = f"{TEST_ASYNC_LOWER_LIMIT}_{ctas_method.name.lower()}" result = run_sql( test_client, QUERY, @@ -332,7 +444,7 @@ def test_run_async_cta_query_with_lower_limit(test_client, ctas_method): else get_select_star(tmp_table, query.limit) ) - assert f"CREATE {ctas_method} {tmp_table} AS \n{QUERY}" == query.executed_sql + assert query.executed_sql == expected assert QUERY == query.sql assert query.rows == (1 if backend() == "presto" else 0) @@ -442,28 +554,6 @@ def test_msgpack_payload_serialization(): assert isinstance(serialized, bytes) -def test_create_table_as(): - q = ParsedQuery("SELECT * FROM outer_space;") - - assert "CREATE TABLE tmp AS \nSELECT * FROM outer_space" == q.as_create_table("tmp") - assert ( - "DROP TABLE IF EXISTS tmp;\nCREATE TABLE tmp AS \nSELECT * FROM outer_space" - == q.as_create_table("tmp", overwrite=True) - ) - - # now without a semicolon - q = ParsedQuery("SELECT * FROM outer_space") - assert "CREATE TABLE tmp AS \nSELECT * FROM outer_space" == q.as_create_table("tmp") - - # now a multi-line query - multi_line_query = "SELECT * FROM planets WHERE\nLuke_Father = 'Darth Vader'" - q = ParsedQuery(multi_line_query) - assert ( - "CREATE TABLE tmp AS \nSELECT * FROM planets WHERE\nLuke_Father = 'Darth Vader'" - == q.as_create_table("tmp") - ) - - def test_in_app_context(): @celery_app.task(bind=True) def my_task(self): @@ -484,8 +574,8 @@ def test_in_app_context(): ) -def delete_tmp_view_or_table(name: str, db_object_type: str): - db.get_engine().execute(f"DROP {db_object_type} IF EXISTS {name}") +def delete_tmp_view_or_table(name: str, ctas_method: CTASMethod): + db.get_engine().execute(f"DROP {ctas_method.name} IF EXISTS {name}") def wait_for_success(result): diff --git a/tests/integration_tests/sqllab_tests.py b/tests/integration_tests/sqllab_tests.py index 514053f7d1e..99c347d95f9 100644 --- a/tests/integration_tests/sqllab_tests.py +++ b/tests/integration_tests/sqllab_tests.py @@ -31,16 +31,15 @@ from superset.db_engine_specs import BaseEngineSpec from superset.db_engine_specs.hive import HiveEngineSpec from superset.db_engine_specs.presto import PrestoEngineSpec from superset.errors import ErrorLevel, SupersetError, SupersetErrorType -from superset.exceptions import SupersetErrorException +from superset.exceptions import SupersetErrorException, SupersetInvalidCVASException from superset.models.sql_lab import Query from superset.result_set import SupersetResultSet from superset.sqllab.limiting_factor import LimitingFactor +from superset.sql.parse import CTASMethod from superset.sql_lab import ( cancel_query, execute_sql_statements, - apply_limit_if_exists, ) -from superset.sql_parse import CtasMethod from superset.utils.core import backend from superset.utils import json from superset.utils.json import datetime_to_epoch # noqa: F401 @@ -132,31 +131,13 @@ class TestSqlLab(SupersetTestCase): self.login(ADMIN_USERNAME) data = self.run_sql("DELETE FROM birth_names", "1") - assert data == { - "errors": [ - { - "message": ( - "This database does not allow for DDL/DML, and the query " - "could not be parsed to confirm it is a read-only query. Please " # noqa: E501 - "contact your administrator for more assistance." - ), - "error_type": SupersetErrorType.DML_NOT_ALLOWED_ERROR, - "level": ErrorLevel.ERROR, - "extra": { - "issue_codes": [ - { - "code": 1022, - "message": "Issue 1022 - Database does not allow data manipulation.", # noqa: E501 - } - ] - }, - } - ] - } + assert ( + data["errors"][0]["error_type"] == SupersetErrorType.DML_NOT_ALLOWED_ERROR + ) - @parameterized.expand([CtasMethod.TABLE, CtasMethod.VIEW]) + @parameterized.expand([CTASMethod.TABLE, CTASMethod.VIEW]) @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") - def test_sql_json_cta_dynamic_db(self, ctas_method): + def test_sql_json_cta_dynamic_db(self, ctas_method: CTASMethod) -> None: examples_db = get_example_database() if examples_db.backend == "sqlite": # sqlite doesn't support database creation @@ -170,7 +151,7 @@ class TestSqlLab(SupersetTestCase): examples_db.allow_ctas = True # enable cta self.login(ADMIN_USERNAME) - tmp_table_name = f"test_target_{ctas_method.lower()}" + tmp_table_name = f"test_target_{ctas_method.name.lower()}" self.run_sql( "SELECT * FROM birth_names", "1", @@ -195,7 +176,9 @@ class TestSqlLab(SupersetTestCase): ) # SQL_MAX_ROW not applied due to the SQLLAB_CTAS_NO_LIMIT set to True # cleanup - engine.execute(f"DROP {ctas_method} admin_database.{tmp_table_name}") + engine.execute( + f"DROP {ctas_method.name} admin_database.{tmp_table_name}" + ) examples_db.allow_ctas = old_allow_ctas db.session.commit() @@ -608,10 +591,10 @@ class TestSqlLab(SupersetTestCase): @mock.patch("superset.sql_lab.db") @mock.patch("superset.sql_lab.get_query") - @mock.patch("superset.sql_lab.execute_sql_statement") + @mock.patch("superset.sql_lab.execute_query") def test_execute_sql_statements( self, - mock_execute_sql_statement, + mock_execute_query, mock_get_query, mock_db, ): @@ -623,7 +606,7 @@ class TestSqlLab(SupersetTestCase): """ ) mock_db = mock.MagicMock() # noqa: F841 - mock_query = mock.MagicMock() + mock_query = mock.MagicMock(select_as_cta=False) mock_query.database.allow_run_async = False mock_cursor = mock.MagicMock() mock_query.database.get_raw_connection().__enter__().cursor.return_value = ( @@ -641,30 +624,20 @@ class TestSqlLab(SupersetTestCase): expand_data=False, log_params=None, ) - mock_execute_sql_statement.assert_has_calls( + mock_execute_query.assert_has_calls( [ - mock.call( - "-- comment\nSET @value = 42", - mock_query, - mock_cursor, - None, - False, - ), - mock.call( - "SELECT /*+ hint */ @value AS foo", - mock_query, - mock_cursor, - None, - False, - ), + mock.call(mock_query, mock_cursor, None), + mock.call(mock_query, mock_cursor, None), ] ) @mock.patch("superset.sql_lab.results_backend", None) @mock.patch("superset.sql_lab.get_query") - @mock.patch("superset.sql_lab.execute_sql_statement") + @mock.patch("superset.sql_lab.execute_query") def test_execute_sql_statements_no_results_backend( - self, mock_execute_sql_statement, mock_get_query + self, + mock_execute_query, + mock_get_query, ): sql = dedent( """ @@ -712,10 +685,10 @@ class TestSqlLab(SupersetTestCase): @mock.patch("superset.sql_lab.db") @mock.patch("superset.sql_lab.get_query") - @mock.patch("superset.sql_lab.execute_sql_statement") + @mock.patch("superset.sql_lab.execute_query") def test_execute_sql_statements_ctas( self, - mock_execute_sql_statement, + mock_execute_query, mock_get_query, mock_db, ): @@ -727,7 +700,13 @@ class TestSqlLab(SupersetTestCase): """ ) mock_db = mock.MagicMock() # noqa: F841 - mock_query = mock.MagicMock() + mock_query = mock.MagicMock( + select_as_cta=True, + ctas_method=CTASMethod.TABLE.name, + tmp_table_name="table", + tmp_schema_name="schema", + catalog="catalog", + ) mock_query.database.allow_run_async = False mock_cursor = mock.MagicMock() mock_query.database.get_raw_connection().__enter__().cursor.return_value = ( @@ -738,7 +717,7 @@ class TestSqlLab(SupersetTestCase): # set the query to CTAS mock_query.select_as_cta = True - mock_query.ctas_method = CtasMethod.TABLE + mock_query.ctas_method = CTASMethod.TABLE.name execute_sql_statements( query_id=1, @@ -749,22 +728,10 @@ class TestSqlLab(SupersetTestCase): expand_data=False, log_params=None, ) - mock_execute_sql_statement.assert_has_calls( + mock_execute_query.assert_has_calls( [ - mock.call( - "-- comment\nSET @value = 42", - mock_query, - mock_cursor, - None, - False, - ), - mock.call( - "SELECT /*+ hint */ @value AS foo", - mock_query, - mock_cursor, - None, - True, # apply_ctas - ), + mock.call(mock_query, mock_cursor, None), + mock.call(mock_query, mock_cursor, None), ] ) @@ -795,7 +762,7 @@ class TestSqlLab(SupersetTestCase): ) # try invalid CVAS - mock_query.ctas_method = CtasMethod.VIEW + mock_query.ctas_method = CTASMethod.VIEW.name sql = dedent( """ -- comment @@ -803,7 +770,7 @@ class TestSqlLab(SupersetTestCase): SELECT /*+ hint */ @value AS foo; """ ) - with pytest.raises(SupersetErrorException) as excinfo: + with pytest.raises(SupersetInvalidCVASException) as excinfo: execute_sql_statements( query_id=1, rendered_query=sql, @@ -870,29 +837,6 @@ class TestSqlLab(SupersetTestCase): ] } - def test_apply_limit_if_exists_when_incremented_limit_is_none(self): - sql = """ - SET @value = 42; - SELECT @value AS foo; - """ - database = get_example_database() - mock_query = mock.MagicMock() - mock_query.limit = 300 - final_sql = apply_limit_if_exists(database, None, mock_query, sql) - - assert final_sql == sql - - def test_apply_limit_if_exists_when_increased_limit(self): - sql = """ - SET @value = 42; - SELECT @value AS foo; - """ - database = get_example_database() - mock_query = mock.MagicMock() - mock_query.limit = 300 - final_sql = apply_limit_if_exists(database, 1000, mock_query, sql) - assert "LIMIT 1000" in final_sql - @pytest.mark.parametrize("spec", [HiveEngineSpec, PrestoEngineSpec]) def test_cancel_query_implicit(spec: BaseEngineSpec) -> None: diff --git a/tests/unit_tests/sql_lab_test.py b/tests/unit_tests/sql_lab_test.py index dd04e38dcb6..d5ca4e8b100 100644 --- a/tests/unit_tests/sql_lab_test.py +++ b/tests/unit_tests/sql_lab_test.py @@ -21,105 +21,46 @@ from unittest import mock from uuid import UUID import pytest -import sqlparse from freezegun import freeze_time from pytest_mock import MockerFixture -from sqlalchemy.orm.session import Session -from superset import db from superset.common.db_query_status import QueryStatus +from superset.db_engine_specs.postgres import PostgresEngineSpec from superset.errors import ErrorLevel, SupersetErrorType from superset.exceptions import OAuth2Error, SupersetErrorException from superset.models.core import Database -from superset.sql_lab import execute_sql_statements, get_sql_results -from superset.utils.core import override_user +from superset.sql.parse import SQLStatement, Table +from superset.sql_lab import ( + apply_rls, + execute_query, + execute_sql_statements, + get_predicates_for_table, + get_sql_results, +) from tests.unit_tests.models.core_test import oauth2_client_info -def test_execute_sql_statement(mocker: MockerFixture, app: None) -> None: +def test_execute_query(mocker: MockerFixture, app: None) -> None: """ Simple test for `execute_sql_statement`. """ - from superset.sql_lab import execute_sql_statement - - sql_statement = "SELECT 42 AS answer" - query = mocker.MagicMock() + query.executed_sql = "SELECT 42 AS answer" + query.limit = 1 - query.select_as_cta_used = False database = query.database database.allow_dml = False - database.apply_limit_to_sql.return_value = "SELECT 42 AS answer LIMIT 2" - database.mutate_sql_based_on_config.return_value = "SELECT 42 AS answer LIMIT 2" db_engine_spec = database.db_engine_spec db_engine_spec.fetch_data.return_value = [(42,)] cursor = mocker.MagicMock() SupersetResultSet = mocker.patch("superset.sql_lab.SupersetResultSet") # noqa: N806 - execute_sql_statement( - sql_statement, - query, - cursor=cursor, - log_params={}, - apply_ctas=False, - ) + execute_query(query, cursor=cursor, log_params={}) - database.apply_limit_to_sql.assert_called_with("SELECT 42 AS answer", 2, force=True) db_engine_spec.execute_with_cursor.assert_called_with( cursor, - "SELECT 42 AS answer LIMIT 2", - query, - ) - SupersetResultSet.assert_called_with([(42,)], cursor.description, db_engine_spec) - - -def test_execute_sql_statement_with_rls( - mocker: MockerFixture, -) -> None: - """ - Test for `execute_sql_statement` when an RLS rule is in place. - """ - from superset.sql_lab import execute_sql_statement - - sql_statement = "SELECT * FROM sales" - sql_statement_with_rls = f"{sql_statement} WHERE organization_id=42" - sql_statement_with_rls_and_limit = f"{sql_statement_with_rls} LIMIT 101" - - query = mocker.MagicMock() - query.limit = 100 - query.select_as_cta_used = False - database = query.database - database.allow_dml = False - database.apply_limit_to_sql.return_value = sql_statement_with_rls_and_limit - database.mutate_sql_based_on_config.return_value = sql_statement_with_rls_and_limit - db_engine_spec = database.db_engine_spec - db_engine_spec.fetch_data.return_value = [(42,)] - - cursor = mocker.MagicMock() - SupersetResultSet = mocker.patch("superset.sql_lab.SupersetResultSet") # noqa: N806 - mocker.patch( - "superset.sql_lab.insert_rls_as_subquery", - return_value=sqlparse.parse("SELECT * FROM sales WHERE organization_id=42")[0], - ) - mocker.patch("superset.sql_lab.is_feature_enabled", return_value=True) - - execute_sql_statement( - sql_statement, - query, - cursor=cursor, - log_params={}, - apply_ctas=False, - ) - - database.apply_limit_to_sql.assert_called_with( - "SELECT * FROM sales WHERE organization_id=42", - 101, - force=True, - ) - db_engine_spec.execute_with_cursor.assert_called_with( - cursor, - "SELECT * FROM sales WHERE organization_id=42 LIMIT 101", + "SELECT 42 AS answer", query, ) SupersetResultSet.assert_called_with([(42,)], cursor.description, db_engine_spec) @@ -232,122 +173,6 @@ def test_execute_sql_statement_within_payload_limit(mocker: MockerFixture) -> No ) -def test_sql_lab_insert_rls_as_subquery( - mocker: MockerFixture, - session: Session, -) -> None: - """ - Integration test for `insert_rls_as_subquery`. - """ - from flask_appbuilder.security.sqla.models import Role, User - - from superset.connectors.sqla.models import RowLevelSecurityFilter, SqlaTable - from superset.models.core import Database - from superset.models.sql_lab import Query - from superset.security.manager import SupersetSecurityManager - from superset.sql_lab import execute_sql_statement - from superset.utils.core import RowLevelSecurityFilterType - - engine = db.session.connection().engine - Query.metadata.create_all(engine) # pylint: disable=no-member - - connection = engine.raw_connection() - connection.execute("CREATE TABLE t (c INTEGER)") - for i in range(10): - connection.execute("INSERT INTO t VALUES (?)", (i,)) - - cursor = connection.cursor() - - query = Query( - sql="SELECT c FROM t", - client_id="abcde", - database=Database(database_name="test_db", sqlalchemy_uri="sqlite://"), - schema=None, - limit=5, - select_as_cta_used=False, - ) - db.session.add(query) - db.session.commit() - - admin = User( - first_name="Alice", - last_name="Doe", - email="adoe@example.org", - username="admin", - roles=[Role(name="Admin")], - ) - - # first without RLS - with override_user(admin): - superset_result_set = execute_sql_statement( - sql_statement=query.sql, - query=query, - cursor=cursor, - log_params=None, - apply_ctas=False, - ) - assert ( - superset_result_set.to_pandas_df().to_markdown() - == """ -| | c | -|---:|----:| -| 0 | 0 | -| 1 | 1 | -| 2 | 2 | -| 3 | 3 | -| 4 | 4 |""".strip() - ) - assert query.executed_sql == "SELECT\n c\nFROM t\nLIMIT 6" - - # now with RLS - rls = RowLevelSecurityFilter( - name="sqllab_rls1", - filter_type=RowLevelSecurityFilterType.REGULAR, - tables=[SqlaTable(database_id=1, schema=None, table_name="t")], - roles=[admin.roles[0]], - group_key=None, - clause="c > 5", - ) - db.session.add(rls) - db.session.flush() - mocker.patch.object(SupersetSecurityManager, "find_user", return_value=admin) - mocker.patch("superset.sql_lab.is_feature_enabled", return_value=True) - - with override_user(admin): - superset_result_set = execute_sql_statement( - sql_statement=query.sql, - query=query, - cursor=cursor, - log_params=None, - apply_ctas=False, - ) - assert ( - superset_result_set.to_pandas_df().to_markdown() - == """ -| | c | -|---:|----:| -| 0 | 6 | -| 1 | 7 | -| 2 | 8 | -| 3 | 9 |""".strip() - ) - assert ( - query.executed_sql - == """SELECT - c -FROM ( - SELECT - * - FROM t - WHERE - ( - t.c > 5 - ) -) AS t -LIMIT 6""" - ) - - @freeze_time("2021-04-01T00:00:00Z") def test_get_sql_results_oauth2(mocker: MockerFixture, app) -> None: """ @@ -377,8 +202,7 @@ def test_get_sql_results_oauth2(mocker: MockerFixture, app) -> None: "OAuth2 required" ) - query = mocker.MagicMock() - query.database = database + query = mocker.MagicMock(select_as_cta=False, database=database) mocker.patch("superset.sql_lab.get_query", return_value=query) payload = get_sql_results(query_id=1, rendered_query="SELECT 1") @@ -398,3 +222,67 @@ def test_get_sql_results_oauth2(mocker: MockerFixture, app) -> None: } ], } + + +def test_apply_rls(mocker: MockerFixture) -> None: + """ + Test the ``apply_rls`` helper function. + """ + database = mocker.MagicMock() + database.get_default_schema_for_query.return_value = "public" + database.get_default_catalog.return_value = "examples" + database.db_engine_spec = PostgresEngineSpec + query = mocker.MagicMock(database=database, catalog="examples") + get_predicates_for_table = mocker.patch( + "superset.sql_lab.get_predicates_for_table", + side_effect=[["c1 = 1"], ["c2 = 2"]], + ) + + parsed_statement = SQLStatement("SELECT * FROM t1, t2", "postgresql") + parsed_statement.tables = sorted(parsed_statement.tables, key=lambda x: x.table) # type: ignore + + apply_rls(query, parsed_statement) + + get_predicates_for_table.assert_has_calls( + [ + mocker.call(Table("t1", "public", "examples"), database, True), + mocker.call(Table("t2", "public", "examples"), database, True), + ] + ) + + assert ( + parsed_statement.format() + == """ +SELECT + * +FROM ( + SELECT + * + FROM t1 + WHERE + c1 = 1 +) AS t1, ( + SELECT + * + FROM t2 + WHERE + c2 = 2 +) AS t2 + """.strip() + ) + + +def test_get_predicates_for_table(mocker: MockerFixture) -> None: + """ + Test the ``get_predicates_for_table`` helper function. + """ + database = mocker.MagicMock() + dataset = mocker.MagicMock() + predicate = mocker.MagicMock() + predicate.compile.return_value = "c1 = 1" + dataset.get_sqla_row_level_filters.return_value = [predicate] + db = mocker.patch("superset.sql_lab.db") + db.session.query().filter().one_or_none.return_value = dataset + + table = Table("t1", "public", "examples") + assert get_predicates_for_table(table, database, True) == ["c1 = 1"]