mirror of
https://github.com/apache/superset.git
synced 2026-05-14 20:35:23 +00:00
Compare commits
4 Commits
embedded-e
...
fix/mcp-ss
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8ad7ef9fb5 | ||
|
|
1facda454c | ||
|
|
214b5e172b | ||
|
|
7bb80e2336 |
@@ -44,12 +44,19 @@ Configuration:
|
||||
- MCP_DEV_USERNAME: Fallback username for development
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
import inspect
|
||||
import logging
|
||||
import types
|
||||
from contextlib import AbstractContextManager
|
||||
from typing import Any, Callable, TYPE_CHECKING, TypeVar
|
||||
|
||||
from flask import g, has_request_context
|
||||
from fastmcp import Context as FMContext
|
||||
from flask import current_app, g, has_app_context, has_request_context
|
||||
from flask_appbuilder.security.sqla.models import Group, User
|
||||
from sqlalchemy.exc import DBAPIError, OperationalError
|
||||
from sqlalchemy.orm import joinedload
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
@@ -104,8 +111,6 @@ def check_tool_permission(func: Callable[..., Any]) -> bool:
|
||||
True if user has permission or no permission is required.
|
||||
"""
|
||||
try:
|
||||
from flask import current_app
|
||||
|
||||
if not current_app.config.get("MCP_RBAC_ENABLED", True):
|
||||
return True
|
||||
|
||||
@@ -172,8 +177,6 @@ def load_user_with_relationships(
|
||||
if not username and not email:
|
||||
raise ValueError("Either username or email must be provided")
|
||||
|
||||
from sqlalchemy.orm import joinedload
|
||||
|
||||
from superset.extensions import db
|
||||
|
||||
query = db.session.query(User).options(
|
||||
@@ -330,8 +333,6 @@ def get_user_from_request() -> User:
|
||||
Raises:
|
||||
ValueError: If user cannot be authenticated or found
|
||||
"""
|
||||
from flask import current_app
|
||||
|
||||
# Priority 1: JWT context (per-request safe via ContextVar)
|
||||
if (jwt_user := _resolve_user_from_jwt_context(current_app)) is not None:
|
||||
return jwt_user
|
||||
@@ -445,13 +446,9 @@ def _setup_user_context() -> User | None:
|
||||
# tool calls when no per-request middleware refreshes it.
|
||||
# Only clear in app-context-only mode; preserve g.user when
|
||||
# a request context is active (external middleware set it).
|
||||
from flask import has_request_context
|
||||
|
||||
if not has_request_context():
|
||||
g.pop("user", None)
|
||||
|
||||
from sqlalchemy.exc import OperationalError
|
||||
|
||||
user = None # Ensure defined before loop in case of unexpected exit
|
||||
|
||||
for attempt in range(2):
|
||||
@@ -516,6 +513,45 @@ def _cleanup_session_on_error() -> None:
|
||||
logger.warning("Error cleaning up session after exception: %s", e)
|
||||
|
||||
|
||||
def _remove_session_safe() -> None:
|
||||
"""Remove the scoped SQLAlchemy session, tolerating SSL/connection errors.
|
||||
|
||||
Thread-pool workers reuse threads across requests. Before each tool call
|
||||
the session is removed to prevent a prior request's thread-local session
|
||||
from leaking into the next one. If the underlying DBAPI connection died
|
||||
between requests (e.g. RDS SSL idle-timeout or max-connection-age), the
|
||||
rollback implicit in ``session.close()`` raises a ``DBAPIError`` subclass
|
||||
(``OperationalError`` for psycopg2, ``InterfaceError`` for some other
|
||||
drivers).
|
||||
|
||||
When that happens:
|
||||
1. Invalidate the dead connection so the pool discards it (rather than
|
||||
returning a broken connection to the next caller).
|
||||
2. Retry ``remove()`` to deregister the session from the scoped registry.
|
||||
|
||||
The tool call still proceeds because a fresh connection will be obtained
|
||||
on the next DB access.
|
||||
"""
|
||||
from superset.extensions import db
|
||||
|
||||
try:
|
||||
db.session.remove()
|
||||
except DBAPIError as exc:
|
||||
logger.warning(
|
||||
"Connection error during pre-call session cleanup "
|
||||
"(likely SSL/idle timeout); invalidating connection and retrying: %s",
|
||||
exc,
|
||||
)
|
||||
try:
|
||||
db.session.invalidate()
|
||||
except Exception as invalidate_exc:
|
||||
logger.debug(
|
||||
"Could not invalidate session after connection error: %s",
|
||||
invalidate_exc,
|
||||
)
|
||||
db.session.remove() # retry: session deregisters cleanly after invalidation
|
||||
|
||||
|
||||
def mcp_auth_hook(tool_func: F) -> F: # noqa: C901
|
||||
"""
|
||||
Authentication and authorization decorator for MCP tools.
|
||||
@@ -530,12 +566,6 @@ def mcp_auth_hook(tool_func: F) -> F: # noqa: C901
|
||||
|
||||
Supports both sync and async tool functions.
|
||||
"""
|
||||
import contextlib
|
||||
import functools
|
||||
import inspect
|
||||
import types
|
||||
|
||||
from flask import current_app, has_app_context, has_request_context
|
||||
|
||||
def _get_app_context_manager() -> AbstractContextManager[None]:
|
||||
"""Push a fresh app context unless a request context is active.
|
||||
@@ -570,8 +600,6 @@ def mcp_auth_hook(tool_func: F) -> F: # noqa: C901
|
||||
|
||||
# Detect if the original function expects a ctx: Context parameter.
|
||||
# If so, we inject it via get_context() at call time.
|
||||
from fastmcp import Context as FMContext
|
||||
|
||||
_tool_sig = inspect.signature(tool_func)
|
||||
_needs_ctx = any(
|
||||
p.annotation is FMContext
|
||||
@@ -638,9 +666,7 @@ def mcp_auth_hook(tool_func: F) -> F: # noqa: C901
|
||||
# still be bound to a different tenant's DB engine. Removing it here
|
||||
# ensures the next DB access creates a fresh session bound to the
|
||||
# correct engine for the current request.
|
||||
from superset.extensions import db
|
||||
|
||||
db.session.remove()
|
||||
_remove_session_safe()
|
||||
user = _setup_user_context()
|
||||
|
||||
# No Flask context - this is a FastMCP internal operation
|
||||
|
||||
@@ -409,6 +409,50 @@ def test_mcp_auth_hook_removes_stale_db_session_in_sync_wrapper(app) -> None:
|
||||
assert result == "fresh"
|
||||
|
||||
|
||||
def test_sync_wrapper_handles_ssl_error_on_pre_call_remove(app) -> None:
|
||||
"""sync_wrapper tolerates OperationalError from db.session.remove() before the call.
|
||||
|
||||
If the underlying DBAPI connection died between requests (e.g. RDS SSL
|
||||
idle-timeout), the rollback implicit in session.close() raises
|
||||
OperationalError. _remove_session_safe() should:
|
||||
- Log a warning
|
||||
- Call session.invalidate() to mark the dead connection for pool discard
|
||||
- Retry session.remove() so the registry is clean
|
||||
- Allow the tool to run successfully
|
||||
"""
|
||||
from sqlalchemy.exc import OperationalError as SAOperationalError
|
||||
|
||||
fresh_user = _make_mock_user("fresh")
|
||||
|
||||
def dummy_tool() -> str:
|
||||
"""Dummy sync tool."""
|
||||
return g.user.username
|
||||
|
||||
wrapped = mcp_auth_hook(dummy_tool)
|
||||
|
||||
with app.test_request_context():
|
||||
g.user = fresh_user
|
||||
with patch("superset.extensions.db") as mock_db:
|
||||
mock_db.session.remove.side_effect = [
|
||||
SAOperationalError(
|
||||
"SSL connection has been closed unexpectedly", None, None
|
||||
),
|
||||
None, # second call succeeds
|
||||
]
|
||||
|
||||
with patch(
|
||||
"superset.mcp_service.auth.get_user_from_request",
|
||||
return_value=fresh_user,
|
||||
):
|
||||
result = wrapped()
|
||||
|
||||
assert result == "fresh"
|
||||
assert mock_db.session.invalidate.called, "invalidate() must be called on SSL error"
|
||||
assert mock_db.session.remove.call_count == 2, (
|
||||
"remove() must be retried after SSL error"
|
||||
)
|
||||
|
||||
|
||||
# -- default_user_resolver --
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user