mirror of
https://github.com/apache/superset.git
synced 2026-04-07 18:35:15 +00:00
484 lines
17 KiB
Python
484 lines
17 KiB
Python
# Licensed to the Apache Software Foundation (ASF) under one
|
|
# or more contributor license agreements. See the NOTICE file
|
|
# distributed with this work for additional information
|
|
# regarding copyright ownership. The ASF licenses this file
|
|
# to you under the Apache License, Version 2.0 (the
|
|
# "License"); you may not use this file except in compliance
|
|
# with the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing,
|
|
# software distributed under the License is distributed on an
|
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
# KIND, either express or implied. See the License for the
|
|
# specific language governing permissions and limitations
|
|
# under the License.
|
|
"""
|
|
Celery task for async SQL execution.
|
|
|
|
This module provides the Celery task for executing SQL queries asynchronously.
|
|
It is used by SQLExecutor.execute_async() to run queries in the background.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import dataclasses
|
|
import logging
|
|
import uuid
|
|
from typing import Any
|
|
|
|
import msgpack
|
|
from celery.exceptions import SoftTimeLimitExceeded
|
|
from flask import current_app as app, has_app_context
|
|
from flask_babel import gettext as __
|
|
|
|
from superset import (
|
|
db,
|
|
results_backend,
|
|
security_manager,
|
|
)
|
|
from superset.common.db_query_status import QueryStatus
|
|
from superset.constants import QUERY_CANCEL_KEY
|
|
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
|
from superset.exceptions import (
|
|
SupersetErrorException,
|
|
SupersetErrorsException,
|
|
)
|
|
from superset.extensions import celery_app
|
|
from superset.models.sql_lab import Query
|
|
from superset.result_set import SupersetResultSet
|
|
from superset.sql.execution.executor import execute_sql_with_cursor
|
|
from superset.sql.parse import SQLScript
|
|
from superset.sqllab.utils import write_ipc_buffer
|
|
from superset.utils import json
|
|
from superset.utils.core import override_user, zlib_compress
|
|
from superset.utils.dates import now_as_float
|
|
from superset.utils.decorators import stats_timing
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
BYTES_IN_MB = 1024 * 1024
|
|
|
|
|
|
def _get_query(query_id: int) -> Query:
|
|
"""Get the query by ID."""
|
|
return db.session.query(Query).filter_by(id=query_id).one()
|
|
|
|
|
|
def _handle_query_error(
|
|
ex: Exception,
|
|
query: Query,
|
|
payload: dict[str, Any] | None = None,
|
|
prefix_message: str = "",
|
|
) -> dict[str, Any]:
|
|
"""Handle error while processing the SQL query."""
|
|
payload = payload or {}
|
|
msg = f"{prefix_message} {str(ex)}".strip()
|
|
query.error_message = msg
|
|
query.tmp_table_name = None
|
|
# Preserve TIMED_OUT status if already set (from SoftTimeLimitExceeded handler)
|
|
if query.status != QueryStatus.TIMED_OUT:
|
|
query.status = QueryStatus.FAILED
|
|
|
|
if not query.end_time:
|
|
query.end_time = now_as_float()
|
|
|
|
# Extract DB-specific errors
|
|
if isinstance(ex, SupersetErrorException):
|
|
errors = [ex.error]
|
|
elif isinstance(ex, SupersetErrorsException):
|
|
errors = ex.errors
|
|
else:
|
|
errors = query.database.db_engine_spec.extract_errors(
|
|
str(ex), database_name=query.database.unique_name
|
|
)
|
|
|
|
errors_payload = [dataclasses.asdict(error) for error in errors]
|
|
if errors:
|
|
query.set_extra_json_key("errors", errors_payload)
|
|
|
|
db.session.commit() # pylint: disable=consider-using-transaction
|
|
payload.update(
|
|
{"status": query.status.value, "error": msg, "errors": errors_payload}
|
|
)
|
|
if troubleshooting_link := app.config.get("TROUBLESHOOTING_LINK"):
|
|
payload["link"] = troubleshooting_link
|
|
return payload
|
|
|
|
|
|
def _serialize_payload(payload: dict[Any, Any]) -> bytes:
|
|
"""Serialize payload for storage based on RESULTS_BACKEND_USE_MSGPACK config."""
|
|
from superset import results_backend_use_msgpack
|
|
|
|
if results_backend_use_msgpack:
|
|
return msgpack.dumps(payload, default=json.json_iso_dttm_ser, use_bin_type=True)
|
|
return json.dumps(payload, default=json.json_iso_dttm_ser, ignore_nan=True).encode(
|
|
"utf-8"
|
|
)
|
|
|
|
|
|
def _prepare_statement_blocks(
|
|
rendered_query: str,
|
|
db_engine_spec: Any,
|
|
) -> tuple[SQLScript, list[str]]:
|
|
"""
|
|
Parse SQL and build statement blocks for execution.
|
|
|
|
Some databases (like BigQuery and Kusto) do not persist state across multiple
|
|
statements if they're run separately (especially when using `NullPool`), so we run
|
|
the query as a single block when the database engine spec requires it.
|
|
"""
|
|
parsed_script = SQLScript(rendered_query, engine=db_engine_spec.engine)
|
|
|
|
# Build statement blocks for execution
|
|
if db_engine_spec.run_multiple_statements_as_one:
|
|
blocks = [parsed_script.format(comments=db_engine_spec.allows_sql_comments)]
|
|
else:
|
|
blocks = [
|
|
statement.format(comments=db_engine_spec.allows_sql_comments)
|
|
for statement in parsed_script.statements
|
|
]
|
|
|
|
return parsed_script, blocks
|
|
|
|
|
|
def _finalize_successful_query(
|
|
query: Query,
|
|
original_script: SQLScript,
|
|
execution_results: list[tuple[str, SupersetResultSet | None, float, int]],
|
|
payload: dict[str, Any],
|
|
total_execution_time_ms: float,
|
|
) -> None:
|
|
"""Update query metadata and payload after successful execution."""
|
|
# Calculate total rows across all statements
|
|
total_rows = 0
|
|
statements_data: list[dict[str, Any]] = []
|
|
|
|
# Get original statement strings
|
|
original_sqls = [stmt.format() for stmt in original_script.statements]
|
|
|
|
for orig_sql, (exec_sql, result_set, exec_time, rowcount) in zip(
|
|
original_sqls, execution_results, strict=True
|
|
):
|
|
if result_set is not None:
|
|
# SELECT statement
|
|
total_rows += result_set.size
|
|
data, columns = _serialize_result_set(result_set)
|
|
statements_data.append(
|
|
{
|
|
"original_sql": orig_sql,
|
|
"executed_sql": exec_sql,
|
|
"data": data,
|
|
"columns": columns,
|
|
"row_count": result_set.size,
|
|
"execution_time_ms": exec_time,
|
|
}
|
|
)
|
|
else:
|
|
# DML statement - no data, just row count
|
|
statements_data.append(
|
|
{
|
|
"original_sql": orig_sql,
|
|
"executed_sql": exec_sql,
|
|
"data": None,
|
|
"columns": [],
|
|
"row_count": rowcount,
|
|
"execution_time_ms": exec_time,
|
|
}
|
|
)
|
|
|
|
query.rows = total_rows
|
|
query.progress = 100
|
|
query.set_extra_json_key("progress", None)
|
|
# Store columns from last statement (for compatibility)
|
|
if execution_results and execution_results[-1][1] is not None:
|
|
query.set_extra_json_key("columns", execution_results[-1][1].columns)
|
|
query.end_time = now_as_float()
|
|
|
|
payload.update(
|
|
{
|
|
"status": QueryStatus.SUCCESS.value,
|
|
"statements": statements_data,
|
|
"total_execution_time_ms": total_execution_time_ms,
|
|
"query": query.to_dict(),
|
|
}
|
|
)
|
|
payload["query"]["state"] = QueryStatus.SUCCESS.value
|
|
|
|
|
|
def _store_results_in_backend(
|
|
query: Query,
|
|
payload: dict[str, Any],
|
|
database: Any,
|
|
) -> None:
|
|
"""Store query results in the results backend."""
|
|
key = str(uuid.uuid4())
|
|
payload["query"]["resultsKey"] = key
|
|
logger.info(
|
|
"Query %s: Storing results in results backend, key: %s",
|
|
str(query.id),
|
|
key,
|
|
)
|
|
stats_logger = app.config["STATS_LOGGER"]
|
|
with stats_timing("sqllab.query.results_backend_write", stats_logger):
|
|
with stats_timing(
|
|
"sqllab.query.results_backend_write_serialization", stats_logger
|
|
):
|
|
serialized_payload = _serialize_payload(payload)
|
|
|
|
# Check payload size limit
|
|
if sql_lab_payload_max_mb := app.config.get("SQLLAB_PAYLOAD_MAX_MB"):
|
|
serialized_payload_size = len(serialized_payload)
|
|
max_bytes = sql_lab_payload_max_mb * BYTES_IN_MB
|
|
|
|
if serialized_payload_size > max_bytes:
|
|
logger.info("Result size exceeds the allowed limit.")
|
|
raise SupersetErrorException(
|
|
SupersetError(
|
|
message=(
|
|
f"Result size "
|
|
f"({serialized_payload_size / BYTES_IN_MB:.2f} MB) "
|
|
f"exceeds the allowed limit of "
|
|
f"{sql_lab_payload_max_mb} MB."
|
|
),
|
|
error_type=SupersetErrorType.RESULT_TOO_LARGE_ERROR,
|
|
level=ErrorLevel.ERROR,
|
|
)
|
|
)
|
|
|
|
cache_timeout = database.cache_timeout
|
|
if cache_timeout is None:
|
|
cache_timeout = app.config["CACHE_DEFAULT_TIMEOUT"]
|
|
|
|
compressed = zlib_compress(serialized_payload)
|
|
logger.debug("*** serialized payload size: %i", len(serialized_payload))
|
|
logger.debug("*** compressed payload size: %i", len(compressed))
|
|
|
|
write_success = results_backend.set(key, compressed, cache_timeout)
|
|
if not write_success:
|
|
logger.error(
|
|
"Query %s: Failed to store results in backend, key: %s",
|
|
str(query.id),
|
|
key,
|
|
)
|
|
stats_logger.incr("sqllab.results_backend.write_failure")
|
|
query.results_key = None
|
|
query.status = QueryStatus.FAILED
|
|
query.error_message = (
|
|
"Failed to store query results in the results backend. "
|
|
"Please try again or contact your administrator."
|
|
)
|
|
db.session.commit() # pylint: disable=consider-using-transaction
|
|
raise SupersetErrorException(
|
|
SupersetError(
|
|
message=__("Failed to store query results. Please try again."),
|
|
error_type=SupersetErrorType.RESULTS_BACKEND_ERROR,
|
|
level=ErrorLevel.ERROR,
|
|
)
|
|
)
|
|
else:
|
|
query.results_key = key
|
|
logger.info(
|
|
"Query %s: Successfully stored results in backend, key: %s",
|
|
str(query.id),
|
|
key,
|
|
)
|
|
|
|
|
|
def _serialize_result_set(
|
|
result_set: SupersetResultSet,
|
|
) -> tuple[bytes | list[Any], list[Any]]:
|
|
"""
|
|
Serialize result set based on RESULTS_BACKEND_USE_MSGPACK config.
|
|
|
|
When msgpack is enabled, uses Apache Arrow IPC format for efficiency.
|
|
Otherwise, falls back to JSON-serializable records.
|
|
|
|
:param result_set: Query result set to serialize
|
|
:returns: Tuple of (serialized_data, columns)
|
|
"""
|
|
from superset import results_backend_use_msgpack
|
|
from superset.dataframe import df_to_records
|
|
|
|
if results_backend_use_msgpack:
|
|
if has_app_context():
|
|
stats_logger = app.config["STATS_LOGGER"]
|
|
with stats_timing(
|
|
"sqllab.query.results_backend_pa_serialization", stats_logger
|
|
):
|
|
data: bytes | list[Any] = write_ipc_buffer(
|
|
result_set.pa_table
|
|
).to_pybytes()
|
|
else:
|
|
data = write_ipc_buffer(result_set.pa_table).to_pybytes()
|
|
else:
|
|
df = result_set.to_pandas_df()
|
|
data = df_to_records(df) or []
|
|
|
|
return (data, result_set.columns)
|
|
|
|
|
|
@celery_app.task(name="query_execution.execute_sql")
|
|
def execute_sql_task(
|
|
query_id: int,
|
|
rendered_query: str,
|
|
username: str | None = None,
|
|
start_time: float | None = None,
|
|
) -> dict[str, Any] | None:
|
|
"""
|
|
Execute SQL query asynchronously via Celery.
|
|
|
|
This task is used by SQLExecutor.execute_async() to run queries
|
|
in background workers with full feature support.
|
|
|
|
:param query_id: ID of the Query model
|
|
:param rendered_query: Pre-rendered SQL query to execute
|
|
:param username: Username for context override
|
|
:param start_time: Query start time for timing metrics
|
|
:returns: Query result payload or None
|
|
"""
|
|
with app.test_request_context():
|
|
with override_user(security_manager.find_user(username)):
|
|
try:
|
|
return _execute_sql_statements(
|
|
query_id,
|
|
rendered_query,
|
|
start_time=start_time,
|
|
)
|
|
except Exception as ex:
|
|
logger.debug("Query %d: %s", query_id, ex)
|
|
stats_logger = app.config["STATS_LOGGER"]
|
|
stats_logger.incr("error_sqllab_unhandled")
|
|
query = _get_query(query_id=query_id)
|
|
return _handle_query_error(ex, query)
|
|
|
|
|
|
def _make_check_stopped_fn(query: Query) -> Any:
|
|
"""Create a function to check if query was stopped."""
|
|
|
|
def check_stopped() -> bool:
|
|
db.session.refresh(query)
|
|
return query.status == QueryStatus.STOPPED
|
|
|
|
return check_stopped
|
|
|
|
|
|
def _make_execute_fn(query: Query, db_engine_spec: Any) -> Any:
|
|
"""Create an execute function with stats timing."""
|
|
|
|
def execute_with_stats(cursor: Any, sql: str) -> None:
|
|
query.executed_sql = sql
|
|
stats_logger = app.config["STATS_LOGGER"]
|
|
with stats_timing("sqllab.query.time_executing_query", stats_logger):
|
|
db_engine_spec.execute_with_cursor(cursor, sql, query)
|
|
|
|
return execute_with_stats
|
|
|
|
|
|
def _make_log_query_fn(database: Any) -> Any:
|
|
"""Create a query logging function."""
|
|
|
|
def log_query(sql: str, schema: str | None) -> None:
|
|
if log_query_fn := app.config.get("QUERY_LOGGER"):
|
|
log_query_fn(
|
|
database.sqlalchemy_uri,
|
|
sql,
|
|
schema,
|
|
__name__,
|
|
security_manager,
|
|
None,
|
|
)
|
|
|
|
return log_query
|
|
|
|
|
|
def _execute_sql_statements(
|
|
query_id: int,
|
|
rendered_query: str,
|
|
start_time: float | None,
|
|
) -> dict[str, Any] | None:
|
|
"""Execute SQL statements and store results."""
|
|
if start_time:
|
|
stats_logger = app.config["STATS_LOGGER"]
|
|
stats_logger.timing("sqllab.query.time_pending", now_as_float() - start_time)
|
|
|
|
query = _get_query(query_id=query_id)
|
|
payload: dict[str, Any] = {"query_id": query_id}
|
|
database = query.database
|
|
db_engine_spec = database.db_engine_spec
|
|
db_engine_spec.patch()
|
|
|
|
logger.info("Query %s: Set query to 'running'", str(query_id))
|
|
query.status = QueryStatus.RUNNING
|
|
query.start_running_time = now_as_float()
|
|
execution_start_time = now_as_float()
|
|
db.session.commit() # pylint: disable=consider-using-transaction
|
|
|
|
# Parse original SQL (from user) to preserve before transformations
|
|
original_script = SQLScript(query.sql, engine=db_engine_spec.engine)
|
|
|
|
# Parse transformed SQL (with RLS, limits, etc.)
|
|
parsed_script, blocks = _prepare_statement_blocks(rendered_query, db_engine_spec)
|
|
|
|
with database.get_raw_connection(
|
|
catalog=query.catalog,
|
|
schema=query.schema,
|
|
) as conn:
|
|
cursor = conn.cursor()
|
|
|
|
cancel_query_id = db_engine_spec.get_cancel_query_id(cursor, query)
|
|
if cancel_query_id is not None:
|
|
query.set_extra_json_key(QUERY_CANCEL_KEY, cancel_query_id)
|
|
db.session.commit() # pylint: disable=consider-using-transaction
|
|
|
|
try:
|
|
execution_results = execute_sql_with_cursor(
|
|
database=database,
|
|
cursor=cursor,
|
|
statements=blocks,
|
|
query=query,
|
|
log_query_fn=_make_log_query_fn(database),
|
|
check_stopped_fn=_make_check_stopped_fn(query),
|
|
execute_fn=_make_execute_fn(query, db_engine_spec),
|
|
)
|
|
except SoftTimeLimitExceeded as ex:
|
|
query.status = QueryStatus.TIMED_OUT
|
|
logger.warning("Query %d: Time limit exceeded", query.id)
|
|
timeout_sec = app.config["SQLLAB_ASYNC_TIME_LIMIT_SEC"]
|
|
raise SupersetErrorException(
|
|
SupersetError(
|
|
message=__(
|
|
"The query was killed after %(sqllab_timeout)s seconds. "
|
|
"It might be too complex, or the database might be "
|
|
"under heavy load.",
|
|
sqllab_timeout=timeout_sec,
|
|
),
|
|
error_type=SupersetErrorType.SQLLAB_TIMEOUT_ERROR,
|
|
level=ErrorLevel.ERROR,
|
|
)
|
|
) from ex
|
|
|
|
# Check if stopped
|
|
if not execution_results:
|
|
payload.update({"status": QueryStatus.STOPPED.value})
|
|
return payload
|
|
|
|
# Commit for mutations
|
|
if parsed_script.has_mutation() or query.select_as_cta:
|
|
conn.commit() # pylint: disable=consider-using-transaction
|
|
|
|
total_execution_time_ms = (now_as_float() - execution_start_time) * 1000
|
|
_finalize_successful_query(
|
|
query, original_script, execution_results, payload, total_execution_time_ms
|
|
)
|
|
|
|
if results_backend:
|
|
_store_results_in_backend(query, payload, database)
|
|
|
|
if query.status != QueryStatus.FAILED:
|
|
query.status = QueryStatus.SUCCESS
|
|
db.session.commit() # pylint: disable=consider-using-transaction
|
|
|
|
return payload
|