diff --git a/superset/mcp_service/composite_token_verifier.py b/superset/mcp_service/composite_token_verifier.py new file mode 100644 index 00000000000..e4f5c6a5c43 --- /dev/null +++ b/superset/mcp_service/composite_token_verifier.py @@ -0,0 +1,80 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Composite token verifier for MCP authentication. + +Routes Bearer tokens to the appropriate verifier based on prefix: +- Tokens matching FAB_API_KEY_PREFIXES (e.g. ``sst_``) are passed through + to the Flask layer where ``_resolve_user_from_api_key()`` handles + actual validation via FAB SecurityManager. +- All other tokens are delegated to the wrapped JWT verifier. +""" + +import logging + +from fastmcp.server.auth import AccessToken +from fastmcp.server.auth.providers.jwt import TokenVerifier + +logger = logging.getLogger(__name__) + + +class CompositeTokenVerifier(TokenVerifier): + """Routes Bearer tokens between API key pass-through and JWT verification. + + API key tokens (identified by prefix) are accepted at the transport layer + with a marker claim so that ``_resolve_user_from_jwt_context()`` can + detect them and fall through to ``_resolve_user_from_api_key()`` for + actual validation. + + Args: + jwt_verifier: The wrapped JWT verifier for non-API-key tokens. + api_key_prefixes: List of prefixes that identify API key tokens + (e.g. ``["sst_"]``). + """ + + def __init__( + self, + jwt_verifier: TokenVerifier, + api_key_prefixes: list[str], + ) -> None: + super().__init__( + base_url=getattr(jwt_verifier, "base_url", None), + required_scopes=jwt_verifier.required_scopes, + ) + self._jwt_verifier = jwt_verifier + self._api_key_prefixes = tuple(api_key_prefixes) + + async def verify_token(self, token: str) -> AccessToken | None: + """Verify a Bearer token. + + If the token starts with an API key prefix, return a pass-through + AccessToken with a ``_api_key_passthrough`` claim. The Flask-layer + ``_resolve_user_from_api_key()`` performs the real validation. + + Otherwise, delegate to the wrapped JWT verifier. + """ + if any(token.startswith(prefix) for prefix in self._api_key_prefixes): + logger.debug("API key token detected (prefix match), passing through") + return AccessToken( + token=token, + client_id="api_key", + scopes=[], + claims={"_api_key_passthrough": True}, + ) + + return await self._jwt_verifier.verify_token(token) diff --git a/tests/unit_tests/mcp_service/test_composite_token_verifier.py b/tests/unit_tests/mcp_service/test_composite_token_verifier.py new file mode 100644 index 00000000000..537fe4d8f07 --- /dev/null +++ b/tests/unit_tests/mcp_service/test_composite_token_verifier.py @@ -0,0 +1,105 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Tests for CompositeTokenVerifier.""" + +from unittest.mock import AsyncMock, MagicMock + +import pytest +from fastmcp.server.auth import AccessToken + +from superset.mcp_service.composite_token_verifier import CompositeTokenVerifier + + +@pytest.fixture +def mock_jwt_verifier(): + verifier = MagicMock() + verifier.required_scopes = [] + verifier.verify_token = AsyncMock() + return verifier + + +@pytest.fixture +def composite_verifier(mock_jwt_verifier): + return CompositeTokenVerifier( + jwt_verifier=mock_jwt_verifier, + api_key_prefixes=["sst_", "pat_"], + ) + + +@pytest.mark.asyncio +async def test_api_key_token_returns_passthrough(composite_verifier) -> None: + """Tokens matching an API key prefix return a pass-through AccessToken.""" + api_key = "sst_abc123secret" # noqa: S105 + result = await composite_verifier.verify_token(api_key) + + assert result is not None + assert result.token == api_key + assert result.client_id == "api_key" + assert result.claims.get("_api_key_passthrough") is True + + +@pytest.mark.asyncio +async def test_second_prefix_matches(composite_verifier) -> None: + """All configured prefixes are checked, not just the first.""" + result = await composite_verifier.verify_token("pat_mytoken") + + assert result is not None + assert result.claims.get("_api_key_passthrough") is True + + +@pytest.mark.asyncio +async def test_jwt_token_delegates_to_wrapped_verifier( + composite_verifier, mock_jwt_verifier +) -> None: + """Non-API-key tokens are delegated to the wrapped JWT verifier.""" + jwt_token = "eyJhbGciOiJSUzI1NiJ9.jwt_payload" # noqa: S105 + jwt_result = AccessToken( + token=jwt_token, + client_id="oauth_client", + scopes=["read"], + claims={"sub": "user1"}, + ) + mock_jwt_verifier.verify_token.return_value = jwt_result + + result = await composite_verifier.verify_token("eyJhbGciOiJSUzI1NiJ9.jwt_payload") + + assert result is jwt_result + mock_jwt_verifier.verify_token.assert_awaited_once_with( + "eyJhbGciOiJSUzI1NiJ9.jwt_payload" + ) + + +@pytest.mark.asyncio +async def test_invalid_jwt_returns_none(composite_verifier, mock_jwt_verifier) -> None: + """When the JWT verifier rejects a token, None is returned.""" + mock_jwt_verifier.verify_token.return_value = None + + result = await composite_verifier.verify_token("not_a_valid_token") + + assert result is None + mock_jwt_verifier.verify_token.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_api_key_does_not_call_jwt_verifier( + composite_verifier, mock_jwt_verifier +) -> None: + """API key tokens bypass the JWT verifier entirely.""" + await composite_verifier.verify_token("sst_test_key") + + mock_jwt_verifier.verify_token.assert_not_awaited()