fix(mcp): handle SSL connection drop during pre-call session teardown

When RDS drops an SSL connection due to idle timeout or max-connection-age,
`db.session.remove()` in `sync_wrapper` raises `OperationalError` because
the implicit rollback inside `session.close()` fails on the dead DBAPI
connection. This caused the MCP tool call to fail even when the operation
itself completed successfully, and left a dead connection in the pool.

Introduce `_remove_session_safe()` which:
- Catches `OperationalError` from `session.remove()` (SSL/network errors)
- Calls `session.invalidate()` to mark the dead connection for pool discard
- Retries `session.remove()` so the scoped registry is clean before the tool runs

Replace the bare `db.session.remove()` in `sync_wrapper` with `_remove_session_safe()`.
Add a unit test verifying `invalidate()` is called and remove is retried on SSL error.
This commit is contained in:
Amin Ghadersohi
2026-05-06 22:55:05 +00:00
parent 8c80caefa3
commit 7bb80e2336
2 changed files with 90 additions and 6 deletions

View File

@@ -516,6 +516,45 @@ def _cleanup_session_on_error() -> None:
logger.warning("Error cleaning up session after exception: %s", e)
def _remove_session_safe() -> None:
"""Remove the scoped SQLAlchemy session, tolerating SSL/connection errors.
Thread-pool workers reuse threads across requests. Before each tool call
the session is removed to prevent a prior request's thread-local session
from leaking into the next one. If the underlying DBAPI connection died
between requests (e.g. RDS SSL idle-timeout or max-connection-age), the
rollback implicit in ``session.close()`` raises ``OperationalError``.
When that happens:
1. Invalidate the dead connection so the pool discards it (rather than
returning a broken connection to the next caller).
2. Retry ``remove()`` to deregister the session from the scoped registry.
The tool call still proceeds because a fresh connection will be obtained
on the next DB access.
"""
from sqlalchemy.exc import OperationalError
from superset.extensions import db
try:
db.session.remove()
except OperationalError as exc:
logger.warning(
"Connection error during pre-call session cleanup "
"(likely SSL/idle timeout); invalidating connection and retrying: %s",
exc,
)
try:
db.session.invalidate()
except Exception as invalidate_exc:
logger.debug(
"Could not invalidate session after connection error: %s",
invalidate_exc,
)
db.session.remove()
def mcp_auth_hook(tool_func: F) -> F: # noqa: C901
"""
Authentication and authorization decorator for MCP tools.
@@ -638,9 +677,7 @@ def mcp_auth_hook(tool_func: F) -> F: # noqa: C901
# still be bound to a different tenant's DB engine. Removing it here
# ensures the next DB access creates a fresh session bound to the
# correct engine for the current request.
from superset.extensions import db
db.session.remove()
_remove_session_safe()
user = _setup_user_context()
# No Flask context - this is a FastMCP internal operation

View File

@@ -352,9 +352,9 @@ def test_mcp_auth_hook_preserves_g_user_in_request_context(app) -> None:
def _assert_preserved_then_return():
"""Verify g.user was preserved (not cleared) before returning."""
assert hasattr(g, "user"), (
"g.user should be preserved in request context but was removed"
)
assert hasattr(
g, "user"
), "g.user should be preserved in request context but was removed"
assert g.user is middleware_user, (
"g.user should be preserved in request context but was changed; "
f"g.user={g.user}"
@@ -409,6 +409,53 @@ def test_mcp_auth_hook_removes_stale_db_session_in_sync_wrapper(app) -> None:
assert result == "fresh"
def test_sync_wrapper_handles_ssl_error_on_pre_call_remove(app) -> None:
"""sync_wrapper tolerates OperationalError from db.session.remove() before the call.
If the underlying DBAPI connection died between requests (e.g. RDS SSL
idle-timeout), the rollback implicit in session.close() raises
OperationalError. _remove_session_safe() should:
- Log a warning
- Call session.invalidate() to mark the dead connection for pool discard
- Retry session.remove() so the registry is clean
- Allow the tool to run successfully
"""
from sqlalchemy.exc import OperationalError as SAOperationalError
fresh_user = _make_mock_user("fresh")
def dummy_tool() -> str:
"""Dummy sync tool."""
return g.user.username
wrapped = mcp_auth_hook(dummy_tool)
remove_call_count = 0
def _flaky_remove() -> None:
nonlocal remove_call_count
remove_call_count += 1
if remove_call_count == 1:
raise SAOperationalError(
"SSL connection has been closed unexpectedly", None, None
)
with app.test_request_context():
g.user = fresh_user
with patch("superset.extensions.db") as mock_db:
mock_db.session.remove.side_effect = _flaky_remove
with patch(
"superset.mcp_service.auth.get_user_from_request",
return_value=fresh_user,
):
result = wrapped()
assert result == "fresh"
assert mock_db.session.invalidate.called, "invalidate() must be called on SSL error"
assert remove_call_count == 2, "remove() must be retried after SSL error"
# -- default_user_resolver --