feat: add global task framework (#36368)

This commit is contained in:
Ville Brofeldt
2026-02-09 10:45:56 -08:00
committed by GitHub
parent 6984e93171
commit 59dd2fa385
89 changed files with 15535 additions and 291 deletions

View File

@@ -73,6 +73,7 @@ FEATURE_FLAGS = {
"AVOID_COLORS_COLLISION": True,
"DRILL_TO_DETAIL": True,
"DRILL_BY": True,
"GLOBAL_TASK_FRAMEWORK": True,
}
WEBDRIVER_BASEURL = "http://0.0.0.0:8081/"

View File

@@ -0,0 +1,538 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Integration tests for Task REST API"""
from contextlib import contextmanager
from typing import Generator
import prison
from superset_core.api.tasks import TaskStatus
from superset import db
from superset.models.tasks import Task
from superset.utils import json
from tests.integration_tests.base_tests import SupersetTestCase
from tests.integration_tests.constants import (
ADMIN_USERNAME,
GAMMA_USERNAME,
)
class TestTaskApi(SupersetTestCase):
"""Tests for Task REST API"""
TASK_API_BASE = "api/v1/task"
@contextmanager
def _create_tasks(self) -> Generator[list[Task], None, None]:
"""
Context manager to create test tasks with guaranteed cleanup.
Uses TaskDAO to create tasks, testing the actual production code path.
Usage:
with self._create_tasks() as tasks:
# Use tasks in test
# Cleanup happens automatically even if test fails
"""
from superset_core.api.tasks import TaskScope
from superset.daos.tasks import TaskDAO
admin = self.get_user("admin")
gamma = self.get_user("gamma")
tasks = []
try:
# Create tasks with different statuses using TaskDAO
for i in range(5):
task_key = f"test_task_{i}"
# Create task using DAO (this tests the dedup_key creation logic)
task = TaskDAO.create_task(
task_type="test_type",
task_key=task_key,
task_name=f"Test Task {i}",
scope=TaskScope.PRIVATE,
user_id=admin.id,
payload={"test": "data"},
)
# Set created_by for test purposes (DAO uses Flask-AppBuilder context)
task.created_by = admin
# Alternate between pending and finished tasks
if i % 2 != 0:
# Simulate realistic task lifecycle: PENDING → IN_PROGRESS → SUCCESS
# This sets both started_at (on IN_PROGRESS) and ended_at (on
# SUCCESS) so duration_seconds returns a valid value
task.set_status(TaskStatus.IN_PROGRESS)
task.set_status(TaskStatus.SUCCESS)
db.session.commit()
tasks.append(task)
# Create pending task for gamma user (use PENDING so it can be aborted)
gamma_task = TaskDAO.create_task(
task_type="test_type",
task_key="gamma_task",
task_name="Gamma Task",
scope=TaskScope.PRIVATE,
user_id=gamma.id,
payload={"user": "gamma"},
)
# Set created_by for test purposes
gamma_task.created_by = gamma
db.session.commit()
tasks.append(gamma_task)
yield tasks
finally:
# Cleanup happens here regardless of test success/failure
for task in tasks:
try:
db.session.delete(task)
except Exception: # noqa: S110
# Task may already be deleted or session may be in bad state
pass
try:
db.session.commit()
except Exception:
# Rollback if commit fails
db.session.rollback()
def test_info_task(self):
"""
Task API: Test info endpoint
"""
self.login(ADMIN_USERNAME)
uri = f"{self.TASK_API_BASE}/_info"
rv = self.client.get(uri)
assert rv.status_code == 200
data = json.loads(rv.data.decode("utf-8"))
assert "permissions" in data
def test_get_task_by_uuid(self):
"""
Task API: Test get task by UUID and verify dedup_key is hashed
"""
with self._create_tasks():
self.login(ADMIN_USERNAME)
admin = self.get_user("admin")
# Get a pending task to verify active dedup_key format
task = (
db.session.query(Task)
.filter_by(
created_by_fk=admin.id,
status=TaskStatus.PENDING.value,
task_type="test_type",
)
.first()
)
assert task is not None
# Verify active task has hashed dedup_key (64 chars for SHA-256)
assert len(task.dedup_key) == 64
assert all(c in "0123456789abcdef" for c in task.dedup_key)
assert task.dedup_key != str(task.uuid)
uri = f"{self.TASK_API_BASE}/{task.uuid}"
rv = self.client.get(uri)
assert rv.status_code == 200
data = json.loads(rv.data.decode("utf-8"))
# Compare strings since JSON response contains string UUID
assert data["result"]["uuid"] == str(task.uuid)
assert data["result"]["id"] == task.id
def test_get_task_not_found(self):
"""
Task API: Test get task not found with non-existent UUID
"""
self.login(ADMIN_USERNAME)
# Use a valid UUID that doesn't exist in the database
uri = f"{self.TASK_API_BASE}/00000000-0000-0000-0000-000000000000"
rv = self.client.get(uri)
assert rv.status_code == 404
def test_get_task_invalid_uuid(self):
"""
Task API: Test get task with invalid UUID
"""
self.login(ADMIN_USERNAME)
uri = f"{self.TASK_API_BASE}/invalid-uuid"
rv = self.client.get(uri)
assert rv.status_code == 404
def test_get_task_list(self):
"""
Task API: Test get task list
"""
with self._create_tasks():
self.login(ADMIN_USERNAME)
uri = f"{self.TASK_API_BASE}/"
rv = self.client.get(uri)
assert rv.status_code == 200
data = json.loads(rv.data.decode("utf-8"))
assert data["count"] >= 6 # At least the fixtures we created
assert "result" in data
def test_get_task_list_filtered_by_status(self):
"""
Task API: Test get task list filtered by status
"""
with self._create_tasks():
self.login(ADMIN_USERNAME)
arguments = {
"filters": [
{"col": "status", "opr": "eq", "value": TaskStatus.PENDING.value}
]
}
uri = f"{self.TASK_API_BASE}/?q={prison.dumps(arguments)}"
rv = self.client.get(uri)
assert rv.status_code == 200
data = json.loads(rv.data.decode("utf-8"))
for task in data["result"]:
assert task["status"] == TaskStatus.PENDING.value
def test_get_task_list_filtered_by_type(self):
"""
Task API: Test get task list filtered by type
"""
with self._create_tasks():
self.login(ADMIN_USERNAME)
arguments = {
"filters": [{"col": "task_type", "opr": "eq", "value": "test_type"}]
}
uri = f"{self.TASK_API_BASE}/?q={prison.dumps(arguments)}"
rv = self.client.get(uri)
assert rv.status_code == 200
data = json.loads(rv.data.decode("utf-8"))
assert data["count"] >= 6
for task in data["result"]:
assert task["task_type"] == "test_type"
def test_get_task_list_ordered(self):
"""
Task API: Test get task list with ordering
"""
with self._create_tasks():
self.login(ADMIN_USERNAME)
arguments = {
"order_column": "created_on",
"order_direction": "desc",
}
uri = f"{self.TASK_API_BASE}/?q={prison.dumps(arguments)}"
rv = self.client.get(uri)
assert rv.status_code == 200
data = json.loads(rv.data.decode("utf-8"))
assert len(data["result"]) > 0
def test_get_task_list_paginated(self):
"""
Task API: Test get task list with pagination
"""
with self._create_tasks():
self.login(ADMIN_USERNAME)
arguments = {"page": 0, "page_size": 2}
uri = f"{self.TASK_API_BASE}/?q={prison.dumps(arguments)}"
rv = self.client.get(uri)
assert rv.status_code == 200
data = json.loads(rv.data.decode("utf-8"))
assert len(data["result"]) <= 2
assert data["count"] >= 6
def test_cancel_task_by_uuid(self):
"""
Task API: Test cancel task by UUID
"""
with self._create_tasks():
self.login(ADMIN_USERNAME)
admin = self.get_user("admin")
task = (
db.session.query(Task)
.filter_by(created_by_fk=admin.id, status=TaskStatus.PENDING.value)
.first()
)
assert task is not None
uri = f"{self.TASK_API_BASE}/{task.uuid}/cancel"
rv = self.client.post(uri)
assert rv.status_code == 200
data = json.loads(rv.data.decode("utf-8"))
# Compare strings since JSON response contains string UUID
assert data["task"]["uuid"] == str(task.uuid)
assert data["task"]["status"] == TaskStatus.ABORTED.value
assert data["action"] == "aborted"
def test_cancel_task_not_found(self):
"""
Task API: Test cancel task not found with non-existent UUID
"""
self.login(ADMIN_USERNAME)
uri = f"{self.TASK_API_BASE}/00000000-0000-0000-0000-000000000000/cancel"
rv = self.client.post(uri)
assert rv.status_code == 404
def test_cancel_task_not_owned(self):
"""
Task API: Test cancel task not owned by user
"""
with self._create_tasks():
self.login(GAMMA_USERNAME)
admin = self.get_user("admin")
# Try to cancel admin's task as gamma user
task = db.session.query(Task).filter_by(created_by_fk=admin.id).first()
assert task is not None
uri = f"{self.TASK_API_BASE}/{task.uuid}/cancel"
rv = self.client.post(uri)
assert rv.status_code == 404
def test_cancel_task_admin_can_cancel_others(self):
"""
Task API: Test admin can cancel other users' tasks
"""
with self._create_tasks():
self.login(ADMIN_USERNAME)
gamma = self.get_user("gamma")
# Admin cancels gamma's task
task = db.session.query(Task).filter_by(created_by_fk=gamma.id).first()
assert task is not None
uri = f"{self.TASK_API_BASE}/{task.uuid}/cancel"
rv = self.client.post(uri)
assert rv.status_code == 200
def test_get_task_status_by_uuid(self):
"""
Task API: Test get task status by UUID
"""
with self._create_tasks():
self.login(ADMIN_USERNAME)
admin = self.get_user("admin")
task = db.session.query(Task).filter_by(created_by_fk=admin.id).first()
assert task is not None
uri = f"{self.TASK_API_BASE}/{task.uuid}/status"
rv = self.client.get(uri)
assert rv.status_code == 200
data = json.loads(rv.data.decode("utf-8"))
assert "status" in data
assert data["status"] == task.status
def test_get_task_status_not_found(self):
"""
Task API: Test get task status not found with non-existent UUID
"""
self.login(ADMIN_USERNAME)
uri = f"{self.TASK_API_BASE}/00000000-0000-0000-0000-000000000000/status"
rv = self.client.get(uri)
assert rv.status_code == 404
def test_get_task_status_not_owned(self):
"""
Task API: Test non-owner can't see task status
"""
with self._create_tasks():
self.login(GAMMA_USERNAME)
admin = self.get_user("admin")
# Try to get status of admin's task as gamma user
task = db.session.query(Task).filter_by(created_by_fk=admin.id).first()
assert task is not None
uri = f"{self.TASK_API_BASE}/{task.uuid}/status"
rv = self.client.get(uri)
# Should be forbidden due to base filter
assert rv.status_code == 404
def test_get_task_status_admin_can_see_others(self):
"""
Task API: Test admin can see other users' task status
"""
with self._create_tasks():
self.login(ADMIN_USERNAME)
gamma = self.get_user("gamma")
# Admin gets gamma's task status
task = db.session.query(Task).filter_by(created_by_fk=gamma.id).first()
assert task is not None
uri = f"{self.TASK_API_BASE}/{task.uuid}/status"
rv = self.client.get(uri)
assert rv.status_code == 200
data = json.loads(rv.data.decode("utf-8"))
assert data["status"] == task.status
def test_get_task_list_user_sees_own_tasks(self):
"""
Task API: Test non-admin user only sees their own tasks
"""
with self._create_tasks():
self.login(GAMMA_USERNAME)
gamma = self.get_user("gamma")
uri = f"{self.TASK_API_BASE}/"
rv = self.client.get(uri)
assert rv.status_code == 200
data = json.loads(rv.data.decode("utf-8"))
# Gamma should only see their own task
for task in data["result"]:
assert task["created_by"]["id"] == gamma.id
def test_get_task_list_admin_sees_all_tasks(self):
"""
Task API: Test admin sees all tasks
"""
with self._create_tasks():
self.login(ADMIN_USERNAME)
uri = f"{self.TASK_API_BASE}/"
rv = self.client.get(uri)
assert rv.status_code == 200
data = json.loads(rv.data.decode("utf-8"))
# Admin should see all tasks
assert data["count"] >= 6
def test_task_response_schema(self):
"""
Task API: Test response schema includes all expected fields
"""
with self._create_tasks():
self.login(ADMIN_USERNAME)
admin = self.get_user("admin")
task = db.session.query(Task).filter_by(created_by_fk=admin.id).first()
uri = f"{self.TASK_API_BASE}/{task.uuid}"
rv = self.client.get(uri)
assert rv.status_code == 200
data = json.loads(rv.data.decode("utf-8"))
result = data["result"]
# Check all expected fields are present
expected_fields = [
"id",
"uuid",
"task_key",
"task_type",
"task_name",
"status",
"created_on",
"created_on_delta_humanized",
"changed_on",
"changed_by",
"started_at",
"ended_at",
"created_by",
"user_id",
"payload",
"properties",
"duration_seconds",
"scope",
"subscriber_count",
"subscribers",
]
for field in expected_fields:
assert field in result, f"Field {field} missing from response"
# Verify properties is a dict with expected structure
properties = result["properties"]
assert isinstance(properties, dict)
def test_task_payload_serialization(self):
"""
Task API: Test payload is properly serialized as dict
"""
with self._create_tasks():
self.login(ADMIN_USERNAME)
admin = self.get_user("admin")
task = (
db.session.query(Task)
.filter_by(created_by_fk=admin.id, task_type="test_type")
.first()
)
uri = f"{self.TASK_API_BASE}/{task.uuid}"
rv = self.client.get(uri)
assert rv.status_code == 200
data = json.loads(rv.data.decode("utf-8"))
payload = data["result"]["payload"]
# Payload should be a dict, not a string
assert isinstance(payload, dict)
assert "test" in payload
assert payload["test"] == "data"
def test_task_computed_properties(self):
"""
Task API: Test computed properties in response
This test verifies that computed properties (status, duration_seconds)
are correctly returned in the API response. Internal DB columns like
dedup_key are tested in unit tests (test_find_by_task_key_finished_not_found).
"""
with self._create_tasks():
self.login(ADMIN_USERNAME)
admin = self.get_user("admin")
# Get a successful task
task = (
db.session.query(Task)
.filter_by(created_by_fk=admin.id, status=TaskStatus.SUCCESS.value)
.first()
)
assert task is not None
uri = f"{self.TASK_API_BASE}/{task.uuid}"
rv = self.client.get(uri)
assert rv.status_code == 200
data = json.loads(rv.data.decode("utf-8"))
result = data["result"]
# Check status field (computed properties are now derived from status)
assert result["status"] == TaskStatus.SUCCESS.value
# Properties dict should exist and be a dict
assert "properties" in result
assert isinstance(result["properties"], dict)
# Verify duration_seconds is not null for completed tasks with timestamps
# (requires both started_at and ended_at to be set)
if result.get("started_at") and result.get("ended_at"):
assert result["duration_seconds"] is not None
assert result["duration_seconds"] >= 0.0

View File

@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@@ -0,0 +1,482 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file to you under
# the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from unittest.mock import patch
from uuid import UUID, uuid4
import pytest
from superset_core.api.tasks import TaskScope, TaskStatus
from superset import db
from superset.commands.tasks.cancel import CancelTaskCommand
from superset.commands.tasks.exceptions import (
TaskAbortFailedError,
TaskNotAbortableError,
TaskNotFoundError,
TaskPermissionDeniedError,
)
from superset.daos.tasks import TaskDAO
from superset.utils.core import override_user
from tests.integration_tests.test_app import app
def test_cancel_pending_task_aborts(app_context, get_user) -> None:
"""Test canceling a pending task directly aborts it"""
admin = get_user("admin")
# Create a pending private task
task = TaskDAO.create_task(
task_type="test_type",
task_key="cancel_pending_test",
scope=TaskScope.PRIVATE,
user_id=admin.id,
)
task.created_by = admin
db.session.commit()
try:
# Cancel the pending task with admin user context
with override_user(admin):
command = CancelTaskCommand(task_uuid=task.uuid)
result = command.run()
# Verify task is aborted (pending goes directly to ABORTED)
assert result.uuid == task.uuid
assert result.status == TaskStatus.ABORTED.value
assert command.action_taken == "aborted"
# Verify in database
db.session.refresh(task)
assert task.status == TaskStatus.ABORTED.value
finally:
# Cleanup
db.session.delete(task)
db.session.commit()
def test_cancel_in_progress_abortable_task_sets_aborting(app_context, get_user) -> None:
"""Test canceling an in-progress task with abort handler sets ABORTING"""
admin = get_user("admin")
# Create an in-progress abortable task
task = TaskDAO.create_task(
task_type="test_type",
task_key="cancel_in_progress_test",
scope=TaskScope.PRIVATE,
user_id=admin.id,
properties={"is_abortable": True},
)
task.created_by = admin
task.set_status(TaskStatus.IN_PROGRESS)
db.session.commit()
try:
# Cancel the in-progress task - mock publish_abort to avoid Redis dependency
with (
override_user(admin),
patch("superset.tasks.manager.TaskManager.publish_abort"),
):
command = CancelTaskCommand(task_uuid=task.uuid)
result = command.run()
# In-progress tasks go to ABORTING (not ABORTED)
assert result.uuid == task.uuid
assert result.status == TaskStatus.ABORTING.value
assert command.action_taken == "aborted"
# Verify in database
db.session.refresh(task)
assert task.status == TaskStatus.ABORTING.value
finally:
# Cleanup
db.session.delete(task)
db.session.commit()
def test_cancel_in_progress_not_abortable_raises_error(app_context, get_user) -> None:
"""Test canceling an in-progress task without abort handler raises error"""
admin = get_user("admin")
unique_key = f"cancel_not_abortable_test_{uuid4().hex[:8]}"
# Create an in-progress non-abortable task
task = TaskDAO.create_task(
task_type="test_type",
task_key=unique_key,
scope=TaskScope.PRIVATE,
user_id=admin.id,
properties={"is_abortable": False},
)
task.created_by = admin
task.set_status(TaskStatus.IN_PROGRESS)
db.session.commit()
try:
with override_user(admin):
command = CancelTaskCommand(task_uuid=task.uuid)
with pytest.raises(TaskNotAbortableError):
command.run()
# Verify task status unchanged
db.session.refresh(task)
assert task.status == TaskStatus.IN_PROGRESS.value
finally:
# Cleanup
db.session.delete(task)
db.session.commit()
def test_cancel_task_not_found(app_context, get_user) -> None:
"""Test canceling non-existent task raises error"""
admin = get_user("admin")
with override_user(admin):
command = CancelTaskCommand(
task_uuid=UUID("00000000-0000-0000-0000-000000000000")
)
with pytest.raises(TaskNotFoundError):
command.run()
def test_cancel_finished_task_raises_error(app_context, get_user) -> None:
"""Test canceling an already finished task raises error"""
admin = get_user("admin")
unique_key = f"cancel_finished_test_{uuid4().hex[:8]}"
# Create a finished task
task = TaskDAO.create_task(
task_type="test_type",
task_key=unique_key,
scope=TaskScope.PRIVATE,
user_id=admin.id,
)
task.created_by = admin
task.set_status(TaskStatus.SUCCESS)
db.session.commit()
try:
with override_user(admin):
command = CancelTaskCommand(task_uuid=task.uuid)
with pytest.raises(TaskAbortFailedError):
command.run()
# Verify task status unchanged
db.session.refresh(task)
assert task.status == TaskStatus.SUCCESS.value
finally:
# Cleanup
db.session.delete(task)
db.session.commit()
def test_cancel_shared_task_with_multiple_subscribers_unsubscribes(
app_context, get_user
) -> None:
"""Test canceling a shared task with multiple subscribers unsubscribes user"""
admin = get_user("admin")
gamma = get_user("gamma")
# Create a shared task with admin as creator
task = TaskDAO.create_task(
task_type="test_type",
task_key="cancel_shared_test",
scope=TaskScope.SHARED,
user_id=admin.id,
)
task.created_by = admin
db.session.commit()
# Add gamma as subscriber
TaskDAO.add_subscriber(task.id, user_id=gamma.id)
db.session.commit()
try:
# Verify we have 2 subscribers
db.session.refresh(task)
assert task.subscriber_count == 2
# Cancel as gamma (non-admin subscriber)
with override_user(gamma):
command = CancelTaskCommand(task_uuid=task.uuid)
result = command.run()
# Should unsubscribe, not abort
assert command.action_taken == "unsubscribed"
assert result.status == TaskStatus.PENDING.value # Status unchanged
# Verify gamma was unsubscribed
db.session.refresh(task)
assert task.subscriber_count == 1
assert not task.has_subscriber(gamma.id)
assert task.has_subscriber(admin.id)
finally:
# Cleanup
db.session.delete(task)
db.session.commit()
def test_cancel_shared_task_last_subscriber_aborts(app_context, get_user) -> None:
"""Test canceling a shared task as last subscriber aborts it"""
admin = get_user("admin")
# Create a shared task with only admin as subscriber
task = TaskDAO.create_task(
task_type="test_type",
task_key="cancel_last_subscriber_test",
scope=TaskScope.SHARED,
user_id=admin.id,
)
task.created_by = admin
db.session.commit()
try:
# Verify only 1 subscriber
db.session.refresh(task)
assert task.subscriber_count == 1
# Cancel as the only subscriber
with override_user(admin):
command = CancelTaskCommand(task_uuid=task.uuid)
result = command.run()
# Should abort since last subscriber
assert command.action_taken == "aborted"
assert result.status == TaskStatus.ABORTED.value
finally:
# Cleanup
db.session.delete(task)
db.session.commit()
def test_cancel_with_force_aborts_for_all_subscribers(app_context, get_user) -> None:
"""Test force cancel aborts shared task even with multiple subscribers"""
admin = get_user("admin")
gamma = get_user("gamma")
# Create a shared task with multiple subscribers
task = TaskDAO.create_task(
task_type="test_type",
task_key="force_cancel_test",
scope=TaskScope.SHARED,
user_id=admin.id,
)
task.created_by = admin
db.session.commit()
# Add gamma as subscriber
TaskDAO.add_subscriber(task.id, user_id=gamma.id)
db.session.commit()
try:
# Verify 2 subscribers
db.session.refresh(task)
assert task.subscriber_count == 2
# Force cancel as admin
with override_user(admin):
command = CancelTaskCommand(task_uuid=task.uuid, force=True)
result = command.run()
# Should abort despite multiple subscribers
assert command.action_taken == "aborted"
assert result.status == TaskStatus.ABORTED.value
finally:
# Cleanup
db.session.delete(task)
db.session.commit()
def test_cancel_with_force_requires_admin(app_context, get_user) -> None:
"""Test force cancel requires admin privileges"""
admin = get_user("admin")
gamma = get_user("gamma")
# Create a shared task
task = TaskDAO.create_task(
task_type="test_type",
task_key="force_requires_admin_test",
scope=TaskScope.SHARED,
user_id=admin.id,
)
task.created_by = admin
db.session.commit()
# Add gamma as subscriber
TaskDAO.add_subscriber(task.id, user_id=gamma.id)
db.session.commit()
try:
# Try to force cancel as gamma (non-admin)
with override_user(gamma):
command = CancelTaskCommand(task_uuid=task.uuid, force=True)
with pytest.raises(TaskPermissionDeniedError):
command.run()
# Verify task unchanged
db.session.refresh(task)
assert task.status == TaskStatus.PENDING.value
finally:
# Cleanup
db.session.delete(task)
db.session.commit()
def test_cancel_private_task_permission_denied(app_context, get_user) -> None:
"""Test non-owner cannot cancel private task"""
admin = get_user("admin")
gamma = get_user("gamma")
unique_key = f"private_permission_test_{uuid4().hex[:8]}"
# Use test_request_context to ensure has_request_context() returns True
# so that TaskFilter properly applies permission filtering
with app.test_request_context():
# Create a private task owned by admin
task = TaskDAO.create_task(
task_type="test_type",
task_key=unique_key,
scope=TaskScope.PRIVATE,
user_id=admin.id,
)
task.created_by = admin
db.session.commit()
try:
# Try to cancel admin's private task as gamma (non-owner)
with override_user(gamma):
command = CancelTaskCommand(task_uuid=task.uuid)
# Should fail because gamma can't see admin's private task (base filter)
with pytest.raises(TaskNotFoundError):
command.run()
# Verify task unchanged
db.session.refresh(task)
assert task.status == TaskStatus.PENDING.value
finally:
# Cleanup
db.session.delete(task)
db.session.commit()
def test_cancel_system_task_requires_admin(app_context, get_user) -> None:
"""Test system tasks can only be canceled by admin"""
admin = get_user("admin")
gamma = get_user("gamma")
unique_key = f"system_task_test_{uuid4().hex[:8]}"
# Use test_request_context to ensure has_request_context() returns True
# so that TaskFilter properly applies permission filtering
with app.test_request_context():
# Create a system task
task = TaskDAO.create_task(
task_type="test_type",
task_key=unique_key,
scope=TaskScope.SYSTEM,
user_id=None,
)
task.created_by = admin
db.session.commit()
try:
# Try to cancel as gamma (non-admin)
with override_user(gamma):
command = CancelTaskCommand(task_uuid=task.uuid)
# System tasks are not visible to non-admins via base filter
with pytest.raises(TaskNotFoundError):
command.run()
# Verify task unchanged
db.session.refresh(task)
assert task.status == TaskStatus.PENDING.value
# But admin can cancel it
with override_user(admin):
command = CancelTaskCommand(task_uuid=task.uuid)
result = command.run()
assert result.status == TaskStatus.ABORTED.value
finally:
# Cleanup
db.session.delete(task)
db.session.commit()
def test_cancel_already_aborting_is_idempotent(app_context, get_user) -> None:
"""Test canceling an already aborting task is idempotent"""
admin = get_user("admin")
# Create a task already in ABORTING state
task = TaskDAO.create_task(
task_type="test_type",
task_key="idempotent_cancel_test",
scope=TaskScope.PRIVATE,
user_id=admin.id,
)
task.created_by = admin
task.set_status(TaskStatus.ABORTING)
db.session.commit()
try:
# Cancel the already aborting task
with override_user(admin):
command = CancelTaskCommand(task_uuid=task.uuid)
result = command.run()
# Should succeed without error
assert result.uuid == task.uuid
assert result.status == TaskStatus.ABORTING.value
finally:
# Cleanup
db.session.delete(task)
db.session.commit()
def test_cancel_shared_task_not_subscribed_raises_error(app_context, get_user) -> None:
"""Test non-subscriber cannot cancel shared task"""
admin = get_user("admin")
gamma = get_user("gamma")
# Create a shared task with only admin as subscriber
task = TaskDAO.create_task(
task_type="test_type",
task_key="not_subscribed_test",
scope=TaskScope.SHARED,
user_id=admin.id,
)
task.created_by = admin
db.session.commit()
try:
# Try to cancel as gamma (not subscribed)
with override_user(gamma):
command = CancelTaskCommand(task_uuid=task.uuid)
with pytest.raises(TaskPermissionDeniedError):
command.run()
# Verify task unchanged
db.session.refresh(task)
assert task.status == TaskStatus.PENDING.value
finally:
# Cleanup
db.session.delete(task)
db.session.commit()

View File

@@ -0,0 +1,419 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file to you under
# the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Integration tests for internal task state update commands."""
from uuid import UUID
from superset_core.api.tasks import TaskScope, TaskStatus
from superset import db
from superset.commands.tasks.internal_update import (
InternalStatusTransitionCommand,
InternalUpdateTaskCommand,
)
from superset.daos.tasks import TaskDAO
def test_internal_update_properties(app_context, get_user, login_as) -> None:
"""Test updating only properties without reading task first."""
admin = get_user("admin")
login_as("admin")
task = TaskDAO.create_task(
task_type="test_type",
task_key="internal_update_props",
scope=TaskScope.PRIVATE,
user_id=admin.id,
)
task.created_by = admin
task.set_status(TaskStatus.IN_PROGRESS)
db.session.commit()
try:
# Perform zero-read update
command = InternalUpdateTaskCommand(
task_uuid=task.uuid,
properties={"is_abortable": True, "progress_percent": 0.5},
)
result = command.run()
assert result is True
# Verify in database
db.session.refresh(task)
assert task.properties_dict.get("is_abortable") is True
assert task.properties_dict.get("progress_percent") == 0.5
finally:
db.session.delete(task)
db.session.commit()
def test_internal_update_payload(app_context, get_user, login_as) -> None:
"""Test updating only payload without reading task first."""
admin = get_user("admin")
login_as("admin")
task = TaskDAO.create_task(
task_type="test_type",
task_key="internal_update_payload",
scope=TaskScope.PRIVATE,
user_id=admin.id,
)
task.created_by = admin
task.set_status(TaskStatus.IN_PROGRESS)
db.session.commit()
try:
# Perform zero-read update
command = InternalUpdateTaskCommand(
task_uuid=task.uuid,
payload={"custom_key": "value", "count": 42},
)
result = command.run()
assert result is True
# Verify in database
db.session.refresh(task)
assert task.payload_dict == {"custom_key": "value", "count": 42}
finally:
db.session.delete(task)
db.session.commit()
def test_internal_update_both_properties_and_payload(
app_context, get_user, login_as
) -> None:
"""Test updating both properties and payload in one call."""
admin = get_user("admin")
login_as("admin")
task = TaskDAO.create_task(
task_type="test_type",
task_key="internal_update_both",
scope=TaskScope.PRIVATE,
user_id=admin.id,
)
task.created_by = admin
task.set_status(TaskStatus.IN_PROGRESS)
db.session.commit()
try:
# Perform zero-read update of both
command = InternalUpdateTaskCommand(
task_uuid=task.uuid,
properties={"progress_current": 50, "progress_total": 100},
payload={"last_item": "xyz"},
)
result = command.run()
assert result is True
# Verify in database
db.session.refresh(task)
assert task.properties_dict.get("progress_current") == 50
assert task.properties_dict.get("progress_total") == 100
assert task.payload_dict == {"last_item": "xyz"}
finally:
db.session.delete(task)
db.session.commit()
def test_internal_update_returns_false_for_nonexistent_task(
app_context, login_as
) -> None:
"""Test that updating non-existent task returns False."""
login_as("admin")
command = InternalUpdateTaskCommand(
task_uuid=UUID("00000000-0000-0000-0000-000000000000"),
properties={"is_abortable": True},
)
result = command.run()
assert result is False
def test_internal_update_returns_false_when_nothing_to_update(
app_context, get_user, login_as
) -> None:
"""Test that passing no properties or payload returns False early."""
admin = get_user("admin")
login_as("admin")
task = TaskDAO.create_task(
task_type="test_type",
task_key="internal_update_empty",
scope=TaskScope.PRIVATE,
user_id=admin.id,
)
task.created_by = admin
db.session.commit()
try:
# No properties or payload provided
command = InternalUpdateTaskCommand(
task_uuid=task.uuid,
properties=None,
payload=None,
)
result = command.run()
assert result is False
finally:
db.session.delete(task)
db.session.commit()
def test_internal_update_does_not_change_status(
app_context, get_user, login_as
) -> None:
"""Test that internal update leaves status unchanged (safe for concurrent abort)."""
admin = get_user("admin")
login_as("admin")
task = TaskDAO.create_task(
task_type="test_type",
task_key="internal_update_status",
scope=TaskScope.PRIVATE,
user_id=admin.id,
)
task.created_by = admin
task.set_status(TaskStatus.IN_PROGRESS)
db.session.commit()
try:
# Update properties - status should not change
command = InternalUpdateTaskCommand(
task_uuid=task.uuid,
properties={"progress_percent": 0.75},
)
result = command.run()
assert result is True
# Verify status unchanged
db.session.refresh(task)
assert task.status == TaskStatus.IN_PROGRESS.value
finally:
db.session.delete(task)
db.session.commit()
def test_internal_update_replaces_entire_properties(
app_context, get_user, login_as
) -> None:
"""Test that internal update replaces properties entirely (no merge)."""
admin = get_user("admin")
login_as("admin")
task = TaskDAO.create_task(
task_type="test_type",
task_key="internal_update_replace",
scope=TaskScope.PRIVATE,
user_id=admin.id,
properties={"is_abortable": True, "timeout": 300},
)
task.created_by = admin
db.session.commit()
try:
# Replace with new properties (caller is responsible for merging if needed)
command = InternalUpdateTaskCommand(
task_uuid=task.uuid,
properties={"error_message": "new_value"},
)
result = command.run()
assert result is True
# Verify entire replacement occurred
db.session.refresh(task)
# The caller should have merged if they wanted to preserve is_abortable
assert task.properties_dict == {"error_message": "new_value"}
assert "is_abortable" not in task.properties_dict
assert "timeout" not in task.properties_dict
finally:
db.session.delete(task)
db.session.commit()
# =============================================================================
# InternalStatusTransitionCommand Tests
# =============================================================================
def test_status_transition_atomic_compare_and_swap(
app_context, get_user, login_as
) -> None:
"""Test atomic conditional status transitions with comprehensive scenarios.
Covers: success case, failure case, list of expected statuses, properties update,
ended_at timestamp, and string status values.
"""
admin = get_user("admin")
login_as("admin")
task = TaskDAO.create_task(
task_type="test_type",
task_key="status_transition_comprehensive",
scope=TaskScope.PRIVATE,
user_id=admin.id,
)
task.created_by = admin
db.session.commit()
try:
# 1. SUCCESS CASE: PENDING → IN_PROGRESS (expected matches)
result = InternalStatusTransitionCommand(
task_uuid=task.uuid,
new_status=TaskStatus.IN_PROGRESS,
expected_status=TaskStatus.PENDING,
).run()
assert result is True
db.session.refresh(task)
assert task.status == TaskStatus.IN_PROGRESS.value
# 2. FAILURE CASE: Try wrong expected status (should fail, status unchanged)
result = InternalStatusTransitionCommand(
task_uuid=task.uuid,
new_status=TaskStatus.SUCCESS,
expected_status=TaskStatus.PENDING, # Wrong! Current is IN_PROGRESS
).run()
assert result is False
db.session.refresh(task)
assert task.status == TaskStatus.IN_PROGRESS.value # Unchanged
# 3. LIST OF EXPECTED: Transition with multiple acceptable source statuses
task.set_status(TaskStatus.ABORTING)
db.session.commit()
result = InternalStatusTransitionCommand(
task_uuid=task.uuid,
new_status=TaskStatus.FAILURE,
expected_status=[TaskStatus.IN_PROGRESS, TaskStatus.ABORTING],
properties={"error_message": "Test error"},
).run()
assert result is True
db.session.refresh(task)
assert task.status == TaskStatus.FAILURE.value
assert task.properties_dict.get("error_message") == "Test error"
# 4. ENDED_AT: Reset to IN_PROGRESS and test ended_at timestamp
task.set_status(TaskStatus.IN_PROGRESS)
task.ended_at = None
db.session.commit()
assert task.ended_at is None
result = InternalStatusTransitionCommand(
task_uuid=task.uuid,
new_status=TaskStatus.SUCCESS,
expected_status=TaskStatus.IN_PROGRESS,
set_ended_at=True,
).run()
assert result is True
db.session.refresh(task)
assert task.status == TaskStatus.SUCCESS.value
assert task.ended_at is not None
# 5. STRING VALUES: Reset and test string status values
task.set_status(TaskStatus.PENDING)
db.session.commit()
result = InternalStatusTransitionCommand(
task_uuid=task.uuid,
new_status="in_progress",
expected_status="pending",
).run()
assert result is True
db.session.refresh(task)
assert task.status == "in_progress"
finally:
db.session.delete(task)
db.session.commit()
def test_status_transition_prevents_race_condition(
app_context, get_user, login_as
) -> None:
"""Test that conditional update prevents overwriting concurrent abort.
This is the key race condition fix: if task is aborted concurrently,
the executor's attempt to set SUCCESS should fail (return False),
preserving the ABORTING state.
"""
admin = get_user("admin")
login_as("admin")
task = TaskDAO.create_task(
task_type="test_type",
task_key="status_transition_race",
scope=TaskScope.PRIVATE,
user_id=admin.id,
)
task.created_by = admin
task.set_status(TaskStatus.IN_PROGRESS)
db.session.commit()
try:
# Simulate concurrent abort: directly set ABORTING in DB
# (as if CancelTaskCommand ran in another process)
task.set_status(TaskStatus.ABORTING)
db.session.commit()
# Executor tries to set SUCCESS (expecting IN_PROGRESS) - stale expectation
result = InternalStatusTransitionCommand(
task_uuid=task.uuid,
new_status=TaskStatus.SUCCESS,
expected_status=TaskStatus.IN_PROGRESS,
).run()
# Should fail - task was aborted concurrently
assert result is False
# Verify ABORTING is preserved (not overwritten to SUCCESS)
db.session.refresh(task)
assert task.status == TaskStatus.ABORTING.value
# Verify correct transition from ABORTING still works
result = InternalStatusTransitionCommand(
task_uuid=task.uuid,
new_status=TaskStatus.ABORTED,
expected_status=TaskStatus.ABORTING,
set_ended_at=True,
).run()
assert result is True
db.session.refresh(task)
assert task.status == TaskStatus.ABORTED.value
finally:
db.session.delete(task)
db.session.commit()
def test_status_transition_nonexistent_task(app_context, login_as) -> None:
"""Test that transitioning non-existent task returns False."""
login_as("admin")
result = InternalStatusTransitionCommand(
task_uuid=UUID("00000000-0000-0000-0000-000000000000"),
new_status=TaskStatus.IN_PROGRESS,
expected_status=TaskStatus.PENDING,
).run()
assert result is False

View File

@@ -0,0 +1,258 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file to you under
# the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from datetime import datetime, timezone
from unittest.mock import patch
from freezegun import freeze_time
from superset_core.api.tasks import TaskScope, TaskStatus
from superset import db
from superset.commands.tasks import TaskPruneCommand
from superset.daos.tasks import TaskDAO
from superset.models.tasks import Task
@freeze_time("2024-02-15")
@patch("superset.tasks.utils.get_current_user")
def test_prune_tasks_success(mock_get_user, app_context, get_user, login_as) -> None:
"""Test successful pruning of old completed tasks"""
login_as("admin")
admin = get_user("admin")
mock_get_user.return_value = admin.username
# Create old completed tasks (35 days ago = Jan 11, 2024)
old_date = datetime(2024, 1, 11, tzinfo=timezone.utc)
task_ids = []
for i in range(3):
task = TaskDAO.create_task(
task_type="test_type",
task_key=f"prune_task_{i}",
scope=TaskScope.PRIVATE,
user_id=admin.id,
)
task.created_by = admin
task.set_status(TaskStatus.SUCCESS)
task.ended_at = old_date
task_ids.append(task.id)
# Create a recent task (5 days ago = Feb 10, 2024) that should NOT be deleted
recent_date = datetime(2024, 2, 10, tzinfo=timezone.utc)
recent_task = TaskDAO.create_task(
task_type="test_type",
task_key="recent_task",
scope=TaskScope.PRIVATE,
user_id=admin.id,
)
recent_task.created_by = admin
recent_task.set_status(TaskStatus.SUCCESS)
recent_task.ended_at = recent_date
recent_task_id = recent_task.id
db.session.commit()
try:
# Prune tasks older than 30 days
command = TaskPruneCommand(retention_period_days=30)
command.run()
# Verify old tasks are deleted
for task_id in task_ids:
assert db.session.get(Task, task_id) is None
# Verify recent task is NOT deleted
assert db.session.get(Task, recent_task_id) is not None
finally:
# Cleanup remaining tasks
for task_id in task_ids:
existing = db.session.get(Task, task_id)
if existing:
db.session.delete(existing)
if db.session.get(Task, recent_task_id):
db.session.delete(db.session.get(Task, recent_task_id))
db.session.commit()
@freeze_time("2024-02-15")
@patch("superset.tasks.utils.get_current_user")
def test_prune_tasks_with_max_rows(
mock_get_user, app_context, get_user, login_as
) -> None:
"""Test pruning with max_rows_per_run limit"""
login_as("admin")
admin = get_user("admin")
mock_get_user.return_value = admin.username
# Create old completed tasks (35 days ago = Jan 11, 2024)
task_ids = []
for i in range(5):
# Different ages for ordering (older tasks have smaller hour values)
old_date = datetime(2024, 1, 11, i, tzinfo=timezone.utc)
task = TaskDAO.create_task(
task_type="test_type",
task_key=f"max_rows_task_{i}",
scope=TaskScope.PRIVATE,
user_id=admin.id,
)
task.created_by = admin
task.set_status(TaskStatus.SUCCESS)
task.ended_at = old_date
task_ids.append(task.id)
db.session.commit()
try:
# Prune with max_rows_per_run=2 (should only delete 2 oldest)
command = TaskPruneCommand(retention_period_days=30, max_rows_per_run=2)
command.run()
# Count remaining tasks
remaining = sum(
1 for task_id in task_ids if db.session.get(Task, task_id) is not None
)
assert remaining == 3 # 5 - 2 = 3 remaining
finally:
# Cleanup remaining tasks
for task_id in task_ids:
existing = db.session.get(Task, task_id)
if existing:
db.session.delete(existing)
db.session.commit()
@freeze_time("2024-02-15")
@patch("superset.tasks.utils.get_current_user")
def test_prune_does_not_delete_pending_tasks(
mock_get_user, app_context, get_user, login_as
) -> None:
"""Test that pruning does not delete pending or in-progress tasks"""
login_as("admin")
admin = get_user("admin")
mock_get_user.return_value = admin.username
pending_task = TaskDAO.create_task(
task_type="test_type",
task_key="pending_task",
scope=TaskScope.PRIVATE,
user_id=admin.id,
)
pending_task.created_by = admin
# Keep as PENDING (no ended_at)
in_progress_task = TaskDAO.create_task(
task_type="test_type",
task_key="in_progress_task",
scope=TaskScope.PRIVATE,
user_id=admin.id,
)
in_progress_task.created_by = admin
in_progress_task.set_status(TaskStatus.IN_PROGRESS)
# No ended_at for in-progress tasks
db.session.commit()
try:
# Prune tasks older than 30 days
command = TaskPruneCommand(retention_period_days=30)
command.run()
# Verify non-completed tasks are NOT deleted
assert db.session.get(Task, pending_task.id) is not None
assert db.session.get(Task, in_progress_task.id) is not None
finally:
# Cleanup
for task in [pending_task, in_progress_task]:
existing = db.session.get(Task, task.id)
if existing:
db.session.delete(existing)
db.session.commit()
@freeze_time("2024-02-15")
@patch("superset.tasks.utils.get_current_user")
def test_prune_deletes_all_completed_statuses(
mock_get_user, app_context, get_user, login_as
) -> None:
"""Test pruning deletes SUCCESS, FAILURE, and ABORTED tasks"""
login_as("admin")
admin = get_user("admin")
mock_get_user.return_value = admin.username
old_date = datetime(2024, 1, 11, tzinfo=timezone.utc)
# Create tasks with different completed statuses
success_task = TaskDAO.create_task(
task_type="test_type",
task_key="success_task",
scope=TaskScope.PRIVATE,
user_id=admin.id,
)
success_task.created_by = admin
success_task.set_status(TaskStatus.SUCCESS)
success_task.ended_at = old_date
success_task_id = success_task.id
failure_task = TaskDAO.create_task(
task_type="test_type",
task_key="failure_task",
scope=TaskScope.PRIVATE,
user_id=admin.id,
)
failure_task.created_by = admin
failure_task.set_status(TaskStatus.FAILURE)
failure_task.ended_at = old_date
failure_task_id = failure_task.id
aborted_task = TaskDAO.create_task(
task_type="test_type",
task_key="aborted_task",
scope=TaskScope.PRIVATE,
user_id=admin.id,
)
aborted_task.created_by = admin
aborted_task.set_status(TaskStatus.ABORTED)
aborted_task.ended_at = old_date
aborted_task_id = aborted_task.id
db.session.commit()
task_ids = [success_task_id, failure_task_id, aborted_task_id]
try:
# Prune tasks older than 30 days
command = TaskPruneCommand(retention_period_days=30)
command.run()
# Verify all completed tasks are deleted
for task_id in task_ids:
assert db.session.get(Task, task_id) is None
except AssertionError:
# Cleanup if test fails
for task_id in task_ids:
existing = db.session.get(Task, task_id)
if existing:
db.session.delete(existing)
db.session.commit()
raise
def test_prune_no_tasks_to_delete(app_context, login_as) -> None:
"""Test pruning when no old tasks exist"""
login_as("admin")
# Don't create any tasks - should handle gracefully
command = TaskPruneCommand(retention_period_days=30)
command.run() # Should not raise any errors

View File

@@ -0,0 +1,238 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file to you under
# the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
from superset_core.api.tasks import TaskStatus
from superset import db
from superset.commands.tasks import SubmitTaskCommand
from superset.commands.tasks.exceptions import (
TaskInvalidError,
)
def test_submit_task_success(app_context, login_as, get_user) -> None:
"""Test successful task submission"""
login_as("admin")
admin = get_user("admin")
command = SubmitTaskCommand(
data={
"task_type": "test-type",
"task_key": "test-key",
"task_name": "Test Task",
"user_id": admin.id,
}
)
try:
result = command.run()
# Verify task was created
assert result.task_type == "test-type"
assert result.task_key == "test-key"
assert result.task_name == "Test Task"
assert result.status == TaskStatus.PENDING.value
assert result.payload == "{}"
# Verify in database
db.session.refresh(result)
assert result.id is not None
assert result.uuid is not None
finally:
# Cleanup
db.session.delete(result)
db.session.commit()
def test_submit_task_with_all_fields(app_context, login_as, get_user) -> None:
"""Test task submission with all optional fields"""
login_as("admin")
admin = get_user("admin")
command = SubmitTaskCommand(
data={
"task_type": "test-type",
"task_key": "test-key-full",
"task_name": "Test Task Full",
"user_id": admin.id,
"payload": {"key": "value"},
"properties": {"execution_mode": "async", "timeout": 300},
}
)
try:
result = command.run()
# Verify all fields were set
assert result.task_type == "test-type"
assert result.task_key == "test-key-full"
assert result.task_name == "Test Task Full"
assert result.user_id == admin.id
assert result.payload_dict == {"key": "value"}
assert result.properties_dict.get("execution_mode") == "async"
assert result.properties_dict.get("timeout") == 300
finally:
# Cleanup
db.session.delete(result)
db.session.commit()
def test_submit_task_missing_task_type(app_context, login_as) -> None:
"""Test submission fails when task_type is missing"""
login_as("admin")
command = SubmitTaskCommand(data={})
with pytest.raises(TaskInvalidError) as exc_info:
command.run()
assert len(exc_info.value._exceptions) == 1
assert "task_type" in exc_info.value._exceptions[0].field_name
def test_submit_task_joins_existing(app_context, login_as, get_user) -> None:
"""Test that submitting with duplicate key joins existing task"""
login_as("admin")
admin = get_user("admin")
# Create first task
command1 = SubmitTaskCommand(
data={
"task_type": "test-type",
"task_key": "shared-key",
"task_name": "First Task",
"user_id": admin.id,
}
)
task1 = command1.run()
try:
# Submit second task with same task_key and type
command2 = SubmitTaskCommand(
data={
"task_type": "test-type",
"task_key": "shared-key",
"task_name": "Second Task",
"user_id": admin.id,
}
)
# Should return existing task, not create new one
task2 = command2.run()
assert task2.id == task1.id
assert task2.uuid == task1.uuid
finally:
# Cleanup
db.session.delete(task1)
db.session.commit()
def test_submit_task_without_task_key(app_context, login_as, get_user) -> None:
"""Test task submission without task_key (command generates UUID)"""
login_as("admin")
admin = get_user("admin")
command = SubmitTaskCommand(
data={
"task_type": "test-type",
"task_name": "Test Task No ID",
"user_id": admin.id,
}
)
try:
result = command.run()
# Verify task was created and command generated a task_key
assert result.task_type == "test-type"
assert result.task_name == "Test Task No ID"
assert result.task_key is not None # Command generated UUID
assert result.uuid is not None
finally:
# Cleanup
db.session.delete(result)
db.session.commit()
def test_submit_task_run_with_info_returns_is_new_true(
app_context, login_as, get_user
) -> None:
"""Test run_with_info returns is_new=True for new task"""
login_as("admin")
admin = get_user("admin")
command = SubmitTaskCommand(
data={
"task_type": "test-type",
"task_key": "unique-key-is-new",
"task_name": "Test Task",
"user_id": admin.id,
}
)
try:
task, is_new = command.run_with_info()
assert is_new is True
assert task.task_key == "unique-key-is-new"
finally:
# Cleanup
db.session.delete(task)
db.session.commit()
def test_submit_task_run_with_info_returns_is_new_false(
app_context, login_as, get_user
) -> None:
"""Test run_with_info returns is_new=False when joining existing task"""
login_as("admin")
admin = get_user("admin")
# Create first task
command1 = SubmitTaskCommand(
data={
"task_type": "test-type",
"task_key": "shared-key-is-new",
"task_name": "First Task",
"user_id": admin.id,
}
)
task1, is_new1 = command1.run_with_info()
assert is_new1 is True
try:
# Submit second task with same key
command2 = SubmitTaskCommand(
data={
"task_type": "test-type",
"task_key": "shared-key-is-new",
"task_name": "Second Task",
"user_id": admin.id,
}
)
task2, is_new2 = command2.run_with_info()
# Should return existing task with is_new=False
assert is_new2 is False
assert task2.id == task1.id
assert task2.uuid == task1.uuid
finally:
# Cleanup
db.session.delete(task1)
db.session.commit()

View File

@@ -0,0 +1,260 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file to you under
# the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from uuid import UUID
import pytest
from superset_core.api.tasks import TaskScope, TaskStatus
from superset import db
from superset.commands.tasks import UpdateTaskCommand
from superset.commands.tasks.exceptions import (
TaskForbiddenError,
TaskNotFoundError,
)
from superset.daos.tasks import TaskDAO
def test_update_task_success(app_context, get_user, login_as) -> None:
"""Test successful task update"""
admin = get_user("admin")
login_as("admin")
# Create a task using DAO
task = TaskDAO.create_task(
task_type="test_type",
task_key="update_test",
scope=TaskScope.PRIVATE,
user_id=admin.id,
)
task.created_by = admin
task.set_status(TaskStatus.IN_PROGRESS)
db.session.commit()
try:
# Update the task status
command = UpdateTaskCommand(
task_uuid=task.uuid,
status=TaskStatus.SUCCESS.value,
)
result = command.run()
# Verify update
assert result.uuid == task.uuid
assert result.status == TaskStatus.SUCCESS.value
# Verify in database
db.session.refresh(task)
assert task.status == TaskStatus.SUCCESS.value
finally:
# Cleanup
db.session.delete(task)
db.session.commit()
def test_update_task_not_found(app_context, login_as) -> None:
"""Test update fails when task not found"""
login_as("admin")
command = UpdateTaskCommand(
task_uuid=UUID("00000000-0000-0000-0000-000000000000"),
status=TaskStatus.SUCCESS.value,
)
with pytest.raises(TaskNotFoundError):
command.run()
def test_update_task_forbidden(app_context, get_user, login_as) -> None:
"""Test update fails when user doesn't own task (via base filter)"""
gamma = get_user("gamma")
login_as("gamma")
# Create a task owned by gamma (non-admin) using DAO
task = TaskDAO.create_task(
task_type="test_type",
task_key="forbidden_test",
scope=TaskScope.PRIVATE,
user_id=gamma.id,
)
task.created_by = gamma
task.set_status(TaskStatus.IN_PROGRESS)
db.session.commit()
try:
# Login as alpha user (different non-admin, non-owner)
login_as("alpha")
# Try to update gamma's task as alpha user
command = UpdateTaskCommand(
task_uuid=task.uuid,
status=TaskStatus.SUCCESS.value,
)
# Should raise ForbiddenError because ownership check fails
with pytest.raises(TaskForbiddenError):
command.run()
# Verify task was NOT updated
db.session.refresh(task)
assert task.status == TaskStatus.IN_PROGRESS.value
finally:
# Cleanup
db.session.delete(task)
db.session.commit()
def test_update_task_payload(app_context, get_user, login_as) -> None:
"""Test updating task payload"""
admin = get_user("admin")
login_as("admin")
# Create a task using DAO
task = TaskDAO.create_task(
task_type="test_type",
task_key="payload_test",
scope=TaskScope.PRIVATE,
user_id=admin.id,
payload={"initial": "data"},
)
task.created_by = admin
task.set_status(TaskStatus.IN_PROGRESS)
db.session.commit()
try:
# Update payload
command = UpdateTaskCommand(
task_uuid=task.uuid,
payload={"progress": 50, "message": "halfway"},
)
result = command.run()
# Verify payload was updated
assert result.uuid == task.uuid
payload = result.payload_dict
assert payload["progress"] == 50
assert payload["message"] == "halfway"
# Verify in database
db.session.refresh(task)
task_payload = task.payload_dict
assert task_payload["progress"] == 50
finally:
# Cleanup
db.session.delete(task)
db.session.commit()
def test_update_all_supported_fields(app_context, get_user, login_as) -> None:
"""Test updating all supported task fields
(status, error, progress, abortable, timeout)"""
admin = get_user("admin")
login_as("admin")
# Create a task with initial execution_mode and timeout in properties
task = TaskDAO.create_task(
task_type="test_type",
task_key="all_fields_test",
scope=TaskScope.PRIVATE,
user_id=admin.id,
properties={"execution_mode": "async", "timeout": 300},
)
task.created_by = admin
task.set_status(TaskStatus.IN_PROGRESS)
db.session.commit()
try:
# Update all field types at once
command = UpdateTaskCommand(
task_uuid=task.uuid,
status=TaskStatus.FAILURE.value,
properties={
"error_message": "Task failed due to error",
"progress_percent": 0.75,
"progress_current": 75,
"progress_total": 100,
"is_abortable": True,
},
)
result = command.run()
# Verify all fields were updated
assert result.uuid == task.uuid
assert result.status == TaskStatus.FAILURE.value
assert result.properties_dict.get("error_message") == "Task failed due to error"
assert result.properties_dict.get("progress_percent") == 0.75
assert result.properties_dict.get("progress_current") == 75
assert result.properties_dict.get("progress_total") == 100
assert result.properties_dict.get("is_abortable") is True
assert result.properties_dict.get("execution_mode") == "async"
assert result.properties_dict.get("timeout") == 300
# Verify in database
db.session.refresh(task)
assert task.status == TaskStatus.FAILURE.value
assert task.properties_dict.get("error_message") == "Task failed due to error"
assert task.properties_dict.get("progress_percent") == 0.75
assert task.properties_dict.get("progress_current") == 75
assert task.properties_dict.get("progress_total") == 100
assert task.properties_dict.get("is_abortable") is True
assert task.properties_dict.get("execution_mode") == "async"
assert task.properties_dict.get("timeout") == 300
finally:
# Cleanup
db.session.delete(task)
db.session.commit()
def test_update_task_skip_security_check(app_context, get_user, login_as) -> None:
"""Test skip_security_check allows updating any task"""
admin = get_user("admin")
login_as("admin")
# Create a task owned by admin
task = TaskDAO.create_task(
task_type="test_type",
task_key="skip_security_test",
scope=TaskScope.PRIVATE,
user_id=admin.id,
)
task.created_by = admin
task.set_status(TaskStatus.IN_PROGRESS)
db.session.commit()
try:
# Login as gamma user (non-owner)
login_as("gamma")
# With skip_security_check=True, should succeed even though gamma doesn't own it
command = UpdateTaskCommand(
task_uuid=task.uuid,
properties={"progress_percent": 0.75},
skip_security_check=True,
)
result = command.run()
# Verify update succeeded
assert result.uuid == task.uuid
assert result.properties_dict.get("progress_percent") == 0.75
# Verify in database
db.session.refresh(task)
assert task.properties_dict.get("progress_percent") == 0.75
finally:
# Cleanup
db.session.delete(task)
db.session.commit()

View File

@@ -0,0 +1,415 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""End-to-end integration tests for task event handlers (abort and cleanup)
These tests verify that abort and cleanup handlers work correctly through
the full task execution path using real @task decorated functions executed
via the Celery executor (synchronously via .apply()).
"""
from __future__ import annotations
import uuid
from typing import Any
from superset_core.api.tasks import TaskScope, TaskStatus
from superset.commands.tasks.cancel import CancelTaskCommand
from superset.daos.tasks import TaskDAO
from superset.extensions import db
from superset.models.tasks import Task
from superset.tasks.ambient_context import get_context
from superset.tasks.context import TaskContext
from superset.tasks.registry import TaskRegistry
from superset.tasks.scheduler import execute_task
from tests.integration_tests.base_tests import SupersetTestCase
from tests.integration_tests.constants import ADMIN_USERNAME
# Module-level state to track handler calls across test executions
# (Since decorated functions are defined at module level)
_handler_state: dict[str, Any] = {}
def _reset_handler_state():
"""Reset handler state before each test."""
global _handler_state
_handler_state = {
"cleanup_called": False,
"abort_called": False,
"cleanup_order": [],
"abort_order": [],
"cleanup_data": {},
}
def cleanup_test_task() -> None:
"""Task that registers a cleanup handler."""
ctx = get_context()
@ctx.on_cleanup
def handle_cleanup() -> None:
_handler_state["cleanup_called"] = True
# Simulate some work
ctx.update_task(progress=1.0)
def abort_test_task() -> None:
"""Task that registers an abort handler."""
ctx = get_context()
@ctx.on_abort
def handle_abort() -> None:
_handler_state["abort_called"] = True
def both_handlers_task() -> None:
"""Task that registers both abort and cleanup handlers."""
ctx = get_context()
@ctx.on_abort
def handle_abort() -> None:
_handler_state["abort_called"] = True
_handler_state["abort_order"].append("abort")
@ctx.on_cleanup
def handle_cleanup() -> None:
_handler_state["cleanup_called"] = True
_handler_state["cleanup_order"].append("cleanup")
def multiple_cleanup_handlers_task() -> None:
"""Task that registers multiple cleanup handlers."""
ctx = get_context()
@ctx.on_cleanup
def cleanup_first() -> None:
_handler_state["cleanup_order"].append("first")
@ctx.on_cleanup
def cleanup_second() -> None:
_handler_state["cleanup_order"].append("second")
@ctx.on_cleanup
def cleanup_third() -> None:
_handler_state["cleanup_order"].append("third")
def cleanup_with_data_task() -> None:
"""Task that uses cleanup handler to clean up partial work."""
ctx = get_context()
# Simulate partial work in module-level state
_handler_state["cleanup_data"]["temp_key"] = "temp_value"
@ctx.on_cleanup
def handle_cleanup() -> None:
# Clean up the partial work
_handler_state["cleanup_data"].clear()
_handler_state["cleanup_called"] = True
def _register_test_tasks() -> None:
"""Register test task functions if not already registered.
Called in setUp() to ensure tasks are registered regardless of
whether other tests have cleared the registry.
"""
registrations = [
("test_cleanup_task", cleanup_test_task),
("test_abort_task", abort_test_task),
("test_both_handlers_task", both_handlers_task),
("test_multiple_cleanup_task", multiple_cleanup_handlers_task),
("test_cleanup_with_data", cleanup_with_data_task),
]
for name, func in registrations:
if not TaskRegistry.is_registered(name):
TaskRegistry.register(name, func)
class TestCleanupHandlers(SupersetTestCase):
"""E2E tests for on_cleanup functionality using Celery executor."""
def setUp(self):
"""Set up test fixtures."""
super().setUp()
self.login(ADMIN_USERNAME)
_register_test_tasks()
_reset_handler_state()
def test_cleanup_handler_fires_on_success(self):
"""Test cleanup handler runs when task completes successfully."""
# Create task entry directly
task_obj = TaskDAO.create_task(
task_type="test_cleanup_task",
task_key=f"test_key_{uuid.uuid4().hex[:8]}",
task_name="Test Cleanup",
scope=TaskScope.SYSTEM,
)
# Execute task synchronously through Celery executor
# Use str(uuid) since Celery serializes args as JSON strings
result = execute_task.apply(
args=[str(task_obj.uuid), "test_cleanup_task", (), {}]
)
# Verify task completed successfully
assert result.successful()
assert result.result["status"] == TaskStatus.SUCCESS.value
# Verify cleanup handler was called
assert _handler_state["cleanup_called"]
def test_multiple_cleanup_handlers_in_lifo_order(self):
"""Test multiple cleanup handlers execute in LIFO order."""
task_obj = TaskDAO.create_task(
task_type="test_multiple_cleanup_task",
task_key=f"test_key_{uuid.uuid4().hex[:8]}",
task_name="Test Multiple Cleanup",
scope=TaskScope.SYSTEM,
)
result = execute_task.apply(
args=[str(task_obj.uuid), "test_multiple_cleanup_task", (), {}]
)
assert result.successful()
# Handlers should execute in LIFO order (last registered first)
assert _handler_state["cleanup_order"] == ["third", "second", "first"]
def test_cleanup_handler_cleans_up_partial_work(self):
"""Test cleanup handler can clean up partial work."""
task_obj = TaskDAO.create_task(
task_type="test_cleanup_with_data",
task_key=f"test_key_{uuid.uuid4().hex[:8]}",
task_name="Test Cleanup Data",
scope=TaskScope.SYSTEM,
)
result = execute_task.apply(
args=[str(task_obj.uuid), "test_cleanup_with_data", (), {}]
)
assert result.successful()
assert _handler_state["cleanup_called"]
# Cleanup handler should have cleared the data
assert len(_handler_state["cleanup_data"]) == 0
class TestAbortHandlers(SupersetTestCase):
"""E2E tests for on_abort functionality."""
def setUp(self):
"""Set up test fixtures."""
super().setUp()
self.login(ADMIN_USERNAME)
_register_test_tasks()
_reset_handler_state()
def test_abort_handler_fires_when_task_aborting(self):
"""Test abort handler runs when task is in ABORTING state during cleanup."""
# Create task entry
task_obj = TaskDAO.create_task(
task_type="test_abort_task",
task_key=f"test_key_{uuid.uuid4().hex[:8]}",
task_name="Test Abort",
scope=TaskScope.SYSTEM,
)
# Manually set to IN_PROGRESS and then ABORTING to simulate abort
task_obj.status = TaskStatus.IN_PROGRESS.value
task_obj.update_properties({"is_abortable": True})
db.session.merge(task_obj)
db.session.commit()
# Refresh to get the updated task
db.session.refresh(task_obj)
# Create context (simulating what executor does)
ctx = TaskContext(task_obj)
# Register abort handler
@ctx.on_abort
def handle_abort():
_handler_state["abort_called"] = True
# Set status to ABORTING (simulating CancelTaskCommand)
task_obj.status = TaskStatus.ABORTING.value
db.session.merge(task_obj)
db.session.commit()
# Run cleanup (simulating executor's finally block)
ctx._run_cleanup()
# Verify abort handler was called
assert _handler_state["abort_called"]
def test_both_handlers_fire_on_abort(self):
"""Test both abort and cleanup handlers run when task is aborted."""
task_obj = TaskDAO.create_task(
task_type="test_both_handlers_task",
task_key=f"test_key_{uuid.uuid4().hex[:8]}",
task_name="Test Both Handlers",
scope=TaskScope.SYSTEM,
)
task_obj.status = TaskStatus.IN_PROGRESS.value
task_obj.update_properties({"is_abortable": True})
db.session.merge(task_obj)
db.session.commit()
# Refresh to get the updated task
db.session.refresh(task_obj)
ctx = TaskContext(task_obj)
@ctx.on_abort
def handle_abort():
_handler_state["abort_called"] = True
_handler_state["abort_order"].append("abort")
@ctx.on_cleanup
def handle_cleanup():
_handler_state["cleanup_called"] = True
_handler_state["cleanup_order"].append("cleanup")
# Set to ABORTING
task_obj.status = TaskStatus.ABORTING.value
db.session.merge(task_obj)
db.session.commit()
ctx._run_cleanup()
# Both should have been called
assert _handler_state["abort_called"]
assert _handler_state["cleanup_called"]
def test_abort_handler_not_called_on_success(self):
"""Test abort handler doesn't run when task succeeds."""
task_obj = TaskDAO.create_task(
task_type="test_abort_task",
task_key=f"test_key_{uuid.uuid4().hex[:8]}",
task_name="Test No Abort on Success",
scope=TaskScope.SYSTEM,
)
task_obj.status = TaskStatus.SUCCESS.value
db.session.merge(task_obj)
db.session.commit()
# Refresh to get the updated task
db.session.refresh(task_obj)
ctx = TaskContext(task_obj)
@ctx.on_abort
def handle_abort():
_handler_state["abort_called"] = True
@ctx.on_cleanup
def handle_cleanup():
_handler_state["cleanup_called"] = True
ctx._run_cleanup()
# Abort handler should NOT be called
assert not _handler_state["abort_called"]
# Cleanup handler should still be called
assert _handler_state["cleanup_called"]
class TestTaskContextMethods(SupersetTestCase):
"""Tests for TaskContext public methods."""
def setUp(self):
"""Set up test fixtures."""
super().setUp()
self.login(ADMIN_USERNAME)
def test_on_abort_marks_task_abortable(self):
"""Test that registering an on_abort handler marks task as abortable."""
task_obj = TaskDAO.create_task(
task_type="test_abortable_flag",
task_key=f"test_key_{uuid.uuid4().hex[:8]}",
task_name="Test Abortable",
scope=TaskScope.SYSTEM,
)
assert task_obj.properties_dict.get("is_abortable") is not True
ctx = TaskContext(task_obj)
@ctx.on_abort
def handle_abort():
pass
db.session.expire_all()
task_obj = db.session.query(Task).filter_by(uuid=task_obj.uuid).first()
assert task_obj.properties_dict.get("is_abortable") is True
class TestAbortBeforeExecution(SupersetTestCase):
"""Tests for aborting tasks before they start executing."""
def setUp(self):
"""Set up test fixtures."""
super().setUp()
self.login(ADMIN_USERNAME)
_register_test_tasks()
def test_abort_pending_task(self):
"""Test that pending tasks can be aborted directly."""
task_obj = TaskDAO.create_task(
task_type="test_abort_before_start",
task_key=f"test_key_{uuid.uuid4().hex[:8]}",
task_name="Test Before Start",
scope=TaskScope.SYSTEM,
)
# Cancel immediately (task is still PENDING)
CancelTaskCommand(task_obj.uuid, force=True).run()
db.session.expire_all()
task_obj = db.session.query(Task).filter_by(uuid=task_obj.uuid).first()
assert task_obj.status == TaskStatus.ABORTED.value
def test_executor_skips_aborted_task(self):
"""Test that executor skips tasks already aborted before execution."""
task_obj = TaskDAO.create_task(
task_type="test_cleanup_task",
task_key=f"test_key_{uuid.uuid4().hex[:8]}",
task_name="Test Skip Aborted",
scope=TaskScope.SYSTEM,
)
# Abort the task before execution
task_obj.status = TaskStatus.ABORTED.value
db.session.merge(task_obj)
db.session.commit()
_reset_handler_state()
# Try to execute - should skip
result = execute_task.apply(
args=[str(task_obj.uuid), "test_cleanup_task", (), {}]
)
assert result.successful()
assert result.result["status"] == TaskStatus.ABORTED.value
# Cleanup handler should NOT have been called (task was skipped)
assert not _handler_state["cleanup_called"]

View File

@@ -0,0 +1,158 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file to you under
# the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Integration tests for sync join-and-wait functionality in GTF."""
import time
from superset_core.api.tasks import TaskStatus
from superset import db
from superset.commands.tasks import SubmitTaskCommand
from superset.daos.tasks import TaskDAO
from superset.tasks.manager import TaskManager
def test_submit_task_distinguishes_new_vs_existing(
app_context, login_as, get_user
) -> None:
"""
Test that SubmitTaskCommand.run_with_info() correctly returns is_new flag.
"""
login_as("admin")
admin = get_user("admin")
# First submission - should be new
task1, is_new1 = SubmitTaskCommand(
data={
"task_type": "test-type",
"task_key": "distinguish-key",
"task_name": "First Task",
"user_id": admin.id,
}
).run_with_info()
assert is_new1 is True
try:
# Second submission with same key - should join existing
task2, is_new2 = SubmitTaskCommand(
data={
"task_type": "test-type",
"task_key": "distinguish-key",
"task_name": "Second Task",
"user_id": admin.id,
}
).run_with_info()
assert is_new2 is False
assert task2.uuid == task1.uuid
finally:
# Cleanup
db.session.delete(task1)
db.session.commit()
def test_terminal_states_recognized_correctly(app_context) -> None:
"""
Test that TaskManager.TERMINAL_STATES contains the expected values.
"""
assert TaskStatus.SUCCESS.value in TaskManager.TERMINAL_STATES
assert TaskStatus.FAILURE.value in TaskManager.TERMINAL_STATES
assert TaskStatus.ABORTED.value in TaskManager.TERMINAL_STATES
assert TaskStatus.TIMED_OUT.value in TaskManager.TERMINAL_STATES
# Non-terminal states should not be in the set
assert TaskStatus.PENDING.value not in TaskManager.TERMINAL_STATES
assert TaskStatus.IN_PROGRESS.value not in TaskManager.TERMINAL_STATES
assert TaskStatus.ABORTING.value not in TaskManager.TERMINAL_STATES
def test_wait_for_completion_timeout(app_context, login_as, get_user) -> None:
"""
Test that wait_for_completion raises TimeoutError on timeout.
"""
from unittest.mock import patch
import pytest
login_as("admin")
admin = get_user("admin")
# Create a pending task (won't complete)
task, _ = SubmitTaskCommand(
data={
"task_type": "test-timeout",
"task_key": "timeout-key",
"task_name": "Timeout Task",
"user_id": admin.id,
}
).run_with_info()
try:
# Force polling mode by mocking signal_cache as None
with patch("superset.tasks.manager.cache_manager") as mock_cache_manager:
mock_cache_manager.signal_cache = None
with pytest.raises(TimeoutError):
TaskManager.wait_for_completion(
task.uuid,
timeout=0.2,
poll_interval=0.05,
)
finally:
db.session.delete(task)
db.session.commit()
def test_wait_returns_immediately_for_terminal_task(
app_context, login_as, get_user
) -> None:
"""
Test that wait_for_completion returns immediately if task is already terminal.
"""
login_as("admin")
admin = get_user("admin")
# Create and immediately complete a task
task, _ = SubmitTaskCommand(
data={
"task_type": "test-immediate",
"task_key": "immediate-key",
"task_name": "Immediate Task",
"user_id": admin.id,
}
).run_with_info()
TaskDAO.update(task, {"status": TaskStatus.SUCCESS.value})
db.session.commit()
try:
start = time.time()
result = TaskManager.wait_for_completion(
task.uuid,
timeout=5.0,
poll_interval=0.5,
)
elapsed = time.time() - start
assert result.status == TaskStatus.SUCCESS.value
# Should return almost immediately since task is already terminal
assert elapsed < 0.2
finally:
db.session.delete(task)
db.session.commit()

View File

@@ -0,0 +1,172 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Integration tests for TaskContext update_task throttling.
Tests verify:
1. Final state is persisted correctly via cleanup flush
2. Throttled updates are deferred, timer writes latest pending update
"""
from __future__ import annotations
import time
import uuid
from superset_core.api.tasks import TaskScope, TaskStatus
from superset.daos.tasks import TaskDAO
from superset.extensions import db
from superset.models.tasks import Task
from superset.tasks.ambient_context import get_context
from superset.tasks.registry import TaskRegistry
from superset.tasks.scheduler import execute_task
from tests.integration_tests.base_tests import SupersetTestCase
from tests.integration_tests.constants import ADMIN_USERNAME
def task_with_throttled_updates() -> None:
"""Task with rapid progress and payload updates (exercises throttling)."""
ctx = get_context()
# Rapid-fire updates within throttle window
for i in range(10):
ctx.update_task(progress=(i + 1, 10), payload={"step": i + 1})
def _register_test_tasks() -> None:
"""Register test task functions if not already registered.
Called in setUp() to ensure tasks are registered regardless of
whether other tests have cleared the registry.
"""
if not TaskRegistry.is_registered("test_throttle_combined"):
TaskRegistry.register("test_throttle_combined", task_with_throttled_updates)
class TestUpdateTaskThrottling(SupersetTestCase):
"""Integration test for update_task() throttling behavior."""
def setUp(self) -> None:
super().setUp()
self.login(ADMIN_USERNAME)
_register_test_tasks()
def test_throttled_updates_persisted_on_cleanup(self) -> None:
"""Final state should be persisted regardless of throttling.
Verifies the core invariant: cleanup flush ensures final state is persisted.
"""
task_obj = TaskDAO.create_task(
task_type="test_throttle_combined",
task_key=f"test_key_{uuid.uuid4().hex[:8]}",
task_name="Test Throttled Updates",
scope=TaskScope.SYSTEM,
)
# Use str(uuid) since Celery serializes args as JSON strings
result = execute_task.apply(
args=[str(task_obj.uuid), "test_throttle_combined", (), {}]
)
assert result.successful()
assert result.result["status"] == TaskStatus.SUCCESS.value
# Verify final state is persisted
db.session.expire_all()
task_obj = db.session.query(Task).filter_by(uuid=task_obj.uuid).first()
# Progress: 10/10 = 100%
props = task_obj.properties_dict
assert props.get("progress_current") == 10
assert props.get("progress_total") == 10
assert props.get("progress_percent") == 1.0
# Payload: final step
payload = task_obj.payload_dict
assert payload.get("step") == 10
def test_throttle_behavior(self) -> None:
"""Test complete throttle behavior: immediate write, deferral, and timer.
Verifies:
1. First update writes immediately
2. Second and third updates within throttle window are deferred
3. Deferred timer fires and writes the LATEST pending update (third)
"""
from flask import current_app
from superset.commands.tasks.submit import SubmitTaskCommand
from superset.tasks.context import TaskContext
# Get throttle interval from config (default: 2 seconds)
throttle_interval = current_app.config["TASK_PROGRESS_UPDATE_THROTTLE_INTERVAL"]
# Create task
task_obj = SubmitTaskCommand(
data={
"task_type": "test_throttle_behavior",
"task_key": f"test_key_{uuid.uuid4().hex[:8]}",
"task_name": "Test Throttle Behavior",
"scope": TaskScope.SYSTEM,
}
).run()
task_uuid = task_obj.uuid
# Get fresh task for context
fresh_task = TaskDAO.find_one_or_none(uuid=task_uuid, skip_base_filter=True)
assert fresh_task is not None
ctx = TaskContext(fresh_task)
try:
# === Step 1: First update - writes immediately ===
ctx.update_task(progress=0.1, payload={"step": 1})
db.session.expire_all()
task_step1 = TaskDAO.find_one_or_none(uuid=task_uuid, skip_base_filter=True)
assert task_step1 is not None
assert task_step1.properties_dict.get("progress_percent") == 0.1
assert task_step1.payload_dict.get("step") == 1
# === Step 2: Second update - deferred (within throttle window) ===
ctx.update_task(progress=0.5, payload={"step": 2})
# === Step 3: Third update - also deferred, overwrites second in cache ===
ctx.update_task(progress=0.7, payload={"step": 3})
# Verify in-memory cache has LATEST update (third)
assert ctx._properties_cache.get("progress_percent") == 0.7
assert ctx._payload_cache.get("step") == 3
# Verify DB still has first update (both second and third deferred)
db.session.expire_all()
task_step2 = TaskDAO.find_one_or_none(uuid=task_uuid, skip_base_filter=True)
assert task_step2 is not None
assert task_step2.properties_dict.get("progress_percent") == 0.1
assert task_step2.payload_dict.get("step") == 1
# === Step 4: Wait for deferred timer to fire ===
time.sleep(throttle_interval + 0.5)
# Verify timer fired and wrote the LATEST update (third, not second)
db.session.expire_all()
task_step3 = TaskDAO.find_one_or_none(uuid=task_uuid, skip_base_filter=True)
assert task_step3 is not None
assert task_step3.properties_dict.get("progress_percent") == 0.7
assert task_step3.payload_dict.get("step") == 3
finally:
ctx._cancel_deferred_flush_timer()

View File

@@ -0,0 +1,226 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Integration tests for GTF timeout handling.
Uses module-level task functions with manual registry (like test_event_handlers.py)
to avoid mypy issues with the @task decorator's complex generic types.
NOTE: Tests that use background threads (timeout/abort handlers) are skipped in
SQLite environments because SQLite connections cannot be shared across threads.
"""
from __future__ import annotations
import time
import uuid
from typing import Any
import pytest
from superset_core.api.tasks import TaskScope, TaskStatus
from superset.commands.tasks.cancel import CancelTaskCommand
from superset.daos.tasks import TaskDAO
from superset.extensions import db
from superset.models.tasks import Task
from superset.tasks.ambient_context import get_context
from superset.tasks.registry import TaskRegistry
from superset.tasks.scheduler import execute_task
from tests.integration_tests.base_tests import SupersetTestCase
from tests.integration_tests.constants import ADMIN_USERNAME
def _skip_if_sqlite() -> None:
"""Skip test if running with SQLite database.
SQLite connections cannot be shared across threads, which breaks
timeout tests that use background threads for abort handlers.
Must be called from within a test method (with app context).
"""
if "sqlite" in db.engine.url.drivername:
pytest.skip("SQLite connections cannot be shared across threads")
# Module-level state to track handler calls
_handler_state: dict[str, Any] = {}
def _reset_handler_state() -> None:
"""Reset handler state before each test."""
global _handler_state
_handler_state = {
"abort_called": False,
"handler_exception": None,
}
def timeout_abortable_task() -> None:
"""Task with abort handler that exits when aborted."""
ctx = get_context()
@ctx.on_abort
def on_abort() -> None:
_handler_state["abort_called"] = True
# Poll for abort signal
for _ in range(50):
if _handler_state["abort_called"]:
return
time.sleep(0.1)
def timeout_handler_fails_task() -> None:
"""Task with abort handler that throws an exception."""
ctx = get_context()
@ctx.on_abort
def on_abort() -> None:
_handler_state["abort_called"] = True
raise ValueError("Handler crashed!")
# Sleep longer than timeout
time.sleep(5)
def simple_task_with_abort() -> None:
"""Simple task with abort handler for testing."""
ctx = get_context()
@ctx.on_abort
def on_abort() -> None:
pass
def quick_task_with_abort() -> None:
"""Quick task that completes before timeout."""
ctx = get_context()
@ctx.on_abort
def on_abort() -> None:
pass
time.sleep(0.2)
def _register_test_tasks() -> None:
"""Register test task functions if not already registered.
Called in setUp() to ensure tasks are registered regardless of
whether other tests have cleared the registry.
"""
registrations = [
("test_timeout_abortable", timeout_abortable_task),
("test_timeout_handler_fails", timeout_handler_fails_task),
("test_timeout_simple", simple_task_with_abort),
("test_timeout_quick", quick_task_with_abort),
]
for name, func in registrations:
if not TaskRegistry.is_registered(name):
TaskRegistry.register(name, func)
class TestTimeoutHandling(SupersetTestCase):
"""E2E tests for task timeout functionality."""
def setUp(self) -> None:
"""Set up test fixtures."""
super().setUp()
self.login(ADMIN_USERNAME)
_register_test_tasks()
_reset_handler_state()
def test_timeout_with_abort_handler_results_in_timed_out_status(self) -> None:
"""Task with timeout and abort handler should end with TIMED_OUT status."""
_skip_if_sqlite()
# Create task with timeout
task_obj = TaskDAO.create_task(
task_type="test_timeout_abortable",
task_key=f"test_key_{uuid.uuid4().hex[:8]}",
task_name="Test Timeout",
scope=TaskScope.SYSTEM,
properties={"timeout": 1}, # 1 second timeout
)
# Execute task via Celery executor (synchronously)
# Use str(uuid) since Celery serializes args as JSON strings
result = execute_task.apply(
args=[str(task_obj.uuid), "test_timeout_abortable", (), {}]
)
# Verify execution completed
assert result.successful()
assert result.result["status"] == TaskStatus.TIMED_OUT.value
# Verify abort handler was called
assert _handler_state["abort_called"]
def test_user_abort_results_in_aborted_status(self) -> None:
"""User-initiated abort on pending task should result in ABORTED."""
# Create task (pending state)
task_obj = TaskDAO.create_task(
task_type="test_timeout_simple",
task_key=f"test_key_{uuid.uuid4().hex[:8]}",
task_name="Test Abort Task",
scope=TaskScope.SYSTEM,
)
# Cancel before execution (pending task abort)
CancelTaskCommand(task_obj.uuid, force=True).run()
# Refresh from DB
db.session.expire_all()
task_obj = db.session.query(Task).filter_by(uuid=task_obj.uuid).first()
assert task_obj.status == TaskStatus.ABORTED.value
def test_no_timeout_when_not_configured(self) -> None:
"""Task without timeout should run to completion regardless of duration."""
task_obj = TaskDAO.create_task(
task_type="test_timeout_quick",
task_key=f"test_key_{uuid.uuid4().hex[:8]}",
task_name="Test No Timeout",
scope=TaskScope.SYSTEM,
# No timeout property
)
# Use str(uuid) since Celery serializes args as JSON strings
result = execute_task.apply(
args=[str(task_obj.uuid), "test_timeout_quick", (), {}]
)
assert result.successful()
assert result.result["status"] == TaskStatus.SUCCESS.value
def test_abort_handler_exception_results_in_failure(self) -> None:
"""If abort handler throws during timeout, task should be FAILURE."""
_skip_if_sqlite()
task_obj = TaskDAO.create_task(
task_type="test_timeout_handler_fails",
task_key=f"test_key_{uuid.uuid4().hex[:8]}",
task_name="Test Handler Fails",
scope=TaskScope.SYSTEM,
properties={"timeout": 1}, # 1 second timeout
)
# Use str(uuid) since Celery serializes args as JSON strings
result = execute_task.apply(
args=[str(task_obj.uuid), "test_timeout_handler_fails", (), {}]
)
assert result.successful()
assert result.result["status"] == TaskStatus.FAILURE.value
assert _handler_state["abort_called"]

View File

@@ -0,0 +1,420 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file to you under
# the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from collections.abc import Iterator
from uuid import UUID
import pytest
from sqlalchemy.orm.session import Session
from superset_core.api.tasks import TaskProperties, TaskScope, TaskStatus
from superset.commands.tasks.exceptions import TaskNotAbortableError
from superset.models.tasks import Task
from superset.tasks.utils import get_active_dedup_key, get_finished_dedup_key
# Test constants
TASK_UUID = UUID("e7765491-40c1-4f35-a4f5-06308e79310e")
TASK_ID = 42
TEST_TASK_TYPE = "test_type"
TEST_TASK_KEY = "test-key"
TEST_USER_ID = 1
def create_task(
session: Session,
*,
task_id: int | None = None,
task_uuid: UUID | None = None,
task_key: str = TEST_TASK_KEY,
task_type: str = TEST_TASK_TYPE,
scope: TaskScope = TaskScope.PRIVATE,
status: TaskStatus = TaskStatus.PENDING,
user_id: int | None = TEST_USER_ID,
properties: TaskProperties | None = None,
use_finished_dedup_key: bool = False,
) -> Task:
"""Helper to create a task with sensible defaults for testing."""
if use_finished_dedup_key:
dedup_key = get_finished_dedup_key(task_uuid or TASK_UUID)
else:
dedup_key = get_active_dedup_key(
scope=scope,
task_type=task_type,
task_key=task_key,
user_id=user_id,
)
task = Task(
task_type=task_type,
task_key=task_key,
scope=scope.value,
status=status.value,
dedup_key=dedup_key,
user_id=user_id,
)
if task_id is not None:
task.id = task_id
if task_uuid:
task.uuid = task_uuid
if properties:
task.update_properties(properties)
session.add(task)
session.flush()
return task
@pytest.fixture
def session_with_task(session: Session) -> Iterator[Session]:
"""Create a session with Task and TaskSubscriber tables."""
from superset.models.task_subscribers import TaskSubscriber
engine = session.get_bind()
Task.metadata.create_all(engine)
TaskSubscriber.metadata.create_all(engine)
yield session
session.rollback()
def test_find_by_task_key_active(session_with_task: Session) -> None:
"""Test finding active task by task_key"""
from superset.daos.tasks import TaskDAO
create_task(session_with_task)
result = TaskDAO.find_by_task_key(
task_type=TEST_TASK_TYPE,
task_key=TEST_TASK_KEY,
scope=TaskScope.PRIVATE,
user_id=TEST_USER_ID,
)
assert result is not None
assert result.task_key == TEST_TASK_KEY
assert result.task_type == TEST_TASK_TYPE
assert result.status == TaskStatus.PENDING.value
def test_find_by_task_key_not_found(session_with_task: Session) -> None:
"""Test finding task by task_key returns None when not found"""
from superset.daos.tasks import TaskDAO
result = TaskDAO.find_by_task_key(
task_type=TEST_TASK_TYPE,
task_key="nonexistent-key",
scope=TaskScope.PRIVATE,
user_id=TEST_USER_ID,
)
assert result is None
def test_find_by_task_key_finished_not_found(session_with_task: Session) -> None:
"""Test that find_by_task_key returns None for finished tasks.
Finished tasks have a different dedup_key format (UUID-based),
so they won't be found by the active task lookup.
"""
from superset.daos.tasks import TaskDAO
create_task(
session_with_task,
task_key="finished-key",
status=TaskStatus.SUCCESS,
use_finished_dedup_key=True,
task_uuid=TASK_UUID,
)
# Should not find SUCCESS task via active lookup
result = TaskDAO.find_by_task_key(
task_type=TEST_TASK_TYPE,
task_key="finished-key",
scope=TaskScope.PRIVATE,
user_id=TEST_USER_ID,
)
assert result is None
def test_create_task_success(session_with_task: Session) -> None:
"""Test successful task creation."""
from superset.daos.tasks import TaskDAO
result = TaskDAO.create_task(
task_type=TEST_TASK_TYPE,
task_key=TEST_TASK_KEY,
scope=TaskScope.PRIVATE,
user_id=TEST_USER_ID,
)
assert result is not None
assert result.task_key == TEST_TASK_KEY
assert result.task_type == TEST_TASK_TYPE
assert result.status == TaskStatus.PENDING.value
assert isinstance(result, Task)
def test_create_task_with_user_id(session_with_task: Session) -> None:
"""Test task creation with explicit user_id."""
from superset.daos.tasks import TaskDAO
result = TaskDAO.create_task(
task_type=TEST_TASK_TYPE,
task_key="user-task",
scope=TaskScope.PRIVATE,
user_id=42,
)
assert result is not None
assert result.user_id == 42
# Creator should be auto-subscribed
assert len(result.subscribers) == 1
assert result.subscribers[0].user_id == 42
def test_create_task_with_properties(session_with_task: Session) -> None:
"""Test task creation with properties."""
from superset.daos.tasks import TaskDAO
result = TaskDAO.create_task(
task_type=TEST_TASK_TYPE,
task_key="props-task",
scope=TaskScope.PRIVATE,
user_id=TEST_USER_ID,
properties={"timeout": 300},
)
assert result is not None
assert result.properties_dict.get("timeout") == 300
def test_abort_task_pending_success(session_with_task: Session) -> None:
"""Test successful abort of pending task - goes directly to ABORTED"""
from superset.daos.tasks import TaskDAO
task = create_task(
session_with_task,
task_key="pending-task",
status=TaskStatus.PENDING,
)
result = TaskDAO.abort_task(task.uuid, skip_base_filter=True)
assert result is not None
assert result.status == TaskStatus.ABORTED.value
def test_abort_task_in_progress_abortable(session_with_task: Session) -> None:
"""Test abort of in-progress task with abort handler.
Should transition to ABORTING status.
"""
from superset.daos.tasks import TaskDAO
task = create_task(
session_with_task,
task_key="abortable-task",
status=TaskStatus.IN_PROGRESS,
properties={"is_abortable": True},
)
result = TaskDAO.abort_task(task.uuid, skip_base_filter=True)
assert result is not None
# Should set status to ABORTING, not ABORTED
assert result.status == TaskStatus.ABORTING.value
def test_abort_task_in_progress_not_abortable(session_with_task: Session) -> None:
"""Test abort of in-progress task without abort handler - raises error"""
from superset.daos.tasks import TaskDAO
task = create_task(
session_with_task,
task_key="non-abortable-task",
status=TaskStatus.IN_PROGRESS,
properties={"is_abortable": False},
)
with pytest.raises(TaskNotAbortableError):
TaskDAO.abort_task(task.uuid, skip_base_filter=True)
def test_abort_task_in_progress_is_abortable_none(session_with_task: Session) -> None:
"""Test abort of in-progress task with is_abortable not set - raises error"""
from superset.daos.tasks import TaskDAO
task = create_task(
session_with_task,
task_key="no-abortable-prop-task",
status=TaskStatus.IN_PROGRESS,
# Empty properties - no is_abortable key
)
with pytest.raises(TaskNotAbortableError):
TaskDAO.abort_task(task.uuid, skip_base_filter=True)
def test_abort_task_already_aborting(session_with_task: Session) -> None:
"""Test abort of already aborting task - idempotent success"""
from superset.daos.tasks import TaskDAO
task = create_task(
session_with_task,
task_key="aborting-task",
status=TaskStatus.ABORTING,
)
result = TaskDAO.abort_task(task.uuid, skip_base_filter=True)
# Idempotent - returns task without error
assert result is not None
assert result.status == TaskStatus.ABORTING.value
def test_abort_task_not_found(session_with_task: Session) -> None:
"""Test abort fails when task not found"""
from superset.daos.tasks import TaskDAO
result = TaskDAO.abort_task(UUID("00000000-0000-0000-0000-000000000000"))
assert result is None
def test_abort_task_already_finished(session_with_task: Session) -> None:
"""Test abort fails when task already finished"""
from superset.daos.tasks import TaskDAO
task = create_task(
session_with_task,
task_key="finished-task",
status=TaskStatus.SUCCESS,
use_finished_dedup_key=True,
task_uuid=TASK_UUID,
)
result = TaskDAO.abort_task(task.uuid, skip_base_filter=True)
assert result is None
def test_add_subscriber(session_with_task: Session) -> None:
"""Test adding a subscriber to a task"""
from superset.daos.tasks import TaskDAO
task = create_task(
session_with_task,
task_key="shared-task",
scope=TaskScope.SHARED,
user_id=None,
)
# Add subscriber
result = TaskDAO.add_subscriber(task.id, user_id=TEST_USER_ID)
assert result is True
# Verify subscriber was added
session_with_task.refresh(task)
assert len(task.subscribers) == 1
assert task.subscribers[0].user_id == TEST_USER_ID
def test_add_subscriber_idempotent(session_with_task: Session) -> None:
"""Test adding same subscriber twice is idempotent"""
from superset.daos.tasks import TaskDAO
task = create_task(
session_with_task,
task_key="shared-task-2",
scope=TaskScope.SHARED,
user_id=None,
)
# Add subscriber twice
result1 = TaskDAO.add_subscriber(task.id, user_id=TEST_USER_ID)
result2 = TaskDAO.add_subscriber(task.id, user_id=TEST_USER_ID)
assert result1 is True
assert result2 is False # Already subscribed
# Verify only one subscriber
session_with_task.refresh(task)
assert len(task.subscribers) == 1
def test_remove_subscriber(session_with_task: Session) -> None:
"""Test removing a subscriber from a task"""
from superset.daos.tasks import TaskDAO
task = create_task(
session_with_task,
task_key="shared-task-3",
scope=TaskScope.SHARED,
user_id=None,
)
TaskDAO.add_subscriber(task.id, user_id=TEST_USER_ID)
session_with_task.refresh(task)
assert len(task.subscribers) == 1
# Remove subscriber
result = TaskDAO.remove_subscriber(task.id, user_id=TEST_USER_ID)
assert result is not None
assert len(result.subscribers) == 0
def test_remove_subscriber_not_subscribed(session_with_task: Session) -> None:
"""Test removing non-existent subscriber returns None"""
from superset.daos.tasks import TaskDAO
task = create_task(
session_with_task,
task_key="shared-task-4",
scope=TaskScope.SHARED,
user_id=None,
)
# Try to remove non-existent subscriber
result = TaskDAO.remove_subscriber(task.id, user_id=999)
assert result is None
def test_get_status(session_with_task: Session) -> None:
"""Test get_status returns status string when task found by UUID"""
from superset.daos.tasks import TaskDAO
task = create_task(
session_with_task,
task_uuid=TASK_UUID,
task_key="status-task",
status=TaskStatus.IN_PROGRESS,
)
result = TaskDAO.get_status(task.uuid)
assert result == TaskStatus.IN_PROGRESS.value
def test_get_status_not_found(session_with_task: Session) -> None:
"""Test get_status returns None when task not found"""
from superset.daos.tasks import TaskDAO
result = TaskDAO.get_status(UUID("00000000-0000-0000-0000-000000000000"))
assert result is None

View File

@@ -18,17 +18,21 @@
# pylint: disable=invalid-name
from typing import Any
from unittest.mock import MagicMock, patch
from uuid import UUID
import pytest
from freezegun import freeze_time
from sqlalchemy.orm import Session, sessionmaker
# Force module loading before tests run so patches work correctly
import superset.commands.distributed_lock.acquire as acquire_module
import superset.commands.distributed_lock.release as release_module
from superset import db
from superset.distributed_lock import KeyValueDistributedLock
from superset.distributed_lock import DistributedLock
from superset.distributed_lock.types import LockValue
from superset.distributed_lock.utils import get_key
from superset.exceptions import CreateKeyValueDistributedLockFailedException
from superset.exceptions import AcquireDistributedLockFailedException
from superset.key_value.types import JsonKeyValueCodec
LOCK_VALUE: LockValue = {"value": True}
@@ -56,9 +60,9 @@ def _get_other_session() -> Session:
return SessionMaker()
def test_key_value_distributed_lock_happy_path() -> None:
def test_distributed_lock_kv_happy_path() -> None:
"""
Test successfully acquiring and returning the distributed lock.
Test successfully acquiring and returning the distributed lock via KV backend.
Note, we're using another session for asserting the lock state in the Metastore
to simulate what another worker will observe. Otherwise, there's the risk that
@@ -66,24 +70,29 @@ def test_key_value_distributed_lock_happy_path() -> None:
"""
session = _get_other_session()
with freeze_time("2021-01-01"):
assert _get_lock(MAIN_KEY, session) is None
# Ensure Redis is not configured so KV backend is used
with (
patch.object(acquire_module, "get_redis_client", return_value=None),
patch.object(release_module, "get_redis_client", return_value=None),
):
with freeze_time("2021-01-01"):
assert _get_lock(MAIN_KEY, session) is None
with KeyValueDistributedLock("ns", a=1, b=2) as key:
assert key == MAIN_KEY
assert _get_lock(key, session) == LOCK_VALUE
assert _get_lock(OTHER_KEY, session) is None
with DistributedLock("ns", a=1, b=2) as key:
assert key == MAIN_KEY
assert _get_lock(key, session) == LOCK_VALUE
assert _get_lock(OTHER_KEY, session) is None
with pytest.raises(CreateKeyValueDistributedLockFailedException):
with KeyValueDistributedLock("ns", a=1, b=2):
pass
with pytest.raises(AcquireDistributedLockFailedException):
with DistributedLock("ns", a=1, b=2):
pass
assert _get_lock(MAIN_KEY, session) is None
assert _get_lock(MAIN_KEY, session) is None
def test_key_value_distributed_lock_expired() -> None:
def test_distributed_lock_kv_expired() -> None:
"""
Test expiration of the distributed lock
Test expiration of the distributed lock via KV backend.
Note, we're using another session for asserting the lock state in the Metastore
to simulate what another worker will observe. Otherwise, there's the risk that
@@ -91,11 +100,112 @@ def test_key_value_distributed_lock_expired() -> None:
"""
session = _get_other_session()
with freeze_time("2021-01-01"):
assert _get_lock(MAIN_KEY, session) is None
with KeyValueDistributedLock("ns", a=1, b=2):
assert _get_lock(MAIN_KEY, session) == LOCK_VALUE
with freeze_time("2022-01-01"):
assert _get_lock(MAIN_KEY, session) is None
# Ensure Redis is not configured so KV backend is used
with (
patch.object(acquire_module, "get_redis_client", return_value=None),
patch.object(release_module, "get_redis_client", return_value=None),
):
with freeze_time("2021-01-01"):
assert _get_lock(MAIN_KEY, session) is None
with DistributedLock("ns", a=1, b=2):
assert _get_lock(MAIN_KEY, session) == LOCK_VALUE
with freeze_time("2022-01-01"):
assert _get_lock(MAIN_KEY, session) is None
assert _get_lock(MAIN_KEY, session) is None
assert _get_lock(MAIN_KEY, session) is None
def test_distributed_lock_uses_redis_when_configured() -> None:
"""Test that DistributedLock uses Redis backend when configured."""
mock_redis = MagicMock()
mock_redis.set.return_value = True # Lock acquired
# Use patch.object to patch on already-imported modules
with (
patch.object(acquire_module, "get_redis_client", return_value=mock_redis),
patch.object(release_module, "get_redis_client", return_value=mock_redis),
):
with DistributedLock("test_redis", key="value") as lock_key:
assert lock_key is not None
# Verify SET NX EX was called
mock_redis.set.assert_called_once()
call_args = mock_redis.set.call_args
assert call_args.kwargs["nx"] is True
assert "ex" in call_args.kwargs
# Verify DELETE was called on exit
mock_redis.delete.assert_called_once()
def test_distributed_lock_redis_already_taken() -> None:
"""Test Redis lock fails when already held."""
mock_redis = MagicMock()
mock_redis.set.return_value = None # Lock not acquired (already taken)
with patch.object(acquire_module, "get_redis_client", return_value=mock_redis):
with pytest.raises(AcquireDistributedLockFailedException):
with DistributedLock("test_redis", key="value"):
pass
def test_distributed_lock_redis_connection_error() -> None:
"""Test Redis connection error raises exception (fail fast)."""
import redis
mock_redis = MagicMock()
mock_redis.set.side_effect = redis.RedisError("Connection failed")
with patch.object(acquire_module, "get_redis_client", return_value=mock_redis):
with pytest.raises(AcquireDistributedLockFailedException):
with DistributedLock("test_redis", key="value"):
pass
def test_distributed_lock_custom_ttl() -> None:
"""Test Redis lock with custom TTL."""
mock_redis = MagicMock()
mock_redis.set.return_value = True
with (
patch.object(acquire_module, "get_redis_client", return_value=mock_redis),
patch.object(release_module, "get_redis_client", return_value=mock_redis),
):
with DistributedLock("test", ttl_seconds=60, key="value"):
call_args = mock_redis.set.call_args
assert call_args.kwargs["ex"] == 60 # Custom TTL
def test_distributed_lock_default_ttl(app_context: None) -> None:
"""Test Redis lock uses default TTL when not specified."""
from superset.commands.distributed_lock.base import get_default_lock_ttl
mock_redis = MagicMock()
mock_redis.set.return_value = True
with (
patch.object(acquire_module, "get_redis_client", return_value=mock_redis),
patch.object(release_module, "get_redis_client", return_value=mock_redis),
):
with DistributedLock("test", key="value"):
call_args = mock_redis.set.call_args
assert call_args.kwargs["ex"] == get_default_lock_ttl()
def test_distributed_lock_fallback_to_kv_when_redis_not_configured() -> None:
"""Test falls back to KV lock when Redis not configured."""
session = _get_other_session()
test_key = get_key("test_fallback", key="value")
with (
patch.object(acquire_module, "get_redis_client", return_value=None),
patch.object(release_module, "get_redis_client", return_value=None),
):
with freeze_time("2021-01-01"):
# When Redis is not configured, should use KV backend
with DistributedLock("test_fallback", key="value") as lock_key:
assert lock_key == test_key
# Verify lock exists in KV store
assert _get_lock(test_key, session) == LOCK_VALUE
# Lock should be released
assert _get_lock(test_key, session) is None

View File

@@ -0,0 +1,477 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Unit tests for task decorators"""
from unittest.mock import MagicMock, patch
from uuid import UUID
import pytest
from superset_core.api.tasks import TaskOptions, TaskScope
from superset.commands.tasks.exceptions import GlobalTaskFrameworkDisabledError
from superset.tasks.decorators import task, TaskWrapper
from superset.tasks.registry import TaskRegistry
TEST_UUID = UUID("b8b61b7b-1cd3-4a31-a74a-0a95341afc06")
class TestTaskDecoratorFeatureFlag:
"""Tests for @task decorator feature flag behavior"""
def setup_method(self):
"""Clear task registry before each test"""
TaskRegistry._tasks.clear()
@patch("superset.tasks.decorators.is_feature_enabled", return_value=False)
def test_decorator_succeeds_when_gtf_disabled(self, mock_feature_flag):
"""Test that @task decorator can be applied even when GTF is disabled.
This enables safe module imports during app startup or Celery autodiscovery.
"""
# Decoration should succeed - no error raised
@task(name="test_gtf_disabled_decorator")
def my_task() -> None:
pass
assert isinstance(my_task, TaskWrapper)
assert my_task.name == "test_gtf_disabled_decorator"
@patch("superset.tasks.decorators.is_feature_enabled", return_value=False)
def test_call_raises_error_when_gtf_disabled(self, mock_feature_flag):
"""Test that calling a task raises GlobalTaskFrameworkDisabledError
when GTF is disabled."""
@task(name="test_gtf_disabled_call")
def my_task() -> None:
pass
with pytest.raises(GlobalTaskFrameworkDisabledError):
my_task()
@patch("superset.tasks.decorators.is_feature_enabled", return_value=False)
def test_schedule_raises_error_when_gtf_disabled(self, mock_feature_flag):
"""Test that scheduling a task raises GlobalTaskFrameworkDisabledError
when GTF is disabled."""
@task(name="test_gtf_disabled_schedule")
def my_task() -> None:
pass
with pytest.raises(GlobalTaskFrameworkDisabledError):
my_task.schedule()
class TestTaskDecorator:
"""Tests for @task decorator"""
def test_decorator_basic(self):
"""Test basic decorator usage without options"""
@task(name="test_task")
def my_task(arg1: int, arg2: str) -> None:
pass
assert isinstance(my_task, TaskWrapper)
assert my_task.name == "test_task"
assert my_task.scope == TaskScope.PRIVATE
def test_decorator_without_parentheses(self):
"""Test decorator usage without parentheses"""
@task
def my_no_parens_task(arg1: int, arg2: str) -> None:
pass
assert isinstance(my_no_parens_task, TaskWrapper)
assert my_no_parens_task.name == "my_no_parens_task" # Uses function name
assert my_no_parens_task.scope == TaskScope.PRIVATE
def test_decorator_with_default_scope_private(self):
"""Test decorator with explicit PRIVATE scope"""
@task(name="private_task", scope=TaskScope.PRIVATE)
def my_private_task(arg1: int) -> None:
pass
assert my_private_task.scope == TaskScope.PRIVATE
def test_decorator_with_default_scope_shared(self):
"""Test decorator with SHARED scope"""
@task(name="shared_task", scope=TaskScope.SHARED)
def my_shared_task(arg1: int) -> None:
pass
assert my_shared_task.scope == TaskScope.SHARED
def test_decorator_with_default_scope_system(self):
"""Test decorator with SYSTEM scope"""
@task(name="system_task", scope=TaskScope.SYSTEM)
def my_system_task() -> None:
pass
assert my_system_task.scope == TaskScope.SYSTEM
def test_decorator_forbids_ctx_parameter(self):
"""Test decorator rejects functions with ctx parameter"""
with pytest.raises(TypeError, match="must not define 'ctx'"):
@task(name="bad_task")
def bad_task(ctx, arg1: int) -> None: # noqa: ARG001
pass
def test_decorator_forbids_options_parameter(self):
"""Test decorator rejects functions with options parameter"""
with pytest.raises(TypeError, match="must not define.*'options'"):
@task(name="bad_task")
def bad_task(options, arg1: int) -> None: # noqa: ARG001
pass
class TestTaskWrapperMergeOptions:
"""Tests for TaskWrapper._merge_options()"""
def setup_method(self):
"""Clear task registry before each test"""
TaskRegistry._tasks.clear()
def test_merge_options_no_override(self):
"""Test merging with no override returns defaults"""
@task(name="test_merge_no_override_unique")
def merge_task_1() -> None:
pass
# Set default options for testing
merge_task_1.default_options = TaskOptions(
task_key="default_key",
task_name="Default Name",
)
merged = merge_task_1._merge_options(None)
assert merged.task_key == "default_key"
assert merged.task_name == "Default Name"
def test_merge_options_override_task_key(self):
"""Test overriding task_key at call time"""
@task(name="test_merge_override_key_unique")
def merge_task_2() -> None:
pass
# Set default options for testing
merge_task_2.default_options = TaskOptions(task_key="default_key")
override = TaskOptions(task_key="override_key")
merged = merge_task_2._merge_options(override)
assert merged.task_key == "override_key"
def test_merge_options_override_task_name(self):
"""Test overriding task_name at call time"""
@task(name="test_merge_override_name_unique")
def merge_task_3() -> None:
pass
# Set default options for testing
merge_task_3.default_options = TaskOptions(task_name="Default Name")
override = TaskOptions(task_name="Override Name")
merged = merge_task_3._merge_options(override)
assert merged.task_name == "Override Name"
def test_merge_options_override_all(self):
"""Test overriding all options at call time"""
@task(name="test_merge_override_all_unique")
def merge_task_4() -> None:
pass
# Set default options for testing
merge_task_4.default_options = TaskOptions(
task_key="default_key",
task_name="Default Name",
)
override = TaskOptions(
task_key="override_key",
task_name="Override Name",
)
merged = merge_task_4._merge_options(override)
assert merged.task_key == "override_key"
assert merged.task_name == "Override Name"
class TestTaskWrapperSchedule:
"""Tests for TaskWrapper.schedule() with scope"""
def setup_method(self):
"""Clear task registry before each test"""
TaskRegistry._tasks.clear()
@patch("superset.tasks.decorators.TaskManager.submit_task")
def test_schedule_uses_default_scope(self, mock_submit):
"""Test schedule() uses decorator's default scope"""
mock_submit.return_value = MagicMock()
@task(name="test_schedule_default_unique", scope=TaskScope.SHARED)
def schedule_task_1(arg1: int) -> None:
pass
# Shared tasks require explicit task_key
schedule_task_1.schedule(123, options=TaskOptions(task_key="test_key"))
# Verify TaskManager.submit_task was called with correct scope
mock_submit.assert_called_once()
call_args = mock_submit.call_args
assert call_args[1]["scope"] == TaskScope.SHARED
@patch("superset.tasks.decorators.TaskManager.submit_task")
def test_schedule_uses_private_scope_by_default(self, mock_submit):
"""Test schedule() uses PRIVATE scope when no scope specified"""
mock_submit.return_value = MagicMock()
@task(name="test_schedule_override_unique")
def schedule_task_2(arg1: int) -> None:
pass
schedule_task_2.schedule(123)
# Verify PRIVATE scope was used (default)
mock_submit.assert_called_once()
call_args = mock_submit.call_args
assert call_args[1]["scope"] == TaskScope.PRIVATE
@patch("superset.tasks.decorators.TaskManager.submit_task")
def test_schedule_with_custom_options(self, mock_submit):
"""Test schedule() with custom task options"""
mock_submit.return_value = MagicMock()
@task(name="test_schedule_custom_unique", scope=TaskScope.SYSTEM)
def schedule_task_3(arg1: int) -> None:
pass
# Use custom task key and name
schedule_task_3.schedule(
123,
options=TaskOptions(task_key="custom_key", task_name="Custom Task Name"),
)
# Verify scope from decorator and options from call time
mock_submit.assert_called_once()
call_args = mock_submit.call_args
assert call_args[1]["scope"] == TaskScope.SYSTEM
assert call_args[1]["task_key"] == "custom_key"
assert call_args[1]["task_name"] == "Custom Task Name"
@patch("superset.tasks.decorators.TaskManager.submit_task")
def test_schedule_with_no_decorator_options(self, mock_submit):
"""Test schedule() uses default PRIVATE scope when no options provided"""
mock_submit.return_value = MagicMock()
@task(name="test_schedule_no_options_unique")
def schedule_task_4(arg1: int) -> None:
pass
schedule_task_4.schedule(123)
# Verify default PRIVATE scope
mock_submit.assert_called_once()
call_args = mock_submit.call_args
assert call_args[1]["scope"] == TaskScope.PRIVATE
@patch("superset.tasks.decorators.TaskManager.submit_task")
def test_schedule_shared_task_requires_task_key(self, mock_submit):
"""Test shared task schedule() requires explicit task_key"""
@task(name="test_shared_requires_key", scope=TaskScope.SHARED)
def shared_task(arg1: int) -> None:
pass
# Should raise ValueError when no task_key provided
with pytest.raises(
ValueError,
match="Shared task.*requires an explicit task_key.*for deduplication",
):
shared_task.schedule(123)
# Should work with task_key provided
mock_submit.return_value = MagicMock()
shared_task.schedule(123, options=TaskOptions(task_key="valid_key"))
mock_submit.assert_called_once()
@patch("superset.tasks.decorators.TaskManager.submit_task")
def test_schedule_private_task_allows_no_task_key(self, mock_submit):
"""Test private task schedule() works without task_key"""
mock_submit.return_value = MagicMock()
@task(name="test_private_no_key", scope=TaskScope.PRIVATE)
def private_task(arg1: int) -> None:
pass
# Should work without task_key (generates random UUID)
private_task.schedule(123)
mock_submit.assert_called_once()
class TestTaskWrapperCall:
"""Tests for TaskWrapper.__call__() with scope"""
def setup_method(self):
"""Clear task registry before each test"""
TaskRegistry._tasks.clear()
@patch("superset.commands.tasks.update.UpdateTaskCommand.run")
@patch("superset.daos.tasks.TaskDAO.find_one_or_none")
@patch("superset.commands.tasks.submit.SubmitTaskCommand.run_with_info")
def test_call_uses_default_scope(
self, mock_submit_run_with_info, mock_find, mock_update_run
):
"""Test direct call uses decorator's default scope"""
mock_task = MagicMock()
mock_task.uuid = TEST_UUID
mock_task.status = "in_progress"
mock_submit_run_with_info.return_value = (mock_task, True) # (task, is_new)
mock_update_run.return_value = mock_task
mock_find.return_value = mock_task # Mock the subsequent find call
@task(name="test_call_default_unique", scope=TaskScope.SHARED)
def call_task_1(arg1: int) -> None:
pass
# Shared tasks require explicit task_key
call_task_1(123, options=TaskOptions(task_key="test_key"))
# Verify SubmitTaskCommand.run_with_info was called
mock_submit_run_with_info.assert_called_once()
@patch("superset.utils.core.get_user_id")
@patch("superset.commands.tasks.update.UpdateTaskCommand.run")
@patch("superset.daos.tasks.TaskDAO.find_one_or_none")
@patch("superset.commands.tasks.submit.SubmitTaskCommand.run_with_info")
def test_call_uses_private_scope_by_default(
self, mock_submit_run_with_info, mock_find, mock_update_run, mock_get_user_id
):
"""Test direct call uses PRIVATE scope when no scope specified"""
mock_get_user_id.return_value = 1
mock_task = MagicMock()
mock_task.uuid = TEST_UUID
mock_task.status = "in_progress"
mock_submit_run_with_info.return_value = (mock_task, True) # (task, is_new)
mock_update_run.return_value = mock_task
mock_find.return_value = mock_task # Mock the subsequent find call
@task(name="test_call_private_default_unique")
def call_task_2(arg1: int) -> None:
pass
call_task_2(123)
# Verify SubmitTaskCommand.run_with_info was called
mock_submit_run_with_info.assert_called_once()
@patch("superset.commands.tasks.update.UpdateTaskCommand.run")
@patch("superset.daos.tasks.TaskDAO.find_one_or_none")
@patch("superset.commands.tasks.submit.SubmitTaskCommand.run_with_info")
def test_call_with_custom_options(
self, mock_submit_run_with_info, mock_find, mock_update_run
):
"""Test direct call with custom task options"""
mock_task = MagicMock()
mock_task.uuid = TEST_UUID
mock_task.status = "in_progress"
mock_submit_run_with_info.return_value = (mock_task, True) # (task, is_new)
mock_update_run.return_value = mock_task
mock_find.return_value = mock_task # Mock the subsequent find call
@task(name="test_call_custom_unique", scope=TaskScope.SYSTEM)
def call_task_3(arg1: int) -> None:
pass
# Use custom task key and name
call_task_3(
123,
options=TaskOptions(task_key="custom_key", task_name="Custom Task Name"),
)
# Verify SubmitTaskCommand.run_with_info was called
mock_submit_run_with_info.assert_called_once()
def test_call_shared_task_requires_task_key(self):
"""Test shared task direct call requires explicit task_key"""
@task(name="test_shared_call_requires_key", scope=TaskScope.SHARED)
def shared_task(arg1: int) -> None:
pass
# Should raise ValueError when no task_key provided
with pytest.raises(
ValueError,
match="Shared task.*requires an explicit task_key.*for deduplication",
):
shared_task(123)
@patch("superset.commands.tasks.update.UpdateTaskCommand.run")
@patch("superset.daos.tasks.TaskDAO.find_one_or_none")
@patch("superset.commands.tasks.submit.SubmitTaskCommand.run_with_info")
def test_call_shared_task_works_with_task_key(
self, mock_submit_run_with_info, mock_find, mock_update_run
):
"""Test shared task direct call works with task_key"""
mock_task = MagicMock()
mock_task.uuid = TEST_UUID
mock_task.status = "in_progress"
mock_submit_run_with_info.return_value = (mock_task, True) # (task, is_new)
mock_update_run.return_value = mock_task
mock_find.return_value = mock_task
@task(name="test_shared_call_with_key", scope=TaskScope.SHARED)
def shared_task(arg1: int) -> None:
pass
# Should work with task_key provided
shared_task(123, options=TaskOptions(task_key="valid_key"))
mock_submit_run_with_info.assert_called_once()
@patch("superset.utils.core.get_user_id")
@patch("superset.commands.tasks.update.UpdateTaskCommand.run")
@patch("superset.daos.tasks.TaskDAO.find_one_or_none")
@patch("superset.commands.tasks.submit.SubmitTaskCommand.run_with_info")
def test_call_private_task_allows_no_task_key(
self, mock_submit_run_with_info, mock_find, mock_update_run, mock_get_user_id
):
"""Test private task direct call works without task_key"""
mock_get_user_id.return_value = 1
mock_task = MagicMock()
mock_task.uuid = TEST_UUID
mock_task.status = "in_progress"
mock_submit_run_with_info.return_value = (mock_task, True) # (task, is_new)
mock_update_run.return_value = mock_task
mock_find.return_value = mock_task
@task(name="test_private_call_no_key", scope=TaskScope.PRIVATE)
def private_task(arg1: int) -> None:
pass
# Should work without task_key (generates random UUID)
private_task(123)
mock_submit_run_with_info.assert_called_once()

View File

@@ -0,0 +1,677 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Unit tests for GTF handlers (abort, cleanup) and related Task model behavior."""
import time
from datetime import datetime, timezone
from unittest.mock import MagicMock, Mock, patch
from uuid import UUID
import pytest
from freezegun import freeze_time
from superset_core.api.tasks import TaskStatus
from superset.tasks.context import TaskContext
TEST_UUID = UUID("b8b61b7b-1cd3-4a31-a74a-0a95341afc06")
@pytest.fixture
def mock_task():
"""Create a mock task for testing."""
task = MagicMock()
task.uuid = TEST_UUID
task.status = TaskStatus.PENDING.value
return task
@pytest.fixture
def mock_task_dao(mock_task):
"""Mock TaskDAO to return our test task."""
with patch("superset.daos.tasks.TaskDAO") as mock_dao:
mock_dao.find_one_or_none.return_value = mock_task
yield mock_dao
@pytest.fixture
def mock_update_command():
"""Mock UpdateTaskCommand to avoid database operations."""
with patch("superset.commands.tasks.update.UpdateTaskCommand") as mock_cmd:
mock_cmd.return_value.run.return_value = None
yield mock_cmd
@pytest.fixture
def mock_flask_app():
"""Create a properly configured mock Flask app."""
mock_app = MagicMock()
mock_app.config = {
"TASK_ABORT_POLLING_DEFAULT_INTERVAL": 0.1,
}
# Make app_context() return a proper context manager
mock_app.app_context.return_value.__enter__ = MagicMock(return_value=None)
mock_app.app_context.return_value.__exit__ = MagicMock(return_value=None)
# Use regular Mock (not MagicMock) for _get_current_object to avoid
# AsyncMockMixin creating unawaited coroutines in Python 3.10+
mock_app._get_current_object = Mock(return_value=mock_app)
return mock_app
@pytest.fixture
def task_context(mock_task, mock_task_dao, mock_update_command, mock_flask_app):
"""Create TaskContext with mocked dependencies."""
# Ensure mock_task has properties_dict and payload_dict (TaskContext accesses them)
mock_task.properties_dict = {"is_abortable": False}
mock_task.payload_dict = {}
with (
patch("superset.tasks.context.current_app") as mock_current_app,
patch("superset.tasks.manager.cache_manager") as mock_cache_manager,
):
# Disable Redis by making signal_cache return None
mock_cache_manager.signal_cache = None
# Configure current_app mock
mock_current_app.config = mock_flask_app.config
# Use regular Mock (not MagicMock) for _get_current_object to avoid
# AsyncMockMixin creating unawaited coroutines in Python 3.10+
mock_current_app._get_current_object = Mock(return_value=mock_flask_app)
ctx = TaskContext(mock_task)
yield ctx
# Cleanup: stop polling if started
if ctx._abort_listener:
ctx.stop_abort_polling()
class TestTaskStatusEnum:
"""Test TaskStatus enum values."""
def test_aborting_status_exists(self):
"""Test that ABORTING status is defined."""
assert hasattr(TaskStatus, "ABORTING")
assert TaskStatus.ABORTING.value == "aborting"
def test_all_statuses_present(self):
"""Test all expected statuses are present."""
expected_statuses = [
"pending",
"in_progress",
"success",
"failure",
"aborting",
"aborted",
]
actual_statuses = [s.value for s in TaskStatus]
for status in expected_statuses:
assert status in actual_statuses, f"Missing status: {status}"
class TestTaskAbortProperties:
"""Test Task model abort-related properties via status and properties accessor."""
def test_aborting_status(self):
"""Test ABORTING status check."""
from superset.models.tasks import Task
task = Task()
task.status = TaskStatus.ABORTING.value
assert task.status == TaskStatus.ABORTING.value
def test_is_abortable_in_properties(self):
"""Test is_abortable is accessible via properties."""
from superset.models.tasks import Task
task = Task()
task.update_properties({"is_abortable": True})
assert task.properties_dict.get("is_abortable") is True
def test_is_abortable_default_none(self):
"""Test is_abortable defaults to None for new tasks."""
from superset.models.tasks import Task
task = Task()
assert task.properties_dict.get("is_abortable") is None
class TestTaskSetStatus:
"""Test Task.set_status behavior for abort states."""
def test_set_status_in_progress_sets_is_abortable_false(self):
"""Test that transitioning to IN_PROGRESS sets is_abortable to False."""
from superset.models.tasks import Task
task = Task()
task.uuid = "test-uuid"
# Default is None
task.set_status(TaskStatus.IN_PROGRESS)
assert task.properties_dict.get("is_abortable") is False
assert task.started_at is not None
def test_set_status_in_progress_preserves_existing_is_abortable(self):
"""Test that re-setting IN_PROGRESS doesn't override is_abortable."""
from superset.models.tasks import Task
task = Task()
task.uuid = "test-uuid"
task.update_properties(
{"is_abortable": True}
) # Already set by handler registration
task.started_at = datetime.now(timezone.utc) # Already started
task.set_status(TaskStatus.IN_PROGRESS)
# Should not override since started_at is already set
assert task.properties_dict.get("is_abortable") is True
def test_set_status_aborting_does_not_set_ended_at(self):
"""Test that ABORTING status does not set ended_at."""
from superset.models.tasks import Task
task = Task()
task.uuid = "test-uuid"
task.started_at = datetime.now(timezone.utc)
task.status = TaskStatus.ABORTING.value
assert task.ended_at is None
def test_set_status_aborted_sets_ended_at(self):
"""Test that ABORTED status sets ended_at."""
from superset.models.tasks import Task
task = Task()
task.uuid = "test-uuid"
task.started_at = datetime.now(timezone.utc)
task.set_status(TaskStatus.ABORTED)
assert task.ended_at is not None
class TestTaskDuration:
"""Test Task duration_seconds property with different states."""
def test_duration_seconds_finished_task(self):
"""Test duration for finished task returns actual duration."""
from superset.models.tasks import Task
task = Task()
task.status = TaskStatus.SUCCESS.value # Must be finished to use ended_at
task.started_at = datetime(2024, 1, 1, 10, 0, 0, tzinfo=timezone.utc)
task.ended_at = datetime(2024, 1, 1, 10, 0, 30, tzinfo=timezone.utc)
# Should use ended_at - started_at = 30 seconds
assert task.duration_seconds == 30.0
@freeze_time("2024-01-01 10:00:30")
def test_duration_seconds_running_task(self):
"""Test duration for running task returns time since start."""
from superset.models.tasks import Task
task = Task()
task.started_at = datetime(2024, 1, 1, 10, 0, 0, tzinfo=timezone.utc)
task.ended_at = None
# 30 seconds since start
assert task.duration_seconds == 30.0
@freeze_time("2024-01-01 10:00:15")
def test_duration_seconds_pending_task(self):
"""Test duration for pending task returns queue time."""
from superset.models.tasks import Task
task = Task()
task.created_on = datetime(2024, 1, 1, 10, 0, 0, tzinfo=timezone.utc)
task.started_at = None
task.ended_at = None
# 15 seconds since creation
assert task.duration_seconds == 15.0
def test_duration_seconds_no_timestamps(self):
"""Test duration returns None when no timestamps available."""
from superset.models.tasks import Task
task = Task()
task.created_on = None
task.started_at = None
task.ended_at = None
assert task.duration_seconds is None
class TestAbortHandlerRegistration:
"""Test abort handler registration and is_abortable flag."""
def test_on_abort_registers_handler(self, task_context):
"""Test that on_abort registers a handler."""
handler_called = False
@task_context.on_abort
def handle_abort():
nonlocal handler_called
handler_called = True
assert len(task_context._abort_handlers) == 1
assert not handler_called
@patch("superset.tasks.context.current_app")
def test_on_abort_sets_abortable(self, mock_app):
"""Test on_abort sets is_abortable to True on first handler."""
mock_app.config = {"TASK_ABORT_POLLING_DEFAULT_INTERVAL": 1.0}
mock_app._get_current_object = Mock(return_value=mock_app)
mock_task = MagicMock()
mock_task.uuid = TEST_UUID
mock_task.properties_dict = {"is_abortable": False}
mock_task.payload_dict = {}
with (
patch.object(TaskContext, "_set_abortable") as mock_set_abortable,
patch.object(TaskContext, "start_abort_polling"),
):
ctx = TaskContext(mock_task)
@ctx.on_abort
def handler():
pass
mock_set_abortable.assert_called_once()
@patch("superset.tasks.context.current_app")
def test_on_abort_only_sets_abortable_once(self, mock_app):
"""Test on_abort only calls _set_abortable for first handler."""
mock_app.config = {"TASK_ABORT_POLLING_DEFAULT_INTERVAL": 1.0}
mock_app._get_current_object = Mock(return_value=mock_app)
mock_task = MagicMock()
mock_task.uuid = TEST_UUID
mock_task.properties_dict = {"is_abortable": False}
mock_task.payload_dict = {}
with (
patch.object(TaskContext, "_set_abortable") as mock_set_abortable,
patch.object(TaskContext, "start_abort_polling"),
):
ctx = TaskContext(mock_task)
@ctx.on_abort
def handler1():
pass
@ctx.on_abort
def handler2():
pass
# Should only be called once for first handler
assert mock_set_abortable.call_count == 1
def test_abort_handlers_completed_initially_false(self):
"""Test abort_handlers_completed is False initially."""
mock_task = MagicMock()
mock_task.uuid = TEST_UUID
mock_task.properties_dict = {}
mock_task.payload_dict = {}
with patch("superset.tasks.context.current_app") as mock_app:
mock_app._get_current_object = Mock(return_value=mock_app)
ctx = TaskContext(mock_task)
assert ctx.abort_handlers_completed is False
class TestAbortPolling:
"""Test abort detection polling behavior."""
def test_on_abort_starts_polling_automatically(self, task_context):
"""Test that registering first handler starts abort listener."""
assert task_context._abort_listener is None
@task_context.on_abort
def handle_abort():
pass
assert task_context._abort_listener is not None
def test_stop_abort_polling(self, task_context):
"""Test that stop_abort_polling stops the abort listener."""
@task_context.on_abort
def handle_abort():
pass
assert task_context._abort_listener is not None
task_context.stop_abort_polling()
assert task_context._abort_listener is None
def test_start_abort_polling_only_once(self, task_context):
"""Test that start_abort_polling is idempotent."""
task_context.start_abort_polling(interval=0.1)
first_listener = task_context._abort_listener
# Try to start again
task_context.start_abort_polling(interval=0.1)
second_listener = task_context._abort_listener
# Should be the same listener
assert first_listener is second_listener
def test_on_abort_with_custom_interval(self, task_context):
"""Test that custom interval can be set via start_abort_polling."""
with patch("superset.tasks.context.current_app") as mock_app:
mock_app.config = {"TASK_ABORT_POLLING_DEFAULT_INTERVAL": 0.1}
mock_app._get_current_object = Mock(return_value=mock_app)
@task_context.on_abort
def handle_abort():
pass
# Override with custom interval
task_context.stop_abort_polling()
task_context.start_abort_polling(interval=0.05)
assert task_context._abort_listener is not None
def test_polling_stops_after_abort_detected(self, task_context, mock_task):
"""Test that abort is detected and handlers are triggered."""
@task_context.on_abort
def handle_abort():
pass
# Trigger abort
mock_task.status = TaskStatus.ABORTED.value
# Wait for detection
time.sleep(0.3)
# Abort should have been detected
assert task_context._abort_detected is True
class TestAbortHandlerExecution:
"""Test abort handler execution behavior."""
def test_on_abort_handler_fires_when_task_aborted(self, task_context, mock_task):
"""Test that abort handler fires automatically when task is aborted."""
abort_called = False
@task_context.on_abort
def handle_abort():
nonlocal abort_called
abort_called = True
# Simulate task being aborted
mock_task.status = TaskStatus.ABORTED.value
# Wait for polling to detect abort (max 0.3s with 0.1s interval)
time.sleep(0.3)
assert abort_called
assert task_context._abort_detected
def test_on_abort_not_called_on_success(self, task_context, mock_task):
"""Test that abort handlers don't run on success."""
abort_called = False
@task_context.on_abort
def handle_abort():
nonlocal abort_called
abort_called = True
# Keep task in success state
mock_task.status = TaskStatus.SUCCESS.value
# Wait and verify handler not called
time.sleep(0.3)
assert not abort_called
def test_multiple_abort_handlers(self, task_context, mock_task):
"""Test that all abort handlers execute in LIFO order."""
calls = []
@task_context.on_abort
def handler1():
calls.append(1)
@task_context.on_abort
def handler2():
calls.append(2)
# Trigger abort
mock_task.status = TaskStatus.ABORTED.value
# Wait for detection
time.sleep(0.3)
# LIFO order: handler2 runs first
assert calls == [2, 1]
def test_abort_handler_exception_doesnt_fail_task(self, task_context, mock_task):
"""Test that exception in abort handler is logged but doesn't fail task."""
handler2_called = False
@task_context.on_abort
def bad_handler():
raise ValueError("Handler error")
@task_context.on_abort
def good_handler():
nonlocal handler2_called
handler2_called = True
# Trigger abort
mock_task.status = TaskStatus.ABORTED.value
# Wait for detection
time.sleep(0.3)
# Second handler should still run despite first handler failing
assert handler2_called
class TestBestEffortHandlerExecution:
"""Test that all handlers execute even when some fail (best-effort)."""
def test_all_abort_handlers_run_even_if_all_fail(self, task_context, mock_task):
"""Test all abort handlers execute even if every one raises an exception."""
calls = []
@task_context.on_abort
def handler1():
calls.append(1)
raise ValueError("Handler 1 failed")
@task_context.on_abort
def handler2():
calls.append(2)
raise RuntimeError("Handler 2 failed")
@task_context.on_abort
def handler3():
calls.append(3)
raise TypeError("Handler 3 failed")
# Trigger abort handlers directly (simulating abort detection)
task_context._trigger_abort_handlers()
# All handlers should have been called (LIFO order: 3, 2, 1)
assert calls == [3, 2, 1]
# Failures should be collected (abort handlers don't write to DB)
assert len(task_context._handler_failures) == 3
failure_types = [
type(ex).__name__ for _, ex, _ in task_context._handler_failures
]
assert "TypeError" in failure_types
assert "RuntimeError" in failure_types
assert "ValueError" in failure_types
def test_all_cleanup_handlers_run_even_if_all_fail(self, task_context, mock_task):
"""Test all cleanup handlers execute even if every one raises an exception."""
calls = []
captured_failures = []
# Mock _write_handler_failures_to_db to capture failures before clearing
original_write = task_context._write_handler_failures_to_db
def mock_write():
captured_failures.extend(task_context._handler_failures)
original_write()
task_context._write_handler_failures_to_db = mock_write
@task_context.on_cleanup
def cleanup1():
calls.append(1)
raise ValueError("Cleanup 1 failed")
@task_context.on_cleanup
def cleanup2():
calls.append(2)
raise RuntimeError("Cleanup 2 failed")
@task_context.on_cleanup
def cleanup3():
calls.append(3)
raise TypeError("Cleanup 3 failed")
# Set task to SUCCESS (not aborting) so only cleanup handlers run
mock_task.status = TaskStatus.SUCCESS.value
# Run cleanup
task_context._run_cleanup()
# All handlers should have been called (LIFO order: 3, 2, 1)
assert calls == [3, 2, 1]
# Failures should have been captured before clearing
assert len(captured_failures) == 3
failure_types = [type(ex).__name__ for _, ex, _ in captured_failures]
assert "TypeError" in failure_types
assert "RuntimeError" in failure_types
assert "ValueError" in failure_types
def test_mixed_abort_and_cleanup_failures_all_collected(
self, task_context, mock_task
):
"""Test abort and cleanup handler failures are collected together."""
calls = []
captured_failures = []
# Mock _write_handler_failures_to_db to capture failures before clearing
original_write = task_context._write_handler_failures_to_db
def mock_write():
captured_failures.extend(task_context._handler_failures)
original_write()
task_context._write_handler_failures_to_db = mock_write
@task_context.on_abort
def abort1():
calls.append("abort1")
raise ValueError("Abort 1 failed")
@task_context.on_abort
def abort2():
calls.append("abort2")
raise RuntimeError("Abort 2 failed")
@task_context.on_cleanup
def cleanup1():
calls.append("cleanup1")
raise TypeError("Cleanup 1 failed")
@task_context.on_cleanup
def cleanup2():
calls.append("cleanup2")
raise KeyError("Cleanup 2 failed")
# Set task to ABORTING so both abort and cleanup handlers run
mock_task.status = TaskStatus.ABORTING.value
# Run cleanup (which triggers abort handlers first, then cleanup handlers)
task_context._run_cleanup()
# All handlers should have been called
# Abort handlers run first (LIFO: abort2, abort1)
# Then cleanup handlers (LIFO: cleanup2, cleanup1)
assert calls == ["abort2", "abort1", "cleanup2", "cleanup1"]
# All 4 failures should have been captured
assert len(captured_failures) == 4
# Verify handler types are recorded correctly
handler_types = [htype for htype, _, _ in captured_failures]
assert handler_types.count("abort") == 2
assert handler_types.count("cleanup") == 2
class TestCleanupHandlers:
"""Test cleanup handler behavior."""
def test_cleanup_triggers_abort_handlers_if_not_detected(
self, task_context, mock_task
):
"""Test that _run_cleanup triggers abort handlers if task ended aborted."""
abort_called = False
@task_context.on_abort
def handle_abort():
nonlocal abort_called
abort_called = True
# Set task as aborted but don't let polling detect it
mock_task.status = TaskStatus.ABORTED.value
task_context._abort_detected = False
# Immediately run cleanup (simulating task ending before poll)
task_context._run_cleanup()
assert abort_called
def test_cleanup_doesnt_duplicate_abort_handlers(self, task_context, mock_task):
"""Test that abort handlers only run once even if called from cleanup."""
call_count = 0
@task_context.on_abort
def handle_abort():
nonlocal call_count
call_count += 1
# Trigger abort via polling
mock_task.status = TaskStatus.ABORTED.value
time.sleep(0.3)
# Handlers should have been called once
assert call_count == 1
assert task_context._abort_detected is True
# Run cleanup - handlers should NOT be called again
task_context._run_cleanup()
assert call_count == 1 # Still 1, not 2

View File

@@ -0,0 +1,462 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Unit tests for TaskManager pub/sub functionality"""
import threading
import time
from unittest.mock import MagicMock, patch
import redis
from superset.tasks.manager import AbortListener, TaskManager
class TestAbortListener:
"""Tests for AbortListener class"""
def test_stop_sets_event(self):
"""Test that stop() sets the stop event"""
stop_event = threading.Event()
thread = MagicMock(spec=threading.Thread)
thread.is_alive.return_value = False
listener = AbortListener("test-uuid", thread, stop_event)
assert not stop_event.is_set()
listener.stop()
assert stop_event.is_set()
def test_stop_closes_pubsub(self):
"""Test that stop() closes the pub/sub connection"""
stop_event = threading.Event()
thread = MagicMock(spec=threading.Thread)
thread.is_alive.return_value = False
pubsub = MagicMock()
listener = AbortListener("test-uuid", thread, stop_event, pubsub)
listener.stop()
pubsub.unsubscribe.assert_called_once()
pubsub.close.assert_called_once()
def test_stop_joins_thread(self):
"""Test that stop() joins the listener thread"""
stop_event = threading.Event()
thread = MagicMock(spec=threading.Thread)
thread.is_alive.return_value = True
listener = AbortListener("test-uuid", thread, stop_event)
listener.stop()
thread.join.assert_called_once_with(timeout=2.0)
class TestTaskManagerInitApp:
"""Tests for TaskManager.init_app()"""
def setup_method(self):
"""Reset TaskManager state before each test"""
TaskManager._initialized = False
TaskManager._channel_prefix = "gtf:abort:"
TaskManager._completion_channel_prefix = "gtf:complete:"
def teardown_method(self):
"""Reset TaskManager state after each test"""
TaskManager._initialized = False
TaskManager._channel_prefix = "gtf:abort:"
TaskManager._completion_channel_prefix = "gtf:complete:"
def test_init_app_sets_channel_prefixes(self):
"""Test init_app reads channel prefixes from config"""
app = MagicMock()
app.config.get.side_effect = lambda key, default=None: {
"TASKS_ABORT_CHANNEL_PREFIX": "custom:abort:",
"TASKS_COMPLETION_CHANNEL_PREFIX": "custom:complete:",
}.get(key, default)
TaskManager.init_app(app)
assert TaskManager._initialized is True
assert TaskManager._channel_prefix == "custom:abort:"
assert TaskManager._completion_channel_prefix == "custom:complete:"
def test_init_app_skips_if_already_initialized(self):
"""Test init_app is idempotent"""
TaskManager._initialized = True
app = MagicMock()
TaskManager.init_app(app)
# Should not call app.config.get since already initialized
app.config.get.assert_not_called()
class TestTaskManagerPubSub:
"""Tests for TaskManager pub/sub methods"""
def setup_method(self):
"""Reset TaskManager state before each test"""
TaskManager._initialized = False
TaskManager._channel_prefix = "gtf:abort:"
TaskManager._completion_channel_prefix = "gtf:complete:"
def teardown_method(self):
"""Reset TaskManager state after each test"""
TaskManager._initialized = False
TaskManager._channel_prefix = "gtf:abort:"
TaskManager._completion_channel_prefix = "gtf:complete:"
@patch("superset.tasks.manager.cache_manager")
def test_is_pubsub_available_no_redis(self, mock_cache_manager):
"""Test is_pubsub_available returns False when Redis not configured"""
mock_cache_manager.signal_cache = None
assert TaskManager.is_pubsub_available() is False
@patch("superset.tasks.manager.cache_manager")
def test_is_pubsub_available_with_redis(self, mock_cache_manager):
"""Test is_pubsub_available returns True when Redis is configured"""
mock_cache_manager.signal_cache = MagicMock()
assert TaskManager.is_pubsub_available() is True
def test_get_abort_channel(self):
"""Test get_abort_channel returns correct channel name"""
task_uuid = "abc-123-def-456"
channel = TaskManager.get_abort_channel(task_uuid)
assert channel == "gtf:abort:abc-123-def-456"
def test_get_abort_channel_custom_prefix(self):
"""Test get_abort_channel with custom prefix"""
TaskManager._channel_prefix = "custom:prefix:"
task_uuid = "test-uuid"
channel = TaskManager.get_abort_channel(task_uuid)
assert channel == "custom:prefix:test-uuid"
@patch("superset.tasks.manager.cache_manager")
def test_publish_abort_no_redis(self, mock_cache_manager):
"""Test publish_abort returns False when Redis not available"""
mock_cache_manager.signal_cache = None
result = TaskManager.publish_abort("test-uuid")
assert result is False
@patch("superset.tasks.manager.cache_manager")
def test_publish_abort_success(self, mock_cache_manager):
"""Test publish_abort publishes message successfully"""
mock_redis = MagicMock()
mock_redis.publish.return_value = 1 # One subscriber
mock_cache_manager.signal_cache = mock_redis
result = TaskManager.publish_abort("test-uuid")
assert result is True
mock_redis.publish.assert_called_once_with("gtf:abort:test-uuid", "abort")
@patch("superset.tasks.manager.cache_manager")
def test_publish_abort_redis_error(self, mock_cache_manager):
"""Test publish_abort handles Redis errors gracefully"""
mock_redis = MagicMock()
mock_redis.publish.side_effect = redis.RedisError("Connection lost")
mock_cache_manager.signal_cache = mock_redis
result = TaskManager.publish_abort("test-uuid")
assert result is False
class TestTaskManagerListenForAbort:
"""Tests for TaskManager.listen_for_abort()"""
def setup_method(self):
"""Reset TaskManager state before each test"""
TaskManager._initialized = False
TaskManager._channel_prefix = "gtf:abort:"
TaskManager._completion_channel_prefix = "gtf:complete:"
def teardown_method(self):
"""Reset TaskManager state after each test"""
TaskManager._initialized = False
TaskManager._channel_prefix = "gtf:abort:"
TaskManager._completion_channel_prefix = "gtf:complete:"
@patch("superset.tasks.manager.cache_manager")
def test_listen_for_abort_no_redis_uses_polling(self, mock_cache_manager):
"""Test listen_for_abort falls back to polling when Redis unavailable"""
mock_cache_manager.signal_cache = None
callback = MagicMock()
with patch.object(TaskManager, "_poll_for_abort", return_value=None):
listener = TaskManager.listen_for_abort(
task_uuid="test-uuid",
callback=callback,
poll_interval=1.0,
app=None,
)
# Give thread time to start
time.sleep(0.1)
listener.stop()
# Should use polling since no Redis
assert listener._pubsub is None
@patch("superset.tasks.manager.cache_manager")
def test_listen_for_abort_with_redis_uses_pubsub(self, mock_cache_manager):
"""Test listen_for_abort uses pub/sub when Redis available"""
mock_redis = MagicMock()
mock_pubsub = MagicMock()
mock_redis.pubsub.return_value = mock_pubsub
mock_cache_manager.signal_cache = mock_redis
callback = MagicMock()
with patch.object(TaskManager, "_listen_pubsub", return_value=None):
listener = TaskManager.listen_for_abort(
task_uuid="test-uuid",
callback=callback,
poll_interval=1.0,
app=None,
)
# Give thread time to start
time.sleep(0.1)
listener.stop()
# Should subscribe to channel
mock_pubsub.subscribe.assert_called_once_with("gtf:abort:test-uuid")
@patch("superset.tasks.manager.cache_manager")
def test_listen_for_abort_redis_subscribe_failure_raises(self, mock_cache_manager):
"""Test listen_for_abort raises exception on subscribe failure
when Redis configured"""
import pytest
mock_redis = MagicMock()
mock_redis.pubsub.side_effect = redis.RedisError("Connection failed")
mock_cache_manager.signal_cache = mock_redis
callback = MagicMock()
# With fail-fast behavior, Redis subscribe failure raises exception
with pytest.raises(redis.RedisError, match="Connection failed"):
TaskManager.listen_for_abort(
task_uuid="test-uuid",
callback=callback,
poll_interval=1.0,
app=None,
)
class TestTaskManagerCompletion:
"""Tests for TaskManager completion pub/sub and wait_for_completion"""
def setup_method(self):
"""Reset TaskManager state before each test"""
TaskManager._initialized = False
TaskManager._channel_prefix = "gtf:abort:"
TaskManager._completion_channel_prefix = "gtf:complete:"
def teardown_method(self):
"""Reset TaskManager state after each test"""
TaskManager._initialized = False
TaskManager._channel_prefix = "gtf:abort:"
TaskManager._completion_channel_prefix = "gtf:complete:"
def test_get_completion_channel(self):
"""Test get_completion_channel returns correct channel name"""
task_uuid = "abc-123-def-456"
channel = TaskManager.get_completion_channel(task_uuid)
assert channel == "gtf:complete:abc-123-def-456"
def test_get_completion_channel_custom_prefix(self):
"""Test get_completion_channel with custom prefix"""
TaskManager._completion_channel_prefix = "custom:complete:"
task_uuid = "test-uuid"
channel = TaskManager.get_completion_channel(task_uuid)
assert channel == "custom:complete:test-uuid"
@patch("superset.tasks.manager.cache_manager")
def test_publish_completion_no_redis(self, mock_cache_manager):
"""Test publish_completion returns False when Redis not available"""
mock_cache_manager.signal_cache = None
result = TaskManager.publish_completion("test-uuid", "success")
assert result is False
@patch("superset.tasks.manager.cache_manager")
def test_publish_completion_success(self, mock_cache_manager):
"""Test publish_completion publishes message successfully"""
mock_redis = MagicMock()
mock_redis.publish.return_value = 1 # One subscriber
mock_cache_manager.signal_cache = mock_redis
result = TaskManager.publish_completion("test-uuid", "success")
assert result is True
mock_redis.publish.assert_called_once_with("gtf:complete:test-uuid", "success")
@patch("superset.tasks.manager.cache_manager")
def test_publish_completion_redis_error(self, mock_cache_manager):
"""Test publish_completion handles Redis errors gracefully"""
mock_redis = MagicMock()
mock_redis.publish.side_effect = redis.RedisError("Connection lost")
mock_cache_manager.signal_cache = mock_redis
result = TaskManager.publish_completion("test-uuid", "success")
assert result is False
@patch("superset.tasks.manager.cache_manager")
@patch("superset.daos.tasks.TaskDAO")
def test_wait_for_completion_task_not_found(self, mock_dao, mock_cache_manager):
"""Test wait_for_completion raises ValueError for missing task"""
import pytest
mock_cache_manager.signal_cache = None
mock_dao.find_one_or_none.return_value = None
with pytest.raises(ValueError, match="not found"):
TaskManager.wait_for_completion("nonexistent-uuid")
@patch("superset.tasks.manager.cache_manager")
@patch("superset.daos.tasks.TaskDAO")
def test_wait_for_completion_already_complete(self, mock_dao, mock_cache_manager):
"""Test wait_for_completion returns immediately for terminal state"""
mock_cache_manager.signal_cache = None
mock_task = MagicMock()
mock_task.uuid = "test-uuid"
mock_task.status = "success"
mock_dao.find_one_or_none.return_value = mock_task
result = TaskManager.wait_for_completion("test-uuid")
assert result == mock_task
# Should only call find_one_or_none once (initial check)
mock_dao.find_one_or_none.assert_called_once()
@patch("superset.tasks.manager.cache_manager")
@patch("superset.daos.tasks.TaskDAO")
def test_wait_for_completion_timeout(self, mock_dao, mock_cache_manager):
"""Test wait_for_completion raises TimeoutError when timeout expires"""
import pytest
mock_cache_manager.signal_cache = None
mock_task = MagicMock()
mock_task.uuid = "test-uuid"
mock_task.status = "in_progress" # Never completes
mock_dao.find_one_or_none.return_value = mock_task
with pytest.raises(TimeoutError, match="Timeout waiting"):
TaskManager.wait_for_completion("test-uuid", timeout=0.1)
@patch("superset.tasks.manager.cache_manager")
@patch("superset.daos.tasks.TaskDAO")
def test_wait_for_completion_polling_success(self, mock_dao, mock_cache_manager):
"""Test wait_for_completion returns when task completes via polling"""
mock_cache_manager.signal_cache = None
mock_task_pending = MagicMock()
mock_task_pending.uuid = "test-uuid"
mock_task_pending.status = "pending"
mock_task_complete = MagicMock()
mock_task_complete.uuid = "test-uuid"
mock_task_complete.status = "success"
# First call returns pending, second returns complete
mock_dao.find_one_or_none.side_effect = [
mock_task_pending,
mock_task_complete,
]
result = TaskManager.wait_for_completion(
"test-uuid",
timeout=5.0,
poll_interval=0.1,
)
assert result.status == "success"
@patch("superset.tasks.manager.cache_manager")
@patch("superset.daos.tasks.TaskDAO")
def test_wait_for_completion_with_pubsub(self, mock_dao, mock_cache_manager):
"""Test wait_for_completion uses pub/sub when Redis available"""
mock_task_pending = MagicMock()
mock_task_pending.uuid = "test-uuid"
mock_task_pending.status = "pending"
mock_task_complete = MagicMock()
mock_task_complete.uuid = "test-uuid"
mock_task_complete.status = "success"
# First call returns pending, second returns complete
mock_dao.find_one_or_none.side_effect = [
mock_task_pending,
mock_task_complete,
]
# Set up mock Redis with pub/sub
mock_redis = MagicMock()
mock_pubsub = MagicMock()
# Simulate receiving a completion message
mock_pubsub.get_message.return_value = {
"type": "message",
"data": "success",
}
mock_redis.pubsub.return_value = mock_pubsub
mock_cache_manager.signal_cache = mock_redis
result = TaskManager.wait_for_completion(
"test-uuid",
timeout=5.0,
)
assert result.status == "success"
# Should have subscribed to completion channel
mock_pubsub.subscribe.assert_called_once_with("gtf:complete:test-uuid")
# Should have cleaned up
mock_pubsub.unsubscribe.assert_called_once()
mock_pubsub.close.assert_called_once()
@patch("superset.tasks.manager.cache_manager")
@patch("superset.daos.tasks.TaskDAO")
def test_wait_for_completion_pubsub_error_raises(
self, mock_dao, mock_cache_manager
):
"""Test wait_for_completion raises exception on Redis error when
Redis configured"""
import pytest
mock_task_pending = MagicMock()
mock_task_pending.uuid = "test-uuid"
mock_task_pending.status = "pending"
mock_dao.find_one_or_none.return_value = mock_task_pending
# Set up mock Redis that fails
mock_redis = MagicMock()
mock_redis.pubsub.side_effect = redis.RedisError("Connection failed")
mock_cache_manager.signal_cache = mock_redis
# With fail-fast behavior, Redis error is raised instead of falling back
with pytest.raises(redis.RedisError, match="Connection failed"):
TaskManager.wait_for_completion(
"test-uuid",
timeout=5.0,
poll_interval=0.1,
)
def test_terminal_states_constant(self):
"""Test TERMINAL_STATES contains expected values"""
expected = {"success", "failure", "aborted", "timed_out"}
assert TaskManager.TERMINAL_STATES == expected

View File

@@ -0,0 +1,612 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Unit tests for GTF timeout handling."""
import time
from unittest.mock import MagicMock, patch
from uuid import UUID
import pytest
from superset_core.api.tasks import TaskOptions, TaskScope
from superset.tasks.context import TaskContext
from superset.tasks.decorators import TaskWrapper
TEST_UUID = UUID("b8b61b7b-1cd3-4a31-a74a-0a95341afc06")
# =============================================================================
# Fixtures
# =============================================================================
@pytest.fixture
def mock_flask_app():
"""Create a properly configured mock Flask app."""
mock_app = MagicMock()
mock_app.config = {
"TASK_ABORT_POLLING_DEFAULT_INTERVAL": 0.1,
}
# Make app_context() return a proper context manager
mock_app.app_context.return_value.__enter__ = MagicMock(return_value=None)
mock_app.app_context.return_value.__exit__ = MagicMock(return_value=None)
return mock_app
@pytest.fixture
def mock_task_abortable():
"""Create a mock task that is abortable."""
task = MagicMock()
task.uuid = TEST_UUID
task.status = "in_progress"
task.properties_dict = {"is_abortable": True}
task.payload_dict = {}
# Set real values for dedup_key generation (used by UpdateTaskCommand lock)
task.scope = "shared"
task.task_type = "test_task"
task.task_key = "test_key"
task.user_id = 1
return task
@pytest.fixture
def mock_task_not_abortable():
"""Create a mock task that is NOT abortable."""
task = MagicMock()
task.uuid = TEST_UUID
task.status = "in_progress"
task.properties_dict = {} # No is_abortable means it's not abortable
task.payload_dict = {}
# Set real values for dedup_key generation (used by UpdateTaskCommand lock)
task.scope = "shared"
task.task_type = "test_task"
task.task_key = "test_key"
task.user_id = 1
return task
@pytest.fixture
def task_context_for_timeout(mock_flask_app, mock_task_abortable):
"""Create TaskContext with mocked dependencies for timeout tests."""
# Ensure mock_task has required attributes for TaskContext
mock_task_abortable.payload_dict = {}
with (
patch("superset.tasks.context.current_app") as mock_current_app,
patch("superset.daos.tasks.TaskDAO") as mock_dao,
patch("superset.tasks.manager.cache_manager") as mock_cache_manager,
):
# Disable Redis by making signal_cache return None
mock_cache_manager.signal_cache = None
# Configure current_app mock
mock_current_app.config = mock_flask_app.config
mock_current_app._get_current_object.return_value = mock_flask_app
# Configure TaskDAO mock
mock_dao.find_one_or_none.return_value = mock_task_abortable
ctx = TaskContext(mock_task_abortable)
ctx._app = mock_flask_app
yield ctx
# Cleanup: stop timers if started
ctx.stop_timeout_timer()
if ctx._abort_listener:
ctx.stop_abort_polling()
# =============================================================================
# TaskWrapper._merge_options Timeout Tests
# =============================================================================
class TestTimeoutMerging:
"""Test timeout merging behavior in TaskWrapper._merge_options."""
def test_merge_options_decorator_timeout_used_when_no_override(self):
"""Test that decorator timeout is used when no override is provided."""
def dummy_func():
pass
wrapper = TaskWrapper(
name="test_task",
func=dummy_func,
default_options=TaskOptions(),
scope=TaskScope.PRIVATE,
default_timeout=300, # 5-minute default
)
merged = wrapper._merge_options(None)
assert merged.timeout == 300
def test_merge_options_override_timeout_takes_precedence(self):
"""Test that TaskOptions timeout overrides decorator default."""
def dummy_func():
pass
wrapper = TaskWrapper(
name="test_task",
func=dummy_func,
default_options=TaskOptions(),
scope=TaskScope.PRIVATE,
default_timeout=300, # 5-minute default
)
override = TaskOptions(timeout=600) # 10-minute override
merged = wrapper._merge_options(override)
assert merged.timeout == 600
def test_merge_options_no_timeout_when_not_configured(self):
"""Test that no timeout is set when not configured anywhere."""
def dummy_func():
pass
wrapper = TaskWrapper(
name="test_task",
func=dummy_func,
default_options=TaskOptions(),
scope=TaskScope.PRIVATE,
default_timeout=None, # No default timeout
)
merged = wrapper._merge_options(None)
assert merged.timeout is None
def test_merge_options_override_with_other_options_preserves_timeout(self):
"""Test that setting other options doesn't lose decorator timeout."""
def dummy_func():
pass
wrapper = TaskWrapper(
name="test_task",
func=dummy_func,
default_options=TaskOptions(),
scope=TaskScope.PRIVATE,
default_timeout=300,
)
# Override only task_key, not timeout
override = TaskOptions(task_key="my-key")
merged = wrapper._merge_options(override)
# Should keep decorator timeout since override.timeout is None
assert merged.timeout == 300
assert merged.task_key == "my-key"
# =============================================================================
# TaskContext Timeout Timer Tests
# =============================================================================
class TestTimeoutTimer:
"""Test TaskContext timeout timer behavior."""
def test_start_timeout_timer_sets_timer(self, task_context_for_timeout):
"""Test that start_timeout_timer creates a timer."""
ctx = task_context_for_timeout
assert ctx._timeout_timer is None
ctx.start_timeout_timer(10)
assert ctx._timeout_timer is not None
assert ctx._timeout_triggered is False
def test_start_timeout_timer_is_idempotent(self, task_context_for_timeout):
"""Test that starting timer twice doesn't create duplicate timers."""
ctx = task_context_for_timeout
ctx.start_timeout_timer(10)
first_timer = ctx._timeout_timer
ctx.start_timeout_timer(20) # Try to start again
second_timer = ctx._timeout_timer
assert first_timer is second_timer
def test_stop_timeout_timer_cancels_timer(self, task_context_for_timeout):
"""Test that stop_timeout_timer cancels the timer."""
ctx = task_context_for_timeout
ctx.start_timeout_timer(10)
assert ctx._timeout_timer is not None
ctx.stop_timeout_timer()
assert ctx._timeout_timer is None
def test_stop_timeout_timer_safe_when_no_timer(self, task_context_for_timeout):
"""Test that stop_timeout_timer doesn't fail when no timer exists."""
ctx = task_context_for_timeout
assert ctx._timeout_timer is None
ctx.stop_timeout_timer() # Should not raise
assert ctx._timeout_timer is None
def test_timeout_triggered_property_initially_false(self, task_context_for_timeout):
"""Test that timeout_triggered is False initially."""
ctx = task_context_for_timeout
assert ctx.timeout_triggered is False
def test_cleanup_stops_timeout_timer(self, task_context_for_timeout):
"""Test that _run_cleanup stops the timeout timer."""
ctx = task_context_for_timeout
ctx.start_timeout_timer(10)
assert ctx._timeout_timer is not None
ctx._run_cleanup()
assert ctx._timeout_timer is None
class TestTimeoutTrigger:
"""Test timeout trigger behavior when timer fires."""
def test_timeout_triggers_abort_when_abortable(
self, mock_flask_app, mock_task_abortable
):
"""Test that timeout triggers abort handlers when task is abortable."""
abort_called = False
with (
patch("superset.tasks.context.current_app") as mock_current_app,
patch("superset.daos.tasks.TaskDAO") as mock_dao,
patch(
"superset.commands.tasks.update.UpdateTaskCommand"
) as mock_update_cmd,
patch("superset.tasks.manager.cache_manager") as mock_cache_manager,
):
# Disable Redis by making signal_cache return None
mock_cache_manager.signal_cache = None
mock_current_app.config = mock_flask_app.config
mock_current_app._get_current_object.return_value = mock_flask_app
mock_dao.find_one_or_none.return_value = mock_task_abortable
ctx = TaskContext(mock_task_abortable)
ctx._app = mock_flask_app
@ctx.on_abort
def handle_abort():
nonlocal abort_called
abort_called = True
# Start short timeout
ctx.start_timeout_timer(1)
# Wait for timeout to fire
time.sleep(1.5)
# Abort handler should have been called
assert abort_called
assert ctx._timeout_triggered
assert ctx._abort_detected
# Verify UpdateTaskCommand was called with ABORTING status
mock_update_cmd.assert_called()
call_kwargs = mock_update_cmd.call_args[1]
assert call_kwargs.get("status") == "aborting"
# Cleanup
ctx.stop_timeout_timer()
if ctx._abort_listener:
ctx.stop_abort_polling()
def test_timeout_logs_warning_when_not_abortable(
self, mock_flask_app, mock_task_not_abortable
):
"""Test that timeout logs warning when task has no abort handler."""
with (
patch("superset.tasks.context.current_app") as mock_current_app,
patch("superset.daos.tasks.TaskDAO") as mock_dao,
patch("superset.tasks.context.logger") as mock_logger,
patch("superset.tasks.manager.cache_manager") as mock_cache_manager,
):
# Disable Redis by making signal_cache return None
mock_cache_manager.signal_cache = None
mock_current_app.config = mock_flask_app.config
mock_current_app._get_current_object.return_value = mock_flask_app
mock_dao.find_one_or_none.return_value = mock_task_not_abortable
ctx = TaskContext(mock_task_not_abortable)
ctx._app = mock_flask_app
# No abort handler registered
# Start short timeout
ctx.start_timeout_timer(1)
# Wait for timeout to fire
time.sleep(1.5)
# Should have logged warning
mock_logger.warning.assert_called()
warning_call = mock_logger.warning.call_args
assert "no abort handler" in warning_call[0][0].lower()
assert ctx._timeout_triggered
assert not ctx._abort_detected # No abort since no handler
# Cleanup
ctx.stop_timeout_timer()
def test_timeout_does_not_trigger_if_already_aborting(
self, mock_flask_app, mock_task_abortable
):
"""Test that timeout doesn't re-trigger abort if already aborting."""
abort_count = 0
with (
patch("superset.tasks.context.current_app") as mock_current_app,
patch("superset.daos.tasks.TaskDAO") as mock_dao,
patch("superset.commands.tasks.update.UpdateTaskCommand"),
patch("superset.tasks.manager.cache_manager") as mock_cache_manager,
):
# Disable Redis by making signal_cache return None
mock_cache_manager.signal_cache = None
mock_current_app.config = mock_flask_app.config
mock_current_app._get_current_object.return_value = mock_flask_app
mock_dao.find_one_or_none.return_value = mock_task_abortable
ctx = TaskContext(mock_task_abortable)
ctx._app = mock_flask_app
@ctx.on_abort
def handle_abort():
nonlocal abort_count
abort_count += 1
# Pre-set abort detected
ctx._abort_detected = True
# Start short timeout
ctx.start_timeout_timer(1)
# Wait for timeout to fire
time.sleep(1.5)
# Handler should NOT have been called since already aborting
assert abort_count == 0
# Cleanup
ctx.stop_timeout_timer()
if ctx._abort_listener:
ctx.stop_abort_polling()
# =============================================================================
# Task Decorator Timeout Tests
# =============================================================================
class TestTaskDecoratorTimeout:
"""Test @task decorator timeout parameter."""
def test_task_decorator_accepts_timeout(self):
"""Test that @task decorator accepts timeout parameter."""
from superset.tasks.decorators import task
from superset.tasks.registry import TaskRegistry
@task(name="test_timeout_task_1", timeout=300)
def timeout_test_task_1():
pass
assert isinstance(timeout_test_task_1, TaskWrapper)
assert timeout_test_task_1.default_timeout == 300
# Cleanup registry
TaskRegistry._tasks.pop("test_timeout_task_1", None)
def test_task_decorator_without_timeout(self):
"""Test that @task decorator works without timeout."""
from superset.tasks.decorators import task
from superset.tasks.registry import TaskRegistry
@task(name="test_timeout_task_2")
def timeout_test_task_2():
pass
assert isinstance(timeout_test_task_2, TaskWrapper)
assert timeout_test_task_2.default_timeout is None
# Cleanup registry
TaskRegistry._tasks.pop("test_timeout_task_2", None)
def test_task_decorator_with_all_params(self):
"""Test that @task decorator accepts all parameters together."""
from superset.tasks.decorators import task
from superset.tasks.registry import TaskRegistry
@task(name="test_timeout_task_3", scope=TaskScope.SHARED, timeout=600)
def timeout_test_task_3():
pass
assert timeout_test_task_3.name == "test_timeout_task_3"
assert timeout_test_task_3.scope == TaskScope.SHARED
assert timeout_test_task_3.default_timeout == 600
# Cleanup registry
TaskRegistry._tasks.pop("test_timeout_task_3", None)
# =============================================================================
# Timeout Terminal State Tests
# =============================================================================
class TestTimeoutTerminalState:
"""Test timeout transitions to correct terminal state (TIMED_OUT vs FAILURE)."""
def test_timeout_triggered_flag_set_on_timeout(
self, mock_flask_app, mock_task_abortable
):
"""Test that timeout_triggered flag is set when timeout fires."""
with (
patch("superset.tasks.context.current_app") as mock_current_app,
patch("superset.daos.tasks.TaskDAO") as mock_dao,
patch("superset.commands.tasks.update.UpdateTaskCommand"),
patch("superset.tasks.manager.cache_manager") as mock_cache_manager,
):
# Disable Redis by making signal_cache return None
mock_cache_manager.signal_cache = None
mock_current_app.config = mock_flask_app.config
mock_current_app._get_current_object.return_value = mock_flask_app
mock_dao.find_one_or_none.return_value = mock_task_abortable
ctx = TaskContext(mock_task_abortable)
ctx._app = mock_flask_app
@ctx.on_abort
def handle_abort():
pass
# Initially not triggered
assert ctx.timeout_triggered is False
# Start short timeout
ctx.start_timeout_timer(1)
# Wait for timeout to fire
time.sleep(1.5)
# Should be set after timeout
assert ctx.timeout_triggered is True
# Cleanup
ctx.stop_timeout_timer()
if ctx._abort_listener:
ctx.stop_abort_polling()
def test_user_abort_does_not_set_timeout_triggered(
self, mock_flask_app, mock_task_abortable
):
"""Test that user abort doesn't set timeout_triggered flag."""
with (
patch("superset.tasks.context.current_app") as mock_current_app,
patch("superset.daos.tasks.TaskDAO") as mock_dao,
patch("superset.commands.tasks.update.UpdateTaskCommand"),
patch("superset.tasks.manager.cache_manager") as mock_cache_manager,
):
# Disable Redis by making signal_cache return None
mock_cache_manager.signal_cache = None
mock_current_app.config = mock_flask_app.config
mock_current_app._get_current_object.return_value = mock_flask_app
mock_dao.find_one_or_none.return_value = mock_task_abortable
ctx = TaskContext(mock_task_abortable)
ctx._app = mock_flask_app
@ctx.on_abort
def handle_abort():
pass
# Simulate user abort (not timeout)
ctx._on_abort_detected()
# timeout_triggered should still be False
assert ctx.timeout_triggered is False
# But abort_detected should be True
assert ctx._abort_detected is True
# Cleanup
if ctx._abort_listener:
ctx.stop_abort_polling()
def test_abort_handlers_completed_tracks_success(
self, mock_flask_app, mock_task_abortable
):
"""Test that abort_handlers_completed flag tracks successful
handler execution."""
with (
patch("superset.tasks.context.current_app") as mock_current_app,
patch("superset.daos.tasks.TaskDAO") as mock_dao,
patch("superset.commands.tasks.update.UpdateTaskCommand"),
patch("superset.tasks.manager.cache_manager") as mock_cache_manager,
):
# Disable Redis by making signal_cache return None
mock_cache_manager.signal_cache = None
mock_current_app.config = mock_flask_app.config
mock_current_app._get_current_object.return_value = mock_flask_app
mock_dao.find_one_or_none.return_value = mock_task_abortable
ctx = TaskContext(mock_task_abortable)
ctx._app = mock_flask_app
@ctx.on_abort
def handle_abort():
pass # Successful handler
# Initially not completed
assert ctx.abort_handlers_completed is False
# Trigger abort handlers
ctx._trigger_abort_handlers()
# Should be marked as completed
assert ctx.abort_handlers_completed is True
# Cleanup
if ctx._abort_listener:
ctx.stop_abort_polling()
def test_abort_handlers_completed_false_on_exception(
self, mock_flask_app, mock_task_abortable
):
"""Test that abort_handlers_completed is False when handler throws."""
with (
patch("superset.tasks.context.current_app") as mock_current_app,
patch("superset.daos.tasks.TaskDAO") as mock_dao,
patch("superset.commands.tasks.update.UpdateTaskCommand"),
patch("superset.tasks.manager.cache_manager") as mock_cache_manager,
):
# Disable Redis by making signal_cache return None
mock_cache_manager.signal_cache = None
mock_current_app.config = mock_flask_app.config
mock_current_app._get_current_object.return_value = mock_flask_app
mock_dao.find_one_or_none.return_value = mock_task_abortable
ctx = TaskContext(mock_task_abortable)
ctx._app = mock_flask_app
@ctx.on_abort
def handle_abort():
raise ValueError("Handler failed")
# Initially not completed
assert ctx.abort_handlers_completed is False
# Trigger abort handlers (will catch the exception internally)
ctx._trigger_abort_handlers()
# Should NOT be marked as completed since handler threw
assert ctx.abort_handlers_completed is False
# Cleanup
if ctx._abort_listener:
ctx.stop_abort_polling()

View File

@@ -22,9 +22,19 @@ from typing import Any, Optional, Union
import pytest
from flask_appbuilder.security.sqla.models import User
from superset_core.api.tasks import TaskScope
from superset.tasks.exceptions import ExecutorNotFoundError, InvalidExecutorError
from superset.tasks.types import Executor, ExecutorType, FixedExecutor
from superset.tasks.utils import (
error_update,
get_active_dedup_key,
get_finished_dedup_key,
parse_properties,
progress_update,
serialize_properties,
)
from superset.utils.hashing import hash_from_str
FIXED_USER_ID = 1234
FIXED_USERNAME = "admin"
@@ -330,3 +340,242 @@ def test_get_executor(
)
assert executor_type == expected_executor_type
assert executor == expected_executor
@pytest.mark.parametrize(
"scope,task_type,task_key,user_id,expected_composite_key",
[
# Private tasks with TaskScope enum
(
TaskScope.PRIVATE,
"sql_execution",
"chart_123",
42,
"private|sql_execution|chart_123|42",
),
(
TaskScope.PRIVATE,
"thumbnail_gen",
"dash_456",
100,
"private|thumbnail_gen|dash_456|100",
),
# Private tasks with string scope
(
"private",
"api_call",
"endpoint_789",
200,
"private|api_call|endpoint_789|200",
),
# Shared tasks with TaskScope enum
(
TaskScope.SHARED,
"report_gen",
"monthly_report",
None,
"shared|report_gen|monthly_report",
),
(
TaskScope.SHARED,
"export_csv",
"large_export",
999, # user_id should be ignored for shared
"shared|export_csv|large_export",
),
# Shared tasks with string scope
(
"shared",
"batch_process",
"batch_001",
123, # user_id should be ignored for shared
"shared|batch_process|batch_001",
),
# System tasks with TaskScope enum
(
TaskScope.SYSTEM,
"cleanup_task",
"daily_cleanup",
None,
"system|cleanup_task|daily_cleanup",
),
(
TaskScope.SYSTEM,
"db_migration",
"version_123",
1, # user_id should be ignored for system
"system|db_migration|version_123",
),
# System tasks with string scope
(
"system",
"maintenance",
"nightly_job",
2, # user_id should be ignored for system
"system|maintenance|nightly_job",
),
],
)
def test_get_active_dedup_key(
scope, task_type, task_key, user_id, expected_composite_key, app_context
):
"""Test get_active_dedup_key generates a hash of the composite key.
The function hashes the composite key using the configured HASH_ALGORITHM
to produce a fixed-length dedup_key for database storage. The result is
truncated to 64 chars to fit the database column.
"""
result = get_active_dedup_key(scope, task_type, task_key, user_id)
# The result should be a hash of the expected composite key, truncated to 64 chars
expected_hash = hash_from_str(expected_composite_key)[:64]
assert result == expected_hash
assert len(result) <= 64
def test_get_active_dedup_key_private_requires_user_id():
"""Test that private tasks require explicit user_id parameter."""
with pytest.raises(ValueError, match="user_id required for private tasks"):
get_active_dedup_key(TaskScope.PRIVATE, "test_type", "test_key")
def test_get_finished_dedup_key():
"""Test that finished tasks use UUID as dedup_key"""
test_uuid = "a1b2c3d4-e5f6-7890-abcd-ef1234567890"
result = get_finished_dedup_key(test_uuid)
assert result == test_uuid
@pytest.mark.parametrize(
"progress,expected",
[
# Float (percentage) progress
(0.5, {"progress_percent": 0.5}),
(0.0, {"progress_percent": 0.0}),
(1.0, {"progress_percent": 1.0}),
(0.25, {"progress_percent": 0.25}),
# Int (count only) progress
(42, {"progress_current": 42}),
(0, {"progress_current": 0}),
(1000, {"progress_current": 1000}),
# Tuple (current, total) progress with auto-computed percentage
(
(50, 100),
{"progress_current": 50, "progress_total": 100, "progress_percent": 0.5},
),
(
(25, 100),
{"progress_current": 25, "progress_total": 100, "progress_percent": 0.25},
),
(
(100, 100),
{"progress_current": 100, "progress_total": 100, "progress_percent": 1.0},
),
# Tuple with zero total (no percentage computed)
((10, 0), {"progress_current": 10, "progress_total": 0}),
((0, 0), {"progress_current": 0, "progress_total": 0}),
],
)
def test_progress_update(progress, expected):
"""Test progress_update returns correct TaskProperties dict."""
result = progress_update(progress)
assert result == expected
def test_error_update():
"""Test error_update captures exception details."""
try:
raise ValueError("Test error message")
except ValueError as e:
result = error_update(e)
assert result["error_message"] == "Test error message"
assert result["exception_type"] == "ValueError"
assert "stack_trace" in result
assert "ValueError" in result["stack_trace"]
def test_error_update_custom_exception():
"""Test error_update with custom exception class."""
class CustomError(Exception):
pass
try:
raise CustomError("Custom error")
except CustomError as e:
result = error_update(e)
assert result["error_message"] == "Custom error"
assert result["exception_type"] == "CustomError"
@pytest.mark.parametrize(
"json_str,expected",
[
# Valid JSON
(
'{"is_abortable": true, "progress_percent": 0.5}',
{"is_abortable": True, "progress_percent": 0.5},
),
(
'{"error_message": "Something failed"}',
{"error_message": "Something failed"},
),
(
'{"progress_current": 50, "progress_total": 100}',
{"progress_current": 50, "progress_total": 100},
),
# Empty/None cases
("", {}),
(None, {}),
# Invalid JSON returns empty dict
("not valid json", {}),
("{broken", {}),
# Unknown keys are preserved (forward compatibility)
(
'{"is_abortable": true, "future_field": "value"}',
{"is_abortable": True, "future_field": "value"},
),
],
)
def test_parse_properties(json_str, expected):
"""Test parse_properties parses JSON to TaskProperties dict."""
result = parse_properties(json_str)
assert result == expected
@pytest.mark.parametrize(
"props,expected_contains",
[
# Full properties
(
{"is_abortable": True, "progress_percent": 0.5},
{"is_abortable": True, "progress_percent": 0.5},
),
# Empty dict
({}, {}),
# Sparse properties
({"is_abortable": True}, {"is_abortable": True}),
({"error_message": "fail"}, {"error_message": "fail"}),
],
)
def test_serialize_properties(props, expected_contains):
"""Test serialize_properties converts TaskProperties to JSON."""
from superset.utils import json
result = serialize_properties(props)
parsed = json.loads(result)
assert parsed == expected_contains
def test_properties_roundtrip():
"""Test that serialize -> parse roundtrip preserves data."""
original = {
"is_abortable": True,
"progress_percent": 0.75,
"error_message": "Test error",
}
serialized = serialize_properties(original)
parsed = parse_properties(serialized)
assert parsed == original

View File

@@ -54,7 +54,7 @@ def test_json_loads_exception():
def test_json_loads_encoding():
unicode_data = b'{"a": "\u0073\u0074\u0072"}'
unicode_data = rb'{"a": "\u0073\u0074\u0072"}'
data = json.loads(unicode_data)
assert data["a"] == "str"
utf16_data = b'\xff\xfe{\x00"\x00a\x00"\x00:\x00 \x00"\x00s\x00t\x00r\x00"\x00}\x00'

View File

@@ -119,7 +119,7 @@ def test_refresh_oauth2_token_deletes_token_on_oauth2_exception(
was revoked), the invalid token should be deleted and the exception re-raised.
"""
db = mocker.patch("superset.utils.oauth2.db")
mocker.patch("superset.utils.oauth2.KeyValueDistributedLock")
mocker.patch("superset.utils.oauth2.DistributedLock")
class OAuth2ExceptionError(Exception):
pass
@@ -149,7 +149,7 @@ def test_refresh_oauth2_token_keeps_token_on_other_exception(
exception re-raised.
"""
db = mocker.patch("superset.utils.oauth2.db")
mocker.patch("superset.utils.oauth2.KeyValueDistributedLock")
mocker.patch("superset.utils.oauth2.DistributedLock")
class OAuth2ExceptionError(Exception):
pass
@@ -175,7 +175,7 @@ def test_refresh_oauth2_token_no_access_token_in_response(
This can happen when the refresh token was revoked.
"""
mocker.patch("superset.utils.oauth2.db")
mocker.patch("superset.utils.oauth2.KeyValueDistributedLock")
mocker.patch("superset.utils.oauth2.DistributedLock")
db_engine_spec = mocker.MagicMock()
db_engine_spec.get_oauth2_fresh_token.return_value = {
"error": "invalid_grant",