Files
superset2/superset/tasks/manager.py

762 lines
26 KiB
Python

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Task manager for the Global Task Framework (GTF)"""
from __future__ import annotations
import logging
import threading
import time
from typing import Any, Callable, TYPE_CHECKING
from uuid import UUID
import redis
from superset_core.tasks.types import TaskProperties, TaskScope
from superset.async_events.cache_backend import (
RedisCacheBackend,
RedisSentinelCacheBackend,
)
from superset.extensions import cache_manager
from superset.tasks.constants import ABORT_STATES, TERMINAL_STATES
from superset.tasks.utils import generate_random_task_key
if TYPE_CHECKING:
from flask import Flask
from superset.models.tasks import Task
logger = logging.getLogger(__name__)
class AbortListener:
"""
Handle for a background abort listener.
Returned by TaskManager.listen_for_abort() to allow stopping the listener.
"""
def __init__(
self,
task_uuid: UUID,
thread: threading.Thread,
stop_event: threading.Event,
pubsub: redis.client.PubSub | None = None,
) -> None:
self._task_uuid = task_uuid
self._thread = thread
self._stop_event = stop_event
self._pubsub = pubsub
def stop(self) -> None:
"""Stop the abort listener."""
self._stop_event.set()
# Close pub/sub subscription if active
if self._pubsub is not None:
try:
self._pubsub.unsubscribe()
self._pubsub.close()
except Exception as ex:
logger.debug("Error closing pub/sub during stop: %s", ex)
# Wait for thread to finish (with timeout to avoid blocking indefinitely)
if self._thread.is_alive():
self._thread.join(timeout=2.0)
# Check if thread is still running after timeout
if self._thread.is_alive():
# Thread is a daemon, so it will be killed when process exits.
# Log warning but continue - cleanup will still proceed.
logger.warning(
"Abort listener thread for task %s did not terminate within "
"2 seconds. Thread will be terminated when process exits.",
self._task_uuid,
)
else:
logger.debug("Stopped abort listener for task %s", self._task_uuid)
else:
logger.debug("Stopped abort listener for task %s", self._task_uuid)
class TaskManager:
"""
Handles task creation, scheduling, and abort notifications.
The TaskManager is responsible for:
1. Creating task entries in the metastore (Task model)
2. Scheduling task execution via Celery
3. Handling deduplication (returning existing active task if duplicate)
4. Managing real-time abort notifications (optional)
Redis pub/sub is opt-in via DISTRIBUTED_COORDINATION_CONFIG configuration. When not
configured, tasks use database polling for abort detection.
"""
# Class-level state (initialized once via init_app)
_channel_prefix: str = "gtf:abort:"
_completion_channel_prefix: str = "gtf:complete:"
_initialized: bool = False
@classmethod
def init_app(cls, app: Flask) -> None:
"""
Initialize the TaskManager with Flask app config.
Redis connection is managed by CacheManager - this just reads channel prefixes.
:param app: Flask application instance
"""
if cls._initialized:
return
cls._channel_prefix = app.config.get("TASKS_ABORT_CHANNEL_PREFIX", "gtf:abort:")
cls._completion_channel_prefix = app.config.get(
"TASKS_COMPLETION_CHANNEL_PREFIX", "gtf:complete:"
)
cls._initialized = True
@classmethod
def _get_cache(cls) -> RedisCacheBackend | RedisSentinelCacheBackend | None:
"""
Get the distributed coordination backend.
:returns: The distributed coordination backend, or None if not configured
"""
return cache_manager.distributed_coordination
@classmethod
def is_pubsub_available(cls) -> bool:
"""
Check if Redis pub/sub backend is configured and available.
:returns: True if Redis is available for pub/sub, False otherwise
"""
return cls._get_cache() is not None
@classmethod
def get_abort_channel(cls, task_uuid: UUID) -> str:
"""
Get the abort channel name for a task.
:param task_uuid: UUID of the task
:returns: Channel name for the task's abort notifications
"""
return f"{cls._channel_prefix}{task_uuid}"
@classmethod
def publish_abort(cls, task_uuid: UUID) -> bool:
"""
Publish an abort message to the task's channel.
:param task_uuid: UUID of the task to abort
:returns: True if message was published, False if Redis unavailable
"""
cache = cls._get_cache()
if not cache:
return False
try:
channel = cls.get_abort_channel(task_uuid)
subscriber_count = cache.publish(channel, "abort")
logger.debug(
"Published abort to channel %s (%d subscribers)",
channel,
subscriber_count,
)
return True
except redis.RedisError as ex:
logger.error("Failed to publish abort for task %s: %s", task_uuid, ex)
return False
@classmethod
def get_completion_channel(cls, task_uuid: UUID) -> str:
"""
Get the completion channel name for a task.
:param task_uuid: UUID of the task
:returns: Channel name for the task's completion notifications
"""
return f"{cls._completion_channel_prefix}{task_uuid}"
@classmethod
def publish_completion(cls, task_uuid: UUID, status: str) -> bool:
"""
Publish a completion message to the task's channel.
Called when task reaches terminal state (SUCCESS, FAILURE, ABORTED, TIMED_OUT).
This notifies any waiters (e.g., sync callers waiting for an existing task).
:param task_uuid: UUID of the completed task
:param status: Final status of the task
:returns: True if message was published, False if Redis unavailable
"""
cache = cls._get_cache()
if not cache:
return False
try:
channel = cls.get_completion_channel(task_uuid)
subscriber_count = cache.publish(channel, status)
logger.debug(
"Published completion to channel %s (status=%s, %d subscribers)",
channel,
status,
subscriber_count,
)
return True
except redis.RedisError as ex:
logger.error("Failed to publish completion for task %s: %s", task_uuid, ex)
return False
@classmethod
def wait_for_completion(
cls,
task_uuid: UUID,
timeout: float | None = None,
poll_interval: float = 1.0,
app: Any = None,
) -> "Task":
"""
Block until task reaches terminal state.
Uses Redis pub/sub if configured for low-latency, low-CPU waiting.
Uses database polling if Redis is not configured.
:param task_uuid: UUID of the task to wait for
:param timeout: Maximum time to wait in seconds (None = no limit)
:param poll_interval: Interval for database polling (seconds)
:param app: Flask app for database access
:returns: Task in terminal state
:raises TimeoutError: If timeout expires before task completes
:raises ValueError: If task not found
"""
from superset.daos.tasks import TaskDAO
start_time = time.monotonic()
def time_remaining() -> float | None:
if timeout is None:
return None
elapsed = time.monotonic() - start_time
remaining = timeout - elapsed
return remaining if remaining > 0 else 0
def get_task() -> "Task | None":
if app:
with app.app_context():
return TaskDAO.find_one_or_none(uuid=task_uuid)
return TaskDAO.find_one_or_none(uuid=task_uuid)
# Check current state first
task = get_task()
if not task:
raise ValueError(f"Task {task_uuid} not found")
if task.status in TERMINAL_STATES:
return task
logger.debug(
"Waiting for task %s to complete (current status=%s, timeout=%s)",
task_uuid,
task.status,
timeout,
)
# Use Redis pub/sub if configured
if (cache := cls._get_cache()) is not None:
task = cls._wait_via_pubsub(
task_uuid,
cache.pubsub(),
timeout,
poll_interval,
get_task,
time_remaining,
)
if task:
return task
# Should not reach here - _wait_via_pubsub returns task or raises
raise RuntimeError(f"Unexpected state waiting for task {task_uuid}")
# Use database polling when Redis is not configured
return cls._wait_via_polling(task_uuid, poll_interval, get_task, time_remaining)
@classmethod
def _wait_via_pubsub(
cls,
task_uuid: UUID,
pubsub: redis.client.PubSub,
timeout: float | None,
poll_interval: float,
get_task: Callable[[], "Task | None"],
time_remaining: Callable[[], float | None],
) -> "Task | None":
"""
Wait for task completion using Redis pub/sub.
:returns: Task when completed
:raises TimeoutError: If timeout expires
:raises redis.RedisError: If Redis connection fails
"""
channel = cls.get_completion_channel(task_uuid)
pubsub.subscribe(channel)
try:
while True:
remaining = time_remaining()
if remaining is not None and remaining <= 0:
raise TimeoutError(
f"Timeout waiting for task {task_uuid} to complete"
)
# Wait for message with short timeout for responsive checking
wait_time = min(1.0, remaining) if remaining else 1.0
message = pubsub.get_message(
ignore_subscribe_messages=True,
timeout=wait_time,
)
if message and message.get("type") == "message":
# Completion received - fetch fresh task state
logger.debug(
"Received completion message for task %s: %s",
task_uuid,
message.get("data"),
)
task = get_task()
if task and task.status in TERMINAL_STATES:
return task
# Also check database periodically in case we missed the message
# (e.g., task completed before we subscribed)
task = get_task()
if task and task.status in TERMINAL_STATES:
logger.debug(
"Task %s completed (detected via db check): status=%s",
task_uuid,
task.status,
)
return task
finally:
pubsub.unsubscribe()
pubsub.close()
@classmethod
def _wait_via_polling(
cls,
task_uuid: UUID,
poll_interval: float,
get_task: Callable[[], "Task | None"],
time_remaining: Callable[[], float | None],
) -> "Task":
"""
Wait for task completion using database polling.
:returns: Task when completed
:raises TimeoutError: If timeout expires
:raises ValueError: If task not found
"""
while True:
remaining = time_remaining()
if remaining is not None and remaining <= 0:
raise TimeoutError(f"Timeout waiting for task {task_uuid} to complete")
task = get_task()
if not task:
raise ValueError(f"Task {task_uuid} not found")
if task.status in TERMINAL_STATES:
logger.debug(
"Task %s completed (detected via polling): status=%s",
task_uuid,
task.status,
)
return task
# Sleep with timeout awareness
sleep_time = min(poll_interval, remaining) if remaining else poll_interval
time.sleep(sleep_time)
@classmethod
def listen_for_abort(
cls,
task_uuid: UUID,
callback: Callable[[], None],
poll_interval: float,
app: Any = None,
) -> AbortListener:
"""
Start listening for abort notifications for a task.
Uses Redis pub/sub if configured, otherwise uses database polling.
The callback is invoked when an abort is detected.
:param task_uuid: UUID of the task to monitor (native UUID)
:param callback: Function to call when abort is detected
:param poll_interval: Interval for database polling (when Redis not configured)
:param app: Flask app for database access in background thread
:returns: AbortListener handle to stop listening
"""
stop_event = threading.Event()
pubsub: redis.client.PubSub | None = None
uuid_str = str(task_uuid)
# Use Redis pub/sub if configured
if (cache := cls._get_cache()) is not None:
pubsub = cache.pubsub()
channel = cls.get_abort_channel(task_uuid)
pubsub.subscribe(channel)
logger.debug("Subscribed to abort channel: %s", channel)
# Start pub/sub listener thread
thread = threading.Thread(
target=cls._listen_pubsub,
args=(task_uuid, pubsub, callback, stop_event, app),
daemon=True,
name=f"abort-listener-{uuid_str[:8]}",
)
logger.debug("Started pub/sub abort listener for task %s", task_uuid)
else:
# Use polling when Redis is not configured
pubsub = None
thread = threading.Thread(
target=cls._poll_for_abort,
args=(task_uuid, callback, stop_event, poll_interval, app),
daemon=True,
name=f"abort-poller-{uuid_str[:8]}",
)
logger.debug(
"Started database abort polling for task %s (interval=%ss)",
task_uuid,
poll_interval,
)
thread.start()
return AbortListener(task_uuid, thread, stop_event, pubsub)
@staticmethod
def _invoke_callback_with_context(
callback: Callable[[], None],
app: Any,
) -> None:
"""
Invoke callback with Flask app context if provided.
:param callback: Function to invoke
:param app: Flask app for context, or None
"""
if app:
with app.app_context():
callback()
else:
callback()
@classmethod
def _check_abort_status(cls, task_uuid: UUID) -> bool:
"""
Check if task has been aborted via database query.
:param task_uuid: UUID of the task to check (native UUID)
:returns: True if task is in ABORTING or ABORTED state
"""
from superset.daos.tasks import TaskDAO
task = TaskDAO.find_one_or_none(uuid=task_uuid)
return task is not None and task.status in ABORT_STATES
@classmethod
def _run_abort_listener_loop(
cls,
task_uuid: UUID,
callback: Callable[[], None],
stop_event: threading.Event,
interval: float,
app: Any,
check_fn: Callable[[], bool],
source: str,
) -> None:
"""
Common abort listener loop used by both pub/sub and polling modes.
:param task_uuid: UUID of the task to monitor (native UUID)
:param callback: Function to call when abort is detected
:param stop_event: Event to signal loop termination
:param interval: Wait interval between checks
:param app: Flask app for context
:param check_fn: Function that returns True if abort was detected
:param source: Source identifier for logging ("pub/sub" or "polling")
"""
while not stop_event.is_set():
try:
if check_fn():
logger.info(
"Abort detected via %s for task %s",
source,
task_uuid,
)
cls._invoke_callback_with_context(callback, app)
break
# Wait for interval or until stop is requested
stop_event.wait(timeout=interval)
except (ValueError, OSError) as ex:
# ValueError/OSError with "I/O operation on closed file" or
# "Bad file descriptor" typically means the connection was closed
# during shutdown. Check if stop was requested.
if stop_event.is_set():
logger.debug(
"Abort %s for task %s stopped cleanly (connection closed)",
source,
task_uuid,
)
else:
logger.error(
"Error in abort %s for task %s: %s",
source,
task_uuid,
str(ex),
exc_info=True,
)
break
except Exception as ex:
# Check if stop was requested - if so, this may be expected
if stop_event.is_set():
logger.debug(
"Abort %s for task %s stopped with exception: %s",
source,
task_uuid,
ex,
)
else:
logger.error(
"Error in abort %s for task %s: %s",
source,
task_uuid,
str(ex),
exc_info=True,
)
break
@classmethod
def _listen_pubsub(
cls,
task_uuid: UUID,
pubsub: redis.client.PubSub,
callback: Callable[[], None],
stop_event: threading.Event,
app: Any,
) -> None:
"""Listen for abort via Redis pub/sub."""
# Track if abort was received to avoid double-callback
abort_received = False
def check_pubsub() -> bool:
nonlocal abort_received
message = pubsub.get_message(ignore_subscribe_messages=True, timeout=1.0)
if message is not None and message.get("type") == "message":
abort_received = True
return True
return False
try:
cls._run_abort_listener_loop(
task_uuid=task_uuid,
callback=callback,
stop_event=stop_event,
interval=0, # pub/sub has its own timeout in get_message
app=app,
check_fn=check_pubsub,
source="pub/sub",
)
except redis.RedisError as ex:
# Check if we were asked to stop - if so, this is expected
if stop_event.is_set():
logger.debug(
"Abort listener for task %s stopped (Redis error: %s)",
task_uuid,
ex,
)
else:
# Log error but don't fall back - let the failure be visible
logger.error(
"Redis signal backend failed for task %s abort listener: %s. "
"Task may not receive abort signal.",
task_uuid,
ex,
)
except (ValueError, OSError) as ex:
# ValueError: "I/O operation on closed file" - expected when stop() closes
# OSError: Similar connection-closed errors
if stop_event.is_set():
# Clean shutdown, expected behavior
logger.debug(
"Abort listener for task %s stopped cleanly",
task_uuid,
)
else:
# Unexpected error while running
logger.error(
"Error in abort listener for task %s: %s",
task_uuid,
str(ex),
exc_info=True,
)
except Exception as ex:
# Only log as error if we weren't asked to stop
if stop_event.is_set():
logger.debug(
"Abort listener for task %s stopped with exception: %s",
task_uuid,
ex,
)
else:
logger.error(
"Error in abort listener for task %s: %s",
task_uuid,
str(ex),
exc_info=True,
)
finally:
# Clean up pub/sub subscription
try:
pubsub.unsubscribe()
pubsub.close()
except Exception as ex:
logger.debug("Error closing pub/sub during cleanup: %s", ex)
@classmethod
def _poll_for_abort(
cls,
task_uuid: UUID,
callback: Callable[[], None],
stop_event: threading.Event,
interval: float,
app: Any,
) -> None:
"""Background polling loop - used when Redis pub/sub is not configured."""
def check_database() -> bool:
# Need app context for database access
if app:
with app.app_context():
return cls._check_abort_status(task_uuid)
else:
return cls._check_abort_status(task_uuid)
cls._run_abort_listener_loop(
task_uuid=task_uuid,
callback=callback,
stop_event=stop_event,
interval=interval,
app=app,
check_fn=check_database,
source="polling",
)
@staticmethod
def submit_task(
task_type: str,
task_key: str | None,
task_name: str | None,
scope: TaskScope,
timeout: int | None,
args: tuple[Any, ...],
kwargs: dict[str, Any],
) -> "Task":
"""
Create task entry and schedule for async execution.
Flow:
1. Generate task_key if not provided (random UUID)
2. Submit to SubmitTaskCommand which handles locking and create-vs-join
3. Schedule Celery task ONLY for new tasks (not deduplicated ones)
4. Return Task model to caller
The SubmitTaskCommand uses a distributed lock to prevent race conditions,
returning either a new task or an existing active task with the same key.
:param task_type: Task type identifier (e.g., "superset.generate_thumbnail")
:param task_key: Optional deduplication key (None for random UUID)
:param task_name: Human readable task name
:param scope: Task scope (TaskScope.PRIVATE, SHARED, or SYSTEM)
:param timeout: Optional timeout in seconds
:param args: Positional arguments for the task function
:param kwargs: Keyword arguments for the task function
:returns: Task model representing the scheduled task
"""
from superset.commands.tasks.submit import SubmitTaskCommand
if task_key is None:
task_key = generate_random_task_key()
# Build properties with execution_mode and timeout
properties: TaskProperties = {"execution_mode": "async"}
if timeout:
properties["timeout"] = timeout
# Create or join task entry in metastore
# SubmitTaskCommand handles locking and create-vs-join logic:
# - Acquires distributed lock on dedup_key
# - If active task exists: adds subscriber and returns existing task
# (is_new=False)
# - If no active task: creates new task (is_new=True)
task, is_new = SubmitTaskCommand(
{
"task_key": task_key,
"task_type": task_type,
"task_name": task_name,
"scope": scope.value,
"properties": properties,
}
).run_with_info()
# Only schedule Celery task for NEW tasks, not deduplicated ones
# Deduplicated tasks are already pending or running
if is_new:
# Import here to avoid circular dependency
from superset.tasks.scheduler import execute_task
# Schedule Celery task for async execution
execute_task.delay(
task_uuid=str(task.uuid),
task_type=task_type,
args=args,
kwargs=kwargs,
)
logger.debug(
"Scheduled task %s (uuid=%s) for async execution",
task_type,
task.uuid,
)
else:
logger.debug(
"Joined existing task %s (uuid=%s) - no new Celery task scheduled",
task_type,
task.uuid,
)
return task