mirror of
https://github.com/apache/superset.git
synced 2026-07-02 12:55:35 +00:00
Compare commits
2 Commits
chore/ci/s
...
fix/async-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a8ef5227d0 | ||
|
|
da3b6d656f |
@@ -18,6 +18,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
from typing import Any, Literal, Optional
|
from typing import Any, Literal, Optional
|
||||||
|
|
||||||
import jwt
|
import jwt
|
||||||
@@ -112,6 +113,7 @@ class AsyncQueryManager:
|
|||||||
self._jwt_cookie_domain: Optional[str]
|
self._jwt_cookie_domain: Optional[str]
|
||||||
self._jwt_cookie_samesite: Optional[Literal["None", "Lax", "Strict"]] = None
|
self._jwt_cookie_samesite: Optional[Literal["None", "Lax", "Strict"]] = None
|
||||||
self._jwt_secret: str
|
self._jwt_secret: str
|
||||||
|
self._jwt_exp_seconds: int = 3600
|
||||||
self._load_chart_data_into_cache_job: Any = None
|
self._load_chart_data_into_cache_job: Any = None
|
||||||
# pylint: disable=invalid-name
|
# pylint: disable=invalid-name
|
||||||
self._load_explore_json_into_cache_job: Any = None
|
self._load_explore_json_into_cache_job: Any = None
|
||||||
@@ -147,6 +149,7 @@ class AsyncQueryManager:
|
|||||||
]
|
]
|
||||||
self._jwt_cookie_domain = app.config["GLOBAL_ASYNC_QUERIES_JWT_COOKIE_DOMAIN"]
|
self._jwt_cookie_domain = app.config["GLOBAL_ASYNC_QUERIES_JWT_COOKIE_DOMAIN"]
|
||||||
self._jwt_secret = app.config["GLOBAL_ASYNC_QUERIES_JWT_SECRET"]
|
self._jwt_secret = app.config["GLOBAL_ASYNC_QUERIES_JWT_SECRET"]
|
||||||
|
self._jwt_exp_seconds = app.config["GLOBAL_ASYNC_QUERIES_JWT_EXP_SECONDS"]
|
||||||
|
|
||||||
if app.config["GLOBAL_ASYNC_QUERIES_REGISTER_REQUEST_HANDLERS"]:
|
if app.config["GLOBAL_ASYNC_QUERIES_REGISTER_REQUEST_HANDLERS"]:
|
||||||
self.register_request_handlers(app)
|
self.register_request_handlers(app)
|
||||||
@@ -160,28 +163,62 @@ class AsyncQueryManager:
|
|||||||
self._load_chart_data_into_cache_job = load_chart_data_into_cache
|
self._load_chart_data_into_cache_job = load_chart_data_into_cache
|
||||||
self._load_explore_json_into_cache_job = load_explore_json_into_cache
|
self._load_explore_json_into_cache_job = load_explore_json_into_cache
|
||||||
|
|
||||||
|
def generate_jwt(self, payload: dict[str, Any]) -> str:
|
||||||
|
now = datetime.now(tz=timezone.utc)
|
||||||
|
claims = {
|
||||||
|
**payload,
|
||||||
|
"iat": now,
|
||||||
|
"exp": now + timedelta(seconds=self._jwt_exp_seconds),
|
||||||
|
}
|
||||||
|
return jwt.encode(claims, self._jwt_secret, algorithm="HS256")
|
||||||
|
|
||||||
|
def _jwt_needs_refresh(self, token: Optional[str]) -> bool:
|
||||||
|
"""
|
||||||
|
Return True when the async-queries cookie should be (re)issued.
|
||||||
|
|
||||||
|
The token is refreshed proactively, once it is past the first half of
|
||||||
|
its lifetime, so that an active session keeps a valid token while a
|
||||||
|
leaked token (which is not being refreshed by the user's session) still
|
||||||
|
expires within ``GLOBAL_ASYNC_QUERIES_JWT_EXP_SECONDS``. Missing,
|
||||||
|
malformed, expired, or legacy (no ``exp``) tokens are also refreshed.
|
||||||
|
"""
|
||||||
|
if not token:
|
||||||
|
return True
|
||||||
|
try:
|
||||||
|
claims = jwt.decode(token, self._jwt_secret, algorithms=["HS256"])
|
||||||
|
except jwt.PyJWTError:
|
||||||
|
return True
|
||||||
|
exp = claims.get("exp")
|
||||||
|
if not exp:
|
||||||
|
return True
|
||||||
|
seconds_remaining = exp - datetime.now(tz=timezone.utc).timestamp()
|
||||||
|
return seconds_remaining < self._jwt_exp_seconds / 2
|
||||||
|
|
||||||
def register_request_handlers(self, app: Flask) -> None:
|
def register_request_handlers(self, app: Flask) -> None:
|
||||||
@app.after_request
|
@app.after_request
|
||||||
def validate_session(response: Response) -> Response:
|
def validate_session(response: Response) -> Response:
|
||||||
user_id = get_user_id()
|
user_id = get_user_id()
|
||||||
|
|
||||||
reset_token = (
|
# A new channel is only needed when there isn't one for this user
|
||||||
not request.cookies.get(self._jwt_cookie_name)
|
# yet (or the user changed). The token, however, is refreshed
|
||||||
or "async_channel_id" not in session
|
# whenever it is missing or close to expiring — reusing the existing
|
||||||
|
# channel so an expiring token does not disrupt the session.
|
||||||
|
reset_channel = (
|
||||||
|
"async_channel_id" not in session
|
||||||
or "async_user_id" not in session
|
or "async_user_id" not in session
|
||||||
or user_id != session["async_user_id"]
|
or user_id != session["async_user_id"]
|
||||||
)
|
)
|
||||||
|
|
||||||
if reset_token:
|
if reset_channel:
|
||||||
async_channel_id = str(uuid.uuid4())
|
session["async_channel_id"] = str(uuid.uuid4())
|
||||||
session["async_channel_id"] = async_channel_id
|
|
||||||
session["async_user_id"] = user_id
|
session["async_user_id"] = user_id
|
||||||
|
|
||||||
|
if reset_channel or self._jwt_needs_refresh(
|
||||||
|
request.cookies.get(self._jwt_cookie_name)
|
||||||
|
):
|
||||||
sub = str(user_id) if user_id else None
|
sub = str(user_id) if user_id else None
|
||||||
token = jwt.encode(
|
token = self.generate_jwt(
|
||||||
{"channel": async_channel_id, "sub": sub},
|
{"channel": session["async_channel_id"], "sub": sub}
|
||||||
self._jwt_secret,
|
|
||||||
algorithm="HS256",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
response.set_cookie(
|
response.set_cookie(
|
||||||
@@ -191,6 +228,7 @@ class AsyncQueryManager:
|
|||||||
secure=self._jwt_cookie_secure,
|
secure=self._jwt_cookie_secure,
|
||||||
domain=self._jwt_cookie_domain,
|
domain=self._jwt_cookie_domain,
|
||||||
samesite=self._jwt_cookie_samesite,
|
samesite=self._jwt_cookie_samesite,
|
||||||
|
max_age=self._jwt_exp_seconds,
|
||||||
)
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|||||||
@@ -2332,6 +2332,11 @@ GLOBAL_ASYNC_QUERIES_JWT_COOKIE_SAMESITE: None | (Literal["None", "Lax", "Strict
|
|||||||
)
|
)
|
||||||
GLOBAL_ASYNC_QUERIES_JWT_COOKIE_DOMAIN = None
|
GLOBAL_ASYNC_QUERIES_JWT_COOKIE_DOMAIN = None
|
||||||
GLOBAL_ASYNC_QUERIES_JWT_SECRET = "test-secret-change-me" # noqa: S105
|
GLOBAL_ASYNC_QUERIES_JWT_SECRET = "test-secret-change-me" # noqa: S105
|
||||||
|
# Lifetime (in seconds) of the async-queries JWT. The token carries an `exp`
|
||||||
|
# claim and is transparently refreshed while a session is active (see
|
||||||
|
# AsyncQueryManager.register_request_handlers), so a leaked token only grants
|
||||||
|
# access to the async event channel for at most this long.
|
||||||
|
GLOBAL_ASYNC_QUERIES_JWT_EXP_SECONDS = int(timedelta(hours=1).total_seconds())
|
||||||
GLOBAL_ASYNC_QUERIES_TRANSPORT: Literal["polling", "ws"] = "polling"
|
GLOBAL_ASYNC_QUERIES_TRANSPORT: Literal["polling", "ws"] = "polling"
|
||||||
GLOBAL_ASYNC_QUERIES_POLLING_DELAY = int(
|
GLOBAL_ASYNC_QUERIES_POLLING_DELAY = int(
|
||||||
timedelta(milliseconds=500).total_seconds() * 1000
|
timedelta(milliseconds=500).total_seconds() * 1000
|
||||||
|
|||||||
@@ -14,6 +14,7 @@
|
|||||||
# KIND, either express or implied. See the License for the
|
# KIND, either express or implied. See the License for the
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
from unittest.mock import ANY, Mock
|
from unittest.mock import ANY, Mock
|
||||||
|
|
||||||
@@ -78,6 +79,77 @@ def test_parse_channel_id_from_request_bad_jwt(async_query_manager):
|
|||||||
async_query_manager.parse_channel_id_from_request(request)
|
async_query_manager.parse_channel_id_from_request(request)
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_jwt_sets_expiration(async_query_manager):
|
||||||
|
"""Generated tokens carry iat/exp and round-trip through parsing."""
|
||||||
|
import jwt
|
||||||
|
|
||||||
|
token = async_query_manager.generate_jwt({"channel": "abc", "sub": "1"})
|
||||||
|
claims = jwt.decode(token, JWT_TOKEN_SECRET, algorithms=["HS256"])
|
||||||
|
|
||||||
|
assert claims["channel"] == "abc"
|
||||||
|
assert "exp" in claims
|
||||||
|
assert "iat" in claims
|
||||||
|
assert claims["exp"] > claims["iat"]
|
||||||
|
assert (
|
||||||
|
async_query_manager.parse_channel_id_from_request(
|
||||||
|
Mock(cookies={"superset_async_jwt": token})
|
||||||
|
)
|
||||||
|
== "abc"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_channel_id_rejects_expired_token(async_query_manager):
|
||||||
|
"""An expired token is rejected (PyJWT validates exp on decode)."""
|
||||||
|
past = datetime.now(tz=timezone.utc) - timedelta(hours=2)
|
||||||
|
expired = encode(
|
||||||
|
{
|
||||||
|
"channel": "abc",
|
||||||
|
"iat": past,
|
||||||
|
"exp": past + timedelta(hours=1),
|
||||||
|
},
|
||||||
|
JWT_TOKEN_SECRET,
|
||||||
|
algorithm="HS256",
|
||||||
|
)
|
||||||
|
|
||||||
|
with raises(AsyncQueryTokenException):
|
||||||
|
async_query_manager.parse_channel_id_from_request(
|
||||||
|
Mock(cookies={"superset_async_jwt": expired})
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_jwt_needs_refresh(async_query_manager):
|
||||||
|
"""Refresh missing/legacy/expired/near-expiry tokens; keep fresh ones."""
|
||||||
|
now = datetime.now(tz=timezone.utc)
|
||||||
|
|
||||||
|
# Missing token
|
||||||
|
assert async_query_manager._jwt_needs_refresh(None) is True
|
||||||
|
|
||||||
|
# Legacy token without exp
|
||||||
|
legacy = encode({"channel": "abc"}, JWT_TOKEN_SECRET, algorithm="HS256")
|
||||||
|
assert async_query_manager._jwt_needs_refresh(legacy) is True
|
||||||
|
|
||||||
|
# Fresh token (full lifetime remaining) is not refreshed
|
||||||
|
fresh = async_query_manager.generate_jwt({"channel": "abc"})
|
||||||
|
assert async_query_manager._jwt_needs_refresh(fresh) is False
|
||||||
|
|
||||||
|
# Token in the second half of its lifetime is refreshed
|
||||||
|
near_expiry = encode(
|
||||||
|
{"channel": "abc", "iat": now, "exp": now + timedelta(minutes=10)},
|
||||||
|
JWT_TOKEN_SECRET,
|
||||||
|
algorithm="HS256",
|
||||||
|
)
|
||||||
|
assert async_query_manager._jwt_needs_refresh(near_expiry) is True
|
||||||
|
|
||||||
|
# Already-expired token is refreshed (decode raises ExpiredSignatureError)
|
||||||
|
past = now - timedelta(hours=2)
|
||||||
|
expired = encode(
|
||||||
|
{"channel": "abc", "iat": past, "exp": past + timedelta(hours=1)},
|
||||||
|
JWT_TOKEN_SECRET,
|
||||||
|
algorithm="HS256",
|
||||||
|
)
|
||||||
|
assert async_query_manager._jwt_needs_refresh(expired) is True
|
||||||
|
|
||||||
|
|
||||||
@mark.parametrize(
|
@mark.parametrize(
|
||||||
"cache_type, cache_backend",
|
"cache_type, cache_backend",
|
||||||
[
|
[
|
||||||
|
|||||||
Reference in New Issue
Block a user