Files
superset2/superset/tasks/context.py

674 lines
26 KiB
Python

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Concrete TaskContext implementation for GTF"""
import logging
import threading
import time
import traceback
from typing import Any, Callable, cast, TYPE_CHECKING, TypeVar
from flask import current_app
from superset_core.tasks.types import (
TaskContext as CoreTaskContext,
TaskProperties,
TaskStatus,
)
from superset.stats_logger import BaseStatsLogger
from superset.tasks.constants import ABORT_STATES
from superset.tasks.utils import progress_update
if TYPE_CHECKING:
from superset.models.tasks import Task
from superset.tasks.manager import AbortListener
logger = logging.getLogger(__name__)
T = TypeVar("T")
class TaskContext(CoreTaskContext):
"""
Concrete implementation of TaskContext for the Global Async Task Framework.
Provides write-only access to task state. Tasks use this context to update
their progress and payload, and check for cancellation. Tasks should not
need to read their own state - they are the source of state, not consumers.
"""
# Type alias for handler failures: (handler_type, exception, stack_trace)
HandlerFailure = tuple[str, Exception, str]
def __init__(self, task: "Task") -> None:
"""
Initialize TaskContext with a pre-fetched task entity.
The task entity must be pre-fetched by the caller (executor) to ensure
caching works correctly and to enforce the pattern of single initial fetch.
:param task: Pre-fetched Task entity (required)
"""
self._task_uuid = task.uuid
self._cleanup_handlers: list[Callable[[], None]] = []
self._abort_handlers: list[Callable[[], None]] = []
self._abort_listener: "AbortListener | None" = None
self._abort_detected = False
self._abort_handlers_completed = False # Track if all abort handlers finished
self._execution_completed = False # Set by executor after task work completes
# Collected handler failures for unified reporting
self._handler_failures: list[TaskContext.HandlerFailure] = []
# Timeout timer state
self._timeout_timer: threading.Timer | None = None
self._timeout_triggered = False
# Throttling state for update_task()
# These manage the minimum interval between DB writes
self._last_db_write_time: float | None = None
self._has_pending_updates: bool = False
self._deferred_flush_timer: threading.Timer | None = None
self._throttle_lock = threading.Lock()
# Cached task entity - avoids repeated DB fetches.
# Updated only by _refresh_task() when checking external state changes.
self._task: "Task" = task
# In-memory state caches - authoritative during execution
# These are initialized from the task entity and updated locally
# before being written to DB via targeted SQL updates.
# We copy the dicts to avoid mutating the Task's cached instances.
self._properties_cache: TaskProperties = cast(
TaskProperties, {**task.properties_dict}
)
self._payload_cache: dict[str, Any] = {**task.payload_dict}
# Store Flask app reference for background thread database access
# Use _get_current_object() to get actual app, not proxy
try:
self._app = current_app._get_current_object()
# Cache stats logger to avoid repeated config lookups
self._stats_logger: BaseStatsLogger = current_app.config.get(
"STATS_LOGGER", BaseStatsLogger()
)
except RuntimeError:
# Handle case where app context isn't available (e.g., tests)
self._app = None
self._stats_logger = BaseStatsLogger()
def _refresh_task(self) -> "Task":
"""
Force refresh the task entity from the database.
Use this method when you need to check for external state changes,
such as whether the task has been aborted by a concurrent operation.
This method:
- Fetches fresh task entity from database
- Updates the cached _task reference
- Updates properties/payload caches from fresh data
:returns: Fresh task entity from database
:raises ValueError: If task is not found
"""
from superset.daos.tasks import TaskDAO
fresh_task = TaskDAO.find_one_or_none(uuid=self._task_uuid)
if not fresh_task:
raise ValueError(f"Task {self._task_uuid} not found")
self._task = fresh_task
# Update caches from fresh data (copy to avoid mutating Task's cache)
self._properties_cache = cast(TaskProperties, {**fresh_task.properties_dict})
self._payload_cache = {**fresh_task.payload_dict}
return self._task
def update_task(
self,
progress: float | int | tuple[int, int] | None = None,
payload: dict[str, object] | None = None,
) -> None:
"""
Update task progress and/or payload atomically.
All parameters are optional. Payload is merged with existing cached data.
In-memory caches are always updated immediately, but DB writes are
throttled according to TASK_PROGRESS_UPDATE_THROTTLE_INTERVAL to prevent
excessive database load from eager tasks.
Progress can be specified in three ways:
- float (0.0-1.0): Percentage only, e.g., 0.5 means 50%
- int: Count only (total unknown), e.g., 42 means "42 items processed"
- tuple[int, int]: Count and total, e.g., (3, 100) means "3 of 100"
The percentage is automatically computed from count/total.
:param progress: Progress value, or None to leave unchanged
:param payload: Payload data to merge (dict), or None to leave unchanged
"""
has_updates = False
# Handle progress updates - always update in-memory cache
if progress is not None:
progress_props = progress_update(progress)
if progress_props:
# Merge progress into cached properties
self._properties_cache.update(progress_props)
has_updates = True
else:
# Invalid progress format - progress_update returns empty dict
logger.warning(
"Invalid progress value for task %s: %s "
"(expected float, int, or tuple[int, int])",
self._task_uuid,
progress,
)
# Handle payload updates - always update in-memory cache
if payload is not None:
# Merge payload into cached payload
self._payload_cache.update(payload)
has_updates = True
if not has_updates:
return
# Get throttle interval from config
throttle_interval = current_app.config["TASK_PROGRESS_UPDATE_THROTTLE_INTERVAL"]
# If throttling is disabled (0), write immediately
if throttle_interval <= 0:
self._write_to_db()
return
# Apply throttling with deferred flush
with self._throttle_lock:
now = time.time()
if self._last_db_write_time is None:
# First update - write immediately
self._write_to_db()
self._last_db_write_time = now
elif now - self._last_db_write_time >= throttle_interval:
# Throttle window has passed - write immediately
self._cancel_deferred_flush_timer()
self._write_to_db()
self._last_db_write_time = now
self._has_pending_updates = False
else:
# Within throttle window - defer the write
self._has_pending_updates = True
self._stats_logger.incr("gtf.task.update_deferred")
# Start deferred flush timer if not already running
if self._deferred_flush_timer is None:
remaining_time = throttle_interval - (
now - self._last_db_write_time
)
self._deferred_flush_timer = threading.Timer(
remaining_time, self._deferred_flush
)
self._deferred_flush_timer.daemon = True
self._deferred_flush_timer.start()
def _write_to_db(self) -> None:
"""
Write current cached state to database.
This method performs the actual DB write using InternalUpdateTaskCommand.
It writes whatever is in the caches at the time of the call.
"""
from superset.commands.tasks.internal_update import InternalUpdateTaskCommand
self._stats_logger.incr("gtf.task.update_write")
InternalUpdateTaskCommand(
task_uuid=self._task_uuid,
properties=self._properties_cache,
payload=self._payload_cache,
).run()
def _deferred_flush(self) -> None:
"""
Timer callback that flushes pending updates at end of throttle window.
This ensures the UI never shows stale progress for longer than the
throttle interval.
"""
with self._throttle_lock:
self._deferred_flush_timer = None
if self._has_pending_updates:
# Need app context for DB operations in timer thread
if self._app:
with self._app.app_context():
self._write_to_db()
else:
self._write_to_db()
self._last_db_write_time = time.time()
self._has_pending_updates = False
def _cancel_deferred_flush_timer(self) -> None:
"""Cancel the deferred flush timer if running."""
if self._deferred_flush_timer is not None:
self._deferred_flush_timer.cancel()
self._deferred_flush_timer = None
def on_cleanup(self, handler: Callable[[], None]) -> Callable[[], None]:
"""
Register a cleanup handler that runs when the task ends.
Cleanup handlers are called when the task completes (success),
fails with an error, or is aborted. Multiple handlers can be
registered and will execute in LIFO order (last registered runs first).
Can be used as a decorator:
@ctx.on_cleanup
def cleanup():
logger.info("Task ended")
Or called directly:
ctx.on_cleanup(lambda: logger.info("Task ended"))
:param handler: Cleanup function to register
:returns: The handler (for decorator compatibility)
"""
self._cleanup_handlers.append(handler)
return handler
def on_abort(self, handler: Callable[[], None]) -> Callable[[], None]:
"""
Register abort handler with automatic background listening.
When the first handler is registered:
1. Sets is_abortable=true in the database (marks task as abortable)
2. Background abort listener starts automatically (pub/sub or polling)
The handler will be called automatically when an abort is detected.
:param handler: Callback function to execute when abort is detected
:returns: The handler (for decorator compatibility)
Example:
@ctx.on_abort
def handle_abort():
logger.info("Task was aborted!")
cleanup_partial_work()
Note:
The handler executes in a background thread when abort is detected.
The task code continues running unless the handler does something
to stop it (e.g., raises an exception, modifies shared state, etc.)
"""
is_first_handler = len(self._abort_handlers) == 0
self._abort_handlers.append(handler)
if is_first_handler:
# Mark task as abortable in database
self._set_abortable()
# Auto-start abort listener when first handler is registered
interval = current_app.config["TASK_ABORT_POLLING_DEFAULT_INTERVAL"]
self._start_abort_listener(interval)
return handler
def _set_abortable(self) -> None:
"""Mark the task as abortable (abort handler has been registered)."""
from superset.commands.tasks.internal_update import InternalUpdateTaskCommand
# Update local cache and write to DB
self._properties_cache["is_abortable"] = True
InternalUpdateTaskCommand(
task_uuid=self._task_uuid,
properties=self._properties_cache,
).run()
def _start_abort_listener(self, interval: float) -> None:
"""
Start background abort listener via TaskManager.
Uses Redis pub/sub if available, otherwise falls back to database polling.
The implementation is encapsulated in TaskManager.
"""
if self._abort_listener is not None:
return # Already listening
from superset.tasks.manager import TaskManager
self._abort_listener = TaskManager.listen_for_abort(
task_uuid=self._task_uuid,
callback=self._on_abort_detected,
poll_interval=interval,
app=self._app,
)
def _on_abort_detected(self) -> None:
"""
Callback invoked by TaskManager when abort is detected.
Triggers all registered abort handlers.
"""
if self._abort_detected:
return # Already handled
# Check if task execution has already completed (late abort race).
# Executor sets _execution_completed after task work finishes.
if self._execution_completed:
logger.info(
"Abort detected for task %s but execution already completed",
self._task_uuid,
)
return
self._abort_detected = True
logger.info("Abort detected for task %s", self._task_uuid)
self._trigger_abort_handlers()
def mark_execution_completed(self) -> None:
"""
Mark that the task's main execution has completed.
Called by the executor after the task function returns (successfully
or with an exception). This prevents late abort callbacks from running
handlers when the task work has already finished. Cleanup handlers
still run after this is set.
"""
self._execution_completed = True
def start_abort_polling(self, interval: float | None = None) -> None:
"""
Start background abort listener.
This method is kept for backwards compatibility. It now delegates
to _start_abort_listener which uses TaskManager.
:param interval: Polling interval in seconds (uses config default if None)
"""
if interval is None:
interval = current_app.config["TASK_ABORT_POLLING_DEFAULT_INTERVAL"]
self._start_abort_listener(interval)
def _trigger_abort_handlers(self) -> None:
"""
Execute all registered abort handlers (called by polling thread or cleanup).
All handlers are attempted even if some fail (best-effort cleanup).
Failures are collected in self._handler_failures for unified reporting.
Note: This method never writes to DB directly. All failures are collected
and written by _run_cleanup() in the executor's finally block, ensuring
abort and cleanup handler failures are combined into a single record.
"""
for handler in reversed(self._abort_handlers):
try:
handler()
except Exception as ex:
stack_trace = traceback.format_exc()
logger.error(
"Abort handler failed for task %s: %s",
self._task_uuid,
str(ex),
exc_info=True,
)
self._handler_failures.append(("abort", ex, stack_trace))
# Check if all abort handlers completed successfully
abort_failures = [f for f in self._handler_failures if f[0] == "abort"]
if not abort_failures:
self._abort_handlers_completed = True
def _write_handler_failures_to_db(self) -> None:
"""
Write collected handler failures to the database.
Combines all failures (abort + cleanup) into a single error record.
If the task already has an error (e.g., task function threw exception),
handler failures are APPENDED to preserve the original error context.
"""
from superset.commands.tasks.update import UpdateTaskCommand
if not self._handler_failures:
return
# Build error message from all handler failures
error_messages = [str(ex) for _, ex, _ in self._handler_failures]
handler_types = {htype for htype, _, _ in self._handler_failures}
if len(self._handler_failures) == 1:
htype, ex, handler_stack_trace = self._handler_failures[0]
handler_error_msg = (
f"{htype.capitalize()} handler failed: {error_messages[0]}"
)
handler_exception_type = type(ex).__name__
else:
# Multiple failures
handler_error_msg = f"Handler(s) failed: {'; '.join(error_messages)}"
if handler_types == {"abort"}:
handler_exception_type = "MultipleAbortHandlerFailures"
elif handler_types == {"cleanup"}:
handler_exception_type = "MultipleCleanupHandlerFailures"
else:
handler_exception_type = "MultipleHandlerFailures"
# Combine stack traces with clear separators
handler_stack_trace = "\n--- Next handler failure ---\n".join(
f"[{htype}:{type(ex).__name__}]\n{trace}"
for htype, ex, trace in self._handler_failures
)
if self._app:
with self._app.app_context():
# Check if task already has an error (preserve original context)
task = self._task
original_error = task.properties_dict.get("error_message")
original_type = task.properties_dict.get("exception_type")
original_trace = task.properties_dict.get("stack_trace")
if original_error:
# Append handler failures to original error
error_msg = f"{original_error} | {handler_error_msg}"
exception_type = (
f"{original_type}+{handler_exception_type}"
if original_type
else handler_exception_type
)
stack_trace = (
f"{original_trace}\n\n"
f"=== Handler failures during cleanup ===\n\n"
f"{handler_stack_trace}"
if original_trace
else handler_stack_trace
)
else:
# No original error, just use handler failures
error_msg = handler_error_msg
exception_type = handler_exception_type
stack_trace = handler_stack_trace
# Update task with combined error info
UpdateTaskCommand(
self._task_uuid,
status=TaskStatus.FAILURE.value,
properties={
"error_message": error_msg,
"exception_type": exception_type,
"stack_trace": stack_trace,
},
skip_security_check=True,
).run()
# Clear failures after writing
self._handler_failures = []
def stop_abort_polling(self) -> None:
"""Stop the background abort listener."""
if self._abort_listener is not None:
self._abort_listener.stop()
self._abort_listener = None
def start_timeout_timer(self, timeout_seconds: int) -> None:
"""
Start a timeout timer that triggers abort when elapsed.
Called by execute_task when task transitions to IN_PROGRESS.
Timer only triggers abort handlers if task is abortable.
:param timeout_seconds: Timeout duration in seconds
"""
if self._timeout_timer is not None:
return # Already started
def on_timeout() -> None:
if self._abort_detected:
return # Already aborting
self._timeout_triggered = True
# Check if task has abort handler (requires app context)
if not self._app:
logger.error(
"Timeout fired for task %s but no app context available",
self._task_uuid,
)
return
with self._app.app_context():
from superset.commands.tasks.update import UpdateTaskCommand
task = self._task
if task.properties_dict.get("is_abortable", False):
logger.info(
"Timeout reached for task %s after %d seconds - "
"transitioning to ABORTING and triggering abort handlers",
self._task_uuid,
timeout_seconds,
)
# Set status to ABORTING (same as user abort)
# The executor will determine TIMED_OUT vs FAILURE based on
# whether handlers complete successfully
UpdateTaskCommand(
self._task_uuid,
status=TaskStatus.ABORTING.value,
properties={"error_message": "Task timed out"},
skip_security_check=True,
).run()
# Trigger abort handlers for cleanup
self._on_abort_detected()
else:
# No abort handler - just log warning
logger.warning(
"Timeout reached for task %s after %d seconds, but no "
"abort handler is registered. Task will continue running.",
self._task_uuid,
timeout_seconds,
)
self._timeout_timer = threading.Timer(timeout_seconds, on_timeout)
# Timer is daemon so it won't prevent process exit. If the worker dies,
# the task is already in an inconsistent state (stuck IN_PROGRESS) that
# requires external recovery (orphan detection). A non-daemon timer with
# long timeouts (hours) would block graceful worker shutdown.
self._timeout_timer.daemon = True
self._timeout_timer.start()
logger.debug(
"Started timeout timer for task %s: %d seconds",
self._task_uuid,
timeout_seconds,
)
def stop_timeout_timer(self) -> None:
"""Cancel the timeout timer if running."""
if self._timeout_timer is not None:
self._timeout_timer.cancel()
self._timeout_timer = None
@property
def timeout_triggered(self) -> bool:
"""Check if the timeout was triggered."""
return self._timeout_triggered
@property
def abort_handlers_completed(self) -> bool:
"""Check if all abort handlers have completed successfully."""
return self._abort_handlers_completed
def _run_cleanup(self) -> None:
"""
Run cleanup handlers (called by executor in finally block).
This runs:
1. Flushes any pending throttled updates to ensure final state is persisted
2. Abort handlers if task was aborting/aborted (but not yet detected)
3. All cleanup handlers (always)
All handler failures (abort + cleanup) are collected and written to DB
as a unified error record at the end.
"""
# Flush any pending throttled updates before cleanup
with self._throttle_lock:
self._cancel_deferred_flush_timer()
if self._has_pending_updates:
self._write_to_db()
self._has_pending_updates = False
# Stop abort listener and timeout timer
self.stop_abort_polling()
self.stop_timeout_timer()
# If aborting/aborted but handlers haven't run yet, run them now
# (This catches the case where task ended before listener detected abort)
if self._app:
with self._app.app_context():
task = self._task
if task.status in ABORT_STATES and not self._abort_detected:
self._trigger_abort_handlers()
else:
# Fallback without app context
try:
task = self._task
if task.status in ABORT_STATES and not self._abort_detected:
self._trigger_abort_handlers()
except Exception as ex:
logger.warning(
"Could not check abort status during cleanup for task %s: %s",
self._task_uuid,
str(ex),
)
# Always run cleanup handlers, collecting failures
for handler in reversed(self._cleanup_handlers):
try:
handler()
except Exception as ex:
stack_trace = traceback.format_exc()
logger.error(
"Cleanup handler failed for task %s: %s",
self._task_uuid,
str(ex),
exc_info=True,
)
self._handler_failures.append(("cleanup", ex, stack_trace))
# Write all collected failures (abort + cleanup) to DB as unified record
if self._handler_failures:
self._write_handler_failures_to_db()