mirror of
https://github.com/apache/superset.git
synced 2026-04-18 07:35:09 +00:00
fix(mcp): always push fresh app context per tool call to prevent g.user race (#39385)
This commit is contained in:
@@ -535,18 +535,36 @@ def mcp_auth_hook(tool_func: F) -> F: # noqa: C901
|
||||
import inspect
|
||||
import types
|
||||
|
||||
from flask import has_app_context
|
||||
|
||||
from superset.mcp_service.flask_singleton import get_flask_app
|
||||
from flask import current_app, has_app_context, has_request_context
|
||||
|
||||
def _get_app_context_manager() -> AbstractContextManager[None]:
|
||||
"""Return app context manager only if not already in one."""
|
||||
if has_app_context():
|
||||
# Already in app context (e.g., in tests), use null context
|
||||
"""Push a fresh app context unless a request context is active.
|
||||
|
||||
When a request context is present, external middleware (e.g.
|
||||
Preset's WorkspaceContextMiddleware) has already set ``g.user``
|
||||
on a per-request app context — reuse it via ``nullcontext()``.
|
||||
|
||||
When only a bare app context exists (no request context), we must
|
||||
push a **new** app context. The MCP server typically runs inside
|
||||
a long-lived app context (e.g. ``__main__.py`` wraps
|
||||
``mcp.run()`` in ``app.app_context()``). When FastMCP dispatches
|
||||
concurrent tool calls via ``asyncio.create_task()``, each task
|
||||
inherits the parent's ``ContextVar`` *value* — a reference to the
|
||||
**same** ``AppContext`` object. Without a fresh push, all tasks
|
||||
share one ``g`` namespace and concurrent ``g.user`` mutations
|
||||
race: one user's identity can overwrite another's before
|
||||
``get_user_id()`` runs during the SQLAlchemy INSERT flush,
|
||||
attributing the created asset to the wrong user.
|
||||
"""
|
||||
if has_request_context():
|
||||
return contextlib.nullcontext()
|
||||
# Push new app context for standalone MCP server
|
||||
app = get_flask_app()
|
||||
return app.app_context()
|
||||
if has_app_context():
|
||||
# Push a new context for the CURRENT app (not get_flask_app()
|
||||
# which may return a different instance in test environments).
|
||||
return current_app._get_current_object().app_context()
|
||||
from superset.mcp_service.flask_singleton import get_flask_app
|
||||
|
||||
return get_flask_app().app_context()
|
||||
|
||||
is_async = inspect.iscoroutinefunction(tool_func)
|
||||
|
||||
|
||||
164
tests/unit_tests/mcp_service/test_g_user_race_condition.py
Normal file
164
tests/unit_tests/mcp_service/test_g_user_race_condition.py
Normal file
@@ -0,0 +1,164 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
Tests for g.user isolation across concurrent MCP tool calls.
|
||||
|
||||
The MCP server pushes a parent app_context at startup (__main__.py).
|
||||
asyncio.create_task() copies the ContextVar VALUE — a reference to the
|
||||
SAME AppContext object. Without pushing a fresh app_context() per tool
|
||||
call, concurrent tasks share one g namespace and g.user mutations race.
|
||||
|
||||
These tests verify that:
|
||||
- Always pushing a new app_context() per task isolates g.user (SAFE)
|
||||
- Reusing a shared parent context via nullcontext() causes races (UNSAFE)
|
||||
- When a request context is active, nullcontext() is safe (middleware path)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from flask import Flask, g
|
||||
|
||||
|
||||
def _make_user(user_id: int, username: str) -> SimpleNamespace:
|
||||
return SimpleNamespace(id=user_id, username=username)
|
||||
|
||||
|
||||
ALICE = _make_user(1, "alice")
|
||||
BOB = _make_user(2, "bob")
|
||||
|
||||
|
||||
def _get_user_id() -> int | None:
|
||||
"""Mirrors superset.utils.core.get_user_id."""
|
||||
try:
|
||||
return g.user.id
|
||||
except AttributeError:
|
||||
return None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fresh_app_context_per_task_isolates_g_user():
|
||||
"""
|
||||
Each task pushes its own app_context(). g.user is isolated.
|
||||
This is the fixed code path in _get_app_context_manager() when
|
||||
no request context is active (app-context-only mode).
|
||||
"""
|
||||
app = Flask(__name__)
|
||||
|
||||
async def tool_call(user, results, key):
|
||||
with app.app_context():
|
||||
g.user = user
|
||||
await asyncio.sleep(0) # yield to other tasks
|
||||
results[key] = _get_user_id()
|
||||
|
||||
# Parent context exists (like __main__.py:138)
|
||||
with app.app_context():
|
||||
for _ in range(200):
|
||||
results: dict[str, int | None] = {}
|
||||
await asyncio.gather(
|
||||
tool_call(ALICE, results, "alice"),
|
||||
tool_call(BOB, results, "bob"),
|
||||
)
|
||||
assert results["alice"] == ALICE.id
|
||||
assert results["bob"] == BOB.id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_nullcontext_shared_context_causes_race():
|
||||
"""
|
||||
Both tasks reuse the parent's app context (nullcontext path).
|
||||
g.user is shared — one task overwrites the other's identity.
|
||||
Uses asyncio.Event for deterministic interleaving.
|
||||
"""
|
||||
app = Flask(__name__)
|
||||
|
||||
alice_set = asyncio.Event()
|
||||
bob_set = asyncio.Event()
|
||||
|
||||
async def alice_task(results):
|
||||
g.user = ALICE
|
||||
alice_set.set() # Signal: Alice has set g.user
|
||||
await bob_set.wait() # Wait for Bob to overwrite g.user
|
||||
results["alice"] = _get_user_id()
|
||||
|
||||
async def bob_task(results):
|
||||
await alice_set.wait() # Wait for Alice to set g.user first
|
||||
g.user = BOB # Overwrite the shared g.user
|
||||
bob_set.set() # Signal: Bob has overwritten
|
||||
results["bob"] = _get_user_id()
|
||||
|
||||
with app.app_context():
|
||||
results: dict[str, int | None] = {}
|
||||
await asyncio.gather(
|
||||
alice_task(results),
|
||||
bob_task(results),
|
||||
)
|
||||
# Alice reads Bob's ID because they share the same g
|
||||
assert results["alice"] == BOB.id, (
|
||||
"Expected Alice to see Bob's ID due to shared g"
|
||||
)
|
||||
assert results["bob"] == BOB.id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_context_preserves_g_user():
|
||||
"""
|
||||
When a request context is active (middleware set g.user), each task
|
||||
pushes its own test_request_context. The per-task app_context +
|
||||
request_context provides isolation even with nullcontext() in the
|
||||
auth hook.
|
||||
"""
|
||||
app = Flask(__name__)
|
||||
|
||||
async def tool_call(user, results, key):
|
||||
with app.app_context():
|
||||
with app.test_request_context(path="/mcp"):
|
||||
g.user = user
|
||||
await asyncio.sleep(0)
|
||||
results[key] = _get_user_id()
|
||||
|
||||
with app.app_context():
|
||||
for _ in range(200):
|
||||
results: dict[str, int | None] = {}
|
||||
await asyncio.gather(
|
||||
tool_call(ALICE, results, "alice"),
|
||||
tool_call(BOB, results, "bob"),
|
||||
)
|
||||
assert results["alice"] == ALICE.id
|
||||
assert results["bob"] == BOB.id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_high_contention_isolation():
|
||||
"""10 concurrent users, 50 iterations — stress test."""
|
||||
app = Flask(__name__)
|
||||
users = [_make_user(i, f"user_{i}") for i in range(10)]
|
||||
|
||||
async def tool_call(user, results, key):
|
||||
with app.app_context():
|
||||
g.user = user
|
||||
await asyncio.sleep(0)
|
||||
await asyncio.sleep(0)
|
||||
results[key] = _get_user_id()
|
||||
|
||||
with app.app_context():
|
||||
for _ in range(50):
|
||||
results: dict[str, int | None] = {}
|
||||
await asyncio.gather(*(tool_call(u, results, u.username) for u in users))
|
||||
for u in users:
|
||||
assert results[u.username] == u.id
|
||||
Reference in New Issue
Block a user