mirror of
https://github.com/apache/superset.git
synced 2026-04-19 08:04:53 +00:00
fix(gtf): set dedup_key on atomic sql (#37820)
This commit is contained in:
@@ -50,7 +50,7 @@ When GTF is considered stable, it will replace legacy Celery tasks for built-in
|
||||
### Define a Task
|
||||
|
||||
```python
|
||||
from superset_core.api.types import task, get_context
|
||||
from superset_core.api.tasks import task, get_context
|
||||
|
||||
@task
|
||||
def process_data(dataset_id: int) -> None:
|
||||
@@ -245,7 +245,7 @@ Always implement an abort handler for long-running tasks. This allows users to c
|
||||
Set a timeout to automatically abort tasks that run too long:
|
||||
|
||||
```python
|
||||
from superset_core.api.types import task, get_context, TaskOptions
|
||||
from superset_core.api.tasks import task, get_context, TaskOptions
|
||||
|
||||
# Set default timeout in decorator
|
||||
@task(timeout=300) # 5 minutes
|
||||
@@ -299,7 +299,7 @@ Timeouts require an abort handler to be effective. Without one, the timeout trig
|
||||
Use `task_key` to prevent duplicate task execution:
|
||||
|
||||
```python
|
||||
from superset_core.api.types import TaskOptions
|
||||
from superset_core.api.tasks import TaskOptions
|
||||
|
||||
# Without key - creates new task each time (random UUID)
|
||||
task1 = my_task.schedule(x=1)
|
||||
@@ -331,7 +331,7 @@ print(task2.status) # "success" (terminal status)
|
||||
## Task Scopes
|
||||
|
||||
```python
|
||||
from superset_core.api.types import task, TaskScope
|
||||
from superset_core.api.tasks import task, TaskScope
|
||||
|
||||
@task # Private by default
|
||||
def private_task(): ...
|
||||
|
||||
@@ -259,7 +259,7 @@ def task(
|
||||
is discarded; only side effects and context updates matter.
|
||||
|
||||
Example:
|
||||
from superset_core.api.types import task, get_context, TaskScope
|
||||
from superset_core.api.tasks import task, get_context, TaskScope
|
||||
|
||||
# Private task (default scope)
|
||||
@task
|
||||
|
||||
@@ -28,9 +28,9 @@ from superset.daos.exceptions import DAODeleteFailedError
|
||||
from superset.extensions import db
|
||||
from superset.models.task_subscribers import TaskSubscriber
|
||||
from superset.models.tasks import Task
|
||||
from superset.tasks.constants import ABORTABLE_STATES
|
||||
from superset.tasks.constants import ABORTABLE_STATES, TERMINAL_STATES
|
||||
from superset.tasks.filters import TaskFilter
|
||||
from superset.tasks.utils import get_active_dedup_key, json
|
||||
from superset.tasks.utils import get_active_dedup_key, get_finished_dedup_key, json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -243,7 +243,7 @@ class TaskDAO(BaseDAO[Task]):
|
||||
)
|
||||
|
||||
# Transition to ABORTING (not ABORTED yet)
|
||||
task.status = TaskStatus.ABORTING.value
|
||||
task.set_status(TaskStatus.ABORTING)
|
||||
db.session.merge(task)
|
||||
logger.info("Set task %s to ABORTING (scope: %s)", task_uuid, task.scope)
|
||||
|
||||
@@ -444,6 +444,10 @@ class TaskDAO(BaseDAO[Task]):
|
||||
if set_ended_at:
|
||||
update_values["ended_at"] = datetime.now(timezone.utc)
|
||||
|
||||
# Update dedup_key if transitioning to terminal state
|
||||
if new_status_val in TERMINAL_STATES:
|
||||
update_values["dedup_key"] = get_finished_dedup_key(task_uuid)
|
||||
|
||||
# Atomic compare-and-swap: only update if status matches expected
|
||||
rows_updated = (
|
||||
db.session.query(Task)
|
||||
|
||||
@@ -37,6 +37,7 @@ from superset_core.api.tasks import TaskProperties, TaskStatus
|
||||
|
||||
from superset.models.helpers import AuditMixinNullable
|
||||
from superset.models.task_subscribers import TaskSubscriber
|
||||
from superset.tasks.constants import TERMINAL_STATES
|
||||
from superset.tasks.utils import (
|
||||
error_update,
|
||||
get_finished_dedup_key,
|
||||
@@ -218,12 +219,7 @@ class Task(CoreTask, AuditMixinNullable, Model):
|
||||
# (will be set to True if/when an abort handler is registered)
|
||||
if self.properties_dict.get("is_abortable") is None:
|
||||
self.update_properties({"is_abortable": False})
|
||||
elif status in [
|
||||
TaskStatus.SUCCESS.value,
|
||||
TaskStatus.FAILURE.value,
|
||||
TaskStatus.ABORTED.value,
|
||||
TaskStatus.TIMED_OUT.value,
|
||||
]:
|
||||
elif status in TERMINAL_STATES:
|
||||
if not self.ended_at:
|
||||
self.ended_at = now
|
||||
# Update dedup_key to UUID to free up the slot for new tasks
|
||||
@@ -244,12 +240,7 @@ class Task(CoreTask, AuditMixinNullable, Model):
|
||||
@property
|
||||
def is_finished(self) -> bool:
|
||||
"""Check if task has finished (success, failure, aborted, or timed out)."""
|
||||
return self.status in [
|
||||
TaskStatus.SUCCESS.value,
|
||||
TaskStatus.FAILURE.value,
|
||||
TaskStatus.ABORTED.value,
|
||||
TaskStatus.TIMED_OUT.value,
|
||||
]
|
||||
return self.status in TERMINAL_STATES
|
||||
|
||||
@property
|
||||
def is_successful(self) -> bool:
|
||||
|
||||
@@ -112,9 +112,6 @@ class TaskManager:
|
||||
_completion_channel_prefix: str = "gtf:complete:"
|
||||
_initialized: bool = False
|
||||
|
||||
# Backward compatibility alias - prefer importing from superset.tasks.constants
|
||||
TERMINAL_STATES = TERMINAL_STATES
|
||||
|
||||
@classmethod
|
||||
def init_app(cls, app: Flask) -> None:
|
||||
"""
|
||||
@@ -271,7 +268,7 @@ class TaskManager:
|
||||
if not task:
|
||||
raise ValueError(f"Task {task_uuid} not found")
|
||||
|
||||
if task.status in cls.TERMINAL_STATES:
|
||||
if task.status in TERMINAL_STATES:
|
||||
return task
|
||||
|
||||
logger.debug(
|
||||
@@ -342,13 +339,13 @@ class TaskManager:
|
||||
message.get("data"),
|
||||
)
|
||||
task = get_task()
|
||||
if task and task.status in cls.TERMINAL_STATES:
|
||||
if task and task.status in TERMINAL_STATES:
|
||||
return task
|
||||
|
||||
# Also check database periodically in case we missed the message
|
||||
# (e.g., task completed before we subscribed)
|
||||
task = get_task()
|
||||
if task and task.status in cls.TERMINAL_STATES:
|
||||
if task and task.status in TERMINAL_STATES:
|
||||
logger.debug(
|
||||
"Task %s completed (detected via db check): status=%s",
|
||||
task_uuid,
|
||||
@@ -384,7 +381,7 @@ class TaskManager:
|
||||
if not task:
|
||||
raise ValueError(f"Task {task_uuid} not found")
|
||||
|
||||
if task.status in cls.TERMINAL_STATES:
|
||||
if task.status in TERMINAL_STATES:
|
||||
logger.debug(
|
||||
"Task %s completed (detected via polling): status=%s",
|
||||
task_uuid,
|
||||
|
||||
@@ -25,6 +25,9 @@ get_delete_ids_schema = {"type": "array", "items": {"type": "string"}}
|
||||
# Field descriptions
|
||||
uuid_description = "The unique identifier (UUID) of the task"
|
||||
task_key_description = "The task identifier used for deduplication"
|
||||
dedup_key_description = (
|
||||
"The hashed deduplication key used internally for task deduplication"
|
||||
)
|
||||
task_type_description = (
|
||||
"The type of task (e.g., 'sql_execution', 'thumbnail_generation')"
|
||||
)
|
||||
@@ -74,6 +77,7 @@ class TaskResponseSchema(Schema):
|
||||
id = fields.Int(metadata={"description": "Internal task ID"})
|
||||
uuid = fields.UUID(metadata={"description": uuid_description})
|
||||
task_key = fields.String(metadata={"description": task_key_description})
|
||||
dedup_key = fields.String(metadata={"description": dedup_key_description})
|
||||
task_type = fields.String(metadata={"description": task_type_description})
|
||||
task_name = fields.String(
|
||||
metadata={"description": task_name_description}, allow_none=True
|
||||
|
||||
@@ -68,21 +68,6 @@ def test_submit_task_distinguishes_new_vs_existing(
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def test_terminal_states_recognized_correctly(app_context) -> None:
|
||||
"""
|
||||
Test that TaskManager.TERMINAL_STATES contains the expected values.
|
||||
"""
|
||||
assert TaskStatus.SUCCESS.value in TaskManager.TERMINAL_STATES
|
||||
assert TaskStatus.FAILURE.value in TaskManager.TERMINAL_STATES
|
||||
assert TaskStatus.ABORTED.value in TaskManager.TERMINAL_STATES
|
||||
assert TaskStatus.TIMED_OUT.value in TaskManager.TERMINAL_STATES
|
||||
|
||||
# Non-terminal states should not be in the set
|
||||
assert TaskStatus.PENDING.value not in TaskManager.TERMINAL_STATES
|
||||
assert TaskStatus.IN_PROGRESS.value not in TaskManager.TERMINAL_STATES
|
||||
assert TaskStatus.ABORTING.value not in TaskManager.TERMINAL_STATES
|
||||
|
||||
|
||||
def test_wait_for_completion_timeout(app_context, login_as, get_user) -> None:
|
||||
"""
|
||||
Test that wait_for_completion raises TimeoutError on timeout.
|
||||
|
||||
@@ -418,3 +418,86 @@ def test_get_status_not_found(session_with_task: Session) -> None:
|
||||
result = TaskDAO.get_status(UUID("00000000-0000-0000-0000-000000000000"))
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_conditional_status_update_non_terminal_state_keeps_dedup_key(
|
||||
session_with_task: Session,
|
||||
) -> None:
|
||||
"""Test that conditional_status_update preserves dedup_key for
|
||||
non-terminal transitions"""
|
||||
from superset.daos.tasks import TaskDAO
|
||||
|
||||
# Create task in PENDING state
|
||||
task = create_task(
|
||||
session_with_task,
|
||||
task_uuid=TASK_UUID,
|
||||
task_key="non-terminal-test-task",
|
||||
status=TaskStatus.PENDING,
|
||||
)
|
||||
|
||||
# Store original active dedup_key
|
||||
original_dedup_key = task.dedup_key
|
||||
|
||||
# Transition to non-terminal state (IN_PROGRESS)
|
||||
result = TaskDAO.conditional_status_update(
|
||||
task_uuid=TASK_UUID,
|
||||
new_status=TaskStatus.IN_PROGRESS,
|
||||
expected_status=TaskStatus.PENDING,
|
||||
set_started_at=True,
|
||||
)
|
||||
|
||||
# Should succeed
|
||||
assert result is True
|
||||
|
||||
# Refresh task and verify dedup_key was NOT changed
|
||||
session_with_task.refresh(task)
|
||||
assert task.status == TaskStatus.IN_PROGRESS.value
|
||||
assert task.dedup_key == original_dedup_key # Should remain the same
|
||||
assert task.started_at is not None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"terminal_state",
|
||||
[
|
||||
TaskStatus.SUCCESS,
|
||||
TaskStatus.FAILURE,
|
||||
TaskStatus.ABORTED,
|
||||
TaskStatus.TIMED_OUT,
|
||||
],
|
||||
)
|
||||
def test_conditional_status_update_terminal_state_updates_dedup_key(
|
||||
session_with_task: Session, terminal_state: TaskStatus
|
||||
) -> None:
|
||||
"""Test that terminal states (SUCCESS, FAILURE, ABORTED, TIMED_OUT)
|
||||
update dedup_key"""
|
||||
from superset.daos.tasks import TaskDAO
|
||||
|
||||
task = create_task(
|
||||
session_with_task,
|
||||
task_uuid=TASK_UUID,
|
||||
task_key=f"terminal-test-{terminal_state.value}",
|
||||
status=TaskStatus.IN_PROGRESS,
|
||||
)
|
||||
|
||||
original_dedup_key = task.dedup_key
|
||||
expected_finished_key = get_finished_dedup_key(TASK_UUID)
|
||||
|
||||
# Transition to terminal state
|
||||
result = TaskDAO.conditional_status_update(
|
||||
task_uuid=TASK_UUID,
|
||||
new_status=terminal_state,
|
||||
expected_status=TaskStatus.IN_PROGRESS,
|
||||
set_ended_at=True,
|
||||
)
|
||||
|
||||
assert result is True, f"Failed to update to {terminal_state.value}"
|
||||
|
||||
# Verify dedup_key was updated
|
||||
session_with_task.refresh(task)
|
||||
assert task.status == terminal_state.value
|
||||
assert task.dedup_key == expected_finished_key, (
|
||||
f"dedup_key not updated for {terminal_state.value}"
|
||||
)
|
||||
assert task.dedup_key != original_dedup_key, (
|
||||
f"dedup_key should have changed for {terminal_state.value}"
|
||||
)
|
||||
|
||||
@@ -455,8 +455,3 @@ class TestTaskManagerCompletion:
|
||||
timeout=5.0,
|
||||
poll_interval=0.1,
|
||||
)
|
||||
|
||||
def test_terminal_states_constant(self):
|
||||
"""Test TERMINAL_STATES contains expected values"""
|
||||
expected = {"success", "failure", "aborted", "timed_out"}
|
||||
assert TaskManager.TERMINAL_STATES == expected
|
||||
|
||||
Reference in New Issue
Block a user