# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import json import logging import uuid from typing import Any, Dict, List, Optional, Tuple import jwt import redis from flask import Flask, Request, Response, session logger = logging.getLogger(__name__) class AsyncQueryTokenException(Exception): pass class AsyncQueryJobException(Exception): pass def build_job_metadata(channel_id: str, job_id: str, **kwargs: Any) -> Dict[str, Any]: return { "channel_id": channel_id, "job_id": job_id, "user_id": session.get("user_id"), "status": kwargs.get("status"), "errors": kwargs.get("errors", []), "result_url": kwargs.get("result_url"), } def parse_event(event_data: Tuple[str, Dict[str, Any]]) -> Dict[str, Any]: event_id = event_data[0] event_payload = event_data[1]["data"] return {"id": event_id, **json.loads(event_payload)} def increment_id(redis_id: str) -> str: # redis stream IDs are in this format: '1607477697866-0' try: prefix, last = redis_id[:-1], int(redis_id[-1]) return prefix + str(last + 1) except Exception: # pylint: disable=broad-except return redis_id class AsyncQueryManager: MAX_EVENT_COUNT = 100 STATUS_PENDING = "pending" STATUS_RUNNING = "running" STATUS_ERROR = "error" STATUS_DONE = "done" def __init__(self) -> None: super().__init__() self._redis: redis.Redis self._stream_prefix: str = "" self._stream_limit: Optional[int] self._stream_limit_firehose: Optional[int] self._jwt_cookie_name: str self._jwt_cookie_secure: bool = False self._jwt_secret: str def init_app(self, app: Flask) -> None: config = app.config if ( config["CACHE_CONFIG"]["CACHE_TYPE"] == "null" or config["DATA_CACHE_CONFIG"]["CACHE_TYPE"] == "null" ): raise Exception( """ Cache backends (CACHE_CONFIG, DATA_CACHE_CONFIG) must be configured and non-null in order to enable async queries """ ) if len(config["GLOBAL_ASYNC_QUERIES_JWT_SECRET"]) < 32: raise AsyncQueryTokenException( "Please provide a JWT secret at least 32 bytes long" ) self._redis = redis.Redis( # type: ignore **config["GLOBAL_ASYNC_QUERIES_REDIS_CONFIG"], decode_responses=True ) self._stream_prefix = config["GLOBAL_ASYNC_QUERIES_REDIS_STREAM_PREFIX"] self._stream_limit = config["GLOBAL_ASYNC_QUERIES_REDIS_STREAM_LIMIT"] self._stream_limit_firehose = config[ "GLOBAL_ASYNC_QUERIES_REDIS_STREAM_LIMIT_FIREHOSE" ] self._jwt_cookie_name = config["GLOBAL_ASYNC_QUERIES_JWT_COOKIE_NAME"] self._jwt_cookie_secure = config["GLOBAL_ASYNC_QUERIES_JWT_COOKIE_SECURE"] self._jwt_secret = config["GLOBAL_ASYNC_QUERIES_JWT_SECRET"] @app.after_request def validate_session( # pylint: disable=unused-variable response: Response, ) -> Response: reset_token = False user_id = session["user_id"] if "user_id" in session else None if "async_channel_id" not in session or "async_user_id" not in session: reset_token = True elif user_id != session["async_user_id"]: reset_token = True if reset_token: async_channel_id = str(uuid.uuid4()) session["async_channel_id"] = async_channel_id session["async_user_id"] = user_id sub = str(user_id) if user_id else None token = self.generate_jwt({"channel": async_channel_id, "sub": sub}) response.set_cookie( self._jwt_cookie_name, value=token, httponly=True, secure=self._jwt_cookie_secure, # max_age=max_age or config.cookie_max_age, # domain=config.cookie_domain, # path=config.access_cookie_path, # samesite=config.cookie_samesite ) return response def generate_jwt(self, data: Dict[str, Any]) -> str: encoded_jwt = jwt.encode(data, self._jwt_secret, algorithm="HS256") return encoded_jwt.decode("utf-8") def parse_jwt(self, token: str) -> Dict[str, Any]: data = jwt.decode(token, self._jwt_secret, algorithms=["HS256"]) return data def parse_jwt_from_request(self, request: Request) -> Dict[str, Any]: token = request.cookies.get(self._jwt_cookie_name) if not token: raise AsyncQueryTokenException("Token not preset") try: return self.parse_jwt(token) except Exception as exc: logger.warning(exc) raise AsyncQueryTokenException("Failed to parse token") def init_job(self, channel_id: str) -> Dict[str, Any]: job_id = str(uuid.uuid4()) return build_job_metadata(channel_id, job_id, status=self.STATUS_PENDING) def read_events( self, channel: str, last_id: Optional[str] ) -> List[Optional[Dict[str, Any]]]: stream_name = f"{self._stream_prefix}{channel}" start_id = increment_id(last_id) if last_id else "-" results = self._redis.xrange( # type: ignore stream_name, start_id, "+", self.MAX_EVENT_COUNT ) return [] if not results else list(map(parse_event, results)) def update_job( self, job_metadata: Dict[str, Any], status: str, **kwargs: Any ) -> None: if "channel_id" not in job_metadata: raise AsyncQueryJobException("No channel ID specified") if "job_id" not in job_metadata: raise AsyncQueryJobException("No job ID specified") updates = {"status": status, **kwargs} event_data = {"data": json.dumps({**job_metadata, **updates})} full_stream_name = f"{self._stream_prefix}full" scoped_stream_name = f"{self._stream_prefix}{job_metadata['channel_id']}" logger.debug("********** logging event data to stream %s", scoped_stream_name) logger.debug(event_data) self._redis.xadd( # type: ignore scoped_stream_name, event_data, "*", self._stream_limit ) self._redis.xadd( # type: ignore full_stream_name, event_data, "*", self._stream_limit_firehose )