mirror of
https://github.com/apache/superset.git
synced 2026-04-20 00:24:38 +00:00
feat: add global task framework (#36368)
This commit is contained in:
@@ -73,6 +73,7 @@ FEATURE_FLAGS = {
|
||||
"AVOID_COLORS_COLLISION": True,
|
||||
"DRILL_TO_DETAIL": True,
|
||||
"DRILL_BY": True,
|
||||
"GLOBAL_TASK_FRAMEWORK": True,
|
||||
}
|
||||
|
||||
WEBDRIVER_BASEURL = "http://0.0.0.0:8081/"
|
||||
|
||||
538
tests/integration_tests/tasks/api_tests.py
Normal file
538
tests/integration_tests/tasks/api_tests.py
Normal file
@@ -0,0 +1,538 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""Integration tests for Task REST API"""
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import Generator
|
||||
|
||||
import prison
|
||||
from superset_core.api.tasks import TaskStatus
|
||||
|
||||
from superset import db
|
||||
from superset.models.tasks import Task
|
||||
from superset.utils import json
|
||||
from tests.integration_tests.base_tests import SupersetTestCase
|
||||
from tests.integration_tests.constants import (
|
||||
ADMIN_USERNAME,
|
||||
GAMMA_USERNAME,
|
||||
)
|
||||
|
||||
|
||||
class TestTaskApi(SupersetTestCase):
|
||||
"""Tests for Task REST API"""
|
||||
|
||||
TASK_API_BASE = "api/v1/task"
|
||||
|
||||
@contextmanager
|
||||
def _create_tasks(self) -> Generator[list[Task], None, None]:
|
||||
"""
|
||||
Context manager to create test tasks with guaranteed cleanup.
|
||||
|
||||
Uses TaskDAO to create tasks, testing the actual production code path.
|
||||
|
||||
Usage:
|
||||
with self._create_tasks() as tasks:
|
||||
# Use tasks in test
|
||||
# Cleanup happens automatically even if test fails
|
||||
"""
|
||||
from superset_core.api.tasks import TaskScope
|
||||
|
||||
from superset.daos.tasks import TaskDAO
|
||||
|
||||
admin = self.get_user("admin")
|
||||
gamma = self.get_user("gamma")
|
||||
|
||||
tasks = []
|
||||
|
||||
try:
|
||||
# Create tasks with different statuses using TaskDAO
|
||||
for i in range(5):
|
||||
task_key = f"test_task_{i}"
|
||||
|
||||
# Create task using DAO (this tests the dedup_key creation logic)
|
||||
task = TaskDAO.create_task(
|
||||
task_type="test_type",
|
||||
task_key=task_key,
|
||||
task_name=f"Test Task {i}",
|
||||
scope=TaskScope.PRIVATE,
|
||||
user_id=admin.id,
|
||||
payload={"test": "data"},
|
||||
)
|
||||
|
||||
# Set created_by for test purposes (DAO uses Flask-AppBuilder context)
|
||||
task.created_by = admin
|
||||
|
||||
# Alternate between pending and finished tasks
|
||||
if i % 2 != 0:
|
||||
# Simulate realistic task lifecycle: PENDING → IN_PROGRESS → SUCCESS
|
||||
# This sets both started_at (on IN_PROGRESS) and ended_at (on
|
||||
# SUCCESS) so duration_seconds returns a valid value
|
||||
task.set_status(TaskStatus.IN_PROGRESS)
|
||||
task.set_status(TaskStatus.SUCCESS)
|
||||
|
||||
db.session.commit()
|
||||
tasks.append(task)
|
||||
|
||||
# Create pending task for gamma user (use PENDING so it can be aborted)
|
||||
gamma_task = TaskDAO.create_task(
|
||||
task_type="test_type",
|
||||
task_key="gamma_task",
|
||||
task_name="Gamma Task",
|
||||
scope=TaskScope.PRIVATE,
|
||||
user_id=gamma.id,
|
||||
payload={"user": "gamma"},
|
||||
)
|
||||
# Set created_by for test purposes
|
||||
gamma_task.created_by = gamma
|
||||
db.session.commit()
|
||||
tasks.append(gamma_task)
|
||||
|
||||
yield tasks
|
||||
finally:
|
||||
# Cleanup happens here regardless of test success/failure
|
||||
for task in tasks:
|
||||
try:
|
||||
db.session.delete(task)
|
||||
except Exception: # noqa: S110
|
||||
# Task may already be deleted or session may be in bad state
|
||||
pass
|
||||
try:
|
||||
db.session.commit()
|
||||
except Exception:
|
||||
# Rollback if commit fails
|
||||
db.session.rollback()
|
||||
|
||||
def test_info_task(self):
|
||||
"""
|
||||
Task API: Test info endpoint
|
||||
"""
|
||||
self.login(ADMIN_USERNAME)
|
||||
uri = f"{self.TASK_API_BASE}/_info"
|
||||
rv = self.client.get(uri)
|
||||
assert rv.status_code == 200
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
assert "permissions" in data
|
||||
|
||||
def test_get_task_by_uuid(self):
|
||||
"""
|
||||
Task API: Test get task by UUID and verify dedup_key is hashed
|
||||
"""
|
||||
with self._create_tasks():
|
||||
self.login(ADMIN_USERNAME)
|
||||
admin = self.get_user("admin")
|
||||
|
||||
# Get a pending task to verify active dedup_key format
|
||||
task = (
|
||||
db.session.query(Task)
|
||||
.filter_by(
|
||||
created_by_fk=admin.id,
|
||||
status=TaskStatus.PENDING.value,
|
||||
task_type="test_type",
|
||||
)
|
||||
.first()
|
||||
)
|
||||
assert task is not None
|
||||
|
||||
# Verify active task has hashed dedup_key (64 chars for SHA-256)
|
||||
assert len(task.dedup_key) == 64
|
||||
assert all(c in "0123456789abcdef" for c in task.dedup_key)
|
||||
assert task.dedup_key != str(task.uuid)
|
||||
|
||||
uri = f"{self.TASK_API_BASE}/{task.uuid}"
|
||||
rv = self.client.get(uri)
|
||||
assert rv.status_code == 200
|
||||
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
# Compare strings since JSON response contains string UUID
|
||||
assert data["result"]["uuid"] == str(task.uuid)
|
||||
assert data["result"]["id"] == task.id
|
||||
|
||||
def test_get_task_not_found(self):
|
||||
"""
|
||||
Task API: Test get task not found with non-existent UUID
|
||||
"""
|
||||
self.login(ADMIN_USERNAME)
|
||||
# Use a valid UUID that doesn't exist in the database
|
||||
uri = f"{self.TASK_API_BASE}/00000000-0000-0000-0000-000000000000"
|
||||
rv = self.client.get(uri)
|
||||
assert rv.status_code == 404
|
||||
|
||||
def test_get_task_invalid_uuid(self):
|
||||
"""
|
||||
Task API: Test get task with invalid UUID
|
||||
"""
|
||||
self.login(ADMIN_USERNAME)
|
||||
uri = f"{self.TASK_API_BASE}/invalid-uuid"
|
||||
rv = self.client.get(uri)
|
||||
assert rv.status_code == 404
|
||||
|
||||
def test_get_task_list(self):
|
||||
"""
|
||||
Task API: Test get task list
|
||||
"""
|
||||
with self._create_tasks():
|
||||
self.login(ADMIN_USERNAME)
|
||||
uri = f"{self.TASK_API_BASE}/"
|
||||
rv = self.client.get(uri)
|
||||
assert rv.status_code == 200
|
||||
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
assert data["count"] >= 6 # At least the fixtures we created
|
||||
assert "result" in data
|
||||
|
||||
def test_get_task_list_filtered_by_status(self):
|
||||
"""
|
||||
Task API: Test get task list filtered by status
|
||||
"""
|
||||
with self._create_tasks():
|
||||
self.login(ADMIN_USERNAME)
|
||||
arguments = {
|
||||
"filters": [
|
||||
{"col": "status", "opr": "eq", "value": TaskStatus.PENDING.value}
|
||||
]
|
||||
}
|
||||
uri = f"{self.TASK_API_BASE}/?q={prison.dumps(arguments)}"
|
||||
rv = self.client.get(uri)
|
||||
assert rv.status_code == 200
|
||||
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
for task in data["result"]:
|
||||
assert task["status"] == TaskStatus.PENDING.value
|
||||
|
||||
def test_get_task_list_filtered_by_type(self):
|
||||
"""
|
||||
Task API: Test get task list filtered by type
|
||||
"""
|
||||
with self._create_tasks():
|
||||
self.login(ADMIN_USERNAME)
|
||||
arguments = {
|
||||
"filters": [{"col": "task_type", "opr": "eq", "value": "test_type"}]
|
||||
}
|
||||
uri = f"{self.TASK_API_BASE}/?q={prison.dumps(arguments)}"
|
||||
rv = self.client.get(uri)
|
||||
assert rv.status_code == 200
|
||||
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
assert data["count"] >= 6
|
||||
for task in data["result"]:
|
||||
assert task["task_type"] == "test_type"
|
||||
|
||||
def test_get_task_list_ordered(self):
|
||||
"""
|
||||
Task API: Test get task list with ordering
|
||||
"""
|
||||
with self._create_tasks():
|
||||
self.login(ADMIN_USERNAME)
|
||||
arguments = {
|
||||
"order_column": "created_on",
|
||||
"order_direction": "desc",
|
||||
}
|
||||
uri = f"{self.TASK_API_BASE}/?q={prison.dumps(arguments)}"
|
||||
rv = self.client.get(uri)
|
||||
assert rv.status_code == 200
|
||||
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
assert len(data["result"]) > 0
|
||||
|
||||
def test_get_task_list_paginated(self):
|
||||
"""
|
||||
Task API: Test get task list with pagination
|
||||
"""
|
||||
with self._create_tasks():
|
||||
self.login(ADMIN_USERNAME)
|
||||
arguments = {"page": 0, "page_size": 2}
|
||||
uri = f"{self.TASK_API_BASE}/?q={prison.dumps(arguments)}"
|
||||
rv = self.client.get(uri)
|
||||
assert rv.status_code == 200
|
||||
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
assert len(data["result"]) <= 2
|
||||
assert data["count"] >= 6
|
||||
|
||||
def test_cancel_task_by_uuid(self):
|
||||
"""
|
||||
Task API: Test cancel task by UUID
|
||||
"""
|
||||
with self._create_tasks():
|
||||
self.login(ADMIN_USERNAME)
|
||||
admin = self.get_user("admin")
|
||||
|
||||
task = (
|
||||
db.session.query(Task)
|
||||
.filter_by(created_by_fk=admin.id, status=TaskStatus.PENDING.value)
|
||||
.first()
|
||||
)
|
||||
assert task is not None
|
||||
|
||||
uri = f"{self.TASK_API_BASE}/{task.uuid}/cancel"
|
||||
rv = self.client.post(uri)
|
||||
assert rv.status_code == 200
|
||||
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
# Compare strings since JSON response contains string UUID
|
||||
assert data["task"]["uuid"] == str(task.uuid)
|
||||
assert data["task"]["status"] == TaskStatus.ABORTED.value
|
||||
assert data["action"] == "aborted"
|
||||
|
||||
def test_cancel_task_not_found(self):
|
||||
"""
|
||||
Task API: Test cancel task not found with non-existent UUID
|
||||
"""
|
||||
self.login(ADMIN_USERNAME)
|
||||
uri = f"{self.TASK_API_BASE}/00000000-0000-0000-0000-000000000000/cancel"
|
||||
rv = self.client.post(uri)
|
||||
assert rv.status_code == 404
|
||||
|
||||
def test_cancel_task_not_owned(self):
|
||||
"""
|
||||
Task API: Test cancel task not owned by user
|
||||
"""
|
||||
with self._create_tasks():
|
||||
self.login(GAMMA_USERNAME)
|
||||
admin = self.get_user("admin")
|
||||
|
||||
# Try to cancel admin's task as gamma user
|
||||
task = db.session.query(Task).filter_by(created_by_fk=admin.id).first()
|
||||
assert task is not None
|
||||
|
||||
uri = f"{self.TASK_API_BASE}/{task.uuid}/cancel"
|
||||
rv = self.client.post(uri)
|
||||
assert rv.status_code == 404
|
||||
|
||||
def test_cancel_task_admin_can_cancel_others(self):
|
||||
"""
|
||||
Task API: Test admin can cancel other users' tasks
|
||||
"""
|
||||
with self._create_tasks():
|
||||
self.login(ADMIN_USERNAME)
|
||||
gamma = self.get_user("gamma")
|
||||
|
||||
# Admin cancels gamma's task
|
||||
task = db.session.query(Task).filter_by(created_by_fk=gamma.id).first()
|
||||
assert task is not None
|
||||
|
||||
uri = f"{self.TASK_API_BASE}/{task.uuid}/cancel"
|
||||
rv = self.client.post(uri)
|
||||
assert rv.status_code == 200
|
||||
|
||||
def test_get_task_status_by_uuid(self):
|
||||
"""
|
||||
Task API: Test get task status by UUID
|
||||
"""
|
||||
with self._create_tasks():
|
||||
self.login(ADMIN_USERNAME)
|
||||
admin = self.get_user("admin")
|
||||
|
||||
task = db.session.query(Task).filter_by(created_by_fk=admin.id).first()
|
||||
assert task is not None
|
||||
|
||||
uri = f"{self.TASK_API_BASE}/{task.uuid}/status"
|
||||
rv = self.client.get(uri)
|
||||
assert rv.status_code == 200
|
||||
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
assert "status" in data
|
||||
assert data["status"] == task.status
|
||||
|
||||
def test_get_task_status_not_found(self):
|
||||
"""
|
||||
Task API: Test get task status not found with non-existent UUID
|
||||
"""
|
||||
self.login(ADMIN_USERNAME)
|
||||
uri = f"{self.TASK_API_BASE}/00000000-0000-0000-0000-000000000000/status"
|
||||
rv = self.client.get(uri)
|
||||
assert rv.status_code == 404
|
||||
|
||||
def test_get_task_status_not_owned(self):
|
||||
"""
|
||||
Task API: Test non-owner can't see task status
|
||||
"""
|
||||
with self._create_tasks():
|
||||
self.login(GAMMA_USERNAME)
|
||||
admin = self.get_user("admin")
|
||||
|
||||
# Try to get status of admin's task as gamma user
|
||||
task = db.session.query(Task).filter_by(created_by_fk=admin.id).first()
|
||||
assert task is not None
|
||||
|
||||
uri = f"{self.TASK_API_BASE}/{task.uuid}/status"
|
||||
rv = self.client.get(uri)
|
||||
# Should be forbidden due to base filter
|
||||
assert rv.status_code == 404
|
||||
|
||||
def test_get_task_status_admin_can_see_others(self):
|
||||
"""
|
||||
Task API: Test admin can see other users' task status
|
||||
"""
|
||||
with self._create_tasks():
|
||||
self.login(ADMIN_USERNAME)
|
||||
gamma = self.get_user("gamma")
|
||||
|
||||
# Admin gets gamma's task status
|
||||
task = db.session.query(Task).filter_by(created_by_fk=gamma.id).first()
|
||||
assert task is not None
|
||||
|
||||
uri = f"{self.TASK_API_BASE}/{task.uuid}/status"
|
||||
rv = self.client.get(uri)
|
||||
assert rv.status_code == 200
|
||||
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
assert data["status"] == task.status
|
||||
|
||||
def test_get_task_list_user_sees_own_tasks(self):
|
||||
"""
|
||||
Task API: Test non-admin user only sees their own tasks
|
||||
"""
|
||||
with self._create_tasks():
|
||||
self.login(GAMMA_USERNAME)
|
||||
gamma = self.get_user("gamma")
|
||||
|
||||
uri = f"{self.TASK_API_BASE}/"
|
||||
rv = self.client.get(uri)
|
||||
assert rv.status_code == 200
|
||||
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
# Gamma should only see their own task
|
||||
for task in data["result"]:
|
||||
assert task["created_by"]["id"] == gamma.id
|
||||
|
||||
def test_get_task_list_admin_sees_all_tasks(self):
|
||||
"""
|
||||
Task API: Test admin sees all tasks
|
||||
"""
|
||||
with self._create_tasks():
|
||||
self.login(ADMIN_USERNAME)
|
||||
|
||||
uri = f"{self.TASK_API_BASE}/"
|
||||
rv = self.client.get(uri)
|
||||
assert rv.status_code == 200
|
||||
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
# Admin should see all tasks
|
||||
assert data["count"] >= 6
|
||||
|
||||
def test_task_response_schema(self):
|
||||
"""
|
||||
Task API: Test response schema includes all expected fields
|
||||
"""
|
||||
with self._create_tasks():
|
||||
self.login(ADMIN_USERNAME)
|
||||
admin = self.get_user("admin")
|
||||
|
||||
task = db.session.query(Task).filter_by(created_by_fk=admin.id).first()
|
||||
uri = f"{self.TASK_API_BASE}/{task.uuid}"
|
||||
rv = self.client.get(uri)
|
||||
assert rv.status_code == 200
|
||||
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
result = data["result"]
|
||||
|
||||
# Check all expected fields are present
|
||||
expected_fields = [
|
||||
"id",
|
||||
"uuid",
|
||||
"task_key",
|
||||
"task_type",
|
||||
"task_name",
|
||||
"status",
|
||||
"created_on",
|
||||
"created_on_delta_humanized",
|
||||
"changed_on",
|
||||
"changed_by",
|
||||
"started_at",
|
||||
"ended_at",
|
||||
"created_by",
|
||||
"user_id",
|
||||
"payload",
|
||||
"properties",
|
||||
"duration_seconds",
|
||||
"scope",
|
||||
"subscriber_count",
|
||||
"subscribers",
|
||||
]
|
||||
|
||||
for field in expected_fields:
|
||||
assert field in result, f"Field {field} missing from response"
|
||||
|
||||
# Verify properties is a dict with expected structure
|
||||
properties = result["properties"]
|
||||
assert isinstance(properties, dict)
|
||||
|
||||
def test_task_payload_serialization(self):
|
||||
"""
|
||||
Task API: Test payload is properly serialized as dict
|
||||
"""
|
||||
with self._create_tasks():
|
||||
self.login(ADMIN_USERNAME)
|
||||
admin = self.get_user("admin")
|
||||
|
||||
task = (
|
||||
db.session.query(Task)
|
||||
.filter_by(created_by_fk=admin.id, task_type="test_type")
|
||||
.first()
|
||||
)
|
||||
uri = f"{self.TASK_API_BASE}/{task.uuid}"
|
||||
rv = self.client.get(uri)
|
||||
assert rv.status_code == 200
|
||||
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
payload = data["result"]["payload"]
|
||||
|
||||
# Payload should be a dict, not a string
|
||||
assert isinstance(payload, dict)
|
||||
assert "test" in payload
|
||||
assert payload["test"] == "data"
|
||||
|
||||
def test_task_computed_properties(self):
|
||||
"""
|
||||
Task API: Test computed properties in response
|
||||
|
||||
This test verifies that computed properties (status, duration_seconds)
|
||||
are correctly returned in the API response. Internal DB columns like
|
||||
dedup_key are tested in unit tests (test_find_by_task_key_finished_not_found).
|
||||
"""
|
||||
with self._create_tasks():
|
||||
self.login(ADMIN_USERNAME)
|
||||
admin = self.get_user("admin")
|
||||
|
||||
# Get a successful task
|
||||
task = (
|
||||
db.session.query(Task)
|
||||
.filter_by(created_by_fk=admin.id, status=TaskStatus.SUCCESS.value)
|
||||
.first()
|
||||
)
|
||||
assert task is not None
|
||||
|
||||
uri = f"{self.TASK_API_BASE}/{task.uuid}"
|
||||
rv = self.client.get(uri)
|
||||
assert rv.status_code == 200
|
||||
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
result = data["result"]
|
||||
|
||||
# Check status field (computed properties are now derived from status)
|
||||
assert result["status"] == TaskStatus.SUCCESS.value
|
||||
|
||||
# Properties dict should exist and be a dict
|
||||
assert "properties" in result
|
||||
assert isinstance(result["properties"], dict)
|
||||
|
||||
# Verify duration_seconds is not null for completed tasks with timestamps
|
||||
# (requires both started_at and ended_at to be set)
|
||||
if result.get("started_at") and result.get("ended_at"):
|
||||
assert result["duration_seconds"] is not None
|
||||
assert result["duration_seconds"] >= 0.0
|
||||
16
tests/integration_tests/tasks/commands/__init__.py
Normal file
16
tests/integration_tests/tasks/commands/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
482
tests/integration_tests/tasks/commands/test_cancel.py
Normal file
482
tests/integration_tests/tasks/commands/test_cancel.py
Normal file
@@ -0,0 +1,482 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file to you under
|
||||
# the Apache License, Version 2.0 (the "License"); you may not
|
||||
# use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from unittest.mock import patch
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pytest
|
||||
from superset_core.api.tasks import TaskScope, TaskStatus
|
||||
|
||||
from superset import db
|
||||
from superset.commands.tasks.cancel import CancelTaskCommand
|
||||
from superset.commands.tasks.exceptions import (
|
||||
TaskAbortFailedError,
|
||||
TaskNotAbortableError,
|
||||
TaskNotFoundError,
|
||||
TaskPermissionDeniedError,
|
||||
)
|
||||
from superset.daos.tasks import TaskDAO
|
||||
from superset.utils.core import override_user
|
||||
from tests.integration_tests.test_app import app
|
||||
|
||||
|
||||
def test_cancel_pending_task_aborts(app_context, get_user) -> None:
|
||||
"""Test canceling a pending task directly aborts it"""
|
||||
admin = get_user("admin")
|
||||
|
||||
# Create a pending private task
|
||||
task = TaskDAO.create_task(
|
||||
task_type="test_type",
|
||||
task_key="cancel_pending_test",
|
||||
scope=TaskScope.PRIVATE,
|
||||
user_id=admin.id,
|
||||
)
|
||||
task.created_by = admin
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
# Cancel the pending task with admin user context
|
||||
with override_user(admin):
|
||||
command = CancelTaskCommand(task_uuid=task.uuid)
|
||||
result = command.run()
|
||||
|
||||
# Verify task is aborted (pending goes directly to ABORTED)
|
||||
assert result.uuid == task.uuid
|
||||
assert result.status == TaskStatus.ABORTED.value
|
||||
assert command.action_taken == "aborted"
|
||||
|
||||
# Verify in database
|
||||
db.session.refresh(task)
|
||||
assert task.status == TaskStatus.ABORTED.value
|
||||
finally:
|
||||
# Cleanup
|
||||
db.session.delete(task)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def test_cancel_in_progress_abortable_task_sets_aborting(app_context, get_user) -> None:
|
||||
"""Test canceling an in-progress task with abort handler sets ABORTING"""
|
||||
admin = get_user("admin")
|
||||
|
||||
# Create an in-progress abortable task
|
||||
task = TaskDAO.create_task(
|
||||
task_type="test_type",
|
||||
task_key="cancel_in_progress_test",
|
||||
scope=TaskScope.PRIVATE,
|
||||
user_id=admin.id,
|
||||
properties={"is_abortable": True},
|
||||
)
|
||||
task.created_by = admin
|
||||
task.set_status(TaskStatus.IN_PROGRESS)
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
# Cancel the in-progress task - mock publish_abort to avoid Redis dependency
|
||||
with (
|
||||
override_user(admin),
|
||||
patch("superset.tasks.manager.TaskManager.publish_abort"),
|
||||
):
|
||||
command = CancelTaskCommand(task_uuid=task.uuid)
|
||||
result = command.run()
|
||||
|
||||
# In-progress tasks go to ABORTING (not ABORTED)
|
||||
assert result.uuid == task.uuid
|
||||
assert result.status == TaskStatus.ABORTING.value
|
||||
assert command.action_taken == "aborted"
|
||||
|
||||
# Verify in database
|
||||
db.session.refresh(task)
|
||||
assert task.status == TaskStatus.ABORTING.value
|
||||
finally:
|
||||
# Cleanup
|
||||
db.session.delete(task)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def test_cancel_in_progress_not_abortable_raises_error(app_context, get_user) -> None:
|
||||
"""Test canceling an in-progress task without abort handler raises error"""
|
||||
admin = get_user("admin")
|
||||
unique_key = f"cancel_not_abortable_test_{uuid4().hex[:8]}"
|
||||
|
||||
# Create an in-progress non-abortable task
|
||||
task = TaskDAO.create_task(
|
||||
task_type="test_type",
|
||||
task_key=unique_key,
|
||||
scope=TaskScope.PRIVATE,
|
||||
user_id=admin.id,
|
||||
properties={"is_abortable": False},
|
||||
)
|
||||
task.created_by = admin
|
||||
task.set_status(TaskStatus.IN_PROGRESS)
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
with override_user(admin):
|
||||
command = CancelTaskCommand(task_uuid=task.uuid)
|
||||
|
||||
with pytest.raises(TaskNotAbortableError):
|
||||
command.run()
|
||||
|
||||
# Verify task status unchanged
|
||||
db.session.refresh(task)
|
||||
assert task.status == TaskStatus.IN_PROGRESS.value
|
||||
finally:
|
||||
# Cleanup
|
||||
db.session.delete(task)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def test_cancel_task_not_found(app_context, get_user) -> None:
|
||||
"""Test canceling non-existent task raises error"""
|
||||
admin = get_user("admin")
|
||||
|
||||
with override_user(admin):
|
||||
command = CancelTaskCommand(
|
||||
task_uuid=UUID("00000000-0000-0000-0000-000000000000")
|
||||
)
|
||||
|
||||
with pytest.raises(TaskNotFoundError):
|
||||
command.run()
|
||||
|
||||
|
||||
def test_cancel_finished_task_raises_error(app_context, get_user) -> None:
|
||||
"""Test canceling an already finished task raises error"""
|
||||
|
||||
admin = get_user("admin")
|
||||
unique_key = f"cancel_finished_test_{uuid4().hex[:8]}"
|
||||
|
||||
# Create a finished task
|
||||
task = TaskDAO.create_task(
|
||||
task_type="test_type",
|
||||
task_key=unique_key,
|
||||
scope=TaskScope.PRIVATE,
|
||||
user_id=admin.id,
|
||||
)
|
||||
task.created_by = admin
|
||||
task.set_status(TaskStatus.SUCCESS)
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
with override_user(admin):
|
||||
command = CancelTaskCommand(task_uuid=task.uuid)
|
||||
|
||||
with pytest.raises(TaskAbortFailedError):
|
||||
command.run()
|
||||
|
||||
# Verify task status unchanged
|
||||
db.session.refresh(task)
|
||||
assert task.status == TaskStatus.SUCCESS.value
|
||||
finally:
|
||||
# Cleanup
|
||||
db.session.delete(task)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def test_cancel_shared_task_with_multiple_subscribers_unsubscribes(
|
||||
app_context, get_user
|
||||
) -> None:
|
||||
"""Test canceling a shared task with multiple subscribers unsubscribes user"""
|
||||
admin = get_user("admin")
|
||||
gamma = get_user("gamma")
|
||||
|
||||
# Create a shared task with admin as creator
|
||||
task = TaskDAO.create_task(
|
||||
task_type="test_type",
|
||||
task_key="cancel_shared_test",
|
||||
scope=TaskScope.SHARED,
|
||||
user_id=admin.id,
|
||||
)
|
||||
task.created_by = admin
|
||||
db.session.commit()
|
||||
|
||||
# Add gamma as subscriber
|
||||
TaskDAO.add_subscriber(task.id, user_id=gamma.id)
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
# Verify we have 2 subscribers
|
||||
db.session.refresh(task)
|
||||
assert task.subscriber_count == 2
|
||||
|
||||
# Cancel as gamma (non-admin subscriber)
|
||||
with override_user(gamma):
|
||||
command = CancelTaskCommand(task_uuid=task.uuid)
|
||||
result = command.run()
|
||||
|
||||
# Should unsubscribe, not abort
|
||||
assert command.action_taken == "unsubscribed"
|
||||
assert result.status == TaskStatus.PENDING.value # Status unchanged
|
||||
|
||||
# Verify gamma was unsubscribed
|
||||
db.session.refresh(task)
|
||||
assert task.subscriber_count == 1
|
||||
assert not task.has_subscriber(gamma.id)
|
||||
assert task.has_subscriber(admin.id)
|
||||
finally:
|
||||
# Cleanup
|
||||
db.session.delete(task)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def test_cancel_shared_task_last_subscriber_aborts(app_context, get_user) -> None:
|
||||
"""Test canceling a shared task as last subscriber aborts it"""
|
||||
admin = get_user("admin")
|
||||
|
||||
# Create a shared task with only admin as subscriber
|
||||
task = TaskDAO.create_task(
|
||||
task_type="test_type",
|
||||
task_key="cancel_last_subscriber_test",
|
||||
scope=TaskScope.SHARED,
|
||||
user_id=admin.id,
|
||||
)
|
||||
task.created_by = admin
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
# Verify only 1 subscriber
|
||||
db.session.refresh(task)
|
||||
assert task.subscriber_count == 1
|
||||
|
||||
# Cancel as the only subscriber
|
||||
with override_user(admin):
|
||||
command = CancelTaskCommand(task_uuid=task.uuid)
|
||||
result = command.run()
|
||||
|
||||
# Should abort since last subscriber
|
||||
assert command.action_taken == "aborted"
|
||||
assert result.status == TaskStatus.ABORTED.value
|
||||
finally:
|
||||
# Cleanup
|
||||
db.session.delete(task)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def test_cancel_with_force_aborts_for_all_subscribers(app_context, get_user) -> None:
|
||||
"""Test force cancel aborts shared task even with multiple subscribers"""
|
||||
admin = get_user("admin")
|
||||
gamma = get_user("gamma")
|
||||
|
||||
# Create a shared task with multiple subscribers
|
||||
task = TaskDAO.create_task(
|
||||
task_type="test_type",
|
||||
task_key="force_cancel_test",
|
||||
scope=TaskScope.SHARED,
|
||||
user_id=admin.id,
|
||||
)
|
||||
task.created_by = admin
|
||||
db.session.commit()
|
||||
|
||||
# Add gamma as subscriber
|
||||
TaskDAO.add_subscriber(task.id, user_id=gamma.id)
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
# Verify 2 subscribers
|
||||
db.session.refresh(task)
|
||||
assert task.subscriber_count == 2
|
||||
|
||||
# Force cancel as admin
|
||||
with override_user(admin):
|
||||
command = CancelTaskCommand(task_uuid=task.uuid, force=True)
|
||||
result = command.run()
|
||||
|
||||
# Should abort despite multiple subscribers
|
||||
assert command.action_taken == "aborted"
|
||||
assert result.status == TaskStatus.ABORTED.value
|
||||
finally:
|
||||
# Cleanup
|
||||
db.session.delete(task)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def test_cancel_with_force_requires_admin(app_context, get_user) -> None:
|
||||
"""Test force cancel requires admin privileges"""
|
||||
admin = get_user("admin")
|
||||
gamma = get_user("gamma")
|
||||
|
||||
# Create a shared task
|
||||
task = TaskDAO.create_task(
|
||||
task_type="test_type",
|
||||
task_key="force_requires_admin_test",
|
||||
scope=TaskScope.SHARED,
|
||||
user_id=admin.id,
|
||||
)
|
||||
task.created_by = admin
|
||||
db.session.commit()
|
||||
|
||||
# Add gamma as subscriber
|
||||
TaskDAO.add_subscriber(task.id, user_id=gamma.id)
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
# Try to force cancel as gamma (non-admin)
|
||||
with override_user(gamma):
|
||||
command = CancelTaskCommand(task_uuid=task.uuid, force=True)
|
||||
|
||||
with pytest.raises(TaskPermissionDeniedError):
|
||||
command.run()
|
||||
|
||||
# Verify task unchanged
|
||||
db.session.refresh(task)
|
||||
assert task.status == TaskStatus.PENDING.value
|
||||
finally:
|
||||
# Cleanup
|
||||
db.session.delete(task)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def test_cancel_private_task_permission_denied(app_context, get_user) -> None:
|
||||
"""Test non-owner cannot cancel private task"""
|
||||
admin = get_user("admin")
|
||||
gamma = get_user("gamma")
|
||||
unique_key = f"private_permission_test_{uuid4().hex[:8]}"
|
||||
|
||||
# Use test_request_context to ensure has_request_context() returns True
|
||||
# so that TaskFilter properly applies permission filtering
|
||||
with app.test_request_context():
|
||||
# Create a private task owned by admin
|
||||
task = TaskDAO.create_task(
|
||||
task_type="test_type",
|
||||
task_key=unique_key,
|
||||
scope=TaskScope.PRIVATE,
|
||||
user_id=admin.id,
|
||||
)
|
||||
task.created_by = admin
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
# Try to cancel admin's private task as gamma (non-owner)
|
||||
with override_user(gamma):
|
||||
command = CancelTaskCommand(task_uuid=task.uuid)
|
||||
|
||||
# Should fail because gamma can't see admin's private task (base filter)
|
||||
with pytest.raises(TaskNotFoundError):
|
||||
command.run()
|
||||
|
||||
# Verify task unchanged
|
||||
db.session.refresh(task)
|
||||
assert task.status == TaskStatus.PENDING.value
|
||||
finally:
|
||||
# Cleanup
|
||||
db.session.delete(task)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def test_cancel_system_task_requires_admin(app_context, get_user) -> None:
|
||||
"""Test system tasks can only be canceled by admin"""
|
||||
admin = get_user("admin")
|
||||
gamma = get_user("gamma")
|
||||
unique_key = f"system_task_test_{uuid4().hex[:8]}"
|
||||
|
||||
# Use test_request_context to ensure has_request_context() returns True
|
||||
# so that TaskFilter properly applies permission filtering
|
||||
with app.test_request_context():
|
||||
# Create a system task
|
||||
task = TaskDAO.create_task(
|
||||
task_type="test_type",
|
||||
task_key=unique_key,
|
||||
scope=TaskScope.SYSTEM,
|
||||
user_id=None,
|
||||
)
|
||||
task.created_by = admin
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
# Try to cancel as gamma (non-admin)
|
||||
with override_user(gamma):
|
||||
command = CancelTaskCommand(task_uuid=task.uuid)
|
||||
|
||||
# System tasks are not visible to non-admins via base filter
|
||||
with pytest.raises(TaskNotFoundError):
|
||||
command.run()
|
||||
|
||||
# Verify task unchanged
|
||||
db.session.refresh(task)
|
||||
assert task.status == TaskStatus.PENDING.value
|
||||
|
||||
# But admin can cancel it
|
||||
with override_user(admin):
|
||||
command = CancelTaskCommand(task_uuid=task.uuid)
|
||||
result = command.run()
|
||||
|
||||
assert result.status == TaskStatus.ABORTED.value
|
||||
finally:
|
||||
# Cleanup
|
||||
db.session.delete(task)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def test_cancel_already_aborting_is_idempotent(app_context, get_user) -> None:
|
||||
"""Test canceling an already aborting task is idempotent"""
|
||||
admin = get_user("admin")
|
||||
|
||||
# Create a task already in ABORTING state
|
||||
task = TaskDAO.create_task(
|
||||
task_type="test_type",
|
||||
task_key="idempotent_cancel_test",
|
||||
scope=TaskScope.PRIVATE,
|
||||
user_id=admin.id,
|
||||
)
|
||||
task.created_by = admin
|
||||
task.set_status(TaskStatus.ABORTING)
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
# Cancel the already aborting task
|
||||
with override_user(admin):
|
||||
command = CancelTaskCommand(task_uuid=task.uuid)
|
||||
result = command.run()
|
||||
|
||||
# Should succeed without error
|
||||
assert result.uuid == task.uuid
|
||||
assert result.status == TaskStatus.ABORTING.value
|
||||
finally:
|
||||
# Cleanup
|
||||
db.session.delete(task)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def test_cancel_shared_task_not_subscribed_raises_error(app_context, get_user) -> None:
|
||||
"""Test non-subscriber cannot cancel shared task"""
|
||||
admin = get_user("admin")
|
||||
gamma = get_user("gamma")
|
||||
|
||||
# Create a shared task with only admin as subscriber
|
||||
task = TaskDAO.create_task(
|
||||
task_type="test_type",
|
||||
task_key="not_subscribed_test",
|
||||
scope=TaskScope.SHARED,
|
||||
user_id=admin.id,
|
||||
)
|
||||
task.created_by = admin
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
# Try to cancel as gamma (not subscribed)
|
||||
with override_user(gamma):
|
||||
command = CancelTaskCommand(task_uuid=task.uuid)
|
||||
|
||||
with pytest.raises(TaskPermissionDeniedError):
|
||||
command.run()
|
||||
|
||||
# Verify task unchanged
|
||||
db.session.refresh(task)
|
||||
assert task.status == TaskStatus.PENDING.value
|
||||
finally:
|
||||
# Cleanup
|
||||
db.session.delete(task)
|
||||
db.session.commit()
|
||||
419
tests/integration_tests/tasks/commands/test_internal_update.py
Normal file
419
tests/integration_tests/tasks/commands/test_internal_update.py
Normal file
@@ -0,0 +1,419 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file to you under
|
||||
# the Apache License, Version 2.0 (the "License"); you may not
|
||||
# use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""Integration tests for internal task state update commands."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from superset_core.api.tasks import TaskScope, TaskStatus
|
||||
|
||||
from superset import db
|
||||
from superset.commands.tasks.internal_update import (
|
||||
InternalStatusTransitionCommand,
|
||||
InternalUpdateTaskCommand,
|
||||
)
|
||||
from superset.daos.tasks import TaskDAO
|
||||
|
||||
|
||||
def test_internal_update_properties(app_context, get_user, login_as) -> None:
|
||||
"""Test updating only properties without reading task first."""
|
||||
admin = get_user("admin")
|
||||
login_as("admin")
|
||||
|
||||
task = TaskDAO.create_task(
|
||||
task_type="test_type",
|
||||
task_key="internal_update_props",
|
||||
scope=TaskScope.PRIVATE,
|
||||
user_id=admin.id,
|
||||
)
|
||||
task.created_by = admin
|
||||
task.set_status(TaskStatus.IN_PROGRESS)
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
# Perform zero-read update
|
||||
command = InternalUpdateTaskCommand(
|
||||
task_uuid=task.uuid,
|
||||
properties={"is_abortable": True, "progress_percent": 0.5},
|
||||
)
|
||||
result = command.run()
|
||||
|
||||
assert result is True
|
||||
|
||||
# Verify in database
|
||||
db.session.refresh(task)
|
||||
assert task.properties_dict.get("is_abortable") is True
|
||||
assert task.properties_dict.get("progress_percent") == 0.5
|
||||
finally:
|
||||
db.session.delete(task)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def test_internal_update_payload(app_context, get_user, login_as) -> None:
|
||||
"""Test updating only payload without reading task first."""
|
||||
admin = get_user("admin")
|
||||
login_as("admin")
|
||||
|
||||
task = TaskDAO.create_task(
|
||||
task_type="test_type",
|
||||
task_key="internal_update_payload",
|
||||
scope=TaskScope.PRIVATE,
|
||||
user_id=admin.id,
|
||||
)
|
||||
task.created_by = admin
|
||||
task.set_status(TaskStatus.IN_PROGRESS)
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
# Perform zero-read update
|
||||
command = InternalUpdateTaskCommand(
|
||||
task_uuid=task.uuid,
|
||||
payload={"custom_key": "value", "count": 42},
|
||||
)
|
||||
result = command.run()
|
||||
|
||||
assert result is True
|
||||
|
||||
# Verify in database
|
||||
db.session.refresh(task)
|
||||
assert task.payload_dict == {"custom_key": "value", "count": 42}
|
||||
finally:
|
||||
db.session.delete(task)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def test_internal_update_both_properties_and_payload(
|
||||
app_context, get_user, login_as
|
||||
) -> None:
|
||||
"""Test updating both properties and payload in one call."""
|
||||
admin = get_user("admin")
|
||||
login_as("admin")
|
||||
|
||||
task = TaskDAO.create_task(
|
||||
task_type="test_type",
|
||||
task_key="internal_update_both",
|
||||
scope=TaskScope.PRIVATE,
|
||||
user_id=admin.id,
|
||||
)
|
||||
task.created_by = admin
|
||||
task.set_status(TaskStatus.IN_PROGRESS)
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
# Perform zero-read update of both
|
||||
command = InternalUpdateTaskCommand(
|
||||
task_uuid=task.uuid,
|
||||
properties={"progress_current": 50, "progress_total": 100},
|
||||
payload={"last_item": "xyz"},
|
||||
)
|
||||
result = command.run()
|
||||
|
||||
assert result is True
|
||||
|
||||
# Verify in database
|
||||
db.session.refresh(task)
|
||||
assert task.properties_dict.get("progress_current") == 50
|
||||
assert task.properties_dict.get("progress_total") == 100
|
||||
assert task.payload_dict == {"last_item": "xyz"}
|
||||
finally:
|
||||
db.session.delete(task)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def test_internal_update_returns_false_for_nonexistent_task(
|
||||
app_context, login_as
|
||||
) -> None:
|
||||
"""Test that updating non-existent task returns False."""
|
||||
login_as("admin")
|
||||
|
||||
command = InternalUpdateTaskCommand(
|
||||
task_uuid=UUID("00000000-0000-0000-0000-000000000000"),
|
||||
properties={"is_abortable": True},
|
||||
)
|
||||
result = command.run()
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
def test_internal_update_returns_false_when_nothing_to_update(
|
||||
app_context, get_user, login_as
|
||||
) -> None:
|
||||
"""Test that passing no properties or payload returns False early."""
|
||||
admin = get_user("admin")
|
||||
login_as("admin")
|
||||
|
||||
task = TaskDAO.create_task(
|
||||
task_type="test_type",
|
||||
task_key="internal_update_empty",
|
||||
scope=TaskScope.PRIVATE,
|
||||
user_id=admin.id,
|
||||
)
|
||||
task.created_by = admin
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
# No properties or payload provided
|
||||
command = InternalUpdateTaskCommand(
|
||||
task_uuid=task.uuid,
|
||||
properties=None,
|
||||
payload=None,
|
||||
)
|
||||
result = command.run()
|
||||
|
||||
assert result is False
|
||||
finally:
|
||||
db.session.delete(task)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def test_internal_update_does_not_change_status(
|
||||
app_context, get_user, login_as
|
||||
) -> None:
|
||||
"""Test that internal update leaves status unchanged (safe for concurrent abort)."""
|
||||
admin = get_user("admin")
|
||||
login_as("admin")
|
||||
|
||||
task = TaskDAO.create_task(
|
||||
task_type="test_type",
|
||||
task_key="internal_update_status",
|
||||
scope=TaskScope.PRIVATE,
|
||||
user_id=admin.id,
|
||||
)
|
||||
task.created_by = admin
|
||||
task.set_status(TaskStatus.IN_PROGRESS)
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
# Update properties - status should not change
|
||||
command = InternalUpdateTaskCommand(
|
||||
task_uuid=task.uuid,
|
||||
properties={"progress_percent": 0.75},
|
||||
)
|
||||
result = command.run()
|
||||
|
||||
assert result is True
|
||||
|
||||
# Verify status unchanged
|
||||
db.session.refresh(task)
|
||||
assert task.status == TaskStatus.IN_PROGRESS.value
|
||||
finally:
|
||||
db.session.delete(task)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def test_internal_update_replaces_entire_properties(
|
||||
app_context, get_user, login_as
|
||||
) -> None:
|
||||
"""Test that internal update replaces properties entirely (no merge)."""
|
||||
admin = get_user("admin")
|
||||
login_as("admin")
|
||||
|
||||
task = TaskDAO.create_task(
|
||||
task_type="test_type",
|
||||
task_key="internal_update_replace",
|
||||
scope=TaskScope.PRIVATE,
|
||||
user_id=admin.id,
|
||||
properties={"is_abortable": True, "timeout": 300},
|
||||
)
|
||||
task.created_by = admin
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
# Replace with new properties (caller is responsible for merging if needed)
|
||||
command = InternalUpdateTaskCommand(
|
||||
task_uuid=task.uuid,
|
||||
properties={"error_message": "new_value"},
|
||||
)
|
||||
result = command.run()
|
||||
|
||||
assert result is True
|
||||
|
||||
# Verify entire replacement occurred
|
||||
db.session.refresh(task)
|
||||
# The caller should have merged if they wanted to preserve is_abortable
|
||||
assert task.properties_dict == {"error_message": "new_value"}
|
||||
assert "is_abortable" not in task.properties_dict
|
||||
assert "timeout" not in task.properties_dict
|
||||
finally:
|
||||
db.session.delete(task)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# InternalStatusTransitionCommand Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_status_transition_atomic_compare_and_swap(
|
||||
app_context, get_user, login_as
|
||||
) -> None:
|
||||
"""Test atomic conditional status transitions with comprehensive scenarios.
|
||||
|
||||
Covers: success case, failure case, list of expected statuses, properties update,
|
||||
ended_at timestamp, and string status values.
|
||||
"""
|
||||
admin = get_user("admin")
|
||||
login_as("admin")
|
||||
|
||||
task = TaskDAO.create_task(
|
||||
task_type="test_type",
|
||||
task_key="status_transition_comprehensive",
|
||||
scope=TaskScope.PRIVATE,
|
||||
user_id=admin.id,
|
||||
)
|
||||
task.created_by = admin
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
# 1. SUCCESS CASE: PENDING → IN_PROGRESS (expected matches)
|
||||
result = InternalStatusTransitionCommand(
|
||||
task_uuid=task.uuid,
|
||||
new_status=TaskStatus.IN_PROGRESS,
|
||||
expected_status=TaskStatus.PENDING,
|
||||
).run()
|
||||
assert result is True
|
||||
db.session.refresh(task)
|
||||
assert task.status == TaskStatus.IN_PROGRESS.value
|
||||
|
||||
# 2. FAILURE CASE: Try wrong expected status (should fail, status unchanged)
|
||||
result = InternalStatusTransitionCommand(
|
||||
task_uuid=task.uuid,
|
||||
new_status=TaskStatus.SUCCESS,
|
||||
expected_status=TaskStatus.PENDING, # Wrong! Current is IN_PROGRESS
|
||||
).run()
|
||||
assert result is False
|
||||
db.session.refresh(task)
|
||||
assert task.status == TaskStatus.IN_PROGRESS.value # Unchanged
|
||||
|
||||
# 3. LIST OF EXPECTED: Transition with multiple acceptable source statuses
|
||||
task.set_status(TaskStatus.ABORTING)
|
||||
db.session.commit()
|
||||
|
||||
result = InternalStatusTransitionCommand(
|
||||
task_uuid=task.uuid,
|
||||
new_status=TaskStatus.FAILURE,
|
||||
expected_status=[TaskStatus.IN_PROGRESS, TaskStatus.ABORTING],
|
||||
properties={"error_message": "Test error"},
|
||||
).run()
|
||||
assert result is True
|
||||
db.session.refresh(task)
|
||||
assert task.status == TaskStatus.FAILURE.value
|
||||
assert task.properties_dict.get("error_message") == "Test error"
|
||||
|
||||
# 4. ENDED_AT: Reset to IN_PROGRESS and test ended_at timestamp
|
||||
task.set_status(TaskStatus.IN_PROGRESS)
|
||||
task.ended_at = None
|
||||
db.session.commit()
|
||||
assert task.ended_at is None
|
||||
|
||||
result = InternalStatusTransitionCommand(
|
||||
task_uuid=task.uuid,
|
||||
new_status=TaskStatus.SUCCESS,
|
||||
expected_status=TaskStatus.IN_PROGRESS,
|
||||
set_ended_at=True,
|
||||
).run()
|
||||
assert result is True
|
||||
db.session.refresh(task)
|
||||
assert task.status == TaskStatus.SUCCESS.value
|
||||
assert task.ended_at is not None
|
||||
|
||||
# 5. STRING VALUES: Reset and test string status values
|
||||
task.set_status(TaskStatus.PENDING)
|
||||
db.session.commit()
|
||||
|
||||
result = InternalStatusTransitionCommand(
|
||||
task_uuid=task.uuid,
|
||||
new_status="in_progress",
|
||||
expected_status="pending",
|
||||
).run()
|
||||
assert result is True
|
||||
db.session.refresh(task)
|
||||
assert task.status == "in_progress"
|
||||
|
||||
finally:
|
||||
db.session.delete(task)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def test_status_transition_prevents_race_condition(
|
||||
app_context, get_user, login_as
|
||||
) -> None:
|
||||
"""Test that conditional update prevents overwriting concurrent abort.
|
||||
|
||||
This is the key race condition fix: if task is aborted concurrently,
|
||||
the executor's attempt to set SUCCESS should fail (return False),
|
||||
preserving the ABORTING state.
|
||||
"""
|
||||
admin = get_user("admin")
|
||||
login_as("admin")
|
||||
|
||||
task = TaskDAO.create_task(
|
||||
task_type="test_type",
|
||||
task_key="status_transition_race",
|
||||
scope=TaskScope.PRIVATE,
|
||||
user_id=admin.id,
|
||||
)
|
||||
task.created_by = admin
|
||||
task.set_status(TaskStatus.IN_PROGRESS)
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
# Simulate concurrent abort: directly set ABORTING in DB
|
||||
# (as if CancelTaskCommand ran in another process)
|
||||
task.set_status(TaskStatus.ABORTING)
|
||||
db.session.commit()
|
||||
|
||||
# Executor tries to set SUCCESS (expecting IN_PROGRESS) - stale expectation
|
||||
result = InternalStatusTransitionCommand(
|
||||
task_uuid=task.uuid,
|
||||
new_status=TaskStatus.SUCCESS,
|
||||
expected_status=TaskStatus.IN_PROGRESS,
|
||||
).run()
|
||||
|
||||
# Should fail - task was aborted concurrently
|
||||
assert result is False
|
||||
|
||||
# Verify ABORTING is preserved (not overwritten to SUCCESS)
|
||||
db.session.refresh(task)
|
||||
assert task.status == TaskStatus.ABORTING.value
|
||||
|
||||
# Verify correct transition from ABORTING still works
|
||||
result = InternalStatusTransitionCommand(
|
||||
task_uuid=task.uuid,
|
||||
new_status=TaskStatus.ABORTED,
|
||||
expected_status=TaskStatus.ABORTING,
|
||||
set_ended_at=True,
|
||||
).run()
|
||||
assert result is True
|
||||
db.session.refresh(task)
|
||||
assert task.status == TaskStatus.ABORTED.value
|
||||
|
||||
finally:
|
||||
db.session.delete(task)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def test_status_transition_nonexistent_task(app_context, login_as) -> None:
|
||||
"""Test that transitioning non-existent task returns False."""
|
||||
login_as("admin")
|
||||
|
||||
result = InternalStatusTransitionCommand(
|
||||
task_uuid=UUID("00000000-0000-0000-0000-000000000000"),
|
||||
new_status=TaskStatus.IN_PROGRESS,
|
||||
expected_status=TaskStatus.PENDING,
|
||||
).run()
|
||||
|
||||
assert result is False
|
||||
258
tests/integration_tests/tasks/commands/test_prune.py
Normal file
258
tests/integration_tests/tasks/commands/test_prune.py
Normal file
@@ -0,0 +1,258 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file to you under
|
||||
# the Apache License, Version 2.0 (the "License"); you may not
|
||||
# use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import patch
|
||||
|
||||
from freezegun import freeze_time
|
||||
from superset_core.api.tasks import TaskScope, TaskStatus
|
||||
|
||||
from superset import db
|
||||
from superset.commands.tasks import TaskPruneCommand
|
||||
from superset.daos.tasks import TaskDAO
|
||||
from superset.models.tasks import Task
|
||||
|
||||
|
||||
@freeze_time("2024-02-15")
|
||||
@patch("superset.tasks.utils.get_current_user")
|
||||
def test_prune_tasks_success(mock_get_user, app_context, get_user, login_as) -> None:
|
||||
"""Test successful pruning of old completed tasks"""
|
||||
login_as("admin")
|
||||
admin = get_user("admin")
|
||||
mock_get_user.return_value = admin.username
|
||||
|
||||
# Create old completed tasks (35 days ago = Jan 11, 2024)
|
||||
old_date = datetime(2024, 1, 11, tzinfo=timezone.utc)
|
||||
task_ids = []
|
||||
for i in range(3):
|
||||
task = TaskDAO.create_task(
|
||||
task_type="test_type",
|
||||
task_key=f"prune_task_{i}",
|
||||
scope=TaskScope.PRIVATE,
|
||||
user_id=admin.id,
|
||||
)
|
||||
task.created_by = admin
|
||||
task.set_status(TaskStatus.SUCCESS)
|
||||
task.ended_at = old_date
|
||||
task_ids.append(task.id)
|
||||
|
||||
# Create a recent task (5 days ago = Feb 10, 2024) that should NOT be deleted
|
||||
recent_date = datetime(2024, 2, 10, tzinfo=timezone.utc)
|
||||
recent_task = TaskDAO.create_task(
|
||||
task_type="test_type",
|
||||
task_key="recent_task",
|
||||
scope=TaskScope.PRIVATE,
|
||||
user_id=admin.id,
|
||||
)
|
||||
recent_task.created_by = admin
|
||||
recent_task.set_status(TaskStatus.SUCCESS)
|
||||
recent_task.ended_at = recent_date
|
||||
recent_task_id = recent_task.id
|
||||
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
# Prune tasks older than 30 days
|
||||
command = TaskPruneCommand(retention_period_days=30)
|
||||
command.run()
|
||||
|
||||
# Verify old tasks are deleted
|
||||
for task_id in task_ids:
|
||||
assert db.session.get(Task, task_id) is None
|
||||
|
||||
# Verify recent task is NOT deleted
|
||||
assert db.session.get(Task, recent_task_id) is not None
|
||||
finally:
|
||||
# Cleanup remaining tasks
|
||||
for task_id in task_ids:
|
||||
existing = db.session.get(Task, task_id)
|
||||
if existing:
|
||||
db.session.delete(existing)
|
||||
if db.session.get(Task, recent_task_id):
|
||||
db.session.delete(db.session.get(Task, recent_task_id))
|
||||
db.session.commit()
|
||||
|
||||
|
||||
@freeze_time("2024-02-15")
|
||||
@patch("superset.tasks.utils.get_current_user")
|
||||
def test_prune_tasks_with_max_rows(
|
||||
mock_get_user, app_context, get_user, login_as
|
||||
) -> None:
|
||||
"""Test pruning with max_rows_per_run limit"""
|
||||
login_as("admin")
|
||||
admin = get_user("admin")
|
||||
mock_get_user.return_value = admin.username
|
||||
|
||||
# Create old completed tasks (35 days ago = Jan 11, 2024)
|
||||
task_ids = []
|
||||
for i in range(5):
|
||||
# Different ages for ordering (older tasks have smaller hour values)
|
||||
old_date = datetime(2024, 1, 11, i, tzinfo=timezone.utc)
|
||||
task = TaskDAO.create_task(
|
||||
task_type="test_type",
|
||||
task_key=f"max_rows_task_{i}",
|
||||
scope=TaskScope.PRIVATE,
|
||||
user_id=admin.id,
|
||||
)
|
||||
task.created_by = admin
|
||||
task.set_status(TaskStatus.SUCCESS)
|
||||
task.ended_at = old_date
|
||||
task_ids.append(task.id)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
# Prune with max_rows_per_run=2 (should only delete 2 oldest)
|
||||
command = TaskPruneCommand(retention_period_days=30, max_rows_per_run=2)
|
||||
command.run()
|
||||
|
||||
# Count remaining tasks
|
||||
remaining = sum(
|
||||
1 for task_id in task_ids if db.session.get(Task, task_id) is not None
|
||||
)
|
||||
assert remaining == 3 # 5 - 2 = 3 remaining
|
||||
finally:
|
||||
# Cleanup remaining tasks
|
||||
for task_id in task_ids:
|
||||
existing = db.session.get(Task, task_id)
|
||||
if existing:
|
||||
db.session.delete(existing)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
@freeze_time("2024-02-15")
|
||||
@patch("superset.tasks.utils.get_current_user")
|
||||
def test_prune_does_not_delete_pending_tasks(
|
||||
mock_get_user, app_context, get_user, login_as
|
||||
) -> None:
|
||||
"""Test that pruning does not delete pending or in-progress tasks"""
|
||||
login_as("admin")
|
||||
admin = get_user("admin")
|
||||
mock_get_user.return_value = admin.username
|
||||
|
||||
pending_task = TaskDAO.create_task(
|
||||
task_type="test_type",
|
||||
task_key="pending_task",
|
||||
scope=TaskScope.PRIVATE,
|
||||
user_id=admin.id,
|
||||
)
|
||||
pending_task.created_by = admin
|
||||
# Keep as PENDING (no ended_at)
|
||||
|
||||
in_progress_task = TaskDAO.create_task(
|
||||
task_type="test_type",
|
||||
task_key="in_progress_task",
|
||||
scope=TaskScope.PRIVATE,
|
||||
user_id=admin.id,
|
||||
)
|
||||
in_progress_task.created_by = admin
|
||||
in_progress_task.set_status(TaskStatus.IN_PROGRESS)
|
||||
# No ended_at for in-progress tasks
|
||||
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
# Prune tasks older than 30 days
|
||||
command = TaskPruneCommand(retention_period_days=30)
|
||||
command.run()
|
||||
|
||||
# Verify non-completed tasks are NOT deleted
|
||||
assert db.session.get(Task, pending_task.id) is not None
|
||||
assert db.session.get(Task, in_progress_task.id) is not None
|
||||
finally:
|
||||
# Cleanup
|
||||
for task in [pending_task, in_progress_task]:
|
||||
existing = db.session.get(Task, task.id)
|
||||
if existing:
|
||||
db.session.delete(existing)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
@freeze_time("2024-02-15")
|
||||
@patch("superset.tasks.utils.get_current_user")
|
||||
def test_prune_deletes_all_completed_statuses(
|
||||
mock_get_user, app_context, get_user, login_as
|
||||
) -> None:
|
||||
"""Test pruning deletes SUCCESS, FAILURE, and ABORTED tasks"""
|
||||
login_as("admin")
|
||||
admin = get_user("admin")
|
||||
mock_get_user.return_value = admin.username
|
||||
|
||||
old_date = datetime(2024, 1, 11, tzinfo=timezone.utc)
|
||||
|
||||
# Create tasks with different completed statuses
|
||||
success_task = TaskDAO.create_task(
|
||||
task_type="test_type",
|
||||
task_key="success_task",
|
||||
scope=TaskScope.PRIVATE,
|
||||
user_id=admin.id,
|
||||
)
|
||||
success_task.created_by = admin
|
||||
success_task.set_status(TaskStatus.SUCCESS)
|
||||
success_task.ended_at = old_date
|
||||
success_task_id = success_task.id
|
||||
|
||||
failure_task = TaskDAO.create_task(
|
||||
task_type="test_type",
|
||||
task_key="failure_task",
|
||||
scope=TaskScope.PRIVATE,
|
||||
user_id=admin.id,
|
||||
)
|
||||
failure_task.created_by = admin
|
||||
failure_task.set_status(TaskStatus.FAILURE)
|
||||
failure_task.ended_at = old_date
|
||||
failure_task_id = failure_task.id
|
||||
|
||||
aborted_task = TaskDAO.create_task(
|
||||
task_type="test_type",
|
||||
task_key="aborted_task",
|
||||
scope=TaskScope.PRIVATE,
|
||||
user_id=admin.id,
|
||||
)
|
||||
aborted_task.created_by = admin
|
||||
aborted_task.set_status(TaskStatus.ABORTED)
|
||||
aborted_task.ended_at = old_date
|
||||
aborted_task_id = aborted_task.id
|
||||
|
||||
db.session.commit()
|
||||
task_ids = [success_task_id, failure_task_id, aborted_task_id]
|
||||
|
||||
try:
|
||||
# Prune tasks older than 30 days
|
||||
command = TaskPruneCommand(retention_period_days=30)
|
||||
command.run()
|
||||
|
||||
# Verify all completed tasks are deleted
|
||||
for task_id in task_ids:
|
||||
assert db.session.get(Task, task_id) is None
|
||||
except AssertionError:
|
||||
# Cleanup if test fails
|
||||
for task_id in task_ids:
|
||||
existing = db.session.get(Task, task_id)
|
||||
if existing:
|
||||
db.session.delete(existing)
|
||||
db.session.commit()
|
||||
raise
|
||||
|
||||
|
||||
def test_prune_no_tasks_to_delete(app_context, login_as) -> None:
|
||||
"""Test pruning when no old tasks exist"""
|
||||
login_as("admin")
|
||||
|
||||
# Don't create any tasks - should handle gracefully
|
||||
command = TaskPruneCommand(retention_period_days=30)
|
||||
command.run() # Should not raise any errors
|
||||
238
tests/integration_tests/tasks/commands/test_submit.py
Normal file
238
tests/integration_tests/tasks/commands/test_submit.py
Normal file
@@ -0,0 +1,238 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file to you under
|
||||
# the Apache License, Version 2.0 (the "License"); you may not
|
||||
# use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
|
||||
import pytest
|
||||
from superset_core.api.tasks import TaskStatus
|
||||
|
||||
from superset import db
|
||||
from superset.commands.tasks import SubmitTaskCommand
|
||||
from superset.commands.tasks.exceptions import (
|
||||
TaskInvalidError,
|
||||
)
|
||||
|
||||
|
||||
def test_submit_task_success(app_context, login_as, get_user) -> None:
|
||||
"""Test successful task submission"""
|
||||
login_as("admin")
|
||||
admin = get_user("admin")
|
||||
|
||||
command = SubmitTaskCommand(
|
||||
data={
|
||||
"task_type": "test-type",
|
||||
"task_key": "test-key",
|
||||
"task_name": "Test Task",
|
||||
"user_id": admin.id,
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
result = command.run()
|
||||
|
||||
# Verify task was created
|
||||
assert result.task_type == "test-type"
|
||||
assert result.task_key == "test-key"
|
||||
assert result.task_name == "Test Task"
|
||||
assert result.status == TaskStatus.PENDING.value
|
||||
assert result.payload == "{}"
|
||||
|
||||
# Verify in database
|
||||
db.session.refresh(result)
|
||||
assert result.id is not None
|
||||
assert result.uuid is not None
|
||||
finally:
|
||||
# Cleanup
|
||||
db.session.delete(result)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def test_submit_task_with_all_fields(app_context, login_as, get_user) -> None:
|
||||
"""Test task submission with all optional fields"""
|
||||
login_as("admin")
|
||||
admin = get_user("admin")
|
||||
|
||||
command = SubmitTaskCommand(
|
||||
data={
|
||||
"task_type": "test-type",
|
||||
"task_key": "test-key-full",
|
||||
"task_name": "Test Task Full",
|
||||
"user_id": admin.id,
|
||||
"payload": {"key": "value"},
|
||||
"properties": {"execution_mode": "async", "timeout": 300},
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
result = command.run()
|
||||
|
||||
# Verify all fields were set
|
||||
assert result.task_type == "test-type"
|
||||
assert result.task_key == "test-key-full"
|
||||
assert result.task_name == "Test Task Full"
|
||||
assert result.user_id == admin.id
|
||||
assert result.payload_dict == {"key": "value"}
|
||||
assert result.properties_dict.get("execution_mode") == "async"
|
||||
assert result.properties_dict.get("timeout") == 300
|
||||
finally:
|
||||
# Cleanup
|
||||
db.session.delete(result)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def test_submit_task_missing_task_type(app_context, login_as) -> None:
|
||||
"""Test submission fails when task_type is missing"""
|
||||
login_as("admin")
|
||||
|
||||
command = SubmitTaskCommand(data={})
|
||||
|
||||
with pytest.raises(TaskInvalidError) as exc_info:
|
||||
command.run()
|
||||
|
||||
assert len(exc_info.value._exceptions) == 1
|
||||
assert "task_type" in exc_info.value._exceptions[0].field_name
|
||||
|
||||
|
||||
def test_submit_task_joins_existing(app_context, login_as, get_user) -> None:
|
||||
"""Test that submitting with duplicate key joins existing task"""
|
||||
login_as("admin")
|
||||
admin = get_user("admin")
|
||||
|
||||
# Create first task
|
||||
command1 = SubmitTaskCommand(
|
||||
data={
|
||||
"task_type": "test-type",
|
||||
"task_key": "shared-key",
|
||||
"task_name": "First Task",
|
||||
"user_id": admin.id,
|
||||
}
|
||||
)
|
||||
task1 = command1.run()
|
||||
|
||||
try:
|
||||
# Submit second task with same task_key and type
|
||||
command2 = SubmitTaskCommand(
|
||||
data={
|
||||
"task_type": "test-type",
|
||||
"task_key": "shared-key",
|
||||
"task_name": "Second Task",
|
||||
"user_id": admin.id,
|
||||
}
|
||||
)
|
||||
|
||||
# Should return existing task, not create new one
|
||||
task2 = command2.run()
|
||||
assert task2.id == task1.id
|
||||
assert task2.uuid == task1.uuid
|
||||
finally:
|
||||
# Cleanup
|
||||
db.session.delete(task1)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def test_submit_task_without_task_key(app_context, login_as, get_user) -> None:
|
||||
"""Test task submission without task_key (command generates UUID)"""
|
||||
login_as("admin")
|
||||
admin = get_user("admin")
|
||||
|
||||
command = SubmitTaskCommand(
|
||||
data={
|
||||
"task_type": "test-type",
|
||||
"task_name": "Test Task No ID",
|
||||
"user_id": admin.id,
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
result = command.run()
|
||||
|
||||
# Verify task was created and command generated a task_key
|
||||
assert result.task_type == "test-type"
|
||||
assert result.task_name == "Test Task No ID"
|
||||
assert result.task_key is not None # Command generated UUID
|
||||
assert result.uuid is not None
|
||||
finally:
|
||||
# Cleanup
|
||||
db.session.delete(result)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def test_submit_task_run_with_info_returns_is_new_true(
|
||||
app_context, login_as, get_user
|
||||
) -> None:
|
||||
"""Test run_with_info returns is_new=True for new task"""
|
||||
login_as("admin")
|
||||
admin = get_user("admin")
|
||||
|
||||
command = SubmitTaskCommand(
|
||||
data={
|
||||
"task_type": "test-type",
|
||||
"task_key": "unique-key-is-new",
|
||||
"task_name": "Test Task",
|
||||
"user_id": admin.id,
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
task, is_new = command.run_with_info()
|
||||
|
||||
assert is_new is True
|
||||
assert task.task_key == "unique-key-is-new"
|
||||
finally:
|
||||
# Cleanup
|
||||
db.session.delete(task)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def test_submit_task_run_with_info_returns_is_new_false(
|
||||
app_context, login_as, get_user
|
||||
) -> None:
|
||||
"""Test run_with_info returns is_new=False when joining existing task"""
|
||||
login_as("admin")
|
||||
admin = get_user("admin")
|
||||
|
||||
# Create first task
|
||||
command1 = SubmitTaskCommand(
|
||||
data={
|
||||
"task_type": "test-type",
|
||||
"task_key": "shared-key-is-new",
|
||||
"task_name": "First Task",
|
||||
"user_id": admin.id,
|
||||
}
|
||||
)
|
||||
task1, is_new1 = command1.run_with_info()
|
||||
assert is_new1 is True
|
||||
|
||||
try:
|
||||
# Submit second task with same key
|
||||
command2 = SubmitTaskCommand(
|
||||
data={
|
||||
"task_type": "test-type",
|
||||
"task_key": "shared-key-is-new",
|
||||
"task_name": "Second Task",
|
||||
"user_id": admin.id,
|
||||
}
|
||||
)
|
||||
task2, is_new2 = command2.run_with_info()
|
||||
|
||||
# Should return existing task with is_new=False
|
||||
assert is_new2 is False
|
||||
assert task2.id == task1.id
|
||||
assert task2.uuid == task1.uuid
|
||||
finally:
|
||||
# Cleanup
|
||||
db.session.delete(task1)
|
||||
db.session.commit()
|
||||
260
tests/integration_tests/tasks/commands/test_update.py
Normal file
260
tests/integration_tests/tasks/commands/test_update.py
Normal file
@@ -0,0 +1,260 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file to you under
|
||||
# the Apache License, Version 2.0 (the "License"); you may not
|
||||
# use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from superset_core.api.tasks import TaskScope, TaskStatus
|
||||
|
||||
from superset import db
|
||||
from superset.commands.tasks import UpdateTaskCommand
|
||||
from superset.commands.tasks.exceptions import (
|
||||
TaskForbiddenError,
|
||||
TaskNotFoundError,
|
||||
)
|
||||
from superset.daos.tasks import TaskDAO
|
||||
|
||||
|
||||
def test_update_task_success(app_context, get_user, login_as) -> None:
|
||||
"""Test successful task update"""
|
||||
admin = get_user("admin")
|
||||
login_as("admin")
|
||||
|
||||
# Create a task using DAO
|
||||
task = TaskDAO.create_task(
|
||||
task_type="test_type",
|
||||
task_key="update_test",
|
||||
scope=TaskScope.PRIVATE,
|
||||
user_id=admin.id,
|
||||
)
|
||||
task.created_by = admin
|
||||
task.set_status(TaskStatus.IN_PROGRESS)
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
# Update the task status
|
||||
command = UpdateTaskCommand(
|
||||
task_uuid=task.uuid,
|
||||
status=TaskStatus.SUCCESS.value,
|
||||
)
|
||||
result = command.run()
|
||||
|
||||
# Verify update
|
||||
assert result.uuid == task.uuid
|
||||
assert result.status == TaskStatus.SUCCESS.value
|
||||
|
||||
# Verify in database
|
||||
db.session.refresh(task)
|
||||
assert task.status == TaskStatus.SUCCESS.value
|
||||
finally:
|
||||
# Cleanup
|
||||
db.session.delete(task)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def test_update_task_not_found(app_context, login_as) -> None:
|
||||
"""Test update fails when task not found"""
|
||||
login_as("admin")
|
||||
|
||||
command = UpdateTaskCommand(
|
||||
task_uuid=UUID("00000000-0000-0000-0000-000000000000"),
|
||||
status=TaskStatus.SUCCESS.value,
|
||||
)
|
||||
|
||||
with pytest.raises(TaskNotFoundError):
|
||||
command.run()
|
||||
|
||||
|
||||
def test_update_task_forbidden(app_context, get_user, login_as) -> None:
|
||||
"""Test update fails when user doesn't own task (via base filter)"""
|
||||
gamma = get_user("gamma")
|
||||
login_as("gamma")
|
||||
|
||||
# Create a task owned by gamma (non-admin) using DAO
|
||||
task = TaskDAO.create_task(
|
||||
task_type="test_type",
|
||||
task_key="forbidden_test",
|
||||
scope=TaskScope.PRIVATE,
|
||||
user_id=gamma.id,
|
||||
)
|
||||
task.created_by = gamma
|
||||
task.set_status(TaskStatus.IN_PROGRESS)
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
# Login as alpha user (different non-admin, non-owner)
|
||||
login_as("alpha")
|
||||
|
||||
# Try to update gamma's task as alpha user
|
||||
command = UpdateTaskCommand(
|
||||
task_uuid=task.uuid,
|
||||
status=TaskStatus.SUCCESS.value,
|
||||
)
|
||||
|
||||
# Should raise ForbiddenError because ownership check fails
|
||||
with pytest.raises(TaskForbiddenError):
|
||||
command.run()
|
||||
|
||||
# Verify task was NOT updated
|
||||
db.session.refresh(task)
|
||||
assert task.status == TaskStatus.IN_PROGRESS.value
|
||||
finally:
|
||||
# Cleanup
|
||||
db.session.delete(task)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def test_update_task_payload(app_context, get_user, login_as) -> None:
|
||||
"""Test updating task payload"""
|
||||
admin = get_user("admin")
|
||||
login_as("admin")
|
||||
|
||||
# Create a task using DAO
|
||||
task = TaskDAO.create_task(
|
||||
task_type="test_type",
|
||||
task_key="payload_test",
|
||||
scope=TaskScope.PRIVATE,
|
||||
user_id=admin.id,
|
||||
payload={"initial": "data"},
|
||||
)
|
||||
task.created_by = admin
|
||||
task.set_status(TaskStatus.IN_PROGRESS)
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
# Update payload
|
||||
command = UpdateTaskCommand(
|
||||
task_uuid=task.uuid,
|
||||
payload={"progress": 50, "message": "halfway"},
|
||||
)
|
||||
result = command.run()
|
||||
|
||||
# Verify payload was updated
|
||||
assert result.uuid == task.uuid
|
||||
payload = result.payload_dict
|
||||
assert payload["progress"] == 50
|
||||
assert payload["message"] == "halfway"
|
||||
|
||||
# Verify in database
|
||||
db.session.refresh(task)
|
||||
task_payload = task.payload_dict
|
||||
assert task_payload["progress"] == 50
|
||||
finally:
|
||||
# Cleanup
|
||||
db.session.delete(task)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def test_update_all_supported_fields(app_context, get_user, login_as) -> None:
|
||||
"""Test updating all supported task fields
|
||||
(status, error, progress, abortable, timeout)"""
|
||||
admin = get_user("admin")
|
||||
login_as("admin")
|
||||
|
||||
# Create a task with initial execution_mode and timeout in properties
|
||||
task = TaskDAO.create_task(
|
||||
task_type="test_type",
|
||||
task_key="all_fields_test",
|
||||
scope=TaskScope.PRIVATE,
|
||||
user_id=admin.id,
|
||||
properties={"execution_mode": "async", "timeout": 300},
|
||||
)
|
||||
task.created_by = admin
|
||||
task.set_status(TaskStatus.IN_PROGRESS)
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
# Update all field types at once
|
||||
command = UpdateTaskCommand(
|
||||
task_uuid=task.uuid,
|
||||
status=TaskStatus.FAILURE.value,
|
||||
properties={
|
||||
"error_message": "Task failed due to error",
|
||||
"progress_percent": 0.75,
|
||||
"progress_current": 75,
|
||||
"progress_total": 100,
|
||||
"is_abortable": True,
|
||||
},
|
||||
)
|
||||
result = command.run()
|
||||
|
||||
# Verify all fields were updated
|
||||
assert result.uuid == task.uuid
|
||||
assert result.status == TaskStatus.FAILURE.value
|
||||
assert result.properties_dict.get("error_message") == "Task failed due to error"
|
||||
assert result.properties_dict.get("progress_percent") == 0.75
|
||||
assert result.properties_dict.get("progress_current") == 75
|
||||
assert result.properties_dict.get("progress_total") == 100
|
||||
assert result.properties_dict.get("is_abortable") is True
|
||||
assert result.properties_dict.get("execution_mode") == "async"
|
||||
assert result.properties_dict.get("timeout") == 300
|
||||
|
||||
# Verify in database
|
||||
db.session.refresh(task)
|
||||
assert task.status == TaskStatus.FAILURE.value
|
||||
assert task.properties_dict.get("error_message") == "Task failed due to error"
|
||||
assert task.properties_dict.get("progress_percent") == 0.75
|
||||
assert task.properties_dict.get("progress_current") == 75
|
||||
assert task.properties_dict.get("progress_total") == 100
|
||||
assert task.properties_dict.get("is_abortable") is True
|
||||
assert task.properties_dict.get("execution_mode") == "async"
|
||||
assert task.properties_dict.get("timeout") == 300
|
||||
finally:
|
||||
# Cleanup
|
||||
db.session.delete(task)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def test_update_task_skip_security_check(app_context, get_user, login_as) -> None:
|
||||
"""Test skip_security_check allows updating any task"""
|
||||
admin = get_user("admin")
|
||||
login_as("admin")
|
||||
|
||||
# Create a task owned by admin
|
||||
task = TaskDAO.create_task(
|
||||
task_type="test_type",
|
||||
task_key="skip_security_test",
|
||||
scope=TaskScope.PRIVATE,
|
||||
user_id=admin.id,
|
||||
)
|
||||
task.created_by = admin
|
||||
task.set_status(TaskStatus.IN_PROGRESS)
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
# Login as gamma user (non-owner)
|
||||
login_as("gamma")
|
||||
|
||||
# With skip_security_check=True, should succeed even though gamma doesn't own it
|
||||
command = UpdateTaskCommand(
|
||||
task_uuid=task.uuid,
|
||||
properties={"progress_percent": 0.75},
|
||||
skip_security_check=True,
|
||||
)
|
||||
result = command.run()
|
||||
|
||||
# Verify update succeeded
|
||||
assert result.uuid == task.uuid
|
||||
assert result.properties_dict.get("progress_percent") == 0.75
|
||||
|
||||
# Verify in database
|
||||
db.session.refresh(task)
|
||||
assert task.properties_dict.get("progress_percent") == 0.75
|
||||
finally:
|
||||
# Cleanup
|
||||
db.session.delete(task)
|
||||
db.session.commit()
|
||||
415
tests/integration_tests/tasks/test_event_handlers.py
Normal file
415
tests/integration_tests/tasks/test_event_handlers.py
Normal file
@@ -0,0 +1,415 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""End-to-end integration tests for task event handlers (abort and cleanup)
|
||||
|
||||
These tests verify that abort and cleanup handlers work correctly through
|
||||
the full task execution path using real @task decorated functions executed
|
||||
via the Celery executor (synchronously via .apply()).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from superset_core.api.tasks import TaskScope, TaskStatus
|
||||
|
||||
from superset.commands.tasks.cancel import CancelTaskCommand
|
||||
from superset.daos.tasks import TaskDAO
|
||||
from superset.extensions import db
|
||||
from superset.models.tasks import Task
|
||||
from superset.tasks.ambient_context import get_context
|
||||
from superset.tasks.context import TaskContext
|
||||
from superset.tasks.registry import TaskRegistry
|
||||
from superset.tasks.scheduler import execute_task
|
||||
from tests.integration_tests.base_tests import SupersetTestCase
|
||||
from tests.integration_tests.constants import ADMIN_USERNAME
|
||||
|
||||
# Module-level state to track handler calls across test executions
|
||||
# (Since decorated functions are defined at module level)
|
||||
_handler_state: dict[str, Any] = {}
|
||||
|
||||
|
||||
def _reset_handler_state():
|
||||
"""Reset handler state before each test."""
|
||||
global _handler_state
|
||||
_handler_state = {
|
||||
"cleanup_called": False,
|
||||
"abort_called": False,
|
||||
"cleanup_order": [],
|
||||
"abort_order": [],
|
||||
"cleanup_data": {},
|
||||
}
|
||||
|
||||
|
||||
def cleanup_test_task() -> None:
|
||||
"""Task that registers a cleanup handler."""
|
||||
ctx = get_context()
|
||||
|
||||
@ctx.on_cleanup
|
||||
def handle_cleanup() -> None:
|
||||
_handler_state["cleanup_called"] = True
|
||||
|
||||
# Simulate some work
|
||||
ctx.update_task(progress=1.0)
|
||||
|
||||
|
||||
def abort_test_task() -> None:
|
||||
"""Task that registers an abort handler."""
|
||||
ctx = get_context()
|
||||
|
||||
@ctx.on_abort
|
||||
def handle_abort() -> None:
|
||||
_handler_state["abort_called"] = True
|
||||
|
||||
|
||||
def both_handlers_task() -> None:
|
||||
"""Task that registers both abort and cleanup handlers."""
|
||||
ctx = get_context()
|
||||
|
||||
@ctx.on_abort
|
||||
def handle_abort() -> None:
|
||||
_handler_state["abort_called"] = True
|
||||
_handler_state["abort_order"].append("abort")
|
||||
|
||||
@ctx.on_cleanup
|
||||
def handle_cleanup() -> None:
|
||||
_handler_state["cleanup_called"] = True
|
||||
_handler_state["cleanup_order"].append("cleanup")
|
||||
|
||||
|
||||
def multiple_cleanup_handlers_task() -> None:
|
||||
"""Task that registers multiple cleanup handlers."""
|
||||
ctx = get_context()
|
||||
|
||||
@ctx.on_cleanup
|
||||
def cleanup_first() -> None:
|
||||
_handler_state["cleanup_order"].append("first")
|
||||
|
||||
@ctx.on_cleanup
|
||||
def cleanup_second() -> None:
|
||||
_handler_state["cleanup_order"].append("second")
|
||||
|
||||
@ctx.on_cleanup
|
||||
def cleanup_third() -> None:
|
||||
_handler_state["cleanup_order"].append("third")
|
||||
|
||||
|
||||
def cleanup_with_data_task() -> None:
|
||||
"""Task that uses cleanup handler to clean up partial work."""
|
||||
ctx = get_context()
|
||||
|
||||
# Simulate partial work in module-level state
|
||||
_handler_state["cleanup_data"]["temp_key"] = "temp_value"
|
||||
|
||||
@ctx.on_cleanup
|
||||
def handle_cleanup() -> None:
|
||||
# Clean up the partial work
|
||||
_handler_state["cleanup_data"].clear()
|
||||
_handler_state["cleanup_called"] = True
|
||||
|
||||
|
||||
def _register_test_tasks() -> None:
|
||||
"""Register test task functions if not already registered.
|
||||
|
||||
Called in setUp() to ensure tasks are registered regardless of
|
||||
whether other tests have cleared the registry.
|
||||
"""
|
||||
registrations = [
|
||||
("test_cleanup_task", cleanup_test_task),
|
||||
("test_abort_task", abort_test_task),
|
||||
("test_both_handlers_task", both_handlers_task),
|
||||
("test_multiple_cleanup_task", multiple_cleanup_handlers_task),
|
||||
("test_cleanup_with_data", cleanup_with_data_task),
|
||||
]
|
||||
for name, func in registrations:
|
||||
if not TaskRegistry.is_registered(name):
|
||||
TaskRegistry.register(name, func)
|
||||
|
||||
|
||||
class TestCleanupHandlers(SupersetTestCase):
|
||||
"""E2E tests for on_cleanup functionality using Celery executor."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
super().setUp()
|
||||
self.login(ADMIN_USERNAME)
|
||||
_register_test_tasks()
|
||||
_reset_handler_state()
|
||||
|
||||
def test_cleanup_handler_fires_on_success(self):
|
||||
"""Test cleanup handler runs when task completes successfully."""
|
||||
# Create task entry directly
|
||||
task_obj = TaskDAO.create_task(
|
||||
task_type="test_cleanup_task",
|
||||
task_key=f"test_key_{uuid.uuid4().hex[:8]}",
|
||||
task_name="Test Cleanup",
|
||||
scope=TaskScope.SYSTEM,
|
||||
)
|
||||
|
||||
# Execute task synchronously through Celery executor
|
||||
# Use str(uuid) since Celery serializes args as JSON strings
|
||||
result = execute_task.apply(
|
||||
args=[str(task_obj.uuid), "test_cleanup_task", (), {}]
|
||||
)
|
||||
|
||||
# Verify task completed successfully
|
||||
assert result.successful()
|
||||
assert result.result["status"] == TaskStatus.SUCCESS.value
|
||||
|
||||
# Verify cleanup handler was called
|
||||
assert _handler_state["cleanup_called"]
|
||||
|
||||
def test_multiple_cleanup_handlers_in_lifo_order(self):
|
||||
"""Test multiple cleanup handlers execute in LIFO order."""
|
||||
task_obj = TaskDAO.create_task(
|
||||
task_type="test_multiple_cleanup_task",
|
||||
task_key=f"test_key_{uuid.uuid4().hex[:8]}",
|
||||
task_name="Test Multiple Cleanup",
|
||||
scope=TaskScope.SYSTEM,
|
||||
)
|
||||
|
||||
result = execute_task.apply(
|
||||
args=[str(task_obj.uuid), "test_multiple_cleanup_task", (), {}]
|
||||
)
|
||||
|
||||
assert result.successful()
|
||||
|
||||
# Handlers should execute in LIFO order (last registered first)
|
||||
assert _handler_state["cleanup_order"] == ["third", "second", "first"]
|
||||
|
||||
def test_cleanup_handler_cleans_up_partial_work(self):
|
||||
"""Test cleanup handler can clean up partial work."""
|
||||
task_obj = TaskDAO.create_task(
|
||||
task_type="test_cleanup_with_data",
|
||||
task_key=f"test_key_{uuid.uuid4().hex[:8]}",
|
||||
task_name="Test Cleanup Data",
|
||||
scope=TaskScope.SYSTEM,
|
||||
)
|
||||
|
||||
result = execute_task.apply(
|
||||
args=[str(task_obj.uuid), "test_cleanup_with_data", (), {}]
|
||||
)
|
||||
|
||||
assert result.successful()
|
||||
assert _handler_state["cleanup_called"]
|
||||
# Cleanup handler should have cleared the data
|
||||
assert len(_handler_state["cleanup_data"]) == 0
|
||||
|
||||
|
||||
class TestAbortHandlers(SupersetTestCase):
|
||||
"""E2E tests for on_abort functionality."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
super().setUp()
|
||||
self.login(ADMIN_USERNAME)
|
||||
_register_test_tasks()
|
||||
_reset_handler_state()
|
||||
|
||||
def test_abort_handler_fires_when_task_aborting(self):
|
||||
"""Test abort handler runs when task is in ABORTING state during cleanup."""
|
||||
# Create task entry
|
||||
task_obj = TaskDAO.create_task(
|
||||
task_type="test_abort_task",
|
||||
task_key=f"test_key_{uuid.uuid4().hex[:8]}",
|
||||
task_name="Test Abort",
|
||||
scope=TaskScope.SYSTEM,
|
||||
)
|
||||
|
||||
# Manually set to IN_PROGRESS and then ABORTING to simulate abort
|
||||
task_obj.status = TaskStatus.IN_PROGRESS.value
|
||||
task_obj.update_properties({"is_abortable": True})
|
||||
db.session.merge(task_obj)
|
||||
db.session.commit()
|
||||
|
||||
# Refresh to get the updated task
|
||||
db.session.refresh(task_obj)
|
||||
|
||||
# Create context (simulating what executor does)
|
||||
ctx = TaskContext(task_obj)
|
||||
|
||||
# Register abort handler
|
||||
@ctx.on_abort
|
||||
def handle_abort():
|
||||
_handler_state["abort_called"] = True
|
||||
|
||||
# Set status to ABORTING (simulating CancelTaskCommand)
|
||||
task_obj.status = TaskStatus.ABORTING.value
|
||||
db.session.merge(task_obj)
|
||||
db.session.commit()
|
||||
|
||||
# Run cleanup (simulating executor's finally block)
|
||||
ctx._run_cleanup()
|
||||
|
||||
# Verify abort handler was called
|
||||
assert _handler_state["abort_called"]
|
||||
|
||||
def test_both_handlers_fire_on_abort(self):
|
||||
"""Test both abort and cleanup handlers run when task is aborted."""
|
||||
task_obj = TaskDAO.create_task(
|
||||
task_type="test_both_handlers_task",
|
||||
task_key=f"test_key_{uuid.uuid4().hex[:8]}",
|
||||
task_name="Test Both Handlers",
|
||||
scope=TaskScope.SYSTEM,
|
||||
)
|
||||
|
||||
task_obj.status = TaskStatus.IN_PROGRESS.value
|
||||
task_obj.update_properties({"is_abortable": True})
|
||||
db.session.merge(task_obj)
|
||||
db.session.commit()
|
||||
|
||||
# Refresh to get the updated task
|
||||
db.session.refresh(task_obj)
|
||||
|
||||
ctx = TaskContext(task_obj)
|
||||
|
||||
@ctx.on_abort
|
||||
def handle_abort():
|
||||
_handler_state["abort_called"] = True
|
||||
_handler_state["abort_order"].append("abort")
|
||||
|
||||
@ctx.on_cleanup
|
||||
def handle_cleanup():
|
||||
_handler_state["cleanup_called"] = True
|
||||
_handler_state["cleanup_order"].append("cleanup")
|
||||
|
||||
# Set to ABORTING
|
||||
task_obj.status = TaskStatus.ABORTING.value
|
||||
db.session.merge(task_obj)
|
||||
db.session.commit()
|
||||
|
||||
ctx._run_cleanup()
|
||||
|
||||
# Both should have been called
|
||||
assert _handler_state["abort_called"]
|
||||
assert _handler_state["cleanup_called"]
|
||||
|
||||
def test_abort_handler_not_called_on_success(self):
|
||||
"""Test abort handler doesn't run when task succeeds."""
|
||||
task_obj = TaskDAO.create_task(
|
||||
task_type="test_abort_task",
|
||||
task_key=f"test_key_{uuid.uuid4().hex[:8]}",
|
||||
task_name="Test No Abort on Success",
|
||||
scope=TaskScope.SYSTEM,
|
||||
)
|
||||
|
||||
task_obj.status = TaskStatus.SUCCESS.value
|
||||
db.session.merge(task_obj)
|
||||
db.session.commit()
|
||||
|
||||
# Refresh to get the updated task
|
||||
db.session.refresh(task_obj)
|
||||
|
||||
ctx = TaskContext(task_obj)
|
||||
|
||||
@ctx.on_abort
|
||||
def handle_abort():
|
||||
_handler_state["abort_called"] = True
|
||||
|
||||
@ctx.on_cleanup
|
||||
def handle_cleanup():
|
||||
_handler_state["cleanup_called"] = True
|
||||
|
||||
ctx._run_cleanup()
|
||||
|
||||
# Abort handler should NOT be called
|
||||
assert not _handler_state["abort_called"]
|
||||
# Cleanup handler should still be called
|
||||
assert _handler_state["cleanup_called"]
|
||||
|
||||
|
||||
class TestTaskContextMethods(SupersetTestCase):
|
||||
"""Tests for TaskContext public methods."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
super().setUp()
|
||||
self.login(ADMIN_USERNAME)
|
||||
|
||||
def test_on_abort_marks_task_abortable(self):
|
||||
"""Test that registering an on_abort handler marks task as abortable."""
|
||||
task_obj = TaskDAO.create_task(
|
||||
task_type="test_abortable_flag",
|
||||
task_key=f"test_key_{uuid.uuid4().hex[:8]}",
|
||||
task_name="Test Abortable",
|
||||
scope=TaskScope.SYSTEM,
|
||||
)
|
||||
|
||||
assert task_obj.properties_dict.get("is_abortable") is not True
|
||||
|
||||
ctx = TaskContext(task_obj)
|
||||
|
||||
@ctx.on_abort
|
||||
def handle_abort():
|
||||
pass
|
||||
|
||||
db.session.expire_all()
|
||||
task_obj = db.session.query(Task).filter_by(uuid=task_obj.uuid).first()
|
||||
assert task_obj.properties_dict.get("is_abortable") is True
|
||||
|
||||
|
||||
class TestAbortBeforeExecution(SupersetTestCase):
|
||||
"""Tests for aborting tasks before they start executing."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
super().setUp()
|
||||
self.login(ADMIN_USERNAME)
|
||||
_register_test_tasks()
|
||||
|
||||
def test_abort_pending_task(self):
|
||||
"""Test that pending tasks can be aborted directly."""
|
||||
task_obj = TaskDAO.create_task(
|
||||
task_type="test_abort_before_start",
|
||||
task_key=f"test_key_{uuid.uuid4().hex[:8]}",
|
||||
task_name="Test Before Start",
|
||||
scope=TaskScope.SYSTEM,
|
||||
)
|
||||
|
||||
# Cancel immediately (task is still PENDING)
|
||||
CancelTaskCommand(task_obj.uuid, force=True).run()
|
||||
|
||||
db.session.expire_all()
|
||||
task_obj = db.session.query(Task).filter_by(uuid=task_obj.uuid).first()
|
||||
assert task_obj.status == TaskStatus.ABORTED.value
|
||||
|
||||
def test_executor_skips_aborted_task(self):
|
||||
"""Test that executor skips tasks already aborted before execution."""
|
||||
task_obj = TaskDAO.create_task(
|
||||
task_type="test_cleanup_task",
|
||||
task_key=f"test_key_{uuid.uuid4().hex[:8]}",
|
||||
task_name="Test Skip Aborted",
|
||||
scope=TaskScope.SYSTEM,
|
||||
)
|
||||
|
||||
# Abort the task before execution
|
||||
task_obj.status = TaskStatus.ABORTED.value
|
||||
db.session.merge(task_obj)
|
||||
db.session.commit()
|
||||
|
||||
_reset_handler_state()
|
||||
|
||||
# Try to execute - should skip
|
||||
result = execute_task.apply(
|
||||
args=[str(task_obj.uuid), "test_cleanup_task", (), {}]
|
||||
)
|
||||
|
||||
assert result.successful()
|
||||
assert result.result["status"] == TaskStatus.ABORTED.value
|
||||
# Cleanup handler should NOT have been called (task was skipped)
|
||||
assert not _handler_state["cleanup_called"]
|
||||
158
tests/integration_tests/tasks/test_sync_join_wait.py
Normal file
158
tests/integration_tests/tasks/test_sync_join_wait.py
Normal file
@@ -0,0 +1,158 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file to you under
|
||||
# the Apache License, Version 2.0 (the "License"); you may not
|
||||
# use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""Integration tests for sync join-and-wait functionality in GTF."""
|
||||
|
||||
import time
|
||||
|
||||
from superset_core.api.tasks import TaskStatus
|
||||
|
||||
from superset import db
|
||||
from superset.commands.tasks import SubmitTaskCommand
|
||||
from superset.daos.tasks import TaskDAO
|
||||
from superset.tasks.manager import TaskManager
|
||||
|
||||
|
||||
def test_submit_task_distinguishes_new_vs_existing(
|
||||
app_context, login_as, get_user
|
||||
) -> None:
|
||||
"""
|
||||
Test that SubmitTaskCommand.run_with_info() correctly returns is_new flag.
|
||||
"""
|
||||
login_as("admin")
|
||||
admin = get_user("admin")
|
||||
|
||||
# First submission - should be new
|
||||
task1, is_new1 = SubmitTaskCommand(
|
||||
data={
|
||||
"task_type": "test-type",
|
||||
"task_key": "distinguish-key",
|
||||
"task_name": "First Task",
|
||||
"user_id": admin.id,
|
||||
}
|
||||
).run_with_info()
|
||||
|
||||
assert is_new1 is True
|
||||
|
||||
try:
|
||||
# Second submission with same key - should join existing
|
||||
task2, is_new2 = SubmitTaskCommand(
|
||||
data={
|
||||
"task_type": "test-type",
|
||||
"task_key": "distinguish-key",
|
||||
"task_name": "Second Task",
|
||||
"user_id": admin.id,
|
||||
}
|
||||
).run_with_info()
|
||||
|
||||
assert is_new2 is False
|
||||
assert task2.uuid == task1.uuid
|
||||
|
||||
finally:
|
||||
# Cleanup
|
||||
db.session.delete(task1)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def test_terminal_states_recognized_correctly(app_context) -> None:
|
||||
"""
|
||||
Test that TaskManager.TERMINAL_STATES contains the expected values.
|
||||
"""
|
||||
assert TaskStatus.SUCCESS.value in TaskManager.TERMINAL_STATES
|
||||
assert TaskStatus.FAILURE.value in TaskManager.TERMINAL_STATES
|
||||
assert TaskStatus.ABORTED.value in TaskManager.TERMINAL_STATES
|
||||
assert TaskStatus.TIMED_OUT.value in TaskManager.TERMINAL_STATES
|
||||
|
||||
# Non-terminal states should not be in the set
|
||||
assert TaskStatus.PENDING.value not in TaskManager.TERMINAL_STATES
|
||||
assert TaskStatus.IN_PROGRESS.value not in TaskManager.TERMINAL_STATES
|
||||
assert TaskStatus.ABORTING.value not in TaskManager.TERMINAL_STATES
|
||||
|
||||
|
||||
def test_wait_for_completion_timeout(app_context, login_as, get_user) -> None:
|
||||
"""
|
||||
Test that wait_for_completion raises TimeoutError on timeout.
|
||||
"""
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
login_as("admin")
|
||||
admin = get_user("admin")
|
||||
|
||||
# Create a pending task (won't complete)
|
||||
task, _ = SubmitTaskCommand(
|
||||
data={
|
||||
"task_type": "test-timeout",
|
||||
"task_key": "timeout-key",
|
||||
"task_name": "Timeout Task",
|
||||
"user_id": admin.id,
|
||||
}
|
||||
).run_with_info()
|
||||
|
||||
try:
|
||||
# Force polling mode by mocking signal_cache as None
|
||||
with patch("superset.tasks.manager.cache_manager") as mock_cache_manager:
|
||||
mock_cache_manager.signal_cache = None
|
||||
with pytest.raises(TimeoutError):
|
||||
TaskManager.wait_for_completion(
|
||||
task.uuid,
|
||||
timeout=0.2,
|
||||
poll_interval=0.05,
|
||||
)
|
||||
finally:
|
||||
db.session.delete(task)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def test_wait_returns_immediately_for_terminal_task(
|
||||
app_context, login_as, get_user
|
||||
) -> None:
|
||||
"""
|
||||
Test that wait_for_completion returns immediately if task is already terminal.
|
||||
"""
|
||||
login_as("admin")
|
||||
admin = get_user("admin")
|
||||
|
||||
# Create and immediately complete a task
|
||||
task, _ = SubmitTaskCommand(
|
||||
data={
|
||||
"task_type": "test-immediate",
|
||||
"task_key": "immediate-key",
|
||||
"task_name": "Immediate Task",
|
||||
"user_id": admin.id,
|
||||
}
|
||||
).run_with_info()
|
||||
|
||||
TaskDAO.update(task, {"status": TaskStatus.SUCCESS.value})
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
start = time.time()
|
||||
result = TaskManager.wait_for_completion(
|
||||
task.uuid,
|
||||
timeout=5.0,
|
||||
poll_interval=0.5,
|
||||
)
|
||||
elapsed = time.time() - start
|
||||
|
||||
assert result.status == TaskStatus.SUCCESS.value
|
||||
# Should return almost immediately since task is already terminal
|
||||
assert elapsed < 0.2
|
||||
finally:
|
||||
db.session.delete(task)
|
||||
db.session.commit()
|
||||
172
tests/integration_tests/tasks/test_throttling.py
Normal file
172
tests/integration_tests/tasks/test_throttling.py
Normal file
@@ -0,0 +1,172 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""Integration tests for TaskContext update_task throttling.
|
||||
|
||||
Tests verify:
|
||||
1. Final state is persisted correctly via cleanup flush
|
||||
2. Throttled updates are deferred, timer writes latest pending update
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
import uuid
|
||||
|
||||
from superset_core.api.tasks import TaskScope, TaskStatus
|
||||
|
||||
from superset.daos.tasks import TaskDAO
|
||||
from superset.extensions import db
|
||||
from superset.models.tasks import Task
|
||||
from superset.tasks.ambient_context import get_context
|
||||
from superset.tasks.registry import TaskRegistry
|
||||
from superset.tasks.scheduler import execute_task
|
||||
from tests.integration_tests.base_tests import SupersetTestCase
|
||||
from tests.integration_tests.constants import ADMIN_USERNAME
|
||||
|
||||
|
||||
def task_with_throttled_updates() -> None:
|
||||
"""Task with rapid progress and payload updates (exercises throttling)."""
|
||||
ctx = get_context()
|
||||
|
||||
# Rapid-fire updates within throttle window
|
||||
for i in range(10):
|
||||
ctx.update_task(progress=(i + 1, 10), payload={"step": i + 1})
|
||||
|
||||
|
||||
def _register_test_tasks() -> None:
|
||||
"""Register test task functions if not already registered.
|
||||
|
||||
Called in setUp() to ensure tasks are registered regardless of
|
||||
whether other tests have cleared the registry.
|
||||
"""
|
||||
if not TaskRegistry.is_registered("test_throttle_combined"):
|
||||
TaskRegistry.register("test_throttle_combined", task_with_throttled_updates)
|
||||
|
||||
|
||||
class TestUpdateTaskThrottling(SupersetTestCase):
|
||||
"""Integration test for update_task() throttling behavior."""
|
||||
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
self.login(ADMIN_USERNAME)
|
||||
_register_test_tasks()
|
||||
|
||||
def test_throttled_updates_persisted_on_cleanup(self) -> None:
|
||||
"""Final state should be persisted regardless of throttling.
|
||||
|
||||
Verifies the core invariant: cleanup flush ensures final state is persisted.
|
||||
"""
|
||||
task_obj = TaskDAO.create_task(
|
||||
task_type="test_throttle_combined",
|
||||
task_key=f"test_key_{uuid.uuid4().hex[:8]}",
|
||||
task_name="Test Throttled Updates",
|
||||
scope=TaskScope.SYSTEM,
|
||||
)
|
||||
|
||||
# Use str(uuid) since Celery serializes args as JSON strings
|
||||
result = execute_task.apply(
|
||||
args=[str(task_obj.uuid), "test_throttle_combined", (), {}]
|
||||
)
|
||||
|
||||
assert result.successful()
|
||||
assert result.result["status"] == TaskStatus.SUCCESS.value
|
||||
|
||||
# Verify final state is persisted
|
||||
db.session.expire_all()
|
||||
task_obj = db.session.query(Task).filter_by(uuid=task_obj.uuid).first()
|
||||
|
||||
# Progress: 10/10 = 100%
|
||||
props = task_obj.properties_dict
|
||||
assert props.get("progress_current") == 10
|
||||
assert props.get("progress_total") == 10
|
||||
assert props.get("progress_percent") == 1.0
|
||||
|
||||
# Payload: final step
|
||||
payload = task_obj.payload_dict
|
||||
assert payload.get("step") == 10
|
||||
|
||||
def test_throttle_behavior(self) -> None:
|
||||
"""Test complete throttle behavior: immediate write, deferral, and timer.
|
||||
|
||||
Verifies:
|
||||
1. First update writes immediately
|
||||
2. Second and third updates within throttle window are deferred
|
||||
3. Deferred timer fires and writes the LATEST pending update (third)
|
||||
"""
|
||||
from flask import current_app
|
||||
|
||||
from superset.commands.tasks.submit import SubmitTaskCommand
|
||||
from superset.tasks.context import TaskContext
|
||||
|
||||
# Get throttle interval from config (default: 2 seconds)
|
||||
throttle_interval = current_app.config["TASK_PROGRESS_UPDATE_THROTTLE_INTERVAL"]
|
||||
|
||||
# Create task
|
||||
task_obj = SubmitTaskCommand(
|
||||
data={
|
||||
"task_type": "test_throttle_behavior",
|
||||
"task_key": f"test_key_{uuid.uuid4().hex[:8]}",
|
||||
"task_name": "Test Throttle Behavior",
|
||||
"scope": TaskScope.SYSTEM,
|
||||
}
|
||||
).run()
|
||||
task_uuid = task_obj.uuid
|
||||
|
||||
# Get fresh task for context
|
||||
fresh_task = TaskDAO.find_one_or_none(uuid=task_uuid, skip_base_filter=True)
|
||||
assert fresh_task is not None
|
||||
ctx = TaskContext(fresh_task)
|
||||
|
||||
try:
|
||||
# === Step 1: First update - writes immediately ===
|
||||
ctx.update_task(progress=0.1, payload={"step": 1})
|
||||
|
||||
db.session.expire_all()
|
||||
task_step1 = TaskDAO.find_one_or_none(uuid=task_uuid, skip_base_filter=True)
|
||||
assert task_step1 is not None
|
||||
assert task_step1.properties_dict.get("progress_percent") == 0.1
|
||||
assert task_step1.payload_dict.get("step") == 1
|
||||
|
||||
# === Step 2: Second update - deferred (within throttle window) ===
|
||||
ctx.update_task(progress=0.5, payload={"step": 2})
|
||||
|
||||
# === Step 3: Third update - also deferred, overwrites second in cache ===
|
||||
ctx.update_task(progress=0.7, payload={"step": 3})
|
||||
|
||||
# Verify in-memory cache has LATEST update (third)
|
||||
assert ctx._properties_cache.get("progress_percent") == 0.7
|
||||
assert ctx._payload_cache.get("step") == 3
|
||||
|
||||
# Verify DB still has first update (both second and third deferred)
|
||||
db.session.expire_all()
|
||||
task_step2 = TaskDAO.find_one_or_none(uuid=task_uuid, skip_base_filter=True)
|
||||
assert task_step2 is not None
|
||||
assert task_step2.properties_dict.get("progress_percent") == 0.1
|
||||
assert task_step2.payload_dict.get("step") == 1
|
||||
|
||||
# === Step 4: Wait for deferred timer to fire ===
|
||||
time.sleep(throttle_interval + 0.5)
|
||||
|
||||
# Verify timer fired and wrote the LATEST update (third, not second)
|
||||
db.session.expire_all()
|
||||
task_step3 = TaskDAO.find_one_or_none(uuid=task_uuid, skip_base_filter=True)
|
||||
assert task_step3 is not None
|
||||
assert task_step3.properties_dict.get("progress_percent") == 0.7
|
||||
assert task_step3.payload_dict.get("step") == 3
|
||||
|
||||
finally:
|
||||
ctx._cancel_deferred_flush_timer()
|
||||
226
tests/integration_tests/tasks/test_timeout.py
Normal file
226
tests/integration_tests/tasks/test_timeout.py
Normal file
@@ -0,0 +1,226 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""Integration tests for GTF timeout handling.
|
||||
|
||||
Uses module-level task functions with manual registry (like test_event_handlers.py)
|
||||
to avoid mypy issues with the @task decorator's complex generic types.
|
||||
|
||||
NOTE: Tests that use background threads (timeout/abort handlers) are skipped in
|
||||
SQLite environments because SQLite connections cannot be shared across threads.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from superset_core.api.tasks import TaskScope, TaskStatus
|
||||
|
||||
from superset.commands.tasks.cancel import CancelTaskCommand
|
||||
from superset.daos.tasks import TaskDAO
|
||||
from superset.extensions import db
|
||||
from superset.models.tasks import Task
|
||||
from superset.tasks.ambient_context import get_context
|
||||
from superset.tasks.registry import TaskRegistry
|
||||
from superset.tasks.scheduler import execute_task
|
||||
from tests.integration_tests.base_tests import SupersetTestCase
|
||||
from tests.integration_tests.constants import ADMIN_USERNAME
|
||||
|
||||
|
||||
def _skip_if_sqlite() -> None:
|
||||
"""Skip test if running with SQLite database.
|
||||
|
||||
SQLite connections cannot be shared across threads, which breaks
|
||||
timeout tests that use background threads for abort handlers.
|
||||
Must be called from within a test method (with app context).
|
||||
"""
|
||||
if "sqlite" in db.engine.url.drivername:
|
||||
pytest.skip("SQLite connections cannot be shared across threads")
|
||||
|
||||
|
||||
# Module-level state to track handler calls
|
||||
_handler_state: dict[str, Any] = {}
|
||||
|
||||
|
||||
def _reset_handler_state() -> None:
|
||||
"""Reset handler state before each test."""
|
||||
global _handler_state
|
||||
_handler_state = {
|
||||
"abort_called": False,
|
||||
"handler_exception": None,
|
||||
}
|
||||
|
||||
|
||||
def timeout_abortable_task() -> None:
|
||||
"""Task with abort handler that exits when aborted."""
|
||||
ctx = get_context()
|
||||
|
||||
@ctx.on_abort
|
||||
def on_abort() -> None:
|
||||
_handler_state["abort_called"] = True
|
||||
|
||||
# Poll for abort signal
|
||||
for _ in range(50):
|
||||
if _handler_state["abort_called"]:
|
||||
return
|
||||
time.sleep(0.1)
|
||||
|
||||
|
||||
def timeout_handler_fails_task() -> None:
|
||||
"""Task with abort handler that throws an exception."""
|
||||
ctx = get_context()
|
||||
|
||||
@ctx.on_abort
|
||||
def on_abort() -> None:
|
||||
_handler_state["abort_called"] = True
|
||||
raise ValueError("Handler crashed!")
|
||||
|
||||
# Sleep longer than timeout
|
||||
time.sleep(5)
|
||||
|
||||
|
||||
def simple_task_with_abort() -> None:
|
||||
"""Simple task with abort handler for testing."""
|
||||
ctx = get_context()
|
||||
|
||||
@ctx.on_abort
|
||||
def on_abort() -> None:
|
||||
pass
|
||||
|
||||
|
||||
def quick_task_with_abort() -> None:
|
||||
"""Quick task that completes before timeout."""
|
||||
ctx = get_context()
|
||||
|
||||
@ctx.on_abort
|
||||
def on_abort() -> None:
|
||||
pass
|
||||
|
||||
time.sleep(0.2)
|
||||
|
||||
|
||||
def _register_test_tasks() -> None:
|
||||
"""Register test task functions if not already registered.
|
||||
|
||||
Called in setUp() to ensure tasks are registered regardless of
|
||||
whether other tests have cleared the registry.
|
||||
"""
|
||||
registrations = [
|
||||
("test_timeout_abortable", timeout_abortable_task),
|
||||
("test_timeout_handler_fails", timeout_handler_fails_task),
|
||||
("test_timeout_simple", simple_task_with_abort),
|
||||
("test_timeout_quick", quick_task_with_abort),
|
||||
]
|
||||
for name, func in registrations:
|
||||
if not TaskRegistry.is_registered(name):
|
||||
TaskRegistry.register(name, func)
|
||||
|
||||
|
||||
class TestTimeoutHandling(SupersetTestCase):
|
||||
"""E2E tests for task timeout functionality."""
|
||||
|
||||
def setUp(self) -> None:
|
||||
"""Set up test fixtures."""
|
||||
super().setUp()
|
||||
self.login(ADMIN_USERNAME)
|
||||
_register_test_tasks()
|
||||
_reset_handler_state()
|
||||
|
||||
def test_timeout_with_abort_handler_results_in_timed_out_status(self) -> None:
|
||||
"""Task with timeout and abort handler should end with TIMED_OUT status."""
|
||||
_skip_if_sqlite()
|
||||
|
||||
# Create task with timeout
|
||||
task_obj = TaskDAO.create_task(
|
||||
task_type="test_timeout_abortable",
|
||||
task_key=f"test_key_{uuid.uuid4().hex[:8]}",
|
||||
task_name="Test Timeout",
|
||||
scope=TaskScope.SYSTEM,
|
||||
properties={"timeout": 1}, # 1 second timeout
|
||||
)
|
||||
|
||||
# Execute task via Celery executor (synchronously)
|
||||
# Use str(uuid) since Celery serializes args as JSON strings
|
||||
result = execute_task.apply(
|
||||
args=[str(task_obj.uuid), "test_timeout_abortable", (), {}]
|
||||
)
|
||||
|
||||
# Verify execution completed
|
||||
assert result.successful()
|
||||
assert result.result["status"] == TaskStatus.TIMED_OUT.value
|
||||
|
||||
# Verify abort handler was called
|
||||
assert _handler_state["abort_called"]
|
||||
|
||||
def test_user_abort_results_in_aborted_status(self) -> None:
|
||||
"""User-initiated abort on pending task should result in ABORTED."""
|
||||
# Create task (pending state)
|
||||
task_obj = TaskDAO.create_task(
|
||||
task_type="test_timeout_simple",
|
||||
task_key=f"test_key_{uuid.uuid4().hex[:8]}",
|
||||
task_name="Test Abort Task",
|
||||
scope=TaskScope.SYSTEM,
|
||||
)
|
||||
|
||||
# Cancel before execution (pending task abort)
|
||||
CancelTaskCommand(task_obj.uuid, force=True).run()
|
||||
|
||||
# Refresh from DB
|
||||
db.session.expire_all()
|
||||
task_obj = db.session.query(Task).filter_by(uuid=task_obj.uuid).first()
|
||||
assert task_obj.status == TaskStatus.ABORTED.value
|
||||
|
||||
def test_no_timeout_when_not_configured(self) -> None:
|
||||
"""Task without timeout should run to completion regardless of duration."""
|
||||
task_obj = TaskDAO.create_task(
|
||||
task_type="test_timeout_quick",
|
||||
task_key=f"test_key_{uuid.uuid4().hex[:8]}",
|
||||
task_name="Test No Timeout",
|
||||
scope=TaskScope.SYSTEM,
|
||||
# No timeout property
|
||||
)
|
||||
|
||||
# Use str(uuid) since Celery serializes args as JSON strings
|
||||
result = execute_task.apply(
|
||||
args=[str(task_obj.uuid), "test_timeout_quick", (), {}]
|
||||
)
|
||||
|
||||
assert result.successful()
|
||||
assert result.result["status"] == TaskStatus.SUCCESS.value
|
||||
|
||||
def test_abort_handler_exception_results_in_failure(self) -> None:
|
||||
"""If abort handler throws during timeout, task should be FAILURE."""
|
||||
_skip_if_sqlite()
|
||||
|
||||
task_obj = TaskDAO.create_task(
|
||||
task_type="test_timeout_handler_fails",
|
||||
task_key=f"test_key_{uuid.uuid4().hex[:8]}",
|
||||
task_name="Test Handler Fails",
|
||||
scope=TaskScope.SYSTEM,
|
||||
properties={"timeout": 1}, # 1 second timeout
|
||||
)
|
||||
|
||||
# Use str(uuid) since Celery serializes args as JSON strings
|
||||
result = execute_task.apply(
|
||||
args=[str(task_obj.uuid), "test_timeout_handler_fails", (), {}]
|
||||
)
|
||||
|
||||
assert result.successful()
|
||||
assert result.result["status"] == TaskStatus.FAILURE.value
|
||||
assert _handler_state["abort_called"]
|
||||
420
tests/unit_tests/daos/test_tasks.py
Normal file
420
tests/unit_tests/daos/test_tasks.py
Normal file
@@ -0,0 +1,420 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file to you under
|
||||
# the Apache License, Version 2.0 (the "License"); you may not
|
||||
# use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from collections.abc import Iterator
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm.session import Session
|
||||
from superset_core.api.tasks import TaskProperties, TaskScope, TaskStatus
|
||||
|
||||
from superset.commands.tasks.exceptions import TaskNotAbortableError
|
||||
from superset.models.tasks import Task
|
||||
from superset.tasks.utils import get_active_dedup_key, get_finished_dedup_key
|
||||
|
||||
# Test constants
|
||||
TASK_UUID = UUID("e7765491-40c1-4f35-a4f5-06308e79310e")
|
||||
TASK_ID = 42
|
||||
TEST_TASK_TYPE = "test_type"
|
||||
TEST_TASK_KEY = "test-key"
|
||||
TEST_USER_ID = 1
|
||||
|
||||
|
||||
def create_task(
|
||||
session: Session,
|
||||
*,
|
||||
task_id: int | None = None,
|
||||
task_uuid: UUID | None = None,
|
||||
task_key: str = TEST_TASK_KEY,
|
||||
task_type: str = TEST_TASK_TYPE,
|
||||
scope: TaskScope = TaskScope.PRIVATE,
|
||||
status: TaskStatus = TaskStatus.PENDING,
|
||||
user_id: int | None = TEST_USER_ID,
|
||||
properties: TaskProperties | None = None,
|
||||
use_finished_dedup_key: bool = False,
|
||||
) -> Task:
|
||||
"""Helper to create a task with sensible defaults for testing."""
|
||||
if use_finished_dedup_key:
|
||||
dedup_key = get_finished_dedup_key(task_uuid or TASK_UUID)
|
||||
else:
|
||||
dedup_key = get_active_dedup_key(
|
||||
scope=scope,
|
||||
task_type=task_type,
|
||||
task_key=task_key,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
task = Task(
|
||||
task_type=task_type,
|
||||
task_key=task_key,
|
||||
scope=scope.value,
|
||||
status=status.value,
|
||||
dedup_key=dedup_key,
|
||||
user_id=user_id,
|
||||
)
|
||||
if task_id is not None:
|
||||
task.id = task_id
|
||||
if task_uuid:
|
||||
task.uuid = task_uuid
|
||||
if properties:
|
||||
task.update_properties(properties)
|
||||
|
||||
session.add(task)
|
||||
session.flush()
|
||||
return task
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session_with_task(session: Session) -> Iterator[Session]:
|
||||
"""Create a session with Task and TaskSubscriber tables."""
|
||||
from superset.models.task_subscribers import TaskSubscriber
|
||||
|
||||
engine = session.get_bind()
|
||||
Task.metadata.create_all(engine)
|
||||
TaskSubscriber.metadata.create_all(engine)
|
||||
|
||||
yield session
|
||||
session.rollback()
|
||||
|
||||
|
||||
def test_find_by_task_key_active(session_with_task: Session) -> None:
|
||||
"""Test finding active task by task_key"""
|
||||
from superset.daos.tasks import TaskDAO
|
||||
|
||||
create_task(session_with_task)
|
||||
|
||||
result = TaskDAO.find_by_task_key(
|
||||
task_type=TEST_TASK_TYPE,
|
||||
task_key=TEST_TASK_KEY,
|
||||
scope=TaskScope.PRIVATE,
|
||||
user_id=TEST_USER_ID,
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.task_key == TEST_TASK_KEY
|
||||
assert result.task_type == TEST_TASK_TYPE
|
||||
assert result.status == TaskStatus.PENDING.value
|
||||
|
||||
|
||||
def test_find_by_task_key_not_found(session_with_task: Session) -> None:
|
||||
"""Test finding task by task_key returns None when not found"""
|
||||
from superset.daos.tasks import TaskDAO
|
||||
|
||||
result = TaskDAO.find_by_task_key(
|
||||
task_type=TEST_TASK_TYPE,
|
||||
task_key="nonexistent-key",
|
||||
scope=TaskScope.PRIVATE,
|
||||
user_id=TEST_USER_ID,
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_find_by_task_key_finished_not_found(session_with_task: Session) -> None:
|
||||
"""Test that find_by_task_key returns None for finished tasks.
|
||||
|
||||
Finished tasks have a different dedup_key format (UUID-based),
|
||||
so they won't be found by the active task lookup.
|
||||
"""
|
||||
from superset.daos.tasks import TaskDAO
|
||||
|
||||
create_task(
|
||||
session_with_task,
|
||||
task_key="finished-key",
|
||||
status=TaskStatus.SUCCESS,
|
||||
use_finished_dedup_key=True,
|
||||
task_uuid=TASK_UUID,
|
||||
)
|
||||
|
||||
# Should not find SUCCESS task via active lookup
|
||||
result = TaskDAO.find_by_task_key(
|
||||
task_type=TEST_TASK_TYPE,
|
||||
task_key="finished-key",
|
||||
scope=TaskScope.PRIVATE,
|
||||
user_id=TEST_USER_ID,
|
||||
)
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_create_task_success(session_with_task: Session) -> None:
|
||||
"""Test successful task creation."""
|
||||
from superset.daos.tasks import TaskDAO
|
||||
|
||||
result = TaskDAO.create_task(
|
||||
task_type=TEST_TASK_TYPE,
|
||||
task_key=TEST_TASK_KEY,
|
||||
scope=TaskScope.PRIVATE,
|
||||
user_id=TEST_USER_ID,
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.task_key == TEST_TASK_KEY
|
||||
assert result.task_type == TEST_TASK_TYPE
|
||||
assert result.status == TaskStatus.PENDING.value
|
||||
assert isinstance(result, Task)
|
||||
|
||||
|
||||
def test_create_task_with_user_id(session_with_task: Session) -> None:
|
||||
"""Test task creation with explicit user_id."""
|
||||
from superset.daos.tasks import TaskDAO
|
||||
|
||||
result = TaskDAO.create_task(
|
||||
task_type=TEST_TASK_TYPE,
|
||||
task_key="user-task",
|
||||
scope=TaskScope.PRIVATE,
|
||||
user_id=42,
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.user_id == 42
|
||||
# Creator should be auto-subscribed
|
||||
assert len(result.subscribers) == 1
|
||||
assert result.subscribers[0].user_id == 42
|
||||
|
||||
|
||||
def test_create_task_with_properties(session_with_task: Session) -> None:
|
||||
"""Test task creation with properties."""
|
||||
from superset.daos.tasks import TaskDAO
|
||||
|
||||
result = TaskDAO.create_task(
|
||||
task_type=TEST_TASK_TYPE,
|
||||
task_key="props-task",
|
||||
scope=TaskScope.PRIVATE,
|
||||
user_id=TEST_USER_ID,
|
||||
properties={"timeout": 300},
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.properties_dict.get("timeout") == 300
|
||||
|
||||
|
||||
def test_abort_task_pending_success(session_with_task: Session) -> None:
|
||||
"""Test successful abort of pending task - goes directly to ABORTED"""
|
||||
from superset.daos.tasks import TaskDAO
|
||||
|
||||
task = create_task(
|
||||
session_with_task,
|
||||
task_key="pending-task",
|
||||
status=TaskStatus.PENDING,
|
||||
)
|
||||
|
||||
result = TaskDAO.abort_task(task.uuid, skip_base_filter=True)
|
||||
|
||||
assert result is not None
|
||||
assert result.status == TaskStatus.ABORTED.value
|
||||
|
||||
|
||||
def test_abort_task_in_progress_abortable(session_with_task: Session) -> None:
|
||||
"""Test abort of in-progress task with abort handler.
|
||||
|
||||
Should transition to ABORTING status.
|
||||
"""
|
||||
from superset.daos.tasks import TaskDAO
|
||||
|
||||
task = create_task(
|
||||
session_with_task,
|
||||
task_key="abortable-task",
|
||||
status=TaskStatus.IN_PROGRESS,
|
||||
properties={"is_abortable": True},
|
||||
)
|
||||
|
||||
result = TaskDAO.abort_task(task.uuid, skip_base_filter=True)
|
||||
|
||||
assert result is not None
|
||||
# Should set status to ABORTING, not ABORTED
|
||||
assert result.status == TaskStatus.ABORTING.value
|
||||
|
||||
|
||||
def test_abort_task_in_progress_not_abortable(session_with_task: Session) -> None:
|
||||
"""Test abort of in-progress task without abort handler - raises error"""
|
||||
from superset.daos.tasks import TaskDAO
|
||||
|
||||
task = create_task(
|
||||
session_with_task,
|
||||
task_key="non-abortable-task",
|
||||
status=TaskStatus.IN_PROGRESS,
|
||||
properties={"is_abortable": False},
|
||||
)
|
||||
|
||||
with pytest.raises(TaskNotAbortableError):
|
||||
TaskDAO.abort_task(task.uuid, skip_base_filter=True)
|
||||
|
||||
|
||||
def test_abort_task_in_progress_is_abortable_none(session_with_task: Session) -> None:
|
||||
"""Test abort of in-progress task with is_abortable not set - raises error"""
|
||||
from superset.daos.tasks import TaskDAO
|
||||
|
||||
task = create_task(
|
||||
session_with_task,
|
||||
task_key="no-abortable-prop-task",
|
||||
status=TaskStatus.IN_PROGRESS,
|
||||
# Empty properties - no is_abortable key
|
||||
)
|
||||
|
||||
with pytest.raises(TaskNotAbortableError):
|
||||
TaskDAO.abort_task(task.uuid, skip_base_filter=True)
|
||||
|
||||
|
||||
def test_abort_task_already_aborting(session_with_task: Session) -> None:
|
||||
"""Test abort of already aborting task - idempotent success"""
|
||||
from superset.daos.tasks import TaskDAO
|
||||
|
||||
task = create_task(
|
||||
session_with_task,
|
||||
task_key="aborting-task",
|
||||
status=TaskStatus.ABORTING,
|
||||
)
|
||||
|
||||
result = TaskDAO.abort_task(task.uuid, skip_base_filter=True)
|
||||
|
||||
# Idempotent - returns task without error
|
||||
assert result is not None
|
||||
assert result.status == TaskStatus.ABORTING.value
|
||||
|
||||
|
||||
def test_abort_task_not_found(session_with_task: Session) -> None:
|
||||
"""Test abort fails when task not found"""
|
||||
from superset.daos.tasks import TaskDAO
|
||||
|
||||
result = TaskDAO.abort_task(UUID("00000000-0000-0000-0000-000000000000"))
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_abort_task_already_finished(session_with_task: Session) -> None:
|
||||
"""Test abort fails when task already finished"""
|
||||
from superset.daos.tasks import TaskDAO
|
||||
|
||||
task = create_task(
|
||||
session_with_task,
|
||||
task_key="finished-task",
|
||||
status=TaskStatus.SUCCESS,
|
||||
use_finished_dedup_key=True,
|
||||
task_uuid=TASK_UUID,
|
||||
)
|
||||
|
||||
result = TaskDAO.abort_task(task.uuid, skip_base_filter=True)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_add_subscriber(session_with_task: Session) -> None:
|
||||
"""Test adding a subscriber to a task"""
|
||||
from superset.daos.tasks import TaskDAO
|
||||
|
||||
task = create_task(
|
||||
session_with_task,
|
||||
task_key="shared-task",
|
||||
scope=TaskScope.SHARED,
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
# Add subscriber
|
||||
result = TaskDAO.add_subscriber(task.id, user_id=TEST_USER_ID)
|
||||
assert result is True
|
||||
|
||||
# Verify subscriber was added
|
||||
session_with_task.refresh(task)
|
||||
assert len(task.subscribers) == 1
|
||||
assert task.subscribers[0].user_id == TEST_USER_ID
|
||||
|
||||
|
||||
def test_add_subscriber_idempotent(session_with_task: Session) -> None:
|
||||
"""Test adding same subscriber twice is idempotent"""
|
||||
from superset.daos.tasks import TaskDAO
|
||||
|
||||
task = create_task(
|
||||
session_with_task,
|
||||
task_key="shared-task-2",
|
||||
scope=TaskScope.SHARED,
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
# Add subscriber twice
|
||||
result1 = TaskDAO.add_subscriber(task.id, user_id=TEST_USER_ID)
|
||||
result2 = TaskDAO.add_subscriber(task.id, user_id=TEST_USER_ID)
|
||||
|
||||
assert result1 is True
|
||||
assert result2 is False # Already subscribed
|
||||
|
||||
# Verify only one subscriber
|
||||
session_with_task.refresh(task)
|
||||
assert len(task.subscribers) == 1
|
||||
|
||||
|
||||
def test_remove_subscriber(session_with_task: Session) -> None:
|
||||
"""Test removing a subscriber from a task"""
|
||||
from superset.daos.tasks import TaskDAO
|
||||
|
||||
task = create_task(
|
||||
session_with_task,
|
||||
task_key="shared-task-3",
|
||||
scope=TaskScope.SHARED,
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
TaskDAO.add_subscriber(task.id, user_id=TEST_USER_ID)
|
||||
session_with_task.refresh(task)
|
||||
assert len(task.subscribers) == 1
|
||||
|
||||
# Remove subscriber
|
||||
result = TaskDAO.remove_subscriber(task.id, user_id=TEST_USER_ID)
|
||||
|
||||
assert result is not None
|
||||
assert len(result.subscribers) == 0
|
||||
|
||||
|
||||
def test_remove_subscriber_not_subscribed(session_with_task: Session) -> None:
|
||||
"""Test removing non-existent subscriber returns None"""
|
||||
from superset.daos.tasks import TaskDAO
|
||||
|
||||
task = create_task(
|
||||
session_with_task,
|
||||
task_key="shared-task-4",
|
||||
scope=TaskScope.SHARED,
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
# Try to remove non-existent subscriber
|
||||
result = TaskDAO.remove_subscriber(task.id, user_id=999)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_get_status(session_with_task: Session) -> None:
|
||||
"""Test get_status returns status string when task found by UUID"""
|
||||
from superset.daos.tasks import TaskDAO
|
||||
|
||||
task = create_task(
|
||||
session_with_task,
|
||||
task_uuid=TASK_UUID,
|
||||
task_key="status-task",
|
||||
status=TaskStatus.IN_PROGRESS,
|
||||
)
|
||||
|
||||
result = TaskDAO.get_status(task.uuid)
|
||||
|
||||
assert result == TaskStatus.IN_PROGRESS.value
|
||||
|
||||
|
||||
def test_get_status_not_found(session_with_task: Session) -> None:
|
||||
"""Test get_status returns None when task not found"""
|
||||
from superset.daos.tasks import TaskDAO
|
||||
|
||||
result = TaskDAO.get_status(UUID("00000000-0000-0000-0000-000000000000"))
|
||||
|
||||
assert result is None
|
||||
@@ -18,17 +18,21 @@
|
||||
# pylint: disable=invalid-name
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from freezegun import freeze_time
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
# Force module loading before tests run so patches work correctly
|
||||
import superset.commands.distributed_lock.acquire as acquire_module
|
||||
import superset.commands.distributed_lock.release as release_module
|
||||
from superset import db
|
||||
from superset.distributed_lock import KeyValueDistributedLock
|
||||
from superset.distributed_lock import DistributedLock
|
||||
from superset.distributed_lock.types import LockValue
|
||||
from superset.distributed_lock.utils import get_key
|
||||
from superset.exceptions import CreateKeyValueDistributedLockFailedException
|
||||
from superset.exceptions import AcquireDistributedLockFailedException
|
||||
from superset.key_value.types import JsonKeyValueCodec
|
||||
|
||||
LOCK_VALUE: LockValue = {"value": True}
|
||||
@@ -56,9 +60,9 @@ def _get_other_session() -> Session:
|
||||
return SessionMaker()
|
||||
|
||||
|
||||
def test_key_value_distributed_lock_happy_path() -> None:
|
||||
def test_distributed_lock_kv_happy_path() -> None:
|
||||
"""
|
||||
Test successfully acquiring and returning the distributed lock.
|
||||
Test successfully acquiring and returning the distributed lock via KV backend.
|
||||
|
||||
Note, we're using another session for asserting the lock state in the Metastore
|
||||
to simulate what another worker will observe. Otherwise, there's the risk that
|
||||
@@ -66,24 +70,29 @@ def test_key_value_distributed_lock_happy_path() -> None:
|
||||
"""
|
||||
session = _get_other_session()
|
||||
|
||||
with freeze_time("2021-01-01"):
|
||||
assert _get_lock(MAIN_KEY, session) is None
|
||||
# Ensure Redis is not configured so KV backend is used
|
||||
with (
|
||||
patch.object(acquire_module, "get_redis_client", return_value=None),
|
||||
patch.object(release_module, "get_redis_client", return_value=None),
|
||||
):
|
||||
with freeze_time("2021-01-01"):
|
||||
assert _get_lock(MAIN_KEY, session) is None
|
||||
|
||||
with KeyValueDistributedLock("ns", a=1, b=2) as key:
|
||||
assert key == MAIN_KEY
|
||||
assert _get_lock(key, session) == LOCK_VALUE
|
||||
assert _get_lock(OTHER_KEY, session) is None
|
||||
with DistributedLock("ns", a=1, b=2) as key:
|
||||
assert key == MAIN_KEY
|
||||
assert _get_lock(key, session) == LOCK_VALUE
|
||||
assert _get_lock(OTHER_KEY, session) is None
|
||||
|
||||
with pytest.raises(CreateKeyValueDistributedLockFailedException):
|
||||
with KeyValueDistributedLock("ns", a=1, b=2):
|
||||
pass
|
||||
with pytest.raises(AcquireDistributedLockFailedException):
|
||||
with DistributedLock("ns", a=1, b=2):
|
||||
pass
|
||||
|
||||
assert _get_lock(MAIN_KEY, session) is None
|
||||
assert _get_lock(MAIN_KEY, session) is None
|
||||
|
||||
|
||||
def test_key_value_distributed_lock_expired() -> None:
|
||||
def test_distributed_lock_kv_expired() -> None:
|
||||
"""
|
||||
Test expiration of the distributed lock
|
||||
Test expiration of the distributed lock via KV backend.
|
||||
|
||||
Note, we're using another session for asserting the lock state in the Metastore
|
||||
to simulate what another worker will observe. Otherwise, there's the risk that
|
||||
@@ -91,11 +100,112 @@ def test_key_value_distributed_lock_expired() -> None:
|
||||
"""
|
||||
session = _get_other_session()
|
||||
|
||||
with freeze_time("2021-01-01"):
|
||||
assert _get_lock(MAIN_KEY, session) is None
|
||||
with KeyValueDistributedLock("ns", a=1, b=2):
|
||||
assert _get_lock(MAIN_KEY, session) == LOCK_VALUE
|
||||
with freeze_time("2022-01-01"):
|
||||
assert _get_lock(MAIN_KEY, session) is None
|
||||
# Ensure Redis is not configured so KV backend is used
|
||||
with (
|
||||
patch.object(acquire_module, "get_redis_client", return_value=None),
|
||||
patch.object(release_module, "get_redis_client", return_value=None),
|
||||
):
|
||||
with freeze_time("2021-01-01"):
|
||||
assert _get_lock(MAIN_KEY, session) is None
|
||||
with DistributedLock("ns", a=1, b=2):
|
||||
assert _get_lock(MAIN_KEY, session) == LOCK_VALUE
|
||||
with freeze_time("2022-01-01"):
|
||||
assert _get_lock(MAIN_KEY, session) is None
|
||||
|
||||
assert _get_lock(MAIN_KEY, session) is None
|
||||
assert _get_lock(MAIN_KEY, session) is None
|
||||
|
||||
|
||||
def test_distributed_lock_uses_redis_when_configured() -> None:
|
||||
"""Test that DistributedLock uses Redis backend when configured."""
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.set.return_value = True # Lock acquired
|
||||
|
||||
# Use patch.object to patch on already-imported modules
|
||||
with (
|
||||
patch.object(acquire_module, "get_redis_client", return_value=mock_redis),
|
||||
patch.object(release_module, "get_redis_client", return_value=mock_redis),
|
||||
):
|
||||
with DistributedLock("test_redis", key="value") as lock_key:
|
||||
assert lock_key is not None
|
||||
# Verify SET NX EX was called
|
||||
mock_redis.set.assert_called_once()
|
||||
call_args = mock_redis.set.call_args
|
||||
assert call_args.kwargs["nx"] is True
|
||||
assert "ex" in call_args.kwargs
|
||||
|
||||
# Verify DELETE was called on exit
|
||||
mock_redis.delete.assert_called_once()
|
||||
|
||||
|
||||
def test_distributed_lock_redis_already_taken() -> None:
|
||||
"""Test Redis lock fails when already held."""
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.set.return_value = None # Lock not acquired (already taken)
|
||||
|
||||
with patch.object(acquire_module, "get_redis_client", return_value=mock_redis):
|
||||
with pytest.raises(AcquireDistributedLockFailedException):
|
||||
with DistributedLock("test_redis", key="value"):
|
||||
pass
|
||||
|
||||
|
||||
def test_distributed_lock_redis_connection_error() -> None:
|
||||
"""Test Redis connection error raises exception (fail fast)."""
|
||||
import redis
|
||||
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.set.side_effect = redis.RedisError("Connection failed")
|
||||
|
||||
with patch.object(acquire_module, "get_redis_client", return_value=mock_redis):
|
||||
with pytest.raises(AcquireDistributedLockFailedException):
|
||||
with DistributedLock("test_redis", key="value"):
|
||||
pass
|
||||
|
||||
|
||||
def test_distributed_lock_custom_ttl() -> None:
|
||||
"""Test Redis lock with custom TTL."""
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.set.return_value = True
|
||||
|
||||
with (
|
||||
patch.object(acquire_module, "get_redis_client", return_value=mock_redis),
|
||||
patch.object(release_module, "get_redis_client", return_value=mock_redis),
|
||||
):
|
||||
with DistributedLock("test", ttl_seconds=60, key="value"):
|
||||
call_args = mock_redis.set.call_args
|
||||
assert call_args.kwargs["ex"] == 60 # Custom TTL
|
||||
|
||||
|
||||
def test_distributed_lock_default_ttl(app_context: None) -> None:
|
||||
"""Test Redis lock uses default TTL when not specified."""
|
||||
from superset.commands.distributed_lock.base import get_default_lock_ttl
|
||||
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.set.return_value = True
|
||||
|
||||
with (
|
||||
patch.object(acquire_module, "get_redis_client", return_value=mock_redis),
|
||||
patch.object(release_module, "get_redis_client", return_value=mock_redis),
|
||||
):
|
||||
with DistributedLock("test", key="value"):
|
||||
call_args = mock_redis.set.call_args
|
||||
assert call_args.kwargs["ex"] == get_default_lock_ttl()
|
||||
|
||||
|
||||
def test_distributed_lock_fallback_to_kv_when_redis_not_configured() -> None:
|
||||
"""Test falls back to KV lock when Redis not configured."""
|
||||
session = _get_other_session()
|
||||
test_key = get_key("test_fallback", key="value")
|
||||
|
||||
with (
|
||||
patch.object(acquire_module, "get_redis_client", return_value=None),
|
||||
patch.object(release_module, "get_redis_client", return_value=None),
|
||||
):
|
||||
with freeze_time("2021-01-01"):
|
||||
# When Redis is not configured, should use KV backend
|
||||
with DistributedLock("test_fallback", key="value") as lock_key:
|
||||
assert lock_key == test_key
|
||||
# Verify lock exists in KV store
|
||||
assert _get_lock(test_key, session) == LOCK_VALUE
|
||||
|
||||
# Lock should be released
|
||||
assert _get_lock(test_key, session) is None
|
||||
|
||||
477
tests/unit_tests/tasks/test_decorators.py
Normal file
477
tests/unit_tests/tasks/test_decorators.py
Normal file
@@ -0,0 +1,477 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""Unit tests for task decorators"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from superset_core.api.tasks import TaskOptions, TaskScope
|
||||
|
||||
from superset.commands.tasks.exceptions import GlobalTaskFrameworkDisabledError
|
||||
from superset.tasks.decorators import task, TaskWrapper
|
||||
from superset.tasks.registry import TaskRegistry
|
||||
|
||||
TEST_UUID = UUID("b8b61b7b-1cd3-4a31-a74a-0a95341afc06")
|
||||
|
||||
|
||||
class TestTaskDecoratorFeatureFlag:
|
||||
"""Tests for @task decorator feature flag behavior"""
|
||||
|
||||
def setup_method(self):
|
||||
"""Clear task registry before each test"""
|
||||
TaskRegistry._tasks.clear()
|
||||
|
||||
@patch("superset.tasks.decorators.is_feature_enabled", return_value=False)
|
||||
def test_decorator_succeeds_when_gtf_disabled(self, mock_feature_flag):
|
||||
"""Test that @task decorator can be applied even when GTF is disabled.
|
||||
|
||||
This enables safe module imports during app startup or Celery autodiscovery.
|
||||
"""
|
||||
|
||||
# Decoration should succeed - no error raised
|
||||
@task(name="test_gtf_disabled_decorator")
|
||||
def my_task() -> None:
|
||||
pass
|
||||
|
||||
assert isinstance(my_task, TaskWrapper)
|
||||
assert my_task.name == "test_gtf_disabled_decorator"
|
||||
|
||||
@patch("superset.tasks.decorators.is_feature_enabled", return_value=False)
|
||||
def test_call_raises_error_when_gtf_disabled(self, mock_feature_flag):
|
||||
"""Test that calling a task raises GlobalTaskFrameworkDisabledError
|
||||
when GTF is disabled."""
|
||||
|
||||
@task(name="test_gtf_disabled_call")
|
||||
def my_task() -> None:
|
||||
pass
|
||||
|
||||
with pytest.raises(GlobalTaskFrameworkDisabledError):
|
||||
my_task()
|
||||
|
||||
@patch("superset.tasks.decorators.is_feature_enabled", return_value=False)
|
||||
def test_schedule_raises_error_when_gtf_disabled(self, mock_feature_flag):
|
||||
"""Test that scheduling a task raises GlobalTaskFrameworkDisabledError
|
||||
when GTF is disabled."""
|
||||
|
||||
@task(name="test_gtf_disabled_schedule")
|
||||
def my_task() -> None:
|
||||
pass
|
||||
|
||||
with pytest.raises(GlobalTaskFrameworkDisabledError):
|
||||
my_task.schedule()
|
||||
|
||||
|
||||
class TestTaskDecorator:
|
||||
"""Tests for @task decorator"""
|
||||
|
||||
def test_decorator_basic(self):
|
||||
"""Test basic decorator usage without options"""
|
||||
|
||||
@task(name="test_task")
|
||||
def my_task(arg1: int, arg2: str) -> None:
|
||||
pass
|
||||
|
||||
assert isinstance(my_task, TaskWrapper)
|
||||
assert my_task.name == "test_task"
|
||||
assert my_task.scope == TaskScope.PRIVATE
|
||||
|
||||
def test_decorator_without_parentheses(self):
|
||||
"""Test decorator usage without parentheses"""
|
||||
|
||||
@task
|
||||
def my_no_parens_task(arg1: int, arg2: str) -> None:
|
||||
pass
|
||||
|
||||
assert isinstance(my_no_parens_task, TaskWrapper)
|
||||
assert my_no_parens_task.name == "my_no_parens_task" # Uses function name
|
||||
assert my_no_parens_task.scope == TaskScope.PRIVATE
|
||||
|
||||
def test_decorator_with_default_scope_private(self):
|
||||
"""Test decorator with explicit PRIVATE scope"""
|
||||
|
||||
@task(name="private_task", scope=TaskScope.PRIVATE)
|
||||
def my_private_task(arg1: int) -> None:
|
||||
pass
|
||||
|
||||
assert my_private_task.scope == TaskScope.PRIVATE
|
||||
|
||||
def test_decorator_with_default_scope_shared(self):
|
||||
"""Test decorator with SHARED scope"""
|
||||
|
||||
@task(name="shared_task", scope=TaskScope.SHARED)
|
||||
def my_shared_task(arg1: int) -> None:
|
||||
pass
|
||||
|
||||
assert my_shared_task.scope == TaskScope.SHARED
|
||||
|
||||
def test_decorator_with_default_scope_system(self):
|
||||
"""Test decorator with SYSTEM scope"""
|
||||
|
||||
@task(name="system_task", scope=TaskScope.SYSTEM)
|
||||
def my_system_task() -> None:
|
||||
pass
|
||||
|
||||
assert my_system_task.scope == TaskScope.SYSTEM
|
||||
|
||||
def test_decorator_forbids_ctx_parameter(self):
|
||||
"""Test decorator rejects functions with ctx parameter"""
|
||||
|
||||
with pytest.raises(TypeError, match="must not define 'ctx'"):
|
||||
|
||||
@task(name="bad_task")
|
||||
def bad_task(ctx, arg1: int) -> None: # noqa: ARG001
|
||||
pass
|
||||
|
||||
def test_decorator_forbids_options_parameter(self):
|
||||
"""Test decorator rejects functions with options parameter"""
|
||||
|
||||
with pytest.raises(TypeError, match="must not define.*'options'"):
|
||||
|
||||
@task(name="bad_task")
|
||||
def bad_task(options, arg1: int) -> None: # noqa: ARG001
|
||||
pass
|
||||
|
||||
|
||||
class TestTaskWrapperMergeOptions:
|
||||
"""Tests for TaskWrapper._merge_options()"""
|
||||
|
||||
def setup_method(self):
|
||||
"""Clear task registry before each test"""
|
||||
TaskRegistry._tasks.clear()
|
||||
|
||||
def test_merge_options_no_override(self):
|
||||
"""Test merging with no override returns defaults"""
|
||||
|
||||
@task(name="test_merge_no_override_unique")
|
||||
def merge_task_1() -> None:
|
||||
pass
|
||||
|
||||
# Set default options for testing
|
||||
merge_task_1.default_options = TaskOptions(
|
||||
task_key="default_key",
|
||||
task_name="Default Name",
|
||||
)
|
||||
|
||||
merged = merge_task_1._merge_options(None)
|
||||
assert merged.task_key == "default_key"
|
||||
assert merged.task_name == "Default Name"
|
||||
|
||||
def test_merge_options_override_task_key(self):
|
||||
"""Test overriding task_key at call time"""
|
||||
|
||||
@task(name="test_merge_override_key_unique")
|
||||
def merge_task_2() -> None:
|
||||
pass
|
||||
|
||||
# Set default options for testing
|
||||
merge_task_2.default_options = TaskOptions(task_key="default_key")
|
||||
|
||||
override = TaskOptions(task_key="override_key")
|
||||
merged = merge_task_2._merge_options(override)
|
||||
assert merged.task_key == "override_key"
|
||||
|
||||
def test_merge_options_override_task_name(self):
|
||||
"""Test overriding task_name at call time"""
|
||||
|
||||
@task(name="test_merge_override_name_unique")
|
||||
def merge_task_3() -> None:
|
||||
pass
|
||||
|
||||
# Set default options for testing
|
||||
merge_task_3.default_options = TaskOptions(task_name="Default Name")
|
||||
|
||||
override = TaskOptions(task_name="Override Name")
|
||||
merged = merge_task_3._merge_options(override)
|
||||
assert merged.task_name == "Override Name"
|
||||
|
||||
def test_merge_options_override_all(self):
|
||||
"""Test overriding all options at call time"""
|
||||
|
||||
@task(name="test_merge_override_all_unique")
|
||||
def merge_task_4() -> None:
|
||||
pass
|
||||
|
||||
# Set default options for testing
|
||||
merge_task_4.default_options = TaskOptions(
|
||||
task_key="default_key",
|
||||
task_name="Default Name",
|
||||
)
|
||||
|
||||
override = TaskOptions(
|
||||
task_key="override_key",
|
||||
task_name="Override Name",
|
||||
)
|
||||
merged = merge_task_4._merge_options(override)
|
||||
assert merged.task_key == "override_key"
|
||||
assert merged.task_name == "Override Name"
|
||||
|
||||
|
||||
class TestTaskWrapperSchedule:
|
||||
"""Tests for TaskWrapper.schedule() with scope"""
|
||||
|
||||
def setup_method(self):
|
||||
"""Clear task registry before each test"""
|
||||
TaskRegistry._tasks.clear()
|
||||
|
||||
@patch("superset.tasks.decorators.TaskManager.submit_task")
|
||||
def test_schedule_uses_default_scope(self, mock_submit):
|
||||
"""Test schedule() uses decorator's default scope"""
|
||||
mock_submit.return_value = MagicMock()
|
||||
|
||||
@task(name="test_schedule_default_unique", scope=TaskScope.SHARED)
|
||||
def schedule_task_1(arg1: int) -> None:
|
||||
pass
|
||||
|
||||
# Shared tasks require explicit task_key
|
||||
schedule_task_1.schedule(123, options=TaskOptions(task_key="test_key"))
|
||||
|
||||
# Verify TaskManager.submit_task was called with correct scope
|
||||
mock_submit.assert_called_once()
|
||||
call_args = mock_submit.call_args
|
||||
assert call_args[1]["scope"] == TaskScope.SHARED
|
||||
|
||||
@patch("superset.tasks.decorators.TaskManager.submit_task")
|
||||
def test_schedule_uses_private_scope_by_default(self, mock_submit):
|
||||
"""Test schedule() uses PRIVATE scope when no scope specified"""
|
||||
mock_submit.return_value = MagicMock()
|
||||
|
||||
@task(name="test_schedule_override_unique")
|
||||
def schedule_task_2(arg1: int) -> None:
|
||||
pass
|
||||
|
||||
schedule_task_2.schedule(123)
|
||||
|
||||
# Verify PRIVATE scope was used (default)
|
||||
mock_submit.assert_called_once()
|
||||
call_args = mock_submit.call_args
|
||||
assert call_args[1]["scope"] == TaskScope.PRIVATE
|
||||
|
||||
@patch("superset.tasks.decorators.TaskManager.submit_task")
|
||||
def test_schedule_with_custom_options(self, mock_submit):
|
||||
"""Test schedule() with custom task options"""
|
||||
mock_submit.return_value = MagicMock()
|
||||
|
||||
@task(name="test_schedule_custom_unique", scope=TaskScope.SYSTEM)
|
||||
def schedule_task_3(arg1: int) -> None:
|
||||
pass
|
||||
|
||||
# Use custom task key and name
|
||||
schedule_task_3.schedule(
|
||||
123,
|
||||
options=TaskOptions(task_key="custom_key", task_name="Custom Task Name"),
|
||||
)
|
||||
|
||||
# Verify scope from decorator and options from call time
|
||||
mock_submit.assert_called_once()
|
||||
call_args = mock_submit.call_args
|
||||
assert call_args[1]["scope"] == TaskScope.SYSTEM
|
||||
assert call_args[1]["task_key"] == "custom_key"
|
||||
assert call_args[1]["task_name"] == "Custom Task Name"
|
||||
|
||||
@patch("superset.tasks.decorators.TaskManager.submit_task")
|
||||
def test_schedule_with_no_decorator_options(self, mock_submit):
|
||||
"""Test schedule() uses default PRIVATE scope when no options provided"""
|
||||
mock_submit.return_value = MagicMock()
|
||||
|
||||
@task(name="test_schedule_no_options_unique")
|
||||
def schedule_task_4(arg1: int) -> None:
|
||||
pass
|
||||
|
||||
schedule_task_4.schedule(123)
|
||||
|
||||
# Verify default PRIVATE scope
|
||||
mock_submit.assert_called_once()
|
||||
call_args = mock_submit.call_args
|
||||
assert call_args[1]["scope"] == TaskScope.PRIVATE
|
||||
|
||||
@patch("superset.tasks.decorators.TaskManager.submit_task")
|
||||
def test_schedule_shared_task_requires_task_key(self, mock_submit):
|
||||
"""Test shared task schedule() requires explicit task_key"""
|
||||
|
||||
@task(name="test_shared_requires_key", scope=TaskScope.SHARED)
|
||||
def shared_task(arg1: int) -> None:
|
||||
pass
|
||||
|
||||
# Should raise ValueError when no task_key provided
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="Shared task.*requires an explicit task_key.*for deduplication",
|
||||
):
|
||||
shared_task.schedule(123)
|
||||
|
||||
# Should work with task_key provided
|
||||
mock_submit.return_value = MagicMock()
|
||||
shared_task.schedule(123, options=TaskOptions(task_key="valid_key"))
|
||||
mock_submit.assert_called_once()
|
||||
|
||||
@patch("superset.tasks.decorators.TaskManager.submit_task")
|
||||
def test_schedule_private_task_allows_no_task_key(self, mock_submit):
|
||||
"""Test private task schedule() works without task_key"""
|
||||
mock_submit.return_value = MagicMock()
|
||||
|
||||
@task(name="test_private_no_key", scope=TaskScope.PRIVATE)
|
||||
def private_task(arg1: int) -> None:
|
||||
pass
|
||||
|
||||
# Should work without task_key (generates random UUID)
|
||||
private_task.schedule(123)
|
||||
mock_submit.assert_called_once()
|
||||
|
||||
|
||||
class TestTaskWrapperCall:
|
||||
"""Tests for TaskWrapper.__call__() with scope"""
|
||||
|
||||
def setup_method(self):
|
||||
"""Clear task registry before each test"""
|
||||
TaskRegistry._tasks.clear()
|
||||
|
||||
@patch("superset.commands.tasks.update.UpdateTaskCommand.run")
|
||||
@patch("superset.daos.tasks.TaskDAO.find_one_or_none")
|
||||
@patch("superset.commands.tasks.submit.SubmitTaskCommand.run_with_info")
|
||||
def test_call_uses_default_scope(
|
||||
self, mock_submit_run_with_info, mock_find, mock_update_run
|
||||
):
|
||||
"""Test direct call uses decorator's default scope"""
|
||||
mock_task = MagicMock()
|
||||
mock_task.uuid = TEST_UUID
|
||||
mock_task.status = "in_progress"
|
||||
mock_submit_run_with_info.return_value = (mock_task, True) # (task, is_new)
|
||||
mock_update_run.return_value = mock_task
|
||||
mock_find.return_value = mock_task # Mock the subsequent find call
|
||||
|
||||
@task(name="test_call_default_unique", scope=TaskScope.SHARED)
|
||||
def call_task_1(arg1: int) -> None:
|
||||
pass
|
||||
|
||||
# Shared tasks require explicit task_key
|
||||
call_task_1(123, options=TaskOptions(task_key="test_key"))
|
||||
|
||||
# Verify SubmitTaskCommand.run_with_info was called
|
||||
mock_submit_run_with_info.assert_called_once()
|
||||
|
||||
@patch("superset.utils.core.get_user_id")
|
||||
@patch("superset.commands.tasks.update.UpdateTaskCommand.run")
|
||||
@patch("superset.daos.tasks.TaskDAO.find_one_or_none")
|
||||
@patch("superset.commands.tasks.submit.SubmitTaskCommand.run_with_info")
|
||||
def test_call_uses_private_scope_by_default(
|
||||
self, mock_submit_run_with_info, mock_find, mock_update_run, mock_get_user_id
|
||||
):
|
||||
"""Test direct call uses PRIVATE scope when no scope specified"""
|
||||
mock_get_user_id.return_value = 1
|
||||
mock_task = MagicMock()
|
||||
mock_task.uuid = TEST_UUID
|
||||
mock_task.status = "in_progress"
|
||||
mock_submit_run_with_info.return_value = (mock_task, True) # (task, is_new)
|
||||
mock_update_run.return_value = mock_task
|
||||
mock_find.return_value = mock_task # Mock the subsequent find call
|
||||
|
||||
@task(name="test_call_private_default_unique")
|
||||
def call_task_2(arg1: int) -> None:
|
||||
pass
|
||||
|
||||
call_task_2(123)
|
||||
|
||||
# Verify SubmitTaskCommand.run_with_info was called
|
||||
mock_submit_run_with_info.assert_called_once()
|
||||
|
||||
@patch("superset.commands.tasks.update.UpdateTaskCommand.run")
|
||||
@patch("superset.daos.tasks.TaskDAO.find_one_or_none")
|
||||
@patch("superset.commands.tasks.submit.SubmitTaskCommand.run_with_info")
|
||||
def test_call_with_custom_options(
|
||||
self, mock_submit_run_with_info, mock_find, mock_update_run
|
||||
):
|
||||
"""Test direct call with custom task options"""
|
||||
mock_task = MagicMock()
|
||||
mock_task.uuid = TEST_UUID
|
||||
mock_task.status = "in_progress"
|
||||
mock_submit_run_with_info.return_value = (mock_task, True) # (task, is_new)
|
||||
mock_update_run.return_value = mock_task
|
||||
mock_find.return_value = mock_task # Mock the subsequent find call
|
||||
|
||||
@task(name="test_call_custom_unique", scope=TaskScope.SYSTEM)
|
||||
def call_task_3(arg1: int) -> None:
|
||||
pass
|
||||
|
||||
# Use custom task key and name
|
||||
call_task_3(
|
||||
123,
|
||||
options=TaskOptions(task_key="custom_key", task_name="Custom Task Name"),
|
||||
)
|
||||
|
||||
# Verify SubmitTaskCommand.run_with_info was called
|
||||
mock_submit_run_with_info.assert_called_once()
|
||||
|
||||
def test_call_shared_task_requires_task_key(self):
|
||||
"""Test shared task direct call requires explicit task_key"""
|
||||
|
||||
@task(name="test_shared_call_requires_key", scope=TaskScope.SHARED)
|
||||
def shared_task(arg1: int) -> None:
|
||||
pass
|
||||
|
||||
# Should raise ValueError when no task_key provided
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="Shared task.*requires an explicit task_key.*for deduplication",
|
||||
):
|
||||
shared_task(123)
|
||||
|
||||
@patch("superset.commands.tasks.update.UpdateTaskCommand.run")
|
||||
@patch("superset.daos.tasks.TaskDAO.find_one_or_none")
|
||||
@patch("superset.commands.tasks.submit.SubmitTaskCommand.run_with_info")
|
||||
def test_call_shared_task_works_with_task_key(
|
||||
self, mock_submit_run_with_info, mock_find, mock_update_run
|
||||
):
|
||||
"""Test shared task direct call works with task_key"""
|
||||
mock_task = MagicMock()
|
||||
mock_task.uuid = TEST_UUID
|
||||
mock_task.status = "in_progress"
|
||||
mock_submit_run_with_info.return_value = (mock_task, True) # (task, is_new)
|
||||
mock_update_run.return_value = mock_task
|
||||
mock_find.return_value = mock_task
|
||||
|
||||
@task(name="test_shared_call_with_key", scope=TaskScope.SHARED)
|
||||
def shared_task(arg1: int) -> None:
|
||||
pass
|
||||
|
||||
# Should work with task_key provided
|
||||
shared_task(123, options=TaskOptions(task_key="valid_key"))
|
||||
mock_submit_run_with_info.assert_called_once()
|
||||
|
||||
@patch("superset.utils.core.get_user_id")
|
||||
@patch("superset.commands.tasks.update.UpdateTaskCommand.run")
|
||||
@patch("superset.daos.tasks.TaskDAO.find_one_or_none")
|
||||
@patch("superset.commands.tasks.submit.SubmitTaskCommand.run_with_info")
|
||||
def test_call_private_task_allows_no_task_key(
|
||||
self, mock_submit_run_with_info, mock_find, mock_update_run, mock_get_user_id
|
||||
):
|
||||
"""Test private task direct call works without task_key"""
|
||||
mock_get_user_id.return_value = 1
|
||||
mock_task = MagicMock()
|
||||
mock_task.uuid = TEST_UUID
|
||||
mock_task.status = "in_progress"
|
||||
mock_submit_run_with_info.return_value = (mock_task, True) # (task, is_new)
|
||||
mock_update_run.return_value = mock_task
|
||||
mock_find.return_value = mock_task
|
||||
|
||||
@task(name="test_private_call_no_key", scope=TaskScope.PRIVATE)
|
||||
def private_task(arg1: int) -> None:
|
||||
pass
|
||||
|
||||
# Should work without task_key (generates random UUID)
|
||||
private_task(123)
|
||||
mock_submit_run_with_info.assert_called_once()
|
||||
677
tests/unit_tests/tasks/test_handlers.py
Normal file
677
tests/unit_tests/tasks/test_handlers.py
Normal file
@@ -0,0 +1,677 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""Unit tests for GTF handlers (abort, cleanup) and related Task model behavior."""
|
||||
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from freezegun import freeze_time
|
||||
from superset_core.api.tasks import TaskStatus
|
||||
|
||||
from superset.tasks.context import TaskContext
|
||||
|
||||
TEST_UUID = UUID("b8b61b7b-1cd3-4a31-a74a-0a95341afc06")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_task():
|
||||
"""Create a mock task for testing."""
|
||||
task = MagicMock()
|
||||
task.uuid = TEST_UUID
|
||||
task.status = TaskStatus.PENDING.value
|
||||
return task
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_task_dao(mock_task):
|
||||
"""Mock TaskDAO to return our test task."""
|
||||
with patch("superset.daos.tasks.TaskDAO") as mock_dao:
|
||||
mock_dao.find_one_or_none.return_value = mock_task
|
||||
yield mock_dao
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_update_command():
|
||||
"""Mock UpdateTaskCommand to avoid database operations."""
|
||||
with patch("superset.commands.tasks.update.UpdateTaskCommand") as mock_cmd:
|
||||
mock_cmd.return_value.run.return_value = None
|
||||
yield mock_cmd
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_flask_app():
|
||||
"""Create a properly configured mock Flask app."""
|
||||
mock_app = MagicMock()
|
||||
mock_app.config = {
|
||||
"TASK_ABORT_POLLING_DEFAULT_INTERVAL": 0.1,
|
||||
}
|
||||
# Make app_context() return a proper context manager
|
||||
mock_app.app_context.return_value.__enter__ = MagicMock(return_value=None)
|
||||
mock_app.app_context.return_value.__exit__ = MagicMock(return_value=None)
|
||||
# Use regular Mock (not MagicMock) for _get_current_object to avoid
|
||||
# AsyncMockMixin creating unawaited coroutines in Python 3.10+
|
||||
mock_app._get_current_object = Mock(return_value=mock_app)
|
||||
return mock_app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def task_context(mock_task, mock_task_dao, mock_update_command, mock_flask_app):
|
||||
"""Create TaskContext with mocked dependencies."""
|
||||
# Ensure mock_task has properties_dict and payload_dict (TaskContext accesses them)
|
||||
mock_task.properties_dict = {"is_abortable": False}
|
||||
mock_task.payload_dict = {}
|
||||
|
||||
with (
|
||||
patch("superset.tasks.context.current_app") as mock_current_app,
|
||||
patch("superset.tasks.manager.cache_manager") as mock_cache_manager,
|
||||
):
|
||||
# Disable Redis by making signal_cache return None
|
||||
mock_cache_manager.signal_cache = None
|
||||
|
||||
# Configure current_app mock
|
||||
mock_current_app.config = mock_flask_app.config
|
||||
# Use regular Mock (not MagicMock) for _get_current_object to avoid
|
||||
# AsyncMockMixin creating unawaited coroutines in Python 3.10+
|
||||
mock_current_app._get_current_object = Mock(return_value=mock_flask_app)
|
||||
|
||||
ctx = TaskContext(mock_task)
|
||||
|
||||
yield ctx
|
||||
|
||||
# Cleanup: stop polling if started
|
||||
if ctx._abort_listener:
|
||||
ctx.stop_abort_polling()
|
||||
|
||||
|
||||
class TestTaskStatusEnum:
|
||||
"""Test TaskStatus enum values."""
|
||||
|
||||
def test_aborting_status_exists(self):
|
||||
"""Test that ABORTING status is defined."""
|
||||
assert hasattr(TaskStatus, "ABORTING")
|
||||
assert TaskStatus.ABORTING.value == "aborting"
|
||||
|
||||
def test_all_statuses_present(self):
|
||||
"""Test all expected statuses are present."""
|
||||
expected_statuses = [
|
||||
"pending",
|
||||
"in_progress",
|
||||
"success",
|
||||
"failure",
|
||||
"aborting",
|
||||
"aborted",
|
||||
]
|
||||
actual_statuses = [s.value for s in TaskStatus]
|
||||
|
||||
for status in expected_statuses:
|
||||
assert status in actual_statuses, f"Missing status: {status}"
|
||||
|
||||
|
||||
class TestTaskAbortProperties:
|
||||
"""Test Task model abort-related properties via status and properties accessor."""
|
||||
|
||||
def test_aborting_status(self):
|
||||
"""Test ABORTING status check."""
|
||||
from superset.models.tasks import Task
|
||||
|
||||
task = Task()
|
||||
task.status = TaskStatus.ABORTING.value
|
||||
|
||||
assert task.status == TaskStatus.ABORTING.value
|
||||
|
||||
def test_is_abortable_in_properties(self):
|
||||
"""Test is_abortable is accessible via properties."""
|
||||
from superset.models.tasks import Task
|
||||
|
||||
task = Task()
|
||||
task.update_properties({"is_abortable": True})
|
||||
|
||||
assert task.properties_dict.get("is_abortable") is True
|
||||
|
||||
def test_is_abortable_default_none(self):
|
||||
"""Test is_abortable defaults to None for new tasks."""
|
||||
from superset.models.tasks import Task
|
||||
|
||||
task = Task()
|
||||
|
||||
assert task.properties_dict.get("is_abortable") is None
|
||||
|
||||
|
||||
class TestTaskSetStatus:
|
||||
"""Test Task.set_status behavior for abort states."""
|
||||
|
||||
def test_set_status_in_progress_sets_is_abortable_false(self):
|
||||
"""Test that transitioning to IN_PROGRESS sets is_abortable to False."""
|
||||
from superset.models.tasks import Task
|
||||
|
||||
task = Task()
|
||||
task.uuid = "test-uuid"
|
||||
# Default is None
|
||||
|
||||
task.set_status(TaskStatus.IN_PROGRESS)
|
||||
|
||||
assert task.properties_dict.get("is_abortable") is False
|
||||
assert task.started_at is not None
|
||||
|
||||
def test_set_status_in_progress_preserves_existing_is_abortable(self):
|
||||
"""Test that re-setting IN_PROGRESS doesn't override is_abortable."""
|
||||
from superset.models.tasks import Task
|
||||
|
||||
task = Task()
|
||||
task.uuid = "test-uuid"
|
||||
task.update_properties(
|
||||
{"is_abortable": True}
|
||||
) # Already set by handler registration
|
||||
task.started_at = datetime.now(timezone.utc) # Already started
|
||||
|
||||
task.set_status(TaskStatus.IN_PROGRESS)
|
||||
|
||||
# Should not override since started_at is already set
|
||||
assert task.properties_dict.get("is_abortable") is True
|
||||
|
||||
def test_set_status_aborting_does_not_set_ended_at(self):
|
||||
"""Test that ABORTING status does not set ended_at."""
|
||||
from superset.models.tasks import Task
|
||||
|
||||
task = Task()
|
||||
task.uuid = "test-uuid"
|
||||
task.started_at = datetime.now(timezone.utc)
|
||||
|
||||
task.status = TaskStatus.ABORTING.value
|
||||
|
||||
assert task.ended_at is None
|
||||
|
||||
def test_set_status_aborted_sets_ended_at(self):
|
||||
"""Test that ABORTED status sets ended_at."""
|
||||
from superset.models.tasks import Task
|
||||
|
||||
task = Task()
|
||||
task.uuid = "test-uuid"
|
||||
task.started_at = datetime.now(timezone.utc)
|
||||
|
||||
task.set_status(TaskStatus.ABORTED)
|
||||
|
||||
assert task.ended_at is not None
|
||||
|
||||
|
||||
class TestTaskDuration:
|
||||
"""Test Task duration_seconds property with different states."""
|
||||
|
||||
def test_duration_seconds_finished_task(self):
|
||||
"""Test duration for finished task returns actual duration."""
|
||||
from superset.models.tasks import Task
|
||||
|
||||
task = Task()
|
||||
task.status = TaskStatus.SUCCESS.value # Must be finished to use ended_at
|
||||
task.started_at = datetime(2024, 1, 1, 10, 0, 0, tzinfo=timezone.utc)
|
||||
task.ended_at = datetime(2024, 1, 1, 10, 0, 30, tzinfo=timezone.utc)
|
||||
|
||||
# Should use ended_at - started_at = 30 seconds
|
||||
assert task.duration_seconds == 30.0
|
||||
|
||||
@freeze_time("2024-01-01 10:00:30")
|
||||
def test_duration_seconds_running_task(self):
|
||||
"""Test duration for running task returns time since start."""
|
||||
from superset.models.tasks import Task
|
||||
|
||||
task = Task()
|
||||
task.started_at = datetime(2024, 1, 1, 10, 0, 0, tzinfo=timezone.utc)
|
||||
task.ended_at = None
|
||||
|
||||
# 30 seconds since start
|
||||
assert task.duration_seconds == 30.0
|
||||
|
||||
@freeze_time("2024-01-01 10:00:15")
|
||||
def test_duration_seconds_pending_task(self):
|
||||
"""Test duration for pending task returns queue time."""
|
||||
from superset.models.tasks import Task
|
||||
|
||||
task = Task()
|
||||
task.created_on = datetime(2024, 1, 1, 10, 0, 0, tzinfo=timezone.utc)
|
||||
task.started_at = None
|
||||
task.ended_at = None
|
||||
|
||||
# 15 seconds since creation
|
||||
assert task.duration_seconds == 15.0
|
||||
|
||||
def test_duration_seconds_no_timestamps(self):
|
||||
"""Test duration returns None when no timestamps available."""
|
||||
from superset.models.tasks import Task
|
||||
|
||||
task = Task()
|
||||
task.created_on = None
|
||||
task.started_at = None
|
||||
task.ended_at = None
|
||||
|
||||
assert task.duration_seconds is None
|
||||
|
||||
|
||||
class TestAbortHandlerRegistration:
|
||||
"""Test abort handler registration and is_abortable flag."""
|
||||
|
||||
def test_on_abort_registers_handler(self, task_context):
|
||||
"""Test that on_abort registers a handler."""
|
||||
handler_called = False
|
||||
|
||||
@task_context.on_abort
|
||||
def handle_abort():
|
||||
nonlocal handler_called
|
||||
handler_called = True
|
||||
|
||||
assert len(task_context._abort_handlers) == 1
|
||||
assert not handler_called
|
||||
|
||||
@patch("superset.tasks.context.current_app")
|
||||
def test_on_abort_sets_abortable(self, mock_app):
|
||||
"""Test on_abort sets is_abortable to True on first handler."""
|
||||
mock_app.config = {"TASK_ABORT_POLLING_DEFAULT_INTERVAL": 1.0}
|
||||
mock_app._get_current_object = Mock(return_value=mock_app)
|
||||
mock_task = MagicMock()
|
||||
mock_task.uuid = TEST_UUID
|
||||
mock_task.properties_dict = {"is_abortable": False}
|
||||
mock_task.payload_dict = {}
|
||||
|
||||
with (
|
||||
patch.object(TaskContext, "_set_abortable") as mock_set_abortable,
|
||||
patch.object(TaskContext, "start_abort_polling"),
|
||||
):
|
||||
ctx = TaskContext(mock_task)
|
||||
|
||||
@ctx.on_abort
|
||||
def handler():
|
||||
pass
|
||||
|
||||
mock_set_abortable.assert_called_once()
|
||||
|
||||
@patch("superset.tasks.context.current_app")
|
||||
def test_on_abort_only_sets_abortable_once(self, mock_app):
|
||||
"""Test on_abort only calls _set_abortable for first handler."""
|
||||
mock_app.config = {"TASK_ABORT_POLLING_DEFAULT_INTERVAL": 1.0}
|
||||
mock_app._get_current_object = Mock(return_value=mock_app)
|
||||
mock_task = MagicMock()
|
||||
mock_task.uuid = TEST_UUID
|
||||
mock_task.properties_dict = {"is_abortable": False}
|
||||
mock_task.payload_dict = {}
|
||||
|
||||
with (
|
||||
patch.object(TaskContext, "_set_abortable") as mock_set_abortable,
|
||||
patch.object(TaskContext, "start_abort_polling"),
|
||||
):
|
||||
ctx = TaskContext(mock_task)
|
||||
|
||||
@ctx.on_abort
|
||||
def handler1():
|
||||
pass
|
||||
|
||||
@ctx.on_abort
|
||||
def handler2():
|
||||
pass
|
||||
|
||||
# Should only be called once for first handler
|
||||
assert mock_set_abortable.call_count == 1
|
||||
|
||||
def test_abort_handlers_completed_initially_false(self):
|
||||
"""Test abort_handlers_completed is False initially."""
|
||||
mock_task = MagicMock()
|
||||
mock_task.uuid = TEST_UUID
|
||||
mock_task.properties_dict = {}
|
||||
mock_task.payload_dict = {}
|
||||
|
||||
with patch("superset.tasks.context.current_app") as mock_app:
|
||||
mock_app._get_current_object = Mock(return_value=mock_app)
|
||||
ctx = TaskContext(mock_task)
|
||||
assert ctx.abort_handlers_completed is False
|
||||
|
||||
|
||||
class TestAbortPolling:
|
||||
"""Test abort detection polling behavior."""
|
||||
|
||||
def test_on_abort_starts_polling_automatically(self, task_context):
|
||||
"""Test that registering first handler starts abort listener."""
|
||||
assert task_context._abort_listener is None
|
||||
|
||||
@task_context.on_abort
|
||||
def handle_abort():
|
||||
pass
|
||||
|
||||
assert task_context._abort_listener is not None
|
||||
|
||||
def test_stop_abort_polling(self, task_context):
|
||||
"""Test that stop_abort_polling stops the abort listener."""
|
||||
|
||||
@task_context.on_abort
|
||||
def handle_abort():
|
||||
pass
|
||||
|
||||
assert task_context._abort_listener is not None
|
||||
|
||||
task_context.stop_abort_polling()
|
||||
|
||||
assert task_context._abort_listener is None
|
||||
|
||||
def test_start_abort_polling_only_once(self, task_context):
|
||||
"""Test that start_abort_polling is idempotent."""
|
||||
task_context.start_abort_polling(interval=0.1)
|
||||
first_listener = task_context._abort_listener
|
||||
|
||||
# Try to start again
|
||||
task_context.start_abort_polling(interval=0.1)
|
||||
second_listener = task_context._abort_listener
|
||||
|
||||
# Should be the same listener
|
||||
assert first_listener is second_listener
|
||||
|
||||
def test_on_abort_with_custom_interval(self, task_context):
|
||||
"""Test that custom interval can be set via start_abort_polling."""
|
||||
with patch("superset.tasks.context.current_app") as mock_app:
|
||||
mock_app.config = {"TASK_ABORT_POLLING_DEFAULT_INTERVAL": 0.1}
|
||||
mock_app._get_current_object = Mock(return_value=mock_app)
|
||||
|
||||
@task_context.on_abort
|
||||
def handle_abort():
|
||||
pass
|
||||
|
||||
# Override with custom interval
|
||||
task_context.stop_abort_polling()
|
||||
task_context.start_abort_polling(interval=0.05)
|
||||
|
||||
assert task_context._abort_listener is not None
|
||||
|
||||
def test_polling_stops_after_abort_detected(self, task_context, mock_task):
|
||||
"""Test that abort is detected and handlers are triggered."""
|
||||
|
||||
@task_context.on_abort
|
||||
def handle_abort():
|
||||
pass
|
||||
|
||||
# Trigger abort
|
||||
mock_task.status = TaskStatus.ABORTED.value
|
||||
|
||||
# Wait for detection
|
||||
time.sleep(0.3)
|
||||
|
||||
# Abort should have been detected
|
||||
assert task_context._abort_detected is True
|
||||
|
||||
|
||||
class TestAbortHandlerExecution:
|
||||
"""Test abort handler execution behavior."""
|
||||
|
||||
def test_on_abort_handler_fires_when_task_aborted(self, task_context, mock_task):
|
||||
"""Test that abort handler fires automatically when task is aborted."""
|
||||
abort_called = False
|
||||
|
||||
@task_context.on_abort
|
||||
def handle_abort():
|
||||
nonlocal abort_called
|
||||
abort_called = True
|
||||
|
||||
# Simulate task being aborted
|
||||
mock_task.status = TaskStatus.ABORTED.value
|
||||
|
||||
# Wait for polling to detect abort (max 0.3s with 0.1s interval)
|
||||
time.sleep(0.3)
|
||||
|
||||
assert abort_called
|
||||
assert task_context._abort_detected
|
||||
|
||||
def test_on_abort_not_called_on_success(self, task_context, mock_task):
|
||||
"""Test that abort handlers don't run on success."""
|
||||
abort_called = False
|
||||
|
||||
@task_context.on_abort
|
||||
def handle_abort():
|
||||
nonlocal abort_called
|
||||
abort_called = True
|
||||
|
||||
# Keep task in success state
|
||||
mock_task.status = TaskStatus.SUCCESS.value
|
||||
|
||||
# Wait and verify handler not called
|
||||
time.sleep(0.3)
|
||||
|
||||
assert not abort_called
|
||||
|
||||
def test_multiple_abort_handlers(self, task_context, mock_task):
|
||||
"""Test that all abort handlers execute in LIFO order."""
|
||||
calls = []
|
||||
|
||||
@task_context.on_abort
|
||||
def handler1():
|
||||
calls.append(1)
|
||||
|
||||
@task_context.on_abort
|
||||
def handler2():
|
||||
calls.append(2)
|
||||
|
||||
# Trigger abort
|
||||
mock_task.status = TaskStatus.ABORTED.value
|
||||
|
||||
# Wait for detection
|
||||
time.sleep(0.3)
|
||||
|
||||
# LIFO order: handler2 runs first
|
||||
assert calls == [2, 1]
|
||||
|
||||
def test_abort_handler_exception_doesnt_fail_task(self, task_context, mock_task):
|
||||
"""Test that exception in abort handler is logged but doesn't fail task."""
|
||||
handler2_called = False
|
||||
|
||||
@task_context.on_abort
|
||||
def bad_handler():
|
||||
raise ValueError("Handler error")
|
||||
|
||||
@task_context.on_abort
|
||||
def good_handler():
|
||||
nonlocal handler2_called
|
||||
handler2_called = True
|
||||
|
||||
# Trigger abort
|
||||
mock_task.status = TaskStatus.ABORTED.value
|
||||
|
||||
# Wait for detection
|
||||
time.sleep(0.3)
|
||||
|
||||
# Second handler should still run despite first handler failing
|
||||
assert handler2_called
|
||||
|
||||
|
||||
class TestBestEffortHandlerExecution:
|
||||
"""Test that all handlers execute even when some fail (best-effort)."""
|
||||
|
||||
def test_all_abort_handlers_run_even_if_all_fail(self, task_context, mock_task):
|
||||
"""Test all abort handlers execute even if every one raises an exception."""
|
||||
calls = []
|
||||
|
||||
@task_context.on_abort
|
||||
def handler1():
|
||||
calls.append(1)
|
||||
raise ValueError("Handler 1 failed")
|
||||
|
||||
@task_context.on_abort
|
||||
def handler2():
|
||||
calls.append(2)
|
||||
raise RuntimeError("Handler 2 failed")
|
||||
|
||||
@task_context.on_abort
|
||||
def handler3():
|
||||
calls.append(3)
|
||||
raise TypeError("Handler 3 failed")
|
||||
|
||||
# Trigger abort handlers directly (simulating abort detection)
|
||||
task_context._trigger_abort_handlers()
|
||||
|
||||
# All handlers should have been called (LIFO order: 3, 2, 1)
|
||||
assert calls == [3, 2, 1]
|
||||
|
||||
# Failures should be collected (abort handlers don't write to DB)
|
||||
assert len(task_context._handler_failures) == 3
|
||||
failure_types = [
|
||||
type(ex).__name__ for _, ex, _ in task_context._handler_failures
|
||||
]
|
||||
assert "TypeError" in failure_types
|
||||
assert "RuntimeError" in failure_types
|
||||
assert "ValueError" in failure_types
|
||||
|
||||
def test_all_cleanup_handlers_run_even_if_all_fail(self, task_context, mock_task):
|
||||
"""Test all cleanup handlers execute even if every one raises an exception."""
|
||||
calls = []
|
||||
captured_failures = []
|
||||
|
||||
# Mock _write_handler_failures_to_db to capture failures before clearing
|
||||
original_write = task_context._write_handler_failures_to_db
|
||||
|
||||
def mock_write():
|
||||
captured_failures.extend(task_context._handler_failures)
|
||||
original_write()
|
||||
|
||||
task_context._write_handler_failures_to_db = mock_write
|
||||
|
||||
@task_context.on_cleanup
|
||||
def cleanup1():
|
||||
calls.append(1)
|
||||
raise ValueError("Cleanup 1 failed")
|
||||
|
||||
@task_context.on_cleanup
|
||||
def cleanup2():
|
||||
calls.append(2)
|
||||
raise RuntimeError("Cleanup 2 failed")
|
||||
|
||||
@task_context.on_cleanup
|
||||
def cleanup3():
|
||||
calls.append(3)
|
||||
raise TypeError("Cleanup 3 failed")
|
||||
|
||||
# Set task to SUCCESS (not aborting) so only cleanup handlers run
|
||||
mock_task.status = TaskStatus.SUCCESS.value
|
||||
|
||||
# Run cleanup
|
||||
task_context._run_cleanup()
|
||||
|
||||
# All handlers should have been called (LIFO order: 3, 2, 1)
|
||||
assert calls == [3, 2, 1]
|
||||
|
||||
# Failures should have been captured before clearing
|
||||
assert len(captured_failures) == 3
|
||||
failure_types = [type(ex).__name__ for _, ex, _ in captured_failures]
|
||||
assert "TypeError" in failure_types
|
||||
assert "RuntimeError" in failure_types
|
||||
assert "ValueError" in failure_types
|
||||
|
||||
def test_mixed_abort_and_cleanup_failures_all_collected(
|
||||
self, task_context, mock_task
|
||||
):
|
||||
"""Test abort and cleanup handler failures are collected together."""
|
||||
calls = []
|
||||
captured_failures = []
|
||||
|
||||
# Mock _write_handler_failures_to_db to capture failures before clearing
|
||||
original_write = task_context._write_handler_failures_to_db
|
||||
|
||||
def mock_write():
|
||||
captured_failures.extend(task_context._handler_failures)
|
||||
original_write()
|
||||
|
||||
task_context._write_handler_failures_to_db = mock_write
|
||||
|
||||
@task_context.on_abort
|
||||
def abort1():
|
||||
calls.append("abort1")
|
||||
raise ValueError("Abort 1 failed")
|
||||
|
||||
@task_context.on_abort
|
||||
def abort2():
|
||||
calls.append("abort2")
|
||||
raise RuntimeError("Abort 2 failed")
|
||||
|
||||
@task_context.on_cleanup
|
||||
def cleanup1():
|
||||
calls.append("cleanup1")
|
||||
raise TypeError("Cleanup 1 failed")
|
||||
|
||||
@task_context.on_cleanup
|
||||
def cleanup2():
|
||||
calls.append("cleanup2")
|
||||
raise KeyError("Cleanup 2 failed")
|
||||
|
||||
# Set task to ABORTING so both abort and cleanup handlers run
|
||||
mock_task.status = TaskStatus.ABORTING.value
|
||||
|
||||
# Run cleanup (which triggers abort handlers first, then cleanup handlers)
|
||||
task_context._run_cleanup()
|
||||
|
||||
# All handlers should have been called
|
||||
# Abort handlers run first (LIFO: abort2, abort1)
|
||||
# Then cleanup handlers (LIFO: cleanup2, cleanup1)
|
||||
assert calls == ["abort2", "abort1", "cleanup2", "cleanup1"]
|
||||
|
||||
# All 4 failures should have been captured
|
||||
assert len(captured_failures) == 4
|
||||
|
||||
# Verify handler types are recorded correctly
|
||||
handler_types = [htype for htype, _, _ in captured_failures]
|
||||
assert handler_types.count("abort") == 2
|
||||
assert handler_types.count("cleanup") == 2
|
||||
|
||||
|
||||
class TestCleanupHandlers:
|
||||
"""Test cleanup handler behavior."""
|
||||
|
||||
def test_cleanup_triggers_abort_handlers_if_not_detected(
|
||||
self, task_context, mock_task
|
||||
):
|
||||
"""Test that _run_cleanup triggers abort handlers if task ended aborted."""
|
||||
abort_called = False
|
||||
|
||||
@task_context.on_abort
|
||||
def handle_abort():
|
||||
nonlocal abort_called
|
||||
abort_called = True
|
||||
|
||||
# Set task as aborted but don't let polling detect it
|
||||
mock_task.status = TaskStatus.ABORTED.value
|
||||
task_context._abort_detected = False
|
||||
|
||||
# Immediately run cleanup (simulating task ending before poll)
|
||||
task_context._run_cleanup()
|
||||
|
||||
assert abort_called
|
||||
|
||||
def test_cleanup_doesnt_duplicate_abort_handlers(self, task_context, mock_task):
|
||||
"""Test that abort handlers only run once even if called from cleanup."""
|
||||
call_count = 0
|
||||
|
||||
@task_context.on_abort
|
||||
def handle_abort():
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
|
||||
# Trigger abort via polling
|
||||
mock_task.status = TaskStatus.ABORTED.value
|
||||
time.sleep(0.3)
|
||||
|
||||
# Handlers should have been called once
|
||||
assert call_count == 1
|
||||
assert task_context._abort_detected is True
|
||||
|
||||
# Run cleanup - handlers should NOT be called again
|
||||
task_context._run_cleanup()
|
||||
|
||||
assert call_count == 1 # Still 1, not 2
|
||||
462
tests/unit_tests/tasks/test_manager.py
Normal file
462
tests/unit_tests/tasks/test_manager.py
Normal file
@@ -0,0 +1,462 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""Unit tests for TaskManager pub/sub functionality"""
|
||||
|
||||
import threading
|
||||
import time
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import redis
|
||||
|
||||
from superset.tasks.manager import AbortListener, TaskManager
|
||||
|
||||
|
||||
class TestAbortListener:
|
||||
"""Tests for AbortListener class"""
|
||||
|
||||
def test_stop_sets_event(self):
|
||||
"""Test that stop() sets the stop event"""
|
||||
stop_event = threading.Event()
|
||||
thread = MagicMock(spec=threading.Thread)
|
||||
thread.is_alive.return_value = False
|
||||
|
||||
listener = AbortListener("test-uuid", thread, stop_event)
|
||||
|
||||
assert not stop_event.is_set()
|
||||
listener.stop()
|
||||
assert stop_event.is_set()
|
||||
|
||||
def test_stop_closes_pubsub(self):
|
||||
"""Test that stop() closes the pub/sub connection"""
|
||||
stop_event = threading.Event()
|
||||
thread = MagicMock(spec=threading.Thread)
|
||||
thread.is_alive.return_value = False
|
||||
pubsub = MagicMock()
|
||||
|
||||
listener = AbortListener("test-uuid", thread, stop_event, pubsub)
|
||||
listener.stop()
|
||||
|
||||
pubsub.unsubscribe.assert_called_once()
|
||||
pubsub.close.assert_called_once()
|
||||
|
||||
def test_stop_joins_thread(self):
|
||||
"""Test that stop() joins the listener thread"""
|
||||
stop_event = threading.Event()
|
||||
thread = MagicMock(spec=threading.Thread)
|
||||
thread.is_alive.return_value = True
|
||||
|
||||
listener = AbortListener("test-uuid", thread, stop_event)
|
||||
listener.stop()
|
||||
|
||||
thread.join.assert_called_once_with(timeout=2.0)
|
||||
|
||||
|
||||
class TestTaskManagerInitApp:
|
||||
"""Tests for TaskManager.init_app()"""
|
||||
|
||||
def setup_method(self):
|
||||
"""Reset TaskManager state before each test"""
|
||||
TaskManager._initialized = False
|
||||
TaskManager._channel_prefix = "gtf:abort:"
|
||||
TaskManager._completion_channel_prefix = "gtf:complete:"
|
||||
|
||||
def teardown_method(self):
|
||||
"""Reset TaskManager state after each test"""
|
||||
TaskManager._initialized = False
|
||||
TaskManager._channel_prefix = "gtf:abort:"
|
||||
TaskManager._completion_channel_prefix = "gtf:complete:"
|
||||
|
||||
def test_init_app_sets_channel_prefixes(self):
|
||||
"""Test init_app reads channel prefixes from config"""
|
||||
app = MagicMock()
|
||||
app.config.get.side_effect = lambda key, default=None: {
|
||||
"TASKS_ABORT_CHANNEL_PREFIX": "custom:abort:",
|
||||
"TASKS_COMPLETION_CHANNEL_PREFIX": "custom:complete:",
|
||||
}.get(key, default)
|
||||
|
||||
TaskManager.init_app(app)
|
||||
|
||||
assert TaskManager._initialized is True
|
||||
assert TaskManager._channel_prefix == "custom:abort:"
|
||||
assert TaskManager._completion_channel_prefix == "custom:complete:"
|
||||
|
||||
def test_init_app_skips_if_already_initialized(self):
|
||||
"""Test init_app is idempotent"""
|
||||
TaskManager._initialized = True
|
||||
|
||||
app = MagicMock()
|
||||
TaskManager.init_app(app)
|
||||
|
||||
# Should not call app.config.get since already initialized
|
||||
app.config.get.assert_not_called()
|
||||
|
||||
|
||||
class TestTaskManagerPubSub:
|
||||
"""Tests for TaskManager pub/sub methods"""
|
||||
|
||||
def setup_method(self):
|
||||
"""Reset TaskManager state before each test"""
|
||||
TaskManager._initialized = False
|
||||
TaskManager._channel_prefix = "gtf:abort:"
|
||||
TaskManager._completion_channel_prefix = "gtf:complete:"
|
||||
|
||||
def teardown_method(self):
|
||||
"""Reset TaskManager state after each test"""
|
||||
TaskManager._initialized = False
|
||||
TaskManager._channel_prefix = "gtf:abort:"
|
||||
TaskManager._completion_channel_prefix = "gtf:complete:"
|
||||
|
||||
@patch("superset.tasks.manager.cache_manager")
|
||||
def test_is_pubsub_available_no_redis(self, mock_cache_manager):
|
||||
"""Test is_pubsub_available returns False when Redis not configured"""
|
||||
mock_cache_manager.signal_cache = None
|
||||
assert TaskManager.is_pubsub_available() is False
|
||||
|
||||
@patch("superset.tasks.manager.cache_manager")
|
||||
def test_is_pubsub_available_with_redis(self, mock_cache_manager):
|
||||
"""Test is_pubsub_available returns True when Redis is configured"""
|
||||
mock_cache_manager.signal_cache = MagicMock()
|
||||
assert TaskManager.is_pubsub_available() is True
|
||||
|
||||
def test_get_abort_channel(self):
|
||||
"""Test get_abort_channel returns correct channel name"""
|
||||
task_uuid = "abc-123-def-456"
|
||||
channel = TaskManager.get_abort_channel(task_uuid)
|
||||
assert channel == "gtf:abort:abc-123-def-456"
|
||||
|
||||
def test_get_abort_channel_custom_prefix(self):
|
||||
"""Test get_abort_channel with custom prefix"""
|
||||
TaskManager._channel_prefix = "custom:prefix:"
|
||||
task_uuid = "test-uuid"
|
||||
channel = TaskManager.get_abort_channel(task_uuid)
|
||||
assert channel == "custom:prefix:test-uuid"
|
||||
|
||||
@patch("superset.tasks.manager.cache_manager")
|
||||
def test_publish_abort_no_redis(self, mock_cache_manager):
|
||||
"""Test publish_abort returns False when Redis not available"""
|
||||
mock_cache_manager.signal_cache = None
|
||||
result = TaskManager.publish_abort("test-uuid")
|
||||
assert result is False
|
||||
|
||||
@patch("superset.tasks.manager.cache_manager")
|
||||
def test_publish_abort_success(self, mock_cache_manager):
|
||||
"""Test publish_abort publishes message successfully"""
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.publish.return_value = 1 # One subscriber
|
||||
mock_cache_manager.signal_cache = mock_redis
|
||||
|
||||
result = TaskManager.publish_abort("test-uuid")
|
||||
|
||||
assert result is True
|
||||
mock_redis.publish.assert_called_once_with("gtf:abort:test-uuid", "abort")
|
||||
|
||||
@patch("superset.tasks.manager.cache_manager")
|
||||
def test_publish_abort_redis_error(self, mock_cache_manager):
|
||||
"""Test publish_abort handles Redis errors gracefully"""
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.publish.side_effect = redis.RedisError("Connection lost")
|
||||
mock_cache_manager.signal_cache = mock_redis
|
||||
|
||||
result = TaskManager.publish_abort("test-uuid")
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestTaskManagerListenForAbort:
|
||||
"""Tests for TaskManager.listen_for_abort()"""
|
||||
|
||||
def setup_method(self):
|
||||
"""Reset TaskManager state before each test"""
|
||||
TaskManager._initialized = False
|
||||
TaskManager._channel_prefix = "gtf:abort:"
|
||||
TaskManager._completion_channel_prefix = "gtf:complete:"
|
||||
|
||||
def teardown_method(self):
|
||||
"""Reset TaskManager state after each test"""
|
||||
TaskManager._initialized = False
|
||||
TaskManager._channel_prefix = "gtf:abort:"
|
||||
TaskManager._completion_channel_prefix = "gtf:complete:"
|
||||
|
||||
@patch("superset.tasks.manager.cache_manager")
|
||||
def test_listen_for_abort_no_redis_uses_polling(self, mock_cache_manager):
|
||||
"""Test listen_for_abort falls back to polling when Redis unavailable"""
|
||||
mock_cache_manager.signal_cache = None
|
||||
callback = MagicMock()
|
||||
|
||||
with patch.object(TaskManager, "_poll_for_abort", return_value=None):
|
||||
listener = TaskManager.listen_for_abort(
|
||||
task_uuid="test-uuid",
|
||||
callback=callback,
|
||||
poll_interval=1.0,
|
||||
app=None,
|
||||
)
|
||||
|
||||
# Give thread time to start
|
||||
time.sleep(0.1)
|
||||
listener.stop()
|
||||
|
||||
# Should use polling since no Redis
|
||||
assert listener._pubsub is None
|
||||
|
||||
@patch("superset.tasks.manager.cache_manager")
|
||||
def test_listen_for_abort_with_redis_uses_pubsub(self, mock_cache_manager):
|
||||
"""Test listen_for_abort uses pub/sub when Redis available"""
|
||||
mock_redis = MagicMock()
|
||||
mock_pubsub = MagicMock()
|
||||
mock_redis.pubsub.return_value = mock_pubsub
|
||||
mock_cache_manager.signal_cache = mock_redis
|
||||
|
||||
callback = MagicMock()
|
||||
|
||||
with patch.object(TaskManager, "_listen_pubsub", return_value=None):
|
||||
listener = TaskManager.listen_for_abort(
|
||||
task_uuid="test-uuid",
|
||||
callback=callback,
|
||||
poll_interval=1.0,
|
||||
app=None,
|
||||
)
|
||||
|
||||
# Give thread time to start
|
||||
time.sleep(0.1)
|
||||
listener.stop()
|
||||
|
||||
# Should subscribe to channel
|
||||
mock_pubsub.subscribe.assert_called_once_with("gtf:abort:test-uuid")
|
||||
|
||||
@patch("superset.tasks.manager.cache_manager")
|
||||
def test_listen_for_abort_redis_subscribe_failure_raises(self, mock_cache_manager):
|
||||
"""Test listen_for_abort raises exception on subscribe failure
|
||||
when Redis configured"""
|
||||
import pytest
|
||||
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.pubsub.side_effect = redis.RedisError("Connection failed")
|
||||
mock_cache_manager.signal_cache = mock_redis
|
||||
|
||||
callback = MagicMock()
|
||||
|
||||
# With fail-fast behavior, Redis subscribe failure raises exception
|
||||
with pytest.raises(redis.RedisError, match="Connection failed"):
|
||||
TaskManager.listen_for_abort(
|
||||
task_uuid="test-uuid",
|
||||
callback=callback,
|
||||
poll_interval=1.0,
|
||||
app=None,
|
||||
)
|
||||
|
||||
|
||||
class TestTaskManagerCompletion:
|
||||
"""Tests for TaskManager completion pub/sub and wait_for_completion"""
|
||||
|
||||
def setup_method(self):
|
||||
"""Reset TaskManager state before each test"""
|
||||
TaskManager._initialized = False
|
||||
TaskManager._channel_prefix = "gtf:abort:"
|
||||
TaskManager._completion_channel_prefix = "gtf:complete:"
|
||||
|
||||
def teardown_method(self):
|
||||
"""Reset TaskManager state after each test"""
|
||||
TaskManager._initialized = False
|
||||
TaskManager._channel_prefix = "gtf:abort:"
|
||||
TaskManager._completion_channel_prefix = "gtf:complete:"
|
||||
|
||||
def test_get_completion_channel(self):
|
||||
"""Test get_completion_channel returns correct channel name"""
|
||||
task_uuid = "abc-123-def-456"
|
||||
channel = TaskManager.get_completion_channel(task_uuid)
|
||||
assert channel == "gtf:complete:abc-123-def-456"
|
||||
|
||||
def test_get_completion_channel_custom_prefix(self):
|
||||
"""Test get_completion_channel with custom prefix"""
|
||||
TaskManager._completion_channel_prefix = "custom:complete:"
|
||||
task_uuid = "test-uuid"
|
||||
channel = TaskManager.get_completion_channel(task_uuid)
|
||||
assert channel == "custom:complete:test-uuid"
|
||||
|
||||
@patch("superset.tasks.manager.cache_manager")
|
||||
def test_publish_completion_no_redis(self, mock_cache_manager):
|
||||
"""Test publish_completion returns False when Redis not available"""
|
||||
mock_cache_manager.signal_cache = None
|
||||
result = TaskManager.publish_completion("test-uuid", "success")
|
||||
assert result is False
|
||||
|
||||
@patch("superset.tasks.manager.cache_manager")
|
||||
def test_publish_completion_success(self, mock_cache_manager):
|
||||
"""Test publish_completion publishes message successfully"""
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.publish.return_value = 1 # One subscriber
|
||||
mock_cache_manager.signal_cache = mock_redis
|
||||
|
||||
result = TaskManager.publish_completion("test-uuid", "success")
|
||||
|
||||
assert result is True
|
||||
mock_redis.publish.assert_called_once_with("gtf:complete:test-uuid", "success")
|
||||
|
||||
@patch("superset.tasks.manager.cache_manager")
|
||||
def test_publish_completion_redis_error(self, mock_cache_manager):
|
||||
"""Test publish_completion handles Redis errors gracefully"""
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.publish.side_effect = redis.RedisError("Connection lost")
|
||||
mock_cache_manager.signal_cache = mock_redis
|
||||
|
||||
result = TaskManager.publish_completion("test-uuid", "success")
|
||||
|
||||
assert result is False
|
||||
|
||||
@patch("superset.tasks.manager.cache_manager")
|
||||
@patch("superset.daos.tasks.TaskDAO")
|
||||
def test_wait_for_completion_task_not_found(self, mock_dao, mock_cache_manager):
|
||||
"""Test wait_for_completion raises ValueError for missing task"""
|
||||
import pytest
|
||||
|
||||
mock_cache_manager.signal_cache = None
|
||||
mock_dao.find_one_or_none.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
TaskManager.wait_for_completion("nonexistent-uuid")
|
||||
|
||||
@patch("superset.tasks.manager.cache_manager")
|
||||
@patch("superset.daos.tasks.TaskDAO")
|
||||
def test_wait_for_completion_already_complete(self, mock_dao, mock_cache_manager):
|
||||
"""Test wait_for_completion returns immediately for terminal state"""
|
||||
mock_cache_manager.signal_cache = None
|
||||
mock_task = MagicMock()
|
||||
mock_task.uuid = "test-uuid"
|
||||
mock_task.status = "success"
|
||||
mock_dao.find_one_or_none.return_value = mock_task
|
||||
|
||||
result = TaskManager.wait_for_completion("test-uuid")
|
||||
|
||||
assert result == mock_task
|
||||
# Should only call find_one_or_none once (initial check)
|
||||
mock_dao.find_one_or_none.assert_called_once()
|
||||
|
||||
@patch("superset.tasks.manager.cache_manager")
|
||||
@patch("superset.daos.tasks.TaskDAO")
|
||||
def test_wait_for_completion_timeout(self, mock_dao, mock_cache_manager):
|
||||
"""Test wait_for_completion raises TimeoutError when timeout expires"""
|
||||
import pytest
|
||||
|
||||
mock_cache_manager.signal_cache = None
|
||||
mock_task = MagicMock()
|
||||
mock_task.uuid = "test-uuid"
|
||||
mock_task.status = "in_progress" # Never completes
|
||||
mock_dao.find_one_or_none.return_value = mock_task
|
||||
|
||||
with pytest.raises(TimeoutError, match="Timeout waiting"):
|
||||
TaskManager.wait_for_completion("test-uuid", timeout=0.1)
|
||||
|
||||
@patch("superset.tasks.manager.cache_manager")
|
||||
@patch("superset.daos.tasks.TaskDAO")
|
||||
def test_wait_for_completion_polling_success(self, mock_dao, mock_cache_manager):
|
||||
"""Test wait_for_completion returns when task completes via polling"""
|
||||
mock_cache_manager.signal_cache = None
|
||||
mock_task_pending = MagicMock()
|
||||
mock_task_pending.uuid = "test-uuid"
|
||||
mock_task_pending.status = "pending"
|
||||
|
||||
mock_task_complete = MagicMock()
|
||||
mock_task_complete.uuid = "test-uuid"
|
||||
mock_task_complete.status = "success"
|
||||
|
||||
# First call returns pending, second returns complete
|
||||
mock_dao.find_one_or_none.side_effect = [
|
||||
mock_task_pending,
|
||||
mock_task_complete,
|
||||
]
|
||||
|
||||
result = TaskManager.wait_for_completion(
|
||||
"test-uuid",
|
||||
timeout=5.0,
|
||||
poll_interval=0.1,
|
||||
)
|
||||
|
||||
assert result.status == "success"
|
||||
|
||||
@patch("superset.tasks.manager.cache_manager")
|
||||
@patch("superset.daos.tasks.TaskDAO")
|
||||
def test_wait_for_completion_with_pubsub(self, mock_dao, mock_cache_manager):
|
||||
"""Test wait_for_completion uses pub/sub when Redis available"""
|
||||
mock_task_pending = MagicMock()
|
||||
mock_task_pending.uuid = "test-uuid"
|
||||
mock_task_pending.status = "pending"
|
||||
|
||||
mock_task_complete = MagicMock()
|
||||
mock_task_complete.uuid = "test-uuid"
|
||||
mock_task_complete.status = "success"
|
||||
|
||||
# First call returns pending, second returns complete
|
||||
mock_dao.find_one_or_none.side_effect = [
|
||||
mock_task_pending,
|
||||
mock_task_complete,
|
||||
]
|
||||
|
||||
# Set up mock Redis with pub/sub
|
||||
mock_redis = MagicMock()
|
||||
mock_pubsub = MagicMock()
|
||||
# Simulate receiving a completion message
|
||||
mock_pubsub.get_message.return_value = {
|
||||
"type": "message",
|
||||
"data": "success",
|
||||
}
|
||||
mock_redis.pubsub.return_value = mock_pubsub
|
||||
mock_cache_manager.signal_cache = mock_redis
|
||||
|
||||
result = TaskManager.wait_for_completion(
|
||||
"test-uuid",
|
||||
timeout=5.0,
|
||||
)
|
||||
|
||||
assert result.status == "success"
|
||||
# Should have subscribed to completion channel
|
||||
mock_pubsub.subscribe.assert_called_once_with("gtf:complete:test-uuid")
|
||||
# Should have cleaned up
|
||||
mock_pubsub.unsubscribe.assert_called_once()
|
||||
mock_pubsub.close.assert_called_once()
|
||||
|
||||
@patch("superset.tasks.manager.cache_manager")
|
||||
@patch("superset.daos.tasks.TaskDAO")
|
||||
def test_wait_for_completion_pubsub_error_raises(
|
||||
self, mock_dao, mock_cache_manager
|
||||
):
|
||||
"""Test wait_for_completion raises exception on Redis error when
|
||||
Redis configured"""
|
||||
import pytest
|
||||
|
||||
mock_task_pending = MagicMock()
|
||||
mock_task_pending.uuid = "test-uuid"
|
||||
mock_task_pending.status = "pending"
|
||||
|
||||
mock_dao.find_one_or_none.return_value = mock_task_pending
|
||||
|
||||
# Set up mock Redis that fails
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.pubsub.side_effect = redis.RedisError("Connection failed")
|
||||
mock_cache_manager.signal_cache = mock_redis
|
||||
|
||||
# With fail-fast behavior, Redis error is raised instead of falling back
|
||||
with pytest.raises(redis.RedisError, match="Connection failed"):
|
||||
TaskManager.wait_for_completion(
|
||||
"test-uuid",
|
||||
timeout=5.0,
|
||||
poll_interval=0.1,
|
||||
)
|
||||
|
||||
def test_terminal_states_constant(self):
|
||||
"""Test TERMINAL_STATES contains expected values"""
|
||||
expected = {"success", "failure", "aborted", "timed_out"}
|
||||
assert TaskManager.TERMINAL_STATES == expected
|
||||
612
tests/unit_tests/tasks/test_timeout.py
Normal file
612
tests/unit_tests/tasks/test_timeout.py
Normal file
@@ -0,0 +1,612 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""Unit tests for GTF timeout handling."""
|
||||
|
||||
import time
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from superset_core.api.tasks import TaskOptions, TaskScope
|
||||
|
||||
from superset.tasks.context import TaskContext
|
||||
from superset.tasks.decorators import TaskWrapper
|
||||
|
||||
TEST_UUID = UUID("b8b61b7b-1cd3-4a31-a74a-0a95341afc06")
|
||||
|
||||
# =============================================================================
|
||||
# Fixtures
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_flask_app():
|
||||
"""Create a properly configured mock Flask app."""
|
||||
mock_app = MagicMock()
|
||||
mock_app.config = {
|
||||
"TASK_ABORT_POLLING_DEFAULT_INTERVAL": 0.1,
|
||||
}
|
||||
# Make app_context() return a proper context manager
|
||||
mock_app.app_context.return_value.__enter__ = MagicMock(return_value=None)
|
||||
mock_app.app_context.return_value.__exit__ = MagicMock(return_value=None)
|
||||
return mock_app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_task_abortable():
|
||||
"""Create a mock task that is abortable."""
|
||||
task = MagicMock()
|
||||
task.uuid = TEST_UUID
|
||||
task.status = "in_progress"
|
||||
task.properties_dict = {"is_abortable": True}
|
||||
task.payload_dict = {}
|
||||
# Set real values for dedup_key generation (used by UpdateTaskCommand lock)
|
||||
task.scope = "shared"
|
||||
task.task_type = "test_task"
|
||||
task.task_key = "test_key"
|
||||
task.user_id = 1
|
||||
return task
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_task_not_abortable():
|
||||
"""Create a mock task that is NOT abortable."""
|
||||
task = MagicMock()
|
||||
task.uuid = TEST_UUID
|
||||
task.status = "in_progress"
|
||||
task.properties_dict = {} # No is_abortable means it's not abortable
|
||||
task.payload_dict = {}
|
||||
# Set real values for dedup_key generation (used by UpdateTaskCommand lock)
|
||||
task.scope = "shared"
|
||||
task.task_type = "test_task"
|
||||
task.task_key = "test_key"
|
||||
task.user_id = 1
|
||||
return task
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def task_context_for_timeout(mock_flask_app, mock_task_abortable):
|
||||
"""Create TaskContext with mocked dependencies for timeout tests."""
|
||||
# Ensure mock_task has required attributes for TaskContext
|
||||
mock_task_abortable.payload_dict = {}
|
||||
|
||||
with (
|
||||
patch("superset.tasks.context.current_app") as mock_current_app,
|
||||
patch("superset.daos.tasks.TaskDAO") as mock_dao,
|
||||
patch("superset.tasks.manager.cache_manager") as mock_cache_manager,
|
||||
):
|
||||
# Disable Redis by making signal_cache return None
|
||||
mock_cache_manager.signal_cache = None
|
||||
|
||||
# Configure current_app mock
|
||||
mock_current_app.config = mock_flask_app.config
|
||||
mock_current_app._get_current_object.return_value = mock_flask_app
|
||||
|
||||
# Configure TaskDAO mock
|
||||
mock_dao.find_one_or_none.return_value = mock_task_abortable
|
||||
|
||||
ctx = TaskContext(mock_task_abortable)
|
||||
ctx._app = mock_flask_app
|
||||
|
||||
yield ctx
|
||||
|
||||
# Cleanup: stop timers if started
|
||||
ctx.stop_timeout_timer()
|
||||
if ctx._abort_listener:
|
||||
ctx.stop_abort_polling()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# TaskWrapper._merge_options Timeout Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestTimeoutMerging:
|
||||
"""Test timeout merging behavior in TaskWrapper._merge_options."""
|
||||
|
||||
def test_merge_options_decorator_timeout_used_when_no_override(self):
|
||||
"""Test that decorator timeout is used when no override is provided."""
|
||||
|
||||
def dummy_func():
|
||||
pass
|
||||
|
||||
wrapper = TaskWrapper(
|
||||
name="test_task",
|
||||
func=dummy_func,
|
||||
default_options=TaskOptions(),
|
||||
scope=TaskScope.PRIVATE,
|
||||
default_timeout=300, # 5-minute default
|
||||
)
|
||||
|
||||
merged = wrapper._merge_options(None)
|
||||
assert merged.timeout == 300
|
||||
|
||||
def test_merge_options_override_timeout_takes_precedence(self):
|
||||
"""Test that TaskOptions timeout overrides decorator default."""
|
||||
|
||||
def dummy_func():
|
||||
pass
|
||||
|
||||
wrapper = TaskWrapper(
|
||||
name="test_task",
|
||||
func=dummy_func,
|
||||
default_options=TaskOptions(),
|
||||
scope=TaskScope.PRIVATE,
|
||||
default_timeout=300, # 5-minute default
|
||||
)
|
||||
|
||||
override = TaskOptions(timeout=600) # 10-minute override
|
||||
merged = wrapper._merge_options(override)
|
||||
assert merged.timeout == 600
|
||||
|
||||
def test_merge_options_no_timeout_when_not_configured(self):
|
||||
"""Test that no timeout is set when not configured anywhere."""
|
||||
|
||||
def dummy_func():
|
||||
pass
|
||||
|
||||
wrapper = TaskWrapper(
|
||||
name="test_task",
|
||||
func=dummy_func,
|
||||
default_options=TaskOptions(),
|
||||
scope=TaskScope.PRIVATE,
|
||||
default_timeout=None, # No default timeout
|
||||
)
|
||||
|
||||
merged = wrapper._merge_options(None)
|
||||
assert merged.timeout is None
|
||||
|
||||
def test_merge_options_override_with_other_options_preserves_timeout(self):
|
||||
"""Test that setting other options doesn't lose decorator timeout."""
|
||||
|
||||
def dummy_func():
|
||||
pass
|
||||
|
||||
wrapper = TaskWrapper(
|
||||
name="test_task",
|
||||
func=dummy_func,
|
||||
default_options=TaskOptions(),
|
||||
scope=TaskScope.PRIVATE,
|
||||
default_timeout=300,
|
||||
)
|
||||
|
||||
# Override only task_key, not timeout
|
||||
override = TaskOptions(task_key="my-key")
|
||||
merged = wrapper._merge_options(override)
|
||||
|
||||
# Should keep decorator timeout since override.timeout is None
|
||||
assert merged.timeout == 300
|
||||
assert merged.task_key == "my-key"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# TaskContext Timeout Timer Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestTimeoutTimer:
|
||||
"""Test TaskContext timeout timer behavior."""
|
||||
|
||||
def test_start_timeout_timer_sets_timer(self, task_context_for_timeout):
|
||||
"""Test that start_timeout_timer creates a timer."""
|
||||
ctx = task_context_for_timeout
|
||||
|
||||
assert ctx._timeout_timer is None
|
||||
|
||||
ctx.start_timeout_timer(10)
|
||||
|
||||
assert ctx._timeout_timer is not None
|
||||
assert ctx._timeout_triggered is False
|
||||
|
||||
def test_start_timeout_timer_is_idempotent(self, task_context_for_timeout):
|
||||
"""Test that starting timer twice doesn't create duplicate timers."""
|
||||
ctx = task_context_for_timeout
|
||||
|
||||
ctx.start_timeout_timer(10)
|
||||
first_timer = ctx._timeout_timer
|
||||
|
||||
ctx.start_timeout_timer(20) # Try to start again
|
||||
second_timer = ctx._timeout_timer
|
||||
|
||||
assert first_timer is second_timer
|
||||
|
||||
def test_stop_timeout_timer_cancels_timer(self, task_context_for_timeout):
|
||||
"""Test that stop_timeout_timer cancels the timer."""
|
||||
ctx = task_context_for_timeout
|
||||
|
||||
ctx.start_timeout_timer(10)
|
||||
assert ctx._timeout_timer is not None
|
||||
|
||||
ctx.stop_timeout_timer()
|
||||
|
||||
assert ctx._timeout_timer is None
|
||||
|
||||
def test_stop_timeout_timer_safe_when_no_timer(self, task_context_for_timeout):
|
||||
"""Test that stop_timeout_timer doesn't fail when no timer exists."""
|
||||
ctx = task_context_for_timeout
|
||||
|
||||
assert ctx._timeout_timer is None
|
||||
ctx.stop_timeout_timer() # Should not raise
|
||||
assert ctx._timeout_timer is None
|
||||
|
||||
def test_timeout_triggered_property_initially_false(self, task_context_for_timeout):
|
||||
"""Test that timeout_triggered is False initially."""
|
||||
ctx = task_context_for_timeout
|
||||
assert ctx.timeout_triggered is False
|
||||
|
||||
def test_cleanup_stops_timeout_timer(self, task_context_for_timeout):
|
||||
"""Test that _run_cleanup stops the timeout timer."""
|
||||
ctx = task_context_for_timeout
|
||||
|
||||
ctx.start_timeout_timer(10)
|
||||
assert ctx._timeout_timer is not None
|
||||
|
||||
ctx._run_cleanup()
|
||||
|
||||
assert ctx._timeout_timer is None
|
||||
|
||||
|
||||
class TestTimeoutTrigger:
|
||||
"""Test timeout trigger behavior when timer fires."""
|
||||
|
||||
def test_timeout_triggers_abort_when_abortable(
|
||||
self, mock_flask_app, mock_task_abortable
|
||||
):
|
||||
"""Test that timeout triggers abort handlers when task is abortable."""
|
||||
abort_called = False
|
||||
|
||||
with (
|
||||
patch("superset.tasks.context.current_app") as mock_current_app,
|
||||
patch("superset.daos.tasks.TaskDAO") as mock_dao,
|
||||
patch(
|
||||
"superset.commands.tasks.update.UpdateTaskCommand"
|
||||
) as mock_update_cmd,
|
||||
patch("superset.tasks.manager.cache_manager") as mock_cache_manager,
|
||||
):
|
||||
# Disable Redis by making signal_cache return None
|
||||
mock_cache_manager.signal_cache = None
|
||||
|
||||
mock_current_app.config = mock_flask_app.config
|
||||
mock_current_app._get_current_object.return_value = mock_flask_app
|
||||
mock_dao.find_one_or_none.return_value = mock_task_abortable
|
||||
|
||||
ctx = TaskContext(mock_task_abortable)
|
||||
ctx._app = mock_flask_app
|
||||
|
||||
@ctx.on_abort
|
||||
def handle_abort():
|
||||
nonlocal abort_called
|
||||
abort_called = True
|
||||
|
||||
# Start short timeout
|
||||
ctx.start_timeout_timer(1)
|
||||
|
||||
# Wait for timeout to fire
|
||||
time.sleep(1.5)
|
||||
|
||||
# Abort handler should have been called
|
||||
assert abort_called
|
||||
assert ctx._timeout_triggered
|
||||
assert ctx._abort_detected
|
||||
|
||||
# Verify UpdateTaskCommand was called with ABORTING status
|
||||
mock_update_cmd.assert_called()
|
||||
call_kwargs = mock_update_cmd.call_args[1]
|
||||
assert call_kwargs.get("status") == "aborting"
|
||||
|
||||
# Cleanup
|
||||
ctx.stop_timeout_timer()
|
||||
if ctx._abort_listener:
|
||||
ctx.stop_abort_polling()
|
||||
|
||||
def test_timeout_logs_warning_when_not_abortable(
|
||||
self, mock_flask_app, mock_task_not_abortable
|
||||
):
|
||||
"""Test that timeout logs warning when task has no abort handler."""
|
||||
with (
|
||||
patch("superset.tasks.context.current_app") as mock_current_app,
|
||||
patch("superset.daos.tasks.TaskDAO") as mock_dao,
|
||||
patch("superset.tasks.context.logger") as mock_logger,
|
||||
patch("superset.tasks.manager.cache_manager") as mock_cache_manager,
|
||||
):
|
||||
# Disable Redis by making signal_cache return None
|
||||
mock_cache_manager.signal_cache = None
|
||||
|
||||
mock_current_app.config = mock_flask_app.config
|
||||
mock_current_app._get_current_object.return_value = mock_flask_app
|
||||
mock_dao.find_one_or_none.return_value = mock_task_not_abortable
|
||||
|
||||
ctx = TaskContext(mock_task_not_abortable)
|
||||
ctx._app = mock_flask_app
|
||||
|
||||
# No abort handler registered
|
||||
|
||||
# Start short timeout
|
||||
ctx.start_timeout_timer(1)
|
||||
|
||||
# Wait for timeout to fire
|
||||
time.sleep(1.5)
|
||||
|
||||
# Should have logged warning
|
||||
mock_logger.warning.assert_called()
|
||||
warning_call = mock_logger.warning.call_args
|
||||
assert "no abort handler" in warning_call[0][0].lower()
|
||||
assert ctx._timeout_triggered
|
||||
assert not ctx._abort_detected # No abort since no handler
|
||||
|
||||
# Cleanup
|
||||
ctx.stop_timeout_timer()
|
||||
|
||||
def test_timeout_does_not_trigger_if_already_aborting(
|
||||
self, mock_flask_app, mock_task_abortable
|
||||
):
|
||||
"""Test that timeout doesn't re-trigger abort if already aborting."""
|
||||
abort_count = 0
|
||||
|
||||
with (
|
||||
patch("superset.tasks.context.current_app") as mock_current_app,
|
||||
patch("superset.daos.tasks.TaskDAO") as mock_dao,
|
||||
patch("superset.commands.tasks.update.UpdateTaskCommand"),
|
||||
patch("superset.tasks.manager.cache_manager") as mock_cache_manager,
|
||||
):
|
||||
# Disable Redis by making signal_cache return None
|
||||
mock_cache_manager.signal_cache = None
|
||||
|
||||
mock_current_app.config = mock_flask_app.config
|
||||
mock_current_app._get_current_object.return_value = mock_flask_app
|
||||
mock_dao.find_one_or_none.return_value = mock_task_abortable
|
||||
|
||||
ctx = TaskContext(mock_task_abortable)
|
||||
ctx._app = mock_flask_app
|
||||
|
||||
@ctx.on_abort
|
||||
def handle_abort():
|
||||
nonlocal abort_count
|
||||
abort_count += 1
|
||||
|
||||
# Pre-set abort detected
|
||||
ctx._abort_detected = True
|
||||
|
||||
# Start short timeout
|
||||
ctx.start_timeout_timer(1)
|
||||
|
||||
# Wait for timeout to fire
|
||||
time.sleep(1.5)
|
||||
|
||||
# Handler should NOT have been called since already aborting
|
||||
assert abort_count == 0
|
||||
|
||||
# Cleanup
|
||||
ctx.stop_timeout_timer()
|
||||
if ctx._abort_listener:
|
||||
ctx.stop_abort_polling()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Task Decorator Timeout Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestTaskDecoratorTimeout:
|
||||
"""Test @task decorator timeout parameter."""
|
||||
|
||||
def test_task_decorator_accepts_timeout(self):
|
||||
"""Test that @task decorator accepts timeout parameter."""
|
||||
from superset.tasks.decorators import task
|
||||
from superset.tasks.registry import TaskRegistry
|
||||
|
||||
@task(name="test_timeout_task_1", timeout=300)
|
||||
def timeout_test_task_1():
|
||||
pass
|
||||
|
||||
assert isinstance(timeout_test_task_1, TaskWrapper)
|
||||
assert timeout_test_task_1.default_timeout == 300
|
||||
|
||||
# Cleanup registry
|
||||
TaskRegistry._tasks.pop("test_timeout_task_1", None)
|
||||
|
||||
def test_task_decorator_without_timeout(self):
|
||||
"""Test that @task decorator works without timeout."""
|
||||
from superset.tasks.decorators import task
|
||||
from superset.tasks.registry import TaskRegistry
|
||||
|
||||
@task(name="test_timeout_task_2")
|
||||
def timeout_test_task_2():
|
||||
pass
|
||||
|
||||
assert isinstance(timeout_test_task_2, TaskWrapper)
|
||||
assert timeout_test_task_2.default_timeout is None
|
||||
|
||||
# Cleanup registry
|
||||
TaskRegistry._tasks.pop("test_timeout_task_2", None)
|
||||
|
||||
def test_task_decorator_with_all_params(self):
|
||||
"""Test that @task decorator accepts all parameters together."""
|
||||
from superset.tasks.decorators import task
|
||||
from superset.tasks.registry import TaskRegistry
|
||||
|
||||
@task(name="test_timeout_task_3", scope=TaskScope.SHARED, timeout=600)
|
||||
def timeout_test_task_3():
|
||||
pass
|
||||
|
||||
assert timeout_test_task_3.name == "test_timeout_task_3"
|
||||
assert timeout_test_task_3.scope == TaskScope.SHARED
|
||||
assert timeout_test_task_3.default_timeout == 600
|
||||
|
||||
# Cleanup registry
|
||||
TaskRegistry._tasks.pop("test_timeout_task_3", None)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Timeout Terminal State Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestTimeoutTerminalState:
|
||||
"""Test timeout transitions to correct terminal state (TIMED_OUT vs FAILURE)."""
|
||||
|
||||
def test_timeout_triggered_flag_set_on_timeout(
|
||||
self, mock_flask_app, mock_task_abortable
|
||||
):
|
||||
"""Test that timeout_triggered flag is set when timeout fires."""
|
||||
with (
|
||||
patch("superset.tasks.context.current_app") as mock_current_app,
|
||||
patch("superset.daos.tasks.TaskDAO") as mock_dao,
|
||||
patch("superset.commands.tasks.update.UpdateTaskCommand"),
|
||||
patch("superset.tasks.manager.cache_manager") as mock_cache_manager,
|
||||
):
|
||||
# Disable Redis by making signal_cache return None
|
||||
mock_cache_manager.signal_cache = None
|
||||
|
||||
mock_current_app.config = mock_flask_app.config
|
||||
mock_current_app._get_current_object.return_value = mock_flask_app
|
||||
mock_dao.find_one_or_none.return_value = mock_task_abortable
|
||||
|
||||
ctx = TaskContext(mock_task_abortable)
|
||||
ctx._app = mock_flask_app
|
||||
|
||||
@ctx.on_abort
|
||||
def handle_abort():
|
||||
pass
|
||||
|
||||
# Initially not triggered
|
||||
assert ctx.timeout_triggered is False
|
||||
|
||||
# Start short timeout
|
||||
ctx.start_timeout_timer(1)
|
||||
|
||||
# Wait for timeout to fire
|
||||
time.sleep(1.5)
|
||||
|
||||
# Should be set after timeout
|
||||
assert ctx.timeout_triggered is True
|
||||
|
||||
# Cleanup
|
||||
ctx.stop_timeout_timer()
|
||||
if ctx._abort_listener:
|
||||
ctx.stop_abort_polling()
|
||||
|
||||
def test_user_abort_does_not_set_timeout_triggered(
|
||||
self, mock_flask_app, mock_task_abortable
|
||||
):
|
||||
"""Test that user abort doesn't set timeout_triggered flag."""
|
||||
with (
|
||||
patch("superset.tasks.context.current_app") as mock_current_app,
|
||||
patch("superset.daos.tasks.TaskDAO") as mock_dao,
|
||||
patch("superset.commands.tasks.update.UpdateTaskCommand"),
|
||||
patch("superset.tasks.manager.cache_manager") as mock_cache_manager,
|
||||
):
|
||||
# Disable Redis by making signal_cache return None
|
||||
mock_cache_manager.signal_cache = None
|
||||
|
||||
mock_current_app.config = mock_flask_app.config
|
||||
mock_current_app._get_current_object.return_value = mock_flask_app
|
||||
mock_dao.find_one_or_none.return_value = mock_task_abortable
|
||||
|
||||
ctx = TaskContext(mock_task_abortable)
|
||||
ctx._app = mock_flask_app
|
||||
|
||||
@ctx.on_abort
|
||||
def handle_abort():
|
||||
pass
|
||||
|
||||
# Simulate user abort (not timeout)
|
||||
ctx._on_abort_detected()
|
||||
|
||||
# timeout_triggered should still be False
|
||||
assert ctx.timeout_triggered is False
|
||||
# But abort_detected should be True
|
||||
assert ctx._abort_detected is True
|
||||
|
||||
# Cleanup
|
||||
if ctx._abort_listener:
|
||||
ctx.stop_abort_polling()
|
||||
|
||||
def test_abort_handlers_completed_tracks_success(
|
||||
self, mock_flask_app, mock_task_abortable
|
||||
):
|
||||
"""Test that abort_handlers_completed flag tracks successful
|
||||
handler execution."""
|
||||
with (
|
||||
patch("superset.tasks.context.current_app") as mock_current_app,
|
||||
patch("superset.daos.tasks.TaskDAO") as mock_dao,
|
||||
patch("superset.commands.tasks.update.UpdateTaskCommand"),
|
||||
patch("superset.tasks.manager.cache_manager") as mock_cache_manager,
|
||||
):
|
||||
# Disable Redis by making signal_cache return None
|
||||
mock_cache_manager.signal_cache = None
|
||||
|
||||
mock_current_app.config = mock_flask_app.config
|
||||
mock_current_app._get_current_object.return_value = mock_flask_app
|
||||
mock_dao.find_one_or_none.return_value = mock_task_abortable
|
||||
|
||||
ctx = TaskContext(mock_task_abortable)
|
||||
ctx._app = mock_flask_app
|
||||
|
||||
@ctx.on_abort
|
||||
def handle_abort():
|
||||
pass # Successful handler
|
||||
|
||||
# Initially not completed
|
||||
assert ctx.abort_handlers_completed is False
|
||||
|
||||
# Trigger abort handlers
|
||||
ctx._trigger_abort_handlers()
|
||||
|
||||
# Should be marked as completed
|
||||
assert ctx.abort_handlers_completed is True
|
||||
|
||||
# Cleanup
|
||||
if ctx._abort_listener:
|
||||
ctx.stop_abort_polling()
|
||||
|
||||
def test_abort_handlers_completed_false_on_exception(
|
||||
self, mock_flask_app, mock_task_abortable
|
||||
):
|
||||
"""Test that abort_handlers_completed is False when handler throws."""
|
||||
with (
|
||||
patch("superset.tasks.context.current_app") as mock_current_app,
|
||||
patch("superset.daos.tasks.TaskDAO") as mock_dao,
|
||||
patch("superset.commands.tasks.update.UpdateTaskCommand"),
|
||||
patch("superset.tasks.manager.cache_manager") as mock_cache_manager,
|
||||
):
|
||||
# Disable Redis by making signal_cache return None
|
||||
mock_cache_manager.signal_cache = None
|
||||
|
||||
mock_current_app.config = mock_flask_app.config
|
||||
mock_current_app._get_current_object.return_value = mock_flask_app
|
||||
mock_dao.find_one_or_none.return_value = mock_task_abortable
|
||||
|
||||
ctx = TaskContext(mock_task_abortable)
|
||||
ctx._app = mock_flask_app
|
||||
|
||||
@ctx.on_abort
|
||||
def handle_abort():
|
||||
raise ValueError("Handler failed")
|
||||
|
||||
# Initially not completed
|
||||
assert ctx.abort_handlers_completed is False
|
||||
|
||||
# Trigger abort handlers (will catch the exception internally)
|
||||
ctx._trigger_abort_handlers()
|
||||
|
||||
# Should NOT be marked as completed since handler threw
|
||||
assert ctx.abort_handlers_completed is False
|
||||
|
||||
# Cleanup
|
||||
if ctx._abort_listener:
|
||||
ctx.stop_abort_polling()
|
||||
@@ -22,9 +22,19 @@ from typing import Any, Optional, Union
|
||||
|
||||
import pytest
|
||||
from flask_appbuilder.security.sqla.models import User
|
||||
from superset_core.api.tasks import TaskScope
|
||||
|
||||
from superset.tasks.exceptions import ExecutorNotFoundError, InvalidExecutorError
|
||||
from superset.tasks.types import Executor, ExecutorType, FixedExecutor
|
||||
from superset.tasks.utils import (
|
||||
error_update,
|
||||
get_active_dedup_key,
|
||||
get_finished_dedup_key,
|
||||
parse_properties,
|
||||
progress_update,
|
||||
serialize_properties,
|
||||
)
|
||||
from superset.utils.hashing import hash_from_str
|
||||
|
||||
FIXED_USER_ID = 1234
|
||||
FIXED_USERNAME = "admin"
|
||||
@@ -330,3 +340,242 @@ def test_get_executor(
|
||||
)
|
||||
assert executor_type == expected_executor_type
|
||||
assert executor == expected_executor
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"scope,task_type,task_key,user_id,expected_composite_key",
|
||||
[
|
||||
# Private tasks with TaskScope enum
|
||||
(
|
||||
TaskScope.PRIVATE,
|
||||
"sql_execution",
|
||||
"chart_123",
|
||||
42,
|
||||
"private|sql_execution|chart_123|42",
|
||||
),
|
||||
(
|
||||
TaskScope.PRIVATE,
|
||||
"thumbnail_gen",
|
||||
"dash_456",
|
||||
100,
|
||||
"private|thumbnail_gen|dash_456|100",
|
||||
),
|
||||
# Private tasks with string scope
|
||||
(
|
||||
"private",
|
||||
"api_call",
|
||||
"endpoint_789",
|
||||
200,
|
||||
"private|api_call|endpoint_789|200",
|
||||
),
|
||||
# Shared tasks with TaskScope enum
|
||||
(
|
||||
TaskScope.SHARED,
|
||||
"report_gen",
|
||||
"monthly_report",
|
||||
None,
|
||||
"shared|report_gen|monthly_report",
|
||||
),
|
||||
(
|
||||
TaskScope.SHARED,
|
||||
"export_csv",
|
||||
"large_export",
|
||||
999, # user_id should be ignored for shared
|
||||
"shared|export_csv|large_export",
|
||||
),
|
||||
# Shared tasks with string scope
|
||||
(
|
||||
"shared",
|
||||
"batch_process",
|
||||
"batch_001",
|
||||
123, # user_id should be ignored for shared
|
||||
"shared|batch_process|batch_001",
|
||||
),
|
||||
# System tasks with TaskScope enum
|
||||
(
|
||||
TaskScope.SYSTEM,
|
||||
"cleanup_task",
|
||||
"daily_cleanup",
|
||||
None,
|
||||
"system|cleanup_task|daily_cleanup",
|
||||
),
|
||||
(
|
||||
TaskScope.SYSTEM,
|
||||
"db_migration",
|
||||
"version_123",
|
||||
1, # user_id should be ignored for system
|
||||
"system|db_migration|version_123",
|
||||
),
|
||||
# System tasks with string scope
|
||||
(
|
||||
"system",
|
||||
"maintenance",
|
||||
"nightly_job",
|
||||
2, # user_id should be ignored for system
|
||||
"system|maintenance|nightly_job",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_get_active_dedup_key(
|
||||
scope, task_type, task_key, user_id, expected_composite_key, app_context
|
||||
):
|
||||
"""Test get_active_dedup_key generates a hash of the composite key.
|
||||
|
||||
The function hashes the composite key using the configured HASH_ALGORITHM
|
||||
to produce a fixed-length dedup_key for database storage. The result is
|
||||
truncated to 64 chars to fit the database column.
|
||||
"""
|
||||
result = get_active_dedup_key(scope, task_type, task_key, user_id)
|
||||
|
||||
# The result should be a hash of the expected composite key, truncated to 64 chars
|
||||
expected_hash = hash_from_str(expected_composite_key)[:64]
|
||||
assert result == expected_hash
|
||||
assert len(result) <= 64
|
||||
|
||||
|
||||
def test_get_active_dedup_key_private_requires_user_id():
|
||||
"""Test that private tasks require explicit user_id parameter."""
|
||||
with pytest.raises(ValueError, match="user_id required for private tasks"):
|
||||
get_active_dedup_key(TaskScope.PRIVATE, "test_type", "test_key")
|
||||
|
||||
|
||||
def test_get_finished_dedup_key():
|
||||
"""Test that finished tasks use UUID as dedup_key"""
|
||||
test_uuid = "a1b2c3d4-e5f6-7890-abcd-ef1234567890"
|
||||
result = get_finished_dedup_key(test_uuid)
|
||||
assert result == test_uuid
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"progress,expected",
|
||||
[
|
||||
# Float (percentage) progress
|
||||
(0.5, {"progress_percent": 0.5}),
|
||||
(0.0, {"progress_percent": 0.0}),
|
||||
(1.0, {"progress_percent": 1.0}),
|
||||
(0.25, {"progress_percent": 0.25}),
|
||||
# Int (count only) progress
|
||||
(42, {"progress_current": 42}),
|
||||
(0, {"progress_current": 0}),
|
||||
(1000, {"progress_current": 1000}),
|
||||
# Tuple (current, total) progress with auto-computed percentage
|
||||
(
|
||||
(50, 100),
|
||||
{"progress_current": 50, "progress_total": 100, "progress_percent": 0.5},
|
||||
),
|
||||
(
|
||||
(25, 100),
|
||||
{"progress_current": 25, "progress_total": 100, "progress_percent": 0.25},
|
||||
),
|
||||
(
|
||||
(100, 100),
|
||||
{"progress_current": 100, "progress_total": 100, "progress_percent": 1.0},
|
||||
),
|
||||
# Tuple with zero total (no percentage computed)
|
||||
((10, 0), {"progress_current": 10, "progress_total": 0}),
|
||||
((0, 0), {"progress_current": 0, "progress_total": 0}),
|
||||
],
|
||||
)
|
||||
def test_progress_update(progress, expected):
|
||||
"""Test progress_update returns correct TaskProperties dict."""
|
||||
result = progress_update(progress)
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_error_update():
|
||||
"""Test error_update captures exception details."""
|
||||
try:
|
||||
raise ValueError("Test error message")
|
||||
except ValueError as e:
|
||||
result = error_update(e)
|
||||
|
||||
assert result["error_message"] == "Test error message"
|
||||
assert result["exception_type"] == "ValueError"
|
||||
assert "stack_trace" in result
|
||||
assert "ValueError" in result["stack_trace"]
|
||||
|
||||
|
||||
def test_error_update_custom_exception():
|
||||
"""Test error_update with custom exception class."""
|
||||
|
||||
class CustomError(Exception):
|
||||
pass
|
||||
|
||||
try:
|
||||
raise CustomError("Custom error")
|
||||
except CustomError as e:
|
||||
result = error_update(e)
|
||||
|
||||
assert result["error_message"] == "Custom error"
|
||||
assert result["exception_type"] == "CustomError"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"json_str,expected",
|
||||
[
|
||||
# Valid JSON
|
||||
(
|
||||
'{"is_abortable": true, "progress_percent": 0.5}',
|
||||
{"is_abortable": True, "progress_percent": 0.5},
|
||||
),
|
||||
(
|
||||
'{"error_message": "Something failed"}',
|
||||
{"error_message": "Something failed"},
|
||||
),
|
||||
(
|
||||
'{"progress_current": 50, "progress_total": 100}',
|
||||
{"progress_current": 50, "progress_total": 100},
|
||||
),
|
||||
# Empty/None cases
|
||||
("", {}),
|
||||
(None, {}),
|
||||
# Invalid JSON returns empty dict
|
||||
("not valid json", {}),
|
||||
("{broken", {}),
|
||||
# Unknown keys are preserved (forward compatibility)
|
||||
(
|
||||
'{"is_abortable": true, "future_field": "value"}',
|
||||
{"is_abortable": True, "future_field": "value"},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_parse_properties(json_str, expected):
|
||||
"""Test parse_properties parses JSON to TaskProperties dict."""
|
||||
result = parse_properties(json_str)
|
||||
assert result == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"props,expected_contains",
|
||||
[
|
||||
# Full properties
|
||||
(
|
||||
{"is_abortable": True, "progress_percent": 0.5},
|
||||
{"is_abortable": True, "progress_percent": 0.5},
|
||||
),
|
||||
# Empty dict
|
||||
({}, {}),
|
||||
# Sparse properties
|
||||
({"is_abortable": True}, {"is_abortable": True}),
|
||||
({"error_message": "fail"}, {"error_message": "fail"}),
|
||||
],
|
||||
)
|
||||
def test_serialize_properties(props, expected_contains):
|
||||
"""Test serialize_properties converts TaskProperties to JSON."""
|
||||
from superset.utils import json
|
||||
|
||||
result = serialize_properties(props)
|
||||
parsed = json.loads(result)
|
||||
assert parsed == expected_contains
|
||||
|
||||
|
||||
def test_properties_roundtrip():
|
||||
"""Test that serialize -> parse roundtrip preserves data."""
|
||||
original = {
|
||||
"is_abortable": True,
|
||||
"progress_percent": 0.75,
|
||||
"error_message": "Test error",
|
||||
}
|
||||
serialized = serialize_properties(original)
|
||||
parsed = parse_properties(serialized)
|
||||
assert parsed == original
|
||||
|
||||
@@ -54,7 +54,7 @@ def test_json_loads_exception():
|
||||
|
||||
|
||||
def test_json_loads_encoding():
|
||||
unicode_data = b'{"a": "\u0073\u0074\u0072"}'
|
||||
unicode_data = rb'{"a": "\u0073\u0074\u0072"}'
|
||||
data = json.loads(unicode_data)
|
||||
assert data["a"] == "str"
|
||||
utf16_data = b'\xff\xfe{\x00"\x00a\x00"\x00:\x00 \x00"\x00s\x00t\x00r\x00"\x00}\x00'
|
||||
|
||||
@@ -119,7 +119,7 @@ def test_refresh_oauth2_token_deletes_token_on_oauth2_exception(
|
||||
was revoked), the invalid token should be deleted and the exception re-raised.
|
||||
"""
|
||||
db = mocker.patch("superset.utils.oauth2.db")
|
||||
mocker.patch("superset.utils.oauth2.KeyValueDistributedLock")
|
||||
mocker.patch("superset.utils.oauth2.DistributedLock")
|
||||
|
||||
class OAuth2ExceptionError(Exception):
|
||||
pass
|
||||
@@ -149,7 +149,7 @@ def test_refresh_oauth2_token_keeps_token_on_other_exception(
|
||||
exception re-raised.
|
||||
"""
|
||||
db = mocker.patch("superset.utils.oauth2.db")
|
||||
mocker.patch("superset.utils.oauth2.KeyValueDistributedLock")
|
||||
mocker.patch("superset.utils.oauth2.DistributedLock")
|
||||
|
||||
class OAuth2ExceptionError(Exception):
|
||||
pass
|
||||
@@ -175,7 +175,7 @@ def test_refresh_oauth2_token_no_access_token_in_response(
|
||||
This can happen when the refresh token was revoked.
|
||||
"""
|
||||
mocker.patch("superset.utils.oauth2.db")
|
||||
mocker.patch("superset.utils.oauth2.KeyValueDistributedLock")
|
||||
mocker.patch("superset.utils.oauth2.DistributedLock")
|
||||
db_engine_spec = mocker.MagicMock()
|
||||
db_engine_spec.get_oauth2_fresh_token.return_value = {
|
||||
"error": "invalid_grant",
|
||||
|
||||
Reference in New Issue
Block a user