mirror of
https://github.com/apache/superset.git
synced 2026-04-19 08:04:53 +00:00
feat: add global task framework (#36368)
This commit is contained in:
420
tests/unit_tests/daos/test_tasks.py
Normal file
420
tests/unit_tests/daos/test_tasks.py
Normal file
@@ -0,0 +1,420 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file to you under
|
||||
# the Apache License, Version 2.0 (the "License"); you may not
|
||||
# use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from collections.abc import Iterator
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm.session import Session
|
||||
from superset_core.api.tasks import TaskProperties, TaskScope, TaskStatus
|
||||
|
||||
from superset.commands.tasks.exceptions import TaskNotAbortableError
|
||||
from superset.models.tasks import Task
|
||||
from superset.tasks.utils import get_active_dedup_key, get_finished_dedup_key
|
||||
|
||||
# Test constants
|
||||
TASK_UUID = UUID("e7765491-40c1-4f35-a4f5-06308e79310e")
|
||||
TASK_ID = 42
|
||||
TEST_TASK_TYPE = "test_type"
|
||||
TEST_TASK_KEY = "test-key"
|
||||
TEST_USER_ID = 1
|
||||
|
||||
|
||||
def create_task(
|
||||
session: Session,
|
||||
*,
|
||||
task_id: int | None = None,
|
||||
task_uuid: UUID | None = None,
|
||||
task_key: str = TEST_TASK_KEY,
|
||||
task_type: str = TEST_TASK_TYPE,
|
||||
scope: TaskScope = TaskScope.PRIVATE,
|
||||
status: TaskStatus = TaskStatus.PENDING,
|
||||
user_id: int | None = TEST_USER_ID,
|
||||
properties: TaskProperties | None = None,
|
||||
use_finished_dedup_key: bool = False,
|
||||
) -> Task:
|
||||
"""Helper to create a task with sensible defaults for testing."""
|
||||
if use_finished_dedup_key:
|
||||
dedup_key = get_finished_dedup_key(task_uuid or TASK_UUID)
|
||||
else:
|
||||
dedup_key = get_active_dedup_key(
|
||||
scope=scope,
|
||||
task_type=task_type,
|
||||
task_key=task_key,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
task = Task(
|
||||
task_type=task_type,
|
||||
task_key=task_key,
|
||||
scope=scope.value,
|
||||
status=status.value,
|
||||
dedup_key=dedup_key,
|
||||
user_id=user_id,
|
||||
)
|
||||
if task_id is not None:
|
||||
task.id = task_id
|
||||
if task_uuid:
|
||||
task.uuid = task_uuid
|
||||
if properties:
|
||||
task.update_properties(properties)
|
||||
|
||||
session.add(task)
|
||||
session.flush()
|
||||
return task
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session_with_task(session: Session) -> Iterator[Session]:
|
||||
"""Create a session with Task and TaskSubscriber tables."""
|
||||
from superset.models.task_subscribers import TaskSubscriber
|
||||
|
||||
engine = session.get_bind()
|
||||
Task.metadata.create_all(engine)
|
||||
TaskSubscriber.metadata.create_all(engine)
|
||||
|
||||
yield session
|
||||
session.rollback()
|
||||
|
||||
|
||||
def test_find_by_task_key_active(session_with_task: Session) -> None:
|
||||
"""Test finding active task by task_key"""
|
||||
from superset.daos.tasks import TaskDAO
|
||||
|
||||
create_task(session_with_task)
|
||||
|
||||
result = TaskDAO.find_by_task_key(
|
||||
task_type=TEST_TASK_TYPE,
|
||||
task_key=TEST_TASK_KEY,
|
||||
scope=TaskScope.PRIVATE,
|
||||
user_id=TEST_USER_ID,
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.task_key == TEST_TASK_KEY
|
||||
assert result.task_type == TEST_TASK_TYPE
|
||||
assert result.status == TaskStatus.PENDING.value
|
||||
|
||||
|
||||
def test_find_by_task_key_not_found(session_with_task: Session) -> None:
|
||||
"""Test finding task by task_key returns None when not found"""
|
||||
from superset.daos.tasks import TaskDAO
|
||||
|
||||
result = TaskDAO.find_by_task_key(
|
||||
task_type=TEST_TASK_TYPE,
|
||||
task_key="nonexistent-key",
|
||||
scope=TaskScope.PRIVATE,
|
||||
user_id=TEST_USER_ID,
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_find_by_task_key_finished_not_found(session_with_task: Session) -> None:
|
||||
"""Test that find_by_task_key returns None for finished tasks.
|
||||
|
||||
Finished tasks have a different dedup_key format (UUID-based),
|
||||
so they won't be found by the active task lookup.
|
||||
"""
|
||||
from superset.daos.tasks import TaskDAO
|
||||
|
||||
create_task(
|
||||
session_with_task,
|
||||
task_key="finished-key",
|
||||
status=TaskStatus.SUCCESS,
|
||||
use_finished_dedup_key=True,
|
||||
task_uuid=TASK_UUID,
|
||||
)
|
||||
|
||||
# Should not find SUCCESS task via active lookup
|
||||
result = TaskDAO.find_by_task_key(
|
||||
task_type=TEST_TASK_TYPE,
|
||||
task_key="finished-key",
|
||||
scope=TaskScope.PRIVATE,
|
||||
user_id=TEST_USER_ID,
|
||||
)
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_create_task_success(session_with_task: Session) -> None:
|
||||
"""Test successful task creation."""
|
||||
from superset.daos.tasks import TaskDAO
|
||||
|
||||
result = TaskDAO.create_task(
|
||||
task_type=TEST_TASK_TYPE,
|
||||
task_key=TEST_TASK_KEY,
|
||||
scope=TaskScope.PRIVATE,
|
||||
user_id=TEST_USER_ID,
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.task_key == TEST_TASK_KEY
|
||||
assert result.task_type == TEST_TASK_TYPE
|
||||
assert result.status == TaskStatus.PENDING.value
|
||||
assert isinstance(result, Task)
|
||||
|
||||
|
||||
def test_create_task_with_user_id(session_with_task: Session) -> None:
|
||||
"""Test task creation with explicit user_id."""
|
||||
from superset.daos.tasks import TaskDAO
|
||||
|
||||
result = TaskDAO.create_task(
|
||||
task_type=TEST_TASK_TYPE,
|
||||
task_key="user-task",
|
||||
scope=TaskScope.PRIVATE,
|
||||
user_id=42,
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.user_id == 42
|
||||
# Creator should be auto-subscribed
|
||||
assert len(result.subscribers) == 1
|
||||
assert result.subscribers[0].user_id == 42
|
||||
|
||||
|
||||
def test_create_task_with_properties(session_with_task: Session) -> None:
|
||||
"""Test task creation with properties."""
|
||||
from superset.daos.tasks import TaskDAO
|
||||
|
||||
result = TaskDAO.create_task(
|
||||
task_type=TEST_TASK_TYPE,
|
||||
task_key="props-task",
|
||||
scope=TaskScope.PRIVATE,
|
||||
user_id=TEST_USER_ID,
|
||||
properties={"timeout": 300},
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.properties_dict.get("timeout") == 300
|
||||
|
||||
|
||||
def test_abort_task_pending_success(session_with_task: Session) -> None:
|
||||
"""Test successful abort of pending task - goes directly to ABORTED"""
|
||||
from superset.daos.tasks import TaskDAO
|
||||
|
||||
task = create_task(
|
||||
session_with_task,
|
||||
task_key="pending-task",
|
||||
status=TaskStatus.PENDING,
|
||||
)
|
||||
|
||||
result = TaskDAO.abort_task(task.uuid, skip_base_filter=True)
|
||||
|
||||
assert result is not None
|
||||
assert result.status == TaskStatus.ABORTED.value
|
||||
|
||||
|
||||
def test_abort_task_in_progress_abortable(session_with_task: Session) -> None:
|
||||
"""Test abort of in-progress task with abort handler.
|
||||
|
||||
Should transition to ABORTING status.
|
||||
"""
|
||||
from superset.daos.tasks import TaskDAO
|
||||
|
||||
task = create_task(
|
||||
session_with_task,
|
||||
task_key="abortable-task",
|
||||
status=TaskStatus.IN_PROGRESS,
|
||||
properties={"is_abortable": True},
|
||||
)
|
||||
|
||||
result = TaskDAO.abort_task(task.uuid, skip_base_filter=True)
|
||||
|
||||
assert result is not None
|
||||
# Should set status to ABORTING, not ABORTED
|
||||
assert result.status == TaskStatus.ABORTING.value
|
||||
|
||||
|
||||
def test_abort_task_in_progress_not_abortable(session_with_task: Session) -> None:
|
||||
"""Test abort of in-progress task without abort handler - raises error"""
|
||||
from superset.daos.tasks import TaskDAO
|
||||
|
||||
task = create_task(
|
||||
session_with_task,
|
||||
task_key="non-abortable-task",
|
||||
status=TaskStatus.IN_PROGRESS,
|
||||
properties={"is_abortable": False},
|
||||
)
|
||||
|
||||
with pytest.raises(TaskNotAbortableError):
|
||||
TaskDAO.abort_task(task.uuid, skip_base_filter=True)
|
||||
|
||||
|
||||
def test_abort_task_in_progress_is_abortable_none(session_with_task: Session) -> None:
|
||||
"""Test abort of in-progress task with is_abortable not set - raises error"""
|
||||
from superset.daos.tasks import TaskDAO
|
||||
|
||||
task = create_task(
|
||||
session_with_task,
|
||||
task_key="no-abortable-prop-task",
|
||||
status=TaskStatus.IN_PROGRESS,
|
||||
# Empty properties - no is_abortable key
|
||||
)
|
||||
|
||||
with pytest.raises(TaskNotAbortableError):
|
||||
TaskDAO.abort_task(task.uuid, skip_base_filter=True)
|
||||
|
||||
|
||||
def test_abort_task_already_aborting(session_with_task: Session) -> None:
|
||||
"""Test abort of already aborting task - idempotent success"""
|
||||
from superset.daos.tasks import TaskDAO
|
||||
|
||||
task = create_task(
|
||||
session_with_task,
|
||||
task_key="aborting-task",
|
||||
status=TaskStatus.ABORTING,
|
||||
)
|
||||
|
||||
result = TaskDAO.abort_task(task.uuid, skip_base_filter=True)
|
||||
|
||||
# Idempotent - returns task without error
|
||||
assert result is not None
|
||||
assert result.status == TaskStatus.ABORTING.value
|
||||
|
||||
|
||||
def test_abort_task_not_found(session_with_task: Session) -> None:
|
||||
"""Test abort fails when task not found"""
|
||||
from superset.daos.tasks import TaskDAO
|
||||
|
||||
result = TaskDAO.abort_task(UUID("00000000-0000-0000-0000-000000000000"))
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_abort_task_already_finished(session_with_task: Session) -> None:
|
||||
"""Test abort fails when task already finished"""
|
||||
from superset.daos.tasks import TaskDAO
|
||||
|
||||
task = create_task(
|
||||
session_with_task,
|
||||
task_key="finished-task",
|
||||
status=TaskStatus.SUCCESS,
|
||||
use_finished_dedup_key=True,
|
||||
task_uuid=TASK_UUID,
|
||||
)
|
||||
|
||||
result = TaskDAO.abort_task(task.uuid, skip_base_filter=True)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_add_subscriber(session_with_task: Session) -> None:
|
||||
"""Test adding a subscriber to a task"""
|
||||
from superset.daos.tasks import TaskDAO
|
||||
|
||||
task = create_task(
|
||||
session_with_task,
|
||||
task_key="shared-task",
|
||||
scope=TaskScope.SHARED,
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
# Add subscriber
|
||||
result = TaskDAO.add_subscriber(task.id, user_id=TEST_USER_ID)
|
||||
assert result is True
|
||||
|
||||
# Verify subscriber was added
|
||||
session_with_task.refresh(task)
|
||||
assert len(task.subscribers) == 1
|
||||
assert task.subscribers[0].user_id == TEST_USER_ID
|
||||
|
||||
|
||||
def test_add_subscriber_idempotent(session_with_task: Session) -> None:
|
||||
"""Test adding same subscriber twice is idempotent"""
|
||||
from superset.daos.tasks import TaskDAO
|
||||
|
||||
task = create_task(
|
||||
session_with_task,
|
||||
task_key="shared-task-2",
|
||||
scope=TaskScope.SHARED,
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
# Add subscriber twice
|
||||
result1 = TaskDAO.add_subscriber(task.id, user_id=TEST_USER_ID)
|
||||
result2 = TaskDAO.add_subscriber(task.id, user_id=TEST_USER_ID)
|
||||
|
||||
assert result1 is True
|
||||
assert result2 is False # Already subscribed
|
||||
|
||||
# Verify only one subscriber
|
||||
session_with_task.refresh(task)
|
||||
assert len(task.subscribers) == 1
|
||||
|
||||
|
||||
def test_remove_subscriber(session_with_task: Session) -> None:
|
||||
"""Test removing a subscriber from a task"""
|
||||
from superset.daos.tasks import TaskDAO
|
||||
|
||||
task = create_task(
|
||||
session_with_task,
|
||||
task_key="shared-task-3",
|
||||
scope=TaskScope.SHARED,
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
TaskDAO.add_subscriber(task.id, user_id=TEST_USER_ID)
|
||||
session_with_task.refresh(task)
|
||||
assert len(task.subscribers) == 1
|
||||
|
||||
# Remove subscriber
|
||||
result = TaskDAO.remove_subscriber(task.id, user_id=TEST_USER_ID)
|
||||
|
||||
assert result is not None
|
||||
assert len(result.subscribers) == 0
|
||||
|
||||
|
||||
def test_remove_subscriber_not_subscribed(session_with_task: Session) -> None:
|
||||
"""Test removing non-existent subscriber returns None"""
|
||||
from superset.daos.tasks import TaskDAO
|
||||
|
||||
task = create_task(
|
||||
session_with_task,
|
||||
task_key="shared-task-4",
|
||||
scope=TaskScope.SHARED,
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
# Try to remove non-existent subscriber
|
||||
result = TaskDAO.remove_subscriber(task.id, user_id=999)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_get_status(session_with_task: Session) -> None:
|
||||
"""Test get_status returns status string when task found by UUID"""
|
||||
from superset.daos.tasks import TaskDAO
|
||||
|
||||
task = create_task(
|
||||
session_with_task,
|
||||
task_uuid=TASK_UUID,
|
||||
task_key="status-task",
|
||||
status=TaskStatus.IN_PROGRESS,
|
||||
)
|
||||
|
||||
result = TaskDAO.get_status(task.uuid)
|
||||
|
||||
assert result == TaskStatus.IN_PROGRESS.value
|
||||
|
||||
|
||||
def test_get_status_not_found(session_with_task: Session) -> None:
|
||||
"""Test get_status returns None when task not found"""
|
||||
from superset.daos.tasks import TaskDAO
|
||||
|
||||
result = TaskDAO.get_status(UUID("00000000-0000-0000-0000-000000000000"))
|
||||
|
||||
assert result is None
|
||||
@@ -18,17 +18,21 @@
|
||||
# pylint: disable=invalid-name
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from freezegun import freeze_time
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
# Force module loading before tests run so patches work correctly
|
||||
import superset.commands.distributed_lock.acquire as acquire_module
|
||||
import superset.commands.distributed_lock.release as release_module
|
||||
from superset import db
|
||||
from superset.distributed_lock import KeyValueDistributedLock
|
||||
from superset.distributed_lock import DistributedLock
|
||||
from superset.distributed_lock.types import LockValue
|
||||
from superset.distributed_lock.utils import get_key
|
||||
from superset.exceptions import CreateKeyValueDistributedLockFailedException
|
||||
from superset.exceptions import AcquireDistributedLockFailedException
|
||||
from superset.key_value.types import JsonKeyValueCodec
|
||||
|
||||
LOCK_VALUE: LockValue = {"value": True}
|
||||
@@ -56,9 +60,9 @@ def _get_other_session() -> Session:
|
||||
return SessionMaker()
|
||||
|
||||
|
||||
def test_key_value_distributed_lock_happy_path() -> None:
|
||||
def test_distributed_lock_kv_happy_path() -> None:
|
||||
"""
|
||||
Test successfully acquiring and returning the distributed lock.
|
||||
Test successfully acquiring and returning the distributed lock via KV backend.
|
||||
|
||||
Note, we're using another session for asserting the lock state in the Metastore
|
||||
to simulate what another worker will observe. Otherwise, there's the risk that
|
||||
@@ -66,24 +70,29 @@ def test_key_value_distributed_lock_happy_path() -> None:
|
||||
"""
|
||||
session = _get_other_session()
|
||||
|
||||
with freeze_time("2021-01-01"):
|
||||
assert _get_lock(MAIN_KEY, session) is None
|
||||
# Ensure Redis is not configured so KV backend is used
|
||||
with (
|
||||
patch.object(acquire_module, "get_redis_client", return_value=None),
|
||||
patch.object(release_module, "get_redis_client", return_value=None),
|
||||
):
|
||||
with freeze_time("2021-01-01"):
|
||||
assert _get_lock(MAIN_KEY, session) is None
|
||||
|
||||
with KeyValueDistributedLock("ns", a=1, b=2) as key:
|
||||
assert key == MAIN_KEY
|
||||
assert _get_lock(key, session) == LOCK_VALUE
|
||||
assert _get_lock(OTHER_KEY, session) is None
|
||||
with DistributedLock("ns", a=1, b=2) as key:
|
||||
assert key == MAIN_KEY
|
||||
assert _get_lock(key, session) == LOCK_VALUE
|
||||
assert _get_lock(OTHER_KEY, session) is None
|
||||
|
||||
with pytest.raises(CreateKeyValueDistributedLockFailedException):
|
||||
with KeyValueDistributedLock("ns", a=1, b=2):
|
||||
pass
|
||||
with pytest.raises(AcquireDistributedLockFailedException):
|
||||
with DistributedLock("ns", a=1, b=2):
|
||||
pass
|
||||
|
||||
assert _get_lock(MAIN_KEY, session) is None
|
||||
assert _get_lock(MAIN_KEY, session) is None
|
||||
|
||||
|
||||
def test_key_value_distributed_lock_expired() -> None:
|
||||
def test_distributed_lock_kv_expired() -> None:
|
||||
"""
|
||||
Test expiration of the distributed lock
|
||||
Test expiration of the distributed lock via KV backend.
|
||||
|
||||
Note, we're using another session for asserting the lock state in the Metastore
|
||||
to simulate what another worker will observe. Otherwise, there's the risk that
|
||||
@@ -91,11 +100,112 @@ def test_key_value_distributed_lock_expired() -> None:
|
||||
"""
|
||||
session = _get_other_session()
|
||||
|
||||
with freeze_time("2021-01-01"):
|
||||
assert _get_lock(MAIN_KEY, session) is None
|
||||
with KeyValueDistributedLock("ns", a=1, b=2):
|
||||
assert _get_lock(MAIN_KEY, session) == LOCK_VALUE
|
||||
with freeze_time("2022-01-01"):
|
||||
assert _get_lock(MAIN_KEY, session) is None
|
||||
# Ensure Redis is not configured so KV backend is used
|
||||
with (
|
||||
patch.object(acquire_module, "get_redis_client", return_value=None),
|
||||
patch.object(release_module, "get_redis_client", return_value=None),
|
||||
):
|
||||
with freeze_time("2021-01-01"):
|
||||
assert _get_lock(MAIN_KEY, session) is None
|
||||
with DistributedLock("ns", a=1, b=2):
|
||||
assert _get_lock(MAIN_KEY, session) == LOCK_VALUE
|
||||
with freeze_time("2022-01-01"):
|
||||
assert _get_lock(MAIN_KEY, session) is None
|
||||
|
||||
assert _get_lock(MAIN_KEY, session) is None
|
||||
assert _get_lock(MAIN_KEY, session) is None
|
||||
|
||||
|
||||
def test_distributed_lock_uses_redis_when_configured() -> None:
|
||||
"""Test that DistributedLock uses Redis backend when configured."""
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.set.return_value = True # Lock acquired
|
||||
|
||||
# Use patch.object to patch on already-imported modules
|
||||
with (
|
||||
patch.object(acquire_module, "get_redis_client", return_value=mock_redis),
|
||||
patch.object(release_module, "get_redis_client", return_value=mock_redis),
|
||||
):
|
||||
with DistributedLock("test_redis", key="value") as lock_key:
|
||||
assert lock_key is not None
|
||||
# Verify SET NX EX was called
|
||||
mock_redis.set.assert_called_once()
|
||||
call_args = mock_redis.set.call_args
|
||||
assert call_args.kwargs["nx"] is True
|
||||
assert "ex" in call_args.kwargs
|
||||
|
||||
# Verify DELETE was called on exit
|
||||
mock_redis.delete.assert_called_once()
|
||||
|
||||
|
||||
def test_distributed_lock_redis_already_taken() -> None:
|
||||
"""Test Redis lock fails when already held."""
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.set.return_value = None # Lock not acquired (already taken)
|
||||
|
||||
with patch.object(acquire_module, "get_redis_client", return_value=mock_redis):
|
||||
with pytest.raises(AcquireDistributedLockFailedException):
|
||||
with DistributedLock("test_redis", key="value"):
|
||||
pass
|
||||
|
||||
|
||||
def test_distributed_lock_redis_connection_error() -> None:
|
||||
"""Test Redis connection error raises exception (fail fast)."""
|
||||
import redis
|
||||
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.set.side_effect = redis.RedisError("Connection failed")
|
||||
|
||||
with patch.object(acquire_module, "get_redis_client", return_value=mock_redis):
|
||||
with pytest.raises(AcquireDistributedLockFailedException):
|
||||
with DistributedLock("test_redis", key="value"):
|
||||
pass
|
||||
|
||||
|
||||
def test_distributed_lock_custom_ttl() -> None:
|
||||
"""Test Redis lock with custom TTL."""
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.set.return_value = True
|
||||
|
||||
with (
|
||||
patch.object(acquire_module, "get_redis_client", return_value=mock_redis),
|
||||
patch.object(release_module, "get_redis_client", return_value=mock_redis),
|
||||
):
|
||||
with DistributedLock("test", ttl_seconds=60, key="value"):
|
||||
call_args = mock_redis.set.call_args
|
||||
assert call_args.kwargs["ex"] == 60 # Custom TTL
|
||||
|
||||
|
||||
def test_distributed_lock_default_ttl(app_context: None) -> None:
|
||||
"""Test Redis lock uses default TTL when not specified."""
|
||||
from superset.commands.distributed_lock.base import get_default_lock_ttl
|
||||
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.set.return_value = True
|
||||
|
||||
with (
|
||||
patch.object(acquire_module, "get_redis_client", return_value=mock_redis),
|
||||
patch.object(release_module, "get_redis_client", return_value=mock_redis),
|
||||
):
|
||||
with DistributedLock("test", key="value"):
|
||||
call_args = mock_redis.set.call_args
|
||||
assert call_args.kwargs["ex"] == get_default_lock_ttl()
|
||||
|
||||
|
||||
def test_distributed_lock_fallback_to_kv_when_redis_not_configured() -> None:
|
||||
"""Test falls back to KV lock when Redis not configured."""
|
||||
session = _get_other_session()
|
||||
test_key = get_key("test_fallback", key="value")
|
||||
|
||||
with (
|
||||
patch.object(acquire_module, "get_redis_client", return_value=None),
|
||||
patch.object(release_module, "get_redis_client", return_value=None),
|
||||
):
|
||||
with freeze_time("2021-01-01"):
|
||||
# When Redis is not configured, should use KV backend
|
||||
with DistributedLock("test_fallback", key="value") as lock_key:
|
||||
assert lock_key == test_key
|
||||
# Verify lock exists in KV store
|
||||
assert _get_lock(test_key, session) == LOCK_VALUE
|
||||
|
||||
# Lock should be released
|
||||
assert _get_lock(test_key, session) is None
|
||||
|
||||
477
tests/unit_tests/tasks/test_decorators.py
Normal file
477
tests/unit_tests/tasks/test_decorators.py
Normal file
@@ -0,0 +1,477 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""Unit tests for task decorators"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from superset_core.api.tasks import TaskOptions, TaskScope
|
||||
|
||||
from superset.commands.tasks.exceptions import GlobalTaskFrameworkDisabledError
|
||||
from superset.tasks.decorators import task, TaskWrapper
|
||||
from superset.tasks.registry import TaskRegistry
|
||||
|
||||
TEST_UUID = UUID("b8b61b7b-1cd3-4a31-a74a-0a95341afc06")
|
||||
|
||||
|
||||
class TestTaskDecoratorFeatureFlag:
|
||||
"""Tests for @task decorator feature flag behavior"""
|
||||
|
||||
def setup_method(self):
|
||||
"""Clear task registry before each test"""
|
||||
TaskRegistry._tasks.clear()
|
||||
|
||||
@patch("superset.tasks.decorators.is_feature_enabled", return_value=False)
|
||||
def test_decorator_succeeds_when_gtf_disabled(self, mock_feature_flag):
|
||||
"""Test that @task decorator can be applied even when GTF is disabled.
|
||||
|
||||
This enables safe module imports during app startup or Celery autodiscovery.
|
||||
"""
|
||||
|
||||
# Decoration should succeed - no error raised
|
||||
@task(name="test_gtf_disabled_decorator")
|
||||
def my_task() -> None:
|
||||
pass
|
||||
|
||||
assert isinstance(my_task, TaskWrapper)
|
||||
assert my_task.name == "test_gtf_disabled_decorator"
|
||||
|
||||
@patch("superset.tasks.decorators.is_feature_enabled", return_value=False)
|
||||
def test_call_raises_error_when_gtf_disabled(self, mock_feature_flag):
|
||||
"""Test that calling a task raises GlobalTaskFrameworkDisabledError
|
||||
when GTF is disabled."""
|
||||
|
||||
@task(name="test_gtf_disabled_call")
|
||||
def my_task() -> None:
|
||||
pass
|
||||
|
||||
with pytest.raises(GlobalTaskFrameworkDisabledError):
|
||||
my_task()
|
||||
|
||||
@patch("superset.tasks.decorators.is_feature_enabled", return_value=False)
|
||||
def test_schedule_raises_error_when_gtf_disabled(self, mock_feature_flag):
|
||||
"""Test that scheduling a task raises GlobalTaskFrameworkDisabledError
|
||||
when GTF is disabled."""
|
||||
|
||||
@task(name="test_gtf_disabled_schedule")
|
||||
def my_task() -> None:
|
||||
pass
|
||||
|
||||
with pytest.raises(GlobalTaskFrameworkDisabledError):
|
||||
my_task.schedule()
|
||||
|
||||
|
||||
class TestTaskDecorator:
|
||||
"""Tests for @task decorator"""
|
||||
|
||||
def test_decorator_basic(self):
|
||||
"""Test basic decorator usage without options"""
|
||||
|
||||
@task(name="test_task")
|
||||
def my_task(arg1: int, arg2: str) -> None:
|
||||
pass
|
||||
|
||||
assert isinstance(my_task, TaskWrapper)
|
||||
assert my_task.name == "test_task"
|
||||
assert my_task.scope == TaskScope.PRIVATE
|
||||
|
||||
def test_decorator_without_parentheses(self):
|
||||
"""Test decorator usage without parentheses"""
|
||||
|
||||
@task
|
||||
def my_no_parens_task(arg1: int, arg2: str) -> None:
|
||||
pass
|
||||
|
||||
assert isinstance(my_no_parens_task, TaskWrapper)
|
||||
assert my_no_parens_task.name == "my_no_parens_task" # Uses function name
|
||||
assert my_no_parens_task.scope == TaskScope.PRIVATE
|
||||
|
||||
def test_decorator_with_default_scope_private(self):
|
||||
"""Test decorator with explicit PRIVATE scope"""
|
||||
|
||||
@task(name="private_task", scope=TaskScope.PRIVATE)
|
||||
def my_private_task(arg1: int) -> None:
|
||||
pass
|
||||
|
||||
assert my_private_task.scope == TaskScope.PRIVATE
|
||||
|
||||
def test_decorator_with_default_scope_shared(self):
|
||||
"""Test decorator with SHARED scope"""
|
||||
|
||||
@task(name="shared_task", scope=TaskScope.SHARED)
|
||||
def my_shared_task(arg1: int) -> None:
|
||||
pass
|
||||
|
||||
assert my_shared_task.scope == TaskScope.SHARED
|
||||
|
||||
def test_decorator_with_default_scope_system(self):
|
||||
"""Test decorator with SYSTEM scope"""
|
||||
|
||||
@task(name="system_task", scope=TaskScope.SYSTEM)
|
||||
def my_system_task() -> None:
|
||||
pass
|
||||
|
||||
assert my_system_task.scope == TaskScope.SYSTEM
|
||||
|
||||
def test_decorator_forbids_ctx_parameter(self):
|
||||
"""Test decorator rejects functions with ctx parameter"""
|
||||
|
||||
with pytest.raises(TypeError, match="must not define 'ctx'"):
|
||||
|
||||
@task(name="bad_task")
|
||||
def bad_task(ctx, arg1: int) -> None: # noqa: ARG001
|
||||
pass
|
||||
|
||||
def test_decorator_forbids_options_parameter(self):
|
||||
"""Test decorator rejects functions with options parameter"""
|
||||
|
||||
with pytest.raises(TypeError, match="must not define.*'options'"):
|
||||
|
||||
@task(name="bad_task")
|
||||
def bad_task(options, arg1: int) -> None: # noqa: ARG001
|
||||
pass
|
||||
|
||||
|
||||
class TestTaskWrapperMergeOptions:
|
||||
"""Tests for TaskWrapper._merge_options()"""
|
||||
|
||||
def setup_method(self):
|
||||
"""Clear task registry before each test"""
|
||||
TaskRegistry._tasks.clear()
|
||||
|
||||
def test_merge_options_no_override(self):
|
||||
"""Test merging with no override returns defaults"""
|
||||
|
||||
@task(name="test_merge_no_override_unique")
|
||||
def merge_task_1() -> None:
|
||||
pass
|
||||
|
||||
# Set default options for testing
|
||||
merge_task_1.default_options = TaskOptions(
|
||||
task_key="default_key",
|
||||
task_name="Default Name",
|
||||
)
|
||||
|
||||
merged = merge_task_1._merge_options(None)
|
||||
assert merged.task_key == "default_key"
|
||||
assert merged.task_name == "Default Name"
|
||||
|
||||
def test_merge_options_override_task_key(self):
|
||||
"""Test overriding task_key at call time"""
|
||||
|
||||
@task(name="test_merge_override_key_unique")
|
||||
def merge_task_2() -> None:
|
||||
pass
|
||||
|
||||
# Set default options for testing
|
||||
merge_task_2.default_options = TaskOptions(task_key="default_key")
|
||||
|
||||
override = TaskOptions(task_key="override_key")
|
||||
merged = merge_task_2._merge_options(override)
|
||||
assert merged.task_key == "override_key"
|
||||
|
||||
def test_merge_options_override_task_name(self):
|
||||
"""Test overriding task_name at call time"""
|
||||
|
||||
@task(name="test_merge_override_name_unique")
|
||||
def merge_task_3() -> None:
|
||||
pass
|
||||
|
||||
# Set default options for testing
|
||||
merge_task_3.default_options = TaskOptions(task_name="Default Name")
|
||||
|
||||
override = TaskOptions(task_name="Override Name")
|
||||
merged = merge_task_3._merge_options(override)
|
||||
assert merged.task_name == "Override Name"
|
||||
|
||||
def test_merge_options_override_all(self):
|
||||
"""Test overriding all options at call time"""
|
||||
|
||||
@task(name="test_merge_override_all_unique")
|
||||
def merge_task_4() -> None:
|
||||
pass
|
||||
|
||||
# Set default options for testing
|
||||
merge_task_4.default_options = TaskOptions(
|
||||
task_key="default_key",
|
||||
task_name="Default Name",
|
||||
)
|
||||
|
||||
override = TaskOptions(
|
||||
task_key="override_key",
|
||||
task_name="Override Name",
|
||||
)
|
||||
merged = merge_task_4._merge_options(override)
|
||||
assert merged.task_key == "override_key"
|
||||
assert merged.task_name == "Override Name"
|
||||
|
||||
|
||||
class TestTaskWrapperSchedule:
|
||||
"""Tests for TaskWrapper.schedule() with scope"""
|
||||
|
||||
def setup_method(self):
|
||||
"""Clear task registry before each test"""
|
||||
TaskRegistry._tasks.clear()
|
||||
|
||||
@patch("superset.tasks.decorators.TaskManager.submit_task")
|
||||
def test_schedule_uses_default_scope(self, mock_submit):
|
||||
"""Test schedule() uses decorator's default scope"""
|
||||
mock_submit.return_value = MagicMock()
|
||||
|
||||
@task(name="test_schedule_default_unique", scope=TaskScope.SHARED)
|
||||
def schedule_task_1(arg1: int) -> None:
|
||||
pass
|
||||
|
||||
# Shared tasks require explicit task_key
|
||||
schedule_task_1.schedule(123, options=TaskOptions(task_key="test_key"))
|
||||
|
||||
# Verify TaskManager.submit_task was called with correct scope
|
||||
mock_submit.assert_called_once()
|
||||
call_args = mock_submit.call_args
|
||||
assert call_args[1]["scope"] == TaskScope.SHARED
|
||||
|
||||
@patch("superset.tasks.decorators.TaskManager.submit_task")
|
||||
def test_schedule_uses_private_scope_by_default(self, mock_submit):
|
||||
"""Test schedule() uses PRIVATE scope when no scope specified"""
|
||||
mock_submit.return_value = MagicMock()
|
||||
|
||||
@task(name="test_schedule_override_unique")
|
||||
def schedule_task_2(arg1: int) -> None:
|
||||
pass
|
||||
|
||||
schedule_task_2.schedule(123)
|
||||
|
||||
# Verify PRIVATE scope was used (default)
|
||||
mock_submit.assert_called_once()
|
||||
call_args = mock_submit.call_args
|
||||
assert call_args[1]["scope"] == TaskScope.PRIVATE
|
||||
|
||||
@patch("superset.tasks.decorators.TaskManager.submit_task")
|
||||
def test_schedule_with_custom_options(self, mock_submit):
|
||||
"""Test schedule() with custom task options"""
|
||||
mock_submit.return_value = MagicMock()
|
||||
|
||||
@task(name="test_schedule_custom_unique", scope=TaskScope.SYSTEM)
|
||||
def schedule_task_3(arg1: int) -> None:
|
||||
pass
|
||||
|
||||
# Use custom task key and name
|
||||
schedule_task_3.schedule(
|
||||
123,
|
||||
options=TaskOptions(task_key="custom_key", task_name="Custom Task Name"),
|
||||
)
|
||||
|
||||
# Verify scope from decorator and options from call time
|
||||
mock_submit.assert_called_once()
|
||||
call_args = mock_submit.call_args
|
||||
assert call_args[1]["scope"] == TaskScope.SYSTEM
|
||||
assert call_args[1]["task_key"] == "custom_key"
|
||||
assert call_args[1]["task_name"] == "Custom Task Name"
|
||||
|
||||
@patch("superset.tasks.decorators.TaskManager.submit_task")
|
||||
def test_schedule_with_no_decorator_options(self, mock_submit):
|
||||
"""Test schedule() uses default PRIVATE scope when no options provided"""
|
||||
mock_submit.return_value = MagicMock()
|
||||
|
||||
@task(name="test_schedule_no_options_unique")
|
||||
def schedule_task_4(arg1: int) -> None:
|
||||
pass
|
||||
|
||||
schedule_task_4.schedule(123)
|
||||
|
||||
# Verify default PRIVATE scope
|
||||
mock_submit.assert_called_once()
|
||||
call_args = mock_submit.call_args
|
||||
assert call_args[1]["scope"] == TaskScope.PRIVATE
|
||||
|
||||
@patch("superset.tasks.decorators.TaskManager.submit_task")
|
||||
def test_schedule_shared_task_requires_task_key(self, mock_submit):
|
||||
"""Test shared task schedule() requires explicit task_key"""
|
||||
|
||||
@task(name="test_shared_requires_key", scope=TaskScope.SHARED)
|
||||
def shared_task(arg1: int) -> None:
|
||||
pass
|
||||
|
||||
# Should raise ValueError when no task_key provided
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="Shared task.*requires an explicit task_key.*for deduplication",
|
||||
):
|
||||
shared_task.schedule(123)
|
||||
|
||||
# Should work with task_key provided
|
||||
mock_submit.return_value = MagicMock()
|
||||
shared_task.schedule(123, options=TaskOptions(task_key="valid_key"))
|
||||
mock_submit.assert_called_once()
|
||||
|
||||
@patch("superset.tasks.decorators.TaskManager.submit_task")
|
||||
def test_schedule_private_task_allows_no_task_key(self, mock_submit):
|
||||
"""Test private task schedule() works without task_key"""
|
||||
mock_submit.return_value = MagicMock()
|
||||
|
||||
@task(name="test_private_no_key", scope=TaskScope.PRIVATE)
|
||||
def private_task(arg1: int) -> None:
|
||||
pass
|
||||
|
||||
# Should work without task_key (generates random UUID)
|
||||
private_task.schedule(123)
|
||||
mock_submit.assert_called_once()
|
||||
|
||||
|
||||
class TestTaskWrapperCall:
|
||||
"""Tests for TaskWrapper.__call__() with scope"""
|
||||
|
||||
def setup_method(self):
|
||||
"""Clear task registry before each test"""
|
||||
TaskRegistry._tasks.clear()
|
||||
|
||||
@patch("superset.commands.tasks.update.UpdateTaskCommand.run")
|
||||
@patch("superset.daos.tasks.TaskDAO.find_one_or_none")
|
||||
@patch("superset.commands.tasks.submit.SubmitTaskCommand.run_with_info")
|
||||
def test_call_uses_default_scope(
|
||||
self, mock_submit_run_with_info, mock_find, mock_update_run
|
||||
):
|
||||
"""Test direct call uses decorator's default scope"""
|
||||
mock_task = MagicMock()
|
||||
mock_task.uuid = TEST_UUID
|
||||
mock_task.status = "in_progress"
|
||||
mock_submit_run_with_info.return_value = (mock_task, True) # (task, is_new)
|
||||
mock_update_run.return_value = mock_task
|
||||
mock_find.return_value = mock_task # Mock the subsequent find call
|
||||
|
||||
@task(name="test_call_default_unique", scope=TaskScope.SHARED)
|
||||
def call_task_1(arg1: int) -> None:
|
||||
pass
|
||||
|
||||
# Shared tasks require explicit task_key
|
||||
call_task_1(123, options=TaskOptions(task_key="test_key"))
|
||||
|
||||
# Verify SubmitTaskCommand.run_with_info was called
|
||||
mock_submit_run_with_info.assert_called_once()
|
||||
|
||||
@patch("superset.utils.core.get_user_id")
|
||||
@patch("superset.commands.tasks.update.UpdateTaskCommand.run")
|
||||
@patch("superset.daos.tasks.TaskDAO.find_one_or_none")
|
||||
@patch("superset.commands.tasks.submit.SubmitTaskCommand.run_with_info")
|
||||
def test_call_uses_private_scope_by_default(
|
||||
self, mock_submit_run_with_info, mock_find, mock_update_run, mock_get_user_id
|
||||
):
|
||||
"""Test direct call uses PRIVATE scope when no scope specified"""
|
||||
mock_get_user_id.return_value = 1
|
||||
mock_task = MagicMock()
|
||||
mock_task.uuid = TEST_UUID
|
||||
mock_task.status = "in_progress"
|
||||
mock_submit_run_with_info.return_value = (mock_task, True) # (task, is_new)
|
||||
mock_update_run.return_value = mock_task
|
||||
mock_find.return_value = mock_task # Mock the subsequent find call
|
||||
|
||||
@task(name="test_call_private_default_unique")
|
||||
def call_task_2(arg1: int) -> None:
|
||||
pass
|
||||
|
||||
call_task_2(123)
|
||||
|
||||
# Verify SubmitTaskCommand.run_with_info was called
|
||||
mock_submit_run_with_info.assert_called_once()
|
||||
|
||||
@patch("superset.commands.tasks.update.UpdateTaskCommand.run")
|
||||
@patch("superset.daos.tasks.TaskDAO.find_one_or_none")
|
||||
@patch("superset.commands.tasks.submit.SubmitTaskCommand.run_with_info")
|
||||
def test_call_with_custom_options(
|
||||
self, mock_submit_run_with_info, mock_find, mock_update_run
|
||||
):
|
||||
"""Test direct call with custom task options"""
|
||||
mock_task = MagicMock()
|
||||
mock_task.uuid = TEST_UUID
|
||||
mock_task.status = "in_progress"
|
||||
mock_submit_run_with_info.return_value = (mock_task, True) # (task, is_new)
|
||||
mock_update_run.return_value = mock_task
|
||||
mock_find.return_value = mock_task # Mock the subsequent find call
|
||||
|
||||
@task(name="test_call_custom_unique", scope=TaskScope.SYSTEM)
|
||||
def call_task_3(arg1: int) -> None:
|
||||
pass
|
||||
|
||||
# Use custom task key and name
|
||||
call_task_3(
|
||||
123,
|
||||
options=TaskOptions(task_key="custom_key", task_name="Custom Task Name"),
|
||||
)
|
||||
|
||||
# Verify SubmitTaskCommand.run_with_info was called
|
||||
mock_submit_run_with_info.assert_called_once()
|
||||
|
||||
def test_call_shared_task_requires_task_key(self):
|
||||
"""Test shared task direct call requires explicit task_key"""
|
||||
|
||||
@task(name="test_shared_call_requires_key", scope=TaskScope.SHARED)
|
||||
def shared_task(arg1: int) -> None:
|
||||
pass
|
||||
|
||||
# Should raise ValueError when no task_key provided
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="Shared task.*requires an explicit task_key.*for deduplication",
|
||||
):
|
||||
shared_task(123)
|
||||
|
||||
@patch("superset.commands.tasks.update.UpdateTaskCommand.run")
|
||||
@patch("superset.daos.tasks.TaskDAO.find_one_or_none")
|
||||
@patch("superset.commands.tasks.submit.SubmitTaskCommand.run_with_info")
|
||||
def test_call_shared_task_works_with_task_key(
|
||||
self, mock_submit_run_with_info, mock_find, mock_update_run
|
||||
):
|
||||
"""Test shared task direct call works with task_key"""
|
||||
mock_task = MagicMock()
|
||||
mock_task.uuid = TEST_UUID
|
||||
mock_task.status = "in_progress"
|
||||
mock_submit_run_with_info.return_value = (mock_task, True) # (task, is_new)
|
||||
mock_update_run.return_value = mock_task
|
||||
mock_find.return_value = mock_task
|
||||
|
||||
@task(name="test_shared_call_with_key", scope=TaskScope.SHARED)
|
||||
def shared_task(arg1: int) -> None:
|
||||
pass
|
||||
|
||||
# Should work with task_key provided
|
||||
shared_task(123, options=TaskOptions(task_key="valid_key"))
|
||||
mock_submit_run_with_info.assert_called_once()
|
||||
|
||||
@patch("superset.utils.core.get_user_id")
|
||||
@patch("superset.commands.tasks.update.UpdateTaskCommand.run")
|
||||
@patch("superset.daos.tasks.TaskDAO.find_one_or_none")
|
||||
@patch("superset.commands.tasks.submit.SubmitTaskCommand.run_with_info")
|
||||
def test_call_private_task_allows_no_task_key(
|
||||
self, mock_submit_run_with_info, mock_find, mock_update_run, mock_get_user_id
|
||||
):
|
||||
"""Test private task direct call works without task_key"""
|
||||
mock_get_user_id.return_value = 1
|
||||
mock_task = MagicMock()
|
||||
mock_task.uuid = TEST_UUID
|
||||
mock_task.status = "in_progress"
|
||||
mock_submit_run_with_info.return_value = (mock_task, True) # (task, is_new)
|
||||
mock_update_run.return_value = mock_task
|
||||
mock_find.return_value = mock_task
|
||||
|
||||
@task(name="test_private_call_no_key", scope=TaskScope.PRIVATE)
|
||||
def private_task(arg1: int) -> None:
|
||||
pass
|
||||
|
||||
# Should work without task_key (generates random UUID)
|
||||
private_task(123)
|
||||
mock_submit_run_with_info.assert_called_once()
|
||||
677
tests/unit_tests/tasks/test_handlers.py
Normal file
677
tests/unit_tests/tasks/test_handlers.py
Normal file
@@ -0,0 +1,677 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""Unit tests for GTF handlers (abort, cleanup) and related Task model behavior."""
|
||||
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from freezegun import freeze_time
|
||||
from superset_core.api.tasks import TaskStatus
|
||||
|
||||
from superset.tasks.context import TaskContext
|
||||
|
||||
TEST_UUID = UUID("b8b61b7b-1cd3-4a31-a74a-0a95341afc06")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_task():
|
||||
"""Create a mock task for testing."""
|
||||
task = MagicMock()
|
||||
task.uuid = TEST_UUID
|
||||
task.status = TaskStatus.PENDING.value
|
||||
return task
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_task_dao(mock_task):
|
||||
"""Mock TaskDAO to return our test task."""
|
||||
with patch("superset.daos.tasks.TaskDAO") as mock_dao:
|
||||
mock_dao.find_one_or_none.return_value = mock_task
|
||||
yield mock_dao
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_update_command():
|
||||
"""Mock UpdateTaskCommand to avoid database operations."""
|
||||
with patch("superset.commands.tasks.update.UpdateTaskCommand") as mock_cmd:
|
||||
mock_cmd.return_value.run.return_value = None
|
||||
yield mock_cmd
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_flask_app():
|
||||
"""Create a properly configured mock Flask app."""
|
||||
mock_app = MagicMock()
|
||||
mock_app.config = {
|
||||
"TASK_ABORT_POLLING_DEFAULT_INTERVAL": 0.1,
|
||||
}
|
||||
# Make app_context() return a proper context manager
|
||||
mock_app.app_context.return_value.__enter__ = MagicMock(return_value=None)
|
||||
mock_app.app_context.return_value.__exit__ = MagicMock(return_value=None)
|
||||
# Use regular Mock (not MagicMock) for _get_current_object to avoid
|
||||
# AsyncMockMixin creating unawaited coroutines in Python 3.10+
|
||||
mock_app._get_current_object = Mock(return_value=mock_app)
|
||||
return mock_app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def task_context(mock_task, mock_task_dao, mock_update_command, mock_flask_app):
|
||||
"""Create TaskContext with mocked dependencies."""
|
||||
# Ensure mock_task has properties_dict and payload_dict (TaskContext accesses them)
|
||||
mock_task.properties_dict = {"is_abortable": False}
|
||||
mock_task.payload_dict = {}
|
||||
|
||||
with (
|
||||
patch("superset.tasks.context.current_app") as mock_current_app,
|
||||
patch("superset.tasks.manager.cache_manager") as mock_cache_manager,
|
||||
):
|
||||
# Disable Redis by making signal_cache return None
|
||||
mock_cache_manager.signal_cache = None
|
||||
|
||||
# Configure current_app mock
|
||||
mock_current_app.config = mock_flask_app.config
|
||||
# Use regular Mock (not MagicMock) for _get_current_object to avoid
|
||||
# AsyncMockMixin creating unawaited coroutines in Python 3.10+
|
||||
mock_current_app._get_current_object = Mock(return_value=mock_flask_app)
|
||||
|
||||
ctx = TaskContext(mock_task)
|
||||
|
||||
yield ctx
|
||||
|
||||
# Cleanup: stop polling if started
|
||||
if ctx._abort_listener:
|
||||
ctx.stop_abort_polling()
|
||||
|
||||
|
||||
class TestTaskStatusEnum:
|
||||
"""Test TaskStatus enum values."""
|
||||
|
||||
def test_aborting_status_exists(self):
|
||||
"""Test that ABORTING status is defined."""
|
||||
assert hasattr(TaskStatus, "ABORTING")
|
||||
assert TaskStatus.ABORTING.value == "aborting"
|
||||
|
||||
def test_all_statuses_present(self):
|
||||
"""Test all expected statuses are present."""
|
||||
expected_statuses = [
|
||||
"pending",
|
||||
"in_progress",
|
||||
"success",
|
||||
"failure",
|
||||
"aborting",
|
||||
"aborted",
|
||||
]
|
||||
actual_statuses = [s.value for s in TaskStatus]
|
||||
|
||||
for status in expected_statuses:
|
||||
assert status in actual_statuses, f"Missing status: {status}"
|
||||
|
||||
|
||||
class TestTaskAbortProperties:
|
||||
"""Test Task model abort-related properties via status and properties accessor."""
|
||||
|
||||
def test_aborting_status(self):
|
||||
"""Test ABORTING status check."""
|
||||
from superset.models.tasks import Task
|
||||
|
||||
task = Task()
|
||||
task.status = TaskStatus.ABORTING.value
|
||||
|
||||
assert task.status == TaskStatus.ABORTING.value
|
||||
|
||||
def test_is_abortable_in_properties(self):
|
||||
"""Test is_abortable is accessible via properties."""
|
||||
from superset.models.tasks import Task
|
||||
|
||||
task = Task()
|
||||
task.update_properties({"is_abortable": True})
|
||||
|
||||
assert task.properties_dict.get("is_abortable") is True
|
||||
|
||||
def test_is_abortable_default_none(self):
|
||||
"""Test is_abortable defaults to None for new tasks."""
|
||||
from superset.models.tasks import Task
|
||||
|
||||
task = Task()
|
||||
|
||||
assert task.properties_dict.get("is_abortable") is None
|
||||
|
||||
|
||||
class TestTaskSetStatus:
|
||||
"""Test Task.set_status behavior for abort states."""
|
||||
|
||||
def test_set_status_in_progress_sets_is_abortable_false(self):
|
||||
"""Test that transitioning to IN_PROGRESS sets is_abortable to False."""
|
||||
from superset.models.tasks import Task
|
||||
|
||||
task = Task()
|
||||
task.uuid = "test-uuid"
|
||||
# Default is None
|
||||
|
||||
task.set_status(TaskStatus.IN_PROGRESS)
|
||||
|
||||
assert task.properties_dict.get("is_abortable") is False
|
||||
assert task.started_at is not None
|
||||
|
||||
def test_set_status_in_progress_preserves_existing_is_abortable(self):
|
||||
"""Test that re-setting IN_PROGRESS doesn't override is_abortable."""
|
||||
from superset.models.tasks import Task
|
||||
|
||||
task = Task()
|
||||
task.uuid = "test-uuid"
|
||||
task.update_properties(
|
||||
{"is_abortable": True}
|
||||
) # Already set by handler registration
|
||||
task.started_at = datetime.now(timezone.utc) # Already started
|
||||
|
||||
task.set_status(TaskStatus.IN_PROGRESS)
|
||||
|
||||
# Should not override since started_at is already set
|
||||
assert task.properties_dict.get("is_abortable") is True
|
||||
|
||||
def test_set_status_aborting_does_not_set_ended_at(self):
|
||||
"""Test that ABORTING status does not set ended_at."""
|
||||
from superset.models.tasks import Task
|
||||
|
||||
task = Task()
|
||||
task.uuid = "test-uuid"
|
||||
task.started_at = datetime.now(timezone.utc)
|
||||
|
||||
task.status = TaskStatus.ABORTING.value
|
||||
|
||||
assert task.ended_at is None
|
||||
|
||||
def test_set_status_aborted_sets_ended_at(self):
|
||||
"""Test that ABORTED status sets ended_at."""
|
||||
from superset.models.tasks import Task
|
||||
|
||||
task = Task()
|
||||
task.uuid = "test-uuid"
|
||||
task.started_at = datetime.now(timezone.utc)
|
||||
|
||||
task.set_status(TaskStatus.ABORTED)
|
||||
|
||||
assert task.ended_at is not None
|
||||
|
||||
|
||||
class TestTaskDuration:
|
||||
"""Test Task duration_seconds property with different states."""
|
||||
|
||||
def test_duration_seconds_finished_task(self):
|
||||
"""Test duration for finished task returns actual duration."""
|
||||
from superset.models.tasks import Task
|
||||
|
||||
task = Task()
|
||||
task.status = TaskStatus.SUCCESS.value # Must be finished to use ended_at
|
||||
task.started_at = datetime(2024, 1, 1, 10, 0, 0, tzinfo=timezone.utc)
|
||||
task.ended_at = datetime(2024, 1, 1, 10, 0, 30, tzinfo=timezone.utc)
|
||||
|
||||
# Should use ended_at - started_at = 30 seconds
|
||||
assert task.duration_seconds == 30.0
|
||||
|
||||
@freeze_time("2024-01-01 10:00:30")
|
||||
def test_duration_seconds_running_task(self):
|
||||
"""Test duration for running task returns time since start."""
|
||||
from superset.models.tasks import Task
|
||||
|
||||
task = Task()
|
||||
task.started_at = datetime(2024, 1, 1, 10, 0, 0, tzinfo=timezone.utc)
|
||||
task.ended_at = None
|
||||
|
||||
# 30 seconds since start
|
||||
assert task.duration_seconds == 30.0
|
||||
|
||||
@freeze_time("2024-01-01 10:00:15")
|
||||
def test_duration_seconds_pending_task(self):
|
||||
"""Test duration for pending task returns queue time."""
|
||||
from superset.models.tasks import Task
|
||||
|
||||
task = Task()
|
||||
task.created_on = datetime(2024, 1, 1, 10, 0, 0, tzinfo=timezone.utc)
|
||||
task.started_at = None
|
||||
task.ended_at = None
|
||||
|
||||
# 15 seconds since creation
|
||||
assert task.duration_seconds == 15.0
|
||||
|
||||
def test_duration_seconds_no_timestamps(self):
|
||||
"""Test duration returns None when no timestamps available."""
|
||||
from superset.models.tasks import Task
|
||||
|
||||
task = Task()
|
||||
task.created_on = None
|
||||
task.started_at = None
|
||||
task.ended_at = None
|
||||
|
||||
assert task.duration_seconds is None
|
||||
|
||||
|
||||
class TestAbortHandlerRegistration:
|
||||
"""Test abort handler registration and is_abortable flag."""
|
||||
|
||||
def test_on_abort_registers_handler(self, task_context):
|
||||
"""Test that on_abort registers a handler."""
|
||||
handler_called = False
|
||||
|
||||
@task_context.on_abort
|
||||
def handle_abort():
|
||||
nonlocal handler_called
|
||||
handler_called = True
|
||||
|
||||
assert len(task_context._abort_handlers) == 1
|
||||
assert not handler_called
|
||||
|
||||
@patch("superset.tasks.context.current_app")
|
||||
def test_on_abort_sets_abortable(self, mock_app):
|
||||
"""Test on_abort sets is_abortable to True on first handler."""
|
||||
mock_app.config = {"TASK_ABORT_POLLING_DEFAULT_INTERVAL": 1.0}
|
||||
mock_app._get_current_object = Mock(return_value=mock_app)
|
||||
mock_task = MagicMock()
|
||||
mock_task.uuid = TEST_UUID
|
||||
mock_task.properties_dict = {"is_abortable": False}
|
||||
mock_task.payload_dict = {}
|
||||
|
||||
with (
|
||||
patch.object(TaskContext, "_set_abortable") as mock_set_abortable,
|
||||
patch.object(TaskContext, "start_abort_polling"),
|
||||
):
|
||||
ctx = TaskContext(mock_task)
|
||||
|
||||
@ctx.on_abort
|
||||
def handler():
|
||||
pass
|
||||
|
||||
mock_set_abortable.assert_called_once()
|
||||
|
||||
@patch("superset.tasks.context.current_app")
|
||||
def test_on_abort_only_sets_abortable_once(self, mock_app):
|
||||
"""Test on_abort only calls _set_abortable for first handler."""
|
||||
mock_app.config = {"TASK_ABORT_POLLING_DEFAULT_INTERVAL": 1.0}
|
||||
mock_app._get_current_object = Mock(return_value=mock_app)
|
||||
mock_task = MagicMock()
|
||||
mock_task.uuid = TEST_UUID
|
||||
mock_task.properties_dict = {"is_abortable": False}
|
||||
mock_task.payload_dict = {}
|
||||
|
||||
with (
|
||||
patch.object(TaskContext, "_set_abortable") as mock_set_abortable,
|
||||
patch.object(TaskContext, "start_abort_polling"),
|
||||
):
|
||||
ctx = TaskContext(mock_task)
|
||||
|
||||
@ctx.on_abort
|
||||
def handler1():
|
||||
pass
|
||||
|
||||
@ctx.on_abort
|
||||
def handler2():
|
||||
pass
|
||||
|
||||
# Should only be called once for first handler
|
||||
assert mock_set_abortable.call_count == 1
|
||||
|
||||
def test_abort_handlers_completed_initially_false(self):
|
||||
"""Test abort_handlers_completed is False initially."""
|
||||
mock_task = MagicMock()
|
||||
mock_task.uuid = TEST_UUID
|
||||
mock_task.properties_dict = {}
|
||||
mock_task.payload_dict = {}
|
||||
|
||||
with patch("superset.tasks.context.current_app") as mock_app:
|
||||
mock_app._get_current_object = Mock(return_value=mock_app)
|
||||
ctx = TaskContext(mock_task)
|
||||
assert ctx.abort_handlers_completed is False
|
||||
|
||||
|
||||
class TestAbortPolling:
|
||||
"""Test abort detection polling behavior."""
|
||||
|
||||
def test_on_abort_starts_polling_automatically(self, task_context):
|
||||
"""Test that registering first handler starts abort listener."""
|
||||
assert task_context._abort_listener is None
|
||||
|
||||
@task_context.on_abort
|
||||
def handle_abort():
|
||||
pass
|
||||
|
||||
assert task_context._abort_listener is not None
|
||||
|
||||
def test_stop_abort_polling(self, task_context):
|
||||
"""Test that stop_abort_polling stops the abort listener."""
|
||||
|
||||
@task_context.on_abort
|
||||
def handle_abort():
|
||||
pass
|
||||
|
||||
assert task_context._abort_listener is not None
|
||||
|
||||
task_context.stop_abort_polling()
|
||||
|
||||
assert task_context._abort_listener is None
|
||||
|
||||
def test_start_abort_polling_only_once(self, task_context):
|
||||
"""Test that start_abort_polling is idempotent."""
|
||||
task_context.start_abort_polling(interval=0.1)
|
||||
first_listener = task_context._abort_listener
|
||||
|
||||
# Try to start again
|
||||
task_context.start_abort_polling(interval=0.1)
|
||||
second_listener = task_context._abort_listener
|
||||
|
||||
# Should be the same listener
|
||||
assert first_listener is second_listener
|
||||
|
||||
def test_on_abort_with_custom_interval(self, task_context):
|
||||
"""Test that custom interval can be set via start_abort_polling."""
|
||||
with patch("superset.tasks.context.current_app") as mock_app:
|
||||
mock_app.config = {"TASK_ABORT_POLLING_DEFAULT_INTERVAL": 0.1}
|
||||
mock_app._get_current_object = Mock(return_value=mock_app)
|
||||
|
||||
@task_context.on_abort
|
||||
def handle_abort():
|
||||
pass
|
||||
|
||||
# Override with custom interval
|
||||
task_context.stop_abort_polling()
|
||||
task_context.start_abort_polling(interval=0.05)
|
||||
|
||||
assert task_context._abort_listener is not None
|
||||
|
||||
def test_polling_stops_after_abort_detected(self, task_context, mock_task):
|
||||
"""Test that abort is detected and handlers are triggered."""
|
||||
|
||||
@task_context.on_abort
|
||||
def handle_abort():
|
||||
pass
|
||||
|
||||
# Trigger abort
|
||||
mock_task.status = TaskStatus.ABORTED.value
|
||||
|
||||
# Wait for detection
|
||||
time.sleep(0.3)
|
||||
|
||||
# Abort should have been detected
|
||||
assert task_context._abort_detected is True
|
||||
|
||||
|
||||
class TestAbortHandlerExecution:
|
||||
"""Test abort handler execution behavior."""
|
||||
|
||||
def test_on_abort_handler_fires_when_task_aborted(self, task_context, mock_task):
|
||||
"""Test that abort handler fires automatically when task is aborted."""
|
||||
abort_called = False
|
||||
|
||||
@task_context.on_abort
|
||||
def handle_abort():
|
||||
nonlocal abort_called
|
||||
abort_called = True
|
||||
|
||||
# Simulate task being aborted
|
||||
mock_task.status = TaskStatus.ABORTED.value
|
||||
|
||||
# Wait for polling to detect abort (max 0.3s with 0.1s interval)
|
||||
time.sleep(0.3)
|
||||
|
||||
assert abort_called
|
||||
assert task_context._abort_detected
|
||||
|
||||
def test_on_abort_not_called_on_success(self, task_context, mock_task):
|
||||
"""Test that abort handlers don't run on success."""
|
||||
abort_called = False
|
||||
|
||||
@task_context.on_abort
|
||||
def handle_abort():
|
||||
nonlocal abort_called
|
||||
abort_called = True
|
||||
|
||||
# Keep task in success state
|
||||
mock_task.status = TaskStatus.SUCCESS.value
|
||||
|
||||
# Wait and verify handler not called
|
||||
time.sleep(0.3)
|
||||
|
||||
assert not abort_called
|
||||
|
||||
def test_multiple_abort_handlers(self, task_context, mock_task):
|
||||
"""Test that all abort handlers execute in LIFO order."""
|
||||
calls = []
|
||||
|
||||
@task_context.on_abort
|
||||
def handler1():
|
||||
calls.append(1)
|
||||
|
||||
@task_context.on_abort
|
||||
def handler2():
|
||||
calls.append(2)
|
||||
|
||||
# Trigger abort
|
||||
mock_task.status = TaskStatus.ABORTED.value
|
||||
|
||||
# Wait for detection
|
||||
time.sleep(0.3)
|
||||
|
||||
# LIFO order: handler2 runs first
|
||||
assert calls == [2, 1]
|
||||
|
||||
def test_abort_handler_exception_doesnt_fail_task(self, task_context, mock_task):
|
||||
"""Test that exception in abort handler is logged but doesn't fail task."""
|
||||
handler2_called = False
|
||||
|
||||
@task_context.on_abort
|
||||
def bad_handler():
|
||||
raise ValueError("Handler error")
|
||||
|
||||
@task_context.on_abort
|
||||
def good_handler():
|
||||
nonlocal handler2_called
|
||||
handler2_called = True
|
||||
|
||||
# Trigger abort
|
||||
mock_task.status = TaskStatus.ABORTED.value
|
||||
|
||||
# Wait for detection
|
||||
time.sleep(0.3)
|
||||
|
||||
# Second handler should still run despite first handler failing
|
||||
assert handler2_called
|
||||
|
||||
|
||||
class TestBestEffortHandlerExecution:
|
||||
"""Test that all handlers execute even when some fail (best-effort)."""
|
||||
|
||||
def test_all_abort_handlers_run_even_if_all_fail(self, task_context, mock_task):
|
||||
"""Test all abort handlers execute even if every one raises an exception."""
|
||||
calls = []
|
||||
|
||||
@task_context.on_abort
|
||||
def handler1():
|
||||
calls.append(1)
|
||||
raise ValueError("Handler 1 failed")
|
||||
|
||||
@task_context.on_abort
|
||||
def handler2():
|
||||
calls.append(2)
|
||||
raise RuntimeError("Handler 2 failed")
|
||||
|
||||
@task_context.on_abort
|
||||
def handler3():
|
||||
calls.append(3)
|
||||
raise TypeError("Handler 3 failed")
|
||||
|
||||
# Trigger abort handlers directly (simulating abort detection)
|
||||
task_context._trigger_abort_handlers()
|
||||
|
||||
# All handlers should have been called (LIFO order: 3, 2, 1)
|
||||
assert calls == [3, 2, 1]
|
||||
|
||||
# Failures should be collected (abort handlers don't write to DB)
|
||||
assert len(task_context._handler_failures) == 3
|
||||
failure_types = [
|
||||
type(ex).__name__ for _, ex, _ in task_context._handler_failures
|
||||
]
|
||||
assert "TypeError" in failure_types
|
||||
assert "RuntimeError" in failure_types
|
||||
assert "ValueError" in failure_types
|
||||
|
||||
def test_all_cleanup_handlers_run_even_if_all_fail(self, task_context, mock_task):
|
||||
"""Test all cleanup handlers execute even if every one raises an exception."""
|
||||
calls = []
|
||||
captured_failures = []
|
||||
|
||||
# Mock _write_handler_failures_to_db to capture failures before clearing
|
||||
original_write = task_context._write_handler_failures_to_db
|
||||
|
||||
def mock_write():
|
||||
captured_failures.extend(task_context._handler_failures)
|
||||
original_write()
|
||||
|
||||
task_context._write_handler_failures_to_db = mock_write
|
||||
|
||||
@task_context.on_cleanup
|
||||
def cleanup1():
|
||||
calls.append(1)
|
||||
raise ValueError("Cleanup 1 failed")
|
||||
|
||||
@task_context.on_cleanup
|
||||
def cleanup2():
|
||||
calls.append(2)
|
||||
raise RuntimeError("Cleanup 2 failed")
|
||||
|
||||
@task_context.on_cleanup
|
||||
def cleanup3():
|
||||
calls.append(3)
|
||||
raise TypeError("Cleanup 3 failed")
|
||||
|
||||
# Set task to SUCCESS (not aborting) so only cleanup handlers run
|
||||
mock_task.status = TaskStatus.SUCCESS.value
|
||||
|
||||
# Run cleanup
|
||||
task_context._run_cleanup()
|
||||
|
||||
# All handlers should have been called (LIFO order: 3, 2, 1)
|
||||
assert calls == [3, 2, 1]
|
||||
|
||||
# Failures should have been captured before clearing
|
||||
assert len(captured_failures) == 3
|
||||
failure_types = [type(ex).__name__ for _, ex, _ in captured_failures]
|
||||
assert "TypeError" in failure_types
|
||||
assert "RuntimeError" in failure_types
|
||||
assert "ValueError" in failure_types
|
||||
|
||||
def test_mixed_abort_and_cleanup_failures_all_collected(
|
||||
self, task_context, mock_task
|
||||
):
|
||||
"""Test abort and cleanup handler failures are collected together."""
|
||||
calls = []
|
||||
captured_failures = []
|
||||
|
||||
# Mock _write_handler_failures_to_db to capture failures before clearing
|
||||
original_write = task_context._write_handler_failures_to_db
|
||||
|
||||
def mock_write():
|
||||
captured_failures.extend(task_context._handler_failures)
|
||||
original_write()
|
||||
|
||||
task_context._write_handler_failures_to_db = mock_write
|
||||
|
||||
@task_context.on_abort
|
||||
def abort1():
|
||||
calls.append("abort1")
|
||||
raise ValueError("Abort 1 failed")
|
||||
|
||||
@task_context.on_abort
|
||||
def abort2():
|
||||
calls.append("abort2")
|
||||
raise RuntimeError("Abort 2 failed")
|
||||
|
||||
@task_context.on_cleanup
|
||||
def cleanup1():
|
||||
calls.append("cleanup1")
|
||||
raise TypeError("Cleanup 1 failed")
|
||||
|
||||
@task_context.on_cleanup
|
||||
def cleanup2():
|
||||
calls.append("cleanup2")
|
||||
raise KeyError("Cleanup 2 failed")
|
||||
|
||||
# Set task to ABORTING so both abort and cleanup handlers run
|
||||
mock_task.status = TaskStatus.ABORTING.value
|
||||
|
||||
# Run cleanup (which triggers abort handlers first, then cleanup handlers)
|
||||
task_context._run_cleanup()
|
||||
|
||||
# All handlers should have been called
|
||||
# Abort handlers run first (LIFO: abort2, abort1)
|
||||
# Then cleanup handlers (LIFO: cleanup2, cleanup1)
|
||||
assert calls == ["abort2", "abort1", "cleanup2", "cleanup1"]
|
||||
|
||||
# All 4 failures should have been captured
|
||||
assert len(captured_failures) == 4
|
||||
|
||||
# Verify handler types are recorded correctly
|
||||
handler_types = [htype for htype, _, _ in captured_failures]
|
||||
assert handler_types.count("abort") == 2
|
||||
assert handler_types.count("cleanup") == 2
|
||||
|
||||
|
||||
class TestCleanupHandlers:
|
||||
"""Test cleanup handler behavior."""
|
||||
|
||||
def test_cleanup_triggers_abort_handlers_if_not_detected(
|
||||
self, task_context, mock_task
|
||||
):
|
||||
"""Test that _run_cleanup triggers abort handlers if task ended aborted."""
|
||||
abort_called = False
|
||||
|
||||
@task_context.on_abort
|
||||
def handle_abort():
|
||||
nonlocal abort_called
|
||||
abort_called = True
|
||||
|
||||
# Set task as aborted but don't let polling detect it
|
||||
mock_task.status = TaskStatus.ABORTED.value
|
||||
task_context._abort_detected = False
|
||||
|
||||
# Immediately run cleanup (simulating task ending before poll)
|
||||
task_context._run_cleanup()
|
||||
|
||||
assert abort_called
|
||||
|
||||
def test_cleanup_doesnt_duplicate_abort_handlers(self, task_context, mock_task):
|
||||
"""Test that abort handlers only run once even if called from cleanup."""
|
||||
call_count = 0
|
||||
|
||||
@task_context.on_abort
|
||||
def handle_abort():
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
|
||||
# Trigger abort via polling
|
||||
mock_task.status = TaskStatus.ABORTED.value
|
||||
time.sleep(0.3)
|
||||
|
||||
# Handlers should have been called once
|
||||
assert call_count == 1
|
||||
assert task_context._abort_detected is True
|
||||
|
||||
# Run cleanup - handlers should NOT be called again
|
||||
task_context._run_cleanup()
|
||||
|
||||
assert call_count == 1 # Still 1, not 2
|
||||
462
tests/unit_tests/tasks/test_manager.py
Normal file
462
tests/unit_tests/tasks/test_manager.py
Normal file
@@ -0,0 +1,462 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""Unit tests for TaskManager pub/sub functionality"""
|
||||
|
||||
import threading
|
||||
import time
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import redis
|
||||
|
||||
from superset.tasks.manager import AbortListener, TaskManager
|
||||
|
||||
|
||||
class TestAbortListener:
|
||||
"""Tests for AbortListener class"""
|
||||
|
||||
def test_stop_sets_event(self):
|
||||
"""Test that stop() sets the stop event"""
|
||||
stop_event = threading.Event()
|
||||
thread = MagicMock(spec=threading.Thread)
|
||||
thread.is_alive.return_value = False
|
||||
|
||||
listener = AbortListener("test-uuid", thread, stop_event)
|
||||
|
||||
assert not stop_event.is_set()
|
||||
listener.stop()
|
||||
assert stop_event.is_set()
|
||||
|
||||
def test_stop_closes_pubsub(self):
|
||||
"""Test that stop() closes the pub/sub connection"""
|
||||
stop_event = threading.Event()
|
||||
thread = MagicMock(spec=threading.Thread)
|
||||
thread.is_alive.return_value = False
|
||||
pubsub = MagicMock()
|
||||
|
||||
listener = AbortListener("test-uuid", thread, stop_event, pubsub)
|
||||
listener.stop()
|
||||
|
||||
pubsub.unsubscribe.assert_called_once()
|
||||
pubsub.close.assert_called_once()
|
||||
|
||||
def test_stop_joins_thread(self):
|
||||
"""Test that stop() joins the listener thread"""
|
||||
stop_event = threading.Event()
|
||||
thread = MagicMock(spec=threading.Thread)
|
||||
thread.is_alive.return_value = True
|
||||
|
||||
listener = AbortListener("test-uuid", thread, stop_event)
|
||||
listener.stop()
|
||||
|
||||
thread.join.assert_called_once_with(timeout=2.0)
|
||||
|
||||
|
||||
class TestTaskManagerInitApp:
|
||||
"""Tests for TaskManager.init_app()"""
|
||||
|
||||
def setup_method(self):
|
||||
"""Reset TaskManager state before each test"""
|
||||
TaskManager._initialized = False
|
||||
TaskManager._channel_prefix = "gtf:abort:"
|
||||
TaskManager._completion_channel_prefix = "gtf:complete:"
|
||||
|
||||
def teardown_method(self):
|
||||
"""Reset TaskManager state after each test"""
|
||||
TaskManager._initialized = False
|
||||
TaskManager._channel_prefix = "gtf:abort:"
|
||||
TaskManager._completion_channel_prefix = "gtf:complete:"
|
||||
|
||||
def test_init_app_sets_channel_prefixes(self):
|
||||
"""Test init_app reads channel prefixes from config"""
|
||||
app = MagicMock()
|
||||
app.config.get.side_effect = lambda key, default=None: {
|
||||
"TASKS_ABORT_CHANNEL_PREFIX": "custom:abort:",
|
||||
"TASKS_COMPLETION_CHANNEL_PREFIX": "custom:complete:",
|
||||
}.get(key, default)
|
||||
|
||||
TaskManager.init_app(app)
|
||||
|
||||
assert TaskManager._initialized is True
|
||||
assert TaskManager._channel_prefix == "custom:abort:"
|
||||
assert TaskManager._completion_channel_prefix == "custom:complete:"
|
||||
|
||||
def test_init_app_skips_if_already_initialized(self):
|
||||
"""Test init_app is idempotent"""
|
||||
TaskManager._initialized = True
|
||||
|
||||
app = MagicMock()
|
||||
TaskManager.init_app(app)
|
||||
|
||||
# Should not call app.config.get since already initialized
|
||||
app.config.get.assert_not_called()
|
||||
|
||||
|
||||
class TestTaskManagerPubSub:
|
||||
"""Tests for TaskManager pub/sub methods"""
|
||||
|
||||
def setup_method(self):
|
||||
"""Reset TaskManager state before each test"""
|
||||
TaskManager._initialized = False
|
||||
TaskManager._channel_prefix = "gtf:abort:"
|
||||
TaskManager._completion_channel_prefix = "gtf:complete:"
|
||||
|
||||
def teardown_method(self):
|
||||
"""Reset TaskManager state after each test"""
|
||||
TaskManager._initialized = False
|
||||
TaskManager._channel_prefix = "gtf:abort:"
|
||||
TaskManager._completion_channel_prefix = "gtf:complete:"
|
||||
|
||||
@patch("superset.tasks.manager.cache_manager")
|
||||
def test_is_pubsub_available_no_redis(self, mock_cache_manager):
|
||||
"""Test is_pubsub_available returns False when Redis not configured"""
|
||||
mock_cache_manager.signal_cache = None
|
||||
assert TaskManager.is_pubsub_available() is False
|
||||
|
||||
@patch("superset.tasks.manager.cache_manager")
|
||||
def test_is_pubsub_available_with_redis(self, mock_cache_manager):
|
||||
"""Test is_pubsub_available returns True when Redis is configured"""
|
||||
mock_cache_manager.signal_cache = MagicMock()
|
||||
assert TaskManager.is_pubsub_available() is True
|
||||
|
||||
def test_get_abort_channel(self):
|
||||
"""Test get_abort_channel returns correct channel name"""
|
||||
task_uuid = "abc-123-def-456"
|
||||
channel = TaskManager.get_abort_channel(task_uuid)
|
||||
assert channel == "gtf:abort:abc-123-def-456"
|
||||
|
||||
def test_get_abort_channel_custom_prefix(self):
|
||||
"""Test get_abort_channel with custom prefix"""
|
||||
TaskManager._channel_prefix = "custom:prefix:"
|
||||
task_uuid = "test-uuid"
|
||||
channel = TaskManager.get_abort_channel(task_uuid)
|
||||
assert channel == "custom:prefix:test-uuid"
|
||||
|
||||
@patch("superset.tasks.manager.cache_manager")
|
||||
def test_publish_abort_no_redis(self, mock_cache_manager):
|
||||
"""Test publish_abort returns False when Redis not available"""
|
||||
mock_cache_manager.signal_cache = None
|
||||
result = TaskManager.publish_abort("test-uuid")
|
||||
assert result is False
|
||||
|
||||
@patch("superset.tasks.manager.cache_manager")
|
||||
def test_publish_abort_success(self, mock_cache_manager):
|
||||
"""Test publish_abort publishes message successfully"""
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.publish.return_value = 1 # One subscriber
|
||||
mock_cache_manager.signal_cache = mock_redis
|
||||
|
||||
result = TaskManager.publish_abort("test-uuid")
|
||||
|
||||
assert result is True
|
||||
mock_redis.publish.assert_called_once_with("gtf:abort:test-uuid", "abort")
|
||||
|
||||
@patch("superset.tasks.manager.cache_manager")
|
||||
def test_publish_abort_redis_error(self, mock_cache_manager):
|
||||
"""Test publish_abort handles Redis errors gracefully"""
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.publish.side_effect = redis.RedisError("Connection lost")
|
||||
mock_cache_manager.signal_cache = mock_redis
|
||||
|
||||
result = TaskManager.publish_abort("test-uuid")
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestTaskManagerListenForAbort:
|
||||
"""Tests for TaskManager.listen_for_abort()"""
|
||||
|
||||
def setup_method(self):
|
||||
"""Reset TaskManager state before each test"""
|
||||
TaskManager._initialized = False
|
||||
TaskManager._channel_prefix = "gtf:abort:"
|
||||
TaskManager._completion_channel_prefix = "gtf:complete:"
|
||||
|
||||
def teardown_method(self):
|
||||
"""Reset TaskManager state after each test"""
|
||||
TaskManager._initialized = False
|
||||
TaskManager._channel_prefix = "gtf:abort:"
|
||||
TaskManager._completion_channel_prefix = "gtf:complete:"
|
||||
|
||||
@patch("superset.tasks.manager.cache_manager")
|
||||
def test_listen_for_abort_no_redis_uses_polling(self, mock_cache_manager):
|
||||
"""Test listen_for_abort falls back to polling when Redis unavailable"""
|
||||
mock_cache_manager.signal_cache = None
|
||||
callback = MagicMock()
|
||||
|
||||
with patch.object(TaskManager, "_poll_for_abort", return_value=None):
|
||||
listener = TaskManager.listen_for_abort(
|
||||
task_uuid="test-uuid",
|
||||
callback=callback,
|
||||
poll_interval=1.0,
|
||||
app=None,
|
||||
)
|
||||
|
||||
# Give thread time to start
|
||||
time.sleep(0.1)
|
||||
listener.stop()
|
||||
|
||||
# Should use polling since no Redis
|
||||
assert listener._pubsub is None
|
||||
|
||||
@patch("superset.tasks.manager.cache_manager")
|
||||
def test_listen_for_abort_with_redis_uses_pubsub(self, mock_cache_manager):
|
||||
"""Test listen_for_abort uses pub/sub when Redis available"""
|
||||
mock_redis = MagicMock()
|
||||
mock_pubsub = MagicMock()
|
||||
mock_redis.pubsub.return_value = mock_pubsub
|
||||
mock_cache_manager.signal_cache = mock_redis
|
||||
|
||||
callback = MagicMock()
|
||||
|
||||
with patch.object(TaskManager, "_listen_pubsub", return_value=None):
|
||||
listener = TaskManager.listen_for_abort(
|
||||
task_uuid="test-uuid",
|
||||
callback=callback,
|
||||
poll_interval=1.0,
|
||||
app=None,
|
||||
)
|
||||
|
||||
# Give thread time to start
|
||||
time.sleep(0.1)
|
||||
listener.stop()
|
||||
|
||||
# Should subscribe to channel
|
||||
mock_pubsub.subscribe.assert_called_once_with("gtf:abort:test-uuid")
|
||||
|
||||
@patch("superset.tasks.manager.cache_manager")
|
||||
def test_listen_for_abort_redis_subscribe_failure_raises(self, mock_cache_manager):
|
||||
"""Test listen_for_abort raises exception on subscribe failure
|
||||
when Redis configured"""
|
||||
import pytest
|
||||
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.pubsub.side_effect = redis.RedisError("Connection failed")
|
||||
mock_cache_manager.signal_cache = mock_redis
|
||||
|
||||
callback = MagicMock()
|
||||
|
||||
# With fail-fast behavior, Redis subscribe failure raises exception
|
||||
with pytest.raises(redis.RedisError, match="Connection failed"):
|
||||
TaskManager.listen_for_abort(
|
||||
task_uuid="test-uuid",
|
||||
callback=callback,
|
||||
poll_interval=1.0,
|
||||
app=None,
|
||||
)
|
||||
|
||||
|
||||
class TestTaskManagerCompletion:
|
||||
"""Tests for TaskManager completion pub/sub and wait_for_completion"""
|
||||
|
||||
def setup_method(self):
|
||||
"""Reset TaskManager state before each test"""
|
||||
TaskManager._initialized = False
|
||||
TaskManager._channel_prefix = "gtf:abort:"
|
||||
TaskManager._completion_channel_prefix = "gtf:complete:"
|
||||
|
||||
def teardown_method(self):
|
||||
"""Reset TaskManager state after each test"""
|
||||
TaskManager._initialized = False
|
||||
TaskManager._channel_prefix = "gtf:abort:"
|
||||
TaskManager._completion_channel_prefix = "gtf:complete:"
|
||||
|
||||
def test_get_completion_channel(self):
|
||||
"""Test get_completion_channel returns correct channel name"""
|
||||
task_uuid = "abc-123-def-456"
|
||||
channel = TaskManager.get_completion_channel(task_uuid)
|
||||
assert channel == "gtf:complete:abc-123-def-456"
|
||||
|
||||
def test_get_completion_channel_custom_prefix(self):
|
||||
"""Test get_completion_channel with custom prefix"""
|
||||
TaskManager._completion_channel_prefix = "custom:complete:"
|
||||
task_uuid = "test-uuid"
|
||||
channel = TaskManager.get_completion_channel(task_uuid)
|
||||
assert channel == "custom:complete:test-uuid"
|
||||
|
||||
@patch("superset.tasks.manager.cache_manager")
|
||||
def test_publish_completion_no_redis(self, mock_cache_manager):
|
||||
"""Test publish_completion returns False when Redis not available"""
|
||||
mock_cache_manager.signal_cache = None
|
||||
result = TaskManager.publish_completion("test-uuid", "success")
|
||||
assert result is False
|
||||
|
||||
@patch("superset.tasks.manager.cache_manager")
|
||||
def test_publish_completion_success(self, mock_cache_manager):
|
||||
"""Test publish_completion publishes message successfully"""
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.publish.return_value = 1 # One subscriber
|
||||
mock_cache_manager.signal_cache = mock_redis
|
||||
|
||||
result = TaskManager.publish_completion("test-uuid", "success")
|
||||
|
||||
assert result is True
|
||||
mock_redis.publish.assert_called_once_with("gtf:complete:test-uuid", "success")
|
||||
|
||||
@patch("superset.tasks.manager.cache_manager")
|
||||
def test_publish_completion_redis_error(self, mock_cache_manager):
|
||||
"""Test publish_completion handles Redis errors gracefully"""
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.publish.side_effect = redis.RedisError("Connection lost")
|
||||
mock_cache_manager.signal_cache = mock_redis
|
||||
|
||||
result = TaskManager.publish_completion("test-uuid", "success")
|
||||
|
||||
assert result is False
|
||||
|
||||
@patch("superset.tasks.manager.cache_manager")
|
||||
@patch("superset.daos.tasks.TaskDAO")
|
||||
def test_wait_for_completion_task_not_found(self, mock_dao, mock_cache_manager):
|
||||
"""Test wait_for_completion raises ValueError for missing task"""
|
||||
import pytest
|
||||
|
||||
mock_cache_manager.signal_cache = None
|
||||
mock_dao.find_one_or_none.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
TaskManager.wait_for_completion("nonexistent-uuid")
|
||||
|
||||
@patch("superset.tasks.manager.cache_manager")
|
||||
@patch("superset.daos.tasks.TaskDAO")
|
||||
def test_wait_for_completion_already_complete(self, mock_dao, mock_cache_manager):
|
||||
"""Test wait_for_completion returns immediately for terminal state"""
|
||||
mock_cache_manager.signal_cache = None
|
||||
mock_task = MagicMock()
|
||||
mock_task.uuid = "test-uuid"
|
||||
mock_task.status = "success"
|
||||
mock_dao.find_one_or_none.return_value = mock_task
|
||||
|
||||
result = TaskManager.wait_for_completion("test-uuid")
|
||||
|
||||
assert result == mock_task
|
||||
# Should only call find_one_or_none once (initial check)
|
||||
mock_dao.find_one_or_none.assert_called_once()
|
||||
|
||||
@patch("superset.tasks.manager.cache_manager")
|
||||
@patch("superset.daos.tasks.TaskDAO")
|
||||
def test_wait_for_completion_timeout(self, mock_dao, mock_cache_manager):
|
||||
"""Test wait_for_completion raises TimeoutError when timeout expires"""
|
||||
import pytest
|
||||
|
||||
mock_cache_manager.signal_cache = None
|
||||
mock_task = MagicMock()
|
||||
mock_task.uuid = "test-uuid"
|
||||
mock_task.status = "in_progress" # Never completes
|
||||
mock_dao.find_one_or_none.return_value = mock_task
|
||||
|
||||
with pytest.raises(TimeoutError, match="Timeout waiting"):
|
||||
TaskManager.wait_for_completion("test-uuid", timeout=0.1)
|
||||
|
||||
@patch("superset.tasks.manager.cache_manager")
|
||||
@patch("superset.daos.tasks.TaskDAO")
|
||||
def test_wait_for_completion_polling_success(self, mock_dao, mock_cache_manager):
|
||||
"""Test wait_for_completion returns when task completes via polling"""
|
||||
mock_cache_manager.signal_cache = None
|
||||
mock_task_pending = MagicMock()
|
||||
mock_task_pending.uuid = "test-uuid"
|
||||
mock_task_pending.status = "pending"
|
||||
|
||||
mock_task_complete = MagicMock()
|
||||
mock_task_complete.uuid = "test-uuid"
|
||||
mock_task_complete.status = "success"
|
||||
|
||||
# First call returns pending, second returns complete
|
||||
mock_dao.find_one_or_none.side_effect = [
|
||||
mock_task_pending,
|
||||
mock_task_complete,
|
||||
]
|
||||
|
||||
result = TaskManager.wait_for_completion(
|
||||
"test-uuid",
|
||||
timeout=5.0,
|
||||
poll_interval=0.1,
|
||||
)
|
||||
|
||||
assert result.status == "success"
|
||||
|
||||
@patch("superset.tasks.manager.cache_manager")
|
||||
@patch("superset.daos.tasks.TaskDAO")
|
||||
def test_wait_for_completion_with_pubsub(self, mock_dao, mock_cache_manager):
|
||||
"""Test wait_for_completion uses pub/sub when Redis available"""
|
||||
mock_task_pending = MagicMock()
|
||||
mock_task_pending.uuid = "test-uuid"
|
||||
mock_task_pending.status = "pending"
|
||||
|
||||
mock_task_complete = MagicMock()
|
||||
mock_task_complete.uuid = "test-uuid"
|
||||
mock_task_complete.status = "success"
|
||||
|
||||
# First call returns pending, second returns complete
|
||||
mock_dao.find_one_or_none.side_effect = [
|
||||
mock_task_pending,
|
||||
mock_task_complete,
|
||||
]
|
||||
|
||||
# Set up mock Redis with pub/sub
|
||||
mock_redis = MagicMock()
|
||||
mock_pubsub = MagicMock()
|
||||
# Simulate receiving a completion message
|
||||
mock_pubsub.get_message.return_value = {
|
||||
"type": "message",
|
||||
"data": "success",
|
||||
}
|
||||
mock_redis.pubsub.return_value = mock_pubsub
|
||||
mock_cache_manager.signal_cache = mock_redis
|
||||
|
||||
result = TaskManager.wait_for_completion(
|
||||
"test-uuid",
|
||||
timeout=5.0,
|
||||
)
|
||||
|
||||
assert result.status == "success"
|
||||
# Should have subscribed to completion channel
|
||||
mock_pubsub.subscribe.assert_called_once_with("gtf:complete:test-uuid")
|
||||
# Should have cleaned up
|
||||
mock_pubsub.unsubscribe.assert_called_once()
|
||||
mock_pubsub.close.assert_called_once()
|
||||
|
||||
@patch("superset.tasks.manager.cache_manager")
|
||||
@patch("superset.daos.tasks.TaskDAO")
|
||||
def test_wait_for_completion_pubsub_error_raises(
|
||||
self, mock_dao, mock_cache_manager
|
||||
):
|
||||
"""Test wait_for_completion raises exception on Redis error when
|
||||
Redis configured"""
|
||||
import pytest
|
||||
|
||||
mock_task_pending = MagicMock()
|
||||
mock_task_pending.uuid = "test-uuid"
|
||||
mock_task_pending.status = "pending"
|
||||
|
||||
mock_dao.find_one_or_none.return_value = mock_task_pending
|
||||
|
||||
# Set up mock Redis that fails
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.pubsub.side_effect = redis.RedisError("Connection failed")
|
||||
mock_cache_manager.signal_cache = mock_redis
|
||||
|
||||
# With fail-fast behavior, Redis error is raised instead of falling back
|
||||
with pytest.raises(redis.RedisError, match="Connection failed"):
|
||||
TaskManager.wait_for_completion(
|
||||
"test-uuid",
|
||||
timeout=5.0,
|
||||
poll_interval=0.1,
|
||||
)
|
||||
|
||||
def test_terminal_states_constant(self):
|
||||
"""Test TERMINAL_STATES contains expected values"""
|
||||
expected = {"success", "failure", "aborted", "timed_out"}
|
||||
assert TaskManager.TERMINAL_STATES == expected
|
||||
612
tests/unit_tests/tasks/test_timeout.py
Normal file
612
tests/unit_tests/tasks/test_timeout.py
Normal file
@@ -0,0 +1,612 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""Unit tests for GTF timeout handling."""
|
||||
|
||||
import time
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from superset_core.api.tasks import TaskOptions, TaskScope
|
||||
|
||||
from superset.tasks.context import TaskContext
|
||||
from superset.tasks.decorators import TaskWrapper
|
||||
|
||||
TEST_UUID = UUID("b8b61b7b-1cd3-4a31-a74a-0a95341afc06")
|
||||
|
||||
# =============================================================================
|
||||
# Fixtures
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_flask_app():
|
||||
"""Create a properly configured mock Flask app."""
|
||||
mock_app = MagicMock()
|
||||
mock_app.config = {
|
||||
"TASK_ABORT_POLLING_DEFAULT_INTERVAL": 0.1,
|
||||
}
|
||||
# Make app_context() return a proper context manager
|
||||
mock_app.app_context.return_value.__enter__ = MagicMock(return_value=None)
|
||||
mock_app.app_context.return_value.__exit__ = MagicMock(return_value=None)
|
||||
return mock_app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_task_abortable():
|
||||
"""Create a mock task that is abortable."""
|
||||
task = MagicMock()
|
||||
task.uuid = TEST_UUID
|
||||
task.status = "in_progress"
|
||||
task.properties_dict = {"is_abortable": True}
|
||||
task.payload_dict = {}
|
||||
# Set real values for dedup_key generation (used by UpdateTaskCommand lock)
|
||||
task.scope = "shared"
|
||||
task.task_type = "test_task"
|
||||
task.task_key = "test_key"
|
||||
task.user_id = 1
|
||||
return task
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_task_not_abortable():
|
||||
"""Create a mock task that is NOT abortable."""
|
||||
task = MagicMock()
|
||||
task.uuid = TEST_UUID
|
||||
task.status = "in_progress"
|
||||
task.properties_dict = {} # No is_abortable means it's not abortable
|
||||
task.payload_dict = {}
|
||||
# Set real values for dedup_key generation (used by UpdateTaskCommand lock)
|
||||
task.scope = "shared"
|
||||
task.task_type = "test_task"
|
||||
task.task_key = "test_key"
|
||||
task.user_id = 1
|
||||
return task
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def task_context_for_timeout(mock_flask_app, mock_task_abortable):
|
||||
"""Create TaskContext with mocked dependencies for timeout tests."""
|
||||
# Ensure mock_task has required attributes for TaskContext
|
||||
mock_task_abortable.payload_dict = {}
|
||||
|
||||
with (
|
||||
patch("superset.tasks.context.current_app") as mock_current_app,
|
||||
patch("superset.daos.tasks.TaskDAO") as mock_dao,
|
||||
patch("superset.tasks.manager.cache_manager") as mock_cache_manager,
|
||||
):
|
||||
# Disable Redis by making signal_cache return None
|
||||
mock_cache_manager.signal_cache = None
|
||||
|
||||
# Configure current_app mock
|
||||
mock_current_app.config = mock_flask_app.config
|
||||
mock_current_app._get_current_object.return_value = mock_flask_app
|
||||
|
||||
# Configure TaskDAO mock
|
||||
mock_dao.find_one_or_none.return_value = mock_task_abortable
|
||||
|
||||
ctx = TaskContext(mock_task_abortable)
|
||||
ctx._app = mock_flask_app
|
||||
|
||||
yield ctx
|
||||
|
||||
# Cleanup: stop timers if started
|
||||
ctx.stop_timeout_timer()
|
||||
if ctx._abort_listener:
|
||||
ctx.stop_abort_polling()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# TaskWrapper._merge_options Timeout Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestTimeoutMerging:
|
||||
"""Test timeout merging behavior in TaskWrapper._merge_options."""
|
||||
|
||||
def test_merge_options_decorator_timeout_used_when_no_override(self):
|
||||
"""Test that decorator timeout is used when no override is provided."""
|
||||
|
||||
def dummy_func():
|
||||
pass
|
||||
|
||||
wrapper = TaskWrapper(
|
||||
name="test_task",
|
||||
func=dummy_func,
|
||||
default_options=TaskOptions(),
|
||||
scope=TaskScope.PRIVATE,
|
||||
default_timeout=300, # 5-minute default
|
||||
)
|
||||
|
||||
merged = wrapper._merge_options(None)
|
||||
assert merged.timeout == 300
|
||||
|
||||
def test_merge_options_override_timeout_takes_precedence(self):
|
||||
"""Test that TaskOptions timeout overrides decorator default."""
|
||||
|
||||
def dummy_func():
|
||||
pass
|
||||
|
||||
wrapper = TaskWrapper(
|
||||
name="test_task",
|
||||
func=dummy_func,
|
||||
default_options=TaskOptions(),
|
||||
scope=TaskScope.PRIVATE,
|
||||
default_timeout=300, # 5-minute default
|
||||
)
|
||||
|
||||
override = TaskOptions(timeout=600) # 10-minute override
|
||||
merged = wrapper._merge_options(override)
|
||||
assert merged.timeout == 600
|
||||
|
||||
def test_merge_options_no_timeout_when_not_configured(self):
|
||||
"""Test that no timeout is set when not configured anywhere."""
|
||||
|
||||
def dummy_func():
|
||||
pass
|
||||
|
||||
wrapper = TaskWrapper(
|
||||
name="test_task",
|
||||
func=dummy_func,
|
||||
default_options=TaskOptions(),
|
||||
scope=TaskScope.PRIVATE,
|
||||
default_timeout=None, # No default timeout
|
||||
)
|
||||
|
||||
merged = wrapper._merge_options(None)
|
||||
assert merged.timeout is None
|
||||
|
||||
def test_merge_options_override_with_other_options_preserves_timeout(self):
|
||||
"""Test that setting other options doesn't lose decorator timeout."""
|
||||
|
||||
def dummy_func():
|
||||
pass
|
||||
|
||||
wrapper = TaskWrapper(
|
||||
name="test_task",
|
||||
func=dummy_func,
|
||||
default_options=TaskOptions(),
|
||||
scope=TaskScope.PRIVATE,
|
||||
default_timeout=300,
|
||||
)
|
||||
|
||||
# Override only task_key, not timeout
|
||||
override = TaskOptions(task_key="my-key")
|
||||
merged = wrapper._merge_options(override)
|
||||
|
||||
# Should keep decorator timeout since override.timeout is None
|
||||
assert merged.timeout == 300
|
||||
assert merged.task_key == "my-key"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# TaskContext Timeout Timer Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestTimeoutTimer:
|
||||
"""Test TaskContext timeout timer behavior."""
|
||||
|
||||
def test_start_timeout_timer_sets_timer(self, task_context_for_timeout):
|
||||
"""Test that start_timeout_timer creates a timer."""
|
||||
ctx = task_context_for_timeout
|
||||
|
||||
assert ctx._timeout_timer is None
|
||||
|
||||
ctx.start_timeout_timer(10)
|
||||
|
||||
assert ctx._timeout_timer is not None
|
||||
assert ctx._timeout_triggered is False
|
||||
|
||||
def test_start_timeout_timer_is_idempotent(self, task_context_for_timeout):
|
||||
"""Test that starting timer twice doesn't create duplicate timers."""
|
||||
ctx = task_context_for_timeout
|
||||
|
||||
ctx.start_timeout_timer(10)
|
||||
first_timer = ctx._timeout_timer
|
||||
|
||||
ctx.start_timeout_timer(20) # Try to start again
|
||||
second_timer = ctx._timeout_timer
|
||||
|
||||
assert first_timer is second_timer
|
||||
|
||||
def test_stop_timeout_timer_cancels_timer(self, task_context_for_timeout):
|
||||
"""Test that stop_timeout_timer cancels the timer."""
|
||||
ctx = task_context_for_timeout
|
||||
|
||||
ctx.start_timeout_timer(10)
|
||||
assert ctx._timeout_timer is not None
|
||||
|
||||
ctx.stop_timeout_timer()
|
||||
|
||||
assert ctx._timeout_timer is None
|
||||
|
||||
def test_stop_timeout_timer_safe_when_no_timer(self, task_context_for_timeout):
|
||||
"""Test that stop_timeout_timer doesn't fail when no timer exists."""
|
||||
ctx = task_context_for_timeout
|
||||
|
||||
assert ctx._timeout_timer is None
|
||||
ctx.stop_timeout_timer() # Should not raise
|
||||
assert ctx._timeout_timer is None
|
||||
|
||||
def test_timeout_triggered_property_initially_false(self, task_context_for_timeout):
|
||||
"""Test that timeout_triggered is False initially."""
|
||||
ctx = task_context_for_timeout
|
||||
assert ctx.timeout_triggered is False
|
||||
|
||||
def test_cleanup_stops_timeout_timer(self, task_context_for_timeout):
|
||||
"""Test that _run_cleanup stops the timeout timer."""
|
||||
ctx = task_context_for_timeout
|
||||
|
||||
ctx.start_timeout_timer(10)
|
||||
assert ctx._timeout_timer is not None
|
||||
|
||||
ctx._run_cleanup()
|
||||
|
||||
assert ctx._timeout_timer is None
|
||||
|
||||
|
||||
class TestTimeoutTrigger:
|
||||
"""Test timeout trigger behavior when timer fires."""
|
||||
|
||||
def test_timeout_triggers_abort_when_abortable(
|
||||
self, mock_flask_app, mock_task_abortable
|
||||
):
|
||||
"""Test that timeout triggers abort handlers when task is abortable."""
|
||||
abort_called = False
|
||||
|
||||
with (
|
||||
patch("superset.tasks.context.current_app") as mock_current_app,
|
||||
patch("superset.daos.tasks.TaskDAO") as mock_dao,
|
||||
patch(
|
||||
"superset.commands.tasks.update.UpdateTaskCommand"
|
||||
) as mock_update_cmd,
|
||||
patch("superset.tasks.manager.cache_manager") as mock_cache_manager,
|
||||
):
|
||||
# Disable Redis by making signal_cache return None
|
||||
mock_cache_manager.signal_cache = None
|
||||
|
||||
mock_current_app.config = mock_flask_app.config
|
||||
mock_current_app._get_current_object.return_value = mock_flask_app
|
||||
mock_dao.find_one_or_none.return_value = mock_task_abortable
|
||||
|
||||
ctx = TaskContext(mock_task_abortable)
|
||||
ctx._app = mock_flask_app
|
||||
|
||||
@ctx.on_abort
|
||||
def handle_abort():
|
||||
nonlocal abort_called
|
||||
abort_called = True
|
||||
|
||||
# Start short timeout
|
||||
ctx.start_timeout_timer(1)
|
||||
|
||||
# Wait for timeout to fire
|
||||
time.sleep(1.5)
|
||||
|
||||
# Abort handler should have been called
|
||||
assert abort_called
|
||||
assert ctx._timeout_triggered
|
||||
assert ctx._abort_detected
|
||||
|
||||
# Verify UpdateTaskCommand was called with ABORTING status
|
||||
mock_update_cmd.assert_called()
|
||||
call_kwargs = mock_update_cmd.call_args[1]
|
||||
assert call_kwargs.get("status") == "aborting"
|
||||
|
||||
# Cleanup
|
||||
ctx.stop_timeout_timer()
|
||||
if ctx._abort_listener:
|
||||
ctx.stop_abort_polling()
|
||||
|
||||
def test_timeout_logs_warning_when_not_abortable(
|
||||
self, mock_flask_app, mock_task_not_abortable
|
||||
):
|
||||
"""Test that timeout logs warning when task has no abort handler."""
|
||||
with (
|
||||
patch("superset.tasks.context.current_app") as mock_current_app,
|
||||
patch("superset.daos.tasks.TaskDAO") as mock_dao,
|
||||
patch("superset.tasks.context.logger") as mock_logger,
|
||||
patch("superset.tasks.manager.cache_manager") as mock_cache_manager,
|
||||
):
|
||||
# Disable Redis by making signal_cache return None
|
||||
mock_cache_manager.signal_cache = None
|
||||
|
||||
mock_current_app.config = mock_flask_app.config
|
||||
mock_current_app._get_current_object.return_value = mock_flask_app
|
||||
mock_dao.find_one_or_none.return_value = mock_task_not_abortable
|
||||
|
||||
ctx = TaskContext(mock_task_not_abortable)
|
||||
ctx._app = mock_flask_app
|
||||
|
||||
# No abort handler registered
|
||||
|
||||
# Start short timeout
|
||||
ctx.start_timeout_timer(1)
|
||||
|
||||
# Wait for timeout to fire
|
||||
time.sleep(1.5)
|
||||
|
||||
# Should have logged warning
|
||||
mock_logger.warning.assert_called()
|
||||
warning_call = mock_logger.warning.call_args
|
||||
assert "no abort handler" in warning_call[0][0].lower()
|
||||
assert ctx._timeout_triggered
|
||||
assert not ctx._abort_detected # No abort since no handler
|
||||
|
||||
# Cleanup
|
||||
ctx.stop_timeout_timer()
|
||||
|
||||
def test_timeout_does_not_trigger_if_already_aborting(
|
||||
self, mock_flask_app, mock_task_abortable
|
||||
):
|
||||
"""Test that timeout doesn't re-trigger abort if already aborting."""
|
||||
abort_count = 0
|
||||
|
||||
with (
|
||||
patch("superset.tasks.context.current_app") as mock_current_app,
|
||||
patch("superset.daos.tasks.TaskDAO") as mock_dao,
|
||||
patch("superset.commands.tasks.update.UpdateTaskCommand"),
|
||||
patch("superset.tasks.manager.cache_manager") as mock_cache_manager,
|
||||
):
|
||||
# Disable Redis by making signal_cache return None
|
||||
mock_cache_manager.signal_cache = None
|
||||
|
||||
mock_current_app.config = mock_flask_app.config
|
||||
mock_current_app._get_current_object.return_value = mock_flask_app
|
||||
mock_dao.find_one_or_none.return_value = mock_task_abortable
|
||||
|
||||
ctx = TaskContext(mock_task_abortable)
|
||||
ctx._app = mock_flask_app
|
||||
|
||||
@ctx.on_abort
|
||||
def handle_abort():
|
||||
nonlocal abort_count
|
||||
abort_count += 1
|
||||
|
||||
# Pre-set abort detected
|
||||
ctx._abort_detected = True
|
||||
|
||||
# Start short timeout
|
||||
ctx.start_timeout_timer(1)
|
||||
|
||||
# Wait for timeout to fire
|
||||
time.sleep(1.5)
|
||||
|
||||
# Handler should NOT have been called since already aborting
|
||||
assert abort_count == 0
|
||||
|
||||
# Cleanup
|
||||
ctx.stop_timeout_timer()
|
||||
if ctx._abort_listener:
|
||||
ctx.stop_abort_polling()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Task Decorator Timeout Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestTaskDecoratorTimeout:
|
||||
"""Test @task decorator timeout parameter."""
|
||||
|
||||
def test_task_decorator_accepts_timeout(self):
|
||||
"""Test that @task decorator accepts timeout parameter."""
|
||||
from superset.tasks.decorators import task
|
||||
from superset.tasks.registry import TaskRegistry
|
||||
|
||||
@task(name="test_timeout_task_1", timeout=300)
|
||||
def timeout_test_task_1():
|
||||
pass
|
||||
|
||||
assert isinstance(timeout_test_task_1, TaskWrapper)
|
||||
assert timeout_test_task_1.default_timeout == 300
|
||||
|
||||
# Cleanup registry
|
||||
TaskRegistry._tasks.pop("test_timeout_task_1", None)
|
||||
|
||||
def test_task_decorator_without_timeout(self):
|
||||
"""Test that @task decorator works without timeout."""
|
||||
from superset.tasks.decorators import task
|
||||
from superset.tasks.registry import TaskRegistry
|
||||
|
||||
@task(name="test_timeout_task_2")
|
||||
def timeout_test_task_2():
|
||||
pass
|
||||
|
||||
assert isinstance(timeout_test_task_2, TaskWrapper)
|
||||
assert timeout_test_task_2.default_timeout is None
|
||||
|
||||
# Cleanup registry
|
||||
TaskRegistry._tasks.pop("test_timeout_task_2", None)
|
||||
|
||||
def test_task_decorator_with_all_params(self):
|
||||
"""Test that @task decorator accepts all parameters together."""
|
||||
from superset.tasks.decorators import task
|
||||
from superset.tasks.registry import TaskRegistry
|
||||
|
||||
@task(name="test_timeout_task_3", scope=TaskScope.SHARED, timeout=600)
|
||||
def timeout_test_task_3():
|
||||
pass
|
||||
|
||||
assert timeout_test_task_3.name == "test_timeout_task_3"
|
||||
assert timeout_test_task_3.scope == TaskScope.SHARED
|
||||
assert timeout_test_task_3.default_timeout == 600
|
||||
|
||||
# Cleanup registry
|
||||
TaskRegistry._tasks.pop("test_timeout_task_3", None)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Timeout Terminal State Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestTimeoutTerminalState:
|
||||
"""Test timeout transitions to correct terminal state (TIMED_OUT vs FAILURE)."""
|
||||
|
||||
def test_timeout_triggered_flag_set_on_timeout(
|
||||
self, mock_flask_app, mock_task_abortable
|
||||
):
|
||||
"""Test that timeout_triggered flag is set when timeout fires."""
|
||||
with (
|
||||
patch("superset.tasks.context.current_app") as mock_current_app,
|
||||
patch("superset.daos.tasks.TaskDAO") as mock_dao,
|
||||
patch("superset.commands.tasks.update.UpdateTaskCommand"),
|
||||
patch("superset.tasks.manager.cache_manager") as mock_cache_manager,
|
||||
):
|
||||
# Disable Redis by making signal_cache return None
|
||||
mock_cache_manager.signal_cache = None
|
||||
|
||||
mock_current_app.config = mock_flask_app.config
|
||||
mock_current_app._get_current_object.return_value = mock_flask_app
|
||||
mock_dao.find_one_or_none.return_value = mock_task_abortable
|
||||
|
||||
ctx = TaskContext(mock_task_abortable)
|
||||
ctx._app = mock_flask_app
|
||||
|
||||
@ctx.on_abort
|
||||
def handle_abort():
|
||||
pass
|
||||
|
||||
# Initially not triggered
|
||||
assert ctx.timeout_triggered is False
|
||||
|
||||
# Start short timeout
|
||||
ctx.start_timeout_timer(1)
|
||||
|
||||
# Wait for timeout to fire
|
||||
time.sleep(1.5)
|
||||
|
||||
# Should be set after timeout
|
||||
assert ctx.timeout_triggered is True
|
||||
|
||||
# Cleanup
|
||||
ctx.stop_timeout_timer()
|
||||
if ctx._abort_listener:
|
||||
ctx.stop_abort_polling()
|
||||
|
||||
def test_user_abort_does_not_set_timeout_triggered(
|
||||
self, mock_flask_app, mock_task_abortable
|
||||
):
|
||||
"""Test that user abort doesn't set timeout_triggered flag."""
|
||||
with (
|
||||
patch("superset.tasks.context.current_app") as mock_current_app,
|
||||
patch("superset.daos.tasks.TaskDAO") as mock_dao,
|
||||
patch("superset.commands.tasks.update.UpdateTaskCommand"),
|
||||
patch("superset.tasks.manager.cache_manager") as mock_cache_manager,
|
||||
):
|
||||
# Disable Redis by making signal_cache return None
|
||||
mock_cache_manager.signal_cache = None
|
||||
|
||||
mock_current_app.config = mock_flask_app.config
|
||||
mock_current_app._get_current_object.return_value = mock_flask_app
|
||||
mock_dao.find_one_or_none.return_value = mock_task_abortable
|
||||
|
||||
ctx = TaskContext(mock_task_abortable)
|
||||
ctx._app = mock_flask_app
|
||||
|
||||
@ctx.on_abort
|
||||
def handle_abort():
|
||||
pass
|
||||
|
||||
# Simulate user abort (not timeout)
|
||||
ctx._on_abort_detected()
|
||||
|
||||
# timeout_triggered should still be False
|
||||
assert ctx.timeout_triggered is False
|
||||
# But abort_detected should be True
|
||||
assert ctx._abort_detected is True
|
||||
|
||||
# Cleanup
|
||||
if ctx._abort_listener:
|
||||
ctx.stop_abort_polling()
|
||||
|
||||
def test_abort_handlers_completed_tracks_success(
|
||||
self, mock_flask_app, mock_task_abortable
|
||||
):
|
||||
"""Test that abort_handlers_completed flag tracks successful
|
||||
handler execution."""
|
||||
with (
|
||||
patch("superset.tasks.context.current_app") as mock_current_app,
|
||||
patch("superset.daos.tasks.TaskDAO") as mock_dao,
|
||||
patch("superset.commands.tasks.update.UpdateTaskCommand"),
|
||||
patch("superset.tasks.manager.cache_manager") as mock_cache_manager,
|
||||
):
|
||||
# Disable Redis by making signal_cache return None
|
||||
mock_cache_manager.signal_cache = None
|
||||
|
||||
mock_current_app.config = mock_flask_app.config
|
||||
mock_current_app._get_current_object.return_value = mock_flask_app
|
||||
mock_dao.find_one_or_none.return_value = mock_task_abortable
|
||||
|
||||
ctx = TaskContext(mock_task_abortable)
|
||||
ctx._app = mock_flask_app
|
||||
|
||||
@ctx.on_abort
|
||||
def handle_abort():
|
||||
pass # Successful handler
|
||||
|
||||
# Initially not completed
|
||||
assert ctx.abort_handlers_completed is False
|
||||
|
||||
# Trigger abort handlers
|
||||
ctx._trigger_abort_handlers()
|
||||
|
||||
# Should be marked as completed
|
||||
assert ctx.abort_handlers_completed is True
|
||||
|
||||
# Cleanup
|
||||
if ctx._abort_listener:
|
||||
ctx.stop_abort_polling()
|
||||
|
||||
def test_abort_handlers_completed_false_on_exception(
|
||||
self, mock_flask_app, mock_task_abortable
|
||||
):
|
||||
"""Test that abort_handlers_completed is False when handler throws."""
|
||||
with (
|
||||
patch("superset.tasks.context.current_app") as mock_current_app,
|
||||
patch("superset.daos.tasks.TaskDAO") as mock_dao,
|
||||
patch("superset.commands.tasks.update.UpdateTaskCommand"),
|
||||
patch("superset.tasks.manager.cache_manager") as mock_cache_manager,
|
||||
):
|
||||
# Disable Redis by making signal_cache return None
|
||||
mock_cache_manager.signal_cache = None
|
||||
|
||||
mock_current_app.config = mock_flask_app.config
|
||||
mock_current_app._get_current_object.return_value = mock_flask_app
|
||||
mock_dao.find_one_or_none.return_value = mock_task_abortable
|
||||
|
||||
ctx = TaskContext(mock_task_abortable)
|
||||
ctx._app = mock_flask_app
|
||||
|
||||
@ctx.on_abort
|
||||
def handle_abort():
|
||||
raise ValueError("Handler failed")
|
||||
|
||||
# Initially not completed
|
||||
assert ctx.abort_handlers_completed is False
|
||||
|
||||
# Trigger abort handlers (will catch the exception internally)
|
||||
ctx._trigger_abort_handlers()
|
||||
|
||||
# Should NOT be marked as completed since handler threw
|
||||
assert ctx.abort_handlers_completed is False
|
||||
|
||||
# Cleanup
|
||||
if ctx._abort_listener:
|
||||
ctx.stop_abort_polling()
|
||||
@@ -22,9 +22,19 @@ from typing import Any, Optional, Union
|
||||
|
||||
import pytest
|
||||
from flask_appbuilder.security.sqla.models import User
|
||||
from superset_core.api.tasks import TaskScope
|
||||
|
||||
from superset.tasks.exceptions import ExecutorNotFoundError, InvalidExecutorError
|
||||
from superset.tasks.types import Executor, ExecutorType, FixedExecutor
|
||||
from superset.tasks.utils import (
|
||||
error_update,
|
||||
get_active_dedup_key,
|
||||
get_finished_dedup_key,
|
||||
parse_properties,
|
||||
progress_update,
|
||||
serialize_properties,
|
||||
)
|
||||
from superset.utils.hashing import hash_from_str
|
||||
|
||||
FIXED_USER_ID = 1234
|
||||
FIXED_USERNAME = "admin"
|
||||
@@ -330,3 +340,242 @@ def test_get_executor(
|
||||
)
|
||||
assert executor_type == expected_executor_type
|
||||
assert executor == expected_executor
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"scope,task_type,task_key,user_id,expected_composite_key",
|
||||
[
|
||||
# Private tasks with TaskScope enum
|
||||
(
|
||||
TaskScope.PRIVATE,
|
||||
"sql_execution",
|
||||
"chart_123",
|
||||
42,
|
||||
"private|sql_execution|chart_123|42",
|
||||
),
|
||||
(
|
||||
TaskScope.PRIVATE,
|
||||
"thumbnail_gen",
|
||||
"dash_456",
|
||||
100,
|
||||
"private|thumbnail_gen|dash_456|100",
|
||||
),
|
||||
# Private tasks with string scope
|
||||
(
|
||||
"private",
|
||||
"api_call",
|
||||
"endpoint_789",
|
||||
200,
|
||||
"private|api_call|endpoint_789|200",
|
||||
),
|
||||
# Shared tasks with TaskScope enum
|
||||
(
|
||||
TaskScope.SHARED,
|
||||
"report_gen",
|
||||
"monthly_report",
|
||||
None,
|
||||
"shared|report_gen|monthly_report",
|
||||
),
|
||||
(
|
||||
TaskScope.SHARED,
|
||||
"export_csv",
|
||||
"large_export",
|
||||
999, # user_id should be ignored for shared
|
||||
"shared|export_csv|large_export",
|
||||
),
|
||||
# Shared tasks with string scope
|
||||
(
|
||||
"shared",
|
||||
"batch_process",
|
||||
"batch_001",
|
||||
123, # user_id should be ignored for shared
|
||||
"shared|batch_process|batch_001",
|
||||
),
|
||||
# System tasks with TaskScope enum
|
||||
(
|
||||
TaskScope.SYSTEM,
|
||||
"cleanup_task",
|
||||
"daily_cleanup",
|
||||
None,
|
||||
"system|cleanup_task|daily_cleanup",
|
||||
),
|
||||
(
|
||||
TaskScope.SYSTEM,
|
||||
"db_migration",
|
||||
"version_123",
|
||||
1, # user_id should be ignored for system
|
||||
"system|db_migration|version_123",
|
||||
),
|
||||
# System tasks with string scope
|
||||
(
|
||||
"system",
|
||||
"maintenance",
|
||||
"nightly_job",
|
||||
2, # user_id should be ignored for system
|
||||
"system|maintenance|nightly_job",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_get_active_dedup_key(
|
||||
scope, task_type, task_key, user_id, expected_composite_key, app_context
|
||||
):
|
||||
"""Test get_active_dedup_key generates a hash of the composite key.
|
||||
|
||||
The function hashes the composite key using the configured HASH_ALGORITHM
|
||||
to produce a fixed-length dedup_key for database storage. The result is
|
||||
truncated to 64 chars to fit the database column.
|
||||
"""
|
||||
result = get_active_dedup_key(scope, task_type, task_key, user_id)
|
||||
|
||||
# The result should be a hash of the expected composite key, truncated to 64 chars
|
||||
expected_hash = hash_from_str(expected_composite_key)[:64]
|
||||
assert result == expected_hash
|
||||
assert len(result) <= 64
|
||||
|
||||
|
||||
def test_get_active_dedup_key_private_requires_user_id():
|
||||
"""Test that private tasks require explicit user_id parameter."""
|
||||
with pytest.raises(ValueError, match="user_id required for private tasks"):
|
||||
get_active_dedup_key(TaskScope.PRIVATE, "test_type", "test_key")
|
||||
|
||||
|
||||
def test_get_finished_dedup_key():
|
||||
"""Test that finished tasks use UUID as dedup_key"""
|
||||
test_uuid = "a1b2c3d4-e5f6-7890-abcd-ef1234567890"
|
||||
result = get_finished_dedup_key(test_uuid)
|
||||
assert result == test_uuid
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"progress,expected",
|
||||
[
|
||||
# Float (percentage) progress
|
||||
(0.5, {"progress_percent": 0.5}),
|
||||
(0.0, {"progress_percent": 0.0}),
|
||||
(1.0, {"progress_percent": 1.0}),
|
||||
(0.25, {"progress_percent": 0.25}),
|
||||
# Int (count only) progress
|
||||
(42, {"progress_current": 42}),
|
||||
(0, {"progress_current": 0}),
|
||||
(1000, {"progress_current": 1000}),
|
||||
# Tuple (current, total) progress with auto-computed percentage
|
||||
(
|
||||
(50, 100),
|
||||
{"progress_current": 50, "progress_total": 100, "progress_percent": 0.5},
|
||||
),
|
||||
(
|
||||
(25, 100),
|
||||
{"progress_current": 25, "progress_total": 100, "progress_percent": 0.25},
|
||||
),
|
||||
(
|
||||
(100, 100),
|
||||
{"progress_current": 100, "progress_total": 100, "progress_percent": 1.0},
|
||||
),
|
||||
# Tuple with zero total (no percentage computed)
|
||||
((10, 0), {"progress_current": 10, "progress_total": 0}),
|
||||
((0, 0), {"progress_current": 0, "progress_total": 0}),
|
||||
],
|
||||
)
|
||||
def test_progress_update(progress, expected):
|
||||
"""Test progress_update returns correct TaskProperties dict."""
|
||||
result = progress_update(progress)
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_error_update():
|
||||
"""Test error_update captures exception details."""
|
||||
try:
|
||||
raise ValueError("Test error message")
|
||||
except ValueError as e:
|
||||
result = error_update(e)
|
||||
|
||||
assert result["error_message"] == "Test error message"
|
||||
assert result["exception_type"] == "ValueError"
|
||||
assert "stack_trace" in result
|
||||
assert "ValueError" in result["stack_trace"]
|
||||
|
||||
|
||||
def test_error_update_custom_exception():
|
||||
"""Test error_update with custom exception class."""
|
||||
|
||||
class CustomError(Exception):
|
||||
pass
|
||||
|
||||
try:
|
||||
raise CustomError("Custom error")
|
||||
except CustomError as e:
|
||||
result = error_update(e)
|
||||
|
||||
assert result["error_message"] == "Custom error"
|
||||
assert result["exception_type"] == "CustomError"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"json_str,expected",
|
||||
[
|
||||
# Valid JSON
|
||||
(
|
||||
'{"is_abortable": true, "progress_percent": 0.5}',
|
||||
{"is_abortable": True, "progress_percent": 0.5},
|
||||
),
|
||||
(
|
||||
'{"error_message": "Something failed"}',
|
||||
{"error_message": "Something failed"},
|
||||
),
|
||||
(
|
||||
'{"progress_current": 50, "progress_total": 100}',
|
||||
{"progress_current": 50, "progress_total": 100},
|
||||
),
|
||||
# Empty/None cases
|
||||
("", {}),
|
||||
(None, {}),
|
||||
# Invalid JSON returns empty dict
|
||||
("not valid json", {}),
|
||||
("{broken", {}),
|
||||
# Unknown keys are preserved (forward compatibility)
|
||||
(
|
||||
'{"is_abortable": true, "future_field": "value"}',
|
||||
{"is_abortable": True, "future_field": "value"},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_parse_properties(json_str, expected):
|
||||
"""Test parse_properties parses JSON to TaskProperties dict."""
|
||||
result = parse_properties(json_str)
|
||||
assert result == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"props,expected_contains",
|
||||
[
|
||||
# Full properties
|
||||
(
|
||||
{"is_abortable": True, "progress_percent": 0.5},
|
||||
{"is_abortable": True, "progress_percent": 0.5},
|
||||
),
|
||||
# Empty dict
|
||||
({}, {}),
|
||||
# Sparse properties
|
||||
({"is_abortable": True}, {"is_abortable": True}),
|
||||
({"error_message": "fail"}, {"error_message": "fail"}),
|
||||
],
|
||||
)
|
||||
def test_serialize_properties(props, expected_contains):
|
||||
"""Test serialize_properties converts TaskProperties to JSON."""
|
||||
from superset.utils import json
|
||||
|
||||
result = serialize_properties(props)
|
||||
parsed = json.loads(result)
|
||||
assert parsed == expected_contains
|
||||
|
||||
|
||||
def test_properties_roundtrip():
|
||||
"""Test that serialize -> parse roundtrip preserves data."""
|
||||
original = {
|
||||
"is_abortable": True,
|
||||
"progress_percent": 0.75,
|
||||
"error_message": "Test error",
|
||||
}
|
||||
serialized = serialize_properties(original)
|
||||
parsed = parse_properties(serialized)
|
||||
assert parsed == original
|
||||
|
||||
@@ -54,7 +54,7 @@ def test_json_loads_exception():
|
||||
|
||||
|
||||
def test_json_loads_encoding():
|
||||
unicode_data = b'{"a": "\u0073\u0074\u0072"}'
|
||||
unicode_data = rb'{"a": "\u0073\u0074\u0072"}'
|
||||
data = json.loads(unicode_data)
|
||||
assert data["a"] == "str"
|
||||
utf16_data = b'\xff\xfe{\x00"\x00a\x00"\x00:\x00 \x00"\x00s\x00t\x00r\x00"\x00}\x00'
|
||||
|
||||
@@ -119,7 +119,7 @@ def test_refresh_oauth2_token_deletes_token_on_oauth2_exception(
|
||||
was revoked), the invalid token should be deleted and the exception re-raised.
|
||||
"""
|
||||
db = mocker.patch("superset.utils.oauth2.db")
|
||||
mocker.patch("superset.utils.oauth2.KeyValueDistributedLock")
|
||||
mocker.patch("superset.utils.oauth2.DistributedLock")
|
||||
|
||||
class OAuth2ExceptionError(Exception):
|
||||
pass
|
||||
@@ -149,7 +149,7 @@ def test_refresh_oauth2_token_keeps_token_on_other_exception(
|
||||
exception re-raised.
|
||||
"""
|
||||
db = mocker.patch("superset.utils.oauth2.db")
|
||||
mocker.patch("superset.utils.oauth2.KeyValueDistributedLock")
|
||||
mocker.patch("superset.utils.oauth2.DistributedLock")
|
||||
|
||||
class OAuth2ExceptionError(Exception):
|
||||
pass
|
||||
@@ -175,7 +175,7 @@ def test_refresh_oauth2_token_no_access_token_in_response(
|
||||
This can happen when the refresh token was revoked.
|
||||
"""
|
||||
mocker.patch("superset.utils.oauth2.db")
|
||||
mocker.patch("superset.utils.oauth2.KeyValueDistributedLock")
|
||||
mocker.patch("superset.utils.oauth2.DistributedLock")
|
||||
db_engine_spec = mocker.MagicMock()
|
||||
db_engine_spec.get_oauth2_fresh_token.return_value = {
|
||||
"error": "invalid_grant",
|
||||
|
||||
Reference in New Issue
Block a user