Files
superset2/superset/mcp_service/middleware.py

758 lines
26 KiB
Python

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
import time
from collections import defaultdict
from typing import Any, Awaitable, Callable, Dict, Protocol
from fastmcp.exceptions import ToolError
from fastmcp.server.middleware import Middleware, MiddlewareContext
from pydantic import ValidationError
from sqlalchemy.exc import OperationalError, TimeoutError
from starlette.exceptions import HTTPException
from superset.extensions import event_logger
from superset.utils.core import get_user_id
logger = logging.getLogger(__name__)
def _sanitize_error_for_logging(error: Exception) -> str:
"""Sanitize error messages to prevent information disclosure in logs."""
error_str = str(error)
# SECURITY FIX: Limit error message length FIRST to prevent ReDoS attacks
if len(error_str) > 500:
error_str = error_str[:500] + "...[truncated]"
# SECURITY FIX: Use bounded patterns to prevent ReDoS
import re
# Database connection strings - bounded patterns with word boundaries
# Use case-insensitive flag to handle both cases
error_str = re.sub(
r"\bpostgresql://[^@\s]{1,100}@[^/\s]{1,100}/[^\s]{0,100}",
"postgresql://[REDACTED]@[REDACTED]/[REDACTED]",
error_str,
flags=re.IGNORECASE,
)
error_str = re.sub(
r"\bmysql://[^@\s]{1,100}@[^/\s]{1,100}/[^\s]{0,100}",
"mysql://[REDACTED]@[REDACTED]/[REDACTED]",
error_str,
flags=re.IGNORECASE,
)
# API keys and tokens - bounded patterns
error_str = re.sub(
r"[Aa]pi[_-]?[Kk]ey[:\s]{0,5}[^\s'\"]{1,100}",
"ApiKey: [REDACTED]",
error_str,
)
error_str = re.sub(
r"[Tt]oken[:\s]{0,5}[^\s'\"]{1,100}", "Token: [REDACTED]", error_str
)
# File paths - bounded pattern
error_str = re.sub(
r"/[a-zA-Z0-9_\-/.]{1,200}/superset/", "/[REDACTED]/superset/", error_str
)
# IP addresses - already safe pattern, keep as-is
error_str = re.sub(r"\b(\d+)\.\d+\.\d+\.\d+\b", r"\1.xxx.xxx.xxx", error_str)
# For certain error types, provide generic messages
if isinstance(error, (OperationalError, TimeoutError)):
return "Database operation failed"
elif isinstance(error, PermissionError):
return "Access denied"
elif isinstance(error, ValidationError):
return "Request validation failed"
return error_str
class LoggingMiddleware(Middleware):
"""
Middleware that logs every MCP message (request and response) using the
event logger. This matches the core audit log system (Action Log UI,
logs table, custom loggers). Also attempts to log dashboard_id, chart_id
(slice_id), and dataset_id if present in tool params.
"""
async def on_message(
self,
context: MiddlewareContext,
call_next: Callable[[MiddlewareContext], Awaitable[Any]],
) -> Any:
# Extract agent_id and user_id
agent_id = None
user_id = None
dashboard_id = None
slice_id = None
dataset_id = None
params = getattr(context.message, "params", {}) or {}
if hasattr(context, "metadata") and context.metadata:
agent_id = context.metadata.get("agent_id")
if not agent_id and hasattr(context, "session") and context.session:
agent_id = getattr(context.session, "agent_id", None)
try:
user_id = get_user_id()
except Exception:
user_id = None
# Try to extract IDs from params
if isinstance(params, dict):
dashboard_id = params.get("dashboard_id")
# Chart ID may be under 'chart_id' or 'slice_id'
slice_id = params.get("chart_id") or params.get("slice_id")
dataset_id = params.get("dataset_id")
# Log to Superset's event logger (DB, Action Log UI, or custom)
event_logger.log(
user_id=user_id,
action="mcp_tool_call",
dashboard_id=dashboard_id,
duration_ms=None,
slice_id=slice_id,
referrer=None,
curated_payload={
"tool": getattr(context.message, "name", None),
"agent_id": agent_id,
"params": params,
"method": context.method,
"dashboard_id": dashboard_id,
"slice_id": slice_id,
"dataset_id": dataset_id,
},
)
# (Optional) also log to standard logger for debugging
logger.info(
"MCP tool call: tool=%s, agent_id=%s, user_id=%s, method=%s, "
"dashboard_id=%s, slice_id=%s, dataset_id=%s",
getattr(context.message, "name", None),
agent_id,
user_id,
context.method,
dashboard_id,
slice_id,
dataset_id,
)
return await call_next(context)
class PrivateToolMiddleware(Middleware):
"""
Middleware that blocks access to tools tagged as 'private'.
"""
async def on_call_tool(
self,
context: MiddlewareContext,
call_next: Callable[[MiddlewareContext], Awaitable[Any]],
) -> Any:
tool = await context.fastmcp_context.fastmcp.get_tool(context.message.name)
if "private" in getattr(tool, "tags", set()):
raise ToolError(f"Access denied to private tool: {context.message.name}")
return await call_next(context)
class GlobalErrorHandlerMiddleware(Middleware):
"""
Global error handler middleware that provides consistent error responses
and proper error logging for all MCP tool calls.
"""
async def on_message(
self,
context: MiddlewareContext,
call_next: Callable[[MiddlewareContext], Awaitable[Any]],
) -> Any:
"""Handle all message types with consistent error handling"""
start_time = time.time()
tool_name = getattr(context.message, "name", "unknown")
try:
return await call_next(context)
except Exception as e:
duration_ms = int((time.time() - start_time) * 1000)
return await self._handle_error(e, context, tool_name, duration_ms)
async def _handle_error( # noqa: C901
self,
error: Exception,
context: MiddlewareContext,
tool_name: str,
duration_ms: int,
) -> None:
"""Handle different types of errors with appropriate responses"""
# Extract user context for logging
user_id = None
try:
user_id = get_user_id()
except Exception:
user_id = None # User not authenticated
# SECURITY FIX: Log the error with sanitized context
sanitized_error = _sanitize_error_for_logging(error)
logger.error(
"MCP tool error: tool=%s, user_id=%s, duration_ms=%s, "
"error_type=%s, error=%s",
tool_name,
user_id,
duration_ms,
type(error).__name__,
sanitized_error,
)
# Log to Superset's event system
try:
event_logger.log(
user_id=user_id,
action="mcp_tool_error",
duration_ms=duration_ms,
curated_payload={
"tool": tool_name,
"error_type": type(error).__name__,
"error_message": str(error),
"method": context.method,
},
)
except Exception as log_error:
logger.warning("Failed to log error event: %s", log_error)
# Handle specific error types with appropriate responses
if isinstance(error, ToolError):
# Tool errors are already formatted for MCP
raise error
elif isinstance(error, ValidationError):
# Pydantic validation errors
validation_details = []
for err in error.errors():
field = " -> ".join(str(loc) for loc in err["loc"])
validation_details.append(f"{field}: {err['msg']}")
raise ToolError(
f"Validation error in {tool_name}: {'; '.join(validation_details)}"
) from error
elif isinstance(error, (OperationalError, TimeoutError)):
# Database errors
raise ToolError(
f"Database error in {tool_name}: Service temporarily unavailable. "
f"Please try again in a few moments."
) from error
elif isinstance(error, HTTPException):
# HTTP errors from screenshot endpoints or API calls
raise ToolError(f"Service error in {tool_name}: {error.detail}") from error
elif isinstance(error, PermissionError):
# Permission/authorization errors
raise ToolError(
f"Permission denied for {tool_name}: "
f"You don't have access to this resource."
) from error
elif isinstance(error, FileNotFoundError):
# File/resource not found errors
raise ToolError(
f"Resource not found in {tool_name}: {str(error)}"
) from error
elif isinstance(error, ValueError):
# Value/parameter errors
raise ToolError(
f"Invalid parameter in {tool_name}: {str(error)}"
) from error
else:
# Generic internal errors
error_id = f"err_{int(time.time())}"
logger.error("Unexpected error [%s] in %s: %s", error_id, tool_name, error)
raise ToolError(
f"Internal error in {tool_name}: An unexpected error occurred. "
f"Error ID: {error_id}. Please contact support if this persists."
) from error
class RateLimiterProtocol(Protocol):
"""Protocol for rate limiter implementations."""
def is_rate_limited(
self, key: str, limit: int, window: int = 60
) -> tuple[bool, dict[str, Any]]:
"""Check if a key is rate limited."""
...
def cleanup(self) -> None:
"""Clean up old entries if needed."""
...
class InMemoryRateLimiter:
"""In-memory rate limiter for development."""
def __init__(self) -> None:
# Structure: {key: [(timestamp, count), ...]}
self._requests: Dict[str, list[tuple[float, int]]] = defaultdict(list)
self._cleanup_interval = 300 # Clean up every 5 minutes
self._last_cleanup = time.time()
def is_rate_limited(
self, key: str, limit: int, window: int = 60
) -> tuple[bool, dict[str, Any]]:
"""Check if request should be rate limited using sliding window."""
current_time = time.time()
window_start = current_time - window
# Get requests in the current window
requests_in_window = [
(timestamp, count)
for timestamp, count in self._requests[key]
if timestamp > window_start
]
# Calculate total requests in window
total_requests = sum(count for _, count in requests_in_window)
# Check if rate limited BEFORE adding the current request
if total_requests >= limit:
# Rate limit info when limited
rate_limit_info = {
"limit": limit,
"remaining": 0,
"reset_time": int(window_start + window),
"window_seconds": window,
}
return True, rate_limit_info
# Add current request to tracking
self._requests[key].append((current_time, 1))
# Update total after adding
total_requests += 1
# Keep only recent entries
self._requests[key] = [
(ts, count)
for ts, count in self._requests[key]
if ts > current_time - 3600 # Keep last hour
]
# Rate limit info after adding request
rate_limit_info = {
"limit": limit,
"remaining": max(0, limit - total_requests),
"reset_time": int(window_start + window),
"window_seconds": window,
}
return False, rate_limit_info
def cleanup(self) -> None:
"""Remove entries older than 1 hour to prevent memory leaks."""
current_time = time.time()
# SECURITY FIX: Check both time-based and size-based cleanup conditions
total_entries = sum(len(requests) for requests in self._requests.values())
size_threshold = 10000 # Maximum entries before forced cleanup
time_based_cleanup = current_time - self._last_cleanup >= self._cleanup_interval
size_based_cleanup = total_entries > size_threshold
if not (time_based_cleanup or size_based_cleanup):
return
cutoff_time = current_time - 3600 # 1 hour ago
keys_to_clean = []
for key, requests in self._requests.items():
# Remove old entries
self._requests[key] = [
(timestamp, count)
for timestamp, count in requests
if timestamp > cutoff_time
]
# Mark empty keys for removal
if not self._requests[key]:
keys_to_clean.append(key)
for key in keys_to_clean:
del self._requests[key]
# SECURITY FIX: If still too many entries, implement aggressive cleanup
if total_entries > size_threshold:
logger.warning(
"Rate limiter memory high (%d entries), performing aggressive cleanup",
total_entries,
)
# Keep only the most recent entries per key
for key in list(self._requests.keys()):
if len(self._requests[key]) > 100: # Keep max 100 entries per key
self._requests[key] = self._requests[key][-100:]
self._last_cleanup = current_time
class RedisRateLimiter:
"""Redis-backed rate limiter for production."""
def __init__(self) -> None:
from superset.extensions import cache_manager
self._cache = cache_manager.cache
self._prefix = "mcp:ratelimit:"
def is_rate_limited(
self, key: str, limit: int, window: int = 60
) -> tuple[bool, dict[str, Any]]:
"""Check if request should be rate limited using Redis sliding window."""
current_time = time.time()
full_key = "%s%s" % (self._prefix, key)
try:
# Use Redis sorted set for sliding window
window_start = current_time - window
# Remove old entries outside the window
self._cache.delete_many(
[
k
for k, score in self._cache.get(full_key) or []
if score < window_start
]
)
# Get count of requests in window
request_count = self._cache.get("%s:count" % full_key) or 0
# Rate limit info
rate_limit_info = {
"limit": limit,
"remaining": max(0, limit - request_count),
"reset_time": int(current_time + window),
"window_seconds": window,
}
if request_count >= limit:
return True, rate_limit_info
# Increment counter with TTL
new_count = (request_count or 0) + 1
self._cache.set("%s:count" % full_key, new_count, timeout=window)
return False, rate_limit_info
except Exception as e:
logger.warning("Redis rate limiter error: %s, allowing request", e)
# On Redis error, allow the request
return False, {
"limit": limit,
"remaining": limit,
"reset_time": 0,
"window_seconds": window,
}
def cleanup(self) -> None:
"""No cleanup needed for Redis - TTL handles expiration."""
pass
def create_rate_limiter() -> RateLimiterProtocol:
"""Factory to create appropriate rate limiter based on environment."""
try:
# Try to use Redis first (production)
from superset.extensions import cache_manager
if cache_manager and cache_manager.cache:
# Test Redis connectivity
test_key = "mcp:ratelimit:test"
cache_manager.cache.set(test_key, 1, timeout=1)
if cache_manager.cache.get(test_key):
cache_manager.cache.delete(test_key)
logger.info("Using Redis for rate limiting")
return RedisRateLimiter()
except Exception as e:
logger.warning(
"Redis not available for rate limiting: %s, falling back to in-memory", e
)
# Fallback to in-memory rate limiter (development)
logger.info("Using in-memory rate limiter")
return InMemoryRateLimiter()
class RateLimitMiddleware(Middleware):
"""
Rate limiting middleware to prevent abuse of MCP tools.
Implements sliding window rate limiting with separate limits for:
- Per-user limits (if authenticated)
- Per-IP limits (for unauthenticated requests)
- Per-tool limits (for expensive operations)
Configuration:
- default_requests_per_minute: Default rate limit (60 requests/minute)
- per_user_requests_per_minute: Rate limit per authenticated user (120/min)
- expensive_tool_requests_per_minute: Rate limit for expensive tools (10/min)
"""
def __init__(
self,
default_requests_per_minute: int = 60,
per_user_requests_per_minute: int = 120,
expensive_tool_requests_per_minute: int = 10,
expensive_tools: list[str] | None = None,
) -> None:
self.default_rpm = default_requests_per_minute
self.user_rpm = per_user_requests_per_minute
self.expensive_rpm = expensive_tool_requests_per_minute
self.expensive_tools = set(
expensive_tools
or [
"get_chart_preview",
"generate_chart",
"generate_dashboard",
"get_chart_data",
]
)
# Use hybrid rate limiter (Redis in production, in-memory in development)
self._rate_limiter = create_rate_limiter()
def _get_rate_limit_key(self, context: MiddlewareContext) -> tuple[str, int]:
"""
Generate rate limit key and determine applicable limit.
Returns:
Tuple of (key, requests_per_minute_limit)
"""
tool_name = getattr(context.message, "name", "unknown")
# Get user context
user_id = None
try:
user_id = get_user_id()
except Exception:
user_id = None # User not authenticated
# Determine rate limit
if tool_name in self.expensive_tools:
limit = self.expensive_rpm
key_prefix = "expensive"
elif user_id:
limit = self.user_rpm
key_prefix = "user"
else:
limit = self.default_rpm
key_prefix = "default"
# Generate key
if user_id:
key = f"{key_prefix}:user:{user_id}:{tool_name}"
else:
# Use agent_id or session info as fallback
agent_id = None
if hasattr(context, "metadata") and context.metadata:
agent_id = context.metadata.get("agent_id")
if not agent_id and hasattr(context, "session") and context.session:
agent_id = getattr(context.session, "agent_id", None)
if agent_id:
key = f"{key_prefix}:agent:{agent_id}:{tool_name}"
else:
key = f"{key_prefix}:anonymous:{tool_name}"
return key, limit
async def on_call_tool(
self,
context: MiddlewareContext,
call_next: Callable[[MiddlewareContext], Awaitable[Any]],
) -> Any:
"""Check rate limits before allowing tool calls."""
# Clean up old entries periodically (only needed for in-memory)
self._rate_limiter.cleanup()
# Get rate limit key and limit
key, limit = self._get_rate_limit_key(context)
# Check if rate limited
is_limited, rate_info = self._rate_limiter.is_rate_limited(key, limit)
if is_limited:
tool_name = getattr(context.message, "name", "unknown")
# Log rate limit event
try:
user_id = get_user_id() if hasattr(context, "session") else None
event_logger.log(
user_id=user_id,
action="mcp_rate_limit_exceeded",
curated_payload={
"tool": tool_name,
"rate_limit_key": key,
"limit": limit,
"window_seconds": 60,
},
)
except Exception as log_error:
logger.warning("Failed to log rate limit event: %s", log_error)
logger.warning(
"Rate limit exceeded for %s: key=%s, limit=%s/min, reset_in=%ss",
tool_name,
key,
limit,
rate_info["reset_time"] - int(time.time()),
)
raise ToolError(
"Rate limit exceeded for %s. "
"Limit: %s requests per minute. "
"Try again in %s seconds."
% (tool_name, limit, rate_info["reset_time"] - int(time.time()))
)
# Log rate limit info for monitoring
logger.debug(
"Rate limit check: %s: key=%s, remaining=%s/%s",
getattr(context.message, "name", "unknown"),
key,
rate_info["remaining"],
limit,
)
return await call_next(context)
class FieldPermissionsMiddleware(Middleware):
"""
Middleware that applies field-level permissions to filter sensitive data
from MCP tool responses based on user permissions.
"""
# Map tool names to object types for permission filtering
TOOL_OBJECT_TYPE_MAP = {
"list_datasets": "dataset",
"get_dataset_info": "dataset",
"list_charts": "chart",
"get_chart_info": "chart",
"get_chart_data": "chart",
"get_chart_preview": "chart",
"update_chart": "chart",
"generate_chart": "chart",
"list_dashboards": "dashboard",
"get_dashboard_info": "dashboard",
"generate_dashboard": "dashboard",
"add_chart_to_existing_dashboard": "dashboard",
}
async def on_call_tool(
self,
context: MiddlewareContext,
call_next: Callable[[MiddlewareContext], Awaitable[Any]],
) -> Any:
"""Apply field-level permissions to tool responses."""
# Get the tool response first
response = await call_next(context)
# Get tool name
tool_name = getattr(context.message, "name", "unknown")
# Check if this tool needs field-level filtering
object_type = self.TOOL_OBJECT_TYPE_MAP.get(tool_name)
if not object_type:
# No filtering needed
return response
# Get current user for permissions
try:
user = self._get_current_user()
except Exception as e:
logger.warning("Could not get current user for field filtering: %s", e)
user = None
# Apply field-level permissions to the response
try:
filtered_response = self._filter_response(response, object_type, user)
# Log field filtering activity for monitoring
logger.debug(
"Applied field-level permissions for %s (object_type=%s, user=%s)",
tool_name,
object_type,
getattr(user, "username", "anonymous"),
)
return filtered_response
except Exception as e:
logger.error("Error applying field permissions to %s: %s", tool_name, e)
# Return original response if filtering fails
return response
def _get_current_user(self) -> Any:
"""Get the current authenticated user."""
try:
from flask import g
return getattr(g, "user", None)
except Exception:
# Try to get user from core utils
try:
user_id = get_user_id()
if user_id:
from flask_appbuilder.security.sqla.models import User
from superset.extensions import db
return db.session.query(User).filter_by(id=user_id).first()
except Exception as e:
logger.debug("Could not get user from session: %s", e)
return None
def _filter_response(self, response: Any, object_type: str, user: Any) -> Any:
"""
Filter response data based on object type and user permissions.
Args:
response: The response object to filter
object_type: Type of object ('dataset', 'chart', 'dashboard')
user: User object for permission checking
Returns:
Filtered response
"""
from superset.mcp_service.utils.permissions_utils import filter_sensitive_data
if not response:
return response
# Handle different response types
if hasattr(response, "model_dump"):
# Pydantic model - convert to dict, filter, and return dict
response_dict = response.model_dump()
return filter_sensitive_data(response_dict, object_type, user)
elif isinstance(response, dict):
# Dictionary response - filter directly
return filter_sensitive_data(response, object_type, user)
elif isinstance(response, list):
# List response - filter each item
return [filter_sensitive_data(item, object_type, user) for item in response]
else:
# Unknown response type, return as-is
logger.debug(
"Unknown response type for field filtering: %s", type(response)
)
return response