mirror of
https://github.com/apache/superset.git
synced 2026-04-12 12:47:53 +00:00
430 lines
15 KiB
Python
430 lines
15 KiB
Python
# Licensed to the Apache Software Foundation (ASF) under one
|
|
# or more contributor license agreements. See the NOTICE file
|
|
# distributed with this work for additional information
|
|
# regarding copyright ownership. The ASF licenses this file
|
|
# to you under the Apache License, Version 2.0 (the
|
|
# "License"); you may not use this file except in compliance
|
|
# with the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing,
|
|
# software distributed under the License is distributed on an
|
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
# KIND, either express or implied. See the License for the
|
|
# specific language governing permissions and limitations
|
|
# under the License.
|
|
|
|
"""Tests for MCP user resolution priority and stale g.user prevention."""
|
|
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
from flask import g
|
|
|
|
from superset.mcp_service.auth import (
|
|
_resolve_user_from_jwt_context,
|
|
get_user_from_request,
|
|
mcp_auth_hook,
|
|
)
|
|
from superset.mcp_service.mcp_config import default_user_resolver
|
|
|
|
|
|
def _make_mock_user(username: str = "testuser") -> MagicMock:
|
|
"""Create a mock User with required attributes."""
|
|
user = MagicMock()
|
|
user.username = username
|
|
user.roles = []
|
|
user.groups = []
|
|
return user
|
|
|
|
|
|
def _make_access_token(
|
|
claims: dict[str, str] | None = None, **kwargs: str
|
|
) -> MagicMock:
|
|
"""Create a mock AccessToken matching FastMCP's format."""
|
|
token = MagicMock()
|
|
token.claims = claims or {}
|
|
token.client_id = kwargs.get("client_id", "")
|
|
token.scopes = kwargs.get("scopes", [])
|
|
# Remove auto-created attributes so getattr fallbacks work correctly
|
|
for attr in ("subject", "payload"):
|
|
if attr not in kwargs:
|
|
delattr(token, attr)
|
|
for attr in kwargs:
|
|
setattr(token, attr, kwargs[attr])
|
|
return token
|
|
|
|
|
|
# -- _resolve_user_from_jwt_context --
|
|
|
|
|
|
def test_jwt_context_resolves_correct_user(app) -> None:
|
|
"""JWT context with valid claims resolves the correct DB user."""
|
|
mock_user = _make_mock_user("alice")
|
|
token = _make_access_token(claims={"sub": "alice"})
|
|
|
|
with app.app_context():
|
|
with (
|
|
patch("fastmcp.server.dependencies.get_access_token", return_value=token),
|
|
patch(
|
|
"superset.mcp_service.auth.load_user_with_relationships",
|
|
return_value=mock_user,
|
|
),
|
|
):
|
|
result = _resolve_user_from_jwt_context(app)
|
|
|
|
assert result is not None
|
|
assert result.username == "alice"
|
|
|
|
|
|
def test_jwt_context_returns_none_when_no_token(app) -> None:
|
|
"""No JWT token present returns None (fall through to next source)."""
|
|
with app.app_context():
|
|
with patch("fastmcp.server.dependencies.get_access_token", return_value=None):
|
|
result = _resolve_user_from_jwt_context(app)
|
|
|
|
assert result is None
|
|
|
|
|
|
def test_jwt_context_raises_for_unknown_user(app) -> None:
|
|
"""JWT resolves a username not in DB — raises ValueError (fail closed)."""
|
|
token = _make_access_token(claims={"sub": "nonexistent"})
|
|
|
|
with app.app_context():
|
|
with (
|
|
patch("fastmcp.server.dependencies.get_access_token", return_value=token),
|
|
patch(
|
|
"superset.mcp_service.auth.load_user_with_relationships",
|
|
return_value=None,
|
|
),
|
|
):
|
|
with pytest.raises(ValueError, match="not found in Superset database"):
|
|
_resolve_user_from_jwt_context(app)
|
|
|
|
|
|
def test_jwt_context_raises_when_no_username_in_claims(app) -> None:
|
|
"""JWT present but claims have no extractable username — fails closed."""
|
|
token = _make_access_token(claims={"iss": "some-issuer"})
|
|
|
|
with app.app_context():
|
|
with patch("fastmcp.server.dependencies.get_access_token", return_value=token):
|
|
with pytest.raises(ValueError, match="no username could be extracted"):
|
|
_resolve_user_from_jwt_context(app)
|
|
|
|
|
|
def test_jwt_context_uses_custom_resolver(app) -> None:
|
|
"""Custom MCP_USER_RESOLVER config is used when set."""
|
|
mock_user = _make_mock_user("custom_user")
|
|
token = _make_access_token(claims={"custom_field": "custom_user"})
|
|
custom_resolver = MagicMock(return_value="custom_user")
|
|
|
|
with app.app_context():
|
|
app.config["MCP_USER_RESOLVER"] = custom_resolver
|
|
try:
|
|
with (
|
|
patch(
|
|
"fastmcp.server.dependencies.get_access_token", return_value=token
|
|
),
|
|
patch(
|
|
"superset.mcp_service.auth.load_user_with_relationships",
|
|
return_value=mock_user,
|
|
),
|
|
):
|
|
result = _resolve_user_from_jwt_context(app)
|
|
finally:
|
|
app.config.pop("MCP_USER_RESOLVER", None)
|
|
|
|
assert result is not None
|
|
assert result.username == "custom_user"
|
|
custom_resolver.assert_called_once_with(app, token)
|
|
|
|
|
|
def test_jwt_context_email_fallback_lookup(app) -> None:
|
|
"""When resolver returns an email, tries email lookup after username miss."""
|
|
mock_user = _make_mock_user("alice")
|
|
token = _make_access_token(claims={"email": "alice@example.com"})
|
|
|
|
def _load_side_effect(username=None, email=None):
|
|
if email == "alice@example.com":
|
|
return mock_user
|
|
return None
|
|
|
|
with app.app_context():
|
|
with (
|
|
patch("fastmcp.server.dependencies.get_access_token", return_value=token),
|
|
patch(
|
|
"superset.mcp_service.auth.load_user_with_relationships",
|
|
side_effect=_load_side_effect,
|
|
),
|
|
):
|
|
result = _resolve_user_from_jwt_context(app)
|
|
|
|
assert result is not None
|
|
assert result.username == "alice"
|
|
|
|
|
|
# -- get_user_from_request priority order --
|
|
|
|
|
|
def test_jwt_takes_priority_over_stale_g_user(app) -> None:
|
|
"""Core regression test: JWT user wins over stale g.user."""
|
|
stale_user = _make_mock_user("stale_bob")
|
|
jwt_user = _make_mock_user("jwt_alice")
|
|
token = _make_access_token(claims={"sub": "jwt_alice"})
|
|
|
|
with app.app_context():
|
|
g.user = stale_user
|
|
with (
|
|
patch("fastmcp.server.dependencies.get_access_token", return_value=token),
|
|
patch(
|
|
"superset.mcp_service.auth.load_user_with_relationships",
|
|
return_value=jwt_user,
|
|
),
|
|
):
|
|
result = get_user_from_request()
|
|
|
|
assert result.username == "jwt_alice"
|
|
|
|
|
|
def test_dev_username_fallback_when_no_jwt(app) -> None:
|
|
"""MCP_DEV_USERNAME used when no JWT context available."""
|
|
mock_user = _make_mock_user("dev_admin")
|
|
|
|
with app.app_context():
|
|
app.config["MCP_DEV_USERNAME"] = "dev_admin"
|
|
try:
|
|
with (
|
|
patch(
|
|
"fastmcp.server.dependencies.get_access_token", return_value=None
|
|
),
|
|
patch(
|
|
"superset.mcp_service.auth.load_user_with_relationships",
|
|
return_value=mock_user,
|
|
),
|
|
):
|
|
result = get_user_from_request()
|
|
finally:
|
|
app.config.pop("MCP_DEV_USERNAME", None)
|
|
|
|
assert result.username == "dev_admin"
|
|
|
|
|
|
def test_g_user_fallback_when_no_jwt_and_no_dev_username(app) -> None:
|
|
"""g.user used as last-resort fallback (Preset middleware compatibility)."""
|
|
preset_user = _make_mock_user("preset_user")
|
|
|
|
with app.app_context():
|
|
app.config.pop("MCP_DEV_USERNAME", None)
|
|
g.user = preset_user
|
|
with patch("fastmcp.server.dependencies.get_access_token", return_value=None):
|
|
result = get_user_from_request()
|
|
|
|
assert result.username == "preset_user"
|
|
|
|
|
|
def test_raises_when_no_auth_source(app) -> None:
|
|
"""ValueError raised when no auth source is available."""
|
|
with app.app_context():
|
|
app.config.pop("MCP_DEV_USERNAME", None)
|
|
g.pop("user", None)
|
|
with patch("fastmcp.server.dependencies.get_access_token", return_value=None):
|
|
with pytest.raises(ValueError, match="No authenticated user found"):
|
|
get_user_from_request()
|
|
|
|
|
|
def test_dev_username_not_found_raises(app) -> None:
|
|
"""MCP_DEV_USERNAME configured but user not in DB raises ValueError."""
|
|
with app.app_context():
|
|
app.config["MCP_DEV_USERNAME"] = "ghost"
|
|
try:
|
|
with (
|
|
patch(
|
|
"fastmcp.server.dependencies.get_access_token", return_value=None
|
|
),
|
|
patch(
|
|
"superset.mcp_service.auth.load_user_with_relationships",
|
|
return_value=None,
|
|
),
|
|
):
|
|
with pytest.raises(ValueError, match="not found"):
|
|
get_user_from_request()
|
|
finally:
|
|
app.config.pop("MCP_DEV_USERNAME", None)
|
|
|
|
|
|
# -- g.user clearing in mcp_auth_hook --
|
|
|
|
|
|
def test_mcp_auth_hook_clears_stale_g_user(app) -> None:
|
|
"""mcp_auth_hook clears g.user before setting up user context.
|
|
|
|
Uses a side_effect that asserts g.user was cleared before user
|
|
resolution runs, so the test fails if g.pop("user") is removed.
|
|
"""
|
|
stale_user = _make_mock_user("stale")
|
|
fresh_user = _make_mock_user("fresh")
|
|
|
|
def dummy_tool():
|
|
"""Dummy tool."""
|
|
return g.user.username
|
|
|
|
wrapped = mcp_auth_hook(dummy_tool)
|
|
|
|
def _assert_cleared_then_return():
|
|
"""Verify stale g.user was cleared before returning fresh user."""
|
|
assert not hasattr(g, "user") or g.user is None, (
|
|
"g.user should have been cleared before get_user_from_request() "
|
|
f"but found g.user={getattr(g, 'user', '<missing>')}"
|
|
)
|
|
return fresh_user
|
|
|
|
with app.app_context():
|
|
g.user = stale_user
|
|
# Explicitly mock has_request_context to False because the test
|
|
# framework's autouse app_context fixture may implicitly provide
|
|
# a request context in some CI environments.
|
|
with (
|
|
patch("flask.has_request_context", return_value=False),
|
|
patch(
|
|
"superset.mcp_service.auth.get_user_from_request",
|
|
side_effect=lambda: _assert_cleared_then_return(),
|
|
),
|
|
):
|
|
result = wrapped()
|
|
|
|
assert result == "fresh"
|
|
|
|
|
|
def test_mcp_auth_hook_clears_stale_g_user_async(app) -> None:
|
|
"""mcp_auth_hook clears g.user before setting up user context (async).
|
|
|
|
Uses a side_effect that asserts g.user was cleared before user
|
|
resolution runs, so the test fails if g.pop("user") is removed.
|
|
"""
|
|
import asyncio
|
|
|
|
stale_user = _make_mock_user("stale")
|
|
fresh_user = _make_mock_user("fresh")
|
|
|
|
async def dummy_tool():
|
|
"""Dummy tool."""
|
|
return g.user.username
|
|
|
|
wrapped = mcp_auth_hook(dummy_tool)
|
|
|
|
def _assert_cleared_then_return():
|
|
"""Verify stale g.user was cleared before returning fresh user."""
|
|
assert not hasattr(g, "user") or g.user is None, (
|
|
"g.user should have been cleared before get_user_from_request() "
|
|
f"but found g.user={getattr(g, 'user', '<missing>')}"
|
|
)
|
|
return fresh_user
|
|
|
|
with app.app_context():
|
|
g.user = stale_user
|
|
with (
|
|
patch("flask.has_request_context", return_value=False),
|
|
patch(
|
|
"superset.mcp_service.auth.get_user_from_request",
|
|
side_effect=lambda: _assert_cleared_then_return(),
|
|
),
|
|
):
|
|
result = asyncio.run(wrapped())
|
|
|
|
assert result == "fresh"
|
|
|
|
|
|
def test_mcp_auth_hook_preserves_g_user_in_request_context(app) -> None:
|
|
"""g.user is NOT cleared when a request context is active (middleware compat).
|
|
|
|
Uses a side_effect that asserts g.user is still the middleware-set
|
|
user when get_user_from_request() is called, proving the hook did
|
|
NOT clear it.
|
|
"""
|
|
middleware_user = _make_mock_user("middleware_user")
|
|
|
|
def dummy_tool():
|
|
"""Dummy tool."""
|
|
return g.user.username
|
|
|
|
wrapped = mcp_auth_hook(dummy_tool)
|
|
|
|
def _assert_preserved_then_return():
|
|
"""Verify g.user was preserved (not cleared) before returning."""
|
|
assert hasattr(g, "user"), (
|
|
"g.user should be preserved in request context but was removed"
|
|
)
|
|
assert g.user is middleware_user, (
|
|
"g.user should be preserved in request context but was changed; "
|
|
f"g.user={g.user}"
|
|
)
|
|
return middleware_user
|
|
|
|
with app.test_request_context():
|
|
g.user = middleware_user
|
|
with patch(
|
|
"superset.mcp_service.auth.get_user_from_request",
|
|
side_effect=lambda: _assert_preserved_then_return(),
|
|
):
|
|
result = wrapped()
|
|
|
|
assert result == "middleware_user"
|
|
|
|
|
|
# -- default_user_resolver --
|
|
|
|
|
|
def test_default_resolver_extracts_sub_from_claims() -> None:
|
|
"""Extracts 'sub' claim as last-resort from AccessToken.claims dict."""
|
|
token = _make_access_token(claims={"sub": "alice"})
|
|
assert default_user_resolver(None, token) == "alice"
|
|
|
|
|
|
def test_default_resolver_extracts_preferred_username() -> None:
|
|
"""Extracts 'preferred_username' claim (common OIDC claim)."""
|
|
token = _make_access_token(claims={"preferred_username": "alice"})
|
|
assert default_user_resolver(None, token) == "alice"
|
|
|
|
|
|
def test_default_resolver_extracts_email_from_claims() -> None:
|
|
"""Falls back to 'email' claim when 'sub' is absent."""
|
|
token = _make_access_token(claims={"email": "alice@example.com"})
|
|
assert default_user_resolver(None, token) == "alice@example.com"
|
|
|
|
|
|
def test_default_resolver_extracts_username_from_claims() -> None:
|
|
"""Falls back to 'username' claim."""
|
|
token = _make_access_token(claims={"username": "alice"})
|
|
assert default_user_resolver(None, token) == "alice"
|
|
|
|
|
|
def test_default_resolver_falls_back_to_subject_attr() -> None:
|
|
"""Falls back to legacy .subject attribute when claims empty."""
|
|
token = _make_access_token(claims={}, subject="legacy_user")
|
|
assert default_user_resolver(None, token) == "legacy_user"
|
|
|
|
|
|
def test_default_resolver_falls_back_to_client_id() -> None:
|
|
"""Falls back to .client_id when claims empty and no subject."""
|
|
token = _make_access_token(claims={}, client_id="service-account")
|
|
assert default_user_resolver(None, token) == "service-account"
|
|
|
|
|
|
def test_default_resolver_returns_none_for_empty_token() -> None:
|
|
"""Returns None when no claims or attributes have a username."""
|
|
token = _make_access_token(claims={}, client_id="")
|
|
assert default_user_resolver(None, token) is None
|
|
|
|
|
|
def test_default_resolver_preferred_username_takes_priority() -> None:
|
|
"""'preferred_username' takes priority over 'sub' and 'email' in claims."""
|
|
token = _make_access_token(
|
|
claims={
|
|
"sub": "opaque-id-123",
|
|
"preferred_username": "alice",
|
|
"email": "alice@example.com",
|
|
}
|
|
)
|
|
assert default_user_resolver(None, token) == "alice"
|