# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. from __future__ import annotations import logging import uuid from typing import Any, Literal, Optional import jwt from flask import Flask, Request, request, Response, session from flask_caching.backends.base import BaseCache from superset.async_events.cache_backend import ( RedisCacheBackend, RedisSentinelCacheBackend, ) from superset.utils import json from superset.utils.core import get_user_id logger = logging.getLogger(__name__) class CacheBackendNotInitialized(Exception): # noqa: N818 pass class AsyncQueryTokenException(Exception): # noqa: N818 pass class UnsupportedCacheBackendError(Exception): # noqa: N818 pass class AsyncQueryJobException(Exception): # noqa: N818 pass def build_job_metadata( channel_id: str, job_id: str, user_id: Optional[int], **kwargs: Any ) -> dict[str, Any]: return { "channel_id": channel_id, "job_id": job_id, "user_id": user_id, "status": kwargs.get("status"), "errors": kwargs.get("errors", []), "result_url": kwargs.get("result_url"), } def parse_event(event_data: tuple[str, dict[str, Any]]) -> dict[str, Any]: event_id = event_data[0] event_payload = event_data[1]["data"] return {"id": event_id, **json.loads(event_payload)} def increment_id(entry_id: str) -> str: # redis stream IDs are in this format: '1607477697866-0' try: prefix, last = entry_id[:-1], int(entry_id[-1]) return prefix + str(last + 1) except Exception: # pylint: disable=broad-except return entry_id def get_cache_backend( config: dict[str, Any], ) -> RedisCacheBackend | RedisSentinelCacheBackend: cache_config = config.get("GLOBAL_ASYNC_QUERIES_CACHE_BACKEND", {}) cache_type = cache_config.get("CACHE_TYPE") if cache_type == "RedisCache": return RedisCacheBackend.from_config(cache_config) if cache_type == "RedisSentinelCache": return RedisSentinelCacheBackend.from_config(cache_config) # TODO: Expand cache backend options. raise UnsupportedCacheBackendError("Unsupported cache backend configuration") class AsyncQueryManager: MAX_EVENT_COUNT = 100 STATUS_PENDING = "pending" STATUS_RUNNING = "running" STATUS_ERROR = "error" STATUS_DONE = "done" def __init__(self) -> None: super().__init__() self._cache: Optional[BaseCache] = None self._stream_prefix: str = "" self._stream_limit: Optional[int] self._stream_limit_firehose: Optional[int] self._jwt_cookie_name: str = "" self._jwt_cookie_secure: bool = False self._jwt_cookie_domain: Optional[str] self._jwt_cookie_samesite: Optional[Literal["None", "Lax", "Strict"]] = None self._jwt_secret: str self._load_chart_data_into_cache_job: Any = None # pylint: disable=invalid-name self._load_explore_json_into_cache_job: Any = None def init_app(self, app: Flask) -> None: cache_type = app.config.get("CACHE_CONFIG", {}).get("CACHE_TYPE") data_cache_type = app.config.get("DATA_CACHE_CONFIG", {}).get("CACHE_TYPE") if cache_type in [None, "null"] or data_cache_type in [None, "null"]: raise Exception( # pylint: disable=broad-exception-raised """ Cache backends (CACHE_CONFIG, DATA_CACHE_CONFIG) must be configured and non-null in order to enable async queries """ ) self._cache = get_cache_backend(app.config) logger.debug("Using GAQ Cache backend as %s", type(self._cache).__name__) if len(app.config["GLOBAL_ASYNC_QUERIES_JWT_SECRET"]) < 32: raise AsyncQueryTokenException( "Please provide a JWT secret at least 32 bytes long" ) self._stream_prefix = app.config["GLOBAL_ASYNC_QUERIES_REDIS_STREAM_PREFIX"] self._stream_limit = app.config["GLOBAL_ASYNC_QUERIES_REDIS_STREAM_LIMIT"] self._stream_limit_firehose = app.config[ "GLOBAL_ASYNC_QUERIES_REDIS_STREAM_LIMIT_FIREHOSE" ] self._jwt_cookie_name = app.config["GLOBAL_ASYNC_QUERIES_JWT_COOKIE_NAME"] self._jwt_cookie_secure = app.config["GLOBAL_ASYNC_QUERIES_JWT_COOKIE_SECURE"] self._jwt_cookie_samesite = app.config[ "GLOBAL_ASYNC_QUERIES_JWT_COOKIE_SAMESITE" ] self._jwt_cookie_domain = app.config["GLOBAL_ASYNC_QUERIES_JWT_COOKIE_DOMAIN"] self._jwt_secret = app.config["GLOBAL_ASYNC_QUERIES_JWT_SECRET"] if app.config["GLOBAL_ASYNC_QUERIES_REGISTER_REQUEST_HANDLERS"]: self.register_request_handlers(app) # pylint: disable=import-outside-toplevel from superset.tasks.async_queries import ( load_chart_data_into_cache, load_explore_json_into_cache, ) self._load_chart_data_into_cache_job = load_chart_data_into_cache self._load_explore_json_into_cache_job = load_explore_json_into_cache def register_request_handlers(self, app: Flask) -> None: @app.after_request def validate_session(response: Response) -> Response: user_id = get_user_id() reset_token = ( not request.cookies.get(self._jwt_cookie_name) or "async_channel_id" not in session or "async_user_id" not in session or user_id != session["async_user_id"] ) if reset_token: async_channel_id = str(uuid.uuid4()) session["async_channel_id"] = async_channel_id session["async_user_id"] = user_id sub = str(user_id) if user_id else None token = jwt.encode( {"channel": async_channel_id, "sub": sub}, self._jwt_secret, algorithm="HS256", ) response.set_cookie( self._jwt_cookie_name, value=token, httponly=True, secure=self._jwt_cookie_secure, domain=self._jwt_cookie_domain, samesite=self._jwt_cookie_samesite, ) return response def parse_channel_id_from_request(self, req: Request) -> str: token = req.cookies.get(self._jwt_cookie_name) if not token: raise AsyncQueryTokenException("Token not preset") try: return jwt.decode(token, self._jwt_secret, algorithms=["HS256"])["channel"] except Exception as ex: logger.warning("Parse jwt failed", exc_info=True) raise AsyncQueryTokenException("Failed to parse token") from ex def init_job(self, channel_id: str, user_id: Optional[int]) -> dict[str, Any]: job_id = str(uuid.uuid4()) return build_job_metadata( channel_id, job_id, user_id, status=self.STATUS_PENDING ) # pylint: disable=too-many-arguments def submit_explore_json_job( self, channel_id: str, form_data: dict[str, Any], response_type: str, force: Optional[bool] = False, user_id: Optional[int] = None, ) -> dict[str, Any]: # pylint: disable=import-outside-toplevel from superset import security_manager job_metadata = self.init_job(channel_id, user_id) self._load_explore_json_into_cache_job.delay( {**job_metadata, "guest_token": guest_user.guest_token} if (guest_user := security_manager.get_current_guest_user_if_guest()) else job_metadata, form_data, response_type, force, ) return job_metadata def submit_chart_data_job( self, channel_id: str, form_data: dict[str, Any], user_id: Optional[int] = None, ) -> dict[str, Any]: # pylint: disable=import-outside-toplevel from superset import security_manager # if it's guest user, we want to pass the guest token to the celery task # chart data cache key is calculated based on the current user # this way we can keep the cache key consistent between sync and async command # so that it can be looked up consistently job_metadata = self.init_job(channel_id, user_id) self._load_chart_data_into_cache_job.delay( {**job_metadata, "guest_token": guest_user.guest_token} if (guest_user := security_manager.get_current_guest_user_if_guest()) else job_metadata, form_data, ) return job_metadata def read_events( self, channel: str, last_id: Optional[str] ) -> list[Optional[dict[str, Any]]]: if not self._cache: raise CacheBackendNotInitialized("Cache backend not initialized") stream_name = f"{self._stream_prefix}{channel}" start_id = increment_id(last_id) if last_id else "-" results = self._cache.xrange(stream_name, start_id, "+", self.MAX_EVENT_COUNT) # Decode bytes to strings, decode_responses is not supported at RedisCache and RedisSentinelCache # noqa: E501 if isinstance(self._cache, (RedisSentinelCacheBackend, RedisCacheBackend)): decoded_results = [ ( event_id.decode("utf-8"), { key.decode("utf-8"): value.decode("utf-8") for key, value in event_data.items() }, ) for event_id, event_data in results ] return ( [] if not decoded_results else list(map(parse_event, decoded_results)) ) return [] if not results else list(map(parse_event, results)) def update_job( self, job_metadata: dict[str, Any], status: str, **kwargs: Any ) -> None: if not self._cache: raise CacheBackendNotInitialized("Cache backend not initialized") if "channel_id" not in job_metadata: raise AsyncQueryJobException("No channel ID specified") if "job_id" not in job_metadata: raise AsyncQueryJobException("No job ID specified") updates = {"status": status, **kwargs} event_data = {"data": json.dumps({**job_metadata, **updates})} full_stream_name = f"{self._stream_prefix}full" scoped_stream_name = f"{self._stream_prefix}{job_metadata['channel_id']}" logger.debug("********** logging event data to stream %s", scoped_stream_name) logger.debug(event_data) self._cache.xadd(scoped_stream_name, event_data, "*", self._stream_limit) self._cache.xadd(full_stream_name, event_data, "*", self._stream_limit_firehose)