fix(mcp-service): ensure Flask app context in auth hook and resolve Pydantic warnings (#36013)

This commit is contained in:
Amin Ghadersohi
2025-11-17 20:31:03 +10:00
committed by GitHub
parent c2baba50f9
commit 9605a4a9cb
4 changed files with 135 additions and 55 deletions

View File

@@ -46,23 +46,34 @@ def get_user_from_request() -> User:
"""
Get the current user for the MCP tool request.
TODO (future PR): Add JWT token extraction and validation.
TODO (future PR): Add user impersonation support.
TODO (future PR): Add fallback user configuration.
Priority order:
1. g.user if already set (by Preset workspace middleware)
2. MCP_DEV_USERNAME from configuration (for development/testing)
For now, this returns the admin user for development.
Returns:
User object with roles and groups eagerly loaded
Raises:
ValueError: If user cannot be authenticated or found
"""
from flask import current_app
from sqlalchemy.orm import joinedload
from superset.extensions import db
# TODO: Extract from JWT token once authentication is implemented
# For now, use MCP_DEV_USERNAME from configuration
# First check if user is already set by Preset workspace middleware
if hasattr(g, "user") and g.user:
return g.user
# Fall back to configured username for development/single-user deployments
username = current_app.config.get("MCP_DEV_USERNAME")
if not username:
raise ValueError("Username not configured")
raise ValueError(
"No authenticated user found. "
"Either pass a valid JWT bearer token or configure "
"MCP_DEV_USERNAME for development."
)
# Query user directly with eager loading to ensure fresh session-bound object
# Do NOT use security_manager.find_user() as it may return cached/detached user
@@ -115,65 +126,109 @@ def has_dataset_access(dataset: "SqlaTable") -> bool:
return False # Deny access on error
def _setup_user_context() -> User:
"""
Set up user context for MCP tool execution.
Returns:
User object with roles and groups loaded
"""
user = get_user_from_request()
# Validate user has necessary relationships loaded
# (Force access to ensure they're loaded if lazy)
user_roles = user.roles # noqa: F841
if hasattr(user, "groups"):
user_groups = user.groups # noqa: F841
g.user = user
return user
def _cleanup_session_on_error() -> None:
"""Clean up database session after an exception."""
from superset.extensions import db
# pylint: disable=consider-using-transaction
try:
db.session.rollback()
db.session.remove()
except Exception as e:
logger.warning("Error cleaning up session after exception: %s", e)
def _cleanup_session_finally() -> None:
"""Clean up database session in finally block."""
from superset.extensions import db
# Rollback active session (no exception occurred)
# Do NOT call remove() on success to avoid detaching user
try:
if db.session.is_active:
# pylint: disable=consider-using-transaction
db.session.rollback()
except Exception as e:
logger.warning("Error in finally block: %s", e)
def mcp_auth_hook(tool_func: F) -> F:
"""
Authentication and authorization decorator for MCP tools.
This is a minimal implementation that:
1. Gets the current user
2. Sets g.user for Flask context
This decorator assumes Flask application context and g.user
have already been set by WorkspaceContextMiddleware.
Supports both sync and async tool functions.
TODO (future PR): Add permission checking
TODO (future PR): Add JWT scope validation
TODO (future PR): Add comprehensive audit logging
TODO (future PR): Add rate limiting integration
"""
import functools
import inspect
@functools.wraps(tool_func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
from superset.extensions import db
is_async = inspect.iscoroutinefunction(tool_func)
# Get user and set Flask context OUTSIDE try block
user = get_user_from_request()
if is_async:
# Force load relationships NOW while session is definitely active
_ = user.roles
if hasattr(user, "groups"):
_ = user.groups
@functools.wraps(tool_func)
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
user = _setup_user_context()
g.user = user
try:
# TODO: Add permission checks here in future PR
# TODO: Add audit logging here in future PR
logger.debug(
"MCP tool call: user=%s, tool=%s", user.username, tool_func.__name__
)
result = tool_func(*args, **kwargs)
return result
except Exception:
# On error, rollback and cleanup session
# pylint: disable=consider-using-transaction
try:
db.session.rollback()
db.session.remove()
except Exception as e:
logger.warning("Error cleaning up session after exception: %s", e)
raise
logger.debug(
"MCP tool call: user=%s, tool=%s",
user.username,
tool_func.__name__,
)
result = await tool_func(*args, **kwargs)
return result
except Exception:
_cleanup_session_on_error()
raise
finally:
_cleanup_session_finally()
return async_wrapper # type: ignore[return-value]
else:
@functools.wraps(tool_func)
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
user = _setup_user_context()
finally:
# Only rollback if session is still active (no exception occurred)
# Do NOT call remove() on success to avoid detaching user
try:
if db.session.is_active:
# pylint: disable=consider-using-transaction
db.session.rollback()
except Exception as e:
logger.warning("Error in finally block: %s", e)
logger.debug(
"MCP tool call: user=%s, tool=%s",
user.username,
tool_func.__name__,
)
result = tool_func(*args, **kwargs)
return result
except Exception:
_cleanup_session_on_error()
raise
finally:
_cleanup_session_finally()
return wrapper # type: ignore[return-value]
return sync_wrapper # type: ignore[return-value]

View File

@@ -48,9 +48,16 @@ class ValidationError(BaseModel):
class DatasetContext(BaseModel):
"""Dataset information for error context"""
model_config = {"populate_by_name": True}
id: int = Field(..., description="Dataset ID")
table_name: str = Field(..., description="Table name")
schema: str | None = Field(None, description="Schema name")
schema_name: str | None = Field(
None,
alias="schema",
serialization_alias="schema",
description="Schema name",
)
database_name: str = Field(..., description="Database name")
available_columns: List[Dict[str, Any]] = Field(
default_factory=list, description="Available columns with metadata"

View File

@@ -82,9 +82,26 @@ def run_server(
factory_config = get_mcp_factory_config()
mcp_instance = create_mcp_app(**factory_config)
else:
# Use default initialization
# Use default initialization with auth from Flask config
logging.info("Creating MCP app with default configuration...")
mcp_instance = init_fastmcp_server()
from superset.mcp_service.flask_singleton import get_flask_app
flask_app = get_flask_app()
# Get auth factory from config and create auth provider
auth_provider = None
auth_factory = flask_app.config.get("MCP_AUTH_FACTORY")
if auth_factory:
try:
auth_provider = auth_factory(flask_app)
logging.info(
"Auth provider created: %s",
type(auth_provider).__name__ if auth_provider else "None",
)
except Exception as e:
logging.error("Failed to create auth provider: %s", e)
mcp_instance = init_fastmcp_server(auth=auth_provider)
env_key = f"FASTMCP_RUNNING_{port}"
if not os.environ.get(env_key):

View File

@@ -21,6 +21,7 @@ import datetime
import logging
import platform
from fastmcp import Context
from flask import current_app
from superset.mcp_service.app import mcp
@@ -33,7 +34,7 @@ logger = logging.getLogger(__name__)
@mcp.tool
@mcp_auth_hook
async def health_check() -> HealthCheckResponse:
async def health_check(ctx: Context) -> HealthCheckResponse:
"""
Simple health check tool for testing the MCP service.