fix(mcp): handle stale SSL connections, heatmap duplicate labels, and session rollback (#39015)

This commit is contained in:
Amin Ghadersohi
2026-04-03 22:07:29 +02:00
committed by GitHub
parent c7d175b842
commit b3a402d936
4 changed files with 151 additions and 62 deletions

View File

@@ -434,6 +434,10 @@ def _setup_user_context() -> User | None:
"""
Set up user context for MCP tool execution.
Includes retry logic for stale database connections (e.g., SSL dropped
by proxy/load balancer after idle periods). On OperationalError, the
session is reset and the user lookup is retried once.
Returns:
User object with roles and groups loaded, or None if no Flask context
"""
@@ -446,38 +450,64 @@ def _setup_user_context() -> User | None:
if not has_request_context():
g.pop("user", None)
try:
user = get_user_from_request()
except RuntimeError as e:
# No Flask application context (e.g., prompts before middleware runs)
# This is expected for some FastMCP operations - return None gracefully
if "application context" in str(e):
logger.debug("No Flask app context available for user setup")
return None
raise
except ValueError as e:
# JWT user resolution failed (e.g. SAML subject not in DB).
# If middleware already set g.user (request context exists),
# use that instead of failing closed.
from flask import has_request_context
from sqlalchemy.exc import OperationalError
if has_request_context() and hasattr(g, "user") and g.user:
logger.warning(
"JWT user resolution failed (%s), using middleware-provided g.user=%s",
e,
g.user.username,
)
# Assign to local so relationship validation below runs
# (same as the normal path) to prevent detached instance errors.
user = g.user
else:
user = None # Ensure defined before loop in case of unexpected exit
for attempt in range(2):
try:
user = get_user_from_request()
# Validate user has necessary relationships loaded.
# Force access to ensure they're loaded if lazy.
# This is inside the retry loop because relationship loading
# also hits the DB and can fail on stale SSL connections.
user_roles = user.roles # noqa: F841
if hasattr(user, "groups"):
user_groups = user.groups # noqa: F841
break
except RuntimeError as e:
# No Flask application context (e.g., prompts before middleware runs)
if "application context" in str(e):
logger.debug("No Flask app context available for user setup")
return None
raise
except OperationalError as e:
if attempt == 0:
# Only retry on connection-level errors (SSL drops, server
# closed connection). Other OperationalErrors (e.g., lock
# timeouts) are unlikely to succeed on immediate retry but
# are bounded to one attempt so the cost is acceptable.
logger.warning(
"Stale DB connection during user setup (attempt 1), "
"resetting session and retrying: %s",
e,
)
_cleanup_session_on_error()
continue
logger.error("DB connection failed on retry during user setup: %s", e)
raise
except ValueError as e:
# JWT user resolution failed (e.g. SAML subject not in DB).
# Log a security warning but fall back to middleware-provided
# g.user if available. This handles cases where the JWT
# resolver username format doesn't match the DB username
# (e.g., SAML subject vs email). A separate story should
# investigate whether any deployments hit this path and
# migrate them before removing the fallback entirely.
if has_request_context() and hasattr(g, "user") and g.user:
logger.warning(
"SECURITY: JWT user resolution failed (%s), falling "
"back to middleware-provided g.user=%s. This fallback "
"should be investigated and removed once all "
"deployments use consistent username resolution.",
e,
g.user.username,
)
user = g.user
break
raise
# Validate user has necessary relationships loaded
# (Force access to ensure they're loaded if lazy)
user_roles = user.roles # noqa: F841
if hasattr(user, "groups"):
user_groups = user.groups # noqa: F841
g.user = user
return user

View File

@@ -301,7 +301,13 @@ async def get_chart_data( # noqa: C901
cached_groupby: list[str] = []
else:
cached_metrics = cached_form_data_dict.get("metrics", [])
cached_groupby = cached_form_data_dict.get("groupby", [])
raw_groupby = cached_form_data_dict.get("groupby", [])
# Guard against string groupby (e.g. heatmap_v2 migrated
# from legacy heatmap where all_columns_y was a string)
if isinstance(raw_groupby, str):
cached_groupby = [raw_groupby]
else:
cached_groupby = list(raw_groupby)
_apply_extra_form_data(cached_form_data_dict, request.extra_form_data)
@@ -443,7 +449,13 @@ async def get_chart_data( # noqa: C901
else:
# Standard charts use "metrics" (plural) and "groupby"
metrics = form_data.get("metrics", [])
groupby_columns = list(form_data.get("groupby") or [])
raw_groupby = form_data.get("groupby") or []
# Guard against string groupby (e.g. heatmap_v2 migrated
# from legacy heatmap where all_columns_y was a string)
if isinstance(raw_groupby, str):
groupby_columns = [raw_groupby]
else:
groupby_columns = list(raw_groupby)
# Some chart types use "columns" instead of "groupby"
if not groupby_columns:
form_columns = form_data.get("columns")

View File

@@ -27,7 +27,7 @@ from urllib.parse import urlencode
from fastmcp import Context
from superset_core.mcp.decorators import tool, ToolAnnotations
from superset.extensions import event_logger
from superset.extensions import db, event_logger
from superset.mcp_service.sql_lab.schemas import (
OpenSqlLabRequest,
SqlLabResponse,
@@ -118,6 +118,15 @@ def open_sql_lab_with_context(
)
except Exception as e:
try:
db.session.rollback() # pylint: disable=consider-using-transaction
except Exception: # noqa: BLE001
# Broad catch: the DB connection itself may be broken (e.g.,
# SSL drop), so even rollback can fail with non-SQLAlchemy
# errors. This is a cleanup path — swallow and log.
logger.warning(
"Database rollback failed during error handling", exc_info=True
)
logger.error("Error generating SQL Lab URL: %s", e)
return SqlLabResponse(
url="",

View File

@@ -23,9 +23,10 @@ InstanceInfoCore for flexible, extensible metrics calculation.
import logging
from fastmcp import Context
from sqlalchemy.exc import OperationalError
from superset_core.mcp.decorators import tool, ToolAnnotations
from superset.extensions import event_logger
from superset.extensions import db, event_logger
from superset.mcp_service.mcp_core import InstanceInfoCore
from superset.mcp_service.system.schemas import (
GetSupersetInstanceInfoRequest,
@@ -92,38 +93,75 @@ def get_instance_info(
Returns counts, activity metrics, and database types.
"""
try:
# Import DAOs at runtime to avoid circular imports
from flask import g
return _run_instance_info()
from superset.daos.chart import ChartDAO
from superset.daos.dashboard import DashboardDAO
from superset.daos.database import DatabaseDAO
from superset.daos.dataset import DatasetDAO
from superset.daos.tag import TagDAO
from superset.daos.user import UserDAO
except OperationalError as e:
logger.warning(
"Database connection error in get_instance_info, "
"resetting session and retrying: %s",
e,
)
try:
db.session.rollback() # pylint: disable=consider-using-transaction
except Exception: # noqa: BLE001
# Broad catch: the DB connection itself may be broken (e.g.,
# SSL drop), so even rollback can fail with non-SQLAlchemy
# errors. This is a cleanup path — swallow and log.
logger.warning(
"Rollback failed during get_instance_info connection reset",
exc_info=True,
)
try:
db.session.remove() # pylint: disable=consider-using-transaction
except Exception: # noqa: BLE001
# Same as above — cleanup must not prevent the retry.
logger.warning(
"Session remove failed during get_instance_info connection reset",
exc_info=True,
)
# Configure DAO classes at runtime
_instance_info_core.dao_classes = {
"dashboards": DashboardDAO,
"charts": ChartDAO,
"datasets": DatasetDAO,
"databases": DatabaseDAO,
"users": UserDAO,
"tags": TagDAO,
}
# Run the configurable core
with event_logger.log_context(action="mcp.get_instance_info.metrics"):
result = _instance_info_core.run_tool()
# Attach the authenticated user's identity to the response
user = getattr(g, "user", None)
if user is not None:
result.current_user = serialize_user_object(user)
return result
try:
result = _run_instance_info()
logger.info("get_instance_info retry succeeded after connection reset")
return result
except OperationalError as retry_error:
logger.error(
"get_instance_info retry failed after connection reset: %s",
retry_error,
exc_info=True,
)
raise
except Exception as e:
error_msg = f"Unexpected error in instance info: {str(e)}"
logger.error(error_msg, exc_info=True)
raise
def _run_instance_info() -> InstanceInfo:
"""Execute the instance info core logic."""
from flask import g
from superset.daos.chart import ChartDAO
from superset.daos.dashboard import DashboardDAO
from superset.daos.database import DatabaseDAO
from superset.daos.dataset import DatasetDAO
from superset.daos.tag import TagDAO
from superset.daos.user import UserDAO
_instance_info_core.dao_classes = {
"dashboards": DashboardDAO,
"charts": ChartDAO,
"datasets": DatasetDAO,
"databases": DatabaseDAO,
"users": UserDAO,
"tags": TagDAO,
}
with event_logger.log_context(action="mcp.get_instance_info.metrics"):
result = _instance_info_core.run_tool()
if (user := getattr(g, "user", None)) is not None:
result.current_user = serialize_user_object(user)
return result