mirror of
https://github.com/apache/superset.git
synced 2026-05-12 19:35:17 +00:00
fix(mcp): always push fresh app context per tool call to prevent g.user race (#39385)
(cherry picked from commit e7b9fb277e)
This commit is contained in:
committed by
Michael S. Molina
parent
2dca313e41
commit
8b9a74a3d4
@@ -385,18 +385,36 @@ def mcp_auth_hook(tool_func: F) -> F: # noqa: C901
|
|||||||
import inspect
|
import inspect
|
||||||
import types
|
import types
|
||||||
|
|
||||||
from flask import has_app_context
|
from flask import current_app, has_app_context, has_request_context
|
||||||
|
|
||||||
from superset.mcp_service.flask_singleton import get_flask_app
|
|
||||||
|
|
||||||
def _get_app_context_manager() -> AbstractContextManager[None]:
|
def _get_app_context_manager() -> AbstractContextManager[None]:
|
||||||
"""Return app context manager only if not already in one."""
|
"""Push a fresh app context unless a request context is active.
|
||||||
if has_app_context():
|
|
||||||
# Already in app context (e.g., in tests), use null context
|
When a request context is present, external middleware (e.g.
|
||||||
|
Preset's WorkspaceContextMiddleware) has already set ``g.user``
|
||||||
|
on a per-request app context — reuse it via ``nullcontext()``.
|
||||||
|
|
||||||
|
When only a bare app context exists (no request context), we must
|
||||||
|
push a **new** app context. The MCP server typically runs inside
|
||||||
|
a long-lived app context (e.g. ``__main__.py`` wraps
|
||||||
|
``mcp.run()`` in ``app.app_context()``). When FastMCP dispatches
|
||||||
|
concurrent tool calls via ``asyncio.create_task()``, each task
|
||||||
|
inherits the parent's ``ContextVar`` *value* — a reference to the
|
||||||
|
**same** ``AppContext`` object. Without a fresh push, all tasks
|
||||||
|
share one ``g`` namespace and concurrent ``g.user`` mutations
|
||||||
|
race: one user's identity can overwrite another's before
|
||||||
|
``get_user_id()`` runs during the SQLAlchemy INSERT flush,
|
||||||
|
attributing the created asset to the wrong user.
|
||||||
|
"""
|
||||||
|
if has_request_context():
|
||||||
return contextlib.nullcontext()
|
return contextlib.nullcontext()
|
||||||
# Push new app context for standalone MCP server
|
if has_app_context():
|
||||||
app = get_flask_app()
|
# Push a new context for the CURRENT app (not get_flask_app()
|
||||||
return app.app_context()
|
# which may return a different instance in test environments).
|
||||||
|
return current_app._get_current_object().app_context()
|
||||||
|
from superset.mcp_service.flask_singleton import get_flask_app
|
||||||
|
|
||||||
|
return get_flask_app().app_context()
|
||||||
|
|
||||||
is_async = inspect.iscoroutinefunction(tool_func)
|
is_async = inspect.iscoroutinefunction(tool_func)
|
||||||
|
|
||||||
|
|||||||
164
tests/unit_tests/mcp_service/test_g_user_race_condition.py
Normal file
164
tests/unit_tests/mcp_service/test_g_user_race_condition.py
Normal file
@@ -0,0 +1,164 @@
|
|||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
"""
|
||||||
|
Tests for g.user isolation across concurrent MCP tool calls.
|
||||||
|
|
||||||
|
The MCP server pushes a parent app_context at startup (__main__.py).
|
||||||
|
asyncio.create_task() copies the ContextVar VALUE — a reference to the
|
||||||
|
SAME AppContext object. Without pushing a fresh app_context() per tool
|
||||||
|
call, concurrent tasks share one g namespace and g.user mutations race.
|
||||||
|
|
||||||
|
These tests verify that:
|
||||||
|
- Always pushing a new app_context() per task isolates g.user (SAFE)
|
||||||
|
- Reusing a shared parent context via nullcontext() causes races (UNSAFE)
|
||||||
|
- When a request context is active, nullcontext() is safe (middleware path)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from flask import Flask, g
|
||||||
|
|
||||||
|
|
||||||
|
def _make_user(user_id: int, username: str) -> SimpleNamespace:
|
||||||
|
return SimpleNamespace(id=user_id, username=username)
|
||||||
|
|
||||||
|
|
||||||
|
ALICE = _make_user(1, "alice")
|
||||||
|
BOB = _make_user(2, "bob")
|
||||||
|
|
||||||
|
|
||||||
|
def _get_user_id() -> int | None:
|
||||||
|
"""Mirrors superset.utils.core.get_user_id."""
|
||||||
|
try:
|
||||||
|
return g.user.id
|
||||||
|
except AttributeError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_fresh_app_context_per_task_isolates_g_user():
|
||||||
|
"""
|
||||||
|
Each task pushes its own app_context(). g.user is isolated.
|
||||||
|
This is the fixed code path in _get_app_context_manager() when
|
||||||
|
no request context is active (app-context-only mode).
|
||||||
|
"""
|
||||||
|
app = Flask(__name__)
|
||||||
|
|
||||||
|
async def tool_call(user, results, key):
|
||||||
|
with app.app_context():
|
||||||
|
g.user = user
|
||||||
|
await asyncio.sleep(0) # yield to other tasks
|
||||||
|
results[key] = _get_user_id()
|
||||||
|
|
||||||
|
# Parent context exists (like __main__.py:138)
|
||||||
|
with app.app_context():
|
||||||
|
for _ in range(200):
|
||||||
|
results: dict[str, int | None] = {}
|
||||||
|
await asyncio.gather(
|
||||||
|
tool_call(ALICE, results, "alice"),
|
||||||
|
tool_call(BOB, results, "bob"),
|
||||||
|
)
|
||||||
|
assert results["alice"] == ALICE.id
|
||||||
|
assert results["bob"] == BOB.id
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_nullcontext_shared_context_causes_race():
|
||||||
|
"""
|
||||||
|
Both tasks reuse the parent's app context (nullcontext path).
|
||||||
|
g.user is shared — one task overwrites the other's identity.
|
||||||
|
Uses asyncio.Event for deterministic interleaving.
|
||||||
|
"""
|
||||||
|
app = Flask(__name__)
|
||||||
|
|
||||||
|
alice_set = asyncio.Event()
|
||||||
|
bob_set = asyncio.Event()
|
||||||
|
|
||||||
|
async def alice_task(results):
|
||||||
|
g.user = ALICE
|
||||||
|
alice_set.set() # Signal: Alice has set g.user
|
||||||
|
await bob_set.wait() # Wait for Bob to overwrite g.user
|
||||||
|
results["alice"] = _get_user_id()
|
||||||
|
|
||||||
|
async def bob_task(results):
|
||||||
|
await alice_set.wait() # Wait for Alice to set g.user first
|
||||||
|
g.user = BOB # Overwrite the shared g.user
|
||||||
|
bob_set.set() # Signal: Bob has overwritten
|
||||||
|
results["bob"] = _get_user_id()
|
||||||
|
|
||||||
|
with app.app_context():
|
||||||
|
results: dict[str, int | None] = {}
|
||||||
|
await asyncio.gather(
|
||||||
|
alice_task(results),
|
||||||
|
bob_task(results),
|
||||||
|
)
|
||||||
|
# Alice reads Bob's ID because they share the same g
|
||||||
|
assert results["alice"] == BOB.id, (
|
||||||
|
"Expected Alice to see Bob's ID due to shared g"
|
||||||
|
)
|
||||||
|
assert results["bob"] == BOB.id
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_request_context_preserves_g_user():
|
||||||
|
"""
|
||||||
|
When a request context is active (middleware set g.user), each task
|
||||||
|
pushes its own test_request_context. The per-task app_context +
|
||||||
|
request_context provides isolation even with nullcontext() in the
|
||||||
|
auth hook.
|
||||||
|
"""
|
||||||
|
app = Flask(__name__)
|
||||||
|
|
||||||
|
async def tool_call(user, results, key):
|
||||||
|
with app.app_context():
|
||||||
|
with app.test_request_context(path="/mcp"):
|
||||||
|
g.user = user
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
results[key] = _get_user_id()
|
||||||
|
|
||||||
|
with app.app_context():
|
||||||
|
for _ in range(200):
|
||||||
|
results: dict[str, int | None] = {}
|
||||||
|
await asyncio.gather(
|
||||||
|
tool_call(ALICE, results, "alice"),
|
||||||
|
tool_call(BOB, results, "bob"),
|
||||||
|
)
|
||||||
|
assert results["alice"] == ALICE.id
|
||||||
|
assert results["bob"] == BOB.id
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_high_contention_isolation():
|
||||||
|
"""10 concurrent users, 50 iterations — stress test."""
|
||||||
|
app = Flask(__name__)
|
||||||
|
users = [_make_user(i, f"user_{i}") for i in range(10)]
|
||||||
|
|
||||||
|
async def tool_call(user, results, key):
|
||||||
|
with app.app_context():
|
||||||
|
g.user = user
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
results[key] = _get_user_id()
|
||||||
|
|
||||||
|
with app.app_context():
|
||||||
|
for _ in range(50):
|
||||||
|
results: dict[str, int | None] = {}
|
||||||
|
await asyncio.gather(*(tool_call(u, results, u.username) for u in users))
|
||||||
|
for u in users:
|
||||||
|
assert results[u.username] == u.id
|
||||||
Reference in New Issue
Block a user