Files
superset2/superset/mcp_service/utils/retry_utils.py

342 lines
10 KiB
Python

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Retry utilities for handling transient failures in MCP service operations.
"""
import asyncio
import functools
import logging
import secrets
import time
from typing import Any, Callable, Type, TypeVar
from sqlalchemy.exc import OperationalError, TimeoutError
from starlette.exceptions import HTTPException
logger = logging.getLogger(__name__)
T = TypeVar("T")
# Default retryable exceptions
DEFAULT_RETRYABLE_EXCEPTIONS = (
OperationalError, # Database connection issues
TimeoutError, # Database timeouts
ConnectionError, # Network issues
OSError, # File system issues (for screenshots)
)
def exponential_backoff(
attempt: int, base_delay: float = 1.0, max_delay: float = 60.0, jitter: bool = True
) -> float:
"""
Calculate exponential backoff delay with optional jitter.
Args:
attempt: Current attempt number (0-based)
base_delay: Base delay in seconds
max_delay: Maximum delay in seconds
jitter: Whether to add random jitter to avoid thundering herd
Returns:
Delay in seconds
"""
delay = base_delay * (2**attempt)
delay = min(delay, max_delay)
if jitter:
# Add up to 25% jitter using cryptographically secure random
jitter_amount = delay * 0.25
random_gen = secrets.SystemRandom()
delay += random_gen.uniform(-jitter_amount, jitter_amount)
return max(0, delay)
def retry_on_exception(
max_attempts: int = 3,
base_delay: float = 1.0,
max_delay: float = 60.0,
exceptions: tuple[Type[Exception], ...] = DEFAULT_RETRYABLE_EXCEPTIONS,
jitter: bool = True,
) -> Callable[[Callable[..., T]], Callable[..., T]]:
"""
Decorator to retry function calls on specific exceptions with exponential backoff.
Args:
max_attempts: Maximum number of attempts (including initial attempt)
base_delay: Base delay in seconds for exponential backoff
max_delay: Maximum delay in seconds between retries
exceptions: Tuple of exception types to retry on
jitter: Whether to add random jitter to backoff delays
Returns:
Decorated function
"""
def decorator(func: Callable[..., T]) -> Callable[..., T]:
@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> T:
last_exception = None
for attempt in range(max_attempts):
try:
return func(*args, **kwargs)
except exceptions as e: # noqa: PERF203
last_exception = e
if attempt == max_attempts - 1:
# Last attempt, don't wait
break
delay = exponential_backoff(attempt, base_delay, max_delay, jitter)
logger.warning(
"Attempt %s/%s failed for %s: %s. Retrying in %.2fs...",
attempt + 1,
max_attempts,
func.__name__,
e,
delay,
)
time.sleep(delay)
except Exception as e:
# Non-retryable exception, fail immediately
logger.error("Non-retryable exception in %s: %s", func.__name__, e)
raise
# All attempts failed
if last_exception is not None:
logger.error(
"All %s attempts failed for %s: %s",
max_attempts,
func.__name__,
last_exception,
)
raise last_exception
logger.error("All %s attempts failed for %s", max_attempts, func.__name__)
raise RuntimeError(
f"All {max_attempts} attempts failed for {func.__name__}"
)
return wrapper
return decorator
def async_retry_on_exception(
max_attempts: int = 3,
base_delay: float = 1.0,
max_delay: float = 60.0,
exceptions: tuple[Type[Exception], ...] = DEFAULT_RETRYABLE_EXCEPTIONS,
jitter: bool = True,
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
"""
Async version of retry_on_exception decorator.
"""
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
@functools.wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> Any:
last_exception = None
for attempt in range(max_attempts):
try:
return await func(*args, **kwargs)
except exceptions as e: # noqa: PERF203
last_exception = e
if attempt == max_attempts - 1:
# Last attempt, don't wait
break
delay = exponential_backoff(attempt, base_delay, max_delay, jitter)
logger.warning(
"Attempt %s/%s failed for %s: %s. Retrying in %.2fs...",
attempt + 1,
max_attempts,
func.__name__,
e,
delay,
)
await asyncio.sleep(delay)
except Exception as e:
# Non-retryable exception, fail immediately
logger.error("Non-retryable exception in %s: %s", func.__name__, e)
raise
# All attempts failed
if last_exception is not None:
logger.error(
"All %s attempts failed for %s: %s",
max_attempts,
func.__name__,
last_exception,
)
raise last_exception
logger.error("All %s attempts failed for %s", max_attempts, func.__name__)
raise RuntimeError(
f"All {max_attempts} attempts failed for {func.__name__}"
)
return wrapper
return decorator
class RetryableOperation:
"""
Context manager for retryable operations with custom logic.
"""
def __init__(
self,
operation_name: str,
max_attempts: int = 3,
base_delay: float = 1.0,
max_delay: float = 60.0,
exceptions: tuple[Type[Exception], ...] = DEFAULT_RETRYABLE_EXCEPTIONS,
jitter: bool = True,
) -> None:
self.operation_name = operation_name
self.max_attempts = max_attempts
self.base_delay = base_delay
self.max_delay = max_delay
self.exceptions = exceptions
self.jitter = jitter
self.current_attempt = 0
self.last_exception: Exception | None = None
def __enter__(self) -> "RetryableOperation":
return self
def __exit__(
self,
exc_type: Type[Exception] | None,
exc_val: Exception | None,
exc_tb: Any,
) -> bool:
if exc_type is None:
# No exception, operation succeeded
return False
if not issubclass(exc_type, self.exceptions):
# Non-retryable exception
logger.error(
"Non-retryable exception in %s: %s", self.operation_name, exc_val
)
return False
self.last_exception = exc_val
self.current_attempt += 1
if self.current_attempt >= self.max_attempts:
# Max attempts reached
logger.error(
"All %s attempts failed for %s. ",
self.max_attempts,
self.operation_name,
)
return False
# Calculate delay and wait
delay = exponential_backoff(
self.current_attempt - 1, self.base_delay, self.max_delay, self.jitter
)
logger.warning(
"Attempt %s/%s failed for %s: %s. Retrying in %.2fs...",
self.current_attempt,
self.max_attempts,
self.operation_name,
exc_val,
delay,
)
time.sleep(delay)
return True # Suppress the exception and continue
def should_retry(self) -> bool:
"""Check if we should continue retrying"""
return self.current_attempt < self.max_attempts
# Convenience functions for common operations
def retry_database_operation(
func: Callable[..., T], *args: Any, max_attempts: int = 3, **kwargs: Any
) -> T:
"""
Execute a database operation with retry logic.
"""
@retry_on_exception(
max_attempts=max_attempts,
exceptions=(OperationalError, TimeoutError),
base_delay=0.5,
max_delay=30.0,
)
def _wrapped() -> T:
return func(*args, **kwargs)
return _wrapped()
async def async_retry_database_operation(
func: Callable[..., Any], *args: Any, max_attempts: int = 3, **kwargs: Any
) -> Any:
"""
Execute an async database operation with retry logic.
"""
@async_retry_on_exception(
max_attempts=max_attempts,
exceptions=(OperationalError, TimeoutError),
base_delay=0.5,
max_delay=30.0,
)
async def _wrapped() -> Any:
return await func(*args, **kwargs)
return await _wrapped()
def retry_screenshot_operation(
func: Callable[..., T],
*args: Any,
max_attempts: int = 2, # Screenshots are expensive, fewer retries
**kwargs: Any,
) -> T:
"""
Execute a screenshot operation with retry logic.
"""
@retry_on_exception(
max_attempts=max_attempts,
exceptions=(OSError, ConnectionError, HTTPException),
base_delay=2.0, # Longer initial delay for screenshots
max_delay=30.0,
)
def _wrapped() -> T:
return func(*args, **kwargs)
return _wrapped()