Compare commits

...

4 Commits

Author SHA1 Message Date
Amin Ghadersohi
8ad7ef9fb5 refactor(mcp): lift deferred imports to module level in auth.py
Move non-circular stdlib, SQLAlchemy, Flask, and fastmcp imports from
function bodies to the module-level import block for readability.
Imports that must remain deferred (superset.extensions.db,
security_manager, mcp_config, flask_singleton, chart_utils) and
optional-dep guards (fastmcp.server.dependencies) are unchanged.
2026-05-07 16:56:38 +00:00
Amin Ghadersohi
1facda454c fix(mcp): broaden pre-call session cleanup to catch DBAPIError not just OperationalError
Some database drivers (e.g. MySQL, SQLite) surface dropped connections as
InterfaceError rather than OperationalError. Both are DBAPIError subclasses.
Widen the catch in _remove_session_safe() so all DBAPI-level disconnect
errors are handled consistently regardless of driver.
2026-05-07 16:56:38 +00:00
Amin Ghadersohi
214b5e172b refactor(mcp): simplify test mock and add retry comment from review
- Replace nonlocal counter in SSL error test with MagicMock side_effect list
- Add inline comment on retry db.session.remove() to clarify intent
2026-05-07 16:56:38 +00:00
Amin Ghadersohi
7bb80e2336 fix(mcp): handle SSL connection drop during pre-call session teardown
When RDS drops an SSL connection due to idle timeout or max-connection-age,
`db.session.remove()` in `sync_wrapper` raises `OperationalError` because
the implicit rollback inside `session.close()` fails on the dead DBAPI
connection. This caused the MCP tool call to fail even when the operation
itself completed successfully, and left a dead connection in the pool.

Introduce `_remove_session_safe()` which:
- Catches `OperationalError` from `session.remove()` (SSL/network errors)
- Calls `session.invalidate()` to mark the dead connection for pool discard
- Retries `session.remove()` so the scoped registry is clean before the tool runs

Replace the bare `db.session.remove()` in `sync_wrapper` with `_remove_session_safe()`.
Add a unit test verifying `invalidate()` is called and remove is retried on SSL error.
2026-05-07 16:56:38 +00:00
2 changed files with 92 additions and 22 deletions

View File

@@ -44,12 +44,19 @@ Configuration:
- MCP_DEV_USERNAME: Fallback username for development
"""
import contextlib
import functools
import inspect
import logging
import types
from contextlib import AbstractContextManager
from typing import Any, Callable, TYPE_CHECKING, TypeVar
from flask import g, has_request_context
from fastmcp import Context as FMContext
from flask import current_app, g, has_app_context, has_request_context
from flask_appbuilder.security.sqla.models import Group, User
from sqlalchemy.exc import DBAPIError, OperationalError
from sqlalchemy.orm import joinedload
if TYPE_CHECKING:
from superset.connectors.sqla.models import SqlaTable
@@ -104,8 +111,6 @@ def check_tool_permission(func: Callable[..., Any]) -> bool:
True if user has permission or no permission is required.
"""
try:
from flask import current_app
if not current_app.config.get("MCP_RBAC_ENABLED", True):
return True
@@ -172,8 +177,6 @@ def load_user_with_relationships(
if not username and not email:
raise ValueError("Either username or email must be provided")
from sqlalchemy.orm import joinedload
from superset.extensions import db
query = db.session.query(User).options(
@@ -330,8 +333,6 @@ def get_user_from_request() -> User:
Raises:
ValueError: If user cannot be authenticated or found
"""
from flask import current_app
# Priority 1: JWT context (per-request safe via ContextVar)
if (jwt_user := _resolve_user_from_jwt_context(current_app)) is not None:
return jwt_user
@@ -445,13 +446,9 @@ def _setup_user_context() -> User | None:
# tool calls when no per-request middleware refreshes it.
# Only clear in app-context-only mode; preserve g.user when
# a request context is active (external middleware set it).
from flask import has_request_context
if not has_request_context():
g.pop("user", None)
from sqlalchemy.exc import OperationalError
user = None # Ensure defined before loop in case of unexpected exit
for attempt in range(2):
@@ -516,6 +513,45 @@ def _cleanup_session_on_error() -> None:
logger.warning("Error cleaning up session after exception: %s", e)
def _remove_session_safe() -> None:
"""Remove the scoped SQLAlchemy session, tolerating SSL/connection errors.
Thread-pool workers reuse threads across requests. Before each tool call
the session is removed to prevent a prior request's thread-local session
from leaking into the next one. If the underlying DBAPI connection died
between requests (e.g. RDS SSL idle-timeout or max-connection-age), the
rollback implicit in ``session.close()`` raises a ``DBAPIError`` subclass
(``OperationalError`` for psycopg2, ``InterfaceError`` for some other
drivers).
When that happens:
1. Invalidate the dead connection so the pool discards it (rather than
returning a broken connection to the next caller).
2. Retry ``remove()`` to deregister the session from the scoped registry.
The tool call still proceeds because a fresh connection will be obtained
on the next DB access.
"""
from superset.extensions import db
try:
db.session.remove()
except DBAPIError as exc:
logger.warning(
"Connection error during pre-call session cleanup "
"(likely SSL/idle timeout); invalidating connection and retrying: %s",
exc,
)
try:
db.session.invalidate()
except Exception as invalidate_exc:
logger.debug(
"Could not invalidate session after connection error: %s",
invalidate_exc,
)
db.session.remove() # retry: session deregisters cleanly after invalidation
def mcp_auth_hook(tool_func: F) -> F: # noqa: C901
"""
Authentication and authorization decorator for MCP tools.
@@ -530,12 +566,6 @@ def mcp_auth_hook(tool_func: F) -> F: # noqa: C901
Supports both sync and async tool functions.
"""
import contextlib
import functools
import inspect
import types
from flask import current_app, has_app_context, has_request_context
def _get_app_context_manager() -> AbstractContextManager[None]:
"""Push a fresh app context unless a request context is active.
@@ -570,8 +600,6 @@ def mcp_auth_hook(tool_func: F) -> F: # noqa: C901
# Detect if the original function expects a ctx: Context parameter.
# If so, we inject it via get_context() at call time.
from fastmcp import Context as FMContext
_tool_sig = inspect.signature(tool_func)
_needs_ctx = any(
p.annotation is FMContext
@@ -638,9 +666,7 @@ def mcp_auth_hook(tool_func: F) -> F: # noqa: C901
# still be bound to a different tenant's DB engine. Removing it here
# ensures the next DB access creates a fresh session bound to the
# correct engine for the current request.
from superset.extensions import db
db.session.remove()
_remove_session_safe()
user = _setup_user_context()
# No Flask context - this is a FastMCP internal operation

View File

@@ -409,6 +409,50 @@ def test_mcp_auth_hook_removes_stale_db_session_in_sync_wrapper(app) -> None:
assert result == "fresh"
def test_sync_wrapper_handles_ssl_error_on_pre_call_remove(app) -> None:
"""sync_wrapper tolerates OperationalError from db.session.remove() before the call.
If the underlying DBAPI connection died between requests (e.g. RDS SSL
idle-timeout), the rollback implicit in session.close() raises
OperationalError. _remove_session_safe() should:
- Log a warning
- Call session.invalidate() to mark the dead connection for pool discard
- Retry session.remove() so the registry is clean
- Allow the tool to run successfully
"""
from sqlalchemy.exc import OperationalError as SAOperationalError
fresh_user = _make_mock_user("fresh")
def dummy_tool() -> str:
"""Dummy sync tool."""
return g.user.username
wrapped = mcp_auth_hook(dummy_tool)
with app.test_request_context():
g.user = fresh_user
with patch("superset.extensions.db") as mock_db:
mock_db.session.remove.side_effect = [
SAOperationalError(
"SSL connection has been closed unexpectedly", None, None
),
None, # second call succeeds
]
with patch(
"superset.mcp_service.auth.get_user_from_request",
return_value=fresh_user,
):
result = wrapped()
assert result == "fresh"
assert mock_db.session.invalidate.called, "invalidate() must be called on SSL error"
assert mock_db.session.remove.call_count == 2, (
"remove() must be retried after SSL error"
)
# -- default_user_resolver --