fix(gtf): set dedup_key on atomic sql (#37820)

This commit is contained in:
Ville Brofeldt
2026-02-10 06:56:14 -08:00
committed by GitHub
parent 948b1d613b
commit 0f1278fa61
9 changed files with 106 additions and 47 deletions

View File

@@ -50,7 +50,7 @@ When GTF is considered stable, it will replace legacy Celery tasks for built-in
### Define a Task
```python
from superset_core.api.types import task, get_context
from superset_core.api.tasks import task, get_context
@task
def process_data(dataset_id: int) -> None:
@@ -245,7 +245,7 @@ Always implement an abort handler for long-running tasks. This allows users to c
Set a timeout to automatically abort tasks that run too long:
```python
from superset_core.api.types import task, get_context, TaskOptions
from superset_core.api.tasks import task, get_context, TaskOptions
# Set default timeout in decorator
@task(timeout=300) # 5 minutes
@@ -299,7 +299,7 @@ Timeouts require an abort handler to be effective. Without one, the timeout trig
Use `task_key` to prevent duplicate task execution:
```python
from superset_core.api.types import TaskOptions
from superset_core.api.tasks import TaskOptions
# Without key - creates new task each time (random UUID)
task1 = my_task.schedule(x=1)
@@ -331,7 +331,7 @@ print(task2.status) # "success" (terminal status)
## Task Scopes
```python
from superset_core.api.types import task, TaskScope
from superset_core.api.tasks import task, TaskScope
@task # Private by default
def private_task(): ...

View File

@@ -259,7 +259,7 @@ def task(
is discarded; only side effects and context updates matter.
Example:
from superset_core.api.types import task, get_context, TaskScope
from superset_core.api.tasks import task, get_context, TaskScope
# Private task (default scope)
@task

View File

@@ -28,9 +28,9 @@ from superset.daos.exceptions import DAODeleteFailedError
from superset.extensions import db
from superset.models.task_subscribers import TaskSubscriber
from superset.models.tasks import Task
from superset.tasks.constants import ABORTABLE_STATES
from superset.tasks.constants import ABORTABLE_STATES, TERMINAL_STATES
from superset.tasks.filters import TaskFilter
from superset.tasks.utils import get_active_dedup_key, json
from superset.tasks.utils import get_active_dedup_key, get_finished_dedup_key, json
logger = logging.getLogger(__name__)
@@ -243,7 +243,7 @@ class TaskDAO(BaseDAO[Task]):
)
# Transition to ABORTING (not ABORTED yet)
task.status = TaskStatus.ABORTING.value
task.set_status(TaskStatus.ABORTING)
db.session.merge(task)
logger.info("Set task %s to ABORTING (scope: %s)", task_uuid, task.scope)
@@ -444,6 +444,10 @@ class TaskDAO(BaseDAO[Task]):
if set_ended_at:
update_values["ended_at"] = datetime.now(timezone.utc)
# Update dedup_key if transitioning to terminal state
if new_status_val in TERMINAL_STATES:
update_values["dedup_key"] = get_finished_dedup_key(task_uuid)
# Atomic compare-and-swap: only update if status matches expected
rows_updated = (
db.session.query(Task)

View File

@@ -37,6 +37,7 @@ from superset_core.api.tasks import TaskProperties, TaskStatus
from superset.models.helpers import AuditMixinNullable
from superset.models.task_subscribers import TaskSubscriber
from superset.tasks.constants import TERMINAL_STATES
from superset.tasks.utils import (
error_update,
get_finished_dedup_key,
@@ -218,12 +219,7 @@ class Task(CoreTask, AuditMixinNullable, Model):
# (will be set to True if/when an abort handler is registered)
if self.properties_dict.get("is_abortable") is None:
self.update_properties({"is_abortable": False})
elif status in [
TaskStatus.SUCCESS.value,
TaskStatus.FAILURE.value,
TaskStatus.ABORTED.value,
TaskStatus.TIMED_OUT.value,
]:
elif status in TERMINAL_STATES:
if not self.ended_at:
self.ended_at = now
# Update dedup_key to UUID to free up the slot for new tasks
@@ -244,12 +240,7 @@ class Task(CoreTask, AuditMixinNullable, Model):
@property
def is_finished(self) -> bool:
"""Check if task has finished (success, failure, aborted, or timed out)."""
return self.status in [
TaskStatus.SUCCESS.value,
TaskStatus.FAILURE.value,
TaskStatus.ABORTED.value,
TaskStatus.TIMED_OUT.value,
]
return self.status in TERMINAL_STATES
@property
def is_successful(self) -> bool:

View File

@@ -112,9 +112,6 @@ class TaskManager:
_completion_channel_prefix: str = "gtf:complete:"
_initialized: bool = False
# Backward compatibility alias - prefer importing from superset.tasks.constants
TERMINAL_STATES = TERMINAL_STATES
@classmethod
def init_app(cls, app: Flask) -> None:
"""
@@ -271,7 +268,7 @@ class TaskManager:
if not task:
raise ValueError(f"Task {task_uuid} not found")
if task.status in cls.TERMINAL_STATES:
if task.status in TERMINAL_STATES:
return task
logger.debug(
@@ -342,13 +339,13 @@ class TaskManager:
message.get("data"),
)
task = get_task()
if task and task.status in cls.TERMINAL_STATES:
if task and task.status in TERMINAL_STATES:
return task
# Also check database periodically in case we missed the message
# (e.g., task completed before we subscribed)
task = get_task()
if task and task.status in cls.TERMINAL_STATES:
if task and task.status in TERMINAL_STATES:
logger.debug(
"Task %s completed (detected via db check): status=%s",
task_uuid,
@@ -384,7 +381,7 @@ class TaskManager:
if not task:
raise ValueError(f"Task {task_uuid} not found")
if task.status in cls.TERMINAL_STATES:
if task.status in TERMINAL_STATES:
logger.debug(
"Task %s completed (detected via polling): status=%s",
task_uuid,

View File

@@ -25,6 +25,9 @@ get_delete_ids_schema = {"type": "array", "items": {"type": "string"}}
# Field descriptions
uuid_description = "The unique identifier (UUID) of the task"
task_key_description = "The task identifier used for deduplication"
dedup_key_description = (
"The hashed deduplication key used internally for task deduplication"
)
task_type_description = (
"The type of task (e.g., 'sql_execution', 'thumbnail_generation')"
)
@@ -74,6 +77,7 @@ class TaskResponseSchema(Schema):
id = fields.Int(metadata={"description": "Internal task ID"})
uuid = fields.UUID(metadata={"description": uuid_description})
task_key = fields.String(metadata={"description": task_key_description})
dedup_key = fields.String(metadata={"description": dedup_key_description})
task_type = fields.String(metadata={"description": task_type_description})
task_name = fields.String(
metadata={"description": task_name_description}, allow_none=True

View File

@@ -68,21 +68,6 @@ def test_submit_task_distinguishes_new_vs_existing(
db.session.commit()
def test_terminal_states_recognized_correctly(app_context) -> None:
"""
Test that TaskManager.TERMINAL_STATES contains the expected values.
"""
assert TaskStatus.SUCCESS.value in TaskManager.TERMINAL_STATES
assert TaskStatus.FAILURE.value in TaskManager.TERMINAL_STATES
assert TaskStatus.ABORTED.value in TaskManager.TERMINAL_STATES
assert TaskStatus.TIMED_OUT.value in TaskManager.TERMINAL_STATES
# Non-terminal states should not be in the set
assert TaskStatus.PENDING.value not in TaskManager.TERMINAL_STATES
assert TaskStatus.IN_PROGRESS.value not in TaskManager.TERMINAL_STATES
assert TaskStatus.ABORTING.value not in TaskManager.TERMINAL_STATES
def test_wait_for_completion_timeout(app_context, login_as, get_user) -> None:
"""
Test that wait_for_completion raises TimeoutError on timeout.

View File

@@ -418,3 +418,86 @@ def test_get_status_not_found(session_with_task: Session) -> None:
result = TaskDAO.get_status(UUID("00000000-0000-0000-0000-000000000000"))
assert result is None
def test_conditional_status_update_non_terminal_state_keeps_dedup_key(
session_with_task: Session,
) -> None:
"""Test that conditional_status_update preserves dedup_key for
non-terminal transitions"""
from superset.daos.tasks import TaskDAO
# Create task in PENDING state
task = create_task(
session_with_task,
task_uuid=TASK_UUID,
task_key="non-terminal-test-task",
status=TaskStatus.PENDING,
)
# Store original active dedup_key
original_dedup_key = task.dedup_key
# Transition to non-terminal state (IN_PROGRESS)
result = TaskDAO.conditional_status_update(
task_uuid=TASK_UUID,
new_status=TaskStatus.IN_PROGRESS,
expected_status=TaskStatus.PENDING,
set_started_at=True,
)
# Should succeed
assert result is True
# Refresh task and verify dedup_key was NOT changed
session_with_task.refresh(task)
assert task.status == TaskStatus.IN_PROGRESS.value
assert task.dedup_key == original_dedup_key # Should remain the same
assert task.started_at is not None
@pytest.mark.parametrize(
"terminal_state",
[
TaskStatus.SUCCESS,
TaskStatus.FAILURE,
TaskStatus.ABORTED,
TaskStatus.TIMED_OUT,
],
)
def test_conditional_status_update_terminal_state_updates_dedup_key(
session_with_task: Session, terminal_state: TaskStatus
) -> None:
"""Test that terminal states (SUCCESS, FAILURE, ABORTED, TIMED_OUT)
update dedup_key"""
from superset.daos.tasks import TaskDAO
task = create_task(
session_with_task,
task_uuid=TASK_UUID,
task_key=f"terminal-test-{terminal_state.value}",
status=TaskStatus.IN_PROGRESS,
)
original_dedup_key = task.dedup_key
expected_finished_key = get_finished_dedup_key(TASK_UUID)
# Transition to terminal state
result = TaskDAO.conditional_status_update(
task_uuid=TASK_UUID,
new_status=terminal_state,
expected_status=TaskStatus.IN_PROGRESS,
set_ended_at=True,
)
assert result is True, f"Failed to update to {terminal_state.value}"
# Verify dedup_key was updated
session_with_task.refresh(task)
assert task.status == terminal_state.value
assert task.dedup_key == expected_finished_key, (
f"dedup_key not updated for {terminal_state.value}"
)
assert task.dedup_key != original_dedup_key, (
f"dedup_key should have changed for {terminal_state.value}"
)

View File

@@ -455,8 +455,3 @@ class TestTaskManagerCompletion:
timeout=5.0,
poll_interval=0.1,
)
def test_terminal_states_constant(self):
"""Test TERMINAL_STATES contains expected values"""
expected = {"success", "failure", "aborted", "timed_out"}
assert TaskManager.TERMINAL_STATES == expected