fix(mcp): always push fresh app context per tool call to prevent g.user race (#39385)

This commit is contained in:
Amin Ghadersohi
2026-04-15 20:48:21 -04:00
committed by GitHub
parent 838ee870d0
commit e7b9fb277e
2 changed files with 191 additions and 9 deletions

View File

@@ -535,18 +535,36 @@ def mcp_auth_hook(tool_func: F) -> F: # noqa: C901
import inspect
import types
from flask import has_app_context
from superset.mcp_service.flask_singleton import get_flask_app
from flask import current_app, has_app_context, has_request_context
def _get_app_context_manager() -> AbstractContextManager[None]:
"""Return app context manager only if not already in one."""
if has_app_context():
# Already in app context (e.g., in tests), use null context
"""Push a fresh app context unless a request context is active.
When a request context is present, external middleware (e.g.
Preset's WorkspaceContextMiddleware) has already set ``g.user``
on a per-request app context — reuse it via ``nullcontext()``.
When only a bare app context exists (no request context), we must
push a **new** app context. The MCP server typically runs inside
a long-lived app context (e.g. ``__main__.py`` wraps
``mcp.run()`` in ``app.app_context()``). When FastMCP dispatches
concurrent tool calls via ``asyncio.create_task()``, each task
inherits the parent's ``ContextVar`` *value* — a reference to the
**same** ``AppContext`` object. Without a fresh push, all tasks
share one ``g`` namespace and concurrent ``g.user`` mutations
race: one user's identity can overwrite another's before
``get_user_id()`` runs during the SQLAlchemy INSERT flush,
attributing the created asset to the wrong user.
"""
if has_request_context():
return contextlib.nullcontext()
# Push new app context for standalone MCP server
app = get_flask_app()
return app.app_context()
if has_app_context():
# Push a new context for the CURRENT app (not get_flask_app()
# which may return a different instance in test environments).
return current_app._get_current_object().app_context()
from superset.mcp_service.flask_singleton import get_flask_app
return get_flask_app().app_context()
is_async = inspect.iscoroutinefunction(tool_func)

View File

@@ -0,0 +1,164 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Tests for g.user isolation across concurrent MCP tool calls.
The MCP server pushes a parent app_context at startup (__main__.py).
asyncio.create_task() copies the ContextVar VALUE — a reference to the
SAME AppContext object. Without pushing a fresh app_context() per tool
call, concurrent tasks share one g namespace and g.user mutations race.
These tests verify that:
- Always pushing a new app_context() per task isolates g.user (SAFE)
- Reusing a shared parent context via nullcontext() causes races (UNSAFE)
- When a request context is active, nullcontext() is safe (middleware path)
"""
import asyncio
from types import SimpleNamespace
import pytest
from flask import Flask, g
def _make_user(user_id: int, username: str) -> SimpleNamespace:
return SimpleNamespace(id=user_id, username=username)
ALICE = _make_user(1, "alice")
BOB = _make_user(2, "bob")
def _get_user_id() -> int | None:
"""Mirrors superset.utils.core.get_user_id."""
try:
return g.user.id
except AttributeError:
return None
@pytest.mark.asyncio
async def test_fresh_app_context_per_task_isolates_g_user():
"""
Each task pushes its own app_context(). g.user is isolated.
This is the fixed code path in _get_app_context_manager() when
no request context is active (app-context-only mode).
"""
app = Flask(__name__)
async def tool_call(user, results, key):
with app.app_context():
g.user = user
await asyncio.sleep(0) # yield to other tasks
results[key] = _get_user_id()
# Parent context exists (like __main__.py:138)
with app.app_context():
for _ in range(200):
results: dict[str, int | None] = {}
await asyncio.gather(
tool_call(ALICE, results, "alice"),
tool_call(BOB, results, "bob"),
)
assert results["alice"] == ALICE.id
assert results["bob"] == BOB.id
@pytest.mark.asyncio
async def test_nullcontext_shared_context_causes_race():
"""
Both tasks reuse the parent's app context (nullcontext path).
g.user is shared — one task overwrites the other's identity.
Uses asyncio.Event for deterministic interleaving.
"""
app = Flask(__name__)
alice_set = asyncio.Event()
bob_set = asyncio.Event()
async def alice_task(results):
g.user = ALICE
alice_set.set() # Signal: Alice has set g.user
await bob_set.wait() # Wait for Bob to overwrite g.user
results["alice"] = _get_user_id()
async def bob_task(results):
await alice_set.wait() # Wait for Alice to set g.user first
g.user = BOB # Overwrite the shared g.user
bob_set.set() # Signal: Bob has overwritten
results["bob"] = _get_user_id()
with app.app_context():
results: dict[str, int | None] = {}
await asyncio.gather(
alice_task(results),
bob_task(results),
)
# Alice reads Bob's ID because they share the same g
assert results["alice"] == BOB.id, (
"Expected Alice to see Bob's ID due to shared g"
)
assert results["bob"] == BOB.id
@pytest.mark.asyncio
async def test_request_context_preserves_g_user():
"""
When a request context is active (middleware set g.user), each task
pushes its own test_request_context. The per-task app_context +
request_context provides isolation even with nullcontext() in the
auth hook.
"""
app = Flask(__name__)
async def tool_call(user, results, key):
with app.app_context():
with app.test_request_context(path="/mcp"):
g.user = user
await asyncio.sleep(0)
results[key] = _get_user_id()
with app.app_context():
for _ in range(200):
results: dict[str, int | None] = {}
await asyncio.gather(
tool_call(ALICE, results, "alice"),
tool_call(BOB, results, "bob"),
)
assert results["alice"] == ALICE.id
assert results["bob"] == BOB.id
@pytest.mark.asyncio
async def test_high_contention_isolation():
"""10 concurrent users, 50 iterations — stress test."""
app = Flask(__name__)
users = [_make_user(i, f"user_{i}") for i in range(10)]
async def tool_call(user, results, key):
with app.app_context():
g.user = user
await asyncio.sleep(0)
await asyncio.sleep(0)
results[key] = _get_user_id()
with app.app_context():
for _ in range(50):
results: dict[str, int | None] = {}
await asyncio.gather(*(tool_call(u, results, u.username) for u in users))
for u in users:
assert results[u.username] == u.id