Files
superset2/tests/integration_tests/tasks/test_event_handlers.py
2026-02-09 10:45:56 -08:00

416 lines
14 KiB
Python

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""End-to-end integration tests for task event handlers (abort and cleanup)
These tests verify that abort and cleanup handlers work correctly through
the full task execution path using real @task decorated functions executed
via the Celery executor (synchronously via .apply()).
"""
from __future__ import annotations
import uuid
from typing import Any
from superset_core.api.tasks import TaskScope, TaskStatus
from superset.commands.tasks.cancel import CancelTaskCommand
from superset.daos.tasks import TaskDAO
from superset.extensions import db
from superset.models.tasks import Task
from superset.tasks.ambient_context import get_context
from superset.tasks.context import TaskContext
from superset.tasks.registry import TaskRegistry
from superset.tasks.scheduler import execute_task
from tests.integration_tests.base_tests import SupersetTestCase
from tests.integration_tests.constants import ADMIN_USERNAME
# Module-level state to track handler calls across test executions
# (Since decorated functions are defined at module level)
_handler_state: dict[str, Any] = {}
def _reset_handler_state():
"""Reset handler state before each test."""
global _handler_state
_handler_state = {
"cleanup_called": False,
"abort_called": False,
"cleanup_order": [],
"abort_order": [],
"cleanup_data": {},
}
def cleanup_test_task() -> None:
"""Task that registers a cleanup handler."""
ctx = get_context()
@ctx.on_cleanup
def handle_cleanup() -> None:
_handler_state["cleanup_called"] = True
# Simulate some work
ctx.update_task(progress=1.0)
def abort_test_task() -> None:
"""Task that registers an abort handler."""
ctx = get_context()
@ctx.on_abort
def handle_abort() -> None:
_handler_state["abort_called"] = True
def both_handlers_task() -> None:
"""Task that registers both abort and cleanup handlers."""
ctx = get_context()
@ctx.on_abort
def handle_abort() -> None:
_handler_state["abort_called"] = True
_handler_state["abort_order"].append("abort")
@ctx.on_cleanup
def handle_cleanup() -> None:
_handler_state["cleanup_called"] = True
_handler_state["cleanup_order"].append("cleanup")
def multiple_cleanup_handlers_task() -> None:
"""Task that registers multiple cleanup handlers."""
ctx = get_context()
@ctx.on_cleanup
def cleanup_first() -> None:
_handler_state["cleanup_order"].append("first")
@ctx.on_cleanup
def cleanup_second() -> None:
_handler_state["cleanup_order"].append("second")
@ctx.on_cleanup
def cleanup_third() -> None:
_handler_state["cleanup_order"].append("third")
def cleanup_with_data_task() -> None:
"""Task that uses cleanup handler to clean up partial work."""
ctx = get_context()
# Simulate partial work in module-level state
_handler_state["cleanup_data"]["temp_key"] = "temp_value"
@ctx.on_cleanup
def handle_cleanup() -> None:
# Clean up the partial work
_handler_state["cleanup_data"].clear()
_handler_state["cleanup_called"] = True
def _register_test_tasks() -> None:
"""Register test task functions if not already registered.
Called in setUp() to ensure tasks are registered regardless of
whether other tests have cleared the registry.
"""
registrations = [
("test_cleanup_task", cleanup_test_task),
("test_abort_task", abort_test_task),
("test_both_handlers_task", both_handlers_task),
("test_multiple_cleanup_task", multiple_cleanup_handlers_task),
("test_cleanup_with_data", cleanup_with_data_task),
]
for name, func in registrations:
if not TaskRegistry.is_registered(name):
TaskRegistry.register(name, func)
class TestCleanupHandlers(SupersetTestCase):
"""E2E tests for on_cleanup functionality using Celery executor."""
def setUp(self):
"""Set up test fixtures."""
super().setUp()
self.login(ADMIN_USERNAME)
_register_test_tasks()
_reset_handler_state()
def test_cleanup_handler_fires_on_success(self):
"""Test cleanup handler runs when task completes successfully."""
# Create task entry directly
task_obj = TaskDAO.create_task(
task_type="test_cleanup_task",
task_key=f"test_key_{uuid.uuid4().hex[:8]}",
task_name="Test Cleanup",
scope=TaskScope.SYSTEM,
)
# Execute task synchronously through Celery executor
# Use str(uuid) since Celery serializes args as JSON strings
result = execute_task.apply(
args=[str(task_obj.uuid), "test_cleanup_task", (), {}]
)
# Verify task completed successfully
assert result.successful()
assert result.result["status"] == TaskStatus.SUCCESS.value
# Verify cleanup handler was called
assert _handler_state["cleanup_called"]
def test_multiple_cleanup_handlers_in_lifo_order(self):
"""Test multiple cleanup handlers execute in LIFO order."""
task_obj = TaskDAO.create_task(
task_type="test_multiple_cleanup_task",
task_key=f"test_key_{uuid.uuid4().hex[:8]}",
task_name="Test Multiple Cleanup",
scope=TaskScope.SYSTEM,
)
result = execute_task.apply(
args=[str(task_obj.uuid), "test_multiple_cleanup_task", (), {}]
)
assert result.successful()
# Handlers should execute in LIFO order (last registered first)
assert _handler_state["cleanup_order"] == ["third", "second", "first"]
def test_cleanup_handler_cleans_up_partial_work(self):
"""Test cleanup handler can clean up partial work."""
task_obj = TaskDAO.create_task(
task_type="test_cleanup_with_data",
task_key=f"test_key_{uuid.uuid4().hex[:8]}",
task_name="Test Cleanup Data",
scope=TaskScope.SYSTEM,
)
result = execute_task.apply(
args=[str(task_obj.uuid), "test_cleanup_with_data", (), {}]
)
assert result.successful()
assert _handler_state["cleanup_called"]
# Cleanup handler should have cleared the data
assert len(_handler_state["cleanup_data"]) == 0
class TestAbortHandlers(SupersetTestCase):
"""E2E tests for on_abort functionality."""
def setUp(self):
"""Set up test fixtures."""
super().setUp()
self.login(ADMIN_USERNAME)
_register_test_tasks()
_reset_handler_state()
def test_abort_handler_fires_when_task_aborting(self):
"""Test abort handler runs when task is in ABORTING state during cleanup."""
# Create task entry
task_obj = TaskDAO.create_task(
task_type="test_abort_task",
task_key=f"test_key_{uuid.uuid4().hex[:8]}",
task_name="Test Abort",
scope=TaskScope.SYSTEM,
)
# Manually set to IN_PROGRESS and then ABORTING to simulate abort
task_obj.status = TaskStatus.IN_PROGRESS.value
task_obj.update_properties({"is_abortable": True})
db.session.merge(task_obj)
db.session.commit()
# Refresh to get the updated task
db.session.refresh(task_obj)
# Create context (simulating what executor does)
ctx = TaskContext(task_obj)
# Register abort handler
@ctx.on_abort
def handle_abort():
_handler_state["abort_called"] = True
# Set status to ABORTING (simulating CancelTaskCommand)
task_obj.status = TaskStatus.ABORTING.value
db.session.merge(task_obj)
db.session.commit()
# Run cleanup (simulating executor's finally block)
ctx._run_cleanup()
# Verify abort handler was called
assert _handler_state["abort_called"]
def test_both_handlers_fire_on_abort(self):
"""Test both abort and cleanup handlers run when task is aborted."""
task_obj = TaskDAO.create_task(
task_type="test_both_handlers_task",
task_key=f"test_key_{uuid.uuid4().hex[:8]}",
task_name="Test Both Handlers",
scope=TaskScope.SYSTEM,
)
task_obj.status = TaskStatus.IN_PROGRESS.value
task_obj.update_properties({"is_abortable": True})
db.session.merge(task_obj)
db.session.commit()
# Refresh to get the updated task
db.session.refresh(task_obj)
ctx = TaskContext(task_obj)
@ctx.on_abort
def handle_abort():
_handler_state["abort_called"] = True
_handler_state["abort_order"].append("abort")
@ctx.on_cleanup
def handle_cleanup():
_handler_state["cleanup_called"] = True
_handler_state["cleanup_order"].append("cleanup")
# Set to ABORTING
task_obj.status = TaskStatus.ABORTING.value
db.session.merge(task_obj)
db.session.commit()
ctx._run_cleanup()
# Both should have been called
assert _handler_state["abort_called"]
assert _handler_state["cleanup_called"]
def test_abort_handler_not_called_on_success(self):
"""Test abort handler doesn't run when task succeeds."""
task_obj = TaskDAO.create_task(
task_type="test_abort_task",
task_key=f"test_key_{uuid.uuid4().hex[:8]}",
task_name="Test No Abort on Success",
scope=TaskScope.SYSTEM,
)
task_obj.status = TaskStatus.SUCCESS.value
db.session.merge(task_obj)
db.session.commit()
# Refresh to get the updated task
db.session.refresh(task_obj)
ctx = TaskContext(task_obj)
@ctx.on_abort
def handle_abort():
_handler_state["abort_called"] = True
@ctx.on_cleanup
def handle_cleanup():
_handler_state["cleanup_called"] = True
ctx._run_cleanup()
# Abort handler should NOT be called
assert not _handler_state["abort_called"]
# Cleanup handler should still be called
assert _handler_state["cleanup_called"]
class TestTaskContextMethods(SupersetTestCase):
"""Tests for TaskContext public methods."""
def setUp(self):
"""Set up test fixtures."""
super().setUp()
self.login(ADMIN_USERNAME)
def test_on_abort_marks_task_abortable(self):
"""Test that registering an on_abort handler marks task as abortable."""
task_obj = TaskDAO.create_task(
task_type="test_abortable_flag",
task_key=f"test_key_{uuid.uuid4().hex[:8]}",
task_name="Test Abortable",
scope=TaskScope.SYSTEM,
)
assert task_obj.properties_dict.get("is_abortable") is not True
ctx = TaskContext(task_obj)
@ctx.on_abort
def handle_abort():
pass
db.session.expire_all()
task_obj = db.session.query(Task).filter_by(uuid=task_obj.uuid).first()
assert task_obj.properties_dict.get("is_abortable") is True
class TestAbortBeforeExecution(SupersetTestCase):
"""Tests for aborting tasks before they start executing."""
def setUp(self):
"""Set up test fixtures."""
super().setUp()
self.login(ADMIN_USERNAME)
_register_test_tasks()
def test_abort_pending_task(self):
"""Test that pending tasks can be aborted directly."""
task_obj = TaskDAO.create_task(
task_type="test_abort_before_start",
task_key=f"test_key_{uuid.uuid4().hex[:8]}",
task_name="Test Before Start",
scope=TaskScope.SYSTEM,
)
# Cancel immediately (task is still PENDING)
CancelTaskCommand(task_obj.uuid, force=True).run()
db.session.expire_all()
task_obj = db.session.query(Task).filter_by(uuid=task_obj.uuid).first()
assert task_obj.status == TaskStatus.ABORTED.value
def test_executor_skips_aborted_task(self):
"""Test that executor skips tasks already aborted before execution."""
task_obj = TaskDAO.create_task(
task_type="test_cleanup_task",
task_key=f"test_key_{uuid.uuid4().hex[:8]}",
task_name="Test Skip Aborted",
scope=TaskScope.SYSTEM,
)
# Abort the task before execution
task_obj.status = TaskStatus.ABORTED.value
db.session.merge(task_obj)
db.session.commit()
_reset_handler_state()
# Try to execute - should skip
result = execute_task.apply(
args=[str(task_obj.uuid), "test_cleanup_task", (), {}]
)
assert result.successful()
assert result.result["status"] == TaskStatus.ABORTED.value
# Cleanup handler should NOT have been called (task was skipped)
assert not _handler_state["cleanup_called"]