mirror of
https://github.com/apache/superset.git
synced 2026-05-30 04:39:20 +00:00
feat: add global task framework (#36368)
This commit is contained in:
22
UPDATING.md
22
UPDATING.md
@@ -24,6 +24,28 @@ assists people when migrating to a new version.
|
||||
|
||||
## Next
|
||||
|
||||
### Signal Cache Backend
|
||||
|
||||
A new `SIGNAL_CACHE_CONFIG` configuration provides a unified Redis-based backend for real-time coordination features in Superset. This backend enables:
|
||||
|
||||
- **Pub/sub messaging** for real-time event notifications between workers
|
||||
- **Atomic distributed locking** using Redis SET NX EX (more performant than database-backed locks)
|
||||
- **Event-based coordination** for background task management
|
||||
|
||||
The signal cache is used by the Global Task Framework (GTF) for abort notifications and task completion signaling, and will eventually replace `GLOBAL_ASYNC_QUERIES_CACHE_BACKEND` as the standard signaling backend. Configuring this is recommended for Redis enabled production deployments.
|
||||
|
||||
Example configuration in `superset_config.py`:
|
||||
```python
|
||||
SIGNAL_CACHE_CONFIG = {
|
||||
"CACHE_TYPE": "RedisCache",
|
||||
"CACHE_KEY_PREFIX": "signal_",
|
||||
"CACHE_REDIS_URL": "redis://localhost:6379/1",
|
||||
"CACHE_DEFAULT_TIMEOUT": 300,
|
||||
}
|
||||
```
|
||||
|
||||
See `superset/config.py` for complete configuration options.
|
||||
|
||||
### WebSocket config for GAQ with Docker
|
||||
|
||||
[35896](https://github.com/apache/superset/pull/35896) and [37624](https://github.com/apache/superset/pull/37624) updated documentation on how to run and configure Superset with Docker. Specifically for the WebSocket configuration, a new `docker/superset-websocket/config.example.json` was added to the repo, so that users could copy it to create a `docker/superset-websocket/config.json` file. The existing `docker/superset-websocket/config.json` was removed and git-ignored, so if you're using GAQ / WebSocket make sure to:
|
||||
|
||||
@@ -51,4 +51,5 @@ Extensions can provide:
|
||||
- **[Deployment](./deployment)** - Packaging and deploying extensions
|
||||
- **[MCP Integration](./mcp)** - Adding AI agent capabilities using extensions
|
||||
- **[Security](./security)** - Security considerations and best practices
|
||||
- **[Tasks](./tasks)** - Framework for creating and managing long running tasks
|
||||
- **[Community Extensions](./registry)** - Browse extensions shared by the community
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
---
|
||||
title: Community Extensions
|
||||
sidebar_position: 10
|
||||
sidebar_position: 11
|
||||
---
|
||||
|
||||
<!--
|
||||
|
||||
440
docs/developer_portal/extensions/tasks.md
Normal file
440
docs/developer_portal/extensions/tasks.md
Normal file
@@ -0,0 +1,440 @@
|
||||
---
|
||||
title: Tasks
|
||||
sidebar_position: 10
|
||||
---
|
||||
|
||||
<!--
|
||||
Licensed to the Apache Software Foundation (ASF) under one
|
||||
or more contributor license agreements. See the NOTICE file
|
||||
distributed with this work for additional information
|
||||
regarding copyright ownership. The ASF licenses this file
|
||||
to you under the Apache License, Version 2.0 (the
|
||||
"License"); you may not use this file except in compliance
|
||||
with the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing,
|
||||
software distributed under the License is distributed on an
|
||||
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations
|
||||
under the License.
|
||||
-->
|
||||
|
||||
# Global Task Framework
|
||||
|
||||
The Global Task Framework (GTF) provides a unified way to manage background tasks. It handles task execution, progress tracking, cancellation, and deduplication for both synchronous and asynchronous execution. The framework uses distributed locking internally to ensure race-free operations—you don't need to worry about concurrent task creation or cancellation conflicts.
|
||||
|
||||
## Enabling GTF
|
||||
|
||||
GTF is disabled by default and must be enabled via the `GLOBAL_TASK_FRAMEWORK` feature flag in your `superset_config.py`:
|
||||
|
||||
```python
|
||||
FEATURE_FLAGS = {
|
||||
"GLOBAL_TASK_FRAMEWORK": True,
|
||||
}
|
||||
```
|
||||
|
||||
When GTF is disabled:
|
||||
- The Task List UI menu item is hidden
|
||||
- The `/api/v1/task/*` endpoints return 404
|
||||
- Calling or scheduling a `@task`-decorated function raises `GlobalTaskFrameworkDisabledError`
|
||||
|
||||
:::note Future Migration
|
||||
When GTF is considered stable, it will replace legacy Celery tasks for built-in features like thumbnails and alerts & reports. Enabling this flag prepares your deployment for that migration.
|
||||
:::
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Define a Task
|
||||
|
||||
```python
|
||||
from superset_core.api.types import task, get_context
|
||||
|
||||
@task
|
||||
def process_data(dataset_id: int) -> None:
|
||||
ctx = get_context()
|
||||
|
||||
@ctx.on_cleanup
|
||||
def cleanup():
|
||||
logger.info("Processing complete")
|
||||
|
||||
data = fetch_dataset(dataset_id)
|
||||
process_and_cache(data)
|
||||
```
|
||||
|
||||
### Execute a Task
|
||||
|
||||
```python
|
||||
# Async execution - schedules on Celery worker
|
||||
task = process_data.schedule(dataset_id=123)
|
||||
print(task.status) # "pending"
|
||||
|
||||
# Sync execution - runs inline in current process
|
||||
task = process_data(dataset_id=123)
|
||||
# ... blocks until complete
|
||||
print(task.status) # "success"
|
||||
```
|
||||
|
||||
### Async vs Sync Execution
|
||||
|
||||
| Method | When to Use |
|
||||
|--------|-------------|
|
||||
| `.schedule()` | Long-running operations, background processing, when you need to return immediately |
|
||||
| Direct call | Short operations, when deduplication matters, when you need the result before responding |
|
||||
|
||||
Both execution modes provide the same task features: deduplication, progress tracking, cancellation, and visibility in the Task List UI. The difference is whether execution happens in a Celery worker (async) or inline (sync).
|
||||
|
||||
## Task Lifecycle
|
||||
|
||||
```
|
||||
PENDING ──→ IN_PROGRESS ────→ SUCCESS
|
||||
│ │
|
||||
│ ├──────────→ FAILURE
|
||||
│ ↓ ↑
|
||||
│ ABORTING ────────────┘
|
||||
│ │
|
||||
│ ├──────────→ TIMED_OUT (timeout)
|
||||
│ │
|
||||
└─────────────┴──────────→ ABORTED (user cancel)
|
||||
```
|
||||
|
||||
| Status | Description |
|
||||
|--------|-------------|
|
||||
| `PENDING` | Queued, awaiting execution |
|
||||
| `IN_PROGRESS` | Executing |
|
||||
| `ABORTING` | Abort/timeout triggered, abort handlers running |
|
||||
| `SUCCESS` | Completed successfully |
|
||||
| `FAILURE` | Failed with error or abort/cleanup handler exception |
|
||||
| `ABORTED` | Cancelled by user/admin |
|
||||
| `TIMED_OUT` | Exceeded configured timeout |
|
||||
|
||||
## Context API
|
||||
|
||||
Access task context via `get_context()` from within any `@task` function. The context provides methods for updating task metadata and registering handlers.
|
||||
|
||||
### Updating Task Metadata
|
||||
|
||||
Use `update_task()` to report progress and store custom payload data:
|
||||
|
||||
```python
|
||||
@task
|
||||
def my_task(items: list[int]) -> None:
|
||||
ctx = get_context()
|
||||
|
||||
for i, item in enumerate(items):
|
||||
result = process(item)
|
||||
ctx.update_task(
|
||||
progress=(i + 1, len(items)),
|
||||
payload={"last_result": result}
|
||||
)
|
||||
```
|
||||
|
||||
:::tip
|
||||
Call `update_task()` once per iteration for best performance. Frequent DB writes are throttled to limit metastore load, so batching progress and payload updates together in a single call ensures both are persisted at the same time.
|
||||
:::
|
||||
|
||||
#### Progress Formats
|
||||
|
||||
The `progress` parameter accepts three formats:
|
||||
|
||||
| Format | Example | Display |
|
||||
|--------|---------|---------|
|
||||
| `tuple[int, int]` | `progress=(3, 100)` | 3 of 100 (3%) with ETA |
|
||||
| `float` (0.0-1.0) | `progress=0.5` | 50% with ETA |
|
||||
| `int` | `progress=42` | 42 processed |
|
||||
|
||||
:::tip
|
||||
Use the tuple format `(current, total)` whenever possible. It provides the richest information to users: showing both the count and percentage, while still computing ETA automatically.
|
||||
:::
|
||||
|
||||
#### Payload
|
||||
|
||||
The `payload` parameter stores custom metadata that can help users understand what the task is doing. Each call to `update_task()` replaces the previous payload completely.
|
||||
|
||||
In the Task List UI, when a payload is defined, an info icon appears in the **Details** column. Users can hover over it to see the JSON content.
|
||||
|
||||
### Handlers
|
||||
|
||||
Register handlers to run cleanup logic or respond to abort requests:
|
||||
|
||||
| Handler | When it runs | Use case |
|
||||
|---------|--------------|----------|
|
||||
| `on_cleanup` | Always (success, failure, abort) | Release resources, close connections |
|
||||
| `on_abort` | When task is aborted | Set stop flag, cancel external operations |
|
||||
|
||||
```python
|
||||
@task
|
||||
def my_task() -> None:
|
||||
ctx = get_context()
|
||||
|
||||
@ctx.on_cleanup
|
||||
def cleanup():
|
||||
logger.info("Task ended, cleaning up")
|
||||
|
||||
@ctx.on_abort
|
||||
def handle_abort():
|
||||
logger.info("Abort requested")
|
||||
|
||||
# ... task logic
|
||||
```
|
||||
|
||||
Multiple handlers of the same type execute in LIFO order (last registered runs first). Abort handlers run first when abort is detected, then cleanup handlers run when the task ends.
|
||||
|
||||
#### Best-Effort Execution
|
||||
|
||||
**All registered handlers will always be attempted, even if one fails.** This ensures that a failure in one handler doesn't prevent other handlers from running their cleanup logic.
|
||||
|
||||
For example, if you have three cleanup handlers and the second one throws an exception:
|
||||
1. Handler 3 runs ✓
|
||||
2. Handler 2 throws an exception ✗ (logged, but execution continues)
|
||||
3. Handler 1 runs ✓
|
||||
|
||||
If any handler fails, the task is marked as `FAILURE` with combined error details showing all handler failures.
|
||||
|
||||
:::tip
|
||||
Write handlers to be independent and self-contained. Don't assume previous handlers succeeded, and don't rely on shared state between handlers.
|
||||
:::
|
||||
|
||||
## Making Tasks Abortable
|
||||
|
||||
When users click **Cancel** in the Task List, the system decides whether to **abort** (stop) the task or **unsubscribe** (remove the user from a shared task). Abort occurs when:
|
||||
- It's a private or system task
|
||||
- It's a shared task and the user is the last subscriber
|
||||
- An admin checks **Force abort** to stop the task for all subscribers
|
||||
|
||||
Pending tasks can always be aborted: they simply won't start. In-progress tasks require an abort handler to be abortable:
|
||||
|
||||
```python
|
||||
@task
|
||||
def abortable_task(items: list[str]) -> None:
|
||||
ctx = get_context()
|
||||
should_stop = False
|
||||
|
||||
@ctx.on_abort
|
||||
def handle_abort():
|
||||
nonlocal should_stop
|
||||
should_stop = True
|
||||
logger.info("Abort signal received")
|
||||
|
||||
@ctx.on_cleanup
|
||||
def cleanup():
|
||||
logger.info("Task ended, cleaning up")
|
||||
|
||||
for item in items:
|
||||
if should_stop:
|
||||
return # Exit gracefully
|
||||
process(item)
|
||||
```
|
||||
|
||||
**Key points:**
|
||||
- Registering `on_abort` marks the task as abortable and starts the abort listener
|
||||
- The abort handler fires automatically when abort is triggered
|
||||
- Use a flag pattern to gracefully stop processing at safe points
|
||||
- Without an abort handler, in-progress tasks cannot be aborted: the Cancel button in the Task List UI will be disabled
|
||||
|
||||
The framework automatically skips execution if a task was aborted while pending: no manual check needed at task start.
|
||||
|
||||
:::tip
|
||||
Always implement an abort handler for long-running tasks. This allows users to cancel unneeded tasks and free up worker capacity for other operations.
|
||||
:::
|
||||
|
||||
## Timeouts
|
||||
|
||||
Set a timeout to automatically abort tasks that run too long:
|
||||
|
||||
```python
|
||||
from superset_core.api.types import task, get_context, TaskOptions
|
||||
|
||||
# Set default timeout in decorator
|
||||
@task(timeout=300) # 5 minutes
|
||||
def process_data(dataset_id: int) -> None:
|
||||
ctx = get_context()
|
||||
should_stop = False
|
||||
|
||||
@ctx.on_abort
|
||||
def handle_abort():
|
||||
nonlocal should_stop
|
||||
should_stop = True
|
||||
|
||||
for chunk in fetch_large_dataset(dataset_id):
|
||||
if should_stop:
|
||||
return
|
||||
process(chunk)
|
||||
|
||||
# Override timeout at call time
|
||||
task = process_data.schedule(
|
||||
dataset_id=123,
|
||||
options=TaskOptions(timeout=600) # Override to 10 minutes
|
||||
)
|
||||
```
|
||||
|
||||
### How Timeouts Work
|
||||
|
||||
The timeout timer starts when the task begins executing (status changes to `IN_PROGRESS`). When the timeout expires:
|
||||
|
||||
1. **With an abort handler registered:** The task transitions to `ABORTING`, abort handlers run, then cleanup handlers run. The final status depends on handler execution:
|
||||
- If handlers complete successfully → `TIMED_OUT` status
|
||||
- If handlers throw an exception → `FAILURE` status
|
||||
|
||||
2. **Without an abort handler:** The framework cannot forcibly terminate the task. A warning is logged, and the task continues running. The Task List UI shows a warning indicator (⚠️) in the Details column to alert users that the timeout cannot be enforced.
|
||||
|
||||
### Timeout Precedence
|
||||
|
||||
| Source | Priority | Example |
|
||||
|--------|----------|---------|
|
||||
| `TaskOptions.timeout` | Highest | `options=TaskOptions(timeout=600)` |
|
||||
| `@task(timeout=...)` | Default | `@task(timeout=300)` |
|
||||
| Not set | No timeout | Task runs indefinitely |
|
||||
|
||||
Call-time options always override decorator defaults, allowing tasks to have sensible defaults while permitting callers to extend or shorten the timeout for specific use cases.
|
||||
|
||||
:::warning
|
||||
Timeouts require an abort handler to be effective. Without one, the timeout triggers only a warning and the task continues running. Always implement an abort handler when using timeouts.
|
||||
:::
|
||||
|
||||
## Deduplication
|
||||
|
||||
Use `task_key` to prevent duplicate task execution:
|
||||
|
||||
```python
|
||||
from superset_core.api.types import TaskOptions
|
||||
|
||||
# Without key - creates new task each time (random UUID)
|
||||
task1 = my_task.schedule(x=1)
|
||||
task2 = my_task.schedule(x=1) # Different task
|
||||
|
||||
# With key - joins existing task if active
|
||||
task1 = my_task.schedule(x=1, options=TaskOptions(task_key="report_123"))
|
||||
task2 = my_task.schedule(x=1, options=TaskOptions(task_key="report_123")) # Returns same task
|
||||
```
|
||||
|
||||
When a task with matching key already exists, the user is added as a subscriber and the existing task is returned. This behavior is consistent across all scopes—private tasks naturally have only one subscriber since their deduplication key includes the user ID.
|
||||
|
||||
Deduplication only applies to active tasks (pending/in-progress). Once a task completes, a new task with the same key can be created.
|
||||
|
||||
### Sync Join-and-Wait
|
||||
|
||||
When a sync call joins an existing task, it blocks until the task completes:
|
||||
|
||||
```python
|
||||
# Schedule async task
|
||||
task = my_task.schedule(options=TaskOptions(task_key="report_123"))
|
||||
|
||||
# Later sync call with same key blocks until completion of the active task
|
||||
task2 = my_task(options=TaskOptions(task_key="report_123"))
|
||||
assert task.uuid == task2.uuid # True
|
||||
print(task2.status) # "success" (terminal status)
|
||||
```
|
||||
|
||||
## Task Scopes
|
||||
|
||||
```python
|
||||
from superset_core.api.types import task, TaskScope
|
||||
|
||||
@task # Private by default
|
||||
def private_task(): ...
|
||||
|
||||
@task(scope=TaskScope.SHARED) # Multiple users can subscribe
|
||||
def shared_task(): ...
|
||||
|
||||
@task(scope=TaskScope.SYSTEM) # Admin-only visibility
|
||||
def system_task(): ...
|
||||
```
|
||||
|
||||
| Scope | Visibility | Cancel Behavior |
|
||||
|-------|------------|-----------------|
|
||||
| `PRIVATE` | Creator only | Cancels immediately |
|
||||
| `SHARED` | All subscribers | Last subscriber cancels; others unsubscribe |
|
||||
| `SYSTEM` | Admins only | Admin cancels |
|
||||
|
||||
## Task Cleanup
|
||||
|
||||
Completed tasks accumulate in the database over time. Configure a scheduled prune job to automatically remove old tasks:
|
||||
|
||||
```python
|
||||
# In your superset_config.py, add to your Celery beat schedule:
|
||||
CELERY_CONFIG.beat_schedule["prune_tasks"] = {
|
||||
"task": "prune_tasks",
|
||||
"schedule": crontab(minute=0, hour=0), # Run daily at midnight
|
||||
"kwargs": {
|
||||
"retention_period_days": 90, # Keep tasks for 90 days
|
||||
"max_rows_per_run": 10000, # Limit deletions per run
|
||||
},
|
||||
}
|
||||
```
|
||||
|
||||
The prune job only removes tasks in terminal states (`SUCCESS`, `FAILURE`, `ABORTED`, `TIMED_OUT`). Active tasks (`PENDING`, `IN_PROGRESS`, `ABORTING`) are never pruned.
|
||||
|
||||
See `superset/config.py` for a complete example configuration.
|
||||
|
||||
:::tip Signal Cache for Faster Notifications
|
||||
By default, abort detection and sync join-and-wait use database polling. Configure `SIGNAL_CACHE_CONFIG` to enable Redis pub/sub for real-time notifications. See [Signal Cache Backend](/docs/configuration/cache#signal-cache-backend) for configuration details.
|
||||
:::
|
||||
|
||||
## API Reference
|
||||
|
||||
### @task Decorator
|
||||
|
||||
```python
|
||||
@task(
|
||||
name: str | None = None,
|
||||
scope: TaskScope = TaskScope.PRIVATE,
|
||||
timeout: int | None = None
|
||||
)
|
||||
```
|
||||
|
||||
- `name`: Task identifier (defaults to function name)
|
||||
- `scope`: `PRIVATE`, `SHARED`, or `SYSTEM`
|
||||
- `timeout`: Default timeout in seconds (can be overridden via `TaskOptions`)
|
||||
|
||||
### TaskContext Methods
|
||||
|
||||
| Method | Description |
|
||||
|--------|-------------|
|
||||
| `update_task(progress, payload)` | Update progress and/or custom payload |
|
||||
| `on_cleanup(handler)` | Register cleanup handler |
|
||||
| `on_abort(handler)` | Register abort handler (makes task abortable) |
|
||||
|
||||
### TaskOptions
|
||||
|
||||
```python
|
||||
TaskOptions(
|
||||
task_key: str | None = None,
|
||||
task_name: str | None = None,
|
||||
timeout: int | None = None
|
||||
)
|
||||
```
|
||||
|
||||
- `task_key`: Deduplication key (also used as display name if `task_name` is not set)
|
||||
- `task_name`: Human-readable display name for the Task List UI
|
||||
- `timeout`: Timeout in seconds (overrides decorator default)
|
||||
|
||||
:::tip
|
||||
Provide a descriptive `task_name` for better readability in the Task List UI. While `task_key` is used for deduplication and may be technical (e.g., `chart_export_123`), `task_name` can be user-friendly (e.g., `"Export Sales Chart 123"`).
|
||||
:::
|
||||
|
||||
## Error Handling
|
||||
|
||||
Let exceptions propagate: the framework captures them automatically and sets task status to `FAILURE`:
|
||||
|
||||
```python
|
||||
@task
|
||||
def risky_task() -> None:
|
||||
# No try/catch needed - framework handles it
|
||||
result = operation_that_might_fail()
|
||||
```
|
||||
|
||||
On failure, the framework records:
|
||||
- `error_message`: Exception message
|
||||
- `exception_type`: Exception class name
|
||||
- `stack_trace`: Full traceback (visible when `SHOW_STACKTRACE=True`)
|
||||
|
||||
In the Task List UI, failed tasks show error details when hovering over the status. When stack traces are enabled, a separate bug icon appears in the **Details** column for viewing the full traceback.
|
||||
|
||||
Cleanup handlers still run after an exception, so resources can be properly released as necessary.
|
||||
|
||||
:::tip
|
||||
Use descriptive exception messages. In environments where stack traces are hidden (`SHOW_STACKTRACE=False`), users see only the error message and exception type when hovering over failed tasks. Clear messages help users troubleshoot issues without administrator assistance.
|
||||
:::
|
||||
@@ -53,6 +53,7 @@ module.exports = {
|
||||
'extensions/deployment',
|
||||
'extensions/mcp',
|
||||
'extensions/security',
|
||||
'extensions/tasks',
|
||||
'extensions/registry',
|
||||
],
|
||||
},
|
||||
|
||||
@@ -7,6 +7,12 @@ version: 1
|
||||
|
||||
# Caching
|
||||
|
||||
:::note
|
||||
When a cache backend is configured, Superset expects it to remain available. Operations will
|
||||
fail if the configured backend becomes unavailable rather than silently degrading. This
|
||||
fail-fast behavior ensures operators are immediately aware of infrastructure issues.
|
||||
:::
|
||||
|
||||
Superset uses [Flask-Caching](https://flask-caching.readthedocs.io/) for caching purposes.
|
||||
Flask-Caching supports various caching backends, including Redis (recommended), Memcached,
|
||||
SimpleCache (in-memory), or the local filesystem.
|
||||
@@ -153,6 +159,84 @@ Then on configuration:
|
||||
WEBDRIVER_AUTH_FUNC = auth_driver
|
||||
```
|
||||
|
||||
## Signal Cache Backend
|
||||
|
||||
Superset supports an optional signal cache (`SIGNAL_CACHE_CONFIG`) for
|
||||
high-performance distributed operations. This configuration enables:
|
||||
|
||||
- **Distributed locking**: Moves lock operations from the metadata database to Redis, improving
|
||||
performance and reducing metastore load
|
||||
- **Real-time event notifications**: Enables instant pub/sub messaging for task abort signals and
|
||||
completion notifications instead of polling-based approaches
|
||||
|
||||
:::note
|
||||
This requires Redis or Valkey specifically—it uses Redis-specific features (pub/sub, `SET NX EX`)
|
||||
that are not available in general Flask-Caching backends.
|
||||
:::
|
||||
|
||||
### Configuration
|
||||
|
||||
The signal cache uses Flask-Caching style configuration for consistency with other cache
|
||||
backends. Configure `SIGNAL_CACHE_CONFIG` in `superset_config.py`:
|
||||
|
||||
```python
|
||||
SIGNAL_CACHE_CONFIG = {
|
||||
"CACHE_TYPE": "RedisCache",
|
||||
"CACHE_REDIS_HOST": "localhost",
|
||||
"CACHE_REDIS_PORT": 6379,
|
||||
"CACHE_REDIS_DB": 0,
|
||||
"CACHE_REDIS_PASSWORD": "", # Optional
|
||||
}
|
||||
```
|
||||
|
||||
For Redis Sentinel deployments:
|
||||
|
||||
```python
|
||||
SIGNAL_CACHE_CONFIG = {
|
||||
"CACHE_TYPE": "RedisSentinelCache",
|
||||
"CACHE_REDIS_SENTINELS": [("sentinel1", 26379), ("sentinel2", 26379)],
|
||||
"CACHE_REDIS_SENTINEL_MASTER": "mymaster",
|
||||
"CACHE_REDIS_SENTINEL_PASSWORD": None, # Sentinel password (if different)
|
||||
"CACHE_REDIS_PASSWORD": "", # Redis password
|
||||
"CACHE_REDIS_DB": 0,
|
||||
}
|
||||
```
|
||||
|
||||
For SSL/TLS connections:
|
||||
|
||||
```python
|
||||
SIGNAL_CACHE_CONFIG = {
|
||||
"CACHE_TYPE": "RedisCache",
|
||||
"CACHE_REDIS_HOST": "redis.example.com",
|
||||
"CACHE_REDIS_PORT": 6380,
|
||||
"CACHE_REDIS_SSL": True,
|
||||
"CACHE_REDIS_SSL_CERTFILE": "/path/to/client.crt",
|
||||
"CACHE_REDIS_SSL_KEYFILE": "/path/to/client.key",
|
||||
"CACHE_REDIS_SSL_CA_CERTS": "/path/to/ca.crt",
|
||||
}
|
||||
```
|
||||
|
||||
### Distributed Lock TTL
|
||||
|
||||
You can configure the default lock TTL (time-to-live) in seconds. Locks automatically expire after
|
||||
this duration to prevent deadlocks from crashed processes:
|
||||
|
||||
```python
|
||||
DISTRIBUTED_LOCK_DEFAULT_TTL = 30 # Default: 30 seconds
|
||||
```
|
||||
|
||||
Individual lock acquisitions can override this value when needed.
|
||||
|
||||
### Database-Only Mode
|
||||
|
||||
When `SIGNAL_CACHE_CONFIG` is not configured, Superset uses database-backed operations:
|
||||
|
||||
- **Locking**: Uses the KeyValue table with periodic cleanup of expired entries
|
||||
- **Event notifications**: Uses database polling instead of pub/sub
|
||||
|
||||
While database-backed operations work reliably, the Redis backend is recommended for production
|
||||
deployments where low latency and reduced database load are important.
|
||||
|
||||
:::resources
|
||||
- [Blog: The Data Engineer's Guide to Lightning-Fast Superset Dashboards](https://preset.io/blog/the-data-engineers-guide-to-lightning-fast-apache-superset-dashboards/)
|
||||
- [Blog: Accelerating Dashboards with Materialized Views](https://preset.io/blog/accelerating-apache-superset-dashboards-with-materialized-views/)
|
||||
|
||||
@@ -97,6 +97,7 @@ const sidebars = {
|
||||
'extensions/deployment',
|
||||
'extensions/mcp',
|
||||
'extensions/security',
|
||||
'extensions/tasks',
|
||||
'extensions/registry',
|
||||
],
|
||||
},
|
||||
|
||||
@@ -46,6 +46,7 @@ from superset_core.api.models import (
|
||||
Query,
|
||||
SavedQuery,
|
||||
Tag,
|
||||
Task,
|
||||
User,
|
||||
)
|
||||
|
||||
@@ -248,6 +249,48 @@ class KeyValueDAO(BaseDAO[KeyValue]):
|
||||
id_column_name = "id"
|
||||
|
||||
|
||||
class TaskDAO(BaseDAO[Task]):
|
||||
"""
|
||||
Abstract Task DAO interface.
|
||||
|
||||
Host implementations will replace this class during initialization
|
||||
with a concrete implementation providing actual functionality.
|
||||
"""
|
||||
|
||||
# Class variables that will be set by host implementation
|
||||
model_cls = None
|
||||
base_filter = None
|
||||
id_column_name = "id"
|
||||
uuid_column_name = "uuid"
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def find_by_task_key(
|
||||
cls,
|
||||
task_type: str,
|
||||
task_key: str,
|
||||
scope: str = "private",
|
||||
user_id: int | None = None,
|
||||
) -> Task | None:
|
||||
"""
|
||||
Find active task by type, key, scope, and user.
|
||||
|
||||
Uses dedup_key internally for efficient querying with a unique index.
|
||||
Only returns tasks that are active (pending or in progress).
|
||||
|
||||
Uniqueness logic by scope:
|
||||
- private: scope + task_type + task_key + user_id
|
||||
- shared/system: scope + task_type + task_key (user-agnostic)
|
||||
|
||||
:param task_type: Task type to filter by
|
||||
:param task_key: Task identifier for deduplication
|
||||
:param scope: Task scope (private/shared/system)
|
||||
:param user_id: User ID (required for private tasks)
|
||||
:returns: Task instance or None if not found or not active
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BaseDAO",
|
||||
"DatasetDAO",
|
||||
@@ -259,4 +302,5 @@ __all__ = [
|
||||
"SavedQueryDAO",
|
||||
"TagDAO",
|
||||
"KeyValueDAO",
|
||||
"TaskDAO",
|
||||
]
|
||||
|
||||
@@ -40,6 +40,7 @@ from flask_appbuilder import Model
|
||||
from sqlalchemy.orm import scoped_session
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset_core.api.tasks import TaskProperties
|
||||
from superset_core.api.types import (
|
||||
AsyncQueryHandle,
|
||||
QueryOptions,
|
||||
@@ -361,6 +362,132 @@ class KeyValue(CoreModel):
|
||||
changed_by_fk: int | None
|
||||
|
||||
|
||||
class Task(CoreModel):
|
||||
"""
|
||||
Abstract Task model interface.
|
||||
|
||||
Host implementations will replace this class during initialization
|
||||
with concrete implementation providing actual functionality.
|
||||
|
||||
This model represents async tasks in the Global Task Framework (GTF).
|
||||
|
||||
Non-filterable fields (progress, error info, execution config) are stored
|
||||
in a `properties` JSON blob for schema flexibility.
|
||||
"""
|
||||
|
||||
__abstract__ = True
|
||||
|
||||
# Type hints for expected column attributes
|
||||
id: int
|
||||
uuid: UUID
|
||||
task_key: str # For deduplication
|
||||
task_type: str # e.g., 'sql_execution'
|
||||
task_name: str | None # Human readable name
|
||||
scope: str # private/shared/system
|
||||
status: str
|
||||
dedup_key: str # Computed deduplication key
|
||||
|
||||
# Timestamps (from AuditMixinNullable)
|
||||
created_on: datetime | None
|
||||
changed_on: datetime | None
|
||||
started_at: datetime | None
|
||||
ended_at: datetime | None
|
||||
|
||||
# User context
|
||||
created_by_fk: int | None
|
||||
user_id: int | None
|
||||
|
||||
# Task output data
|
||||
payload: str # JSON serialized task output data
|
||||
|
||||
def get_payload(self) -> dict[str, Any]:
|
||||
"""
|
||||
Get payload as parsed JSON.
|
||||
|
||||
Payload contains task-specific output data set by task code.
|
||||
|
||||
Host implementations will replace this method during initialization
|
||||
with concrete implementation providing actual functionality.
|
||||
|
||||
:returns: Dictionary containing payload data
|
||||
"""
|
||||
raise NotImplementedError("Method will be replaced during initialization")
|
||||
|
||||
def set_payload(self, data: dict[str, Any]) -> None:
|
||||
"""
|
||||
Update payload with new data (merges with existing).
|
||||
|
||||
Host implementations will replace this method during initialization
|
||||
with concrete implementation providing actual functionality.
|
||||
|
||||
:param data: Dictionary of data to merge into payload
|
||||
"""
|
||||
raise NotImplementedError("Method will be replaced during initialization")
|
||||
|
||||
@property
|
||||
def properties(self) -> Any:
|
||||
"""
|
||||
Get typed properties (runtime state and execution config).
|
||||
|
||||
Properties contain:
|
||||
- is_abortable: bool | None - has abort handler registered
|
||||
- progress_percent: float | None - progress 0.0-1.0
|
||||
- progress_current: int | None - current iteration count
|
||||
- progress_total: int | None - total iterations
|
||||
- error_message: str | None - human-readable error message
|
||||
- exception_type: str | None - exception class name
|
||||
- stack_trace: str | None - full formatted traceback
|
||||
- timeout: int | None - timeout in seconds
|
||||
|
||||
Host implementations will replace this property during initialization.
|
||||
|
||||
:returns: TaskProperties dataclass instance
|
||||
"""
|
||||
raise NotImplementedError("Property will be replaced during initialization")
|
||||
|
||||
def update_properties(self, updates: "TaskProperties") -> None:
|
||||
"""
|
||||
Update specific properties fields (merge semantics).
|
||||
|
||||
Only updates fields present in the updates dict.
|
||||
|
||||
Host implementations will replace this method during initialization.
|
||||
|
||||
:param updates: TaskProperties dict with fields to update
|
||||
|
||||
Example:
|
||||
task.update_properties({"is_abortable": True})
|
||||
"""
|
||||
raise NotImplementedError("Method will be replaced during initialization")
|
||||
|
||||
|
||||
class TaskSubscriber(CoreModel):
|
||||
"""
|
||||
Abstract TaskSubscriber model interface.
|
||||
|
||||
Host implementations will replace this class during initialization
|
||||
with concrete implementation providing actual functionality.
|
||||
|
||||
This model tracks task subscriptions for multi-user shared tasks. When a user
|
||||
schedules a shared task with the same parameters as an existing task,
|
||||
they are subscribed to that task instead of creating a duplicate.
|
||||
"""
|
||||
|
||||
__abstract__ = True
|
||||
|
||||
# Type hints for expected attributes (no actual field definitions)
|
||||
id: int
|
||||
task_id: int
|
||||
user_id: int
|
||||
subscribed_at: datetime
|
||||
|
||||
# Audit fields from AuditMixinNullable
|
||||
created_on: datetime | None
|
||||
changed_on: datetime | None
|
||||
created_by_fk: int | None
|
||||
changed_by_fk: int | None
|
||||
|
||||
|
||||
def get_session() -> scoped_session:
|
||||
"""
|
||||
Retrieve the SQLAlchemy session to directly interface with the
|
||||
@@ -384,6 +511,8 @@ __all__ = [
|
||||
"SavedQuery",
|
||||
"Tag",
|
||||
"KeyValue",
|
||||
"Task",
|
||||
"TaskSubscriber",
|
||||
"CoreModel",
|
||||
"get_session",
|
||||
]
|
||||
|
||||
361
superset-core/src/superset_core/api/tasks.py
Normal file
361
superset-core/src/superset_core/api/tasks.py
Normal file
@@ -0,0 +1,361 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Generic, Literal, ParamSpec, TypedDict, TypeVar
|
||||
|
||||
from superset_core.api.models import Task
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
class TaskStatus(str, Enum):
|
||||
"""
|
||||
Status of task execution.
|
||||
"""
|
||||
|
||||
PENDING = "pending"
|
||||
IN_PROGRESS = "in_progress"
|
||||
SUCCESS = "success"
|
||||
FAILURE = "failure"
|
||||
ABORTING = "aborting" # Abort/timeout requested, handlers running
|
||||
ABORTED = "aborted" # User/admin cancelled
|
||||
TIMED_OUT = "timed_out" # Timeout expired
|
||||
|
||||
|
||||
class TaskScope(str, Enum):
|
||||
"""
|
||||
Scope of task visibility and access control.
|
||||
"""
|
||||
|
||||
PRIVATE = "private" # User-specific tasks (default)
|
||||
SHARED = "shared" # Multi-user collaborative tasks
|
||||
SYSTEM = "system" # Admin-only background tasks
|
||||
|
||||
|
||||
class TaskProperties(TypedDict, total=False):
|
||||
"""
|
||||
TypedDict for task runtime state and execution config.
|
||||
|
||||
Stored as JSON in the database, accessed as a dict throughout the codebase.
|
||||
All fields are optional (total=False) - only set keys are present in the dict.
|
||||
|
||||
Usage:
|
||||
# Reading - always use .get() since keys may not be present
|
||||
if task.properties.get("is_abortable"):
|
||||
...
|
||||
|
||||
# Writing/updating - only include keys you want to set
|
||||
task.update_properties({"is_abortable": True, "progress_percent": 0.5})
|
||||
|
||||
Notes:
|
||||
- Sparse dict: only keys that are explicitly set are present
|
||||
- Unknown keys from JSON are preserved (forward compatibility)
|
||||
- Always use .get() for reads since keys may be absent
|
||||
"""
|
||||
|
||||
# Execution config - set at task creation
|
||||
execution_mode: Literal["async", "sync"]
|
||||
timeout: int
|
||||
|
||||
# Runtime state - set by framework during execution
|
||||
is_abortable: bool
|
||||
progress_percent: float
|
||||
progress_current: int
|
||||
progress_total: int
|
||||
|
||||
# Error info - set when task fails
|
||||
error_message: str
|
||||
exception_type: str
|
||||
stack_trace: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TaskOptions:
|
||||
"""
|
||||
Execution metadata for tasks.
|
||||
|
||||
NOTE: This is intentionally minimal for the initial implementation.
|
||||
Additional options (queue, priority, run_at, delay_s,
|
||||
max_retries, retry_backoff_s, tags, etc.) can be added later when needed.
|
||||
|
||||
Future enhancements will include:
|
||||
- Validation (e.g., run_at vs delay_s mutual exclusion)
|
||||
- Queue routing and priority management
|
||||
- Retry policies and backoff strategies
|
||||
|
||||
Example:
|
||||
from superset_core.api.tasks import TaskOptions, TaskScope
|
||||
|
||||
# Private task (default)
|
||||
task = my_task.schedule(arg1)
|
||||
|
||||
# Custom task with deduplication
|
||||
task = my_task.schedule(
|
||||
arg1,
|
||||
options=TaskOptions(
|
||||
task_key="custom_key",
|
||||
task_name="Custom Task Name"
|
||||
)
|
||||
)
|
||||
|
||||
# Task with custom name
|
||||
task = admin_task.schedule(
|
||||
options=TaskOptions(task_name="Admin Operation")
|
||||
)
|
||||
|
||||
# Task with timeout (overrides decorator default)
|
||||
task = long_task.schedule(
|
||||
options=TaskOptions(timeout=600) # 10 minute timeout
|
||||
)
|
||||
"""
|
||||
|
||||
task_key: str | None = None
|
||||
task_name: str | None = None
|
||||
timeout: int | None = None # Timeout in seconds
|
||||
|
||||
|
||||
class TaskContext(ABC):
|
||||
"""
|
||||
Abstract task context for write-only task state updates.
|
||||
|
||||
Tasks use this context to update their state (progress, payload) and
|
||||
check for cancellation. Tasks should not need to read their own state -
|
||||
they are the source of state, not consumers of it.
|
||||
|
||||
Host implementations will replace this abstract class during initialization
|
||||
with a concrete implementation providing actual functionality.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def update_task(
|
||||
self,
|
||||
progress: float | int | tuple[int, int] | None = None,
|
||||
payload: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Update task progress and/or payload atomically.
|
||||
|
||||
All parameters are optional. Payload is merged with existing data,
|
||||
not replaced. All updates occur in a single database transaction.
|
||||
|
||||
Progress can be specified in three ways:
|
||||
- float (0.0-1.0): Percentage only, e.g., 0.5 means 50%
|
||||
- int: Count only (total unknown), e.g., 42 means "42 items processed"
|
||||
- tuple[int, int]: Count and total, e.g., (3, 100) means "3 of 100"
|
||||
The percentage is automatically computed from count/total.
|
||||
|
||||
:param progress: Progress value, or None to leave unchanged
|
||||
:param payload: Payload data to merge (dict), or None to leave unchanged
|
||||
|
||||
Examples:
|
||||
# Percentage only - displays as "In progress: 50 %"
|
||||
ctx.update_task(progress=0.5)
|
||||
|
||||
# Count only (total unknown) - displays as "In progress: 42"
|
||||
ctx.update_task(progress=42)
|
||||
|
||||
# Count and total - displays as "In progress: 3 of 100 (3 %)"
|
||||
ctx.update_task(progress=(3, 100))
|
||||
|
||||
# Update payload only
|
||||
ctx.update_task(payload={"step": "processing"})
|
||||
|
||||
# Update both atomically
|
||||
ctx.update_task(
|
||||
progress=(80, 100),
|
||||
payload={"processed": 80, "total": 100}
|
||||
)
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def on_cleanup(self, handler: Callable[[], None]) -> Callable[[], None]:
|
||||
"""
|
||||
Register a cleanup handler that runs when the task ends.
|
||||
|
||||
Cleanup handlers are called when the task completes (success),
|
||||
fails with an error, or is cancelled. Multiple handlers can be
|
||||
registered and will execute in LIFO order (last registered runs first).
|
||||
|
||||
Can be used as a decorator:
|
||||
@ctx.on_cleanup
|
||||
def cleanup():
|
||||
logger.info("Task ended")
|
||||
|
||||
Or called directly:
|
||||
ctx.on_cleanup(lambda: logger.info("Task ended"))
|
||||
|
||||
:param handler: Cleanup function to register
|
||||
:returns: The handler (for decorator compatibility)
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def on_abort(self, handler: Callable[[], None]) -> Callable[[], None]:
|
||||
"""
|
||||
Register handler that runs when task is aborted.
|
||||
|
||||
When the first handler is registered, background polling starts
|
||||
automatically. The handler will be called when an abort is detected.
|
||||
|
||||
The handler executes in a background thread and the task code
|
||||
continues running unless the handler takes action to stop it.
|
||||
|
||||
:param handler: Callback function to execute when abort is detected
|
||||
:returns: The handler (for decorator compatibility)
|
||||
|
||||
Example:
|
||||
@ctx.on_abort
|
||||
def handle_abort():
|
||||
logger.info("Task was aborted!")
|
||||
cleanup_partial_work()
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
def task(
|
||||
name: str | None = None,
|
||||
scope: TaskScope = TaskScope.PRIVATE,
|
||||
timeout: int | None = None,
|
||||
) -> Callable[[Callable[P, R]], "TaskWrapper[P]"]:
|
||||
"""
|
||||
Decorator to register a task.
|
||||
|
||||
Host implementations will replace this function during initialization
|
||||
with a concrete implementation providing actual functionality.
|
||||
|
||||
:param name: Optional unique task name (e.g., "superset.generate_thumbnail").
|
||||
If not provided, uses the function name as the task name.
|
||||
:param scope: Task scope (TaskScope.PRIVATE, SHARED, or SYSTEM).
|
||||
Defaults to TaskScope.PRIVATE.
|
||||
:param timeout: Optional timeout in seconds. When the timeout is reached,
|
||||
abort handlers are triggered if registered. Can be overridden
|
||||
at call time via TaskOptions(timeout=...).
|
||||
:returns: TaskWrapper with .schedule() method
|
||||
|
||||
Note:
|
||||
Both direct calls and .schedule() return Task, regardless of the
|
||||
original function's return type. The decorated function's return value
|
||||
is discarded; only side effects and context updates matter.
|
||||
|
||||
Example:
|
||||
from superset_core.api.types import task, get_context, TaskScope
|
||||
|
||||
# Private task (default scope)
|
||||
@task
|
||||
def generate_thumbnail(chart_id: int) -> None:
|
||||
ctx = get_context()
|
||||
# ... task implementation
|
||||
|
||||
# Named task with shared scope
|
||||
@task(name="generate_report", scope=TaskScope.SHARED)
|
||||
def generate_chart_thumbnail(chart_id: int) -> None:
|
||||
ctx = get_context()
|
||||
|
||||
# Update progress and payload atomically
|
||||
ctx.update_task(
|
||||
progress=0.5,
|
||||
payload={"chart_id": chart_id, "status": "processing"}
|
||||
)
|
||||
# ... task implementation
|
||||
|
||||
ctx.update_task(progress=1.0)
|
||||
|
||||
# System task (admin-only)
|
||||
@task(scope=TaskScope.SYSTEM)
|
||||
def cleanup_old_data() -> None:
|
||||
ctx = get_context()
|
||||
# ... cleanup implementation
|
||||
|
||||
# Task with timeout
|
||||
@task(timeout=300) # 5-minute timeout
|
||||
def long_running_task() -> None:
|
||||
ctx = get_context()
|
||||
|
||||
@ctx.on_abort
|
||||
def handle_abort():
|
||||
# Called when timeout or manual abort
|
||||
pass
|
||||
|
||||
# Schedule async execution
|
||||
task = generate_chart_thumbnail.schedule(chart_id=123) # Returns Task
|
||||
|
||||
# Direct call for sync execution (blocks until task is complete)
|
||||
task = generate_chart_thumbnail(chart_id=123) # Also returns Task
|
||||
"""
|
||||
raise NotImplementedError("Function will be replaced during initialization")
|
||||
|
||||
|
||||
class TaskWrapper(Generic[P]):
|
||||
"""
|
||||
Type stub for task wrapper returned by @task decorator.
|
||||
|
||||
Both __call__ and .schedule() return Task.
|
||||
"""
|
||||
|
||||
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Task:
|
||||
"""Execute the task synchronously."""
|
||||
raise NotImplementedError("Will be replaced during initialization")
|
||||
|
||||
def schedule(self, *args: P.args, **kwargs: P.kwargs) -> Task:
|
||||
"""Schedule the task for async execution."""
|
||||
raise NotImplementedError("Will be replaced during initialization")
|
||||
|
||||
|
||||
def get_context() -> TaskContext:
|
||||
"""
|
||||
Get the current task context from ambient context.
|
||||
|
||||
Host implementations will replace this function during initialization
|
||||
with a concrete implementation providing actual functionality.
|
||||
|
||||
This function provides ambient access to the task context without
|
||||
requiring it to be passed as a parameter. It can only be called
|
||||
from within an async task execution.
|
||||
|
||||
:returns: The current TaskContext
|
||||
:raises RuntimeError: If called outside a task execution context
|
||||
|
||||
Example:
|
||||
@task("thumbnail_generation")
|
||||
def generate_chart_thumbnail(chart_id: int):
|
||||
ctx = get_context() # Access ambient context
|
||||
|
||||
# Update task state - no need to fetch task object
|
||||
ctx.update_task(
|
||||
progress=0.5,
|
||||
payload={"chart_id": chart_id}
|
||||
)
|
||||
"""
|
||||
raise NotImplementedError("Function will be replaced during initialization")
|
||||
|
||||
|
||||
__all__ = [
|
||||
"TaskStatus",
|
||||
"TaskScope",
|
||||
"TaskProperties",
|
||||
"TaskContext",
|
||||
"TaskOptions",
|
||||
"task",
|
||||
"get_context",
|
||||
]
|
||||
30
superset-frontend/package-lock.json
generated
30
superset-frontend/package-lock.json
generated
@@ -109,6 +109,7 @@
|
||||
"mustache": "^4.2.0",
|
||||
"nanoid": "^5.1.6",
|
||||
"ol": "^7.5.2",
|
||||
"pretty-ms": "^9.3.0",
|
||||
"query-string": "9.3.1",
|
||||
"re-resizable": "^6.11.2",
|
||||
"react": "^17.0.2",
|
||||
@@ -43687,6 +43688,21 @@
|
||||
"integrity": "sha512-w2GsyukL62IJnlaff/nRegPQR94C/XXamvMWmSHRJ4y7Ts/4ocGRmTHvOs8PSE6pB3dWOrD/nueuU5sduBsQ4w==",
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/pretty-ms": {
|
||||
"version": "9.3.0",
|
||||
"resolved": "https://registry.npmjs.org/pretty-ms/-/pretty-ms-9.3.0.tgz",
|
||||
"integrity": "sha512-gjVS5hOP+M3wMm5nmNOucbIrqudzs9v/57bWRHQWLYklXqoXKrVfYW2W9+glfGsqtPgpiz5WwyEEB+ksXIx3gQ==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"parse-ms": "^4.0.0"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=18"
|
||||
},
|
||||
"funding": {
|
||||
"url": "https://github.com/sponsors/sindresorhus"
|
||||
}
|
||||
},
|
||||
"node_modules/prismjs": {
|
||||
"version": "1.30.0",
|
||||
"resolved": "https://registry.npmjs.org/prismjs/-/prismjs-1.30.0.tgz",
|
||||
@@ -56437,20 +56453,6 @@
|
||||
"url": "https://github.com/sponsors/wooorm"
|
||||
}
|
||||
},
|
||||
"packages/superset-ui-core/node_modules/pretty-ms": {
|
||||
"version": "9.3.0",
|
||||
"resolved": "https://registry.npmjs.org/pretty-ms/-/pretty-ms-9.3.0.tgz",
|
||||
"integrity": "sha512-gjVS5hOP+M3wMm5nmNOucbIrqudzs9v/57bWRHQWLYklXqoXKrVfYW2W9+glfGsqtPgpiz5WwyEEB+ksXIx3gQ==",
|
||||
"dependencies": {
|
||||
"parse-ms": "^4.0.0"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=18"
|
||||
},
|
||||
"funding": {
|
||||
"url": "https://github.com/sponsors/sindresorhus"
|
||||
}
|
||||
},
|
||||
"packages/superset-ui-core/node_modules/property-information": {
|
||||
"version": "7.1.0",
|
||||
"resolved": "https://registry.npmjs.org/property-information/-/property-information-7.1.0.tgz",
|
||||
|
||||
@@ -187,6 +187,7 @@
|
||||
"markdown-to-jsx": "^9.7.3",
|
||||
"match-sorter": "^6.3.4",
|
||||
"memoize-one": "^5.2.1",
|
||||
"pretty-ms": "^9.3.0",
|
||||
"mousetrap": "^1.6.5",
|
||||
"mustache": "^4.2.0",
|
||||
"nanoid": "^5.1.6",
|
||||
|
||||
@@ -54,6 +54,7 @@ export enum FeatureFlag {
|
||||
EstimateQueryCost = 'ESTIMATE_QUERY_COST',
|
||||
FilterBarClosedByDefault = 'FILTERBAR_CLOSED_BY_DEFAULT',
|
||||
GlobalAsyncQueries = 'GLOBAL_ASYNC_QUERIES',
|
||||
GlobalTaskFramework = 'GLOBAL_TASK_FRAMEWORK',
|
||||
ListviewsDefaultCardView = 'LISTVIEWS_DEFAULT_CARD_VIEW',
|
||||
Matrixify = 'MATRIXIFY',
|
||||
ScheduledQueries = 'SCHEDULED_QUERIES',
|
||||
|
||||
@@ -19,9 +19,9 @@
|
||||
import getOwnerName from 'src/utils/getOwnerName';
|
||||
import { t } from '@apache-superset/core';
|
||||
import { Tooltip } from '@superset-ui/core/components';
|
||||
import type { ModifiedInfoProps } from './types';
|
||||
import type { AuditInfoProps } from './types';
|
||||
|
||||
export const ModifiedInfo = ({ user, date }: ModifiedInfoProps) => {
|
||||
export const ModifiedInfo = ({ user, date }: AuditInfoProps) => {
|
||||
const dateSpan = (
|
||||
<span className="no-wrap" data-test="audit-info-date">
|
||||
{date}
|
||||
@@ -40,4 +40,23 @@ export const ModifiedInfo = ({ user, date }: ModifiedInfoProps) => {
|
||||
return dateSpan;
|
||||
};
|
||||
|
||||
export type { ModifiedInfoProps };
|
||||
export const CreatedInfo = ({ user, date }: AuditInfoProps) => {
|
||||
const dateSpan = (
|
||||
<span className="no-wrap" data-test="audit-info-date">
|
||||
{date}
|
||||
</span>
|
||||
);
|
||||
|
||||
if (user) {
|
||||
const userName = getOwnerName(user);
|
||||
const title = t('Created by: %s', userName);
|
||||
return (
|
||||
<Tooltip title={title} placement="bottom">
|
||||
{dateSpan}
|
||||
</Tooltip>
|
||||
);
|
||||
}
|
||||
return dateSpan;
|
||||
};
|
||||
|
||||
export type { AuditInfoProps };
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
*/
|
||||
import type Owner from 'src/types/Owner';
|
||||
|
||||
export type ModifiedInfoProps = {
|
||||
export type AuditInfoProps = {
|
||||
user?: Owner;
|
||||
date: string;
|
||||
};
|
||||
|
||||
@@ -41,7 +41,7 @@ export * from './GenericLink';
|
||||
export { GridTable, type TableProps } from './GridTable';
|
||||
export * from './Tag';
|
||||
export * from './TagsList';
|
||||
export { ModifiedInfo, type ModifiedInfoProps } from './AuditInfo';
|
||||
export { CreatedInfo, ModifiedInfo, type AuditInfoProps } from './AuditInfo';
|
||||
export {
|
||||
DynamicPluginProvider,
|
||||
PluginContext,
|
||||
|
||||
76
superset-frontend/src/features/tasks/TaskPayloadPopover.tsx
Normal file
76
superset-frontend/src/features/tasks/TaskPayloadPopover.tsx
Normal file
@@ -0,0 +1,76 @@
|
||||
/**
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
import { useState } from 'react';
|
||||
import { styled } from '@apache-superset/core/ui';
|
||||
import { Popover } from '@superset-ui/core/components';
|
||||
import { Icons } from '@superset-ui/core/components/Icons';
|
||||
|
||||
const PayloadContainer = styled.div`
|
||||
max-width: 400px;
|
||||
max-height: 300px;
|
||||
overflow: auto;
|
||||
padding: ${({ theme }) => theme.sizeUnit * 2}px;
|
||||
`;
|
||||
|
||||
const PayloadPre = styled.pre`
|
||||
margin: 0;
|
||||
font-size: ${({ theme }) => theme.fontSizeSM}px;
|
||||
white-space: pre-wrap;
|
||||
word-wrap: break-word;
|
||||
`;
|
||||
|
||||
const InfoIconWrapper = styled.span`
|
||||
cursor: pointer;
|
||||
color: ${({ theme }) => theme.colorIcon};
|
||||
|
||||
&:hover {
|
||||
color: ${({ theme }) => theme.colorPrimary};
|
||||
}
|
||||
`;
|
||||
|
||||
interface TaskPayloadPopoverProps {
|
||||
payload: Record<string, any>;
|
||||
}
|
||||
|
||||
export default function TaskPayloadPopover({
|
||||
payload,
|
||||
}: TaskPayloadPopoverProps) {
|
||||
const [visible, setVisible] = useState(false);
|
||||
|
||||
const content = (
|
||||
<PayloadContainer>
|
||||
<PayloadPre>{JSON.stringify(payload, null, 2)}</PayloadPre>
|
||||
</PayloadContainer>
|
||||
);
|
||||
|
||||
return (
|
||||
<Popover
|
||||
content={content}
|
||||
trigger="hover"
|
||||
placement="leftTop"
|
||||
visible={visible}
|
||||
onVisibleChange={setVisible}
|
||||
>
|
||||
<InfoIconWrapper>
|
||||
<Icons.InfoCircleOutlined iconSize="l" />
|
||||
</InfoIconWrapper>
|
||||
</Popover>
|
||||
);
|
||||
}
|
||||
137
superset-frontend/src/features/tasks/TaskStackTracePopover.tsx
Normal file
137
superset-frontend/src/features/tasks/TaskStackTracePopover.tsx
Normal file
@@ -0,0 +1,137 @@
|
||||
/**
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
import { useState, useCallback } from 'react';
|
||||
import { t } from '@apache-superset/core';
|
||||
import { styled } from '@apache-superset/core/ui';
|
||||
import { Popover, Tooltip } from '@superset-ui/core/components';
|
||||
import { Icons } from '@superset-ui/core/components/Icons';
|
||||
import { useToasts } from 'src/components/MessageToasts/withToasts';
|
||||
import copyTextToClipboard from 'src/utils/copy';
|
||||
|
||||
const StackTraceContainer = styled.div`
|
||||
max-width: 600px;
|
||||
max-height: 400px;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
`;
|
||||
|
||||
const Header = styled.div`
|
||||
display: flex;
|
||||
justify-content: flex-end;
|
||||
padding: ${({ theme }) => theme.sizeUnit}px
|
||||
${({ theme }) => theme.sizeUnit * 2}px;
|
||||
border-bottom: 1px solid ${({ theme }) => theme.colorBorder};
|
||||
`;
|
||||
|
||||
const CopyButton = styled.button`
|
||||
background: none;
|
||||
border: none;
|
||||
cursor: pointer;
|
||||
padding: ${({ theme }) => theme.sizeUnit / 2}px;
|
||||
color: ${({ theme }) => theme.colorTextSecondary};
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: ${({ theme }) => theme.sizeUnit / 2}px;
|
||||
font-size: ${({ theme }) => theme.fontSizeSM}px;
|
||||
|
||||
&:hover {
|
||||
color: ${({ theme }) => theme.colorText};
|
||||
}
|
||||
`;
|
||||
|
||||
const StackTraceContent = styled.div`
|
||||
overflow: auto;
|
||||
padding: ${({ theme }) => theme.sizeUnit * 2}px;
|
||||
flex: 1;
|
||||
`;
|
||||
|
||||
const StackTrace = styled.pre`
|
||||
margin: 0;
|
||||
font-size: ${({ theme }) => theme.fontSizeSM}px;
|
||||
white-space: pre-wrap;
|
||||
word-wrap: break-word;
|
||||
font-family: ${({ theme }) => theme.fontFamilyCode};
|
||||
`;
|
||||
|
||||
const ErrorIconWrapper = styled.span`
|
||||
cursor: pointer;
|
||||
color: ${({ theme }) => theme.colorError};
|
||||
|
||||
&:hover {
|
||||
opacity: 0.8;
|
||||
}
|
||||
`;
|
||||
|
||||
interface TaskStackTracePopoverProps {
|
||||
stackTrace: string;
|
||||
}
|
||||
|
||||
export default function TaskStackTracePopover({
|
||||
stackTrace,
|
||||
}: TaskStackTracePopoverProps) {
|
||||
const [visible, setVisible] = useState(false);
|
||||
const [copied, setCopied] = useState(false);
|
||||
const { addDangerToast } = useToasts();
|
||||
|
||||
const handleCopy = useCallback(() => {
|
||||
copyTextToClipboard(() => Promise.resolve(stackTrace))
|
||||
.then(() => {
|
||||
setCopied(true);
|
||||
setTimeout(() => setCopied(false), 2000);
|
||||
})
|
||||
.catch(() => {
|
||||
addDangerToast(t('Failed to copy stack trace to clipboard'));
|
||||
});
|
||||
}, [stackTrace, addDangerToast]);
|
||||
|
||||
const content = (
|
||||
<StackTraceContainer>
|
||||
<Header>
|
||||
<Tooltip title={copied ? t('Copied!') : t('Copy to clipboard')}>
|
||||
<CopyButton onClick={handleCopy}>
|
||||
{copied ? (
|
||||
<Icons.CheckOutlined iconSize="s" />
|
||||
) : (
|
||||
<Icons.CopyOutlined iconSize="s" />
|
||||
)}
|
||||
{t('Copy')}
|
||||
</CopyButton>
|
||||
</Tooltip>
|
||||
</Header>
|
||||
<StackTraceContent>
|
||||
<StackTrace>{stackTrace}</StackTrace>
|
||||
</StackTraceContent>
|
||||
</StackTraceContainer>
|
||||
);
|
||||
|
||||
return (
|
||||
<Popover
|
||||
content={content}
|
||||
trigger="hover"
|
||||
placement="leftTop"
|
||||
visible={visible}
|
||||
onVisibleChange={setVisible}
|
||||
>
|
||||
<ErrorIconWrapper>
|
||||
<Icons.BugOutlined iconSize="l" />
|
||||
</ErrorIconWrapper>
|
||||
</Popover>
|
||||
);
|
||||
}
|
||||
145
superset-frontend/src/features/tasks/TaskStatusIcon.tsx
Normal file
145
superset-frontend/src/features/tasks/TaskStatusIcon.tsx
Normal file
@@ -0,0 +1,145 @@
|
||||
/**
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
import React from 'react';
|
||||
import { useTheme, SupersetTheme, t } from '@apache-superset/core/ui';
|
||||
import { Icons } from '@superset-ui/core/components/Icons';
|
||||
import { Tooltip } from '@superset-ui/core/components';
|
||||
import { TaskStatus } from './types';
|
||||
import { formatProgressTooltip } from './timeUtils';
|
||||
|
||||
function getStatusColor(status: TaskStatus, theme: SupersetTheme): string {
|
||||
switch (status) {
|
||||
case TaskStatus.Pending:
|
||||
return theme.colorPrimaryText;
|
||||
case TaskStatus.InProgress:
|
||||
return theme.colorPrimaryText;
|
||||
case TaskStatus.Success:
|
||||
return theme.colorSuccessText;
|
||||
case TaskStatus.Failure:
|
||||
return theme.colorErrorText;
|
||||
case TaskStatus.TimedOut:
|
||||
return theme.colorErrorText;
|
||||
case TaskStatus.Aborting:
|
||||
return theme.colorWarningText;
|
||||
case TaskStatus.Aborted:
|
||||
return theme.colorWarningText;
|
||||
default:
|
||||
return theme.colorText;
|
||||
}
|
||||
}
|
||||
|
||||
const statusIcons = {
|
||||
[TaskStatus.Pending]: Icons.ClockCircleOutlined,
|
||||
[TaskStatus.InProgress]: Icons.LoadingOutlined,
|
||||
[TaskStatus.Success]: Icons.CheckCircleOutlined,
|
||||
[TaskStatus.Failure]: Icons.CloseCircleOutlined,
|
||||
[TaskStatus.TimedOut]: Icons.ClockCircleOutlined, // Clock to indicate timeout
|
||||
[TaskStatus.Aborting]: Icons.LoadingOutlined, // Spinning to show in-progress abort
|
||||
[TaskStatus.Aborted]: Icons.StopOutlined,
|
||||
};
|
||||
|
||||
const statusLabels = {
|
||||
[TaskStatus.Pending]: t('Pending'),
|
||||
[TaskStatus.InProgress]: t('In Progress'),
|
||||
[TaskStatus.Success]: t('Success'),
|
||||
[TaskStatus.Failure]: t('Failed'),
|
||||
[TaskStatus.TimedOut]: t('Timed Out'),
|
||||
[TaskStatus.Aborting]: t('Aborting'),
|
||||
[TaskStatus.Aborted]: t('Aborted'),
|
||||
};
|
||||
|
||||
interface TaskStatusIconProps {
|
||||
status: TaskStatus;
|
||||
progressPercent?: number | null;
|
||||
progressCurrent?: number | null;
|
||||
progressTotal?: number | null;
|
||||
durationSeconds?: number | null;
|
||||
errorMessage?: string | null;
|
||||
exceptionType?: string | null;
|
||||
}
|
||||
|
||||
export default function TaskStatusIcon({
|
||||
status,
|
||||
progressPercent,
|
||||
progressCurrent,
|
||||
progressTotal,
|
||||
durationSeconds,
|
||||
errorMessage,
|
||||
exceptionType,
|
||||
}: TaskStatusIconProps) {
|
||||
const theme = useTheme();
|
||||
const IconComponent = statusIcons[status];
|
||||
const label = statusLabels[status];
|
||||
|
||||
// Build tooltip content based on status
|
||||
let tooltipContent: React.ReactNode;
|
||||
if (status === TaskStatus.InProgress || status === TaskStatus.Aborting) {
|
||||
// Progress tooltip for active tasks (multiline)
|
||||
const lines = formatProgressTooltip(
|
||||
label,
|
||||
progressCurrent,
|
||||
progressTotal,
|
||||
progressPercent,
|
||||
durationSeconds,
|
||||
);
|
||||
tooltipContent = (
|
||||
<>
|
||||
{lines.map((line, index) => (
|
||||
<React.Fragment key={index}>
|
||||
{index > 0 && <br />}
|
||||
{line}
|
||||
</React.Fragment>
|
||||
))}
|
||||
</>
|
||||
);
|
||||
} else if (
|
||||
(status === TaskStatus.Failure || status === TaskStatus.TimedOut) &&
|
||||
(exceptionType || errorMessage)
|
||||
) {
|
||||
// Error tooltip for failed/timed out tasks: "Label (ExceptionType): message"
|
||||
if (exceptionType && errorMessage) {
|
||||
tooltipContent = `${label} (${exceptionType}): ${errorMessage}`;
|
||||
} else if (exceptionType) {
|
||||
tooltipContent = `${label} (${exceptionType})`;
|
||||
} else if (errorMessage) {
|
||||
tooltipContent = `${label}: ${errorMessage}`;
|
||||
} else {
|
||||
tooltipContent = label;
|
||||
}
|
||||
} else {
|
||||
tooltipContent = label;
|
||||
}
|
||||
|
||||
// Spin for in-progress and aborting states
|
||||
const shouldSpin =
|
||||
status === TaskStatus.InProgress || status === TaskStatus.Aborting;
|
||||
|
||||
return (
|
||||
<Tooltip title={tooltipContent} placement="top">
|
||||
<span>
|
||||
<IconComponent
|
||||
iconSize="l"
|
||||
iconColor={getStatusColor(status, theme)}
|
||||
spin={shouldSpin}
|
||||
/>
|
||||
</span>
|
||||
</Tooltip>
|
||||
);
|
||||
}
|
||||
145
superset-frontend/src/features/tasks/timeUtils.test.ts
Normal file
145
superset-frontend/src/features/tasks/timeUtils.test.ts
Normal file
@@ -0,0 +1,145 @@
|
||||
/**
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
import {
|
||||
formatDuration,
|
||||
calculateEta,
|
||||
formatProgressTooltip,
|
||||
} from './timeUtils';
|
||||
|
||||
test('formatDuration returns null for invalid inputs', () => {
|
||||
expect(formatDuration(null)).toBeNull();
|
||||
expect(formatDuration(undefined)).toBeNull();
|
||||
expect(formatDuration(0)).toBeNull();
|
||||
expect(formatDuration(-5)).toBeNull();
|
||||
});
|
||||
|
||||
test('formatDuration formats seconds correctly', () => {
|
||||
expect(formatDuration(37.5)).toBe('37s');
|
||||
expect(formatDuration(1)).toBe('1s');
|
||||
expect(formatDuration(30)).toBe('30s');
|
||||
});
|
||||
|
||||
test('formatDuration formats minutes correctly', () => {
|
||||
expect(formatDuration(60)).toBe('1m');
|
||||
expect(formatDuration(90)).toBe('1m 30s');
|
||||
expect(formatDuration(150)).toBe('2m 30s');
|
||||
});
|
||||
|
||||
test('formatDuration formats hours correctly', () => {
|
||||
expect(formatDuration(3600)).toBe('1h');
|
||||
expect(formatDuration(3660)).toBe('1h 1m');
|
||||
expect(formatDuration(7200)).toBe('2h');
|
||||
});
|
||||
|
||||
test('calculateEta returns null for invalid inputs', () => {
|
||||
expect(calculateEta(null, 60)).toBeNull();
|
||||
expect(calculateEta(undefined, 60)).toBeNull();
|
||||
expect(calculateEta(0.5, null)).toBeNull();
|
||||
expect(calculateEta(0.5, undefined)).toBeNull();
|
||||
});
|
||||
|
||||
test('calculateEta returns null for edge case progress values', () => {
|
||||
// No progress yet
|
||||
expect(calculateEta(0, 60)).toBeNull();
|
||||
// Already complete
|
||||
expect(calculateEta(1, 60)).toBeNull();
|
||||
// Negative progress (invalid)
|
||||
expect(calculateEta(-0.1, 60)).toBeNull();
|
||||
// Over 100% (invalid)
|
||||
expect(calculateEta(1.1, 60)).toBeNull();
|
||||
});
|
||||
|
||||
test('calculateEta calculates correct remaining time', () => {
|
||||
// 50% done in 60s -> ETA = 60s remaining
|
||||
expect(calculateEta(0.5, 60)).toBe('1m');
|
||||
|
||||
// 30% done in 60s -> remaining = (60/0.3) * 0.7 = 140s
|
||||
expect(calculateEta(0.3, 60)).toBe('2m 20s');
|
||||
|
||||
// 10% done in 10s -> remaining = (10/0.1) * 0.9 = 90s
|
||||
expect(calculateEta(0.1, 10)).toBe('1m 30s');
|
||||
|
||||
// 90% done in 90s -> remaining = (90/0.9) * 0.1 = 10s
|
||||
expect(calculateEta(0.9, 90)).toBe('10s');
|
||||
});
|
||||
|
||||
test('calculateEta returns null for ETAs over 24 hours', () => {
|
||||
// 0.1% done in 100s -> remaining = (100/0.001) * 0.999 = ~99900s > 86400s
|
||||
expect(calculateEta(0.001, 100)).toBeNull();
|
||||
});
|
||||
|
||||
test('formatProgressTooltip returns label only when no progress data', () => {
|
||||
expect(formatProgressTooltip('In Progress')).toEqual(['In Progress']);
|
||||
expect(formatProgressTooltip('In Progress', null, null, null, null)).toEqual([
|
||||
'In Progress',
|
||||
]);
|
||||
});
|
||||
|
||||
test('formatProgressTooltip formats count and total correctly', () => {
|
||||
expect(formatProgressTooltip('In Progress', 9, 60)).toEqual([
|
||||
'In Progress: 9 of 60',
|
||||
]);
|
||||
});
|
||||
|
||||
test('formatProgressTooltip formats count only correctly', () => {
|
||||
expect(formatProgressTooltip('In Progress', 42)).toEqual([
|
||||
'In Progress: 42 processed',
|
||||
]);
|
||||
expect(formatProgressTooltip('In Progress', 42, null)).toEqual([
|
||||
'In Progress: 42 processed',
|
||||
]);
|
||||
});
|
||||
|
||||
test('formatProgressTooltip formats percentage correctly', () => {
|
||||
expect(formatProgressTooltip('In Progress', null, null, 0.5)).toEqual([
|
||||
'In Progress: 50%',
|
||||
]);
|
||||
expect(formatProgressTooltip('In Progress', null, null, 0.333)).toEqual([
|
||||
'In Progress: 33%',
|
||||
]);
|
||||
});
|
||||
|
||||
test('formatProgressTooltip combines count, total, and percentage', () => {
|
||||
expect(formatProgressTooltip('In Progress', 9, 60, 0.15)).toEqual([
|
||||
'In Progress: 9 of 60 (15%)',
|
||||
]);
|
||||
});
|
||||
|
||||
test('formatProgressTooltip includes ETA when duration is provided', () => {
|
||||
// 50% done in 60s -> ETA = 60s = ~1m
|
||||
expect(formatProgressTooltip('In Progress', 30, 60, 0.5, 60)).toEqual([
|
||||
'In Progress: 30 of 60 (50%)',
|
||||
'ETA: 1m',
|
||||
]);
|
||||
});
|
||||
|
||||
test('formatProgressTooltip works with percentage and ETA only', () => {
|
||||
// 25% done in 30s -> ETA = (30/0.25) * 0.75 = 90s = 1m 30s
|
||||
expect(formatProgressTooltip('In Progress', null, null, 0.25, 30)).toEqual([
|
||||
'In Progress: 25%',
|
||||
'ETA: 1m 30s',
|
||||
]);
|
||||
});
|
||||
|
||||
test('formatProgressTooltip works with different labels', () => {
|
||||
expect(formatProgressTooltip('Aborting', 5, 10, 0.5)).toEqual([
|
||||
'Aborting: 5 of 10 (50%)',
|
||||
]);
|
||||
});
|
||||
151
superset-frontend/src/features/tasks/timeUtils.ts
Normal file
151
superset-frontend/src/features/tasks/timeUtils.ts
Normal file
@@ -0,0 +1,151 @@
|
||||
/**
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
import prettyMs from 'pretty-ms';
|
||||
|
||||
/**
|
||||
* Maximum ETA to display (24 hours in seconds).
|
||||
* ETAs beyond this are not shown as they're unreliable.
|
||||
*/
|
||||
const MAX_ETA_SECONDS = 86400;
|
||||
|
||||
/**
|
||||
* Format a duration in seconds to a human-readable string.
|
||||
*
|
||||
* @param seconds - Duration in seconds
|
||||
* @returns Formatted string like "1m 30s" or "2h 15m", or null if invalid
|
||||
*/
|
||||
export function formatDuration(
|
||||
seconds: number | null | undefined,
|
||||
): string | null {
|
||||
if (seconds === null || seconds === undefined || seconds <= 0) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return prettyMs(seconds * 1000, {
|
||||
unitCount: 2,
|
||||
secondsDecimalDigits: 0,
|
||||
keepDecimalsOnWholeSeconds: false,
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate and format estimated time to completion based on progress and elapsed time.
|
||||
*
|
||||
* Uses the formula: ETA = (elapsed / progress) * (1 - progress)
|
||||
* For example, if 30% done in 60s, remaining = (60/0.3) * 0.7 = 140s
|
||||
*
|
||||
* @param progressPercent - Progress as a fraction (0.0 to 1.0)
|
||||
* @param durationSeconds - Time elapsed so far in seconds
|
||||
* @returns Formatted ETA string or null if cannot be calculated
|
||||
*/
|
||||
export function calculateEta(
|
||||
progressPercent: number | null | undefined,
|
||||
durationSeconds: number | null | undefined,
|
||||
): string | null {
|
||||
// Need both progress and duration to calculate ETA
|
||||
if (
|
||||
progressPercent === null ||
|
||||
progressPercent === undefined ||
|
||||
durationSeconds === null ||
|
||||
durationSeconds === undefined
|
||||
) {
|
||||
return null;
|
||||
}
|
||||
|
||||
// Can't calculate ETA if no progress yet or already complete
|
||||
if (progressPercent <= 0 || progressPercent >= 1) {
|
||||
return null;
|
||||
}
|
||||
|
||||
// ETA = (elapsed / progress) * (1 - progress)
|
||||
const estimatedTotalTime = durationSeconds / progressPercent;
|
||||
const remainingSeconds = estimatedTotalTime * (1 - progressPercent);
|
||||
|
||||
// Only show ETA if it's reasonable (less than 24 hours)
|
||||
if (remainingSeconds <= 0 || remainingSeconds > MAX_ETA_SECONDS) {
|
||||
return null;
|
||||
}
|
||||
|
||||
// Use unitCount: 2 to show up to 2 units (e.g., "1m 30s" instead of just "1m")
|
||||
// Use secondsDecimalDigits: 0 to show whole seconds (e.g., "52s" instead of "52.4s")
|
||||
return prettyMs(remainingSeconds * 1000, {
|
||||
unitCount: 2,
|
||||
secondsDecimalDigits: 0,
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Build a progress display for task status tooltips.
|
||||
*
|
||||
* Returns an array of lines for proper multiline tooltip rendering:
|
||||
* - ["In Progress: 9 of 60 (15%)", "ETA: 51s"]
|
||||
* - ["In Progress: 42 processed"]
|
||||
* - ["In Progress: 50%"]
|
||||
* - ["In Progress: 50%", "ETA: 2m"]
|
||||
*
|
||||
* @param label - Status label (e.g., "In Progress", "Aborting")
|
||||
* @param progressCurrent - Current count of items processed
|
||||
* @param progressTotal - Total count of items to process
|
||||
* @param progressPercent - Progress as a fraction (0.0 to 1.0)
|
||||
* @param durationSeconds - Time elapsed so far in seconds (used for ETA calculation)
|
||||
* @returns Array of lines for tooltip display
|
||||
*/
|
||||
export function formatProgressTooltip(
|
||||
label: string,
|
||||
progressCurrent?: number | null,
|
||||
progressTotal?: number | null,
|
||||
progressPercent?: number | null,
|
||||
durationSeconds?: number | null,
|
||||
): string[] {
|
||||
const lines: string[] = [];
|
||||
let progressPart = '';
|
||||
|
||||
// Build progress part
|
||||
if (progressCurrent !== null && progressCurrent !== undefined) {
|
||||
if (progressTotal !== null && progressTotal !== undefined) {
|
||||
// Count and total with percentage: "3 of 278 (15%)"
|
||||
progressPart = `${progressCurrent} of ${progressTotal}`;
|
||||
if (progressPercent !== null && progressPercent !== undefined) {
|
||||
progressPart += ` (${Math.round(progressPercent * 100)}%)`;
|
||||
}
|
||||
} else {
|
||||
// Count only: "3 processed"
|
||||
progressPart = `${progressCurrent} processed`;
|
||||
}
|
||||
} else if (progressPercent !== null && progressPercent !== undefined) {
|
||||
// Percentage only: "50%"
|
||||
progressPart = `${Math.round(progressPercent * 100)}%`;
|
||||
}
|
||||
|
||||
// Add the main progress line
|
||||
if (progressPart) {
|
||||
lines.push(`${label}: ${progressPart}`);
|
||||
} else {
|
||||
lines.push(label);
|
||||
}
|
||||
|
||||
// Add ETA on a separate line if available
|
||||
const eta = calculateEta(progressPercent, durationSeconds);
|
||||
if (eta) {
|
||||
lines.push(`ETA: ${eta}`);
|
||||
}
|
||||
|
||||
return lines;
|
||||
}
|
||||
115
superset-frontend/src/features/tasks/types.ts
Normal file
115
superset-frontend/src/features/tasks/types.ts
Normal file
@@ -0,0 +1,115 @@
|
||||
/**
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
export interface TaskSubscriber {
|
||||
user_id: number;
|
||||
first_name: string;
|
||||
last_name: string;
|
||||
subscribed_at: string;
|
||||
}
|
||||
|
||||
export enum TaskScope {
|
||||
Private = 'private',
|
||||
Shared = 'shared',
|
||||
System = 'system',
|
||||
}
|
||||
|
||||
/**
|
||||
* Task properties - runtime state and execution config stored in JSON blob.
|
||||
*/
|
||||
export interface TaskProperties {
|
||||
// Execution config - set at task creation
|
||||
execution_mode: 'async' | 'sync' | null;
|
||||
timeout: number | null;
|
||||
|
||||
// Runtime state - set by framework during execution
|
||||
is_abortable: boolean | null;
|
||||
progress_percent: number | null;
|
||||
progress_current: number | null;
|
||||
progress_total: number | null;
|
||||
|
||||
// Error info - set when task fails
|
||||
error_message: string | null;
|
||||
exception_type: string | null;
|
||||
stack_trace: string | null;
|
||||
}
|
||||
|
||||
export interface Task {
|
||||
id: number;
|
||||
uuid: string;
|
||||
task_key: string;
|
||||
task_type: string;
|
||||
task_name: string | null;
|
||||
status:
|
||||
| 'pending'
|
||||
| 'in_progress'
|
||||
| 'success'
|
||||
| 'failure'
|
||||
| 'aborting'
|
||||
| 'aborted'
|
||||
| 'timed_out';
|
||||
scope: TaskScope;
|
||||
created_on: string;
|
||||
created_on_delta_humanized?: string;
|
||||
changed_on: string;
|
||||
started_at: string | null;
|
||||
ended_at: string | null;
|
||||
created_by: {
|
||||
id: number;
|
||||
first_name: string;
|
||||
last_name: string;
|
||||
} | null;
|
||||
changed_by?: {
|
||||
first_name: string;
|
||||
last_name: string;
|
||||
} | null;
|
||||
user_id: number | null;
|
||||
payload: Record<string, any>;
|
||||
properties: TaskProperties;
|
||||
duration_seconds: number | null;
|
||||
subscriber_count: number;
|
||||
subscribers: TaskSubscriber[];
|
||||
}
|
||||
|
||||
// Derived status helpers (frontend computes these from status and properties)
|
||||
export function isTaskFinished(task: Task): boolean {
|
||||
return ['success', 'failure', 'aborted', 'timed_out'].includes(task.status);
|
||||
}
|
||||
|
||||
export function isTaskAborting(task: Task): boolean {
|
||||
return task.status === 'aborting';
|
||||
}
|
||||
|
||||
export function canAbortTask(task: Task): boolean {
|
||||
if (task.status === 'pending') return true;
|
||||
if (task.status === 'in_progress' && task.properties.is_abortable === true)
|
||||
return true;
|
||||
if (task.status === 'aborting') return true; // Idempotent
|
||||
return false;
|
||||
}
|
||||
|
||||
export enum TaskStatus {
|
||||
Pending = 'pending',
|
||||
InProgress = 'in_progress',
|
||||
Success = 'success',
|
||||
Failure = 'failure',
|
||||
Aborting = 'aborting',
|
||||
Aborted = 'aborted',
|
||||
TimedOut = 'timed_out',
|
||||
}
|
||||
328
superset-frontend/src/pages/TaskList/TaskList.test.tsx
Normal file
328
superset-frontend/src/pages/TaskList/TaskList.test.tsx
Normal file
@@ -0,0 +1,328 @@
|
||||
/**
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
import { MemoryRouter } from 'react-router-dom';
|
||||
import fetchMock from 'fetch-mock';
|
||||
import {
|
||||
render,
|
||||
screen,
|
||||
waitFor,
|
||||
fireEvent,
|
||||
} from 'spec/helpers/testing-library';
|
||||
import { QueryParamProvider } from 'use-query-params';
|
||||
import { ReactRouter5Adapter } from 'use-query-params/adapters/react-router-5';
|
||||
import { TaskStatus, TaskScope } from 'src/features/tasks/types';
|
||||
import TaskList from 'src/pages/TaskList';
|
||||
|
||||
// Set up window.featureFlags before importing TaskList
|
||||
window.featureFlags = { GLOBAL_TASK_FRAMEWORK: true };
|
||||
|
||||
// Mock getBootstrapData before importing components that use it
|
||||
jest.mock('src/utils/getBootstrapData', () => ({
|
||||
__esModule: true,
|
||||
default: () => ({
|
||||
user: {
|
||||
userId: 1,
|
||||
firstName: 'admin',
|
||||
lastName: 'user',
|
||||
roles: { Admin: [] },
|
||||
},
|
||||
common: {
|
||||
feature_flags: { GLOBAL_TASK_FRAMEWORK: true },
|
||||
conf: {},
|
||||
},
|
||||
}),
|
||||
}));
|
||||
|
||||
const tasksInfoEndpoint = 'glob:*/api/v1/task/_info*';
|
||||
const tasksCreatedByEndpoint = 'glob:*/api/v1/task/related/created_by*';
|
||||
const tasksEndpoint = 'glob:*/api/v1/task/?*';
|
||||
const taskCancelEndpoint = 'glob:*/api/v1/task/*/cancel';
|
||||
|
||||
const mockTasks = [
|
||||
{
|
||||
id: 1,
|
||||
uuid: 'task-uuid-1',
|
||||
task_key: 'test_task_1',
|
||||
task_type: 'data_export',
|
||||
task_name: 'Export Data Task',
|
||||
status: TaskStatus.Success,
|
||||
scope: TaskScope.Private,
|
||||
created_on: '2024-01-15T10:00:00Z',
|
||||
changed_on: '2024-01-15T10:05:00Z',
|
||||
created_on_delta_humanized: '5 minutes ago',
|
||||
started_at: '2024-01-15T10:00:01Z',
|
||||
ended_at: '2024-01-15T10:05:00Z',
|
||||
created_by: { id: 1, first_name: 'admin', last_name: 'user' },
|
||||
user_id: 1,
|
||||
payload: {},
|
||||
duration_seconds: 299,
|
||||
subscriber_count: 0,
|
||||
subscribers: [],
|
||||
properties: {
|
||||
is_abortable: null,
|
||||
progress_percent: 1.0,
|
||||
progress_current: null,
|
||||
progress_total: null,
|
||||
error_message: null,
|
||||
exception_type: null,
|
||||
stack_trace: null,
|
||||
timeout: null,
|
||||
},
|
||||
},
|
||||
{
|
||||
id: 2,
|
||||
uuid: 'task-uuid-2',
|
||||
task_key: 'test_task_2',
|
||||
task_type: 'report_generation',
|
||||
task_name: null,
|
||||
status: TaskStatus.InProgress,
|
||||
scope: TaskScope.Private,
|
||||
created_on: '2024-01-15T11:00:00Z',
|
||||
changed_on: '2024-01-15T11:00:00Z',
|
||||
created_on_delta_humanized: '1 minute ago',
|
||||
started_at: '2024-01-15T11:00:01Z',
|
||||
ended_at: null,
|
||||
created_by: { id: 1, first_name: 'admin', last_name: 'user' },
|
||||
user_id: 1,
|
||||
payload: { report_id: 42 },
|
||||
duration_seconds: null,
|
||||
subscriber_count: 0,
|
||||
subscribers: [],
|
||||
properties: {
|
||||
is_abortable: true,
|
||||
progress_percent: 0.5,
|
||||
progress_current: null,
|
||||
progress_total: null,
|
||||
error_message: null,
|
||||
exception_type: null,
|
||||
stack_trace: null,
|
||||
timeout: null,
|
||||
},
|
||||
},
|
||||
{
|
||||
id: 3,
|
||||
uuid: 'task-uuid-3',
|
||||
task_key: 'shared_task_1',
|
||||
task_type: 'bulk_operation',
|
||||
task_name: 'Shared Bulk Task',
|
||||
status: TaskStatus.Pending,
|
||||
scope: TaskScope.Shared,
|
||||
created_on: '2024-01-15T12:00:00Z',
|
||||
changed_on: '2024-01-15T12:00:00Z',
|
||||
created_on_delta_humanized: 'just now',
|
||||
started_at: null,
|
||||
ended_at: null,
|
||||
created_by: { id: 2, first_name: 'other', last_name: 'user' },
|
||||
user_id: 2,
|
||||
payload: {},
|
||||
duration_seconds: null,
|
||||
subscriber_count: 2,
|
||||
subscribers: [
|
||||
{
|
||||
user_id: 1,
|
||||
first_name: 'admin',
|
||||
last_name: 'user',
|
||||
subscribed_at: '2024-01-15T12:00:00Z',
|
||||
},
|
||||
{
|
||||
user_id: 2,
|
||||
first_name: 'other',
|
||||
last_name: 'user',
|
||||
subscribed_at: '2024-01-15T12:00:01Z',
|
||||
},
|
||||
],
|
||||
properties: {
|
||||
is_abortable: null,
|
||||
progress_percent: null,
|
||||
progress_current: null,
|
||||
progress_total: null,
|
||||
error_message: null,
|
||||
exception_type: null,
|
||||
stack_trace: null,
|
||||
timeout: null,
|
||||
},
|
||||
},
|
||||
];
|
||||
|
||||
const mockUser = {
|
||||
userId: 1,
|
||||
firstName: 'admin',
|
||||
lastName: 'user',
|
||||
};
|
||||
|
||||
fetchMock.get(
|
||||
tasksInfoEndpoint,
|
||||
{ permissions: ['can_read', 'can_write'] },
|
||||
{ name: tasksInfoEndpoint },
|
||||
);
|
||||
fetchMock.get(
|
||||
tasksCreatedByEndpoint,
|
||||
{ result: [] },
|
||||
{ name: tasksCreatedByEndpoint },
|
||||
);
|
||||
fetchMock.get(
|
||||
tasksEndpoint,
|
||||
{ result: mockTasks, count: 3 },
|
||||
{ name: tasksEndpoint },
|
||||
);
|
||||
fetchMock.post(
|
||||
taskCancelEndpoint,
|
||||
{ action: 'aborted', message: 'Task cancelled' },
|
||||
{ name: taskCancelEndpoint },
|
||||
);
|
||||
|
||||
const renderTaskList = (props = {}, userProp = mockUser) =>
|
||||
render(
|
||||
<MemoryRouter>
|
||||
<QueryParamProvider adapter={ReactRouter5Adapter}>
|
||||
<TaskList {...props} user={userProp} />
|
||||
</QueryParamProvider>
|
||||
</MemoryRouter>,
|
||||
{ useRedux: true },
|
||||
);
|
||||
|
||||
beforeEach(() => {
|
||||
fetchMock.clearHistory();
|
||||
});
|
||||
|
||||
test('renders TaskList with title, ListView, and fetches data from endpoints', async () => {
|
||||
renderTaskList();
|
||||
|
||||
// Wait for data to load and verify title
|
||||
expect(await screen.findByText('Tasks')).toBeInTheDocument();
|
||||
expect(screen.getByTestId('task-list-view')).toBeInTheDocument();
|
||||
|
||||
// Verify API calls were made
|
||||
expect(fetchMock.callHistory.calls(/task\/_info/).length).toBe(1);
|
||||
expect(fetchMock.callHistory.calls(/task\/\?q/).length).toBe(1);
|
||||
});
|
||||
|
||||
test('displays task data including types, scope labels, and duration', async () => {
|
||||
renderTaskList();
|
||||
|
||||
// Wait for data to load
|
||||
await screen.findByText('Export Data Task');
|
||||
|
||||
// Task types
|
||||
expect(screen.getByText('data_export')).toBeInTheDocument();
|
||||
expect(screen.getByText('report_generation')).toBeInTheDocument();
|
||||
|
||||
// Scope labels
|
||||
expect(screen.getAllByText('Private').length).toBeGreaterThan(0);
|
||||
expect(screen.getByText('Shared')).toBeInTheDocument();
|
||||
|
||||
// Duration (299s = 4m 59s via prettyMs)
|
||||
expect(screen.getByText('4m 59s')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test('shows cancel button and modal for cancellable tasks', async () => {
|
||||
renderTaskList();
|
||||
|
||||
// Wait for data to load
|
||||
await screen.findByText('test_task_2');
|
||||
|
||||
// Cancel buttons exist for in-progress and shared tasks
|
||||
const stopIcons = screen.getAllByRole('img', { name: 'stop' });
|
||||
expect(stopIcons.length).toBeGreaterThan(0);
|
||||
|
||||
// Click a cancel button to show confirmation modal
|
||||
const cancelButton = stopIcons.find(
|
||||
icon => icon.closest('[role="button"]') !== null,
|
||||
);
|
||||
expect(cancelButton).toBeDefined();
|
||||
fireEvent.click(cancelButton!);
|
||||
|
||||
expect(await screen.findByText('Cancel Task')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test('does not show cancel button for completed shared tasks', async () => {
|
||||
const completedSharedTask = {
|
||||
id: 4,
|
||||
uuid: 'task-uuid-4',
|
||||
task_key: 'completed_shared_task',
|
||||
task_type: 'bulk_operation',
|
||||
task_name: 'Completed Shared Task',
|
||||
status: TaskStatus.Success,
|
||||
scope: TaskScope.Shared,
|
||||
created_on: '2024-01-15T12:00:00Z',
|
||||
changed_on: '2024-01-15T12:05:00Z',
|
||||
created_on_delta_humanized: '5 minutes ago',
|
||||
started_at: '2024-01-15T12:00:01Z',
|
||||
ended_at: '2024-01-15T12:05:00Z',
|
||||
created_by: { id: 2, first_name: 'other', last_name: 'user' },
|
||||
user_id: 2,
|
||||
payload: {},
|
||||
duration_seconds: 299,
|
||||
subscriber_count: 1,
|
||||
subscribers: [
|
||||
{
|
||||
user_id: 1,
|
||||
first_name: 'admin',
|
||||
last_name: 'user',
|
||||
subscribed_at: '2024-01-15T12:00:00Z',
|
||||
},
|
||||
],
|
||||
properties: {
|
||||
is_abortable: null,
|
||||
progress_percent: 1.0,
|
||||
progress_current: null,
|
||||
progress_total: null,
|
||||
error_message: null,
|
||||
exception_type: null,
|
||||
stack_trace: null,
|
||||
timeout: null,
|
||||
},
|
||||
};
|
||||
|
||||
fetchMock.modifyRoute(tasksEndpoint, {
|
||||
response: { result: [completedSharedTask], count: 1 },
|
||||
});
|
||||
|
||||
renderTaskList();
|
||||
await screen.findByText('Completed Shared Task');
|
||||
|
||||
// No action buttons with stop icons for completed tasks
|
||||
const stopIcons = screen.queryAllByRole('img', { name: 'stop' });
|
||||
const actionButtons = stopIcons.filter(
|
||||
icon => icon.closest('[role="button"]') !== null,
|
||||
);
|
||||
expect(actionButtons).toHaveLength(0);
|
||||
|
||||
// Restore mock
|
||||
fetchMock.modifyRoute(tasksEndpoint, {
|
||||
response: { result: mockTasks, count: 3 },
|
||||
});
|
||||
});
|
||||
|
||||
test('displays empty state when no tasks', async () => {
|
||||
fetchMock.modifyRoute(tasksEndpoint, {
|
||||
response: { result: [], count: 0 },
|
||||
});
|
||||
|
||||
renderTaskList();
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('No tasks yet')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
// Restore mock
|
||||
fetchMock.modifyRoute(tasksEndpoint, {
|
||||
response: { result: mockTasks, count: 3 },
|
||||
});
|
||||
});
|
||||
658
superset-frontend/src/pages/TaskList/index.tsx
Normal file
658
superset-frontend/src/pages/TaskList/index.tsx
Normal file
@@ -0,0 +1,658 @@
|
||||
/**
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
import {
|
||||
FeatureFlag,
|
||||
isFeatureEnabled,
|
||||
SupersetClient,
|
||||
} from '@superset-ui/core';
|
||||
import { t, useTheme } from '@apache-superset/core';
|
||||
import { useMemo, useCallback, useState } from 'react';
|
||||
import { Tooltip, Label, Modal, Checkbox } from '@superset-ui/core/components';
|
||||
import {
|
||||
CreatedInfo,
|
||||
ListView,
|
||||
ListViewFilterOperator as FilterOperator,
|
||||
type ListViewFilters,
|
||||
FacePile,
|
||||
} from 'src/components';
|
||||
import { Icons } from '@superset-ui/core/components/Icons';
|
||||
import withToasts from 'src/components/MessageToasts/withToasts';
|
||||
import SubMenu from 'src/features/home/SubMenu';
|
||||
import { useListViewResource } from 'src/views/CRUD/hooks';
|
||||
import { createErrorHandler, createFetchRelated } from 'src/views/CRUD/utils';
|
||||
import TaskStatusIcon from 'src/features/tasks/TaskStatusIcon';
|
||||
import TaskPayloadPopover from 'src/features/tasks/TaskPayloadPopover';
|
||||
import TaskStackTracePopover from 'src/features/tasks/TaskStackTracePopover';
|
||||
import { formatDuration } from 'src/features/tasks/timeUtils';
|
||||
import {
|
||||
Task,
|
||||
TaskStatus,
|
||||
TaskScope,
|
||||
canAbortTask,
|
||||
isTaskAborting,
|
||||
TaskSubscriber,
|
||||
} from 'src/features/tasks/types';
|
||||
import { isUserAdmin } from 'src/dashboard/util/permissionUtils';
|
||||
import getBootstrapData from 'src/utils/getBootstrapData';
|
||||
|
||||
const PAGE_SIZE = 25;
|
||||
|
||||
/**
|
||||
* Typed cell props for react-table columns.
|
||||
* Replaces `: any` for better type safety in Cell render functions.
|
||||
*/
|
||||
interface TaskCellProps {
|
||||
row: {
|
||||
original: Task;
|
||||
};
|
||||
}
|
||||
|
||||
interface TaskListProps {
|
||||
addDangerToast: (msg: string) => void;
|
||||
addSuccessToast: (msg: string) => void;
|
||||
user: {
|
||||
userId: string | number;
|
||||
firstName: string;
|
||||
lastName: string;
|
||||
};
|
||||
}
|
||||
|
||||
function TaskList({ addDangerToast, addSuccessToast, user }: TaskListProps) {
|
||||
const theme = useTheme();
|
||||
|
||||
// Check if GTF feature flag is enabled
|
||||
if (!isFeatureEnabled(FeatureFlag.GlobalTaskFramework)) {
|
||||
return (
|
||||
<>
|
||||
<SubMenu name={t('Tasks')} />
|
||||
<div
|
||||
style={{
|
||||
display: 'flex',
|
||||
flexDirection: 'column',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
height: '50vh',
|
||||
color: theme.colorTextSecondary,
|
||||
}}
|
||||
>
|
||||
<h3>{t('Feature Not Enabled')}</h3>
|
||||
<p>
|
||||
{t(
|
||||
'The Global Task Framework is not enabled. Please contact your administrator to enable the GLOBAL_TASK_FRAMEWORK feature flag.',
|
||||
)}
|
||||
</p>
|
||||
</div>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
const {
|
||||
state: { loading, resourceCount: tasksCount, resourceCollection: tasks },
|
||||
fetchData,
|
||||
refreshData,
|
||||
} = useListViewResource<Task>('task', t('task'), addDangerToast);
|
||||
|
||||
// Get full user with roles to check admin status
|
||||
const bootstrapData = getBootstrapData();
|
||||
const fullUser = bootstrapData?.user;
|
||||
const isAdmin = useMemo(() => isUserAdmin(fullUser), [fullUser]);
|
||||
|
||||
// State for cancel confirmation modal
|
||||
const [cancelModalTask, setCancelModalTask] = useState<Task | null>(null);
|
||||
const [forceCancel, setForceCancel] = useState(false);
|
||||
|
||||
// Determine dialog message based on task context
|
||||
const getCancelDialogMessage = useCallback((task: Task) => {
|
||||
const isSharedTask = task.scope === TaskScope.Shared;
|
||||
const subscriberCount = task.subscriber_count || 0;
|
||||
const otherSubscribers = subscriberCount - 1;
|
||||
|
||||
// If it's going to abort (private, system, or last subscriber)
|
||||
if (!isSharedTask || subscriberCount <= 1) {
|
||||
return t('This will cancel the task.');
|
||||
}
|
||||
|
||||
// Shared task with multiple subscribers
|
||||
return t(
|
||||
"You'll be removed from this task. It will continue running for %s other subscriber(s).",
|
||||
otherSubscribers,
|
||||
);
|
||||
}, []);
|
||||
|
||||
// Get force abort message for admin checkbox
|
||||
const getForceAbortMessage = useCallback((task: Task) => {
|
||||
const subscriberCount = task.subscriber_count || 0;
|
||||
return t(
|
||||
'This will abort (stop) the task for all %s subscriber(s).',
|
||||
subscriberCount,
|
||||
);
|
||||
}, []);
|
||||
|
||||
// Check if current user is subscribed to a task
|
||||
const isUserSubscribed = useCallback(
|
||||
(task: Task) =>
|
||||
task.subscribers?.some(
|
||||
(sub: TaskSubscriber) => sub.user_id === user.userId,
|
||||
) ?? false,
|
||||
[user.userId],
|
||||
);
|
||||
|
||||
// Check if force cancel option should be shown (for admins on shared tasks)
|
||||
const showForceCancelOption = useCallback(
|
||||
(task: Task) => {
|
||||
const isSharedTask = task.scope === TaskScope.Shared;
|
||||
const subscriberCount = task.subscriber_count || 0;
|
||||
const userSubscribed = isUserSubscribed(task);
|
||||
// Show for admins on shared tasks when:
|
||||
// - Not subscribed (can only abort, so show checkbox pre-checked disabled), OR
|
||||
// - Multiple subscribers (can choose between unsubscribe and force abort)
|
||||
// Don't show when admin is the sole subscriber - cancel will abort anyway
|
||||
return (
|
||||
isAdmin && isSharedTask && (subscriberCount > 1 || !userSubscribed)
|
||||
);
|
||||
},
|
||||
[isAdmin, isUserSubscribed],
|
||||
);
|
||||
|
||||
// Check if force cancel checkbox should be disabled (admin not subscribed)
|
||||
const isForceCancelDisabled = useCallback(
|
||||
(task: Task) => isAdmin && !isUserSubscribed(task),
|
||||
[isAdmin, isUserSubscribed],
|
||||
);
|
||||
|
||||
const handleTaskCancel = useCallback(
|
||||
(task: Task, force: boolean = false) => {
|
||||
SupersetClient.post({
|
||||
endpoint: `/api/v1/task/${task.uuid}/cancel`,
|
||||
jsonPayload: force ? { force: true } : {},
|
||||
}).then(
|
||||
({ json }) => {
|
||||
refreshData();
|
||||
const { action } = json as { action: string };
|
||||
if (action === 'aborted') {
|
||||
addSuccessToast(
|
||||
t('Task cancelled: %s', task.task_name || task.task_key),
|
||||
);
|
||||
} else {
|
||||
addSuccessToast(
|
||||
t(
|
||||
'You have been removed from task: %s',
|
||||
task.task_name || task.task_key,
|
||||
),
|
||||
);
|
||||
}
|
||||
},
|
||||
createErrorHandler(errMsg =>
|
||||
addDangerToast(
|
||||
t('There was an issue cancelling the task: %s', errMsg),
|
||||
),
|
||||
),
|
||||
);
|
||||
},
|
||||
[addDangerToast, addSuccessToast, refreshData],
|
||||
);
|
||||
|
||||
// Handle opening the cancel modal - set initial forceCancel state
|
||||
const openCancelModal = useCallback(
|
||||
(task: Task) => {
|
||||
// Pre-check force cancel if admin is not subscribed
|
||||
const shouldPreCheck = isAdmin && !isUserSubscribed(task);
|
||||
setForceCancel(shouldPreCheck);
|
||||
setCancelModalTask(task);
|
||||
},
|
||||
[isAdmin, isUserSubscribed],
|
||||
);
|
||||
|
||||
// Handle modal confirmation
|
||||
const handleCancelConfirm = useCallback(() => {
|
||||
if (cancelModalTask) {
|
||||
handleTaskCancel(cancelModalTask, forceCancel);
|
||||
setCancelModalTask(null);
|
||||
setForceCancel(false);
|
||||
}
|
||||
}, [cancelModalTask, forceCancel, handleTaskCancel]);
|
||||
|
||||
// Handle modal close
|
||||
const handleCancelModalClose = useCallback(() => {
|
||||
setCancelModalTask(null);
|
||||
setForceCancel(false);
|
||||
}, []);
|
||||
|
||||
const columns = useMemo(
|
||||
() => [
|
||||
{
|
||||
Cell: ({
|
||||
row: {
|
||||
original: { task_name, task_key, uuid },
|
||||
},
|
||||
}: TaskCellProps) => {
|
||||
// Display preference: task_name > task_key
|
||||
const displayText = task_name || task_key;
|
||||
const truncated =
|
||||
displayText.length > 30
|
||||
? `${displayText.slice(0, 30)}...`
|
||||
: displayText;
|
||||
|
||||
// Build tooltip with all identifiers
|
||||
const tooltipLines = [];
|
||||
if (task_name) tooltipLines.push(`Name: ${task_name}`);
|
||||
tooltipLines.push(`Key: ${task_key}`);
|
||||
tooltipLines.push(`UUID: ${uuid}`);
|
||||
const tooltipText = tooltipLines.join('\n');
|
||||
|
||||
return (
|
||||
<Tooltip
|
||||
title={
|
||||
<span style={{ whiteSpace: 'pre-line' }}>{tooltipText}</span>
|
||||
}
|
||||
placement="top"
|
||||
>
|
||||
<span>{truncated}</span>
|
||||
</Tooltip>
|
||||
);
|
||||
},
|
||||
accessor: 'task_name',
|
||||
Header: t('Task'),
|
||||
size: 'xl',
|
||||
id: 'task',
|
||||
},
|
||||
{
|
||||
Cell: ({
|
||||
row: {
|
||||
original: { status, properties, duration_seconds },
|
||||
},
|
||||
}: TaskCellProps) => (
|
||||
<TaskStatusIcon
|
||||
status={status as TaskStatus}
|
||||
progressPercent={properties?.progress_percent}
|
||||
progressCurrent={properties?.progress_current}
|
||||
progressTotal={properties?.progress_total}
|
||||
durationSeconds={duration_seconds}
|
||||
errorMessage={properties?.error_message}
|
||||
exceptionType={properties?.exception_type}
|
||||
/>
|
||||
),
|
||||
accessor: 'status',
|
||||
Header: t('Status'),
|
||||
size: 'xs',
|
||||
id: 'status',
|
||||
},
|
||||
{
|
||||
accessor: 'task_type',
|
||||
Header: t('Type'),
|
||||
size: 'md',
|
||||
id: 'task_type',
|
||||
},
|
||||
{
|
||||
Cell: ({
|
||||
row: {
|
||||
original: { scope },
|
||||
},
|
||||
}: TaskCellProps) => {
|
||||
const scopeConfig: Record<
|
||||
TaskScope,
|
||||
{ label: string; type: 'default' | 'info' | 'warning' }
|
||||
> = {
|
||||
[TaskScope.Private]: { label: t('Private'), type: 'default' },
|
||||
[TaskScope.Shared]: { label: t('Shared'), type: 'info' },
|
||||
[TaskScope.System]: { label: t('System'), type: 'warning' },
|
||||
};
|
||||
|
||||
const config = scopeConfig[scope as TaskScope] || {
|
||||
label: scope,
|
||||
type: 'default' as const,
|
||||
};
|
||||
|
||||
return <Label type={config.type}>{config.label}</Label>;
|
||||
},
|
||||
accessor: 'scope',
|
||||
Header: t('Scope'),
|
||||
size: 'sm',
|
||||
id: 'scope',
|
||||
},
|
||||
{
|
||||
Cell: ({
|
||||
row: {
|
||||
original: { subscriber_count, subscribers },
|
||||
},
|
||||
}: TaskCellProps) => {
|
||||
if (!subscribers || subscriber_count === 0) {
|
||||
return '-';
|
||||
}
|
||||
|
||||
// Convert subscribers to FacePile format
|
||||
const users = subscribers.map((sub: TaskSubscriber) => ({
|
||||
id: sub.user_id,
|
||||
first_name: sub.first_name,
|
||||
last_name: sub.last_name,
|
||||
}));
|
||||
|
||||
return <FacePile users={users} maxCount={3} />;
|
||||
},
|
||||
accessor: 'subscriber_count',
|
||||
Header: t('Subscribers'),
|
||||
size: 'md',
|
||||
id: 'subscribers',
|
||||
disableSortBy: true,
|
||||
},
|
||||
{
|
||||
Cell: ({
|
||||
row: {
|
||||
original: {
|
||||
created_on_delta_humanized: createdOn,
|
||||
created_by: createdBy,
|
||||
},
|
||||
},
|
||||
}: TaskCellProps) => (
|
||||
<CreatedInfo date={createdOn ?? ''} user={createdBy ?? undefined} />
|
||||
),
|
||||
Header: t('Created'),
|
||||
accessor: 'created_on',
|
||||
size: 'xl',
|
||||
id: 'created_on',
|
||||
},
|
||||
{
|
||||
// Hidden column for filtering by created_by
|
||||
accessor: 'created_by',
|
||||
id: 'created_by',
|
||||
hidden: true,
|
||||
},
|
||||
{
|
||||
Cell: ({
|
||||
row: {
|
||||
original: { duration_seconds },
|
||||
},
|
||||
}: TaskCellProps) => formatDuration(duration_seconds) ?? '-',
|
||||
accessor: 'duration_seconds',
|
||||
Header: t('Duration'),
|
||||
size: 'sm',
|
||||
id: 'duration_seconds',
|
||||
disableSortBy: true,
|
||||
},
|
||||
{
|
||||
Cell: ({
|
||||
row: {
|
||||
original: { payload, properties, status },
|
||||
},
|
||||
}: TaskCellProps) => {
|
||||
const hasPayload = payload && Object.keys(payload).length > 0;
|
||||
const hasStackTrace = !!properties?.stack_trace;
|
||||
|
||||
// Show warning if timeout is set but no abort handler during execution
|
||||
// Only show for IN_PROGRESS (abort handler registers at runtime, not during PENDING)
|
||||
const hasTimeoutWithoutHandler =
|
||||
status === TaskStatus.InProgress &&
|
||||
properties?.timeout &&
|
||||
!properties?.is_abortable;
|
||||
|
||||
if (!hasPayload && !hasStackTrace && !hasTimeoutWithoutHandler) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<div style={{ display: 'flex', gap: theme.sizeUnit * 2 }}>
|
||||
{hasTimeoutWithoutHandler && (
|
||||
<Tooltip
|
||||
title={t(
|
||||
'Timeout configured (%s seconds) but no abort handler defined. ' +
|
||||
'Task will continue running past the timeout.',
|
||||
properties.timeout,
|
||||
)}
|
||||
placement="top"
|
||||
>
|
||||
<span>
|
||||
<Icons.WarningOutlined
|
||||
iconSize="l"
|
||||
iconColor={theme.colorWarningText}
|
||||
/>
|
||||
</span>
|
||||
</Tooltip>
|
||||
)}
|
||||
{hasPayload && <TaskPayloadPopover payload={payload} />}
|
||||
{hasStackTrace && properties.stack_trace && (
|
||||
<TaskStackTracePopover stackTrace={properties.stack_trace} />
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
},
|
||||
accessor: 'payload',
|
||||
Header: t('Details'),
|
||||
size: 'xs',
|
||||
id: 'payload',
|
||||
disableSortBy: true,
|
||||
},
|
||||
{
|
||||
Cell: ({ row: { original } }: TaskCellProps) => {
|
||||
// Unified Cancel button logic:
|
||||
// - Show Cancel for any active task that the user can cancel
|
||||
// - The backend handles the smart behavior (unsubscribe vs abort)
|
||||
const isRunning = original.status === TaskStatus.InProgress;
|
||||
// Task is not cancellable if running without abort handler
|
||||
// Use !== true to catch false, undefined, and null
|
||||
const isRunningButNotCancellable =
|
||||
isRunning && !original.properties?.is_abortable;
|
||||
|
||||
const isSharedTask = original.scope === TaskScope.Shared;
|
||||
const userIsSubscribed = original.subscribers?.some(
|
||||
(sub: any) => sub.user_id === user.userId,
|
||||
);
|
||||
|
||||
// Check if task is in a non-active state (completed or aborting)
|
||||
const isNonActiveStatus = [
|
||||
TaskStatus.Success,
|
||||
TaskStatus.Failure,
|
||||
TaskStatus.Aborted,
|
||||
TaskStatus.Aborting,
|
||||
TaskStatus.TimedOut,
|
||||
].includes(original.status as TaskStatus);
|
||||
|
||||
// Show disabled button for running tasks without abort handler
|
||||
// (only for non-shared tasks or when user is the only subscriber)
|
||||
const showDisabledCancel =
|
||||
isRunningButNotCancellable &&
|
||||
!isNonActiveStatus &&
|
||||
(!isSharedTask || (original.subscriber_count || 0) <= 1);
|
||||
|
||||
// Show Cancel button when:
|
||||
// 1. Task can be aborted (pending, or in-progress with handler), OR
|
||||
// 2. User is subscribed to a shared task (can always unsubscribe)
|
||||
// But NOT when disabled cancel is shown (mutually exclusive)
|
||||
const canCancelTask =
|
||||
!showDisabledCancel &&
|
||||
((canAbortTask(original) && !isTaskAborting(original)) ||
|
||||
(isSharedTask && userIsSubscribed && !isNonActiveStatus));
|
||||
|
||||
if (!canCancelTask && !showDisabledCancel) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<div style={{ display: 'flex', gap: theme.sizeUnit * 2 }}>
|
||||
{showDisabledCancel && (
|
||||
<Tooltip
|
||||
id="cancel-disabled-tooltip"
|
||||
title={t(
|
||||
'Cancellation not available due to missing abort handler',
|
||||
)}
|
||||
placement="bottom"
|
||||
>
|
||||
<span
|
||||
className="action-button"
|
||||
style={{ cursor: 'not-allowed' }}
|
||||
>
|
||||
<Icons.StopOutlined
|
||||
iconSize="l"
|
||||
iconColor={theme.colorTextDisabled}
|
||||
/>
|
||||
</span>
|
||||
</Tooltip>
|
||||
)}
|
||||
{canCancelTask && (
|
||||
<Tooltip
|
||||
id="cancel-action-tooltip"
|
||||
title={t('Cancel')}
|
||||
placement="bottom"
|
||||
>
|
||||
<span
|
||||
role="button"
|
||||
tabIndex={0}
|
||||
className="action-button"
|
||||
onClick={() => openCancelModal(original)}
|
||||
>
|
||||
<Icons.StopOutlined iconSize="l" />
|
||||
</span>
|
||||
</Tooltip>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
},
|
||||
Header: t('Actions'),
|
||||
id: 'actions',
|
||||
size: 'sm',
|
||||
disableSortBy: true,
|
||||
},
|
||||
],
|
||||
[user.userId, theme, openCancelModal],
|
||||
);
|
||||
|
||||
const filters: ListViewFilters = useMemo(
|
||||
() => [
|
||||
{
|
||||
Header: t('Status'),
|
||||
key: 'status',
|
||||
id: 'status',
|
||||
input: 'select',
|
||||
operator: FilterOperator.Equals,
|
||||
unfilteredLabel: t('Any'),
|
||||
selects: [
|
||||
{ label: t('Pending'), value: TaskStatus.Pending },
|
||||
{ label: t('In Progress'), value: TaskStatus.InProgress },
|
||||
{ label: t('Success'), value: TaskStatus.Success },
|
||||
{ label: t('Failed'), value: TaskStatus.Failure },
|
||||
{ label: t('Timed Out'), value: TaskStatus.TimedOut },
|
||||
{ label: t('Aborting'), value: TaskStatus.Aborting },
|
||||
{ label: t('Aborted'), value: TaskStatus.Aborted },
|
||||
],
|
||||
},
|
||||
{
|
||||
Header: t('Type'),
|
||||
key: 'task_type',
|
||||
id: 'task_type',
|
||||
input: 'search',
|
||||
operator: FilterOperator.Contains,
|
||||
},
|
||||
{
|
||||
Header: t('Scope'),
|
||||
key: 'scope',
|
||||
id: 'scope',
|
||||
input: 'select',
|
||||
operator: FilterOperator.Equals,
|
||||
unfilteredLabel: t('Any'),
|
||||
selects: [
|
||||
{ label: t('Private'), value: TaskScope.Private },
|
||||
{ label: t('Shared'), value: TaskScope.Shared },
|
||||
{ label: t('System'), value: TaskScope.System },
|
||||
],
|
||||
},
|
||||
{
|
||||
Header: t('Created by'),
|
||||
key: 'created_by',
|
||||
id: 'created_by',
|
||||
input: 'select',
|
||||
operator: FilterOperator.RelationOneMany,
|
||||
unfilteredLabel: t('All'),
|
||||
fetchSelects: createFetchRelated(
|
||||
'task',
|
||||
'created_by',
|
||||
createErrorHandler(errMsg =>
|
||||
addDangerToast(
|
||||
t(
|
||||
'An error occurred while fetching created by values: %s',
|
||||
errMsg,
|
||||
),
|
||||
),
|
||||
),
|
||||
),
|
||||
},
|
||||
],
|
||||
[addDangerToast],
|
||||
);
|
||||
|
||||
const initialSort = [{ id: 'created_on', desc: true }];
|
||||
|
||||
const emptyState = {
|
||||
title: t('No tasks yet'),
|
||||
image: 'filter-results.svg',
|
||||
description: t(
|
||||
'Tasks will appear here as background operations are executed.',
|
||||
),
|
||||
};
|
||||
|
||||
return (
|
||||
<>
|
||||
<SubMenu name={t('Tasks')} />
|
||||
<ListView<Task>
|
||||
className="task-list-view"
|
||||
columns={columns}
|
||||
count={tasksCount}
|
||||
data={tasks}
|
||||
emptyState={emptyState}
|
||||
fetchData={fetchData}
|
||||
filters={filters}
|
||||
initialSort={initialSort}
|
||||
loading={loading}
|
||||
pageSize={PAGE_SIZE}
|
||||
refreshData={refreshData}
|
||||
addDangerToast={addDangerToast}
|
||||
addSuccessToast={addSuccessToast}
|
||||
/>
|
||||
|
||||
{/* Cancel Confirmation Modal */}
|
||||
<Modal
|
||||
title={t('Cancel Task')}
|
||||
show={!!cancelModalTask}
|
||||
onHide={handleCancelModalClose}
|
||||
primaryButtonName={t('Yes, Cancel')}
|
||||
onHandledPrimaryAction={handleCancelConfirm}
|
||||
>
|
||||
{cancelModalTask && (
|
||||
<>
|
||||
<p>
|
||||
{forceCancel
|
||||
? getForceAbortMessage(cancelModalTask)
|
||||
: getCancelDialogMessage(cancelModalTask)}
|
||||
</p>
|
||||
{showForceCancelOption(cancelModalTask) && (
|
||||
<Checkbox
|
||||
checked={forceCancel}
|
||||
onChange={e => setForceCancel(e.target.checked)}
|
||||
disabled={isForceCancelDisabled(cancelModalTask)}
|
||||
>
|
||||
{t('Force abort (stops task for all subscribers)')}
|
||||
</Checkbox>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</Modal>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
export default withToasts(TaskList);
|
||||
@@ -16,6 +16,7 @@
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
import { FeatureFlag, isFeatureEnabled } from '@superset-ui/core';
|
||||
import {
|
||||
lazy,
|
||||
@@ -138,6 +139,10 @@ const RowLevelSecurityList = lazy(
|
||||
),
|
||||
);
|
||||
|
||||
const TaskList = lazy(
|
||||
() => import(/* webpackChunkName: "TaskList" */ 'src/pages/TaskList'),
|
||||
);
|
||||
|
||||
const RolesList = lazy(
|
||||
() => import(/* webpackChunkName: "RolesList" */ 'src/pages/RolesList'),
|
||||
);
|
||||
@@ -297,6 +302,10 @@ export const routes: Routes = [
|
||||
path: '/rowlevelsecurity/list',
|
||||
Component: RowLevelSecurityList,
|
||||
},
|
||||
{
|
||||
path: '/tasks/list/',
|
||||
Component: TaskList,
|
||||
},
|
||||
{
|
||||
path: '/sqllab/',
|
||||
Component: SqlLab,
|
||||
|
||||
@@ -14,7 +14,9 @@
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import redis
|
||||
from flask_caching.backends.rediscache import RedisCache, RedisSentinelCache
|
||||
@@ -28,15 +30,15 @@ class RedisCacheBackend(RedisCache):
|
||||
self,
|
||||
host: str,
|
||||
port: int,
|
||||
password: Optional[str] = None,
|
||||
password: str | None = None,
|
||||
db: int = 0,
|
||||
default_timeout: int = 300,
|
||||
key_prefix: Optional[str] = None,
|
||||
key_prefix: str | None = None,
|
||||
ssl: bool = False,
|
||||
ssl_certfile: Optional[str] = None,
|
||||
ssl_keyfile: Optional[str] = None,
|
||||
ssl_certfile: str | None = None,
|
||||
ssl_keyfile: str | None = None,
|
||||
ssl_cert_reqs: str = "required",
|
||||
ssl_ca_certs: Optional[str] = None,
|
||||
ssl_ca_certs: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
@@ -61,12 +63,61 @@ class RedisCacheBackend(RedisCache):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def set(
|
||||
self,
|
||||
name: str,
|
||||
value: Any,
|
||||
ex: int | None = None,
|
||||
px: int | None = None,
|
||||
nx: bool = False,
|
||||
xx: bool = False,
|
||||
) -> bool | None:
|
||||
"""
|
||||
Set the value at key ``name``.
|
||||
|
||||
:param name: Key name
|
||||
:param value: Value to set
|
||||
:param ex: Expire time in seconds
|
||||
:param px: Expire time in milliseconds
|
||||
:param nx: If True, set only if key does not exist
|
||||
:param xx: If True, set only if key already exists
|
||||
:returns: True if set successfully, None if nx/xx condition not met
|
||||
"""
|
||||
return self._cache.set(name, value, ex=ex, px=px, nx=nx, xx=xx)
|
||||
|
||||
def delete(self, *names: str) -> int:
|
||||
"""
|
||||
Delete one or more keys.
|
||||
|
||||
:param names: Key names to delete
|
||||
:returns: Number of keys deleted
|
||||
"""
|
||||
return self._cache.delete(*names)
|
||||
|
||||
def publish(self, channel: str, message: str) -> int:
|
||||
"""
|
||||
Publish a message to a Redis pub/sub channel.
|
||||
|
||||
:param channel: The channel name to publish to
|
||||
:param message: The message to publish
|
||||
:returns: Number of subscribers that received the message
|
||||
"""
|
||||
return self._cache.publish(channel, message)
|
||||
|
||||
def pubsub(self) -> redis.client.PubSub:
|
||||
"""
|
||||
Create a pub/sub subscription object.
|
||||
|
||||
:returns: PubSub object for subscribing to channels
|
||||
"""
|
||||
return self._cache.pubsub()
|
||||
|
||||
def xadd(
|
||||
self,
|
||||
stream_name: str,
|
||||
event_data: Dict[str, Any],
|
||||
event_data: dict[str, Any],
|
||||
event_id: str = "*",
|
||||
maxlen: Optional[int] = None,
|
||||
maxlen: int | None = None,
|
||||
) -> str:
|
||||
return self._cache.xadd(stream_name, event_data, event_id, maxlen)
|
||||
|
||||
@@ -75,13 +126,13 @@ class RedisCacheBackend(RedisCache):
|
||||
stream_name: str,
|
||||
start: str = "-",
|
||||
end: str = "+",
|
||||
count: Optional[int] = None,
|
||||
) -> List[Any]:
|
||||
count: int | None = None,
|
||||
) -> list[Any]:
|
||||
count = count or self.MAX_EVENT_COUNT
|
||||
return self._cache.xrange(stream_name, start, end, count)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "RedisCacheBackend":
|
||||
def from_config(cls, config: dict[str, Any]) -> RedisCacheBackend:
|
||||
kwargs = {
|
||||
"host": config.get("CACHE_REDIS_HOST", "localhost"),
|
||||
"port": config.get("CACHE_REDIS_PORT", 6379),
|
||||
@@ -108,18 +159,18 @@ class RedisSentinelCacheBackend(RedisSentinelCache):
|
||||
|
||||
def __init__( # pylint: disable=too-many-arguments
|
||||
self,
|
||||
sentinels: List[Tuple[str, int]],
|
||||
sentinels: list[tuple[str, int]],
|
||||
master: str,
|
||||
password: Optional[str] = None,
|
||||
sentinel_password: Optional[str] = None,
|
||||
password: str | None = None,
|
||||
sentinel_password: str | None = None,
|
||||
db: int = 0,
|
||||
default_timeout: int = 300,
|
||||
key_prefix: str = "",
|
||||
ssl: bool = False,
|
||||
ssl_certfile: Optional[str] = None,
|
||||
ssl_keyfile: Optional[str] = None,
|
||||
ssl_certfile: str | None = None,
|
||||
ssl_keyfile: str | None = None,
|
||||
ssl_cert_reqs: str = "required",
|
||||
ssl_ca_certs: Optional[str] = None,
|
||||
ssl_ca_certs: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
# Sentinel dont directly support SSL
|
||||
@@ -177,12 +228,61 @@ class RedisSentinelCacheBackend(RedisSentinelCache):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def set(
|
||||
self,
|
||||
name: str,
|
||||
value: Any,
|
||||
ex: int | None = None,
|
||||
px: int | None = None,
|
||||
nx: bool = False,
|
||||
xx: bool = False,
|
||||
) -> bool | None:
|
||||
"""
|
||||
Set the value at key ``name``.
|
||||
|
||||
:param name: Key name
|
||||
:param value: Value to set
|
||||
:param ex: Expire time in seconds
|
||||
:param px: Expire time in milliseconds
|
||||
:param nx: If True, set only if key does not exist
|
||||
:param xx: If True, set only if key already exists
|
||||
:returns: True if set successfully, None if nx/xx condition not met
|
||||
"""
|
||||
return self._cache.set(name, value, ex=ex, px=px, nx=nx, xx=xx)
|
||||
|
||||
def delete(self, *names: str) -> int:
|
||||
"""
|
||||
Delete one or more keys.
|
||||
|
||||
:param names: Key names to delete
|
||||
:returns: Number of keys deleted
|
||||
"""
|
||||
return self._cache.delete(*names)
|
||||
|
||||
def publish(self, channel: str, message: str) -> int:
|
||||
"""
|
||||
Publish a message to a Redis pub/sub channel.
|
||||
|
||||
:param channel: The channel name to publish to
|
||||
:param message: The message to publish
|
||||
:returns: Number of subscribers that received the message
|
||||
"""
|
||||
return self._cache.publish(channel, message)
|
||||
|
||||
def pubsub(self) -> redis.client.PubSub:
|
||||
"""
|
||||
Create a pub/sub subscription object.
|
||||
|
||||
:returns: PubSub object for subscribing to channels
|
||||
"""
|
||||
return self._cache.pubsub()
|
||||
|
||||
def xadd(
|
||||
self,
|
||||
stream_name: str,
|
||||
event_data: Dict[str, Any],
|
||||
event_data: dict[str, Any],
|
||||
event_id: str = "*",
|
||||
maxlen: Optional[int] = None,
|
||||
maxlen: int | None = None,
|
||||
) -> str:
|
||||
return self._cache.xadd(stream_name, event_data, event_id, maxlen)
|
||||
|
||||
@@ -191,13 +291,13 @@ class RedisSentinelCacheBackend(RedisSentinelCache):
|
||||
stream_name: str,
|
||||
start: str = "-",
|
||||
end: str = "+",
|
||||
count: Optional[int] = None,
|
||||
) -> List[Any]:
|
||||
count: int | None = None,
|
||||
) -> list[Any]:
|
||||
count = count or self.MAX_EVENT_COUNT
|
||||
return self._cache.xrange(stream_name, start, end, count)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "RedisSentinelCacheBackend":
|
||||
def from_config(cls, config: dict[str, Any]) -> RedisSentinelCacheBackend:
|
||||
kwargs = {
|
||||
"sentinels": config.get("CACHE_REDIS_SENTINELS", [("127.0.0.1", 26379)]),
|
||||
"master": config.get("CACHE_REDIS_SENTINEL_MASTER", "mymaster"),
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from superset.commands.distributed_lock.acquire import AcquireDistributedLock
|
||||
from superset.commands.distributed_lock.release import ReleaseDistributedLock
|
||||
|
||||
__all__ = [
|
||||
"AcquireDistributedLock",
|
||||
"ReleaseDistributedLock",
|
||||
]
|
||||
|
||||
132
superset/commands/distributed_lock/acquire.py
Normal file
132
superset/commands/distributed_lock/acquire.py
Normal file
@@ -0,0 +1,132 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
import redis
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from superset.commands.distributed_lock.base import (
|
||||
BaseDistributedLockCommand,
|
||||
get_default_lock_ttl,
|
||||
get_redis_client,
|
||||
)
|
||||
from superset.daos.key_value import KeyValueDAO
|
||||
from superset.exceptions import AcquireDistributedLockFailedException
|
||||
from superset.key_value.exceptions import (
|
||||
KeyValueCodecEncodeException,
|
||||
KeyValueUpsertFailedError,
|
||||
)
|
||||
from superset.key_value.types import KeyValueResource
|
||||
from superset.utils.decorators import on_error, transaction
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AcquireDistributedLock(BaseDistributedLockCommand):
|
||||
"""
|
||||
Acquire a distributed lock with automatic backend selection.
|
||||
|
||||
Uses Redis SET NX EX when SIGNAL_CACHE_CONFIG is configured,
|
||||
otherwise falls back to KeyValue table.
|
||||
|
||||
Raises AcquireDistributedLockFailedException if:
|
||||
- Lock is already held by another process
|
||||
- Redis connection fails
|
||||
"""
|
||||
|
||||
ttl_seconds: int
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
namespace: str,
|
||||
params: dict[str, Any] | None = None,
|
||||
ttl_seconds: int | None = None,
|
||||
) -> None:
|
||||
super().__init__(namespace, params)
|
||||
self.ttl_seconds = ttl_seconds or get_default_lock_ttl()
|
||||
|
||||
def run(self) -> None:
|
||||
if (redis_client := get_redis_client()) is not None:
|
||||
self._acquire_redis(redis_client)
|
||||
else:
|
||||
self._acquire_kv()
|
||||
|
||||
def _acquire_redis(self, redis_client: Any) -> None:
|
||||
"""Acquire lock using Redis SET NX EX (atomic)."""
|
||||
try:
|
||||
# SET NX EX: Set if not exists, with expiration
|
||||
# Returns True if lock acquired, None if already exists
|
||||
acquired = redis_client.set(
|
||||
self.redis_lock_key,
|
||||
"1",
|
||||
nx=True,
|
||||
ex=self.ttl_seconds,
|
||||
)
|
||||
|
||||
if not acquired:
|
||||
logger.debug("Redis lock on %s already taken", self.redis_lock_key)
|
||||
raise AcquireDistributedLockFailedException("Lock already taken")
|
||||
|
||||
logger.debug(
|
||||
"Acquired Redis lock: %s (TTL=%ds)",
|
||||
self.redis_lock_key,
|
||||
self.ttl_seconds,
|
||||
)
|
||||
|
||||
except redis.RedisError as ex:
|
||||
logger.error("Redis lock error for %s: %s", self.redis_lock_key, ex)
|
||||
raise AcquireDistributedLockFailedException(
|
||||
f"Redis lock failed: {ex}"
|
||||
) from ex
|
||||
|
||||
@transaction(
|
||||
on_error=partial(
|
||||
on_error,
|
||||
catches=(
|
||||
KeyValueCodecEncodeException,
|
||||
KeyValueUpsertFailedError,
|
||||
SQLAlchemyError,
|
||||
),
|
||||
reraise=AcquireDistributedLockFailedException,
|
||||
),
|
||||
)
|
||||
def _acquire_kv(self) -> None:
|
||||
"""Acquire lock using KeyValue table (database)."""
|
||||
# Delete expired entries first to prevent stale locks from blocking
|
||||
KeyValueDAO.delete_expired_entries(self.resource)
|
||||
|
||||
# Create entry - unique constraint will raise if lock already exists
|
||||
KeyValueDAO.create_entry(
|
||||
resource=KeyValueResource.LOCK,
|
||||
value={"value": True},
|
||||
codec=self.codec,
|
||||
key=self.key,
|
||||
expires_on=datetime.now(timezone.utc) + timedelta(seconds=self.ttl_seconds),
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"Acquired KV lock: namespace=%s key=%s (TTL=%ds)",
|
||||
self.namespace,
|
||||
self.key,
|
||||
self.ttl_seconds,
|
||||
)
|
||||
@@ -15,27 +15,58 @@
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any, Union
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
from flask import current_app as app
|
||||
from flask import current_app
|
||||
|
||||
from superset.commands.base import BaseCommand
|
||||
from superset.distributed_lock.utils import get_key
|
||||
from superset.extensions import cache_manager
|
||||
from superset.key_value.types import JsonKeyValueCodec, KeyValueResource
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import redis
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
stats_logger = app.config["STATS_LOGGER"]
|
||||
|
||||
|
||||
def get_default_lock_ttl() -> int:
|
||||
"""Get the default lock TTL from config."""
|
||||
return int(current_app.config.get("DISTRIBUTED_LOCK_DEFAULT_TTL", 30))
|
||||
|
||||
|
||||
def get_redis_client() -> "redis.Redis[Any] | None":
|
||||
"""
|
||||
Get Redis client from signal cache if available.
|
||||
|
||||
Returns None if SIGNAL_CACHE_CONFIG is not configured,
|
||||
allowing fallback to database-backed locking.
|
||||
"""
|
||||
backend = cache_manager.signal_cache
|
||||
return backend._cache if backend else None
|
||||
|
||||
|
||||
class BaseDistributedLockCommand(BaseCommand):
|
||||
"""Base command for distributed lock operations."""
|
||||
|
||||
key: uuid.UUID
|
||||
namespace: str
|
||||
codec = JsonKeyValueCodec()
|
||||
resource = KeyValueResource.LOCK
|
||||
|
||||
def __init__(self, namespace: str, params: Union[dict[str, Any], None] = None):
|
||||
self.key = get_key(namespace, **(params or {}))
|
||||
def __init__(self, namespace: str, params: dict[str, Any] | None = None) -> None:
|
||||
self.namespace = namespace
|
||||
self.params = params or {}
|
||||
self.key = get_key(namespace, **self.params)
|
||||
|
||||
@property
|
||||
def redis_lock_key(self) -> str:
|
||||
"""Redis key for this lock."""
|
||||
return f"lock:{self.namespace}:{self.key}"
|
||||
|
||||
def validate(self) -> None:
|
||||
pass
|
||||
|
||||
@@ -1,64 +0,0 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from functools import partial
|
||||
|
||||
from flask import current_app as app
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from superset.commands.distributed_lock.base import BaseDistributedLockCommand
|
||||
from superset.daos.key_value import KeyValueDAO
|
||||
from superset.exceptions import CreateKeyValueDistributedLockFailedException
|
||||
from superset.key_value.exceptions import (
|
||||
KeyValueCodecEncodeException,
|
||||
KeyValueUpsertFailedError,
|
||||
)
|
||||
from superset.key_value.types import KeyValueResource
|
||||
from superset.utils.decorators import on_error, transaction
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
stats_logger = app.config["STATS_LOGGER"]
|
||||
|
||||
|
||||
class CreateDistributedLock(BaseDistributedLockCommand):
|
||||
lock_expiration = timedelta(seconds=30)
|
||||
|
||||
def validate(self) -> None:
|
||||
pass
|
||||
|
||||
@transaction(
|
||||
on_error=partial(
|
||||
on_error,
|
||||
catches=(
|
||||
KeyValueCodecEncodeException,
|
||||
KeyValueUpsertFailedError,
|
||||
SQLAlchemyError,
|
||||
),
|
||||
reraise=CreateKeyValueDistributedLockFailedException,
|
||||
),
|
||||
)
|
||||
def run(self) -> None:
|
||||
KeyValueDAO.delete_expired_entries(self.resource)
|
||||
KeyValueDAO.create_entry(
|
||||
resource=KeyValueResource.LOCK,
|
||||
value={"value": True},
|
||||
codec=self.codec,
|
||||
key=self.key,
|
||||
expires_on=datetime.now() + self.lock_expiration,
|
||||
)
|
||||
@@ -1,49 +0,0 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import logging
|
||||
from functools import partial
|
||||
|
||||
from flask import current_app as app
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from superset.commands.distributed_lock.base import BaseDistributedLockCommand
|
||||
from superset.daos.key_value import KeyValueDAO
|
||||
from superset.exceptions import DeleteKeyValueDistributedLockFailedException
|
||||
from superset.key_value.exceptions import KeyValueDeleteFailedError
|
||||
from superset.utils.decorators import on_error, transaction
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
stats_logger = app.config["STATS_LOGGER"]
|
||||
|
||||
|
||||
class DeleteDistributedLock(BaseDistributedLockCommand):
|
||||
def validate(self) -> None:
|
||||
pass
|
||||
|
||||
@transaction(
|
||||
on_error=partial(
|
||||
on_error,
|
||||
catches=(
|
||||
KeyValueDeleteFailedError,
|
||||
SQLAlchemyError,
|
||||
),
|
||||
reraise=DeleteKeyValueDistributedLockFailedException,
|
||||
),
|
||||
)
|
||||
def run(self) -> None:
|
||||
KeyValueDAO.delete_entry(self.resource, self.key)
|
||||
83
superset/commands/distributed_lock/release.py
Normal file
83
superset/commands/distributed_lock/release.py
Normal file
@@ -0,0 +1,83 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
import redis
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from superset.commands.distributed_lock.base import (
|
||||
BaseDistributedLockCommand,
|
||||
get_redis_client,
|
||||
)
|
||||
from superset.daos.key_value import KeyValueDAO
|
||||
from superset.exceptions import ReleaseDistributedLockFailedException
|
||||
from superset.key_value.exceptions import KeyValueDeleteFailedError
|
||||
from superset.utils.decorators import on_error, transaction
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ReleaseDistributedLock(BaseDistributedLockCommand):
|
||||
"""
|
||||
Release a distributed lock with automatic backend selection.
|
||||
|
||||
Uses Redis DELETE when SIGNAL_CACHE_CONFIG is configured,
|
||||
otherwise deletes from KeyValue table.
|
||||
"""
|
||||
|
||||
def run(self) -> None:
|
||||
if (redis_client := get_redis_client()) is not None:
|
||||
self._release_redis(redis_client)
|
||||
else:
|
||||
self._release_kv()
|
||||
|
||||
def _release_redis(self, redis_client: Any) -> None:
|
||||
"""Release lock using Redis DELETE."""
|
||||
try:
|
||||
redis_client.delete(self.redis_lock_key)
|
||||
logger.debug("Released Redis lock: %s", self.redis_lock_key)
|
||||
except redis.RedisError as ex:
|
||||
# Log warning but don't raise - TTL will handle cleanup
|
||||
logger.warning(
|
||||
"Failed to release Redis lock %s: %s (TTL will handle cleanup)",
|
||||
self.redis_lock_key,
|
||||
ex,
|
||||
)
|
||||
|
||||
@transaction(
|
||||
on_error=partial(
|
||||
on_error,
|
||||
catches=(
|
||||
KeyValueDeleteFailedError,
|
||||
SQLAlchemyError,
|
||||
),
|
||||
reraise=ReleaseDistributedLockFailedException,
|
||||
),
|
||||
)
|
||||
def _release_kv(self) -> None:
|
||||
"""Release lock using KeyValue table (database)."""
|
||||
KeyValueDAO.delete_entry(self.resource, self.key)
|
||||
logger.debug(
|
||||
"Released KV lock: namespace=%s key=%s",
|
||||
self.namespace,
|
||||
self.key,
|
||||
)
|
||||
28
superset/commands/tasks/__init__.py
Normal file
28
superset/commands/tasks/__init__.py
Normal file
@@ -0,0 +1,28 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from superset.commands.tasks.cancel import CancelTaskCommand
|
||||
from superset.commands.tasks.prune import TaskPruneCommand
|
||||
from superset.commands.tasks.submit import SubmitTaskCommand
|
||||
from superset.commands.tasks.update import UpdateTaskCommand
|
||||
|
||||
__all__ = [
|
||||
"CancelTaskCommand",
|
||||
"SubmitTaskCommand",
|
||||
"TaskPruneCommand",
|
||||
"UpdateTaskCommand",
|
||||
]
|
||||
314
superset/commands/tasks/cancel.py
Normal file
314
superset/commands/tasks/cancel.py
Normal file
@@ -0,0 +1,314 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""Unified cancel task command for GTF."""
|
||||
|
||||
import logging
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
from flask import current_app
|
||||
from superset_core.api.tasks import TaskScope, TaskStatus
|
||||
|
||||
from superset.commands.base import BaseCommand
|
||||
from superset.commands.tasks.exceptions import (
|
||||
TaskAbortFailedError,
|
||||
TaskNotAbortableError,
|
||||
TaskNotFoundError,
|
||||
TaskPermissionDeniedError,
|
||||
)
|
||||
from superset.extensions import security_manager
|
||||
from superset.stats_logger import BaseStatsLogger
|
||||
from superset.tasks.locks import task_lock
|
||||
from superset.tasks.utils import get_active_dedup_key
|
||||
from superset.utils.core import get_user_id
|
||||
from superset.utils.decorators import on_error, transaction
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.models.tasks import Task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CancelTaskCommand(BaseCommand):
|
||||
"""
|
||||
Unified command to cancel a task.
|
||||
|
||||
Behavior:
|
||||
- For private tasks or single-subscriber tasks: aborts the task
|
||||
- For shared tasks with multiple subscribers (non-admin): unsubscribes user
|
||||
- For shared tasks with force=True (admin only): aborts for all subscribers
|
||||
|
||||
The term "cancel" is user-facing; internally this may abort or unsubscribe.
|
||||
|
||||
This command acquires a distributed lock before starting a transaction to
|
||||
prevent race conditions with concurrent submit/cancel operations.
|
||||
|
||||
Permission checks are deferred to inside the lock to minimize SELECTs:
|
||||
we only fetch the task once, then validate permissions on the fetched data.
|
||||
"""
|
||||
|
||||
def __init__(self, task_uuid: UUID, force: bool = False):
|
||||
"""
|
||||
Initialize the cancel command.
|
||||
|
||||
:param task_uuid: UUID of the task to cancel
|
||||
:param force: If True, force abort even with multiple subscribers (admin only)
|
||||
"""
|
||||
self._task_uuid = task_uuid
|
||||
self._force = force
|
||||
self._action_taken: str = (
|
||||
"cancelled" # Will be set to 'aborted' or 'unsubscribed'
|
||||
)
|
||||
self._should_publish_abort: bool = False
|
||||
|
||||
def run(self) -> "Task":
|
||||
"""
|
||||
Execute the cancel command with distributed locking.
|
||||
|
||||
The lock is acquired BEFORE starting the transaction to avoid holding
|
||||
a DB connection during lock acquisition. Uses dedup_key as lock key
|
||||
to ensure Submit and Cancel operations use the same lock.
|
||||
|
||||
:returns: The updated task model
|
||||
"""
|
||||
from superset.daos.tasks import TaskDAO
|
||||
|
||||
# Lightweight fetch to compute dedup_key for locking
|
||||
# This is needed to use the same lock key as SubmitTaskCommand
|
||||
task = TaskDAO.find_one_or_none(
|
||||
skip_base_filter=security_manager.is_admin(), uuid=self._task_uuid
|
||||
)
|
||||
|
||||
if not task:
|
||||
raise TaskNotFoundError()
|
||||
|
||||
# Compute dedup_key using the same logic as SubmitTaskCommand
|
||||
dedup_key = get_active_dedup_key(
|
||||
scope=task.scope,
|
||||
task_type=task.task_type,
|
||||
task_key=task.task_key,
|
||||
user_id=task.user_id,
|
||||
)
|
||||
|
||||
# Acquire lock BEFORE transaction starts
|
||||
# Using dedup_key ensures Submit and Cancel use the same lock
|
||||
with task_lock(dedup_key):
|
||||
result = self._execute_with_transaction()
|
||||
|
||||
# Publish abort notification AFTER transaction commits
|
||||
# This prevents race conditions where listeners check DB before commit
|
||||
if self._should_publish_abort:
|
||||
from superset.tasks.manager import TaskManager
|
||||
|
||||
TaskManager.publish_abort(self._task_uuid)
|
||||
|
||||
return result
|
||||
|
||||
@transaction(on_error=partial(on_error, reraise=TaskAbortFailedError))
|
||||
def _execute_with_transaction(self) -> "Task":
|
||||
"""
|
||||
Execute the cancel operation inside a transaction.
|
||||
|
||||
Combines fetch + validation + execution in a single transaction,
|
||||
reducing the number of SELECTs from 3 to 1 (plus DAO operations).
|
||||
|
||||
:returns: The updated task model
|
||||
"""
|
||||
from superset.daos.tasks import TaskDAO
|
||||
|
||||
# Check admin status (no DB access)
|
||||
is_admin = security_manager.is_admin()
|
||||
|
||||
# Force flag requires admin
|
||||
if self._force and not is_admin:
|
||||
raise TaskPermissionDeniedError(
|
||||
"Only administrators can force cancel a task"
|
||||
)
|
||||
|
||||
# Single SELECT: fetch task and validate permissions on it
|
||||
task = TaskDAO.find_one_or_none(skip_base_filter=is_admin, uuid=self._task_uuid)
|
||||
|
||||
if not task:
|
||||
raise TaskNotFoundError()
|
||||
|
||||
# Validate permissions on the fetched task
|
||||
self._validate_permissions(task, is_admin)
|
||||
|
||||
# Execute cancel and return updated task
|
||||
return self._do_cancel(task, is_admin)
|
||||
|
||||
def _validate_permissions(self, task: "Task", is_admin: bool) -> None:
|
||||
"""
|
||||
Validate permissions on an already-fetched task.
|
||||
|
||||
Permission rules by scope:
|
||||
- private: Only creator or admin (already filtered by base_filter)
|
||||
- shared: Subscribers or admin
|
||||
- system: Only admin
|
||||
|
||||
:param task: The task to validate permissions for
|
||||
:param is_admin: Whether current user is admin
|
||||
:raises TaskAbortFailedError: If task is not in cancellable state
|
||||
:raises TaskPermissionDeniedError: If user lacks permission
|
||||
"""
|
||||
# Check if task is in a cancellable state
|
||||
if task.status not in [
|
||||
TaskStatus.PENDING.value,
|
||||
TaskStatus.IN_PROGRESS.value,
|
||||
TaskStatus.ABORTING.value, # Already aborting is OK (idempotent)
|
||||
]:
|
||||
raise TaskAbortFailedError()
|
||||
|
||||
# Admin can cancel anything
|
||||
if is_admin:
|
||||
return
|
||||
|
||||
# Non-admin permission checks by scope
|
||||
user_id = get_user_id()
|
||||
|
||||
if task.scope == TaskScope.SYSTEM.value:
|
||||
# System tasks are admin-only
|
||||
raise TaskPermissionDeniedError(
|
||||
"Only administrators can cancel system tasks"
|
||||
)
|
||||
|
||||
if task.is_shared:
|
||||
# Shared tasks: must be a subscriber
|
||||
if not user_id or not task.has_subscriber(user_id):
|
||||
raise TaskPermissionDeniedError(
|
||||
"You must be subscribed to cancel this shared task"
|
||||
)
|
||||
|
||||
# Private tasks: already filtered by base_filter (only creator can see)
|
||||
# If we got here, user has permission
|
||||
|
||||
def _do_cancel(self, task: "Task", is_admin: bool) -> "Task":
|
||||
"""
|
||||
Execute the cancel operation (abort or unsubscribe).
|
||||
|
||||
:param task: The task to cancel
|
||||
:param is_admin: Whether current user is admin
|
||||
:returns: The updated task model
|
||||
"""
|
||||
user_id = get_user_id()
|
||||
|
||||
# Determine action based on task scope and force flag
|
||||
should_abort = (
|
||||
# Admin with force flag always aborts
|
||||
(is_admin and self._force)
|
||||
# Private tasks always abort (only one user)
|
||||
or task.is_private
|
||||
# System tasks always abort (admin only anyway)
|
||||
or task.is_system
|
||||
# Single or last subscriber - abort
|
||||
or task.subscriber_count <= 1
|
||||
)
|
||||
|
||||
if should_abort:
|
||||
return self._do_abort(task, is_admin)
|
||||
else:
|
||||
return self._do_unsubscribe(task, user_id)
|
||||
|
||||
def _do_abort(self, task: "Task", is_admin: bool) -> "Task":
|
||||
"""
|
||||
Execute abort operation.
|
||||
|
||||
:param task: The task to abort
|
||||
:param is_admin: Whether current user is admin
|
||||
:returns: The updated task model
|
||||
"""
|
||||
from superset.daos.tasks import TaskDAO
|
||||
|
||||
try:
|
||||
result: Task | None = TaskDAO.abort_task(
|
||||
task.uuid, skip_base_filter=is_admin
|
||||
)
|
||||
except TaskNotAbortableError:
|
||||
raise
|
||||
|
||||
if result is None:
|
||||
# abort_task returned None - task wasn't aborted
|
||||
# This can happen if task is already finished
|
||||
raise TaskAbortFailedError()
|
||||
|
||||
self._action_taken = "aborted"
|
||||
|
||||
# Track if we need to publish abort after commit
|
||||
if TaskStatus(result.status) == TaskStatus.ABORTING:
|
||||
self._should_publish_abort = True
|
||||
|
||||
# Emit stats metric
|
||||
stats_logger: BaseStatsLogger = current_app.config["STATS_LOGGER"]
|
||||
stats_logger.incr("gtf.task.abort")
|
||||
|
||||
logger.info(
|
||||
"Task aborted: %s (scope: %s, force: %s)",
|
||||
task.uuid,
|
||||
task.scope,
|
||||
self._force,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def _do_unsubscribe(self, task: "Task", user_id: int | None) -> "Task":
|
||||
"""
|
||||
Execute unsubscribe operation.
|
||||
|
||||
:param task: The task to unsubscribe from
|
||||
:param user_id: ID of user to unsubscribe
|
||||
:returns: The updated task model
|
||||
"""
|
||||
from superset.daos.tasks import TaskDAO
|
||||
|
||||
self._action_taken = "unsubscribed"
|
||||
|
||||
if not user_id or not task.has_subscriber(user_id):
|
||||
# User not subscribed - they shouldn't be able to cancel
|
||||
raise TaskPermissionDeniedError(
|
||||
"You are not subscribed to this shared task"
|
||||
)
|
||||
|
||||
result = TaskDAO.remove_subscriber(task.id, user_id)
|
||||
if result is None:
|
||||
raise TaskPermissionDeniedError(
|
||||
"You are not subscribed to this shared task"
|
||||
)
|
||||
|
||||
# Emit stats metric
|
||||
stats_logger: BaseStatsLogger = current_app.config["STATS_LOGGER"]
|
||||
stats_logger.incr("gtf.task.unsubscribe")
|
||||
|
||||
logger.info(
|
||||
"User %s unsubscribed from shared task: %s",
|
||||
user_id,
|
||||
task.uuid,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def validate(self) -> None:
|
||||
pass
|
||||
|
||||
@property
|
||||
def action_taken(self) -> str:
|
||||
"""
|
||||
Get the action that was taken.
|
||||
|
||||
:returns: 'aborted' or 'unsubscribed'
|
||||
"""
|
||||
return self._action_taken
|
||||
106
superset/commands/tasks/exceptions.py
Normal file
106
superset/commands/tasks/exceptions.py
Normal file
@@ -0,0 +1,106 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from flask_babel import lazy_gettext as _
|
||||
|
||||
from superset.commands.exceptions import (
|
||||
CommandException,
|
||||
CommandInvalidError,
|
||||
CreateFailedError,
|
||||
ForbiddenError,
|
||||
UpdateFailedError,
|
||||
)
|
||||
|
||||
|
||||
class TaskNotFoundError(CommandException):
|
||||
"""Task not found."""
|
||||
|
||||
status = 404
|
||||
message = "Task not found."
|
||||
|
||||
|
||||
class TaskInvalidError(CommandInvalidError):
|
||||
"""Task parameters are invalid."""
|
||||
|
||||
message = _("Task parameters are invalid.")
|
||||
|
||||
|
||||
class TaskCreateFailedError(CreateFailedError):
|
||||
"""Task creation failed."""
|
||||
|
||||
message = _("Task could not be created.")
|
||||
|
||||
|
||||
class TaskUpdateFailedError(UpdateFailedError):
|
||||
"""Task update failed."""
|
||||
|
||||
message = _("Task could not be updated.")
|
||||
|
||||
|
||||
class TaskAbortFailedError(CommandException):
|
||||
"""Task abortion failed."""
|
||||
|
||||
status = 422
|
||||
message = _("Task could not be aborted.")
|
||||
|
||||
|
||||
class TaskNotAbortableError(CommandException):
|
||||
"""
|
||||
Task cannot be aborted.
|
||||
|
||||
Raised when attempting to abort an in-progress task that has not
|
||||
registered an abort handler (is_abortable is not True).
|
||||
"""
|
||||
|
||||
status = 400
|
||||
message = _(
|
||||
"Task is not abortable. The task is in progress but has not "
|
||||
"registered an abort handler."
|
||||
)
|
||||
|
||||
|
||||
class TaskForbiddenError(ForbiddenError):
|
||||
"""Task operation forbidden."""
|
||||
|
||||
message = _("Changing this task is forbidden")
|
||||
|
||||
|
||||
class TaskPermissionDeniedError(ForbiddenError):
|
||||
"""Task operation not permitted for current user."""
|
||||
|
||||
def __init__(self, message: str | None = None):
|
||||
super().__init__()
|
||||
if message:
|
||||
self.message = message
|
||||
else:
|
||||
self.message = _("You do not have permission to perform this operation")
|
||||
|
||||
|
||||
class GlobalTaskFrameworkDisabledError(CommandException):
|
||||
"""
|
||||
Raised when a GTF task is called or scheduled but GTF is disabled.
|
||||
|
||||
This exception is raised at call/schedule time (not decoration time) to allow
|
||||
modules with @task-decorated functions to be imported safely when GTF is disabled.
|
||||
The check is deferred until someone actually tries to execute a task.
|
||||
"""
|
||||
|
||||
message = _(
|
||||
"The Global Task Framework is not enabled. "
|
||||
"Set GLOBAL_TASK_FRAMEWORK=True in your feature flags to use @task. "
|
||||
"See https://superset.apache.org/docs/configuration/async-queries-celery "
|
||||
"for configuration details."
|
||||
)
|
||||
184
superset/commands/tasks/internal_update.py
Normal file
184
superset/commands/tasks/internal_update.py
Normal file
@@ -0,0 +1,184 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""Internal task update commands for GTF internal use only.
|
||||
|
||||
These commands perform zero-read updates using targeted SQL UPDATE statements.
|
||||
They're designed for use by TaskContext and executor code where the framework
|
||||
owns the authoritative state and doesn't need to read before writing.
|
||||
|
||||
Unlike UpdateTaskCommand, these commands:
|
||||
- Do NOT fetch the task entity before updating
|
||||
- Do NOT check permissions (internal use only)
|
||||
- Use targeted SQL UPDATE for efficiency
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from superset_core.api.tasks import TaskProperties, TaskStatus
|
||||
|
||||
from superset.commands.base import BaseCommand
|
||||
from superset.commands.tasks.exceptions import TaskUpdateFailedError
|
||||
from superset.daos.tasks import TaskDAO
|
||||
from superset.utils.decorators import on_error, transaction
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class InternalUpdateTaskCommand(BaseCommand):
|
||||
"""
|
||||
Zero-read task update command for properties/payload.
|
||||
|
||||
This command directly writes properties and/or payload to the database
|
||||
without reading the current values first. The caller (TaskContext)
|
||||
maintains the authoritative cached state and passes complete merged
|
||||
values to write.
|
||||
|
||||
This is an optimization for task execution where:
|
||||
1. The executor owns the properties/payload state
|
||||
2. No permission checks are needed (internal framework code)
|
||||
3. Status column should not be touched (use InternalStatusTransitionCommand)
|
||||
|
||||
WARNING: This command should ONLY be used by TaskContext and similar
|
||||
internal framework code. External callers should use UpdateTaskCommand.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
task_uuid: UUID,
|
||||
properties: TaskProperties | None = None,
|
||||
payload: dict[str, Any] | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize internal update command.
|
||||
|
||||
:param task_uuid: UUID of the task to update
|
||||
:param properties: Complete properties dict to write (replaces existing)
|
||||
:param payload: Complete payload dict to write (replaces existing)
|
||||
"""
|
||||
self._task_uuid = task_uuid
|
||||
self._properties = properties
|
||||
self._payload = payload
|
||||
|
||||
def validate(self) -> None:
|
||||
"""No validation needed for internal command."""
|
||||
pass
|
||||
|
||||
@transaction(on_error=partial(on_error, reraise=TaskUpdateFailedError))
|
||||
def run(self) -> bool:
|
||||
"""
|
||||
Execute zero-read update.
|
||||
|
||||
:returns: True if task was updated, False if not found or nothing to update
|
||||
"""
|
||||
if self._properties is None and self._payload is None:
|
||||
return False
|
||||
|
||||
updated = TaskDAO.set_properties_and_payload(
|
||||
task_uuid=self._task_uuid,
|
||||
properties=self._properties,
|
||||
payload=self._payload,
|
||||
)
|
||||
|
||||
if updated:
|
||||
logger.debug(
|
||||
"Internal update for task %s: properties=%s, payload=%s",
|
||||
self._task_uuid,
|
||||
self._properties is not None,
|
||||
self._payload is not None,
|
||||
)
|
||||
|
||||
return updated
|
||||
|
||||
|
||||
class InternalStatusTransitionCommand(BaseCommand):
|
||||
"""
|
||||
Atomic conditional status transition command for executor use.
|
||||
|
||||
This command provides race-safe status transitions by using atomic
|
||||
compare-and-swap semantics. The status is only updated if the current
|
||||
status matches the expected value(s).
|
||||
|
||||
Use cases:
|
||||
- PENDING → IN_PROGRESS: Task pickup (executor starting)
|
||||
- IN_PROGRESS → SUCCESS: Normal completion (only if not ABORTING)
|
||||
- IN_PROGRESS → FAILURE: Task exception (only if not ABORTING)
|
||||
- ABORTING → ABORTED: Abort handlers completed successfully
|
||||
- ABORTING → TIMED_OUT: Timeout handlers completed successfully
|
||||
- ABORTING → FAILURE: Abort/cleanup handlers failed
|
||||
|
||||
The atomic nature prevents race conditions where:
|
||||
- Executor tries to set SUCCESS but task was concurrently aborted
|
||||
- Multiple executors try to pick up the same task
|
||||
|
||||
WARNING: This command should ONLY be used by executor code (decorators.py,
|
||||
scheduler.py). External callers should use UpdateTaskCommand.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
task_uuid: UUID,
|
||||
new_status: TaskStatus | str,
|
||||
expected_status: TaskStatus | str | list[TaskStatus | str],
|
||||
properties: TaskProperties | None = None,
|
||||
set_started_at: bool = False,
|
||||
set_ended_at: bool = False,
|
||||
):
|
||||
"""
|
||||
Initialize status transition command.
|
||||
|
||||
:param task_uuid: UUID of the task to update
|
||||
:param new_status: Target status to set
|
||||
:param expected_status: Current status(es) required for update to succeed.
|
||||
Can be a single status or list of acceptable current statuses.
|
||||
:param properties: Optional properties to update atomically with status
|
||||
(e.g., error_message on FAILURE)
|
||||
:param set_started_at: If True, also set started_at to current timestamp.
|
||||
Should be True for PENDING → IN_PROGRESS transitions.
|
||||
:param set_ended_at: If True, also set ended_at to current timestamp.
|
||||
Should be True for terminal status transitions.
|
||||
"""
|
||||
self._task_uuid = task_uuid
|
||||
self._new_status = new_status
|
||||
self._expected_status = expected_status
|
||||
self._properties = properties
|
||||
self._set_started_at = set_started_at
|
||||
self._set_ended_at = set_ended_at
|
||||
|
||||
def validate(self) -> None:
|
||||
"""No validation needed for internal command."""
|
||||
pass
|
||||
|
||||
@transaction(on_error=partial(on_error, reraise=TaskUpdateFailedError))
|
||||
def run(self) -> bool:
|
||||
"""
|
||||
Execute atomic conditional status update.
|
||||
|
||||
:returns: True if status was updated (expected matched), False otherwise
|
||||
"""
|
||||
return TaskDAO.conditional_status_update(
|
||||
task_uuid=self._task_uuid,
|
||||
new_status=self._new_status,
|
||||
expected_status=self._expected_status,
|
||||
properties=self._properties,
|
||||
set_started_at=self._set_started_at,
|
||||
set_ended_at=self._set_ended_at,
|
||||
)
|
||||
134
superset/commands/tasks/prune.py
Normal file
134
superset/commands/tasks/prune.py
Normal file
@@ -0,0 +1,134 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import sqlalchemy as sa
|
||||
from superset_core.api.tasks import TaskStatus
|
||||
|
||||
from superset import db
|
||||
from superset.commands.base import BaseCommand
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# pylint: disable=consider-using-transaction
|
||||
class TaskPruneCommand(BaseCommand):
|
||||
"""
|
||||
Command to prune the tasks table by deleting rows older than the specified
|
||||
retention period.
|
||||
|
||||
This command deletes records from the `Task` table that are in terminal states
|
||||
(success, failure, aborted, or timed_out) and have not been changed within the
|
||||
specified number of days. It helps in maintaining the database by removing
|
||||
outdated entries and freeing up space.
|
||||
|
||||
Attributes:
|
||||
retention_period_days (int): The number of days for which records should be retained.
|
||||
Records older than this period will be deleted.
|
||||
max_rows_per_run (int | None): The maximum number of rows to delete in a single run.
|
||||
If provided and greater than zero, rows are selected
|
||||
deterministically from the oldest first (by timestamp then id)
|
||||
up to this limit in this execution.
|
||||
""" # noqa: E501
|
||||
|
||||
def __init__(self, retention_period_days: int, max_rows_per_run: int | None = None):
|
||||
"""
|
||||
:param retention_period_days: Number of days to keep in the tasks table
|
||||
:param max_rows_per_run: The maximum number of rows to delete in a single run.
|
||||
If provided and greater than zero, rows are selected deterministically from the
|
||||
oldest first (by timestamp then id) up to this limit in this execution.
|
||||
""" # noqa: E501
|
||||
self.retention_period_days = retention_period_days
|
||||
self.max_rows_per_run = max_rows_per_run
|
||||
|
||||
def run(self) -> None:
|
||||
"""
|
||||
Executes the prune command
|
||||
"""
|
||||
batch_size = 999 # SQLite has a IN clause limit of 999
|
||||
total_deleted = 0
|
||||
start_time = time.time()
|
||||
|
||||
# Select all IDs that need to be deleted
|
||||
# Only delete completed tasks (success, failure, or aborted)
|
||||
from superset.models.tasks import Task
|
||||
|
||||
select_stmt = sa.select(Task.id).where(
|
||||
Task.ended_at < datetime.now() - timedelta(days=self.retention_period_days),
|
||||
Task.status.in_(
|
||||
[
|
||||
TaskStatus.SUCCESS.value,
|
||||
TaskStatus.FAILURE.value,
|
||||
TaskStatus.ABORTED.value,
|
||||
TaskStatus.TIMED_OUT.value,
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
# Optionally limited by max_rows_per_run
|
||||
# order by oldest first for deterministic deletion
|
||||
if self.max_rows_per_run is not None and self.max_rows_per_run > 0:
|
||||
select_stmt = select_stmt.order_by(
|
||||
Task.ended_at.asc(), Task.id.asc()
|
||||
).limit(self.max_rows_per_run)
|
||||
|
||||
ids_to_delete = db.session.execute(select_stmt).scalars().all()
|
||||
|
||||
total_rows = len(ids_to_delete)
|
||||
|
||||
logger.info("Total rows to be deleted: %s", f"{total_rows:,}")
|
||||
|
||||
next_logging_threshold = 1
|
||||
|
||||
# Iterate over the IDs in batches
|
||||
for i in range(0, total_rows, batch_size):
|
||||
batch_ids = ids_to_delete[i : i + batch_size]
|
||||
|
||||
# Delete the selected batch using IN clause
|
||||
result = db.session.execute(sa.delete(Task).where(Task.id.in_(batch_ids)))
|
||||
|
||||
# Update the total number of deleted records
|
||||
total_deleted += result.rowcount
|
||||
|
||||
# Explicitly commit the transaction given that if an error occurs, we want to ensure that the # noqa: E501
|
||||
# records that have been deleted so far are committed
|
||||
db.session.commit()
|
||||
|
||||
# Log the number of deleted records every 1% increase in progress
|
||||
percentage_complete = (total_deleted / total_rows) * 100
|
||||
if percentage_complete >= next_logging_threshold:
|
||||
logger.info(
|
||||
"Deleted %s rows from the tasks table older than %s days (%d%% complete)", # noqa: E501
|
||||
f"{total_deleted:,}",
|
||||
self.retention_period_days,
|
||||
percentage_complete,
|
||||
)
|
||||
next_logging_threshold += 1
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
minutes, seconds = divmod(elapsed_time, 60)
|
||||
formatted_time = f"{int(minutes):02}:{int(seconds):02}"
|
||||
logger.info(
|
||||
"Pruning complete: %s rows deleted in %s",
|
||||
f"{total_deleted:,}",
|
||||
formatted_time,
|
||||
)
|
||||
|
||||
def validate(self) -> None:
|
||||
pass
|
||||
168
superset/commands/tasks/submit.py
Normal file
168
superset/commands/tasks/submit.py
Normal file
@@ -0,0 +1,168 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""Submit task command for GTF."""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from functools import partial
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
from flask import current_app
|
||||
from marshmallow import ValidationError
|
||||
from superset_core.api.tasks import TaskScope
|
||||
|
||||
from superset.commands.base import BaseCommand
|
||||
from superset.commands.tasks.exceptions import (
|
||||
TaskCreateFailedError,
|
||||
TaskInvalidError,
|
||||
)
|
||||
from superset.daos.exceptions import DAOCreateFailedError
|
||||
from superset.stats_logger import BaseStatsLogger
|
||||
from superset.tasks.locks import task_lock
|
||||
from superset.tasks.utils import get_active_dedup_key
|
||||
from superset.utils.decorators import on_error, transaction
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.models.tasks import Task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SubmitTaskCommand(BaseCommand):
|
||||
"""
|
||||
Command to submit a task (create new or join existing).
|
||||
|
||||
This command owns locking and create-vs-join business logic.
|
||||
It acquires a distributed lock and then decides whether to:
|
||||
- Create a new task (if no existing task with same dedup_key)
|
||||
- Join an existing task by adding the user as subscriber
|
||||
"""
|
||||
|
||||
def __init__(self, data: dict[str, Any]):
|
||||
self._properties = data.copy()
|
||||
|
||||
@transaction(on_error=partial(on_error, reraise=TaskCreateFailedError))
|
||||
def run(self) -> "Task":
|
||||
"""
|
||||
Execute the command with distributed locking.
|
||||
|
||||
Acquires lock based on dedup_key, then checks for existing task
|
||||
and either creates new or joins existing (adding subscriber).
|
||||
|
||||
:returns: Task model (either newly created or existing)
|
||||
"""
|
||||
task, _ = self.run_with_info()
|
||||
return task
|
||||
|
||||
@transaction(on_error=partial(on_error, reraise=TaskCreateFailedError))
|
||||
def run_with_info(self) -> tuple["Task", bool]:
|
||||
"""
|
||||
Execute the command and return (task, is_new) tuple.
|
||||
|
||||
This variant allows callers to distinguish between creating a new task
|
||||
and joining an existing one. Useful for sync execution where the caller
|
||||
needs to wait for an existing task to complete rather than executing again.
|
||||
|
||||
:returns: Tuple of (Task, is_new) where is_new is True if task was created
|
||||
"""
|
||||
from superset.daos.tasks import TaskDAO
|
||||
|
||||
self.validate()
|
||||
|
||||
# Extract and normalize parameters
|
||||
task_type = self._properties["task_type"]
|
||||
task_key = self._properties.get("task_key") or str(uuid.uuid4())
|
||||
scope = self._properties.get("scope", TaskScope.PRIVATE.value)
|
||||
user_id = self._properties.get("user_id")
|
||||
|
||||
# Build dedup_key for lock
|
||||
dedup_key = get_active_dedup_key(
|
||||
scope=scope,
|
||||
task_type=task_type,
|
||||
task_key=task_key,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Acquire lock to prevent race conditions during create/join
|
||||
with task_lock(dedup_key):
|
||||
# Check for existing task (safe under lock)
|
||||
existing = TaskDAO.find_by_task_key(task_type, task_key, scope, user_id)
|
||||
|
||||
# Get stats logger
|
||||
stats_logger: BaseStatsLogger = current_app.config["STATS_LOGGER"]
|
||||
|
||||
if existing:
|
||||
# Join existing task - add subscriber if not already subscribed
|
||||
if user_id and not existing.has_subscriber(user_id):
|
||||
TaskDAO.add_subscriber(existing.id, user_id)
|
||||
stats_logger.incr("gtf.task.subscribe")
|
||||
logger.info(
|
||||
"User %s joined existing task: %s",
|
||||
user_id,
|
||||
task_key,
|
||||
)
|
||||
else:
|
||||
# Same user submitted the same task - deduplication hit
|
||||
stats_logger.incr("gtf.task.dedupe")
|
||||
logger.debug(
|
||||
"Deduplication hit for task: %s (user_id=%s)",
|
||||
task_key,
|
||||
user_id,
|
||||
)
|
||||
return existing, False # is_new=False: joined existing task
|
||||
|
||||
# Create new task (DAO is now a pure data operation)
|
||||
try:
|
||||
task = TaskDAO.create_task(
|
||||
task_type=task_type,
|
||||
task_key=task_key,
|
||||
scope=scope,
|
||||
task_name=self._properties.get("task_name"),
|
||||
user_id=user_id,
|
||||
payload=self._properties.get("payload", {}),
|
||||
properties=self._properties.get("properties", {}),
|
||||
)
|
||||
stats_logger.incr("gtf.task.create")
|
||||
return task, True # is_new=True: created new task
|
||||
except DAOCreateFailedError as ex:
|
||||
raise TaskCreateFailedError() from ex
|
||||
|
||||
def validate(self) -> None:
|
||||
"""Validate command parameters."""
|
||||
exceptions: list[ValidationError] = []
|
||||
|
||||
# Require task_type
|
||||
if not self._properties.get("task_type"):
|
||||
exceptions.append(
|
||||
ValidationError("task_type is required", field_name="task_type")
|
||||
)
|
||||
|
||||
scope = self._properties.get("scope", TaskScope.PRIVATE.value)
|
||||
scope_value = scope.value if isinstance(scope, TaskScope) else scope
|
||||
valid_scopes = [s.value for s in TaskScope]
|
||||
if scope_value not in valid_scopes:
|
||||
exceptions.append(
|
||||
ValidationError(
|
||||
f"scope must be one of {valid_scopes}",
|
||||
field_name="scope",
|
||||
)
|
||||
)
|
||||
# Store normalized value for use in run()
|
||||
self._properties["scope"] = scope_value
|
||||
|
||||
if exceptions:
|
||||
raise TaskInvalidError(exceptions=exceptions)
|
||||
170
superset/commands/tasks/update.py
Normal file
170
superset/commands/tasks/update.py
Normal file
@@ -0,0 +1,170 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from typing import Any, TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
from superset_core.api.tasks import TaskProperties
|
||||
|
||||
from superset import security_manager
|
||||
from superset.commands.base import BaseCommand
|
||||
from superset.commands.tasks.exceptions import (
|
||||
TaskForbiddenError,
|
||||
TaskNotFoundError,
|
||||
TaskUpdateFailedError,
|
||||
)
|
||||
from superset.exceptions import SupersetSecurityException
|
||||
from superset.tasks.locks import task_lock
|
||||
from superset.tasks.utils import get_active_dedup_key
|
||||
from superset.utils.decorators import on_error, transaction
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.models.tasks import Task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UpdateTaskCommand(BaseCommand):
|
||||
"""
|
||||
Command to update a task.
|
||||
|
||||
Uses explicit typed parameters to avoid confusion between
|
||||
payload (task output) and properties (runtime state/config).
|
||||
|
||||
This command acquires a distributed lock to prevent race conditions with
|
||||
concurrent submit/cancel operations on the same logical task.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
task_uuid: UUID,
|
||||
*,
|
||||
status: str | None = None,
|
||||
started_at: datetime | None = None,
|
||||
ended_at: datetime | None = None,
|
||||
payload: dict[str, Any] | None = None,
|
||||
properties: TaskProperties | None = None,
|
||||
skip_security_check: bool = False,
|
||||
):
|
||||
"""
|
||||
Initialize UpdateTaskCommand.
|
||||
|
||||
:param task_uuid: UUID of the task to update
|
||||
:param status: New status value (column field)
|
||||
:param started_at: Started timestamp (column field)
|
||||
:param ended_at: Ended timestamp (column field)
|
||||
:param payload: Task output data to merge (stored in payload column)
|
||||
:param properties: Runtime state/config updates as dict. Keys must be
|
||||
valid TaskProperties field names (is_abortable, progress_percent, etc.)
|
||||
:param skip_security_check: If True, skip ownership validation.
|
||||
Use this for internal task updates (e.g., task executor updating
|
||||
its own task's progress). Default is False for API-driven updates.
|
||||
"""
|
||||
self._task_uuid = task_uuid
|
||||
self._status = status
|
||||
self._started_at = started_at
|
||||
self._ended_at = ended_at
|
||||
self._payload = payload
|
||||
self._properties = properties
|
||||
self._model: Task | None = None
|
||||
self._skip_security_check = skip_security_check
|
||||
|
||||
@transaction(on_error=partial(on_error, reraise=TaskUpdateFailedError))
|
||||
def run(self) -> Task:
|
||||
"""
|
||||
Execute the update command with distributed locking.
|
||||
|
||||
Acquires lock based on dedup_key to prevent race conditions with
|
||||
concurrent submit/cancel operations on the same logical task.
|
||||
|
||||
:returns: The updated task model
|
||||
"""
|
||||
from superset.daos.tasks import TaskDAO
|
||||
|
||||
self.validate()
|
||||
|
||||
# Fetch task to compute dedup_key for locking
|
||||
task = TaskDAO.find_one_or_none(
|
||||
skip_base_filter=self._skip_security_check,
|
||||
uuid=self._task_uuid,
|
||||
)
|
||||
if not task:
|
||||
raise TaskNotFoundError()
|
||||
|
||||
self._model = task
|
||||
|
||||
# Build lock key from task properties (same structure as dedup_key)
|
||||
dedup_key = get_active_dedup_key(
|
||||
scope=self._model.scope,
|
||||
task_type=self._model.task_type,
|
||||
task_key=self._model.task_key,
|
||||
user_id=self._model.user_id,
|
||||
)
|
||||
|
||||
# Acquire lock to prevent race with submit/cancel operations
|
||||
with task_lock(dedup_key):
|
||||
return self._execute_update()
|
||||
|
||||
def _execute_update(self) -> "Task":
|
||||
"""
|
||||
Execute the update operation under lock.
|
||||
|
||||
:returns: The updated task model
|
||||
"""
|
||||
from superset.daos.tasks import TaskDAO
|
||||
|
||||
# Re-fetch model under lock to get fresh state
|
||||
fresh_model = TaskDAO.find_one_or_none(
|
||||
skip_base_filter=self._skip_security_check,
|
||||
uuid=self._task_uuid,
|
||||
)
|
||||
if not fresh_model:
|
||||
raise TaskNotFoundError()
|
||||
self._model = fresh_model
|
||||
|
||||
# Verify ownership (user can only update their own tasks)
|
||||
# Skip this check for internal updates (e.g., task executor updating progress)
|
||||
if not self._skip_security_check:
|
||||
try:
|
||||
security_manager.raise_for_ownership(self._model)
|
||||
except SupersetSecurityException as ex:
|
||||
raise TaskForbiddenError() from ex
|
||||
|
||||
# Update status via set_status() for proper timestamp handling
|
||||
if self._status is not None:
|
||||
self._model.set_status(self._status)
|
||||
if self._started_at is not None:
|
||||
self._model.started_at = self._started_at
|
||||
if self._ended_at is not None:
|
||||
self._model.ended_at = self._ended_at
|
||||
|
||||
# Update payload (merges with existing)
|
||||
if self._payload is not None:
|
||||
self._model.set_payload(self._payload)
|
||||
|
||||
# Update properties (dict passed through to model)
|
||||
if self._properties:
|
||||
self._model.update_properties(self._properties)
|
||||
|
||||
return TaskDAO.update(self._model)
|
||||
|
||||
def validate(self) -> None:
|
||||
pass
|
||||
@@ -662,6 +662,9 @@ DEFAULT_FEATURE_FLAGS: dict[str, bool] = {
|
||||
# sts:AssumeRole permissions to prevent unauthorized access.
|
||||
# @lifecycle: testing
|
||||
"AWS_DATABASE_IAM_AUTH": False,
|
||||
# Global Task Framework - unified task management with progress tracking,
|
||||
# cancellation, and deduplication.
|
||||
"GLOBAL_TASK_FRAMEWORK": False,
|
||||
# Use analogous colors in charts
|
||||
# @lifecycle: testing
|
||||
"USE_ANALOGOUS_COLORS": False,
|
||||
@@ -1393,6 +1396,12 @@ class CeleryConfig: # pylint: disable=too-few-public-methods
|
||||
# "schedule": crontab(minute="*", hour="*"),
|
||||
# "kwargs": {"retention_period_days": 180, "max_rows_per_run": 10000},
|
||||
# },
|
||||
# Uncomment to enable pruning of the tasks table
|
||||
# "prune_tasks": {
|
||||
# "task": "prune_tasks",
|
||||
# "schedule": crontab(minute=0, hour=0),
|
||||
# "kwargs": {"retention_period_days": 90, "max_rows_per_run": 10000},
|
||||
# },
|
||||
# Uncomment to enable Slack channel cache warm-up
|
||||
# "slack.cache_channels": {
|
||||
# "task": "slack.cache_channels",
|
||||
@@ -2456,6 +2465,62 @@ except ImportError:
|
||||
LOCAL_EXTENSIONS: list[str] = []
|
||||
EXTENSIONS_PATH: str | None = None
|
||||
|
||||
# Default polling interval for tasks (seconds)
|
||||
TASK_ABORT_POLLING_DEFAULT_INTERVAL = 10
|
||||
|
||||
# Minimum interval in seconds between database writes for task progress updates.
|
||||
# Set to 0 to disable throttling (write every update to DB).
|
||||
TASK_PROGRESS_UPDATE_THROTTLE_INTERVAL = 2 # seconds
|
||||
|
||||
# ---------------------------------------------------
|
||||
# Signal Cache Configuration
|
||||
# ---------------------------------------------------
|
||||
# Shared Redis/Valkey configuration for signaling features that require
|
||||
# Redis-specific primitives (pub/sub messaging, distributed locks).
|
||||
#
|
||||
# Uses Flask-Caching style configuration for consistency with other cache backends.
|
||||
# Set CACHE_TYPE to 'RedisCache' for standard Redis or 'RedisSentinelCache' for
|
||||
# Sentinel.
|
||||
#
|
||||
# These features cannot use generic cache backends because they rely on:
|
||||
# - Pub/Sub: Real-time message broadcasting between workers
|
||||
# - SET NX EX: Atomic lock acquisition with automatic expiration
|
||||
#
|
||||
# When configured, enables:
|
||||
# - Real-time abort/completion notifications for GTF tasks (vs database polling)
|
||||
# - Redis-based distributed locking (vs KeyValueDAO-backed DistributedLock)
|
||||
#
|
||||
# Future: This cache will also be used by Global Async Queries, consolidating
|
||||
# GLOBAL_ASYNC_QUERIES_CACHE_BACKEND into this unified configuration.
|
||||
#
|
||||
# Example with standard Redis:
|
||||
# SIGNAL_CACHE_CONFIG: CacheConfig = {
|
||||
# "CACHE_TYPE": "RedisCache",
|
||||
# "CACHE_REDIS_HOST": "localhost",
|
||||
# "CACHE_REDIS_PORT": 6379,
|
||||
# "CACHE_REDIS_DB": 0,
|
||||
# "CACHE_REDIS_PASSWORD": "",
|
||||
# }
|
||||
#
|
||||
# Example with Redis Sentinel:
|
||||
# SIGNAL_CACHE_CONFIG: CacheConfig = {
|
||||
# "CACHE_TYPE": "RedisSentinelCache",
|
||||
# "CACHE_REDIS_SENTINELS": [("sentinel1", 26379), ("sentinel2", 26379)],
|
||||
# "CACHE_REDIS_SENTINEL_MASTER": "mymaster",
|
||||
# "CACHE_REDIS_SENTINEL_PASSWORD": None,
|
||||
# "CACHE_REDIS_DB": 0,
|
||||
# "CACHE_REDIS_PASSWORD": "",
|
||||
# }
|
||||
SIGNAL_CACHE_CONFIG: CacheConfig | None = None
|
||||
|
||||
# Default lock TTL (time-to-live) in seconds for distributed locks.
|
||||
# Can be overridden per-call via the `ttl_seconds` parameter.
|
||||
# After TTL expires, the lock is automatically released to prevent deadlocks.
|
||||
DISTRIBUTED_LOCK_DEFAULT_TTL = 30
|
||||
|
||||
# Channel prefix for task abort pub/sub messages
|
||||
TASKS_ABORT_CHANNEL_PREFIX = "gtf:abort:"
|
||||
|
||||
# -------------------------------------------------------------------
|
||||
# * WARNING: STOP EDITING HERE *
|
||||
# -------------------------------------------------------------------
|
||||
|
||||
@@ -52,6 +52,7 @@ def inject_dao_implementations() -> None:
|
||||
SavedQueryDAO as HostSavedQueryDAO,
|
||||
)
|
||||
from superset.daos.tag import TagDAO as HostTagDAO
|
||||
from superset.daos.tasks import TaskDAO as HostTaskDAO
|
||||
from superset.daos.user import UserDAO as HostUserDAO
|
||||
|
||||
# Replace abstract classes with concrete implementations
|
||||
@@ -64,18 +65,7 @@ def inject_dao_implementations() -> None:
|
||||
core_dao_module.SavedQueryDAO = HostSavedQueryDAO # type: ignore[assignment,misc]
|
||||
core_dao_module.TagDAO = HostTagDAO # type: ignore[assignment,misc]
|
||||
core_dao_module.KeyValueDAO = HostKeyValueDAO # type: ignore[assignment,misc]
|
||||
|
||||
core_dao_module.__all__ = [
|
||||
"DatasetDAO",
|
||||
"DatabaseDAO",
|
||||
"ChartDAO",
|
||||
"DashboardDAO",
|
||||
"UserDAO",
|
||||
"QueryDAO",
|
||||
"SavedQueryDAO",
|
||||
"TagDAO",
|
||||
"KeyValueDAO",
|
||||
]
|
||||
core_dao_module.TaskDAO = HostTaskDAO # type: ignore[assignment,misc]
|
||||
|
||||
|
||||
def inject_model_implementations() -> None:
|
||||
@@ -94,6 +84,7 @@ def inject_model_implementations() -> None:
|
||||
from superset.models.dashboard import Dashboard as HostDashboard
|
||||
from superset.models.slice import Slice as HostChart
|
||||
from superset.models.sql_lab import Query as HostQuery, SavedQuery as HostSavedQuery
|
||||
from superset.models.tasks import Task as HostTask
|
||||
from superset.tags.models import Tag as HostTag
|
||||
|
||||
# In-place replacement - extensions will import concrete implementations
|
||||
@@ -106,6 +97,7 @@ def inject_model_implementations() -> None:
|
||||
core_models_module.SavedQuery = HostSavedQuery # type: ignore[misc]
|
||||
core_models_module.Tag = HostTag # type: ignore[misc]
|
||||
core_models_module.KeyValue = HostKeyValue # type: ignore[misc]
|
||||
core_models_module.Task = HostTask # type: ignore[misc]
|
||||
|
||||
|
||||
def inject_query_implementations() -> None:
|
||||
@@ -124,7 +116,23 @@ def inject_query_implementations() -> None:
|
||||
)
|
||||
|
||||
core_query_module.get_sqlglot_dialect = get_sqlglot_dialect
|
||||
core_query_module.__all__ = ["get_sqlglot_dialect"]
|
||||
|
||||
|
||||
def inject_task_implementations() -> None:
|
||||
"""
|
||||
Replace abstract task functions in superset_core.api.tasks with concrete
|
||||
implementations from Superset.
|
||||
"""
|
||||
import superset_core.api.tasks as core_tasks_module
|
||||
|
||||
from superset.tasks.ambient_context import get_context
|
||||
from superset.tasks.context import TaskContext
|
||||
from superset.tasks.decorators import task
|
||||
|
||||
# Replace abstract classes and functions with concrete implementations
|
||||
core_tasks_module.TaskContext = TaskContext # type: ignore[assignment,misc]
|
||||
core_tasks_module.task = task # type: ignore[assignment]
|
||||
core_tasks_module.get_context = get_context
|
||||
|
||||
|
||||
def inject_rest_api_implementations() -> None:
|
||||
@@ -147,7 +155,6 @@ def inject_rest_api_implementations() -> None:
|
||||
|
||||
core_rest_api_module.add_api = add_api
|
||||
core_rest_api_module.add_extension_api = add_extension_api
|
||||
core_rest_api_module.__all__ = ["RestApi", "add_api", "add_extension_api"]
|
||||
|
||||
|
||||
def inject_model_session_implementation() -> None:
|
||||
@@ -163,7 +170,6 @@ def inject_model_session_implementation() -> None:
|
||||
return db.session
|
||||
|
||||
core_models_module.get_session = get_session
|
||||
# Update __all__ to include get_session (already done in the module)
|
||||
|
||||
|
||||
def initialize_core_api_dependencies() -> None:
|
||||
@@ -177,4 +183,5 @@ def initialize_core_api_dependencies() -> None:
|
||||
inject_model_implementations()
|
||||
inject_model_session_implementation()
|
||||
inject_query_implementations()
|
||||
inject_task_implementations()
|
||||
inject_rest_api_implementations()
|
||||
|
||||
@@ -18,7 +18,7 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, TYPE_CHECKING
|
||||
from typing import Dict, List
|
||||
|
||||
from flask_appbuilder.models.sqla.interface import SQLAInterface
|
||||
|
||||
@@ -30,9 +30,6 @@ from superset.models.core import FavStar, FavStarClassName
|
||||
from superset.models.slice import id_or_uuid_filter, Slice
|
||||
from superset.utils.core import get_user_id
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Custom filterable fields for charts
|
||||
|
||||
470
superset/daos/tasks.py
Normal file
470
superset/daos/tasks.py
Normal file
@@ -0,0 +1,470 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""Task DAO for Global Task Framework (GTF)"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from superset_core.api.tasks import TaskProperties, TaskScope, TaskStatus
|
||||
|
||||
from superset.daos.base import BaseDAO
|
||||
from superset.daos.exceptions import DAODeleteFailedError
|
||||
from superset.extensions import db
|
||||
from superset.models.task_subscribers import TaskSubscriber
|
||||
from superset.models.tasks import Task
|
||||
from superset.tasks.constants import ABORTABLE_STATES
|
||||
from superset.tasks.filters import TaskFilter
|
||||
from superset.tasks.utils import get_active_dedup_key, json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TaskDAO(BaseDAO[Task]):
|
||||
"""
|
||||
Concrete TaskDAO for the Global Task Framework (GTF).
|
||||
|
||||
Provides database access operations for async tasks including
|
||||
creation, status management, filtering, and subscription management
|
||||
for shared tasks.
|
||||
"""
|
||||
|
||||
base_filter = TaskFilter
|
||||
|
||||
@classmethod
|
||||
def get_status(cls, task_uuid: UUID) -> str | None:
|
||||
"""
|
||||
Get only the status of a task by UUID.
|
||||
|
||||
This is a lightweight query that only fetches the status column,
|
||||
optimized for polling endpoints where full entity loading is unnecessary.
|
||||
Applies the base filter (TaskFilter) to enforce permission checks.
|
||||
|
||||
:param task_uuid: UUID of the task
|
||||
:returns: Task status string, or None if task not found or not accessible
|
||||
"""
|
||||
# Start with query on Task model so base filter can be applied
|
||||
query = db.session.query(Task)
|
||||
query = cls._apply_base_filter(query)
|
||||
query = query.filter(Task.uuid == task_uuid)
|
||||
|
||||
# Select only the status column for efficiency
|
||||
result = query.with_entities(Task.status).one_or_none()
|
||||
return result[0] if result else None
|
||||
|
||||
@classmethod
|
||||
def find_by_task_key(
|
||||
cls,
|
||||
task_type: str,
|
||||
task_key: str,
|
||||
scope: TaskScope | str = TaskScope.PRIVATE,
|
||||
user_id: int | None = None,
|
||||
) -> Task | None:
|
||||
"""
|
||||
Find active task by type, key, scope, and user.
|
||||
|
||||
Uses dedup_key internally for efficient querying with a unique index.
|
||||
Only returns tasks that are active (pending or in progress).
|
||||
|
||||
Uniqueness logic by scope:
|
||||
- private: scope + task_type + task_key + user_id
|
||||
- shared/system: scope + task_type + task_key (user-agnostic)
|
||||
|
||||
:param task_type: Task type to filter by
|
||||
:param task_key: Task identifier for deduplication
|
||||
:param scope: Task scope (private/shared/system)
|
||||
:param user_id: User ID (required for private tasks)
|
||||
:returns: Task instance or None if not found or not active
|
||||
"""
|
||||
dedup_key = get_active_dedup_key(
|
||||
scope=scope,
|
||||
task_type=task_type,
|
||||
task_key=task_key,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Simple single-column query with unique index
|
||||
return db.session.query(Task).filter(Task.dedup_key == dedup_key).one_or_none()
|
||||
|
||||
@classmethod
|
||||
def create_task(
|
||||
cls,
|
||||
task_type: str,
|
||||
task_key: str,
|
||||
scope: TaskScope | str = TaskScope.PRIVATE,
|
||||
user_id: int | None = None,
|
||||
payload: dict[str, Any] | None = None,
|
||||
properties: TaskProperties | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Task:
|
||||
"""
|
||||
Create a new task record in the database.
|
||||
|
||||
This is a pure data operation - assumes caller holds lock and has
|
||||
already checked for existing tasks. Business logic (create vs join)
|
||||
is handled by SubmitTaskCommand.
|
||||
|
||||
:param task_type: Type of task to create
|
||||
:param task_key: Task identifier (required)
|
||||
:param scope: Task scope (private/shared/system), defaults to private
|
||||
:param user_id: User ID creating the task
|
||||
:param payload: Optional user-defined context data (dict)
|
||||
:param properties: Optional framework-managed runtime state (e.g., timeout)
|
||||
:param kwargs: Additional task attributes (e.g., task_name)
|
||||
:returns: Created Task instance
|
||||
"""
|
||||
# Handle both TaskScope enum and string values
|
||||
scope_value = scope.value if isinstance(scope, TaskScope) else scope
|
||||
scope_enum = scope if isinstance(scope, TaskScope) else TaskScope(scope)
|
||||
|
||||
# Validate user_id is required for private tasks
|
||||
if scope_enum == TaskScope.PRIVATE and user_id is None:
|
||||
raise ValueError("user_id is required for private tasks")
|
||||
|
||||
# Build dedup_key for active task
|
||||
dedup_key = get_active_dedup_key(
|
||||
scope=scope,
|
||||
task_type=task_type,
|
||||
task_key=task_key,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Note: properties is handled separately via update_properties()
|
||||
task_data = {
|
||||
"task_type": task_type,
|
||||
"task_key": task_key,
|
||||
"scope": scope_value,
|
||||
"status": TaskStatus.PENDING.value,
|
||||
"dedup_key": dedup_key,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
# Handle payload - serialize to JSON if dict provided
|
||||
if payload:
|
||||
task_data["payload"] = json.dumps(payload)
|
||||
|
||||
if user_id is not None:
|
||||
task_data["user_id"] = user_id
|
||||
|
||||
task = cls.create(attributes=task_data)
|
||||
|
||||
# Set properties after creation via update_properties (handles caching)
|
||||
if properties:
|
||||
task.update_properties(properties)
|
||||
|
||||
# Flush to get the task ID (auto-incremented primary key)
|
||||
db.session.flush()
|
||||
|
||||
# Auto-subscribe creator for all tasks
|
||||
# This enables consistent subscriber display across all task types
|
||||
if user_id:
|
||||
cls.add_subscriber(task.id, user_id)
|
||||
logger.info(
|
||||
"Creator %s auto-subscribed to task: %s (scope: %s)",
|
||||
user_id,
|
||||
task_key,
|
||||
scope_value,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Created new async task: %s (type: %s, scope: %s)",
|
||||
task_key,
|
||||
task_type,
|
||||
scope_value,
|
||||
)
|
||||
return task
|
||||
|
||||
@classmethod
|
||||
def abort_task(cls, task_uuid: UUID, skip_base_filter: bool = False) -> Task | None:
|
||||
"""
|
||||
Abort a task by UUID.
|
||||
|
||||
This is a pure data operation. Business logic (subscriber count checks,
|
||||
permission validation) is handled by CancelTaskCommand which holds the lock.
|
||||
|
||||
Abort behavior by status:
|
||||
- PENDING: Goes directly to ABORTED (always abortable)
|
||||
- IN_PROGRESS with is_abortable=True: Goes to ABORTING
|
||||
- IN_PROGRESS with is_abortable=False/None: Raises TaskNotAbortableError
|
||||
- ABORTING: Returns task (idempotent)
|
||||
- Finished statuses: Returns None
|
||||
|
||||
Note: Caller is responsible for calling TaskManager.publish_abort() AFTER
|
||||
the transaction commits if task.status == ABORTING. This prevents race
|
||||
conditions where listeners check the DB before the status is visible.
|
||||
|
||||
:param task_uuid: UUID of task to abort
|
||||
:param skip_base_filter: If True, skip base filter (for admin abortions)
|
||||
:returns: Task if aborted/aborting, None if not found or already finished
|
||||
:raises TaskNotAbortableError: If in-progress task has no abort handler
|
||||
"""
|
||||
from superset.commands.tasks.exceptions import TaskNotAbortableError
|
||||
|
||||
task = cls.find_one_or_none(skip_base_filter=skip_base_filter, uuid=task_uuid)
|
||||
if not task:
|
||||
return None
|
||||
|
||||
# Already aborting - idempotent success
|
||||
if task.status == TaskStatus.ABORTING.value:
|
||||
logger.info("Task %s is already aborting", task_uuid)
|
||||
return task
|
||||
|
||||
# Already finished - cannot abort
|
||||
if task.status not in ABORTABLE_STATES:
|
||||
return None
|
||||
|
||||
# PENDING: Go directly to ABORTED
|
||||
if task.status == TaskStatus.PENDING.value:
|
||||
task.set_status(TaskStatus.ABORTED)
|
||||
logger.info("Aborted pending task: %s (scope: %s)", task_uuid, task.scope)
|
||||
return task
|
||||
|
||||
# IN_PROGRESS: Check if abortable
|
||||
if task.status == TaskStatus.IN_PROGRESS.value:
|
||||
if task.properties_dict.get("is_abortable") is not True:
|
||||
raise TaskNotAbortableError(
|
||||
f"Task {task_uuid} is in progress but has not registered "
|
||||
"an abort handler (is_abortable is not true)"
|
||||
)
|
||||
|
||||
# Transition to ABORTING (not ABORTED yet)
|
||||
task.status = TaskStatus.ABORTING.value
|
||||
db.session.merge(task)
|
||||
logger.info("Set task %s to ABORTING (scope: %s)", task_uuid, task.scope)
|
||||
|
||||
# NOTE: publish_abort is NOT called here - caller handles it after commit
|
||||
# This prevents race conditions where listeners check DB before commit
|
||||
|
||||
return task
|
||||
|
||||
return None
|
||||
|
||||
# Subscription management methods
|
||||
|
||||
@classmethod
|
||||
def add_subscriber(cls, task_id: int, user_id: int) -> bool:
|
||||
"""
|
||||
Add a user as a subscriber to a task.
|
||||
|
||||
:param task_id: ID of the task
|
||||
:param user_id: ID of the user to subscribe
|
||||
:returns: True if subscriber was added, False if already exists
|
||||
"""
|
||||
# Check first to avoid IntegrityError which invalidates the session
|
||||
# in nested transaction contexts (IntegrityError can't be recovered from)
|
||||
existing = (
|
||||
db.session.query(TaskSubscriber)
|
||||
.filter_by(task_id=task_id, user_id=user_id)
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
logger.debug(
|
||||
"Subscriber %s already subscribed to task %s", user_id, task_id
|
||||
)
|
||||
return False
|
||||
|
||||
subscription = TaskSubscriber(
|
||||
task_id=task_id,
|
||||
user_id=user_id,
|
||||
subscribed_at=datetime.now(timezone.utc),
|
||||
)
|
||||
db.session.add(subscription)
|
||||
db.session.flush()
|
||||
logger.info("Added subscriber %s to task %s", user_id, task_id)
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def remove_subscriber(cls, task_id: int, user_id: int) -> Task | None:
|
||||
"""
|
||||
Remove a user's subscription from a task and return the updated task.
|
||||
|
||||
This is a pure data operation. Business logic (whether to abort after
|
||||
last subscriber leaves) is handled by CancelTaskCommand which holds
|
||||
the lock and decides whether to call abort_task() separately.
|
||||
|
||||
:param task_id: ID of the task
|
||||
:param user_id: ID of the user to unsubscribe
|
||||
:returns: Updated Task if subscriber was removed, None if not subscribed
|
||||
:raises DAODeleteFailedError: If subscription removal fails
|
||||
"""
|
||||
subscription = (
|
||||
db.session.query(TaskSubscriber)
|
||||
.filter(
|
||||
TaskSubscriber.task_id == task_id,
|
||||
TaskSubscriber.user_id == user_id,
|
||||
)
|
||||
.one_or_none()
|
||||
)
|
||||
|
||||
if not subscription:
|
||||
return None
|
||||
|
||||
try:
|
||||
db.session.delete(subscription)
|
||||
db.session.flush()
|
||||
logger.info("Removed subscriber %s from task %s", user_id, task_id)
|
||||
|
||||
# Return the updated task
|
||||
task = cls.find_by_id(task_id, skip_base_filter=True)
|
||||
if task:
|
||||
db.session.refresh(task) # Ensure subscribers list is fresh
|
||||
return task
|
||||
|
||||
except DAODeleteFailedError:
|
||||
raise
|
||||
except Exception as ex:
|
||||
raise DAODeleteFailedError(
|
||||
f"Failed to remove subscription for task {task_id}, user {user_id}"
|
||||
) from ex
|
||||
|
||||
@classmethod
|
||||
def set_properties_and_payload(
|
||||
cls,
|
||||
task_uuid: UUID,
|
||||
properties: TaskProperties | None = None,
|
||||
payload: dict[str, Any] | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Perform a zero-read SQL UPDATE on properties and/or payload columns.
|
||||
|
||||
This method directly writes the provided values without reading first.
|
||||
The caller (TaskContext) is responsible for maintaining the authoritative
|
||||
cached state and passing complete values to write.
|
||||
|
||||
This method is designed for internal task updates (progress, is_abortable)
|
||||
where the executor owns the state and doesn't need to read before writing.
|
||||
|
||||
IMPORTANT: This method only touches properties and payload columns.
|
||||
It does NOT touch the status column, so it's safe to use concurrently
|
||||
with operations that modify status (like abort).
|
||||
|
||||
:param task_uuid: UUID of the task to update
|
||||
:param properties: Complete properties dict to write (replaces existing)
|
||||
:param payload: Complete payload dict to write (replaces existing)
|
||||
:returns: True if task was updated, False if not found or nothing to update
|
||||
"""
|
||||
if properties is None and payload is None:
|
||||
return False
|
||||
|
||||
# Build update values dict - no reads, just write what caller provides
|
||||
update_values: dict[str, Any] = {}
|
||||
|
||||
if properties is not None:
|
||||
# Write complete properties (caller manages merging in their cache)
|
||||
update_values["properties"] = json.dumps(properties)
|
||||
|
||||
if payload is not None:
|
||||
# Write complete payload (payload column name matches attribute name)
|
||||
update_values["payload"] = json.dumps(payload)
|
||||
|
||||
if not update_values:
|
||||
return False
|
||||
|
||||
# Execute targeted UPDATE - zero read, just write
|
||||
rows_updated = (
|
||||
db.session.query(Task)
|
||||
.filter(Task.uuid == task_uuid)
|
||||
.update(update_values, synchronize_session=False)
|
||||
)
|
||||
|
||||
return rows_updated > 0
|
||||
|
||||
@classmethod
|
||||
def conditional_status_update(
|
||||
cls,
|
||||
task_uuid: UUID,
|
||||
new_status: TaskStatus | str,
|
||||
expected_status: TaskStatus | str | list[TaskStatus | str],
|
||||
properties: TaskProperties | None = None,
|
||||
set_started_at: bool = False,
|
||||
set_ended_at: bool = False,
|
||||
) -> bool:
|
||||
"""
|
||||
Atomically update task status only if current status matches expected.
|
||||
|
||||
This provides atomic compare-and-swap semantics for status transitions,
|
||||
preventing race conditions between executor status updates and concurrent
|
||||
abort operations. Uses a single UPDATE with WHERE clause for atomicity.
|
||||
|
||||
Use cases:
|
||||
- Executor transitioning IN_PROGRESS → SUCCESS (only if not ABORTING)
|
||||
- Executor transitioning ABORTING → ABORTED/TIMED_OUT (cleanup complete)
|
||||
- Initial PENDING → IN_PROGRESS (task pickup)
|
||||
|
||||
:param task_uuid: UUID of the task to update
|
||||
:param new_status: Target status to set
|
||||
:param expected_status: Current status(es) required for update to succeed.
|
||||
Can be a single status or list of statuses.
|
||||
:param properties: Optional properties to update atomically with status
|
||||
:param set_started_at: If True, also set started_at to current timestamp
|
||||
:param set_ended_at: If True, also set ended_at to current timestamp
|
||||
:returns: True if status was updated (expected matched), False otherwise
|
||||
"""
|
||||
# Normalize status values
|
||||
new_status_val = (
|
||||
new_status.value if isinstance(new_status, TaskStatus) else new_status
|
||||
)
|
||||
|
||||
# Build list of expected status values
|
||||
if isinstance(expected_status, list):
|
||||
expected_vals = [
|
||||
s.value if isinstance(s, TaskStatus) else s for s in expected_status
|
||||
]
|
||||
else:
|
||||
expected_vals = [
|
||||
expected_status.value
|
||||
if isinstance(expected_status, TaskStatus)
|
||||
else expected_status
|
||||
]
|
||||
|
||||
# Build update values
|
||||
update_values: dict[str, Any] = {"status": new_status_val}
|
||||
|
||||
if properties is not None:
|
||||
update_values["properties"] = json.dumps(properties)
|
||||
|
||||
if set_started_at:
|
||||
update_values["started_at"] = datetime.now(timezone.utc)
|
||||
|
||||
if set_ended_at:
|
||||
update_values["ended_at"] = datetime.now(timezone.utc)
|
||||
|
||||
# Atomic compare-and-swap: only update if status matches expected
|
||||
rows_updated = (
|
||||
db.session.query(Task)
|
||||
.filter(Task.uuid == task_uuid, Task.status.in_(expected_vals))
|
||||
.update(update_values, synchronize_session=False)
|
||||
)
|
||||
|
||||
if rows_updated > 0:
|
||||
logger.debug(
|
||||
"Conditional status update succeeded: %s -> %s (expected: %s)",
|
||||
task_uuid,
|
||||
new_status_val,
|
||||
expected_vals,
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"Conditional status update skipped: %s -> %s "
|
||||
"(current status not in expected: %s)",
|
||||
task_uuid,
|
||||
new_status_val,
|
||||
expected_vals,
|
||||
)
|
||||
|
||||
return rows_updated > 0
|
||||
@@ -17,61 +17,46 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
from superset.distributed_lock.utils import get_key
|
||||
from superset.exceptions import CreateKeyValueDistributedLockFailedException
|
||||
from superset.key_value.types import JsonKeyValueCodec, KeyValueResource
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CODEC = JsonKeyValueCodec()
|
||||
LOCK_EXPIRATION = timedelta(seconds=30)
|
||||
RESOURCE = KeyValueResource.LOCK
|
||||
|
||||
|
||||
@contextmanager
|
||||
def KeyValueDistributedLock( # pylint: disable=invalid-name # noqa: N802
|
||||
def DistributedLock( # noqa: N802
|
||||
namespace: str,
|
||||
ttl_seconds: int | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[uuid.UUID]:
|
||||
"""
|
||||
KV global lock for refreshing tokens.
|
||||
Distributed lock for coordinating operations across workers.
|
||||
|
||||
This context manager acquires a distributed lock for a given namespace, with
|
||||
optional parameters (eg, namespace="cache", user_id=1). It yields a UUID for the
|
||||
lock that can be used within the context, and corresponds to the key in the KV
|
||||
store.
|
||||
Automatically uses Redis-based locking when SIGNAL_CACHE_CONFIG is
|
||||
configured, falling back to database-backed locking otherwise.
|
||||
|
||||
:param namespace: The namespace for which the lock is to be acquired.
|
||||
:param kwargs: Additional keyword arguments.
|
||||
:yields: A unique identifier (UUID) for the acquired lock (the KV key).
|
||||
:raises CreateKeyValueDistributedLockFailedException: If the lock is taken.
|
||||
Redis locking uses SET NX EX for atomic acquisition with automatic expiration.
|
||||
Database locking uses the KeyValue table with manual expiration cleanup.
|
||||
|
||||
:param namespace: Lock namespace for grouping related locks
|
||||
:param ttl_seconds: Lock TTL in seconds. Defaults to 30 seconds.
|
||||
After expiration, the lock is automatically released
|
||||
to prevent deadlocks from crashed processes.
|
||||
:param kwargs: Additional key parameters to differentiate locks
|
||||
:yields: UUID identifying this lock acquisition
|
||||
:raises AcquireDistributedLockFailedException: If lock is already held
|
||||
or Redis connection fails
|
||||
"""
|
||||
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from superset.commands.distributed_lock.create import CreateDistributedLock
|
||||
from superset.commands.distributed_lock.delete import DeleteDistributedLock
|
||||
from superset.commands.distributed_lock.get import GetDistributedLock
|
||||
from superset.commands.distributed_lock.acquire import AcquireDistributedLock
|
||||
from superset.commands.distributed_lock.release import ReleaseDistributedLock
|
||||
|
||||
key = get_key(namespace, **kwargs)
|
||||
value = GetDistributedLock(namespace=namespace, params=kwargs).run()
|
||||
if value:
|
||||
logger.debug("Lock on namespace %s for key %s already taken", namespace, key)
|
||||
raise CreateKeyValueDistributedLockFailedException("Lock already taken")
|
||||
|
||||
logger.debug("Acquiring lock on namespace %s for key %s", namespace, key)
|
||||
AcquireDistributedLock(namespace, kwargs, ttl_seconds).run()
|
||||
try:
|
||||
CreateDistributedLock(namespace=namespace, params=kwargs).run()
|
||||
except CreateKeyValueDistributedLockFailedException as ex:
|
||||
logger.debug("Lock on namespace %s for key %s already taken", namespace, key)
|
||||
raise CreateKeyValueDistributedLockFailedException("Lock already taken") from ex
|
||||
|
||||
yield key
|
||||
DeleteDistributedLock(namespace=namespace, params=kwargs).run()
|
||||
logger.debug("Removed lock on namespace %s for key %s", namespace, key)
|
||||
yield key
|
||||
finally:
|
||||
ReleaseDistributedLock(namespace, kwargs).run()
|
||||
|
||||
@@ -414,15 +414,15 @@ class SupersetDisallowedSQLTableException(SupersetErrorException):
|
||||
)
|
||||
|
||||
|
||||
class CreateKeyValueDistributedLockFailedException(Exception): # noqa: N818
|
||||
class AcquireDistributedLockFailedException(Exception): # noqa: N818
|
||||
"""
|
||||
Exception to signalize failure to acquire lock.
|
||||
"""
|
||||
|
||||
|
||||
class DeleteKeyValueDistributedLockFailedException(Exception): # noqa: N818
|
||||
class ReleaseDistributedLockFailedException(Exception): # noqa: N818
|
||||
"""
|
||||
Exception to signalize failure to delete lock.
|
||||
Exception to signalize failure to release lock.
|
||||
"""
|
||||
|
||||
|
||||
|
||||
@@ -23,13 +23,10 @@ import os
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, TYPE_CHECKING
|
||||
from typing import Any
|
||||
|
||||
from flask import Flask
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Guard to prevent multiple initializations
|
||||
|
||||
@@ -218,6 +218,7 @@ class SupersetAppInitializer: # pylint: disable=too-many-public-methods
|
||||
)
|
||||
from superset.views.sqllab import SqllabView
|
||||
from superset.views.tags import TagModelView, TagView
|
||||
from superset.views.tasks import TaskModelView
|
||||
from superset.views.themes import ThemeModelView
|
||||
from superset.views.user_info import UserInfoView
|
||||
from superset.views.user_registrations import UserRegistrationsView
|
||||
@@ -275,6 +276,11 @@ class SupersetAppInitializer: # pylint: disable=too-many-public-methods
|
||||
|
||||
appbuilder.add_api(ExtensionsRestApi)
|
||||
|
||||
if feature_flag_manager.is_feature_enabled("GLOBAL_TASK_FRAMEWORK"):
|
||||
from superset.tasks.api import TaskRestApi
|
||||
|
||||
appbuilder.add_api(TaskRestApi)
|
||||
|
||||
#
|
||||
# Setup regular views
|
||||
#
|
||||
@@ -408,6 +414,18 @@ class SupersetAppInitializer: # pylint: disable=too-many-public-methods
|
||||
),
|
||||
)
|
||||
|
||||
appbuilder.add_view(
|
||||
TaskModelView,
|
||||
"Tasks",
|
||||
label=_("Tasks"),
|
||||
icon="fa-clock-o",
|
||||
category="Manage",
|
||||
category_label=_("Manage"),
|
||||
menu_cond=lambda: feature_flag_manager.is_feature_enabled(
|
||||
"GLOBAL_TASK_FRAMEWORK"
|
||||
),
|
||||
)
|
||||
|
||||
#
|
||||
# Setup views with no menu
|
||||
#
|
||||
@@ -588,6 +606,7 @@ class SupersetAppInitializer: # pylint: disable=too-many-public-methods
|
||||
self.configure_async_queries()
|
||||
self.configure_ssh_manager()
|
||||
self.configure_stats_manager()
|
||||
self.configure_task_manager()
|
||||
|
||||
# Hook that provides administrators a handle on the Flask APP
|
||||
# after initialization
|
||||
@@ -928,6 +947,13 @@ class SupersetAppInitializer: # pylint: disable=too-many-public-methods
|
||||
if feature_flag_manager.is_feature_enabled("GLOBAL_ASYNC_QUERIES"):
|
||||
async_query_manager_factory.init_app(self.superset_app)
|
||||
|
||||
def configure_task_manager(self) -> None:
|
||||
"""Initialize the TaskManager for GTF realtime notifications."""
|
||||
if feature_flag_manager.is_feature_enabled("GLOBAL_TASK_FRAMEWORK"):
|
||||
from superset.tasks.manager import TaskManager
|
||||
|
||||
TaskManager.init_app(self.superset_app)
|
||||
|
||||
def register_blueprints(self) -> None:
|
||||
# Register custom blueprints from config
|
||||
for bp in self.config["BLUEPRINTS"]:
|
||||
|
||||
@@ -0,0 +1,221 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""Create tasks and task_subscriber tables for Global Task Framework (GTF)
|
||||
|
||||
Revision ID: 4b2a8c9d3e1f
|
||||
Revises: 9787190b3d89
|
||||
Create Date: 2025-12-18 02:20:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from sqlalchemy import (
|
||||
Column,
|
||||
DateTime,
|
||||
Integer,
|
||||
String,
|
||||
Text,
|
||||
UniqueConstraint,
|
||||
)
|
||||
from sqlalchemy_utils import UUIDType
|
||||
|
||||
from superset.migrations.shared.utils import (
|
||||
create_fks_for_table,
|
||||
create_index,
|
||||
create_table,
|
||||
drop_fks_for_table,
|
||||
drop_index,
|
||||
drop_table,
|
||||
)
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "4b2a8c9d3e1f"
|
||||
down_revision = "9787190b3d89"
|
||||
|
||||
TASKS_TABLE = "tasks"
|
||||
TASK_SUBSCRIBERS_TABLE = "task_subscribers"
|
||||
|
||||
|
||||
def upgrade():
|
||||
"""
|
||||
Create tasks and task_subscribers tables for the Global Task Framework (GTF).
|
||||
|
||||
This migration creates:
|
||||
1. tasks table - unified tracking for all long running tasks
|
||||
2. task_subscribers table - multi-user task subscriptions for shared tasks
|
||||
|
||||
The scope feature allows tasks to be:
|
||||
- private: user-specific (default)
|
||||
- shared: multi-user collaborative tasks
|
||||
- system: admin-only background tasks
|
||||
"""
|
||||
# Create tasks table
|
||||
create_table(
|
||||
TASKS_TABLE,
|
||||
Column("id", Integer, primary_key=True),
|
||||
Column("uuid", UUIDType(binary=True), nullable=False, unique=True),
|
||||
Column("task_key", String(256), nullable=False),
|
||||
Column("task_type", String(100), nullable=False),
|
||||
Column("task_name", String(256), nullable=True),
|
||||
Column("scope", String(20), nullable=False, server_default="private"),
|
||||
Column("status", String(50), nullable=False),
|
||||
Column("dedup_key", String(64), nullable=False),
|
||||
# AuditMixinNullable columns
|
||||
Column("created_on", DateTime, nullable=True),
|
||||
Column("changed_on", DateTime, nullable=True),
|
||||
Column("created_by_fk", Integer, nullable=True),
|
||||
Column("changed_by_fk", Integer, nullable=True),
|
||||
# Task-specific columns
|
||||
Column("started_at", DateTime, nullable=True),
|
||||
Column("ended_at", DateTime, nullable=True),
|
||||
Column("user_id", Integer, nullable=True),
|
||||
Column("payload", Text, nullable=True),
|
||||
Column("properties", Text, nullable=True),
|
||||
)
|
||||
|
||||
# Create indexes for optimal query performance
|
||||
create_index(TASKS_TABLE, "idx_tasks_dedup_key", ["dedup_key"], unique=True)
|
||||
create_index(TASKS_TABLE, "idx_tasks_status", ["status"])
|
||||
create_index(TASKS_TABLE, "idx_tasks_scope", ["scope"])
|
||||
create_index(TASKS_TABLE, "idx_tasks_ended_at", ["ended_at"])
|
||||
create_index(TASKS_TABLE, "idx_tasks_created_by", ["created_by_fk"])
|
||||
create_index(TASKS_TABLE, "idx_tasks_created_on", ["created_on"])
|
||||
create_index(TASKS_TABLE, "idx_tasks_task_key", ["task_key"])
|
||||
create_index(TASKS_TABLE, "idx_tasks_task_type", ["task_type"])
|
||||
create_index(TASKS_TABLE, "idx_tasks_uuid", ["uuid"], unique=True)
|
||||
|
||||
# Create foreign key constraints for tasks
|
||||
create_fks_for_table(
|
||||
foreign_key_name="fk_tasks_created_by_fk_ab_user",
|
||||
table_name=TASKS_TABLE,
|
||||
referenced_table="ab_user",
|
||||
local_cols=["created_by_fk"],
|
||||
remote_cols=["id"],
|
||||
ondelete="SET NULL",
|
||||
)
|
||||
|
||||
create_fks_for_table(
|
||||
foreign_key_name="fk_tasks_changed_by_fk_ab_user",
|
||||
table_name=TASKS_TABLE,
|
||||
referenced_table="ab_user",
|
||||
local_cols=["changed_by_fk"],
|
||||
remote_cols=["id"],
|
||||
ondelete="SET NULL",
|
||||
)
|
||||
|
||||
create_fks_for_table(
|
||||
foreign_key_name="fk_tasks_user_id_ab_user",
|
||||
table_name=TASKS_TABLE,
|
||||
referenced_table="ab_user",
|
||||
local_cols=["user_id"],
|
||||
remote_cols=["id"],
|
||||
ondelete="SET NULL",
|
||||
)
|
||||
|
||||
# Create task_subscribers table for multi-user task subscriptions
|
||||
create_table(
|
||||
TASK_SUBSCRIBERS_TABLE,
|
||||
Column("id", Integer, primary_key=True),
|
||||
Column("task_id", Integer, nullable=False),
|
||||
Column("user_id", Integer, nullable=False),
|
||||
Column("subscribed_at", DateTime, nullable=False),
|
||||
# AuditMixinNullable columns
|
||||
Column("created_on", DateTime, nullable=True),
|
||||
Column("created_by_fk", Integer, nullable=True),
|
||||
Column("changed_on", DateTime, nullable=True),
|
||||
Column("changed_by_fk", Integer, nullable=True),
|
||||
# Unique constraint defined as part of table creation (SQLite compatible)
|
||||
UniqueConstraint("task_id", "user_id", name="uq_task_subscribers_task_user"),
|
||||
)
|
||||
|
||||
# Create indexes for task_subscribers table
|
||||
create_index(TASK_SUBSCRIBERS_TABLE, "idx_task_subscribers_user_id", ["user_id"])
|
||||
|
||||
# Create foreign key constraints for task_subscribers
|
||||
create_fks_for_table(
|
||||
foreign_key_name="fk_task_subscribers_task_id_tasks",
|
||||
table_name=TASK_SUBSCRIBERS_TABLE,
|
||||
referenced_table=TASKS_TABLE,
|
||||
local_cols=["task_id"],
|
||||
remote_cols=["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
create_fks_for_table(
|
||||
foreign_key_name="fk_task_subscribers_user_id_ab_user",
|
||||
table_name=TASK_SUBSCRIBERS_TABLE,
|
||||
referenced_table="ab_user",
|
||||
local_cols=["user_id"],
|
||||
remote_cols=["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
create_fks_for_table(
|
||||
foreign_key_name="fk_task_subscribers_created_by_fk_ab_user",
|
||||
table_name=TASK_SUBSCRIBERS_TABLE,
|
||||
referenced_table="ab_user",
|
||||
local_cols=["created_by_fk"],
|
||||
remote_cols=["id"],
|
||||
ondelete="SET NULL",
|
||||
)
|
||||
|
||||
create_fks_for_table(
|
||||
foreign_key_name="fk_task_subscribers_changed_by_fk_ab_user",
|
||||
table_name=TASK_SUBSCRIBERS_TABLE,
|
||||
referenced_table="ab_user",
|
||||
local_cols=["changed_by_fk"],
|
||||
remote_cols=["id"],
|
||||
ondelete="SET NULL",
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
"""
|
||||
Drop tasks and task_subscribers tables and all related indexes and foreign keys.
|
||||
"""
|
||||
drop_fks_for_table(
|
||||
TASK_SUBSCRIBERS_TABLE,
|
||||
[
|
||||
"fk_task_subscribers_task_id_tasks",
|
||||
"fk_task_subscribers_user_id_ab_user",
|
||||
"fk_task_subscribers_created_by_fk_ab_user",
|
||||
"fk_task_subscribers_changed_by_fk_ab_user",
|
||||
],
|
||||
)
|
||||
|
||||
drop_index(TASK_SUBSCRIBERS_TABLE, "idx_task_subscribers_user_id")
|
||||
drop_table(TASK_SUBSCRIBERS_TABLE)
|
||||
|
||||
drop_fks_for_table(
|
||||
TASKS_TABLE,
|
||||
[
|
||||
"fk_tasks_created_by_fk_ab_user",
|
||||
"fk_tasks_changed_by_fk_ab_user",
|
||||
"fk_tasks_user_id_ab_user",
|
||||
],
|
||||
)
|
||||
|
||||
drop_index(TASKS_TABLE, "idx_tasks_dedup_key")
|
||||
drop_index(TASKS_TABLE, "idx_tasks_status")
|
||||
drop_index(TASKS_TABLE, "idx_tasks_scope")
|
||||
drop_index(TASKS_TABLE, "idx_tasks_ended_at")
|
||||
drop_index(TASKS_TABLE, "idx_tasks_created_by")
|
||||
drop_index(TASKS_TABLE, "idx_tasks_created_on")
|
||||
drop_index(TASKS_TABLE, "idx_tasks_task_key")
|
||||
drop_index(TASKS_TABLE, "idx_tasks_task_type")
|
||||
drop_index(TASKS_TABLE, "idx_tasks_uuid")
|
||||
|
||||
drop_table(TASKS_TABLE)
|
||||
62
superset/models/task_subscribers.py
Normal file
62
superset/models/task_subscribers.py
Normal file
@@ -0,0 +1,62 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""TaskSubscriber model for tracking multi-user task subscriptions"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from flask_appbuilder import Model
|
||||
from sqlalchemy import Column, DateTime, ForeignKey, Integer, UniqueConstraint
|
||||
from sqlalchemy.orm import relationship
|
||||
from superset_core.api.models import TaskSubscriber as CoreTaskSubscriber
|
||||
|
||||
from superset.models.helpers import AuditMixinNullable
|
||||
|
||||
|
||||
class TaskSubscriber(CoreTaskSubscriber, AuditMixinNullable, Model):
|
||||
"""
|
||||
Model for tracking task subscriptions in shared tasks.
|
||||
|
||||
This model enables multi-user collaboration on async tasks. When a user
|
||||
schedules a shared task with the same parameters as an existing task,
|
||||
they are automatically subscribed to that task instead of creating a
|
||||
duplicate.
|
||||
|
||||
Subscribers can unsubscribe from shared tasks. When the last subscriber
|
||||
unsubscribes, the task is automatically aborted.
|
||||
"""
|
||||
|
||||
__tablename__ = "task_subscribers"
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
task_id = Column(
|
||||
Integer, ForeignKey("tasks.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
user_id = Column(
|
||||
Integer, ForeignKey("ab_user.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
subscribed_at = Column(DateTime, nullable=False, default=datetime.now(timezone.utc))
|
||||
|
||||
# Relationships
|
||||
task = relationship("Task", back_populates="subscribers")
|
||||
user = relationship("User", foreign_keys=[user_id], lazy="joined")
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint("task_id", "user_id", name="uq_task_subscribers_task_user"),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<TaskSubscriber user_id={self.user_id} task_id={self.task_id}>"
|
||||
367
superset/models/tasks.py
Normal file
367
superset/models/tasks.py
Normal file
@@ -0,0 +1,367 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""Task model for Global Task Framework (GTF)"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid as uuid_module
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, cast
|
||||
|
||||
from flask_appbuilder import Model
|
||||
from sqlalchemy import (
|
||||
Column,
|
||||
DateTime,
|
||||
Integer,
|
||||
String,
|
||||
Text,
|
||||
)
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy_utils import UUIDType
|
||||
from superset_core.api.models import Task as CoreTask
|
||||
from superset_core.api.tasks import TaskProperties, TaskStatus
|
||||
|
||||
from superset.models.helpers import AuditMixinNullable
|
||||
from superset.models.task_subscribers import TaskSubscriber
|
||||
from superset.tasks.utils import (
|
||||
error_update,
|
||||
get_finished_dedup_key,
|
||||
parse_properties,
|
||||
serialize_properties,
|
||||
)
|
||||
from superset.utils import json
|
||||
|
||||
|
||||
class Task(CoreTask, AuditMixinNullable, Model):
|
||||
"""
|
||||
Concrete Task model for the Global Task Framework (GTF).
|
||||
|
||||
This model represents async tasks in Superset, providing unified tracking
|
||||
for all background operations including SQL queries, thumbnail generation,
|
||||
reports, and other async operations.
|
||||
|
||||
Non-filterable fields (progress, error info, execution config) are stored
|
||||
in a `properties` JSON blob for schema flexibility.
|
||||
"""
|
||||
|
||||
__tablename__ = "tasks"
|
||||
|
||||
# Primary key and identifiers
|
||||
id = Column(Integer, primary_key=True)
|
||||
uuid = Column(
|
||||
UUIDType(binary=True), nullable=False, unique=True, default=uuid_module.uuid4
|
||||
)
|
||||
|
||||
# Task metadata (filterable)
|
||||
task_key = Column(String(256), nullable=False, index=True) # For deduplication
|
||||
task_type = Column(String(100), nullable=False, index=True) # e.g., 'sql_execution'
|
||||
task_name = Column(String(256), nullable=True) # Human readable name
|
||||
scope = Column(
|
||||
String(20), nullable=False, index=True, default="private"
|
||||
) # private/shared/system
|
||||
status = Column(
|
||||
String(50), nullable=False, index=True, default=TaskStatus.PENDING.value
|
||||
)
|
||||
dedup_key = Column(
|
||||
String(64), nullable=False, unique=True, index=True
|
||||
) # Hashed deduplication key (SHA-256 = 64 chars, UUID = 36 chars)
|
||||
|
||||
# Timestamps
|
||||
started_at = Column(DateTime, nullable=True)
|
||||
ended_at = Column(DateTime, nullable=True)
|
||||
|
||||
# User context for execution
|
||||
user_id = Column(Integer, nullable=True)
|
||||
|
||||
# Task-specific output data (set by task code via ctx.update_task(payload=...))
|
||||
payload = Column(Text, nullable=True, default="{}")
|
||||
|
||||
# Properties JSON blob - contains runtime state and execution config:
|
||||
# - is_abortable: bool - has abort handler registered
|
||||
# - progress_percent: float - progress 0.0-1.0
|
||||
# - progress_current: int - current iteration count
|
||||
# - progress_total: int - total iterations
|
||||
# - error_message: str - human-readable error message
|
||||
# - exception_type: str - exception class name
|
||||
# - stack_trace: str - full formatted traceback
|
||||
# - timeout: int - timeout in seconds
|
||||
properties = Column(Text, nullable=True, default="{}")
|
||||
|
||||
# Relationships
|
||||
# Use lazy="selectin" to avoid N+1 queries when listing tasks with subscribers
|
||||
subscribers = relationship(
|
||||
TaskSubscriber,
|
||||
back_populates="task",
|
||||
cascade="all, delete-orphan",
|
||||
lazy="selectin",
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Task {self.task_type}:{self.task_key} [{self.status}]>"
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Properties accessor
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
@property
|
||||
def properties_dict(self) -> TaskProperties:
|
||||
"""
|
||||
Get typed properties.
|
||||
|
||||
Properties contain runtime state and execution config that doesn't
|
||||
need database filtering. Always use .get() for reads since keys may
|
||||
be absent.
|
||||
|
||||
:returns: TaskProperties dict (sparse - only contains keys that were set)
|
||||
"""
|
||||
return parse_properties(self.properties)
|
||||
|
||||
def update_properties(self, updates: TaskProperties) -> None:
|
||||
"""
|
||||
Update specific properties fields (merge semantics).
|
||||
|
||||
Only updates fields present in the updates dict.
|
||||
|
||||
:param updates: TaskProperties dict with fields to update
|
||||
|
||||
Example:
|
||||
task.update_properties({"is_abortable": True})
|
||||
task.update_properties(progress_update((50, 100)))
|
||||
"""
|
||||
current = cast(TaskProperties, dict(self.properties_dict))
|
||||
current.update(updates) # Merge updates
|
||||
self.properties = serialize_properties(current)
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Payload accessor (for task-specific output data)
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
@property
|
||||
def payload_dict(self) -> dict[str, Any]:
|
||||
"""
|
||||
Get payload as parsed JSON.
|
||||
|
||||
Payload contains task-specific output data set by task code via
|
||||
ctx.update_task(payload=...).
|
||||
|
||||
:returns: Dictionary containing payload data
|
||||
"""
|
||||
try:
|
||||
return json.loads(self.payload or "{}")
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return {}
|
||||
|
||||
def set_payload(self, data: dict[str, Any]) -> None:
|
||||
"""
|
||||
Update payload with new data.
|
||||
|
||||
The payload is merged with existing data, not replaced.
|
||||
|
||||
:param data: Dictionary of data to merge into payload
|
||||
"""
|
||||
current = self.payload_dict
|
||||
current.update(data)
|
||||
self.payload = json.dumps(current)
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Error handling
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def set_error_from_exception(self, exception: BaseException) -> None:
|
||||
"""
|
||||
Set error fields from an exception.
|
||||
|
||||
Captures the error message, exception type, and full stack trace.
|
||||
Called automatically by the executor when a task raises an exception.
|
||||
|
||||
:param exception: The exception that caused the failure
|
||||
"""
|
||||
self.update_properties(error_update(exception))
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Status management
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def set_status(self, status: TaskStatus | str) -> None:
|
||||
"""
|
||||
Update task status and dedup_key.
|
||||
|
||||
When a task finishes (success, failure, or abort), the dedup_key is
|
||||
changed to the task's UUID. This frees up the slot so new tasks with
|
||||
the same parameters can be created.
|
||||
|
||||
:param status: New task status
|
||||
"""
|
||||
if isinstance(status, TaskStatus):
|
||||
status = status.value
|
||||
self.status = status
|
||||
|
||||
# Update timestamps and is_abortable based on status
|
||||
now = datetime.now(timezone.utc)
|
||||
if status == TaskStatus.IN_PROGRESS.value and not self.started_at:
|
||||
self.started_at = now
|
||||
# Set is_abortable to False when task starts executing
|
||||
# (will be set to True if/when an abort handler is registered)
|
||||
if self.properties_dict.get("is_abortable") is None:
|
||||
self.update_properties({"is_abortable": False})
|
||||
elif status in [
|
||||
TaskStatus.SUCCESS.value,
|
||||
TaskStatus.FAILURE.value,
|
||||
TaskStatus.ABORTED.value,
|
||||
TaskStatus.TIMED_OUT.value,
|
||||
]:
|
||||
if not self.ended_at:
|
||||
self.ended_at = now
|
||||
# Update dedup_key to UUID to free up the slot for new tasks
|
||||
self.dedup_key = get_finished_dedup_key(self.uuid)
|
||||
# Note: ABORTING status doesn't set ended_at yet - that happens when
|
||||
# the task transitions to ABORTED after handlers complete
|
||||
|
||||
@property
|
||||
def is_pending(self) -> bool:
|
||||
"""Check if task is pending."""
|
||||
return self.status == TaskStatus.PENDING.value
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
"""Check if task is currently running."""
|
||||
return self.status == TaskStatus.IN_PROGRESS.value
|
||||
|
||||
@property
|
||||
def is_finished(self) -> bool:
|
||||
"""Check if task has finished (success, failure, aborted, or timed out)."""
|
||||
return self.status in [
|
||||
TaskStatus.SUCCESS.value,
|
||||
TaskStatus.FAILURE.value,
|
||||
TaskStatus.ABORTED.value,
|
||||
TaskStatus.TIMED_OUT.value,
|
||||
]
|
||||
|
||||
@property
|
||||
def is_successful(self) -> bool:
|
||||
"""Check if task completed successfully."""
|
||||
return self.status == TaskStatus.SUCCESS.value
|
||||
|
||||
@property
|
||||
def duration_seconds(self) -> float | None:
|
||||
"""
|
||||
Get task duration in seconds.
|
||||
|
||||
- Finished tasks: Time from started_at to ended_at (None if never started)
|
||||
- Running/aborting tasks: Time from started_at to now
|
||||
- Pending tasks: Time from created_on to now (queue time)
|
||||
|
||||
Note: started_at/ended_at are stored in UTC, but created_on from
|
||||
AuditMixinNullable is stored as naive local time. We handle both cases.
|
||||
"""
|
||||
if self.is_finished:
|
||||
# Task has completed - use fixed timestamps, never increment
|
||||
if self.started_at and self.ended_at:
|
||||
# Finished task - both timestamps use the same timezone (UTC)
|
||||
# Just compute the difference directly
|
||||
return (self.ended_at - self.started_at).total_seconds()
|
||||
# Never started (e.g., aborted while pending) - no duration
|
||||
return None
|
||||
elif self.started_at:
|
||||
# Running or aborting - started_at is UTC (set by set_status)
|
||||
# Use UTC now for comparison
|
||||
now = datetime.now(timezone.utc)
|
||||
started = (
|
||||
self.started_at.replace(tzinfo=timezone.utc)
|
||||
if self.started_at.tzinfo is None
|
||||
else self.started_at
|
||||
)
|
||||
return (now - started).total_seconds()
|
||||
elif self.created_on:
|
||||
# Pending - created_on is naive LOCAL time (from AuditMixinNullable)
|
||||
# Use naive local time for comparison
|
||||
now = datetime.now() # Local time, no timezone
|
||||
created = (
|
||||
self.created_on.replace(tzinfo=None)
|
||||
if self.created_on.tzinfo is not None
|
||||
else self.created_on
|
||||
)
|
||||
return (now - created).total_seconds()
|
||||
return None
|
||||
|
||||
# Scope-related properties
|
||||
@property
|
||||
def is_private(self) -> bool:
|
||||
"""Check if task is private (user-specific)."""
|
||||
return self.scope == "private"
|
||||
|
||||
@property
|
||||
def is_shared(self) -> bool:
|
||||
"""Check if task is shared (multi-user)."""
|
||||
return self.scope == "shared"
|
||||
|
||||
@property
|
||||
def is_system(self) -> bool:
|
||||
"""Check if task is system (admin-only)."""
|
||||
return self.scope == "system"
|
||||
|
||||
# Subscriber-related methods
|
||||
@property
|
||||
def subscriber_count(self) -> int:
|
||||
"""Get number of subscribers to this task."""
|
||||
return len(self.subscribers)
|
||||
|
||||
def has_subscriber(self, user_id: int) -> bool:
|
||||
"""
|
||||
Check if a user is subscribed to this task.
|
||||
|
||||
:param user_id: User ID to check
|
||||
:returns: True if user is subscribed
|
||||
"""
|
||||
return any(sub.user_id == user_id for sub in self.subscribers)
|
||||
|
||||
def get_subscriber_ids(self) -> list[int]:
|
||||
"""
|
||||
Get list of all subscriber user IDs.
|
||||
|
||||
:returns: List of user IDs subscribed to this task
|
||||
"""
|
||||
return [sub.user_id for sub in self.subscribers]
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""
|
||||
Convert task to dictionary representation.
|
||||
|
||||
Minimal API payload - frontend derives status booleans and abort logic
|
||||
from status and properties.is_abortable.
|
||||
|
||||
:returns: Dictionary representation of the task
|
||||
"""
|
||||
return {
|
||||
"id": self.id,
|
||||
"uuid": str(self.uuid),
|
||||
"task_key": self.task_key,
|
||||
"task_type": self.task_type,
|
||||
"task_name": self.task_name,
|
||||
"scope": self.scope,
|
||||
"status": self.status,
|
||||
"created_on": self.created_on.isoformat() if self.created_on else None,
|
||||
"changed_on": self.changed_on.isoformat() if self.changed_on else None,
|
||||
"started_at": self.started_at.isoformat() if self.started_at else None,
|
||||
"ended_at": self.ended_at.isoformat() if self.ended_at else None,
|
||||
"created_by_fk": self.created_by_fk,
|
||||
"user_id": self.user_id,
|
||||
"payload": self.payload_dict,
|
||||
"properties": self.properties_dict,
|
||||
"subscriber_count": self.subscriber_count,
|
||||
"subscriber_ids": self.get_subscriber_ids(),
|
||||
}
|
||||
@@ -26,7 +26,7 @@ from __future__ import annotations
|
||||
import dataclasses
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any, TYPE_CHECKING
|
||||
from typing import Any
|
||||
|
||||
import msgpack
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
@@ -56,9 +56,6 @@ from superset.utils.core import override_user, zlib_compress
|
||||
from superset.utils.dates import now_as_float
|
||||
from superset.utils.decorators import stats_timing
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
BYTES_IN_MB = 1024 * 1024
|
||||
|
||||
87
superset/tasks/ambient_context.py
Normal file
87
superset/tasks/ambient_context.py
Normal file
@@ -0,0 +1,87 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""Ambient context management for the Global Task Framework (GTF)"""
|
||||
|
||||
from contextlib import contextmanager
|
||||
from contextvars import ContextVar
|
||||
from typing import Iterator
|
||||
|
||||
from superset.tasks.context import TaskContext
|
||||
|
||||
# Global context variable for ambient context pattern
|
||||
# This is thread-safe and async-safe via Python's contextvars
|
||||
_current_context: ContextVar[TaskContext | None] = ContextVar(
|
||||
"task_context", default=None
|
||||
)
|
||||
|
||||
|
||||
def get_context() -> TaskContext:
|
||||
"""
|
||||
Get the current task context from contextvars.
|
||||
|
||||
This function provides ambient access to the task context without
|
||||
requiring it to be passed as a parameter. It can only be called
|
||||
from within a task execution.
|
||||
|
||||
:returns: The current TaskContext
|
||||
:raises RuntimeError: If called outside a task execution context
|
||||
|
||||
Example:
|
||||
>>> @task()
|
||||
>>> def my_task(chart_id: int) -> None:
|
||||
>>> ctx = get_context() # Access ambient context
|
||||
>>>
|
||||
>>> # Update progress and payload atomically
|
||||
>>> ctx.update_task(
|
||||
>>> progress=0.5,
|
||||
>>> payload={"chart_id": chart_id}
|
||||
>>> )
|
||||
"""
|
||||
ctx = _current_context.get()
|
||||
if ctx is None:
|
||||
raise RuntimeError(
|
||||
"get_context() called outside task execution context. "
|
||||
"This function can only be called from within a @task "
|
||||
"decorated function."
|
||||
)
|
||||
return ctx
|
||||
|
||||
|
||||
@contextmanager
|
||||
def use_context(ctx: TaskContext) -> Iterator[None]:
|
||||
"""
|
||||
Context manager to set ambient context for task execution.
|
||||
|
||||
This is used internally by the framework to establish the ambient context
|
||||
before executing a task function. The context is automatically cleaned up
|
||||
after execution, even if the task raises an exception.
|
||||
|
||||
:param ctx: TaskContext to set as the current context
|
||||
:yields: None
|
||||
|
||||
Example (internal framework use):
|
||||
>>> ctx = TaskContext(task_uuid=task.uuid)
|
||||
>>> with use_context(ctx):
|
||||
>>> # Task function can now call get_context()
|
||||
>>> task_function(*args, **kwargs)
|
||||
>>> # Context automatically reset after execution
|
||||
"""
|
||||
token = _current_context.set(ctx)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_current_context.reset(token)
|
||||
471
superset/tasks/api.py
Normal file
471
superset/tasks/api.py
Normal file
@@ -0,0 +1,471 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""Task REST API"""
|
||||
|
||||
import logging
|
||||
from uuid import UUID
|
||||
|
||||
from flask import Response
|
||||
from flask_appbuilder.api import expose, protect, safe
|
||||
from flask_appbuilder.models.sqla.interface import SQLAInterface
|
||||
|
||||
from superset.commands.tasks.cancel import CancelTaskCommand
|
||||
from superset.commands.tasks.exceptions import (
|
||||
TaskAbortFailedError,
|
||||
TaskForbiddenError,
|
||||
TaskNotAbortableError,
|
||||
TaskNotFoundError,
|
||||
TaskPermissionDeniedError,
|
||||
)
|
||||
from superset.constants import MODEL_API_RW_METHOD_PERMISSION_MAP, RouteMethod
|
||||
from superset.extensions import event_logger
|
||||
from superset.models.tasks import Task
|
||||
from superset.tasks.filters import TaskFilter
|
||||
from superset.tasks.schemas import (
|
||||
openapi_spec_methods_override,
|
||||
TaskCancelRequestSchema,
|
||||
TaskCancelResponseSchema,
|
||||
TaskResponseSchema,
|
||||
TaskStatusResponseSchema,
|
||||
)
|
||||
from superset.views.base_api import (
|
||||
BaseSupersetModelRestApi,
|
||||
RelatedFieldFilter,
|
||||
statsd_metrics,
|
||||
)
|
||||
from superset.views.filters import BaseFilterRelatedUsers, FilterRelatedOwners
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TaskRestApi(BaseSupersetModelRestApi):
|
||||
"""REST API for task management"""
|
||||
|
||||
datamodel = SQLAInterface(Task)
|
||||
resource_name = "task"
|
||||
allow_browser_login = True
|
||||
|
||||
class_permission_name = "Task"
|
||||
|
||||
# Map cancel and status to write/read permissions
|
||||
method_permission_name = {
|
||||
**MODEL_API_RW_METHOD_PERMISSION_MAP,
|
||||
"cancel": "write",
|
||||
"status": "read",
|
||||
}
|
||||
|
||||
# Only allow read operations - no create/update/delete through REST API
|
||||
# Tasks are created via SubmitTaskCommand, cancelled via /cancel endpoint
|
||||
include_route_methods = {
|
||||
RouteMethod.GET,
|
||||
RouteMethod.GET_LIST,
|
||||
RouteMethod.INFO,
|
||||
"cancel",
|
||||
"status",
|
||||
"related_subscribers",
|
||||
"related",
|
||||
}
|
||||
|
||||
list_columns = [
|
||||
"id",
|
||||
"uuid",
|
||||
"task_type",
|
||||
"task_key",
|
||||
"task_name",
|
||||
"scope",
|
||||
"status",
|
||||
"created_on",
|
||||
"created_on_delta_humanized",
|
||||
"changed_on",
|
||||
"changed_by.first_name",
|
||||
"changed_by.last_name",
|
||||
"started_at",
|
||||
"ended_at",
|
||||
"created_by.id",
|
||||
"created_by.first_name",
|
||||
"created_by.last_name",
|
||||
"user_id",
|
||||
"payload",
|
||||
"properties",
|
||||
"duration_seconds",
|
||||
"subscriber_count",
|
||||
"subscribers",
|
||||
]
|
||||
|
||||
list_select_columns = list_columns + ["created_by_fk", "changed_by_fk"]
|
||||
|
||||
show_columns = list_columns
|
||||
|
||||
order_columns = [
|
||||
"task_type",
|
||||
"scope",
|
||||
"status",
|
||||
"created_on",
|
||||
"changed_on",
|
||||
"started_at",
|
||||
"ended_at",
|
||||
]
|
||||
|
||||
search_columns = [
|
||||
"task_type",
|
||||
"task_key",
|
||||
"task_name",
|
||||
"scope",
|
||||
"status",
|
||||
"created_by",
|
||||
"created_on",
|
||||
]
|
||||
|
||||
base_order = ("created_on", "desc")
|
||||
base_filters = [["id", TaskFilter, lambda: []]]
|
||||
|
||||
# Related field configuration for filter dropdowns
|
||||
allowed_rel_fields = {"created_by"}
|
||||
related_field_filters = {
|
||||
"created_by": RelatedFieldFilter("first_name", FilterRelatedOwners),
|
||||
}
|
||||
base_related_field_filters = {
|
||||
"created_by": [["id", BaseFilterRelatedUsers, lambda: []]],
|
||||
}
|
||||
|
||||
show_model_schema = TaskResponseSchema()
|
||||
list_model_schema = TaskResponseSchema()
|
||||
cancel_request_schema = TaskCancelRequestSchema()
|
||||
|
||||
openapi_spec_tag = "Tasks"
|
||||
openapi_spec_component_schemas = (
|
||||
TaskResponseSchema,
|
||||
TaskCancelRequestSchema,
|
||||
TaskCancelResponseSchema,
|
||||
TaskStatusResponseSchema,
|
||||
)
|
||||
openapi_spec_methods = openapi_spec_methods_override
|
||||
|
||||
@expose("/<task_uuid>", methods=("GET",))
|
||||
@protect()
|
||||
@safe
|
||||
@statsd_metrics
|
||||
@event_logger.log_this_with_context(
|
||||
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.get",
|
||||
log_to_statsd=False,
|
||||
)
|
||||
def get(self, task_uuid: str) -> Response:
|
||||
"""Get a task.
|
||||
---
|
||||
get:
|
||||
summary: Get a task
|
||||
parameters:
|
||||
- in: path
|
||||
schema:
|
||||
type: string
|
||||
format: uuid
|
||||
name: task_uuid
|
||||
description: The UUID of the task
|
||||
responses:
|
||||
200:
|
||||
description: Task detail
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
result:
|
||||
$ref: '#/components/schemas/TaskResponseSchema'
|
||||
401:
|
||||
$ref: '#/components/responses/401'
|
||||
403:
|
||||
$ref: '#/components/responses/403'
|
||||
404:
|
||||
$ref: '#/components/responses/404'
|
||||
"""
|
||||
from superset.daos.tasks import TaskDAO
|
||||
|
||||
try:
|
||||
uuid = UUID(task_uuid)
|
||||
task = TaskDAO.find_one_or_none(uuid=uuid)
|
||||
|
||||
if not task:
|
||||
return self.response_404()
|
||||
|
||||
result = self.show_model_schema.dump(task)
|
||||
return self.response(200, result=result)
|
||||
except (ValueError, TypeError):
|
||||
return self.response_404()
|
||||
|
||||
@expose("/<task_uuid>/status", methods=("GET",))
|
||||
@protect()
|
||||
@safe
|
||||
@statsd_metrics
|
||||
@event_logger.log_this_with_context(
|
||||
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.status",
|
||||
log_to_statsd=False,
|
||||
)
|
||||
def status(self, task_uuid: str) -> Response:
|
||||
"""Get only the status of a task (lightweight for polling).
|
||||
---
|
||||
get:
|
||||
summary: Get task status
|
||||
parameters:
|
||||
- in: path
|
||||
schema:
|
||||
type: string
|
||||
format: uuid
|
||||
name: task_uuid
|
||||
description: The UUID of the task
|
||||
responses:
|
||||
200:
|
||||
description: Task status
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
status:
|
||||
type: string
|
||||
description: Current status of the task
|
||||
401:
|
||||
$ref: '#/components/responses/401'
|
||||
403:
|
||||
$ref: '#/components/responses/403'
|
||||
404:
|
||||
$ref: '#/components/responses/404'
|
||||
"""
|
||||
from superset.daos.tasks import TaskDAO
|
||||
|
||||
try:
|
||||
uuid = UUID(task_uuid)
|
||||
status = TaskDAO.get_status(uuid)
|
||||
|
||||
if status is None:
|
||||
return self.response_404()
|
||||
|
||||
return self.response(200, status=status)
|
||||
except (ValueError, TypeError):
|
||||
return self.response_404()
|
||||
|
||||
@expose("/<task_uuid>/cancel", methods=("POST",))
|
||||
@protect()
|
||||
@safe
|
||||
@statsd_metrics
|
||||
@event_logger.log_this_with_context(
|
||||
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.cancel",
|
||||
log_to_statsd=False,
|
||||
)
|
||||
def cancel(self, task_uuid: str) -> Response:
|
||||
"""Cancel a task.
|
||||
---
|
||||
post:
|
||||
summary: Cancel a task
|
||||
description: >
|
||||
Cancel a task. The behavior depends on task scope and subscriber
|
||||
count:
|
||||
|
||||
- **Private tasks**: Aborts the task
|
||||
- **Shared tasks (single subscriber)**: Aborts the task
|
||||
- **Shared tasks (multiple subscribers)**: Removes current user's
|
||||
subscription; the task continues for other subscribers
|
||||
- **Shared tasks with force=true (admin only)**: Aborts task for
|
||||
all subscribers
|
||||
|
||||
The `action` field in the response indicates what happened:
|
||||
- `aborted`: Task was terminated
|
||||
- `unsubscribed`: User was removed from task (task continues)
|
||||
parameters:
|
||||
- in: path
|
||||
schema:
|
||||
type: string
|
||||
format: uuid
|
||||
name: task_uuid
|
||||
description: The UUID of the task to cancel
|
||||
requestBody:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/TaskCancelRequestSchema'
|
||||
responses:
|
||||
200:
|
||||
description: Task cancelled successfully
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/TaskCancelResponseSchema'
|
||||
401:
|
||||
$ref: '#/components/responses/401'
|
||||
403:
|
||||
$ref: '#/components/responses/403'
|
||||
404:
|
||||
$ref: '#/components/responses/404'
|
||||
422:
|
||||
$ref: '#/components/responses/422'
|
||||
"""
|
||||
return self._execute_cancel(task_uuid)
|
||||
|
||||
def _execute_cancel(self, task_uuid_str: str) -> Response:
|
||||
"""Execute the cancel operation with error handling."""
|
||||
try:
|
||||
task_uuid = UUID(task_uuid_str)
|
||||
|
||||
command, updated_task = self._run_cancel_command(task_uuid)
|
||||
return self._build_cancel_response(command, updated_task)
|
||||
|
||||
except TaskNotFoundError:
|
||||
return self.response_404()
|
||||
except (TaskForbiddenError, TaskPermissionDeniedError) as ex:
|
||||
if isinstance(ex, TaskPermissionDeniedError):
|
||||
logger.warning(
|
||||
"Permission denied cancelling task %s: %s",
|
||||
task_uuid_str,
|
||||
str(ex),
|
||||
)
|
||||
return self.response_403()
|
||||
except TaskNotAbortableError as ex:
|
||||
logger.warning("Task %s is not cancellable: %s", task_uuid_str, str(ex))
|
||||
return self.response_422(message=str(ex))
|
||||
except TaskAbortFailedError as ex:
|
||||
logger.error(
|
||||
"Error cancelling task %s: %s", task_uuid_str, str(ex), exc_info=True
|
||||
)
|
||||
return self.response_422(message=str(ex))
|
||||
except (ValueError, TypeError):
|
||||
return self.response_404()
|
||||
|
||||
def _run_cancel_command(self, task_uuid: UUID) -> tuple[CancelTaskCommand, "Task"]:
|
||||
"""Parse request and run the cancel command."""
|
||||
from flask import request
|
||||
|
||||
force = False
|
||||
# Use get_json with silent=True to handle missing Content-Type gracefully
|
||||
json_data = request.get_json(silent=True)
|
||||
if json_data:
|
||||
parsed = self.cancel_request_schema.load(json_data)
|
||||
force = parsed.get("force", False)
|
||||
|
||||
command = CancelTaskCommand(task_uuid, force=force)
|
||||
updated_task = command.run()
|
||||
return command, updated_task
|
||||
|
||||
def _build_cancel_response(
|
||||
self, command: CancelTaskCommand, updated_task: "Task"
|
||||
) -> Response:
|
||||
"""Build the response for a successful cancel operation."""
|
||||
action = command.action_taken
|
||||
message = (
|
||||
"Task cancelled"
|
||||
if action == "aborted"
|
||||
else "You have been removed from this task"
|
||||
)
|
||||
result = {
|
||||
"message": message,
|
||||
"action": action,
|
||||
"task": self.show_model_schema.dump(updated_task),
|
||||
}
|
||||
return self.response(200, **result)
|
||||
|
||||
@expose("/related/subscribers", methods=("GET",))
|
||||
@protect()
|
||||
@safe
|
||||
@statsd_metrics
|
||||
@event_logger.log_this_with_context(
|
||||
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}"
|
||||
".related_subscribers",
|
||||
log_to_statsd=False,
|
||||
)
|
||||
def related_subscribers(self) -> Response:
|
||||
"""Get users who are subscribers to tasks.
|
||||
---
|
||||
get:
|
||||
summary: Get related subscribers
|
||||
description: >
|
||||
Returns a list of users who are subscribed to tasks, for use in filter
|
||||
dropdowns. Results can be filtered by a search query parameter.
|
||||
parameters:
|
||||
- in: query
|
||||
schema:
|
||||
type: string
|
||||
name: q
|
||||
description: Search query to filter subscribers by name
|
||||
responses:
|
||||
200:
|
||||
description: List of subscribers
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
count:
|
||||
type: integer
|
||||
description: Total number of matching subscribers
|
||||
result:
|
||||
type: array
|
||||
items:
|
||||
type: object
|
||||
properties:
|
||||
value:
|
||||
type: integer
|
||||
description: User ID
|
||||
text:
|
||||
type: string
|
||||
description: User display name
|
||||
401:
|
||||
$ref: '#/components/responses/401'
|
||||
"""
|
||||
from flask import request
|
||||
|
||||
from superset import db, security_manager
|
||||
from superset.models.task_subscribers import TaskSubscriber
|
||||
|
||||
# Get search query
|
||||
|
||||
# Get user model
|
||||
user_model = security_manager.user_model
|
||||
|
||||
# Query distinct users who are task subscribers
|
||||
query = (
|
||||
db.session.query(user_model.id, user_model.first_name, user_model.last_name)
|
||||
.join(TaskSubscriber, user_model.id == TaskSubscriber.user_id)
|
||||
.distinct()
|
||||
)
|
||||
|
||||
# Apply search filter if provided
|
||||
if search_query := request.args.get("q", ""):
|
||||
like_value = f"%{search_query}%"
|
||||
query = query.filter(
|
||||
(user_model.first_name + " " + user_model.last_name).ilike(like_value)
|
||||
| user_model.username.ilike(like_value)
|
||||
)
|
||||
|
||||
# Order by name
|
||||
query = query.order_by(user_model.first_name, user_model.last_name)
|
||||
|
||||
# Limit results
|
||||
query = query.limit(100)
|
||||
|
||||
# Execute and format results
|
||||
results = query.all()
|
||||
|
||||
return self.response(
|
||||
200,
|
||||
count=len(results),
|
||||
result=[
|
||||
{
|
||||
"value": user_id,
|
||||
"text": f"{first_name or ''} {last_name or ''}".strip()
|
||||
or str(user_id),
|
||||
}
|
||||
for user_id, first_name, last_name in results
|
||||
],
|
||||
)
|
||||
54
superset/tasks/constants.py
Normal file
54
superset/tasks/constants.py
Normal file
@@ -0,0 +1,54 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""Constants for the Global Task Framework (GTF)."""
|
||||
|
||||
from superset_core.api.tasks import TaskStatus
|
||||
|
||||
# Terminal states: Task execution has ended and dedup_key slot is freed
|
||||
TERMINAL_STATES: frozenset[str] = frozenset(
|
||||
{
|
||||
TaskStatus.SUCCESS.value,
|
||||
TaskStatus.FAILURE.value,
|
||||
TaskStatus.ABORTED.value,
|
||||
TaskStatus.TIMED_OUT.value,
|
||||
}
|
||||
)
|
||||
|
||||
# Active states: Task is still in progress and dedup_key is reserved
|
||||
ACTIVE_STATES: frozenset[str] = frozenset(
|
||||
{
|
||||
TaskStatus.PENDING.value,
|
||||
TaskStatus.IN_PROGRESS.value,
|
||||
TaskStatus.ABORTING.value,
|
||||
}
|
||||
)
|
||||
|
||||
# Abortable states: Task can be aborted (for pending or abortable in-progress)
|
||||
ABORTABLE_STATES: frozenset[str] = frozenset(
|
||||
{
|
||||
TaskStatus.PENDING.value,
|
||||
TaskStatus.IN_PROGRESS.value,
|
||||
}
|
||||
)
|
||||
|
||||
# Abort-related states: Task is being or has been aborted
|
||||
ABORT_STATES: frozenset[str] = frozenset(
|
||||
{
|
||||
TaskStatus.ABORTING.value,
|
||||
TaskStatus.ABORTED.value,
|
||||
}
|
||||
)
|
||||
673
superset/tasks/context.py
Normal file
673
superset/tasks/context.py
Normal file
@@ -0,0 +1,673 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""Concrete TaskContext implementation for GTF"""
|
||||
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
from typing import Any, Callable, cast, TYPE_CHECKING, TypeVar
|
||||
|
||||
from flask import current_app
|
||||
from superset_core.api.tasks import (
|
||||
TaskContext as CoreTaskContext,
|
||||
TaskProperties,
|
||||
TaskStatus,
|
||||
)
|
||||
|
||||
from superset.stats_logger import BaseStatsLogger
|
||||
from superset.tasks.constants import ABORT_STATES
|
||||
from superset.tasks.utils import progress_update
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.models.tasks import Task
|
||||
from superset.tasks.manager import AbortListener
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class TaskContext(CoreTaskContext):
|
||||
"""
|
||||
Concrete implementation of TaskContext for the Global Async Task Framework.
|
||||
|
||||
Provides write-only access to task state. Tasks use this context to update
|
||||
their progress and payload, and check for cancellation. Tasks should not
|
||||
need to read their own state - they are the source of state, not consumers.
|
||||
"""
|
||||
|
||||
# Type alias for handler failures: (handler_type, exception, stack_trace)
|
||||
HandlerFailure = tuple[str, Exception, str]
|
||||
|
||||
def __init__(self, task: "Task") -> None:
|
||||
"""
|
||||
Initialize TaskContext with a pre-fetched task entity.
|
||||
|
||||
The task entity must be pre-fetched by the caller (executor) to ensure
|
||||
caching works correctly and to enforce the pattern of single initial fetch.
|
||||
|
||||
:param task: Pre-fetched Task entity (required)
|
||||
"""
|
||||
self._task_uuid = task.uuid
|
||||
self._cleanup_handlers: list[Callable[[], None]] = []
|
||||
self._abort_handlers: list[Callable[[], None]] = []
|
||||
self._abort_listener: "AbortListener | None" = None
|
||||
self._abort_detected = False
|
||||
self._abort_handlers_completed = False # Track if all abort handlers finished
|
||||
self._execution_completed = False # Set by executor after task work completes
|
||||
|
||||
# Collected handler failures for unified reporting
|
||||
self._handler_failures: list[TaskContext.HandlerFailure] = []
|
||||
|
||||
# Timeout timer state
|
||||
self._timeout_timer: threading.Timer | None = None
|
||||
self._timeout_triggered = False
|
||||
|
||||
# Throttling state for update_task()
|
||||
# These manage the minimum interval between DB writes
|
||||
self._last_db_write_time: float | None = None
|
||||
self._has_pending_updates: bool = False
|
||||
self._deferred_flush_timer: threading.Timer | None = None
|
||||
self._throttle_lock = threading.Lock()
|
||||
|
||||
# Cached task entity - avoids repeated DB fetches.
|
||||
# Updated only by _refresh_task() when checking external state changes.
|
||||
self._task: "Task" = task
|
||||
|
||||
# In-memory state caches - authoritative during execution
|
||||
# These are initialized from the task entity and updated locally
|
||||
# before being written to DB via targeted SQL updates.
|
||||
# We copy the dicts to avoid mutating the Task's cached instances.
|
||||
self._properties_cache: TaskProperties = cast(
|
||||
TaskProperties, {**task.properties_dict}
|
||||
)
|
||||
self._payload_cache: dict[str, Any] = {**task.payload_dict}
|
||||
|
||||
# Store Flask app reference for background thread database access
|
||||
# Use _get_current_object() to get actual app, not proxy
|
||||
try:
|
||||
self._app = current_app._get_current_object()
|
||||
# Cache stats logger to avoid repeated config lookups
|
||||
self._stats_logger: BaseStatsLogger = current_app.config.get(
|
||||
"STATS_LOGGER", BaseStatsLogger()
|
||||
)
|
||||
except RuntimeError:
|
||||
# Handle case where app context isn't available (e.g., tests)
|
||||
self._app = None
|
||||
self._stats_logger = BaseStatsLogger()
|
||||
|
||||
def _refresh_task(self) -> "Task":
|
||||
"""
|
||||
Force refresh the task entity from the database.
|
||||
|
||||
Use this method when you need to check for external state changes,
|
||||
such as whether the task has been aborted by a concurrent operation.
|
||||
|
||||
This method:
|
||||
- Fetches fresh task entity from database
|
||||
- Updates the cached _task reference
|
||||
- Updates properties/payload caches from fresh data
|
||||
|
||||
:returns: Fresh task entity from database
|
||||
:raises ValueError: If task is not found
|
||||
"""
|
||||
from superset.daos.tasks import TaskDAO
|
||||
|
||||
fresh_task = TaskDAO.find_one_or_none(uuid=self._task_uuid)
|
||||
if not fresh_task:
|
||||
raise ValueError(f"Task {self._task_uuid} not found")
|
||||
|
||||
self._task = fresh_task
|
||||
|
||||
# Update caches from fresh data (copy to avoid mutating Task's cache)
|
||||
self._properties_cache = cast(TaskProperties, {**fresh_task.properties_dict})
|
||||
self._payload_cache = {**fresh_task.payload_dict}
|
||||
|
||||
return self._task
|
||||
|
||||
def update_task(
|
||||
self,
|
||||
progress: float | int | tuple[int, int] | None = None,
|
||||
payload: dict[str, object] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Update task progress and/or payload atomically.
|
||||
|
||||
All parameters are optional. Payload is merged with existing cached data.
|
||||
In-memory caches are always updated immediately, but DB writes are
|
||||
throttled according to TASK_PROGRESS_UPDATE_THROTTLE_INTERVAL to prevent
|
||||
excessive database load from eager tasks.
|
||||
|
||||
Progress can be specified in three ways:
|
||||
- float (0.0-1.0): Percentage only, e.g., 0.5 means 50%
|
||||
- int: Count only (total unknown), e.g., 42 means "42 items processed"
|
||||
- tuple[int, int]: Count and total, e.g., (3, 100) means "3 of 100"
|
||||
The percentage is automatically computed from count/total.
|
||||
|
||||
:param progress: Progress value, or None to leave unchanged
|
||||
:param payload: Payload data to merge (dict), or None to leave unchanged
|
||||
"""
|
||||
has_updates = False
|
||||
|
||||
# Handle progress updates - always update in-memory cache
|
||||
if progress is not None:
|
||||
progress_props = progress_update(progress)
|
||||
if progress_props:
|
||||
# Merge progress into cached properties
|
||||
self._properties_cache.update(progress_props)
|
||||
has_updates = True
|
||||
else:
|
||||
# Invalid progress format - progress_update returns empty dict
|
||||
logger.warning(
|
||||
"Invalid progress value for task %s: %s "
|
||||
"(expected float, int, or tuple[int, int])",
|
||||
self._task_uuid,
|
||||
progress,
|
||||
)
|
||||
|
||||
# Handle payload updates - always update in-memory cache
|
||||
if payload is not None:
|
||||
# Merge payload into cached payload
|
||||
self._payload_cache.update(payload)
|
||||
has_updates = True
|
||||
|
||||
if not has_updates:
|
||||
return
|
||||
|
||||
# Get throttle interval from config
|
||||
throttle_interval = current_app.config["TASK_PROGRESS_UPDATE_THROTTLE_INTERVAL"]
|
||||
|
||||
# If throttling is disabled (0), write immediately
|
||||
if throttle_interval <= 0:
|
||||
self._write_to_db()
|
||||
return
|
||||
|
||||
# Apply throttling with deferred flush
|
||||
with self._throttle_lock:
|
||||
now = time.time()
|
||||
|
||||
if self._last_db_write_time is None:
|
||||
# First update - write immediately
|
||||
self._write_to_db()
|
||||
self._last_db_write_time = now
|
||||
elif now - self._last_db_write_time >= throttle_interval:
|
||||
# Throttle window has passed - write immediately
|
||||
self._cancel_deferred_flush_timer()
|
||||
self._write_to_db()
|
||||
self._last_db_write_time = now
|
||||
self._has_pending_updates = False
|
||||
else:
|
||||
# Within throttle window - defer the write
|
||||
self._has_pending_updates = True
|
||||
self._stats_logger.incr("gtf.task.update_deferred")
|
||||
|
||||
# Start deferred flush timer if not already running
|
||||
if self._deferred_flush_timer is None:
|
||||
remaining_time = throttle_interval - (
|
||||
now - self._last_db_write_time
|
||||
)
|
||||
self._deferred_flush_timer = threading.Timer(
|
||||
remaining_time, self._deferred_flush
|
||||
)
|
||||
self._deferred_flush_timer.daemon = True
|
||||
self._deferred_flush_timer.start()
|
||||
|
||||
def _write_to_db(self) -> None:
|
||||
"""
|
||||
Write current cached state to database.
|
||||
|
||||
This method performs the actual DB write using InternalUpdateTaskCommand.
|
||||
It writes whatever is in the caches at the time of the call.
|
||||
"""
|
||||
from superset.commands.tasks.internal_update import InternalUpdateTaskCommand
|
||||
|
||||
self._stats_logger.incr("gtf.task.update_write")
|
||||
|
||||
InternalUpdateTaskCommand(
|
||||
task_uuid=self._task_uuid,
|
||||
properties=self._properties_cache,
|
||||
payload=self._payload_cache,
|
||||
).run()
|
||||
|
||||
def _deferred_flush(self) -> None:
|
||||
"""
|
||||
Timer callback that flushes pending updates at end of throttle window.
|
||||
|
||||
This ensures the UI never shows stale progress for longer than the
|
||||
throttle interval.
|
||||
"""
|
||||
with self._throttle_lock:
|
||||
self._deferred_flush_timer = None
|
||||
|
||||
if self._has_pending_updates:
|
||||
# Need app context for DB operations in timer thread
|
||||
if self._app:
|
||||
with self._app.app_context():
|
||||
self._write_to_db()
|
||||
else:
|
||||
self._write_to_db()
|
||||
|
||||
self._last_db_write_time = time.time()
|
||||
self._has_pending_updates = False
|
||||
|
||||
def _cancel_deferred_flush_timer(self) -> None:
|
||||
"""Cancel the deferred flush timer if running."""
|
||||
if self._deferred_flush_timer is not None:
|
||||
self._deferred_flush_timer.cancel()
|
||||
self._deferred_flush_timer = None
|
||||
|
||||
def on_cleanup(self, handler: Callable[[], None]) -> Callable[[], None]:
|
||||
"""
|
||||
Register a cleanup handler that runs when the task ends.
|
||||
|
||||
Cleanup handlers are called when the task completes (success),
|
||||
fails with an error, or is aborted. Multiple handlers can be
|
||||
registered and will execute in LIFO order (last registered runs first).
|
||||
|
||||
Can be used as a decorator:
|
||||
@ctx.on_cleanup
|
||||
def cleanup():
|
||||
logger.info("Task ended")
|
||||
|
||||
Or called directly:
|
||||
ctx.on_cleanup(lambda: logger.info("Task ended"))
|
||||
|
||||
:param handler: Cleanup function to register
|
||||
:returns: The handler (for decorator compatibility)
|
||||
"""
|
||||
self._cleanup_handlers.append(handler)
|
||||
return handler
|
||||
|
||||
def on_abort(self, handler: Callable[[], None]) -> Callable[[], None]:
|
||||
"""
|
||||
Register abort handler with automatic background listening.
|
||||
|
||||
When the first handler is registered:
|
||||
1. Sets is_abortable=true in the database (marks task as abortable)
|
||||
2. Background abort listener starts automatically (pub/sub or polling)
|
||||
|
||||
The handler will be called automatically when an abort is detected.
|
||||
|
||||
:param handler: Callback function to execute when abort is detected
|
||||
:returns: The handler (for decorator compatibility)
|
||||
|
||||
Example:
|
||||
@ctx.on_abort
|
||||
def handle_abort():
|
||||
logger.info("Task was aborted!")
|
||||
cleanup_partial_work()
|
||||
|
||||
Note:
|
||||
The handler executes in a background thread when abort is detected.
|
||||
The task code continues running unless the handler does something
|
||||
to stop it (e.g., raises an exception, modifies shared state, etc.)
|
||||
"""
|
||||
is_first_handler = len(self._abort_handlers) == 0
|
||||
self._abort_handlers.append(handler)
|
||||
|
||||
if is_first_handler:
|
||||
# Mark task as abortable in database
|
||||
self._set_abortable()
|
||||
|
||||
# Auto-start abort listener when first handler is registered
|
||||
interval = current_app.config["TASK_ABORT_POLLING_DEFAULT_INTERVAL"]
|
||||
self._start_abort_listener(interval)
|
||||
|
||||
return handler
|
||||
|
||||
def _set_abortable(self) -> None:
|
||||
"""Mark the task as abortable (abort handler has been registered)."""
|
||||
from superset.commands.tasks.internal_update import InternalUpdateTaskCommand
|
||||
|
||||
# Update local cache and write to DB
|
||||
self._properties_cache["is_abortable"] = True
|
||||
InternalUpdateTaskCommand(
|
||||
task_uuid=self._task_uuid,
|
||||
properties=self._properties_cache,
|
||||
).run()
|
||||
|
||||
def _start_abort_listener(self, interval: float) -> None:
|
||||
"""
|
||||
Start background abort listener via TaskManager.
|
||||
|
||||
Uses Redis pub/sub if available, otherwise falls back to database polling.
|
||||
The implementation is encapsulated in TaskManager.
|
||||
"""
|
||||
if self._abort_listener is not None:
|
||||
return # Already listening
|
||||
|
||||
from superset.tasks.manager import TaskManager
|
||||
|
||||
self._abort_listener = TaskManager.listen_for_abort(
|
||||
task_uuid=self._task_uuid,
|
||||
callback=self._on_abort_detected,
|
||||
poll_interval=interval,
|
||||
app=self._app,
|
||||
)
|
||||
|
||||
def _on_abort_detected(self) -> None:
|
||||
"""
|
||||
Callback invoked by TaskManager when abort is detected.
|
||||
|
||||
Triggers all registered abort handlers.
|
||||
"""
|
||||
if self._abort_detected:
|
||||
return # Already handled
|
||||
|
||||
# Check if task execution has already completed (late abort race).
|
||||
# Executor sets _execution_completed after task work finishes.
|
||||
if self._execution_completed:
|
||||
logger.info(
|
||||
"Abort detected for task %s but execution already completed",
|
||||
self._task_uuid,
|
||||
)
|
||||
return
|
||||
|
||||
self._abort_detected = True
|
||||
logger.info("Abort detected for task %s", self._task_uuid)
|
||||
self._trigger_abort_handlers()
|
||||
|
||||
def mark_execution_completed(self) -> None:
|
||||
"""
|
||||
Mark that the task's main execution has completed.
|
||||
|
||||
Called by the executor after the task function returns (successfully
|
||||
or with an exception). This prevents late abort callbacks from running
|
||||
handlers when the task work has already finished. Cleanup handlers
|
||||
still run after this is set.
|
||||
"""
|
||||
self._execution_completed = True
|
||||
|
||||
def start_abort_polling(self, interval: float | None = None) -> None:
|
||||
"""
|
||||
Start background abort listener.
|
||||
|
||||
This method is kept for backwards compatibility. It now delegates
|
||||
to _start_abort_listener which uses TaskManager.
|
||||
|
||||
:param interval: Polling interval in seconds (uses config default if None)
|
||||
"""
|
||||
if interval is None:
|
||||
interval = current_app.config["TASK_ABORT_POLLING_DEFAULT_INTERVAL"]
|
||||
self._start_abort_listener(interval)
|
||||
|
||||
def _trigger_abort_handlers(self) -> None:
|
||||
"""
|
||||
Execute all registered abort handlers (called by polling thread or cleanup).
|
||||
|
||||
All handlers are attempted even if some fail (best-effort cleanup).
|
||||
Failures are collected in self._handler_failures for unified reporting.
|
||||
|
||||
Note: This method never writes to DB directly. All failures are collected
|
||||
and written by _run_cleanup() in the executor's finally block, ensuring
|
||||
abort and cleanup handler failures are combined into a single record.
|
||||
"""
|
||||
for handler in reversed(self._abort_handlers):
|
||||
try:
|
||||
handler()
|
||||
except Exception as ex:
|
||||
stack_trace = traceback.format_exc()
|
||||
logger.error(
|
||||
"Abort handler failed for task %s: %s",
|
||||
self._task_uuid,
|
||||
str(ex),
|
||||
exc_info=True,
|
||||
)
|
||||
self._handler_failures.append(("abort", ex, stack_trace))
|
||||
|
||||
# Check if all abort handlers completed successfully
|
||||
abort_failures = [f for f in self._handler_failures if f[0] == "abort"]
|
||||
if not abort_failures:
|
||||
self._abort_handlers_completed = True
|
||||
|
||||
def _write_handler_failures_to_db(self) -> None:
|
||||
"""
|
||||
Write collected handler failures to the database.
|
||||
|
||||
Combines all failures (abort + cleanup) into a single error record.
|
||||
If the task already has an error (e.g., task function threw exception),
|
||||
handler failures are APPENDED to preserve the original error context.
|
||||
"""
|
||||
from superset.commands.tasks.update import UpdateTaskCommand
|
||||
|
||||
if not self._handler_failures:
|
||||
return
|
||||
|
||||
# Build error message from all handler failures
|
||||
error_messages = [str(ex) for _, ex, _ in self._handler_failures]
|
||||
handler_types = {htype for htype, _, _ in self._handler_failures}
|
||||
|
||||
if len(self._handler_failures) == 1:
|
||||
htype, ex, handler_stack_trace = self._handler_failures[0]
|
||||
handler_error_msg = (
|
||||
f"{htype.capitalize()} handler failed: {error_messages[0]}"
|
||||
)
|
||||
handler_exception_type = type(ex).__name__
|
||||
else:
|
||||
# Multiple failures
|
||||
handler_error_msg = f"Handler(s) failed: {'; '.join(error_messages)}"
|
||||
if handler_types == {"abort"}:
|
||||
handler_exception_type = "MultipleAbortHandlerFailures"
|
||||
elif handler_types == {"cleanup"}:
|
||||
handler_exception_type = "MultipleCleanupHandlerFailures"
|
||||
else:
|
||||
handler_exception_type = "MultipleHandlerFailures"
|
||||
|
||||
# Combine stack traces with clear separators
|
||||
handler_stack_trace = "\n--- Next handler failure ---\n".join(
|
||||
f"[{htype}:{type(ex).__name__}]\n{trace}"
|
||||
for htype, ex, trace in self._handler_failures
|
||||
)
|
||||
|
||||
if self._app:
|
||||
with self._app.app_context():
|
||||
# Check if task already has an error (preserve original context)
|
||||
task = self._task
|
||||
original_error = task.properties_dict.get("error_message")
|
||||
original_type = task.properties_dict.get("exception_type")
|
||||
original_trace = task.properties_dict.get("stack_trace")
|
||||
|
||||
if original_error:
|
||||
# Append handler failures to original error
|
||||
error_msg = f"{original_error} | {handler_error_msg}"
|
||||
exception_type = (
|
||||
f"{original_type}+{handler_exception_type}"
|
||||
if original_type
|
||||
else handler_exception_type
|
||||
)
|
||||
stack_trace = (
|
||||
f"{original_trace}\n\n"
|
||||
f"=== Handler failures during cleanup ===\n\n"
|
||||
f"{handler_stack_trace}"
|
||||
if original_trace
|
||||
else handler_stack_trace
|
||||
)
|
||||
else:
|
||||
# No original error, just use handler failures
|
||||
error_msg = handler_error_msg
|
||||
exception_type = handler_exception_type
|
||||
stack_trace = handler_stack_trace
|
||||
|
||||
# Update task with combined error info
|
||||
UpdateTaskCommand(
|
||||
self._task_uuid,
|
||||
status=TaskStatus.FAILURE.value,
|
||||
properties={
|
||||
"error_message": error_msg,
|
||||
"exception_type": exception_type,
|
||||
"stack_trace": stack_trace,
|
||||
},
|
||||
skip_security_check=True,
|
||||
).run()
|
||||
|
||||
# Clear failures after writing
|
||||
self._handler_failures = []
|
||||
|
||||
def stop_abort_polling(self) -> None:
|
||||
"""Stop the background abort listener."""
|
||||
if self._abort_listener is not None:
|
||||
self._abort_listener.stop()
|
||||
self._abort_listener = None
|
||||
|
||||
def start_timeout_timer(self, timeout_seconds: int) -> None:
|
||||
"""
|
||||
Start a timeout timer that triggers abort when elapsed.
|
||||
|
||||
Called by execute_task when task transitions to IN_PROGRESS.
|
||||
Timer only triggers abort handlers if task is abortable.
|
||||
|
||||
:param timeout_seconds: Timeout duration in seconds
|
||||
"""
|
||||
if self._timeout_timer is not None:
|
||||
return # Already started
|
||||
|
||||
def on_timeout() -> None:
|
||||
if self._abort_detected:
|
||||
return # Already aborting
|
||||
|
||||
self._timeout_triggered = True
|
||||
|
||||
# Check if task has abort handler (requires app context)
|
||||
if not self._app:
|
||||
logger.error(
|
||||
"Timeout fired for task %s but no app context available",
|
||||
self._task_uuid,
|
||||
)
|
||||
return
|
||||
|
||||
with self._app.app_context():
|
||||
from superset.commands.tasks.update import UpdateTaskCommand
|
||||
|
||||
task = self._task
|
||||
if task.properties_dict.get("is_abortable", False):
|
||||
logger.info(
|
||||
"Timeout reached for task %s after %d seconds - "
|
||||
"transitioning to ABORTING and triggering abort handlers",
|
||||
self._task_uuid,
|
||||
timeout_seconds,
|
||||
)
|
||||
# Set status to ABORTING (same as user abort)
|
||||
# The executor will determine TIMED_OUT vs FAILURE based on
|
||||
# whether handlers complete successfully
|
||||
UpdateTaskCommand(
|
||||
self._task_uuid,
|
||||
status=TaskStatus.ABORTING.value,
|
||||
properties={"error_message": "Task timed out"},
|
||||
skip_security_check=True,
|
||||
).run()
|
||||
|
||||
# Trigger abort handlers for cleanup
|
||||
self._on_abort_detected()
|
||||
else:
|
||||
# No abort handler - just log warning
|
||||
logger.warning(
|
||||
"Timeout reached for task %s after %d seconds, but no "
|
||||
"abort handler is registered. Task will continue running.",
|
||||
self._task_uuid,
|
||||
timeout_seconds,
|
||||
)
|
||||
|
||||
self._timeout_timer = threading.Timer(timeout_seconds, on_timeout)
|
||||
# Timer is daemon so it won't prevent process exit. If the worker dies,
|
||||
# the task is already in an inconsistent state (stuck IN_PROGRESS) that
|
||||
# requires external recovery (orphan detection). A non-daemon timer with
|
||||
# long timeouts (hours) would block graceful worker shutdown.
|
||||
self._timeout_timer.daemon = True
|
||||
self._timeout_timer.start()
|
||||
logger.debug(
|
||||
"Started timeout timer for task %s: %d seconds",
|
||||
self._task_uuid,
|
||||
timeout_seconds,
|
||||
)
|
||||
|
||||
def stop_timeout_timer(self) -> None:
|
||||
"""Cancel the timeout timer if running."""
|
||||
if self._timeout_timer is not None:
|
||||
self._timeout_timer.cancel()
|
||||
self._timeout_timer = None
|
||||
|
||||
@property
|
||||
def timeout_triggered(self) -> bool:
|
||||
"""Check if the timeout was triggered."""
|
||||
return self._timeout_triggered
|
||||
|
||||
@property
|
||||
def abort_handlers_completed(self) -> bool:
|
||||
"""Check if all abort handlers have completed successfully."""
|
||||
return self._abort_handlers_completed
|
||||
|
||||
def _run_cleanup(self) -> None:
|
||||
"""
|
||||
Run cleanup handlers (called by executor in finally block).
|
||||
|
||||
This runs:
|
||||
1. Flushes any pending throttled updates to ensure final state is persisted
|
||||
2. Abort handlers if task was aborting/aborted (but not yet detected)
|
||||
3. All cleanup handlers (always)
|
||||
|
||||
All handler failures (abort + cleanup) are collected and written to DB
|
||||
as a unified error record at the end.
|
||||
"""
|
||||
# Flush any pending throttled updates before cleanup
|
||||
with self._throttle_lock:
|
||||
self._cancel_deferred_flush_timer()
|
||||
if self._has_pending_updates:
|
||||
self._write_to_db()
|
||||
self._has_pending_updates = False
|
||||
|
||||
# Stop abort listener and timeout timer
|
||||
self.stop_abort_polling()
|
||||
self.stop_timeout_timer()
|
||||
|
||||
# If aborting/aborted but handlers haven't run yet, run them now
|
||||
# (This catches the case where task ended before listener detected abort)
|
||||
if self._app:
|
||||
with self._app.app_context():
|
||||
task = self._task
|
||||
if task.status in ABORT_STATES and not self._abort_detected:
|
||||
self._trigger_abort_handlers()
|
||||
else:
|
||||
# Fallback without app context
|
||||
try:
|
||||
task = self._task
|
||||
if task.status in ABORT_STATES and not self._abort_detected:
|
||||
self._trigger_abort_handlers()
|
||||
except Exception as ex:
|
||||
logger.warning(
|
||||
"Could not check abort status during cleanup for task %s: %s",
|
||||
self._task_uuid,
|
||||
str(ex),
|
||||
)
|
||||
|
||||
# Always run cleanup handlers, collecting failures
|
||||
for handler in reversed(self._cleanup_handlers):
|
||||
try:
|
||||
handler()
|
||||
except Exception as ex:
|
||||
stack_trace = traceback.format_exc()
|
||||
logger.error(
|
||||
"Cleanup handler failed for task %s: %s",
|
||||
self._task_uuid,
|
||||
str(ex),
|
||||
exc_info=True,
|
||||
)
|
||||
self._handler_failures.append(("cleanup", ex, stack_trace))
|
||||
|
||||
# Write all collected failures (abort + cleanup) to DB as unified record
|
||||
if self._handler_failures:
|
||||
self._write_handler_failures_to_db()
|
||||
609
superset/tasks/decorators.py
Normal file
609
superset/tasks/decorators.py
Normal file
@@ -0,0 +1,609 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""Decorators for the Global Task Framework (GTF)"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
from typing import Any, Callable, cast, Generic, ParamSpec, TYPE_CHECKING, TypeVar
|
||||
|
||||
from superset_core.api.tasks import TaskOptions, TaskScope, TaskStatus
|
||||
|
||||
from superset import is_feature_enabled
|
||||
from superset.commands.tasks.exceptions import GlobalTaskFrameworkDisabledError
|
||||
from superset.tasks.ambient_context import use_context
|
||||
from superset.tasks.constants import TERMINAL_STATES
|
||||
from superset.tasks.context import TaskContext
|
||||
from superset.tasks.manager import TaskManager
|
||||
from superset.tasks.registry import TaskRegistry
|
||||
from superset.tasks.utils import generate_random_task_key
|
||||
from superset.utils.core import get_user_id
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.models.tasks import Task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
def task(
|
||||
func: Callable[P, R] | None = None,
|
||||
*,
|
||||
name: str | None = None,
|
||||
scope: TaskScope = TaskScope.PRIVATE,
|
||||
timeout: int | None = None,
|
||||
) -> Callable[[Callable[P, R]], "TaskWrapper[P]"] | "TaskWrapper[P]":
|
||||
"""
|
||||
Decorator to register a task with default scope.
|
||||
|
||||
Can be used with or without parentheses:
|
||||
@task
|
||||
def my_func(): ...
|
||||
|
||||
@task()
|
||||
def my_func(): ...
|
||||
|
||||
@task(name="custom_name", scope=TaskScope.SHARED)
|
||||
def my_func(): ...
|
||||
|
||||
@task(timeout=300) # 5-minute timeout
|
||||
def long_running_func(): ...
|
||||
|
||||
Args:
|
||||
func: The function to decorate (when used without parentheses).
|
||||
name: Optional unique task name (e.g., "superset.generate_thumbnail").
|
||||
If not provided, uses the function name as the task name.
|
||||
scope: Task scope (TaskScope.PRIVATE, SHARED, or SYSTEM).
|
||||
Defaults to TaskScope.PRIVATE.
|
||||
timeout: Optional timeout in seconds. When the timeout is reached,
|
||||
abort handlers are triggered if registered. Can be overridden
|
||||
at call time via TaskOptions(timeout=...).
|
||||
|
||||
Usage:
|
||||
# Private task (default scope) - no parentheses
|
||||
@task
|
||||
def my_async_func(chart_id: int) -> None:
|
||||
ctx = get_context()
|
||||
...
|
||||
|
||||
# Named task with shared scope
|
||||
@task(name="generate_report", scope=TaskScope.SHARED)
|
||||
def generate_expensive_report(report_id: int) -> None:
|
||||
ctx = get_context()
|
||||
...
|
||||
|
||||
# System task (admin-only)
|
||||
@task(scope=TaskScope.SYSTEM)
|
||||
def cleanup_task() -> None:
|
||||
ctx = get_context()
|
||||
...
|
||||
|
||||
# Task with timeout
|
||||
@task(timeout=300)
|
||||
def long_task() -> None:
|
||||
ctx = get_context()
|
||||
|
||||
@ctx.on_abort
|
||||
def handle_abort():
|
||||
# Called when timeout is reached or user cancels
|
||||
...
|
||||
|
||||
Note:
|
||||
Both direct calls and .schedule() return Task, regardless of the
|
||||
original function's return type. The decorated function's return value
|
||||
is discarded; only side effects and context updates matter.
|
||||
"""
|
||||
|
||||
def decorator(f: Callable[P, R]) -> "TaskWrapper[P]":
|
||||
# Use function name if no name provided
|
||||
task_name = name if name is not None else f.__name__
|
||||
|
||||
# Create default options with no scope (scope is now in decorator)
|
||||
default_options = TaskOptions()
|
||||
|
||||
# Validate function signature - must not have ctx or options params
|
||||
sig = inspect.signature(f)
|
||||
forbidden = {"ctx", "options"}
|
||||
if any(param in forbidden for param in sig.parameters):
|
||||
raise TypeError(
|
||||
f"Task function {f.__name__} must not define 'ctx' or "
|
||||
"'options' parameters. "
|
||||
f"Use get_context() instead for ambient context access."
|
||||
)
|
||||
|
||||
# Register task
|
||||
TaskRegistry.register(task_name, f)
|
||||
|
||||
# Create wrapper with schedule() method, default options, scope, and timeout
|
||||
wrapper = TaskWrapper(task_name, f, default_options, scope, timeout)
|
||||
|
||||
# Preserve signature for introspection
|
||||
wrapper.__signature__ = sig # type: ignore[attr-defined]
|
||||
|
||||
return wrapper
|
||||
|
||||
if func is None:
|
||||
# Called with parentheses: @task() or @task(name="foo", scope=TaskScope.SHARED)
|
||||
return decorator
|
||||
else:
|
||||
# Called without parentheses: @task
|
||||
return decorator(func)
|
||||
|
||||
|
||||
class TaskWrapper(Generic[P]):
|
||||
"""
|
||||
Wrapper for task functions that provides .schedule() method.
|
||||
|
||||
Both direct calls and .schedule() return Task. The original function's
|
||||
return value is discarded.
|
||||
|
||||
Direct calls execute synchronously, .schedule() runs async via Celery.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
func: Callable[P, R],
|
||||
default_options: TaskOptions,
|
||||
scope: TaskScope = TaskScope.PRIVATE,
|
||||
default_timeout: int | None = None,
|
||||
) -> None:
|
||||
self.name = name
|
||||
self.func = func
|
||||
self.default_options = default_options
|
||||
self.scope = scope
|
||||
self.default_timeout = default_timeout
|
||||
self.__name__ = func.__name__
|
||||
self.__doc__ = func.__doc__
|
||||
self.__module__ = func.__module__
|
||||
|
||||
# Patch schedule.__signature__ to mirror function + options parameter
|
||||
# This enables proper IDE support and introspection
|
||||
sig = inspect.signature(func)
|
||||
params = list(sig.parameters.values())
|
||||
# Add keyword-only options parameter
|
||||
params.append(
|
||||
inspect.Parameter(
|
||||
"options",
|
||||
inspect.Parameter.KEYWORD_ONLY,
|
||||
default=None,
|
||||
annotation=TaskOptions | None,
|
||||
)
|
||||
)
|
||||
self.schedule.__func__.__signature__ = sig.replace( # type: ignore[attr-defined]
|
||||
parameters=params, return_annotation="Task"
|
||||
)
|
||||
|
||||
def _merge_options(self, override_options: TaskOptions | None) -> TaskOptions:
|
||||
"""
|
||||
Merge decorator defaults with call-time overrides.
|
||||
|
||||
Call-time options take precedence over decorator defaults.
|
||||
For timeout, an explicit None in TaskOptions disables the decorator timeout.
|
||||
|
||||
Args:
|
||||
override_options: Options provided at call time, or None
|
||||
|
||||
Returns:
|
||||
Merged TaskOptions with overrides applied
|
||||
"""
|
||||
if override_options is None:
|
||||
return TaskOptions(
|
||||
task_key=self.default_options.task_key,
|
||||
task_name=self.default_options.task_name,
|
||||
timeout=self.default_timeout, # Use decorator default
|
||||
)
|
||||
|
||||
# Merge: use override if provided, otherwise use default
|
||||
# For timeout: if override_options.timeout is explicitly set (even to None),
|
||||
# use it; otherwise fall back to decorator default
|
||||
return TaskOptions(
|
||||
task_key=override_options.task_key or self.default_options.task_key,
|
||||
task_name=override_options.task_name or self.default_options.task_name,
|
||||
timeout=override_options.timeout
|
||||
if override_options.timeout is not None
|
||||
else self.default_timeout,
|
||||
)
|
||||
|
||||
def _validate_task(self, options: TaskOptions) -> None:
|
||||
"""
|
||||
Validate task configuration before execution.
|
||||
|
||||
Args:
|
||||
options: Merged task options to validate
|
||||
|
||||
Raises:
|
||||
ValueError: If validation fails
|
||||
"""
|
||||
# Shared tasks must have an explicit task_key for deduplication
|
||||
if self.scope == TaskScope.SHARED and options.task_key is None:
|
||||
raise ValueError(
|
||||
f"Shared task '{self.name}' requires an explicit task_key in "
|
||||
"TaskOptions for deduplication. Without a task_key, each "
|
||||
"invocation creates a separate task with a random UUID, "
|
||||
"defeating the purpose of shared tasks."
|
||||
)
|
||||
|
||||
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> "Task":
|
||||
"""
|
||||
Call the function synchronously.
|
||||
|
||||
This is invoked when you call the decorated function directly:
|
||||
task = generate_thumbnail(chart_id) # Blocks until completion
|
||||
|
||||
Flow:
|
||||
1. Submit task (create new or join existing via deduplication)
|
||||
2. If joining existing task: wait for it to complete (blocking)
|
||||
3. If new task: execute inline and return completed task
|
||||
|
||||
Sync execution always blocks until completion - even when joining an
|
||||
existing task that's running in another process/worker.
|
||||
|
||||
Returns the Task entity in terminal state (SUCCESS, FAILURE, etc.).
|
||||
|
||||
Raises:
|
||||
GlobalTaskFrameworkDisabledError: If GTF feature flag is not enabled
|
||||
ValueError: If task validation fails
|
||||
TimeoutError: If timeout expires while waiting for existing task
|
||||
"""
|
||||
from superset.commands.tasks.submit import SubmitTaskCommand
|
||||
|
||||
if not is_feature_enabled("GLOBAL_TASK_FRAMEWORK"):
|
||||
raise GlobalTaskFrameworkDisabledError()
|
||||
|
||||
# Extract and merge options (decorator defaults + call-time overrides)
|
||||
override_options = cast(TaskOptions | None, kwargs.pop("options", None))
|
||||
options = self._merge_options(override_options)
|
||||
|
||||
# Validate task configuration
|
||||
self._validate_task(options)
|
||||
|
||||
# Extract task_name and task_key from merged options, scope from decorator
|
||||
task_name = (
|
||||
options.task_name or f"{self.name}:{generate_random_task_key()[:50]}"
|
||||
)
|
||||
task_key = options.task_key or generate_random_task_key()
|
||||
scope = self.scope # Use scope from decorator
|
||||
|
||||
# Build properties with execution_mode and timeout
|
||||
properties: dict[str, str | int] = {"execution_mode": "sync"}
|
||||
if options.timeout:
|
||||
properties["timeout"] = options.timeout
|
||||
|
||||
# Submit task - may create new or join existing
|
||||
task, is_new = SubmitTaskCommand(
|
||||
{
|
||||
"task_type": self.name,
|
||||
"task_key": task_key,
|
||||
"task_name": task_name,
|
||||
"scope": scope.value,
|
||||
"properties": properties,
|
||||
"user_id": get_user_id(),
|
||||
}
|
||||
).run_with_info()
|
||||
|
||||
# If joining existing task, wait for it to complete
|
||||
if not is_new:
|
||||
return self._wait_for_existing_task(task, options.timeout)
|
||||
|
||||
# New task - execute inline
|
||||
return self._execute_inline(task, options, args, kwargs)
|
||||
|
||||
def _wait_for_existing_task(self, task: "Task", timeout: int | None) -> "Task":
|
||||
"""
|
||||
Wait for an existing task to complete.
|
||||
|
||||
Called when sync execution joins a pre-existing task via deduplication.
|
||||
Blocks until the task reaches a terminal state.
|
||||
|
||||
:param task: The existing task to wait for
|
||||
:param timeout: Maximum time to wait in seconds (None = no limit)
|
||||
:returns: Task in terminal state
|
||||
:raises TimeoutError: If timeout expires before task completes
|
||||
"""
|
||||
from flask import current_app
|
||||
|
||||
from superset.daos.tasks import TaskDAO
|
||||
|
||||
# Check if already in terminal state
|
||||
if task.status in TERMINAL_STATES:
|
||||
logger.info(
|
||||
"Joined already-completed task %s (uuid=%s, status=%s)",
|
||||
self.name,
|
||||
task.uuid,
|
||||
task.status,
|
||||
)
|
||||
return task
|
||||
|
||||
# Wait for the existing task to complete
|
||||
logger.info(
|
||||
"Joined active task %s (uuid=%s, status=%s), waiting for completion",
|
||||
self.name,
|
||||
task.uuid,
|
||||
task.status,
|
||||
)
|
||||
|
||||
try:
|
||||
app = current_app._get_current_object()
|
||||
except RuntimeError:
|
||||
app = None
|
||||
|
||||
try:
|
||||
task = TaskManager.wait_for_completion(
|
||||
task_uuid=task.uuid,
|
||||
timeout=float(timeout) if timeout else None,
|
||||
poll_interval=1.0,
|
||||
app=app,
|
||||
)
|
||||
logger.info(
|
||||
"Task %s (uuid=%s) completed with status=%s",
|
||||
self.name,
|
||||
task.uuid,
|
||||
task.status,
|
||||
)
|
||||
return task
|
||||
|
||||
except TimeoutError:
|
||||
logger.warning(
|
||||
"Timeout waiting for task %s (uuid=%s)",
|
||||
self.name,
|
||||
task.uuid,
|
||||
)
|
||||
# Return task in current state (caller can check status)
|
||||
refreshed = TaskDAO.find_one_or_none(uuid=task.uuid)
|
||||
return refreshed if refreshed else task
|
||||
|
||||
def _execute_inline(
|
||||
self,
|
||||
task: "Task",
|
||||
options: TaskOptions,
|
||||
args: tuple[Any, ...],
|
||||
kwargs: dict[str, Any],
|
||||
) -> "Task":
|
||||
"""
|
||||
Execute task function inline (synchronously).
|
||||
|
||||
Called when this is a new task (not joining existing).
|
||||
Uses atomic conditional status transitions for race-safe execution.
|
||||
|
||||
:param task: The newly created task
|
||||
:param options: Merged task options
|
||||
:param args: Positional arguments for the task function
|
||||
:param kwargs: Keyword arguments for the task function
|
||||
:returns: Task in terminal state
|
||||
"""
|
||||
from superset.commands.tasks.internal_update import (
|
||||
InternalStatusTransitionCommand,
|
||||
)
|
||||
from superset.daos.tasks import TaskDAO
|
||||
from superset.tasks.constants import ABORT_STATES
|
||||
|
||||
# PRE-EXECUTION CHECK: Don't execute if already aborted/aborting
|
||||
# (Matches async flow in scheduler.py)
|
||||
if task.status in ABORT_STATES:
|
||||
logger.info(
|
||||
"Task %s (uuid=%s) was aborted before execution started",
|
||||
self.name,
|
||||
task.uuid,
|
||||
)
|
||||
# Ensure status is ABORTED (not just ABORTING)
|
||||
InternalStatusTransitionCommand(
|
||||
task_uuid=task.uuid,
|
||||
new_status=TaskStatus.ABORTED,
|
||||
expected_status=[TaskStatus.PENDING, TaskStatus.ABORTING],
|
||||
set_ended_at=True,
|
||||
).run()
|
||||
# Refresh to get updated task
|
||||
refreshed = TaskDAO.find_one_or_none(uuid=task.uuid)
|
||||
return refreshed if refreshed else task
|
||||
|
||||
# Atomic transition: PENDING → IN_PROGRESS (set started_at for duration
|
||||
# tracking)
|
||||
task_uuid = task.uuid # Cache UUID before any potential state changes
|
||||
if not InternalStatusTransitionCommand(
|
||||
task_uuid=task_uuid,
|
||||
new_status=TaskStatus.IN_PROGRESS,
|
||||
expected_status=TaskStatus.PENDING,
|
||||
set_started_at=True,
|
||||
).run():
|
||||
# Status wasn't PENDING - task may have been aborted concurrently
|
||||
logger.warning(
|
||||
"Task %s (uuid=%s) failed PENDING → IN_PROGRESS transition "
|
||||
"(may have been aborted concurrently)",
|
||||
self.name,
|
||||
task_uuid,
|
||||
)
|
||||
refreshed = TaskDAO.find_one_or_none(uuid=task_uuid)
|
||||
return refreshed if refreshed else task
|
||||
|
||||
# Update cached status (no DB read needed - we just wrote IN_PROGRESS)
|
||||
task.status = TaskStatus.IN_PROGRESS.value
|
||||
|
||||
# Build context with the updated task entity
|
||||
ctx = TaskContext(task)
|
||||
|
||||
# Start timeout timer if configured
|
||||
if options.timeout:
|
||||
ctx.start_timeout_timer(options.timeout)
|
||||
logger.debug(
|
||||
"Started timeout timer for task %s: %d seconds",
|
||||
task.uuid,
|
||||
options.timeout,
|
||||
)
|
||||
|
||||
# Track final task state for completion notification
|
||||
final_task: Task | None = None
|
||||
|
||||
try:
|
||||
# Execute with ambient context
|
||||
with use_context(ctx):
|
||||
self.func(*args, **kwargs)
|
||||
|
||||
# Determine terminal status based on abort detection
|
||||
# Use atomic conditional updates to prevent overwriting concurrent abort
|
||||
if ctx._abort_detected or ctx.timeout_triggered:
|
||||
# Abort was detected - transition ABORTING → terminal
|
||||
if ctx.timeout_triggered:
|
||||
InternalStatusTransitionCommand(
|
||||
task_uuid=task_uuid,
|
||||
new_status=TaskStatus.TIMED_OUT,
|
||||
expected_status=TaskStatus.ABORTING,
|
||||
set_ended_at=True,
|
||||
).run()
|
||||
logger.info(
|
||||
"Task %s (uuid=%s) timed out and completed cleanup",
|
||||
self.name,
|
||||
task_uuid,
|
||||
)
|
||||
else:
|
||||
InternalStatusTransitionCommand(
|
||||
task_uuid=task_uuid,
|
||||
new_status=TaskStatus.ABORTED,
|
||||
expected_status=TaskStatus.ABORTING,
|
||||
set_ended_at=True,
|
||||
).run()
|
||||
logger.info(
|
||||
"Task %s (uuid=%s) was aborted by user",
|
||||
self.name,
|
||||
task_uuid,
|
||||
)
|
||||
else:
|
||||
# Normal completion - atomic IN_PROGRESS → SUCCESS
|
||||
# This will fail (return False) if task was concurrently aborted
|
||||
if InternalStatusTransitionCommand(
|
||||
task_uuid=task_uuid,
|
||||
new_status=TaskStatus.SUCCESS,
|
||||
expected_status=TaskStatus.IN_PROGRESS,
|
||||
set_ended_at=True,
|
||||
).run():
|
||||
logger.debug(
|
||||
"Synchronous execution of task %s (uuid=%s) "
|
||||
"completed successfully",
|
||||
self.name,
|
||||
task_uuid,
|
||||
)
|
||||
else:
|
||||
# Transition failed - task was likely aborted concurrently
|
||||
logger.info(
|
||||
"Task %s (uuid=%s) IN_PROGRESS → SUCCESS failed "
|
||||
"(may have been aborted concurrently)",
|
||||
self.name,
|
||||
task_uuid,
|
||||
)
|
||||
|
||||
# Refresh once at end to return current state
|
||||
final_task = TaskDAO.find_one_or_none(uuid=task_uuid)
|
||||
return final_task if final_task else task
|
||||
|
||||
except Exception as ex:
|
||||
# Atomic transition to FAILURE (only if still IN_PROGRESS)
|
||||
InternalStatusTransitionCommand(
|
||||
task_uuid=task_uuid,
|
||||
new_status=TaskStatus.FAILURE,
|
||||
expected_status=[TaskStatus.IN_PROGRESS, TaskStatus.ABORTING],
|
||||
properties={"error_message": str(ex)},
|
||||
set_ended_at=True,
|
||||
).run()
|
||||
|
||||
logger.error(
|
||||
"Synchronous execution of task %s (uuid=%s) failed: %s",
|
||||
self.name,
|
||||
task_uuid,
|
||||
str(ex),
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# Refresh once at end to return current state
|
||||
final_task = TaskDAO.find_one_or_none(uuid=task_uuid)
|
||||
return final_task if final_task else task
|
||||
|
||||
finally:
|
||||
# Always clean up timer and handlers
|
||||
ctx._run_cleanup()
|
||||
|
||||
# Publish completion notification for any waiters
|
||||
# Use final_task if set by try/except, otherwise refresh (fallback)
|
||||
if final_task is None:
|
||||
final_task = TaskDAO.find_one_or_none(uuid=task_uuid)
|
||||
if final_task and final_task.status in TERMINAL_STATES:
|
||||
TaskManager.publish_completion(task_uuid, final_task.status)
|
||||
|
||||
def schedule(self, *args: P.args, **kwargs: P.kwargs) -> "Task":
|
||||
"""
|
||||
Schedule this task for asynchronous execution.
|
||||
|
||||
The signature mirrors the original task function, with an additional
|
||||
keyword-only 'options' parameter for execution metadata.
|
||||
|
||||
Args:
|
||||
*args, **kwargs: Business arguments for the task function
|
||||
options: Execution options
|
||||
|
||||
Returns:
|
||||
Task model representing the scheduled task (PENDING status)
|
||||
|
||||
Raises:
|
||||
GlobalTaskFrameworkDisabledError: If GTF feature flag is not enabled
|
||||
ValueError: If task is SHARED scope but no task_key is provided
|
||||
|
||||
Usage:
|
||||
# Auto-generated task_key (random UUID, no deduplication):
|
||||
task = generate_thumbnail.schedule(chart_id)
|
||||
|
||||
# Custom task_key for task deduplication:
|
||||
task = generate_thumbnail.schedule(
|
||||
chart_id,
|
||||
options=TaskOptions(task_key=f"thumb_{chart_id}")
|
||||
)
|
||||
|
||||
# SHARED tasks require task_key:
|
||||
task = shared_task.schedule(
|
||||
data_id,
|
||||
options=TaskOptions(task_key=f"shared_{data_id}")
|
||||
)
|
||||
|
||||
Note: Unlike direct calls (__call__), this schedules async execution.
|
||||
The function returns immediately with the Task model in PENDING status.
|
||||
"""
|
||||
if not is_feature_enabled("GLOBAL_TASK_FRAMEWORK"):
|
||||
raise GlobalTaskFrameworkDisabledError()
|
||||
|
||||
# Extract and merge options (decorator defaults + call-time overrides)
|
||||
override_options = cast(TaskOptions | None, kwargs.pop("options", None))
|
||||
options = self._merge_options(override_options)
|
||||
|
||||
# Validate task configuration
|
||||
self._validate_task(options)
|
||||
|
||||
# Extract task_name and task_key from merged options, scope from decorator
|
||||
task_name = options.task_name
|
||||
task_key = options.task_key
|
||||
scope = self.scope # Use scope from decorator
|
||||
|
||||
# Create task entry in metastore and schedule execution
|
||||
return TaskManager.submit_task(
|
||||
task_type=self.name,
|
||||
task_name=task_name,
|
||||
task_key=task_key,
|
||||
scope=scope,
|
||||
timeout=options.timeout,
|
||||
args=args,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
112
superset/tasks/filters.py
Normal file
112
superset/tasks/filters.py
Normal file
@@ -0,0 +1,112 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""Filters for Task model"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.orm.query import Query
|
||||
|
||||
from superset.utils.core import get_user_id
|
||||
from superset.views.base import BaseFilter
|
||||
|
||||
|
||||
class TaskFilter(BaseFilter): # pylint: disable=too-few-public-methods
|
||||
"""
|
||||
Filter for Task that shows tasks based on scope and user permissions.
|
||||
|
||||
Filtering rules:
|
||||
- Admins: See all tasks (private, shared, system)
|
||||
- Non-admins:
|
||||
- Private tasks: Only their own tasks
|
||||
- Shared tasks: Tasks they're subscribed to
|
||||
- System tasks: None (admin-only)
|
||||
"""
|
||||
|
||||
def apply(self, query: Query, value: Any) -> Query:
|
||||
"""Apply the filter to the query."""
|
||||
from flask import g, has_request_context
|
||||
from sqlalchemy import or_
|
||||
|
||||
from superset import db, security_manager
|
||||
from superset.models.task_subscribers import TaskSubscriber
|
||||
from superset.models.tasks import Task
|
||||
|
||||
# If no request context or no user, return unfiltered query
|
||||
# (this handles background tasks and system operations)
|
||||
if not has_request_context() or not hasattr(g, "user"):
|
||||
return query
|
||||
|
||||
# If user is admin, return unfiltered query
|
||||
if security_manager.is_admin():
|
||||
return query
|
||||
|
||||
# For non-admins, filter by scope and permissions
|
||||
user_id = get_user_id()
|
||||
|
||||
# Use subquery for shared tasks to avoid join ambiguity
|
||||
shared_task_ids_query = (
|
||||
db.session.query(Task.id)
|
||||
.join(TaskSubscriber, Task.id == TaskSubscriber.task_id)
|
||||
.filter(
|
||||
Task.scope == "shared",
|
||||
TaskSubscriber.user_id == user_id,
|
||||
)
|
||||
)
|
||||
|
||||
# Build filter conditions:
|
||||
# 1. Private tasks created by current user
|
||||
# 2. Shared tasks where user is subscribed (via subquery)
|
||||
# 3. System tasks are excluded (admin-only)
|
||||
return query.filter(
|
||||
or_(
|
||||
# Own private tasks
|
||||
(Task.scope == "private") & (Task.created_by_fk == user_id),
|
||||
# Shared tasks where user is subscribed
|
||||
Task.id.in_(shared_task_ids_query),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class TaskSubscriberFilter(BaseFilter): # pylint: disable=too-few-public-methods
|
||||
"""
|
||||
Filter tasks by subscriber user ID.
|
||||
|
||||
This filter allows finding tasks where a specific user is subscribed.
|
||||
Used by the frontend for the subscribers filter dropdown.
|
||||
"""
|
||||
|
||||
def apply(self, query: Query, value: Any) -> Query:
|
||||
"""Apply the filter to the query."""
|
||||
from superset import db
|
||||
from superset.models.task_subscribers import TaskSubscriber
|
||||
from superset.models.tasks import Task
|
||||
|
||||
if not value:
|
||||
return query
|
||||
|
||||
# Handle both single ID and list of IDs
|
||||
if isinstance(value, (list, tuple)):
|
||||
user_ids = [int(v) for v in value]
|
||||
else:
|
||||
user_ids = [int(value)]
|
||||
|
||||
# Find tasks where any of these users are subscribers
|
||||
subscribed_task_ids = db.session.query(TaskSubscriber.task_id).filter(
|
||||
TaskSubscriber.user_id.in_(user_ids)
|
||||
)
|
||||
|
||||
return query.filter(Task.id.in_(subscribed_task_ids))
|
||||
81
superset/tasks/locks.py
Normal file
81
superset/tasks/locks.py
Normal file
@@ -0,0 +1,81 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""Distributed locking utilities for the Global Task Framework (GTF).
|
||||
|
||||
This module provides distributed locks for task operations to prevent race
|
||||
conditions during concurrent task creation, subscription, and cancellation.
|
||||
|
||||
The lock key uses the task's dedup_key, ensuring all operations on the same
|
||||
logical task serialize correctly.
|
||||
|
||||
When SIGNAL_CACHE_CONFIG is configured, uses Redis SET NX EX for
|
||||
efficient single-command locking. Otherwise falls back to database-backed
|
||||
locking via DistributedLock.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from typing import Iterator
|
||||
|
||||
from superset.distributed_lock import DistributedLock
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Task operations use a shorter TTL than the global default since
|
||||
# they complete quickly (just DB operations, no external calls)
|
||||
TASK_LOCK_TTL_SECONDS = 10
|
||||
|
||||
|
||||
@contextmanager
|
||||
def task_lock(dedup_key: str) -> Iterator[None]:
|
||||
"""
|
||||
Acquire a distributed lock for task operations.
|
||||
|
||||
Uses the task's dedup_key as the lock key. All operations on the same
|
||||
logical task (create, subscribe, cancel) use the same lock, ensuring
|
||||
mutual exclusion. This prevents race conditions such as:
|
||||
- Two concurrent creates with the same key
|
||||
- Subscribe racing with cancel
|
||||
- Multiple concurrent cancel requests
|
||||
|
||||
When SIGNAL_CACHE_CONFIG is configured, uses Redis SET NX EX
|
||||
for efficient single-command locking. Otherwise falls back to
|
||||
database-backed DistributedLock.
|
||||
|
||||
:param dedup_key: Task deduplication key (from get_active_dedup_key)
|
||||
:yields: Nothing; used as context manager
|
||||
:raises AcquireDistributedLockFailedException: If lock is already held
|
||||
|
||||
Example:
|
||||
dedup_key = get_active_dedup_key(TaskScope.SHARED, "report", "monthly")
|
||||
with task_lock(dedup_key):
|
||||
# Create, subscribe, or cancel task here
|
||||
...
|
||||
"""
|
||||
logger.debug("Acquiring task lock for key: %s", dedup_key)
|
||||
|
||||
with DistributedLock(
|
||||
namespace="gtf:task",
|
||||
key=dedup_key,
|
||||
ttl_seconds=TASK_LOCK_TTL_SECONDS,
|
||||
):
|
||||
yield
|
||||
|
||||
logger.debug("Released task lock for key: %s", dedup_key)
|
||||
764
superset/tasks/manager.py
Normal file
764
superset/tasks/manager.py
Normal file
@@ -0,0 +1,764 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""Task manager for the Global Task Framework (GTF)"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, Callable, TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
import redis
|
||||
from superset_core.api.tasks import TaskProperties, TaskScope
|
||||
|
||||
from superset.async_events.cache_backend import (
|
||||
RedisCacheBackend,
|
||||
RedisSentinelCacheBackend,
|
||||
)
|
||||
from superset.extensions import cache_manager
|
||||
from superset.tasks.constants import ABORT_STATES, TERMINAL_STATES
|
||||
from superset.tasks.utils import generate_random_task_key
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from flask import Flask
|
||||
|
||||
from superset.models.tasks import Task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AbortListener:
|
||||
"""
|
||||
Handle for a background abort listener.
|
||||
|
||||
Returned by TaskManager.listen_for_abort() to allow stopping the listener.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
task_uuid: UUID,
|
||||
thread: threading.Thread,
|
||||
stop_event: threading.Event,
|
||||
pubsub: redis.client.PubSub | None = None,
|
||||
) -> None:
|
||||
self._task_uuid = task_uuid
|
||||
self._thread = thread
|
||||
self._stop_event = stop_event
|
||||
self._pubsub = pubsub
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the abort listener."""
|
||||
self._stop_event.set()
|
||||
|
||||
# Close pub/sub subscription if active
|
||||
if self._pubsub is not None:
|
||||
try:
|
||||
self._pubsub.unsubscribe()
|
||||
self._pubsub.close()
|
||||
except Exception as ex:
|
||||
logger.debug("Error closing pub/sub during stop: %s", ex)
|
||||
|
||||
# Wait for thread to finish (with timeout to avoid blocking indefinitely)
|
||||
if self._thread.is_alive():
|
||||
self._thread.join(timeout=2.0)
|
||||
|
||||
# Check if thread is still running after timeout
|
||||
if self._thread.is_alive():
|
||||
# Thread is a daemon, so it will be killed when process exits.
|
||||
# Log warning but continue - cleanup will still proceed.
|
||||
logger.warning(
|
||||
"Abort listener thread for task %s did not terminate within "
|
||||
"2 seconds. Thread will be terminated when process exits.",
|
||||
self._task_uuid,
|
||||
)
|
||||
else:
|
||||
logger.debug("Stopped abort listener for task %s", self._task_uuid)
|
||||
else:
|
||||
logger.debug("Stopped abort listener for task %s", self._task_uuid)
|
||||
|
||||
|
||||
class TaskManager:
|
||||
"""
|
||||
Handles task creation, scheduling, and abort notifications.
|
||||
|
||||
The TaskManager is responsible for:
|
||||
1. Creating task entries in the metastore (Task model)
|
||||
2. Scheduling task execution via Celery
|
||||
3. Handling deduplication (returning existing active task if duplicate)
|
||||
4. Managing real-time abort notifications (optional)
|
||||
|
||||
Redis pub/sub is opt-in via SIGNAL_CACHE_CONFIG configuration. When not
|
||||
configured, tasks use database polling for abort detection.
|
||||
"""
|
||||
|
||||
# Class-level state (initialized once via init_app)
|
||||
_channel_prefix: str = "gtf:abort:"
|
||||
_completion_channel_prefix: str = "gtf:complete:"
|
||||
_initialized: bool = False
|
||||
|
||||
# Backward compatibility alias - prefer importing from superset.tasks.constants
|
||||
TERMINAL_STATES = TERMINAL_STATES
|
||||
|
||||
@classmethod
|
||||
def init_app(cls, app: Flask) -> None:
|
||||
"""
|
||||
Initialize the TaskManager with Flask app config.
|
||||
|
||||
Redis connection is managed by CacheManager - this just reads channel prefixes.
|
||||
|
||||
:param app: Flask application instance
|
||||
"""
|
||||
if cls._initialized:
|
||||
return
|
||||
|
||||
cls._channel_prefix = app.config.get("TASKS_ABORT_CHANNEL_PREFIX", "gtf:abort:")
|
||||
cls._completion_channel_prefix = app.config.get(
|
||||
"TASKS_COMPLETION_CHANNEL_PREFIX", "gtf:complete:"
|
||||
)
|
||||
|
||||
cls._initialized = True
|
||||
|
||||
@classmethod
|
||||
def _get_cache(cls) -> RedisCacheBackend | RedisSentinelCacheBackend | None:
|
||||
"""
|
||||
Get the signal cache backend.
|
||||
|
||||
:returns: The signal cache backend, or None if not configured
|
||||
"""
|
||||
return cache_manager.signal_cache
|
||||
|
||||
@classmethod
|
||||
def is_pubsub_available(cls) -> bool:
|
||||
"""
|
||||
Check if Redis pub/sub backend is configured and available.
|
||||
|
||||
:returns: True if Redis is available for pub/sub, False otherwise
|
||||
"""
|
||||
return cls._get_cache() is not None
|
||||
|
||||
@classmethod
|
||||
def get_abort_channel(cls, task_uuid: UUID) -> str:
|
||||
"""
|
||||
Get the abort channel name for a task.
|
||||
|
||||
:param task_uuid: UUID of the task
|
||||
:returns: Channel name for the task's abort notifications
|
||||
"""
|
||||
return f"{cls._channel_prefix}{task_uuid}"
|
||||
|
||||
@classmethod
|
||||
def publish_abort(cls, task_uuid: UUID) -> bool:
|
||||
"""
|
||||
Publish an abort message to the task's channel.
|
||||
|
||||
:param task_uuid: UUID of the task to abort
|
||||
:returns: True if message was published, False if Redis unavailable
|
||||
"""
|
||||
cache = cls._get_cache()
|
||||
if not cache:
|
||||
return False
|
||||
|
||||
try:
|
||||
channel = cls.get_abort_channel(task_uuid)
|
||||
subscriber_count = cache.publish(channel, "abort")
|
||||
logger.debug(
|
||||
"Published abort to channel %s (%d subscribers)",
|
||||
channel,
|
||||
subscriber_count,
|
||||
)
|
||||
return True
|
||||
except redis.RedisError as ex:
|
||||
logger.error("Failed to publish abort for task %s: %s", task_uuid, ex)
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_completion_channel(cls, task_uuid: UUID) -> str:
|
||||
"""
|
||||
Get the completion channel name for a task.
|
||||
|
||||
:param task_uuid: UUID of the task
|
||||
:returns: Channel name for the task's completion notifications
|
||||
"""
|
||||
return f"{cls._completion_channel_prefix}{task_uuid}"
|
||||
|
||||
@classmethod
|
||||
def publish_completion(cls, task_uuid: UUID, status: str) -> bool:
|
||||
"""
|
||||
Publish a completion message to the task's channel.
|
||||
|
||||
Called when task reaches terminal state (SUCCESS, FAILURE, ABORTED, TIMED_OUT).
|
||||
This notifies any waiters (e.g., sync callers waiting for an existing task).
|
||||
|
||||
:param task_uuid: UUID of the completed task
|
||||
:param status: Final status of the task
|
||||
:returns: True if message was published, False if Redis unavailable
|
||||
"""
|
||||
cache = cls._get_cache()
|
||||
if not cache:
|
||||
return False
|
||||
|
||||
try:
|
||||
channel = cls.get_completion_channel(task_uuid)
|
||||
subscriber_count = cache.publish(channel, status)
|
||||
logger.debug(
|
||||
"Published completion to channel %s (status=%s, %d subscribers)",
|
||||
channel,
|
||||
status,
|
||||
subscriber_count,
|
||||
)
|
||||
return True
|
||||
except redis.RedisError as ex:
|
||||
logger.error("Failed to publish completion for task %s: %s", task_uuid, ex)
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def wait_for_completion(
|
||||
cls,
|
||||
task_uuid: UUID,
|
||||
timeout: float | None = None,
|
||||
poll_interval: float = 1.0,
|
||||
app: Any = None,
|
||||
) -> "Task":
|
||||
"""
|
||||
Block until task reaches terminal state.
|
||||
|
||||
Uses Redis pub/sub if configured for low-latency, low-CPU waiting.
|
||||
Uses database polling if Redis is not configured.
|
||||
|
||||
:param task_uuid: UUID of the task to wait for
|
||||
:param timeout: Maximum time to wait in seconds (None = no limit)
|
||||
:param poll_interval: Interval for database polling (seconds)
|
||||
:param app: Flask app for database access
|
||||
:returns: Task in terminal state
|
||||
:raises TimeoutError: If timeout expires before task completes
|
||||
:raises ValueError: If task not found
|
||||
"""
|
||||
from superset.daos.tasks import TaskDAO
|
||||
|
||||
start_time = time.monotonic()
|
||||
|
||||
def time_remaining() -> float | None:
|
||||
if timeout is None:
|
||||
return None
|
||||
elapsed = time.monotonic() - start_time
|
||||
remaining = timeout - elapsed
|
||||
return remaining if remaining > 0 else 0
|
||||
|
||||
def get_task() -> "Task | None":
|
||||
if app:
|
||||
with app.app_context():
|
||||
return TaskDAO.find_one_or_none(uuid=task_uuid)
|
||||
return TaskDAO.find_one_or_none(uuid=task_uuid)
|
||||
|
||||
# Check current state first
|
||||
task = get_task()
|
||||
if not task:
|
||||
raise ValueError(f"Task {task_uuid} not found")
|
||||
|
||||
if task.status in cls.TERMINAL_STATES:
|
||||
return task
|
||||
|
||||
logger.debug(
|
||||
"Waiting for task %s to complete (current status=%s, timeout=%s)",
|
||||
task_uuid,
|
||||
task.status,
|
||||
timeout,
|
||||
)
|
||||
|
||||
# Use Redis pub/sub if configured
|
||||
if (cache := cls._get_cache()) is not None:
|
||||
task = cls._wait_via_pubsub(
|
||||
task_uuid,
|
||||
cache.pubsub(),
|
||||
timeout,
|
||||
poll_interval,
|
||||
get_task,
|
||||
time_remaining,
|
||||
)
|
||||
if task:
|
||||
return task
|
||||
# Should not reach here - _wait_via_pubsub returns task or raises
|
||||
raise RuntimeError(f"Unexpected state waiting for task {task_uuid}")
|
||||
|
||||
# Use database polling when Redis is not configured
|
||||
return cls._wait_via_polling(task_uuid, poll_interval, get_task, time_remaining)
|
||||
|
||||
@classmethod
|
||||
def _wait_via_pubsub(
|
||||
cls,
|
||||
task_uuid: UUID,
|
||||
pubsub: redis.client.PubSub,
|
||||
timeout: float | None,
|
||||
poll_interval: float,
|
||||
get_task: Callable[[], "Task | None"],
|
||||
time_remaining: Callable[[], float | None],
|
||||
) -> "Task | None":
|
||||
"""
|
||||
Wait for task completion using Redis pub/sub.
|
||||
|
||||
:returns: Task when completed
|
||||
:raises TimeoutError: If timeout expires
|
||||
:raises redis.RedisError: If Redis connection fails
|
||||
"""
|
||||
channel = cls.get_completion_channel(task_uuid)
|
||||
pubsub.subscribe(channel)
|
||||
|
||||
try:
|
||||
while True:
|
||||
remaining = time_remaining()
|
||||
if remaining is not None and remaining <= 0:
|
||||
raise TimeoutError(
|
||||
f"Timeout waiting for task {task_uuid} to complete"
|
||||
)
|
||||
|
||||
# Wait for message with short timeout for responsive checking
|
||||
wait_time = min(1.0, remaining) if remaining else 1.0
|
||||
message = pubsub.get_message(
|
||||
ignore_subscribe_messages=True,
|
||||
timeout=wait_time,
|
||||
)
|
||||
|
||||
if message and message.get("type") == "message":
|
||||
# Completion received - fetch fresh task state
|
||||
logger.debug(
|
||||
"Received completion message for task %s: %s",
|
||||
task_uuid,
|
||||
message.get("data"),
|
||||
)
|
||||
task = get_task()
|
||||
if task and task.status in cls.TERMINAL_STATES:
|
||||
return task
|
||||
|
||||
# Also check database periodically in case we missed the message
|
||||
# (e.g., task completed before we subscribed)
|
||||
task = get_task()
|
||||
if task and task.status in cls.TERMINAL_STATES:
|
||||
logger.debug(
|
||||
"Task %s completed (detected via db check): status=%s",
|
||||
task_uuid,
|
||||
task.status,
|
||||
)
|
||||
return task
|
||||
|
||||
finally:
|
||||
pubsub.unsubscribe()
|
||||
pubsub.close()
|
||||
|
||||
@classmethod
|
||||
def _wait_via_polling(
|
||||
cls,
|
||||
task_uuid: UUID,
|
||||
poll_interval: float,
|
||||
get_task: Callable[[], "Task | None"],
|
||||
time_remaining: Callable[[], float | None],
|
||||
) -> "Task":
|
||||
"""
|
||||
Wait for task completion using database polling.
|
||||
|
||||
:returns: Task when completed
|
||||
:raises TimeoutError: If timeout expires
|
||||
:raises ValueError: If task not found
|
||||
"""
|
||||
while True:
|
||||
remaining = time_remaining()
|
||||
if remaining is not None and remaining <= 0:
|
||||
raise TimeoutError(f"Timeout waiting for task {task_uuid} to complete")
|
||||
|
||||
task = get_task()
|
||||
if not task:
|
||||
raise ValueError(f"Task {task_uuid} not found")
|
||||
|
||||
if task.status in cls.TERMINAL_STATES:
|
||||
logger.debug(
|
||||
"Task %s completed (detected via polling): status=%s",
|
||||
task_uuid,
|
||||
task.status,
|
||||
)
|
||||
return task
|
||||
|
||||
# Sleep with timeout awareness
|
||||
sleep_time = min(poll_interval, remaining) if remaining else poll_interval
|
||||
time.sleep(sleep_time)
|
||||
|
||||
@classmethod
|
||||
def listen_for_abort(
|
||||
cls,
|
||||
task_uuid: UUID,
|
||||
callback: Callable[[], None],
|
||||
poll_interval: float,
|
||||
app: Any = None,
|
||||
) -> AbortListener:
|
||||
"""
|
||||
Start listening for abort notifications for a task.
|
||||
|
||||
Uses Redis pub/sub if configured, otherwise uses database polling.
|
||||
The callback is invoked when an abort is detected.
|
||||
|
||||
:param task_uuid: UUID of the task to monitor (native UUID)
|
||||
:param callback: Function to call when abort is detected
|
||||
:param poll_interval: Interval for database polling (when Redis not configured)
|
||||
:param app: Flask app for database access in background thread
|
||||
:returns: AbortListener handle to stop listening
|
||||
"""
|
||||
stop_event = threading.Event()
|
||||
pubsub: redis.client.PubSub | None = None
|
||||
uuid_str = str(task_uuid)
|
||||
|
||||
# Use Redis pub/sub if configured
|
||||
if (cache := cls._get_cache()) is not None:
|
||||
pubsub = cache.pubsub()
|
||||
channel = cls.get_abort_channel(task_uuid)
|
||||
pubsub.subscribe(channel)
|
||||
logger.debug("Subscribed to abort channel: %s", channel)
|
||||
|
||||
# Start pub/sub listener thread
|
||||
thread = threading.Thread(
|
||||
target=cls._listen_pubsub,
|
||||
args=(task_uuid, pubsub, callback, stop_event, app),
|
||||
daemon=True,
|
||||
name=f"abort-listener-{uuid_str[:8]}",
|
||||
)
|
||||
logger.debug("Started pub/sub abort listener for task %s", task_uuid)
|
||||
else:
|
||||
# Use polling when Redis is not configured
|
||||
pubsub = None
|
||||
thread = threading.Thread(
|
||||
target=cls._poll_for_abort,
|
||||
args=(task_uuid, callback, stop_event, poll_interval, app),
|
||||
daemon=True,
|
||||
name=f"abort-poller-{uuid_str[:8]}",
|
||||
)
|
||||
logger.debug(
|
||||
"Started database abort polling for task %s (interval=%ss)",
|
||||
task_uuid,
|
||||
poll_interval,
|
||||
)
|
||||
|
||||
thread.start()
|
||||
return AbortListener(task_uuid, thread, stop_event, pubsub)
|
||||
|
||||
@staticmethod
|
||||
def _invoke_callback_with_context(
|
||||
callback: Callable[[], None],
|
||||
app: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Invoke callback with Flask app context if provided.
|
||||
|
||||
:param callback: Function to invoke
|
||||
:param app: Flask app for context, or None
|
||||
"""
|
||||
if app:
|
||||
with app.app_context():
|
||||
callback()
|
||||
else:
|
||||
callback()
|
||||
|
||||
@classmethod
|
||||
def _check_abort_status(cls, task_uuid: UUID) -> bool:
|
||||
"""
|
||||
Check if task has been aborted via database query.
|
||||
|
||||
:param task_uuid: UUID of the task to check (native UUID)
|
||||
:returns: True if task is in ABORTING or ABORTED state
|
||||
"""
|
||||
from superset.daos.tasks import TaskDAO
|
||||
|
||||
task = TaskDAO.find_one_or_none(uuid=task_uuid)
|
||||
return task is not None and task.status in ABORT_STATES
|
||||
|
||||
@classmethod
|
||||
def _run_abort_listener_loop(
|
||||
cls,
|
||||
task_uuid: UUID,
|
||||
callback: Callable[[], None],
|
||||
stop_event: threading.Event,
|
||||
interval: float,
|
||||
app: Any,
|
||||
check_fn: Callable[[], bool],
|
||||
source: str,
|
||||
) -> None:
|
||||
"""
|
||||
Common abort listener loop used by both pub/sub and polling modes.
|
||||
|
||||
:param task_uuid: UUID of the task to monitor (native UUID)
|
||||
:param callback: Function to call when abort is detected
|
||||
:param stop_event: Event to signal loop termination
|
||||
:param interval: Wait interval between checks
|
||||
:param app: Flask app for context
|
||||
:param check_fn: Function that returns True if abort was detected
|
||||
:param source: Source identifier for logging ("pub/sub" or "polling")
|
||||
"""
|
||||
while not stop_event.is_set():
|
||||
try:
|
||||
if check_fn():
|
||||
logger.info(
|
||||
"Abort detected via %s for task %s",
|
||||
source,
|
||||
task_uuid,
|
||||
)
|
||||
cls._invoke_callback_with_context(callback, app)
|
||||
break
|
||||
|
||||
# Wait for interval or until stop is requested
|
||||
stop_event.wait(timeout=interval)
|
||||
|
||||
except (ValueError, OSError) as ex:
|
||||
# ValueError/OSError with "I/O operation on closed file" or
|
||||
# "Bad file descriptor" typically means the connection was closed
|
||||
# during shutdown. Check if stop was requested.
|
||||
if stop_event.is_set():
|
||||
logger.debug(
|
||||
"Abort %s for task %s stopped cleanly (connection closed)",
|
||||
source,
|
||||
task_uuid,
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
"Error in abort %s for task %s: %s",
|
||||
source,
|
||||
task_uuid,
|
||||
str(ex),
|
||||
exc_info=True,
|
||||
)
|
||||
break
|
||||
|
||||
except Exception as ex:
|
||||
# Check if stop was requested - if so, this may be expected
|
||||
if stop_event.is_set():
|
||||
logger.debug(
|
||||
"Abort %s for task %s stopped with exception: %s",
|
||||
source,
|
||||
task_uuid,
|
||||
ex,
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
"Error in abort %s for task %s: %s",
|
||||
source,
|
||||
task_uuid,
|
||||
str(ex),
|
||||
exc_info=True,
|
||||
)
|
||||
break
|
||||
|
||||
@classmethod
|
||||
def _listen_pubsub(
|
||||
cls,
|
||||
task_uuid: UUID,
|
||||
pubsub: redis.client.PubSub,
|
||||
callback: Callable[[], None],
|
||||
stop_event: threading.Event,
|
||||
app: Any,
|
||||
) -> None:
|
||||
"""Listen for abort via Redis pub/sub."""
|
||||
# Track if abort was received to avoid double-callback
|
||||
abort_received = False
|
||||
|
||||
def check_pubsub() -> bool:
|
||||
nonlocal abort_received
|
||||
message = pubsub.get_message(ignore_subscribe_messages=True, timeout=1.0)
|
||||
if message is not None and message.get("type") == "message":
|
||||
abort_received = True
|
||||
return True
|
||||
return False
|
||||
|
||||
try:
|
||||
cls._run_abort_listener_loop(
|
||||
task_uuid=task_uuid,
|
||||
callback=callback,
|
||||
stop_event=stop_event,
|
||||
interval=0, # pub/sub has its own timeout in get_message
|
||||
app=app,
|
||||
check_fn=check_pubsub,
|
||||
source="pub/sub",
|
||||
)
|
||||
|
||||
except redis.RedisError as ex:
|
||||
# Check if we were asked to stop - if so, this is expected
|
||||
if stop_event.is_set():
|
||||
logger.debug(
|
||||
"Abort listener for task %s stopped (Redis error: %s)",
|
||||
task_uuid,
|
||||
ex,
|
||||
)
|
||||
else:
|
||||
# Log error but don't fall back - let the failure be visible
|
||||
logger.error(
|
||||
"Redis signal backend failed for task %s abort listener: %s. "
|
||||
"Task may not receive abort signal.",
|
||||
task_uuid,
|
||||
ex,
|
||||
)
|
||||
|
||||
except (ValueError, OSError) as ex:
|
||||
# ValueError: "I/O operation on closed file" - expected when stop() closes
|
||||
# OSError: Similar connection-closed errors
|
||||
if stop_event.is_set():
|
||||
# Clean shutdown, expected behavior
|
||||
logger.debug(
|
||||
"Abort listener for task %s stopped cleanly",
|
||||
task_uuid,
|
||||
)
|
||||
else:
|
||||
# Unexpected error while running
|
||||
logger.error(
|
||||
"Error in abort listener for task %s: %s",
|
||||
task_uuid,
|
||||
str(ex),
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
except Exception as ex:
|
||||
# Only log as error if we weren't asked to stop
|
||||
if stop_event.is_set():
|
||||
logger.debug(
|
||||
"Abort listener for task %s stopped with exception: %s",
|
||||
task_uuid,
|
||||
ex,
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
"Error in abort listener for task %s: %s",
|
||||
task_uuid,
|
||||
str(ex),
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
finally:
|
||||
# Clean up pub/sub subscription
|
||||
try:
|
||||
pubsub.unsubscribe()
|
||||
pubsub.close()
|
||||
except Exception as ex:
|
||||
logger.debug("Error closing pub/sub during cleanup: %s", ex)
|
||||
|
||||
@classmethod
|
||||
def _poll_for_abort(
|
||||
cls,
|
||||
task_uuid: UUID,
|
||||
callback: Callable[[], None],
|
||||
stop_event: threading.Event,
|
||||
interval: float,
|
||||
app: Any,
|
||||
) -> None:
|
||||
"""Background polling loop - used when Redis pub/sub is not configured."""
|
||||
|
||||
def check_database() -> bool:
|
||||
# Need app context for database access
|
||||
if app:
|
||||
with app.app_context():
|
||||
return cls._check_abort_status(task_uuid)
|
||||
else:
|
||||
return cls._check_abort_status(task_uuid)
|
||||
|
||||
cls._run_abort_listener_loop(
|
||||
task_uuid=task_uuid,
|
||||
callback=callback,
|
||||
stop_event=stop_event,
|
||||
interval=interval,
|
||||
app=app,
|
||||
check_fn=check_database,
|
||||
source="polling",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def submit_task(
|
||||
task_type: str,
|
||||
task_key: str | None,
|
||||
task_name: str | None,
|
||||
scope: TaskScope,
|
||||
timeout: int | None,
|
||||
args: tuple[Any, ...],
|
||||
kwargs: dict[str, Any],
|
||||
) -> "Task":
|
||||
"""
|
||||
Create task entry and schedule for async execution.
|
||||
|
||||
Flow:
|
||||
1. Generate task_key if not provided (random UUID)
|
||||
2. Submit to SubmitTaskCommand which handles locking and create-vs-join
|
||||
3. Schedule Celery task ONLY for new tasks (not deduplicated ones)
|
||||
4. Return Task model to caller
|
||||
|
||||
The SubmitTaskCommand uses a distributed lock to prevent race conditions,
|
||||
returning either a new task or an existing active task with the same key.
|
||||
|
||||
:param task_type: Task type identifier (e.g., "superset.generate_thumbnail")
|
||||
:param task_key: Optional deduplication key (None for random UUID)
|
||||
:param task_name: Human readable task name
|
||||
:param scope: Task scope (TaskScope.PRIVATE, SHARED, or SYSTEM)
|
||||
:param timeout: Optional timeout in seconds
|
||||
:param args: Positional arguments for the task function
|
||||
:param kwargs: Keyword arguments for the task function
|
||||
:returns: Task model representing the scheduled task
|
||||
"""
|
||||
from superset.commands.tasks.submit import SubmitTaskCommand
|
||||
|
||||
if task_key is None:
|
||||
task_key = generate_random_task_key()
|
||||
|
||||
# Build properties with execution_mode and timeout
|
||||
properties: TaskProperties = {"execution_mode": "async"}
|
||||
if timeout:
|
||||
properties["timeout"] = timeout
|
||||
|
||||
# Create or join task entry in metastore
|
||||
# SubmitTaskCommand handles locking and create-vs-join logic:
|
||||
# - Acquires distributed lock on dedup_key
|
||||
# - If active task exists: adds subscriber and returns existing task
|
||||
# (is_new=False)
|
||||
# - If no active task: creates new task (is_new=True)
|
||||
task, is_new = SubmitTaskCommand(
|
||||
{
|
||||
"task_key": task_key,
|
||||
"task_type": task_type,
|
||||
"task_name": task_name,
|
||||
"scope": scope.value,
|
||||
"properties": properties,
|
||||
}
|
||||
).run_with_info()
|
||||
|
||||
# Only schedule Celery task for NEW tasks, not deduplicated ones
|
||||
# Deduplicated tasks are already pending or running
|
||||
if is_new:
|
||||
# Import here to avoid circular dependency
|
||||
from superset.tasks.scheduler import execute_task
|
||||
|
||||
# Schedule Celery task for async execution
|
||||
execute_task.delay(
|
||||
task_uuid=str(task.uuid),
|
||||
task_type=task_type,
|
||||
args=args,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"Scheduled task %s (uuid=%s) for async execution",
|
||||
task_type,
|
||||
task.uuid,
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"Joined existing task %s (uuid=%s) - no new Celery task scheduled",
|
||||
task_type,
|
||||
task.uuid,
|
||||
)
|
||||
|
||||
return task
|
||||
110
superset/tasks/registry.py
Normal file
110
superset/tasks/registry.py
Normal file
@@ -0,0 +1,110 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""Task registry for the Global Task Framework (GTF)"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Callable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TaskRegistry:
|
||||
"""
|
||||
Registry for task functions.
|
||||
|
||||
Stores task functions by name, allowing the Celery executor to look up
|
||||
and execute registered tasks. This enables the decorator pattern where
|
||||
functions are registered at module import time.
|
||||
"""
|
||||
|
||||
_tasks: dict[str, Callable[..., Any]] = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, task_name: str, func: Callable[..., Any]) -> None:
|
||||
"""
|
||||
Register a task function by name.
|
||||
|
||||
:param task_name: Unique task identifier (e.g., "superset.generate_thumbnail")
|
||||
:param func: The task function to register
|
||||
:raises ValueError: If task name is already registered
|
||||
"""
|
||||
if task_name in cls._tasks:
|
||||
existing_func = cls._tasks[task_name]
|
||||
if existing_func is not func:
|
||||
raise ValueError(
|
||||
f"Task '{task_name}' is already registered with a different "
|
||||
"function. "
|
||||
f"Existing: {existing_func.__module__}.{existing_func.__name__}, "
|
||||
f"New: {func.__module__}.{func.__name__}"
|
||||
)
|
||||
# Same function being registered again (e.g., module reload) - allow it
|
||||
logger.debug("Task '%s' re-registered with same function", task_name)
|
||||
return
|
||||
|
||||
cls._tasks[task_name] = func
|
||||
logger.info(
|
||||
"Registered async task: %s -> %s.%s",
|
||||
task_name,
|
||||
func.__module__,
|
||||
func.__name__,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_executor(cls, task_name: str) -> Callable[..., Any]:
|
||||
"""
|
||||
Get the executor function for a task.
|
||||
|
||||
:param task_name: Task identifier to look up
|
||||
:returns: The registered task function
|
||||
:raises KeyError: If task name is not registered
|
||||
"""
|
||||
if task_name not in cls._tasks:
|
||||
raise KeyError(
|
||||
f"Task '{task_name}' is not registered. "
|
||||
f"Available tasks: {', '.join(sorted(cls._tasks.keys()))}"
|
||||
)
|
||||
return cls._tasks[task_name]
|
||||
|
||||
@classmethod
|
||||
def is_registered(cls, task_name: str) -> bool:
|
||||
"""
|
||||
Check if a task is registered.
|
||||
|
||||
:param task_name: Task identifier to check
|
||||
:returns: True if task is registered
|
||||
"""
|
||||
return task_name in cls._tasks
|
||||
|
||||
@classmethod
|
||||
def list_tasks(cls) -> list[str]:
|
||||
"""
|
||||
Get list of all registered task names.
|
||||
|
||||
:returns: Sorted list of task names
|
||||
"""
|
||||
return sorted(cls._tasks.keys())
|
||||
|
||||
@classmethod
|
||||
def clear(cls) -> None:
|
||||
"""
|
||||
Clear all registered tasks.
|
||||
|
||||
WARNING: This is primarily for testing purposes. In production,
|
||||
tasks should remain registered for the lifetime of the process.
|
||||
"""
|
||||
cls._tasks.clear()
|
||||
logger.warning("Task registry cleared")
|
||||
@@ -19,11 +19,13 @@ from __future__ import annotations
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from celery.signals import task_failure
|
||||
from flask import current_app
|
||||
from superset_core.api.tasks import TaskStatus
|
||||
|
||||
from superset import is_feature_enabled
|
||||
from superset.commands.exceptions import CommandException
|
||||
@@ -32,10 +34,17 @@ from superset.commands.report.exceptions import ReportScheduleUnexpectedError
|
||||
from superset.commands.report.execute import AsyncExecuteReportScheduleCommand
|
||||
from superset.commands.report.log_prune import AsyncPruneReportScheduleLogCommand
|
||||
from superset.commands.sql_lab.query import QueryPruneCommand
|
||||
from superset.commands.tasks.prune import TaskPruneCommand
|
||||
from superset.daos.report import ReportScheduleDAO
|
||||
from superset.daos.tasks import TaskDAO
|
||||
from superset.extensions import celery_app
|
||||
from superset.stats_logger import BaseStatsLogger
|
||||
from superset.tasks.ambient_context import use_context
|
||||
from superset.tasks.constants import ABORT_STATES, TERMINAL_STATES
|
||||
from superset.tasks.context import TaskContext
|
||||
from superset.tasks.cron_util import cron_schedule_window
|
||||
from superset.tasks.manager import TaskManager
|
||||
from superset.tasks.registry import TaskRegistry
|
||||
from superset.utils.core import LoggerLevel
|
||||
from superset.utils.log import get_logger_from_status
|
||||
|
||||
@@ -199,3 +208,251 @@ def prune_logs(
|
||||
LogPruneCommand(retention_period_days, max_rows_per_run).run()
|
||||
except CommandException as ex:
|
||||
logger.exception("An error occurred while pruning logs: %s", ex)
|
||||
|
||||
|
||||
@celery_app.task(name="prune_tasks", bind=True)
|
||||
def prune_tasks(
|
||||
self: Task,
|
||||
retention_period_days: int | None = None,
|
||||
max_rows_per_run: int | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
stats_logger: BaseStatsLogger = current_app.config["STATS_LOGGER"]
|
||||
stats_logger.incr("prune_tasks")
|
||||
|
||||
# TODO: Deprecated: Remove support for passing retention period via options in 6.0
|
||||
if retention_period_days is None:
|
||||
retention_period_days = prune_tasks.request.properties.get(
|
||||
"retention_period_days"
|
||||
)
|
||||
logger.warning(
|
||||
"Your `prune_tasks` beat schedule uses `options` to pass the "
|
||||
"retention period, please use `kwargs` instead."
|
||||
)
|
||||
|
||||
try:
|
||||
TaskPruneCommand(retention_period_days, max_rows_per_run).run()
|
||||
except CommandException as ex:
|
||||
logger.exception("An error occurred while pruning async tasks: %s", ex)
|
||||
|
||||
|
||||
@celery_app.task(name="tasks.execute", bind=True)
|
||||
def execute_task( # noqa: C901
|
||||
self: Any, # Celery task instance
|
||||
task_uuid: str,
|
||||
task_type: str,
|
||||
args: tuple[Any, ...],
|
||||
kwargs: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Generic task executor for GTF tasks.
|
||||
|
||||
This executor:
|
||||
1. Checks if task was aborted before execution starts
|
||||
2. Fetches task from metastore
|
||||
3. Builds context (task + user) and sets ambient context via contextvars
|
||||
4. Executes the task function (which accesses context via get_context())
|
||||
5. Updates task status throughout lifecycle using atomic conditional updates
|
||||
6. Runs cleanup handlers on task end (success/failure/abortion)
|
||||
7. Resets context after execution
|
||||
|
||||
Uses atomic conditional status updates to prevent race conditions with
|
||||
concurrent abort operations.
|
||||
|
||||
:param task_uuid: UUID of the task to execute
|
||||
:param task_type: Type of the task (for registry lookup)
|
||||
:param args: Positional arguments for the task function
|
||||
:param kwargs: Keyword arguments for the task function
|
||||
:returns: Dict with status and task_uuid
|
||||
"""
|
||||
from superset.commands.tasks.internal_update import InternalStatusTransitionCommand
|
||||
|
||||
# Convert string UUID to native UUID (Celery deserializes as string)
|
||||
native_uuid = UUID(task_uuid)
|
||||
|
||||
task = TaskDAO.find_one_or_none(uuid=native_uuid)
|
||||
if not task:
|
||||
logger.error("Task %s not found in metastore", task_uuid)
|
||||
return {"status": "error", "message": "Task not found"}
|
||||
|
||||
# AUTOMATIC PRE-EXECUTION CHECK: Don't execute if already aborted/aborting
|
||||
if task.status in ABORT_STATES:
|
||||
logger.info(
|
||||
"Task %s (uuid=%s) was aborted before execution started",
|
||||
task_type,
|
||||
task_uuid,
|
||||
)
|
||||
# Atomic transition to ABORTED (if not already)
|
||||
InternalStatusTransitionCommand(
|
||||
task_uuid=native_uuid,
|
||||
new_status=TaskStatus.ABORTED,
|
||||
expected_status=[TaskStatus.PENDING, TaskStatus.ABORTING],
|
||||
set_ended_at=True,
|
||||
).run()
|
||||
return {"status": TaskStatus.ABORTED.value, "task_uuid": task_uuid}
|
||||
|
||||
# Atomic transition: PENDING → IN_PROGRESS (set started_at for duration tracking)
|
||||
if not InternalStatusTransitionCommand(
|
||||
task_uuid=native_uuid,
|
||||
new_status=TaskStatus.IN_PROGRESS,
|
||||
expected_status=TaskStatus.PENDING,
|
||||
set_started_at=True,
|
||||
).run():
|
||||
# Status wasn't PENDING - task may have been aborted concurrently
|
||||
logger.warning(
|
||||
"Task %s (uuid=%s) failed PENDING → IN_PROGRESS transition "
|
||||
"(may have been aborted concurrently)",
|
||||
task_type,
|
||||
task_uuid,
|
||||
)
|
||||
refreshed = TaskDAO.find_one_or_none(uuid=native_uuid)
|
||||
return {
|
||||
"status": refreshed.status if refreshed else "unknown",
|
||||
"task_uuid": task_uuid,
|
||||
}
|
||||
|
||||
# Update cached status (no DB read needed - we just wrote IN_PROGRESS)
|
||||
task.status = TaskStatus.IN_PROGRESS.value
|
||||
|
||||
# Build context from task (includes user who created the task)
|
||||
ctx = TaskContext(task)
|
||||
|
||||
# Start timeout timer if configured (timer starts from execution time)
|
||||
if timeout := task.properties_dict.get("timeout"):
|
||||
ctx.start_timeout_timer(timeout)
|
||||
logger.debug(
|
||||
"Started timeout timer for task %s: %d seconds",
|
||||
task_uuid,
|
||||
timeout,
|
||||
)
|
||||
|
||||
try:
|
||||
# Get registered executor function
|
||||
executor_fn = TaskRegistry.get_executor(task_type)
|
||||
|
||||
logger.info(
|
||||
"Executing task %s (uuid=%s) with function %s.%s",
|
||||
task_type,
|
||||
task_uuid,
|
||||
executor_fn.__module__,
|
||||
executor_fn.__name__,
|
||||
)
|
||||
|
||||
# Execute with ambient context (no ctx parameter!)
|
||||
with use_context(ctx):
|
||||
executor_fn(*args, **kwargs)
|
||||
|
||||
# Mark execution as completed to prevent late abort handlers
|
||||
ctx.mark_execution_completed()
|
||||
|
||||
# Determine terminal status based on abort detection
|
||||
# Use atomic conditional updates to prevent overwriting concurrent abort
|
||||
if ctx._abort_detected or ctx.timeout_triggered:
|
||||
# Abort was detected - will be handled in finally block
|
||||
pass
|
||||
else:
|
||||
# Normal completion - also allow ABORTING → SUCCESS for late abort
|
||||
# (task finished before abort was detected)
|
||||
if InternalStatusTransitionCommand(
|
||||
task_uuid=native_uuid,
|
||||
new_status=TaskStatus.SUCCESS,
|
||||
expected_status=[TaskStatus.IN_PROGRESS, TaskStatus.ABORTING],
|
||||
set_ended_at=True,
|
||||
).run():
|
||||
# Emit stats metric for success
|
||||
stats_logger: BaseStatsLogger = current_app.config["STATS_LOGGER"]
|
||||
stats_logger.incr("gtf.task.success")
|
||||
logger.info(
|
||||
"Task %s (uuid=%s) completed successfully", task_type, task_uuid
|
||||
)
|
||||
else:
|
||||
# Transition failed - task was likely already in a terminal state
|
||||
logger.info(
|
||||
"Task %s (uuid=%s) completion transition failed "
|
||||
"(task may already be in terminal state)",
|
||||
task_type,
|
||||
task_uuid,
|
||||
)
|
||||
|
||||
except Exception as ex:
|
||||
# Mark execution as completed to prevent late abort handlers
|
||||
ctx.mark_execution_completed()
|
||||
|
||||
# Atomic transition to FAILURE (only if still IN_PROGRESS or ABORTING)
|
||||
InternalStatusTransitionCommand(
|
||||
task_uuid=native_uuid,
|
||||
new_status=TaskStatus.FAILURE,
|
||||
expected_status=[TaskStatus.IN_PROGRESS, TaskStatus.ABORTING],
|
||||
properties={"error_message": str(ex)},
|
||||
set_ended_at=True,
|
||||
).run()
|
||||
|
||||
logger.error(
|
||||
"Task %s (uuid=%s) failed with error: %s",
|
||||
task_type,
|
||||
task_uuid,
|
||||
str(ex),
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# Emit stats metric for failure
|
||||
stats_logger = current_app.config["STATS_LOGGER"]
|
||||
stats_logger.incr("gtf.task.failure")
|
||||
|
||||
finally:
|
||||
# ALWAYS run cleanup handlers (also stops timeout timer)
|
||||
ctx._run_cleanup()
|
||||
|
||||
# Handle abort/timeout terminal transitions
|
||||
# Use atomic updates to safely transition ABORTING → terminal state
|
||||
if ctx._abort_detected or ctx.timeout_triggered:
|
||||
if ctx.abort_handlers_completed:
|
||||
# All handlers succeeded - determine terminal state based on cause
|
||||
if ctx.timeout_triggered:
|
||||
InternalStatusTransitionCommand(
|
||||
task_uuid=native_uuid,
|
||||
new_status=TaskStatus.TIMED_OUT,
|
||||
expected_status=TaskStatus.ABORTING,
|
||||
set_ended_at=True,
|
||||
).run()
|
||||
logger.info(
|
||||
"Task %s (uuid=%s) timed out and completed cleanup",
|
||||
task_type,
|
||||
task_uuid,
|
||||
)
|
||||
else:
|
||||
InternalStatusTransitionCommand(
|
||||
task_uuid=native_uuid,
|
||||
new_status=TaskStatus.ABORTED,
|
||||
expected_status=TaskStatus.ABORTING,
|
||||
set_ended_at=True,
|
||||
).run()
|
||||
logger.info(
|
||||
"Task %s (uuid=%s) was aborted by user",
|
||||
task_type,
|
||||
task_uuid,
|
||||
)
|
||||
else:
|
||||
# Handlers didn't complete successfully - mark as FAILURE
|
||||
InternalStatusTransitionCommand(
|
||||
task_uuid=native_uuid,
|
||||
new_status=TaskStatus.FAILURE,
|
||||
expected_status=TaskStatus.ABORTING,
|
||||
properties={"error_message": "Abort handlers did not complete"},
|
||||
set_ended_at=True,
|
||||
).run()
|
||||
logger.warning(
|
||||
"Task %s (uuid=%s) stuck in ABORTING - marking as FAILURE",
|
||||
task_type,
|
||||
task_uuid,
|
||||
)
|
||||
|
||||
# Refresh to get final status for return value and completion notification
|
||||
refreshed = TaskDAO.find_one_or_none(uuid=native_uuid)
|
||||
final_status = refreshed.status if refreshed else "unknown"
|
||||
|
||||
# Publish completion notification for any waiters (e.g., sync callers)
|
||||
if final_status in TERMINAL_STATES:
|
||||
TaskManager.publish_completion(native_uuid, final_status)
|
||||
|
||||
return {"status": final_status, "task_uuid": task_uuid}
|
||||
|
||||
200
superset/tasks/schemas.py
Normal file
200
superset/tasks/schemas.py
Normal file
@@ -0,0 +1,200 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""Task API schemas"""
|
||||
|
||||
from marshmallow import fields, Schema
|
||||
from marshmallow.fields import Method
|
||||
|
||||
# RISON/JSON schemas for query parameters
|
||||
get_delete_ids_schema = {"type": "array", "items": {"type": "string"}}
|
||||
|
||||
# Field descriptions
|
||||
uuid_description = "The unique identifier (UUID) of the task"
|
||||
task_key_description = "The task identifier used for deduplication"
|
||||
task_type_description = (
|
||||
"The type of task (e.g., 'sql_execution', 'thumbnail_generation')"
|
||||
)
|
||||
task_name_description = "Human-readable name for the task"
|
||||
status_description = "Current status of the task"
|
||||
created_on_description = "Timestamp when the task was created"
|
||||
changed_on_description = "Timestamp when the task was last updated"
|
||||
started_at_description = "Timestamp when the task started execution"
|
||||
ended_at_description = "Timestamp when the task completed or failed"
|
||||
created_by_description = "User who created the task"
|
||||
user_id_description = "ID of the user context for task execution"
|
||||
payload_description = "Task-specific data in JSON format"
|
||||
properties_description = (
|
||||
"Runtime state and execution config. Contains: is_abortable, progress_percent, "
|
||||
"progress_current, progress_total, error_message, exception_type, stack_trace, "
|
||||
"timeout"
|
||||
)
|
||||
duration_seconds_description = (
|
||||
"Duration in seconds - for finished tasks: execution time, "
|
||||
"for running tasks: time since start, for pending: queue time"
|
||||
)
|
||||
scope_description = (
|
||||
"Task scope: 'private' (user-specific), 'shared' (multi-user), "
|
||||
"or 'system' (admin-only)"
|
||||
)
|
||||
subscriber_count_description = (
|
||||
"Number of users subscribed to this task (for shared tasks)"
|
||||
)
|
||||
subscribers_description = "List of users subscribed to this task (for shared tasks)"
|
||||
|
||||
|
||||
class UserSchema(Schema):
|
||||
"""Schema for user information"""
|
||||
|
||||
id = fields.Int()
|
||||
first_name = fields.String()
|
||||
last_name = fields.String()
|
||||
|
||||
|
||||
class TaskResponseSchema(Schema):
|
||||
"""
|
||||
Schema for task response.
|
||||
|
||||
Used for both list and detail endpoints.
|
||||
"""
|
||||
|
||||
id = fields.Int(metadata={"description": "Internal task ID"})
|
||||
uuid = fields.UUID(metadata={"description": uuid_description})
|
||||
task_key = fields.String(metadata={"description": task_key_description})
|
||||
task_type = fields.String(metadata={"description": task_type_description})
|
||||
task_name = fields.String(
|
||||
metadata={"description": task_name_description}, allow_none=True
|
||||
)
|
||||
status = fields.String(metadata={"description": status_description})
|
||||
created_on = fields.DateTime(metadata={"description": created_on_description})
|
||||
created_on_delta_humanized = Method(
|
||||
"get_created_on_delta_humanized",
|
||||
metadata={"description": "Humanized time since creation"},
|
||||
)
|
||||
changed_on = fields.DateTime(metadata={"description": changed_on_description})
|
||||
changed_by = fields.Nested(UserSchema, allow_none=True)
|
||||
started_at = fields.DateTime(
|
||||
metadata={"description": started_at_description}, allow_none=True
|
||||
)
|
||||
ended_at = fields.DateTime(
|
||||
metadata={"description": ended_at_description}, allow_none=True
|
||||
)
|
||||
created_by = fields.Nested(UserSchema, allow_none=True)
|
||||
user_id = fields.Int(metadata={"description": user_id_description}, allow_none=True)
|
||||
payload = Method("get_payload_dict", metadata={"description": payload_description})
|
||||
properties = Method(
|
||||
"get_properties", metadata={"description": properties_description}
|
||||
)
|
||||
duration_seconds = Method(
|
||||
"get_duration",
|
||||
metadata={"description": duration_seconds_description},
|
||||
)
|
||||
scope = fields.String(metadata={"description": scope_description})
|
||||
subscriber_count = Method(
|
||||
"get_subscriber_count", metadata={"description": subscriber_count_description}
|
||||
)
|
||||
subscribers = Method(
|
||||
"get_subscribers", metadata={"description": subscribers_description}
|
||||
)
|
||||
|
||||
def get_payload_dict(self, obj: object) -> dict[str, object] | None:
|
||||
"""Get payload as dictionary"""
|
||||
return obj.payload_dict # type: ignore[attr-defined]
|
||||
|
||||
def get_properties(self, obj: object) -> dict[str, object]:
|
||||
"""Get properties dict, filtering stack_trace if SHOW_STACKTRACE is disabled."""
|
||||
from flask import current_app
|
||||
|
||||
properties = dict(obj.properties_dict) # type: ignore[attr-defined]
|
||||
|
||||
# Remove stack_trace unless SHOW_STACKTRACE is enabled
|
||||
if not current_app.config.get("SHOW_STACKTRACE", False):
|
||||
properties.pop("stack_trace", None)
|
||||
|
||||
return properties
|
||||
|
||||
def get_duration(self, obj: object) -> float | None:
|
||||
"""Get duration in seconds"""
|
||||
return obj.duration_seconds # type: ignore[attr-defined]
|
||||
|
||||
def get_created_on_delta_humanized(self, obj: object) -> str:
|
||||
"""Get humanized time since creation"""
|
||||
return obj.created_on_delta_humanized() # type: ignore[attr-defined]
|
||||
|
||||
def get_subscriber_count(self, obj: object) -> int:
|
||||
"""Get number of subscribers"""
|
||||
return obj.subscriber_count # type: ignore[attr-defined]
|
||||
|
||||
def get_subscribers(self, obj: object) -> list[dict[str, object]]:
|
||||
"""Get list of subscribers with user info"""
|
||||
subscribers = []
|
||||
for sub in obj.subscribers: # type: ignore[attr-defined]
|
||||
subscribers.append(
|
||||
{
|
||||
"user_id": sub.user_id,
|
||||
"first_name": sub.user.first_name if sub.user else None,
|
||||
"last_name": sub.user.last_name if sub.user else None,
|
||||
"subscribed_at": sub.subscribed_at.isoformat()
|
||||
if sub.subscribed_at
|
||||
else None,
|
||||
}
|
||||
)
|
||||
return subscribers
|
||||
|
||||
|
||||
class TaskStatusResponseSchema(Schema):
|
||||
"""Schema for task status response (lightweight for polling)"""
|
||||
|
||||
status = fields.String(metadata={"description": status_description})
|
||||
|
||||
|
||||
class TaskCancelRequestSchema(Schema):
|
||||
"""Schema for task cancellation request"""
|
||||
|
||||
force = fields.Boolean(
|
||||
load_default=False,
|
||||
metadata={
|
||||
"description": "Force cancel the task for all subscribers (admin only). "
|
||||
"Only applicable for shared tasks with multiple subscribers."
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class TaskCancelResponseSchema(Schema):
|
||||
"""Schema for task cancellation response"""
|
||||
|
||||
message = fields.String(metadata={"description": "Success or status message"})
|
||||
action = fields.String(
|
||||
metadata={
|
||||
"description": "The action taken: 'aborted' (task terminated) or "
|
||||
"'unsubscribed' (user removed from shared task)"
|
||||
}
|
||||
)
|
||||
task = fields.Nested(TaskResponseSchema, allow_none=True)
|
||||
|
||||
|
||||
openapi_spec_methods_override = {
|
||||
"get": {"get": {"summary": "Get a task detail"}},
|
||||
"get_list": {
|
||||
"get": {
|
||||
"summary": "Get a list of tasks",
|
||||
"description": "Gets a list of tasks for the current user. "
|
||||
"Use Rison or JSON query parameters for filtering, sorting, "
|
||||
"pagination and for selecting specific columns and metadata.",
|
||||
}
|
||||
},
|
||||
"info": {"get": {"summary": "Get metadata information about this API resource"}},
|
||||
}
|
||||
@@ -14,6 +14,8 @@
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import NamedTuple
|
||||
|
||||
from superset.utils.backports import StrEnum
|
||||
|
||||
@@ -18,16 +18,25 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import traceback
|
||||
from http.client import HTTPResponse
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
from typing import cast, TYPE_CHECKING
|
||||
from urllib import request
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from celery.utils.log import get_task_logger
|
||||
from flask import g
|
||||
from superset_core.api.tasks import TaskProperties, TaskScope
|
||||
|
||||
from superset.tasks.exceptions import ExecutorNotFoundError, InvalidExecutorError
|
||||
from superset.tasks.types import ChosenExecutor, Executor, ExecutorType, FixedExecutor
|
||||
from superset.tasks.types import (
|
||||
ChosenExecutor,
|
||||
Executor,
|
||||
ExecutorType,
|
||||
FixedExecutor,
|
||||
)
|
||||
from superset.utils import json
|
||||
from superset.utils.hashing import hash_from_str
|
||||
from superset.utils.urls import get_url_path
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -123,7 +132,7 @@ def fetch_csrf_token(
|
||||
response: HTTPResponse
|
||||
with request.urlopen(req, timeout=600) as response: # noqa: S310
|
||||
body = response.read().decode("utf-8")
|
||||
session_cookie: Optional[str] = None
|
||||
session_cookie: str | None = None
|
||||
cookie_headers = response.headers.get_all("set-cookie")
|
||||
if cookie_headers:
|
||||
for cookie in cookie_headers:
|
||||
@@ -142,3 +151,164 @@ def fetch_csrf_token(
|
||||
|
||||
logger.error("Error fetching CSRF token, status code: %s", response.status)
|
||||
return {}
|
||||
|
||||
|
||||
def generate_random_task_key() -> str:
|
||||
"""
|
||||
Generate a random task key.
|
||||
|
||||
This is the default behavior - each task submission gets a unique UUID
|
||||
unless an explicit task_key is provided in TaskOptions.
|
||||
|
||||
:returns: A random UUID string
|
||||
"""
|
||||
return str(uuid4())
|
||||
|
||||
|
||||
def get_active_dedup_key(
|
||||
scope: TaskScope | str,
|
||||
task_type: str,
|
||||
task_key: str,
|
||||
user_id: int | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Build a deduplication key for active tasks.
|
||||
|
||||
The dedup_key enforces uniqueness at the database level via a unique index.
|
||||
Active tasks use a composite key based on scope, which is then hashed using
|
||||
the configured HASH_ALGORITHM to produce a fixed-length key.
|
||||
|
||||
The composite key format before hashing is:
|
||||
- Private: private|task_type|task_key|user_id
|
||||
- Shared: shared|task_type|task_key
|
||||
- System: system|task_type|task_key
|
||||
|
||||
The final key is a hash digest (64 chars for sha256, 32 chars for md5).
|
||||
|
||||
:param scope: Task scope (PRIVATE/SHARED/SYSTEM) as TaskScope enum or string
|
||||
:param task_type: Type of task (e.g., 'sql_execution')
|
||||
:param task_key: Task identifier for deduplication
|
||||
:param user_id: User ID (required for private tasks)
|
||||
:returns: Hashed deduplication key string
|
||||
:raises ValueError: If user_id is missing for private scope
|
||||
"""
|
||||
# Convert string to TaskScope if needed
|
||||
if isinstance(scope, str):
|
||||
scope = TaskScope(scope)
|
||||
|
||||
# Build composite key
|
||||
match scope:
|
||||
case TaskScope.PRIVATE:
|
||||
if user_id is None:
|
||||
raise ValueError("user_id required for private tasks")
|
||||
composite_key = f"{scope.value}|{task_type}|{task_key}|{user_id}"
|
||||
case TaskScope.SHARED:
|
||||
composite_key = f"{scope.value}|{task_type}|{task_key}"
|
||||
case TaskScope.SYSTEM:
|
||||
composite_key = f"{scope.value}|{task_type}|{task_key}"
|
||||
case _:
|
||||
raise ValueError(f"Invalid scope: {scope}")
|
||||
|
||||
# Hash the composite key to produce a fixed-length dedup_key
|
||||
# Truncate to 64 chars max to fit the database column in case
|
||||
# a hash algo is used that generates hashes that exceed 64 chars
|
||||
return hash_from_str(composite_key)[:64]
|
||||
|
||||
|
||||
def get_finished_dedup_key(task_uuid: UUID) -> str:
|
||||
"""
|
||||
Build a deduplication key for finished tasks.
|
||||
|
||||
When a task completes (success, failure, or abort), its dedup_key is
|
||||
changed to its UUID. This frees up the slot so new tasks with the same
|
||||
parameters can be created.
|
||||
|
||||
:param task_uuid: Task UUID (native UUID type)
|
||||
:returns: The task UUID string as the dedup key
|
||||
|
||||
Example:
|
||||
>>> from uuid import UUID
|
||||
>>> get_finished_dedup_key(UUID("a1b2c3d4-e5f6-7890-abcd-ef1234567890"))
|
||||
'a1b2c3d4-e5f6-7890-abcd-ef1234567890'
|
||||
"""
|
||||
return str(task_uuid)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# TaskProperties helper functions
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def progress_update(progress: float | int | tuple[int, int]) -> TaskProperties:
|
||||
"""
|
||||
Create a properties update dict for progress values.
|
||||
|
||||
:param progress: One of:
|
||||
- float (0.0-1.0): Percentage only
|
||||
- int: Count only (total unknown)
|
||||
- tuple[int, int]: (current, total) with auto-computed percentage
|
||||
:returns: TaskProperties dict with appropriate progress fields set
|
||||
|
||||
Example:
|
||||
task.update_properties(progress_update((50, 100)))
|
||||
"""
|
||||
if isinstance(progress, float):
|
||||
return {"progress_percent": progress}
|
||||
if isinstance(progress, int):
|
||||
return {"progress_current": progress}
|
||||
# tuple
|
||||
current, total = progress
|
||||
result: TaskProperties = {
|
||||
"progress_current": current,
|
||||
"progress_total": total,
|
||||
}
|
||||
if total > 0:
|
||||
result["progress_percent"] = current / total
|
||||
return result
|
||||
|
||||
|
||||
def error_update(exception: BaseException) -> TaskProperties:
|
||||
"""
|
||||
Create a properties update dict from an exception.
|
||||
|
||||
:param exception: The exception that caused the failure
|
||||
:returns: TaskProperties dict with error fields populated
|
||||
"""
|
||||
return {
|
||||
"error_message": str(exception),
|
||||
"exception_type": type(exception).__name__,
|
||||
"stack_trace": traceback.format_exc(),
|
||||
}
|
||||
|
||||
|
||||
def parse_properties(json_str: str | None) -> TaskProperties:
|
||||
"""
|
||||
Parse JSON string into TaskProperties dict.
|
||||
|
||||
Returns empty dict on parse errors. Unknown keys are preserved
|
||||
for forward compatibility (allows adding new properties without
|
||||
breaking existing code).
|
||||
|
||||
:param json_str: JSON string or None
|
||||
:returns: TaskProperties dict (sparse - only contains keys that were set)
|
||||
"""
|
||||
if not json_str:
|
||||
return {}
|
||||
|
||||
try:
|
||||
raw = json.loads(json_str)
|
||||
if isinstance(raw, dict):
|
||||
return cast(TaskProperties, raw)
|
||||
return {}
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return {}
|
||||
|
||||
|
||||
def serialize_properties(props: TaskProperties) -> str:
|
||||
"""
|
||||
Serialize TaskProperties to JSON string.
|
||||
|
||||
:param props: TaskProperties dict
|
||||
:returns: JSON string
|
||||
"""
|
||||
return json.dumps(props)
|
||||
|
||||
@@ -14,9 +14,11 @@
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
|
||||
|
||||
from flask import current_app, Flask
|
||||
from flask_caching import Cache
|
||||
@@ -24,6 +26,12 @@ from markupsafe import Markup
|
||||
|
||||
from superset.utils.core import DatasourceType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.async_events.cache_backend import (
|
||||
RedisCacheBackend,
|
||||
RedisSentinelCacheBackend,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CACHE_IMPORT_PATH = "superset.extensions.metastore_cache.SupersetMetastoreCache"
|
||||
@@ -185,6 +193,7 @@ class CacheManager:
|
||||
self._thumbnail_cache = SupersetCache()
|
||||
self._filter_state_cache = SupersetCache()
|
||||
self._explore_form_data_cache = ExploreFormDataCache()
|
||||
self._signal_cache: RedisCacheBackend | RedisSentinelCacheBackend | None = None
|
||||
|
||||
@staticmethod
|
||||
def _init_cache(
|
||||
@@ -226,6 +235,30 @@ class CacheManager:
|
||||
"EXPLORE_FORM_DATA_CACHE_CONFIG",
|
||||
required=True,
|
||||
)
|
||||
self._init_signal_cache(app)
|
||||
|
||||
def _init_signal_cache(self, app: Flask) -> None:
|
||||
"""Initialize the signal cache for pub/sub and distributed locks."""
|
||||
from superset.async_events.cache_backend import (
|
||||
RedisCacheBackend,
|
||||
RedisSentinelCacheBackend,
|
||||
)
|
||||
|
||||
config = app.config.get("SIGNAL_CACHE_CONFIG")
|
||||
if not config:
|
||||
return
|
||||
|
||||
cache_type = config.get("CACHE_TYPE")
|
||||
if cache_type == "RedisCache":
|
||||
self._signal_cache = RedisCacheBackend.from_config(config)
|
||||
elif cache_type == "RedisSentinelCache":
|
||||
self._signal_cache = RedisSentinelCacheBackend.from_config(config)
|
||||
else:
|
||||
logger.warning(
|
||||
"Unsupported CACHE_TYPE for SIGNAL_CACHE_CONFIG: %s. "
|
||||
"Use 'RedisCache' or 'RedisSentinelCache'.",
|
||||
cache_type,
|
||||
)
|
||||
|
||||
@property
|
||||
def data_cache(self) -> Cache:
|
||||
@@ -246,3 +279,23 @@ class CacheManager:
|
||||
@property
|
||||
def explore_form_data_cache(self) -> Cache:
|
||||
return self._explore_form_data_cache
|
||||
|
||||
@property
|
||||
def signal_cache(
|
||||
self,
|
||||
) -> RedisCacheBackend | RedisSentinelCacheBackend | None:
|
||||
"""
|
||||
Return the signal cache backend.
|
||||
|
||||
Used for signaling features that require Redis-specific primitives:
|
||||
- Pub/Sub messaging for real-time abort/completion notifications
|
||||
- SET NX EX for atomic distributed lock acquisition
|
||||
|
||||
The backend provides:
|
||||
- `._cache`: Raw Redis client
|
||||
- `.key_prefix`: Configured key prefix (from CACHE_KEY_PREFIX)
|
||||
- `.default_timeout`: Default timeout in seconds (from CACHE_DEFAULT_TIMEOUT)
|
||||
|
||||
Returns None if SIGNAL_CACHE_CONFIG is not configured.
|
||||
"""
|
||||
return self._signal_cache
|
||||
|
||||
@@ -24,7 +24,7 @@ from abc import ABC, abstractmethod
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Callable, cast, Literal, TYPE_CHECKING
|
||||
from typing import Any, Callable, cast, Literal
|
||||
|
||||
from flask import g, has_request_context, request
|
||||
from flask_appbuilder.const import API_URI_RIS_KEY
|
||||
@@ -34,9 +34,6 @@ from superset.extensions import stats_logger_manager
|
||||
from superset.utils import json
|
||||
from superset.utils.core import get_user_id, LoggerLevel, to_int
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
||||
@@ -31,8 +31,8 @@ from flask import current_app as app, url_for
|
||||
from marshmallow import EXCLUDE, fields, post_load, Schema, validate
|
||||
|
||||
from superset import db
|
||||
from superset.distributed_lock import KeyValueDistributedLock
|
||||
from superset.exceptions import CreateKeyValueDistributedLockFailedException
|
||||
from superset.distributed_lock import DistributedLock
|
||||
from superset.exceptions import AcquireDistributedLockFailedException
|
||||
from superset.superset_typing import OAuth2ClientConfig, OAuth2State
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -77,7 +77,7 @@ def generate_code_challenge(code_verifier: str) -> str:
|
||||
|
||||
@backoff.on_exception(
|
||||
backoff.expo,
|
||||
CreateKeyValueDistributedLockFailedException,
|
||||
AcquireDistributedLockFailedException,
|
||||
factor=10,
|
||||
base=2,
|
||||
max_tries=5,
|
||||
@@ -128,8 +128,10 @@ def refresh_oauth2_token(
|
||||
db_engine_spec: type[BaseEngineSpec],
|
||||
token: DatabaseUserOAuth2Tokens,
|
||||
) -> str | None:
|
||||
with KeyValueDistributedLock(
|
||||
# Use longer TTL for OAuth2 token refresh (may involve network calls)
|
||||
with DistributedLock(
|
||||
namespace="refresh_oauth2_token",
|
||||
ttl_seconds=30,
|
||||
user_id=user_id,
|
||||
database_id=database_id,
|
||||
):
|
||||
|
||||
@@ -14,32 +14,19 @@
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from flask_appbuilder import expose, has_access
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import cast
|
||||
|
||||
from flask import current_app as app
|
||||
|
||||
from superset.commands.distributed_lock.base import BaseDistributedLockCommand
|
||||
from superset.daos.key_value import KeyValueDAO
|
||||
from superset.distributed_lock.types import LockValue
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
stats_logger = app.config["STATS_LOGGER"]
|
||||
from superset.constants import MODEL_VIEW_RW_METHOD_PERMISSION_MAP
|
||||
from superset.superset_typing import FlaskResponse
|
||||
from superset.views.base import BaseSupersetView
|
||||
|
||||
|
||||
class GetDistributedLock(BaseDistributedLockCommand):
|
||||
def validate(self) -> None:
|
||||
pass
|
||||
class TaskModelView(BaseSupersetView):
|
||||
route_base = "/tasks"
|
||||
class_permission_name = "Task"
|
||||
method_permission_name = MODEL_VIEW_RW_METHOD_PERMISSION_MAP
|
||||
|
||||
def run(self) -> LockValue | None:
|
||||
entry = KeyValueDAO.get_entry(
|
||||
resource=self.resource,
|
||||
key=self.key,
|
||||
)
|
||||
if not entry or entry.is_expired():
|
||||
return None
|
||||
|
||||
return cast(LockValue, self.codec.decode(entry.value))
|
||||
@expose("/list/")
|
||||
@has_access
|
||||
def list(self) -> FlaskResponse:
|
||||
return super().render_app_template()
|
||||
@@ -73,6 +73,7 @@ FEATURE_FLAGS = {
|
||||
"AVOID_COLORS_COLLISION": True,
|
||||
"DRILL_TO_DETAIL": True,
|
||||
"DRILL_BY": True,
|
||||
"GLOBAL_TASK_FRAMEWORK": True,
|
||||
}
|
||||
|
||||
WEBDRIVER_BASEURL = "http://0.0.0.0:8081/"
|
||||
|
||||
538
tests/integration_tests/tasks/api_tests.py
Normal file
538
tests/integration_tests/tasks/api_tests.py
Normal file
@@ -0,0 +1,538 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""Integration tests for Task REST API"""
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import Generator
|
||||
|
||||
import prison
|
||||
from superset_core.api.tasks import TaskStatus
|
||||
|
||||
from superset import db
|
||||
from superset.models.tasks import Task
|
||||
from superset.utils import json
|
||||
from tests.integration_tests.base_tests import SupersetTestCase
|
||||
from tests.integration_tests.constants import (
|
||||
ADMIN_USERNAME,
|
||||
GAMMA_USERNAME,
|
||||
)
|
||||
|
||||
|
||||
class TestTaskApi(SupersetTestCase):
|
||||
"""Tests for Task REST API"""
|
||||
|
||||
TASK_API_BASE = "api/v1/task"
|
||||
|
||||
@contextmanager
|
||||
def _create_tasks(self) -> Generator[list[Task], None, None]:
|
||||
"""
|
||||
Context manager to create test tasks with guaranteed cleanup.
|
||||
|
||||
Uses TaskDAO to create tasks, testing the actual production code path.
|
||||
|
||||
Usage:
|
||||
with self._create_tasks() as tasks:
|
||||
# Use tasks in test
|
||||
# Cleanup happens automatically even if test fails
|
||||
"""
|
||||
from superset_core.api.tasks import TaskScope
|
||||
|
||||
from superset.daos.tasks import TaskDAO
|
||||
|
||||
admin = self.get_user("admin")
|
||||
gamma = self.get_user("gamma")
|
||||
|
||||
tasks = []
|
||||
|
||||
try:
|
||||
# Create tasks with different statuses using TaskDAO
|
||||
for i in range(5):
|
||||
task_key = f"test_task_{i}"
|
||||
|
||||
# Create task using DAO (this tests the dedup_key creation logic)
|
||||
task = TaskDAO.create_task(
|
||||
task_type="test_type",
|
||||
task_key=task_key,
|
||||
task_name=f"Test Task {i}",
|
||||
scope=TaskScope.PRIVATE,
|
||||
user_id=admin.id,
|
||||
payload={"test": "data"},
|
||||
)
|
||||
|
||||
# Set created_by for test purposes (DAO uses Flask-AppBuilder context)
|
||||
task.created_by = admin
|
||||
|
||||
# Alternate between pending and finished tasks
|
||||
if i % 2 != 0:
|
||||
# Simulate realistic task lifecycle: PENDING → IN_PROGRESS → SUCCESS
|
||||
# This sets both started_at (on IN_PROGRESS) and ended_at (on
|
||||
# SUCCESS) so duration_seconds returns a valid value
|
||||
task.set_status(TaskStatus.IN_PROGRESS)
|
||||
task.set_status(TaskStatus.SUCCESS)
|
||||
|
||||
db.session.commit()
|
||||
tasks.append(task)
|
||||
|
||||
# Create pending task for gamma user (use PENDING so it can be aborted)
|
||||
gamma_task = TaskDAO.create_task(
|
||||
task_type="test_type",
|
||||
task_key="gamma_task",
|
||||
task_name="Gamma Task",
|
||||
scope=TaskScope.PRIVATE,
|
||||
user_id=gamma.id,
|
||||
payload={"user": "gamma"},
|
||||
)
|
||||
# Set created_by for test purposes
|
||||
gamma_task.created_by = gamma
|
||||
db.session.commit()
|
||||
tasks.append(gamma_task)
|
||||
|
||||
yield tasks
|
||||
finally:
|
||||
# Cleanup happens here regardless of test success/failure
|
||||
for task in tasks:
|
||||
try:
|
||||
db.session.delete(task)
|
||||
except Exception: # noqa: S110
|
||||
# Task may already be deleted or session may be in bad state
|
||||
pass
|
||||
try:
|
||||
db.session.commit()
|
||||
except Exception:
|
||||
# Rollback if commit fails
|
||||
db.session.rollback()
|
||||
|
||||
def test_info_task(self):
|
||||
"""
|
||||
Task API: Test info endpoint
|
||||
"""
|
||||
self.login(ADMIN_USERNAME)
|
||||
uri = f"{self.TASK_API_BASE}/_info"
|
||||
rv = self.client.get(uri)
|
||||
assert rv.status_code == 200
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
assert "permissions" in data
|
||||
|
||||
def test_get_task_by_uuid(self):
|
||||
"""
|
||||
Task API: Test get task by UUID and verify dedup_key is hashed
|
||||
"""
|
||||
with self._create_tasks():
|
||||
self.login(ADMIN_USERNAME)
|
||||
admin = self.get_user("admin")
|
||||
|
||||
# Get a pending task to verify active dedup_key format
|
||||
task = (
|
||||
db.session.query(Task)
|
||||
.filter_by(
|
||||
created_by_fk=admin.id,
|
||||
status=TaskStatus.PENDING.value,
|
||||
task_type="test_type",
|
||||
)
|
||||
.first()
|
||||
)
|
||||
assert task is not None
|
||||
|
||||
# Verify active task has hashed dedup_key (64 chars for SHA-256)
|
||||
assert len(task.dedup_key) == 64
|
||||
assert all(c in "0123456789abcdef" for c in task.dedup_key)
|
||||
assert task.dedup_key != str(task.uuid)
|
||||
|
||||
uri = f"{self.TASK_API_BASE}/{task.uuid}"
|
||||
rv = self.client.get(uri)
|
||||
assert rv.status_code == 200
|
||||
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
# Compare strings since JSON response contains string UUID
|
||||
assert data["result"]["uuid"] == str(task.uuid)
|
||||
assert data["result"]["id"] == task.id
|
||||
|
||||
def test_get_task_not_found(self):
|
||||
"""
|
||||
Task API: Test get task not found with non-existent UUID
|
||||
"""
|
||||
self.login(ADMIN_USERNAME)
|
||||
# Use a valid UUID that doesn't exist in the database
|
||||
uri = f"{self.TASK_API_BASE}/00000000-0000-0000-0000-000000000000"
|
||||
rv = self.client.get(uri)
|
||||
assert rv.status_code == 404
|
||||
|
||||
def test_get_task_invalid_uuid(self):
|
||||
"""
|
||||
Task API: Test get task with invalid UUID
|
||||
"""
|
||||
self.login(ADMIN_USERNAME)
|
||||
uri = f"{self.TASK_API_BASE}/invalid-uuid"
|
||||
rv = self.client.get(uri)
|
||||
assert rv.status_code == 404
|
||||
|
||||
def test_get_task_list(self):
|
||||
"""
|
||||
Task API: Test get task list
|
||||
"""
|
||||
with self._create_tasks():
|
||||
self.login(ADMIN_USERNAME)
|
||||
uri = f"{self.TASK_API_BASE}/"
|
||||
rv = self.client.get(uri)
|
||||
assert rv.status_code == 200
|
||||
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
assert data["count"] >= 6 # At least the fixtures we created
|
||||
assert "result" in data
|
||||
|
||||
def test_get_task_list_filtered_by_status(self):
|
||||
"""
|
||||
Task API: Test get task list filtered by status
|
||||
"""
|
||||
with self._create_tasks():
|
||||
self.login(ADMIN_USERNAME)
|
||||
arguments = {
|
||||
"filters": [
|
||||
{"col": "status", "opr": "eq", "value": TaskStatus.PENDING.value}
|
||||
]
|
||||
}
|
||||
uri = f"{self.TASK_API_BASE}/?q={prison.dumps(arguments)}"
|
||||
rv = self.client.get(uri)
|
||||
assert rv.status_code == 200
|
||||
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
for task in data["result"]:
|
||||
assert task["status"] == TaskStatus.PENDING.value
|
||||
|
||||
def test_get_task_list_filtered_by_type(self):
|
||||
"""
|
||||
Task API: Test get task list filtered by type
|
||||
"""
|
||||
with self._create_tasks():
|
||||
self.login(ADMIN_USERNAME)
|
||||
arguments = {
|
||||
"filters": [{"col": "task_type", "opr": "eq", "value": "test_type"}]
|
||||
}
|
||||
uri = f"{self.TASK_API_BASE}/?q={prison.dumps(arguments)}"
|
||||
rv = self.client.get(uri)
|
||||
assert rv.status_code == 200
|
||||
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
assert data["count"] >= 6
|
||||
for task in data["result"]:
|
||||
assert task["task_type"] == "test_type"
|
||||
|
||||
def test_get_task_list_ordered(self):
|
||||
"""
|
||||
Task API: Test get task list with ordering
|
||||
"""
|
||||
with self._create_tasks():
|
||||
self.login(ADMIN_USERNAME)
|
||||
arguments = {
|
||||
"order_column": "created_on",
|
||||
"order_direction": "desc",
|
||||
}
|
||||
uri = f"{self.TASK_API_BASE}/?q={prison.dumps(arguments)}"
|
||||
rv = self.client.get(uri)
|
||||
assert rv.status_code == 200
|
||||
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
assert len(data["result"]) > 0
|
||||
|
||||
def test_get_task_list_paginated(self):
|
||||
"""
|
||||
Task API: Test get task list with pagination
|
||||
"""
|
||||
with self._create_tasks():
|
||||
self.login(ADMIN_USERNAME)
|
||||
arguments = {"page": 0, "page_size": 2}
|
||||
uri = f"{self.TASK_API_BASE}/?q={prison.dumps(arguments)}"
|
||||
rv = self.client.get(uri)
|
||||
assert rv.status_code == 200
|
||||
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
assert len(data["result"]) <= 2
|
||||
assert data["count"] >= 6
|
||||
|
||||
def test_cancel_task_by_uuid(self):
|
||||
"""
|
||||
Task API: Test cancel task by UUID
|
||||
"""
|
||||
with self._create_tasks():
|
||||
self.login(ADMIN_USERNAME)
|
||||
admin = self.get_user("admin")
|
||||
|
||||
task = (
|
||||
db.session.query(Task)
|
||||
.filter_by(created_by_fk=admin.id, status=TaskStatus.PENDING.value)
|
||||
.first()
|
||||
)
|
||||
assert task is not None
|
||||
|
||||
uri = f"{self.TASK_API_BASE}/{task.uuid}/cancel"
|
||||
rv = self.client.post(uri)
|
||||
assert rv.status_code == 200
|
||||
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
# Compare strings since JSON response contains string UUID
|
||||
assert data["task"]["uuid"] == str(task.uuid)
|
||||
assert data["task"]["status"] == TaskStatus.ABORTED.value
|
||||
assert data["action"] == "aborted"
|
||||
|
||||
def test_cancel_task_not_found(self):
|
||||
"""
|
||||
Task API: Test cancel task not found with non-existent UUID
|
||||
"""
|
||||
self.login(ADMIN_USERNAME)
|
||||
uri = f"{self.TASK_API_BASE}/00000000-0000-0000-0000-000000000000/cancel"
|
||||
rv = self.client.post(uri)
|
||||
assert rv.status_code == 404
|
||||
|
||||
def test_cancel_task_not_owned(self):
|
||||
"""
|
||||
Task API: Test cancel task not owned by user
|
||||
"""
|
||||
with self._create_tasks():
|
||||
self.login(GAMMA_USERNAME)
|
||||
admin = self.get_user("admin")
|
||||
|
||||
# Try to cancel admin's task as gamma user
|
||||
task = db.session.query(Task).filter_by(created_by_fk=admin.id).first()
|
||||
assert task is not None
|
||||
|
||||
uri = f"{self.TASK_API_BASE}/{task.uuid}/cancel"
|
||||
rv = self.client.post(uri)
|
||||
assert rv.status_code == 404
|
||||
|
||||
def test_cancel_task_admin_can_cancel_others(self):
|
||||
"""
|
||||
Task API: Test admin can cancel other users' tasks
|
||||
"""
|
||||
with self._create_tasks():
|
||||
self.login(ADMIN_USERNAME)
|
||||
gamma = self.get_user("gamma")
|
||||
|
||||
# Admin cancels gamma's task
|
||||
task = db.session.query(Task).filter_by(created_by_fk=gamma.id).first()
|
||||
assert task is not None
|
||||
|
||||
uri = f"{self.TASK_API_BASE}/{task.uuid}/cancel"
|
||||
rv = self.client.post(uri)
|
||||
assert rv.status_code == 200
|
||||
|
||||
def test_get_task_status_by_uuid(self):
|
||||
"""
|
||||
Task API: Test get task status by UUID
|
||||
"""
|
||||
with self._create_tasks():
|
||||
self.login(ADMIN_USERNAME)
|
||||
admin = self.get_user("admin")
|
||||
|
||||
task = db.session.query(Task).filter_by(created_by_fk=admin.id).first()
|
||||
assert task is not None
|
||||
|
||||
uri = f"{self.TASK_API_BASE}/{task.uuid}/status"
|
||||
rv = self.client.get(uri)
|
||||
assert rv.status_code == 200
|
||||
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
assert "status" in data
|
||||
assert data["status"] == task.status
|
||||
|
||||
def test_get_task_status_not_found(self):
|
||||
"""
|
||||
Task API: Test get task status not found with non-existent UUID
|
||||
"""
|
||||
self.login(ADMIN_USERNAME)
|
||||
uri = f"{self.TASK_API_BASE}/00000000-0000-0000-0000-000000000000/status"
|
||||
rv = self.client.get(uri)
|
||||
assert rv.status_code == 404
|
||||
|
||||
def test_get_task_status_not_owned(self):
|
||||
"""
|
||||
Task API: Test non-owner can't see task status
|
||||
"""
|
||||
with self._create_tasks():
|
||||
self.login(GAMMA_USERNAME)
|
||||
admin = self.get_user("admin")
|
||||
|
||||
# Try to get status of admin's task as gamma user
|
||||
task = db.session.query(Task).filter_by(created_by_fk=admin.id).first()
|
||||
assert task is not None
|
||||
|
||||
uri = f"{self.TASK_API_BASE}/{task.uuid}/status"
|
||||
rv = self.client.get(uri)
|
||||
# Should be forbidden due to base filter
|
||||
assert rv.status_code == 404
|
||||
|
||||
def test_get_task_status_admin_can_see_others(self):
|
||||
"""
|
||||
Task API: Test admin can see other users' task status
|
||||
"""
|
||||
with self._create_tasks():
|
||||
self.login(ADMIN_USERNAME)
|
||||
gamma = self.get_user("gamma")
|
||||
|
||||
# Admin gets gamma's task status
|
||||
task = db.session.query(Task).filter_by(created_by_fk=gamma.id).first()
|
||||
assert task is not None
|
||||
|
||||
uri = f"{self.TASK_API_BASE}/{task.uuid}/status"
|
||||
rv = self.client.get(uri)
|
||||
assert rv.status_code == 200
|
||||
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
assert data["status"] == task.status
|
||||
|
||||
def test_get_task_list_user_sees_own_tasks(self):
|
||||
"""
|
||||
Task API: Test non-admin user only sees their own tasks
|
||||
"""
|
||||
with self._create_tasks():
|
||||
self.login(GAMMA_USERNAME)
|
||||
gamma = self.get_user("gamma")
|
||||
|
||||
uri = f"{self.TASK_API_BASE}/"
|
||||
rv = self.client.get(uri)
|
||||
assert rv.status_code == 200
|
||||
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
# Gamma should only see their own task
|
||||
for task in data["result"]:
|
||||
assert task["created_by"]["id"] == gamma.id
|
||||
|
||||
def test_get_task_list_admin_sees_all_tasks(self):
|
||||
"""
|
||||
Task API: Test admin sees all tasks
|
||||
"""
|
||||
with self._create_tasks():
|
||||
self.login(ADMIN_USERNAME)
|
||||
|
||||
uri = f"{self.TASK_API_BASE}/"
|
||||
rv = self.client.get(uri)
|
||||
assert rv.status_code == 200
|
||||
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
# Admin should see all tasks
|
||||
assert data["count"] >= 6
|
||||
|
||||
def test_task_response_schema(self):
|
||||
"""
|
||||
Task API: Test response schema includes all expected fields
|
||||
"""
|
||||
with self._create_tasks():
|
||||
self.login(ADMIN_USERNAME)
|
||||
admin = self.get_user("admin")
|
||||
|
||||
task = db.session.query(Task).filter_by(created_by_fk=admin.id).first()
|
||||
uri = f"{self.TASK_API_BASE}/{task.uuid}"
|
||||
rv = self.client.get(uri)
|
||||
assert rv.status_code == 200
|
||||
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
result = data["result"]
|
||||
|
||||
# Check all expected fields are present
|
||||
expected_fields = [
|
||||
"id",
|
||||
"uuid",
|
||||
"task_key",
|
||||
"task_type",
|
||||
"task_name",
|
||||
"status",
|
||||
"created_on",
|
||||
"created_on_delta_humanized",
|
||||
"changed_on",
|
||||
"changed_by",
|
||||
"started_at",
|
||||
"ended_at",
|
||||
"created_by",
|
||||
"user_id",
|
||||
"payload",
|
||||
"properties",
|
||||
"duration_seconds",
|
||||
"scope",
|
||||
"subscriber_count",
|
||||
"subscribers",
|
||||
]
|
||||
|
||||
for field in expected_fields:
|
||||
assert field in result, f"Field {field} missing from response"
|
||||
|
||||
# Verify properties is a dict with expected structure
|
||||
properties = result["properties"]
|
||||
assert isinstance(properties, dict)
|
||||
|
||||
def test_task_payload_serialization(self):
|
||||
"""
|
||||
Task API: Test payload is properly serialized as dict
|
||||
"""
|
||||
with self._create_tasks():
|
||||
self.login(ADMIN_USERNAME)
|
||||
admin = self.get_user("admin")
|
||||
|
||||
task = (
|
||||
db.session.query(Task)
|
||||
.filter_by(created_by_fk=admin.id, task_type="test_type")
|
||||
.first()
|
||||
)
|
||||
uri = f"{self.TASK_API_BASE}/{task.uuid}"
|
||||
rv = self.client.get(uri)
|
||||
assert rv.status_code == 200
|
||||
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
payload = data["result"]["payload"]
|
||||
|
||||
# Payload should be a dict, not a string
|
||||
assert isinstance(payload, dict)
|
||||
assert "test" in payload
|
||||
assert payload["test"] == "data"
|
||||
|
||||
def test_task_computed_properties(self):
|
||||
"""
|
||||
Task API: Test computed properties in response
|
||||
|
||||
This test verifies that computed properties (status, duration_seconds)
|
||||
are correctly returned in the API response. Internal DB columns like
|
||||
dedup_key are tested in unit tests (test_find_by_task_key_finished_not_found).
|
||||
"""
|
||||
with self._create_tasks():
|
||||
self.login(ADMIN_USERNAME)
|
||||
admin = self.get_user("admin")
|
||||
|
||||
# Get a successful task
|
||||
task = (
|
||||
db.session.query(Task)
|
||||
.filter_by(created_by_fk=admin.id, status=TaskStatus.SUCCESS.value)
|
||||
.first()
|
||||
)
|
||||
assert task is not None
|
||||
|
||||
uri = f"{self.TASK_API_BASE}/{task.uuid}"
|
||||
rv = self.client.get(uri)
|
||||
assert rv.status_code == 200
|
||||
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
result = data["result"]
|
||||
|
||||
# Check status field (computed properties are now derived from status)
|
||||
assert result["status"] == TaskStatus.SUCCESS.value
|
||||
|
||||
# Properties dict should exist and be a dict
|
||||
assert "properties" in result
|
||||
assert isinstance(result["properties"], dict)
|
||||
|
||||
# Verify duration_seconds is not null for completed tasks with timestamps
|
||||
# (requires both started_at and ended_at to be set)
|
||||
if result.get("started_at") and result.get("ended_at"):
|
||||
assert result["duration_seconds"] is not None
|
||||
assert result["duration_seconds"] >= 0.0
|
||||
16
tests/integration_tests/tasks/commands/__init__.py
Normal file
16
tests/integration_tests/tasks/commands/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
482
tests/integration_tests/tasks/commands/test_cancel.py
Normal file
482
tests/integration_tests/tasks/commands/test_cancel.py
Normal file
@@ -0,0 +1,482 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file to you under
|
||||
# the Apache License, Version 2.0 (the "License"); you may not
|
||||
# use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from unittest.mock import patch
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pytest
|
||||
from superset_core.api.tasks import TaskScope, TaskStatus
|
||||
|
||||
from superset import db
|
||||
from superset.commands.tasks.cancel import CancelTaskCommand
|
||||
from superset.commands.tasks.exceptions import (
|
||||
TaskAbortFailedError,
|
||||
TaskNotAbortableError,
|
||||
TaskNotFoundError,
|
||||
TaskPermissionDeniedError,
|
||||
)
|
||||
from superset.daos.tasks import TaskDAO
|
||||
from superset.utils.core import override_user
|
||||
from tests.integration_tests.test_app import app
|
||||
|
||||
|
||||
def test_cancel_pending_task_aborts(app_context, get_user) -> None:
|
||||
"""Test canceling a pending task directly aborts it"""
|
||||
admin = get_user("admin")
|
||||
|
||||
# Create a pending private task
|
||||
task = TaskDAO.create_task(
|
||||
task_type="test_type",
|
||||
task_key="cancel_pending_test",
|
||||
scope=TaskScope.PRIVATE,
|
||||
user_id=admin.id,
|
||||
)
|
||||
task.created_by = admin
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
# Cancel the pending task with admin user context
|
||||
with override_user(admin):
|
||||
command = CancelTaskCommand(task_uuid=task.uuid)
|
||||
result = command.run()
|
||||
|
||||
# Verify task is aborted (pending goes directly to ABORTED)
|
||||
assert result.uuid == task.uuid
|
||||
assert result.status == TaskStatus.ABORTED.value
|
||||
assert command.action_taken == "aborted"
|
||||
|
||||
# Verify in database
|
||||
db.session.refresh(task)
|
||||
assert task.status == TaskStatus.ABORTED.value
|
||||
finally:
|
||||
# Cleanup
|
||||
db.session.delete(task)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def test_cancel_in_progress_abortable_task_sets_aborting(app_context, get_user) -> None:
|
||||
"""Test canceling an in-progress task with abort handler sets ABORTING"""
|
||||
admin = get_user("admin")
|
||||
|
||||
# Create an in-progress abortable task
|
||||
task = TaskDAO.create_task(
|
||||
task_type="test_type",
|
||||
task_key="cancel_in_progress_test",
|
||||
scope=TaskScope.PRIVATE,
|
||||
user_id=admin.id,
|
||||
properties={"is_abortable": True},
|
||||
)
|
||||
task.created_by = admin
|
||||
task.set_status(TaskStatus.IN_PROGRESS)
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
# Cancel the in-progress task - mock publish_abort to avoid Redis dependency
|
||||
with (
|
||||
override_user(admin),
|
||||
patch("superset.tasks.manager.TaskManager.publish_abort"),
|
||||
):
|
||||
command = CancelTaskCommand(task_uuid=task.uuid)
|
||||
result = command.run()
|
||||
|
||||
# In-progress tasks go to ABORTING (not ABORTED)
|
||||
assert result.uuid == task.uuid
|
||||
assert result.status == TaskStatus.ABORTING.value
|
||||
assert command.action_taken == "aborted"
|
||||
|
||||
# Verify in database
|
||||
db.session.refresh(task)
|
||||
assert task.status == TaskStatus.ABORTING.value
|
||||
finally:
|
||||
# Cleanup
|
||||
db.session.delete(task)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def test_cancel_in_progress_not_abortable_raises_error(app_context, get_user) -> None:
|
||||
"""Test canceling an in-progress task without abort handler raises error"""
|
||||
admin = get_user("admin")
|
||||
unique_key = f"cancel_not_abortable_test_{uuid4().hex[:8]}"
|
||||
|
||||
# Create an in-progress non-abortable task
|
||||
task = TaskDAO.create_task(
|
||||
task_type="test_type",
|
||||
task_key=unique_key,
|
||||
scope=TaskScope.PRIVATE,
|
||||
user_id=admin.id,
|
||||
properties={"is_abortable": False},
|
||||
)
|
||||
task.created_by = admin
|
||||
task.set_status(TaskStatus.IN_PROGRESS)
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
with override_user(admin):
|
||||
command = CancelTaskCommand(task_uuid=task.uuid)
|
||||
|
||||
with pytest.raises(TaskNotAbortableError):
|
||||
command.run()
|
||||
|
||||
# Verify task status unchanged
|
||||
db.session.refresh(task)
|
||||
assert task.status == TaskStatus.IN_PROGRESS.value
|
||||
finally:
|
||||
# Cleanup
|
||||
db.session.delete(task)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def test_cancel_task_not_found(app_context, get_user) -> None:
|
||||
"""Test canceling non-existent task raises error"""
|
||||
admin = get_user("admin")
|
||||
|
||||
with override_user(admin):
|
||||
command = CancelTaskCommand(
|
||||
task_uuid=UUID("00000000-0000-0000-0000-000000000000")
|
||||
)
|
||||
|
||||
with pytest.raises(TaskNotFoundError):
|
||||
command.run()
|
||||
|
||||
|
||||
def test_cancel_finished_task_raises_error(app_context, get_user) -> None:
|
||||
"""Test canceling an already finished task raises error"""
|
||||
|
||||
admin = get_user("admin")
|
||||
unique_key = f"cancel_finished_test_{uuid4().hex[:8]}"
|
||||
|
||||
# Create a finished task
|
||||
task = TaskDAO.create_task(
|
||||
task_type="test_type",
|
||||
task_key=unique_key,
|
||||
scope=TaskScope.PRIVATE,
|
||||
user_id=admin.id,
|
||||
)
|
||||
task.created_by = admin
|
||||
task.set_status(TaskStatus.SUCCESS)
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
with override_user(admin):
|
||||
command = CancelTaskCommand(task_uuid=task.uuid)
|
||||
|
||||
with pytest.raises(TaskAbortFailedError):
|
||||
command.run()
|
||||
|
||||
# Verify task status unchanged
|
||||
db.session.refresh(task)
|
||||
assert task.status == TaskStatus.SUCCESS.value
|
||||
finally:
|
||||
# Cleanup
|
||||
db.session.delete(task)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def test_cancel_shared_task_with_multiple_subscribers_unsubscribes(
|
||||
app_context, get_user
|
||||
) -> None:
|
||||
"""Test canceling a shared task with multiple subscribers unsubscribes user"""
|
||||
admin = get_user("admin")
|
||||
gamma = get_user("gamma")
|
||||
|
||||
# Create a shared task with admin as creator
|
||||
task = TaskDAO.create_task(
|
||||
task_type="test_type",
|
||||
task_key="cancel_shared_test",
|
||||
scope=TaskScope.SHARED,
|
||||
user_id=admin.id,
|
||||
)
|
||||
task.created_by = admin
|
||||
db.session.commit()
|
||||
|
||||
# Add gamma as subscriber
|
||||
TaskDAO.add_subscriber(task.id, user_id=gamma.id)
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
# Verify we have 2 subscribers
|
||||
db.session.refresh(task)
|
||||
assert task.subscriber_count == 2
|
||||
|
||||
# Cancel as gamma (non-admin subscriber)
|
||||
with override_user(gamma):
|
||||
command = CancelTaskCommand(task_uuid=task.uuid)
|
||||
result = command.run()
|
||||
|
||||
# Should unsubscribe, not abort
|
||||
assert command.action_taken == "unsubscribed"
|
||||
assert result.status == TaskStatus.PENDING.value # Status unchanged
|
||||
|
||||
# Verify gamma was unsubscribed
|
||||
db.session.refresh(task)
|
||||
assert task.subscriber_count == 1
|
||||
assert not task.has_subscriber(gamma.id)
|
||||
assert task.has_subscriber(admin.id)
|
||||
finally:
|
||||
# Cleanup
|
||||
db.session.delete(task)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def test_cancel_shared_task_last_subscriber_aborts(app_context, get_user) -> None:
|
||||
"""Test canceling a shared task as last subscriber aborts it"""
|
||||
admin = get_user("admin")
|
||||
|
||||
# Create a shared task with only admin as subscriber
|
||||
task = TaskDAO.create_task(
|
||||
task_type="test_type",
|
||||
task_key="cancel_last_subscriber_test",
|
||||
scope=TaskScope.SHARED,
|
||||
user_id=admin.id,
|
||||
)
|
||||
task.created_by = admin
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
# Verify only 1 subscriber
|
||||
db.session.refresh(task)
|
||||
assert task.subscriber_count == 1
|
||||
|
||||
# Cancel as the only subscriber
|
||||
with override_user(admin):
|
||||
command = CancelTaskCommand(task_uuid=task.uuid)
|
||||
result = command.run()
|
||||
|
||||
# Should abort since last subscriber
|
||||
assert command.action_taken == "aborted"
|
||||
assert result.status == TaskStatus.ABORTED.value
|
||||
finally:
|
||||
# Cleanup
|
||||
db.session.delete(task)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def test_cancel_with_force_aborts_for_all_subscribers(app_context, get_user) -> None:
|
||||
"""Test force cancel aborts shared task even with multiple subscribers"""
|
||||
admin = get_user("admin")
|
||||
gamma = get_user("gamma")
|
||||
|
||||
# Create a shared task with multiple subscribers
|
||||
task = TaskDAO.create_task(
|
||||
task_type="test_type",
|
||||
task_key="force_cancel_test",
|
||||
scope=TaskScope.SHARED,
|
||||
user_id=admin.id,
|
||||
)
|
||||
task.created_by = admin
|
||||
db.session.commit()
|
||||
|
||||
# Add gamma as subscriber
|
||||
TaskDAO.add_subscriber(task.id, user_id=gamma.id)
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
# Verify 2 subscribers
|
||||
db.session.refresh(task)
|
||||
assert task.subscriber_count == 2
|
||||
|
||||
# Force cancel as admin
|
||||
with override_user(admin):
|
||||
command = CancelTaskCommand(task_uuid=task.uuid, force=True)
|
||||
result = command.run()
|
||||
|
||||
# Should abort despite multiple subscribers
|
||||
assert command.action_taken == "aborted"
|
||||
assert result.status == TaskStatus.ABORTED.value
|
||||
finally:
|
||||
# Cleanup
|
||||
db.session.delete(task)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def test_cancel_with_force_requires_admin(app_context, get_user) -> None:
|
||||
"""Test force cancel requires admin privileges"""
|
||||
admin = get_user("admin")
|
||||
gamma = get_user("gamma")
|
||||
|
||||
# Create a shared task
|
||||
task = TaskDAO.create_task(
|
||||
task_type="test_type",
|
||||
task_key="force_requires_admin_test",
|
||||
scope=TaskScope.SHARED,
|
||||
user_id=admin.id,
|
||||
)
|
||||
task.created_by = admin
|
||||
db.session.commit()
|
||||
|
||||
# Add gamma as subscriber
|
||||
TaskDAO.add_subscriber(task.id, user_id=gamma.id)
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
# Try to force cancel as gamma (non-admin)
|
||||
with override_user(gamma):
|
||||
command = CancelTaskCommand(task_uuid=task.uuid, force=True)
|
||||
|
||||
with pytest.raises(TaskPermissionDeniedError):
|
||||
command.run()
|
||||
|
||||
# Verify task unchanged
|
||||
db.session.refresh(task)
|
||||
assert task.status == TaskStatus.PENDING.value
|
||||
finally:
|
||||
# Cleanup
|
||||
db.session.delete(task)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def test_cancel_private_task_permission_denied(app_context, get_user) -> None:
|
||||
"""Test non-owner cannot cancel private task"""
|
||||
admin = get_user("admin")
|
||||
gamma = get_user("gamma")
|
||||
unique_key = f"private_permission_test_{uuid4().hex[:8]}"
|
||||
|
||||
# Use test_request_context to ensure has_request_context() returns True
|
||||
# so that TaskFilter properly applies permission filtering
|
||||
with app.test_request_context():
|
||||
# Create a private task owned by admin
|
||||
task = TaskDAO.create_task(
|
||||
task_type="test_type",
|
||||
task_key=unique_key,
|
||||
scope=TaskScope.PRIVATE,
|
||||
user_id=admin.id,
|
||||
)
|
||||
task.created_by = admin
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
# Try to cancel admin's private task as gamma (non-owner)
|
||||
with override_user(gamma):
|
||||
command = CancelTaskCommand(task_uuid=task.uuid)
|
||||
|
||||
# Should fail because gamma can't see admin's private task (base filter)
|
||||
with pytest.raises(TaskNotFoundError):
|
||||
command.run()
|
||||
|
||||
# Verify task unchanged
|
||||
db.session.refresh(task)
|
||||
assert task.status == TaskStatus.PENDING.value
|
||||
finally:
|
||||
# Cleanup
|
||||
db.session.delete(task)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def test_cancel_system_task_requires_admin(app_context, get_user) -> None:
|
||||
"""Test system tasks can only be canceled by admin"""
|
||||
admin = get_user("admin")
|
||||
gamma = get_user("gamma")
|
||||
unique_key = f"system_task_test_{uuid4().hex[:8]}"
|
||||
|
||||
# Use test_request_context to ensure has_request_context() returns True
|
||||
# so that TaskFilter properly applies permission filtering
|
||||
with app.test_request_context():
|
||||
# Create a system task
|
||||
task = TaskDAO.create_task(
|
||||
task_type="test_type",
|
||||
task_key=unique_key,
|
||||
scope=TaskScope.SYSTEM,
|
||||
user_id=None,
|
||||
)
|
||||
task.created_by = admin
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
# Try to cancel as gamma (non-admin)
|
||||
with override_user(gamma):
|
||||
command = CancelTaskCommand(task_uuid=task.uuid)
|
||||
|
||||
# System tasks are not visible to non-admins via base filter
|
||||
with pytest.raises(TaskNotFoundError):
|
||||
command.run()
|
||||
|
||||
# Verify task unchanged
|
||||
db.session.refresh(task)
|
||||
assert task.status == TaskStatus.PENDING.value
|
||||
|
||||
# But admin can cancel it
|
||||
with override_user(admin):
|
||||
command = CancelTaskCommand(task_uuid=task.uuid)
|
||||
result = command.run()
|
||||
|
||||
assert result.status == TaskStatus.ABORTED.value
|
||||
finally:
|
||||
# Cleanup
|
||||
db.session.delete(task)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def test_cancel_already_aborting_is_idempotent(app_context, get_user) -> None:
|
||||
"""Test canceling an already aborting task is idempotent"""
|
||||
admin = get_user("admin")
|
||||
|
||||
# Create a task already in ABORTING state
|
||||
task = TaskDAO.create_task(
|
||||
task_type="test_type",
|
||||
task_key="idempotent_cancel_test",
|
||||
scope=TaskScope.PRIVATE,
|
||||
user_id=admin.id,
|
||||
)
|
||||
task.created_by = admin
|
||||
task.set_status(TaskStatus.ABORTING)
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
# Cancel the already aborting task
|
||||
with override_user(admin):
|
||||
command = CancelTaskCommand(task_uuid=task.uuid)
|
||||
result = command.run()
|
||||
|
||||
# Should succeed without error
|
||||
assert result.uuid == task.uuid
|
||||
assert result.status == TaskStatus.ABORTING.value
|
||||
finally:
|
||||
# Cleanup
|
||||
db.session.delete(task)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def test_cancel_shared_task_not_subscribed_raises_error(app_context, get_user) -> None:
|
||||
"""Test non-subscriber cannot cancel shared task"""
|
||||
admin = get_user("admin")
|
||||
gamma = get_user("gamma")
|
||||
|
||||
# Create a shared task with only admin as subscriber
|
||||
task = TaskDAO.create_task(
|
||||
task_type="test_type",
|
||||
task_key="not_subscribed_test",
|
||||
scope=TaskScope.SHARED,
|
||||
user_id=admin.id,
|
||||
)
|
||||
task.created_by = admin
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
# Try to cancel as gamma (not subscribed)
|
||||
with override_user(gamma):
|
||||
command = CancelTaskCommand(task_uuid=task.uuid)
|
||||
|
||||
with pytest.raises(TaskPermissionDeniedError):
|
||||
command.run()
|
||||
|
||||
# Verify task unchanged
|
||||
db.session.refresh(task)
|
||||
assert task.status == TaskStatus.PENDING.value
|
||||
finally:
|
||||
# Cleanup
|
||||
db.session.delete(task)
|
||||
db.session.commit()
|
||||
419
tests/integration_tests/tasks/commands/test_internal_update.py
Normal file
419
tests/integration_tests/tasks/commands/test_internal_update.py
Normal file
@@ -0,0 +1,419 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file to you under
|
||||
# the Apache License, Version 2.0 (the "License"); you may not
|
||||
# use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""Integration tests for internal task state update commands."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from superset_core.api.tasks import TaskScope, TaskStatus
|
||||
|
||||
from superset import db
|
||||
from superset.commands.tasks.internal_update import (
|
||||
InternalStatusTransitionCommand,
|
||||
InternalUpdateTaskCommand,
|
||||
)
|
||||
from superset.daos.tasks import TaskDAO
|
||||
|
||||
|
||||
def test_internal_update_properties(app_context, get_user, login_as) -> None:
|
||||
"""Test updating only properties without reading task first."""
|
||||
admin = get_user("admin")
|
||||
login_as("admin")
|
||||
|
||||
task = TaskDAO.create_task(
|
||||
task_type="test_type",
|
||||
task_key="internal_update_props",
|
||||
scope=TaskScope.PRIVATE,
|
||||
user_id=admin.id,
|
||||
)
|
||||
task.created_by = admin
|
||||
task.set_status(TaskStatus.IN_PROGRESS)
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
# Perform zero-read update
|
||||
command = InternalUpdateTaskCommand(
|
||||
task_uuid=task.uuid,
|
||||
properties={"is_abortable": True, "progress_percent": 0.5},
|
||||
)
|
||||
result = command.run()
|
||||
|
||||
assert result is True
|
||||
|
||||
# Verify in database
|
||||
db.session.refresh(task)
|
||||
assert task.properties_dict.get("is_abortable") is True
|
||||
assert task.properties_dict.get("progress_percent") == 0.5
|
||||
finally:
|
||||
db.session.delete(task)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def test_internal_update_payload(app_context, get_user, login_as) -> None:
|
||||
"""Test updating only payload without reading task first."""
|
||||
admin = get_user("admin")
|
||||
login_as("admin")
|
||||
|
||||
task = TaskDAO.create_task(
|
||||
task_type="test_type",
|
||||
task_key="internal_update_payload",
|
||||
scope=TaskScope.PRIVATE,
|
||||
user_id=admin.id,
|
||||
)
|
||||
task.created_by = admin
|
||||
task.set_status(TaskStatus.IN_PROGRESS)
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
# Perform zero-read update
|
||||
command = InternalUpdateTaskCommand(
|
||||
task_uuid=task.uuid,
|
||||
payload={"custom_key": "value", "count": 42},
|
||||
)
|
||||
result = command.run()
|
||||
|
||||
assert result is True
|
||||
|
||||
# Verify in database
|
||||
db.session.refresh(task)
|
||||
assert task.payload_dict == {"custom_key": "value", "count": 42}
|
||||
finally:
|
||||
db.session.delete(task)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def test_internal_update_both_properties_and_payload(
|
||||
app_context, get_user, login_as
|
||||
) -> None:
|
||||
"""Test updating both properties and payload in one call."""
|
||||
admin = get_user("admin")
|
||||
login_as("admin")
|
||||
|
||||
task = TaskDAO.create_task(
|
||||
task_type="test_type",
|
||||
task_key="internal_update_both",
|
||||
scope=TaskScope.PRIVATE,
|
||||
user_id=admin.id,
|
||||
)
|
||||
task.created_by = admin
|
||||
task.set_status(TaskStatus.IN_PROGRESS)
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
# Perform zero-read update of both
|
||||
command = InternalUpdateTaskCommand(
|
||||
task_uuid=task.uuid,
|
||||
properties={"progress_current": 50, "progress_total": 100},
|
||||
payload={"last_item": "xyz"},
|
||||
)
|
||||
result = command.run()
|
||||
|
||||
assert result is True
|
||||
|
||||
# Verify in database
|
||||
db.session.refresh(task)
|
||||
assert task.properties_dict.get("progress_current") == 50
|
||||
assert task.properties_dict.get("progress_total") == 100
|
||||
assert task.payload_dict == {"last_item": "xyz"}
|
||||
finally:
|
||||
db.session.delete(task)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def test_internal_update_returns_false_for_nonexistent_task(
|
||||
app_context, login_as
|
||||
) -> None:
|
||||
"""Test that updating non-existent task returns False."""
|
||||
login_as("admin")
|
||||
|
||||
command = InternalUpdateTaskCommand(
|
||||
task_uuid=UUID("00000000-0000-0000-0000-000000000000"),
|
||||
properties={"is_abortable": True},
|
||||
)
|
||||
result = command.run()
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
def test_internal_update_returns_false_when_nothing_to_update(
|
||||
app_context, get_user, login_as
|
||||
) -> None:
|
||||
"""Test that passing no properties or payload returns False early."""
|
||||
admin = get_user("admin")
|
||||
login_as("admin")
|
||||
|
||||
task = TaskDAO.create_task(
|
||||
task_type="test_type",
|
||||
task_key="internal_update_empty",
|
||||
scope=TaskScope.PRIVATE,
|
||||
user_id=admin.id,
|
||||
)
|
||||
task.created_by = admin
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
# No properties or payload provided
|
||||
command = InternalUpdateTaskCommand(
|
||||
task_uuid=task.uuid,
|
||||
properties=None,
|
||||
payload=None,
|
||||
)
|
||||
result = command.run()
|
||||
|
||||
assert result is False
|
||||
finally:
|
||||
db.session.delete(task)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def test_internal_update_does_not_change_status(
|
||||
app_context, get_user, login_as
|
||||
) -> None:
|
||||
"""Test that internal update leaves status unchanged (safe for concurrent abort)."""
|
||||
admin = get_user("admin")
|
||||
login_as("admin")
|
||||
|
||||
task = TaskDAO.create_task(
|
||||
task_type="test_type",
|
||||
task_key="internal_update_status",
|
||||
scope=TaskScope.PRIVATE,
|
||||
user_id=admin.id,
|
||||
)
|
||||
task.created_by = admin
|
||||
task.set_status(TaskStatus.IN_PROGRESS)
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
# Update properties - status should not change
|
||||
command = InternalUpdateTaskCommand(
|
||||
task_uuid=task.uuid,
|
||||
properties={"progress_percent": 0.75},
|
||||
)
|
||||
result = command.run()
|
||||
|
||||
assert result is True
|
||||
|
||||
# Verify status unchanged
|
||||
db.session.refresh(task)
|
||||
assert task.status == TaskStatus.IN_PROGRESS.value
|
||||
finally:
|
||||
db.session.delete(task)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def test_internal_update_replaces_entire_properties(
|
||||
app_context, get_user, login_as
|
||||
) -> None:
|
||||
"""Test that internal update replaces properties entirely (no merge)."""
|
||||
admin = get_user("admin")
|
||||
login_as("admin")
|
||||
|
||||
task = TaskDAO.create_task(
|
||||
task_type="test_type",
|
||||
task_key="internal_update_replace",
|
||||
scope=TaskScope.PRIVATE,
|
||||
user_id=admin.id,
|
||||
properties={"is_abortable": True, "timeout": 300},
|
||||
)
|
||||
task.created_by = admin
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
# Replace with new properties (caller is responsible for merging if needed)
|
||||
command = InternalUpdateTaskCommand(
|
||||
task_uuid=task.uuid,
|
||||
properties={"error_message": "new_value"},
|
||||
)
|
||||
result = command.run()
|
||||
|
||||
assert result is True
|
||||
|
||||
# Verify entire replacement occurred
|
||||
db.session.refresh(task)
|
||||
# The caller should have merged if they wanted to preserve is_abortable
|
||||
assert task.properties_dict == {"error_message": "new_value"}
|
||||
assert "is_abortable" not in task.properties_dict
|
||||
assert "timeout" not in task.properties_dict
|
||||
finally:
|
||||
db.session.delete(task)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# InternalStatusTransitionCommand Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_status_transition_atomic_compare_and_swap(
|
||||
app_context, get_user, login_as
|
||||
) -> None:
|
||||
"""Test atomic conditional status transitions with comprehensive scenarios.
|
||||
|
||||
Covers: success case, failure case, list of expected statuses, properties update,
|
||||
ended_at timestamp, and string status values.
|
||||
"""
|
||||
admin = get_user("admin")
|
||||
login_as("admin")
|
||||
|
||||
task = TaskDAO.create_task(
|
||||
task_type="test_type",
|
||||
task_key="status_transition_comprehensive",
|
||||
scope=TaskScope.PRIVATE,
|
||||
user_id=admin.id,
|
||||
)
|
||||
task.created_by = admin
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
# 1. SUCCESS CASE: PENDING → IN_PROGRESS (expected matches)
|
||||
result = InternalStatusTransitionCommand(
|
||||
task_uuid=task.uuid,
|
||||
new_status=TaskStatus.IN_PROGRESS,
|
||||
expected_status=TaskStatus.PENDING,
|
||||
).run()
|
||||
assert result is True
|
||||
db.session.refresh(task)
|
||||
assert task.status == TaskStatus.IN_PROGRESS.value
|
||||
|
||||
# 2. FAILURE CASE: Try wrong expected status (should fail, status unchanged)
|
||||
result = InternalStatusTransitionCommand(
|
||||
task_uuid=task.uuid,
|
||||
new_status=TaskStatus.SUCCESS,
|
||||
expected_status=TaskStatus.PENDING, # Wrong! Current is IN_PROGRESS
|
||||
).run()
|
||||
assert result is False
|
||||
db.session.refresh(task)
|
||||
assert task.status == TaskStatus.IN_PROGRESS.value # Unchanged
|
||||
|
||||
# 3. LIST OF EXPECTED: Transition with multiple acceptable source statuses
|
||||
task.set_status(TaskStatus.ABORTING)
|
||||
db.session.commit()
|
||||
|
||||
result = InternalStatusTransitionCommand(
|
||||
task_uuid=task.uuid,
|
||||
new_status=TaskStatus.FAILURE,
|
||||
expected_status=[TaskStatus.IN_PROGRESS, TaskStatus.ABORTING],
|
||||
properties={"error_message": "Test error"},
|
||||
).run()
|
||||
assert result is True
|
||||
db.session.refresh(task)
|
||||
assert task.status == TaskStatus.FAILURE.value
|
||||
assert task.properties_dict.get("error_message") == "Test error"
|
||||
|
||||
# 4. ENDED_AT: Reset to IN_PROGRESS and test ended_at timestamp
|
||||
task.set_status(TaskStatus.IN_PROGRESS)
|
||||
task.ended_at = None
|
||||
db.session.commit()
|
||||
assert task.ended_at is None
|
||||
|
||||
result = InternalStatusTransitionCommand(
|
||||
task_uuid=task.uuid,
|
||||
new_status=TaskStatus.SUCCESS,
|
||||
expected_status=TaskStatus.IN_PROGRESS,
|
||||
set_ended_at=True,
|
||||
).run()
|
||||
assert result is True
|
||||
db.session.refresh(task)
|
||||
assert task.status == TaskStatus.SUCCESS.value
|
||||
assert task.ended_at is not None
|
||||
|
||||
# 5. STRING VALUES: Reset and test string status values
|
||||
task.set_status(TaskStatus.PENDING)
|
||||
db.session.commit()
|
||||
|
||||
result = InternalStatusTransitionCommand(
|
||||
task_uuid=task.uuid,
|
||||
new_status="in_progress",
|
||||
expected_status="pending",
|
||||
).run()
|
||||
assert result is True
|
||||
db.session.refresh(task)
|
||||
assert task.status == "in_progress"
|
||||
|
||||
finally:
|
||||
db.session.delete(task)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def test_status_transition_prevents_race_condition(
|
||||
app_context, get_user, login_as
|
||||
) -> None:
|
||||
"""Test that conditional update prevents overwriting concurrent abort.
|
||||
|
||||
This is the key race condition fix: if task is aborted concurrently,
|
||||
the executor's attempt to set SUCCESS should fail (return False),
|
||||
preserving the ABORTING state.
|
||||
"""
|
||||
admin = get_user("admin")
|
||||
login_as("admin")
|
||||
|
||||
task = TaskDAO.create_task(
|
||||
task_type="test_type",
|
||||
task_key="status_transition_race",
|
||||
scope=TaskScope.PRIVATE,
|
||||
user_id=admin.id,
|
||||
)
|
||||
task.created_by = admin
|
||||
task.set_status(TaskStatus.IN_PROGRESS)
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
# Simulate concurrent abort: directly set ABORTING in DB
|
||||
# (as if CancelTaskCommand ran in another process)
|
||||
task.set_status(TaskStatus.ABORTING)
|
||||
db.session.commit()
|
||||
|
||||
# Executor tries to set SUCCESS (expecting IN_PROGRESS) - stale expectation
|
||||
result = InternalStatusTransitionCommand(
|
||||
task_uuid=task.uuid,
|
||||
new_status=TaskStatus.SUCCESS,
|
||||
expected_status=TaskStatus.IN_PROGRESS,
|
||||
).run()
|
||||
|
||||
# Should fail - task was aborted concurrently
|
||||
assert result is False
|
||||
|
||||
# Verify ABORTING is preserved (not overwritten to SUCCESS)
|
||||
db.session.refresh(task)
|
||||
assert task.status == TaskStatus.ABORTING.value
|
||||
|
||||
# Verify correct transition from ABORTING still works
|
||||
result = InternalStatusTransitionCommand(
|
||||
task_uuid=task.uuid,
|
||||
new_status=TaskStatus.ABORTED,
|
||||
expected_status=TaskStatus.ABORTING,
|
||||
set_ended_at=True,
|
||||
).run()
|
||||
assert result is True
|
||||
db.session.refresh(task)
|
||||
assert task.status == TaskStatus.ABORTED.value
|
||||
|
||||
finally:
|
||||
db.session.delete(task)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def test_status_transition_nonexistent_task(app_context, login_as) -> None:
|
||||
"""Test that transitioning non-existent task returns False."""
|
||||
login_as("admin")
|
||||
|
||||
result = InternalStatusTransitionCommand(
|
||||
task_uuid=UUID("00000000-0000-0000-0000-000000000000"),
|
||||
new_status=TaskStatus.IN_PROGRESS,
|
||||
expected_status=TaskStatus.PENDING,
|
||||
).run()
|
||||
|
||||
assert result is False
|
||||
258
tests/integration_tests/tasks/commands/test_prune.py
Normal file
258
tests/integration_tests/tasks/commands/test_prune.py
Normal file
@@ -0,0 +1,258 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file to you under
|
||||
# the Apache License, Version 2.0 (the "License"); you may not
|
||||
# use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import patch
|
||||
|
||||
from freezegun import freeze_time
|
||||
from superset_core.api.tasks import TaskScope, TaskStatus
|
||||
|
||||
from superset import db
|
||||
from superset.commands.tasks import TaskPruneCommand
|
||||
from superset.daos.tasks import TaskDAO
|
||||
from superset.models.tasks import Task
|
||||
|
||||
|
||||
@freeze_time("2024-02-15")
|
||||
@patch("superset.tasks.utils.get_current_user")
|
||||
def test_prune_tasks_success(mock_get_user, app_context, get_user, login_as) -> None:
|
||||
"""Test successful pruning of old completed tasks"""
|
||||
login_as("admin")
|
||||
admin = get_user("admin")
|
||||
mock_get_user.return_value = admin.username
|
||||
|
||||
# Create old completed tasks (35 days ago = Jan 11, 2024)
|
||||
old_date = datetime(2024, 1, 11, tzinfo=timezone.utc)
|
||||
task_ids = []
|
||||
for i in range(3):
|
||||
task = TaskDAO.create_task(
|
||||
task_type="test_type",
|
||||
task_key=f"prune_task_{i}",
|
||||
scope=TaskScope.PRIVATE,
|
||||
user_id=admin.id,
|
||||
)
|
||||
task.created_by = admin
|
||||
task.set_status(TaskStatus.SUCCESS)
|
||||
task.ended_at = old_date
|
||||
task_ids.append(task.id)
|
||||
|
||||
# Create a recent task (5 days ago = Feb 10, 2024) that should NOT be deleted
|
||||
recent_date = datetime(2024, 2, 10, tzinfo=timezone.utc)
|
||||
recent_task = TaskDAO.create_task(
|
||||
task_type="test_type",
|
||||
task_key="recent_task",
|
||||
scope=TaskScope.PRIVATE,
|
||||
user_id=admin.id,
|
||||
)
|
||||
recent_task.created_by = admin
|
||||
recent_task.set_status(TaskStatus.SUCCESS)
|
||||
recent_task.ended_at = recent_date
|
||||
recent_task_id = recent_task.id
|
||||
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
# Prune tasks older than 30 days
|
||||
command = TaskPruneCommand(retention_period_days=30)
|
||||
command.run()
|
||||
|
||||
# Verify old tasks are deleted
|
||||
for task_id in task_ids:
|
||||
assert db.session.get(Task, task_id) is None
|
||||
|
||||
# Verify recent task is NOT deleted
|
||||
assert db.session.get(Task, recent_task_id) is not None
|
||||
finally:
|
||||
# Cleanup remaining tasks
|
||||
for task_id in task_ids:
|
||||
existing = db.session.get(Task, task_id)
|
||||
if existing:
|
||||
db.session.delete(existing)
|
||||
if db.session.get(Task, recent_task_id):
|
||||
db.session.delete(db.session.get(Task, recent_task_id))
|
||||
db.session.commit()
|
||||
|
||||
|
||||
@freeze_time("2024-02-15")
|
||||
@patch("superset.tasks.utils.get_current_user")
|
||||
def test_prune_tasks_with_max_rows(
|
||||
mock_get_user, app_context, get_user, login_as
|
||||
) -> None:
|
||||
"""Test pruning with max_rows_per_run limit"""
|
||||
login_as("admin")
|
||||
admin = get_user("admin")
|
||||
mock_get_user.return_value = admin.username
|
||||
|
||||
# Create old completed tasks (35 days ago = Jan 11, 2024)
|
||||
task_ids = []
|
||||
for i in range(5):
|
||||
# Different ages for ordering (older tasks have smaller hour values)
|
||||
old_date = datetime(2024, 1, 11, i, tzinfo=timezone.utc)
|
||||
task = TaskDAO.create_task(
|
||||
task_type="test_type",
|
||||
task_key=f"max_rows_task_{i}",
|
||||
scope=TaskScope.PRIVATE,
|
||||
user_id=admin.id,
|
||||
)
|
||||
task.created_by = admin
|
||||
task.set_status(TaskStatus.SUCCESS)
|
||||
task.ended_at = old_date
|
||||
task_ids.append(task.id)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
# Prune with max_rows_per_run=2 (should only delete 2 oldest)
|
||||
command = TaskPruneCommand(retention_period_days=30, max_rows_per_run=2)
|
||||
command.run()
|
||||
|
||||
# Count remaining tasks
|
||||
remaining = sum(
|
||||
1 for task_id in task_ids if db.session.get(Task, task_id) is not None
|
||||
)
|
||||
assert remaining == 3 # 5 - 2 = 3 remaining
|
||||
finally:
|
||||
# Cleanup remaining tasks
|
||||
for task_id in task_ids:
|
||||
existing = db.session.get(Task, task_id)
|
||||
if existing:
|
||||
db.session.delete(existing)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
@freeze_time("2024-02-15")
|
||||
@patch("superset.tasks.utils.get_current_user")
|
||||
def test_prune_does_not_delete_pending_tasks(
|
||||
mock_get_user, app_context, get_user, login_as
|
||||
) -> None:
|
||||
"""Test that pruning does not delete pending or in-progress tasks"""
|
||||
login_as("admin")
|
||||
admin = get_user("admin")
|
||||
mock_get_user.return_value = admin.username
|
||||
|
||||
pending_task = TaskDAO.create_task(
|
||||
task_type="test_type",
|
||||
task_key="pending_task",
|
||||
scope=TaskScope.PRIVATE,
|
||||
user_id=admin.id,
|
||||
)
|
||||
pending_task.created_by = admin
|
||||
# Keep as PENDING (no ended_at)
|
||||
|
||||
in_progress_task = TaskDAO.create_task(
|
||||
task_type="test_type",
|
||||
task_key="in_progress_task",
|
||||
scope=TaskScope.PRIVATE,
|
||||
user_id=admin.id,
|
||||
)
|
||||
in_progress_task.created_by = admin
|
||||
in_progress_task.set_status(TaskStatus.IN_PROGRESS)
|
||||
# No ended_at for in-progress tasks
|
||||
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
# Prune tasks older than 30 days
|
||||
command = TaskPruneCommand(retention_period_days=30)
|
||||
command.run()
|
||||
|
||||
# Verify non-completed tasks are NOT deleted
|
||||
assert db.session.get(Task, pending_task.id) is not None
|
||||
assert db.session.get(Task, in_progress_task.id) is not None
|
||||
finally:
|
||||
# Cleanup
|
||||
for task in [pending_task, in_progress_task]:
|
||||
existing = db.session.get(Task, task.id)
|
||||
if existing:
|
||||
db.session.delete(existing)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
@freeze_time("2024-02-15")
|
||||
@patch("superset.tasks.utils.get_current_user")
|
||||
def test_prune_deletes_all_completed_statuses(
|
||||
mock_get_user, app_context, get_user, login_as
|
||||
) -> None:
|
||||
"""Test pruning deletes SUCCESS, FAILURE, and ABORTED tasks"""
|
||||
login_as("admin")
|
||||
admin = get_user("admin")
|
||||
mock_get_user.return_value = admin.username
|
||||
|
||||
old_date = datetime(2024, 1, 11, tzinfo=timezone.utc)
|
||||
|
||||
# Create tasks with different completed statuses
|
||||
success_task = TaskDAO.create_task(
|
||||
task_type="test_type",
|
||||
task_key="success_task",
|
||||
scope=TaskScope.PRIVATE,
|
||||
user_id=admin.id,
|
||||
)
|
||||
success_task.created_by = admin
|
||||
success_task.set_status(TaskStatus.SUCCESS)
|
||||
success_task.ended_at = old_date
|
||||
success_task_id = success_task.id
|
||||
|
||||
failure_task = TaskDAO.create_task(
|
||||
task_type="test_type",
|
||||
task_key="failure_task",
|
||||
scope=TaskScope.PRIVATE,
|
||||
user_id=admin.id,
|
||||
)
|
||||
failure_task.created_by = admin
|
||||
failure_task.set_status(TaskStatus.FAILURE)
|
||||
failure_task.ended_at = old_date
|
||||
failure_task_id = failure_task.id
|
||||
|
||||
aborted_task = TaskDAO.create_task(
|
||||
task_type="test_type",
|
||||
task_key="aborted_task",
|
||||
scope=TaskScope.PRIVATE,
|
||||
user_id=admin.id,
|
||||
)
|
||||
aborted_task.created_by = admin
|
||||
aborted_task.set_status(TaskStatus.ABORTED)
|
||||
aborted_task.ended_at = old_date
|
||||
aborted_task_id = aborted_task.id
|
||||
|
||||
db.session.commit()
|
||||
task_ids = [success_task_id, failure_task_id, aborted_task_id]
|
||||
|
||||
try:
|
||||
# Prune tasks older than 30 days
|
||||
command = TaskPruneCommand(retention_period_days=30)
|
||||
command.run()
|
||||
|
||||
# Verify all completed tasks are deleted
|
||||
for task_id in task_ids:
|
||||
assert db.session.get(Task, task_id) is None
|
||||
except AssertionError:
|
||||
# Cleanup if test fails
|
||||
for task_id in task_ids:
|
||||
existing = db.session.get(Task, task_id)
|
||||
if existing:
|
||||
db.session.delete(existing)
|
||||
db.session.commit()
|
||||
raise
|
||||
|
||||
|
||||
def test_prune_no_tasks_to_delete(app_context, login_as) -> None:
|
||||
"""Test pruning when no old tasks exist"""
|
||||
login_as("admin")
|
||||
|
||||
# Don't create any tasks - should handle gracefully
|
||||
command = TaskPruneCommand(retention_period_days=30)
|
||||
command.run() # Should not raise any errors
|
||||
238
tests/integration_tests/tasks/commands/test_submit.py
Normal file
238
tests/integration_tests/tasks/commands/test_submit.py
Normal file
@@ -0,0 +1,238 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file to you under
|
||||
# the Apache License, Version 2.0 (the "License"); you may not
|
||||
# use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
|
||||
import pytest
|
||||
from superset_core.api.tasks import TaskStatus
|
||||
|
||||
from superset import db
|
||||
from superset.commands.tasks import SubmitTaskCommand
|
||||
from superset.commands.tasks.exceptions import (
|
||||
TaskInvalidError,
|
||||
)
|
||||
|
||||
|
||||
def test_submit_task_success(app_context, login_as, get_user) -> None:
|
||||
"""Test successful task submission"""
|
||||
login_as("admin")
|
||||
admin = get_user("admin")
|
||||
|
||||
command = SubmitTaskCommand(
|
||||
data={
|
||||
"task_type": "test-type",
|
||||
"task_key": "test-key",
|
||||
"task_name": "Test Task",
|
||||
"user_id": admin.id,
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
result = command.run()
|
||||
|
||||
# Verify task was created
|
||||
assert result.task_type == "test-type"
|
||||
assert result.task_key == "test-key"
|
||||
assert result.task_name == "Test Task"
|
||||
assert result.status == TaskStatus.PENDING.value
|
||||
assert result.payload == "{}"
|
||||
|
||||
# Verify in database
|
||||
db.session.refresh(result)
|
||||
assert result.id is not None
|
||||
assert result.uuid is not None
|
||||
finally:
|
||||
# Cleanup
|
||||
db.session.delete(result)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def test_submit_task_with_all_fields(app_context, login_as, get_user) -> None:
|
||||
"""Test task submission with all optional fields"""
|
||||
login_as("admin")
|
||||
admin = get_user("admin")
|
||||
|
||||
command = SubmitTaskCommand(
|
||||
data={
|
||||
"task_type": "test-type",
|
||||
"task_key": "test-key-full",
|
||||
"task_name": "Test Task Full",
|
||||
"user_id": admin.id,
|
||||
"payload": {"key": "value"},
|
||||
"properties": {"execution_mode": "async", "timeout": 300},
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
result = command.run()
|
||||
|
||||
# Verify all fields were set
|
||||
assert result.task_type == "test-type"
|
||||
assert result.task_key == "test-key-full"
|
||||
assert result.task_name == "Test Task Full"
|
||||
assert result.user_id == admin.id
|
||||
assert result.payload_dict == {"key": "value"}
|
||||
assert result.properties_dict.get("execution_mode") == "async"
|
||||
assert result.properties_dict.get("timeout") == 300
|
||||
finally:
|
||||
# Cleanup
|
||||
db.session.delete(result)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def test_submit_task_missing_task_type(app_context, login_as) -> None:
|
||||
"""Test submission fails when task_type is missing"""
|
||||
login_as("admin")
|
||||
|
||||
command = SubmitTaskCommand(data={})
|
||||
|
||||
with pytest.raises(TaskInvalidError) as exc_info:
|
||||
command.run()
|
||||
|
||||
assert len(exc_info.value._exceptions) == 1
|
||||
assert "task_type" in exc_info.value._exceptions[0].field_name
|
||||
|
||||
|
||||
def test_submit_task_joins_existing(app_context, login_as, get_user) -> None:
|
||||
"""Test that submitting with duplicate key joins existing task"""
|
||||
login_as("admin")
|
||||
admin = get_user("admin")
|
||||
|
||||
# Create first task
|
||||
command1 = SubmitTaskCommand(
|
||||
data={
|
||||
"task_type": "test-type",
|
||||
"task_key": "shared-key",
|
||||
"task_name": "First Task",
|
||||
"user_id": admin.id,
|
||||
}
|
||||
)
|
||||
task1 = command1.run()
|
||||
|
||||
try:
|
||||
# Submit second task with same task_key and type
|
||||
command2 = SubmitTaskCommand(
|
||||
data={
|
||||
"task_type": "test-type",
|
||||
"task_key": "shared-key",
|
||||
"task_name": "Second Task",
|
||||
"user_id": admin.id,
|
||||
}
|
||||
)
|
||||
|
||||
# Should return existing task, not create new one
|
||||
task2 = command2.run()
|
||||
assert task2.id == task1.id
|
||||
assert task2.uuid == task1.uuid
|
||||
finally:
|
||||
# Cleanup
|
||||
db.session.delete(task1)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def test_submit_task_without_task_key(app_context, login_as, get_user) -> None:
|
||||
"""Test task submission without task_key (command generates UUID)"""
|
||||
login_as("admin")
|
||||
admin = get_user("admin")
|
||||
|
||||
command = SubmitTaskCommand(
|
||||
data={
|
||||
"task_type": "test-type",
|
||||
"task_name": "Test Task No ID",
|
||||
"user_id": admin.id,
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
result = command.run()
|
||||
|
||||
# Verify task was created and command generated a task_key
|
||||
assert result.task_type == "test-type"
|
||||
assert result.task_name == "Test Task No ID"
|
||||
assert result.task_key is not None # Command generated UUID
|
||||
assert result.uuid is not None
|
||||
finally:
|
||||
# Cleanup
|
||||
db.session.delete(result)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def test_submit_task_run_with_info_returns_is_new_true(
|
||||
app_context, login_as, get_user
|
||||
) -> None:
|
||||
"""Test run_with_info returns is_new=True for new task"""
|
||||
login_as("admin")
|
||||
admin = get_user("admin")
|
||||
|
||||
command = SubmitTaskCommand(
|
||||
data={
|
||||
"task_type": "test-type",
|
||||
"task_key": "unique-key-is-new",
|
||||
"task_name": "Test Task",
|
||||
"user_id": admin.id,
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
task, is_new = command.run_with_info()
|
||||
|
||||
assert is_new is True
|
||||
assert task.task_key == "unique-key-is-new"
|
||||
finally:
|
||||
# Cleanup
|
||||
db.session.delete(task)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def test_submit_task_run_with_info_returns_is_new_false(
|
||||
app_context, login_as, get_user
|
||||
) -> None:
|
||||
"""Test run_with_info returns is_new=False when joining existing task"""
|
||||
login_as("admin")
|
||||
admin = get_user("admin")
|
||||
|
||||
# Create first task
|
||||
command1 = SubmitTaskCommand(
|
||||
data={
|
||||
"task_type": "test-type",
|
||||
"task_key": "shared-key-is-new",
|
||||
"task_name": "First Task",
|
||||
"user_id": admin.id,
|
||||
}
|
||||
)
|
||||
task1, is_new1 = command1.run_with_info()
|
||||
assert is_new1 is True
|
||||
|
||||
try:
|
||||
# Submit second task with same key
|
||||
command2 = SubmitTaskCommand(
|
||||
data={
|
||||
"task_type": "test-type",
|
||||
"task_key": "shared-key-is-new",
|
||||
"task_name": "Second Task",
|
||||
"user_id": admin.id,
|
||||
}
|
||||
)
|
||||
task2, is_new2 = command2.run_with_info()
|
||||
|
||||
# Should return existing task with is_new=False
|
||||
assert is_new2 is False
|
||||
assert task2.id == task1.id
|
||||
assert task2.uuid == task1.uuid
|
||||
finally:
|
||||
# Cleanup
|
||||
db.session.delete(task1)
|
||||
db.session.commit()
|
||||
260
tests/integration_tests/tasks/commands/test_update.py
Normal file
260
tests/integration_tests/tasks/commands/test_update.py
Normal file
@@ -0,0 +1,260 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file to you under
|
||||
# the Apache License, Version 2.0 (the "License"); you may not
|
||||
# use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from superset_core.api.tasks import TaskScope, TaskStatus
|
||||
|
||||
from superset import db
|
||||
from superset.commands.tasks import UpdateTaskCommand
|
||||
from superset.commands.tasks.exceptions import (
|
||||
TaskForbiddenError,
|
||||
TaskNotFoundError,
|
||||
)
|
||||
from superset.daos.tasks import TaskDAO
|
||||
|
||||
|
||||
def test_update_task_success(app_context, get_user, login_as) -> None:
|
||||
"""Test successful task update"""
|
||||
admin = get_user("admin")
|
||||
login_as("admin")
|
||||
|
||||
# Create a task using DAO
|
||||
task = TaskDAO.create_task(
|
||||
task_type="test_type",
|
||||
task_key="update_test",
|
||||
scope=TaskScope.PRIVATE,
|
||||
user_id=admin.id,
|
||||
)
|
||||
task.created_by = admin
|
||||
task.set_status(TaskStatus.IN_PROGRESS)
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
# Update the task status
|
||||
command = UpdateTaskCommand(
|
||||
task_uuid=task.uuid,
|
||||
status=TaskStatus.SUCCESS.value,
|
||||
)
|
||||
result = command.run()
|
||||
|
||||
# Verify update
|
||||
assert result.uuid == task.uuid
|
||||
assert result.status == TaskStatus.SUCCESS.value
|
||||
|
||||
# Verify in database
|
||||
db.session.refresh(task)
|
||||
assert task.status == TaskStatus.SUCCESS.value
|
||||
finally:
|
||||
# Cleanup
|
||||
db.session.delete(task)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def test_update_task_not_found(app_context, login_as) -> None:
|
||||
"""Test update fails when task not found"""
|
||||
login_as("admin")
|
||||
|
||||
command = UpdateTaskCommand(
|
||||
task_uuid=UUID("00000000-0000-0000-0000-000000000000"),
|
||||
status=TaskStatus.SUCCESS.value,
|
||||
)
|
||||
|
||||
with pytest.raises(TaskNotFoundError):
|
||||
command.run()
|
||||
|
||||
|
||||
def test_update_task_forbidden(app_context, get_user, login_as) -> None:
|
||||
"""Test update fails when user doesn't own task (via base filter)"""
|
||||
gamma = get_user("gamma")
|
||||
login_as("gamma")
|
||||
|
||||
# Create a task owned by gamma (non-admin) using DAO
|
||||
task = TaskDAO.create_task(
|
||||
task_type="test_type",
|
||||
task_key="forbidden_test",
|
||||
scope=TaskScope.PRIVATE,
|
||||
user_id=gamma.id,
|
||||
)
|
||||
task.created_by = gamma
|
||||
task.set_status(TaskStatus.IN_PROGRESS)
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
# Login as alpha user (different non-admin, non-owner)
|
||||
login_as("alpha")
|
||||
|
||||
# Try to update gamma's task as alpha user
|
||||
command = UpdateTaskCommand(
|
||||
task_uuid=task.uuid,
|
||||
status=TaskStatus.SUCCESS.value,
|
||||
)
|
||||
|
||||
# Should raise ForbiddenError because ownership check fails
|
||||
with pytest.raises(TaskForbiddenError):
|
||||
command.run()
|
||||
|
||||
# Verify task was NOT updated
|
||||
db.session.refresh(task)
|
||||
assert task.status == TaskStatus.IN_PROGRESS.value
|
||||
finally:
|
||||
# Cleanup
|
||||
db.session.delete(task)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def test_update_task_payload(app_context, get_user, login_as) -> None:
|
||||
"""Test updating task payload"""
|
||||
admin = get_user("admin")
|
||||
login_as("admin")
|
||||
|
||||
# Create a task using DAO
|
||||
task = TaskDAO.create_task(
|
||||
task_type="test_type",
|
||||
task_key="payload_test",
|
||||
scope=TaskScope.PRIVATE,
|
||||
user_id=admin.id,
|
||||
payload={"initial": "data"},
|
||||
)
|
||||
task.created_by = admin
|
||||
task.set_status(TaskStatus.IN_PROGRESS)
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
# Update payload
|
||||
command = UpdateTaskCommand(
|
||||
task_uuid=task.uuid,
|
||||
payload={"progress": 50, "message": "halfway"},
|
||||
)
|
||||
result = command.run()
|
||||
|
||||
# Verify payload was updated
|
||||
assert result.uuid == task.uuid
|
||||
payload = result.payload_dict
|
||||
assert payload["progress"] == 50
|
||||
assert payload["message"] == "halfway"
|
||||
|
||||
# Verify in database
|
||||
db.session.refresh(task)
|
||||
task_payload = task.payload_dict
|
||||
assert task_payload["progress"] == 50
|
||||
finally:
|
||||
# Cleanup
|
||||
db.session.delete(task)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def test_update_all_supported_fields(app_context, get_user, login_as) -> None:
|
||||
"""Test updating all supported task fields
|
||||
(status, error, progress, abortable, timeout)"""
|
||||
admin = get_user("admin")
|
||||
login_as("admin")
|
||||
|
||||
# Create a task with initial execution_mode and timeout in properties
|
||||
task = TaskDAO.create_task(
|
||||
task_type="test_type",
|
||||
task_key="all_fields_test",
|
||||
scope=TaskScope.PRIVATE,
|
||||
user_id=admin.id,
|
||||
properties={"execution_mode": "async", "timeout": 300},
|
||||
)
|
||||
task.created_by = admin
|
||||
task.set_status(TaskStatus.IN_PROGRESS)
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
# Update all field types at once
|
||||
command = UpdateTaskCommand(
|
||||
task_uuid=task.uuid,
|
||||
status=TaskStatus.FAILURE.value,
|
||||
properties={
|
||||
"error_message": "Task failed due to error",
|
||||
"progress_percent": 0.75,
|
||||
"progress_current": 75,
|
||||
"progress_total": 100,
|
||||
"is_abortable": True,
|
||||
},
|
||||
)
|
||||
result = command.run()
|
||||
|
||||
# Verify all fields were updated
|
||||
assert result.uuid == task.uuid
|
||||
assert result.status == TaskStatus.FAILURE.value
|
||||
assert result.properties_dict.get("error_message") == "Task failed due to error"
|
||||
assert result.properties_dict.get("progress_percent") == 0.75
|
||||
assert result.properties_dict.get("progress_current") == 75
|
||||
assert result.properties_dict.get("progress_total") == 100
|
||||
assert result.properties_dict.get("is_abortable") is True
|
||||
assert result.properties_dict.get("execution_mode") == "async"
|
||||
assert result.properties_dict.get("timeout") == 300
|
||||
|
||||
# Verify in database
|
||||
db.session.refresh(task)
|
||||
assert task.status == TaskStatus.FAILURE.value
|
||||
assert task.properties_dict.get("error_message") == "Task failed due to error"
|
||||
assert task.properties_dict.get("progress_percent") == 0.75
|
||||
assert task.properties_dict.get("progress_current") == 75
|
||||
assert task.properties_dict.get("progress_total") == 100
|
||||
assert task.properties_dict.get("is_abortable") is True
|
||||
assert task.properties_dict.get("execution_mode") == "async"
|
||||
assert task.properties_dict.get("timeout") == 300
|
||||
finally:
|
||||
# Cleanup
|
||||
db.session.delete(task)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def test_update_task_skip_security_check(app_context, get_user, login_as) -> None:
|
||||
"""Test skip_security_check allows updating any task"""
|
||||
admin = get_user("admin")
|
||||
login_as("admin")
|
||||
|
||||
# Create a task owned by admin
|
||||
task = TaskDAO.create_task(
|
||||
task_type="test_type",
|
||||
task_key="skip_security_test",
|
||||
scope=TaskScope.PRIVATE,
|
||||
user_id=admin.id,
|
||||
)
|
||||
task.created_by = admin
|
||||
task.set_status(TaskStatus.IN_PROGRESS)
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
# Login as gamma user (non-owner)
|
||||
login_as("gamma")
|
||||
|
||||
# With skip_security_check=True, should succeed even though gamma doesn't own it
|
||||
command = UpdateTaskCommand(
|
||||
task_uuid=task.uuid,
|
||||
properties={"progress_percent": 0.75},
|
||||
skip_security_check=True,
|
||||
)
|
||||
result = command.run()
|
||||
|
||||
# Verify update succeeded
|
||||
assert result.uuid == task.uuid
|
||||
assert result.properties_dict.get("progress_percent") == 0.75
|
||||
|
||||
# Verify in database
|
||||
db.session.refresh(task)
|
||||
assert task.properties_dict.get("progress_percent") == 0.75
|
||||
finally:
|
||||
# Cleanup
|
||||
db.session.delete(task)
|
||||
db.session.commit()
|
||||
415
tests/integration_tests/tasks/test_event_handlers.py
Normal file
415
tests/integration_tests/tasks/test_event_handlers.py
Normal file
@@ -0,0 +1,415 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""End-to-end integration tests for task event handlers (abort and cleanup)
|
||||
|
||||
These tests verify that abort and cleanup handlers work correctly through
|
||||
the full task execution path using real @task decorated functions executed
|
||||
via the Celery executor (synchronously via .apply()).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from superset_core.api.tasks import TaskScope, TaskStatus
|
||||
|
||||
from superset.commands.tasks.cancel import CancelTaskCommand
|
||||
from superset.daos.tasks import TaskDAO
|
||||
from superset.extensions import db
|
||||
from superset.models.tasks import Task
|
||||
from superset.tasks.ambient_context import get_context
|
||||
from superset.tasks.context import TaskContext
|
||||
from superset.tasks.registry import TaskRegistry
|
||||
from superset.tasks.scheduler import execute_task
|
||||
from tests.integration_tests.base_tests import SupersetTestCase
|
||||
from tests.integration_tests.constants import ADMIN_USERNAME
|
||||
|
||||
# Module-level state to track handler calls across test executions
|
||||
# (Since decorated functions are defined at module level)
|
||||
_handler_state: dict[str, Any] = {}
|
||||
|
||||
|
||||
def _reset_handler_state():
|
||||
"""Reset handler state before each test."""
|
||||
global _handler_state
|
||||
_handler_state = {
|
||||
"cleanup_called": False,
|
||||
"abort_called": False,
|
||||
"cleanup_order": [],
|
||||
"abort_order": [],
|
||||
"cleanup_data": {},
|
||||
}
|
||||
|
||||
|
||||
def cleanup_test_task() -> None:
|
||||
"""Task that registers a cleanup handler."""
|
||||
ctx = get_context()
|
||||
|
||||
@ctx.on_cleanup
|
||||
def handle_cleanup() -> None:
|
||||
_handler_state["cleanup_called"] = True
|
||||
|
||||
# Simulate some work
|
||||
ctx.update_task(progress=1.0)
|
||||
|
||||
|
||||
def abort_test_task() -> None:
|
||||
"""Task that registers an abort handler."""
|
||||
ctx = get_context()
|
||||
|
||||
@ctx.on_abort
|
||||
def handle_abort() -> None:
|
||||
_handler_state["abort_called"] = True
|
||||
|
||||
|
||||
def both_handlers_task() -> None:
|
||||
"""Task that registers both abort and cleanup handlers."""
|
||||
ctx = get_context()
|
||||
|
||||
@ctx.on_abort
|
||||
def handle_abort() -> None:
|
||||
_handler_state["abort_called"] = True
|
||||
_handler_state["abort_order"].append("abort")
|
||||
|
||||
@ctx.on_cleanup
|
||||
def handle_cleanup() -> None:
|
||||
_handler_state["cleanup_called"] = True
|
||||
_handler_state["cleanup_order"].append("cleanup")
|
||||
|
||||
|
||||
def multiple_cleanup_handlers_task() -> None:
|
||||
"""Task that registers multiple cleanup handlers."""
|
||||
ctx = get_context()
|
||||
|
||||
@ctx.on_cleanup
|
||||
def cleanup_first() -> None:
|
||||
_handler_state["cleanup_order"].append("first")
|
||||
|
||||
@ctx.on_cleanup
|
||||
def cleanup_second() -> None:
|
||||
_handler_state["cleanup_order"].append("second")
|
||||
|
||||
@ctx.on_cleanup
|
||||
def cleanup_third() -> None:
|
||||
_handler_state["cleanup_order"].append("third")
|
||||
|
||||
|
||||
def cleanup_with_data_task() -> None:
|
||||
"""Task that uses cleanup handler to clean up partial work."""
|
||||
ctx = get_context()
|
||||
|
||||
# Simulate partial work in module-level state
|
||||
_handler_state["cleanup_data"]["temp_key"] = "temp_value"
|
||||
|
||||
@ctx.on_cleanup
|
||||
def handle_cleanup() -> None:
|
||||
# Clean up the partial work
|
||||
_handler_state["cleanup_data"].clear()
|
||||
_handler_state["cleanup_called"] = True
|
||||
|
||||
|
||||
def _register_test_tasks() -> None:
|
||||
"""Register test task functions if not already registered.
|
||||
|
||||
Called in setUp() to ensure tasks are registered regardless of
|
||||
whether other tests have cleared the registry.
|
||||
"""
|
||||
registrations = [
|
||||
("test_cleanup_task", cleanup_test_task),
|
||||
("test_abort_task", abort_test_task),
|
||||
("test_both_handlers_task", both_handlers_task),
|
||||
("test_multiple_cleanup_task", multiple_cleanup_handlers_task),
|
||||
("test_cleanup_with_data", cleanup_with_data_task),
|
||||
]
|
||||
for name, func in registrations:
|
||||
if not TaskRegistry.is_registered(name):
|
||||
TaskRegistry.register(name, func)
|
||||
|
||||
|
||||
class TestCleanupHandlers(SupersetTestCase):
|
||||
"""E2E tests for on_cleanup functionality using Celery executor."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
super().setUp()
|
||||
self.login(ADMIN_USERNAME)
|
||||
_register_test_tasks()
|
||||
_reset_handler_state()
|
||||
|
||||
def test_cleanup_handler_fires_on_success(self):
|
||||
"""Test cleanup handler runs when task completes successfully."""
|
||||
# Create task entry directly
|
||||
task_obj = TaskDAO.create_task(
|
||||
task_type="test_cleanup_task",
|
||||
task_key=f"test_key_{uuid.uuid4().hex[:8]}",
|
||||
task_name="Test Cleanup",
|
||||
scope=TaskScope.SYSTEM,
|
||||
)
|
||||
|
||||
# Execute task synchronously through Celery executor
|
||||
# Use str(uuid) since Celery serializes args as JSON strings
|
||||
result = execute_task.apply(
|
||||
args=[str(task_obj.uuid), "test_cleanup_task", (), {}]
|
||||
)
|
||||
|
||||
# Verify task completed successfully
|
||||
assert result.successful()
|
||||
assert result.result["status"] == TaskStatus.SUCCESS.value
|
||||
|
||||
# Verify cleanup handler was called
|
||||
assert _handler_state["cleanup_called"]
|
||||
|
||||
def test_multiple_cleanup_handlers_in_lifo_order(self):
|
||||
"""Test multiple cleanup handlers execute in LIFO order."""
|
||||
task_obj = TaskDAO.create_task(
|
||||
task_type="test_multiple_cleanup_task",
|
||||
task_key=f"test_key_{uuid.uuid4().hex[:8]}",
|
||||
task_name="Test Multiple Cleanup",
|
||||
scope=TaskScope.SYSTEM,
|
||||
)
|
||||
|
||||
result = execute_task.apply(
|
||||
args=[str(task_obj.uuid), "test_multiple_cleanup_task", (), {}]
|
||||
)
|
||||
|
||||
assert result.successful()
|
||||
|
||||
# Handlers should execute in LIFO order (last registered first)
|
||||
assert _handler_state["cleanup_order"] == ["third", "second", "first"]
|
||||
|
||||
def test_cleanup_handler_cleans_up_partial_work(self):
|
||||
"""Test cleanup handler can clean up partial work."""
|
||||
task_obj = TaskDAO.create_task(
|
||||
task_type="test_cleanup_with_data",
|
||||
task_key=f"test_key_{uuid.uuid4().hex[:8]}",
|
||||
task_name="Test Cleanup Data",
|
||||
scope=TaskScope.SYSTEM,
|
||||
)
|
||||
|
||||
result = execute_task.apply(
|
||||
args=[str(task_obj.uuid), "test_cleanup_with_data", (), {}]
|
||||
)
|
||||
|
||||
assert result.successful()
|
||||
assert _handler_state["cleanup_called"]
|
||||
# Cleanup handler should have cleared the data
|
||||
assert len(_handler_state["cleanup_data"]) == 0
|
||||
|
||||
|
||||
class TestAbortHandlers(SupersetTestCase):
|
||||
"""E2E tests for on_abort functionality."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
super().setUp()
|
||||
self.login(ADMIN_USERNAME)
|
||||
_register_test_tasks()
|
||||
_reset_handler_state()
|
||||
|
||||
def test_abort_handler_fires_when_task_aborting(self):
|
||||
"""Test abort handler runs when task is in ABORTING state during cleanup."""
|
||||
# Create task entry
|
||||
task_obj = TaskDAO.create_task(
|
||||
task_type="test_abort_task",
|
||||
task_key=f"test_key_{uuid.uuid4().hex[:8]}",
|
||||
task_name="Test Abort",
|
||||
scope=TaskScope.SYSTEM,
|
||||
)
|
||||
|
||||
# Manually set to IN_PROGRESS and then ABORTING to simulate abort
|
||||
task_obj.status = TaskStatus.IN_PROGRESS.value
|
||||
task_obj.update_properties({"is_abortable": True})
|
||||
db.session.merge(task_obj)
|
||||
db.session.commit()
|
||||
|
||||
# Refresh to get the updated task
|
||||
db.session.refresh(task_obj)
|
||||
|
||||
# Create context (simulating what executor does)
|
||||
ctx = TaskContext(task_obj)
|
||||
|
||||
# Register abort handler
|
||||
@ctx.on_abort
|
||||
def handle_abort():
|
||||
_handler_state["abort_called"] = True
|
||||
|
||||
# Set status to ABORTING (simulating CancelTaskCommand)
|
||||
task_obj.status = TaskStatus.ABORTING.value
|
||||
db.session.merge(task_obj)
|
||||
db.session.commit()
|
||||
|
||||
# Run cleanup (simulating executor's finally block)
|
||||
ctx._run_cleanup()
|
||||
|
||||
# Verify abort handler was called
|
||||
assert _handler_state["abort_called"]
|
||||
|
||||
def test_both_handlers_fire_on_abort(self):
|
||||
"""Test both abort and cleanup handlers run when task is aborted."""
|
||||
task_obj = TaskDAO.create_task(
|
||||
task_type="test_both_handlers_task",
|
||||
task_key=f"test_key_{uuid.uuid4().hex[:8]}",
|
||||
task_name="Test Both Handlers",
|
||||
scope=TaskScope.SYSTEM,
|
||||
)
|
||||
|
||||
task_obj.status = TaskStatus.IN_PROGRESS.value
|
||||
task_obj.update_properties({"is_abortable": True})
|
||||
db.session.merge(task_obj)
|
||||
db.session.commit()
|
||||
|
||||
# Refresh to get the updated task
|
||||
db.session.refresh(task_obj)
|
||||
|
||||
ctx = TaskContext(task_obj)
|
||||
|
||||
@ctx.on_abort
|
||||
def handle_abort():
|
||||
_handler_state["abort_called"] = True
|
||||
_handler_state["abort_order"].append("abort")
|
||||
|
||||
@ctx.on_cleanup
|
||||
def handle_cleanup():
|
||||
_handler_state["cleanup_called"] = True
|
||||
_handler_state["cleanup_order"].append("cleanup")
|
||||
|
||||
# Set to ABORTING
|
||||
task_obj.status = TaskStatus.ABORTING.value
|
||||
db.session.merge(task_obj)
|
||||
db.session.commit()
|
||||
|
||||
ctx._run_cleanup()
|
||||
|
||||
# Both should have been called
|
||||
assert _handler_state["abort_called"]
|
||||
assert _handler_state["cleanup_called"]
|
||||
|
||||
def test_abort_handler_not_called_on_success(self):
|
||||
"""Test abort handler doesn't run when task succeeds."""
|
||||
task_obj = TaskDAO.create_task(
|
||||
task_type="test_abort_task",
|
||||
task_key=f"test_key_{uuid.uuid4().hex[:8]}",
|
||||
task_name="Test No Abort on Success",
|
||||
scope=TaskScope.SYSTEM,
|
||||
)
|
||||
|
||||
task_obj.status = TaskStatus.SUCCESS.value
|
||||
db.session.merge(task_obj)
|
||||
db.session.commit()
|
||||
|
||||
# Refresh to get the updated task
|
||||
db.session.refresh(task_obj)
|
||||
|
||||
ctx = TaskContext(task_obj)
|
||||
|
||||
@ctx.on_abort
|
||||
def handle_abort():
|
||||
_handler_state["abort_called"] = True
|
||||
|
||||
@ctx.on_cleanup
|
||||
def handle_cleanup():
|
||||
_handler_state["cleanup_called"] = True
|
||||
|
||||
ctx._run_cleanup()
|
||||
|
||||
# Abort handler should NOT be called
|
||||
assert not _handler_state["abort_called"]
|
||||
# Cleanup handler should still be called
|
||||
assert _handler_state["cleanup_called"]
|
||||
|
||||
|
||||
class TestTaskContextMethods(SupersetTestCase):
|
||||
"""Tests for TaskContext public methods."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
super().setUp()
|
||||
self.login(ADMIN_USERNAME)
|
||||
|
||||
def test_on_abort_marks_task_abortable(self):
|
||||
"""Test that registering an on_abort handler marks task as abortable."""
|
||||
task_obj = TaskDAO.create_task(
|
||||
task_type="test_abortable_flag",
|
||||
task_key=f"test_key_{uuid.uuid4().hex[:8]}",
|
||||
task_name="Test Abortable",
|
||||
scope=TaskScope.SYSTEM,
|
||||
)
|
||||
|
||||
assert task_obj.properties_dict.get("is_abortable") is not True
|
||||
|
||||
ctx = TaskContext(task_obj)
|
||||
|
||||
@ctx.on_abort
|
||||
def handle_abort():
|
||||
pass
|
||||
|
||||
db.session.expire_all()
|
||||
task_obj = db.session.query(Task).filter_by(uuid=task_obj.uuid).first()
|
||||
assert task_obj.properties_dict.get("is_abortable") is True
|
||||
|
||||
|
||||
class TestAbortBeforeExecution(SupersetTestCase):
|
||||
"""Tests for aborting tasks before they start executing."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
super().setUp()
|
||||
self.login(ADMIN_USERNAME)
|
||||
_register_test_tasks()
|
||||
|
||||
def test_abort_pending_task(self):
|
||||
"""Test that pending tasks can be aborted directly."""
|
||||
task_obj = TaskDAO.create_task(
|
||||
task_type="test_abort_before_start",
|
||||
task_key=f"test_key_{uuid.uuid4().hex[:8]}",
|
||||
task_name="Test Before Start",
|
||||
scope=TaskScope.SYSTEM,
|
||||
)
|
||||
|
||||
# Cancel immediately (task is still PENDING)
|
||||
CancelTaskCommand(task_obj.uuid, force=True).run()
|
||||
|
||||
db.session.expire_all()
|
||||
task_obj = db.session.query(Task).filter_by(uuid=task_obj.uuid).first()
|
||||
assert task_obj.status == TaskStatus.ABORTED.value
|
||||
|
||||
def test_executor_skips_aborted_task(self):
|
||||
"""Test that executor skips tasks already aborted before execution."""
|
||||
task_obj = TaskDAO.create_task(
|
||||
task_type="test_cleanup_task",
|
||||
task_key=f"test_key_{uuid.uuid4().hex[:8]}",
|
||||
task_name="Test Skip Aborted",
|
||||
scope=TaskScope.SYSTEM,
|
||||
)
|
||||
|
||||
# Abort the task before execution
|
||||
task_obj.status = TaskStatus.ABORTED.value
|
||||
db.session.merge(task_obj)
|
||||
db.session.commit()
|
||||
|
||||
_reset_handler_state()
|
||||
|
||||
# Try to execute - should skip
|
||||
result = execute_task.apply(
|
||||
args=[str(task_obj.uuid), "test_cleanup_task", (), {}]
|
||||
)
|
||||
|
||||
assert result.successful()
|
||||
assert result.result["status"] == TaskStatus.ABORTED.value
|
||||
# Cleanup handler should NOT have been called (task was skipped)
|
||||
assert not _handler_state["cleanup_called"]
|
||||
158
tests/integration_tests/tasks/test_sync_join_wait.py
Normal file
158
tests/integration_tests/tasks/test_sync_join_wait.py
Normal file
@@ -0,0 +1,158 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file to you under
|
||||
# the Apache License, Version 2.0 (the "License"); you may not
|
||||
# use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""Integration tests for sync join-and-wait functionality in GTF."""
|
||||
|
||||
import time
|
||||
|
||||
from superset_core.api.tasks import TaskStatus
|
||||
|
||||
from superset import db
|
||||
from superset.commands.tasks import SubmitTaskCommand
|
||||
from superset.daos.tasks import TaskDAO
|
||||
from superset.tasks.manager import TaskManager
|
||||
|
||||
|
||||
def test_submit_task_distinguishes_new_vs_existing(
|
||||
app_context, login_as, get_user
|
||||
) -> None:
|
||||
"""
|
||||
Test that SubmitTaskCommand.run_with_info() correctly returns is_new flag.
|
||||
"""
|
||||
login_as("admin")
|
||||
admin = get_user("admin")
|
||||
|
||||
# First submission - should be new
|
||||
task1, is_new1 = SubmitTaskCommand(
|
||||
data={
|
||||
"task_type": "test-type",
|
||||
"task_key": "distinguish-key",
|
||||
"task_name": "First Task",
|
||||
"user_id": admin.id,
|
||||
}
|
||||
).run_with_info()
|
||||
|
||||
assert is_new1 is True
|
||||
|
||||
try:
|
||||
# Second submission with same key - should join existing
|
||||
task2, is_new2 = SubmitTaskCommand(
|
||||
data={
|
||||
"task_type": "test-type",
|
||||
"task_key": "distinguish-key",
|
||||
"task_name": "Second Task",
|
||||
"user_id": admin.id,
|
||||
}
|
||||
).run_with_info()
|
||||
|
||||
assert is_new2 is False
|
||||
assert task2.uuid == task1.uuid
|
||||
|
||||
finally:
|
||||
# Cleanup
|
||||
db.session.delete(task1)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def test_terminal_states_recognized_correctly(app_context) -> None:
|
||||
"""
|
||||
Test that TaskManager.TERMINAL_STATES contains the expected values.
|
||||
"""
|
||||
assert TaskStatus.SUCCESS.value in TaskManager.TERMINAL_STATES
|
||||
assert TaskStatus.FAILURE.value in TaskManager.TERMINAL_STATES
|
||||
assert TaskStatus.ABORTED.value in TaskManager.TERMINAL_STATES
|
||||
assert TaskStatus.TIMED_OUT.value in TaskManager.TERMINAL_STATES
|
||||
|
||||
# Non-terminal states should not be in the set
|
||||
assert TaskStatus.PENDING.value not in TaskManager.TERMINAL_STATES
|
||||
assert TaskStatus.IN_PROGRESS.value not in TaskManager.TERMINAL_STATES
|
||||
assert TaskStatus.ABORTING.value not in TaskManager.TERMINAL_STATES
|
||||
|
||||
|
||||
def test_wait_for_completion_timeout(app_context, login_as, get_user) -> None:
|
||||
"""
|
||||
Test that wait_for_completion raises TimeoutError on timeout.
|
||||
"""
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
login_as("admin")
|
||||
admin = get_user("admin")
|
||||
|
||||
# Create a pending task (won't complete)
|
||||
task, _ = SubmitTaskCommand(
|
||||
data={
|
||||
"task_type": "test-timeout",
|
||||
"task_key": "timeout-key",
|
||||
"task_name": "Timeout Task",
|
||||
"user_id": admin.id,
|
||||
}
|
||||
).run_with_info()
|
||||
|
||||
try:
|
||||
# Force polling mode by mocking signal_cache as None
|
||||
with patch("superset.tasks.manager.cache_manager") as mock_cache_manager:
|
||||
mock_cache_manager.signal_cache = None
|
||||
with pytest.raises(TimeoutError):
|
||||
TaskManager.wait_for_completion(
|
||||
task.uuid,
|
||||
timeout=0.2,
|
||||
poll_interval=0.05,
|
||||
)
|
||||
finally:
|
||||
db.session.delete(task)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def test_wait_returns_immediately_for_terminal_task(
|
||||
app_context, login_as, get_user
|
||||
) -> None:
|
||||
"""
|
||||
Test that wait_for_completion returns immediately if task is already terminal.
|
||||
"""
|
||||
login_as("admin")
|
||||
admin = get_user("admin")
|
||||
|
||||
# Create and immediately complete a task
|
||||
task, _ = SubmitTaskCommand(
|
||||
data={
|
||||
"task_type": "test-immediate",
|
||||
"task_key": "immediate-key",
|
||||
"task_name": "Immediate Task",
|
||||
"user_id": admin.id,
|
||||
}
|
||||
).run_with_info()
|
||||
|
||||
TaskDAO.update(task, {"status": TaskStatus.SUCCESS.value})
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
start = time.time()
|
||||
result = TaskManager.wait_for_completion(
|
||||
task.uuid,
|
||||
timeout=5.0,
|
||||
poll_interval=0.5,
|
||||
)
|
||||
elapsed = time.time() - start
|
||||
|
||||
assert result.status == TaskStatus.SUCCESS.value
|
||||
# Should return almost immediately since task is already terminal
|
||||
assert elapsed < 0.2
|
||||
finally:
|
||||
db.session.delete(task)
|
||||
db.session.commit()
|
||||
172
tests/integration_tests/tasks/test_throttling.py
Normal file
172
tests/integration_tests/tasks/test_throttling.py
Normal file
@@ -0,0 +1,172 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""Integration tests for TaskContext update_task throttling.
|
||||
|
||||
Tests verify:
|
||||
1. Final state is persisted correctly via cleanup flush
|
||||
2. Throttled updates are deferred, timer writes latest pending update
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
import uuid
|
||||
|
||||
from superset_core.api.tasks import TaskScope, TaskStatus
|
||||
|
||||
from superset.daos.tasks import TaskDAO
|
||||
from superset.extensions import db
|
||||
from superset.models.tasks import Task
|
||||
from superset.tasks.ambient_context import get_context
|
||||
from superset.tasks.registry import TaskRegistry
|
||||
from superset.tasks.scheduler import execute_task
|
||||
from tests.integration_tests.base_tests import SupersetTestCase
|
||||
from tests.integration_tests.constants import ADMIN_USERNAME
|
||||
|
||||
|
||||
def task_with_throttled_updates() -> None:
|
||||
"""Task with rapid progress and payload updates (exercises throttling)."""
|
||||
ctx = get_context()
|
||||
|
||||
# Rapid-fire updates within throttle window
|
||||
for i in range(10):
|
||||
ctx.update_task(progress=(i + 1, 10), payload={"step": i + 1})
|
||||
|
||||
|
||||
def _register_test_tasks() -> None:
|
||||
"""Register test task functions if not already registered.
|
||||
|
||||
Called in setUp() to ensure tasks are registered regardless of
|
||||
whether other tests have cleared the registry.
|
||||
"""
|
||||
if not TaskRegistry.is_registered("test_throttle_combined"):
|
||||
TaskRegistry.register("test_throttle_combined", task_with_throttled_updates)
|
||||
|
||||
|
||||
class TestUpdateTaskThrottling(SupersetTestCase):
|
||||
"""Integration test for update_task() throttling behavior."""
|
||||
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
self.login(ADMIN_USERNAME)
|
||||
_register_test_tasks()
|
||||
|
||||
def test_throttled_updates_persisted_on_cleanup(self) -> None:
|
||||
"""Final state should be persisted regardless of throttling.
|
||||
|
||||
Verifies the core invariant: cleanup flush ensures final state is persisted.
|
||||
"""
|
||||
task_obj = TaskDAO.create_task(
|
||||
task_type="test_throttle_combined",
|
||||
task_key=f"test_key_{uuid.uuid4().hex[:8]}",
|
||||
task_name="Test Throttled Updates",
|
||||
scope=TaskScope.SYSTEM,
|
||||
)
|
||||
|
||||
# Use str(uuid) since Celery serializes args as JSON strings
|
||||
result = execute_task.apply(
|
||||
args=[str(task_obj.uuid), "test_throttle_combined", (), {}]
|
||||
)
|
||||
|
||||
assert result.successful()
|
||||
assert result.result["status"] == TaskStatus.SUCCESS.value
|
||||
|
||||
# Verify final state is persisted
|
||||
db.session.expire_all()
|
||||
task_obj = db.session.query(Task).filter_by(uuid=task_obj.uuid).first()
|
||||
|
||||
# Progress: 10/10 = 100%
|
||||
props = task_obj.properties_dict
|
||||
assert props.get("progress_current") == 10
|
||||
assert props.get("progress_total") == 10
|
||||
assert props.get("progress_percent") == 1.0
|
||||
|
||||
# Payload: final step
|
||||
payload = task_obj.payload_dict
|
||||
assert payload.get("step") == 10
|
||||
|
||||
def test_throttle_behavior(self) -> None:
|
||||
"""Test complete throttle behavior: immediate write, deferral, and timer.
|
||||
|
||||
Verifies:
|
||||
1. First update writes immediately
|
||||
2. Second and third updates within throttle window are deferred
|
||||
3. Deferred timer fires and writes the LATEST pending update (third)
|
||||
"""
|
||||
from flask import current_app
|
||||
|
||||
from superset.commands.tasks.submit import SubmitTaskCommand
|
||||
from superset.tasks.context import TaskContext
|
||||
|
||||
# Get throttle interval from config (default: 2 seconds)
|
||||
throttle_interval = current_app.config["TASK_PROGRESS_UPDATE_THROTTLE_INTERVAL"]
|
||||
|
||||
# Create task
|
||||
task_obj = SubmitTaskCommand(
|
||||
data={
|
||||
"task_type": "test_throttle_behavior",
|
||||
"task_key": f"test_key_{uuid.uuid4().hex[:8]}",
|
||||
"task_name": "Test Throttle Behavior",
|
||||
"scope": TaskScope.SYSTEM,
|
||||
}
|
||||
).run()
|
||||
task_uuid = task_obj.uuid
|
||||
|
||||
# Get fresh task for context
|
||||
fresh_task = TaskDAO.find_one_or_none(uuid=task_uuid, skip_base_filter=True)
|
||||
assert fresh_task is not None
|
||||
ctx = TaskContext(fresh_task)
|
||||
|
||||
try:
|
||||
# === Step 1: First update - writes immediately ===
|
||||
ctx.update_task(progress=0.1, payload={"step": 1})
|
||||
|
||||
db.session.expire_all()
|
||||
task_step1 = TaskDAO.find_one_or_none(uuid=task_uuid, skip_base_filter=True)
|
||||
assert task_step1 is not None
|
||||
assert task_step1.properties_dict.get("progress_percent") == 0.1
|
||||
assert task_step1.payload_dict.get("step") == 1
|
||||
|
||||
# === Step 2: Second update - deferred (within throttle window) ===
|
||||
ctx.update_task(progress=0.5, payload={"step": 2})
|
||||
|
||||
# === Step 3: Third update - also deferred, overwrites second in cache ===
|
||||
ctx.update_task(progress=0.7, payload={"step": 3})
|
||||
|
||||
# Verify in-memory cache has LATEST update (third)
|
||||
assert ctx._properties_cache.get("progress_percent") == 0.7
|
||||
assert ctx._payload_cache.get("step") == 3
|
||||
|
||||
# Verify DB still has first update (both second and third deferred)
|
||||
db.session.expire_all()
|
||||
task_step2 = TaskDAO.find_one_or_none(uuid=task_uuid, skip_base_filter=True)
|
||||
assert task_step2 is not None
|
||||
assert task_step2.properties_dict.get("progress_percent") == 0.1
|
||||
assert task_step2.payload_dict.get("step") == 1
|
||||
|
||||
# === Step 4: Wait for deferred timer to fire ===
|
||||
time.sleep(throttle_interval + 0.5)
|
||||
|
||||
# Verify timer fired and wrote the LATEST update (third, not second)
|
||||
db.session.expire_all()
|
||||
task_step3 = TaskDAO.find_one_or_none(uuid=task_uuid, skip_base_filter=True)
|
||||
assert task_step3 is not None
|
||||
assert task_step3.properties_dict.get("progress_percent") == 0.7
|
||||
assert task_step3.payload_dict.get("step") == 3
|
||||
|
||||
finally:
|
||||
ctx._cancel_deferred_flush_timer()
|
||||
226
tests/integration_tests/tasks/test_timeout.py
Normal file
226
tests/integration_tests/tasks/test_timeout.py
Normal file
@@ -0,0 +1,226 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""Integration tests for GTF timeout handling.
|
||||
|
||||
Uses module-level task functions with manual registry (like test_event_handlers.py)
|
||||
to avoid mypy issues with the @task decorator's complex generic types.
|
||||
|
||||
NOTE: Tests that use background threads (timeout/abort handlers) are skipped in
|
||||
SQLite environments because SQLite connections cannot be shared across threads.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from superset_core.api.tasks import TaskScope, TaskStatus
|
||||
|
||||
from superset.commands.tasks.cancel import CancelTaskCommand
|
||||
from superset.daos.tasks import TaskDAO
|
||||
from superset.extensions import db
|
||||
from superset.models.tasks import Task
|
||||
from superset.tasks.ambient_context import get_context
|
||||
from superset.tasks.registry import TaskRegistry
|
||||
from superset.tasks.scheduler import execute_task
|
||||
from tests.integration_tests.base_tests import SupersetTestCase
|
||||
from tests.integration_tests.constants import ADMIN_USERNAME
|
||||
|
||||
|
||||
def _skip_if_sqlite() -> None:
|
||||
"""Skip test if running with SQLite database.
|
||||
|
||||
SQLite connections cannot be shared across threads, which breaks
|
||||
timeout tests that use background threads for abort handlers.
|
||||
Must be called from within a test method (with app context).
|
||||
"""
|
||||
if "sqlite" in db.engine.url.drivername:
|
||||
pytest.skip("SQLite connections cannot be shared across threads")
|
||||
|
||||
|
||||
# Module-level state to track handler calls
|
||||
_handler_state: dict[str, Any] = {}
|
||||
|
||||
|
||||
def _reset_handler_state() -> None:
|
||||
"""Reset handler state before each test."""
|
||||
global _handler_state
|
||||
_handler_state = {
|
||||
"abort_called": False,
|
||||
"handler_exception": None,
|
||||
}
|
||||
|
||||
|
||||
def timeout_abortable_task() -> None:
|
||||
"""Task with abort handler that exits when aborted."""
|
||||
ctx = get_context()
|
||||
|
||||
@ctx.on_abort
|
||||
def on_abort() -> None:
|
||||
_handler_state["abort_called"] = True
|
||||
|
||||
# Poll for abort signal
|
||||
for _ in range(50):
|
||||
if _handler_state["abort_called"]:
|
||||
return
|
||||
time.sleep(0.1)
|
||||
|
||||
|
||||
def timeout_handler_fails_task() -> None:
|
||||
"""Task with abort handler that throws an exception."""
|
||||
ctx = get_context()
|
||||
|
||||
@ctx.on_abort
|
||||
def on_abort() -> None:
|
||||
_handler_state["abort_called"] = True
|
||||
raise ValueError("Handler crashed!")
|
||||
|
||||
# Sleep longer than timeout
|
||||
time.sleep(5)
|
||||
|
||||
|
||||
def simple_task_with_abort() -> None:
|
||||
"""Simple task with abort handler for testing."""
|
||||
ctx = get_context()
|
||||
|
||||
@ctx.on_abort
|
||||
def on_abort() -> None:
|
||||
pass
|
||||
|
||||
|
||||
def quick_task_with_abort() -> None:
|
||||
"""Quick task that completes before timeout."""
|
||||
ctx = get_context()
|
||||
|
||||
@ctx.on_abort
|
||||
def on_abort() -> None:
|
||||
pass
|
||||
|
||||
time.sleep(0.2)
|
||||
|
||||
|
||||
def _register_test_tasks() -> None:
|
||||
"""Register test task functions if not already registered.
|
||||
|
||||
Called in setUp() to ensure tasks are registered regardless of
|
||||
whether other tests have cleared the registry.
|
||||
"""
|
||||
registrations = [
|
||||
("test_timeout_abortable", timeout_abortable_task),
|
||||
("test_timeout_handler_fails", timeout_handler_fails_task),
|
||||
("test_timeout_simple", simple_task_with_abort),
|
||||
("test_timeout_quick", quick_task_with_abort),
|
||||
]
|
||||
for name, func in registrations:
|
||||
if not TaskRegistry.is_registered(name):
|
||||
TaskRegistry.register(name, func)
|
||||
|
||||
|
||||
class TestTimeoutHandling(SupersetTestCase):
|
||||
"""E2E tests for task timeout functionality."""
|
||||
|
||||
def setUp(self) -> None:
|
||||
"""Set up test fixtures."""
|
||||
super().setUp()
|
||||
self.login(ADMIN_USERNAME)
|
||||
_register_test_tasks()
|
||||
_reset_handler_state()
|
||||
|
||||
def test_timeout_with_abort_handler_results_in_timed_out_status(self) -> None:
|
||||
"""Task with timeout and abort handler should end with TIMED_OUT status."""
|
||||
_skip_if_sqlite()
|
||||
|
||||
# Create task with timeout
|
||||
task_obj = TaskDAO.create_task(
|
||||
task_type="test_timeout_abortable",
|
||||
task_key=f"test_key_{uuid.uuid4().hex[:8]}",
|
||||
task_name="Test Timeout",
|
||||
scope=TaskScope.SYSTEM,
|
||||
properties={"timeout": 1}, # 1 second timeout
|
||||
)
|
||||
|
||||
# Execute task via Celery executor (synchronously)
|
||||
# Use str(uuid) since Celery serializes args as JSON strings
|
||||
result = execute_task.apply(
|
||||
args=[str(task_obj.uuid), "test_timeout_abortable", (), {}]
|
||||
)
|
||||
|
||||
# Verify execution completed
|
||||
assert result.successful()
|
||||
assert result.result["status"] == TaskStatus.TIMED_OUT.value
|
||||
|
||||
# Verify abort handler was called
|
||||
assert _handler_state["abort_called"]
|
||||
|
||||
def test_user_abort_results_in_aborted_status(self) -> None:
|
||||
"""User-initiated abort on pending task should result in ABORTED."""
|
||||
# Create task (pending state)
|
||||
task_obj = TaskDAO.create_task(
|
||||
task_type="test_timeout_simple",
|
||||
task_key=f"test_key_{uuid.uuid4().hex[:8]}",
|
||||
task_name="Test Abort Task",
|
||||
scope=TaskScope.SYSTEM,
|
||||
)
|
||||
|
||||
# Cancel before execution (pending task abort)
|
||||
CancelTaskCommand(task_obj.uuid, force=True).run()
|
||||
|
||||
# Refresh from DB
|
||||
db.session.expire_all()
|
||||
task_obj = db.session.query(Task).filter_by(uuid=task_obj.uuid).first()
|
||||
assert task_obj.status == TaskStatus.ABORTED.value
|
||||
|
||||
def test_no_timeout_when_not_configured(self) -> None:
|
||||
"""Task without timeout should run to completion regardless of duration."""
|
||||
task_obj = TaskDAO.create_task(
|
||||
task_type="test_timeout_quick",
|
||||
task_key=f"test_key_{uuid.uuid4().hex[:8]}",
|
||||
task_name="Test No Timeout",
|
||||
scope=TaskScope.SYSTEM,
|
||||
# No timeout property
|
||||
)
|
||||
|
||||
# Use str(uuid) since Celery serializes args as JSON strings
|
||||
result = execute_task.apply(
|
||||
args=[str(task_obj.uuid), "test_timeout_quick", (), {}]
|
||||
)
|
||||
|
||||
assert result.successful()
|
||||
assert result.result["status"] == TaskStatus.SUCCESS.value
|
||||
|
||||
def test_abort_handler_exception_results_in_failure(self) -> None:
|
||||
"""If abort handler throws during timeout, task should be FAILURE."""
|
||||
_skip_if_sqlite()
|
||||
|
||||
task_obj = TaskDAO.create_task(
|
||||
task_type="test_timeout_handler_fails",
|
||||
task_key=f"test_key_{uuid.uuid4().hex[:8]}",
|
||||
task_name="Test Handler Fails",
|
||||
scope=TaskScope.SYSTEM,
|
||||
properties={"timeout": 1}, # 1 second timeout
|
||||
)
|
||||
|
||||
# Use str(uuid) since Celery serializes args as JSON strings
|
||||
result = execute_task.apply(
|
||||
args=[str(task_obj.uuid), "test_timeout_handler_fails", (), {}]
|
||||
)
|
||||
|
||||
assert result.successful()
|
||||
assert result.result["status"] == TaskStatus.FAILURE.value
|
||||
assert _handler_state["abort_called"]
|
||||
420
tests/unit_tests/daos/test_tasks.py
Normal file
420
tests/unit_tests/daos/test_tasks.py
Normal file
@@ -0,0 +1,420 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file to you under
|
||||
# the Apache License, Version 2.0 (the "License"); you may not
|
||||
# use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from collections.abc import Iterator
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm.session import Session
|
||||
from superset_core.api.tasks import TaskProperties, TaskScope, TaskStatus
|
||||
|
||||
from superset.commands.tasks.exceptions import TaskNotAbortableError
|
||||
from superset.models.tasks import Task
|
||||
from superset.tasks.utils import get_active_dedup_key, get_finished_dedup_key
|
||||
|
||||
# Test constants
|
||||
TASK_UUID = UUID("e7765491-40c1-4f35-a4f5-06308e79310e")
|
||||
TASK_ID = 42
|
||||
TEST_TASK_TYPE = "test_type"
|
||||
TEST_TASK_KEY = "test-key"
|
||||
TEST_USER_ID = 1
|
||||
|
||||
|
||||
def create_task(
|
||||
session: Session,
|
||||
*,
|
||||
task_id: int | None = None,
|
||||
task_uuid: UUID | None = None,
|
||||
task_key: str = TEST_TASK_KEY,
|
||||
task_type: str = TEST_TASK_TYPE,
|
||||
scope: TaskScope = TaskScope.PRIVATE,
|
||||
status: TaskStatus = TaskStatus.PENDING,
|
||||
user_id: int | None = TEST_USER_ID,
|
||||
properties: TaskProperties | None = None,
|
||||
use_finished_dedup_key: bool = False,
|
||||
) -> Task:
|
||||
"""Helper to create a task with sensible defaults for testing."""
|
||||
if use_finished_dedup_key:
|
||||
dedup_key = get_finished_dedup_key(task_uuid or TASK_UUID)
|
||||
else:
|
||||
dedup_key = get_active_dedup_key(
|
||||
scope=scope,
|
||||
task_type=task_type,
|
||||
task_key=task_key,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
task = Task(
|
||||
task_type=task_type,
|
||||
task_key=task_key,
|
||||
scope=scope.value,
|
||||
status=status.value,
|
||||
dedup_key=dedup_key,
|
||||
user_id=user_id,
|
||||
)
|
||||
if task_id is not None:
|
||||
task.id = task_id
|
||||
if task_uuid:
|
||||
task.uuid = task_uuid
|
||||
if properties:
|
||||
task.update_properties(properties)
|
||||
|
||||
session.add(task)
|
||||
session.flush()
|
||||
return task
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session_with_task(session: Session) -> Iterator[Session]:
|
||||
"""Create a session with Task and TaskSubscriber tables."""
|
||||
from superset.models.task_subscribers import TaskSubscriber
|
||||
|
||||
engine = session.get_bind()
|
||||
Task.metadata.create_all(engine)
|
||||
TaskSubscriber.metadata.create_all(engine)
|
||||
|
||||
yield session
|
||||
session.rollback()
|
||||
|
||||
|
||||
def test_find_by_task_key_active(session_with_task: Session) -> None:
|
||||
"""Test finding active task by task_key"""
|
||||
from superset.daos.tasks import TaskDAO
|
||||
|
||||
create_task(session_with_task)
|
||||
|
||||
result = TaskDAO.find_by_task_key(
|
||||
task_type=TEST_TASK_TYPE,
|
||||
task_key=TEST_TASK_KEY,
|
||||
scope=TaskScope.PRIVATE,
|
||||
user_id=TEST_USER_ID,
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.task_key == TEST_TASK_KEY
|
||||
assert result.task_type == TEST_TASK_TYPE
|
||||
assert result.status == TaskStatus.PENDING.value
|
||||
|
||||
|
||||
def test_find_by_task_key_not_found(session_with_task: Session) -> None:
|
||||
"""Test finding task by task_key returns None when not found"""
|
||||
from superset.daos.tasks import TaskDAO
|
||||
|
||||
result = TaskDAO.find_by_task_key(
|
||||
task_type=TEST_TASK_TYPE,
|
||||
task_key="nonexistent-key",
|
||||
scope=TaskScope.PRIVATE,
|
||||
user_id=TEST_USER_ID,
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_find_by_task_key_finished_not_found(session_with_task: Session) -> None:
|
||||
"""Test that find_by_task_key returns None for finished tasks.
|
||||
|
||||
Finished tasks have a different dedup_key format (UUID-based),
|
||||
so they won't be found by the active task lookup.
|
||||
"""
|
||||
from superset.daos.tasks import TaskDAO
|
||||
|
||||
create_task(
|
||||
session_with_task,
|
||||
task_key="finished-key",
|
||||
status=TaskStatus.SUCCESS,
|
||||
use_finished_dedup_key=True,
|
||||
task_uuid=TASK_UUID,
|
||||
)
|
||||
|
||||
# Should not find SUCCESS task via active lookup
|
||||
result = TaskDAO.find_by_task_key(
|
||||
task_type=TEST_TASK_TYPE,
|
||||
task_key="finished-key",
|
||||
scope=TaskScope.PRIVATE,
|
||||
user_id=TEST_USER_ID,
|
||||
)
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_create_task_success(session_with_task: Session) -> None:
|
||||
"""Test successful task creation."""
|
||||
from superset.daos.tasks import TaskDAO
|
||||
|
||||
result = TaskDAO.create_task(
|
||||
task_type=TEST_TASK_TYPE,
|
||||
task_key=TEST_TASK_KEY,
|
||||
scope=TaskScope.PRIVATE,
|
||||
user_id=TEST_USER_ID,
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.task_key == TEST_TASK_KEY
|
||||
assert result.task_type == TEST_TASK_TYPE
|
||||
assert result.status == TaskStatus.PENDING.value
|
||||
assert isinstance(result, Task)
|
||||
|
||||
|
||||
def test_create_task_with_user_id(session_with_task: Session) -> None:
|
||||
"""Test task creation with explicit user_id."""
|
||||
from superset.daos.tasks import TaskDAO
|
||||
|
||||
result = TaskDAO.create_task(
|
||||
task_type=TEST_TASK_TYPE,
|
||||
task_key="user-task",
|
||||
scope=TaskScope.PRIVATE,
|
||||
user_id=42,
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.user_id == 42
|
||||
# Creator should be auto-subscribed
|
||||
assert len(result.subscribers) == 1
|
||||
assert result.subscribers[0].user_id == 42
|
||||
|
||||
|
||||
def test_create_task_with_properties(session_with_task: Session) -> None:
|
||||
"""Test task creation with properties."""
|
||||
from superset.daos.tasks import TaskDAO
|
||||
|
||||
result = TaskDAO.create_task(
|
||||
task_type=TEST_TASK_TYPE,
|
||||
task_key="props-task",
|
||||
scope=TaskScope.PRIVATE,
|
||||
user_id=TEST_USER_ID,
|
||||
properties={"timeout": 300},
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.properties_dict.get("timeout") == 300
|
||||
|
||||
|
||||
def test_abort_task_pending_success(session_with_task: Session) -> None:
|
||||
"""Test successful abort of pending task - goes directly to ABORTED"""
|
||||
from superset.daos.tasks import TaskDAO
|
||||
|
||||
task = create_task(
|
||||
session_with_task,
|
||||
task_key="pending-task",
|
||||
status=TaskStatus.PENDING,
|
||||
)
|
||||
|
||||
result = TaskDAO.abort_task(task.uuid, skip_base_filter=True)
|
||||
|
||||
assert result is not None
|
||||
assert result.status == TaskStatus.ABORTED.value
|
||||
|
||||
|
||||
def test_abort_task_in_progress_abortable(session_with_task: Session) -> None:
|
||||
"""Test abort of in-progress task with abort handler.
|
||||
|
||||
Should transition to ABORTING status.
|
||||
"""
|
||||
from superset.daos.tasks import TaskDAO
|
||||
|
||||
task = create_task(
|
||||
session_with_task,
|
||||
task_key="abortable-task",
|
||||
status=TaskStatus.IN_PROGRESS,
|
||||
properties={"is_abortable": True},
|
||||
)
|
||||
|
||||
result = TaskDAO.abort_task(task.uuid, skip_base_filter=True)
|
||||
|
||||
assert result is not None
|
||||
# Should set status to ABORTING, not ABORTED
|
||||
assert result.status == TaskStatus.ABORTING.value
|
||||
|
||||
|
||||
def test_abort_task_in_progress_not_abortable(session_with_task: Session) -> None:
|
||||
"""Test abort of in-progress task without abort handler - raises error"""
|
||||
from superset.daos.tasks import TaskDAO
|
||||
|
||||
task = create_task(
|
||||
session_with_task,
|
||||
task_key="non-abortable-task",
|
||||
status=TaskStatus.IN_PROGRESS,
|
||||
properties={"is_abortable": False},
|
||||
)
|
||||
|
||||
with pytest.raises(TaskNotAbortableError):
|
||||
TaskDAO.abort_task(task.uuid, skip_base_filter=True)
|
||||
|
||||
|
||||
def test_abort_task_in_progress_is_abortable_none(session_with_task: Session) -> None:
|
||||
"""Test abort of in-progress task with is_abortable not set - raises error"""
|
||||
from superset.daos.tasks import TaskDAO
|
||||
|
||||
task = create_task(
|
||||
session_with_task,
|
||||
task_key="no-abortable-prop-task",
|
||||
status=TaskStatus.IN_PROGRESS,
|
||||
# Empty properties - no is_abortable key
|
||||
)
|
||||
|
||||
with pytest.raises(TaskNotAbortableError):
|
||||
TaskDAO.abort_task(task.uuid, skip_base_filter=True)
|
||||
|
||||
|
||||
def test_abort_task_already_aborting(session_with_task: Session) -> None:
|
||||
"""Test abort of already aborting task - idempotent success"""
|
||||
from superset.daos.tasks import TaskDAO
|
||||
|
||||
task = create_task(
|
||||
session_with_task,
|
||||
task_key="aborting-task",
|
||||
status=TaskStatus.ABORTING,
|
||||
)
|
||||
|
||||
result = TaskDAO.abort_task(task.uuid, skip_base_filter=True)
|
||||
|
||||
# Idempotent - returns task without error
|
||||
assert result is not None
|
||||
assert result.status == TaskStatus.ABORTING.value
|
||||
|
||||
|
||||
def test_abort_task_not_found(session_with_task: Session) -> None:
|
||||
"""Test abort fails when task not found"""
|
||||
from superset.daos.tasks import TaskDAO
|
||||
|
||||
result = TaskDAO.abort_task(UUID("00000000-0000-0000-0000-000000000000"))
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_abort_task_already_finished(session_with_task: Session) -> None:
|
||||
"""Test abort fails when task already finished"""
|
||||
from superset.daos.tasks import TaskDAO
|
||||
|
||||
task = create_task(
|
||||
session_with_task,
|
||||
task_key="finished-task",
|
||||
status=TaskStatus.SUCCESS,
|
||||
use_finished_dedup_key=True,
|
||||
task_uuid=TASK_UUID,
|
||||
)
|
||||
|
||||
result = TaskDAO.abort_task(task.uuid, skip_base_filter=True)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_add_subscriber(session_with_task: Session) -> None:
|
||||
"""Test adding a subscriber to a task"""
|
||||
from superset.daos.tasks import TaskDAO
|
||||
|
||||
task = create_task(
|
||||
session_with_task,
|
||||
task_key="shared-task",
|
||||
scope=TaskScope.SHARED,
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
# Add subscriber
|
||||
result = TaskDAO.add_subscriber(task.id, user_id=TEST_USER_ID)
|
||||
assert result is True
|
||||
|
||||
# Verify subscriber was added
|
||||
session_with_task.refresh(task)
|
||||
assert len(task.subscribers) == 1
|
||||
assert task.subscribers[0].user_id == TEST_USER_ID
|
||||
|
||||
|
||||
def test_add_subscriber_idempotent(session_with_task: Session) -> None:
|
||||
"""Test adding same subscriber twice is idempotent"""
|
||||
from superset.daos.tasks import TaskDAO
|
||||
|
||||
task = create_task(
|
||||
session_with_task,
|
||||
task_key="shared-task-2",
|
||||
scope=TaskScope.SHARED,
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
# Add subscriber twice
|
||||
result1 = TaskDAO.add_subscriber(task.id, user_id=TEST_USER_ID)
|
||||
result2 = TaskDAO.add_subscriber(task.id, user_id=TEST_USER_ID)
|
||||
|
||||
assert result1 is True
|
||||
assert result2 is False # Already subscribed
|
||||
|
||||
# Verify only one subscriber
|
||||
session_with_task.refresh(task)
|
||||
assert len(task.subscribers) == 1
|
||||
|
||||
|
||||
def test_remove_subscriber(session_with_task: Session) -> None:
|
||||
"""Test removing a subscriber from a task"""
|
||||
from superset.daos.tasks import TaskDAO
|
||||
|
||||
task = create_task(
|
||||
session_with_task,
|
||||
task_key="shared-task-3",
|
||||
scope=TaskScope.SHARED,
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
TaskDAO.add_subscriber(task.id, user_id=TEST_USER_ID)
|
||||
session_with_task.refresh(task)
|
||||
assert len(task.subscribers) == 1
|
||||
|
||||
# Remove subscriber
|
||||
result = TaskDAO.remove_subscriber(task.id, user_id=TEST_USER_ID)
|
||||
|
||||
assert result is not None
|
||||
assert len(result.subscribers) == 0
|
||||
|
||||
|
||||
def test_remove_subscriber_not_subscribed(session_with_task: Session) -> None:
|
||||
"""Test removing non-existent subscriber returns None"""
|
||||
from superset.daos.tasks import TaskDAO
|
||||
|
||||
task = create_task(
|
||||
session_with_task,
|
||||
task_key="shared-task-4",
|
||||
scope=TaskScope.SHARED,
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
# Try to remove non-existent subscriber
|
||||
result = TaskDAO.remove_subscriber(task.id, user_id=999)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_get_status(session_with_task: Session) -> None:
|
||||
"""Test get_status returns status string when task found by UUID"""
|
||||
from superset.daos.tasks import TaskDAO
|
||||
|
||||
task = create_task(
|
||||
session_with_task,
|
||||
task_uuid=TASK_UUID,
|
||||
task_key="status-task",
|
||||
status=TaskStatus.IN_PROGRESS,
|
||||
)
|
||||
|
||||
result = TaskDAO.get_status(task.uuid)
|
||||
|
||||
assert result == TaskStatus.IN_PROGRESS.value
|
||||
|
||||
|
||||
def test_get_status_not_found(session_with_task: Session) -> None:
|
||||
"""Test get_status returns None when task not found"""
|
||||
from superset.daos.tasks import TaskDAO
|
||||
|
||||
result = TaskDAO.get_status(UUID("00000000-0000-0000-0000-000000000000"))
|
||||
|
||||
assert result is None
|
||||
@@ -18,17 +18,21 @@
|
||||
# pylint: disable=invalid-name
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from freezegun import freeze_time
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
# Force module loading before tests run so patches work correctly
|
||||
import superset.commands.distributed_lock.acquire as acquire_module
|
||||
import superset.commands.distributed_lock.release as release_module
|
||||
from superset import db
|
||||
from superset.distributed_lock import KeyValueDistributedLock
|
||||
from superset.distributed_lock import DistributedLock
|
||||
from superset.distributed_lock.types import LockValue
|
||||
from superset.distributed_lock.utils import get_key
|
||||
from superset.exceptions import CreateKeyValueDistributedLockFailedException
|
||||
from superset.exceptions import AcquireDistributedLockFailedException
|
||||
from superset.key_value.types import JsonKeyValueCodec
|
||||
|
||||
LOCK_VALUE: LockValue = {"value": True}
|
||||
@@ -56,9 +60,9 @@ def _get_other_session() -> Session:
|
||||
return SessionMaker()
|
||||
|
||||
|
||||
def test_key_value_distributed_lock_happy_path() -> None:
|
||||
def test_distributed_lock_kv_happy_path() -> None:
|
||||
"""
|
||||
Test successfully acquiring and returning the distributed lock.
|
||||
Test successfully acquiring and returning the distributed lock via KV backend.
|
||||
|
||||
Note, we're using another session for asserting the lock state in the Metastore
|
||||
to simulate what another worker will observe. Otherwise, there's the risk that
|
||||
@@ -66,24 +70,29 @@ def test_key_value_distributed_lock_happy_path() -> None:
|
||||
"""
|
||||
session = _get_other_session()
|
||||
|
||||
with freeze_time("2021-01-01"):
|
||||
assert _get_lock(MAIN_KEY, session) is None
|
||||
# Ensure Redis is not configured so KV backend is used
|
||||
with (
|
||||
patch.object(acquire_module, "get_redis_client", return_value=None),
|
||||
patch.object(release_module, "get_redis_client", return_value=None),
|
||||
):
|
||||
with freeze_time("2021-01-01"):
|
||||
assert _get_lock(MAIN_KEY, session) is None
|
||||
|
||||
with KeyValueDistributedLock("ns", a=1, b=2) as key:
|
||||
assert key == MAIN_KEY
|
||||
assert _get_lock(key, session) == LOCK_VALUE
|
||||
assert _get_lock(OTHER_KEY, session) is None
|
||||
with DistributedLock("ns", a=1, b=2) as key:
|
||||
assert key == MAIN_KEY
|
||||
assert _get_lock(key, session) == LOCK_VALUE
|
||||
assert _get_lock(OTHER_KEY, session) is None
|
||||
|
||||
with pytest.raises(CreateKeyValueDistributedLockFailedException):
|
||||
with KeyValueDistributedLock("ns", a=1, b=2):
|
||||
pass
|
||||
with pytest.raises(AcquireDistributedLockFailedException):
|
||||
with DistributedLock("ns", a=1, b=2):
|
||||
pass
|
||||
|
||||
assert _get_lock(MAIN_KEY, session) is None
|
||||
assert _get_lock(MAIN_KEY, session) is None
|
||||
|
||||
|
||||
def test_key_value_distributed_lock_expired() -> None:
|
||||
def test_distributed_lock_kv_expired() -> None:
|
||||
"""
|
||||
Test expiration of the distributed lock
|
||||
Test expiration of the distributed lock via KV backend.
|
||||
|
||||
Note, we're using another session for asserting the lock state in the Metastore
|
||||
to simulate what another worker will observe. Otherwise, there's the risk that
|
||||
@@ -91,11 +100,112 @@ def test_key_value_distributed_lock_expired() -> None:
|
||||
"""
|
||||
session = _get_other_session()
|
||||
|
||||
with freeze_time("2021-01-01"):
|
||||
assert _get_lock(MAIN_KEY, session) is None
|
||||
with KeyValueDistributedLock("ns", a=1, b=2):
|
||||
assert _get_lock(MAIN_KEY, session) == LOCK_VALUE
|
||||
with freeze_time("2022-01-01"):
|
||||
assert _get_lock(MAIN_KEY, session) is None
|
||||
# Ensure Redis is not configured so KV backend is used
|
||||
with (
|
||||
patch.object(acquire_module, "get_redis_client", return_value=None),
|
||||
patch.object(release_module, "get_redis_client", return_value=None),
|
||||
):
|
||||
with freeze_time("2021-01-01"):
|
||||
assert _get_lock(MAIN_KEY, session) is None
|
||||
with DistributedLock("ns", a=1, b=2):
|
||||
assert _get_lock(MAIN_KEY, session) == LOCK_VALUE
|
||||
with freeze_time("2022-01-01"):
|
||||
assert _get_lock(MAIN_KEY, session) is None
|
||||
|
||||
assert _get_lock(MAIN_KEY, session) is None
|
||||
assert _get_lock(MAIN_KEY, session) is None
|
||||
|
||||
|
||||
def test_distributed_lock_uses_redis_when_configured() -> None:
|
||||
"""Test that DistributedLock uses Redis backend when configured."""
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.set.return_value = True # Lock acquired
|
||||
|
||||
# Use patch.object to patch on already-imported modules
|
||||
with (
|
||||
patch.object(acquire_module, "get_redis_client", return_value=mock_redis),
|
||||
patch.object(release_module, "get_redis_client", return_value=mock_redis),
|
||||
):
|
||||
with DistributedLock("test_redis", key="value") as lock_key:
|
||||
assert lock_key is not None
|
||||
# Verify SET NX EX was called
|
||||
mock_redis.set.assert_called_once()
|
||||
call_args = mock_redis.set.call_args
|
||||
assert call_args.kwargs["nx"] is True
|
||||
assert "ex" in call_args.kwargs
|
||||
|
||||
# Verify DELETE was called on exit
|
||||
mock_redis.delete.assert_called_once()
|
||||
|
||||
|
||||
def test_distributed_lock_redis_already_taken() -> None:
|
||||
"""Test Redis lock fails when already held."""
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.set.return_value = None # Lock not acquired (already taken)
|
||||
|
||||
with patch.object(acquire_module, "get_redis_client", return_value=mock_redis):
|
||||
with pytest.raises(AcquireDistributedLockFailedException):
|
||||
with DistributedLock("test_redis", key="value"):
|
||||
pass
|
||||
|
||||
|
||||
def test_distributed_lock_redis_connection_error() -> None:
|
||||
"""Test Redis connection error raises exception (fail fast)."""
|
||||
import redis
|
||||
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.set.side_effect = redis.RedisError("Connection failed")
|
||||
|
||||
with patch.object(acquire_module, "get_redis_client", return_value=mock_redis):
|
||||
with pytest.raises(AcquireDistributedLockFailedException):
|
||||
with DistributedLock("test_redis", key="value"):
|
||||
pass
|
||||
|
||||
|
||||
def test_distributed_lock_custom_ttl() -> None:
|
||||
"""Test Redis lock with custom TTL."""
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.set.return_value = True
|
||||
|
||||
with (
|
||||
patch.object(acquire_module, "get_redis_client", return_value=mock_redis),
|
||||
patch.object(release_module, "get_redis_client", return_value=mock_redis),
|
||||
):
|
||||
with DistributedLock("test", ttl_seconds=60, key="value"):
|
||||
call_args = mock_redis.set.call_args
|
||||
assert call_args.kwargs["ex"] == 60 # Custom TTL
|
||||
|
||||
|
||||
def test_distributed_lock_default_ttl(app_context: None) -> None:
|
||||
"""Test Redis lock uses default TTL when not specified."""
|
||||
from superset.commands.distributed_lock.base import get_default_lock_ttl
|
||||
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.set.return_value = True
|
||||
|
||||
with (
|
||||
patch.object(acquire_module, "get_redis_client", return_value=mock_redis),
|
||||
patch.object(release_module, "get_redis_client", return_value=mock_redis),
|
||||
):
|
||||
with DistributedLock("test", key="value"):
|
||||
call_args = mock_redis.set.call_args
|
||||
assert call_args.kwargs["ex"] == get_default_lock_ttl()
|
||||
|
||||
|
||||
def test_distributed_lock_fallback_to_kv_when_redis_not_configured() -> None:
|
||||
"""Test falls back to KV lock when Redis not configured."""
|
||||
session = _get_other_session()
|
||||
test_key = get_key("test_fallback", key="value")
|
||||
|
||||
with (
|
||||
patch.object(acquire_module, "get_redis_client", return_value=None),
|
||||
patch.object(release_module, "get_redis_client", return_value=None),
|
||||
):
|
||||
with freeze_time("2021-01-01"):
|
||||
# When Redis is not configured, should use KV backend
|
||||
with DistributedLock("test_fallback", key="value") as lock_key:
|
||||
assert lock_key == test_key
|
||||
# Verify lock exists in KV store
|
||||
assert _get_lock(test_key, session) == LOCK_VALUE
|
||||
|
||||
# Lock should be released
|
||||
assert _get_lock(test_key, session) is None
|
||||
|
||||
477
tests/unit_tests/tasks/test_decorators.py
Normal file
477
tests/unit_tests/tasks/test_decorators.py
Normal file
@@ -0,0 +1,477 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""Unit tests for task decorators"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from superset_core.api.tasks import TaskOptions, TaskScope
|
||||
|
||||
from superset.commands.tasks.exceptions import GlobalTaskFrameworkDisabledError
|
||||
from superset.tasks.decorators import task, TaskWrapper
|
||||
from superset.tasks.registry import TaskRegistry
|
||||
|
||||
TEST_UUID = UUID("b8b61b7b-1cd3-4a31-a74a-0a95341afc06")
|
||||
|
||||
|
||||
class TestTaskDecoratorFeatureFlag:
|
||||
"""Tests for @task decorator feature flag behavior"""
|
||||
|
||||
def setup_method(self):
|
||||
"""Clear task registry before each test"""
|
||||
TaskRegistry._tasks.clear()
|
||||
|
||||
@patch("superset.tasks.decorators.is_feature_enabled", return_value=False)
|
||||
def test_decorator_succeeds_when_gtf_disabled(self, mock_feature_flag):
|
||||
"""Test that @task decorator can be applied even when GTF is disabled.
|
||||
|
||||
This enables safe module imports during app startup or Celery autodiscovery.
|
||||
"""
|
||||
|
||||
# Decoration should succeed - no error raised
|
||||
@task(name="test_gtf_disabled_decorator")
|
||||
def my_task() -> None:
|
||||
pass
|
||||
|
||||
assert isinstance(my_task, TaskWrapper)
|
||||
assert my_task.name == "test_gtf_disabled_decorator"
|
||||
|
||||
@patch("superset.tasks.decorators.is_feature_enabled", return_value=False)
|
||||
def test_call_raises_error_when_gtf_disabled(self, mock_feature_flag):
|
||||
"""Test that calling a task raises GlobalTaskFrameworkDisabledError
|
||||
when GTF is disabled."""
|
||||
|
||||
@task(name="test_gtf_disabled_call")
|
||||
def my_task() -> None:
|
||||
pass
|
||||
|
||||
with pytest.raises(GlobalTaskFrameworkDisabledError):
|
||||
my_task()
|
||||
|
||||
@patch("superset.tasks.decorators.is_feature_enabled", return_value=False)
|
||||
def test_schedule_raises_error_when_gtf_disabled(self, mock_feature_flag):
|
||||
"""Test that scheduling a task raises GlobalTaskFrameworkDisabledError
|
||||
when GTF is disabled."""
|
||||
|
||||
@task(name="test_gtf_disabled_schedule")
|
||||
def my_task() -> None:
|
||||
pass
|
||||
|
||||
with pytest.raises(GlobalTaskFrameworkDisabledError):
|
||||
my_task.schedule()
|
||||
|
||||
|
||||
class TestTaskDecorator:
|
||||
"""Tests for @task decorator"""
|
||||
|
||||
def test_decorator_basic(self):
|
||||
"""Test basic decorator usage without options"""
|
||||
|
||||
@task(name="test_task")
|
||||
def my_task(arg1: int, arg2: str) -> None:
|
||||
pass
|
||||
|
||||
assert isinstance(my_task, TaskWrapper)
|
||||
assert my_task.name == "test_task"
|
||||
assert my_task.scope == TaskScope.PRIVATE
|
||||
|
||||
def test_decorator_without_parentheses(self):
|
||||
"""Test decorator usage without parentheses"""
|
||||
|
||||
@task
|
||||
def my_no_parens_task(arg1: int, arg2: str) -> None:
|
||||
pass
|
||||
|
||||
assert isinstance(my_no_parens_task, TaskWrapper)
|
||||
assert my_no_parens_task.name == "my_no_parens_task" # Uses function name
|
||||
assert my_no_parens_task.scope == TaskScope.PRIVATE
|
||||
|
||||
def test_decorator_with_default_scope_private(self):
|
||||
"""Test decorator with explicit PRIVATE scope"""
|
||||
|
||||
@task(name="private_task", scope=TaskScope.PRIVATE)
|
||||
def my_private_task(arg1: int) -> None:
|
||||
pass
|
||||
|
||||
assert my_private_task.scope == TaskScope.PRIVATE
|
||||
|
||||
def test_decorator_with_default_scope_shared(self):
|
||||
"""Test decorator with SHARED scope"""
|
||||
|
||||
@task(name="shared_task", scope=TaskScope.SHARED)
|
||||
def my_shared_task(arg1: int) -> None:
|
||||
pass
|
||||
|
||||
assert my_shared_task.scope == TaskScope.SHARED
|
||||
|
||||
def test_decorator_with_default_scope_system(self):
|
||||
"""Test decorator with SYSTEM scope"""
|
||||
|
||||
@task(name="system_task", scope=TaskScope.SYSTEM)
|
||||
def my_system_task() -> None:
|
||||
pass
|
||||
|
||||
assert my_system_task.scope == TaskScope.SYSTEM
|
||||
|
||||
def test_decorator_forbids_ctx_parameter(self):
|
||||
"""Test decorator rejects functions with ctx parameter"""
|
||||
|
||||
with pytest.raises(TypeError, match="must not define 'ctx'"):
|
||||
|
||||
@task(name="bad_task")
|
||||
def bad_task(ctx, arg1: int) -> None: # noqa: ARG001
|
||||
pass
|
||||
|
||||
def test_decorator_forbids_options_parameter(self):
|
||||
"""Test decorator rejects functions with options parameter"""
|
||||
|
||||
with pytest.raises(TypeError, match="must not define.*'options'"):
|
||||
|
||||
@task(name="bad_task")
|
||||
def bad_task(options, arg1: int) -> None: # noqa: ARG001
|
||||
pass
|
||||
|
||||
|
||||
class TestTaskWrapperMergeOptions:
|
||||
"""Tests for TaskWrapper._merge_options()"""
|
||||
|
||||
def setup_method(self):
|
||||
"""Clear task registry before each test"""
|
||||
TaskRegistry._tasks.clear()
|
||||
|
||||
def test_merge_options_no_override(self):
|
||||
"""Test merging with no override returns defaults"""
|
||||
|
||||
@task(name="test_merge_no_override_unique")
|
||||
def merge_task_1() -> None:
|
||||
pass
|
||||
|
||||
# Set default options for testing
|
||||
merge_task_1.default_options = TaskOptions(
|
||||
task_key="default_key",
|
||||
task_name="Default Name",
|
||||
)
|
||||
|
||||
merged = merge_task_1._merge_options(None)
|
||||
assert merged.task_key == "default_key"
|
||||
assert merged.task_name == "Default Name"
|
||||
|
||||
def test_merge_options_override_task_key(self):
|
||||
"""Test overriding task_key at call time"""
|
||||
|
||||
@task(name="test_merge_override_key_unique")
|
||||
def merge_task_2() -> None:
|
||||
pass
|
||||
|
||||
# Set default options for testing
|
||||
merge_task_2.default_options = TaskOptions(task_key="default_key")
|
||||
|
||||
override = TaskOptions(task_key="override_key")
|
||||
merged = merge_task_2._merge_options(override)
|
||||
assert merged.task_key == "override_key"
|
||||
|
||||
def test_merge_options_override_task_name(self):
|
||||
"""Test overriding task_name at call time"""
|
||||
|
||||
@task(name="test_merge_override_name_unique")
|
||||
def merge_task_3() -> None:
|
||||
pass
|
||||
|
||||
# Set default options for testing
|
||||
merge_task_3.default_options = TaskOptions(task_name="Default Name")
|
||||
|
||||
override = TaskOptions(task_name="Override Name")
|
||||
merged = merge_task_3._merge_options(override)
|
||||
assert merged.task_name == "Override Name"
|
||||
|
||||
def test_merge_options_override_all(self):
|
||||
"""Test overriding all options at call time"""
|
||||
|
||||
@task(name="test_merge_override_all_unique")
|
||||
def merge_task_4() -> None:
|
||||
pass
|
||||
|
||||
# Set default options for testing
|
||||
merge_task_4.default_options = TaskOptions(
|
||||
task_key="default_key",
|
||||
task_name="Default Name",
|
||||
)
|
||||
|
||||
override = TaskOptions(
|
||||
task_key="override_key",
|
||||
task_name="Override Name",
|
||||
)
|
||||
merged = merge_task_4._merge_options(override)
|
||||
assert merged.task_key == "override_key"
|
||||
assert merged.task_name == "Override Name"
|
||||
|
||||
|
||||
class TestTaskWrapperSchedule:
|
||||
"""Tests for TaskWrapper.schedule() with scope"""
|
||||
|
||||
def setup_method(self):
|
||||
"""Clear task registry before each test"""
|
||||
TaskRegistry._tasks.clear()
|
||||
|
||||
@patch("superset.tasks.decorators.TaskManager.submit_task")
|
||||
def test_schedule_uses_default_scope(self, mock_submit):
|
||||
"""Test schedule() uses decorator's default scope"""
|
||||
mock_submit.return_value = MagicMock()
|
||||
|
||||
@task(name="test_schedule_default_unique", scope=TaskScope.SHARED)
|
||||
def schedule_task_1(arg1: int) -> None:
|
||||
pass
|
||||
|
||||
# Shared tasks require explicit task_key
|
||||
schedule_task_1.schedule(123, options=TaskOptions(task_key="test_key"))
|
||||
|
||||
# Verify TaskManager.submit_task was called with correct scope
|
||||
mock_submit.assert_called_once()
|
||||
call_args = mock_submit.call_args
|
||||
assert call_args[1]["scope"] == TaskScope.SHARED
|
||||
|
||||
@patch("superset.tasks.decorators.TaskManager.submit_task")
|
||||
def test_schedule_uses_private_scope_by_default(self, mock_submit):
|
||||
"""Test schedule() uses PRIVATE scope when no scope specified"""
|
||||
mock_submit.return_value = MagicMock()
|
||||
|
||||
@task(name="test_schedule_override_unique")
|
||||
def schedule_task_2(arg1: int) -> None:
|
||||
pass
|
||||
|
||||
schedule_task_2.schedule(123)
|
||||
|
||||
# Verify PRIVATE scope was used (default)
|
||||
mock_submit.assert_called_once()
|
||||
call_args = mock_submit.call_args
|
||||
assert call_args[1]["scope"] == TaskScope.PRIVATE
|
||||
|
||||
@patch("superset.tasks.decorators.TaskManager.submit_task")
|
||||
def test_schedule_with_custom_options(self, mock_submit):
|
||||
"""Test schedule() with custom task options"""
|
||||
mock_submit.return_value = MagicMock()
|
||||
|
||||
@task(name="test_schedule_custom_unique", scope=TaskScope.SYSTEM)
|
||||
def schedule_task_3(arg1: int) -> None:
|
||||
pass
|
||||
|
||||
# Use custom task key and name
|
||||
schedule_task_3.schedule(
|
||||
123,
|
||||
options=TaskOptions(task_key="custom_key", task_name="Custom Task Name"),
|
||||
)
|
||||
|
||||
# Verify scope from decorator and options from call time
|
||||
mock_submit.assert_called_once()
|
||||
call_args = mock_submit.call_args
|
||||
assert call_args[1]["scope"] == TaskScope.SYSTEM
|
||||
assert call_args[1]["task_key"] == "custom_key"
|
||||
assert call_args[1]["task_name"] == "Custom Task Name"
|
||||
|
||||
@patch("superset.tasks.decorators.TaskManager.submit_task")
|
||||
def test_schedule_with_no_decorator_options(self, mock_submit):
|
||||
"""Test schedule() uses default PRIVATE scope when no options provided"""
|
||||
mock_submit.return_value = MagicMock()
|
||||
|
||||
@task(name="test_schedule_no_options_unique")
|
||||
def schedule_task_4(arg1: int) -> None:
|
||||
pass
|
||||
|
||||
schedule_task_4.schedule(123)
|
||||
|
||||
# Verify default PRIVATE scope
|
||||
mock_submit.assert_called_once()
|
||||
call_args = mock_submit.call_args
|
||||
assert call_args[1]["scope"] == TaskScope.PRIVATE
|
||||
|
||||
@patch("superset.tasks.decorators.TaskManager.submit_task")
|
||||
def test_schedule_shared_task_requires_task_key(self, mock_submit):
|
||||
"""Test shared task schedule() requires explicit task_key"""
|
||||
|
||||
@task(name="test_shared_requires_key", scope=TaskScope.SHARED)
|
||||
def shared_task(arg1: int) -> None:
|
||||
pass
|
||||
|
||||
# Should raise ValueError when no task_key provided
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="Shared task.*requires an explicit task_key.*for deduplication",
|
||||
):
|
||||
shared_task.schedule(123)
|
||||
|
||||
# Should work with task_key provided
|
||||
mock_submit.return_value = MagicMock()
|
||||
shared_task.schedule(123, options=TaskOptions(task_key="valid_key"))
|
||||
mock_submit.assert_called_once()
|
||||
|
||||
@patch("superset.tasks.decorators.TaskManager.submit_task")
|
||||
def test_schedule_private_task_allows_no_task_key(self, mock_submit):
|
||||
"""Test private task schedule() works without task_key"""
|
||||
mock_submit.return_value = MagicMock()
|
||||
|
||||
@task(name="test_private_no_key", scope=TaskScope.PRIVATE)
|
||||
def private_task(arg1: int) -> None:
|
||||
pass
|
||||
|
||||
# Should work without task_key (generates random UUID)
|
||||
private_task.schedule(123)
|
||||
mock_submit.assert_called_once()
|
||||
|
||||
|
||||
class TestTaskWrapperCall:
|
||||
"""Tests for TaskWrapper.__call__() with scope"""
|
||||
|
||||
def setup_method(self):
|
||||
"""Clear task registry before each test"""
|
||||
TaskRegistry._tasks.clear()
|
||||
|
||||
@patch("superset.commands.tasks.update.UpdateTaskCommand.run")
|
||||
@patch("superset.daos.tasks.TaskDAO.find_one_or_none")
|
||||
@patch("superset.commands.tasks.submit.SubmitTaskCommand.run_with_info")
|
||||
def test_call_uses_default_scope(
|
||||
self, mock_submit_run_with_info, mock_find, mock_update_run
|
||||
):
|
||||
"""Test direct call uses decorator's default scope"""
|
||||
mock_task = MagicMock()
|
||||
mock_task.uuid = TEST_UUID
|
||||
mock_task.status = "in_progress"
|
||||
mock_submit_run_with_info.return_value = (mock_task, True) # (task, is_new)
|
||||
mock_update_run.return_value = mock_task
|
||||
mock_find.return_value = mock_task # Mock the subsequent find call
|
||||
|
||||
@task(name="test_call_default_unique", scope=TaskScope.SHARED)
|
||||
def call_task_1(arg1: int) -> None:
|
||||
pass
|
||||
|
||||
# Shared tasks require explicit task_key
|
||||
call_task_1(123, options=TaskOptions(task_key="test_key"))
|
||||
|
||||
# Verify SubmitTaskCommand.run_with_info was called
|
||||
mock_submit_run_with_info.assert_called_once()
|
||||
|
||||
@patch("superset.utils.core.get_user_id")
|
||||
@patch("superset.commands.tasks.update.UpdateTaskCommand.run")
|
||||
@patch("superset.daos.tasks.TaskDAO.find_one_or_none")
|
||||
@patch("superset.commands.tasks.submit.SubmitTaskCommand.run_with_info")
|
||||
def test_call_uses_private_scope_by_default(
|
||||
self, mock_submit_run_with_info, mock_find, mock_update_run, mock_get_user_id
|
||||
):
|
||||
"""Test direct call uses PRIVATE scope when no scope specified"""
|
||||
mock_get_user_id.return_value = 1
|
||||
mock_task = MagicMock()
|
||||
mock_task.uuid = TEST_UUID
|
||||
mock_task.status = "in_progress"
|
||||
mock_submit_run_with_info.return_value = (mock_task, True) # (task, is_new)
|
||||
mock_update_run.return_value = mock_task
|
||||
mock_find.return_value = mock_task # Mock the subsequent find call
|
||||
|
||||
@task(name="test_call_private_default_unique")
|
||||
def call_task_2(arg1: int) -> None:
|
||||
pass
|
||||
|
||||
call_task_2(123)
|
||||
|
||||
# Verify SubmitTaskCommand.run_with_info was called
|
||||
mock_submit_run_with_info.assert_called_once()
|
||||
|
||||
@patch("superset.commands.tasks.update.UpdateTaskCommand.run")
|
||||
@patch("superset.daos.tasks.TaskDAO.find_one_or_none")
|
||||
@patch("superset.commands.tasks.submit.SubmitTaskCommand.run_with_info")
|
||||
def test_call_with_custom_options(
|
||||
self, mock_submit_run_with_info, mock_find, mock_update_run
|
||||
):
|
||||
"""Test direct call with custom task options"""
|
||||
mock_task = MagicMock()
|
||||
mock_task.uuid = TEST_UUID
|
||||
mock_task.status = "in_progress"
|
||||
mock_submit_run_with_info.return_value = (mock_task, True) # (task, is_new)
|
||||
mock_update_run.return_value = mock_task
|
||||
mock_find.return_value = mock_task # Mock the subsequent find call
|
||||
|
||||
@task(name="test_call_custom_unique", scope=TaskScope.SYSTEM)
|
||||
def call_task_3(arg1: int) -> None:
|
||||
pass
|
||||
|
||||
# Use custom task key and name
|
||||
call_task_3(
|
||||
123,
|
||||
options=TaskOptions(task_key="custom_key", task_name="Custom Task Name"),
|
||||
)
|
||||
|
||||
# Verify SubmitTaskCommand.run_with_info was called
|
||||
mock_submit_run_with_info.assert_called_once()
|
||||
|
||||
def test_call_shared_task_requires_task_key(self):
|
||||
"""Test shared task direct call requires explicit task_key"""
|
||||
|
||||
@task(name="test_shared_call_requires_key", scope=TaskScope.SHARED)
|
||||
def shared_task(arg1: int) -> None:
|
||||
pass
|
||||
|
||||
# Should raise ValueError when no task_key provided
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="Shared task.*requires an explicit task_key.*for deduplication",
|
||||
):
|
||||
shared_task(123)
|
||||
|
||||
@patch("superset.commands.tasks.update.UpdateTaskCommand.run")
|
||||
@patch("superset.daos.tasks.TaskDAO.find_one_or_none")
|
||||
@patch("superset.commands.tasks.submit.SubmitTaskCommand.run_with_info")
|
||||
def test_call_shared_task_works_with_task_key(
|
||||
self, mock_submit_run_with_info, mock_find, mock_update_run
|
||||
):
|
||||
"""Test shared task direct call works with task_key"""
|
||||
mock_task = MagicMock()
|
||||
mock_task.uuid = TEST_UUID
|
||||
mock_task.status = "in_progress"
|
||||
mock_submit_run_with_info.return_value = (mock_task, True) # (task, is_new)
|
||||
mock_update_run.return_value = mock_task
|
||||
mock_find.return_value = mock_task
|
||||
|
||||
@task(name="test_shared_call_with_key", scope=TaskScope.SHARED)
|
||||
def shared_task(arg1: int) -> None:
|
||||
pass
|
||||
|
||||
# Should work with task_key provided
|
||||
shared_task(123, options=TaskOptions(task_key="valid_key"))
|
||||
mock_submit_run_with_info.assert_called_once()
|
||||
|
||||
@patch("superset.utils.core.get_user_id")
|
||||
@patch("superset.commands.tasks.update.UpdateTaskCommand.run")
|
||||
@patch("superset.daos.tasks.TaskDAO.find_one_or_none")
|
||||
@patch("superset.commands.tasks.submit.SubmitTaskCommand.run_with_info")
|
||||
def test_call_private_task_allows_no_task_key(
|
||||
self, mock_submit_run_with_info, mock_find, mock_update_run, mock_get_user_id
|
||||
):
|
||||
"""Test private task direct call works without task_key"""
|
||||
mock_get_user_id.return_value = 1
|
||||
mock_task = MagicMock()
|
||||
mock_task.uuid = TEST_UUID
|
||||
mock_task.status = "in_progress"
|
||||
mock_submit_run_with_info.return_value = (mock_task, True) # (task, is_new)
|
||||
mock_update_run.return_value = mock_task
|
||||
mock_find.return_value = mock_task
|
||||
|
||||
@task(name="test_private_call_no_key", scope=TaskScope.PRIVATE)
|
||||
def private_task(arg1: int) -> None:
|
||||
pass
|
||||
|
||||
# Should work without task_key (generates random UUID)
|
||||
private_task(123)
|
||||
mock_submit_run_with_info.assert_called_once()
|
||||
677
tests/unit_tests/tasks/test_handlers.py
Normal file
677
tests/unit_tests/tasks/test_handlers.py
Normal file
@@ -0,0 +1,677 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""Unit tests for GTF handlers (abort, cleanup) and related Task model behavior."""
|
||||
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from freezegun import freeze_time
|
||||
from superset_core.api.tasks import TaskStatus
|
||||
|
||||
from superset.tasks.context import TaskContext
|
||||
|
||||
TEST_UUID = UUID("b8b61b7b-1cd3-4a31-a74a-0a95341afc06")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_task():
|
||||
"""Create a mock task for testing."""
|
||||
task = MagicMock()
|
||||
task.uuid = TEST_UUID
|
||||
task.status = TaskStatus.PENDING.value
|
||||
return task
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_task_dao(mock_task):
|
||||
"""Mock TaskDAO to return our test task."""
|
||||
with patch("superset.daos.tasks.TaskDAO") as mock_dao:
|
||||
mock_dao.find_one_or_none.return_value = mock_task
|
||||
yield mock_dao
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_update_command():
|
||||
"""Mock UpdateTaskCommand to avoid database operations."""
|
||||
with patch("superset.commands.tasks.update.UpdateTaskCommand") as mock_cmd:
|
||||
mock_cmd.return_value.run.return_value = None
|
||||
yield mock_cmd
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_flask_app():
|
||||
"""Create a properly configured mock Flask app."""
|
||||
mock_app = MagicMock()
|
||||
mock_app.config = {
|
||||
"TASK_ABORT_POLLING_DEFAULT_INTERVAL": 0.1,
|
||||
}
|
||||
# Make app_context() return a proper context manager
|
||||
mock_app.app_context.return_value.__enter__ = MagicMock(return_value=None)
|
||||
mock_app.app_context.return_value.__exit__ = MagicMock(return_value=None)
|
||||
# Use regular Mock (not MagicMock) for _get_current_object to avoid
|
||||
# AsyncMockMixin creating unawaited coroutines in Python 3.10+
|
||||
mock_app._get_current_object = Mock(return_value=mock_app)
|
||||
return mock_app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def task_context(mock_task, mock_task_dao, mock_update_command, mock_flask_app):
|
||||
"""Create TaskContext with mocked dependencies."""
|
||||
# Ensure mock_task has properties_dict and payload_dict (TaskContext accesses them)
|
||||
mock_task.properties_dict = {"is_abortable": False}
|
||||
mock_task.payload_dict = {}
|
||||
|
||||
with (
|
||||
patch("superset.tasks.context.current_app") as mock_current_app,
|
||||
patch("superset.tasks.manager.cache_manager") as mock_cache_manager,
|
||||
):
|
||||
# Disable Redis by making signal_cache return None
|
||||
mock_cache_manager.signal_cache = None
|
||||
|
||||
# Configure current_app mock
|
||||
mock_current_app.config = mock_flask_app.config
|
||||
# Use regular Mock (not MagicMock) for _get_current_object to avoid
|
||||
# AsyncMockMixin creating unawaited coroutines in Python 3.10+
|
||||
mock_current_app._get_current_object = Mock(return_value=mock_flask_app)
|
||||
|
||||
ctx = TaskContext(mock_task)
|
||||
|
||||
yield ctx
|
||||
|
||||
# Cleanup: stop polling if started
|
||||
if ctx._abort_listener:
|
||||
ctx.stop_abort_polling()
|
||||
|
||||
|
||||
class TestTaskStatusEnum:
|
||||
"""Test TaskStatus enum values."""
|
||||
|
||||
def test_aborting_status_exists(self):
|
||||
"""Test that ABORTING status is defined."""
|
||||
assert hasattr(TaskStatus, "ABORTING")
|
||||
assert TaskStatus.ABORTING.value == "aborting"
|
||||
|
||||
def test_all_statuses_present(self):
|
||||
"""Test all expected statuses are present."""
|
||||
expected_statuses = [
|
||||
"pending",
|
||||
"in_progress",
|
||||
"success",
|
||||
"failure",
|
||||
"aborting",
|
||||
"aborted",
|
||||
]
|
||||
actual_statuses = [s.value for s in TaskStatus]
|
||||
|
||||
for status in expected_statuses:
|
||||
assert status in actual_statuses, f"Missing status: {status}"
|
||||
|
||||
|
||||
class TestTaskAbortProperties:
|
||||
"""Test Task model abort-related properties via status and properties accessor."""
|
||||
|
||||
def test_aborting_status(self):
|
||||
"""Test ABORTING status check."""
|
||||
from superset.models.tasks import Task
|
||||
|
||||
task = Task()
|
||||
task.status = TaskStatus.ABORTING.value
|
||||
|
||||
assert task.status == TaskStatus.ABORTING.value
|
||||
|
||||
def test_is_abortable_in_properties(self):
|
||||
"""Test is_abortable is accessible via properties."""
|
||||
from superset.models.tasks import Task
|
||||
|
||||
task = Task()
|
||||
task.update_properties({"is_abortable": True})
|
||||
|
||||
assert task.properties_dict.get("is_abortable") is True
|
||||
|
||||
def test_is_abortable_default_none(self):
|
||||
"""Test is_abortable defaults to None for new tasks."""
|
||||
from superset.models.tasks import Task
|
||||
|
||||
task = Task()
|
||||
|
||||
assert task.properties_dict.get("is_abortable") is None
|
||||
|
||||
|
||||
class TestTaskSetStatus:
|
||||
"""Test Task.set_status behavior for abort states."""
|
||||
|
||||
def test_set_status_in_progress_sets_is_abortable_false(self):
|
||||
"""Test that transitioning to IN_PROGRESS sets is_abortable to False."""
|
||||
from superset.models.tasks import Task
|
||||
|
||||
task = Task()
|
||||
task.uuid = "test-uuid"
|
||||
# Default is None
|
||||
|
||||
task.set_status(TaskStatus.IN_PROGRESS)
|
||||
|
||||
assert task.properties_dict.get("is_abortable") is False
|
||||
assert task.started_at is not None
|
||||
|
||||
def test_set_status_in_progress_preserves_existing_is_abortable(self):
|
||||
"""Test that re-setting IN_PROGRESS doesn't override is_abortable."""
|
||||
from superset.models.tasks import Task
|
||||
|
||||
task = Task()
|
||||
task.uuid = "test-uuid"
|
||||
task.update_properties(
|
||||
{"is_abortable": True}
|
||||
) # Already set by handler registration
|
||||
task.started_at = datetime.now(timezone.utc) # Already started
|
||||
|
||||
task.set_status(TaskStatus.IN_PROGRESS)
|
||||
|
||||
# Should not override since started_at is already set
|
||||
assert task.properties_dict.get("is_abortable") is True
|
||||
|
||||
def test_set_status_aborting_does_not_set_ended_at(self):
|
||||
"""Test that ABORTING status does not set ended_at."""
|
||||
from superset.models.tasks import Task
|
||||
|
||||
task = Task()
|
||||
task.uuid = "test-uuid"
|
||||
task.started_at = datetime.now(timezone.utc)
|
||||
|
||||
task.status = TaskStatus.ABORTING.value
|
||||
|
||||
assert task.ended_at is None
|
||||
|
||||
def test_set_status_aborted_sets_ended_at(self):
|
||||
"""Test that ABORTED status sets ended_at."""
|
||||
from superset.models.tasks import Task
|
||||
|
||||
task = Task()
|
||||
task.uuid = "test-uuid"
|
||||
task.started_at = datetime.now(timezone.utc)
|
||||
|
||||
task.set_status(TaskStatus.ABORTED)
|
||||
|
||||
assert task.ended_at is not None
|
||||
|
||||
|
||||
class TestTaskDuration:
|
||||
"""Test Task duration_seconds property with different states."""
|
||||
|
||||
def test_duration_seconds_finished_task(self):
|
||||
"""Test duration for finished task returns actual duration."""
|
||||
from superset.models.tasks import Task
|
||||
|
||||
task = Task()
|
||||
task.status = TaskStatus.SUCCESS.value # Must be finished to use ended_at
|
||||
task.started_at = datetime(2024, 1, 1, 10, 0, 0, tzinfo=timezone.utc)
|
||||
task.ended_at = datetime(2024, 1, 1, 10, 0, 30, tzinfo=timezone.utc)
|
||||
|
||||
# Should use ended_at - started_at = 30 seconds
|
||||
assert task.duration_seconds == 30.0
|
||||
|
||||
@freeze_time("2024-01-01 10:00:30")
|
||||
def test_duration_seconds_running_task(self):
|
||||
"""Test duration for running task returns time since start."""
|
||||
from superset.models.tasks import Task
|
||||
|
||||
task = Task()
|
||||
task.started_at = datetime(2024, 1, 1, 10, 0, 0, tzinfo=timezone.utc)
|
||||
task.ended_at = None
|
||||
|
||||
# 30 seconds since start
|
||||
assert task.duration_seconds == 30.0
|
||||
|
||||
@freeze_time("2024-01-01 10:00:15")
|
||||
def test_duration_seconds_pending_task(self):
|
||||
"""Test duration for pending task returns queue time."""
|
||||
from superset.models.tasks import Task
|
||||
|
||||
task = Task()
|
||||
task.created_on = datetime(2024, 1, 1, 10, 0, 0, tzinfo=timezone.utc)
|
||||
task.started_at = None
|
||||
task.ended_at = None
|
||||
|
||||
# 15 seconds since creation
|
||||
assert task.duration_seconds == 15.0
|
||||
|
||||
def test_duration_seconds_no_timestamps(self):
|
||||
"""Test duration returns None when no timestamps available."""
|
||||
from superset.models.tasks import Task
|
||||
|
||||
task = Task()
|
||||
task.created_on = None
|
||||
task.started_at = None
|
||||
task.ended_at = None
|
||||
|
||||
assert task.duration_seconds is None
|
||||
|
||||
|
||||
class TestAbortHandlerRegistration:
|
||||
"""Test abort handler registration and is_abortable flag."""
|
||||
|
||||
def test_on_abort_registers_handler(self, task_context):
|
||||
"""Test that on_abort registers a handler."""
|
||||
handler_called = False
|
||||
|
||||
@task_context.on_abort
|
||||
def handle_abort():
|
||||
nonlocal handler_called
|
||||
handler_called = True
|
||||
|
||||
assert len(task_context._abort_handlers) == 1
|
||||
assert not handler_called
|
||||
|
||||
@patch("superset.tasks.context.current_app")
|
||||
def test_on_abort_sets_abortable(self, mock_app):
|
||||
"""Test on_abort sets is_abortable to True on first handler."""
|
||||
mock_app.config = {"TASK_ABORT_POLLING_DEFAULT_INTERVAL": 1.0}
|
||||
mock_app._get_current_object = Mock(return_value=mock_app)
|
||||
mock_task = MagicMock()
|
||||
mock_task.uuid = TEST_UUID
|
||||
mock_task.properties_dict = {"is_abortable": False}
|
||||
mock_task.payload_dict = {}
|
||||
|
||||
with (
|
||||
patch.object(TaskContext, "_set_abortable") as mock_set_abortable,
|
||||
patch.object(TaskContext, "start_abort_polling"),
|
||||
):
|
||||
ctx = TaskContext(mock_task)
|
||||
|
||||
@ctx.on_abort
|
||||
def handler():
|
||||
pass
|
||||
|
||||
mock_set_abortable.assert_called_once()
|
||||
|
||||
@patch("superset.tasks.context.current_app")
|
||||
def test_on_abort_only_sets_abortable_once(self, mock_app):
|
||||
"""Test on_abort only calls _set_abortable for first handler."""
|
||||
mock_app.config = {"TASK_ABORT_POLLING_DEFAULT_INTERVAL": 1.0}
|
||||
mock_app._get_current_object = Mock(return_value=mock_app)
|
||||
mock_task = MagicMock()
|
||||
mock_task.uuid = TEST_UUID
|
||||
mock_task.properties_dict = {"is_abortable": False}
|
||||
mock_task.payload_dict = {}
|
||||
|
||||
with (
|
||||
patch.object(TaskContext, "_set_abortable") as mock_set_abortable,
|
||||
patch.object(TaskContext, "start_abort_polling"),
|
||||
):
|
||||
ctx = TaskContext(mock_task)
|
||||
|
||||
@ctx.on_abort
|
||||
def handler1():
|
||||
pass
|
||||
|
||||
@ctx.on_abort
|
||||
def handler2():
|
||||
pass
|
||||
|
||||
# Should only be called once for first handler
|
||||
assert mock_set_abortable.call_count == 1
|
||||
|
||||
def test_abort_handlers_completed_initially_false(self):
|
||||
"""Test abort_handlers_completed is False initially."""
|
||||
mock_task = MagicMock()
|
||||
mock_task.uuid = TEST_UUID
|
||||
mock_task.properties_dict = {}
|
||||
mock_task.payload_dict = {}
|
||||
|
||||
with patch("superset.tasks.context.current_app") as mock_app:
|
||||
mock_app._get_current_object = Mock(return_value=mock_app)
|
||||
ctx = TaskContext(mock_task)
|
||||
assert ctx.abort_handlers_completed is False
|
||||
|
||||
|
||||
class TestAbortPolling:
|
||||
"""Test abort detection polling behavior."""
|
||||
|
||||
def test_on_abort_starts_polling_automatically(self, task_context):
|
||||
"""Test that registering first handler starts abort listener."""
|
||||
assert task_context._abort_listener is None
|
||||
|
||||
@task_context.on_abort
|
||||
def handle_abort():
|
||||
pass
|
||||
|
||||
assert task_context._abort_listener is not None
|
||||
|
||||
def test_stop_abort_polling(self, task_context):
|
||||
"""Test that stop_abort_polling stops the abort listener."""
|
||||
|
||||
@task_context.on_abort
|
||||
def handle_abort():
|
||||
pass
|
||||
|
||||
assert task_context._abort_listener is not None
|
||||
|
||||
task_context.stop_abort_polling()
|
||||
|
||||
assert task_context._abort_listener is None
|
||||
|
||||
def test_start_abort_polling_only_once(self, task_context):
|
||||
"""Test that start_abort_polling is idempotent."""
|
||||
task_context.start_abort_polling(interval=0.1)
|
||||
first_listener = task_context._abort_listener
|
||||
|
||||
# Try to start again
|
||||
task_context.start_abort_polling(interval=0.1)
|
||||
second_listener = task_context._abort_listener
|
||||
|
||||
# Should be the same listener
|
||||
assert first_listener is second_listener
|
||||
|
||||
def test_on_abort_with_custom_interval(self, task_context):
|
||||
"""Test that custom interval can be set via start_abort_polling."""
|
||||
with patch("superset.tasks.context.current_app") as mock_app:
|
||||
mock_app.config = {"TASK_ABORT_POLLING_DEFAULT_INTERVAL": 0.1}
|
||||
mock_app._get_current_object = Mock(return_value=mock_app)
|
||||
|
||||
@task_context.on_abort
|
||||
def handle_abort():
|
||||
pass
|
||||
|
||||
# Override with custom interval
|
||||
task_context.stop_abort_polling()
|
||||
task_context.start_abort_polling(interval=0.05)
|
||||
|
||||
assert task_context._abort_listener is not None
|
||||
|
||||
def test_polling_stops_after_abort_detected(self, task_context, mock_task):
|
||||
"""Test that abort is detected and handlers are triggered."""
|
||||
|
||||
@task_context.on_abort
|
||||
def handle_abort():
|
||||
pass
|
||||
|
||||
# Trigger abort
|
||||
mock_task.status = TaskStatus.ABORTED.value
|
||||
|
||||
# Wait for detection
|
||||
time.sleep(0.3)
|
||||
|
||||
# Abort should have been detected
|
||||
assert task_context._abort_detected is True
|
||||
|
||||
|
||||
class TestAbortHandlerExecution:
|
||||
"""Test abort handler execution behavior."""
|
||||
|
||||
def test_on_abort_handler_fires_when_task_aborted(self, task_context, mock_task):
|
||||
"""Test that abort handler fires automatically when task is aborted."""
|
||||
abort_called = False
|
||||
|
||||
@task_context.on_abort
|
||||
def handle_abort():
|
||||
nonlocal abort_called
|
||||
abort_called = True
|
||||
|
||||
# Simulate task being aborted
|
||||
mock_task.status = TaskStatus.ABORTED.value
|
||||
|
||||
# Wait for polling to detect abort (max 0.3s with 0.1s interval)
|
||||
time.sleep(0.3)
|
||||
|
||||
assert abort_called
|
||||
assert task_context._abort_detected
|
||||
|
||||
def test_on_abort_not_called_on_success(self, task_context, mock_task):
|
||||
"""Test that abort handlers don't run on success."""
|
||||
abort_called = False
|
||||
|
||||
@task_context.on_abort
|
||||
def handle_abort():
|
||||
nonlocal abort_called
|
||||
abort_called = True
|
||||
|
||||
# Keep task in success state
|
||||
mock_task.status = TaskStatus.SUCCESS.value
|
||||
|
||||
# Wait and verify handler not called
|
||||
time.sleep(0.3)
|
||||
|
||||
assert not abort_called
|
||||
|
||||
def test_multiple_abort_handlers(self, task_context, mock_task):
|
||||
"""Test that all abort handlers execute in LIFO order."""
|
||||
calls = []
|
||||
|
||||
@task_context.on_abort
|
||||
def handler1():
|
||||
calls.append(1)
|
||||
|
||||
@task_context.on_abort
|
||||
def handler2():
|
||||
calls.append(2)
|
||||
|
||||
# Trigger abort
|
||||
mock_task.status = TaskStatus.ABORTED.value
|
||||
|
||||
# Wait for detection
|
||||
time.sleep(0.3)
|
||||
|
||||
# LIFO order: handler2 runs first
|
||||
assert calls == [2, 1]
|
||||
|
||||
def test_abort_handler_exception_doesnt_fail_task(self, task_context, mock_task):
|
||||
"""Test that exception in abort handler is logged but doesn't fail task."""
|
||||
handler2_called = False
|
||||
|
||||
@task_context.on_abort
|
||||
def bad_handler():
|
||||
raise ValueError("Handler error")
|
||||
|
||||
@task_context.on_abort
|
||||
def good_handler():
|
||||
nonlocal handler2_called
|
||||
handler2_called = True
|
||||
|
||||
# Trigger abort
|
||||
mock_task.status = TaskStatus.ABORTED.value
|
||||
|
||||
# Wait for detection
|
||||
time.sleep(0.3)
|
||||
|
||||
# Second handler should still run despite first handler failing
|
||||
assert handler2_called
|
||||
|
||||
|
||||
class TestBestEffortHandlerExecution:
|
||||
"""Test that all handlers execute even when some fail (best-effort)."""
|
||||
|
||||
def test_all_abort_handlers_run_even_if_all_fail(self, task_context, mock_task):
|
||||
"""Test all abort handlers execute even if every one raises an exception."""
|
||||
calls = []
|
||||
|
||||
@task_context.on_abort
|
||||
def handler1():
|
||||
calls.append(1)
|
||||
raise ValueError("Handler 1 failed")
|
||||
|
||||
@task_context.on_abort
|
||||
def handler2():
|
||||
calls.append(2)
|
||||
raise RuntimeError("Handler 2 failed")
|
||||
|
||||
@task_context.on_abort
|
||||
def handler3():
|
||||
calls.append(3)
|
||||
raise TypeError("Handler 3 failed")
|
||||
|
||||
# Trigger abort handlers directly (simulating abort detection)
|
||||
task_context._trigger_abort_handlers()
|
||||
|
||||
# All handlers should have been called (LIFO order: 3, 2, 1)
|
||||
assert calls == [3, 2, 1]
|
||||
|
||||
# Failures should be collected (abort handlers don't write to DB)
|
||||
assert len(task_context._handler_failures) == 3
|
||||
failure_types = [
|
||||
type(ex).__name__ for _, ex, _ in task_context._handler_failures
|
||||
]
|
||||
assert "TypeError" in failure_types
|
||||
assert "RuntimeError" in failure_types
|
||||
assert "ValueError" in failure_types
|
||||
|
||||
def test_all_cleanup_handlers_run_even_if_all_fail(self, task_context, mock_task):
|
||||
"""Test all cleanup handlers execute even if every one raises an exception."""
|
||||
calls = []
|
||||
captured_failures = []
|
||||
|
||||
# Mock _write_handler_failures_to_db to capture failures before clearing
|
||||
original_write = task_context._write_handler_failures_to_db
|
||||
|
||||
def mock_write():
|
||||
captured_failures.extend(task_context._handler_failures)
|
||||
original_write()
|
||||
|
||||
task_context._write_handler_failures_to_db = mock_write
|
||||
|
||||
@task_context.on_cleanup
|
||||
def cleanup1():
|
||||
calls.append(1)
|
||||
raise ValueError("Cleanup 1 failed")
|
||||
|
||||
@task_context.on_cleanup
|
||||
def cleanup2():
|
||||
calls.append(2)
|
||||
raise RuntimeError("Cleanup 2 failed")
|
||||
|
||||
@task_context.on_cleanup
|
||||
def cleanup3():
|
||||
calls.append(3)
|
||||
raise TypeError("Cleanup 3 failed")
|
||||
|
||||
# Set task to SUCCESS (not aborting) so only cleanup handlers run
|
||||
mock_task.status = TaskStatus.SUCCESS.value
|
||||
|
||||
# Run cleanup
|
||||
task_context._run_cleanup()
|
||||
|
||||
# All handlers should have been called (LIFO order: 3, 2, 1)
|
||||
assert calls == [3, 2, 1]
|
||||
|
||||
# Failures should have been captured before clearing
|
||||
assert len(captured_failures) == 3
|
||||
failure_types = [type(ex).__name__ for _, ex, _ in captured_failures]
|
||||
assert "TypeError" in failure_types
|
||||
assert "RuntimeError" in failure_types
|
||||
assert "ValueError" in failure_types
|
||||
|
||||
def test_mixed_abort_and_cleanup_failures_all_collected(
|
||||
self, task_context, mock_task
|
||||
):
|
||||
"""Test abort and cleanup handler failures are collected together."""
|
||||
calls = []
|
||||
captured_failures = []
|
||||
|
||||
# Mock _write_handler_failures_to_db to capture failures before clearing
|
||||
original_write = task_context._write_handler_failures_to_db
|
||||
|
||||
def mock_write():
|
||||
captured_failures.extend(task_context._handler_failures)
|
||||
original_write()
|
||||
|
||||
task_context._write_handler_failures_to_db = mock_write
|
||||
|
||||
@task_context.on_abort
|
||||
def abort1():
|
||||
calls.append("abort1")
|
||||
raise ValueError("Abort 1 failed")
|
||||
|
||||
@task_context.on_abort
|
||||
def abort2():
|
||||
calls.append("abort2")
|
||||
raise RuntimeError("Abort 2 failed")
|
||||
|
||||
@task_context.on_cleanup
|
||||
def cleanup1():
|
||||
calls.append("cleanup1")
|
||||
raise TypeError("Cleanup 1 failed")
|
||||
|
||||
@task_context.on_cleanup
|
||||
def cleanup2():
|
||||
calls.append("cleanup2")
|
||||
raise KeyError("Cleanup 2 failed")
|
||||
|
||||
# Set task to ABORTING so both abort and cleanup handlers run
|
||||
mock_task.status = TaskStatus.ABORTING.value
|
||||
|
||||
# Run cleanup (which triggers abort handlers first, then cleanup handlers)
|
||||
task_context._run_cleanup()
|
||||
|
||||
# All handlers should have been called
|
||||
# Abort handlers run first (LIFO: abort2, abort1)
|
||||
# Then cleanup handlers (LIFO: cleanup2, cleanup1)
|
||||
assert calls == ["abort2", "abort1", "cleanup2", "cleanup1"]
|
||||
|
||||
# All 4 failures should have been captured
|
||||
assert len(captured_failures) == 4
|
||||
|
||||
# Verify handler types are recorded correctly
|
||||
handler_types = [htype for htype, _, _ in captured_failures]
|
||||
assert handler_types.count("abort") == 2
|
||||
assert handler_types.count("cleanup") == 2
|
||||
|
||||
|
||||
class TestCleanupHandlers:
|
||||
"""Test cleanup handler behavior."""
|
||||
|
||||
def test_cleanup_triggers_abort_handlers_if_not_detected(
|
||||
self, task_context, mock_task
|
||||
):
|
||||
"""Test that _run_cleanup triggers abort handlers if task ended aborted."""
|
||||
abort_called = False
|
||||
|
||||
@task_context.on_abort
|
||||
def handle_abort():
|
||||
nonlocal abort_called
|
||||
abort_called = True
|
||||
|
||||
# Set task as aborted but don't let polling detect it
|
||||
mock_task.status = TaskStatus.ABORTED.value
|
||||
task_context._abort_detected = False
|
||||
|
||||
# Immediately run cleanup (simulating task ending before poll)
|
||||
task_context._run_cleanup()
|
||||
|
||||
assert abort_called
|
||||
|
||||
def test_cleanup_doesnt_duplicate_abort_handlers(self, task_context, mock_task):
|
||||
"""Test that abort handlers only run once even if called from cleanup."""
|
||||
call_count = 0
|
||||
|
||||
@task_context.on_abort
|
||||
def handle_abort():
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
|
||||
# Trigger abort via polling
|
||||
mock_task.status = TaskStatus.ABORTED.value
|
||||
time.sleep(0.3)
|
||||
|
||||
# Handlers should have been called once
|
||||
assert call_count == 1
|
||||
assert task_context._abort_detected is True
|
||||
|
||||
# Run cleanup - handlers should NOT be called again
|
||||
task_context._run_cleanup()
|
||||
|
||||
assert call_count == 1 # Still 1, not 2
|
||||
462
tests/unit_tests/tasks/test_manager.py
Normal file
462
tests/unit_tests/tasks/test_manager.py
Normal file
@@ -0,0 +1,462 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""Unit tests for TaskManager pub/sub functionality"""
|
||||
|
||||
import threading
|
||||
import time
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import redis
|
||||
|
||||
from superset.tasks.manager import AbortListener, TaskManager
|
||||
|
||||
|
||||
class TestAbortListener:
|
||||
"""Tests for AbortListener class"""
|
||||
|
||||
def test_stop_sets_event(self):
|
||||
"""Test that stop() sets the stop event"""
|
||||
stop_event = threading.Event()
|
||||
thread = MagicMock(spec=threading.Thread)
|
||||
thread.is_alive.return_value = False
|
||||
|
||||
listener = AbortListener("test-uuid", thread, stop_event)
|
||||
|
||||
assert not stop_event.is_set()
|
||||
listener.stop()
|
||||
assert stop_event.is_set()
|
||||
|
||||
def test_stop_closes_pubsub(self):
|
||||
"""Test that stop() closes the pub/sub connection"""
|
||||
stop_event = threading.Event()
|
||||
thread = MagicMock(spec=threading.Thread)
|
||||
thread.is_alive.return_value = False
|
||||
pubsub = MagicMock()
|
||||
|
||||
listener = AbortListener("test-uuid", thread, stop_event, pubsub)
|
||||
listener.stop()
|
||||
|
||||
pubsub.unsubscribe.assert_called_once()
|
||||
pubsub.close.assert_called_once()
|
||||
|
||||
def test_stop_joins_thread(self):
|
||||
"""Test that stop() joins the listener thread"""
|
||||
stop_event = threading.Event()
|
||||
thread = MagicMock(spec=threading.Thread)
|
||||
thread.is_alive.return_value = True
|
||||
|
||||
listener = AbortListener("test-uuid", thread, stop_event)
|
||||
listener.stop()
|
||||
|
||||
thread.join.assert_called_once_with(timeout=2.0)
|
||||
|
||||
|
||||
class TestTaskManagerInitApp:
|
||||
"""Tests for TaskManager.init_app()"""
|
||||
|
||||
def setup_method(self):
|
||||
"""Reset TaskManager state before each test"""
|
||||
TaskManager._initialized = False
|
||||
TaskManager._channel_prefix = "gtf:abort:"
|
||||
TaskManager._completion_channel_prefix = "gtf:complete:"
|
||||
|
||||
def teardown_method(self):
|
||||
"""Reset TaskManager state after each test"""
|
||||
TaskManager._initialized = False
|
||||
TaskManager._channel_prefix = "gtf:abort:"
|
||||
TaskManager._completion_channel_prefix = "gtf:complete:"
|
||||
|
||||
def test_init_app_sets_channel_prefixes(self):
|
||||
"""Test init_app reads channel prefixes from config"""
|
||||
app = MagicMock()
|
||||
app.config.get.side_effect = lambda key, default=None: {
|
||||
"TASKS_ABORT_CHANNEL_PREFIX": "custom:abort:",
|
||||
"TASKS_COMPLETION_CHANNEL_PREFIX": "custom:complete:",
|
||||
}.get(key, default)
|
||||
|
||||
TaskManager.init_app(app)
|
||||
|
||||
assert TaskManager._initialized is True
|
||||
assert TaskManager._channel_prefix == "custom:abort:"
|
||||
assert TaskManager._completion_channel_prefix == "custom:complete:"
|
||||
|
||||
def test_init_app_skips_if_already_initialized(self):
|
||||
"""Test init_app is idempotent"""
|
||||
TaskManager._initialized = True
|
||||
|
||||
app = MagicMock()
|
||||
TaskManager.init_app(app)
|
||||
|
||||
# Should not call app.config.get since already initialized
|
||||
app.config.get.assert_not_called()
|
||||
|
||||
|
||||
class TestTaskManagerPubSub:
|
||||
"""Tests for TaskManager pub/sub methods"""
|
||||
|
||||
def setup_method(self):
|
||||
"""Reset TaskManager state before each test"""
|
||||
TaskManager._initialized = False
|
||||
TaskManager._channel_prefix = "gtf:abort:"
|
||||
TaskManager._completion_channel_prefix = "gtf:complete:"
|
||||
|
||||
def teardown_method(self):
|
||||
"""Reset TaskManager state after each test"""
|
||||
TaskManager._initialized = False
|
||||
TaskManager._channel_prefix = "gtf:abort:"
|
||||
TaskManager._completion_channel_prefix = "gtf:complete:"
|
||||
|
||||
@patch("superset.tasks.manager.cache_manager")
|
||||
def test_is_pubsub_available_no_redis(self, mock_cache_manager):
|
||||
"""Test is_pubsub_available returns False when Redis not configured"""
|
||||
mock_cache_manager.signal_cache = None
|
||||
assert TaskManager.is_pubsub_available() is False
|
||||
|
||||
@patch("superset.tasks.manager.cache_manager")
|
||||
def test_is_pubsub_available_with_redis(self, mock_cache_manager):
|
||||
"""Test is_pubsub_available returns True when Redis is configured"""
|
||||
mock_cache_manager.signal_cache = MagicMock()
|
||||
assert TaskManager.is_pubsub_available() is True
|
||||
|
||||
def test_get_abort_channel(self):
|
||||
"""Test get_abort_channel returns correct channel name"""
|
||||
task_uuid = "abc-123-def-456"
|
||||
channel = TaskManager.get_abort_channel(task_uuid)
|
||||
assert channel == "gtf:abort:abc-123-def-456"
|
||||
|
||||
def test_get_abort_channel_custom_prefix(self):
|
||||
"""Test get_abort_channel with custom prefix"""
|
||||
TaskManager._channel_prefix = "custom:prefix:"
|
||||
task_uuid = "test-uuid"
|
||||
channel = TaskManager.get_abort_channel(task_uuid)
|
||||
assert channel == "custom:prefix:test-uuid"
|
||||
|
||||
@patch("superset.tasks.manager.cache_manager")
|
||||
def test_publish_abort_no_redis(self, mock_cache_manager):
|
||||
"""Test publish_abort returns False when Redis not available"""
|
||||
mock_cache_manager.signal_cache = None
|
||||
result = TaskManager.publish_abort("test-uuid")
|
||||
assert result is False
|
||||
|
||||
@patch("superset.tasks.manager.cache_manager")
|
||||
def test_publish_abort_success(self, mock_cache_manager):
|
||||
"""Test publish_abort publishes message successfully"""
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.publish.return_value = 1 # One subscriber
|
||||
mock_cache_manager.signal_cache = mock_redis
|
||||
|
||||
result = TaskManager.publish_abort("test-uuid")
|
||||
|
||||
assert result is True
|
||||
mock_redis.publish.assert_called_once_with("gtf:abort:test-uuid", "abort")
|
||||
|
||||
@patch("superset.tasks.manager.cache_manager")
|
||||
def test_publish_abort_redis_error(self, mock_cache_manager):
|
||||
"""Test publish_abort handles Redis errors gracefully"""
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.publish.side_effect = redis.RedisError("Connection lost")
|
||||
mock_cache_manager.signal_cache = mock_redis
|
||||
|
||||
result = TaskManager.publish_abort("test-uuid")
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestTaskManagerListenForAbort:
|
||||
"""Tests for TaskManager.listen_for_abort()"""
|
||||
|
||||
def setup_method(self):
|
||||
"""Reset TaskManager state before each test"""
|
||||
TaskManager._initialized = False
|
||||
TaskManager._channel_prefix = "gtf:abort:"
|
||||
TaskManager._completion_channel_prefix = "gtf:complete:"
|
||||
|
||||
def teardown_method(self):
|
||||
"""Reset TaskManager state after each test"""
|
||||
TaskManager._initialized = False
|
||||
TaskManager._channel_prefix = "gtf:abort:"
|
||||
TaskManager._completion_channel_prefix = "gtf:complete:"
|
||||
|
||||
@patch("superset.tasks.manager.cache_manager")
|
||||
def test_listen_for_abort_no_redis_uses_polling(self, mock_cache_manager):
|
||||
"""Test listen_for_abort falls back to polling when Redis unavailable"""
|
||||
mock_cache_manager.signal_cache = None
|
||||
callback = MagicMock()
|
||||
|
||||
with patch.object(TaskManager, "_poll_for_abort", return_value=None):
|
||||
listener = TaskManager.listen_for_abort(
|
||||
task_uuid="test-uuid",
|
||||
callback=callback,
|
||||
poll_interval=1.0,
|
||||
app=None,
|
||||
)
|
||||
|
||||
# Give thread time to start
|
||||
time.sleep(0.1)
|
||||
listener.stop()
|
||||
|
||||
# Should use polling since no Redis
|
||||
assert listener._pubsub is None
|
||||
|
||||
@patch("superset.tasks.manager.cache_manager")
|
||||
def test_listen_for_abort_with_redis_uses_pubsub(self, mock_cache_manager):
|
||||
"""Test listen_for_abort uses pub/sub when Redis available"""
|
||||
mock_redis = MagicMock()
|
||||
mock_pubsub = MagicMock()
|
||||
mock_redis.pubsub.return_value = mock_pubsub
|
||||
mock_cache_manager.signal_cache = mock_redis
|
||||
|
||||
callback = MagicMock()
|
||||
|
||||
with patch.object(TaskManager, "_listen_pubsub", return_value=None):
|
||||
listener = TaskManager.listen_for_abort(
|
||||
task_uuid="test-uuid",
|
||||
callback=callback,
|
||||
poll_interval=1.0,
|
||||
app=None,
|
||||
)
|
||||
|
||||
# Give thread time to start
|
||||
time.sleep(0.1)
|
||||
listener.stop()
|
||||
|
||||
# Should subscribe to channel
|
||||
mock_pubsub.subscribe.assert_called_once_with("gtf:abort:test-uuid")
|
||||
|
||||
@patch("superset.tasks.manager.cache_manager")
|
||||
def test_listen_for_abort_redis_subscribe_failure_raises(self, mock_cache_manager):
|
||||
"""Test listen_for_abort raises exception on subscribe failure
|
||||
when Redis configured"""
|
||||
import pytest
|
||||
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.pubsub.side_effect = redis.RedisError("Connection failed")
|
||||
mock_cache_manager.signal_cache = mock_redis
|
||||
|
||||
callback = MagicMock()
|
||||
|
||||
# With fail-fast behavior, Redis subscribe failure raises exception
|
||||
with pytest.raises(redis.RedisError, match="Connection failed"):
|
||||
TaskManager.listen_for_abort(
|
||||
task_uuid="test-uuid",
|
||||
callback=callback,
|
||||
poll_interval=1.0,
|
||||
app=None,
|
||||
)
|
||||
|
||||
|
||||
class TestTaskManagerCompletion:
|
||||
"""Tests for TaskManager completion pub/sub and wait_for_completion"""
|
||||
|
||||
def setup_method(self):
|
||||
"""Reset TaskManager state before each test"""
|
||||
TaskManager._initialized = False
|
||||
TaskManager._channel_prefix = "gtf:abort:"
|
||||
TaskManager._completion_channel_prefix = "gtf:complete:"
|
||||
|
||||
def teardown_method(self):
|
||||
"""Reset TaskManager state after each test"""
|
||||
TaskManager._initialized = False
|
||||
TaskManager._channel_prefix = "gtf:abort:"
|
||||
TaskManager._completion_channel_prefix = "gtf:complete:"
|
||||
|
||||
def test_get_completion_channel(self):
|
||||
"""Test get_completion_channel returns correct channel name"""
|
||||
task_uuid = "abc-123-def-456"
|
||||
channel = TaskManager.get_completion_channel(task_uuid)
|
||||
assert channel == "gtf:complete:abc-123-def-456"
|
||||
|
||||
def test_get_completion_channel_custom_prefix(self):
|
||||
"""Test get_completion_channel with custom prefix"""
|
||||
TaskManager._completion_channel_prefix = "custom:complete:"
|
||||
task_uuid = "test-uuid"
|
||||
channel = TaskManager.get_completion_channel(task_uuid)
|
||||
assert channel == "custom:complete:test-uuid"
|
||||
|
||||
@patch("superset.tasks.manager.cache_manager")
|
||||
def test_publish_completion_no_redis(self, mock_cache_manager):
|
||||
"""Test publish_completion returns False when Redis not available"""
|
||||
mock_cache_manager.signal_cache = None
|
||||
result = TaskManager.publish_completion("test-uuid", "success")
|
||||
assert result is False
|
||||
|
||||
@patch("superset.tasks.manager.cache_manager")
|
||||
def test_publish_completion_success(self, mock_cache_manager):
|
||||
"""Test publish_completion publishes message successfully"""
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.publish.return_value = 1 # One subscriber
|
||||
mock_cache_manager.signal_cache = mock_redis
|
||||
|
||||
result = TaskManager.publish_completion("test-uuid", "success")
|
||||
|
||||
assert result is True
|
||||
mock_redis.publish.assert_called_once_with("gtf:complete:test-uuid", "success")
|
||||
|
||||
@patch("superset.tasks.manager.cache_manager")
|
||||
def test_publish_completion_redis_error(self, mock_cache_manager):
|
||||
"""Test publish_completion handles Redis errors gracefully"""
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.publish.side_effect = redis.RedisError("Connection lost")
|
||||
mock_cache_manager.signal_cache = mock_redis
|
||||
|
||||
result = TaskManager.publish_completion("test-uuid", "success")
|
||||
|
||||
assert result is False
|
||||
|
||||
@patch("superset.tasks.manager.cache_manager")
|
||||
@patch("superset.daos.tasks.TaskDAO")
|
||||
def test_wait_for_completion_task_not_found(self, mock_dao, mock_cache_manager):
|
||||
"""Test wait_for_completion raises ValueError for missing task"""
|
||||
import pytest
|
||||
|
||||
mock_cache_manager.signal_cache = None
|
||||
mock_dao.find_one_or_none.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
TaskManager.wait_for_completion("nonexistent-uuid")
|
||||
|
||||
@patch("superset.tasks.manager.cache_manager")
|
||||
@patch("superset.daos.tasks.TaskDAO")
|
||||
def test_wait_for_completion_already_complete(self, mock_dao, mock_cache_manager):
|
||||
"""Test wait_for_completion returns immediately for terminal state"""
|
||||
mock_cache_manager.signal_cache = None
|
||||
mock_task = MagicMock()
|
||||
mock_task.uuid = "test-uuid"
|
||||
mock_task.status = "success"
|
||||
mock_dao.find_one_or_none.return_value = mock_task
|
||||
|
||||
result = TaskManager.wait_for_completion("test-uuid")
|
||||
|
||||
assert result == mock_task
|
||||
# Should only call find_one_or_none once (initial check)
|
||||
mock_dao.find_one_or_none.assert_called_once()
|
||||
|
||||
@patch("superset.tasks.manager.cache_manager")
|
||||
@patch("superset.daos.tasks.TaskDAO")
|
||||
def test_wait_for_completion_timeout(self, mock_dao, mock_cache_manager):
|
||||
"""Test wait_for_completion raises TimeoutError when timeout expires"""
|
||||
import pytest
|
||||
|
||||
mock_cache_manager.signal_cache = None
|
||||
mock_task = MagicMock()
|
||||
mock_task.uuid = "test-uuid"
|
||||
mock_task.status = "in_progress" # Never completes
|
||||
mock_dao.find_one_or_none.return_value = mock_task
|
||||
|
||||
with pytest.raises(TimeoutError, match="Timeout waiting"):
|
||||
TaskManager.wait_for_completion("test-uuid", timeout=0.1)
|
||||
|
||||
@patch("superset.tasks.manager.cache_manager")
|
||||
@patch("superset.daos.tasks.TaskDAO")
|
||||
def test_wait_for_completion_polling_success(self, mock_dao, mock_cache_manager):
|
||||
"""Test wait_for_completion returns when task completes via polling"""
|
||||
mock_cache_manager.signal_cache = None
|
||||
mock_task_pending = MagicMock()
|
||||
mock_task_pending.uuid = "test-uuid"
|
||||
mock_task_pending.status = "pending"
|
||||
|
||||
mock_task_complete = MagicMock()
|
||||
mock_task_complete.uuid = "test-uuid"
|
||||
mock_task_complete.status = "success"
|
||||
|
||||
# First call returns pending, second returns complete
|
||||
mock_dao.find_one_or_none.side_effect = [
|
||||
mock_task_pending,
|
||||
mock_task_complete,
|
||||
]
|
||||
|
||||
result = TaskManager.wait_for_completion(
|
||||
"test-uuid",
|
||||
timeout=5.0,
|
||||
poll_interval=0.1,
|
||||
)
|
||||
|
||||
assert result.status == "success"
|
||||
|
||||
@patch("superset.tasks.manager.cache_manager")
|
||||
@patch("superset.daos.tasks.TaskDAO")
|
||||
def test_wait_for_completion_with_pubsub(self, mock_dao, mock_cache_manager):
|
||||
"""Test wait_for_completion uses pub/sub when Redis available"""
|
||||
mock_task_pending = MagicMock()
|
||||
mock_task_pending.uuid = "test-uuid"
|
||||
mock_task_pending.status = "pending"
|
||||
|
||||
mock_task_complete = MagicMock()
|
||||
mock_task_complete.uuid = "test-uuid"
|
||||
mock_task_complete.status = "success"
|
||||
|
||||
# First call returns pending, second returns complete
|
||||
mock_dao.find_one_or_none.side_effect = [
|
||||
mock_task_pending,
|
||||
mock_task_complete,
|
||||
]
|
||||
|
||||
# Set up mock Redis with pub/sub
|
||||
mock_redis = MagicMock()
|
||||
mock_pubsub = MagicMock()
|
||||
# Simulate receiving a completion message
|
||||
mock_pubsub.get_message.return_value = {
|
||||
"type": "message",
|
||||
"data": "success",
|
||||
}
|
||||
mock_redis.pubsub.return_value = mock_pubsub
|
||||
mock_cache_manager.signal_cache = mock_redis
|
||||
|
||||
result = TaskManager.wait_for_completion(
|
||||
"test-uuid",
|
||||
timeout=5.0,
|
||||
)
|
||||
|
||||
assert result.status == "success"
|
||||
# Should have subscribed to completion channel
|
||||
mock_pubsub.subscribe.assert_called_once_with("gtf:complete:test-uuid")
|
||||
# Should have cleaned up
|
||||
mock_pubsub.unsubscribe.assert_called_once()
|
||||
mock_pubsub.close.assert_called_once()
|
||||
|
||||
@patch("superset.tasks.manager.cache_manager")
|
||||
@patch("superset.daos.tasks.TaskDAO")
|
||||
def test_wait_for_completion_pubsub_error_raises(
|
||||
self, mock_dao, mock_cache_manager
|
||||
):
|
||||
"""Test wait_for_completion raises exception on Redis error when
|
||||
Redis configured"""
|
||||
import pytest
|
||||
|
||||
mock_task_pending = MagicMock()
|
||||
mock_task_pending.uuid = "test-uuid"
|
||||
mock_task_pending.status = "pending"
|
||||
|
||||
mock_dao.find_one_or_none.return_value = mock_task_pending
|
||||
|
||||
# Set up mock Redis that fails
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.pubsub.side_effect = redis.RedisError("Connection failed")
|
||||
mock_cache_manager.signal_cache = mock_redis
|
||||
|
||||
# With fail-fast behavior, Redis error is raised instead of falling back
|
||||
with pytest.raises(redis.RedisError, match="Connection failed"):
|
||||
TaskManager.wait_for_completion(
|
||||
"test-uuid",
|
||||
timeout=5.0,
|
||||
poll_interval=0.1,
|
||||
)
|
||||
|
||||
def test_terminal_states_constant(self):
|
||||
"""Test TERMINAL_STATES contains expected values"""
|
||||
expected = {"success", "failure", "aborted", "timed_out"}
|
||||
assert TaskManager.TERMINAL_STATES == expected
|
||||
612
tests/unit_tests/tasks/test_timeout.py
Normal file
612
tests/unit_tests/tasks/test_timeout.py
Normal file
@@ -0,0 +1,612 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""Unit tests for GTF timeout handling."""
|
||||
|
||||
import time
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from superset_core.api.tasks import TaskOptions, TaskScope
|
||||
|
||||
from superset.tasks.context import TaskContext
|
||||
from superset.tasks.decorators import TaskWrapper
|
||||
|
||||
TEST_UUID = UUID("b8b61b7b-1cd3-4a31-a74a-0a95341afc06")
|
||||
|
||||
# =============================================================================
|
||||
# Fixtures
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_flask_app():
|
||||
"""Create a properly configured mock Flask app."""
|
||||
mock_app = MagicMock()
|
||||
mock_app.config = {
|
||||
"TASK_ABORT_POLLING_DEFAULT_INTERVAL": 0.1,
|
||||
}
|
||||
# Make app_context() return a proper context manager
|
||||
mock_app.app_context.return_value.__enter__ = MagicMock(return_value=None)
|
||||
mock_app.app_context.return_value.__exit__ = MagicMock(return_value=None)
|
||||
return mock_app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_task_abortable():
|
||||
"""Create a mock task that is abortable."""
|
||||
task = MagicMock()
|
||||
task.uuid = TEST_UUID
|
||||
task.status = "in_progress"
|
||||
task.properties_dict = {"is_abortable": True}
|
||||
task.payload_dict = {}
|
||||
# Set real values for dedup_key generation (used by UpdateTaskCommand lock)
|
||||
task.scope = "shared"
|
||||
task.task_type = "test_task"
|
||||
task.task_key = "test_key"
|
||||
task.user_id = 1
|
||||
return task
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_task_not_abortable():
|
||||
"""Create a mock task that is NOT abortable."""
|
||||
task = MagicMock()
|
||||
task.uuid = TEST_UUID
|
||||
task.status = "in_progress"
|
||||
task.properties_dict = {} # No is_abortable means it's not abortable
|
||||
task.payload_dict = {}
|
||||
# Set real values for dedup_key generation (used by UpdateTaskCommand lock)
|
||||
task.scope = "shared"
|
||||
task.task_type = "test_task"
|
||||
task.task_key = "test_key"
|
||||
task.user_id = 1
|
||||
return task
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def task_context_for_timeout(mock_flask_app, mock_task_abortable):
|
||||
"""Create TaskContext with mocked dependencies for timeout tests."""
|
||||
# Ensure mock_task has required attributes for TaskContext
|
||||
mock_task_abortable.payload_dict = {}
|
||||
|
||||
with (
|
||||
patch("superset.tasks.context.current_app") as mock_current_app,
|
||||
patch("superset.daos.tasks.TaskDAO") as mock_dao,
|
||||
patch("superset.tasks.manager.cache_manager") as mock_cache_manager,
|
||||
):
|
||||
# Disable Redis by making signal_cache return None
|
||||
mock_cache_manager.signal_cache = None
|
||||
|
||||
# Configure current_app mock
|
||||
mock_current_app.config = mock_flask_app.config
|
||||
mock_current_app._get_current_object.return_value = mock_flask_app
|
||||
|
||||
# Configure TaskDAO mock
|
||||
mock_dao.find_one_or_none.return_value = mock_task_abortable
|
||||
|
||||
ctx = TaskContext(mock_task_abortable)
|
||||
ctx._app = mock_flask_app
|
||||
|
||||
yield ctx
|
||||
|
||||
# Cleanup: stop timers if started
|
||||
ctx.stop_timeout_timer()
|
||||
if ctx._abort_listener:
|
||||
ctx.stop_abort_polling()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# TaskWrapper._merge_options Timeout Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestTimeoutMerging:
|
||||
"""Test timeout merging behavior in TaskWrapper._merge_options."""
|
||||
|
||||
def test_merge_options_decorator_timeout_used_when_no_override(self):
|
||||
"""Test that decorator timeout is used when no override is provided."""
|
||||
|
||||
def dummy_func():
|
||||
pass
|
||||
|
||||
wrapper = TaskWrapper(
|
||||
name="test_task",
|
||||
func=dummy_func,
|
||||
default_options=TaskOptions(),
|
||||
scope=TaskScope.PRIVATE,
|
||||
default_timeout=300, # 5-minute default
|
||||
)
|
||||
|
||||
merged = wrapper._merge_options(None)
|
||||
assert merged.timeout == 300
|
||||
|
||||
def test_merge_options_override_timeout_takes_precedence(self):
|
||||
"""Test that TaskOptions timeout overrides decorator default."""
|
||||
|
||||
def dummy_func():
|
||||
pass
|
||||
|
||||
wrapper = TaskWrapper(
|
||||
name="test_task",
|
||||
func=dummy_func,
|
||||
default_options=TaskOptions(),
|
||||
scope=TaskScope.PRIVATE,
|
||||
default_timeout=300, # 5-minute default
|
||||
)
|
||||
|
||||
override = TaskOptions(timeout=600) # 10-minute override
|
||||
merged = wrapper._merge_options(override)
|
||||
assert merged.timeout == 600
|
||||
|
||||
def test_merge_options_no_timeout_when_not_configured(self):
|
||||
"""Test that no timeout is set when not configured anywhere."""
|
||||
|
||||
def dummy_func():
|
||||
pass
|
||||
|
||||
wrapper = TaskWrapper(
|
||||
name="test_task",
|
||||
func=dummy_func,
|
||||
default_options=TaskOptions(),
|
||||
scope=TaskScope.PRIVATE,
|
||||
default_timeout=None, # No default timeout
|
||||
)
|
||||
|
||||
merged = wrapper._merge_options(None)
|
||||
assert merged.timeout is None
|
||||
|
||||
def test_merge_options_override_with_other_options_preserves_timeout(self):
|
||||
"""Test that setting other options doesn't lose decorator timeout."""
|
||||
|
||||
def dummy_func():
|
||||
pass
|
||||
|
||||
wrapper = TaskWrapper(
|
||||
name="test_task",
|
||||
func=dummy_func,
|
||||
default_options=TaskOptions(),
|
||||
scope=TaskScope.PRIVATE,
|
||||
default_timeout=300,
|
||||
)
|
||||
|
||||
# Override only task_key, not timeout
|
||||
override = TaskOptions(task_key="my-key")
|
||||
merged = wrapper._merge_options(override)
|
||||
|
||||
# Should keep decorator timeout since override.timeout is None
|
||||
assert merged.timeout == 300
|
||||
assert merged.task_key == "my-key"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# TaskContext Timeout Timer Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestTimeoutTimer:
|
||||
"""Test TaskContext timeout timer behavior."""
|
||||
|
||||
def test_start_timeout_timer_sets_timer(self, task_context_for_timeout):
|
||||
"""Test that start_timeout_timer creates a timer."""
|
||||
ctx = task_context_for_timeout
|
||||
|
||||
assert ctx._timeout_timer is None
|
||||
|
||||
ctx.start_timeout_timer(10)
|
||||
|
||||
assert ctx._timeout_timer is not None
|
||||
assert ctx._timeout_triggered is False
|
||||
|
||||
def test_start_timeout_timer_is_idempotent(self, task_context_for_timeout):
|
||||
"""Test that starting timer twice doesn't create duplicate timers."""
|
||||
ctx = task_context_for_timeout
|
||||
|
||||
ctx.start_timeout_timer(10)
|
||||
first_timer = ctx._timeout_timer
|
||||
|
||||
ctx.start_timeout_timer(20) # Try to start again
|
||||
second_timer = ctx._timeout_timer
|
||||
|
||||
assert first_timer is second_timer
|
||||
|
||||
def test_stop_timeout_timer_cancels_timer(self, task_context_for_timeout):
|
||||
"""Test that stop_timeout_timer cancels the timer."""
|
||||
ctx = task_context_for_timeout
|
||||
|
||||
ctx.start_timeout_timer(10)
|
||||
assert ctx._timeout_timer is not None
|
||||
|
||||
ctx.stop_timeout_timer()
|
||||
|
||||
assert ctx._timeout_timer is None
|
||||
|
||||
def test_stop_timeout_timer_safe_when_no_timer(self, task_context_for_timeout):
|
||||
"""Test that stop_timeout_timer doesn't fail when no timer exists."""
|
||||
ctx = task_context_for_timeout
|
||||
|
||||
assert ctx._timeout_timer is None
|
||||
ctx.stop_timeout_timer() # Should not raise
|
||||
assert ctx._timeout_timer is None
|
||||
|
||||
def test_timeout_triggered_property_initially_false(self, task_context_for_timeout):
|
||||
"""Test that timeout_triggered is False initially."""
|
||||
ctx = task_context_for_timeout
|
||||
assert ctx.timeout_triggered is False
|
||||
|
||||
def test_cleanup_stops_timeout_timer(self, task_context_for_timeout):
|
||||
"""Test that _run_cleanup stops the timeout timer."""
|
||||
ctx = task_context_for_timeout
|
||||
|
||||
ctx.start_timeout_timer(10)
|
||||
assert ctx._timeout_timer is not None
|
||||
|
||||
ctx._run_cleanup()
|
||||
|
||||
assert ctx._timeout_timer is None
|
||||
|
||||
|
||||
class TestTimeoutTrigger:
|
||||
"""Test timeout trigger behavior when timer fires."""
|
||||
|
||||
def test_timeout_triggers_abort_when_abortable(
|
||||
self, mock_flask_app, mock_task_abortable
|
||||
):
|
||||
"""Test that timeout triggers abort handlers when task is abortable."""
|
||||
abort_called = False
|
||||
|
||||
with (
|
||||
patch("superset.tasks.context.current_app") as mock_current_app,
|
||||
patch("superset.daos.tasks.TaskDAO") as mock_dao,
|
||||
patch(
|
||||
"superset.commands.tasks.update.UpdateTaskCommand"
|
||||
) as mock_update_cmd,
|
||||
patch("superset.tasks.manager.cache_manager") as mock_cache_manager,
|
||||
):
|
||||
# Disable Redis by making signal_cache return None
|
||||
mock_cache_manager.signal_cache = None
|
||||
|
||||
mock_current_app.config = mock_flask_app.config
|
||||
mock_current_app._get_current_object.return_value = mock_flask_app
|
||||
mock_dao.find_one_or_none.return_value = mock_task_abortable
|
||||
|
||||
ctx = TaskContext(mock_task_abortable)
|
||||
ctx._app = mock_flask_app
|
||||
|
||||
@ctx.on_abort
|
||||
def handle_abort():
|
||||
nonlocal abort_called
|
||||
abort_called = True
|
||||
|
||||
# Start short timeout
|
||||
ctx.start_timeout_timer(1)
|
||||
|
||||
# Wait for timeout to fire
|
||||
time.sleep(1.5)
|
||||
|
||||
# Abort handler should have been called
|
||||
assert abort_called
|
||||
assert ctx._timeout_triggered
|
||||
assert ctx._abort_detected
|
||||
|
||||
# Verify UpdateTaskCommand was called with ABORTING status
|
||||
mock_update_cmd.assert_called()
|
||||
call_kwargs = mock_update_cmd.call_args[1]
|
||||
assert call_kwargs.get("status") == "aborting"
|
||||
|
||||
# Cleanup
|
||||
ctx.stop_timeout_timer()
|
||||
if ctx._abort_listener:
|
||||
ctx.stop_abort_polling()
|
||||
|
||||
def test_timeout_logs_warning_when_not_abortable(
|
||||
self, mock_flask_app, mock_task_not_abortable
|
||||
):
|
||||
"""Test that timeout logs warning when task has no abort handler."""
|
||||
with (
|
||||
patch("superset.tasks.context.current_app") as mock_current_app,
|
||||
patch("superset.daos.tasks.TaskDAO") as mock_dao,
|
||||
patch("superset.tasks.context.logger") as mock_logger,
|
||||
patch("superset.tasks.manager.cache_manager") as mock_cache_manager,
|
||||
):
|
||||
# Disable Redis by making signal_cache return None
|
||||
mock_cache_manager.signal_cache = None
|
||||
|
||||
mock_current_app.config = mock_flask_app.config
|
||||
mock_current_app._get_current_object.return_value = mock_flask_app
|
||||
mock_dao.find_one_or_none.return_value = mock_task_not_abortable
|
||||
|
||||
ctx = TaskContext(mock_task_not_abortable)
|
||||
ctx._app = mock_flask_app
|
||||
|
||||
# No abort handler registered
|
||||
|
||||
# Start short timeout
|
||||
ctx.start_timeout_timer(1)
|
||||
|
||||
# Wait for timeout to fire
|
||||
time.sleep(1.5)
|
||||
|
||||
# Should have logged warning
|
||||
mock_logger.warning.assert_called()
|
||||
warning_call = mock_logger.warning.call_args
|
||||
assert "no abort handler" in warning_call[0][0].lower()
|
||||
assert ctx._timeout_triggered
|
||||
assert not ctx._abort_detected # No abort since no handler
|
||||
|
||||
# Cleanup
|
||||
ctx.stop_timeout_timer()
|
||||
|
||||
def test_timeout_does_not_trigger_if_already_aborting(
|
||||
self, mock_flask_app, mock_task_abortable
|
||||
):
|
||||
"""Test that timeout doesn't re-trigger abort if already aborting."""
|
||||
abort_count = 0
|
||||
|
||||
with (
|
||||
patch("superset.tasks.context.current_app") as mock_current_app,
|
||||
patch("superset.daos.tasks.TaskDAO") as mock_dao,
|
||||
patch("superset.commands.tasks.update.UpdateTaskCommand"),
|
||||
patch("superset.tasks.manager.cache_manager") as mock_cache_manager,
|
||||
):
|
||||
# Disable Redis by making signal_cache return None
|
||||
mock_cache_manager.signal_cache = None
|
||||
|
||||
mock_current_app.config = mock_flask_app.config
|
||||
mock_current_app._get_current_object.return_value = mock_flask_app
|
||||
mock_dao.find_one_or_none.return_value = mock_task_abortable
|
||||
|
||||
ctx = TaskContext(mock_task_abortable)
|
||||
ctx._app = mock_flask_app
|
||||
|
||||
@ctx.on_abort
|
||||
def handle_abort():
|
||||
nonlocal abort_count
|
||||
abort_count += 1
|
||||
|
||||
# Pre-set abort detected
|
||||
ctx._abort_detected = True
|
||||
|
||||
# Start short timeout
|
||||
ctx.start_timeout_timer(1)
|
||||
|
||||
# Wait for timeout to fire
|
||||
time.sleep(1.5)
|
||||
|
||||
# Handler should NOT have been called since already aborting
|
||||
assert abort_count == 0
|
||||
|
||||
# Cleanup
|
||||
ctx.stop_timeout_timer()
|
||||
if ctx._abort_listener:
|
||||
ctx.stop_abort_polling()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Task Decorator Timeout Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestTaskDecoratorTimeout:
|
||||
"""Test @task decorator timeout parameter."""
|
||||
|
||||
def test_task_decorator_accepts_timeout(self):
|
||||
"""Test that @task decorator accepts timeout parameter."""
|
||||
from superset.tasks.decorators import task
|
||||
from superset.tasks.registry import TaskRegistry
|
||||
|
||||
@task(name="test_timeout_task_1", timeout=300)
|
||||
def timeout_test_task_1():
|
||||
pass
|
||||
|
||||
assert isinstance(timeout_test_task_1, TaskWrapper)
|
||||
assert timeout_test_task_1.default_timeout == 300
|
||||
|
||||
# Cleanup registry
|
||||
TaskRegistry._tasks.pop("test_timeout_task_1", None)
|
||||
|
||||
def test_task_decorator_without_timeout(self):
|
||||
"""Test that @task decorator works without timeout."""
|
||||
from superset.tasks.decorators import task
|
||||
from superset.tasks.registry import TaskRegistry
|
||||
|
||||
@task(name="test_timeout_task_2")
|
||||
def timeout_test_task_2():
|
||||
pass
|
||||
|
||||
assert isinstance(timeout_test_task_2, TaskWrapper)
|
||||
assert timeout_test_task_2.default_timeout is None
|
||||
|
||||
# Cleanup registry
|
||||
TaskRegistry._tasks.pop("test_timeout_task_2", None)
|
||||
|
||||
def test_task_decorator_with_all_params(self):
|
||||
"""Test that @task decorator accepts all parameters together."""
|
||||
from superset.tasks.decorators import task
|
||||
from superset.tasks.registry import TaskRegistry
|
||||
|
||||
@task(name="test_timeout_task_3", scope=TaskScope.SHARED, timeout=600)
|
||||
def timeout_test_task_3():
|
||||
pass
|
||||
|
||||
assert timeout_test_task_3.name == "test_timeout_task_3"
|
||||
assert timeout_test_task_3.scope == TaskScope.SHARED
|
||||
assert timeout_test_task_3.default_timeout == 600
|
||||
|
||||
# Cleanup registry
|
||||
TaskRegistry._tasks.pop("test_timeout_task_3", None)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Timeout Terminal State Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestTimeoutTerminalState:
|
||||
"""Test timeout transitions to correct terminal state (TIMED_OUT vs FAILURE)."""
|
||||
|
||||
def test_timeout_triggered_flag_set_on_timeout(
|
||||
self, mock_flask_app, mock_task_abortable
|
||||
):
|
||||
"""Test that timeout_triggered flag is set when timeout fires."""
|
||||
with (
|
||||
patch("superset.tasks.context.current_app") as mock_current_app,
|
||||
patch("superset.daos.tasks.TaskDAO") as mock_dao,
|
||||
patch("superset.commands.tasks.update.UpdateTaskCommand"),
|
||||
patch("superset.tasks.manager.cache_manager") as mock_cache_manager,
|
||||
):
|
||||
# Disable Redis by making signal_cache return None
|
||||
mock_cache_manager.signal_cache = None
|
||||
|
||||
mock_current_app.config = mock_flask_app.config
|
||||
mock_current_app._get_current_object.return_value = mock_flask_app
|
||||
mock_dao.find_one_or_none.return_value = mock_task_abortable
|
||||
|
||||
ctx = TaskContext(mock_task_abortable)
|
||||
ctx._app = mock_flask_app
|
||||
|
||||
@ctx.on_abort
|
||||
def handle_abort():
|
||||
pass
|
||||
|
||||
# Initially not triggered
|
||||
assert ctx.timeout_triggered is False
|
||||
|
||||
# Start short timeout
|
||||
ctx.start_timeout_timer(1)
|
||||
|
||||
# Wait for timeout to fire
|
||||
time.sleep(1.5)
|
||||
|
||||
# Should be set after timeout
|
||||
assert ctx.timeout_triggered is True
|
||||
|
||||
# Cleanup
|
||||
ctx.stop_timeout_timer()
|
||||
if ctx._abort_listener:
|
||||
ctx.stop_abort_polling()
|
||||
|
||||
def test_user_abort_does_not_set_timeout_triggered(
|
||||
self, mock_flask_app, mock_task_abortable
|
||||
):
|
||||
"""Test that user abort doesn't set timeout_triggered flag."""
|
||||
with (
|
||||
patch("superset.tasks.context.current_app") as mock_current_app,
|
||||
patch("superset.daos.tasks.TaskDAO") as mock_dao,
|
||||
patch("superset.commands.tasks.update.UpdateTaskCommand"),
|
||||
patch("superset.tasks.manager.cache_manager") as mock_cache_manager,
|
||||
):
|
||||
# Disable Redis by making signal_cache return None
|
||||
mock_cache_manager.signal_cache = None
|
||||
|
||||
mock_current_app.config = mock_flask_app.config
|
||||
mock_current_app._get_current_object.return_value = mock_flask_app
|
||||
mock_dao.find_one_or_none.return_value = mock_task_abortable
|
||||
|
||||
ctx = TaskContext(mock_task_abortable)
|
||||
ctx._app = mock_flask_app
|
||||
|
||||
@ctx.on_abort
|
||||
def handle_abort():
|
||||
pass
|
||||
|
||||
# Simulate user abort (not timeout)
|
||||
ctx._on_abort_detected()
|
||||
|
||||
# timeout_triggered should still be False
|
||||
assert ctx.timeout_triggered is False
|
||||
# But abort_detected should be True
|
||||
assert ctx._abort_detected is True
|
||||
|
||||
# Cleanup
|
||||
if ctx._abort_listener:
|
||||
ctx.stop_abort_polling()
|
||||
|
||||
def test_abort_handlers_completed_tracks_success(
|
||||
self, mock_flask_app, mock_task_abortable
|
||||
):
|
||||
"""Test that abort_handlers_completed flag tracks successful
|
||||
handler execution."""
|
||||
with (
|
||||
patch("superset.tasks.context.current_app") as mock_current_app,
|
||||
patch("superset.daos.tasks.TaskDAO") as mock_dao,
|
||||
patch("superset.commands.tasks.update.UpdateTaskCommand"),
|
||||
patch("superset.tasks.manager.cache_manager") as mock_cache_manager,
|
||||
):
|
||||
# Disable Redis by making signal_cache return None
|
||||
mock_cache_manager.signal_cache = None
|
||||
|
||||
mock_current_app.config = mock_flask_app.config
|
||||
mock_current_app._get_current_object.return_value = mock_flask_app
|
||||
mock_dao.find_one_or_none.return_value = mock_task_abortable
|
||||
|
||||
ctx = TaskContext(mock_task_abortable)
|
||||
ctx._app = mock_flask_app
|
||||
|
||||
@ctx.on_abort
|
||||
def handle_abort():
|
||||
pass # Successful handler
|
||||
|
||||
# Initially not completed
|
||||
assert ctx.abort_handlers_completed is False
|
||||
|
||||
# Trigger abort handlers
|
||||
ctx._trigger_abort_handlers()
|
||||
|
||||
# Should be marked as completed
|
||||
assert ctx.abort_handlers_completed is True
|
||||
|
||||
# Cleanup
|
||||
if ctx._abort_listener:
|
||||
ctx.stop_abort_polling()
|
||||
|
||||
def test_abort_handlers_completed_false_on_exception(
|
||||
self, mock_flask_app, mock_task_abortable
|
||||
):
|
||||
"""Test that abort_handlers_completed is False when handler throws."""
|
||||
with (
|
||||
patch("superset.tasks.context.current_app") as mock_current_app,
|
||||
patch("superset.daos.tasks.TaskDAO") as mock_dao,
|
||||
patch("superset.commands.tasks.update.UpdateTaskCommand"),
|
||||
patch("superset.tasks.manager.cache_manager") as mock_cache_manager,
|
||||
):
|
||||
# Disable Redis by making signal_cache return None
|
||||
mock_cache_manager.signal_cache = None
|
||||
|
||||
mock_current_app.config = mock_flask_app.config
|
||||
mock_current_app._get_current_object.return_value = mock_flask_app
|
||||
mock_dao.find_one_or_none.return_value = mock_task_abortable
|
||||
|
||||
ctx = TaskContext(mock_task_abortable)
|
||||
ctx._app = mock_flask_app
|
||||
|
||||
@ctx.on_abort
|
||||
def handle_abort():
|
||||
raise ValueError("Handler failed")
|
||||
|
||||
# Initially not completed
|
||||
assert ctx.abort_handlers_completed is False
|
||||
|
||||
# Trigger abort handlers (will catch the exception internally)
|
||||
ctx._trigger_abort_handlers()
|
||||
|
||||
# Should NOT be marked as completed since handler threw
|
||||
assert ctx.abort_handlers_completed is False
|
||||
|
||||
# Cleanup
|
||||
if ctx._abort_listener:
|
||||
ctx.stop_abort_polling()
|
||||
@@ -22,9 +22,19 @@ from typing import Any, Optional, Union
|
||||
|
||||
import pytest
|
||||
from flask_appbuilder.security.sqla.models import User
|
||||
from superset_core.api.tasks import TaskScope
|
||||
|
||||
from superset.tasks.exceptions import ExecutorNotFoundError, InvalidExecutorError
|
||||
from superset.tasks.types import Executor, ExecutorType, FixedExecutor
|
||||
from superset.tasks.utils import (
|
||||
error_update,
|
||||
get_active_dedup_key,
|
||||
get_finished_dedup_key,
|
||||
parse_properties,
|
||||
progress_update,
|
||||
serialize_properties,
|
||||
)
|
||||
from superset.utils.hashing import hash_from_str
|
||||
|
||||
FIXED_USER_ID = 1234
|
||||
FIXED_USERNAME = "admin"
|
||||
@@ -330,3 +340,242 @@ def test_get_executor(
|
||||
)
|
||||
assert executor_type == expected_executor_type
|
||||
assert executor == expected_executor
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"scope,task_type,task_key,user_id,expected_composite_key",
|
||||
[
|
||||
# Private tasks with TaskScope enum
|
||||
(
|
||||
TaskScope.PRIVATE,
|
||||
"sql_execution",
|
||||
"chart_123",
|
||||
42,
|
||||
"private|sql_execution|chart_123|42",
|
||||
),
|
||||
(
|
||||
TaskScope.PRIVATE,
|
||||
"thumbnail_gen",
|
||||
"dash_456",
|
||||
100,
|
||||
"private|thumbnail_gen|dash_456|100",
|
||||
),
|
||||
# Private tasks with string scope
|
||||
(
|
||||
"private",
|
||||
"api_call",
|
||||
"endpoint_789",
|
||||
200,
|
||||
"private|api_call|endpoint_789|200",
|
||||
),
|
||||
# Shared tasks with TaskScope enum
|
||||
(
|
||||
TaskScope.SHARED,
|
||||
"report_gen",
|
||||
"monthly_report",
|
||||
None,
|
||||
"shared|report_gen|monthly_report",
|
||||
),
|
||||
(
|
||||
TaskScope.SHARED,
|
||||
"export_csv",
|
||||
"large_export",
|
||||
999, # user_id should be ignored for shared
|
||||
"shared|export_csv|large_export",
|
||||
),
|
||||
# Shared tasks with string scope
|
||||
(
|
||||
"shared",
|
||||
"batch_process",
|
||||
"batch_001",
|
||||
123, # user_id should be ignored for shared
|
||||
"shared|batch_process|batch_001",
|
||||
),
|
||||
# System tasks with TaskScope enum
|
||||
(
|
||||
TaskScope.SYSTEM,
|
||||
"cleanup_task",
|
||||
"daily_cleanup",
|
||||
None,
|
||||
"system|cleanup_task|daily_cleanup",
|
||||
),
|
||||
(
|
||||
TaskScope.SYSTEM,
|
||||
"db_migration",
|
||||
"version_123",
|
||||
1, # user_id should be ignored for system
|
||||
"system|db_migration|version_123",
|
||||
),
|
||||
# System tasks with string scope
|
||||
(
|
||||
"system",
|
||||
"maintenance",
|
||||
"nightly_job",
|
||||
2, # user_id should be ignored for system
|
||||
"system|maintenance|nightly_job",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_get_active_dedup_key(
|
||||
scope, task_type, task_key, user_id, expected_composite_key, app_context
|
||||
):
|
||||
"""Test get_active_dedup_key generates a hash of the composite key.
|
||||
|
||||
The function hashes the composite key using the configured HASH_ALGORITHM
|
||||
to produce a fixed-length dedup_key for database storage. The result is
|
||||
truncated to 64 chars to fit the database column.
|
||||
"""
|
||||
result = get_active_dedup_key(scope, task_type, task_key, user_id)
|
||||
|
||||
# The result should be a hash of the expected composite key, truncated to 64 chars
|
||||
expected_hash = hash_from_str(expected_composite_key)[:64]
|
||||
assert result == expected_hash
|
||||
assert len(result) <= 64
|
||||
|
||||
|
||||
def test_get_active_dedup_key_private_requires_user_id():
|
||||
"""Test that private tasks require explicit user_id parameter."""
|
||||
with pytest.raises(ValueError, match="user_id required for private tasks"):
|
||||
get_active_dedup_key(TaskScope.PRIVATE, "test_type", "test_key")
|
||||
|
||||
|
||||
def test_get_finished_dedup_key():
|
||||
"""Test that finished tasks use UUID as dedup_key"""
|
||||
test_uuid = "a1b2c3d4-e5f6-7890-abcd-ef1234567890"
|
||||
result = get_finished_dedup_key(test_uuid)
|
||||
assert result == test_uuid
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"progress,expected",
|
||||
[
|
||||
# Float (percentage) progress
|
||||
(0.5, {"progress_percent": 0.5}),
|
||||
(0.0, {"progress_percent": 0.0}),
|
||||
(1.0, {"progress_percent": 1.0}),
|
||||
(0.25, {"progress_percent": 0.25}),
|
||||
# Int (count only) progress
|
||||
(42, {"progress_current": 42}),
|
||||
(0, {"progress_current": 0}),
|
||||
(1000, {"progress_current": 1000}),
|
||||
# Tuple (current, total) progress with auto-computed percentage
|
||||
(
|
||||
(50, 100),
|
||||
{"progress_current": 50, "progress_total": 100, "progress_percent": 0.5},
|
||||
),
|
||||
(
|
||||
(25, 100),
|
||||
{"progress_current": 25, "progress_total": 100, "progress_percent": 0.25},
|
||||
),
|
||||
(
|
||||
(100, 100),
|
||||
{"progress_current": 100, "progress_total": 100, "progress_percent": 1.0},
|
||||
),
|
||||
# Tuple with zero total (no percentage computed)
|
||||
((10, 0), {"progress_current": 10, "progress_total": 0}),
|
||||
((0, 0), {"progress_current": 0, "progress_total": 0}),
|
||||
],
|
||||
)
|
||||
def test_progress_update(progress, expected):
|
||||
"""Test progress_update returns correct TaskProperties dict."""
|
||||
result = progress_update(progress)
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_error_update():
|
||||
"""Test error_update captures exception details."""
|
||||
try:
|
||||
raise ValueError("Test error message")
|
||||
except ValueError as e:
|
||||
result = error_update(e)
|
||||
|
||||
assert result["error_message"] == "Test error message"
|
||||
assert result["exception_type"] == "ValueError"
|
||||
assert "stack_trace" in result
|
||||
assert "ValueError" in result["stack_trace"]
|
||||
|
||||
|
||||
def test_error_update_custom_exception():
|
||||
"""Test error_update with custom exception class."""
|
||||
|
||||
class CustomError(Exception):
|
||||
pass
|
||||
|
||||
try:
|
||||
raise CustomError("Custom error")
|
||||
except CustomError as e:
|
||||
result = error_update(e)
|
||||
|
||||
assert result["error_message"] == "Custom error"
|
||||
assert result["exception_type"] == "CustomError"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"json_str,expected",
|
||||
[
|
||||
# Valid JSON
|
||||
(
|
||||
'{"is_abortable": true, "progress_percent": 0.5}',
|
||||
{"is_abortable": True, "progress_percent": 0.5},
|
||||
),
|
||||
(
|
||||
'{"error_message": "Something failed"}',
|
||||
{"error_message": "Something failed"},
|
||||
),
|
||||
(
|
||||
'{"progress_current": 50, "progress_total": 100}',
|
||||
{"progress_current": 50, "progress_total": 100},
|
||||
),
|
||||
# Empty/None cases
|
||||
("", {}),
|
||||
(None, {}),
|
||||
# Invalid JSON returns empty dict
|
||||
("not valid json", {}),
|
||||
("{broken", {}),
|
||||
# Unknown keys are preserved (forward compatibility)
|
||||
(
|
||||
'{"is_abortable": true, "future_field": "value"}',
|
||||
{"is_abortable": True, "future_field": "value"},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_parse_properties(json_str, expected):
|
||||
"""Test parse_properties parses JSON to TaskProperties dict."""
|
||||
result = parse_properties(json_str)
|
||||
assert result == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"props,expected_contains",
|
||||
[
|
||||
# Full properties
|
||||
(
|
||||
{"is_abortable": True, "progress_percent": 0.5},
|
||||
{"is_abortable": True, "progress_percent": 0.5},
|
||||
),
|
||||
# Empty dict
|
||||
({}, {}),
|
||||
# Sparse properties
|
||||
({"is_abortable": True}, {"is_abortable": True}),
|
||||
({"error_message": "fail"}, {"error_message": "fail"}),
|
||||
],
|
||||
)
|
||||
def test_serialize_properties(props, expected_contains):
|
||||
"""Test serialize_properties converts TaskProperties to JSON."""
|
||||
from superset.utils import json
|
||||
|
||||
result = serialize_properties(props)
|
||||
parsed = json.loads(result)
|
||||
assert parsed == expected_contains
|
||||
|
||||
|
||||
def test_properties_roundtrip():
|
||||
"""Test that serialize -> parse roundtrip preserves data."""
|
||||
original = {
|
||||
"is_abortable": True,
|
||||
"progress_percent": 0.75,
|
||||
"error_message": "Test error",
|
||||
}
|
||||
serialized = serialize_properties(original)
|
||||
parsed = parse_properties(serialized)
|
||||
assert parsed == original
|
||||
|
||||
@@ -54,7 +54,7 @@ def test_json_loads_exception():
|
||||
|
||||
|
||||
def test_json_loads_encoding():
|
||||
unicode_data = b'{"a": "\u0073\u0074\u0072"}'
|
||||
unicode_data = rb'{"a": "\u0073\u0074\u0072"}'
|
||||
data = json.loads(unicode_data)
|
||||
assert data["a"] == "str"
|
||||
utf16_data = b'\xff\xfe{\x00"\x00a\x00"\x00:\x00 \x00"\x00s\x00t\x00r\x00"\x00}\x00'
|
||||
|
||||
@@ -119,7 +119,7 @@ def test_refresh_oauth2_token_deletes_token_on_oauth2_exception(
|
||||
was revoked), the invalid token should be deleted and the exception re-raised.
|
||||
"""
|
||||
db = mocker.patch("superset.utils.oauth2.db")
|
||||
mocker.patch("superset.utils.oauth2.KeyValueDistributedLock")
|
||||
mocker.patch("superset.utils.oauth2.DistributedLock")
|
||||
|
||||
class OAuth2ExceptionError(Exception):
|
||||
pass
|
||||
@@ -149,7 +149,7 @@ def test_refresh_oauth2_token_keeps_token_on_other_exception(
|
||||
exception re-raised.
|
||||
"""
|
||||
db = mocker.patch("superset.utils.oauth2.db")
|
||||
mocker.patch("superset.utils.oauth2.KeyValueDistributedLock")
|
||||
mocker.patch("superset.utils.oauth2.DistributedLock")
|
||||
|
||||
class OAuth2ExceptionError(Exception):
|
||||
pass
|
||||
@@ -175,7 +175,7 @@ def test_refresh_oauth2_token_no_access_token_in_response(
|
||||
This can happen when the refresh token was revoked.
|
||||
"""
|
||||
mocker.patch("superset.utils.oauth2.db")
|
||||
mocker.patch("superset.utils.oauth2.KeyValueDistributedLock")
|
||||
mocker.patch("superset.utils.oauth2.DistributedLock")
|
||||
db_engine_spec = mocker.MagicMock()
|
||||
db_engine_spec.get_oauth2_fresh_token.return_value = {
|
||||
"error": "invalid_grant",
|
||||
|
||||
Reference in New Issue
Block a user