feat: add global task framework (#36368)

This commit is contained in:
Ville Brofeldt
2026-02-09 10:45:56 -08:00
committed by GitHub
parent 6984e93171
commit 59dd2fa385
89 changed files with 15535 additions and 291 deletions

View File

@@ -0,0 +1,420 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file to you under
# the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from collections.abc import Iterator
from uuid import UUID
import pytest
from sqlalchemy.orm.session import Session
from superset_core.api.tasks import TaskProperties, TaskScope, TaskStatus
from superset.commands.tasks.exceptions import TaskNotAbortableError
from superset.models.tasks import Task
from superset.tasks.utils import get_active_dedup_key, get_finished_dedup_key
# Test constants
TASK_UUID = UUID("e7765491-40c1-4f35-a4f5-06308e79310e")
TASK_ID = 42
TEST_TASK_TYPE = "test_type"
TEST_TASK_KEY = "test-key"
TEST_USER_ID = 1
def create_task(
session: Session,
*,
task_id: int | None = None,
task_uuid: UUID | None = None,
task_key: str = TEST_TASK_KEY,
task_type: str = TEST_TASK_TYPE,
scope: TaskScope = TaskScope.PRIVATE,
status: TaskStatus = TaskStatus.PENDING,
user_id: int | None = TEST_USER_ID,
properties: TaskProperties | None = None,
use_finished_dedup_key: bool = False,
) -> Task:
"""Helper to create a task with sensible defaults for testing."""
if use_finished_dedup_key:
dedup_key = get_finished_dedup_key(task_uuid or TASK_UUID)
else:
dedup_key = get_active_dedup_key(
scope=scope,
task_type=task_type,
task_key=task_key,
user_id=user_id,
)
task = Task(
task_type=task_type,
task_key=task_key,
scope=scope.value,
status=status.value,
dedup_key=dedup_key,
user_id=user_id,
)
if task_id is not None:
task.id = task_id
if task_uuid:
task.uuid = task_uuid
if properties:
task.update_properties(properties)
session.add(task)
session.flush()
return task
@pytest.fixture
def session_with_task(session: Session) -> Iterator[Session]:
"""Create a session with Task and TaskSubscriber tables."""
from superset.models.task_subscribers import TaskSubscriber
engine = session.get_bind()
Task.metadata.create_all(engine)
TaskSubscriber.metadata.create_all(engine)
yield session
session.rollback()
def test_find_by_task_key_active(session_with_task: Session) -> None:
"""Test finding active task by task_key"""
from superset.daos.tasks import TaskDAO
create_task(session_with_task)
result = TaskDAO.find_by_task_key(
task_type=TEST_TASK_TYPE,
task_key=TEST_TASK_KEY,
scope=TaskScope.PRIVATE,
user_id=TEST_USER_ID,
)
assert result is not None
assert result.task_key == TEST_TASK_KEY
assert result.task_type == TEST_TASK_TYPE
assert result.status == TaskStatus.PENDING.value
def test_find_by_task_key_not_found(session_with_task: Session) -> None:
"""Test finding task by task_key returns None when not found"""
from superset.daos.tasks import TaskDAO
result = TaskDAO.find_by_task_key(
task_type=TEST_TASK_TYPE,
task_key="nonexistent-key",
scope=TaskScope.PRIVATE,
user_id=TEST_USER_ID,
)
assert result is None
def test_find_by_task_key_finished_not_found(session_with_task: Session) -> None:
"""Test that find_by_task_key returns None for finished tasks.
Finished tasks have a different dedup_key format (UUID-based),
so they won't be found by the active task lookup.
"""
from superset.daos.tasks import TaskDAO
create_task(
session_with_task,
task_key="finished-key",
status=TaskStatus.SUCCESS,
use_finished_dedup_key=True,
task_uuid=TASK_UUID,
)
# Should not find SUCCESS task via active lookup
result = TaskDAO.find_by_task_key(
task_type=TEST_TASK_TYPE,
task_key="finished-key",
scope=TaskScope.PRIVATE,
user_id=TEST_USER_ID,
)
assert result is None
def test_create_task_success(session_with_task: Session) -> None:
"""Test successful task creation."""
from superset.daos.tasks import TaskDAO
result = TaskDAO.create_task(
task_type=TEST_TASK_TYPE,
task_key=TEST_TASK_KEY,
scope=TaskScope.PRIVATE,
user_id=TEST_USER_ID,
)
assert result is not None
assert result.task_key == TEST_TASK_KEY
assert result.task_type == TEST_TASK_TYPE
assert result.status == TaskStatus.PENDING.value
assert isinstance(result, Task)
def test_create_task_with_user_id(session_with_task: Session) -> None:
"""Test task creation with explicit user_id."""
from superset.daos.tasks import TaskDAO
result = TaskDAO.create_task(
task_type=TEST_TASK_TYPE,
task_key="user-task",
scope=TaskScope.PRIVATE,
user_id=42,
)
assert result is not None
assert result.user_id == 42
# Creator should be auto-subscribed
assert len(result.subscribers) == 1
assert result.subscribers[0].user_id == 42
def test_create_task_with_properties(session_with_task: Session) -> None:
"""Test task creation with properties."""
from superset.daos.tasks import TaskDAO
result = TaskDAO.create_task(
task_type=TEST_TASK_TYPE,
task_key="props-task",
scope=TaskScope.PRIVATE,
user_id=TEST_USER_ID,
properties={"timeout": 300},
)
assert result is not None
assert result.properties_dict.get("timeout") == 300
def test_abort_task_pending_success(session_with_task: Session) -> None:
"""Test successful abort of pending task - goes directly to ABORTED"""
from superset.daos.tasks import TaskDAO
task = create_task(
session_with_task,
task_key="pending-task",
status=TaskStatus.PENDING,
)
result = TaskDAO.abort_task(task.uuid, skip_base_filter=True)
assert result is not None
assert result.status == TaskStatus.ABORTED.value
def test_abort_task_in_progress_abortable(session_with_task: Session) -> None:
"""Test abort of in-progress task with abort handler.
Should transition to ABORTING status.
"""
from superset.daos.tasks import TaskDAO
task = create_task(
session_with_task,
task_key="abortable-task",
status=TaskStatus.IN_PROGRESS,
properties={"is_abortable": True},
)
result = TaskDAO.abort_task(task.uuid, skip_base_filter=True)
assert result is not None
# Should set status to ABORTING, not ABORTED
assert result.status == TaskStatus.ABORTING.value
def test_abort_task_in_progress_not_abortable(session_with_task: Session) -> None:
"""Test abort of in-progress task without abort handler - raises error"""
from superset.daos.tasks import TaskDAO
task = create_task(
session_with_task,
task_key="non-abortable-task",
status=TaskStatus.IN_PROGRESS,
properties={"is_abortable": False},
)
with pytest.raises(TaskNotAbortableError):
TaskDAO.abort_task(task.uuid, skip_base_filter=True)
def test_abort_task_in_progress_is_abortable_none(session_with_task: Session) -> None:
"""Test abort of in-progress task with is_abortable not set - raises error"""
from superset.daos.tasks import TaskDAO
task = create_task(
session_with_task,
task_key="no-abortable-prop-task",
status=TaskStatus.IN_PROGRESS,
# Empty properties - no is_abortable key
)
with pytest.raises(TaskNotAbortableError):
TaskDAO.abort_task(task.uuid, skip_base_filter=True)
def test_abort_task_already_aborting(session_with_task: Session) -> None:
"""Test abort of already aborting task - idempotent success"""
from superset.daos.tasks import TaskDAO
task = create_task(
session_with_task,
task_key="aborting-task",
status=TaskStatus.ABORTING,
)
result = TaskDAO.abort_task(task.uuid, skip_base_filter=True)
# Idempotent - returns task without error
assert result is not None
assert result.status == TaskStatus.ABORTING.value
def test_abort_task_not_found(session_with_task: Session) -> None:
"""Test abort fails when task not found"""
from superset.daos.tasks import TaskDAO
result = TaskDAO.abort_task(UUID("00000000-0000-0000-0000-000000000000"))
assert result is None
def test_abort_task_already_finished(session_with_task: Session) -> None:
"""Test abort fails when task already finished"""
from superset.daos.tasks import TaskDAO
task = create_task(
session_with_task,
task_key="finished-task",
status=TaskStatus.SUCCESS,
use_finished_dedup_key=True,
task_uuid=TASK_UUID,
)
result = TaskDAO.abort_task(task.uuid, skip_base_filter=True)
assert result is None
def test_add_subscriber(session_with_task: Session) -> None:
"""Test adding a subscriber to a task"""
from superset.daos.tasks import TaskDAO
task = create_task(
session_with_task,
task_key="shared-task",
scope=TaskScope.SHARED,
user_id=None,
)
# Add subscriber
result = TaskDAO.add_subscriber(task.id, user_id=TEST_USER_ID)
assert result is True
# Verify subscriber was added
session_with_task.refresh(task)
assert len(task.subscribers) == 1
assert task.subscribers[0].user_id == TEST_USER_ID
def test_add_subscriber_idempotent(session_with_task: Session) -> None:
"""Test adding same subscriber twice is idempotent"""
from superset.daos.tasks import TaskDAO
task = create_task(
session_with_task,
task_key="shared-task-2",
scope=TaskScope.SHARED,
user_id=None,
)
# Add subscriber twice
result1 = TaskDAO.add_subscriber(task.id, user_id=TEST_USER_ID)
result2 = TaskDAO.add_subscriber(task.id, user_id=TEST_USER_ID)
assert result1 is True
assert result2 is False # Already subscribed
# Verify only one subscriber
session_with_task.refresh(task)
assert len(task.subscribers) == 1
def test_remove_subscriber(session_with_task: Session) -> None:
"""Test removing a subscriber from a task"""
from superset.daos.tasks import TaskDAO
task = create_task(
session_with_task,
task_key="shared-task-3",
scope=TaskScope.SHARED,
user_id=None,
)
TaskDAO.add_subscriber(task.id, user_id=TEST_USER_ID)
session_with_task.refresh(task)
assert len(task.subscribers) == 1
# Remove subscriber
result = TaskDAO.remove_subscriber(task.id, user_id=TEST_USER_ID)
assert result is not None
assert len(result.subscribers) == 0
def test_remove_subscriber_not_subscribed(session_with_task: Session) -> None:
"""Test removing non-existent subscriber returns None"""
from superset.daos.tasks import TaskDAO
task = create_task(
session_with_task,
task_key="shared-task-4",
scope=TaskScope.SHARED,
user_id=None,
)
# Try to remove non-existent subscriber
result = TaskDAO.remove_subscriber(task.id, user_id=999)
assert result is None
def test_get_status(session_with_task: Session) -> None:
"""Test get_status returns status string when task found by UUID"""
from superset.daos.tasks import TaskDAO
task = create_task(
session_with_task,
task_uuid=TASK_UUID,
task_key="status-task",
status=TaskStatus.IN_PROGRESS,
)
result = TaskDAO.get_status(task.uuid)
assert result == TaskStatus.IN_PROGRESS.value
def test_get_status_not_found(session_with_task: Session) -> None:
"""Test get_status returns None when task not found"""
from superset.daos.tasks import TaskDAO
result = TaskDAO.get_status(UUID("00000000-0000-0000-0000-000000000000"))
assert result is None

View File

@@ -18,17 +18,21 @@
# pylint: disable=invalid-name
from typing import Any
from unittest.mock import MagicMock, patch
from uuid import UUID
import pytest
from freezegun import freeze_time
from sqlalchemy.orm import Session, sessionmaker
# Force module loading before tests run so patches work correctly
import superset.commands.distributed_lock.acquire as acquire_module
import superset.commands.distributed_lock.release as release_module
from superset import db
from superset.distributed_lock import KeyValueDistributedLock
from superset.distributed_lock import DistributedLock
from superset.distributed_lock.types import LockValue
from superset.distributed_lock.utils import get_key
from superset.exceptions import CreateKeyValueDistributedLockFailedException
from superset.exceptions import AcquireDistributedLockFailedException
from superset.key_value.types import JsonKeyValueCodec
LOCK_VALUE: LockValue = {"value": True}
@@ -56,9 +60,9 @@ def _get_other_session() -> Session:
return SessionMaker()
def test_key_value_distributed_lock_happy_path() -> None:
def test_distributed_lock_kv_happy_path() -> None:
"""
Test successfully acquiring and returning the distributed lock.
Test successfully acquiring and returning the distributed lock via KV backend.
Note, we're using another session for asserting the lock state in the Metastore
to simulate what another worker will observe. Otherwise, there's the risk that
@@ -66,24 +70,29 @@ def test_key_value_distributed_lock_happy_path() -> None:
"""
session = _get_other_session()
with freeze_time("2021-01-01"):
assert _get_lock(MAIN_KEY, session) is None
# Ensure Redis is not configured so KV backend is used
with (
patch.object(acquire_module, "get_redis_client", return_value=None),
patch.object(release_module, "get_redis_client", return_value=None),
):
with freeze_time("2021-01-01"):
assert _get_lock(MAIN_KEY, session) is None
with KeyValueDistributedLock("ns", a=1, b=2) as key:
assert key == MAIN_KEY
assert _get_lock(key, session) == LOCK_VALUE
assert _get_lock(OTHER_KEY, session) is None
with DistributedLock("ns", a=1, b=2) as key:
assert key == MAIN_KEY
assert _get_lock(key, session) == LOCK_VALUE
assert _get_lock(OTHER_KEY, session) is None
with pytest.raises(CreateKeyValueDistributedLockFailedException):
with KeyValueDistributedLock("ns", a=1, b=2):
pass
with pytest.raises(AcquireDistributedLockFailedException):
with DistributedLock("ns", a=1, b=2):
pass
assert _get_lock(MAIN_KEY, session) is None
assert _get_lock(MAIN_KEY, session) is None
def test_key_value_distributed_lock_expired() -> None:
def test_distributed_lock_kv_expired() -> None:
"""
Test expiration of the distributed lock
Test expiration of the distributed lock via KV backend.
Note, we're using another session for asserting the lock state in the Metastore
to simulate what another worker will observe. Otherwise, there's the risk that
@@ -91,11 +100,112 @@ def test_key_value_distributed_lock_expired() -> None:
"""
session = _get_other_session()
with freeze_time("2021-01-01"):
assert _get_lock(MAIN_KEY, session) is None
with KeyValueDistributedLock("ns", a=1, b=2):
assert _get_lock(MAIN_KEY, session) == LOCK_VALUE
with freeze_time("2022-01-01"):
assert _get_lock(MAIN_KEY, session) is None
# Ensure Redis is not configured so KV backend is used
with (
patch.object(acquire_module, "get_redis_client", return_value=None),
patch.object(release_module, "get_redis_client", return_value=None),
):
with freeze_time("2021-01-01"):
assert _get_lock(MAIN_KEY, session) is None
with DistributedLock("ns", a=1, b=2):
assert _get_lock(MAIN_KEY, session) == LOCK_VALUE
with freeze_time("2022-01-01"):
assert _get_lock(MAIN_KEY, session) is None
assert _get_lock(MAIN_KEY, session) is None
assert _get_lock(MAIN_KEY, session) is None
def test_distributed_lock_uses_redis_when_configured() -> None:
"""Test that DistributedLock uses Redis backend when configured."""
mock_redis = MagicMock()
mock_redis.set.return_value = True # Lock acquired
# Use patch.object to patch on already-imported modules
with (
patch.object(acquire_module, "get_redis_client", return_value=mock_redis),
patch.object(release_module, "get_redis_client", return_value=mock_redis),
):
with DistributedLock("test_redis", key="value") as lock_key:
assert lock_key is not None
# Verify SET NX EX was called
mock_redis.set.assert_called_once()
call_args = mock_redis.set.call_args
assert call_args.kwargs["nx"] is True
assert "ex" in call_args.kwargs
# Verify DELETE was called on exit
mock_redis.delete.assert_called_once()
def test_distributed_lock_redis_already_taken() -> None:
"""Test Redis lock fails when already held."""
mock_redis = MagicMock()
mock_redis.set.return_value = None # Lock not acquired (already taken)
with patch.object(acquire_module, "get_redis_client", return_value=mock_redis):
with pytest.raises(AcquireDistributedLockFailedException):
with DistributedLock("test_redis", key="value"):
pass
def test_distributed_lock_redis_connection_error() -> None:
"""Test Redis connection error raises exception (fail fast)."""
import redis
mock_redis = MagicMock()
mock_redis.set.side_effect = redis.RedisError("Connection failed")
with patch.object(acquire_module, "get_redis_client", return_value=mock_redis):
with pytest.raises(AcquireDistributedLockFailedException):
with DistributedLock("test_redis", key="value"):
pass
def test_distributed_lock_custom_ttl() -> None:
"""Test Redis lock with custom TTL."""
mock_redis = MagicMock()
mock_redis.set.return_value = True
with (
patch.object(acquire_module, "get_redis_client", return_value=mock_redis),
patch.object(release_module, "get_redis_client", return_value=mock_redis),
):
with DistributedLock("test", ttl_seconds=60, key="value"):
call_args = mock_redis.set.call_args
assert call_args.kwargs["ex"] == 60 # Custom TTL
def test_distributed_lock_default_ttl(app_context: None) -> None:
"""Test Redis lock uses default TTL when not specified."""
from superset.commands.distributed_lock.base import get_default_lock_ttl
mock_redis = MagicMock()
mock_redis.set.return_value = True
with (
patch.object(acquire_module, "get_redis_client", return_value=mock_redis),
patch.object(release_module, "get_redis_client", return_value=mock_redis),
):
with DistributedLock("test", key="value"):
call_args = mock_redis.set.call_args
assert call_args.kwargs["ex"] == get_default_lock_ttl()
def test_distributed_lock_fallback_to_kv_when_redis_not_configured() -> None:
"""Test falls back to KV lock when Redis not configured."""
session = _get_other_session()
test_key = get_key("test_fallback", key="value")
with (
patch.object(acquire_module, "get_redis_client", return_value=None),
patch.object(release_module, "get_redis_client", return_value=None),
):
with freeze_time("2021-01-01"):
# When Redis is not configured, should use KV backend
with DistributedLock("test_fallback", key="value") as lock_key:
assert lock_key == test_key
# Verify lock exists in KV store
assert _get_lock(test_key, session) == LOCK_VALUE
# Lock should be released
assert _get_lock(test_key, session) is None

View File

@@ -0,0 +1,477 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Unit tests for task decorators"""
from unittest.mock import MagicMock, patch
from uuid import UUID
import pytest
from superset_core.api.tasks import TaskOptions, TaskScope
from superset.commands.tasks.exceptions import GlobalTaskFrameworkDisabledError
from superset.tasks.decorators import task, TaskWrapper
from superset.tasks.registry import TaskRegistry
TEST_UUID = UUID("b8b61b7b-1cd3-4a31-a74a-0a95341afc06")
class TestTaskDecoratorFeatureFlag:
"""Tests for @task decorator feature flag behavior"""
def setup_method(self):
"""Clear task registry before each test"""
TaskRegistry._tasks.clear()
@patch("superset.tasks.decorators.is_feature_enabled", return_value=False)
def test_decorator_succeeds_when_gtf_disabled(self, mock_feature_flag):
"""Test that @task decorator can be applied even when GTF is disabled.
This enables safe module imports during app startup or Celery autodiscovery.
"""
# Decoration should succeed - no error raised
@task(name="test_gtf_disabled_decorator")
def my_task() -> None:
pass
assert isinstance(my_task, TaskWrapper)
assert my_task.name == "test_gtf_disabled_decorator"
@patch("superset.tasks.decorators.is_feature_enabled", return_value=False)
def test_call_raises_error_when_gtf_disabled(self, mock_feature_flag):
"""Test that calling a task raises GlobalTaskFrameworkDisabledError
when GTF is disabled."""
@task(name="test_gtf_disabled_call")
def my_task() -> None:
pass
with pytest.raises(GlobalTaskFrameworkDisabledError):
my_task()
@patch("superset.tasks.decorators.is_feature_enabled", return_value=False)
def test_schedule_raises_error_when_gtf_disabled(self, mock_feature_flag):
"""Test that scheduling a task raises GlobalTaskFrameworkDisabledError
when GTF is disabled."""
@task(name="test_gtf_disabled_schedule")
def my_task() -> None:
pass
with pytest.raises(GlobalTaskFrameworkDisabledError):
my_task.schedule()
class TestTaskDecorator:
"""Tests for @task decorator"""
def test_decorator_basic(self):
"""Test basic decorator usage without options"""
@task(name="test_task")
def my_task(arg1: int, arg2: str) -> None:
pass
assert isinstance(my_task, TaskWrapper)
assert my_task.name == "test_task"
assert my_task.scope == TaskScope.PRIVATE
def test_decorator_without_parentheses(self):
"""Test decorator usage without parentheses"""
@task
def my_no_parens_task(arg1: int, arg2: str) -> None:
pass
assert isinstance(my_no_parens_task, TaskWrapper)
assert my_no_parens_task.name == "my_no_parens_task" # Uses function name
assert my_no_parens_task.scope == TaskScope.PRIVATE
def test_decorator_with_default_scope_private(self):
"""Test decorator with explicit PRIVATE scope"""
@task(name="private_task", scope=TaskScope.PRIVATE)
def my_private_task(arg1: int) -> None:
pass
assert my_private_task.scope == TaskScope.PRIVATE
def test_decorator_with_default_scope_shared(self):
"""Test decorator with SHARED scope"""
@task(name="shared_task", scope=TaskScope.SHARED)
def my_shared_task(arg1: int) -> None:
pass
assert my_shared_task.scope == TaskScope.SHARED
def test_decorator_with_default_scope_system(self):
"""Test decorator with SYSTEM scope"""
@task(name="system_task", scope=TaskScope.SYSTEM)
def my_system_task() -> None:
pass
assert my_system_task.scope == TaskScope.SYSTEM
def test_decorator_forbids_ctx_parameter(self):
"""Test decorator rejects functions with ctx parameter"""
with pytest.raises(TypeError, match="must not define 'ctx'"):
@task(name="bad_task")
def bad_task(ctx, arg1: int) -> None: # noqa: ARG001
pass
def test_decorator_forbids_options_parameter(self):
"""Test decorator rejects functions with options parameter"""
with pytest.raises(TypeError, match="must not define.*'options'"):
@task(name="bad_task")
def bad_task(options, arg1: int) -> None: # noqa: ARG001
pass
class TestTaskWrapperMergeOptions:
"""Tests for TaskWrapper._merge_options()"""
def setup_method(self):
"""Clear task registry before each test"""
TaskRegistry._tasks.clear()
def test_merge_options_no_override(self):
"""Test merging with no override returns defaults"""
@task(name="test_merge_no_override_unique")
def merge_task_1() -> None:
pass
# Set default options for testing
merge_task_1.default_options = TaskOptions(
task_key="default_key",
task_name="Default Name",
)
merged = merge_task_1._merge_options(None)
assert merged.task_key == "default_key"
assert merged.task_name == "Default Name"
def test_merge_options_override_task_key(self):
"""Test overriding task_key at call time"""
@task(name="test_merge_override_key_unique")
def merge_task_2() -> None:
pass
# Set default options for testing
merge_task_2.default_options = TaskOptions(task_key="default_key")
override = TaskOptions(task_key="override_key")
merged = merge_task_2._merge_options(override)
assert merged.task_key == "override_key"
def test_merge_options_override_task_name(self):
"""Test overriding task_name at call time"""
@task(name="test_merge_override_name_unique")
def merge_task_3() -> None:
pass
# Set default options for testing
merge_task_3.default_options = TaskOptions(task_name="Default Name")
override = TaskOptions(task_name="Override Name")
merged = merge_task_3._merge_options(override)
assert merged.task_name == "Override Name"
def test_merge_options_override_all(self):
"""Test overriding all options at call time"""
@task(name="test_merge_override_all_unique")
def merge_task_4() -> None:
pass
# Set default options for testing
merge_task_4.default_options = TaskOptions(
task_key="default_key",
task_name="Default Name",
)
override = TaskOptions(
task_key="override_key",
task_name="Override Name",
)
merged = merge_task_4._merge_options(override)
assert merged.task_key == "override_key"
assert merged.task_name == "Override Name"
class TestTaskWrapperSchedule:
"""Tests for TaskWrapper.schedule() with scope"""
def setup_method(self):
"""Clear task registry before each test"""
TaskRegistry._tasks.clear()
@patch("superset.tasks.decorators.TaskManager.submit_task")
def test_schedule_uses_default_scope(self, mock_submit):
"""Test schedule() uses decorator's default scope"""
mock_submit.return_value = MagicMock()
@task(name="test_schedule_default_unique", scope=TaskScope.SHARED)
def schedule_task_1(arg1: int) -> None:
pass
# Shared tasks require explicit task_key
schedule_task_1.schedule(123, options=TaskOptions(task_key="test_key"))
# Verify TaskManager.submit_task was called with correct scope
mock_submit.assert_called_once()
call_args = mock_submit.call_args
assert call_args[1]["scope"] == TaskScope.SHARED
@patch("superset.tasks.decorators.TaskManager.submit_task")
def test_schedule_uses_private_scope_by_default(self, mock_submit):
"""Test schedule() uses PRIVATE scope when no scope specified"""
mock_submit.return_value = MagicMock()
@task(name="test_schedule_override_unique")
def schedule_task_2(arg1: int) -> None:
pass
schedule_task_2.schedule(123)
# Verify PRIVATE scope was used (default)
mock_submit.assert_called_once()
call_args = mock_submit.call_args
assert call_args[1]["scope"] == TaskScope.PRIVATE
@patch("superset.tasks.decorators.TaskManager.submit_task")
def test_schedule_with_custom_options(self, mock_submit):
"""Test schedule() with custom task options"""
mock_submit.return_value = MagicMock()
@task(name="test_schedule_custom_unique", scope=TaskScope.SYSTEM)
def schedule_task_3(arg1: int) -> None:
pass
# Use custom task key and name
schedule_task_3.schedule(
123,
options=TaskOptions(task_key="custom_key", task_name="Custom Task Name"),
)
# Verify scope from decorator and options from call time
mock_submit.assert_called_once()
call_args = mock_submit.call_args
assert call_args[1]["scope"] == TaskScope.SYSTEM
assert call_args[1]["task_key"] == "custom_key"
assert call_args[1]["task_name"] == "Custom Task Name"
@patch("superset.tasks.decorators.TaskManager.submit_task")
def test_schedule_with_no_decorator_options(self, mock_submit):
"""Test schedule() uses default PRIVATE scope when no options provided"""
mock_submit.return_value = MagicMock()
@task(name="test_schedule_no_options_unique")
def schedule_task_4(arg1: int) -> None:
pass
schedule_task_4.schedule(123)
# Verify default PRIVATE scope
mock_submit.assert_called_once()
call_args = mock_submit.call_args
assert call_args[1]["scope"] == TaskScope.PRIVATE
@patch("superset.tasks.decorators.TaskManager.submit_task")
def test_schedule_shared_task_requires_task_key(self, mock_submit):
"""Test shared task schedule() requires explicit task_key"""
@task(name="test_shared_requires_key", scope=TaskScope.SHARED)
def shared_task(arg1: int) -> None:
pass
# Should raise ValueError when no task_key provided
with pytest.raises(
ValueError,
match="Shared task.*requires an explicit task_key.*for deduplication",
):
shared_task.schedule(123)
# Should work with task_key provided
mock_submit.return_value = MagicMock()
shared_task.schedule(123, options=TaskOptions(task_key="valid_key"))
mock_submit.assert_called_once()
@patch("superset.tasks.decorators.TaskManager.submit_task")
def test_schedule_private_task_allows_no_task_key(self, mock_submit):
"""Test private task schedule() works without task_key"""
mock_submit.return_value = MagicMock()
@task(name="test_private_no_key", scope=TaskScope.PRIVATE)
def private_task(arg1: int) -> None:
pass
# Should work without task_key (generates random UUID)
private_task.schedule(123)
mock_submit.assert_called_once()
class TestTaskWrapperCall:
"""Tests for TaskWrapper.__call__() with scope"""
def setup_method(self):
"""Clear task registry before each test"""
TaskRegistry._tasks.clear()
@patch("superset.commands.tasks.update.UpdateTaskCommand.run")
@patch("superset.daos.tasks.TaskDAO.find_one_or_none")
@patch("superset.commands.tasks.submit.SubmitTaskCommand.run_with_info")
def test_call_uses_default_scope(
self, mock_submit_run_with_info, mock_find, mock_update_run
):
"""Test direct call uses decorator's default scope"""
mock_task = MagicMock()
mock_task.uuid = TEST_UUID
mock_task.status = "in_progress"
mock_submit_run_with_info.return_value = (mock_task, True) # (task, is_new)
mock_update_run.return_value = mock_task
mock_find.return_value = mock_task # Mock the subsequent find call
@task(name="test_call_default_unique", scope=TaskScope.SHARED)
def call_task_1(arg1: int) -> None:
pass
# Shared tasks require explicit task_key
call_task_1(123, options=TaskOptions(task_key="test_key"))
# Verify SubmitTaskCommand.run_with_info was called
mock_submit_run_with_info.assert_called_once()
@patch("superset.utils.core.get_user_id")
@patch("superset.commands.tasks.update.UpdateTaskCommand.run")
@patch("superset.daos.tasks.TaskDAO.find_one_or_none")
@patch("superset.commands.tasks.submit.SubmitTaskCommand.run_with_info")
def test_call_uses_private_scope_by_default(
self, mock_submit_run_with_info, mock_find, mock_update_run, mock_get_user_id
):
"""Test direct call uses PRIVATE scope when no scope specified"""
mock_get_user_id.return_value = 1
mock_task = MagicMock()
mock_task.uuid = TEST_UUID
mock_task.status = "in_progress"
mock_submit_run_with_info.return_value = (mock_task, True) # (task, is_new)
mock_update_run.return_value = mock_task
mock_find.return_value = mock_task # Mock the subsequent find call
@task(name="test_call_private_default_unique")
def call_task_2(arg1: int) -> None:
pass
call_task_2(123)
# Verify SubmitTaskCommand.run_with_info was called
mock_submit_run_with_info.assert_called_once()
@patch("superset.commands.tasks.update.UpdateTaskCommand.run")
@patch("superset.daos.tasks.TaskDAO.find_one_or_none")
@patch("superset.commands.tasks.submit.SubmitTaskCommand.run_with_info")
def test_call_with_custom_options(
self, mock_submit_run_with_info, mock_find, mock_update_run
):
"""Test direct call with custom task options"""
mock_task = MagicMock()
mock_task.uuid = TEST_UUID
mock_task.status = "in_progress"
mock_submit_run_with_info.return_value = (mock_task, True) # (task, is_new)
mock_update_run.return_value = mock_task
mock_find.return_value = mock_task # Mock the subsequent find call
@task(name="test_call_custom_unique", scope=TaskScope.SYSTEM)
def call_task_3(arg1: int) -> None:
pass
# Use custom task key and name
call_task_3(
123,
options=TaskOptions(task_key="custom_key", task_name="Custom Task Name"),
)
# Verify SubmitTaskCommand.run_with_info was called
mock_submit_run_with_info.assert_called_once()
def test_call_shared_task_requires_task_key(self):
"""Test shared task direct call requires explicit task_key"""
@task(name="test_shared_call_requires_key", scope=TaskScope.SHARED)
def shared_task(arg1: int) -> None:
pass
# Should raise ValueError when no task_key provided
with pytest.raises(
ValueError,
match="Shared task.*requires an explicit task_key.*for deduplication",
):
shared_task(123)
@patch("superset.commands.tasks.update.UpdateTaskCommand.run")
@patch("superset.daos.tasks.TaskDAO.find_one_or_none")
@patch("superset.commands.tasks.submit.SubmitTaskCommand.run_with_info")
def test_call_shared_task_works_with_task_key(
self, mock_submit_run_with_info, mock_find, mock_update_run
):
"""Test shared task direct call works with task_key"""
mock_task = MagicMock()
mock_task.uuid = TEST_UUID
mock_task.status = "in_progress"
mock_submit_run_with_info.return_value = (mock_task, True) # (task, is_new)
mock_update_run.return_value = mock_task
mock_find.return_value = mock_task
@task(name="test_shared_call_with_key", scope=TaskScope.SHARED)
def shared_task(arg1: int) -> None:
pass
# Should work with task_key provided
shared_task(123, options=TaskOptions(task_key="valid_key"))
mock_submit_run_with_info.assert_called_once()
@patch("superset.utils.core.get_user_id")
@patch("superset.commands.tasks.update.UpdateTaskCommand.run")
@patch("superset.daos.tasks.TaskDAO.find_one_or_none")
@patch("superset.commands.tasks.submit.SubmitTaskCommand.run_with_info")
def test_call_private_task_allows_no_task_key(
self, mock_submit_run_with_info, mock_find, mock_update_run, mock_get_user_id
):
"""Test private task direct call works without task_key"""
mock_get_user_id.return_value = 1
mock_task = MagicMock()
mock_task.uuid = TEST_UUID
mock_task.status = "in_progress"
mock_submit_run_with_info.return_value = (mock_task, True) # (task, is_new)
mock_update_run.return_value = mock_task
mock_find.return_value = mock_task
@task(name="test_private_call_no_key", scope=TaskScope.PRIVATE)
def private_task(arg1: int) -> None:
pass
# Should work without task_key (generates random UUID)
private_task(123)
mock_submit_run_with_info.assert_called_once()

View File

@@ -0,0 +1,677 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Unit tests for GTF handlers (abort, cleanup) and related Task model behavior."""
import time
from datetime import datetime, timezone
from unittest.mock import MagicMock, Mock, patch
from uuid import UUID
import pytest
from freezegun import freeze_time
from superset_core.api.tasks import TaskStatus
from superset.tasks.context import TaskContext
TEST_UUID = UUID("b8b61b7b-1cd3-4a31-a74a-0a95341afc06")
@pytest.fixture
def mock_task():
"""Create a mock task for testing."""
task = MagicMock()
task.uuid = TEST_UUID
task.status = TaskStatus.PENDING.value
return task
@pytest.fixture
def mock_task_dao(mock_task):
"""Mock TaskDAO to return our test task."""
with patch("superset.daos.tasks.TaskDAO") as mock_dao:
mock_dao.find_one_or_none.return_value = mock_task
yield mock_dao
@pytest.fixture
def mock_update_command():
"""Mock UpdateTaskCommand to avoid database operations."""
with patch("superset.commands.tasks.update.UpdateTaskCommand") as mock_cmd:
mock_cmd.return_value.run.return_value = None
yield mock_cmd
@pytest.fixture
def mock_flask_app():
"""Create a properly configured mock Flask app."""
mock_app = MagicMock()
mock_app.config = {
"TASK_ABORT_POLLING_DEFAULT_INTERVAL": 0.1,
}
# Make app_context() return a proper context manager
mock_app.app_context.return_value.__enter__ = MagicMock(return_value=None)
mock_app.app_context.return_value.__exit__ = MagicMock(return_value=None)
# Use regular Mock (not MagicMock) for _get_current_object to avoid
# AsyncMockMixin creating unawaited coroutines in Python 3.10+
mock_app._get_current_object = Mock(return_value=mock_app)
return mock_app
@pytest.fixture
def task_context(mock_task, mock_task_dao, mock_update_command, mock_flask_app):
"""Create TaskContext with mocked dependencies."""
# Ensure mock_task has properties_dict and payload_dict (TaskContext accesses them)
mock_task.properties_dict = {"is_abortable": False}
mock_task.payload_dict = {}
with (
patch("superset.tasks.context.current_app") as mock_current_app,
patch("superset.tasks.manager.cache_manager") as mock_cache_manager,
):
# Disable Redis by making signal_cache return None
mock_cache_manager.signal_cache = None
# Configure current_app mock
mock_current_app.config = mock_flask_app.config
# Use regular Mock (not MagicMock) for _get_current_object to avoid
# AsyncMockMixin creating unawaited coroutines in Python 3.10+
mock_current_app._get_current_object = Mock(return_value=mock_flask_app)
ctx = TaskContext(mock_task)
yield ctx
# Cleanup: stop polling if started
if ctx._abort_listener:
ctx.stop_abort_polling()
class TestTaskStatusEnum:
"""Test TaskStatus enum values."""
def test_aborting_status_exists(self):
"""Test that ABORTING status is defined."""
assert hasattr(TaskStatus, "ABORTING")
assert TaskStatus.ABORTING.value == "aborting"
def test_all_statuses_present(self):
"""Test all expected statuses are present."""
expected_statuses = [
"pending",
"in_progress",
"success",
"failure",
"aborting",
"aborted",
]
actual_statuses = [s.value for s in TaskStatus]
for status in expected_statuses:
assert status in actual_statuses, f"Missing status: {status}"
class TestTaskAbortProperties:
"""Test Task model abort-related properties via status and properties accessor."""
def test_aborting_status(self):
"""Test ABORTING status check."""
from superset.models.tasks import Task
task = Task()
task.status = TaskStatus.ABORTING.value
assert task.status == TaskStatus.ABORTING.value
def test_is_abortable_in_properties(self):
"""Test is_abortable is accessible via properties."""
from superset.models.tasks import Task
task = Task()
task.update_properties({"is_abortable": True})
assert task.properties_dict.get("is_abortable") is True
def test_is_abortable_default_none(self):
"""Test is_abortable defaults to None for new tasks."""
from superset.models.tasks import Task
task = Task()
assert task.properties_dict.get("is_abortable") is None
class TestTaskSetStatus:
"""Test Task.set_status behavior for abort states."""
def test_set_status_in_progress_sets_is_abortable_false(self):
"""Test that transitioning to IN_PROGRESS sets is_abortable to False."""
from superset.models.tasks import Task
task = Task()
task.uuid = "test-uuid"
# Default is None
task.set_status(TaskStatus.IN_PROGRESS)
assert task.properties_dict.get("is_abortable") is False
assert task.started_at is not None
def test_set_status_in_progress_preserves_existing_is_abortable(self):
"""Test that re-setting IN_PROGRESS doesn't override is_abortable."""
from superset.models.tasks import Task
task = Task()
task.uuid = "test-uuid"
task.update_properties(
{"is_abortable": True}
) # Already set by handler registration
task.started_at = datetime.now(timezone.utc) # Already started
task.set_status(TaskStatus.IN_PROGRESS)
# Should not override since started_at is already set
assert task.properties_dict.get("is_abortable") is True
def test_set_status_aborting_does_not_set_ended_at(self):
"""Test that ABORTING status does not set ended_at."""
from superset.models.tasks import Task
task = Task()
task.uuid = "test-uuid"
task.started_at = datetime.now(timezone.utc)
task.status = TaskStatus.ABORTING.value
assert task.ended_at is None
def test_set_status_aborted_sets_ended_at(self):
"""Test that ABORTED status sets ended_at."""
from superset.models.tasks import Task
task = Task()
task.uuid = "test-uuid"
task.started_at = datetime.now(timezone.utc)
task.set_status(TaskStatus.ABORTED)
assert task.ended_at is not None
class TestTaskDuration:
"""Test Task duration_seconds property with different states."""
def test_duration_seconds_finished_task(self):
"""Test duration for finished task returns actual duration."""
from superset.models.tasks import Task
task = Task()
task.status = TaskStatus.SUCCESS.value # Must be finished to use ended_at
task.started_at = datetime(2024, 1, 1, 10, 0, 0, tzinfo=timezone.utc)
task.ended_at = datetime(2024, 1, 1, 10, 0, 30, tzinfo=timezone.utc)
# Should use ended_at - started_at = 30 seconds
assert task.duration_seconds == 30.0
@freeze_time("2024-01-01 10:00:30")
def test_duration_seconds_running_task(self):
"""Test duration for running task returns time since start."""
from superset.models.tasks import Task
task = Task()
task.started_at = datetime(2024, 1, 1, 10, 0, 0, tzinfo=timezone.utc)
task.ended_at = None
# 30 seconds since start
assert task.duration_seconds == 30.0
@freeze_time("2024-01-01 10:00:15")
def test_duration_seconds_pending_task(self):
"""Test duration for pending task returns queue time."""
from superset.models.tasks import Task
task = Task()
task.created_on = datetime(2024, 1, 1, 10, 0, 0, tzinfo=timezone.utc)
task.started_at = None
task.ended_at = None
# 15 seconds since creation
assert task.duration_seconds == 15.0
def test_duration_seconds_no_timestamps(self):
"""Test duration returns None when no timestamps available."""
from superset.models.tasks import Task
task = Task()
task.created_on = None
task.started_at = None
task.ended_at = None
assert task.duration_seconds is None
class TestAbortHandlerRegistration:
"""Test abort handler registration and is_abortable flag."""
def test_on_abort_registers_handler(self, task_context):
"""Test that on_abort registers a handler."""
handler_called = False
@task_context.on_abort
def handle_abort():
nonlocal handler_called
handler_called = True
assert len(task_context._abort_handlers) == 1
assert not handler_called
@patch("superset.tasks.context.current_app")
def test_on_abort_sets_abortable(self, mock_app):
"""Test on_abort sets is_abortable to True on first handler."""
mock_app.config = {"TASK_ABORT_POLLING_DEFAULT_INTERVAL": 1.0}
mock_app._get_current_object = Mock(return_value=mock_app)
mock_task = MagicMock()
mock_task.uuid = TEST_UUID
mock_task.properties_dict = {"is_abortable": False}
mock_task.payload_dict = {}
with (
patch.object(TaskContext, "_set_abortable") as mock_set_abortable,
patch.object(TaskContext, "start_abort_polling"),
):
ctx = TaskContext(mock_task)
@ctx.on_abort
def handler():
pass
mock_set_abortable.assert_called_once()
@patch("superset.tasks.context.current_app")
def test_on_abort_only_sets_abortable_once(self, mock_app):
"""Test on_abort only calls _set_abortable for first handler."""
mock_app.config = {"TASK_ABORT_POLLING_DEFAULT_INTERVAL": 1.0}
mock_app._get_current_object = Mock(return_value=mock_app)
mock_task = MagicMock()
mock_task.uuid = TEST_UUID
mock_task.properties_dict = {"is_abortable": False}
mock_task.payload_dict = {}
with (
patch.object(TaskContext, "_set_abortable") as mock_set_abortable,
patch.object(TaskContext, "start_abort_polling"),
):
ctx = TaskContext(mock_task)
@ctx.on_abort
def handler1():
pass
@ctx.on_abort
def handler2():
pass
# Should only be called once for first handler
assert mock_set_abortable.call_count == 1
def test_abort_handlers_completed_initially_false(self):
"""Test abort_handlers_completed is False initially."""
mock_task = MagicMock()
mock_task.uuid = TEST_UUID
mock_task.properties_dict = {}
mock_task.payload_dict = {}
with patch("superset.tasks.context.current_app") as mock_app:
mock_app._get_current_object = Mock(return_value=mock_app)
ctx = TaskContext(mock_task)
assert ctx.abort_handlers_completed is False
class TestAbortPolling:
"""Test abort detection polling behavior."""
def test_on_abort_starts_polling_automatically(self, task_context):
"""Test that registering first handler starts abort listener."""
assert task_context._abort_listener is None
@task_context.on_abort
def handle_abort():
pass
assert task_context._abort_listener is not None
def test_stop_abort_polling(self, task_context):
"""Test that stop_abort_polling stops the abort listener."""
@task_context.on_abort
def handle_abort():
pass
assert task_context._abort_listener is not None
task_context.stop_abort_polling()
assert task_context._abort_listener is None
def test_start_abort_polling_only_once(self, task_context):
"""Test that start_abort_polling is idempotent."""
task_context.start_abort_polling(interval=0.1)
first_listener = task_context._abort_listener
# Try to start again
task_context.start_abort_polling(interval=0.1)
second_listener = task_context._abort_listener
# Should be the same listener
assert first_listener is second_listener
def test_on_abort_with_custom_interval(self, task_context):
"""Test that custom interval can be set via start_abort_polling."""
with patch("superset.tasks.context.current_app") as mock_app:
mock_app.config = {"TASK_ABORT_POLLING_DEFAULT_INTERVAL": 0.1}
mock_app._get_current_object = Mock(return_value=mock_app)
@task_context.on_abort
def handle_abort():
pass
# Override with custom interval
task_context.stop_abort_polling()
task_context.start_abort_polling(interval=0.05)
assert task_context._abort_listener is not None
def test_polling_stops_after_abort_detected(self, task_context, mock_task):
"""Test that abort is detected and handlers are triggered."""
@task_context.on_abort
def handle_abort():
pass
# Trigger abort
mock_task.status = TaskStatus.ABORTED.value
# Wait for detection
time.sleep(0.3)
# Abort should have been detected
assert task_context._abort_detected is True
class TestAbortHandlerExecution:
"""Test abort handler execution behavior."""
def test_on_abort_handler_fires_when_task_aborted(self, task_context, mock_task):
"""Test that abort handler fires automatically when task is aborted."""
abort_called = False
@task_context.on_abort
def handle_abort():
nonlocal abort_called
abort_called = True
# Simulate task being aborted
mock_task.status = TaskStatus.ABORTED.value
# Wait for polling to detect abort (max 0.3s with 0.1s interval)
time.sleep(0.3)
assert abort_called
assert task_context._abort_detected
def test_on_abort_not_called_on_success(self, task_context, mock_task):
"""Test that abort handlers don't run on success."""
abort_called = False
@task_context.on_abort
def handle_abort():
nonlocal abort_called
abort_called = True
# Keep task in success state
mock_task.status = TaskStatus.SUCCESS.value
# Wait and verify handler not called
time.sleep(0.3)
assert not abort_called
def test_multiple_abort_handlers(self, task_context, mock_task):
"""Test that all abort handlers execute in LIFO order."""
calls = []
@task_context.on_abort
def handler1():
calls.append(1)
@task_context.on_abort
def handler2():
calls.append(2)
# Trigger abort
mock_task.status = TaskStatus.ABORTED.value
# Wait for detection
time.sleep(0.3)
# LIFO order: handler2 runs first
assert calls == [2, 1]
def test_abort_handler_exception_doesnt_fail_task(self, task_context, mock_task):
"""Test that exception in abort handler is logged but doesn't fail task."""
handler2_called = False
@task_context.on_abort
def bad_handler():
raise ValueError("Handler error")
@task_context.on_abort
def good_handler():
nonlocal handler2_called
handler2_called = True
# Trigger abort
mock_task.status = TaskStatus.ABORTED.value
# Wait for detection
time.sleep(0.3)
# Second handler should still run despite first handler failing
assert handler2_called
class TestBestEffortHandlerExecution:
"""Test that all handlers execute even when some fail (best-effort)."""
def test_all_abort_handlers_run_even_if_all_fail(self, task_context, mock_task):
"""Test all abort handlers execute even if every one raises an exception."""
calls = []
@task_context.on_abort
def handler1():
calls.append(1)
raise ValueError("Handler 1 failed")
@task_context.on_abort
def handler2():
calls.append(2)
raise RuntimeError("Handler 2 failed")
@task_context.on_abort
def handler3():
calls.append(3)
raise TypeError("Handler 3 failed")
# Trigger abort handlers directly (simulating abort detection)
task_context._trigger_abort_handlers()
# All handlers should have been called (LIFO order: 3, 2, 1)
assert calls == [3, 2, 1]
# Failures should be collected (abort handlers don't write to DB)
assert len(task_context._handler_failures) == 3
failure_types = [
type(ex).__name__ for _, ex, _ in task_context._handler_failures
]
assert "TypeError" in failure_types
assert "RuntimeError" in failure_types
assert "ValueError" in failure_types
def test_all_cleanup_handlers_run_even_if_all_fail(self, task_context, mock_task):
"""Test all cleanup handlers execute even if every one raises an exception."""
calls = []
captured_failures = []
# Mock _write_handler_failures_to_db to capture failures before clearing
original_write = task_context._write_handler_failures_to_db
def mock_write():
captured_failures.extend(task_context._handler_failures)
original_write()
task_context._write_handler_failures_to_db = mock_write
@task_context.on_cleanup
def cleanup1():
calls.append(1)
raise ValueError("Cleanup 1 failed")
@task_context.on_cleanup
def cleanup2():
calls.append(2)
raise RuntimeError("Cleanup 2 failed")
@task_context.on_cleanup
def cleanup3():
calls.append(3)
raise TypeError("Cleanup 3 failed")
# Set task to SUCCESS (not aborting) so only cleanup handlers run
mock_task.status = TaskStatus.SUCCESS.value
# Run cleanup
task_context._run_cleanup()
# All handlers should have been called (LIFO order: 3, 2, 1)
assert calls == [3, 2, 1]
# Failures should have been captured before clearing
assert len(captured_failures) == 3
failure_types = [type(ex).__name__ for _, ex, _ in captured_failures]
assert "TypeError" in failure_types
assert "RuntimeError" in failure_types
assert "ValueError" in failure_types
def test_mixed_abort_and_cleanup_failures_all_collected(
self, task_context, mock_task
):
"""Test abort and cleanup handler failures are collected together."""
calls = []
captured_failures = []
# Mock _write_handler_failures_to_db to capture failures before clearing
original_write = task_context._write_handler_failures_to_db
def mock_write():
captured_failures.extend(task_context._handler_failures)
original_write()
task_context._write_handler_failures_to_db = mock_write
@task_context.on_abort
def abort1():
calls.append("abort1")
raise ValueError("Abort 1 failed")
@task_context.on_abort
def abort2():
calls.append("abort2")
raise RuntimeError("Abort 2 failed")
@task_context.on_cleanup
def cleanup1():
calls.append("cleanup1")
raise TypeError("Cleanup 1 failed")
@task_context.on_cleanup
def cleanup2():
calls.append("cleanup2")
raise KeyError("Cleanup 2 failed")
# Set task to ABORTING so both abort and cleanup handlers run
mock_task.status = TaskStatus.ABORTING.value
# Run cleanup (which triggers abort handlers first, then cleanup handlers)
task_context._run_cleanup()
# All handlers should have been called
# Abort handlers run first (LIFO: abort2, abort1)
# Then cleanup handlers (LIFO: cleanup2, cleanup1)
assert calls == ["abort2", "abort1", "cleanup2", "cleanup1"]
# All 4 failures should have been captured
assert len(captured_failures) == 4
# Verify handler types are recorded correctly
handler_types = [htype for htype, _, _ in captured_failures]
assert handler_types.count("abort") == 2
assert handler_types.count("cleanup") == 2
class TestCleanupHandlers:
"""Test cleanup handler behavior."""
def test_cleanup_triggers_abort_handlers_if_not_detected(
self, task_context, mock_task
):
"""Test that _run_cleanup triggers abort handlers if task ended aborted."""
abort_called = False
@task_context.on_abort
def handle_abort():
nonlocal abort_called
abort_called = True
# Set task as aborted but don't let polling detect it
mock_task.status = TaskStatus.ABORTED.value
task_context._abort_detected = False
# Immediately run cleanup (simulating task ending before poll)
task_context._run_cleanup()
assert abort_called
def test_cleanup_doesnt_duplicate_abort_handlers(self, task_context, mock_task):
"""Test that abort handlers only run once even if called from cleanup."""
call_count = 0
@task_context.on_abort
def handle_abort():
nonlocal call_count
call_count += 1
# Trigger abort via polling
mock_task.status = TaskStatus.ABORTED.value
time.sleep(0.3)
# Handlers should have been called once
assert call_count == 1
assert task_context._abort_detected is True
# Run cleanup - handlers should NOT be called again
task_context._run_cleanup()
assert call_count == 1 # Still 1, not 2

View File

@@ -0,0 +1,462 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Unit tests for TaskManager pub/sub functionality"""
import threading
import time
from unittest.mock import MagicMock, patch
import redis
from superset.tasks.manager import AbortListener, TaskManager
class TestAbortListener:
"""Tests for AbortListener class"""
def test_stop_sets_event(self):
"""Test that stop() sets the stop event"""
stop_event = threading.Event()
thread = MagicMock(spec=threading.Thread)
thread.is_alive.return_value = False
listener = AbortListener("test-uuid", thread, stop_event)
assert not stop_event.is_set()
listener.stop()
assert stop_event.is_set()
def test_stop_closes_pubsub(self):
"""Test that stop() closes the pub/sub connection"""
stop_event = threading.Event()
thread = MagicMock(spec=threading.Thread)
thread.is_alive.return_value = False
pubsub = MagicMock()
listener = AbortListener("test-uuid", thread, stop_event, pubsub)
listener.stop()
pubsub.unsubscribe.assert_called_once()
pubsub.close.assert_called_once()
def test_stop_joins_thread(self):
"""Test that stop() joins the listener thread"""
stop_event = threading.Event()
thread = MagicMock(spec=threading.Thread)
thread.is_alive.return_value = True
listener = AbortListener("test-uuid", thread, stop_event)
listener.stop()
thread.join.assert_called_once_with(timeout=2.0)
class TestTaskManagerInitApp:
"""Tests for TaskManager.init_app()"""
def setup_method(self):
"""Reset TaskManager state before each test"""
TaskManager._initialized = False
TaskManager._channel_prefix = "gtf:abort:"
TaskManager._completion_channel_prefix = "gtf:complete:"
def teardown_method(self):
"""Reset TaskManager state after each test"""
TaskManager._initialized = False
TaskManager._channel_prefix = "gtf:abort:"
TaskManager._completion_channel_prefix = "gtf:complete:"
def test_init_app_sets_channel_prefixes(self):
"""Test init_app reads channel prefixes from config"""
app = MagicMock()
app.config.get.side_effect = lambda key, default=None: {
"TASKS_ABORT_CHANNEL_PREFIX": "custom:abort:",
"TASKS_COMPLETION_CHANNEL_PREFIX": "custom:complete:",
}.get(key, default)
TaskManager.init_app(app)
assert TaskManager._initialized is True
assert TaskManager._channel_prefix == "custom:abort:"
assert TaskManager._completion_channel_prefix == "custom:complete:"
def test_init_app_skips_if_already_initialized(self):
"""Test init_app is idempotent"""
TaskManager._initialized = True
app = MagicMock()
TaskManager.init_app(app)
# Should not call app.config.get since already initialized
app.config.get.assert_not_called()
class TestTaskManagerPubSub:
"""Tests for TaskManager pub/sub methods"""
def setup_method(self):
"""Reset TaskManager state before each test"""
TaskManager._initialized = False
TaskManager._channel_prefix = "gtf:abort:"
TaskManager._completion_channel_prefix = "gtf:complete:"
def teardown_method(self):
"""Reset TaskManager state after each test"""
TaskManager._initialized = False
TaskManager._channel_prefix = "gtf:abort:"
TaskManager._completion_channel_prefix = "gtf:complete:"
@patch("superset.tasks.manager.cache_manager")
def test_is_pubsub_available_no_redis(self, mock_cache_manager):
"""Test is_pubsub_available returns False when Redis not configured"""
mock_cache_manager.signal_cache = None
assert TaskManager.is_pubsub_available() is False
@patch("superset.tasks.manager.cache_manager")
def test_is_pubsub_available_with_redis(self, mock_cache_manager):
"""Test is_pubsub_available returns True when Redis is configured"""
mock_cache_manager.signal_cache = MagicMock()
assert TaskManager.is_pubsub_available() is True
def test_get_abort_channel(self):
"""Test get_abort_channel returns correct channel name"""
task_uuid = "abc-123-def-456"
channel = TaskManager.get_abort_channel(task_uuid)
assert channel == "gtf:abort:abc-123-def-456"
def test_get_abort_channel_custom_prefix(self):
"""Test get_abort_channel with custom prefix"""
TaskManager._channel_prefix = "custom:prefix:"
task_uuid = "test-uuid"
channel = TaskManager.get_abort_channel(task_uuid)
assert channel == "custom:prefix:test-uuid"
@patch("superset.tasks.manager.cache_manager")
def test_publish_abort_no_redis(self, mock_cache_manager):
"""Test publish_abort returns False when Redis not available"""
mock_cache_manager.signal_cache = None
result = TaskManager.publish_abort("test-uuid")
assert result is False
@patch("superset.tasks.manager.cache_manager")
def test_publish_abort_success(self, mock_cache_manager):
"""Test publish_abort publishes message successfully"""
mock_redis = MagicMock()
mock_redis.publish.return_value = 1 # One subscriber
mock_cache_manager.signal_cache = mock_redis
result = TaskManager.publish_abort("test-uuid")
assert result is True
mock_redis.publish.assert_called_once_with("gtf:abort:test-uuid", "abort")
@patch("superset.tasks.manager.cache_manager")
def test_publish_abort_redis_error(self, mock_cache_manager):
"""Test publish_abort handles Redis errors gracefully"""
mock_redis = MagicMock()
mock_redis.publish.side_effect = redis.RedisError("Connection lost")
mock_cache_manager.signal_cache = mock_redis
result = TaskManager.publish_abort("test-uuid")
assert result is False
class TestTaskManagerListenForAbort:
"""Tests for TaskManager.listen_for_abort()"""
def setup_method(self):
"""Reset TaskManager state before each test"""
TaskManager._initialized = False
TaskManager._channel_prefix = "gtf:abort:"
TaskManager._completion_channel_prefix = "gtf:complete:"
def teardown_method(self):
"""Reset TaskManager state after each test"""
TaskManager._initialized = False
TaskManager._channel_prefix = "gtf:abort:"
TaskManager._completion_channel_prefix = "gtf:complete:"
@patch("superset.tasks.manager.cache_manager")
def test_listen_for_abort_no_redis_uses_polling(self, mock_cache_manager):
"""Test listen_for_abort falls back to polling when Redis unavailable"""
mock_cache_manager.signal_cache = None
callback = MagicMock()
with patch.object(TaskManager, "_poll_for_abort", return_value=None):
listener = TaskManager.listen_for_abort(
task_uuid="test-uuid",
callback=callback,
poll_interval=1.0,
app=None,
)
# Give thread time to start
time.sleep(0.1)
listener.stop()
# Should use polling since no Redis
assert listener._pubsub is None
@patch("superset.tasks.manager.cache_manager")
def test_listen_for_abort_with_redis_uses_pubsub(self, mock_cache_manager):
"""Test listen_for_abort uses pub/sub when Redis available"""
mock_redis = MagicMock()
mock_pubsub = MagicMock()
mock_redis.pubsub.return_value = mock_pubsub
mock_cache_manager.signal_cache = mock_redis
callback = MagicMock()
with patch.object(TaskManager, "_listen_pubsub", return_value=None):
listener = TaskManager.listen_for_abort(
task_uuid="test-uuid",
callback=callback,
poll_interval=1.0,
app=None,
)
# Give thread time to start
time.sleep(0.1)
listener.stop()
# Should subscribe to channel
mock_pubsub.subscribe.assert_called_once_with("gtf:abort:test-uuid")
@patch("superset.tasks.manager.cache_manager")
def test_listen_for_abort_redis_subscribe_failure_raises(self, mock_cache_manager):
"""Test listen_for_abort raises exception on subscribe failure
when Redis configured"""
import pytest
mock_redis = MagicMock()
mock_redis.pubsub.side_effect = redis.RedisError("Connection failed")
mock_cache_manager.signal_cache = mock_redis
callback = MagicMock()
# With fail-fast behavior, Redis subscribe failure raises exception
with pytest.raises(redis.RedisError, match="Connection failed"):
TaskManager.listen_for_abort(
task_uuid="test-uuid",
callback=callback,
poll_interval=1.0,
app=None,
)
class TestTaskManagerCompletion:
"""Tests for TaskManager completion pub/sub and wait_for_completion"""
def setup_method(self):
"""Reset TaskManager state before each test"""
TaskManager._initialized = False
TaskManager._channel_prefix = "gtf:abort:"
TaskManager._completion_channel_prefix = "gtf:complete:"
def teardown_method(self):
"""Reset TaskManager state after each test"""
TaskManager._initialized = False
TaskManager._channel_prefix = "gtf:abort:"
TaskManager._completion_channel_prefix = "gtf:complete:"
def test_get_completion_channel(self):
"""Test get_completion_channel returns correct channel name"""
task_uuid = "abc-123-def-456"
channel = TaskManager.get_completion_channel(task_uuid)
assert channel == "gtf:complete:abc-123-def-456"
def test_get_completion_channel_custom_prefix(self):
"""Test get_completion_channel with custom prefix"""
TaskManager._completion_channel_prefix = "custom:complete:"
task_uuid = "test-uuid"
channel = TaskManager.get_completion_channel(task_uuid)
assert channel == "custom:complete:test-uuid"
@patch("superset.tasks.manager.cache_manager")
def test_publish_completion_no_redis(self, mock_cache_manager):
"""Test publish_completion returns False when Redis not available"""
mock_cache_manager.signal_cache = None
result = TaskManager.publish_completion("test-uuid", "success")
assert result is False
@patch("superset.tasks.manager.cache_manager")
def test_publish_completion_success(self, mock_cache_manager):
"""Test publish_completion publishes message successfully"""
mock_redis = MagicMock()
mock_redis.publish.return_value = 1 # One subscriber
mock_cache_manager.signal_cache = mock_redis
result = TaskManager.publish_completion("test-uuid", "success")
assert result is True
mock_redis.publish.assert_called_once_with("gtf:complete:test-uuid", "success")
@patch("superset.tasks.manager.cache_manager")
def test_publish_completion_redis_error(self, mock_cache_manager):
"""Test publish_completion handles Redis errors gracefully"""
mock_redis = MagicMock()
mock_redis.publish.side_effect = redis.RedisError("Connection lost")
mock_cache_manager.signal_cache = mock_redis
result = TaskManager.publish_completion("test-uuid", "success")
assert result is False
@patch("superset.tasks.manager.cache_manager")
@patch("superset.daos.tasks.TaskDAO")
def test_wait_for_completion_task_not_found(self, mock_dao, mock_cache_manager):
"""Test wait_for_completion raises ValueError for missing task"""
import pytest
mock_cache_manager.signal_cache = None
mock_dao.find_one_or_none.return_value = None
with pytest.raises(ValueError, match="not found"):
TaskManager.wait_for_completion("nonexistent-uuid")
@patch("superset.tasks.manager.cache_manager")
@patch("superset.daos.tasks.TaskDAO")
def test_wait_for_completion_already_complete(self, mock_dao, mock_cache_manager):
"""Test wait_for_completion returns immediately for terminal state"""
mock_cache_manager.signal_cache = None
mock_task = MagicMock()
mock_task.uuid = "test-uuid"
mock_task.status = "success"
mock_dao.find_one_or_none.return_value = mock_task
result = TaskManager.wait_for_completion("test-uuid")
assert result == mock_task
# Should only call find_one_or_none once (initial check)
mock_dao.find_one_or_none.assert_called_once()
@patch("superset.tasks.manager.cache_manager")
@patch("superset.daos.tasks.TaskDAO")
def test_wait_for_completion_timeout(self, mock_dao, mock_cache_manager):
"""Test wait_for_completion raises TimeoutError when timeout expires"""
import pytest
mock_cache_manager.signal_cache = None
mock_task = MagicMock()
mock_task.uuid = "test-uuid"
mock_task.status = "in_progress" # Never completes
mock_dao.find_one_or_none.return_value = mock_task
with pytest.raises(TimeoutError, match="Timeout waiting"):
TaskManager.wait_for_completion("test-uuid", timeout=0.1)
@patch("superset.tasks.manager.cache_manager")
@patch("superset.daos.tasks.TaskDAO")
def test_wait_for_completion_polling_success(self, mock_dao, mock_cache_manager):
"""Test wait_for_completion returns when task completes via polling"""
mock_cache_manager.signal_cache = None
mock_task_pending = MagicMock()
mock_task_pending.uuid = "test-uuid"
mock_task_pending.status = "pending"
mock_task_complete = MagicMock()
mock_task_complete.uuid = "test-uuid"
mock_task_complete.status = "success"
# First call returns pending, second returns complete
mock_dao.find_one_or_none.side_effect = [
mock_task_pending,
mock_task_complete,
]
result = TaskManager.wait_for_completion(
"test-uuid",
timeout=5.0,
poll_interval=0.1,
)
assert result.status == "success"
@patch("superset.tasks.manager.cache_manager")
@patch("superset.daos.tasks.TaskDAO")
def test_wait_for_completion_with_pubsub(self, mock_dao, mock_cache_manager):
"""Test wait_for_completion uses pub/sub when Redis available"""
mock_task_pending = MagicMock()
mock_task_pending.uuid = "test-uuid"
mock_task_pending.status = "pending"
mock_task_complete = MagicMock()
mock_task_complete.uuid = "test-uuid"
mock_task_complete.status = "success"
# First call returns pending, second returns complete
mock_dao.find_one_or_none.side_effect = [
mock_task_pending,
mock_task_complete,
]
# Set up mock Redis with pub/sub
mock_redis = MagicMock()
mock_pubsub = MagicMock()
# Simulate receiving a completion message
mock_pubsub.get_message.return_value = {
"type": "message",
"data": "success",
}
mock_redis.pubsub.return_value = mock_pubsub
mock_cache_manager.signal_cache = mock_redis
result = TaskManager.wait_for_completion(
"test-uuid",
timeout=5.0,
)
assert result.status == "success"
# Should have subscribed to completion channel
mock_pubsub.subscribe.assert_called_once_with("gtf:complete:test-uuid")
# Should have cleaned up
mock_pubsub.unsubscribe.assert_called_once()
mock_pubsub.close.assert_called_once()
@patch("superset.tasks.manager.cache_manager")
@patch("superset.daos.tasks.TaskDAO")
def test_wait_for_completion_pubsub_error_raises(
self, mock_dao, mock_cache_manager
):
"""Test wait_for_completion raises exception on Redis error when
Redis configured"""
import pytest
mock_task_pending = MagicMock()
mock_task_pending.uuid = "test-uuid"
mock_task_pending.status = "pending"
mock_dao.find_one_or_none.return_value = mock_task_pending
# Set up mock Redis that fails
mock_redis = MagicMock()
mock_redis.pubsub.side_effect = redis.RedisError("Connection failed")
mock_cache_manager.signal_cache = mock_redis
# With fail-fast behavior, Redis error is raised instead of falling back
with pytest.raises(redis.RedisError, match="Connection failed"):
TaskManager.wait_for_completion(
"test-uuid",
timeout=5.0,
poll_interval=0.1,
)
def test_terminal_states_constant(self):
"""Test TERMINAL_STATES contains expected values"""
expected = {"success", "failure", "aborted", "timed_out"}
assert TaskManager.TERMINAL_STATES == expected

View File

@@ -0,0 +1,612 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Unit tests for GTF timeout handling."""
import time
from unittest.mock import MagicMock, patch
from uuid import UUID
import pytest
from superset_core.api.tasks import TaskOptions, TaskScope
from superset.tasks.context import TaskContext
from superset.tasks.decorators import TaskWrapper
TEST_UUID = UUID("b8b61b7b-1cd3-4a31-a74a-0a95341afc06")
# =============================================================================
# Fixtures
# =============================================================================
@pytest.fixture
def mock_flask_app():
"""Create a properly configured mock Flask app."""
mock_app = MagicMock()
mock_app.config = {
"TASK_ABORT_POLLING_DEFAULT_INTERVAL": 0.1,
}
# Make app_context() return a proper context manager
mock_app.app_context.return_value.__enter__ = MagicMock(return_value=None)
mock_app.app_context.return_value.__exit__ = MagicMock(return_value=None)
return mock_app
@pytest.fixture
def mock_task_abortable():
"""Create a mock task that is abortable."""
task = MagicMock()
task.uuid = TEST_UUID
task.status = "in_progress"
task.properties_dict = {"is_abortable": True}
task.payload_dict = {}
# Set real values for dedup_key generation (used by UpdateTaskCommand lock)
task.scope = "shared"
task.task_type = "test_task"
task.task_key = "test_key"
task.user_id = 1
return task
@pytest.fixture
def mock_task_not_abortable():
"""Create a mock task that is NOT abortable."""
task = MagicMock()
task.uuid = TEST_UUID
task.status = "in_progress"
task.properties_dict = {} # No is_abortable means it's not abortable
task.payload_dict = {}
# Set real values for dedup_key generation (used by UpdateTaskCommand lock)
task.scope = "shared"
task.task_type = "test_task"
task.task_key = "test_key"
task.user_id = 1
return task
@pytest.fixture
def task_context_for_timeout(mock_flask_app, mock_task_abortable):
"""Create TaskContext with mocked dependencies for timeout tests."""
# Ensure mock_task has required attributes for TaskContext
mock_task_abortable.payload_dict = {}
with (
patch("superset.tasks.context.current_app") as mock_current_app,
patch("superset.daos.tasks.TaskDAO") as mock_dao,
patch("superset.tasks.manager.cache_manager") as mock_cache_manager,
):
# Disable Redis by making signal_cache return None
mock_cache_manager.signal_cache = None
# Configure current_app mock
mock_current_app.config = mock_flask_app.config
mock_current_app._get_current_object.return_value = mock_flask_app
# Configure TaskDAO mock
mock_dao.find_one_or_none.return_value = mock_task_abortable
ctx = TaskContext(mock_task_abortable)
ctx._app = mock_flask_app
yield ctx
# Cleanup: stop timers if started
ctx.stop_timeout_timer()
if ctx._abort_listener:
ctx.stop_abort_polling()
# =============================================================================
# TaskWrapper._merge_options Timeout Tests
# =============================================================================
class TestTimeoutMerging:
"""Test timeout merging behavior in TaskWrapper._merge_options."""
def test_merge_options_decorator_timeout_used_when_no_override(self):
"""Test that decorator timeout is used when no override is provided."""
def dummy_func():
pass
wrapper = TaskWrapper(
name="test_task",
func=dummy_func,
default_options=TaskOptions(),
scope=TaskScope.PRIVATE,
default_timeout=300, # 5-minute default
)
merged = wrapper._merge_options(None)
assert merged.timeout == 300
def test_merge_options_override_timeout_takes_precedence(self):
"""Test that TaskOptions timeout overrides decorator default."""
def dummy_func():
pass
wrapper = TaskWrapper(
name="test_task",
func=dummy_func,
default_options=TaskOptions(),
scope=TaskScope.PRIVATE,
default_timeout=300, # 5-minute default
)
override = TaskOptions(timeout=600) # 10-minute override
merged = wrapper._merge_options(override)
assert merged.timeout == 600
def test_merge_options_no_timeout_when_not_configured(self):
"""Test that no timeout is set when not configured anywhere."""
def dummy_func():
pass
wrapper = TaskWrapper(
name="test_task",
func=dummy_func,
default_options=TaskOptions(),
scope=TaskScope.PRIVATE,
default_timeout=None, # No default timeout
)
merged = wrapper._merge_options(None)
assert merged.timeout is None
def test_merge_options_override_with_other_options_preserves_timeout(self):
"""Test that setting other options doesn't lose decorator timeout."""
def dummy_func():
pass
wrapper = TaskWrapper(
name="test_task",
func=dummy_func,
default_options=TaskOptions(),
scope=TaskScope.PRIVATE,
default_timeout=300,
)
# Override only task_key, not timeout
override = TaskOptions(task_key="my-key")
merged = wrapper._merge_options(override)
# Should keep decorator timeout since override.timeout is None
assert merged.timeout == 300
assert merged.task_key == "my-key"
# =============================================================================
# TaskContext Timeout Timer Tests
# =============================================================================
class TestTimeoutTimer:
"""Test TaskContext timeout timer behavior."""
def test_start_timeout_timer_sets_timer(self, task_context_for_timeout):
"""Test that start_timeout_timer creates a timer."""
ctx = task_context_for_timeout
assert ctx._timeout_timer is None
ctx.start_timeout_timer(10)
assert ctx._timeout_timer is not None
assert ctx._timeout_triggered is False
def test_start_timeout_timer_is_idempotent(self, task_context_for_timeout):
"""Test that starting timer twice doesn't create duplicate timers."""
ctx = task_context_for_timeout
ctx.start_timeout_timer(10)
first_timer = ctx._timeout_timer
ctx.start_timeout_timer(20) # Try to start again
second_timer = ctx._timeout_timer
assert first_timer is second_timer
def test_stop_timeout_timer_cancels_timer(self, task_context_for_timeout):
"""Test that stop_timeout_timer cancels the timer."""
ctx = task_context_for_timeout
ctx.start_timeout_timer(10)
assert ctx._timeout_timer is not None
ctx.stop_timeout_timer()
assert ctx._timeout_timer is None
def test_stop_timeout_timer_safe_when_no_timer(self, task_context_for_timeout):
"""Test that stop_timeout_timer doesn't fail when no timer exists."""
ctx = task_context_for_timeout
assert ctx._timeout_timer is None
ctx.stop_timeout_timer() # Should not raise
assert ctx._timeout_timer is None
def test_timeout_triggered_property_initially_false(self, task_context_for_timeout):
"""Test that timeout_triggered is False initially."""
ctx = task_context_for_timeout
assert ctx.timeout_triggered is False
def test_cleanup_stops_timeout_timer(self, task_context_for_timeout):
"""Test that _run_cleanup stops the timeout timer."""
ctx = task_context_for_timeout
ctx.start_timeout_timer(10)
assert ctx._timeout_timer is not None
ctx._run_cleanup()
assert ctx._timeout_timer is None
class TestTimeoutTrigger:
"""Test timeout trigger behavior when timer fires."""
def test_timeout_triggers_abort_when_abortable(
self, mock_flask_app, mock_task_abortable
):
"""Test that timeout triggers abort handlers when task is abortable."""
abort_called = False
with (
patch("superset.tasks.context.current_app") as mock_current_app,
patch("superset.daos.tasks.TaskDAO") as mock_dao,
patch(
"superset.commands.tasks.update.UpdateTaskCommand"
) as mock_update_cmd,
patch("superset.tasks.manager.cache_manager") as mock_cache_manager,
):
# Disable Redis by making signal_cache return None
mock_cache_manager.signal_cache = None
mock_current_app.config = mock_flask_app.config
mock_current_app._get_current_object.return_value = mock_flask_app
mock_dao.find_one_or_none.return_value = mock_task_abortable
ctx = TaskContext(mock_task_abortable)
ctx._app = mock_flask_app
@ctx.on_abort
def handle_abort():
nonlocal abort_called
abort_called = True
# Start short timeout
ctx.start_timeout_timer(1)
# Wait for timeout to fire
time.sleep(1.5)
# Abort handler should have been called
assert abort_called
assert ctx._timeout_triggered
assert ctx._abort_detected
# Verify UpdateTaskCommand was called with ABORTING status
mock_update_cmd.assert_called()
call_kwargs = mock_update_cmd.call_args[1]
assert call_kwargs.get("status") == "aborting"
# Cleanup
ctx.stop_timeout_timer()
if ctx._abort_listener:
ctx.stop_abort_polling()
def test_timeout_logs_warning_when_not_abortable(
self, mock_flask_app, mock_task_not_abortable
):
"""Test that timeout logs warning when task has no abort handler."""
with (
patch("superset.tasks.context.current_app") as mock_current_app,
patch("superset.daos.tasks.TaskDAO") as mock_dao,
patch("superset.tasks.context.logger") as mock_logger,
patch("superset.tasks.manager.cache_manager") as mock_cache_manager,
):
# Disable Redis by making signal_cache return None
mock_cache_manager.signal_cache = None
mock_current_app.config = mock_flask_app.config
mock_current_app._get_current_object.return_value = mock_flask_app
mock_dao.find_one_or_none.return_value = mock_task_not_abortable
ctx = TaskContext(mock_task_not_abortable)
ctx._app = mock_flask_app
# No abort handler registered
# Start short timeout
ctx.start_timeout_timer(1)
# Wait for timeout to fire
time.sleep(1.5)
# Should have logged warning
mock_logger.warning.assert_called()
warning_call = mock_logger.warning.call_args
assert "no abort handler" in warning_call[0][0].lower()
assert ctx._timeout_triggered
assert not ctx._abort_detected # No abort since no handler
# Cleanup
ctx.stop_timeout_timer()
def test_timeout_does_not_trigger_if_already_aborting(
self, mock_flask_app, mock_task_abortable
):
"""Test that timeout doesn't re-trigger abort if already aborting."""
abort_count = 0
with (
patch("superset.tasks.context.current_app") as mock_current_app,
patch("superset.daos.tasks.TaskDAO") as mock_dao,
patch("superset.commands.tasks.update.UpdateTaskCommand"),
patch("superset.tasks.manager.cache_manager") as mock_cache_manager,
):
# Disable Redis by making signal_cache return None
mock_cache_manager.signal_cache = None
mock_current_app.config = mock_flask_app.config
mock_current_app._get_current_object.return_value = mock_flask_app
mock_dao.find_one_or_none.return_value = mock_task_abortable
ctx = TaskContext(mock_task_abortable)
ctx._app = mock_flask_app
@ctx.on_abort
def handle_abort():
nonlocal abort_count
abort_count += 1
# Pre-set abort detected
ctx._abort_detected = True
# Start short timeout
ctx.start_timeout_timer(1)
# Wait for timeout to fire
time.sleep(1.5)
# Handler should NOT have been called since already aborting
assert abort_count == 0
# Cleanup
ctx.stop_timeout_timer()
if ctx._abort_listener:
ctx.stop_abort_polling()
# =============================================================================
# Task Decorator Timeout Tests
# =============================================================================
class TestTaskDecoratorTimeout:
"""Test @task decorator timeout parameter."""
def test_task_decorator_accepts_timeout(self):
"""Test that @task decorator accepts timeout parameter."""
from superset.tasks.decorators import task
from superset.tasks.registry import TaskRegistry
@task(name="test_timeout_task_1", timeout=300)
def timeout_test_task_1():
pass
assert isinstance(timeout_test_task_1, TaskWrapper)
assert timeout_test_task_1.default_timeout == 300
# Cleanup registry
TaskRegistry._tasks.pop("test_timeout_task_1", None)
def test_task_decorator_without_timeout(self):
"""Test that @task decorator works without timeout."""
from superset.tasks.decorators import task
from superset.tasks.registry import TaskRegistry
@task(name="test_timeout_task_2")
def timeout_test_task_2():
pass
assert isinstance(timeout_test_task_2, TaskWrapper)
assert timeout_test_task_2.default_timeout is None
# Cleanup registry
TaskRegistry._tasks.pop("test_timeout_task_2", None)
def test_task_decorator_with_all_params(self):
"""Test that @task decorator accepts all parameters together."""
from superset.tasks.decorators import task
from superset.tasks.registry import TaskRegistry
@task(name="test_timeout_task_3", scope=TaskScope.SHARED, timeout=600)
def timeout_test_task_3():
pass
assert timeout_test_task_3.name == "test_timeout_task_3"
assert timeout_test_task_3.scope == TaskScope.SHARED
assert timeout_test_task_3.default_timeout == 600
# Cleanup registry
TaskRegistry._tasks.pop("test_timeout_task_3", None)
# =============================================================================
# Timeout Terminal State Tests
# =============================================================================
class TestTimeoutTerminalState:
"""Test timeout transitions to correct terminal state (TIMED_OUT vs FAILURE)."""
def test_timeout_triggered_flag_set_on_timeout(
self, mock_flask_app, mock_task_abortable
):
"""Test that timeout_triggered flag is set when timeout fires."""
with (
patch("superset.tasks.context.current_app") as mock_current_app,
patch("superset.daos.tasks.TaskDAO") as mock_dao,
patch("superset.commands.tasks.update.UpdateTaskCommand"),
patch("superset.tasks.manager.cache_manager") as mock_cache_manager,
):
# Disable Redis by making signal_cache return None
mock_cache_manager.signal_cache = None
mock_current_app.config = mock_flask_app.config
mock_current_app._get_current_object.return_value = mock_flask_app
mock_dao.find_one_or_none.return_value = mock_task_abortable
ctx = TaskContext(mock_task_abortable)
ctx._app = mock_flask_app
@ctx.on_abort
def handle_abort():
pass
# Initially not triggered
assert ctx.timeout_triggered is False
# Start short timeout
ctx.start_timeout_timer(1)
# Wait for timeout to fire
time.sleep(1.5)
# Should be set after timeout
assert ctx.timeout_triggered is True
# Cleanup
ctx.stop_timeout_timer()
if ctx._abort_listener:
ctx.stop_abort_polling()
def test_user_abort_does_not_set_timeout_triggered(
self, mock_flask_app, mock_task_abortable
):
"""Test that user abort doesn't set timeout_triggered flag."""
with (
patch("superset.tasks.context.current_app") as mock_current_app,
patch("superset.daos.tasks.TaskDAO") as mock_dao,
patch("superset.commands.tasks.update.UpdateTaskCommand"),
patch("superset.tasks.manager.cache_manager") as mock_cache_manager,
):
# Disable Redis by making signal_cache return None
mock_cache_manager.signal_cache = None
mock_current_app.config = mock_flask_app.config
mock_current_app._get_current_object.return_value = mock_flask_app
mock_dao.find_one_or_none.return_value = mock_task_abortable
ctx = TaskContext(mock_task_abortable)
ctx._app = mock_flask_app
@ctx.on_abort
def handle_abort():
pass
# Simulate user abort (not timeout)
ctx._on_abort_detected()
# timeout_triggered should still be False
assert ctx.timeout_triggered is False
# But abort_detected should be True
assert ctx._abort_detected is True
# Cleanup
if ctx._abort_listener:
ctx.stop_abort_polling()
def test_abort_handlers_completed_tracks_success(
self, mock_flask_app, mock_task_abortable
):
"""Test that abort_handlers_completed flag tracks successful
handler execution."""
with (
patch("superset.tasks.context.current_app") as mock_current_app,
patch("superset.daos.tasks.TaskDAO") as mock_dao,
patch("superset.commands.tasks.update.UpdateTaskCommand"),
patch("superset.tasks.manager.cache_manager") as mock_cache_manager,
):
# Disable Redis by making signal_cache return None
mock_cache_manager.signal_cache = None
mock_current_app.config = mock_flask_app.config
mock_current_app._get_current_object.return_value = mock_flask_app
mock_dao.find_one_or_none.return_value = mock_task_abortable
ctx = TaskContext(mock_task_abortable)
ctx._app = mock_flask_app
@ctx.on_abort
def handle_abort():
pass # Successful handler
# Initially not completed
assert ctx.abort_handlers_completed is False
# Trigger abort handlers
ctx._trigger_abort_handlers()
# Should be marked as completed
assert ctx.abort_handlers_completed is True
# Cleanup
if ctx._abort_listener:
ctx.stop_abort_polling()
def test_abort_handlers_completed_false_on_exception(
self, mock_flask_app, mock_task_abortable
):
"""Test that abort_handlers_completed is False when handler throws."""
with (
patch("superset.tasks.context.current_app") as mock_current_app,
patch("superset.daos.tasks.TaskDAO") as mock_dao,
patch("superset.commands.tasks.update.UpdateTaskCommand"),
patch("superset.tasks.manager.cache_manager") as mock_cache_manager,
):
# Disable Redis by making signal_cache return None
mock_cache_manager.signal_cache = None
mock_current_app.config = mock_flask_app.config
mock_current_app._get_current_object.return_value = mock_flask_app
mock_dao.find_one_or_none.return_value = mock_task_abortable
ctx = TaskContext(mock_task_abortable)
ctx._app = mock_flask_app
@ctx.on_abort
def handle_abort():
raise ValueError("Handler failed")
# Initially not completed
assert ctx.abort_handlers_completed is False
# Trigger abort handlers (will catch the exception internally)
ctx._trigger_abort_handlers()
# Should NOT be marked as completed since handler threw
assert ctx.abort_handlers_completed is False
# Cleanup
if ctx._abort_listener:
ctx.stop_abort_polling()

View File

@@ -22,9 +22,19 @@ from typing import Any, Optional, Union
import pytest
from flask_appbuilder.security.sqla.models import User
from superset_core.api.tasks import TaskScope
from superset.tasks.exceptions import ExecutorNotFoundError, InvalidExecutorError
from superset.tasks.types import Executor, ExecutorType, FixedExecutor
from superset.tasks.utils import (
error_update,
get_active_dedup_key,
get_finished_dedup_key,
parse_properties,
progress_update,
serialize_properties,
)
from superset.utils.hashing import hash_from_str
FIXED_USER_ID = 1234
FIXED_USERNAME = "admin"
@@ -330,3 +340,242 @@ def test_get_executor(
)
assert executor_type == expected_executor_type
assert executor == expected_executor
@pytest.mark.parametrize(
"scope,task_type,task_key,user_id,expected_composite_key",
[
# Private tasks with TaskScope enum
(
TaskScope.PRIVATE,
"sql_execution",
"chart_123",
42,
"private|sql_execution|chart_123|42",
),
(
TaskScope.PRIVATE,
"thumbnail_gen",
"dash_456",
100,
"private|thumbnail_gen|dash_456|100",
),
# Private tasks with string scope
(
"private",
"api_call",
"endpoint_789",
200,
"private|api_call|endpoint_789|200",
),
# Shared tasks with TaskScope enum
(
TaskScope.SHARED,
"report_gen",
"monthly_report",
None,
"shared|report_gen|monthly_report",
),
(
TaskScope.SHARED,
"export_csv",
"large_export",
999, # user_id should be ignored for shared
"shared|export_csv|large_export",
),
# Shared tasks with string scope
(
"shared",
"batch_process",
"batch_001",
123, # user_id should be ignored for shared
"shared|batch_process|batch_001",
),
# System tasks with TaskScope enum
(
TaskScope.SYSTEM,
"cleanup_task",
"daily_cleanup",
None,
"system|cleanup_task|daily_cleanup",
),
(
TaskScope.SYSTEM,
"db_migration",
"version_123",
1, # user_id should be ignored for system
"system|db_migration|version_123",
),
# System tasks with string scope
(
"system",
"maintenance",
"nightly_job",
2, # user_id should be ignored for system
"system|maintenance|nightly_job",
),
],
)
def test_get_active_dedup_key(
scope, task_type, task_key, user_id, expected_composite_key, app_context
):
"""Test get_active_dedup_key generates a hash of the composite key.
The function hashes the composite key using the configured HASH_ALGORITHM
to produce a fixed-length dedup_key for database storage. The result is
truncated to 64 chars to fit the database column.
"""
result = get_active_dedup_key(scope, task_type, task_key, user_id)
# The result should be a hash of the expected composite key, truncated to 64 chars
expected_hash = hash_from_str(expected_composite_key)[:64]
assert result == expected_hash
assert len(result) <= 64
def test_get_active_dedup_key_private_requires_user_id():
"""Test that private tasks require explicit user_id parameter."""
with pytest.raises(ValueError, match="user_id required for private tasks"):
get_active_dedup_key(TaskScope.PRIVATE, "test_type", "test_key")
def test_get_finished_dedup_key():
"""Test that finished tasks use UUID as dedup_key"""
test_uuid = "a1b2c3d4-e5f6-7890-abcd-ef1234567890"
result = get_finished_dedup_key(test_uuid)
assert result == test_uuid
@pytest.mark.parametrize(
"progress,expected",
[
# Float (percentage) progress
(0.5, {"progress_percent": 0.5}),
(0.0, {"progress_percent": 0.0}),
(1.0, {"progress_percent": 1.0}),
(0.25, {"progress_percent": 0.25}),
# Int (count only) progress
(42, {"progress_current": 42}),
(0, {"progress_current": 0}),
(1000, {"progress_current": 1000}),
# Tuple (current, total) progress with auto-computed percentage
(
(50, 100),
{"progress_current": 50, "progress_total": 100, "progress_percent": 0.5},
),
(
(25, 100),
{"progress_current": 25, "progress_total": 100, "progress_percent": 0.25},
),
(
(100, 100),
{"progress_current": 100, "progress_total": 100, "progress_percent": 1.0},
),
# Tuple with zero total (no percentage computed)
((10, 0), {"progress_current": 10, "progress_total": 0}),
((0, 0), {"progress_current": 0, "progress_total": 0}),
],
)
def test_progress_update(progress, expected):
"""Test progress_update returns correct TaskProperties dict."""
result = progress_update(progress)
assert result == expected
def test_error_update():
"""Test error_update captures exception details."""
try:
raise ValueError("Test error message")
except ValueError as e:
result = error_update(e)
assert result["error_message"] == "Test error message"
assert result["exception_type"] == "ValueError"
assert "stack_trace" in result
assert "ValueError" in result["stack_trace"]
def test_error_update_custom_exception():
"""Test error_update with custom exception class."""
class CustomError(Exception):
pass
try:
raise CustomError("Custom error")
except CustomError as e:
result = error_update(e)
assert result["error_message"] == "Custom error"
assert result["exception_type"] == "CustomError"
@pytest.mark.parametrize(
"json_str,expected",
[
# Valid JSON
(
'{"is_abortable": true, "progress_percent": 0.5}',
{"is_abortable": True, "progress_percent": 0.5},
),
(
'{"error_message": "Something failed"}',
{"error_message": "Something failed"},
),
(
'{"progress_current": 50, "progress_total": 100}',
{"progress_current": 50, "progress_total": 100},
),
# Empty/None cases
("", {}),
(None, {}),
# Invalid JSON returns empty dict
("not valid json", {}),
("{broken", {}),
# Unknown keys are preserved (forward compatibility)
(
'{"is_abortable": true, "future_field": "value"}',
{"is_abortable": True, "future_field": "value"},
),
],
)
def test_parse_properties(json_str, expected):
"""Test parse_properties parses JSON to TaskProperties dict."""
result = parse_properties(json_str)
assert result == expected
@pytest.mark.parametrize(
"props,expected_contains",
[
# Full properties
(
{"is_abortable": True, "progress_percent": 0.5},
{"is_abortable": True, "progress_percent": 0.5},
),
# Empty dict
({}, {}),
# Sparse properties
({"is_abortable": True}, {"is_abortable": True}),
({"error_message": "fail"}, {"error_message": "fail"}),
],
)
def test_serialize_properties(props, expected_contains):
"""Test serialize_properties converts TaskProperties to JSON."""
from superset.utils import json
result = serialize_properties(props)
parsed = json.loads(result)
assert parsed == expected_contains
def test_properties_roundtrip():
"""Test that serialize -> parse roundtrip preserves data."""
original = {
"is_abortable": True,
"progress_percent": 0.75,
"error_message": "Test error",
}
serialized = serialize_properties(original)
parsed = parse_properties(serialized)
assert parsed == original

View File

@@ -54,7 +54,7 @@ def test_json_loads_exception():
def test_json_loads_encoding():
unicode_data = b'{"a": "\u0073\u0074\u0072"}'
unicode_data = rb'{"a": "\u0073\u0074\u0072"}'
data = json.loads(unicode_data)
assert data["a"] == "str"
utf16_data = b'\xff\xfe{\x00"\x00a\x00"\x00:\x00 \x00"\x00s\x00t\x00r\x00"\x00}\x00'

View File

@@ -119,7 +119,7 @@ def test_refresh_oauth2_token_deletes_token_on_oauth2_exception(
was revoked), the invalid token should be deleted and the exception re-raised.
"""
db = mocker.patch("superset.utils.oauth2.db")
mocker.patch("superset.utils.oauth2.KeyValueDistributedLock")
mocker.patch("superset.utils.oauth2.DistributedLock")
class OAuth2ExceptionError(Exception):
pass
@@ -149,7 +149,7 @@ def test_refresh_oauth2_token_keeps_token_on_other_exception(
exception re-raised.
"""
db = mocker.patch("superset.utils.oauth2.db")
mocker.patch("superset.utils.oauth2.KeyValueDistributedLock")
mocker.patch("superset.utils.oauth2.DistributedLock")
class OAuth2ExceptionError(Exception):
pass
@@ -175,7 +175,7 @@ def test_refresh_oauth2_token_no_access_token_in_response(
This can happen when the refresh token was revoked.
"""
mocker.patch("superset.utils.oauth2.db")
mocker.patch("superset.utils.oauth2.KeyValueDistributedLock")
mocker.patch("superset.utils.oauth2.DistributedLock")
db_engine_spec = mocker.MagicMock()
db_engine_spec.get_oauth2_fresh_token.return_value = {
"error": "invalid_grant",