diff --git a/superset/mcp_service/auth.py b/superset/mcp_service/auth.py index 1a1e68054f4..f70b1623d65 100644 --- a/superset/mcp_service/auth.py +++ b/superset/mcp_service/auth.py @@ -46,23 +46,34 @@ def get_user_from_request() -> User: """ Get the current user for the MCP tool request. - TODO (future PR): Add JWT token extraction and validation. - TODO (future PR): Add user impersonation support. - TODO (future PR): Add fallback user configuration. + Priority order: + 1. g.user if already set (by Preset workspace middleware) + 2. MCP_DEV_USERNAME from configuration (for development/testing) - For now, this returns the admin user for development. + Returns: + User object with roles and groups eagerly loaded + + Raises: + ValueError: If user cannot be authenticated or found """ from flask import current_app from sqlalchemy.orm import joinedload from superset.extensions import db - # TODO: Extract from JWT token once authentication is implemented - # For now, use MCP_DEV_USERNAME from configuration + # First check if user is already set by Preset workspace middleware + if hasattr(g, "user") and g.user: + return g.user + + # Fall back to configured username for development/single-user deployments username = current_app.config.get("MCP_DEV_USERNAME") if not username: - raise ValueError("Username not configured") + raise ValueError( + "No authenticated user found. " + "Either pass a valid JWT bearer token or configure " + "MCP_DEV_USERNAME for development." + ) # Query user directly with eager loading to ensure fresh session-bound object # Do NOT use security_manager.find_user() as it may return cached/detached user @@ -115,65 +126,109 @@ def has_dataset_access(dataset: "SqlaTable") -> bool: return False # Deny access on error +def _setup_user_context() -> User: + """ + Set up user context for MCP tool execution. + + Returns: + User object with roles and groups loaded + """ + user = get_user_from_request() + + # Validate user has necessary relationships loaded + # (Force access to ensure they're loaded if lazy) + user_roles = user.roles # noqa: F841 + if hasattr(user, "groups"): + user_groups = user.groups # noqa: F841 + + g.user = user + return user + + +def _cleanup_session_on_error() -> None: + """Clean up database session after an exception.""" + from superset.extensions import db + + # pylint: disable=consider-using-transaction + try: + db.session.rollback() + db.session.remove() + except Exception as e: + logger.warning("Error cleaning up session after exception: %s", e) + + +def _cleanup_session_finally() -> None: + """Clean up database session in finally block.""" + from superset.extensions import db + + # Rollback active session (no exception occurred) + # Do NOT call remove() on success to avoid detaching user + try: + if db.session.is_active: + # pylint: disable=consider-using-transaction + db.session.rollback() + except Exception as e: + logger.warning("Error in finally block: %s", e) + + def mcp_auth_hook(tool_func: F) -> F: """ Authentication and authorization decorator for MCP tools. - This is a minimal implementation that: - 1. Gets the current user - 2. Sets g.user for Flask context + This decorator assumes Flask application context and g.user + have already been set by WorkspaceContextMiddleware. + + Supports both sync and async tool functions. TODO (future PR): Add permission checking TODO (future PR): Add JWT scope validation TODO (future PR): Add comprehensive audit logging - TODO (future PR): Add rate limiting integration """ import functools + import inspect - @functools.wraps(tool_func) - def wrapper(*args: Any, **kwargs: Any) -> Any: - from superset.extensions import db + is_async = inspect.iscoroutinefunction(tool_func) - # Get user and set Flask context OUTSIDE try block - user = get_user_from_request() + if is_async: - # Force load relationships NOW while session is definitely active - _ = user.roles - if hasattr(user, "groups"): - _ = user.groups + @functools.wraps(tool_func) + async def async_wrapper(*args: Any, **kwargs: Any) -> Any: + user = _setup_user_context() - g.user = user - - try: - # TODO: Add permission checks here in future PR - # TODO: Add audit logging here in future PR - - logger.debug( - "MCP tool call: user=%s, tool=%s", user.username, tool_func.__name__ - ) - - result = tool_func(*args, **kwargs) - - return result - - except Exception: - # On error, rollback and cleanup session - # pylint: disable=consider-using-transaction try: - db.session.rollback() - db.session.remove() - except Exception as e: - logger.warning("Error cleaning up session after exception: %s", e) - raise + logger.debug( + "MCP tool call: user=%s, tool=%s", + user.username, + tool_func.__name__, + ) + result = await tool_func(*args, **kwargs) + return result + except Exception: + _cleanup_session_on_error() + raise + finally: + _cleanup_session_finally() + + return async_wrapper # type: ignore[return-value] + + else: + + @functools.wraps(tool_func) + def sync_wrapper(*args: Any, **kwargs: Any) -> Any: + user = _setup_user_context() - finally: - # Only rollback if session is still active (no exception occurred) - # Do NOT call remove() on success to avoid detaching user try: - if db.session.is_active: - # pylint: disable=consider-using-transaction - db.session.rollback() - except Exception as e: - logger.warning("Error in finally block: %s", e) + logger.debug( + "MCP tool call: user=%s, tool=%s", + user.username, + tool_func.__name__, + ) + result = tool_func(*args, **kwargs) + return result + except Exception: + _cleanup_session_on_error() + raise + finally: + _cleanup_session_finally() - return wrapper # type: ignore[return-value] + return sync_wrapper # type: ignore[return-value] diff --git a/superset/mcp_service/common/error_schemas.py b/superset/mcp_service/common/error_schemas.py index a0527c34f12..ec0274cc0be 100644 --- a/superset/mcp_service/common/error_schemas.py +++ b/superset/mcp_service/common/error_schemas.py @@ -48,9 +48,16 @@ class ValidationError(BaseModel): class DatasetContext(BaseModel): """Dataset information for error context""" + model_config = {"populate_by_name": True} + id: int = Field(..., description="Dataset ID") table_name: str = Field(..., description="Table name") - schema: str | None = Field(None, description="Schema name") + schema_name: str | None = Field( + None, + alias="schema", + serialization_alias="schema", + description="Schema name", + ) database_name: str = Field(..., description="Database name") available_columns: List[Dict[str, Any]] = Field( default_factory=list, description="Available columns with metadata" diff --git a/superset/mcp_service/server.py b/superset/mcp_service/server.py index e723ef611a9..3fe0907651c 100644 --- a/superset/mcp_service/server.py +++ b/superset/mcp_service/server.py @@ -82,9 +82,26 @@ def run_server( factory_config = get_mcp_factory_config() mcp_instance = create_mcp_app(**factory_config) else: - # Use default initialization + # Use default initialization with auth from Flask config logging.info("Creating MCP app with default configuration...") - mcp_instance = init_fastmcp_server() + from superset.mcp_service.flask_singleton import get_flask_app + + flask_app = get_flask_app() + + # Get auth factory from config and create auth provider + auth_provider = None + auth_factory = flask_app.config.get("MCP_AUTH_FACTORY") + if auth_factory: + try: + auth_provider = auth_factory(flask_app) + logging.info( + "Auth provider created: %s", + type(auth_provider).__name__ if auth_provider else "None", + ) + except Exception as e: + logging.error("Failed to create auth provider: %s", e) + + mcp_instance = init_fastmcp_server(auth=auth_provider) env_key = f"FASTMCP_RUNNING_{port}" if not os.environ.get(env_key): diff --git a/superset/mcp_service/system/tool/health_check.py b/superset/mcp_service/system/tool/health_check.py index 133d1ad0325..7dfa5ca7cf7 100644 --- a/superset/mcp_service/system/tool/health_check.py +++ b/superset/mcp_service/system/tool/health_check.py @@ -21,6 +21,7 @@ import datetime import logging import platform +from fastmcp import Context from flask import current_app from superset.mcp_service.app import mcp @@ -33,7 +34,7 @@ logger = logging.getLogger(__name__) @mcp.tool @mcp_auth_hook -async def health_check() -> HealthCheckResponse: +async def health_check(ctx: Context) -> HealthCheckResponse: """ Simple health check tool for testing the MCP service.