mirror of
https://github.com/apache/superset.git
synced 2026-06-10 10:09:14 +00:00
Compare commits
2 Commits
fix/dropdo
...
fix/async-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a8ef5227d0 | ||
|
|
da3b6d656f |
@@ -18,6 +18,7 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
import jwt
|
||||
@@ -112,6 +113,7 @@ class AsyncQueryManager:
|
||||
self._jwt_cookie_domain: Optional[str]
|
||||
self._jwt_cookie_samesite: Optional[Literal["None", "Lax", "Strict"]] = None
|
||||
self._jwt_secret: str
|
||||
self._jwt_exp_seconds: int = 3600
|
||||
self._load_chart_data_into_cache_job: Any = None
|
||||
# pylint: disable=invalid-name
|
||||
self._load_explore_json_into_cache_job: Any = None
|
||||
@@ -147,6 +149,7 @@ class AsyncQueryManager:
|
||||
]
|
||||
self._jwt_cookie_domain = app.config["GLOBAL_ASYNC_QUERIES_JWT_COOKIE_DOMAIN"]
|
||||
self._jwt_secret = app.config["GLOBAL_ASYNC_QUERIES_JWT_SECRET"]
|
||||
self._jwt_exp_seconds = app.config["GLOBAL_ASYNC_QUERIES_JWT_EXP_SECONDS"]
|
||||
|
||||
if app.config["GLOBAL_ASYNC_QUERIES_REGISTER_REQUEST_HANDLERS"]:
|
||||
self.register_request_handlers(app)
|
||||
@@ -160,28 +163,62 @@ class AsyncQueryManager:
|
||||
self._load_chart_data_into_cache_job = load_chart_data_into_cache
|
||||
self._load_explore_json_into_cache_job = load_explore_json_into_cache
|
||||
|
||||
def generate_jwt(self, payload: dict[str, Any]) -> str:
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
claims = {
|
||||
**payload,
|
||||
"iat": now,
|
||||
"exp": now + timedelta(seconds=self._jwt_exp_seconds),
|
||||
}
|
||||
return jwt.encode(claims, self._jwt_secret, algorithm="HS256")
|
||||
|
||||
def _jwt_needs_refresh(self, token: Optional[str]) -> bool:
|
||||
"""
|
||||
Return True when the async-queries cookie should be (re)issued.
|
||||
|
||||
The token is refreshed proactively, once it is past the first half of
|
||||
its lifetime, so that an active session keeps a valid token while a
|
||||
leaked token (which is not being refreshed by the user's session) still
|
||||
expires within ``GLOBAL_ASYNC_QUERIES_JWT_EXP_SECONDS``. Missing,
|
||||
malformed, expired, or legacy (no ``exp``) tokens are also refreshed.
|
||||
"""
|
||||
if not token:
|
||||
return True
|
||||
try:
|
||||
claims = jwt.decode(token, self._jwt_secret, algorithms=["HS256"])
|
||||
except jwt.PyJWTError:
|
||||
return True
|
||||
exp = claims.get("exp")
|
||||
if not exp:
|
||||
return True
|
||||
seconds_remaining = exp - datetime.now(tz=timezone.utc).timestamp()
|
||||
return seconds_remaining < self._jwt_exp_seconds / 2
|
||||
|
||||
def register_request_handlers(self, app: Flask) -> None:
|
||||
@app.after_request
|
||||
def validate_session(response: Response) -> Response:
|
||||
user_id = get_user_id()
|
||||
|
||||
reset_token = (
|
||||
not request.cookies.get(self._jwt_cookie_name)
|
||||
or "async_channel_id" not in session
|
||||
# A new channel is only needed when there isn't one for this user
|
||||
# yet (or the user changed). The token, however, is refreshed
|
||||
# whenever it is missing or close to expiring — reusing the existing
|
||||
# channel so an expiring token does not disrupt the session.
|
||||
reset_channel = (
|
||||
"async_channel_id" not in session
|
||||
or "async_user_id" not in session
|
||||
or user_id != session["async_user_id"]
|
||||
)
|
||||
|
||||
if reset_token:
|
||||
async_channel_id = str(uuid.uuid4())
|
||||
session["async_channel_id"] = async_channel_id
|
||||
if reset_channel:
|
||||
session["async_channel_id"] = str(uuid.uuid4())
|
||||
session["async_user_id"] = user_id
|
||||
|
||||
if reset_channel or self._jwt_needs_refresh(
|
||||
request.cookies.get(self._jwt_cookie_name)
|
||||
):
|
||||
sub = str(user_id) if user_id else None
|
||||
token = jwt.encode(
|
||||
{"channel": async_channel_id, "sub": sub},
|
||||
self._jwt_secret,
|
||||
algorithm="HS256",
|
||||
token = self.generate_jwt(
|
||||
{"channel": session["async_channel_id"], "sub": sub}
|
||||
)
|
||||
|
||||
response.set_cookie(
|
||||
@@ -191,6 +228,7 @@ class AsyncQueryManager:
|
||||
secure=self._jwt_cookie_secure,
|
||||
domain=self._jwt_cookie_domain,
|
||||
samesite=self._jwt_cookie_samesite,
|
||||
max_age=self._jwt_exp_seconds,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
@@ -2332,6 +2332,11 @@ GLOBAL_ASYNC_QUERIES_JWT_COOKIE_SAMESITE: None | (Literal["None", "Lax", "Strict
|
||||
)
|
||||
GLOBAL_ASYNC_QUERIES_JWT_COOKIE_DOMAIN = None
|
||||
GLOBAL_ASYNC_QUERIES_JWT_SECRET = "test-secret-change-me" # noqa: S105
|
||||
# Lifetime (in seconds) of the async-queries JWT. The token carries an `exp`
|
||||
# claim and is transparently refreshed while a session is active (see
|
||||
# AsyncQueryManager.register_request_handlers), so a leaked token only grants
|
||||
# access to the async event channel for at most this long.
|
||||
GLOBAL_ASYNC_QUERIES_JWT_EXP_SECONDS = int(timedelta(hours=1).total_seconds())
|
||||
GLOBAL_ASYNC_QUERIES_TRANSPORT: Literal["polling", "ws"] = "polling"
|
||||
GLOBAL_ASYNC_QUERIES_POLLING_DELAY = int(
|
||||
timedelta(milliseconds=500).total_seconds() * 1000
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from unittest import mock
|
||||
from unittest.mock import ANY, Mock
|
||||
|
||||
@@ -78,6 +79,77 @@ def test_parse_channel_id_from_request_bad_jwt(async_query_manager):
|
||||
async_query_manager.parse_channel_id_from_request(request)
|
||||
|
||||
|
||||
def test_generate_jwt_sets_expiration(async_query_manager):
|
||||
"""Generated tokens carry iat/exp and round-trip through parsing."""
|
||||
import jwt
|
||||
|
||||
token = async_query_manager.generate_jwt({"channel": "abc", "sub": "1"})
|
||||
claims = jwt.decode(token, JWT_TOKEN_SECRET, algorithms=["HS256"])
|
||||
|
||||
assert claims["channel"] == "abc"
|
||||
assert "exp" in claims
|
||||
assert "iat" in claims
|
||||
assert claims["exp"] > claims["iat"]
|
||||
assert (
|
||||
async_query_manager.parse_channel_id_from_request(
|
||||
Mock(cookies={"superset_async_jwt": token})
|
||||
)
|
||||
== "abc"
|
||||
)
|
||||
|
||||
|
||||
def test_parse_channel_id_rejects_expired_token(async_query_manager):
|
||||
"""An expired token is rejected (PyJWT validates exp on decode)."""
|
||||
past = datetime.now(tz=timezone.utc) - timedelta(hours=2)
|
||||
expired = encode(
|
||||
{
|
||||
"channel": "abc",
|
||||
"iat": past,
|
||||
"exp": past + timedelta(hours=1),
|
||||
},
|
||||
JWT_TOKEN_SECRET,
|
||||
algorithm="HS256",
|
||||
)
|
||||
|
||||
with raises(AsyncQueryTokenException):
|
||||
async_query_manager.parse_channel_id_from_request(
|
||||
Mock(cookies={"superset_async_jwt": expired})
|
||||
)
|
||||
|
||||
|
||||
def test_jwt_needs_refresh(async_query_manager):
|
||||
"""Refresh missing/legacy/expired/near-expiry tokens; keep fresh ones."""
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
|
||||
# Missing token
|
||||
assert async_query_manager._jwt_needs_refresh(None) is True
|
||||
|
||||
# Legacy token without exp
|
||||
legacy = encode({"channel": "abc"}, JWT_TOKEN_SECRET, algorithm="HS256")
|
||||
assert async_query_manager._jwt_needs_refresh(legacy) is True
|
||||
|
||||
# Fresh token (full lifetime remaining) is not refreshed
|
||||
fresh = async_query_manager.generate_jwt({"channel": "abc"})
|
||||
assert async_query_manager._jwt_needs_refresh(fresh) is False
|
||||
|
||||
# Token in the second half of its lifetime is refreshed
|
||||
near_expiry = encode(
|
||||
{"channel": "abc", "iat": now, "exp": now + timedelta(minutes=10)},
|
||||
JWT_TOKEN_SECRET,
|
||||
algorithm="HS256",
|
||||
)
|
||||
assert async_query_manager._jwt_needs_refresh(near_expiry) is True
|
||||
|
||||
# Already-expired token is refreshed (decode raises ExpiredSignatureError)
|
||||
past = now - timedelta(hours=2)
|
||||
expired = encode(
|
||||
{"channel": "abc", "iat": past, "exp": past + timedelta(hours=1)},
|
||||
JWT_TOKEN_SECRET,
|
||||
algorithm="HS256",
|
||||
)
|
||||
assert async_query_manager._jwt_needs_refresh(expired) is True
|
||||
|
||||
|
||||
@mark.parametrize(
|
||||
"cache_type, cache_backend",
|
||||
[
|
||||
|
||||
Reference in New Issue
Block a user