mirror of
https://github.com/apache/superset.git
synced 2026-04-13 13:18:25 +00:00
315 lines
11 KiB
Python
315 lines
11 KiB
Python
# Licensed to the Apache Software Foundation (ASF) under one
|
|
# or more contributor license agreements. See the NOTICE file
|
|
# distributed with this work for additional information
|
|
# regarding copyright ownership. The ASF licenses this file
|
|
# to you under the Apache License, Version 2.0 (the
|
|
# "License"); you may not use this file except in compliance
|
|
# with the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing,
|
|
# software distributed under the License is distributed on an
|
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
# KIND, either express or implied. See the License for the
|
|
# specific language governing permissions and limitations
|
|
# under the License.
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import traceback
|
|
from http.client import HTTPResponse
|
|
from typing import cast, TYPE_CHECKING
|
|
from urllib import request
|
|
from uuid import UUID, uuid4
|
|
|
|
from celery.utils.log import get_task_logger
|
|
from flask import g
|
|
from superset_core.tasks.types import TaskProperties, TaskScope
|
|
|
|
from superset.tasks.exceptions import ExecutorNotFoundError, InvalidExecutorError
|
|
from superset.tasks.types import (
|
|
ChosenExecutor,
|
|
Executor,
|
|
ExecutorType,
|
|
FixedExecutor,
|
|
)
|
|
from superset.utils import json
|
|
from superset.utils.hashing import hash_from_str
|
|
from superset.utils.urls import get_url_path
|
|
|
|
if TYPE_CHECKING:
|
|
from superset.models.dashboard import Dashboard
|
|
from superset.models.slice import Slice
|
|
from superset.reports.models import ReportSchedule
|
|
|
|
|
|
logger = get_task_logger(__name__)
|
|
logger.setLevel(logging.INFO)
|
|
|
|
|
|
# pylint: disable=too-many-branches
|
|
def get_executor( # noqa: C901
|
|
executors: list[Executor],
|
|
model: Dashboard | ReportSchedule | Slice,
|
|
current_user: str | None = None,
|
|
) -> ChosenExecutor:
|
|
"""
|
|
Extract the user that should be used to execute a scheduled task. Certain executor
|
|
types extract the user from the underlying object (e.g. CREATOR), the constant
|
|
Selenium user (SELENIUM), or the user that initiated the request.
|
|
|
|
:param executors: The requested executor in descending order. When the
|
|
first user is found it is returned.
|
|
:param model: The underlying object
|
|
:param current_user: The username of the user that initiated the task. For
|
|
thumbnails this is the user that requested the thumbnail, while for alerts
|
|
and reports this is None (=initiated by Celery).
|
|
:return: User to execute the execute the async task as. The first element of the
|
|
tuple represents the type of the executor, and the second represents the
|
|
username of the executor.
|
|
:raises ExecutorNotFoundError: If no users were found in after
|
|
iterating through all entries in `executors`
|
|
"""
|
|
owners = model.owners
|
|
owner_dict = {owner.id: owner for owner in owners}
|
|
for executor in executors:
|
|
if isinstance(executor, FixedExecutor):
|
|
return ExecutorType.FIXED_USER, executor.username
|
|
if executor == ExecutorType.FIXED_USER:
|
|
raise InvalidExecutorError()
|
|
if executor == ExecutorType.CURRENT_USER and current_user:
|
|
return executor, current_user
|
|
if executor == ExecutorType.CREATOR_OWNER:
|
|
if (user := model.created_by) and (owner := owner_dict.get(user.id)):
|
|
return executor, owner.username
|
|
if executor == ExecutorType.CREATOR:
|
|
if user := model.created_by:
|
|
return executor, user.username
|
|
if executor == ExecutorType.MODIFIER_OWNER:
|
|
if (user := model.changed_by) and (owner := owner_dict.get(user.id)):
|
|
return executor, owner.username
|
|
if executor == ExecutorType.MODIFIER:
|
|
if user := model.changed_by:
|
|
return executor, user.username
|
|
if executor == ExecutorType.OWNER:
|
|
owners = model.owners
|
|
if len(owners) == 1:
|
|
return executor, owners[0].username
|
|
if len(owners) > 1:
|
|
if modifier := model.changed_by:
|
|
if modifier and (user := owner_dict.get(modifier.id)):
|
|
return executor, user.username
|
|
if creator := model.created_by:
|
|
if creator and (user := owner_dict.get(creator.id)):
|
|
return executor, user.username
|
|
return executor, owners[0].username
|
|
|
|
raise ExecutorNotFoundError()
|
|
|
|
|
|
def get_current_user() -> str | None:
|
|
user = g.user if hasattr(g, "user") and g.user else None
|
|
if user and not user.is_anonymous:
|
|
return user.username
|
|
|
|
return None
|
|
|
|
|
|
def fetch_csrf_token(
|
|
headers: dict[str, str], session_cookie_name: str = "session"
|
|
) -> dict[str, str]:
|
|
"""
|
|
Fetches a CSRF token for API requests
|
|
|
|
:param headers: A map of headers to use in the request, including the session cookie
|
|
:returns: A map of headers, including the session cookie and csrf token
|
|
"""
|
|
url = get_url_path("SecurityRestApi.csrf_token")
|
|
logger.info("Fetching %s", url)
|
|
req = request.Request(url, headers=headers, method="GET") # noqa: S310
|
|
response: HTTPResponse
|
|
with request.urlopen(req, timeout=600) as response: # noqa: S310
|
|
body = response.read().decode("utf-8")
|
|
session_cookie: str | None = None
|
|
cookie_headers = response.headers.get_all("set-cookie")
|
|
if cookie_headers:
|
|
for cookie in cookie_headers:
|
|
cookie = cookie.split(";", 1)[0]
|
|
name, value = cookie.split("=", 1)
|
|
if name == session_cookie_name:
|
|
session_cookie = value
|
|
break
|
|
|
|
if response.status == 200:
|
|
data = json.loads(body)
|
|
res = {"X-CSRF-Token": data["result"]}
|
|
if session_cookie is not None:
|
|
res["Cookie"] = f"{session_cookie_name}={session_cookie}"
|
|
return res
|
|
|
|
logger.error("Error fetching CSRF token, status code: %s", response.status)
|
|
return {}
|
|
|
|
|
|
def generate_random_task_key() -> str:
|
|
"""
|
|
Generate a random task key.
|
|
|
|
This is the default behavior - each task submission gets a unique UUID
|
|
unless an explicit task_key is provided in TaskOptions.
|
|
|
|
:returns: A random UUID string
|
|
"""
|
|
return str(uuid4())
|
|
|
|
|
|
def get_active_dedup_key(
|
|
scope: TaskScope | str,
|
|
task_type: str,
|
|
task_key: str,
|
|
user_id: int | None = None,
|
|
) -> str:
|
|
"""
|
|
Build a deduplication key for active tasks.
|
|
|
|
The dedup_key enforces uniqueness at the database level via a unique index.
|
|
Active tasks use a composite key based on scope, which is then hashed using
|
|
the configured HASH_ALGORITHM to produce a fixed-length key.
|
|
|
|
The composite key format before hashing is:
|
|
- Private: private|task_type|task_key|user_id
|
|
- Shared: shared|task_type|task_key
|
|
- System: system|task_type|task_key
|
|
|
|
The final key is a hash digest (64 chars for sha256, 32 chars for md5).
|
|
|
|
:param scope: Task scope (PRIVATE/SHARED/SYSTEM) as TaskScope enum or string
|
|
:param task_type: Type of task (e.g., 'sql_execution')
|
|
:param task_key: Task identifier for deduplication
|
|
:param user_id: User ID (required for private tasks)
|
|
:returns: Hashed deduplication key string
|
|
:raises ValueError: If user_id is missing for private scope
|
|
"""
|
|
# Convert string to TaskScope if needed
|
|
if isinstance(scope, str):
|
|
scope = TaskScope(scope)
|
|
|
|
# Build composite key
|
|
match scope:
|
|
case TaskScope.PRIVATE:
|
|
if user_id is None:
|
|
raise ValueError("user_id required for private tasks")
|
|
composite_key = f"{scope.value}|{task_type}|{task_key}|{user_id}"
|
|
case TaskScope.SHARED:
|
|
composite_key = f"{scope.value}|{task_type}|{task_key}"
|
|
case TaskScope.SYSTEM:
|
|
composite_key = f"{scope.value}|{task_type}|{task_key}"
|
|
case _:
|
|
raise ValueError(f"Invalid scope: {scope}")
|
|
|
|
# Hash the composite key to produce a fixed-length dedup_key
|
|
# Truncate to 64 chars max to fit the database column in case
|
|
# a hash algo is used that generates hashes that exceed 64 chars
|
|
return hash_from_str(composite_key)[:64]
|
|
|
|
|
|
def get_finished_dedup_key(task_uuid: UUID) -> str:
|
|
"""
|
|
Build a deduplication key for finished tasks.
|
|
|
|
When a task completes (success, failure, or abort), its dedup_key is
|
|
changed to its UUID. This frees up the slot so new tasks with the same
|
|
parameters can be created.
|
|
|
|
:param task_uuid: Task UUID (native UUID type)
|
|
:returns: The task UUID string as the dedup key
|
|
|
|
Example:
|
|
>>> from uuid import UUID
|
|
>>> get_finished_dedup_key(UUID("a1b2c3d4-e5f6-7890-abcd-ef1234567890"))
|
|
'a1b2c3d4-e5f6-7890-abcd-ef1234567890'
|
|
"""
|
|
return str(task_uuid)
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# TaskProperties helper functions
|
|
# -----------------------------------------------------------------------------
|
|
|
|
|
|
def progress_update(progress: float | int | tuple[int, int]) -> TaskProperties:
|
|
"""
|
|
Create a properties update dict for progress values.
|
|
|
|
:param progress: One of:
|
|
- float (0.0-1.0): Percentage only
|
|
- int: Count only (total unknown)
|
|
- tuple[int, int]: (current, total) with auto-computed percentage
|
|
:returns: TaskProperties dict with appropriate progress fields set
|
|
|
|
Example:
|
|
task.update_properties(progress_update((50, 100)))
|
|
"""
|
|
if isinstance(progress, float):
|
|
return {"progress_percent": progress}
|
|
if isinstance(progress, int):
|
|
return {"progress_current": progress}
|
|
# tuple
|
|
current, total = progress
|
|
result: TaskProperties = {
|
|
"progress_current": current,
|
|
"progress_total": total,
|
|
}
|
|
if total > 0:
|
|
result["progress_percent"] = current / total
|
|
return result
|
|
|
|
|
|
def error_update(exception: BaseException) -> TaskProperties:
|
|
"""
|
|
Create a properties update dict from an exception.
|
|
|
|
:param exception: The exception that caused the failure
|
|
:returns: TaskProperties dict with error fields populated
|
|
"""
|
|
return {
|
|
"error_message": str(exception),
|
|
"exception_type": type(exception).__name__,
|
|
"stack_trace": traceback.format_exc(),
|
|
}
|
|
|
|
|
|
def parse_properties(json_str: str | None) -> TaskProperties:
|
|
"""
|
|
Parse JSON string into TaskProperties dict.
|
|
|
|
Returns empty dict on parse errors. Unknown keys are preserved
|
|
for forward compatibility (allows adding new properties without
|
|
breaking existing code).
|
|
|
|
:param json_str: JSON string or None
|
|
:returns: TaskProperties dict (sparse - only contains keys that were set)
|
|
"""
|
|
if not json_str:
|
|
return {}
|
|
|
|
try:
|
|
raw = json.loads(json_str)
|
|
if isinstance(raw, dict):
|
|
return cast(TaskProperties, raw)
|
|
return {}
|
|
except (json.JSONDecodeError, TypeError):
|
|
return {}
|
|
|
|
|
|
def serialize_properties(props: TaskProperties) -> str:
|
|
"""
|
|
Serialize TaskProperties to JSON string.
|
|
|
|
:param props: TaskProperties dict
|
|
:returns: JSON string
|
|
"""
|
|
return json.dumps(props)
|