feat: add global task framework (#36368)

This commit is contained in:
Ville Brofeldt
2026-02-09 10:45:56 -08:00
committed by GitHub
parent 6984e93171
commit 59dd2fa385
89 changed files with 15535 additions and 291 deletions

View File

@@ -24,6 +24,28 @@ assists people when migrating to a new version.
## Next
### Signal Cache Backend
A new `SIGNAL_CACHE_CONFIG` configuration provides a unified Redis-based backend for real-time coordination features in Superset. This backend enables:
- **Pub/sub messaging** for real-time event notifications between workers
- **Atomic distributed locking** using Redis SET NX EX (more performant than database-backed locks)
- **Event-based coordination** for background task management
The signal cache is used by the Global Task Framework (GTF) for abort notifications and task completion signaling, and will eventually replace `GLOBAL_ASYNC_QUERIES_CACHE_BACKEND` as the standard signaling backend. Configuring this is recommended for Redis enabled production deployments.
Example configuration in `superset_config.py`:
```python
SIGNAL_CACHE_CONFIG = {
"CACHE_TYPE": "RedisCache",
"CACHE_KEY_PREFIX": "signal_",
"CACHE_REDIS_URL": "redis://localhost:6379/1",
"CACHE_DEFAULT_TIMEOUT": 300,
}
```
See `superset/config.py` for complete configuration options.
### WebSocket config for GAQ with Docker
[35896](https://github.com/apache/superset/pull/35896) and [37624](https://github.com/apache/superset/pull/37624) updated documentation on how to run and configure Superset with Docker. Specifically for the WebSocket configuration, a new `docker/superset-websocket/config.example.json` was added to the repo, so that users could copy it to create a `docker/superset-websocket/config.json` file. The existing `docker/superset-websocket/config.json` was removed and git-ignored, so if you're using GAQ / WebSocket make sure to:

View File

@@ -51,4 +51,5 @@ Extensions can provide:
- **[Deployment](./deployment)** - Packaging and deploying extensions
- **[MCP Integration](./mcp)** - Adding AI agent capabilities using extensions
- **[Security](./security)** - Security considerations and best practices
- **[Tasks](./tasks)** - Framework for creating and managing long running tasks
- **[Community Extensions](./registry)** - Browse extensions shared by the community

View File

@@ -1,6 +1,6 @@
---
title: Community Extensions
sidebar_position: 10
sidebar_position: 11
---
<!--

View File

@@ -0,0 +1,440 @@
---
title: Tasks
sidebar_position: 10
---
<!--
Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
-->
# Global Task Framework
The Global Task Framework (GTF) provides a unified way to manage background tasks. It handles task execution, progress tracking, cancellation, and deduplication for both synchronous and asynchronous execution. The framework uses distributed locking internally to ensure race-free operations—you don't need to worry about concurrent task creation or cancellation conflicts.
## Enabling GTF
GTF is disabled by default and must be enabled via the `GLOBAL_TASK_FRAMEWORK` feature flag in your `superset_config.py`:
```python
FEATURE_FLAGS = {
"GLOBAL_TASK_FRAMEWORK": True,
}
```
When GTF is disabled:
- The Task List UI menu item is hidden
- The `/api/v1/task/*` endpoints return 404
- Calling or scheduling a `@task`-decorated function raises `GlobalTaskFrameworkDisabledError`
:::note Future Migration
When GTF is considered stable, it will replace legacy Celery tasks for built-in features like thumbnails and alerts & reports. Enabling this flag prepares your deployment for that migration.
:::
## Quick Start
### Define a Task
```python
from superset_core.api.types import task, get_context
@task
def process_data(dataset_id: int) -> None:
ctx = get_context()
@ctx.on_cleanup
def cleanup():
logger.info("Processing complete")
data = fetch_dataset(dataset_id)
process_and_cache(data)
```
### Execute a Task
```python
# Async execution - schedules on Celery worker
task = process_data.schedule(dataset_id=123)
print(task.status) # "pending"
# Sync execution - runs inline in current process
task = process_data(dataset_id=123)
# ... blocks until complete
print(task.status) # "success"
```
### Async vs Sync Execution
| Method | When to Use |
|--------|-------------|
| `.schedule()` | Long-running operations, background processing, when you need to return immediately |
| Direct call | Short operations, when deduplication matters, when you need the result before responding |
Both execution modes provide the same task features: deduplication, progress tracking, cancellation, and visibility in the Task List UI. The difference is whether execution happens in a Celery worker (async) or inline (sync).
## Task Lifecycle
```
PENDING ──→ IN_PROGRESS ────→ SUCCESS
│ │
│ ├──────────→ FAILURE
│ ↓ ↑
│ ABORTING ────────────┘
│ │
│ ├──────────→ TIMED_OUT (timeout)
│ │
└─────────────┴──────────→ ABORTED (user cancel)
```
| Status | Description |
|--------|-------------|
| `PENDING` | Queued, awaiting execution |
| `IN_PROGRESS` | Executing |
| `ABORTING` | Abort/timeout triggered, abort handlers running |
| `SUCCESS` | Completed successfully |
| `FAILURE` | Failed with error or abort/cleanup handler exception |
| `ABORTED` | Cancelled by user/admin |
| `TIMED_OUT` | Exceeded configured timeout |
## Context API
Access task context via `get_context()` from within any `@task` function. The context provides methods for updating task metadata and registering handlers.
### Updating Task Metadata
Use `update_task()` to report progress and store custom payload data:
```python
@task
def my_task(items: list[int]) -> None:
ctx = get_context()
for i, item in enumerate(items):
result = process(item)
ctx.update_task(
progress=(i + 1, len(items)),
payload={"last_result": result}
)
```
:::tip
Call `update_task()` once per iteration for best performance. Frequent DB writes are throttled to limit metastore load, so batching progress and payload updates together in a single call ensures both are persisted at the same time.
:::
#### Progress Formats
The `progress` parameter accepts three formats:
| Format | Example | Display |
|--------|---------|---------|
| `tuple[int, int]` | `progress=(3, 100)` | 3 of 100 (3%) with ETA |
| `float` (0.0-1.0) | `progress=0.5` | 50% with ETA |
| `int` | `progress=42` | 42 processed |
:::tip
Use the tuple format `(current, total)` whenever possible. It provides the richest information to users: showing both the count and percentage, while still computing ETA automatically.
:::
#### Payload
The `payload` parameter stores custom metadata that can help users understand what the task is doing. Each call to `update_task()` replaces the previous payload completely.
In the Task List UI, when a payload is defined, an info icon appears in the **Details** column. Users can hover over it to see the JSON content.
### Handlers
Register handlers to run cleanup logic or respond to abort requests:
| Handler | When it runs | Use case |
|---------|--------------|----------|
| `on_cleanup` | Always (success, failure, abort) | Release resources, close connections |
| `on_abort` | When task is aborted | Set stop flag, cancel external operations |
```python
@task
def my_task() -> None:
ctx = get_context()
@ctx.on_cleanup
def cleanup():
logger.info("Task ended, cleaning up")
@ctx.on_abort
def handle_abort():
logger.info("Abort requested")
# ... task logic
```
Multiple handlers of the same type execute in LIFO order (last registered runs first). Abort handlers run first when abort is detected, then cleanup handlers run when the task ends.
#### Best-Effort Execution
**All registered handlers will always be attempted, even if one fails.** This ensures that a failure in one handler doesn't prevent other handlers from running their cleanup logic.
For example, if you have three cleanup handlers and the second one throws an exception:
1. Handler 3 runs ✓
2. Handler 2 throws an exception ✗ (logged, but execution continues)
3. Handler 1 runs ✓
If any handler fails, the task is marked as `FAILURE` with combined error details showing all handler failures.
:::tip
Write handlers to be independent and self-contained. Don't assume previous handlers succeeded, and don't rely on shared state between handlers.
:::
## Making Tasks Abortable
When users click **Cancel** in the Task List, the system decides whether to **abort** (stop) the task or **unsubscribe** (remove the user from a shared task). Abort occurs when:
- It's a private or system task
- It's a shared task and the user is the last subscriber
- An admin checks **Force abort** to stop the task for all subscribers
Pending tasks can always be aborted: they simply won't start. In-progress tasks require an abort handler to be abortable:
```python
@task
def abortable_task(items: list[str]) -> None:
ctx = get_context()
should_stop = False
@ctx.on_abort
def handle_abort():
nonlocal should_stop
should_stop = True
logger.info("Abort signal received")
@ctx.on_cleanup
def cleanup():
logger.info("Task ended, cleaning up")
for item in items:
if should_stop:
return # Exit gracefully
process(item)
```
**Key points:**
- Registering `on_abort` marks the task as abortable and starts the abort listener
- The abort handler fires automatically when abort is triggered
- Use a flag pattern to gracefully stop processing at safe points
- Without an abort handler, in-progress tasks cannot be aborted: the Cancel button in the Task List UI will be disabled
The framework automatically skips execution if a task was aborted while pending: no manual check needed at task start.
:::tip
Always implement an abort handler for long-running tasks. This allows users to cancel unneeded tasks and free up worker capacity for other operations.
:::
## Timeouts
Set a timeout to automatically abort tasks that run too long:
```python
from superset_core.api.types import task, get_context, TaskOptions
# Set default timeout in decorator
@task(timeout=300) # 5 minutes
def process_data(dataset_id: int) -> None:
ctx = get_context()
should_stop = False
@ctx.on_abort
def handle_abort():
nonlocal should_stop
should_stop = True
for chunk in fetch_large_dataset(dataset_id):
if should_stop:
return
process(chunk)
# Override timeout at call time
task = process_data.schedule(
dataset_id=123,
options=TaskOptions(timeout=600) # Override to 10 minutes
)
```
### How Timeouts Work
The timeout timer starts when the task begins executing (status changes to `IN_PROGRESS`). When the timeout expires:
1. **With an abort handler registered:** The task transitions to `ABORTING`, abort handlers run, then cleanup handlers run. The final status depends on handler execution:
- If handlers complete successfully → `TIMED_OUT` status
- If handlers throw an exception → `FAILURE` status
2. **Without an abort handler:** The framework cannot forcibly terminate the task. A warning is logged, and the task continues running. The Task List UI shows a warning indicator (⚠️) in the Details column to alert users that the timeout cannot be enforced.
### Timeout Precedence
| Source | Priority | Example |
|--------|----------|---------|
| `TaskOptions.timeout` | Highest | `options=TaskOptions(timeout=600)` |
| `@task(timeout=...)` | Default | `@task(timeout=300)` |
| Not set | No timeout | Task runs indefinitely |
Call-time options always override decorator defaults, allowing tasks to have sensible defaults while permitting callers to extend or shorten the timeout for specific use cases.
:::warning
Timeouts require an abort handler to be effective. Without one, the timeout triggers only a warning and the task continues running. Always implement an abort handler when using timeouts.
:::
## Deduplication
Use `task_key` to prevent duplicate task execution:
```python
from superset_core.api.types import TaskOptions
# Without key - creates new task each time (random UUID)
task1 = my_task.schedule(x=1)
task2 = my_task.schedule(x=1) # Different task
# With key - joins existing task if active
task1 = my_task.schedule(x=1, options=TaskOptions(task_key="report_123"))
task2 = my_task.schedule(x=1, options=TaskOptions(task_key="report_123")) # Returns same task
```
When a task with matching key already exists, the user is added as a subscriber and the existing task is returned. This behavior is consistent across all scopes—private tasks naturally have only one subscriber since their deduplication key includes the user ID.
Deduplication only applies to active tasks (pending/in-progress). Once a task completes, a new task with the same key can be created.
### Sync Join-and-Wait
When a sync call joins an existing task, it blocks until the task completes:
```python
# Schedule async task
task = my_task.schedule(options=TaskOptions(task_key="report_123"))
# Later sync call with same key blocks until completion of the active task
task2 = my_task(options=TaskOptions(task_key="report_123"))
assert task.uuid == task2.uuid # True
print(task2.status) # "success" (terminal status)
```
## Task Scopes
```python
from superset_core.api.types import task, TaskScope
@task # Private by default
def private_task(): ...
@task(scope=TaskScope.SHARED) # Multiple users can subscribe
def shared_task(): ...
@task(scope=TaskScope.SYSTEM) # Admin-only visibility
def system_task(): ...
```
| Scope | Visibility | Cancel Behavior |
|-------|------------|-----------------|
| `PRIVATE` | Creator only | Cancels immediately |
| `SHARED` | All subscribers | Last subscriber cancels; others unsubscribe |
| `SYSTEM` | Admins only | Admin cancels |
## Task Cleanup
Completed tasks accumulate in the database over time. Configure a scheduled prune job to automatically remove old tasks:
```python
# In your superset_config.py, add to your Celery beat schedule:
CELERY_CONFIG.beat_schedule["prune_tasks"] = {
"task": "prune_tasks",
"schedule": crontab(minute=0, hour=0), # Run daily at midnight
"kwargs": {
"retention_period_days": 90, # Keep tasks for 90 days
"max_rows_per_run": 10000, # Limit deletions per run
},
}
```
The prune job only removes tasks in terminal states (`SUCCESS`, `FAILURE`, `ABORTED`, `TIMED_OUT`). Active tasks (`PENDING`, `IN_PROGRESS`, `ABORTING`) are never pruned.
See `superset/config.py` for a complete example configuration.
:::tip Signal Cache for Faster Notifications
By default, abort detection and sync join-and-wait use database polling. Configure `SIGNAL_CACHE_CONFIG` to enable Redis pub/sub for real-time notifications. See [Signal Cache Backend](/docs/configuration/cache#signal-cache-backend) for configuration details.
:::
## API Reference
### @task Decorator
```python
@task(
name: str | None = None,
scope: TaskScope = TaskScope.PRIVATE,
timeout: int | None = None
)
```
- `name`: Task identifier (defaults to function name)
- `scope`: `PRIVATE`, `SHARED`, or `SYSTEM`
- `timeout`: Default timeout in seconds (can be overridden via `TaskOptions`)
### TaskContext Methods
| Method | Description |
|--------|-------------|
| `update_task(progress, payload)` | Update progress and/or custom payload |
| `on_cleanup(handler)` | Register cleanup handler |
| `on_abort(handler)` | Register abort handler (makes task abortable) |
### TaskOptions
```python
TaskOptions(
task_key: str | None = None,
task_name: str | None = None,
timeout: int | None = None
)
```
- `task_key`: Deduplication key (also used as display name if `task_name` is not set)
- `task_name`: Human-readable display name for the Task List UI
- `timeout`: Timeout in seconds (overrides decorator default)
:::tip
Provide a descriptive `task_name` for better readability in the Task List UI. While `task_key` is used for deduplication and may be technical (e.g., `chart_export_123`), `task_name` can be user-friendly (e.g., `"Export Sales Chart 123"`).
:::
## Error Handling
Let exceptions propagate: the framework captures them automatically and sets task status to `FAILURE`:
```python
@task
def risky_task() -> None:
# No try/catch needed - framework handles it
result = operation_that_might_fail()
```
On failure, the framework records:
- `error_message`: Exception message
- `exception_type`: Exception class name
- `stack_trace`: Full traceback (visible when `SHOW_STACKTRACE=True`)
In the Task List UI, failed tasks show error details when hovering over the status. When stack traces are enabled, a separate bug icon appears in the **Details** column for viewing the full traceback.
Cleanup handlers still run after an exception, so resources can be properly released as necessary.
:::tip
Use descriptive exception messages. In environments where stack traces are hidden (`SHOW_STACKTRACE=False`), users see only the error message and exception type when hovering over failed tasks. Clear messages help users troubleshoot issues without administrator assistance.
:::

View File

@@ -53,6 +53,7 @@ module.exports = {
'extensions/deployment',
'extensions/mcp',
'extensions/security',
'extensions/tasks',
'extensions/registry',
],
},

View File

@@ -7,6 +7,12 @@ version: 1
# Caching
:::note
When a cache backend is configured, Superset expects it to remain available. Operations will
fail if the configured backend becomes unavailable rather than silently degrading. This
fail-fast behavior ensures operators are immediately aware of infrastructure issues.
:::
Superset uses [Flask-Caching](https://flask-caching.readthedocs.io/) for caching purposes.
Flask-Caching supports various caching backends, including Redis (recommended), Memcached,
SimpleCache (in-memory), or the local filesystem.
@@ -153,6 +159,84 @@ Then on configuration:
WEBDRIVER_AUTH_FUNC = auth_driver
```
## Signal Cache Backend
Superset supports an optional signal cache (`SIGNAL_CACHE_CONFIG`) for
high-performance distributed operations. This configuration enables:
- **Distributed locking**: Moves lock operations from the metadata database to Redis, improving
performance and reducing metastore load
- **Real-time event notifications**: Enables instant pub/sub messaging for task abort signals and
completion notifications instead of polling-based approaches
:::note
This requires Redis or Valkey specifically—it uses Redis-specific features (pub/sub, `SET NX EX`)
that are not available in general Flask-Caching backends.
:::
### Configuration
The signal cache uses Flask-Caching style configuration for consistency with other cache
backends. Configure `SIGNAL_CACHE_CONFIG` in `superset_config.py`:
```python
SIGNAL_CACHE_CONFIG = {
"CACHE_TYPE": "RedisCache",
"CACHE_REDIS_HOST": "localhost",
"CACHE_REDIS_PORT": 6379,
"CACHE_REDIS_DB": 0,
"CACHE_REDIS_PASSWORD": "", # Optional
}
```
For Redis Sentinel deployments:
```python
SIGNAL_CACHE_CONFIG = {
"CACHE_TYPE": "RedisSentinelCache",
"CACHE_REDIS_SENTINELS": [("sentinel1", 26379), ("sentinel2", 26379)],
"CACHE_REDIS_SENTINEL_MASTER": "mymaster",
"CACHE_REDIS_SENTINEL_PASSWORD": None, # Sentinel password (if different)
"CACHE_REDIS_PASSWORD": "", # Redis password
"CACHE_REDIS_DB": 0,
}
```
For SSL/TLS connections:
```python
SIGNAL_CACHE_CONFIG = {
"CACHE_TYPE": "RedisCache",
"CACHE_REDIS_HOST": "redis.example.com",
"CACHE_REDIS_PORT": 6380,
"CACHE_REDIS_SSL": True,
"CACHE_REDIS_SSL_CERTFILE": "/path/to/client.crt",
"CACHE_REDIS_SSL_KEYFILE": "/path/to/client.key",
"CACHE_REDIS_SSL_CA_CERTS": "/path/to/ca.crt",
}
```
### Distributed Lock TTL
You can configure the default lock TTL (time-to-live) in seconds. Locks automatically expire after
this duration to prevent deadlocks from crashed processes:
```python
DISTRIBUTED_LOCK_DEFAULT_TTL = 30 # Default: 30 seconds
```
Individual lock acquisitions can override this value when needed.
### Database-Only Mode
When `SIGNAL_CACHE_CONFIG` is not configured, Superset uses database-backed operations:
- **Locking**: Uses the KeyValue table with periodic cleanup of expired entries
- **Event notifications**: Uses database polling instead of pub/sub
While database-backed operations work reliably, the Redis backend is recommended for production
deployments where low latency and reduced database load are important.
:::resources
- [Blog: The Data Engineer's Guide to Lightning-Fast Superset Dashboards](https://preset.io/blog/the-data-engineers-guide-to-lightning-fast-apache-superset-dashboards/)
- [Blog: Accelerating Dashboards with Materialized Views](https://preset.io/blog/accelerating-apache-superset-dashboards-with-materialized-views/)

View File

@@ -97,6 +97,7 @@ const sidebars = {
'extensions/deployment',
'extensions/mcp',
'extensions/security',
'extensions/tasks',
'extensions/registry',
],
},

View File

@@ -46,6 +46,7 @@ from superset_core.api.models import (
Query,
SavedQuery,
Tag,
Task,
User,
)
@@ -248,6 +249,48 @@ class KeyValueDAO(BaseDAO[KeyValue]):
id_column_name = "id"
class TaskDAO(BaseDAO[Task]):
"""
Abstract Task DAO interface.
Host implementations will replace this class during initialization
with a concrete implementation providing actual functionality.
"""
# Class variables that will be set by host implementation
model_cls = None
base_filter = None
id_column_name = "id"
uuid_column_name = "uuid"
@classmethod
@abstractmethod
def find_by_task_key(
cls,
task_type: str,
task_key: str,
scope: str = "private",
user_id: int | None = None,
) -> Task | None:
"""
Find active task by type, key, scope, and user.
Uses dedup_key internally for efficient querying with a unique index.
Only returns tasks that are active (pending or in progress).
Uniqueness logic by scope:
- private: scope + task_type + task_key + user_id
- shared/system: scope + task_type + task_key (user-agnostic)
:param task_type: Task type to filter by
:param task_key: Task identifier for deduplication
:param scope: Task scope (private/shared/system)
:param user_id: User ID (required for private tasks)
:returns: Task instance or None if not found or not active
"""
...
__all__ = [
"BaseDAO",
"DatasetDAO",
@@ -259,4 +302,5 @@ __all__ = [
"SavedQueryDAO",
"TagDAO",
"KeyValueDAO",
"TaskDAO",
]

View File

@@ -40,6 +40,7 @@ from flask_appbuilder import Model
from sqlalchemy.orm import scoped_session
if TYPE_CHECKING:
from superset_core.api.tasks import TaskProperties
from superset_core.api.types import (
AsyncQueryHandle,
QueryOptions,
@@ -361,6 +362,132 @@ class KeyValue(CoreModel):
changed_by_fk: int | None
class Task(CoreModel):
"""
Abstract Task model interface.
Host implementations will replace this class during initialization
with concrete implementation providing actual functionality.
This model represents async tasks in the Global Task Framework (GTF).
Non-filterable fields (progress, error info, execution config) are stored
in a `properties` JSON blob for schema flexibility.
"""
__abstract__ = True
# Type hints for expected column attributes
id: int
uuid: UUID
task_key: str # For deduplication
task_type: str # e.g., 'sql_execution'
task_name: str | None # Human readable name
scope: str # private/shared/system
status: str
dedup_key: str # Computed deduplication key
# Timestamps (from AuditMixinNullable)
created_on: datetime | None
changed_on: datetime | None
started_at: datetime | None
ended_at: datetime | None
# User context
created_by_fk: int | None
user_id: int | None
# Task output data
payload: str # JSON serialized task output data
def get_payload(self) -> dict[str, Any]:
"""
Get payload as parsed JSON.
Payload contains task-specific output data set by task code.
Host implementations will replace this method during initialization
with concrete implementation providing actual functionality.
:returns: Dictionary containing payload data
"""
raise NotImplementedError("Method will be replaced during initialization")
def set_payload(self, data: dict[str, Any]) -> None:
"""
Update payload with new data (merges with existing).
Host implementations will replace this method during initialization
with concrete implementation providing actual functionality.
:param data: Dictionary of data to merge into payload
"""
raise NotImplementedError("Method will be replaced during initialization")
@property
def properties(self) -> Any:
"""
Get typed properties (runtime state and execution config).
Properties contain:
- is_abortable: bool | None - has abort handler registered
- progress_percent: float | None - progress 0.0-1.0
- progress_current: int | None - current iteration count
- progress_total: int | None - total iterations
- error_message: str | None - human-readable error message
- exception_type: str | None - exception class name
- stack_trace: str | None - full formatted traceback
- timeout: int | None - timeout in seconds
Host implementations will replace this property during initialization.
:returns: TaskProperties dataclass instance
"""
raise NotImplementedError("Property will be replaced during initialization")
def update_properties(self, updates: "TaskProperties") -> None:
"""
Update specific properties fields (merge semantics).
Only updates fields present in the updates dict.
Host implementations will replace this method during initialization.
:param updates: TaskProperties dict with fields to update
Example:
task.update_properties({"is_abortable": True})
"""
raise NotImplementedError("Method will be replaced during initialization")
class TaskSubscriber(CoreModel):
"""
Abstract TaskSubscriber model interface.
Host implementations will replace this class during initialization
with concrete implementation providing actual functionality.
This model tracks task subscriptions for multi-user shared tasks. When a user
schedules a shared task with the same parameters as an existing task,
they are subscribed to that task instead of creating a duplicate.
"""
__abstract__ = True
# Type hints for expected attributes (no actual field definitions)
id: int
task_id: int
user_id: int
subscribed_at: datetime
# Audit fields from AuditMixinNullable
created_on: datetime | None
changed_on: datetime | None
created_by_fk: int | None
changed_by_fk: int | None
def get_session() -> scoped_session:
"""
Retrieve the SQLAlchemy session to directly interface with the
@@ -384,6 +511,8 @@ __all__ = [
"SavedQuery",
"Tag",
"KeyValue",
"Task",
"TaskSubscriber",
"CoreModel",
"get_session",
]

View File

@@ -0,0 +1,361 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import Any, Callable, Generic, Literal, ParamSpec, TypedDict, TypeVar
from superset_core.api.models import Task
P = ParamSpec("P")
R = TypeVar("R")
class TaskStatus(str, Enum):
"""
Status of task execution.
"""
PENDING = "pending"
IN_PROGRESS = "in_progress"
SUCCESS = "success"
FAILURE = "failure"
ABORTING = "aborting" # Abort/timeout requested, handlers running
ABORTED = "aborted" # User/admin cancelled
TIMED_OUT = "timed_out" # Timeout expired
class TaskScope(str, Enum):
"""
Scope of task visibility and access control.
"""
PRIVATE = "private" # User-specific tasks (default)
SHARED = "shared" # Multi-user collaborative tasks
SYSTEM = "system" # Admin-only background tasks
class TaskProperties(TypedDict, total=False):
"""
TypedDict for task runtime state and execution config.
Stored as JSON in the database, accessed as a dict throughout the codebase.
All fields are optional (total=False) - only set keys are present in the dict.
Usage:
# Reading - always use .get() since keys may not be present
if task.properties.get("is_abortable"):
...
# Writing/updating - only include keys you want to set
task.update_properties({"is_abortable": True, "progress_percent": 0.5})
Notes:
- Sparse dict: only keys that are explicitly set are present
- Unknown keys from JSON are preserved (forward compatibility)
- Always use .get() for reads since keys may be absent
"""
# Execution config - set at task creation
execution_mode: Literal["async", "sync"]
timeout: int
# Runtime state - set by framework during execution
is_abortable: bool
progress_percent: float
progress_current: int
progress_total: int
# Error info - set when task fails
error_message: str
exception_type: str
stack_trace: str
@dataclass(frozen=True)
class TaskOptions:
"""
Execution metadata for tasks.
NOTE: This is intentionally minimal for the initial implementation.
Additional options (queue, priority, run_at, delay_s,
max_retries, retry_backoff_s, tags, etc.) can be added later when needed.
Future enhancements will include:
- Validation (e.g., run_at vs delay_s mutual exclusion)
- Queue routing and priority management
- Retry policies and backoff strategies
Example:
from superset_core.api.tasks import TaskOptions, TaskScope
# Private task (default)
task = my_task.schedule(arg1)
# Custom task with deduplication
task = my_task.schedule(
arg1,
options=TaskOptions(
task_key="custom_key",
task_name="Custom Task Name"
)
)
# Task with custom name
task = admin_task.schedule(
options=TaskOptions(task_name="Admin Operation")
)
# Task with timeout (overrides decorator default)
task = long_task.schedule(
options=TaskOptions(timeout=600) # 10 minute timeout
)
"""
task_key: str | None = None
task_name: str | None = None
timeout: int | None = None # Timeout in seconds
class TaskContext(ABC):
"""
Abstract task context for write-only task state updates.
Tasks use this context to update their state (progress, payload) and
check for cancellation. Tasks should not need to read their own state -
they are the source of state, not consumers of it.
Host implementations will replace this abstract class during initialization
with a concrete implementation providing actual functionality.
"""
@abstractmethod
def update_task(
self,
progress: float | int | tuple[int, int] | None = None,
payload: dict[str, Any] | None = None,
) -> None:
"""
Update task progress and/or payload atomically.
All parameters are optional. Payload is merged with existing data,
not replaced. All updates occur in a single database transaction.
Progress can be specified in three ways:
- float (0.0-1.0): Percentage only, e.g., 0.5 means 50%
- int: Count only (total unknown), e.g., 42 means "42 items processed"
- tuple[int, int]: Count and total, e.g., (3, 100) means "3 of 100"
The percentage is automatically computed from count/total.
:param progress: Progress value, or None to leave unchanged
:param payload: Payload data to merge (dict), or None to leave unchanged
Examples:
# Percentage only - displays as "In progress: 50 %"
ctx.update_task(progress=0.5)
# Count only (total unknown) - displays as "In progress: 42"
ctx.update_task(progress=42)
# Count and total - displays as "In progress: 3 of 100 (3 %)"
ctx.update_task(progress=(3, 100))
# Update payload only
ctx.update_task(payload={"step": "processing"})
# Update both atomically
ctx.update_task(
progress=(80, 100),
payload={"processed": 80, "total": 100}
)
"""
...
@abstractmethod
def on_cleanup(self, handler: Callable[[], None]) -> Callable[[], None]:
"""
Register a cleanup handler that runs when the task ends.
Cleanup handlers are called when the task completes (success),
fails with an error, or is cancelled. Multiple handlers can be
registered and will execute in LIFO order (last registered runs first).
Can be used as a decorator:
@ctx.on_cleanup
def cleanup():
logger.info("Task ended")
Or called directly:
ctx.on_cleanup(lambda: logger.info("Task ended"))
:param handler: Cleanup function to register
:returns: The handler (for decorator compatibility)
"""
...
@abstractmethod
def on_abort(self, handler: Callable[[], None]) -> Callable[[], None]:
"""
Register handler that runs when task is aborted.
When the first handler is registered, background polling starts
automatically. The handler will be called when an abort is detected.
The handler executes in a background thread and the task code
continues running unless the handler takes action to stop it.
:param handler: Callback function to execute when abort is detected
:returns: The handler (for decorator compatibility)
Example:
@ctx.on_abort
def handle_abort():
logger.info("Task was aborted!")
cleanup_partial_work()
"""
...
def task(
name: str | None = None,
scope: TaskScope = TaskScope.PRIVATE,
timeout: int | None = None,
) -> Callable[[Callable[P, R]], "TaskWrapper[P]"]:
"""
Decorator to register a task.
Host implementations will replace this function during initialization
with a concrete implementation providing actual functionality.
:param name: Optional unique task name (e.g., "superset.generate_thumbnail").
If not provided, uses the function name as the task name.
:param scope: Task scope (TaskScope.PRIVATE, SHARED, or SYSTEM).
Defaults to TaskScope.PRIVATE.
:param timeout: Optional timeout in seconds. When the timeout is reached,
abort handlers are triggered if registered. Can be overridden
at call time via TaskOptions(timeout=...).
:returns: TaskWrapper with .schedule() method
Note:
Both direct calls and .schedule() return Task, regardless of the
original function's return type. The decorated function's return value
is discarded; only side effects and context updates matter.
Example:
from superset_core.api.types import task, get_context, TaskScope
# Private task (default scope)
@task
def generate_thumbnail(chart_id: int) -> None:
ctx = get_context()
# ... task implementation
# Named task with shared scope
@task(name="generate_report", scope=TaskScope.SHARED)
def generate_chart_thumbnail(chart_id: int) -> None:
ctx = get_context()
# Update progress and payload atomically
ctx.update_task(
progress=0.5,
payload={"chart_id": chart_id, "status": "processing"}
)
# ... task implementation
ctx.update_task(progress=1.0)
# System task (admin-only)
@task(scope=TaskScope.SYSTEM)
def cleanup_old_data() -> None:
ctx = get_context()
# ... cleanup implementation
# Task with timeout
@task(timeout=300) # 5-minute timeout
def long_running_task() -> None:
ctx = get_context()
@ctx.on_abort
def handle_abort():
# Called when timeout or manual abort
pass
# Schedule async execution
task = generate_chart_thumbnail.schedule(chart_id=123) # Returns Task
# Direct call for sync execution (blocks until task is complete)
task = generate_chart_thumbnail(chart_id=123) # Also returns Task
"""
raise NotImplementedError("Function will be replaced during initialization")
class TaskWrapper(Generic[P]):
"""
Type stub for task wrapper returned by @task decorator.
Both __call__ and .schedule() return Task.
"""
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Task:
"""Execute the task synchronously."""
raise NotImplementedError("Will be replaced during initialization")
def schedule(self, *args: P.args, **kwargs: P.kwargs) -> Task:
"""Schedule the task for async execution."""
raise NotImplementedError("Will be replaced during initialization")
def get_context() -> TaskContext:
"""
Get the current task context from ambient context.
Host implementations will replace this function during initialization
with a concrete implementation providing actual functionality.
This function provides ambient access to the task context without
requiring it to be passed as a parameter. It can only be called
from within an async task execution.
:returns: The current TaskContext
:raises RuntimeError: If called outside a task execution context
Example:
@task("thumbnail_generation")
def generate_chart_thumbnail(chart_id: int):
ctx = get_context() # Access ambient context
# Update task state - no need to fetch task object
ctx.update_task(
progress=0.5,
payload={"chart_id": chart_id}
)
"""
raise NotImplementedError("Function will be replaced during initialization")
__all__ = [
"TaskStatus",
"TaskScope",
"TaskProperties",
"TaskContext",
"TaskOptions",
"task",
"get_context",
]

View File

@@ -109,6 +109,7 @@
"mustache": "^4.2.0",
"nanoid": "^5.1.6",
"ol": "^7.5.2",
"pretty-ms": "^9.3.0",
"query-string": "9.3.1",
"re-resizable": "^6.11.2",
"react": "^17.0.2",
@@ -43687,6 +43688,21 @@
"integrity": "sha512-w2GsyukL62IJnlaff/nRegPQR94C/XXamvMWmSHRJ4y7Ts/4ocGRmTHvOs8PSE6pB3dWOrD/nueuU5sduBsQ4w==",
"license": "MIT"
},
"node_modules/pretty-ms": {
"version": "9.3.0",
"resolved": "https://registry.npmjs.org/pretty-ms/-/pretty-ms-9.3.0.tgz",
"integrity": "sha512-gjVS5hOP+M3wMm5nmNOucbIrqudzs9v/57bWRHQWLYklXqoXKrVfYW2W9+glfGsqtPgpiz5WwyEEB+ksXIx3gQ==",
"license": "MIT",
"dependencies": {
"parse-ms": "^4.0.0"
},
"engines": {
"node": ">=18"
},
"funding": {
"url": "https://github.com/sponsors/sindresorhus"
}
},
"node_modules/prismjs": {
"version": "1.30.0",
"resolved": "https://registry.npmjs.org/prismjs/-/prismjs-1.30.0.tgz",
@@ -56437,20 +56453,6 @@
"url": "https://github.com/sponsors/wooorm"
}
},
"packages/superset-ui-core/node_modules/pretty-ms": {
"version": "9.3.0",
"resolved": "https://registry.npmjs.org/pretty-ms/-/pretty-ms-9.3.0.tgz",
"integrity": "sha512-gjVS5hOP+M3wMm5nmNOucbIrqudzs9v/57bWRHQWLYklXqoXKrVfYW2W9+glfGsqtPgpiz5WwyEEB+ksXIx3gQ==",
"dependencies": {
"parse-ms": "^4.0.0"
},
"engines": {
"node": ">=18"
},
"funding": {
"url": "https://github.com/sponsors/sindresorhus"
}
},
"packages/superset-ui-core/node_modules/property-information": {
"version": "7.1.0",
"resolved": "https://registry.npmjs.org/property-information/-/property-information-7.1.0.tgz",

View File

@@ -187,6 +187,7 @@
"markdown-to-jsx": "^9.7.3",
"match-sorter": "^6.3.4",
"memoize-one": "^5.2.1",
"pretty-ms": "^9.3.0",
"mousetrap": "^1.6.5",
"mustache": "^4.2.0",
"nanoid": "^5.1.6",

View File

@@ -54,6 +54,7 @@ export enum FeatureFlag {
EstimateQueryCost = 'ESTIMATE_QUERY_COST',
FilterBarClosedByDefault = 'FILTERBAR_CLOSED_BY_DEFAULT',
GlobalAsyncQueries = 'GLOBAL_ASYNC_QUERIES',
GlobalTaskFramework = 'GLOBAL_TASK_FRAMEWORK',
ListviewsDefaultCardView = 'LISTVIEWS_DEFAULT_CARD_VIEW',
Matrixify = 'MATRIXIFY',
ScheduledQueries = 'SCHEDULED_QUERIES',

View File

@@ -19,9 +19,9 @@
import getOwnerName from 'src/utils/getOwnerName';
import { t } from '@apache-superset/core';
import { Tooltip } from '@superset-ui/core/components';
import type { ModifiedInfoProps } from './types';
import type { AuditInfoProps } from './types';
export const ModifiedInfo = ({ user, date }: ModifiedInfoProps) => {
export const ModifiedInfo = ({ user, date }: AuditInfoProps) => {
const dateSpan = (
<span className="no-wrap" data-test="audit-info-date">
{date}
@@ -40,4 +40,23 @@ export const ModifiedInfo = ({ user, date }: ModifiedInfoProps) => {
return dateSpan;
};
export type { ModifiedInfoProps };
export const CreatedInfo = ({ user, date }: AuditInfoProps) => {
const dateSpan = (
<span className="no-wrap" data-test="audit-info-date">
{date}
</span>
);
if (user) {
const userName = getOwnerName(user);
const title = t('Created by: %s', userName);
return (
<Tooltip title={title} placement="bottom">
{dateSpan}
</Tooltip>
);
}
return dateSpan;
};
export type { AuditInfoProps };

View File

@@ -18,7 +18,7 @@
*/
import type Owner from 'src/types/Owner';
export type ModifiedInfoProps = {
export type AuditInfoProps = {
user?: Owner;
date: string;
};

View File

@@ -41,7 +41,7 @@ export * from './GenericLink';
export { GridTable, type TableProps } from './GridTable';
export * from './Tag';
export * from './TagsList';
export { ModifiedInfo, type ModifiedInfoProps } from './AuditInfo';
export { CreatedInfo, ModifiedInfo, type AuditInfoProps } from './AuditInfo';
export {
DynamicPluginProvider,
PluginContext,

View File

@@ -0,0 +1,76 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
import { useState } from 'react';
import { styled } from '@apache-superset/core/ui';
import { Popover } from '@superset-ui/core/components';
import { Icons } from '@superset-ui/core/components/Icons';
const PayloadContainer = styled.div`
max-width: 400px;
max-height: 300px;
overflow: auto;
padding: ${({ theme }) => theme.sizeUnit * 2}px;
`;
const PayloadPre = styled.pre`
margin: 0;
font-size: ${({ theme }) => theme.fontSizeSM}px;
white-space: pre-wrap;
word-wrap: break-word;
`;
const InfoIconWrapper = styled.span`
cursor: pointer;
color: ${({ theme }) => theme.colorIcon};
&:hover {
color: ${({ theme }) => theme.colorPrimary};
}
`;
interface TaskPayloadPopoverProps {
payload: Record<string, any>;
}
export default function TaskPayloadPopover({
payload,
}: TaskPayloadPopoverProps) {
const [visible, setVisible] = useState(false);
const content = (
<PayloadContainer>
<PayloadPre>{JSON.stringify(payload, null, 2)}</PayloadPre>
</PayloadContainer>
);
return (
<Popover
content={content}
trigger="hover"
placement="leftTop"
visible={visible}
onVisibleChange={setVisible}
>
<InfoIconWrapper>
<Icons.InfoCircleOutlined iconSize="l" />
</InfoIconWrapper>
</Popover>
);
}

View File

@@ -0,0 +1,137 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
import { useState, useCallback } from 'react';
import { t } from '@apache-superset/core';
import { styled } from '@apache-superset/core/ui';
import { Popover, Tooltip } from '@superset-ui/core/components';
import { Icons } from '@superset-ui/core/components/Icons';
import { useToasts } from 'src/components/MessageToasts/withToasts';
import copyTextToClipboard from 'src/utils/copy';
const StackTraceContainer = styled.div`
max-width: 600px;
max-height: 400px;
display: flex;
flex-direction: column;
`;
const Header = styled.div`
display: flex;
justify-content: flex-end;
padding: ${({ theme }) => theme.sizeUnit}px
${({ theme }) => theme.sizeUnit * 2}px;
border-bottom: 1px solid ${({ theme }) => theme.colorBorder};
`;
const CopyButton = styled.button`
background: none;
border: none;
cursor: pointer;
padding: ${({ theme }) => theme.sizeUnit / 2}px;
color: ${({ theme }) => theme.colorTextSecondary};
display: flex;
align-items: center;
gap: ${({ theme }) => theme.sizeUnit / 2}px;
font-size: ${({ theme }) => theme.fontSizeSM}px;
&:hover {
color: ${({ theme }) => theme.colorText};
}
`;
const StackTraceContent = styled.div`
overflow: auto;
padding: ${({ theme }) => theme.sizeUnit * 2}px;
flex: 1;
`;
const StackTrace = styled.pre`
margin: 0;
font-size: ${({ theme }) => theme.fontSizeSM}px;
white-space: pre-wrap;
word-wrap: break-word;
font-family: ${({ theme }) => theme.fontFamilyCode};
`;
const ErrorIconWrapper = styled.span`
cursor: pointer;
color: ${({ theme }) => theme.colorError};
&:hover {
opacity: 0.8;
}
`;
interface TaskStackTracePopoverProps {
stackTrace: string;
}
export default function TaskStackTracePopover({
stackTrace,
}: TaskStackTracePopoverProps) {
const [visible, setVisible] = useState(false);
const [copied, setCopied] = useState(false);
const { addDangerToast } = useToasts();
const handleCopy = useCallback(() => {
copyTextToClipboard(() => Promise.resolve(stackTrace))
.then(() => {
setCopied(true);
setTimeout(() => setCopied(false), 2000);
})
.catch(() => {
addDangerToast(t('Failed to copy stack trace to clipboard'));
});
}, [stackTrace, addDangerToast]);
const content = (
<StackTraceContainer>
<Header>
<Tooltip title={copied ? t('Copied!') : t('Copy to clipboard')}>
<CopyButton onClick={handleCopy}>
{copied ? (
<Icons.CheckOutlined iconSize="s" />
) : (
<Icons.CopyOutlined iconSize="s" />
)}
{t('Copy')}
</CopyButton>
</Tooltip>
</Header>
<StackTraceContent>
<StackTrace>{stackTrace}</StackTrace>
</StackTraceContent>
</StackTraceContainer>
);
return (
<Popover
content={content}
trigger="hover"
placement="leftTop"
visible={visible}
onVisibleChange={setVisible}
>
<ErrorIconWrapper>
<Icons.BugOutlined iconSize="l" />
</ErrorIconWrapper>
</Popover>
);
}

View File

@@ -0,0 +1,145 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
import React from 'react';
import { useTheme, SupersetTheme, t } from '@apache-superset/core/ui';
import { Icons } from '@superset-ui/core/components/Icons';
import { Tooltip } from '@superset-ui/core/components';
import { TaskStatus } from './types';
import { formatProgressTooltip } from './timeUtils';
function getStatusColor(status: TaskStatus, theme: SupersetTheme): string {
switch (status) {
case TaskStatus.Pending:
return theme.colorPrimaryText;
case TaskStatus.InProgress:
return theme.colorPrimaryText;
case TaskStatus.Success:
return theme.colorSuccessText;
case TaskStatus.Failure:
return theme.colorErrorText;
case TaskStatus.TimedOut:
return theme.colorErrorText;
case TaskStatus.Aborting:
return theme.colorWarningText;
case TaskStatus.Aborted:
return theme.colorWarningText;
default:
return theme.colorText;
}
}
const statusIcons = {
[TaskStatus.Pending]: Icons.ClockCircleOutlined,
[TaskStatus.InProgress]: Icons.LoadingOutlined,
[TaskStatus.Success]: Icons.CheckCircleOutlined,
[TaskStatus.Failure]: Icons.CloseCircleOutlined,
[TaskStatus.TimedOut]: Icons.ClockCircleOutlined, // Clock to indicate timeout
[TaskStatus.Aborting]: Icons.LoadingOutlined, // Spinning to show in-progress abort
[TaskStatus.Aborted]: Icons.StopOutlined,
};
const statusLabels = {
[TaskStatus.Pending]: t('Pending'),
[TaskStatus.InProgress]: t('In Progress'),
[TaskStatus.Success]: t('Success'),
[TaskStatus.Failure]: t('Failed'),
[TaskStatus.TimedOut]: t('Timed Out'),
[TaskStatus.Aborting]: t('Aborting'),
[TaskStatus.Aborted]: t('Aborted'),
};
interface TaskStatusIconProps {
status: TaskStatus;
progressPercent?: number | null;
progressCurrent?: number | null;
progressTotal?: number | null;
durationSeconds?: number | null;
errorMessage?: string | null;
exceptionType?: string | null;
}
export default function TaskStatusIcon({
status,
progressPercent,
progressCurrent,
progressTotal,
durationSeconds,
errorMessage,
exceptionType,
}: TaskStatusIconProps) {
const theme = useTheme();
const IconComponent = statusIcons[status];
const label = statusLabels[status];
// Build tooltip content based on status
let tooltipContent: React.ReactNode;
if (status === TaskStatus.InProgress || status === TaskStatus.Aborting) {
// Progress tooltip for active tasks (multiline)
const lines = formatProgressTooltip(
label,
progressCurrent,
progressTotal,
progressPercent,
durationSeconds,
);
tooltipContent = (
<>
{lines.map((line, index) => (
<React.Fragment key={index}>
{index > 0 && <br />}
{line}
</React.Fragment>
))}
</>
);
} else if (
(status === TaskStatus.Failure || status === TaskStatus.TimedOut) &&
(exceptionType || errorMessage)
) {
// Error tooltip for failed/timed out tasks: "Label (ExceptionType): message"
if (exceptionType && errorMessage) {
tooltipContent = `${label} (${exceptionType}): ${errorMessage}`;
} else if (exceptionType) {
tooltipContent = `${label} (${exceptionType})`;
} else if (errorMessage) {
tooltipContent = `${label}: ${errorMessage}`;
} else {
tooltipContent = label;
}
} else {
tooltipContent = label;
}
// Spin for in-progress and aborting states
const shouldSpin =
status === TaskStatus.InProgress || status === TaskStatus.Aborting;
return (
<Tooltip title={tooltipContent} placement="top">
<span>
<IconComponent
iconSize="l"
iconColor={getStatusColor(status, theme)}
spin={shouldSpin}
/>
</span>
</Tooltip>
);
}

View File

@@ -0,0 +1,145 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
import {
formatDuration,
calculateEta,
formatProgressTooltip,
} from './timeUtils';
test('formatDuration returns null for invalid inputs', () => {
expect(formatDuration(null)).toBeNull();
expect(formatDuration(undefined)).toBeNull();
expect(formatDuration(0)).toBeNull();
expect(formatDuration(-5)).toBeNull();
});
test('formatDuration formats seconds correctly', () => {
expect(formatDuration(37.5)).toBe('37s');
expect(formatDuration(1)).toBe('1s');
expect(formatDuration(30)).toBe('30s');
});
test('formatDuration formats minutes correctly', () => {
expect(formatDuration(60)).toBe('1m');
expect(formatDuration(90)).toBe('1m 30s');
expect(formatDuration(150)).toBe('2m 30s');
});
test('formatDuration formats hours correctly', () => {
expect(formatDuration(3600)).toBe('1h');
expect(formatDuration(3660)).toBe('1h 1m');
expect(formatDuration(7200)).toBe('2h');
});
test('calculateEta returns null for invalid inputs', () => {
expect(calculateEta(null, 60)).toBeNull();
expect(calculateEta(undefined, 60)).toBeNull();
expect(calculateEta(0.5, null)).toBeNull();
expect(calculateEta(0.5, undefined)).toBeNull();
});
test('calculateEta returns null for edge case progress values', () => {
// No progress yet
expect(calculateEta(0, 60)).toBeNull();
// Already complete
expect(calculateEta(1, 60)).toBeNull();
// Negative progress (invalid)
expect(calculateEta(-0.1, 60)).toBeNull();
// Over 100% (invalid)
expect(calculateEta(1.1, 60)).toBeNull();
});
test('calculateEta calculates correct remaining time', () => {
// 50% done in 60s -> ETA = 60s remaining
expect(calculateEta(0.5, 60)).toBe('1m');
// 30% done in 60s -> remaining = (60/0.3) * 0.7 = 140s
expect(calculateEta(0.3, 60)).toBe('2m 20s');
// 10% done in 10s -> remaining = (10/0.1) * 0.9 = 90s
expect(calculateEta(0.1, 10)).toBe('1m 30s');
// 90% done in 90s -> remaining = (90/0.9) * 0.1 = 10s
expect(calculateEta(0.9, 90)).toBe('10s');
});
test('calculateEta returns null for ETAs over 24 hours', () => {
// 0.1% done in 100s -> remaining = (100/0.001) * 0.999 = ~99900s > 86400s
expect(calculateEta(0.001, 100)).toBeNull();
});
test('formatProgressTooltip returns label only when no progress data', () => {
expect(formatProgressTooltip('In Progress')).toEqual(['In Progress']);
expect(formatProgressTooltip('In Progress', null, null, null, null)).toEqual([
'In Progress',
]);
});
test('formatProgressTooltip formats count and total correctly', () => {
expect(formatProgressTooltip('In Progress', 9, 60)).toEqual([
'In Progress: 9 of 60',
]);
});
test('formatProgressTooltip formats count only correctly', () => {
expect(formatProgressTooltip('In Progress', 42)).toEqual([
'In Progress: 42 processed',
]);
expect(formatProgressTooltip('In Progress', 42, null)).toEqual([
'In Progress: 42 processed',
]);
});
test('formatProgressTooltip formats percentage correctly', () => {
expect(formatProgressTooltip('In Progress', null, null, 0.5)).toEqual([
'In Progress: 50%',
]);
expect(formatProgressTooltip('In Progress', null, null, 0.333)).toEqual([
'In Progress: 33%',
]);
});
test('formatProgressTooltip combines count, total, and percentage', () => {
expect(formatProgressTooltip('In Progress', 9, 60, 0.15)).toEqual([
'In Progress: 9 of 60 (15%)',
]);
});
test('formatProgressTooltip includes ETA when duration is provided', () => {
// 50% done in 60s -> ETA = 60s = ~1m
expect(formatProgressTooltip('In Progress', 30, 60, 0.5, 60)).toEqual([
'In Progress: 30 of 60 (50%)',
'ETA: 1m',
]);
});
test('formatProgressTooltip works with percentage and ETA only', () => {
// 25% done in 30s -> ETA = (30/0.25) * 0.75 = 90s = 1m 30s
expect(formatProgressTooltip('In Progress', null, null, 0.25, 30)).toEqual([
'In Progress: 25%',
'ETA: 1m 30s',
]);
});
test('formatProgressTooltip works with different labels', () => {
expect(formatProgressTooltip('Aborting', 5, 10, 0.5)).toEqual([
'Aborting: 5 of 10 (50%)',
]);
});

View File

@@ -0,0 +1,151 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
import prettyMs from 'pretty-ms';
/**
* Maximum ETA to display (24 hours in seconds).
* ETAs beyond this are not shown as they're unreliable.
*/
const MAX_ETA_SECONDS = 86400;
/**
* Format a duration in seconds to a human-readable string.
*
* @param seconds - Duration in seconds
* @returns Formatted string like "1m 30s" or "2h 15m", or null if invalid
*/
export function formatDuration(
seconds: number | null | undefined,
): string | null {
if (seconds === null || seconds === undefined || seconds <= 0) {
return null;
}
return prettyMs(seconds * 1000, {
unitCount: 2,
secondsDecimalDigits: 0,
keepDecimalsOnWholeSeconds: false,
});
}
/**
* Calculate and format estimated time to completion based on progress and elapsed time.
*
* Uses the formula: ETA = (elapsed / progress) * (1 - progress)
* For example, if 30% done in 60s, remaining = (60/0.3) * 0.7 = 140s
*
* @param progressPercent - Progress as a fraction (0.0 to 1.0)
* @param durationSeconds - Time elapsed so far in seconds
* @returns Formatted ETA string or null if cannot be calculated
*/
export function calculateEta(
progressPercent: number | null | undefined,
durationSeconds: number | null | undefined,
): string | null {
// Need both progress and duration to calculate ETA
if (
progressPercent === null ||
progressPercent === undefined ||
durationSeconds === null ||
durationSeconds === undefined
) {
return null;
}
// Can't calculate ETA if no progress yet or already complete
if (progressPercent <= 0 || progressPercent >= 1) {
return null;
}
// ETA = (elapsed / progress) * (1 - progress)
const estimatedTotalTime = durationSeconds / progressPercent;
const remainingSeconds = estimatedTotalTime * (1 - progressPercent);
// Only show ETA if it's reasonable (less than 24 hours)
if (remainingSeconds <= 0 || remainingSeconds > MAX_ETA_SECONDS) {
return null;
}
// Use unitCount: 2 to show up to 2 units (e.g., "1m 30s" instead of just "1m")
// Use secondsDecimalDigits: 0 to show whole seconds (e.g., "52s" instead of "52.4s")
return prettyMs(remainingSeconds * 1000, {
unitCount: 2,
secondsDecimalDigits: 0,
});
}
/**
* Build a progress display for task status tooltips.
*
* Returns an array of lines for proper multiline tooltip rendering:
* - ["In Progress: 9 of 60 (15%)", "ETA: 51s"]
* - ["In Progress: 42 processed"]
* - ["In Progress: 50%"]
* - ["In Progress: 50%", "ETA: 2m"]
*
* @param label - Status label (e.g., "In Progress", "Aborting")
* @param progressCurrent - Current count of items processed
* @param progressTotal - Total count of items to process
* @param progressPercent - Progress as a fraction (0.0 to 1.0)
* @param durationSeconds - Time elapsed so far in seconds (used for ETA calculation)
* @returns Array of lines for tooltip display
*/
export function formatProgressTooltip(
label: string,
progressCurrent?: number | null,
progressTotal?: number | null,
progressPercent?: number | null,
durationSeconds?: number | null,
): string[] {
const lines: string[] = [];
let progressPart = '';
// Build progress part
if (progressCurrent !== null && progressCurrent !== undefined) {
if (progressTotal !== null && progressTotal !== undefined) {
// Count and total with percentage: "3 of 278 (15%)"
progressPart = `${progressCurrent} of ${progressTotal}`;
if (progressPercent !== null && progressPercent !== undefined) {
progressPart += ` (${Math.round(progressPercent * 100)}%)`;
}
} else {
// Count only: "3 processed"
progressPart = `${progressCurrent} processed`;
}
} else if (progressPercent !== null && progressPercent !== undefined) {
// Percentage only: "50%"
progressPart = `${Math.round(progressPercent * 100)}%`;
}
// Add the main progress line
if (progressPart) {
lines.push(`${label}: ${progressPart}`);
} else {
lines.push(label);
}
// Add ETA on a separate line if available
const eta = calculateEta(progressPercent, durationSeconds);
if (eta) {
lines.push(`ETA: ${eta}`);
}
return lines;
}

View File

@@ -0,0 +1,115 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
export interface TaskSubscriber {
user_id: number;
first_name: string;
last_name: string;
subscribed_at: string;
}
export enum TaskScope {
Private = 'private',
Shared = 'shared',
System = 'system',
}
/**
* Task properties - runtime state and execution config stored in JSON blob.
*/
export interface TaskProperties {
// Execution config - set at task creation
execution_mode: 'async' | 'sync' | null;
timeout: number | null;
// Runtime state - set by framework during execution
is_abortable: boolean | null;
progress_percent: number | null;
progress_current: number | null;
progress_total: number | null;
// Error info - set when task fails
error_message: string | null;
exception_type: string | null;
stack_trace: string | null;
}
export interface Task {
id: number;
uuid: string;
task_key: string;
task_type: string;
task_name: string | null;
status:
| 'pending'
| 'in_progress'
| 'success'
| 'failure'
| 'aborting'
| 'aborted'
| 'timed_out';
scope: TaskScope;
created_on: string;
created_on_delta_humanized?: string;
changed_on: string;
started_at: string | null;
ended_at: string | null;
created_by: {
id: number;
first_name: string;
last_name: string;
} | null;
changed_by?: {
first_name: string;
last_name: string;
} | null;
user_id: number | null;
payload: Record<string, any>;
properties: TaskProperties;
duration_seconds: number | null;
subscriber_count: number;
subscribers: TaskSubscriber[];
}
// Derived status helpers (frontend computes these from status and properties)
export function isTaskFinished(task: Task): boolean {
return ['success', 'failure', 'aborted', 'timed_out'].includes(task.status);
}
export function isTaskAborting(task: Task): boolean {
return task.status === 'aborting';
}
export function canAbortTask(task: Task): boolean {
if (task.status === 'pending') return true;
if (task.status === 'in_progress' && task.properties.is_abortable === true)
return true;
if (task.status === 'aborting') return true; // Idempotent
return false;
}
export enum TaskStatus {
Pending = 'pending',
InProgress = 'in_progress',
Success = 'success',
Failure = 'failure',
Aborting = 'aborting',
Aborted = 'aborted',
TimedOut = 'timed_out',
}

View File

@@ -0,0 +1,328 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
import { MemoryRouter } from 'react-router-dom';
import fetchMock from 'fetch-mock';
import {
render,
screen,
waitFor,
fireEvent,
} from 'spec/helpers/testing-library';
import { QueryParamProvider } from 'use-query-params';
import { ReactRouter5Adapter } from 'use-query-params/adapters/react-router-5';
import { TaskStatus, TaskScope } from 'src/features/tasks/types';
import TaskList from 'src/pages/TaskList';
// Set up window.featureFlags before importing TaskList
window.featureFlags = { GLOBAL_TASK_FRAMEWORK: true };
// Mock getBootstrapData before importing components that use it
jest.mock('src/utils/getBootstrapData', () => ({
__esModule: true,
default: () => ({
user: {
userId: 1,
firstName: 'admin',
lastName: 'user',
roles: { Admin: [] },
},
common: {
feature_flags: { GLOBAL_TASK_FRAMEWORK: true },
conf: {},
},
}),
}));
const tasksInfoEndpoint = 'glob:*/api/v1/task/_info*';
const tasksCreatedByEndpoint = 'glob:*/api/v1/task/related/created_by*';
const tasksEndpoint = 'glob:*/api/v1/task/?*';
const taskCancelEndpoint = 'glob:*/api/v1/task/*/cancel';
const mockTasks = [
{
id: 1,
uuid: 'task-uuid-1',
task_key: 'test_task_1',
task_type: 'data_export',
task_name: 'Export Data Task',
status: TaskStatus.Success,
scope: TaskScope.Private,
created_on: '2024-01-15T10:00:00Z',
changed_on: '2024-01-15T10:05:00Z',
created_on_delta_humanized: '5 minutes ago',
started_at: '2024-01-15T10:00:01Z',
ended_at: '2024-01-15T10:05:00Z',
created_by: { id: 1, first_name: 'admin', last_name: 'user' },
user_id: 1,
payload: {},
duration_seconds: 299,
subscriber_count: 0,
subscribers: [],
properties: {
is_abortable: null,
progress_percent: 1.0,
progress_current: null,
progress_total: null,
error_message: null,
exception_type: null,
stack_trace: null,
timeout: null,
},
},
{
id: 2,
uuid: 'task-uuid-2',
task_key: 'test_task_2',
task_type: 'report_generation',
task_name: null,
status: TaskStatus.InProgress,
scope: TaskScope.Private,
created_on: '2024-01-15T11:00:00Z',
changed_on: '2024-01-15T11:00:00Z',
created_on_delta_humanized: '1 minute ago',
started_at: '2024-01-15T11:00:01Z',
ended_at: null,
created_by: { id: 1, first_name: 'admin', last_name: 'user' },
user_id: 1,
payload: { report_id: 42 },
duration_seconds: null,
subscriber_count: 0,
subscribers: [],
properties: {
is_abortable: true,
progress_percent: 0.5,
progress_current: null,
progress_total: null,
error_message: null,
exception_type: null,
stack_trace: null,
timeout: null,
},
},
{
id: 3,
uuid: 'task-uuid-3',
task_key: 'shared_task_1',
task_type: 'bulk_operation',
task_name: 'Shared Bulk Task',
status: TaskStatus.Pending,
scope: TaskScope.Shared,
created_on: '2024-01-15T12:00:00Z',
changed_on: '2024-01-15T12:00:00Z',
created_on_delta_humanized: 'just now',
started_at: null,
ended_at: null,
created_by: { id: 2, first_name: 'other', last_name: 'user' },
user_id: 2,
payload: {},
duration_seconds: null,
subscriber_count: 2,
subscribers: [
{
user_id: 1,
first_name: 'admin',
last_name: 'user',
subscribed_at: '2024-01-15T12:00:00Z',
},
{
user_id: 2,
first_name: 'other',
last_name: 'user',
subscribed_at: '2024-01-15T12:00:01Z',
},
],
properties: {
is_abortable: null,
progress_percent: null,
progress_current: null,
progress_total: null,
error_message: null,
exception_type: null,
stack_trace: null,
timeout: null,
},
},
];
const mockUser = {
userId: 1,
firstName: 'admin',
lastName: 'user',
};
fetchMock.get(
tasksInfoEndpoint,
{ permissions: ['can_read', 'can_write'] },
{ name: tasksInfoEndpoint },
);
fetchMock.get(
tasksCreatedByEndpoint,
{ result: [] },
{ name: tasksCreatedByEndpoint },
);
fetchMock.get(
tasksEndpoint,
{ result: mockTasks, count: 3 },
{ name: tasksEndpoint },
);
fetchMock.post(
taskCancelEndpoint,
{ action: 'aborted', message: 'Task cancelled' },
{ name: taskCancelEndpoint },
);
const renderTaskList = (props = {}, userProp = mockUser) =>
render(
<MemoryRouter>
<QueryParamProvider adapter={ReactRouter5Adapter}>
<TaskList {...props} user={userProp} />
</QueryParamProvider>
</MemoryRouter>,
{ useRedux: true },
);
beforeEach(() => {
fetchMock.clearHistory();
});
test('renders TaskList with title, ListView, and fetches data from endpoints', async () => {
renderTaskList();
// Wait for data to load and verify title
expect(await screen.findByText('Tasks')).toBeInTheDocument();
expect(screen.getByTestId('task-list-view')).toBeInTheDocument();
// Verify API calls were made
expect(fetchMock.callHistory.calls(/task\/_info/).length).toBe(1);
expect(fetchMock.callHistory.calls(/task\/\?q/).length).toBe(1);
});
test('displays task data including types, scope labels, and duration', async () => {
renderTaskList();
// Wait for data to load
await screen.findByText('Export Data Task');
// Task types
expect(screen.getByText('data_export')).toBeInTheDocument();
expect(screen.getByText('report_generation')).toBeInTheDocument();
// Scope labels
expect(screen.getAllByText('Private').length).toBeGreaterThan(0);
expect(screen.getByText('Shared')).toBeInTheDocument();
// Duration (299s = 4m 59s via prettyMs)
expect(screen.getByText('4m 59s')).toBeInTheDocument();
});
test('shows cancel button and modal for cancellable tasks', async () => {
renderTaskList();
// Wait for data to load
await screen.findByText('test_task_2');
// Cancel buttons exist for in-progress and shared tasks
const stopIcons = screen.getAllByRole('img', { name: 'stop' });
expect(stopIcons.length).toBeGreaterThan(0);
// Click a cancel button to show confirmation modal
const cancelButton = stopIcons.find(
icon => icon.closest('[role="button"]') !== null,
);
expect(cancelButton).toBeDefined();
fireEvent.click(cancelButton!);
expect(await screen.findByText('Cancel Task')).toBeInTheDocument();
});
test('does not show cancel button for completed shared tasks', async () => {
const completedSharedTask = {
id: 4,
uuid: 'task-uuid-4',
task_key: 'completed_shared_task',
task_type: 'bulk_operation',
task_name: 'Completed Shared Task',
status: TaskStatus.Success,
scope: TaskScope.Shared,
created_on: '2024-01-15T12:00:00Z',
changed_on: '2024-01-15T12:05:00Z',
created_on_delta_humanized: '5 minutes ago',
started_at: '2024-01-15T12:00:01Z',
ended_at: '2024-01-15T12:05:00Z',
created_by: { id: 2, first_name: 'other', last_name: 'user' },
user_id: 2,
payload: {},
duration_seconds: 299,
subscriber_count: 1,
subscribers: [
{
user_id: 1,
first_name: 'admin',
last_name: 'user',
subscribed_at: '2024-01-15T12:00:00Z',
},
],
properties: {
is_abortable: null,
progress_percent: 1.0,
progress_current: null,
progress_total: null,
error_message: null,
exception_type: null,
stack_trace: null,
timeout: null,
},
};
fetchMock.modifyRoute(tasksEndpoint, {
response: { result: [completedSharedTask], count: 1 },
});
renderTaskList();
await screen.findByText('Completed Shared Task');
// No action buttons with stop icons for completed tasks
const stopIcons = screen.queryAllByRole('img', { name: 'stop' });
const actionButtons = stopIcons.filter(
icon => icon.closest('[role="button"]') !== null,
);
expect(actionButtons).toHaveLength(0);
// Restore mock
fetchMock.modifyRoute(tasksEndpoint, {
response: { result: mockTasks, count: 3 },
});
});
test('displays empty state when no tasks', async () => {
fetchMock.modifyRoute(tasksEndpoint, {
response: { result: [], count: 0 },
});
renderTaskList();
await waitFor(() => {
expect(screen.getByText('No tasks yet')).toBeInTheDocument();
});
// Restore mock
fetchMock.modifyRoute(tasksEndpoint, {
response: { result: mockTasks, count: 3 },
});
});

View File

@@ -0,0 +1,658 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
import {
FeatureFlag,
isFeatureEnabled,
SupersetClient,
} from '@superset-ui/core';
import { t, useTheme } from '@apache-superset/core';
import { useMemo, useCallback, useState } from 'react';
import { Tooltip, Label, Modal, Checkbox } from '@superset-ui/core/components';
import {
CreatedInfo,
ListView,
ListViewFilterOperator as FilterOperator,
type ListViewFilters,
FacePile,
} from 'src/components';
import { Icons } from '@superset-ui/core/components/Icons';
import withToasts from 'src/components/MessageToasts/withToasts';
import SubMenu from 'src/features/home/SubMenu';
import { useListViewResource } from 'src/views/CRUD/hooks';
import { createErrorHandler, createFetchRelated } from 'src/views/CRUD/utils';
import TaskStatusIcon from 'src/features/tasks/TaskStatusIcon';
import TaskPayloadPopover from 'src/features/tasks/TaskPayloadPopover';
import TaskStackTracePopover from 'src/features/tasks/TaskStackTracePopover';
import { formatDuration } from 'src/features/tasks/timeUtils';
import {
Task,
TaskStatus,
TaskScope,
canAbortTask,
isTaskAborting,
TaskSubscriber,
} from 'src/features/tasks/types';
import { isUserAdmin } from 'src/dashboard/util/permissionUtils';
import getBootstrapData from 'src/utils/getBootstrapData';
const PAGE_SIZE = 25;
/**
* Typed cell props for react-table columns.
* Replaces `: any` for better type safety in Cell render functions.
*/
interface TaskCellProps {
row: {
original: Task;
};
}
interface TaskListProps {
addDangerToast: (msg: string) => void;
addSuccessToast: (msg: string) => void;
user: {
userId: string | number;
firstName: string;
lastName: string;
};
}
function TaskList({ addDangerToast, addSuccessToast, user }: TaskListProps) {
const theme = useTheme();
// Check if GTF feature flag is enabled
if (!isFeatureEnabled(FeatureFlag.GlobalTaskFramework)) {
return (
<>
<SubMenu name={t('Tasks')} />
<div
style={{
display: 'flex',
flexDirection: 'column',
alignItems: 'center',
justifyContent: 'center',
height: '50vh',
color: theme.colorTextSecondary,
}}
>
<h3>{t('Feature Not Enabled')}</h3>
<p>
{t(
'The Global Task Framework is not enabled. Please contact your administrator to enable the GLOBAL_TASK_FRAMEWORK feature flag.',
)}
</p>
</div>
</>
);
}
const {
state: { loading, resourceCount: tasksCount, resourceCollection: tasks },
fetchData,
refreshData,
} = useListViewResource<Task>('task', t('task'), addDangerToast);
// Get full user with roles to check admin status
const bootstrapData = getBootstrapData();
const fullUser = bootstrapData?.user;
const isAdmin = useMemo(() => isUserAdmin(fullUser), [fullUser]);
// State for cancel confirmation modal
const [cancelModalTask, setCancelModalTask] = useState<Task | null>(null);
const [forceCancel, setForceCancel] = useState(false);
// Determine dialog message based on task context
const getCancelDialogMessage = useCallback((task: Task) => {
const isSharedTask = task.scope === TaskScope.Shared;
const subscriberCount = task.subscriber_count || 0;
const otherSubscribers = subscriberCount - 1;
// If it's going to abort (private, system, or last subscriber)
if (!isSharedTask || subscriberCount <= 1) {
return t('This will cancel the task.');
}
// Shared task with multiple subscribers
return t(
"You'll be removed from this task. It will continue running for %s other subscriber(s).",
otherSubscribers,
);
}, []);
// Get force abort message for admin checkbox
const getForceAbortMessage = useCallback((task: Task) => {
const subscriberCount = task.subscriber_count || 0;
return t(
'This will abort (stop) the task for all %s subscriber(s).',
subscriberCount,
);
}, []);
// Check if current user is subscribed to a task
const isUserSubscribed = useCallback(
(task: Task) =>
task.subscribers?.some(
(sub: TaskSubscriber) => sub.user_id === user.userId,
) ?? false,
[user.userId],
);
// Check if force cancel option should be shown (for admins on shared tasks)
const showForceCancelOption = useCallback(
(task: Task) => {
const isSharedTask = task.scope === TaskScope.Shared;
const subscriberCount = task.subscriber_count || 0;
const userSubscribed = isUserSubscribed(task);
// Show for admins on shared tasks when:
// - Not subscribed (can only abort, so show checkbox pre-checked disabled), OR
// - Multiple subscribers (can choose between unsubscribe and force abort)
// Don't show when admin is the sole subscriber - cancel will abort anyway
return (
isAdmin && isSharedTask && (subscriberCount > 1 || !userSubscribed)
);
},
[isAdmin, isUserSubscribed],
);
// Check if force cancel checkbox should be disabled (admin not subscribed)
const isForceCancelDisabled = useCallback(
(task: Task) => isAdmin && !isUserSubscribed(task),
[isAdmin, isUserSubscribed],
);
const handleTaskCancel = useCallback(
(task: Task, force: boolean = false) => {
SupersetClient.post({
endpoint: `/api/v1/task/${task.uuid}/cancel`,
jsonPayload: force ? { force: true } : {},
}).then(
({ json }) => {
refreshData();
const { action } = json as { action: string };
if (action === 'aborted') {
addSuccessToast(
t('Task cancelled: %s', task.task_name || task.task_key),
);
} else {
addSuccessToast(
t(
'You have been removed from task: %s',
task.task_name || task.task_key,
),
);
}
},
createErrorHandler(errMsg =>
addDangerToast(
t('There was an issue cancelling the task: %s', errMsg),
),
),
);
},
[addDangerToast, addSuccessToast, refreshData],
);
// Handle opening the cancel modal - set initial forceCancel state
const openCancelModal = useCallback(
(task: Task) => {
// Pre-check force cancel if admin is not subscribed
const shouldPreCheck = isAdmin && !isUserSubscribed(task);
setForceCancel(shouldPreCheck);
setCancelModalTask(task);
},
[isAdmin, isUserSubscribed],
);
// Handle modal confirmation
const handleCancelConfirm = useCallback(() => {
if (cancelModalTask) {
handleTaskCancel(cancelModalTask, forceCancel);
setCancelModalTask(null);
setForceCancel(false);
}
}, [cancelModalTask, forceCancel, handleTaskCancel]);
// Handle modal close
const handleCancelModalClose = useCallback(() => {
setCancelModalTask(null);
setForceCancel(false);
}, []);
const columns = useMemo(
() => [
{
Cell: ({
row: {
original: { task_name, task_key, uuid },
},
}: TaskCellProps) => {
// Display preference: task_name > task_key
const displayText = task_name || task_key;
const truncated =
displayText.length > 30
? `${displayText.slice(0, 30)}...`
: displayText;
// Build tooltip with all identifiers
const tooltipLines = [];
if (task_name) tooltipLines.push(`Name: ${task_name}`);
tooltipLines.push(`Key: ${task_key}`);
tooltipLines.push(`UUID: ${uuid}`);
const tooltipText = tooltipLines.join('\n');
return (
<Tooltip
title={
<span style={{ whiteSpace: 'pre-line' }}>{tooltipText}</span>
}
placement="top"
>
<span>{truncated}</span>
</Tooltip>
);
},
accessor: 'task_name',
Header: t('Task'),
size: 'xl',
id: 'task',
},
{
Cell: ({
row: {
original: { status, properties, duration_seconds },
},
}: TaskCellProps) => (
<TaskStatusIcon
status={status as TaskStatus}
progressPercent={properties?.progress_percent}
progressCurrent={properties?.progress_current}
progressTotal={properties?.progress_total}
durationSeconds={duration_seconds}
errorMessage={properties?.error_message}
exceptionType={properties?.exception_type}
/>
),
accessor: 'status',
Header: t('Status'),
size: 'xs',
id: 'status',
},
{
accessor: 'task_type',
Header: t('Type'),
size: 'md',
id: 'task_type',
},
{
Cell: ({
row: {
original: { scope },
},
}: TaskCellProps) => {
const scopeConfig: Record<
TaskScope,
{ label: string; type: 'default' | 'info' | 'warning' }
> = {
[TaskScope.Private]: { label: t('Private'), type: 'default' },
[TaskScope.Shared]: { label: t('Shared'), type: 'info' },
[TaskScope.System]: { label: t('System'), type: 'warning' },
};
const config = scopeConfig[scope as TaskScope] || {
label: scope,
type: 'default' as const,
};
return <Label type={config.type}>{config.label}</Label>;
},
accessor: 'scope',
Header: t('Scope'),
size: 'sm',
id: 'scope',
},
{
Cell: ({
row: {
original: { subscriber_count, subscribers },
},
}: TaskCellProps) => {
if (!subscribers || subscriber_count === 0) {
return '-';
}
// Convert subscribers to FacePile format
const users = subscribers.map((sub: TaskSubscriber) => ({
id: sub.user_id,
first_name: sub.first_name,
last_name: sub.last_name,
}));
return <FacePile users={users} maxCount={3} />;
},
accessor: 'subscriber_count',
Header: t('Subscribers'),
size: 'md',
id: 'subscribers',
disableSortBy: true,
},
{
Cell: ({
row: {
original: {
created_on_delta_humanized: createdOn,
created_by: createdBy,
},
},
}: TaskCellProps) => (
<CreatedInfo date={createdOn ?? ''} user={createdBy ?? undefined} />
),
Header: t('Created'),
accessor: 'created_on',
size: 'xl',
id: 'created_on',
},
{
// Hidden column for filtering by created_by
accessor: 'created_by',
id: 'created_by',
hidden: true,
},
{
Cell: ({
row: {
original: { duration_seconds },
},
}: TaskCellProps) => formatDuration(duration_seconds) ?? '-',
accessor: 'duration_seconds',
Header: t('Duration'),
size: 'sm',
id: 'duration_seconds',
disableSortBy: true,
},
{
Cell: ({
row: {
original: { payload, properties, status },
},
}: TaskCellProps) => {
const hasPayload = payload && Object.keys(payload).length > 0;
const hasStackTrace = !!properties?.stack_trace;
// Show warning if timeout is set but no abort handler during execution
// Only show for IN_PROGRESS (abort handler registers at runtime, not during PENDING)
const hasTimeoutWithoutHandler =
status === TaskStatus.InProgress &&
properties?.timeout &&
!properties?.is_abortable;
if (!hasPayload && !hasStackTrace && !hasTimeoutWithoutHandler) {
return null;
}
return (
<div style={{ display: 'flex', gap: theme.sizeUnit * 2 }}>
{hasTimeoutWithoutHandler && (
<Tooltip
title={t(
'Timeout configured (%s seconds) but no abort handler defined. ' +
'Task will continue running past the timeout.',
properties.timeout,
)}
placement="top"
>
<span>
<Icons.WarningOutlined
iconSize="l"
iconColor={theme.colorWarningText}
/>
</span>
</Tooltip>
)}
{hasPayload && <TaskPayloadPopover payload={payload} />}
{hasStackTrace && properties.stack_trace && (
<TaskStackTracePopover stackTrace={properties.stack_trace} />
)}
</div>
);
},
accessor: 'payload',
Header: t('Details'),
size: 'xs',
id: 'payload',
disableSortBy: true,
},
{
Cell: ({ row: { original } }: TaskCellProps) => {
// Unified Cancel button logic:
// - Show Cancel for any active task that the user can cancel
// - The backend handles the smart behavior (unsubscribe vs abort)
const isRunning = original.status === TaskStatus.InProgress;
// Task is not cancellable if running without abort handler
// Use !== true to catch false, undefined, and null
const isRunningButNotCancellable =
isRunning && !original.properties?.is_abortable;
const isSharedTask = original.scope === TaskScope.Shared;
const userIsSubscribed = original.subscribers?.some(
(sub: any) => sub.user_id === user.userId,
);
// Check if task is in a non-active state (completed or aborting)
const isNonActiveStatus = [
TaskStatus.Success,
TaskStatus.Failure,
TaskStatus.Aborted,
TaskStatus.Aborting,
TaskStatus.TimedOut,
].includes(original.status as TaskStatus);
// Show disabled button for running tasks without abort handler
// (only for non-shared tasks or when user is the only subscriber)
const showDisabledCancel =
isRunningButNotCancellable &&
!isNonActiveStatus &&
(!isSharedTask || (original.subscriber_count || 0) <= 1);
// Show Cancel button when:
// 1. Task can be aborted (pending, or in-progress with handler), OR
// 2. User is subscribed to a shared task (can always unsubscribe)
// But NOT when disabled cancel is shown (mutually exclusive)
const canCancelTask =
!showDisabledCancel &&
((canAbortTask(original) && !isTaskAborting(original)) ||
(isSharedTask && userIsSubscribed && !isNonActiveStatus));
if (!canCancelTask && !showDisabledCancel) {
return null;
}
return (
<div style={{ display: 'flex', gap: theme.sizeUnit * 2 }}>
{showDisabledCancel && (
<Tooltip
id="cancel-disabled-tooltip"
title={t(
'Cancellation not available due to missing abort handler',
)}
placement="bottom"
>
<span
className="action-button"
style={{ cursor: 'not-allowed' }}
>
<Icons.StopOutlined
iconSize="l"
iconColor={theme.colorTextDisabled}
/>
</span>
</Tooltip>
)}
{canCancelTask && (
<Tooltip
id="cancel-action-tooltip"
title={t('Cancel')}
placement="bottom"
>
<span
role="button"
tabIndex={0}
className="action-button"
onClick={() => openCancelModal(original)}
>
<Icons.StopOutlined iconSize="l" />
</span>
</Tooltip>
)}
</div>
);
},
Header: t('Actions'),
id: 'actions',
size: 'sm',
disableSortBy: true,
},
],
[user.userId, theme, openCancelModal],
);
const filters: ListViewFilters = useMemo(
() => [
{
Header: t('Status'),
key: 'status',
id: 'status',
input: 'select',
operator: FilterOperator.Equals,
unfilteredLabel: t('Any'),
selects: [
{ label: t('Pending'), value: TaskStatus.Pending },
{ label: t('In Progress'), value: TaskStatus.InProgress },
{ label: t('Success'), value: TaskStatus.Success },
{ label: t('Failed'), value: TaskStatus.Failure },
{ label: t('Timed Out'), value: TaskStatus.TimedOut },
{ label: t('Aborting'), value: TaskStatus.Aborting },
{ label: t('Aborted'), value: TaskStatus.Aborted },
],
},
{
Header: t('Type'),
key: 'task_type',
id: 'task_type',
input: 'search',
operator: FilterOperator.Contains,
},
{
Header: t('Scope'),
key: 'scope',
id: 'scope',
input: 'select',
operator: FilterOperator.Equals,
unfilteredLabel: t('Any'),
selects: [
{ label: t('Private'), value: TaskScope.Private },
{ label: t('Shared'), value: TaskScope.Shared },
{ label: t('System'), value: TaskScope.System },
],
},
{
Header: t('Created by'),
key: 'created_by',
id: 'created_by',
input: 'select',
operator: FilterOperator.RelationOneMany,
unfilteredLabel: t('All'),
fetchSelects: createFetchRelated(
'task',
'created_by',
createErrorHandler(errMsg =>
addDangerToast(
t(
'An error occurred while fetching created by values: %s',
errMsg,
),
),
),
),
},
],
[addDangerToast],
);
const initialSort = [{ id: 'created_on', desc: true }];
const emptyState = {
title: t('No tasks yet'),
image: 'filter-results.svg',
description: t(
'Tasks will appear here as background operations are executed.',
),
};
return (
<>
<SubMenu name={t('Tasks')} />
<ListView<Task>
className="task-list-view"
columns={columns}
count={tasksCount}
data={tasks}
emptyState={emptyState}
fetchData={fetchData}
filters={filters}
initialSort={initialSort}
loading={loading}
pageSize={PAGE_SIZE}
refreshData={refreshData}
addDangerToast={addDangerToast}
addSuccessToast={addSuccessToast}
/>
{/* Cancel Confirmation Modal */}
<Modal
title={t('Cancel Task')}
show={!!cancelModalTask}
onHide={handleCancelModalClose}
primaryButtonName={t('Yes, Cancel')}
onHandledPrimaryAction={handleCancelConfirm}
>
{cancelModalTask && (
<>
<p>
{forceCancel
? getForceAbortMessage(cancelModalTask)
: getCancelDialogMessage(cancelModalTask)}
</p>
{showForceCancelOption(cancelModalTask) && (
<Checkbox
checked={forceCancel}
onChange={e => setForceCancel(e.target.checked)}
disabled={isForceCancelDisabled(cancelModalTask)}
>
{t('Force abort (stops task for all subscribers)')}
</Checkbox>
)}
</>
)}
</Modal>
</>
);
}
export default withToasts(TaskList);

View File

@@ -16,6 +16,7 @@
* specific language governing permissions and limitations
* under the License.
*/
import { FeatureFlag, isFeatureEnabled } from '@superset-ui/core';
import {
lazy,
@@ -138,6 +139,10 @@ const RowLevelSecurityList = lazy(
),
);
const TaskList = lazy(
() => import(/* webpackChunkName: "TaskList" */ 'src/pages/TaskList'),
);
const RolesList = lazy(
() => import(/* webpackChunkName: "RolesList" */ 'src/pages/RolesList'),
);
@@ -297,6 +302,10 @@ export const routes: Routes = [
path: '/rowlevelsecurity/list',
Component: RowLevelSecurityList,
},
{
path: '/tasks/list/',
Component: TaskList,
},
{
path: '/sqllab/',
Component: SqlLab,

View File

@@ -14,7 +14,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any, Dict, List, Optional, Tuple
from __future__ import annotations
from typing import Any
import redis
from flask_caching.backends.rediscache import RedisCache, RedisSentinelCache
@@ -28,15 +30,15 @@ class RedisCacheBackend(RedisCache):
self,
host: str,
port: int,
password: Optional[str] = None,
password: str | None = None,
db: int = 0,
default_timeout: int = 300,
key_prefix: Optional[str] = None,
key_prefix: str | None = None,
ssl: bool = False,
ssl_certfile: Optional[str] = None,
ssl_keyfile: Optional[str] = None,
ssl_certfile: str | None = None,
ssl_keyfile: str | None = None,
ssl_cert_reqs: str = "required",
ssl_ca_certs: Optional[str] = None,
ssl_ca_certs: str | None = None,
**kwargs: Any,
) -> None:
super().__init__(
@@ -61,12 +63,61 @@ class RedisCacheBackend(RedisCache):
**kwargs,
)
def set(
self,
name: str,
value: Any,
ex: int | None = None,
px: int | None = None,
nx: bool = False,
xx: bool = False,
) -> bool | None:
"""
Set the value at key ``name``.
:param name: Key name
:param value: Value to set
:param ex: Expire time in seconds
:param px: Expire time in milliseconds
:param nx: If True, set only if key does not exist
:param xx: If True, set only if key already exists
:returns: True if set successfully, None if nx/xx condition not met
"""
return self._cache.set(name, value, ex=ex, px=px, nx=nx, xx=xx)
def delete(self, *names: str) -> int:
"""
Delete one or more keys.
:param names: Key names to delete
:returns: Number of keys deleted
"""
return self._cache.delete(*names)
def publish(self, channel: str, message: str) -> int:
"""
Publish a message to a Redis pub/sub channel.
:param channel: The channel name to publish to
:param message: The message to publish
:returns: Number of subscribers that received the message
"""
return self._cache.publish(channel, message)
def pubsub(self) -> redis.client.PubSub:
"""
Create a pub/sub subscription object.
:returns: PubSub object for subscribing to channels
"""
return self._cache.pubsub()
def xadd(
self,
stream_name: str,
event_data: Dict[str, Any],
event_data: dict[str, Any],
event_id: str = "*",
maxlen: Optional[int] = None,
maxlen: int | None = None,
) -> str:
return self._cache.xadd(stream_name, event_data, event_id, maxlen)
@@ -75,13 +126,13 @@ class RedisCacheBackend(RedisCache):
stream_name: str,
start: str = "-",
end: str = "+",
count: Optional[int] = None,
) -> List[Any]:
count: int | None = None,
) -> list[Any]:
count = count or self.MAX_EVENT_COUNT
return self._cache.xrange(stream_name, start, end, count)
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "RedisCacheBackend":
def from_config(cls, config: dict[str, Any]) -> RedisCacheBackend:
kwargs = {
"host": config.get("CACHE_REDIS_HOST", "localhost"),
"port": config.get("CACHE_REDIS_PORT", 6379),
@@ -108,18 +159,18 @@ class RedisSentinelCacheBackend(RedisSentinelCache):
def __init__( # pylint: disable=too-many-arguments
self,
sentinels: List[Tuple[str, int]],
sentinels: list[tuple[str, int]],
master: str,
password: Optional[str] = None,
sentinel_password: Optional[str] = None,
password: str | None = None,
sentinel_password: str | None = None,
db: int = 0,
default_timeout: int = 300,
key_prefix: str = "",
ssl: bool = False,
ssl_certfile: Optional[str] = None,
ssl_keyfile: Optional[str] = None,
ssl_certfile: str | None = None,
ssl_keyfile: str | None = None,
ssl_cert_reqs: str = "required",
ssl_ca_certs: Optional[str] = None,
ssl_ca_certs: str | None = None,
**kwargs: Any,
) -> None:
# Sentinel dont directly support SSL
@@ -177,12 +228,61 @@ class RedisSentinelCacheBackend(RedisSentinelCache):
**kwargs,
)
def set(
self,
name: str,
value: Any,
ex: int | None = None,
px: int | None = None,
nx: bool = False,
xx: bool = False,
) -> bool | None:
"""
Set the value at key ``name``.
:param name: Key name
:param value: Value to set
:param ex: Expire time in seconds
:param px: Expire time in milliseconds
:param nx: If True, set only if key does not exist
:param xx: If True, set only if key already exists
:returns: True if set successfully, None if nx/xx condition not met
"""
return self._cache.set(name, value, ex=ex, px=px, nx=nx, xx=xx)
def delete(self, *names: str) -> int:
"""
Delete one or more keys.
:param names: Key names to delete
:returns: Number of keys deleted
"""
return self._cache.delete(*names)
def publish(self, channel: str, message: str) -> int:
"""
Publish a message to a Redis pub/sub channel.
:param channel: The channel name to publish to
:param message: The message to publish
:returns: Number of subscribers that received the message
"""
return self._cache.publish(channel, message)
def pubsub(self) -> redis.client.PubSub:
"""
Create a pub/sub subscription object.
:returns: PubSub object for subscribing to channels
"""
return self._cache.pubsub()
def xadd(
self,
stream_name: str,
event_data: Dict[str, Any],
event_data: dict[str, Any],
event_id: str = "*",
maxlen: Optional[int] = None,
maxlen: int | None = None,
) -> str:
return self._cache.xadd(stream_name, event_data, event_id, maxlen)
@@ -191,13 +291,13 @@ class RedisSentinelCacheBackend(RedisSentinelCache):
stream_name: str,
start: str = "-",
end: str = "+",
count: Optional[int] = None,
) -> List[Any]:
count: int | None = None,
) -> list[Any]:
count = count or self.MAX_EVENT_COUNT
return self._cache.xrange(stream_name, start, end, count)
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "RedisSentinelCacheBackend":
def from_config(cls, config: dict[str, Any]) -> RedisSentinelCacheBackend:
kwargs = {
"sentinels": config.get("CACHE_REDIS_SENTINELS", [("127.0.0.1", 26379)]),
"master": config.get("CACHE_REDIS_SENTINEL_MASTER", "mymaster"),

View File

@@ -0,0 +1,24 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from superset.commands.distributed_lock.acquire import AcquireDistributedLock
from superset.commands.distributed_lock.release import ReleaseDistributedLock
__all__ = [
"AcquireDistributedLock",
"ReleaseDistributedLock",
]

View File

@@ -0,0 +1,132 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import logging
from datetime import datetime, timedelta, timezone
from functools import partial
from typing import Any
import redis
from sqlalchemy.exc import SQLAlchemyError
from superset.commands.distributed_lock.base import (
BaseDistributedLockCommand,
get_default_lock_ttl,
get_redis_client,
)
from superset.daos.key_value import KeyValueDAO
from superset.exceptions import AcquireDistributedLockFailedException
from superset.key_value.exceptions import (
KeyValueCodecEncodeException,
KeyValueUpsertFailedError,
)
from superset.key_value.types import KeyValueResource
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__)
class AcquireDistributedLock(BaseDistributedLockCommand):
"""
Acquire a distributed lock with automatic backend selection.
Uses Redis SET NX EX when SIGNAL_CACHE_CONFIG is configured,
otherwise falls back to KeyValue table.
Raises AcquireDistributedLockFailedException if:
- Lock is already held by another process
- Redis connection fails
"""
ttl_seconds: int
def __init__(
self,
namespace: str,
params: dict[str, Any] | None = None,
ttl_seconds: int | None = None,
) -> None:
super().__init__(namespace, params)
self.ttl_seconds = ttl_seconds or get_default_lock_ttl()
def run(self) -> None:
if (redis_client := get_redis_client()) is not None:
self._acquire_redis(redis_client)
else:
self._acquire_kv()
def _acquire_redis(self, redis_client: Any) -> None:
"""Acquire lock using Redis SET NX EX (atomic)."""
try:
# SET NX EX: Set if not exists, with expiration
# Returns True if lock acquired, None if already exists
acquired = redis_client.set(
self.redis_lock_key,
"1",
nx=True,
ex=self.ttl_seconds,
)
if not acquired:
logger.debug("Redis lock on %s already taken", self.redis_lock_key)
raise AcquireDistributedLockFailedException("Lock already taken")
logger.debug(
"Acquired Redis lock: %s (TTL=%ds)",
self.redis_lock_key,
self.ttl_seconds,
)
except redis.RedisError as ex:
logger.error("Redis lock error for %s: %s", self.redis_lock_key, ex)
raise AcquireDistributedLockFailedException(
f"Redis lock failed: {ex}"
) from ex
@transaction(
on_error=partial(
on_error,
catches=(
KeyValueCodecEncodeException,
KeyValueUpsertFailedError,
SQLAlchemyError,
),
reraise=AcquireDistributedLockFailedException,
),
)
def _acquire_kv(self) -> None:
"""Acquire lock using KeyValue table (database)."""
# Delete expired entries first to prevent stale locks from blocking
KeyValueDAO.delete_expired_entries(self.resource)
# Create entry - unique constraint will raise if lock already exists
KeyValueDAO.create_entry(
resource=KeyValueResource.LOCK,
value={"value": True},
codec=self.codec,
key=self.key,
expires_on=datetime.now(timezone.utc) + timedelta(seconds=self.ttl_seconds),
)
logger.debug(
"Acquired KV lock: namespace=%s key=%s (TTL=%ds)",
self.namespace,
self.key,
self.ttl_seconds,
)

View File

@@ -15,27 +15,58 @@
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import logging
import uuid
from typing import Any, Union
from typing import Any, TYPE_CHECKING
from flask import current_app as app
from flask import current_app
from superset.commands.base import BaseCommand
from superset.distributed_lock.utils import get_key
from superset.extensions import cache_manager
from superset.key_value.types import JsonKeyValueCodec, KeyValueResource
if TYPE_CHECKING:
import redis
logger = logging.getLogger(__name__)
stats_logger = app.config["STATS_LOGGER"]
def get_default_lock_ttl() -> int:
"""Get the default lock TTL from config."""
return int(current_app.config.get("DISTRIBUTED_LOCK_DEFAULT_TTL", 30))
def get_redis_client() -> "redis.Redis[Any] | None":
"""
Get Redis client from signal cache if available.
Returns None if SIGNAL_CACHE_CONFIG is not configured,
allowing fallback to database-backed locking.
"""
backend = cache_manager.signal_cache
return backend._cache if backend else None
class BaseDistributedLockCommand(BaseCommand):
"""Base command for distributed lock operations."""
key: uuid.UUID
namespace: str
codec = JsonKeyValueCodec()
resource = KeyValueResource.LOCK
def __init__(self, namespace: str, params: Union[dict[str, Any], None] = None):
self.key = get_key(namespace, **(params or {}))
def __init__(self, namespace: str, params: dict[str, Any] | None = None) -> None:
self.namespace = namespace
self.params = params or {}
self.key = get_key(namespace, **self.params)
@property
def redis_lock_key(self) -> str:
"""Redis key for this lock."""
return f"lock:{self.namespace}:{self.key}"
def validate(self) -> None:
pass

View File

@@ -1,64 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from datetime import datetime, timedelta
from functools import partial
from flask import current_app as app
from sqlalchemy.exc import SQLAlchemyError
from superset.commands.distributed_lock.base import BaseDistributedLockCommand
from superset.daos.key_value import KeyValueDAO
from superset.exceptions import CreateKeyValueDistributedLockFailedException
from superset.key_value.exceptions import (
KeyValueCodecEncodeException,
KeyValueUpsertFailedError,
)
from superset.key_value.types import KeyValueResource
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__)
stats_logger = app.config["STATS_LOGGER"]
class CreateDistributedLock(BaseDistributedLockCommand):
lock_expiration = timedelta(seconds=30)
def validate(self) -> None:
pass
@transaction(
on_error=partial(
on_error,
catches=(
KeyValueCodecEncodeException,
KeyValueUpsertFailedError,
SQLAlchemyError,
),
reraise=CreateKeyValueDistributedLockFailedException,
),
)
def run(self) -> None:
KeyValueDAO.delete_expired_entries(self.resource)
KeyValueDAO.create_entry(
resource=KeyValueResource.LOCK,
value={"value": True},
codec=self.codec,
key=self.key,
expires_on=datetime.now() + self.lock_expiration,
)

View File

@@ -1,49 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from functools import partial
from flask import current_app as app
from sqlalchemy.exc import SQLAlchemyError
from superset.commands.distributed_lock.base import BaseDistributedLockCommand
from superset.daos.key_value import KeyValueDAO
from superset.exceptions import DeleteKeyValueDistributedLockFailedException
from superset.key_value.exceptions import KeyValueDeleteFailedError
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__)
stats_logger = app.config["STATS_LOGGER"]
class DeleteDistributedLock(BaseDistributedLockCommand):
def validate(self) -> None:
pass
@transaction(
on_error=partial(
on_error,
catches=(
KeyValueDeleteFailedError,
SQLAlchemyError,
),
reraise=DeleteKeyValueDistributedLockFailedException,
),
)
def run(self) -> None:
KeyValueDAO.delete_entry(self.resource, self.key)

View File

@@ -0,0 +1,83 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import logging
from functools import partial
from typing import Any
import redis
from sqlalchemy.exc import SQLAlchemyError
from superset.commands.distributed_lock.base import (
BaseDistributedLockCommand,
get_redis_client,
)
from superset.daos.key_value import KeyValueDAO
from superset.exceptions import ReleaseDistributedLockFailedException
from superset.key_value.exceptions import KeyValueDeleteFailedError
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__)
class ReleaseDistributedLock(BaseDistributedLockCommand):
"""
Release a distributed lock with automatic backend selection.
Uses Redis DELETE when SIGNAL_CACHE_CONFIG is configured,
otherwise deletes from KeyValue table.
"""
def run(self) -> None:
if (redis_client := get_redis_client()) is not None:
self._release_redis(redis_client)
else:
self._release_kv()
def _release_redis(self, redis_client: Any) -> None:
"""Release lock using Redis DELETE."""
try:
redis_client.delete(self.redis_lock_key)
logger.debug("Released Redis lock: %s", self.redis_lock_key)
except redis.RedisError as ex:
# Log warning but don't raise - TTL will handle cleanup
logger.warning(
"Failed to release Redis lock %s: %s (TTL will handle cleanup)",
self.redis_lock_key,
ex,
)
@transaction(
on_error=partial(
on_error,
catches=(
KeyValueDeleteFailedError,
SQLAlchemyError,
),
reraise=ReleaseDistributedLockFailedException,
),
)
def _release_kv(self) -> None:
"""Release lock using KeyValue table (database)."""
KeyValueDAO.delete_entry(self.resource, self.key)
logger.debug(
"Released KV lock: namespace=%s key=%s",
self.namespace,
self.key,
)

View File

@@ -0,0 +1,28 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from superset.commands.tasks.cancel import CancelTaskCommand
from superset.commands.tasks.prune import TaskPruneCommand
from superset.commands.tasks.submit import SubmitTaskCommand
from superset.commands.tasks.update import UpdateTaskCommand
__all__ = [
"CancelTaskCommand",
"SubmitTaskCommand",
"TaskPruneCommand",
"UpdateTaskCommand",
]

View File

@@ -0,0 +1,314 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Unified cancel task command for GTF."""
import logging
from functools import partial
from typing import TYPE_CHECKING
from uuid import UUID
from flask import current_app
from superset_core.api.tasks import TaskScope, TaskStatus
from superset.commands.base import BaseCommand
from superset.commands.tasks.exceptions import (
TaskAbortFailedError,
TaskNotAbortableError,
TaskNotFoundError,
TaskPermissionDeniedError,
)
from superset.extensions import security_manager
from superset.stats_logger import BaseStatsLogger
from superset.tasks.locks import task_lock
from superset.tasks.utils import get_active_dedup_key
from superset.utils.core import get_user_id
from superset.utils.decorators import on_error, transaction
if TYPE_CHECKING:
from superset.models.tasks import Task
logger = logging.getLogger(__name__)
class CancelTaskCommand(BaseCommand):
"""
Unified command to cancel a task.
Behavior:
- For private tasks or single-subscriber tasks: aborts the task
- For shared tasks with multiple subscribers (non-admin): unsubscribes user
- For shared tasks with force=True (admin only): aborts for all subscribers
The term "cancel" is user-facing; internally this may abort or unsubscribe.
This command acquires a distributed lock before starting a transaction to
prevent race conditions with concurrent submit/cancel operations.
Permission checks are deferred to inside the lock to minimize SELECTs:
we only fetch the task once, then validate permissions on the fetched data.
"""
def __init__(self, task_uuid: UUID, force: bool = False):
"""
Initialize the cancel command.
:param task_uuid: UUID of the task to cancel
:param force: If True, force abort even with multiple subscribers (admin only)
"""
self._task_uuid = task_uuid
self._force = force
self._action_taken: str = (
"cancelled" # Will be set to 'aborted' or 'unsubscribed'
)
self._should_publish_abort: bool = False
def run(self) -> "Task":
"""
Execute the cancel command with distributed locking.
The lock is acquired BEFORE starting the transaction to avoid holding
a DB connection during lock acquisition. Uses dedup_key as lock key
to ensure Submit and Cancel operations use the same lock.
:returns: The updated task model
"""
from superset.daos.tasks import TaskDAO
# Lightweight fetch to compute dedup_key for locking
# This is needed to use the same lock key as SubmitTaskCommand
task = TaskDAO.find_one_or_none(
skip_base_filter=security_manager.is_admin(), uuid=self._task_uuid
)
if not task:
raise TaskNotFoundError()
# Compute dedup_key using the same logic as SubmitTaskCommand
dedup_key = get_active_dedup_key(
scope=task.scope,
task_type=task.task_type,
task_key=task.task_key,
user_id=task.user_id,
)
# Acquire lock BEFORE transaction starts
# Using dedup_key ensures Submit and Cancel use the same lock
with task_lock(dedup_key):
result = self._execute_with_transaction()
# Publish abort notification AFTER transaction commits
# This prevents race conditions where listeners check DB before commit
if self._should_publish_abort:
from superset.tasks.manager import TaskManager
TaskManager.publish_abort(self._task_uuid)
return result
@transaction(on_error=partial(on_error, reraise=TaskAbortFailedError))
def _execute_with_transaction(self) -> "Task":
"""
Execute the cancel operation inside a transaction.
Combines fetch + validation + execution in a single transaction,
reducing the number of SELECTs from 3 to 1 (plus DAO operations).
:returns: The updated task model
"""
from superset.daos.tasks import TaskDAO
# Check admin status (no DB access)
is_admin = security_manager.is_admin()
# Force flag requires admin
if self._force and not is_admin:
raise TaskPermissionDeniedError(
"Only administrators can force cancel a task"
)
# Single SELECT: fetch task and validate permissions on it
task = TaskDAO.find_one_or_none(skip_base_filter=is_admin, uuid=self._task_uuid)
if not task:
raise TaskNotFoundError()
# Validate permissions on the fetched task
self._validate_permissions(task, is_admin)
# Execute cancel and return updated task
return self._do_cancel(task, is_admin)
def _validate_permissions(self, task: "Task", is_admin: bool) -> None:
"""
Validate permissions on an already-fetched task.
Permission rules by scope:
- private: Only creator or admin (already filtered by base_filter)
- shared: Subscribers or admin
- system: Only admin
:param task: The task to validate permissions for
:param is_admin: Whether current user is admin
:raises TaskAbortFailedError: If task is not in cancellable state
:raises TaskPermissionDeniedError: If user lacks permission
"""
# Check if task is in a cancellable state
if task.status not in [
TaskStatus.PENDING.value,
TaskStatus.IN_PROGRESS.value,
TaskStatus.ABORTING.value, # Already aborting is OK (idempotent)
]:
raise TaskAbortFailedError()
# Admin can cancel anything
if is_admin:
return
# Non-admin permission checks by scope
user_id = get_user_id()
if task.scope == TaskScope.SYSTEM.value:
# System tasks are admin-only
raise TaskPermissionDeniedError(
"Only administrators can cancel system tasks"
)
if task.is_shared:
# Shared tasks: must be a subscriber
if not user_id or not task.has_subscriber(user_id):
raise TaskPermissionDeniedError(
"You must be subscribed to cancel this shared task"
)
# Private tasks: already filtered by base_filter (only creator can see)
# If we got here, user has permission
def _do_cancel(self, task: "Task", is_admin: bool) -> "Task":
"""
Execute the cancel operation (abort or unsubscribe).
:param task: The task to cancel
:param is_admin: Whether current user is admin
:returns: The updated task model
"""
user_id = get_user_id()
# Determine action based on task scope and force flag
should_abort = (
# Admin with force flag always aborts
(is_admin and self._force)
# Private tasks always abort (only one user)
or task.is_private
# System tasks always abort (admin only anyway)
or task.is_system
# Single or last subscriber - abort
or task.subscriber_count <= 1
)
if should_abort:
return self._do_abort(task, is_admin)
else:
return self._do_unsubscribe(task, user_id)
def _do_abort(self, task: "Task", is_admin: bool) -> "Task":
"""
Execute abort operation.
:param task: The task to abort
:param is_admin: Whether current user is admin
:returns: The updated task model
"""
from superset.daos.tasks import TaskDAO
try:
result: Task | None = TaskDAO.abort_task(
task.uuid, skip_base_filter=is_admin
)
except TaskNotAbortableError:
raise
if result is None:
# abort_task returned None - task wasn't aborted
# This can happen if task is already finished
raise TaskAbortFailedError()
self._action_taken = "aborted"
# Track if we need to publish abort after commit
if TaskStatus(result.status) == TaskStatus.ABORTING:
self._should_publish_abort = True
# Emit stats metric
stats_logger: BaseStatsLogger = current_app.config["STATS_LOGGER"]
stats_logger.incr("gtf.task.abort")
logger.info(
"Task aborted: %s (scope: %s, force: %s)",
task.uuid,
task.scope,
self._force,
)
return result
def _do_unsubscribe(self, task: "Task", user_id: int | None) -> "Task":
"""
Execute unsubscribe operation.
:param task: The task to unsubscribe from
:param user_id: ID of user to unsubscribe
:returns: The updated task model
"""
from superset.daos.tasks import TaskDAO
self._action_taken = "unsubscribed"
if not user_id or not task.has_subscriber(user_id):
# User not subscribed - they shouldn't be able to cancel
raise TaskPermissionDeniedError(
"You are not subscribed to this shared task"
)
result = TaskDAO.remove_subscriber(task.id, user_id)
if result is None:
raise TaskPermissionDeniedError(
"You are not subscribed to this shared task"
)
# Emit stats metric
stats_logger: BaseStatsLogger = current_app.config["STATS_LOGGER"]
stats_logger.incr("gtf.task.unsubscribe")
logger.info(
"User %s unsubscribed from shared task: %s",
user_id,
task.uuid,
)
return result
def validate(self) -> None:
pass
@property
def action_taken(self) -> str:
"""
Get the action that was taken.
:returns: 'aborted' or 'unsubscribed'
"""
return self._action_taken

View File

@@ -0,0 +1,106 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from flask_babel import lazy_gettext as _
from superset.commands.exceptions import (
CommandException,
CommandInvalidError,
CreateFailedError,
ForbiddenError,
UpdateFailedError,
)
class TaskNotFoundError(CommandException):
"""Task not found."""
status = 404
message = "Task not found."
class TaskInvalidError(CommandInvalidError):
"""Task parameters are invalid."""
message = _("Task parameters are invalid.")
class TaskCreateFailedError(CreateFailedError):
"""Task creation failed."""
message = _("Task could not be created.")
class TaskUpdateFailedError(UpdateFailedError):
"""Task update failed."""
message = _("Task could not be updated.")
class TaskAbortFailedError(CommandException):
"""Task abortion failed."""
status = 422
message = _("Task could not be aborted.")
class TaskNotAbortableError(CommandException):
"""
Task cannot be aborted.
Raised when attempting to abort an in-progress task that has not
registered an abort handler (is_abortable is not True).
"""
status = 400
message = _(
"Task is not abortable. The task is in progress but has not "
"registered an abort handler."
)
class TaskForbiddenError(ForbiddenError):
"""Task operation forbidden."""
message = _("Changing this task is forbidden")
class TaskPermissionDeniedError(ForbiddenError):
"""Task operation not permitted for current user."""
def __init__(self, message: str | None = None):
super().__init__()
if message:
self.message = message
else:
self.message = _("You do not have permission to perform this operation")
class GlobalTaskFrameworkDisabledError(CommandException):
"""
Raised when a GTF task is called or scheduled but GTF is disabled.
This exception is raised at call/schedule time (not decoration time) to allow
modules with @task-decorated functions to be imported safely when GTF is disabled.
The check is deferred until someone actually tries to execute a task.
"""
message = _(
"The Global Task Framework is not enabled. "
"Set GLOBAL_TASK_FRAMEWORK=True in your feature flags to use @task. "
"See https://superset.apache.org/docs/configuration/async-queries-celery "
"for configuration details."
)

View File

@@ -0,0 +1,184 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Internal task update commands for GTF internal use only.
These commands perform zero-read updates using targeted SQL UPDATE statements.
They're designed for use by TaskContext and executor code where the framework
owns the authoritative state and doesn't need to read before writing.
Unlike UpdateTaskCommand, these commands:
- Do NOT fetch the task entity before updating
- Do NOT check permissions (internal use only)
- Use targeted SQL UPDATE for efficiency
"""
from __future__ import annotations
import logging
from functools import partial
from typing import Any
from uuid import UUID
from superset_core.api.tasks import TaskProperties, TaskStatus
from superset.commands.base import BaseCommand
from superset.commands.tasks.exceptions import TaskUpdateFailedError
from superset.daos.tasks import TaskDAO
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__)
class InternalUpdateTaskCommand(BaseCommand):
"""
Zero-read task update command for properties/payload.
This command directly writes properties and/or payload to the database
without reading the current values first. The caller (TaskContext)
maintains the authoritative cached state and passes complete merged
values to write.
This is an optimization for task execution where:
1. The executor owns the properties/payload state
2. No permission checks are needed (internal framework code)
3. Status column should not be touched (use InternalStatusTransitionCommand)
WARNING: This command should ONLY be used by TaskContext and similar
internal framework code. External callers should use UpdateTaskCommand.
"""
def __init__(
self,
task_uuid: UUID,
properties: TaskProperties | None = None,
payload: dict[str, Any] | None = None,
):
"""
Initialize internal update command.
:param task_uuid: UUID of the task to update
:param properties: Complete properties dict to write (replaces existing)
:param payload: Complete payload dict to write (replaces existing)
"""
self._task_uuid = task_uuid
self._properties = properties
self._payload = payload
def validate(self) -> None:
"""No validation needed for internal command."""
pass
@transaction(on_error=partial(on_error, reraise=TaskUpdateFailedError))
def run(self) -> bool:
"""
Execute zero-read update.
:returns: True if task was updated, False if not found or nothing to update
"""
if self._properties is None and self._payload is None:
return False
updated = TaskDAO.set_properties_and_payload(
task_uuid=self._task_uuid,
properties=self._properties,
payload=self._payload,
)
if updated:
logger.debug(
"Internal update for task %s: properties=%s, payload=%s",
self._task_uuid,
self._properties is not None,
self._payload is not None,
)
return updated
class InternalStatusTransitionCommand(BaseCommand):
"""
Atomic conditional status transition command for executor use.
This command provides race-safe status transitions by using atomic
compare-and-swap semantics. The status is only updated if the current
status matches the expected value(s).
Use cases:
- PENDING → IN_PROGRESS: Task pickup (executor starting)
- IN_PROGRESS → SUCCESS: Normal completion (only if not ABORTING)
- IN_PROGRESS → FAILURE: Task exception (only if not ABORTING)
- ABORTING → ABORTED: Abort handlers completed successfully
- ABORTING → TIMED_OUT: Timeout handlers completed successfully
- ABORTING → FAILURE: Abort/cleanup handlers failed
The atomic nature prevents race conditions where:
- Executor tries to set SUCCESS but task was concurrently aborted
- Multiple executors try to pick up the same task
WARNING: This command should ONLY be used by executor code (decorators.py,
scheduler.py). External callers should use UpdateTaskCommand.
"""
def __init__(
self,
task_uuid: UUID,
new_status: TaskStatus | str,
expected_status: TaskStatus | str | list[TaskStatus | str],
properties: TaskProperties | None = None,
set_started_at: bool = False,
set_ended_at: bool = False,
):
"""
Initialize status transition command.
:param task_uuid: UUID of the task to update
:param new_status: Target status to set
:param expected_status: Current status(es) required for update to succeed.
Can be a single status or list of acceptable current statuses.
:param properties: Optional properties to update atomically with status
(e.g., error_message on FAILURE)
:param set_started_at: If True, also set started_at to current timestamp.
Should be True for PENDING → IN_PROGRESS transitions.
:param set_ended_at: If True, also set ended_at to current timestamp.
Should be True for terminal status transitions.
"""
self._task_uuid = task_uuid
self._new_status = new_status
self._expected_status = expected_status
self._properties = properties
self._set_started_at = set_started_at
self._set_ended_at = set_ended_at
def validate(self) -> None:
"""No validation needed for internal command."""
pass
@transaction(on_error=partial(on_error, reraise=TaskUpdateFailedError))
def run(self) -> bool:
"""
Execute atomic conditional status update.
:returns: True if status was updated (expected matched), False otherwise
"""
return TaskDAO.conditional_status_update(
task_uuid=self._task_uuid,
new_status=self._new_status,
expected_status=self._expected_status,
properties=self._properties,
set_started_at=self._set_started_at,
set_ended_at=self._set_ended_at,
)

View File

@@ -0,0 +1,134 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
import time
from datetime import datetime, timedelta
import sqlalchemy as sa
from superset_core.api.tasks import TaskStatus
from superset import db
from superset.commands.base import BaseCommand
logger = logging.getLogger(__name__)
# pylint: disable=consider-using-transaction
class TaskPruneCommand(BaseCommand):
"""
Command to prune the tasks table by deleting rows older than the specified
retention period.
This command deletes records from the `Task` table that are in terminal states
(success, failure, aborted, or timed_out) and have not been changed within the
specified number of days. It helps in maintaining the database by removing
outdated entries and freeing up space.
Attributes:
retention_period_days (int): The number of days for which records should be retained.
Records older than this period will be deleted.
max_rows_per_run (int | None): The maximum number of rows to delete in a single run.
If provided and greater than zero, rows are selected
deterministically from the oldest first (by timestamp then id)
up to this limit in this execution.
""" # noqa: E501
def __init__(self, retention_period_days: int, max_rows_per_run: int | None = None):
"""
:param retention_period_days: Number of days to keep in the tasks table
:param max_rows_per_run: The maximum number of rows to delete in a single run.
If provided and greater than zero, rows are selected deterministically from the
oldest first (by timestamp then id) up to this limit in this execution.
""" # noqa: E501
self.retention_period_days = retention_period_days
self.max_rows_per_run = max_rows_per_run
def run(self) -> None:
"""
Executes the prune command
"""
batch_size = 999 # SQLite has a IN clause limit of 999
total_deleted = 0
start_time = time.time()
# Select all IDs that need to be deleted
# Only delete completed tasks (success, failure, or aborted)
from superset.models.tasks import Task
select_stmt = sa.select(Task.id).where(
Task.ended_at < datetime.now() - timedelta(days=self.retention_period_days),
Task.status.in_(
[
TaskStatus.SUCCESS.value,
TaskStatus.FAILURE.value,
TaskStatus.ABORTED.value,
TaskStatus.TIMED_OUT.value,
]
),
)
# Optionally limited by max_rows_per_run
# order by oldest first for deterministic deletion
if self.max_rows_per_run is not None and self.max_rows_per_run > 0:
select_stmt = select_stmt.order_by(
Task.ended_at.asc(), Task.id.asc()
).limit(self.max_rows_per_run)
ids_to_delete = db.session.execute(select_stmt).scalars().all()
total_rows = len(ids_to_delete)
logger.info("Total rows to be deleted: %s", f"{total_rows:,}")
next_logging_threshold = 1
# Iterate over the IDs in batches
for i in range(0, total_rows, batch_size):
batch_ids = ids_to_delete[i : i + batch_size]
# Delete the selected batch using IN clause
result = db.session.execute(sa.delete(Task).where(Task.id.in_(batch_ids)))
# Update the total number of deleted records
total_deleted += result.rowcount
# Explicitly commit the transaction given that if an error occurs, we want to ensure that the # noqa: E501
# records that have been deleted so far are committed
db.session.commit()
# Log the number of deleted records every 1% increase in progress
percentage_complete = (total_deleted / total_rows) * 100
if percentage_complete >= next_logging_threshold:
logger.info(
"Deleted %s rows from the tasks table older than %s days (%d%% complete)", # noqa: E501
f"{total_deleted:,}",
self.retention_period_days,
percentage_complete,
)
next_logging_threshold += 1
elapsed_time = time.time() - start_time
minutes, seconds = divmod(elapsed_time, 60)
formatted_time = f"{int(minutes):02}:{int(seconds):02}"
logger.info(
"Pruning complete: %s rows deleted in %s",
f"{total_deleted:,}",
formatted_time,
)
def validate(self) -> None:
pass

View File

@@ -0,0 +1,168 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Submit task command for GTF."""
import logging
import uuid
from functools import partial
from typing import Any, TYPE_CHECKING
from flask import current_app
from marshmallow import ValidationError
from superset_core.api.tasks import TaskScope
from superset.commands.base import BaseCommand
from superset.commands.tasks.exceptions import (
TaskCreateFailedError,
TaskInvalidError,
)
from superset.daos.exceptions import DAOCreateFailedError
from superset.stats_logger import BaseStatsLogger
from superset.tasks.locks import task_lock
from superset.tasks.utils import get_active_dedup_key
from superset.utils.decorators import on_error, transaction
if TYPE_CHECKING:
from superset.models.tasks import Task
logger = logging.getLogger(__name__)
class SubmitTaskCommand(BaseCommand):
"""
Command to submit a task (create new or join existing).
This command owns locking and create-vs-join business logic.
It acquires a distributed lock and then decides whether to:
- Create a new task (if no existing task with same dedup_key)
- Join an existing task by adding the user as subscriber
"""
def __init__(self, data: dict[str, Any]):
self._properties = data.copy()
@transaction(on_error=partial(on_error, reraise=TaskCreateFailedError))
def run(self) -> "Task":
"""
Execute the command with distributed locking.
Acquires lock based on dedup_key, then checks for existing task
and either creates new or joins existing (adding subscriber).
:returns: Task model (either newly created or existing)
"""
task, _ = self.run_with_info()
return task
@transaction(on_error=partial(on_error, reraise=TaskCreateFailedError))
def run_with_info(self) -> tuple["Task", bool]:
"""
Execute the command and return (task, is_new) tuple.
This variant allows callers to distinguish between creating a new task
and joining an existing one. Useful for sync execution where the caller
needs to wait for an existing task to complete rather than executing again.
:returns: Tuple of (Task, is_new) where is_new is True if task was created
"""
from superset.daos.tasks import TaskDAO
self.validate()
# Extract and normalize parameters
task_type = self._properties["task_type"]
task_key = self._properties.get("task_key") or str(uuid.uuid4())
scope = self._properties.get("scope", TaskScope.PRIVATE.value)
user_id = self._properties.get("user_id")
# Build dedup_key for lock
dedup_key = get_active_dedup_key(
scope=scope,
task_type=task_type,
task_key=task_key,
user_id=user_id,
)
# Acquire lock to prevent race conditions during create/join
with task_lock(dedup_key):
# Check for existing task (safe under lock)
existing = TaskDAO.find_by_task_key(task_type, task_key, scope, user_id)
# Get stats logger
stats_logger: BaseStatsLogger = current_app.config["STATS_LOGGER"]
if existing:
# Join existing task - add subscriber if not already subscribed
if user_id and not existing.has_subscriber(user_id):
TaskDAO.add_subscriber(existing.id, user_id)
stats_logger.incr("gtf.task.subscribe")
logger.info(
"User %s joined existing task: %s",
user_id,
task_key,
)
else:
# Same user submitted the same task - deduplication hit
stats_logger.incr("gtf.task.dedupe")
logger.debug(
"Deduplication hit for task: %s (user_id=%s)",
task_key,
user_id,
)
return existing, False # is_new=False: joined existing task
# Create new task (DAO is now a pure data operation)
try:
task = TaskDAO.create_task(
task_type=task_type,
task_key=task_key,
scope=scope,
task_name=self._properties.get("task_name"),
user_id=user_id,
payload=self._properties.get("payload", {}),
properties=self._properties.get("properties", {}),
)
stats_logger.incr("gtf.task.create")
return task, True # is_new=True: created new task
except DAOCreateFailedError as ex:
raise TaskCreateFailedError() from ex
def validate(self) -> None:
"""Validate command parameters."""
exceptions: list[ValidationError] = []
# Require task_type
if not self._properties.get("task_type"):
exceptions.append(
ValidationError("task_type is required", field_name="task_type")
)
scope = self._properties.get("scope", TaskScope.PRIVATE.value)
scope_value = scope.value if isinstance(scope, TaskScope) else scope
valid_scopes = [s.value for s in TaskScope]
if scope_value not in valid_scopes:
exceptions.append(
ValidationError(
f"scope must be one of {valid_scopes}",
field_name="scope",
)
)
# Store normalized value for use in run()
self._properties["scope"] = scope_value
if exceptions:
raise TaskInvalidError(exceptions=exceptions)

View File

@@ -0,0 +1,170 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import logging
from datetime import datetime
from functools import partial
from typing import Any, TYPE_CHECKING
from uuid import UUID
from superset_core.api.tasks import TaskProperties
from superset import security_manager
from superset.commands.base import BaseCommand
from superset.commands.tasks.exceptions import (
TaskForbiddenError,
TaskNotFoundError,
TaskUpdateFailedError,
)
from superset.exceptions import SupersetSecurityException
from superset.tasks.locks import task_lock
from superset.tasks.utils import get_active_dedup_key
from superset.utils.decorators import on_error, transaction
if TYPE_CHECKING:
from superset.models.tasks import Task
logger = logging.getLogger(__name__)
class UpdateTaskCommand(BaseCommand):
"""
Command to update a task.
Uses explicit typed parameters to avoid confusion between
payload (task output) and properties (runtime state/config).
This command acquires a distributed lock to prevent race conditions with
concurrent submit/cancel operations on the same logical task.
"""
def __init__(
self,
task_uuid: UUID,
*,
status: str | None = None,
started_at: datetime | None = None,
ended_at: datetime | None = None,
payload: dict[str, Any] | None = None,
properties: TaskProperties | None = None,
skip_security_check: bool = False,
):
"""
Initialize UpdateTaskCommand.
:param task_uuid: UUID of the task to update
:param status: New status value (column field)
:param started_at: Started timestamp (column field)
:param ended_at: Ended timestamp (column field)
:param payload: Task output data to merge (stored in payload column)
:param properties: Runtime state/config updates as dict. Keys must be
valid TaskProperties field names (is_abortable, progress_percent, etc.)
:param skip_security_check: If True, skip ownership validation.
Use this for internal task updates (e.g., task executor updating
its own task's progress). Default is False for API-driven updates.
"""
self._task_uuid = task_uuid
self._status = status
self._started_at = started_at
self._ended_at = ended_at
self._payload = payload
self._properties = properties
self._model: Task | None = None
self._skip_security_check = skip_security_check
@transaction(on_error=partial(on_error, reraise=TaskUpdateFailedError))
def run(self) -> Task:
"""
Execute the update command with distributed locking.
Acquires lock based on dedup_key to prevent race conditions with
concurrent submit/cancel operations on the same logical task.
:returns: The updated task model
"""
from superset.daos.tasks import TaskDAO
self.validate()
# Fetch task to compute dedup_key for locking
task = TaskDAO.find_one_or_none(
skip_base_filter=self._skip_security_check,
uuid=self._task_uuid,
)
if not task:
raise TaskNotFoundError()
self._model = task
# Build lock key from task properties (same structure as dedup_key)
dedup_key = get_active_dedup_key(
scope=self._model.scope,
task_type=self._model.task_type,
task_key=self._model.task_key,
user_id=self._model.user_id,
)
# Acquire lock to prevent race with submit/cancel operations
with task_lock(dedup_key):
return self._execute_update()
def _execute_update(self) -> "Task":
"""
Execute the update operation under lock.
:returns: The updated task model
"""
from superset.daos.tasks import TaskDAO
# Re-fetch model under lock to get fresh state
fresh_model = TaskDAO.find_one_or_none(
skip_base_filter=self._skip_security_check,
uuid=self._task_uuid,
)
if not fresh_model:
raise TaskNotFoundError()
self._model = fresh_model
# Verify ownership (user can only update their own tasks)
# Skip this check for internal updates (e.g., task executor updating progress)
if not self._skip_security_check:
try:
security_manager.raise_for_ownership(self._model)
except SupersetSecurityException as ex:
raise TaskForbiddenError() from ex
# Update status via set_status() for proper timestamp handling
if self._status is not None:
self._model.set_status(self._status)
if self._started_at is not None:
self._model.started_at = self._started_at
if self._ended_at is not None:
self._model.ended_at = self._ended_at
# Update payload (merges with existing)
if self._payload is not None:
self._model.set_payload(self._payload)
# Update properties (dict passed through to model)
if self._properties:
self._model.update_properties(self._properties)
return TaskDAO.update(self._model)
def validate(self) -> None:
pass

View File

@@ -662,6 +662,9 @@ DEFAULT_FEATURE_FLAGS: dict[str, bool] = {
# sts:AssumeRole permissions to prevent unauthorized access.
# @lifecycle: testing
"AWS_DATABASE_IAM_AUTH": False,
# Global Task Framework - unified task management with progress tracking,
# cancellation, and deduplication.
"GLOBAL_TASK_FRAMEWORK": False,
# Use analogous colors in charts
# @lifecycle: testing
"USE_ANALOGOUS_COLORS": False,
@@ -1393,6 +1396,12 @@ class CeleryConfig: # pylint: disable=too-few-public-methods
# "schedule": crontab(minute="*", hour="*"),
# "kwargs": {"retention_period_days": 180, "max_rows_per_run": 10000},
# },
# Uncomment to enable pruning of the tasks table
# "prune_tasks": {
# "task": "prune_tasks",
# "schedule": crontab(minute=0, hour=0),
# "kwargs": {"retention_period_days": 90, "max_rows_per_run": 10000},
# },
# Uncomment to enable Slack channel cache warm-up
# "slack.cache_channels": {
# "task": "slack.cache_channels",
@@ -2456,6 +2465,62 @@ except ImportError:
LOCAL_EXTENSIONS: list[str] = []
EXTENSIONS_PATH: str | None = None
# Default polling interval for tasks (seconds)
TASK_ABORT_POLLING_DEFAULT_INTERVAL = 10
# Minimum interval in seconds between database writes for task progress updates.
# Set to 0 to disable throttling (write every update to DB).
TASK_PROGRESS_UPDATE_THROTTLE_INTERVAL = 2 # seconds
# ---------------------------------------------------
# Signal Cache Configuration
# ---------------------------------------------------
# Shared Redis/Valkey configuration for signaling features that require
# Redis-specific primitives (pub/sub messaging, distributed locks).
#
# Uses Flask-Caching style configuration for consistency with other cache backends.
# Set CACHE_TYPE to 'RedisCache' for standard Redis or 'RedisSentinelCache' for
# Sentinel.
#
# These features cannot use generic cache backends because they rely on:
# - Pub/Sub: Real-time message broadcasting between workers
# - SET NX EX: Atomic lock acquisition with automatic expiration
#
# When configured, enables:
# - Real-time abort/completion notifications for GTF tasks (vs database polling)
# - Redis-based distributed locking (vs KeyValueDAO-backed DistributedLock)
#
# Future: This cache will also be used by Global Async Queries, consolidating
# GLOBAL_ASYNC_QUERIES_CACHE_BACKEND into this unified configuration.
#
# Example with standard Redis:
# SIGNAL_CACHE_CONFIG: CacheConfig = {
# "CACHE_TYPE": "RedisCache",
# "CACHE_REDIS_HOST": "localhost",
# "CACHE_REDIS_PORT": 6379,
# "CACHE_REDIS_DB": 0,
# "CACHE_REDIS_PASSWORD": "",
# }
#
# Example with Redis Sentinel:
# SIGNAL_CACHE_CONFIG: CacheConfig = {
# "CACHE_TYPE": "RedisSentinelCache",
# "CACHE_REDIS_SENTINELS": [("sentinel1", 26379), ("sentinel2", 26379)],
# "CACHE_REDIS_SENTINEL_MASTER": "mymaster",
# "CACHE_REDIS_SENTINEL_PASSWORD": None,
# "CACHE_REDIS_DB": 0,
# "CACHE_REDIS_PASSWORD": "",
# }
SIGNAL_CACHE_CONFIG: CacheConfig | None = None
# Default lock TTL (time-to-live) in seconds for distributed locks.
# Can be overridden per-call via the `ttl_seconds` parameter.
# After TTL expires, the lock is automatically released to prevent deadlocks.
DISTRIBUTED_LOCK_DEFAULT_TTL = 30
# Channel prefix for task abort pub/sub messages
TASKS_ABORT_CHANNEL_PREFIX = "gtf:abort:"
# -------------------------------------------------------------------
# * WARNING: STOP EDITING HERE *
# -------------------------------------------------------------------

View File

@@ -52,6 +52,7 @@ def inject_dao_implementations() -> None:
SavedQueryDAO as HostSavedQueryDAO,
)
from superset.daos.tag import TagDAO as HostTagDAO
from superset.daos.tasks import TaskDAO as HostTaskDAO
from superset.daos.user import UserDAO as HostUserDAO
# Replace abstract classes with concrete implementations
@@ -64,18 +65,7 @@ def inject_dao_implementations() -> None:
core_dao_module.SavedQueryDAO = HostSavedQueryDAO # type: ignore[assignment,misc]
core_dao_module.TagDAO = HostTagDAO # type: ignore[assignment,misc]
core_dao_module.KeyValueDAO = HostKeyValueDAO # type: ignore[assignment,misc]
core_dao_module.__all__ = [
"DatasetDAO",
"DatabaseDAO",
"ChartDAO",
"DashboardDAO",
"UserDAO",
"QueryDAO",
"SavedQueryDAO",
"TagDAO",
"KeyValueDAO",
]
core_dao_module.TaskDAO = HostTaskDAO # type: ignore[assignment,misc]
def inject_model_implementations() -> None:
@@ -94,6 +84,7 @@ def inject_model_implementations() -> None:
from superset.models.dashboard import Dashboard as HostDashboard
from superset.models.slice import Slice as HostChart
from superset.models.sql_lab import Query as HostQuery, SavedQuery as HostSavedQuery
from superset.models.tasks import Task as HostTask
from superset.tags.models import Tag as HostTag
# In-place replacement - extensions will import concrete implementations
@@ -106,6 +97,7 @@ def inject_model_implementations() -> None:
core_models_module.SavedQuery = HostSavedQuery # type: ignore[misc]
core_models_module.Tag = HostTag # type: ignore[misc]
core_models_module.KeyValue = HostKeyValue # type: ignore[misc]
core_models_module.Task = HostTask # type: ignore[misc]
def inject_query_implementations() -> None:
@@ -124,7 +116,23 @@ def inject_query_implementations() -> None:
)
core_query_module.get_sqlglot_dialect = get_sqlglot_dialect
core_query_module.__all__ = ["get_sqlglot_dialect"]
def inject_task_implementations() -> None:
"""
Replace abstract task functions in superset_core.api.tasks with concrete
implementations from Superset.
"""
import superset_core.api.tasks as core_tasks_module
from superset.tasks.ambient_context import get_context
from superset.tasks.context import TaskContext
from superset.tasks.decorators import task
# Replace abstract classes and functions with concrete implementations
core_tasks_module.TaskContext = TaskContext # type: ignore[assignment,misc]
core_tasks_module.task = task # type: ignore[assignment]
core_tasks_module.get_context = get_context
def inject_rest_api_implementations() -> None:
@@ -147,7 +155,6 @@ def inject_rest_api_implementations() -> None:
core_rest_api_module.add_api = add_api
core_rest_api_module.add_extension_api = add_extension_api
core_rest_api_module.__all__ = ["RestApi", "add_api", "add_extension_api"]
def inject_model_session_implementation() -> None:
@@ -163,7 +170,6 @@ def inject_model_session_implementation() -> None:
return db.session
core_models_module.get_session = get_session
# Update __all__ to include get_session (already done in the module)
def initialize_core_api_dependencies() -> None:
@@ -177,4 +183,5 @@ def initialize_core_api_dependencies() -> None:
inject_model_implementations()
inject_model_session_implementation()
inject_query_implementations()
inject_task_implementations()
inject_rest_api_implementations()

View File

@@ -18,7 +18,7 @@ from __future__ import annotations
import logging
from datetime import datetime
from typing import Dict, List, TYPE_CHECKING
from typing import Dict, List
from flask_appbuilder.models.sqla.interface import SQLAInterface
@@ -30,9 +30,6 @@ from superset.models.core import FavStar, FavStarClassName
from superset.models.slice import id_or_uuid_filter, Slice
from superset.utils.core import get_user_id
if TYPE_CHECKING:
pass
logger = logging.getLogger(__name__)
# Custom filterable fields for charts

470
superset/daos/tasks.py Normal file
View File

@@ -0,0 +1,470 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Task DAO for Global Task Framework (GTF)"""
import logging
from datetime import datetime, timezone
from typing import Any
from uuid import UUID
from superset_core.api.tasks import TaskProperties, TaskScope, TaskStatus
from superset.daos.base import BaseDAO
from superset.daos.exceptions import DAODeleteFailedError
from superset.extensions import db
from superset.models.task_subscribers import TaskSubscriber
from superset.models.tasks import Task
from superset.tasks.constants import ABORTABLE_STATES
from superset.tasks.filters import TaskFilter
from superset.tasks.utils import get_active_dedup_key, json
logger = logging.getLogger(__name__)
class TaskDAO(BaseDAO[Task]):
"""
Concrete TaskDAO for the Global Task Framework (GTF).
Provides database access operations for async tasks including
creation, status management, filtering, and subscription management
for shared tasks.
"""
base_filter = TaskFilter
@classmethod
def get_status(cls, task_uuid: UUID) -> str | None:
"""
Get only the status of a task by UUID.
This is a lightweight query that only fetches the status column,
optimized for polling endpoints where full entity loading is unnecessary.
Applies the base filter (TaskFilter) to enforce permission checks.
:param task_uuid: UUID of the task
:returns: Task status string, or None if task not found or not accessible
"""
# Start with query on Task model so base filter can be applied
query = db.session.query(Task)
query = cls._apply_base_filter(query)
query = query.filter(Task.uuid == task_uuid)
# Select only the status column for efficiency
result = query.with_entities(Task.status).one_or_none()
return result[0] if result else None
@classmethod
def find_by_task_key(
cls,
task_type: str,
task_key: str,
scope: TaskScope | str = TaskScope.PRIVATE,
user_id: int | None = None,
) -> Task | None:
"""
Find active task by type, key, scope, and user.
Uses dedup_key internally for efficient querying with a unique index.
Only returns tasks that are active (pending or in progress).
Uniqueness logic by scope:
- private: scope + task_type + task_key + user_id
- shared/system: scope + task_type + task_key (user-agnostic)
:param task_type: Task type to filter by
:param task_key: Task identifier for deduplication
:param scope: Task scope (private/shared/system)
:param user_id: User ID (required for private tasks)
:returns: Task instance or None if not found or not active
"""
dedup_key = get_active_dedup_key(
scope=scope,
task_type=task_type,
task_key=task_key,
user_id=user_id,
)
# Simple single-column query with unique index
return db.session.query(Task).filter(Task.dedup_key == dedup_key).one_or_none()
@classmethod
def create_task(
cls,
task_type: str,
task_key: str,
scope: TaskScope | str = TaskScope.PRIVATE,
user_id: int | None = None,
payload: dict[str, Any] | None = None,
properties: TaskProperties | None = None,
**kwargs: Any,
) -> Task:
"""
Create a new task record in the database.
This is a pure data operation - assumes caller holds lock and has
already checked for existing tasks. Business logic (create vs join)
is handled by SubmitTaskCommand.
:param task_type: Type of task to create
:param task_key: Task identifier (required)
:param scope: Task scope (private/shared/system), defaults to private
:param user_id: User ID creating the task
:param payload: Optional user-defined context data (dict)
:param properties: Optional framework-managed runtime state (e.g., timeout)
:param kwargs: Additional task attributes (e.g., task_name)
:returns: Created Task instance
"""
# Handle both TaskScope enum and string values
scope_value = scope.value if isinstance(scope, TaskScope) else scope
scope_enum = scope if isinstance(scope, TaskScope) else TaskScope(scope)
# Validate user_id is required for private tasks
if scope_enum == TaskScope.PRIVATE and user_id is None:
raise ValueError("user_id is required for private tasks")
# Build dedup_key for active task
dedup_key = get_active_dedup_key(
scope=scope,
task_type=task_type,
task_key=task_key,
user_id=user_id,
)
# Note: properties is handled separately via update_properties()
task_data = {
"task_type": task_type,
"task_key": task_key,
"scope": scope_value,
"status": TaskStatus.PENDING.value,
"dedup_key": dedup_key,
**kwargs,
}
# Handle payload - serialize to JSON if dict provided
if payload:
task_data["payload"] = json.dumps(payload)
if user_id is not None:
task_data["user_id"] = user_id
task = cls.create(attributes=task_data)
# Set properties after creation via update_properties (handles caching)
if properties:
task.update_properties(properties)
# Flush to get the task ID (auto-incremented primary key)
db.session.flush()
# Auto-subscribe creator for all tasks
# This enables consistent subscriber display across all task types
if user_id:
cls.add_subscriber(task.id, user_id)
logger.info(
"Creator %s auto-subscribed to task: %s (scope: %s)",
user_id,
task_key,
scope_value,
)
logger.info(
"Created new async task: %s (type: %s, scope: %s)",
task_key,
task_type,
scope_value,
)
return task
@classmethod
def abort_task(cls, task_uuid: UUID, skip_base_filter: bool = False) -> Task | None:
"""
Abort a task by UUID.
This is a pure data operation. Business logic (subscriber count checks,
permission validation) is handled by CancelTaskCommand which holds the lock.
Abort behavior by status:
- PENDING: Goes directly to ABORTED (always abortable)
- IN_PROGRESS with is_abortable=True: Goes to ABORTING
- IN_PROGRESS with is_abortable=False/None: Raises TaskNotAbortableError
- ABORTING: Returns task (idempotent)
- Finished statuses: Returns None
Note: Caller is responsible for calling TaskManager.publish_abort() AFTER
the transaction commits if task.status == ABORTING. This prevents race
conditions where listeners check the DB before the status is visible.
:param task_uuid: UUID of task to abort
:param skip_base_filter: If True, skip base filter (for admin abortions)
:returns: Task if aborted/aborting, None if not found or already finished
:raises TaskNotAbortableError: If in-progress task has no abort handler
"""
from superset.commands.tasks.exceptions import TaskNotAbortableError
task = cls.find_one_or_none(skip_base_filter=skip_base_filter, uuid=task_uuid)
if not task:
return None
# Already aborting - idempotent success
if task.status == TaskStatus.ABORTING.value:
logger.info("Task %s is already aborting", task_uuid)
return task
# Already finished - cannot abort
if task.status not in ABORTABLE_STATES:
return None
# PENDING: Go directly to ABORTED
if task.status == TaskStatus.PENDING.value:
task.set_status(TaskStatus.ABORTED)
logger.info("Aborted pending task: %s (scope: %s)", task_uuid, task.scope)
return task
# IN_PROGRESS: Check if abortable
if task.status == TaskStatus.IN_PROGRESS.value:
if task.properties_dict.get("is_abortable") is not True:
raise TaskNotAbortableError(
f"Task {task_uuid} is in progress but has not registered "
"an abort handler (is_abortable is not true)"
)
# Transition to ABORTING (not ABORTED yet)
task.status = TaskStatus.ABORTING.value
db.session.merge(task)
logger.info("Set task %s to ABORTING (scope: %s)", task_uuid, task.scope)
# NOTE: publish_abort is NOT called here - caller handles it after commit
# This prevents race conditions where listeners check DB before commit
return task
return None
# Subscription management methods
@classmethod
def add_subscriber(cls, task_id: int, user_id: int) -> bool:
"""
Add a user as a subscriber to a task.
:param task_id: ID of the task
:param user_id: ID of the user to subscribe
:returns: True if subscriber was added, False if already exists
"""
# Check first to avoid IntegrityError which invalidates the session
# in nested transaction contexts (IntegrityError can't be recovered from)
existing = (
db.session.query(TaskSubscriber)
.filter_by(task_id=task_id, user_id=user_id)
.first()
)
if existing:
logger.debug(
"Subscriber %s already subscribed to task %s", user_id, task_id
)
return False
subscription = TaskSubscriber(
task_id=task_id,
user_id=user_id,
subscribed_at=datetime.now(timezone.utc),
)
db.session.add(subscription)
db.session.flush()
logger.info("Added subscriber %s to task %s", user_id, task_id)
return True
@classmethod
def remove_subscriber(cls, task_id: int, user_id: int) -> Task | None:
"""
Remove a user's subscription from a task and return the updated task.
This is a pure data operation. Business logic (whether to abort after
last subscriber leaves) is handled by CancelTaskCommand which holds
the lock and decides whether to call abort_task() separately.
:param task_id: ID of the task
:param user_id: ID of the user to unsubscribe
:returns: Updated Task if subscriber was removed, None if not subscribed
:raises DAODeleteFailedError: If subscription removal fails
"""
subscription = (
db.session.query(TaskSubscriber)
.filter(
TaskSubscriber.task_id == task_id,
TaskSubscriber.user_id == user_id,
)
.one_or_none()
)
if not subscription:
return None
try:
db.session.delete(subscription)
db.session.flush()
logger.info("Removed subscriber %s from task %s", user_id, task_id)
# Return the updated task
task = cls.find_by_id(task_id, skip_base_filter=True)
if task:
db.session.refresh(task) # Ensure subscribers list is fresh
return task
except DAODeleteFailedError:
raise
except Exception as ex:
raise DAODeleteFailedError(
f"Failed to remove subscription for task {task_id}, user {user_id}"
) from ex
@classmethod
def set_properties_and_payload(
cls,
task_uuid: UUID,
properties: TaskProperties | None = None,
payload: dict[str, Any] | None = None,
) -> bool:
"""
Perform a zero-read SQL UPDATE on properties and/or payload columns.
This method directly writes the provided values without reading first.
The caller (TaskContext) is responsible for maintaining the authoritative
cached state and passing complete values to write.
This method is designed for internal task updates (progress, is_abortable)
where the executor owns the state and doesn't need to read before writing.
IMPORTANT: This method only touches properties and payload columns.
It does NOT touch the status column, so it's safe to use concurrently
with operations that modify status (like abort).
:param task_uuid: UUID of the task to update
:param properties: Complete properties dict to write (replaces existing)
:param payload: Complete payload dict to write (replaces existing)
:returns: True if task was updated, False if not found or nothing to update
"""
if properties is None and payload is None:
return False
# Build update values dict - no reads, just write what caller provides
update_values: dict[str, Any] = {}
if properties is not None:
# Write complete properties (caller manages merging in their cache)
update_values["properties"] = json.dumps(properties)
if payload is not None:
# Write complete payload (payload column name matches attribute name)
update_values["payload"] = json.dumps(payload)
if not update_values:
return False
# Execute targeted UPDATE - zero read, just write
rows_updated = (
db.session.query(Task)
.filter(Task.uuid == task_uuid)
.update(update_values, synchronize_session=False)
)
return rows_updated > 0
@classmethod
def conditional_status_update(
cls,
task_uuid: UUID,
new_status: TaskStatus | str,
expected_status: TaskStatus | str | list[TaskStatus | str],
properties: TaskProperties | None = None,
set_started_at: bool = False,
set_ended_at: bool = False,
) -> bool:
"""
Atomically update task status only if current status matches expected.
This provides atomic compare-and-swap semantics for status transitions,
preventing race conditions between executor status updates and concurrent
abort operations. Uses a single UPDATE with WHERE clause for atomicity.
Use cases:
- Executor transitioning IN_PROGRESS → SUCCESS (only if not ABORTING)
- Executor transitioning ABORTING → ABORTED/TIMED_OUT (cleanup complete)
- Initial PENDING → IN_PROGRESS (task pickup)
:param task_uuid: UUID of the task to update
:param new_status: Target status to set
:param expected_status: Current status(es) required for update to succeed.
Can be a single status or list of statuses.
:param properties: Optional properties to update atomically with status
:param set_started_at: If True, also set started_at to current timestamp
:param set_ended_at: If True, also set ended_at to current timestamp
:returns: True if status was updated (expected matched), False otherwise
"""
# Normalize status values
new_status_val = (
new_status.value if isinstance(new_status, TaskStatus) else new_status
)
# Build list of expected status values
if isinstance(expected_status, list):
expected_vals = [
s.value if isinstance(s, TaskStatus) else s for s in expected_status
]
else:
expected_vals = [
expected_status.value
if isinstance(expected_status, TaskStatus)
else expected_status
]
# Build update values
update_values: dict[str, Any] = {"status": new_status_val}
if properties is not None:
update_values["properties"] = json.dumps(properties)
if set_started_at:
update_values["started_at"] = datetime.now(timezone.utc)
if set_ended_at:
update_values["ended_at"] = datetime.now(timezone.utc)
# Atomic compare-and-swap: only update if status matches expected
rows_updated = (
db.session.query(Task)
.filter(Task.uuid == task_uuid, Task.status.in_(expected_vals))
.update(update_values, synchronize_session=False)
)
if rows_updated > 0:
logger.debug(
"Conditional status update succeeded: %s -> %s (expected: %s)",
task_uuid,
new_status_val,
expected_vals,
)
else:
logger.debug(
"Conditional status update skipped: %s -> %s "
"(current status not in expected: %s)",
task_uuid,
new_status_val,
expected_vals,
)
return rows_updated > 0

View File

@@ -17,61 +17,46 @@
from __future__ import annotations
import logging
import uuid
from collections.abc import Iterator
from contextlib import contextmanager
from datetime import timedelta
from typing import Any
from superset.distributed_lock.utils import get_key
from superset.exceptions import CreateKeyValueDistributedLockFailedException
from superset.key_value.types import JsonKeyValueCodec, KeyValueResource
logger = logging.getLogger(__name__)
CODEC = JsonKeyValueCodec()
LOCK_EXPIRATION = timedelta(seconds=30)
RESOURCE = KeyValueResource.LOCK
@contextmanager
def KeyValueDistributedLock( # pylint: disable=invalid-name # noqa: N802
def DistributedLock( # noqa: N802
namespace: str,
ttl_seconds: int | None = None,
**kwargs: Any,
) -> Iterator[uuid.UUID]:
"""
KV global lock for refreshing tokens.
Distributed lock for coordinating operations across workers.
This context manager acquires a distributed lock for a given namespace, with
optional parameters (eg, namespace="cache", user_id=1). It yields a UUID for the
lock that can be used within the context, and corresponds to the key in the KV
store.
Automatically uses Redis-based locking when SIGNAL_CACHE_CONFIG is
configured, falling back to database-backed locking otherwise.
:param namespace: The namespace for which the lock is to be acquired.
:param kwargs: Additional keyword arguments.
:yields: A unique identifier (UUID) for the acquired lock (the KV key).
:raises CreateKeyValueDistributedLockFailedException: If the lock is taken.
Redis locking uses SET NX EX for atomic acquisition with automatic expiration.
Database locking uses the KeyValue table with manual expiration cleanup.
:param namespace: Lock namespace for grouping related locks
:param ttl_seconds: Lock TTL in seconds. Defaults to 30 seconds.
After expiration, the lock is automatically released
to prevent deadlocks from crashed processes.
:param kwargs: Additional key parameters to differentiate locks
:yields: UUID identifying this lock acquisition
:raises AcquireDistributedLockFailedException: If lock is already held
or Redis connection fails
"""
# pylint: disable=import-outside-toplevel
from superset.commands.distributed_lock.create import CreateDistributedLock
from superset.commands.distributed_lock.delete import DeleteDistributedLock
from superset.commands.distributed_lock.get import GetDistributedLock
from superset.commands.distributed_lock.acquire import AcquireDistributedLock
from superset.commands.distributed_lock.release import ReleaseDistributedLock
key = get_key(namespace, **kwargs)
value = GetDistributedLock(namespace=namespace, params=kwargs).run()
if value:
logger.debug("Lock on namespace %s for key %s already taken", namespace, key)
raise CreateKeyValueDistributedLockFailedException("Lock already taken")
logger.debug("Acquiring lock on namespace %s for key %s", namespace, key)
AcquireDistributedLock(namespace, kwargs, ttl_seconds).run()
try:
CreateDistributedLock(namespace=namespace, params=kwargs).run()
except CreateKeyValueDistributedLockFailedException as ex:
logger.debug("Lock on namespace %s for key %s already taken", namespace, key)
raise CreateKeyValueDistributedLockFailedException("Lock already taken") from ex
yield key
DeleteDistributedLock(namespace=namespace, params=kwargs).run()
logger.debug("Removed lock on namespace %s for key %s", namespace, key)
yield key
finally:
ReleaseDistributedLock(namespace, kwargs).run()

View File

@@ -414,15 +414,15 @@ class SupersetDisallowedSQLTableException(SupersetErrorException):
)
class CreateKeyValueDistributedLockFailedException(Exception): # noqa: N818
class AcquireDistributedLockFailedException(Exception): # noqa: N818
"""
Exception to signalize failure to acquire lock.
"""
class DeleteKeyValueDistributedLockFailedException(Exception): # noqa: N818
class ReleaseDistributedLockFailedException(Exception): # noqa: N818
"""
Exception to signalize failure to delete lock.
Exception to signalize failure to release lock.
"""

View File

@@ -23,13 +23,10 @@ import os
import threading
import time
from pathlib import Path
from typing import Any, TYPE_CHECKING
from typing import Any
from flask import Flask
if TYPE_CHECKING:
pass
logger = logging.getLogger(__name__)
# Guard to prevent multiple initializations

View File

@@ -218,6 +218,7 @@ class SupersetAppInitializer: # pylint: disable=too-many-public-methods
)
from superset.views.sqllab import SqllabView
from superset.views.tags import TagModelView, TagView
from superset.views.tasks import TaskModelView
from superset.views.themes import ThemeModelView
from superset.views.user_info import UserInfoView
from superset.views.user_registrations import UserRegistrationsView
@@ -275,6 +276,11 @@ class SupersetAppInitializer: # pylint: disable=too-many-public-methods
appbuilder.add_api(ExtensionsRestApi)
if feature_flag_manager.is_feature_enabled("GLOBAL_TASK_FRAMEWORK"):
from superset.tasks.api import TaskRestApi
appbuilder.add_api(TaskRestApi)
#
# Setup regular views
#
@@ -408,6 +414,18 @@ class SupersetAppInitializer: # pylint: disable=too-many-public-methods
),
)
appbuilder.add_view(
TaskModelView,
"Tasks",
label=_("Tasks"),
icon="fa-clock-o",
category="Manage",
category_label=_("Manage"),
menu_cond=lambda: feature_flag_manager.is_feature_enabled(
"GLOBAL_TASK_FRAMEWORK"
),
)
#
# Setup views with no menu
#
@@ -588,6 +606,7 @@ class SupersetAppInitializer: # pylint: disable=too-many-public-methods
self.configure_async_queries()
self.configure_ssh_manager()
self.configure_stats_manager()
self.configure_task_manager()
# Hook that provides administrators a handle on the Flask APP
# after initialization
@@ -928,6 +947,13 @@ class SupersetAppInitializer: # pylint: disable=too-many-public-methods
if feature_flag_manager.is_feature_enabled("GLOBAL_ASYNC_QUERIES"):
async_query_manager_factory.init_app(self.superset_app)
def configure_task_manager(self) -> None:
"""Initialize the TaskManager for GTF realtime notifications."""
if feature_flag_manager.is_feature_enabled("GLOBAL_TASK_FRAMEWORK"):
from superset.tasks.manager import TaskManager
TaskManager.init_app(self.superset_app)
def register_blueprints(self) -> None:
# Register custom blueprints from config
for bp in self.config["BLUEPRINTS"]:

View File

@@ -0,0 +1,221 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Create tasks and task_subscriber tables for Global Task Framework (GTF)
Revision ID: 4b2a8c9d3e1f
Revises: 9787190b3d89
Create Date: 2025-12-18 02:20:00.000000
"""
from sqlalchemy import (
Column,
DateTime,
Integer,
String,
Text,
UniqueConstraint,
)
from sqlalchemy_utils import UUIDType
from superset.migrations.shared.utils import (
create_fks_for_table,
create_index,
create_table,
drop_fks_for_table,
drop_index,
drop_table,
)
# revision identifiers, used by Alembic.
revision = "4b2a8c9d3e1f"
down_revision = "9787190b3d89"
TASKS_TABLE = "tasks"
TASK_SUBSCRIBERS_TABLE = "task_subscribers"
def upgrade():
"""
Create tasks and task_subscribers tables for the Global Task Framework (GTF).
This migration creates:
1. tasks table - unified tracking for all long running tasks
2. task_subscribers table - multi-user task subscriptions for shared tasks
The scope feature allows tasks to be:
- private: user-specific (default)
- shared: multi-user collaborative tasks
- system: admin-only background tasks
"""
# Create tasks table
create_table(
TASKS_TABLE,
Column("id", Integer, primary_key=True),
Column("uuid", UUIDType(binary=True), nullable=False, unique=True),
Column("task_key", String(256), nullable=False),
Column("task_type", String(100), nullable=False),
Column("task_name", String(256), nullable=True),
Column("scope", String(20), nullable=False, server_default="private"),
Column("status", String(50), nullable=False),
Column("dedup_key", String(64), nullable=False),
# AuditMixinNullable columns
Column("created_on", DateTime, nullable=True),
Column("changed_on", DateTime, nullable=True),
Column("created_by_fk", Integer, nullable=True),
Column("changed_by_fk", Integer, nullable=True),
# Task-specific columns
Column("started_at", DateTime, nullable=True),
Column("ended_at", DateTime, nullable=True),
Column("user_id", Integer, nullable=True),
Column("payload", Text, nullable=True),
Column("properties", Text, nullable=True),
)
# Create indexes for optimal query performance
create_index(TASKS_TABLE, "idx_tasks_dedup_key", ["dedup_key"], unique=True)
create_index(TASKS_TABLE, "idx_tasks_status", ["status"])
create_index(TASKS_TABLE, "idx_tasks_scope", ["scope"])
create_index(TASKS_TABLE, "idx_tasks_ended_at", ["ended_at"])
create_index(TASKS_TABLE, "idx_tasks_created_by", ["created_by_fk"])
create_index(TASKS_TABLE, "idx_tasks_created_on", ["created_on"])
create_index(TASKS_TABLE, "idx_tasks_task_key", ["task_key"])
create_index(TASKS_TABLE, "idx_tasks_task_type", ["task_type"])
create_index(TASKS_TABLE, "idx_tasks_uuid", ["uuid"], unique=True)
# Create foreign key constraints for tasks
create_fks_for_table(
foreign_key_name="fk_tasks_created_by_fk_ab_user",
table_name=TASKS_TABLE,
referenced_table="ab_user",
local_cols=["created_by_fk"],
remote_cols=["id"],
ondelete="SET NULL",
)
create_fks_for_table(
foreign_key_name="fk_tasks_changed_by_fk_ab_user",
table_name=TASKS_TABLE,
referenced_table="ab_user",
local_cols=["changed_by_fk"],
remote_cols=["id"],
ondelete="SET NULL",
)
create_fks_for_table(
foreign_key_name="fk_tasks_user_id_ab_user",
table_name=TASKS_TABLE,
referenced_table="ab_user",
local_cols=["user_id"],
remote_cols=["id"],
ondelete="SET NULL",
)
# Create task_subscribers table for multi-user task subscriptions
create_table(
TASK_SUBSCRIBERS_TABLE,
Column("id", Integer, primary_key=True),
Column("task_id", Integer, nullable=False),
Column("user_id", Integer, nullable=False),
Column("subscribed_at", DateTime, nullable=False),
# AuditMixinNullable columns
Column("created_on", DateTime, nullable=True),
Column("created_by_fk", Integer, nullable=True),
Column("changed_on", DateTime, nullable=True),
Column("changed_by_fk", Integer, nullable=True),
# Unique constraint defined as part of table creation (SQLite compatible)
UniqueConstraint("task_id", "user_id", name="uq_task_subscribers_task_user"),
)
# Create indexes for task_subscribers table
create_index(TASK_SUBSCRIBERS_TABLE, "idx_task_subscribers_user_id", ["user_id"])
# Create foreign key constraints for task_subscribers
create_fks_for_table(
foreign_key_name="fk_task_subscribers_task_id_tasks",
table_name=TASK_SUBSCRIBERS_TABLE,
referenced_table=TASKS_TABLE,
local_cols=["task_id"],
remote_cols=["id"],
ondelete="CASCADE",
)
create_fks_for_table(
foreign_key_name="fk_task_subscribers_user_id_ab_user",
table_name=TASK_SUBSCRIBERS_TABLE,
referenced_table="ab_user",
local_cols=["user_id"],
remote_cols=["id"],
ondelete="CASCADE",
)
create_fks_for_table(
foreign_key_name="fk_task_subscribers_created_by_fk_ab_user",
table_name=TASK_SUBSCRIBERS_TABLE,
referenced_table="ab_user",
local_cols=["created_by_fk"],
remote_cols=["id"],
ondelete="SET NULL",
)
create_fks_for_table(
foreign_key_name="fk_task_subscribers_changed_by_fk_ab_user",
table_name=TASK_SUBSCRIBERS_TABLE,
referenced_table="ab_user",
local_cols=["changed_by_fk"],
remote_cols=["id"],
ondelete="SET NULL",
)
def downgrade():
"""
Drop tasks and task_subscribers tables and all related indexes and foreign keys.
"""
drop_fks_for_table(
TASK_SUBSCRIBERS_TABLE,
[
"fk_task_subscribers_task_id_tasks",
"fk_task_subscribers_user_id_ab_user",
"fk_task_subscribers_created_by_fk_ab_user",
"fk_task_subscribers_changed_by_fk_ab_user",
],
)
drop_index(TASK_SUBSCRIBERS_TABLE, "idx_task_subscribers_user_id")
drop_table(TASK_SUBSCRIBERS_TABLE)
drop_fks_for_table(
TASKS_TABLE,
[
"fk_tasks_created_by_fk_ab_user",
"fk_tasks_changed_by_fk_ab_user",
"fk_tasks_user_id_ab_user",
],
)
drop_index(TASKS_TABLE, "idx_tasks_dedup_key")
drop_index(TASKS_TABLE, "idx_tasks_status")
drop_index(TASKS_TABLE, "idx_tasks_scope")
drop_index(TASKS_TABLE, "idx_tasks_ended_at")
drop_index(TASKS_TABLE, "idx_tasks_created_by")
drop_index(TASKS_TABLE, "idx_tasks_created_on")
drop_index(TASKS_TABLE, "idx_tasks_task_key")
drop_index(TASKS_TABLE, "idx_tasks_task_type")
drop_index(TASKS_TABLE, "idx_tasks_uuid")
drop_table(TASKS_TABLE)

View File

@@ -0,0 +1,62 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""TaskSubscriber model for tracking multi-user task subscriptions"""
from datetime import datetime, timezone
from flask_appbuilder import Model
from sqlalchemy import Column, DateTime, ForeignKey, Integer, UniqueConstraint
from sqlalchemy.orm import relationship
from superset_core.api.models import TaskSubscriber as CoreTaskSubscriber
from superset.models.helpers import AuditMixinNullable
class TaskSubscriber(CoreTaskSubscriber, AuditMixinNullable, Model):
"""
Model for tracking task subscriptions in shared tasks.
This model enables multi-user collaboration on async tasks. When a user
schedules a shared task with the same parameters as an existing task,
they are automatically subscribed to that task instead of creating a
duplicate.
Subscribers can unsubscribe from shared tasks. When the last subscriber
unsubscribes, the task is automatically aborted.
"""
__tablename__ = "task_subscribers"
id = Column(Integer, primary_key=True)
task_id = Column(
Integer, ForeignKey("tasks.id", ondelete="CASCADE"), nullable=False
)
user_id = Column(
Integer, ForeignKey("ab_user.id", ondelete="CASCADE"), nullable=False
)
subscribed_at = Column(DateTime, nullable=False, default=datetime.now(timezone.utc))
# Relationships
task = relationship("Task", back_populates="subscribers")
user = relationship("User", foreign_keys=[user_id], lazy="joined")
__table_args__ = (
UniqueConstraint("task_id", "user_id", name="uq_task_subscribers_task_user"),
)
def __repr__(self) -> str:
return f"<TaskSubscriber user_id={self.user_id} task_id={self.task_id}>"

367
superset/models/tasks.py Normal file
View File

@@ -0,0 +1,367 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Task model for Global Task Framework (GTF)"""
from __future__ import annotations
import uuid as uuid_module
from datetime import datetime, timezone
from typing import Any, cast
from flask_appbuilder import Model
from sqlalchemy import (
Column,
DateTime,
Integer,
String,
Text,
)
from sqlalchemy.orm import relationship
from sqlalchemy_utils import UUIDType
from superset_core.api.models import Task as CoreTask
from superset_core.api.tasks import TaskProperties, TaskStatus
from superset.models.helpers import AuditMixinNullable
from superset.models.task_subscribers import TaskSubscriber
from superset.tasks.utils import (
error_update,
get_finished_dedup_key,
parse_properties,
serialize_properties,
)
from superset.utils import json
class Task(CoreTask, AuditMixinNullable, Model):
"""
Concrete Task model for the Global Task Framework (GTF).
This model represents async tasks in Superset, providing unified tracking
for all background operations including SQL queries, thumbnail generation,
reports, and other async operations.
Non-filterable fields (progress, error info, execution config) are stored
in a `properties` JSON blob for schema flexibility.
"""
__tablename__ = "tasks"
# Primary key and identifiers
id = Column(Integer, primary_key=True)
uuid = Column(
UUIDType(binary=True), nullable=False, unique=True, default=uuid_module.uuid4
)
# Task metadata (filterable)
task_key = Column(String(256), nullable=False, index=True) # For deduplication
task_type = Column(String(100), nullable=False, index=True) # e.g., 'sql_execution'
task_name = Column(String(256), nullable=True) # Human readable name
scope = Column(
String(20), nullable=False, index=True, default="private"
) # private/shared/system
status = Column(
String(50), nullable=False, index=True, default=TaskStatus.PENDING.value
)
dedup_key = Column(
String(64), nullable=False, unique=True, index=True
) # Hashed deduplication key (SHA-256 = 64 chars, UUID = 36 chars)
# Timestamps
started_at = Column(DateTime, nullable=True)
ended_at = Column(DateTime, nullable=True)
# User context for execution
user_id = Column(Integer, nullable=True)
# Task-specific output data (set by task code via ctx.update_task(payload=...))
payload = Column(Text, nullable=True, default="{}")
# Properties JSON blob - contains runtime state and execution config:
# - is_abortable: bool - has abort handler registered
# - progress_percent: float - progress 0.0-1.0
# - progress_current: int - current iteration count
# - progress_total: int - total iterations
# - error_message: str - human-readable error message
# - exception_type: str - exception class name
# - stack_trace: str - full formatted traceback
# - timeout: int - timeout in seconds
properties = Column(Text, nullable=True, default="{}")
# Relationships
# Use lazy="selectin" to avoid N+1 queries when listing tasks with subscribers
subscribers = relationship(
TaskSubscriber,
back_populates="task",
cascade="all, delete-orphan",
lazy="selectin",
)
def __repr__(self) -> str:
return f"<Task {self.task_type}:{self.task_key} [{self.status}]>"
# -------------------------------------------------------------------------
# Properties accessor
# -------------------------------------------------------------------------
@property
def properties_dict(self) -> TaskProperties:
"""
Get typed properties.
Properties contain runtime state and execution config that doesn't
need database filtering. Always use .get() for reads since keys may
be absent.
:returns: TaskProperties dict (sparse - only contains keys that were set)
"""
return parse_properties(self.properties)
def update_properties(self, updates: TaskProperties) -> None:
"""
Update specific properties fields (merge semantics).
Only updates fields present in the updates dict.
:param updates: TaskProperties dict with fields to update
Example:
task.update_properties({"is_abortable": True})
task.update_properties(progress_update((50, 100)))
"""
current = cast(TaskProperties, dict(self.properties_dict))
current.update(updates) # Merge updates
self.properties = serialize_properties(current)
# -------------------------------------------------------------------------
# Payload accessor (for task-specific output data)
# -------------------------------------------------------------------------
@property
def payload_dict(self) -> dict[str, Any]:
"""
Get payload as parsed JSON.
Payload contains task-specific output data set by task code via
ctx.update_task(payload=...).
:returns: Dictionary containing payload data
"""
try:
return json.loads(self.payload or "{}")
except (json.JSONDecodeError, TypeError):
return {}
def set_payload(self, data: dict[str, Any]) -> None:
"""
Update payload with new data.
The payload is merged with existing data, not replaced.
:param data: Dictionary of data to merge into payload
"""
current = self.payload_dict
current.update(data)
self.payload = json.dumps(current)
# -------------------------------------------------------------------------
# Error handling
# -------------------------------------------------------------------------
def set_error_from_exception(self, exception: BaseException) -> None:
"""
Set error fields from an exception.
Captures the error message, exception type, and full stack trace.
Called automatically by the executor when a task raises an exception.
:param exception: The exception that caused the failure
"""
self.update_properties(error_update(exception))
# -------------------------------------------------------------------------
# Status management
# -------------------------------------------------------------------------
def set_status(self, status: TaskStatus | str) -> None:
"""
Update task status and dedup_key.
When a task finishes (success, failure, or abort), the dedup_key is
changed to the task's UUID. This frees up the slot so new tasks with
the same parameters can be created.
:param status: New task status
"""
if isinstance(status, TaskStatus):
status = status.value
self.status = status
# Update timestamps and is_abortable based on status
now = datetime.now(timezone.utc)
if status == TaskStatus.IN_PROGRESS.value and not self.started_at:
self.started_at = now
# Set is_abortable to False when task starts executing
# (will be set to True if/when an abort handler is registered)
if self.properties_dict.get("is_abortable") is None:
self.update_properties({"is_abortable": False})
elif status in [
TaskStatus.SUCCESS.value,
TaskStatus.FAILURE.value,
TaskStatus.ABORTED.value,
TaskStatus.TIMED_OUT.value,
]:
if not self.ended_at:
self.ended_at = now
# Update dedup_key to UUID to free up the slot for new tasks
self.dedup_key = get_finished_dedup_key(self.uuid)
# Note: ABORTING status doesn't set ended_at yet - that happens when
# the task transitions to ABORTED after handlers complete
@property
def is_pending(self) -> bool:
"""Check if task is pending."""
return self.status == TaskStatus.PENDING.value
@property
def is_running(self) -> bool:
"""Check if task is currently running."""
return self.status == TaskStatus.IN_PROGRESS.value
@property
def is_finished(self) -> bool:
"""Check if task has finished (success, failure, aborted, or timed out)."""
return self.status in [
TaskStatus.SUCCESS.value,
TaskStatus.FAILURE.value,
TaskStatus.ABORTED.value,
TaskStatus.TIMED_OUT.value,
]
@property
def is_successful(self) -> bool:
"""Check if task completed successfully."""
return self.status == TaskStatus.SUCCESS.value
@property
def duration_seconds(self) -> float | None:
"""
Get task duration in seconds.
- Finished tasks: Time from started_at to ended_at (None if never started)
- Running/aborting tasks: Time from started_at to now
- Pending tasks: Time from created_on to now (queue time)
Note: started_at/ended_at are stored in UTC, but created_on from
AuditMixinNullable is stored as naive local time. We handle both cases.
"""
if self.is_finished:
# Task has completed - use fixed timestamps, never increment
if self.started_at and self.ended_at:
# Finished task - both timestamps use the same timezone (UTC)
# Just compute the difference directly
return (self.ended_at - self.started_at).total_seconds()
# Never started (e.g., aborted while pending) - no duration
return None
elif self.started_at:
# Running or aborting - started_at is UTC (set by set_status)
# Use UTC now for comparison
now = datetime.now(timezone.utc)
started = (
self.started_at.replace(tzinfo=timezone.utc)
if self.started_at.tzinfo is None
else self.started_at
)
return (now - started).total_seconds()
elif self.created_on:
# Pending - created_on is naive LOCAL time (from AuditMixinNullable)
# Use naive local time for comparison
now = datetime.now() # Local time, no timezone
created = (
self.created_on.replace(tzinfo=None)
if self.created_on.tzinfo is not None
else self.created_on
)
return (now - created).total_seconds()
return None
# Scope-related properties
@property
def is_private(self) -> bool:
"""Check if task is private (user-specific)."""
return self.scope == "private"
@property
def is_shared(self) -> bool:
"""Check if task is shared (multi-user)."""
return self.scope == "shared"
@property
def is_system(self) -> bool:
"""Check if task is system (admin-only)."""
return self.scope == "system"
# Subscriber-related methods
@property
def subscriber_count(self) -> int:
"""Get number of subscribers to this task."""
return len(self.subscribers)
def has_subscriber(self, user_id: int) -> bool:
"""
Check if a user is subscribed to this task.
:param user_id: User ID to check
:returns: True if user is subscribed
"""
return any(sub.user_id == user_id for sub in self.subscribers)
def get_subscriber_ids(self) -> list[int]:
"""
Get list of all subscriber user IDs.
:returns: List of user IDs subscribed to this task
"""
return [sub.user_id for sub in self.subscribers]
def to_dict(self) -> dict[str, Any]:
"""
Convert task to dictionary representation.
Minimal API payload - frontend derives status booleans and abort logic
from status and properties.is_abortable.
:returns: Dictionary representation of the task
"""
return {
"id": self.id,
"uuid": str(self.uuid),
"task_key": self.task_key,
"task_type": self.task_type,
"task_name": self.task_name,
"scope": self.scope,
"status": self.status,
"created_on": self.created_on.isoformat() if self.created_on else None,
"changed_on": self.changed_on.isoformat() if self.changed_on else None,
"started_at": self.started_at.isoformat() if self.started_at else None,
"ended_at": self.ended_at.isoformat() if self.ended_at else None,
"created_by_fk": self.created_by_fk,
"user_id": self.user_id,
"payload": self.payload_dict,
"properties": self.properties_dict,
"subscriber_count": self.subscriber_count,
"subscriber_ids": self.get_subscriber_ids(),
}

View File

@@ -26,7 +26,7 @@ from __future__ import annotations
import dataclasses
import logging
import uuid
from typing import Any, TYPE_CHECKING
from typing import Any
import msgpack
from celery.exceptions import SoftTimeLimitExceeded
@@ -56,9 +56,6 @@ from superset.utils.core import override_user, zlib_compress
from superset.utils.dates import now_as_float
from superset.utils.decorators import stats_timing
if TYPE_CHECKING:
pass
logger = logging.getLogger(__name__)
BYTES_IN_MB = 1024 * 1024

View File

@@ -0,0 +1,87 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Ambient context management for the Global Task Framework (GTF)"""
from contextlib import contextmanager
from contextvars import ContextVar
from typing import Iterator
from superset.tasks.context import TaskContext
# Global context variable for ambient context pattern
# This is thread-safe and async-safe via Python's contextvars
_current_context: ContextVar[TaskContext | None] = ContextVar(
"task_context", default=None
)
def get_context() -> TaskContext:
"""
Get the current task context from contextvars.
This function provides ambient access to the task context without
requiring it to be passed as a parameter. It can only be called
from within a task execution.
:returns: The current TaskContext
:raises RuntimeError: If called outside a task execution context
Example:
>>> @task()
>>> def my_task(chart_id: int) -> None:
>>> ctx = get_context() # Access ambient context
>>>
>>> # Update progress and payload atomically
>>> ctx.update_task(
>>> progress=0.5,
>>> payload={"chart_id": chart_id}
>>> )
"""
ctx = _current_context.get()
if ctx is None:
raise RuntimeError(
"get_context() called outside task execution context. "
"This function can only be called from within a @task "
"decorated function."
)
return ctx
@contextmanager
def use_context(ctx: TaskContext) -> Iterator[None]:
"""
Context manager to set ambient context for task execution.
This is used internally by the framework to establish the ambient context
before executing a task function. The context is automatically cleaned up
after execution, even if the task raises an exception.
:param ctx: TaskContext to set as the current context
:yields: None
Example (internal framework use):
>>> ctx = TaskContext(task_uuid=task.uuid)
>>> with use_context(ctx):
>>> # Task function can now call get_context()
>>> task_function(*args, **kwargs)
>>> # Context automatically reset after execution
"""
token = _current_context.set(ctx)
try:
yield
finally:
_current_context.reset(token)

471
superset/tasks/api.py Normal file
View File

@@ -0,0 +1,471 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Task REST API"""
import logging
from uuid import UUID
from flask import Response
from flask_appbuilder.api import expose, protect, safe
from flask_appbuilder.models.sqla.interface import SQLAInterface
from superset.commands.tasks.cancel import CancelTaskCommand
from superset.commands.tasks.exceptions import (
TaskAbortFailedError,
TaskForbiddenError,
TaskNotAbortableError,
TaskNotFoundError,
TaskPermissionDeniedError,
)
from superset.constants import MODEL_API_RW_METHOD_PERMISSION_MAP, RouteMethod
from superset.extensions import event_logger
from superset.models.tasks import Task
from superset.tasks.filters import TaskFilter
from superset.tasks.schemas import (
openapi_spec_methods_override,
TaskCancelRequestSchema,
TaskCancelResponseSchema,
TaskResponseSchema,
TaskStatusResponseSchema,
)
from superset.views.base_api import (
BaseSupersetModelRestApi,
RelatedFieldFilter,
statsd_metrics,
)
from superset.views.filters import BaseFilterRelatedUsers, FilterRelatedOwners
logger = logging.getLogger(__name__)
class TaskRestApi(BaseSupersetModelRestApi):
"""REST API for task management"""
datamodel = SQLAInterface(Task)
resource_name = "task"
allow_browser_login = True
class_permission_name = "Task"
# Map cancel and status to write/read permissions
method_permission_name = {
**MODEL_API_RW_METHOD_PERMISSION_MAP,
"cancel": "write",
"status": "read",
}
# Only allow read operations - no create/update/delete through REST API
# Tasks are created via SubmitTaskCommand, cancelled via /cancel endpoint
include_route_methods = {
RouteMethod.GET,
RouteMethod.GET_LIST,
RouteMethod.INFO,
"cancel",
"status",
"related_subscribers",
"related",
}
list_columns = [
"id",
"uuid",
"task_type",
"task_key",
"task_name",
"scope",
"status",
"created_on",
"created_on_delta_humanized",
"changed_on",
"changed_by.first_name",
"changed_by.last_name",
"started_at",
"ended_at",
"created_by.id",
"created_by.first_name",
"created_by.last_name",
"user_id",
"payload",
"properties",
"duration_seconds",
"subscriber_count",
"subscribers",
]
list_select_columns = list_columns + ["created_by_fk", "changed_by_fk"]
show_columns = list_columns
order_columns = [
"task_type",
"scope",
"status",
"created_on",
"changed_on",
"started_at",
"ended_at",
]
search_columns = [
"task_type",
"task_key",
"task_name",
"scope",
"status",
"created_by",
"created_on",
]
base_order = ("created_on", "desc")
base_filters = [["id", TaskFilter, lambda: []]]
# Related field configuration for filter dropdowns
allowed_rel_fields = {"created_by"}
related_field_filters = {
"created_by": RelatedFieldFilter("first_name", FilterRelatedOwners),
}
base_related_field_filters = {
"created_by": [["id", BaseFilterRelatedUsers, lambda: []]],
}
show_model_schema = TaskResponseSchema()
list_model_schema = TaskResponseSchema()
cancel_request_schema = TaskCancelRequestSchema()
openapi_spec_tag = "Tasks"
openapi_spec_component_schemas = (
TaskResponseSchema,
TaskCancelRequestSchema,
TaskCancelResponseSchema,
TaskStatusResponseSchema,
)
openapi_spec_methods = openapi_spec_methods_override
@expose("/<task_uuid>", methods=("GET",))
@protect()
@safe
@statsd_metrics
@event_logger.log_this_with_context(
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.get",
log_to_statsd=False,
)
def get(self, task_uuid: str) -> Response:
"""Get a task.
---
get:
summary: Get a task
parameters:
- in: path
schema:
type: string
format: uuid
name: task_uuid
description: The UUID of the task
responses:
200:
description: Task detail
content:
application/json:
schema:
type: object
properties:
result:
$ref: '#/components/schemas/TaskResponseSchema'
401:
$ref: '#/components/responses/401'
403:
$ref: '#/components/responses/403'
404:
$ref: '#/components/responses/404'
"""
from superset.daos.tasks import TaskDAO
try:
uuid = UUID(task_uuid)
task = TaskDAO.find_one_or_none(uuid=uuid)
if not task:
return self.response_404()
result = self.show_model_schema.dump(task)
return self.response(200, result=result)
except (ValueError, TypeError):
return self.response_404()
@expose("/<task_uuid>/status", methods=("GET",))
@protect()
@safe
@statsd_metrics
@event_logger.log_this_with_context(
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.status",
log_to_statsd=False,
)
def status(self, task_uuid: str) -> Response:
"""Get only the status of a task (lightweight for polling).
---
get:
summary: Get task status
parameters:
- in: path
schema:
type: string
format: uuid
name: task_uuid
description: The UUID of the task
responses:
200:
description: Task status
content:
application/json:
schema:
type: object
properties:
status:
type: string
description: Current status of the task
401:
$ref: '#/components/responses/401'
403:
$ref: '#/components/responses/403'
404:
$ref: '#/components/responses/404'
"""
from superset.daos.tasks import TaskDAO
try:
uuid = UUID(task_uuid)
status = TaskDAO.get_status(uuid)
if status is None:
return self.response_404()
return self.response(200, status=status)
except (ValueError, TypeError):
return self.response_404()
@expose("/<task_uuid>/cancel", methods=("POST",))
@protect()
@safe
@statsd_metrics
@event_logger.log_this_with_context(
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.cancel",
log_to_statsd=False,
)
def cancel(self, task_uuid: str) -> Response:
"""Cancel a task.
---
post:
summary: Cancel a task
description: >
Cancel a task. The behavior depends on task scope and subscriber
count:
- **Private tasks**: Aborts the task
- **Shared tasks (single subscriber)**: Aborts the task
- **Shared tasks (multiple subscribers)**: Removes current user's
subscription; the task continues for other subscribers
- **Shared tasks with force=true (admin only)**: Aborts task for
all subscribers
The `action` field in the response indicates what happened:
- `aborted`: Task was terminated
- `unsubscribed`: User was removed from task (task continues)
parameters:
- in: path
schema:
type: string
format: uuid
name: task_uuid
description: The UUID of the task to cancel
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/TaskCancelRequestSchema'
responses:
200:
description: Task cancelled successfully
content:
application/json:
schema:
$ref: '#/components/schemas/TaskCancelResponseSchema'
401:
$ref: '#/components/responses/401'
403:
$ref: '#/components/responses/403'
404:
$ref: '#/components/responses/404'
422:
$ref: '#/components/responses/422'
"""
return self._execute_cancel(task_uuid)
def _execute_cancel(self, task_uuid_str: str) -> Response:
"""Execute the cancel operation with error handling."""
try:
task_uuid = UUID(task_uuid_str)
command, updated_task = self._run_cancel_command(task_uuid)
return self._build_cancel_response(command, updated_task)
except TaskNotFoundError:
return self.response_404()
except (TaskForbiddenError, TaskPermissionDeniedError) as ex:
if isinstance(ex, TaskPermissionDeniedError):
logger.warning(
"Permission denied cancelling task %s: %s",
task_uuid_str,
str(ex),
)
return self.response_403()
except TaskNotAbortableError as ex:
logger.warning("Task %s is not cancellable: %s", task_uuid_str, str(ex))
return self.response_422(message=str(ex))
except TaskAbortFailedError as ex:
logger.error(
"Error cancelling task %s: %s", task_uuid_str, str(ex), exc_info=True
)
return self.response_422(message=str(ex))
except (ValueError, TypeError):
return self.response_404()
def _run_cancel_command(self, task_uuid: UUID) -> tuple[CancelTaskCommand, "Task"]:
"""Parse request and run the cancel command."""
from flask import request
force = False
# Use get_json with silent=True to handle missing Content-Type gracefully
json_data = request.get_json(silent=True)
if json_data:
parsed = self.cancel_request_schema.load(json_data)
force = parsed.get("force", False)
command = CancelTaskCommand(task_uuid, force=force)
updated_task = command.run()
return command, updated_task
def _build_cancel_response(
self, command: CancelTaskCommand, updated_task: "Task"
) -> Response:
"""Build the response for a successful cancel operation."""
action = command.action_taken
message = (
"Task cancelled"
if action == "aborted"
else "You have been removed from this task"
)
result = {
"message": message,
"action": action,
"task": self.show_model_schema.dump(updated_task),
}
return self.response(200, **result)
@expose("/related/subscribers", methods=("GET",))
@protect()
@safe
@statsd_metrics
@event_logger.log_this_with_context(
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}"
".related_subscribers",
log_to_statsd=False,
)
def related_subscribers(self) -> Response:
"""Get users who are subscribers to tasks.
---
get:
summary: Get related subscribers
description: >
Returns a list of users who are subscribed to tasks, for use in filter
dropdowns. Results can be filtered by a search query parameter.
parameters:
- in: query
schema:
type: string
name: q
description: Search query to filter subscribers by name
responses:
200:
description: List of subscribers
content:
application/json:
schema:
type: object
properties:
count:
type: integer
description: Total number of matching subscribers
result:
type: array
items:
type: object
properties:
value:
type: integer
description: User ID
text:
type: string
description: User display name
401:
$ref: '#/components/responses/401'
"""
from flask import request
from superset import db, security_manager
from superset.models.task_subscribers import TaskSubscriber
# Get search query
# Get user model
user_model = security_manager.user_model
# Query distinct users who are task subscribers
query = (
db.session.query(user_model.id, user_model.first_name, user_model.last_name)
.join(TaskSubscriber, user_model.id == TaskSubscriber.user_id)
.distinct()
)
# Apply search filter if provided
if search_query := request.args.get("q", ""):
like_value = f"%{search_query}%"
query = query.filter(
(user_model.first_name + " " + user_model.last_name).ilike(like_value)
| user_model.username.ilike(like_value)
)
# Order by name
query = query.order_by(user_model.first_name, user_model.last_name)
# Limit results
query = query.limit(100)
# Execute and format results
results = query.all()
return self.response(
200,
count=len(results),
result=[
{
"value": user_id,
"text": f"{first_name or ''} {last_name or ''}".strip()
or str(user_id),
}
for user_id, first_name, last_name in results
],
)

View File

@@ -0,0 +1,54 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Constants for the Global Task Framework (GTF)."""
from superset_core.api.tasks import TaskStatus
# Terminal states: Task execution has ended and dedup_key slot is freed
TERMINAL_STATES: frozenset[str] = frozenset(
{
TaskStatus.SUCCESS.value,
TaskStatus.FAILURE.value,
TaskStatus.ABORTED.value,
TaskStatus.TIMED_OUT.value,
}
)
# Active states: Task is still in progress and dedup_key is reserved
ACTIVE_STATES: frozenset[str] = frozenset(
{
TaskStatus.PENDING.value,
TaskStatus.IN_PROGRESS.value,
TaskStatus.ABORTING.value,
}
)
# Abortable states: Task can be aborted (for pending or abortable in-progress)
ABORTABLE_STATES: frozenset[str] = frozenset(
{
TaskStatus.PENDING.value,
TaskStatus.IN_PROGRESS.value,
}
)
# Abort-related states: Task is being or has been aborted
ABORT_STATES: frozenset[str] = frozenset(
{
TaskStatus.ABORTING.value,
TaskStatus.ABORTED.value,
}
)

673
superset/tasks/context.py Normal file
View File

@@ -0,0 +1,673 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Concrete TaskContext implementation for GTF"""
import logging
import threading
import time
import traceback
from typing import Any, Callable, cast, TYPE_CHECKING, TypeVar
from flask import current_app
from superset_core.api.tasks import (
TaskContext as CoreTaskContext,
TaskProperties,
TaskStatus,
)
from superset.stats_logger import BaseStatsLogger
from superset.tasks.constants import ABORT_STATES
from superset.tasks.utils import progress_update
if TYPE_CHECKING:
from superset.models.tasks import Task
from superset.tasks.manager import AbortListener
logger = logging.getLogger(__name__)
T = TypeVar("T")
class TaskContext(CoreTaskContext):
"""
Concrete implementation of TaskContext for the Global Async Task Framework.
Provides write-only access to task state. Tasks use this context to update
their progress and payload, and check for cancellation. Tasks should not
need to read their own state - they are the source of state, not consumers.
"""
# Type alias for handler failures: (handler_type, exception, stack_trace)
HandlerFailure = tuple[str, Exception, str]
def __init__(self, task: "Task") -> None:
"""
Initialize TaskContext with a pre-fetched task entity.
The task entity must be pre-fetched by the caller (executor) to ensure
caching works correctly and to enforce the pattern of single initial fetch.
:param task: Pre-fetched Task entity (required)
"""
self._task_uuid = task.uuid
self._cleanup_handlers: list[Callable[[], None]] = []
self._abort_handlers: list[Callable[[], None]] = []
self._abort_listener: "AbortListener | None" = None
self._abort_detected = False
self._abort_handlers_completed = False # Track if all abort handlers finished
self._execution_completed = False # Set by executor after task work completes
# Collected handler failures for unified reporting
self._handler_failures: list[TaskContext.HandlerFailure] = []
# Timeout timer state
self._timeout_timer: threading.Timer | None = None
self._timeout_triggered = False
# Throttling state for update_task()
# These manage the minimum interval between DB writes
self._last_db_write_time: float | None = None
self._has_pending_updates: bool = False
self._deferred_flush_timer: threading.Timer | None = None
self._throttle_lock = threading.Lock()
# Cached task entity - avoids repeated DB fetches.
# Updated only by _refresh_task() when checking external state changes.
self._task: "Task" = task
# In-memory state caches - authoritative during execution
# These are initialized from the task entity and updated locally
# before being written to DB via targeted SQL updates.
# We copy the dicts to avoid mutating the Task's cached instances.
self._properties_cache: TaskProperties = cast(
TaskProperties, {**task.properties_dict}
)
self._payload_cache: dict[str, Any] = {**task.payload_dict}
# Store Flask app reference for background thread database access
# Use _get_current_object() to get actual app, not proxy
try:
self._app = current_app._get_current_object()
# Cache stats logger to avoid repeated config lookups
self._stats_logger: BaseStatsLogger = current_app.config.get(
"STATS_LOGGER", BaseStatsLogger()
)
except RuntimeError:
# Handle case where app context isn't available (e.g., tests)
self._app = None
self._stats_logger = BaseStatsLogger()
def _refresh_task(self) -> "Task":
"""
Force refresh the task entity from the database.
Use this method when you need to check for external state changes,
such as whether the task has been aborted by a concurrent operation.
This method:
- Fetches fresh task entity from database
- Updates the cached _task reference
- Updates properties/payload caches from fresh data
:returns: Fresh task entity from database
:raises ValueError: If task is not found
"""
from superset.daos.tasks import TaskDAO
fresh_task = TaskDAO.find_one_or_none(uuid=self._task_uuid)
if not fresh_task:
raise ValueError(f"Task {self._task_uuid} not found")
self._task = fresh_task
# Update caches from fresh data (copy to avoid mutating Task's cache)
self._properties_cache = cast(TaskProperties, {**fresh_task.properties_dict})
self._payload_cache = {**fresh_task.payload_dict}
return self._task
def update_task(
self,
progress: float | int | tuple[int, int] | None = None,
payload: dict[str, object] | None = None,
) -> None:
"""
Update task progress and/or payload atomically.
All parameters are optional. Payload is merged with existing cached data.
In-memory caches are always updated immediately, but DB writes are
throttled according to TASK_PROGRESS_UPDATE_THROTTLE_INTERVAL to prevent
excessive database load from eager tasks.
Progress can be specified in three ways:
- float (0.0-1.0): Percentage only, e.g., 0.5 means 50%
- int: Count only (total unknown), e.g., 42 means "42 items processed"
- tuple[int, int]: Count and total, e.g., (3, 100) means "3 of 100"
The percentage is automatically computed from count/total.
:param progress: Progress value, or None to leave unchanged
:param payload: Payload data to merge (dict), or None to leave unchanged
"""
has_updates = False
# Handle progress updates - always update in-memory cache
if progress is not None:
progress_props = progress_update(progress)
if progress_props:
# Merge progress into cached properties
self._properties_cache.update(progress_props)
has_updates = True
else:
# Invalid progress format - progress_update returns empty dict
logger.warning(
"Invalid progress value for task %s: %s "
"(expected float, int, or tuple[int, int])",
self._task_uuid,
progress,
)
# Handle payload updates - always update in-memory cache
if payload is not None:
# Merge payload into cached payload
self._payload_cache.update(payload)
has_updates = True
if not has_updates:
return
# Get throttle interval from config
throttle_interval = current_app.config["TASK_PROGRESS_UPDATE_THROTTLE_INTERVAL"]
# If throttling is disabled (0), write immediately
if throttle_interval <= 0:
self._write_to_db()
return
# Apply throttling with deferred flush
with self._throttle_lock:
now = time.time()
if self._last_db_write_time is None:
# First update - write immediately
self._write_to_db()
self._last_db_write_time = now
elif now - self._last_db_write_time >= throttle_interval:
# Throttle window has passed - write immediately
self._cancel_deferred_flush_timer()
self._write_to_db()
self._last_db_write_time = now
self._has_pending_updates = False
else:
# Within throttle window - defer the write
self._has_pending_updates = True
self._stats_logger.incr("gtf.task.update_deferred")
# Start deferred flush timer if not already running
if self._deferred_flush_timer is None:
remaining_time = throttle_interval - (
now - self._last_db_write_time
)
self._deferred_flush_timer = threading.Timer(
remaining_time, self._deferred_flush
)
self._deferred_flush_timer.daemon = True
self._deferred_flush_timer.start()
def _write_to_db(self) -> None:
"""
Write current cached state to database.
This method performs the actual DB write using InternalUpdateTaskCommand.
It writes whatever is in the caches at the time of the call.
"""
from superset.commands.tasks.internal_update import InternalUpdateTaskCommand
self._stats_logger.incr("gtf.task.update_write")
InternalUpdateTaskCommand(
task_uuid=self._task_uuid,
properties=self._properties_cache,
payload=self._payload_cache,
).run()
def _deferred_flush(self) -> None:
"""
Timer callback that flushes pending updates at end of throttle window.
This ensures the UI never shows stale progress for longer than the
throttle interval.
"""
with self._throttle_lock:
self._deferred_flush_timer = None
if self._has_pending_updates:
# Need app context for DB operations in timer thread
if self._app:
with self._app.app_context():
self._write_to_db()
else:
self._write_to_db()
self._last_db_write_time = time.time()
self._has_pending_updates = False
def _cancel_deferred_flush_timer(self) -> None:
"""Cancel the deferred flush timer if running."""
if self._deferred_flush_timer is not None:
self._deferred_flush_timer.cancel()
self._deferred_flush_timer = None
def on_cleanup(self, handler: Callable[[], None]) -> Callable[[], None]:
"""
Register a cleanup handler that runs when the task ends.
Cleanup handlers are called when the task completes (success),
fails with an error, or is aborted. Multiple handlers can be
registered and will execute in LIFO order (last registered runs first).
Can be used as a decorator:
@ctx.on_cleanup
def cleanup():
logger.info("Task ended")
Or called directly:
ctx.on_cleanup(lambda: logger.info("Task ended"))
:param handler: Cleanup function to register
:returns: The handler (for decorator compatibility)
"""
self._cleanup_handlers.append(handler)
return handler
def on_abort(self, handler: Callable[[], None]) -> Callable[[], None]:
"""
Register abort handler with automatic background listening.
When the first handler is registered:
1. Sets is_abortable=true in the database (marks task as abortable)
2. Background abort listener starts automatically (pub/sub or polling)
The handler will be called automatically when an abort is detected.
:param handler: Callback function to execute when abort is detected
:returns: The handler (for decorator compatibility)
Example:
@ctx.on_abort
def handle_abort():
logger.info("Task was aborted!")
cleanup_partial_work()
Note:
The handler executes in a background thread when abort is detected.
The task code continues running unless the handler does something
to stop it (e.g., raises an exception, modifies shared state, etc.)
"""
is_first_handler = len(self._abort_handlers) == 0
self._abort_handlers.append(handler)
if is_first_handler:
# Mark task as abortable in database
self._set_abortable()
# Auto-start abort listener when first handler is registered
interval = current_app.config["TASK_ABORT_POLLING_DEFAULT_INTERVAL"]
self._start_abort_listener(interval)
return handler
def _set_abortable(self) -> None:
"""Mark the task as abortable (abort handler has been registered)."""
from superset.commands.tasks.internal_update import InternalUpdateTaskCommand
# Update local cache and write to DB
self._properties_cache["is_abortable"] = True
InternalUpdateTaskCommand(
task_uuid=self._task_uuid,
properties=self._properties_cache,
).run()
def _start_abort_listener(self, interval: float) -> None:
"""
Start background abort listener via TaskManager.
Uses Redis pub/sub if available, otherwise falls back to database polling.
The implementation is encapsulated in TaskManager.
"""
if self._abort_listener is not None:
return # Already listening
from superset.tasks.manager import TaskManager
self._abort_listener = TaskManager.listen_for_abort(
task_uuid=self._task_uuid,
callback=self._on_abort_detected,
poll_interval=interval,
app=self._app,
)
def _on_abort_detected(self) -> None:
"""
Callback invoked by TaskManager when abort is detected.
Triggers all registered abort handlers.
"""
if self._abort_detected:
return # Already handled
# Check if task execution has already completed (late abort race).
# Executor sets _execution_completed after task work finishes.
if self._execution_completed:
logger.info(
"Abort detected for task %s but execution already completed",
self._task_uuid,
)
return
self._abort_detected = True
logger.info("Abort detected for task %s", self._task_uuid)
self._trigger_abort_handlers()
def mark_execution_completed(self) -> None:
"""
Mark that the task's main execution has completed.
Called by the executor after the task function returns (successfully
or with an exception). This prevents late abort callbacks from running
handlers when the task work has already finished. Cleanup handlers
still run after this is set.
"""
self._execution_completed = True
def start_abort_polling(self, interval: float | None = None) -> None:
"""
Start background abort listener.
This method is kept for backwards compatibility. It now delegates
to _start_abort_listener which uses TaskManager.
:param interval: Polling interval in seconds (uses config default if None)
"""
if interval is None:
interval = current_app.config["TASK_ABORT_POLLING_DEFAULT_INTERVAL"]
self._start_abort_listener(interval)
def _trigger_abort_handlers(self) -> None:
"""
Execute all registered abort handlers (called by polling thread or cleanup).
All handlers are attempted even if some fail (best-effort cleanup).
Failures are collected in self._handler_failures for unified reporting.
Note: This method never writes to DB directly. All failures are collected
and written by _run_cleanup() in the executor's finally block, ensuring
abort and cleanup handler failures are combined into a single record.
"""
for handler in reversed(self._abort_handlers):
try:
handler()
except Exception as ex:
stack_trace = traceback.format_exc()
logger.error(
"Abort handler failed for task %s: %s",
self._task_uuid,
str(ex),
exc_info=True,
)
self._handler_failures.append(("abort", ex, stack_trace))
# Check if all abort handlers completed successfully
abort_failures = [f for f in self._handler_failures if f[0] == "abort"]
if not abort_failures:
self._abort_handlers_completed = True
def _write_handler_failures_to_db(self) -> None:
"""
Write collected handler failures to the database.
Combines all failures (abort + cleanup) into a single error record.
If the task already has an error (e.g., task function threw exception),
handler failures are APPENDED to preserve the original error context.
"""
from superset.commands.tasks.update import UpdateTaskCommand
if not self._handler_failures:
return
# Build error message from all handler failures
error_messages = [str(ex) for _, ex, _ in self._handler_failures]
handler_types = {htype for htype, _, _ in self._handler_failures}
if len(self._handler_failures) == 1:
htype, ex, handler_stack_trace = self._handler_failures[0]
handler_error_msg = (
f"{htype.capitalize()} handler failed: {error_messages[0]}"
)
handler_exception_type = type(ex).__name__
else:
# Multiple failures
handler_error_msg = f"Handler(s) failed: {'; '.join(error_messages)}"
if handler_types == {"abort"}:
handler_exception_type = "MultipleAbortHandlerFailures"
elif handler_types == {"cleanup"}:
handler_exception_type = "MultipleCleanupHandlerFailures"
else:
handler_exception_type = "MultipleHandlerFailures"
# Combine stack traces with clear separators
handler_stack_trace = "\n--- Next handler failure ---\n".join(
f"[{htype}:{type(ex).__name__}]\n{trace}"
for htype, ex, trace in self._handler_failures
)
if self._app:
with self._app.app_context():
# Check if task already has an error (preserve original context)
task = self._task
original_error = task.properties_dict.get("error_message")
original_type = task.properties_dict.get("exception_type")
original_trace = task.properties_dict.get("stack_trace")
if original_error:
# Append handler failures to original error
error_msg = f"{original_error} | {handler_error_msg}"
exception_type = (
f"{original_type}+{handler_exception_type}"
if original_type
else handler_exception_type
)
stack_trace = (
f"{original_trace}\n\n"
f"=== Handler failures during cleanup ===\n\n"
f"{handler_stack_trace}"
if original_trace
else handler_stack_trace
)
else:
# No original error, just use handler failures
error_msg = handler_error_msg
exception_type = handler_exception_type
stack_trace = handler_stack_trace
# Update task with combined error info
UpdateTaskCommand(
self._task_uuid,
status=TaskStatus.FAILURE.value,
properties={
"error_message": error_msg,
"exception_type": exception_type,
"stack_trace": stack_trace,
},
skip_security_check=True,
).run()
# Clear failures after writing
self._handler_failures = []
def stop_abort_polling(self) -> None:
"""Stop the background abort listener."""
if self._abort_listener is not None:
self._abort_listener.stop()
self._abort_listener = None
def start_timeout_timer(self, timeout_seconds: int) -> None:
"""
Start a timeout timer that triggers abort when elapsed.
Called by execute_task when task transitions to IN_PROGRESS.
Timer only triggers abort handlers if task is abortable.
:param timeout_seconds: Timeout duration in seconds
"""
if self._timeout_timer is not None:
return # Already started
def on_timeout() -> None:
if self._abort_detected:
return # Already aborting
self._timeout_triggered = True
# Check if task has abort handler (requires app context)
if not self._app:
logger.error(
"Timeout fired for task %s but no app context available",
self._task_uuid,
)
return
with self._app.app_context():
from superset.commands.tasks.update import UpdateTaskCommand
task = self._task
if task.properties_dict.get("is_abortable", False):
logger.info(
"Timeout reached for task %s after %d seconds - "
"transitioning to ABORTING and triggering abort handlers",
self._task_uuid,
timeout_seconds,
)
# Set status to ABORTING (same as user abort)
# The executor will determine TIMED_OUT vs FAILURE based on
# whether handlers complete successfully
UpdateTaskCommand(
self._task_uuid,
status=TaskStatus.ABORTING.value,
properties={"error_message": "Task timed out"},
skip_security_check=True,
).run()
# Trigger abort handlers for cleanup
self._on_abort_detected()
else:
# No abort handler - just log warning
logger.warning(
"Timeout reached for task %s after %d seconds, but no "
"abort handler is registered. Task will continue running.",
self._task_uuid,
timeout_seconds,
)
self._timeout_timer = threading.Timer(timeout_seconds, on_timeout)
# Timer is daemon so it won't prevent process exit. If the worker dies,
# the task is already in an inconsistent state (stuck IN_PROGRESS) that
# requires external recovery (orphan detection). A non-daemon timer with
# long timeouts (hours) would block graceful worker shutdown.
self._timeout_timer.daemon = True
self._timeout_timer.start()
logger.debug(
"Started timeout timer for task %s: %d seconds",
self._task_uuid,
timeout_seconds,
)
def stop_timeout_timer(self) -> None:
"""Cancel the timeout timer if running."""
if self._timeout_timer is not None:
self._timeout_timer.cancel()
self._timeout_timer = None
@property
def timeout_triggered(self) -> bool:
"""Check if the timeout was triggered."""
return self._timeout_triggered
@property
def abort_handlers_completed(self) -> bool:
"""Check if all abort handlers have completed successfully."""
return self._abort_handlers_completed
def _run_cleanup(self) -> None:
"""
Run cleanup handlers (called by executor in finally block).
This runs:
1. Flushes any pending throttled updates to ensure final state is persisted
2. Abort handlers if task was aborting/aborted (but not yet detected)
3. All cleanup handlers (always)
All handler failures (abort + cleanup) are collected and written to DB
as a unified error record at the end.
"""
# Flush any pending throttled updates before cleanup
with self._throttle_lock:
self._cancel_deferred_flush_timer()
if self._has_pending_updates:
self._write_to_db()
self._has_pending_updates = False
# Stop abort listener and timeout timer
self.stop_abort_polling()
self.stop_timeout_timer()
# If aborting/aborted but handlers haven't run yet, run them now
# (This catches the case where task ended before listener detected abort)
if self._app:
with self._app.app_context():
task = self._task
if task.status in ABORT_STATES and not self._abort_detected:
self._trigger_abort_handlers()
else:
# Fallback without app context
try:
task = self._task
if task.status in ABORT_STATES and not self._abort_detected:
self._trigger_abort_handlers()
except Exception as ex:
logger.warning(
"Could not check abort status during cleanup for task %s: %s",
self._task_uuid,
str(ex),
)
# Always run cleanup handlers, collecting failures
for handler in reversed(self._cleanup_handlers):
try:
handler()
except Exception as ex:
stack_trace = traceback.format_exc()
logger.error(
"Cleanup handler failed for task %s: %s",
self._task_uuid,
str(ex),
exc_info=True,
)
self._handler_failures.append(("cleanup", ex, stack_trace))
# Write all collected failures (abort + cleanup) to DB as unified record
if self._handler_failures:
self._write_handler_failures_to_db()

View File

@@ -0,0 +1,609 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Decorators for the Global Task Framework (GTF)"""
from __future__ import annotations
import inspect
import logging
from typing import Any, Callable, cast, Generic, ParamSpec, TYPE_CHECKING, TypeVar
from superset_core.api.tasks import TaskOptions, TaskScope, TaskStatus
from superset import is_feature_enabled
from superset.commands.tasks.exceptions import GlobalTaskFrameworkDisabledError
from superset.tasks.ambient_context import use_context
from superset.tasks.constants import TERMINAL_STATES
from superset.tasks.context import TaskContext
from superset.tasks.manager import TaskManager
from superset.tasks.registry import TaskRegistry
from superset.tasks.utils import generate_random_task_key
from superset.utils.core import get_user_id
if TYPE_CHECKING:
from superset.models.tasks import Task
logger = logging.getLogger(__name__)
P = ParamSpec("P")
R = TypeVar("R")
def task(
func: Callable[P, R] | None = None,
*,
name: str | None = None,
scope: TaskScope = TaskScope.PRIVATE,
timeout: int | None = None,
) -> Callable[[Callable[P, R]], "TaskWrapper[P]"] | "TaskWrapper[P]":
"""
Decorator to register a task with default scope.
Can be used with or without parentheses:
@task
def my_func(): ...
@task()
def my_func(): ...
@task(name="custom_name", scope=TaskScope.SHARED)
def my_func(): ...
@task(timeout=300) # 5-minute timeout
def long_running_func(): ...
Args:
func: The function to decorate (when used without parentheses).
name: Optional unique task name (e.g., "superset.generate_thumbnail").
If not provided, uses the function name as the task name.
scope: Task scope (TaskScope.PRIVATE, SHARED, or SYSTEM).
Defaults to TaskScope.PRIVATE.
timeout: Optional timeout in seconds. When the timeout is reached,
abort handlers are triggered if registered. Can be overridden
at call time via TaskOptions(timeout=...).
Usage:
# Private task (default scope) - no parentheses
@task
def my_async_func(chart_id: int) -> None:
ctx = get_context()
...
# Named task with shared scope
@task(name="generate_report", scope=TaskScope.SHARED)
def generate_expensive_report(report_id: int) -> None:
ctx = get_context()
...
# System task (admin-only)
@task(scope=TaskScope.SYSTEM)
def cleanup_task() -> None:
ctx = get_context()
...
# Task with timeout
@task(timeout=300)
def long_task() -> None:
ctx = get_context()
@ctx.on_abort
def handle_abort():
# Called when timeout is reached or user cancels
...
Note:
Both direct calls and .schedule() return Task, regardless of the
original function's return type. The decorated function's return value
is discarded; only side effects and context updates matter.
"""
def decorator(f: Callable[P, R]) -> "TaskWrapper[P]":
# Use function name if no name provided
task_name = name if name is not None else f.__name__
# Create default options with no scope (scope is now in decorator)
default_options = TaskOptions()
# Validate function signature - must not have ctx or options params
sig = inspect.signature(f)
forbidden = {"ctx", "options"}
if any(param in forbidden for param in sig.parameters):
raise TypeError(
f"Task function {f.__name__} must not define 'ctx' or "
"'options' parameters. "
f"Use get_context() instead for ambient context access."
)
# Register task
TaskRegistry.register(task_name, f)
# Create wrapper with schedule() method, default options, scope, and timeout
wrapper = TaskWrapper(task_name, f, default_options, scope, timeout)
# Preserve signature for introspection
wrapper.__signature__ = sig # type: ignore[attr-defined]
return wrapper
if func is None:
# Called with parentheses: @task() or @task(name="foo", scope=TaskScope.SHARED)
return decorator
else:
# Called without parentheses: @task
return decorator(func)
class TaskWrapper(Generic[P]):
"""
Wrapper for task functions that provides .schedule() method.
Both direct calls and .schedule() return Task. The original function's
return value is discarded.
Direct calls execute synchronously, .schedule() runs async via Celery.
"""
def __init__(
self,
name: str,
func: Callable[P, R],
default_options: TaskOptions,
scope: TaskScope = TaskScope.PRIVATE,
default_timeout: int | None = None,
) -> None:
self.name = name
self.func = func
self.default_options = default_options
self.scope = scope
self.default_timeout = default_timeout
self.__name__ = func.__name__
self.__doc__ = func.__doc__
self.__module__ = func.__module__
# Patch schedule.__signature__ to mirror function + options parameter
# This enables proper IDE support and introspection
sig = inspect.signature(func)
params = list(sig.parameters.values())
# Add keyword-only options parameter
params.append(
inspect.Parameter(
"options",
inspect.Parameter.KEYWORD_ONLY,
default=None,
annotation=TaskOptions | None,
)
)
self.schedule.__func__.__signature__ = sig.replace( # type: ignore[attr-defined]
parameters=params, return_annotation="Task"
)
def _merge_options(self, override_options: TaskOptions | None) -> TaskOptions:
"""
Merge decorator defaults with call-time overrides.
Call-time options take precedence over decorator defaults.
For timeout, an explicit None in TaskOptions disables the decorator timeout.
Args:
override_options: Options provided at call time, or None
Returns:
Merged TaskOptions with overrides applied
"""
if override_options is None:
return TaskOptions(
task_key=self.default_options.task_key,
task_name=self.default_options.task_name,
timeout=self.default_timeout, # Use decorator default
)
# Merge: use override if provided, otherwise use default
# For timeout: if override_options.timeout is explicitly set (even to None),
# use it; otherwise fall back to decorator default
return TaskOptions(
task_key=override_options.task_key or self.default_options.task_key,
task_name=override_options.task_name or self.default_options.task_name,
timeout=override_options.timeout
if override_options.timeout is not None
else self.default_timeout,
)
def _validate_task(self, options: TaskOptions) -> None:
"""
Validate task configuration before execution.
Args:
options: Merged task options to validate
Raises:
ValueError: If validation fails
"""
# Shared tasks must have an explicit task_key for deduplication
if self.scope == TaskScope.SHARED and options.task_key is None:
raise ValueError(
f"Shared task '{self.name}' requires an explicit task_key in "
"TaskOptions for deduplication. Without a task_key, each "
"invocation creates a separate task with a random UUID, "
"defeating the purpose of shared tasks."
)
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> "Task":
"""
Call the function synchronously.
This is invoked when you call the decorated function directly:
task = generate_thumbnail(chart_id) # Blocks until completion
Flow:
1. Submit task (create new or join existing via deduplication)
2. If joining existing task: wait for it to complete (blocking)
3. If new task: execute inline and return completed task
Sync execution always blocks until completion - even when joining an
existing task that's running in another process/worker.
Returns the Task entity in terminal state (SUCCESS, FAILURE, etc.).
Raises:
GlobalTaskFrameworkDisabledError: If GTF feature flag is not enabled
ValueError: If task validation fails
TimeoutError: If timeout expires while waiting for existing task
"""
from superset.commands.tasks.submit import SubmitTaskCommand
if not is_feature_enabled("GLOBAL_TASK_FRAMEWORK"):
raise GlobalTaskFrameworkDisabledError()
# Extract and merge options (decorator defaults + call-time overrides)
override_options = cast(TaskOptions | None, kwargs.pop("options", None))
options = self._merge_options(override_options)
# Validate task configuration
self._validate_task(options)
# Extract task_name and task_key from merged options, scope from decorator
task_name = (
options.task_name or f"{self.name}:{generate_random_task_key()[:50]}"
)
task_key = options.task_key or generate_random_task_key()
scope = self.scope # Use scope from decorator
# Build properties with execution_mode and timeout
properties: dict[str, str | int] = {"execution_mode": "sync"}
if options.timeout:
properties["timeout"] = options.timeout
# Submit task - may create new or join existing
task, is_new = SubmitTaskCommand(
{
"task_type": self.name,
"task_key": task_key,
"task_name": task_name,
"scope": scope.value,
"properties": properties,
"user_id": get_user_id(),
}
).run_with_info()
# If joining existing task, wait for it to complete
if not is_new:
return self._wait_for_existing_task(task, options.timeout)
# New task - execute inline
return self._execute_inline(task, options, args, kwargs)
def _wait_for_existing_task(self, task: "Task", timeout: int | None) -> "Task":
"""
Wait for an existing task to complete.
Called when sync execution joins a pre-existing task via deduplication.
Blocks until the task reaches a terminal state.
:param task: The existing task to wait for
:param timeout: Maximum time to wait in seconds (None = no limit)
:returns: Task in terminal state
:raises TimeoutError: If timeout expires before task completes
"""
from flask import current_app
from superset.daos.tasks import TaskDAO
# Check if already in terminal state
if task.status in TERMINAL_STATES:
logger.info(
"Joined already-completed task %s (uuid=%s, status=%s)",
self.name,
task.uuid,
task.status,
)
return task
# Wait for the existing task to complete
logger.info(
"Joined active task %s (uuid=%s, status=%s), waiting for completion",
self.name,
task.uuid,
task.status,
)
try:
app = current_app._get_current_object()
except RuntimeError:
app = None
try:
task = TaskManager.wait_for_completion(
task_uuid=task.uuid,
timeout=float(timeout) if timeout else None,
poll_interval=1.0,
app=app,
)
logger.info(
"Task %s (uuid=%s) completed with status=%s",
self.name,
task.uuid,
task.status,
)
return task
except TimeoutError:
logger.warning(
"Timeout waiting for task %s (uuid=%s)",
self.name,
task.uuid,
)
# Return task in current state (caller can check status)
refreshed = TaskDAO.find_one_or_none(uuid=task.uuid)
return refreshed if refreshed else task
def _execute_inline(
self,
task: "Task",
options: TaskOptions,
args: tuple[Any, ...],
kwargs: dict[str, Any],
) -> "Task":
"""
Execute task function inline (synchronously).
Called when this is a new task (not joining existing).
Uses atomic conditional status transitions for race-safe execution.
:param task: The newly created task
:param options: Merged task options
:param args: Positional arguments for the task function
:param kwargs: Keyword arguments for the task function
:returns: Task in terminal state
"""
from superset.commands.tasks.internal_update import (
InternalStatusTransitionCommand,
)
from superset.daos.tasks import TaskDAO
from superset.tasks.constants import ABORT_STATES
# PRE-EXECUTION CHECK: Don't execute if already aborted/aborting
# (Matches async flow in scheduler.py)
if task.status in ABORT_STATES:
logger.info(
"Task %s (uuid=%s) was aborted before execution started",
self.name,
task.uuid,
)
# Ensure status is ABORTED (not just ABORTING)
InternalStatusTransitionCommand(
task_uuid=task.uuid,
new_status=TaskStatus.ABORTED,
expected_status=[TaskStatus.PENDING, TaskStatus.ABORTING],
set_ended_at=True,
).run()
# Refresh to get updated task
refreshed = TaskDAO.find_one_or_none(uuid=task.uuid)
return refreshed if refreshed else task
# Atomic transition: PENDING → IN_PROGRESS (set started_at for duration
# tracking)
task_uuid = task.uuid # Cache UUID before any potential state changes
if not InternalStatusTransitionCommand(
task_uuid=task_uuid,
new_status=TaskStatus.IN_PROGRESS,
expected_status=TaskStatus.PENDING,
set_started_at=True,
).run():
# Status wasn't PENDING - task may have been aborted concurrently
logger.warning(
"Task %s (uuid=%s) failed PENDING → IN_PROGRESS transition "
"(may have been aborted concurrently)",
self.name,
task_uuid,
)
refreshed = TaskDAO.find_one_or_none(uuid=task_uuid)
return refreshed if refreshed else task
# Update cached status (no DB read needed - we just wrote IN_PROGRESS)
task.status = TaskStatus.IN_PROGRESS.value
# Build context with the updated task entity
ctx = TaskContext(task)
# Start timeout timer if configured
if options.timeout:
ctx.start_timeout_timer(options.timeout)
logger.debug(
"Started timeout timer for task %s: %d seconds",
task.uuid,
options.timeout,
)
# Track final task state for completion notification
final_task: Task | None = None
try:
# Execute with ambient context
with use_context(ctx):
self.func(*args, **kwargs)
# Determine terminal status based on abort detection
# Use atomic conditional updates to prevent overwriting concurrent abort
if ctx._abort_detected or ctx.timeout_triggered:
# Abort was detected - transition ABORTING → terminal
if ctx.timeout_triggered:
InternalStatusTransitionCommand(
task_uuid=task_uuid,
new_status=TaskStatus.TIMED_OUT,
expected_status=TaskStatus.ABORTING,
set_ended_at=True,
).run()
logger.info(
"Task %s (uuid=%s) timed out and completed cleanup",
self.name,
task_uuid,
)
else:
InternalStatusTransitionCommand(
task_uuid=task_uuid,
new_status=TaskStatus.ABORTED,
expected_status=TaskStatus.ABORTING,
set_ended_at=True,
).run()
logger.info(
"Task %s (uuid=%s) was aborted by user",
self.name,
task_uuid,
)
else:
# Normal completion - atomic IN_PROGRESS → SUCCESS
# This will fail (return False) if task was concurrently aborted
if InternalStatusTransitionCommand(
task_uuid=task_uuid,
new_status=TaskStatus.SUCCESS,
expected_status=TaskStatus.IN_PROGRESS,
set_ended_at=True,
).run():
logger.debug(
"Synchronous execution of task %s (uuid=%s) "
"completed successfully",
self.name,
task_uuid,
)
else:
# Transition failed - task was likely aborted concurrently
logger.info(
"Task %s (uuid=%s) IN_PROGRESS → SUCCESS failed "
"(may have been aborted concurrently)",
self.name,
task_uuid,
)
# Refresh once at end to return current state
final_task = TaskDAO.find_one_or_none(uuid=task_uuid)
return final_task if final_task else task
except Exception as ex:
# Atomic transition to FAILURE (only if still IN_PROGRESS)
InternalStatusTransitionCommand(
task_uuid=task_uuid,
new_status=TaskStatus.FAILURE,
expected_status=[TaskStatus.IN_PROGRESS, TaskStatus.ABORTING],
properties={"error_message": str(ex)},
set_ended_at=True,
).run()
logger.error(
"Synchronous execution of task %s (uuid=%s) failed: %s",
self.name,
task_uuid,
str(ex),
exc_info=True,
)
# Refresh once at end to return current state
final_task = TaskDAO.find_one_or_none(uuid=task_uuid)
return final_task if final_task else task
finally:
# Always clean up timer and handlers
ctx._run_cleanup()
# Publish completion notification for any waiters
# Use final_task if set by try/except, otherwise refresh (fallback)
if final_task is None:
final_task = TaskDAO.find_one_or_none(uuid=task_uuid)
if final_task and final_task.status in TERMINAL_STATES:
TaskManager.publish_completion(task_uuid, final_task.status)
def schedule(self, *args: P.args, **kwargs: P.kwargs) -> "Task":
"""
Schedule this task for asynchronous execution.
The signature mirrors the original task function, with an additional
keyword-only 'options' parameter for execution metadata.
Args:
*args, **kwargs: Business arguments for the task function
options: Execution options
Returns:
Task model representing the scheduled task (PENDING status)
Raises:
GlobalTaskFrameworkDisabledError: If GTF feature flag is not enabled
ValueError: If task is SHARED scope but no task_key is provided
Usage:
# Auto-generated task_key (random UUID, no deduplication):
task = generate_thumbnail.schedule(chart_id)
# Custom task_key for task deduplication:
task = generate_thumbnail.schedule(
chart_id,
options=TaskOptions(task_key=f"thumb_{chart_id}")
)
# SHARED tasks require task_key:
task = shared_task.schedule(
data_id,
options=TaskOptions(task_key=f"shared_{data_id}")
)
Note: Unlike direct calls (__call__), this schedules async execution.
The function returns immediately with the Task model in PENDING status.
"""
if not is_feature_enabled("GLOBAL_TASK_FRAMEWORK"):
raise GlobalTaskFrameworkDisabledError()
# Extract and merge options (decorator defaults + call-time overrides)
override_options = cast(TaskOptions | None, kwargs.pop("options", None))
options = self._merge_options(override_options)
# Validate task configuration
self._validate_task(options)
# Extract task_name and task_key from merged options, scope from decorator
task_name = options.task_name
task_key = options.task_key
scope = self.scope # Use scope from decorator
# Create task entry in metastore and schedule execution
return TaskManager.submit_task(
task_type=self.name,
task_name=task_name,
task_key=task_key,
scope=scope,
timeout=options.timeout,
args=args,
kwargs=kwargs,
)

112
superset/tasks/filters.py Normal file
View File

@@ -0,0 +1,112 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Filters for Task model"""
from typing import Any
from sqlalchemy.orm.query import Query
from superset.utils.core import get_user_id
from superset.views.base import BaseFilter
class TaskFilter(BaseFilter): # pylint: disable=too-few-public-methods
"""
Filter for Task that shows tasks based on scope and user permissions.
Filtering rules:
- Admins: See all tasks (private, shared, system)
- Non-admins:
- Private tasks: Only their own tasks
- Shared tasks: Tasks they're subscribed to
- System tasks: None (admin-only)
"""
def apply(self, query: Query, value: Any) -> Query:
"""Apply the filter to the query."""
from flask import g, has_request_context
from sqlalchemy import or_
from superset import db, security_manager
from superset.models.task_subscribers import TaskSubscriber
from superset.models.tasks import Task
# If no request context or no user, return unfiltered query
# (this handles background tasks and system operations)
if not has_request_context() or not hasattr(g, "user"):
return query
# If user is admin, return unfiltered query
if security_manager.is_admin():
return query
# For non-admins, filter by scope and permissions
user_id = get_user_id()
# Use subquery for shared tasks to avoid join ambiguity
shared_task_ids_query = (
db.session.query(Task.id)
.join(TaskSubscriber, Task.id == TaskSubscriber.task_id)
.filter(
Task.scope == "shared",
TaskSubscriber.user_id == user_id,
)
)
# Build filter conditions:
# 1. Private tasks created by current user
# 2. Shared tasks where user is subscribed (via subquery)
# 3. System tasks are excluded (admin-only)
return query.filter(
or_(
# Own private tasks
(Task.scope == "private") & (Task.created_by_fk == user_id),
# Shared tasks where user is subscribed
Task.id.in_(shared_task_ids_query),
)
)
class TaskSubscriberFilter(BaseFilter): # pylint: disable=too-few-public-methods
"""
Filter tasks by subscriber user ID.
This filter allows finding tasks where a specific user is subscribed.
Used by the frontend for the subscribers filter dropdown.
"""
def apply(self, query: Query, value: Any) -> Query:
"""Apply the filter to the query."""
from superset import db
from superset.models.task_subscribers import TaskSubscriber
from superset.models.tasks import Task
if not value:
return query
# Handle both single ID and list of IDs
if isinstance(value, (list, tuple)):
user_ids = [int(v) for v in value]
else:
user_ids = [int(value)]
# Find tasks where any of these users are subscribers
subscribed_task_ids = db.session.query(TaskSubscriber.task_id).filter(
TaskSubscriber.user_id.in_(user_ids)
)
return query.filter(Task.id.in_(subscribed_task_ids))

81
superset/tasks/locks.py Normal file
View File

@@ -0,0 +1,81 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Distributed locking utilities for the Global Task Framework (GTF).
This module provides distributed locks for task operations to prevent race
conditions during concurrent task creation, subscription, and cancellation.
The lock key uses the task's dedup_key, ensuring all operations on the same
logical task serialize correctly.
When SIGNAL_CACHE_CONFIG is configured, uses Redis SET NX EX for
efficient single-command locking. Otherwise falls back to database-backed
locking via DistributedLock.
"""
from __future__ import annotations
import logging
from contextlib import contextmanager
from typing import Iterator
from superset.distributed_lock import DistributedLock
logger = logging.getLogger(__name__)
# Task operations use a shorter TTL than the global default since
# they complete quickly (just DB operations, no external calls)
TASK_LOCK_TTL_SECONDS = 10
@contextmanager
def task_lock(dedup_key: str) -> Iterator[None]:
"""
Acquire a distributed lock for task operations.
Uses the task's dedup_key as the lock key. All operations on the same
logical task (create, subscribe, cancel) use the same lock, ensuring
mutual exclusion. This prevents race conditions such as:
- Two concurrent creates with the same key
- Subscribe racing with cancel
- Multiple concurrent cancel requests
When SIGNAL_CACHE_CONFIG is configured, uses Redis SET NX EX
for efficient single-command locking. Otherwise falls back to
database-backed DistributedLock.
:param dedup_key: Task deduplication key (from get_active_dedup_key)
:yields: Nothing; used as context manager
:raises AcquireDistributedLockFailedException: If lock is already held
Example:
dedup_key = get_active_dedup_key(TaskScope.SHARED, "report", "monthly")
with task_lock(dedup_key):
# Create, subscribe, or cancel task here
...
"""
logger.debug("Acquiring task lock for key: %s", dedup_key)
with DistributedLock(
namespace="gtf:task",
key=dedup_key,
ttl_seconds=TASK_LOCK_TTL_SECONDS,
):
yield
logger.debug("Released task lock for key: %s", dedup_key)

764
superset/tasks/manager.py Normal file
View File

@@ -0,0 +1,764 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Task manager for the Global Task Framework (GTF)"""
from __future__ import annotations
import logging
import threading
import time
from typing import Any, Callable, TYPE_CHECKING
from uuid import UUID
import redis
from superset_core.api.tasks import TaskProperties, TaskScope
from superset.async_events.cache_backend import (
RedisCacheBackend,
RedisSentinelCacheBackend,
)
from superset.extensions import cache_manager
from superset.tasks.constants import ABORT_STATES, TERMINAL_STATES
from superset.tasks.utils import generate_random_task_key
if TYPE_CHECKING:
from flask import Flask
from superset.models.tasks import Task
logger = logging.getLogger(__name__)
class AbortListener:
"""
Handle for a background abort listener.
Returned by TaskManager.listen_for_abort() to allow stopping the listener.
"""
def __init__(
self,
task_uuid: UUID,
thread: threading.Thread,
stop_event: threading.Event,
pubsub: redis.client.PubSub | None = None,
) -> None:
self._task_uuid = task_uuid
self._thread = thread
self._stop_event = stop_event
self._pubsub = pubsub
def stop(self) -> None:
"""Stop the abort listener."""
self._stop_event.set()
# Close pub/sub subscription if active
if self._pubsub is not None:
try:
self._pubsub.unsubscribe()
self._pubsub.close()
except Exception as ex:
logger.debug("Error closing pub/sub during stop: %s", ex)
# Wait for thread to finish (with timeout to avoid blocking indefinitely)
if self._thread.is_alive():
self._thread.join(timeout=2.0)
# Check if thread is still running after timeout
if self._thread.is_alive():
# Thread is a daemon, so it will be killed when process exits.
# Log warning but continue - cleanup will still proceed.
logger.warning(
"Abort listener thread for task %s did not terminate within "
"2 seconds. Thread will be terminated when process exits.",
self._task_uuid,
)
else:
logger.debug("Stopped abort listener for task %s", self._task_uuid)
else:
logger.debug("Stopped abort listener for task %s", self._task_uuid)
class TaskManager:
"""
Handles task creation, scheduling, and abort notifications.
The TaskManager is responsible for:
1. Creating task entries in the metastore (Task model)
2. Scheduling task execution via Celery
3. Handling deduplication (returning existing active task if duplicate)
4. Managing real-time abort notifications (optional)
Redis pub/sub is opt-in via SIGNAL_CACHE_CONFIG configuration. When not
configured, tasks use database polling for abort detection.
"""
# Class-level state (initialized once via init_app)
_channel_prefix: str = "gtf:abort:"
_completion_channel_prefix: str = "gtf:complete:"
_initialized: bool = False
# Backward compatibility alias - prefer importing from superset.tasks.constants
TERMINAL_STATES = TERMINAL_STATES
@classmethod
def init_app(cls, app: Flask) -> None:
"""
Initialize the TaskManager with Flask app config.
Redis connection is managed by CacheManager - this just reads channel prefixes.
:param app: Flask application instance
"""
if cls._initialized:
return
cls._channel_prefix = app.config.get("TASKS_ABORT_CHANNEL_PREFIX", "gtf:abort:")
cls._completion_channel_prefix = app.config.get(
"TASKS_COMPLETION_CHANNEL_PREFIX", "gtf:complete:"
)
cls._initialized = True
@classmethod
def _get_cache(cls) -> RedisCacheBackend | RedisSentinelCacheBackend | None:
"""
Get the signal cache backend.
:returns: The signal cache backend, or None if not configured
"""
return cache_manager.signal_cache
@classmethod
def is_pubsub_available(cls) -> bool:
"""
Check if Redis pub/sub backend is configured and available.
:returns: True if Redis is available for pub/sub, False otherwise
"""
return cls._get_cache() is not None
@classmethod
def get_abort_channel(cls, task_uuid: UUID) -> str:
"""
Get the abort channel name for a task.
:param task_uuid: UUID of the task
:returns: Channel name for the task's abort notifications
"""
return f"{cls._channel_prefix}{task_uuid}"
@classmethod
def publish_abort(cls, task_uuid: UUID) -> bool:
"""
Publish an abort message to the task's channel.
:param task_uuid: UUID of the task to abort
:returns: True if message was published, False if Redis unavailable
"""
cache = cls._get_cache()
if not cache:
return False
try:
channel = cls.get_abort_channel(task_uuid)
subscriber_count = cache.publish(channel, "abort")
logger.debug(
"Published abort to channel %s (%d subscribers)",
channel,
subscriber_count,
)
return True
except redis.RedisError as ex:
logger.error("Failed to publish abort for task %s: %s", task_uuid, ex)
return False
@classmethod
def get_completion_channel(cls, task_uuid: UUID) -> str:
"""
Get the completion channel name for a task.
:param task_uuid: UUID of the task
:returns: Channel name for the task's completion notifications
"""
return f"{cls._completion_channel_prefix}{task_uuid}"
@classmethod
def publish_completion(cls, task_uuid: UUID, status: str) -> bool:
"""
Publish a completion message to the task's channel.
Called when task reaches terminal state (SUCCESS, FAILURE, ABORTED, TIMED_OUT).
This notifies any waiters (e.g., sync callers waiting for an existing task).
:param task_uuid: UUID of the completed task
:param status: Final status of the task
:returns: True if message was published, False if Redis unavailable
"""
cache = cls._get_cache()
if not cache:
return False
try:
channel = cls.get_completion_channel(task_uuid)
subscriber_count = cache.publish(channel, status)
logger.debug(
"Published completion to channel %s (status=%s, %d subscribers)",
channel,
status,
subscriber_count,
)
return True
except redis.RedisError as ex:
logger.error("Failed to publish completion for task %s: %s", task_uuid, ex)
return False
@classmethod
def wait_for_completion(
cls,
task_uuid: UUID,
timeout: float | None = None,
poll_interval: float = 1.0,
app: Any = None,
) -> "Task":
"""
Block until task reaches terminal state.
Uses Redis pub/sub if configured for low-latency, low-CPU waiting.
Uses database polling if Redis is not configured.
:param task_uuid: UUID of the task to wait for
:param timeout: Maximum time to wait in seconds (None = no limit)
:param poll_interval: Interval for database polling (seconds)
:param app: Flask app for database access
:returns: Task in terminal state
:raises TimeoutError: If timeout expires before task completes
:raises ValueError: If task not found
"""
from superset.daos.tasks import TaskDAO
start_time = time.monotonic()
def time_remaining() -> float | None:
if timeout is None:
return None
elapsed = time.monotonic() - start_time
remaining = timeout - elapsed
return remaining if remaining > 0 else 0
def get_task() -> "Task | None":
if app:
with app.app_context():
return TaskDAO.find_one_or_none(uuid=task_uuid)
return TaskDAO.find_one_or_none(uuid=task_uuid)
# Check current state first
task = get_task()
if not task:
raise ValueError(f"Task {task_uuid} not found")
if task.status in cls.TERMINAL_STATES:
return task
logger.debug(
"Waiting for task %s to complete (current status=%s, timeout=%s)",
task_uuid,
task.status,
timeout,
)
# Use Redis pub/sub if configured
if (cache := cls._get_cache()) is not None:
task = cls._wait_via_pubsub(
task_uuid,
cache.pubsub(),
timeout,
poll_interval,
get_task,
time_remaining,
)
if task:
return task
# Should not reach here - _wait_via_pubsub returns task or raises
raise RuntimeError(f"Unexpected state waiting for task {task_uuid}")
# Use database polling when Redis is not configured
return cls._wait_via_polling(task_uuid, poll_interval, get_task, time_remaining)
@classmethod
def _wait_via_pubsub(
cls,
task_uuid: UUID,
pubsub: redis.client.PubSub,
timeout: float | None,
poll_interval: float,
get_task: Callable[[], "Task | None"],
time_remaining: Callable[[], float | None],
) -> "Task | None":
"""
Wait for task completion using Redis pub/sub.
:returns: Task when completed
:raises TimeoutError: If timeout expires
:raises redis.RedisError: If Redis connection fails
"""
channel = cls.get_completion_channel(task_uuid)
pubsub.subscribe(channel)
try:
while True:
remaining = time_remaining()
if remaining is not None and remaining <= 0:
raise TimeoutError(
f"Timeout waiting for task {task_uuid} to complete"
)
# Wait for message with short timeout for responsive checking
wait_time = min(1.0, remaining) if remaining else 1.0
message = pubsub.get_message(
ignore_subscribe_messages=True,
timeout=wait_time,
)
if message and message.get("type") == "message":
# Completion received - fetch fresh task state
logger.debug(
"Received completion message for task %s: %s",
task_uuid,
message.get("data"),
)
task = get_task()
if task and task.status in cls.TERMINAL_STATES:
return task
# Also check database periodically in case we missed the message
# (e.g., task completed before we subscribed)
task = get_task()
if task and task.status in cls.TERMINAL_STATES:
logger.debug(
"Task %s completed (detected via db check): status=%s",
task_uuid,
task.status,
)
return task
finally:
pubsub.unsubscribe()
pubsub.close()
@classmethod
def _wait_via_polling(
cls,
task_uuid: UUID,
poll_interval: float,
get_task: Callable[[], "Task | None"],
time_remaining: Callable[[], float | None],
) -> "Task":
"""
Wait for task completion using database polling.
:returns: Task when completed
:raises TimeoutError: If timeout expires
:raises ValueError: If task not found
"""
while True:
remaining = time_remaining()
if remaining is not None and remaining <= 0:
raise TimeoutError(f"Timeout waiting for task {task_uuid} to complete")
task = get_task()
if not task:
raise ValueError(f"Task {task_uuid} not found")
if task.status in cls.TERMINAL_STATES:
logger.debug(
"Task %s completed (detected via polling): status=%s",
task_uuid,
task.status,
)
return task
# Sleep with timeout awareness
sleep_time = min(poll_interval, remaining) if remaining else poll_interval
time.sleep(sleep_time)
@classmethod
def listen_for_abort(
cls,
task_uuid: UUID,
callback: Callable[[], None],
poll_interval: float,
app: Any = None,
) -> AbortListener:
"""
Start listening for abort notifications for a task.
Uses Redis pub/sub if configured, otherwise uses database polling.
The callback is invoked when an abort is detected.
:param task_uuid: UUID of the task to monitor (native UUID)
:param callback: Function to call when abort is detected
:param poll_interval: Interval for database polling (when Redis not configured)
:param app: Flask app for database access in background thread
:returns: AbortListener handle to stop listening
"""
stop_event = threading.Event()
pubsub: redis.client.PubSub | None = None
uuid_str = str(task_uuid)
# Use Redis pub/sub if configured
if (cache := cls._get_cache()) is not None:
pubsub = cache.pubsub()
channel = cls.get_abort_channel(task_uuid)
pubsub.subscribe(channel)
logger.debug("Subscribed to abort channel: %s", channel)
# Start pub/sub listener thread
thread = threading.Thread(
target=cls._listen_pubsub,
args=(task_uuid, pubsub, callback, stop_event, app),
daemon=True,
name=f"abort-listener-{uuid_str[:8]}",
)
logger.debug("Started pub/sub abort listener for task %s", task_uuid)
else:
# Use polling when Redis is not configured
pubsub = None
thread = threading.Thread(
target=cls._poll_for_abort,
args=(task_uuid, callback, stop_event, poll_interval, app),
daemon=True,
name=f"abort-poller-{uuid_str[:8]}",
)
logger.debug(
"Started database abort polling for task %s (interval=%ss)",
task_uuid,
poll_interval,
)
thread.start()
return AbortListener(task_uuid, thread, stop_event, pubsub)
@staticmethod
def _invoke_callback_with_context(
callback: Callable[[], None],
app: Any,
) -> None:
"""
Invoke callback with Flask app context if provided.
:param callback: Function to invoke
:param app: Flask app for context, or None
"""
if app:
with app.app_context():
callback()
else:
callback()
@classmethod
def _check_abort_status(cls, task_uuid: UUID) -> bool:
"""
Check if task has been aborted via database query.
:param task_uuid: UUID of the task to check (native UUID)
:returns: True if task is in ABORTING or ABORTED state
"""
from superset.daos.tasks import TaskDAO
task = TaskDAO.find_one_or_none(uuid=task_uuid)
return task is not None and task.status in ABORT_STATES
@classmethod
def _run_abort_listener_loop(
cls,
task_uuid: UUID,
callback: Callable[[], None],
stop_event: threading.Event,
interval: float,
app: Any,
check_fn: Callable[[], bool],
source: str,
) -> None:
"""
Common abort listener loop used by both pub/sub and polling modes.
:param task_uuid: UUID of the task to monitor (native UUID)
:param callback: Function to call when abort is detected
:param stop_event: Event to signal loop termination
:param interval: Wait interval between checks
:param app: Flask app for context
:param check_fn: Function that returns True if abort was detected
:param source: Source identifier for logging ("pub/sub" or "polling")
"""
while not stop_event.is_set():
try:
if check_fn():
logger.info(
"Abort detected via %s for task %s",
source,
task_uuid,
)
cls._invoke_callback_with_context(callback, app)
break
# Wait for interval or until stop is requested
stop_event.wait(timeout=interval)
except (ValueError, OSError) as ex:
# ValueError/OSError with "I/O operation on closed file" or
# "Bad file descriptor" typically means the connection was closed
# during shutdown. Check if stop was requested.
if stop_event.is_set():
logger.debug(
"Abort %s for task %s stopped cleanly (connection closed)",
source,
task_uuid,
)
else:
logger.error(
"Error in abort %s for task %s: %s",
source,
task_uuid,
str(ex),
exc_info=True,
)
break
except Exception as ex:
# Check if stop was requested - if so, this may be expected
if stop_event.is_set():
logger.debug(
"Abort %s for task %s stopped with exception: %s",
source,
task_uuid,
ex,
)
else:
logger.error(
"Error in abort %s for task %s: %s",
source,
task_uuid,
str(ex),
exc_info=True,
)
break
@classmethod
def _listen_pubsub(
cls,
task_uuid: UUID,
pubsub: redis.client.PubSub,
callback: Callable[[], None],
stop_event: threading.Event,
app: Any,
) -> None:
"""Listen for abort via Redis pub/sub."""
# Track if abort was received to avoid double-callback
abort_received = False
def check_pubsub() -> bool:
nonlocal abort_received
message = pubsub.get_message(ignore_subscribe_messages=True, timeout=1.0)
if message is not None and message.get("type") == "message":
abort_received = True
return True
return False
try:
cls._run_abort_listener_loop(
task_uuid=task_uuid,
callback=callback,
stop_event=stop_event,
interval=0, # pub/sub has its own timeout in get_message
app=app,
check_fn=check_pubsub,
source="pub/sub",
)
except redis.RedisError as ex:
# Check if we were asked to stop - if so, this is expected
if stop_event.is_set():
logger.debug(
"Abort listener for task %s stopped (Redis error: %s)",
task_uuid,
ex,
)
else:
# Log error but don't fall back - let the failure be visible
logger.error(
"Redis signal backend failed for task %s abort listener: %s. "
"Task may not receive abort signal.",
task_uuid,
ex,
)
except (ValueError, OSError) as ex:
# ValueError: "I/O operation on closed file" - expected when stop() closes
# OSError: Similar connection-closed errors
if stop_event.is_set():
# Clean shutdown, expected behavior
logger.debug(
"Abort listener for task %s stopped cleanly",
task_uuid,
)
else:
# Unexpected error while running
logger.error(
"Error in abort listener for task %s: %s",
task_uuid,
str(ex),
exc_info=True,
)
except Exception as ex:
# Only log as error if we weren't asked to stop
if stop_event.is_set():
logger.debug(
"Abort listener for task %s stopped with exception: %s",
task_uuid,
ex,
)
else:
logger.error(
"Error in abort listener for task %s: %s",
task_uuid,
str(ex),
exc_info=True,
)
finally:
# Clean up pub/sub subscription
try:
pubsub.unsubscribe()
pubsub.close()
except Exception as ex:
logger.debug("Error closing pub/sub during cleanup: %s", ex)
@classmethod
def _poll_for_abort(
cls,
task_uuid: UUID,
callback: Callable[[], None],
stop_event: threading.Event,
interval: float,
app: Any,
) -> None:
"""Background polling loop - used when Redis pub/sub is not configured."""
def check_database() -> bool:
# Need app context for database access
if app:
with app.app_context():
return cls._check_abort_status(task_uuid)
else:
return cls._check_abort_status(task_uuid)
cls._run_abort_listener_loop(
task_uuid=task_uuid,
callback=callback,
stop_event=stop_event,
interval=interval,
app=app,
check_fn=check_database,
source="polling",
)
@staticmethod
def submit_task(
task_type: str,
task_key: str | None,
task_name: str | None,
scope: TaskScope,
timeout: int | None,
args: tuple[Any, ...],
kwargs: dict[str, Any],
) -> "Task":
"""
Create task entry and schedule for async execution.
Flow:
1. Generate task_key if not provided (random UUID)
2. Submit to SubmitTaskCommand which handles locking and create-vs-join
3. Schedule Celery task ONLY for new tasks (not deduplicated ones)
4. Return Task model to caller
The SubmitTaskCommand uses a distributed lock to prevent race conditions,
returning either a new task or an existing active task with the same key.
:param task_type: Task type identifier (e.g., "superset.generate_thumbnail")
:param task_key: Optional deduplication key (None for random UUID)
:param task_name: Human readable task name
:param scope: Task scope (TaskScope.PRIVATE, SHARED, or SYSTEM)
:param timeout: Optional timeout in seconds
:param args: Positional arguments for the task function
:param kwargs: Keyword arguments for the task function
:returns: Task model representing the scheduled task
"""
from superset.commands.tasks.submit import SubmitTaskCommand
if task_key is None:
task_key = generate_random_task_key()
# Build properties with execution_mode and timeout
properties: TaskProperties = {"execution_mode": "async"}
if timeout:
properties["timeout"] = timeout
# Create or join task entry in metastore
# SubmitTaskCommand handles locking and create-vs-join logic:
# - Acquires distributed lock on dedup_key
# - If active task exists: adds subscriber and returns existing task
# (is_new=False)
# - If no active task: creates new task (is_new=True)
task, is_new = SubmitTaskCommand(
{
"task_key": task_key,
"task_type": task_type,
"task_name": task_name,
"scope": scope.value,
"properties": properties,
}
).run_with_info()
# Only schedule Celery task for NEW tasks, not deduplicated ones
# Deduplicated tasks are already pending or running
if is_new:
# Import here to avoid circular dependency
from superset.tasks.scheduler import execute_task
# Schedule Celery task for async execution
execute_task.delay(
task_uuid=str(task.uuid),
task_type=task_type,
args=args,
kwargs=kwargs,
)
logger.debug(
"Scheduled task %s (uuid=%s) for async execution",
task_type,
task.uuid,
)
else:
logger.debug(
"Joined existing task %s (uuid=%s) - no new Celery task scheduled",
task_type,
task.uuid,
)
return task

110
superset/tasks/registry.py Normal file
View File

@@ -0,0 +1,110 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Task registry for the Global Task Framework (GTF)"""
import logging
from typing import Any, Callable
logger = logging.getLogger(__name__)
class TaskRegistry:
"""
Registry for task functions.
Stores task functions by name, allowing the Celery executor to look up
and execute registered tasks. This enables the decorator pattern where
functions are registered at module import time.
"""
_tasks: dict[str, Callable[..., Any]] = {}
@classmethod
def register(cls, task_name: str, func: Callable[..., Any]) -> None:
"""
Register a task function by name.
:param task_name: Unique task identifier (e.g., "superset.generate_thumbnail")
:param func: The task function to register
:raises ValueError: If task name is already registered
"""
if task_name in cls._tasks:
existing_func = cls._tasks[task_name]
if existing_func is not func:
raise ValueError(
f"Task '{task_name}' is already registered with a different "
"function. "
f"Existing: {existing_func.__module__}.{existing_func.__name__}, "
f"New: {func.__module__}.{func.__name__}"
)
# Same function being registered again (e.g., module reload) - allow it
logger.debug("Task '%s' re-registered with same function", task_name)
return
cls._tasks[task_name] = func
logger.info(
"Registered async task: %s -> %s.%s",
task_name,
func.__module__,
func.__name__,
)
@classmethod
def get_executor(cls, task_name: str) -> Callable[..., Any]:
"""
Get the executor function for a task.
:param task_name: Task identifier to look up
:returns: The registered task function
:raises KeyError: If task name is not registered
"""
if task_name not in cls._tasks:
raise KeyError(
f"Task '{task_name}' is not registered. "
f"Available tasks: {', '.join(sorted(cls._tasks.keys()))}"
)
return cls._tasks[task_name]
@classmethod
def is_registered(cls, task_name: str) -> bool:
"""
Check if a task is registered.
:param task_name: Task identifier to check
:returns: True if task is registered
"""
return task_name in cls._tasks
@classmethod
def list_tasks(cls) -> list[str]:
"""
Get list of all registered task names.
:returns: Sorted list of task names
"""
return sorted(cls._tasks.keys())
@classmethod
def clear(cls) -> None:
"""
Clear all registered tasks.
WARNING: This is primarily for testing purposes. In production,
tasks should remain registered for the lifetime of the process.
"""
cls._tasks.clear()
logger.warning("Task registry cleared")

View File

@@ -19,11 +19,13 @@ from __future__ import annotations
import logging
from datetime import datetime, timezone
from typing import Any
from uuid import UUID
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from celery.signals import task_failure
from flask import current_app
from superset_core.api.tasks import TaskStatus
from superset import is_feature_enabled
from superset.commands.exceptions import CommandException
@@ -32,10 +34,17 @@ from superset.commands.report.exceptions import ReportScheduleUnexpectedError
from superset.commands.report.execute import AsyncExecuteReportScheduleCommand
from superset.commands.report.log_prune import AsyncPruneReportScheduleLogCommand
from superset.commands.sql_lab.query import QueryPruneCommand
from superset.commands.tasks.prune import TaskPruneCommand
from superset.daos.report import ReportScheduleDAO
from superset.daos.tasks import TaskDAO
from superset.extensions import celery_app
from superset.stats_logger import BaseStatsLogger
from superset.tasks.ambient_context import use_context
from superset.tasks.constants import ABORT_STATES, TERMINAL_STATES
from superset.tasks.context import TaskContext
from superset.tasks.cron_util import cron_schedule_window
from superset.tasks.manager import TaskManager
from superset.tasks.registry import TaskRegistry
from superset.utils.core import LoggerLevel
from superset.utils.log import get_logger_from_status
@@ -199,3 +208,251 @@ def prune_logs(
LogPruneCommand(retention_period_days, max_rows_per_run).run()
except CommandException as ex:
logger.exception("An error occurred while pruning logs: %s", ex)
@celery_app.task(name="prune_tasks", bind=True)
def prune_tasks(
self: Task,
retention_period_days: int | None = None,
max_rows_per_run: int | None = None,
**kwargs: Any,
) -> None:
stats_logger: BaseStatsLogger = current_app.config["STATS_LOGGER"]
stats_logger.incr("prune_tasks")
# TODO: Deprecated: Remove support for passing retention period via options in 6.0
if retention_period_days is None:
retention_period_days = prune_tasks.request.properties.get(
"retention_period_days"
)
logger.warning(
"Your `prune_tasks` beat schedule uses `options` to pass the "
"retention period, please use `kwargs` instead."
)
try:
TaskPruneCommand(retention_period_days, max_rows_per_run).run()
except CommandException as ex:
logger.exception("An error occurred while pruning async tasks: %s", ex)
@celery_app.task(name="tasks.execute", bind=True)
def execute_task( # noqa: C901
self: Any, # Celery task instance
task_uuid: str,
task_type: str,
args: tuple[Any, ...],
kwargs: dict[str, Any],
) -> dict[str, Any]:
"""
Generic task executor for GTF tasks.
This executor:
1. Checks if task was aborted before execution starts
2. Fetches task from metastore
3. Builds context (task + user) and sets ambient context via contextvars
4. Executes the task function (which accesses context via get_context())
5. Updates task status throughout lifecycle using atomic conditional updates
6. Runs cleanup handlers on task end (success/failure/abortion)
7. Resets context after execution
Uses atomic conditional status updates to prevent race conditions with
concurrent abort operations.
:param task_uuid: UUID of the task to execute
:param task_type: Type of the task (for registry lookup)
:param args: Positional arguments for the task function
:param kwargs: Keyword arguments for the task function
:returns: Dict with status and task_uuid
"""
from superset.commands.tasks.internal_update import InternalStatusTransitionCommand
# Convert string UUID to native UUID (Celery deserializes as string)
native_uuid = UUID(task_uuid)
task = TaskDAO.find_one_or_none(uuid=native_uuid)
if not task:
logger.error("Task %s not found in metastore", task_uuid)
return {"status": "error", "message": "Task not found"}
# AUTOMATIC PRE-EXECUTION CHECK: Don't execute if already aborted/aborting
if task.status in ABORT_STATES:
logger.info(
"Task %s (uuid=%s) was aborted before execution started",
task_type,
task_uuid,
)
# Atomic transition to ABORTED (if not already)
InternalStatusTransitionCommand(
task_uuid=native_uuid,
new_status=TaskStatus.ABORTED,
expected_status=[TaskStatus.PENDING, TaskStatus.ABORTING],
set_ended_at=True,
).run()
return {"status": TaskStatus.ABORTED.value, "task_uuid": task_uuid}
# Atomic transition: PENDING → IN_PROGRESS (set started_at for duration tracking)
if not InternalStatusTransitionCommand(
task_uuid=native_uuid,
new_status=TaskStatus.IN_PROGRESS,
expected_status=TaskStatus.PENDING,
set_started_at=True,
).run():
# Status wasn't PENDING - task may have been aborted concurrently
logger.warning(
"Task %s (uuid=%s) failed PENDING → IN_PROGRESS transition "
"(may have been aborted concurrently)",
task_type,
task_uuid,
)
refreshed = TaskDAO.find_one_or_none(uuid=native_uuid)
return {
"status": refreshed.status if refreshed else "unknown",
"task_uuid": task_uuid,
}
# Update cached status (no DB read needed - we just wrote IN_PROGRESS)
task.status = TaskStatus.IN_PROGRESS.value
# Build context from task (includes user who created the task)
ctx = TaskContext(task)
# Start timeout timer if configured (timer starts from execution time)
if timeout := task.properties_dict.get("timeout"):
ctx.start_timeout_timer(timeout)
logger.debug(
"Started timeout timer for task %s: %d seconds",
task_uuid,
timeout,
)
try:
# Get registered executor function
executor_fn = TaskRegistry.get_executor(task_type)
logger.info(
"Executing task %s (uuid=%s) with function %s.%s",
task_type,
task_uuid,
executor_fn.__module__,
executor_fn.__name__,
)
# Execute with ambient context (no ctx parameter!)
with use_context(ctx):
executor_fn(*args, **kwargs)
# Mark execution as completed to prevent late abort handlers
ctx.mark_execution_completed()
# Determine terminal status based on abort detection
# Use atomic conditional updates to prevent overwriting concurrent abort
if ctx._abort_detected or ctx.timeout_triggered:
# Abort was detected - will be handled in finally block
pass
else:
# Normal completion - also allow ABORTING → SUCCESS for late abort
# (task finished before abort was detected)
if InternalStatusTransitionCommand(
task_uuid=native_uuid,
new_status=TaskStatus.SUCCESS,
expected_status=[TaskStatus.IN_PROGRESS, TaskStatus.ABORTING],
set_ended_at=True,
).run():
# Emit stats metric for success
stats_logger: BaseStatsLogger = current_app.config["STATS_LOGGER"]
stats_logger.incr("gtf.task.success")
logger.info(
"Task %s (uuid=%s) completed successfully", task_type, task_uuid
)
else:
# Transition failed - task was likely already in a terminal state
logger.info(
"Task %s (uuid=%s) completion transition failed "
"(task may already be in terminal state)",
task_type,
task_uuid,
)
except Exception as ex:
# Mark execution as completed to prevent late abort handlers
ctx.mark_execution_completed()
# Atomic transition to FAILURE (only if still IN_PROGRESS or ABORTING)
InternalStatusTransitionCommand(
task_uuid=native_uuid,
new_status=TaskStatus.FAILURE,
expected_status=[TaskStatus.IN_PROGRESS, TaskStatus.ABORTING],
properties={"error_message": str(ex)},
set_ended_at=True,
).run()
logger.error(
"Task %s (uuid=%s) failed with error: %s",
task_type,
task_uuid,
str(ex),
exc_info=True,
)
# Emit stats metric for failure
stats_logger = current_app.config["STATS_LOGGER"]
stats_logger.incr("gtf.task.failure")
finally:
# ALWAYS run cleanup handlers (also stops timeout timer)
ctx._run_cleanup()
# Handle abort/timeout terminal transitions
# Use atomic updates to safely transition ABORTING → terminal state
if ctx._abort_detected or ctx.timeout_triggered:
if ctx.abort_handlers_completed:
# All handlers succeeded - determine terminal state based on cause
if ctx.timeout_triggered:
InternalStatusTransitionCommand(
task_uuid=native_uuid,
new_status=TaskStatus.TIMED_OUT,
expected_status=TaskStatus.ABORTING,
set_ended_at=True,
).run()
logger.info(
"Task %s (uuid=%s) timed out and completed cleanup",
task_type,
task_uuid,
)
else:
InternalStatusTransitionCommand(
task_uuid=native_uuid,
new_status=TaskStatus.ABORTED,
expected_status=TaskStatus.ABORTING,
set_ended_at=True,
).run()
logger.info(
"Task %s (uuid=%s) was aborted by user",
task_type,
task_uuid,
)
else:
# Handlers didn't complete successfully - mark as FAILURE
InternalStatusTransitionCommand(
task_uuid=native_uuid,
new_status=TaskStatus.FAILURE,
expected_status=TaskStatus.ABORTING,
properties={"error_message": "Abort handlers did not complete"},
set_ended_at=True,
).run()
logger.warning(
"Task %s (uuid=%s) stuck in ABORTING - marking as FAILURE",
task_type,
task_uuid,
)
# Refresh to get final status for return value and completion notification
refreshed = TaskDAO.find_one_or_none(uuid=native_uuid)
final_status = refreshed.status if refreshed else "unknown"
# Publish completion notification for any waiters (e.g., sync callers)
if final_status in TERMINAL_STATES:
TaskManager.publish_completion(native_uuid, final_status)
return {"status": final_status, "task_uuid": task_uuid}

200
superset/tasks/schemas.py Normal file
View File

@@ -0,0 +1,200 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Task API schemas"""
from marshmallow import fields, Schema
from marshmallow.fields import Method
# RISON/JSON schemas for query parameters
get_delete_ids_schema = {"type": "array", "items": {"type": "string"}}
# Field descriptions
uuid_description = "The unique identifier (UUID) of the task"
task_key_description = "The task identifier used for deduplication"
task_type_description = (
"The type of task (e.g., 'sql_execution', 'thumbnail_generation')"
)
task_name_description = "Human-readable name for the task"
status_description = "Current status of the task"
created_on_description = "Timestamp when the task was created"
changed_on_description = "Timestamp when the task was last updated"
started_at_description = "Timestamp when the task started execution"
ended_at_description = "Timestamp when the task completed or failed"
created_by_description = "User who created the task"
user_id_description = "ID of the user context for task execution"
payload_description = "Task-specific data in JSON format"
properties_description = (
"Runtime state and execution config. Contains: is_abortable, progress_percent, "
"progress_current, progress_total, error_message, exception_type, stack_trace, "
"timeout"
)
duration_seconds_description = (
"Duration in seconds - for finished tasks: execution time, "
"for running tasks: time since start, for pending: queue time"
)
scope_description = (
"Task scope: 'private' (user-specific), 'shared' (multi-user), "
"or 'system' (admin-only)"
)
subscriber_count_description = (
"Number of users subscribed to this task (for shared tasks)"
)
subscribers_description = "List of users subscribed to this task (for shared tasks)"
class UserSchema(Schema):
"""Schema for user information"""
id = fields.Int()
first_name = fields.String()
last_name = fields.String()
class TaskResponseSchema(Schema):
"""
Schema for task response.
Used for both list and detail endpoints.
"""
id = fields.Int(metadata={"description": "Internal task ID"})
uuid = fields.UUID(metadata={"description": uuid_description})
task_key = fields.String(metadata={"description": task_key_description})
task_type = fields.String(metadata={"description": task_type_description})
task_name = fields.String(
metadata={"description": task_name_description}, allow_none=True
)
status = fields.String(metadata={"description": status_description})
created_on = fields.DateTime(metadata={"description": created_on_description})
created_on_delta_humanized = Method(
"get_created_on_delta_humanized",
metadata={"description": "Humanized time since creation"},
)
changed_on = fields.DateTime(metadata={"description": changed_on_description})
changed_by = fields.Nested(UserSchema, allow_none=True)
started_at = fields.DateTime(
metadata={"description": started_at_description}, allow_none=True
)
ended_at = fields.DateTime(
metadata={"description": ended_at_description}, allow_none=True
)
created_by = fields.Nested(UserSchema, allow_none=True)
user_id = fields.Int(metadata={"description": user_id_description}, allow_none=True)
payload = Method("get_payload_dict", metadata={"description": payload_description})
properties = Method(
"get_properties", metadata={"description": properties_description}
)
duration_seconds = Method(
"get_duration",
metadata={"description": duration_seconds_description},
)
scope = fields.String(metadata={"description": scope_description})
subscriber_count = Method(
"get_subscriber_count", metadata={"description": subscriber_count_description}
)
subscribers = Method(
"get_subscribers", metadata={"description": subscribers_description}
)
def get_payload_dict(self, obj: object) -> dict[str, object] | None:
"""Get payload as dictionary"""
return obj.payload_dict # type: ignore[attr-defined]
def get_properties(self, obj: object) -> dict[str, object]:
"""Get properties dict, filtering stack_trace if SHOW_STACKTRACE is disabled."""
from flask import current_app
properties = dict(obj.properties_dict) # type: ignore[attr-defined]
# Remove stack_trace unless SHOW_STACKTRACE is enabled
if not current_app.config.get("SHOW_STACKTRACE", False):
properties.pop("stack_trace", None)
return properties
def get_duration(self, obj: object) -> float | None:
"""Get duration in seconds"""
return obj.duration_seconds # type: ignore[attr-defined]
def get_created_on_delta_humanized(self, obj: object) -> str:
"""Get humanized time since creation"""
return obj.created_on_delta_humanized() # type: ignore[attr-defined]
def get_subscriber_count(self, obj: object) -> int:
"""Get number of subscribers"""
return obj.subscriber_count # type: ignore[attr-defined]
def get_subscribers(self, obj: object) -> list[dict[str, object]]:
"""Get list of subscribers with user info"""
subscribers = []
for sub in obj.subscribers: # type: ignore[attr-defined]
subscribers.append(
{
"user_id": sub.user_id,
"first_name": sub.user.first_name if sub.user else None,
"last_name": sub.user.last_name if sub.user else None,
"subscribed_at": sub.subscribed_at.isoformat()
if sub.subscribed_at
else None,
}
)
return subscribers
class TaskStatusResponseSchema(Schema):
"""Schema for task status response (lightweight for polling)"""
status = fields.String(metadata={"description": status_description})
class TaskCancelRequestSchema(Schema):
"""Schema for task cancellation request"""
force = fields.Boolean(
load_default=False,
metadata={
"description": "Force cancel the task for all subscribers (admin only). "
"Only applicable for shared tasks with multiple subscribers."
},
)
class TaskCancelResponseSchema(Schema):
"""Schema for task cancellation response"""
message = fields.String(metadata={"description": "Success or status message"})
action = fields.String(
metadata={
"description": "The action taken: 'aborted' (task terminated) or "
"'unsubscribed' (user removed from shared task)"
}
)
task = fields.Nested(TaskResponseSchema, allow_none=True)
openapi_spec_methods_override = {
"get": {"get": {"summary": "Get a task detail"}},
"get_list": {
"get": {
"summary": "Get a list of tasks",
"description": "Gets a list of tasks for the current user. "
"Use Rison or JSON query parameters for filtering, sorting, "
"pagination and for selecting specific columns and metadata.",
}
},
"info": {"get": {"summary": "Get metadata information about this API resource"}},
}

View File

@@ -14,6 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
from typing import NamedTuple
from superset.utils.backports import StrEnum

View File

@@ -18,16 +18,25 @@
from __future__ import annotations
import logging
import traceback
from http.client import HTTPResponse
from typing import Optional, TYPE_CHECKING
from typing import cast, TYPE_CHECKING
from urllib import request
from uuid import UUID, uuid4
from celery.utils.log import get_task_logger
from flask import g
from superset_core.api.tasks import TaskProperties, TaskScope
from superset.tasks.exceptions import ExecutorNotFoundError, InvalidExecutorError
from superset.tasks.types import ChosenExecutor, Executor, ExecutorType, FixedExecutor
from superset.tasks.types import (
ChosenExecutor,
Executor,
ExecutorType,
FixedExecutor,
)
from superset.utils import json
from superset.utils.hashing import hash_from_str
from superset.utils.urls import get_url_path
if TYPE_CHECKING:
@@ -123,7 +132,7 @@ def fetch_csrf_token(
response: HTTPResponse
with request.urlopen(req, timeout=600) as response: # noqa: S310
body = response.read().decode("utf-8")
session_cookie: Optional[str] = None
session_cookie: str | None = None
cookie_headers = response.headers.get_all("set-cookie")
if cookie_headers:
for cookie in cookie_headers:
@@ -142,3 +151,164 @@ def fetch_csrf_token(
logger.error("Error fetching CSRF token, status code: %s", response.status)
return {}
def generate_random_task_key() -> str:
"""
Generate a random task key.
This is the default behavior - each task submission gets a unique UUID
unless an explicit task_key is provided in TaskOptions.
:returns: A random UUID string
"""
return str(uuid4())
def get_active_dedup_key(
scope: TaskScope | str,
task_type: str,
task_key: str,
user_id: int | None = None,
) -> str:
"""
Build a deduplication key for active tasks.
The dedup_key enforces uniqueness at the database level via a unique index.
Active tasks use a composite key based on scope, which is then hashed using
the configured HASH_ALGORITHM to produce a fixed-length key.
The composite key format before hashing is:
- Private: private|task_type|task_key|user_id
- Shared: shared|task_type|task_key
- System: system|task_type|task_key
The final key is a hash digest (64 chars for sha256, 32 chars for md5).
:param scope: Task scope (PRIVATE/SHARED/SYSTEM) as TaskScope enum or string
:param task_type: Type of task (e.g., 'sql_execution')
:param task_key: Task identifier for deduplication
:param user_id: User ID (required for private tasks)
:returns: Hashed deduplication key string
:raises ValueError: If user_id is missing for private scope
"""
# Convert string to TaskScope if needed
if isinstance(scope, str):
scope = TaskScope(scope)
# Build composite key
match scope:
case TaskScope.PRIVATE:
if user_id is None:
raise ValueError("user_id required for private tasks")
composite_key = f"{scope.value}|{task_type}|{task_key}|{user_id}"
case TaskScope.SHARED:
composite_key = f"{scope.value}|{task_type}|{task_key}"
case TaskScope.SYSTEM:
composite_key = f"{scope.value}|{task_type}|{task_key}"
case _:
raise ValueError(f"Invalid scope: {scope}")
# Hash the composite key to produce a fixed-length dedup_key
# Truncate to 64 chars max to fit the database column in case
# a hash algo is used that generates hashes that exceed 64 chars
return hash_from_str(composite_key)[:64]
def get_finished_dedup_key(task_uuid: UUID) -> str:
"""
Build a deduplication key for finished tasks.
When a task completes (success, failure, or abort), its dedup_key is
changed to its UUID. This frees up the slot so new tasks with the same
parameters can be created.
:param task_uuid: Task UUID (native UUID type)
:returns: The task UUID string as the dedup key
Example:
>>> from uuid import UUID
>>> get_finished_dedup_key(UUID("a1b2c3d4-e5f6-7890-abcd-ef1234567890"))
'a1b2c3d4-e5f6-7890-abcd-ef1234567890'
"""
return str(task_uuid)
# -----------------------------------------------------------------------------
# TaskProperties helper functions
# -----------------------------------------------------------------------------
def progress_update(progress: float | int | tuple[int, int]) -> TaskProperties:
"""
Create a properties update dict for progress values.
:param progress: One of:
- float (0.0-1.0): Percentage only
- int: Count only (total unknown)
- tuple[int, int]: (current, total) with auto-computed percentage
:returns: TaskProperties dict with appropriate progress fields set
Example:
task.update_properties(progress_update((50, 100)))
"""
if isinstance(progress, float):
return {"progress_percent": progress}
if isinstance(progress, int):
return {"progress_current": progress}
# tuple
current, total = progress
result: TaskProperties = {
"progress_current": current,
"progress_total": total,
}
if total > 0:
result["progress_percent"] = current / total
return result
def error_update(exception: BaseException) -> TaskProperties:
"""
Create a properties update dict from an exception.
:param exception: The exception that caused the failure
:returns: TaskProperties dict with error fields populated
"""
return {
"error_message": str(exception),
"exception_type": type(exception).__name__,
"stack_trace": traceback.format_exc(),
}
def parse_properties(json_str: str | None) -> TaskProperties:
"""
Parse JSON string into TaskProperties dict.
Returns empty dict on parse errors. Unknown keys are preserved
for forward compatibility (allows adding new properties without
breaking existing code).
:param json_str: JSON string or None
:returns: TaskProperties dict (sparse - only contains keys that were set)
"""
if not json_str:
return {}
try:
raw = json.loads(json_str)
if isinstance(raw, dict):
return cast(TaskProperties, raw)
return {}
except (json.JSONDecodeError, TypeError):
return {}
def serialize_properties(props: TaskProperties) -> str:
"""
Serialize TaskProperties to JSON string.
:param props: TaskProperties dict
:returns: JSON string
"""
return json.dumps(props)

View File

@@ -14,9 +14,11 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import hashlib
import logging
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
from flask import current_app, Flask
from flask_caching import Cache
@@ -24,6 +26,12 @@ from markupsafe import Markup
from superset.utils.core import DatasourceType
if TYPE_CHECKING:
from superset.async_events.cache_backend import (
RedisCacheBackend,
RedisSentinelCacheBackend,
)
logger = logging.getLogger(__name__)
CACHE_IMPORT_PATH = "superset.extensions.metastore_cache.SupersetMetastoreCache"
@@ -185,6 +193,7 @@ class CacheManager:
self._thumbnail_cache = SupersetCache()
self._filter_state_cache = SupersetCache()
self._explore_form_data_cache = ExploreFormDataCache()
self._signal_cache: RedisCacheBackend | RedisSentinelCacheBackend | None = None
@staticmethod
def _init_cache(
@@ -226,6 +235,30 @@ class CacheManager:
"EXPLORE_FORM_DATA_CACHE_CONFIG",
required=True,
)
self._init_signal_cache(app)
def _init_signal_cache(self, app: Flask) -> None:
"""Initialize the signal cache for pub/sub and distributed locks."""
from superset.async_events.cache_backend import (
RedisCacheBackend,
RedisSentinelCacheBackend,
)
config = app.config.get("SIGNAL_CACHE_CONFIG")
if not config:
return
cache_type = config.get("CACHE_TYPE")
if cache_type == "RedisCache":
self._signal_cache = RedisCacheBackend.from_config(config)
elif cache_type == "RedisSentinelCache":
self._signal_cache = RedisSentinelCacheBackend.from_config(config)
else:
logger.warning(
"Unsupported CACHE_TYPE for SIGNAL_CACHE_CONFIG: %s. "
"Use 'RedisCache' or 'RedisSentinelCache'.",
cache_type,
)
@property
def data_cache(self) -> Cache:
@@ -246,3 +279,23 @@ class CacheManager:
@property
def explore_form_data_cache(self) -> Cache:
return self._explore_form_data_cache
@property
def signal_cache(
self,
) -> RedisCacheBackend | RedisSentinelCacheBackend | None:
"""
Return the signal cache backend.
Used for signaling features that require Redis-specific primitives:
- Pub/Sub messaging for real-time abort/completion notifications
- SET NX EX for atomic distributed lock acquisition
The backend provides:
- `._cache`: Raw Redis client
- `.key_prefix`: Configured key prefix (from CACHE_KEY_PREFIX)
- `.default_timeout`: Default timeout in seconds (from CACHE_DEFAULT_TIMEOUT)
Returns None if SIGNAL_CACHE_CONFIG is not configured.
"""
return self._signal_cache

View File

@@ -24,7 +24,7 @@ from abc import ABC, abstractmethod
from collections.abc import Iterator
from contextlib import contextmanager
from datetime import datetime, timedelta
from typing import Any, Callable, cast, Literal, TYPE_CHECKING
from typing import Any, Callable, cast, Literal
from flask import g, has_request_context, request
from flask_appbuilder.const import API_URI_RIS_KEY
@@ -34,9 +34,6 @@ from superset.extensions import stats_logger_manager
from superset.utils import json
from superset.utils.core import get_user_id, LoggerLevel, to_int
if TYPE_CHECKING:
pass
logger = logging.getLogger(__name__)

View File

@@ -31,8 +31,8 @@ from flask import current_app as app, url_for
from marshmallow import EXCLUDE, fields, post_load, Schema, validate
from superset import db
from superset.distributed_lock import KeyValueDistributedLock
from superset.exceptions import CreateKeyValueDistributedLockFailedException
from superset.distributed_lock import DistributedLock
from superset.exceptions import AcquireDistributedLockFailedException
from superset.superset_typing import OAuth2ClientConfig, OAuth2State
if TYPE_CHECKING:
@@ -77,7 +77,7 @@ def generate_code_challenge(code_verifier: str) -> str:
@backoff.on_exception(
backoff.expo,
CreateKeyValueDistributedLockFailedException,
AcquireDistributedLockFailedException,
factor=10,
base=2,
max_tries=5,
@@ -128,8 +128,10 @@ def refresh_oauth2_token(
db_engine_spec: type[BaseEngineSpec],
token: DatabaseUserOAuth2Tokens,
) -> str | None:
with KeyValueDistributedLock(
# Use longer TTL for OAuth2 token refresh (may involve network calls)
with DistributedLock(
namespace="refresh_oauth2_token",
ttl_seconds=30,
user_id=user_id,
database_id=database_id,
):

View File

@@ -14,32 +14,19 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from flask_appbuilder import expose, has_access
from __future__ import annotations
import logging
from typing import cast
from flask import current_app as app
from superset.commands.distributed_lock.base import BaseDistributedLockCommand
from superset.daos.key_value import KeyValueDAO
from superset.distributed_lock.types import LockValue
logger = logging.getLogger(__name__)
stats_logger = app.config["STATS_LOGGER"]
from superset.constants import MODEL_VIEW_RW_METHOD_PERMISSION_MAP
from superset.superset_typing import FlaskResponse
from superset.views.base import BaseSupersetView
class GetDistributedLock(BaseDistributedLockCommand):
def validate(self) -> None:
pass
class TaskModelView(BaseSupersetView):
route_base = "/tasks"
class_permission_name = "Task"
method_permission_name = MODEL_VIEW_RW_METHOD_PERMISSION_MAP
def run(self) -> LockValue | None:
entry = KeyValueDAO.get_entry(
resource=self.resource,
key=self.key,
)
if not entry or entry.is_expired():
return None
return cast(LockValue, self.codec.decode(entry.value))
@expose("/list/")
@has_access
def list(self) -> FlaskResponse:
return super().render_app_template()

View File

@@ -73,6 +73,7 @@ FEATURE_FLAGS = {
"AVOID_COLORS_COLLISION": True,
"DRILL_TO_DETAIL": True,
"DRILL_BY": True,
"GLOBAL_TASK_FRAMEWORK": True,
}
WEBDRIVER_BASEURL = "http://0.0.0.0:8081/"

View File

@@ -0,0 +1,538 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Integration tests for Task REST API"""
from contextlib import contextmanager
from typing import Generator
import prison
from superset_core.api.tasks import TaskStatus
from superset import db
from superset.models.tasks import Task
from superset.utils import json
from tests.integration_tests.base_tests import SupersetTestCase
from tests.integration_tests.constants import (
ADMIN_USERNAME,
GAMMA_USERNAME,
)
class TestTaskApi(SupersetTestCase):
"""Tests for Task REST API"""
TASK_API_BASE = "api/v1/task"
@contextmanager
def _create_tasks(self) -> Generator[list[Task], None, None]:
"""
Context manager to create test tasks with guaranteed cleanup.
Uses TaskDAO to create tasks, testing the actual production code path.
Usage:
with self._create_tasks() as tasks:
# Use tasks in test
# Cleanup happens automatically even if test fails
"""
from superset_core.api.tasks import TaskScope
from superset.daos.tasks import TaskDAO
admin = self.get_user("admin")
gamma = self.get_user("gamma")
tasks = []
try:
# Create tasks with different statuses using TaskDAO
for i in range(5):
task_key = f"test_task_{i}"
# Create task using DAO (this tests the dedup_key creation logic)
task = TaskDAO.create_task(
task_type="test_type",
task_key=task_key,
task_name=f"Test Task {i}",
scope=TaskScope.PRIVATE,
user_id=admin.id,
payload={"test": "data"},
)
# Set created_by for test purposes (DAO uses Flask-AppBuilder context)
task.created_by = admin
# Alternate between pending and finished tasks
if i % 2 != 0:
# Simulate realistic task lifecycle: PENDING → IN_PROGRESS → SUCCESS
# This sets both started_at (on IN_PROGRESS) and ended_at (on
# SUCCESS) so duration_seconds returns a valid value
task.set_status(TaskStatus.IN_PROGRESS)
task.set_status(TaskStatus.SUCCESS)
db.session.commit()
tasks.append(task)
# Create pending task for gamma user (use PENDING so it can be aborted)
gamma_task = TaskDAO.create_task(
task_type="test_type",
task_key="gamma_task",
task_name="Gamma Task",
scope=TaskScope.PRIVATE,
user_id=gamma.id,
payload={"user": "gamma"},
)
# Set created_by for test purposes
gamma_task.created_by = gamma
db.session.commit()
tasks.append(gamma_task)
yield tasks
finally:
# Cleanup happens here regardless of test success/failure
for task in tasks:
try:
db.session.delete(task)
except Exception: # noqa: S110
# Task may already be deleted or session may be in bad state
pass
try:
db.session.commit()
except Exception:
# Rollback if commit fails
db.session.rollback()
def test_info_task(self):
"""
Task API: Test info endpoint
"""
self.login(ADMIN_USERNAME)
uri = f"{self.TASK_API_BASE}/_info"
rv = self.client.get(uri)
assert rv.status_code == 200
data = json.loads(rv.data.decode("utf-8"))
assert "permissions" in data
def test_get_task_by_uuid(self):
"""
Task API: Test get task by UUID and verify dedup_key is hashed
"""
with self._create_tasks():
self.login(ADMIN_USERNAME)
admin = self.get_user("admin")
# Get a pending task to verify active dedup_key format
task = (
db.session.query(Task)
.filter_by(
created_by_fk=admin.id,
status=TaskStatus.PENDING.value,
task_type="test_type",
)
.first()
)
assert task is not None
# Verify active task has hashed dedup_key (64 chars for SHA-256)
assert len(task.dedup_key) == 64
assert all(c in "0123456789abcdef" for c in task.dedup_key)
assert task.dedup_key != str(task.uuid)
uri = f"{self.TASK_API_BASE}/{task.uuid}"
rv = self.client.get(uri)
assert rv.status_code == 200
data = json.loads(rv.data.decode("utf-8"))
# Compare strings since JSON response contains string UUID
assert data["result"]["uuid"] == str(task.uuid)
assert data["result"]["id"] == task.id
def test_get_task_not_found(self):
"""
Task API: Test get task not found with non-existent UUID
"""
self.login(ADMIN_USERNAME)
# Use a valid UUID that doesn't exist in the database
uri = f"{self.TASK_API_BASE}/00000000-0000-0000-0000-000000000000"
rv = self.client.get(uri)
assert rv.status_code == 404
def test_get_task_invalid_uuid(self):
"""
Task API: Test get task with invalid UUID
"""
self.login(ADMIN_USERNAME)
uri = f"{self.TASK_API_BASE}/invalid-uuid"
rv = self.client.get(uri)
assert rv.status_code == 404
def test_get_task_list(self):
"""
Task API: Test get task list
"""
with self._create_tasks():
self.login(ADMIN_USERNAME)
uri = f"{self.TASK_API_BASE}/"
rv = self.client.get(uri)
assert rv.status_code == 200
data = json.loads(rv.data.decode("utf-8"))
assert data["count"] >= 6 # At least the fixtures we created
assert "result" in data
def test_get_task_list_filtered_by_status(self):
"""
Task API: Test get task list filtered by status
"""
with self._create_tasks():
self.login(ADMIN_USERNAME)
arguments = {
"filters": [
{"col": "status", "opr": "eq", "value": TaskStatus.PENDING.value}
]
}
uri = f"{self.TASK_API_BASE}/?q={prison.dumps(arguments)}"
rv = self.client.get(uri)
assert rv.status_code == 200
data = json.loads(rv.data.decode("utf-8"))
for task in data["result"]:
assert task["status"] == TaskStatus.PENDING.value
def test_get_task_list_filtered_by_type(self):
"""
Task API: Test get task list filtered by type
"""
with self._create_tasks():
self.login(ADMIN_USERNAME)
arguments = {
"filters": [{"col": "task_type", "opr": "eq", "value": "test_type"}]
}
uri = f"{self.TASK_API_BASE}/?q={prison.dumps(arguments)}"
rv = self.client.get(uri)
assert rv.status_code == 200
data = json.loads(rv.data.decode("utf-8"))
assert data["count"] >= 6
for task in data["result"]:
assert task["task_type"] == "test_type"
def test_get_task_list_ordered(self):
"""
Task API: Test get task list with ordering
"""
with self._create_tasks():
self.login(ADMIN_USERNAME)
arguments = {
"order_column": "created_on",
"order_direction": "desc",
}
uri = f"{self.TASK_API_BASE}/?q={prison.dumps(arguments)}"
rv = self.client.get(uri)
assert rv.status_code == 200
data = json.loads(rv.data.decode("utf-8"))
assert len(data["result"]) > 0
def test_get_task_list_paginated(self):
"""
Task API: Test get task list with pagination
"""
with self._create_tasks():
self.login(ADMIN_USERNAME)
arguments = {"page": 0, "page_size": 2}
uri = f"{self.TASK_API_BASE}/?q={prison.dumps(arguments)}"
rv = self.client.get(uri)
assert rv.status_code == 200
data = json.loads(rv.data.decode("utf-8"))
assert len(data["result"]) <= 2
assert data["count"] >= 6
def test_cancel_task_by_uuid(self):
"""
Task API: Test cancel task by UUID
"""
with self._create_tasks():
self.login(ADMIN_USERNAME)
admin = self.get_user("admin")
task = (
db.session.query(Task)
.filter_by(created_by_fk=admin.id, status=TaskStatus.PENDING.value)
.first()
)
assert task is not None
uri = f"{self.TASK_API_BASE}/{task.uuid}/cancel"
rv = self.client.post(uri)
assert rv.status_code == 200
data = json.loads(rv.data.decode("utf-8"))
# Compare strings since JSON response contains string UUID
assert data["task"]["uuid"] == str(task.uuid)
assert data["task"]["status"] == TaskStatus.ABORTED.value
assert data["action"] == "aborted"
def test_cancel_task_not_found(self):
"""
Task API: Test cancel task not found with non-existent UUID
"""
self.login(ADMIN_USERNAME)
uri = f"{self.TASK_API_BASE}/00000000-0000-0000-0000-000000000000/cancel"
rv = self.client.post(uri)
assert rv.status_code == 404
def test_cancel_task_not_owned(self):
"""
Task API: Test cancel task not owned by user
"""
with self._create_tasks():
self.login(GAMMA_USERNAME)
admin = self.get_user("admin")
# Try to cancel admin's task as gamma user
task = db.session.query(Task).filter_by(created_by_fk=admin.id).first()
assert task is not None
uri = f"{self.TASK_API_BASE}/{task.uuid}/cancel"
rv = self.client.post(uri)
assert rv.status_code == 404
def test_cancel_task_admin_can_cancel_others(self):
"""
Task API: Test admin can cancel other users' tasks
"""
with self._create_tasks():
self.login(ADMIN_USERNAME)
gamma = self.get_user("gamma")
# Admin cancels gamma's task
task = db.session.query(Task).filter_by(created_by_fk=gamma.id).first()
assert task is not None
uri = f"{self.TASK_API_BASE}/{task.uuid}/cancel"
rv = self.client.post(uri)
assert rv.status_code == 200
def test_get_task_status_by_uuid(self):
"""
Task API: Test get task status by UUID
"""
with self._create_tasks():
self.login(ADMIN_USERNAME)
admin = self.get_user("admin")
task = db.session.query(Task).filter_by(created_by_fk=admin.id).first()
assert task is not None
uri = f"{self.TASK_API_BASE}/{task.uuid}/status"
rv = self.client.get(uri)
assert rv.status_code == 200
data = json.loads(rv.data.decode("utf-8"))
assert "status" in data
assert data["status"] == task.status
def test_get_task_status_not_found(self):
"""
Task API: Test get task status not found with non-existent UUID
"""
self.login(ADMIN_USERNAME)
uri = f"{self.TASK_API_BASE}/00000000-0000-0000-0000-000000000000/status"
rv = self.client.get(uri)
assert rv.status_code == 404
def test_get_task_status_not_owned(self):
"""
Task API: Test non-owner can't see task status
"""
with self._create_tasks():
self.login(GAMMA_USERNAME)
admin = self.get_user("admin")
# Try to get status of admin's task as gamma user
task = db.session.query(Task).filter_by(created_by_fk=admin.id).first()
assert task is not None
uri = f"{self.TASK_API_BASE}/{task.uuid}/status"
rv = self.client.get(uri)
# Should be forbidden due to base filter
assert rv.status_code == 404
def test_get_task_status_admin_can_see_others(self):
"""
Task API: Test admin can see other users' task status
"""
with self._create_tasks():
self.login(ADMIN_USERNAME)
gamma = self.get_user("gamma")
# Admin gets gamma's task status
task = db.session.query(Task).filter_by(created_by_fk=gamma.id).first()
assert task is not None
uri = f"{self.TASK_API_BASE}/{task.uuid}/status"
rv = self.client.get(uri)
assert rv.status_code == 200
data = json.loads(rv.data.decode("utf-8"))
assert data["status"] == task.status
def test_get_task_list_user_sees_own_tasks(self):
"""
Task API: Test non-admin user only sees their own tasks
"""
with self._create_tasks():
self.login(GAMMA_USERNAME)
gamma = self.get_user("gamma")
uri = f"{self.TASK_API_BASE}/"
rv = self.client.get(uri)
assert rv.status_code == 200
data = json.loads(rv.data.decode("utf-8"))
# Gamma should only see their own task
for task in data["result"]:
assert task["created_by"]["id"] == gamma.id
def test_get_task_list_admin_sees_all_tasks(self):
"""
Task API: Test admin sees all tasks
"""
with self._create_tasks():
self.login(ADMIN_USERNAME)
uri = f"{self.TASK_API_BASE}/"
rv = self.client.get(uri)
assert rv.status_code == 200
data = json.loads(rv.data.decode("utf-8"))
# Admin should see all tasks
assert data["count"] >= 6
def test_task_response_schema(self):
"""
Task API: Test response schema includes all expected fields
"""
with self._create_tasks():
self.login(ADMIN_USERNAME)
admin = self.get_user("admin")
task = db.session.query(Task).filter_by(created_by_fk=admin.id).first()
uri = f"{self.TASK_API_BASE}/{task.uuid}"
rv = self.client.get(uri)
assert rv.status_code == 200
data = json.loads(rv.data.decode("utf-8"))
result = data["result"]
# Check all expected fields are present
expected_fields = [
"id",
"uuid",
"task_key",
"task_type",
"task_name",
"status",
"created_on",
"created_on_delta_humanized",
"changed_on",
"changed_by",
"started_at",
"ended_at",
"created_by",
"user_id",
"payload",
"properties",
"duration_seconds",
"scope",
"subscriber_count",
"subscribers",
]
for field in expected_fields:
assert field in result, f"Field {field} missing from response"
# Verify properties is a dict with expected structure
properties = result["properties"]
assert isinstance(properties, dict)
def test_task_payload_serialization(self):
"""
Task API: Test payload is properly serialized as dict
"""
with self._create_tasks():
self.login(ADMIN_USERNAME)
admin = self.get_user("admin")
task = (
db.session.query(Task)
.filter_by(created_by_fk=admin.id, task_type="test_type")
.first()
)
uri = f"{self.TASK_API_BASE}/{task.uuid}"
rv = self.client.get(uri)
assert rv.status_code == 200
data = json.loads(rv.data.decode("utf-8"))
payload = data["result"]["payload"]
# Payload should be a dict, not a string
assert isinstance(payload, dict)
assert "test" in payload
assert payload["test"] == "data"
def test_task_computed_properties(self):
"""
Task API: Test computed properties in response
This test verifies that computed properties (status, duration_seconds)
are correctly returned in the API response. Internal DB columns like
dedup_key are tested in unit tests (test_find_by_task_key_finished_not_found).
"""
with self._create_tasks():
self.login(ADMIN_USERNAME)
admin = self.get_user("admin")
# Get a successful task
task = (
db.session.query(Task)
.filter_by(created_by_fk=admin.id, status=TaskStatus.SUCCESS.value)
.first()
)
assert task is not None
uri = f"{self.TASK_API_BASE}/{task.uuid}"
rv = self.client.get(uri)
assert rv.status_code == 200
data = json.loads(rv.data.decode("utf-8"))
result = data["result"]
# Check status field (computed properties are now derived from status)
assert result["status"] == TaskStatus.SUCCESS.value
# Properties dict should exist and be a dict
assert "properties" in result
assert isinstance(result["properties"], dict)
# Verify duration_seconds is not null for completed tasks with timestamps
# (requires both started_at and ended_at to be set)
if result.get("started_at") and result.get("ended_at"):
assert result["duration_seconds"] is not None
assert result["duration_seconds"] >= 0.0

View File

@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@@ -0,0 +1,482 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file to you under
# the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from unittest.mock import patch
from uuid import UUID, uuid4
import pytest
from superset_core.api.tasks import TaskScope, TaskStatus
from superset import db
from superset.commands.tasks.cancel import CancelTaskCommand
from superset.commands.tasks.exceptions import (
TaskAbortFailedError,
TaskNotAbortableError,
TaskNotFoundError,
TaskPermissionDeniedError,
)
from superset.daos.tasks import TaskDAO
from superset.utils.core import override_user
from tests.integration_tests.test_app import app
def test_cancel_pending_task_aborts(app_context, get_user) -> None:
"""Test canceling a pending task directly aborts it"""
admin = get_user("admin")
# Create a pending private task
task = TaskDAO.create_task(
task_type="test_type",
task_key="cancel_pending_test",
scope=TaskScope.PRIVATE,
user_id=admin.id,
)
task.created_by = admin
db.session.commit()
try:
# Cancel the pending task with admin user context
with override_user(admin):
command = CancelTaskCommand(task_uuid=task.uuid)
result = command.run()
# Verify task is aborted (pending goes directly to ABORTED)
assert result.uuid == task.uuid
assert result.status == TaskStatus.ABORTED.value
assert command.action_taken == "aborted"
# Verify in database
db.session.refresh(task)
assert task.status == TaskStatus.ABORTED.value
finally:
# Cleanup
db.session.delete(task)
db.session.commit()
def test_cancel_in_progress_abortable_task_sets_aborting(app_context, get_user) -> None:
"""Test canceling an in-progress task with abort handler sets ABORTING"""
admin = get_user("admin")
# Create an in-progress abortable task
task = TaskDAO.create_task(
task_type="test_type",
task_key="cancel_in_progress_test",
scope=TaskScope.PRIVATE,
user_id=admin.id,
properties={"is_abortable": True},
)
task.created_by = admin
task.set_status(TaskStatus.IN_PROGRESS)
db.session.commit()
try:
# Cancel the in-progress task - mock publish_abort to avoid Redis dependency
with (
override_user(admin),
patch("superset.tasks.manager.TaskManager.publish_abort"),
):
command = CancelTaskCommand(task_uuid=task.uuid)
result = command.run()
# In-progress tasks go to ABORTING (not ABORTED)
assert result.uuid == task.uuid
assert result.status == TaskStatus.ABORTING.value
assert command.action_taken == "aborted"
# Verify in database
db.session.refresh(task)
assert task.status == TaskStatus.ABORTING.value
finally:
# Cleanup
db.session.delete(task)
db.session.commit()
def test_cancel_in_progress_not_abortable_raises_error(app_context, get_user) -> None:
"""Test canceling an in-progress task without abort handler raises error"""
admin = get_user("admin")
unique_key = f"cancel_not_abortable_test_{uuid4().hex[:8]}"
# Create an in-progress non-abortable task
task = TaskDAO.create_task(
task_type="test_type",
task_key=unique_key,
scope=TaskScope.PRIVATE,
user_id=admin.id,
properties={"is_abortable": False},
)
task.created_by = admin
task.set_status(TaskStatus.IN_PROGRESS)
db.session.commit()
try:
with override_user(admin):
command = CancelTaskCommand(task_uuid=task.uuid)
with pytest.raises(TaskNotAbortableError):
command.run()
# Verify task status unchanged
db.session.refresh(task)
assert task.status == TaskStatus.IN_PROGRESS.value
finally:
# Cleanup
db.session.delete(task)
db.session.commit()
def test_cancel_task_not_found(app_context, get_user) -> None:
"""Test canceling non-existent task raises error"""
admin = get_user("admin")
with override_user(admin):
command = CancelTaskCommand(
task_uuid=UUID("00000000-0000-0000-0000-000000000000")
)
with pytest.raises(TaskNotFoundError):
command.run()
def test_cancel_finished_task_raises_error(app_context, get_user) -> None:
"""Test canceling an already finished task raises error"""
admin = get_user("admin")
unique_key = f"cancel_finished_test_{uuid4().hex[:8]}"
# Create a finished task
task = TaskDAO.create_task(
task_type="test_type",
task_key=unique_key,
scope=TaskScope.PRIVATE,
user_id=admin.id,
)
task.created_by = admin
task.set_status(TaskStatus.SUCCESS)
db.session.commit()
try:
with override_user(admin):
command = CancelTaskCommand(task_uuid=task.uuid)
with pytest.raises(TaskAbortFailedError):
command.run()
# Verify task status unchanged
db.session.refresh(task)
assert task.status == TaskStatus.SUCCESS.value
finally:
# Cleanup
db.session.delete(task)
db.session.commit()
def test_cancel_shared_task_with_multiple_subscribers_unsubscribes(
app_context, get_user
) -> None:
"""Test canceling a shared task with multiple subscribers unsubscribes user"""
admin = get_user("admin")
gamma = get_user("gamma")
# Create a shared task with admin as creator
task = TaskDAO.create_task(
task_type="test_type",
task_key="cancel_shared_test",
scope=TaskScope.SHARED,
user_id=admin.id,
)
task.created_by = admin
db.session.commit()
# Add gamma as subscriber
TaskDAO.add_subscriber(task.id, user_id=gamma.id)
db.session.commit()
try:
# Verify we have 2 subscribers
db.session.refresh(task)
assert task.subscriber_count == 2
# Cancel as gamma (non-admin subscriber)
with override_user(gamma):
command = CancelTaskCommand(task_uuid=task.uuid)
result = command.run()
# Should unsubscribe, not abort
assert command.action_taken == "unsubscribed"
assert result.status == TaskStatus.PENDING.value # Status unchanged
# Verify gamma was unsubscribed
db.session.refresh(task)
assert task.subscriber_count == 1
assert not task.has_subscriber(gamma.id)
assert task.has_subscriber(admin.id)
finally:
# Cleanup
db.session.delete(task)
db.session.commit()
def test_cancel_shared_task_last_subscriber_aborts(app_context, get_user) -> None:
"""Test canceling a shared task as last subscriber aborts it"""
admin = get_user("admin")
# Create a shared task with only admin as subscriber
task = TaskDAO.create_task(
task_type="test_type",
task_key="cancel_last_subscriber_test",
scope=TaskScope.SHARED,
user_id=admin.id,
)
task.created_by = admin
db.session.commit()
try:
# Verify only 1 subscriber
db.session.refresh(task)
assert task.subscriber_count == 1
# Cancel as the only subscriber
with override_user(admin):
command = CancelTaskCommand(task_uuid=task.uuid)
result = command.run()
# Should abort since last subscriber
assert command.action_taken == "aborted"
assert result.status == TaskStatus.ABORTED.value
finally:
# Cleanup
db.session.delete(task)
db.session.commit()
def test_cancel_with_force_aborts_for_all_subscribers(app_context, get_user) -> None:
"""Test force cancel aborts shared task even with multiple subscribers"""
admin = get_user("admin")
gamma = get_user("gamma")
# Create a shared task with multiple subscribers
task = TaskDAO.create_task(
task_type="test_type",
task_key="force_cancel_test",
scope=TaskScope.SHARED,
user_id=admin.id,
)
task.created_by = admin
db.session.commit()
# Add gamma as subscriber
TaskDAO.add_subscriber(task.id, user_id=gamma.id)
db.session.commit()
try:
# Verify 2 subscribers
db.session.refresh(task)
assert task.subscriber_count == 2
# Force cancel as admin
with override_user(admin):
command = CancelTaskCommand(task_uuid=task.uuid, force=True)
result = command.run()
# Should abort despite multiple subscribers
assert command.action_taken == "aborted"
assert result.status == TaskStatus.ABORTED.value
finally:
# Cleanup
db.session.delete(task)
db.session.commit()
def test_cancel_with_force_requires_admin(app_context, get_user) -> None:
"""Test force cancel requires admin privileges"""
admin = get_user("admin")
gamma = get_user("gamma")
# Create a shared task
task = TaskDAO.create_task(
task_type="test_type",
task_key="force_requires_admin_test",
scope=TaskScope.SHARED,
user_id=admin.id,
)
task.created_by = admin
db.session.commit()
# Add gamma as subscriber
TaskDAO.add_subscriber(task.id, user_id=gamma.id)
db.session.commit()
try:
# Try to force cancel as gamma (non-admin)
with override_user(gamma):
command = CancelTaskCommand(task_uuid=task.uuid, force=True)
with pytest.raises(TaskPermissionDeniedError):
command.run()
# Verify task unchanged
db.session.refresh(task)
assert task.status == TaskStatus.PENDING.value
finally:
# Cleanup
db.session.delete(task)
db.session.commit()
def test_cancel_private_task_permission_denied(app_context, get_user) -> None:
"""Test non-owner cannot cancel private task"""
admin = get_user("admin")
gamma = get_user("gamma")
unique_key = f"private_permission_test_{uuid4().hex[:8]}"
# Use test_request_context to ensure has_request_context() returns True
# so that TaskFilter properly applies permission filtering
with app.test_request_context():
# Create a private task owned by admin
task = TaskDAO.create_task(
task_type="test_type",
task_key=unique_key,
scope=TaskScope.PRIVATE,
user_id=admin.id,
)
task.created_by = admin
db.session.commit()
try:
# Try to cancel admin's private task as gamma (non-owner)
with override_user(gamma):
command = CancelTaskCommand(task_uuid=task.uuid)
# Should fail because gamma can't see admin's private task (base filter)
with pytest.raises(TaskNotFoundError):
command.run()
# Verify task unchanged
db.session.refresh(task)
assert task.status == TaskStatus.PENDING.value
finally:
# Cleanup
db.session.delete(task)
db.session.commit()
def test_cancel_system_task_requires_admin(app_context, get_user) -> None:
"""Test system tasks can only be canceled by admin"""
admin = get_user("admin")
gamma = get_user("gamma")
unique_key = f"system_task_test_{uuid4().hex[:8]}"
# Use test_request_context to ensure has_request_context() returns True
# so that TaskFilter properly applies permission filtering
with app.test_request_context():
# Create a system task
task = TaskDAO.create_task(
task_type="test_type",
task_key=unique_key,
scope=TaskScope.SYSTEM,
user_id=None,
)
task.created_by = admin
db.session.commit()
try:
# Try to cancel as gamma (non-admin)
with override_user(gamma):
command = CancelTaskCommand(task_uuid=task.uuid)
# System tasks are not visible to non-admins via base filter
with pytest.raises(TaskNotFoundError):
command.run()
# Verify task unchanged
db.session.refresh(task)
assert task.status == TaskStatus.PENDING.value
# But admin can cancel it
with override_user(admin):
command = CancelTaskCommand(task_uuid=task.uuid)
result = command.run()
assert result.status == TaskStatus.ABORTED.value
finally:
# Cleanup
db.session.delete(task)
db.session.commit()
def test_cancel_already_aborting_is_idempotent(app_context, get_user) -> None:
"""Test canceling an already aborting task is idempotent"""
admin = get_user("admin")
# Create a task already in ABORTING state
task = TaskDAO.create_task(
task_type="test_type",
task_key="idempotent_cancel_test",
scope=TaskScope.PRIVATE,
user_id=admin.id,
)
task.created_by = admin
task.set_status(TaskStatus.ABORTING)
db.session.commit()
try:
# Cancel the already aborting task
with override_user(admin):
command = CancelTaskCommand(task_uuid=task.uuid)
result = command.run()
# Should succeed without error
assert result.uuid == task.uuid
assert result.status == TaskStatus.ABORTING.value
finally:
# Cleanup
db.session.delete(task)
db.session.commit()
def test_cancel_shared_task_not_subscribed_raises_error(app_context, get_user) -> None:
"""Test non-subscriber cannot cancel shared task"""
admin = get_user("admin")
gamma = get_user("gamma")
# Create a shared task with only admin as subscriber
task = TaskDAO.create_task(
task_type="test_type",
task_key="not_subscribed_test",
scope=TaskScope.SHARED,
user_id=admin.id,
)
task.created_by = admin
db.session.commit()
try:
# Try to cancel as gamma (not subscribed)
with override_user(gamma):
command = CancelTaskCommand(task_uuid=task.uuid)
with pytest.raises(TaskPermissionDeniedError):
command.run()
# Verify task unchanged
db.session.refresh(task)
assert task.status == TaskStatus.PENDING.value
finally:
# Cleanup
db.session.delete(task)
db.session.commit()

View File

@@ -0,0 +1,419 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file to you under
# the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Integration tests for internal task state update commands."""
from uuid import UUID
from superset_core.api.tasks import TaskScope, TaskStatus
from superset import db
from superset.commands.tasks.internal_update import (
InternalStatusTransitionCommand,
InternalUpdateTaskCommand,
)
from superset.daos.tasks import TaskDAO
def test_internal_update_properties(app_context, get_user, login_as) -> None:
"""Test updating only properties without reading task first."""
admin = get_user("admin")
login_as("admin")
task = TaskDAO.create_task(
task_type="test_type",
task_key="internal_update_props",
scope=TaskScope.PRIVATE,
user_id=admin.id,
)
task.created_by = admin
task.set_status(TaskStatus.IN_PROGRESS)
db.session.commit()
try:
# Perform zero-read update
command = InternalUpdateTaskCommand(
task_uuid=task.uuid,
properties={"is_abortable": True, "progress_percent": 0.5},
)
result = command.run()
assert result is True
# Verify in database
db.session.refresh(task)
assert task.properties_dict.get("is_abortable") is True
assert task.properties_dict.get("progress_percent") == 0.5
finally:
db.session.delete(task)
db.session.commit()
def test_internal_update_payload(app_context, get_user, login_as) -> None:
"""Test updating only payload without reading task first."""
admin = get_user("admin")
login_as("admin")
task = TaskDAO.create_task(
task_type="test_type",
task_key="internal_update_payload",
scope=TaskScope.PRIVATE,
user_id=admin.id,
)
task.created_by = admin
task.set_status(TaskStatus.IN_PROGRESS)
db.session.commit()
try:
# Perform zero-read update
command = InternalUpdateTaskCommand(
task_uuid=task.uuid,
payload={"custom_key": "value", "count": 42},
)
result = command.run()
assert result is True
# Verify in database
db.session.refresh(task)
assert task.payload_dict == {"custom_key": "value", "count": 42}
finally:
db.session.delete(task)
db.session.commit()
def test_internal_update_both_properties_and_payload(
app_context, get_user, login_as
) -> None:
"""Test updating both properties and payload in one call."""
admin = get_user("admin")
login_as("admin")
task = TaskDAO.create_task(
task_type="test_type",
task_key="internal_update_both",
scope=TaskScope.PRIVATE,
user_id=admin.id,
)
task.created_by = admin
task.set_status(TaskStatus.IN_PROGRESS)
db.session.commit()
try:
# Perform zero-read update of both
command = InternalUpdateTaskCommand(
task_uuid=task.uuid,
properties={"progress_current": 50, "progress_total": 100},
payload={"last_item": "xyz"},
)
result = command.run()
assert result is True
# Verify in database
db.session.refresh(task)
assert task.properties_dict.get("progress_current") == 50
assert task.properties_dict.get("progress_total") == 100
assert task.payload_dict == {"last_item": "xyz"}
finally:
db.session.delete(task)
db.session.commit()
def test_internal_update_returns_false_for_nonexistent_task(
app_context, login_as
) -> None:
"""Test that updating non-existent task returns False."""
login_as("admin")
command = InternalUpdateTaskCommand(
task_uuid=UUID("00000000-0000-0000-0000-000000000000"),
properties={"is_abortable": True},
)
result = command.run()
assert result is False
def test_internal_update_returns_false_when_nothing_to_update(
app_context, get_user, login_as
) -> None:
"""Test that passing no properties or payload returns False early."""
admin = get_user("admin")
login_as("admin")
task = TaskDAO.create_task(
task_type="test_type",
task_key="internal_update_empty",
scope=TaskScope.PRIVATE,
user_id=admin.id,
)
task.created_by = admin
db.session.commit()
try:
# No properties or payload provided
command = InternalUpdateTaskCommand(
task_uuid=task.uuid,
properties=None,
payload=None,
)
result = command.run()
assert result is False
finally:
db.session.delete(task)
db.session.commit()
def test_internal_update_does_not_change_status(
app_context, get_user, login_as
) -> None:
"""Test that internal update leaves status unchanged (safe for concurrent abort)."""
admin = get_user("admin")
login_as("admin")
task = TaskDAO.create_task(
task_type="test_type",
task_key="internal_update_status",
scope=TaskScope.PRIVATE,
user_id=admin.id,
)
task.created_by = admin
task.set_status(TaskStatus.IN_PROGRESS)
db.session.commit()
try:
# Update properties - status should not change
command = InternalUpdateTaskCommand(
task_uuid=task.uuid,
properties={"progress_percent": 0.75},
)
result = command.run()
assert result is True
# Verify status unchanged
db.session.refresh(task)
assert task.status == TaskStatus.IN_PROGRESS.value
finally:
db.session.delete(task)
db.session.commit()
def test_internal_update_replaces_entire_properties(
app_context, get_user, login_as
) -> None:
"""Test that internal update replaces properties entirely (no merge)."""
admin = get_user("admin")
login_as("admin")
task = TaskDAO.create_task(
task_type="test_type",
task_key="internal_update_replace",
scope=TaskScope.PRIVATE,
user_id=admin.id,
properties={"is_abortable": True, "timeout": 300},
)
task.created_by = admin
db.session.commit()
try:
# Replace with new properties (caller is responsible for merging if needed)
command = InternalUpdateTaskCommand(
task_uuid=task.uuid,
properties={"error_message": "new_value"},
)
result = command.run()
assert result is True
# Verify entire replacement occurred
db.session.refresh(task)
# The caller should have merged if they wanted to preserve is_abortable
assert task.properties_dict == {"error_message": "new_value"}
assert "is_abortable" not in task.properties_dict
assert "timeout" not in task.properties_dict
finally:
db.session.delete(task)
db.session.commit()
# =============================================================================
# InternalStatusTransitionCommand Tests
# =============================================================================
def test_status_transition_atomic_compare_and_swap(
app_context, get_user, login_as
) -> None:
"""Test atomic conditional status transitions with comprehensive scenarios.
Covers: success case, failure case, list of expected statuses, properties update,
ended_at timestamp, and string status values.
"""
admin = get_user("admin")
login_as("admin")
task = TaskDAO.create_task(
task_type="test_type",
task_key="status_transition_comprehensive",
scope=TaskScope.PRIVATE,
user_id=admin.id,
)
task.created_by = admin
db.session.commit()
try:
# 1. SUCCESS CASE: PENDING → IN_PROGRESS (expected matches)
result = InternalStatusTransitionCommand(
task_uuid=task.uuid,
new_status=TaskStatus.IN_PROGRESS,
expected_status=TaskStatus.PENDING,
).run()
assert result is True
db.session.refresh(task)
assert task.status == TaskStatus.IN_PROGRESS.value
# 2. FAILURE CASE: Try wrong expected status (should fail, status unchanged)
result = InternalStatusTransitionCommand(
task_uuid=task.uuid,
new_status=TaskStatus.SUCCESS,
expected_status=TaskStatus.PENDING, # Wrong! Current is IN_PROGRESS
).run()
assert result is False
db.session.refresh(task)
assert task.status == TaskStatus.IN_PROGRESS.value # Unchanged
# 3. LIST OF EXPECTED: Transition with multiple acceptable source statuses
task.set_status(TaskStatus.ABORTING)
db.session.commit()
result = InternalStatusTransitionCommand(
task_uuid=task.uuid,
new_status=TaskStatus.FAILURE,
expected_status=[TaskStatus.IN_PROGRESS, TaskStatus.ABORTING],
properties={"error_message": "Test error"},
).run()
assert result is True
db.session.refresh(task)
assert task.status == TaskStatus.FAILURE.value
assert task.properties_dict.get("error_message") == "Test error"
# 4. ENDED_AT: Reset to IN_PROGRESS and test ended_at timestamp
task.set_status(TaskStatus.IN_PROGRESS)
task.ended_at = None
db.session.commit()
assert task.ended_at is None
result = InternalStatusTransitionCommand(
task_uuid=task.uuid,
new_status=TaskStatus.SUCCESS,
expected_status=TaskStatus.IN_PROGRESS,
set_ended_at=True,
).run()
assert result is True
db.session.refresh(task)
assert task.status == TaskStatus.SUCCESS.value
assert task.ended_at is not None
# 5. STRING VALUES: Reset and test string status values
task.set_status(TaskStatus.PENDING)
db.session.commit()
result = InternalStatusTransitionCommand(
task_uuid=task.uuid,
new_status="in_progress",
expected_status="pending",
).run()
assert result is True
db.session.refresh(task)
assert task.status == "in_progress"
finally:
db.session.delete(task)
db.session.commit()
def test_status_transition_prevents_race_condition(
app_context, get_user, login_as
) -> None:
"""Test that conditional update prevents overwriting concurrent abort.
This is the key race condition fix: if task is aborted concurrently,
the executor's attempt to set SUCCESS should fail (return False),
preserving the ABORTING state.
"""
admin = get_user("admin")
login_as("admin")
task = TaskDAO.create_task(
task_type="test_type",
task_key="status_transition_race",
scope=TaskScope.PRIVATE,
user_id=admin.id,
)
task.created_by = admin
task.set_status(TaskStatus.IN_PROGRESS)
db.session.commit()
try:
# Simulate concurrent abort: directly set ABORTING in DB
# (as if CancelTaskCommand ran in another process)
task.set_status(TaskStatus.ABORTING)
db.session.commit()
# Executor tries to set SUCCESS (expecting IN_PROGRESS) - stale expectation
result = InternalStatusTransitionCommand(
task_uuid=task.uuid,
new_status=TaskStatus.SUCCESS,
expected_status=TaskStatus.IN_PROGRESS,
).run()
# Should fail - task was aborted concurrently
assert result is False
# Verify ABORTING is preserved (not overwritten to SUCCESS)
db.session.refresh(task)
assert task.status == TaskStatus.ABORTING.value
# Verify correct transition from ABORTING still works
result = InternalStatusTransitionCommand(
task_uuid=task.uuid,
new_status=TaskStatus.ABORTED,
expected_status=TaskStatus.ABORTING,
set_ended_at=True,
).run()
assert result is True
db.session.refresh(task)
assert task.status == TaskStatus.ABORTED.value
finally:
db.session.delete(task)
db.session.commit()
def test_status_transition_nonexistent_task(app_context, login_as) -> None:
"""Test that transitioning non-existent task returns False."""
login_as("admin")
result = InternalStatusTransitionCommand(
task_uuid=UUID("00000000-0000-0000-0000-000000000000"),
new_status=TaskStatus.IN_PROGRESS,
expected_status=TaskStatus.PENDING,
).run()
assert result is False

View File

@@ -0,0 +1,258 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file to you under
# the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from datetime import datetime, timezone
from unittest.mock import patch
from freezegun import freeze_time
from superset_core.api.tasks import TaskScope, TaskStatus
from superset import db
from superset.commands.tasks import TaskPruneCommand
from superset.daos.tasks import TaskDAO
from superset.models.tasks import Task
@freeze_time("2024-02-15")
@patch("superset.tasks.utils.get_current_user")
def test_prune_tasks_success(mock_get_user, app_context, get_user, login_as) -> None:
"""Test successful pruning of old completed tasks"""
login_as("admin")
admin = get_user("admin")
mock_get_user.return_value = admin.username
# Create old completed tasks (35 days ago = Jan 11, 2024)
old_date = datetime(2024, 1, 11, tzinfo=timezone.utc)
task_ids = []
for i in range(3):
task = TaskDAO.create_task(
task_type="test_type",
task_key=f"prune_task_{i}",
scope=TaskScope.PRIVATE,
user_id=admin.id,
)
task.created_by = admin
task.set_status(TaskStatus.SUCCESS)
task.ended_at = old_date
task_ids.append(task.id)
# Create a recent task (5 days ago = Feb 10, 2024) that should NOT be deleted
recent_date = datetime(2024, 2, 10, tzinfo=timezone.utc)
recent_task = TaskDAO.create_task(
task_type="test_type",
task_key="recent_task",
scope=TaskScope.PRIVATE,
user_id=admin.id,
)
recent_task.created_by = admin
recent_task.set_status(TaskStatus.SUCCESS)
recent_task.ended_at = recent_date
recent_task_id = recent_task.id
db.session.commit()
try:
# Prune tasks older than 30 days
command = TaskPruneCommand(retention_period_days=30)
command.run()
# Verify old tasks are deleted
for task_id in task_ids:
assert db.session.get(Task, task_id) is None
# Verify recent task is NOT deleted
assert db.session.get(Task, recent_task_id) is not None
finally:
# Cleanup remaining tasks
for task_id in task_ids:
existing = db.session.get(Task, task_id)
if existing:
db.session.delete(existing)
if db.session.get(Task, recent_task_id):
db.session.delete(db.session.get(Task, recent_task_id))
db.session.commit()
@freeze_time("2024-02-15")
@patch("superset.tasks.utils.get_current_user")
def test_prune_tasks_with_max_rows(
mock_get_user, app_context, get_user, login_as
) -> None:
"""Test pruning with max_rows_per_run limit"""
login_as("admin")
admin = get_user("admin")
mock_get_user.return_value = admin.username
# Create old completed tasks (35 days ago = Jan 11, 2024)
task_ids = []
for i in range(5):
# Different ages for ordering (older tasks have smaller hour values)
old_date = datetime(2024, 1, 11, i, tzinfo=timezone.utc)
task = TaskDAO.create_task(
task_type="test_type",
task_key=f"max_rows_task_{i}",
scope=TaskScope.PRIVATE,
user_id=admin.id,
)
task.created_by = admin
task.set_status(TaskStatus.SUCCESS)
task.ended_at = old_date
task_ids.append(task.id)
db.session.commit()
try:
# Prune with max_rows_per_run=2 (should only delete 2 oldest)
command = TaskPruneCommand(retention_period_days=30, max_rows_per_run=2)
command.run()
# Count remaining tasks
remaining = sum(
1 for task_id in task_ids if db.session.get(Task, task_id) is not None
)
assert remaining == 3 # 5 - 2 = 3 remaining
finally:
# Cleanup remaining tasks
for task_id in task_ids:
existing = db.session.get(Task, task_id)
if existing:
db.session.delete(existing)
db.session.commit()
@freeze_time("2024-02-15")
@patch("superset.tasks.utils.get_current_user")
def test_prune_does_not_delete_pending_tasks(
mock_get_user, app_context, get_user, login_as
) -> None:
"""Test that pruning does not delete pending or in-progress tasks"""
login_as("admin")
admin = get_user("admin")
mock_get_user.return_value = admin.username
pending_task = TaskDAO.create_task(
task_type="test_type",
task_key="pending_task",
scope=TaskScope.PRIVATE,
user_id=admin.id,
)
pending_task.created_by = admin
# Keep as PENDING (no ended_at)
in_progress_task = TaskDAO.create_task(
task_type="test_type",
task_key="in_progress_task",
scope=TaskScope.PRIVATE,
user_id=admin.id,
)
in_progress_task.created_by = admin
in_progress_task.set_status(TaskStatus.IN_PROGRESS)
# No ended_at for in-progress tasks
db.session.commit()
try:
# Prune tasks older than 30 days
command = TaskPruneCommand(retention_period_days=30)
command.run()
# Verify non-completed tasks are NOT deleted
assert db.session.get(Task, pending_task.id) is not None
assert db.session.get(Task, in_progress_task.id) is not None
finally:
# Cleanup
for task in [pending_task, in_progress_task]:
existing = db.session.get(Task, task.id)
if existing:
db.session.delete(existing)
db.session.commit()
@freeze_time("2024-02-15")
@patch("superset.tasks.utils.get_current_user")
def test_prune_deletes_all_completed_statuses(
mock_get_user, app_context, get_user, login_as
) -> None:
"""Test pruning deletes SUCCESS, FAILURE, and ABORTED tasks"""
login_as("admin")
admin = get_user("admin")
mock_get_user.return_value = admin.username
old_date = datetime(2024, 1, 11, tzinfo=timezone.utc)
# Create tasks with different completed statuses
success_task = TaskDAO.create_task(
task_type="test_type",
task_key="success_task",
scope=TaskScope.PRIVATE,
user_id=admin.id,
)
success_task.created_by = admin
success_task.set_status(TaskStatus.SUCCESS)
success_task.ended_at = old_date
success_task_id = success_task.id
failure_task = TaskDAO.create_task(
task_type="test_type",
task_key="failure_task",
scope=TaskScope.PRIVATE,
user_id=admin.id,
)
failure_task.created_by = admin
failure_task.set_status(TaskStatus.FAILURE)
failure_task.ended_at = old_date
failure_task_id = failure_task.id
aborted_task = TaskDAO.create_task(
task_type="test_type",
task_key="aborted_task",
scope=TaskScope.PRIVATE,
user_id=admin.id,
)
aborted_task.created_by = admin
aborted_task.set_status(TaskStatus.ABORTED)
aborted_task.ended_at = old_date
aborted_task_id = aborted_task.id
db.session.commit()
task_ids = [success_task_id, failure_task_id, aborted_task_id]
try:
# Prune tasks older than 30 days
command = TaskPruneCommand(retention_period_days=30)
command.run()
# Verify all completed tasks are deleted
for task_id in task_ids:
assert db.session.get(Task, task_id) is None
except AssertionError:
# Cleanup if test fails
for task_id in task_ids:
existing = db.session.get(Task, task_id)
if existing:
db.session.delete(existing)
db.session.commit()
raise
def test_prune_no_tasks_to_delete(app_context, login_as) -> None:
"""Test pruning when no old tasks exist"""
login_as("admin")
# Don't create any tasks - should handle gracefully
command = TaskPruneCommand(retention_period_days=30)
command.run() # Should not raise any errors

View File

@@ -0,0 +1,238 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file to you under
# the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
from superset_core.api.tasks import TaskStatus
from superset import db
from superset.commands.tasks import SubmitTaskCommand
from superset.commands.tasks.exceptions import (
TaskInvalidError,
)
def test_submit_task_success(app_context, login_as, get_user) -> None:
"""Test successful task submission"""
login_as("admin")
admin = get_user("admin")
command = SubmitTaskCommand(
data={
"task_type": "test-type",
"task_key": "test-key",
"task_name": "Test Task",
"user_id": admin.id,
}
)
try:
result = command.run()
# Verify task was created
assert result.task_type == "test-type"
assert result.task_key == "test-key"
assert result.task_name == "Test Task"
assert result.status == TaskStatus.PENDING.value
assert result.payload == "{}"
# Verify in database
db.session.refresh(result)
assert result.id is not None
assert result.uuid is not None
finally:
# Cleanup
db.session.delete(result)
db.session.commit()
def test_submit_task_with_all_fields(app_context, login_as, get_user) -> None:
"""Test task submission with all optional fields"""
login_as("admin")
admin = get_user("admin")
command = SubmitTaskCommand(
data={
"task_type": "test-type",
"task_key": "test-key-full",
"task_name": "Test Task Full",
"user_id": admin.id,
"payload": {"key": "value"},
"properties": {"execution_mode": "async", "timeout": 300},
}
)
try:
result = command.run()
# Verify all fields were set
assert result.task_type == "test-type"
assert result.task_key == "test-key-full"
assert result.task_name == "Test Task Full"
assert result.user_id == admin.id
assert result.payload_dict == {"key": "value"}
assert result.properties_dict.get("execution_mode") == "async"
assert result.properties_dict.get("timeout") == 300
finally:
# Cleanup
db.session.delete(result)
db.session.commit()
def test_submit_task_missing_task_type(app_context, login_as) -> None:
"""Test submission fails when task_type is missing"""
login_as("admin")
command = SubmitTaskCommand(data={})
with pytest.raises(TaskInvalidError) as exc_info:
command.run()
assert len(exc_info.value._exceptions) == 1
assert "task_type" in exc_info.value._exceptions[0].field_name
def test_submit_task_joins_existing(app_context, login_as, get_user) -> None:
"""Test that submitting with duplicate key joins existing task"""
login_as("admin")
admin = get_user("admin")
# Create first task
command1 = SubmitTaskCommand(
data={
"task_type": "test-type",
"task_key": "shared-key",
"task_name": "First Task",
"user_id": admin.id,
}
)
task1 = command1.run()
try:
# Submit second task with same task_key and type
command2 = SubmitTaskCommand(
data={
"task_type": "test-type",
"task_key": "shared-key",
"task_name": "Second Task",
"user_id": admin.id,
}
)
# Should return existing task, not create new one
task2 = command2.run()
assert task2.id == task1.id
assert task2.uuid == task1.uuid
finally:
# Cleanup
db.session.delete(task1)
db.session.commit()
def test_submit_task_without_task_key(app_context, login_as, get_user) -> None:
"""Test task submission without task_key (command generates UUID)"""
login_as("admin")
admin = get_user("admin")
command = SubmitTaskCommand(
data={
"task_type": "test-type",
"task_name": "Test Task No ID",
"user_id": admin.id,
}
)
try:
result = command.run()
# Verify task was created and command generated a task_key
assert result.task_type == "test-type"
assert result.task_name == "Test Task No ID"
assert result.task_key is not None # Command generated UUID
assert result.uuid is not None
finally:
# Cleanup
db.session.delete(result)
db.session.commit()
def test_submit_task_run_with_info_returns_is_new_true(
app_context, login_as, get_user
) -> None:
"""Test run_with_info returns is_new=True for new task"""
login_as("admin")
admin = get_user("admin")
command = SubmitTaskCommand(
data={
"task_type": "test-type",
"task_key": "unique-key-is-new",
"task_name": "Test Task",
"user_id": admin.id,
}
)
try:
task, is_new = command.run_with_info()
assert is_new is True
assert task.task_key == "unique-key-is-new"
finally:
# Cleanup
db.session.delete(task)
db.session.commit()
def test_submit_task_run_with_info_returns_is_new_false(
app_context, login_as, get_user
) -> None:
"""Test run_with_info returns is_new=False when joining existing task"""
login_as("admin")
admin = get_user("admin")
# Create first task
command1 = SubmitTaskCommand(
data={
"task_type": "test-type",
"task_key": "shared-key-is-new",
"task_name": "First Task",
"user_id": admin.id,
}
)
task1, is_new1 = command1.run_with_info()
assert is_new1 is True
try:
# Submit second task with same key
command2 = SubmitTaskCommand(
data={
"task_type": "test-type",
"task_key": "shared-key-is-new",
"task_name": "Second Task",
"user_id": admin.id,
}
)
task2, is_new2 = command2.run_with_info()
# Should return existing task with is_new=False
assert is_new2 is False
assert task2.id == task1.id
assert task2.uuid == task1.uuid
finally:
# Cleanup
db.session.delete(task1)
db.session.commit()

View File

@@ -0,0 +1,260 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file to you under
# the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from uuid import UUID
import pytest
from superset_core.api.tasks import TaskScope, TaskStatus
from superset import db
from superset.commands.tasks import UpdateTaskCommand
from superset.commands.tasks.exceptions import (
TaskForbiddenError,
TaskNotFoundError,
)
from superset.daos.tasks import TaskDAO
def test_update_task_success(app_context, get_user, login_as) -> None:
"""Test successful task update"""
admin = get_user("admin")
login_as("admin")
# Create a task using DAO
task = TaskDAO.create_task(
task_type="test_type",
task_key="update_test",
scope=TaskScope.PRIVATE,
user_id=admin.id,
)
task.created_by = admin
task.set_status(TaskStatus.IN_PROGRESS)
db.session.commit()
try:
# Update the task status
command = UpdateTaskCommand(
task_uuid=task.uuid,
status=TaskStatus.SUCCESS.value,
)
result = command.run()
# Verify update
assert result.uuid == task.uuid
assert result.status == TaskStatus.SUCCESS.value
# Verify in database
db.session.refresh(task)
assert task.status == TaskStatus.SUCCESS.value
finally:
# Cleanup
db.session.delete(task)
db.session.commit()
def test_update_task_not_found(app_context, login_as) -> None:
"""Test update fails when task not found"""
login_as("admin")
command = UpdateTaskCommand(
task_uuid=UUID("00000000-0000-0000-0000-000000000000"),
status=TaskStatus.SUCCESS.value,
)
with pytest.raises(TaskNotFoundError):
command.run()
def test_update_task_forbidden(app_context, get_user, login_as) -> None:
"""Test update fails when user doesn't own task (via base filter)"""
gamma = get_user("gamma")
login_as("gamma")
# Create a task owned by gamma (non-admin) using DAO
task = TaskDAO.create_task(
task_type="test_type",
task_key="forbidden_test",
scope=TaskScope.PRIVATE,
user_id=gamma.id,
)
task.created_by = gamma
task.set_status(TaskStatus.IN_PROGRESS)
db.session.commit()
try:
# Login as alpha user (different non-admin, non-owner)
login_as("alpha")
# Try to update gamma's task as alpha user
command = UpdateTaskCommand(
task_uuid=task.uuid,
status=TaskStatus.SUCCESS.value,
)
# Should raise ForbiddenError because ownership check fails
with pytest.raises(TaskForbiddenError):
command.run()
# Verify task was NOT updated
db.session.refresh(task)
assert task.status == TaskStatus.IN_PROGRESS.value
finally:
# Cleanup
db.session.delete(task)
db.session.commit()
def test_update_task_payload(app_context, get_user, login_as) -> None:
"""Test updating task payload"""
admin = get_user("admin")
login_as("admin")
# Create a task using DAO
task = TaskDAO.create_task(
task_type="test_type",
task_key="payload_test",
scope=TaskScope.PRIVATE,
user_id=admin.id,
payload={"initial": "data"},
)
task.created_by = admin
task.set_status(TaskStatus.IN_PROGRESS)
db.session.commit()
try:
# Update payload
command = UpdateTaskCommand(
task_uuid=task.uuid,
payload={"progress": 50, "message": "halfway"},
)
result = command.run()
# Verify payload was updated
assert result.uuid == task.uuid
payload = result.payload_dict
assert payload["progress"] == 50
assert payload["message"] == "halfway"
# Verify in database
db.session.refresh(task)
task_payload = task.payload_dict
assert task_payload["progress"] == 50
finally:
# Cleanup
db.session.delete(task)
db.session.commit()
def test_update_all_supported_fields(app_context, get_user, login_as) -> None:
"""Test updating all supported task fields
(status, error, progress, abortable, timeout)"""
admin = get_user("admin")
login_as("admin")
# Create a task with initial execution_mode and timeout in properties
task = TaskDAO.create_task(
task_type="test_type",
task_key="all_fields_test",
scope=TaskScope.PRIVATE,
user_id=admin.id,
properties={"execution_mode": "async", "timeout": 300},
)
task.created_by = admin
task.set_status(TaskStatus.IN_PROGRESS)
db.session.commit()
try:
# Update all field types at once
command = UpdateTaskCommand(
task_uuid=task.uuid,
status=TaskStatus.FAILURE.value,
properties={
"error_message": "Task failed due to error",
"progress_percent": 0.75,
"progress_current": 75,
"progress_total": 100,
"is_abortable": True,
},
)
result = command.run()
# Verify all fields were updated
assert result.uuid == task.uuid
assert result.status == TaskStatus.FAILURE.value
assert result.properties_dict.get("error_message") == "Task failed due to error"
assert result.properties_dict.get("progress_percent") == 0.75
assert result.properties_dict.get("progress_current") == 75
assert result.properties_dict.get("progress_total") == 100
assert result.properties_dict.get("is_abortable") is True
assert result.properties_dict.get("execution_mode") == "async"
assert result.properties_dict.get("timeout") == 300
# Verify in database
db.session.refresh(task)
assert task.status == TaskStatus.FAILURE.value
assert task.properties_dict.get("error_message") == "Task failed due to error"
assert task.properties_dict.get("progress_percent") == 0.75
assert task.properties_dict.get("progress_current") == 75
assert task.properties_dict.get("progress_total") == 100
assert task.properties_dict.get("is_abortable") is True
assert task.properties_dict.get("execution_mode") == "async"
assert task.properties_dict.get("timeout") == 300
finally:
# Cleanup
db.session.delete(task)
db.session.commit()
def test_update_task_skip_security_check(app_context, get_user, login_as) -> None:
"""Test skip_security_check allows updating any task"""
admin = get_user("admin")
login_as("admin")
# Create a task owned by admin
task = TaskDAO.create_task(
task_type="test_type",
task_key="skip_security_test",
scope=TaskScope.PRIVATE,
user_id=admin.id,
)
task.created_by = admin
task.set_status(TaskStatus.IN_PROGRESS)
db.session.commit()
try:
# Login as gamma user (non-owner)
login_as("gamma")
# With skip_security_check=True, should succeed even though gamma doesn't own it
command = UpdateTaskCommand(
task_uuid=task.uuid,
properties={"progress_percent": 0.75},
skip_security_check=True,
)
result = command.run()
# Verify update succeeded
assert result.uuid == task.uuid
assert result.properties_dict.get("progress_percent") == 0.75
# Verify in database
db.session.refresh(task)
assert task.properties_dict.get("progress_percent") == 0.75
finally:
# Cleanup
db.session.delete(task)
db.session.commit()

View File

@@ -0,0 +1,415 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""End-to-end integration tests for task event handlers (abort and cleanup)
These tests verify that abort and cleanup handlers work correctly through
the full task execution path using real @task decorated functions executed
via the Celery executor (synchronously via .apply()).
"""
from __future__ import annotations
import uuid
from typing import Any
from superset_core.api.tasks import TaskScope, TaskStatus
from superset.commands.tasks.cancel import CancelTaskCommand
from superset.daos.tasks import TaskDAO
from superset.extensions import db
from superset.models.tasks import Task
from superset.tasks.ambient_context import get_context
from superset.tasks.context import TaskContext
from superset.tasks.registry import TaskRegistry
from superset.tasks.scheduler import execute_task
from tests.integration_tests.base_tests import SupersetTestCase
from tests.integration_tests.constants import ADMIN_USERNAME
# Module-level state to track handler calls across test executions
# (Since decorated functions are defined at module level)
_handler_state: dict[str, Any] = {}
def _reset_handler_state():
"""Reset handler state before each test."""
global _handler_state
_handler_state = {
"cleanup_called": False,
"abort_called": False,
"cleanup_order": [],
"abort_order": [],
"cleanup_data": {},
}
def cleanup_test_task() -> None:
"""Task that registers a cleanup handler."""
ctx = get_context()
@ctx.on_cleanup
def handle_cleanup() -> None:
_handler_state["cleanup_called"] = True
# Simulate some work
ctx.update_task(progress=1.0)
def abort_test_task() -> None:
"""Task that registers an abort handler."""
ctx = get_context()
@ctx.on_abort
def handle_abort() -> None:
_handler_state["abort_called"] = True
def both_handlers_task() -> None:
"""Task that registers both abort and cleanup handlers."""
ctx = get_context()
@ctx.on_abort
def handle_abort() -> None:
_handler_state["abort_called"] = True
_handler_state["abort_order"].append("abort")
@ctx.on_cleanup
def handle_cleanup() -> None:
_handler_state["cleanup_called"] = True
_handler_state["cleanup_order"].append("cleanup")
def multiple_cleanup_handlers_task() -> None:
"""Task that registers multiple cleanup handlers."""
ctx = get_context()
@ctx.on_cleanup
def cleanup_first() -> None:
_handler_state["cleanup_order"].append("first")
@ctx.on_cleanup
def cleanup_second() -> None:
_handler_state["cleanup_order"].append("second")
@ctx.on_cleanup
def cleanup_third() -> None:
_handler_state["cleanup_order"].append("third")
def cleanup_with_data_task() -> None:
"""Task that uses cleanup handler to clean up partial work."""
ctx = get_context()
# Simulate partial work in module-level state
_handler_state["cleanup_data"]["temp_key"] = "temp_value"
@ctx.on_cleanup
def handle_cleanup() -> None:
# Clean up the partial work
_handler_state["cleanup_data"].clear()
_handler_state["cleanup_called"] = True
def _register_test_tasks() -> None:
"""Register test task functions if not already registered.
Called in setUp() to ensure tasks are registered regardless of
whether other tests have cleared the registry.
"""
registrations = [
("test_cleanup_task", cleanup_test_task),
("test_abort_task", abort_test_task),
("test_both_handlers_task", both_handlers_task),
("test_multiple_cleanup_task", multiple_cleanup_handlers_task),
("test_cleanup_with_data", cleanup_with_data_task),
]
for name, func in registrations:
if not TaskRegistry.is_registered(name):
TaskRegistry.register(name, func)
class TestCleanupHandlers(SupersetTestCase):
"""E2E tests for on_cleanup functionality using Celery executor."""
def setUp(self):
"""Set up test fixtures."""
super().setUp()
self.login(ADMIN_USERNAME)
_register_test_tasks()
_reset_handler_state()
def test_cleanup_handler_fires_on_success(self):
"""Test cleanup handler runs when task completes successfully."""
# Create task entry directly
task_obj = TaskDAO.create_task(
task_type="test_cleanup_task",
task_key=f"test_key_{uuid.uuid4().hex[:8]}",
task_name="Test Cleanup",
scope=TaskScope.SYSTEM,
)
# Execute task synchronously through Celery executor
# Use str(uuid) since Celery serializes args as JSON strings
result = execute_task.apply(
args=[str(task_obj.uuid), "test_cleanup_task", (), {}]
)
# Verify task completed successfully
assert result.successful()
assert result.result["status"] == TaskStatus.SUCCESS.value
# Verify cleanup handler was called
assert _handler_state["cleanup_called"]
def test_multiple_cleanup_handlers_in_lifo_order(self):
"""Test multiple cleanup handlers execute in LIFO order."""
task_obj = TaskDAO.create_task(
task_type="test_multiple_cleanup_task",
task_key=f"test_key_{uuid.uuid4().hex[:8]}",
task_name="Test Multiple Cleanup",
scope=TaskScope.SYSTEM,
)
result = execute_task.apply(
args=[str(task_obj.uuid), "test_multiple_cleanup_task", (), {}]
)
assert result.successful()
# Handlers should execute in LIFO order (last registered first)
assert _handler_state["cleanup_order"] == ["third", "second", "first"]
def test_cleanup_handler_cleans_up_partial_work(self):
"""Test cleanup handler can clean up partial work."""
task_obj = TaskDAO.create_task(
task_type="test_cleanup_with_data",
task_key=f"test_key_{uuid.uuid4().hex[:8]}",
task_name="Test Cleanup Data",
scope=TaskScope.SYSTEM,
)
result = execute_task.apply(
args=[str(task_obj.uuid), "test_cleanup_with_data", (), {}]
)
assert result.successful()
assert _handler_state["cleanup_called"]
# Cleanup handler should have cleared the data
assert len(_handler_state["cleanup_data"]) == 0
class TestAbortHandlers(SupersetTestCase):
"""E2E tests for on_abort functionality."""
def setUp(self):
"""Set up test fixtures."""
super().setUp()
self.login(ADMIN_USERNAME)
_register_test_tasks()
_reset_handler_state()
def test_abort_handler_fires_when_task_aborting(self):
"""Test abort handler runs when task is in ABORTING state during cleanup."""
# Create task entry
task_obj = TaskDAO.create_task(
task_type="test_abort_task",
task_key=f"test_key_{uuid.uuid4().hex[:8]}",
task_name="Test Abort",
scope=TaskScope.SYSTEM,
)
# Manually set to IN_PROGRESS and then ABORTING to simulate abort
task_obj.status = TaskStatus.IN_PROGRESS.value
task_obj.update_properties({"is_abortable": True})
db.session.merge(task_obj)
db.session.commit()
# Refresh to get the updated task
db.session.refresh(task_obj)
# Create context (simulating what executor does)
ctx = TaskContext(task_obj)
# Register abort handler
@ctx.on_abort
def handle_abort():
_handler_state["abort_called"] = True
# Set status to ABORTING (simulating CancelTaskCommand)
task_obj.status = TaskStatus.ABORTING.value
db.session.merge(task_obj)
db.session.commit()
# Run cleanup (simulating executor's finally block)
ctx._run_cleanup()
# Verify abort handler was called
assert _handler_state["abort_called"]
def test_both_handlers_fire_on_abort(self):
"""Test both abort and cleanup handlers run when task is aborted."""
task_obj = TaskDAO.create_task(
task_type="test_both_handlers_task",
task_key=f"test_key_{uuid.uuid4().hex[:8]}",
task_name="Test Both Handlers",
scope=TaskScope.SYSTEM,
)
task_obj.status = TaskStatus.IN_PROGRESS.value
task_obj.update_properties({"is_abortable": True})
db.session.merge(task_obj)
db.session.commit()
# Refresh to get the updated task
db.session.refresh(task_obj)
ctx = TaskContext(task_obj)
@ctx.on_abort
def handle_abort():
_handler_state["abort_called"] = True
_handler_state["abort_order"].append("abort")
@ctx.on_cleanup
def handle_cleanup():
_handler_state["cleanup_called"] = True
_handler_state["cleanup_order"].append("cleanup")
# Set to ABORTING
task_obj.status = TaskStatus.ABORTING.value
db.session.merge(task_obj)
db.session.commit()
ctx._run_cleanup()
# Both should have been called
assert _handler_state["abort_called"]
assert _handler_state["cleanup_called"]
def test_abort_handler_not_called_on_success(self):
"""Test abort handler doesn't run when task succeeds."""
task_obj = TaskDAO.create_task(
task_type="test_abort_task",
task_key=f"test_key_{uuid.uuid4().hex[:8]}",
task_name="Test No Abort on Success",
scope=TaskScope.SYSTEM,
)
task_obj.status = TaskStatus.SUCCESS.value
db.session.merge(task_obj)
db.session.commit()
# Refresh to get the updated task
db.session.refresh(task_obj)
ctx = TaskContext(task_obj)
@ctx.on_abort
def handle_abort():
_handler_state["abort_called"] = True
@ctx.on_cleanup
def handle_cleanup():
_handler_state["cleanup_called"] = True
ctx._run_cleanup()
# Abort handler should NOT be called
assert not _handler_state["abort_called"]
# Cleanup handler should still be called
assert _handler_state["cleanup_called"]
class TestTaskContextMethods(SupersetTestCase):
"""Tests for TaskContext public methods."""
def setUp(self):
"""Set up test fixtures."""
super().setUp()
self.login(ADMIN_USERNAME)
def test_on_abort_marks_task_abortable(self):
"""Test that registering an on_abort handler marks task as abortable."""
task_obj = TaskDAO.create_task(
task_type="test_abortable_flag",
task_key=f"test_key_{uuid.uuid4().hex[:8]}",
task_name="Test Abortable",
scope=TaskScope.SYSTEM,
)
assert task_obj.properties_dict.get("is_abortable") is not True
ctx = TaskContext(task_obj)
@ctx.on_abort
def handle_abort():
pass
db.session.expire_all()
task_obj = db.session.query(Task).filter_by(uuid=task_obj.uuid).first()
assert task_obj.properties_dict.get("is_abortable") is True
class TestAbortBeforeExecution(SupersetTestCase):
"""Tests for aborting tasks before they start executing."""
def setUp(self):
"""Set up test fixtures."""
super().setUp()
self.login(ADMIN_USERNAME)
_register_test_tasks()
def test_abort_pending_task(self):
"""Test that pending tasks can be aborted directly."""
task_obj = TaskDAO.create_task(
task_type="test_abort_before_start",
task_key=f"test_key_{uuid.uuid4().hex[:8]}",
task_name="Test Before Start",
scope=TaskScope.SYSTEM,
)
# Cancel immediately (task is still PENDING)
CancelTaskCommand(task_obj.uuid, force=True).run()
db.session.expire_all()
task_obj = db.session.query(Task).filter_by(uuid=task_obj.uuid).first()
assert task_obj.status == TaskStatus.ABORTED.value
def test_executor_skips_aborted_task(self):
"""Test that executor skips tasks already aborted before execution."""
task_obj = TaskDAO.create_task(
task_type="test_cleanup_task",
task_key=f"test_key_{uuid.uuid4().hex[:8]}",
task_name="Test Skip Aborted",
scope=TaskScope.SYSTEM,
)
# Abort the task before execution
task_obj.status = TaskStatus.ABORTED.value
db.session.merge(task_obj)
db.session.commit()
_reset_handler_state()
# Try to execute - should skip
result = execute_task.apply(
args=[str(task_obj.uuid), "test_cleanup_task", (), {}]
)
assert result.successful()
assert result.result["status"] == TaskStatus.ABORTED.value
# Cleanup handler should NOT have been called (task was skipped)
assert not _handler_state["cleanup_called"]

View File

@@ -0,0 +1,158 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file to you under
# the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Integration tests for sync join-and-wait functionality in GTF."""
import time
from superset_core.api.tasks import TaskStatus
from superset import db
from superset.commands.tasks import SubmitTaskCommand
from superset.daos.tasks import TaskDAO
from superset.tasks.manager import TaskManager
def test_submit_task_distinguishes_new_vs_existing(
app_context, login_as, get_user
) -> None:
"""
Test that SubmitTaskCommand.run_with_info() correctly returns is_new flag.
"""
login_as("admin")
admin = get_user("admin")
# First submission - should be new
task1, is_new1 = SubmitTaskCommand(
data={
"task_type": "test-type",
"task_key": "distinguish-key",
"task_name": "First Task",
"user_id": admin.id,
}
).run_with_info()
assert is_new1 is True
try:
# Second submission with same key - should join existing
task2, is_new2 = SubmitTaskCommand(
data={
"task_type": "test-type",
"task_key": "distinguish-key",
"task_name": "Second Task",
"user_id": admin.id,
}
).run_with_info()
assert is_new2 is False
assert task2.uuid == task1.uuid
finally:
# Cleanup
db.session.delete(task1)
db.session.commit()
def test_terminal_states_recognized_correctly(app_context) -> None:
"""
Test that TaskManager.TERMINAL_STATES contains the expected values.
"""
assert TaskStatus.SUCCESS.value in TaskManager.TERMINAL_STATES
assert TaskStatus.FAILURE.value in TaskManager.TERMINAL_STATES
assert TaskStatus.ABORTED.value in TaskManager.TERMINAL_STATES
assert TaskStatus.TIMED_OUT.value in TaskManager.TERMINAL_STATES
# Non-terminal states should not be in the set
assert TaskStatus.PENDING.value not in TaskManager.TERMINAL_STATES
assert TaskStatus.IN_PROGRESS.value not in TaskManager.TERMINAL_STATES
assert TaskStatus.ABORTING.value not in TaskManager.TERMINAL_STATES
def test_wait_for_completion_timeout(app_context, login_as, get_user) -> None:
"""
Test that wait_for_completion raises TimeoutError on timeout.
"""
from unittest.mock import patch
import pytest
login_as("admin")
admin = get_user("admin")
# Create a pending task (won't complete)
task, _ = SubmitTaskCommand(
data={
"task_type": "test-timeout",
"task_key": "timeout-key",
"task_name": "Timeout Task",
"user_id": admin.id,
}
).run_with_info()
try:
# Force polling mode by mocking signal_cache as None
with patch("superset.tasks.manager.cache_manager") as mock_cache_manager:
mock_cache_manager.signal_cache = None
with pytest.raises(TimeoutError):
TaskManager.wait_for_completion(
task.uuid,
timeout=0.2,
poll_interval=0.05,
)
finally:
db.session.delete(task)
db.session.commit()
def test_wait_returns_immediately_for_terminal_task(
app_context, login_as, get_user
) -> None:
"""
Test that wait_for_completion returns immediately if task is already terminal.
"""
login_as("admin")
admin = get_user("admin")
# Create and immediately complete a task
task, _ = SubmitTaskCommand(
data={
"task_type": "test-immediate",
"task_key": "immediate-key",
"task_name": "Immediate Task",
"user_id": admin.id,
}
).run_with_info()
TaskDAO.update(task, {"status": TaskStatus.SUCCESS.value})
db.session.commit()
try:
start = time.time()
result = TaskManager.wait_for_completion(
task.uuid,
timeout=5.0,
poll_interval=0.5,
)
elapsed = time.time() - start
assert result.status == TaskStatus.SUCCESS.value
# Should return almost immediately since task is already terminal
assert elapsed < 0.2
finally:
db.session.delete(task)
db.session.commit()

View File

@@ -0,0 +1,172 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Integration tests for TaskContext update_task throttling.
Tests verify:
1. Final state is persisted correctly via cleanup flush
2. Throttled updates are deferred, timer writes latest pending update
"""
from __future__ import annotations
import time
import uuid
from superset_core.api.tasks import TaskScope, TaskStatus
from superset.daos.tasks import TaskDAO
from superset.extensions import db
from superset.models.tasks import Task
from superset.tasks.ambient_context import get_context
from superset.tasks.registry import TaskRegistry
from superset.tasks.scheduler import execute_task
from tests.integration_tests.base_tests import SupersetTestCase
from tests.integration_tests.constants import ADMIN_USERNAME
def task_with_throttled_updates() -> None:
"""Task with rapid progress and payload updates (exercises throttling)."""
ctx = get_context()
# Rapid-fire updates within throttle window
for i in range(10):
ctx.update_task(progress=(i + 1, 10), payload={"step": i + 1})
def _register_test_tasks() -> None:
"""Register test task functions if not already registered.
Called in setUp() to ensure tasks are registered regardless of
whether other tests have cleared the registry.
"""
if not TaskRegistry.is_registered("test_throttle_combined"):
TaskRegistry.register("test_throttle_combined", task_with_throttled_updates)
class TestUpdateTaskThrottling(SupersetTestCase):
"""Integration test for update_task() throttling behavior."""
def setUp(self) -> None:
super().setUp()
self.login(ADMIN_USERNAME)
_register_test_tasks()
def test_throttled_updates_persisted_on_cleanup(self) -> None:
"""Final state should be persisted regardless of throttling.
Verifies the core invariant: cleanup flush ensures final state is persisted.
"""
task_obj = TaskDAO.create_task(
task_type="test_throttle_combined",
task_key=f"test_key_{uuid.uuid4().hex[:8]}",
task_name="Test Throttled Updates",
scope=TaskScope.SYSTEM,
)
# Use str(uuid) since Celery serializes args as JSON strings
result = execute_task.apply(
args=[str(task_obj.uuid), "test_throttle_combined", (), {}]
)
assert result.successful()
assert result.result["status"] == TaskStatus.SUCCESS.value
# Verify final state is persisted
db.session.expire_all()
task_obj = db.session.query(Task).filter_by(uuid=task_obj.uuid).first()
# Progress: 10/10 = 100%
props = task_obj.properties_dict
assert props.get("progress_current") == 10
assert props.get("progress_total") == 10
assert props.get("progress_percent") == 1.0
# Payload: final step
payload = task_obj.payload_dict
assert payload.get("step") == 10
def test_throttle_behavior(self) -> None:
"""Test complete throttle behavior: immediate write, deferral, and timer.
Verifies:
1. First update writes immediately
2. Second and third updates within throttle window are deferred
3. Deferred timer fires and writes the LATEST pending update (third)
"""
from flask import current_app
from superset.commands.tasks.submit import SubmitTaskCommand
from superset.tasks.context import TaskContext
# Get throttle interval from config (default: 2 seconds)
throttle_interval = current_app.config["TASK_PROGRESS_UPDATE_THROTTLE_INTERVAL"]
# Create task
task_obj = SubmitTaskCommand(
data={
"task_type": "test_throttle_behavior",
"task_key": f"test_key_{uuid.uuid4().hex[:8]}",
"task_name": "Test Throttle Behavior",
"scope": TaskScope.SYSTEM,
}
).run()
task_uuid = task_obj.uuid
# Get fresh task for context
fresh_task = TaskDAO.find_one_or_none(uuid=task_uuid, skip_base_filter=True)
assert fresh_task is not None
ctx = TaskContext(fresh_task)
try:
# === Step 1: First update - writes immediately ===
ctx.update_task(progress=0.1, payload={"step": 1})
db.session.expire_all()
task_step1 = TaskDAO.find_one_or_none(uuid=task_uuid, skip_base_filter=True)
assert task_step1 is not None
assert task_step1.properties_dict.get("progress_percent") == 0.1
assert task_step1.payload_dict.get("step") == 1
# === Step 2: Second update - deferred (within throttle window) ===
ctx.update_task(progress=0.5, payload={"step": 2})
# === Step 3: Third update - also deferred, overwrites second in cache ===
ctx.update_task(progress=0.7, payload={"step": 3})
# Verify in-memory cache has LATEST update (third)
assert ctx._properties_cache.get("progress_percent") == 0.7
assert ctx._payload_cache.get("step") == 3
# Verify DB still has first update (both second and third deferred)
db.session.expire_all()
task_step2 = TaskDAO.find_one_or_none(uuid=task_uuid, skip_base_filter=True)
assert task_step2 is not None
assert task_step2.properties_dict.get("progress_percent") == 0.1
assert task_step2.payload_dict.get("step") == 1
# === Step 4: Wait for deferred timer to fire ===
time.sleep(throttle_interval + 0.5)
# Verify timer fired and wrote the LATEST update (third, not second)
db.session.expire_all()
task_step3 = TaskDAO.find_one_or_none(uuid=task_uuid, skip_base_filter=True)
assert task_step3 is not None
assert task_step3.properties_dict.get("progress_percent") == 0.7
assert task_step3.payload_dict.get("step") == 3
finally:
ctx._cancel_deferred_flush_timer()

View File

@@ -0,0 +1,226 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Integration tests for GTF timeout handling.
Uses module-level task functions with manual registry (like test_event_handlers.py)
to avoid mypy issues with the @task decorator's complex generic types.
NOTE: Tests that use background threads (timeout/abort handlers) are skipped in
SQLite environments because SQLite connections cannot be shared across threads.
"""
from __future__ import annotations
import time
import uuid
from typing import Any
import pytest
from superset_core.api.tasks import TaskScope, TaskStatus
from superset.commands.tasks.cancel import CancelTaskCommand
from superset.daos.tasks import TaskDAO
from superset.extensions import db
from superset.models.tasks import Task
from superset.tasks.ambient_context import get_context
from superset.tasks.registry import TaskRegistry
from superset.tasks.scheduler import execute_task
from tests.integration_tests.base_tests import SupersetTestCase
from tests.integration_tests.constants import ADMIN_USERNAME
def _skip_if_sqlite() -> None:
"""Skip test if running with SQLite database.
SQLite connections cannot be shared across threads, which breaks
timeout tests that use background threads for abort handlers.
Must be called from within a test method (with app context).
"""
if "sqlite" in db.engine.url.drivername:
pytest.skip("SQLite connections cannot be shared across threads")
# Module-level state to track handler calls
_handler_state: dict[str, Any] = {}
def _reset_handler_state() -> None:
"""Reset handler state before each test."""
global _handler_state
_handler_state = {
"abort_called": False,
"handler_exception": None,
}
def timeout_abortable_task() -> None:
"""Task with abort handler that exits when aborted."""
ctx = get_context()
@ctx.on_abort
def on_abort() -> None:
_handler_state["abort_called"] = True
# Poll for abort signal
for _ in range(50):
if _handler_state["abort_called"]:
return
time.sleep(0.1)
def timeout_handler_fails_task() -> None:
"""Task with abort handler that throws an exception."""
ctx = get_context()
@ctx.on_abort
def on_abort() -> None:
_handler_state["abort_called"] = True
raise ValueError("Handler crashed!")
# Sleep longer than timeout
time.sleep(5)
def simple_task_with_abort() -> None:
"""Simple task with abort handler for testing."""
ctx = get_context()
@ctx.on_abort
def on_abort() -> None:
pass
def quick_task_with_abort() -> None:
"""Quick task that completes before timeout."""
ctx = get_context()
@ctx.on_abort
def on_abort() -> None:
pass
time.sleep(0.2)
def _register_test_tasks() -> None:
"""Register test task functions if not already registered.
Called in setUp() to ensure tasks are registered regardless of
whether other tests have cleared the registry.
"""
registrations = [
("test_timeout_abortable", timeout_abortable_task),
("test_timeout_handler_fails", timeout_handler_fails_task),
("test_timeout_simple", simple_task_with_abort),
("test_timeout_quick", quick_task_with_abort),
]
for name, func in registrations:
if not TaskRegistry.is_registered(name):
TaskRegistry.register(name, func)
class TestTimeoutHandling(SupersetTestCase):
"""E2E tests for task timeout functionality."""
def setUp(self) -> None:
"""Set up test fixtures."""
super().setUp()
self.login(ADMIN_USERNAME)
_register_test_tasks()
_reset_handler_state()
def test_timeout_with_abort_handler_results_in_timed_out_status(self) -> None:
"""Task with timeout and abort handler should end with TIMED_OUT status."""
_skip_if_sqlite()
# Create task with timeout
task_obj = TaskDAO.create_task(
task_type="test_timeout_abortable",
task_key=f"test_key_{uuid.uuid4().hex[:8]}",
task_name="Test Timeout",
scope=TaskScope.SYSTEM,
properties={"timeout": 1}, # 1 second timeout
)
# Execute task via Celery executor (synchronously)
# Use str(uuid) since Celery serializes args as JSON strings
result = execute_task.apply(
args=[str(task_obj.uuid), "test_timeout_abortable", (), {}]
)
# Verify execution completed
assert result.successful()
assert result.result["status"] == TaskStatus.TIMED_OUT.value
# Verify abort handler was called
assert _handler_state["abort_called"]
def test_user_abort_results_in_aborted_status(self) -> None:
"""User-initiated abort on pending task should result in ABORTED."""
# Create task (pending state)
task_obj = TaskDAO.create_task(
task_type="test_timeout_simple",
task_key=f"test_key_{uuid.uuid4().hex[:8]}",
task_name="Test Abort Task",
scope=TaskScope.SYSTEM,
)
# Cancel before execution (pending task abort)
CancelTaskCommand(task_obj.uuid, force=True).run()
# Refresh from DB
db.session.expire_all()
task_obj = db.session.query(Task).filter_by(uuid=task_obj.uuid).first()
assert task_obj.status == TaskStatus.ABORTED.value
def test_no_timeout_when_not_configured(self) -> None:
"""Task without timeout should run to completion regardless of duration."""
task_obj = TaskDAO.create_task(
task_type="test_timeout_quick",
task_key=f"test_key_{uuid.uuid4().hex[:8]}",
task_name="Test No Timeout",
scope=TaskScope.SYSTEM,
# No timeout property
)
# Use str(uuid) since Celery serializes args as JSON strings
result = execute_task.apply(
args=[str(task_obj.uuid), "test_timeout_quick", (), {}]
)
assert result.successful()
assert result.result["status"] == TaskStatus.SUCCESS.value
def test_abort_handler_exception_results_in_failure(self) -> None:
"""If abort handler throws during timeout, task should be FAILURE."""
_skip_if_sqlite()
task_obj = TaskDAO.create_task(
task_type="test_timeout_handler_fails",
task_key=f"test_key_{uuid.uuid4().hex[:8]}",
task_name="Test Handler Fails",
scope=TaskScope.SYSTEM,
properties={"timeout": 1}, # 1 second timeout
)
# Use str(uuid) since Celery serializes args as JSON strings
result = execute_task.apply(
args=[str(task_obj.uuid), "test_timeout_handler_fails", (), {}]
)
assert result.successful()
assert result.result["status"] == TaskStatus.FAILURE.value
assert _handler_state["abort_called"]

View File

@@ -0,0 +1,420 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file to you under
# the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from collections.abc import Iterator
from uuid import UUID
import pytest
from sqlalchemy.orm.session import Session
from superset_core.api.tasks import TaskProperties, TaskScope, TaskStatus
from superset.commands.tasks.exceptions import TaskNotAbortableError
from superset.models.tasks import Task
from superset.tasks.utils import get_active_dedup_key, get_finished_dedup_key
# Test constants
TASK_UUID = UUID("e7765491-40c1-4f35-a4f5-06308e79310e")
TASK_ID = 42
TEST_TASK_TYPE = "test_type"
TEST_TASK_KEY = "test-key"
TEST_USER_ID = 1
def create_task(
session: Session,
*,
task_id: int | None = None,
task_uuid: UUID | None = None,
task_key: str = TEST_TASK_KEY,
task_type: str = TEST_TASK_TYPE,
scope: TaskScope = TaskScope.PRIVATE,
status: TaskStatus = TaskStatus.PENDING,
user_id: int | None = TEST_USER_ID,
properties: TaskProperties | None = None,
use_finished_dedup_key: bool = False,
) -> Task:
"""Helper to create a task with sensible defaults for testing."""
if use_finished_dedup_key:
dedup_key = get_finished_dedup_key(task_uuid or TASK_UUID)
else:
dedup_key = get_active_dedup_key(
scope=scope,
task_type=task_type,
task_key=task_key,
user_id=user_id,
)
task = Task(
task_type=task_type,
task_key=task_key,
scope=scope.value,
status=status.value,
dedup_key=dedup_key,
user_id=user_id,
)
if task_id is not None:
task.id = task_id
if task_uuid:
task.uuid = task_uuid
if properties:
task.update_properties(properties)
session.add(task)
session.flush()
return task
@pytest.fixture
def session_with_task(session: Session) -> Iterator[Session]:
"""Create a session with Task and TaskSubscriber tables."""
from superset.models.task_subscribers import TaskSubscriber
engine = session.get_bind()
Task.metadata.create_all(engine)
TaskSubscriber.metadata.create_all(engine)
yield session
session.rollback()
def test_find_by_task_key_active(session_with_task: Session) -> None:
"""Test finding active task by task_key"""
from superset.daos.tasks import TaskDAO
create_task(session_with_task)
result = TaskDAO.find_by_task_key(
task_type=TEST_TASK_TYPE,
task_key=TEST_TASK_KEY,
scope=TaskScope.PRIVATE,
user_id=TEST_USER_ID,
)
assert result is not None
assert result.task_key == TEST_TASK_KEY
assert result.task_type == TEST_TASK_TYPE
assert result.status == TaskStatus.PENDING.value
def test_find_by_task_key_not_found(session_with_task: Session) -> None:
"""Test finding task by task_key returns None when not found"""
from superset.daos.tasks import TaskDAO
result = TaskDAO.find_by_task_key(
task_type=TEST_TASK_TYPE,
task_key="nonexistent-key",
scope=TaskScope.PRIVATE,
user_id=TEST_USER_ID,
)
assert result is None
def test_find_by_task_key_finished_not_found(session_with_task: Session) -> None:
"""Test that find_by_task_key returns None for finished tasks.
Finished tasks have a different dedup_key format (UUID-based),
so they won't be found by the active task lookup.
"""
from superset.daos.tasks import TaskDAO
create_task(
session_with_task,
task_key="finished-key",
status=TaskStatus.SUCCESS,
use_finished_dedup_key=True,
task_uuid=TASK_UUID,
)
# Should not find SUCCESS task via active lookup
result = TaskDAO.find_by_task_key(
task_type=TEST_TASK_TYPE,
task_key="finished-key",
scope=TaskScope.PRIVATE,
user_id=TEST_USER_ID,
)
assert result is None
def test_create_task_success(session_with_task: Session) -> None:
"""Test successful task creation."""
from superset.daos.tasks import TaskDAO
result = TaskDAO.create_task(
task_type=TEST_TASK_TYPE,
task_key=TEST_TASK_KEY,
scope=TaskScope.PRIVATE,
user_id=TEST_USER_ID,
)
assert result is not None
assert result.task_key == TEST_TASK_KEY
assert result.task_type == TEST_TASK_TYPE
assert result.status == TaskStatus.PENDING.value
assert isinstance(result, Task)
def test_create_task_with_user_id(session_with_task: Session) -> None:
"""Test task creation with explicit user_id."""
from superset.daos.tasks import TaskDAO
result = TaskDAO.create_task(
task_type=TEST_TASK_TYPE,
task_key="user-task",
scope=TaskScope.PRIVATE,
user_id=42,
)
assert result is not None
assert result.user_id == 42
# Creator should be auto-subscribed
assert len(result.subscribers) == 1
assert result.subscribers[0].user_id == 42
def test_create_task_with_properties(session_with_task: Session) -> None:
"""Test task creation with properties."""
from superset.daos.tasks import TaskDAO
result = TaskDAO.create_task(
task_type=TEST_TASK_TYPE,
task_key="props-task",
scope=TaskScope.PRIVATE,
user_id=TEST_USER_ID,
properties={"timeout": 300},
)
assert result is not None
assert result.properties_dict.get("timeout") == 300
def test_abort_task_pending_success(session_with_task: Session) -> None:
"""Test successful abort of pending task - goes directly to ABORTED"""
from superset.daos.tasks import TaskDAO
task = create_task(
session_with_task,
task_key="pending-task",
status=TaskStatus.PENDING,
)
result = TaskDAO.abort_task(task.uuid, skip_base_filter=True)
assert result is not None
assert result.status == TaskStatus.ABORTED.value
def test_abort_task_in_progress_abortable(session_with_task: Session) -> None:
"""Test abort of in-progress task with abort handler.
Should transition to ABORTING status.
"""
from superset.daos.tasks import TaskDAO
task = create_task(
session_with_task,
task_key="abortable-task",
status=TaskStatus.IN_PROGRESS,
properties={"is_abortable": True},
)
result = TaskDAO.abort_task(task.uuid, skip_base_filter=True)
assert result is not None
# Should set status to ABORTING, not ABORTED
assert result.status == TaskStatus.ABORTING.value
def test_abort_task_in_progress_not_abortable(session_with_task: Session) -> None:
"""Test abort of in-progress task without abort handler - raises error"""
from superset.daos.tasks import TaskDAO
task = create_task(
session_with_task,
task_key="non-abortable-task",
status=TaskStatus.IN_PROGRESS,
properties={"is_abortable": False},
)
with pytest.raises(TaskNotAbortableError):
TaskDAO.abort_task(task.uuid, skip_base_filter=True)
def test_abort_task_in_progress_is_abortable_none(session_with_task: Session) -> None:
"""Test abort of in-progress task with is_abortable not set - raises error"""
from superset.daos.tasks import TaskDAO
task = create_task(
session_with_task,
task_key="no-abortable-prop-task",
status=TaskStatus.IN_PROGRESS,
# Empty properties - no is_abortable key
)
with pytest.raises(TaskNotAbortableError):
TaskDAO.abort_task(task.uuid, skip_base_filter=True)
def test_abort_task_already_aborting(session_with_task: Session) -> None:
"""Test abort of already aborting task - idempotent success"""
from superset.daos.tasks import TaskDAO
task = create_task(
session_with_task,
task_key="aborting-task",
status=TaskStatus.ABORTING,
)
result = TaskDAO.abort_task(task.uuid, skip_base_filter=True)
# Idempotent - returns task without error
assert result is not None
assert result.status == TaskStatus.ABORTING.value
def test_abort_task_not_found(session_with_task: Session) -> None:
"""Test abort fails when task not found"""
from superset.daos.tasks import TaskDAO
result = TaskDAO.abort_task(UUID("00000000-0000-0000-0000-000000000000"))
assert result is None
def test_abort_task_already_finished(session_with_task: Session) -> None:
"""Test abort fails when task already finished"""
from superset.daos.tasks import TaskDAO
task = create_task(
session_with_task,
task_key="finished-task",
status=TaskStatus.SUCCESS,
use_finished_dedup_key=True,
task_uuid=TASK_UUID,
)
result = TaskDAO.abort_task(task.uuid, skip_base_filter=True)
assert result is None
def test_add_subscriber(session_with_task: Session) -> None:
"""Test adding a subscriber to a task"""
from superset.daos.tasks import TaskDAO
task = create_task(
session_with_task,
task_key="shared-task",
scope=TaskScope.SHARED,
user_id=None,
)
# Add subscriber
result = TaskDAO.add_subscriber(task.id, user_id=TEST_USER_ID)
assert result is True
# Verify subscriber was added
session_with_task.refresh(task)
assert len(task.subscribers) == 1
assert task.subscribers[0].user_id == TEST_USER_ID
def test_add_subscriber_idempotent(session_with_task: Session) -> None:
"""Test adding same subscriber twice is idempotent"""
from superset.daos.tasks import TaskDAO
task = create_task(
session_with_task,
task_key="shared-task-2",
scope=TaskScope.SHARED,
user_id=None,
)
# Add subscriber twice
result1 = TaskDAO.add_subscriber(task.id, user_id=TEST_USER_ID)
result2 = TaskDAO.add_subscriber(task.id, user_id=TEST_USER_ID)
assert result1 is True
assert result2 is False # Already subscribed
# Verify only one subscriber
session_with_task.refresh(task)
assert len(task.subscribers) == 1
def test_remove_subscriber(session_with_task: Session) -> None:
"""Test removing a subscriber from a task"""
from superset.daos.tasks import TaskDAO
task = create_task(
session_with_task,
task_key="shared-task-3",
scope=TaskScope.SHARED,
user_id=None,
)
TaskDAO.add_subscriber(task.id, user_id=TEST_USER_ID)
session_with_task.refresh(task)
assert len(task.subscribers) == 1
# Remove subscriber
result = TaskDAO.remove_subscriber(task.id, user_id=TEST_USER_ID)
assert result is not None
assert len(result.subscribers) == 0
def test_remove_subscriber_not_subscribed(session_with_task: Session) -> None:
"""Test removing non-existent subscriber returns None"""
from superset.daos.tasks import TaskDAO
task = create_task(
session_with_task,
task_key="shared-task-4",
scope=TaskScope.SHARED,
user_id=None,
)
# Try to remove non-existent subscriber
result = TaskDAO.remove_subscriber(task.id, user_id=999)
assert result is None
def test_get_status(session_with_task: Session) -> None:
"""Test get_status returns status string when task found by UUID"""
from superset.daos.tasks import TaskDAO
task = create_task(
session_with_task,
task_uuid=TASK_UUID,
task_key="status-task",
status=TaskStatus.IN_PROGRESS,
)
result = TaskDAO.get_status(task.uuid)
assert result == TaskStatus.IN_PROGRESS.value
def test_get_status_not_found(session_with_task: Session) -> None:
"""Test get_status returns None when task not found"""
from superset.daos.tasks import TaskDAO
result = TaskDAO.get_status(UUID("00000000-0000-0000-0000-000000000000"))
assert result is None

View File

@@ -18,17 +18,21 @@
# pylint: disable=invalid-name
from typing import Any
from unittest.mock import MagicMock, patch
from uuid import UUID
import pytest
from freezegun import freeze_time
from sqlalchemy.orm import Session, sessionmaker
# Force module loading before tests run so patches work correctly
import superset.commands.distributed_lock.acquire as acquire_module
import superset.commands.distributed_lock.release as release_module
from superset import db
from superset.distributed_lock import KeyValueDistributedLock
from superset.distributed_lock import DistributedLock
from superset.distributed_lock.types import LockValue
from superset.distributed_lock.utils import get_key
from superset.exceptions import CreateKeyValueDistributedLockFailedException
from superset.exceptions import AcquireDistributedLockFailedException
from superset.key_value.types import JsonKeyValueCodec
LOCK_VALUE: LockValue = {"value": True}
@@ -56,9 +60,9 @@ def _get_other_session() -> Session:
return SessionMaker()
def test_key_value_distributed_lock_happy_path() -> None:
def test_distributed_lock_kv_happy_path() -> None:
"""
Test successfully acquiring and returning the distributed lock.
Test successfully acquiring and returning the distributed lock via KV backend.
Note, we're using another session for asserting the lock state in the Metastore
to simulate what another worker will observe. Otherwise, there's the risk that
@@ -66,24 +70,29 @@ def test_key_value_distributed_lock_happy_path() -> None:
"""
session = _get_other_session()
with freeze_time("2021-01-01"):
assert _get_lock(MAIN_KEY, session) is None
# Ensure Redis is not configured so KV backend is used
with (
patch.object(acquire_module, "get_redis_client", return_value=None),
patch.object(release_module, "get_redis_client", return_value=None),
):
with freeze_time("2021-01-01"):
assert _get_lock(MAIN_KEY, session) is None
with KeyValueDistributedLock("ns", a=1, b=2) as key:
assert key == MAIN_KEY
assert _get_lock(key, session) == LOCK_VALUE
assert _get_lock(OTHER_KEY, session) is None
with DistributedLock("ns", a=1, b=2) as key:
assert key == MAIN_KEY
assert _get_lock(key, session) == LOCK_VALUE
assert _get_lock(OTHER_KEY, session) is None
with pytest.raises(CreateKeyValueDistributedLockFailedException):
with KeyValueDistributedLock("ns", a=1, b=2):
pass
with pytest.raises(AcquireDistributedLockFailedException):
with DistributedLock("ns", a=1, b=2):
pass
assert _get_lock(MAIN_KEY, session) is None
assert _get_lock(MAIN_KEY, session) is None
def test_key_value_distributed_lock_expired() -> None:
def test_distributed_lock_kv_expired() -> None:
"""
Test expiration of the distributed lock
Test expiration of the distributed lock via KV backend.
Note, we're using another session for asserting the lock state in the Metastore
to simulate what another worker will observe. Otherwise, there's the risk that
@@ -91,11 +100,112 @@ def test_key_value_distributed_lock_expired() -> None:
"""
session = _get_other_session()
with freeze_time("2021-01-01"):
assert _get_lock(MAIN_KEY, session) is None
with KeyValueDistributedLock("ns", a=1, b=2):
assert _get_lock(MAIN_KEY, session) == LOCK_VALUE
with freeze_time("2022-01-01"):
assert _get_lock(MAIN_KEY, session) is None
# Ensure Redis is not configured so KV backend is used
with (
patch.object(acquire_module, "get_redis_client", return_value=None),
patch.object(release_module, "get_redis_client", return_value=None),
):
with freeze_time("2021-01-01"):
assert _get_lock(MAIN_KEY, session) is None
with DistributedLock("ns", a=1, b=2):
assert _get_lock(MAIN_KEY, session) == LOCK_VALUE
with freeze_time("2022-01-01"):
assert _get_lock(MAIN_KEY, session) is None
assert _get_lock(MAIN_KEY, session) is None
assert _get_lock(MAIN_KEY, session) is None
def test_distributed_lock_uses_redis_when_configured() -> None:
"""Test that DistributedLock uses Redis backend when configured."""
mock_redis = MagicMock()
mock_redis.set.return_value = True # Lock acquired
# Use patch.object to patch on already-imported modules
with (
patch.object(acquire_module, "get_redis_client", return_value=mock_redis),
patch.object(release_module, "get_redis_client", return_value=mock_redis),
):
with DistributedLock("test_redis", key="value") as lock_key:
assert lock_key is not None
# Verify SET NX EX was called
mock_redis.set.assert_called_once()
call_args = mock_redis.set.call_args
assert call_args.kwargs["nx"] is True
assert "ex" in call_args.kwargs
# Verify DELETE was called on exit
mock_redis.delete.assert_called_once()
def test_distributed_lock_redis_already_taken() -> None:
"""Test Redis lock fails when already held."""
mock_redis = MagicMock()
mock_redis.set.return_value = None # Lock not acquired (already taken)
with patch.object(acquire_module, "get_redis_client", return_value=mock_redis):
with pytest.raises(AcquireDistributedLockFailedException):
with DistributedLock("test_redis", key="value"):
pass
def test_distributed_lock_redis_connection_error() -> None:
"""Test Redis connection error raises exception (fail fast)."""
import redis
mock_redis = MagicMock()
mock_redis.set.side_effect = redis.RedisError("Connection failed")
with patch.object(acquire_module, "get_redis_client", return_value=mock_redis):
with pytest.raises(AcquireDistributedLockFailedException):
with DistributedLock("test_redis", key="value"):
pass
def test_distributed_lock_custom_ttl() -> None:
"""Test Redis lock with custom TTL."""
mock_redis = MagicMock()
mock_redis.set.return_value = True
with (
patch.object(acquire_module, "get_redis_client", return_value=mock_redis),
patch.object(release_module, "get_redis_client", return_value=mock_redis),
):
with DistributedLock("test", ttl_seconds=60, key="value"):
call_args = mock_redis.set.call_args
assert call_args.kwargs["ex"] == 60 # Custom TTL
def test_distributed_lock_default_ttl(app_context: None) -> None:
"""Test Redis lock uses default TTL when not specified."""
from superset.commands.distributed_lock.base import get_default_lock_ttl
mock_redis = MagicMock()
mock_redis.set.return_value = True
with (
patch.object(acquire_module, "get_redis_client", return_value=mock_redis),
patch.object(release_module, "get_redis_client", return_value=mock_redis),
):
with DistributedLock("test", key="value"):
call_args = mock_redis.set.call_args
assert call_args.kwargs["ex"] == get_default_lock_ttl()
def test_distributed_lock_fallback_to_kv_when_redis_not_configured() -> None:
"""Test falls back to KV lock when Redis not configured."""
session = _get_other_session()
test_key = get_key("test_fallback", key="value")
with (
patch.object(acquire_module, "get_redis_client", return_value=None),
patch.object(release_module, "get_redis_client", return_value=None),
):
with freeze_time("2021-01-01"):
# When Redis is not configured, should use KV backend
with DistributedLock("test_fallback", key="value") as lock_key:
assert lock_key == test_key
# Verify lock exists in KV store
assert _get_lock(test_key, session) == LOCK_VALUE
# Lock should be released
assert _get_lock(test_key, session) is None

View File

@@ -0,0 +1,477 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Unit tests for task decorators"""
from unittest.mock import MagicMock, patch
from uuid import UUID
import pytest
from superset_core.api.tasks import TaskOptions, TaskScope
from superset.commands.tasks.exceptions import GlobalTaskFrameworkDisabledError
from superset.tasks.decorators import task, TaskWrapper
from superset.tasks.registry import TaskRegistry
TEST_UUID = UUID("b8b61b7b-1cd3-4a31-a74a-0a95341afc06")
class TestTaskDecoratorFeatureFlag:
"""Tests for @task decorator feature flag behavior"""
def setup_method(self):
"""Clear task registry before each test"""
TaskRegistry._tasks.clear()
@patch("superset.tasks.decorators.is_feature_enabled", return_value=False)
def test_decorator_succeeds_when_gtf_disabled(self, mock_feature_flag):
"""Test that @task decorator can be applied even when GTF is disabled.
This enables safe module imports during app startup or Celery autodiscovery.
"""
# Decoration should succeed - no error raised
@task(name="test_gtf_disabled_decorator")
def my_task() -> None:
pass
assert isinstance(my_task, TaskWrapper)
assert my_task.name == "test_gtf_disabled_decorator"
@patch("superset.tasks.decorators.is_feature_enabled", return_value=False)
def test_call_raises_error_when_gtf_disabled(self, mock_feature_flag):
"""Test that calling a task raises GlobalTaskFrameworkDisabledError
when GTF is disabled."""
@task(name="test_gtf_disabled_call")
def my_task() -> None:
pass
with pytest.raises(GlobalTaskFrameworkDisabledError):
my_task()
@patch("superset.tasks.decorators.is_feature_enabled", return_value=False)
def test_schedule_raises_error_when_gtf_disabled(self, mock_feature_flag):
"""Test that scheduling a task raises GlobalTaskFrameworkDisabledError
when GTF is disabled."""
@task(name="test_gtf_disabled_schedule")
def my_task() -> None:
pass
with pytest.raises(GlobalTaskFrameworkDisabledError):
my_task.schedule()
class TestTaskDecorator:
"""Tests for @task decorator"""
def test_decorator_basic(self):
"""Test basic decorator usage without options"""
@task(name="test_task")
def my_task(arg1: int, arg2: str) -> None:
pass
assert isinstance(my_task, TaskWrapper)
assert my_task.name == "test_task"
assert my_task.scope == TaskScope.PRIVATE
def test_decorator_without_parentheses(self):
"""Test decorator usage without parentheses"""
@task
def my_no_parens_task(arg1: int, arg2: str) -> None:
pass
assert isinstance(my_no_parens_task, TaskWrapper)
assert my_no_parens_task.name == "my_no_parens_task" # Uses function name
assert my_no_parens_task.scope == TaskScope.PRIVATE
def test_decorator_with_default_scope_private(self):
"""Test decorator with explicit PRIVATE scope"""
@task(name="private_task", scope=TaskScope.PRIVATE)
def my_private_task(arg1: int) -> None:
pass
assert my_private_task.scope == TaskScope.PRIVATE
def test_decorator_with_default_scope_shared(self):
"""Test decorator with SHARED scope"""
@task(name="shared_task", scope=TaskScope.SHARED)
def my_shared_task(arg1: int) -> None:
pass
assert my_shared_task.scope == TaskScope.SHARED
def test_decorator_with_default_scope_system(self):
"""Test decorator with SYSTEM scope"""
@task(name="system_task", scope=TaskScope.SYSTEM)
def my_system_task() -> None:
pass
assert my_system_task.scope == TaskScope.SYSTEM
def test_decorator_forbids_ctx_parameter(self):
"""Test decorator rejects functions with ctx parameter"""
with pytest.raises(TypeError, match="must not define 'ctx'"):
@task(name="bad_task")
def bad_task(ctx, arg1: int) -> None: # noqa: ARG001
pass
def test_decorator_forbids_options_parameter(self):
"""Test decorator rejects functions with options parameter"""
with pytest.raises(TypeError, match="must not define.*'options'"):
@task(name="bad_task")
def bad_task(options, arg1: int) -> None: # noqa: ARG001
pass
class TestTaskWrapperMergeOptions:
"""Tests for TaskWrapper._merge_options()"""
def setup_method(self):
"""Clear task registry before each test"""
TaskRegistry._tasks.clear()
def test_merge_options_no_override(self):
"""Test merging with no override returns defaults"""
@task(name="test_merge_no_override_unique")
def merge_task_1() -> None:
pass
# Set default options for testing
merge_task_1.default_options = TaskOptions(
task_key="default_key",
task_name="Default Name",
)
merged = merge_task_1._merge_options(None)
assert merged.task_key == "default_key"
assert merged.task_name == "Default Name"
def test_merge_options_override_task_key(self):
"""Test overriding task_key at call time"""
@task(name="test_merge_override_key_unique")
def merge_task_2() -> None:
pass
# Set default options for testing
merge_task_2.default_options = TaskOptions(task_key="default_key")
override = TaskOptions(task_key="override_key")
merged = merge_task_2._merge_options(override)
assert merged.task_key == "override_key"
def test_merge_options_override_task_name(self):
"""Test overriding task_name at call time"""
@task(name="test_merge_override_name_unique")
def merge_task_3() -> None:
pass
# Set default options for testing
merge_task_3.default_options = TaskOptions(task_name="Default Name")
override = TaskOptions(task_name="Override Name")
merged = merge_task_3._merge_options(override)
assert merged.task_name == "Override Name"
def test_merge_options_override_all(self):
"""Test overriding all options at call time"""
@task(name="test_merge_override_all_unique")
def merge_task_4() -> None:
pass
# Set default options for testing
merge_task_4.default_options = TaskOptions(
task_key="default_key",
task_name="Default Name",
)
override = TaskOptions(
task_key="override_key",
task_name="Override Name",
)
merged = merge_task_4._merge_options(override)
assert merged.task_key == "override_key"
assert merged.task_name == "Override Name"
class TestTaskWrapperSchedule:
"""Tests for TaskWrapper.schedule() with scope"""
def setup_method(self):
"""Clear task registry before each test"""
TaskRegistry._tasks.clear()
@patch("superset.tasks.decorators.TaskManager.submit_task")
def test_schedule_uses_default_scope(self, mock_submit):
"""Test schedule() uses decorator's default scope"""
mock_submit.return_value = MagicMock()
@task(name="test_schedule_default_unique", scope=TaskScope.SHARED)
def schedule_task_1(arg1: int) -> None:
pass
# Shared tasks require explicit task_key
schedule_task_1.schedule(123, options=TaskOptions(task_key="test_key"))
# Verify TaskManager.submit_task was called with correct scope
mock_submit.assert_called_once()
call_args = mock_submit.call_args
assert call_args[1]["scope"] == TaskScope.SHARED
@patch("superset.tasks.decorators.TaskManager.submit_task")
def test_schedule_uses_private_scope_by_default(self, mock_submit):
"""Test schedule() uses PRIVATE scope when no scope specified"""
mock_submit.return_value = MagicMock()
@task(name="test_schedule_override_unique")
def schedule_task_2(arg1: int) -> None:
pass
schedule_task_2.schedule(123)
# Verify PRIVATE scope was used (default)
mock_submit.assert_called_once()
call_args = mock_submit.call_args
assert call_args[1]["scope"] == TaskScope.PRIVATE
@patch("superset.tasks.decorators.TaskManager.submit_task")
def test_schedule_with_custom_options(self, mock_submit):
"""Test schedule() with custom task options"""
mock_submit.return_value = MagicMock()
@task(name="test_schedule_custom_unique", scope=TaskScope.SYSTEM)
def schedule_task_3(arg1: int) -> None:
pass
# Use custom task key and name
schedule_task_3.schedule(
123,
options=TaskOptions(task_key="custom_key", task_name="Custom Task Name"),
)
# Verify scope from decorator and options from call time
mock_submit.assert_called_once()
call_args = mock_submit.call_args
assert call_args[1]["scope"] == TaskScope.SYSTEM
assert call_args[1]["task_key"] == "custom_key"
assert call_args[1]["task_name"] == "Custom Task Name"
@patch("superset.tasks.decorators.TaskManager.submit_task")
def test_schedule_with_no_decorator_options(self, mock_submit):
"""Test schedule() uses default PRIVATE scope when no options provided"""
mock_submit.return_value = MagicMock()
@task(name="test_schedule_no_options_unique")
def schedule_task_4(arg1: int) -> None:
pass
schedule_task_4.schedule(123)
# Verify default PRIVATE scope
mock_submit.assert_called_once()
call_args = mock_submit.call_args
assert call_args[1]["scope"] == TaskScope.PRIVATE
@patch("superset.tasks.decorators.TaskManager.submit_task")
def test_schedule_shared_task_requires_task_key(self, mock_submit):
"""Test shared task schedule() requires explicit task_key"""
@task(name="test_shared_requires_key", scope=TaskScope.SHARED)
def shared_task(arg1: int) -> None:
pass
# Should raise ValueError when no task_key provided
with pytest.raises(
ValueError,
match="Shared task.*requires an explicit task_key.*for deduplication",
):
shared_task.schedule(123)
# Should work with task_key provided
mock_submit.return_value = MagicMock()
shared_task.schedule(123, options=TaskOptions(task_key="valid_key"))
mock_submit.assert_called_once()
@patch("superset.tasks.decorators.TaskManager.submit_task")
def test_schedule_private_task_allows_no_task_key(self, mock_submit):
"""Test private task schedule() works without task_key"""
mock_submit.return_value = MagicMock()
@task(name="test_private_no_key", scope=TaskScope.PRIVATE)
def private_task(arg1: int) -> None:
pass
# Should work without task_key (generates random UUID)
private_task.schedule(123)
mock_submit.assert_called_once()
class TestTaskWrapperCall:
"""Tests for TaskWrapper.__call__() with scope"""
def setup_method(self):
"""Clear task registry before each test"""
TaskRegistry._tasks.clear()
@patch("superset.commands.tasks.update.UpdateTaskCommand.run")
@patch("superset.daos.tasks.TaskDAO.find_one_or_none")
@patch("superset.commands.tasks.submit.SubmitTaskCommand.run_with_info")
def test_call_uses_default_scope(
self, mock_submit_run_with_info, mock_find, mock_update_run
):
"""Test direct call uses decorator's default scope"""
mock_task = MagicMock()
mock_task.uuid = TEST_UUID
mock_task.status = "in_progress"
mock_submit_run_with_info.return_value = (mock_task, True) # (task, is_new)
mock_update_run.return_value = mock_task
mock_find.return_value = mock_task # Mock the subsequent find call
@task(name="test_call_default_unique", scope=TaskScope.SHARED)
def call_task_1(arg1: int) -> None:
pass
# Shared tasks require explicit task_key
call_task_1(123, options=TaskOptions(task_key="test_key"))
# Verify SubmitTaskCommand.run_with_info was called
mock_submit_run_with_info.assert_called_once()
@patch("superset.utils.core.get_user_id")
@patch("superset.commands.tasks.update.UpdateTaskCommand.run")
@patch("superset.daos.tasks.TaskDAO.find_one_or_none")
@patch("superset.commands.tasks.submit.SubmitTaskCommand.run_with_info")
def test_call_uses_private_scope_by_default(
self, mock_submit_run_with_info, mock_find, mock_update_run, mock_get_user_id
):
"""Test direct call uses PRIVATE scope when no scope specified"""
mock_get_user_id.return_value = 1
mock_task = MagicMock()
mock_task.uuid = TEST_UUID
mock_task.status = "in_progress"
mock_submit_run_with_info.return_value = (mock_task, True) # (task, is_new)
mock_update_run.return_value = mock_task
mock_find.return_value = mock_task # Mock the subsequent find call
@task(name="test_call_private_default_unique")
def call_task_2(arg1: int) -> None:
pass
call_task_2(123)
# Verify SubmitTaskCommand.run_with_info was called
mock_submit_run_with_info.assert_called_once()
@patch("superset.commands.tasks.update.UpdateTaskCommand.run")
@patch("superset.daos.tasks.TaskDAO.find_one_or_none")
@patch("superset.commands.tasks.submit.SubmitTaskCommand.run_with_info")
def test_call_with_custom_options(
self, mock_submit_run_with_info, mock_find, mock_update_run
):
"""Test direct call with custom task options"""
mock_task = MagicMock()
mock_task.uuid = TEST_UUID
mock_task.status = "in_progress"
mock_submit_run_with_info.return_value = (mock_task, True) # (task, is_new)
mock_update_run.return_value = mock_task
mock_find.return_value = mock_task # Mock the subsequent find call
@task(name="test_call_custom_unique", scope=TaskScope.SYSTEM)
def call_task_3(arg1: int) -> None:
pass
# Use custom task key and name
call_task_3(
123,
options=TaskOptions(task_key="custom_key", task_name="Custom Task Name"),
)
# Verify SubmitTaskCommand.run_with_info was called
mock_submit_run_with_info.assert_called_once()
def test_call_shared_task_requires_task_key(self):
"""Test shared task direct call requires explicit task_key"""
@task(name="test_shared_call_requires_key", scope=TaskScope.SHARED)
def shared_task(arg1: int) -> None:
pass
# Should raise ValueError when no task_key provided
with pytest.raises(
ValueError,
match="Shared task.*requires an explicit task_key.*for deduplication",
):
shared_task(123)
@patch("superset.commands.tasks.update.UpdateTaskCommand.run")
@patch("superset.daos.tasks.TaskDAO.find_one_or_none")
@patch("superset.commands.tasks.submit.SubmitTaskCommand.run_with_info")
def test_call_shared_task_works_with_task_key(
self, mock_submit_run_with_info, mock_find, mock_update_run
):
"""Test shared task direct call works with task_key"""
mock_task = MagicMock()
mock_task.uuid = TEST_UUID
mock_task.status = "in_progress"
mock_submit_run_with_info.return_value = (mock_task, True) # (task, is_new)
mock_update_run.return_value = mock_task
mock_find.return_value = mock_task
@task(name="test_shared_call_with_key", scope=TaskScope.SHARED)
def shared_task(arg1: int) -> None:
pass
# Should work with task_key provided
shared_task(123, options=TaskOptions(task_key="valid_key"))
mock_submit_run_with_info.assert_called_once()
@patch("superset.utils.core.get_user_id")
@patch("superset.commands.tasks.update.UpdateTaskCommand.run")
@patch("superset.daos.tasks.TaskDAO.find_one_or_none")
@patch("superset.commands.tasks.submit.SubmitTaskCommand.run_with_info")
def test_call_private_task_allows_no_task_key(
self, mock_submit_run_with_info, mock_find, mock_update_run, mock_get_user_id
):
"""Test private task direct call works without task_key"""
mock_get_user_id.return_value = 1
mock_task = MagicMock()
mock_task.uuid = TEST_UUID
mock_task.status = "in_progress"
mock_submit_run_with_info.return_value = (mock_task, True) # (task, is_new)
mock_update_run.return_value = mock_task
mock_find.return_value = mock_task
@task(name="test_private_call_no_key", scope=TaskScope.PRIVATE)
def private_task(arg1: int) -> None:
pass
# Should work without task_key (generates random UUID)
private_task(123)
mock_submit_run_with_info.assert_called_once()

View File

@@ -0,0 +1,677 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Unit tests for GTF handlers (abort, cleanup) and related Task model behavior."""
import time
from datetime import datetime, timezone
from unittest.mock import MagicMock, Mock, patch
from uuid import UUID
import pytest
from freezegun import freeze_time
from superset_core.api.tasks import TaskStatus
from superset.tasks.context import TaskContext
TEST_UUID = UUID("b8b61b7b-1cd3-4a31-a74a-0a95341afc06")
@pytest.fixture
def mock_task():
"""Create a mock task for testing."""
task = MagicMock()
task.uuid = TEST_UUID
task.status = TaskStatus.PENDING.value
return task
@pytest.fixture
def mock_task_dao(mock_task):
"""Mock TaskDAO to return our test task."""
with patch("superset.daos.tasks.TaskDAO") as mock_dao:
mock_dao.find_one_or_none.return_value = mock_task
yield mock_dao
@pytest.fixture
def mock_update_command():
"""Mock UpdateTaskCommand to avoid database operations."""
with patch("superset.commands.tasks.update.UpdateTaskCommand") as mock_cmd:
mock_cmd.return_value.run.return_value = None
yield mock_cmd
@pytest.fixture
def mock_flask_app():
"""Create a properly configured mock Flask app."""
mock_app = MagicMock()
mock_app.config = {
"TASK_ABORT_POLLING_DEFAULT_INTERVAL": 0.1,
}
# Make app_context() return a proper context manager
mock_app.app_context.return_value.__enter__ = MagicMock(return_value=None)
mock_app.app_context.return_value.__exit__ = MagicMock(return_value=None)
# Use regular Mock (not MagicMock) for _get_current_object to avoid
# AsyncMockMixin creating unawaited coroutines in Python 3.10+
mock_app._get_current_object = Mock(return_value=mock_app)
return mock_app
@pytest.fixture
def task_context(mock_task, mock_task_dao, mock_update_command, mock_flask_app):
"""Create TaskContext with mocked dependencies."""
# Ensure mock_task has properties_dict and payload_dict (TaskContext accesses them)
mock_task.properties_dict = {"is_abortable": False}
mock_task.payload_dict = {}
with (
patch("superset.tasks.context.current_app") as mock_current_app,
patch("superset.tasks.manager.cache_manager") as mock_cache_manager,
):
# Disable Redis by making signal_cache return None
mock_cache_manager.signal_cache = None
# Configure current_app mock
mock_current_app.config = mock_flask_app.config
# Use regular Mock (not MagicMock) for _get_current_object to avoid
# AsyncMockMixin creating unawaited coroutines in Python 3.10+
mock_current_app._get_current_object = Mock(return_value=mock_flask_app)
ctx = TaskContext(mock_task)
yield ctx
# Cleanup: stop polling if started
if ctx._abort_listener:
ctx.stop_abort_polling()
class TestTaskStatusEnum:
"""Test TaskStatus enum values."""
def test_aborting_status_exists(self):
"""Test that ABORTING status is defined."""
assert hasattr(TaskStatus, "ABORTING")
assert TaskStatus.ABORTING.value == "aborting"
def test_all_statuses_present(self):
"""Test all expected statuses are present."""
expected_statuses = [
"pending",
"in_progress",
"success",
"failure",
"aborting",
"aborted",
]
actual_statuses = [s.value for s in TaskStatus]
for status in expected_statuses:
assert status in actual_statuses, f"Missing status: {status}"
class TestTaskAbortProperties:
"""Test Task model abort-related properties via status and properties accessor."""
def test_aborting_status(self):
"""Test ABORTING status check."""
from superset.models.tasks import Task
task = Task()
task.status = TaskStatus.ABORTING.value
assert task.status == TaskStatus.ABORTING.value
def test_is_abortable_in_properties(self):
"""Test is_abortable is accessible via properties."""
from superset.models.tasks import Task
task = Task()
task.update_properties({"is_abortable": True})
assert task.properties_dict.get("is_abortable") is True
def test_is_abortable_default_none(self):
"""Test is_abortable defaults to None for new tasks."""
from superset.models.tasks import Task
task = Task()
assert task.properties_dict.get("is_abortable") is None
class TestTaskSetStatus:
"""Test Task.set_status behavior for abort states."""
def test_set_status_in_progress_sets_is_abortable_false(self):
"""Test that transitioning to IN_PROGRESS sets is_abortable to False."""
from superset.models.tasks import Task
task = Task()
task.uuid = "test-uuid"
# Default is None
task.set_status(TaskStatus.IN_PROGRESS)
assert task.properties_dict.get("is_abortable") is False
assert task.started_at is not None
def test_set_status_in_progress_preserves_existing_is_abortable(self):
"""Test that re-setting IN_PROGRESS doesn't override is_abortable."""
from superset.models.tasks import Task
task = Task()
task.uuid = "test-uuid"
task.update_properties(
{"is_abortable": True}
) # Already set by handler registration
task.started_at = datetime.now(timezone.utc) # Already started
task.set_status(TaskStatus.IN_PROGRESS)
# Should not override since started_at is already set
assert task.properties_dict.get("is_abortable") is True
def test_set_status_aborting_does_not_set_ended_at(self):
"""Test that ABORTING status does not set ended_at."""
from superset.models.tasks import Task
task = Task()
task.uuid = "test-uuid"
task.started_at = datetime.now(timezone.utc)
task.status = TaskStatus.ABORTING.value
assert task.ended_at is None
def test_set_status_aborted_sets_ended_at(self):
"""Test that ABORTED status sets ended_at."""
from superset.models.tasks import Task
task = Task()
task.uuid = "test-uuid"
task.started_at = datetime.now(timezone.utc)
task.set_status(TaskStatus.ABORTED)
assert task.ended_at is not None
class TestTaskDuration:
"""Test Task duration_seconds property with different states."""
def test_duration_seconds_finished_task(self):
"""Test duration for finished task returns actual duration."""
from superset.models.tasks import Task
task = Task()
task.status = TaskStatus.SUCCESS.value # Must be finished to use ended_at
task.started_at = datetime(2024, 1, 1, 10, 0, 0, tzinfo=timezone.utc)
task.ended_at = datetime(2024, 1, 1, 10, 0, 30, tzinfo=timezone.utc)
# Should use ended_at - started_at = 30 seconds
assert task.duration_seconds == 30.0
@freeze_time("2024-01-01 10:00:30")
def test_duration_seconds_running_task(self):
"""Test duration for running task returns time since start."""
from superset.models.tasks import Task
task = Task()
task.started_at = datetime(2024, 1, 1, 10, 0, 0, tzinfo=timezone.utc)
task.ended_at = None
# 30 seconds since start
assert task.duration_seconds == 30.0
@freeze_time("2024-01-01 10:00:15")
def test_duration_seconds_pending_task(self):
"""Test duration for pending task returns queue time."""
from superset.models.tasks import Task
task = Task()
task.created_on = datetime(2024, 1, 1, 10, 0, 0, tzinfo=timezone.utc)
task.started_at = None
task.ended_at = None
# 15 seconds since creation
assert task.duration_seconds == 15.0
def test_duration_seconds_no_timestamps(self):
"""Test duration returns None when no timestamps available."""
from superset.models.tasks import Task
task = Task()
task.created_on = None
task.started_at = None
task.ended_at = None
assert task.duration_seconds is None
class TestAbortHandlerRegistration:
"""Test abort handler registration and is_abortable flag."""
def test_on_abort_registers_handler(self, task_context):
"""Test that on_abort registers a handler."""
handler_called = False
@task_context.on_abort
def handle_abort():
nonlocal handler_called
handler_called = True
assert len(task_context._abort_handlers) == 1
assert not handler_called
@patch("superset.tasks.context.current_app")
def test_on_abort_sets_abortable(self, mock_app):
"""Test on_abort sets is_abortable to True on first handler."""
mock_app.config = {"TASK_ABORT_POLLING_DEFAULT_INTERVAL": 1.0}
mock_app._get_current_object = Mock(return_value=mock_app)
mock_task = MagicMock()
mock_task.uuid = TEST_UUID
mock_task.properties_dict = {"is_abortable": False}
mock_task.payload_dict = {}
with (
patch.object(TaskContext, "_set_abortable") as mock_set_abortable,
patch.object(TaskContext, "start_abort_polling"),
):
ctx = TaskContext(mock_task)
@ctx.on_abort
def handler():
pass
mock_set_abortable.assert_called_once()
@patch("superset.tasks.context.current_app")
def test_on_abort_only_sets_abortable_once(self, mock_app):
"""Test on_abort only calls _set_abortable for first handler."""
mock_app.config = {"TASK_ABORT_POLLING_DEFAULT_INTERVAL": 1.0}
mock_app._get_current_object = Mock(return_value=mock_app)
mock_task = MagicMock()
mock_task.uuid = TEST_UUID
mock_task.properties_dict = {"is_abortable": False}
mock_task.payload_dict = {}
with (
patch.object(TaskContext, "_set_abortable") as mock_set_abortable,
patch.object(TaskContext, "start_abort_polling"),
):
ctx = TaskContext(mock_task)
@ctx.on_abort
def handler1():
pass
@ctx.on_abort
def handler2():
pass
# Should only be called once for first handler
assert mock_set_abortable.call_count == 1
def test_abort_handlers_completed_initially_false(self):
"""Test abort_handlers_completed is False initially."""
mock_task = MagicMock()
mock_task.uuid = TEST_UUID
mock_task.properties_dict = {}
mock_task.payload_dict = {}
with patch("superset.tasks.context.current_app") as mock_app:
mock_app._get_current_object = Mock(return_value=mock_app)
ctx = TaskContext(mock_task)
assert ctx.abort_handlers_completed is False
class TestAbortPolling:
"""Test abort detection polling behavior."""
def test_on_abort_starts_polling_automatically(self, task_context):
"""Test that registering first handler starts abort listener."""
assert task_context._abort_listener is None
@task_context.on_abort
def handle_abort():
pass
assert task_context._abort_listener is not None
def test_stop_abort_polling(self, task_context):
"""Test that stop_abort_polling stops the abort listener."""
@task_context.on_abort
def handle_abort():
pass
assert task_context._abort_listener is not None
task_context.stop_abort_polling()
assert task_context._abort_listener is None
def test_start_abort_polling_only_once(self, task_context):
"""Test that start_abort_polling is idempotent."""
task_context.start_abort_polling(interval=0.1)
first_listener = task_context._abort_listener
# Try to start again
task_context.start_abort_polling(interval=0.1)
second_listener = task_context._abort_listener
# Should be the same listener
assert first_listener is second_listener
def test_on_abort_with_custom_interval(self, task_context):
"""Test that custom interval can be set via start_abort_polling."""
with patch("superset.tasks.context.current_app") as mock_app:
mock_app.config = {"TASK_ABORT_POLLING_DEFAULT_INTERVAL": 0.1}
mock_app._get_current_object = Mock(return_value=mock_app)
@task_context.on_abort
def handle_abort():
pass
# Override with custom interval
task_context.stop_abort_polling()
task_context.start_abort_polling(interval=0.05)
assert task_context._abort_listener is not None
def test_polling_stops_after_abort_detected(self, task_context, mock_task):
"""Test that abort is detected and handlers are triggered."""
@task_context.on_abort
def handle_abort():
pass
# Trigger abort
mock_task.status = TaskStatus.ABORTED.value
# Wait for detection
time.sleep(0.3)
# Abort should have been detected
assert task_context._abort_detected is True
class TestAbortHandlerExecution:
"""Test abort handler execution behavior."""
def test_on_abort_handler_fires_when_task_aborted(self, task_context, mock_task):
"""Test that abort handler fires automatically when task is aborted."""
abort_called = False
@task_context.on_abort
def handle_abort():
nonlocal abort_called
abort_called = True
# Simulate task being aborted
mock_task.status = TaskStatus.ABORTED.value
# Wait for polling to detect abort (max 0.3s with 0.1s interval)
time.sleep(0.3)
assert abort_called
assert task_context._abort_detected
def test_on_abort_not_called_on_success(self, task_context, mock_task):
"""Test that abort handlers don't run on success."""
abort_called = False
@task_context.on_abort
def handle_abort():
nonlocal abort_called
abort_called = True
# Keep task in success state
mock_task.status = TaskStatus.SUCCESS.value
# Wait and verify handler not called
time.sleep(0.3)
assert not abort_called
def test_multiple_abort_handlers(self, task_context, mock_task):
"""Test that all abort handlers execute in LIFO order."""
calls = []
@task_context.on_abort
def handler1():
calls.append(1)
@task_context.on_abort
def handler2():
calls.append(2)
# Trigger abort
mock_task.status = TaskStatus.ABORTED.value
# Wait for detection
time.sleep(0.3)
# LIFO order: handler2 runs first
assert calls == [2, 1]
def test_abort_handler_exception_doesnt_fail_task(self, task_context, mock_task):
"""Test that exception in abort handler is logged but doesn't fail task."""
handler2_called = False
@task_context.on_abort
def bad_handler():
raise ValueError("Handler error")
@task_context.on_abort
def good_handler():
nonlocal handler2_called
handler2_called = True
# Trigger abort
mock_task.status = TaskStatus.ABORTED.value
# Wait for detection
time.sleep(0.3)
# Second handler should still run despite first handler failing
assert handler2_called
class TestBestEffortHandlerExecution:
"""Test that all handlers execute even when some fail (best-effort)."""
def test_all_abort_handlers_run_even_if_all_fail(self, task_context, mock_task):
"""Test all abort handlers execute even if every one raises an exception."""
calls = []
@task_context.on_abort
def handler1():
calls.append(1)
raise ValueError("Handler 1 failed")
@task_context.on_abort
def handler2():
calls.append(2)
raise RuntimeError("Handler 2 failed")
@task_context.on_abort
def handler3():
calls.append(3)
raise TypeError("Handler 3 failed")
# Trigger abort handlers directly (simulating abort detection)
task_context._trigger_abort_handlers()
# All handlers should have been called (LIFO order: 3, 2, 1)
assert calls == [3, 2, 1]
# Failures should be collected (abort handlers don't write to DB)
assert len(task_context._handler_failures) == 3
failure_types = [
type(ex).__name__ for _, ex, _ in task_context._handler_failures
]
assert "TypeError" in failure_types
assert "RuntimeError" in failure_types
assert "ValueError" in failure_types
def test_all_cleanup_handlers_run_even_if_all_fail(self, task_context, mock_task):
"""Test all cleanup handlers execute even if every one raises an exception."""
calls = []
captured_failures = []
# Mock _write_handler_failures_to_db to capture failures before clearing
original_write = task_context._write_handler_failures_to_db
def mock_write():
captured_failures.extend(task_context._handler_failures)
original_write()
task_context._write_handler_failures_to_db = mock_write
@task_context.on_cleanup
def cleanup1():
calls.append(1)
raise ValueError("Cleanup 1 failed")
@task_context.on_cleanup
def cleanup2():
calls.append(2)
raise RuntimeError("Cleanup 2 failed")
@task_context.on_cleanup
def cleanup3():
calls.append(3)
raise TypeError("Cleanup 3 failed")
# Set task to SUCCESS (not aborting) so only cleanup handlers run
mock_task.status = TaskStatus.SUCCESS.value
# Run cleanup
task_context._run_cleanup()
# All handlers should have been called (LIFO order: 3, 2, 1)
assert calls == [3, 2, 1]
# Failures should have been captured before clearing
assert len(captured_failures) == 3
failure_types = [type(ex).__name__ for _, ex, _ in captured_failures]
assert "TypeError" in failure_types
assert "RuntimeError" in failure_types
assert "ValueError" in failure_types
def test_mixed_abort_and_cleanup_failures_all_collected(
self, task_context, mock_task
):
"""Test abort and cleanup handler failures are collected together."""
calls = []
captured_failures = []
# Mock _write_handler_failures_to_db to capture failures before clearing
original_write = task_context._write_handler_failures_to_db
def mock_write():
captured_failures.extend(task_context._handler_failures)
original_write()
task_context._write_handler_failures_to_db = mock_write
@task_context.on_abort
def abort1():
calls.append("abort1")
raise ValueError("Abort 1 failed")
@task_context.on_abort
def abort2():
calls.append("abort2")
raise RuntimeError("Abort 2 failed")
@task_context.on_cleanup
def cleanup1():
calls.append("cleanup1")
raise TypeError("Cleanup 1 failed")
@task_context.on_cleanup
def cleanup2():
calls.append("cleanup2")
raise KeyError("Cleanup 2 failed")
# Set task to ABORTING so both abort and cleanup handlers run
mock_task.status = TaskStatus.ABORTING.value
# Run cleanup (which triggers abort handlers first, then cleanup handlers)
task_context._run_cleanup()
# All handlers should have been called
# Abort handlers run first (LIFO: abort2, abort1)
# Then cleanup handlers (LIFO: cleanup2, cleanup1)
assert calls == ["abort2", "abort1", "cleanup2", "cleanup1"]
# All 4 failures should have been captured
assert len(captured_failures) == 4
# Verify handler types are recorded correctly
handler_types = [htype for htype, _, _ in captured_failures]
assert handler_types.count("abort") == 2
assert handler_types.count("cleanup") == 2
class TestCleanupHandlers:
"""Test cleanup handler behavior."""
def test_cleanup_triggers_abort_handlers_if_not_detected(
self, task_context, mock_task
):
"""Test that _run_cleanup triggers abort handlers if task ended aborted."""
abort_called = False
@task_context.on_abort
def handle_abort():
nonlocal abort_called
abort_called = True
# Set task as aborted but don't let polling detect it
mock_task.status = TaskStatus.ABORTED.value
task_context._abort_detected = False
# Immediately run cleanup (simulating task ending before poll)
task_context._run_cleanup()
assert abort_called
def test_cleanup_doesnt_duplicate_abort_handlers(self, task_context, mock_task):
"""Test that abort handlers only run once even if called from cleanup."""
call_count = 0
@task_context.on_abort
def handle_abort():
nonlocal call_count
call_count += 1
# Trigger abort via polling
mock_task.status = TaskStatus.ABORTED.value
time.sleep(0.3)
# Handlers should have been called once
assert call_count == 1
assert task_context._abort_detected is True
# Run cleanup - handlers should NOT be called again
task_context._run_cleanup()
assert call_count == 1 # Still 1, not 2

View File

@@ -0,0 +1,462 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Unit tests for TaskManager pub/sub functionality"""
import threading
import time
from unittest.mock import MagicMock, patch
import redis
from superset.tasks.manager import AbortListener, TaskManager
class TestAbortListener:
"""Tests for AbortListener class"""
def test_stop_sets_event(self):
"""Test that stop() sets the stop event"""
stop_event = threading.Event()
thread = MagicMock(spec=threading.Thread)
thread.is_alive.return_value = False
listener = AbortListener("test-uuid", thread, stop_event)
assert not stop_event.is_set()
listener.stop()
assert stop_event.is_set()
def test_stop_closes_pubsub(self):
"""Test that stop() closes the pub/sub connection"""
stop_event = threading.Event()
thread = MagicMock(spec=threading.Thread)
thread.is_alive.return_value = False
pubsub = MagicMock()
listener = AbortListener("test-uuid", thread, stop_event, pubsub)
listener.stop()
pubsub.unsubscribe.assert_called_once()
pubsub.close.assert_called_once()
def test_stop_joins_thread(self):
"""Test that stop() joins the listener thread"""
stop_event = threading.Event()
thread = MagicMock(spec=threading.Thread)
thread.is_alive.return_value = True
listener = AbortListener("test-uuid", thread, stop_event)
listener.stop()
thread.join.assert_called_once_with(timeout=2.0)
class TestTaskManagerInitApp:
"""Tests for TaskManager.init_app()"""
def setup_method(self):
"""Reset TaskManager state before each test"""
TaskManager._initialized = False
TaskManager._channel_prefix = "gtf:abort:"
TaskManager._completion_channel_prefix = "gtf:complete:"
def teardown_method(self):
"""Reset TaskManager state after each test"""
TaskManager._initialized = False
TaskManager._channel_prefix = "gtf:abort:"
TaskManager._completion_channel_prefix = "gtf:complete:"
def test_init_app_sets_channel_prefixes(self):
"""Test init_app reads channel prefixes from config"""
app = MagicMock()
app.config.get.side_effect = lambda key, default=None: {
"TASKS_ABORT_CHANNEL_PREFIX": "custom:abort:",
"TASKS_COMPLETION_CHANNEL_PREFIX": "custom:complete:",
}.get(key, default)
TaskManager.init_app(app)
assert TaskManager._initialized is True
assert TaskManager._channel_prefix == "custom:abort:"
assert TaskManager._completion_channel_prefix == "custom:complete:"
def test_init_app_skips_if_already_initialized(self):
"""Test init_app is idempotent"""
TaskManager._initialized = True
app = MagicMock()
TaskManager.init_app(app)
# Should not call app.config.get since already initialized
app.config.get.assert_not_called()
class TestTaskManagerPubSub:
"""Tests for TaskManager pub/sub methods"""
def setup_method(self):
"""Reset TaskManager state before each test"""
TaskManager._initialized = False
TaskManager._channel_prefix = "gtf:abort:"
TaskManager._completion_channel_prefix = "gtf:complete:"
def teardown_method(self):
"""Reset TaskManager state after each test"""
TaskManager._initialized = False
TaskManager._channel_prefix = "gtf:abort:"
TaskManager._completion_channel_prefix = "gtf:complete:"
@patch("superset.tasks.manager.cache_manager")
def test_is_pubsub_available_no_redis(self, mock_cache_manager):
"""Test is_pubsub_available returns False when Redis not configured"""
mock_cache_manager.signal_cache = None
assert TaskManager.is_pubsub_available() is False
@patch("superset.tasks.manager.cache_manager")
def test_is_pubsub_available_with_redis(self, mock_cache_manager):
"""Test is_pubsub_available returns True when Redis is configured"""
mock_cache_manager.signal_cache = MagicMock()
assert TaskManager.is_pubsub_available() is True
def test_get_abort_channel(self):
"""Test get_abort_channel returns correct channel name"""
task_uuid = "abc-123-def-456"
channel = TaskManager.get_abort_channel(task_uuid)
assert channel == "gtf:abort:abc-123-def-456"
def test_get_abort_channel_custom_prefix(self):
"""Test get_abort_channel with custom prefix"""
TaskManager._channel_prefix = "custom:prefix:"
task_uuid = "test-uuid"
channel = TaskManager.get_abort_channel(task_uuid)
assert channel == "custom:prefix:test-uuid"
@patch("superset.tasks.manager.cache_manager")
def test_publish_abort_no_redis(self, mock_cache_manager):
"""Test publish_abort returns False when Redis not available"""
mock_cache_manager.signal_cache = None
result = TaskManager.publish_abort("test-uuid")
assert result is False
@patch("superset.tasks.manager.cache_manager")
def test_publish_abort_success(self, mock_cache_manager):
"""Test publish_abort publishes message successfully"""
mock_redis = MagicMock()
mock_redis.publish.return_value = 1 # One subscriber
mock_cache_manager.signal_cache = mock_redis
result = TaskManager.publish_abort("test-uuid")
assert result is True
mock_redis.publish.assert_called_once_with("gtf:abort:test-uuid", "abort")
@patch("superset.tasks.manager.cache_manager")
def test_publish_abort_redis_error(self, mock_cache_manager):
"""Test publish_abort handles Redis errors gracefully"""
mock_redis = MagicMock()
mock_redis.publish.side_effect = redis.RedisError("Connection lost")
mock_cache_manager.signal_cache = mock_redis
result = TaskManager.publish_abort("test-uuid")
assert result is False
class TestTaskManagerListenForAbort:
"""Tests for TaskManager.listen_for_abort()"""
def setup_method(self):
"""Reset TaskManager state before each test"""
TaskManager._initialized = False
TaskManager._channel_prefix = "gtf:abort:"
TaskManager._completion_channel_prefix = "gtf:complete:"
def teardown_method(self):
"""Reset TaskManager state after each test"""
TaskManager._initialized = False
TaskManager._channel_prefix = "gtf:abort:"
TaskManager._completion_channel_prefix = "gtf:complete:"
@patch("superset.tasks.manager.cache_manager")
def test_listen_for_abort_no_redis_uses_polling(self, mock_cache_manager):
"""Test listen_for_abort falls back to polling when Redis unavailable"""
mock_cache_manager.signal_cache = None
callback = MagicMock()
with patch.object(TaskManager, "_poll_for_abort", return_value=None):
listener = TaskManager.listen_for_abort(
task_uuid="test-uuid",
callback=callback,
poll_interval=1.0,
app=None,
)
# Give thread time to start
time.sleep(0.1)
listener.stop()
# Should use polling since no Redis
assert listener._pubsub is None
@patch("superset.tasks.manager.cache_manager")
def test_listen_for_abort_with_redis_uses_pubsub(self, mock_cache_manager):
"""Test listen_for_abort uses pub/sub when Redis available"""
mock_redis = MagicMock()
mock_pubsub = MagicMock()
mock_redis.pubsub.return_value = mock_pubsub
mock_cache_manager.signal_cache = mock_redis
callback = MagicMock()
with patch.object(TaskManager, "_listen_pubsub", return_value=None):
listener = TaskManager.listen_for_abort(
task_uuid="test-uuid",
callback=callback,
poll_interval=1.0,
app=None,
)
# Give thread time to start
time.sleep(0.1)
listener.stop()
# Should subscribe to channel
mock_pubsub.subscribe.assert_called_once_with("gtf:abort:test-uuid")
@patch("superset.tasks.manager.cache_manager")
def test_listen_for_abort_redis_subscribe_failure_raises(self, mock_cache_manager):
"""Test listen_for_abort raises exception on subscribe failure
when Redis configured"""
import pytest
mock_redis = MagicMock()
mock_redis.pubsub.side_effect = redis.RedisError("Connection failed")
mock_cache_manager.signal_cache = mock_redis
callback = MagicMock()
# With fail-fast behavior, Redis subscribe failure raises exception
with pytest.raises(redis.RedisError, match="Connection failed"):
TaskManager.listen_for_abort(
task_uuid="test-uuid",
callback=callback,
poll_interval=1.0,
app=None,
)
class TestTaskManagerCompletion:
"""Tests for TaskManager completion pub/sub and wait_for_completion"""
def setup_method(self):
"""Reset TaskManager state before each test"""
TaskManager._initialized = False
TaskManager._channel_prefix = "gtf:abort:"
TaskManager._completion_channel_prefix = "gtf:complete:"
def teardown_method(self):
"""Reset TaskManager state after each test"""
TaskManager._initialized = False
TaskManager._channel_prefix = "gtf:abort:"
TaskManager._completion_channel_prefix = "gtf:complete:"
def test_get_completion_channel(self):
"""Test get_completion_channel returns correct channel name"""
task_uuid = "abc-123-def-456"
channel = TaskManager.get_completion_channel(task_uuid)
assert channel == "gtf:complete:abc-123-def-456"
def test_get_completion_channel_custom_prefix(self):
"""Test get_completion_channel with custom prefix"""
TaskManager._completion_channel_prefix = "custom:complete:"
task_uuid = "test-uuid"
channel = TaskManager.get_completion_channel(task_uuid)
assert channel == "custom:complete:test-uuid"
@patch("superset.tasks.manager.cache_manager")
def test_publish_completion_no_redis(self, mock_cache_manager):
"""Test publish_completion returns False when Redis not available"""
mock_cache_manager.signal_cache = None
result = TaskManager.publish_completion("test-uuid", "success")
assert result is False
@patch("superset.tasks.manager.cache_manager")
def test_publish_completion_success(self, mock_cache_manager):
"""Test publish_completion publishes message successfully"""
mock_redis = MagicMock()
mock_redis.publish.return_value = 1 # One subscriber
mock_cache_manager.signal_cache = mock_redis
result = TaskManager.publish_completion("test-uuid", "success")
assert result is True
mock_redis.publish.assert_called_once_with("gtf:complete:test-uuid", "success")
@patch("superset.tasks.manager.cache_manager")
def test_publish_completion_redis_error(self, mock_cache_manager):
"""Test publish_completion handles Redis errors gracefully"""
mock_redis = MagicMock()
mock_redis.publish.side_effect = redis.RedisError("Connection lost")
mock_cache_manager.signal_cache = mock_redis
result = TaskManager.publish_completion("test-uuid", "success")
assert result is False
@patch("superset.tasks.manager.cache_manager")
@patch("superset.daos.tasks.TaskDAO")
def test_wait_for_completion_task_not_found(self, mock_dao, mock_cache_manager):
"""Test wait_for_completion raises ValueError for missing task"""
import pytest
mock_cache_manager.signal_cache = None
mock_dao.find_one_or_none.return_value = None
with pytest.raises(ValueError, match="not found"):
TaskManager.wait_for_completion("nonexistent-uuid")
@patch("superset.tasks.manager.cache_manager")
@patch("superset.daos.tasks.TaskDAO")
def test_wait_for_completion_already_complete(self, mock_dao, mock_cache_manager):
"""Test wait_for_completion returns immediately for terminal state"""
mock_cache_manager.signal_cache = None
mock_task = MagicMock()
mock_task.uuid = "test-uuid"
mock_task.status = "success"
mock_dao.find_one_or_none.return_value = mock_task
result = TaskManager.wait_for_completion("test-uuid")
assert result == mock_task
# Should only call find_one_or_none once (initial check)
mock_dao.find_one_or_none.assert_called_once()
@patch("superset.tasks.manager.cache_manager")
@patch("superset.daos.tasks.TaskDAO")
def test_wait_for_completion_timeout(self, mock_dao, mock_cache_manager):
"""Test wait_for_completion raises TimeoutError when timeout expires"""
import pytest
mock_cache_manager.signal_cache = None
mock_task = MagicMock()
mock_task.uuid = "test-uuid"
mock_task.status = "in_progress" # Never completes
mock_dao.find_one_or_none.return_value = mock_task
with pytest.raises(TimeoutError, match="Timeout waiting"):
TaskManager.wait_for_completion("test-uuid", timeout=0.1)
@patch("superset.tasks.manager.cache_manager")
@patch("superset.daos.tasks.TaskDAO")
def test_wait_for_completion_polling_success(self, mock_dao, mock_cache_manager):
"""Test wait_for_completion returns when task completes via polling"""
mock_cache_manager.signal_cache = None
mock_task_pending = MagicMock()
mock_task_pending.uuid = "test-uuid"
mock_task_pending.status = "pending"
mock_task_complete = MagicMock()
mock_task_complete.uuid = "test-uuid"
mock_task_complete.status = "success"
# First call returns pending, second returns complete
mock_dao.find_one_or_none.side_effect = [
mock_task_pending,
mock_task_complete,
]
result = TaskManager.wait_for_completion(
"test-uuid",
timeout=5.0,
poll_interval=0.1,
)
assert result.status == "success"
@patch("superset.tasks.manager.cache_manager")
@patch("superset.daos.tasks.TaskDAO")
def test_wait_for_completion_with_pubsub(self, mock_dao, mock_cache_manager):
"""Test wait_for_completion uses pub/sub when Redis available"""
mock_task_pending = MagicMock()
mock_task_pending.uuid = "test-uuid"
mock_task_pending.status = "pending"
mock_task_complete = MagicMock()
mock_task_complete.uuid = "test-uuid"
mock_task_complete.status = "success"
# First call returns pending, second returns complete
mock_dao.find_one_or_none.side_effect = [
mock_task_pending,
mock_task_complete,
]
# Set up mock Redis with pub/sub
mock_redis = MagicMock()
mock_pubsub = MagicMock()
# Simulate receiving a completion message
mock_pubsub.get_message.return_value = {
"type": "message",
"data": "success",
}
mock_redis.pubsub.return_value = mock_pubsub
mock_cache_manager.signal_cache = mock_redis
result = TaskManager.wait_for_completion(
"test-uuid",
timeout=5.0,
)
assert result.status == "success"
# Should have subscribed to completion channel
mock_pubsub.subscribe.assert_called_once_with("gtf:complete:test-uuid")
# Should have cleaned up
mock_pubsub.unsubscribe.assert_called_once()
mock_pubsub.close.assert_called_once()
@patch("superset.tasks.manager.cache_manager")
@patch("superset.daos.tasks.TaskDAO")
def test_wait_for_completion_pubsub_error_raises(
self, mock_dao, mock_cache_manager
):
"""Test wait_for_completion raises exception on Redis error when
Redis configured"""
import pytest
mock_task_pending = MagicMock()
mock_task_pending.uuid = "test-uuid"
mock_task_pending.status = "pending"
mock_dao.find_one_or_none.return_value = mock_task_pending
# Set up mock Redis that fails
mock_redis = MagicMock()
mock_redis.pubsub.side_effect = redis.RedisError("Connection failed")
mock_cache_manager.signal_cache = mock_redis
# With fail-fast behavior, Redis error is raised instead of falling back
with pytest.raises(redis.RedisError, match="Connection failed"):
TaskManager.wait_for_completion(
"test-uuid",
timeout=5.0,
poll_interval=0.1,
)
def test_terminal_states_constant(self):
"""Test TERMINAL_STATES contains expected values"""
expected = {"success", "failure", "aborted", "timed_out"}
assert TaskManager.TERMINAL_STATES == expected

View File

@@ -0,0 +1,612 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Unit tests for GTF timeout handling."""
import time
from unittest.mock import MagicMock, patch
from uuid import UUID
import pytest
from superset_core.api.tasks import TaskOptions, TaskScope
from superset.tasks.context import TaskContext
from superset.tasks.decorators import TaskWrapper
TEST_UUID = UUID("b8b61b7b-1cd3-4a31-a74a-0a95341afc06")
# =============================================================================
# Fixtures
# =============================================================================
@pytest.fixture
def mock_flask_app():
"""Create a properly configured mock Flask app."""
mock_app = MagicMock()
mock_app.config = {
"TASK_ABORT_POLLING_DEFAULT_INTERVAL": 0.1,
}
# Make app_context() return a proper context manager
mock_app.app_context.return_value.__enter__ = MagicMock(return_value=None)
mock_app.app_context.return_value.__exit__ = MagicMock(return_value=None)
return mock_app
@pytest.fixture
def mock_task_abortable():
"""Create a mock task that is abortable."""
task = MagicMock()
task.uuid = TEST_UUID
task.status = "in_progress"
task.properties_dict = {"is_abortable": True}
task.payload_dict = {}
# Set real values for dedup_key generation (used by UpdateTaskCommand lock)
task.scope = "shared"
task.task_type = "test_task"
task.task_key = "test_key"
task.user_id = 1
return task
@pytest.fixture
def mock_task_not_abortable():
"""Create a mock task that is NOT abortable."""
task = MagicMock()
task.uuid = TEST_UUID
task.status = "in_progress"
task.properties_dict = {} # No is_abortable means it's not abortable
task.payload_dict = {}
# Set real values for dedup_key generation (used by UpdateTaskCommand lock)
task.scope = "shared"
task.task_type = "test_task"
task.task_key = "test_key"
task.user_id = 1
return task
@pytest.fixture
def task_context_for_timeout(mock_flask_app, mock_task_abortable):
"""Create TaskContext with mocked dependencies for timeout tests."""
# Ensure mock_task has required attributes for TaskContext
mock_task_abortable.payload_dict = {}
with (
patch("superset.tasks.context.current_app") as mock_current_app,
patch("superset.daos.tasks.TaskDAO") as mock_dao,
patch("superset.tasks.manager.cache_manager") as mock_cache_manager,
):
# Disable Redis by making signal_cache return None
mock_cache_manager.signal_cache = None
# Configure current_app mock
mock_current_app.config = mock_flask_app.config
mock_current_app._get_current_object.return_value = mock_flask_app
# Configure TaskDAO mock
mock_dao.find_one_or_none.return_value = mock_task_abortable
ctx = TaskContext(mock_task_abortable)
ctx._app = mock_flask_app
yield ctx
# Cleanup: stop timers if started
ctx.stop_timeout_timer()
if ctx._abort_listener:
ctx.stop_abort_polling()
# =============================================================================
# TaskWrapper._merge_options Timeout Tests
# =============================================================================
class TestTimeoutMerging:
"""Test timeout merging behavior in TaskWrapper._merge_options."""
def test_merge_options_decorator_timeout_used_when_no_override(self):
"""Test that decorator timeout is used when no override is provided."""
def dummy_func():
pass
wrapper = TaskWrapper(
name="test_task",
func=dummy_func,
default_options=TaskOptions(),
scope=TaskScope.PRIVATE,
default_timeout=300, # 5-minute default
)
merged = wrapper._merge_options(None)
assert merged.timeout == 300
def test_merge_options_override_timeout_takes_precedence(self):
"""Test that TaskOptions timeout overrides decorator default."""
def dummy_func():
pass
wrapper = TaskWrapper(
name="test_task",
func=dummy_func,
default_options=TaskOptions(),
scope=TaskScope.PRIVATE,
default_timeout=300, # 5-minute default
)
override = TaskOptions(timeout=600) # 10-minute override
merged = wrapper._merge_options(override)
assert merged.timeout == 600
def test_merge_options_no_timeout_when_not_configured(self):
"""Test that no timeout is set when not configured anywhere."""
def dummy_func():
pass
wrapper = TaskWrapper(
name="test_task",
func=dummy_func,
default_options=TaskOptions(),
scope=TaskScope.PRIVATE,
default_timeout=None, # No default timeout
)
merged = wrapper._merge_options(None)
assert merged.timeout is None
def test_merge_options_override_with_other_options_preserves_timeout(self):
"""Test that setting other options doesn't lose decorator timeout."""
def dummy_func():
pass
wrapper = TaskWrapper(
name="test_task",
func=dummy_func,
default_options=TaskOptions(),
scope=TaskScope.PRIVATE,
default_timeout=300,
)
# Override only task_key, not timeout
override = TaskOptions(task_key="my-key")
merged = wrapper._merge_options(override)
# Should keep decorator timeout since override.timeout is None
assert merged.timeout == 300
assert merged.task_key == "my-key"
# =============================================================================
# TaskContext Timeout Timer Tests
# =============================================================================
class TestTimeoutTimer:
"""Test TaskContext timeout timer behavior."""
def test_start_timeout_timer_sets_timer(self, task_context_for_timeout):
"""Test that start_timeout_timer creates a timer."""
ctx = task_context_for_timeout
assert ctx._timeout_timer is None
ctx.start_timeout_timer(10)
assert ctx._timeout_timer is not None
assert ctx._timeout_triggered is False
def test_start_timeout_timer_is_idempotent(self, task_context_for_timeout):
"""Test that starting timer twice doesn't create duplicate timers."""
ctx = task_context_for_timeout
ctx.start_timeout_timer(10)
first_timer = ctx._timeout_timer
ctx.start_timeout_timer(20) # Try to start again
second_timer = ctx._timeout_timer
assert first_timer is second_timer
def test_stop_timeout_timer_cancels_timer(self, task_context_for_timeout):
"""Test that stop_timeout_timer cancels the timer."""
ctx = task_context_for_timeout
ctx.start_timeout_timer(10)
assert ctx._timeout_timer is not None
ctx.stop_timeout_timer()
assert ctx._timeout_timer is None
def test_stop_timeout_timer_safe_when_no_timer(self, task_context_for_timeout):
"""Test that stop_timeout_timer doesn't fail when no timer exists."""
ctx = task_context_for_timeout
assert ctx._timeout_timer is None
ctx.stop_timeout_timer() # Should not raise
assert ctx._timeout_timer is None
def test_timeout_triggered_property_initially_false(self, task_context_for_timeout):
"""Test that timeout_triggered is False initially."""
ctx = task_context_for_timeout
assert ctx.timeout_triggered is False
def test_cleanup_stops_timeout_timer(self, task_context_for_timeout):
"""Test that _run_cleanup stops the timeout timer."""
ctx = task_context_for_timeout
ctx.start_timeout_timer(10)
assert ctx._timeout_timer is not None
ctx._run_cleanup()
assert ctx._timeout_timer is None
class TestTimeoutTrigger:
"""Test timeout trigger behavior when timer fires."""
def test_timeout_triggers_abort_when_abortable(
self, mock_flask_app, mock_task_abortable
):
"""Test that timeout triggers abort handlers when task is abortable."""
abort_called = False
with (
patch("superset.tasks.context.current_app") as mock_current_app,
patch("superset.daos.tasks.TaskDAO") as mock_dao,
patch(
"superset.commands.tasks.update.UpdateTaskCommand"
) as mock_update_cmd,
patch("superset.tasks.manager.cache_manager") as mock_cache_manager,
):
# Disable Redis by making signal_cache return None
mock_cache_manager.signal_cache = None
mock_current_app.config = mock_flask_app.config
mock_current_app._get_current_object.return_value = mock_flask_app
mock_dao.find_one_or_none.return_value = mock_task_abortable
ctx = TaskContext(mock_task_abortable)
ctx._app = mock_flask_app
@ctx.on_abort
def handle_abort():
nonlocal abort_called
abort_called = True
# Start short timeout
ctx.start_timeout_timer(1)
# Wait for timeout to fire
time.sleep(1.5)
# Abort handler should have been called
assert abort_called
assert ctx._timeout_triggered
assert ctx._abort_detected
# Verify UpdateTaskCommand was called with ABORTING status
mock_update_cmd.assert_called()
call_kwargs = mock_update_cmd.call_args[1]
assert call_kwargs.get("status") == "aborting"
# Cleanup
ctx.stop_timeout_timer()
if ctx._abort_listener:
ctx.stop_abort_polling()
def test_timeout_logs_warning_when_not_abortable(
self, mock_flask_app, mock_task_not_abortable
):
"""Test that timeout logs warning when task has no abort handler."""
with (
patch("superset.tasks.context.current_app") as mock_current_app,
patch("superset.daos.tasks.TaskDAO") as mock_dao,
patch("superset.tasks.context.logger") as mock_logger,
patch("superset.tasks.manager.cache_manager") as mock_cache_manager,
):
# Disable Redis by making signal_cache return None
mock_cache_manager.signal_cache = None
mock_current_app.config = mock_flask_app.config
mock_current_app._get_current_object.return_value = mock_flask_app
mock_dao.find_one_or_none.return_value = mock_task_not_abortable
ctx = TaskContext(mock_task_not_abortable)
ctx._app = mock_flask_app
# No abort handler registered
# Start short timeout
ctx.start_timeout_timer(1)
# Wait for timeout to fire
time.sleep(1.5)
# Should have logged warning
mock_logger.warning.assert_called()
warning_call = mock_logger.warning.call_args
assert "no abort handler" in warning_call[0][0].lower()
assert ctx._timeout_triggered
assert not ctx._abort_detected # No abort since no handler
# Cleanup
ctx.stop_timeout_timer()
def test_timeout_does_not_trigger_if_already_aborting(
self, mock_flask_app, mock_task_abortable
):
"""Test that timeout doesn't re-trigger abort if already aborting."""
abort_count = 0
with (
patch("superset.tasks.context.current_app") as mock_current_app,
patch("superset.daos.tasks.TaskDAO") as mock_dao,
patch("superset.commands.tasks.update.UpdateTaskCommand"),
patch("superset.tasks.manager.cache_manager") as mock_cache_manager,
):
# Disable Redis by making signal_cache return None
mock_cache_manager.signal_cache = None
mock_current_app.config = mock_flask_app.config
mock_current_app._get_current_object.return_value = mock_flask_app
mock_dao.find_one_or_none.return_value = mock_task_abortable
ctx = TaskContext(mock_task_abortable)
ctx._app = mock_flask_app
@ctx.on_abort
def handle_abort():
nonlocal abort_count
abort_count += 1
# Pre-set abort detected
ctx._abort_detected = True
# Start short timeout
ctx.start_timeout_timer(1)
# Wait for timeout to fire
time.sleep(1.5)
# Handler should NOT have been called since already aborting
assert abort_count == 0
# Cleanup
ctx.stop_timeout_timer()
if ctx._abort_listener:
ctx.stop_abort_polling()
# =============================================================================
# Task Decorator Timeout Tests
# =============================================================================
class TestTaskDecoratorTimeout:
"""Test @task decorator timeout parameter."""
def test_task_decorator_accepts_timeout(self):
"""Test that @task decorator accepts timeout parameter."""
from superset.tasks.decorators import task
from superset.tasks.registry import TaskRegistry
@task(name="test_timeout_task_1", timeout=300)
def timeout_test_task_1():
pass
assert isinstance(timeout_test_task_1, TaskWrapper)
assert timeout_test_task_1.default_timeout == 300
# Cleanup registry
TaskRegistry._tasks.pop("test_timeout_task_1", None)
def test_task_decorator_without_timeout(self):
"""Test that @task decorator works without timeout."""
from superset.tasks.decorators import task
from superset.tasks.registry import TaskRegistry
@task(name="test_timeout_task_2")
def timeout_test_task_2():
pass
assert isinstance(timeout_test_task_2, TaskWrapper)
assert timeout_test_task_2.default_timeout is None
# Cleanup registry
TaskRegistry._tasks.pop("test_timeout_task_2", None)
def test_task_decorator_with_all_params(self):
"""Test that @task decorator accepts all parameters together."""
from superset.tasks.decorators import task
from superset.tasks.registry import TaskRegistry
@task(name="test_timeout_task_3", scope=TaskScope.SHARED, timeout=600)
def timeout_test_task_3():
pass
assert timeout_test_task_3.name == "test_timeout_task_3"
assert timeout_test_task_3.scope == TaskScope.SHARED
assert timeout_test_task_3.default_timeout == 600
# Cleanup registry
TaskRegistry._tasks.pop("test_timeout_task_3", None)
# =============================================================================
# Timeout Terminal State Tests
# =============================================================================
class TestTimeoutTerminalState:
"""Test timeout transitions to correct terminal state (TIMED_OUT vs FAILURE)."""
def test_timeout_triggered_flag_set_on_timeout(
self, mock_flask_app, mock_task_abortable
):
"""Test that timeout_triggered flag is set when timeout fires."""
with (
patch("superset.tasks.context.current_app") as mock_current_app,
patch("superset.daos.tasks.TaskDAO") as mock_dao,
patch("superset.commands.tasks.update.UpdateTaskCommand"),
patch("superset.tasks.manager.cache_manager") as mock_cache_manager,
):
# Disable Redis by making signal_cache return None
mock_cache_manager.signal_cache = None
mock_current_app.config = mock_flask_app.config
mock_current_app._get_current_object.return_value = mock_flask_app
mock_dao.find_one_or_none.return_value = mock_task_abortable
ctx = TaskContext(mock_task_abortable)
ctx._app = mock_flask_app
@ctx.on_abort
def handle_abort():
pass
# Initially not triggered
assert ctx.timeout_triggered is False
# Start short timeout
ctx.start_timeout_timer(1)
# Wait for timeout to fire
time.sleep(1.5)
# Should be set after timeout
assert ctx.timeout_triggered is True
# Cleanup
ctx.stop_timeout_timer()
if ctx._abort_listener:
ctx.stop_abort_polling()
def test_user_abort_does_not_set_timeout_triggered(
self, mock_flask_app, mock_task_abortable
):
"""Test that user abort doesn't set timeout_triggered flag."""
with (
patch("superset.tasks.context.current_app") as mock_current_app,
patch("superset.daos.tasks.TaskDAO") as mock_dao,
patch("superset.commands.tasks.update.UpdateTaskCommand"),
patch("superset.tasks.manager.cache_manager") as mock_cache_manager,
):
# Disable Redis by making signal_cache return None
mock_cache_manager.signal_cache = None
mock_current_app.config = mock_flask_app.config
mock_current_app._get_current_object.return_value = mock_flask_app
mock_dao.find_one_or_none.return_value = mock_task_abortable
ctx = TaskContext(mock_task_abortable)
ctx._app = mock_flask_app
@ctx.on_abort
def handle_abort():
pass
# Simulate user abort (not timeout)
ctx._on_abort_detected()
# timeout_triggered should still be False
assert ctx.timeout_triggered is False
# But abort_detected should be True
assert ctx._abort_detected is True
# Cleanup
if ctx._abort_listener:
ctx.stop_abort_polling()
def test_abort_handlers_completed_tracks_success(
self, mock_flask_app, mock_task_abortable
):
"""Test that abort_handlers_completed flag tracks successful
handler execution."""
with (
patch("superset.tasks.context.current_app") as mock_current_app,
patch("superset.daos.tasks.TaskDAO") as mock_dao,
patch("superset.commands.tasks.update.UpdateTaskCommand"),
patch("superset.tasks.manager.cache_manager") as mock_cache_manager,
):
# Disable Redis by making signal_cache return None
mock_cache_manager.signal_cache = None
mock_current_app.config = mock_flask_app.config
mock_current_app._get_current_object.return_value = mock_flask_app
mock_dao.find_one_or_none.return_value = mock_task_abortable
ctx = TaskContext(mock_task_abortable)
ctx._app = mock_flask_app
@ctx.on_abort
def handle_abort():
pass # Successful handler
# Initially not completed
assert ctx.abort_handlers_completed is False
# Trigger abort handlers
ctx._trigger_abort_handlers()
# Should be marked as completed
assert ctx.abort_handlers_completed is True
# Cleanup
if ctx._abort_listener:
ctx.stop_abort_polling()
def test_abort_handlers_completed_false_on_exception(
self, mock_flask_app, mock_task_abortable
):
"""Test that abort_handlers_completed is False when handler throws."""
with (
patch("superset.tasks.context.current_app") as mock_current_app,
patch("superset.daos.tasks.TaskDAO") as mock_dao,
patch("superset.commands.tasks.update.UpdateTaskCommand"),
patch("superset.tasks.manager.cache_manager") as mock_cache_manager,
):
# Disable Redis by making signal_cache return None
mock_cache_manager.signal_cache = None
mock_current_app.config = mock_flask_app.config
mock_current_app._get_current_object.return_value = mock_flask_app
mock_dao.find_one_or_none.return_value = mock_task_abortable
ctx = TaskContext(mock_task_abortable)
ctx._app = mock_flask_app
@ctx.on_abort
def handle_abort():
raise ValueError("Handler failed")
# Initially not completed
assert ctx.abort_handlers_completed is False
# Trigger abort handlers (will catch the exception internally)
ctx._trigger_abort_handlers()
# Should NOT be marked as completed since handler threw
assert ctx.abort_handlers_completed is False
# Cleanup
if ctx._abort_listener:
ctx.stop_abort_polling()

View File

@@ -22,9 +22,19 @@ from typing import Any, Optional, Union
import pytest
from flask_appbuilder.security.sqla.models import User
from superset_core.api.tasks import TaskScope
from superset.tasks.exceptions import ExecutorNotFoundError, InvalidExecutorError
from superset.tasks.types import Executor, ExecutorType, FixedExecutor
from superset.tasks.utils import (
error_update,
get_active_dedup_key,
get_finished_dedup_key,
parse_properties,
progress_update,
serialize_properties,
)
from superset.utils.hashing import hash_from_str
FIXED_USER_ID = 1234
FIXED_USERNAME = "admin"
@@ -330,3 +340,242 @@ def test_get_executor(
)
assert executor_type == expected_executor_type
assert executor == expected_executor
@pytest.mark.parametrize(
"scope,task_type,task_key,user_id,expected_composite_key",
[
# Private tasks with TaskScope enum
(
TaskScope.PRIVATE,
"sql_execution",
"chart_123",
42,
"private|sql_execution|chart_123|42",
),
(
TaskScope.PRIVATE,
"thumbnail_gen",
"dash_456",
100,
"private|thumbnail_gen|dash_456|100",
),
# Private tasks with string scope
(
"private",
"api_call",
"endpoint_789",
200,
"private|api_call|endpoint_789|200",
),
# Shared tasks with TaskScope enum
(
TaskScope.SHARED,
"report_gen",
"monthly_report",
None,
"shared|report_gen|monthly_report",
),
(
TaskScope.SHARED,
"export_csv",
"large_export",
999, # user_id should be ignored for shared
"shared|export_csv|large_export",
),
# Shared tasks with string scope
(
"shared",
"batch_process",
"batch_001",
123, # user_id should be ignored for shared
"shared|batch_process|batch_001",
),
# System tasks with TaskScope enum
(
TaskScope.SYSTEM,
"cleanup_task",
"daily_cleanup",
None,
"system|cleanup_task|daily_cleanup",
),
(
TaskScope.SYSTEM,
"db_migration",
"version_123",
1, # user_id should be ignored for system
"system|db_migration|version_123",
),
# System tasks with string scope
(
"system",
"maintenance",
"nightly_job",
2, # user_id should be ignored for system
"system|maintenance|nightly_job",
),
],
)
def test_get_active_dedup_key(
scope, task_type, task_key, user_id, expected_composite_key, app_context
):
"""Test get_active_dedup_key generates a hash of the composite key.
The function hashes the composite key using the configured HASH_ALGORITHM
to produce a fixed-length dedup_key for database storage. The result is
truncated to 64 chars to fit the database column.
"""
result = get_active_dedup_key(scope, task_type, task_key, user_id)
# The result should be a hash of the expected composite key, truncated to 64 chars
expected_hash = hash_from_str(expected_composite_key)[:64]
assert result == expected_hash
assert len(result) <= 64
def test_get_active_dedup_key_private_requires_user_id():
"""Test that private tasks require explicit user_id parameter."""
with pytest.raises(ValueError, match="user_id required for private tasks"):
get_active_dedup_key(TaskScope.PRIVATE, "test_type", "test_key")
def test_get_finished_dedup_key():
"""Test that finished tasks use UUID as dedup_key"""
test_uuid = "a1b2c3d4-e5f6-7890-abcd-ef1234567890"
result = get_finished_dedup_key(test_uuid)
assert result == test_uuid
@pytest.mark.parametrize(
"progress,expected",
[
# Float (percentage) progress
(0.5, {"progress_percent": 0.5}),
(0.0, {"progress_percent": 0.0}),
(1.0, {"progress_percent": 1.0}),
(0.25, {"progress_percent": 0.25}),
# Int (count only) progress
(42, {"progress_current": 42}),
(0, {"progress_current": 0}),
(1000, {"progress_current": 1000}),
# Tuple (current, total) progress with auto-computed percentage
(
(50, 100),
{"progress_current": 50, "progress_total": 100, "progress_percent": 0.5},
),
(
(25, 100),
{"progress_current": 25, "progress_total": 100, "progress_percent": 0.25},
),
(
(100, 100),
{"progress_current": 100, "progress_total": 100, "progress_percent": 1.0},
),
# Tuple with zero total (no percentage computed)
((10, 0), {"progress_current": 10, "progress_total": 0}),
((0, 0), {"progress_current": 0, "progress_total": 0}),
],
)
def test_progress_update(progress, expected):
"""Test progress_update returns correct TaskProperties dict."""
result = progress_update(progress)
assert result == expected
def test_error_update():
"""Test error_update captures exception details."""
try:
raise ValueError("Test error message")
except ValueError as e:
result = error_update(e)
assert result["error_message"] == "Test error message"
assert result["exception_type"] == "ValueError"
assert "stack_trace" in result
assert "ValueError" in result["stack_trace"]
def test_error_update_custom_exception():
"""Test error_update with custom exception class."""
class CustomError(Exception):
pass
try:
raise CustomError("Custom error")
except CustomError as e:
result = error_update(e)
assert result["error_message"] == "Custom error"
assert result["exception_type"] == "CustomError"
@pytest.mark.parametrize(
"json_str,expected",
[
# Valid JSON
(
'{"is_abortable": true, "progress_percent": 0.5}',
{"is_abortable": True, "progress_percent": 0.5},
),
(
'{"error_message": "Something failed"}',
{"error_message": "Something failed"},
),
(
'{"progress_current": 50, "progress_total": 100}',
{"progress_current": 50, "progress_total": 100},
),
# Empty/None cases
("", {}),
(None, {}),
# Invalid JSON returns empty dict
("not valid json", {}),
("{broken", {}),
# Unknown keys are preserved (forward compatibility)
(
'{"is_abortable": true, "future_field": "value"}',
{"is_abortable": True, "future_field": "value"},
),
],
)
def test_parse_properties(json_str, expected):
"""Test parse_properties parses JSON to TaskProperties dict."""
result = parse_properties(json_str)
assert result == expected
@pytest.mark.parametrize(
"props,expected_contains",
[
# Full properties
(
{"is_abortable": True, "progress_percent": 0.5},
{"is_abortable": True, "progress_percent": 0.5},
),
# Empty dict
({}, {}),
# Sparse properties
({"is_abortable": True}, {"is_abortable": True}),
({"error_message": "fail"}, {"error_message": "fail"}),
],
)
def test_serialize_properties(props, expected_contains):
"""Test serialize_properties converts TaskProperties to JSON."""
from superset.utils import json
result = serialize_properties(props)
parsed = json.loads(result)
assert parsed == expected_contains
def test_properties_roundtrip():
"""Test that serialize -> parse roundtrip preserves data."""
original = {
"is_abortable": True,
"progress_percent": 0.75,
"error_message": "Test error",
}
serialized = serialize_properties(original)
parsed = parse_properties(serialized)
assert parsed == original

View File

@@ -54,7 +54,7 @@ def test_json_loads_exception():
def test_json_loads_encoding():
unicode_data = b'{"a": "\u0073\u0074\u0072"}'
unicode_data = rb'{"a": "\u0073\u0074\u0072"}'
data = json.loads(unicode_data)
assert data["a"] == "str"
utf16_data = b'\xff\xfe{\x00"\x00a\x00"\x00:\x00 \x00"\x00s\x00t\x00r\x00"\x00}\x00'

View File

@@ -119,7 +119,7 @@ def test_refresh_oauth2_token_deletes_token_on_oauth2_exception(
was revoked), the invalid token should be deleted and the exception re-raised.
"""
db = mocker.patch("superset.utils.oauth2.db")
mocker.patch("superset.utils.oauth2.KeyValueDistributedLock")
mocker.patch("superset.utils.oauth2.DistributedLock")
class OAuth2ExceptionError(Exception):
pass
@@ -149,7 +149,7 @@ def test_refresh_oauth2_token_keeps_token_on_other_exception(
exception re-raised.
"""
db = mocker.patch("superset.utils.oauth2.db")
mocker.patch("superset.utils.oauth2.KeyValueDistributedLock")
mocker.patch("superset.utils.oauth2.DistributedLock")
class OAuth2ExceptionError(Exception):
pass
@@ -175,7 +175,7 @@ def test_refresh_oauth2_token_no_access_token_in_response(
This can happen when the refresh token was revoked.
"""
mocker.patch("superset.utils.oauth2.db")
mocker.patch("superset.utils.oauth2.KeyValueDistributedLock")
mocker.patch("superset.utils.oauth2.DistributedLock")
db_engine_spec = mocker.MagicMock()
db_engine_spec.get_oauth2_fresh_token.return_value = {
"error": "invalid_grant",