diff --git a/superset/config.py b/superset/config.py index 2637c0032be..09ef4e71a2d 100644 --- a/superset/config.py +++ b/superset/config.py @@ -1316,7 +1316,8 @@ GUEST_TOKEN_JWT_SECRET = "test-guest-secret-change-me" GUEST_TOKEN_JWT_ALGO = "HS256" GUEST_TOKEN_HEADER_NAME = "X-GuestToken" GUEST_TOKEN_JWT_EXP_SECONDS = 300 # 5 minutes -GUEST_TOKEN_JWT_AUDIENCE = None +# Guest token audience for the embedded superset, either string or callable +GUEST_TOKEN_JWT_AUDIENCE: Optional[Union[Callable[[], str], str]] = None # A SQL dataset health check. Note if enabled it is strongly advised that the callable # be memoized to aid with performance, i.e., diff --git a/superset/security/manager.py b/superset/security/manager.py index ac494a18378..91b203e83f7 100644 --- a/superset/security/manager.py +++ b/superset/security/manager.py @@ -1300,6 +1300,13 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods """ This is used so the tests can mock time """ return time.time() + @staticmethod + def _get_guest_token_jwt_audience() -> str: + audience = current_app.config["GUEST_TOKEN_JWT_AUDIENCE"] or get_url_host() + if callable(audience): + audience = audience() + return audience + def create_guest_access_token( self, user: GuestTokenUser, @@ -1309,8 +1316,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods secret = current_app.config["GUEST_TOKEN_JWT_SECRET"] algo = current_app.config["GUEST_TOKEN_JWT_ALGO"] exp_seconds = current_app.config["GUEST_TOKEN_JWT_EXP_SECONDS"] - aud = current_app.config["GUEST_TOKEN_JWT_AUDIENCE"] or get_url_host() - + audience = self._get_guest_token_jwt_audience() # calculate expiration time now = self._get_current_epoch_time() exp = now + (exp_seconds * 1000) @@ -1321,7 +1327,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods # standard jwt claims: "iat": now, # issued at "exp": exp, # expiration time - "aud": aud, + "aud": audience, "type": "guest", } token = jwt.encode(claims, secret, algorithm=algo) @@ -1363,8 +1369,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods token=token, roles=[self.find_role(current_app.config["GUEST_ROLE_NAME"])], ) - @staticmethod - def parse_jwt_guest_token(raw_token: str) -> Dict[str, Any]: + def parse_jwt_guest_token(self, raw_token: str) -> Dict[str, Any]: """ Parses a guest token. Raises an error if the jwt fails standard claims checks. :param raw_token: the token gotten from the request @@ -1372,8 +1377,8 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods """ secret = current_app.config["GUEST_TOKEN_JWT_SECRET"] algo = current_app.config["GUEST_TOKEN_JWT_ALGO"] - aud = current_app.config["GUEST_TOKEN_JWT_AUDIENCE"] or get_url_host() - return jwt.decode(raw_token, secret, algorithms=[algo], audience=aud) + audience = self._get_guest_token_jwt_audience() + return jwt.decode(raw_token, secret, algorithms=[algo], audience=audience) @staticmethod def is_guest_user(user: Optional[Any] = None) -> bool: diff --git a/tests/integration_tests/security_tests.py b/tests/integration_tests/security_tests.py index efcd191ffaf..9dca5ac5137 100644 --- a/tests/integration_tests/security_tests.py +++ b/tests/integration_tests/security_tests.py @@ -1299,3 +1299,25 @@ class TestGuestTokens(SupersetTestCase): self.assertRaisesRegex(jwt.exceptions.InvalidAudienceError, "Invalid audience") self.assertIsNone(guest_user) + + @patch("superset.security.SupersetSecurityManager._get_current_epoch_time") + def test_create_guest_access_token_callable_audience(self, get_time_mock): + now = time.time() + get_time_mock.return_value = now + app.config["GUEST_TOKEN_JWT_AUDIENCE"] = Mock(return_value="cool_code") + + user = {"username": "test_guest"} + resources = [{"some": "resource"}] + rls = [{"dataset": 1, "clause": "access = 1"}] + token = security_manager.create_guest_access_token(user, resources, rls) + + decoded_token = jwt.decode( + token, + self.app.config["GUEST_TOKEN_JWT_SECRET"], + algorithms=[self.app.config["GUEST_TOKEN_JWT_ALGO"]], + audience="cool_code", + ) + app.config["GUEST_TOKEN_JWT_AUDIENCE"].assert_called_once() + self.assertEqual("cool_code", decoded_token["aud"]) + self.assertEqual("guest", decoded_token["type"]) + app.config["GUEST_TOKEN_JWT_AUDIENCE"] = None