mirror of
https://github.com/apache/superset.git
synced 2026-04-13 05:07:53 +00:00
621 lines
23 KiB
Python
621 lines
23 KiB
Python
# Licensed to the Apache Software Foundation (ASF) under one
|
|
# or more contributor license agreements. See the NOTICE file
|
|
# distributed with this work for additional information
|
|
# regarding copyright ownership. The ASF licenses this file
|
|
# to you under the Apache License, Version 2.0 (the
|
|
# "License"); you may not use this file except in compliance
|
|
# with the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing,
|
|
# software distributed under the License is distributed on an
|
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
# KIND, either express or implied. See the License for the
|
|
# specific language governing permissions and limitations
|
|
# under the License.
|
|
"""Decorators for the Global Task Framework (GTF)"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import inspect
|
|
import logging
|
|
from typing import Any, Callable, cast, Generic, ParamSpec, TYPE_CHECKING, TypeVar
|
|
|
|
from superset_core.tasks.types import TaskOptions, TaskScope, TaskStatus
|
|
|
|
from superset import is_feature_enabled
|
|
from superset.commands.tasks.exceptions import GlobalTaskFrameworkDisabledError
|
|
from superset.tasks.ambient_context import use_context
|
|
from superset.tasks.constants import TERMINAL_STATES
|
|
from superset.tasks.context import TaskContext
|
|
from superset.tasks.manager import TaskManager
|
|
from superset.tasks.registry import TaskRegistry
|
|
from superset.tasks.utils import generate_random_task_key
|
|
|
|
if TYPE_CHECKING:
|
|
from superset.models.tasks import Task
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
P = ParamSpec("P")
|
|
R = TypeVar("R")
|
|
|
|
|
|
def task(
|
|
func: Callable[P, R] | None = None,
|
|
*,
|
|
name: str | None = None,
|
|
scope: TaskScope = TaskScope.PRIVATE,
|
|
timeout: int | None = None,
|
|
) -> Callable[[Callable[P, R]], "TaskWrapper[P]"] | "TaskWrapper[P]":
|
|
"""
|
|
Decorator to register a task with default scope.
|
|
|
|
Can be used with or without parentheses:
|
|
@task
|
|
def my_func(): ...
|
|
|
|
@task()
|
|
def my_func(): ...
|
|
|
|
@task(name="custom_name", scope=TaskScope.SHARED)
|
|
def my_func(): ...
|
|
|
|
@task(timeout=300) # 5-minute timeout
|
|
def long_running_func(): ...
|
|
|
|
Args:
|
|
func: The function to decorate (when used without parentheses).
|
|
name: Optional unique task name (e.g., "superset.generate_thumbnail").
|
|
If not provided, uses the function name as the task name.
|
|
scope: Task scope (TaskScope.PRIVATE, SHARED, or SYSTEM).
|
|
Defaults to TaskScope.PRIVATE.
|
|
timeout: Optional timeout in seconds. When the timeout is reached,
|
|
abort handlers are triggered if registered. Can be overridden
|
|
at call time via TaskOptions(timeout=...).
|
|
|
|
Usage:
|
|
# Private task (default scope) - no parentheses
|
|
@task
|
|
def my_async_func(chart_id: int) -> None:
|
|
ctx = get_context()
|
|
...
|
|
|
|
# Named task with shared scope
|
|
@task(name="generate_report", scope=TaskScope.SHARED)
|
|
def generate_expensive_report(report_id: int) -> None:
|
|
ctx = get_context()
|
|
...
|
|
|
|
# System task (admin-only)
|
|
@task(scope=TaskScope.SYSTEM)
|
|
def cleanup_task() -> None:
|
|
ctx = get_context()
|
|
...
|
|
|
|
# Task with timeout
|
|
@task(timeout=300)
|
|
def long_task() -> None:
|
|
ctx = get_context()
|
|
|
|
@ctx.on_abort
|
|
def handle_abort():
|
|
# Called when timeout is reached or user cancels
|
|
...
|
|
|
|
Note:
|
|
Both direct calls and .schedule() return Task, regardless of the
|
|
original function's return type. The decorated function's return value
|
|
is discarded; only side effects and context updates matter.
|
|
"""
|
|
|
|
def decorator(f: Callable[P, R]) -> "TaskWrapper[P]":
|
|
# Use function name if no name provided
|
|
base_task_name = name if name is not None else f.__name__
|
|
|
|
# Apply ambient context detection for ID prefixing (like MCP decorators)
|
|
from superset.extensions.context import get_current_extension_context
|
|
|
|
if context := get_current_extension_context():
|
|
# Extension context: prefix task name to prevent collisions
|
|
manifest = context.manifest
|
|
task_name = (
|
|
f"extensions.{manifest.publisher}.{manifest.name}.{base_task_name}"
|
|
)
|
|
else:
|
|
# Host context: use original task name
|
|
task_name = base_task_name
|
|
|
|
# Create default options with no scope (scope is now in decorator)
|
|
default_options = TaskOptions()
|
|
|
|
# Validate function signature - must not have ctx or options params
|
|
sig = inspect.signature(f)
|
|
forbidden = {"ctx", "options"}
|
|
if any(param in forbidden for param in sig.parameters):
|
|
raise TypeError(
|
|
f"Task function {f.__name__} must not define 'ctx' or "
|
|
"'options' parameters. "
|
|
f"Use get_context() instead for ambient context access."
|
|
)
|
|
|
|
# Register task
|
|
TaskRegistry.register(task_name, f)
|
|
|
|
# Create wrapper with schedule() method, default options, scope, and timeout
|
|
wrapper = TaskWrapper(task_name, f, default_options, scope, timeout)
|
|
|
|
# Preserve signature for introspection
|
|
wrapper.__signature__ = sig # type: ignore[attr-defined]
|
|
|
|
return wrapper
|
|
|
|
if func is None:
|
|
# Called with parentheses: @task() or @task(name="foo", scope=TaskScope.SHARED)
|
|
return decorator
|
|
else:
|
|
# Called without parentheses: @task
|
|
return decorator(func)
|
|
|
|
|
|
class TaskWrapper(Generic[P]):
|
|
"""
|
|
Wrapper for task functions that provides .schedule() method.
|
|
|
|
Both direct calls and .schedule() return Task. The original function's
|
|
return value is discarded.
|
|
|
|
Direct calls execute synchronously, .schedule() runs async via Celery.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
name: str,
|
|
func: Callable[P, R],
|
|
default_options: TaskOptions,
|
|
scope: TaskScope = TaskScope.PRIVATE,
|
|
default_timeout: int | None = None,
|
|
) -> None:
|
|
self.name = name
|
|
self.func = func
|
|
self.default_options = default_options
|
|
self.scope = scope
|
|
self.default_timeout = default_timeout
|
|
self.__name__ = func.__name__
|
|
self.__doc__ = func.__doc__
|
|
self.__module__ = func.__module__
|
|
|
|
# Patch schedule.__signature__ to mirror function + options parameter
|
|
# This enables proper IDE support and introspection
|
|
sig = inspect.signature(func)
|
|
params = list(sig.parameters.values())
|
|
# Add keyword-only options parameter
|
|
params.append(
|
|
inspect.Parameter(
|
|
"options",
|
|
inspect.Parameter.KEYWORD_ONLY,
|
|
default=None,
|
|
annotation=TaskOptions | None,
|
|
)
|
|
)
|
|
self.schedule.__func__.__signature__ = sig.replace( # type: ignore[attr-defined]
|
|
parameters=params, return_annotation="Task"
|
|
)
|
|
|
|
def _merge_options(self, override_options: TaskOptions | None) -> TaskOptions:
|
|
"""
|
|
Merge decorator defaults with call-time overrides.
|
|
|
|
Call-time options take precedence over decorator defaults.
|
|
For timeout, an explicit None in TaskOptions disables the decorator timeout.
|
|
|
|
Args:
|
|
override_options: Options provided at call time, or None
|
|
|
|
Returns:
|
|
Merged TaskOptions with overrides applied
|
|
"""
|
|
if override_options is None:
|
|
return TaskOptions(
|
|
task_key=self.default_options.task_key,
|
|
task_name=self.default_options.task_name,
|
|
timeout=self.default_timeout, # Use decorator default
|
|
)
|
|
|
|
# Merge: use override if provided, otherwise use default
|
|
# For timeout: if override_options.timeout is explicitly set (even to None),
|
|
# use it; otherwise fall back to decorator default
|
|
return TaskOptions(
|
|
task_key=override_options.task_key or self.default_options.task_key,
|
|
task_name=override_options.task_name or self.default_options.task_name,
|
|
timeout=override_options.timeout
|
|
if override_options.timeout is not None
|
|
else self.default_timeout,
|
|
)
|
|
|
|
def _validate_task(self, options: TaskOptions) -> None:
|
|
"""
|
|
Validate task configuration before execution.
|
|
|
|
Args:
|
|
options: Merged task options to validate
|
|
|
|
Raises:
|
|
ValueError: If validation fails
|
|
"""
|
|
# Shared tasks must have an explicit task_key for deduplication
|
|
if self.scope == TaskScope.SHARED and options.task_key is None:
|
|
raise ValueError(
|
|
f"Shared task '{self.name}' requires an explicit task_key in "
|
|
"TaskOptions for deduplication. Without a task_key, each "
|
|
"invocation creates a separate task with a random UUID, "
|
|
"defeating the purpose of shared tasks."
|
|
)
|
|
|
|
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> "Task":
|
|
"""
|
|
Call the function synchronously.
|
|
|
|
This is invoked when you call the decorated function directly:
|
|
task = generate_thumbnail(chart_id) # Blocks until completion
|
|
|
|
Flow:
|
|
1. Submit task (create new or join existing via deduplication)
|
|
2. If joining existing task: wait for it to complete (blocking)
|
|
3. If new task: execute inline and return completed task
|
|
|
|
Sync execution always blocks until completion - even when joining an
|
|
existing task that's running in another process/worker.
|
|
|
|
Returns the Task entity in terminal state (SUCCESS, FAILURE, etc.).
|
|
|
|
Raises:
|
|
GlobalTaskFrameworkDisabledError: If GTF feature flag is not enabled
|
|
ValueError: If task validation fails
|
|
TimeoutError: If timeout expires while waiting for existing task
|
|
"""
|
|
from superset.commands.tasks.submit import SubmitTaskCommand
|
|
|
|
if not is_feature_enabled("GLOBAL_TASK_FRAMEWORK"):
|
|
raise GlobalTaskFrameworkDisabledError()
|
|
|
|
# Extract and merge options (decorator defaults + call-time overrides)
|
|
override_options = cast(TaskOptions | None, kwargs.pop("options", None))
|
|
options = self._merge_options(override_options)
|
|
|
|
# Validate task configuration
|
|
self._validate_task(options)
|
|
|
|
# Extract task_name and task_key from merged options, scope from decorator
|
|
task_name = (
|
|
options.task_name or f"{self.name}:{generate_random_task_key()[:50]}"
|
|
)
|
|
task_key = options.task_key or generate_random_task_key()
|
|
scope = self.scope # Use scope from decorator
|
|
|
|
# Build properties with execution_mode and timeout
|
|
properties: dict[str, str | int] = {"execution_mode": "sync"}
|
|
if options.timeout:
|
|
properties["timeout"] = options.timeout
|
|
|
|
# Submit task - may create new or join existing
|
|
task, is_new = SubmitTaskCommand(
|
|
{
|
|
"task_type": self.name,
|
|
"task_key": task_key,
|
|
"task_name": task_name,
|
|
"scope": scope.value,
|
|
"properties": properties,
|
|
}
|
|
).run_with_info()
|
|
|
|
# If joining existing task, wait for it to complete
|
|
if not is_new:
|
|
return self._wait_for_existing_task(task, options.timeout)
|
|
|
|
# New task - execute inline
|
|
return self._execute_inline(task, options, args, kwargs)
|
|
|
|
def _wait_for_existing_task(self, task: "Task", timeout: int | None) -> "Task":
|
|
"""
|
|
Wait for an existing task to complete.
|
|
|
|
Called when sync execution joins a pre-existing task via deduplication.
|
|
Blocks until the task reaches a terminal state.
|
|
|
|
:param task: The existing task to wait for
|
|
:param timeout: Maximum time to wait in seconds (None = no limit)
|
|
:returns: Task in terminal state
|
|
:raises TimeoutError: If timeout expires before task completes
|
|
"""
|
|
from flask import current_app
|
|
|
|
from superset.daos.tasks import TaskDAO
|
|
|
|
# Check if already in terminal state
|
|
if task.status in TERMINAL_STATES:
|
|
logger.info(
|
|
"Joined already-completed task %s (uuid=%s, status=%s)",
|
|
self.name,
|
|
task.uuid,
|
|
task.status,
|
|
)
|
|
return task
|
|
|
|
# Wait for the existing task to complete
|
|
logger.info(
|
|
"Joined active task %s (uuid=%s, status=%s), waiting for completion",
|
|
self.name,
|
|
task.uuid,
|
|
task.status,
|
|
)
|
|
|
|
try:
|
|
app = current_app._get_current_object()
|
|
except RuntimeError:
|
|
app = None
|
|
|
|
try:
|
|
task = TaskManager.wait_for_completion(
|
|
task_uuid=task.uuid,
|
|
timeout=float(timeout) if timeout else None,
|
|
poll_interval=1.0,
|
|
app=app,
|
|
)
|
|
logger.info(
|
|
"Task %s (uuid=%s) completed with status=%s",
|
|
self.name,
|
|
task.uuid,
|
|
task.status,
|
|
)
|
|
return task
|
|
|
|
except TimeoutError:
|
|
logger.warning(
|
|
"Timeout waiting for task %s (uuid=%s)",
|
|
self.name,
|
|
task.uuid,
|
|
)
|
|
# Return task in current state (caller can check status)
|
|
refreshed = TaskDAO.find_one_or_none(uuid=task.uuid)
|
|
return refreshed if refreshed else task
|
|
|
|
def _execute_inline(
|
|
self,
|
|
task: "Task",
|
|
options: TaskOptions,
|
|
args: tuple[Any, ...],
|
|
kwargs: dict[str, Any],
|
|
) -> "Task":
|
|
"""
|
|
Execute task function inline (synchronously).
|
|
|
|
Called when this is a new task (not joining existing).
|
|
Uses atomic conditional status transitions for race-safe execution.
|
|
|
|
:param task: The newly created task
|
|
:param options: Merged task options
|
|
:param args: Positional arguments for the task function
|
|
:param kwargs: Keyword arguments for the task function
|
|
:returns: Task in terminal state
|
|
"""
|
|
from superset.commands.tasks.internal_update import (
|
|
InternalStatusTransitionCommand,
|
|
)
|
|
from superset.daos.tasks import TaskDAO
|
|
from superset.tasks.constants import ABORT_STATES
|
|
|
|
# PRE-EXECUTION CHECK: Don't execute if already aborted/aborting
|
|
# (Matches async flow in scheduler.py)
|
|
if task.status in ABORT_STATES:
|
|
logger.info(
|
|
"Task %s (uuid=%s) was aborted before execution started",
|
|
self.name,
|
|
task.uuid,
|
|
)
|
|
# Ensure status is ABORTED (not just ABORTING)
|
|
InternalStatusTransitionCommand(
|
|
task_uuid=task.uuid,
|
|
new_status=TaskStatus.ABORTED,
|
|
expected_status=[TaskStatus.PENDING, TaskStatus.ABORTING],
|
|
set_ended_at=True,
|
|
).run()
|
|
# Refresh to get updated task
|
|
refreshed = TaskDAO.find_one_or_none(uuid=task.uuid)
|
|
return refreshed if refreshed else task
|
|
|
|
# Atomic transition: PENDING → IN_PROGRESS (set started_at for duration
|
|
# tracking)
|
|
task_uuid = task.uuid # Cache UUID before any potential state changes
|
|
if not InternalStatusTransitionCommand(
|
|
task_uuid=task_uuid,
|
|
new_status=TaskStatus.IN_PROGRESS,
|
|
expected_status=TaskStatus.PENDING,
|
|
set_started_at=True,
|
|
).run():
|
|
# Status wasn't PENDING - task may have been aborted concurrently
|
|
logger.warning(
|
|
"Task %s (uuid=%s) failed PENDING → IN_PROGRESS transition "
|
|
"(may have been aborted concurrently)",
|
|
self.name,
|
|
task_uuid,
|
|
)
|
|
refreshed = TaskDAO.find_one_or_none(uuid=task_uuid)
|
|
return refreshed if refreshed else task
|
|
|
|
# Update cached status (no DB read needed - we just wrote IN_PROGRESS)
|
|
task.status = TaskStatus.IN_PROGRESS.value
|
|
|
|
# Build context with the updated task entity
|
|
ctx = TaskContext(task)
|
|
|
|
# Start timeout timer if configured
|
|
if options.timeout:
|
|
ctx.start_timeout_timer(options.timeout)
|
|
logger.debug(
|
|
"Started timeout timer for task %s: %d seconds",
|
|
task.uuid,
|
|
options.timeout,
|
|
)
|
|
|
|
# Track final task state for completion notification
|
|
final_task: Task | None = None
|
|
|
|
try:
|
|
# Execute with ambient context
|
|
with use_context(ctx):
|
|
self.func(*args, **kwargs)
|
|
|
|
# Determine terminal status based on abort detection
|
|
# Use atomic conditional updates to prevent overwriting concurrent abort
|
|
if ctx._abort_detected or ctx.timeout_triggered:
|
|
# Abort was detected - transition ABORTING → terminal
|
|
if ctx.timeout_triggered:
|
|
InternalStatusTransitionCommand(
|
|
task_uuid=task_uuid,
|
|
new_status=TaskStatus.TIMED_OUT,
|
|
expected_status=TaskStatus.ABORTING,
|
|
set_ended_at=True,
|
|
).run()
|
|
logger.info(
|
|
"Task %s (uuid=%s) timed out and completed cleanup",
|
|
self.name,
|
|
task_uuid,
|
|
)
|
|
else:
|
|
InternalStatusTransitionCommand(
|
|
task_uuid=task_uuid,
|
|
new_status=TaskStatus.ABORTED,
|
|
expected_status=TaskStatus.ABORTING,
|
|
set_ended_at=True,
|
|
).run()
|
|
logger.info(
|
|
"Task %s (uuid=%s) was aborted by user",
|
|
self.name,
|
|
task_uuid,
|
|
)
|
|
else:
|
|
# Normal completion - atomic IN_PROGRESS → SUCCESS
|
|
# This will fail (return False) if task was concurrently aborted
|
|
if InternalStatusTransitionCommand(
|
|
task_uuid=task_uuid,
|
|
new_status=TaskStatus.SUCCESS,
|
|
expected_status=TaskStatus.IN_PROGRESS,
|
|
set_ended_at=True,
|
|
).run():
|
|
logger.debug(
|
|
"Synchronous execution of task %s (uuid=%s) "
|
|
"completed successfully",
|
|
self.name,
|
|
task_uuid,
|
|
)
|
|
else:
|
|
# Transition failed - task was likely aborted concurrently
|
|
logger.info(
|
|
"Task %s (uuid=%s) IN_PROGRESS → SUCCESS failed "
|
|
"(may have been aborted concurrently)",
|
|
self.name,
|
|
task_uuid,
|
|
)
|
|
|
|
# Refresh once at end to return current state
|
|
final_task = TaskDAO.find_one_or_none(uuid=task_uuid)
|
|
return final_task if final_task else task
|
|
|
|
except Exception as ex:
|
|
# Atomic transition to FAILURE (only if still IN_PROGRESS)
|
|
InternalStatusTransitionCommand(
|
|
task_uuid=task_uuid,
|
|
new_status=TaskStatus.FAILURE,
|
|
expected_status=[TaskStatus.IN_PROGRESS, TaskStatus.ABORTING],
|
|
properties={"error_message": str(ex)},
|
|
set_ended_at=True,
|
|
).run()
|
|
|
|
logger.error(
|
|
"Synchronous execution of task %s (uuid=%s) failed: %s",
|
|
self.name,
|
|
task_uuid,
|
|
str(ex),
|
|
exc_info=True,
|
|
)
|
|
|
|
# Refresh once at end to return current state
|
|
final_task = TaskDAO.find_one_or_none(uuid=task_uuid)
|
|
return final_task if final_task else task
|
|
|
|
finally:
|
|
# Always clean up timer and handlers
|
|
ctx._run_cleanup()
|
|
|
|
# Publish completion notification for any waiters
|
|
# Use final_task if set by try/except, otherwise refresh (fallback)
|
|
if final_task is None:
|
|
final_task = TaskDAO.find_one_or_none(uuid=task_uuid)
|
|
if final_task and final_task.status in TERMINAL_STATES:
|
|
TaskManager.publish_completion(task_uuid, final_task.status)
|
|
|
|
def schedule(self, *args: P.args, **kwargs: P.kwargs) -> "Task":
|
|
"""
|
|
Schedule this task for asynchronous execution.
|
|
|
|
The signature mirrors the original task function, with an additional
|
|
keyword-only 'options' parameter for execution metadata.
|
|
|
|
Args:
|
|
*args, **kwargs: Business arguments for the task function
|
|
options: Execution options
|
|
|
|
Returns:
|
|
Task model representing the scheduled task (PENDING status)
|
|
|
|
Raises:
|
|
GlobalTaskFrameworkDisabledError: If GTF feature flag is not enabled
|
|
ValueError: If task is SHARED scope but no task_key is provided
|
|
|
|
Usage:
|
|
# Auto-generated task_key (random UUID, no deduplication):
|
|
task = generate_thumbnail.schedule(chart_id)
|
|
|
|
# Custom task_key for task deduplication:
|
|
task = generate_thumbnail.schedule(
|
|
chart_id,
|
|
options=TaskOptions(task_key=f"thumb_{chart_id}")
|
|
)
|
|
|
|
# SHARED tasks require task_key:
|
|
task = shared_task.schedule(
|
|
data_id,
|
|
options=TaskOptions(task_key=f"shared_{data_id}")
|
|
)
|
|
|
|
Note: Unlike direct calls (__call__), this schedules async execution.
|
|
The function returns immediately with the Task model in PENDING status.
|
|
"""
|
|
if not is_feature_enabled("GLOBAL_TASK_FRAMEWORK"):
|
|
raise GlobalTaskFrameworkDisabledError()
|
|
|
|
# Extract and merge options (decorator defaults + call-time overrides)
|
|
override_options = cast(TaskOptions | None, kwargs.pop("options", None))
|
|
options = self._merge_options(override_options)
|
|
|
|
# Validate task configuration
|
|
self._validate_task(options)
|
|
|
|
# Extract task_name and task_key from merged options, scope from decorator
|
|
task_name = options.task_name
|
|
task_key = options.task_key
|
|
scope = self.scope # Use scope from decorator
|
|
|
|
# Create task entry in metastore and schedule execution
|
|
return TaskManager.submit_task(
|
|
task_type=self.name,
|
|
task_name=task_name,
|
|
task_key=task_key,
|
|
scope=scope,
|
|
timeout=options.timeout,
|
|
args=args,
|
|
kwargs=kwargs,
|
|
)
|