# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import logging from datetime import datetime from typing import Any, cast, Optional from urllib import parse from flask import current_app as app, request, Response from flask_appbuilder import permission_name from flask_appbuilder.api import expose, protect, rison, safe from flask_appbuilder.models.sqla.interface import SQLAInterface from marshmallow import ValidationError from werkzeug.utils import secure_filename from superset import is_feature_enabled from superset.commands.sql_lab.estimate import QueryEstimationCommand from superset.commands.sql_lab.execute import CommandResult, ExecuteSqlCommand from superset.commands.sql_lab.export import SqlResultExportCommand from superset.commands.sql_lab.results import SqlExecutionResultsCommand from superset.commands.sql_lab.streaming_export_command import ( StreamingSqlResultExportCommand, ) from superset.constants import MODEL_API_RW_METHOD_PERMISSION_MAP from superset.daos.database import DatabaseDAO from superset.daos.query import QueryDAO from superset.extensions import event_logger from superset.jinja_context import get_template_processor from superset.models.sql_lab import Query from superset.sql.parse import SQLScript from superset.sql_lab import get_sql_results from superset.sqllab.command_status import SqlJsonExecutionStatus from superset.sqllab.exceptions import ( QueryIsForbiddenToAccessException, SqlLabException, ) from superset.sqllab.execution_context_convertor import ExecutionContextConvertor from superset.sqllab.query_render import SqlQueryRenderImpl from superset.sqllab.schemas import ( EstimateQueryCostSchema, ExecutePayloadSchema, FormatQueryPayloadSchema, QueryExecutionResponseSchema, sql_lab_get_results_schema, SQLLabBootstrapSchema, ) from superset.sqllab.sql_json_executer import ( ASynchronousSqlJsonExecutor, SqlJsonExecutor, SynchronousSqlJsonExecutor, ) from superset.sqllab.sqllab_execution_context import SqlJsonExecutionContext from superset.sqllab.utils import bootstrap_sqllab_data from superset.sqllab.validators import CanAccessQueryValidatorImpl from superset.superset_typing import FlaskResponse from superset.utils import core as utils, json from superset.views.base import CsvResponse, generate_download_headers, json_success from superset.views.base_api import BaseSupersetApi, requires_json, statsd_metrics logger = logging.getLogger(__name__) class SqlLabRestApi(BaseSupersetApi): method_permission_name = MODEL_API_RW_METHOD_PERMISSION_MAP datamodel = SQLAInterface(Query) resource_name = "sqllab" allow_browser_login = True class_permission_name = "SQLLab" estimate_model_schema = EstimateQueryCostSchema() execute_model_schema = ExecutePayloadSchema() format_model_schema = FormatQueryPayloadSchema() apispec_parameter_schemas = { "sql_lab_get_results_schema": sql_lab_get_results_schema, } openapi_spec_tag = "SQL Lab" openapi_spec_component_schemas = ( EstimateQueryCostSchema, ExecutePayloadSchema, FormatQueryPayloadSchema, QueryExecutionResponseSchema, SQLLabBootstrapSchema, ) @expose("/", methods=("GET",)) @protect() @safe @statsd_metrics @event_logger.log_this_with_context( action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.get", log_to_statsd=False, ) def get(self) -> Response: """Get the bootstrap data for SqlLab --- get: summary: Get the bootstrap data for SqlLab page description: >- Assembles SQLLab bootstrap data (active_tab, databases, queries, tab_state_ids) in a single endpoint. The data can be assembled from the current user's id. responses: 200: description: Returns the initial bootstrap data for SqlLab content: application/json: schema: $ref: '#/components/schemas/SQLLabBootstrapSchema' 400: $ref: '#/components/responses/400' 401: $ref: '#/components/responses/401' 403: $ref: '#/components/responses/403' 500: $ref: '#/components/responses/500' """ user_id = utils.get_user_id() # TODO: Replace with a command class once fully migrated to SPA result = bootstrap_sqllab_data(user_id) return json_success( json.dumps( {"result": result}, default=json.json_iso_dttm_ser, ignore_nan=True, ), 200, ) @expose("/estimate/", methods=("POST",)) @protect() @statsd_metrics @requires_json @event_logger.log_this_with_context( action=lambda self, *args, **kwargs: f"{self.__class__.__name__}" f".estimate_query_cost", log_to_statsd=False, ) def estimate_query_cost(self) -> Response: """Estimate the SQL query execution cost. --- post: summary: Estimate the SQL query execution cost requestBody: description: SQL query and params required: true content: application/json: schema: $ref: '#/components/schemas/EstimateQueryCostSchema' responses: 200: description: Query estimation result content: application/json: schema: type: object properties: result: type: object 400: $ref: '#/components/responses/400' 401: $ref: '#/components/responses/401' 403: $ref: '#/components/responses/403' 500: $ref: '#/components/responses/500' """ try: model = self.estimate_model_schema.load(request.json) except ValidationError as error: return self.response_400(message=error.messages) command = QueryEstimationCommand(model) result = command.run() return self.response(200, result=result) @expose("/format_sql/", methods=("POST",)) @statsd_metrics @protect() @permission_name("read") @event_logger.log_this_with_context( action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.format", log_to_statsd=False, ) def format_sql(self) -> FlaskResponse: """Format the SQL query. --- post: summary: Format SQL code requestBody: description: SQL query required: true content: application/json: schema: $ref: '#/components/schemas/FormatQueryPayloadSchema' responses: 200: description: Format SQL result content: application/json: schema: type: object properties: result: type: string 400: $ref: '#/components/responses/400' 401: $ref: '#/components/responses/401' 403: $ref: '#/components/responses/403' 500: $ref: '#/components/responses/500' """ try: model = self.format_model_schema.load(request.json) sql = model["sql"] template_params = model.get("template_params") database_id = model.get("database_id") # Process Jinja templates if template_params and database_id are provided if template_params and database_id is not None: database = DatabaseDAO.find_by_id(database_id) if database: try: template_params = ( json.loads(template_params) if isinstance(template_params, str) else template_params ) if template_params: template_processor = get_template_processor( database=database ) sql = template_processor.process_template( sql, **template_params ) except json.JSONDecodeError: logger.warning( "Invalid template parameter %s. Skipping processing", str(template_params), ) result = SQLScript(sql, model.get("engine")).format() return self.response(200, result=result) except ValidationError as error: return self.response_400(message=error.messages) @expose("/export//") @protect() @statsd_metrics @event_logger.log_this_with_context( action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.export_csv", log_to_statsd=False, ) def export_csv(self, client_id: str) -> CsvResponse: """Export the SQL query results to a CSV. --- get: summary: Export the SQL query results to a CSV parameters: - in: path schema: type: integer name: client_id description: The SQL query result identifier responses: 200: description: SQL query results content: text/csv: schema: type: string 400: $ref: '#/components/responses/400' 401: $ref: '#/components/responses/401' 403: $ref: '#/components/responses/403' 404: $ref: '#/components/responses/404' 500: $ref: '#/components/responses/500' """ result = SqlResultExportCommand(client_id=client_id).run() query, data, row_count = result["query"], result["data"], result["count"] quoted_csv_name = parse.quote(query.name) response = CsvResponse( data, headers=generate_download_headers("csv", quoted_csv_name) ) event_info = { "event_type": "data_export", "client_id": client_id, "row_count": row_count, "database": query.database.name, "catalog": query.catalog, "schema": query.schema, "sql": query.sql, "exported_format": "csv", } event_rep = repr(event_info) logger.debug( "CSV exported: %s", event_rep, extra={"superset_event": event_info} ) return response @expose("/export_streaming/", methods=("POST",)) @protect() @permission_name("read") @statsd_metrics @event_logger.log_this_with_context( action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.export_streaming_csv", log_to_statsd=False, ) def export_streaming_csv(self) -> Response: """Export SQL query results using streaming for large datasets. --- post: summary: Export SQL query results to CSV with streaming requestBody: description: Export parameters required: true content: application/x-www-form-urlencoded: schema: type: object properties: client_id: type: string description: The SQL query result identifier filename: type: string description: Optional filename for the export expected_rows: type: integer description: Optional expected row count for progress tracking responses: 200: description: Streaming CSV export content: text/csv: schema: type: string 400: $ref: '#/components/responses/400' 401: $ref: '#/components/responses/401' 403: $ref: '#/components/responses/403' 404: $ref: '#/components/responses/404' 500: $ref: '#/components/responses/500' """ # Extract parameters from form data client_id = request.form.get("client_id") filename = request.form.get("filename") if not client_id: return self.response_400(message="client_id is required") expected_rows = None if expected_rows_str := request.form.get("expected_rows"): try: expected_rows = int(expected_rows_str) except (ValueError, TypeError): logger.warning("Invalid expected_rows value: %s", expected_rows_str) return self._create_streaming_csv_response(client_id, filename, expected_rows) def _create_streaming_csv_response( self, client_id: str, filename: str | None = None, expected_rows: int | None = None, ) -> Response: """Create a streaming CSV response for large SQL Lab result sets.""" # Execute streaming command # TODO: Make chunk size configurable via SUPERSET_CONFIG chunk_size = 1024 command = StreamingSqlResultExportCommand(client_id, chunk_size) command.validate() if not filename: timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") filename = secure_filename(f"sqllab_{client_id}_{timestamp}.csv") # Get the callable that returns the generator csv_generator_callable = command.run() # Get encoding from config encoding = app.config.get("CSV_EXPORT", {}).get("encoding", "utf-8") # Create response with streaming headers response = Response( csv_generator_callable(), # Call the callable to get generator mimetype=f"text/csv; charset={encoding}", headers={ "Content-Disposition": f'attachment; filename="{filename}"', "Cache-Control": "no-cache", "X-Accel-Buffering": "no", # Disable nginx buffering }, direct_passthrough=False, # Flask must iterate generator ) # Force chunked transfer encoding response.implicit_sequence_conversion = False logger.info( "SQL Lab streaming CSV export started: client_id=%s, filename=%s", client_id, filename, ) return response @expose("/results/") @protect() @statsd_metrics @rison(sql_lab_get_results_schema) @event_logger.log_this_with_context( action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.get_results", log_to_statsd=False, ) def get_results(self, **kwargs: Any) -> FlaskResponse: """Get the result of a SQL query execution. --- get: summary: Get the result of a SQL query execution parameters: - in: query name: q content: application/json: schema: $ref: '#/components/schemas/sql_lab_get_results_schema' responses: 200: description: SQL query execution result content: application/json: schema: $ref: '#/components/schemas/QueryExecutionResponseSchema' 400: $ref: '#/components/responses/400' 401: $ref: '#/components/responses/401' 403: $ref: '#/components/responses/403' 404: $ref: '#/components/responses/404' 410: $ref: '#/components/responses/410' 500: $ref: '#/components/responses/500' """ params = kwargs["rison"] key = params.get("key") rows = params.get("rows") result = SqlExecutionResultsCommand(key=key, rows=rows).run() # Using pessimistic json serialization since some database drivers can return # unserializeable types at times payload = json.dumps( result, default=json.pessimistic_json_iso_dttm_ser, ignore_nan=True, ) return json_success(payload, 200) @expose("/execute/", methods=("POST",)) @protect() @statsd_metrics @requires_json @event_logger.log_this_with_context( action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.get_results", log_to_statsd=False, ) def execute_sql_query(self) -> FlaskResponse: """Execute a SQL query. --- post: summary: Execute a SQL query requestBody: description: SQL query and params required: true content: application/json: schema: $ref: '#/components/schemas/ExecutePayloadSchema' responses: 200: description: Query execution result content: application/json: schema: $ref: '#/components/schemas/QueryExecutionResponseSchema' 202: description: Query execution result, query still running content: application/json: schema: $ref: '#/components/schemas/QueryExecutionResponseSchema' 400: $ref: '#/components/responses/400' 401: $ref: '#/components/responses/401' 403: $ref: '#/components/responses/403' 404: $ref: '#/components/responses/404' 500: $ref: '#/components/responses/500' """ try: self.execute_model_schema.load(request.json) except ValidationError as error: return self.response_400(message=error.messages) try: log_params = { "user_agent": cast(Optional[str], request.headers.get("USER_AGENT")) } execution_context = SqlJsonExecutionContext(request.json) command = self._create_sql_json_command(execution_context, log_params) command_result: CommandResult = command.run() response_status = ( 202 if command_result["status"] == SqlJsonExecutionStatus.QUERY_IS_RUNNING else 200 ) # return the execution result without special encoding return json_success(command_result["payload"], response_status) except SqlLabException as ex: payload = {"errors": [ex.to_dict()]} response_status = ( 403 if isinstance(ex, QueryIsForbiddenToAccessException) else ex.status ) return self.response(response_status, **payload) @staticmethod def _create_sql_json_command( execution_context: SqlJsonExecutionContext, log_params: Optional[dict[str, Any]] ) -> ExecuteSqlCommand: query_dao = QueryDAO() sql_json_executor = SqlLabRestApi._create_sql_json_executor( execution_context, query_dao ) execution_context_convertor = ExecutionContextConvertor() execution_context_convertor.set_max_row_in_display( int(app.config.get("DISPLAY_MAX_ROW")) ) return ExecuteSqlCommand( execution_context, query_dao, DatabaseDAO(), CanAccessQueryValidatorImpl(), SqlQueryRenderImpl(get_template_processor), sql_json_executor, execution_context_convertor, app.config["SQLLAB_CTAS_NO_LIMIT"], log_params, ) @staticmethod def _create_sql_json_executor( execution_context: SqlJsonExecutionContext, query_dao: QueryDAO ) -> SqlJsonExecutor: sql_json_executor: SqlJsonExecutor if execution_context.is_run_asynchronous(): sql_json_executor = ASynchronousSqlJsonExecutor(query_dao, get_sql_results) else: sql_json_executor = SynchronousSqlJsonExecutor( query_dao, get_sql_results, app.config.get("SQLLAB_TIMEOUT"), is_feature_enabled("SQLLAB_BACKEND_PERSISTENCE"), ) return sql_json_executor