feat(GAQ): Add Redis Sentinel Support for Global Async Queries (#29912)

Co-authored-by: Sivarajan Narayanan <narayanan_sivarajan@apple.com>
This commit is contained in:
nsivarajan
2024-08-30 23:12:29 +05:30
committed by GitHub
parent cd6b8b2f6d
commit 103cd3d6f3
6 changed files with 450 additions and 45 deletions

View File

@@ -14,20 +14,31 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import logging
import uuid
from typing import Any, Literal, Optional
from typing import Any, Literal, Optional, Union
import jwt
import redis
from flask import Flask, Request, request, Response, session
from flask_caching.backends.base import BaseCache
from superset.async_events.cache_backend import (
RedisCacheBackend,
RedisSentinelCacheBackend,
)
from superset.utils import json
from superset.utils.core import get_user_id
logger = logging.getLogger(__name__)
class CacheBackendNotInitialized(Exception):
pass
class AsyncQueryTokenException(Exception):
pass
@@ -55,13 +66,32 @@ def parse_event(event_data: tuple[str, dict[str, Any]]) -> dict[str, Any]:
return {"id": event_id, **json.loads(event_payload)}
def increment_id(redis_id: str) -> str:
def increment_id(entry_id: str) -> str:
# redis stream IDs are in this format: '1607477697866-0'
try:
prefix, last = redis_id[:-1], int(redis_id[-1])
prefix, last = entry_id[:-1], int(entry_id[-1])
return prefix + str(last + 1)
except Exception: # pylint: disable=broad-except
return redis_id
return entry_id
def get_cache_backend(
config: dict[str, Any],
) -> Union[RedisCacheBackend, RedisSentinelCacheBackend, redis.Redis]: # type: ignore
cache_config = config.get("GLOBAL_ASYNC_QUERIES_CACHE_BACKEND", {})
cache_type = cache_config.get("CACHE_TYPE")
if cache_type == "RedisCache":
return RedisCacheBackend.from_config(cache_config)
if cache_type == "RedisSentinelCache":
return RedisSentinelCacheBackend.from_config(cache_config)
# TODO: Deprecate hardcoded plain Redis code and expand cache backend options.
# Maintain backward compatibility with 'GLOBAL_ASYNC_QUERIES_REDIS_CONFIG' until it is deprecated.
return redis.Redis(
**config["GLOBAL_ASYNC_QUERIES_REDIS_CONFIG"], decode_responses=True
)
class AsyncQueryManager:
@@ -73,7 +103,7 @@ class AsyncQueryManager:
def __init__(self) -> None:
super().__init__()
self._redis: redis.Redis # type: ignore
self._cache: Optional[BaseCache] = None
self._stream_prefix: str = ""
self._stream_limit: Optional[int]
self._stream_limit_firehose: Optional[int]
@@ -88,10 +118,9 @@ class AsyncQueryManager:
def init_app(self, app: Flask) -> None:
config = app.config
if (
config["CACHE_CONFIG"]["CACHE_TYPE"] == "null"
or config["DATA_CACHE_CONFIG"]["CACHE_TYPE"] == "null"
):
cache_type = config.get("CACHE_CONFIG", {}).get("CACHE_TYPE")
data_cache_type = config.get("DATA_CACHE_CONFIG", {}).get("CACHE_TYPE")
if cache_type in [None, "null"] or data_cache_type in [None, "null"]:
raise Exception( # pylint: disable=broad-exception-raised
"""
Cache backends (CACHE_CONFIG, DATA_CACHE_CONFIG) must be configured
@@ -99,14 +128,14 @@ class AsyncQueryManager:
"""
)
self._cache = get_cache_backend(config)
logger.debug("Using GAQ Cache backend as %s", type(self._cache).__name__)
if len(config["GLOBAL_ASYNC_QUERIES_JWT_SECRET"]) < 32:
raise AsyncQueryTokenException(
"Please provide a JWT secret at least 32 bytes long"
)
self._redis = redis.Redis(
**config["GLOBAL_ASYNC_QUERIES_REDIS_CONFIG"], decode_responses=True
)
self._stream_prefix = config["GLOBAL_ASYNC_QUERIES_REDIS_STREAM_PREFIX"]
self._stream_limit = config["GLOBAL_ASYNC_QUERIES_REDIS_STREAM_LIMIT"]
self._stream_limit_firehose = config[
@@ -230,14 +259,35 @@ class AsyncQueryManager:
def read_events(
self, channel: str, last_id: Optional[str]
) -> list[Optional[dict[str, Any]]]:
if not self._cache:
raise CacheBackendNotInitialized("Cache backend not initialized")
stream_name = f"{self._stream_prefix}{channel}"
start_id = increment_id(last_id) if last_id else "-"
results = self._redis.xrange(stream_name, start_id, "+", self.MAX_EVENT_COUNT)
results = self._cache.xrange(stream_name, start_id, "+", self.MAX_EVENT_COUNT)
# Decode bytes to strings, decode_responses is not supported at RedisCache and RedisSentinelCache
if isinstance(self._cache, (RedisSentinelCacheBackend, RedisCacheBackend)):
decoded_results = [
(
event_id.decode("utf-8"),
{
key.decode("utf-8"): value.decode("utf-8")
for key, value in event_data.items()
},
)
for event_id, event_data in results
]
return (
[] if not decoded_results else list(map(parse_event, decoded_results))
)
return [] if not results else list(map(parse_event, results))
def update_job(
self, job_metadata: dict[str, Any], status: str, **kwargs: Any
) -> None:
if not self._cache:
raise CacheBackendNotInitialized("Cache backend not initialized")
if "channel_id" not in job_metadata:
raise AsyncQueryJobException("No channel ID specified")
@@ -253,5 +303,5 @@ class AsyncQueryManager:
logger.debug("********** logging event data to stream %s", scoped_stream_name)
logger.debug(event_data)
self._redis.xadd(scoped_stream_name, event_data, "*", self._stream_limit)
self._redis.xadd(full_stream_name, event_data, "*", self._stream_limit_firehose)
self._cache.xadd(scoped_stream_name, event_data, "*", self._stream_limit)
self._cache.xadd(full_stream_name, event_data, "*", self._stream_limit_firehose)