Compare commits

...

2 Commits

Author SHA1 Message Date
Evan
a8ef5227d0 test(async): assert _jwt_needs_refresh handles already-expired tokens
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-03 08:06:55 -07:00
Claude Code
da3b6d656f fix(async): add expiration to async-queries JWT with sliding refresh
The async-queries JWT was minted without an exp claim, so a leaked async-token
cookie granted access to the user's async event channel indefinitely.

Add iat/exp to the token (lifetime configurable via the new
GLOBAL_ASYNC_QUERIES_JWT_EXP_SECONDS, default 1h) and refresh it transparently:
the after_request handler now reissues the cookie whenever the token is
missing, legacy (no exp), expired, or past the first half of its lifetime,
reusing the existing channel so an active session is never disrupted. The
cookie also gets a matching max_age. PyJWT validates exp on decode, so an
expired (e.g. leaked) token is rejected by parse_channel_id_from_request and by
the websocket server.

DRAFT: behavior-sensitive (session/cookie lifetime). Needs validation that a
long-idle session (idle beyond the token lifetime) refreshes cleanly and does
not surface a transient async-query failure on the first request after expiry.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-03 08:05:51 -07:00
3 changed files with 125 additions and 10 deletions

View File

@@ -18,6 +18,7 @@ from __future__ import annotations
import logging
import uuid
from datetime import datetime, timedelta, timezone
from typing import Any, Literal, Optional
import jwt
@@ -112,6 +113,7 @@ class AsyncQueryManager:
self._jwt_cookie_domain: Optional[str]
self._jwt_cookie_samesite: Optional[Literal["None", "Lax", "Strict"]] = None
self._jwt_secret: str
self._jwt_exp_seconds: int = 3600
self._load_chart_data_into_cache_job: Any = None
# pylint: disable=invalid-name
self._load_explore_json_into_cache_job: Any = None
@@ -147,6 +149,7 @@ class AsyncQueryManager:
]
self._jwt_cookie_domain = app.config["GLOBAL_ASYNC_QUERIES_JWT_COOKIE_DOMAIN"]
self._jwt_secret = app.config["GLOBAL_ASYNC_QUERIES_JWT_SECRET"]
self._jwt_exp_seconds = app.config["GLOBAL_ASYNC_QUERIES_JWT_EXP_SECONDS"]
if app.config["GLOBAL_ASYNC_QUERIES_REGISTER_REQUEST_HANDLERS"]:
self.register_request_handlers(app)
@@ -160,28 +163,62 @@ class AsyncQueryManager:
self._load_chart_data_into_cache_job = load_chart_data_into_cache
self._load_explore_json_into_cache_job = load_explore_json_into_cache
def generate_jwt(self, payload: dict[str, Any]) -> str:
now = datetime.now(tz=timezone.utc)
claims = {
**payload,
"iat": now,
"exp": now + timedelta(seconds=self._jwt_exp_seconds),
}
return jwt.encode(claims, self._jwt_secret, algorithm="HS256")
def _jwt_needs_refresh(self, token: Optional[str]) -> bool:
"""
Return True when the async-queries cookie should be (re)issued.
The token is refreshed proactively, once it is past the first half of
its lifetime, so that an active session keeps a valid token while a
leaked token (which is not being refreshed by the user's session) still
expires within ``GLOBAL_ASYNC_QUERIES_JWT_EXP_SECONDS``. Missing,
malformed, expired, or legacy (no ``exp``) tokens are also refreshed.
"""
if not token:
return True
try:
claims = jwt.decode(token, self._jwt_secret, algorithms=["HS256"])
except jwt.PyJWTError:
return True
exp = claims.get("exp")
if not exp:
return True
seconds_remaining = exp - datetime.now(tz=timezone.utc).timestamp()
return seconds_remaining < self._jwt_exp_seconds / 2
def register_request_handlers(self, app: Flask) -> None:
@app.after_request
def validate_session(response: Response) -> Response:
user_id = get_user_id()
reset_token = (
not request.cookies.get(self._jwt_cookie_name)
or "async_channel_id" not in session
# A new channel is only needed when there isn't one for this user
# yet (or the user changed). The token, however, is refreshed
# whenever it is missing or close to expiring — reusing the existing
# channel so an expiring token does not disrupt the session.
reset_channel = (
"async_channel_id" not in session
or "async_user_id" not in session
or user_id != session["async_user_id"]
)
if reset_token:
async_channel_id = str(uuid.uuid4())
session["async_channel_id"] = async_channel_id
if reset_channel:
session["async_channel_id"] = str(uuid.uuid4())
session["async_user_id"] = user_id
if reset_channel or self._jwt_needs_refresh(
request.cookies.get(self._jwt_cookie_name)
):
sub = str(user_id) if user_id else None
token = jwt.encode(
{"channel": async_channel_id, "sub": sub},
self._jwt_secret,
algorithm="HS256",
token = self.generate_jwt(
{"channel": session["async_channel_id"], "sub": sub}
)
response.set_cookie(
@@ -191,6 +228,7 @@ class AsyncQueryManager:
secure=self._jwt_cookie_secure,
domain=self._jwt_cookie_domain,
samesite=self._jwt_cookie_samesite,
max_age=self._jwt_exp_seconds,
)
return response

View File

@@ -2332,6 +2332,11 @@ GLOBAL_ASYNC_QUERIES_JWT_COOKIE_SAMESITE: None | (Literal["None", "Lax", "Strict
)
GLOBAL_ASYNC_QUERIES_JWT_COOKIE_DOMAIN = None
GLOBAL_ASYNC_QUERIES_JWT_SECRET = "test-secret-change-me" # noqa: S105
# Lifetime (in seconds) of the async-queries JWT. The token carries an `exp`
# claim and is transparently refreshed while a session is active (see
# AsyncQueryManager.register_request_handlers), so a leaked token only grants
# access to the async event channel for at most this long.
GLOBAL_ASYNC_QUERIES_JWT_EXP_SECONDS = int(timedelta(hours=1).total_seconds())
GLOBAL_ASYNC_QUERIES_TRANSPORT: Literal["polling", "ws"] = "polling"
GLOBAL_ASYNC_QUERIES_POLLING_DELAY = int(
timedelta(milliseconds=500).total_seconds() * 1000

View File

@@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from datetime import datetime, timedelta, timezone
from unittest import mock
from unittest.mock import ANY, Mock
@@ -78,6 +79,77 @@ def test_parse_channel_id_from_request_bad_jwt(async_query_manager):
async_query_manager.parse_channel_id_from_request(request)
def test_generate_jwt_sets_expiration(async_query_manager):
"""Generated tokens carry iat/exp and round-trip through parsing."""
import jwt
token = async_query_manager.generate_jwt({"channel": "abc", "sub": "1"})
claims = jwt.decode(token, JWT_TOKEN_SECRET, algorithms=["HS256"])
assert claims["channel"] == "abc"
assert "exp" in claims
assert "iat" in claims
assert claims["exp"] > claims["iat"]
assert (
async_query_manager.parse_channel_id_from_request(
Mock(cookies={"superset_async_jwt": token})
)
== "abc"
)
def test_parse_channel_id_rejects_expired_token(async_query_manager):
"""An expired token is rejected (PyJWT validates exp on decode)."""
past = datetime.now(tz=timezone.utc) - timedelta(hours=2)
expired = encode(
{
"channel": "abc",
"iat": past,
"exp": past + timedelta(hours=1),
},
JWT_TOKEN_SECRET,
algorithm="HS256",
)
with raises(AsyncQueryTokenException):
async_query_manager.parse_channel_id_from_request(
Mock(cookies={"superset_async_jwt": expired})
)
def test_jwt_needs_refresh(async_query_manager):
"""Refresh missing/legacy/expired/near-expiry tokens; keep fresh ones."""
now = datetime.now(tz=timezone.utc)
# Missing token
assert async_query_manager._jwt_needs_refresh(None) is True
# Legacy token without exp
legacy = encode({"channel": "abc"}, JWT_TOKEN_SECRET, algorithm="HS256")
assert async_query_manager._jwt_needs_refresh(legacy) is True
# Fresh token (full lifetime remaining) is not refreshed
fresh = async_query_manager.generate_jwt({"channel": "abc"})
assert async_query_manager._jwt_needs_refresh(fresh) is False
# Token in the second half of its lifetime is refreshed
near_expiry = encode(
{"channel": "abc", "iat": now, "exp": now + timedelta(minutes=10)},
JWT_TOKEN_SECRET,
algorithm="HS256",
)
assert async_query_manager._jwt_needs_refresh(near_expiry) is True
# Already-expired token is refreshed (decode raises ExpiredSignatureError)
past = now - timedelta(hours=2)
expired = encode(
{"channel": "abc", "iat": past, "exp": past + timedelta(hours=1)},
JWT_TOKEN_SECRET,
algorithm="HS256",
)
assert async_query_manager._jwt_needs_refresh(expired) is True
@mark.parametrize(
"cache_type, cache_backend",
[