mirror of
https://github.com/apache/superset.git
synced 2026-04-18 23:55:00 +00:00
fix(mcp-service): ensure Flask app context in auth hook and resolve Pydantic warnings (#36013)
This commit is contained in:
@@ -46,23 +46,34 @@ def get_user_from_request() -> User:
|
||||
"""
|
||||
Get the current user for the MCP tool request.
|
||||
|
||||
TODO (future PR): Add JWT token extraction and validation.
|
||||
TODO (future PR): Add user impersonation support.
|
||||
TODO (future PR): Add fallback user configuration.
|
||||
Priority order:
|
||||
1. g.user if already set (by Preset workspace middleware)
|
||||
2. MCP_DEV_USERNAME from configuration (for development/testing)
|
||||
|
||||
For now, this returns the admin user for development.
|
||||
Returns:
|
||||
User object with roles and groups eagerly loaded
|
||||
|
||||
Raises:
|
||||
ValueError: If user cannot be authenticated or found
|
||||
"""
|
||||
from flask import current_app
|
||||
from sqlalchemy.orm import joinedload
|
||||
|
||||
from superset.extensions import db
|
||||
|
||||
# TODO: Extract from JWT token once authentication is implemented
|
||||
# For now, use MCP_DEV_USERNAME from configuration
|
||||
# First check if user is already set by Preset workspace middleware
|
||||
if hasattr(g, "user") and g.user:
|
||||
return g.user
|
||||
|
||||
# Fall back to configured username for development/single-user deployments
|
||||
username = current_app.config.get("MCP_DEV_USERNAME")
|
||||
|
||||
if not username:
|
||||
raise ValueError("Username not configured")
|
||||
raise ValueError(
|
||||
"No authenticated user found. "
|
||||
"Either pass a valid JWT bearer token or configure "
|
||||
"MCP_DEV_USERNAME for development."
|
||||
)
|
||||
|
||||
# Query user directly with eager loading to ensure fresh session-bound object
|
||||
# Do NOT use security_manager.find_user() as it may return cached/detached user
|
||||
@@ -115,65 +126,109 @@ def has_dataset_access(dataset: "SqlaTable") -> bool:
|
||||
return False # Deny access on error
|
||||
|
||||
|
||||
def _setup_user_context() -> User:
|
||||
"""
|
||||
Set up user context for MCP tool execution.
|
||||
|
||||
Returns:
|
||||
User object with roles and groups loaded
|
||||
"""
|
||||
user = get_user_from_request()
|
||||
|
||||
# Validate user has necessary relationships loaded
|
||||
# (Force access to ensure they're loaded if lazy)
|
||||
user_roles = user.roles # noqa: F841
|
||||
if hasattr(user, "groups"):
|
||||
user_groups = user.groups # noqa: F841
|
||||
|
||||
g.user = user
|
||||
return user
|
||||
|
||||
|
||||
def _cleanup_session_on_error() -> None:
|
||||
"""Clean up database session after an exception."""
|
||||
from superset.extensions import db
|
||||
|
||||
# pylint: disable=consider-using-transaction
|
||||
try:
|
||||
db.session.rollback()
|
||||
db.session.remove()
|
||||
except Exception as e:
|
||||
logger.warning("Error cleaning up session after exception: %s", e)
|
||||
|
||||
|
||||
def _cleanup_session_finally() -> None:
|
||||
"""Clean up database session in finally block."""
|
||||
from superset.extensions import db
|
||||
|
||||
# Rollback active session (no exception occurred)
|
||||
# Do NOT call remove() on success to avoid detaching user
|
||||
try:
|
||||
if db.session.is_active:
|
||||
# pylint: disable=consider-using-transaction
|
||||
db.session.rollback()
|
||||
except Exception as e:
|
||||
logger.warning("Error in finally block: %s", e)
|
||||
|
||||
|
||||
def mcp_auth_hook(tool_func: F) -> F:
|
||||
"""
|
||||
Authentication and authorization decorator for MCP tools.
|
||||
|
||||
This is a minimal implementation that:
|
||||
1. Gets the current user
|
||||
2. Sets g.user for Flask context
|
||||
This decorator assumes Flask application context and g.user
|
||||
have already been set by WorkspaceContextMiddleware.
|
||||
|
||||
Supports both sync and async tool functions.
|
||||
|
||||
TODO (future PR): Add permission checking
|
||||
TODO (future PR): Add JWT scope validation
|
||||
TODO (future PR): Add comprehensive audit logging
|
||||
TODO (future PR): Add rate limiting integration
|
||||
"""
|
||||
import functools
|
||||
import inspect
|
||||
|
||||
@functools.wraps(tool_func)
|
||||
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
from superset.extensions import db
|
||||
is_async = inspect.iscoroutinefunction(tool_func)
|
||||
|
||||
# Get user and set Flask context OUTSIDE try block
|
||||
user = get_user_from_request()
|
||||
if is_async:
|
||||
|
||||
# Force load relationships NOW while session is definitely active
|
||||
_ = user.roles
|
||||
if hasattr(user, "groups"):
|
||||
_ = user.groups
|
||||
@functools.wraps(tool_func)
|
||||
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
user = _setup_user_context()
|
||||
|
||||
g.user = user
|
||||
|
||||
try:
|
||||
# TODO: Add permission checks here in future PR
|
||||
# TODO: Add audit logging here in future PR
|
||||
|
||||
logger.debug(
|
||||
"MCP tool call: user=%s, tool=%s", user.username, tool_func.__name__
|
||||
)
|
||||
|
||||
result = tool_func(*args, **kwargs)
|
||||
|
||||
return result
|
||||
|
||||
except Exception:
|
||||
# On error, rollback and cleanup session
|
||||
# pylint: disable=consider-using-transaction
|
||||
try:
|
||||
db.session.rollback()
|
||||
db.session.remove()
|
||||
except Exception as e:
|
||||
logger.warning("Error cleaning up session after exception: %s", e)
|
||||
raise
|
||||
logger.debug(
|
||||
"MCP tool call: user=%s, tool=%s",
|
||||
user.username,
|
||||
tool_func.__name__,
|
||||
)
|
||||
result = await tool_func(*args, **kwargs)
|
||||
return result
|
||||
except Exception:
|
||||
_cleanup_session_on_error()
|
||||
raise
|
||||
finally:
|
||||
_cleanup_session_finally()
|
||||
|
||||
return async_wrapper # type: ignore[return-value]
|
||||
|
||||
else:
|
||||
|
||||
@functools.wraps(tool_func)
|
||||
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
user = _setup_user_context()
|
||||
|
||||
finally:
|
||||
# Only rollback if session is still active (no exception occurred)
|
||||
# Do NOT call remove() on success to avoid detaching user
|
||||
try:
|
||||
if db.session.is_active:
|
||||
# pylint: disable=consider-using-transaction
|
||||
db.session.rollback()
|
||||
except Exception as e:
|
||||
logger.warning("Error in finally block: %s", e)
|
||||
logger.debug(
|
||||
"MCP tool call: user=%s, tool=%s",
|
||||
user.username,
|
||||
tool_func.__name__,
|
||||
)
|
||||
result = tool_func(*args, **kwargs)
|
||||
return result
|
||||
except Exception:
|
||||
_cleanup_session_on_error()
|
||||
raise
|
||||
finally:
|
||||
_cleanup_session_finally()
|
||||
|
||||
return wrapper # type: ignore[return-value]
|
||||
return sync_wrapper # type: ignore[return-value]
|
||||
|
||||
@@ -48,9 +48,16 @@ class ValidationError(BaseModel):
|
||||
class DatasetContext(BaseModel):
|
||||
"""Dataset information for error context"""
|
||||
|
||||
model_config = {"populate_by_name": True}
|
||||
|
||||
id: int = Field(..., description="Dataset ID")
|
||||
table_name: str = Field(..., description="Table name")
|
||||
schema: str | None = Field(None, description="Schema name")
|
||||
schema_name: str | None = Field(
|
||||
None,
|
||||
alias="schema",
|
||||
serialization_alias="schema",
|
||||
description="Schema name",
|
||||
)
|
||||
database_name: str = Field(..., description="Database name")
|
||||
available_columns: List[Dict[str, Any]] = Field(
|
||||
default_factory=list, description="Available columns with metadata"
|
||||
|
||||
@@ -82,9 +82,26 @@ def run_server(
|
||||
factory_config = get_mcp_factory_config()
|
||||
mcp_instance = create_mcp_app(**factory_config)
|
||||
else:
|
||||
# Use default initialization
|
||||
# Use default initialization with auth from Flask config
|
||||
logging.info("Creating MCP app with default configuration...")
|
||||
mcp_instance = init_fastmcp_server()
|
||||
from superset.mcp_service.flask_singleton import get_flask_app
|
||||
|
||||
flask_app = get_flask_app()
|
||||
|
||||
# Get auth factory from config and create auth provider
|
||||
auth_provider = None
|
||||
auth_factory = flask_app.config.get("MCP_AUTH_FACTORY")
|
||||
if auth_factory:
|
||||
try:
|
||||
auth_provider = auth_factory(flask_app)
|
||||
logging.info(
|
||||
"Auth provider created: %s",
|
||||
type(auth_provider).__name__ if auth_provider else "None",
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error("Failed to create auth provider: %s", e)
|
||||
|
||||
mcp_instance = init_fastmcp_server(auth=auth_provider)
|
||||
|
||||
env_key = f"FASTMCP_RUNNING_{port}"
|
||||
if not os.environ.get(env_key):
|
||||
|
||||
@@ -21,6 +21,7 @@ import datetime
|
||||
import logging
|
||||
import platform
|
||||
|
||||
from fastmcp import Context
|
||||
from flask import current_app
|
||||
|
||||
from superset.mcp_service.app import mcp
|
||||
@@ -33,7 +34,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
@mcp.tool
|
||||
@mcp_auth_hook
|
||||
async def health_check() -> HealthCheckResponse:
|
||||
async def health_check(ctx: Context) -> HealthCheckResponse:
|
||||
"""
|
||||
Simple health check tool for testing the MCP service.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user