mirror of
https://github.com/apache/superset.git
synced 2026-04-11 04:15:33 +00:00
678 lines
22 KiB
Python
678 lines
22 KiB
Python
# Licensed to the Apache Software Foundation (ASF) under one
|
|
# or more contributor license agreements. See the NOTICE file
|
|
# distributed with this work for additional information
|
|
# regarding copyright ownership. The ASF licenses this file
|
|
# to you under the Apache License, Version 2.0 (the
|
|
# "License"); you may not use this file except in compliance
|
|
# with the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing,
|
|
# software distributed under the License is distributed on an
|
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
# KIND, either express or implied. See the License for the
|
|
# specific language governing permissions and limitations
|
|
# under the License.
|
|
"""Unit tests for GTF handlers (abort, cleanup) and related Task model behavior."""
|
|
|
|
import time
|
|
from datetime import datetime, timezone
|
|
from unittest.mock import MagicMock, Mock, patch
|
|
from uuid import UUID
|
|
|
|
import pytest
|
|
from freezegun import freeze_time
|
|
from superset_core.api.tasks import TaskStatus
|
|
|
|
from superset.tasks.context import TaskContext
|
|
|
|
TEST_UUID = UUID("b8b61b7b-1cd3-4a31-a74a-0a95341afc06")
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_task():
|
|
"""Create a mock task for testing."""
|
|
task = MagicMock()
|
|
task.uuid = TEST_UUID
|
|
task.status = TaskStatus.PENDING.value
|
|
return task
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_task_dao(mock_task):
|
|
"""Mock TaskDAO to return our test task."""
|
|
with patch("superset.daos.tasks.TaskDAO") as mock_dao:
|
|
mock_dao.find_one_or_none.return_value = mock_task
|
|
yield mock_dao
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_update_command():
|
|
"""Mock UpdateTaskCommand to avoid database operations."""
|
|
with patch("superset.commands.tasks.update.UpdateTaskCommand") as mock_cmd:
|
|
mock_cmd.return_value.run.return_value = None
|
|
yield mock_cmd
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_flask_app():
|
|
"""Create a properly configured mock Flask app."""
|
|
mock_app = MagicMock()
|
|
mock_app.config = {
|
|
"TASK_ABORT_POLLING_DEFAULT_INTERVAL": 0.1,
|
|
}
|
|
# Make app_context() return a proper context manager
|
|
mock_app.app_context.return_value.__enter__ = MagicMock(return_value=None)
|
|
mock_app.app_context.return_value.__exit__ = MagicMock(return_value=None)
|
|
# Use regular Mock (not MagicMock) for _get_current_object to avoid
|
|
# AsyncMockMixin creating unawaited coroutines in Python 3.10+
|
|
mock_app._get_current_object = Mock(return_value=mock_app)
|
|
return mock_app
|
|
|
|
|
|
@pytest.fixture
|
|
def task_context(mock_task, mock_task_dao, mock_update_command, mock_flask_app):
|
|
"""Create TaskContext with mocked dependencies."""
|
|
# Ensure mock_task has properties_dict and payload_dict (TaskContext accesses them)
|
|
mock_task.properties_dict = {"is_abortable": False}
|
|
mock_task.payload_dict = {}
|
|
|
|
with (
|
|
patch("superset.tasks.context.current_app") as mock_current_app,
|
|
patch("superset.tasks.manager.cache_manager") as mock_cache_manager,
|
|
):
|
|
# Disable Redis by making distributed_coordination return None
|
|
mock_cache_manager.distributed_coordination = None
|
|
|
|
# Configure current_app mock
|
|
mock_current_app.config = mock_flask_app.config
|
|
# Use regular Mock (not MagicMock) for _get_current_object to avoid
|
|
# AsyncMockMixin creating unawaited coroutines in Python 3.10+
|
|
mock_current_app._get_current_object = Mock(return_value=mock_flask_app)
|
|
|
|
ctx = TaskContext(mock_task)
|
|
|
|
yield ctx
|
|
|
|
# Cleanup: stop polling if started
|
|
if ctx._abort_listener:
|
|
ctx.stop_abort_polling()
|
|
|
|
|
|
class TestTaskStatusEnum:
|
|
"""Test TaskStatus enum values."""
|
|
|
|
def test_aborting_status_exists(self):
|
|
"""Test that ABORTING status is defined."""
|
|
assert hasattr(TaskStatus, "ABORTING")
|
|
assert TaskStatus.ABORTING.value == "aborting"
|
|
|
|
def test_all_statuses_present(self):
|
|
"""Test all expected statuses are present."""
|
|
expected_statuses = [
|
|
"pending",
|
|
"in_progress",
|
|
"success",
|
|
"failure",
|
|
"aborting",
|
|
"aborted",
|
|
]
|
|
actual_statuses = [s.value for s in TaskStatus]
|
|
|
|
for status in expected_statuses:
|
|
assert status in actual_statuses, f"Missing status: {status}"
|
|
|
|
|
|
class TestTaskAbortProperties:
|
|
"""Test Task model abort-related properties via status and properties accessor."""
|
|
|
|
def test_aborting_status(self):
|
|
"""Test ABORTING status check."""
|
|
from superset.models.tasks import Task
|
|
|
|
task = Task()
|
|
task.status = TaskStatus.ABORTING.value
|
|
|
|
assert task.status == TaskStatus.ABORTING.value
|
|
|
|
def test_is_abortable_in_properties(self):
|
|
"""Test is_abortable is accessible via properties."""
|
|
from superset.models.tasks import Task
|
|
|
|
task = Task()
|
|
task.update_properties({"is_abortable": True})
|
|
|
|
assert task.properties_dict.get("is_abortable") is True
|
|
|
|
def test_is_abortable_default_none(self):
|
|
"""Test is_abortable defaults to None for new tasks."""
|
|
from superset.models.tasks import Task
|
|
|
|
task = Task()
|
|
|
|
assert task.properties_dict.get("is_abortable") is None
|
|
|
|
|
|
class TestTaskSetStatus:
|
|
"""Test Task.set_status behavior for abort states."""
|
|
|
|
def test_set_status_in_progress_sets_is_abortable_false(self):
|
|
"""Test that transitioning to IN_PROGRESS sets is_abortable to False."""
|
|
from superset.models.tasks import Task
|
|
|
|
task = Task()
|
|
task.uuid = "test-uuid"
|
|
# Default is None
|
|
|
|
task.set_status(TaskStatus.IN_PROGRESS)
|
|
|
|
assert task.properties_dict.get("is_abortable") is False
|
|
assert task.started_at is not None
|
|
|
|
def test_set_status_in_progress_preserves_existing_is_abortable(self):
|
|
"""Test that re-setting IN_PROGRESS doesn't override is_abortable."""
|
|
from superset.models.tasks import Task
|
|
|
|
task = Task()
|
|
task.uuid = "test-uuid"
|
|
task.update_properties(
|
|
{"is_abortable": True}
|
|
) # Already set by handler registration
|
|
task.started_at = datetime.now(timezone.utc) # Already started
|
|
|
|
task.set_status(TaskStatus.IN_PROGRESS)
|
|
|
|
# Should not override since started_at is already set
|
|
assert task.properties_dict.get("is_abortable") is True
|
|
|
|
def test_set_status_aborting_does_not_set_ended_at(self):
|
|
"""Test that ABORTING status does not set ended_at."""
|
|
from superset.models.tasks import Task
|
|
|
|
task = Task()
|
|
task.uuid = "test-uuid"
|
|
task.started_at = datetime.now(timezone.utc)
|
|
|
|
task.status = TaskStatus.ABORTING.value
|
|
|
|
assert task.ended_at is None
|
|
|
|
def test_set_status_aborted_sets_ended_at(self):
|
|
"""Test that ABORTED status sets ended_at."""
|
|
from superset.models.tasks import Task
|
|
|
|
task = Task()
|
|
task.uuid = "test-uuid"
|
|
task.started_at = datetime.now(timezone.utc)
|
|
|
|
task.set_status(TaskStatus.ABORTED)
|
|
|
|
assert task.ended_at is not None
|
|
|
|
|
|
class TestTaskDuration:
|
|
"""Test Task duration_seconds property with different states."""
|
|
|
|
def test_duration_seconds_finished_task(self):
|
|
"""Test duration for finished task returns actual duration."""
|
|
from superset.models.tasks import Task
|
|
|
|
task = Task()
|
|
task.status = TaskStatus.SUCCESS.value # Must be finished to use ended_at
|
|
task.started_at = datetime(2024, 1, 1, 10, 0, 0, tzinfo=timezone.utc)
|
|
task.ended_at = datetime(2024, 1, 1, 10, 0, 30, tzinfo=timezone.utc)
|
|
|
|
# Should use ended_at - started_at = 30 seconds
|
|
assert task.duration_seconds == 30.0
|
|
|
|
@freeze_time("2024-01-01 10:00:30")
|
|
def test_duration_seconds_running_task(self):
|
|
"""Test duration for running task returns time since start."""
|
|
from superset.models.tasks import Task
|
|
|
|
task = Task()
|
|
task.started_at = datetime(2024, 1, 1, 10, 0, 0, tzinfo=timezone.utc)
|
|
task.ended_at = None
|
|
|
|
# 30 seconds since start
|
|
assert task.duration_seconds == 30.0
|
|
|
|
@freeze_time("2024-01-01 10:00:15")
|
|
def test_duration_seconds_pending_task(self):
|
|
"""Test duration for pending task returns queue time."""
|
|
from superset.models.tasks import Task
|
|
|
|
task = Task()
|
|
task.created_on = datetime(2024, 1, 1, 10, 0, 0, tzinfo=timezone.utc)
|
|
task.started_at = None
|
|
task.ended_at = None
|
|
|
|
# 15 seconds since creation
|
|
assert task.duration_seconds == 15.0
|
|
|
|
def test_duration_seconds_no_timestamps(self):
|
|
"""Test duration returns None when no timestamps available."""
|
|
from superset.models.tasks import Task
|
|
|
|
task = Task()
|
|
task.created_on = None
|
|
task.started_at = None
|
|
task.ended_at = None
|
|
|
|
assert task.duration_seconds is None
|
|
|
|
|
|
class TestAbortHandlerRegistration:
|
|
"""Test abort handler registration and is_abortable flag."""
|
|
|
|
def test_on_abort_registers_handler(self, task_context):
|
|
"""Test that on_abort registers a handler."""
|
|
handler_called = False
|
|
|
|
@task_context.on_abort
|
|
def handle_abort():
|
|
nonlocal handler_called
|
|
handler_called = True
|
|
|
|
assert len(task_context._abort_handlers) == 1
|
|
assert not handler_called
|
|
|
|
@patch("superset.tasks.context.current_app")
|
|
def test_on_abort_sets_abortable(self, mock_app):
|
|
"""Test on_abort sets is_abortable to True on first handler."""
|
|
mock_app.config = {"TASK_ABORT_POLLING_DEFAULT_INTERVAL": 1.0}
|
|
mock_app._get_current_object = Mock(return_value=mock_app)
|
|
mock_task = MagicMock()
|
|
mock_task.uuid = TEST_UUID
|
|
mock_task.properties_dict = {"is_abortable": False}
|
|
mock_task.payload_dict = {}
|
|
|
|
with (
|
|
patch.object(TaskContext, "_set_abortable") as mock_set_abortable,
|
|
patch.object(TaskContext, "start_abort_polling"),
|
|
):
|
|
ctx = TaskContext(mock_task)
|
|
|
|
@ctx.on_abort
|
|
def handler():
|
|
pass
|
|
|
|
mock_set_abortable.assert_called_once()
|
|
|
|
@patch("superset.tasks.context.current_app")
|
|
def test_on_abort_only_sets_abortable_once(self, mock_app):
|
|
"""Test on_abort only calls _set_abortable for first handler."""
|
|
mock_app.config = {"TASK_ABORT_POLLING_DEFAULT_INTERVAL": 1.0}
|
|
mock_app._get_current_object = Mock(return_value=mock_app)
|
|
mock_task = MagicMock()
|
|
mock_task.uuid = TEST_UUID
|
|
mock_task.properties_dict = {"is_abortable": False}
|
|
mock_task.payload_dict = {}
|
|
|
|
with (
|
|
patch.object(TaskContext, "_set_abortable") as mock_set_abortable,
|
|
patch.object(TaskContext, "start_abort_polling"),
|
|
):
|
|
ctx = TaskContext(mock_task)
|
|
|
|
@ctx.on_abort
|
|
def handler1():
|
|
pass
|
|
|
|
@ctx.on_abort
|
|
def handler2():
|
|
pass
|
|
|
|
# Should only be called once for first handler
|
|
assert mock_set_abortable.call_count == 1
|
|
|
|
def test_abort_handlers_completed_initially_false(self):
|
|
"""Test abort_handlers_completed is False initially."""
|
|
mock_task = MagicMock()
|
|
mock_task.uuid = TEST_UUID
|
|
mock_task.properties_dict = {}
|
|
mock_task.payload_dict = {}
|
|
|
|
with patch("superset.tasks.context.current_app") as mock_app:
|
|
mock_app._get_current_object = Mock(return_value=mock_app)
|
|
ctx = TaskContext(mock_task)
|
|
assert ctx.abort_handlers_completed is False
|
|
|
|
|
|
class TestAbortPolling:
|
|
"""Test abort detection polling behavior."""
|
|
|
|
def test_on_abort_starts_polling_automatically(self, task_context):
|
|
"""Test that registering first handler starts abort listener."""
|
|
assert task_context._abort_listener is None
|
|
|
|
@task_context.on_abort
|
|
def handle_abort():
|
|
pass
|
|
|
|
assert task_context._abort_listener is not None
|
|
|
|
def test_stop_abort_polling(self, task_context):
|
|
"""Test that stop_abort_polling stops the abort listener."""
|
|
|
|
@task_context.on_abort
|
|
def handle_abort():
|
|
pass
|
|
|
|
assert task_context._abort_listener is not None
|
|
|
|
task_context.stop_abort_polling()
|
|
|
|
assert task_context._abort_listener is None
|
|
|
|
def test_start_abort_polling_only_once(self, task_context):
|
|
"""Test that start_abort_polling is idempotent."""
|
|
task_context.start_abort_polling(interval=0.1)
|
|
first_listener = task_context._abort_listener
|
|
|
|
# Try to start again
|
|
task_context.start_abort_polling(interval=0.1)
|
|
second_listener = task_context._abort_listener
|
|
|
|
# Should be the same listener
|
|
assert first_listener is second_listener
|
|
|
|
def test_on_abort_with_custom_interval(self, task_context):
|
|
"""Test that custom interval can be set via start_abort_polling."""
|
|
with patch("superset.tasks.context.current_app") as mock_app:
|
|
mock_app.config = {"TASK_ABORT_POLLING_DEFAULT_INTERVAL": 0.1}
|
|
mock_app._get_current_object = Mock(return_value=mock_app)
|
|
|
|
@task_context.on_abort
|
|
def handle_abort():
|
|
pass
|
|
|
|
# Override with custom interval
|
|
task_context.stop_abort_polling()
|
|
task_context.start_abort_polling(interval=0.05)
|
|
|
|
assert task_context._abort_listener is not None
|
|
|
|
def test_polling_stops_after_abort_detected(self, task_context, mock_task):
|
|
"""Test that abort is detected and handlers are triggered."""
|
|
|
|
@task_context.on_abort
|
|
def handle_abort():
|
|
pass
|
|
|
|
# Trigger abort
|
|
mock_task.status = TaskStatus.ABORTED.value
|
|
|
|
# Wait for detection
|
|
time.sleep(0.3)
|
|
|
|
# Abort should have been detected
|
|
assert task_context._abort_detected is True
|
|
|
|
|
|
class TestAbortHandlerExecution:
|
|
"""Test abort handler execution behavior."""
|
|
|
|
def test_on_abort_handler_fires_when_task_aborted(self, task_context, mock_task):
|
|
"""Test that abort handler fires automatically when task is aborted."""
|
|
abort_called = False
|
|
|
|
@task_context.on_abort
|
|
def handle_abort():
|
|
nonlocal abort_called
|
|
abort_called = True
|
|
|
|
# Simulate task being aborted
|
|
mock_task.status = TaskStatus.ABORTED.value
|
|
|
|
# Wait for polling to detect abort (max 0.3s with 0.1s interval)
|
|
time.sleep(0.3)
|
|
|
|
assert abort_called
|
|
assert task_context._abort_detected
|
|
|
|
def test_on_abort_not_called_on_success(self, task_context, mock_task):
|
|
"""Test that abort handlers don't run on success."""
|
|
abort_called = False
|
|
|
|
@task_context.on_abort
|
|
def handle_abort():
|
|
nonlocal abort_called
|
|
abort_called = True
|
|
|
|
# Keep task in success state
|
|
mock_task.status = TaskStatus.SUCCESS.value
|
|
|
|
# Wait and verify handler not called
|
|
time.sleep(0.3)
|
|
|
|
assert not abort_called
|
|
|
|
def test_multiple_abort_handlers(self, task_context, mock_task):
|
|
"""Test that all abort handlers execute in LIFO order."""
|
|
calls = []
|
|
|
|
@task_context.on_abort
|
|
def handler1():
|
|
calls.append(1)
|
|
|
|
@task_context.on_abort
|
|
def handler2():
|
|
calls.append(2)
|
|
|
|
# Trigger abort
|
|
mock_task.status = TaskStatus.ABORTED.value
|
|
|
|
# Wait for detection
|
|
time.sleep(0.3)
|
|
|
|
# LIFO order: handler2 runs first
|
|
assert calls == [2, 1]
|
|
|
|
def test_abort_handler_exception_doesnt_fail_task(self, task_context, mock_task):
|
|
"""Test that exception in abort handler is logged but doesn't fail task."""
|
|
handler2_called = False
|
|
|
|
@task_context.on_abort
|
|
def bad_handler():
|
|
raise ValueError("Handler error")
|
|
|
|
@task_context.on_abort
|
|
def good_handler():
|
|
nonlocal handler2_called
|
|
handler2_called = True
|
|
|
|
# Trigger abort
|
|
mock_task.status = TaskStatus.ABORTED.value
|
|
|
|
# Wait for detection
|
|
time.sleep(0.3)
|
|
|
|
# Second handler should still run despite first handler failing
|
|
assert handler2_called
|
|
|
|
|
|
class TestBestEffortHandlerExecution:
|
|
"""Test that all handlers execute even when some fail (best-effort)."""
|
|
|
|
def test_all_abort_handlers_run_even_if_all_fail(self, task_context, mock_task):
|
|
"""Test all abort handlers execute even if every one raises an exception."""
|
|
calls = []
|
|
|
|
@task_context.on_abort
|
|
def handler1():
|
|
calls.append(1)
|
|
raise ValueError("Handler 1 failed")
|
|
|
|
@task_context.on_abort
|
|
def handler2():
|
|
calls.append(2)
|
|
raise RuntimeError("Handler 2 failed")
|
|
|
|
@task_context.on_abort
|
|
def handler3():
|
|
calls.append(3)
|
|
raise TypeError("Handler 3 failed")
|
|
|
|
# Trigger abort handlers directly (simulating abort detection)
|
|
task_context._trigger_abort_handlers()
|
|
|
|
# All handlers should have been called (LIFO order: 3, 2, 1)
|
|
assert calls == [3, 2, 1]
|
|
|
|
# Failures should be collected (abort handlers don't write to DB)
|
|
assert len(task_context._handler_failures) == 3
|
|
failure_types = [
|
|
type(ex).__name__ for _, ex, _ in task_context._handler_failures
|
|
]
|
|
assert "TypeError" in failure_types
|
|
assert "RuntimeError" in failure_types
|
|
assert "ValueError" in failure_types
|
|
|
|
def test_all_cleanup_handlers_run_even_if_all_fail(self, task_context, mock_task):
|
|
"""Test all cleanup handlers execute even if every one raises an exception."""
|
|
calls = []
|
|
captured_failures = []
|
|
|
|
# Mock _write_handler_failures_to_db to capture failures before clearing
|
|
original_write = task_context._write_handler_failures_to_db
|
|
|
|
def mock_write():
|
|
captured_failures.extend(task_context._handler_failures)
|
|
original_write()
|
|
|
|
task_context._write_handler_failures_to_db = mock_write
|
|
|
|
@task_context.on_cleanup
|
|
def cleanup1():
|
|
calls.append(1)
|
|
raise ValueError("Cleanup 1 failed")
|
|
|
|
@task_context.on_cleanup
|
|
def cleanup2():
|
|
calls.append(2)
|
|
raise RuntimeError("Cleanup 2 failed")
|
|
|
|
@task_context.on_cleanup
|
|
def cleanup3():
|
|
calls.append(3)
|
|
raise TypeError("Cleanup 3 failed")
|
|
|
|
# Set task to SUCCESS (not aborting) so only cleanup handlers run
|
|
mock_task.status = TaskStatus.SUCCESS.value
|
|
|
|
# Run cleanup
|
|
task_context._run_cleanup()
|
|
|
|
# All handlers should have been called (LIFO order: 3, 2, 1)
|
|
assert calls == [3, 2, 1]
|
|
|
|
# Failures should have been captured before clearing
|
|
assert len(captured_failures) == 3
|
|
failure_types = [type(ex).__name__ for _, ex, _ in captured_failures]
|
|
assert "TypeError" in failure_types
|
|
assert "RuntimeError" in failure_types
|
|
assert "ValueError" in failure_types
|
|
|
|
def test_mixed_abort_and_cleanup_failures_all_collected(
|
|
self, task_context, mock_task
|
|
):
|
|
"""Test abort and cleanup handler failures are collected together."""
|
|
calls = []
|
|
captured_failures = []
|
|
|
|
# Mock _write_handler_failures_to_db to capture failures before clearing
|
|
original_write = task_context._write_handler_failures_to_db
|
|
|
|
def mock_write():
|
|
captured_failures.extend(task_context._handler_failures)
|
|
original_write()
|
|
|
|
task_context._write_handler_failures_to_db = mock_write
|
|
|
|
@task_context.on_abort
|
|
def abort1():
|
|
calls.append("abort1")
|
|
raise ValueError("Abort 1 failed")
|
|
|
|
@task_context.on_abort
|
|
def abort2():
|
|
calls.append("abort2")
|
|
raise RuntimeError("Abort 2 failed")
|
|
|
|
@task_context.on_cleanup
|
|
def cleanup1():
|
|
calls.append("cleanup1")
|
|
raise TypeError("Cleanup 1 failed")
|
|
|
|
@task_context.on_cleanup
|
|
def cleanup2():
|
|
calls.append("cleanup2")
|
|
raise KeyError("Cleanup 2 failed")
|
|
|
|
# Set task to ABORTING so both abort and cleanup handlers run
|
|
mock_task.status = TaskStatus.ABORTING.value
|
|
|
|
# Run cleanup (which triggers abort handlers first, then cleanup handlers)
|
|
task_context._run_cleanup()
|
|
|
|
# All handlers should have been called
|
|
# Abort handlers run first (LIFO: abort2, abort1)
|
|
# Then cleanup handlers (LIFO: cleanup2, cleanup1)
|
|
assert calls == ["abort2", "abort1", "cleanup2", "cleanup1"]
|
|
|
|
# All 4 failures should have been captured
|
|
assert len(captured_failures) == 4
|
|
|
|
# Verify handler types are recorded correctly
|
|
handler_types = [htype for htype, _, _ in captured_failures]
|
|
assert handler_types.count("abort") == 2
|
|
assert handler_types.count("cleanup") == 2
|
|
|
|
|
|
class TestCleanupHandlers:
|
|
"""Test cleanup handler behavior."""
|
|
|
|
def test_cleanup_triggers_abort_handlers_if_not_detected(
|
|
self, task_context, mock_task
|
|
):
|
|
"""Test that _run_cleanup triggers abort handlers if task ended aborted."""
|
|
abort_called = False
|
|
|
|
@task_context.on_abort
|
|
def handle_abort():
|
|
nonlocal abort_called
|
|
abort_called = True
|
|
|
|
# Set task as aborted but don't let polling detect it
|
|
mock_task.status = TaskStatus.ABORTED.value
|
|
task_context._abort_detected = False
|
|
|
|
# Immediately run cleanup (simulating task ending before poll)
|
|
task_context._run_cleanup()
|
|
|
|
assert abort_called
|
|
|
|
def test_cleanup_doesnt_duplicate_abort_handlers(self, task_context, mock_task):
|
|
"""Test that abort handlers only run once even if called from cleanup."""
|
|
call_count = 0
|
|
|
|
@task_context.on_abort
|
|
def handle_abort():
|
|
nonlocal call_count
|
|
call_count += 1
|
|
|
|
# Trigger abort via polling
|
|
mock_task.status = TaskStatus.ABORTED.value
|
|
time.sleep(0.3)
|
|
|
|
# Handlers should have been called once
|
|
assert call_count == 1
|
|
assert task_context._abort_detected is True
|
|
|
|
# Run cleanup - handlers should NOT be called again
|
|
task_context._run_cleanup()
|
|
|
|
assert call_count == 1 # Still 1, not 2
|