Files
superset2/superset/tasks/utils.py

315 lines
11 KiB
Python

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import logging
import traceback
from http.client import HTTPResponse
from typing import cast, TYPE_CHECKING
from urllib import request
from uuid import UUID, uuid4
from celery.utils.log import get_task_logger
from flask import g
from superset_core.tasks.types import TaskProperties, TaskScope
from superset.tasks.exceptions import ExecutorNotFoundError, InvalidExecutorError
from superset.tasks.types import (
ChosenExecutor,
Executor,
ExecutorType,
FixedExecutor,
)
from superset.utils import json
from superset.utils.hashing import hash_from_str
from superset.utils.urls import get_url_path
if TYPE_CHECKING:
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from superset.reports.models import ReportSchedule
logger = get_task_logger(__name__)
logger.setLevel(logging.INFO)
# pylint: disable=too-many-branches
def get_executor( # noqa: C901
executors: list[Executor],
model: Dashboard | ReportSchedule | Slice,
current_user: str | None = None,
) -> ChosenExecutor:
"""
Extract the user that should be used to execute a scheduled task. Certain executor
types extract the user from the underlying object (e.g. CREATOR), the constant
Selenium user (SELENIUM), or the user that initiated the request.
:param executors: The requested executor in descending order. When the
first user is found it is returned.
:param model: The underlying object
:param current_user: The username of the user that initiated the task. For
thumbnails this is the user that requested the thumbnail, while for alerts
and reports this is None (=initiated by Celery).
:return: User to execute the execute the async task as. The first element of the
tuple represents the type of the executor, and the second represents the
username of the executor.
:raises ExecutorNotFoundError: If no users were found in after
iterating through all entries in `executors`
"""
owners = model.owners
owner_dict = {owner.id: owner for owner in owners}
for executor in executors:
if isinstance(executor, FixedExecutor):
return ExecutorType.FIXED_USER, executor.username
if executor == ExecutorType.FIXED_USER:
raise InvalidExecutorError()
if executor == ExecutorType.CURRENT_USER and current_user:
return executor, current_user
if executor == ExecutorType.CREATOR_OWNER:
if (user := model.created_by) and (owner := owner_dict.get(user.id)):
return executor, owner.username
if executor == ExecutorType.CREATOR:
if user := model.created_by:
return executor, user.username
if executor == ExecutorType.MODIFIER_OWNER:
if (user := model.changed_by) and (owner := owner_dict.get(user.id)):
return executor, owner.username
if executor == ExecutorType.MODIFIER:
if user := model.changed_by:
return executor, user.username
if executor == ExecutorType.OWNER:
owners = model.owners
if len(owners) == 1:
return executor, owners[0].username
if len(owners) > 1:
if modifier := model.changed_by:
if modifier and (user := owner_dict.get(modifier.id)):
return executor, user.username
if creator := model.created_by:
if creator and (user := owner_dict.get(creator.id)):
return executor, user.username
return executor, owners[0].username
raise ExecutorNotFoundError()
def get_current_user() -> str | None:
user = g.user if hasattr(g, "user") and g.user else None
if user and not user.is_anonymous:
return user.username
return None
def fetch_csrf_token(
headers: dict[str, str], session_cookie_name: str = "session"
) -> dict[str, str]:
"""
Fetches a CSRF token for API requests
:param headers: A map of headers to use in the request, including the session cookie
:returns: A map of headers, including the session cookie and csrf token
"""
url = get_url_path("SecurityRestApi.csrf_token")
logger.info("Fetching %s", url)
req = request.Request(url, headers=headers, method="GET") # noqa: S310
response: HTTPResponse
with request.urlopen(req, timeout=600) as response: # noqa: S310
body = response.read().decode("utf-8")
session_cookie: str | None = None
cookie_headers = response.headers.get_all("set-cookie")
if cookie_headers:
for cookie in cookie_headers:
cookie = cookie.split(";", 1)[0]
name, value = cookie.split("=", 1)
if name == session_cookie_name:
session_cookie = value
break
if response.status == 200:
data = json.loads(body)
res = {"X-CSRF-Token": data["result"]}
if session_cookie is not None:
res["Cookie"] = f"{session_cookie_name}={session_cookie}"
return res
logger.error("Error fetching CSRF token, status code: %s", response.status)
return {}
def generate_random_task_key() -> str:
"""
Generate a random task key.
This is the default behavior - each task submission gets a unique UUID
unless an explicit task_key is provided in TaskOptions.
:returns: A random UUID string
"""
return str(uuid4())
def get_active_dedup_key(
scope: TaskScope | str,
task_type: str,
task_key: str,
user_id: int | None = None,
) -> str:
"""
Build a deduplication key for active tasks.
The dedup_key enforces uniqueness at the database level via a unique index.
Active tasks use a composite key based on scope, which is then hashed using
the configured HASH_ALGORITHM to produce a fixed-length key.
The composite key format before hashing is:
- Private: private|task_type|task_key|user_id
- Shared: shared|task_type|task_key
- System: system|task_type|task_key
The final key is a hash digest (64 chars for sha256, 32 chars for md5).
:param scope: Task scope (PRIVATE/SHARED/SYSTEM) as TaskScope enum or string
:param task_type: Type of task (e.g., 'sql_execution')
:param task_key: Task identifier for deduplication
:param user_id: User ID (required for private tasks)
:returns: Hashed deduplication key string
:raises ValueError: If user_id is missing for private scope
"""
# Convert string to TaskScope if needed
if isinstance(scope, str):
scope = TaskScope(scope)
# Build composite key
match scope:
case TaskScope.PRIVATE:
if user_id is None:
raise ValueError("user_id required for private tasks")
composite_key = f"{scope.value}|{task_type}|{task_key}|{user_id}"
case TaskScope.SHARED:
composite_key = f"{scope.value}|{task_type}|{task_key}"
case TaskScope.SYSTEM:
composite_key = f"{scope.value}|{task_type}|{task_key}"
case _:
raise ValueError(f"Invalid scope: {scope}")
# Hash the composite key to produce a fixed-length dedup_key
# Truncate to 64 chars max to fit the database column in case
# a hash algo is used that generates hashes that exceed 64 chars
return hash_from_str(composite_key)[:64]
def get_finished_dedup_key(task_uuid: UUID) -> str:
"""
Build a deduplication key for finished tasks.
When a task completes (success, failure, or abort), its dedup_key is
changed to its UUID. This frees up the slot so new tasks with the same
parameters can be created.
:param task_uuid: Task UUID (native UUID type)
:returns: The task UUID string as the dedup key
Example:
>>> from uuid import UUID
>>> get_finished_dedup_key(UUID("a1b2c3d4-e5f6-7890-abcd-ef1234567890"))
'a1b2c3d4-e5f6-7890-abcd-ef1234567890'
"""
return str(task_uuid)
# -----------------------------------------------------------------------------
# TaskProperties helper functions
# -----------------------------------------------------------------------------
def progress_update(progress: float | int | tuple[int, int]) -> TaskProperties:
"""
Create a properties update dict for progress values.
:param progress: One of:
- float (0.0-1.0): Percentage only
- int: Count only (total unknown)
- tuple[int, int]: (current, total) with auto-computed percentage
:returns: TaskProperties dict with appropriate progress fields set
Example:
task.update_properties(progress_update((50, 100)))
"""
if isinstance(progress, float):
return {"progress_percent": progress}
if isinstance(progress, int):
return {"progress_current": progress}
# tuple
current, total = progress
result: TaskProperties = {
"progress_current": current,
"progress_total": total,
}
if total > 0:
result["progress_percent"] = current / total
return result
def error_update(exception: BaseException) -> TaskProperties:
"""
Create a properties update dict from an exception.
:param exception: The exception that caused the failure
:returns: TaskProperties dict with error fields populated
"""
return {
"error_message": str(exception),
"exception_type": type(exception).__name__,
"stack_trace": traceback.format_exc(),
}
def parse_properties(json_str: str | None) -> TaskProperties:
"""
Parse JSON string into TaskProperties dict.
Returns empty dict on parse errors. Unknown keys are preserved
for forward compatibility (allows adding new properties without
breaking existing code).
:param json_str: JSON string or None
:returns: TaskProperties dict (sparse - only contains keys that were set)
"""
if not json_str:
return {}
try:
raw = json.loads(json_str)
if isinstance(raw, dict):
return cast(TaskProperties, raw)
return {}
except (json.JSONDecodeError, TypeError):
return {}
def serialize_properties(props: TaskProperties) -> str:
"""
Serialize TaskProperties to JSON string.
:param props: TaskProperties dict
:returns: JSON string
"""
return json.dumps(props)